Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
78fe7753
Unverified
Commit
78fe7753
authored
Jul 03, 2025
by
bnellnm
Committed by
GitHub
Jul 03, 2025
Browse files
[Kernel] Enable fp8 support for pplx and BatchedTritonExperts. (#18864)
Signed-off-by:
Bill Nell
<
bnell@redhat.com
>
parent
2f2fcb31
Changes
25
Show whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
1190 additions
and
651 deletions
+1190
-651
tests/kernels/moe/parallel_utils.py
tests/kernels/moe/parallel_utils.py
+3
-7
tests/kernels/moe/test_batched_moe.py
tests/kernels/moe/test_batched_moe.py
+54
-30
tests/kernels/moe/test_deepep_deepgemm_moe.py
tests/kernels/moe/test_deepep_deepgemm_moe.py
+1
-2
tests/kernels/moe/test_deepep_moe.py
tests/kernels/moe/test_deepep_moe.py
+3
-2
tests/kernels/moe/test_pplx_cutlass_moe.py
tests/kernels/moe/test_pplx_cutlass_moe.py
+44
-35
tests/kernels/moe/test_pplx_moe.py
tests/kernels/moe/test_pplx_moe.py
+488
-244
tests/kernels/moe/utils.py
tests/kernels/moe/utils.py
+6
-8
tests/kernels/quant_utils.py
tests/kernels/quant_utils.py
+18
-0
tests/kernels/utils.py
tests/kernels/utils.py
+6
-4
vllm/model_executor/layers/fused_moe/batched_deep_gemm_moe.py
.../model_executor/layers/fused_moe/batched_deep_gemm_moe.py
+6
-8
vllm/model_executor/layers/fused_moe/batched_triton_or_deep_gemm_moe.py
...cutor/layers/fused_moe/batched_triton_or_deep_gemm_moe.py
+11
-23
vllm/model_executor/layers/fused_moe/config.py
vllm/model_executor/layers/fused_moe/config.py
+55
-9
vllm/model_executor/layers/fused_moe/cutlass_moe.py
vllm/model_executor/layers/fused_moe/cutlass_moe.py
+12
-6
vllm/model_executor/layers/fused_moe/deepep_ht_prepare_finalize.py
...l_executor/layers/fused_moe/deepep_ht_prepare_finalize.py
+6
-17
vllm/model_executor/layers/fused_moe/deepep_ll_prepare_finalize.py
...l_executor/layers/fused_moe/deepep_ll_prepare_finalize.py
+7
-8
vllm/model_executor/layers/fused_moe/fused_batched_moe.py
vllm/model_executor/layers/fused_moe/fused_batched_moe.py
+391
-192
vllm/model_executor/layers/fused_moe/fused_moe.py
vllm/model_executor/layers/fused_moe/fused_moe.py
+2
-0
vllm/model_executor/layers/fused_moe/layer.py
vllm/model_executor/layers/fused_moe/layer.py
+12
-22
vllm/model_executor/layers/fused_moe/modular_kernel.py
vllm/model_executor/layers/fused_moe/modular_kernel.py
+4
-0
vllm/model_executor/layers/fused_moe/pplx_prepare_finalize.py
.../model_executor/layers/fused_moe/pplx_prepare_finalize.py
+61
-34
No files found.
tests/kernels/moe/parallel_utils.py
View file @
78fe7753
...
...
@@ -137,8 +137,7 @@ def make_deepep_ht_a2a(pg: ProcessGroup,
low_latency_mode
=
low_latency_mode
,
num_qps_per_rank
=
num_qps_per_rank
)
return
DeepEPHTPrepareAndFinalize
(
buffer
=
buffer
,
world_size
=
pgi
.
world_size
,
rank
=
pgi
.
rank
,
num_dispatchers
=
pgi
.
world_size
,
dp_size
=
dp_size
,
rank_expert_offset
=
pgi
.
rank
*
ht_args
.
num_local_experts
)
...
...
@@ -146,7 +145,6 @@ def make_deepep_ht_a2a(pg: ProcessGroup,
def
make_deepep_ll_a2a
(
pg
:
ProcessGroup
,
pgi
:
ProcessGroupInfo
,
dp_size
:
int
,
deepep_ll_args
:
DeepEPLLArgs
,
q_dtype
:
Optional
[
torch
.
dtype
]
=
None
,
block_shape
:
Optional
[
list
[
int
]]
=
None
):
...
...
@@ -166,8 +164,7 @@ def make_deepep_ll_a2a(pg: ProcessGroup,
return
DeepEPLLPrepareAndFinalize
(
buffer
=
buffer
,
world_size
=
pgi
.
world_size
,
dp_size
=
dp_size
,
num_dispatchers
=
pgi
.
world_size
,
max_tokens_per_rank
=
deepep_ll_args
.
max_tokens_per_rank
,
use_fp8_dispatch
=
deepep_ll_args
.
use_fp8_dispatch
,
)
...
...
@@ -186,5 +183,4 @@ def make_deepep_a2a(pg: ProcessGroup,
block_shape
)
assert
deepep_ll_args
is
not
None
return
make_deepep_ll_a2a
(
pg
,
pgi
,
dp_size
,
deepep_ll_args
,
q_dtype
,
block_shape
)
return
make_deepep_ll_a2a
(
pg
,
pgi
,
deepep_ll_args
,
q_dtype
,
block_shape
)
tests/kernels/moe/test_batched_moe.py
View file @
78fe7753
...
...
@@ -10,7 +10,7 @@ import triton.language as tl
from
tests.kernels.moe.utils
import
(
batched_moe
,
make_quantized_test_activations
,
make_test_weights
,
triton
_moe
)
make_test_weights
,
naive_batched
_moe
)
from
tests.kernels.quant_utils
import
native_batched_masked_quant_matmul
from
tests.kernels.utils
import
torch_experts
from
vllm.config
import
VllmConfig
,
set_current_vllm_config
...
...
@@ -33,12 +33,10 @@ MNK_FACTORS = [
(
45
,
512
,
512
),
(
45
,
1024
,
128
),
(
45
,
1024
,
2048
),
(
64
,
128
,
128
),
(
64
,
512
,
512
),
(
64
,
1024
,
2048
),
(
222
,
128
,
128
),
(
222
,
128
,
2048
),
(
222
,
512
,
512
),
(
222
,
1024
,
128
),
(
222
,
1024
,
2048
),
]
...
...
@@ -95,11 +93,12 @@ class BatchedMMTensors:
@
pytest
.
mark
.
parametrize
(
"max_tokens_per_expert"
,
[
32
,
64
,
128
,
192
,
224
,
256
,
512
])
@
pytest
.
mark
.
parametrize
(
"K"
,
[
128
,
256
,
1024
])
@
pytest
.
mark
.
parametrize
(
"N"
,
[
128
,
256
,
512
,
1024
])
@
pytest
.
mark
.
parametrize
(
"dtype"
,
[
torch
.
float32
,
torch
.
float16
,
torch
.
bfloat16
])
@
pytest
.
mark
.
parametrize
(
"block_shape"
,
[
None
])
@
pytest
.
mark
.
parametrize
(
"per_act_token_quant"
,
[
False
])
@
pytest
.
mark
.
parametrize
(
"N"
,
[
128
,
256
,
1024
])
@
pytest
.
mark
.
parametrize
(
"dtype"
,
[
torch
.
float8_e4m3fn
,
torch
.
float32
,
torch
.
float16
,
torch
.
bfloat16
])
@
pytest
.
mark
.
parametrize
(
"block_shape"
,
[
None
,
[
128
,
128
]])
@
pytest
.
mark
.
parametrize
(
"per_act_token_quant"
,
[
False
,
True
])
def
test_batched_mm
(
num_experts
:
int
,
max_tokens_per_expert
:
int
,
K
:
int
,
N
:
int
,
dtype
:
torch
.
dtype
,
block_shape
:
Optional
[
list
[
int
]],
...
...
@@ -134,7 +133,8 @@ def test_batched_mm(num_experts: int, max_tokens_per_expert: int, K: int,
in_dtype
=
act_dtype
,
quant_dtype
=
quant_dtype
,
block_shape
=
block_shape
,
per_act_token_quant
=
per_act_token_quant
)
per_act_token_quant
=
per_act_token_quant
,
)
B
,
B_q
,
B_scale
,
_
,
_
,
_
=
make_test_weights
(
num_experts
,
...
...
@@ -143,6 +143,7 @@ def test_batched_mm(num_experts: int, max_tokens_per_expert: int, K: int,
in_dtype
=
act_dtype
,
quant_dtype
=
quant_dtype
,
block_shape
=
block_shape
,
per_act_token_quant
=
per_act_token_quant
,
)
out_shape
=
(
num_experts
,
max_tokens_per_expert
,
N
)
...
...
@@ -177,6 +178,7 @@ def test_batched_mm(num_experts: int, max_tokens_per_expert: int, K: int,
"BLOCK_SIZE_N"
:
16
,
"BLOCK_SIZE_K"
:
16
if
dtype
.
itemsize
>
1
else
32
},
per_act_token_quant
=
per_act_token_quant
,
block_shape
=
block_shape
,
)
...
...
@@ -185,15 +187,13 @@ def test_batched_mm(num_experts: int, max_tokens_per_expert: int, K: int,
B
,
ref_output
,
num_expert_tokens
,
None
,
None
,
None
,
)
q_ref_output
=
native_batched_masked_quant_matmul
(
A_q
,
B_q
,
q_ref_output
,
num_expert_tokens
,
A_scale
,
B_scale
,
block_shape
)
block_shape
,
per_act_token_quant
)
rtol
,
atol
=
{
torch
.
float16
:
(
6e-2
,
6e-2
),
...
...
@@ -201,16 +201,17 @@ def test_batched_mm(num_experts: int, max_tokens_per_expert: int, K: int,
torch
.
float32
:
(
1e-2
,
1e-2
),
}[
test_output
.
dtype
]
torch
.
testing
.
assert_close
(
ref_output
,
test
_output
,
atol
=
atol
,
rtol
=
rtol
)
torch
.
testing
.
assert_close
(
ref_output
,
q_ref
_output
,
atol
=
atol
,
rtol
=
rtol
)
torch
.
testing
.
assert_close
(
test_output
,
q_ref_output
,
atol
=
atol
,
rtol
=
rtol
)
@
pytest
.
mark
.
parametrize
((
"m"
,
"n"
,
"k"
),
MNK_FACTORS
)
@
pytest
.
mark
.
parametrize
(
"e"
,
NUM_EXPERTS
)
@
pytest
.
mark
.
parametrize
(
"topk"
,
TOP_KS
)
@
pytest
.
mark
.
parametrize
(
"dtype"
,
[
torch
.
bfloat16
])
@
pytest
.
mark
.
parametrize
(
"per_act_token_quant"
,
[
False
])
@
pytest
.
mark
.
parametrize
(
"block_shape"
,
[
None
])
@
pytest
.
mark
.
parametrize
(
"dtype"
,
[
torch
.
float8_e4m3fn
,
torch
.
bfloat16
])
@
pytest
.
mark
.
parametrize
(
"per_act_token_quant"
,
[
False
,
True
])
@
pytest
.
mark
.
parametrize
(
"block_shape"
,
[
None
,
[
128
,
128
]])
@
pytest
.
mark
.
parametrize
(
"input_scales"
,
[
False
])
def
test_fused_moe_batched_experts
(
m
:
int
,
n
:
int
,
...
...
@@ -220,15 +221,19 @@ def test_fused_moe_batched_experts(
dtype
:
torch
.
dtype
,
per_act_token_quant
:
bool
,
block_shape
:
Optional
[
list
[
int
]],
input_scales
:
bool
,
):
current_platform
.
seed_everything
(
7
)
use_fp8_w8a8
=
dtype
==
torch
.
float8_e4m3fn
if
topk
>
e
:
pytest
.
skip
(
"topk > e"
)
if
not
use_fp8_w8a8
and
(
per_act_token_quant
or
block_shape
is
not
None
):
pytest
.
skip
(
"Skip quantization test for non-quantized type"
)
if
per_act_token_quant
and
block_shape
is
not
None
or
topk
>
e
:
if
per_act_token_quant
and
block_shape
is
not
None
:
pytest
.
skip
(
"Skip illegal quantization test."
)
a
=
torch
.
randn
((
m
,
k
),
device
=
"cuda"
,
dtype
=
torch
.
bfloat16
)
/
10
...
...
@@ -241,16 +246,27 @@ def test_fused_moe_batched_experts(
act_dtype
=
dtype
quant_dtype
=
None
_
,
w1
,
w1_s
,
_
,
w2
,
w2_s
=
make_test_weights
(
e
,
w1_16
,
w1
,
w1_s
,
w2_16
,
w2
,
w2_s
=
make_test_weights
(
e
,
n
,
k
,
block_shape
=
block_shape
,
in_dtype
=
act_dtype
,
quant_dtype
=
quant_dtype
)
quant_dtype
=
quant_dtype
,
per_act_token_quant
=
per_act_token_quant
,
)
if
input_scales
and
quant_dtype
is
not
None
:
a1_scale
=
torch
.
tensor
(
1
,
device
=
"cuda"
,
dtype
=
torch
.
float32
)
a2_scale
=
torch
.
tensor
(
1
,
device
=
"cuda"
,
dtype
=
torch
.
float32
)
else
:
a1_scale
=
None
a2_scale
=
None
with
set_current_vllm_config
(
vllm_config
):
topk_weight
,
topk_ids
,
_
=
fused_topk
(
a
,
score
,
topk
,
False
)
batched_output
=
batched_moe
(
baseline_output
=
torch_experts
(
a
,
w1
,
w2
,
...
...
@@ -258,11 +274,14 @@ def test_fused_moe_batched_experts(
topk_ids
,
w1_scale
=
w1_s
,
w2_scale
=
w2_s
,
a1_scale
=
a1_scale
,
a2_scale
=
a2_scale
,
quant_dtype
=
quant_dtype
,
per_act_token_quant
=
per_act_token_quant
,
block_shape
=
block_shape
,
)
baseline_output
=
torch_experts
(
batched_output
=
naive_batched_moe
(
a
,
w1
,
w2
,
...
...
@@ -270,11 +289,14 @@ def test_fused_moe_batched_experts(
topk_ids
,
w1_scale
=
w1_s
,
w2_scale
=
w2_s
,
a1_scale
=
a1_scale
,
a2_scale
=
a2_scale
,
quant_dtype
=
quant_dtype
,
per_act_token_quant
=
per_act_token_quant
,
block_shape
=
block_shape
)
block_shape
=
block_shape
,
)
triton_output
=
triton
_moe
(
triton_output
=
batched
_moe
(
a
,
w1
,
w2
,
...
...
@@ -282,14 +304,16 @@ def test_fused_moe_batched_experts(
topk_ids
,
w1_scale
=
w1_s
,
w2_scale
=
w2_s
,
a1_scale
=
a1_scale
,
a2_scale
=
a2_scale
,
quant_dtype
=
quant_dtype
,
per_act_token_quant
=
per_act_token_quant
,
block_shape
=
block_shape
,
)
torch
.
testing
.
assert_close
(
triton
_output
,
torch
.
testing
.
assert_close
(
batched
_output
,
baseline_output
,
atol
=
2
e-2
,
atol
=
3
e-2
,
rtol
=
2e-2
)
torch
.
testing
.
assert_close
(
triton_output
,
...
...
tests/kernels/moe/test_deepep_deepgemm_moe.py
View file @
78fe7753
...
...
@@ -148,8 +148,7 @@ def make_ll_modular_kernel(pg: ProcessGroup, pgi: ProcessGroupInfo,
fused_experts
=
BatchedDeepGemmExperts
(
max_num_tokens
=
max_tokens_per_rank
,
world_size
=
pgi
.
world_size
,
dp_size
=
dp_size
,
num_dispatchers
=
pgi
.
world_size
//
dp_size
,
block_shape
=
test_config
.
block_size
,
per_act_token_quant
=
test_config
.
per_act_token_quant
)
mk
=
FusedMoEModularKernel
(
prepare_finalize
=
a2a
,
...
...
tests/kernels/moe/test_deepep_moe.py
View file @
78fe7753
...
...
@@ -154,12 +154,13 @@ def make_modular_kernel(
deepep_ht_args
=
ht_args
,
deepep_ll_args
=
ll_args
)
num_dispatchers
=
pgi
.
world_size
//
dp_size
if
low_latency_mode
:
assert
not
per_act_token_quant
,
"not supported in ll mode"
fused_experts
=
BatchedTritonExperts
(
max_num_tokens
=
MAX_TOKENS_PER_RANK
,
world_size
=
pgi
.
world_size
,
dp_size
=
dp_size
,
num_dispatchers
=
num_dispatchers
,
use_fp8_w8a8
=
is_quantized
,
use_int8_w8a8
=
False
,
use_int8_w8a16
=
False
,
...
...
tests/kernels/moe/test_pplx_cutlass_moe.py
View file @
78fe7753
...
...
@@ -14,6 +14,7 @@ from vllm.model_executor.layers.fused_moe.fused_moe import fused_topk
from
vllm.model_executor.layers.fused_moe.modular_kernel
import
(
FusedMoEModularKernel
)
from
vllm.platforms
import
current_platform
from
vllm.utils
import
cdiv
from
.parallel_utils
import
ProcessGroupInfo
,
parallel_launch
...
...
@@ -112,18 +113,21 @@ def pplx_cutlass_moe(
w2_scale
=
w2_scale
.
to
(
device
)
a1_scale
=
a1_scale
.
to
(
device
)
assert
num_experts
%
world_size
==
0
num_local_experts
=
cdiv
(
num_experts
,
world_size
)
num_dispatchers
=
pgi
.
world_size
//
dp_size
prepare_finalize
=
PplxPrepareAndFinalize
(
ata
,
max_num_tokens
,
pgi
.
world_size
,
rank
,
dp_size
,
)
max_num_tokens
=
max_num_tokens
,
num_local_experts
=
num_local_experts
,
num_dispatchers
=
num_dispatchers
)
experts
=
CutlassExpertsFp8
(
(
num_experts
+
world_size
-
1
)
//
world_size
,
experts
=
CutlassExpertsFp8
(
num_
local_
experts
,
out_dtype
,
per_act_token
,
per_out_ch
,
num_dispatchers
=
num_dispatchers
,
use_batched_format
=
True
)
fused_cutlass_experts
=
FusedMoEModularKernel
(
...
...
@@ -181,6 +185,7 @@ def _pplx_moe(
per_out_ch
:
bool
,
use_internode
:
bool
,
):
try
:
if
use_internode
:
uid
=
nvshmem_get_unique_id
(
)
if
pgi
.
rank
==
0
else
nvshmem_alloc_empty_unique_id
()
...
...
@@ -188,12 +193,13 @@ def _pplx_moe(
nvshmem_init
(
uid
,
pgi
.
rank
,
pgi
.
world_size
)
else
:
group_ranks
=
list
(
range
(
pgi
.
world_size
))
cpu_group
=
torch
.
distributed
.
new_group
(
group_ranks
,
backend
=
"gloo"
)
cpu_group
=
torch
.
distributed
.
new_group
(
group_ranks
,
backend
=
"gloo"
)
group_name
=
cpu_group
.
group_name
with
set_current_vllm_config
(
vllm_config
):
torch_output
=
torch_experts
(
a_full
,
w1_full
,
w2_full
,
topk_weights
,
topk_ids
)
torch_output
=
torch_experts
(
a_full
,
w1_full
,
w2_full
,
topk_weights
,
topk_ids
)
pplx_output
=
pplx_cutlass_moe
(
pgi
,
dp_size
,
a
,
w1
,
w2
,
w1_scale
,
w2_scale
,
topk_weights
,
topk_ids
,
a1_scale
,
out_dtype
,
per_act_token
,
...
...
@@ -206,8 +212,11 @@ def _pplx_moe(
# print("PPLX OUT:", pplx_output)
# print("TORCH OUT:", torch_output)
torch
.
testing
.
assert_close
(
pplx_output
,
torch_output
,
atol
=
0.05
,
rtol
=
0
)
torch
.
testing
.
assert_close
(
pplx_output
,
torch_output
,
atol
=
0.05
,
rtol
=
0
)
finally
:
if
use_internode
:
nvshmem_finalize
()
...
...
tests/kernels/moe/test_pplx_moe.py
View file @
78fe7753
...
...
@@ -4,7 +4,10 @@
Run `pytest tests/kernels/test_pplx_moe.py`.
"""
from
typing
import
Optional
import
itertools
import
textwrap
import
traceback
from
typing
import
Callable
,
Optional
import
pytest
import
torch
...
...
@@ -19,12 +22,13 @@ except ImportError:
has_pplx
=
False
from
tests.kernels.moe.utils
import
make_test_weights
,
naive_batched_moe
from
tests.kernels.quant_utils
import
dequant
from
tests.kernels.utils
import
torch_experts
from
vllm.config
import
VllmConfig
,
set_current_vllm_config
from
vllm.model_executor.layers.fused_moe
import
fused_topk
,
override_config
from
vllm.model_executor.layers.fused_moe.config
import
FusedMoEQuantConfig
from
vllm.model_executor.layers.fused_moe.fused_batched_moe
import
(
Batched
PrepareAndFinalize
,
BatchedTritonExperts
,
NaiveBatched
Experts
)
Batched
Triton
Experts
)
from
vllm.model_executor.layers.fused_moe.fused_moe
import
get_default_config
from
vllm.model_executor.layers.fused_moe.modular_kernel
import
(
FusedMoEModularKernel
)
...
...
@@ -38,22 +42,22 @@ requires_pplx = pytest.mark.skipif(
reason
=
"Requires PPLX kernels"
,
)
PPLX_PREPARE_COMBOS
=
[(
4
,
128
,
128
),
(
32
,
1024
,
512
),
(
64
,
1024
,
512
),
(
222
,
2048
,
1024
)]
PPLX_MOE_COMBOS
=
[
(
1
,
128
,
128
),
PPLX_COMBOS
=
[
# TODO: figure out why this fails, seems to be test problem
#(1, 128, 128),
(
2
,
128
,
512
),
(
3
,
1024
,
2048
),
(
32
,
128
,
1024
),
(
4
,
128
,
128
),
(
32
,
1024
,
512
),
(
45
,
512
,
2048
),
(
64
,
1024
,
1024
),
(
222
,
1024
,
2048
),
(
64
,
1024
,
512
),
(
222
,
2048
,
1024
),
(
256
,
1408
,
2048
),
]
NUM_EXPERTS
=
[
8
,
64
]
EP_SIZE
=
[
1
,
4
]
TOP_KS
=
[
1
,
2
,
6
]
DTYPES
=
[
torch
.
float8_e4m3fn
,
torch
.
bfloat16
]
vllm_config
=
VllmConfig
()
vllm_config
.
scheduler_config
.
max_num_seqs
=
128
...
...
@@ -169,9 +173,11 @@ def test_fused_moe_batched_experts(
with
set_current_vllm_config
(
vllm_config
):
topk_weight
,
topk_ids
,
_
=
fused_topk
(
a
,
score
,
topk
,
False
)
baseline_output
=
torch_experts
(
a
,
w1
,
w2
,
topk_weight
,
topk_ids
)
baseline_output
=
torch_experts
(
a
,
w1
,
w2
,
topk_weight
,
topk_ids
)
# only for baseline
torch_output
=
torch_batched_moe
(
a
,
w1
,
w2
,
topk_weight
,
topk_ids
)
batched_output
=
naive_batched_moe
(
a
,
w1
,
w2
,
topk_weight
,
topk_ids
)
batched_output
=
naive_batched_moe
(
a
,
w1
,
w2
,
topk_weight
,
topk_ids
)
# pick torch_experts or this
torch
.
testing
.
assert_close
(
baseline_output
,
torch_output
,
...
...
@@ -183,6 +189,63 @@ def test_fused_moe_batched_experts(
rtol
=
0
)
def
create_pplx_prepare_finalize
(
num_tokens
:
int
,
hidden_dim
:
int
,
topk
:
int
,
num_experts
:
int
,
rank
:
int
,
dp_size
:
int
,
world_size
:
int
,
in_dtype
:
torch
.
dtype
,
quant_dtype
:
Optional
[
torch
.
dtype
],
block_shape
:
Optional
[
list
[
int
]],
per_act_token_quant
:
bool
,
group_name
:
Optional
[
str
],
):
from
vllm.model_executor.layers.fused_moe.pplx_prepare_finalize
import
(
PplxPrepareAndFinalize
,
pplx_hidden_dim_scale_bytes
)
max_num_tokens
=
max
(
rank_chunk
(
num_tokens
,
0
,
world_size
),
1
)
num_local_experts
=
rank_chunk
(
num_experts
,
0
,
world_size
)
hidden_dim_bytes
,
scale_bytes
=
pplx_hidden_dim_scale_bytes
(
max_num_tokens
,
hidden_dim
,
in_dtype
,
quant_dtype
,
per_act_token_quant
=
per_act_token_quant
,
block_shape
=
block_shape
,
)
args
=
dict
(
max_num_tokens
=
max_num_tokens
,
num_experts
=
num_experts
,
experts_per_token
=
topk
,
rank
=
rank
,
world_size
=
world_size
,
dp_size
=
dp_size
,
hidden_dim
=
hidden_dim
,
hidden_dim_bytes
=
hidden_dim_bytes
,
hidden_dim_scale_bytes
=
scale_bytes
,
)
if
group_name
is
None
:
ata
=
AllToAll
.
internode
(
**
args
)
else
:
args
[
"group_name"
]
=
group_name
ata
=
AllToAll
.
intranode
(
**
args
)
prepare_finalize
=
PplxPrepareAndFinalize
(
ata
,
max_num_tokens
=
max_num_tokens
,
num_local_experts
=
num_local_experts
,
num_dispatchers
=
world_size
//
dp_size
,
)
return
prepare_finalize
,
ata
def
rank_chunk
(
num
:
int
,
r
:
int
,
w
:
int
)
->
int
:
rem
=
num
%
w
return
(
num
//
w
)
+
(
1
if
r
<
rem
else
0
)
...
...
@@ -193,6 +256,35 @@ def chunk_by_rank(t: torch.Tensor, r: int, w: int) -> torch.Tensor:
return
t
[(
r
*
chunk
):(
r
+
1
)
*
chunk
]
def
maybe_chunk_by_rank
(
t
:
Optional
[
torch
.
Tensor
],
r
:
int
,
w
:
int
)
->
Optional
[
torch
.
Tensor
]:
if
t
is
not
None
:
return
chunk_by_rank
(
t
,
r
,
w
)
else
:
return
t
def
chunk_scales_by_rank
(
t
:
Optional
[
torch
.
Tensor
],
r
:
int
,
w
:
int
)
->
Optional
[
torch
.
Tensor
]:
if
t
is
not
None
and
t
.
numel
()
>
1
:
chunk
=
rank_chunk
(
t
.
shape
[
0
],
r
,
w
)
return
t
[(
r
*
chunk
):(
r
+
1
)
*
chunk
]
else
:
return
t
def
chunk_scales
(
t
:
Optional
[
torch
.
Tensor
],
start
:
int
,
end
:
int
)
->
Optional
[
torch
.
Tensor
]:
if
t
is
not
None
and
t
.
numel
()
>
1
:
return
t
[
start
:
end
]
else
:
return
t
def
dummy_work
(
a
:
torch
.
Tensor
)
->
torch
.
Tensor
:
return
a
*
1.1
def
pplx_prepare_finalize
(
pgi
:
ProcessGroupInfo
,
dp_size
:
int
,
...
...
@@ -200,11 +292,11 @@ def pplx_prepare_finalize(
topk_weight
:
torch
.
Tensor
,
topk_ids
:
torch
.
Tensor
,
num_experts
:
int
,
quant_dtype
:
Optional
[
torch
.
dtype
],
block_shape
:
Optional
[
list
[
int
]],
per_act_token_quant
:
bool
,
group_name
:
Optional
[
str
],
)
->
torch
.
Tensor
:
from
vllm.model_executor.layers.fused_moe.pplx_prepare_finalize
import
(
PplxPrepareAndFinalize
)
assert
torch
.
cuda
.
current_device
()
==
pgi
.
local_rank
topk
=
topk_ids
.
shape
[
1
]
...
...
@@ -212,60 +304,66 @@ def pplx_prepare_finalize(
device
=
pgi
.
device
rank
=
pgi
.
rank
world_size
=
pgi
.
world_size
max_num_tokens
=
rank_chunk
(
num_tokens
,
0
,
world_size
)
args
=
dict
(
max_num_tokens
=
max_num_tokens
,
num_experts
=
num_experts
,
experts_per_token
=
topk
,
rank
=
rank
,
world_size
=
world_size
,
dp_size
=
dp_size
,
hidden_dim
=
hidden_dim
,
hidden_dim_bytes
=
hidden_dim
*
a
.
dtype
.
itemsize
,
hidden_dim_scale_bytes
=
0
,
)
if
group_name
is
None
:
ata
=
AllToAll
.
internode
(
**
args
)
else
:
args
[
"group_name"
]
=
group_name
ata
=
AllToAll
.
intranode
(
**
args
)
topk_ids
=
topk_ids
.
to
(
dtype
=
torch
.
uint32
)
prepare_finalize
=
PplxPrepareAndFinalize
(
ata
,
max_num_tokens
,
world_size
,
prepare_finalize
,
ata
=
create_pplx_prepare_finalize
(
num_tokens
,
hidden_dim
,
topk
,
num_experts
,
rank
,
dp_size
,
world_size
,
a
.
dtype
,
quant_dtype
,
block_shape
,
per_act_token_quant
,
group_name
,
)
assert
a
.
shape
[
0
]
==
topk_ids
.
shape
[
0
]
a_chunk
=
chunk_by_rank
(
a
,
rank
,
world_size
).
to
(
device
)
chunk_topk_weight
=
chunk_by_rank
(
topk_weight
,
rank
,
world_size
).
to
(
device
)
chunk_topk_ids
=
chunk_by_rank
(
topk_ids
,
rank
,
world_size
).
to
(
device
)
assert
a_chunk
.
shape
[
0
]
==
chunk_topk_ids
.
shape
[
0
]
out
=
torch
.
full
(
a_chunk
.
shape
,
torch
.
nan
,
dtype
=
a
.
dtype
,
device
=
device
,
)
if
(
quant_dtype
is
not
None
and
not
per_act_token_quant
and
block_shape
is
None
):
a1_scale
=
torch
.
tensor
(
1.0
,
device
=
"cuda"
,
dtype
=
torch
.
float32
)
a2_scale
=
torch
.
tensor
(
1.0
,
device
=
"cuda"
,
dtype
=
torch
.
float32
)
else
:
a1_scale
=
None
a2_scale
=
None
b_a
,
b_a_scale
,
expert_num_tokens
,
_
,
_
=
prepare_finalize
.
prepare
(
a_chunk
,
Non
e
,
Non
e
,
a1_scal
e
,
a2_scal
e
,
chunk_topk_weight
,
chunk_topk_ids
,
num_experts
,
None
,
False
,
FusedMoEQuantConfig
(),
FusedMoEQuantConfig
(
quant_dtype
,
per_act_token_quant
,
False
,
block_shape
,
),
)
b_a
=
b_a
*
1.5
out
=
torch
.
full
(
(
max_num_tokens
,
hidden_dim
),
torch
.
nan
,
dtype
=
a
.
dtype
,
device
=
device
,
)
b_a
=
dummy_work
(
dequant
(
b_a
,
b_a_scale
,
block_shape
,
per_act_token_quant
,
a
.
dtype
))
prepare_finalize
.
finalize
(
out
,
...
...
@@ -291,8 +389,12 @@ def _pplx_prepare_finalize(
score
:
torch
.
Tensor
,
topk
:
torch
.
Tensor
,
num_experts
:
int
,
quant_dtype
:
Optional
[
torch
.
dtype
],
block_shape
:
Optional
[
list
[
int
]],
per_act_token_quant
:
bool
,
use_internode
:
bool
,
):
try
:
if
use_internode
:
uid
=
nvshmem_get_unique_id
(
)
if
pgi
.
rank
==
0
else
nvshmem_alloc_empty_unique_id
()
...
...
@@ -301,60 +403,82 @@ def _pplx_prepare_finalize(
group_name
=
None
else
:
group_ranks
=
list
(
range
(
pgi
.
world_size
))
cpu_group
=
torch
.
distributed
.
new_group
(
group_ranks
,
backend
=
"gloo"
)
cpu_group
=
torch
.
distributed
.
new_group
(
group_ranks
,
backend
=
"gloo"
)
group_name
=
cpu_group
.
group_name
device
=
pgi
.
device
topk_weight
,
topk_ids
,
_
=
fused_topk
(
a
,
score
,
topk
,
False
)
k
=
a
.
shape
[
1
]
m
,
k
=
a
.
shape
a_rep
=
torch
.
repeat_interleave
(
a
,
topk
,
dim
=
0
)
.
to
(
device
)
a_rep
=
torch
.
repeat_interleave
(
dummy_work
(
a
)
,
topk
,
dim
=
0
)
torch_output
=
(
a_rep
.
view
(
-
1
,
topk
,
k
)
*
1.5
*
topk_weight
.
view
(
-
1
,
topk
,
1
).
to
(
device
)).
sum
(
dim
=
1
).
to
(
a
.
dtype
)
torch_output
=
(
a_rep
.
view
(
m
,
topk
,
k
)
*
topk_weight
.
view
(
m
,
topk
,
1
).
to
(
a_rep
.
dtype
)).
sum
(
dim
=
1
)
pplx_output
=
pplx_prepare_finalize
(
pgi
,
dp_size
,
a
,
topk_weight
,
topk_ids
,
num_experts
,
group_name
)
pplx_output
=
pplx_prepare_finalize
(
pgi
,
dp_size
,
a
,
topk_weight
,
topk_ids
,
num_experts
,
quant_dtype
,
block_shape
,
per_act_token_quant
,
group_name
)
torch_output
=
chunk_by_rank
(
torch_output
,
pgi
.
rank
,
pgi
.
world_size
).
to
(
pplx_output
.
device
)
torch
.
testing
.
assert_close
(
pplx_output
,
torch_output
,
atol
=
2e-2
,
rtol
=
0
)
pgi
.
world_size
).
to
(
pgi
.
device
)
torch
.
testing
.
assert_close
(
pplx_output
,
torch_output
,
atol
=
3e-2
,
rtol
=
3e-2
)
finally
:
if
use_internode
:
nvshmem_finalize
()
# TODO (bnell): this test point does not work for odd M due to how the test is
# written, not due to limitations of the pplx kernels. The pplx_moe
# test below is able to deal with odd M.
# TODO (bnell) add fp8 tests
@
pytest
.
mark
.
parametrize
(
"mnk"
,
PPLX_PREPARE_COMBOS
)
@
pytest
.
mark
.
parametrize
(
"mnk"
,
PPLX_COMBOS
)
@
pytest
.
mark
.
parametrize
(
"e"
,
NUM_EXPERTS
)
@
pytest
.
mark
.
parametrize
(
"topk"
,
TOP_KS
)
@
pytest
.
mark
.
parametrize
(
"dtype"
,
[
torch
.
bfloat16
]
)
@
pytest
.
mark
.
parametrize
(
"dtype"
,
DTYPES
)
@
pytest
.
mark
.
parametrize
(
"world_dp_size"
,
[[
2
,
1
]])
@
pytest
.
mark
.
parametrize
(
"per_act_token_quant"
,
[
False
,
True
])
@
pytest
.
mark
.
parametrize
(
"block_shape"
,
[
None
,
[
128
,
128
]])
@
pytest
.
mark
.
parametrize
(
"use_internode"
,
[
False
])
@
pytest
.
mark
.
optional
@
requires_pplx
def
test_pplx_prepare_finalize
(
def
test_pplx_prepare_finalize
_slow
(
mnk
:
tuple
[
int
,
int
,
int
],
e
:
int
,
topk
:
int
,
dtype
:
torch
.
dtype
,
world_dp_size
:
tuple
[
int
,
int
],
per_act_token_quant
:
bool
,
block_shape
:
Optional
[
list
[
int
]],
use_internode
:
bool
,
):
if
dtype
==
torch
.
float8_e4m3fn
:
use_fp8_w8a8
=
True
act_dtype
=
torch
.
bfloat16
quant_dtype
=
dtype
else
:
use_fp8_w8a8
=
False
act_dtype
=
dtype
quant_dtype
=
None
if
not
use_fp8_w8a8
and
(
per_act_token_quant
or
block_shape
is
not
None
):
pytest
.
skip
(
"Skip quantization test for non-quantized type"
)
if
per_act_token_quant
and
block_shape
is
not
None
:
pytest
.
skip
(
"Skip illegal quantization combination"
)
current_platform
.
seed_everything
(
7
)
m
,
n
,
k
=
mnk
world_size
,
dp_size
=
world_dp_size
device
=
"cuda"
a
=
torch
.
randn
((
m
,
k
),
device
=
device
,
dtype
=
dtype
)
/
10
score
=
torch
.
randn
((
m
,
e
),
device
=
device
,
dtype
=
dtype
)
a
=
torch
.
randn
((
m
,
k
),
device
=
device
,
dtype
=
act_dtype
)
/
10
score
=
torch
.
randn
((
m
,
e
),
device
=
device
,
dtype
=
act_dtype
)
parallel_launch
(
world_size
,
_pplx_prepare_finalize
,
dp_size
,
a
,
score
,
topk
,
e
,
use_internode
)
topk
,
e
,
quant_dtype
,
block_shape
,
per_act_token_quant
,
use_internode
)
def
pplx_moe
(
...
...
@@ -369,84 +493,62 @@ def pplx_moe(
topk_ids
:
torch
.
Tensor
,
w1_scale
:
Optional
[
torch
.
Tensor
]
=
None
,
w2_scale
:
Optional
[
torch
.
Tensor
]
=
None
,
qtype
:
Optional
[
torch
.
dtype
]
=
None
,
a1_scale
:
Optional
[
torch
.
Tensor
]
=
None
,
a2_scale
:
Optional
[
torch
.
Tensor
]
=
None
,
quant_dtype
:
Optional
[
torch
.
dtype
]
=
None
,
per_act_token_quant
=
False
,
block_shape
:
Optional
[
list
[
int
]]
=
None
,
use_compile
:
bool
=
False
,
use_cudagraphs
:
bool
=
True
,
)
->
torch
.
Tensor
:
from
vllm.model_executor.layers.fused_moe.pplx_prepare_finalize
import
(
PplxPrepareAndFinalize
,
pplx_hidden_dim_scale_bytes
)
device
=
torch
.
device
(
"cuda"
,
rank
)
hidden_dim
=
a
.
shape
[
1
]
num_tokens
,
hidden_dim
=
a
.
shape
num_experts
=
w1
.
shape
[
0
]
topk
=
topk_ids
.
shape
[
1
]
max_num_tokens
=
round_up
(
rank_chunk
(
a
.
shape
[
0
],
0
,
world_size
),
6
4
)
max_num_tokens
=
round_up
(
rank_chunk
(
a
.
shape
[
0
],
0
,
world_size
),
1
6
)
hidden_dim_bytes
,
scale_bytes
=
pplx_hidden_dim_scale_bytes
(
max_
num_tokens
,
prepare_finalize
,
ata
=
create_pplx_prepare_finalize
(
num_tokens
,
hidden_dim
,
topk
,
num_experts
,
rank
,
dp_size
,
world_size
,
a
.
dtype
,
qtype
,
per_act_token_quant
=
per_act_token_quant
,
block_shape
=
block_shape
,
quant_dtype
,
block_shape
,
per_act_token_quant
,
group_name
,
)
args
=
dict
(
max_num_tokens
=
max_num_tokens
,
num_experts
=
num_experts
,
experts_per_token
=
topk
,
rank
=
rank
,
world_size
=
world_size
,
dp_size
=
dp_size
,
hidden_dim
=
hidden_dim
,
hidden_dim_bytes
=
hidden_dim_bytes
,
hidden_dim_scale_bytes
=
scale_bytes
,
)
if
group_name
is
None
:
ata
=
AllToAll
.
internode
(
**
args
)
else
:
args
[
"group_name"
]
=
group_name
ata
=
AllToAll
.
intranode
(
**
args
)
topk_ids
=
topk_ids
.
to
(
dtype
=
torch
.
uint32
)
prepare_finalize
=
PplxPrepareAndFinalize
(
ata
,
max_num_tokens
,
world_size
,
rank
,
dp_size
,
experts
=
BatchedTritonExperts
(
max_num_tokens
=
max_num_tokens
,
num_dispatchers
=
prepare_finalize
.
num_dispatchers
()
,
use_fp8_w8a8
=
quant_dtype
==
torch
.
float8_e4m3fn
,
block_shape
=
block_shape
,
per_act_token_quant
=
per_act_token_quant
,
)
experts
=
BatchedTritonExperts
(
max_num_tokens
=
max_num_tokens
,
world_size
=
world_size
,
dp_size
=
dp_size
,
use_fp8_w8a8
=
qtype
==
torch
.
float8_e4m3fn
,
block_shape
=
block_shape
)
fused_experts
=
FusedMoEModularKernel
(
prepare_finalize
,
experts
,
)
# Note: workers with the same dp_rank must use the exact same inputs.
a_chunk
=
chunk_by_rank
(
a
,
rank
,
world_size
)
.
to
(
device
)
chunk_topk_weight
=
chunk_by_rank
(
topk_weight
,
rank
,
world_size
)
.
to
(
device
)
chunk_topk_ids
=
chunk_by_rank
(
topk_ids
,
rank
,
world_size
)
.
to
(
device
)
a_chunk
=
chunk_by_rank
(
a
,
rank
,
world_size
)
chunk_topk_weight
=
chunk_by_rank
(
topk_weight
,
rank
,
world_size
)
chunk_topk_ids
=
chunk_by_rank
(
topk_ids
,
rank
,
world_size
)
# Chunking weights like this only works for batched format
w1_chunk
=
chunk_by_rank
(
w1
,
rank
,
world_size
).
to
(
device
)
w2_chunk
=
chunk_by_rank
(
w2
,
rank
,
world_size
).
to
(
device
)
if
w1_scale
is
not
None
:
w1_scale_chunk
=
chunk_by_rank
(
w1_scale
,
rank
,
world_size
).
to
(
device
)
w2_scale_chunk
=
chunk_by_rank
(
w2_scale
,
rank
,
world_size
).
to
(
device
)
else
:
w1_scale_chunk
=
None
w2_scale_chunk
=
None
w1_chunk
=
chunk_by_rank
(
w1
,
rank
,
world_size
)
w2_chunk
=
chunk_by_rank
(
w2
,
rank
,
world_size
)
w1_scale_chunk
=
maybe_chunk_by_rank
(
w1_scale
,
rank
,
world_size
)
w2_scale_chunk
=
maybe_chunk_by_rank
(
w2_scale
,
rank
,
world_size
)
a1_scale_chunk
=
chunk_scales_by_rank
(
a1_scale
,
rank
,
world_size
)
a2_scale_chunk
=
chunk_scales_by_rank
(
a2_scale
,
rank
,
world_size
)
# Note: for now use_compile will error out if the problem size is
# large enough to trigger chunking. I'm leaving the flag and
...
...
@@ -468,6 +570,8 @@ def pplx_moe(
chunk_topk_ids
,
w1_scale
=
w1_scale_chunk
,
w2_scale
=
w2_scale_chunk
,
a1_scale
=
a1_scale_chunk
,
a2_scale
=
a2_scale_chunk
,
global_num_experts
=
num_experts
)
if
use_cudagraphs
:
...
...
@@ -482,6 +586,8 @@ def pplx_moe(
chunk_topk_ids
,
w1_scale
=
w1_scale_chunk
,
w2_scale
=
w2_scale_chunk
,
a1_scale
=
a1_scale_chunk
,
a2_scale
=
a2_scale_chunk
,
global_num_experts
=
num_experts
)
torch
.
cuda
.
synchronize
()
...
...
@@ -494,48 +600,6 @@ def pplx_moe(
return
out
def
_batched_moe
(
pgi
,
dp_size
,
a
,
w1
,
w2
,
topk_weight
,
topk_ids
):
assert
torch
.
cuda
.
current_device
()
==
pgi
.
local_rank
num_experts
=
w1
.
shape
[
0
]
device
=
pgi
.
device
rank
=
pgi
.
rank
world_size
=
pgi
.
world_size
max_num_tokens
=
rank_chunk
(
a
.
shape
[
0
],
0
,
world_size
)
prepare_finalize
=
BatchedPrepareAndFinalize
(
max_num_tokens
=
max_num_tokens
,
world_size
=
world_size
,
dp_size
=
dp_size
,
rank
=
rank
,
)
experts
=
NaiveBatchedExperts
(
max_num_tokens
=
a
.
shape
[
0
],
world_size
=
1
,
dp_size
=
1
)
fused_experts
=
FusedMoEModularKernel
(
prepare_finalize
,
experts
,
)
# Note: workers with the same dp_rank must use the exact same inputs.
a_chunk
=
chunk_by_rank
(
a
,
rank
,
world_size
).
to
(
device
)
chunk_topk_weight
=
chunk_by_rank
(
topk_weight
,
rank
,
world_size
).
to
(
device
)
chunk_topk_ids
=
chunk_by_rank
(
topk_ids
,
rank
,
world_size
).
to
(
device
)
out
=
fused_experts
(
a_chunk
,
# Chunking weights like this only works for batched format
chunk_by_rank
(
w1
,
rank
,
world_size
).
to
(
device
),
chunk_by_rank
(
w2
,
rank
,
world_size
).
to
(
device
),
chunk_topk_weight
,
chunk_topk_ids
,
global_num_experts
=
num_experts
)
return
out
def
_pplx_moe
(
pgi
:
ProcessGroupInfo
,
dp_size
:
int
,
...
...
@@ -544,13 +608,15 @@ def _pplx_moe(
w2
:
torch
.
Tensor
,
score
:
torch
.
Tensor
,
topk
:
int
,
num_experts
:
int
,
w1_s
:
Optional
[
torch
.
Tensor
]
=
None
,
w2_s
:
Optional
[
torch
.
Tensor
]
=
None
,
qtype
:
Optional
[
torch
.
dtype
]
=
None
,
q
uant_d
type
:
Optional
[
torch
.
dtype
]
=
None
,
per_act_token_quant
:
bool
=
False
,
block_shape
:
Optional
[
list
[
int
]]
=
None
,
use_internode
:
bool
=
False
,
):
try
:
if
use_internode
:
uid
=
nvshmem_get_unique_id
(
)
if
pgi
.
rank
==
0
else
nvshmem_alloc_empty_unique_id
()
...
...
@@ -559,7 +625,8 @@ def _pplx_moe(
group_name
=
None
else
:
group_ranks
=
list
(
range
(
pgi
.
world_size
))
cpu_group
=
torch
.
distributed
.
new_group
(
group_ranks
,
backend
=
"gloo"
)
cpu_group
=
torch
.
distributed
.
new_group
(
group_ranks
,
backend
=
"gloo"
)
group_name
=
cpu_group
.
group_name
m
,
k
=
a
.
shape
...
...
@@ -568,51 +635,103 @@ def _pplx_moe(
moe_config
=
get_default_config
(
m
,
e
,
n
,
k
,
topk
,
a
.
dtype
,
False
)
device
=
torch
.
device
(
"cuda"
,
pgi
.
rank
)
rank
=
pgi
.
rank
world_size
=
pgi
.
world_size
a
=
a
.
to
(
device
)
w1
=
w1
.
to
(
device
)
w2
=
w2
.
to
(
device
)
w1_s
=
w1_s
.
to
(
device
)
if
w1_s
is
not
None
else
None
w2_s
=
w2_s
.
to
(
device
)
if
w2_s
is
not
None
else
None
if
(
quant_dtype
is
not
None
and
not
per_act_token_quant
and
block_shape
is
None
):
a1_scale
=
torch
.
tensor
(
1.0
,
device
=
"cuda"
,
dtype
=
torch
.
float32
)
a2_scale
=
torch
.
tensor
(
1.0
,
device
=
"cuda"
,
dtype
=
torch
.
float32
)
else
:
a1_scale
=
None
a2_scale
=
None
with
set_current_vllm_config
(
vllm_config
),
override_config
(
moe_config
):
topk_weight
,
topk_ids
,
_
=
fused_topk
(
a
,
score
,
topk
,
False
)
torch_output
=
torch_experts
(
a
,
torch_output
=
torch_experts
(
a
,
w1
,
w2
,
topk_weight
,
topk_ids
,
w1_scale
=
w1_s
,
w2_scale
=
w2_s
,
quant_dtype
=
qtype
,
a1_scale
=
a1_scale
,
a2_scale
=
a2_scale
,
quant_dtype
=
quant_dtype
,
per_act_token_quant
=
per_act_token_quant
,
block_shape
=
block_shape
)
pplx_output
=
pplx_moe
(
group_name
,
pgi
.
rank
,
pgi
.
world_size
,
dp_size
,
a
,
w1
,
w2
,
topk_weight
,
topk_ids
,
w1_s
,
w2_s
,
qtype
,
per_act_token_quant
,
block_shape
)
# TODO (bnell): fix + re-enable
#batched_output = _batched_moe(pgi, dp_size, a, w1, w2, topk_weight,
# topk_ids)
block_shape
=
block_shape
,
)
torch_output
=
chunk_by_rank
(
torch_output
,
pgi
.
rank
,
pgi
.
world_size
).
to
(
pplx_output
.
device
)
batched_output
=
naive_batched_moe
(
a
,
w1
,
w2
,
topk_weight
,
topk_ids
,
w1_scale
=
w1_s
,
w2_scale
=
w2_s
,
a1_scale
=
a1_scale
,
a2_scale
=
a2_scale
,
quant_dtype
=
quant_dtype
,
per_act_token_quant
=
per_act_token_quant
,
block_shape
=
block_shape
,
)
pplx_output
=
pplx_moe
(
group_name
,
rank
,
world_size
,
dp_size
,
a
,
w1
,
w2
,
topk_weight
,
topk_ids
,
w1_scale
=
w1_s
,
w2_scale
=
w2_s
,
a1_scale
=
a1_scale
,
a2_scale
=
a2_scale
,
quant_dtype
=
quant_dtype
,
per_act_token_quant
=
per_act_token_quant
,
block_shape
=
block_shape
,
)
torch
.
testing
.
assert_close
(
pplx_output
,
torch_output
,
atol
=
2e-2
,
rtol
=
0
)
#torch.testing.assert_close(batched_output, torch_output, atol=2e-2, rtol=0
)
chunked_batch_output
=
chunk_by_rank
(
batched_output
,
pgi
.
rank
,
pgi
.
world_size
).
to
(
pplx_output
.
device
)
torch
.
testing
.
assert_close
(
batched_output
,
torch_output
,
atol
=
3e-2
,
rtol
=
3e-2
)
torch
.
testing
.
assert_close
(
pplx_output
,
chunked_batch_output
,
atol
=
3e-2
,
rtol
=
3e-2
)
finally
:
if
use_internode
:
nvshmem_finalize
()
@
pytest
.
mark
.
parametrize
(
"mnk"
,
PPLX_
MOE_
COMBOS
)
@
pytest
.
mark
.
parametrize
(
"mnk"
,
PPLX_COMBOS
)
@
pytest
.
mark
.
parametrize
(
"e"
,
NUM_EXPERTS
)
@
pytest
.
mark
.
parametrize
(
"topk"
,
TOP_KS
)
@
pytest
.
mark
.
parametrize
(
"dtype"
,
[
torch
.
bfloat16
]
)
@
pytest
.
mark
.
parametrize
(
"dtype"
,
DTYPES
)
@
pytest
.
mark
.
parametrize
(
"world_dp_size"
,
[[
2
,
1
]])
@
pytest
.
mark
.
parametrize
(
"per_act_token_quant"
,
[
False
,
True
])
@
pytest
.
mark
.
parametrize
(
"block_shape"
,
[
None
,
[
128
,
128
]])
@
pytest
.
mark
.
parametrize
(
"use_internode"
,
[
False
])
@
pytest
.
mark
.
optional
@
requires_pplx
def
test_pplx_moe
(
def
test_pplx_moe
_slow
(
mnk
:
tuple
[
int
,
int
,
int
],
e
:
int
,
topk
:
int
,
...
...
@@ -633,18 +752,143 @@ def test_pplx_moe(
use_fp8_w8a8
=
False
quant_dtype
=
None
if
not
use_fp8_w8a8
and
per_act_token_quant
and
block_shape
is
not
None
:
if
not
use_fp8_w8a8
and
(
per_act_token_quant
or
block_shape
is
not
None
)
:
pytest
.
skip
(
"Skip quantization test for non-quantized type"
)
if
per_act_token_quant
and
block_shape
is
not
None
:
pytest
.
skip
(
"Skip illegal quantization combination"
)
a
=
torch
.
randn
((
m
,
k
),
device
=
"cuda"
,
dtype
=
torch
.
bfloat16
)
/
10
score
=
torch
.
randn
((
m
,
e
),
device
=
"cuda"
,
dtype
=
torch
.
bfloat16
)
_
,
w1
,
w1_s
,
_
,
w2
,
w2_s
=
make_test_weights
(
e
,
_
,
w1
,
w1_s
,
_
,
w2
,
w2_s
=
make_test_weights
(
e
,
n
,
k
,
quant_dtype
=
quant_dtype
,
block_shape
=
block_shape
)
block_shape
=
block_shape
,
per_act_token_quant
=
per_act_token_quant
,
)
parallel_launch
(
world_size
,
_pplx_moe
,
dp_size
,
a
,
w1
,
w2
,
score
,
topk
,
parallel_launch
(
world_size
,
_pplx_moe
,
dp_size
,
a
,
w1
,
w2
,
score
,
topk
,
e
,
w1_s
,
w2_s
,
quant_dtype
,
per_act_token_quant
,
block_shape
,
use_internode
)
def
_pplx_test_loop
(
pgi
:
ProcessGroupInfo
,
dp_size
:
int
,
use_internode
:
bool
,
make_weights
:
bool
,
test_fn
:
Callable
):
def
format_result
(
msg
,
ex
=
None
):
if
ex
is
not
None
:
x
=
str
(
ex
)
newx
=
x
.
strip
(
"
\n\t
"
)[:
16
]
if
len
(
newx
)
<
len
(
x
):
newx
=
newx
+
" ..."
prefix
=
"E
\t
"
print
(
f
"
{
textwrap
.
indent
(
traceback
.
format_exc
(),
prefix
)
}
"
)
print
(
f
"FAILED
{
msg
}
-
{
newx
}
\n
"
)
else
:
print
(
f
"PASSED
{
msg
}
"
)
current_platform
.
seed_everything
(
7
)
combos
=
itertools
.
product
(
PPLX_COMBOS
,
NUM_EXPERTS
,
TOP_KS
,
DTYPES
,
[
False
,
True
],
[
None
,
[
128
,
128
]])
exceptions
=
[]
count
=
0
for
mnk
,
e
,
topk
,
dtype
,
per_act_token_quant
,
block_shape
in
combos
:
count
=
count
+
1
m
,
n
,
k
=
mnk
if
dtype
==
torch
.
float8_e4m3fn
:
use_fp8_w8a8
=
True
quant_dtype
=
dtype
else
:
use_fp8_w8a8
=
False
quant_dtype
=
None
test_desc
=
(
f
"test_pplx_moe[mnk=
{
mnk
}
, e=
{
e
}
, topk=
{
topk
}
, "
f
"dtype=
{
dtype
}
, per_act_token=
{
per_act_token_quant
}
, "
f
"block_shape=
{
block_shape
}
"
)
if
not
use_fp8_w8a8
and
(
per_act_token_quant
or
block_shape
is
not
None
):
print
(
f
"
{
test_desc
}
- Skip quantization test for non-quantized type."
)
continue
if
per_act_token_quant
and
block_shape
is
not
None
:
print
(
f
"
{
test_desc
}
- Skip illegal quantization combination."
)
continue
a
=
torch
.
randn
((
m
,
k
),
device
=
"cuda"
,
dtype
=
torch
.
bfloat16
)
/
10
score
=
torch
.
randn
((
m
,
e
),
device
=
"cuda"
,
dtype
=
torch
.
bfloat16
)
args
=
dict
()
if
make_weights
:
_
,
w1
,
w1_s
,
_
,
w2
,
w2_s
=
make_test_weights
(
e
,
n
,
k
,
quant_dtype
=
quant_dtype
,
block_shape
=
block_shape
,
per_act_token_quant
=
per_act_token_quant
,
)
args
[
"w1"
]
=
w1
args
[
"w2"
]
=
w2
args
[
"w1_s"
]
=
w1_s
args
[
"w2_s"
]
=
w2_s
try
:
test_fn
(
pgi
=
pgi
,
dp_size
=
dp_size
,
a
=
a
,
score
=
score
,
topk
=
topk
,
num_experts
=
e
,
quant_dtype
=
quant_dtype
,
per_act_token_quant
=
per_act_token_quant
,
block_shape
=
block_shape
,
use_internode
=
use_internode
,
**
args
,
)
format_result
(
test_desc
)
except
Exception
as
ex
:
format_result
(
test_desc
,
ex
)
exceptions
.
append
(
ex
)
if
len
(
exceptions
)
>
0
:
raise
RuntimeError
(
f
"
{
len
(
exceptions
)
}
of
{
count
}
tests failed in child process, "
f
"rank=
{
pgi
.
rank
}
."
)
else
:
print
(
f
"
{
count
}
of
{
count
}
tests passed in child process, "
f
"rank=
{
pgi
.
rank
}
."
)
@
pytest
.
mark
.
parametrize
(
"world_dp_size"
,
[[
2
,
1
]])
@
pytest
.
mark
.
parametrize
(
"use_internode"
,
[
False
])
@
requires_pplx
def
test_pplx_prepare_finalize
(
world_dp_size
:
tuple
[
int
,
int
],
use_internode
:
bool
,
):
current_platform
.
seed_everything
(
7
)
world_size
,
dp_size
=
world_dp_size
parallel_launch
(
world_size
*
dp_size
,
_pplx_test_loop
,
dp_size
,
use_internode
,
False
,
_pplx_prepare_finalize
)
@
pytest
.
mark
.
parametrize
(
"world_dp_size"
,
[[
2
,
1
]])
@
pytest
.
mark
.
parametrize
(
"use_internode"
,
[
False
])
@
requires_pplx
def
test_pplx_moe
(
world_dp_size
:
tuple
[
int
,
int
],
use_internode
:
bool
,
):
current_platform
.
seed_everything
(
7
)
world_size
,
dp_size
=
world_dp_size
parallel_launch
(
world_size
,
_pplx_test_loop
,
dp_size
,
use_internode
,
True
,
_pplx_moe
)
tests/kernels/moe/utils.py
View file @
78fe7753
...
...
@@ -63,13 +63,12 @@ def batched_moe(
fused_experts
=
FusedMoEModularKernel
(
BatchedPrepareAndFinalize
(
max_num_tokens
,
world_size
=
1
,
dp_size
=
1
,
num_dispatchers
=
1
,
num_local_experts
=
w1
.
shape
[
0
]
,
rank
=
0
),
BatchedTritonExperts
(
max_num_tokens
=
max_num_tokens
,
world_size
=
1
,
dp_size
=
1
,
num_dispatchers
=
1
,
use_fp8_w8a8
=
quant_dtype
==
torch
.
float8_e4m3fn
,
per_act_token_quant
=
per_act_token_quant
,
block_shape
=
block_shape
,
...
...
@@ -105,13 +104,12 @@ def naive_batched_moe(
fused_experts
=
FusedMoEModularKernel
(
BatchedPrepareAndFinalize
(
max_num_tokens
,
world_size
=
1
,
dp_size
=
1
,
num_dispatchers
=
1
,
num_local_experts
=
w1
.
shape
[
0
]
,
rank
=
0
),
NaiveBatchedExperts
(
max_num_tokens
=
max_num_tokens
,
dp_size
=
1
,
world_size
=
1
,
num_dispatchers
=
1
,
use_fp8_w8a8
=
quant_dtype
==
torch
.
float8_e4m3fn
,
per_act_token_quant
=
per_act_token_quant
,
block_shape
=
block_shape
,
...
...
tests/kernels/quant_utils.py
View file @
78fe7753
...
...
@@ -277,6 +277,24 @@ def dequant(
return
t
.
to
(
out_dtype
)
def
batched_dequant
(
t
:
torch
.
Tensor
,
scale
:
Optional
[
torch
.
Tensor
],
block_shape
:
Optional
[
list
[
int
]],
per_act_token_quant
:
bool
,
out_dtype
:
Optional
[
torch
.
dtype
]
=
torch
.
float32
,
)
->
torch
.
Tensor
:
if
scale
is
not
None
:
assert
t
.
shape
[
0
]
==
scale
.
shape
[
0
]
out
=
torch
.
empty_like
(
t
,
dtype
=
out_dtype
)
for
e
in
range
(
t
.
shape
[
0
]):
out
[
e
]
=
dequant
(
t
[
e
],
scale
[
e
],
block_shape
,
per_act_token_quant
,
out_dtype
)
return
out
return
t
.
to
(
out_dtype
)
def
native_batched_masked_quant_matmul
(
A
:
torch
.
Tensor
,
B
:
torch
.
Tensor
,
...
...
tests/kernels/utils.py
View file @
78fe7753
...
...
@@ -1094,6 +1094,8 @@ def torch_experts(
if
expert_map
is
not
None
:
topk_ids
=
expert_map
[
topk_ids
]
f32
=
torch
.
float32
for
i
in
range
(
num_experts
):
mask
=
topk_ids
==
i
if
mask
.
sum
():
...
...
@@ -1109,7 +1111,8 @@ def torch_experts(
out
.
dtype
)
tmp2
=
SiluAndMul
()(
tmp1
)
tmp2
,
b_scale
=
moe_kernel_quantize_input
(
tmp2
,
None
,
quant_dtype
,
per_act_token_quant
,
block_shape
)
tmp2
,
a2_scale
,
quant_dtype
,
per_act_token_quant
,
block_shape
)
out
[
mask
]
=
native_w8a8_block_matmul
(
tmp2
,
w2
[
i
],
b_scale
,
w2_scale
[
i
],
block_shape
,
...
...
@@ -1117,7 +1120,6 @@ def torch_experts(
else
:
assert
(
a_scale
is
not
None
and
w1_scale
is
not
None
and
w2_scale
is
not
None
)
f32
=
torch
.
float32
scales
=
a_scale
if
a_scale
.
numel
()
==
1
else
a_scale
[
mask
]
tmp1
=
a
[
mask
].
to
(
f32
)
*
scales
w1_dq
=
(
w1
[
i
].
to
(
f32
)
*
w1_scale
[
i
]).
transpose
(
0
,
1
)
...
...
@@ -1126,8 +1128,8 @@ def torch_experts(
w2_dq
=
(
w2
[
i
].
to
(
f32
)
*
w2_scale
[
i
]).
transpose
(
0
,
1
)
out
[
mask
]
=
(
tmp2
@
w2_dq
).
to
(
out
.
dtype
)
return
(
out
.
view
(
M
,
-
1
,
w2
.
shape
[
1
])
*
topk_weight
.
view
(
M
,
-
1
,
1
).
to
(
out
.
dtype
)
).
sum
(
dim
=
1
)
return
(
out
.
view
(
M
,
-
1
,
w2
.
shape
[
1
])
.
to
(
f32
)
*
topk_weight
.
view
(
M
,
-
1
,
1
)).
sum
(
dim
=
1
).
to
(
out
.
dtype
)
def
torch_moe
(
a
:
torch
.
Tensor
,
...
...
vllm/model_executor/layers/fused_moe/batched_deep_gemm_moe.py
View file @
78fe7753
...
...
@@ -184,15 +184,14 @@ class BatchedDeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute):
def
__init__
(
self
,
max_num_tokens
:
int
,
world_size
:
int
,
dp_size
:
int
,
num_dispatchers
:
int
,
block_shape
:
list
[
int
],
per_act_token_quant
=
False
):
"""
max_num_tokens: Maximum number of tokens from a DP Rank
world_size: N
umber of
E
P
ranks
dp_size: Number of data-parallel ranks
block_shape: Block
quantization
block shape
num_dispatchers: The n
umber of
D
P
dispatchers.
block_shape: Block quantization block shape.
per_act_token_quant: Per activation token
quantization
flag.
"""
super
().
__init__
(
FusedMoEQuantConfig
(
...
...
@@ -202,8 +201,7 @@ class BatchedDeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute):
))
assert
self
.
block_shape
==
self
.
DEEPGEMM_BLOCK_SHAPE
self
.
max_num_tokens
=
max_num_tokens
self
.
world_size
=
world_size
self
.
dp_size
=
dp_size
self
.
num_dispatchers
=
num_dispatchers
@
property
def
activation_formats
(
...
...
@@ -233,7 +231,7 @@ class BatchedDeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute):
# FIXME (varun): We should be able to dispatch only from the leader
# DP ranks in the case of TP > 1. At the moment, all the Ranks
# end up sending their tokens. This needs to be fixed.
num_dispatchers
=
self
.
world_size
num_dispatchers
=
self
.
num_dispatchers
num_experts
=
local_num_experts
max_num_tokens
=
a
.
size
(
0
)
if
self
.
max_num_tokens
is
None
else
self
.
max_num_tokens
...
...
vllm/model_executor/layers/fused_moe/batched_triton_or_deep_gemm_moe.py
View file @
78fe7753
...
...
@@ -15,8 +15,7 @@ class BatchedTritonOrDeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute):
def
__init__
(
self
,
max_num_tokens
:
int
,
world_size
:
int
,
dp_size
:
int
,
num_dispatchers
:
int
,
use_fp8_w8a8
:
bool
=
False
,
use_int8_w8a8
:
bool
=
False
,
use_int8_w8a16
:
bool
=
False
,
...
...
@@ -37,35 +36,28 @@ class BatchedTritonOrDeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute):
block_shape
=
block_shape
,
per_act_token_quant
=
per_act_token_quant
,
))
self
.
max_num_tokens
=
max_num_tokens
self
.
world_size
=
world_size
self
.
dp_size
=
dp_size
self
.
allow_deep_gemm
=
allow_deep_gemm
# BatchedTritonKernel doesn't support block quantization
# at the moment.
self
.
batched_triton_experts
=
BatchedTritonExperts
(
max_num_tokens
=
self
.
max_num_tokens
,
world_size
=
self
.
world_size
,
dp_size
=
self
.
dp_size
,
max_num_tokens
=
max_num_tokens
,
num_dispatchers
=
num_dispatchers
,
use_fp8_w8a8
=
use_fp8_w8a8
,
use_int8_w8a8
=
use_int8_w8a8
,
use_int8_w8a16
=
use_int8_w8a16
,
use_int4_w4a16
=
use_int4_w4a16
,
per_act_token_quant
=
self
.
per_act_token_quant
,
block_shape
=
self
.
block_shape
,
)
if
self
.
block_shape
is
None
else
None
)
is_fp8_128_block_quantized
=
(
use_fp8_w8a8
and
self
.
block_shape
self
.
allow_deep_gemm
=
(
allow_deep_gemm
and
use_fp8_w8a8
and
self
.
block_shape
==
BatchedDeepGemmExperts
.
DEEPGEMM_BLOCK_SHAPE
)
self
.
batched_deep_gemm_experts
=
BatchedDeepGemmExperts
(
max_num_tokens
=
self
.
max_num_tokens
,
world_size
=
self
.
world_size
,
dp_size
=
self
.
dp_size
,
max_num_tokens
=
max_num_tokens
,
num_dispatchers
=
num_dispatchers
,
block_shape
=
self
.
block_shape
,
# type: ignore[arg-type]
)
if
(
self
.
allow_deep_gemm
and
is_fp8_128_block_quantized
)
else
None
)
if
self
.
allow_deep_gemm
else
None
assert
(
self
.
batched_deep_gemm_experts
is
not
None
or
self
.
batched_triton_experts
is
not
None
)
...
...
@@ -138,12 +130,8 @@ class BatchedTritonOrDeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute):
workspace2
:
torch
.
Tensor
,
expert_num_tokens
:
Optional
[
torch
.
Tensor
],
):
use_batched_deep_gemm_experts
=
(
self
.
allow_deep_gemm
and
self
.
batched_deep_gemm_experts
is
not
None
)
experts
=
(
self
.
batched_deep_gemm_experts
if
use_batched_deep_gemm_experts
else
self
.
batched_triton_experts
)
if
self
.
allow_deep_gemm
else
self
.
batched_triton_experts
)
assert
experts
is
not
None
experts
.
apply
(
output
,
hidden_states
,
w1
,
w2
,
topk_ids
,
activation
,
global_num_experts
,
expert_map
,
w1_scale
,
w2_scale
,
...
...
vllm/model_executor/layers/fused_moe/config.py
View file @
78fe7753
...
...
@@ -14,6 +14,7 @@ from vllm.distributed import get_dp_group, get_tensor_model_parallel_rank
from
vllm.logger
import
init_logger
from
vllm.model_executor.layers.quantization.base_config
import
(
QuantizationConfig
)
from
vllm.utils
import
cdiv
logger
=
init_logger
(
__name__
)
...
...
@@ -68,6 +69,57 @@ class FusedMoEQuantConfig:
# TODO: add col major flag?
# add detailed quant info for input, intermediates, weights, etc?
def
__post_init__
(
self
):
assert
(
not
self
.
per_act_token_quant
or
self
.
block_shape
is
None
),
"illegal quantization"
@
property
def
is_quantized
(
self
)
->
bool
:
return
self
.
quant_dtype
is
not
None
@
property
def
is_per_act_token
(
self
)
->
bool
:
return
self
.
per_act_token_quant
@
property
def
is_block_quantized
(
self
)
->
bool
:
return
self
.
block_shape
is
not
None
@
property
def
is_per_tensor
(
self
)
->
bool
:
return
not
self
.
per_act_token_quant
and
self
.
block_shape
is
None
def
scale_shape
(
self
,
max_tokens
:
int
,
hidden_dim
:
int
,
)
->
Optional
[
tuple
[
int
,
int
]]:
if
self
.
is_quantized
:
if
self
.
is_block_quantized
:
assert
self
.
block_shape
is
not
None
_
,
block_k
=
self
.
block_shape
k_tiles
=
cdiv
(
hidden_dim
,
block_k
)
return
(
max_tokens
,
k_tiles
)
elif
self
.
is_per_act_token
:
return
(
max_tokens
,
1
)
else
:
return
(
1
,
1
)
else
:
return
None
def
batched_scale_shape
(
self
,
num_experts
:
int
,
max_tokens
:
int
,
hidden_dim
:
int
,
)
->
Optional
[
tuple
[
int
,
int
,
int
]]:
if
self
.
is_quantized
:
scale_shape
=
self
.
scale_shape
(
max_tokens
,
hidden_dim
)
assert
scale_shape
is
not
None
return
(
num_experts
,
*
scale_shape
)
else
:
return
None
@
staticmethod
def
make
(
use_fp8_w8a8
:
bool
=
False
,
...
...
@@ -109,7 +161,6 @@ class FusedMoEParallelConfig:
tp_rank
:
int
dp_rank
:
int
ep_rank
:
int
world_size
:
int
use_ep
:
bool
# whether to use EP or not
...
...
@@ -133,7 +184,7 @@ class FusedMoEParallelConfig:
and
envs
.
VLLM_ALL2ALL_BACKEND
==
"deepep_low_latency"
)
@
staticmethod
def
make
(
tp_size_
:
int
,
dp_size_
:
int
,
world_size_
:
int
,
def
make
(
tp_size_
:
int
,
dp_size_
:
int
,
vllm_parallel_config
:
ParallelConfig
)
->
"FusedMoEParallelConfig"
:
"""
Determine MoE parallel configuration. Based on the input tp_size_,
...
...
@@ -144,7 +195,6 @@ class FusedMoEParallelConfig:
tp_size_ (int): tp_size passed into the FusedMoE constructor.
dp_size_ (int): dp_size passed into the FusedMoE constructor.
ep_size_ (int): ep_size passed into the FusedMoE constructor.
world_size_ (int): the world size of the current All2All manager.
vllm_parallel_config (ParallelConfig): vllm's parallel config
object.
...
...
@@ -223,7 +273,6 @@ class FusedMoEParallelConfig:
dp_rank
=
dp_rank
,
ep_size
=
1
,
ep_rank
=
0
,
world_size
=
world_size_
,
use_ep
=
False
)
# DP + EP / TP + EP / DP + TP + EP
assert
use_ep
...
...
@@ -237,7 +286,6 @@ class FusedMoEParallelConfig:
dp_rank
=
dp_rank
,
ep_size
=
ep_size
,
ep_rank
=
ep_rank
,
world_size
=
world_size_
,
use_ep
=
True
)
...
...
@@ -263,6 +311,8 @@ class FusedMoEConfig:
logger
.
debug
(
"Using FusedMoEConfig::max_num_tokens=%d"
,
self
.
max_num_tokens
)
assert
self
.
max_num_tokens
>
0
@
property
def
quant_dtype
(
self
)
->
Optional
[
torch
.
dtype
]:
if
self
.
quant_config
is
not
None
:
...
...
@@ -303,10 +353,6 @@ class FusedMoEConfig:
def
ep_size
(
self
):
return
self
.
moe_parallel_config
.
ep_size
@
property
def
world_size
(
self
):
return
self
.
moe_parallel_config
.
world_size
@
property
def
tp_rank
(
self
):
return
self
.
moe_parallel_config
.
tp_rank
...
...
vllm/model_executor/layers/fused_moe/cutlass_moe.py
View file @
78fe7753
...
...
@@ -41,10 +41,7 @@ def run_cutlass_moe_fp8(
assert
w2_scale
is
not
None
assert
w1
.
dtype
==
torch
.
float8_e4m3fn
assert
w2
.
dtype
==
torch
.
float8_e4m3fn
if
expert_num_tokens
is
None
:
assert
a1q
.
size
(
1
)
==
w1
.
size
(
2
),
"Hidden size mismatch w1"
else
:
assert
a1q
.
size
(
2
)
==
w1
.
size
(
2
),
"Hidden size mismatch w1"
assert
a1q
.
size
(
-
1
)
==
w1
.
size
(
2
),
"Hidden size mismatch w1"
assert
w1
.
size
(
1
)
==
w2
.
size
(
2
)
*
2
,
"Hidden size mismatch w2"
assert
w1_scale
.
dim
()
==
1
or
w1_scale
.
size
(
1
)
==
1
or
w1_scale
.
shape
[
1
]
==
w1
.
size
(
1
),
"W1 scale shape mismatch"
...
...
@@ -178,6 +175,8 @@ def run_cutlass_moe_fp8(
c2
=
_resize_cache
(
workspace2
,
(
M
*
topk
,
N
))
c3
=
_resize_cache
(
workspace13
,
(
M
*
topk
,
K
))
c1
.
fill_
(
0
)
ops
.
cutlass_moe_mm
(
c1
,
a1q
,
w1
,
a1q_scale
,
w1_scale
,
expert_offsets
,
problem_sizes1
,
ab_strides1
,
ab_strides1
,
c_strides1
,
per_act_token
,
per_out_ch
)
...
...
@@ -213,6 +212,7 @@ class CutlassExpertsFp8(mk.FusedMoEPermuteExpertsUnpermute):
per_act_token_quant
:
bool
,
per_out_ch_quant
:
bool
,
block_shape
:
Optional
[
list
[
int
]]
=
None
,
num_dispatchers
:
Optional
[
int
]
=
None
,
use_batched_format
:
bool
=
False
,
):
super
().
__init__
(
...
...
@@ -223,7 +223,9 @@ class CutlassExpertsFp8(mk.FusedMoEPermuteExpertsUnpermute):
block_shape
=
block_shape
,
))
assert
max_experts_per_worker
>
0
assert
not
use_batched_format
or
num_dispatchers
is
not
None
self
.
max_experts_per_worker
=
max_experts_per_worker
self
.
num_dispatchers
=
num_dispatchers
self
.
out_dtype
=
out_dtype
self
.
use_batched_format
=
use_batched_format
...
...
@@ -260,8 +262,12 @@ class CutlassExpertsFp8(mk.FusedMoEPermuteExpertsUnpermute):
output
:
tuple
[
int
,
...]
=
()
if
self
.
use_batched_format
:
padded_M
=
aq
.
size
(
1
)
workspace1
=
(
self
.
max_experts_per_worker
,
padded_M
,
max
(
N
,
K
))
workspace2
=
(
self
.
max_experts_per_worker
,
padded_M
,
(
N
//
2
))
num_dp
=
self
.
num_dispatchers
assert
num_dp
is
not
None
workspace1
=
(
self
.
max_experts_per_worker
,
padded_M
*
num_dp
,
max
(
N
,
K
))
workspace2
=
(
self
.
max_experts_per_worker
,
padded_M
*
num_dp
,
(
N
//
2
))
output
=
(
self
.
max_experts_per_worker
,
padded_M
,
K
)
else
:
workspace1
=
(
M
*
topk
,
max
(
2
*
N
,
K
))
...
...
vllm/model_executor/layers/fused_moe/deepep_ht_prepare_finalize.py
View file @
78fe7753
...
...
@@ -16,12 +16,11 @@ class DeepEPHTPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
Prepare/Finalize using DeepEP High-Throughput kernels.
"""
def
__init__
(
self
,
buffer
:
deep_ep
.
Buffer
,
world_size
:
int
,
rank
:
int
,
def
__init__
(
self
,
buffer
:
deep_ep
.
Buffer
,
num_dispatchers
:
int
,
dp_size
:
int
,
rank_expert_offset
:
int
):
super
().
__init__
()
self
.
buffer
=
buffer
self
.
world_size
=
world_size
self
.
rank
=
rank
self
.
num_dispatchers_
=
num_dispatchers
self
.
dp_size
=
dp_size
self
.
rank_expert_offset
=
rank_expert_offset
# The dispatch function returns a handle that the combine function
...
...
@@ -32,6 +31,9 @@ class DeepEPHTPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
# From https://github.com/deepseek-ai/DeepEP/blob/9fe9021f29c9083cd1808ab36b740208524d9f63/deep_ep/buffer.py#L164
self
.
available_rank_configs
=
[
2
,
4
,
8
,
16
,
24
,
32
,
64
,
128
,
144
,
160
]
def
num_dispatchers
(
self
)
->
int
:
return
self
.
num_dispatchers_
@
property
def
activation_format
(
self
)
->
mk
.
FusedMoEActivationFormat
:
return
mk
.
FusedMoEActivationFormat
.
Standard
...
...
@@ -136,20 +138,7 @@ class DeepEPHTPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
"apply_router_weight_on_input is only implemented for topk=1"
)
a1
=
a1
*
topk_weights
.
to
(
a1
.
dtype
)
# Check if there is a block_shape / or if we can infer the quantization
# schemes from the scales.
per_token_quant
=
None
if
all
([
x
is
None
for
x
in
[
quant_config
.
block_shape
,
a1_scale
,
a2_scale
]
])
and
quant_config
.
quant_dtype
is
not
None
:
# Quantization required despite none of the inputs suggesting
# quantization. Fallback to per_token_dynamic quant.
per_token_quant
=
True
else
:
per_token_quant
=
False
if
per_token_quant
:
if
quant_config
.
per_act_token_quant
:
a1q
,
a1q_scale
=
moe_kernel_quantize_input
(
a1
,
a1_scale
,
...
...
vllm/model_executor/layers/fused_moe/deepep_ll_prepare_finalize.py
View file @
78fe7753
...
...
@@ -7,7 +7,7 @@ import torch
import
vllm.model_executor.layers.fused_moe.modular_kernel
as
mk
from
vllm.model_executor.layers.fused_moe.config
import
FusedMoEQuantConfig
from
vllm.model_executor.layers.fused_moe.utils
import
(
maybe_fix_scales
,
moe_kernel_quantize_input
)
moe_kernel_quantize_input
,
normalize_batched_scales_shape
)
# DeepEP kernels quantize dispatch inputs in 128 element chunks.
DEEPEP_QUANT_BLOCK_SIZE
=
128
...
...
@@ -42,20 +42,21 @@ class DeepEPLLPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
def
__init__
(
self
,
buffer
:
deep_ep
.
Buffer
,
max_tokens_per_rank
:
int
,
world_size
:
int
,
dp_size
:
int
,
num_dispatchers
:
int
,
use_fp8_dispatch
:
bool
=
False
):
super
().
__init__
()
self
.
buffer
=
buffer
self
.
max_tokens_per_rank
=
max_tokens_per_rank
self
.
world_size
=
world_size
self
.
dp_size
=
dp_size
self
.
use_fp8_dispatch
=
use_fp8_dispatch
# The dispatch function returns a handle that the combine function
# requires. We store the handle here so it is available to the
# combine function.
self
.
handle
=
None
self
.
num_dispatchers_
=
num_dispatchers
def
num_dispatchers
(
self
)
->
int
:
return
self
.
num_dispatchers_
@
property
def
activation_format
(
self
)
->
mk
.
FusedMoEActivationFormat
:
...
...
@@ -91,8 +92,6 @@ class DeepEPLLPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
assert
isinstance
(
x
,
torch
.
Tensor
)
assert
not
per_act_token_quant
num_experts
,
max_tokens
,
hidden_dim
=
x
.
size
()
# TODO (varun): Optimization - Use a batched version of quant
...
...
@@ -104,7 +103,7 @@ class DeepEPLLPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
if
quant_dtype
is
not
None
:
assert
x_scales
is
not
None
x_scales
=
maybe_fix
_scales
(
x_scales
,
num_experts
)
x_scales
=
normalize_batched
_scales
_shape
(
x_scales
,
num_experts
)
return
x
,
x_scales
...
...
vllm/model_executor/layers/fused_moe/fused_batched_moe.py
View file @
78fe7753
...
...
@@ -12,7 +12,10 @@ from vllm.model_executor.layers.fused_moe.config import FusedMoEQuantConfig
from
vllm.model_executor.layers.fused_moe.fused_moe
import
(
get_config_dtype_str
,
try_get_optimal_moe_config
)
from
vllm.model_executor.layers.fused_moe.utils
import
(
_resize_cache
,
moe_kernel_quantize_input
)
_resize_cache
,
moe_kernel_quantize_input
,
normalize_batched_scales_shape
,
normalize_scales_shape
)
from
vllm.model_executor.layers.quantization.utils.quant_utils
import
(
group_broadcast
)
@
triton
.
jit
...
...
@@ -27,16 +30,18 @@ def moe_mmk(
# moving by 1 element in a particular dimension. E.g. `stride_am` is
# how much to increase `a_ptr` by to get the element one row down
# (A has M rows).
stride_ak
,
stride_bk
,
stride_asm
,
stride_ask
,
stride_bse
,
stride_bsk
,
stride_bsn
,
stride_ak
:
tl
.
int64
,
stride_bk
:
tl
.
int64
,
stride_ase
:
tl
.
int64
,
stride_asm
:
tl
.
int64
,
stride_ask
:
tl
.
int64
,
stride_bse
:
tl
.
int64
,
stride_bsk
:
tl
.
int64
,
stride_bsn
:
tl
.
int64
,
# Offsets and masks
offs_m
,
offs_n
,
offs_bn
,
mask_m
,
# Block size for block-wise quantization
group_n
:
tl
.
constexpr
,
...
...
@@ -47,7 +52,9 @@ def moe_mmk(
BLOCK_K
:
tl
.
constexpr
,
compute_type
:
tl
.
constexpr
,
use_w8a8
:
tl
.
constexpr
,
use_w8a16
:
tl
.
constexpr
):
use_w8a16
:
tl
.
constexpr
,
per_act_token_quant
:
tl
.
constexpr
,
):
offs_k
=
tl
.
arange
(
0
,
BLOCK_K
)
...
...
@@ -60,13 +67,22 @@ def moe_mmk(
# block-wise
if
group_k
>
0
and
group_n
>
0
:
a_scale_ptrs
=
a_scale_ptr
+
offs_m
*
stride_asm
offs_bsn
=
offs_n
//
group_n
b_scale_ptrs
=
(
b_scale_ptr
+
expert_id
*
stride_bse
+
offs_bsn
*
stride_bsn
)
offs_bsn
=
offs_bn
//
group_n
b_scale_ptrs
=
b_scale_ptr
+
offs_bsn
*
stride_bsn
# per act token
elif
per_act_token_quant
:
# Load per-token scale for activations
a_scale_ptrs
=
a_scale_ptr
+
offs_m
*
stride_asm
a_scale
=
tl
.
load
(
a_scale_ptrs
,
mask
=
mask_m
,
other
=
0.0
)[:,
None
]
b_scale_ptrs
=
b_scale_ptr
+
offs_bn
[
None
,
:]
*
stride_bsn
b_scale
=
tl
.
load
(
b_scale_ptrs
)
# tensor-wise
else
:
a_scale
=
tl
.
load
(
a_scale_ptr
)
b_scale
=
tl
.
load
(
b_scale_ptr
+
expert_id
)
b_scale
=
tl
.
load
(
b_scale_ptr
)
# -----------------------------------------------------------
# Iterate to compute a block of the C matrix.
...
...
@@ -96,13 +112,11 @@ def moe_mmk(
accumulator
+=
tl
.
dot
(
a
,
b
)
*
a_scale
[:,
None
]
*
b_scale
[
None
,
:]
else
:
if
use_w8a8
:
# acc used to enable fp8_fast_accum
accumulator
=
tl
.
dot
(
a
,
b
,
acc
=
accumulator
)
else
:
accumulator
+=
tl
.
dot
(
a
,
b
)
else
:
accumulator
+=
tl
.
dot
(
a
,
b
)
# Advance the ptrs to the next K block.
a_ptrs
+=
BLOCK_K
*
stride_ak
b_ptrs
+=
BLOCK_K
*
stride_bk
...
...
@@ -136,33 +150,39 @@ def expert_triton_kernel(
b_scale_ptr
,
b_zp_ptr
,
# strides
stride_am
,
stride_ak
,
stride_bk
,
stride_bn
,
stride_cm
,
stride_cn
,
stride_asm
,
stride_ask
,
stride_bse
,
stride_bsk
,
stride_bsn
,
stride_am
:
tl
.
int64
,
stride_ak
:
tl
.
int64
,
stride_bk
:
tl
.
int64
,
stride_bn
:
tl
.
int64
,
stride_cm
:
tl
.
int64
,
stride_cn
:
tl
.
int64
,
stride_ase
:
tl
.
int64
,
stride_asm
:
tl
.
int64
,
stride_ask
:
tl
.
int64
,
stride_bse
:
tl
.
int64
,
stride_bsk
:
tl
.
int64
,
stride_bsn
:
tl
.
int64
,
# offsets
offs_bn
,
# Blockwise quantization data
group_n
,
group_k
,
# Quantization schemes
use_fp8_w8a8
:
tl
.
constexpr
,
use_int8_w8a16
:
tl
.
constexpr
,
per_act_token_quant
:
tl
.
constexpr
,
# Kernel config
BLOCK_M
:
tl
.
constexpr
,
BLOCK_N
:
tl
.
constexpr
,
BLOCK_K
:
tl
.
constexpr
):
BLOCK_K
:
tl
.
constexpr
,
):
offs_m
=
tl
.
arange
(
0
,
BLOCK_M
)
offs_n
=
tl
.
arange
(
0
,
BLOCK_N
)
%
N
offs_k
=
tl
.
arange
(
0
,
BLOCK_K
)
mask_m
=
offs_m
<
M
# Make grids of a + b pointers
a_ptrs
=
a_ptr
+
offs_m
[:,
None
]
*
stride_am
+
offs_k
[
None
,
:]
*
stride_ak
b_ptrs
=
b_ptr
+
offs_k
[:,
None
]
*
stride_bk
+
offs_n
[
None
,
:]
*
stride_bn
...
...
@@ -179,6 +199,7 @@ def expert_triton_kernel(
# (A has M rows).
stride_ak
,
stride_bk
,
stride_ase
,
stride_asm
,
stride_ask
,
stride_bse
,
...
...
@@ -187,6 +208,7 @@ def expert_triton_kernel(
# Offsets and masks
offs_m
,
offs_n
,
offs_bn
,
mask_m
,
# Block size for block-wise quantization
group_n
,
...
...
@@ -197,7 +219,8 @@ def expert_triton_kernel(
BLOCK_K
,
compute_type
,
use_fp8_w8a8
,
use_int8_w8a16
)
use_int8_w8a16
,
per_act_token_quant
)
# store in C
offs_cn
=
tl
.
arange
(
0
,
BLOCK_N
)
...
...
@@ -225,36 +248,40 @@ def batched_triton_kernel(
# moving by 1 element in a particular dimension. E.g. `stride_am` is
# how much to increase `a_ptr` by to get the element one row down
# (A has M rows).
stride_ae
,
stride_am
,
stride_ak
,
stride_be
,
stride_bk
,
stride_bn
,
stride_ce
,
stride_cm
,
stride_cn
,
stride_asm
,
stride_ask
,
stride_bse
,
stride_bsk
,
stride_bsn
,
stride_ae
:
tl
.
int64
,
stride_am
:
tl
.
int64
,
stride_ak
:
tl
.
int64
,
stride_be
:
tl
.
int64
,
stride_bk
:
tl
.
int64
,
stride_bn
:
tl
.
int64
,
stride_ce
:
tl
.
int64
,
stride_cm
:
tl
.
int64
,
stride_cn
:
tl
.
int64
,
stride_ase
:
tl
.
int64
,
stride_asm
:
tl
.
int64
,
stride_ask
:
tl
.
int64
,
stride_bse
:
tl
.
int64
,
stride_bsk
:
tl
.
int64
,
stride_bsn
:
tl
.
int64
,
# Blockwise quantization data
group_n
:
tl
.
constexpr
,
group_k
:
tl
.
constexpr
,
# Quantization schemes
use_fp8_w8a8
:
tl
.
constexpr
,
use_int8_w8a16
:
tl
.
constexpr
,
per_act_token_quant
:
tl
.
constexpr
,
# Kernel config
BLOCK_M
:
tl
.
constexpr
,
BLOCK_N
:
tl
.
constexpr
,
BLOCK_K
:
tl
.
constexpr
):
BLOCK_K
:
tl
.
constexpr
,
):
expert_id
=
tl
.
program_id
(
axis
=
0
)
e_num_tokens
=
tl
.
load
(
expert_num_tokens
+
expert_id
)
if
e_num_tokens
==
0
:
# Early exit
return
# axis 1 is M_blocks * N_blocks
pid_mn
=
tl
.
program_id
(
axis
=
1
)
#num_pid_m = tl.cdiv(max_num_tokens, BLOCK_M)
num_pid_n
=
tl
.
cdiv
(
N
,
BLOCK_N
)
...
...
@@ -275,6 +302,16 @@ def batched_triton_kernel(
c_ptr
=
(
c_ptr
+
expert_id
*
stride_ce
+
cta_m_start
*
stride_cm
+
cta_n_start
*
stride_cn
)
offs_bn
=
(
pid_n
*
BLOCK_N
+
tl
.
arange
(
0
,
BLOCK_N
).
to
(
tl
.
int64
))
%
N
if
use_fp8_w8a8
:
a_scale_ptr
=
a_scale_ptr
+
expert_id
*
stride_ase
b_scale_ptr
=
b_scale_ptr
+
expert_id
*
stride_bse
# block-wise
if
group_k
>
0
and
group_n
>
0
or
per_act_token_quant
:
a_scale_ptr
=
a_scale_ptr
+
cta_m_start
*
stride_asm
expert_triton_kernel
(
a_ptr
,
b_ptr
,
...
...
@@ -294,17 +331,21 @@ def batched_triton_kernel(
stride_bn
,
stride_cm
,
stride_cn
,
stride_ase
,
stride_asm
,
stride_ask
,
stride_bse
,
stride_bsk
,
stride_bsn
,
# offsets
offs_bn
,
# Blockwise quantization data
group_n
,
group_k
,
# Quantization schemes
use_fp8_w8a8
,
use_int8_w8a16
,
per_act_token_quant
,
# Kernel config
BLOCK_M
,
BLOCK_N
,
...
...
@@ -326,6 +367,7 @@ def invoke_moe_batched_triton_kernel(
use_int8_w8a16
:
bool
,
use_int4_w4a16
:
bool
,
config
:
dict
[
str
,
int
],
per_act_token_quant
:
bool
,
block_shape
:
Optional
[
list
[
int
]]
=
None
):
assert
not
use_int4_w4a16
...
...
@@ -340,6 +382,42 @@ def invoke_moe_batched_triton_kernel(
grid
=
(
expert_num_tokens
.
size
(
0
),
triton
.
cdiv
(
max_num_tokens
,
BLOCK_M
)
*
triton
.
cdiv
(
B
.
size
(
1
),
BLOCK_N
))
A_scale
=
normalize_batched_scales_shape
(
A_scale
,
expert_num_tokens
.
shape
[
0
])
if
B_scale
is
not
None
and
B_scale
.
ndim
==
1
:
assert
B_scale
.
numel
()
==
expert_num_tokens
.
shape
[
0
]
B_scale
=
B_scale
.
view
(
-
1
,
1
,
1
)
assert
A_scale
is
None
or
A_scale
.
ndim
==
3
,
(
f
"
{
0
if
A_scale
is
None
else
A_scale
.
shape
}
"
)
assert
B_scale
is
None
or
B_scale
.
ndim
==
1
or
B_scale
.
ndim
==
3
,
(
f
"
{
0
if
B_scale
is
None
else
B_scale
.
shape
}
"
)
if
B_scale
is
not
None
:
if
B_scale
.
ndim
==
1
:
stride_bse
=
1
stride_bsk
=
0
stride_bsn
=
0
else
:
stride_bse
=
B_scale
.
stride
(
0
)
stride_bsk
=
B_scale
.
stride
(
2
)
stride_bsn
=
B_scale
.
stride
(
1
)
else
:
stride_bse
=
0
stride_bsk
=
0
stride_bsn
=
0
if
A_scale
is
not
None
:
stride_ase
=
A_scale
.
stride
(
0
)
stride_asm
=
A_scale
.
stride
(
1
)
stride_ask
=
A_scale
.
stride
(
2
)
else
:
stride_ase
=
0
stride_asm
=
0
stride_ask
=
0
batched_triton_kernel
[
grid
](
A
,
B
,
...
...
@@ -364,17 +442,19 @@ def invoke_moe_batched_triton_kernel(
C
.
stride
(
0
),
C
.
stride
(
1
),
C
.
stride
(
2
),
A_scale
.
stride
(
0
)
if
A_scale
is
not
None
and
A_scale
.
ndim
==
2
else
0
,
A_scale
.
stride
(
1
)
if
A_scale
is
not
None
and
A_scale
.
ndim
==
2
else
0
,
B_scale
.
stride
(
0
)
if
B_scale
is
not
None
and
B_scale
.
ndim
>=
2
else
0
,
B_scale
.
stride
(
2
)
if
B_scale
is
not
None
and
B_scale
.
ndim
==
3
else
0
,
B_scale
.
stride
(
1
)
if
B_scale
is
not
None
and
B_scale
.
ndim
>=
2
else
0
,
stride_ase
,
stride_asm
,
stride_ask
,
stride_bse
,
stride_bsk
,
stride_bsn
,
# Blockwise quantization data
0
if
block_shape
is
None
else
block_shape
[
0
],
0
if
block_shape
is
None
else
block_shape
[
1
],
# Quantization schemes
use_fp8_w8a8
,
use_int8_w8a16
,
per_act_token_quant
,
# Kernel config
BLOCK_M
=
BLOCK_M
,
BLOCK_N
=
BLOCK_N
,
...
...
@@ -391,15 +471,15 @@ class BatchedPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
def
__init__
(
self
,
max_num_tokens
:
int
,
world_size
:
int
,
dp_size
:
int
,
num_local_experts
:
int
,
num_dispatchers
:
int
,
rank
:
int
,
):
super
().
__init__
()
self
.
world_size
=
world_size
self
.
dp_size
=
dp_size
self
.
rank
=
rank
self
.
max_num_tokens
=
max_num_tokens
self
.
num_local_experts
=
num_local_experts
self
.
rank
=
rank
self
.
num_dispatchers_
=
num_dispatchers
@
property
def
activation_format
(
self
)
->
mk
.
FusedMoEActivationFormat
:
...
...
@@ -411,6 +491,9 @@ class BatchedPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
def
topk_indices_dtype
(
self
)
->
Optional
[
torch
.
dtype
]:
return
None
def
num_dispatchers
(
self
)
->
int
:
return
self
.
num_dispatchers_
def
prepare
(
self
,
a1
:
torch
.
Tensor
,
...
...
@@ -442,9 +525,7 @@ class BatchedPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
dtype
=
torch
.
int
,
device
=
a1
.
device
)
assert
num_experts
%
self
.
world_size
==
0
num_local_experts
=
num_experts
//
self
.
world_size
num_local_experts
=
self
.
num_local_experts
if
quant_config
.
quant_dtype
is
None
:
b_type
=
a1
.
dtype
...
...
@@ -456,21 +537,53 @@ class BatchedPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
dtype
=
b_type
,
device
=
a1
.
device
)
b_a1_scale
=
None
if
quant_config
.
is_quantized
:
scale_shape
=
quant_config
.
batched_scale_shape
(
num_local_experts
,
self
.
max_num_tokens
,
hidden_dim
)
assert
quant_config
.
quant_dtype
is
None
,
"quantization NYI"
b_a1_scale
=
torch
.
empty
(
scale_shape
,
dtype
=
torch
.
float32
,
device
=
a1
.
device
)
else
:
assert
a1_scale
is
None
b_a1_scale
=
None
first_expert
=
num_local_experts
*
self
.
rank
last_expert
=
first_expert
+
num_local_experts
a1_scale
=
normalize_scales_shape
(
a1_scale
)
a2_scale
=
normalize_scales_shape
(
a2_scale
)
for
expert_id
in
range
(
first_expert
,
last_expert
):
topks
=
torch
.
any
(
topk_ids
==
expert_id
,
dim
=
1
).
flatten
()
rows
=
torch
.
count_nonzero
(
topks
.
flatten
())
if
rows
==
0
:
continue
idx
=
expert_id
-
first_expert
b_a1
[
idx
,
:
rows
,
:]
=
a1
[:
topks
.
numel
()][
topks
]
tokens_per_expert
[
idx
]
=
rows
rhs
=
a1
[:
topks
.
numel
()][
topks
]
if
quant_config
.
quant_dtype
is
not
None
:
if
a1_scale
is
not
None
:
if
quant_config
.
is_per_act_token
:
rhs_a1_scale
=
a1_scale
[:
topks
.
numel
()][
topks
]
else
:
rhs_a1_scale
=
a1_scale
else
:
rhs_a1_scale
=
None
b_a1
[
idx
,
:
rows
,
:],
b_s
=
moe_kernel_quantize_input
(
rhs
,
rhs_a1_scale
,
quant_config
.
quant_dtype
,
quant_config
.
per_act_token_quant
,
quant_config
.
block_shape
,
)
assert
b_s
is
not
None
if
quant_config
.
is_per_act_token
:
b_a1_scale
[
idx
,
:
rows
]
=
b_s
[:
rows
]
else
:
b_a1_scale
[
idx
,
:
b_s
.
shape
[
0
]]
=
b_s
else
:
b_a1
[
idx
,
:
rows
,
:]
=
rhs
assert
b_a1_scale
is
None
or
b_a1_scale
.
ndim
==
3
...
...
@@ -514,8 +627,7 @@ class NaiveBatchedExperts(mk.FusedMoEPermuteExpertsUnpermute):
def
__init__
(
self
,
max_num_tokens
:
int
,
world_size
:
int
,
dp_size
:
int
,
num_dispatchers
:
int
,
use_fp8_w8a8
:
bool
=
False
,
use_int8_w8a8
:
bool
=
False
,
use_int8_w8a16
:
bool
=
False
,
...
...
@@ -532,13 +644,11 @@ class NaiveBatchedExperts(mk.FusedMoEPermuteExpertsUnpermute):
per_act_token_quant
=
per_act_token_quant
,
block_shape
=
block_shape
,
))
assert
not
use_fp8_w8a8
,
"NYI"
assert
not
use_int8_w8a8
,
"NYI"
assert
not
use_int8_w8a16
,
"NYI"
assert
not
use_int4_w4a16
,
"NYI"
self
.
max_num_tokens
=
max_num_tokens
self
.
world_size
=
world_size
self
.
dp_size
=
dp_size
self
.
num_dispatchers
=
num_dispatchers
@
property
def
activation_formats
(
...
...
@@ -565,11 +675,21 @@ class NaiveBatchedExperts(mk.FusedMoEPermuteExpertsUnpermute):
local_num_experts
:
int
,
)
->
tuple
[
tuple
[
int
,
...],
tuple
[
int
,
...],
tuple
[
int
,
...],
torch
.
dtype
]:
assert
a
.
dim
()
==
2
num_dp
=
self
.
dp_size
num_dp
=
self
.
num_dispatchers
num_experts
=
local_num_experts
workspace13
=
(
num_experts
,
self
.
max_num_tokens
*
num_dp
,
K
)
workspace2
=
(
self
.
max_num_tokens
*
num_dp
,
N
)
return
(
workspace13
,
workspace2
,
workspace13
,
a
.
dtype
)
output
=
workspace13
return
(
workspace13
,
workspace2
,
output
,
a
.
dtype
)
def
dequant
(
self
,
t
:
torch
.
Tensor
,
scale
:
torch
.
Tensor
)
->
torch
.
Tensor
:
assert
self
.
quant_config
.
is_quantized
f32
=
torch
.
float32
if
(
self
.
quant_config
.
is_per_act_token
or
self
.
quant_config
.
is_per_tensor
):
return
t
.
to
(
f32
)
*
scale
else
:
return
t
.
to
(
f32
)
*
group_broadcast
(
scale
,
t
.
shape
)
def
apply
(
self
,
...
...
@@ -612,9 +732,95 @@ class NaiveBatchedExperts(mk.FusedMoEPermuteExpertsUnpermute):
continue
tmp
=
_resize_cache
(
workspace2
,
(
num
,
N
))
input
=
hidden_states
[
expert
,
:
num
,
:]
@
w1
[
expert
].
transpose
(
0
,
1
)
self
.
activation
(
activation
,
tmp
,
input
)
output
[
expert
,
:
num
,
:]
=
tmp
@
w2
[
expert
].
transpose
(
0
,
1
)
if
self
.
quant_config
.
is_quantized
:
assert
a1q_scale
is
not
None
and
w1_scale
is
not
None
input
=
self
.
dequant
(
hidden_states
[
expert
,
:,
:],
a1q_scale
[
expert
])
w1_dq
=
self
.
dequant
(
w1
[
expert
],
w1_scale
[
expert
])
input
=
input
[:
num
]
@
w1_dq
.
transpose
(
0
,
1
)
else
:
input
=
hidden_states
[
expert
,
:
num
,
:]
@
w1
[
expert
].
transpose
(
0
,
1
)
self
.
activation
(
activation
,
tmp
,
input
.
to
(
tmp
.
dtype
))
if
self
.
quant_config
.
is_quantized
:
assert
w2_scale
is
not
None
w2_dq
=
self
.
dequant
(
w2
[
expert
],
w2_scale
[
expert
])
else
:
w2_dq
=
w2
[
expert
]
output
[
expert
,
:
num
,
:]
=
tmp
@
w2_dq
.
transpose
(
0
,
1
).
to
(
tmp
.
dtype
)
def
batched_moe_kernel_quantize_input
(
A
:
torch
.
Tensor
,
A_scale
:
Optional
[
torch
.
Tensor
],
num_tokens
:
int
,
E
:
int
,
N
:
int
,
expert_num_tokens
:
torch
.
Tensor
,
qtype
:
Optional
[
torch
.
dtype
],
per_act_token_quant
:
bool
,
block_shape
:
Optional
[
list
[
int
]]
=
None
,
)
->
tuple
[
torch
.
Tensor
,
Optional
[
torch
.
Tensor
]]:
if
(
torch
.
compiler
.
is_compiling
()
or
torch
.
cuda
.
is_current_stream_capturing
()):
# Note: this does a bunch of extra work because expert_num_tokens is
# ignored but it does support torch.compile + cudagraphs.
hidden_dim
=
A
.
size
(
-
1
)
assert
A_scale
is
None
or
A_scale
.
ndim
<=
2
,
(
f
"
{
A_scale
.
shape
if
A_scale
is
not
None
else
None
}
"
)
A_q
,
A_q_scale
=
moe_kernel_quantize_input
(
A
.
view
(
-
1
,
hidden_dim
),
A_scale
,
qtype
,
per_act_token_quant
,
block_shape
)
A_q
=
A_q
.
view
(
E
,
-
1
,
hidden_dim
)
A_q_scale
=
normalize_batched_scales_shape
(
A_q_scale
,
E
)
return
A_q
,
A_q_scale
elif
qtype
is
None
:
return
A
,
normalize_batched_scales_shape
(
A_scale
,
E
)
else
:
A_q
=
torch
.
empty_like
(
A
,
dtype
=
qtype
)
if
per_act_token_quant
:
assert
block_shape
is
None
scale_shape
=
(
E
,
num_tokens
,
1
)
elif
block_shape
is
not
None
:
_
,
block_k
=
block_shape
k_tiles
=
(
A
.
shape
[
-
1
]
+
block_k
-
1
)
//
block_k
scale_shape
=
(
E
,
num_tokens
,
k_tiles
)
else
:
scale_shape
=
(
E
,
1
,
1
)
A_q_scale
=
torch
.
zeros
(
scale_shape
,
dtype
=
torch
.
float32
,
device
=
A
.
device
)
num_experts
=
expert_num_tokens
.
numel
()
A_scale
=
normalize_batched_scales_shape
(
A_scale
,
num_experts
)
for
e
in
range
(
E
):
num_tokens
=
int
(
expert_num_tokens
[
e
].
item
())
if
num_tokens
>
0
:
if
A_scale
is
not
None
:
scales
=
A_scale
[
e
,
:
min
(
num_tokens
,
A_scale
.
shape
[
1
])]
else
:
scales
=
None
A_q
[
e
,
:
num_tokens
],
tmp_scale
=
moe_kernel_quantize_input
(
A
[
e
,
:
num_tokens
],
scales
,
qtype
,
per_act_token_quant
,
block_shape
,
)
assert
tmp_scale
is
not
None
A_q_scale
[
e
,
:
tmp_scale
.
shape
[
0
]]
=
tmp_scale
return
A_q
,
A_q_scale
class
BatchedTritonExperts
(
mk
.
FusedMoEPermuteExpertsUnpermute
):
...
...
@@ -627,8 +833,7 @@ class BatchedTritonExperts(mk.FusedMoEPermuteExpertsUnpermute):
def
__init__
(
self
,
max_num_tokens
:
int
,
world_size
:
int
,
dp_size
:
int
,
num_dispatchers
:
int
,
use_fp8_w8a8
:
bool
=
False
,
use_int8_w8a8
:
bool
=
False
,
use_int8_w8a16
:
bool
=
False
,
...
...
@@ -648,17 +853,14 @@ class BatchedTritonExperts(mk.FusedMoEPermuteExpertsUnpermute):
assert
not
use_int8_w8a8
,
"NYI"
assert
not
use_int8_w8a16
,
"NYI"
assert
not
use_int4_w4a16
,
"NYI"
assert
max_num_tokens
>
0
assert
num_dispatchers
>
0
self
.
use_fp8_w8a8
=
use_fp8_w8a8
self
.
use_int8_w8a8
=
use_int8_w8a8
self
.
use_int4_w4a16
=
use_int4_w4a16
self
.
use_int8_w8a16
=
use_int8_w8a16
self
.
max_num_tokens
=
max_num_tokens
self
.
world_size
=
world_size
self
.
dp_size
=
dp_size
assert
world_size
>
0
assert
dp_size
>
0
assert
dp_size
<=
world_size
assert
max_num_tokens
>
0
self
.
num_dispatchers
=
num_dispatchers
@
property
def
activation_formats
(
...
...
@@ -685,7 +887,7 @@ class BatchedTritonExperts(mk.FusedMoEPermuteExpertsUnpermute):
local_num_experts
:
int
,
)
->
tuple
[
tuple
[
int
,
...],
tuple
[
int
,
...],
tuple
[
int
,
...],
torch
.
dtype
]:
assert
a
.
dim
()
==
2
num_dp
=
self
.
world_size
num_dp
=
self
.
num_dispatchers
num_experts
=
local_num_experts
max_num_tokens
=
self
.
max_num_tokens
workspace13
=
(
num_experts
,
max_num_tokens
*
num_dp
,
max
(
K
,
N
))
...
...
@@ -772,8 +974,11 @@ class BatchedTritonExperts(mk.FusedMoEPermuteExpertsUnpermute):
if
self
.
use_fp8_w8a8
:
intermediate_cache1
.
fill_
(
0
)
a1q_scale
=
normalize_batched_scales_shape
(
a1q_scale
,
E
)
# MM1
invoke_moe_batched_triton_kernel
(
A
=
hidden_states
,
invoke_moe_batched_triton_kernel
(
A
=
hidden_states
,
B
=
w1
,
C
=
intermediate_cache1
,
expert_num_tokens
=
expert_num_tokens
,
...
...
@@ -785,29 +990,22 @@ class BatchedTritonExperts(mk.FusedMoEPermuteExpertsUnpermute):
use_int8_w8a16
=
self
.
use_int8_w8a16
,
use_int4_w4a16
=
self
.
use_int4_w4a16
,
config
=
config
,
per_act_token_quant
=
self
.
per_act_token_quant
,
block_shape
=
self
.
block_shape
)
intermediate_cache2
.
fill_
(
0
)
# TODO: would be nice to use expert_num_tokens here to reduce
# garbage compute
# TODO (bnell): use triton utility from batched deep gemm.
self
.
activation
(
activation
,
intermediate_cache2
.
view
(
-
1
,
N
//
2
),
intermediate_cache1
.
view
(
-
1
,
N
))
ic2_hidden_size
=
intermediate_cache2
.
size
(
-
1
)
intermediate_cache2
=
intermediate_cache2
.
view
(
-
1
,
ic2_hidden_size
)
qintermediate_cache2
,
a2q_scale
=
moe_kernel_quantize_input
(
A
=
intermediate_cache2
,
A_scale
=
a2_scale
,
quant_dtype
=
self
.
quant_dtype
,
per_act_token_quant
=
self
.
per_act_token_quant
,
block_shape
=
self
.
block_shape
)
qintermediate_cache2
=
qintermediate_cache2
.
view
(
(
E
,
-
1
,
ic2_hidden_size
))
qintermediate_cache2
,
a2q_scale
=
batched_moe_kernel_quantize_input
(
intermediate_cache2
,
a2_scale
,
max_num_tokens
,
E
,
N
,
expert_num_tokens
,
self
.
quant_dtype
,
self
.
per_act_token_quant
,
self
.
block_shape
)
invoke_moe_batched_triton_kernel
(
A
=
qintermediate_cache2
,
invoke_moe_batched_triton_kernel
(
A
=
qintermediate_cache2
,
B
=
w2
,
C
=
output
,
expert_num_tokens
=
expert_num_tokens
,
...
...
@@ -819,4 +1017,5 @@ class BatchedTritonExperts(mk.FusedMoEPermuteExpertsUnpermute):
use_int8_w8a16
=
self
.
use_int8_w8a16
,
use_int4_w4a16
=
self
.
use_int4_w4a16
,
config
=
config
,
per_act_token_quant
=
self
.
per_act_token_quant
,
block_shape
=
self
.
block_shape
)
vllm/model_executor/layers/fused_moe/fused_moe.py
View file @
78fe7753
...
...
@@ -1127,6 +1127,8 @@ def dispatch_fused_experts_func(inplace: bool) -> Callable[..., torch.Tensor]:
return
torch_vllm_outplace_fused_experts
# TODO (bnell): replace this with modular op. Can get rid of inplace/outplace
# torch ops.
def
fused_experts
(
hidden_states
:
torch
.
Tensor
,
w1
:
torch
.
Tensor
,
w2
:
torch
.
Tensor
,
...
...
vllm/model_executor/layers/fused_moe/layer.py
View file @
78fe7753
...
...
@@ -14,7 +14,6 @@ import vllm.envs as envs
from
vllm.config
import
get_current_vllm_config
from
vllm.distributed
import
(
get_dp_group
,
get_ep_group
,
get_tensor_model_parallel_world_size
,
get_world_group
,
tensor_model_parallel_all_reduce
)
from
vllm.distributed.eplb.eplb_state
import
EplbState
from
vllm.forward_context
import
ForwardContext
,
get_forward_context
...
...
@@ -114,6 +113,9 @@ class FusedMoEMethodBase(QuantizeMethodBase):
hidden_dim_scale_bytes
=
hidden_scale_bytes
,
)
num_dispatchers
=
(
all2all_manager
.
world_size
//
all2all_manager
.
tp_group
.
world_size
)
# Intranode pplx a2a takes a group name while internode does not.
if
not
all2all_manager
.
internode
:
all_to_all_args
[
...
...
@@ -124,10 +126,8 @@ class FusedMoEMethodBase(QuantizeMethodBase):
prepare_finalize
=
PplxPrepareAndFinalize
(
handle
,
max_num_tokens
=
moe
.
max_num_tokens
,
world_size
=
all2all_manager
.
world_size
,
rank
=
all2all_manager
.
rank
,
# dp_size actually means tp_size, bug in pplx kernels
dp_size
=
all2all_manager
.
tp_group
.
world_size
,
num_local_experts
=
moe
.
num_local_experts
,
num_dispatchers
=
num_dispatchers
,
)
elif
moe
.
use_deepep_ht_kernels
:
assert
moe
.
dp_size
==
all2all_manager
.
dp_world_size
...
...
@@ -136,16 +136,13 @@ class FusedMoEMethodBase(QuantizeMethodBase):
handle
=
all2all_manager
.
get_handle
(
all_to_all_args
)
prepare_finalize
=
DeepEPHTPrepareAndFinalize
(
handle
,
world_size
=
all2all_manager
.
world_size
,
rank
=
all2all_manager
.
rank
,
num_dispatchers
=
all2all_manager
.
world_size
,
dp_size
=
all2all_manager
.
dp_world_size
,
rank_expert_offset
=
all2all_manager
.
rank
*
moe
.
num_local_experts
,
)
elif
moe
.
use_deepep_ll_kernels
:
assert
moe
.
dp_size
==
all2all_manager
.
dp_world_size
all_to_all_args
=
dict
(
max_num_tokens_per_dp_rank
=
moe
.
max_num_tokens
,
token_hidden_size
=
moe
.
hidden_dim
,
...
...
@@ -168,8 +165,7 @@ class FusedMoEMethodBase(QuantizeMethodBase):
prepare_finalize
=
DeepEPLLPrepareAndFinalize
(
handle
,
max_tokens_per_rank
=
moe
.
max_num_tokens
,
world_size
=
all2all_manager
.
world_size
,
dp_size
=
all2all_manager
.
dp_world_size
,
num_dispatchers
=
all2all_manager
.
world_size
,
use_fp8_dispatch
=
use_fp8_dispatch
,
)
...
...
@@ -245,18 +241,12 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
assert
self
.
fused_experts
==
fused_experts
all2all_manager
=
get_ep_group
().
device_communicator
.
all2all_manager
assert
all2all_manager
is
not
None
if
(
prepare_finalize
.
activation_format
==
FusedMoEActivationFormat
.
BatchedExperts
):
logger
.
debug
(
"BatchedTritonExperts %s"
,
self
.
moe
)
assert
self
.
moe
.
dp_size
==
all2all_manager
.
dp_world_size
return
BatchedTritonExperts
(
max_num_tokens
=
self
.
moe
.
max_num_tokens
,
world_size
=
all2all_manager
.
world_size
,
# dp_size actually means tp_size, bug in pplx kernels
dp_size
=
all2all_manager
.
tp_group
.
world_size
,
num_dispatchers
=
prepare_finalize
.
num_dispatchers
(),
)
else
:
logger
.
debug
(
"TritonExperts %s"
,
self
.
moe
)
...
...
@@ -652,14 +642,12 @@ class FusedMoE(torch.nn.Module):
get_tensor_model_parallel_world_size
())
dp_size_
=
(
dp_size
if
dp_size
is
not
None
else
get_dp_group
().
world_size
)
world_size_
=
get_world_group
().
world_size
vllm_config
=
get_current_vllm_config
()
self
.
moe_parallel_config
:
FusedMoEParallelConfig
=
(
FusedMoEParallelConfig
.
make
(
tp_size_
=
tp_size_
,
dp_size_
=
dp_size_
,
world_size_
=
world_size_
,
vllm_parallel_config
=
vllm_config
.
parallel_config
))
self
.
global_num_experts
=
num_experts
+
num_redundant_experts
...
...
@@ -1299,6 +1287,8 @@ class FusedMoE(torch.nn.Module):
topk_ids
=
topk_ids
.
to
(
dtype
=
indices_type
)
assert
topk_ids
.
dtype
==
indices_type
or
indices_type
is
None
return
topk_weights
,
topk_ids
def
must_reduce_shared_expert_outputs
(
self
)
->
bool
:
...
...
vllm/model_executor/layers/fused_moe/modular_kernel.py
View file @
78fe7753
...
...
@@ -193,6 +193,10 @@ class FusedMoEPrepareAndFinalize(ABC):
"""
raise
NotImplementedError
@
abstractmethod
def
num_dispatchers
(
self
)
->
int
:
raise
NotImplementedError
class
FusedMoEPermuteExpertsUnpermute
(
ABC
):
"""
...
...
vllm/model_executor/layers/fused_moe/pplx_prepare_finalize.py
View file @
78fe7753
...
...
@@ -8,7 +8,7 @@ import torch
import
vllm.model_executor.layers.fused_moe.modular_kernel
as
mk
from
vllm.model_executor.layers.fused_moe.config
import
FusedMoEQuantConfig
from
vllm.model_executor.layers.fused_moe.utils
import
(
moe_kernel_quantize_input
)
_validate_scale_shape
,
moe_kernel_quantize_input
)
from
vllm.utils
import
cdiv
,
round_up
...
...
@@ -32,16 +32,16 @@ def pplx_hidden_dim_scale_bytes(
elem_size
=
torch
.
float32
.
itemsize
if
per_act_token_quant
:
# per-token
# per-token
(M x 1)
assert
block_shape
is
None
hidden_scale_bytes
=
elem_size
elif
block_shape
is
not
None
:
# per-group
# per-group
(M x K_tiles)
block_size
=
block_shape
[
1
]
num_blocks
=
cdiv
(
hidden_dim
,
block_size
)
hidden_scale_bytes
=
num_blocks
*
elem_size
else
:
# per-tensor
# per-tensor
(1 x 1)
hidden_scale_bytes
=
elem_size
else
:
hidden_dim_bytes
=
hidden_dim
*
in_dtype
.
itemsize
...
...
@@ -53,25 +53,22 @@ def pplx_hidden_dim_scale_bytes(
)
# The max_num_tokens, world_size and dp_size must be the same
# as the ones used to create the AllToAll.
class
PplxPrepareAndFinalize
(
mk
.
FusedMoEPrepareAndFinalize
):
def
__init__
(
self
,
a2a
:
pplx
.
AllToAll
,
max_num_tokens
:
int
,
world_size
:
int
,
rank
:
int
,
dp_size
:
int
,
num_local_experts
:
int
,
num_dispatchers
:
int
,
):
super
().
__init__
()
assert
max_num_tokens
>
0
assert
num_local_experts
>
0
self
.
a2a
=
a2a
self
.
max_num_tokens
=
max_num_tokens
self
.
world_size
=
world_size
self
.
rank
=
rank
self
.
dp_size
=
dp_size
self
.
num_local_experts
=
num_local_experts
self
.
num_dispatchers_
=
num_dispatchers
@
property
def
activation_format
(
self
)
->
mk
.
FusedMoEActivationFormat
:
...
...
@@ -83,6 +80,9 @@ class PplxPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
def
topk_indices_dtype
(
self
)
->
Optional
[
torch
.
dtype
]:
return
torch
.
uint32
def
num_dispatchers
(
self
)
->
int
:
return
self
.
num_dispatchers_
def
prepare
(
self
,
a1
:
torch
.
Tensor
,
...
...
@@ -120,42 +120,64 @@ class PplxPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
per_act_token_quant
=
quant_config
.
per_act_token_quant
,
block_shape
=
quant_config
.
block_shape
)
_validate_scale_shape
(
a1q
,
a1q_scale
,
quant_config
.
per_act_token_quant
,
quant_config
.
block_shape
)
if
a1q_scale
is
not
None
:
if
a1q_scale
.
numel
()
==
1
:
orig_a_scale_block_shape
=
1
else
:
scalar_scales
=
a1q_scale
.
numel
()
==
1
# pplx requires 2-d scales even for scalar scales
if
a1q_scale
.
dim
()
<=
1
:
assert
scalar_scales
a1q_scale
=
a1q_scale
.
view
(
1
,
1
)
orig_a_scale_block_shape
=
a1q_scale
.
shape
[
-
1
]
if
not
quant_config
.
is_block_quantized
:
# TODO (bnell): use group_broadcast instead?
a1q_scale
=
a1q_scale
.
repeat
(
repeat_rows
,
repeat_cols
)
# rem_experts need to be 0 for pplx to work properly.
rem_experts
=
num_experts
%
self
.
world_size
assert
rem_experts
==
0
num_local_experts
=
((
num_experts
//
self
.
world_size
)
+
(
1
if
self
.
rank
<
rem_experts
else
0
))
assert
a1q_scale
is
None
or
a1q_scale
.
ndim
==
2
,
\
f
"
{
0
if
a1q_scale
is
None
else
(
a1q_scale
.
ndim
,
a1q_scale
.
shape
)
}
"
expert_num_tokens
=
torch
.
empty
(
num_local_experts
,
self
.
num_local_experts
,
dtype
=
torch
.
int32
,
device
=
device
,
)
num_dp
=
self
.
world_size
//
self
.
dp_size
expert_x
=
torch
.
empty
(
(
num_local_experts
,
self
.
max_num_tokens
*
num_dp
,
hidden_dim
),
(
self
.
num_local_experts
,
self
.
max_num_tokens
*
self
.
num_dispatchers
(),
hidden_dim
),
dtype
=
a1q
.
dtype
,
device
=
device
,
)
expert_x_scale
:
Optional
[
torch
.
Tensor
]
=
None
if
a1q
.
dtype
.
itemsize
==
1
:
block_size
=
(
quant_config
.
block_shape
[
1
]
if
quant_config
.
block_shape
is
not
None
else
1
)
if
quant_config
.
is_per_act_token
:
# (M x 1) -> (E x M x K)
final_dim
=
expert_x
.
size
(
2
)
elif
quant_config
.
is_per_tensor
:
# (1 x 1) -> (E x 1 x 1)
final_dim
=
1
else
:
# (M x K_tiles) -> (E x M x K_tiles)
assert
quant_config
.
block_shape
is
not
None
num_blocks
=
cdiv
(
expert_x
.
size
(
2
),
quant_config
.
block_shape
[
1
])
final_dim
=
num_blocks
expert_x_scale_shape
=
(
self
.
num_local_experts
,
expert_x
.
size
(
1
),
round_up
(
final_dim
,
4
)
# round up for alignment
)
expert_x_scale
=
torch
.
empty
(
(
num_local_experts
,
expert_x
.
size
(
1
),
round_up
(
(
expert_x
.
size
(
2
)
+
block_size
-
1
)
//
block_size
,
4
)),
expert_x_scale_shape
,
dtype
=
torch
.
float32
,
device
=
device
,
device
=
expert_x
.
device
,
)
# This argument is optional, defaults to indices.size(0)
...
...
@@ -171,8 +193,10 @@ class PplxPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
indices
=
topk_ids
,
bound_m
=
bound_m
,
)
if
expert_x_scale
is
not
None
:
expert_x_scale
=
expert_x_scale
[:,
:,
:
orig_a_scale_block_shape
]
assert
expert_x_scale
.
ndim
==
3
return
expert_x
,
expert_x_scale
,
expert_num_tokens
,
None
,
None
...
...
@@ -184,13 +208,16 @@ class PplxPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
topk_ids
:
torch
.
Tensor
,
apply_router_weight_on_input
:
bool
,
)
->
None
:
num_tokens
=
output
.
size
(
0
)
# M
# This argument is optional
# There's not much point setting this unless it is != topk_ids.size(0)
bound_m
:
Optional
[
torch
.
Tensor
]
=
None
assert
topk_ids
.
size
(
0
)
==
num_tokens
,
(
f
"
{
topk_ids
.
size
(
0
)
}
==
{
num_tokens
}
"
)
# TODO (bnell): fails in test_pplx_moe.py, figure out what's going on
#num_tokens = output.size(0) # M
#assert topk_ids.size(0) == num_tokens, (
# f"{topk_ids.size(0)} == {num_tokens}")
assert
topk_ids
.
size
()
==
topk_weights
.
size
(),
(
f
"
{
topk_ids
.
size
()
}
==
{
topk_weights
.
size
()
}
"
)
assert
output
.
size
(
0
)
<=
self
.
max_num_tokens
,
(
f
"
{
output
.
size
(
0
)
}
<=
{
self
.
max_num_tokens
}
"
)
assert
output
.
size
(
1
)
==
fused_expert_output
.
size
(
-
1
)
...
...
Prev
1
2
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment