Unverified Commit 7b8a2ab7 authored by Varun Sundar Rabindranath's avatar Varun Sundar Rabindranath Committed by GitHub
Browse files

[Kernel] Add expert_map support to Cutlass FP8 MOE (#16861)


Signed-off-by: default avatarvarun sundar rabindranath <vsundarr@redhat.com>
Co-authored-by: default avatarvarun sundar rabindranath <vsundarr@redhat.com>
parent c9acbf11
...@@ -46,14 +46,26 @@ __global__ void compute_expert_offsets( ...@@ -46,14 +46,26 @@ __global__ void compute_expert_offsets(
} }
__global__ void compute_arg_sorts(const int* __restrict__ topk_ids, __global__ void compute_arg_sorts(const int* __restrict__ topk_ids,
const int32_t* __restrict__ expert_offsets,
int32_t* input_permutation, int32_t* input_permutation,
int32_t* output_permutation, int32_t* output_permutation,
int32_t* atomic_buffer, const int topk_length, int32_t* atomic_buffer, const int topk_length,
const int topk) { const int topk) {
int expert_id = blockIdx.x; int const blk_expert_id = blockIdx.x;
int const num_experts = gridDim.x;
int32_t const num_tokens = expert_offsets[num_experts];
for (int i = threadIdx.x; i < topk_length; i += THREADS_PER_EXPERT) { for (int i = threadIdx.x; i < topk_length; i += THREADS_PER_EXPERT) {
if (topk_ids[i] == expert_id) { int const expert_id = topk_ids[i];
if (expert_id == -1 && blockIdx.x == 0) {
// output_permutation is used to re-order the moe outputs. It is
// used as c2 = c2[c_map], where c2 is a torch.tensor that is the
// output of the cutlass kernels and c_map is the output_permutation.
// c2 is initialized to zeros, therefore by setting the output_permutation
// to num_tokens, we are guaranteed to fill the moe outputs to zero
// for "invalid" topk_ids.
output_permutation[i] = num_tokens;
} else if (expert_id == blk_expert_id) {
int start = atomicAdd(&atomic_buffer[expert_id], 1); int start = atomicAdd(&atomic_buffer[expert_id], 1);
input_permutation[start] = i / topk; input_permutation[start] = i / topk;
output_permutation[i] = start; output_permutation[i] = start;
...@@ -83,6 +95,7 @@ void get_cutlass_moe_mm_data_caller( ...@@ -83,6 +95,7 @@ void get_cutlass_moe_mm_data_caller(
static_cast<int32_t*>(atomic_buffer.data_ptr()), num_experts); static_cast<int32_t*>(atomic_buffer.data_ptr()), num_experts);
compute_arg_sorts<<<num_experts, num_threads, 0, stream>>>( compute_arg_sorts<<<num_experts, num_threads, 0, stream>>>(
static_cast<const int32_t*>(topk_ids.data_ptr()), static_cast<const int32_t*>(topk_ids.data_ptr()),
static_cast<const int32_t*>(expert_offsets.data_ptr()),
static_cast<int32_t*>(input_permutation.data_ptr()), static_cast<int32_t*>(input_permutation.data_ptr()),
static_cast<int32_t*>(output_permutation.data_ptr()), static_cast<int32_t*>(output_permutation.data_ptr()),
static_cast<int32_t*>(atomic_buffer.data_ptr()), topk_ids.numel(), static_cast<int32_t*>(atomic_buffer.data_ptr()), topk_ids.numel(),
......
This diff is collapsed.
...@@ -1693,6 +1693,7 @@ class ParallelConfig: ...@@ -1693,6 +1693,7 @@ class ParallelConfig:
factors: list[Any] = [] factors: list[Any] = []
factors.append(self.pipeline_parallel_size) factors.append(self.pipeline_parallel_size)
factors.append(self.tensor_parallel_size) factors.append(self.tensor_parallel_size)
factors.append(self.enable_expert_parallel)
return hashlib.sha256(str(factors).encode()).hexdigest() return hashlib.sha256(str(factors).encode()).hexdigest()
def __post_init__(self) -> None: def __post_init__(self) -> None:
......
...@@ -15,7 +15,7 @@ def cutlass_moe_fp8( ...@@ -15,7 +15,7 @@ def cutlass_moe_fp8(
w1_scale: torch.Tensor, w1_scale: torch.Tensor,
w2_scale: torch.Tensor, w2_scale: torch.Tensor,
topk_weights: torch.Tensor, topk_weights: torch.Tensor,
topk_ids: torch.Tensor, topk_ids_: torch.Tensor,
ab_strides1: torch.Tensor, ab_strides1: torch.Tensor,
c_strides1: torch.Tensor, c_strides1: torch.Tensor,
ab_strides2: torch.Tensor, ab_strides2: torch.Tensor,
...@@ -23,6 +23,7 @@ def cutlass_moe_fp8( ...@@ -23,6 +23,7 @@ def cutlass_moe_fp8(
a1_scale: Optional[torch.Tensor] = None, a1_scale: Optional[torch.Tensor] = None,
a2_scale: Optional[torch.Tensor] = None, a2_scale: Optional[torch.Tensor] = None,
out_dtype: torch.dtype = torch.half, out_dtype: torch.dtype = torch.half,
expert_map: Optional[torch.Tensor] = None,
apply_router_weight_on_input: bool = False, apply_router_weight_on_input: bool = False,
) -> torch.Tensor: ) -> torch.Tensor:
""" """
...@@ -57,12 +58,19 @@ def cutlass_moe_fp8( ...@@ -57,12 +58,19 @@ def cutlass_moe_fp8(
quantize the intermediate result between the gemms. quantize the intermediate result between the gemms.
Shape: scalar or [M] Shape: scalar or [M]
- out_dtype (torch.Tensor): The output tensor type. - out_dtype (torch.Tensor): The output tensor type.
- expert_map (Optional[torch.Tensor]): In the case of Expert parallel,
every Rank is responsible for a subset of experts. expert_map is a
mapping from global expert-id to local expert-id. When expert_map[i]
is -1, it means that this Rank is not responsible for global
expert-id i.
- apply_router_weight_on_input (bool): When true, the topk weights are
applied directly on the inputs. This is only applicable when topk is 1.
Returns: Returns:
- torch.Tensor: The fp16 output tensor after applying the MoE layer. - torch.Tensor: The fp16 output tensor after applying the MoE layer.
""" """
assert topk_weights.shape == topk_ids.shape, "topk shape mismatch" assert topk_weights.shape == topk_ids_.shape, "topk shape mismatch"
assert w1_q.dtype == torch.float8_e4m3fn assert w1_q.dtype == torch.float8_e4m3fn
assert w2_q.dtype == torch.float8_e4m3fn assert w2_q.dtype == torch.float8_e4m3fn
assert a.shape[1] == w1_q.shape[1], "Hidden size mismatch w1" assert a.shape[1] == w1_q.shape[1], "Hidden size mismatch w1"
...@@ -96,7 +104,13 @@ def cutlass_moe_fp8( ...@@ -96,7 +104,13 @@ def cutlass_moe_fp8(
k = w1_q.size(1) k = w1_q.size(1)
n = w2_q.size(1) n = w2_q.size(1)
topk = topk_ids.size(1) local_topk_ids = topk_ids_
if expert_map is not None:
"Translate info from expert_map to topk_ids"
local_topk_ids = torch.where(expert_map[topk_ids_] != -1,
expert_map[topk_ids_], -1)
topk = local_topk_ids.size(1)
per_act_token = a1_scale.numel() != 1 if a1_scale is not None else ( per_act_token = a1_scale.numel() != 1 if a1_scale is not None else (
a2_scale.numel() != 1 if a2_scale is not None else False) a2_scale.numel() != 1 if a2_scale is not None else False)
...@@ -120,10 +134,23 @@ def cutlass_moe_fp8( ...@@ -120,10 +134,23 @@ def cutlass_moe_fp8(
dtype=torch.int32, dtype=torch.int32,
device=device) device=device)
a_map = torch.empty((topk_ids.numel()), dtype=torch.int32, device=device) a_map_initializer = torch.empty
c_map = torch.empty((topk_ids.numel()), dtype=torch.int32, device=device) c2_initializer = torch.empty
if expert_map is not None:
ops.get_cutlass_moe_mm_data(topk_ids, expert_offsets, problem_sizes1, # With expert_map each Rank processes only a subset of experts. As
# a result not all of a_map and c2 tensors are filled. We fill it
# zeros for correctness.
a_map_initializer = torch.zeros
c2_initializer = torch.zeros
a_map = a_map_initializer((local_topk_ids.numel()),
dtype=torch.int32,
device=device)
c_map = torch.empty((local_topk_ids.numel()),
dtype=torch.int32,
device=device)
ops.get_cutlass_moe_mm_data(local_topk_ids, expert_offsets, problem_sizes1,
problem_sizes2, a_map, c_map, num_experts, n, problem_sizes2, a_map, c_map, num_experts, n,
k) k)
...@@ -131,7 +158,7 @@ def cutlass_moe_fp8( ...@@ -131,7 +158,7 @@ def cutlass_moe_fp8(
rep_a1_scales = a1_scale[a_map] if per_act_token else a1_scale rep_a1_scales = a1_scale[a_map] if per_act_token else a1_scale
c1 = torch.empty((m * topk, n * 2), device=device, dtype=out_dtype) c1 = torch.empty((m * topk, n * 2), device=device, dtype=out_dtype)
c2 = torch.empty((m * topk, k), device=device, dtype=out_dtype) c2 = c2_initializer((m * topk, k), device=device, dtype=out_dtype)
ops.cutlass_moe_mm(c1, rep_a_q, w1_q, rep_a1_scales, w1_scale, ops.cutlass_moe_mm(c1, rep_a_q, w1_q, rep_a1_scales, w1_scale,
expert_offsets[:-1], problem_sizes1, ab_strides1, expert_offsets[:-1], problem_sizes1, ab_strides1,
......
...@@ -67,7 +67,7 @@ class CompressedTensorsMoEMethod(FusedMoEMethodBase): ...@@ -67,7 +67,7 @@ class CompressedTensorsMoEMethod(FusedMoEMethodBase):
else: else:
return CompressedTensorsWNA16MarlinMoEMethod(quant_config) return CompressedTensorsWNA16MarlinMoEMethod(quant_config)
elif (quant_config._is_fp8_w8a8_sm90(weight_quant, input_quant) elif (quant_config._is_fp8_w8a8_sm90(weight_quant, input_quant)
and layer.activation == "silu" and layer.expert_map is None): and layer.activation == "silu"):
return CompressedTensorsW8A8Fp8MoECutlassMethod(quant_config) return CompressedTensorsW8A8Fp8MoECutlassMethod(quant_config)
elif quant_config._is_fp8_w8a8(weight_quant, input_quant): elif quant_config._is_fp8_w8a8(weight_quant, input_quant):
return CompressedTensorsW8A8Fp8MoEMethod(quant_config) return CompressedTensorsW8A8Fp8MoEMethod(quant_config)
...@@ -510,8 +510,6 @@ class CompressedTensorsW8A8Fp8MoECutlassMethod(CompressedTensorsMoEMethod): ...@@ -510,8 +510,6 @@ class CompressedTensorsW8A8Fp8MoECutlassMethod(CompressedTensorsMoEMethod):
) -> torch.Tensor: ) -> torch.Tensor:
assert activation == "silu" assert activation == "silu"
assert global_num_experts == layer.w13_weight.shape[0]
assert expert_map is None
topk_weights, topk_ids = FusedMoE.select_experts( topk_weights, topk_ids = FusedMoE.select_experts(
hidden_states=x, hidden_states=x,
...@@ -542,6 +540,7 @@ class CompressedTensorsW8A8Fp8MoECutlassMethod(CompressedTensorsMoEMethod): ...@@ -542,6 +540,7 @@ class CompressedTensorsW8A8Fp8MoECutlassMethod(CompressedTensorsMoEMethod):
a1_scale=layer.w13_input_scale, a1_scale=layer.w13_input_scale,
a2_scale=layer.w2_input_scale, a2_scale=layer.w2_input_scale,
out_dtype=x.dtype, out_dtype=x.dtype,
expert_map=expert_map,
apply_router_weight_on_input=apply_router_weight_on_input, apply_router_weight_on_input=apply_router_weight_on_input,
) )
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment