Unverified Commit 21463e32 authored by lukec's avatar lukec Committed by GitHub
Browse files

Expert Parallelism (EP) Support for DeepSeek V3/R1 (#3602)


Co-authored-by: default avatarlaixin <xielx@shanghaitech.edu.cn>
Co-authored-by: default avatarHandH1998 <1335248067@qq.com>
Co-authored-by: default avatarlaixin <q865809639@gmail.com>
parent 3dc9ff3c
import logging import logging
from typing import Optional from typing import List, Optional
import torch import torch
import triton import triton
import triton.language as tl import triton.language as tl
from sglang.srt.layers.quantization.fp8_kernel import per_token_group_quant_fp8
_is_cuda = torch.cuda.is_available() and torch.version.cuda
if _is_cuda:
from sglang.srt.layers.quantization.fp8_kernel import (
sglang_per_token_group_quant_fp8,
)
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
...@@ -218,12 +225,19 @@ def grouped_gemm_triton_kernel( ...@@ -218,12 +225,19 @@ def grouped_gemm_triton_kernel(
seg_indptr, seg_indptr,
weight_indices, weight_indices,
m_num_tiles_indptr, m_num_tiles_indptr,
use_fp8_w8a8,
scale_a, scale_a,
scale_b, scale_b,
use_fp8_w8a8: tl.constexpr,
group_n: tl.constexpr,
group_k: tl.constexpr,
a_stride_0: tl.constexpr, a_stride_0: tl.constexpr,
b_stride_0: tl.constexpr, b_stride_0: tl.constexpr,
b_stride_1: tl.constexpr, b_stride_1: tl.constexpr,
as_stride_0: tl.constexpr,
as_stride_1: tl.constexpr,
bs_stride_0: tl.constexpr,
bs_stride_2: tl.constexpr,
bs_stride_1: tl.constexpr,
BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_M: tl.constexpr,
BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_N: tl.constexpr,
BLOCK_SIZE_K: tl.constexpr, BLOCK_SIZE_K: tl.constexpr,
...@@ -260,6 +274,12 @@ def grouped_gemm_triton_kernel( ...@@ -260,6 +274,12 @@ def grouped_gemm_triton_kernel(
+ (n_range_start + offs_bn[:, None]) * b_stride_1 + (n_range_start + offs_bn[:, None]) * b_stride_1
+ offs_k[None, :] + offs_k[None, :]
) )
if group_k > 0 and group_n > 0:
a_scale_ptrs = scale_a + (m_range_start + offs_am[:, None]) * as_stride_0
offs_bsn = (n_range_start + offs_bn) // group_n
b_scale_ptrs = scale_b + (expert_id * bs_stride_0) + offs_bsn * bs_stride_1
accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)): for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)):
a_tile = tl.load( a_tile = tl.load(
...@@ -268,14 +288,23 @@ def grouped_gemm_triton_kernel( ...@@ -268,14 +288,23 @@ def grouped_gemm_triton_kernel(
b_tile = tl.load( b_tile = tl.load(
b_ptr, mask=offs_k[None, :] < (K - k * BLOCK_SIZE_K), other=0.0 b_ptr, mask=offs_k[None, :] < (K - k * BLOCK_SIZE_K), other=0.0
) )
if group_k > 0 and group_n > 0:
k_start = k * BLOCK_SIZE_K
offs_ks = k_start // group_k
a_scale = tl.load(a_scale_ptrs + offs_ks * as_stride_1)
b_scale = tl.load(b_scale_ptrs + offs_ks * bs_stride_2)
accumulator += tl.dot(a_tile, b_tile.T) * a_scale * b_scale[None, :]
else:
accumulator = tl.dot(a_tile, b_tile.T, accumulator) accumulator = tl.dot(a_tile, b_tile.T, accumulator)
a_ptr += BLOCK_SIZE_K a_ptr += BLOCK_SIZE_K
b_ptr += BLOCK_SIZE_K b_ptr += BLOCK_SIZE_K
if use_fp8_w8a8: if use_fp8_w8a8 and not (group_k > 0 and group_n > 0):
scale_a_value = tl.load(scale_a + expert_id) scale_a_value = tl.load(scale_a + expert_id)
scale_b_value = tl.load(scale_b + expert_id) scale_b_value = tl.load(scale_b + expert_id)
accumulator *= scale_a_value * scale_b_value accumulator *= scale_a_value * scale_b_value
c_tile = accumulator.to(c_dtype) c_tile = accumulator.to(c_dtype)
offs_cm = m_range_start + tl.arange(0, BLOCK_SIZE_M) offs_cm = m_range_start + tl.arange(0, BLOCK_SIZE_M)
...@@ -307,14 +336,29 @@ def grouped_gemm_triton( ...@@ -307,14 +336,29 @@ def grouped_gemm_triton(
use_fp8_w8a8: bool = False, use_fp8_w8a8: bool = False,
scale_a: torch.Tensor = None, scale_a: torch.Tensor = None,
scale_b: torch.Tensor = None, scale_b: torch.Tensor = None,
block_shape: Optional[List[int]] = None,
): ):
assert weight_column_major == True # TODO: more assert weight_column_major == True # TODO: more
if use_fp8_w8a8: if use_fp8_w8a8 and block_shape is None:
assert scale_a is not None and scale_b is not None assert scale_a is not None and scale_b is not None
if block_shape is not None:
assert len(block_shape) == 2
block_n, block_k = block_shape[0], block_shape[1]
if _is_cuda:
a, scale_a = sglang_per_token_group_quant_fp8(a, block_k)
else:
a, scale_a = per_token_group_quant_fp8(a, block_k)
assert triton.cdiv(a.shape[-1], block_k) == scale_a.shape[-1]
assert triton.cdiv(b.shape[-2], block_n) == scale_b.shape[-2]
assert triton.cdiv(b.shape[-1], block_k) == scale_b.shape[-1]
# TODO: adjust config or tune kernel
# Reduce block size to prevent L40 shared memory overflow.
config = { config = {
"BLOCK_SIZE_M": 128, "BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 128, "BLOCK_SIZE_N": 32,
"BLOCK_SIZE_K": 128, "BLOCK_SIZE_K": 128,
} }
...@@ -338,12 +382,19 @@ def grouped_gemm_triton( ...@@ -338,12 +382,19 @@ def grouped_gemm_triton(
seg_indptr, seg_indptr,
weight_indices, weight_indices,
m_num_tiles_indptr, m_num_tiles_indptr,
use_fp8_w8a8,
scale_a, scale_a,
scale_b, scale_b,
use_fp8_w8a8,
0 if block_shape is None else block_shape[0],
0 if block_shape is None else block_shape[1],
a.stride(0), a.stride(0),
b.stride(0), b.stride(0),
b.stride(1), b.stride(1),
scale_a.stride(0) if scale_a is not None and scale_a.ndim == 2 else 0,
scale_a.stride(1) if scale_a is not None and scale_a.ndim == 2 else 0,
scale_b.stride(0) if scale_b is not None and scale_b.ndim >= 2 else 0,
scale_b.stride(2) if scale_b is not None and scale_b.ndim == 3 else 0,
scale_b.stride(1) if scale_b is not None and scale_b.ndim >= 2 else 0,
**config, **config,
) )
return c return c
...@@ -17,6 +17,7 @@ from sglang.srt.layers.moe.ep_moe.kernels import ( ...@@ -17,6 +17,7 @@ from sglang.srt.layers.moe.ep_moe.kernels import (
run_moe_ep_preproess, run_moe_ep_preproess,
silu_and_mul_triton_kernel, silu_and_mul_triton_kernel,
) )
from sglang.srt.layers.moe.fused_moe_triton import FusedMoeWeightScaleSupported
from sglang.srt.layers.moe.fused_moe_triton.layer import FusedMoEMethodBase from sglang.srt.layers.moe.fused_moe_triton.layer import FusedMoEMethodBase
from sglang.srt.layers.moe.topk import select_experts from sglang.srt.layers.moe.topk import select_experts
from sglang.srt.layers.quantization.base_config import ( from sglang.srt.layers.quantization.base_config import (
...@@ -61,6 +62,7 @@ class GroupedGemmRunner(torch.nn.Module): ...@@ -61,6 +62,7 @@ class GroupedGemmRunner(torch.nn.Module):
use_fp8_w8a8: bool = False, use_fp8_w8a8: bool = False,
scale_a: torch.Tensor = None, scale_a: torch.Tensor = None,
scale_b: torch.Tensor = None, scale_b: torch.Tensor = None,
block_shape: Optional[List[int]] = None,
): ):
if self.use_flashinfer: if self.use_flashinfer:
# TODO: flashinfer # TODO: flashinfer
...@@ -87,6 +89,7 @@ class GroupedGemmRunner(torch.nn.Module): ...@@ -87,6 +89,7 @@ class GroupedGemmRunner(torch.nn.Module):
use_fp8_w8a8, use_fp8_w8a8,
scale_a, scale_a,
scale_b, scale_b,
block_shape=block_shape,
) )
return c return c
...@@ -147,12 +150,20 @@ class EPMoE(torch.nn.Module): ...@@ -147,12 +150,20 @@ class EPMoE(torch.nn.Module):
if quant_config is None: if quant_config is None:
self.quant_method: Optional[QuantizeMethodBase] = UnquantizedEPMoEMethod() self.quant_method: Optional[QuantizeMethodBase] = UnquantizedEPMoEMethod()
self.use_fp8_w8a8 = False self.use_fp8_w8a8 = False
self.use_block_quant = False
self.block_shape = None
self.activation_scheme = None self.activation_scheme = None
else: else:
self.quant_method: Optional[QuantizeMethodBase] = Fp8EPMoEMethod( self.quant_method: Optional[QuantizeMethodBase] = Fp8EPMoEMethod(
quant_config quant_config
) )
self.use_fp8_w8a8 = True self.use_fp8_w8a8 = True
self.use_block_quant = getattr(self.quant_method, "block_quant", False)
self.block_shape = (
self.quant_method.quant_config.weight_block_size
if self.use_block_quant
else None
)
self.fp8_dtype = torch.float8_e4m3fn self.fp8_dtype = torch.float8_e4m3fn
self.activation_scheme = quant_config.activation_scheme self.activation_scheme = quant_config.activation_scheme
...@@ -173,7 +184,8 @@ class EPMoE(torch.nn.Module): ...@@ -173,7 +184,8 @@ class EPMoE(torch.nn.Module):
if self.grouped_gemm_runner is None: if self.grouped_gemm_runner is None:
self.grouped_gemm_runner = GroupedGemmRunner( self.grouped_gemm_runner = GroupedGemmRunner(
hidden_states.device, use_flashinfer=False # TODO: use flashinfer hidden_states.device,
use_flashinfer=False, # TODO: use flashinfer
) )
topk_weights, topk_ids = select_experts( topk_weights, topk_ids = select_experts(
...@@ -195,9 +207,13 @@ class EPMoE(torch.nn.Module): ...@@ -195,9 +207,13 @@ class EPMoE(torch.nn.Module):
gateup_input = torch.empty( gateup_input = torch.empty(
(int(hidden_states.shape[0] * self.top_k), hidden_states.shape[1]), (int(hidden_states.shape[0] * self.top_k), hidden_states.shape[1]),
device=hidden_states.device, device=hidden_states.device,
dtype=self.fp8_dtype if self.use_fp8_w8a8 else hidden_states.dtype, dtype=(
self.fp8_dtype
if (self.use_fp8_w8a8 and not self.use_block_quant)
else hidden_states.dtype
),
) )
if self.activation_scheme == "dynamic": if self.activation_scheme == "dynamic" and not self.use_block_quant:
max_value = ( max_value = (
torch.max(hidden_states) torch.max(hidden_states)
.repeat(self.num_experts_per_partition) .repeat(self.num_experts_per_partition)
...@@ -243,7 +259,12 @@ class EPMoE(torch.nn.Module): ...@@ -243,7 +259,12 @@ class EPMoE(torch.nn.Module):
weight_indices=weight_indices_cur_rank, weight_indices=weight_indices_cur_rank,
use_fp8_w8a8=self.use_fp8_w8a8, use_fp8_w8a8=self.use_fp8_w8a8,
scale_a=self.w13_input_scale, scale_a=self.w13_input_scale,
scale_b=self.w13_weight_scale, scale_b=(
self.w13_weight_scale_inv
if self.use_block_quant
else self.w13_weight_scale
),
block_shape=self.block_shape,
) )
# Act # Act
...@@ -251,9 +272,13 @@ class EPMoE(torch.nn.Module): ...@@ -251,9 +272,13 @@ class EPMoE(torch.nn.Module):
gateup_output.shape[0], gateup_output.shape[0],
gateup_output.shape[1] // 2, gateup_output.shape[1] // 2,
device=gateup_output.device, device=gateup_output.device,
dtype=self.fp8_dtype if self.use_fp8_w8a8 else hidden_states.dtype, dtype=(
self.fp8_dtype
if (self.use_fp8_w8a8 and not self.use_block_quant)
else hidden_states.dtype
),
) )
if self.w2_input_scale is None: if self.w2_input_scale is None and not self.use_block_quant:
self.w2_input_scale = torch.ones( self.w2_input_scale = torch.ones(
self.num_experts_per_partition, self.num_experts_per_partition,
dtype=torch.float32, dtype=torch.float32,
...@@ -291,7 +316,12 @@ class EPMoE(torch.nn.Module): ...@@ -291,7 +316,12 @@ class EPMoE(torch.nn.Module):
weight_indices=weight_indices_cur_rank, weight_indices=weight_indices_cur_rank,
use_fp8_w8a8=self.use_fp8_w8a8, use_fp8_w8a8=self.use_fp8_w8a8,
scale_a=self.w2_input_scale, scale_a=self.w2_input_scale,
scale_b=self.w2_weight_scale, scale_b=(
self.w2_weight_scale_inv
if self.use_block_quant
else self.w2_weight_scale
),
block_shape=self.block_shape,
) )
# PostReorder # PostReorder
...@@ -358,7 +388,11 @@ class EPMoE(torch.nn.Module): ...@@ -358,7 +388,11 @@ class EPMoE(torch.nn.Module):
# Special case for fp8 scales. # Special case for fp8 scales.
if "scale" in weight_name: if "scale" in weight_name:
self._load_fp8_scale( self._load_fp8_scale(
param.data, loaded_weight, weight_name, shard_id, expert_id param.data,
loaded_weight,
weight_name,
shard_id,
expert_id,
) )
return return
...@@ -395,18 +429,33 @@ class EPMoE(torch.nn.Module): ...@@ -395,18 +429,33 @@ class EPMoE(torch.nn.Module):
param_data[expert_id] = loaded_weight param_data[expert_id] = loaded_weight
# Weight scales # Weight scales
elif "weight_scale" in weight_name: elif "weight_scale" in weight_name:
if self.use_block_quant:
block_n, block_k = self.block_shape[0], self.block_shape[1]
if shard_id == "w1":
param_data[expert_id][
: (self.intermediate_size + block_n - 1) // block_n, :
] = loaded_weight
elif shard_id == "w3":
param_data[expert_id][
(self.intermediate_size + block_n - 1) // block_n :, :
] = loaded_weight
else: # w2
param_data[expert_id] = loaded_weight
# If we are in merged column case (gate_up_proj) # If we are in merged column case (gate_up_proj)
else:
if shard_id in ("w1", "w3"): if shard_id in ("w1", "w3"):
# We have to keep the weight scales of w1 and w3 because # We have to keep the weight scales of w1 and w3 because
# we need to re-quantize w1/w3 weights after weight loading. # we need to re-quantize w1/w3 weights after weight loading.
idx = 0 if shard_id == "w1" else 1 idx = 0 if shard_id == "w1" else 1
param_data[expert_id][idx] = loaded_weight param_data[expert_id][idx] = loaded_weight
# If we are in the row parallel case (down_proj) # If we are in the row parallel case (down_proj)
else: else:
param_data[expert_id] = loaded_weight param_data[expert_id] = loaded_weight
class UnquantizedEPMoEMethod(FusedMoEMethodBase, CustomOp): class UnquantizedEPMoEMethod(FusedMoEMethodBase, CustomOp):
def create_weights( def create_weights(
self, self,
layer: torch.nn.Module, layer: torch.nn.Module,
...@@ -498,6 +547,7 @@ class Fp8EPMoEMethod(Fp8MoEMethod): ...@@ -498,6 +547,7 @@ class Fp8EPMoEMethod(Fp8MoEMethod):
def __init__(self, quant_config: Fp8Config): def __init__(self, quant_config: Fp8Config):
self.quant_config = quant_config self.quant_config = quant_config
self.block_quant = self.quant_config.weight_block_size is not None
def create_weights( def create_weights(
self, self,
...@@ -512,6 +562,29 @@ class Fp8EPMoEMethod(Fp8MoEMethod): ...@@ -512,6 +562,29 @@ class Fp8EPMoEMethod(Fp8MoEMethod):
if self.quant_config.is_checkpoint_fp8_serialized: if self.quant_config.is_checkpoint_fp8_serialized:
params_dtype = torch.float8_e4m3fn params_dtype = torch.float8_e4m3fn
tp_size = get_tensor_model_parallel_world_size()
if self.block_quant:
block_n, block_k = (
self.quant_config.weight_block_size[0],
self.quant_config.weight_block_size[1],
)
# NOTE(HandH1998): To ensure proper alignment of the block-wise quantization scales, the output_size of the weights for both the gate and up layers must be divisible by block_n.
# Required by collum parallel or enabling merged weights
if intermediate_size % block_n != 0:
raise ValueError(
f"The output_size of gate's and up's weight = "
f"{intermediate_size} is not divisible by "
f"weight quantization block_n = {block_n}."
)
if tp_size > 1:
# Required by row parallel
if intermediate_size % block_k != 0:
raise ValueError(
f"The input_size of down's weight = "
f"{intermediate_size} is not divisible by "
f"weight quantization block_k = {block_k}."
)
# WEIGHTS # WEIGHTS
w13_weight = torch.nn.Parameter( w13_weight = torch.nn.Parameter(
torch.empty( torch.empty(
...@@ -537,6 +610,30 @@ class Fp8EPMoEMethod(Fp8MoEMethod): ...@@ -537,6 +610,30 @@ class Fp8EPMoEMethod(Fp8MoEMethod):
layer.register_parameter("w2_weight", w2_weight) layer.register_parameter("w2_weight", w2_weight)
set_weight_attrs(w2_weight, extra_weight_attrs) set_weight_attrs(w2_weight, extra_weight_attrs)
# WEIGHT_SCALES
if self.block_quant:
w13_weight_scale = torch.nn.Parameter(
torch.ones(
num_experts_per_partition,
2 * ((intermediate_size + block_n - 1) // block_n),
(hidden_size + block_k - 1) // block_k,
dtype=torch.float32,
),
requires_grad=False,
)
w2_weight_scale = torch.nn.Parameter(
torch.ones(
num_experts_per_partition,
(hidden_size + block_n - 1) // block_n,
(intermediate_size + block_k - 1) // block_k,
dtype=torch.float32,
),
requires_grad=False,
)
layer.register_parameter("w13_weight_scale_inv", w13_weight_scale)
layer.register_parameter("w2_weight_scale_inv", w2_weight_scale)
assert self.quant_config.activation_scheme == "dynamic"
else:
# WEIGHT_SCALES # WEIGHT_SCALES
# Allocate 2 scales for w1 and w3 respectively. # Allocate 2 scales for w1 and w3 respectively.
w13_weight_scale = torch.nn.Parameter( w13_weight_scale = torch.nn.Parameter(
...@@ -552,7 +649,11 @@ class Fp8EPMoEMethod(Fp8MoEMethod): ...@@ -552,7 +649,11 @@ class Fp8EPMoEMethod(Fp8MoEMethod):
layer.register_parameter("w2_weight_scale", w2_weight_scale) layer.register_parameter("w2_weight_scale", w2_weight_scale)
# Add the quantization method used (per tensor/grouped/channel) # Add the quantization method used (per tensor/grouped/channel)
# to ensure the weight scales are loaded in properly # to ensure the weight scales are loaded in properly
extra_weight_attrs.update({"quant_method": "tensor"}) extra_weight_attrs.update(
{"quant_method": FusedMoeWeightScaleSupported.BLOCK.value}
if self.block_quant
else {"quant_method": FusedMoeWeightScaleSupported.TENSOR.value}
)
# If loading fp8 checkpoint, pass the weight loaders. # If loading fp8 checkpoint, pass the weight loaders.
# If loading an fp16 checkpoint, do not (we will quantize in # If loading an fp16 checkpoint, do not (we will quantize in
# process_weights_after_loading() # process_weights_after_loading()
......
import itertools
import random
import unittest
from typing import Any, Callable, Dict, List, Optional, Tuple
import torch
from sglang.srt.layers.moe.ep_moe.kernels import (
grouped_gemm_triton,
post_reorder_triton_kernel,
pre_reorder_triton_kernel,
run_moe_ep_preproess,
silu_and_mul_triton_kernel,
)
from sglang.srt.layers.moe.topk import select_experts
# For test
def ep_moe(
hidden_states: torch.Tensor,
w1: torch.Tensor,
w2: torch.Tensor,
router_logits: torch.Tensor,
top_k: int,
renormalize: bool,
# ep config
num_experts: int = 256,
fp8_dtype: torch.types = torch.float8_e4m3fn,
num_experts_per_partition: int = 128,
start_expert_id: int = 0,
end_expert_id: int = 127,
use_grouped_topk: bool = False,
num_expert_group: Optional[int] = None,
topk_group: Optional[int] = None,
custom_routing_function: Optional[Callable] = None,
use_fp8_w8a8: bool = False,
w1_scale_inv: Optional[torch.Tensor] = None,
w2_scale_inv: Optional[torch.Tensor] = None,
block_shape: Optional[List[int]] = None,
):
use_blockwise_fp8 = block_shape is not None
topk_weights, topk_ids = select_experts(
hidden_states=hidden_states,
router_logits=router_logits,
top_k=top_k,
use_grouped_topk=use_grouped_topk,
renormalize=renormalize,
topk_group=topk_group,
num_expert_group=num_expert_group,
# correction_bias=correction_bias, #skip this in test
custom_routing_function=custom_routing_function,
)
reorder_topk_ids, src2dst, seg_indptr = run_moe_ep_preproess(topk_ids, num_experts)
gateup_input = torch.empty(
(int(hidden_states.shape[0] * top_k), hidden_states.shape[1]),
device=hidden_states.device,
dtype=(
fp8_dtype
if (use_fp8_w8a8 and not use_blockwise_fp8)
else hidden_states.dtype
),
)
if use_fp8_w8a8 and not use_blockwise_fp8:
max_value = (
torch.max(hidden_states).repeat(num_experts_per_partition).to(torch.float32)
)
w1_input_scale = max_value / torch.finfo(fp8_dtype).max
else:
w1_input_scale = None
# PreReorder
pre_reorder_triton_kernel[(hidden_states.shape[0],)](
hidden_states,
gateup_input,
src2dst,
topk_ids,
w1_input_scale,
start_expert_id,
end_expert_id,
top_k,
hidden_states.shape[1],
BLOCK_SIZE=512,
)
seg_indptr_cur_rank = seg_indptr[start_expert_id : end_expert_id + 2]
weight_indices_cur_rank = torch.arange(
0,
num_experts_per_partition,
device=hidden_states.device,
dtype=torch.int64,
)
# GroupGemm-0
gateup_output = torch.empty(
gateup_input.shape[0],
w1.shape[1],
device=hidden_states.device,
dtype=hidden_states.dtype,
)
gateup_output = grouped_gemm_triton(
a=gateup_input,
b=w1,
c=gateup_output,
batch_size=num_experts_per_partition,
weight_column_major=True,
seg_indptr=seg_indptr_cur_rank,
weight_indices=weight_indices_cur_rank,
use_fp8_w8a8=use_fp8_w8a8,
scale_a=w1_input_scale,
scale_b=w1_scale_inv,
block_shape=block_shape,
)
# Act
down_input = torch.empty(
gateup_output.shape[0],
gateup_output.shape[1] // 2,
device=gateup_output.device,
dtype=(
fp8_dtype
if (use_fp8_w8a8 and not use_blockwise_fp8)
else hidden_states.dtype
),
)
if use_fp8_w8a8 and not use_blockwise_fp8:
w2_input_scale = torch.ones(
num_experts_per_partition,
dtype=torch.float32,
device=hidden_states.device,
)
else:
w2_input_scale = None
silu_and_mul_triton_kernel[(gateup_output.shape[0],)](
gateup_output,
down_input,
gateup_output.shape[1],
reorder_topk_ids,
w2_input_scale,
start_expert_id,
end_expert_id,
BLOCK_SIZE=512,
)
# GroupGemm-1
down_output = torch.empty(
down_input.shape[0],
w2.shape[1],
device=hidden_states.device,
dtype=hidden_states.dtype,
)
down_output = grouped_gemm_triton(
a=down_input,
b=w2,
c=down_output,
batch_size=num_experts_per_partition,
weight_column_major=True,
seg_indptr=seg_indptr_cur_rank,
weight_indices=weight_indices_cur_rank,
use_fp8_w8a8=use_fp8_w8a8,
scale_a=w2_input_scale,
scale_b=w2_scale_inv,
block_shape=block_shape,
)
# PostReorder
output = torch.empty_like(hidden_states)
post_reorder_triton_kernel[(hidden_states.size(0),)](
down_output,
output,
src2dst,
topk_ids,
topk_weights,
start_expert_id,
end_expert_id,
top_k,
hidden_states.size(1),
BLOCK_SIZE=512,
)
return output
# test util
def block_dequant(
x_q_block: torch.Tensor,
x_s: torch.Tensor,
block_size: List[int],
) -> Tuple[torch.Tensor, torch.Tensor]:
"""This function converts block-wise quantization to tensor-wise quantization.
The inputs are block-wise quantization tensor `x_q_block`, block-wise quantization scale
and the block size.
The outputs are tensor-wise quantization tensor and tensor-wise quantization scale.
Note only float8 is supported for now.
"""
# process 3D tensor
if x_q_block.dim() == 3:
batch_size = x_q_block.size(0)
return torch.stack(
[block_dequant(x_q_block[b], x_s[b], block_size) for b in range(batch_size)]
)
block_n, block_k = block_size[0], block_size[1]
n, k = x_q_block.shape
n_tiles = (n + block_n - 1) // block_n
k_tiles = (k + block_k - 1) // block_k
assert n_tiles == x_s.shape[0]
assert k_tiles == x_s.shape[1]
x_dq_block = x_q_block.to(torch.float32)
x_dq_block_tiles = [
[
x_dq_block[
j * block_n : min((j + 1) * block_n, n),
i * block_k : min((i + 1) * block_k, k),
]
for i in range(k_tiles)
]
for j in range(n_tiles)
]
for i in range(k_tiles):
for j in range(n_tiles):
x_dq_block_tiles[j][i][:, :] = x_dq_block_tiles[j][i] * x_s[j][i]
return x_dq_block
class TestW8A8BlockFP8EPMoE(unittest.TestCase):
DTYPES = [torch.half, torch.bfloat16]
M = [1, 222, 1024, 2048]
N = [128, 1024, 2048]
K = [256, 4096, 5120]
E = [8, 16]
ep_size = [2, 4]
TOP_KS = [2, 4]
BLOCK_SIZE = [[128, 128]]
SEEDS = [0]
@classmethod
def setUpClass(cls):
if not torch.cuda.is_available():
raise unittest.SkipTest("CUDA is not available")
torch.set_default_device("cuda")
def _w8a8_block_fp8_ep_moe(
self, M, N, K, E, ep_size, topk, block_size, dtype, seed
):
torch.manual_seed(seed)
random.seed(seed)
# NOTE(HandH1998): to avoid overflow when out_dtype = torch.half
factor_for_scale = 1e-2
fp8_info = torch.finfo(torch.float8_e4m3fn)
fp8_max, fp8_min = fp8_info.max, fp8_info.min
a = torch.randn((M, K), dtype=dtype) / 10
w1_fp32 = (torch.rand((E, 2 * N, K), dtype=dtype) - 0.5) * 2 * fp8_max
w1 = w1_fp32.clamp(min=fp8_min, max=fp8_max).to(torch.float8_e4m3fn)
w2_fp32 = (torch.rand((E, K, N), dtype=dtype) - 0.5) * 2 * fp8_max
w2 = w2_fp32.clamp(min=fp8_min, max=fp8_max).to(torch.float8_e4m3fn)
block_n, block_k = block_size[0], block_size[1]
n_tiles_w1 = (2 * N + block_n - 1) // block_n
n_tiles_w2 = (K + block_n - 1) // block_n
k_tiles_w1 = (K + block_k - 1) // block_k
k_tiles_w2 = (N + block_k - 1) // block_k
w1_s = (
torch.rand((E, n_tiles_w1, k_tiles_w1), dtype=torch.float32)
* factor_for_scale
)
w2_s = (
torch.rand((E, n_tiles_w2, k_tiles_w2), dtype=torch.float32)
* factor_for_scale
)
w1_ref = block_dequant(w1, w1_s, block_size).to(dtype)
w2_ref = block_dequant(w2, w2_s, block_size).to(dtype)
score = torch.randn((M, E), dtype=dtype)
num_experts_per_partition = E // ep_size
cur_rank = random.randint(0, ep_size - 1)
start_id = cur_rank * num_experts_per_partition
end_id = start_id + num_experts_per_partition - 1
with torch.inference_mode():
out = ep_moe(
hidden_states=a,
w1=w1,
w2=w2,
router_logits=score,
top_k=topk,
renormalize=False,
use_fp8_w8a8=True,
w1_scale_inv=w1_s,
w2_scale_inv=w2_s,
block_shape=block_size,
num_experts=E,
num_experts_per_partition=num_experts_per_partition,
start_expert_id=start_id,
end_expert_id=end_id,
)
ref_out = ep_moe(
hidden_states=a,
w1=w1_ref,
w2=w2_ref,
router_logits=score,
top_k=topk,
renormalize=False,
use_fp8_w8a8=False,
w1_scale_inv=None,
w2_scale_inv=None,
block_shape=None,
num_experts=E,
num_experts_per_partition=num_experts_per_partition,
start_expert_id=start_id,
end_expert_id=end_id,
)
self.assertTrue(
torch.mean(torch.abs(out.to(torch.float32) - ref_out.to(torch.float32)))
/ (torch.mean(torch.abs(ref_out.to(torch.float32))) + 1e-6)
< 0.06
)
def test_w8a8_block_fp8_ep_moe(self):
for params in itertools.product(
self.M,
self.N,
self.K,
self.E,
self.ep_size,
self.TOP_KS,
self.BLOCK_SIZE,
self.DTYPES,
self.SEEDS,
):
with self.subTest(
M=params[0],
N=params[1],
K=params[2],
E=params[3],
ep_size=params[4],
topk=params[5],
block_size=params[6],
dtype=params[7],
seed=params[8],
):
self._w8a8_block_fp8_ep_moe(*params)
torch.cuda.empty_cache()
if __name__ == "__main__":
unittest.main(verbosity=2)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment