Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
b17109be
Unverified
Commit
b17109be
authored
Aug 20, 2025
by
shixianc
Committed by
GitHub
Aug 20, 2025
Browse files
[Kernel] CUTLASS MoE FP8: Integrate cuda moe permute/unpermute (#23045)
Signed-off-by:
Shixian Cui
<
shixian@amazon.com
>
parent
44492358
Changes
15
Show whitespace changes
Inline
Side-by-side
Showing
15 changed files
with
369 additions
and
121 deletions
+369
-121
benchmarks/kernels/benchmark_grouped_gemm_cutlass.py
benchmarks/kernels/benchmark_grouped_gemm_cutlass.py
+34
-1
csrc/moe/moe_permute_unpermute_op.cu
csrc/moe/moe_permute_unpermute_op.cu
+14
-19
csrc/ops.h
csrc/ops.h
+5
-0
csrc/quantization/cutlass_w8a8/moe/get_group_starts.cuh
csrc/quantization/cutlass_w8a8/moe/get_group_starts.cuh
+4
-2
csrc/quantization/cutlass_w8a8/moe/moe_data.cu
csrc/quantization/cutlass_w8a8/moe/moe_data.cu
+50
-15
csrc/quantization/cutlass_w8a8/scaled_mm_entry.cu
csrc/quantization/cutlass_w8a8/scaled_mm_entry.cu
+24
-0
csrc/torch_bindings.cpp
csrc/torch_bindings.cpp
+13
-0
tests/kernels/moe/test_cutlass_moe.py
tests/kernels/moe/test_cutlass_moe.py
+14
-4
tests/kernels/moe/test_moe_permute_unpermute.py
tests/kernels/moe/test_moe_permute_unpermute.py
+5
-1
tests/kernels/moe/test_pplx_cutlass_moe.py
tests/kernels/moe/test_pplx_cutlass_moe.py
+21
-1
tests/kernels/quantization/test_cutlass_scaled_mm.py
tests/kernels/quantization/test_cutlass_scaled_mm.py
+1
-1
vllm/_custom_ops.py
vllm/_custom_ops.py
+22
-0
vllm/model_executor/layers/fused_moe/cutlass_moe.py
vllm/model_executor/layers/fused_moe/cutlass_moe.py
+111
-68
vllm/model_executor/layers/fused_moe/moe_permute_unpermute.py
.../model_executor/layers/fused_moe/moe_permute_unpermute.py
+20
-9
vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py
...quantization/compressed_tensors/compressed_tensors_moe.py
+31
-0
No files found.
benchmarks/kernels/benchmark_grouped_gemm_cutlass.py
View file @
b17109be
...
@@ -80,6 +80,11 @@ def bench_run(
...
@@ -80,6 +80,11 @@ def bench_run(
a
,
score
,
topk
,
renormalize
=
False
a
,
score
,
topk
,
renormalize
=
False
)
)
ab_strides1
=
torch
.
full
((
num_experts
,),
k
,
device
=
"cuda"
,
dtype
=
torch
.
int64
)
ab_strides2
=
torch
.
full
((
num_experts
,),
n
,
device
=
"cuda"
,
dtype
=
torch
.
int64
)
c_strides1
=
torch
.
full
((
num_experts
,),
2
*
n
,
device
=
"cuda"
,
dtype
=
torch
.
int64
)
c_strides2
=
torch
.
full
((
num_experts
,),
k
,
device
=
"cuda"
,
dtype
=
torch
.
int64
)
def
run_triton_moe
(
def
run_triton_moe
(
a
:
torch
.
Tensor
,
a
:
torch
.
Tensor
,
w1
:
torch
.
Tensor
,
w1
:
torch
.
Tensor
,
...
@@ -111,6 +116,10 @@ def bench_run(
...
@@ -111,6 +116,10 @@ def bench_run(
w2
:
torch
.
Tensor
,
w2
:
torch
.
Tensor
,
w1_scale
:
torch
.
Tensor
,
w1_scale
:
torch
.
Tensor
,
w2_scale
:
torch
.
Tensor
,
w2_scale
:
torch
.
Tensor
,
ab_strides1
:
torch
.
Tensor
,
ab_strides2
:
torch
.
Tensor
,
c_strides1
:
torch
.
Tensor
,
c_strides2
:
torch
.
Tensor
,
topk_weights
:
torch
.
Tensor
,
topk_weights
:
torch
.
Tensor
,
topk_ids
:
torch
.
Tensor
,
topk_ids
:
torch
.
Tensor
,
per_act_token
:
bool
,
per_act_token
:
bool
,
...
@@ -125,6 +134,10 @@ def bench_run(
...
@@ -125,6 +134,10 @@ def bench_run(
topk_ids
,
topk_ids
,
w1_scale
,
w1_scale
,
w2_scale
,
w2_scale
,
ab_strides1
,
ab_strides2
,
c_strides1
,
c_strides2
,
per_act_token
,
per_act_token
,
a1_scale
=
None
,
a1_scale
=
None
,
)
)
...
@@ -136,6 +149,10 @@ def bench_run(
...
@@ -136,6 +149,10 @@ def bench_run(
w2_q
:
torch
.
Tensor
,
w2_q
:
torch
.
Tensor
,
w1_scale
:
torch
.
Tensor
,
w1_scale
:
torch
.
Tensor
,
w2_scale
:
torch
.
Tensor
,
w2_scale
:
torch
.
Tensor
,
ab_strides1
:
torch
.
Tensor
,
ab_strides2
:
torch
.
Tensor
,
c_strides1
:
torch
.
Tensor
,
c_strides2
:
torch
.
Tensor
,
topk_weights
:
torch
.
Tensor
,
topk_weights
:
torch
.
Tensor
,
topk_ids
:
torch
.
Tensor
,
topk_ids
:
torch
.
Tensor
,
):
):
...
@@ -150,6 +167,10 @@ def bench_run(
...
@@ -150,6 +167,10 @@ def bench_run(
topk_ids
,
topk_ids
,
w1_scale
,
w1_scale
,
w2_scale
,
w2_scale
,
ab_strides1
,
ab_strides2
,
c_strides1
,
c_strides2
,
per_act_token
,
per_act_token
,
a1_scale
=
None
,
a1_scale
=
None
,
)
)
...
@@ -194,6 +215,10 @@ def bench_run(
...
@@ -194,6 +215,10 @@ def bench_run(
w2_q
,
w2_q
,
w1_scale
,
w1_scale
,
w2_scale
,
w2_scale
,
ab_strides1
,
ab_strides2
,
c_strides1
,
c_strides2
,
topk_weights
,
topk_weights
,
topk_ids
,
topk_ids
,
)
)
...
@@ -231,6 +256,10 @@ def bench_run(
...
@@ -231,6 +256,10 @@ def bench_run(
"w1_scale"
:
w1_scale
,
"w1_scale"
:
w1_scale
,
"w2_scale"
:
w2_scale
,
"w2_scale"
:
w2_scale
,
"per_act_token"
:
per_act_token
,
"per_act_token"
:
per_act_token
,
"ab_strides1"
:
ab_strides1
,
"ab_strides2"
:
ab_strides2
,
"c_strides1"
:
c_strides1
,
"c_strides2"
:
c_strides2
,
# cuda graph params
# cuda graph params
"cutlass_graph"
:
cutlass_graph
,
"cutlass_graph"
:
cutlass_graph
,
"triton_graph"
:
triton_graph
,
"triton_graph"
:
triton_graph
,
...
@@ -289,6 +318,10 @@ def bench_run(
...
@@ -289,6 +318,10 @@ def bench_run(
w2_q
,
w2_q
,
w1_scale
,
w1_scale
,
w2_scale
,
w2_scale
,
ab_strides1
,
ab_strides2
,
c_strides1
,
c_strides2
,
topk_weights
,
topk_weights
,
topk_ids
,
topk_ids
,
per_act_token
,
per_act_token
,
...
@@ -297,7 +330,7 @@ def bench_run(
...
@@ -297,7 +330,7 @@ def bench_run(
results
.
append
(
results
.
append
(
benchmark
.
Timer
(
benchmark
.
Timer
(
stmt
=
"run_cutlass_moe(a, a_scale, w1_q, w2_q, w1_scale, w2_scale, topk_weights, topk_ids, per_act_token, num_runs)"
,
# noqa: E501
stmt
=
"run_cutlass_moe(a, a_scale, w1_q, w2_q, w1_scale, w2_scale,
ab_strides1, ab_strides2, c_strides1, c_strides2,
topk_weights, topk_ids, per_act_token, num_runs)"
,
# noqa: E501
globals
=
globals
,
globals
=
globals
,
label
=
label
,
label
=
label
,
sub_label
=
sub_label
,
sub_label
=
sub_label
,
...
...
csrc/moe/moe_permute_unpermute_op.cu
View file @
b17109be
...
@@ -45,8 +45,6 @@ void moe_permute(
...
@@ -45,8 +45,6 @@ void moe_permute(
auto
copy_topk_ids
=
topk_ids
.
clone
();
// copy topk_ids for preprocess
auto
copy_topk_ids
=
topk_ids
.
clone
();
// copy topk_ids for preprocess
auto
permuted_experts_id
=
torch
::
empty_like
(
topk_ids
);
auto
permuted_experts_id
=
torch
::
empty_like
(
topk_ids
);
auto
sorted_row_idx
=
torch
::
empty_like
(
inv_permuted_idx
);
auto
sorted_row_idx
=
torch
::
empty_like
(
inv_permuted_idx
);
auto
align_expert_first_token_offset
=
torch
::
zeros_like
(
expert_first_token_offset
);
CubKeyValueSorter
sorter
{};
CubKeyValueSorter
sorter
{};
int64_t
*
valid_num_ptr
=
nullptr
;
int64_t
*
valid_num_ptr
=
nullptr
;
...
@@ -85,12 +83,14 @@ void moe_permute(
...
@@ -85,12 +83,14 @@ void moe_permute(
});
});
// get m_indices and update expert_first_token_offset with align block
// get m_indices and update expert_first_token_offset with align block
// this is only required for DeepGemm and not required for CUTLASS group gemm
if
(
align_block_size
.
has_value
())
{
auto
align_expert_first_token_offset
=
torch
::
zeros_like
(
expert_first_token_offset
);
getMIndices
(
get_ptr
<
int64_t
>
(
expert_first_token_offset
),
getMIndices
(
get_ptr
<
int64_t
>
(
expert_first_token_offset
),
get_ptr
<
int64_t
>
(
align_expert_first_token_offset
),
get_ptr
<
int64_t
>
(
align_expert_first_token_offset
),
get_ptr
<
int
>
(
m_indices
),
n_local_expert
,
align_block_size_value
,
get_ptr
<
int
>
(
m_indices
),
n_local_expert
,
align_block_size_value
,
stream
);
stream
);
if
(
align_block_size
.
has_value
())
{
// update align_expert_first_token_offset
expert_first_token_offset
.
copy_
(
align_expert_first_token_offset
);
expert_first_token_offset
.
copy_
(
align_expert_first_token_offset
);
}
}
}
}
...
@@ -195,19 +195,14 @@ void moe_permute(const torch::Tensor& input, const torch::Tensor& topk_weights,
...
@@ -195,19 +195,14 @@ void moe_permute(const torch::Tensor& input, const torch::Tensor& topk_weights,
torch
::
Tensor
&
expert_first_token_offset
,
torch
::
Tensor
&
expert_first_token_offset
,
torch
::
Tensor
&
src_row_id2dst_row_id_map
,
torch
::
Tensor
&
src_row_id2dst_row_id_map
,
torch
::
Tensor
&
m_indices
)
{
torch
::
Tensor
&
m_indices
)
{
TORCH_CHECK
(
false
,
"moe_
un
permute is not supported on CUDA < 12.0"
);
TORCH_CHECK
(
false
,
"moe_permute is not supported on CUDA < 12.0"
);
}
}
void
moe_unpermute
(
const
torch
::
Tensor
&
input
,
void
moe_unpermute
(
const
torch
::
Tensor
&
topk_weights
,
torch
::
Tensor
&
topk_ids
,
const
torch
::
Tensor
&
permuted_hidden_states
,
const
torch
::
Tensor
&
token_expert_indices
,
const
torch
::
Tensor
&
topk_weights
,
const
torch
::
Tensor
&
inv_permuted_idx
,
const
std
::
optional
<
torch
::
Tensor
>&
expert_map
,
const
std
::
optional
<
torch
::
Tensor
>&
expert_first_token_offset
,
int64_t
topk
,
int64_t
n_expert
,
int64_t
n_local_expert
,
int64_t
topk
,
torch
::
Tensor
&
hidden_states
)
{
const
std
::
optional
<
int64_t
>&
align_block_size
,
torch
::
Tensor
&
permuted_input
,
torch
::
Tensor
&
expert_first_token_offset
,
torch
::
Tensor
&
src_row_id2dst_row_id_map
,
torch
::
Tensor
&
m_indices
)
{
TORCH_CHECK
(
false
,
"moe_unpermute is not supported on CUDA < 12.0"
);
TORCH_CHECK
(
false
,
"moe_unpermute is not supported on CUDA < 12.0"
);
}
}
...
...
csrc/ops.h
View file @
b17109be
...
@@ -229,6 +229,11 @@ void get_cutlass_moe_mm_data(
...
@@ -229,6 +229,11 @@ void get_cutlass_moe_mm_data(
const
int64_t
num_experts
,
const
int64_t
n
,
const
int64_t
k
,
const
int64_t
num_experts
,
const
int64_t
n
,
const
int64_t
k
,
const
std
::
optional
<
torch
::
Tensor
>&
blockscale_offsets
);
const
std
::
optional
<
torch
::
Tensor
>&
blockscale_offsets
);
void
get_cutlass_moe_mm_problem_sizes
(
const
torch
::
Tensor
&
topk_ids
,
torch
::
Tensor
&
problem_sizes1
,
torch
::
Tensor
&
problem_sizes2
,
const
int64_t
num_experts
,
const
int64_t
n
,
const
int64_t
k
,
const
std
::
optional
<
torch
::
Tensor
>&
blockscale_offsets
);
void
get_cutlass_pplx_moe_mm_data
(
torch
::
Tensor
&
expert_offsets
,
void
get_cutlass_pplx_moe_mm_data
(
torch
::
Tensor
&
expert_offsets
,
torch
::
Tensor
&
problem_sizes1
,
torch
::
Tensor
&
problem_sizes1
,
torch
::
Tensor
&
problem_sizes2
,
torch
::
Tensor
&
problem_sizes2
,
...
...
csrc/quantization/cutlass_w8a8/moe/get_group_starts.cuh
View file @
b17109be
...
@@ -10,7 +10,7 @@
...
@@ -10,7 +10,7 @@
template
<
typename
ElementAB
,
typename
ElementC
,
typename
ElementAccumulator
>
template
<
typename
ElementAB
,
typename
ElementC
,
typename
ElementAccumulator
>
__global__
void
get_group_gemm_starts
(
__global__
void
get_group_gemm_starts
(
int
32
_t
*
expert_offsets
,
ElementAB
**
a_offsets
,
ElementAB
**
b_offsets
,
int
64
_t
*
expert_offsets
,
ElementAB
**
a_offsets
,
ElementAB
**
b_offsets
,
ElementC
**
out_offsets
,
ElementAccumulator
**
a_scales_offsets
,
ElementC
**
out_offsets
,
ElementAccumulator
**
a_scales_offsets
,
ElementAccumulator
**
b_scales_offsets
,
ElementAB
*
a_base_as_int
,
ElementAccumulator
**
b_scales_offsets
,
ElementAB
*
a_base_as_int
,
ElementAB
*
b_base_as_int
,
ElementC
*
out_base_as_int
,
ElementAB
*
b_base_as_int
,
ElementC
*
out_base_as_int
,
...
@@ -34,7 +34,7 @@ __global__ void get_group_gemm_starts(
...
@@ -34,7 +34,7 @@ __global__ void get_group_gemm_starts(
else if (out_tensors.dtype() == TENSOR_C_TYPE) { \
else if (out_tensors.dtype() == TENSOR_C_TYPE) { \
get_group_gemm_starts<cutlass::float_e4m3_t, C_TYPE, float> \
get_group_gemm_starts<cutlass::float_e4m3_t, C_TYPE, float> \
<<<1, num_experts, 0, stream>>>( \
<<<1, num_experts, 0, stream>>>( \
static_cast<int
32
_t*>(expert_offsets.data_ptr()), \
static_cast<int
64
_t*>(expert_offsets.data_ptr()), \
static_cast<cutlass::float_e4m3_t**>(a_ptrs.data_ptr()), \
static_cast<cutlass::float_e4m3_t**>(a_ptrs.data_ptr()), \
static_cast<cutlass::float_e4m3_t**>(b_ptrs.data_ptr()), \
static_cast<cutlass::float_e4m3_t**>(b_ptrs.data_ptr()), \
static_cast<C_TYPE**>(out_ptrs.data_ptr()), \
static_cast<C_TYPE**>(out_ptrs.data_ptr()), \
...
@@ -61,6 +61,8 @@ void run_get_group_gemm_starts(
...
@@ -61,6 +61,8 @@ void run_get_group_gemm_starts(
TORCH_CHECK
(
b_tensors
.
dtype
()
==
torch
::
kFloat8_e4m3fn
);
TORCH_CHECK
(
b_tensors
.
dtype
()
==
torch
::
kFloat8_e4m3fn
);
TORCH_CHECK
(
a_scales
.
dtype
()
==
torch
::
kFloat32
);
TORCH_CHECK
(
a_scales
.
dtype
()
==
torch
::
kFloat32
);
TORCH_CHECK
(
b_scales
.
dtype
()
==
torch
::
kFloat32
);
TORCH_CHECK
(
b_scales
.
dtype
()
==
torch
::
kFloat32
);
// expect int64_t to avoid overflow during offset calculations
TORCH_CHECK
(
expert_offsets
.
dtype
()
==
torch
::
kInt64
);
int
num_experts
=
static_cast
<
int
>
(
expert_offsets
.
size
(
0
));
int
num_experts
=
static_cast
<
int
>
(
expert_offsets
.
size
(
0
));
bool
per_act_token
=
a_scales
.
numel
()
!=
1
;
bool
per_act_token
=
a_scales
.
numel
()
!=
1
;
...
...
csrc/quantization/cutlass_w8a8/moe/moe_data.cu
View file @
b17109be
...
@@ -104,6 +104,53 @@ __global__ void compute_arg_sorts(const int32_t* __restrict__ topk_ids,
...
@@ -104,6 +104,53 @@ __global__ void compute_arg_sorts(const int32_t* __restrict__ topk_ids,
}
}
}
}
namespace
{
inline
void
launch_compute_problem_sizes
(
const
torch
::
Tensor
&
topk_ids
,
torch
::
Tensor
&
problem_sizes1
,
torch
::
Tensor
&
problem_sizes2
,
torch
::
Tensor
&
atomic_buffer
,
int64_t
num_experts
,
int64_t
n
,
int64_t
k
,
cudaStream_t
stream
,
const
bool
swap_ab
)
{
int
num_threads
=
min
(
THREADS_PER_EXPERT
,
topk_ids
.
numel
());
const
int32_t
*
topk_ptr
=
static_cast
<
const
int32_t
*>
(
topk_ids
.
data_ptr
());
int32_t
*
ps1_ptr
=
static_cast
<
int32_t
*>
(
problem_sizes1
.
data_ptr
());
int32_t
*
ps2_ptr
=
static_cast
<
int32_t
*>
(
problem_sizes2
.
data_ptr
());
int32_t
*
atomic_ptr
=
static_cast
<
int32_t
*>
(
atomic_buffer
.
data_ptr
());
if
(
swap_ab
)
{
compute_problem_sizes
<
true
><<<
num_experts
,
num_threads
,
0
,
stream
>>>
(
topk_ptr
,
ps1_ptr
,
ps2_ptr
,
atomic_ptr
,
static_cast
<
int
>
(
topk_ids
.
numel
()),
static_cast
<
int
>
(
n
),
static_cast
<
int
>
(
k
));
}
else
{
compute_problem_sizes
<
false
><<<
num_experts
,
num_threads
,
0
,
stream
>>>
(
topk_ptr
,
ps1_ptr
,
ps2_ptr
,
atomic_ptr
,
static_cast
<
int
>
(
topk_ids
.
numel
()),
static_cast
<
int
>
(
n
),
static_cast
<
int
>
(
k
));
}
}
}
// namespace
void
get_cutlass_moe_mm_problem_sizes_caller
(
const
torch
::
Tensor
&
topk_ids
,
torch
::
Tensor
&
problem_sizes1
,
torch
::
Tensor
&
problem_sizes2
,
const
int64_t
num_experts
,
const
int64_t
n
,
const
int64_t
k
,
const
std
::
optional
<
torch
::
Tensor
>&
blockscale_offsets
)
{
auto
stream
=
at
::
cuda
::
getCurrentCUDAStream
(
topk_ids
.
device
().
index
());
auto
options_int32
=
torch
::
TensorOptions
().
dtype
(
torch
::
kInt32
).
device
(
topk_ids
.
device
());
torch
::
Tensor
atomic_buffer
=
torch
::
zeros
(
num_experts
,
options_int32
);
// Swap-AB should be disabled for FP4 path
bool
may_swap_ab
=
(
!
blockscale_offsets
.
has_value
())
&&
(
topk_ids
.
numel
()
<=
SWAP_AB_THRESHOLD
);
launch_compute_problem_sizes
(
topk_ids
,
problem_sizes1
,
problem_sizes2
,
atomic_buffer
,
num_experts
,
n
,
k
,
stream
,
may_swap_ab
);
}
void
get_cutlass_moe_mm_data_caller
(
void
get_cutlass_moe_mm_data_caller
(
const
torch
::
Tensor
&
topk_ids
,
torch
::
Tensor
&
expert_offsets
,
const
torch
::
Tensor
&
topk_ids
,
torch
::
Tensor
&
expert_offsets
,
torch
::
Tensor
&
problem_sizes1
,
torch
::
Tensor
&
problem_sizes2
,
torch
::
Tensor
&
problem_sizes1
,
torch
::
Tensor
&
problem_sizes2
,
...
@@ -121,21 +168,9 @@ void get_cutlass_moe_mm_data_caller(
...
@@ -121,21 +168,9 @@ void get_cutlass_moe_mm_data_caller(
bool
may_swap_ab
=
(
!
blockscale_offsets
.
has_value
())
&&
bool
may_swap_ab
=
(
!
blockscale_offsets
.
has_value
())
&&
(
topk_ids
.
numel
()
<=
SWAP_AB_THRESHOLD
);
(
topk_ids
.
numel
()
<=
SWAP_AB_THRESHOLD
);
if
(
may_swap_ab
)
{
launch_compute_problem_sizes
(
topk_ids
,
problem_sizes1
,
problem_sizes2
,
compute_problem_sizes
<
true
><<<
num_experts
,
num_threads
,
0
,
stream
>>>
(
atomic_buffer
,
num_experts
,
n
,
k
,
stream
,
static_cast
<
const
int32_t
*>
(
topk_ids
.
data_ptr
()),
may_swap_ab
);
static_cast
<
int32_t
*>
(
problem_sizes1
.
data_ptr
()),
static_cast
<
int32_t
*>
(
problem_sizes2
.
data_ptr
()),
static_cast
<
int32_t
*>
(
atomic_buffer
.
data_ptr
()),
topk_ids
.
numel
(),
n
,
k
);
}
else
{
compute_problem_sizes
<
false
><<<
num_experts
,
num_threads
,
0
,
stream
>>>
(
static_cast
<
const
int32_t
*>
(
topk_ids
.
data_ptr
()),
static_cast
<
int32_t
*>
(
problem_sizes1
.
data_ptr
()),
static_cast
<
int32_t
*>
(
problem_sizes2
.
data_ptr
()),
static_cast
<
int32_t
*>
(
atomic_buffer
.
data_ptr
()),
topk_ids
.
numel
(),
n
,
k
);
}
if
(
blockscale_offsets
.
has_value
())
{
if
(
blockscale_offsets
.
has_value
())
{
// fp4 path
// fp4 path
...
...
csrc/quantization/cutlass_w8a8/scaled_mm_entry.cu
View file @
b17109be
...
@@ -76,6 +76,11 @@ void get_cutlass_moe_mm_data_caller(
...
@@ -76,6 +76,11 @@ void get_cutlass_moe_mm_data_caller(
const
int64_t
num_experts
,
const
int64_t
n
,
const
int64_t
k
,
const
int64_t
num_experts
,
const
int64_t
n
,
const
int64_t
k
,
const
std
::
optional
<
torch
::
Tensor
>&
blockscale_offsets
);
const
std
::
optional
<
torch
::
Tensor
>&
blockscale_offsets
);
void
get_cutlass_moe_mm_problem_sizes_caller
(
const
torch
::
Tensor
&
topk_ids
,
torch
::
Tensor
&
problem_sizes1
,
torch
::
Tensor
&
problem_sizes2
,
const
int64_t
num_experts
,
const
int64_t
n
,
const
int64_t
k
,
const
std
::
optional
<
torch
::
Tensor
>&
blockscale_offsets
);
void
get_cutlass_pplx_moe_mm_data_caller
(
torch
::
Tensor
&
expert_offsets
,
void
get_cutlass_pplx_moe_mm_data_caller
(
torch
::
Tensor
&
expert_offsets
,
torch
::
Tensor
&
problem_sizes1
,
torch
::
Tensor
&
problem_sizes1
,
torch
::
Tensor
&
problem_sizes2
,
torch
::
Tensor
&
problem_sizes2
,
...
@@ -293,6 +298,25 @@ void get_cutlass_moe_mm_data(
...
@@ -293,6 +298,25 @@ void get_cutlass_moe_mm_data(
version_num
,
". Required capability: 90 or 100"
);
version_num
,
". Required capability: 90 or 100"
);
}
}
void
get_cutlass_moe_mm_problem_sizes
(
const
torch
::
Tensor
&
topk_ids
,
torch
::
Tensor
&
problem_sizes1
,
torch
::
Tensor
&
problem_sizes2
,
const
int64_t
num_experts
,
const
int64_t
n
,
const
int64_t
k
,
const
std
::
optional
<
torch
::
Tensor
>&
blockscale_offsets
)
{
int32_t
version_num
=
get_sm_version_num
();
#if (defined ENABLE_CUTLASS_MOE_SM90 && ENABLE_CUTLASS_MOE_SM90) || \
(defined ENABLE_CUTLASS_MOE_SM100 && ENABLE_CUTLASS_MOE_SM100)
get_cutlass_moe_mm_problem_sizes_caller
(
topk_ids
,
problem_sizes1
,
problem_sizes2
,
num_experts
,
n
,
k
,
blockscale_offsets
);
return
;
#endif
TORCH_CHECK_NOT_IMPLEMENTED
(
false
,
"No compiled get_cutlass_moe_mm_problem_sizes: no cutlass_scaled_mm "
"kernel for CUDA device capability: "
,
version_num
,
". Required capability: 90 or 100"
);
}
void
get_cutlass_pplx_moe_mm_data
(
torch
::
Tensor
&
expert_offsets
,
void
get_cutlass_pplx_moe_mm_data
(
torch
::
Tensor
&
expert_offsets
,
torch
::
Tensor
&
problem_sizes1
,
torch
::
Tensor
&
problem_sizes1
,
torch
::
Tensor
&
problem_sizes2
,
torch
::
Tensor
&
problem_sizes2
,
...
...
csrc/torch_bindings.cpp
View file @
b17109be
...
@@ -440,6 +440,19 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
...
@@ -440,6 +440,19 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
{
stride_tag
});
{
stride_tag
});
ops
.
impl
(
"get_cutlass_moe_mm_data"
,
torch
::
kCUDA
,
&
get_cutlass_moe_mm_data
);
ops
.
impl
(
"get_cutlass_moe_mm_data"
,
torch
::
kCUDA
,
&
get_cutlass_moe_mm_data
);
// A function that computes problem sizes for each expert's multiplication
// used by the two mms called from fused MoE operation. It takes topk_ids as
// an input, and computes problem_sizes1 and problem_sizes2 only.
ops
.
def
(
"get_cutlass_moe_mm_problem_sizes(Tensor topk_ids, "
" Tensor! problem_sizes1, "
" Tensor! problem_sizes2, "
" int num_experts, int n, int k, "
" Tensor? blockscale_offsets) -> ()"
,
{
stride_tag
});
ops
.
impl
(
"get_cutlass_moe_mm_problem_sizes"
,
torch
::
kCUDA
,
&
get_cutlass_moe_mm_problem_sizes
);
// A function that computes data required to run fused MoE with w8a8 grouped
// A function that computes data required to run fused MoE with w8a8 grouped
// GEMM and PPLX. It takes expert_num_tokens and non_zero_expert_idxs
// GEMM and PPLX. It takes expert_num_tokens and non_zero_expert_idxs
// as an input, and computes expert_offsets (token start indices of each
// as an input, and computes expert_offsets (token start indices of each
...
...
tests/kernels/moe/test_cutlass_moe.py
View file @
b17109be
...
@@ -207,6 +207,10 @@ def run_8_bit(moe_tensors: MOETensors8Bit,
...
@@ -207,6 +207,10 @@ def run_8_bit(moe_tensors: MOETensors8Bit,
'topk_ids'
:
topk_ids
,
'topk_ids'
:
topk_ids
,
'w1_scale'
:
moe_tensors
.
w1_scale
,
'w1_scale'
:
moe_tensors
.
w1_scale
,
'w2_scale'
:
moe_tensors
.
w2_scale
,
'w2_scale'
:
moe_tensors
.
w2_scale
,
'ab_strides1'
:
moe_tensors
.
ab_strides1
,
'ab_strides2'
:
moe_tensors
.
ab_strides2
,
'c_strides1'
:
moe_tensors
.
c_strides1
,
'c_strides2'
:
moe_tensors
.
c_strides2
,
'per_act_token'
:
per_act_token
,
'per_act_token'
:
per_act_token
,
'a1_scale'
:
None
#moe_tensors.a_scale
'a1_scale'
:
None
#moe_tensors.a_scale
}
}
...
@@ -424,8 +428,8 @@ def test_run_cutlass_moe_fp8(
...
@@ -424,8 +428,8 @@ def test_run_cutlass_moe_fp8(
topk_ids
[
0
][
1
]
=
1
topk_ids
[
0
][
1
]
=
1
workspace13_shape
=
(
m
*
topk
,
max
(
2
*
n
,
k
))
workspace13_shape
=
(
m
*
topk
,
max
(
2
*
n
,
k
))
workspace2_shape
=
(
m
*
topk
,
n
)
workspace2_shape
=
(
m
*
topk
,
max
(
n
,
k
)
)
output_shape
=
(
m
*
topk
,
k
)
output_shape
=
(
m
,
k
)
workspace13
=
torch
.
empty
(
prod
(
workspace13_shape
),
workspace13
=
torch
.
empty
(
prod
(
workspace13_shape
),
device
=
"cuda"
,
device
=
"cuda"
,
...
@@ -440,6 +444,11 @@ def test_run_cutlass_moe_fp8(
...
@@ -440,6 +444,11 @@ def test_run_cutlass_moe_fp8(
expert_map
[
start
:
end
]
=
list
(
range
(
num_local_experts
))
expert_map
[
start
:
end
]
=
list
(
range
(
num_local_experts
))
expert_map
=
torch
.
tensor
(
expert_map
,
dtype
=
torch
.
int32
,
device
=
"cuda"
)
expert_map
=
torch
.
tensor
(
expert_map
,
dtype
=
torch
.
int32
,
device
=
"cuda"
)
ab_strides1
=
torch
.
full
((
e
,
),
k
,
device
=
"cuda"
,
dtype
=
torch
.
int64
)
ab_strides2
=
torch
.
full
((
e
,
),
n
,
device
=
"cuda"
,
dtype
=
torch
.
int64
)
c_strides1
=
torch
.
full
((
e
,
),
2
*
n
,
device
=
"cuda"
,
dtype
=
torch
.
int64
)
c_strides2
=
torch
.
full
((
e
,
),
k
,
device
=
"cuda"
,
dtype
=
torch
.
int64
)
activation
=
lambda
o
,
i
:
torch
.
ops
.
_C
.
silu_and_mul
(
o
,
i
)
activation
=
lambda
o
,
i
:
torch
.
ops
.
_C
.
silu_and_mul
(
o
,
i
)
a1q
,
a1q_scale
=
moe_kernel_quantize_input
(
mt
.
a
,
mt
.
a_scale
,
a1q
,
a1q_scale
=
moe_kernel_quantize_input
(
mt
.
a
,
mt
.
a_scale
,
torch
.
float8_e4m3fn
,
torch
.
float8_e4m3fn
,
...
@@ -448,8 +457,9 @@ def test_run_cutlass_moe_fp8(
...
@@ -448,8 +457,9 @@ def test_run_cutlass_moe_fp8(
func
=
lambda
output
:
run_cutlass_moe_fp8
(
func
=
lambda
output
:
run_cutlass_moe_fp8
(
output
,
a1q
,
mt
.
w1_q
,
mt
.
w2_q
,
topk_ids
,
activation
,
output
,
a1q
,
mt
.
w1_q
,
mt
.
w2_q
,
topk_ids
,
activation
,
global_num_experts
,
expert_map
,
mt
.
w1_scale
,
mt
.
w2_scale
,
global_num_experts
,
expert_map
,
mt
.
w1_scale
,
mt
.
w2_scale
,
a1q_scale
,
None
,
workspace13
,
workspace2
,
None
,
mt
.
a
.
dtype
,
a1q_scale
,
None
,
ab_strides1
,
ab_strides2
,
c_strides1
,
c_strides2
,
per_act_token
,
per_out_channel
,
False
)
workspace13
,
workspace2
,
None
,
mt
.
a
.
dtype
,
per_act_token
,
per_out_channel
,
False
,
topk_weights
)
workspace13
.
random_
()
workspace13
.
random_
()
output_random_workspace
=
torch
.
empty
(
output_shape
,
output_random_workspace
=
torch
.
empty
(
output_shape
,
...
...
tests/kernels/moe/test_moe_permute_unpermute.py
View file @
b17109be
...
@@ -238,7 +238,11 @@ def test_moe_permute_unpermute(n_token: int, n_hidden: int, topk: int,
...
@@ -238,7 +238,11 @@ def test_moe_permute_unpermute(n_token: int, n_hidden: int, topk: int,
atol
=
0
,
atol
=
0
,
rtol
=
0
)
rtol
=
0
)
# check mindice
# check mindice
# current kernel usage assumes deepgemm requires align_block_size
# when it's not provided then we don't compute m_indices (for cutlass)
if
align_block_size
is
not
None
:
torch
.
testing
.
assert_close
(
gold_m_indices
,
m_indices
,
atol
=
0
,
rtol
=
0
)
torch
.
testing
.
assert_close
(
gold_m_indices
,
m_indices
,
atol
=
0
,
rtol
=
0
)
# check permuted_hidden_states, only valid token
# check permuted_hidden_states, only valid token
torch
.
testing
.
assert_close
(
gold_permuted_hidden_states
[
valid_row_idx
],
torch
.
testing
.
assert_close
(
gold_permuted_hidden_states
[
valid_row_idx
],
permuted_hidden_states
[
valid_row_idx
],
permuted_hidden_states
[
valid_row_idx
],
...
...
tests/kernels/moe/test_pplx_cutlass_moe.py
View file @
b17109be
...
@@ -76,6 +76,7 @@ def pplx_cutlass_moe(
...
@@ -76,6 +76,7 @@ def pplx_cutlass_moe(
assert
torch
.
cuda
.
current_device
()
==
pgi
.
local_rank
assert
torch
.
cuda
.
current_device
()
==
pgi
.
local_rank
num_tokens
,
hidden_dim
=
a
.
shape
num_tokens
,
hidden_dim
=
a
.
shape
intermediate_dim
=
w2
.
shape
[
2
]
num_experts
=
w1
.
shape
[
0
]
num_experts
=
w1
.
shape
[
0
]
block_size
=
hidden_dim
# TODO support more cases
block_size
=
hidden_dim
# TODO support more cases
device
=
pgi
.
device
device
=
pgi
.
device
...
@@ -124,8 +125,27 @@ def pplx_cutlass_moe(
...
@@ -124,8 +125,27 @@ def pplx_cutlass_moe(
num_local_experts
=
num_local_experts
,
num_local_experts
=
num_local_experts
,
num_dispatchers
=
num_dispatchers
)
num_dispatchers
=
num_dispatchers
)
ab_strides1
=
torch
.
full
((
num_local_experts
,
),
hidden_dim
,
device
=
"cuda"
,
dtype
=
torch
.
int64
)
ab_strides2
=
torch
.
full
((
num_local_experts
,
),
intermediate_dim
,
device
=
"cuda"
,
dtype
=
torch
.
int64
)
c_strides1
=
torch
.
full
((
num_local_experts
,
),
2
*
intermediate_dim
,
device
=
"cuda"
,
dtype
=
torch
.
int64
)
c_strides2
=
torch
.
full
((
num_local_experts
,
),
hidden_dim
,
device
=
"cuda"
,
dtype
=
torch
.
int64
)
experts
=
CutlassBatchedExpertsFp8
(
num_local_experts
,
num_dispatchers
,
experts
=
CutlassBatchedExpertsFp8
(
num_local_experts
,
num_dispatchers
,
out_dtype
,
per_act_token
,
per_out_ch
)
out_dtype
,
per_act_token
,
per_out_ch
,
ab_strides1
,
ab_strides2
,
c_strides1
,
c_strides2
)
fused_cutlass_experts
=
FusedMoEModularKernel
(
fused_cutlass_experts
=
FusedMoEModularKernel
(
prepare_finalize
,
prepare_finalize
,
...
...
tests/kernels/quantization/test_cutlass_scaled_mm.py
View file @
b17109be
...
@@ -535,7 +535,7 @@ def test_cutlass_fp8_group_gemm(num_experts: int, per_act_token: bool,
...
@@ -535,7 +535,7 @@ def test_cutlass_fp8_group_gemm(num_experts: int, per_act_token: bool,
expert_offsets
=
torch
.
zeros
((
num_experts
+
1
),
expert_offsets
=
torch
.
zeros
((
num_experts
+
1
),
device
=
device
,
device
=
device
,
dtype
=
torch
.
int
32
)
dtype
=
torch
.
int
64
)
problem_sizes
=
torch
.
zeros
((
num_experts
,
3
),
problem_sizes
=
torch
.
zeros
((
num_experts
,
3
),
device
=
device
,
device
=
device
,
...
...
vllm/_custom_ops.py
View file @
b17109be
...
@@ -844,6 +844,28 @@ def get_cutlass_moe_mm_data(topk_ids: torch.Tensor,
...
@@ -844,6 +844,28 @@ def get_cutlass_moe_mm_data(topk_ids: torch.Tensor,
blockscale_offsets
)
blockscale_offsets
)
def
get_cutlass_moe_mm_problem_sizes
(
topk_ids
:
torch
.
Tensor
,
problem_sizes1
:
torch
.
Tensor
,
problem_sizes2
:
torch
.
Tensor
,
num_experts
:
int
,
n
:
int
,
k
:
int
,
blockscale_offsets
:
Optional
[
torch
.
Tensor
]
=
None
):
"""
Compute only the per-expert problem sizes needed by the two grouped matrix
multiplications used in CUTLASS-based fused MoE.
The function takes in topk_ids (token→expert mapping) and computes:
- problem_sizes1, problem_sizes2: M×N×K sizes of each expert's
multiplication for the two grouped MMs
used in the fused MoE operation.
"""
return
torch
.
ops
.
_C
.
get_cutlass_moe_mm_problem_sizes
(
topk_ids
,
problem_sizes1
,
problem_sizes2
,
num_experts
,
n
,
k
,
blockscale_offsets
)
def
shuffle_rows
(
input_tensor
:
torch
.
Tensor
,
dst2src_map
:
torch
.
Tensor
):
def
shuffle_rows
(
input_tensor
:
torch
.
Tensor
,
dst2src_map
:
torch
.
Tensor
):
"""
"""
Shuffle and expand the input tensor according to the dst2src_map and store the result in output_tensor.
Shuffle and expand the input tensor according to the dst2src_map and store the result in output_tensor.
...
...
vllm/model_executor/layers/fused_moe/cutlass_moe.py
View file @
b17109be
...
@@ -9,12 +9,13 @@ import vllm.model_executor.layers.fused_moe.modular_kernel as mk
...
@@ -9,12 +9,13 @@ import vllm.model_executor.layers.fused_moe.modular_kernel as mk
from
vllm
import
_custom_ops
as
ops
from
vllm
import
_custom_ops
as
ops
from
vllm.logger
import
init_logger
from
vllm.logger
import
init_logger
from
vllm.model_executor.layers.fused_moe.config
import
FusedMoEQuantConfig
from
vllm.model_executor.layers.fused_moe.config
import
FusedMoEQuantConfig
from
vllm.model_executor.layers.fused_moe.moe_permute_unpermute
import
(
moe_permute
,
moe_unpermute
)
from
vllm.model_executor.layers.fused_moe.prepare_finalize
import
(
from
vllm.model_executor.layers.fused_moe.prepare_finalize
import
(
MoEPrepareAndFinalizeNoEP
)
MoEPrepareAndFinalizeNoEP
)
from
vllm.model_executor.layers.fused_moe.topk_weight_and_reduce
import
(
from
vllm.model_executor.layers.fused_moe.topk_weight_and_reduce
import
(
TopKWeightAndReduceDelegate
,
TopKWeightAndReduceNoOP
)
TopKWeightAndReduceDelegate
,
TopKWeightAndReduceNoOP
)
from
vllm.model_executor.layers.fused_moe.utils
import
(
_fp8_perm
,
from
vllm.model_executor.layers.fused_moe.utils
import
(
_fp8_quantize
,
_fp8_quantize
,
_resize_cache
)
_resize_cache
)
from
vllm.scalar_type
import
scalar_types
from
vllm.scalar_type
import
scalar_types
...
@@ -34,6 +35,10 @@ def run_cutlass_moe_fp8(
...
@@ -34,6 +35,10 @@ def run_cutlass_moe_fp8(
w2_scale
:
Optional
[
torch
.
Tensor
],
w2_scale
:
Optional
[
torch
.
Tensor
],
a1q_scale
:
Optional
[
torch
.
Tensor
],
a1q_scale
:
Optional
[
torch
.
Tensor
],
a2_scale
:
Optional
[
torch
.
Tensor
],
a2_scale
:
Optional
[
torch
.
Tensor
],
ab_strides1
:
torch
.
Tensor
,
ab_strides2
:
torch
.
Tensor
,
c_strides1
:
torch
.
Tensor
,
c_strides2
:
torch
.
Tensor
,
workspace13
:
torch
.
Tensor
,
workspace13
:
torch
.
Tensor
,
workspace2
:
torch
.
Tensor
,
workspace2
:
torch
.
Tensor
,
expert_num_tokens
:
Optional
[
torch
.
Tensor
],
expert_num_tokens
:
Optional
[
torch
.
Tensor
],
...
@@ -41,6 +46,7 @@ def run_cutlass_moe_fp8(
...
@@ -41,6 +46,7 @@ def run_cutlass_moe_fp8(
per_act_token
:
bool
,
per_act_token
:
bool
,
per_out_ch
:
bool
,
per_out_ch
:
bool
,
use_batched_format
:
bool
,
use_batched_format
:
bool
,
topk_weights
:
Optional
[
torch
.
Tensor
],
):
):
a1q
=
hidden_states
a1q
=
hidden_states
...
@@ -99,6 +105,22 @@ def run_cutlass_moe_fp8(
...
@@ -99,6 +105,22 @@ def run_cutlass_moe_fp8(
topk
=
local_topk_ids
.
size
(
1
)
topk
=
local_topk_ids
.
size
(
1
)
local_E
=
w1
.
size
(
0
)
local_E
=
w1
.
size
(
0
)
if
use_batched_format
:
mm1_out
=
_resize_cache
(
workspace13
,
(
local_E
*
padded_M
,
N
*
2
))
act_out
=
_resize_cache
(
workspace2
,
(
local_E
*
padded_M
,
N
))
quant_out
=
_resize_cache
(
workspace13
.
view
(
dtype
=
torch
.
float8_e4m3fn
),
(
local_E
*
padded_M
,
N
))
mm2_out
=
_resize_cache
(
workspace2
,
(
local_E
*
padded_M
,
K
))
else
:
a1q_perm
=
_resize_cache
(
workspace2
.
view
(
dtype
=
torch
.
float8_e4m3fn
),
(
M
*
topk
,
K
))
mm1_out
=
_resize_cache
(
workspace13
,
(
M
*
topk
,
N
*
2
))
act_out
=
_resize_cache
(
workspace2
,
(
M
*
topk
,
N
))
# original workspace are based on input hidden_states dtype (bf16)
quant_out
=
_resize_cache
(
workspace13
.
view
(
dtype
=
torch
.
float8_e4m3fn
),
(
M
*
topk
,
N
))
mm2_out
=
_resize_cache
(
workspace2
,
(
M
*
topk
,
K
))
if
use_batched_format
:
if
use_batched_format
:
assert
expert_num_tokens
is
not
None
assert
expert_num_tokens
is
not
None
...
@@ -120,11 +142,10 @@ def run_cutlass_moe_fp8(
...
@@ -120,11 +142,10 @@ def run_cutlass_moe_fp8(
w2_scale
=
w2_scale
.
reshape
(
w2_scale
.
size
(
0
),
-
1
)
w2_scale
=
w2_scale
.
reshape
(
w2_scale
.
size
(
0
),
-
1
)
a1q
=
a1q
.
reshape
(
-
1
,
a1q
.
size
(
2
))
a1q
=
a1q
.
reshape
(
-
1
,
a1q
.
size
(
2
))
a1q_scale
=
a1q_scale
.
reshape
(
-
1
,
a1q_scale
.
size
(
2
)).
contiguous
()
a1q_scale
=
a1q_scale
.
reshape
(
-
1
,
a1q_scale
.
size
(
2
)).
contiguous
()
# c3x get_group_gemm_starts expects int64 to avoid overflow
# during offset calculations
expert_offsets
=
expert_offsets
.
to
(
torch
.
int64
)
else
:
else
:
expert_offsets
=
torch
.
empty
((
global_num_experts
+
1
),
dtype
=
torch
.
int32
,
device
=
device
)
problem_sizes1
=
torch
.
empty
((
global_num_experts
,
3
),
problem_sizes1
=
torch
.
empty
((
global_num_experts
,
3
),
dtype
=
torch
.
int32
,
dtype
=
torch
.
int32
,
device
=
device
)
device
=
device
)
...
@@ -132,84 +153,57 @@ def run_cutlass_moe_fp8(
...
@@ -132,84 +153,57 @@ def run_cutlass_moe_fp8(
dtype
=
torch
.
int32
,
dtype
=
torch
.
int32
,
device
=
device
)
device
=
device
)
# With expert_map each Rank processes only a subset of experts. As
num_expert
=
global_num_experts
if
expert_map
is
None
\
# a result not all of a_map and c2 tensors are filled. We fill it
else
expert_map
.
size
(
0
)
# zeros for correctness.
# permuted a1q reuses workspace2
if
expert_map
is
not
None
:
a1q
,
a1q_scale
,
expert_offsets
,
inv_perm
,
_
=
moe_permute
(
a_map
=
torch
.
zeros
((
local_topk_ids
.
numel
()),
a1q
,
dtype
=
torch
.
int32
,
a1q_scale
,
device
=
device
)
topk_ids
,
else
:
num_expert
,
a_map
=
torch
.
empty
((
local_topk_ids
.
numel
()),
local_E
,
dtype
=
torch
.
int32
,
expert_map
,
device
=
device
)
permuted_hidden_states
=
a1q_perm
)
c_map
=
torch
.
empty
((
local_topk_ids
.
numel
()),
dtype
=
torch
.
int32
,
device
=
device
)
ops
.
get_cutlass_moe_mm_data
(
local_topk_ids
,
expert_offsets
,
problem_sizes1
,
problem_sizes2
,
a_map
,
c_map
,
global_num_experts
,
N
,
K
)
a1q
=
_fp8_perm
(
a1q
,
a_map
)
a1q_scale
=
a1q_scale
[
a_map
]
if
per_act_token
else
a1q_scale
expert_offsets
=
expert_offsets
[:
-
1
]
expert_offsets
=
expert_offsets
[:
-
1
]
ab_strides1
=
torch
.
full
((
w1
.
size
(
0
),
),
ops
.
get_cutlass_moe_mm_problem_sizes
(
local_topk_ids
,
problem_sizes1
,
K
,
problem_sizes2
,
device
=
device
,
global_num_experts
,
N
,
K
)
dtype
=
torch
.
int64
)
c_strides1
=
torch
.
full
((
w1
.
size
(
0
),
),
2
*
N
,
device
=
device
,
dtype
=
torch
.
int64
)
ab_strides2
=
torch
.
full
((
w1
.
size
(
0
),
),
N
,
device
=
device
,
dtype
=
torch
.
int64
)
c_strides2
=
torch
.
full
((
w1
.
size
(
0
),
),
K
,
device
=
device
,
dtype
=
torch
.
int64
)
if
use_batched_format
:
c1
=
_resize_cache
(
workspace13
,
(
local_E
*
padded_M
,
N
*
2
))
c2
=
_resize_cache
(
workspace2
,
(
local_E
*
padded_M
,
N
))
c3
=
_resize_cache
(
workspace13
,
(
local_E
*
padded_M
,
K
))
else
:
c1
=
_resize_cache
(
workspace13
,
(
M
*
topk
,
N
*
2
))
c2
=
_resize_cache
(
workspace2
,
(
M
*
topk
,
N
))
c3
=
_resize_cache
(
workspace13
,
(
M
*
topk
,
K
))
if
not
per_act_token
and
(
expert_map
is
not
None
or
use_batched_format
):
if
not
per_act_token
and
(
expert_map
is
not
None
or
use_batched_format
):
# this is necessary to avoid imprecise scale calculation caused by
# this is necessary to avoid imprecise scale calculation caused by
# random data in the unused workspace. The workspace is unused when
# random data in the unused workspace. The workspace is unused when
# this rank handles only partial tokens, or when it is batched .
# this rank handles only partial tokens, or when it is batched .
c1
.
fill_
(
0
)
mm1_out
.
fill_
(
0
)
ops
.
cutlass_moe_mm
(
c1
,
a1q
,
w1
,
a1q_scale
,
w1_scale
,
expert_offsets
,
ops
.
cutlass_moe_mm
(
mm1_out
,
a1q
,
w1
,
a1q_scale
,
w1_scale
,
expert_offsets
,
problem_sizes1
,
ab_strides1
,
ab_strides1
,
c_strides1
,
problem_sizes1
,
ab_strides1
,
ab_strides1
,
c_strides1
,
per_act_token
,
per_out_ch
)
per_act_token
,
per_out_ch
)
activation_callable
(
c2
,
c1
)
activation_callable
(
act_out
,
mm1_out
)
a2q
,
a2q_scale
=
ops
.
scaled_fp8_quant
(
a2q
,
a2q_scale
=
ops
.
scaled_fp8_quant
(
c2
,
a2_scale
,
use_per_token_if_dynamic
=
per_act_token
)
act_out
,
a2_scale
,
use_per_token_if_dynamic
=
per_act_token
,
output
=
quant_out
)
if
expert_map
is
not
None
:
if
expert_map
is
not
None
:
c3
.
fill_
(
0
)
mm2_out
.
fill_
(
0
)
ops
.
cutlass_moe_mm
(
c3
,
a2q
,
w2
,
a2q_scale
,
w2_scale
,
expert_offsets
,
ops
.
cutlass_moe_mm
(
mm2_out
,
a2q
,
w2
,
a2q_scale
,
w2_scale
,
expert_offsets
,
problem_sizes2
,
ab_strides2
,
ab_strides2
,
c_strides2
,
problem_sizes2
,
ab_strides2
,
ab_strides2
,
c_strides2
,
per_act_token
,
per_out_ch
)
per_act_token
,
per_out_ch
)
if
use_batched_format
:
if
use_batched_format
:
output
.
copy_
(
c3
.
reshape
(
local_E
,
padded_M
,
K
),
non_blocking
=
True
)
output
.
copy_
(
mm2_out
.
reshape
(
local_E
,
padded_M
,
K
),
non_blocking
=
True
)
else
:
else
:
# We can't do this inplace because output may point to the same tensor
# for non-chunking mode the output is resized from workspace13
# as c3.
# so we need to make sure mm2_out uses workspace2.
output
.
copy_
(
c3
[
c_map
].
view
(
M
*
topk
,
K
),
non_blocking
=
True
)
moe_unpermute
(
out
=
output
,
permuted_hidden_states
=
mm2_out
,
topk_weights
=
topk_weights
,
inv_permuted_idx
=
inv_perm
)
class
CutlassExpertsFp8Base
(
mk
.
FusedMoEPermuteExpertsUnpermute
):
class
CutlassExpertsFp8Base
(
mk
.
FusedMoEPermuteExpertsUnpermute
):
...
@@ -219,6 +213,10 @@ class CutlassExpertsFp8Base(mk.FusedMoEPermuteExpertsUnpermute):
...
@@ -219,6 +213,10 @@ class CutlassExpertsFp8Base(mk.FusedMoEPermuteExpertsUnpermute):
out_dtype
:
Optional
[
torch
.
dtype
],
out_dtype
:
Optional
[
torch
.
dtype
],
per_act_token_quant
:
bool
,
per_act_token_quant
:
bool
,
per_out_ch_quant
:
bool
,
per_out_ch_quant
:
bool
,
ab_strides1
:
torch
.
Tensor
,
ab_strides2
:
torch
.
Tensor
,
c_strides1
:
torch
.
Tensor
,
c_strides2
:
torch
.
Tensor
,
block_shape
:
Optional
[
list
[
int
]]
=
None
,
block_shape
:
Optional
[
list
[
int
]]
=
None
,
):
):
super
().
__init__
(
super
().
__init__
(
...
@@ -229,6 +227,10 @@ class CutlassExpertsFp8Base(mk.FusedMoEPermuteExpertsUnpermute):
...
@@ -229,6 +227,10 @@ class CutlassExpertsFp8Base(mk.FusedMoEPermuteExpertsUnpermute):
block_shape
=
block_shape
,
block_shape
=
block_shape
,
))
))
self
.
out_dtype
=
out_dtype
self
.
out_dtype
=
out_dtype
self
.
ab_strides1
=
ab_strides1
self
.
ab_strides2
=
ab_strides2
self
.
c_strides1
=
c_strides1
self
.
c_strides2
=
c_strides2
def
finalize_weight_and_reduce_impl
(
self
)
->
mk
.
TopKWeightAndReduce
:
def
finalize_weight_and_reduce_impl
(
self
)
->
mk
.
TopKWeightAndReduce
:
# Let PrepareAndFinalize::finalize() decide the impl.
# Let PrepareAndFinalize::finalize() decide the impl.
...
@@ -272,10 +274,11 @@ class CutlassExpertsFp8Base(mk.FusedMoEPermuteExpertsUnpermute):
...
@@ -272,10 +274,11 @@ class CutlassExpertsFp8Base(mk.FusedMoEPermuteExpertsUnpermute):
run_cutlass_moe_fp8
(
run_cutlass_moe_fp8
(
output
,
hidden_states
,
w1
,
w2
,
topk_ids
,
activation_callable
,
output
,
hidden_states
,
w1
,
w2
,
topk_ids
,
activation_callable
,
global_num_experts
,
expert_map
,
w1_scale
,
w2_scale
,
a1q_scale
,
global_num_experts
,
expert_map
,
w1_scale
,
w2_scale
,
a1q_scale
,
a2_scale
,
workspace13
,
workspace2
,
expert_num_tokens
,
a2_scale
,
self
.
ab_strides1
,
self
.
ab_strides2
,
self
.
c_strides1
,
self
.
c_strides2
,
workspace13
,
workspace2
,
expert_num_tokens
,
self
.
out_dtype
if
self
.
out_dtype
is
not
None
else
in_dtype
,
self
.
out_dtype
if
self
.
out_dtype
is
not
None
else
in_dtype
,
self
.
per_act_token_quant
,
self
.
per_out_ch_quant
,
self
.
per_act_token_quant
,
self
.
per_out_ch_quant
,
use_batched_format
)
use_batched_format
,
topk_weights
)
class
CutlassExpertsFp8
(
CutlassExpertsFp8Base
):
class
CutlassExpertsFp8
(
CutlassExpertsFp8Base
):
...
@@ -285,12 +288,20 @@ class CutlassExpertsFp8(CutlassExpertsFp8Base):
...
@@ -285,12 +288,20 @@ class CutlassExpertsFp8(CutlassExpertsFp8Base):
out_dtype
:
Optional
[
torch
.
dtype
],
out_dtype
:
Optional
[
torch
.
dtype
],
per_act_token_quant
:
bool
,
per_act_token_quant
:
bool
,
per_out_ch_quant
:
bool
,
per_out_ch_quant
:
bool
,
ab_strides1
:
torch
.
Tensor
,
ab_strides2
:
torch
.
Tensor
,
c_strides1
:
torch
.
Tensor
,
c_strides2
:
torch
.
Tensor
,
block_shape
:
Optional
[
list
[
int
]]
=
None
,
block_shape
:
Optional
[
list
[
int
]]
=
None
,
):
):
super
().
__init__
(
super
().
__init__
(
out_dtype
,
out_dtype
,
per_act_token_quant
,
per_act_token_quant
,
per_out_ch_quant
,
per_out_ch_quant
,
ab_strides1
,
ab_strides2
,
c_strides1
,
c_strides2
,
block_shape
,
block_shape
,
)
)
...
@@ -307,6 +318,10 @@ class CutlassExpertsFp8(CutlassExpertsFp8Base):
...
@@ -307,6 +318,10 @@ class CutlassExpertsFp8(CutlassExpertsFp8Base):
def
supports_expert_map
(
self
)
->
bool
:
def
supports_expert_map
(
self
)
->
bool
:
return
True
return
True
def
finalize_weight_and_reduce_impl
(
self
)
->
mk
.
TopKWeightAndReduce
:
# topk weights and reduction are fused in moe_unpermute cuda kernel
return
TopKWeightAndReduceNoOP
()
def
workspace_shapes
(
def
workspace_shapes
(
self
,
self
,
a
:
torch
.
Tensor
,
a
:
torch
.
Tensor
,
...
@@ -320,8 +335,8 @@ class CutlassExpertsFp8(CutlassExpertsFp8Base):
...
@@ -320,8 +335,8 @@ class CutlassExpertsFp8(CutlassExpertsFp8Base):
expert_tokens_meta
:
Optional
[
mk
.
ExpertTokensMetadata
],
expert_tokens_meta
:
Optional
[
mk
.
ExpertTokensMetadata
],
)
->
tuple
[
tuple
[
int
,
...],
tuple
[
int
,
...],
tuple
[
int
,
...],
torch
.
dtype
]:
)
->
tuple
[
tuple
[
int
,
...],
tuple
[
int
,
...],
tuple
[
int
,
...],
torch
.
dtype
]:
workspace1
=
(
M
*
topk
,
max
(
N
,
K
))
workspace1
=
(
M
*
topk
,
max
(
N
,
K
))
workspace2
=
(
M
*
topk
,
N
//
2
)
workspace2
=
(
M
*
topk
,
max
(
N
//
2
,
K
)
)
output
=
(
M
*
topk
,
K
)
output
=
(
M
,
K
)
return
(
workspace1
,
workspace2
,
output
,
return
(
workspace1
,
workspace2
,
output
,
self
.
out_dtype
if
self
.
out_dtype
is
not
None
else
a
.
dtype
)
self
.
out_dtype
if
self
.
out_dtype
is
not
None
else
a
.
dtype
)
...
@@ -335,12 +350,20 @@ class CutlassBatchedExpertsFp8(CutlassExpertsFp8Base):
...
@@ -335,12 +350,20 @@ class CutlassBatchedExpertsFp8(CutlassExpertsFp8Base):
out_dtype
:
Optional
[
torch
.
dtype
],
out_dtype
:
Optional
[
torch
.
dtype
],
per_act_token_quant
:
bool
,
per_act_token_quant
:
bool
,
per_out_ch_quant
:
bool
,
per_out_ch_quant
:
bool
,
ab_strides1
:
torch
.
Tensor
,
ab_strides2
:
torch
.
Tensor
,
c_strides1
:
torch
.
Tensor
,
c_strides2
:
torch
.
Tensor
,
block_shape
:
Optional
[
list
[
int
]]
=
None
,
block_shape
:
Optional
[
list
[
int
]]
=
None
,
):
):
super
().
__init__
(
super
().
__init__
(
out_dtype
,
out_dtype
,
per_act_token_quant
,
per_act_token_quant
,
per_out_ch_quant
,
per_out_ch_quant
,
ab_strides1
,
ab_strides2
,
c_strides1
,
c_strides2
,
block_shape
,
block_shape
,
)
)
assert
max_experts_per_worker
>
0
assert
max_experts_per_worker
>
0
...
@@ -378,7 +401,8 @@ class CutlassBatchedExpertsFp8(CutlassExpertsFp8Base):
...
@@ -378,7 +401,8 @@ class CutlassBatchedExpertsFp8(CutlassExpertsFp8Base):
assert
num_dp
is
not
None
assert
num_dp
is
not
None
workspace1
=
(
self
.
max_experts_per_worker
,
padded_M
*
num_dp
,
workspace1
=
(
self
.
max_experts_per_worker
,
padded_M
*
num_dp
,
max
(
N
,
K
))
max
(
N
,
K
))
workspace2
=
(
self
.
max_experts_per_worker
,
padded_M
*
num_dp
,
(
N
//
2
))
workspace2
=
(
self
.
max_experts_per_worker
,
padded_M
*
num_dp
,
max
(
N
//
2
,
K
))
output
=
(
self
.
max_experts_per_worker
,
padded_M
,
K
)
output
=
(
self
.
max_experts_per_worker
,
padded_M
,
K
)
return
(
workspace1
,
workspace2
,
output
,
return
(
workspace1
,
workspace2
,
output
,
self
.
out_dtype
if
self
.
out_dtype
is
not
None
else
a
.
dtype
)
self
.
out_dtype
if
self
.
out_dtype
is
not
None
else
a
.
dtype
)
...
@@ -392,6 +416,10 @@ def cutlass_moe_fp8(
...
@@ -392,6 +416,10 @@ def cutlass_moe_fp8(
topk_ids
:
torch
.
Tensor
,
topk_ids
:
torch
.
Tensor
,
w1_scale
:
torch
.
Tensor
,
w1_scale
:
torch
.
Tensor
,
w2_scale
:
torch
.
Tensor
,
w2_scale
:
torch
.
Tensor
,
ab_strides1
:
torch
.
Tensor
,
ab_strides2
:
torch
.
Tensor
,
c_strides1
:
torch
.
Tensor
,
c_strides2
:
torch
.
Tensor
,
per_act_token
:
Optional
[
bool
]
=
None
,
per_act_token
:
Optional
[
bool
]
=
None
,
activation
:
str
=
"silu"
,
activation
:
str
=
"silu"
,
a1_scale
:
Optional
[
torch
.
Tensor
]
=
None
,
a1_scale
:
Optional
[
torch
.
Tensor
]
=
None
,
...
@@ -419,6 +447,17 @@ def cutlass_moe_fp8(
...
@@ -419,6 +447,17 @@ def cutlass_moe_fp8(
Shape: [num_experts] or [num_experts, 2N]
Shape: [num_experts] or [num_experts, 2N]
- w2_scale (torch.Tensor): The fp32 scale to dequantize w2_q.
- w2_scale (torch.Tensor): The fp32 scale to dequantize w2_q.
Shape: [num_experts] or [num_experts, K]
Shape: [num_experts] or [num_experts, K]
- ab_strides1 (torch.Tensor): The input/weight strides for the first gemm.
Shape: [num_experts]
- ab_strides2 (torch.Tensor): The input/weight strides for the second gemm.
Shape: [num_experts]
- c_strides1 (torch.Tensor): The output strides for the first gemm.
Shape: [num_experts]
- c_strides2 (torch.Tensor): The output strides for the second gemm.
Shape: [num_experts]
- per_act_token (Optional[bool]): Whether the scale is per-token or
per-tensor.
- activation (str): The activation function to use.
- a1_scale (Optional[torch.Tensor]): The optional fp32 scale to quantize a.
- a1_scale (Optional[torch.Tensor]): The optional fp32 scale to quantize a.
Shape: scalar or [M]
Shape: scalar or [M]
- a2_scale (Optional[torch.Tensor]): The optional fp32 scale to
- a2_scale (Optional[torch.Tensor]): The optional fp32 scale to
...
@@ -450,6 +489,10 @@ def cutlass_moe_fp8(
...
@@ -450,6 +489,10 @@ def cutlass_moe_fp8(
out_dtype
=
a
.
dtype
,
out_dtype
=
a
.
dtype
,
per_act_token_quant
=
per_act_token
,
per_act_token_quant
=
per_act_token
,
per_out_ch_quant
=
per_out_ch
,
per_out_ch_quant
=
per_out_ch
,
ab_strides1
=
ab_strides1
,
ab_strides2
=
ab_strides2
,
c_strides1
=
c_strides1
,
c_strides2
=
c_strides2
,
),
),
)
)
...
...
vllm/model_executor/layers/fused_moe/moe_permute_unpermute.py
View file @
b17109be
...
@@ -82,7 +82,8 @@ def moe_permute(
...
@@ -82,7 +82,8 @@ def moe_permute(
n_local_expert
:
int
=
-
1
,
n_local_expert
:
int
=
-
1
,
expert_map
:
Optional
[
torch
.
Tensor
]
=
None
,
expert_map
:
Optional
[
torch
.
Tensor
]
=
None
,
align_block_size
:
Optional
[
int
]
=
None
,
align_block_size
:
Optional
[
int
]
=
None
,
fill_invalid_expert
:
int
=
-
1
fill_invalid_expert
:
int
=
-
1
,
permuted_hidden_states
:
Optional
[
torch
.
Tensor
]
=
None
,
)
->
tuple
[
torch
.
Tensor
,
Optional
[
torch
.
Tensor
],
torch
.
Tensor
,
torch
.
Tensor
,
)
->
tuple
[
torch
.
Tensor
,
Optional
[
torch
.
Tensor
],
torch
.
Tensor
,
torch
.
Tensor
,
torch
.
Tensor
]:
torch
.
Tensor
]:
"""
"""
...
@@ -100,9 +101,12 @@ def moe_permute(
...
@@ -100,9 +101,12 @@ def moe_permute(
- align_block_size (Optional[int]): align group gemm block size for deepgemm
- align_block_size (Optional[int]): align group gemm block size for deepgemm
- fill_invalid_expert(int): fill expert id in m_indices for invalid expert
- fill_invalid_expert(int): fill expert id in m_indices for invalid expert
to workaround DeepGemm unsupported -1 in m_indices
to workaround DeepGemm unsupported -1 in m_indices
- permuted_hidden_states (Optional[torch.Tensor]): Optional output tensor.
If None, the output tensor will be created in this function.
Returns:
Returns:
- permuted_hidden_states (torch.Tensor): permuted activation.
- permuted_hidden_states (torch.Tensor): permuted activation.
- a1q_scale (Optional[torch.Tensor]): quant scale for hidden_states
- a1q_scale (Optional[torch.Tensor]): permuted quant scale for hidden_states
if original scale not per-tensor scaling
- expert_first_token_offset (torch.Tensor): offset of the first token
- expert_first_token_offset (torch.Tensor): offset of the first token
of each expert for standard grouped gemm. if enable 'align_block_size'
of each expert for standard grouped gemm. if enable 'align_block_size'
expert_first_token_offset will align up to 'align_block_size'.
expert_first_token_offset will align up to 'align_block_size'.
...
@@ -122,11 +126,16 @@ def moe_permute(
...
@@ -122,11 +126,16 @@ def moe_permute(
1
)
//
align_block_size
*
align_block_size
1
)
//
align_block_size
*
align_block_size
if
n_local_expert
==
-
1
:
if
n_local_expert
==
-
1
:
n_local_expert
=
n_expert
n_local_expert
=
n_expert
if
permuted_hidden_states
is
None
:
permuted_hidden_states
=
torch
.
empty
(
permuted_hidden_states
=
torch
.
empty
(
(
permuted_row_size
,
n_hidden
),
(
permuted_row_size
,
n_hidden
),
dtype
=
hidden_states
.
dtype
,
dtype
=
hidden_states
.
dtype
,
device
=
hidden_states
.
device
,
device
=
hidden_states
.
device
,
)
)
assert
permuted_hidden_states
.
size
()
==
(
permuted_row_size
,
n_hidden
),
(
f
"Expected permuted hidden states to be
{
(
permuted_row_size
,
n_hidden
)
}
"
f
" but got
{
permuted_hidden_states
.
size
()
}
"
)
token_expert_indices
=
torch
.
arange
(
0
,
token_expert_indices
=
torch
.
arange
(
0
,
n_token
*
topk
,
n_token
*
topk
,
dtype
=
torch
.
int32
,
dtype
=
torch
.
int32
,
...
@@ -153,7 +162,8 @@ def moe_permute(
...
@@ -153,7 +162,8 @@ def moe_permute(
align_block_size
,
permuted_hidden_states
,
align_block_size
,
permuted_hidden_states
,
expert_first_token_offset
,
inv_permuted_idx
,
expert_first_token_offset
,
inv_permuted_idx
,
permuted_idx
,
m_indices
)
permuted_idx
,
m_indices
)
if
a1q_scale
is
not
None
:
if
a1q_scale
is
not
None
and
a1q_scale
.
dim
()
>
1
:
a1q_scale
=
a1q_scale
[
permuted_idx
.
clamp
(
max
=
n_token
*
topk
-
1
)
//
a1q_scale
=
a1q_scale
[
permuted_idx
.
clamp
(
max
=
n_token
*
topk
-
1
)
//
topk
]
topk
]
return
(
permuted_hidden_states
,
a1q_scale
,
expert_first_token_offset
,
return
(
permuted_hidden_states
,
a1q_scale
,
expert_first_token_offset
,
...
@@ -185,6 +195,7 @@ def moe_unpermute(
...
@@ -185,6 +195,7 @@ def moe_unpermute(
n_hidden
=
permuted_hidden_states
.
size
(
-
1
)
n_hidden
=
permuted_hidden_states
.
size
(
-
1
)
assert
(
n_hidden
*
permuted_hidden_states
.
element_size
()
assert
(
n_hidden
*
permuted_hidden_states
.
element_size
()
)
%
16
==
0
,
"unpermue kernel need hidden dim align to 16B"
)
%
16
==
0
,
"unpermue kernel need hidden dim align to 16B"
torch
.
ops
.
_moe_C
.
moe_unpermute
(
permuted_hidden_states
,
topk_weights
,
torch
.
ops
.
_moe_C
.
moe_unpermute
(
permuted_hidden_states
,
topk_weights
,
inv_permuted_idx
,
expert_first_token_offset
,
inv_permuted_idx
,
expert_first_token_offset
,
topk
,
out
)
topk
,
out
)
...
...
vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py
View file @
b17109be
...
@@ -669,6 +669,25 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod):
...
@@ -669,6 +669,25 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod):
from
vllm.model_executor.layers.fused_moe
import
fused_experts
from
vllm.model_executor.layers.fused_moe
import
fused_experts
self
.
fused_experts_func
=
fused_experts
self
.
fused_experts_func
=
fused_experts
if
self
.
use_cutlass
:
device
=
layer
.
w13_weight
.
device
# ab_strides1 and c_strides2 are the same
self
.
ab_strides1_c_strides2
=
torch
.
full
(
(
layer
.
local_num_experts
,
),
layer
.
hidden_size
,
device
=
device
,
dtype
=
torch
.
int64
)
self
.
ab_strides2
=
torch
.
full
(
(
layer
.
local_num_experts
,
),
layer
.
intermediate_size_per_partition
,
device
=
device
,
dtype
=
torch
.
int64
)
self
.
c_strides1
=
torch
.
full
(
(
layer
.
local_num_experts
,
),
2
*
layer
.
intermediate_size_per_partition
,
device
=
device
,
dtype
=
torch
.
int64
)
def
select_gemm_impl
(
def
select_gemm_impl
(
self
,
self
,
prepare_finalize
:
FusedMoEPrepareAndFinalize
,
prepare_finalize
:
FusedMoEPrepareAndFinalize
,
...
@@ -693,6 +712,10 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod):
...
@@ -693,6 +712,10 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod):
moe
.
in_dtype
,
moe
.
in_dtype
,
self
.
input_quant
.
strategy
==
QuantizationStrategy
.
TOKEN
,
self
.
input_quant
.
strategy
==
QuantizationStrategy
.
TOKEN
,
self
.
weight_quant
.
strategy
==
QuantizationStrategy
.
CHANNEL
,
self
.
weight_quant
.
strategy
==
QuantizationStrategy
.
CHANNEL
,
ab_strides1
=
self
.
ab_strides1_c_strides2
,
ab_strides2
=
self
.
ab_strides2
,
c_strides1
=
self
.
c_strides1
,
c_strides2
=
self
.
ab_strides1_c_strides2
,
)
)
else
:
else
:
logger
.
debug
(
"CutlassExpertsFp8(%s)"
,
self
.
__class__
.
__name__
)
logger
.
debug
(
"CutlassExpertsFp8(%s)"
,
self
.
__class__
.
__name__
)
...
@@ -700,6 +723,10 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod):
...
@@ -700,6 +723,10 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod):
moe
.
in_dtype
,
moe
.
in_dtype
,
self
.
input_quant
.
strategy
==
QuantizationStrategy
.
TOKEN
,
self
.
input_quant
.
strategy
==
QuantizationStrategy
.
TOKEN
,
self
.
weight_quant
.
strategy
==
QuantizationStrategy
.
CHANNEL
,
self
.
weight_quant
.
strategy
==
QuantizationStrategy
.
CHANNEL
,
ab_strides1
=
self
.
ab_strides1_c_strides2
,
ab_strides2
=
self
.
ab_strides2
,
c_strides1
=
self
.
c_strides1
,
c_strides2
=
self
.
ab_strides1_c_strides2
,
)
)
self
.
disable_expert_map
=
(
num_dispatchers
>
1
self
.
disable_expert_map
=
(
num_dispatchers
>
1
...
@@ -822,6 +849,10 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod):
...
@@ -822,6 +849,10 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod):
expert_map
=
None
if
self
.
disable_expert_map
else
expert_map
,
expert_map
=
None
if
self
.
disable_expert_map
else
expert_map
,
w1_scale
=
layer
.
w13_weight_scale
,
w1_scale
=
layer
.
w13_weight_scale
,
w2_scale
=
layer
.
w2_weight_scale
,
w2_scale
=
layer
.
w2_weight_scale
,
ab_strides1
=
self
.
ab_strides1_c_strides2
,
ab_strides2
=
self
.
ab_strides2
,
c_strides1
=
self
.
c_strides1
,
c_strides2
=
self
.
ab_strides1_c_strides2
,
a1_scale
=
layer
.
w13_input_scale
,
a1_scale
=
layer
.
w13_input_scale
,
a2_scale
=
layer
.
w2_input_scale
,
a2_scale
=
layer
.
w2_input_scale
,
)
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment