Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
70b68029
Unverified
Commit
70b68029
authored
Sep 13, 2024
by
Liangsheng Yin
Committed by
GitHub
Sep 13, 2024
Browse files
Optimize conflicts between CUDA graph and vocab mask tensors (#1392)
parent
f3d32f88
Changes
32
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
81 additions
and
149 deletions
+81
-149
python/sglang/bench_latency.py
python/sglang/bench_latency.py
+4
-4
python/sglang/srt/layers/sampler.py
python/sglang/srt/layers/sampler.py
+0
-23
python/sglang/srt/managers/schedule_batch.py
python/sglang/srt/managers/schedule_batch.py
+0
-19
python/sglang/srt/managers/tp_worker.py
python/sglang/srt/managers/tp_worker.py
+5
-4
python/sglang/srt/model_executor/cuda_graph_runner.py
python/sglang/srt/model_executor/cuda_graph_runner.py
+2
-19
python/sglang/srt/model_executor/forward_batch_info.py
python/sglang/srt/model_executor/forward_batch_info.py
+0
-5
python/sglang/srt/model_executor/model_runner.py
python/sglang/srt/model_executor/model_runner.py
+57
-9
python/sglang/srt/models/baichuan.py
python/sglang/srt/models/baichuan.py
+1
-6
python/sglang/srt/models/chatglm.py
python/sglang/srt/models/chatglm.py
+1
-5
python/sglang/srt/models/commandr.py
python/sglang/srt/models/commandr.py
+1
-5
python/sglang/srt/models/dbrx.py
python/sglang/srt/models/dbrx.py
+1
-5
python/sglang/srt/models/deepseek.py
python/sglang/srt/models/deepseek.py
+1
-5
python/sglang/srt/models/deepseek_v2.py
python/sglang/srt/models/deepseek_v2.py
+1
-5
python/sglang/srt/models/exaone.py
python/sglang/srt/models/exaone.py
+1
-5
python/sglang/srt/models/gemma.py
python/sglang/srt/models/gemma.py
+1
-5
python/sglang/srt/models/gemma2.py
python/sglang/srt/models/gemma2.py
+1
-5
python/sglang/srt/models/gpt_bigcode.py
python/sglang/srt/models/gpt_bigcode.py
+1
-5
python/sglang/srt/models/grok.py
python/sglang/srt/models/grok.py
+1
-5
python/sglang/srt/models/internlm2.py
python/sglang/srt/models/internlm2.py
+1
-5
python/sglang/srt/models/llama.py
python/sglang/srt/models/llama.py
+1
-5
No files found.
python/sglang/bench_latency.py
View file @
70b68029
...
...
@@ -207,15 +207,15 @@ def extend(reqs, model_runner):
tree_cache
=
None
,
)
batch
.
prepare_for_extend
(
model_runner
.
model_config
.
vocab_size
)
sample_output
,
logits_output
=
model_runner
.
forward
(
batch
)
next_token_ids
=
sample
_output
.
batch
_next_token_ids
.
tolist
()
logits_output
=
model_runner
.
forward
(
batch
)
next_token_ids
=
model_runner
.
sample
(
logits
_output
,
batch
)
.
tolist
()
return
next_token_ids
,
logits_output
.
next_token_logits
,
batch
def
decode
(
input_token_ids
,
batch
,
model_runner
):
batch
.
prepare_for_decode
(
input_token_ids
)
sample_output
,
logits_output
=
model_runner
.
forward
(
batch
)
next_token_ids
=
sample
_output
.
batch
_next_token_ids
.
tolist
()
logits_output
=
model_runner
.
forward
(
batch
)
next_token_ids
=
model_runner
.
sample
(
logits
_output
,
batch
)
.
tolist
()
return
next_token_ids
,
logits_output
.
next_token_logits
...
...
python/sglang/srt/layers/sampler.py
View file @
70b68029
...
...
@@ -35,21 +35,6 @@ class Sampler(CustomOp):
self
.
forward_native
=
self
.
forward_cuda
self
.
is_torch_compile
=
False
def
_apply_penalties
(
self
,
logits
:
torch
.
Tensor
,
sampling_info
:
SamplingBatchInfo
):
# min-token, presence, frequency
if
sampling_info
.
linear_penalties
is
not
None
:
logits
+=
sampling_info
.
linear_penalties
# repetition
if
sampling_info
.
scaling_penalties
is
not
None
:
logits
=
torch
.
where
(
logits
>
0
,
logits
/
sampling_info
.
scaling_penalties
,
logits
*
sampling_info
.
scaling_penalties
,
)
return
logits
def
_get_probs
(
self
,
logits
:
torch
.
Tensor
,
sampling_info
:
SamplingBatchInfo
):
# Post process logits
logits
=
logits
.
contiguous
()
...
...
@@ -58,14 +43,6 @@ class Sampler(CustomOp):
# FIXME: Temporary workaround for unknown bugs in torch.compile
logits
.
add_
(
0
)
if
sampling_info
.
logit_bias
is
not
None
:
logits
.
add_
(
sampling_info
.
logit_bias
)
if
sampling_info
.
vocab_mask
is
not
None
:
logits
=
logits
.
masked_fill
(
sampling_info
.
vocab_mask
,
float
(
"-inf"
))
logits
=
self
.
_apply_penalties
(
logits
,
sampling_info
)
return
torch
.
softmax
(
logits
,
dim
=-
1
)
def
forward_cuda
(
...
...
python/sglang/srt/managers/schedule_batch.py
View file @
70b68029
...
...
@@ -33,10 +33,6 @@ from sglang.srt.model_executor.forward_batch_info import ForwardMode
from
sglang.srt.sampling.sampling_batch_info
import
SamplingBatchInfo
from
sglang.srt.server_args
import
ServerArgs
if
TYPE_CHECKING
:
from
sglang.srt.layers.sampler
import
SampleOutput
INIT_INCREMENTAL_DETOKENIZATION_OFFSET
=
5
# Put some global args for easy access
...
...
@@ -710,18 +706,3 @@ class ScheduleBatch:
self
.
out_cache_loc
=
None
self
.
top_logprobs_nums
.
extend
(
other
.
top_logprobs_nums
)
self
.
return_logprob
=
any
(
req
.
return_logprob
for
req
in
self
.
reqs
)
def
check_sample_results
(
self
,
sample_output
:
SampleOutput
):
if
not
torch
.
all
(
sample_output
.
success
):
probs
=
sample_output
.
probs
batch_next_token_ids
=
sample_output
.
batch_next_token_ids
logging
.
warning
(
"Sampling failed, fallback to top_k=1 strategy"
)
probs
=
probs
.
masked_fill
(
torch
.
isnan
(
probs
),
0.0
)
argmax_ids
=
torch
.
argmax
(
probs
,
dim
=-
1
)
batch_next_token_ids
=
torch
.
where
(
sample_output
.
success
,
batch_next_token_ids
,
argmax_ids
)
sample_output
.
probs
=
probs
sample_output
.
batch_next_token_ids
=
batch_next_token_ids
return
sample_output
.
batch_next_token_ids
python/sglang/srt/managers/tp_worker.py
View file @
70b68029
...
...
@@ -547,8 +547,9 @@ class ModelTpServer:
if
self
.
model_runner
.
is_generation
:
# Forward and sample the next tokens
if
batch
.
extend_num_tokens
!=
0
:
sample_output
,
logits_output
=
self
.
model_runner
.
forward
(
batch
)
next_token_ids
=
batch
.
check_sample_results
(
sample_output
)
logits_output
=
self
.
model_runner
.
forward
(
batch
)
next_token_ids
=
self
.
model_runner
.
sample
(
logits_output
,
batch
)
batch
.
sampling_info
.
penalizer_orchestrator
.
cumulate_output_tokens
(
next_token_ids
)
...
...
@@ -723,8 +724,8 @@ class ModelTpServer:
batch
.
prepare_for_decode
()
# Forward and sample the next tokens
sample_output
,
logits_output
=
self
.
model_runner
.
forward
(
batch
)
next_token_ids
=
batch
.
check_sample_results
(
sample_output
)
logits_output
=
self
.
model_runner
.
forward
(
batch
)
next_token_ids
=
self
.
model_runner
.
sample
(
logits_output
,
batch
)
batch
.
sampling_info
.
penalizer_orchestrator
.
cumulate_output_tokens
(
next_token_ids
)
...
...
python/sglang/srt/model_executor/cuda_graph_runner.py
View file @
70b68029
...
...
@@ -30,10 +30,8 @@ from sglang.srt.layers.logits_processor import (
LogitsProcessor
,
LogitsProcessorOutput
,
)
from
sglang.srt.layers.sampler
import
SampleOutput
from
sglang.srt.managers.schedule_batch
import
ScheduleBatch
from
sglang.srt.model_executor.forward_batch_info
import
ForwardMode
,
InputMetadata
from
sglang.srt.sampling.sampling_batch_info
import
SamplingBatchInfo
from
sglang.srt.utils
import
monkey_patch_vllm_all_gather
if
TYPE_CHECKING
:
...
...
@@ -129,10 +127,6 @@ class CudaGraphRunner:
self
.
model_runner
.
attn_backend
.
get_cuda_graph_seq_len_fill_value
()
)
# Sampling info
vocab_size
=
model_runner
.
model_config
.
vocab_size
self
.
sampling_info
=
SamplingBatchInfo
.
dummy_one
(
self
.
max_bs
,
vocab_size
)
if
self
.
use_torch_compile
:
set_torch_compile_config
()
...
...
@@ -191,7 +185,6 @@ class CudaGraphRunner:
def
run_once
():
input_metadata
=
InputMetadata
(
forward_mode
=
ForwardMode
.
DECODE
,
sampling_info
=
self
.
sampling_info
[:
bs
],
batch_size
=
bs
,
req_pool_indices
=
req_pool_indices
,
seq_lens
=
seq_lens
,
...
...
@@ -250,14 +243,9 @@ class CudaGraphRunner:
bs
,
self
.
req_pool_indices
,
self
.
seq_lens
)
# Sampling inputs
self
.
sampling_info
.
inplace_assign
(
raw_bs
,
batch
.
sampling_info
)
# Replay
torch
.
cuda
.
synchronize
()
self
.
graphs
[
bs
].
replay
()
torch
.
cuda
.
synchronize
()
sample_output
,
logits_output
=
self
.
output_buffers
[
bs
]
logits_output
=
self
.
output_buffers
[
bs
]
# Unpad
if
bs
!=
raw_bs
:
...
...
@@ -269,11 +257,6 @@ class CudaGraphRunner:
input_top_logprobs
=
None
,
output_top_logprobs
=
None
,
)
sample_output
=
SampleOutput
(
sample_output
.
success
[:
raw_bs
],
sample_output
.
probs
[:
raw_bs
],
sample_output
.
batch_next_token_ids
[:
raw_bs
],
)
# Extract logprobs
if
batch
.
return_logprob
:
...
...
@@ -290,4 +273,4 @@ class CudaGraphRunner:
logits_output
.
next_token_logprobs
,
logits_metadata
)[
1
]
return
sample_output
,
logits_output
return
logits_output
python/sglang/srt/model_executor/forward_batch_info.py
View file @
70b68029
...
...
@@ -28,7 +28,6 @@ if TYPE_CHECKING:
from
sglang.srt.managers.schedule_batch
import
ScheduleBatch
from
sglang.srt.mem_cache.memory_pool
import
BaseTokenToKVPool
,
ReqToTokenPool
from
sglang.srt.model_executor.model_runner
import
ModelRunner
from
sglang.srt.sampling.sampling_batch_info
import
SamplingBatchInfo
class
ForwardMode
(
IntEnum
):
...
...
@@ -59,7 +58,6 @@ class InputMetadata:
"""Store all inforamtion of a forward pass."""
forward_mode
:
ForwardMode
sampling_info
:
SamplingBatchInfo
batch_size
:
int
req_pool_indices
:
torch
.
Tensor
seq_lens
:
torch
.
Tensor
...
...
@@ -170,7 +168,6 @@ class InputMetadata:
):
ret
=
cls
(
forward_mode
=
batch
.
forward_mode
,
sampling_info
=
batch
.
sampling_info
,
batch_size
=
batch
.
batch_size
(),
req_pool_indices
=
batch
.
req_pool_indices
,
seq_lens
=
batch
.
seq_lens
,
...
...
@@ -182,8 +179,6 @@ class InputMetadata:
top_logprobs_nums
=
batch
.
top_logprobs_nums
,
)
ret
.
sampling_info
.
update_penalties
()
ret
.
sampling_info
.
update_regex_vocab_mask
(
batch
)
ret
.
compute_positions
(
batch
)
if
not
batch
.
forward_mode
.
is_decode
():
...
...
python/sglang/srt/model_executor/model_runner.py
View file @
70b68029
...
...
@@ -40,7 +40,7 @@ from vllm.model_executor.models import ModelRegistry
from
sglang.srt.configs.model_config
import
AttentionArch
,
ModelConfig
from
sglang.srt.layers.attention_backend
import
FlashInferAttnBackend
,
TritonAttnBackend
from
sglang.srt.layers.logits_processor
import
LogitsProcessorOutput
from
sglang.srt.layers.sampler
import
SampleOutput
from
sglang.srt.layers.sampler
import
SampleOutput
,
Sampler
from
sglang.srt.lora.lora_manager
import
LoRAManager
from
sglang.srt.managers.schedule_batch
import
ScheduleBatch
,
global_server_args_dict
from
sglang.srt.mem_cache.memory_pool
import
(
...
...
@@ -49,6 +49,7 @@ from sglang.srt.mem_cache.memory_pool import (
ReqToTokenPool
,
)
from
sglang.srt.model_executor.forward_batch_info
import
InputMetadata
from
sglang.srt.sampling.sampling_batch_info
import
SamplingBatchInfo
from
sglang.srt.server_args
import
ServerArgs
from
sglang.srt.utils
import
(
get_available_gpu_memory
,
...
...
@@ -107,6 +108,7 @@ class ModelRunner:
# Init componnets
min_per_gpu_memory
=
self
.
init_torch_distributed
()
self
.
sampler
=
Sampler
()
self
.
load_model
()
if
server_args
.
lora_paths
is
not
None
:
self
.
init_lora_manager
()
...
...
@@ -466,11 +468,8 @@ class ModelRunner:
def
forward_decode
(
self
,
batch
:
ScheduleBatch
):
if
self
.
server_args
.
lora_paths
is
not
None
:
self
.
lora_manager
.
prepare_lora_batch
(
batch
)
if
(
self
.
cuda_graph_runner
and
self
.
cuda_graph_runner
.
can_run
(
len
(
batch
.
reqs
))
and
batch
.
sampling_info
.
can_run_in_cuda_graph
()
):
if
self
.
cuda_graph_runner
and
self
.
cuda_graph_runner
.
can_run
(
len
(
batch
.
reqs
)):
return
self
.
cuda_graph_runner
.
replay
(
batch
)
input_metadata
=
InputMetadata
.
from_schedule_batch
(
self
,
batch
)
...
...
@@ -510,9 +509,7 @@ class ModelRunner:
input_metadata
.
image_offsets
,
)
def
forward
(
self
,
batch
:
ScheduleBatch
)
->
Tuple
[
SampleOutput
,
LogitsProcessorOutput
]:
def
forward
(
self
,
batch
:
ScheduleBatch
)
->
Tuple
[
LogitsProcessorOutput
]:
assert
batch
.
forward_mode
is
not
None
if
self
.
is_multimodal_model
and
batch
.
forward_mode
.
is_extend
():
...
...
@@ -524,6 +521,57 @@ class ModelRunner:
else
:
raise
ValueError
(
f
"Invaid forward mode:
{
batch
.
forward_mode
}
"
)
def
_check_sample_results
(
self
,
sample_output
:
SampleOutput
):
if
not
torch
.
all
(
sample_output
.
success
):
probs
=
sample_output
.
probs
batch_next_token_ids
=
sample_output
.
batch_next_token_ids
logging
.
warning
(
"Sampling failed, fallback to top_k=1 strategy"
)
probs
=
probs
.
masked_fill
(
torch
.
isnan
(
probs
),
0.0
)
argmax_ids
=
torch
.
argmax
(
probs
,
dim
=-
1
)
batch_next_token_ids
=
torch
.
where
(
sample_output
.
success
,
batch_next_token_ids
,
argmax_ids
)
sample_output
.
probs
=
probs
sample_output
.
batch_next_token_ids
=
batch_next_token_ids
return
sample_output
.
batch_next_token_ids
def
_apply_logits_bias
(
self
,
logits
:
torch
.
Tensor
,
sampling_info
:
SamplingBatchInfo
):
# Apply logit_bias
if
sampling_info
.
logit_bias
is
not
None
:
logits
.
add_
(
sampling_info
.
logit_bias
)
# min-token, presence, frequency
if
sampling_info
.
linear_penalties
is
not
None
:
logits
+=
sampling_info
.
linear_penalties
# repetition
if
sampling_info
.
scaling_penalties
is
not
None
:
logits
=
torch
.
where
(
logits
>
0
,
logits
/
sampling_info
.
scaling_penalties
,
logits
*
sampling_info
.
scaling_penalties
,
)
# Apply regex vocab_mask
if
sampling_info
.
vocab_mask
is
not
None
:
logits
=
logits
.
masked_fill
(
sampling_info
.
vocab_mask
,
float
(
"-inf"
))
return
logits
def
sample
(
self
,
logits_output
:
LogitsProcessorOutput
,
batch
:
ScheduleBatch
)
->
torch
.
Tensor
:
batch
.
sampling_info
.
update_regex_vocab_mask
(
batch
)
batch
.
sampling_info
.
update_penalties
()
logits
=
self
.
_apply_logits_bias
(
logits_output
.
next_token_logits
,
batch
.
sampling_info
)
sample_output
=
self
.
sampler
(
logits
,
batch
.
sampling_info
)
return
self
.
_check_sample_results
(
sample_output
)
@
lru_cache
()
def
import_model_classes
():
...
...
python/sglang/srt/models/baichuan.py
View file @
70b68029
...
...
@@ -46,7 +46,6 @@ from sglang.srt.layers.activation import SiluAndMul
from
sglang.srt.layers.layernorm
import
RMSNorm
from
sglang.srt.layers.logits_processor
import
LogitsProcessor
from
sglang.srt.layers.radix_attention
import
RadixAttention
from
sglang.srt.layers.sampler
import
Sampler
from
sglang.srt.model_executor.forward_batch_info
import
InputMetadata
...
...
@@ -346,7 +345,6 @@ class BaiChuanBaseForCausalLM(nn.Module):
if
self
.
config
.
tie_word_embeddings
:
self
.
lm_head
.
weight
=
self
.
model
.
embed_tokens
.
weight
self
.
logits_processor
=
LogitsProcessor
(
config
)
self
.
sampler
=
Sampler
()
def
forward
(
self
,
...
...
@@ -355,12 +353,9 @@ class BaiChuanBaseForCausalLM(nn.Module):
input_metadata
:
InputMetadata
,
)
->
torch
.
Tensor
:
hidden_states
=
self
.
model
(
input_ids
,
positions
,
input_metadata
)
logits_output
=
self
.
logits_processor
(
return
self
.
logits_processor
(
input_ids
,
hidden_states
,
self
.
lm_head
.
weight
,
input_metadata
)
sample_output
=
self
.
sampler
(
logits_output
,
input_metadata
.
sampling_info
)
return
sample_output
,
logits_output
def
load_weights
(
self
,
weights
:
Iterable
[
Tuple
[
str
,
torch
.
Tensor
]]):
stacked_params_mapping
=
[
...
...
python/sglang/srt/models/chatglm.py
View file @
70b68029
...
...
@@ -42,7 +42,6 @@ from sglang.srt.layers.activation import SiluAndMul
from
sglang.srt.layers.layernorm
import
RMSNorm
from
sglang.srt.layers.logits_processor
import
LogitsProcessor
from
sglang.srt.layers.radix_attention
import
RadixAttention
from
sglang.srt.layers.sampler
import
Sampler
from
sglang.srt.model_executor.forward_batch_info
import
InputMetadata
LoraConfig
=
None
...
...
@@ -371,7 +370,6 @@ class ChatGLMForCausalLM(nn.Module):
self
.
transformer
=
ChatGLMModel
(
config
,
cache_config
,
quant_config
)
self
.
lm_head
=
self
.
transformer
.
output_layer
self
.
logits_processor
=
LogitsProcessor
(
config
)
self
.
sampler
=
Sampler
()
@
torch
.
no_grad
()
def
forward
(
...
...
@@ -381,11 +379,9 @@ class ChatGLMForCausalLM(nn.Module):
input_metadata
:
InputMetadata
,
)
->
torch
.
Tensor
:
hidden_states
=
self
.
transformer
(
input_ids
,
positions
,
input_metadata
)
logits_output
=
self
.
logits_processor
(
return
self
.
logits_processor
(
input_ids
,
hidden_states
,
self
.
lm_head
.
weight
,
input_metadata
)
sample_output
=
self
.
sampler
(
logits_output
,
input_metadata
.
sampling_info
)
return
sample_output
,
logits_output
def
load_weights
(
self
,
weights
:
Iterable
[
Tuple
[
str
,
torch
.
Tensor
]]):
params_dict
=
dict
(
self
.
named_parameters
(
remove_duplicate
=
False
))
...
...
python/sglang/srt/models/commandr.py
View file @
70b68029
...
...
@@ -64,7 +64,6 @@ from vllm.model_executor.utils import set_weight_attrs
from
sglang.srt.layers.activation
import
SiluAndMul
from
sglang.srt.layers.logits_processor
import
LogitsProcessor
from
sglang.srt.layers.radix_attention
import
RadixAttention
from
sglang.srt.layers.sampler
import
Sampler
from
sglang.srt.model_executor.forward_batch_info
import
InputMetadata
...
...
@@ -327,7 +326,6 @@ class CohereForCausalLM(nn.Module):
self
.
config
=
config
self
.
quant_config
=
quant_config
self
.
logits_processor
=
LogitsProcessor
(
config
)
self
.
sampler
=
Sampler
()
self
.
model
=
CohereModel
(
config
,
quant_config
)
@
torch
.
no_grad
()
...
...
@@ -342,11 +340,9 @@ class CohereForCausalLM(nn.Module):
positions
,
input_metadata
,
)
logits_output
=
self
.
logits_processor
(
return
self
.
logits_processor
(
input_ids
,
hidden_states
,
self
.
model
.
embed_tokens
.
weight
,
input_metadata
)
sample_output
=
self
.
sampler
(
logits_output
,
input_metadata
.
sampling_info
)
return
sample_output
,
logits_output
def
load_weights
(
self
,
weights
:
Iterable
[
Tuple
[
str
,
torch
.
Tensor
]]):
stacked_params_mapping
=
[
...
...
python/sglang/srt/models/dbrx.py
View file @
70b68029
...
...
@@ -45,7 +45,6 @@ from vllm.transformers_utils.configs.dbrx import DbrxConfig
from
sglang.srt.layers.logits_processor
import
LogitsProcessor
from
sglang.srt.layers.radix_attention
import
RadixAttention
from
sglang.srt.layers.sampler
import
Sampler
from
sglang.srt.model_executor.forward_batch_info
import
InputMetadata
...
...
@@ -383,7 +382,6 @@ class DbrxForCausalLM(nn.Module):
padding_size
=
DEFAULT_VOCAB_PADDING_SIZE
,
)
self
.
logits_processor
=
LogitsProcessor
(
config
)
self
.
sampler
=
Sampler
()
@
torch
.
no_grad
()
def
forward
(
...
...
@@ -393,11 +391,9 @@ class DbrxForCausalLM(nn.Module):
input_metadata
:
InputMetadata
,
)
->
torch
.
Tensor
:
hidden_states
=
self
.
transformer
(
input_ids
,
positions
,
input_metadata
)
logits_output
=
self
.
logits_processor
(
return
self
.
logits_processor
(
input_ids
,
hidden_states
,
self
.
lm_head
.
weight
,
input_metadata
)
sample_output
=
self
.
sampler
(
logits_output
,
input_metadata
.
sampling_info
)
return
sample_output
,
logits_output
def
load_weights
(
self
,
weights
:
Iterable
[
Tuple
[
str
,
torch
.
Tensor
]]):
expert_params_mapping
=
[
...
...
python/sglang/srt/models/deepseek.py
View file @
70b68029
...
...
@@ -46,7 +46,6 @@ from sglang.srt.layers.activation import SiluAndMul
from
sglang.srt.layers.layernorm
import
RMSNorm
from
sglang.srt.layers.logits_processor
import
LogitsProcessor
from
sglang.srt.layers.radix_attention
import
RadixAttention
from
sglang.srt.layers.sampler
import
Sampler
from
sglang.srt.model_executor.forward_batch_info
import
InputMetadata
...
...
@@ -386,7 +385,6 @@ class DeepseekForCausalLM(nn.Module):
config
.
vocab_size
,
config
.
hidden_size
,
quant_config
=
quant_config
)
self
.
logits_processor
=
LogitsProcessor
(
config
)
self
.
sampler
=
Sampler
()
@
torch
.
no_grad
()
def
forward
(
...
...
@@ -396,11 +394,9 @@ class DeepseekForCausalLM(nn.Module):
input_metadata
:
InputMetadata
,
)
->
torch
.
Tensor
:
hidden_states
=
self
.
model
(
input_ids
,
positions
,
input_metadata
)
logits_output
=
self
.
logits_processor
(
return
self
.
logits_processor
(
input_ids
,
hidden_states
,
self
.
lm_head
.
weight
,
input_metadata
)
sample_output
=
self
.
sampler
(
logits_output
,
input_metadata
.
sampling_info
)
return
sample_output
,
logits_output
def
load_weights
(
self
,
weights
:
Iterable
[
Tuple
[
str
,
torch
.
Tensor
]]):
stacked_params_mapping
=
[
...
...
python/sglang/srt/models/deepseek_v2.py
View file @
70b68029
...
...
@@ -46,7 +46,6 @@ from sglang.srt.layers.activation import SiluAndMul
from
sglang.srt.layers.layernorm
import
RMSNorm
from
sglang.srt.layers.logits_processor
import
LogitsProcessor
from
sglang.srt.layers.radix_attention
import
RadixAttention
from
sglang.srt.layers.sampler
import
Sampler
from
sglang.srt.managers.schedule_batch
import
global_server_args_dict
from
sglang.srt.model_executor.forward_batch_info
import
InputMetadata
...
...
@@ -649,7 +648,6 @@ class DeepseekV2ForCausalLM(nn.Module):
config
.
vocab_size
,
config
.
hidden_size
,
quant_config
=
quant_config
)
self
.
logits_processor
=
LogitsProcessor
(
config
)
self
.
sampler
=
Sampler
()
def
forward
(
self
,
...
...
@@ -658,11 +656,9 @@ class DeepseekV2ForCausalLM(nn.Module):
input_metadata
:
InputMetadata
,
)
->
torch
.
Tensor
:
hidden_states
=
self
.
model
(
input_ids
,
positions
,
input_metadata
)
logits_output
=
self
.
logits_processor
(
return
self
.
logits_processor
(
input_ids
,
hidden_states
,
self
.
lm_head
.
weight
,
input_metadata
)
sample_output
=
self
.
sampler
(
logits_output
,
input_metadata
.
sampling_info
)
return
sample_output
,
logits_output
def
load_weights
(
self
,
weights
:
Iterable
[
Tuple
[
str
,
torch
.
Tensor
]]):
stacked_params_mapping
=
[
...
...
python/sglang/srt/models/exaone.py
View file @
70b68029
...
...
@@ -40,7 +40,6 @@ from sglang.srt.layers.activation import SiluAndMul
from
sglang.srt.layers.layernorm
import
RMSNorm
from
sglang.srt.layers.logits_processor
import
LogitsProcessor
,
LogitsProcessorOutput
from
sglang.srt.layers.radix_attention
import
RadixAttention
from
sglang.srt.layers.sampler
import
Sampler
from
sglang.srt.model_executor.forward_batch_info
import
InputMetadata
...
...
@@ -304,7 +303,6 @@ class ExaoneForCausalLM(nn.Module):
self
.
transformer
=
ExaoneModel
(
config
,
quant_config
=
quant_config
)
self
.
lm_head
=
ParallelLMHead
(
config
.
vocab_size
,
config
.
hidden_size
)
self
.
logits_processor
=
LogitsProcessor
(
config
)
self
.
sampler
=
Sampler
()
@
torch
.
no_grad
()
def
forward
(
...
...
@@ -317,11 +315,9 @@ class ExaoneForCausalLM(nn.Module):
hidden_states
=
self
.
transformer
(
input_ids
,
positions
,
input_metadata
,
input_embeds
)
logits_output
=
self
.
logits_processor
(
return
self
.
logits_processor
(
input_ids
,
hidden_states
,
self
.
lm_head
.
weight
,
input_metadata
)
sample_output
=
self
.
sampler
(
logits_output
,
input_metadata
.
sampling_info
)
return
sample_output
,
logits_output
def
load_weights
(
self
,
weights
:
Iterable
[
Tuple
[
str
,
torch
.
Tensor
]]):
stacked_params_mapping
=
[
...
...
python/sglang/srt/models/gemma.py
View file @
70b68029
...
...
@@ -37,7 +37,6 @@ from sglang.srt.layers.activation import GeluAndMul
from
sglang.srt.layers.layernorm
import
RMSNorm
from
sglang.srt.layers.logits_processor
import
LogitsProcessor
from
sglang.srt.layers.radix_attention
import
RadixAttention
from
sglang.srt.layers.sampler
import
Sampler
from
sglang.srt.model_executor.forward_batch_info
import
InputMetadata
...
...
@@ -288,7 +287,6 @@ class GemmaForCausalLM(nn.Module):
self
.
quant_config
=
quant_config
self
.
model
=
GemmaModel
(
config
,
quant_config
=
quant_config
)
self
.
logits_processor
=
LogitsProcessor
(
config
)
self
.
sampler
=
Sampler
()
@
torch
.
no_grad
()
def
forward
(
...
...
@@ -299,11 +297,9 @@ class GemmaForCausalLM(nn.Module):
input_embeds
:
torch
.
Tensor
=
None
,
)
->
torch
.
Tensor
:
hidden_states
=
self
.
model
(
input_ids
,
positions
,
input_metadata
,
input_embeds
)
logits_output
=
self
.
logits_processor
(
return
self
.
logits_processor
(
input_ids
,
hidden_states
,
self
.
model
.
embed_tokens
.
weight
,
input_metadata
)
sample_output
=
self
.
sampler
(
logits_output
,
input_metadata
.
sampling_info
)
return
(
sample_output
,
logits_output
)
def
load_weights
(
self
,
weights
:
Iterable
[
Tuple
[
str
,
torch
.
Tensor
]]):
stacked_params_mapping
=
[
...
...
python/sglang/srt/models/gemma2.py
View file @
70b68029
...
...
@@ -37,7 +37,6 @@ from sglang.srt.layers.activation import GeluAndMul
from
sglang.srt.layers.layernorm
import
GemmaRMSNorm
from
sglang.srt.layers.logits_processor
import
LogitsProcessor
from
sglang.srt.layers.radix_attention
import
RadixAttention
from
sglang.srt.layers.sampler
import
Sampler
from
sglang.srt.model_executor.forward_batch_info
import
InputMetadata
...
...
@@ -347,7 +346,6 @@ class Gemma2ForCausalLM(nn.Module):
self
.
quant_config
=
quant_config
self
.
model
=
Gemma2Model
(
config
,
cache_config
,
quant_config
)
self
.
logits_processor
=
LogitsProcessor
(
config
)
self
.
sampler
=
Sampler
()
@
torch
.
no_grad
()
def
forward
(
...
...
@@ -358,11 +356,9 @@ class Gemma2ForCausalLM(nn.Module):
input_embeds
:
torch
.
Tensor
=
None
,
)
->
torch
.
Tensor
:
hidden_states
=
self
.
model
(
input_ids
,
positions
,
input_metadata
,
input_embeds
)
logits_output
=
self
.
logits_processor
(
return
self
.
logits_processor
(
input_ids
,
hidden_states
,
self
.
model
.
embed_tokens
.
weight
,
input_metadata
)
sample_output
=
self
.
sampler
(
logits_output
,
input_metadata
.
sampling_info
)
return
sample_output
,
logits_output
def
get_attention_sliding_window_size
(
self
):
return
get_attention_sliding_window_size
(
self
.
config
)
...
...
python/sglang/srt/models/gpt_bigcode.py
View file @
70b68029
...
...
@@ -35,7 +35,6 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from
sglang.srt.layers.activation
import
get_act_fn
from
sglang.srt.layers.logits_processor
import
LogitsProcessor
from
sglang.srt.layers.radix_attention
import
RadixAttention
from
sglang.srt.layers.sampler
import
Sampler
from
sglang.srt.model_executor.forward_batch_info
import
InputMetadata
...
...
@@ -262,7 +261,6 @@ class GPTBigCodeForCausalLM(nn.Module):
if
lora_config
:
self
.
unpadded_vocab_size
+=
lora_config
.
lora_extra_vocab_size
self
.
logits_processor
=
LogitsProcessor
(
config
)
self
.
sampler
=
Sampler
()
@
torch
.
no_grad
()
def
forward
(
...
...
@@ -272,11 +270,9 @@ class GPTBigCodeForCausalLM(nn.Module):
input_metadata
:
InputMetadata
,
)
->
torch
.
Tensor
:
hidden_states
=
self
.
transformer
(
input_ids
,
positions
,
input_metadata
)
logits_output
=
self
.
logits_processor
(
return
self
.
logits_processor
(
input_ids
,
hidden_states
,
self
.
lm_head
.
weight
,
input_metadata
)
sample_output
=
self
.
sampler
(
logits_output
,
input_metadata
.
sampling_info
)
return
sample_output
,
logits_output
def
load_weights
(
self
,
weights
:
Iterable
[
Tuple
[
str
,
torch
.
Tensor
]]):
params_dict
=
dict
(
self
.
named_parameters
(
remove_duplicate
=
False
))
...
...
python/sglang/srt/models/grok.py
View file @
70b68029
...
...
@@ -46,7 +46,6 @@ from sglang.srt.layers.fused_moe import FusedMoE
from
sglang.srt.layers.layernorm
import
RMSNorm
from
sglang.srt.layers.logits_processor
import
LogitsProcessor
from
sglang.srt.layers.radix_attention
import
RadixAttention
from
sglang.srt.layers.sampler
import
Sampler
from
sglang.srt.model_executor.forward_batch_info
import
InputMetadata
...
...
@@ -298,7 +297,6 @@ class Grok1ForCausalLM(nn.Module):
self
.
model
=
Grok1Model
(
config
,
quant_config
=
quant_config
)
self
.
lm_head
=
ParallelLMHead
(
config
.
vocab_size
,
config
.
hidden_size
)
self
.
logits_processor
=
LogitsProcessor
(
config
)
self
.
sampler
=
Sampler
()
# Monkey patch _prepare_weights to load pre-sharded weights
setattr
(
DefaultModelLoader
,
"_prepare_weights"
,
_prepare_presharded_weights
)
...
...
@@ -315,11 +313,9 @@ class Grok1ForCausalLM(nn.Module):
input_embeds
:
torch
.
Tensor
=
None
,
)
->
torch
.
Tensor
:
hidden_states
=
self
.
model
(
input_ids
,
positions
,
input_metadata
,
input_embeds
)
logits_output
=
self
.
logits_processor
(
return
self
.
logits_processor
(
input_ids
,
hidden_states
,
self
.
lm_head
.
weight
,
input_metadata
)
sample_output
=
self
.
sampler
(
logits_output
,
input_metadata
.
sampling_info
)
return
sample_output
,
logits_output
def
load_weights
(
self
,
weights
:
Iterable
[
Tuple
[
str
,
torch
.
Tensor
]]):
stacked_params_mapping
=
[
...
...
python/sglang/srt/models/internlm2.py
View file @
70b68029
...
...
@@ -40,7 +40,6 @@ from sglang.srt.layers.activation import SiluAndMul
from
sglang.srt.layers.layernorm
import
RMSNorm
from
sglang.srt.layers.logits_processor
import
LogitsProcessor
from
sglang.srt.layers.radix_attention
import
RadixAttention
from
sglang.srt.layers.sampler
import
Sampler
from
sglang.srt.model_executor.forward_batch_info
import
InputMetadata
...
...
@@ -263,7 +262,6 @@ class InternLM2ForCausalLM(nn.Module):
self
.
model
=
InternLM2Model
(
config
,
quant_config
)
self
.
output
=
ParallelLMHead
(
config
.
vocab_size
,
config
.
hidden_size
)
self
.
logits_processor
=
LogitsProcessor
(
config
)
self
.
sampler
=
Sampler
()
@
torch
.
no_grad
()
def
forward
(
...
...
@@ -274,11 +272,9 @@ class InternLM2ForCausalLM(nn.Module):
input_embeds
:
torch
.
Tensor
=
None
,
)
->
torch
.
Tensor
:
hidden_states
=
self
.
model
(
input_ids
,
positions
,
input_metadata
,
input_embeds
)
logits_output
=
self
.
logits_processor
(
return
self
.
logits_processor
(
input_ids
,
hidden_states
,
self
.
output
.
weight
,
input_metadata
)
sample_output
=
self
.
sampler
(
logits_output
,
input_metadata
.
sampling_info
)
return
sample_output
,
logits_output
def
load_weights
(
self
,
weights
:
Iterable
[
Tuple
[
str
,
torch
.
Tensor
]]):
stacked_params_mapping
=
[
...
...
python/sglang/srt/models/llama.py
View file @
70b68029
...
...
@@ -41,7 +41,6 @@ from sglang.srt.layers.activation import SiluAndMul
from
sglang.srt.layers.layernorm
import
RMSNorm
from
sglang.srt.layers.logits_processor
import
LogitsProcessor
,
LogitsProcessorOutput
from
sglang.srt.layers.radix_attention
import
RadixAttention
from
sglang.srt.layers.sampler
import
Sampler
from
sglang.srt.layers.torchao_utils
import
torchao_quantize_param_data
from
sglang.srt.managers.schedule_batch
import
global_server_args_dict
from
sglang.srt.model_executor.forward_batch_info
import
InputMetadata
...
...
@@ -305,7 +304,6 @@ class LlamaForCausalLM(nn.Module):
self
.
model
=
LlamaModel
(
config
,
quant_config
=
quant_config
)
self
.
lm_head
=
ParallelLMHead
(
config
.
vocab_size
,
config
.
hidden_size
)
self
.
logits_processor
=
LogitsProcessor
(
config
)
self
.
sampler
=
Sampler
()
self
.
param_dict
=
dict
(
self
.
named_parameters
())
...
...
@@ -318,11 +316,9 @@ class LlamaForCausalLM(nn.Module):
input_embeds
:
torch
.
Tensor
=
None
,
)
->
LogitsProcessorOutput
:
hidden_states
=
self
.
model
(
input_ids
,
positions
,
input_metadata
,
input_embeds
)
logits_output
=
self
.
logits_processor
(
return
self
.
logits_processor
(
input_ids
,
hidden_states
,
self
.
lm_head
.
weight
,
input_metadata
)
sample_output
=
self
.
sampler
(
logits_output
,
input_metadata
.
sampling_info
)
return
sample_output
,
logits_output
def
get_hidden_dim
(
self
,
module_name
):
if
module_name
in
[
"q_proj"
,
"o_proj"
,
"qkv_proj"
]:
...
...
Prev
1
2
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment