Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
bdf946bf
"docs/code-docs/vscode:/vscode.git/clone" did not exist on "c82756cd1571acde62916554446c5a6b324b5f20"
Unverified
Commit
bdf946bf
authored
Jan 02, 2025
by
Lianmin Zheng
Committed by
GitHub
Jan 02, 2025
Browse files
Support loading pre-sharded moe weights (#2716)
parent
8c8779cd
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
111 additions
and
32 deletions
+111
-32
python/sglang/srt/layers/moe/fused_moe_triton/layer.py
python/sglang/srt/layers/moe/fused_moe_triton/layer.py
+14
-6
python/sglang/srt/models/grok.py
python/sglang/srt/models/grok.py
+97
-26
No files found.
python/sglang/srt/layers/moe/fused_moe_triton/layer.py
View file @
bdf946bf
...
@@ -321,9 +321,12 @@ class FusedMoE(torch.nn.Module):
...
@@ -321,9 +321,12 @@ class FusedMoE(torch.nn.Module):
# Index the loaded weight for tp sharding.
# Index the loaded weight for tp sharding.
# gate_up_proj: "MergedColumnParallel", so tp sharding on output_dim
# gate_up_proj: "MergedColumnParallel", so tp sharding on output_dim
shard_size
=
expert_data
.
shape
[
shard_dim
]
//
2
shard_size
=
expert_data
.
shape
[
shard_dim
]
//
2
loaded_weight
=
loaded_weight
.
narrow
(
shard_dim
,
shard_size
*
tp_rank
,
shard_size
if
not
self
.
use_presharded_weights
:
)
loaded_weight
=
loaded_weight
.
narrow
(
shard_dim
,
shard_size
*
tp_rank
,
shard_size
)
# Narrow parameter and load.
# Narrow parameter and load.
# w1, gate_proj: Load into first logical weight of w13.
# w1, gate_proj: Load into first logical weight of w13.
if
shard_id
==
"w1"
:
if
shard_id
==
"w1"
:
...
@@ -347,9 +350,12 @@ class FusedMoE(torch.nn.Module):
...
@@ -347,9 +350,12 @@ class FusedMoE(torch.nn.Module):
# down_proj: "RowParallel" so tp sharding on input_dim
# down_proj: "RowParallel" so tp sharding on input_dim
# Narrow parameter and load.
# Narrow parameter and load.
shard_size
=
expert_data
.
shape
[
shard_dim
]
shard_size
=
expert_data
.
shape
[
shard_dim
]
loaded_weight
=
loaded_weight
.
narrow
(
shard_dim
,
shard_size
*
tp_rank
,
shard_size
if
not
self
.
use_presharded_weights
:
)
loaded_weight
=
loaded_weight
.
narrow
(
shard_dim
,
shard_size
*
tp_rank
,
shard_size
)
# w2, down_proj: Load into only logical weight of w2.
# w2, down_proj: Load into only logical weight of w2.
expert_data
.
copy_
(
loaded_weight
)
expert_data
.
copy_
(
loaded_weight
)
...
@@ -389,7 +395,9 @@ class FusedMoE(torch.nn.Module):
...
@@ -389,7 +395,9 @@ class FusedMoE(torch.nn.Module):
weight_name
:
str
,
weight_name
:
str
,
shard_id
:
str
,
shard_id
:
str
,
expert_id
:
int
,
expert_id
:
int
,
use_presharded_weights
:
bool
=
False
,
)
->
None
:
)
->
None
:
self
.
use_presharded_weights
=
use_presharded_weights
# compressed-tensors checkpoints with packed weights are stored flipped
# compressed-tensors checkpoints with packed weights are stored flipped
# TODO (mgoin): check self.quant_method.quant_config.quant_format
# TODO (mgoin): check self.quant_method.quant_config.quant_format
...
...
python/sglang/srt/models/grok.py
View file @
bdf946bf
...
@@ -16,13 +16,16 @@
...
@@ -16,13 +16,16 @@
# https://github.com/vllm-project/vllm/blob/c7f2cf2b7f67bce5842fedfdba508440fe257375/vllm/model_executor/models/mixtral.py#L1
# https://github.com/vllm-project/vllm/blob/c7f2cf2b7f67bce5842fedfdba508440fe257375/vllm/model_executor/models/mixtral.py#L1
"""Inference-only Grok1 model."""
"""Inference-only Grok1 model."""
from
typing
import
Iterable
,
Optional
,
Tuple
from
typing
import
Iterable
,
List
,
Optional
,
Tuple
import
torch
import
torch
import
torch.nn.functional
as
F
import
torch.nn.functional
as
F
from
torch
import
nn
from
torch
import
nn
from
transformers
import
PretrainedConfig
from
transformers
import
PretrainedConfig
from
vllm.distributed
import
get_tensor_model_parallel_world_size
from
vllm.distributed
import
(
get_tensor_model_parallel_rank
,
get_tensor_model_parallel_world_size
,
)
from
vllm.model_executor.layers.rotary_embedding
import
get_rope
from
vllm.model_executor.layers.rotary_embedding
import
get_rope
from
sglang.srt.layers.activation
import
GeluAndMul
from
sglang.srt.layers.activation
import
GeluAndMul
...
@@ -42,6 +45,7 @@ from sglang.srt.layers.vocab_parallel_embedding import (
...
@@ -42,6 +45,7 @@ from sglang.srt.layers.vocab_parallel_embedding import (
VocabParallelEmbedding
,
VocabParallelEmbedding
,
)
)
from
sglang.srt.model_executor.forward_batch_info
import
ForwardBatch
from
sglang.srt.model_executor.forward_batch_info
import
ForwardBatch
from
sglang.srt.model_loader.loader
import
DefaultModelLoader
from
sglang.srt.model_loader.weight_utils
import
default_weight_loader
from
sglang.srt.model_loader.weight_utils
import
default_weight_loader
...
@@ -347,6 +351,16 @@ class Grok1ForCausalLM(nn.Module):
...
@@ -347,6 +351,16 @@ class Grok1ForCausalLM(nn.Module):
self
.
lm_head
=
ParallelLMHead
(
config
.
vocab_size
,
config
.
hidden_size
)
self
.
lm_head
=
ParallelLMHead
(
config
.
vocab_size
,
config
.
hidden_size
)
self
.
logits_processor
=
LogitsProcessor
(
config
)
self
.
logits_processor
=
LogitsProcessor
(
config
)
# Monkey patch _prepare_weights to load pre-sharded weights
if
(
self
.
config
.
num_local_experts
>
0
and
get_tensor_model_parallel_world_size
()
>
1
):
self
.
use_presharded_weights
=
True
setattr
(
DefaultModelLoader
,
"_prepare_weights"
,
_prepare_presharded_weights
)
else
:
self
.
use_presharded_weights
=
False
def
forward
(
def
forward
(
self
,
self
,
input_ids
:
torch
.
Tensor
,
input_ids
:
torch
.
Tensor
,
...
@@ -359,7 +373,15 @@ class Grok1ForCausalLM(nn.Module):
...
@@ -359,7 +373,15 @@ class Grok1ForCausalLM(nn.Module):
input_ids
,
hidden_states
,
self
.
lm_head
,
forward_batch
input_ids
,
hidden_states
,
self
.
lm_head
,
forward_batch
)
)
def
load_weights
(
self
,
weights
:
Iterable
[
Tuple
[
str
,
torch
.
Tensor
]]):
def
load_weights
(
self
,
weights
:
Iterable
[
Tuple
[
str
,
torch
.
Tensor
]],
use_presharded_weights
:
bool
|
None
=
None
,
):
if
use_presharded_weights
is
None
:
use_presharded_weights
=
self
.
use_presharded_weights
num_experts
=
self
.
config
.
num_local_experts
stacked_params_mapping
=
[
stacked_params_mapping
=
[
# (param_name, shard_name, shard_id)
# (param_name, shard_name, shard_id)
(
"qkv_proj"
,
"q_proj"
,
"q"
),
(
"qkv_proj"
,
"q_proj"
,
"q"
),
...
@@ -375,10 +397,23 @@ class Grok1ForCausalLM(nn.Module):
...
@@ -375,10 +397,23 @@ class Grok1ForCausalLM(nn.Module):
ckpt_gate_proj_name
=
"w1"
,
ckpt_gate_proj_name
=
"w1"
,
ckpt_down_proj_name
=
"w2"
,
ckpt_down_proj_name
=
"w2"
,
ckpt_up_proj_name
=
"w3"
,
ckpt_up_proj_name
=
"w3"
,
num_experts
=
self
.
config
.
num_local
_experts
,
num_experts
=
num
_experts
,
)
)
params_dict
=
dict
(
self
.
named_parameters
())
params_dict
=
dict
(
self
.
named_parameters
())
all_names
=
set
(
params_dict
.
keys
())
hit_names
=
set
()
def
load_weight_wrapper
(
name
,
loaded_weight
,
*
args
,
**
kwargs
):
if
name
not
in
params_dict
:
return
param
=
params_dict
[
name
]
weight_loader
=
getattr
(
param
,
"weight_loader"
,
default_weight_loader
)
weight_loader
(
param
,
loaded_weight
,
*
args
,
**
kwargs
)
hit_names
.
add
(
name
)
for
name
,
loaded_weight
in
weights
:
for
name
,
loaded_weight
in
weights
:
if
"rotary_emb.inv_freq"
in
name
:
if
"rotary_emb.inv_freq"
in
name
:
continue
continue
...
@@ -391,9 +426,7 @@ class Grok1ForCausalLM(nn.Module):
...
@@ -391,9 +426,7 @@ class Grok1ForCausalLM(nn.Module):
if
name
.
endswith
(
".bias"
)
and
name
not
in
params_dict
:
if
name
.
endswith
(
".bias"
)
and
name
not
in
params_dict
:
continue
continue
param
=
params_dict
[
name
]
load_weight_wrapper
(
name
,
loaded_weight
,
shard_id
)
weight_loader
=
param
.
weight_loader
weight_loader
(
param
,
loaded_weight
,
shard_id
)
break
break
else
:
else
:
for
mapping
in
expert_params_mapping
:
for
mapping
in
expert_params_mapping
:
...
@@ -402,38 +435,76 @@ class Grok1ForCausalLM(nn.Module):
...
@@ -402,38 +435,76 @@ class Grok1ForCausalLM(nn.Module):
continue
continue
name
=
name
.
replace
(
weight_name
,
param_name
)
name
=
name
.
replace
(
weight_name
,
param_name
)
if
(
if
use_presharded_weights
:
name
.
endswith
(
".bias"
)
or
name
.
endswith
(
"_bias"
)
extra_kwargs
=
{
)
and
name
not
in
params_dict
:
"use_presharded_weights"
:
use_presharded_weights
continue
}
else
:
extra_kwargs
=
{}
param
=
params_dict
[
name
]
load_weight_wrapper
(
weight_loader
=
param
.
weight_loader
name
,
weight_loader
(
param
,
loaded_weight
,
loaded_weight
,
name
,
name
,
shard_id
=
shard_id
,
shard_id
=
shard_id
,
expert_id
=
expert_id
,
expert_id
=
expert_id
,
**
extra_kwargs
,
)
)
break
break
else
:
else
:
# Skip loading extra bias for GPTQ models.
# Skip loading extra bias for GPTQ models.
if
(
if
name
.
endswith
(
".bias"
)
and
name
not
in
params_dict
:
name
.
endswith
(
".bias"
)
or
name
.
endswith
(
"_bias"
)
)
and
name
not
in
params_dict
:
continue
# Skip loading kv_scale from ckpts towards new design.
if
name
.
endswith
(
".kv_scale"
)
and
name
not
in
params_dict
:
continue
continue
if
name
is
None
:
if
name
is
None
:
continue
continue
param
=
params_dict
[
name
]
load_weight_wrapper
(
name
=
name
,
loaded_weight
=
loaded_weight
)
weight_loader
=
getattr
(
param
,
"weight_loader"
,
default_weight_loader
)
old_prepare_weights
=
getattr
(
DefaultModelLoader
,
"_prepare_weights"
)
weight_loader
(
param
,
loaded_weight
)
def
_prepare_presharded_weights
(
self
,
model_name_or_path
:
str
,
revision
:
Optional
[
str
],
fall_back_to_pt
:
bool
)
->
Tuple
[
str
,
List
[
str
],
bool
]:
import
glob
import
os
if
get_tensor_model_parallel_world_size
()
==
1
:
return
old_prepare_weights
(
self
,
model_name_or_path
,
revision
,
fall_back_to_pt
)
if
not
os
.
path
.
isdir
(
model_name_or_path
):
from
sglang.srt.model_loader.weight_utils
import
download_weights_from_hf
allow_patterns
=
[
"*.safetensors"
,
"*.bin"
]
hf_folder
=
download_weights_from_hf
(
model_name_or_path
,
self
.
load_config
.
download_dir
,
allow_patterns
,
revision
,
ignore_patterns
=
self
.
load_config
.
ignore_patterns
,
)
else
:
hf_folder
=
model_name_or_path
tp_rank
=
get_tensor_model_parallel_rank
()
# The old format
allow_patterns
=
[
f
"*-
{
tp_rank
:
03
d
}
.bin"
]
# The new format
allow_patterns
+=
[
f
"*-TP-
{
tp_rank
:
03
d
}
.safetensors"
,
"*-TP-common.safetensors"
]
hf_weights_files
:
List
[
str
]
=
[]
for
pattern
in
allow_patterns
:
hf_weights_files
+=
glob
.
glob
(
os
.
path
.
join
(
hf_folder
,
pattern
))
if
hf_weights_files
[
0
].
endswith
(
"safetensors"
):
use_safetensors
=
True
else
:
use_safetensors
=
False
return
hf_folder
,
hf_weights_files
,
use_safetensors
class
Grok1ModelForCausalLM
(
Grok1ForCausalLM
):
class
Grok1ModelForCausalLM
(
Grok1ForCausalLM
):
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment