Unverified Commit cb9d91ea authored by SijiaYang's avatar SijiaYang Committed by GitHub
Browse files

feat: support DeepSeek-R1-W4AFP8 model with ep-moe mode (#7762)


Signed-off-by: default avataryangsijia.614 <yangsijia.614@bytedance.com>
parent 6a6e0bb7
...@@ -359,7 +359,17 @@ class ModelConfig: ...@@ -359,7 +359,17 @@ class ModelConfig:
if hf_api.file_exists(self.model_path, "hf_quant_config.json"): if hf_api.file_exists(self.model_path, "hf_quant_config.json"):
quant_cfg = modelopt_quant_config quant_cfg = modelopt_quant_config
elif os.path.exists(os.path.join(self.model_path, "hf_quant_config.json")): elif os.path.exists(os.path.join(self.model_path, "hf_quant_config.json")):
quant_cfg = modelopt_quant_config quant_config_file = os.path.join(
self.model_path, "hf_quant_config.json"
)
with open(quant_config_file) as f:
quant_config_dict = json.load(f)
json_quant_configs = quant_config_dict["quantization"]
quant_algo = json_quant_configs.get("quant_algo", None)
if quant_algo == "MIXED_PRECISION":
quant_cfg = {"quant_method": "w4afp8"}
else:
quant_cfg = modelopt_quant_config
return quant_cfg return quant_cfg
# adapted from https://github.com/vllm-project/vllm/blob/v0.6.4.post1/vllm/config.py # adapted from https://github.com/vllm-project/vllm/blob/v0.6.4.post1/vllm/config.py
...@@ -389,6 +399,7 @@ class ModelConfig: ...@@ -389,6 +399,7 @@ class ModelConfig:
"w8a8_fp8", "w8a8_fp8",
"moe_wna16", "moe_wna16",
"qoq", "qoq",
"w4afp8",
] ]
compatible_quantization_methods = { compatible_quantization_methods = {
"modelopt_fp4": ["modelopt"], "modelopt_fp4": ["modelopt"],
......
# SPDX-License-Identifier: Apache-2.0
"""Cutlass W4A8 MoE kernel."""
from typing import Optional
import torch
from sgl_kernel import (
cutlass_w4a8_moe_mm,
get_cutlass_w4a8_moe_mm_data,
sgl_per_tensor_quant_fp8,
silu_and_mul,
)
from sglang.srt.layers.moe.ep_moe.kernels import (
post_reorder_triton_kernel,
pre_reorder_triton_kernel_for_cutlass_moe,
run_cutlass_moe_ep_preproess,
)
def cutlass_w4a8_moe(
start_expert_id: int,
end_expert_id: int,
total_num_experts: int,
a: torch.Tensor,
w1_q: torch.Tensor,
w2_q: torch.Tensor,
w1_scale: torch.Tensor,
w2_scale: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids_: torch.Tensor,
local_topk_ids: torch.Tensor,
a_strides1: torch.Tensor,
b_strides1: torch.Tensor,
c_strides1: torch.Tensor,
a_strides2: torch.Tensor,
b_strides2: torch.Tensor,
c_strides2: torch.Tensor,
s_strides13: torch.Tensor,
s_strides2: torch.Tensor,
expert_offsets: torch.Tensor,
problem_sizes1: torch.Tensor,
problem_sizes2: torch.Tensor,
a1_scale: Optional[torch.Tensor] = None,
a2_scale: Optional[torch.Tensor] = None,
apply_router_weight_on_input: bool = False,
) -> torch.Tensor:
"""
This function computes a w4a8-quantized Mixture of Experts (MoE) layer
using two sets of quantized weights, w1_q and w2_q, and top-k gating
mechanism. The matrix multiplications are implemented with CUTLASS
grouped gemm.
Parameters:
- a (torch.Tensor): The input tensor to the MoE layer.
Shape: [M, K]
- w1_q (torch.Tensor): The first set of int4-quantized expert weights.
Shape: [num_experts, N * 2, K // 2]
(the weights are passed transposed and int4-packed)
- w2_q (torch.Tensor): The second set of int4-quantized expert weights.
Shape: [num_experts, K, N // 2]
(the weights are passed transposed and int4-packed)
- w1_scale (torch.Tensor): The fp32 scale to dequantize w1_q.
Shape: [num_experts, K // 512, N * 8]
- w2_scale (torch.Tensor): The fp32 scale to dequantize w2_q.
Shape: [num_experts, N // 512, K * 4]
- topk_weights (torch.Tensor): The weights of each token->expert mapping.
- a_strides1 (torch.Tensor): The input strides of the first grouped gemm.
- b_strides1 (torch.Tensor): The weights strides of the first grouped gemm.
- c_strides1 (torch.Tensor): The output strides of the first grouped gemm.
- a_strides2 (torch.Tensor): The input strides of the second grouped gemm.
- b_strides2 (torch.Tensor): The weights strides of the second grouped gemm.
- c_strides2 (torch.Tensor): The output strides of the second grouped gemm.
- s_strides13 (torch.Tensor): The input and scale strides of the first grouped gemm.
- s_strides2 (torch.Tensor): The scale strides of the second grouped gemm.
- a1_scale (Optional[torch.Tensor]): The optional fp32 scale to quantize a.
Shape: scalar or [1, K]
- a2_scale (Optional[torch.Tensor]): The optional fp32 scale to
quantize the intermediate result between the gemms.
Shape: scalar or [1, N]
- apply_router_weight_on_input (bool): When true, the topk weights are
applied directly on the inputs. This is only applicable when topk is 1.
Returns:
- torch.Tensor: The fp8 output tensor after applying the MoE layer.
"""
assert topk_weights.shape == topk_ids_.shape, "topk shape mismatch"
assert w1_q.dtype == torch.int8
assert w2_q.dtype == torch.int8
assert a.shape[1] // 2 == w1_q.shape[2], "Hidden size mismatch w1"
assert w1_q.shape[2] * 2 == w2_q.shape[1], "Hidden size mismatch w2"
assert w1_q.shape[0] == w2_q.shape[0], "Expert number mismatch"
assert w1_q.shape[0] == w1_scale.shape[0], "w1 scales expert number mismatch"
assert w1_q.shape[0] == w2_scale.shape[0], "w2 scales expert number mismatch"
assert (
w1_scale.shape[1] == w1_q.shape[2] * 2 / 512
and w1_scale.shape[2] == w1_q.shape[1] * 4
), "W1 scale shape mismatch"
assert (
w2_scale.shape[1] == w2_q.shape[2] * 2 / 512
and w2_scale.shape[2] == w2_q.shape[1] * 4
), "W2 scale shape mismatch"
assert a_strides1.shape[0] == w1_q.shape[0], "A Strides 1 expert number mismatch"
assert b_strides1.shape[0] == w1_q.shape[0], "B Strides 1 expert number mismatch"
assert a_strides2.shape[0] == w2_q.shape[0], "A Strides 2 expert number mismatch"
assert b_strides2.shape[0] == w2_q.shape[0], "B Strides 2 expert number mismatch"
num_experts = w1_q.size(0)
m = a.size(0)
k = w1_q.size(2) * 2 # w1_q is transposed and packed
n = w2_q.size(2) * 2 # w2_q is transposed and packed
topk = topk_ids_.size(1)
if apply_router_weight_on_input:
assert topk == 1, "apply_router_weight_on_input is only implemented for topk=1"
device = a.device
_, src2dst, _ = run_cutlass_moe_ep_preproess(
local_topk_ids,
num_experts,
)
gateup_input = torch.empty(
(m * topk, k),
device=device,
dtype=torch.float8_e4m3fn,
)
pre_reorder_triton_kernel_for_cutlass_moe[(m,)](
a,
gateup_input,
src2dst,
local_topk_ids,
a1_scale,
total_num_experts,
topk,
k,
BLOCK_SIZE=512,
)
# NOTE: a_map and c_map are not used in the get_cutlass_w4a8_moe_mm_data kernel,
# they are kept to allow for a quick switch of the permutation logic
# from the current triton kernel implementation to the cutlass-based one if needed.
a_map = torch.empty((local_topk_ids.numel()), dtype=torch.int32, device=device)
c_map = torch.empty((local_topk_ids.numel()), dtype=torch.int32, device=device)
get_cutlass_w4a8_moe_mm_data(
local_topk_ids,
expert_offsets,
problem_sizes1,
problem_sizes2,
a_map,
c_map,
num_experts,
n,
k,
)
c1 = torch.empty((m * topk, n * 2), device=device, dtype=torch.half)
c2 = torch.zeros((m * topk, k), device=device, dtype=torch.half)
cutlass_w4a8_moe_mm(
c1,
gateup_input,
w1_q,
a1_scale.float(),
w1_scale,
expert_offsets[:-1],
problem_sizes1,
a_strides1,
b_strides1,
c_strides1,
s_strides13,
128,
topk,
)
intermediate = torch.empty((m * topk, n), device=device, dtype=torch.half)
silu_and_mul(c1, intermediate)
intermediate_q = torch.empty(
intermediate.shape, dtype=torch.float8_e4m3fn, device=device
)
sgl_per_tensor_quant_fp8(intermediate, intermediate_q, a2_scale.float(), True)
cutlass_w4a8_moe_mm(
c2,
intermediate_q,
w2_q,
a2_scale.float(),
w2_scale,
expert_offsets[:-1],
problem_sizes2,
a_strides2,
b_strides2,
c_strides2,
s_strides2,
128,
topk,
)
output = torch.empty_like(a)
post_reorder_triton_kernel[(m,)](
c2,
output,
src2dst,
topk_ids_,
topk_weights,
start_expert_id,
end_expert_id,
topk,
k,
0,
BLOCK_SIZE=512,
)
return output
...@@ -146,6 +146,7 @@ def compute_seg_indptr_triton_kernel(reorder_topk_ids, seg_indptr, num_toks): ...@@ -146,6 +146,7 @@ def compute_seg_indptr_triton_kernel(reorder_topk_ids, seg_indptr, num_toks):
def run_moe_ep_preproess(topk_ids: torch.Tensor, num_experts: int): def run_moe_ep_preproess(topk_ids: torch.Tensor, num_experts: int):
reorder_topk_ids, reorder_ids = torch.sort(topk_ids.view(-1), stable=True) reorder_topk_ids, reorder_ids = torch.sort(topk_ids.view(-1), stable=True)
seg_indptr = torch.zeros(num_experts + 1, device=topk_ids.device, dtype=torch.int64) seg_indptr = torch.zeros(num_experts + 1, device=topk_ids.device, dtype=torch.int64)
src2dst = torch.empty(topk_ids.numel(), device=topk_ids.device, dtype=torch.int32) src2dst = torch.empty(topk_ids.numel(), device=topk_ids.device, dtype=torch.int32)
...@@ -158,9 +159,66 @@ def run_moe_ep_preproess(topk_ids: torch.Tensor, num_experts: int): ...@@ -158,9 +159,66 @@ def run_moe_ep_preproess(topk_ids: torch.Tensor, num_experts: int):
compute_src2dst_triton_kernel[grid]( compute_src2dst_triton_kernel[grid](
reorder_ids, src2dst, topk_ids.numel(), BLOCK_SIZE reorder_ids, src2dst, topk_ids.numel(), BLOCK_SIZE
) )
return reorder_topk_ids, src2dst, seg_indptr return reorder_topk_ids, src2dst, seg_indptr
def run_cutlass_moe_ep_preproess(local_topk_ids: torch.Tensor, local_num_experts: int):
reorder_topk_ids, reorder_ids = torch.sort(local_topk_ids.view(-1), stable=True)
seg_indptr = torch.zeros(
local_num_experts + 1, device=local_topk_ids.device, dtype=torch.int64
)
src2dst = torch.empty(
local_topk_ids.numel(), device=local_topk_ids.device, dtype=torch.int32
)
BLOCK_SIZE = 512
grid = (triton.cdiv(local_topk_ids.numel(), BLOCK_SIZE),)
compute_src2dst_triton_kernel[grid](
reorder_ids, src2dst, local_topk_ids.numel(), BLOCK_SIZE
)
return reorder_topk_ids, src2dst, seg_indptr
@triton.jit
def pre_reorder_triton_kernel_for_cutlass_moe(
input_ptr,
gateup_input_ptr,
src2dst_ptr,
topk_ids_ptr,
a1_scales_ptr,
num_experts,
topk,
hidden_size,
BLOCK_SIZE: tl.constexpr,
):
OutDtype = gateup_input_ptr.dtype.element_ty
src_idx = tl.program_id(0)
src2dst_ptr = src2dst_ptr + src_idx * topk
topk_ids_ptr = topk_ids_ptr + src_idx * topk
src_ptr = input_ptr + src_idx * hidden_size
for idx in range(topk):
expert_id = tl.load(topk_ids_ptr + idx)
if expert_id != num_experts:
if a1_scales_ptr is not None:
scale = 1.0 / tl.load(a1_scales_ptr)
else:
scale = 1.0
dst_idx = tl.load(src2dst_ptr + idx)
dst_ptr = gateup_input_ptr + dst_idx * hidden_size
for start_offset in tl.range(0, hidden_size, BLOCK_SIZE):
offset = start_offset + tl.arange(0, BLOCK_SIZE)
mask = offset < hidden_size
in_data = tl.load(src_ptr + offset, mask=mask).to(tl.float32)
out_data = (in_data * scale).to(OutDtype)
tl.store(dst_ptr + offset, out_data, mask=mask)
@triton.jit @triton.jit
def pre_reorder_triton_kernel( def pre_reorder_triton_kernel(
input_ptr, input_ptr,
......
...@@ -12,6 +12,7 @@ from sglang.srt.distributed import ( ...@@ -12,6 +12,7 @@ from sglang.srt.distributed import (
) )
from sglang.srt.eplb.expert_location import get_global_expert_location_metadata from sglang.srt.eplb.expert_location import get_global_expert_location_metadata
from sglang.srt.eplb.expert_location_dispatch import ExpertLocationDispatchInfo from sglang.srt.eplb.expert_location_dispatch import ExpertLocationDispatchInfo
from sglang.srt.layers.moe.cutlass_w4a8_moe import cutlass_w4a8_moe
from sglang.srt.layers.moe.ep_moe.kernels import ( from sglang.srt.layers.moe.ep_moe.kernels import (
ep_gather, ep_gather,
ep_scatter, ep_scatter,
...@@ -20,6 +21,8 @@ from sglang.srt.layers.moe.ep_moe.kernels import ( ...@@ -20,6 +21,8 @@ from sglang.srt.layers.moe.ep_moe.kernels import (
moe_ep_deepgemm_preprocess, moe_ep_deepgemm_preprocess,
post_reorder_triton_kernel, post_reorder_triton_kernel,
pre_reorder_triton_kernel, pre_reorder_triton_kernel,
pre_reorder_triton_kernel_for_cutlass_moe,
run_cutlass_moe_ep_preproess,
run_moe_ep_preproess, run_moe_ep_preproess,
silu_and_mul_masked_post_quant_fwd, silu_and_mul_masked_post_quant_fwd,
silu_and_mul_triton_kernel, silu_and_mul_triton_kernel,
...@@ -41,6 +44,7 @@ from sglang.srt.layers.quantization.fp8_kernel import ( ...@@ -41,6 +44,7 @@ from sglang.srt.layers.quantization.fp8_kernel import (
sglang_per_token_quant_fp8, sglang_per_token_quant_fp8,
) )
from sglang.srt.layers.quantization.fp8_utils import normalize_e4m3fn_to_e4m3fnuz from sglang.srt.layers.quantization.fp8_utils import normalize_e4m3fn_to_e4m3fnuz
from sglang.srt.layers.quantization.w4afp8 import W4AFp8Config, W4AFp8MoEMethod
from sglang.srt.managers.schedule_batch import global_server_args_dict from sglang.srt.managers.schedule_batch import global_server_args_dict
from sglang.srt.model_executor.forward_batch_info import ForwardBatch from sglang.srt.model_executor.forward_batch_info import ForwardBatch
from sglang.srt.utils import ( from sglang.srt.utils import (
...@@ -191,7 +195,7 @@ class EPMoE(torch.nn.Module): ...@@ -191,7 +195,7 @@ class EPMoE(torch.nn.Module):
num_fused_shared_experts == 0 num_fused_shared_experts == 0
), "num_fused_shared_experts is not supported in EP" ), "num_fused_shared_experts is not supported in EP"
self.num_fused_shared_experts = num_fused_shared_experts self.num_fused_shared_experts = num_fused_shared_experts
self.num_experts_per_partition = self.num_experts // self.tp_size self.num_experts_per_partition, self.expert_map = self.determine_expert_map()
self.start_expert_id = self.tp_rank * self.num_experts_per_partition self.start_expert_id = self.tp_rank * self.num_experts_per_partition
self.end_expert_id = self.start_expert_id + self.num_experts_per_partition - 1 self.end_expert_id = self.start_expert_id + self.num_experts_per_partition - 1
...@@ -215,6 +219,18 @@ class EPMoE(torch.nn.Module): ...@@ -215,6 +219,18 @@ class EPMoE(torch.nn.Module):
self.use_block_quant = False self.use_block_quant = False
self.block_shape = None self.block_shape = None
self.activation_scheme = None self.activation_scheme = None
self.use_w4afp8 = False
elif isinstance(quant_config, W4AFp8Config):
self.quant_method: Optional[QuantizeMethodBase] = W4AFp8MoEMethod(
quant_config
)
self.use_w4afp8 = True
self.use_fp8_w8a8 = False
self.use_block_quant = False
self.fp8_dtype = torch.float8_e4m3fn
self.w13_weight_scale = None
self.w2_weight_scale = None
self.activation_scheme = quant_config.moe_activation_scheme
else: else:
self.quant_method: Optional[QuantizeMethodBase] = Fp8EPMoEMethod( self.quant_method: Optional[QuantizeMethodBase] = Fp8EPMoEMethod(
quant_config quant_config
...@@ -228,6 +244,7 @@ class EPMoE(torch.nn.Module): ...@@ -228,6 +244,7 @@ class EPMoE(torch.nn.Module):
) )
self.fp8_dtype = torch.float8_e4m3fn self.fp8_dtype = torch.float8_e4m3fn
self.activation_scheme = quant_config.activation_scheme self.activation_scheme = quant_config.activation_scheme
self.use_w4afp8 = False
self.quant_method.create_weights( self.quant_method.create_weights(
layer=self, layer=self,
...@@ -253,6 +270,49 @@ class EPMoE(torch.nn.Module): ...@@ -253,6 +270,49 @@ class EPMoE(torch.nn.Module):
self.w2_weight_scale_inv if self.use_block_quant else self.w2_weight_scale, self.w2_weight_scale_inv if self.use_block_quant else self.w2_weight_scale,
) )
# Adapted from https://github.com/vllm-project/vllm/blob/9fb52e523abf7bdaf7e60cf2971edb5a1b13dc08/vllm/model_executor/layers/fused_moe/layer.py#L544C1-L586C43
# Modifications: use determine_expert_map as a class internal function, set 'global_num_experts' rather than '-1' for experts not assigned to the current rank.
def determine_expert_map(self) -> Tuple[int, Optional[torch.Tensor]]:
"""
Calculates how many experts should be assigned to each rank for EP and
creates a mapping from global to local expert index. Experts are
distributed evenly across ranks. Any remaining are assigned to the
last rank.
Returns:
Tuple[int, Optional[torch.Tensor]]: A tuple containing:
- local_num_experts (int): The number of experts assigned
to the current rank.
- expert_map (Optional[torch.Tensor]): A tensor of shape
(global_num_experts,) mapping from global to local index.
Contains global_num_experts for experts not assigned to the current rank.
Returns None if ep_size is 1.
"""
ep_size = self.tp_size
ep_rank = self.tp_rank
global_num_experts = self.num_experts
assert ep_size > 0
if ep_size == 1:
return (global_num_experts, None)
local_num_experts = global_num_experts // ep_size
expert_map = torch.full(
(global_num_experts,), self.num_experts, dtype=torch.int32
)
if ep_rank < (ep_size - 1):
expert_map[
ep_rank * local_num_experts : (ep_rank + 1) * local_num_experts
] = torch.arange(0, local_num_experts, dtype=torch.int32)
else:
local_num_experts = global_num_experts - ep_rank * local_num_experts
expert_map[-local_num_experts:] = torch.arange(
0, local_num_experts, dtype=torch.int32
)
return (local_num_experts, expert_map)
def forward(self, hidden_states: torch.Tensor, router_logits: torch.Tensor): def forward(self, hidden_states: torch.Tensor, router_logits: torch.Tensor):
if deep_gemm_wrapper.ENABLE_JIT_DEEPGEMM and self.use_fp8_w8a8: if deep_gemm_wrapper.ENABLE_JIT_DEEPGEMM and self.use_fp8_w8a8:
return self.forward_deepgemm(hidden_states, router_logits) return self.forward_deepgemm(hidden_states, router_logits)
...@@ -440,6 +500,51 @@ class EPMoE(torch.nn.Module): ...@@ -440,6 +500,51 @@ class EPMoE(torch.nn.Module):
), ),
) )
if self.use_w4afp8:
local_topk_ids = topk_ids
if self.expert_map is not None:
"Translate info from expert_map to topk_ids"
local_topk_ids = torch.where(
self.expert_map[topk_ids] != self.num_experts,
self.expert_map[topk_ids],
self.num_experts,
)
output = cutlass_w4a8_moe(
self.start_expert_id,
self.end_expert_id,
self.num_experts,
hidden_states,
self.w13_weight,
self.w2_weight,
self.w13_weight_scale_inv,
self.w2_weight_scale_inv,
topk_weights,
topk_ids,
local_topk_ids,
self.quant_method.a_strides1,
self.quant_method.b_strides1,
self.quant_method.c_strides1,
self.quant_method.a_strides2,
self.quant_method.b_strides2,
self.quant_method.c_strides2,
self.quant_method.s_strides13,
self.quant_method.s_strides2,
self.quant_method.expert_offsets,
self.quant_method.problem_sizes1,
self.quant_method.problem_sizes2,
self.w13_input_scale,
self.w2_input_scale,
)
return output
if self.grouped_gemm_runner is None:
self.grouped_gemm_runner = GroupedGemmRunner(
hidden_states.device,
use_flashinfer=False, # TODO: use flashinfer
use_per_token_if_dynamic=self.use_per_token_if_dynamic,
)
reorder_topk_ids, src2dst, seg_indptr = run_moe_ep_preproess( reorder_topk_ids, src2dst, seg_indptr = run_moe_ep_preproess(
topk_ids, self.num_experts topk_ids, self.num_experts
) )
...@@ -449,7 +554,7 @@ class EPMoE(torch.nn.Module): ...@@ -449,7 +554,7 @@ class EPMoE(torch.nn.Module):
device=hidden_states.device, device=hidden_states.device,
dtype=( dtype=(
self.fp8_dtype self.fp8_dtype
if (self.use_fp8_w8a8 and not self.use_block_quant) if ((self.use_fp8_w8a8 or self.use_w4afp8) and not self.use_block_quant)
else hidden_states.dtype else hidden_states.dtype
), ),
) )
...@@ -656,6 +761,23 @@ class EPMoE(torch.nn.Module): ...@@ -656,6 +761,23 @@ class EPMoE(torch.nn.Module):
] ]
] ]
@classmethod
def make_expert_input_scale_params_mapping(
cls,
num_experts: int,
) -> List[Tuple[str, str, int, str]]:
# (param_name, weight_name, expert_id, shard_id)
return [
(
"experts.w13_" if shard_id in ["w1", "w3"] else "experts.w2_",
f"experts.{expert_id}.{shard_id}.",
expert_id,
shard_id,
)
for expert_id in range(num_experts)
for shard_id in ["w1", "w2", "w3"]
]
def weight_loader( def weight_loader(
self, self,
param: torch.nn.Parameter, param: torch.nn.Parameter,
...@@ -727,6 +849,15 @@ class EPMoE(torch.nn.Module): ...@@ -727,6 +849,15 @@ class EPMoE(torch.nn.Module):
# Input scales can be loaded directly and should be equal. # Input scales can be loaded directly and should be equal.
if "input_scale" in weight_name: if "input_scale" in weight_name:
if self.use_w4afp8:
if shard_id == "w1":
param_data[expert_id][0] = loaded_weight
elif shard_id == "w3":
param_data[expert_id][1] = loaded_weight
else:
param_data[expert_id] = loaded_weight
return
if ( if (
(shard_id == "w1" or shard_id == "w3") (shard_id == "w1" or shard_id == "w3")
and param_data[expert_id] != 1 and param_data[expert_id] != 1
...@@ -752,6 +883,13 @@ class EPMoE(torch.nn.Module): ...@@ -752,6 +883,13 @@ class EPMoE(torch.nn.Module):
] = loaded_weight ] = loaded_weight
else: # w2 else: # w2
param_data[expert_id] = loaded_weight param_data[expert_id] = loaded_weight
elif self.use_w4afp8:
if shard_id == "w1":
param_data[expert_id][: self.intermediate_size, :] = loaded_weight
elif shard_id == "w3":
param_data[expert_id][self.intermediate_size :, :] = loaded_weight
else:
param_data[expert_id] = loaded_weight
# If we are in merged column case (gate_up_proj) # If we are in merged column case (gate_up_proj)
else: else:
if shard_id in ("w1", "w3"): if shard_id in ("w1", "w3"):
......
...@@ -68,6 +68,7 @@ from sglang.srt.layers.quantization.modelopt_quant import ( ...@@ -68,6 +68,7 @@ from sglang.srt.layers.quantization.modelopt_quant import (
) )
from sglang.srt.layers.quantization.moe_wna16 import MoeWNA16Config from sglang.srt.layers.quantization.moe_wna16 import MoeWNA16Config
from sglang.srt.layers.quantization.qoq import QoQConfig from sglang.srt.layers.quantization.qoq import QoQConfig
from sglang.srt.layers.quantization.w4afp8 import W4AFp8Config
from sglang.srt.layers.quantization.w8a8_fp8 import W8A8Fp8Config from sglang.srt.layers.quantization.w8a8_fp8 import W8A8Fp8Config
from sglang.srt.layers.quantization.w8a8_int8 import W8A8Int8Config from sglang.srt.layers.quantization.w8a8_int8 import W8A8Int8Config
...@@ -82,6 +83,7 @@ BASE_QUANTIZATION_METHODS: Dict[str, Type[QuantizationConfig]] = { ...@@ -82,6 +83,7 @@ BASE_QUANTIZATION_METHODS: Dict[str, Type[QuantizationConfig]] = {
"moe_wna16": MoeWNA16Config, "moe_wna16": MoeWNA16Config,
"compressed-tensors": CompressedTensorsConfig, "compressed-tensors": CompressedTensorsConfig,
"qoq": QoQConfig, "qoq": QoQConfig,
"w4afp8": W4AFp8Config,
} }
# VLLM-dependent quantization methods # VLLM-dependent quantization methods
......
# Adapted from https://github.com/vllm-project/vllm/blob/v0.6.4.post1/vllm/model_executor/layers/quantization/fp8.py # Adapted from https://github.com/vllm-project/vllm/blob/v0.6.4.post1/vllm/model_executor/layers/quantization/fp8.py
import logging import logging
from typing import Any, Callable, Dict, List, Optional from typing import Any, Callable, Dict, List, Optional, Union
import torch import torch
import torch.nn.functional as F import torch.nn.functional as F
...@@ -200,7 +200,7 @@ class Fp8LinearMethod(LinearMethodBase): ...@@ -200,7 +200,7 @@ class Fp8LinearMethod(LinearMethodBase):
quant_config: The quantization config. quant_config: The quantization config.
""" """
def __init__(self, quant_config: Fp8Config): def __init__(self, quant_config: Union["Fp8Config", "W4AFp8Config"]):
self.quant_config = quant_config self.quant_config = quant_config
self.cutlass_fp8_supported = cutlass_fp8_supported() self.cutlass_fp8_supported = cutlass_fp8_supported()
...@@ -286,7 +286,10 @@ class Fp8LinearMethod(LinearMethodBase): ...@@ -286,7 +286,10 @@ class Fp8LinearMethod(LinearMethodBase):
if self.quant_config.is_checkpoint_fp8_serialized: if self.quant_config.is_checkpoint_fp8_serialized:
# WEIGHT SCALE # WEIGHT SCALE
if self.block_quant: if self.block_quant:
assert self.quant_config.activation_scheme == "dynamic" if hasattr(self.quant_config, "activation_scheme"):
assert self.quant_config.activation_scheme == "dynamic"
elif hasattr(self.quant_config, "linear_activation_scheme"):
assert self.quant_config.linear_activation_scheme == "dynamic"
scale = BlockQuantScaleParameter( scale = BlockQuantScaleParameter(
data=torch.empty( data=torch.empty(
(output_size_per_partition + block_n - 1) // block_n, (output_size_per_partition + block_n - 1) // block_n,
...@@ -308,7 +311,13 @@ class Fp8LinearMethod(LinearMethodBase): ...@@ -308,7 +311,13 @@ class Fp8LinearMethod(LinearMethodBase):
layer.register_parameter("weight_scale", scale) layer.register_parameter("weight_scale", scale)
# INPUT ACTIVATION SCALE # INPUT ACTIVATION SCALE
if self.quant_config.activation_scheme == "static": if (
hasattr(self.quant_config, "activation_scheme")
and self.quant_config.activation_scheme == "static"
) or (
hasattr(self.quant_config, "linear_activation_scheme")
and self.quant_config.linear_activation_scheme == "static"
):
scale = PerTensorScaleParameter( scale = PerTensorScaleParameter(
data=torch.empty(len(output_partition_sizes), dtype=torch.float32), data=torch.empty(len(output_partition_sizes), dtype=torch.float32),
weight_loader=weight_loader, weight_loader=weight_loader,
...@@ -371,7 +380,13 @@ class Fp8LinearMethod(LinearMethodBase): ...@@ -371,7 +380,13 @@ class Fp8LinearMethod(LinearMethodBase):
layer.weight_scale = torch.nn.Parameter( layer.weight_scale = torch.nn.Parameter(
layer.weight_scale.data, requires_grad=False layer.weight_scale.data, requires_grad=False
) )
if self.quant_config.activation_scheme == "static": if (
hasattr(self.quant_config, "activation_scheme")
and self.quant_config.activation_scheme == "static"
) or (
hasattr(self.quant_config, "linear_activation_scheme")
and self.quant_config.linear_activation_scheme == "static"
):
layer.input_scale = torch.nn.Parameter( layer.input_scale = torch.nn.Parameter(
layer.input_scale.data, requires_grad=False layer.input_scale.data, requires_grad=False
) )
...@@ -405,7 +420,13 @@ class Fp8LinearMethod(LinearMethodBase): ...@@ -405,7 +420,13 @@ class Fp8LinearMethod(LinearMethodBase):
# Update layer with new values. # Update layer with new values.
layer.weight = Parameter(weight.t(), requires_grad=False) layer.weight = Parameter(weight.t(), requires_grad=False)
layer.weight_scale = Parameter(weight_scale, requires_grad=False) layer.weight_scale = Parameter(weight_scale, requires_grad=False)
if self.quant_config.activation_scheme == "static": if (
hasattr(self.quant_config, "activation_scheme")
and self.quant_config.activation_scheme == "static"
) or (
hasattr(self.quant_config, "linear_activation_scheme")
and self.quant_config.linear_activation_scheme == "static"
):
layer.input_scale = Parameter( layer.input_scale = Parameter(
layer.input_scale.max(), requires_grad=False layer.input_scale.max(), requires_grad=False
) )
......
import logging
from typing import Any, Dict, List, Optional
import torch
from torch.nn import Module
from torch.nn.parameter import Parameter
from sglang.srt.layers.linear import LinearBase, UnquantizedLinearMethod
from sglang.srt.layers.quantization.base_config import (
QuantizationConfig,
QuantizeMethodBase,
)
from sglang.srt.layers.quantization.fp8 import Fp8LinearMethod
from sglang.srt.layers.quantization.utils import is_layer_skipped
from sglang.srt.utils import set_weight_attrs
ACTIVATION_SCHEMES = ["static", "dynamic"]
logger = logging.getLogger(__name__)
class W4AFp8Config(QuantizationConfig):
"""Config class for MIXED_PRECISION W4AFp8."""
def __init__(
self,
is_checkpoint_fp8_serialized: bool = True,
is_checkpoint_w4afp8_serialized: bool = True,
linear_activation_scheme: str = "dynamic",
moe_activation_scheme: str = "static",
ignored_layers: Optional[List[str]] = None,
weight_block_size: Optional[List[int]] = None,
group_size: int = 128,
) -> None:
super().__init__()
self.is_checkpoint_fp8_serialized = is_checkpoint_fp8_serialized
self.is_checkpoint_w4afp8_serialized = is_checkpoint_w4afp8_serialized
if is_checkpoint_w4afp8_serialized:
logger.warning("Detected w4afp8 checkpoint. Please note that")
if moe_activation_scheme not in ACTIVATION_SCHEMES:
raise ValueError(f"Unsupported activation scheme {moe_activation_scheme}")
self.linear_activation_scheme = linear_activation_scheme
self.moe_activation_scheme = moe_activation_scheme
self.ignored_layers = ignored_layers or []
self.weight_block_size = [128, 128]
self.group_size = group_size
@classmethod
def get_name(cls) -> str:
return "w4afp8"
@classmethod
def get_supported_act_dtypes(cls) -> List[torch.dtype]:
return [torch.bfloat16, torch.float8_e4m3fn]
@classmethod
def get_min_capability(cls) -> int:
return 90
@classmethod
def get_config_filenames(cls) -> List[str]:
return []
@classmethod
def from_config(cls, config: Dict[str, Any]) -> "W4AFp8Config":
quant_method = cls.get_from_keys(config, ["quant_method"])
is_checkpoint_fp8_serialized = "fp8" in quant_method
is_checkpoint_w4afp8_serialized = "w4afp8" in quant_method
linear_activation_scheme = "dynamic"
moe_activation_scheme = "static"
weight_block_size = [128, 128]
return cls(
is_checkpoint_fp8_serialized=is_checkpoint_fp8_serialized,
is_checkpoint_w4afp8_serialized=is_checkpoint_w4afp8_serialized,
linear_activation_scheme=linear_activation_scheme,
moe_activation_scheme=moe_activation_scheme,
weight_block_size=weight_block_size,
)
def get_quant_method(
self, layer: torch.nn.Module, prefix: str
) -> Optional["QuantizeMethodBase"]:
from sglang.srt.layers.moe.fused_moe_triton import FusedMoE
if isinstance(layer, LinearBase):
if is_layer_skipped(prefix, self.ignored_layers):
return UnquantizedLinearMethod()
return Fp8LinearMethod(self)
elif isinstance(layer, FusedMoE):
return W4AFp8MoEMethod(self)
return None
def get_scaled_act_names(self) -> List[str]:
return []
class W4AFp8MoEMethod:
def __init__(self, quant_config: W4AFp8Config):
self.quant_config = quant_config
def create_weights(
self,
layer: Module,
num_experts_per_partition: int,
hidden_size: int,
intermediate_size: int,
params_dtype: torch.dtype,
**extra_weight_attrs,
):
assert "weight_loader" in extra_weight_attrs
# Fused gate_up_proj (column parallel)
w13_weight = torch.nn.Parameter(
torch.empty(
num_experts_per_partition,
intermediate_size * 2,
hidden_size // 2,
dtype=torch.int8,
),
requires_grad=False,
)
layer.register_parameter("w13_weight", w13_weight)
set_weight_attrs(w13_weight, extra_weight_attrs)
# down_proj (row parallel)
w2_weight = torch.nn.Parameter(
torch.empty(
num_experts_per_partition,
hidden_size,
intermediate_size // 2,
dtype=torch.int8,
),
requires_grad=False,
)
layer.register_parameter("w2_weight", w2_weight)
set_weight_attrs(w2_weight, extra_weight_attrs)
w13_weight_scale = torch.nn.Parameter(
torch.zeros(
num_experts_per_partition,
2 * intermediate_size,
hidden_size // self.quant_config.group_size,
dtype=torch.float32,
),
requires_grad=False,
)
layer.register_parameter("w13_weight_scale_inv", w13_weight_scale)
set_weight_attrs(w13_weight_scale, extra_weight_attrs)
w2_weight_scale = torch.nn.Parameter(
torch.zeros(
num_experts_per_partition,
hidden_size,
intermediate_size // self.quant_config.group_size,
dtype=torch.float32,
),
requires_grad=False,
)
layer.register_parameter("w2_weight_scale_inv", w2_weight_scale)
set_weight_attrs(w2_weight_scale, extra_weight_attrs)
# Input scales
w13_input_scale = torch.nn.Parameter(
torch.ones((num_experts_per_partition, 2), dtype=torch.bfloat16),
requires_grad=False,
)
layer.register_parameter("w13_input_scale", w13_input_scale)
set_weight_attrs(w13_input_scale, extra_weight_attrs)
w2_input_scale = torch.nn.Parameter(
torch.ones(num_experts_per_partition, dtype=torch.bfloat16),
requires_grad=False,
)
layer.register_parameter("w2_input_scale", w2_input_scale)
set_weight_attrs(w2_input_scale, extra_weight_attrs)
# Pre-populate the strides
device = layer.w13_weight.device
self.a_strides1 = torch.full(
(num_experts_per_partition, 3),
hidden_size,
device=device,
dtype=torch.int64,
)
self.c_strides1 = torch.full(
(num_experts_per_partition, 3),
2 * intermediate_size,
device=device,
dtype=torch.int64,
)
self.a_strides2 = torch.full(
(num_experts_per_partition, 3),
intermediate_size,
device=device,
dtype=torch.int64,
)
self.c_strides2 = torch.full(
(num_experts_per_partition, 3),
hidden_size,
device=device,
dtype=torch.int64,
)
self.b_strides1 = self.a_strides1
self.s_strides13 = self.c_strides1
self.b_strides2 = self.a_strides2
self.s_strides2 = self.c_strides2
self.expert_offsets = torch.empty(
(num_experts_per_partition + 1), dtype=torch.int32, device=device
)
self.problem_sizes1 = torch.empty(
(num_experts_per_partition, 3), dtype=torch.int32, device=device
)
self.problem_sizes2 = torch.empty(
(num_experts_per_partition, 3), dtype=torch.int32, device=device
)
return
def _interleave_scales(self, scales: torch.Tensor) -> torch.Tensor:
"""Interleave scales in groups of 4 similar to TRT-LLM implementation."""
s_shape = scales.shape
# Reshape to separate groups of 4
scales_interleaved = scales.reshape(
s_shape[0], s_shape[1], (s_shape[2] // 4), 4
)
# Permute dimensions to interleave
scales_interleaved = scales_interleaved.permute(0, 2, 1, 3)
# Reshape back to original dimensions but with interleaved values
scales_interleaved = scales_interleaved.reshape(
s_shape[0], s_shape[2] // 4, s_shape[1] * 4
)
return scales_interleaved.contiguous()
def process_weights_after_loading(self, layer: Module) -> None:
dtype = torch.bfloat16
device = layer.w2_weight.device
# Interleave w13_weight_scale (gate_up_proj)
w13_weight_scale = layer.w13_weight_scale_inv.to(dtype)
w13_weight_scale = self._interleave_scales(w13_weight_scale)
layer.w13_weight_scale_inv = Parameter(w13_weight_scale, requires_grad=False)
# Interleave w2_weight_scale (down_proj)
w2_weight_scale = layer.w2_weight_scale_inv.to(dtype)
w2_weight_scale = self._interleave_scales(w2_weight_scale)
layer.w2_weight_scale_inv = Parameter(w2_weight_scale, requires_grad=False)
# Process input scales
w13_input_scale_max = layer.w13_input_scale.max().to(dtype).item()
new_w13_input_scale = torch.tensor(
[w13_input_scale_max],
dtype=dtype,
device=device,
)
layer.w13_input_scale = Parameter(new_w13_input_scale, requires_grad=False)
w2_input_scale_max = layer.w2_input_scale.max().to(dtype).item()
new_w2_input_scale = torch.tensor(
[w2_input_scale_max], dtype=dtype, device=device
)
layer.w2_input_scale = Parameter(new_w2_input_scale, requires_grad=False)
...@@ -2363,6 +2363,12 @@ class DeepseekV2ForCausalLM(nn.Module): ...@@ -2363,6 +2363,12 @@ class DeepseekV2ForCausalLM(nn.Module):
ckpt_up_proj_name="up_proj", ckpt_up_proj_name="up_proj",
num_experts=self.config.n_routed_experts + self.num_fused_shared_experts, num_experts=self.config.n_routed_experts + self.num_fused_shared_experts,
) )
if self.quant_config and self.quant_config.get_name() == "w4afp8":
expert_params_mapping += (
get_moe_impl_class().make_expert_input_scale_params_mapping(
num_experts=self.config.n_routed_experts
)
)
# Fuse q_a_proj and kv_a_proj_with_mqa along output dimension when q_lora_rank is not None # Fuse q_a_proj and kv_a_proj_with_mqa along output dimension when q_lora_rank is not None
fuse_qkv_a_proj = hasattr(self.config, "q_lora_rank") and ( fuse_qkv_a_proj = hasattr(self.config, "q_lora_rank") and (
......
...@@ -708,6 +708,7 @@ class ServerArgs: ...@@ -708,6 +708,7 @@ class ServerArgs:
"w8a8_fp8", "w8a8_fp8",
"moe_wna16", "moe_wna16",
"qoq", "qoq",
"w4afp8",
], ],
help="The quantization method.", help="The quantization method.",
) )
......
# SPDX-License-Identifier: Apache-2.0
from typing import Optional
import pytest
import torch
from sglang.srt.layers.moe.cutlass_w4a8_moe import cutlass_w4a8_moe
from sglang.srt.layers.moe.topk import select_experts
def pack_int4_values_to_int8(int4_values_interleaved: torch.Tensor) -> torch.Tensor:
if int4_values_interleaved.shape[-1] % 2 != 0:
raise ValueError(
"the last dim size of int4_values_interleaved tensor must be even."
)
input_tensor_int8 = int4_values_interleaved.to(torch.int8)
low_nibbles = input_tensor_int8[..., 0::2]
high_nibbles = input_tensor_int8[..., 1::2]
packed_tensor = (high_nibbles << 4) | (low_nibbles & 0x0F)
return packed_tensor.to(torch.int8)
def pack_interleave(num_experts, ref_weight, ref_scale):
n, k = ref_weight.shape[1], ref_weight.shape[2]
weight = pack_int4_values_to_int8(ref_weight.cpu()).cuda()
w_q = weight.view((num_experts, n, k // 2)).view(torch.int8)
w_q = w_q.contiguous()
scale_interleaved = ref_scale.reshape(
ref_scale.shape[0], ref_scale.shape[1], (ref_scale.shape[2] // 4), 4
) # [E, N, K/4, 4]
scale_interleaved = scale_interleaved.permute(0, 2, 1, 3) # [E, K/4, N, 4]
scale_interleaved = scale_interleaved.reshape(
ref_scale.shape[0], ref_scale.shape[2] // 4, ref_scale.shape[1] * 4
) # [E, K/4, N*4]
w_scale = scale_interleaved.contiguous()
return w_q, w_scale
@pytest.mark.parametrize("M", [1, 2, 4, 8, 16])
@pytest.mark.parametrize("N", [2048])
@pytest.mark.parametrize("K", [7168])
@pytest.mark.parametrize("E", [256])
@pytest.mark.parametrize("ep_size", [8])
@pytest.mark.parametrize("topk", [8])
@pytest.mark.parametrize("group_size", [128])
@pytest.mark.parametrize("dtype", [torch.bfloat16])
def test_cutlass_w4a8_moe(M, N, K, E, ep_size, topk, group_size, dtype):
local_e = E // ep_size
debug = False
if debug:
a = torch.ones((M, K), dtype=dtype, device="cuda") * 0.001
ref_weight_1 = torch.ones((local_e, N * 2, K), dtype=torch.int8, device="cuda")
ref_weight_2 = torch.ones((local_e, K, N), dtype=torch.int8, device="cuda")
a1_scale = torch.ones(1, dtype=torch.float32, device="cuda")
a2_scale = torch.ones(1, dtype=torch.float32, device="cuda")
scale_1 = torch.ones(
(local_e, N * 2, K // group_size), dtype=dtype, device="cuda"
)
scale_2 = torch.ones((local_e, K, N // group_size), dtype=dtype, device="cuda")
else:
a = torch.randn(M, K, dtype=dtype, device="cuda")
ref_weight_1 = torch.randint(
-8, 8, (local_e, N * 2, K), dtype=torch.int8, device="cuda"
)
ref_weight_2 = torch.randint(
-8, 8, (local_e, K, N), dtype=torch.int8, device="cuda"
)
affine_coeff = 0.005
a1_scale = torch.randn(1, dtype=torch.float32, device="cuda")
a2_scale = torch.randn(1, dtype=torch.float32, device="cuda")
scale_1 = (
torch.randn(local_e, N * 2, K // group_size, dtype=dtype, device="cuda")
* affine_coeff
)
scale_2 = (
torch.randn(local_e, K, N // group_size, dtype=dtype, device="cuda")
* affine_coeff
)
w1_q, w1_scale = pack_interleave(local_e, ref_weight_1, scale_1)
w2_q, w2_scale = pack_interleave(local_e, ref_weight_2, scale_2)
device = "cuda"
a_strides1 = torch.full((local_e, 3), K, device=device, dtype=torch.int64)
c_strides1 = torch.full((local_e, 3), 2 * N, device=device, dtype=torch.int64)
a_strides2 = torch.full((local_e, 3), N, device=device, dtype=torch.int64)
c_strides2 = torch.full((local_e, 3), K, device=device, dtype=torch.int64)
b_strides1 = a_strides1
s_strides13 = c_strides1
b_strides2 = a_strides2
s_strides2 = c_strides2
score = torch.randn((M, E), dtype=dtype, device=device)
topk_weights, topk_ids = select_experts(
hidden_states=a,
router_logits=score,
top_k=topk,
use_grouped_topk=False,
renormalize=False,
)
expert_map = torch.arange(E, dtype=torch.int32, device=device)
expert_map[local_e:] = E
output = cutlass_moe(
a,
w1_q,
w2_q,
w1_scale,
w2_scale,
topk_weights,
topk_ids,
a_strides1,
b_strides1,
c_strides1,
a_strides2,
b_strides2,
c_strides2,
s_strides13,
s_strides2,
0,
local_e - 1,
E,
a1_scale,
a2_scale,
expert_map,
)
ref_output = ref(
a,
local_e,
topk_weights,
topk_ids,
ref_weight_1,
ref_weight_2,
scale_1,
scale_2,
has_pre_quant=True,
has_alpha=True,
pre_quant_scale_1=a1_scale,
pre_quant_scale_2=a2_scale,
alpha_1=a1_scale,
alpha_2=a2_scale,
)
# compare
torch.cuda.synchronize()
# compare final output
torch.testing.assert_close(output, ref_output, rtol=1e-2, atol=0.1)
print("SUCCESS: Final output tensors are close.")
def cutlass_moe(
a: torch.Tensor,
w1_q: torch.Tensor,
w2_q: torch.Tensor,
w1_scale: torch.Tensor,
w2_scale: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids_: torch.Tensor,
a_strides1: torch.Tensor,
b_strides1: torch.Tensor,
c_strides1: torch.Tensor,
a_strides2: torch.Tensor,
b_strides2: torch.Tensor,
c_strides2: torch.Tensor,
s_strides13: torch.Tensor,
s_strides2: torch.Tensor,
start_expert_id: int,
end_expert_id: int,
E: int,
a1_scale: Optional[torch.Tensor] = None,
a2_scale: Optional[torch.Tensor] = None,
expert_map: Optional[torch.Tensor] = None,
apply_router_weight_on_input: bool = False,
):
local_topk_ids = topk_ids_
local_topk_ids = torch.where(expert_map[topk_ids_] != E, expert_map[topk_ids_], E)
device = a.device
local_num_experts = end_expert_id - start_expert_id + 1
expert_offsets = torch.empty(
(local_num_experts + 1), dtype=torch.int32, device=device
)
problem_sizes1 = torch.empty(
(local_num_experts, 3), dtype=torch.int32, device=device
)
problem_sizes2 = torch.empty(
(local_num_experts, 3), dtype=torch.int32, device=device
)
return cutlass_w4a8_moe(
start_expert_id,
end_expert_id,
E,
a,
w1_q,
w2_q,
w1_scale,
w2_scale,
topk_weights,
topk_ids_,
local_topk_ids,
a_strides1,
b_strides1,
c_strides1,
a_strides2,
b_strides2,
c_strides2,
s_strides13,
s_strides2,
expert_offsets,
problem_sizes1,
problem_sizes2,
a1_scale,
a2_scale,
apply_router_weight_on_input,
)
def ref(
x: torch.Tensor,
num_experts: int,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
ref_weight_1: torch.Tensor,
ref_weight_2: torch.Tensor,
ref_weight_scale_1: torch.Tensor,
ref_weight_scale_2: torch.Tensor,
has_pre_quant: bool = False,
has_alpha: bool = False,
pre_quant_scale_1: Optional[torch.Tensor] = None,
pre_quant_scale_2: Optional[torch.Tensor] = None,
alpha_1: Optional[torch.Tensor] = None,
alpha_2: Optional[torch.Tensor] = None,
):
results = torch.zeros_like(x)
dtype = x.dtype
for e_idx in range(num_experts):
mask = topk_ids == e_idx
activated_tokens = mask.sum(1).bool()
act = x[activated_tokens, :]
if act.shape[0] == 0:
continue
final_scale = (topk_weights * mask).sum(1)[activated_tokens].unsqueeze(1)
act = (
torch.clamp((act / pre_quant_scale_1.float()), -448.0, 448.0)
.to(torch.float8_e4m3fn)
.to(dtype)
)
w3_w1 = ref_weight_1[e_idx]
ref_w_scale_repeat = (
ref_weight_scale_1[e_idx].repeat_interleave(128, dim=1).to(float)
)
w3_w1 = (w3_w1.to(float) * ref_w_scale_repeat).to(dtype)
fc1 = ((torch.matmul(act, w3_w1.T)) * alpha_1).to(torch.float16)
gate, fc1 = fc1.chunk(2, dim=-1)
fc1 = fc1 * torch.nn.functional.silu(gate)
act = (fc1 / pre_quant_scale_2.float()).to(torch.float8_e4m3fn)
act = act.to(dtype)
w2 = ref_weight_2[e_idx]
ref_w_scale_repeat = (
ref_weight_scale_2[e_idx].repeat_interleave(128, dim=1).to(float)
)
w2 = (w2.to(float) * ref_w_scale_repeat).to(dtype)
fc2 = (torch.matmul(act, w2.T) * alpha_2).to(torch.float16)
results[activated_tokens, :] += (fc2 * final_scale).to(results.dtype)
return results
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment