Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
75ce37f4
Unverified
Commit
75ce37f4
authored
Aug 26, 2024
by
Liangsheng Yin
Committed by
GitHub
Aug 26, 2024
Browse files
Move sampler into CUDA graph (#1201)
Co-authored-by:
Yineng Zhang
<
me@zhyncs.com
>
parent
97589a60
Changes
28
Hide whitespace changes
Inline
Side-by-side
Showing
8 changed files
with
106 additions
and
23 deletions
+106
-23
python/sglang/srt/models/mixtral.py
python/sglang/srt/models/mixtral.py
+5
-1
python/sglang/srt/models/mixtral_quant.py
python/sglang/srt/models/mixtral_quant.py
+5
-1
python/sglang/srt/models/qwen.py
python/sglang/srt/models/qwen.py
+5
-2
python/sglang/srt/models/qwen2.py
python/sglang/srt/models/qwen2.py
+6
-2
python/sglang/srt/models/qwen2_moe.py
python/sglang/srt/models/qwen2_moe.py
+5
-14
python/sglang/srt/models/stablelm.py
python/sglang/srt/models/stablelm.py
+5
-1
python/sglang/srt/sampling/sampling_batch_info.py
python/sglang/srt/sampling/sampling_batch_info.py
+74
-1
python/sglang/test/runners.py
python/sglang/test/runners.py
+1
-1
No files found.
python/sglang/srt/models/mixtral.py
View file @
75ce37f4
...
...
@@ -41,6 +41,7 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from
sglang.srt.layers.layernorm
import
RMSNorm
from
sglang.srt.layers.logits_processor
import
LogitsProcessor
from
sglang.srt.layers.radix_attention
import
RadixAttention
from
sglang.srt.layers.sampler
import
Sampler
from
sglang.srt.model_executor.forward_batch_info
import
InputMetadata
...
...
@@ -299,6 +300,7 @@ class MixtralForCausalLM(nn.Module):
self
.
model
=
MixtralModel
(
config
,
quant_config
=
quant_config
,
prefix
=
"model"
)
self
.
lm_head
=
ParallelLMHead
(
config
.
vocab_size
,
config
.
hidden_size
)
self
.
logits_processor
=
LogitsProcessor
(
config
)
self
.
sampler
=
Sampler
()
def
forward
(
self
,
...
...
@@ -308,9 +310,11 @@ class MixtralForCausalLM(nn.Module):
input_embeds
:
torch
.
Tensor
=
None
,
)
->
torch
.
Tensor
:
hidden_states
=
self
.
model
(
input_ids
,
positions
,
input_metadata
,
input_embeds
)
return
self
.
logits_processor
(
logits_output
=
self
.
logits_processor
(
input_ids
,
hidden_states
,
self
.
lm_head
.
weight
,
input_metadata
)
sample_output
=
self
.
sampler
(
logits_output
,
input_metadata
.
sampling_info
)
return
sample_output
,
logits_output
def
load_weights
(
self
,
weights
:
Iterable
[
Tuple
[
str
,
torch
.
Tensor
]]):
stacked_params_mapping
=
[
...
...
python/sglang/srt/models/mixtral_quant.py
View file @
75ce37f4
...
...
@@ -45,6 +45,7 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from
sglang.srt.layers.layernorm
import
RMSNorm
from
sglang.srt.layers.logits_processor
import
LogitsProcessor
from
sglang.srt.layers.radix_attention
import
RadixAttention
from
sglang.srt.layers.sampler
import
Sampler
from
sglang.srt.model_executor.forward_batch_info
import
InputMetadata
...
...
@@ -333,6 +334,7 @@ class QuantMixtralForCausalLM(nn.Module):
self
.
model
=
MixtralModel
(
config
,
quant_config
=
quant_config
)
self
.
lm_head
=
ParallelLMHead
(
config
.
vocab_size
,
config
.
hidden_size
)
self
.
logits_processor
=
LogitsProcessor
(
config
)
self
.
sampler
=
Sampler
()
@
torch
.
no_grad
()
def
forward
(
...
...
@@ -343,9 +345,11 @@ class QuantMixtralForCausalLM(nn.Module):
input_embeds
:
torch
.
Tensor
=
None
,
)
->
torch
.
Tensor
:
hidden_states
=
self
.
model
(
input_ids
,
positions
,
input_metadata
,
input_embeds
)
return
self
.
logits_processor
(
logits_output
=
self
.
logits_processor
(
input_ids
,
hidden_states
,
self
.
lm_head
.
weight
,
input_metadata
)
sample_output
=
self
.
sampler
(
logits_output
,
input_metadata
.
sampling_info
)
return
sample_output
,
logits_output
def
load_weights
(
self
,
weights
:
Iterable
[
Tuple
[
str
,
torch
.
Tensor
]]):
stacked_params_mapping
=
[
...
...
python/sglang/srt/models/qwen.py
View file @
75ce37f4
...
...
@@ -39,6 +39,7 @@ from sglang.srt.layers.activation import SiluAndMul
from
sglang.srt.layers.layernorm
import
RMSNorm
from
sglang.srt.layers.logits_processor
import
LogitsProcessor
from
sglang.srt.layers.radix_attention
import
RadixAttention
from
sglang.srt.layers.sampler
import
Sampler
from
sglang.srt.model_executor.forward_batch_info
import
InputMetadata
...
...
@@ -251,6 +252,7 @@ class QWenLMHeadModel(nn.Module):
vocab_size
=
((
config
.
vocab_size
+
63
)
//
64
)
*
64
self
.
lm_head
=
ParallelLMHead
(
vocab_size
,
config
.
hidden_size
)
self
.
logits_processor
=
LogitsProcessor
(
config
)
self
.
sampler
=
Sampler
()
@
torch
.
no_grad
()
def
forward
(
...
...
@@ -260,10 +262,11 @@ class QWenLMHeadModel(nn.Module):
input_metadata
:
InputMetadata
,
):
hidden_states
=
self
.
transformer
(
input_ids
,
positions
,
input_metadata
)
next_tokens
=
self
.
logits_processor
(
logits_output
=
self
.
logits_processor
(
input_ids
,
hidden_states
,
self
.
lm_head
.
weight
,
input_metadata
)
return
next_tokens
sample_output
=
self
.
sampler
(
logits_output
,
input_metadata
.
sampling_info
)
return
sample_output
,
logits_output
def
load_weights
(
self
,
weights
:
Iterable
[
Tuple
[
str
,
torch
.
Tensor
]]):
stacked_params_mapping
=
[
...
...
python/sglang/srt/models/qwen2.py
View file @
75ce37f4
...
...
@@ -38,8 +38,9 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from
sglang.srt.layers.activation
import
SiluAndMul
from
sglang.srt.layers.layernorm
import
RMSNorm
from
sglang.srt.layers.logits_processor
import
LogitsProcessor
from
sglang.srt.layers.pooler
import
EmbeddingPoolerOutput
,
Pooler
,
PoolingType
from
sglang.srt.layers.pooler
import
Pooler
,
PoolingType
from
sglang.srt.layers.radix_attention
import
RadixAttention
from
sglang.srt.layers.sampler
import
Sampler
from
sglang.srt.model_executor.forward_batch_info
import
InputMetadata
Qwen2Config
=
None
...
...
@@ -276,6 +277,7 @@ class Qwen2ForCausalLM(nn.Module):
self
.
model
=
Qwen2Model
(
config
,
quant_config
=
quant_config
)
self
.
lm_head
=
ParallelLMHead
(
config
.
vocab_size
,
config
.
hidden_size
)
self
.
logits_processor
=
LogitsProcessor
(
config
)
self
.
sampler
=
Sampler
()
self
.
pooler
=
Pooler
(
pooling_type
=
PoolingType
.
LAST
,
normalize
=
True
)
@
torch
.
no_grad
()
...
...
@@ -289,9 +291,11 @@ class Qwen2ForCausalLM(nn.Module):
)
->
torch
.
Tensor
:
hidden_states
=
self
.
model
(
input_ids
,
positions
,
input_metadata
,
input_embeds
)
if
not
get_embedding
:
return
self
.
logits_processor
(
logits_output
=
self
.
logits_processor
(
input_ids
,
hidden_states
,
self
.
lm_head
.
weight
,
input_metadata
)
sample_output
=
self
.
sampler
(
logits_output
,
input_metadata
.
sampling_info
)
return
sample_output
,
logits_output
else
:
return
self
.
pooler
(
hidden_states
,
input_metadata
)
...
...
python/sglang/srt/models/qwen2_moe.py
View file @
75ce37f4
...
...
@@ -35,10 +35,8 @@ from vllm.model_executor.layers.linear import (
ReplicatedLinear
,
RowParallelLinear
,
)
from
vllm.model_executor.layers.logits_processor
import
LogitsProcessor
from
vllm.model_executor.layers.quantization.base_config
import
QuantizationConfig
from
vllm.model_executor.layers.rotary_embedding
import
get_rope
from
vllm.model_executor.layers.sampler
import
Sampler
from
vllm.model_executor.layers.vocab_parallel_embedding
import
(
ParallelLMHead
,
VocabParallelEmbedding
,
...
...
@@ -49,6 +47,7 @@ from sglang.srt.layers.activation import SiluAndMul
from
sglang.srt.layers.layernorm
import
RMSNorm
from
sglang.srt.layers.logits_processor
import
LogitsProcessor
from
sglang.srt.layers.radix_attention
import
RadixAttention
from
sglang.srt.layers.sampler
import
Sampler
from
sglang.srt.model_executor.forward_batch_info
import
InputMetadata
...
...
@@ -366,6 +365,7 @@ class Qwen2MoeForCausalLM(nn.Module):
config
.
vocab_size
,
config
.
hidden_size
,
quant_config
=
quant_config
)
self
.
logits_processor
=
LogitsProcessor
(
config
)
self
.
sampler
=
Sampler
()
@
torch
.
no_grad
()
def
forward
(
...
...
@@ -376,20 +376,11 @@ class Qwen2MoeForCausalLM(nn.Module):
input_embeds
:
torch
.
Tensor
=
None
,
)
->
torch
.
Tensor
:
hidden_states
=
self
.
model
(
input_ids
,
positions
,
input_metadata
,
input_embeds
)
return
self
.
logits_processor
(
logits_output
=
self
.
logits_processor
(
input_ids
,
hidden_states
,
self
.
lm_head
.
weight
,
input_metadata
)
def
compute_logits
(
self
,
input_ids
:
torch
.
Tensor
,
hidden_states
:
torch
.
Tensor
,
input_metadata
:
InputMetadata
,
)
->
torch
.
Tensor
:
logits
=
self
.
logits_processor
(
input_ids
,
hidden_states
,
self
.
lm_head
.
weight
,
input_metadata
)
return
logits
sample_output
=
self
.
sampler
(
logits_output
,
input_metadata
.
sampling_info
)
return
sample_output
,
logits_output
def
load_weights
(
self
,
weights
:
Iterable
[
Tuple
[
str
,
torch
.
Tensor
]]):
stacked_params_mapping
=
[
...
...
python/sglang/srt/models/stablelm.py
View file @
75ce37f4
...
...
@@ -40,6 +40,7 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from
sglang.srt.layers.activation
import
SiluAndMul
from
sglang.srt.layers.logits_processor
import
LogitsProcessor
from
sglang.srt.layers.radix_attention
import
RadixAttention
from
sglang.srt.layers.sampler
import
Sampler
from
sglang.srt.model_executor.forward_batch_info
import
InputMetadata
...
...
@@ -249,6 +250,7 @@ class StableLmForCausalLM(nn.Module):
self
.
model
=
StableLMEpochModel
(
config
,
quant_config
=
quant_config
)
self
.
lm_head
=
ParallelLMHead
(
config
.
vocab_size
,
config
.
hidden_size
)
self
.
logits_processor
=
LogitsProcessor
(
config
)
self
.
sampler
=
Sampler
()
@
torch
.
no_grad
()
def
forward
(
...
...
@@ -259,9 +261,11 @@ class StableLmForCausalLM(nn.Module):
input_embeds
:
torch
.
Tensor
=
None
,
)
->
torch
.
Tensor
:
hidden_states
=
self
.
model
(
input_ids
,
positions
,
input_metadata
,
input_embeds
)
return
self
.
logits_processor
(
logits_output
=
self
.
logits_processor
(
input_ids
,
hidden_states
,
self
.
lm_head
.
weight
,
input_metadata
)
sample_output
=
self
.
sampler
(
logits_output
,
input_metadata
.
sampling_info
)
return
sample_output
,
logits_output
def
load_weights
(
self
,
weights
:
Iterable
[
Tuple
[
str
,
torch
.
Tensor
]]):
stacked_params_mapping
=
[
...
...
python/sglang/srt/sampling/sampling_batch_info.py
View file @
75ce37f4
...
...
@@ -21,10 +21,63 @@ class SamplingBatchInfo:
top_ps
:
torch
.
Tensor
=
None
top_ks
:
torch
.
Tensor
=
None
min_ps
:
torch
.
Tensor
=
None
penalizer_orchestrator
:
penaltylib
.
BatchedPenalizerOrchestrator
=
None
# Dispatch in CUDA graph
need_min_p_sampling
:
bool
=
False
# Bias Tensors
logit_bias
:
torch
.
Tensor
=
None
vocab_mask
:
torch
.
Tensor
=
None
# Penalizer
penalizer_orchestrator
:
penaltylib
.
BatchedPenalizerOrchestrator
=
None
linear_penalties
:
torch
.
Tensor
=
None
scaling_penalties
:
torch
.
Tensor
=
None
def
has_bias
(
self
):
return
(
self
.
logit_bias
is
not
None
or
self
.
vocab_mask
is
not
None
or
self
.
linear_penalties
is
not
None
or
self
.
scaling_penalties
is
not
None
)
@
classmethod
def
dummy_one
(
cls
,
max_bs
:
int
,
vocab_size
:
int
):
ret
=
cls
(
vocab_size
=
vocab_size
)
ret
.
temperatures
=
torch
.
ones
((
max_bs
,
1
),
dtype
=
torch
.
float
,
device
=
"cuda"
)
ret
.
top_ps
=
torch
.
ones
((
max_bs
,),
dtype
=
torch
.
float
,
device
=
"cuda"
)
ret
.
top_ks
=
torch
.
ones
((
max_bs
,),
dtype
=
torch
.
int
,
device
=
"cuda"
)
ret
.
min_ps
=
torch
.
zeros
((
max_bs
,),
dtype
=
torch
.
float
,
device
=
"cuda"
)
return
ret
def
__getitem__
(
self
,
key
):
if
isinstance
(
key
,
slice
):
# NOTE: We do not use cuda graph when there is bias tensors
assert
not
self
.
has_bias
()
return
SamplingBatchInfo
(
vocab_size
=
self
.
vocab_size
,
temperatures
=
self
.
temperatures
[
key
],
top_ps
=
self
.
top_ps
[
key
],
top_ks
=
self
.
top_ks
[
key
],
min_ps
=
self
.
min_ps
[
key
],
need_min_p_sampling
=
self
.
need_min_p_sampling
,
)
else
:
raise
NotImplementedError
def
inplace_assign
(
self
,
bs
:
int
,
other
:
SamplingBatchInfo
):
# NOTE: We do not use cuda graph when there is bias tensors
assert
not
self
.
has_bias
()
self
.
vocab_size
=
other
.
vocab_size
self
.
need_min_p_sampling
=
other
.
need_min_p_sampling
self
.
temperatures
[:
bs
]
=
other
.
temperatures
self
.
top_ps
[:
bs
]
=
other
.
top_ps
self
.
top_ks
[:
bs
]
=
other
.
top_ks
self
.
min_ps
[:
bs
]
=
other
.
min_ps
@
classmethod
def
from_schedule_batch
(
cls
,
batch
:
ScheduleBatch
,
vocab_size
:
int
):
device
=
"cuda"
...
...
@@ -45,6 +98,7 @@ class SamplingBatchInfo:
ret
.
min_ps
=
torch
.
tensor
(
[
r
.
sampling_params
.
min_p
for
r
in
reqs
],
dtype
=
torch
.
float
,
device
=
device
)
ret
.
need_min_p_sampling
=
any
(
r
.
sampling_params
.
min_p
>
0
for
r
in
reqs
)
# Each penalizers will do nothing if they evaluate themselves as not required by looking at
# the sampling_params of the requests (See {_is_required()} of each penalizers). So this
...
...
@@ -72,6 +126,25 @@ class SamplingBatchInfo:
return
ret
def
prepare_penalties
(
self
):
self
.
scaling_penalties
=
None
self
.
linear_penalties
=
None
for
penalizer
in
self
.
penalizer_orchestrator
.
penalizers
.
values
():
if
isinstance
(
penalizer
,
penaltylib
.
BatchedRepetitionPenalizer
):
if
penalizer
.
is_prepared
():
self
.
scaling_penalties
=
penalizer
.
cumulated_repetition_penalties
else
:
if
penalizer
.
is_prepared
():
if
self
.
linear_penalties
is
None
:
bs
=
self
.
penalizer_orchestrator
.
batch
.
batch_size
()
self
.
linear_penalties
=
torch
.
zeros
(
(
bs
,
self
.
vocab_size
),
dtype
=
torch
.
float32
,
device
=
"cuda"
,
)
self
.
linear_penalties
=
penalizer
.
apply
(
self
.
linear_penalties
)
def
update_regex_vocab_mask
(
self
,
batch
:
ScheduleBatch
):
bs
,
reqs
=
batch
.
batch_size
(),
batch
.
reqs
device
=
"cuda"
...
...
python/sglang/test/runners.py
View file @
75ce37f4
...
...
@@ -180,7 +180,7 @@ class SRTRunner:
tp_size
=
tp_size
,
dtype
=
get_dtype_str
(
torch_dtype
),
port
=
port
,
mem_fraction_static
=
0.
7
,
mem_fraction_static
=
0.
69
,
trust_remote_code
=
False
,
is_embedding
=
not
self
.
is_generation
,
)
...
...
Prev
1
2
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment