Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
b17109be
Unverified
Commit
b17109be
authored
Aug 20, 2025
by
shixianc
Committed by
GitHub
Aug 20, 2025
Browse files
[Kernel] CUTLASS MoE FP8: Integrate cuda moe permute/unpermute (#23045)
Signed-off-by:
Shixian Cui
<
shixian@amazon.com
>
parent
44492358
Changes
15
Hide whitespace changes
Inline
Side-by-side
Showing
15 changed files
with
369 additions
and
121 deletions
+369
-121
benchmarks/kernels/benchmark_grouped_gemm_cutlass.py
benchmarks/kernels/benchmark_grouped_gemm_cutlass.py
+34
-1
csrc/moe/moe_permute_unpermute_op.cu
csrc/moe/moe_permute_unpermute_op.cu
+14
-19
csrc/ops.h
csrc/ops.h
+5
-0
csrc/quantization/cutlass_w8a8/moe/get_group_starts.cuh
csrc/quantization/cutlass_w8a8/moe/get_group_starts.cuh
+4
-2
csrc/quantization/cutlass_w8a8/moe/moe_data.cu
csrc/quantization/cutlass_w8a8/moe/moe_data.cu
+50
-15
csrc/quantization/cutlass_w8a8/scaled_mm_entry.cu
csrc/quantization/cutlass_w8a8/scaled_mm_entry.cu
+24
-0
csrc/torch_bindings.cpp
csrc/torch_bindings.cpp
+13
-0
tests/kernels/moe/test_cutlass_moe.py
tests/kernels/moe/test_cutlass_moe.py
+14
-4
tests/kernels/moe/test_moe_permute_unpermute.py
tests/kernels/moe/test_moe_permute_unpermute.py
+5
-1
tests/kernels/moe/test_pplx_cutlass_moe.py
tests/kernels/moe/test_pplx_cutlass_moe.py
+21
-1
tests/kernels/quantization/test_cutlass_scaled_mm.py
tests/kernels/quantization/test_cutlass_scaled_mm.py
+1
-1
vllm/_custom_ops.py
vllm/_custom_ops.py
+22
-0
vllm/model_executor/layers/fused_moe/cutlass_moe.py
vllm/model_executor/layers/fused_moe/cutlass_moe.py
+111
-68
vllm/model_executor/layers/fused_moe/moe_permute_unpermute.py
.../model_executor/layers/fused_moe/moe_permute_unpermute.py
+20
-9
vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py
...quantization/compressed_tensors/compressed_tensors_moe.py
+31
-0
No files found.
benchmarks/kernels/benchmark_grouped_gemm_cutlass.py
View file @
b17109be
...
@@ -80,6 +80,11 @@ def bench_run(
...
@@ -80,6 +80,11 @@ def bench_run(
a
,
score
,
topk
,
renormalize
=
False
a
,
score
,
topk
,
renormalize
=
False
)
)
ab_strides1
=
torch
.
full
((
num_experts
,),
k
,
device
=
"cuda"
,
dtype
=
torch
.
int64
)
ab_strides2
=
torch
.
full
((
num_experts
,),
n
,
device
=
"cuda"
,
dtype
=
torch
.
int64
)
c_strides1
=
torch
.
full
((
num_experts
,),
2
*
n
,
device
=
"cuda"
,
dtype
=
torch
.
int64
)
c_strides2
=
torch
.
full
((
num_experts
,),
k
,
device
=
"cuda"
,
dtype
=
torch
.
int64
)
def
run_triton_moe
(
def
run_triton_moe
(
a
:
torch
.
Tensor
,
a
:
torch
.
Tensor
,
w1
:
torch
.
Tensor
,
w1
:
torch
.
Tensor
,
...
@@ -111,6 +116,10 @@ def bench_run(
...
@@ -111,6 +116,10 @@ def bench_run(
w2
:
torch
.
Tensor
,
w2
:
torch
.
Tensor
,
w1_scale
:
torch
.
Tensor
,
w1_scale
:
torch
.
Tensor
,
w2_scale
:
torch
.
Tensor
,
w2_scale
:
torch
.
Tensor
,
ab_strides1
:
torch
.
Tensor
,
ab_strides2
:
torch
.
Tensor
,
c_strides1
:
torch
.
Tensor
,
c_strides2
:
torch
.
Tensor
,
topk_weights
:
torch
.
Tensor
,
topk_weights
:
torch
.
Tensor
,
topk_ids
:
torch
.
Tensor
,
topk_ids
:
torch
.
Tensor
,
per_act_token
:
bool
,
per_act_token
:
bool
,
...
@@ -125,6 +134,10 @@ def bench_run(
...
@@ -125,6 +134,10 @@ def bench_run(
topk_ids
,
topk_ids
,
w1_scale
,
w1_scale
,
w2_scale
,
w2_scale
,
ab_strides1
,
ab_strides2
,
c_strides1
,
c_strides2
,
per_act_token
,
per_act_token
,
a1_scale
=
None
,
a1_scale
=
None
,
)
)
...
@@ -136,6 +149,10 @@ def bench_run(
...
@@ -136,6 +149,10 @@ def bench_run(
w2_q
:
torch
.
Tensor
,
w2_q
:
torch
.
Tensor
,
w1_scale
:
torch
.
Tensor
,
w1_scale
:
torch
.
Tensor
,
w2_scale
:
torch
.
Tensor
,
w2_scale
:
torch
.
Tensor
,
ab_strides1
:
torch
.
Tensor
,
ab_strides2
:
torch
.
Tensor
,
c_strides1
:
torch
.
Tensor
,
c_strides2
:
torch
.
Tensor
,
topk_weights
:
torch
.
Tensor
,
topk_weights
:
torch
.
Tensor
,
topk_ids
:
torch
.
Tensor
,
topk_ids
:
torch
.
Tensor
,
):
):
...
@@ -150,6 +167,10 @@ def bench_run(
...
@@ -150,6 +167,10 @@ def bench_run(
topk_ids
,
topk_ids
,
w1_scale
,
w1_scale
,
w2_scale
,
w2_scale
,
ab_strides1
,
ab_strides2
,
c_strides1
,
c_strides2
,
per_act_token
,
per_act_token
,
a1_scale
=
None
,
a1_scale
=
None
,
)
)
...
@@ -194,6 +215,10 @@ def bench_run(
...
@@ -194,6 +215,10 @@ def bench_run(
w2_q
,
w2_q
,
w1_scale
,
w1_scale
,
w2_scale
,
w2_scale
,
ab_strides1
,
ab_strides2
,
c_strides1
,
c_strides2
,
topk_weights
,
topk_weights
,
topk_ids
,
topk_ids
,
)
)
...
@@ -231,6 +256,10 @@ def bench_run(
...
@@ -231,6 +256,10 @@ def bench_run(
"w1_scale"
:
w1_scale
,
"w1_scale"
:
w1_scale
,
"w2_scale"
:
w2_scale
,
"w2_scale"
:
w2_scale
,
"per_act_token"
:
per_act_token
,
"per_act_token"
:
per_act_token
,
"ab_strides1"
:
ab_strides1
,
"ab_strides2"
:
ab_strides2
,
"c_strides1"
:
c_strides1
,
"c_strides2"
:
c_strides2
,
# cuda graph params
# cuda graph params
"cutlass_graph"
:
cutlass_graph
,
"cutlass_graph"
:
cutlass_graph
,
"triton_graph"
:
triton_graph
,
"triton_graph"
:
triton_graph
,
...
@@ -289,6 +318,10 @@ def bench_run(
...
@@ -289,6 +318,10 @@ def bench_run(
w2_q
,
w2_q
,
w1_scale
,
w1_scale
,
w2_scale
,
w2_scale
,
ab_strides1
,
ab_strides2
,
c_strides1
,
c_strides2
,
topk_weights
,
topk_weights
,
topk_ids
,
topk_ids
,
per_act_token
,
per_act_token
,
...
@@ -297,7 +330,7 @@ def bench_run(
...
@@ -297,7 +330,7 @@ def bench_run(
results
.
append
(
results
.
append
(
benchmark
.
Timer
(
benchmark
.
Timer
(
stmt
=
"run_cutlass_moe(a, a_scale, w1_q, w2_q, w1_scale, w2_scale, topk_weights, topk_ids, per_act_token, num_runs)"
,
# noqa: E501
stmt
=
"run_cutlass_moe(a, a_scale, w1_q, w2_q, w1_scale, w2_scale,
ab_strides1, ab_strides2, c_strides1, c_strides2,
topk_weights, topk_ids, per_act_token, num_runs)"
,
# noqa: E501
globals
=
globals
,
globals
=
globals
,
label
=
label
,
label
=
label
,
sub_label
=
sub_label
,
sub_label
=
sub_label
,
...
...
csrc/moe/moe_permute_unpermute_op.cu
View file @
b17109be
...
@@ -45,8 +45,6 @@ void moe_permute(
...
@@ -45,8 +45,6 @@ void moe_permute(
auto
copy_topk_ids
=
topk_ids
.
clone
();
// copy topk_ids for preprocess
auto
copy_topk_ids
=
topk_ids
.
clone
();
// copy topk_ids for preprocess
auto
permuted_experts_id
=
torch
::
empty_like
(
topk_ids
);
auto
permuted_experts_id
=
torch
::
empty_like
(
topk_ids
);
auto
sorted_row_idx
=
torch
::
empty_like
(
inv_permuted_idx
);
auto
sorted_row_idx
=
torch
::
empty_like
(
inv_permuted_idx
);
auto
align_expert_first_token_offset
=
torch
::
zeros_like
(
expert_first_token_offset
);
CubKeyValueSorter
sorter
{};
CubKeyValueSorter
sorter
{};
int64_t
*
valid_num_ptr
=
nullptr
;
int64_t
*
valid_num_ptr
=
nullptr
;
...
@@ -85,12 +83,14 @@ void moe_permute(
...
@@ -85,12 +83,14 @@ void moe_permute(
});
});
// get m_indices and update expert_first_token_offset with align block
// get m_indices and update expert_first_token_offset with align block
getMIndices
(
get_ptr
<
int64_t
>
(
expert_first_token_offset
),
// this is only required for DeepGemm and not required for CUTLASS group gemm
get_ptr
<
int64_t
>
(
align_expert_first_token_offset
),
get_ptr
<
int
>
(
m_indices
),
n_local_expert
,
align_block_size_value
,
stream
);
if
(
align_block_size
.
has_value
())
{
if
(
align_block_size
.
has_value
())
{
// update align_expert_first_token_offset
auto
align_expert_first_token_offset
=
torch
::
zeros_like
(
expert_first_token_offset
);
getMIndices
(
get_ptr
<
int64_t
>
(
expert_first_token_offset
),
get_ptr
<
int64_t
>
(
align_expert_first_token_offset
),
get_ptr
<
int
>
(
m_indices
),
n_local_expert
,
align_block_size_value
,
stream
);
expert_first_token_offset
.
copy_
(
align_expert_first_token_offset
);
expert_first_token_offset
.
copy_
(
align_expert_first_token_offset
);
}
}
}
}
...
@@ -195,19 +195,14 @@ void moe_permute(const torch::Tensor& input, const torch::Tensor& topk_weights,
...
@@ -195,19 +195,14 @@ void moe_permute(const torch::Tensor& input, const torch::Tensor& topk_weights,
torch
::
Tensor
&
expert_first_token_offset
,
torch
::
Tensor
&
expert_first_token_offset
,
torch
::
Tensor
&
src_row_id2dst_row_id_map
,
torch
::
Tensor
&
src_row_id2dst_row_id_map
,
torch
::
Tensor
&
m_indices
)
{
torch
::
Tensor
&
m_indices
)
{
TORCH_CHECK
(
false
,
"moe_
un
permute is not supported on CUDA < 12.0"
);
TORCH_CHECK
(
false
,
"moe_permute is not supported on CUDA < 12.0"
);
}
}
void
moe_unpermute
(
const
torch
::
Tensor
&
input
,
void
moe_unpermute
(
const
torch
::
Tensor
&
topk_weights
,
torch
::
Tensor
&
topk_ids
,
const
torch
::
Tensor
&
permuted_hidden_states
,
const
torch
::
Tensor
&
token_expert_indices
,
const
torch
::
Tensor
&
topk_weights
,
const
torch
::
Tensor
&
inv_permuted_idx
,
const
std
::
optional
<
torch
::
Tensor
>&
expert_map
,
const
std
::
optional
<
torch
::
Tensor
>&
expert_first_token_offset
,
int64_t
topk
,
int64_t
n_expert
,
int64_t
n_local_expert
,
int64_t
topk
,
torch
::
Tensor
&
hidden_states
)
{
const
std
::
optional
<
int64_t
>&
align_block_size
,
torch
::
Tensor
&
permuted_input
,
torch
::
Tensor
&
expert_first_token_offset
,
torch
::
Tensor
&
src_row_id2dst_row_id_map
,
torch
::
Tensor
&
m_indices
)
{
TORCH_CHECK
(
false
,
"moe_unpermute is not supported on CUDA < 12.0"
);
TORCH_CHECK
(
false
,
"moe_unpermute is not supported on CUDA < 12.0"
);
}
}
...
@@ -224,4 +219,4 @@ bool moe_permute_unpermute_supported() {
...
@@ -224,4 +219,4 @@ bool moe_permute_unpermute_supported() {
TORCH_LIBRARY_IMPL_EXPAND
(
TORCH_EXTENSION_NAME
,
CUDA
,
m
)
{
TORCH_LIBRARY_IMPL_EXPAND
(
TORCH_EXTENSION_NAME
,
CUDA
,
m
)
{
m
.
impl
(
"moe_permute"
,
&
moe_permute
);
m
.
impl
(
"moe_permute"
,
&
moe_permute
);
m
.
impl
(
"moe_unpermute"
,
&
moe_unpermute
);
m
.
impl
(
"moe_unpermute"
,
&
moe_unpermute
);
}
}
\ No newline at end of file
csrc/ops.h
View file @
b17109be
...
@@ -229,6 +229,11 @@ void get_cutlass_moe_mm_data(
...
@@ -229,6 +229,11 @@ void get_cutlass_moe_mm_data(
const
int64_t
num_experts
,
const
int64_t
n
,
const
int64_t
k
,
const
int64_t
num_experts
,
const
int64_t
n
,
const
int64_t
k
,
const
std
::
optional
<
torch
::
Tensor
>&
blockscale_offsets
);
const
std
::
optional
<
torch
::
Tensor
>&
blockscale_offsets
);
void
get_cutlass_moe_mm_problem_sizes
(
const
torch
::
Tensor
&
topk_ids
,
torch
::
Tensor
&
problem_sizes1
,
torch
::
Tensor
&
problem_sizes2
,
const
int64_t
num_experts
,
const
int64_t
n
,
const
int64_t
k
,
const
std
::
optional
<
torch
::
Tensor
>&
blockscale_offsets
);
void
get_cutlass_pplx_moe_mm_data
(
torch
::
Tensor
&
expert_offsets
,
void
get_cutlass_pplx_moe_mm_data
(
torch
::
Tensor
&
expert_offsets
,
torch
::
Tensor
&
problem_sizes1
,
torch
::
Tensor
&
problem_sizes1
,
torch
::
Tensor
&
problem_sizes2
,
torch
::
Tensor
&
problem_sizes2
,
...
...
csrc/quantization/cutlass_w8a8/moe/get_group_starts.cuh
View file @
b17109be
...
@@ -10,7 +10,7 @@
...
@@ -10,7 +10,7 @@
template
<
typename
ElementAB
,
typename
ElementC
,
typename
ElementAccumulator
>
template
<
typename
ElementAB
,
typename
ElementC
,
typename
ElementAccumulator
>
__global__
void
get_group_gemm_starts
(
__global__
void
get_group_gemm_starts
(
int
32
_t
*
expert_offsets
,
ElementAB
**
a_offsets
,
ElementAB
**
b_offsets
,
int
64
_t
*
expert_offsets
,
ElementAB
**
a_offsets
,
ElementAB
**
b_offsets
,
ElementC
**
out_offsets
,
ElementAccumulator
**
a_scales_offsets
,
ElementC
**
out_offsets
,
ElementAccumulator
**
a_scales_offsets
,
ElementAccumulator
**
b_scales_offsets
,
ElementAB
*
a_base_as_int
,
ElementAccumulator
**
b_scales_offsets
,
ElementAB
*
a_base_as_int
,
ElementAB
*
b_base_as_int
,
ElementC
*
out_base_as_int
,
ElementAB
*
b_base_as_int
,
ElementC
*
out_base_as_int
,
...
@@ -34,7 +34,7 @@ __global__ void get_group_gemm_starts(
...
@@ -34,7 +34,7 @@ __global__ void get_group_gemm_starts(
else if (out_tensors.dtype() == TENSOR_C_TYPE) { \
else if (out_tensors.dtype() == TENSOR_C_TYPE) { \
get_group_gemm_starts<cutlass::float_e4m3_t, C_TYPE, float> \
get_group_gemm_starts<cutlass::float_e4m3_t, C_TYPE, float> \
<<<1, num_experts, 0, stream>>>( \
<<<1, num_experts, 0, stream>>>( \
static_cast<int
32
_t*>(expert_offsets.data_ptr()), \
static_cast<int
64
_t*>(expert_offsets.data_ptr()), \
static_cast<cutlass::float_e4m3_t**>(a_ptrs.data_ptr()), \
static_cast<cutlass::float_e4m3_t**>(a_ptrs.data_ptr()), \
static_cast<cutlass::float_e4m3_t**>(b_ptrs.data_ptr()), \
static_cast<cutlass::float_e4m3_t**>(b_ptrs.data_ptr()), \
static_cast<C_TYPE**>(out_ptrs.data_ptr()), \
static_cast<C_TYPE**>(out_ptrs.data_ptr()), \
...
@@ -61,6 +61,8 @@ void run_get_group_gemm_starts(
...
@@ -61,6 +61,8 @@ void run_get_group_gemm_starts(
TORCH_CHECK
(
b_tensors
.
dtype
()
==
torch
::
kFloat8_e4m3fn
);
TORCH_CHECK
(
b_tensors
.
dtype
()
==
torch
::
kFloat8_e4m3fn
);
TORCH_CHECK
(
a_scales
.
dtype
()
==
torch
::
kFloat32
);
TORCH_CHECK
(
a_scales
.
dtype
()
==
torch
::
kFloat32
);
TORCH_CHECK
(
b_scales
.
dtype
()
==
torch
::
kFloat32
);
TORCH_CHECK
(
b_scales
.
dtype
()
==
torch
::
kFloat32
);
// expect int64_t to avoid overflow during offset calculations
TORCH_CHECK
(
expert_offsets
.
dtype
()
==
torch
::
kInt64
);
int
num_experts
=
static_cast
<
int
>
(
expert_offsets
.
size
(
0
));
int
num_experts
=
static_cast
<
int
>
(
expert_offsets
.
size
(
0
));
bool
per_act_token
=
a_scales
.
numel
()
!=
1
;
bool
per_act_token
=
a_scales
.
numel
()
!=
1
;
...
...
csrc/quantization/cutlass_w8a8/moe/moe_data.cu
View file @
b17109be
...
@@ -104,6 +104,53 @@ __global__ void compute_arg_sorts(const int32_t* __restrict__ topk_ids,
...
@@ -104,6 +104,53 @@ __global__ void compute_arg_sorts(const int32_t* __restrict__ topk_ids,
}
}
}
}
namespace
{
inline
void
launch_compute_problem_sizes
(
const
torch
::
Tensor
&
topk_ids
,
torch
::
Tensor
&
problem_sizes1
,
torch
::
Tensor
&
problem_sizes2
,
torch
::
Tensor
&
atomic_buffer
,
int64_t
num_experts
,
int64_t
n
,
int64_t
k
,
cudaStream_t
stream
,
const
bool
swap_ab
)
{
int
num_threads
=
min
(
THREADS_PER_EXPERT
,
topk_ids
.
numel
());
const
int32_t
*
topk_ptr
=
static_cast
<
const
int32_t
*>
(
topk_ids
.
data_ptr
());
int32_t
*
ps1_ptr
=
static_cast
<
int32_t
*>
(
problem_sizes1
.
data_ptr
());
int32_t
*
ps2_ptr
=
static_cast
<
int32_t
*>
(
problem_sizes2
.
data_ptr
());
int32_t
*
atomic_ptr
=
static_cast
<
int32_t
*>
(
atomic_buffer
.
data_ptr
());
if
(
swap_ab
)
{
compute_problem_sizes
<
true
><<<
num_experts
,
num_threads
,
0
,
stream
>>>
(
topk_ptr
,
ps1_ptr
,
ps2_ptr
,
atomic_ptr
,
static_cast
<
int
>
(
topk_ids
.
numel
()),
static_cast
<
int
>
(
n
),
static_cast
<
int
>
(
k
));
}
else
{
compute_problem_sizes
<
false
><<<
num_experts
,
num_threads
,
0
,
stream
>>>
(
topk_ptr
,
ps1_ptr
,
ps2_ptr
,
atomic_ptr
,
static_cast
<
int
>
(
topk_ids
.
numel
()),
static_cast
<
int
>
(
n
),
static_cast
<
int
>
(
k
));
}
}
}
// namespace
void
get_cutlass_moe_mm_problem_sizes_caller
(
const
torch
::
Tensor
&
topk_ids
,
torch
::
Tensor
&
problem_sizes1
,
torch
::
Tensor
&
problem_sizes2
,
const
int64_t
num_experts
,
const
int64_t
n
,
const
int64_t
k
,
const
std
::
optional
<
torch
::
Tensor
>&
blockscale_offsets
)
{
auto
stream
=
at
::
cuda
::
getCurrentCUDAStream
(
topk_ids
.
device
().
index
());
auto
options_int32
=
torch
::
TensorOptions
().
dtype
(
torch
::
kInt32
).
device
(
topk_ids
.
device
());
torch
::
Tensor
atomic_buffer
=
torch
::
zeros
(
num_experts
,
options_int32
);
// Swap-AB should be disabled for FP4 path
bool
may_swap_ab
=
(
!
blockscale_offsets
.
has_value
())
&&
(
topk_ids
.
numel
()
<=
SWAP_AB_THRESHOLD
);
launch_compute_problem_sizes
(
topk_ids
,
problem_sizes1
,
problem_sizes2
,
atomic_buffer
,
num_experts
,
n
,
k
,
stream
,
may_swap_ab
);
}
void
get_cutlass_moe_mm_data_caller
(
void
get_cutlass_moe_mm_data_caller
(
const
torch
::
Tensor
&
topk_ids
,
torch
::
Tensor
&
expert_offsets
,
const
torch
::
Tensor
&
topk_ids
,
torch
::
Tensor
&
expert_offsets
,
torch
::
Tensor
&
problem_sizes1
,
torch
::
Tensor
&
problem_sizes2
,
torch
::
Tensor
&
problem_sizes1
,
torch
::
Tensor
&
problem_sizes2
,
...
@@ -121,21 +168,9 @@ void get_cutlass_moe_mm_data_caller(
...
@@ -121,21 +168,9 @@ void get_cutlass_moe_mm_data_caller(
bool
may_swap_ab
=
(
!
blockscale_offsets
.
has_value
())
&&
bool
may_swap_ab
=
(
!
blockscale_offsets
.
has_value
())
&&
(
topk_ids
.
numel
()
<=
SWAP_AB_THRESHOLD
);
(
topk_ids
.
numel
()
<=
SWAP_AB_THRESHOLD
);
if
(
may_swap_ab
)
{
launch_compute_problem_sizes
(
topk_ids
,
problem_sizes1
,
problem_sizes2
,
compute_problem_sizes
<
true
><<<
num_experts
,
num_threads
,
0
,
stream
>>>
(
atomic_buffer
,
num_experts
,
n
,
k
,
stream
,
static_cast
<
const
int32_t
*>
(
topk_ids
.
data_ptr
()),
may_swap_ab
);
static_cast
<
int32_t
*>
(
problem_sizes1
.
data_ptr
()),
static_cast
<
int32_t
*>
(
problem_sizes2
.
data_ptr
()),
static_cast
<
int32_t
*>
(
atomic_buffer
.
data_ptr
()),
topk_ids
.
numel
(),
n
,
k
);
}
else
{
compute_problem_sizes
<
false
><<<
num_experts
,
num_threads
,
0
,
stream
>>>
(
static_cast
<
const
int32_t
*>
(
topk_ids
.
data_ptr
()),
static_cast
<
int32_t
*>
(
problem_sizes1
.
data_ptr
()),
static_cast
<
int32_t
*>
(
problem_sizes2
.
data_ptr
()),
static_cast
<
int32_t
*>
(
atomic_buffer
.
data_ptr
()),
topk_ids
.
numel
(),
n
,
k
);
}
if
(
blockscale_offsets
.
has_value
())
{
if
(
blockscale_offsets
.
has_value
())
{
// fp4 path
// fp4 path
...
...
csrc/quantization/cutlass_w8a8/scaled_mm_entry.cu
View file @
b17109be
...
@@ -76,6 +76,11 @@ void get_cutlass_moe_mm_data_caller(
...
@@ -76,6 +76,11 @@ void get_cutlass_moe_mm_data_caller(
const
int64_t
num_experts
,
const
int64_t
n
,
const
int64_t
k
,
const
int64_t
num_experts
,
const
int64_t
n
,
const
int64_t
k
,
const
std
::
optional
<
torch
::
Tensor
>&
blockscale_offsets
);
const
std
::
optional
<
torch
::
Tensor
>&
blockscale_offsets
);
void
get_cutlass_moe_mm_problem_sizes_caller
(
const
torch
::
Tensor
&
topk_ids
,
torch
::
Tensor
&
problem_sizes1
,
torch
::
Tensor
&
problem_sizes2
,
const
int64_t
num_experts
,
const
int64_t
n
,
const
int64_t
k
,
const
std
::
optional
<
torch
::
Tensor
>&
blockscale_offsets
);
void
get_cutlass_pplx_moe_mm_data_caller
(
torch
::
Tensor
&
expert_offsets
,
void
get_cutlass_pplx_moe_mm_data_caller
(
torch
::
Tensor
&
expert_offsets
,
torch
::
Tensor
&
problem_sizes1
,
torch
::
Tensor
&
problem_sizes1
,
torch
::
Tensor
&
problem_sizes2
,
torch
::
Tensor
&
problem_sizes2
,
...
@@ -293,6 +298,25 @@ void get_cutlass_moe_mm_data(
...
@@ -293,6 +298,25 @@ void get_cutlass_moe_mm_data(
version_num
,
". Required capability: 90 or 100"
);
version_num
,
". Required capability: 90 or 100"
);
}
}
void
get_cutlass_moe_mm_problem_sizes
(
const
torch
::
Tensor
&
topk_ids
,
torch
::
Tensor
&
problem_sizes1
,
torch
::
Tensor
&
problem_sizes2
,
const
int64_t
num_experts
,
const
int64_t
n
,
const
int64_t
k
,
const
std
::
optional
<
torch
::
Tensor
>&
blockscale_offsets
)
{
int32_t
version_num
=
get_sm_version_num
();
#if (defined ENABLE_CUTLASS_MOE_SM90 && ENABLE_CUTLASS_MOE_SM90) || \
(defined ENABLE_CUTLASS_MOE_SM100 && ENABLE_CUTLASS_MOE_SM100)
get_cutlass_moe_mm_problem_sizes_caller
(
topk_ids
,
problem_sizes1
,
problem_sizes2
,
num_experts
,
n
,
k
,
blockscale_offsets
);
return
;
#endif
TORCH_CHECK_NOT_IMPLEMENTED
(
false
,
"No compiled get_cutlass_moe_mm_problem_sizes: no cutlass_scaled_mm "
"kernel for CUDA device capability: "
,
version_num
,
". Required capability: 90 or 100"
);
}
void
get_cutlass_pplx_moe_mm_data
(
torch
::
Tensor
&
expert_offsets
,
void
get_cutlass_pplx_moe_mm_data
(
torch
::
Tensor
&
expert_offsets
,
torch
::
Tensor
&
problem_sizes1
,
torch
::
Tensor
&
problem_sizes1
,
torch
::
Tensor
&
problem_sizes2
,
torch
::
Tensor
&
problem_sizes2
,
...
...
csrc/torch_bindings.cpp
View file @
b17109be
...
@@ -440,6 +440,19 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
...
@@ -440,6 +440,19 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
{
stride_tag
});
{
stride_tag
});
ops
.
impl
(
"get_cutlass_moe_mm_data"
,
torch
::
kCUDA
,
&
get_cutlass_moe_mm_data
);
ops
.
impl
(
"get_cutlass_moe_mm_data"
,
torch
::
kCUDA
,
&
get_cutlass_moe_mm_data
);
// A function that computes problem sizes for each expert's multiplication
// used by the two mms called from fused MoE operation. It takes topk_ids as
// an input, and computes problem_sizes1 and problem_sizes2 only.
ops
.
def
(
"get_cutlass_moe_mm_problem_sizes(Tensor topk_ids, "
" Tensor! problem_sizes1, "
" Tensor! problem_sizes2, "
" int num_experts, int n, int k, "
" Tensor? blockscale_offsets) -> ()"
,
{
stride_tag
});
ops
.
impl
(
"get_cutlass_moe_mm_problem_sizes"
,
torch
::
kCUDA
,
&
get_cutlass_moe_mm_problem_sizes
);
// A function that computes data required to run fused MoE with w8a8 grouped
// A function that computes data required to run fused MoE with w8a8 grouped
// GEMM and PPLX. It takes expert_num_tokens and non_zero_expert_idxs
// GEMM and PPLX. It takes expert_num_tokens and non_zero_expert_idxs
// as an input, and computes expert_offsets (token start indices of each
// as an input, and computes expert_offsets (token start indices of each
...
...
tests/kernels/moe/test_cutlass_moe.py
View file @
b17109be
...
@@ -207,6 +207,10 @@ def run_8_bit(moe_tensors: MOETensors8Bit,
...
@@ -207,6 +207,10 @@ def run_8_bit(moe_tensors: MOETensors8Bit,
'topk_ids'
:
topk_ids
,
'topk_ids'
:
topk_ids
,
'w1_scale'
:
moe_tensors
.
w1_scale
,
'w1_scale'
:
moe_tensors
.
w1_scale
,
'w2_scale'
:
moe_tensors
.
w2_scale
,
'w2_scale'
:
moe_tensors
.
w2_scale
,
'ab_strides1'
:
moe_tensors
.
ab_strides1
,
'ab_strides2'
:
moe_tensors
.
ab_strides2
,
'c_strides1'
:
moe_tensors
.
c_strides1
,
'c_strides2'
:
moe_tensors
.
c_strides2
,
'per_act_token'
:
per_act_token
,
'per_act_token'
:
per_act_token
,
'a1_scale'
:
None
#moe_tensors.a_scale
'a1_scale'
:
None
#moe_tensors.a_scale
}
}
...
@@ -424,8 +428,8 @@ def test_run_cutlass_moe_fp8(
...
@@ -424,8 +428,8 @@ def test_run_cutlass_moe_fp8(
topk_ids
[
0
][
1
]
=
1
topk_ids
[
0
][
1
]
=
1
workspace13_shape
=
(
m
*
topk
,
max
(
2
*
n
,
k
))
workspace13_shape
=
(
m
*
topk
,
max
(
2
*
n
,
k
))
workspace2_shape
=
(
m
*
topk
,
n
)
workspace2_shape
=
(
m
*
topk
,
max
(
n
,
k
)
)
output_shape
=
(
m
*
topk
,
k
)
output_shape
=
(
m
,
k
)
workspace13
=
torch
.
empty
(
prod
(
workspace13_shape
),
workspace13
=
torch
.
empty
(
prod
(
workspace13_shape
),
device
=
"cuda"
,
device
=
"cuda"
,
...
@@ -440,6 +444,11 @@ def test_run_cutlass_moe_fp8(
...
@@ -440,6 +444,11 @@ def test_run_cutlass_moe_fp8(
expert_map
[
start
:
end
]
=
list
(
range
(
num_local_experts
))
expert_map
[
start
:
end
]
=
list
(
range
(
num_local_experts
))
expert_map
=
torch
.
tensor
(
expert_map
,
dtype
=
torch
.
int32
,
device
=
"cuda"
)
expert_map
=
torch
.
tensor
(
expert_map
,
dtype
=
torch
.
int32
,
device
=
"cuda"
)
ab_strides1
=
torch
.
full
((
e
,
),
k
,
device
=
"cuda"
,
dtype
=
torch
.
int64
)
ab_strides2
=
torch
.
full
((
e
,
),
n
,
device
=
"cuda"
,
dtype
=
torch
.
int64
)
c_strides1
=
torch
.
full
((
e
,
),
2
*
n
,
device
=
"cuda"
,
dtype
=
torch
.
int64
)
c_strides2
=
torch
.
full
((
e
,
),
k
,
device
=
"cuda"
,
dtype
=
torch
.
int64
)
activation
=
lambda
o
,
i
:
torch
.
ops
.
_C
.
silu_and_mul
(
o
,
i
)
activation
=
lambda
o
,
i
:
torch
.
ops
.
_C
.
silu_and_mul
(
o
,
i
)
a1q
,
a1q_scale
=
moe_kernel_quantize_input
(
mt
.
a
,
mt
.
a_scale
,
a1q
,
a1q_scale
=
moe_kernel_quantize_input
(
mt
.
a
,
mt
.
a_scale
,
torch
.
float8_e4m3fn
,
torch
.
float8_e4m3fn
,
...
@@ -448,8 +457,9 @@ def test_run_cutlass_moe_fp8(
...
@@ -448,8 +457,9 @@ def test_run_cutlass_moe_fp8(
func
=
lambda
output
:
run_cutlass_moe_fp8
(
func
=
lambda
output
:
run_cutlass_moe_fp8
(
output
,
a1q
,
mt
.
w1_q
,
mt
.
w2_q
,
topk_ids
,
activation
,
output
,
a1q
,
mt
.
w1_q
,
mt
.
w2_q
,
topk_ids
,
activation
,
global_num_experts
,
expert_map
,
mt
.
w1_scale
,
mt
.
w2_scale
,
global_num_experts
,
expert_map
,
mt
.
w1_scale
,
mt
.
w2_scale
,
a1q_scale
,
None
,
workspace13
,
workspace2
,
None
,
mt
.
a
.
dtype
,
a1q_scale
,
None
,
ab_strides1
,
ab_strides2
,
c_strides1
,
c_strides2
,
per_act_token
,
per_out_channel
,
False
)
workspace13
,
workspace2
,
None
,
mt
.
a
.
dtype
,
per_act_token
,
per_out_channel
,
False
,
topk_weights
)
workspace13
.
random_
()
workspace13
.
random_
()
output_random_workspace
=
torch
.
empty
(
output_shape
,
output_random_workspace
=
torch
.
empty
(
output_shape
,
...
...
tests/kernels/moe/test_moe_permute_unpermute.py
View file @
b17109be
...
@@ -238,7 +238,11 @@ def test_moe_permute_unpermute(n_token: int, n_hidden: int, topk: int,
...
@@ -238,7 +238,11 @@ def test_moe_permute_unpermute(n_token: int, n_hidden: int, topk: int,
atol
=
0
,
atol
=
0
,
rtol
=
0
)
rtol
=
0
)
# check mindice
# check mindice
torch
.
testing
.
assert_close
(
gold_m_indices
,
m_indices
,
atol
=
0
,
rtol
=
0
)
# current kernel usage assumes deepgemm requires align_block_size
# when it's not provided then we don't compute m_indices (for cutlass)
if
align_block_size
is
not
None
:
torch
.
testing
.
assert_close
(
gold_m_indices
,
m_indices
,
atol
=
0
,
rtol
=
0
)
# check permuted_hidden_states, only valid token
# check permuted_hidden_states, only valid token
torch
.
testing
.
assert_close
(
gold_permuted_hidden_states
[
valid_row_idx
],
torch
.
testing
.
assert_close
(
gold_permuted_hidden_states
[
valid_row_idx
],
permuted_hidden_states
[
valid_row_idx
],
permuted_hidden_states
[
valid_row_idx
],
...
...
tests/kernels/moe/test_pplx_cutlass_moe.py
View file @
b17109be
...
@@ -76,6 +76,7 @@ def pplx_cutlass_moe(
...
@@ -76,6 +76,7 @@ def pplx_cutlass_moe(
assert
torch
.
cuda
.
current_device
()
==
pgi
.
local_rank
assert
torch
.
cuda
.
current_device
()
==
pgi
.
local_rank
num_tokens
,
hidden_dim
=
a
.
shape
num_tokens
,
hidden_dim
=
a
.
shape
intermediate_dim
=
w2
.
shape
[
2
]
num_experts
=
w1
.
shape
[
0
]
num_experts
=
w1
.
shape
[
0
]
block_size
=
hidden_dim
# TODO support more cases
block_size
=
hidden_dim
# TODO support more cases
device
=
pgi
.
device
device
=
pgi
.
device
...
@@ -124,8 +125,27 @@ def pplx_cutlass_moe(
...
@@ -124,8 +125,27 @@ def pplx_cutlass_moe(
num_local_experts
=
num_local_experts
,
num_local_experts
=
num_local_experts
,
num_dispatchers
=
num_dispatchers
)
num_dispatchers
=
num_dispatchers
)
ab_strides1
=
torch
.
full
((
num_local_experts
,
),
hidden_dim
,
device
=
"cuda"
,
dtype
=
torch
.
int64
)
ab_strides2
=
torch
.
full
((
num_local_experts
,
),
intermediate_dim
,
device
=
"cuda"
,
dtype
=
torch
.
int64
)
c_strides1
=
torch
.
full
((
num_local_experts
,
),
2
*
intermediate_dim
,
device
=
"cuda"
,
dtype
=
torch
.
int64
)
c_strides2
=
torch
.
full
((
num_local_experts
,
),
hidden_dim
,
device
=
"cuda"
,
dtype
=
torch
.
int64
)
experts
=
CutlassBatchedExpertsFp8
(
num_local_experts
,
num_dispatchers
,
experts
=
CutlassBatchedExpertsFp8
(
num_local_experts
,
num_dispatchers
,
out_dtype
,
per_act_token
,
per_out_ch
)
out_dtype
,
per_act_token
,
per_out_ch
,
ab_strides1
,
ab_strides2
,
c_strides1
,
c_strides2
)
fused_cutlass_experts
=
FusedMoEModularKernel
(
fused_cutlass_experts
=
FusedMoEModularKernel
(
prepare_finalize
,
prepare_finalize
,
...
...
tests/kernels/quantization/test_cutlass_scaled_mm.py
View file @
b17109be
...
@@ -535,7 +535,7 @@ def test_cutlass_fp8_group_gemm(num_experts: int, per_act_token: bool,
...
@@ -535,7 +535,7 @@ def test_cutlass_fp8_group_gemm(num_experts: int, per_act_token: bool,
expert_offsets
=
torch
.
zeros
((
num_experts
+
1
),
expert_offsets
=
torch
.
zeros
((
num_experts
+
1
),
device
=
device
,
device
=
device
,
dtype
=
torch
.
int
32
)
dtype
=
torch
.
int
64
)
problem_sizes
=
torch
.
zeros
((
num_experts
,
3
),
problem_sizes
=
torch
.
zeros
((
num_experts
,
3
),
device
=
device
,
device
=
device
,
...
...
vllm/_custom_ops.py
View file @
b17109be
...
@@ -844,6 +844,28 @@ def get_cutlass_moe_mm_data(topk_ids: torch.Tensor,
...
@@ -844,6 +844,28 @@ def get_cutlass_moe_mm_data(topk_ids: torch.Tensor,
blockscale_offsets
)
blockscale_offsets
)
def
get_cutlass_moe_mm_problem_sizes
(
topk_ids
:
torch
.
Tensor
,
problem_sizes1
:
torch
.
Tensor
,
problem_sizes2
:
torch
.
Tensor
,
num_experts
:
int
,
n
:
int
,
k
:
int
,
blockscale_offsets
:
Optional
[
torch
.
Tensor
]
=
None
):
"""
Compute only the per-expert problem sizes needed by the two grouped matrix
multiplications used in CUTLASS-based fused MoE.
The function takes in topk_ids (token→expert mapping) and computes:
- problem_sizes1, problem_sizes2: M×N×K sizes of each expert's
multiplication for the two grouped MMs
used in the fused MoE operation.
"""
return
torch
.
ops
.
_C
.
get_cutlass_moe_mm_problem_sizes
(
topk_ids
,
problem_sizes1
,
problem_sizes2
,
num_experts
,
n
,
k
,
blockscale_offsets
)
def
shuffle_rows
(
input_tensor
:
torch
.
Tensor
,
dst2src_map
:
torch
.
Tensor
):
def
shuffle_rows
(
input_tensor
:
torch
.
Tensor
,
dst2src_map
:
torch
.
Tensor
):
"""
"""
Shuffle and expand the input tensor according to the dst2src_map and store the result in output_tensor.
Shuffle and expand the input tensor according to the dst2src_map and store the result in output_tensor.
...
...
vllm/model_executor/layers/fused_moe/cutlass_moe.py
View file @
b17109be
...
@@ -9,12 +9,13 @@ import vllm.model_executor.layers.fused_moe.modular_kernel as mk
...
@@ -9,12 +9,13 @@ import vllm.model_executor.layers.fused_moe.modular_kernel as mk
from
vllm
import
_custom_ops
as
ops
from
vllm
import
_custom_ops
as
ops
from
vllm.logger
import
init_logger
from
vllm.logger
import
init_logger
from
vllm.model_executor.layers.fused_moe.config
import
FusedMoEQuantConfig
from
vllm.model_executor.layers.fused_moe.config
import
FusedMoEQuantConfig
from
vllm.model_executor.layers.fused_moe.moe_permute_unpermute
import
(
moe_permute
,
moe_unpermute
)
from
vllm.model_executor.layers.fused_moe.prepare_finalize
import
(
from
vllm.model_executor.layers.fused_moe.prepare_finalize
import
(
MoEPrepareAndFinalizeNoEP
)
MoEPrepareAndFinalizeNoEP
)
from
vllm.model_executor.layers.fused_moe.topk_weight_and_reduce
import
(
from
vllm.model_executor.layers.fused_moe.topk_weight_and_reduce
import
(
TopKWeightAndReduceDelegate
,
TopKWeightAndReduceNoOP
)
TopKWeightAndReduceDelegate
,
TopKWeightAndReduceNoOP
)
from
vllm.model_executor.layers.fused_moe.utils
import
(
_fp8_perm
,
from
vllm.model_executor.layers.fused_moe.utils
import
(
_fp8_quantize
,
_fp8_quantize
,
_resize_cache
)
_resize_cache
)
from
vllm.scalar_type
import
scalar_types
from
vllm.scalar_type
import
scalar_types
...
@@ -34,6 +35,10 @@ def run_cutlass_moe_fp8(
...
@@ -34,6 +35,10 @@ def run_cutlass_moe_fp8(
w2_scale
:
Optional
[
torch
.
Tensor
],
w2_scale
:
Optional
[
torch
.
Tensor
],
a1q_scale
:
Optional
[
torch
.
Tensor
],
a1q_scale
:
Optional
[
torch
.
Tensor
],
a2_scale
:
Optional
[
torch
.
Tensor
],
a2_scale
:
Optional
[
torch
.
Tensor
],
ab_strides1
:
torch
.
Tensor
,
ab_strides2
:
torch
.
Tensor
,
c_strides1
:
torch
.
Tensor
,
c_strides2
:
torch
.
Tensor
,
workspace13
:
torch
.
Tensor
,
workspace13
:
torch
.
Tensor
,
workspace2
:
torch
.
Tensor
,
workspace2
:
torch
.
Tensor
,
expert_num_tokens
:
Optional
[
torch
.
Tensor
],
expert_num_tokens
:
Optional
[
torch
.
Tensor
],
...
@@ -41,6 +46,7 @@ def run_cutlass_moe_fp8(
...
@@ -41,6 +46,7 @@ def run_cutlass_moe_fp8(
per_act_token
:
bool
,
per_act_token
:
bool
,
per_out_ch
:
bool
,
per_out_ch
:
bool
,
use_batched_format
:
bool
,
use_batched_format
:
bool
,
topk_weights
:
Optional
[
torch
.
Tensor
],
):
):
a1q
=
hidden_states
a1q
=
hidden_states
...
@@ -99,6 +105,22 @@ def run_cutlass_moe_fp8(
...
@@ -99,6 +105,22 @@ def run_cutlass_moe_fp8(
topk
=
local_topk_ids
.
size
(
1
)
topk
=
local_topk_ids
.
size
(
1
)
local_E
=
w1
.
size
(
0
)
local_E
=
w1
.
size
(
0
)
if
use_batched_format
:
mm1_out
=
_resize_cache
(
workspace13
,
(
local_E
*
padded_M
,
N
*
2
))
act_out
=
_resize_cache
(
workspace2
,
(
local_E
*
padded_M
,
N
))
quant_out
=
_resize_cache
(
workspace13
.
view
(
dtype
=
torch
.
float8_e4m3fn
),
(
local_E
*
padded_M
,
N
))
mm2_out
=
_resize_cache
(
workspace2
,
(
local_E
*
padded_M
,
K
))
else
:
a1q_perm
=
_resize_cache
(
workspace2
.
view
(
dtype
=
torch
.
float8_e4m3fn
),
(
M
*
topk
,
K
))
mm1_out
=
_resize_cache
(
workspace13
,
(
M
*
topk
,
N
*
2
))
act_out
=
_resize_cache
(
workspace2
,
(
M
*
topk
,
N
))
# original workspace are based on input hidden_states dtype (bf16)
quant_out
=
_resize_cache
(
workspace13
.
view
(
dtype
=
torch
.
float8_e4m3fn
),
(
M
*
topk
,
N
))
mm2_out
=
_resize_cache
(
workspace2
,
(
M
*
topk
,
K
))
if
use_batched_format
:
if
use_batched_format
:
assert
expert_num_tokens
is
not
None
assert
expert_num_tokens
is
not
None
...
@@ -120,11 +142,10 @@ def run_cutlass_moe_fp8(
...
@@ -120,11 +142,10 @@ def run_cutlass_moe_fp8(
w2_scale
=
w2_scale
.
reshape
(
w2_scale
.
size
(
0
),
-
1
)
w2_scale
=
w2_scale
.
reshape
(
w2_scale
.
size
(
0
),
-
1
)
a1q
=
a1q
.
reshape
(
-
1
,
a1q
.
size
(
2
))
a1q
=
a1q
.
reshape
(
-
1
,
a1q
.
size
(
2
))
a1q_scale
=
a1q_scale
.
reshape
(
-
1
,
a1q_scale
.
size
(
2
)).
contiguous
()
a1q_scale
=
a1q_scale
.
reshape
(
-
1
,
a1q_scale
.
size
(
2
)).
contiguous
()
# c3x get_group_gemm_starts expects int64 to avoid overflow
# during offset calculations
expert_offsets
=
expert_offsets
.
to
(
torch
.
int64
)
else
:
else
:
expert_offsets
=
torch
.
empty
((
global_num_experts
+
1
),
dtype
=
torch
.
int32
,
device
=
device
)
problem_sizes1
=
torch
.
empty
((
global_num_experts
,
3
),
problem_sizes1
=
torch
.
empty
((
global_num_experts
,
3
),
dtype
=
torch
.
int32
,
dtype
=
torch
.
int32
,
device
=
device
)
device
=
device
)
...
@@ -132,84 +153,57 @@ def run_cutlass_moe_fp8(
...
@@ -132,84 +153,57 @@ def run_cutlass_moe_fp8(
dtype
=
torch
.
int32
,
dtype
=
torch
.
int32
,
device
=
device
)
device
=
device
)
# With expert_map each Rank processes only a subset of experts. As
num_expert
=
global_num_experts
if
expert_map
is
None
\
# a result not all of a_map and c2 tensors are filled. We fill it
else
expert_map
.
size
(
0
)
# zeros for correctness.
# permuted a1q reuses workspace2
if
expert_map
is
not
None
:
a1q
,
a1q_scale
,
expert_offsets
,
inv_perm
,
_
=
moe_permute
(
a_map
=
torch
.
zeros
((
local_topk_ids
.
numel
()),
a1q
,
dtype
=
torch
.
int32
,
a1q_scale
,
device
=
device
)
topk_ids
,
else
:
num_expert
,
a_map
=
torch
.
empty
((
local_topk_ids
.
numel
()),
local_E
,
dtype
=
torch
.
int32
,
expert_map
,
device
=
device
)
permuted_hidden_states
=
a1q_perm
)
c_map
=
torch
.
empty
((
local_topk_ids
.
numel
()),
dtype
=
torch
.
int32
,
device
=
device
)
ops
.
get_cutlass_moe_mm_data
(
local_topk_ids
,
expert_offsets
,
problem_sizes1
,
problem_sizes2
,
a_map
,
c_map
,
global_num_experts
,
N
,
K
)
a1q
=
_fp8_perm
(
a1q
,
a_map
)
a1q_scale
=
a1q_scale
[
a_map
]
if
per_act_token
else
a1q_scale
expert_offsets
=
expert_offsets
[:
-
1
]
expert_offsets
=
expert_offsets
[:
-
1
]
ab_strides1
=
torch
.
full
((
w1
.
size
(
0
),
),
ops
.
get_cutlass_moe_mm_problem_sizes
(
local_topk_ids
,
problem_sizes1
,
K
,
problem_sizes2
,
device
=
device
,
global_num_experts
,
N
,
K
)
dtype
=
torch
.
int64
)
c_strides1
=
torch
.
full
((
w1
.
size
(
0
),
),
2
*
N
,
device
=
device
,
dtype
=
torch
.
int64
)
ab_strides2
=
torch
.
full
((
w1
.
size
(
0
),
),
N
,
device
=
device
,
dtype
=
torch
.
int64
)
c_strides2
=
torch
.
full
((
w1
.
size
(
0
),
),
K
,
device
=
device
,
dtype
=
torch
.
int64
)
if
use_batched_format
:
c1
=
_resize_cache
(
workspace13
,
(
local_E
*
padded_M
,
N
*
2
))
c2
=
_resize_cache
(
workspace2
,
(
local_E
*
padded_M
,
N
))
c3
=
_resize_cache
(
workspace13
,
(
local_E
*
padded_M
,
K
))
else
:
c1
=
_resize_cache
(
workspace13
,
(
M
*
topk
,
N
*
2
))
c2
=
_resize_cache
(
workspace2
,
(
M
*
topk
,
N
))
c3
=
_resize_cache
(
workspace13
,
(
M
*
topk
,
K
))
if
not
per_act_token
and
(
expert_map
is
not
None
or
use_batched_format
):
if
not
per_act_token
and
(
expert_map
is
not
None
or
use_batched_format
):
# this is necessary to avoid imprecise scale calculation caused by
# this is necessary to avoid imprecise scale calculation caused by
# random data in the unused workspace. The workspace is unused when
# random data in the unused workspace. The workspace is unused when
# this rank handles only partial tokens, or when it is batched .
# this rank handles only partial tokens, or when it is batched .
c1
.
fill_
(
0
)
mm1_out
.
fill_
(
0
)
ops
.
cutlass_moe_mm
(
c1
,
a1q
,
w1
,
a1q_scale
,
w1_scale
,
expert_offsets
,
ops
.
cutlass_moe_mm
(
mm1_out
,
a1q
,
w1
,
a1q_scale
,
w1_scale
,
expert_offsets
,
problem_sizes1
,
ab_strides1
,
ab_strides1
,
c_strides1
,
problem_sizes1
,
ab_strides1
,
ab_strides1
,
c_strides1
,
per_act_token
,
per_out_ch
)
per_act_token
,
per_out_ch
)
activation_callable
(
c2
,
c1
)
activation_callable
(
act_out
,
mm1_out
)
a2q
,
a2q_scale
=
ops
.
scaled_fp8_quant
(
a2q
,
a2q_scale
=
ops
.
scaled_fp8_quant
(
c2
,
a2_scale
,
use_per_token_if_dynamic
=
per_act_token
)
act_out
,
a2_scale
,
use_per_token_if_dynamic
=
per_act_token
,
output
=
quant_out
)
if
expert_map
is
not
None
:
if
expert_map
is
not
None
:
c3
.
fill_
(
0
)
mm2_out
.
fill_
(
0
)
ops
.
cutlass_moe_mm
(
c3
,
a2q
,
w2
,
a2q_scale
,
w2_scale
,
expert_offsets
,
ops
.
cutlass_moe_mm
(
mm2_out
,
a2q
,
w2
,
a2q_scale
,
w2_scale
,
expert_offsets
,
problem_sizes2
,
ab_strides2
,
ab_strides2
,
c_strides2
,
problem_sizes2
,
ab_strides2
,
ab_strides2
,
c_strides2
,
per_act_token
,
per_out_ch
)
per_act_token
,
per_out_ch
)
if
use_batched_format
:
if
use_batched_format
:
output
.
copy_
(
c3
.
reshape
(
local_E
,
padded_M
,
K
),
non_blocking
=
True
)
output
.
copy_
(
mm2_out
.
reshape
(
local_E
,
padded_M
,
K
),
non_blocking
=
True
)
else
:
else
:
# We can't do this inplace because output may point to the same tensor
# for non-chunking mode the output is resized from workspace13
# as c3.
# so we need to make sure mm2_out uses workspace2.
output
.
copy_
(
c3
[
c_map
].
view
(
M
*
topk
,
K
),
non_blocking
=
True
)
moe_unpermute
(
out
=
output
,
permuted_hidden_states
=
mm2_out
,
topk_weights
=
topk_weights
,
inv_permuted_idx
=
inv_perm
)
class
CutlassExpertsFp8Base
(
mk
.
FusedMoEPermuteExpertsUnpermute
):
class
CutlassExpertsFp8Base
(
mk
.
FusedMoEPermuteExpertsUnpermute
):
...
@@ -219,6 +213,10 @@ class CutlassExpertsFp8Base(mk.FusedMoEPermuteExpertsUnpermute):
...
@@ -219,6 +213,10 @@ class CutlassExpertsFp8Base(mk.FusedMoEPermuteExpertsUnpermute):
out_dtype
:
Optional
[
torch
.
dtype
],
out_dtype
:
Optional
[
torch
.
dtype
],
per_act_token_quant
:
bool
,
per_act_token_quant
:
bool
,
per_out_ch_quant
:
bool
,
per_out_ch_quant
:
bool
,
ab_strides1
:
torch
.
Tensor
,
ab_strides2
:
torch
.
Tensor
,
c_strides1
:
torch
.
Tensor
,
c_strides2
:
torch
.
Tensor
,
block_shape
:
Optional
[
list
[
int
]]
=
None
,
block_shape
:
Optional
[
list
[
int
]]
=
None
,
):
):
super
().
__init__
(
super
().
__init__
(
...
@@ -229,6 +227,10 @@ class CutlassExpertsFp8Base(mk.FusedMoEPermuteExpertsUnpermute):
...
@@ -229,6 +227,10 @@ class CutlassExpertsFp8Base(mk.FusedMoEPermuteExpertsUnpermute):
block_shape
=
block_shape
,
block_shape
=
block_shape
,
))
))
self
.
out_dtype
=
out_dtype
self
.
out_dtype
=
out_dtype
self
.
ab_strides1
=
ab_strides1
self
.
ab_strides2
=
ab_strides2
self
.
c_strides1
=
c_strides1
self
.
c_strides2
=
c_strides2
def
finalize_weight_and_reduce_impl
(
self
)
->
mk
.
TopKWeightAndReduce
:
def
finalize_weight_and_reduce_impl
(
self
)
->
mk
.
TopKWeightAndReduce
:
# Let PrepareAndFinalize::finalize() decide the impl.
# Let PrepareAndFinalize::finalize() decide the impl.
...
@@ -272,10 +274,11 @@ class CutlassExpertsFp8Base(mk.FusedMoEPermuteExpertsUnpermute):
...
@@ -272,10 +274,11 @@ class CutlassExpertsFp8Base(mk.FusedMoEPermuteExpertsUnpermute):
run_cutlass_moe_fp8
(
run_cutlass_moe_fp8
(
output
,
hidden_states
,
w1
,
w2
,
topk_ids
,
activation_callable
,
output
,
hidden_states
,
w1
,
w2
,
topk_ids
,
activation_callable
,
global_num_experts
,
expert_map
,
w1_scale
,
w2_scale
,
a1q_scale
,
global_num_experts
,
expert_map
,
w1_scale
,
w2_scale
,
a1q_scale
,
a2_scale
,
workspace13
,
workspace2
,
expert_num_tokens
,
a2_scale
,
self
.
ab_strides1
,
self
.
ab_strides2
,
self
.
c_strides1
,
self
.
c_strides2
,
workspace13
,
workspace2
,
expert_num_tokens
,
self
.
out_dtype
if
self
.
out_dtype
is
not
None
else
in_dtype
,
self
.
out_dtype
if
self
.
out_dtype
is
not
None
else
in_dtype
,
self
.
per_act_token_quant
,
self
.
per_out_ch_quant
,
self
.
per_act_token_quant
,
self
.
per_out_ch_quant
,
use_batched_format
)
use_batched_format
,
topk_weights
)
class
CutlassExpertsFp8
(
CutlassExpertsFp8Base
):
class
CutlassExpertsFp8
(
CutlassExpertsFp8Base
):
...
@@ -285,12 +288,20 @@ class CutlassExpertsFp8(CutlassExpertsFp8Base):
...
@@ -285,12 +288,20 @@ class CutlassExpertsFp8(CutlassExpertsFp8Base):
out_dtype
:
Optional
[
torch
.
dtype
],
out_dtype
:
Optional
[
torch
.
dtype
],
per_act_token_quant
:
bool
,
per_act_token_quant
:
bool
,
per_out_ch_quant
:
bool
,
per_out_ch_quant
:
bool
,
ab_strides1
:
torch
.
Tensor
,
ab_strides2
:
torch
.
Tensor
,
c_strides1
:
torch
.
Tensor
,
c_strides2
:
torch
.
Tensor
,
block_shape
:
Optional
[
list
[
int
]]
=
None
,
block_shape
:
Optional
[
list
[
int
]]
=
None
,
):
):
super
().
__init__
(
super
().
__init__
(
out_dtype
,
out_dtype
,
per_act_token_quant
,
per_act_token_quant
,
per_out_ch_quant
,
per_out_ch_quant
,
ab_strides1
,
ab_strides2
,
c_strides1
,
c_strides2
,
block_shape
,
block_shape
,
)
)
...
@@ -307,6 +318,10 @@ class CutlassExpertsFp8(CutlassExpertsFp8Base):
...
@@ -307,6 +318,10 @@ class CutlassExpertsFp8(CutlassExpertsFp8Base):
def
supports_expert_map
(
self
)
->
bool
:
def
supports_expert_map
(
self
)
->
bool
:
return
True
return
True
def
finalize_weight_and_reduce_impl
(
self
)
->
mk
.
TopKWeightAndReduce
:
# topk weights and reduction are fused in moe_unpermute cuda kernel
return
TopKWeightAndReduceNoOP
()
def
workspace_shapes
(
def
workspace_shapes
(
self
,
self
,
a
:
torch
.
Tensor
,
a
:
torch
.
Tensor
,
...
@@ -320,8 +335,8 @@ class CutlassExpertsFp8(CutlassExpertsFp8Base):
...
@@ -320,8 +335,8 @@ class CutlassExpertsFp8(CutlassExpertsFp8Base):
expert_tokens_meta
:
Optional
[
mk
.
ExpertTokensMetadata
],
expert_tokens_meta
:
Optional
[
mk
.
ExpertTokensMetadata
],
)
->
tuple
[
tuple
[
int
,
...],
tuple
[
int
,
...],
tuple
[
int
,
...],
torch
.
dtype
]:
)
->
tuple
[
tuple
[
int
,
...],
tuple
[
int
,
...],
tuple
[
int
,
...],
torch
.
dtype
]:
workspace1
=
(
M
*
topk
,
max
(
N
,
K
))
workspace1
=
(
M
*
topk
,
max
(
N
,
K
))
workspace2
=
(
M
*
topk
,
N
//
2
)
workspace2
=
(
M
*
topk
,
max
(
N
//
2
,
K
)
)
output
=
(
M
*
topk
,
K
)
output
=
(
M
,
K
)
return
(
workspace1
,
workspace2
,
output
,
return
(
workspace1
,
workspace2
,
output
,
self
.
out_dtype
if
self
.
out_dtype
is
not
None
else
a
.
dtype
)
self
.
out_dtype
if
self
.
out_dtype
is
not
None
else
a
.
dtype
)
...
@@ -335,12 +350,20 @@ class CutlassBatchedExpertsFp8(CutlassExpertsFp8Base):
...
@@ -335,12 +350,20 @@ class CutlassBatchedExpertsFp8(CutlassExpertsFp8Base):
out_dtype
:
Optional
[
torch
.
dtype
],
out_dtype
:
Optional
[
torch
.
dtype
],
per_act_token_quant
:
bool
,
per_act_token_quant
:
bool
,
per_out_ch_quant
:
bool
,
per_out_ch_quant
:
bool
,
ab_strides1
:
torch
.
Tensor
,
ab_strides2
:
torch
.
Tensor
,
c_strides1
:
torch
.
Tensor
,
c_strides2
:
torch
.
Tensor
,
block_shape
:
Optional
[
list
[
int
]]
=
None
,
block_shape
:
Optional
[
list
[
int
]]
=
None
,
):
):
super
().
__init__
(
super
().
__init__
(
out_dtype
,
out_dtype
,
per_act_token_quant
,
per_act_token_quant
,
per_out_ch_quant
,
per_out_ch_quant
,
ab_strides1
,
ab_strides2
,
c_strides1
,
c_strides2
,
block_shape
,
block_shape
,
)
)
assert
max_experts_per_worker
>
0
assert
max_experts_per_worker
>
0
...
@@ -378,7 +401,8 @@ class CutlassBatchedExpertsFp8(CutlassExpertsFp8Base):
...
@@ -378,7 +401,8 @@ class CutlassBatchedExpertsFp8(CutlassExpertsFp8Base):
assert
num_dp
is
not
None
assert
num_dp
is
not
None
workspace1
=
(
self
.
max_experts_per_worker
,
padded_M
*
num_dp
,
workspace1
=
(
self
.
max_experts_per_worker
,
padded_M
*
num_dp
,
max
(
N
,
K
))
max
(
N
,
K
))
workspace2
=
(
self
.
max_experts_per_worker
,
padded_M
*
num_dp
,
(
N
//
2
))
workspace2
=
(
self
.
max_experts_per_worker
,
padded_M
*
num_dp
,
max
(
N
//
2
,
K
))
output
=
(
self
.
max_experts_per_worker
,
padded_M
,
K
)
output
=
(
self
.
max_experts_per_worker
,
padded_M
,
K
)
return
(
workspace1
,
workspace2
,
output
,
return
(
workspace1
,
workspace2
,
output
,
self
.
out_dtype
if
self
.
out_dtype
is
not
None
else
a
.
dtype
)
self
.
out_dtype
if
self
.
out_dtype
is
not
None
else
a
.
dtype
)
...
@@ -392,6 +416,10 @@ def cutlass_moe_fp8(
...
@@ -392,6 +416,10 @@ def cutlass_moe_fp8(
topk_ids
:
torch
.
Tensor
,
topk_ids
:
torch
.
Tensor
,
w1_scale
:
torch
.
Tensor
,
w1_scale
:
torch
.
Tensor
,
w2_scale
:
torch
.
Tensor
,
w2_scale
:
torch
.
Tensor
,
ab_strides1
:
torch
.
Tensor
,
ab_strides2
:
torch
.
Tensor
,
c_strides1
:
torch
.
Tensor
,
c_strides2
:
torch
.
Tensor
,
per_act_token
:
Optional
[
bool
]
=
None
,
per_act_token
:
Optional
[
bool
]
=
None
,
activation
:
str
=
"silu"
,
activation
:
str
=
"silu"
,
a1_scale
:
Optional
[
torch
.
Tensor
]
=
None
,
a1_scale
:
Optional
[
torch
.
Tensor
]
=
None
,
...
@@ -419,6 +447,17 @@ def cutlass_moe_fp8(
...
@@ -419,6 +447,17 @@ def cutlass_moe_fp8(
Shape: [num_experts] or [num_experts, 2N]
Shape: [num_experts] or [num_experts, 2N]
- w2_scale (torch.Tensor): The fp32 scale to dequantize w2_q.
- w2_scale (torch.Tensor): The fp32 scale to dequantize w2_q.
Shape: [num_experts] or [num_experts, K]
Shape: [num_experts] or [num_experts, K]
- ab_strides1 (torch.Tensor): The input/weight strides for the first gemm.
Shape: [num_experts]
- ab_strides2 (torch.Tensor): The input/weight strides for the second gemm.
Shape: [num_experts]
- c_strides1 (torch.Tensor): The output strides for the first gemm.
Shape: [num_experts]
- c_strides2 (torch.Tensor): The output strides for the second gemm.
Shape: [num_experts]
- per_act_token (Optional[bool]): Whether the scale is per-token or
per-tensor.
- activation (str): The activation function to use.
- a1_scale (Optional[torch.Tensor]): The optional fp32 scale to quantize a.
- a1_scale (Optional[torch.Tensor]): The optional fp32 scale to quantize a.
Shape: scalar or [M]
Shape: scalar or [M]
- a2_scale (Optional[torch.Tensor]): The optional fp32 scale to
- a2_scale (Optional[torch.Tensor]): The optional fp32 scale to
...
@@ -450,6 +489,10 @@ def cutlass_moe_fp8(
...
@@ -450,6 +489,10 @@ def cutlass_moe_fp8(
out_dtype
=
a
.
dtype
,
out_dtype
=
a
.
dtype
,
per_act_token_quant
=
per_act_token
,
per_act_token_quant
=
per_act_token
,
per_out_ch_quant
=
per_out_ch
,
per_out_ch_quant
=
per_out_ch
,
ab_strides1
=
ab_strides1
,
ab_strides2
=
ab_strides2
,
c_strides1
=
c_strides1
,
c_strides2
=
c_strides2
,
),
),
)
)
...
...
vllm/model_executor/layers/fused_moe/moe_permute_unpermute.py
View file @
b17109be
...
@@ -82,7 +82,8 @@ def moe_permute(
...
@@ -82,7 +82,8 @@ def moe_permute(
n_local_expert
:
int
=
-
1
,
n_local_expert
:
int
=
-
1
,
expert_map
:
Optional
[
torch
.
Tensor
]
=
None
,
expert_map
:
Optional
[
torch
.
Tensor
]
=
None
,
align_block_size
:
Optional
[
int
]
=
None
,
align_block_size
:
Optional
[
int
]
=
None
,
fill_invalid_expert
:
int
=
-
1
fill_invalid_expert
:
int
=
-
1
,
permuted_hidden_states
:
Optional
[
torch
.
Tensor
]
=
None
,
)
->
tuple
[
torch
.
Tensor
,
Optional
[
torch
.
Tensor
],
torch
.
Tensor
,
torch
.
Tensor
,
)
->
tuple
[
torch
.
Tensor
,
Optional
[
torch
.
Tensor
],
torch
.
Tensor
,
torch
.
Tensor
,
torch
.
Tensor
]:
torch
.
Tensor
]:
"""
"""
...
@@ -95,14 +96,17 @@ def moe_permute(
...
@@ -95,14 +96,17 @@ def moe_permute(
- n_expert (int): The number of expert.
- n_expert (int): The number of expert.
- n_local_expert (int): The number of expert in current EP rank.
- n_local_expert (int): The number of expert in current EP rank.
- expert_map (Optional[torch.Tensor]): A tensor mapping expert indices
- expert_map (Optional[torch.Tensor]): A tensor mapping expert indices
from the global expert space to the local expert space of the expert
from the global expert space to the local expert space of the expert
parallel shard.
parallel shard.
- align_block_size (Optional[int]): align group gemm block size for deepgemm
- align_block_size (Optional[int]): align group gemm block size for deepgemm
- fill_invalid_expert(int): fill expert id in m_indices for invalid expert
- fill_invalid_expert(int): fill expert id in m_indices for invalid expert
to workaround DeepGemm unsupported -1 in m_indices
to workaround DeepGemm unsupported -1 in m_indices
- permuted_hidden_states (Optional[torch.Tensor]): Optional output tensor.
If None, the output tensor will be created in this function.
Returns:
Returns:
- permuted_hidden_states (torch.Tensor): permuted activation.
- permuted_hidden_states (torch.Tensor): permuted activation.
- a1q_scale (Optional[torch.Tensor]): quant scale for hidden_states
- a1q_scale (Optional[torch.Tensor]): permuted quant scale for hidden_states
if original scale not per-tensor scaling
- expert_first_token_offset (torch.Tensor): offset of the first token
- expert_first_token_offset (torch.Tensor): offset of the first token
of each expert for standard grouped gemm. if enable 'align_block_size'
of each expert for standard grouped gemm. if enable 'align_block_size'
expert_first_token_offset will align up to 'align_block_size'.
expert_first_token_offset will align up to 'align_block_size'.
...
@@ -122,11 +126,16 @@ def moe_permute(
...
@@ -122,11 +126,16 @@ def moe_permute(
1
)
//
align_block_size
*
align_block_size
1
)
//
align_block_size
*
align_block_size
if
n_local_expert
==
-
1
:
if
n_local_expert
==
-
1
:
n_local_expert
=
n_expert
n_local_expert
=
n_expert
permuted_hidden_states
=
torch
.
empty
(
if
permuted_hidden_states
is
None
:
(
permuted_row_size
,
n_hidden
),
permuted_hidden_states
=
torch
.
empty
(
dtype
=
hidden_states
.
dtype
,
(
permuted_row_size
,
n_hidden
),
device
=
hidden_states
.
device
,
dtype
=
hidden_states
.
dtype
,
)
device
=
hidden_states
.
device
,
)
assert
permuted_hidden_states
.
size
()
==
(
permuted_row_size
,
n_hidden
),
(
f
"Expected permuted hidden states to be
{
(
permuted_row_size
,
n_hidden
)
}
"
f
" but got
{
permuted_hidden_states
.
size
()
}
"
)
token_expert_indices
=
torch
.
arange
(
0
,
token_expert_indices
=
torch
.
arange
(
0
,
n_token
*
topk
,
n_token
*
topk
,
dtype
=
torch
.
int32
,
dtype
=
torch
.
int32
,
...
@@ -153,7 +162,8 @@ def moe_permute(
...
@@ -153,7 +162,8 @@ def moe_permute(
align_block_size
,
permuted_hidden_states
,
align_block_size
,
permuted_hidden_states
,
expert_first_token_offset
,
inv_permuted_idx
,
expert_first_token_offset
,
inv_permuted_idx
,
permuted_idx
,
m_indices
)
permuted_idx
,
m_indices
)
if
a1q_scale
is
not
None
:
if
a1q_scale
is
not
None
and
a1q_scale
.
dim
()
>
1
:
a1q_scale
=
a1q_scale
[
permuted_idx
.
clamp
(
max
=
n_token
*
topk
-
1
)
//
a1q_scale
=
a1q_scale
[
permuted_idx
.
clamp
(
max
=
n_token
*
topk
-
1
)
//
topk
]
topk
]
return
(
permuted_hidden_states
,
a1q_scale
,
expert_first_token_offset
,
return
(
permuted_hidden_states
,
a1q_scale
,
expert_first_token_offset
,
...
@@ -185,6 +195,7 @@ def moe_unpermute(
...
@@ -185,6 +195,7 @@ def moe_unpermute(
n_hidden
=
permuted_hidden_states
.
size
(
-
1
)
n_hidden
=
permuted_hidden_states
.
size
(
-
1
)
assert
(
n_hidden
*
permuted_hidden_states
.
element_size
()
assert
(
n_hidden
*
permuted_hidden_states
.
element_size
()
)
%
16
==
0
,
"unpermue kernel need hidden dim align to 16B"
)
%
16
==
0
,
"unpermue kernel need hidden dim align to 16B"
torch
.
ops
.
_moe_C
.
moe_unpermute
(
permuted_hidden_states
,
topk_weights
,
torch
.
ops
.
_moe_C
.
moe_unpermute
(
permuted_hidden_states
,
topk_weights
,
inv_permuted_idx
,
expert_first_token_offset
,
inv_permuted_idx
,
expert_first_token_offset
,
topk
,
out
)
topk
,
out
)
...
...
vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py
View file @
b17109be
...
@@ -669,6 +669,25 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod):
...
@@ -669,6 +669,25 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod):
from
vllm.model_executor.layers.fused_moe
import
fused_experts
from
vllm.model_executor.layers.fused_moe
import
fused_experts
self
.
fused_experts_func
=
fused_experts
self
.
fused_experts_func
=
fused_experts
if
self
.
use_cutlass
:
device
=
layer
.
w13_weight
.
device
# ab_strides1 and c_strides2 are the same
self
.
ab_strides1_c_strides2
=
torch
.
full
(
(
layer
.
local_num_experts
,
),
layer
.
hidden_size
,
device
=
device
,
dtype
=
torch
.
int64
)
self
.
ab_strides2
=
torch
.
full
(
(
layer
.
local_num_experts
,
),
layer
.
intermediate_size_per_partition
,
device
=
device
,
dtype
=
torch
.
int64
)
self
.
c_strides1
=
torch
.
full
(
(
layer
.
local_num_experts
,
),
2
*
layer
.
intermediate_size_per_partition
,
device
=
device
,
dtype
=
torch
.
int64
)
def
select_gemm_impl
(
def
select_gemm_impl
(
self
,
self
,
prepare_finalize
:
FusedMoEPrepareAndFinalize
,
prepare_finalize
:
FusedMoEPrepareAndFinalize
,
...
@@ -693,6 +712,10 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod):
...
@@ -693,6 +712,10 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod):
moe
.
in_dtype
,
moe
.
in_dtype
,
self
.
input_quant
.
strategy
==
QuantizationStrategy
.
TOKEN
,
self
.
input_quant
.
strategy
==
QuantizationStrategy
.
TOKEN
,
self
.
weight_quant
.
strategy
==
QuantizationStrategy
.
CHANNEL
,
self
.
weight_quant
.
strategy
==
QuantizationStrategy
.
CHANNEL
,
ab_strides1
=
self
.
ab_strides1_c_strides2
,
ab_strides2
=
self
.
ab_strides2
,
c_strides1
=
self
.
c_strides1
,
c_strides2
=
self
.
ab_strides1_c_strides2
,
)
)
else
:
else
:
logger
.
debug
(
"CutlassExpertsFp8(%s)"
,
self
.
__class__
.
__name__
)
logger
.
debug
(
"CutlassExpertsFp8(%s)"
,
self
.
__class__
.
__name__
)
...
@@ -700,6 +723,10 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod):
...
@@ -700,6 +723,10 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod):
moe
.
in_dtype
,
moe
.
in_dtype
,
self
.
input_quant
.
strategy
==
QuantizationStrategy
.
TOKEN
,
self
.
input_quant
.
strategy
==
QuantizationStrategy
.
TOKEN
,
self
.
weight_quant
.
strategy
==
QuantizationStrategy
.
CHANNEL
,
self
.
weight_quant
.
strategy
==
QuantizationStrategy
.
CHANNEL
,
ab_strides1
=
self
.
ab_strides1_c_strides2
,
ab_strides2
=
self
.
ab_strides2
,
c_strides1
=
self
.
c_strides1
,
c_strides2
=
self
.
ab_strides1_c_strides2
,
)
)
self
.
disable_expert_map
=
(
num_dispatchers
>
1
self
.
disable_expert_map
=
(
num_dispatchers
>
1
...
@@ -822,6 +849,10 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod):
...
@@ -822,6 +849,10 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod):
expert_map
=
None
if
self
.
disable_expert_map
else
expert_map
,
expert_map
=
None
if
self
.
disable_expert_map
else
expert_map
,
w1_scale
=
layer
.
w13_weight_scale
,
w1_scale
=
layer
.
w13_weight_scale
,
w2_scale
=
layer
.
w2_weight_scale
,
w2_scale
=
layer
.
w2_weight_scale
,
ab_strides1
=
self
.
ab_strides1_c_strides2
,
ab_strides2
=
self
.
ab_strides2
,
c_strides1
=
self
.
c_strides1
,
c_strides2
=
self
.
ab_strides1_c_strides2
,
a1_scale
=
layer
.
w13_input_scale
,
a1_scale
=
layer
.
w13_input_scale
,
a2_scale
=
layer
.
w2_input_scale
,
a2_scale
=
layer
.
w2_input_scale
,
)
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment