Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
a59636bb
Unverified
Commit
a59636bb
authored
Aug 14, 2024
by
Lianmin Zheng
Committed by
GitHub
Aug 14, 2024
Browse files
Update grok 1 model (#1095)
parent
fe502432
Changes
11
Hide whitespace changes
Inline
Side-by-side
Showing
11 changed files
with
813 additions
and
513 deletions
+813
-513
benchmark/gsm8k/bench_sglang.py
benchmark/gsm8k/bench_sglang.py
+3
-0
python/sglang/bench_latency.py
python/sglang/bench_latency.py
+1
-0
python/sglang/srt/layers/activation.py
python/sglang/srt/layers/activation.py
+0
-1
python/sglang/srt/layers/fused_moe/__init__.py
python/sglang/srt/layers/fused_moe/__init__.py
+1
-0
python/sglang/srt/layers/fused_moe/fused_moe.py
python/sglang/srt/layers/fused_moe/fused_moe.py
+165
-108
python/sglang/srt/layers/fused_moe/layer.py
python/sglang/srt/layers/fused_moe/layer.py
+587
-0
python/sglang/srt/layers/logits_processor.py
python/sglang/srt/layers/logits_processor.py
+4
-4
python/sglang/srt/model_executor/model_runner.py
python/sglang/srt/model_executor/model_runner.py
+2
-2
python/sglang/srt/models/grok.py
python/sglang/srt/models/grok.py
+49
-395
python/sglang/srt/models/mixtral.py
python/sglang/srt/models/mixtral.py
+0
-1
python/sglang/srt/utils.py
python/sglang/srt/utils.py
+1
-2
No files found.
benchmark/gsm8k/bench_sglang.py
View file @
a59636bb
...
@@ -88,6 +88,9 @@ def main(args):
...
@@ -88,6 +88,9 @@ def main(args):
for
i
in
range
(
len
(
states
)):
for
i
in
range
(
len
(
states
)):
preds
.
append
(
get_answer_value
(
states
[
i
][
"answer"
]))
preds
.
append
(
get_answer_value
(
states
[
i
][
"answer"
]))
# print(f"{preds=}")
# print(f"{labels=}")
# Compute accuracy
# Compute accuracy
acc
=
np
.
mean
(
np
.
array
(
preds
)
==
np
.
array
(
labels
))
acc
=
np
.
mean
(
np
.
array
(
preds
)
==
np
.
array
(
labels
))
invalid
=
np
.
mean
(
np
.
array
(
preds
)
==
INVALID
)
invalid
=
np
.
mean
(
np
.
array
(
preds
)
==
INVALID
)
...
...
python/sglang/bench_latency.py
View file @
a59636bb
...
@@ -221,6 +221,7 @@ def correctness_test(
...
@@ -221,6 +221,7 @@ def correctness_test(
# Prepare inputs
# Prepare inputs
input_ids
,
reqs
=
prepare_inputs_for_correctness_test
(
bench_args
,
tokenizer
)
input_ids
,
reqs
=
prepare_inputs_for_correctness_test
(
bench_args
,
tokenizer
)
rank_print
(
f
"
{
input_ids
=
}
"
)
if
bench_args
.
cut_len
>
0
:
if
bench_args
.
cut_len
>
0
:
# Prefill
# Prefill
...
...
python/sglang/srt/layers/activation.py
View file @
a59636bb
...
@@ -14,7 +14,6 @@ limitations under the License.
...
@@ -14,7 +14,6 @@ limitations under the License.
"""Fused operators for activation layers."""
"""Fused operators for activation layers."""
import
torch
import
torch
import
torch.nn
as
nn
import
torch.nn.functional
as
F
import
torch.nn.functional
as
F
from
flashinfer.activation
import
silu_and_mul
from
flashinfer.activation
import
silu_and_mul
from
vllm.model_executor.custom_op
import
CustomOp
from
vllm.model_executor.custom_op
import
CustomOp
...
...
python/sglang/srt/layers/fused_moe/__init__.py
0 → 100644
View file @
a59636bb
from
sglang.srt.layers.fused_moe.layer
import
FusedMoE
,
FusedMoEMethodBase
python/sglang/srt/layers/fused_moe.py
→
python/sglang/srt/layers/fused_moe
/fused_moe
.py
View file @
a59636bb
"""
Copyright 2023-2024 SGLang Team
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
# Adapted from
# Adapted from
# https://github.com/vllm-project/vllm/
blob/c7f2cf2b7f67bce5842fedfdba508440fe257375
/vllm/model_executor/layers/fused_moe
/fused_moe.py#L1
# https://github.com/vllm-project/vllm/
tree/v0.5.4
/vllm/model_executor/layers/fused_moe
"""Fused MoE kernel."""
"""Fused MoE kernel."""
import
functools
import
functools
import
json
import
json
...
@@ -24,6 +9,7 @@ from typing import Any, Dict, Optional, Tuple
...
@@ -24,6 +9,7 @@ from typing import Any, Dict, Optional, Tuple
import
torch
import
torch
import
triton
import
triton
import
triton.language
as
tl
import
triton.language
as
tl
import
vllm.envs
as
envs
from
vllm
import
_custom_ops
as
ops
from
vllm
import
_custom_ops
as
ops
from
vllm.logger
import
init_logger
from
vllm.logger
import
init_logger
...
@@ -373,6 +359,31 @@ def get_default_config(
...
@@ -373,6 +359,31 @@ def get_default_config(
return
config
return
config
def
try_get_optimal_moe_config
(
w1_shape
:
Tuple
[
int
,
...],
w2_shape
:
Tuple
[
int
,
...],
top_k
:
int
,
dtype
:
Optional
[
str
],
M
:
int
,
override_config
:
Optional
[
Dict
[
str
,
Any
]]
=
None
,
):
if
override_config
:
config
=
override_config
else
:
# First try to load optimal config from the file
E
,
_
,
N
=
w2_shape
configs
=
get_moe_configs
(
E
,
N
,
dtype
)
if
configs
:
# If an optimal configuration map has been found, look up the
# optimal config
config
=
configs
[
min
(
configs
.
keys
(),
key
=
lambda
x
:
abs
(
x
-
M
))]
else
:
# Else use the default config
config
=
get_default_config
(
M
,
E
,
N
,
w1_shape
[
2
],
top_k
,
dtype
)
return
config
def
fused_topk
(
def
fused_topk
(
hidden_states
:
torch
.
Tensor
,
hidden_states
:
torch
.
Tensor
,
gating_output
:
torch
.
Tensor
,
gating_output
:
torch
.
Tensor
,
...
@@ -403,6 +414,41 @@ def fused_topk(
...
@@ -403,6 +414,41 @@ def fused_topk(
return
topk_weights
,
topk_ids
return
topk_weights
,
topk_ids
# This is used by the Deepseek-V2 model
def
grouped_topk
(
hidden_states
:
torch
.
Tensor
,
gating_output
:
torch
.
Tensor
,
topk
:
int
,
renormalize
:
bool
,
num_expert_group
:
int
=
0
,
topk_group
:
int
=
0
,
):
assert
hidden_states
.
shape
[
0
]
==
gating_output
.
shape
[
0
],
"Number of tokens mismatch"
scores
=
torch
.
softmax
(
gating_output
,
dim
=-
1
)
num_token
=
scores
.
shape
[
0
]
group_scores
=
(
scores
.
view
(
num_token
,
num_expert_group
,
-
1
).
max
(
dim
=-
1
).
values
)
# [n, n_group]
group_idx
=
torch
.
topk
(
group_scores
,
k
=
topk_group
,
dim
=-
1
,
sorted
=
False
)[
1
]
# [n, top_k_group]
group_mask
=
torch
.
zeros_like
(
group_scores
)
# [n, n_group]
group_mask
.
scatter_
(
1
,
group_idx
,
1
)
# [n, n_group]
score_mask
=
(
group_mask
.
unsqueeze
(
-
1
)
.
expand
(
num_token
,
num_expert_group
,
scores
.
shape
[
-
1
]
//
num_expert_group
)
.
reshape
(
num_token
,
-
1
)
)
# [n, e]
tmp_scores
=
scores
.
masked_fill
(
~
score_mask
.
bool
(),
0.0
)
# [n, e]
topk_weights
,
topk_ids
=
torch
.
topk
(
tmp_scores
,
k
=
topk
,
dim
=-
1
,
sorted
=
False
)
if
renormalize
:
topk_weights
=
topk_weights
/
topk_weights
.
sum
(
dim
=-
1
,
keepdim
=
True
)
return
topk_weights
,
topk_ids
def
fused_experts
(
def
fused_experts
(
hidden_states
:
torch
.
Tensor
,
hidden_states
:
torch
.
Tensor
,
w1
:
torch
.
Tensor
,
w1
:
torch
.
Tensor
,
...
@@ -425,24 +471,23 @@ def fused_experts(
...
@@ -425,24 +471,23 @@ def fused_experts(
assert
w2
.
is_contiguous
(),
"Expert weights2 must be contiguous"
assert
w2
.
is_contiguous
(),
"Expert weights2 must be contiguous"
assert
hidden_states
.
dtype
in
[
torch
.
float32
,
torch
.
float16
,
torch
.
bfloat16
]
assert
hidden_states
.
dtype
in
[
torch
.
float32
,
torch
.
float16
,
torch
.
bfloat16
]
M
,
_
=
hidden_states
.
shape
num_tokens
,
_
=
hidden_states
.
shape
E
,
N
,
_
=
w1
.
shape
E
,
N
,
_
=
w1
.
shape
# We execute the fused_moe kernel in chunks to circumvent this issue:
# https://github.com/vllm-project/vllm/issues/5938
CHUNK_SIZE
=
envs
.
VLLM_FUSED_MOE_CHUNK_SIZE
M
=
min
(
num_tokens
,
CHUNK_SIZE
)
get_config_func
=
functools
.
partial
(
try_get_optimal_moe_config
,
w1
.
shape
,
w2
.
shape
,
topk_ids
.
shape
[
1
],
"float8"
if
use_fp8
else
None
,
override_config
=
override_config
,
)
if
override_config
:
config
=
get_config_func
(
M
)
config
=
override_config
else
:
# First try to load optimal config from the file
configs
=
get_moe_configs
(
E
,
w2
.
shape
[
2
],
"float8"
if
use_fp8
else
None
)
if
configs
:
# If an optimal configuration map has been found, look up the
# optimal config
config
=
configs
[
min
(
configs
.
keys
(),
key
=
lambda
x
:
abs
(
x
-
M
))]
else
:
# Else use the default config
config
=
get_default_config
(
M
,
E
,
N
,
w1
.
shape
[
2
],
topk_ids
.
shape
[
1
],
"float8"
if
use_fp8
else
None
)
intermediate_cache1
=
torch
.
empty
(
intermediate_cache1
=
torch
.
empty
(
(
M
,
topk_ids
.
shape
[
1
],
N
),
(
M
,
topk_ids
.
shape
[
1
],
N
),
...
@@ -460,56 +505,85 @@ def fused_experts(
...
@@ -460,56 +505,85 @@ def fused_experts(
dtype
=
hidden_states
.
dtype
,
dtype
=
hidden_states
.
dtype
,
)
)
sorted_token_ids
,
expert_ids
,
num_tokens_post_padded
=
moe_align_block_size
(
topk_ids
,
config
[
"BLOCK_SIZE_M"
],
E
)
compute_type
=
tl
.
bfloat16
if
hidden_states
.
dtype
==
torch
.
bfloat16
else
tl
.
float16
compute_type
=
tl
.
bfloat16
if
hidden_states
.
dtype
==
torch
.
bfloat16
else
tl
.
float16
invoke_fused_moe_kernel
(
if
inplace
:
hidden_states
,
out_hidden_states
=
hidden_states
w1
,
else
:
intermediate_cache1
,
out_hidden_states
=
torch
.
empty_like
(
hidden_states
)
a1_scale
,
w1_scale
,
topk_weights
,
topk_ids
,
sorted_token_ids
,
expert_ids
,
num_tokens_post_padded
,
False
,
topk_ids
.
shape
[
1
],
config
,
compute_type
=
compute_type
,
use_fp8
=
use_fp8
,
)
ops
.
gelu_and_mul
(
intermediate_cache2
,
intermediate_cache1
.
view
(
-
1
,
N
))
for
chunk
in
range
((
num_tokens
//
CHUNK_SIZE
)
+
1
):
begin_chunk_idx
,
end_chunk_idx
=
(
chunk
*
CHUNK_SIZE
,
min
((
chunk
+
1
)
*
CHUNK_SIZE
,
num_tokens
),
)
curr_hidden_states
=
hidden_states
[
begin_chunk_idx
:
end_chunk_idx
]
tokens_in_chunk
,
_
=
curr_hidden_states
.
shape
if
tokens_in_chunk
==
0
:
break
if
tokens_in_chunk
<
CHUNK_SIZE
and
chunk
>
0
:
# Adjust the intermediate cache size and config for the last
# chunk. Note that in most cases we only have one chunk
# so the cache size and config are already set correctly and
# do not need to be adjusted.
intermediate_cache1
=
intermediate_cache1
[:
tokens_in_chunk
]
intermediate_cache2
=
intermediate_cache2
[:
tokens_in_chunk
]
intermediate_cache3
=
intermediate_cache3
[:
tokens_in_chunk
]
config
=
get_config_func
(
tokens_in_chunk
)
curr_topk_ids
=
topk_ids
[
begin_chunk_idx
:
end_chunk_idx
]
curr_topk_weights
=
topk_weights
[
begin_chunk_idx
:
end_chunk_idx
]
sorted_token_ids
,
expert_ids
,
num_tokens_post_padded
=
moe_align_block_size
(
curr_topk_ids
,
config
[
"BLOCK_SIZE_M"
],
E
)
invoke_fused_moe_kernel
(
invoke_fused_moe_kernel
(
intermediate_cache2
,
curr_hidden_states
,
w2
,
w1
,
intermediate_cache
3
,
intermediate_cache
1
,
a2
_scale
,
a1
_scale
,
w2
_scale
,
w1
_scale
,
topk_weights
,
curr_
topk_weights
,
topk_ids
,
curr_
topk_ids
,
sorted_token_ids
,
sorted_token_ids
,
expert_ids
,
expert_ids
,
num_tokens_post_padded
,
num_tokens_post_padded
,
Tru
e
,
Fals
e
,
1
,
topk_ids
.
shape
[
1
]
,
config
,
config
,
compute_type
=
compute_type
,
compute_type
=
compute_type
,
use_fp8
=
use_fp8
,
use_fp8
=
use_fp8
,
)
)
if
inplace
:
ops
.
gelu_and_mul
(
intermediate_cache2
,
intermediate_cache1
.
view
(
-
1
,
N
))
return
torch
.
sum
(
invoke_fused_moe_kernel
(
intermediate_cache2
,
w2
,
intermediate_cache3
,
a2_scale
,
w2_scale
,
curr_topk_weights
,
curr_topk_ids
,
sorted_token_ids
,
expert_ids
,
num_tokens_post_padded
,
True
,
1
,
config
,
compute_type
=
compute_type
,
use_fp8
=
use_fp8
,
)
torch
.
sum
(
intermediate_cache3
.
view
(
*
intermediate_cache3
.
shape
),
intermediate_cache3
.
view
(
*
intermediate_cache3
.
shape
),
dim
=
1
,
dim
=
1
,
out
=
hidden_states
,
out
=
out_
hidden_states
[
begin_chunk_idx
:
end_chunk_idx
]
,
)
)
return
torch
.
sum
(
intermediate_cache3
.
view
(
*
intermediate_cache3
.
shape
),
dim
=
1
)
return
out_hidden_states
def
fused_moe
(
def
fused_moe
(
...
@@ -521,6 +595,9 @@ def fused_moe(
...
@@ -521,6 +595,9 @@ def fused_moe(
renormalize
:
bool
,
renormalize
:
bool
,
inplace
:
bool
=
False
,
inplace
:
bool
=
False
,
override_config
:
Optional
[
Dict
[
str
,
Any
]]
=
None
,
override_config
:
Optional
[
Dict
[
str
,
Any
]]
=
None
,
use_grouped_topk
:
bool
=
False
,
num_expert_group
:
Optional
[
int
]
=
None
,
topk_group
:
Optional
[
int
]
=
None
,
use_fp8
:
bool
=
False
,
use_fp8
:
bool
=
False
,
w1_scale
:
Optional
[
torch
.
Tensor
]
=
None
,
w1_scale
:
Optional
[
torch
.
Tensor
]
=
None
,
w2_scale
:
Optional
[
torch
.
Tensor
]
=
None
,
w2_scale
:
Optional
[
torch
.
Tensor
]
=
None
,
...
@@ -543,6 +620,10 @@ def fused_moe(
...
@@ -543,6 +620,10 @@ def fused_moe(
Defaults to False.
Defaults to False.
- override_config (Optional[Dict[str, Any]]): Optional override
- override_config (Optional[Dict[str, Any]]): Optional override
for the kernel configuration.
for the kernel configuration.
- num_expert_group: Optional[int]: additional parameter for grouped_topk
- topk_group: Optional[int]: additional parameter for grouped_topk
- use_grouped_topk: If True, use grouped_topk instead of fused_topk
note: Deepseekv2 model uses grouped_topk
- use_fp8 (bool): If True, use fp8 arithmetic to compute the inner
- use_fp8 (bool): If True, use fp8 arithmetic to compute the inner
products for w1 and w2. Defaults to False.
products for w1 and w2. Defaults to False.
- w1_scale (Optional[torch.Tensor]): Optional scale to be used for
- w1_scale (Optional[torch.Tensor]): Optional scale to be used for
...
@@ -556,12 +637,18 @@ def fused_moe(
...
@@ -556,12 +637,18 @@ def fused_moe(
# Check constraints.
# Check constraints.
assert
gating_output
.
shape
[
1
]
==
w1
.
shape
[
0
],
"Number of experts mismatch"
assert
gating_output
.
shape
[
1
]
==
w1
.
shape
[
0
],
"Number of experts mismatch"
if
hasattr
(
ops
,
"topk_softmax"
):
if
use_grouped_topk
:
topk_weights
,
topk_ids
=
fused_topk
(
assert
num_expert_group
is
not
None
and
topk_group
is
not
None
hidden_states
,
gating_output
,
topk
,
renormalize
topk_weights
,
topk_ids
=
grouped_topk
(
hidden_states
,
gating_output
,
topk
,
renormalize
,
num_expert_group
,
topk_group
,
)
)
else
:
else
:
topk_weights
,
topk_ids
=
fused_topk
_v0_4_3
(
topk_weights
,
topk_ids
=
fused_topk
(
hidden_states
,
gating_output
,
topk
,
renormalize
hidden_states
,
gating_output
,
topk
,
renormalize
)
)
...
@@ -579,33 +666,3 @@ def fused_moe(
...
@@ -579,33 +666,3 @@ def fused_moe(
a1_scale
=
a1_scale
,
a1_scale
=
a1_scale
,
a2_scale
=
a2_scale
,
a2_scale
=
a2_scale
,
)
)
def
fused_topk_v0_4_3
(
hidden_states
:
torch
.
Tensor
,
gating_output
:
torch
.
Tensor
,
topk
:
int
,
renormalize
:
bool
,
):
import
vllm._moe_C
as
moe_kernels
M
,
_
=
hidden_states
.
shape
topk_weights
=
torch
.
empty
(
M
,
topk
,
dtype
=
torch
.
float32
,
device
=
hidden_states
.
device
)
topk_ids
=
torch
.
empty
(
M
,
topk
,
dtype
=
torch
.
int32
,
device
=
hidden_states
.
device
)
token_expert_indicies
=
torch
.
empty
(
M
,
topk
,
dtype
=
torch
.
int32
,
device
=
hidden_states
.
device
)
moe_kernels
.
topk_softmax
(
topk_weights
,
topk_ids
,
token_expert_indicies
,
gating_output
.
float
(),
# TODO(woosuk): Optimize this.
)
del
token_expert_indicies
# Not used. Will be used in the future.
if
renormalize
:
topk_weights
=
topk_weights
/
topk_weights
.
sum
(
dim
=-
1
,
keepdim
=
True
)
return
topk_weights
,
topk_ids
python/sglang/srt/layers/fused_moe/layer.py
0 → 100644
View file @
a59636bb
# Adapted from
# https://github.com/vllm-project/vllm/tree/v0.5.4/vllm/model_executor/layers/fused_moe
from
abc
import
abstractmethod
from
typing
import
List
,
Optional
,
Tuple
import
torch
from
vllm.distributed
import
(
get_tensor_model_parallel_rank
,
get_tensor_model_parallel_world_size
,
tensor_model_parallel_all_reduce
,
)
from
vllm.logger
import
init_logger
from
vllm.model_executor.custom_op
import
CustomOp
from
vllm.model_executor.layers.quantization.base_config
import
(
QuantizationConfig
,
QuantizeMethodBase
,
)
from
vllm.model_executor.layers.quantization.fp8
import
Fp8Config
from
vllm.model_executor.utils
import
set_weight_attrs
logger
=
init_logger
(
__name__
)
class
FusedMoEMethodBase
(
QuantizeMethodBase
):
@
abstractmethod
def
create_weights
(
self
,
layer
:
torch
.
nn
.
Module
,
num_experts
:
int
,
hidden_size
:
int
,
intermediate_size
:
int
,
params_dtype
:
torch
.
dtype
,
**
extra_weight_attrs
,
):
raise
NotImplementedError
@
abstractmethod
def
apply
(
self
,
layer
:
torch
.
nn
.
Module
,
x
:
torch
.
Tensor
,
router_logits
:
torch
.
Tensor
,
top_k
:
int
,
renormalize
:
bool
=
True
,
use_grouped_topk
:
bool
=
False
,
num_expert_group
:
Optional
[
int
]
=
None
,
topk_group
:
Optional
[
int
]
=
None
,
)
->
torch
.
Tensor
:
raise
NotImplementedError
class
UnquantizedFusedMoEMethod
(
FusedMoEMethodBase
,
CustomOp
):
"""MoE method without quantization."""
def
create_weights
(
self
,
layer
:
torch
.
nn
.
Module
,
num_experts
:
int
,
hidden_size
:
int
,
intermediate_size
:
int
,
params_dtype
:
torch
.
dtype
,
**
extra_weight_attrs
,
):
# Fused gate_up_proj (column parallel)
w13_weight
=
torch
.
nn
.
Parameter
(
torch
.
empty
(
num_experts
,
2
*
intermediate_size
,
hidden_size
,
dtype
=
params_dtype
),
requires_grad
=
False
,
)
layer
.
register_parameter
(
"w13_weight"
,
w13_weight
)
set_weight_attrs
(
w13_weight
,
extra_weight_attrs
)
# down_proj (row parallel)
w2_weight
=
torch
.
nn
.
Parameter
(
torch
.
empty
(
num_experts
,
hidden_size
,
intermediate_size
,
dtype
=
params_dtype
),
requires_grad
=
False
,
)
layer
.
register_parameter
(
"w2_weight"
,
w2_weight
)
set_weight_attrs
(
w2_weight
,
extra_weight_attrs
)
def
apply
(
self
,
layer
:
torch
.
nn
.
Module
,
x
:
torch
.
Tensor
,
router_logits
:
torch
.
Tensor
,
top_k
:
int
,
renormalize
:
bool
=
True
,
use_grouped_topk
:
bool
=
False
,
num_expert_group
:
Optional
[
int
]
=
None
,
topk_group
:
Optional
[
int
]
=
None
,
)
->
torch
.
Tensor
:
return
self
.
forward
(
x
,
layer
.
w13_weight
,
layer
.
w2_weight
,
router_logits
,
top_k
,
renormalize
,
use_grouped_topk
,
num_expert_group
,
topk_group
,
)
def
forward_cuda
(
self
,
x
:
torch
.
Tensor
,
w1
:
torch
.
Tensor
,
w2
:
torch
.
Tensor
,
router_logits
:
torch
.
Tensor
,
top_k
:
int
,
renormalize
:
bool
,
use_grouped_topk
:
bool
,
num_expert_group
:
Optional
[
int
],
topk_group
:
Optional
[
int
],
)
->
torch
.
Tensor
:
from
sglang.srt.layers.fused_moe.fused_moe
import
fused_moe
return
fused_moe
(
x
,
w1
,
w2
,
router_logits
,
top_k
,
renormalize
=
renormalize
,
inplace
=
True
,
use_grouped_topk
=
use_grouped_topk
,
num_expert_group
=
num_expert_group
,
topk_group
=
topk_group
,
)
def
forward_cpu
(
self
,
*
args
,
**
kwargs
):
raise
NotImplementedError
(
"The CPU backend currently does not support MoE."
)
def
forward_tpu
(
self
,
x
:
torch
.
Tensor
,
w1
:
torch
.
Tensor
,
w2
:
torch
.
Tensor
,
router_logits
:
torch
.
Tensor
,
top_k
:
int
,
renormalize
:
bool
,
use_grouped_topk
:
bool
,
num_expert_group
:
Optional
[
int
],
topk_group
:
Optional
[
int
],
)
->
torch
.
Tensor
:
from
vllm.model_executor.layers.fused_moe.moe_pallas
import
fused_moe
assert
not
use_grouped_topk
assert
num_expert_group
is
None
assert
topk_group
is
None
return
fused_moe
(
x
,
w1
,
w2
,
router_logits
,
top_k
,
renormalize
)
class
FusedMoE
(
torch
.
nn
.
Module
):
"""FusedMoE layer for MoE models.
This layer contains both MergedColumnParallel weights (gate_up_proj /
w13) and RowParallelLinear weights (down_proj/ w2).
Note: Mixtral uses w1, w2, and w3 for gate, up, and down_proj. We
copy that naming convention here and handle any remapping in the
load_weights function in each model implementation.
Args:
num_experts: Number of experts in the model
top_k: Number of experts selected for each token
hidden_size: Input hidden state size of the transformer
intermediate_size: Intermediate size of the experts
params_dtype: Data type for the parameters.
reduce_results: Whether to all all_reduce on the output of the layer
renomalize: Whether to renormalize the logits in the fused_moe kernel
quant_config: Quantization configure.
"""
def
__init__
(
self
,
num_experts
:
int
,
top_k
:
int
,
hidden_size
:
int
,
intermediate_size
:
int
,
params_dtype
:
Optional
[
torch
.
dtype
]
=
None
,
reduce_results
:
bool
=
False
,
renormalize
:
bool
=
True
,
use_grouped_topk
:
bool
=
False
,
num_expert_group
:
Optional
[
int
]
=
None
,
topk_group
:
Optional
[
int
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
tp_size
:
Optional
[
int
]
=
None
,
prefix
:
str
=
""
,
):
super
().
__init__
()
if
params_dtype
is
None
:
params_dtype
=
torch
.
get_default_dtype
()
self
.
tp_size
=
(
tp_size
if
tp_size
is
not
None
else
get_tensor_model_parallel_world_size
()
)
self
.
top_k
=
top_k
self
.
num_experts
=
num_experts
self
.
intermediate_size_per_partition
=
intermediate_size
//
self
.
tp_size
self
.
reduce_results
=
reduce_results
self
.
renormalize
=
renormalize
self
.
use_grouped_topk
=
use_grouped_topk
if
self
.
use_grouped_topk
:
assert
num_expert_group
is
not
None
and
topk_group
is
not
None
self
.
num_expert_group
=
num_expert_group
self
.
topk_group
=
topk_group
if
quant_config
is
None
:
self
.
quant_method
:
Optional
[
QuantizeMethodBase
]
=
(
UnquantizedFusedMoEMethod
()
)
else
:
if
isinstance
(
quant_config
,
Fp8Config
):
self
.
quant_method
=
Fp8MoEMethod
(
quant_config
)
else
:
self
.
quant_method
=
quant_config
.
get_quant_method
(
self
,
prefix
)
assert
self
.
quant_method
is
not
None
self
.
quant_method
.
create_weights
(
layer
=
self
,
num_experts
=
num_experts
,
hidden_size
=
hidden_size
,
intermediate_size
=
self
.
intermediate_size_per_partition
,
params_dtype
=
params_dtype
,
weight_loader
=
self
.
weight_loader
,
)
def
weight_loader
(
self
,
param
:
torch
.
nn
.
Parameter
,
loaded_weight
:
torch
.
Tensor
,
weight_name
:
str
,
shard_id
:
int
,
expert_id
:
int
,
pre_sharded
:
bool
,
):
param_data
=
param
.
data
# Input scales can be loaded directly and should be equal.
if
"input_scale"
in
weight_name
:
if
(
param_data
[
expert_id
]
!=
1
and
(
param_data
[
expert_id
]
-
loaded_weight
).
abs
()
>
1e-5
):
raise
ValueError
(
"input_scales of w1 and w3 of a layer "
f
"must be equal. But got
{
param_data
[
expert_id
]
}
"
f
"vs.
{
loaded_weight
}
"
)
param_data
[
expert_id
]
=
loaded_weight
# Weight scales
elif
"weight_scale"
in
weight_name
:
# If we are in merged column case (gate_up_proj)
# shard_id 0 == gate_proj / w1
# shard_id 2 == up_proj / w3
if
shard_id
==
0
or
shard_id
==
2
:
# We have to keep the weight scales of w1 and w3 because
# we need to re-quantize w1/w3 weights after weight loading.
idx
=
0
if
shard_id
==
0
else
1
param_data
[
expert_id
][
idx
]
=
loaded_weight
# If we are in the row parallel case (down_proj)
# shard_id 1 == down_proj / w2
else
:
param_data
[
expert_id
]
=
loaded_weight
# Weights
else
:
tp_rank
=
get_tensor_model_parallel_rank
()
shard_size
=
self
.
intermediate_size_per_partition
if
pre_sharded
:
shard
=
slice
(
None
)
else
:
shard
=
slice
(
tp_rank
*
shard_size
,
(
tp_rank
+
1
)
*
shard_size
)
# w1, gate_proj case: Load into first shard of w13.
if
shard_id
==
0
:
param_data
[
expert_id
,
0
:
shard_size
,
:]
=
loaded_weight
[
shard
,
:]
# w3, up_proj case: Load into second shard of w13.
elif
shard_id
==
2
:
param_data
[
expert_id
,
shard_size
:
2
*
shard_size
,
:]
=
loaded_weight
[
shard
,
:
]
# w2, down_proj case: Load into only shard of w2.
elif
shard_id
==
1
:
param_data
[
expert_id
,
:,
:]
=
loaded_weight
[:,
shard
]
else
:
raise
ValueError
(
f
"Shard id must be in [0,1,2] but got
{
shard_id
}
"
)
def
forward
(
self
,
hidden_states
:
torch
.
Tensor
,
router_logits
:
torch
.
Tensor
):
assert
self
.
quant_method
is
not
None
# Matrix multiply.
final_hidden_states
=
self
.
quant_method
.
apply
(
self
,
x
=
hidden_states
,
router_logits
=
router_logits
,
top_k
=
self
.
top_k
,
renormalize
=
self
.
renormalize
,
use_grouped_topk
=
self
.
use_grouped_topk
,
num_expert_group
=
self
.
num_expert_group
,
topk_group
=
self
.
topk_group
,
)
if
self
.
reduce_results
and
self
.
tp_size
>
1
:
final_hidden_states
=
tensor_model_parallel_all_reduce
(
final_hidden_states
)
return
final_hidden_states
@
classmethod
def
make_expert_params_mapping
(
cls
,
ckpt_gate_proj_name
:
str
,
ckpt_down_proj_name
:
str
,
ckpt_up_proj_name
:
str
,
num_experts
:
int
,
)
->
List
[
Tuple
[
str
,
str
,
int
,
int
]]:
gate_up
=
[
ckpt_gate_proj_name
,
ckpt_up_proj_name
]
gate_down_up
=
[
ckpt_gate_proj_name
,
ckpt_down_proj_name
,
ckpt_up_proj_name
]
return
(
[
# These are the weight scales for the experts
# (param_name, weight_name, expert_id, shard_id)
(
(
"experts.w13_scale"
if
weight_name
in
gate_up
else
"experts.w2_scale"
),
f
"experts.
{
expert_id
}
.
{
weight_name
}
.weight_scale"
,
expert_id
,
shard_id
,
)
for
expert_id
in
range
(
num_experts
)
for
shard_id
,
weight_name
in
enumerate
(
gate_down_up
)
]
+
[
# These are the weights for the experts
# (param_name, weight_name, expert_id, shard_id)
(
(
"experts.w13_weight"
if
weight_name
in
gate_up
else
"experts.w2_weight"
),
f
"experts.
{
expert_id
}
.
{
weight_name
}
.weight"
,
expert_id
,
shard_id
,
)
for
expert_id
in
range
(
num_experts
)
for
shard_id
,
weight_name
in
enumerate
(
gate_down_up
)
]
+
[
# These are the weight scales for the experts
# (param_name, weight_name, expert_id, shard_id)
(
(
"experts.a13_scale"
if
weight_name
in
gate_up
else
"experts.a2_scale"
),
f
"experts.
{
expert_id
}
.
{
weight_name
}
.input_scale"
,
expert_id
,
shard_id
,
)
for
expert_id
in
range
(
num_experts
)
for
shard_id
,
weight_name
in
enumerate
(
gate_down_up
)
]
)
import
torch
from
torch.nn
import
Module
from
vllm
import
_custom_ops
as
ops
from
vllm.model_executor.layers.quantization.utils.w8a8_utils
import
(
all_close_1d
,
per_tensor_dequantize
,
)
from
vllm.utils
import
print_warning_once
class
Fp8MoEMethod
(
FusedMoEMethodBase
):
"""MoE method for FP8.
Supports loading FP8 checkpoints with static weight scale and
dynamic/static activation scale.
Also supports loading quantized FP16/BF16 model checkpoints with dynamic
activation scaling. The weight scaling factor will be initialized after
the model weights are loaded.
Args:
quant_config: The quantization config.
"""
def
__init__
(
self
,
quant_config
:
Fp8Config
):
self
.
quant_config
=
quant_config
def
create_weights
(
self
,
layer
:
Module
,
num_experts
:
int
,
hidden_size
:
int
,
intermediate_size
:
int
,
params_dtype
:
torch
.
dtype
,
**
extra_weight_attrs
,
):
if
self
.
quant_config
.
is_checkpoint_fp8_serialized
:
params_dtype
=
torch
.
float8_e4m3fn
# WEIGHTS
w13_weight
=
torch
.
nn
.
Parameter
(
torch
.
empty
(
num_experts
,
2
*
intermediate_size
,
hidden_size
,
dtype
=
params_dtype
),
requires_grad
=
False
,
)
layer
.
register_parameter
(
"w13_weight"
,
w13_weight
)
set_weight_attrs
(
w13_weight
,
extra_weight_attrs
)
w2_weight
=
torch
.
nn
.
Parameter
(
torch
.
empty
(
num_experts
,
hidden_size
,
intermediate_size
,
dtype
=
params_dtype
),
requires_grad
=
False
,
)
layer
.
register_parameter
(
"w2_weight"
,
w2_weight
)
set_weight_attrs
(
w2_weight
,
extra_weight_attrs
)
# WEIGHT_SCALES
# Allocate 2 scales for w1 and w3 respectively.
# They will be combined to a single scale after weight loading.
w13_scale
=
torch
.
nn
.
Parameter
(
torch
.
ones
(
num_experts
,
2
,
dtype
=
torch
.
float32
),
requires_grad
=
False
)
layer
.
register_parameter
(
"w13_scale"
,
w13_scale
)
w2_scale
=
torch
.
nn
.
Parameter
(
torch
.
ones
(
num_experts
,
dtype
=
torch
.
float32
),
requires_grad
=
False
)
layer
.
register_parameter
(
"w2_scale"
,
w2_scale
)
# If loading fp8 checkpoint, pass the weight loaders.
# If loading an fp16 checkpoint, do not (we will quantize in
# process_weights_after_loading()
if
self
.
quant_config
.
is_checkpoint_fp8_serialized
:
set_weight_attrs
(
w13_scale
,
extra_weight_attrs
)
set_weight_attrs
(
w2_scale
,
extra_weight_attrs
)
# INPUT_SCALES
if
self
.
quant_config
.
activation_scheme
==
"static"
:
if
not
self
.
quant_config
.
is_checkpoint_fp8_serialized
:
raise
ValueError
(
"Found static activation scheme for checkpoint that "
"was not serialized fp8."
)
a13_scale
=
torch
.
nn
.
Parameter
(
torch
.
ones
(
num_experts
,
dtype
=
torch
.
float32
),
requires_grad
=
False
)
layer
.
register_parameter
(
"a13_scale"
,
a13_scale
)
set_weight_attrs
(
a13_scale
,
extra_weight_attrs
)
a2_scale
=
torch
.
nn
.
Parameter
(
torch
.
ones
(
num_experts
,
dtype
=
torch
.
float32
),
requires_grad
=
False
)
layer
.
register_parameter
(
"a2_scale"
,
a2_scale
)
set_weight_attrs
(
a2_scale
,
extra_weight_attrs
)
else
:
layer
.
a13_scale
=
None
layer
.
a2_scale
=
None
def
process_weights_after_loading
(
self
,
layer
:
Module
)
->
None
:
# If checkpoint is fp16, quantize in place.
if
not
self
.
quant_config
.
is_checkpoint_fp8_serialized
:
w13_weight
=
torch
.
empty_like
(
layer
.
w13_weight
.
data
,
dtype
=
torch
.
float8_e4m3fn
)
w2_weight
=
torch
.
empty_like
(
layer
.
w2_weight
.
data
,
dtype
=
torch
.
float8_e4m3fn
)
# Re-initialize w13_scale because we directly quantize
# merged w13 weights and generate a single scaling factor.
layer
.
w13_scale
=
torch
.
nn
.
Parameter
(
torch
.
ones
(
layer
.
num_experts
,
dtype
=
torch
.
float32
,
device
=
w13_weight
.
device
),
requires_grad
=
False
,
)
for
expert
in
range
(
layer
.
num_experts
):
w13_weight
[
expert
,
:,
:],
layer
.
w13_scale
[
expert
]
=
(
ops
.
scaled_fp8_quant
(
layer
.
w13_weight
.
data
[
expert
,
:,
:])
)
w2_weight
[
expert
,
:,
:],
layer
.
w2_scale
[
expert
]
=
ops
.
scaled_fp8_quant
(
layer
.
w2_weight
.
data
[
expert
,
:,
:]
)
layer
.
w13_weight
=
torch
.
nn
.
Parameter
(
w13_weight
,
requires_grad
=
False
)
layer
.
w2_weight
=
torch
.
nn
.
Parameter
(
w2_weight
,
requires_grad
=
False
)
return
# If checkpoint is fp8, we need to handle that the
# MoE kernels require single activation scale and single weight
# scale for w13 per expert.
else
:
# Fp8 moe kernels require a single activation scale.
# We take the max of all the scales in case they differ.
if
self
.
quant_config
.
activation_scheme
==
"static"
:
if
layer
.
a13_scale
is
None
or
layer
.
a2_scale
is
None
:
raise
ValueError
(
"QuantConfig has static quantization, but found "
"activation scales are None."
)
if
not
all_close_1d
(
layer
.
a13_scale
)
or
not
all_close_1d
(
layer
.
a2_scale
):
print_warning_once
(
"Found input_scales that are not equal for "
"fp8 MoE layer. Using the maximum across experts "
"for each layer. "
)
layer
.
a13_scale
=
torch
.
nn
.
Parameter
(
layer
.
a13_scale
.
max
(),
requires_grad
=
False
)
layer
.
a2_scale
=
torch
.
nn
.
Parameter
(
layer
.
a2_scale
.
max
(),
requires_grad
=
False
)
# Fp8 moe kernel needs single weight scale for w13 per expert.
# We take the max then dequant and requant each expert.
assert
layer
.
w13_scale
is
not
None
shard_size
=
layer
.
intermediate_size_per_partition
max_w13_scales
=
layer
.
w13_scale
.
max
(
dim
=
1
).
values
for
expert_id
in
range
(
layer
.
num_experts
):
start
=
0
for
shard_id
in
range
(
2
):
dq_weight
=
per_tensor_dequantize
(
layer
.
w13_weight
[
expert_id
][
start
:
start
+
shard_size
,
:],
layer
.
w13_scale
[
expert_id
][
shard_id
],
)
layer
.
w13_weight
[
expert_id
][
start
:
start
+
shard_size
,
:],
_
=
(
ops
.
scaled_fp8_quant
(
dq_weight
,
max_w13_scales
[
expert_id
])
)
start
+=
shard_size
layer
.
w13_scale
=
torch
.
nn
.
Parameter
(
max_w13_scales
,
requires_grad
=
False
)
return
def
apply
(
self
,
layer
:
torch
.
nn
.
Module
,
x
:
torch
.
Tensor
,
router_logits
:
torch
.
Tensor
,
top_k
:
int
,
renormalize
:
bool
=
True
,
use_grouped_topk
:
bool
=
False
,
num_expert_group
:
Optional
[
int
]
=
None
,
topk_group
:
Optional
[
int
]
=
None
,
)
->
torch
.
Tensor
:
from
sglang.srt.layers.fused_moe.fused_moe
import
fused_moe
return
fused_moe
(
x
,
layer
.
w13_weight
,
layer
.
w2_weight
,
router_logits
,
top_k
,
renormalize
=
renormalize
,
inplace
=
True
,
use_fp8
=
True
,
w1_scale
=
layer
.
w13_scale
,
w2_scale
=
layer
.
w2_scale
,
a1_scale
=
layer
.
a13_scale
,
a2_scale
=
layer
.
a2_scale
,
use_grouped_topk
=
use_grouped_topk
,
num_expert_group
=
num_expert_group
,
topk_group
=
topk_group
,
)
python/sglang/srt/layers/logits_processor.py
View file @
a59636bb
...
@@ -164,9 +164,9 @@ class LogitsProcessor(nn.Module):
...
@@ -164,9 +164,9 @@ class LogitsProcessor(nn.Module):
last_logits
=
last_logits
[:,
:
self
.
config
.
vocab_size
].
float
()
last_logits
=
last_logits
[:,
:
self
.
config
.
vocab_size
].
float
()
if
hasattr
(
self
.
config
,
"final_logit_softcapping"
):
if
hasattr
(
self
.
config
,
"final_logit_softcapping"
):
last_logits
/=
self
.
config
.
final_logit_softcapping
last_logits
.
div_
(
self
.
config
.
final_logit_softcapping
)
last_logits
=
torch
.
tanh
(
last_logits
)
last_logits
=
torch
.
tanh
(
last_logits
)
last_logits
*=
self
.
config
.
final_logit_softcapping
last_logits
.
mul_
(
self
.
config
.
final_logit_softcapping
)
# Return only last_logits if logprob is not requested
# Return only last_logits if logprob is not requested
if
not
logits_metadata
.
return_logprob
:
if
not
logits_metadata
.
return_logprob
:
...
@@ -209,9 +209,9 @@ class LogitsProcessor(nn.Module):
...
@@ -209,9 +209,9 @@ class LogitsProcessor(nn.Module):
all_logits
=
all_logits
[:,
:
self
.
config
.
vocab_size
].
float
()
all_logits
=
all_logits
[:,
:
self
.
config
.
vocab_size
].
float
()
if
hasattr
(
self
.
config
,
"final_logit_softcapping"
):
if
hasattr
(
self
.
config
,
"final_logit_softcapping"
):
all_logits
/=
self
.
config
.
final_logit_softcapping
all_logits
.
div_
(
self
.
config
.
final_logit_softcapping
)
all_logits
=
torch
.
tanh
(
all_logits
)
all_logits
=
torch
.
tanh
(
all_logits
)
all_logits
*=
self
.
config
.
final_logit_softcapping
all_logits
.
mul_
(
self
.
config
.
final_logit_softcapping
)
all_logprobs
=
all_logits
all_logprobs
=
all_logits
del
all_logits
,
hidden_states
del
all_logits
,
hidden_states
...
...
python/sglang/srt/model_executor/model_runner.py
View file @
a59636bb
...
@@ -53,7 +53,7 @@ from sglang.srt.server_args import ServerArgs
...
@@ -53,7 +53,7 @@ from sglang.srt.server_args import ServerArgs
from
sglang.srt.utils
import
(
from
sglang.srt.utils
import
(
get_available_gpu_memory
,
get_available_gpu_memory
,
is_generation_model
,
is_generation_model
,
is_llama3_405b_fp8
,
is_llama3_405b_fp8
_head_16
,
is_multimodal_model
,
is_multimodal_model
,
monkey_patch_vllm_dummy_weight_loader
,
monkey_patch_vllm_dummy_weight_loader
,
monkey_patch_vllm_p2p_access_check
,
monkey_patch_vllm_p2p_access_check
,
...
@@ -158,7 +158,7 @@ class ModelRunner:
...
@@ -158,7 +158,7 @@ class ModelRunner:
skip_tokenizer_init
=
True
,
skip_tokenizer_init
=
True
,
)
)
if
is_llama3_405b_fp8
(
self
.
model_config
)
and
self
.
tp_size
<=
8
:
if
is_llama3_405b_fp8
_head_16
(
self
.
model_config
)
and
self
.
tp_size
<=
8
:
# A temporary hack to fix the num_heads for meta-llama/Meta-Llama-3.1-405B-FP8 checkpoints
# A temporary hack to fix the num_heads for meta-llama/Meta-Llama-3.1-405B-FP8 checkpoints
self
.
model_config
.
hf_config
.
num_key_value_heads
=
8
self
.
model_config
.
hf_config
.
num_key_value_heads
=
8
vllm_model_config
.
hf_config
.
num_key_value_heads
=
8
vllm_model_config
.
hf_config
.
num_key_value_heads
=
8
...
...
python/sglang/srt/models/grok.py
View file @
a59636bb
...
@@ -16,20 +16,17 @@ limitations under the License.
...
@@ -16,20 +16,17 @@ limitations under the License.
# Adapted from
# Adapted from
# https://github.com/vllm-project/vllm/blob/c7f2cf2b7f67bce5842fedfdba508440fe257375/vllm/model_executor/models/mixtral.py#L1
# https://github.com/vllm-project/vllm/blob/c7f2cf2b7f67bce5842fedfdba508440fe257375/vllm/model_executor/models/mixtral.py#L1
"""Inference-only Grok1 model."""
"""Inference-only Grok1 model."""
import
warnings
from
typing
import
Iterable
,
List
,
Optional
,
Tuple
from
typing
import
Iterable
,
List
,
Optional
,
Tuple
import
numpy
as
np
import
torch
import
torch
import
torch.nn.functional
as
F
import
torch.nn.functional
as
F
import
tqdm
from
torch
import
nn
from
torch
import
nn
from
transformers
import
PretrainedConfig
from
transformers
import
PretrainedConfig
from
vllm
import
_custom_ops
as
ops
from
vllm.config
import
CacheConfig
from
vllm.config
import
CacheConfig
from
vllm.distributed
import
(
from
vllm.distributed
import
(
get_tensor_model_parallel_rank
,
get_tensor_model_parallel_rank
,
get_tensor_model_parallel_world_size
,
get_tensor_model_parallel_world_size
,
tensor_model_parallel_all_reduce
,
)
)
from
vllm.model_executor.layers.linear
import
(
from
vllm.model_executor.layers.linear
import
(
QKVParallelLinear
,
QKVParallelLinear
,
...
@@ -37,7 +34,6 @@ from vllm.model_executor.layers.linear import (
...
@@ -37,7 +34,6 @@ from vllm.model_executor.layers.linear import (
RowParallelLinear
,
RowParallelLinear
,
)
)
from
vllm.model_executor.layers.quantization.base_config
import
QuantizationConfig
from
vllm.model_executor.layers.quantization.base_config
import
QuantizationConfig
from
vllm.model_executor.layers.quantization.fp8
import
Fp8Config
from
vllm.model_executor.layers.rotary_embedding
import
get_rope
from
vllm.model_executor.layers.rotary_embedding
import
get_rope
from
vllm.model_executor.layers.vocab_parallel_embedding
import
(
from
vllm.model_executor.layers.vocab_parallel_embedding
import
(
ParallelLMHead
,
ParallelLMHead
,
...
@@ -45,141 +41,13 @@ from vllm.model_executor.layers.vocab_parallel_embedding import (
...
@@ -45,141 +41,13 @@ from vllm.model_executor.layers.vocab_parallel_embedding import (
)
)
from
vllm.model_executor.model_loader.loader
import
DefaultModelLoader
from
vllm.model_executor.model_loader.loader
import
DefaultModelLoader
from
vllm.model_executor.model_loader.weight_utils
import
default_weight_loader
from
vllm.model_executor.model_loader.weight_utils
import
default_weight_loader
from
vllm.model_executor.utils
import
set_weight_attrs
from
vllm.utils
import
print_warning_once
from
sglang.srt.layers.fused_moe
import
f
used
_moe
from
sglang.srt.layers.fused_moe
import
F
used
MoE
from
sglang.srt.layers.layernorm
import
RMSNorm
from
sglang.srt.layers.layernorm
import
RMSNorm
from
sglang.srt.layers.logits_processor
import
LogitsProcessor
from
sglang.srt.layers.logits_processor
import
LogitsProcessor
from
sglang.srt.layers.radix_attention
import
RadixAttention
from
sglang.srt.layers.radix_attention
import
RadixAttention
from
sglang.srt.model_executor.forward_batch_info
import
InputMetadata
from
sglang.srt.model_executor.forward_batch_info
import
InputMetadata
use_fused
=
True
class
Grok1MLP
(
nn
.
Module
):
def
__init__
(
self
,
num_experts
:
int
,
hidden_size
:
int
,
intermediate_size
:
int
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
)
->
None
:
super
().
__init__
()
self
.
num_experts
=
num_experts
self
.
ffn_dim
=
intermediate_size
self
.
hidden_dim
=
hidden_size
self
.
w1
=
ReplicatedLinear
(
self
.
hidden_dim
,
self
.
ffn_dim
,
bias
=
False
,
quant_config
=
quant_config
)
self
.
w2
=
ReplicatedLinear
(
self
.
ffn_dim
,
self
.
hidden_dim
,
bias
=
False
,
quant_config
=
quant_config
)
self
.
w3
=
ReplicatedLinear
(
self
.
hidden_dim
,
self
.
ffn_dim
,
bias
=
False
,
quant_config
=
quant_config
)
self
.
act_fn
=
nn
.
GELU
()
def
forward
(
self
,
hidden_states
:
torch
.
Tensor
)
->
torch
.
Tensor
:
w1_out
,
_
=
self
.
w1
(
hidden_states
)
w1_out
=
self
.
act_fn
(
w1_out
)
w3_out
,
_
=
self
.
w3
(
hidden_states
)
current_hidden_states
=
w1_out
*
w3_out
current_hidden_states
,
_
=
self
.
w2
(
current_hidden_states
)
return
current_hidden_states
class
Grok1MoEUnfused
(
nn
.
Module
):
def
__init__
(
self
,
config
:
PretrainedConfig
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
):
super
().
__init__
()
self
.
config
=
config
self
.
rank
=
get_tensor_model_parallel_rank
()
self
.
tp_size
=
get_tensor_model_parallel_world_size
()
self
.
num_total_experts
=
config
.
num_local_experts
self
.
top_k
=
config
.
num_experts_per_tok
if
self
.
tp_size
>
self
.
num_total_experts
:
raise
ValueError
(
f
"Tensor parallel size
{
self
.
tp_size
}
is greater than "
f
"the number of experts
{
self
.
num_total_experts
}
."
)
# Split experts equally between ranks
self
.
expert_indicies
=
np
.
array_split
(
range
(
self
.
num_total_experts
),
self
.
tp_size
)[
self
.
rank
].
tolist
()
if
not
self
.
expert_indicies
:
raise
ValueError
(
f
"Rank
{
self
.
rank
}
has no experts assigned to it."
)
self
.
experts
=
nn
.
ModuleList
(
[
(
Grok1MLP
(
self
.
num_total_experts
,
config
.
hidden_size
,
config
.
intermediate_size
,
quant_config
=
quant_config
,
)
if
idx
in
self
.
expert_indicies
else
None
)
for
idx
in
range
(
self
.
num_total_experts
)
]
)
self
.
gate
=
ReplicatedLinear
(
config
.
hidden_size
,
self
.
num_total_experts
,
bias
=
False
,
quant_config
=
None
)
def
forward
(
self
,
hidden_states
:
torch
.
Tensor
)
->
torch
.
Tensor
:
router_logits
,
_
=
self
.
gate
(
hidden_states
)
router_logits
=
30
*
F
.
tanh
(
router_logits
/
30
)
routing_weights
=
F
.
softmax
(
router_logits
,
dim
=
1
,
dtype
=
torch
.
float
)
routing_weights
,
selected_experts
=
torch
.
topk
(
routing_weights
,
self
.
top_k
,
dim
=-
1
)
routing_weights
=
routing_weights
.
to
(
hidden_states
.
dtype
)
hidden_dim
=
hidden_states
.
shape
[
1
]
final_hidden_states
=
torch
.
zeros
(
(
hidden_states
.
shape
[
0
],
hidden_dim
),
dtype
=
hidden_states
.
dtype
,
device
=
hidden_states
.
device
,
)
expert_mask
=
torch
.
nn
.
functional
.
one_hot
(
selected_experts
,
num_classes
=
self
.
num_total_experts
).
permute
(
2
,
1
,
0
)
for
expert_idx
in
self
.
expert_indicies
:
expert_layer
=
self
.
experts
[
expert_idx
]
idx
,
top_x
=
torch
.
where
(
expert_mask
[
expert_idx
])
if
top_x
.
shape
[
0
]
==
0
:
continue
# in torch it is faster to index using lists than torch tensors
top_x_list
=
top_x
.
tolist
()
idx_list
=
idx
.
tolist
()
# Index the correct hidden states and compute the expert hidden state for
# the current expert. We need to make sure to multiply the output hidden
# states by `routing_weights` on the corresponding tokens (top-1 and top-2)
current_state
=
hidden_states
[
None
,
top_x_list
].
reshape
(
-
1
,
hidden_dim
)
current_hidden_states
=
(
expert_layer
(
current_state
)
*
routing_weights
[
top_x_list
,
idx_list
,
None
]
)
# However `index_add_` only support torch tensors for indexing so we'll use
# the `top_x` tensor here.
final_hidden_states
.
index_add_
(
0
,
top_x
,
current_hidden_states
)
return
tensor_model_parallel_all_reduce
(
final_hidden_states
)
class
Grok1MoE
(
nn
.
Module
):
class
Grok1MoE
(
nn
.
Module
):
"""A tensor-parallel MoE implementation for Grok1 that shards each expert
"""A tensor-parallel MoE implementation for Grok1 that shards each expert
...
@@ -197,221 +65,42 @@ class Grok1MoE(nn.Module):
...
@@ -197,221 +65,42 @@ class Grok1MoE(nn.Module):
hidden_size
:
int
,
hidden_size
:
int
,
intermediate_size
:
int
,
intermediate_size
:
int
,
params_dtype
:
Optional
[
torch
.
dtype
]
=
None
,
params_dtype
:
Optional
[
torch
.
dtype
]
=
None
,
tp_size
:
Optional
[
int
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
tp_size
:
Optional
[
int
]
=
None
,
):
):
super
().
__init__
()
super
().
__init__
()
self
.
tp_size
=
tp_size
or
get_tensor_model_parallel_world_size
()
self
.
num_total_experts
=
num_experts
self
.
top_k
=
top_k
self
.
hidden_size
=
hidden_size
self
.
hidden_size
=
hidden_size
self
.
intermediate_size
=
intermediate_size
//
self
.
tp_size
self
.
quant_config
=
quant_config
# FIXME(pcmoritz): Make this more general to support different
# quantization schemes
self
.
use_fp8
=
isinstance
(
quant_config
,
Fp8Config
)
if
params_dtype
is
None
:
params_dtype
=
torch
.
get_default_dtype
()
self
.
params_dtype
=
params_dtype
# Gate always runs at half / full precision for now.
# Gate always runs at half / full precision for now.
self
.
gate
=
ReplicatedLinear
(
self
.
gate
=
ReplicatedLinear
(
self
.
hidden_size
,
hidden_size
,
self
.
num_total
_experts
,
num
_experts
,
bias
=
False
,
bias
=
False
,
params_dtype
=
self
.
params_dtype
,
params_dtype
=
params_dtype
,
quant_config
=
None
,
quant_config
=
None
,
)
)
if
self
.
use_fp8
and
self
.
quant_config
.
is_checkpoint_fp8_serialized
:
self
.
experts
=
FusedMoE
(
params_dtype
=
torch
.
float8_e4m3fn
num_experts
=
num_experts
,
top_k
=
top_k
,
self
.
w13_weight
=
nn
.
Parameter
(
hidden_size
=
hidden_size
,
torch
.
empty
(
intermediate_size
=
intermediate_size
,
self
.
num_total_experts
,
params_dtype
=
params_dtype
,
2
*
self
.
intermediate_size
,
reduce_results
=
True
,
self
.
hidden_size
,
renormalize
=
False
,
dtype
=
params_dtype
,
quant_config
=
quant_config
,
)
tp_size
=
tp_size
,
)
self
.
w2_weight
=
nn
.
Parameter
(
torch
.
empty
(
self
.
num_total_experts
,
self
.
hidden_size
,
self
.
intermediate_size
,
dtype
=
params_dtype
,
)
)
set_weight_attrs
(
self
.
w13_weight
,
{
"weight_loader"
:
self
.
weight_loader
,
},
)
set_weight_attrs
(
self
.
w2_weight
,
{
"weight_loader"
:
self
.
weight_loader
,
},
)
)
# Used for fp8.
self
.
w13_scale
=
None
self
.
w2_scale
=
None
self
.
a13_scale
=
None
self
.
a2_scale
=
None
if
self
.
use_fp8
:
# WEIGHT_SCALE (for fp8)
self
.
w13_scale
=
nn
.
Parameter
(
torch
.
ones
(
self
.
num_total_experts
,
dtype
=
torch
.
float32
),
requires_grad
=
False
,
)
self
.
w2_scale
=
nn
.
Parameter
(
torch
.
ones
(
self
.
num_total_experts
,
dtype
=
torch
.
float32
),
requires_grad
=
False
,
)
# If loading fp8 checkpoint, pass the weight loaders.
# If loading an fp16 checkpoint, do not (we will quantize in
# process_weights_after_loading()
if
quant_config
.
is_checkpoint_fp8_serialized
:
set_weight_attrs
(
self
.
w13_scale
,
{
"weight_loader"
:
self
.
weight_loader
,
},
)
set_weight_attrs
(
self
.
w2_scale
,
{
"weight_loader"
:
self
.
weight_loader
,
},
)
# ACT_SCALE (for fp8)
if
quant_config
.
activation_scheme
==
"static"
:
if
not
quant_config
.
is_checkpoint_fp8_serialized
:
raise
ValueError
(
"Found static activation scheme for checkpoint that "
"was not serialized fp8."
)
self
.
a13_scale
=
nn
.
Parameter
(
torch
.
zeros
(
self
.
num_total_experts
,
dtype
=
torch
.
float32
),
requires_grad
=
False
,
)
self
.
a2_scale
=
nn
.
Parameter
(
torch
.
zeros
(
self
.
num_total_experts
,
dtype
=
torch
.
float32
),
requires_grad
=
False
,
)
set_weight_attrs
(
self
.
a13_scale
,
{
"weight_loader"
:
self
.
weight_loader
,
},
)
set_weight_attrs
(
self
.
a2_scale
,
{
"weight_loader"
:
self
.
weight_loader
,
},
)
def
weight_loader
(
self
,
param
:
nn
.
Parameter
,
loaded_weight
:
torch
.
Tensor
,
weight_name
:
str
,
expert_id
:
int
,
pre_sharded
:
bool
,
):
param_data
=
param
.
data
shard_size
=
self
.
intermediate_size
if
pre_sharded
:
# The weight is already sharded. Readl the full shard
shard
=
slice
(
None
)
else
:
tp_rank
=
get_tensor_model_parallel_rank
()
shard
=
slice
(
tp_rank
*
shard_size
,
(
tp_rank
+
1
)
*
shard_size
)
if
weight_name
.
endswith
(
"w1.weight"
):
param_data
[
expert_id
,
0
:
shard_size
,
:]
=
loaded_weight
[
shard
,
:]
if
weight_name
.
endswith
(
"w3.weight"
):
param_data
[
expert_id
,
shard_size
:
2
*
shard_size
,
:]
=
loaded_weight
[
shard
,
:
]
if
weight_name
.
endswith
(
"w2.weight"
):
param_data
[
expert_id
,
:,
:]
=
loaded_weight
[:,
shard
]
if
"act_scale"
in
weight_name
or
"weight_scale"
in
weight_name
:
param_data
[
expert_id
]
=
loaded_weight
def
process_weights_after_loading
(
self
):
# Fp8 is the only case where we need to process after loading.
if
not
self
.
use_fp8
:
return
# If checkpoint is fp16, quantize here.
if
not
self
.
quant_config
.
is_checkpoint_fp8_serialized
:
w13_weight
=
torch
.
empty_like
(
self
.
w13_weight
.
data
,
dtype
=
torch
.
float8_e4m3fn
)
w2_weight
=
torch
.
empty_like
(
self
.
w2_weight
.
data
,
dtype
=
torch
.
float8_e4m3fn
)
for
expert
in
range
(
self
.
num_total_experts
):
w13_weight
[
expert
,
:,
:],
self
.
w13_scale
[
expert
]
=
ops
.
scaled_fp8_quant
(
self
.
w13_weight
.
data
[
expert
,
:,
:]
)
w2_weight
[
expert
,
:,
:],
self
.
w2_scale
[
expert
]
=
ops
.
scaled_fp8_quant
(
self
.
w2_weight
.
data
[
expert
,
:,
:]
)
self
.
w13_weight
=
nn
.
Parameter
(
w13_weight
,
requires_grad
=
False
)
self
.
w2_weight
=
nn
.
Parameter
(
w2_weight
,
requires_grad
=
False
)
# If checkpoint is fp8 + static, cleanup act_scales.
# Since state_dict has an act_scale per expert but our kernels
# are passed one act_scale shared across all experts.
elif
self
.
quant_config
.
activation_scheme
==
"static"
:
if
self
.
a13_scale
is
None
or
self
.
a2_scale
is
None
:
raise
ValueError
(
"QuantConfig has static quantization, but found "
"activation scales are None."
)
if
not
all_close_1d
(
self
.
a13_scale
)
or
not
all_close_1d
(
self
.
a2_scale
):
print_warning_once
(
"Found act_scales that are not equal for fp8 MoE layer. "
"Using the maximum across experts for each layer. "
)
self
.
a13_scale
=
nn
.
Parameter
(
self
.
a13_scale
.
max
(),
requires_grad
=
False
)
self
.
a2_scale
=
nn
.
Parameter
(
self
.
a2_scale
.
max
(),
requires_grad
=
False
)
def
forward
(
self
,
hidden_states
:
torch
.
Tensor
)
->
torch
.
Tensor
:
def
forward
(
self
,
hidden_states
:
torch
.
Tensor
)
->
torch
.
Tensor
:
num_tokens
,
hidden_size
=
hidden_states
.
shape
# NOTE: hidden_states can have either 1D or 2D shape.
orig_shape
=
hidden_states
.
shape
hidden_states
=
hidden_states
.
view
(
-
1
,
self
.
hidden_size
)
hidden_states
=
hidden_states
.
view
(
-
1
,
self
.
hidden_size
)
# router_logits: (num_tokens, n_experts)
# router_logits: (num_tokens, n_experts)
router_logits
,
_
=
self
.
gate
(
hidden_states
)
router_logits
,
_
=
self
.
gate
(
hidden_states
)
final_hidden_states
=
fused_moe
(
router_logits
=
30.0
*
F
.
tanh
(
router_logits
/
30.0
)
hidden_states
,
final_hidden_states
=
self
.
experts
(
hidden_states
,
router_logits
)
self
.
w13_weight
,
return
final_hidden_states
.
view
(
orig_shape
)
self
.
w2_weight
,
router_logits
,
self
.
top_k
,
renormalize
=
False
,
inplace
=
True
,
use_fp8
=
self
.
use_fp8
,
w1_scale
=
self
.
w13_scale
,
w2_scale
=
self
.
w2_scale
,
a1_scale
=
self
.
a13_scale
,
a2_scale
=
self
.
a2_scale
,
)
if
self
.
tp_size
>
1
:
final_hidden_states
=
tensor_model_parallel_all_reduce
(
final_hidden_states
)
return
final_hidden_states
.
view
(
num_tokens
,
hidden_size
)
class
Grok1Attention
(
nn
.
Module
):
class
Grok1Attention
(
nn
.
Module
):
...
@@ -478,6 +167,7 @@ class Grok1Attention(nn.Module):
...
@@ -478,6 +167,7 @@ class Grok1Attention(nn.Module):
layer_id
=
layer_id
,
layer_id
=
layer_id
,
logit_cap
=
logit_cap
,
logit_cap
=
logit_cap
,
)
)
# TODO(lianmin): load logit cap from config
def
forward
(
def
forward
(
self
,
self
,
...
@@ -502,7 +192,7 @@ class Grok1DecoderLayer(nn.Module):
...
@@ -502,7 +192,7 @@ class Grok1DecoderLayer(nn.Module):
)
->
None
:
)
->
None
:
super
().
__init__
()
super
().
__init__
()
self
.
hidden_size
=
config
.
hidden_size
self
.
hidden_size
=
config
.
hidden_size
# Requires transformers > 4.32.0
rope_theta
=
getattr
(
config
,
"rope_theta"
,
10000
)
rope_theta
=
getattr
(
config
,
"rope_theta"
,
10000
)
self
.
self_attn
=
Grok1Attention
(
self
.
self_attn
=
Grok1Attention
(
hidden_size
=
self
.
hidden_size
,
hidden_size
=
self
.
hidden_size
,
...
@@ -513,18 +203,13 @@ class Grok1DecoderLayer(nn.Module):
...
@@ -513,18 +203,13 @@ class Grok1DecoderLayer(nn.Module):
rope_theta
=
rope_theta
,
rope_theta
=
rope_theta
,
quant_config
=
quant_config
,
quant_config
=
quant_config
,
)
)
if
use_fused
:
self
.
block_sparse_moe
=
Grok1MoE
(
self
.
block_sparse_moe
=
Grok1MoE
(
num_experts
=
config
.
num_local_experts
,
num_experts
=
config
.
num_local_experts
,
top_k
=
config
.
num_experts_per_tok
,
top_k
=
config
.
num_experts_per_tok
,
hidden_size
=
config
.
hidden_size
,
hidden_size
=
config
.
hidden_size
,
intermediate_size
=
config
.
intermediate_size
,
intermediate_size
=
config
.
intermediate_size
,
quant_config
=
quant_config
,
quant_config
=
quant_config
,
)
)
else
:
self
.
block_sparse_moe
=
Grok1MoEUnfused
(
config
=
config
,
quant_config
=
quant_config
)
self
.
pre_attn_norm
=
RMSNorm
(
config
.
hidden_size
,
eps
=
config
.
rms_norm_eps
)
self
.
pre_attn_norm
=
RMSNorm
(
config
.
hidden_size
,
eps
=
config
.
rms_norm_eps
)
self
.
post_attn_norm
=
RMSNorm
(
config
.
hidden_size
,
eps
=
config
.
rms_norm_eps
)
self
.
post_attn_norm
=
RMSNorm
(
config
.
hidden_size
,
eps
=
config
.
rms_norm_eps
)
self
.
pre_moe_norm
=
RMSNorm
(
config
.
hidden_size
,
eps
=
config
.
rms_norm_eps
)
self
.
pre_moe_norm
=
RMSNorm
(
config
.
hidden_size
,
eps
=
config
.
rms_norm_eps
)
...
@@ -536,6 +221,7 @@ class Grok1DecoderLayer(nn.Module):
...
@@ -536,6 +221,7 @@ class Grok1DecoderLayer(nn.Module):
hidden_states
:
torch
.
Tensor
,
hidden_states
:
torch
.
Tensor
,
input_metadata
:
InputMetadata
,
input_metadata
:
InputMetadata
,
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
# Self Attention
hidden_states
=
(
hidden_states
=
(
self
.
post_attn_norm
(
self
.
post_attn_norm
(
self
.
self_attn
(
self
.
self_attn
(
...
@@ -547,11 +233,11 @@ class Grok1DecoderLayer(nn.Module):
...
@@ -547,11 +233,11 @@ class Grok1DecoderLayer(nn.Module):
+
hidden_states
+
hidden_states
)
)
# Fully Connected
hidden_states
=
(
hidden_states
=
(
self
.
post_moe_norm
(
self
.
block_sparse_moe
(
self
.
pre_moe_norm
(
hidden_states
)))
self
.
post_moe_norm
(
self
.
block_sparse_moe
(
self
.
pre_moe_norm
(
hidden_states
)))
+
hidden_states
+
hidden_states
)
)
return
hidden_states
return
hidden_states
...
@@ -593,7 +279,6 @@ class Grok1Model(nn.Module):
...
@@ -593,7 +279,6 @@ class Grok1Model(nn.Module):
for
i
in
range
(
len
(
self
.
layers
)):
for
i
in
range
(
len
(
self
.
layers
)):
hidden_states
=
self
.
layers
[
i
](
positions
,
hidden_states
,
input_metadata
)
hidden_states
=
self
.
layers
[
i
](
positions
,
hidden_states
,
input_metadata
)
hidden_states
=
self
.
norm
(
hidden_states
)
hidden_states
=
self
.
norm
(
hidden_states
)
hidden_states
.
mul_
(
self
.
config
.
output_multiplier_scale
)
hidden_states
.
mul_
(
self
.
config
.
output_multiplier_scale
)
return
hidden_states
return
hidden_states
...
@@ -615,8 +300,8 @@ class Grok1ModelForCausalLM(nn.Module):
...
@@ -615,8 +300,8 @@ class Grok1ModelForCausalLM(nn.Module):
# Monkey patch _prepare_weights to load pre-sharded weights
# Monkey patch _prepare_weights to load pre-sharded weights
setattr
(
DefaultModelLoader
,
"_prepare_weights"
,
_prepare_presharded_weights
)
setattr
(
DefaultModelLoader
,
"_prepare_weights"
,
_prepare_presharded_weights
)
warnings
.
filterwarnings
(
"ignore"
,
category
=
FutureWarning
)
@
torch
.
no_grad
()
def
forward
(
def
forward
(
self
,
self
,
input_ids
:
torch
.
Tensor
,
input_ids
:
torch
.
Tensor
,
...
@@ -637,50 +322,17 @@ class Grok1ModelForCausalLM(nn.Module):
...
@@ -637,50 +322,17 @@ class Grok1ModelForCausalLM(nn.Module):
(
"qkv_proj"
,
"v_proj"
,
"v"
),
(
"qkv_proj"
,
"v_proj"
,
"v"
),
]
]
if
use_fused
:
# Params for weights, fp8 weight scales, fp8 activation scales
expert_params_mapping
=
(
# (param_name, weight_name, expert_id, shard_id)
[
expert_params_mapping
=
FusedMoE
.
make_expert_params_mapping
(
# These are the weight scales for the experts
ckpt_gate_proj_name
=
"w1"
,
# (param_name, weight_name, expert_id)
ckpt_down_proj_name
=
"w2"
,
(
ckpt_up_proj_name
=
"w3"
,
"w13_scale"
if
weight_name
in
[
"w1"
,
"w3"
]
else
"w2_scale"
,
num_experts
=
self
.
config
.
num_local_experts
,
f
"experts.
{
expert_id
}
.
{
weight_name
}
.weight_scale"
,
)
expert_id
,
)
for
expert_id
in
range
(
self
.
config
.
num_local_experts
)
for
weight_name
in
[
"w1"
,
"w2"
,
"w3"
]
]
+
[
# These are the weights for the experts
# (param_name, weight_name, expert_id)
(
"w13_weight"
if
weight_name
in
[
"w1"
,
"w3"
]
else
"w2_weight"
,
f
"experts.
{
expert_id
}
.
{
weight_name
}
.weight"
,
expert_id
,
)
for
expert_id
in
range
(
self
.
config
.
num_local_experts
)
for
weight_name
in
[
"w1"
,
"w2"
,
"w3"
]
]
+
[
# These are the activation scales for the experts
# (param_name, weight_name, expert_id)
(
"a13_scale"
if
weight_name
in
[
"w1"
,
"w3"
]
else
"a2_scale"
,
f
"experts.
{
expert_id
}
.
{
weight_name
}
.act_scale"
,
expert_id
,
)
for
expert_id
in
range
(
self
.
config
.
num_local_experts
)
for
weight_name
in
[
"w1"
,
"w2"
,
"w3"
]
]
)
else
:
expert_params_mapping
=
[]
params_dict
=
dict
(
self
.
named_parameters
())
params_dict
=
dict
(
self
.
named_parameters
())
if
get_tensor_model_parallel_rank
()
==
0
:
weights
=
tqdm
.
tqdm
(
weights
,
total
=
int
(
len
(
params_dict
)
*
3.4
))
for
name
,
loaded_weight
in
weights
:
for
name
,
loaded_weight
in
weights
:
# print(get_tensor_model_parallel_rank(), name)
if
"rotary_emb.inv_freq"
in
name
:
if
"rotary_emb.inv_freq"
in
name
:
continue
continue
...
@@ -691,21 +343,25 @@ class Grok1ModelForCausalLM(nn.Module):
...
@@ -691,21 +343,25 @@ class Grok1ModelForCausalLM(nn.Module):
# Skip loading extra bias for GPTQ models.
# Skip loading extra bias for GPTQ models.
if
name
.
endswith
(
".bias"
)
and
name
not
in
params_dict
:
if
name
.
endswith
(
".bias"
)
and
name
not
in
params_dict
:
continue
continue
param
=
params_dict
[
name
]
param
=
params_dict
[
name
]
weight_loader
=
param
.
weight_loader
weight_loader
=
param
.
weight_loader
weight_loader
(
param
,
loaded_weight
,
shard_id
)
weight_loader
(
param
,
loaded_weight
,
shard_id
)
break
break
else
:
else
:
for
param_name
,
weight_name
,
expert_id
in
expert_params_mapping
:
for
mapping
in
expert_params_mapping
:
param_name
,
weight_name
,
expert_id
,
shard_id
=
mapping
if
weight_name
not
in
name
:
if
weight_name
not
in
name
:
continue
continue
name
=
name
.
replace
(
weight_name
,
param_name
)
name
=
name
.
replace
(
weight_name
,
param_name
)
param
=
params_dict
[
name
]
param
=
params_dict
[
name
]
weight_loader
=
param
.
weight_loader
weight_loader
=
param
.
weight_loader
weight_loader
(
weight_loader
(
param
,
param
,
loaded_weight
,
loaded_weight
,
weight_name
,
weight_name
,
shard_id
=
shard_id
,
expert_id
=
expert_id
,
expert_id
=
expert_id
,
pre_sharded
=
get_tensor_model_parallel_world_size
()
>
1
,
pre_sharded
=
get_tensor_model_parallel_world_size
()
>
1
,
)
)
...
@@ -714,6 +370,9 @@ class Grok1ModelForCausalLM(nn.Module):
...
@@ -714,6 +370,9 @@ class Grok1ModelForCausalLM(nn.Module):
# Skip loading extra bias for GPTQ models.
# Skip loading extra bias for GPTQ models.
if
name
.
endswith
(
".bias"
)
and
name
not
in
params_dict
:
if
name
.
endswith
(
".bias"
)
and
name
not
in
params_dict
:
continue
continue
if
name
is
None
:
continue
param
=
params_dict
[
name
]
param
=
params_dict
[
name
]
weight_loader
=
getattr
(
weight_loader
=
getattr
(
param
,
"weight_loader"
,
default_weight_loader
param
,
"weight_loader"
,
default_weight_loader
...
@@ -721,11 +380,6 @@ class Grok1ModelForCausalLM(nn.Module):
...
@@ -721,11 +380,6 @@ class Grok1ModelForCausalLM(nn.Module):
weight_loader
(
param
,
loaded_weight
)
weight_loader
(
param
,
loaded_weight
)
def
all_close_1d
(
x
:
torch
.
Tensor
)
->
bool
:
assert
len
(
x
.
shape
)
==
1
return
all
(
torch
.
allclose
(
x
[
0
],
x
[
i
])
for
i
in
range
(
x
.
shape
[
0
]))
old_prepare_weights
=
getattr
(
DefaultModelLoader
,
"_prepare_weights"
)
old_prepare_weights
=
getattr
(
DefaultModelLoader
,
"_prepare_weights"
)
...
...
python/sglang/srt/models/mixtral.py
View file @
a59636bb
...
@@ -32,7 +32,6 @@ from vllm.model_executor.layers.linear import (
...
@@ -32,7 +32,6 @@ from vllm.model_executor.layers.linear import (
from
vllm.model_executor.layers.quantization.base_config
import
QuantizationConfig
from
vllm.model_executor.layers.quantization.base_config
import
QuantizationConfig
from
vllm.model_executor.layers.rotary_embedding
import
get_rope
from
vllm.model_executor.layers.rotary_embedding
import
get_rope
from
vllm.model_executor.layers.vocab_parallel_embedding
import
(
from
vllm.model_executor.layers.vocab_parallel_embedding
import
(
DEFAULT_VOCAB_PADDING_SIZE
,
ParallelLMHead
,
ParallelLMHead
,
VocabParallelEmbedding
,
VocabParallelEmbedding
,
)
)
...
...
python/sglang/srt/utils.py
View file @
a59636bb
...
@@ -35,7 +35,6 @@ import torch
...
@@ -35,7 +35,6 @@ import torch
import
torch.distributed
as
dist
import
torch.distributed
as
dist
from
fastapi.responses
import
JSONResponse
from
fastapi.responses
import
JSONResponse
from
packaging
import
version
as
pkg_version
from
packaging
import
version
as
pkg_version
from
starlette.middleware.base
import
BaseHTTPMiddleware
from
torch.nn.parameter
import
Parameter
from
torch.nn.parameter
import
Parameter
from
triton.runtime.cache
import
(
from
triton.runtime.cache
import
(
FileCacheManager
,
FileCacheManager
,
...
@@ -644,7 +643,7 @@ def set_ulimit(target_soft_limit=65535):
...
@@ -644,7 +643,7 @@ def set_ulimit(target_soft_limit=65535):
logger
.
warn
(
f
"Fail to set RLIMIT_NOFILE:
{
e
}
"
)
logger
.
warn
(
f
"Fail to set RLIMIT_NOFILE:
{
e
}
"
)
def
is_llama3_405b_fp8
(
model_config
):
def
is_llama3_405b_fp8
_head_16
(
model_config
):
"""Return whether the model is meta-llama/Meta-Llama-3.1-405B-FP8 with 16 kv heads."""
"""Return whether the model is meta-llama/Meta-Llama-3.1-405B-FP8 with 16 kv heads."""
if
(
if
(
model_config
.
hf_config
.
architectures
[
0
]
==
"LlamaForCausalLM"
model_config
.
hf_config
.
architectures
[
0
]
==
"LlamaForCausalLM"
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment