Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
bdf946bf
Unverified
Commit
bdf946bf
authored
Jan 02, 2025
by
Lianmin Zheng
Committed by
GitHub
Jan 02, 2025
Browse files
Support loading pre-sharded moe weights (#2716)
parent
8c8779cd
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
111 additions
and
32 deletions
+111
-32
python/sglang/srt/layers/moe/fused_moe_triton/layer.py
python/sglang/srt/layers/moe/fused_moe_triton/layer.py
+14
-6
python/sglang/srt/models/grok.py
python/sglang/srt/models/grok.py
+97
-26
No files found.
python/sglang/srt/layers/moe/fused_moe_triton/layer.py
View file @
bdf946bf
...
...
@@ -321,9 +321,12 @@ class FusedMoE(torch.nn.Module):
# Index the loaded weight for tp sharding.
# gate_up_proj: "MergedColumnParallel", so tp sharding on output_dim
shard_size
=
expert_data
.
shape
[
shard_dim
]
//
2
loaded_weight
=
loaded_weight
.
narrow
(
shard_dim
,
shard_size
*
tp_rank
,
shard_size
)
if
not
self
.
use_presharded_weights
:
loaded_weight
=
loaded_weight
.
narrow
(
shard_dim
,
shard_size
*
tp_rank
,
shard_size
)
# Narrow parameter and load.
# w1, gate_proj: Load into first logical weight of w13.
if
shard_id
==
"w1"
:
...
...
@@ -347,9 +350,12 @@ class FusedMoE(torch.nn.Module):
# down_proj: "RowParallel" so tp sharding on input_dim
# Narrow parameter and load.
shard_size
=
expert_data
.
shape
[
shard_dim
]
loaded_weight
=
loaded_weight
.
narrow
(
shard_dim
,
shard_size
*
tp_rank
,
shard_size
)
if
not
self
.
use_presharded_weights
:
loaded_weight
=
loaded_weight
.
narrow
(
shard_dim
,
shard_size
*
tp_rank
,
shard_size
)
# w2, down_proj: Load into only logical weight of w2.
expert_data
.
copy_
(
loaded_weight
)
...
...
@@ -389,7 +395,9 @@ class FusedMoE(torch.nn.Module):
weight_name
:
str
,
shard_id
:
str
,
expert_id
:
int
,
use_presharded_weights
:
bool
=
False
,
)
->
None
:
self
.
use_presharded_weights
=
use_presharded_weights
# compressed-tensors checkpoints with packed weights are stored flipped
# TODO (mgoin): check self.quant_method.quant_config.quant_format
...
...
python/sglang/srt/models/grok.py
View file @
bdf946bf
...
...
@@ -16,13 +16,16 @@
# https://github.com/vllm-project/vllm/blob/c7f2cf2b7f67bce5842fedfdba508440fe257375/vllm/model_executor/models/mixtral.py#L1
"""Inference-only Grok1 model."""
from
typing
import
Iterable
,
Optional
,
Tuple
from
typing
import
Iterable
,
List
,
Optional
,
Tuple
import
torch
import
torch.nn.functional
as
F
from
torch
import
nn
from
transformers
import
PretrainedConfig
from
vllm.distributed
import
get_tensor_model_parallel_world_size
from
vllm.distributed
import
(
get_tensor_model_parallel_rank
,
get_tensor_model_parallel_world_size
,
)
from
vllm.model_executor.layers.rotary_embedding
import
get_rope
from
sglang.srt.layers.activation
import
GeluAndMul
...
...
@@ -42,6 +45,7 @@ from sglang.srt.layers.vocab_parallel_embedding import (
VocabParallelEmbedding
,
)
from
sglang.srt.model_executor.forward_batch_info
import
ForwardBatch
from
sglang.srt.model_loader.loader
import
DefaultModelLoader
from
sglang.srt.model_loader.weight_utils
import
default_weight_loader
...
...
@@ -347,6 +351,16 @@ class Grok1ForCausalLM(nn.Module):
self
.
lm_head
=
ParallelLMHead
(
config
.
vocab_size
,
config
.
hidden_size
)
self
.
logits_processor
=
LogitsProcessor
(
config
)
# Monkey patch _prepare_weights to load pre-sharded weights
if
(
self
.
config
.
num_local_experts
>
0
and
get_tensor_model_parallel_world_size
()
>
1
):
self
.
use_presharded_weights
=
True
setattr
(
DefaultModelLoader
,
"_prepare_weights"
,
_prepare_presharded_weights
)
else
:
self
.
use_presharded_weights
=
False
def
forward
(
self
,
input_ids
:
torch
.
Tensor
,
...
...
@@ -359,7 +373,15 @@ class Grok1ForCausalLM(nn.Module):
input_ids
,
hidden_states
,
self
.
lm_head
,
forward_batch
)
def
load_weights
(
self
,
weights
:
Iterable
[
Tuple
[
str
,
torch
.
Tensor
]]):
def
load_weights
(
self
,
weights
:
Iterable
[
Tuple
[
str
,
torch
.
Tensor
]],
use_presharded_weights
:
bool
|
None
=
None
,
):
if
use_presharded_weights
is
None
:
use_presharded_weights
=
self
.
use_presharded_weights
num_experts
=
self
.
config
.
num_local_experts
stacked_params_mapping
=
[
# (param_name, shard_name, shard_id)
(
"qkv_proj"
,
"q_proj"
,
"q"
),
...
...
@@ -375,10 +397,23 @@ class Grok1ForCausalLM(nn.Module):
ckpt_gate_proj_name
=
"w1"
,
ckpt_down_proj_name
=
"w2"
,
ckpt_up_proj_name
=
"w3"
,
num_experts
=
self
.
config
.
num_local
_experts
,
num_experts
=
num
_experts
,
)
params_dict
=
dict
(
self
.
named_parameters
())
all_names
=
set
(
params_dict
.
keys
())
hit_names
=
set
()
def
load_weight_wrapper
(
name
,
loaded_weight
,
*
args
,
**
kwargs
):
if
name
not
in
params_dict
:
return
param
=
params_dict
[
name
]
weight_loader
=
getattr
(
param
,
"weight_loader"
,
default_weight_loader
)
weight_loader
(
param
,
loaded_weight
,
*
args
,
**
kwargs
)
hit_names
.
add
(
name
)
for
name
,
loaded_weight
in
weights
:
if
"rotary_emb.inv_freq"
in
name
:
continue
...
...
@@ -391,9 +426,7 @@ class Grok1ForCausalLM(nn.Module):
if
name
.
endswith
(
".bias"
)
and
name
not
in
params_dict
:
continue
param
=
params_dict
[
name
]
weight_loader
=
param
.
weight_loader
weight_loader
(
param
,
loaded_weight
,
shard_id
)
load_weight_wrapper
(
name
,
loaded_weight
,
shard_id
)
break
else
:
for
mapping
in
expert_params_mapping
:
...
...
@@ -402,38 +435,76 @@ class Grok1ForCausalLM(nn.Module):
continue
name
=
name
.
replace
(
weight_name
,
param_name
)
if
(
name
.
endswith
(
".bias"
)
or
name
.
endswith
(
"_bias"
)
)
and
name
not
in
params_dict
:
continue
if
use_presharded_weights
:
extra_kwargs
=
{
"use_presharded_weights"
:
use_presharded_weights
}
else
:
extra_kwargs
=
{}
param
=
params_dict
[
name
]
weight_loader
=
param
.
weight_loader
weight_loader
(
param
,
load_weight_wrapper
(
name
,
loaded_weight
,
name
,
shard_id
=
shard_id
,
expert_id
=
expert_id
,
**
extra_kwargs
,
)
break
else
:
# Skip loading extra bias for GPTQ models.
if
(
name
.
endswith
(
".bias"
)
or
name
.
endswith
(
"_bias"
)
)
and
name
not
in
params_dict
:
continue
# Skip loading kv_scale from ckpts towards new design.
if
name
.
endswith
(
".kv_scale"
)
and
name
not
in
params_dict
:
if
name
.
endswith
(
".bias"
)
and
name
not
in
params_dict
:
continue
if
name
is
None
:
continue
param
=
params_dict
[
name
]
weight_loader
=
getattr
(
param
,
"weight_loader"
,
default_weight_loader
)
weight_loader
(
param
,
loaded_weight
)
load_weight_wrapper
(
name
=
name
,
loaded_weight
=
loaded_weight
)
old_prepare_weights
=
getattr
(
DefaultModelLoader
,
"_prepare_weights"
)
def
_prepare_presharded_weights
(
self
,
model_name_or_path
:
str
,
revision
:
Optional
[
str
],
fall_back_to_pt
:
bool
)
->
Tuple
[
str
,
List
[
str
],
bool
]:
import
glob
import
os
if
get_tensor_model_parallel_world_size
()
==
1
:
return
old_prepare_weights
(
self
,
model_name_or_path
,
revision
,
fall_back_to_pt
)
if
not
os
.
path
.
isdir
(
model_name_or_path
):
from
sglang.srt.model_loader.weight_utils
import
download_weights_from_hf
allow_patterns
=
[
"*.safetensors"
,
"*.bin"
]
hf_folder
=
download_weights_from_hf
(
model_name_or_path
,
self
.
load_config
.
download_dir
,
allow_patterns
,
revision
,
ignore_patterns
=
self
.
load_config
.
ignore_patterns
,
)
else
:
hf_folder
=
model_name_or_path
tp_rank
=
get_tensor_model_parallel_rank
()
# The old format
allow_patterns
=
[
f
"*-
{
tp_rank
:
03
d
}
.bin"
]
# The new format
allow_patterns
+=
[
f
"*-TP-
{
tp_rank
:
03
d
}
.safetensors"
,
"*-TP-common.safetensors"
]
hf_weights_files
:
List
[
str
]
=
[]
for
pattern
in
allow_patterns
:
hf_weights_files
+=
glob
.
glob
(
os
.
path
.
join
(
hf_folder
,
pattern
))
if
hf_weights_files
[
0
].
endswith
(
"safetensors"
):
use_safetensors
=
True
else
:
use_safetensors
=
False
return
hf_folder
,
hf_weights_files
,
use_safetensors
class
Grok1ModelForCausalLM
(
Grok1ForCausalLM
):
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment