Unverified Commit 84166fee authored by ElizaWszola's avatar ElizaWszola Committed by GitHub
Browse files

[Kernel] Integrate CUTLASS MoE kernel with PPLX (#18762)


Signed-off-by: default avatarElizaWszola <ewszola@redhat.com>
Signed-off-by: default avatarTyler Michael Smith <tyler@neuralmagic.com>
Co-authored-by: default avatarTyler Michael Smith <tyler@neuralmagic.com>
parent 6e0cd10f
...@@ -543,8 +543,8 @@ if(VLLM_GPU_LANG STREQUAL "CUDA") ...@@ -543,8 +543,8 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
# CUTLASS MoE kernels # CUTLASS MoE kernels
# The MoE kernel cutlass_moe_mm requires CUDA 12.3 or later (and only works # The MoE kernel cutlass_moe_mm requires CUDA 12.3 or later (and only works
# on Hopper). get_cutlass_moe_mm_data should only be compiled if it's possible # on Hopper). get_cutlass_(pplx_)moe_mm_data should only be compiled
# to compile MoE kernels that use its output. # if it's possible to compile MoE kernels that use its output.
cuda_archs_loose_intersection(SCALED_MM_ARCHS "9.0a;10.0a" "${CUDA_ARCHS}") cuda_archs_loose_intersection(SCALED_MM_ARCHS "9.0a;10.0a" "${CUDA_ARCHS}")
if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.3 AND SCALED_MM_ARCHS) if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.3 AND SCALED_MM_ARCHS)
set(SRCS "csrc/quantization/cutlass_w8a8/moe/grouped_mm_c3x.cu" set(SRCS "csrc/quantization/cutlass_w8a8/moe/grouped_mm_c3x.cu"
......
...@@ -7,8 +7,8 @@ from benchmark_shapes import WEIGHT_SHAPES_MOE ...@@ -7,8 +7,8 @@ from benchmark_shapes import WEIGHT_SHAPES_MOE
from vllm import _custom_ops as ops from vllm import _custom_ops as ops
from vllm.config import ParallelConfig, VllmConfig, set_current_vllm_config from vllm.config import ParallelConfig, VllmConfig, set_current_vllm_config
from vllm.model_executor.layers.fused_moe.cutlass_moe import cutlass_moe_fp8
from vllm.model_executor.layers.fused_moe.fused_moe import ( from vllm.model_executor.layers.fused_moe.fused_moe import (
cutlass_moe_fp8,
fused_experts, fused_experts,
fused_topk, fused_topk,
) )
...@@ -70,18 +70,9 @@ def bench_run( ...@@ -70,18 +70,9 @@ def bench_run(
w1_scale = torch.empty((num_experts, 1, 1), device="cuda", dtype=torch.float32) w1_scale = torch.empty((num_experts, 1, 1), device="cuda", dtype=torch.float32)
w2_scale = torch.empty((num_experts, 1, 1), device="cuda", dtype=torch.float32) w2_scale = torch.empty((num_experts, 1, 1), device="cuda", dtype=torch.float32)
ab_strides1 = torch.full((num_experts,), k, device="cuda", dtype=torch.int64)
c_strides1 = torch.full((num_experts,), 2 * n, device="cuda", dtype=torch.int64)
ab_strides2 = torch.full((num_experts,), n, device="cuda", dtype=torch.int64)
c_strides2 = torch.full((num_experts,), k, device="cuda", dtype=torch.int64)
for expert in range(num_experts): for expert in range(num_experts):
w1_q[expert], w1_scale[expert] = ops.scaled_fp8_quant(w1[expert]) w1_q[expert], w1_scale[expert] = ops.scaled_fp8_quant(w1[expert])
w2_q[expert], w2_scale[expert] = ops.scaled_fp8_quant(w2[expert]) w2_q[expert], w2_scale[expert] = ops.scaled_fp8_quant(w2[expert])
w1_q_notransp = w1_q.clone()
w2_q_notransp = w2_q.clone()
w1_q = w1_q.transpose(1, 2)
w2_q = w2_q.transpose(1, 2)
score = torch.randn((m, num_experts), device="cuda", dtype=dtype) score = torch.randn((m, num_experts), device="cuda", dtype=dtype)
...@@ -122,10 +113,6 @@ def bench_run( ...@@ -122,10 +113,6 @@ def bench_run(
w2_scale: torch.Tensor, w2_scale: torch.Tensor,
topk_weights: torch.Tensor, topk_weights: torch.Tensor,
topk_ids: torch.Tensor, topk_ids: torch.Tensor,
ab_strides1: torch.Tensor,
c_strides1: torch.Tensor,
ab_strides2: torch.Tensor,
c_strides2: torch.Tensor,
num_repeats: int, num_repeats: int,
): ):
for _ in range(num_repeats): for _ in range(num_repeats):
...@@ -133,14 +120,10 @@ def bench_run( ...@@ -133,14 +120,10 @@ def bench_run(
a, a,
w1, w1,
w2, w2,
w1_scale,
w2_scale,
topk_weights, topk_weights,
topk_ids, topk_ids,
ab_strides1, w1_scale,
c_strides1, w2_scale,
ab_strides2,
c_strides2,
a1_scale=a_scale, a1_scale=a_scale,
) )
...@@ -153,10 +136,6 @@ def bench_run( ...@@ -153,10 +136,6 @@ def bench_run(
w2_scale: torch.Tensor, w2_scale: torch.Tensor,
topk_weights: torch.Tensor, topk_weights: torch.Tensor,
topk_ids: torch.Tensor, topk_ids: torch.Tensor,
ab_strides1: torch.Tensor,
c_strides1: torch.Tensor,
ab_strides2: torch.Tensor,
c_strides2: torch.Tensor,
): ):
with set_current_vllm_config( with set_current_vllm_config(
VllmConfig(parallel_config=ParallelConfig(pipeline_parallel_size=1)) VllmConfig(parallel_config=ParallelConfig(pipeline_parallel_size=1))
...@@ -165,14 +144,10 @@ def bench_run( ...@@ -165,14 +144,10 @@ def bench_run(
a, a,
w1_q, w1_q,
w2_q, w2_q,
w1_scale,
w2_scale,
topk_weights, topk_weights,
topk_ids, topk_ids,
ab_strides1, w1_scale,
c_strides1, w2_scale,
ab_strides2,
c_strides2,
a1_scale=a_scale, a1_scale=a_scale,
) )
...@@ -218,10 +193,6 @@ def bench_run( ...@@ -218,10 +193,6 @@ def bench_run(
w2_scale, w2_scale,
topk_weights, topk_weights,
topk_ids, topk_ids,
ab_strides1,
c_strides1,
ab_strides2,
c_strides2,
) )
torch.cuda.synchronize() torch.cuda.synchronize()
...@@ -230,8 +201,8 @@ def bench_run( ...@@ -230,8 +201,8 @@ def bench_run(
with torch.cuda.graph(triton_graph, stream=triton_stream): with torch.cuda.graph(triton_graph, stream=triton_stream):
run_triton_from_graph( run_triton_from_graph(
a, a,
w1_q_notransp, w1_q,
w2_q_notransp, w2_q,
topk_weights, topk_weights,
topk_ids, topk_ids,
w1_scale, w1_scale,
...@@ -250,18 +221,12 @@ def bench_run( ...@@ -250,18 +221,12 @@ def bench_run(
"w2": w2, "w2": w2,
"score": score, "score": score,
"topk": topk, "topk": topk,
"w1_q_notransp": w1_q_notransp,
"w2_q_notransp": w2_q_notransp,
# Cutlass params # Cutlass params
"a_scale": a_scale, "a_scale": a_scale,
"w1_q": w1_q, "w1_q": w1_q,
"w2_q": w2_q, "w2_q": w2_q,
"w1_scale": w1_scale, "w1_scale": w1_scale,
"w2_scale": w2_scale, "w2_scale": w2_scale,
"ab_strides1": ab_strides1,
"c_strides1": c_strides1,
"ab_strides2": ab_strides2,
"c_strides2": c_strides2,
# cuda graph params # cuda graph params
"cutlass_graph": cutlass_graph, "cutlass_graph": cutlass_graph,
"triton_graph": triton_graph, "triton_graph": triton_graph,
...@@ -279,8 +244,8 @@ def bench_run( ...@@ -279,8 +244,8 @@ def bench_run(
# Warmup # Warmup
run_triton_moe( run_triton_moe(
a, a,
w1_q_notransp, w1_q,
w2_q_notransp, w2_q,
topk_weights, topk_weights,
topk_ids, topk_ids,
w1_scale, w1_scale,
...@@ -291,7 +256,7 @@ def bench_run( ...@@ -291,7 +256,7 @@ def bench_run(
results.append( results.append(
benchmark.Timer( benchmark.Timer(
stmt="run_triton_moe(a, w1_q_notransp, w2_q_notransp, topk_weights, topk_ids, w1_scale, w2_scale, a_scale, num_runs)", # noqa: E501 stmt="run_triton_moe(a, w1_q, w2_q, topk_weights, topk_ids, w1_scale, w2_scale, a_scale, num_runs)", # noqa: E501
globals=globals, globals=globals,
label=label, label=label,
sub_label=sub_label, sub_label=sub_label,
...@@ -322,16 +287,12 @@ def bench_run( ...@@ -322,16 +287,12 @@ def bench_run(
w2_scale, w2_scale,
topk_weights, topk_weights,
topk_ids, topk_ids,
ab_strides1,
c_strides1,
ab_strides2,
c_strides2,
num_warmup, num_warmup,
) )
results.append( results.append(
benchmark.Timer( benchmark.Timer(
stmt="run_cutlass_moe(a, a_scale, w1_q, w2_q, w1_scale, w2_scale, topk_weights, topk_ids, ab_strides1, c_strides1, ab_strides2, c_strides2, num_runs)", # noqa: E501 stmt="run_cutlass_moe(a, a_scale, w1_q, w2_q, w1_scale, w2_scale, topk_weights, topk_ids, num_runs)", # noqa: E501
globals=globals, globals=globals,
label=label, label=label,
sub_label=sub_label, sub_label=sub_label,
......
...@@ -236,7 +236,8 @@ void cutlass_moe_mm( ...@@ -236,7 +236,8 @@ void cutlass_moe_mm(
torch::Tensor const& b_tensors, torch::Tensor const& a_scales, torch::Tensor const& b_tensors, torch::Tensor const& a_scales,
torch::Tensor const& b_scales, torch::Tensor const& expert_offsets, torch::Tensor const& b_scales, torch::Tensor const& expert_offsets,
torch::Tensor const& problem_sizes, torch::Tensor const& a_strides, torch::Tensor const& problem_sizes, torch::Tensor const& a_strides,
torch::Tensor const& b_strides, torch::Tensor const& c_strides); torch::Tensor const& b_strides, torch::Tensor const& c_strides,
bool per_act_token, bool per_out_ch);
void cutlass_fp4_group_mm( void cutlass_fp4_group_mm(
torch::Tensor& output, const torch::Tensor& a, const torch::Tensor& b, torch::Tensor& output, const torch::Tensor& a, const torch::Tensor& b,
...@@ -251,6 +252,14 @@ void get_cutlass_moe_mm_data( ...@@ -251,6 +252,14 @@ void get_cutlass_moe_mm_data(
const int64_t num_experts, const int64_t n, const int64_t k, const int64_t num_experts, const int64_t n, const int64_t k,
const std::optional<torch::Tensor>& blockscale_offsets); const std::optional<torch::Tensor>& blockscale_offsets);
void get_cutlass_pplx_moe_mm_data(torch::Tensor& expert_offsets,
torch::Tensor& problem_sizes1,
torch::Tensor& problem_sizes2,
const torch::Tensor& expert_num_tokens,
const int64_t num_local_experts,
const int64_t padded_m, const int64_t n,
const int64_t k);
void cutlass_scaled_mm_azp(torch::Tensor& out, torch::Tensor const& a, void cutlass_scaled_mm_azp(torch::Tensor& out, torch::Tensor const& a,
torch::Tensor const& b, torch::Tensor const& b,
torch::Tensor const& a_scales, torch::Tensor const& a_scales,
......
...@@ -84,7 +84,8 @@ void run_cutlass_moe_mm_sm90( ...@@ -84,7 +84,8 @@ void run_cutlass_moe_mm_sm90(
torch::Tensor const& b_tensors, torch::Tensor const& a_scales, torch::Tensor const& b_tensors, torch::Tensor const& a_scales,
torch::Tensor const& b_scales, torch::Tensor const& expert_offsets, torch::Tensor const& b_scales, torch::Tensor const& expert_offsets,
torch::Tensor const& problem_sizes, torch::Tensor const& a_strides, torch::Tensor const& problem_sizes, torch::Tensor const& a_strides,
torch::Tensor const& b_strides, torch::Tensor const& c_strides) { torch::Tensor const& b_strides, torch::Tensor const& c_strides,
bool per_act_token, bool per_out_ch) {
TORCH_CHECK(a_tensors.size(0) > 0, "No input A tensors provided."); TORCH_CHECK(a_tensors.size(0) > 0, "No input A tensors provided.");
TORCH_CHECK(b_tensors.size(0) > 0, "No input B tensors provided."); TORCH_CHECK(b_tensors.size(0) > 0, "No input B tensors provided.");
TORCH_CHECK(out_tensors.size(0) > 0, "No output tensors provided."); TORCH_CHECK(out_tensors.size(0) > 0, "No output tensors provided.");
...@@ -113,19 +114,23 @@ void run_cutlass_moe_mm_sm90( ...@@ -113,19 +114,23 @@ void run_cutlass_moe_mm_sm90(
if (n >= 8192) { if (n >= 8192) {
cutlass_group_gemm_caller<Cutlass3xGemmN8192>( cutlass_group_gemm_caller<Cutlass3xGemmN8192>(
out_tensors, a_tensors, b_tensors, a_scales, b_scales, expert_offsets, out_tensors, a_tensors, b_tensors, a_scales, b_scales, expert_offsets,
problem_sizes, a_strides, b_strides, c_strides); problem_sizes, a_strides, b_strides, c_strides, per_act_token,
per_out_ch);
} else if (k >= 8192) { } else if (k >= 8192) {
cutlass_group_gemm_caller<Cutlass3xGemmK8192>( cutlass_group_gemm_caller<Cutlass3xGemmK8192>(
out_tensors, a_tensors, b_tensors, a_scales, b_scales, expert_offsets, out_tensors, a_tensors, b_tensors, a_scales, b_scales, expert_offsets,
problem_sizes, a_strides, b_strides, c_strides); problem_sizes, a_strides, b_strides, c_strides, per_act_token,
per_out_ch);
} else if (m <= 16) { } else if (m <= 16) {
cutlass_group_gemm_caller<Cutlass3xGemmM16>( cutlass_group_gemm_caller<Cutlass3xGemmM16>(
out_tensors, a_tensors, b_tensors, a_scales, b_scales, expert_offsets, out_tensors, a_tensors, b_tensors, a_scales, b_scales, expert_offsets,
problem_sizes, a_strides, b_strides, c_strides); problem_sizes, a_strides, b_strides, c_strides, per_act_token,
per_out_ch);
} else { } else {
cutlass_group_gemm_caller<Cutlass3xGemmDefault>( cutlass_group_gemm_caller<Cutlass3xGemmDefault>(
out_tensors, a_tensors, b_tensors, a_scales, b_scales, expert_offsets, out_tensors, a_tensors, b_tensors, a_scales, b_scales, expert_offsets,
problem_sizes, a_strides, b_strides, c_strides); problem_sizes, a_strides, b_strides, c_strides, per_act_token,
per_out_ch);
} }
} }
...@@ -134,15 +139,18 @@ void dispatch_moe_mm_sm90( ...@@ -134,15 +139,18 @@ void dispatch_moe_mm_sm90(
torch::Tensor const& b_tensors, torch::Tensor const& a_scales, torch::Tensor const& b_tensors, torch::Tensor const& a_scales,
torch::Tensor const& b_scales, torch::Tensor const& expert_offsets, torch::Tensor const& b_scales, torch::Tensor const& expert_offsets,
torch::Tensor const& problem_sizes, torch::Tensor const& a_strides, torch::Tensor const& problem_sizes, torch::Tensor const& a_strides,
torch::Tensor const& b_strides, torch::Tensor const& c_strides) { torch::Tensor const& b_strides, torch::Tensor const& c_strides,
bool per_act_token, bool per_out_ch) {
if (out_tensors.dtype() == torch::kBFloat16) { if (out_tensors.dtype() == torch::kBFloat16) {
run_cutlass_moe_mm_sm90<cutlass::float_e4m3_t, cutlass::bfloat16_t>( run_cutlass_moe_mm_sm90<cutlass::float_e4m3_t, cutlass::bfloat16_t>(
out_tensors, a_tensors, b_tensors, a_scales, b_scales, expert_offsets, out_tensors, a_tensors, b_tensors, a_scales, b_scales, expert_offsets,
problem_sizes, a_strides, b_strides, c_strides); problem_sizes, a_strides, b_strides, c_strides, per_act_token,
per_out_ch);
} else { } else {
run_cutlass_moe_mm_sm90<cutlass::float_e4m3_t, cutlass::half_t>( run_cutlass_moe_mm_sm90<cutlass::float_e4m3_t, cutlass::half_t>(
out_tensors, a_tensors, b_tensors, a_scales, b_scales, expert_offsets, out_tensors, a_tensors, b_tensors, a_scales, b_scales, expert_offsets,
problem_sizes, a_strides, b_strides, c_strides); problem_sizes, a_strides, b_strides, c_strides, per_act_token,
per_out_ch);
} }
} }
...@@ -153,8 +161,9 @@ void cutlass_moe_mm_sm90( ...@@ -153,8 +161,9 @@ void cutlass_moe_mm_sm90(
torch::Tensor const& b_tensors, torch::Tensor const& a_scales, torch::Tensor const& b_tensors, torch::Tensor const& a_scales,
torch::Tensor const& b_scales, torch::Tensor const& expert_offsets, torch::Tensor const& b_scales, torch::Tensor const& expert_offsets,
torch::Tensor const& problem_sizes, torch::Tensor const& a_strides, torch::Tensor const& problem_sizes, torch::Tensor const& a_strides,
torch::Tensor const& b_strides, torch::Tensor const& c_strides) { torch::Tensor const& b_strides, torch::Tensor const& c_strides,
bool per_act_token, bool per_out_ch) {
dispatch_moe_mm_sm90(out_tensors, a_tensors, b_tensors, a_scales, b_scales, dispatch_moe_mm_sm90(out_tensors, a_tensors, b_tensors, a_scales, b_scales,
expert_offsets, problem_sizes, a_strides, b_strides, expert_offsets, problem_sizes, a_strides, b_strides,
c_strides); c_strides, per_act_token, per_out_ch);
} }
...@@ -76,7 +76,8 @@ void cutlass_group_gemm_caller( ...@@ -76,7 +76,8 @@ void cutlass_group_gemm_caller(
torch::Tensor const& b_tensors, torch::Tensor const& a_scales, torch::Tensor const& b_tensors, torch::Tensor const& a_scales,
torch::Tensor const& b_scales, torch::Tensor const& expert_offsets, torch::Tensor const& b_scales, torch::Tensor const& expert_offsets,
torch::Tensor const& problem_sizes, torch::Tensor const& a_strides, torch::Tensor const& problem_sizes, torch::Tensor const& a_strides,
torch::Tensor const& b_strides, torch::Tensor const& c_strides) { torch::Tensor const& b_strides, torch::Tensor const& c_strides,
bool per_act_token, bool per_out_ch) {
using ElementAB = typename Gemm::ElementAB; using ElementAB = typename Gemm::ElementAB;
using ElementD = typename Gemm::ElementD; using ElementD = typename Gemm::ElementD;
...@@ -84,9 +85,6 @@ void cutlass_group_gemm_caller( ...@@ -84,9 +85,6 @@ void cutlass_group_gemm_caller(
int k_size = a_tensors.size(1); int k_size = a_tensors.size(1);
int n_size = out_tensors.size(1); int n_size = out_tensors.size(1);
bool per_act_token = a_scales.numel() != 1;
bool per_out_ch = b_scales.numel() != num_experts;
auto stream = at::cuda::getCurrentCUDAStream(a_tensors.device().index()); auto stream = at::cuda::getCurrentCUDAStream(a_tensors.device().index());
auto options_int = auto options_int =
......
...@@ -7,7 +7,7 @@ ...@@ -7,7 +7,7 @@
constexpr uint64_t THREADS_PER_EXPERT = 512; constexpr uint64_t THREADS_PER_EXPERT = 512;
__global__ void compute_problem_sizes(const int* __restrict__ topk_ids, __global__ void compute_problem_sizes(const uint32_t* __restrict__ topk_ids,
int32_t* problem_sizes1, int32_t* problem_sizes1,
int32_t* problem_sizes2, int32_t* problem_sizes2,
int32_t* atomic_buffer, int32_t* atomic_buffer,
...@@ -62,7 +62,7 @@ __global__ void compute_expert_blockscale_offsets( ...@@ -62,7 +62,7 @@ __global__ void compute_expert_blockscale_offsets(
} }
} }
__global__ void compute_arg_sorts(const int* __restrict__ topk_ids, __global__ void compute_arg_sorts(const uint32_t* __restrict__ topk_ids,
const int32_t* __restrict__ expert_offsets, const int32_t* __restrict__ expert_offsets,
int32_t* input_permutation, int32_t* input_permutation,
int32_t* output_permutation, int32_t* output_permutation,
...@@ -103,7 +103,7 @@ void get_cutlass_moe_mm_data_caller( ...@@ -103,7 +103,7 @@ void get_cutlass_moe_mm_data_caller(
int num_threads = min(THREADS_PER_EXPERT, topk_ids.numel()); int num_threads = min(THREADS_PER_EXPERT, topk_ids.numel());
compute_problem_sizes<<<num_experts, num_threads, 0, stream>>>( compute_problem_sizes<<<num_experts, num_threads, 0, stream>>>(
static_cast<const int32_t*>(topk_ids.data_ptr()), static_cast<const uint32_t*>(topk_ids.data_ptr()),
static_cast<int32_t*>(problem_sizes1.data_ptr()), static_cast<int32_t*>(problem_sizes1.data_ptr()),
static_cast<int32_t*>(problem_sizes2.data_ptr()), static_cast<int32_t*>(problem_sizes2.data_ptr()),
static_cast<int32_t*>(atomic_buffer.data_ptr()), topk_ids.numel(), n, k); static_cast<int32_t*>(atomic_buffer.data_ptr()), topk_ids.numel(), n, k);
...@@ -120,10 +120,44 @@ void get_cutlass_moe_mm_data_caller( ...@@ -120,10 +120,44 @@ void get_cutlass_moe_mm_data_caller(
static_cast<int32_t*>(atomic_buffer.data_ptr()), num_experts); static_cast<int32_t*>(atomic_buffer.data_ptr()), num_experts);
} }
compute_arg_sorts<<<num_experts, num_threads, 0, stream>>>( compute_arg_sorts<<<num_experts, num_threads, 0, stream>>>(
static_cast<const int32_t*>(topk_ids.data_ptr()), static_cast<const uint32_t*>(topk_ids.data_ptr()),
static_cast<const int32_t*>(expert_offsets.data_ptr()), static_cast<const int32_t*>(expert_offsets.data_ptr()),
static_cast<int32_t*>(input_permutation.data_ptr()), static_cast<int32_t*>(input_permutation.data_ptr()),
static_cast<int32_t*>(output_permutation.data_ptr()), static_cast<int32_t*>(output_permutation.data_ptr()),
static_cast<int32_t*>(atomic_buffer.data_ptr()), topk_ids.numel(), static_cast<int32_t*>(atomic_buffer.data_ptr()), topk_ids.numel(),
topk_ids.size(1)); topk_ids.size(1));
} }
__global__ void compute_pplx_data(int32_t* expert_offsets,
int32_t* problem_sizes1,
int32_t* problem_sizes2,
const int32_t* __restrict__ expert_num_tokens,
const int padded_m, const int n,
const int k) {
int expert_idx = threadIdx.x;
expert_offsets[expert_idx] = expert_idx * padded_m;
problem_sizes1[expert_idx * 3] = expert_num_tokens[expert_idx];
problem_sizes1[expert_idx * 3 + 1] = 2 * n;
problem_sizes1[expert_idx * 3 + 2] = k;
problem_sizes2[expert_idx * 3] = expert_num_tokens[expert_idx];
problem_sizes2[expert_idx * 3 + 1] = k;
problem_sizes2[expert_idx * 3 + 2] = n;
}
void get_cutlass_pplx_moe_mm_data_caller(torch::Tensor& expert_offsets,
torch::Tensor& problem_sizes1,
torch::Tensor& problem_sizes2,
const torch::Tensor& expert_num_tokens,
const int64_t num_local_experts,
const int64_t padded_m,
const int64_t n, const int64_t k) {
auto stream = at::cuda::getCurrentCUDAStream(expert_offsets.device().index());
compute_pplx_data<<<1, num_local_experts, 0, stream>>>(
static_cast<int32_t*>(expert_offsets.data_ptr()),
static_cast<int32_t*>(problem_sizes1.data_ptr()),
static_cast<int32_t*>(problem_sizes2.data_ptr()),
static_cast<const int32_t*>(expert_num_tokens.data_ptr()), padded_m, n,
k);
}
...@@ -36,7 +36,8 @@ void cutlass_moe_mm_sm90( ...@@ -36,7 +36,8 @@ void cutlass_moe_mm_sm90(
torch::Tensor const& b_tensors, torch::Tensor const& a_scales, torch::Tensor const& b_tensors, torch::Tensor const& a_scales,
torch::Tensor const& b_scales, torch::Tensor const& expert_offsets, torch::Tensor const& b_scales, torch::Tensor const& expert_offsets,
torch::Tensor const& problem_sizes, torch::Tensor const& a_strides, torch::Tensor const& problem_sizes, torch::Tensor const& a_strides,
torch::Tensor const& b_strides, torch::Tensor const& c_strides); torch::Tensor const& b_strides, torch::Tensor const& c_strides,
bool per_act_token, bool per_out_ch);
#endif #endif
...@@ -56,6 +57,14 @@ void get_cutlass_moe_mm_data_caller( ...@@ -56,6 +57,14 @@ void get_cutlass_moe_mm_data_caller(
torch::Tensor& input_permutation, torch::Tensor& output_permutation, torch::Tensor& input_permutation, torch::Tensor& output_permutation,
const int64_t num_experts, const int64_t n, const int64_t k, const int64_t num_experts, const int64_t n, const int64_t k,
const std::optional<torch::Tensor>& blockscale_offsets); const std::optional<torch::Tensor>& blockscale_offsets);
void get_cutlass_pplx_moe_mm_data_caller(torch::Tensor& expert_offsets,
torch::Tensor& problem_sizes1,
torch::Tensor& problem_sizes2,
const torch::Tensor& expert_num_tokens,
const int64_t num_local_experts,
const int64_t padded_m,
const int64_t n, const int64_t k);
#endif #endif
void cutlass_scaled_mm_azp_sm75(torch::Tensor& c, torch::Tensor const& a, void cutlass_scaled_mm_azp_sm75(torch::Tensor& c, torch::Tensor const& a,
...@@ -207,12 +216,13 @@ void cutlass_moe_mm( ...@@ -207,12 +216,13 @@ void cutlass_moe_mm(
torch::Tensor const& b_tensors, torch::Tensor const& a_scales, torch::Tensor const& b_tensors, torch::Tensor const& a_scales,
torch::Tensor const& b_scales, torch::Tensor const& expert_offsets, torch::Tensor const& b_scales, torch::Tensor const& expert_offsets,
torch::Tensor const& problem_sizes, torch::Tensor const& a_strides, torch::Tensor const& problem_sizes, torch::Tensor const& a_strides,
torch::Tensor const& b_strides, torch::Tensor const& c_strides) { torch::Tensor const& b_strides, torch::Tensor const& c_strides,
bool per_act_token, bool per_out_ch) {
int32_t version_num = get_sm_version_num(); int32_t version_num = get_sm_version_num();
#if defined ENABLE_CUTLASS_MOE_SM90 && ENABLE_CUTLASS_MOE_SM90 #if defined ENABLE_CUTLASS_MOE_SM90 && ENABLE_CUTLASS_MOE_SM90
cutlass_moe_mm_sm90(out_tensors, a_tensors, b_tensors, a_scales, b_scales, cutlass_moe_mm_sm90(out_tensors, a_tensors, b_tensors, a_scales, b_scales,
expert_offsets, problem_sizes, a_strides, b_strides, expert_offsets, problem_sizes, a_strides, b_strides,
c_strides); c_strides, per_act_token, per_out_ch);
return; return;
#endif #endif
TORCH_CHECK_NOT_IMPLEMENTED( TORCH_CHECK_NOT_IMPLEMENTED(
...@@ -245,6 +255,29 @@ void get_cutlass_moe_mm_data( ...@@ -245,6 +255,29 @@ void get_cutlass_moe_mm_data(
version_num, ". Required capability: 90"); version_num, ". Required capability: 90");
} }
void get_cutlass_pplx_moe_mm_data(torch::Tensor& expert_offsets,
torch::Tensor& problem_sizes1,
torch::Tensor& problem_sizes2,
const torch::Tensor& expert_num_tokens,
const int64_t num_local_experts,
const int64_t padded_m, const int64_t n,
const int64_t k) {
// This function currently gets compiled only if we have a valid cutlass moe
// mm to run it for.
int32_t version_num = get_sm_version_num();
#if defined ENABLE_CUTLASS_MOE_SM90 && ENABLE_CUTLASS_MOE_SM90
get_cutlass_pplx_moe_mm_data_caller(expert_offsets, problem_sizes1,
problem_sizes2, expert_num_tokens,
num_local_experts, padded_m, n, k);
return;
#endif
TORCH_CHECK_NOT_IMPLEMENTED(
false,
"No compiled get_cutlass_pplx_moe_mm_data: no cutlass_scaled_mm kernel "
"for CUDA device capability: ",
version_num, ". Required capability: 90");
}
void cutlass_scaled_mm_azp(torch::Tensor& c, torch::Tensor const& a, void cutlass_scaled_mm_azp(torch::Tensor& c, torch::Tensor const& a,
torch::Tensor const& b, torch::Tensor const& b,
torch::Tensor const& a_scales, torch::Tensor const& a_scales,
......
...@@ -435,7 +435,8 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) { ...@@ -435,7 +435,8 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
"cutlass_moe_mm(Tensor! out_tensors, Tensor a_tensors, Tensor b_tensors, " "cutlass_moe_mm(Tensor! out_tensors, Tensor a_tensors, Tensor b_tensors, "
" Tensor a_scales, Tensor b_scales, Tensor expert_offsets, " " Tensor a_scales, Tensor b_scales, Tensor expert_offsets, "
" Tensor problem_sizes, Tensor a_strides, " " Tensor problem_sizes, Tensor a_strides, "
" Tensor b_strides, Tensor c_strides) -> ()", " Tensor b_strides, Tensor c_strides, bool per_act_token, "
" bool per_out_ch) -> ()",
{stride_tag}); {stride_tag});
ops.impl("cutlass_moe_mm", torch::kCUDA, &cutlass_moe_mm); ops.impl("cutlass_moe_mm", torch::kCUDA, &cutlass_moe_mm);
...@@ -454,6 +455,22 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) { ...@@ -454,6 +455,22 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
{stride_tag}); {stride_tag});
ops.impl("get_cutlass_moe_mm_data", torch::kCUDA, &get_cutlass_moe_mm_data); ops.impl("get_cutlass_moe_mm_data", torch::kCUDA, &get_cutlass_moe_mm_data);
// A function that computes data required to run fused MoE with w8a8 grouped
// GEMM and PPLX. It takes expert_num_tokens and non_zero_expert_idxs
// as an input, and computes expert_offsets (token start indices of each
// expert). In addition to this, it computes problem sizes for each expert's
// multiplication used by the two mms called from fused MoE operation.
ops.def(
"get_cutlass_pplx_moe_mm_data(Tensor! expert_offsets, "
" Tensor! problem_sizes1, "
" Tensor! problem_sizes2, "
" Tensor expert_num_tokens, "
" int num_local_experts, int padded_m, "
" int n, int k) -> ()",
{stride_tag});
ops.impl("get_cutlass_pplx_moe_mm_data", torch::kCUDA,
&get_cutlass_pplx_moe_mm_data);
// Check if cutlass scaled_mm supports block quantization (used by DeepSeekV3) // Check if cutlass scaled_mm supports block quantization (used by DeepSeekV3)
ops.def( ops.def(
"cutlass_scaled_mm_supports_block_fp8(int cuda_device_capability) -> " "cutlass_scaled_mm_supports_block_fp8(int cuda_device_capability) -> "
......
...@@ -193,14 +193,10 @@ def run_8_bit(moe_tensors: MOETensors8Bit, ...@@ -193,14 +193,10 @@ def run_8_bit(moe_tensors: MOETensors8Bit,
kwargs = { kwargs = {
'a': moe_tensors.a, 'a': moe_tensors.a,
'w1_q': moe_tensors.w1_q.transpose(1, 2), # type: ignore[union-attr] 'w1_q': moe_tensors.w1_q, # type: ignore[union-attr]
'w2_q': moe_tensors.w2_q.transpose(1, 2), # type: ignore[union-attr] 'w2_q': moe_tensors.w2_q, # type: ignore[union-attr]
'topk_weights': topk_weights, 'topk_weights': topk_weights,
'topk_ids': topk_ids, 'topk_ids': topk_ids,
'ab_strides1': moe_tensors.ab_strides1,
'c_strides1': moe_tensors.c_strides1,
'ab_strides2': moe_tensors.ab_strides2,
'c_strides2': moe_tensors.c_strides2,
'w1_scale': moe_tensors.w1_scale, 'w1_scale': moe_tensors.w1_scale,
'w2_scale': moe_tensors.w2_scale, 'w2_scale': moe_tensors.w2_scale,
'a1_scale': moe_tensors.a_scale 'a1_scale': moe_tensors.a_scale
......
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import pytest
import torch
from tests.pplx_utils import ProcessGroupInfo, parallel_launch
from vllm import _custom_ops as ops
from vllm.config import VllmConfig, set_current_vllm_config
from vllm.model_executor.layers.activation import SiluAndMul
from vllm.model_executor.layers.fused_moe.cutlass_moe import CutlassExpertsFp8
from vllm.model_executor.layers.fused_moe.fused_moe import fused_topk
from vllm.model_executor.layers.fused_moe.modular_kernel import (
FusedMoEModularKernel)
from vllm.platforms import current_platform
try:
from pplx_kernels import AllToAll
from pplx_kernels.nvshmem import (nvshmem_alloc_empty_unique_id,
nvshmem_finalize, nvshmem_get_unique_id,
nvshmem_init)
has_pplx = True
except ImportError:
has_pplx = False
requires_pplx = pytest.mark.skipif(
not has_pplx,
reason="Requires PPLX kernels",
)
NUM_EXPERTS = [40, 64]
TOP_KS = [6, 8]
def rank_chunk(num, r, w):
rem = num % w
return (num // w) + (1 if r < rem else 0)
def chunk_by_rank(t, r, w):
num = t.shape[0]
chunk = rank_chunk(num, r, w)
rem = num % w
if rem == 0 or r < rem:
return t[(r * chunk):(r + 1) * chunk].contiguous()
else:
long_chunks = (num // w + 1) * rem
short_chunks = (r - rem) * chunk
start = long_chunks + short_chunks
return t[start:start + chunk].contiguous()
def pplx_cutlass_moe(
pgi: ProcessGroupInfo,
dp_size: int,
a: torch.Tensor,
w1: torch.Tensor,
w2: torch.Tensor,
w1_scale: torch.Tensor,
w2_scale: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
a1_scale: torch.Tensor,
out_dtype,
per_act_token: bool,
per_out_ch: bool,
):
from vllm.model_executor.layers.fused_moe.pplx_prepare_finalize import (
PplxPrepareAndFinalize)
assert torch.cuda.current_device() == pgi.local_rank
num_tokens, hidden_dim = a.shape
num_experts = w1.shape[0]
block_size = hidden_dim # TODO support more cases
device = pgi.device
rank = pgi.rank
world_size = pgi.world_size
rank_num_tokens = rank_chunk(num_tokens, rank, world_size)
max_num_tokens = rank_chunk(num_tokens, 0, world_size)
topk = topk_ids.shape[1]
if block_size == hidden_dim:
scale_elems = 4 # hack to circumvent pplx data format requirements
else:
scale_elems = (hidden_dim + block_size - 1) // block_size
ata = AllToAll.internode(
max_num_tokens=max_num_tokens,
num_experts=num_experts,
experts_per_token=topk,
rank=rank,
world_size=pgi.world_size,
dp_size=dp_size,
hidden_dim=hidden_dim,
hidden_dim_bytes=hidden_dim, # because a.dtype.itemsize == 1
hidden_dim_scale_bytes=scale_elems * torch.float32.itemsize,
)
w1 = w1.to(device)
w2 = w2.to(device)
w1_scale = w1_scale.to(device)
w2_scale = w2_scale.to(device)
a1_scale = a1_scale.to(device)
prepare_finalize = PplxPrepareAndFinalize(
ata,
max_num_tokens,
pgi.world_size,
rank,
dp_size,
quant_dtype=torch.float8_e4m3fn,
per_act_token=per_act_token,
)
experts = CutlassExpertsFp8((num_experts + world_size - 1) // world_size,
out_dtype, per_act_token, per_out_ch)
fused_cutlass_experts = FusedMoEModularKernel(
prepare_finalize,
experts,
)
a_chunk = chunk_by_rank(a, rank, world_size).to(device)
chunk_topk_weight = chunk_by_rank(topk_weights, rank,
world_size).to(device)
chunk_topk_ids = chunk_by_rank(topk_ids, rank,
world_size).to(torch.uint32).to(device)
out = fused_cutlass_experts(
a_chunk,
chunk_by_rank(w1, rank, world_size),
chunk_by_rank(w2, rank, world_size),
chunk_topk_weight,
chunk_topk_ids,
global_num_experts=num_experts,
expert_map=None, #TODO
w1_scale=chunk_by_rank(w1_scale, rank, world_size),
w2_scale=chunk_by_rank(w2_scale, rank, world_size),
a1_scale=chunk_by_rank(a1_scale, rank, world_size)
if per_act_token else a1_scale[rank])
torch.cuda.synchronize()
ata.destroy()
return out[:rank_num_tokens]
vllm_config = VllmConfig()
vllm_config.scheduler_config.max_num_seqs = 128
vllm_config.scheduler_config.max_model_len = 8192
def torch_moe2(a, w1, w2, topk_weight, topk_ids):
M, K = a.shape
topk = topk_ids.shape[1]
a = a.view(M, -1, K).repeat(1, topk, 1).reshape(-1, K)
out = torch.zeros(M * topk, w2.shape[1], dtype=a.dtype, device=a.device)
num_experts = w1.shape[0]
for i in range(num_experts):
mask = (topk_ids == i).view(-1)
if mask.sum():
out[mask] = SiluAndMul()(
a[mask] @ w1[i].transpose(0, 1)) @ w2[i].transpose(0, 1)
return (out.view(M, -1, w2.shape[1]) *
topk_weight.view(M, -1, 1).to(out.dtype)).sum(dim=1)
def _pplx_moe(
pgi: ProcessGroupInfo,
dp_size: int,
a: torch.Tensor,
w1: torch.Tensor,
w2: torch.Tensor,
w1_scale: torch.Tensor,
w2_scale: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
a1_scale: torch.Tensor,
out_dtype,
a_full: torch.Tensor,
w1_full: torch.Tensor,
w2_full: torch.Tensor,
per_act_token: bool,
per_out_ch: bool,
):
uid = nvshmem_get_unique_id(
) if pgi.rank == 0 else nvshmem_alloc_empty_unique_id()
torch.distributed.broadcast(uid, src=0)
nvshmem_init(uid, pgi.rank, pgi.world_size)
with set_current_vllm_config(vllm_config):
torch_output = torch_moe2(a_full, w1_full, w2_full, topk_weights,
topk_ids)
pplx_output = pplx_cutlass_moe(pgi, dp_size, a, w1, w2, w1_scale,
w2_scale, topk_weights, topk_ids,
a1_scale, out_dtype, per_act_token,
per_out_ch)
torch_output = chunk_by_rank(torch_output, pgi.rank,
pgi.world_size).to(pplx_output.device)
# Uncomment if more debugging is needed
# print("PPLX OUT:", pplx_output)
# print("TORCH OUT:", torch_output)
torch.testing.assert_close(pplx_output, torch_output, atol=0.05, rtol=0)
nvshmem_finalize()
@pytest.mark.parametrize("m", [2, 224])
@pytest.mark.parametrize("n", [3072])
@pytest.mark.parametrize("k", [1536])
@pytest.mark.parametrize("e", NUM_EXPERTS)
@pytest.mark.parametrize("topk", TOP_KS)
@pytest.mark.parametrize("per_act_token", [True, False])
@pytest.mark.parametrize("per_out_ch", [True, False])
@pytest.mark.parametrize("world_dp_size", [[2, 1]]) #, [4, 2]])
@pytest.mark.skipif(
(lambda x: x is None or not ops.cutlass_group_gemm_supported(x.to_int()))(
current_platform.get_device_capability()),
reason="Grouped gemm is not supported on this GPU type.")
@requires_pplx
def test_cutlass_moe_pplx(
m: int,
n: int,
k: int,
e: int,
topk: int,
per_act_token: bool,
per_out_ch: bool,
world_dp_size: tuple[int, int],
):
current_platform.seed_everything(7)
with set_current_vllm_config(vllm_config):
dtype = torch.half
a = torch.randn((m, k), device="cuda", dtype=dtype) / 10.0
w1 = torch.randn((e, 2 * n, k), device="cuda", dtype=dtype) / 10.0
w2 = torch.randn((e, k, n), device="cuda", dtype=dtype) / 10.0
n_b_scales = 2 * n if per_out_ch else 1
k_b_scales = k if per_out_ch else 1
w1_q = torch.empty((e, 2 * n, k),
device="cuda",
dtype=torch.float8_e4m3fn)
w2_q = torch.empty((e, k, n), device="cuda", dtype=torch.float8_e4m3fn)
w1_scale = torch.empty((e, n_b_scales, 1),
device="cuda",
dtype=torch.float32)
w2_scale = torch.empty((e, k_b_scales, 1),
device="cuda",
dtype=torch.float32)
for expert in range(e):
w1_q[expert], w1_scale[expert] = ops.scaled_fp8_quant(
w1[expert], use_per_token_if_dynamic=per_out_ch)
w2_q[expert], w2_scale[expert] = ops.scaled_fp8_quant(
w2[expert], use_per_token_if_dynamic=per_out_ch)
w1_d = torch.empty_like(w1)
w2_d = torch.empty_like(w2)
for expert in range(e):
w1_d[expert] = (w1_q[expert].float() * w1_scale[expert]).half()
w2_d[expert] = (w2_q[expert].float() * w2_scale[expert]).half()
score = torch.randn((m, e), device="cuda", dtype=dtype)
topk_weights, topk_ids, _ = fused_topk(a,
score,
topk,
renormalize=False)
world_size, dp_size = world_dp_size
a_scale1 = torch.randn(
(m if per_act_token else 1, 1), device="cuda",
dtype=torch.float32) / 10.0
if not per_act_token:
a_scale1 = a_scale1.repeat(world_size, 1)
parallel_launch(world_size, _pplx_moe, dp_size, a, w1_q, w2_q,
w1_scale, w2_scale, topk_weights, topk_ids, a_scale1,
dtype, a, w1_d, w2_d, per_act_token, per_out_ch)
...@@ -4,10 +4,7 @@ ...@@ -4,10 +4,7 @@
Run `pytest tests/kernels/test_pplx_moe.py`. Run `pytest tests/kernels/test_pplx_moe.py`.
""" """
import dataclasses from typing import Optional
import os
import traceback
from typing import Callable, Optional
import pytest import pytest
import torch import torch
...@@ -21,10 +18,7 @@ try: ...@@ -21,10 +18,7 @@ try:
except ImportError: except ImportError:
has_pplx = False has_pplx = False
from torch.multiprocessing import ( from tests.pplx_utils import ProcessGroupInfo, parallel_launch
spawn) # pyright: ignore[reportPrivateImportUsage]
from typing_extensions import Concatenate, ParamSpec
from vllm.config import VllmConfig, set_current_vllm_config from vllm.config import VllmConfig, set_current_vllm_config
from vllm.model_executor.layers.activation import SiluAndMul from vllm.model_executor.layers.activation import SiluAndMul
from vllm.model_executor.layers.fused_moe import override_config from vllm.model_executor.layers.fused_moe import override_config
...@@ -36,6 +30,11 @@ from vllm.model_executor.layers.fused_moe.modular_kernel import ( ...@@ -36,6 +30,11 @@ from vllm.model_executor.layers.fused_moe.modular_kernel import (
FusedMoEModularKernel) FusedMoEModularKernel)
from vllm.platforms import current_platform from vllm.platforms import current_platform
requires_pplx = pytest.mark.skipif(
not has_pplx,
reason="Requires PPLX kernels",
)
PPLX_PREPARE_COMBOS = [(4, 128, 128), (32, 1024, 512), (64, 1024, 512), PPLX_PREPARE_COMBOS = [(4, 128, 128), (32, 1024, 512), (64, 1024, 512),
(222, 2048, 1024)] (222, 2048, 1024)]
...@@ -57,122 +56,6 @@ vllm_config = VllmConfig() ...@@ -57,122 +56,6 @@ vllm_config = VllmConfig()
vllm_config.scheduler_config.max_num_seqs = 128 vllm_config.scheduler_config.max_num_seqs = 128
vllm_config.scheduler_config.max_model_len = 8192 vllm_config.scheduler_config.max_model_len = 8192
P = ParamSpec("P")
requires_pplx = pytest.mark.skipif(
not has_pplx,
reason="Requires PPLX kernels",
)
@dataclasses.dataclass
class ProcessGroupInfo:
world_size: int
world_local_size: int
rank: int
node_rank: int
local_rank: int
device: torch.device
def _worker_parallel_launch(
local_rank: int,
world_size: int,
world_local_size: int,
node_rank: int,
init_method: str,
worker: Callable[Concatenate[ProcessGroupInfo, P], None],
*args: P.args,
**kwargs: P.kwargs,
) -> None:
rank = node_rank * world_local_size + local_rank
torch.cuda.set_device(local_rank)
device = torch.device("cuda", local_rank)
torch.distributed.init_process_group(
backend="cpu:gloo,cuda:nccl",
init_method=init_method,
rank=rank,
world_size=world_size,
device_id=device,
)
barrier = torch.tensor([rank], device=device)
torch.distributed.all_reduce(barrier)
try:
worker(
ProcessGroupInfo(
world_size=world_size,
world_local_size=world_local_size,
rank=rank,
node_rank=node_rank,
local_rank=local_rank,
device=device,
),
*args,
**kwargs,
)
except Exception as ex:
print(ex)
traceback.print_exc()
raise
finally:
torch.distributed.destroy_process_group()
def parallel_launch(
world_size: int,
worker: Callable[Concatenate[ProcessGroupInfo, P], None],
*args: P.args,
**kwargs: P.kwargs,
) -> None:
assert not kwargs
spawn(
_worker_parallel_launch,
args=(
world_size,
world_size,
0,
"tcp://localhost:29500",
worker,
) + args,
nprocs=world_size,
join=True,
)
def parallel_launch_from_env(
worker: Callable[Concatenate[ProcessGroupInfo, P], None],
*args: P.args,
**kwargs: P.kwargs,
) -> None:
"""
Launches a worker function in parallel across all processes in the current
environment. The environment must have the following variables set:
- WORLD_SIZE: The total number of processes.
- WORLD_LOCAL_SIZE: The number of processes on the current node.
- NODE_RANK: The rank of the current
- MASTER_ADDR: The address of the master process.
- MASTER_PORT: The port of the master process.
"""
assert not kwargs
world_size = int(os.environ["WORLD_SIZE"])
world_local_size = int(os.environ["WORLD_LOCAL_SIZE"])
node_rank = int(os.environ["NODE_RANK"])
assert "MASTER_ADDR" in os.environ
assert "MASTER_PORT" in os.environ
spawn(
_worker_parallel_launch,
args=(
world_size,
world_local_size,
node_rank,
"env://",
worker,
) + args,
nprocs=world_local_size,
join=True,
)
def torch_prepare( def torch_prepare(
a: torch.Tensor, a: torch.Tensor,
......
...@@ -632,7 +632,8 @@ def test_cutlass_fp8_group_gemm(num_experts: int, per_act_token: bool, ...@@ -632,7 +632,8 @@ def test_cutlass_fp8_group_gemm(num_experts: int, per_act_token: bool,
ops.cutlass_moe_mm(out_tensors_stacked, a_tensors_stacked, ops.cutlass_moe_mm(out_tensors_stacked, a_tensors_stacked,
b_tensors_stacked, a_scales_tensors_stacked, b_tensors_stacked, a_scales_tensors_stacked,
b_scales_tensors_stacked, expert_offsets[:-1], b_scales_tensors_stacked, expert_offsets[:-1],
problem_sizes, ab_strides, ab_strides, c_strides) problem_sizes, ab_strides, ab_strides, c_strides,
per_act_token, per_out_ch)
# Validate each group's result against the baseline # Validate each group's result against the baseline
for g in range(num_experts): for g in range(num_experts):
......
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import dataclasses
import os
import traceback
from typing import Callable
import torch
from torch.multiprocessing import (
spawn) # pyright: ignore[reportPrivateImportUsage]
from typing_extensions import Concatenate, ParamSpec
P = ParamSpec("P")
@dataclasses.dataclass
class ProcessGroupInfo:
world_size: int
world_local_size: int
rank: int
node_rank: int
local_rank: int
device: torch.device
def _worker_parallel_launch(
local_rank: int,
world_size: int,
world_local_size: int,
node_rank: int,
init_method: str,
worker: Callable[Concatenate[ProcessGroupInfo, P], None],
*args: P.args,
**kwargs: P.kwargs,
) -> None:
rank = node_rank * world_local_size + local_rank
torch.cuda.set_device(local_rank)
device = torch.device("cuda", local_rank)
torch.distributed.init_process_group(
backend="cpu:gloo,cuda:nccl",
init_method=init_method,
rank=rank,
world_size=world_size,
device_id=device,
)
barrier = torch.tensor([rank], device=device)
torch.distributed.all_reduce(barrier)
try:
worker(
ProcessGroupInfo(
world_size=world_size,
world_local_size=world_local_size,
rank=rank,
node_rank=node_rank,
local_rank=local_rank,
device=device,
),
*args,
**kwargs,
)
except Exception as ex:
print(ex)
traceback.print_exc()
raise
finally:
torch.distributed.destroy_process_group()
def parallel_launch(
world_size: int,
worker: Callable[Concatenate[ProcessGroupInfo, P], None],
*args: P.args,
**kwargs: P.kwargs,
) -> None:
assert not kwargs
spawn(
_worker_parallel_launch,
args=(
world_size,
world_size,
0,
"tcp://localhost:29500",
worker,
) + args,
nprocs=world_size,
join=True,
)
def parallel_launch_from_env(
worker: Callable[Concatenate[ProcessGroupInfo, P], None],
*args: P.args,
**kwargs: P.kwargs,
) -> None:
"""
Launches a worker function in parallel across all processes in the current
environment. The environment must have the following variables set:
- WORLD_SIZE: The total number of processes.
- WORLD_LOCAL_SIZE: The number of processes on the current node.
- NODE_RANK: The rank of the current
- MASTER_ADDR: The address of the master process.
- MASTER_PORT: The port of the master process.
"""
assert not kwargs
world_size = int(os.environ["WORLD_SIZE"])
world_local_size = int(os.environ["WORLD_LOCAL_SIZE"])
node_rank = int(os.environ["NODE_RANK"])
assert "MASTER_ADDR" in os.environ
assert "MASTER_PORT" in os.environ
spawn(
_worker_parallel_launch,
args=(
world_size,
world_local_size,
node_rank,
"env://",
worker,
) + args,
nprocs=world_local_size,
join=True,
)
...@@ -899,11 +899,36 @@ def shuffle_rows(input_tensor: torch.Tensor, dst2src_map: torch.Tensor): ...@@ -899,11 +899,36 @@ def shuffle_rows(input_tensor: torch.Tensor, dst2src_map: torch.Tensor):
return output_tensor return output_tensor
def get_cutlass_pplx_moe_mm_data(expert_offsets: torch.Tensor,
problem_sizes1: torch.Tensor,
problem_sizes2: torch.Tensor,
expert_num_tokens: torch.Tensor,
num_local_experts: int, padded_m: int, n: int,
k: int):
"""
Prepare data necessary to perform CUTLASS grouped matrix multiplications
used in CUTLASS-based fused MoE.
The function takes in expert_num_tokens (token count per expert) and
non_zero_expert_idxs (consecutive indices of experts with non-zero token
counts) and uses them to compute:
- expert_offsets: Indices that mark at which token index each expert begins
its computation.
- problem_sizes1, problem_sizes2: MxNxK sizes of each expert's
multiplication in two grouped MMs used in
the fused MoE operation.
"""
return torch.ops._C.get_cutlass_pplx_moe_mm_data(
expert_offsets, problem_sizes1, problem_sizes2, expert_num_tokens,
num_local_experts, padded_m, n, k)
def cutlass_moe_mm(out_tensors: torch.Tensor, a_tensors: torch.Tensor, def cutlass_moe_mm(out_tensors: torch.Tensor, a_tensors: torch.Tensor,
b_tensors: torch.Tensor, a_scales: torch.Tensor, b_tensors: torch.Tensor, a_scales: torch.Tensor,
b_scales: torch.Tensor, expert_offsets: torch.Tensor, b_scales: torch.Tensor, expert_offsets: torch.Tensor,
problem_sizes: torch.Tensor, a_strides: torch.Tensor, problem_sizes: torch.Tensor, a_strides: torch.Tensor,
b_strides: torch.Tensor, c_strides: torch.Tensor): b_strides: torch.Tensor, c_strides: torch.Tensor,
per_act_token: bool, per_out_ch: bool):
""" """
A single grouped matrix multiplication used in CUTLASS-based fused MoE. A single grouped matrix multiplication used in CUTLASS-based fused MoE.
The function executes fp8-quantized OUT = AB matrix multiplication. The function executes fp8-quantized OUT = AB matrix multiplication.
...@@ -918,7 +943,7 @@ def cutlass_moe_mm(out_tensors: torch.Tensor, a_tensors: torch.Tensor, ...@@ -918,7 +943,7 @@ def cutlass_moe_mm(out_tensors: torch.Tensor, a_tensors: torch.Tensor,
return torch.ops._C.cutlass_moe_mm(out_tensors, a_tensors, b_tensors, return torch.ops._C.cutlass_moe_mm(out_tensors, a_tensors, b_tensors,
a_scales, b_scales, expert_offsets, a_scales, b_scales, expert_offsets,
problem_sizes, a_strides, b_strides, problem_sizes, a_strides, b_strides,
c_strides) c_strides, per_act_token, per_out_ch)
def cutlass_fp4_moe_mm(a_tensors: torch.Tensor, b_tensors: torch.Tensor, def cutlass_fp4_moe_mm(a_tensors: torch.Tensor, b_tensors: torch.Tensor,
......
...@@ -39,6 +39,7 @@ class BatchedDeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute): ...@@ -39,6 +39,7 @@ class BatchedDeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute):
def workspace_shapes( def workspace_shapes(
self, self,
a: torch.Tensor, a: torch.Tensor,
aq: torch.Tensor,
M: int, M: int,
N: int, N: int,
K: int, K: int,
......
...@@ -67,6 +67,7 @@ class BatchedTritonOrDeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute): ...@@ -67,6 +67,7 @@ class BatchedTritonOrDeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute):
def workspace_shapes( def workspace_shapes(
self, self,
a: torch.Tensor, a: torch.Tensor,
aq: torch.Tensor,
M: int, M: int,
N: int, N: int,
K: int, K: int,
...@@ -78,11 +79,11 @@ class BatchedTritonOrDeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute): ...@@ -78,11 +79,11 @@ class BatchedTritonOrDeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute):
# even if we fall back to triton later, e.g. if expert maps are set. # even if we fall back to triton later, e.g. if expert maps are set.
if self.allow_deep_gemm and self.batched_deep_gemm_experts is not None: if self.allow_deep_gemm and self.batched_deep_gemm_experts is not None:
return self.batched_deep_gemm_experts.workspace_shapes( return self.batched_deep_gemm_experts.workspace_shapes(
a, M, N, K, topk, num_experts) a, aq, M, N, K, topk, num_experts)
else: else:
assert self.batched_triton_experts is not None assert self.batched_triton_experts is not None
return self.batched_triton_experts.workspace_shapes( return self.batched_triton_experts.workspace_shapes(
a, M, N, K, topk, num_experts) a, aq, M, N, K, topk, num_experts)
def apply( def apply(
self, self,
......
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
""" CUTLASS based Fused MoE kernels.""" """ CUTLASS based Fused MoE kernels."""
from typing import Optional from typing import Callable, Optional
import torch import torch
...@@ -13,110 +13,109 @@ from vllm.model_executor.layers.fused_moe.utils import _fp8_perm, _resize_cache ...@@ -13,110 +13,109 @@ from vllm.model_executor.layers.fused_moe.utils import _fp8_perm, _resize_cache
from vllm.scalar_type import scalar_types from vllm.scalar_type import scalar_types
class CutlassExpertsFp8(mk.FusedMoEPermuteExpertsUnpermute): def run_cutlass_moe_fp8(
hidden_states: torch.Tensor,
def __init__( w1: torch.Tensor,
self, w2: torch.Tensor,
ab_strides1: torch.Tensor, topk_ids: torch.Tensor,
c_strides1: torch.Tensor, activation_callable: Callable,
ab_strides2: torch.Tensor, global_num_experts: int,
c_strides2: torch.Tensor, expert_map: Optional[torch.Tensor],
out_dtype: torch.dtype, w1_scale: Optional[torch.Tensor],
): w2_scale: Optional[torch.Tensor],
super().__init__() a1q_scale: Optional[torch.Tensor],
self.ab_strides1 = ab_strides1 a2_scale: Optional[torch.Tensor],
self.c_strides1 = c_strides1 workspace13: torch.Tensor,
self.ab_strides2 = ab_strides2 workspace2: torch.Tensor,
self.c_strides2 = c_strides2 expert_num_tokens: Optional[torch.Tensor],
self.out_dtype = out_dtype out_dtype: torch.dtype,
per_act_token: bool,
def workspace_shapes( per_out_ch: bool,
self, ) -> torch.Tensor:
a: torch.Tensor, a1q = hidden_states
M: int,
N: int, assert w1_scale is not None
K: int, assert w2_scale is not None
topk: int, assert w1.dtype == torch.float8_e4m3fn
num_experts: int, assert w2.dtype == torch.float8_e4m3fn
) -> tuple[int, int, torch.dtype]: if expert_num_tokens is None:
# Note that K, N are transposed assert a1q.shape[1] == w1.shape[2], "Hidden size mismatch w1"
N, K = K, N else:
workspace1 = M * topk * max(2 * N, K) assert a1q.shape[2] == w1.shape[2], "Hidden size mismatch w1"
workspace2 = M * topk * N assert w1.shape[1] == w2.shape[2] * 2, "Hidden size mismatch w2"
return (workspace1, workspace2, self.out_dtype) assert w1_scale.dim() == 1 or w1_scale.shape[1] == 1 or w1_scale.shape[
1] == w1.shape[1], "W1 scale shape mismatch"
def apply( assert w2_scale.dim() == 1 or w2_scale.shape[1] == 1 or w2_scale.shape[
self, 1] == w2.shape[1], "W2 scale shape mismatch"
hidden_states: torch.Tensor, assert w1.shape[0] == w2.shape[0], "Expert number mismatch"
w1: torch.Tensor, assert a1q_scale is None or a1q_scale.dim(
w2: torch.Tensor, ) == 0 or a1q_scale.shape[0] == 1 or a1q_scale.shape[0] == a1q.shape[
topk_ids: torch.Tensor, 0], "Input scale shape mismatch"
activation: str, assert w1.shape[0] == w2.shape[0], "Weights expert number mismatch"
global_num_experts: int, assert w1.shape[0] == w1_scale.shape[0], "w1 scales expert number mismatch"
expert_map: Optional[torch.Tensor], assert w1.shape[0] == w2_scale.shape[0], "w2 scales expert number mismatch"
w1_scale: Optional[torch.Tensor], assert a2_scale is None or a2_scale.dim(
w2_scale: Optional[torch.Tensor], ) == 0 or a2_scale.shape[0] == 1 or a2_scale.shape[0] == a1q.shape[
w1_zp: Optional[torch.Tensor], 0], "Intermediate scale shape mismatch"
w2_zp: Optional[torch.Tensor], assert out_dtype in [torch.half, torch.bfloat16], "Invalid output dtype"
a1q_scale: Optional[torch.Tensor], if expert_map is not None:
a2_scale: Optional[torch.Tensor], assert expert_num_tokens is None
workspace13: torch.Tensor,
workspace2: torch.Tensor, # We have two modes: PPLX and non-PPLX. We differentiate them by checking
expert_num_tokens: Optional[torch.Tensor], # if expert_num_tokens is None (expert_num_tokens is a tensor which PPLX
) -> torch.Tensor: # uses to track the number of tokens per expert).
a1q = hidden_states # In the non-PPLX mode, the input tokens are not padded: thus, the shape
# of the input is [total_num_tokens, hidden_size]. The input and output
assert w1_scale is not None # require shuffling by a_map and c_map such that the tokens assigned to
assert w2_scale is not None # each expert are contiguous.
assert w1.dtype == torch.float8_e4m3fn # In the PPLX mode, the input tokens are padded per expert to ensure that
assert w2.dtype == torch.float8_e4m3fn # the PPLX dispatch and combine functions work correctly: thus, the shape
assert a1q.shape[1] == w1.shape[1], "Hidden size mismatch w1" # of the input is [num_experts, max_num_tokens_per_expert, hidden_size].
assert w1.shape[2] == w2.shape[1] * 2, "Hidden size mismatch w2" # The PPLX input and output require no shuffling by a_map and c_map since
assert w1.shape[0] == w2.shape[0], "Expert number mismatch" # their tokens are already contiguous for each expert as a result of
assert a1q_scale is None or a1q_scale.dim( # the dispatch function.
) == 0 or a1q_scale.shape[0] == 1 or a1q_scale.shape[0] == a1q.shape[ is_pplx = expert_num_tokens is not None
0], "Input scale shape mismatch"
assert w1_scale.dim() == 1 or w1_scale.shape[1] == 1 or w1_scale.shape[ M = a1q.shape[0] # no pplx
1] == w1.shape[2], "W1 scale shape mismatch" padded_M = a1q.shape[1] # pplx
assert w2_scale.dim() == 1 or w2_scale.shape[1] == 1 or w2_scale.shape[ _, K, N = w2.shape
1] == w2.shape[2], "W2 scale shape mismatch" device = a1q.device
assert w1.shape[0] == w2.shape[0], "Weights expert number mismatch"
assert w1.shape[0] == w1_scale.shape[ assert w1.shape[2] == K
0], "w1 scales expert number mismatch" assert global_num_experts != -1
assert w1.shape[0] == w2_scale.shape[ assert a1q_scale is not None
0], "w2 scales expert number mismatch"
assert a2_scale is None or a1q_scale is None or a2_scale.shape == a1q_scale.shape, "Intermediate scale shape mismatch" # noqa: E501 if expert_map is not None:
assert self.ab_strides1.shape[0] == w1.shape[ "Translate info from expert_map to topk_ids"
0], "AB Strides 1 expert number mismatch" local_topk_ids = torch.where(expert_map[topk_ids] != -1,
assert self.c_strides1.shape[0] == w1.shape[ expert_map[topk_ids], -1)
0], "C Strides 1 expert number mismatch" else:
assert self.ab_strides2.shape[0] == w2.shape[ local_topk_ids = topk_ids
0], "AB Strides 2 expert number mismatch"
assert self.c_strides2.shape[0] == w2.shape[ topk = local_topk_ids.shape[1]
0], "C Strides 2 expert number mismatch" local_E = w1.shape[0]
assert self.out_dtype in [torch.half,
torch.bfloat16], "Invalid output dtype" if is_pplx:
expert_offsets = torch.empty((local_E),
M = a1q.shape[0] dtype=torch.int32,
_, N, K = w2.shape # because w1 + w2 are transposed device=device)
device = a1q.device problem_sizes1 = torch.empty((local_E, 3),
dtype=torch.int32,
assert w1.shape[1] == K device=device)
assert global_num_experts != -1 problem_sizes2 = torch.empty((local_E, 3),
assert a1q_scale is not None dtype=torch.int32,
device=device)
if expert_map is not None:
"Translate info from expert_map to topk_ids"
local_topk_ids = torch.where(expert_map[topk_ids] != -1,
expert_map[topk_ids], -1)
else:
local_topk_ids = topk_ids
topk = local_topk_ids.shape[1] ops.get_cutlass_pplx_moe_mm_data(expert_offsets, problem_sizes1,
problem_sizes2, expert_num_tokens,
local_E, padded_M, N, K)
per_act_token = a1q_scale.numel() != 1 if a1q_scale is not None else ( w1_scale = w1_scale.reshape(w1_scale.shape[0], -1)
a2_scale.numel() != 1 if a2_scale is not None else False) w2_scale = w2_scale.reshape(w2_scale.shape[0], -1)
a1q = a1q.reshape(-1, a1q.shape[2])
a1q_scale = a1q_scale.reshape(-1, a1q_scale.shape[2]).contiguous()
else:
expert_offsets = torch.empty((global_num_experts + 1), expert_offsets = torch.empty((global_num_experts + 1),
dtype=torch.int32, dtype=torch.int32,
device=device) device=device)
...@@ -149,50 +148,130 @@ class CutlassExpertsFp8(mk.FusedMoEPermuteExpertsUnpermute): ...@@ -149,50 +148,130 @@ class CutlassExpertsFp8(mk.FusedMoEPermuteExpertsUnpermute):
a1q = _fp8_perm(a1q, a_map) a1q = _fp8_perm(a1q, a_map)
a1q_scale = a1q_scale[a_map] if per_act_token else a1q_scale a1q_scale = a1q_scale[a_map] if per_act_token else a1q_scale
expert_offsets = expert_offsets[:-1]
ab_strides1 = torch.full((w1.shape[0], ),
K,
device=device,
dtype=torch.int64)
c_strides1 = torch.full((w1.shape[0], ),
2 * N,
device=device,
dtype=torch.int64)
ab_strides2 = torch.full((w1.shape[0], ),
N,
device=device,
dtype=torch.int64)
c_strides2 = torch.full((w1.shape[0], ),
K,
device=device,
dtype=torch.int64)
if is_pplx:
c1 = _resize_cache(workspace13, (local_E * padded_M, N * 2))
c2 = _resize_cache(workspace2, (local_E * padded_M, N))
c3 = _resize_cache(workspace13, (local_E * padded_M, K))
else:
c1 = _resize_cache(workspace13, (M * topk, N * 2)) c1 = _resize_cache(workspace13, (M * topk, N * 2))
c2 = _resize_cache(workspace2, (M * topk, N)) c2 = _resize_cache(workspace2, (M * topk, N))
c3 = _resize_cache(workspace13, (M * topk, K)) c3 = _resize_cache(workspace13, (M * topk, K))
ops.cutlass_moe_mm(c1, a1q, w1, a1q_scale, w1_scale, ops.cutlass_moe_mm(c1, a1q, w1, a1q_scale, w1_scale, expert_offsets,
expert_offsets[:-1], problem_sizes1, problem_sizes1, ab_strides1, ab_strides1, c_strides1,
self.ab_strides1, self.ab_strides1, self.c_strides1) per_act_token, per_out_ch)
self.activation(activation, c2, c1) activation_callable(c2, c1)
a2q, a2q_scale = ops.scaled_fp8_quant( a2q, a2q_scale = ops.scaled_fp8_quant(
c2, a2_scale, use_per_token_if_dynamic=per_act_token) c2, a2_scale, use_per_token_if_dynamic=per_act_token)
if expert_map is not None: if expert_map is not None:
c3.fill_(0) c3.fill_(0)
ops.cutlass_moe_mm(c3, a2q, w2, a2q_scale, w2_scale, expert_offsets,
problem_sizes2, ab_strides2, ab_strides2, c_strides2,
per_act_token, per_out_ch)
ops.cutlass_moe_mm(c3, a2q, w2, a2q_scale, w2_scale, if is_pplx:
expert_offsets[:-1], problem_sizes2, return c3.reshape(local_E, padded_M, K)
self.ab_strides2, self.ab_strides2, self.c_strides2) else:
return c3[c_map].view(M, topk, K)
c3 = c3[c_map]
return c3 class CutlassExpertsFp8(mk.FusedMoEPermuteExpertsUnpermute):
def __init__(
self,
max_experts_per_worker: int,
out_dtype: torch.dtype,
per_act_token: bool,
per_out_ch: bool,
):
super().__init__()
self.max_experts_per_worker = max_experts_per_worker
self.out_dtype = out_dtype
self.per_act_token = per_act_token
self.per_out_ch = per_out_ch
def workspace_shapes(
self,
a: torch.Tensor,
aq: torch.Tensor,
M: int,
N: int,
K: int,
topk: int,
num_experts: int,
) -> tuple[int, int, torch.dtype]:
padded_M = aq.shape[1]
workspace1 = self.max_experts_per_worker * padded_M * max(N, K)
workspace2 = self.max_experts_per_worker * padded_M * (N // 2)
return (workspace1, workspace2, self.out_dtype)
def apply(
self,
hidden_states: torch.Tensor,
w1: torch.Tensor,
w2: torch.Tensor,
topk_ids: torch.Tensor,
activation: str,
global_num_experts: int,
expert_map: Optional[torch.Tensor],
w1_scale: Optional[torch.Tensor],
w2_scale: Optional[torch.Tensor],
w1_zp: Optional[torch.Tensor],
w2_zp: Optional[torch.Tensor],
a1q_scale: Optional[torch.Tensor],
a2_scale: Optional[torch.Tensor],
workspace13: torch.Tensor,
workspace2: torch.Tensor,
expert_num_tokens: Optional[torch.Tensor],
) -> torch.Tensor:
assert w1_zp is None, "w1_zp is not supported in CUTLASS MoE"
assert w2_zp is None, "w2_zp is not supported in CUTLASS MoE"
activation_callable = lambda i, o: self.activation(activation, i, o)
return run_cutlass_moe_fp8(hidden_states, w1, w2, topk_ids,
activation_callable, global_num_experts,
expert_map, w1_scale, w2_scale, a1q_scale,
a2_scale, workspace13, workspace2,
expert_num_tokens, self.out_dtype,
self.per_act_token, self.per_out_ch)
#TODO make the grouped gemm kernel consistent with scaled gemm kernel
def cutlass_moe_fp8( def cutlass_moe_fp8(
a: torch.Tensor, a: torch.Tensor,
w1_q: torch.Tensor, w1_q: torch.Tensor,
w2_q: torch.Tensor, w2_q: torch.Tensor,
w1_scale: torch.Tensor,
w2_scale: torch.Tensor,
topk_weights: torch.Tensor, topk_weights: torch.Tensor,
topk_ids: torch.Tensor, topk_ids: torch.Tensor,
ab_strides1: torch.Tensor, w1_scale: torch.Tensor,
c_strides1: torch.Tensor, w2_scale: torch.Tensor,
ab_strides2: torch.Tensor, activation: str = "silu",
c_strides2: torch.Tensor,
a1_scale: Optional[torch.Tensor] = None, a1_scale: Optional[torch.Tensor] = None,
a2_scale: Optional[torch.Tensor] = None, a2_scale: Optional[torch.Tensor] = None,
out_dtype: torch.dtype = torch.half,
expert_map: Optional[torch.Tensor] = None, expert_map: Optional[torch.Tensor] = None,
apply_router_weight_on_input: bool = False, apply_router_weight_on_input: bool = False,
global_num_experts: int = -1,
) -> torch.Tensor: ) -> torch.Tensor:
""" """
This function computes a a8w8-quantized Mixture of Experts (MoE) layer This function computes a a8w8-quantized Mixture of Experts (MoE) layer
...@@ -207,25 +286,17 @@ def cutlass_moe_fp8( ...@@ -207,25 +286,17 @@ def cutlass_moe_fp8(
Shape: [num_experts, K, 2N] (the weights are passed transposed) Shape: [num_experts, K, 2N] (the weights are passed transposed)
- w2_q (torch.Tensor): The second set of fp8-quantized expert weights. - w2_q (torch.Tensor): The second set of fp8-quantized expert weights.
Shape: [num_experts, N, K] (the weights are passed transposed) Shape: [num_experts, N, K] (the weights are passed transposed)
- topk_weights (torch.Tensor): The weights of each token->expert mapping.
- topk_ids (torch.Tensor): The token->expert mappings.
- w1_scale (torch.Tensor): The fp32 scale to dequantize w1_q. - w1_scale (torch.Tensor): The fp32 scale to dequantize w1_q.
Shape: [num_experts] or [num_experts, 2N] Shape: [num_experts] or [num_experts, 2N]
- w2_scale (torch.Tensor): The fp32 scale to dequantize w2_q. - w2_scale (torch.Tensor): The fp32 scale to dequantize w2_q.
Shape: [num_experts] or [num_experts, K] Shape: [num_experts] or [num_experts, K]
- gating_output (torch.Tensor): The output of the gating operation
(before softmax).
- topk_weights (torch.Tensor): The weights of each token->expert mapping.
- ab_strides1 (torch.Tensor): The input and weights strides of the first
grouped gemm.
- c_strides1 (torch.Tensor): The output strides of the first grouped gemm.
- ab_strides2 (torch.Tensor): The input and weights strides of the second
grouped gemm.
- c_strides2 (torch.Tensor): The output strides of the second grouped gemm.
- a1_scale (Optional[torch.Tensor]): The optional fp32 scale to quantize a. - a1_scale (Optional[torch.Tensor]): The optional fp32 scale to quantize a.
Shape: scalar or [M] Shape: scalar or [M]
- a2_scale (Optional[torch.Tensor]): The optional fp32 scale to - a2_scale (Optional[torch.Tensor]): The optional fp32 scale to
quantize the intermediate result between the gemms. quantize the intermediate result between the gemms.
Shape: scalar or [M] Shape: scalar or [M]
- out_dtype (torch.dtype): The output tensor type.
- expert_map (Optional[torch.Tensor]): In the case of Expert parallel, - expert_map (Optional[torch.Tensor]): In the case of Expert parallel,
every Rank is responsible for a subset of experts. expert_map is a every Rank is responsible for a subset of experts. expert_map is a
mapping from global expert-id to local expert-id. When expert_map[i] mapping from global expert-id to local expert-id. When expert_map[i]
...@@ -233,24 +304,27 @@ def cutlass_moe_fp8( ...@@ -233,24 +304,27 @@ def cutlass_moe_fp8(
expert-id i. expert-id i.
- apply_router_weight_on_input (bool): When true, the topk weights are - apply_router_weight_on_input (bool): When true, the topk weights are
applied directly on the inputs. This is only applicable when topk is 1. applied directly on the inputs. This is only applicable when topk is 1.
- global_num_experts (int): The total number of experts.
Returns: Returns:
- torch.Tensor: The fp16 output tensor after applying the MoE layer. - torch.Tensor: The fp16 output tensor after applying the MoE layer.
""" """
per_act_token = a1_scale.numel() != 1 if a1_scale is not None else ( per_act_token = a1_scale.numel() != 1 if a1_scale is not None else (
a2_scale.numel() != 1 if a2_scale is not None else False) a2_scale.numel() != 1 if a2_scale is not None else False)
per_out_ch = w1_scale.numel() != w1_q.shape[0]
out_dtype = a.dtype
fn = mk.FusedMoEModularKernel( fn = mk.FusedMoEModularKernel(
MoEPrepareAndFinalizeNoEP( MoEPrepareAndFinalizeNoEP(
per_channel_quant=per_act_token,
quant_dtype=torch.float8_e4m3fn, quant_dtype=torch.float8_e4m3fn,
per_channel_quant=per_act_token,
), ),
CutlassExpertsFp8( CutlassExpertsFp8(
ab_strides1, max_experts_per_worker=global_num_experts,
c_strides1, out_dtype=out_dtype,
ab_strides2, per_act_token=per_act_token,
c_strides2, per_out_ch=per_out_ch,
out_dtype,
), ),
) )
...@@ -260,9 +334,12 @@ def cutlass_moe_fp8( ...@@ -260,9 +334,12 @@ def cutlass_moe_fp8(
w2_q, w2_q,
topk_weights, topk_weights,
topk_ids, topk_ids,
expert_map=expert_map, False,
w1_scale=w1_scale, activation,
w2_scale=w2_scale, global_num_experts if global_num_experts != -1 else w1_q.size(0),
expert_map,
w1_scale,
w2_scale,
a1_scale=a1_scale, a1_scale=a1_scale,
a2_scale=a2_scale, a2_scale=a2_scale,
apply_router_weight_on_input=apply_router_weight_on_input, apply_router_weight_on_input=apply_router_weight_on_input,
......
...@@ -73,6 +73,7 @@ class DeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute): ...@@ -73,6 +73,7 @@ class DeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute):
def workspace_shapes( def workspace_shapes(
self, self,
a: torch.Tensor, a: torch.Tensor,
aq: torch.Tensor,
M: int, M: int,
N: int, N: int,
K: int, K: int,
......
...@@ -521,6 +521,7 @@ class BatchedExperts(mk.FusedMoEPermuteExpertsUnpermute): ...@@ -521,6 +521,7 @@ class BatchedExperts(mk.FusedMoEPermuteExpertsUnpermute):
def workspace_shapes( def workspace_shapes(
self, self,
a: torch.Tensor, a: torch.Tensor,
aq: torch.Tensor,
M: int, M: int,
N: int, N: int,
K: int, K: int,
...@@ -632,6 +633,7 @@ class BatchedTritonExperts(mk.FusedMoEPermuteExpertsUnpermute): ...@@ -632,6 +633,7 @@ class BatchedTritonExperts(mk.FusedMoEPermuteExpertsUnpermute):
def workspace_shapes( def workspace_shapes(
self, self,
a: torch.Tensor, a: torch.Tensor,
aq: torch.Tensor,
M: int, M: int,
N: int, N: int,
K: int, K: int,
......
...@@ -1545,6 +1545,7 @@ class TritonExperts(mk.FusedMoEPermuteExpertsUnpermute): ...@@ -1545,6 +1545,7 @@ class TritonExperts(mk.FusedMoEPermuteExpertsUnpermute):
def workspace_shapes( def workspace_shapes(
self, self,
a: torch.Tensor, a: torch.Tensor,
aq: torch.Tensor,
M: int, M: int,
N: int, N: int,
K: int, K: int,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment