Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
70b68029
Unverified
Commit
70b68029
authored
Sep 13, 2024
by
Liangsheng Yin
Committed by
GitHub
Sep 13, 2024
Browse files
Optimize conflicts between CUDA graph and vocab mask tensors (#1392)
parent
f3d32f88
Changes
32
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
81 additions
and
149 deletions
+81
-149
python/sglang/bench_latency.py
python/sglang/bench_latency.py
+4
-4
python/sglang/srt/layers/sampler.py
python/sglang/srt/layers/sampler.py
+0
-23
python/sglang/srt/managers/schedule_batch.py
python/sglang/srt/managers/schedule_batch.py
+0
-19
python/sglang/srt/managers/tp_worker.py
python/sglang/srt/managers/tp_worker.py
+5
-4
python/sglang/srt/model_executor/cuda_graph_runner.py
python/sglang/srt/model_executor/cuda_graph_runner.py
+2
-19
python/sglang/srt/model_executor/forward_batch_info.py
python/sglang/srt/model_executor/forward_batch_info.py
+0
-5
python/sglang/srt/model_executor/model_runner.py
python/sglang/srt/model_executor/model_runner.py
+57
-9
python/sglang/srt/models/baichuan.py
python/sglang/srt/models/baichuan.py
+1
-6
python/sglang/srt/models/chatglm.py
python/sglang/srt/models/chatglm.py
+1
-5
python/sglang/srt/models/commandr.py
python/sglang/srt/models/commandr.py
+1
-5
python/sglang/srt/models/dbrx.py
python/sglang/srt/models/dbrx.py
+1
-5
python/sglang/srt/models/deepseek.py
python/sglang/srt/models/deepseek.py
+1
-5
python/sglang/srt/models/deepseek_v2.py
python/sglang/srt/models/deepseek_v2.py
+1
-5
python/sglang/srt/models/exaone.py
python/sglang/srt/models/exaone.py
+1
-5
python/sglang/srt/models/gemma.py
python/sglang/srt/models/gemma.py
+1
-5
python/sglang/srt/models/gemma2.py
python/sglang/srt/models/gemma2.py
+1
-5
python/sglang/srt/models/gpt_bigcode.py
python/sglang/srt/models/gpt_bigcode.py
+1
-5
python/sglang/srt/models/grok.py
python/sglang/srt/models/grok.py
+1
-5
python/sglang/srt/models/internlm2.py
python/sglang/srt/models/internlm2.py
+1
-5
python/sglang/srt/models/llama.py
python/sglang/srt/models/llama.py
+1
-5
No files found.
python/sglang/bench_latency.py
View file @
70b68029
...
@@ -207,15 +207,15 @@ def extend(reqs, model_runner):
...
@@ -207,15 +207,15 @@ def extend(reqs, model_runner):
tree_cache
=
None
,
tree_cache
=
None
,
)
)
batch
.
prepare_for_extend
(
model_runner
.
model_config
.
vocab_size
)
batch
.
prepare_for_extend
(
model_runner
.
model_config
.
vocab_size
)
sample_output
,
logits_output
=
model_runner
.
forward
(
batch
)
logits_output
=
model_runner
.
forward
(
batch
)
next_token_ids
=
sample
_output
.
batch
_next_token_ids
.
tolist
()
next_token_ids
=
model_runner
.
sample
(
logits
_output
,
batch
)
.
tolist
()
return
next_token_ids
,
logits_output
.
next_token_logits
,
batch
return
next_token_ids
,
logits_output
.
next_token_logits
,
batch
def
decode
(
input_token_ids
,
batch
,
model_runner
):
def
decode
(
input_token_ids
,
batch
,
model_runner
):
batch
.
prepare_for_decode
(
input_token_ids
)
batch
.
prepare_for_decode
(
input_token_ids
)
sample_output
,
logits_output
=
model_runner
.
forward
(
batch
)
logits_output
=
model_runner
.
forward
(
batch
)
next_token_ids
=
sample
_output
.
batch
_next_token_ids
.
tolist
()
next_token_ids
=
model_runner
.
sample
(
logits
_output
,
batch
)
.
tolist
()
return
next_token_ids
,
logits_output
.
next_token_logits
return
next_token_ids
,
logits_output
.
next_token_logits
...
...
python/sglang/srt/layers/sampler.py
View file @
70b68029
...
@@ -35,21 +35,6 @@ class Sampler(CustomOp):
...
@@ -35,21 +35,6 @@ class Sampler(CustomOp):
self
.
forward_native
=
self
.
forward_cuda
self
.
forward_native
=
self
.
forward_cuda
self
.
is_torch_compile
=
False
self
.
is_torch_compile
=
False
def
_apply_penalties
(
self
,
logits
:
torch
.
Tensor
,
sampling_info
:
SamplingBatchInfo
):
# min-token, presence, frequency
if
sampling_info
.
linear_penalties
is
not
None
:
logits
+=
sampling_info
.
linear_penalties
# repetition
if
sampling_info
.
scaling_penalties
is
not
None
:
logits
=
torch
.
where
(
logits
>
0
,
logits
/
sampling_info
.
scaling_penalties
,
logits
*
sampling_info
.
scaling_penalties
,
)
return
logits
def
_get_probs
(
self
,
logits
:
torch
.
Tensor
,
sampling_info
:
SamplingBatchInfo
):
def
_get_probs
(
self
,
logits
:
torch
.
Tensor
,
sampling_info
:
SamplingBatchInfo
):
# Post process logits
# Post process logits
logits
=
logits
.
contiguous
()
logits
=
logits
.
contiguous
()
...
@@ -58,14 +43,6 @@ class Sampler(CustomOp):
...
@@ -58,14 +43,6 @@ class Sampler(CustomOp):
# FIXME: Temporary workaround for unknown bugs in torch.compile
# FIXME: Temporary workaround for unknown bugs in torch.compile
logits
.
add_
(
0
)
logits
.
add_
(
0
)
if
sampling_info
.
logit_bias
is
not
None
:
logits
.
add_
(
sampling_info
.
logit_bias
)
if
sampling_info
.
vocab_mask
is
not
None
:
logits
=
logits
.
masked_fill
(
sampling_info
.
vocab_mask
,
float
(
"-inf"
))
logits
=
self
.
_apply_penalties
(
logits
,
sampling_info
)
return
torch
.
softmax
(
logits
,
dim
=-
1
)
return
torch
.
softmax
(
logits
,
dim
=-
1
)
def
forward_cuda
(
def
forward_cuda
(
...
...
python/sglang/srt/managers/schedule_batch.py
View file @
70b68029
...
@@ -33,10 +33,6 @@ from sglang.srt.model_executor.forward_batch_info import ForwardMode
...
@@ -33,10 +33,6 @@ from sglang.srt.model_executor.forward_batch_info import ForwardMode
from
sglang.srt.sampling.sampling_batch_info
import
SamplingBatchInfo
from
sglang.srt.sampling.sampling_batch_info
import
SamplingBatchInfo
from
sglang.srt.server_args
import
ServerArgs
from
sglang.srt.server_args
import
ServerArgs
if
TYPE_CHECKING
:
from
sglang.srt.layers.sampler
import
SampleOutput
INIT_INCREMENTAL_DETOKENIZATION_OFFSET
=
5
INIT_INCREMENTAL_DETOKENIZATION_OFFSET
=
5
# Put some global args for easy access
# Put some global args for easy access
...
@@ -710,18 +706,3 @@ class ScheduleBatch:
...
@@ -710,18 +706,3 @@ class ScheduleBatch:
self
.
out_cache_loc
=
None
self
.
out_cache_loc
=
None
self
.
top_logprobs_nums
.
extend
(
other
.
top_logprobs_nums
)
self
.
top_logprobs_nums
.
extend
(
other
.
top_logprobs_nums
)
self
.
return_logprob
=
any
(
req
.
return_logprob
for
req
in
self
.
reqs
)
self
.
return_logprob
=
any
(
req
.
return_logprob
for
req
in
self
.
reqs
)
def
check_sample_results
(
self
,
sample_output
:
SampleOutput
):
if
not
torch
.
all
(
sample_output
.
success
):
probs
=
sample_output
.
probs
batch_next_token_ids
=
sample_output
.
batch_next_token_ids
logging
.
warning
(
"Sampling failed, fallback to top_k=1 strategy"
)
probs
=
probs
.
masked_fill
(
torch
.
isnan
(
probs
),
0.0
)
argmax_ids
=
torch
.
argmax
(
probs
,
dim
=-
1
)
batch_next_token_ids
=
torch
.
where
(
sample_output
.
success
,
batch_next_token_ids
,
argmax_ids
)
sample_output
.
probs
=
probs
sample_output
.
batch_next_token_ids
=
batch_next_token_ids
return
sample_output
.
batch_next_token_ids
python/sglang/srt/managers/tp_worker.py
View file @
70b68029
...
@@ -547,8 +547,9 @@ class ModelTpServer:
...
@@ -547,8 +547,9 @@ class ModelTpServer:
if
self
.
model_runner
.
is_generation
:
if
self
.
model_runner
.
is_generation
:
# Forward and sample the next tokens
# Forward and sample the next tokens
if
batch
.
extend_num_tokens
!=
0
:
if
batch
.
extend_num_tokens
!=
0
:
sample_output
,
logits_output
=
self
.
model_runner
.
forward
(
batch
)
logits_output
=
self
.
model_runner
.
forward
(
batch
)
next_token_ids
=
batch
.
check_sample_results
(
sample_output
)
next_token_ids
=
self
.
model_runner
.
sample
(
logits_output
,
batch
)
batch
.
sampling_info
.
penalizer_orchestrator
.
cumulate_output_tokens
(
batch
.
sampling_info
.
penalizer_orchestrator
.
cumulate_output_tokens
(
next_token_ids
next_token_ids
)
)
...
@@ -723,8 +724,8 @@ class ModelTpServer:
...
@@ -723,8 +724,8 @@ class ModelTpServer:
batch
.
prepare_for_decode
()
batch
.
prepare_for_decode
()
# Forward and sample the next tokens
# Forward and sample the next tokens
sample_output
,
logits_output
=
self
.
model_runner
.
forward
(
batch
)
logits_output
=
self
.
model_runner
.
forward
(
batch
)
next_token_ids
=
batch
.
check_sample_results
(
sample_output
)
next_token_ids
=
self
.
model_runner
.
sample
(
logits_output
,
batch
)
batch
.
sampling_info
.
penalizer_orchestrator
.
cumulate_output_tokens
(
batch
.
sampling_info
.
penalizer_orchestrator
.
cumulate_output_tokens
(
next_token_ids
next_token_ids
)
)
...
...
python/sglang/srt/model_executor/cuda_graph_runner.py
View file @
70b68029
...
@@ -30,10 +30,8 @@ from sglang.srt.layers.logits_processor import (
...
@@ -30,10 +30,8 @@ from sglang.srt.layers.logits_processor import (
LogitsProcessor
,
LogitsProcessor
,
LogitsProcessorOutput
,
LogitsProcessorOutput
,
)
)
from
sglang.srt.layers.sampler
import
SampleOutput
from
sglang.srt.managers.schedule_batch
import
ScheduleBatch
from
sglang.srt.managers.schedule_batch
import
ScheduleBatch
from
sglang.srt.model_executor.forward_batch_info
import
ForwardMode
,
InputMetadata
from
sglang.srt.model_executor.forward_batch_info
import
ForwardMode
,
InputMetadata
from
sglang.srt.sampling.sampling_batch_info
import
SamplingBatchInfo
from
sglang.srt.utils
import
monkey_patch_vllm_all_gather
from
sglang.srt.utils
import
monkey_patch_vllm_all_gather
if
TYPE_CHECKING
:
if
TYPE_CHECKING
:
...
@@ -129,10 +127,6 @@ class CudaGraphRunner:
...
@@ -129,10 +127,6 @@ class CudaGraphRunner:
self
.
model_runner
.
attn_backend
.
get_cuda_graph_seq_len_fill_value
()
self
.
model_runner
.
attn_backend
.
get_cuda_graph_seq_len_fill_value
()
)
)
# Sampling info
vocab_size
=
model_runner
.
model_config
.
vocab_size
self
.
sampling_info
=
SamplingBatchInfo
.
dummy_one
(
self
.
max_bs
,
vocab_size
)
if
self
.
use_torch_compile
:
if
self
.
use_torch_compile
:
set_torch_compile_config
()
set_torch_compile_config
()
...
@@ -191,7 +185,6 @@ class CudaGraphRunner:
...
@@ -191,7 +185,6 @@ class CudaGraphRunner:
def
run_once
():
def
run_once
():
input_metadata
=
InputMetadata
(
input_metadata
=
InputMetadata
(
forward_mode
=
ForwardMode
.
DECODE
,
forward_mode
=
ForwardMode
.
DECODE
,
sampling_info
=
self
.
sampling_info
[:
bs
],
batch_size
=
bs
,
batch_size
=
bs
,
req_pool_indices
=
req_pool_indices
,
req_pool_indices
=
req_pool_indices
,
seq_lens
=
seq_lens
,
seq_lens
=
seq_lens
,
...
@@ -250,14 +243,9 @@ class CudaGraphRunner:
...
@@ -250,14 +243,9 @@ class CudaGraphRunner:
bs
,
self
.
req_pool_indices
,
self
.
seq_lens
bs
,
self
.
req_pool_indices
,
self
.
seq_lens
)
)
# Sampling inputs
self
.
sampling_info
.
inplace_assign
(
raw_bs
,
batch
.
sampling_info
)
# Replay
# Replay
torch
.
cuda
.
synchronize
()
self
.
graphs
[
bs
].
replay
()
self
.
graphs
[
bs
].
replay
()
torch
.
cuda
.
synchronize
()
logits_output
=
self
.
output_buffers
[
bs
]
sample_output
,
logits_output
=
self
.
output_buffers
[
bs
]
# Unpad
# Unpad
if
bs
!=
raw_bs
:
if
bs
!=
raw_bs
:
...
@@ -269,11 +257,6 @@ class CudaGraphRunner:
...
@@ -269,11 +257,6 @@ class CudaGraphRunner:
input_top_logprobs
=
None
,
input_top_logprobs
=
None
,
output_top_logprobs
=
None
,
output_top_logprobs
=
None
,
)
)
sample_output
=
SampleOutput
(
sample_output
.
success
[:
raw_bs
],
sample_output
.
probs
[:
raw_bs
],
sample_output
.
batch_next_token_ids
[:
raw_bs
],
)
# Extract logprobs
# Extract logprobs
if
batch
.
return_logprob
:
if
batch
.
return_logprob
:
...
@@ -290,4 +273,4 @@ class CudaGraphRunner:
...
@@ -290,4 +273,4 @@ class CudaGraphRunner:
logits_output
.
next_token_logprobs
,
logits_metadata
logits_output
.
next_token_logprobs
,
logits_metadata
)[
1
]
)[
1
]
return
sample_output
,
logits_output
return
logits_output
python/sglang/srt/model_executor/forward_batch_info.py
View file @
70b68029
...
@@ -28,7 +28,6 @@ if TYPE_CHECKING:
...
@@ -28,7 +28,6 @@ if TYPE_CHECKING:
from
sglang.srt.managers.schedule_batch
import
ScheduleBatch
from
sglang.srt.managers.schedule_batch
import
ScheduleBatch
from
sglang.srt.mem_cache.memory_pool
import
BaseTokenToKVPool
,
ReqToTokenPool
from
sglang.srt.mem_cache.memory_pool
import
BaseTokenToKVPool
,
ReqToTokenPool
from
sglang.srt.model_executor.model_runner
import
ModelRunner
from
sglang.srt.model_executor.model_runner
import
ModelRunner
from
sglang.srt.sampling.sampling_batch_info
import
SamplingBatchInfo
class
ForwardMode
(
IntEnum
):
class
ForwardMode
(
IntEnum
):
...
@@ -59,7 +58,6 @@ class InputMetadata:
...
@@ -59,7 +58,6 @@ class InputMetadata:
"""Store all inforamtion of a forward pass."""
"""Store all inforamtion of a forward pass."""
forward_mode
:
ForwardMode
forward_mode
:
ForwardMode
sampling_info
:
SamplingBatchInfo
batch_size
:
int
batch_size
:
int
req_pool_indices
:
torch
.
Tensor
req_pool_indices
:
torch
.
Tensor
seq_lens
:
torch
.
Tensor
seq_lens
:
torch
.
Tensor
...
@@ -170,7 +168,6 @@ class InputMetadata:
...
@@ -170,7 +168,6 @@ class InputMetadata:
):
):
ret
=
cls
(
ret
=
cls
(
forward_mode
=
batch
.
forward_mode
,
forward_mode
=
batch
.
forward_mode
,
sampling_info
=
batch
.
sampling_info
,
batch_size
=
batch
.
batch_size
(),
batch_size
=
batch
.
batch_size
(),
req_pool_indices
=
batch
.
req_pool_indices
,
req_pool_indices
=
batch
.
req_pool_indices
,
seq_lens
=
batch
.
seq_lens
,
seq_lens
=
batch
.
seq_lens
,
...
@@ -182,8 +179,6 @@ class InputMetadata:
...
@@ -182,8 +179,6 @@ class InputMetadata:
top_logprobs_nums
=
batch
.
top_logprobs_nums
,
top_logprobs_nums
=
batch
.
top_logprobs_nums
,
)
)
ret
.
sampling_info
.
update_penalties
()
ret
.
sampling_info
.
update_regex_vocab_mask
(
batch
)
ret
.
compute_positions
(
batch
)
ret
.
compute_positions
(
batch
)
if
not
batch
.
forward_mode
.
is_decode
():
if
not
batch
.
forward_mode
.
is_decode
():
...
...
python/sglang/srt/model_executor/model_runner.py
View file @
70b68029
...
@@ -40,7 +40,7 @@ from vllm.model_executor.models import ModelRegistry
...
@@ -40,7 +40,7 @@ from vllm.model_executor.models import ModelRegistry
from
sglang.srt.configs.model_config
import
AttentionArch
,
ModelConfig
from
sglang.srt.configs.model_config
import
AttentionArch
,
ModelConfig
from
sglang.srt.layers.attention_backend
import
FlashInferAttnBackend
,
TritonAttnBackend
from
sglang.srt.layers.attention_backend
import
FlashInferAttnBackend
,
TritonAttnBackend
from
sglang.srt.layers.logits_processor
import
LogitsProcessorOutput
from
sglang.srt.layers.logits_processor
import
LogitsProcessorOutput
from
sglang.srt.layers.sampler
import
SampleOutput
from
sglang.srt.layers.sampler
import
SampleOutput
,
Sampler
from
sglang.srt.lora.lora_manager
import
LoRAManager
from
sglang.srt.lora.lora_manager
import
LoRAManager
from
sglang.srt.managers.schedule_batch
import
ScheduleBatch
,
global_server_args_dict
from
sglang.srt.managers.schedule_batch
import
ScheduleBatch
,
global_server_args_dict
from
sglang.srt.mem_cache.memory_pool
import
(
from
sglang.srt.mem_cache.memory_pool
import
(
...
@@ -49,6 +49,7 @@ from sglang.srt.mem_cache.memory_pool import (
...
@@ -49,6 +49,7 @@ from sglang.srt.mem_cache.memory_pool import (
ReqToTokenPool
,
ReqToTokenPool
,
)
)
from
sglang.srt.model_executor.forward_batch_info
import
InputMetadata
from
sglang.srt.model_executor.forward_batch_info
import
InputMetadata
from
sglang.srt.sampling.sampling_batch_info
import
SamplingBatchInfo
from
sglang.srt.server_args
import
ServerArgs
from
sglang.srt.server_args
import
ServerArgs
from
sglang.srt.utils
import
(
from
sglang.srt.utils
import
(
get_available_gpu_memory
,
get_available_gpu_memory
,
...
@@ -107,6 +108,7 @@ class ModelRunner:
...
@@ -107,6 +108,7 @@ class ModelRunner:
# Init componnets
# Init componnets
min_per_gpu_memory
=
self
.
init_torch_distributed
()
min_per_gpu_memory
=
self
.
init_torch_distributed
()
self
.
sampler
=
Sampler
()
self
.
load_model
()
self
.
load_model
()
if
server_args
.
lora_paths
is
not
None
:
if
server_args
.
lora_paths
is
not
None
:
self
.
init_lora_manager
()
self
.
init_lora_manager
()
...
@@ -466,11 +468,8 @@ class ModelRunner:
...
@@ -466,11 +468,8 @@ class ModelRunner:
def
forward_decode
(
self
,
batch
:
ScheduleBatch
):
def
forward_decode
(
self
,
batch
:
ScheduleBatch
):
if
self
.
server_args
.
lora_paths
is
not
None
:
if
self
.
server_args
.
lora_paths
is
not
None
:
self
.
lora_manager
.
prepare_lora_batch
(
batch
)
self
.
lora_manager
.
prepare_lora_batch
(
batch
)
if
(
self
.
cuda_graph_runner
if
self
.
cuda_graph_runner
and
self
.
cuda_graph_runner
.
can_run
(
len
(
batch
.
reqs
)):
and
self
.
cuda_graph_runner
.
can_run
(
len
(
batch
.
reqs
))
and
batch
.
sampling_info
.
can_run_in_cuda_graph
()
):
return
self
.
cuda_graph_runner
.
replay
(
batch
)
return
self
.
cuda_graph_runner
.
replay
(
batch
)
input_metadata
=
InputMetadata
.
from_schedule_batch
(
self
,
batch
)
input_metadata
=
InputMetadata
.
from_schedule_batch
(
self
,
batch
)
...
@@ -510,9 +509,7 @@ class ModelRunner:
...
@@ -510,9 +509,7 @@ class ModelRunner:
input_metadata
.
image_offsets
,
input_metadata
.
image_offsets
,
)
)
def
forward
(
def
forward
(
self
,
batch
:
ScheduleBatch
)
->
Tuple
[
LogitsProcessorOutput
]:
self
,
batch
:
ScheduleBatch
)
->
Tuple
[
SampleOutput
,
LogitsProcessorOutput
]:
assert
batch
.
forward_mode
is
not
None
assert
batch
.
forward_mode
is
not
None
if
self
.
is_multimodal_model
and
batch
.
forward_mode
.
is_extend
():
if
self
.
is_multimodal_model
and
batch
.
forward_mode
.
is_extend
():
...
@@ -524,6 +521,57 @@ class ModelRunner:
...
@@ -524,6 +521,57 @@ class ModelRunner:
else
:
else
:
raise
ValueError
(
f
"Invaid forward mode:
{
batch
.
forward_mode
}
"
)
raise
ValueError
(
f
"Invaid forward mode:
{
batch
.
forward_mode
}
"
)
def
_check_sample_results
(
self
,
sample_output
:
SampleOutput
):
if
not
torch
.
all
(
sample_output
.
success
):
probs
=
sample_output
.
probs
batch_next_token_ids
=
sample_output
.
batch_next_token_ids
logging
.
warning
(
"Sampling failed, fallback to top_k=1 strategy"
)
probs
=
probs
.
masked_fill
(
torch
.
isnan
(
probs
),
0.0
)
argmax_ids
=
torch
.
argmax
(
probs
,
dim
=-
1
)
batch_next_token_ids
=
torch
.
where
(
sample_output
.
success
,
batch_next_token_ids
,
argmax_ids
)
sample_output
.
probs
=
probs
sample_output
.
batch_next_token_ids
=
batch_next_token_ids
return
sample_output
.
batch_next_token_ids
def
_apply_logits_bias
(
self
,
logits
:
torch
.
Tensor
,
sampling_info
:
SamplingBatchInfo
):
# Apply logit_bias
if
sampling_info
.
logit_bias
is
not
None
:
logits
.
add_
(
sampling_info
.
logit_bias
)
# min-token, presence, frequency
if
sampling_info
.
linear_penalties
is
not
None
:
logits
+=
sampling_info
.
linear_penalties
# repetition
if
sampling_info
.
scaling_penalties
is
not
None
:
logits
=
torch
.
where
(
logits
>
0
,
logits
/
sampling_info
.
scaling_penalties
,
logits
*
sampling_info
.
scaling_penalties
,
)
# Apply regex vocab_mask
if
sampling_info
.
vocab_mask
is
not
None
:
logits
=
logits
.
masked_fill
(
sampling_info
.
vocab_mask
,
float
(
"-inf"
))
return
logits
def
sample
(
self
,
logits_output
:
LogitsProcessorOutput
,
batch
:
ScheduleBatch
)
->
torch
.
Tensor
:
batch
.
sampling_info
.
update_regex_vocab_mask
(
batch
)
batch
.
sampling_info
.
update_penalties
()
logits
=
self
.
_apply_logits_bias
(
logits_output
.
next_token_logits
,
batch
.
sampling_info
)
sample_output
=
self
.
sampler
(
logits
,
batch
.
sampling_info
)
return
self
.
_check_sample_results
(
sample_output
)
@
lru_cache
()
@
lru_cache
()
def
import_model_classes
():
def
import_model_classes
():
...
...
python/sglang/srt/models/baichuan.py
View file @
70b68029
...
@@ -46,7 +46,6 @@ from sglang.srt.layers.activation import SiluAndMul
...
@@ -46,7 +46,6 @@ from sglang.srt.layers.activation import SiluAndMul
from
sglang.srt.layers.layernorm
import
RMSNorm
from
sglang.srt.layers.layernorm
import
RMSNorm
from
sglang.srt.layers.logits_processor
import
LogitsProcessor
from
sglang.srt.layers.logits_processor
import
LogitsProcessor
from
sglang.srt.layers.radix_attention
import
RadixAttention
from
sglang.srt.layers.radix_attention
import
RadixAttention
from
sglang.srt.layers.sampler
import
Sampler
from
sglang.srt.model_executor.forward_batch_info
import
InputMetadata
from
sglang.srt.model_executor.forward_batch_info
import
InputMetadata
...
@@ -346,7 +345,6 @@ class BaiChuanBaseForCausalLM(nn.Module):
...
@@ -346,7 +345,6 @@ class BaiChuanBaseForCausalLM(nn.Module):
if
self
.
config
.
tie_word_embeddings
:
if
self
.
config
.
tie_word_embeddings
:
self
.
lm_head
.
weight
=
self
.
model
.
embed_tokens
.
weight
self
.
lm_head
.
weight
=
self
.
model
.
embed_tokens
.
weight
self
.
logits_processor
=
LogitsProcessor
(
config
)
self
.
logits_processor
=
LogitsProcessor
(
config
)
self
.
sampler
=
Sampler
()
def
forward
(
def
forward
(
self
,
self
,
...
@@ -355,12 +353,9 @@ class BaiChuanBaseForCausalLM(nn.Module):
...
@@ -355,12 +353,9 @@ class BaiChuanBaseForCausalLM(nn.Module):
input_metadata
:
InputMetadata
,
input_metadata
:
InputMetadata
,
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
hidden_states
=
self
.
model
(
input_ids
,
positions
,
input_metadata
)
hidden_states
=
self
.
model
(
input_ids
,
positions
,
input_metadata
)
logits_output
=
self
.
logits_processor
(
return
self
.
logits_processor
(
input_ids
,
hidden_states
,
self
.
lm_head
.
weight
,
input_metadata
input_ids
,
hidden_states
,
self
.
lm_head
.
weight
,
input_metadata
)
)
sample_output
=
self
.
sampler
(
logits_output
,
input_metadata
.
sampling_info
)
return
sample_output
,
logits_output
def
load_weights
(
self
,
weights
:
Iterable
[
Tuple
[
str
,
torch
.
Tensor
]]):
def
load_weights
(
self
,
weights
:
Iterable
[
Tuple
[
str
,
torch
.
Tensor
]]):
stacked_params_mapping
=
[
stacked_params_mapping
=
[
...
...
python/sglang/srt/models/chatglm.py
View file @
70b68029
...
@@ -42,7 +42,6 @@ from sglang.srt.layers.activation import SiluAndMul
...
@@ -42,7 +42,6 @@ from sglang.srt.layers.activation import SiluAndMul
from
sglang.srt.layers.layernorm
import
RMSNorm
from
sglang.srt.layers.layernorm
import
RMSNorm
from
sglang.srt.layers.logits_processor
import
LogitsProcessor
from
sglang.srt.layers.logits_processor
import
LogitsProcessor
from
sglang.srt.layers.radix_attention
import
RadixAttention
from
sglang.srt.layers.radix_attention
import
RadixAttention
from
sglang.srt.layers.sampler
import
Sampler
from
sglang.srt.model_executor.forward_batch_info
import
InputMetadata
from
sglang.srt.model_executor.forward_batch_info
import
InputMetadata
LoraConfig
=
None
LoraConfig
=
None
...
@@ -371,7 +370,6 @@ class ChatGLMForCausalLM(nn.Module):
...
@@ -371,7 +370,6 @@ class ChatGLMForCausalLM(nn.Module):
self
.
transformer
=
ChatGLMModel
(
config
,
cache_config
,
quant_config
)
self
.
transformer
=
ChatGLMModel
(
config
,
cache_config
,
quant_config
)
self
.
lm_head
=
self
.
transformer
.
output_layer
self
.
lm_head
=
self
.
transformer
.
output_layer
self
.
logits_processor
=
LogitsProcessor
(
config
)
self
.
logits_processor
=
LogitsProcessor
(
config
)
self
.
sampler
=
Sampler
()
@
torch
.
no_grad
()
@
torch
.
no_grad
()
def
forward
(
def
forward
(
...
@@ -381,11 +379,9 @@ class ChatGLMForCausalLM(nn.Module):
...
@@ -381,11 +379,9 @@ class ChatGLMForCausalLM(nn.Module):
input_metadata
:
InputMetadata
,
input_metadata
:
InputMetadata
,
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
hidden_states
=
self
.
transformer
(
input_ids
,
positions
,
input_metadata
)
hidden_states
=
self
.
transformer
(
input_ids
,
positions
,
input_metadata
)
logits_output
=
self
.
logits_processor
(
return
self
.
logits_processor
(
input_ids
,
hidden_states
,
self
.
lm_head
.
weight
,
input_metadata
input_ids
,
hidden_states
,
self
.
lm_head
.
weight
,
input_metadata
)
)
sample_output
=
self
.
sampler
(
logits_output
,
input_metadata
.
sampling_info
)
return
sample_output
,
logits_output
def
load_weights
(
self
,
weights
:
Iterable
[
Tuple
[
str
,
torch
.
Tensor
]]):
def
load_weights
(
self
,
weights
:
Iterable
[
Tuple
[
str
,
torch
.
Tensor
]]):
params_dict
=
dict
(
self
.
named_parameters
(
remove_duplicate
=
False
))
params_dict
=
dict
(
self
.
named_parameters
(
remove_duplicate
=
False
))
...
...
python/sglang/srt/models/commandr.py
View file @
70b68029
...
@@ -64,7 +64,6 @@ from vllm.model_executor.utils import set_weight_attrs
...
@@ -64,7 +64,6 @@ from vllm.model_executor.utils import set_weight_attrs
from
sglang.srt.layers.activation
import
SiluAndMul
from
sglang.srt.layers.activation
import
SiluAndMul
from
sglang.srt.layers.logits_processor
import
LogitsProcessor
from
sglang.srt.layers.logits_processor
import
LogitsProcessor
from
sglang.srt.layers.radix_attention
import
RadixAttention
from
sglang.srt.layers.radix_attention
import
RadixAttention
from
sglang.srt.layers.sampler
import
Sampler
from
sglang.srt.model_executor.forward_batch_info
import
InputMetadata
from
sglang.srt.model_executor.forward_batch_info
import
InputMetadata
...
@@ -327,7 +326,6 @@ class CohereForCausalLM(nn.Module):
...
@@ -327,7 +326,6 @@ class CohereForCausalLM(nn.Module):
self
.
config
=
config
self
.
config
=
config
self
.
quant_config
=
quant_config
self
.
quant_config
=
quant_config
self
.
logits_processor
=
LogitsProcessor
(
config
)
self
.
logits_processor
=
LogitsProcessor
(
config
)
self
.
sampler
=
Sampler
()
self
.
model
=
CohereModel
(
config
,
quant_config
)
self
.
model
=
CohereModel
(
config
,
quant_config
)
@
torch
.
no_grad
()
@
torch
.
no_grad
()
...
@@ -342,11 +340,9 @@ class CohereForCausalLM(nn.Module):
...
@@ -342,11 +340,9 @@ class CohereForCausalLM(nn.Module):
positions
,
positions
,
input_metadata
,
input_metadata
,
)
)
logits_output
=
self
.
logits_processor
(
return
self
.
logits_processor
(
input_ids
,
hidden_states
,
self
.
model
.
embed_tokens
.
weight
,
input_metadata
input_ids
,
hidden_states
,
self
.
model
.
embed_tokens
.
weight
,
input_metadata
)
)
sample_output
=
self
.
sampler
(
logits_output
,
input_metadata
.
sampling_info
)
return
sample_output
,
logits_output
def
load_weights
(
self
,
weights
:
Iterable
[
Tuple
[
str
,
torch
.
Tensor
]]):
def
load_weights
(
self
,
weights
:
Iterable
[
Tuple
[
str
,
torch
.
Tensor
]]):
stacked_params_mapping
=
[
stacked_params_mapping
=
[
...
...
python/sglang/srt/models/dbrx.py
View file @
70b68029
...
@@ -45,7 +45,6 @@ from vllm.transformers_utils.configs.dbrx import DbrxConfig
...
@@ -45,7 +45,6 @@ from vllm.transformers_utils.configs.dbrx import DbrxConfig
from
sglang.srt.layers.logits_processor
import
LogitsProcessor
from
sglang.srt.layers.logits_processor
import
LogitsProcessor
from
sglang.srt.layers.radix_attention
import
RadixAttention
from
sglang.srt.layers.radix_attention
import
RadixAttention
from
sglang.srt.layers.sampler
import
Sampler
from
sglang.srt.model_executor.forward_batch_info
import
InputMetadata
from
sglang.srt.model_executor.forward_batch_info
import
InputMetadata
...
@@ -383,7 +382,6 @@ class DbrxForCausalLM(nn.Module):
...
@@ -383,7 +382,6 @@ class DbrxForCausalLM(nn.Module):
padding_size
=
DEFAULT_VOCAB_PADDING_SIZE
,
padding_size
=
DEFAULT_VOCAB_PADDING_SIZE
,
)
)
self
.
logits_processor
=
LogitsProcessor
(
config
)
self
.
logits_processor
=
LogitsProcessor
(
config
)
self
.
sampler
=
Sampler
()
@
torch
.
no_grad
()
@
torch
.
no_grad
()
def
forward
(
def
forward
(
...
@@ -393,11 +391,9 @@ class DbrxForCausalLM(nn.Module):
...
@@ -393,11 +391,9 @@ class DbrxForCausalLM(nn.Module):
input_metadata
:
InputMetadata
,
input_metadata
:
InputMetadata
,
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
hidden_states
=
self
.
transformer
(
input_ids
,
positions
,
input_metadata
)
hidden_states
=
self
.
transformer
(
input_ids
,
positions
,
input_metadata
)
logits_output
=
self
.
logits_processor
(
return
self
.
logits_processor
(
input_ids
,
hidden_states
,
self
.
lm_head
.
weight
,
input_metadata
input_ids
,
hidden_states
,
self
.
lm_head
.
weight
,
input_metadata
)
)
sample_output
=
self
.
sampler
(
logits_output
,
input_metadata
.
sampling_info
)
return
sample_output
,
logits_output
def
load_weights
(
self
,
weights
:
Iterable
[
Tuple
[
str
,
torch
.
Tensor
]]):
def
load_weights
(
self
,
weights
:
Iterable
[
Tuple
[
str
,
torch
.
Tensor
]]):
expert_params_mapping
=
[
expert_params_mapping
=
[
...
...
python/sglang/srt/models/deepseek.py
View file @
70b68029
...
@@ -46,7 +46,6 @@ from sglang.srt.layers.activation import SiluAndMul
...
@@ -46,7 +46,6 @@ from sglang.srt.layers.activation import SiluAndMul
from
sglang.srt.layers.layernorm
import
RMSNorm
from
sglang.srt.layers.layernorm
import
RMSNorm
from
sglang.srt.layers.logits_processor
import
LogitsProcessor
from
sglang.srt.layers.logits_processor
import
LogitsProcessor
from
sglang.srt.layers.radix_attention
import
RadixAttention
from
sglang.srt.layers.radix_attention
import
RadixAttention
from
sglang.srt.layers.sampler
import
Sampler
from
sglang.srt.model_executor.forward_batch_info
import
InputMetadata
from
sglang.srt.model_executor.forward_batch_info
import
InputMetadata
...
@@ -386,7 +385,6 @@ class DeepseekForCausalLM(nn.Module):
...
@@ -386,7 +385,6 @@ class DeepseekForCausalLM(nn.Module):
config
.
vocab_size
,
config
.
hidden_size
,
quant_config
=
quant_config
config
.
vocab_size
,
config
.
hidden_size
,
quant_config
=
quant_config
)
)
self
.
logits_processor
=
LogitsProcessor
(
config
)
self
.
logits_processor
=
LogitsProcessor
(
config
)
self
.
sampler
=
Sampler
()
@
torch
.
no_grad
()
@
torch
.
no_grad
()
def
forward
(
def
forward
(
...
@@ -396,11 +394,9 @@ class DeepseekForCausalLM(nn.Module):
...
@@ -396,11 +394,9 @@ class DeepseekForCausalLM(nn.Module):
input_metadata
:
InputMetadata
,
input_metadata
:
InputMetadata
,
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
hidden_states
=
self
.
model
(
input_ids
,
positions
,
input_metadata
)
hidden_states
=
self
.
model
(
input_ids
,
positions
,
input_metadata
)
logits_output
=
self
.
logits_processor
(
return
self
.
logits_processor
(
input_ids
,
hidden_states
,
self
.
lm_head
.
weight
,
input_metadata
input_ids
,
hidden_states
,
self
.
lm_head
.
weight
,
input_metadata
)
)
sample_output
=
self
.
sampler
(
logits_output
,
input_metadata
.
sampling_info
)
return
sample_output
,
logits_output
def
load_weights
(
self
,
weights
:
Iterable
[
Tuple
[
str
,
torch
.
Tensor
]]):
def
load_weights
(
self
,
weights
:
Iterable
[
Tuple
[
str
,
torch
.
Tensor
]]):
stacked_params_mapping
=
[
stacked_params_mapping
=
[
...
...
python/sglang/srt/models/deepseek_v2.py
View file @
70b68029
...
@@ -46,7 +46,6 @@ from sglang.srt.layers.activation import SiluAndMul
...
@@ -46,7 +46,6 @@ from sglang.srt.layers.activation import SiluAndMul
from
sglang.srt.layers.layernorm
import
RMSNorm
from
sglang.srt.layers.layernorm
import
RMSNorm
from
sglang.srt.layers.logits_processor
import
LogitsProcessor
from
sglang.srt.layers.logits_processor
import
LogitsProcessor
from
sglang.srt.layers.radix_attention
import
RadixAttention
from
sglang.srt.layers.radix_attention
import
RadixAttention
from
sglang.srt.layers.sampler
import
Sampler
from
sglang.srt.managers.schedule_batch
import
global_server_args_dict
from
sglang.srt.managers.schedule_batch
import
global_server_args_dict
from
sglang.srt.model_executor.forward_batch_info
import
InputMetadata
from
sglang.srt.model_executor.forward_batch_info
import
InputMetadata
...
@@ -649,7 +648,6 @@ class DeepseekV2ForCausalLM(nn.Module):
...
@@ -649,7 +648,6 @@ class DeepseekV2ForCausalLM(nn.Module):
config
.
vocab_size
,
config
.
hidden_size
,
quant_config
=
quant_config
config
.
vocab_size
,
config
.
hidden_size
,
quant_config
=
quant_config
)
)
self
.
logits_processor
=
LogitsProcessor
(
config
)
self
.
logits_processor
=
LogitsProcessor
(
config
)
self
.
sampler
=
Sampler
()
def
forward
(
def
forward
(
self
,
self
,
...
@@ -658,11 +656,9 @@ class DeepseekV2ForCausalLM(nn.Module):
...
@@ -658,11 +656,9 @@ class DeepseekV2ForCausalLM(nn.Module):
input_metadata
:
InputMetadata
,
input_metadata
:
InputMetadata
,
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
hidden_states
=
self
.
model
(
input_ids
,
positions
,
input_metadata
)
hidden_states
=
self
.
model
(
input_ids
,
positions
,
input_metadata
)
logits_output
=
self
.
logits_processor
(
return
self
.
logits_processor
(
input_ids
,
hidden_states
,
self
.
lm_head
.
weight
,
input_metadata
input_ids
,
hidden_states
,
self
.
lm_head
.
weight
,
input_metadata
)
)
sample_output
=
self
.
sampler
(
logits_output
,
input_metadata
.
sampling_info
)
return
sample_output
,
logits_output
def
load_weights
(
self
,
weights
:
Iterable
[
Tuple
[
str
,
torch
.
Tensor
]]):
def
load_weights
(
self
,
weights
:
Iterable
[
Tuple
[
str
,
torch
.
Tensor
]]):
stacked_params_mapping
=
[
stacked_params_mapping
=
[
...
...
python/sglang/srt/models/exaone.py
View file @
70b68029
...
@@ -40,7 +40,6 @@ from sglang.srt.layers.activation import SiluAndMul
...
@@ -40,7 +40,6 @@ from sglang.srt.layers.activation import SiluAndMul
from
sglang.srt.layers.layernorm
import
RMSNorm
from
sglang.srt.layers.layernorm
import
RMSNorm
from
sglang.srt.layers.logits_processor
import
LogitsProcessor
,
LogitsProcessorOutput
from
sglang.srt.layers.logits_processor
import
LogitsProcessor
,
LogitsProcessorOutput
from
sglang.srt.layers.radix_attention
import
RadixAttention
from
sglang.srt.layers.radix_attention
import
RadixAttention
from
sglang.srt.layers.sampler
import
Sampler
from
sglang.srt.model_executor.forward_batch_info
import
InputMetadata
from
sglang.srt.model_executor.forward_batch_info
import
InputMetadata
...
@@ -304,7 +303,6 @@ class ExaoneForCausalLM(nn.Module):
...
@@ -304,7 +303,6 @@ class ExaoneForCausalLM(nn.Module):
self
.
transformer
=
ExaoneModel
(
config
,
quant_config
=
quant_config
)
self
.
transformer
=
ExaoneModel
(
config
,
quant_config
=
quant_config
)
self
.
lm_head
=
ParallelLMHead
(
config
.
vocab_size
,
config
.
hidden_size
)
self
.
lm_head
=
ParallelLMHead
(
config
.
vocab_size
,
config
.
hidden_size
)
self
.
logits_processor
=
LogitsProcessor
(
config
)
self
.
logits_processor
=
LogitsProcessor
(
config
)
self
.
sampler
=
Sampler
()
@
torch
.
no_grad
()
@
torch
.
no_grad
()
def
forward
(
def
forward
(
...
@@ -317,11 +315,9 @@ class ExaoneForCausalLM(nn.Module):
...
@@ -317,11 +315,9 @@ class ExaoneForCausalLM(nn.Module):
hidden_states
=
self
.
transformer
(
hidden_states
=
self
.
transformer
(
input_ids
,
positions
,
input_metadata
,
input_embeds
input_ids
,
positions
,
input_metadata
,
input_embeds
)
)
logits_output
=
self
.
logits_processor
(
return
self
.
logits_processor
(
input_ids
,
hidden_states
,
self
.
lm_head
.
weight
,
input_metadata
input_ids
,
hidden_states
,
self
.
lm_head
.
weight
,
input_metadata
)
)
sample_output
=
self
.
sampler
(
logits_output
,
input_metadata
.
sampling_info
)
return
sample_output
,
logits_output
def
load_weights
(
self
,
weights
:
Iterable
[
Tuple
[
str
,
torch
.
Tensor
]]):
def
load_weights
(
self
,
weights
:
Iterable
[
Tuple
[
str
,
torch
.
Tensor
]]):
stacked_params_mapping
=
[
stacked_params_mapping
=
[
...
...
python/sglang/srt/models/gemma.py
View file @
70b68029
...
@@ -37,7 +37,6 @@ from sglang.srt.layers.activation import GeluAndMul
...
@@ -37,7 +37,6 @@ from sglang.srt.layers.activation import GeluAndMul
from
sglang.srt.layers.layernorm
import
RMSNorm
from
sglang.srt.layers.layernorm
import
RMSNorm
from
sglang.srt.layers.logits_processor
import
LogitsProcessor
from
sglang.srt.layers.logits_processor
import
LogitsProcessor
from
sglang.srt.layers.radix_attention
import
RadixAttention
from
sglang.srt.layers.radix_attention
import
RadixAttention
from
sglang.srt.layers.sampler
import
Sampler
from
sglang.srt.model_executor.forward_batch_info
import
InputMetadata
from
sglang.srt.model_executor.forward_batch_info
import
InputMetadata
...
@@ -288,7 +287,6 @@ class GemmaForCausalLM(nn.Module):
...
@@ -288,7 +287,6 @@ class GemmaForCausalLM(nn.Module):
self
.
quant_config
=
quant_config
self
.
quant_config
=
quant_config
self
.
model
=
GemmaModel
(
config
,
quant_config
=
quant_config
)
self
.
model
=
GemmaModel
(
config
,
quant_config
=
quant_config
)
self
.
logits_processor
=
LogitsProcessor
(
config
)
self
.
logits_processor
=
LogitsProcessor
(
config
)
self
.
sampler
=
Sampler
()
@
torch
.
no_grad
()
@
torch
.
no_grad
()
def
forward
(
def
forward
(
...
@@ -299,11 +297,9 @@ class GemmaForCausalLM(nn.Module):
...
@@ -299,11 +297,9 @@ class GemmaForCausalLM(nn.Module):
input_embeds
:
torch
.
Tensor
=
None
,
input_embeds
:
torch
.
Tensor
=
None
,
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
hidden_states
=
self
.
model
(
input_ids
,
positions
,
input_metadata
,
input_embeds
)
hidden_states
=
self
.
model
(
input_ids
,
positions
,
input_metadata
,
input_embeds
)
logits_output
=
self
.
logits_processor
(
return
self
.
logits_processor
(
input_ids
,
hidden_states
,
self
.
model
.
embed_tokens
.
weight
,
input_metadata
input_ids
,
hidden_states
,
self
.
model
.
embed_tokens
.
weight
,
input_metadata
)
)
sample_output
=
self
.
sampler
(
logits_output
,
input_metadata
.
sampling_info
)
return
(
sample_output
,
logits_output
)
def
load_weights
(
self
,
weights
:
Iterable
[
Tuple
[
str
,
torch
.
Tensor
]]):
def
load_weights
(
self
,
weights
:
Iterable
[
Tuple
[
str
,
torch
.
Tensor
]]):
stacked_params_mapping
=
[
stacked_params_mapping
=
[
...
...
python/sglang/srt/models/gemma2.py
View file @
70b68029
...
@@ -37,7 +37,6 @@ from sglang.srt.layers.activation import GeluAndMul
...
@@ -37,7 +37,6 @@ from sglang.srt.layers.activation import GeluAndMul
from
sglang.srt.layers.layernorm
import
GemmaRMSNorm
from
sglang.srt.layers.layernorm
import
GemmaRMSNorm
from
sglang.srt.layers.logits_processor
import
LogitsProcessor
from
sglang.srt.layers.logits_processor
import
LogitsProcessor
from
sglang.srt.layers.radix_attention
import
RadixAttention
from
sglang.srt.layers.radix_attention
import
RadixAttention
from
sglang.srt.layers.sampler
import
Sampler
from
sglang.srt.model_executor.forward_batch_info
import
InputMetadata
from
sglang.srt.model_executor.forward_batch_info
import
InputMetadata
...
@@ -347,7 +346,6 @@ class Gemma2ForCausalLM(nn.Module):
...
@@ -347,7 +346,6 @@ class Gemma2ForCausalLM(nn.Module):
self
.
quant_config
=
quant_config
self
.
quant_config
=
quant_config
self
.
model
=
Gemma2Model
(
config
,
cache_config
,
quant_config
)
self
.
model
=
Gemma2Model
(
config
,
cache_config
,
quant_config
)
self
.
logits_processor
=
LogitsProcessor
(
config
)
self
.
logits_processor
=
LogitsProcessor
(
config
)
self
.
sampler
=
Sampler
()
@
torch
.
no_grad
()
@
torch
.
no_grad
()
def
forward
(
def
forward
(
...
@@ -358,11 +356,9 @@ class Gemma2ForCausalLM(nn.Module):
...
@@ -358,11 +356,9 @@ class Gemma2ForCausalLM(nn.Module):
input_embeds
:
torch
.
Tensor
=
None
,
input_embeds
:
torch
.
Tensor
=
None
,
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
hidden_states
=
self
.
model
(
input_ids
,
positions
,
input_metadata
,
input_embeds
)
hidden_states
=
self
.
model
(
input_ids
,
positions
,
input_metadata
,
input_embeds
)
logits_output
=
self
.
logits_processor
(
return
self
.
logits_processor
(
input_ids
,
hidden_states
,
self
.
model
.
embed_tokens
.
weight
,
input_metadata
input_ids
,
hidden_states
,
self
.
model
.
embed_tokens
.
weight
,
input_metadata
)
)
sample_output
=
self
.
sampler
(
logits_output
,
input_metadata
.
sampling_info
)
return
sample_output
,
logits_output
def
get_attention_sliding_window_size
(
self
):
def
get_attention_sliding_window_size
(
self
):
return
get_attention_sliding_window_size
(
self
.
config
)
return
get_attention_sliding_window_size
(
self
.
config
)
...
...
python/sglang/srt/models/gpt_bigcode.py
View file @
70b68029
...
@@ -35,7 +35,6 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader
...
@@ -35,7 +35,6 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from
sglang.srt.layers.activation
import
get_act_fn
from
sglang.srt.layers.activation
import
get_act_fn
from
sglang.srt.layers.logits_processor
import
LogitsProcessor
from
sglang.srt.layers.logits_processor
import
LogitsProcessor
from
sglang.srt.layers.radix_attention
import
RadixAttention
from
sglang.srt.layers.radix_attention
import
RadixAttention
from
sglang.srt.layers.sampler
import
Sampler
from
sglang.srt.model_executor.forward_batch_info
import
InputMetadata
from
sglang.srt.model_executor.forward_batch_info
import
InputMetadata
...
@@ -262,7 +261,6 @@ class GPTBigCodeForCausalLM(nn.Module):
...
@@ -262,7 +261,6 @@ class GPTBigCodeForCausalLM(nn.Module):
if
lora_config
:
if
lora_config
:
self
.
unpadded_vocab_size
+=
lora_config
.
lora_extra_vocab_size
self
.
unpadded_vocab_size
+=
lora_config
.
lora_extra_vocab_size
self
.
logits_processor
=
LogitsProcessor
(
config
)
self
.
logits_processor
=
LogitsProcessor
(
config
)
self
.
sampler
=
Sampler
()
@
torch
.
no_grad
()
@
torch
.
no_grad
()
def
forward
(
def
forward
(
...
@@ -272,11 +270,9 @@ class GPTBigCodeForCausalLM(nn.Module):
...
@@ -272,11 +270,9 @@ class GPTBigCodeForCausalLM(nn.Module):
input_metadata
:
InputMetadata
,
input_metadata
:
InputMetadata
,
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
hidden_states
=
self
.
transformer
(
input_ids
,
positions
,
input_metadata
)
hidden_states
=
self
.
transformer
(
input_ids
,
positions
,
input_metadata
)
logits_output
=
self
.
logits_processor
(
return
self
.
logits_processor
(
input_ids
,
hidden_states
,
self
.
lm_head
.
weight
,
input_metadata
input_ids
,
hidden_states
,
self
.
lm_head
.
weight
,
input_metadata
)
)
sample_output
=
self
.
sampler
(
logits_output
,
input_metadata
.
sampling_info
)
return
sample_output
,
logits_output
def
load_weights
(
self
,
weights
:
Iterable
[
Tuple
[
str
,
torch
.
Tensor
]]):
def
load_weights
(
self
,
weights
:
Iterable
[
Tuple
[
str
,
torch
.
Tensor
]]):
params_dict
=
dict
(
self
.
named_parameters
(
remove_duplicate
=
False
))
params_dict
=
dict
(
self
.
named_parameters
(
remove_duplicate
=
False
))
...
...
python/sglang/srt/models/grok.py
View file @
70b68029
...
@@ -46,7 +46,6 @@ from sglang.srt.layers.fused_moe import FusedMoE
...
@@ -46,7 +46,6 @@ from sglang.srt.layers.fused_moe import FusedMoE
from
sglang.srt.layers.layernorm
import
RMSNorm
from
sglang.srt.layers.layernorm
import
RMSNorm
from
sglang.srt.layers.logits_processor
import
LogitsProcessor
from
sglang.srt.layers.logits_processor
import
LogitsProcessor
from
sglang.srt.layers.radix_attention
import
RadixAttention
from
sglang.srt.layers.radix_attention
import
RadixAttention
from
sglang.srt.layers.sampler
import
Sampler
from
sglang.srt.model_executor.forward_batch_info
import
InputMetadata
from
sglang.srt.model_executor.forward_batch_info
import
InputMetadata
...
@@ -298,7 +297,6 @@ class Grok1ForCausalLM(nn.Module):
...
@@ -298,7 +297,6 @@ class Grok1ForCausalLM(nn.Module):
self
.
model
=
Grok1Model
(
config
,
quant_config
=
quant_config
)
self
.
model
=
Grok1Model
(
config
,
quant_config
=
quant_config
)
self
.
lm_head
=
ParallelLMHead
(
config
.
vocab_size
,
config
.
hidden_size
)
self
.
lm_head
=
ParallelLMHead
(
config
.
vocab_size
,
config
.
hidden_size
)
self
.
logits_processor
=
LogitsProcessor
(
config
)
self
.
logits_processor
=
LogitsProcessor
(
config
)
self
.
sampler
=
Sampler
()
# Monkey patch _prepare_weights to load pre-sharded weights
# Monkey patch _prepare_weights to load pre-sharded weights
setattr
(
DefaultModelLoader
,
"_prepare_weights"
,
_prepare_presharded_weights
)
setattr
(
DefaultModelLoader
,
"_prepare_weights"
,
_prepare_presharded_weights
)
...
@@ -315,11 +313,9 @@ class Grok1ForCausalLM(nn.Module):
...
@@ -315,11 +313,9 @@ class Grok1ForCausalLM(nn.Module):
input_embeds
:
torch
.
Tensor
=
None
,
input_embeds
:
torch
.
Tensor
=
None
,
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
hidden_states
=
self
.
model
(
input_ids
,
positions
,
input_metadata
,
input_embeds
)
hidden_states
=
self
.
model
(
input_ids
,
positions
,
input_metadata
,
input_embeds
)
logits_output
=
self
.
logits_processor
(
return
self
.
logits_processor
(
input_ids
,
hidden_states
,
self
.
lm_head
.
weight
,
input_metadata
input_ids
,
hidden_states
,
self
.
lm_head
.
weight
,
input_metadata
)
)
sample_output
=
self
.
sampler
(
logits_output
,
input_metadata
.
sampling_info
)
return
sample_output
,
logits_output
def
load_weights
(
self
,
weights
:
Iterable
[
Tuple
[
str
,
torch
.
Tensor
]]):
def
load_weights
(
self
,
weights
:
Iterable
[
Tuple
[
str
,
torch
.
Tensor
]]):
stacked_params_mapping
=
[
stacked_params_mapping
=
[
...
...
python/sglang/srt/models/internlm2.py
View file @
70b68029
...
@@ -40,7 +40,6 @@ from sglang.srt.layers.activation import SiluAndMul
...
@@ -40,7 +40,6 @@ from sglang.srt.layers.activation import SiluAndMul
from
sglang.srt.layers.layernorm
import
RMSNorm
from
sglang.srt.layers.layernorm
import
RMSNorm
from
sglang.srt.layers.logits_processor
import
LogitsProcessor
from
sglang.srt.layers.logits_processor
import
LogitsProcessor
from
sglang.srt.layers.radix_attention
import
RadixAttention
from
sglang.srt.layers.radix_attention
import
RadixAttention
from
sglang.srt.layers.sampler
import
Sampler
from
sglang.srt.model_executor.forward_batch_info
import
InputMetadata
from
sglang.srt.model_executor.forward_batch_info
import
InputMetadata
...
@@ -263,7 +262,6 @@ class InternLM2ForCausalLM(nn.Module):
...
@@ -263,7 +262,6 @@ class InternLM2ForCausalLM(nn.Module):
self
.
model
=
InternLM2Model
(
config
,
quant_config
)
self
.
model
=
InternLM2Model
(
config
,
quant_config
)
self
.
output
=
ParallelLMHead
(
config
.
vocab_size
,
config
.
hidden_size
)
self
.
output
=
ParallelLMHead
(
config
.
vocab_size
,
config
.
hidden_size
)
self
.
logits_processor
=
LogitsProcessor
(
config
)
self
.
logits_processor
=
LogitsProcessor
(
config
)
self
.
sampler
=
Sampler
()
@
torch
.
no_grad
()
@
torch
.
no_grad
()
def
forward
(
def
forward
(
...
@@ -274,11 +272,9 @@ class InternLM2ForCausalLM(nn.Module):
...
@@ -274,11 +272,9 @@ class InternLM2ForCausalLM(nn.Module):
input_embeds
:
torch
.
Tensor
=
None
,
input_embeds
:
torch
.
Tensor
=
None
,
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
hidden_states
=
self
.
model
(
input_ids
,
positions
,
input_metadata
,
input_embeds
)
hidden_states
=
self
.
model
(
input_ids
,
positions
,
input_metadata
,
input_embeds
)
logits_output
=
self
.
logits_processor
(
return
self
.
logits_processor
(
input_ids
,
hidden_states
,
self
.
output
.
weight
,
input_metadata
input_ids
,
hidden_states
,
self
.
output
.
weight
,
input_metadata
)
)
sample_output
=
self
.
sampler
(
logits_output
,
input_metadata
.
sampling_info
)
return
sample_output
,
logits_output
def
load_weights
(
self
,
weights
:
Iterable
[
Tuple
[
str
,
torch
.
Tensor
]]):
def
load_weights
(
self
,
weights
:
Iterable
[
Tuple
[
str
,
torch
.
Tensor
]]):
stacked_params_mapping
=
[
stacked_params_mapping
=
[
...
...
python/sglang/srt/models/llama.py
View file @
70b68029
...
@@ -41,7 +41,6 @@ from sglang.srt.layers.activation import SiluAndMul
...
@@ -41,7 +41,6 @@ from sglang.srt.layers.activation import SiluAndMul
from
sglang.srt.layers.layernorm
import
RMSNorm
from
sglang.srt.layers.layernorm
import
RMSNorm
from
sglang.srt.layers.logits_processor
import
LogitsProcessor
,
LogitsProcessorOutput
from
sglang.srt.layers.logits_processor
import
LogitsProcessor
,
LogitsProcessorOutput
from
sglang.srt.layers.radix_attention
import
RadixAttention
from
sglang.srt.layers.radix_attention
import
RadixAttention
from
sglang.srt.layers.sampler
import
Sampler
from
sglang.srt.layers.torchao_utils
import
torchao_quantize_param_data
from
sglang.srt.layers.torchao_utils
import
torchao_quantize_param_data
from
sglang.srt.managers.schedule_batch
import
global_server_args_dict
from
sglang.srt.managers.schedule_batch
import
global_server_args_dict
from
sglang.srt.model_executor.forward_batch_info
import
InputMetadata
from
sglang.srt.model_executor.forward_batch_info
import
InputMetadata
...
@@ -305,7 +304,6 @@ class LlamaForCausalLM(nn.Module):
...
@@ -305,7 +304,6 @@ class LlamaForCausalLM(nn.Module):
self
.
model
=
LlamaModel
(
config
,
quant_config
=
quant_config
)
self
.
model
=
LlamaModel
(
config
,
quant_config
=
quant_config
)
self
.
lm_head
=
ParallelLMHead
(
config
.
vocab_size
,
config
.
hidden_size
)
self
.
lm_head
=
ParallelLMHead
(
config
.
vocab_size
,
config
.
hidden_size
)
self
.
logits_processor
=
LogitsProcessor
(
config
)
self
.
logits_processor
=
LogitsProcessor
(
config
)
self
.
sampler
=
Sampler
()
self
.
param_dict
=
dict
(
self
.
named_parameters
())
self
.
param_dict
=
dict
(
self
.
named_parameters
())
...
@@ -318,11 +316,9 @@ class LlamaForCausalLM(nn.Module):
...
@@ -318,11 +316,9 @@ class LlamaForCausalLM(nn.Module):
input_embeds
:
torch
.
Tensor
=
None
,
input_embeds
:
torch
.
Tensor
=
None
,
)
->
LogitsProcessorOutput
:
)
->
LogitsProcessorOutput
:
hidden_states
=
self
.
model
(
input_ids
,
positions
,
input_metadata
,
input_embeds
)
hidden_states
=
self
.
model
(
input_ids
,
positions
,
input_metadata
,
input_embeds
)
logits_output
=
self
.
logits_processor
(
return
self
.
logits_processor
(
input_ids
,
hidden_states
,
self
.
lm_head
.
weight
,
input_metadata
input_ids
,
hidden_states
,
self
.
lm_head
.
weight
,
input_metadata
)
)
sample_output
=
self
.
sampler
(
logits_output
,
input_metadata
.
sampling_info
)
return
sample_output
,
logits_output
def
get_hidden_dim
(
self
,
module_name
):
def
get_hidden_dim
(
self
,
module_name
):
if
module_name
in
[
"q_proj"
,
"o_proj"
,
"qkv_proj"
]:
if
module_name
in
[
"q_proj"
,
"o_proj"
,
"qkv_proj"
]:
...
...
Prev
1
2
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment