Unverified Commit 61059bee authored by Chiyue Wei's avatar Chiyue Wei Committed by GitHub
Browse files

[Hardware][NVIDIA] FP4 MoE kernel optimization (#19110)


Signed-off-by: default avatarChiyue Wei <chiyuew@nvidia.com>
Co-authored-by: default avatarChiyue Wei <chiyuew@nvidia.com>
parent ec89524f
...@@ -91,7 +91,7 @@ def bench_run( ...@@ -91,7 +91,7 @@ def bench_run(
score = torch.randn((m, num_experts), device=device, dtype=dtype) score = torch.randn((m, num_experts), device=device, dtype=dtype)
topk_weights, topk_ids = fused_topk(a, score, topk, renormalize=False) topk_weights, topk_ids, _ = fused_topk(a, score, topk, renormalize=False)
quant_blocksize = 16 quant_blocksize = 16
w1_blockscale = torch.empty( w1_blockscale = torch.empty(
......
...@@ -31,3 +31,7 @@ torch::Tensor moe_wna16_gemm(torch::Tensor input, torch::Tensor output, ...@@ -31,3 +31,7 @@ torch::Tensor moe_wna16_gemm(torch::Tensor input, torch::Tensor output,
#endif #endif
bool moe_permute_unpermute_supported(); bool moe_permute_unpermute_supported();
void shuffle_rows(const torch::Tensor& input_tensor,
const torch::Tensor& dst2src_map,
torch::Tensor& output_tensor);
\ No newline at end of file
...@@ -130,6 +130,62 @@ void moe_unpermute( ...@@ -130,6 +130,62 @@ void moe_unpermute(
}); });
} }
template <typename T>
__global__ void shuffleInputRowsKernel(const T* input,
const int32_t* dst2src_map, T* output,
int64_t num_src_rows,
int64_t num_dst_rows, int64_t num_cols) {
int64_t dest_row_idx = blockIdx.x;
int64_t const source_row_idx = dst2src_map[dest_row_idx];
if (blockIdx.x < num_dst_rows) {
// Load 128-bits per thread
constexpr int64_t ELEM_PER_THREAD = 128 / sizeof(T) / 8;
using DataElem = cutlass::Array<T, ELEM_PER_THREAD>;
// Duplicate and permute rows
auto const* source_row_ptr =
reinterpret_cast<DataElem const*>(input + source_row_idx * num_cols);
auto* dest_row_ptr =
reinterpret_cast<DataElem*>(output + dest_row_idx * num_cols);
int64_t const start_offset = threadIdx.x;
int64_t const stride = blockDim.x;
int64_t const num_elems_in_col = num_cols / ELEM_PER_THREAD;
for (int elem_index = start_offset; elem_index < num_elems_in_col;
elem_index += stride) {
dest_row_ptr[elem_index] = source_row_ptr[elem_index];
}
}
}
void shuffle_rows(const torch::Tensor& input_tensor,
const torch::Tensor& dst2src_map,
torch::Tensor& output_tensor) {
TORCH_CHECK(input_tensor.scalar_type() == output_tensor.scalar_type(),
"Input and output tensors must have the same data type");
auto stream = at::cuda::getCurrentCUDAStream().stream();
int64_t const blocks = output_tensor.size(0);
int64_t const threads = 256;
int64_t const num_dest_rows = output_tensor.size(0);
int64_t const num_src_rows = input_tensor.size(0);
int64_t const num_cols = input_tensor.size(1);
TORCH_CHECK(!(num_cols % (128 / sizeof(input_tensor.scalar_type()) / 8)),
"num_cols must be divisible by 128 / "
"sizeof(input_tensor.scalar_type()) / 8");
MOE_DISPATCH(input_tensor.scalar_type(), [&] {
shuffleInputRowsKernel<scalar_t><<<blocks, threads, 0, stream>>>(
reinterpret_cast<scalar_t*>(input_tensor.data_ptr()),
dst2src_map.data_ptr<int32_t>(),
reinterpret_cast<scalar_t*>(output_tensor.data_ptr()), num_src_rows,
num_dest_rows, num_cols);
});
}
#else #else
void moe_permute(const torch::Tensor& input, const torch::Tensor& topk_weights, void moe_permute(const torch::Tensor& input, const torch::Tensor& topk_weights,
......
...@@ -19,7 +19,8 @@ ...@@ -19,7 +19,8 @@
MOE_DISPATCH_CASE(at::ScalarType::Half, __VA_ARGS__) \ MOE_DISPATCH_CASE(at::ScalarType::Half, __VA_ARGS__) \
MOE_DISPATCH_CASE(at::ScalarType::BFloat16, __VA_ARGS__) \ MOE_DISPATCH_CASE(at::ScalarType::BFloat16, __VA_ARGS__) \
MOE_DISPATCH_CASE(at::ScalarType::Float8_e5m2, __VA_ARGS__) \ MOE_DISPATCH_CASE(at::ScalarType::Float8_e5m2, __VA_ARGS__) \
MOE_DISPATCH_CASE(at::ScalarType::Float8_e4m3fn, __VA_ARGS__) MOE_DISPATCH_CASE(at::ScalarType::Float8_e4m3fn, __VA_ARGS__) \
MOE_DISPATCH_CASE(at::ScalarType::Byte, __VA_ARGS__)
#define MOE_DISPATCH(TYPE, ...) \ #define MOE_DISPATCH(TYPE, ...) \
MOE_SWITCH(TYPE, MOE_DISPATCH_FLOAT_CASE(__VA_ARGS__)) MOE_SWITCH(TYPE, MOE_DISPATCH_FLOAT_CASE(__VA_ARGS__))
...@@ -39,6 +40,11 @@ template <> ...@@ -39,6 +40,11 @@ template <>
struct ScalarType2CudaType<at::ScalarType::BFloat16> { struct ScalarType2CudaType<at::ScalarType::BFloat16> {
using type = __nv_bfloat16; using type = __nv_bfloat16;
}; };
// uint8 for packed fp4
template <>
struct ScalarType2CudaType<at::ScalarType::Byte> {
using type = uint8_t;
};
// #if __CUDA_ARCH__ >= 890 // #if __CUDA_ARCH__ >= 890
// fp8 // fp8
......
...@@ -81,6 +81,12 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, m) { ...@@ -81,6 +81,12 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, m) {
m.def("moe_permute_unpermute_supported() -> bool"); m.def("moe_permute_unpermute_supported() -> bool");
m.impl("moe_permute_unpermute_supported", &moe_permute_unpermute_supported); m.impl("moe_permute_unpermute_supported", &moe_permute_unpermute_supported);
// Row shuffle for MoE
m.def(
"shuffle_rows(Tensor input_tensor, Tensor dst2src_map, Tensor! "
"output_tensor) -> ()");
m.impl("shuffle_rows", torch::kCUDA, &shuffle_rows);
#endif #endif
} }
......
...@@ -248,7 +248,8 @@ void get_cutlass_moe_mm_data( ...@@ -248,7 +248,8 @@ void get_cutlass_moe_mm_data(
const torch::Tensor& topk_ids, torch::Tensor& expert_offsets, const torch::Tensor& topk_ids, torch::Tensor& expert_offsets,
torch::Tensor& problem_sizes1, torch::Tensor& problem_sizes2, torch::Tensor& problem_sizes1, torch::Tensor& problem_sizes2,
torch::Tensor& input_permutation, torch::Tensor& output_permutation, torch::Tensor& input_permutation, torch::Tensor& output_permutation,
const int64_t num_experts, const int64_t n, const int64_t k); const int64_t num_experts, const int64_t n, const int64_t k,
const std::optional<torch::Tensor>& blockscale_offsets);
void cutlass_scaled_mm_azp(torch::Tensor& out, torch::Tensor const& a, void cutlass_scaled_mm_azp(torch::Tensor& out, torch::Tensor const& a,
torch::Tensor const& b, torch::Tensor const& b,
......
...@@ -45,6 +45,23 @@ __global__ void compute_expert_offsets( ...@@ -45,6 +45,23 @@ __global__ void compute_expert_offsets(
} }
} }
__global__ void compute_expert_blockscale_offsets(
const int32_t* __restrict__ problem_sizes1, int32_t* expert_offsets,
int32_t* blockscale_offsets, int32_t* atomic_buffer,
const int num_experts) {
int32_t tot_offset = 0;
int32_t tot_offset_round = 0;
expert_offsets[0] = 0;
blockscale_offsets[0] = 0;
for (int i = 0; i < num_experts; ++i) {
atomic_buffer[i] = tot_offset;
tot_offset += problem_sizes1[i * 3];
expert_offsets[i + 1] = tot_offset;
tot_offset_round += (problem_sizes1[i * 3] + (128 - 1)) / 128 * 128;
blockscale_offsets[i + 1] = tot_offset_round;
}
}
__global__ void compute_arg_sorts(const int* __restrict__ topk_ids, __global__ void compute_arg_sorts(const int* __restrict__ topk_ids,
const int32_t* __restrict__ expert_offsets, const int32_t* __restrict__ expert_offsets,
int32_t* input_permutation, int32_t* input_permutation,
...@@ -77,7 +94,8 @@ void get_cutlass_moe_mm_data_caller( ...@@ -77,7 +94,8 @@ void get_cutlass_moe_mm_data_caller(
const torch::Tensor& topk_ids, torch::Tensor& expert_offsets, const torch::Tensor& topk_ids, torch::Tensor& expert_offsets,
torch::Tensor& problem_sizes1, torch::Tensor& problem_sizes2, torch::Tensor& problem_sizes1, torch::Tensor& problem_sizes2,
torch::Tensor& input_permutation, torch::Tensor& output_permutation, torch::Tensor& input_permutation, torch::Tensor& output_permutation,
const int64_t num_experts, const int64_t n, const int64_t k) { const int64_t num_experts, const int64_t n, const int64_t k,
const std::optional<torch::Tensor>& blockscale_offsets) {
auto stream = at::cuda::getCurrentCUDAStream(topk_ids.device().index()); auto stream = at::cuda::getCurrentCUDAStream(topk_ids.device().index());
auto options_int32 = auto options_int32 =
torch::TensorOptions().dtype(torch::kInt32).device(topk_ids.device()); torch::TensorOptions().dtype(torch::kInt32).device(topk_ids.device());
...@@ -89,10 +107,18 @@ void get_cutlass_moe_mm_data_caller( ...@@ -89,10 +107,18 @@ void get_cutlass_moe_mm_data_caller(
static_cast<int32_t*>(problem_sizes1.data_ptr()), static_cast<int32_t*>(problem_sizes1.data_ptr()),
static_cast<int32_t*>(problem_sizes2.data_ptr()), static_cast<int32_t*>(problem_sizes2.data_ptr()),
static_cast<int32_t*>(atomic_buffer.data_ptr()), topk_ids.numel(), n, k); static_cast<int32_t*>(atomic_buffer.data_ptr()), topk_ids.numel(), n, k);
if (blockscale_offsets.has_value()) {
compute_expert_blockscale_offsets<<<1, 1, 0, stream>>>(
static_cast<const int32_t*>(problem_sizes1.data_ptr()),
static_cast<int32_t*>(expert_offsets.data_ptr()),
static_cast<int32_t*>(blockscale_offsets.value().data_ptr()),
static_cast<int32_t*>(atomic_buffer.data_ptr()), num_experts);
} else {
compute_expert_offsets<<<1, 1, 0, stream>>>( compute_expert_offsets<<<1, 1, 0, stream>>>(
static_cast<const int32_t*>(problem_sizes1.data_ptr()), static_cast<const int32_t*>(problem_sizes1.data_ptr()),
static_cast<int32_t*>(expert_offsets.data_ptr()), static_cast<int32_t*>(expert_offsets.data_ptr()),
static_cast<int32_t*>(atomic_buffer.data_ptr()), num_experts); static_cast<int32_t*>(atomic_buffer.data_ptr()), num_experts);
}
compute_arg_sorts<<<num_experts, num_threads, 0, stream>>>( compute_arg_sorts<<<num_experts, num_threads, 0, stream>>>(
static_cast<const int32_t*>(topk_ids.data_ptr()), static_cast<const int32_t*>(topk_ids.data_ptr()),
static_cast<const int32_t*>(expert_offsets.data_ptr()), static_cast<const int32_t*>(expert_offsets.data_ptr()),
......
...@@ -54,7 +54,8 @@ void get_cutlass_moe_mm_data_caller( ...@@ -54,7 +54,8 @@ void get_cutlass_moe_mm_data_caller(
const torch::Tensor& topk_ids, torch::Tensor& expert_offsets, const torch::Tensor& topk_ids, torch::Tensor& expert_offsets,
torch::Tensor& problem_sizes1, torch::Tensor& problem_sizes2, torch::Tensor& problem_sizes1, torch::Tensor& problem_sizes2,
torch::Tensor& input_permutation, torch::Tensor& output_permutation, torch::Tensor& input_permutation, torch::Tensor& output_permutation,
const int64_t num_experts, const int64_t n, const int64_t k); const int64_t num_experts, const int64_t n, const int64_t k,
const std::optional<torch::Tensor>& blockscale_offsets);
#endif #endif
void cutlass_scaled_mm_azp_sm75(torch::Tensor& c, torch::Tensor const& a, void cutlass_scaled_mm_azp_sm75(torch::Tensor& c, torch::Tensor const& a,
...@@ -224,7 +225,8 @@ void get_cutlass_moe_mm_data( ...@@ -224,7 +225,8 @@ void get_cutlass_moe_mm_data(
const torch::Tensor& topk_ids, torch::Tensor& expert_offsets, const torch::Tensor& topk_ids, torch::Tensor& expert_offsets,
torch::Tensor& problem_sizes1, torch::Tensor& problem_sizes2, torch::Tensor& problem_sizes1, torch::Tensor& problem_sizes2,
torch::Tensor& input_permutation, torch::Tensor& output_permutation, torch::Tensor& input_permutation, torch::Tensor& output_permutation,
const int64_t num_experts, const int64_t n, const int64_t k) { const int64_t num_experts, const int64_t n, const int64_t k,
const std::optional<torch::Tensor>& blockscale_offsets) {
// This function currently gets compiled only if we have a valid cutlass moe // This function currently gets compiled only if we have a valid cutlass moe
// mm to run it for. // mm to run it for.
int32_t version_num = get_sm_version_num(); int32_t version_num = get_sm_version_num();
...@@ -232,7 +234,8 @@ void get_cutlass_moe_mm_data( ...@@ -232,7 +234,8 @@ void get_cutlass_moe_mm_data(
(defined ENABLE_SCALED_MM_SM100 && ENABLE_SCALED_MM_SM90) (defined ENABLE_SCALED_MM_SM100 && ENABLE_SCALED_MM_SM90)
get_cutlass_moe_mm_data_caller(topk_ids, expert_offsets, problem_sizes1, get_cutlass_moe_mm_data_caller(topk_ids, expert_offsets, problem_sizes1,
problem_sizes2, input_permutation, problem_sizes2, input_permutation,
output_permutation, num_experts, n, k); output_permutation, num_experts, n, k,
blockscale_offsets);
return; return;
#endif #endif
TORCH_CHECK_NOT_IMPLEMENTED( TORCH_CHECK_NOT_IMPLEMENTED(
......
...@@ -450,7 +450,7 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) { ...@@ -450,7 +450,7 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
" Tensor! problem_sizes1, Tensor! problem_sizes2, " " Tensor! problem_sizes1, Tensor! problem_sizes2, "
" Tensor! input_permutation, " " Tensor! input_permutation, "
" Tensor! output_permutation, int num_experts, " " Tensor! output_permutation, int num_experts, "
" int n, int k) -> ()", " int n, int k, Tensor? blockscale_offsets) -> ()",
{stride_tag}); {stride_tag});
ops.impl("get_cutlass_moe_mm_data", torch::kCUDA, &get_cutlass_moe_mm_data); ops.impl("get_cutlass_moe_mm_data", torch::kCUDA, &get_cutlass_moe_mm_data);
......
...@@ -80,7 +80,10 @@ def test_cutlass_fp4_moe_no_graph(m: int, n: int, k: int, e: int, topk: int, ...@@ -80,7 +80,10 @@ def test_cutlass_fp4_moe_no_graph(m: int, n: int, k: int, e: int, topk: int,
w2[expert], w2_gs[expert]) w2[expert], w2_gs[expert])
score = torch.randn((m, e), device="cuda", dtype=dtype) score = torch.randn((m, e), device="cuda", dtype=dtype)
topk_weights, topk_ids = fused_topk(a, score, topk, renormalize=False) topk_weights, topk_ids, _ = fused_topk(a,
score,
topk,
renormalize=False)
a1_gs = torch.ones((e, ), device="cuda", dtype=torch.float32) a1_gs = torch.ones((e, ), device="cuda", dtype=torch.float32)
a2_gs = torch.ones((e, ), device="cuda", dtype=torch.float32) a2_gs = torch.ones((e, ), device="cuda", dtype=torch.float32)
......
...@@ -845,11 +845,16 @@ def cutlass_scaled_sparse_mm( ...@@ -845,11 +845,16 @@ def cutlass_scaled_sparse_mm(
return out return out
def get_cutlass_moe_mm_data( def get_cutlass_moe_mm_data(topk_ids: torch.Tensor,
topk_ids: torch.Tensor, expert_offsets: torch.Tensor, expert_offsets: torch.Tensor,
problem_sizes1: torch.Tensor, problem_sizes2: torch.Tensor, problem_sizes1: torch.Tensor,
input_permutation: torch.Tensor, output_permutation: torch.Tensor, problem_sizes2: torch.Tensor,
num_experts: int, n: int, k: int): input_permutation: torch.Tensor,
output_permutation: torch.Tensor,
num_experts: int,
n: int,
k: int,
blockscale_offsets: Optional[torch.Tensor] = None):
""" """
Prepare data necessary to perform CUTLASS grouped matrix multiplications Prepare data necessary to perform CUTLASS grouped matrix multiplications
used in CUTLASS-based fused MoE. used in CUTLASS-based fused MoE.
...@@ -867,12 +872,31 @@ def get_cutlass_moe_mm_data( ...@@ -867,12 +872,31 @@ def get_cutlass_moe_mm_data(
before executing the MMs. before executing the MMs.
- output_permutation: Permutation that must be used to shuffle the output - output_permutation: Permutation that must be used to shuffle the output
after executing the MMs. after executing the MMs.
- blockscale_offsets: Optional argument passed for fp4 moe. Indices that
mark at which block scale index each expert begins
its computation. The number of block scale rows
computed with expert E is blockscale_offsets[E + 1] -
blockscale_offsets[E]
""" """
return torch.ops._C.get_cutlass_moe_mm_data(topk_ids, expert_offsets, return torch.ops._C.get_cutlass_moe_mm_data(topk_ids, expert_offsets,
problem_sizes1, problem_sizes2, problem_sizes1, problem_sizes2,
input_permutation, input_permutation,
output_permutation, output_permutation,
num_experts, n, k) num_experts, n, k,
blockscale_offsets)
def shuffle_rows(input_tensor: torch.Tensor, dst2src_map: torch.Tensor):
"""
Shuffle and expand the input tensor according to the dst2src_map and store the result in output_tensor.
This is used in MoE to permute the input tensor before performing grouped matrix multiplications.
"""
num_tokens_permuted = dst2src_map.shape[0]
output_tensor = torch.empty((num_tokens_permuted, input_tensor.shape[1]),
device=input_tensor.device,
dtype=input_tensor.dtype)
torch.ops._moe_C.shuffle_rows(input_tensor, dst2src_map, output_tensor)
return output_tensor
def cutlass_moe_mm(out_tensors: torch.Tensor, a_tensors: torch.Tensor, def cutlass_moe_mm(out_tensors: torch.Tensor, a_tensors: torch.Tensor,
...@@ -1124,14 +1148,12 @@ def scaled_fp4_experts_quant( ...@@ -1124,14 +1148,12 @@ def scaled_fp4_experts_quant(
expert_offsets: torch.Tensor, expert_offsets: torch.Tensor,
blockscale_offsets: torch.Tensor, blockscale_offsets: torch.Tensor,
topk: int, topk: int,
expert_map: Optional[torch.Tensor] = None,
) -> tuple[torch.Tensor, torch.Tensor]: ) -> tuple[torch.Tensor, torch.Tensor]:
""" """
Quantize input tensor to FP4 and return quantized tensor and scale, for Quantize input tensor to FP4 and return quantized tensor and scale, for
packed MoE Inputs. packed MoE Inputs.
Args: Args:
input: The input tensor to be quantized to FP4 input_tensor: The input tensor to be quantized to FP4
expert_map: The expert map tensor
input_global_scale: A scalar scaling factor for the entire tensor. input_global_scale: A scalar scaling factor for the entire tensor.
expert_offsets: The expert offsets tensor expert_offsets: The expert offsets tensor
blockscale_offsets: The blockscale offsets tensor blockscale_offsets: The blockscale offsets tensor
...@@ -1143,14 +1165,13 @@ def scaled_fp4_experts_quant( ...@@ -1143,14 +1165,13 @@ def scaled_fp4_experts_quant(
assert input_tensor.ndim == 2, ( assert input_tensor.ndim == 2, (
f'input.ndim needs to be == 2, but got {input_tensor.ndim}.') f'input.ndim needs to be == 2, but got {input_tensor.ndim}.')
input_tensor = input_tensor[
expert_map] if expert_map is not None else input_tensor
m_numtopk, k = input_tensor.shape
# Control the maximum number of tokens per expert supported by the # Control the maximum number of tokens per expert supported by the
# NVFP4 MoE Expert Quantization. This is used to prevent the kernel # NVFP4 MoE Expert Quantization. This is used to prevent the kernel
# from running out of memory. This value can also be increased to support # from running out of memory. This value can also be increased to support
# larger models. # larger models.
MAX_TOKENS_PER_EXPERT = envs.VLLM_MAX_TOKENS_PER_EXPERT_FP4_MOE MAX_TOKENS_PER_EXPERT = envs.VLLM_MAX_TOKENS_PER_EXPERT_FP4_MOE
m_numtopk, k = input_tensor.shape
assert (m_numtopk <= MAX_TOKENS_PER_EXPERT * topk), ( assert (m_numtopk <= MAX_TOKENS_PER_EXPERT * topk), (
f"m_numtopk must be less than MAX_TOKENS_PER_EXPERT(" f"m_numtopk must be less than MAX_TOKENS_PER_EXPERT("
f"{MAX_TOKENS_PER_EXPERT})" f"{MAX_TOKENS_PER_EXPERT})"
......
...@@ -333,6 +333,7 @@ def cutlass_moe_fp4(a: torch.Tensor, a1_gscale: torch.Tensor, ...@@ -333,6 +333,7 @@ def cutlass_moe_fp4(a: torch.Tensor, a1_gscale: torch.Tensor,
num_topk = topk_ids.shape[1] num_topk = topk_ids.shape[1]
expert_offsets = torch.empty((e + 1), dtype=torch.int32, device=device) expert_offsets = torch.empty((e + 1), dtype=torch.int32, device=device)
blockscale_offsets = torch.empty((e + 1), dtype=torch.int32, device=device)
# Problem size: (num_experts, (m,2n,k)) # Problem size: (num_experts, (m,2n,k))
problem_sizes1 = torch.empty((e, 3), dtype=torch.int32, device=device) problem_sizes1 = torch.empty((e, 3), dtype=torch.int32, device=device)
# Problem size: (num_experts, (m,n,k)) # Problem size: (num_experts, (m,n,k))
...@@ -344,12 +345,10 @@ def cutlass_moe_fp4(a: torch.Tensor, a1_gscale: torch.Tensor, ...@@ -344,12 +345,10 @@ def cutlass_moe_fp4(a: torch.Tensor, a1_gscale: torch.Tensor,
# problem shapes should have [m, n, k] # problem shapes should have [m, n, k]
# Note that problem sizes are based on logical number of elements. # Note that problem sizes are based on logical number of elements.
ops.get_cutlass_moe_mm_data(topk_ids, expert_offsets, problem_sizes1, ops.get_cutlass_moe_mm_data(topk_ids, expert_offsets, problem_sizes1,
problem_sizes2, a_map, c_map, e, n, k) problem_sizes2, a_map, c_map, e, n, k,
blockscale_offsets)
tokens_per_expert = problem_sizes1[:, 0] a = ops.shuffle_rows(a, a_map)
rounded_tokens_per_expert = (tokens_per_expert + (128 - 1)) // 128 * 128
blockscale_offsets = torch.zeros(e + 1, dtype=torch.int32, device=device)
blockscale_offsets[1:] = torch.cumsum(rounded_tokens_per_expert, dim=0)
rep_a_fp4, rep_a_blockscale = ops.scaled_fp4_experts_quant( rep_a_fp4, rep_a_blockscale = ops.scaled_fp4_experts_quant(
a, a,
...@@ -357,7 +356,7 @@ def cutlass_moe_fp4(a: torch.Tensor, a1_gscale: torch.Tensor, ...@@ -357,7 +356,7 @@ def cutlass_moe_fp4(a: torch.Tensor, a1_gscale: torch.Tensor,
expert_offsets, expert_offsets,
blockscale_offsets, blockscale_offsets,
num_topk, num_topk,
expert_map=a_map) )
c1 = ops.cutlass_fp4_moe_mm(rep_a_fp4, w1_fp4, rep_a_blockscale, c1 = ops.cutlass_fp4_moe_mm(rep_a_fp4, w1_fp4, rep_a_blockscale,
w1_blockscale, w1_alphas, problem_sizes1, w1_blockscale, w1_alphas, problem_sizes1,
...@@ -378,6 +377,8 @@ def cutlass_moe_fp4(a: torch.Tensor, a1_gscale: torch.Tensor, ...@@ -378,6 +377,8 @@ def cutlass_moe_fp4(a: torch.Tensor, a1_gscale: torch.Tensor,
w2_alphas, problem_sizes2, expert_offsets[:-1], w2_alphas, problem_sizes2, expert_offsets[:-1],
blockscale_offsets[:-1], out_dtype, device) blockscale_offsets[:-1], out_dtype, device)
del int_fp4, int_blockscale del int_fp4, int_blockscale
out = (c2[c_map].view(m, num_topk, k) *
c2 = ops.shuffle_rows(c2, c_map)
out = (c2.view(m, num_topk, k) *
topk_weights.view(m, num_topk, 1).half()).sum(dim=1) topk_weights.view(m, num_topk, 1).half()).sum(dim=1)
return out.to(dtype=out_dtype) return out.to(dtype=out_dtype)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment