Unverified Commit d4a93841 authored by chenxj's avatar chenxj Committed by GitHub
Browse files

[feat] Support tp mode for DeepSeek-R1-W4AFP8 (#8118)


Co-authored-by: default avataryuhyao <827623970@qq.com>
parent 21e1bc47
...@@ -405,9 +405,10 @@ class ModelConfig: ...@@ -405,9 +405,10 @@ class ModelConfig:
# compressed-tensors uses a "compression_config" key # compressed-tensors uses a "compression_config" key
quant_cfg = getattr(self.hf_config, "compression_config", None) quant_cfg = getattr(self.hf_config, "compression_config", None)
if quant_cfg is None: if quant_cfg is None:
# check if is modelopt model -- modelopt doesn't have corresponding field # check if is modelopt or mixed-precision model -- Both of them don't have corresponding field
# in hf `config.json` but has a standalone `hf_quant_config.json` in the root directory # in hf `config.json` but has a standalone `hf_quant_config.json` in the root directory
# example: https://huggingface.co/nvidia/Llama-3.1-8B-Instruct-FP8/tree/main # example: https://huggingface.co/nvidia/Llama-3.1-8B-Instruct-FP8/tree/main
# example: https://huggingface.co/Barrrrry/DeepSeek-R1-W4AFP8/tree/main
is_local = os.path.exists(self.model_path) is_local = os.path.exists(self.model_path)
modelopt_quant_config = {"quant_method": "modelopt"} modelopt_quant_config = {"quant_method": "modelopt"}
if not is_local: if not is_local:
......
...@@ -91,14 +91,6 @@ def cutlass_w4a8_moe( ...@@ -91,14 +91,6 @@ def cutlass_w4a8_moe(
assert w1_q.shape[0] == w2_q.shape[0], "Expert number mismatch" assert w1_q.shape[0] == w2_q.shape[0], "Expert number mismatch"
assert w1_q.shape[0] == w1_scale.shape[0], "w1 scales expert number mismatch" assert w1_q.shape[0] == w1_scale.shape[0], "w1 scales expert number mismatch"
assert w1_q.shape[0] == w2_scale.shape[0], "w2 scales expert number mismatch" assert w1_q.shape[0] == w2_scale.shape[0], "w2 scales expert number mismatch"
assert (
w1_scale.shape[1] == w1_q.shape[2] * 2 / 512
and w1_scale.shape[2] == w1_q.shape[1] * 4
), "W1 scale shape mismatch"
assert (
w2_scale.shape[1] == w2_q.shape[2] * 2 / 512
and w2_scale.shape[2] == w2_q.shape[1] * 4
), "W2 scale shape mismatch"
assert a_strides1.shape[0] == w1_q.shape[0], "A Strides 1 expert number mismatch" assert a_strides1.shape[0] == w1_q.shape[0], "A Strides 1 expert number mismatch"
assert b_strides1.shape[0] == w1_q.shape[0], "B Strides 1 expert number mismatch" assert b_strides1.shape[0] == w1_q.shape[0], "B Strides 1 expert number mismatch"
......
...@@ -114,9 +114,6 @@ class EPMoE(FusedMoE): ...@@ -114,9 +114,6 @@ class EPMoE(FusedMoE):
with_bias=with_bias, with_bias=with_bias,
) )
self.start_expert_id = self.moe_ep_rank * self.num_local_experts
self.end_expert_id = self.start_expert_id + self.num_local_experts - 1
self.intermediate_size = intermediate_size self.intermediate_size = intermediate_size
if isinstance(quant_config, Fp8Config): if isinstance(quant_config, Fp8Config):
......
...@@ -175,6 +175,8 @@ class FusedMoE(torch.nn.Module): ...@@ -175,6 +175,8 @@ class FusedMoE(torch.nn.Module):
self.moe_tp_rank = get_moe_tensor_parallel_rank() self.moe_tp_rank = get_moe_tensor_parallel_rank()
assert num_experts % self.moe_ep_size == 0 assert num_experts % self.moe_ep_size == 0
self.num_local_experts = num_experts // self.moe_ep_size self.num_local_experts = num_experts // self.moe_ep_size
self.start_expert_id = self.moe_ep_rank * self.num_local_experts
self.end_expert_id = self.start_expert_id + self.num_local_experts - 1
if self.moe_ep_size > 1: if self.moe_ep_size > 1:
# TODO(ch-wan): support shared experts fusion # TODO(ch-wan): support shared experts fusion
# Create a tensor of size num_experts filled with -1 # Create a tensor of size num_experts filled with -1
...@@ -593,8 +595,9 @@ class FusedMoE(torch.nn.Module): ...@@ -593,8 +595,9 @@ class FusedMoE(torch.nn.Module):
if ( if (
"compressed" in self.quant_method.__class__.__name__.lower() "compressed" in self.quant_method.__class__.__name__.lower()
and param.data[expert_id] != 1 or "w4afp8" in self.quant_config.get_name()
and (param.data[expert_id] - loaded_weight).abs() > 1e-5 and (param.data[expert_id] != 1).any()
and ((param.data[expert_id] - loaded_weight).abs() > 1e-5).any()
): ):
raise ValueError( raise ValueError(
"input_scales of w1 and w3 of a layer " "input_scales of w1 and w3 of a layer "
......
from __future__ import annotations from __future__ import annotations
import logging import logging
from typing import TYPE_CHECKING, Any, Dict, List, Optional from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional
import torch import torch
from torch.nn import Module from torch.nn import Module
from torch.nn.parameter import Parameter from torch.nn.parameter import Parameter
from sglang.srt.distributed.parallel_state import get_moe_expert_parallel_world_size
from sglang.srt.layers.linear import LinearBase, UnquantizedLinearMethod
from sglang.srt.layers.quantization.base_config import ( from sglang.srt.layers.quantization.base_config import (
FusedMoEMethodBase, FusedMoEMethodBase,
QuantizationConfig, QuantizationConfig,
...@@ -91,12 +93,13 @@ class W4AFp8Config(QuantizationConfig): ...@@ -91,12 +93,13 @@ class W4AFp8Config(QuantizationConfig):
from sglang.srt.layers.linear import LinearBase from sglang.srt.layers.linear import LinearBase
from sglang.srt.layers.moe.ep_moe.layer import EPMoE from sglang.srt.layers.moe.ep_moe.layer import EPMoE
from sglang.srt.layers.moe.fused_moe_triton import FusedMoE from sglang.srt.layers.moe.fused_moe_triton import FusedMoE
from sglang.srt.managers.schedule_batch import global_server_args_dict
if isinstance(layer, LinearBase): if isinstance(layer, LinearBase):
if is_layer_skipped(prefix, self.ignored_layers): if is_layer_skipped(prefix, self.ignored_layers):
return UnquantizedLinearMethod() return UnquantizedLinearMethod()
return Fp8LinearMethod(self) return Fp8LinearMethod(self)
elif isinstance(layer, EPMoE): elif isinstance(layer, FusedMoE):
return W4AFp8MoEMethod(self) return W4AFp8MoEMethod(self)
return None return None
...@@ -104,8 +107,24 @@ class W4AFp8Config(QuantizationConfig): ...@@ -104,8 +107,24 @@ class W4AFp8Config(QuantizationConfig):
return [] return []
class W4AFp8MoEMethod(FusedMoEMethodBase): def interleave_scales(scales: torch.Tensor) -> torch.Tensor:
"""Interleave scales in groups of 4 similar to TRT-LLM implementation."""
s_shape = scales.shape
# Reshape to separate groups of 4
alignment = 4 if s_shape[2] % 4 == 0 else 1
scales_interleaved = scales.reshape(
s_shape[0], s_shape[1], (s_shape[2] // alignment), alignment
)
# Permute dimensions to interleave
scales_interleaved = scales_interleaved.permute(0, 2, 1, 3)
# Reshape back to original dimensions but with interleaved values
scales_interleaved = scales_interleaved.reshape(
s_shape[0], s_shape[2] // alignment, s_shape[1] * alignment
)
return scales_interleaved.contiguous()
class W4AFp8MoEMethod(FusedMoEMethodBase):
def __init__(self, quant_config: W4AFp8Config): def __init__(self, quant_config: W4AFp8Config):
self.quant_config = quant_config self.quant_config = quant_config
...@@ -234,33 +253,18 @@ class W4AFp8MoEMethod(FusedMoEMethodBase): ...@@ -234,33 +253,18 @@ class W4AFp8MoEMethod(FusedMoEMethodBase):
return return
def _interleave_scales(self, scales: torch.Tensor) -> torch.Tensor:
"""Interleave scales in groups of 4 similar to TRT-LLM implementation."""
s_shape = scales.shape
# Reshape to separate groups of 4
scales_interleaved = scales.reshape(
s_shape[0], s_shape[1], (s_shape[2] // 4), 4
)
# Permute dimensions to interleave
scales_interleaved = scales_interleaved.permute(0, 2, 1, 3)
# Reshape back to original dimensions but with interleaved values
scales_interleaved = scales_interleaved.reshape(
s_shape[0], s_shape[2] // 4, s_shape[1] * 4
)
return scales_interleaved.contiguous()
def process_weights_after_loading(self, layer: Module) -> None: def process_weights_after_loading(self, layer: Module) -> None:
dtype = torch.bfloat16 dtype = torch.bfloat16
device = layer.w2_weight.device device = layer.w2_weight.device
# Interleave w13_weight_scale (gate_up_proj) # Interleave w13_weight_scale (gate_up_proj)
w13_weight_scale = layer.w13_weight_scale_inv.to(dtype) w13_weight_scale = layer.w13_weight_scale_inv.to(dtype)
w13_weight_scale = self._interleave_scales(w13_weight_scale) w13_weight_scale = interleave_scales(w13_weight_scale)
layer.w13_weight_scale_inv = Parameter(w13_weight_scale, requires_grad=False) layer.w13_weight_scale_inv = Parameter(w13_weight_scale, requires_grad=False)
# Interleave w2_weight_scale (down_proj) # Interleave w2_weight_scale (down_proj)
w2_weight_scale = layer.w2_weight_scale_inv.to(dtype) w2_weight_scale = layer.w2_weight_scale_inv.to(dtype)
w2_weight_scale = self._interleave_scales(w2_weight_scale) w2_weight_scale = interleave_scales(w2_weight_scale)
layer.w2_weight_scale_inv = Parameter(w2_weight_scale, requires_grad=False) layer.w2_weight_scale_inv = Parameter(w2_weight_scale, requires_grad=False)
# Process input scales # Process input scales
...@@ -291,6 +295,7 @@ class W4AFp8MoEMethod(FusedMoEMethodBase): ...@@ -291,6 +295,7 @@ class W4AFp8MoEMethod(FusedMoEMethodBase):
topk_weights, topk_ids, _ = topk_output topk_weights, topk_ids, _ = topk_output
local_topk_ids = topk_ids local_topk_ids = topk_ids
if get_moe_expert_parallel_world_size() > 1:
local_topk_ids = torch.where( local_topk_ids = torch.where(
topk_ids == -1, topk_ids == -1,
layer.num_experts, layer.num_experts,
......
...@@ -2185,6 +2185,8 @@ class DeepseekV2ForCausalLM(nn.Module): ...@@ -2185,6 +2185,8 @@ class DeepseekV2ForCausalLM(nn.Module):
disable_reason = "Only Deepseek V3/R1 on NV-platform with capability >= 80 can use shared experts fusion optimization." disable_reason = "Only Deepseek V3/R1 on NV-platform with capability >= 80 can use shared experts fusion optimization."
elif get_moe_expert_parallel_world_size() > 1: elif get_moe_expert_parallel_world_size() > 1:
disable_reason = "Deepseek V3/R1 can not use shared experts fusion optimization under expert parallelism." disable_reason = "Deepseek V3/R1 can not use shared experts fusion optimization under expert parallelism."
elif self.quant_config.get_name() == "w4afp8":
disable_reason = "Deepseek V3/R1 W4AFP8 model uses different quant method for routed experts and shared experts."
if disable_reason is not None: if disable_reason is not None:
global_server_args_dict["disable_shared_experts_fusion"] = True global_server_args_dict["disable_shared_experts_fusion"] = True
...@@ -2496,6 +2498,9 @@ class DeepseekV2ForCausalLM(nn.Module): ...@@ -2496,6 +2498,9 @@ class DeepseekV2ForCausalLM(nn.Module):
ckpt_up_proj_name="up_proj", ckpt_up_proj_name="up_proj",
num_experts=self.config.n_routed_experts + self.num_fused_shared_experts, num_experts=self.config.n_routed_experts + self.num_fused_shared_experts,
) )
# Params for special naming rules in mixed-precision models, for example:
# model.layers.xx.mlp.experts.xx.w1.input_scale. For details,
# see https://huggingface.co/Barrrrry/DeepSeek-R1-W4AFP8/blob/main.
if self.quant_config and self.quant_config.get_name() == "w4afp8": if self.quant_config and self.quant_config.get_name() == "w4afp8":
expert_params_mapping += FusedMoE.make_expert_input_scale_params_mapping( expert_params_mapping += FusedMoE.make_expert_input_scale_params_mapping(
num_experts=self.config.n_routed_experts num_experts=self.config.n_routed_experts
......
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
from typing import Optional from typing import Literal, Optional
import pytest import pytest
import torch import torch
...@@ -25,7 +25,7 @@ def pack_int4_values_to_int8(int4_values_interleaved: torch.Tensor) -> torch.Ten ...@@ -25,7 +25,7 @@ def pack_int4_values_to_int8(int4_values_interleaved: torch.Tensor) -> torch.Ten
return packed_tensor.to(torch.int8) return packed_tensor.to(torch.int8)
def pack_interleave(num_experts, ref_weight, ref_scale): def pack_interleave(num_experts, ref_weight, ref_scale, alignment=4):
n, k = ref_weight.shape[1], ref_weight.shape[2] n, k = ref_weight.shape[1], ref_weight.shape[2]
weight = pack_int4_values_to_int8(ref_weight.cpu()).cuda() weight = pack_int4_values_to_int8(ref_weight.cpu()).cuda()
...@@ -33,11 +33,16 @@ def pack_interleave(num_experts, ref_weight, ref_scale): ...@@ -33,11 +33,16 @@ def pack_interleave(num_experts, ref_weight, ref_scale):
w_q = w_q.contiguous() w_q = w_q.contiguous()
scale_interleaved = ref_scale.reshape( scale_interleaved = ref_scale.reshape(
ref_scale.shape[0], ref_scale.shape[1], (ref_scale.shape[2] // 4), 4 ref_scale.shape[0],
ref_scale.shape[1],
(ref_scale.shape[2] // alignment),
alignment,
) # [E, N, K/4, 4] ) # [E, N, K/4, 4]
scale_interleaved = scale_interleaved.permute(0, 2, 1, 3) # [E, K/4, N, 4] scale_interleaved = scale_interleaved.permute(0, 2, 1, 3) # [E, K/4, N, 4]
scale_interleaved = scale_interleaved.reshape( scale_interleaved = scale_interleaved.reshape(
ref_scale.shape[0], ref_scale.shape[2] // 4, ref_scale.shape[1] * 4 ref_scale.shape[0],
ref_scale.shape[2] // alignment,
ref_scale.shape[1] * alignment,
) # [E, K/4, N*4] ) # [E, K/4, N*4]
w_scale = scale_interleaved.contiguous() w_scale = scale_interleaved.contiguous()
...@@ -48,12 +53,17 @@ def pack_interleave(num_experts, ref_weight, ref_scale): ...@@ -48,12 +53,17 @@ def pack_interleave(num_experts, ref_weight, ref_scale):
@pytest.mark.parametrize("N", [2048]) @pytest.mark.parametrize("N", [2048])
@pytest.mark.parametrize("K", [7168]) @pytest.mark.parametrize("K", [7168])
@pytest.mark.parametrize("E", [256]) @pytest.mark.parametrize("E", [256])
@pytest.mark.parametrize("ep_size", [8]) @pytest.mark.parametrize("tp_size", [8])
@pytest.mark.parametrize("use_ep_moe", [True, False])
@pytest.mark.parametrize("topk", [8]) @pytest.mark.parametrize("topk", [8])
@pytest.mark.parametrize("group_size", [128]) @pytest.mark.parametrize("group_size", [128])
@pytest.mark.parametrize("dtype", [torch.bfloat16]) @pytest.mark.parametrize("dtype", [torch.bfloat16])
def test_cutlass_w4a8_moe(M, N, K, E, ep_size, topk, group_size, dtype): def test_cutlass_w4a8_moe(M, N, K, E, tp_size, use_ep_moe, topk, group_size, dtype):
local_e = E // ep_size if use_ep_moe:
local_e = E // tp_size
else: # tp mode
local_e = E
N = N // tp_size
debug = False debug = False
if debug: if debug:
...@@ -87,7 +97,10 @@ def test_cutlass_w4a8_moe(M, N, K, E, ep_size, topk, group_size, dtype): ...@@ -87,7 +97,10 @@ def test_cutlass_w4a8_moe(M, N, K, E, ep_size, topk, group_size, dtype):
) )
w1_q, w1_scale = pack_interleave(local_e, ref_weight_1, scale_1) w1_q, w1_scale = pack_interleave(local_e, ref_weight_1, scale_1)
if use_ep_moe:
w2_q, w2_scale = pack_interleave(local_e, ref_weight_2, scale_2) w2_q, w2_scale = pack_interleave(local_e, ref_weight_2, scale_2)
else:
w2_q, w2_scale = pack_interleave(local_e, ref_weight_2, scale_2, 1)
device = "cuda" device = "cuda"
a_strides1 = torch.full((local_e, 3), K, device=device, dtype=torch.int64) a_strides1 = torch.full((local_e, 3), K, device=device, dtype=torch.int64)
...@@ -265,7 +278,9 @@ def ref( ...@@ -265,7 +278,9 @@ def ref(
gate, fc1 = fc1.chunk(2, dim=-1) gate, fc1 = fc1.chunk(2, dim=-1)
fc1 = fc1 * torch.nn.functional.silu(gate) fc1 = fc1 * torch.nn.functional.silu(gate)
act = (fc1 / pre_quant_scale_2.float()).to(torch.float8_e4m3fn) act = torch.clamp((fc1 / pre_quant_scale_2.float()), -448.0, 448.0).to(
torch.float8_e4m3fn
)
act = act.to(dtype) act = act.to(dtype)
w2 = ref_weight_2[e_idx] w2 = ref_weight_2[e_idx]
......
...@@ -31,7 +31,7 @@ __global__ void int4_fp8_get_group_gemm_starts( ...@@ -31,7 +31,7 @@ __global__ void int4_fp8_get_group_gemm_starts(
b_offsets[expert_id] = b_base_as_int + expert_id * k * n / 2; b_offsets[expert_id] = b_base_as_int + expert_id * k * n / 2;
out_offsets[expert_id] = out_base_as_int + expert_offset * n; out_offsets[expert_id] = out_base_as_int + expert_offset * n;
a_scales_offsets[expert_id] = a_scales_base_as_int + (per_act_token ? expert_offset : 0); a_scales_offsets[expert_id] = a_scales_base_as_int + (per_act_token ? expert_offset : 0);
b_scales_offsets[expert_id] = b_scales_base_as_int + (per_out_ch ? expert_id * n * 4 * k / 512 : expert_id); b_scales_offsets[expert_id] = b_scales_base_as_int + (per_out_ch ? expert_id * n * k / 128 : expert_id);
} }
#define __CALL_W4A8_GET_STARTS_KERNEL(TENSOR_C_TYPE, C_TYPE) \ #define __CALL_W4A8_GET_STARTS_KERNEL(TENSOR_C_TYPE, C_TYPE) \
......
...@@ -2,6 +2,8 @@ ...@@ -2,6 +2,8 @@
#include <cudaTypedefs.h> #include <cudaTypedefs.h>
#include <torch/all.h> #include <torch/all.h>
#include <type_traits>
#include "cutlass/cutlass.h" #include "cutlass/cutlass.h"
#include "w4a8_grouped_mm_c3x.cuh" #include "w4a8_grouped_mm_c3x.cuh"
...@@ -9,38 +11,60 @@ using namespace cute; ...@@ -9,38 +11,60 @@ using namespace cute;
namespace { namespace {
#define JOIN_STRUCT_NAME(m, n, k, a, b, c) sm90_fp8_config##_##m##_##n##_##k##_##a##_##b##_##c enum class Sched { PP, CO };
#define JOIN_STRUCT_NAME_CO(m, n, k, a, b, c) sm90_fp8_co_config##_##m##_##n##_##k##_##a##_##b##_##c template <int M, int N, int K, int A, int B, int C, Sched S>
struct SM90W4A8Config {
using KernelSchedule = std::conditional_t<
S == Sched::PP,
cutlass::gemm::KernelPtrArrayTmaWarpSpecializedPingpong,
cutlass::gemm::KernelPtrArrayTmaWarpSpecializedCooperative>;
#define GENERATE_SM90_W4A8_PP_CONFIG(M, N, K, A, B, C) \ using EpilogueSchedule = std::conditional_t<
struct JOIN_STRUCT_NAME(M, N, K, A, B, C) { \ S == Sched::PP,
using KernelSchedule = cutlass::gemm::KernelPtrArrayTmaWarpSpecializedPingpong; \ cutlass::epilogue::PtrArrayTmaWarpSpecializedPingpong,
using EpilogueSchedule = cutlass::epilogue::PtrArrayTmaWarpSpecializedPingpong; \ cutlass::epilogue::PtrArrayTmaWarpSpecializedCooperative>;
using TileShape = cute::Shape<cute::Int<M>, cute::Int<N>, cute::Int<K>>; \
using ClusterShape = cute::Shape<cute::Int<A>, cute::Int<B>, cute::Int<C>>; \
\
using Cutlass3xW4A8Gemm = cutlass_3x_w4a8_group_gemm<TileShape, ClusterShape, KernelSchedule, EpilogueSchedule>; \
};
#define GENERATE_SM90_W4A8_CO_CONFIG(M, N, K, A, B, C) \ using TileShape = cute::Shape<cute::Int<M>, cute::Int<N>, cute::Int<K>>;
struct JOIN_STRUCT_NAME_CO(M, N, K, A, B, C) { \ using ClusterShape = cute::Shape<cute::Int<A>, cute::Int<B>, cute::Int<C>>;
using KernelSchedule = cutlass::gemm::KernelPtrArrayTmaWarpSpecializedCooperative; \ using Cutlass3xW4A8Gemm = cutlass_3x_w4a8_group_gemm<TileShape, ClusterShape, KernelSchedule, EpilogueSchedule>;
using EpilogueSchedule = cutlass::epilogue::PtrArrayTmaWarpSpecializedCooperative; \ };
using TileShape = cute::Shape<cute::Int<M>, cute::Int<N>, cute::Int<K>>; \
using ClusterShape = cute::Shape<cute::Int<A>, cute::Int<B>, cute::Int<C>>; \
\
using Cutlass3xW4A8Gemm = cutlass_3x_w4a8_group_gemm<TileShape, ClusterShape, KernelSchedule, EpilogueSchedule>; \
};
GENERATE_SM90_W4A8_PP_CONFIG(64, 16, 512, 1, 1, 1) template <int M, int N, int K, int A, int B, int C>
GENERATE_SM90_W4A8_PP_CONFIG(64, 32, 512, 2, 1, 1) using SM90_PP = SM90W4A8Config<M, N, K, A, B, C, Sched::PP>;
GENERATE_SM90_W4A8_CO_CONFIG(128, 16, 512, 1, 1, 1) template <int M, int N, int K, int A, int B, int C>
GENERATE_SM90_W4A8_CO_CONFIG(128, 16, 512, 2, 1, 1) using SM90_CO = SM90W4A8Config<M, N, K, A, B, C, Sched::CO>;
GENERATE_SM90_W4A8_CO_CONFIG(128, 32, 512, 1, 1, 1)
GENERATE_SM90_W4A8_CO_CONFIG(128, 32, 512, 2, 1, 1) template <typename Config>
GENERATE_SM90_W4A8_CO_CONFIG(128, 64, 512, 1, 1, 1) inline void invoke_gemm(
torch::Tensor& d_tensors,
torch::Tensor const& a_tensors,
torch::Tensor const& b_tensors,
torch::Tensor const& a_scales,
torch::Tensor const& b_scales,
torch::Tensor const& expert_offsets,
torch::Tensor const& problem_sizes,
torch::Tensor const& a_strides,
torch::Tensor const& b_strides,
torch::Tensor const& d_strides,
torch::Tensor const& s_strides,
int64_t chunk_size) {
using GemmT = typename Config::Cutlass3xW4A8Gemm;
cutlass_w4a8_group_gemm_caller<GemmT>(
d_tensors,
a_tensors,
b_tensors,
a_scales,
b_scales,
expert_offsets,
problem_sizes,
a_strides,
b_strides,
d_strides,
s_strides,
chunk_size);
}
void dispatch_w4a8_moe_mm_sm90( void dispatch_w4a8_moe_mm_sm90(
torch::Tensor& d_tensors, torch::Tensor& d_tensors,
...@@ -56,9 +80,6 @@ void dispatch_w4a8_moe_mm_sm90( ...@@ -56,9 +80,6 @@ void dispatch_w4a8_moe_mm_sm90(
torch::Tensor const& s_strides, torch::Tensor const& s_strides,
int64_t chunk_size, int64_t chunk_size,
int64_t topk) { int64_t topk) {
using KernelSchedule = cutlass::gemm::KernelPtrArrayTmaWarpSpecializedCooperative;
using EpilogueSchedule = cutlass::epilogue::PtrArrayTmaWarpSpecializedCooperative;
uint32_t const m = a_tensors.size(0) / topk; uint32_t const m = a_tensors.size(0) / topk;
uint32_t const n = d_tensors.size(1); uint32_t const n = d_tensors.size(1);
uint32_t const k = a_tensors.size(1); uint32_t const k = a_tensors.size(1);
...@@ -66,8 +87,7 @@ void dispatch_w4a8_moe_mm_sm90( ...@@ -66,8 +87,7 @@ void dispatch_w4a8_moe_mm_sm90(
if (n == 4096 && k == 7168) { if (n == 4096 && k == 7168) {
// group gemm 1 // group gemm 1
if (m <= 4) { if (m <= 4) {
using Cutlass3xW4A8GemmSelected = typename JOIN_STRUCT_NAME(64, 32, 512, 2, 1, 1)::Cutlass3xW4A8Gemm; invoke_gemm<SM90_PP<64, 32, 512, 2, 1, 1>>(
cutlass_w4a8_group_gemm_caller<Cutlass3xW4A8GemmSelected>(
d_tensors, d_tensors,
a_tensors, a_tensors,
b_tensors, b_tensors,
...@@ -81,8 +101,7 @@ void dispatch_w4a8_moe_mm_sm90( ...@@ -81,8 +101,7 @@ void dispatch_w4a8_moe_mm_sm90(
s_strides, s_strides,
chunk_size); chunk_size);
} else if (m <= 16) { } else if (m <= 16) {
using Cutlass3xW4A8GemmSelected = typename JOIN_STRUCT_NAME_CO(128, 16, 512, 2, 1, 1)::Cutlass3xW4A8Gemm; invoke_gemm<SM90_CO<128, 16, 512, 2, 1, 1>>(
cutlass_w4a8_group_gemm_caller<Cutlass3xW4A8GemmSelected>(
d_tensors, d_tensors,
a_tensors, a_tensors,
b_tensors, b_tensors,
...@@ -96,8 +115,7 @@ void dispatch_w4a8_moe_mm_sm90( ...@@ -96,8 +115,7 @@ void dispatch_w4a8_moe_mm_sm90(
s_strides, s_strides,
chunk_size); chunk_size);
} else if (m <= 256) { } else if (m <= 256) {
using Cutlass3xW4A8GemmSelected = typename JOIN_STRUCT_NAME_CO(128, 16, 512, 1, 1, 1)::Cutlass3xW4A8Gemm; invoke_gemm<SM90_CO<128, 16, 512, 1, 1, 1>>(
cutlass_w4a8_group_gemm_caller<Cutlass3xW4A8GemmSelected>(
d_tensors, d_tensors,
a_tensors, a_tensors,
b_tensors, b_tensors,
...@@ -111,8 +129,7 @@ void dispatch_w4a8_moe_mm_sm90( ...@@ -111,8 +129,7 @@ void dispatch_w4a8_moe_mm_sm90(
s_strides, s_strides,
chunk_size); chunk_size);
} else if (m <= 1024) { } else if (m <= 1024) {
using Cutlass3xW4A8GemmSelected = typename JOIN_STRUCT_NAME_CO(128, 32, 512, 2, 1, 1)::Cutlass3xW4A8Gemm; invoke_gemm<SM90_CO<128, 32, 512, 2, 1, 1>>(
cutlass_w4a8_group_gemm_caller<Cutlass3xW4A8GemmSelected>(
d_tensors, d_tensors,
a_tensors, a_tensors,
b_tensors, b_tensors,
...@@ -126,8 +143,7 @@ void dispatch_w4a8_moe_mm_sm90( ...@@ -126,8 +143,7 @@ void dispatch_w4a8_moe_mm_sm90(
s_strides, s_strides,
chunk_size); chunk_size);
} else { } else {
using Cutlass3xW4A8GemmSelected = typename JOIN_STRUCT_NAME_CO(128, 64, 512, 1, 1, 1)::Cutlass3xW4A8Gemm; invoke_gemm<SM90_CO<128, 64, 512, 1, 1, 1>>(
cutlass_w4a8_group_gemm_caller<Cutlass3xW4A8GemmSelected>(
d_tensors, d_tensors,
a_tensors, a_tensors,
b_tensors, b_tensors,
...@@ -144,8 +160,7 @@ void dispatch_w4a8_moe_mm_sm90( ...@@ -144,8 +160,7 @@ void dispatch_w4a8_moe_mm_sm90(
} else if (n == 7168 && k == 2048) { } else if (n == 7168 && k == 2048) {
// group gemm 2 // group gemm 2
if (m <= 8) { if (m <= 8) {
using Cutlass3xW4A8GemmSelected = typename JOIN_STRUCT_NAME(64, 16, 512, 1, 1, 1)::Cutlass3xW4A8Gemm; invoke_gemm<SM90_PP<64, 16, 512, 1, 1, 1>>(
cutlass_w4a8_group_gemm_caller<Cutlass3xW4A8GemmSelected>(
d_tensors, d_tensors,
a_tensors, a_tensors,
b_tensors, b_tensors,
...@@ -159,8 +174,7 @@ void dispatch_w4a8_moe_mm_sm90( ...@@ -159,8 +174,7 @@ void dispatch_w4a8_moe_mm_sm90(
s_strides, s_strides,
chunk_size); chunk_size);
} else if (m <= 512) { } else if (m <= 512) {
using Cutlass3xW4A8GemmSelected = typename JOIN_STRUCT_NAME_CO(128, 32, 512, 1, 1, 1)::Cutlass3xW4A8Gemm; invoke_gemm<SM90_CO<128, 32, 512, 1, 1, 1>>(
cutlass_w4a8_group_gemm_caller<Cutlass3xW4A8GemmSelected>(
d_tensors, d_tensors,
a_tensors, a_tensors,
b_tensors, b_tensors,
...@@ -174,8 +188,7 @@ void dispatch_w4a8_moe_mm_sm90( ...@@ -174,8 +188,7 @@ void dispatch_w4a8_moe_mm_sm90(
s_strides, s_strides,
chunk_size); chunk_size);
} else { } else {
using Cutlass3xW4A8GemmSelected = typename JOIN_STRUCT_NAME_CO(128, 64, 512, 1, 1, 1)::Cutlass3xW4A8Gemm; invoke_gemm<SM90_CO<128, 64, 512, 1, 1, 1>>(
cutlass_w4a8_group_gemm_caller<Cutlass3xW4A8GemmSelected>(
d_tensors, d_tensors,
a_tensors, a_tensors,
b_tensors, b_tensors,
...@@ -189,9 +202,83 @@ void dispatch_w4a8_moe_mm_sm90( ...@@ -189,9 +202,83 @@ void dispatch_w4a8_moe_mm_sm90(
s_strides, s_strides,
chunk_size); chunk_size);
} }
} else if (n == 512 && k == 7168) {
// group gemm 1 for tp
if (m <= 4) {
invoke_gemm<SM90_PP<64, 32, 512, 2, 1, 1>>(
d_tensors,
a_tensors,
b_tensors,
a_scales,
b_scales,
expert_offsets,
problem_sizes,
a_strides,
b_strides,
d_strides,
s_strides,
chunk_size);
} else if (m <= 16) {
invoke_gemm<SM90_CO<128, 16, 512, 2, 1, 1>>(
d_tensors,
a_tensors,
b_tensors,
a_scales,
b_scales,
expert_offsets,
problem_sizes,
a_strides,
b_strides,
d_strides,
s_strides,
chunk_size);
} else if (m <= 256) {
invoke_gemm<SM90_CO<128, 16, 512, 2, 1, 1>>(
d_tensors,
a_tensors,
b_tensors,
a_scales,
b_scales,
expert_offsets,
problem_sizes,
a_strides,
b_strides,
d_strides,
s_strides,
chunk_size);
} else if (m <= 1024) {
invoke_gemm<SM90_CO<128, 32, 512, 2, 1, 1>>(
d_tensors,
a_tensors,
b_tensors,
a_scales,
b_scales,
expert_offsets,
problem_sizes,
a_strides,
b_strides,
d_strides,
s_strides,
chunk_size);
} else { } else {
using Cutlass3xW4A8GemmSelected = typename JOIN_STRUCT_NAME_CO(128, 32, 512, 1, 1, 1)::Cutlass3xW4A8Gemm; invoke_gemm<SM90_CO<128, 64, 512, 1, 1, 1>>(
cutlass_w4a8_group_gemm_caller<Cutlass3xW4A8GemmSelected>( d_tensors,
a_tensors,
b_tensors,
a_scales,
b_scales,
expert_offsets,
problem_sizes,
a_strides,
b_strides,
d_strides,
s_strides,
chunk_size);
}
} else if (n == 7168 && k == 256) {
// group gemm 2 for tp
if (m <= 8) {
invoke_gemm<SM90_PP<64, 16, 128, 1, 1, 1>>(
d_tensors, d_tensors,
a_tensors, a_tensors,
b_tensors, b_tensors,
...@@ -204,6 +291,65 @@ void dispatch_w4a8_moe_mm_sm90( ...@@ -204,6 +291,65 @@ void dispatch_w4a8_moe_mm_sm90(
d_strides, d_strides,
s_strides, s_strides,
chunk_size); chunk_size);
} else if (m <= 512) {
invoke_gemm<SM90_PP<128, 32, 128, 2, 1, 1>>(
d_tensors,
a_tensors,
b_tensors,
a_scales,
b_scales,
expert_offsets,
problem_sizes,
a_strides,
b_strides,
d_strides,
s_strides,
chunk_size);
} else {
invoke_gemm<SM90_PP<128, 64, 128, 1, 1, 1>>(
d_tensors,
a_tensors,
b_tensors,
a_scales,
b_scales,
expert_offsets,
problem_sizes,
a_strides,
b_strides,
d_strides,
s_strides,
chunk_size);
}
} else {
if (k % 512 == 0) {
invoke_gemm<SM90_CO<128, 32, 512, 1, 1, 1>>(
d_tensors,
a_tensors,
b_tensors,
a_scales,
b_scales,
expert_offsets,
problem_sizes,
a_strides,
b_strides,
d_strides,
s_strides,
chunk_size);
} else {
invoke_gemm<SM90_PP<128, 64, 128, 1, 1, 1>>(
d_tensors,
a_tensors,
b_tensors,
a_scales,
b_scales,
expert_offsets,
problem_sizes,
a_strides,
b_strides,
d_strides,
s_strides,
chunk_size);
}
} }
} }
......
...@@ -41,7 +41,6 @@ using MmaType = cutlass::float_e4m3_t; // FP8 e4m3 type ...@@ -41,7 +41,6 @@ using MmaType = cutlass::float_e4m3_t; // FP8 e4m3 type
using QuantType = cutlass::int4b_t; // 4-bit integer type using QuantType = cutlass::int4b_t; // 4-bit integer type
using ElementAccumulator = float; // Accumulator type using ElementAccumulator = float; // Accumulator type
using ElementScale = cutlass::bfloat16_t; // Scale type using ElementScale = cutlass::bfloat16_t; // Scale type
using ElementScalePacked = cutlass::Array<ElementScale, 4>;
using ElementC = cutlass::half_t; // Default output type (FP16) using ElementC = cutlass::half_t; // Default output type (FP16)
using ElementD = ElementC; // Default output type (FP16) using ElementD = ElementC; // Default output type (FP16)
using ProblemShape = cutlass::gemm::GroupProblemShape<Shape<int, int, int>>; using ProblemShape = cutlass::gemm::GroupProblemShape<Shape<int, int, int>>;
...@@ -73,6 +72,10 @@ static constexpr int AlignmentD = 128 / cutlass::sizeof_bits<ElementD>::value; ...@@ -73,6 +72,10 @@ static constexpr int AlignmentD = 128 / cutlass::sizeof_bits<ElementD>::value;
template <typename TileShape, typename ClusterShape, typename KernelSchedule, typename EpilogueSchedule> template <typename TileShape, typename ClusterShape, typename KernelSchedule, typename EpilogueSchedule>
struct cutlass_3x_w4a8_group_gemm { struct cutlass_3x_w4a8_group_gemm {
static constexpr int GroupSize = 128;
static constexpr int PackedScalesNum = get<2>(TileShape{}) / GroupSize;
using ElementScalePacked = cutlass::Array<ElementScale, PackedScalesNum>;
using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder<
ArchTag, ArchTag,
OperatorClass, OperatorClass,
...@@ -184,8 +187,6 @@ void cutlass_w4a8_group_gemm_caller( ...@@ -184,8 +187,6 @@ void cutlass_w4a8_group_gemm_caller(
TORCH_CHECK(b_tensors.size(0) == num_experts, "B tensor first dimension must match number of groups"); TORCH_CHECK(b_tensors.size(0) == num_experts, "B tensor first dimension must match number of groups");
TORCH_CHECK(b_scales.size(0) == num_experts, "Scale tensor first dimension must match number of groups"); TORCH_CHECK(b_scales.size(0) == num_experts, "Scale tensor first dimension must match number of groups");
TORCH_CHECK(b_tensors.size(2) * 2 == a_tensors.size(1), "B tensor K/2 dimension must match A tensor K dimension"); TORCH_CHECK(b_tensors.size(2) * 2 == a_tensors.size(1), "B tensor K/2 dimension must match A tensor K dimension");
TORCH_CHECK(b_scales.size(1) == a_tensors.size(1) / 512, "Scale tensor second dimension must be K//512");
TORCH_CHECK(b_scales.size(2) == 4 * b_tensors.size(1), "Scale tensor last dimension must be 4*N");
// Check tensor types // Check tensor types
TORCH_CHECK(a_tensors.scalar_type() == torch::kFloat8_e4m3fn, "A tensor must be fp8 (float_e4m3_t) type"); TORCH_CHECK(a_tensors.scalar_type() == torch::kFloat8_e4m3fn, "A tensor must be fp8 (float_e4m3_t) type");
...@@ -241,7 +242,7 @@ void cutlass_w4a8_group_gemm_caller( ...@@ -241,7 +242,7 @@ void cutlass_w4a8_group_gemm_caller(
static_cast<typename Gemm::StrideB*>(b_strides.data_ptr()), static_cast<typename Gemm::StrideB*>(b_strides.data_ptr()),
static_cast<const MmaType**>(a_ptrs.data_ptr()), static_cast<const MmaType**>(a_ptrs.data_ptr()),
static_cast<typename Gemm::StrideA*>(a_strides.data_ptr()), static_cast<typename Gemm::StrideA*>(a_strides.data_ptr()),
static_cast<const ElementScalePacked**>(b_scales_ptrs.data_ptr()), static_cast<const typename Gemm::ElementScalePacked**>(b_scales_ptrs.data_ptr()),
static_cast<typename Gemm::StrideS*>(s_strides.data_ptr()), static_cast<typename Gemm::StrideS*>(s_strides.data_ptr()),
static_cast<int>(chunk_size)}, static_cast<int>(chunk_size)},
{fusion_args, {fusion_args,
......
...@@ -27,12 +27,18 @@ def pack_interleave(num_experts, ref_weight, ref_scale): ...@@ -27,12 +27,18 @@ def pack_interleave(num_experts, ref_weight, ref_scale):
w_q = weight.view((num_experts, n, k // 2)).view(torch.int8) w_q = weight.view((num_experts, n, k // 2)).view(torch.int8)
w_q = w_q.contiguous() w_q = w_q.contiguous()
alignment = 4 if k % 512 == 0 else 1
scale_interleaved = ref_scale.reshape( scale_interleaved = ref_scale.reshape(
ref_scale.shape[0], ref_scale.shape[1], (ref_scale.shape[2] // 4), 4 ref_scale.shape[0],
ref_scale.shape[1],
(ref_scale.shape[2] // alignment),
alignment,
) # [E, N, K/4, 4] ) # [E, N, K/4, 4]
scale_interleaved = scale_interleaved.permute(0, 2, 1, 3) # [E, K/4, N, 4] scale_interleaved = scale_interleaved.permute(0, 2, 1, 3) # [E, K/4, N, 4]
scale_interleaved = scale_interleaved.reshape( scale_interleaved = scale_interleaved.reshape(
ref_scale.shape[0], ref_scale.shape[2] // 4, ref_scale.shape[1] * 4 ref_scale.shape[0],
ref_scale.shape[2] // alignment,
ref_scale.shape[1] * alignment,
) # [E, K/4, N*4] ) # [E, K/4, N*4]
w_scale = scale_interleaved.contiguous() w_scale = scale_interleaved.contiguous()
...@@ -137,8 +143,8 @@ def test_int4_fp8_grouped_gemm_single_expert(batch_size): ...@@ -137,8 +143,8 @@ def test_int4_fp8_grouped_gemm_single_expert(batch_size):
reason="cutlass_w4a8_moe_mm is only supported on sm90", reason="cutlass_w4a8_moe_mm is only supported on sm90",
) )
@pytest.mark.parametrize("batch_size", [2, 4, 8, 16]) @pytest.mark.parametrize("batch_size", [2, 4, 8, 16])
@pytest.mark.parametrize("k", [512, 1024]) @pytest.mark.parametrize("k", [256, 512, 1024])
@pytest.mark.parametrize("n", [1024, 2048]) @pytest.mark.parametrize("n", [1024, 2048, 7168])
@pytest.mark.parametrize("num_experts", [2, 4, 6, 8]) @pytest.mark.parametrize("num_experts", [2, 4, 6, 8])
def test_int4_fp8_grouped_gemm_multi_experts(batch_size, k, n, num_experts): def test_int4_fp8_grouped_gemm_multi_experts(batch_size, k, n, num_experts):
torch.manual_seed(0) torch.manual_seed(0)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment