Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
70b68029
"...resnet50_tensorflow.git" did not exist on "27fb855b027ead16d2616dcb59c67409a2176b7f"
Unverified
Commit
70b68029
authored
Sep 13, 2024
by
Liangsheng Yin
Committed by
GitHub
Sep 13, 2024
Browse files
Optimize conflicts between CUDA graph and vocab mask tensors (#1392)
parent
f3d32f88
Changes
32
Hide whitespace changes
Inline
Side-by-side
Showing
12 changed files
with
22 additions
and
75 deletions
+22
-75
python/sglang/srt/models/llama_classification.py
python/sglang/srt/models/llama_classification.py
+1
-20
python/sglang/srt/models/minicpm.py
python/sglang/srt/models/minicpm.py
+1
-5
python/sglang/srt/models/minicpm3.py
python/sglang/srt/models/minicpm3.py
+1
-5
python/sglang/srt/models/mixtral.py
python/sglang/srt/models/mixtral.py
+1
-5
python/sglang/srt/models/mixtral_quant.py
python/sglang/srt/models/mixtral_quant.py
+1
-5
python/sglang/srt/models/qwen.py
python/sglang/srt/models/qwen.py
+1
-5
python/sglang/srt/models/qwen2.py
python/sglang/srt/models/qwen2.py
+1
-5
python/sglang/srt/models/qwen2_moe.py
python/sglang/srt/models/qwen2_moe.py
+1
-5
python/sglang/srt/models/stablelm.py
python/sglang/srt/models/stablelm.py
+1
-5
python/sglang/srt/models/xverse.py
python/sglang/srt/models/xverse.py
+1
-6
python/sglang/srt/models/xverse_moe.py
python/sglang/srt/models/xverse_moe.py
+1
-5
python/sglang/srt/sampling/sampling_batch_info.py
python/sglang/srt/sampling/sampling_batch_info.py
+11
-4
No files found.
python/sglang/srt/models/llama_classification.py
View file @
70b68029
...
@@ -23,7 +23,6 @@ from vllm.model_executor.layers.quantization.base_config import QuantizationConf
...
@@ -23,7 +23,6 @@ from vllm.model_executor.layers.quantization.base_config import QuantizationConf
from
vllm.model_executor.model_loader.weight_utils
import
default_weight_loader
from
vllm.model_executor.model_loader.weight_utils
import
default_weight_loader
from
sglang.srt.layers.logits_processor
import
LogitsProcessorOutput
from
sglang.srt.layers.logits_processor
import
LogitsProcessorOutput
from
sglang.srt.layers.sampler
import
SampleOutput
from
sglang.srt.model_executor.forward_batch_info
import
InputMetadata
from
sglang.srt.model_executor.forward_batch_info
import
InputMetadata
from
sglang.srt.models.llama
import
LlamaForCausalLM
,
LlamaModel
from
sglang.srt.models.llama
import
LlamaForCausalLM
,
LlamaModel
...
@@ -75,25 +74,7 @@ class LlamaForClassification(nn.Module):
...
@@ -75,25 +74,7 @@ class LlamaForClassification(nn.Module):
output_top_logprobs
=
None
,
output_top_logprobs
=
None
,
)
)
# A dummy to make this work
return
logits_output
sample_output
=
SampleOutput
(
success
=
torch
.
full
(
size
=
(
scores
.
shape
[
0
],),
fill_value
=
True
,
dtype
=
torch
.
bool
,
),
probs
=
torch
.
full
(
size
=
(
scores
.
shape
[
0
],
1
),
fill_value
=
1.0
,
dtype
=
torch
.
float16
,
),
batch_next_token_ids
=
torch
.
full
(
size
=
(
scores
.
shape
[
0
],),
fill_value
=
0
,
dtype
=
torch
.
long
,
),
)
return
sample_output
,
logits_output
def
load_weights
(
self
,
weights
:
Iterable
[
Tuple
[
str
,
torch
.
Tensor
]]):
def
load_weights
(
self
,
weights
:
Iterable
[
Tuple
[
str
,
torch
.
Tensor
]]):
params_dict
=
self
.
param_dict
params_dict
=
self
.
param_dict
...
...
python/sglang/srt/models/minicpm.py
View file @
70b68029
...
@@ -39,7 +39,6 @@ from sglang.srt.layers.activation import SiluAndMul
...
@@ -39,7 +39,6 @@ from sglang.srt.layers.activation import SiluAndMul
from
sglang.srt.layers.layernorm
import
RMSNorm
from
sglang.srt.layers.layernorm
import
RMSNorm
from
sglang.srt.layers.logits_processor
import
LogitsProcessor
from
sglang.srt.layers.logits_processor
import
LogitsProcessor
from
sglang.srt.layers.radix_attention
import
RadixAttention
from
sglang.srt.layers.radix_attention
import
RadixAttention
from
sglang.srt.layers.sampler
import
Sampler
from
sglang.srt.model_executor.forward_batch_info
import
InputMetadata
from
sglang.srt.model_executor.forward_batch_info
import
InputMetadata
...
@@ -298,7 +297,6 @@ class MiniCPMForCausalLM(nn.Module):
...
@@ -298,7 +297,6 @@ class MiniCPMForCausalLM(nn.Module):
self
.
scale_width
=
self
.
config
.
hidden_size
/
self
.
config
.
dim_model_base
self
.
scale_width
=
self
.
config
.
hidden_size
/
self
.
config
.
dim_model_base
self
.
logits_processor
=
LogitsProcessor
(
config
)
self
.
logits_processor
=
LogitsProcessor
(
config
)
self
.
sampler
=
Sampler
()
@
torch
.
no_grad
()
@
torch
.
no_grad
()
def
forward
(
def
forward
(
...
@@ -316,11 +314,9 @@ class MiniCPMForCausalLM(nn.Module):
...
@@ -316,11 +314,9 @@ class MiniCPMForCausalLM(nn.Module):
lm_head_weight
=
self
.
model
.
embed_tokens
.
weight
lm_head_weight
=
self
.
model
.
embed_tokens
.
weight
else
:
else
:
lm_head_weight
=
self
.
lm_head
.
weight
lm_head_weight
=
self
.
lm_head
.
weight
logits_output
=
self
.
logits_processor
(
return
self
.
logits_processor
(
input_ids
,
hidden_states
,
lm_head_weight
,
input_metadata
input_ids
,
hidden_states
,
lm_head_weight
,
input_metadata
)
)
sample_output
=
self
.
sampler
(
logits_output
,
input_metadata
.
sampling_info
)
return
sample_output
,
logits_output
def
load_weights
(
self
,
weights
:
Iterable
[
Tuple
[
str
,
torch
.
Tensor
]]):
def
load_weights
(
self
,
weights
:
Iterable
[
Tuple
[
str
,
torch
.
Tensor
]]):
stacked_params_mapping
=
[
stacked_params_mapping
=
[
...
...
python/sglang/srt/models/minicpm3.py
View file @
70b68029
...
@@ -42,7 +42,6 @@ from sglang.srt.layers.activation import SiluAndMul
...
@@ -42,7 +42,6 @@ from sglang.srt.layers.activation import SiluAndMul
from
sglang.srt.layers.layernorm
import
RMSNorm
from
sglang.srt.layers.layernorm
import
RMSNorm
from
sglang.srt.layers.logits_processor
import
LogitsProcessor
from
sglang.srt.layers.logits_processor
import
LogitsProcessor
from
sglang.srt.layers.radix_attention
import
RadixAttention
from
sglang.srt.layers.radix_attention
import
RadixAttention
from
sglang.srt.layers.sampler
import
Sampler
from
sglang.srt.managers.schedule_batch
import
global_server_args_dict
from
sglang.srt.managers.schedule_batch
import
global_server_args_dict
from
sglang.srt.model_executor.forward_batch_info
import
InputMetadata
from
sglang.srt.model_executor.forward_batch_info
import
InputMetadata
...
@@ -572,7 +571,6 @@ class MiniCPM3ForCausalLM(nn.Module):
...
@@ -572,7 +571,6 @@ class MiniCPM3ForCausalLM(nn.Module):
self
.
scale_width
=
self
.
config
.
hidden_size
/
self
.
config
.
dim_model_base
self
.
scale_width
=
self
.
config
.
hidden_size
/
self
.
config
.
dim_model_base
self
.
logits_processor
=
LogitsProcessor
(
config
)
self
.
logits_processor
=
LogitsProcessor
(
config
)
self
.
sampler
=
Sampler
()
@
torch
.
no_grad
()
@
torch
.
no_grad
()
def
forward
(
def
forward
(
...
@@ -590,11 +588,9 @@ class MiniCPM3ForCausalLM(nn.Module):
...
@@ -590,11 +588,9 @@ class MiniCPM3ForCausalLM(nn.Module):
lm_head_weight
=
self
.
model
.
embed_tokens
.
weight
lm_head_weight
=
self
.
model
.
embed_tokens
.
weight
else
:
else
:
lm_head_weight
=
self
.
lm_head
.
weight
lm_head_weight
=
self
.
lm_head
.
weight
logits_output
=
self
.
logits_processor
(
return
self
.
logits_processor
(
input_ids
,
hidden_states
,
lm_head_weight
,
input_metadata
input_ids
,
hidden_states
,
lm_head_weight
,
input_metadata
)
)
sample_output
=
self
.
sampler
(
logits_output
,
input_metadata
.
sampling_info
)
return
sample_output
,
logits_output
def
load_weights
(
self
,
weights
:
Iterable
[
Tuple
[
str
,
torch
.
Tensor
]]):
def
load_weights
(
self
,
weights
:
Iterable
[
Tuple
[
str
,
torch
.
Tensor
]]):
stacked_params_mapping
=
[
stacked_params_mapping
=
[
...
...
python/sglang/srt/models/mixtral.py
View file @
70b68029
...
@@ -41,7 +41,6 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader
...
@@ -41,7 +41,6 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from
sglang.srt.layers.layernorm
import
RMSNorm
from
sglang.srt.layers.layernorm
import
RMSNorm
from
sglang.srt.layers.logits_processor
import
LogitsProcessor
from
sglang.srt.layers.logits_processor
import
LogitsProcessor
from
sglang.srt.layers.radix_attention
import
RadixAttention
from
sglang.srt.layers.radix_attention
import
RadixAttention
from
sglang.srt.layers.sampler
import
Sampler
from
sglang.srt.model_executor.forward_batch_info
import
InputMetadata
from
sglang.srt.model_executor.forward_batch_info
import
InputMetadata
...
@@ -300,7 +299,6 @@ class MixtralForCausalLM(nn.Module):
...
@@ -300,7 +299,6 @@ class MixtralForCausalLM(nn.Module):
self
.
model
=
MixtralModel
(
config
,
quant_config
=
quant_config
,
prefix
=
"model"
)
self
.
model
=
MixtralModel
(
config
,
quant_config
=
quant_config
,
prefix
=
"model"
)
self
.
lm_head
=
ParallelLMHead
(
config
.
vocab_size
,
config
.
hidden_size
)
self
.
lm_head
=
ParallelLMHead
(
config
.
vocab_size
,
config
.
hidden_size
)
self
.
logits_processor
=
LogitsProcessor
(
config
)
self
.
logits_processor
=
LogitsProcessor
(
config
)
self
.
sampler
=
Sampler
()
def
forward
(
def
forward
(
self
,
self
,
...
@@ -310,11 +308,9 @@ class MixtralForCausalLM(nn.Module):
...
@@ -310,11 +308,9 @@ class MixtralForCausalLM(nn.Module):
input_embeds
:
torch
.
Tensor
=
None
,
input_embeds
:
torch
.
Tensor
=
None
,
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
hidden_states
=
self
.
model
(
input_ids
,
positions
,
input_metadata
,
input_embeds
)
hidden_states
=
self
.
model
(
input_ids
,
positions
,
input_metadata
,
input_embeds
)
logits_output
=
self
.
logits_processor
(
return
self
.
logits_processor
(
input_ids
,
hidden_states
,
self
.
lm_head
.
weight
,
input_metadata
input_ids
,
hidden_states
,
self
.
lm_head
.
weight
,
input_metadata
)
)
sample_output
=
self
.
sampler
(
logits_output
,
input_metadata
.
sampling_info
)
return
sample_output
,
logits_output
def
load_weights
(
self
,
weights
:
Iterable
[
Tuple
[
str
,
torch
.
Tensor
]]):
def
load_weights
(
self
,
weights
:
Iterable
[
Tuple
[
str
,
torch
.
Tensor
]]):
stacked_params_mapping
=
[
stacked_params_mapping
=
[
...
...
python/sglang/srt/models/mixtral_quant.py
View file @
70b68029
...
@@ -45,7 +45,6 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader
...
@@ -45,7 +45,6 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from
sglang.srt.layers.layernorm
import
RMSNorm
from
sglang.srt.layers.layernorm
import
RMSNorm
from
sglang.srt.layers.logits_processor
import
LogitsProcessor
from
sglang.srt.layers.logits_processor
import
LogitsProcessor
from
sglang.srt.layers.radix_attention
import
RadixAttention
from
sglang.srt.layers.radix_attention
import
RadixAttention
from
sglang.srt.layers.sampler
import
Sampler
from
sglang.srt.model_executor.forward_batch_info
import
InputMetadata
from
sglang.srt.model_executor.forward_batch_info
import
InputMetadata
...
@@ -334,7 +333,6 @@ class QuantMixtralForCausalLM(nn.Module):
...
@@ -334,7 +333,6 @@ class QuantMixtralForCausalLM(nn.Module):
self
.
model
=
MixtralModel
(
config
,
quant_config
=
quant_config
)
self
.
model
=
MixtralModel
(
config
,
quant_config
=
quant_config
)
self
.
lm_head
=
ParallelLMHead
(
config
.
vocab_size
,
config
.
hidden_size
)
self
.
lm_head
=
ParallelLMHead
(
config
.
vocab_size
,
config
.
hidden_size
)
self
.
logits_processor
=
LogitsProcessor
(
config
)
self
.
logits_processor
=
LogitsProcessor
(
config
)
self
.
sampler
=
Sampler
()
@
torch
.
no_grad
()
@
torch
.
no_grad
()
def
forward
(
def
forward
(
...
@@ -345,11 +343,9 @@ class QuantMixtralForCausalLM(nn.Module):
...
@@ -345,11 +343,9 @@ class QuantMixtralForCausalLM(nn.Module):
input_embeds
:
torch
.
Tensor
=
None
,
input_embeds
:
torch
.
Tensor
=
None
,
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
hidden_states
=
self
.
model
(
input_ids
,
positions
,
input_metadata
,
input_embeds
)
hidden_states
=
self
.
model
(
input_ids
,
positions
,
input_metadata
,
input_embeds
)
logits_output
=
self
.
logits_processor
(
return
self
.
logits_processor
(
input_ids
,
hidden_states
,
self
.
lm_head
.
weight
,
input_metadata
input_ids
,
hidden_states
,
self
.
lm_head
.
weight
,
input_metadata
)
)
sample_output
=
self
.
sampler
(
logits_output
,
input_metadata
.
sampling_info
)
return
sample_output
,
logits_output
def
load_weights
(
self
,
weights
:
Iterable
[
Tuple
[
str
,
torch
.
Tensor
]]):
def
load_weights
(
self
,
weights
:
Iterable
[
Tuple
[
str
,
torch
.
Tensor
]]):
stacked_params_mapping
=
[
stacked_params_mapping
=
[
...
...
python/sglang/srt/models/qwen.py
View file @
70b68029
...
@@ -39,7 +39,6 @@ from sglang.srt.layers.activation import SiluAndMul
...
@@ -39,7 +39,6 @@ from sglang.srt.layers.activation import SiluAndMul
from
sglang.srt.layers.layernorm
import
RMSNorm
from
sglang.srt.layers.layernorm
import
RMSNorm
from
sglang.srt.layers.logits_processor
import
LogitsProcessor
from
sglang.srt.layers.logits_processor
import
LogitsProcessor
from
sglang.srt.layers.radix_attention
import
RadixAttention
from
sglang.srt.layers.radix_attention
import
RadixAttention
from
sglang.srt.layers.sampler
import
Sampler
from
sglang.srt.model_executor.forward_batch_info
import
InputMetadata
from
sglang.srt.model_executor.forward_batch_info
import
InputMetadata
...
@@ -252,7 +251,6 @@ class QWenLMHeadModel(nn.Module):
...
@@ -252,7 +251,6 @@ class QWenLMHeadModel(nn.Module):
vocab_size
=
((
config
.
vocab_size
+
63
)
//
64
)
*
64
vocab_size
=
((
config
.
vocab_size
+
63
)
//
64
)
*
64
self
.
lm_head
=
ParallelLMHead
(
vocab_size
,
config
.
hidden_size
)
self
.
lm_head
=
ParallelLMHead
(
vocab_size
,
config
.
hidden_size
)
self
.
logits_processor
=
LogitsProcessor
(
config
)
self
.
logits_processor
=
LogitsProcessor
(
config
)
self
.
sampler
=
Sampler
()
@
torch
.
no_grad
()
@
torch
.
no_grad
()
def
forward
(
def
forward
(
...
@@ -262,11 +260,9 @@ class QWenLMHeadModel(nn.Module):
...
@@ -262,11 +260,9 @@ class QWenLMHeadModel(nn.Module):
input_metadata
:
InputMetadata
,
input_metadata
:
InputMetadata
,
):
):
hidden_states
=
self
.
transformer
(
input_ids
,
positions
,
input_metadata
)
hidden_states
=
self
.
transformer
(
input_ids
,
positions
,
input_metadata
)
logits_output
=
self
.
logits_processor
(
return
self
.
logits_processor
(
input_ids
,
hidden_states
,
self
.
lm_head
.
weight
,
input_metadata
input_ids
,
hidden_states
,
self
.
lm_head
.
weight
,
input_metadata
)
)
sample_output
=
self
.
sampler
(
logits_output
,
input_metadata
.
sampling_info
)
return
sample_output
,
logits_output
def
load_weights
(
self
,
weights
:
Iterable
[
Tuple
[
str
,
torch
.
Tensor
]]):
def
load_weights
(
self
,
weights
:
Iterable
[
Tuple
[
str
,
torch
.
Tensor
]]):
stacked_params_mapping
=
[
stacked_params_mapping
=
[
...
...
python/sglang/srt/models/qwen2.py
View file @
70b68029
...
@@ -40,7 +40,6 @@ from sglang.srt.layers.layernorm import RMSNorm
...
@@ -40,7 +40,6 @@ from sglang.srt.layers.layernorm import RMSNorm
from
sglang.srt.layers.logits_processor
import
LogitsProcessor
from
sglang.srt.layers.logits_processor
import
LogitsProcessor
from
sglang.srt.layers.pooler
import
Pooler
,
PoolingType
from
sglang.srt.layers.pooler
import
Pooler
,
PoolingType
from
sglang.srt.layers.radix_attention
import
RadixAttention
from
sglang.srt.layers.radix_attention
import
RadixAttention
from
sglang.srt.layers.sampler
import
Sampler
from
sglang.srt.model_executor.forward_batch_info
import
InputMetadata
from
sglang.srt.model_executor.forward_batch_info
import
InputMetadata
Qwen2Config
=
None
Qwen2Config
=
None
...
@@ -277,7 +276,6 @@ class Qwen2ForCausalLM(nn.Module):
...
@@ -277,7 +276,6 @@ class Qwen2ForCausalLM(nn.Module):
self
.
model
=
Qwen2Model
(
config
,
quant_config
=
quant_config
)
self
.
model
=
Qwen2Model
(
config
,
quant_config
=
quant_config
)
self
.
lm_head
=
ParallelLMHead
(
config
.
vocab_size
,
config
.
hidden_size
)
self
.
lm_head
=
ParallelLMHead
(
config
.
vocab_size
,
config
.
hidden_size
)
self
.
logits_processor
=
LogitsProcessor
(
config
)
self
.
logits_processor
=
LogitsProcessor
(
config
)
self
.
sampler
=
Sampler
()
self
.
pooler
=
Pooler
(
pooling_type
=
PoolingType
.
LAST
,
normalize
=
True
)
self
.
pooler
=
Pooler
(
pooling_type
=
PoolingType
.
LAST
,
normalize
=
True
)
@
torch
.
no_grad
()
@
torch
.
no_grad
()
...
@@ -291,11 +289,9 @@ class Qwen2ForCausalLM(nn.Module):
...
@@ -291,11 +289,9 @@ class Qwen2ForCausalLM(nn.Module):
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
hidden_states
=
self
.
model
(
input_ids
,
positions
,
input_metadata
,
input_embeds
)
hidden_states
=
self
.
model
(
input_ids
,
positions
,
input_metadata
,
input_embeds
)
if
not
get_embedding
:
if
not
get_embedding
:
logits_output
=
self
.
logits_processor
(
return
self
.
logits_processor
(
input_ids
,
hidden_states
,
self
.
lm_head
.
weight
,
input_metadata
input_ids
,
hidden_states
,
self
.
lm_head
.
weight
,
input_metadata
)
)
sample_output
=
self
.
sampler
(
logits_output
,
input_metadata
.
sampling_info
)
return
sample_output
,
logits_output
else
:
else
:
return
self
.
pooler
(
hidden_states
,
input_metadata
)
return
self
.
pooler
(
hidden_states
,
input_metadata
)
...
...
python/sglang/srt/models/qwen2_moe.py
View file @
70b68029
...
@@ -47,7 +47,6 @@ from sglang.srt.layers.activation import SiluAndMul
...
@@ -47,7 +47,6 @@ from sglang.srt.layers.activation import SiluAndMul
from
sglang.srt.layers.layernorm
import
RMSNorm
from
sglang.srt.layers.layernorm
import
RMSNorm
from
sglang.srt.layers.logits_processor
import
LogitsProcessor
from
sglang.srt.layers.logits_processor
import
LogitsProcessor
from
sglang.srt.layers.radix_attention
import
RadixAttention
from
sglang.srt.layers.radix_attention
import
RadixAttention
from
sglang.srt.layers.sampler
import
Sampler
from
sglang.srt.model_executor.forward_batch_info
import
InputMetadata
from
sglang.srt.model_executor.forward_batch_info
import
InputMetadata
...
@@ -365,7 +364,6 @@ class Qwen2MoeForCausalLM(nn.Module):
...
@@ -365,7 +364,6 @@ class Qwen2MoeForCausalLM(nn.Module):
config
.
vocab_size
,
config
.
hidden_size
,
quant_config
=
quant_config
config
.
vocab_size
,
config
.
hidden_size
,
quant_config
=
quant_config
)
)
self
.
logits_processor
=
LogitsProcessor
(
config
)
self
.
logits_processor
=
LogitsProcessor
(
config
)
self
.
sampler
=
Sampler
()
@
torch
.
no_grad
()
@
torch
.
no_grad
()
def
forward
(
def
forward
(
...
@@ -376,11 +374,9 @@ class Qwen2MoeForCausalLM(nn.Module):
...
@@ -376,11 +374,9 @@ class Qwen2MoeForCausalLM(nn.Module):
input_embeds
:
torch
.
Tensor
=
None
,
input_embeds
:
torch
.
Tensor
=
None
,
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
hidden_states
=
self
.
model
(
input_ids
,
positions
,
input_metadata
,
input_embeds
)
hidden_states
=
self
.
model
(
input_ids
,
positions
,
input_metadata
,
input_embeds
)
logits_output
=
self
.
logits_processor
(
return
self
.
logits_processor
(
input_ids
,
hidden_states
,
self
.
lm_head
.
weight
,
input_metadata
input_ids
,
hidden_states
,
self
.
lm_head
.
weight
,
input_metadata
)
)
sample_output
=
self
.
sampler
(
logits_output
,
input_metadata
.
sampling_info
)
return
sample_output
,
logits_output
def
load_weights
(
self
,
weights
:
Iterable
[
Tuple
[
str
,
torch
.
Tensor
]]):
def
load_weights
(
self
,
weights
:
Iterable
[
Tuple
[
str
,
torch
.
Tensor
]]):
stacked_params_mapping
=
[
stacked_params_mapping
=
[
...
...
python/sglang/srt/models/stablelm.py
View file @
70b68029
...
@@ -40,7 +40,6 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader
...
@@ -40,7 +40,6 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from
sglang.srt.layers.activation
import
SiluAndMul
from
sglang.srt.layers.activation
import
SiluAndMul
from
sglang.srt.layers.logits_processor
import
LogitsProcessor
from
sglang.srt.layers.logits_processor
import
LogitsProcessor
from
sglang.srt.layers.radix_attention
import
RadixAttention
from
sglang.srt.layers.radix_attention
import
RadixAttention
from
sglang.srt.layers.sampler
import
Sampler
from
sglang.srt.model_executor.forward_batch_info
import
InputMetadata
from
sglang.srt.model_executor.forward_batch_info
import
InputMetadata
...
@@ -250,7 +249,6 @@ class StableLmForCausalLM(nn.Module):
...
@@ -250,7 +249,6 @@ class StableLmForCausalLM(nn.Module):
self
.
model
=
StableLMEpochModel
(
config
,
quant_config
=
quant_config
)
self
.
model
=
StableLMEpochModel
(
config
,
quant_config
=
quant_config
)
self
.
lm_head
=
ParallelLMHead
(
config
.
vocab_size
,
config
.
hidden_size
)
self
.
lm_head
=
ParallelLMHead
(
config
.
vocab_size
,
config
.
hidden_size
)
self
.
logits_processor
=
LogitsProcessor
(
config
)
self
.
logits_processor
=
LogitsProcessor
(
config
)
self
.
sampler
=
Sampler
()
@
torch
.
no_grad
()
@
torch
.
no_grad
()
def
forward
(
def
forward
(
...
@@ -261,11 +259,9 @@ class StableLmForCausalLM(nn.Module):
...
@@ -261,11 +259,9 @@ class StableLmForCausalLM(nn.Module):
input_embeds
:
torch
.
Tensor
=
None
,
input_embeds
:
torch
.
Tensor
=
None
,
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
hidden_states
=
self
.
model
(
input_ids
,
positions
,
input_metadata
,
input_embeds
)
hidden_states
=
self
.
model
(
input_ids
,
positions
,
input_metadata
,
input_embeds
)
logits_output
=
self
.
logits_processor
(
return
self
.
logits_processor
(
input_ids
,
hidden_states
,
self
.
lm_head
.
weight
,
input_metadata
input_ids
,
hidden_states
,
self
.
lm_head
.
weight
,
input_metadata
)
)
sample_output
=
self
.
sampler
(
logits_output
,
input_metadata
.
sampling_info
)
return
sample_output
,
logits_output
def
load_weights
(
self
,
weights
:
Iterable
[
Tuple
[
str
,
torch
.
Tensor
]]):
def
load_weights
(
self
,
weights
:
Iterable
[
Tuple
[
str
,
torch
.
Tensor
]]):
stacked_params_mapping
=
[
stacked_params_mapping
=
[
...
...
python/sglang/srt/models/xverse.py
View file @
70b68029
...
@@ -41,7 +41,6 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader
...
@@ -41,7 +41,6 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from
sglang.srt.layers.logits_processor
import
LogitsProcessor
from
sglang.srt.layers.logits_processor
import
LogitsProcessor
from
sglang.srt.layers.radix_attention
import
RadixAttention
from
sglang.srt.layers.radix_attention
import
RadixAttention
from
sglang.srt.layers.sampler
import
Sampler
from
sglang.srt.model_executor.model_runner
import
InputMetadata
from
sglang.srt.model_executor.model_runner
import
InputMetadata
...
@@ -307,7 +306,6 @@ class XverseForCausalLM(nn.Module):
...
@@ -307,7 +306,6 @@ class XverseForCausalLM(nn.Module):
self
.
model
=
XverseModel
(
config
,
quant_config
=
quant_config
)
self
.
model
=
XverseModel
(
config
,
quant_config
=
quant_config
)
self
.
lm_head
=
ParallelLMHead
(
config
.
vocab_size
,
config
.
hidden_size
)
self
.
lm_head
=
ParallelLMHead
(
config
.
vocab_size
,
config
.
hidden_size
)
self
.
logits_processor
=
LogitsProcessor
(
config
)
self
.
logits_processor
=
LogitsProcessor
(
config
)
self
.
sampler
=
Sampler
()
self
.
param_dict
=
dict
(
self
.
named_parameters
())
self
.
param_dict
=
dict
(
self
.
named_parameters
())
...
@@ -320,12 +318,9 @@ class XverseForCausalLM(nn.Module):
...
@@ -320,12 +318,9 @@ class XverseForCausalLM(nn.Module):
input_embeds
:
torch
.
Tensor
=
None
,
input_embeds
:
torch
.
Tensor
=
None
,
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
hidden_states
=
self
.
model
(
input_ids
,
positions
,
input_metadata
,
input_embeds
)
hidden_states
=
self
.
model
(
input_ids
,
positions
,
input_metadata
,
input_embeds
)
# print(f"{hidden_states=}")
return
self
.
logits_processor
(
logits_output
=
self
.
logits_processor
(
input_ids
,
hidden_states
,
self
.
lm_head
.
weight
,
input_metadata
input_ids
,
hidden_states
,
self
.
lm_head
.
weight
,
input_metadata
)
)
sample_output
=
self
.
sampler
(
logits_output
,
input_metadata
.
sampling_info
)
return
sample_output
,
logits_output
def
load_weights
(
def
load_weights
(
self
,
weights
:
Iterable
[
Tuple
[
str
,
torch
.
Tensor
]],
name
=
None
,
loaded_weight
=
None
self
,
weights
:
Iterable
[
Tuple
[
str
,
torch
.
Tensor
]],
name
=
None
,
loaded_weight
=
None
...
...
python/sglang/srt/models/xverse_moe.py
View file @
70b68029
...
@@ -44,7 +44,6 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader
...
@@ -44,7 +44,6 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from
sglang.srt.layers.logits_processor
import
LogitsProcessor
from
sglang.srt.layers.logits_processor
import
LogitsProcessor
from
sglang.srt.layers.radix_attention
import
RadixAttention
from
sglang.srt.layers.radix_attention
import
RadixAttention
from
sglang.srt.layers.sampler
import
Sampler
from
sglang.srt.model_executor.forward_batch_info
import
InputMetadata
from
sglang.srt.model_executor.forward_batch_info
import
InputMetadata
...
@@ -383,7 +382,6 @@ class XverseMoeForCausalLM(nn.Module):
...
@@ -383,7 +382,6 @@ class XverseMoeForCausalLM(nn.Module):
config
.
vocab_size
,
config
.
hidden_size
,
quant_config
=
quant_config
config
.
vocab_size
,
config
.
hidden_size
,
quant_config
=
quant_config
)
)
self
.
logits_processor
=
LogitsProcessor
(
config
)
self
.
logits_processor
=
LogitsProcessor
(
config
)
self
.
sampler
=
Sampler
()
self
.
param_dict
=
dict
(
self
.
named_parameters
())
self
.
param_dict
=
dict
(
self
.
named_parameters
())
...
@@ -395,11 +393,9 @@ class XverseMoeForCausalLM(nn.Module):
...
@@ -395,11 +393,9 @@ class XverseMoeForCausalLM(nn.Module):
input_metadata
:
InputMetadata
,
input_metadata
:
InputMetadata
,
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
hidden_states
=
self
.
model
(
input_ids
,
positions
,
input_metadata
)
hidden_states
=
self
.
model
(
input_ids
,
positions
,
input_metadata
)
logits_output
=
self
.
logits_processor
(
return
self
.
logits_processor
(
input_ids
,
hidden_states
,
self
.
lm_head
.
weight
,
input_metadata
input_ids
,
hidden_states
,
self
.
lm_head
.
weight
,
input_metadata
)
)
sample_output
=
self
.
sampler
(
logits_output
,
input_metadata
.
sampling_info
)
return
sample_output
,
logits_output
def
load_weights
(
self
,
weights
:
Iterable
[
Tuple
[
str
,
torch
.
Tensor
]]):
def
load_weights
(
self
,
weights
:
Iterable
[
Tuple
[
str
,
torch
.
Tensor
]]):
stacked_params_mapping
=
[
stacked_params_mapping
=
[
...
...
python/sglang/srt/sampling/sampling_batch_info.py
View file @
70b68029
...
@@ -41,7 +41,6 @@ class SamplingBatchInfo:
...
@@ -41,7 +41,6 @@ class SamplingBatchInfo:
# Vocab bias and min_ps are not supported in CUDA graph
# Vocab bias and min_ps are not supported in CUDA graph
return
(
return
(
self
.
logit_bias
is
None
self
.
logit_bias
is
None
and
self
.
vocab_mask
is
None
and
self
.
linear_penalties
is
None
and
self
.
linear_penalties
is
None
and
self
.
scaling_penalties
is
None
and
self
.
scaling_penalties
is
None
and
not
self
.
need_min_p_sampling
and
not
self
.
need_min_p_sampling
...
@@ -50,9 +49,11 @@ class SamplingBatchInfo:
...
@@ -50,9 +49,11 @@ class SamplingBatchInfo:
@
classmethod
@
classmethod
def
dummy_one
(
cls
,
max_bs
:
int
,
vocab_size
:
int
):
def
dummy_one
(
cls
,
max_bs
:
int
,
vocab_size
:
int
):
ret
=
cls
(
vocab_size
=
vocab_size
)
ret
=
cls
(
vocab_size
=
vocab_size
)
ret
.
temperatures
=
torch
.
ones
((
max_bs
,
1
),
dtype
=
torch
.
float
,
device
=
"cuda"
)
with
torch
.
device
(
"cuda"
):
ret
.
top_ps
=
torch
.
ones
((
max_bs
,),
dtype
=
torch
.
float
,
device
=
"cuda"
)
ret
.
temperatures
=
torch
.
ones
((
max_bs
,
1
),
dtype
=
torch
.
float
)
ret
.
top_ks
=
torch
.
ones
((
max_bs
,),
dtype
=
torch
.
int
,
device
=
"cuda"
)
ret
.
top_ps
=
torch
.
ones
((
max_bs
,),
dtype
=
torch
.
float
)
ret
.
top_ks
=
torch
.
ones
((
max_bs
,),
dtype
=
torch
.
int
)
ret
.
vocab_mask
=
torch
.
zeros
((
max_bs
,
vocab_size
),
dtype
=
torch
.
bool
)
return
ret
return
ret
def
__getitem__
(
self
,
key
):
def
__getitem__
(
self
,
key
):
...
@@ -64,6 +65,7 @@ class SamplingBatchInfo:
...
@@ -64,6 +65,7 @@ class SamplingBatchInfo:
temperatures
=
self
.
temperatures
[
key
],
temperatures
=
self
.
temperatures
[
key
],
top_ps
=
self
.
top_ps
[
key
],
top_ps
=
self
.
top_ps
[
key
],
top_ks
=
self
.
top_ks
[
key
],
top_ks
=
self
.
top_ks
[
key
],
vocab_mask
=
self
.
vocab_mask
[
key
],
)
)
else
:
else
:
raise
NotImplementedError
raise
NotImplementedError
...
@@ -77,6 +79,11 @@ class SamplingBatchInfo:
...
@@ -77,6 +79,11 @@ class SamplingBatchInfo:
self
.
top_ps
[:
bs
]
=
other
.
top_ps
self
.
top_ps
[:
bs
]
=
other
.
top_ps
self
.
top_ks
[:
bs
]
=
other
.
top_ks
self
.
top_ks
[:
bs
]
=
other
.
top_ks
if
other
.
vocab_mask
is
None
:
self
.
vocab_mask
[:
bs
].
fill_
(
False
)
else
:
self
.
vocab_mask
[:
bs
]
=
other
.
vocab_mask
@
classmethod
@
classmethod
def
from_schedule_batch
(
cls
,
batch
:
ScheduleBatch
,
vocab_size
:
int
):
def
from_schedule_batch
(
cls
,
batch
:
ScheduleBatch
,
vocab_size
:
int
):
device
=
"cuda"
device
=
"cuda"
...
...
Prev
1
2
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment