Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
85e1a6f3
Unverified
Commit
85e1a6f3
authored
Dec 02, 2024
by
Yineng Zhang
Committed by
GitHub
Dec 02, 2024
Browse files
Update model_loader deps and qqq quantization deps (#2220) (#2318)
Co-authored-by:
HandH1998
<
1335248067@qq.com
>
parent
33deca81
Changes
58
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
33 additions
and
111 deletions
+33
-111
python/sglang/srt/models/deepseek.py
python/sglang/srt/models/deepseek.py
+3
-10
python/sglang/srt/models/deepseek_v2.py
python/sglang/srt/models/deepseek_v2.py
+2
-10
python/sglang/srt/models/exaone.py
python/sglang/srt/models/exaone.py
+1
-2
python/sglang/srt/models/gemma.py
python/sglang/srt/models/gemma.py
+1
-5
python/sglang/srt/models/gemma2.py
python/sglang/srt/models/gemma2.py
+2
-13
python/sglang/srt/models/gemma2_reward.py
python/sglang/srt/models/gemma2_reward.py
+0
-1
python/sglang/srt/models/gpt2.py
python/sglang/srt/models/gpt2.py
+4
-11
python/sglang/srt/models/gpt_bigcode.py
python/sglang/srt/models/gpt_bigcode.py
+5
-21
python/sglang/srt/models/grok.py
python/sglang/srt/models/grok.py
+2
-2
python/sglang/srt/models/internlm2.py
python/sglang/srt/models/internlm2.py
+1
-2
python/sglang/srt/models/internlm2_reward.py
python/sglang/srt/models/internlm2_reward.py
+0
-1
python/sglang/srt/models/llama.py
python/sglang/srt/models/llama.py
+1
-2
python/sglang/srt/models/llama_classification.py
python/sglang/srt/models/llama_classification.py
+1
-2
python/sglang/srt/models/llama_embedding.py
python/sglang/srt/models/llama_embedding.py
+1
-2
python/sglang/srt/models/llama_reward.py
python/sglang/srt/models/llama_reward.py
+2
-3
python/sglang/srt/models/llava.py
python/sglang/srt/models/llava.py
+1
-4
python/sglang/srt/models/llavavid.py
python/sglang/srt/models/llavavid.py
+1
-2
python/sglang/srt/models/minicpm.py
python/sglang/srt/models/minicpm.py
+1
-2
python/sglang/srt/models/minicpm3.py
python/sglang/srt/models/minicpm3.py
+3
-14
python/sglang/srt/models/mixtral.py
python/sglang/srt/models/mixtral.py
+1
-2
No files found.
python/sglang/srt/models/deepseek.py
View file @
85e1a6f3
...
...
@@ -27,7 +27,6 @@ from vllm.distributed import (
tensor_model_parallel_all_reduce
,
)
from
vllm.model_executor.layers.rotary_embedding
import
get_rope
from
vllm.model_executor.model_loader.weight_utils
import
default_weight_loader
from
sglang.srt.layers.activation
import
SiluAndMul
from
sglang.srt.layers.fused_moe_triton
import
fused_moe
...
...
@@ -46,6 +45,7 @@ from sglang.srt.layers.vocab_parallel_embedding import (
VocabParallelEmbedding
,
)
from
sglang.srt.model_executor.forward_batch_info
import
ForwardBatch
from
sglang.srt.model_loader.weight_utils
import
default_weight_loader
class
DeepseekMLP
(
nn
.
Module
):
...
...
@@ -184,7 +184,6 @@ class DeepseekAttention(nn.Module):
rope_theta
:
float
=
10000
,
rope_scaling
:
Optional
[
Dict
[
str
,
Any
]]
=
None
,
max_position_embeddings
:
int
=
8192
,
cache_config
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
)
->
None
:
super
().
__init__
()
...
...
@@ -261,7 +260,6 @@ class DeepseekDecoderLayer(nn.Module):
self
,
config
:
PretrainedConfig
,
layer_id
:
int
,
cache_config
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
)
->
None
:
super
().
__init__
()
...
...
@@ -277,7 +275,6 @@ class DeepseekDecoderLayer(nn.Module):
rope_theta
=
rope_theta
,
rope_scaling
=
rope_scaling
,
max_position_embeddings
=
max_position_embeddings
,
cache_config
=
cache_config
,
quant_config
=
quant_config
,
)
if
(
...
...
@@ -330,7 +327,6 @@ class DeepseekModel(nn.Module):
def
__init__
(
self
,
config
:
PretrainedConfig
,
cache_config
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
)
->
None
:
super
().
__init__
()
...
...
@@ -343,9 +339,7 @@ class DeepseekModel(nn.Module):
)
self
.
layers
=
nn
.
ModuleList
(
[
DeepseekDecoderLayer
(
config
,
layer_id
,
cache_config
,
quant_config
=
quant_config
)
DeepseekDecoderLayer
(
config
,
layer_id
,
quant_config
=
quant_config
)
for
layer_id
in
range
(
config
.
num_hidden_layers
)
]
)
...
...
@@ -373,13 +367,12 @@ class DeepseekForCausalLM(nn.Module):
def
__init__
(
self
,
config
:
PretrainedConfig
,
cache_config
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
)
->
None
:
super
().
__init__
()
self
.
config
=
config
self
.
quant_config
=
quant_config
self
.
model
=
DeepseekModel
(
config
,
cache_config
,
quant_config
)
self
.
model
=
DeepseekModel
(
config
,
quant_config
)
self
.
lm_head
=
ParallelLMHead
(
config
.
vocab_size
,
config
.
hidden_size
,
quant_config
=
quant_config
)
...
...
python/sglang/srt/models/deepseek_v2.py
View file @
85e1a6f3
...
...
@@ -28,7 +28,6 @@ from vllm.distributed import (
tensor_model_parallel_all_reduce
,
)
from
vllm.model_executor.layers.rotary_embedding
import
get_rope
from
vllm.model_executor.model_loader.weight_utils
import
default_weight_loader
from
sglang.srt.layers.activation
import
SiluAndMul
from
sglang.srt.layers.fused_moe_triton
import
FusedMoE
...
...
@@ -48,6 +47,7 @@ from sglang.srt.layers.vocab_parallel_embedding import (
)
from
sglang.srt.managers.schedule_batch
import
global_server_args_dict
from
sglang.srt.model_executor.forward_batch_info
import
ForwardBatch
from
sglang.srt.model_loader.weight_utils
import
default_weight_loader
from
sglang.srt.utils
import
is_flashinfer_available
if
is_flashinfer_available
():
...
...
@@ -189,7 +189,6 @@ class DeepseekV2Attention(nn.Module):
rope_theta
:
float
=
10000
,
rope_scaling
:
Optional
[
Dict
[
str
,
Any
]]
=
None
,
max_position_embeddings
:
int
=
8192
,
cache_config
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
layer_id
=
None
,
)
->
None
:
...
...
@@ -337,7 +336,6 @@ class DeepseekV2AttentionMLA(nn.Module):
rope_theta
:
float
=
10000
,
rope_scaling
:
Optional
[
Dict
[
str
,
Any
]]
=
None
,
max_position_embeddings
:
int
=
8192
,
cache_config
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
layer_id
=
None
,
use_dp
=
False
,
...
...
@@ -568,7 +566,6 @@ class DeepseekV2DecoderLayer(nn.Module):
self
,
config
:
PretrainedConfig
,
layer_id
:
int
,
cache_config
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
)
->
None
:
super
().
__init__
()
...
...
@@ -599,7 +596,6 @@ class DeepseekV2DecoderLayer(nn.Module):
rope_theta
=
rope_theta
,
rope_scaling
=
rope_scaling
,
max_position_embeddings
=
max_position_embeddings
,
cache_config
=
cache_config
,
quant_config
=
quant_config
,
layer_id
=
layer_id
,
use_dp
=
self
.
enable_dp_attention
,
...
...
@@ -619,7 +615,6 @@ class DeepseekV2DecoderLayer(nn.Module):
rope_theta
=
rope_theta
,
rope_scaling
=
rope_scaling
,
max_position_embeddings
=
max_position_embeddings
,
cache_config
=
cache_config
,
quant_config
=
quant_config
,
layer_id
=
layer_id
,
)
...
...
@@ -685,7 +680,6 @@ class DeepseekV2Model(nn.Module):
def
__init__
(
self
,
config
:
PretrainedConfig
,
cache_config
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
)
->
None
:
super
().
__init__
()
...
...
@@ -702,7 +696,6 @@ class DeepseekV2Model(nn.Module):
DeepseekV2DecoderLayer
(
config
,
layer_id
,
cache_config
=
cache_config
,
quant_config
=
quant_config
,
)
for
layer_id
in
range
(
config
.
num_hidden_layers
)
...
...
@@ -733,13 +726,12 @@ class DeepseekV2ForCausalLM(nn.Module):
def
__init__
(
self
,
config
:
PretrainedConfig
,
cache_config
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
)
->
None
:
super
().
__init__
()
self
.
config
=
config
self
.
quant_config
=
quant_config
self
.
model
=
DeepseekV2Model
(
config
,
cache_config
,
quant_config
)
self
.
model
=
DeepseekV2Model
(
config
,
quant_config
)
if
global_server_args_dict
[
"enable_dp_attention"
]:
self
.
lm_head
=
ReplicatedLinear
(
config
.
hidden_size
,
...
...
python/sglang/srt/models/exaone.py
View file @
85e1a6f3
...
...
@@ -22,7 +22,6 @@ import torch
from
torch
import
nn
from
vllm.distributed
import
get_tensor_model_parallel_world_size
from
vllm.model_executor.layers.rotary_embedding
import
get_rope
from
vllm.model_executor.model_loader.weight_utils
import
default_weight_loader
from
sglang.srt.layers.activation
import
SiluAndMul
from
sglang.srt.layers.layernorm
import
RMSNorm
...
...
@@ -39,6 +38,7 @@ from sglang.srt.layers.vocab_parallel_embedding import (
VocabParallelEmbedding
,
)
from
sglang.srt.model_executor.forward_batch_info
import
ForwardBatch
from
sglang.srt.model_loader.weight_utils
import
default_weight_loader
class
ExaoneGatedMLP
(
nn
.
Module
):
...
...
@@ -293,7 +293,6 @@ class ExaoneForCausalLM(nn.Module):
self
,
config
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
cache_config
=
None
,
)
->
None
:
super
().
__init__
()
self
.
config
=
config
...
...
python/sglang/srt/models/gemma.py
View file @
85e1a6f3
...
...
@@ -21,10 +21,8 @@ from typing import Iterable, Optional, Tuple
import
torch
from
torch
import
nn
from
transformers
import
PretrainedConfig
from
vllm.config
import
LoRAConfig
from
vllm.distributed
import
get_tensor_model_parallel_world_size
from
vllm.model_executor.layers.rotary_embedding
import
get_rope
from
vllm.model_executor.model_loader.weight_utils
import
default_weight_loader
from
sglang.srt.layers.activation
import
GeluAndMul
from
sglang.srt.layers.layernorm
import
RMSNorm
...
...
@@ -38,6 +36,7 @@ from sglang.srt.layers.quantization.base_config import QuantizationConfig
from
sglang.srt.layers.radix_attention
import
RadixAttention
from
sglang.srt.layers.vocab_parallel_embedding
import
VocabParallelEmbedding
from
sglang.srt.model_executor.forward_batch_info
import
ForwardBatch
from
sglang.srt.model_loader.weight_utils
import
default_weight_loader
class
GemmaMLP
(
nn
.
Module
):
...
...
@@ -278,10 +277,7 @@ class GemmaForCausalLM(nn.Module):
self
,
config
:
PretrainedConfig
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
lora_config
:
Optional
[
LoRAConfig
]
=
None
,
cache_config
=
None
,
)
->
None
:
del
lora_config
# Unused.
super
().
__init__
()
self
.
config
=
config
self
.
quant_config
=
quant_config
...
...
python/sglang/srt/models/gemma2.py
View file @
85e1a6f3
...
...
@@ -20,12 +20,8 @@ from typing import Iterable, Optional, Set, Tuple, Union
import
torch
from
torch
import
nn
from
transformers
import
PretrainedConfig
from
vllm.config
import
LoRAConfig
from
vllm.distributed
import
get_tensor_model_parallel_world_size
# from vllm.model_executor.layers.rotary_embedding import GemmaRotaryEmbedding
from
vllm.model_executor.model_loader.weight_utils
import
default_weight_loader
from
sglang.srt.layers.activation
import
GeluAndMul
from
sglang.srt.layers.layernorm
import
GemmaRMSNorm
from
sglang.srt.layers.linear
import
(
...
...
@@ -38,6 +34,7 @@ from sglang.srt.layers.quantization.base_config import QuantizationConfig
from
sglang.srt.layers.radix_attention
import
RadixAttention
from
sglang.srt.layers.vocab_parallel_embedding
import
VocabParallelEmbedding
from
sglang.srt.model_executor.forward_batch_info
import
ForwardBatch
from
sglang.srt.model_loader.weight_utils
import
default_weight_loader
from
sglang.srt.utils
import
make_layers
...
...
@@ -106,7 +103,6 @@ class Gemma2Attention(nn.Module):
head_dim
:
int
,
max_position_embeddings
:
int
,
rope_theta
:
float
,
cache_config
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
)
->
None
:
super
().
__init__
()
...
...
@@ -191,7 +187,6 @@ class Gemma2DecoderLayer(nn.Module):
self
,
layer_id
:
int
,
config
:
PretrainedConfig
,
cache_config
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
)
->
None
:
super
().
__init__
()
...
...
@@ -205,7 +200,6 @@ class Gemma2DecoderLayer(nn.Module):
head_dim
=
config
.
head_dim
,
max_position_embeddings
=
config
.
max_position_embeddings
,
rope_theta
=
config
.
rope_theta
,
cache_config
=
cache_config
,
quant_config
=
quant_config
,
)
self
.
hidden_size
=
config
.
hidden_size
...
...
@@ -258,7 +252,6 @@ class Gemma2Model(nn.Module):
def
__init__
(
self
,
config
:
PretrainedConfig
,
cache_config
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
)
->
None
:
super
().
__init__
()
...
...
@@ -273,7 +266,6 @@ class Gemma2Model(nn.Module):
lambda
idx
,
prefix
:
Gemma2DecoderLayer
(
layer_id
=
idx
,
config
=
config
,
cache_config
=
cache_config
,
quant_config
=
quant_config
,
),
prefix
=
""
,
...
...
@@ -342,15 +334,12 @@ class Gemma2ForCausalLM(nn.Module):
def
__init__
(
self
,
config
:
PretrainedConfig
,
cache_config
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
lora_config
:
Optional
[
LoRAConfig
]
=
None
,
)
->
None
:
del
lora_config
# Unused.
super
().
__init__
()
self
.
config
=
config
self
.
quant_config
=
quant_config
self
.
model
=
Gemma2Model
(
config
,
cache_config
,
quant_config
)
self
.
model
=
Gemma2Model
(
config
,
quant_config
)
self
.
logits_processor
=
LogitsProcessor
(
config
)
@
torch
.
no_grad
()
...
...
python/sglang/srt/models/gemma2_reward.py
View file @
85e1a6f3
...
...
@@ -29,7 +29,6 @@ class Gemma2ForSequenceClassification(nn.Module):
self
,
config
:
Gemma2Config
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
cache_config
=
None
,
)
->
None
:
super
().
__init__
()
self
.
config
=
config
...
...
python/sglang/srt/models/gpt2.py
View file @
85e1a6f3
...
...
@@ -22,11 +22,9 @@ from typing import Iterable, List, Optional, Tuple
import
torch
from
torch
import
nn
from
transformers
import
GPT2Config
from
vllm.config
import
CacheConfig
from
vllm.distributed.parallel_state
import
get_tensor_model_parallel_world_size
from
vllm.model_executor.layers.activation
import
get_act_fn
from
vllm.model_executor.layers.vocab_parallel_embedding
import
VocabParallelEmbedding
from
vllm.model_executor.model_loader.weight_utils
import
default_weight_loader
# from sglang.srt.layers.activation import get_act_fn
from
sglang.srt.layers.linear
import
(
...
...
@@ -39,6 +37,7 @@ from sglang.srt.layers.quantization.base_config import QuantizationConfig
from
sglang.srt.layers.radix_attention
import
RadixAttention
from
sglang.srt.layers.vocab_parallel_embedding
import
VocabParallelEmbedding
from
sglang.srt.model_executor.forward_batch_info
import
ForwardBatch
from
sglang.srt.model_loader.weight_utils
import
default_weight_loader
class
GPT2Attention
(
nn
.
Module
):
...
...
@@ -47,7 +46,6 @@ class GPT2Attention(nn.Module):
self
,
layer_id
:
int
,
config
:
GPT2Config
,
cache_config
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
prefix
:
str
=
""
,
):
...
...
@@ -140,7 +138,6 @@ class GPT2Block(nn.Module):
self
,
layer_id
:
int
,
config
:
GPT2Config
,
cache_config
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
prefix
:
str
=
""
,
):
...
...
@@ -150,7 +147,7 @@ class GPT2Block(nn.Module):
self
.
ln_1
=
nn
.
LayerNorm
(
hidden_size
,
eps
=
config
.
layer_norm_epsilon
)
self
.
attn
=
GPT2Attention
(
layer_id
,
config
,
cache_config
,
quant_config
,
prefix
=
f
"
{
prefix
}
.attn"
layer_id
,
config
,
quant_config
,
prefix
=
f
"
{
prefix
}
.attn"
)
self
.
ln_2
=
nn
.
LayerNorm
(
hidden_size
,
eps
=
config
.
layer_norm_epsilon
)
self
.
mlp
=
GPT2MLP
(
inner_dim
,
config
,
quant_config
,
prefix
=
f
"
{
prefix
}
.mlp"
)
...
...
@@ -182,7 +179,6 @@ class GPT2Model(nn.Module):
def
__init__
(
self
,
config
:
GPT2Config
,
cache_config
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
prefix
:
str
=
""
,
):
...
...
@@ -196,7 +192,7 @@ class GPT2Model(nn.Module):
self
.
wpe
=
nn
.
Embedding
(
config
.
max_position_embeddings
,
self
.
embed_dim
)
self
.
h
=
nn
.
ModuleList
(
[
GPT2Block
(
i
,
config
,
cache_config
,
quant_config
)
GPT2Block
(
i
,
config
,
quant_config
)
for
i
in
range
(
config
.
num_hidden_layers
)
]
)
...
...
@@ -226,15 +222,12 @@ class GPT2LMHeadModel(nn.Module):
def
__init__
(
self
,
config
:
GPT2Config
,
cache_config
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
):
super
().
__init__
()
self
.
config
=
config
self
.
quant_config
=
quant_config
self
.
transformer
=
GPT2Model
(
config
,
cache_config
,
quant_config
,
prefix
=
"transformer"
)
self
.
transformer
=
GPT2Model
(
config
,
quant_config
,
prefix
=
"transformer"
)
self
.
lm_head
=
self
.
transformer
.
wte
self
.
logits_processor
=
LogitsProcessor
(
config
)
...
...
python/sglang/srt/models/gpt_bigcode.py
View file @
85e1a6f3
...
...
@@ -21,9 +21,7 @@ from typing import Iterable, Optional, Tuple
import
torch
from
torch
import
nn
from
transformers
import
GPTBigCodeConfig
from
vllm.config
import
LoRAConfig
from
vllm.distributed
import
get_tensor_model_parallel_world_size
from
vllm.model_executor.model_loader.weight_utils
import
default_weight_loader
from
sglang.srt.layers.activation
import
get_act_fn
from
sglang.srt.layers.linear
import
(
...
...
@@ -36,6 +34,7 @@ from sglang.srt.layers.quantization.base_config import QuantizationConfig
from
sglang.srt.layers.radix_attention
import
RadixAttention
from
sglang.srt.layers.vocab_parallel_embedding
import
VocabParallelEmbedding
from
sglang.srt.model_executor.forward_batch_info
import
ForwardBatch
from
sglang.srt.model_loader.weight_utils
import
default_weight_loader
class
GPTBigCodeAttention
(
nn
.
Module
):
...
...
@@ -44,7 +43,6 @@ class GPTBigCodeAttention(nn.Module):
self
,
layer_id
:
int
,
config
:
GPTBigCodeConfig
,
cache_config
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
):
super
().
__init__
()
...
...
@@ -145,7 +143,6 @@ class GPTBigCodeBlock(nn.Module):
self
,
layer_id
:
int
,
config
:
GPTBigCodeConfig
,
cache_config
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
):
super
().
__init__
()
...
...
@@ -153,7 +150,7 @@ class GPTBigCodeBlock(nn.Module):
inner_dim
=
config
.
n_inner
if
config
.
n_inner
is
not
None
else
4
*
hidden_size
self
.
ln_1
=
nn
.
LayerNorm
(
hidden_size
,
eps
=
config
.
layer_norm_epsilon
)
self
.
attn
=
GPTBigCodeAttention
(
layer_id
,
config
,
cache_config
,
quant_config
)
self
.
attn
=
GPTBigCodeAttention
(
layer_id
,
config
,
quant_config
)
self
.
ln_2
=
nn
.
LayerNorm
(
hidden_size
,
eps
=
config
.
layer_norm_epsilon
)
self
.
mlp
=
GPTBigMLP
(
inner_dim
,
config
,
quant_config
)
...
...
@@ -183,20 +180,14 @@ class GPTBigCodeModel(nn.Module):
def
__init__
(
self
,
config
:
GPTBigCodeConfig
,
cache_config
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
lora_config
:
Optional
[
LoRAConfig
]
=
None
,
):
super
().
__init__
()
self
.
config
=
config
assert
not
config
.
add_cross_attention
self
.
embed_dim
=
config
.
hidden_size
lora_vocab
=
(
(
lora_config
.
lora_extra_vocab_size
*
(
lora_config
.
max_loras
or
1
))
if
lora_config
else
0
)
lora_vocab
=
0
self
.
vocab_size
=
config
.
vocab_size
+
lora_vocab
self
.
wte
=
VocabParallelEmbedding
(
self
.
vocab_size
,
self
.
embed_dim
,
org_num_embeddings
=
config
.
vocab_size
...
...
@@ -204,7 +195,7 @@ class GPTBigCodeModel(nn.Module):
self
.
wpe
=
nn
.
Embedding
(
config
.
max_position_embeddings
,
self
.
embed_dim
)
self
.
h
=
nn
.
ModuleList
(
[
GPTBigCodeBlock
(
i
,
config
,
cache_config
,
quant_config
)
GPTBigCodeBlock
(
i
,
config
,
quant_config
)
for
i
in
range
(
config
.
num_hidden_layers
)
]
)
...
...
@@ -243,23 +234,16 @@ class GPTBigCodeForCausalLM(nn.Module):
def
__init__
(
self
,
config
:
GPTBigCodeConfig
,
cache_config
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
lora_config
:
Optional
[
LoRAConfig
]
=
None
,
):
super
().
__init__
()
self
.
config
=
config
self
.
lora_config
=
lora_config
self
.
quant_config
=
quant_config
self
.
transformer
=
GPTBigCodeModel
(
config
,
cache_config
,
quant_config
,
lora_config
)
self
.
transformer
=
GPTBigCodeModel
(
config
,
quant_config
)
self
.
lm_head
=
self
.
transformer
.
wte
self
.
unpadded_vocab_size
=
config
.
vocab_size
if
lora_config
:
self
.
unpadded_vocab_size
+=
lora_config
.
lora_extra_vocab_size
self
.
logits_processor
=
LogitsProcessor
(
config
)
@
torch
.
no_grad
()
...
...
python/sglang/srt/models/grok.py
View file @
85e1a6f3
...
...
@@ -24,7 +24,6 @@ from torch import nn
from
transformers
import
PretrainedConfig
from
vllm.distributed
import
get_tensor_model_parallel_world_size
from
vllm.model_executor.layers.rotary_embedding
import
get_rope
from
vllm.model_executor.model_loader.weight_utils
import
default_weight_loader
from
sglang.srt.layers.fused_moe_triton
import
FusedMoE
from
sglang.srt.layers.layernorm
import
RMSNorm
...
...
@@ -43,6 +42,8 @@ from sglang.srt.layers.vocab_parallel_embedding import (
)
from
sglang.srt.managers.schedule_batch
import
global_server_args_dict
from
sglang.srt.model_executor.forward_batch_info
import
ForwardBatch
from
sglang.srt.model_loader.loader
import
DefaultModelLoader
from
sglang.srt.model_loader.weight_utils
import
default_weight_loader
class
Grok1MoE
(
nn
.
Module
):
...
...
@@ -285,7 +286,6 @@ class Grok1ForCausalLM(nn.Module):
self
,
config
:
PretrainedConfig
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
cache_config
=
None
,
)
->
None
:
super
().
__init__
()
self
.
config
=
config
...
...
python/sglang/srt/models/internlm2.py
View file @
85e1a6f3
...
...
@@ -21,7 +21,6 @@ from torch import nn
from
transformers
import
PretrainedConfig
from
vllm.distributed
import
get_tensor_model_parallel_world_size
from
vllm.model_executor.layers.rotary_embedding
import
get_rope
from
vllm.model_executor.model_loader.weight_utils
import
default_weight_loader
from
sglang.srt.layers.activation
import
SiluAndMul
from
sglang.srt.layers.layernorm
import
RMSNorm
...
...
@@ -38,6 +37,7 @@ from sglang.srt.layers.vocab_parallel_embedding import (
VocabParallelEmbedding
,
)
from
sglang.srt.model_executor.forward_batch_info
import
ForwardBatch
from
sglang.srt.model_loader.weight_utils
import
default_weight_loader
class
InternLM2MLP
(
nn
.
Module
):
...
...
@@ -251,7 +251,6 @@ class InternLM2ForCausalLM(nn.Module):
self
,
config
:
PretrainedConfig
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
cache_config
=
None
,
)
->
None
:
super
().
__init__
()
self
.
config
=
config
...
...
python/sglang/srt/models/internlm2_reward.py
View file @
85e1a6f3
...
...
@@ -29,7 +29,6 @@ class InternLM2ForRewardModel(nn.Module):
self
,
config
:
PretrainedConfig
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
cache_config
=
None
,
)
->
None
:
super
().
__init__
()
self
.
config
=
config
...
...
python/sglang/srt/models/llama.py
View file @
85e1a6f3
...
...
@@ -24,7 +24,6 @@ from torch import nn
from
transformers
import
LlamaConfig
from
vllm.distributed
import
get_tensor_model_parallel_world_size
from
vllm.model_executor.layers.rotary_embedding
import
get_rope
from
vllm.model_executor.model_loader.weight_utils
import
default_weight_loader
from
sglang.srt.layers.activation
import
SiluAndMul
from
sglang.srt.layers.layernorm
import
RMSNorm
...
...
@@ -44,6 +43,7 @@ from sglang.srt.layers.vocab_parallel_embedding import (
)
from
sglang.srt.managers.schedule_batch
import
global_server_args_dict
from
sglang.srt.model_executor.forward_batch_info
import
ForwardBatch
from
sglang.srt.model_loader.weight_utils
import
default_weight_loader
from
sglang.srt.utils
import
make_layers
from
sglang.utils
import
get_exception_traceback
...
...
@@ -300,7 +300,6 @@ class LlamaForCausalLM(nn.Module):
self
,
config
:
LlamaConfig
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
cache_config
=
None
,
)
->
None
:
super
().
__init__
()
self
.
config
=
config
...
...
python/sglang/srt/models/llama_classification.py
View file @
85e1a6f3
...
...
@@ -17,11 +17,11 @@ from typing import Iterable, Optional, Tuple
import
torch
from
torch
import
nn
from
transformers
import
LlamaConfig
from
vllm.model_executor.model_loader.weight_utils
import
default_weight_loader
from
sglang.srt.layers.logits_processor
import
LogitsProcessorOutput
from
sglang.srt.layers.quantization.base_config
import
QuantizationConfig
from
sglang.srt.model_executor.forward_batch_info
import
ForwardBatch
from
sglang.srt.model_loader.weight_utils
import
default_weight_loader
from
sglang.srt.models.llama
import
LlamaForCausalLM
,
LlamaModel
...
...
@@ -30,7 +30,6 @@ class LlamaForClassification(nn.Module):
self
,
config
:
LlamaConfig
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
cache_config
=
None
,
)
->
None
:
super
().
__init__
()
self
.
config
=
config
...
...
python/sglang/srt/models/llama_embedding.py
View file @
85e1a6f3
...
...
@@ -3,10 +3,10 @@ from typing import Iterable, Tuple
import
torch
from
torch
import
nn
from
transformers
import
LlamaConfig
from
vllm.model_executor.model_loader.weight_utils
import
default_weight_loader
from
sglang.srt.layers.pooler
import
EmbeddingPoolerOutput
,
Pooler
,
PoolingType
from
sglang.srt.model_executor.model_runner
import
ForwardBatch
from
sglang.srt.model_loader.weight_utils
import
default_weight_loader
from
sglang.srt.models.llama
import
LlamaModel
...
...
@@ -15,7 +15,6 @@ class LlamaEmbeddingModel(nn.Module):
self
,
config
:
LlamaConfig
,
quant_config
=
None
,
cache_config
=
None
,
)
->
None
:
super
().
__init__
()
self
.
model
=
LlamaModel
(
config
,
quant_config
=
quant_config
)
...
...
python/sglang/srt/models/llama_reward.py
View file @
85e1a6f3
...
...
@@ -21,6 +21,7 @@ from transformers import LlamaConfig
from
sglang.srt.layers.pooler
import
EmbeddingPoolerOutput
,
Pooler
,
PoolingType
from
sglang.srt.layers.quantization.base_config
import
QuantizationConfig
from
sglang.srt.model_executor.forward_batch_info
import
ForwardBatch
from
sglang.srt.model_loader.weight_utils
import
default_weight_loader
from
sglang.srt.models.llama
import
LlamaForCausalLM
,
LlamaModel
...
...
@@ -29,7 +30,6 @@ class LlamaForSequenceClassification(nn.Module):
self
,
config
:
LlamaConfig
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
cache_config
=
None
,
)
->
None
:
super
().
__init__
()
self
.
config
=
config
...
...
@@ -84,9 +84,8 @@ class LlamaForSequenceClassificationWithNormal_Weights(LlamaForSequenceClassific
self
,
config
:
LlamaConfig
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
cache_config
=
None
,
)
->
None
:
super
().
__init__
(
config
,
quant_config
,
cache_config
)
super
().
__init__
(
config
,
quant_config
)
self
.
weights
=
self
.
Weights
(
config
.
hidden_size
,
self
.
num_labels
)
@
torch
.
no_grad
()
...
...
python/sglang/srt/models/llava.py
View file @
85e1a6f3
...
...
@@ -29,7 +29,6 @@ from transformers import (
SiglipVisionModel
,
)
from
transformers.models.llava.modeling_llava
import
LlavaMultiModalProjector
from
vllm.model_executor.model_loader.weight_utils
import
default_weight_loader
from
sglang.srt.layers.quantization.base_config
import
QuantizationConfig
from
sglang.srt.managers.schedule_batch
import
ImageInputs
...
...
@@ -39,6 +38,7 @@ from sglang.srt.mm_utils import (
unpad_image_shape
,
)
from
sglang.srt.model_executor.forward_batch_info
import
ForwardBatch
from
sglang.srt.model_loader.weight_utils
import
default_weight_loader
from
sglang.srt.models.llama
import
LlamaForCausalLM
from
sglang.srt.models.mistral
import
MistralForCausalLM
from
sglang.srt.models.qwen2
import
Qwen2ForCausalLM
...
...
@@ -451,7 +451,6 @@ class LlavaLlamaForCausalLM(LlavaBaseForCausalLM):
self
,
config
:
LlavaConfig
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
cache_config
=
None
,
)
->
None
:
super
().
__init__
()
...
...
@@ -473,7 +472,6 @@ class LlavaQwenForCausalLM(LlavaBaseForCausalLM):
self
,
config
:
LlavaConfig
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
cache_config
=
None
,
)
->
None
:
super
().
__init__
()
...
...
@@ -506,7 +504,6 @@ class LlavaMistralForCausalLM(LlavaBaseForCausalLM):
self
,
config
:
LlavaConfig
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
cache_config
=
None
,
)
->
None
:
super
().
__init__
()
...
...
python/sglang/srt/models/llavavid.py
View file @
85e1a6f3
...
...
@@ -20,11 +20,11 @@ import torch
from
torch
import
nn
from
transformers
import
CLIPVisionModel
,
LlavaConfig
from
transformers.models.llava.modeling_llava
import
LlavaMultiModalProjector
from
vllm.model_executor.model_loader.weight_utils
import
default_weight_loader
from
sglang.srt.layers.quantization.base_config
import
QuantizationConfig
from
sglang.srt.managers.schedule_batch
import
ImageInputs
from
sglang.srt.model_executor.forward_batch_info
import
ForwardBatch
from
sglang.srt.model_loader.weight_utils
import
default_weight_loader
from
sglang.srt.models.llama
import
LlamaForCausalLM
...
...
@@ -33,7 +33,6 @@ class LlavaVidForCausalLM(nn.Module):
self
,
config
:
LlavaConfig
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
cache_config
=
None
,
)
->
None
:
super
().
__init__
()
self
.
config
=
config
...
...
python/sglang/srt/models/minicpm.py
View file @
85e1a6f3
...
...
@@ -20,7 +20,6 @@ import torch
from
torch
import
nn
from
vllm.distributed
import
get_tensor_model_parallel_world_size
from
vllm.model_executor.layers.rotary_embedding
import
get_rope
from
vllm.model_executor.model_loader.weight_utils
import
default_weight_loader
from
sglang.srt.layers.activation
import
SiluAndMul
from
sglang.srt.layers.layernorm
import
RMSNorm
...
...
@@ -37,6 +36,7 @@ from sglang.srt.layers.vocab_parallel_embedding import (
VocabParallelEmbedding
,
)
from
sglang.srt.model_executor.forward_batch_info
import
ForwardBatch
from
sglang.srt.model_loader.weight_utils
import
default_weight_loader
class
MiniCPMMLP
(
nn
.
Module
):
...
...
@@ -275,7 +275,6 @@ class MiniCPMForCausalLM(nn.Module):
self
,
config
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
cache_config
=
None
,
)
->
None
:
super
().
__init__
()
self
.
config
=
config
...
...
python/sglang/srt/models/minicpm3.py
View file @
85e1a6f3
...
...
@@ -27,7 +27,6 @@ from vllm.model_executor.layers.linear import (
RowParallelLinear
,
)
from
vllm.model_executor.layers.rotary_embedding
import
get_rope
from
vllm.model_executor.model_loader.weight_utils
import
default_weight_loader
from
sglang.srt.layers.activation
import
SiluAndMul
from
sglang.srt.layers.layernorm
import
RMSNorm
...
...
@@ -40,6 +39,7 @@ from sglang.srt.layers.vocab_parallel_embedding import (
)
from
sglang.srt.managers.schedule_batch
import
global_server_args_dict
from
sglang.srt.model_executor.forward_batch_info
import
ForwardBatch
from
sglang.srt.model_loader.weight_utils
import
default_weight_loader
from
sglang.srt.utils
import
is_flashinfer_available
if
is_flashinfer_available
():
...
...
@@ -105,7 +105,6 @@ class MiniCPM3Attention(nn.Module):
rope_theta
:
float
=
10000
,
rope_scaling
:
Optional
[
Dict
[
str
,
Any
]]
=
None
,
max_position_embeddings
:
int
=
8192
,
cache_config
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
layer_id
=
None
,
)
->
None
:
...
...
@@ -249,7 +248,6 @@ class MiniCPM3AttentionMLA(nn.Module):
rope_theta
:
float
=
10000
,
rope_scaling
:
Optional
[
Dict
[
str
,
Any
]]
=
None
,
max_position_embeddings
:
int
=
8192
,
cache_config
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
layer_id
=
None
,
)
->
None
:
...
...
@@ -406,7 +404,6 @@ class MiniCPM3DecoderLayer(nn.Module):
self
,
config
:
PretrainedConfig
,
layer_id
:
int
,
cache_config
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
)
->
None
:
super
().
__init__
()
...
...
@@ -430,7 +427,6 @@ class MiniCPM3DecoderLayer(nn.Module):
rope_theta
=
rope_theta
,
rope_scaling
=
rope_scaling
,
max_position_embeddings
=
max_position_embeddings
,
cache_config
=
cache_config
,
quant_config
=
quant_config
,
layer_id
=
layer_id
,
)
...
...
@@ -449,7 +445,6 @@ class MiniCPM3DecoderLayer(nn.Module):
rope_theta
=
rope_theta
,
rope_scaling
=
rope_scaling
,
max_position_embeddings
=
max_position_embeddings
,
cache_config
=
cache_config
,
quant_config
=
quant_config
,
layer_id
=
layer_id
,
)
...
...
@@ -498,7 +493,6 @@ class MiniCPM3Model(nn.Module):
def
__init__
(
self
,
config
:
PretrainedConfig
,
cache_config
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
)
->
None
:
super
().
__init__
()
...
...
@@ -512,9 +506,7 @@ class MiniCPM3Model(nn.Module):
)
self
.
layers
=
nn
.
ModuleList
(
[
MiniCPM3DecoderLayer
(
config
,
i
,
cache_config
=
cache_config
,
quant_config
=
quant_config
)
MiniCPM3DecoderLayer
(
config
,
i
,
quant_config
=
quant_config
)
for
i
in
range
(
config
.
num_hidden_layers
)
]
)
...
...
@@ -549,7 +541,6 @@ class MiniCPM3ForCausalLM(nn.Module):
def
__init__
(
self
,
config
:
PretrainedConfig
,
cache_config
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
)
->
None
:
super
().
__init__
()
...
...
@@ -557,9 +548,7 @@ class MiniCPM3ForCausalLM(nn.Module):
self
.
num_experts
=
getattr
(
self
.
config
,
"num_experts"
,
0
)
self
.
quant_config
=
quant_config
self
.
model
=
MiniCPM3Model
(
config
,
cache_config
=
cache_config
,
quant_config
=
quant_config
)
self
.
model
=
MiniCPM3Model
(
config
,
quant_config
=
quant_config
)
# self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size)
if
not
self
.
config
.
tie_word_embeddings
:
self
.
lm_head
=
ParallelLMHead
(
...
...
python/sglang/srt/models/mixtral.py
View file @
85e1a6f3
...
...
@@ -23,7 +23,6 @@ from torch import nn
from
transformers
import
MixtralConfig
from
vllm.distributed
import
get_tensor_model_parallel_world_size
from
vllm.model_executor.layers.rotary_embedding
import
get_rope
from
vllm.model_executor.model_loader.weight_utils
import
default_weight_loader
from
sglang.srt.layers.fused_moe_triton
import
FusedMoE
from
sglang.srt.layers.layernorm
import
RMSNorm
...
...
@@ -42,6 +41,7 @@ from sglang.srt.layers.vocab_parallel_embedding import (
)
from
sglang.srt.managers.schedule_batch
import
global_server_args_dict
from
sglang.srt.model_executor.forward_batch_info
import
ForwardBatch
from
sglang.srt.model_loader.weight_utils
import
default_weight_loader
class
MixtralMoE
(
nn
.
Module
):
...
...
@@ -291,7 +291,6 @@ class MixtralForCausalLM(nn.Module):
self
,
config
:
MixtralConfig
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
cache_config
=
None
,
)
->
None
:
super
().
__init__
()
self
.
config
=
config
...
...
Prev
1
2
3
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment