Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
dd5eba4c
Unverified
Commit
dd5eba4c
authored
Nov 27, 2024
by
Lianmin Zheng
Committed by
GitHub
Nov 27, 2024
Browse files
Remove fused_moe_grok (#2223)
parent
a4fd2f9b
Changes
7
Hide whitespace changes
Inline
Side-by-side
Showing
7 changed files
with
12 additions
and
1372 deletions
+12
-1372
3rdparty/amd/tuning/benchmark_moe_rocm.py
3rdparty/amd/tuning/benchmark_moe_rocm.py
+1
-1
python/sglang/srt/layers/fused_moe_grok/__init__.py
python/sglang/srt/layers/fused_moe_grok/__init__.py
+0
-1
python/sglang/srt/layers/fused_moe_grok/fused_moe.py
python/sglang/srt/layers/fused_moe_grok/fused_moe.py
+0
-692
python/sglang/srt/layers/fused_moe_grok/layer.py
python/sglang/srt/layers/fused_moe_grok/layer.py
+0
-630
python/sglang/srt/layers/fused_moe_triton/configs/E=8,N=4096,device_name=AMD_Instinct_MI300X,dtype=float8.json
...,N=4096,device_name=AMD_Instinct_MI300X,dtype=float8.json
+0
-0
python/sglang/srt/layers/fused_moe_triton/configs/E=8,N=8192,device_name=AMD_Instinct_MI300X,dtype=float8.json
...,N=8192,device_name=AMD_Instinct_MI300X,dtype=float8.json
+0
-0
python/sglang/srt/models/grok.py
python/sglang/srt/models/grok.py
+11
-48
No files found.
3rdparty/amd/tuning/benchmark_moe_rocm.py
View file @
dd5eba4c
...
...
@@ -10,7 +10,7 @@ import triton.language as tl
from
tqdm
import
tqdm
from
transformers
import
AutoConfig
from
sglang.srt.layers.fused_moe_
grok
.fused_moe
import
fused_moe
,
get_config_file_name
from
sglang.srt.layers.fused_moe_
triton
.fused_moe
import
fused_moe
,
get_config_file_name
padding_size
=
128
if
bool
(
int
(
os
.
getenv
(
"MOE_PADDING"
,
"0"
)))
else
0
...
...
python/sglang/srt/layers/fused_moe_grok/__init__.py
deleted
100644 → 0
View file @
a4fd2f9b
from
sglang.srt.layers.fused_moe_grok.layer
import
FusedMoE
,
FusedMoEMethodBase
python/sglang/srt/layers/fused_moe_grok/fused_moe.py
deleted
100644 → 0
View file @
a4fd2f9b
# Adapted from
# https://github.com/vllm-project/vllm/tree/v0.5.4/vllm/model_executor/layers/fused_moe
"""Fused MoE kernel."""
import
functools
import
json
import
os
from
typing
import
Any
,
Dict
,
Optional
,
Tuple
import
torch
import
triton
import
triton.language
as
tl
import
vllm.envs
as
envs
from
vllm
import
_custom_ops
as
ops
from
vllm.logger
import
init_logger
logger
=
init_logger
(
__name__
)
padding_size
=
128
if
bool
(
int
(
os
.
getenv
(
"MOE_PADDING"
,
"0"
)))
else
0
@
triton
.
jit
def
fused_moe_kernel
(
# Pointers to matrices
a_ptr
,
b_ptr
,
c_ptr
,
a_scale_ptr
,
b_scale_ptr
,
topk_weights_ptr
,
sorted_token_ids_ptr
,
expert_ids_ptr
,
num_tokens_post_padded_ptr
,
# Matrix dimensions
N
,
K
,
EM
,
num_valid_tokens
,
# The stride variables represent how much to increase the ptr by when
# moving by 1 element in a particular dimension. E.g. `stride_am` is
# how much to increase `a_ptr` by to get the element one row down
# (A has M rows).
stride_am
,
stride_ak
,
stride_be
,
stride_bk
,
stride_bn
,
stride_cm
,
stride_cn
,
# Meta-parameters
BLOCK_SIZE_M
:
tl
.
constexpr
,
BLOCK_SIZE_N
:
tl
.
constexpr
,
BLOCK_SIZE_K
:
tl
.
constexpr
,
GROUP_SIZE_M
:
tl
.
constexpr
,
MUL_ROUTED_WEIGHT
:
tl
.
constexpr
,
top_k
:
tl
.
constexpr
,
compute_type
:
tl
.
constexpr
,
use_fp8
:
tl
.
constexpr
,
even_Ks
:
tl
.
constexpr
,
):
"""
Implements the fused computation for a Mixture of Experts (MOE) using
token and expert matrices.
Key Parameters:
- A: The input tensor representing tokens with shape (*, K), where '*' can
be any shape representing batches and K is the feature dimension of
each token.
- B: The stacked MOE weight tensor with shape (E, N, K), where E is
the number of experts, K is the input feature dimension, and N is
the output feature dimension.
- C: The output cache tensor with shape (M, topk, N), where M is the
total number of tokens post padding, topk is the number of times
each token is repeated, and N is the output feature dimension.
- sorted_token_ids: A tensor containing the sorted indices of tokens,
repeated topk times and arranged by the expert index they are
assigned to.
- expert_ids: A tensor containing the indices of the expert for each
block. It determines which expert matrix from B should be used for
each block in A.
This kernel performs the multiplication of a token by its corresponding
expert matrix as determined by `expert_ids`. The sorting of
`sorted_token_ids` by expert index and padding ensures divisibility by
BLOCK_SIZE_M, which is necessary to maintain consistency in block matrix
multiplication across different blocks processed by the same expert.
"""
# -----------------------------------------------------------
# Map program ids `pid` to the block of C it should compute.
# This is done in a grouped ordering to promote L2 data reuse.
pid
=
tl
.
program_id
(
axis
=
0
)
num_pid_m
=
tl
.
cdiv
(
EM
,
BLOCK_SIZE_M
)
num_pid_n
=
tl
.
cdiv
(
N
,
BLOCK_SIZE_N
)
num_pid_in_group
=
GROUP_SIZE_M
*
num_pid_n
group_id
=
pid
//
num_pid_in_group
first_pid_m
=
group_id
*
GROUP_SIZE_M
group_size_m
=
min
(
num_pid_m
-
first_pid_m
,
GROUP_SIZE_M
)
pid_m
=
first_pid_m
+
((
pid
%
num_pid_in_group
)
%
group_size_m
)
pid_n
=
(
pid
%
num_pid_in_group
)
//
group_size_m
# ----------------------------------------------------------
# Create pointers for the first blocks of A and B.
# We will advance this pointer as we move in the K direction
# and accumulate
# `a_ptrs` is a block of [BLOCK_SIZE_M, BLOCK_SIZE_K] pointers
# `b_ptrs` is a block of [BLOCK_SIZE_K, BLOCK_SIZE_N] pointers
num_tokens_post_padded
=
tl
.
load
(
num_tokens_post_padded_ptr
)
if
pid_m
*
BLOCK_SIZE_M
>=
num_tokens_post_padded
:
return
offs_token_id
=
pid_m
*
BLOCK_SIZE_M
+
tl
.
arange
(
0
,
BLOCK_SIZE_M
)
offs_token
=
tl
.
load
(
sorted_token_ids_ptr
+
offs_token_id
)
token_mask
=
offs_token
<
num_valid_tokens
offs_bn
=
(
pid_n
*
BLOCK_SIZE_N
+
tl
.
arange
(
0
,
BLOCK_SIZE_N
))
%
N
offs_k
=
tl
.
arange
(
0
,
BLOCK_SIZE_K
)
a_ptrs
=
a_ptr
+
(
offs_token
[:,
None
]
//
top_k
*
stride_am
+
offs_k
[
None
,
:]
*
stride_ak
)
off_experts
=
tl
.
load
(
expert_ids_ptr
+
pid_m
)
b_ptrs
=
(
b_ptr
+
off_experts
*
stride_be
+
(
offs_k
[:,
None
]
*
stride_bk
+
offs_bn
[
None
,
:]
*
stride_bn
)
)
if
use_fp8
:
a_scale
=
tl
.
load
(
a_scale_ptr
)
b_scale
=
tl
.
load
(
b_scale_ptr
+
off_experts
)
# -----------------------------------------------------------
# Iterate to compute a block of the C matrix.
# We accumulate into a `[BLOCK_SIZE_M, BLOCK_SIZE_N]` block
# of fp32 values for higher accuracy.
# `accumulator` will be converted back to fp16 after the loop.
accumulator
=
tl
.
zeros
((
BLOCK_SIZE_M
,
BLOCK_SIZE_N
),
dtype
=
tl
.
float32
)
for
k
in
range
(
0
,
tl
.
cdiv
(
K
,
BLOCK_SIZE_K
)):
# Load the next block of A and B, generate a mask by checking the
# K dimension.
if
even_Ks
:
a
=
tl
.
load
(
a_ptrs
,
mask
=
token_mask
[:,
None
],
other
=
0.0
,
)
b
=
tl
.
load
(
b_ptrs
)
else
:
a
=
tl
.
load
(
a_ptrs
,
mask
=
token_mask
[:,
None
]
&
(
offs_k
[
None
,
:]
<
K
-
k
*
BLOCK_SIZE_K
),
other
=
0.0
,
)
b
=
tl
.
load
(
b_ptrs
,
mask
=
offs_k
[:,
None
]
<
K
-
k
*
BLOCK_SIZE_K
,
other
=
0.0
)
# We accumulate along the K dimension.
if
use_fp8
:
accumulator
=
tl
.
dot
(
a
,
b
,
acc
=
accumulator
)
else
:
accumulator
+=
tl
.
dot
(
a
,
b
)
# Advance the ptrs to the next K block.
a_ptrs
+=
BLOCK_SIZE_K
*
stride_ak
b_ptrs
+=
BLOCK_SIZE_K
*
stride_bk
if
MUL_ROUTED_WEIGHT
:
moe_weight
=
tl
.
load
(
topk_weights_ptr
+
offs_token
,
mask
=
token_mask
,
other
=
0
)
accumulator
=
accumulator
*
moe_weight
[:,
None
]
if
use_fp8
:
accumulator
=
(
accumulator
*
a_scale
*
b_scale
).
to
(
compute_type
)
else
:
accumulator
=
accumulator
.
to
(
compute_type
)
# -----------------------------------------------------------
# Write back the block of the output
offs_cn
=
pid_n
*
BLOCK_SIZE_N
+
tl
.
arange
(
0
,
BLOCK_SIZE_N
)
c_ptrs
=
c_ptr
+
stride_cm
*
offs_token
[:,
None
]
+
stride_cn
*
offs_cn
[
None
,
:]
c_mask
=
token_mask
[:,
None
]
&
(
offs_cn
[
None
,
:]
<
N
)
tl
.
store
(
c_ptrs
,
accumulator
,
mask
=
c_mask
)
def
moe_align_block_size
(
topk_ids
:
torch
.
Tensor
,
block_size
:
int
,
num_experts
:
int
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
,
torch
.
Tensor
]:
"""
Aligns the token distribution across experts to be compatible with block
size for matrix multiplication.
Parameters:
- topk_ids: A tensor of shape [total_tokens, top_k] representing the
top-k expert indices for each token.
- block_size: The block size used in block matrix multiplication.
- num_experts: The total number of experts.
Returns:
- sorted_token_ids: A tensor containing the sorted token indices according
to their allocated expert.
- expert_ids: A tensor indicating the assigned expert index for each block.
- num_tokens_post_padded: The total number of tokens after padding,
ensuring divisibility by block_size.
This function pads the number of tokens that each expert needs to process
so that it is divisible by block_size.
Padding ensures that during block matrix multiplication, the dimensions
align correctly.
Example:
Given topk_ids = [[2, 3, 4], [1, 2, 4], [1, 3, 4], [1, 2, 3]],
block_size = 4, and num_experts = 4:
- We initially have 12 tokens (after repeating 'top_k' times) and 4 experts,
with each expert needing to process 3 tokens.
- As block_size is 4, we pad 1 token for each expert.
- First, flatten topk_ids to [2, 3, 4, 1, 2, 4, 1, 3, 4, 1, 2, 3].
- Then append padding tokens [12, 12, 12, 12] for each block.
- After sorting by expert index, we obtain token_ids
[3, 6, 9, 12, 0, 4, 10, 12, 1, 7, 11, 12, 2, 5, 8, 12].
Tokens 12 are non-existent (padding) and are ignored in
the subsequent matrix multiplication.
- The padding ensures that the total number of tokens is now divisible
by block_size for proper block matrix operations.
"""
max_num_tokens_padded
=
topk_ids
.
numel
()
+
num_experts
*
(
block_size
-
1
)
sorted_ids
=
torch
.
empty
(
(
max_num_tokens_padded
,),
dtype
=
torch
.
int32
,
device
=
topk_ids
.
device
)
sorted_ids
.
fill_
(
topk_ids
.
numel
())
max_num_m_blocks
=
triton
.
cdiv
(
max_num_tokens_padded
,
block_size
)
expert_ids
=
torch
.
empty
(
(
max_num_m_blocks
,),
dtype
=
torch
.
int32
,
device
=
topk_ids
.
device
)
num_tokens_post_pad
=
torch
.
empty
((
1
),
dtype
=
torch
.
int32
,
device
=
topk_ids
.
device
)
ops
.
moe_align_block_size
(
topk_ids
,
num_experts
,
block_size
,
sorted_ids
,
expert_ids
,
num_tokens_post_pad
)
return
sorted_ids
,
expert_ids
,
num_tokens_post_pad
def
invoke_fused_moe_kernel
(
A
:
torch
.
Tensor
,
B
:
torch
.
Tensor
,
C
:
torch
.
Tensor
,
A_scale
:
Optional
[
torch
.
Tensor
],
B_scale
:
Optional
[
torch
.
Tensor
],
topk_weights
:
torch
.
Tensor
,
topk_ids
:
torch
.
Tensor
,
sorted_token_ids
:
torch
.
Tensor
,
expert_ids
:
torch
.
Tensor
,
num_tokens_post_padded
:
torch
.
Tensor
,
mul_routed_weight
:
bool
,
top_k
:
int
,
config
:
Dict
[
str
,
Any
],
compute_type
:
tl
.
dtype
,
use_fp8
:
bool
,
)
->
None
:
assert
topk_weights
.
stride
(
1
)
==
1
assert
sorted_token_ids
.
stride
(
0
)
==
1
padded_size
=
padding_size
if
not
use_fp8
:
assert
A_scale
is
None
assert
B_scale
is
None
# MOE_PADDING FP8 only
padded_size
=
0
else
:
A
,
A_scale
=
ops
.
scaled_fp8_quant
(
A
,
A_scale
)
assert
B_scale
is
not
None
grid
=
lambda
META
:
(
triton
.
cdiv
(
sorted_token_ids
.
shape
[
0
],
META
[
"BLOCK_SIZE_M"
])
*
triton
.
cdiv
(
B
.
shape
[
1
],
META
[
"BLOCK_SIZE_N"
]),
)
K
=
B
.
shape
[
2
]
-
padded_size
if
K
%
config
[
"BLOCK_SIZE_K"
]
==
0
:
even_ks
=
True
else
:
even_ks
=
False
fused_moe_kernel
[
grid
](
A
,
B
,
C
,
A_scale
,
B_scale
,
topk_weights
,
sorted_token_ids
,
expert_ids
,
num_tokens_post_padded
,
B
.
shape
[
1
],
B
.
shape
[
2
]
-
padded_size
,
sorted_token_ids
.
shape
[
0
],
topk_ids
.
numel
(),
A
.
stride
(
0
),
A
.
stride
(
1
),
B
.
stride
(
0
),
B
.
stride
(
2
),
B
.
stride
(
1
),
C
.
stride
(
1
),
C
.
stride
(
2
),
MUL_ROUTED_WEIGHT
=
mul_routed_weight
,
top_k
=
top_k
,
compute_type
=
compute_type
,
use_fp8
=
use_fp8
,
even_Ks
=
even_ks
,
**
config
,
)
def
get_config_file_name
(
E
:
int
,
N
:
int
,
dtype
:
Optional
[
str
])
->
str
:
device_name
=
torch
.
cuda
.
get_device_name
().
replace
(
" "
,
"_"
)
dtype_selector
=
""
if
not
dtype
else
f
",dtype=
{
dtype
}
"
return
f
"E=
{
E
}
,N=
{
N
}
,device_name=
{
device_name
}{
dtype_selector
}
.json"
@
functools
.
lru_cache
def
get_moe_configs
(
E
:
int
,
N
:
int
,
dtype
:
Optional
[
str
])
->
Optional
[
Dict
[
int
,
Any
]]:
"""
Return optimized configurations for the fused MoE kernel.
The return value will be a dictionary that maps an irregular grid of
batch sizes to configurations of the fused_moe kernel. To evaluate the
kernel on a given batch size bs, the closest batch size in the grid should
be picked and the associated configuration chosen to invoke the kernel.
"""
# First look up if an optimized configuration is available in the configs
# directory
json_file_name
=
get_config_file_name
(
E
,
N
,
dtype
)
config_file_path
=
os
.
path
.
join
(
os
.
path
.
dirname
(
os
.
path
.
realpath
(
__file__
)),
"configs"
,
json_file_name
)
if
os
.
path
.
exists
(
config_file_path
):
with
open
(
config_file_path
)
as
f
:
logger
.
info
(
"Using configuration from %s for MoE layer."
,
config_file_path
)
# If a configuration has been found, return it
return
{
int
(
key
):
val
for
key
,
val
in
json
.
load
(
f
).
items
()}
# If no optimized configuration is available, we will use the default
# configuration
return
None
def
get_default_config
(
M
:
int
,
E
:
int
,
N
:
int
,
K
:
int
,
topk
:
int
,
dtype
:
Optional
[
str
],
)
->
Dict
[
str
,
int
]:
if
dtype
==
"float8"
:
config
=
{
"BLOCK_SIZE_M"
:
128
,
"BLOCK_SIZE_N"
:
256
,
"BLOCK_SIZE_K"
:
128
,
"GROUP_SIZE_M"
:
32
,
"num_warps"
:
8
,
"num_stages"
:
4
,
}
if
M
<=
E
:
config
=
{
"BLOCK_SIZE_M"
:
64
,
"BLOCK_SIZE_N"
:
128
,
"BLOCK_SIZE_K"
:
128
,
"GROUP_SIZE_M"
:
1
,
"num_warps"
:
4
,
"num_stages"
:
4
,
}
else
:
config
=
{
"BLOCK_SIZE_M"
:
64
,
"BLOCK_SIZE_N"
:
64
,
"BLOCK_SIZE_K"
:
32
,
"GROUP_SIZE_M"
:
8
,
}
if
M
<=
E
:
config
=
{
"BLOCK_SIZE_M"
:
16
,
"BLOCK_SIZE_N"
:
32
,
"BLOCK_SIZE_K"
:
64
,
"GROUP_SIZE_M"
:
1
,
}
return
config
def
try_get_optimal_moe_config
(
w1_shape
:
Tuple
[
int
,
...],
w2_shape
:
Tuple
[
int
,
...],
top_k
:
int
,
dtype
:
Optional
[
str
],
M
:
int
,
override_config
:
Optional
[
Dict
[
str
,
Any
]]
=
None
,
):
if
override_config
:
config
=
override_config
else
:
# First try to load optimal config from the file
E
,
_
,
N
=
w2_shape
configs
=
get_moe_configs
(
E
,
N
,
dtype
)
if
configs
:
# If an optimal configuration map has been found, look up the
# optimal config
config
=
configs
[
min
(
configs
.
keys
(),
key
=
lambda
x
:
abs
(
x
-
M
))]
else
:
# Else use the default config
config
=
get_default_config
(
M
,
E
,
N
,
w1_shape
[
2
],
top_k
,
dtype
)
return
config
def
fused_topk
(
hidden_states
:
torch
.
Tensor
,
gating_output
:
torch
.
Tensor
,
topk
:
int
,
renormalize
:
bool
,
):
assert
hidden_states
.
shape
[
0
]
==
gating_output
.
shape
[
0
],
"Number of tokens mismatch"
M
,
_
=
hidden_states
.
shape
topk_weights
=
torch
.
empty
(
M
,
topk
,
dtype
=
torch
.
float32
,
device
=
hidden_states
.
device
)
topk_ids
=
torch
.
empty
(
M
,
topk
,
dtype
=
torch
.
int32
,
device
=
hidden_states
.
device
)
token_expert_indicies
=
torch
.
empty
(
M
,
topk
,
dtype
=
torch
.
int32
,
device
=
hidden_states
.
device
)
ops
.
topk_softmax
(
topk_weights
,
topk_ids
,
token_expert_indicies
,
gating_output
.
float
(),
# TODO(woosuk): Optimize this.
)
del
token_expert_indicies
# Not used. Will be used in the future.
if
renormalize
:
topk_weights
=
topk_weights
/
topk_weights
.
sum
(
dim
=-
1
,
keepdim
=
True
)
return
topk_weights
,
topk_ids
# This is used by the Deepseek-V2 model
def
grouped_topk
(
hidden_states
:
torch
.
Tensor
,
gating_output
:
torch
.
Tensor
,
topk
:
int
,
renormalize
:
bool
,
num_expert_group
:
int
=
0
,
topk_group
:
int
=
0
,
):
assert
hidden_states
.
shape
[
0
]
==
gating_output
.
shape
[
0
],
"Number of tokens mismatch"
scores
=
torch
.
softmax
(
gating_output
,
dim
=-
1
)
num_token
=
scores
.
shape
[
0
]
group_scores
=
(
scores
.
view
(
num_token
,
num_expert_group
,
-
1
).
max
(
dim
=-
1
).
values
)
# [n, n_group]
group_idx
=
torch
.
topk
(
group_scores
,
k
=
topk_group
,
dim
=-
1
,
sorted
=
False
)[
1
]
# [n, top_k_group]
group_mask
=
torch
.
zeros_like
(
group_scores
)
# [n, n_group]
group_mask
.
scatter_
(
1
,
group_idx
,
1
)
# [n, n_group]
score_mask
=
(
group_mask
.
unsqueeze
(
-
1
)
.
expand
(
num_token
,
num_expert_group
,
scores
.
shape
[
-
1
]
//
num_expert_group
)
.
reshape
(
num_token
,
-
1
)
)
# [n, e]
tmp_scores
=
scores
.
masked_fill
(
~
score_mask
.
bool
(),
0.0
)
# [n, e]
topk_weights
,
topk_ids
=
torch
.
topk
(
tmp_scores
,
k
=
topk
,
dim
=-
1
,
sorted
=
False
)
if
renormalize
:
topk_weights
=
topk_weights
/
topk_weights
.
sum
(
dim
=-
1
,
keepdim
=
True
)
return
topk_weights
,
topk_ids
def
fused_experts
(
hidden_states
:
torch
.
Tensor
,
w1
:
torch
.
Tensor
,
w2
:
torch
.
Tensor
,
topk_weights
:
torch
.
Tensor
,
topk_ids
:
torch
.
Tensor
,
inplace
:
bool
=
False
,
override_config
:
Optional
[
Dict
[
str
,
Any
]]
=
None
,
use_fp8
:
bool
=
False
,
w1_scale
:
Optional
[
torch
.
Tensor
]
=
None
,
w2_scale
:
Optional
[
torch
.
Tensor
]
=
None
,
a1_scale
:
Optional
[
torch
.
Tensor
]
=
None
,
a2_scale
:
Optional
[
torch
.
Tensor
]
=
None
,
):
padded_size
=
padding_size
if
not
use_fp8
:
# MOE_PADDING FP8 only
padded_size
=
0
# Check constraints.
assert
hidden_states
.
shape
[
1
]
==
w1
.
shape
[
2
]
-
padded_size
,
"Hidden size mismatch"
assert
topk_weights
.
shape
==
topk_ids
.
shape
,
"topk shape mismatch"
assert
hidden_states
.
is_contiguous
(),
"Hidden_states must be contiguous"
assert
w1
.
is_contiguous
(),
"Expert weights1 must be contiguous"
assert
w2
.
is_contiguous
(),
"Expert weights2 must be contiguous"
assert
hidden_states
.
dtype
in
[
torch
.
float32
,
torch
.
float16
,
torch
.
bfloat16
]
num_tokens
,
_
=
hidden_states
.
shape
E
,
N
,
_
=
w1
.
shape
# We execute the fused_moe kernel in chunks to circumvent this issue:
# https://github.com/vllm-project/vllm/issues/5938
CHUNK_SIZE
=
envs
.
VLLM_FUSED_MOE_CHUNK_SIZE
M
=
min
(
num_tokens
,
CHUNK_SIZE
)
get_config_func
=
functools
.
partial
(
try_get_optimal_moe_config
,
w1
.
shape
,
(
w2
.
shape
[
0
],
w2
.
shape
[
1
],
w2
.
shape
[
2
]
-
padded_size
),
topk_ids
.
shape
[
1
],
"float8"
if
use_fp8
else
None
,
override_config
=
override_config
,
)
config
=
get_config_func
(
M
)
intermediate_cache1
=
torch
.
empty
(
(
M
,
topk_ids
.
shape
[
1
],
N
),
device
=
hidden_states
.
device
,
dtype
=
hidden_states
.
dtype
,
)
intermediate_cache2
=
torch
.
empty
(
(
M
*
topk_ids
.
shape
[
1
],
N
//
2
),
device
=
hidden_states
.
device
,
dtype
=
hidden_states
.
dtype
,
)
intermediate_cache3
=
torch
.
empty
(
(
M
,
topk_ids
.
shape
[
1
],
w2
.
shape
[
1
]),
device
=
hidden_states
.
device
,
dtype
=
hidden_states
.
dtype
,
)
compute_type
=
tl
.
bfloat16
if
hidden_states
.
dtype
==
torch
.
bfloat16
else
tl
.
float16
if
inplace
:
out_hidden_states
=
hidden_states
else
:
out_hidden_states
=
torch
.
empty_like
(
hidden_states
)
for
chunk
in
range
((
num_tokens
//
CHUNK_SIZE
)
+
1
):
begin_chunk_idx
,
end_chunk_idx
=
(
chunk
*
CHUNK_SIZE
,
min
((
chunk
+
1
)
*
CHUNK_SIZE
,
num_tokens
),
)
curr_hidden_states
=
hidden_states
[
begin_chunk_idx
:
end_chunk_idx
]
tokens_in_chunk
,
_
=
curr_hidden_states
.
shape
if
tokens_in_chunk
==
0
:
break
if
tokens_in_chunk
<
CHUNK_SIZE
and
chunk
>
0
:
# Adjust the intermediate cache size and config for the last
# chunk. Note that in most cases we only have one chunk
# so the cache size and config are already set correctly and
# do not need to be adjusted.
intermediate_cache1
=
intermediate_cache1
[:
tokens_in_chunk
]
intermediate_cache2
=
intermediate_cache2
[:
tokens_in_chunk
]
intermediate_cache3
=
intermediate_cache3
[:
tokens_in_chunk
]
config
=
get_config_func
(
tokens_in_chunk
)
curr_topk_ids
=
topk_ids
[
begin_chunk_idx
:
end_chunk_idx
]
curr_topk_weights
=
topk_weights
[
begin_chunk_idx
:
end_chunk_idx
]
sorted_token_ids
,
expert_ids
,
num_tokens_post_padded
=
moe_align_block_size
(
curr_topk_ids
,
config
[
"BLOCK_SIZE_M"
],
E
)
invoke_fused_moe_kernel
(
curr_hidden_states
,
w1
,
intermediate_cache1
,
a1_scale
,
w1_scale
,
curr_topk_weights
,
curr_topk_ids
,
sorted_token_ids
,
expert_ids
,
num_tokens_post_padded
,
False
,
topk_ids
.
shape
[
1
],
config
,
compute_type
=
compute_type
,
use_fp8
=
use_fp8
,
)
ops
.
gelu_and_mul
(
intermediate_cache2
,
intermediate_cache1
.
view
(
-
1
,
N
))
invoke_fused_moe_kernel
(
intermediate_cache2
,
w2
,
intermediate_cache3
,
a2_scale
,
w2_scale
,
curr_topk_weights
,
curr_topk_ids
,
sorted_token_ids
,
expert_ids
,
num_tokens_post_padded
,
True
,
1
,
config
,
compute_type
=
compute_type
,
use_fp8
=
use_fp8
,
)
torch
.
sum
(
intermediate_cache3
.
view
(
*
intermediate_cache3
.
shape
),
dim
=
1
,
out
=
out_hidden_states
[
begin_chunk_idx
:
end_chunk_idx
],
)
return
out_hidden_states
def
fused_moe
(
hidden_states
:
torch
.
Tensor
,
w1
:
torch
.
Tensor
,
w2
:
torch
.
Tensor
,
gating_output
:
torch
.
Tensor
,
topk
:
int
,
renormalize
:
bool
,
inplace
:
bool
=
False
,
override_config
:
Optional
[
Dict
[
str
,
Any
]]
=
None
,
use_grouped_topk
:
bool
=
False
,
num_expert_group
:
Optional
[
int
]
=
None
,
topk_group
:
Optional
[
int
]
=
None
,
use_fp8
:
bool
=
False
,
w1_scale
:
Optional
[
torch
.
Tensor
]
=
None
,
w2_scale
:
Optional
[
torch
.
Tensor
]
=
None
,
a1_scale
:
Optional
[
torch
.
Tensor
]
=
None
,
a2_scale
:
Optional
[
torch
.
Tensor
]
=
None
,
)
->
torch
.
Tensor
:
"""
This function computes a Mixture of Experts (MoE) layer using two sets of
weights, w1 and w2, and top-k gating mechanism.
Parameters:
- hidden_states (torch.Tensor): The input tensor to the MoE layer.
- w1 (torch.Tensor): The first set of expert weights.
- w2 (torch.Tensor): The second set of expert weights.
- gating_output (torch.Tensor): The output of the gating operation
(before softmax).
- topk (int): The number of top-k experts to select.
- renormalize (bool): If True, renormalize the top-k weights to sum to 1.
- inplace (bool): If True, perform the operation in-place.
Defaults to False.
- override_config (Optional[Dict[str, Any]]): Optional override
for the kernel configuration.
- num_expert_group: Optional[int]: additional parameter for grouped_topk
- topk_group: Optional[int]: additional parameter for grouped_topk
- use_grouped_topk: If True, use grouped_topk instead of fused_topk
note: Deepseekv2 model uses grouped_topk
- use_fp8 (bool): If True, use fp8 arithmetic to compute the inner
products for w1 and w2. Defaults to False.
- w1_scale (Optional[torch.Tensor]): Optional scale to be used for
w1.
- w2_scale (Optional[torch.Tensor]): Optional scale to be used for
w2.
Returns:
- torch.Tensor: The output tensor after applying the MoE layer.
"""
# Check constraints.
assert
gating_output
.
shape
[
1
]
==
w1
.
shape
[
0
],
"Number of experts mismatch"
if
use_grouped_topk
:
assert
num_expert_group
is
not
None
and
topk_group
is
not
None
topk_weights
,
topk_ids
=
grouped_topk
(
hidden_states
,
gating_output
,
topk
,
renormalize
,
num_expert_group
,
topk_group
,
)
else
:
topk_weights
,
topk_ids
=
fused_topk
(
hidden_states
,
gating_output
,
topk
,
renormalize
)
return
fused_experts
(
hidden_states
,
w1
,
w2
,
topk_weights
,
topk_ids
,
inplace
=
inplace
,
override_config
=
override_config
,
use_fp8
=
use_fp8
,
w1_scale
=
w1_scale
,
w2_scale
=
w2_scale
,
a1_scale
=
a1_scale
,
a2_scale
=
a2_scale
,
)
python/sglang/srt/layers/fused_moe_grok/layer.py
deleted
100644 → 0
View file @
a4fd2f9b
# Adapted from
# https://github.com/vllm-project/vllm/tree/v0.5.4/vllm/model_executor/layers/fused_moe
import
os
from
abc
import
abstractmethod
from
typing
import
List
,
Optional
,
Tuple
import
torch
import
torch.nn.functional
as
F
from
vllm.distributed
import
(
get_tensor_model_parallel_rank
,
get_tensor_model_parallel_world_size
,
tensor_model_parallel_all_reduce
,
)
from
vllm.logger
import
init_logger
from
vllm.model_executor.custom_op
import
CustomOp
from
vllm.model_executor.layers.quantization.base_config
import
(
QuantizationConfig
,
QuantizeMethodBase
,
)
from
vllm.model_executor.layers.quantization.fp8
import
Fp8Config
from
vllm.model_executor.utils
import
set_weight_attrs
from
sglang.srt.layers.fused_moe_grok.fused_moe
import
padding_size
from
sglang.srt.utils
import
is_hip
logger
=
init_logger
(
__name__
)
class
FusedMoEMethodBase
(
QuantizeMethodBase
):
@
abstractmethod
def
create_weights
(
self
,
layer
:
torch
.
nn
.
Module
,
num_experts
:
int
,
hidden_size
:
int
,
intermediate_size
:
int
,
params_dtype
:
torch
.
dtype
,
**
extra_weight_attrs
,
):
raise
NotImplementedError
@
abstractmethod
def
apply
(
self
,
layer
:
torch
.
nn
.
Module
,
x
:
torch
.
Tensor
,
router_logits
:
torch
.
Tensor
,
top_k
:
int
,
renormalize
:
bool
=
True
,
use_grouped_topk
:
bool
=
False
,
num_expert_group
:
Optional
[
int
]
=
None
,
topk_group
:
Optional
[
int
]
=
None
,
)
->
torch
.
Tensor
:
raise
NotImplementedError
class
UnquantizedFusedMoEMethod
(
FusedMoEMethodBase
,
CustomOp
):
"""MoE method without quantization."""
def
create_weights
(
self
,
layer
:
torch
.
nn
.
Module
,
num_experts
:
int
,
hidden_size
:
int
,
intermediate_size
:
int
,
params_dtype
:
torch
.
dtype
,
**
extra_weight_attrs
,
):
# Fused gate_up_proj (column parallel)
w13_weight
=
torch
.
nn
.
Parameter
(
torch
.
empty
(
num_experts
,
2
*
intermediate_size
,
hidden_size
,
dtype
=
params_dtype
),
requires_grad
=
False
,
)
layer
.
register_parameter
(
"w13_weight"
,
w13_weight
)
set_weight_attrs
(
w13_weight
,
extra_weight_attrs
)
# down_proj (row parallel)
w2_weight
=
torch
.
nn
.
Parameter
(
torch
.
empty
(
num_experts
,
hidden_size
,
intermediate_size
,
dtype
=
params_dtype
),
requires_grad
=
False
,
)
layer
.
register_parameter
(
"w2_weight"
,
w2_weight
)
set_weight_attrs
(
w2_weight
,
extra_weight_attrs
)
def
apply
(
self
,
layer
:
torch
.
nn
.
Module
,
x
:
torch
.
Tensor
,
router_logits
:
torch
.
Tensor
,
top_k
:
int
,
renormalize
:
bool
=
True
,
use_grouped_topk
:
bool
=
False
,
num_expert_group
:
Optional
[
int
]
=
None
,
topk_group
:
Optional
[
int
]
=
None
,
)
->
torch
.
Tensor
:
return
self
.
forward
(
x
,
layer
.
w13_weight
,
layer
.
w2_weight
,
router_logits
,
top_k
,
renormalize
,
use_grouped_topk
,
num_expert_group
,
topk_group
,
)
def
forward_cuda
(
self
,
x
:
torch
.
Tensor
,
w1
:
torch
.
Tensor
,
w2
:
torch
.
Tensor
,
router_logits
:
torch
.
Tensor
,
top_k
:
int
,
renormalize
:
bool
,
use_grouped_topk
:
bool
,
num_expert_group
:
Optional
[
int
],
topk_group
:
Optional
[
int
],
)
->
torch
.
Tensor
:
from
sglang.srt.layers.fused_moe_grok.fused_moe
import
fused_moe
return
fused_moe
(
x
,
w1
,
w2
,
router_logits
,
top_k
,
renormalize
=
renormalize
,
inplace
=
True
,
use_grouped_topk
=
use_grouped_topk
,
num_expert_group
=
num_expert_group
,
topk_group
=
topk_group
,
)
def
forward_cpu
(
self
,
*
args
,
**
kwargs
):
raise
NotImplementedError
(
"The CPU backend currently does not support MoE."
)
def
forward_tpu
(
self
,
x
:
torch
.
Tensor
,
w1
:
torch
.
Tensor
,
w2
:
torch
.
Tensor
,
router_logits
:
torch
.
Tensor
,
top_k
:
int
,
renormalize
:
bool
,
use_grouped_topk
:
bool
,
num_expert_group
:
Optional
[
int
],
topk_group
:
Optional
[
int
],
)
->
torch
.
Tensor
:
raise
NotImplementedError
(
"The TPU backend currently does not support MoE."
)
class
FusedMoE
(
torch
.
nn
.
Module
):
"""FusedMoE layer for MoE models.
This layer contains both MergedColumnParallel weights (gate_up_proj /
w13) and RowParallelLinear weights (down_proj/ w2).
Note: Mixtral uses w1, w2, and w3 for gate, up, and down_proj. We
copy that naming convention here and handle any remapping in the
load_weights function in each model implementation.
Args:
num_experts: Number of experts in the model
top_k: Number of experts selected for each token
hidden_size: Input hidden state size of the transformer
intermediate_size: Intermediate size of the experts
params_dtype: Data type for the parameters.
reduce_results: Whether to all all_reduce on the output of the layer
renomalize: Whether to renormalize the logits in the fused_moe kernel
quant_config: Quantization configure.
"""
def
__init__
(
self
,
num_experts
:
int
,
top_k
:
int
,
hidden_size
:
int
,
intermediate_size
:
int
,
params_dtype
:
Optional
[
torch
.
dtype
]
=
None
,
reduce_results
:
bool
=
False
,
renormalize
:
bool
=
True
,
use_grouped_topk
:
bool
=
False
,
num_expert_group
:
Optional
[
int
]
=
None
,
topk_group
:
Optional
[
int
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
tp_size
:
Optional
[
int
]
=
None
,
prefix
:
str
=
""
,
):
super
().
__init__
()
if
params_dtype
is
None
:
params_dtype
=
torch
.
get_default_dtype
()
self
.
tp_size
=
(
tp_size
if
tp_size
is
not
None
else
get_tensor_model_parallel_world_size
()
)
self
.
top_k
=
top_k
self
.
num_experts
=
num_experts
self
.
intermediate_size_per_partition
=
intermediate_size
//
self
.
tp_size
self
.
reduce_results
=
reduce_results
self
.
renormalize
=
renormalize
self
.
use_grouped_topk
=
use_grouped_topk
if
self
.
use_grouped_topk
:
assert
num_expert_group
is
not
None
and
topk_group
is
not
None
self
.
num_expert_group
=
num_expert_group
self
.
topk_group
=
topk_group
if
quant_config
is
None
:
self
.
quant_method
:
Optional
[
QuantizeMethodBase
]
=
(
UnquantizedFusedMoEMethod
()
)
else
:
if
isinstance
(
quant_config
,
Fp8Config
):
self
.
quant_method
=
Fp8MoEMethod
(
quant_config
)
else
:
self
.
quant_method
=
quant_config
.
get_quant_method
(
self
,
prefix
)
assert
self
.
quant_method
is
not
None
self
.
quant_method
.
create_weights
(
layer
=
self
,
num_experts
=
num_experts
,
hidden_size
=
hidden_size
,
intermediate_size
=
self
.
intermediate_size_per_partition
,
params_dtype
=
params_dtype
,
weight_loader
=
self
.
weight_loader
,
)
def
weight_loader
(
self
,
param
:
torch
.
nn
.
Parameter
,
loaded_weight
:
torch
.
Tensor
,
weight_name
:
str
,
shard_id
:
int
,
expert_id
:
int
,
use_presharded_weights
:
bool
=
False
,
):
param_data
=
param
.
data
# Input scales can be loaded directly and should be equal.
if
"input_scale"
in
weight_name
:
if
(
param_data
[
expert_id
]
!=
1
and
(
param_data
[
expert_id
]
-
loaded_weight
).
abs
()
>
1e-5
):
raise
ValueError
(
"input_scales of w1 and w3 of a layer "
f
"must be equal. But got
{
param_data
[
expert_id
]
}
"
f
"vs.
{
loaded_weight
}
"
)
param_data
[
expert_id
]
=
loaded_weight
# Weight scales
elif
"weight_scale"
in
weight_name
:
# If we are in merged column case (gate_up_proj)
# shard_id 0 == gate_proj / w1
# shard_id 2 == up_proj / w3
if
shard_id
==
0
or
shard_id
==
2
:
# We have to keep the weight scales of w1 and w3 because
# we need to re-quantize w1/w3 weights after weight loading.
idx
=
0
if
shard_id
==
0
else
1
param_data
[
expert_id
][
idx
]
=
loaded_weight
# If we are in the row parallel case (down_proj)
# shard_id 1 == down_proj / w2
else
:
param_data
[
expert_id
]
=
loaded_weight
# Weights
else
:
tp_rank
=
get_tensor_model_parallel_rank
()
shard_size
=
self
.
intermediate_size_per_partition
if
use_presharded_weights
:
shard
=
slice
(
None
)
else
:
shard
=
slice
(
tp_rank
*
shard_size
,
(
tp_rank
+
1
)
*
shard_size
)
# w1, gate_proj case: Load into first shard of w13.
if
shard_id
==
0
:
param_data
[
expert_id
,
0
:
shard_size
,
:]
=
loaded_weight
[
shard
,
:]
# w3, up_proj case: Load into second shard of w13.
elif
shard_id
==
2
:
param_data
[
expert_id
,
shard_size
:
2
*
shard_size
,
:]
=
loaded_weight
[
shard
,
:
]
# w2, down_proj case: Load into only shard of w2.
elif
shard_id
==
1
:
param_data
[
expert_id
,
:,
:]
=
loaded_weight
[:,
shard
]
else
:
raise
ValueError
(
f
"Shard id must be in [0,1,2] but got
{
shard_id
}
"
)
def
forward
(
self
,
hidden_states
:
torch
.
Tensor
,
router_logits
:
torch
.
Tensor
):
assert
self
.
quant_method
is
not
None
# Matrix multiply.
final_hidden_states
=
self
.
quant_method
.
apply
(
self
,
x
=
hidden_states
,
router_logits
=
router_logits
,
top_k
=
self
.
top_k
,
renormalize
=
self
.
renormalize
,
use_grouped_topk
=
self
.
use_grouped_topk
,
num_expert_group
=
self
.
num_expert_group
,
topk_group
=
self
.
topk_group
,
)
if
self
.
reduce_results
and
self
.
tp_size
>
1
:
final_hidden_states
=
tensor_model_parallel_all_reduce
(
final_hidden_states
)
return
final_hidden_states
@
classmethod
def
make_expert_params_mapping
(
cls
,
ckpt_gate_proj_name
:
str
,
ckpt_down_proj_name
:
str
,
ckpt_up_proj_name
:
str
,
num_experts
:
int
,
)
->
List
[
Tuple
[
str
,
str
,
int
,
int
]]:
gate_up
=
[
ckpt_gate_proj_name
,
ckpt_up_proj_name
]
gate_down_up
=
[
ckpt_gate_proj_name
,
ckpt_down_proj_name
,
ckpt_up_proj_name
]
return
(
[
# These are the weight scales for the experts
# (param_name, weight_name, expert_id, shard_id)
(
(
"experts.w13_scale"
if
weight_name
in
gate_up
else
"experts.w2_scale"
),
f
"experts.
{
expert_id
}
.
{
weight_name
}
.weight_scale"
,
expert_id
,
shard_id
,
)
for
expert_id
in
range
(
num_experts
)
for
shard_id
,
weight_name
in
enumerate
(
gate_down_up
)
]
+
[
# These are the weights for the experts
# (param_name, weight_name, expert_id, shard_id)
(
(
"experts.w13_weight"
if
weight_name
in
gate_up
else
"experts.w2_weight"
),
f
"experts.
{
expert_id
}
.
{
weight_name
}
.weight"
,
expert_id
,
shard_id
,
)
for
expert_id
in
range
(
num_experts
)
for
shard_id
,
weight_name
in
enumerate
(
gate_down_up
)
]
+
[
# These are the weight scales for the experts
# (param_name, weight_name, expert_id, shard_id)
(
(
"experts.a13_scale"
if
weight_name
in
gate_up
else
"experts.a2_scale"
),
f
"experts.
{
expert_id
}
.
{
weight_name
}
.input_scale"
,
expert_id
,
shard_id
,
)
for
expert_id
in
range
(
num_experts
)
for
shard_id
,
weight_name
in
enumerate
(
gate_down_up
)
]
)
import
torch
from
torch.nn
import
Module
from
vllm
import
_custom_ops
as
ops
from
vllm.model_executor.layers.quantization.utils.w8a8_utils
import
(
all_close_1d
,
normalize_e4m3fn_to_e4m3fnuz
,
per_tensor_dequantize
,
)
from
vllm.utils
import
print_warning_once
class
Fp8MoEMethod
(
FusedMoEMethodBase
):
"""MoE method for FP8.
Supports loading FP8 checkpoints with static weight scale and
dynamic/static activation scale.
Also supports loading quantized FP16/BF16 model checkpoints with dynamic
activation scaling. The weight scaling factor will be initialized after
the model weights are loaded.
Args:
quant_config: The quantization config.
"""
def
__init__
(
self
,
quant_config
:
Fp8Config
):
self
.
quant_config
=
quant_config
def
create_weights
(
self
,
layer
:
Module
,
num_experts
:
int
,
hidden_size
:
int
,
intermediate_size
:
int
,
params_dtype
:
torch
.
dtype
,
**
extra_weight_attrs
,
):
if
self
.
quant_config
.
is_checkpoint_fp8_serialized
:
params_dtype
=
torch
.
float8_e4m3fn
# WEIGHTS
w13_weight
=
torch
.
nn
.
Parameter
(
torch
.
empty
(
num_experts
,
2
*
intermediate_size
,
hidden_size
,
dtype
=
params_dtype
),
requires_grad
=
False
,
)
layer
.
register_parameter
(
"w13_weight"
,
w13_weight
)
set_weight_attrs
(
w13_weight
,
extra_weight_attrs
)
w2_weight
=
torch
.
nn
.
Parameter
(
torch
.
empty
(
num_experts
,
hidden_size
,
intermediate_size
,
dtype
=
params_dtype
),
requires_grad
=
False
,
)
layer
.
register_parameter
(
"w2_weight"
,
w2_weight
)
set_weight_attrs
(
w2_weight
,
extra_weight_attrs
)
# WEIGHT_SCALES
# Allocate 2 scales for w1 and w3 respectively.
# They will be combined to a single scale after weight loading.
w13_scale
=
torch
.
nn
.
Parameter
(
torch
.
ones
(
num_experts
,
2
,
dtype
=
torch
.
float32
),
requires_grad
=
False
)
layer
.
register_parameter
(
"w13_scale"
,
w13_scale
)
w2_scale
=
torch
.
nn
.
Parameter
(
torch
.
ones
(
num_experts
,
dtype
=
torch
.
float32
),
requires_grad
=
False
)
layer
.
register_parameter
(
"w2_scale"
,
w2_scale
)
# If loading fp8 checkpoint, pass the weight loaders.
# If loading an fp16 checkpoint, do not (we will quantize in
# process_weights_after_loading()
if
self
.
quant_config
.
is_checkpoint_fp8_serialized
:
set_weight_attrs
(
w13_scale
,
extra_weight_attrs
)
set_weight_attrs
(
w2_scale
,
extra_weight_attrs
)
# INPUT_SCALES
if
self
.
quant_config
.
activation_scheme
==
"static"
:
if
not
self
.
quant_config
.
is_checkpoint_fp8_serialized
:
raise
ValueError
(
"Found static activation scheme for checkpoint that "
"was not serialized fp8."
)
a13_scale
=
torch
.
nn
.
Parameter
(
torch
.
ones
(
num_experts
,
dtype
=
torch
.
float32
),
requires_grad
=
False
)
layer
.
register_parameter
(
"a13_scale"
,
a13_scale
)
set_weight_attrs
(
a13_scale
,
extra_weight_attrs
)
a2_scale
=
torch
.
nn
.
Parameter
(
torch
.
ones
(
num_experts
,
dtype
=
torch
.
float32
),
requires_grad
=
False
)
layer
.
register_parameter
(
"a2_scale"
,
a2_scale
)
set_weight_attrs
(
a2_scale
,
extra_weight_attrs
)
else
:
layer
.
a13_scale
=
None
layer
.
a2_scale
=
None
def
process_weights_after_loading
(
self
,
layer
:
Module
)
->
None
:
# If checkpoint is fp16 or bfloat16, quantize in place.
if
not
self
.
quant_config
.
is_checkpoint_fp8_serialized
:
# If ROCm, use float8_e4m3fnuz instead (MI300x HW)
fp8_dtype
=
torch
.
float8_e4m3fnuz
if
is_hip
()
else
torch
.
float8_e4m3fn
w13_weight
=
torch
.
empty_like
(
layer
.
w13_weight
.
data
,
dtype
=
fp8_dtype
)
w2_weight
=
torch
.
empty_like
(
layer
.
w2_weight
.
data
,
dtype
=
fp8_dtype
)
# Re-initialize w13_scale because we directly quantize
# merged w13 weights and generate a single scaling factor.
layer
.
w13_scale
=
torch
.
nn
.
Parameter
(
torch
.
ones
(
layer
.
num_experts
,
dtype
=
torch
.
float32
,
device
=
w13_weight
.
device
),
requires_grad
=
False
,
)
for
expert
in
range
(
layer
.
num_experts
):
w13_weight
[
expert
,
:,
:],
layer
.
w13_scale
[
expert
]
=
(
ops
.
scaled_fp8_quant
(
layer
.
w13_weight
.
data
[
expert
,
:,
:])
)
w2_weight
[
expert
,
:,
:],
layer
.
w2_scale
[
expert
]
=
ops
.
scaled_fp8_quant
(
layer
.
w2_weight
.
data
[
expert
,
:,
:]
)
layer
.
w13_weight
=
torch
.
nn
.
Parameter
(
w13_weight
,
requires_grad
=
False
)
layer
.
w2_weight
=
torch
.
nn
.
Parameter
(
w2_weight
,
requires_grad
=
False
)
# If ROCm, apply weight padding (min. Mem channel contention) only if set
if
is_hip
()
and
bool
(
int
(
os
.
getenv
(
"MOE_PADDING"
,
"0"
))):
layer
.
w13_weight
=
torch
.
nn
.
Parameter
(
F
.
pad
(
layer
.
w13_weight
.
data
,
(
0
,
padding_size
),
"constant"
,
0
),
requires_grad
=
False
,
)
torch
.
cuda
.
empty_cache
()
layer
.
w2_weight
=
torch
.
nn
.
Parameter
(
F
.
pad
(
layer
.
w2_weight
.
data
,
(
0
,
padding_size
),
"constant"
,
0
),
requires_grad
=
False
,
)
torch
.
cuda
.
empty_cache
()
return
# If checkpoint is fp8, we need to handle that the
# MoE kernels require single activation scale and single weight
# scale for w13 per expert.
else
:
# Fp8 moe kernels require a single activation scale.
# We take the max of all the scales in case they differ.
if
self
.
quant_config
.
activation_scheme
==
"static"
:
if
layer
.
a13_scale
is
None
or
layer
.
a2_scale
is
None
:
raise
ValueError
(
"QuantConfig has static quantization, but found "
"activation scales are None."
)
if
not
all_close_1d
(
layer
.
a13_scale
)
or
not
all_close_1d
(
layer
.
a2_scale
):
print_warning_once
(
"Found input_scales that are not equal for "
"fp8 MoE layer. Using the maximum across experts "
"for each layer. "
)
layer
.
a13_scale
=
torch
.
nn
.
Parameter
(
layer
.
a13_scale
.
max
(),
requires_grad
=
False
)
layer
.
a2_scale
=
torch
.
nn
.
Parameter
(
layer
.
a2_scale
.
max
(),
requires_grad
=
False
)
# If ROCm, normalize the weights and scales to e4m3fnuz
if
is_hip
():
# Normalize the weights and scales
w13_weight
,
w13_scale
,
a13_scale
=
normalize_e4m3fn_to_e4m3fnuz
(
layer
.
w13_weight
,
layer
.
w13_scale
,
layer
.
a13_scale
)
w2_weight
,
w2_scale
,
a2_scale
=
normalize_e4m3fn_to_e4m3fnuz
(
layer
.
w2_weight
,
layer
.
w2_scale
,
layer
.
a2_scale
)
# Reset the parameters
layer
.
w13_weight
=
torch
.
nn
.
Parameter
(
w13_weight
,
requires_grad
=
False
)
layer
.
w13_scale
=
torch
.
nn
.
Parameter
(
w13_scale
,
requires_grad
=
False
)
if
a13_scale
is
not
None
:
layer
.
a13_scale
=
torch
.
nn
.
Parameter
(
a13_scale
,
requires_grad
=
False
)
layer
.
w2_weight
=
torch
.
nn
.
Parameter
(
w2_weight
,
requires_grad
=
False
)
layer
.
w2_scale
=
torch
.
nn
.
Parameter
(
w2_scale
,
requires_grad
=
False
)
if
a2_scale
is
not
None
:
layer
.
a2_scale
=
torch
.
nn
.
Parameter
(
a2_scale
,
requires_grad
=
False
)
# Fp8 moe kernel needs single weight scale for w13 per expert.
# We take the max then dequant and requant each expert.
assert
layer
.
w13_scale
is
not
None
shard_size
=
layer
.
intermediate_size_per_partition
max_w13_scales
=
layer
.
w13_scale
.
max
(
dim
=
1
).
values
for
expert_id
in
range
(
layer
.
num_experts
):
start
=
0
for
shard_id
in
range
(
2
):
dq_weight
=
per_tensor_dequantize
(
layer
.
w13_weight
[
expert_id
][
start
:
start
+
shard_size
,
:],
layer
.
w13_scale
[
expert_id
][
shard_id
],
)
layer
.
w13_weight
[
expert_id
][
start
:
start
+
shard_size
,
:],
_
=
(
ops
.
scaled_fp8_quant
(
dq_weight
,
max_w13_scales
[
expert_id
])
)
start
+=
shard_size
layer
.
w13_scale
=
torch
.
nn
.
Parameter
(
max_w13_scales
,
requires_grad
=
False
)
# If ROCm, apply weight padding (min. Mem channel contention) only if set
if
is_hip
()
and
bool
(
int
(
os
.
getenv
(
"MOE_PADDING"
,
"0"
))):
layer
.
w13_weight
=
torch
.
nn
.
Parameter
(
F
.
pad
(
layer
.
w13_weight
.
data
,
(
0
,
padding_size
),
"constant"
,
0
),
requires_grad
=
False
,
)
torch
.
cuda
.
empty_cache
()
layer
.
w2_weight
=
torch
.
nn
.
Parameter
(
F
.
pad
(
layer
.
w2_weight
.
data
,
(
0
,
padding_size
),
"constant"
,
0
),
requires_grad
=
False
,
)
torch
.
cuda
.
empty_cache
()
return
def
apply
(
self
,
layer
:
torch
.
nn
.
Module
,
x
:
torch
.
Tensor
,
router_logits
:
torch
.
Tensor
,
top_k
:
int
,
renormalize
:
bool
=
True
,
use_grouped_topk
:
bool
=
False
,
num_expert_group
:
Optional
[
int
]
=
None
,
topk_group
:
Optional
[
int
]
=
None
,
)
->
torch
.
Tensor
:
from
sglang.srt.layers.fused_moe_grok.fused_moe
import
fused_moe
return
fused_moe
(
x
,
layer
.
w13_weight
,
layer
.
w2_weight
,
router_logits
,
top_k
,
renormalize
=
renormalize
,
inplace
=
True
,
use_fp8
=
True
,
w1_scale
=
layer
.
w13_scale
,
w2_scale
=
layer
.
w2_scale
,
a1_scale
=
layer
.
a13_scale
,
a2_scale
=
layer
.
a2_scale
,
use_grouped_topk
=
use_grouped_topk
,
num_expert_group
=
num_expert_group
,
topk_group
=
topk_group
,
)
python/sglang/srt/layers/fused_moe_
grok
/configs/E=8,N=4096,device_name=AMD_Instinct_MI300X,dtype=float8.json
→
python/sglang/srt/layers/fused_moe_
triton
/configs/E=8,N=4096,device_name=AMD_Instinct_MI300X,dtype=float8.json
View file @
dd5eba4c
File moved
python/sglang/srt/layers/fused_moe_
grok
/configs/E=8,N=8192,device_name=AMD_Instinct_MI300X,dtype=float8.json
→
python/sglang/srt/layers/fused_moe_
triton
/configs/E=8,N=8192,device_name=AMD_Instinct_MI300X,dtype=float8.json
View file @
dd5eba4c
File moved
python/sglang/srt/models/grok.py
View file @
dd5eba4c
...
...
@@ -16,22 +16,17 @@
# https://github.com/vllm-project/vllm/blob/c7f2cf2b7f67bce5842fedfdba508440fe257375/vllm/model_executor/models/mixtral.py#L1
"""Inference-only Grok1 model."""
import
warnings
from
typing
import
Iterable
,
List
,
Optional
,
Tuple
from
typing
import
Iterable
,
Optional
,
Tuple
import
torch
import
torch.nn.functional
as
F
from
torch
import
nn
from
transformers
import
PretrainedConfig
from
vllm.distributed
import
(
get_tensor_model_parallel_rank
,
get_tensor_model_parallel_world_size
,
)
from
vllm.distributed
import
get_tensor_model_parallel_world_size
from
vllm.model_executor.layers.rotary_embedding
import
get_rope
from
vllm.model_executor.model_loader.loader
import
DefaultModelLoader
from
vllm.model_executor.model_loader.weight_utils
import
default_weight_loader
from
sglang.srt.layers.fused_moe_
grok
import
FusedMoE
from
sglang.srt.layers.fused_moe_
triton
import
FusedMoE
from
sglang.srt.layers.layernorm
import
RMSNorm
from
sglang.srt.layers.linear
import
(
QKVParallelLinear
,
...
...
@@ -41,10 +36,12 @@ from sglang.srt.layers.linear import (
from
sglang.srt.layers.logits_processor
import
LogitsProcessor
from
sglang.srt.layers.quantization.base_config
import
QuantizationConfig
from
sglang.srt.layers.radix_attention
import
RadixAttention
from
sglang.srt.layers.torchao_utils
import
apply_torchao_config_
from
sglang.srt.layers.vocab_parallel_embedding
import
(
ParallelLMHead
,
VocabParallelEmbedding
,
)
from
sglang.srt.managers.schedule_batch
import
global_server_args_dict
from
sglang.srt.model_executor.forward_batch_info
import
ForwardBatch
...
...
@@ -293,17 +290,11 @@ class Grok1ForCausalLM(nn.Module):
super
().
__init__
()
self
.
config
=
config
self
.
quant_config
=
quant_config
self
.
torchao_config
=
global_server_args_dict
[
"torchao_config"
]
self
.
model
=
Grok1Model
(
config
,
quant_config
=
quant_config
)
self
.
lm_head
=
ParallelLMHead
(
config
.
vocab_size
,
config
.
hidden_size
)
self
.
logits_processor
=
LogitsProcessor
(
config
)
# Monkey patch _prepare_weights to load pre-sharded weights
setattr
(
DefaultModelLoader
,
"_prepare_weights"
,
_prepare_presharded_weights
)
self
.
use_presharded_weights
=
True
warnings
.
filterwarnings
(
"ignore"
,
category
=
FutureWarning
)
def
forward
(
self
,
input_ids
:
torch
.
Tensor
,
...
...
@@ -357,28 +348,23 @@ class Grok1ForCausalLM(nn.Module):
continue
name
=
name
.
replace
(
weight_name
,
param_name
)
if
self
.
use_presharded_weights
:
extra_kwargs
=
{
"use_presharded_weights"
:
self
.
use_presharded_weights
}
else
:
extra_kwargs
=
{}
param
=
params_dict
[
name
]
weight_loader
=
param
.
weight_loader
weight_loader
(
param
,
loaded_weight
,
weight_
name
,
name
,
shard_id
=
shard_id
,
expert_id
=
expert_id
,
**
extra_kwargs
,
)
break
else
:
# Skip loading extra bias for GPTQ models.
if
name
.
endswith
(
".bias"
)
and
name
not
in
params_dict
:
continue
# Skip loading kv_scale from ckpts towards new design.
if
name
.
endswith
(
".kv_scale"
)
and
name
not
in
params_dict
:
continue
if
name
is
None
:
continue
...
...
@@ -388,30 +374,7 @@ class Grok1ForCausalLM(nn.Module):
)
weight_loader
(
param
,
loaded_weight
)
old_prepare_weights
=
getattr
(
DefaultModelLoader
,
"_prepare_weights"
)
def
_prepare_presharded_weights
(
self
,
model_name_or_path
:
str
,
revision
:
Optional
[
str
],
fall_back_to_pt
:
bool
)
->
Tuple
[
str
,
List
[
str
],
bool
]:
import
glob
import
os
if
get_tensor_model_parallel_world_size
()
==
1
:
return
old_prepare_weights
(
self
,
model_name_or_path
,
revision
,
fall_back_to_pt
)
tp_rank
=
get_tensor_model_parallel_rank
()
allow_patterns
=
[
f
"*-
{
tp_rank
:
03
d
}
.bin"
]
hf_folder
=
model_name_or_path
hf_weights_files
:
List
[
str
]
=
[]
for
pattern
in
allow_patterns
:
hf_weights_files
+=
glob
.
glob
(
os
.
path
.
join
(
hf_folder
,
pattern
))
use_safetensors
=
False
return
hf_folder
,
hf_weights_files
,
use_safetensors
apply_torchao_config_
(
self
,
params_dict
,
set
([
"proj.weight"
]))
class
Grok1ModelForCausalLM
(
Grok1ForCausalLM
):
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment