Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
15ba07ef
Unverified
Commit
15ba07ef
authored
Apr 03, 2025
by
bnellnm
Committed by
GitHub
Apr 03, 2025
Browse files
[Minor] Fused experts refactor (#15914)
Signed-off-by:
Bill Nell
<
bnell@redhat.com
>
parent
d2b58ca2
Changes
8
Hide whitespace changes
Inline
Side-by-side
Showing
8 changed files
with
790 additions
and
737 deletions
+790
-737
tests/kernels/test_block_fp8.py
tests/kernels/test_block_fp8.py
+6
-3
tests/kernels/test_cutlass_moe.py
tests/kernels/test_cutlass_moe.py
+8
-8
vllm/model_executor/layers/fused_moe/__init__.py
vllm/model_executor/layers/fused_moe/__init__.py
+4
-2
vllm/model_executor/layers/fused_moe/cutlass_moe.py
vllm/model_executor/layers/fused_moe/cutlass_moe.py
+144
-0
vllm/model_executor/layers/fused_moe/deep_gemm_moe.py
vllm/model_executor/layers/fused_moe/deep_gemm_moe.py
+294
-0
vllm/model_executor/layers/fused_moe/fused_moe.py
vllm/model_executor/layers/fused_moe/fused_moe.py
+43
-724
vllm/model_executor/layers/fused_moe/moe_align_block_size.py
vllm/model_executor/layers/fused_moe/moe_align_block_size.py
+243
-0
vllm/model_executor/layers/fused_moe/utils.py
vllm/model_executor/layers/fused_moe/utils.py
+48
-0
No files found.
tests/kernels/test_block_fp8.py
View file @
15ba07ef
...
...
@@ -9,8 +9,11 @@ import torch
from
vllm.config
import
VllmConfig
,
set_current_vllm_config
from
vllm.model_executor.layers.activation
import
SiluAndMul
from
vllm.model_executor.layers.fused_moe
import
fused_moe
from
vllm.model_executor.layers.fused_moe.fused_moe
import
(
deep_gemm_moe_fp8
,
fused_topk
,
moe_align_block_size
)
from
vllm.model_executor.layers.fused_moe.deep_gemm_moe
import
(
deep_gemm_moe_fp8
)
from
vllm.model_executor.layers.fused_moe.fused_moe
import
fused_topk
from
vllm.model_executor.layers.fused_moe.moe_align_block_size
import
(
moe_align_block_size
)
from
vllm.model_executor.layers.quantization.utils.fp8_utils
import
(
per_token_group_quant_fp8
,
w8a8_block_fp8_matmul
)
from
vllm.platforms
import
current_platform
...
...
@@ -437,7 +440,7 @@ def test_w8a8_block_fp8_deep_gemm_fused_moe(M, N, K, E, topk, seed):
pytest
.
skip
(
f
"Skipping test; bad size m=
{
M
}
, n=
{
N
}
, k=
{
K
}
, topk=
{
topk
}
, E=
{
E
}
"
)
if
(
N
<=
512
)
:
if
N
<=
512
:
pytest
.
skip
(
"Skipping N <= 512 until performance issues solved."
)
vllm_config
=
VllmConfig
()
...
...
tests/kernels/test_cutlass_moe.py
View file @
15ba07ef
...
...
@@ -4,8 +4,8 @@ import torch
from
vllm
import
_custom_ops
as
ops
from
vllm.config
import
ParallelConfig
,
VllmConfig
,
set_current_vllm_config
from
vllm.model_executor.layers.fused_moe.
fused
_moe
import
(
cutlass_moe_fp8
,
fused_experts
,
from
vllm.model_executor.layers.fused_moe.
cutlass
_moe
import
cutlass_moe_fp8
from
vllm.model_executor.layers.fused_moe.fused_moe
import
(
fused_experts
,
fused_topk
)
from
vllm.platforms
import
current_platform
...
...
@@ -131,9 +131,9 @@ def test_cutlass_moe_no_graph(
c_strides2
,
a1_scale
=
a_scale1
)
print
(
triton_output
)
print
(
cutlass_output
)
print
(
"*"
)
#
print(triton_output)
#
print(cutlass_output)
#
print("*")
torch
.
testing
.
assert_close
(
triton_output
,
cutlass_output
,
...
...
@@ -234,9 +234,9 @@ def test_cutlass_moe_cuda_graph(
graph
.
replay
()
torch
.
cuda
.
synchronize
()
print
(
triton_output
)
print
(
cutlass_output
)
print
(
"*"
)
#
print(triton_output)
#
print(cutlass_output)
#
print("*")
torch
.
testing
.
assert_close
(
triton_output
,
cutlass_output
,
...
...
vllm/model_executor/layers/fused_moe/__init__.py
View file @
15ba07ef
...
...
@@ -35,9 +35,11 @@ if HAS_TRITON:
# import to register the custom ops
import
vllm.model_executor.layers.fused_moe.fused_marlin_moe
# noqa
import
vllm.model_executor.layers.fused_moe.fused_moe
# noqa
from
vllm.model_executor.layers.fused_moe.cutlass_moe
import
(
cutlass_moe_fp8
)
from
vllm.model_executor.layers.fused_moe.fused_moe
import
(
cutlass_moe_fp8
,
fused_experts
,
fused_moe
,
fused_topk
,
get_config_file_name
,
grouped_topk
)
fused_experts
,
fused_moe
,
fused_topk
,
get_config_file_name
,
grouped_topk
)
__all__
+=
[
"fused_moe"
,
...
...
vllm/model_executor/layers/fused_moe/cutlass_moe.py
0 → 100644
View file @
15ba07ef
# SPDX-License-Identifier: Apache-2.0
"""Fused MoE kernel."""
from
typing
import
Optional
import
torch
from
vllm
import
_custom_ops
as
ops
#TODO make the grouped gemm kernel consistent with scaled gemm kernel
def
cutlass_moe_fp8
(
a
:
torch
.
Tensor
,
w1_q
:
torch
.
Tensor
,
w2_q
:
torch
.
Tensor
,
w1_scale
:
torch
.
Tensor
,
w2_scale
:
torch
.
Tensor
,
topk_weights
:
torch
.
Tensor
,
topk_ids
:
torch
.
Tensor
,
ab_strides1
:
torch
.
Tensor
,
c_strides1
:
torch
.
Tensor
,
ab_strides2
:
torch
.
Tensor
,
c_strides2
:
torch
.
Tensor
,
a1_scale
:
Optional
[
torch
.
Tensor
]
=
None
,
a2_scale
:
Optional
[
torch
.
Tensor
]
=
None
,
out_dtype
:
torch
.
dtype
=
torch
.
half
,
)
->
torch
.
Tensor
:
"""
This function computes a a8w8-quantized Mixture of Experts (MoE) layer
using two sets of quantized weights, w1_q and w2_q, and top-k gating
mechanism. The matrix multiplications are implemented with CUTLASS
grouped gemm.
Parameters:
- a (torch.Tensor): The input tensor to the MoE layer.
Shape: [M, K]
- w1_q (torch.Tensor): The first set of fp8-quantized expert weights.
Shape: [num_experts, K, 2N] (the weights are passed transposed)
- w2_q (torch.Tensor): The second set of fp8-quantized expert weights.
Shape: [num_experts, N, K] (the weights are passed transposed)
- w1_scale (torch.Tensor): The fp32 scale to dequantize w1_q.
Shape: [num_experts] or [num_experts, 2N]
- w2_scale (torch.Tensor): The fp32 scale to dequantize w2_q.
Shape: [num_experts] or [num_experts, K]
- gating_output (torch.Tensor): The output of the gating operation
(before softmax).
- topk_weights (torch.Tensor): The weights of each token->expert mapping.
- ab_strides1 (torch.Tensor): The input and weights strides of the first
grouped gemm.
- c_strides1 (torch.Tensor): The output strides of the first grouped gemm.
- ab_strides2 (torch.Tensor): The input and weights strides of the second
grouped gemm.
- c_strides2 (torch.Tensor): The output strides of the second grouped gemm.
- a1_scale (Optional[torch.Tensor]): The optional fp32 scale to quantize a.
Shape: scalar or [M]
- a2_scale (Optional[torch.Tensor]): The optional fp32 scale to
quantize the intermediate result between the gemms.
Shape: scalar or [M]
- out_dtype (torch.Tensor): The output tensor type.
Returns:
- torch.Tensor: The fp16 output tensor after applying the MoE layer.
"""
assert
topk_weights
.
shape
==
topk_ids
.
shape
,
"topk shape mismatch"
assert
w1_q
.
dtype
==
torch
.
float8_e4m3fn
assert
w2_q
.
dtype
==
torch
.
float8_e4m3fn
assert
a
.
shape
[
1
]
==
w1_q
.
shape
[
1
],
"Hidden size mismatch w1"
assert
w1_q
.
shape
[
2
]
==
w2_q
.
shape
[
1
]
*
2
,
"Hidden size mismatch w2"
assert
w1_q
.
shape
[
0
]
==
w2_q
.
shape
[
0
],
"Expert number mismatch"
assert
a1_scale
is
None
or
a1_scale
.
dim
(
)
==
0
or
a1_scale
.
shape
[
0
]
==
1
or
a1_scale
.
shape
[
0
]
==
a
.
shape
[
0
],
"Input scale shape mismatch"
assert
w1_scale
.
dim
()
==
1
or
w1_scale
.
shape
[
1
]
==
1
or
w1_scale
.
shape
[
1
]
==
w1_q
.
shape
[
2
],
"W1 scale shape mismatch"
assert
w2_scale
.
dim
()
==
1
or
w2_scale
.
shape
[
1
]
==
1
or
w2_scale
.
shape
[
1
]
==
w2_q
.
shape
[
2
],
"W2 scale shape mismatch"
assert
w1_q
.
shape
[
0
]
==
w2_q
.
shape
[
0
],
"Weights expert number mismatch"
assert
w1_q
.
shape
[
0
]
==
w1_scale
.
shape
[
0
],
"w1 scales expert number mismatch"
assert
w1_q
.
shape
[
0
]
==
w2_scale
.
shape
[
0
],
"w2 scales expert number mismatch"
assert
a2_scale
is
None
or
a1_scale
is
None
or
a2_scale
.
shape
==
a1_scale
.
shape
,
"Intermediate scale shape mismatch"
# noqa: E501
assert
ab_strides1
.
shape
[
0
]
==
w1_q
.
shape
[
0
],
"AB Strides 1 expert number mismatch"
assert
c_strides1
.
shape
[
0
]
==
w1_q
.
shape
[
0
],
"C Strides 1 expert number mismatch"
assert
ab_strides2
.
shape
[
0
]
==
w2_q
.
shape
[
0
],
"AB Strides 2 expert number mismatch"
assert
c_strides2
.
shape
[
0
]
==
w2_q
.
shape
[
0
],
"C Strides 2 expert number mismatch"
assert
out_dtype
in
[
torch
.
half
,
torch
.
bfloat16
],
"Invalid output dtype"
num_experts
=
w1_q
.
size
(
0
)
m
=
a
.
size
(
0
)
k
=
w1_q
.
size
(
1
)
n
=
w2_q
.
size
(
1
)
topk
=
topk_ids
.
size
(
1
)
per_act_token
=
a1_scale
.
numel
()
!=
1
if
a1_scale
is
not
None
else
(
a2_scale
.
numel
()
!=
1
if
a2_scale
is
not
None
else
False
)
a_q
,
a1_scale
=
ops
.
scaled_fp8_quant
(
a
,
a1_scale
,
use_per_token_if_dynamic
=
per_act_token
)
device
=
a_q
.
device
expert_offsets
=
torch
.
empty
((
num_experts
+
1
),
dtype
=
torch
.
int32
,
device
=
device
)
problem_sizes1
=
torch
.
empty
((
num_experts
,
3
),
dtype
=
torch
.
int32
,
device
=
device
)
problem_sizes2
=
torch
.
empty
((
num_experts
,
3
),
dtype
=
torch
.
int32
,
device
=
device
)
a_map
=
torch
.
empty
((
topk_ids
.
numel
()),
dtype
=
torch
.
int32
,
device
=
device
)
c_map
=
torch
.
empty
((
topk_ids
.
numel
()),
dtype
=
torch
.
int32
,
device
=
device
)
ops
.
get_cutlass_moe_mm_data
(
topk_ids
,
expert_offsets
,
problem_sizes1
,
problem_sizes2
,
a_map
,
c_map
,
num_experts
,
n
,
k
)
rep_a_q
=
a_q
.
view
(
dtype
=
torch
.
uint8
)[
a_map
].
view
(
dtype
=
a_q
.
dtype
)
rep_a1_scales
=
a1_scale
[
a_map
]
if
per_act_token
else
a1_scale
c1
=
torch
.
empty
((
m
*
topk
,
n
*
2
),
device
=
device
,
dtype
=
out_dtype
)
c2
=
torch
.
empty
((
m
*
topk
,
k
),
device
=
device
,
dtype
=
out_dtype
)
ops
.
cutlass_moe_mm
(
c1
,
rep_a_q
,
w1_q
,
rep_a1_scales
,
w1_scale
,
expert_offsets
[:
-
1
],
problem_sizes1
,
ab_strides1
,
ab_strides1
,
c_strides1
)
intermediate
=
torch
.
empty
((
m
*
topk
,
n
),
device
=
device
,
dtype
=
out_dtype
)
torch
.
ops
.
_C
.
silu_and_mul
(
intermediate
,
c1
)
intemediate_q
,
a2_scale
=
ops
.
scaled_fp8_quant
(
intermediate
,
a2_scale
,
use_per_token_if_dynamic
=
per_act_token
)
ops
.
cutlass_moe_mm
(
c2
,
intemediate_q
,
w2_q
,
a2_scale
,
w2_scale
,
expert_offsets
[:
-
1
],
problem_sizes2
,
ab_strides2
,
ab_strides2
,
c_strides2
)
return
(
c2
[
c_map
].
view
(
m
,
topk
,
k
)
*
topk_weights
.
view
(
m
,
topk
,
1
).
to
(
out_dtype
)).
sum
(
dim
=
1
)
vllm/model_executor/layers/fused_moe/deep_gemm_moe.py
0 → 100644
View file @
15ba07ef
# SPDX-License-Identifier: Apache-2.0
import
importlib.util
from
typing
import
Optional
,
Tuple
import
torch
import
vllm.envs
as
envs
from
vllm
import
_custom_ops
as
ops
from
vllm.logger
import
init_logger
from
vllm.model_executor.layers.fused_moe.moe_align_block_size
import
(
moe_align_block_size
)
from
vllm.model_executor.layers.fused_moe.utils
import
(
_fp8_perm
,
_fp8_quantize
,
_resize_cache
)
from
vllm.utils
import
round_up
logger
=
init_logger
(
__name__
)
has_deep_gemm
=
importlib
.
util
.
find_spec
(
"deep_gemm"
)
is
not
None
def
_valid_deep_gemm
(
hidden_states
:
torch
.
Tensor
,
w1
:
torch
.
Tensor
,
w2
:
torch
.
Tensor
,
expert_map
:
Optional
[
torch
.
Tensor
]
=
None
)
->
bool
:
"""
Check if the given problem size is supported by the DeepGemm grouped
gemm kernel. All of M, N, K and the quantization block_shape must be
aligned by `dg.get_m_alignment_for_contiguous_layout()`.
"""
if
not
has_deep_gemm
:
return
False
# Lazy import to avoid CUDA initialization problems.
import
deep_gemm
as
dg
# Expert maps not supported yet.
if
expert_map
is
not
None
:
return
False
align
=
dg
.
get_m_alignment_for_contiguous_layout
()
M
=
hidden_states
.
shape
[
0
]
_
,
K
,
N
=
w2
.
shape
# For now, disable DeepGemm for small N until better permute/unpermute
# ops are available.
if
N
<=
512
:
return
False
if
align
>
M
or
N
%
align
!=
0
or
K
%
align
!=
0
:
return
False
return
(
hidden_states
.
is_contiguous
()
and
w1
.
is_contiguous
()
and
w2
.
is_contiguous
())
def
_moe_permute
(
curr_hidden_states
:
torch
.
Tensor
,
a1q_scale
:
Optional
[
torch
.
Tensor
],
curr_topk_ids
:
torch
.
Tensor
,
global_num_experts
:
int
,
expert_map
:
Optional
[
torch
.
Tensor
],
block_m
:
int
,
)
->
Tuple
[
torch
.
Tensor
,
Optional
[
torch
.
Tensor
],
torch
.
Tensor
,
torch
.
Tensor
,
Optional
[
torch
.
Tensor
]]:
"""
Determine the sorted_token_ids, expert_ids for the given problem size.
Permute the hidden states and scales according to `sorted_token_ids`.
"""
top_k_num
=
curr_topk_ids
.
shape
[
1
]
tokens_in_chunk
,
_
=
curr_hidden_states
.
shape
sorted_token_ids
,
expert_ids
,
num_tokens_post_padded
=
(
moe_align_block_size
(
curr_topk_ids
,
block_m
,
global_num_experts
,
expert_map
,
pad_sorted_ids
=
True
))
inv_perm
:
Optional
[
torch
.
Tensor
]
=
None
num_tokens
=
top_k_num
*
tokens_in_chunk
sorted_token_ids
=
sorted_token_ids
.
clamp
(
max
=
num_tokens
-
1
)
expert_ids
=
torch
.
repeat_interleave
(
expert_ids
,
block_m
,
dim
=
0
)
inv_perm
=
torch
.
argsort
(
sorted_token_ids
)[:
num_tokens
]
# Permute according to sorted token ids.
curr_hidden_states
=
_fp8_perm
(
curr_hidden_states
,
sorted_token_ids
//
top_k_num
)
if
a1q_scale
is
not
None
:
a1q_scale
=
a1q_scale
[
sorted_token_ids
//
top_k_num
]
return
(
curr_hidden_states
,
a1q_scale
,
sorted_token_ids
,
expert_ids
,
inv_perm
)
def
_moe_unpermute_and_reduce
(
out
:
torch
.
Tensor
,
curr_hidden
:
torch
.
Tensor
,
inv_perm
:
Optional
[
torch
.
Tensor
],
topk_weight
:
torch
.
Tensor
,
)
->
None
:
"""
Unpermute the final result and apply topk_weights, then perform the final
reduction on the hidden states.
"""
M
,
topk
=
topk_weight
.
shape
K
=
curr_hidden
.
shape
[
1
]
curr_hidden
=
curr_hidden
[
inv_perm
,
...]
curr_hidden
=
curr_hidden
.
view
(
-
1
,
topk
,
K
)
curr_hidden
.
mul_
(
topk_weight
.
view
(
M
,
-
1
,
1
))
ops
.
moe_sum
(
curr_hidden
,
out
)
def
deep_gemm_moe_fp8
(
hidden_states
:
torch
.
Tensor
,
w1
:
torch
.
Tensor
,
w2
:
torch
.
Tensor
,
w1_scale
:
torch
.
Tensor
,
w2_scale
:
torch
.
Tensor
,
topk_weights
:
torch
.
Tensor
,
topk_ids
:
torch
.
Tensor
,
inplace
:
bool
=
False
,
activation
:
str
=
"silu"
,
global_num_experts
:
int
=
-
1
,
expert_map
:
Optional
[
torch
.
Tensor
]
=
None
,
a1_scale
:
Optional
[
torch
.
Tensor
]
=
None
,
a2_scale
:
Optional
[
torch
.
Tensor
]
=
None
,
)
->
torch
.
Tensor
:
"""
This function computes a a8w8-quantized Mixture of Experts (MoE) layer
using two sets of quantized weights, w1_q and w2_q, and top-k gating
mechanism. The matrix multiplications are implemented with DeepGemm
grouped gemm.
Parameters:
- hidden_states (torch.Tensor): The input tensor to the MoE layer.
Shape: [M, K]
- w1 (torch.Tensor): The first set of fp8 quantized expert weights.
Shape: [num_experts, K, 2N] (the weights are passed transposed)
- w2 (torch.Tensor): The second set of fp8 quantized expert weights.
Shape: [num_experts, N, K] (the weights are passed transposed)
- w1_scale (torch.Tensor): The fp32 scale to dequantize w1_q.
Shape: [num_experts] or [num_experts, 2N]
- w2_scale (torch.Tensor): The fp32 scale to dequantize w2_q.
Shape: [num_experts] or [num_experts, K]
- topk_weights (torch.Tensor): The weights of each token->expert mapping.
- topk_ids (torch.Tensor): The token->expert mapping for topk_weights.
- inplace (bool): If True, perform the operation in-place.
Defaults to False.
- activation (str): The activation function to apply after the first
MoE layer.
- global_num_experts (int): The total number of experts in the global
expert space.
- expert_map (Optional[torch.Tensor]): A tensor mapping expert indices
from the global expert space to the local expert space of the expert
parallel shard.
- a1_scale (Optional[torch.Tensor]): The optional fp32 scale to quantize a.
Shape: scalar or [M]
- a2_scale (Optional[torch.Tensor]): The optional fp32 scale to
quantize the intermediate result between the gemms.
Shape: scalar or [M]
Returns:
- torch.Tensor: The bfloat16 output tensor after applying the MoE layer.
"""
# Lazy import to avoid CUDA initialization problems.
import
deep_gemm
as
dg
assert
expert_map
is
None
,
"Expert maps not supported yet"
assert
hidden_states
.
shape
[
1
]
==
w1
.
shape
[
2
],
"Hidden size mismatch"
assert
topk_weights
.
shape
==
topk_ids
.
shape
,
"topk shape mismatch"
assert
hidden_states
.
is_contiguous
(),
"Hidden_states must be contiguous"
assert
w1
.
stride
(
-
1
)
==
1
,
"Stride of last dimension must be 1"
assert
w2
.
stride
(
-
1
)
==
1
,
"Stride of last dimension must be 1"
assert
hidden_states
.
dtype
in
[
torch
.
float32
,
torch
.
float16
,
torch
.
bfloat16
]
assert
w1
.
dtype
==
torch
.
float8_e4m3fn
assert
w2
.
dtype
==
torch
.
float8_e4m3fn
assert
w1
.
shape
[
0
]
==
w2
.
shape
[
0
],
"Expert number mismatch"
assert
w1
.
shape
[
0
]
==
w1_scale
.
shape
[
0
],
"w1 scales expert number mismatch"
assert
w1
.
shape
[
0
]
==
w2_scale
.
shape
[
0
],
"w2 scales expert number mismatch"
assert
a1_scale
is
None
or
a1_scale
.
dim
(
)
==
0
or
a1_scale
.
shape
[
0
]
==
1
or
a1_scale
.
shape
[
0
]
==
hidden_states
.
shape
[
0
],
"Input scale shape mismatch"
assert
a2_scale
is
None
or
a1_scale
is
None
or
a2_scale
.
shape
==
a1_scale
.
shape
,
"Intermediate scale shape mismatch"
# noqa: E501
num_tokens
,
_
=
hidden_states
.
shape
E
,
N
,
_
=
w1
.
shape
K
=
w2
.
shape
[
1
]
if
global_num_experts
==
-
1
:
global_num_experts
=
E
# We execute the fused_moe kernel in chunks to circumvent this issue:
# https://github.com/vllm-project/vllm/issues/5938
CHUNK_SIZE
=
envs
.
VLLM_FUSED_MOE_CHUNK_SIZE
assert
_valid_deep_gemm
(
hidden_states
,
w1
,
w2
,
expert_map
)
if
inplace
:
out_hidden_states
=
hidden_states
else
:
out_hidden_states
=
torch
.
empty_like
(
hidden_states
)
block_m
=
dg
.
get_m_alignment_for_contiguous_layout
()
block_shape
=
[
block_m
,
block_m
]
assert
w1_scale
is
not
None
assert
w2_scale
is
not
None
# We attempt to transpose and align offline in Fp8MoEMethod, in which
# case these calls will be nops. Otherwise, they'll be performed every
# time the layer is executed.
w1_scale
=
dg
.
get_col_major_tma_aligned_tensor
(
w1_scale
).
contiguous
()
w2_scale
=
dg
.
get_col_major_tma_aligned_tensor
(
w2_scale
).
contiguous
()
M_sum
=
topk_ids
.
numel
()
+
global_num_experts
*
(
block_m
-
1
)
M_sum
=
round_up
(
M_sum
,
block_m
)
num_chunks
=
(
num_tokens
//
CHUNK_SIZE
)
+
1
# We can reuse the memory between cache1 and cache3 because by the time
# we need cache3, we're done with cache1
workspace13
=
torch
.
empty
(
M_sum
*
max
(
N
,
K
),
device
=
hidden_states
.
device
,
dtype
=
hidden_states
.
dtype
)
workspace1
=
workspace13
[:
M_sum
*
N
].
view
(
M_sum
,
N
)
workspace2
=
torch
.
empty
((
M_sum
,
N
//
2
),
device
=
hidden_states
.
device
,
dtype
=
hidden_states
.
dtype
)
workspace3
=
workspace13
[:
M_sum
*
K
].
view
(
M_sum
,
K
)
for
chunk
in
range
(
num_chunks
):
begin_chunk_idx
,
end_chunk_idx
=
(
chunk
*
CHUNK_SIZE
,
min
((
chunk
+
1
)
*
CHUNK_SIZE
,
num_tokens
))
curr_hidden_states
=
hidden_states
[
begin_chunk_idx
:
end_chunk_idx
]
tokens_in_chunk
,
_
=
curr_hidden_states
.
shape
if
tokens_in_chunk
==
0
:
break
curr_topk_ids
=
topk_ids
[
begin_chunk_idx
:
end_chunk_idx
]
curr_topk_weights
=
topk_weights
[
begin_chunk_idx
:
end_chunk_idx
]
a1q_scale
:
Optional
[
torch
.
Tensor
]
=
None
qcurr_hidden_states
,
a1q_scale
=
_fp8_quantize
(
curr_hidden_states
,
a1_scale
,
block_shape
)
(
qcurr_hidden_states
,
a1q_scale
,
sorted_token_ids
,
expert_ids
,
inv_perm
)
=
_moe_permute
(
qcurr_hidden_states
,
a1q_scale
,
curr_topk_ids
,
global_num_experts
,
expert_map
,
block_m
)
# Adjust the intermediate cache size and config for the last chunk.
# Note that in most cases we only have one chunk so the cache size
# and config are already set correctly and do not need to be adjusted.
if
tokens_in_chunk
<
CHUNK_SIZE
and
chunk
>
0
:
curr_M
=
sorted_token_ids
.
numel
()
workspace1
=
_resize_cache
(
workspace1
,
(
curr_M
,
N
))
workspace2
=
_resize_cache
(
workspace2
,
(
curr_M
,
N
//
2
))
workspace3
=
_resize_cache
(
workspace3
,
(
curr_M
,
K
))
dg
.
m_grouped_gemm_fp8_fp8_bf16_nt_contiguous
(
(
qcurr_hidden_states
,
a1q_scale
),
(
w1
,
w1_scale
),
workspace1
,
expert_ids
)
if
activation
==
"silu"
:
torch
.
ops
.
_C
.
silu_and_mul
(
workspace2
,
workspace1
.
view
(
-
1
,
N
))
elif
activation
==
"gelu"
:
torch
.
ops
.
_C
.
gelu_and_mul
(
workspace2
,
workspace1
.
view
(
-
1
,
N
))
else
:
raise
ValueError
(
f
"Unsupported FusedMoe activation:
{
activation
}
"
)
a2q_scale
:
Optional
[
torch
.
Tensor
]
=
None
qworkspace2
,
a2q_scale
=
_fp8_quantize
(
workspace2
,
a2_scale
,
block_shape
)
dg
.
m_grouped_gemm_fp8_fp8_bf16_nt_contiguous
(
(
qworkspace2
,
a2q_scale
),
(
w2
,
w2_scale
),
workspace3
,
expert_ids
)
_moe_unpermute_and_reduce
(
out_hidden_states
[
begin_chunk_idx
:
end_chunk_idx
],
workspace3
.
view
(
*
workspace3
.
shape
),
inv_perm
,
curr_topk_weights
)
return
out_hidden_states
vllm/model_executor/layers/fused_moe/fused_moe.py
View file @
15ba07ef
# SPDX-License-Identifier: Apache-2.0
"""Fused MoE kernel."""
import
functools
import
importlib.util
import
json
import
os
from
math
import
prod
from
typing
import
Any
,
Callable
,
Dict
,
List
,
Optional
,
Tuple
import
torch
...
...
@@ -14,10 +12,13 @@ import triton.language as tl
import
vllm.envs
as
envs
from
vllm
import
_custom_ops
as
ops
from
vllm.logger
import
init_logger
from
vllm.model_executor.layers.quantization.utils.fp8_utils
import
(
per_token_group_quant_fp8
)
from
vllm.model_executor.layers.fused_moe.deep_gemm_moe
import
(
_valid_deep_gemm
,
deep_gemm_moe_fp8
)
from
vllm.model_executor.layers.fused_moe.moe_align_block_size
import
(
moe_align_block_size
)
from
vllm.model_executor.layers.fused_moe.utils
import
_fp8_quantize
from
vllm.platforms
import
current_platform
from
vllm.utils
import
direct_register_custom_op
,
round_up
from
vllm.utils
import
direct_register_custom_op
from
.rocm_aiter_fused_moe
import
(
is_rocm_aiter_moe_enabled
,
rocm_aiter_fused_experts
,
...
...
@@ -25,8 +26,6 @@ from .rocm_aiter_fused_moe import (is_rocm_aiter_moe_enabled,
logger
=
init_logger
(
__name__
)
has_deep_gemm
=
importlib
.
util
.
find_spec
(
"deep_gemm"
)
is
not
None
@
triton
.
jit
def
write_zeros_to_output
(
c_ptr
,
stride_cm
,
stride_cn
,
pid_n
,
N
,
offs_token
,
...
...
@@ -443,300 +442,13 @@ def fused_moe_kernel(
tl
.
store
(
c_ptrs
,
accumulator
,
mask
=
c_mask
)
def
ceil_div
(
a
,
b
):
return
(
a
+
b
-
1
)
//
b
@
triton
.
jit
def
moe_align_block_size_stage1
(
topk_ids_ptr
,
tokens_cnts_ptr
,
num_experts
:
tl
.
constexpr
,
numel
:
tl
.
constexpr
,
tokens_per_thread
:
tl
.
constexpr
,
):
pid
=
tl
.
program_id
(
0
)
start_idx
=
pid
*
tokens_per_thread
off_c
=
(
pid
+
1
)
*
num_experts
for
i
in
range
(
tokens_per_thread
):
if
start_idx
+
i
<
numel
:
idx
=
tl
.
load
(
topk_ids_ptr
+
start_idx
+
i
)
token_cnt
=
tl
.
load
(
tokens_cnts_ptr
+
off_c
+
idx
)
tl
.
store
(
tokens_cnts_ptr
+
off_c
+
idx
,
token_cnt
+
1
)
@
triton
.
jit
def
moe_align_block_size_stage2
(
tokens_cnts_ptr
,
num_experts
:
tl
.
constexpr
,
):
pid
=
tl
.
program_id
(
0
)
last_cnt
=
0
for
i
in
range
(
1
,
num_experts
+
1
):
token_cnt
=
tl
.
load
(
tokens_cnts_ptr
+
i
*
num_experts
+
pid
)
last_cnt
=
last_cnt
+
token_cnt
tl
.
store
(
tokens_cnts_ptr
+
i
*
num_experts
+
pid
,
last_cnt
)
@
triton
.
jit
def
moe_align_block_size_stage3
(
total_tokens_post_pad_ptr
,
tokens_cnts_ptr
,
cumsum_ptr
,
num_experts
:
tl
.
constexpr
,
block_size
:
tl
.
constexpr
,
):
last_cumsum
=
0
off_cnt
=
num_experts
*
num_experts
for
i
in
range
(
1
,
num_experts
+
1
):
token_cnt
=
tl
.
load
(
tokens_cnts_ptr
+
off_cnt
+
i
-
1
)
last_cumsum
=
last_cumsum
+
tl
.
cdiv
(
token_cnt
,
block_size
)
*
block_size
tl
.
store
(
cumsum_ptr
+
i
,
last_cumsum
)
tl
.
store
(
total_tokens_post_pad_ptr
,
last_cumsum
)
@
triton
.
jit
def
moe_align_block_size_stage4
(
topk_ids_ptr
,
sorted_token_ids_ptr
,
expert_ids_ptr
,
tokens_cnts_ptr
,
cumsum_ptr
,
num_experts
:
tl
.
constexpr
,
block_size
:
tl
.
constexpr
,
numel
:
tl
.
constexpr
,
tokens_per_thread
:
tl
.
constexpr
,
):
pid
=
tl
.
program_id
(
0
)
start_idx
=
tl
.
load
(
cumsum_ptr
+
pid
)
end_idx
=
tl
.
load
(
cumsum_ptr
+
pid
+
1
)
for
i
in
range
(
start_idx
,
end_idx
,
block_size
):
tl
.
store
(
expert_ids_ptr
+
i
//
block_size
,
pid
)
start_idx
=
pid
*
tokens_per_thread
off_t
=
pid
*
num_experts
for
i
in
range
(
start_idx
,
tl
.
minimum
(
start_idx
+
tokens_per_thread
,
numel
)):
expert_id
=
tl
.
load
(
topk_ids_ptr
+
i
)
token_cnt
=
tl
.
load
(
tokens_cnts_ptr
+
off_t
+
expert_id
)
rank_post_pad
=
token_cnt
+
tl
.
load
(
cumsum_ptr
+
expert_id
)
tl
.
store
(
sorted_token_ids_ptr
+
rank_post_pad
,
i
)
tl
.
store
(
tokens_cnts_ptr
+
off_t
+
expert_id
,
token_cnt
+
1
)
# Triton implementation based on:
# https://github.com/sgl-project/sglang/commit/ba5112ff691d791a9e38c6c71f59324a5fcb49d0
def
moe_align_block_size_triton
(
topk_ids
:
torch
.
Tensor
,
num_experts
:
int
,
block_size
:
int
,
sorted_token_ids
:
torch
.
Tensor
,
expert_ids
:
torch
.
Tensor
,
num_tokens_post_pad
:
torch
.
Tensor
,
)
->
None
:
numel
=
topk_ids
.
numel
()
grid
=
(
num_experts
,
)
tokens_cnts
=
torch
.
zeros
((
num_experts
+
1
,
num_experts
),
dtype
=
torch
.
int32
,
device
=
topk_ids
.
device
)
cumsum
=
torch
.
zeros
((
num_experts
+
1
,
),
dtype
=
torch
.
int32
,
device
=
topk_ids
.
device
)
tokens_per_thread
=
ceil_div
(
numel
,
num_experts
)
moe_align_block_size_stage1
[
grid
](
topk_ids
,
tokens_cnts
,
num_experts
,
numel
,
tokens_per_thread
,
)
moe_align_block_size_stage2
[
grid
](
tokens_cnts
,
num_experts
,
)
moe_align_block_size_stage3
[(
1
,
)](
num_tokens_post_pad
,
tokens_cnts
,
cumsum
,
num_experts
,
block_size
,
)
moe_align_block_size_stage4
[
grid
](
topk_ids
,
sorted_token_ids
,
expert_ids
,
tokens_cnts
,
cumsum
,
num_experts
,
block_size
,
numel
,
tokens_per_thread
,
)
def
moe_align_block_size
(
topk_ids
:
torch
.
Tensor
,
block_size
:
int
,
num_experts
:
int
,
expert_map
:
Optional
[
torch
.
Tensor
]
=
None
,
pad_sorted_ids
:
bool
=
False
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
,
torch
.
Tensor
]:
"""
Aligns the token distribution across experts to be compatible with block
size for matrix multiplication.
Parameters:
- topk_ids: A tensor of shape [total_tokens, top_k] representing the
top-k expert indices for each token.
- block_size: The block size used in block matrix multiplication.
- num_experts: The total number of experts.
- expert_map: A tensor of shape [num_experts] that maps the expert index
from the global space to the local index space of the current
expert parallel shard. If the expert is not in the current expert
parallel shard, the mapping is set to -1.
- pad_sorted_ids: A flag indicating whether the sorted_token_ids length
should be padded to a multiple of block_size,
Returns:
- sorted_token_ids: A tensor containing the sorted token indices according
to their allocated expert.
- expert_ids: A tensor indicating the assigned expert index for each block.
- num_tokens_post_padded: The total number of tokens after padding,
ensuring divisibility by block_size.
This function pads the number of tokens that each expert needs to process
so that it is divisible by block_size.
Padding ensures that during block matrix multiplication, the dimensions
align correctly.
Example:
Given topk_ids = [[2, 3, 4], [1, 2, 4], [1, 3, 4], [1, 2, 3]],
block_size = 4, and num_experts = 4:
- We initially have 12 tokens (after repeating 'top_k' times) and 4 experts,
with each expert needing to process 3 tokens.
- As block_size is 4, we pad 1 token for each expert.
- First, flatten topk_ids to [2, 3, 4, 1, 2, 4, 1, 3, 4, 1, 2, 3].
- Then append padding tokens [12, 12, 12, 12] for each block.
- After sorting by expert index, we obtain token_ids
[3, 6, 9, 12, 0, 4, 10, 12, 1, 7, 11, 12, 2, 5, 8, 12].
Tokens 12 are non-existent (padding) and are ignored in
the subsequent matrix multiplication.
- The padding ensures that the total number of tokens is now divisible
by block_size for proper block matrix operations.
"""
max_num_tokens_padded
=
topk_ids
.
numel
()
+
num_experts
*
(
block_size
-
1
)
if
pad_sorted_ids
:
max_num_tokens_padded
=
round_up
(
max_num_tokens_padded
,
block_size
)
sorted_ids
=
torch
.
empty
((
max_num_tokens_padded
,
),
dtype
=
torch
.
int32
,
device
=
topk_ids
.
device
)
sorted_ids
.
fill_
(
topk_ids
.
numel
())
max_num_m_blocks
=
triton
.
cdiv
(
max_num_tokens_padded
,
block_size
)
# Expert ids must be zeroed out to prevent index out of bounds error while
# mapping global expert ids to local expert ids in expert parallelism.
expert_ids
=
torch
.
zeros
((
max_num_m_blocks
,
),
dtype
=
torch
.
int32
,
device
=
topk_ids
.
device
)
num_tokens_post_pad
=
torch
.
empty
((
1
),
dtype
=
torch
.
int32
,
device
=
topk_ids
.
device
)
if
num_experts
>=
224
:
if
envs
.
VLLM_ENABLE_MOE_ALIGN_BLOCK_SIZE_TRITON
or
num_experts
!=
256
:
moe_align_block_size_triton
(
topk_ids
,
num_experts
,
block_size
,
sorted_ids
,
expert_ids
,
num_tokens_post_pad
,
)
else
:
# Currently requires num_experts=256
ops
.
sgl_moe_align_block_size
(
topk_ids
,
num_experts
,
block_size
,
sorted_ids
,
expert_ids
,
num_tokens_post_pad
,
)
else
:
ops
.
moe_align_block_size
(
topk_ids
,
num_experts
,
block_size
,
sorted_ids
,
expert_ids
,
num_tokens_post_pad
)
if
expert_map
is
not
None
:
expert_ids
=
expert_map
[
expert_ids
]
return
sorted_ids
,
expert_ids
,
num_tokens_post_pad
def
_valid_deep_gemm
(
hidden_states
:
torch
.
Tensor
,
w1
:
torch
.
Tensor
,
w2
:
torch
.
Tensor
,
expert_map
:
Optional
[
torch
.
Tensor
])
->
bool
:
"""
Check if the given problem size is supported by the DeepGemm grouped
gemm kernel. All of M, N, K and the quantization block_shape must be
aligned by `dg.get_m_alignment_for_contiguous_layout()`.
"""
if
not
has_deep_gemm
:
return
False
# Lazy import to avoid CUDA initialization problems.
import
deep_gemm
as
dg
# Expert maps not supported yet.
if
expert_map
is
not
None
:
return
False
align
=
dg
.
get_m_alignment_for_contiguous_layout
()
M
=
hidden_states
.
shape
[
0
]
_
,
K
,
N
=
w2
.
shape
# For now, disable DeepGemm for small N until better permute/unpermute
# ops are available.
if
N
<=
512
:
return
False
if
align
>
M
or
N
%
align
!=
0
or
K
%
align
!=
0
:
return
False
return
(
hidden_states
.
is_contiguous
()
and
w1
.
is_contiguous
()
and
w2
.
is_contiguous
())
def
_fp8_quantize
(
A
:
torch
.
Tensor
,
A_scale
:
Optional
[
torch
.
Tensor
],
block_shape
:
Optional
[
List
[
int
]],
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
"""
Perform fp8 quantization on the inputs. If a block_shape
is provided, the output will be blocked.
"""
if
block_shape
is
None
:
A
,
A_scale
=
ops
.
scaled_fp8_quant
(
A
,
A_scale
)
else
:
assert
len
(
block_shape
)
==
2
_
,
block_k
=
block_shape
[
0
],
block_shape
[
1
]
A
,
A_scale
=
per_token_group_quant_fp8
(
A
,
block_k
)
assert
triton
.
cdiv
(
A
.
shape
[
-
1
],
block_k
)
==
A_scale
.
shape
[
-
1
]
return
A
,
A_scale
def
invoke_fused_moe_kernel
(
A
:
torch
.
Tensor
,
B
:
torch
.
Tensor
,
C
:
torch
.
Tensor
,
A_scale
:
Optional
[
torch
.
Tensor
],
B_scale
:
Optional
[
torch
.
Tensor
],
B_zp
:
Optional
[
torch
.
Tensor
],
topk_weights
:
torch
.
Tensor
,
topk_ids
:
torch
.
Tensor
,
topk_weights
:
Optional
[
torch
.
Tensor
],
sorted_token_ids
:
torch
.
Tensor
,
expert_ids
:
torch
.
Tensor
,
num_tokens_post_padded
:
torch
.
Tensor
,
...
...
@@ -748,7 +460,8 @@ def invoke_fused_moe_kernel(A: torch.Tensor,
use_int8_w8a16
:
bool
,
use_int4_w4a16
:
bool
,
block_shape
:
Optional
[
List
[
int
]]
=
None
)
->
None
:
assert
topk_weights
.
stride
(
1
)
==
1
assert
topk_weights
is
not
None
or
not
mul_routed_weight
assert
topk_weights
is
None
or
topk_weights
.
stride
(
1
)
==
1
assert
sorted_token_ids
.
stride
(
0
)
==
1
if
use_fp8_w8a8
:
...
...
@@ -765,6 +478,9 @@ def invoke_fused_moe_kernel(A: torch.Tensor,
assert
A_scale
is
None
assert
B_scale
is
None
M
=
A
.
shape
[
0
]
num_tokens
=
M
*
top_k
EM
=
sorted_token_ids
.
shape
[
0
]
if
A
.
shape
[
0
]
<
config
[
"BLOCK_SIZE_M"
]:
# optimize for small batch_size.
...
...
@@ -782,7 +498,7 @@ def invoke_fused_moe_kernel(A: torch.Tensor,
assert
B_zp
is
None
or
B_zp
.
ndim
==
3
use_moe_wna16_cuda
=
should_moe_wna16_use_cuda
(
num_valid_tokens
=
topk_ids
.
numel
()
,
num_valid_tokens
=
num_tokens
,
group_size
=
block_shape
[
1
],
num_experts
=
B
.
shape
[
0
],
bit
=
4
if
use_int4_w4a16
else
8
)
...
...
@@ -790,12 +506,12 @@ def invoke_fused_moe_kernel(A: torch.Tensor,
config
.
update
(
get_moe_wna16_block_config
(
config
=
config
,
use_moe_wna16_cuda
=
use_moe_wna16_cuda
,
num_valid_tokens
=
topk_ids
.
numel
()
,
num_valid_tokens
=
num_tokens
,
size_k
=
A
.
shape
[
1
],
size_n
=
B
.
shape
[
1
],
num_experts
=
B
.
shape
[
1
],
group_size
=
block_shape
[
1
],
real_top_k
=
topk
_ids
.
shape
[
1
]
,
real_top_k
=
top
_
k
,
block_size_m
=
config
[
"BLOCK_SIZE_M"
]))
if
use_moe_wna16_cuda
:
...
...
@@ -821,7 +537,7 @@ def invoke_fused_moe_kernel(A: torch.Tensor,
B
.
shape
[
1
],
A
.
shape
[
1
],
EM
,
topk_ids
.
numel
()
,
num_tokens
,
A
.
stride
(
0
),
A
.
stride
(
1
),
B
.
stride
(
0
),
...
...
@@ -864,7 +580,7 @@ def invoke_fused_moe_kernel(A: torch.Tensor,
B
.
shape
[
1
],
B
.
shape
[
2
],
EM
,
topk_ids
.
numel
()
,
num_tokens
,
A
.
stride
(
0
),
A
.
stride
(
1
),
B
.
stride
(
0
),
...
...
@@ -1389,6 +1105,7 @@ def fused_experts(hidden_states: torch.Tensor,
w2
=
w2
,
topk_weights
=
topk_weights
,
topk_ids
=
topk_ids
,
inplace
=
inplace
,
activation
=
activation
,
global_num_experts
=
global_num_experts
,
expert_map
=
expert_map
,
...
...
@@ -1419,85 +1136,6 @@ def fused_experts(hidden_states: torch.Tensor,
block_shape
=
block_shape
)
def
_fp8_perm
(
m
:
torch
.
Tensor
,
idx
:
torch
.
Tensor
)
->
torch
.
Tensor
:
"""
A permutation routine that works on fp8 types.
"""
if
torch
.
is_floating_point
(
m
)
and
torch
.
finfo
(
m
.
dtype
).
bits
==
8
:
return
m
.
view
(
dtype
=
torch
.
uint8
)[
idx
,
...].
view
(
dtype
=
m
.
dtype
)
else
:
return
m
[
idx
,
...]
def
_moe_permute
(
curr_hidden_states
:
torch
.
Tensor
,
a1q_scale
:
Optional
[
torch
.
Tensor
],
curr_topk_ids
:
torch
.
Tensor
,
global_num_experts
:
int
,
expert_map
:
Optional
[
torch
.
Tensor
],
top_k_num
:
int
,
block_m
:
int
,
)
->
Tuple
[
torch
.
Tensor
,
Optional
[
torch
.
Tensor
],
torch
.
Tensor
,
torch
.
Tensor
,
torch
.
Tensor
]:
"""
Determine the sorted_token_ids, expert_ids for the given problem size.
Permute the hidden states and scales according to `sorted_token_ids`.
"""
tokens_in_chunk
,
_
=
curr_hidden_states
.
shape
sorted_token_ids
,
expert_ids
,
num_tokens_post_padded
=
(
moe_align_block_size
(
curr_topk_ids
,
block_m
,
global_num_experts
,
expert_map
,
pad_sorted_ids
=
True
))
inv_perm
:
Optional
[
torch
.
Tensor
]
=
None
num_tokens
=
top_k_num
*
tokens_in_chunk
sorted_token_ids
=
sorted_token_ids
.
clamp
(
max
=
num_tokens
-
1
)
expert_ids
=
torch
.
repeat_interleave
(
expert_ids
,
block_m
,
dim
=
0
)
inv_perm
=
torch
.
argsort
(
sorted_token_ids
)[:
num_tokens
]
# Permute according to sorted token ids.
curr_hidden_states
=
_fp8_perm
(
curr_hidden_states
,
sorted_token_ids
//
top_k_num
)
if
a1q_scale
is
not
None
:
a1q_scale
=
a1q_scale
[
sorted_token_ids
//
top_k_num
]
return
(
curr_hidden_states
,
a1q_scale
,
sorted_token_ids
,
expert_ids
,
inv_perm
)
def
_moe_unpermute_and_reduce
(
out
:
torch
.
Tensor
,
curr_hidden
:
torch
.
Tensor
,
inv_perm
:
Optional
[
torch
.
Tensor
],
topk
:
int
,
K
:
int
,
topk_weight
:
torch
.
Tensor
,
)
->
None
:
"""
Unpermute the final result and apply topk_weights, then perform the final
reduction on the hidden states.
"""
M
=
topk_weight
.
shape
[
0
]
curr_hidden
=
curr_hidden
[
inv_perm
,
...]
curr_hidden
=
curr_hidden
.
view
(
-
1
,
topk
,
K
)
curr_hidden
.
mul_
(
topk_weight
.
view
(
M
,
-
1
,
1
))
ops
.
moe_sum
(
curr_hidden
,
out
)
def
_resize_cache
(
x
:
torch
.
Tensor
,
v
:
Tuple
[
int
,
...])
->
torch
.
Tensor
:
"""
Shrink the given tensor and apply the given view to it. This is
used to resize the intermediate fused_moe caches.
"""
assert
prod
(
v
)
<=
x
.
numel
()
return
x
.
flatten
()[:
prod
(
v
)].
view
(
*
v
)
def
fused_experts_impl
(
hidden_states
:
torch
.
Tensor
,
w1
:
torch
.
Tensor
,
w2
:
torch
.
Tensor
,
...
...
@@ -1629,7 +1267,6 @@ def fused_experts_impl(hidden_states: torch.Tensor,
w1_scale
,
w1_zp
,
curr_topk_weights
,
curr_topk_ids
,
sorted_token_ids
,
expert_ids
,
num_tokens_post_padded
,
...
...
@@ -1660,28 +1297,34 @@ def fused_experts_impl(hidden_states: torch.Tensor,
qintermediate_cache2
=
intermediate_cache2
a2q_scale
=
a2_scale
invoke_fused_moe_kernel
(
qintermediate_cache2
,
w2
,
intermediate_cache3
,
a2q_scale
,
w2_scale
,
w2_zp
,
curr_topk_weights
,
curr_topk_ids
,
sorted_token_ids
,
expert_ids
,
num_tokens_post_padded
,
True
,
1
,
config
,
compute_type
=
compute_type
,
use_fp8_w8a8
=
use_fp8_w8a8
,
use_int8_w8a16
=
use_int8_w8a16
,
use_int4_w4a16
=
use_int4_w4a16
,
block_shape
=
block_shape
)
invoke_fused_moe_kernel
(
qintermediate_cache2
,
w2
,
intermediate_cache3
,
a2q_scale
,
w2_scale
,
w2_zp
,
curr_topk_weights
,
sorted_token_ids
,
expert_ids
,
num_tokens_post_padded
,
False
,
#True,
1
,
config
,
compute_type
=
compute_type
,
use_fp8_w8a8
=
use_fp8_w8a8
,
use_int8_w8a16
=
use_int8_w8a16
,
use_int4_w4a16
=
use_int4_w4a16
,
block_shape
=
block_shape
)
if
True
:
intermediate_cache3
=
intermediate_cache3
.
view
(
-
1
,
top_k_num
,
K
)
intermediate_cache3
.
mul_
(
curr_topk_weights
.
view
(
tokens_in_chunk
,
-
1
,
1
))
ops
.
moe_sum
(
intermediate_cache3
.
view
(
*
intermediate_cache3
.
shape
),
out_hidden_states
[
begin_chunk_idx
:
end_chunk_idx
])
return
out_hidden_states
...
...
@@ -1790,327 +1433,3 @@ def fused_moe(
a1_scale
=
a1_scale
,
a2_scale
=
a2_scale
,
block_shape
=
block_shape
)
def
deep_gemm_moe_fp8
(
hidden_states
:
torch
.
Tensor
,
w1
:
torch
.
Tensor
,
w2
:
torch
.
Tensor
,
w1_scale
:
torch
.
Tensor
,
w2_scale
:
torch
.
Tensor
,
topk_weights
:
torch
.
Tensor
,
topk_ids
:
torch
.
Tensor
,
inplace
:
bool
=
False
,
activation
:
str
=
"silu"
,
global_num_experts
:
int
=
-
1
,
expert_map
:
Optional
[
torch
.
Tensor
]
=
None
,
a1_scale
:
Optional
[
torch
.
Tensor
]
=
None
,
a2_scale
:
Optional
[
torch
.
Tensor
]
=
None
,
)
->
torch
.
Tensor
:
"""
This function computes a a8w8-quantized Mixture of Experts (MoE) layer
using two sets of quantized weights, w1_q and w2_q, and top-k gating
mechanism. The matrix multiplications are implemented with DeepGemm
grouped gemm.
Parameters:
- hidden_states (torch.Tensor): The input tensor to the MoE layer.
Shape: [M, K]
- w1 (torch.Tensor): The first set of fp8 quantized expert weights.
Shape: [num_experts, K, 2N] (the weights are passed transposed)
- w2 (torch.Tensor): The second set of fp8 quantized expert weights.
Shape: [num_experts, N, K] (the weights are passed transposed)
- w1_scale (torch.Tensor): The fp32 scale to dequantize w1_q.
Shape: [num_experts] or [num_experts, 2N]
- w2_scale (torch.Tensor): The fp32 scale to dequantize w2_q.
Shape: [num_experts] or [num_experts, K]
- topk_weights (torch.Tensor): The weights of each token->expert mapping.
- topk_ids (torch.Tensor): The token->expert mapping for topk_weights.
- inplace (bool): If True, perform the operation in-place.
Defaults to False.
- activation (str): The activation function to apply after the first
MoE layer.
- global_num_experts (int): The total number of experts in the global
expert space.
- expert_map (Optional[torch.Tensor]): A tensor mapping expert indices
from the global expert space to the local expert space of the expert
parallel shard.
- a1_scale (Optional[torch.Tensor]): The optional fp32 scale to quantize a.
Shape: scalar or [M]
- a2_scale (Optional[torch.Tensor]): The optional fp32 scale to
quantize the intermediate result between the gemms.
Shape: scalar or [M]
Returns:
- torch.Tensor: The bfloat16 output tensor after applying the MoE layer.
"""
# Lazy import to avoid CUDA initialization problems.
import
deep_gemm
as
dg
assert
expert_map
is
None
,
"Expert maps not supported yet"
assert
hidden_states
.
shape
[
1
]
==
w1
.
shape
[
2
],
"Hidden size mismatch"
assert
topk_weights
.
shape
==
topk_ids
.
shape
,
"topk shape mismatch"
assert
hidden_states
.
is_contiguous
(),
"Hidden_states must be contiguous"
assert
w1
.
stride
(
-
1
)
==
1
,
"Stride of last dimension must be 1"
assert
w2
.
stride
(
-
1
)
==
1
,
"Stride of last dimension must be 1"
assert
hidden_states
.
dtype
in
[
torch
.
float32
,
torch
.
float16
,
torch
.
bfloat16
]
assert
w1
.
dtype
==
torch
.
float8_e4m3fn
assert
w2
.
dtype
==
torch
.
float8_e4m3fn
assert
w1
.
shape
[
0
]
==
w2
.
shape
[
0
],
"Expert number mismatch"
assert
w1
.
shape
[
0
]
==
w1_scale
.
shape
[
0
],
"w1 scales expert number mismatch"
assert
w1
.
shape
[
0
]
==
w2_scale
.
shape
[
0
],
"w2 scales expert number mismatch"
assert
a1_scale
is
None
or
a1_scale
.
dim
(
)
==
0
or
a1_scale
.
shape
[
0
]
==
1
or
a1_scale
.
shape
[
0
]
==
hidden_states
.
shape
[
0
],
"Input scale shape mismatch"
assert
a2_scale
is
None
or
a1_scale
is
None
or
a2_scale
.
shape
==
a1_scale
.
shape
,
"Intermediate scale shape mismatch"
# noqa: E501
num_tokens
,
_
=
hidden_states
.
shape
E
,
N
,
_
=
w1
.
shape
K
=
w2
.
shape
[
1
]
if
global_num_experts
==
-
1
:
global_num_experts
=
E
top_k_num
=
topk_ids
.
shape
[
1
]
# We execute the fused_moe kernel in chunks to circumvent this issue:
# https://github.com/vllm-project/vllm/issues/5938
CHUNK_SIZE
=
envs
.
VLLM_FUSED_MOE_CHUNK_SIZE
assert
_valid_deep_gemm
(
hidden_states
,
w1
,
w2
,
expert_map
)
if
inplace
:
out_hidden_states
=
hidden_states
else
:
out_hidden_states
=
torch
.
empty_like
(
hidden_states
)
block_m
=
dg
.
get_m_alignment_for_contiguous_layout
()
block_shape
=
[
block_m
,
block_m
]
assert
w1_scale
is
not
None
assert
w2_scale
is
not
None
# We attempt to transpose and align offline in Fp8MoEMethod, in which
# case these calls will be nops. Otherwise, they'll be performed every
# time the layer is executed.
w1_scale
=
dg
.
get_col_major_tma_aligned_tensor
(
w1_scale
).
contiguous
()
w2_scale
=
dg
.
get_col_major_tma_aligned_tensor
(
w2_scale
).
contiguous
()
M_sum
=
topk_ids
.
numel
()
+
global_num_experts
*
(
block_m
-
1
)
M_sum
=
round_up
(
M_sum
,
block_m
)
num_chunks
=
(
num_tokens
//
CHUNK_SIZE
)
+
1
# We can reuse the memory between cache1 and cache3 because by the time
# we need cache3, we're done with cache1
cache13
=
torch
.
empty
(
M_sum
*
max
(
N
,
K
),
device
=
hidden_states
.
device
,
dtype
=
hidden_states
.
dtype
)
intermediate_cache1
=
cache13
[:
M_sum
*
N
].
view
(
M_sum
,
N
)
intermediate_cache2
=
torch
.
empty
((
M_sum
,
N
//
2
),
device
=
hidden_states
.
device
,
dtype
=
hidden_states
.
dtype
)
intermediate_cache3
=
cache13
[:
M_sum
*
K
].
view
(
M_sum
,
K
)
for
chunk
in
range
(
num_chunks
):
begin_chunk_idx
,
end_chunk_idx
=
(
chunk
*
CHUNK_SIZE
,
min
((
chunk
+
1
)
*
CHUNK_SIZE
,
num_tokens
))
curr_hidden_states
=
hidden_states
[
begin_chunk_idx
:
end_chunk_idx
]
tokens_in_chunk
,
_
=
curr_hidden_states
.
shape
if
tokens_in_chunk
==
0
:
break
curr_topk_ids
=
topk_ids
[
begin_chunk_idx
:
end_chunk_idx
]
curr_topk_weights
=
topk_weights
[
begin_chunk_idx
:
end_chunk_idx
]
a1q_scale
:
Optional
[
torch
.
Tensor
]
=
None
qcurr_hidden_states
,
a1q_scale
=
_fp8_quantize
(
curr_hidden_states
,
a1_scale
,
block_shape
)
(
qcurr_hidden_states
,
a1q_scale
,
sorted_token_ids
,
expert_ids
,
inv_perm
)
=
_moe_permute
(
qcurr_hidden_states
,
a1q_scale
,
curr_topk_ids
,
global_num_experts
,
expert_map
,
top_k_num
,
block_m
)
# Adjust the intermediate cache size and config for the last chunk.
# Note that in most cases we only have one chunk so the cache size
# and config are already set correctly and do not need to be adjusted.
if
tokens_in_chunk
<
CHUNK_SIZE
and
chunk
>
0
:
curr_M
=
sorted_token_ids
.
numel
()
intermediate_cache1
=
_resize_cache
(
intermediate_cache1
,
(
curr_M
,
N
))
intermediate_cache2
=
_resize_cache
(
intermediate_cache2
,
(
curr_M
,
N
//
2
))
intermediate_cache3
=
_resize_cache
(
intermediate_cache3
,
(
curr_M
,
K
))
dg
.
m_grouped_gemm_fp8_fp8_bf16_nt_contiguous
(
(
qcurr_hidden_states
,
a1q_scale
),
(
w1
,
w1_scale
),
intermediate_cache1
,
expert_ids
)
if
activation
==
"silu"
:
torch
.
ops
.
_C
.
silu_and_mul
(
intermediate_cache2
,
intermediate_cache1
.
view
(
-
1
,
N
))
elif
activation
==
"gelu"
:
torch
.
ops
.
_C
.
gelu_and_mul
(
intermediate_cache2
,
intermediate_cache1
.
view
(
-
1
,
N
))
else
:
raise
ValueError
(
f
"Unsupported FusedMoe activation:
{
activation
}
"
)
a2q_scale
:
Optional
[
torch
.
Tensor
]
=
None
qintermediate_cache2
,
a2q_scale
=
_fp8_quantize
(
intermediate_cache2
,
a2_scale
,
block_shape
)
dg
.
m_grouped_gemm_fp8_fp8_bf16_nt_contiguous
(
(
qintermediate_cache2
,
a2q_scale
),
(
w2
,
w2_scale
),
intermediate_cache3
,
expert_ids
)
_moe_unpermute_and_reduce
(
out_hidden_states
[
begin_chunk_idx
:
end_chunk_idx
],
intermediate_cache3
.
view
(
*
intermediate_cache3
.
shape
),
inv_perm
,
top_k_num
,
K
,
curr_topk_weights
)
return
out_hidden_states
#TODO make the grouped gemm kernel consistent with scaled gemm kernel
def
cutlass_moe_fp8
(
a
:
torch
.
Tensor
,
w1_q
:
torch
.
Tensor
,
w2_q
:
torch
.
Tensor
,
w1_scale
:
torch
.
Tensor
,
w2_scale
:
torch
.
Tensor
,
topk_weights
:
torch
.
Tensor
,
topk_ids
:
torch
.
Tensor
,
ab_strides1
:
torch
.
Tensor
,
c_strides1
:
torch
.
Tensor
,
ab_strides2
:
torch
.
Tensor
,
c_strides2
:
torch
.
Tensor
,
a1_scale
:
Optional
[
torch
.
Tensor
]
=
None
,
a2_scale
:
Optional
[
torch
.
Tensor
]
=
None
,
out_dtype
:
torch
.
dtype
=
torch
.
half
,
)
->
torch
.
Tensor
:
"""
This function computes a a8w8-quantized Mixture of Experts (MoE) layer
using two sets of quantized weights, w1_q and w2_q, and top-k gating
mechanism. The matrix multiplications are implemented with CUTLASS
grouped gemm.
Parameters:
- a (torch.Tensor): The input tensor to the MoE layer.
Shape: [M, K]
- w1_q (torch.Tensor): The first set of fp8-quantized expert weights.
Shape: [num_experts, K, 2N] (the weights are passed transposed)
- w2_q (torch.Tensor): The second set of fp8-quantized expert weights.
Shape: [num_experts, N, K] (the weights are passed transposed)
- w1_scale (torch.Tensor): The fp32 scale to dequantize w1_q.
Shape: [num_experts] or [num_experts, 2N]
- w2_scale (torch.Tensor): The fp32 scale to dequantize w2_q.
Shape: [num_experts] or [num_experts, K]
- gating_output (torch.Tensor): The output of the gating operation
(before softmax).
- topk_weights (torch.Tensor): The weights of each token->expert mapping.
- ab_strides1 (torch.Tensor): The input and weights strides of the first
grouped gemm.
- c_strides1 (torch.Tensor): The output strides of the first grouped gemm.
- ab_strides2 (torch.Tensor): The input and weights strides of the second
grouped gemm.
- c_strides2 (torch.Tensor): The output strides of the second grouped gemm.
- a1_scale (Optional[torch.Tensor]): The optional fp32 scale to quantize a.
Shape: scalar or [M]
- a2_scale (Optional[torch.Tensor]): The optional fp32 scale to
quantize the intermediate result between the gemms.
Shape: scalar or [M]
- out_dtype (torch.Tensor): The output tensor type.
Returns:
- torch.Tensor: The fp16 output tensor after applying the MoE layer.
"""
assert
topk_weights
.
shape
==
topk_ids
.
shape
,
"topk shape mismatch"
assert
w1_q
.
dtype
==
torch
.
float8_e4m3fn
assert
w2_q
.
dtype
==
torch
.
float8_e4m3fn
assert
a
.
shape
[
1
]
==
w1_q
.
shape
[
1
],
"Hidden size mismatch w1"
assert
w1_q
.
shape
[
2
]
==
w2_q
.
shape
[
1
]
*
2
,
"Hidden size mismatch w2"
assert
w1_q
.
shape
[
0
]
==
w2_q
.
shape
[
0
],
"Expert number mismatch"
assert
a1_scale
is
None
or
a1_scale
.
dim
(
)
==
0
or
a1_scale
.
shape
[
0
]
==
1
or
a1_scale
.
shape
[
0
]
==
a
.
shape
[
0
],
"Input scale shape mismatch"
assert
w1_scale
.
dim
()
==
1
or
w1_scale
.
shape
[
1
]
==
1
or
w1_scale
.
shape
[
1
]
==
w1_q
.
shape
[
2
],
"W1 scale shape mismatch"
assert
w2_scale
.
dim
()
==
1
or
w2_scale
.
shape
[
1
]
==
1
or
w2_scale
.
shape
[
1
]
==
w2_q
.
shape
[
2
],
"W2 scale shape mismatch"
assert
w1_q
.
shape
[
0
]
==
w2_q
.
shape
[
0
],
"Weights expert number mismatch"
assert
w1_q
.
shape
[
0
]
==
w1_scale
.
shape
[
0
],
"w1 scales expert number mismatch"
assert
w1_q
.
shape
[
0
]
==
w2_scale
.
shape
[
0
],
"w2 scales expert number mismatch"
assert
a2_scale
is
None
or
a1_scale
is
None
or
a2_scale
.
shape
==
a1_scale
.
shape
,
"Intermediate scale shape mismatch"
# noqa: E501
assert
ab_strides1
.
shape
[
0
]
==
w1_q
.
shape
[
0
],
"AB Strides 1 expert number mismatch"
assert
c_strides1
.
shape
[
0
]
==
w1_q
.
shape
[
0
],
"C Strides 1 expert number mismatch"
assert
ab_strides2
.
shape
[
0
]
==
w2_q
.
shape
[
0
],
"AB Strides 2 expert number mismatch"
assert
c_strides2
.
shape
[
0
]
==
w2_q
.
shape
[
0
],
"C Strides 2 expert number mismatch"
assert
out_dtype
in
[
torch
.
half
,
torch
.
bfloat16
],
"Invalid output dtype"
num_experts
=
w1_q
.
size
(
0
)
m
=
a
.
size
(
0
)
k
=
w1_q
.
size
(
1
)
n
=
w2_q
.
size
(
1
)
topk
=
topk_ids
.
size
(
1
)
per_act_token
=
a1_scale
.
numel
()
!=
1
if
a1_scale
is
not
None
else
(
a2_scale
.
numel
()
!=
1
if
a2_scale
is
not
None
else
False
)
a_q
,
a1_scale
=
ops
.
scaled_fp8_quant
(
a
,
a1_scale
,
use_per_token_if_dynamic
=
per_act_token
)
device
=
a_q
.
device
expert_offsets
=
torch
.
empty
((
num_experts
+
1
),
dtype
=
torch
.
int32
,
device
=
device
)
problem_sizes1
=
torch
.
empty
((
num_experts
,
3
),
dtype
=
torch
.
int32
,
device
=
device
)
problem_sizes2
=
torch
.
empty
((
num_experts
,
3
),
dtype
=
torch
.
int32
,
device
=
device
)
a_map
=
torch
.
empty
((
topk_ids
.
numel
()),
dtype
=
torch
.
int32
,
device
=
device
)
c_map
=
torch
.
empty
((
topk_ids
.
numel
()),
dtype
=
torch
.
int32
,
device
=
device
)
ops
.
get_cutlass_moe_mm_data
(
topk_ids
,
expert_offsets
,
problem_sizes1
,
problem_sizes2
,
a_map
,
c_map
,
num_experts
,
n
,
k
)
rep_a_q
=
a_q
.
view
(
dtype
=
torch
.
uint8
)[
a_map
].
view
(
dtype
=
a_q
.
dtype
)
rep_a1_scales
=
a1_scale
[
a_map
]
if
per_act_token
else
a1_scale
c1
=
torch
.
empty
((
m
*
topk
,
n
*
2
),
device
=
device
,
dtype
=
out_dtype
)
c2
=
torch
.
empty
((
m
*
topk
,
k
),
device
=
device
,
dtype
=
out_dtype
)
ops
.
cutlass_moe_mm
(
c1
,
rep_a_q
,
w1_q
,
rep_a1_scales
,
w1_scale
,
expert_offsets
[:
-
1
],
problem_sizes1
,
ab_strides1
,
ab_strides1
,
c_strides1
)
intermediate
=
torch
.
empty
((
m
*
topk
,
n
),
device
=
device
,
dtype
=
out_dtype
)
torch
.
ops
.
_C
.
silu_and_mul
(
intermediate
,
c1
)
intemediate_q
,
a2_scale
=
ops
.
scaled_fp8_quant
(
intermediate
,
a2_scale
,
use_per_token_if_dynamic
=
per_act_token
)
ops
.
cutlass_moe_mm
(
c2
,
intemediate_q
,
w2_q
,
a2_scale
,
w2_scale
,
expert_offsets
[:
-
1
],
problem_sizes2
,
ab_strides2
,
ab_strides2
,
c_strides2
)
return
(
c2
[
c_map
].
view
(
m
,
topk
,
k
)
*
topk_weights
.
view
(
m
,
topk
,
1
).
to
(
out_dtype
)).
sum
(
dim
=
1
)
vllm/model_executor/layers/fused_moe/moe_align_block_size.py
0 → 100644
View file @
15ba07ef
# SPDX-License-Identifier: Apache-2.0
from
typing
import
Optional
,
Tuple
import
torch
import
triton
import
triton.language
as
tl
import
vllm.envs
as
envs
from
vllm
import
_custom_ops
as
ops
from
vllm.utils
import
round_up
def
ceil_div
(
a
,
b
):
return
(
a
+
b
-
1
)
//
b
@
triton
.
jit
def
moe_align_block_size_stage1
(
topk_ids_ptr
,
tokens_cnts_ptr
,
num_experts
:
tl
.
constexpr
,
numel
:
tl
.
constexpr
,
tokens_per_thread
:
tl
.
constexpr
,
):
pid
=
tl
.
program_id
(
0
)
start_idx
=
pid
*
tokens_per_thread
off_c
=
(
pid
+
1
)
*
num_experts
for
i
in
range
(
tokens_per_thread
):
if
start_idx
+
i
<
numel
:
idx
=
tl
.
load
(
topk_ids_ptr
+
start_idx
+
i
)
token_cnt
=
tl
.
load
(
tokens_cnts_ptr
+
off_c
+
idx
)
tl
.
store
(
tokens_cnts_ptr
+
off_c
+
idx
,
token_cnt
+
1
)
@
triton
.
jit
def
moe_align_block_size_stage2
(
tokens_cnts_ptr
,
num_experts
:
tl
.
constexpr
,
):
pid
=
tl
.
program_id
(
0
)
last_cnt
=
0
for
i
in
range
(
1
,
num_experts
+
1
):
token_cnt
=
tl
.
load
(
tokens_cnts_ptr
+
i
*
num_experts
+
pid
)
last_cnt
=
last_cnt
+
token_cnt
tl
.
store
(
tokens_cnts_ptr
+
i
*
num_experts
+
pid
,
last_cnt
)
@
triton
.
jit
def
moe_align_block_size_stage3
(
total_tokens_post_pad_ptr
,
tokens_cnts_ptr
,
cumsum_ptr
,
num_experts
:
tl
.
constexpr
,
block_size
:
tl
.
constexpr
,
):
last_cumsum
=
0
off_cnt
=
num_experts
*
num_experts
for
i
in
range
(
1
,
num_experts
+
1
):
token_cnt
=
tl
.
load
(
tokens_cnts_ptr
+
off_cnt
+
i
-
1
)
last_cumsum
=
last_cumsum
+
tl
.
cdiv
(
token_cnt
,
block_size
)
*
block_size
tl
.
store
(
cumsum_ptr
+
i
,
last_cumsum
)
tl
.
store
(
total_tokens_post_pad_ptr
,
last_cumsum
)
@
triton
.
jit
def
moe_align_block_size_stage4
(
topk_ids_ptr
,
sorted_token_ids_ptr
,
expert_ids_ptr
,
tokens_cnts_ptr
,
cumsum_ptr
,
num_experts
:
tl
.
constexpr
,
block_size
:
tl
.
constexpr
,
numel
:
tl
.
constexpr
,
tokens_per_thread
:
tl
.
constexpr
,
):
pid
=
tl
.
program_id
(
0
)
start_idx
=
tl
.
load
(
cumsum_ptr
+
pid
)
end_idx
=
tl
.
load
(
cumsum_ptr
+
pid
+
1
)
for
i
in
range
(
start_idx
,
end_idx
,
block_size
):
tl
.
store
(
expert_ids_ptr
+
i
//
block_size
,
pid
)
start_idx
=
pid
*
tokens_per_thread
off_t
=
pid
*
num_experts
for
i
in
range
(
start_idx
,
tl
.
minimum
(
start_idx
+
tokens_per_thread
,
numel
)):
expert_id
=
tl
.
load
(
topk_ids_ptr
+
i
)
token_cnt
=
tl
.
load
(
tokens_cnts_ptr
+
off_t
+
expert_id
)
rank_post_pad
=
token_cnt
+
tl
.
load
(
cumsum_ptr
+
expert_id
)
tl
.
store
(
sorted_token_ids_ptr
+
rank_post_pad
,
i
)
tl
.
store
(
tokens_cnts_ptr
+
off_t
+
expert_id
,
token_cnt
+
1
)
# Triton implementation based on:
# https://github.com/sgl-project/sglang/commit/ba5112ff691d791a9e38c6c71f59324a5fcb49d0
def
moe_align_block_size_triton
(
topk_ids
:
torch
.
Tensor
,
num_experts
:
int
,
block_size
:
int
,
sorted_token_ids
:
torch
.
Tensor
,
expert_ids
:
torch
.
Tensor
,
num_tokens_post_pad
:
torch
.
Tensor
,
)
->
None
:
numel
=
topk_ids
.
numel
()
grid
=
(
num_experts
,
)
tokens_cnts
=
torch
.
zeros
((
num_experts
+
1
,
num_experts
),
dtype
=
torch
.
int32
,
device
=
topk_ids
.
device
)
cumsum
=
torch
.
zeros
((
num_experts
+
1
,
),
dtype
=
torch
.
int32
,
device
=
topk_ids
.
device
)
tokens_per_thread
=
ceil_div
(
numel
,
num_experts
)
moe_align_block_size_stage1
[
grid
](
topk_ids
,
tokens_cnts
,
num_experts
,
numel
,
tokens_per_thread
,
)
moe_align_block_size_stage2
[
grid
](
tokens_cnts
,
num_experts
,
)
moe_align_block_size_stage3
[(
1
,
)](
num_tokens_post_pad
,
tokens_cnts
,
cumsum
,
num_experts
,
block_size
,
)
moe_align_block_size_stage4
[
grid
](
topk_ids
,
sorted_token_ids
,
expert_ids
,
tokens_cnts
,
cumsum
,
num_experts
,
block_size
,
numel
,
tokens_per_thread
,
)
def
moe_align_block_size
(
topk_ids
:
torch
.
Tensor
,
block_size
:
int
,
num_experts
:
int
,
expert_map
:
Optional
[
torch
.
Tensor
]
=
None
,
pad_sorted_ids
:
bool
=
False
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
,
torch
.
Tensor
]:
"""
Aligns the token distribution across experts to be compatible with block
size for matrix multiplication.
Parameters:
- topk_ids: A tensor of shape [total_tokens, top_k] representing the
top-k expert indices for each token.
- block_size: The block size used in block matrix multiplication.
- num_experts: The total number of experts.
- expert_map: A tensor of shape [num_experts] that maps the expert index
from the global space to the local index space of the current
expert parallel shard. If the expert is not in the current expert
parallel shard, the mapping is set to -1.
- pad_sorted_ids: A flag indicating whether the sorted_token_ids length
should be padded to a multiple of block_size,
Returns:
- sorted_token_ids: A tensor containing the sorted token indices according
to their allocated expert.
- expert_ids: A tensor indicating the assigned expert index for each block.
- num_tokens_post_padded: The total number of tokens after padding,
ensuring divisibility by block_size.
This function pads the number of tokens that each expert needs to process
so that it is divisible by block_size.
Padding ensures that during block matrix multiplication, the dimensions
align correctly.
Example:
Given topk_ids = [[2, 3, 4], [1, 2, 4], [1, 3, 4], [1, 2, 3]],
block_size = 4, and num_experts = 4:
- We initially have 12 tokens (after repeating 'top_k' times) and 4 experts,
with each expert needing to process 3 tokens.
- As block_size is 4, we pad 1 token for each expert.
- First, flatten topk_ids to [2, 3, 4, 1, 2, 4, 1, 3, 4, 1, 2, 3].
- Then append padding tokens [12, 12, 12, 12] for each block.
- After sorting by expert index, we obtain token_ids
[3, 6, 9, 12, 0, 4, 10, 12, 1, 7, 11, 12, 2, 5, 8, 12].
Tokens 12 are non-existent (padding) and are ignored in
the subsequent matrix multiplication.
- The padding ensures that the total number of tokens is now divisible
by block_size for proper block matrix operations.
"""
max_num_tokens_padded
=
topk_ids
.
numel
()
+
num_experts
*
(
block_size
-
1
)
if
pad_sorted_ids
:
max_num_tokens_padded
=
round_up
(
max_num_tokens_padded
,
block_size
)
sorted_ids
=
torch
.
empty
((
max_num_tokens_padded
,
),
dtype
=
torch
.
int32
,
device
=
topk_ids
.
device
)
sorted_ids
.
fill_
(
topk_ids
.
numel
())
max_num_m_blocks
=
triton
.
cdiv
(
max_num_tokens_padded
,
block_size
)
# Expert ids must be zeroed out to prevent index out of bounds error while
# mapping global expert ids to local expert ids in expert parallelism.
expert_ids
=
torch
.
zeros
((
max_num_m_blocks
,
),
dtype
=
torch
.
int32
,
device
=
topk_ids
.
device
)
num_tokens_post_pad
=
torch
.
empty
((
1
),
dtype
=
torch
.
int32
,
device
=
topk_ids
.
device
)
if
num_experts
>=
224
:
if
envs
.
VLLM_ENABLE_MOE_ALIGN_BLOCK_SIZE_TRITON
or
num_experts
!=
256
:
moe_align_block_size_triton
(
topk_ids
,
num_experts
,
block_size
,
sorted_ids
,
expert_ids
,
num_tokens_post_pad
,
)
else
:
# Currently requires num_experts=256
ops
.
sgl_moe_align_block_size
(
topk_ids
,
num_experts
,
block_size
,
sorted_ids
,
expert_ids
,
num_tokens_post_pad
,
)
else
:
ops
.
moe_align_block_size
(
topk_ids
,
num_experts
,
block_size
,
sorted_ids
,
expert_ids
,
num_tokens_post_pad
)
if
expert_map
is
not
None
:
expert_ids
=
expert_map
[
expert_ids
]
return
sorted_ids
,
expert_ids
,
num_tokens_post_pad
vllm/model_executor/layers/fused_moe/utils.py
0 → 100644
View file @
15ba07ef
# SPDX-License-Identifier: Apache-2.0
from
math
import
prod
from
typing
import
List
,
Optional
,
Tuple
import
torch
from
vllm
import
_custom_ops
as
ops
from
vllm.model_executor.layers.quantization.utils.fp8_utils
import
(
per_token_group_quant_fp8
)
from
vllm.utils
import
cdiv
def
_resize_cache
(
x
:
torch
.
Tensor
,
v
:
Tuple
[
int
,
...])
->
torch
.
Tensor
:
"""
Shrink the given tensor and apply the given view to it. This is
used to resize the intermediate fused_moe caches.
"""
assert
prod
(
v
)
<=
x
.
numel
()
return
x
.
flatten
()[:
prod
(
v
)].
view
(
*
v
)
def
_fp8_quantize
(
A
:
torch
.
Tensor
,
A_scale
:
Optional
[
torch
.
Tensor
],
block_shape
:
Optional
[
List
[
int
]],
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
"""
Perform fp8 quantization on the inputs. If a block_shape
is provided, the output will be blocked.
"""
if
block_shape
is
None
:
A
,
A_scale
=
ops
.
scaled_fp8_quant
(
A
,
A_scale
)
else
:
assert
len
(
block_shape
)
==
2
_
,
block_k
=
block_shape
[
0
],
block_shape
[
1
]
A
,
A_scale
=
per_token_group_quant_fp8
(
A
,
block_k
)
assert
cdiv
(
A
.
shape
[
-
1
],
block_k
)
==
A_scale
.
shape
[
-
1
]
return
A
,
A_scale
def
_fp8_perm
(
m
:
torch
.
Tensor
,
idx
:
torch
.
Tensor
)
->
torch
.
Tensor
:
"""
A permutation routine that works on fp8 types.
"""
if
torch
.
is_floating_point
(
m
)
and
torch
.
finfo
(
m
.
dtype
).
bits
==
8
:
return
m
.
view
(
dtype
=
torch
.
uint8
)[
idx
,
...].
view
(
dtype
=
m
.
dtype
)
else
:
return
m
[
idx
,
...]
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment