Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
5282a473
Unverified
Commit
5282a473
authored
Dec 12, 2024
by
Lianmin Zheng
Committed by
GitHub
Dec 12, 2024
Browse files
[Minor] Fix grok model loader (#2473)
parent
f0ed9c35
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
72 additions
and
8 deletions
+72
-8
python/sglang/srt/models/grok.py
python/sglang/srt/models/grok.py
+72
-8
No files found.
python/sglang/srt/models/grok.py
View file @
5282a473
...
...
@@ -25,9 +25,11 @@ from transformers import PretrainedConfig
from
vllm.distributed
import
get_tensor_model_parallel_world_size
from
vllm.model_executor.layers.rotary_embedding
import
get_rope
from
sglang.srt.layers.activation
import
GeluAndMul
from
sglang.srt.layers.fused_moe_triton
import
FusedMoE
from
sglang.srt.layers.layernorm
import
RMSNorm
from
sglang.srt.layers.linear
import
(
MergedColumnParallelLinear
,
QKVParallelLinear
,
ReplicatedLinear
,
RowParallelLinear
,
...
...
@@ -40,10 +42,43 @@ from sglang.srt.layers.vocab_parallel_embedding import (
VocabParallelEmbedding
,
)
from
sglang.srt.model_executor.forward_batch_info
import
ForwardBatch
from
sglang.srt.model_loader.loader
import
DefaultModelLoader
from
sglang.srt.model_loader.weight_utils
import
default_weight_loader
class
Grok1MLP
(
nn
.
Module
):
def
__init__
(
self
,
hidden_size
:
int
,
intermediate_size
:
int
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
prefix
:
str
=
""
,
reduce_results
=
True
,
)
->
None
:
super
().
__init__
()
self
.
gate_up_proj
=
MergedColumnParallelLinear
(
hidden_size
,
[
intermediate_size
]
*
2
,
bias
=
False
,
quant_config
=
quant_config
,
prefix
=
f
"
{
prefix
}
.gate_up_proj"
,
)
self
.
down_proj
=
RowParallelLinear
(
intermediate_size
,
hidden_size
,
bias
=
False
,
quant_config
=
quant_config
,
prefix
=
f
"
{
prefix
}
.down_proj"
,
reduce_results
=
reduce_results
,
)
self
.
act_fn
=
GeluAndMul
(
approximate
=
"tanh"
)
def
forward
(
self
,
x
):
gate_up
,
_
=
self
.
gate_up_proj
(
x
)
x
=
self
.
act_fn
(
gate_up
)
x
,
_
=
self
.
down_proj
(
x
)
return
x
class
Grok1MoE
(
nn
.
Module
):
"""A tensor-parallel MoE implementation for Grok1 that shards each expert
across all ranks.
...
...
@@ -55,6 +90,7 @@ class Grok1MoE(nn.Module):
def
__init__
(
self
,
config
:
PretrainedConfig
,
num_experts
:
int
,
top_k
:
int
,
hidden_size
:
int
,
...
...
@@ -62,6 +98,7 @@ class Grok1MoE(nn.Module):
params_dtype
:
Optional
[
torch
.
dtype
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
tp_size
:
Optional
[
int
]
=
None
,
reduce_results
=
True
,
):
super
().
__init__
()
self
.
hidden_size
=
hidden_size
...
...
@@ -75,13 +112,16 @@ class Grok1MoE(nn.Module):
quant_config
=
None
,
)
self
.
router_logit_softcapping
=
getattr
(
config
,
"router_logit_softcapping"
,
30.0
)
self
.
experts
=
FusedMoE
(
num_experts
=
num_experts
,
top_k
=
top_k
,
hidden_size
=
hidden_size
,
intermediate_size
=
intermediate_size
,
params_dtype
=
params_dtype
,
reduce_results
=
True
,
reduce_results
=
reduce_results
,
renormalize
=
False
,
quant_config
=
quant_config
,
tp_size
=
tp_size
,
...
...
@@ -91,9 +131,12 @@ class Grok1MoE(nn.Module):
# NOTE: hidden_states can have either 1D or 2D shape.
orig_shape
=
hidden_states
.
shape
hidden_states
=
hidden_states
.
view
(
-
1
,
self
.
hidden_size
)
# router_logits: (num_tokens, n_experts)
router_logits
,
_
=
self
.
gate
(
hidden_states
)
router_logits
=
30.0
*
F
.
tanh
(
router_logits
/
30.0
)
# need to assert self.gate.quant_method is unquantized
final_hidden_states
=
self
.
experts
(
hidden_states
,
router_logits
)
return
final_hidden_states
.
view
(
orig_shape
)
...
...
@@ -101,16 +144,18 @@ class Grok1MoE(nn.Module):
class
Grok1Attention
(
nn
.
Module
):
def
__init__
(
self
,
config
:
PretrainedConfig
,
hidden_size
:
int
,
num_heads
:
int
,
num_kv_heads
:
int
,
layer_id
:
int
=
0
,
max_position
:
int
=
4096
*
32
,
rope_theta
:
float
=
10000
,
logit_cap
:
float
=
30
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
)
->
None
:
super
().
__init__
()
self
.
config
=
config
self
.
layer_id
=
layer_id
self
.
hidden_size
=
hidden_size
tp_size
=
get_tensor_model_parallel_world_size
()
self
.
total_num_heads
=
num_heads
...
...
@@ -126,7 +171,7 @@ class Grok1Attention(nn.Module):
# the KV heads across multiple tensor parallel GPUs.
assert
tp_size
%
self
.
total_num_kv_heads
==
0
self
.
num_kv_heads
=
max
(
1
,
self
.
total_num_kv_heads
//
tp_size
)
self
.
head_dim
=
128
self
.
head_dim
=
getattr
(
config
,
"head_dim"
,
128
)
self
.
q_size
=
self
.
num_heads
*
self
.
head_dim
self
.
kv_size
=
self
.
num_kv_heads
*
self
.
head_dim
self
.
scaling
=
self
.
head_dim
**-
0.5
...
...
@@ -140,7 +185,6 @@ class Grok1Attention(nn.Module):
bias
=
False
,
quant_config
=
quant_config
,
)
self
.
o_proj
=
RowParallelLinear
(
self
.
total_num_heads
*
self
.
head_dim
,
hidden_size
,
...
...
@@ -154,6 +198,9 @@ class Grok1Attention(nn.Module):
base
=
int
(
self
.
rope_theta
),
is_neox_style
=
True
,
)
logit_cap
=
max
(
getattr
(
config
,
"attn_logit_softcapping"
,
30.0
),
0.0
)
self
.
attn
=
RadixAttention
(
self
.
num_heads
,
self
.
head_dim
,
...
...
@@ -162,7 +209,6 @@ class Grok1Attention(nn.Module):
layer_id
=
layer_id
,
logit_cap
=
logit_cap
,
)
# TODO(lianmin): load logit cap from config
def
forward
(
self
,
...
...
@@ -186,10 +232,12 @@ class Grok1DecoderLayer(nn.Module):
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
)
->
None
:
super
().
__init__
()
self
.
num_experts
=
config
.
num_local_experts
self
.
hidden_size
=
config
.
hidden_size
rope_theta
=
getattr
(
config
,
"rope_theta"
,
10000
)
self
.
self_attn
=
Grok1Attention
(
config
=
config
,
hidden_size
=
self
.
hidden_size
,
num_heads
=
config
.
num_attention_heads
,
max_position
=
config
.
max_position_embeddings
,
...
...
@@ -199,11 +247,17 @@ class Grok1DecoderLayer(nn.Module):
quant_config
=
quant_config
,
)
self
.
block_sparse_moe
=
Grok1MoE
(
config
=
config
,
num_experts
=
config
.
num_local_experts
,
top_k
=
config
.
num_experts_per_tok
,
hidden_size
=
config
.
hidden_size
,
intermediate_size
=
config
.
intermediate_size
,
intermediate_size
=
getattr
(
config
,
"moe_intermediate_size"
,
getattr
(
config
,
"intermediate_size"
,
None
),
),
quant_config
=
quant_config
,
reduce_results
=
True
,
)
self
.
pre_attn_norm
=
RMSNorm
(
config
.
hidden_size
,
eps
=
config
.
rms_norm_eps
)
self
.
post_attn_norm
=
RMSNorm
(
config
.
hidden_size
,
eps
=
config
.
rms_norm_eps
)
...
...
@@ -284,6 +338,7 @@ class Grok1ForCausalLM(nn.Module):
self
,
config
:
PretrainedConfig
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
cache_config
=
None
,
)
->
None
:
super
().
__init__
()
self
.
config
=
config
...
...
@@ -310,6 +365,8 @@ class Grok1ForCausalLM(nn.Module):
(
"qkv_proj"
,
"q_proj"
,
"q"
),
(
"qkv_proj"
,
"k_proj"
,
"k"
),
(
"qkv_proj"
,
"v_proj"
,
"v"
),
(
"gate_up_proj"
,
"gate_proj"
,
0
),
(
"gate_up_proj"
,
"up_proj"
,
1
),
]
# Params for weights, fp8 weight scales, fp8 activation scales
...
...
@@ -345,6 +402,11 @@ class Grok1ForCausalLM(nn.Module):
continue
name
=
name
.
replace
(
weight_name
,
param_name
)
if
(
name
.
endswith
(
".bias"
)
or
name
.
endswith
(
"_bias"
)
)
and
name
not
in
params_dict
:
continue
param
=
params_dict
[
name
]
weight_loader
=
param
.
weight_loader
weight_loader
(
...
...
@@ -357,7 +419,9 @@ class Grok1ForCausalLM(nn.Module):
break
else
:
# Skip loading extra bias for GPTQ models.
if
name
.
endswith
(
".bias"
)
and
name
not
in
params_dict
:
if
(
name
.
endswith
(
".bias"
)
or
name
.
endswith
(
"_bias"
)
)
and
name
not
in
params_dict
:
continue
# Skip loading kv_scale from ckpts towards new design.
if
name
.
endswith
(
".kv_scale"
)
and
name
not
in
params_dict
:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment