Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
381dd57b
Unverified
Commit
381dd57b
authored
Aug 28, 2024
by
Liangsheng Yin
Committed by
GitHub
Aug 28, 2024
Browse files
Sampler cudagraph (#1253)
parent
8153168c
Changes
29
Hide whitespace changes
Inline
Side-by-side
Showing
9 changed files
with
111 additions
and
24 deletions
+111
-24
python/sglang/srt/models/minicpm.py
python/sglang/srt/models/minicpm.py
+5
-1
python/sglang/srt/models/mixtral.py
python/sglang/srt/models/mixtral.py
+5
-1
python/sglang/srt/models/mixtral_quant.py
python/sglang/srt/models/mixtral_quant.py
+5
-1
python/sglang/srt/models/qwen.py
python/sglang/srt/models/qwen.py
+5
-2
python/sglang/srt/models/qwen2.py
python/sglang/srt/models/qwen2.py
+6
-2
python/sglang/srt/models/qwen2_moe.py
python/sglang/srt/models/qwen2_moe.py
+5
-14
python/sglang/srt/models/stablelm.py
python/sglang/srt/models/stablelm.py
+5
-1
python/sglang/srt/sampling/sampling_batch_info.py
python/sglang/srt/sampling/sampling_batch_info.py
+74
-1
python/sglang/test/runners.py
python/sglang/test/runners.py
+1
-1
No files found.
python/sglang/srt/models/minicpm.py
View file @
381dd57b
...
@@ -39,6 +39,7 @@ from sglang.srt.layers.activation import SiluAndMul
...
@@ -39,6 +39,7 @@ from sglang.srt.layers.activation import SiluAndMul
from
sglang.srt.layers.layernorm
import
RMSNorm
from
sglang.srt.layers.layernorm
import
RMSNorm
from
sglang.srt.layers.logits_processor
import
LogitsProcessor
from
sglang.srt.layers.logits_processor
import
LogitsProcessor
from
sglang.srt.layers.radix_attention
import
RadixAttention
from
sglang.srt.layers.radix_attention
import
RadixAttention
from
sglang.srt.layers.sampler
import
Sampler
from
sglang.srt.model_executor.forward_batch_info
import
InputMetadata
from
sglang.srt.model_executor.forward_batch_info
import
InputMetadata
...
@@ -297,6 +298,7 @@ class MiniCPMForCausalLM(nn.Module):
...
@@ -297,6 +298,7 @@ class MiniCPMForCausalLM(nn.Module):
self
.
scale_width
=
self
.
config
.
hidden_size
/
self
.
config
.
dim_model_base
self
.
scale_width
=
self
.
config
.
hidden_size
/
self
.
config
.
dim_model_base
self
.
logits_processor
=
LogitsProcessor
(
config
)
self
.
logits_processor
=
LogitsProcessor
(
config
)
self
.
sampler
=
Sampler
()
@
torch
.
no_grad
()
@
torch
.
no_grad
()
def
forward
(
def
forward
(
...
@@ -314,9 +316,11 @@ class MiniCPMForCausalLM(nn.Module):
...
@@ -314,9 +316,11 @@ class MiniCPMForCausalLM(nn.Module):
lm_head_weight
=
self
.
model
.
embed_tokens
.
weight
lm_head_weight
=
self
.
model
.
embed_tokens
.
weight
else
:
else
:
lm_head_weight
=
self
.
lm_head
.
weight
lm_head_weight
=
self
.
lm_head
.
weight
return
self
.
logits_processor
(
logits_output
=
self
.
logits_processor
(
input_ids
,
hidden_states
,
lm_head_weight
,
input_metadata
input_ids
,
hidden_states
,
lm_head_weight
,
input_metadata
)
)
sample_output
=
self
.
sampler
(
logits_output
,
input_metadata
.
sampling_info
)
return
sample_output
,
logits_output
def
load_weights
(
self
,
weights
:
Iterable
[
Tuple
[
str
,
torch
.
Tensor
]]):
def
load_weights
(
self
,
weights
:
Iterable
[
Tuple
[
str
,
torch
.
Tensor
]]):
stacked_params_mapping
=
[
stacked_params_mapping
=
[
...
...
python/sglang/srt/models/mixtral.py
View file @
381dd57b
...
@@ -41,6 +41,7 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader
...
@@ -41,6 +41,7 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from
sglang.srt.layers.layernorm
import
RMSNorm
from
sglang.srt.layers.layernorm
import
RMSNorm
from
sglang.srt.layers.logits_processor
import
LogitsProcessor
from
sglang.srt.layers.logits_processor
import
LogitsProcessor
from
sglang.srt.layers.radix_attention
import
RadixAttention
from
sglang.srt.layers.radix_attention
import
RadixAttention
from
sglang.srt.layers.sampler
import
Sampler
from
sglang.srt.model_executor.forward_batch_info
import
InputMetadata
from
sglang.srt.model_executor.forward_batch_info
import
InputMetadata
...
@@ -299,6 +300,7 @@ class MixtralForCausalLM(nn.Module):
...
@@ -299,6 +300,7 @@ class MixtralForCausalLM(nn.Module):
self
.
model
=
MixtralModel
(
config
,
quant_config
=
quant_config
,
prefix
=
"model"
)
self
.
model
=
MixtralModel
(
config
,
quant_config
=
quant_config
,
prefix
=
"model"
)
self
.
lm_head
=
ParallelLMHead
(
config
.
vocab_size
,
config
.
hidden_size
)
self
.
lm_head
=
ParallelLMHead
(
config
.
vocab_size
,
config
.
hidden_size
)
self
.
logits_processor
=
LogitsProcessor
(
config
)
self
.
logits_processor
=
LogitsProcessor
(
config
)
self
.
sampler
=
Sampler
()
def
forward
(
def
forward
(
self
,
self
,
...
@@ -308,9 +310,11 @@ class MixtralForCausalLM(nn.Module):
...
@@ -308,9 +310,11 @@ class MixtralForCausalLM(nn.Module):
input_embeds
:
torch
.
Tensor
=
None
,
input_embeds
:
torch
.
Tensor
=
None
,
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
hidden_states
=
self
.
model
(
input_ids
,
positions
,
input_metadata
,
input_embeds
)
hidden_states
=
self
.
model
(
input_ids
,
positions
,
input_metadata
,
input_embeds
)
return
self
.
logits_processor
(
logits_output
=
self
.
logits_processor
(
input_ids
,
hidden_states
,
self
.
lm_head
.
weight
,
input_metadata
input_ids
,
hidden_states
,
self
.
lm_head
.
weight
,
input_metadata
)
)
sample_output
=
self
.
sampler
(
logits_output
,
input_metadata
.
sampling_info
)
return
sample_output
,
logits_output
def
load_weights
(
self
,
weights
:
Iterable
[
Tuple
[
str
,
torch
.
Tensor
]]):
def
load_weights
(
self
,
weights
:
Iterable
[
Tuple
[
str
,
torch
.
Tensor
]]):
stacked_params_mapping
=
[
stacked_params_mapping
=
[
...
...
python/sglang/srt/models/mixtral_quant.py
View file @
381dd57b
...
@@ -45,6 +45,7 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader
...
@@ -45,6 +45,7 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from
sglang.srt.layers.layernorm
import
RMSNorm
from
sglang.srt.layers.layernorm
import
RMSNorm
from
sglang.srt.layers.logits_processor
import
LogitsProcessor
from
sglang.srt.layers.logits_processor
import
LogitsProcessor
from
sglang.srt.layers.radix_attention
import
RadixAttention
from
sglang.srt.layers.radix_attention
import
RadixAttention
from
sglang.srt.layers.sampler
import
Sampler
from
sglang.srt.model_executor.forward_batch_info
import
InputMetadata
from
sglang.srt.model_executor.forward_batch_info
import
InputMetadata
...
@@ -333,6 +334,7 @@ class QuantMixtralForCausalLM(nn.Module):
...
@@ -333,6 +334,7 @@ class QuantMixtralForCausalLM(nn.Module):
self
.
model
=
MixtralModel
(
config
,
quant_config
=
quant_config
)
self
.
model
=
MixtralModel
(
config
,
quant_config
=
quant_config
)
self
.
lm_head
=
ParallelLMHead
(
config
.
vocab_size
,
config
.
hidden_size
)
self
.
lm_head
=
ParallelLMHead
(
config
.
vocab_size
,
config
.
hidden_size
)
self
.
logits_processor
=
LogitsProcessor
(
config
)
self
.
logits_processor
=
LogitsProcessor
(
config
)
self
.
sampler
=
Sampler
()
@
torch
.
no_grad
()
@
torch
.
no_grad
()
def
forward
(
def
forward
(
...
@@ -343,9 +345,11 @@ class QuantMixtralForCausalLM(nn.Module):
...
@@ -343,9 +345,11 @@ class QuantMixtralForCausalLM(nn.Module):
input_embeds
:
torch
.
Tensor
=
None
,
input_embeds
:
torch
.
Tensor
=
None
,
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
hidden_states
=
self
.
model
(
input_ids
,
positions
,
input_metadata
,
input_embeds
)
hidden_states
=
self
.
model
(
input_ids
,
positions
,
input_metadata
,
input_embeds
)
return
self
.
logits_processor
(
logits_output
=
self
.
logits_processor
(
input_ids
,
hidden_states
,
self
.
lm_head
.
weight
,
input_metadata
input_ids
,
hidden_states
,
self
.
lm_head
.
weight
,
input_metadata
)
)
sample_output
=
self
.
sampler
(
logits_output
,
input_metadata
.
sampling_info
)
return
sample_output
,
logits_output
def
load_weights
(
self
,
weights
:
Iterable
[
Tuple
[
str
,
torch
.
Tensor
]]):
def
load_weights
(
self
,
weights
:
Iterable
[
Tuple
[
str
,
torch
.
Tensor
]]):
stacked_params_mapping
=
[
stacked_params_mapping
=
[
...
...
python/sglang/srt/models/qwen.py
View file @
381dd57b
...
@@ -39,6 +39,7 @@ from sglang.srt.layers.activation import SiluAndMul
...
@@ -39,6 +39,7 @@ from sglang.srt.layers.activation import SiluAndMul
from
sglang.srt.layers.layernorm
import
RMSNorm
from
sglang.srt.layers.layernorm
import
RMSNorm
from
sglang.srt.layers.logits_processor
import
LogitsProcessor
from
sglang.srt.layers.logits_processor
import
LogitsProcessor
from
sglang.srt.layers.radix_attention
import
RadixAttention
from
sglang.srt.layers.radix_attention
import
RadixAttention
from
sglang.srt.layers.sampler
import
Sampler
from
sglang.srt.model_executor.forward_batch_info
import
InputMetadata
from
sglang.srt.model_executor.forward_batch_info
import
InputMetadata
...
@@ -251,6 +252,7 @@ class QWenLMHeadModel(nn.Module):
...
@@ -251,6 +252,7 @@ class QWenLMHeadModel(nn.Module):
vocab_size
=
((
config
.
vocab_size
+
63
)
//
64
)
*
64
vocab_size
=
((
config
.
vocab_size
+
63
)
//
64
)
*
64
self
.
lm_head
=
ParallelLMHead
(
vocab_size
,
config
.
hidden_size
)
self
.
lm_head
=
ParallelLMHead
(
vocab_size
,
config
.
hidden_size
)
self
.
logits_processor
=
LogitsProcessor
(
config
)
self
.
logits_processor
=
LogitsProcessor
(
config
)
self
.
sampler
=
Sampler
()
@
torch
.
no_grad
()
@
torch
.
no_grad
()
def
forward
(
def
forward
(
...
@@ -260,10 +262,11 @@ class QWenLMHeadModel(nn.Module):
...
@@ -260,10 +262,11 @@ class QWenLMHeadModel(nn.Module):
input_metadata
:
InputMetadata
,
input_metadata
:
InputMetadata
,
):
):
hidden_states
=
self
.
transformer
(
input_ids
,
positions
,
input_metadata
)
hidden_states
=
self
.
transformer
(
input_ids
,
positions
,
input_metadata
)
next_tokens
=
self
.
logits_processor
(
logits_output
=
self
.
logits_processor
(
input_ids
,
hidden_states
,
self
.
lm_head
.
weight
,
input_metadata
input_ids
,
hidden_states
,
self
.
lm_head
.
weight
,
input_metadata
)
)
return
next_tokens
sample_output
=
self
.
sampler
(
logits_output
,
input_metadata
.
sampling_info
)
return
sample_output
,
logits_output
def
load_weights
(
self
,
weights
:
Iterable
[
Tuple
[
str
,
torch
.
Tensor
]]):
def
load_weights
(
self
,
weights
:
Iterable
[
Tuple
[
str
,
torch
.
Tensor
]]):
stacked_params_mapping
=
[
stacked_params_mapping
=
[
...
...
python/sglang/srt/models/qwen2.py
View file @
381dd57b
...
@@ -38,8 +38,9 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader
...
@@ -38,8 +38,9 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from
sglang.srt.layers.activation
import
SiluAndMul
from
sglang.srt.layers.activation
import
SiluAndMul
from
sglang.srt.layers.layernorm
import
RMSNorm
from
sglang.srt.layers.layernorm
import
RMSNorm
from
sglang.srt.layers.logits_processor
import
LogitsProcessor
from
sglang.srt.layers.logits_processor
import
LogitsProcessor
from
sglang.srt.layers.pooler
import
EmbeddingPoolerOutput
,
Pooler
,
PoolingType
from
sglang.srt.layers.pooler
import
Pooler
,
PoolingType
from
sglang.srt.layers.radix_attention
import
RadixAttention
from
sglang.srt.layers.radix_attention
import
RadixAttention
from
sglang.srt.layers.sampler
import
Sampler
from
sglang.srt.model_executor.forward_batch_info
import
InputMetadata
from
sglang.srt.model_executor.forward_batch_info
import
InputMetadata
Qwen2Config
=
None
Qwen2Config
=
None
...
@@ -276,6 +277,7 @@ class Qwen2ForCausalLM(nn.Module):
...
@@ -276,6 +277,7 @@ class Qwen2ForCausalLM(nn.Module):
self
.
model
=
Qwen2Model
(
config
,
quant_config
=
quant_config
)
self
.
model
=
Qwen2Model
(
config
,
quant_config
=
quant_config
)
self
.
lm_head
=
ParallelLMHead
(
config
.
vocab_size
,
config
.
hidden_size
)
self
.
lm_head
=
ParallelLMHead
(
config
.
vocab_size
,
config
.
hidden_size
)
self
.
logits_processor
=
LogitsProcessor
(
config
)
self
.
logits_processor
=
LogitsProcessor
(
config
)
self
.
sampler
=
Sampler
()
self
.
pooler
=
Pooler
(
pooling_type
=
PoolingType
.
LAST
,
normalize
=
True
)
self
.
pooler
=
Pooler
(
pooling_type
=
PoolingType
.
LAST
,
normalize
=
True
)
@
torch
.
no_grad
()
@
torch
.
no_grad
()
...
@@ -289,9 +291,11 @@ class Qwen2ForCausalLM(nn.Module):
...
@@ -289,9 +291,11 @@ class Qwen2ForCausalLM(nn.Module):
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
hidden_states
=
self
.
model
(
input_ids
,
positions
,
input_metadata
,
input_embeds
)
hidden_states
=
self
.
model
(
input_ids
,
positions
,
input_metadata
,
input_embeds
)
if
not
get_embedding
:
if
not
get_embedding
:
return
self
.
logits_processor
(
logits_output
=
self
.
logits_processor
(
input_ids
,
hidden_states
,
self
.
lm_head
.
weight
,
input_metadata
input_ids
,
hidden_states
,
self
.
lm_head
.
weight
,
input_metadata
)
)
sample_output
=
self
.
sampler
(
logits_output
,
input_metadata
.
sampling_info
)
return
sample_output
,
logits_output
else
:
else
:
return
self
.
pooler
(
hidden_states
,
input_metadata
)
return
self
.
pooler
(
hidden_states
,
input_metadata
)
...
...
python/sglang/srt/models/qwen2_moe.py
View file @
381dd57b
...
@@ -35,10 +35,8 @@ from vllm.model_executor.layers.linear import (
...
@@ -35,10 +35,8 @@ from vllm.model_executor.layers.linear import (
ReplicatedLinear
,
ReplicatedLinear
,
RowParallelLinear
,
RowParallelLinear
,
)
)
from
vllm.model_executor.layers.logits_processor
import
LogitsProcessor
from
vllm.model_executor.layers.quantization.base_config
import
QuantizationConfig
from
vllm.model_executor.layers.quantization.base_config
import
QuantizationConfig
from
vllm.model_executor.layers.rotary_embedding
import
get_rope
from
vllm.model_executor.layers.rotary_embedding
import
get_rope
from
vllm.model_executor.layers.sampler
import
Sampler
from
vllm.model_executor.layers.vocab_parallel_embedding
import
(
from
vllm.model_executor.layers.vocab_parallel_embedding
import
(
ParallelLMHead
,
ParallelLMHead
,
VocabParallelEmbedding
,
VocabParallelEmbedding
,
...
@@ -49,6 +47,7 @@ from sglang.srt.layers.activation import SiluAndMul
...
@@ -49,6 +47,7 @@ from sglang.srt.layers.activation import SiluAndMul
from
sglang.srt.layers.layernorm
import
RMSNorm
from
sglang.srt.layers.layernorm
import
RMSNorm
from
sglang.srt.layers.logits_processor
import
LogitsProcessor
from
sglang.srt.layers.logits_processor
import
LogitsProcessor
from
sglang.srt.layers.radix_attention
import
RadixAttention
from
sglang.srt.layers.radix_attention
import
RadixAttention
from
sglang.srt.layers.sampler
import
Sampler
from
sglang.srt.model_executor.forward_batch_info
import
InputMetadata
from
sglang.srt.model_executor.forward_batch_info
import
InputMetadata
...
@@ -366,6 +365,7 @@ class Qwen2MoeForCausalLM(nn.Module):
...
@@ -366,6 +365,7 @@ class Qwen2MoeForCausalLM(nn.Module):
config
.
vocab_size
,
config
.
hidden_size
,
quant_config
=
quant_config
config
.
vocab_size
,
config
.
hidden_size
,
quant_config
=
quant_config
)
)
self
.
logits_processor
=
LogitsProcessor
(
config
)
self
.
logits_processor
=
LogitsProcessor
(
config
)
self
.
sampler
=
Sampler
()
@
torch
.
no_grad
()
@
torch
.
no_grad
()
def
forward
(
def
forward
(
...
@@ -376,20 +376,11 @@ class Qwen2MoeForCausalLM(nn.Module):
...
@@ -376,20 +376,11 @@ class Qwen2MoeForCausalLM(nn.Module):
input_embeds
:
torch
.
Tensor
=
None
,
input_embeds
:
torch
.
Tensor
=
None
,
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
hidden_states
=
self
.
model
(
input_ids
,
positions
,
input_metadata
,
input_embeds
)
hidden_states
=
self
.
model
(
input_ids
,
positions
,
input_metadata
,
input_embeds
)
return
self
.
logits_processor
(
logits_output
=
self
.
logits_processor
(
input_ids
,
hidden_states
,
self
.
lm_head
.
weight
,
input_metadata
input_ids
,
hidden_states
,
self
.
lm_head
.
weight
,
input_metadata
)
)
sample_output
=
self
.
sampler
(
logits_output
,
input_metadata
.
sampling_info
)
def
compute_logits
(
return
sample_output
,
logits_output
self
,
input_ids
:
torch
.
Tensor
,
hidden_states
:
torch
.
Tensor
,
input_metadata
:
InputMetadata
,
)
->
torch
.
Tensor
:
logits
=
self
.
logits_processor
(
input_ids
,
hidden_states
,
self
.
lm_head
.
weight
,
input_metadata
)
return
logits
def
load_weights
(
self
,
weights
:
Iterable
[
Tuple
[
str
,
torch
.
Tensor
]]):
def
load_weights
(
self
,
weights
:
Iterable
[
Tuple
[
str
,
torch
.
Tensor
]]):
stacked_params_mapping
=
[
stacked_params_mapping
=
[
...
...
python/sglang/srt/models/stablelm.py
View file @
381dd57b
...
@@ -40,6 +40,7 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader
...
@@ -40,6 +40,7 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from
sglang.srt.layers.activation
import
SiluAndMul
from
sglang.srt.layers.activation
import
SiluAndMul
from
sglang.srt.layers.logits_processor
import
LogitsProcessor
from
sglang.srt.layers.logits_processor
import
LogitsProcessor
from
sglang.srt.layers.radix_attention
import
RadixAttention
from
sglang.srt.layers.radix_attention
import
RadixAttention
from
sglang.srt.layers.sampler
import
Sampler
from
sglang.srt.model_executor.forward_batch_info
import
InputMetadata
from
sglang.srt.model_executor.forward_batch_info
import
InputMetadata
...
@@ -249,6 +250,7 @@ class StableLmForCausalLM(nn.Module):
...
@@ -249,6 +250,7 @@ class StableLmForCausalLM(nn.Module):
self
.
model
=
StableLMEpochModel
(
config
,
quant_config
=
quant_config
)
self
.
model
=
StableLMEpochModel
(
config
,
quant_config
=
quant_config
)
self
.
lm_head
=
ParallelLMHead
(
config
.
vocab_size
,
config
.
hidden_size
)
self
.
lm_head
=
ParallelLMHead
(
config
.
vocab_size
,
config
.
hidden_size
)
self
.
logits_processor
=
LogitsProcessor
(
config
)
self
.
logits_processor
=
LogitsProcessor
(
config
)
self
.
sampler
=
Sampler
()
@
torch
.
no_grad
()
@
torch
.
no_grad
()
def
forward
(
def
forward
(
...
@@ -259,9 +261,11 @@ class StableLmForCausalLM(nn.Module):
...
@@ -259,9 +261,11 @@ class StableLmForCausalLM(nn.Module):
input_embeds
:
torch
.
Tensor
=
None
,
input_embeds
:
torch
.
Tensor
=
None
,
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
hidden_states
=
self
.
model
(
input_ids
,
positions
,
input_metadata
,
input_embeds
)
hidden_states
=
self
.
model
(
input_ids
,
positions
,
input_metadata
,
input_embeds
)
return
self
.
logits_processor
(
logits_output
=
self
.
logits_processor
(
input_ids
,
hidden_states
,
self
.
lm_head
.
weight
,
input_metadata
input_ids
,
hidden_states
,
self
.
lm_head
.
weight
,
input_metadata
)
)
sample_output
=
self
.
sampler
(
logits_output
,
input_metadata
.
sampling_info
)
return
sample_output
,
logits_output
def
load_weights
(
self
,
weights
:
Iterable
[
Tuple
[
str
,
torch
.
Tensor
]]):
def
load_weights
(
self
,
weights
:
Iterable
[
Tuple
[
str
,
torch
.
Tensor
]]):
stacked_params_mapping
=
[
stacked_params_mapping
=
[
...
...
python/sglang/srt/sampling/sampling_batch_info.py
View file @
381dd57b
...
@@ -21,10 +21,63 @@ class SamplingBatchInfo:
...
@@ -21,10 +21,63 @@ class SamplingBatchInfo:
top_ps
:
torch
.
Tensor
=
None
top_ps
:
torch
.
Tensor
=
None
top_ks
:
torch
.
Tensor
=
None
top_ks
:
torch
.
Tensor
=
None
min_ps
:
torch
.
Tensor
=
None
min_ps
:
torch
.
Tensor
=
None
penalizer_orchestrator
:
penaltylib
.
BatchedPenalizerOrchestrator
=
None
# Dispatch in CUDA graph
need_min_p_sampling
:
bool
=
False
# Bias Tensors
logit_bias
:
torch
.
Tensor
=
None
logit_bias
:
torch
.
Tensor
=
None
vocab_mask
:
torch
.
Tensor
=
None
vocab_mask
:
torch
.
Tensor
=
None
# Penalizer
penalizer_orchestrator
:
penaltylib
.
BatchedPenalizerOrchestrator
=
None
linear_penalties
:
torch
.
Tensor
=
None
scaling_penalties
:
torch
.
Tensor
=
None
def
has_bias
(
self
):
return
(
self
.
logit_bias
is
not
None
or
self
.
vocab_mask
is
not
None
or
self
.
linear_penalties
is
not
None
or
self
.
scaling_penalties
is
not
None
)
@
classmethod
def
dummy_one
(
cls
,
max_bs
:
int
,
vocab_size
:
int
):
ret
=
cls
(
vocab_size
=
vocab_size
)
ret
.
temperatures
=
torch
.
ones
((
max_bs
,
1
),
dtype
=
torch
.
float
,
device
=
"cuda"
)
ret
.
top_ps
=
torch
.
ones
((
max_bs
,),
dtype
=
torch
.
float
,
device
=
"cuda"
)
ret
.
top_ks
=
torch
.
ones
((
max_bs
,),
dtype
=
torch
.
int
,
device
=
"cuda"
)
ret
.
min_ps
=
torch
.
zeros
((
max_bs
,),
dtype
=
torch
.
float
,
device
=
"cuda"
)
return
ret
def
__getitem__
(
self
,
key
):
if
isinstance
(
key
,
slice
):
# NOTE: We do not use cuda graph when there is bias tensors
assert
not
self
.
has_bias
()
return
SamplingBatchInfo
(
vocab_size
=
self
.
vocab_size
,
temperatures
=
self
.
temperatures
[
key
],
top_ps
=
self
.
top_ps
[
key
],
top_ks
=
self
.
top_ks
[
key
],
min_ps
=
self
.
min_ps
[
key
],
need_min_p_sampling
=
self
.
need_min_p_sampling
,
)
else
:
raise
NotImplementedError
def
inplace_assign
(
self
,
bs
:
int
,
other
:
SamplingBatchInfo
):
# NOTE: We do not use cuda graph when there is bias tensors
assert
not
self
.
has_bias
()
self
.
vocab_size
=
other
.
vocab_size
self
.
need_min_p_sampling
=
other
.
need_min_p_sampling
self
.
temperatures
[:
bs
]
=
other
.
temperatures
self
.
top_ps
[:
bs
]
=
other
.
top_ps
self
.
top_ks
[:
bs
]
=
other
.
top_ks
self
.
min_ps
[:
bs
]
=
other
.
min_ps
@
classmethod
@
classmethod
def
from_schedule_batch
(
cls
,
batch
:
ScheduleBatch
,
vocab_size
:
int
):
def
from_schedule_batch
(
cls
,
batch
:
ScheduleBatch
,
vocab_size
:
int
):
device
=
"cuda"
device
=
"cuda"
...
@@ -45,6 +98,7 @@ class SamplingBatchInfo:
...
@@ -45,6 +98,7 @@ class SamplingBatchInfo:
ret
.
min_ps
=
torch
.
tensor
(
ret
.
min_ps
=
torch
.
tensor
(
[
r
.
sampling_params
.
min_p
for
r
in
reqs
],
dtype
=
torch
.
float
,
device
=
device
[
r
.
sampling_params
.
min_p
for
r
in
reqs
],
dtype
=
torch
.
float
,
device
=
device
)
)
ret
.
need_min_p_sampling
=
any
(
r
.
sampling_params
.
min_p
>
0
for
r
in
reqs
)
# Each penalizers will do nothing if they evaluate themselves as not required by looking at
# Each penalizers will do nothing if they evaluate themselves as not required by looking at
# the sampling_params of the requests (See {_is_required()} of each penalizers). So this
# the sampling_params of the requests (See {_is_required()} of each penalizers). So this
...
@@ -72,6 +126,25 @@ class SamplingBatchInfo:
...
@@ -72,6 +126,25 @@ class SamplingBatchInfo:
return
ret
return
ret
def
prepare_penalties
(
self
):
self
.
scaling_penalties
=
None
self
.
linear_penalties
=
None
for
penalizer
in
self
.
penalizer_orchestrator
.
penalizers
.
values
():
if
isinstance
(
penalizer
,
penaltylib
.
BatchedRepetitionPenalizer
):
if
penalizer
.
is_prepared
():
self
.
scaling_penalties
=
penalizer
.
cumulated_repetition_penalties
else
:
if
penalizer
.
is_prepared
():
if
self
.
linear_penalties
is
None
:
bs
=
self
.
penalizer_orchestrator
.
batch
.
batch_size
()
self
.
linear_penalties
=
torch
.
zeros
(
(
bs
,
self
.
vocab_size
),
dtype
=
torch
.
float32
,
device
=
"cuda"
,
)
self
.
linear_penalties
=
penalizer
.
apply
(
self
.
linear_penalties
)
def
update_regex_vocab_mask
(
self
,
batch
:
ScheduleBatch
):
def
update_regex_vocab_mask
(
self
,
batch
:
ScheduleBatch
):
bs
,
reqs
=
batch
.
batch_size
(),
batch
.
reqs
bs
,
reqs
=
batch
.
batch_size
(),
batch
.
reqs
device
=
"cuda"
device
=
"cuda"
...
...
python/sglang/test/runners.py
View file @
381dd57b
...
@@ -180,7 +180,7 @@ class SRTRunner:
...
@@ -180,7 +180,7 @@ class SRTRunner:
tp_size
=
tp_size
,
tp_size
=
tp_size
,
dtype
=
get_dtype_str
(
torch_dtype
),
dtype
=
get_dtype_str
(
torch_dtype
),
port
=
port
,
port
=
port
,
mem_fraction_static
=
0.
7
,
mem_fraction_static
=
0.
69
,
trust_remote_code
=
False
,
trust_remote_code
=
False
,
is_embedding
=
not
self
.
is_generation
,
is_embedding
=
not
self
.
is_generation
,
)
)
...
...
Prev
1
2
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment