Unverified Commit bf0f448f authored by Cheng Wan's avatar Cheng Wan Committed by GitHub
Browse files

[2/N] MoE Refactor: Unify weight loader and quant methods (#8397)

parent 36d6f0ba
......@@ -30,13 +30,13 @@ from sglang.srt.layers.quantization.base_config import (
QuantizationConfig,
QuantizeMethodBase,
)
from sglang.srt.layers.quantization.fp8 import Fp8EPMoEMethod
from sglang.srt.layers.quantization.fp8 import Fp8Config, Fp8MoEMethod
from sglang.srt.layers.quantization.fp8_kernel import (
is_fp8_fnuz,
sglang_per_token_group_quant_fp8,
sglang_per_token_quant_fp8,
)
from sglang.srt.layers.quantization.unquant import UnquantizedEPMoEMethod
from sglang.srt.layers.quantization.unquant import UnquantizedFusedMoEMethod
from sglang.srt.layers.quantization.w4afp8 import W4AFp8Config, W4AFp8MoEMethod
from sglang.srt.managers.schedule_batch import global_server_args_dict
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
......@@ -62,8 +62,6 @@ use_flashinfer_trtllm_moe = (
if not (_is_npu or _is_hip):
from sgl_kernel import silu_and_mul
from sglang.srt.layers.moe.cutlass_w4a8_moe import cutlass_w4a8_moe
if _use_aiter:
from aiter import ActivationType, QuantType
from aiter.fused_moe import fused_moe
......@@ -162,7 +160,7 @@ def _get_tile_tokens_dim(num_tokens, top_k, num_experts):
return tile_tokens_dim
class EPMoE(torch.nn.Module):
class EPMoE(FusedMoE):
"""
MoE Expert Parallel Impl
......@@ -184,51 +182,60 @@ class EPMoE(torch.nn.Module):
routed_scaling_factor: Optional[float] = None,
use_per_token_if_dynamic: bool = True,
):
super().__init__()
super().__init__(
num_experts=num_experts,
hidden_size=hidden_size,
intermediate_size=intermediate_size,
top_k=top_k,
layer_id=layer_id,
params_dtype=params_dtype,
quant_config=quant_config,
tp_size=tp_size,
prefix=prefix,
activation=activation,
routed_scaling_factor=routed_scaling_factor,
enable_ep_moe=True,
skip_quant=True,
)
if params_dtype is None:
params_dtype = torch.get_default_dtype()
self.tp_size = (
tp_size if tp_size is not None else get_tensor_model_parallel_world_size()
)
self.tp_rank = get_tensor_model_parallel_rank()
self.layer_id = layer_id
self.num_experts = num_experts
assert self.num_experts % self.tp_size == 0
self.num_experts_per_partition, self.expert_map = self.determine_expert_map()
self.start_expert_id = self.tp_rank * self.num_experts_per_partition
self.end_expert_id = self.start_expert_id + self.num_experts_per_partition - 1
self.num_local_experts, self.expert_map = self.determine_expert_map()
self.start_expert_id = self.ep_rank * self.num_local_experts
self.end_expert_id = self.start_expert_id + self.num_local_experts - 1
self.top_k = top_k
self.intermediate_size = intermediate_size
self.activation = activation
self.routed_scaling_factor = routed_scaling_factor
self.use_per_token_if_dynamic = use_per_token_if_dynamic
# TODO(ch-wan): move quant preparation to FusedMoE
if quant_config is None:
self.quant_method: Optional[QuantizeMethodBase] = UnquantizedEPMoEMethod()
self.quant_method: Optional[QuantizeMethodBase] = (
UnquantizedFusedMoEMethod()
)
self.use_fp8_w8a8 = False
self.use_block_quant = False
self.block_shape = None
self.activation_scheme = None
self.use_w4afp8 = False
self.w13_input_scale = None
self.w2_input_scale = None
self.w13_weight_scale = None
self.w2_weight_scale = None
elif isinstance(quant_config, W4AFp8Config):
self.quant_method: Optional[QuantizeMethodBase] = W4AFp8MoEMethod(
quant_config
)
self.use_w4afp8 = True
self.use_fp8_w8a8 = False
self.use_block_quant = False
self.fp8_dtype = torch.float8_e4m3fn
self.w13_input_scale = None
self.w2_input_scale = None
self.w13_weight_scale = None
self.w2_weight_scale = None
self.activation_scheme = quant_config.moe_activation_scheme
else:
self.quant_method: Optional[QuantizeMethodBase] = Fp8EPMoEMethod(
quant_config
)
elif isinstance(quant_config, Fp8Config):
self.quant_method: Optional[QuantizeMethodBase] = Fp8MoEMethod(quant_config)
self.use_fp8_w8a8 = True
self.use_block_quant = getattr(self.quant_method, "block_quant", False)
self.block_shape = (
......@@ -238,11 +245,13 @@ class EPMoE(torch.nn.Module):
)
self.fp8_dtype = torch.float8_e4m3fn
self.activation_scheme = quant_config.activation_scheme
self.use_w4afp8 = False
else:
raise ValueError(f"Unsupported quant_config: {quant_config}")
self.quant_config = quant_config
self.quant_method.create_weights(
layer=self,
num_experts_per_partition=self.num_experts_per_partition,
num_experts=self.num_local_experts,
hidden_size=hidden_size,
intermediate_size=self.intermediate_size,
params_dtype=params_dtype,
......@@ -251,19 +260,6 @@ class EPMoE(torch.nn.Module):
self.grouped_gemm_runner = None
self.w13_weight_fp8 = (
self.w13_weight,
(
self.w13_weight_scale_inv
if self.use_block_quant
else self.w13_weight_scale
),
)
self.w2_weight_fp8 = (
self.w2_weight,
self.w2_weight_scale_inv if self.use_block_quant else self.w2_weight_scale,
)
# Adapted from https://github.com/vllm-project/vllm/blob/9fb52e523abf7bdaf7e60cf2971edb5a1b13dc08/vllm/model_executor/layers/fused_moe/layer.py#L544C1-L586C43
# Modifications: use determine_expert_map as a class internal function, set 'global_num_experts' rather than '-1' for experts not assigned to the current rank.
def determine_expert_map(self) -> Tuple[int, Optional[torch.Tensor]]:
......@@ -282,8 +278,8 @@ class EPMoE(torch.nn.Module):
Contains global_num_experts for experts not assigned to the current rank.
Returns None if ep_size is 1.
"""
ep_size = self.tp_size
ep_rank = self.tp_rank
ep_size = self.ep_size
ep_rank = self.ep_rank
global_num_experts = self.num_experts
assert ep_size > 0
......@@ -293,7 +289,7 @@ class EPMoE(torch.nn.Module):
local_num_experts = global_num_experts // ep_size
expert_map = torch.full(
(global_num_experts,), self.num_experts, dtype=torch.int32
(global_num_experts,), global_num_experts, dtype=torch.int32
)
if ep_rank < (ep_size - 1):
expert_map[
......@@ -318,6 +314,20 @@ class EPMoE(torch.nn.Module):
hidden_states: torch.Tensor,
topk_output: TopKOutput,
):
self.w13_weight_fp8 = (
self.w13_weight,
(
self.w13_weight_scale_inv
if self.use_block_quant
else self.w13_weight_scale
),
)
self.w2_weight_fp8 = (
self.w2_weight,
self.w2_weight_scale_inv if self.use_block_quant else self.w2_weight_scale,
)
assert self.quant_method is not None
assert self.activation == "silu"
hidden_states_shape = hidden_states.shape
......@@ -457,7 +467,10 @@ class EPMoE(torch.nn.Module):
return output
def forward_normal(self, hidden_states: torch.Tensor, topk_output: TopKOutput):
assert self.quant_method is not None
return self.quant_method.apply(self, hidden_states, topk_output)
def run_moe(self, hidden_states: torch.Tensor, topk_output: TopKOutput):
topk_weights, topk_ids, _ = topk_output
hidden_states_shape = hidden_states.shape
......@@ -470,53 +483,11 @@ class EPMoE(torch.nn.Module):
use_per_token_if_dynamic=self.use_per_token_if_dynamic,
)
if self.use_w4afp8:
local_topk_ids = topk_ids
if self.expert_map is not None:
"Translate info from expert_map to topk_ids"
local_topk_ids = torch.where(
self.expert_map[topk_ids] != self.num_experts,
self.expert_map[topk_ids],
self.num_experts,
)
output = cutlass_w4a8_moe(
self.start_expert_id,
self.end_expert_id,
self.num_experts,
hidden_states,
self.w13_weight,
self.w2_weight,
self.w13_weight_scale_inv,
self.w2_weight_scale_inv,
topk_weights,
topk_ids,
local_topk_ids,
self.quant_method.a_strides1,
self.quant_method.b_strides1,
self.quant_method.c_strides1,
self.quant_method.a_strides2,
self.quant_method.b_strides2,
self.quant_method.c_strides2,
self.quant_method.s_strides13,
self.quant_method.s_strides2,
self.quant_method.expert_offsets,
self.quant_method.problem_sizes1,
self.quant_method.problem_sizes2,
self.w13_input_scale,
self.w2_input_scale,
)
return output
if self.grouped_gemm_runner is None:
self.grouped_gemm_runner = GroupedGemmRunner(
hidden_states.device,
use_flashinfer=False, # TODO: use flashinfer
use_per_token_if_dynamic=self.use_per_token_if_dynamic,
)
num_experts = self.num_experts
reorder_topk_ids, src2dst, seg_indptr = run_moe_ep_preproess(
topk_ids, self.num_experts
topk_ids,
num_experts,
)
gateup_input = torch.empty(
......@@ -524,7 +495,7 @@ class EPMoE(torch.nn.Module):
device=hidden_states.device,
dtype=(
self.fp8_dtype
if ((self.use_fp8_w8a8 or self.use_w4afp8) and not self.use_block_quant)
if self.use_fp8_w8a8 and not self.use_block_quant
else hidden_states.dtype
),
)
......@@ -535,7 +506,7 @@ class EPMoE(torch.nn.Module):
else:
max_value = (
torch.max(hidden_states)
.repeat(self.num_experts_per_partition)
.repeat(self.num_local_experts)
.to(torch.float32)
)
self.w13_input_scale = max_value / torch.finfo(self.fp8_dtype).max
......@@ -576,7 +547,7 @@ class EPMoE(torch.nn.Module):
seg_indptr_cur_rank = seg_indptr[self.start_expert_id : self.end_expert_id + 2]
weight_indices_cur_rank = torch.arange(
0,
self.num_experts_per_partition,
self.num_local_experts,
device=hidden_states_device,
dtype=torch.int64,
)
......@@ -586,17 +557,13 @@ class EPMoE(torch.nn.Module):
b=self.w13_weight,
c=None,
c_dtype=hidden_states_dtype,
batch_size=self.num_experts_per_partition,
batch_size=self.num_local_experts,
weight_column_major=True,
seg_indptr=seg_indptr_cur_rank,
weight_indices=weight_indices_cur_rank,
use_fp8_w8a8=self.use_fp8_w8a8,
scale_a=self.w13_input_scale,
scale_b=(
self.w13_weight_scale_inv
if self.use_block_quant
else self.w13_weight_scale
),
scale_b=self.w13_weight_scale,
block_shape=self.block_shape,
)
del gateup_input
......@@ -653,7 +620,7 @@ class EPMoE(torch.nn.Module):
down_input, self.w2_input_scale = sglang_per_token_quant_fp8(down_input)
else:
self.w2_input_scale = torch.ones(
self.num_experts_per_partition,
self.num_local_experts,
dtype=torch.float32,
device=hidden_states_device,
)
......@@ -669,17 +636,13 @@ class EPMoE(torch.nn.Module):
a=down_input,
b=self.w2_weight,
c=down_output,
batch_size=self.num_experts_per_partition,
batch_size=self.num_local_experts,
weight_column_major=True,
seg_indptr=seg_indptr_cur_rank,
weight_indices=weight_indices_cur_rank,
use_fp8_w8a8=self.use_fp8_w8a8,
scale_a=self.w2_input_scale,
scale_b=(
self.w2_weight_scale_inv
if self.use_block_quant
else self.w2_weight_scale
),
scale_b=self.w2_weight_scale,
block_shape=self.block_shape,
)
del down_input
......@@ -782,107 +745,14 @@ class EPMoE(torch.nn.Module):
return
expert_id = expert_id - self.start_expert_id
if shard_id not in ("w1", "w2", "w3"):
raise ValueError(
f"shard_id must be ['w1','w2','w3'] but " f"got {shard_id}."
)
# Special case for fp8 scales.
if "scale" in weight_name:
self._load_fp8_scale(
param.data,
loaded_weight,
weight_name,
shard_id,
expert_id,
)
return
# Flashinfer assumes w31 format for w13_weight. Same for the scales.
if use_flashinfer_trtllm_moe:
actual_shard_id = {"w1": "w3", "w3": "w1", "w2": "w2"}[shard_id]
else:
actual_shard_id = shard_id
if actual_shard_id == "w2":
param.data[expert_id] = loaded_weight
elif actual_shard_id == "w1":
param.data[expert_id][: self.intermediate_size, :] = loaded_weight
elif actual_shard_id == "w3":
param.data[expert_id][self.intermediate_size :, :] = loaded_weight
else:
raise ValueError(f"Expected shard_id w1,w2 or w3 but got {actual_shard_id}")
def _load_fp8_scale(
self,
param: torch.nn.Parameter,
loaded_weight: torch.Tensor,
weight_name: str,
shard_id: str,
expert_id: int,
) -> None:
param_data = param.data
# Input scales can be loaded directly and should be equal.
if "input_scale" in weight_name:
if self.use_w4afp8:
if shard_id == "w1":
param_data[expert_id][0] = loaded_weight
elif shard_id == "w3":
param_data[expert_id][1] = loaded_weight
else:
param_data[expert_id] = loaded_weight
return
if (
(shard_id == "w1" or shard_id == "w3")
and param_data[expert_id] != 1
and (param_data[expert_id] - loaded_weight).abs() > 1e-5
):
raise ValueError(
"input_scales of w1 and w3 of a layer "
f"must be equal. But got {param_data[expert_id]} "
f"vs. {loaded_weight}"
)
param_data[expert_id] = loaded_weight
# Weight scales
elif "weight_scale" in weight_name:
if self.use_block_quant:
if use_flashinfer_trtllm_moe:
actual_shard_id = {"w1": "w3", "w3": "w1", "w2": "w2"}[shard_id]
else:
actual_shard_id = shard_id
block_n, block_k = self.block_shape[0], self.block_shape[1]
if actual_shard_id == "w1":
param_data[expert_id][
: (self.intermediate_size + block_n - 1) // block_n, :
] = loaded_weight
elif actual_shard_id == "w3":
param_data[expert_id][
(self.intermediate_size + block_n - 1) // block_n :, :
] = loaded_weight
else: # w2
param_data[expert_id] = loaded_weight
elif self.use_w4afp8:
if shard_id == "w1":
param_data[expert_id][: self.intermediate_size, :] = loaded_weight
elif shard_id == "w3":
param_data[expert_id][self.intermediate_size :, :] = loaded_weight
else:
param_data[expert_id] = loaded_weight
# If we are in merged column case (gate_up_proj)
else:
if shard_id in ("w1", "w3"):
# We have to keep the weight scales of w1 and w3 because
# we need to re-quantize w1/w3 weights after weight loading.
idx = 0 if shard_id == "w1" else 1
param_data[expert_id][idx] = loaded_weight
# If we are in the row parallel case (down_proj)
else:
param_data[expert_id] = loaded_weight
self._weight_loader_impl(
param=param,
loaded_weight=loaded_weight,
weight_name=weight_name,
shard_id=shard_id,
expert_id=expert_id,
)
return
class DeepEPMoE(EPMoE):
......@@ -932,13 +802,13 @@ class DeepEPMoE(EPMoE):
deep_gemm_wrapper.ENABLE_JIT_DEEPGEMM
), f"DeepEP {self.deepep_mode} mode requires deep_gemm"
if _use_aiter:
# expert_mask is of size (self.num_experts_per_partition + 1),
# expert_mask is of size (self.num_local_experts + 1),
# the extra 1 is for invalid rank_id (in original deepep, the invalid rank_id is -1, but aiter does not allow -1, we use a mask to make those ids invalid)
# for instance, if we have 4 experts on this rank, we would have a expert_mask like:
# self.expert_mask = [1, 1, 1, 1, 0]
# idx from 0-3 is valid and will be processed, while idx == 4 will be masked out
self.expert_mask = torch.zeros(
(self.num_experts_per_partition + 1),
(self.num_local_experts + 1),
device=torch.cuda.current_device(),
dtype=torch.int,
)
......@@ -1011,13 +881,13 @@ class DeepEPMoE(EPMoE):
if self.activation_scheme == "dynamic" and not self.use_block_quant:
max_value = (
torch.max(hidden_states)
.repeat(self.num_experts_per_partition)
.repeat(self.num_local_experts)
.to(torch.float32)
)
self.w13_input_scale = max_value / torch.finfo(self.fp8_dtype).max
weight_indices_cur_rank = torch.arange(
0,
self.num_experts_per_partition,
self.num_local_experts,
device=hidden_states.device,
dtype=torch.int64,
)
......@@ -1029,7 +899,7 @@ class DeepEPMoE(EPMoE):
b=self.w13_weight,
c=None,
c_dtype=hidden_states.dtype,
batch_size=self.num_experts_per_partition,
batch_size=self.num_local_experts,
weight_column_major=True,
seg_indptr=seg_indptr,
weight_indices=weight_indices_cur_rank,
......@@ -1063,7 +933,7 @@ class DeepEPMoE(EPMoE):
)
if self.w2_input_scale is None and not self.use_block_quant:
self.w2_input_scale = torch.ones(
self.num_experts_per_partition,
self.num_local_experts,
dtype=torch.float32,
device=hidden_states_device,
)
......@@ -1076,7 +946,7 @@ class DeepEPMoE(EPMoE):
reorder_topk_ids,
self.w2_input_scale,
0,
self.num_experts_per_partition - 1,
self.num_local_experts - 1,
BLOCK_SIZE=512,
)
else:
......@@ -1096,7 +966,7 @@ class DeepEPMoE(EPMoE):
a=down_input,
b=self.w2_weight,
c=down_output,
batch_size=self.num_experts_per_partition,
batch_size=self.num_local_experts,
weight_column_major=True,
seg_indptr=seg_indptr,
weight_indices=weight_indices_cur_rank,
......@@ -1121,9 +991,9 @@ class DeepEPMoE(EPMoE):
return hidden_states
# in original deepep, idx == -1 meaning invalid and will not be processed.
# aiter does not accept -1, we use a expert mask to make these idx invalid
# (idx == num_experts_per_partition) meaning not used in aiter fused_moe
# (idx == num_local_experts) meaning not used in aiter fused_moe
topk_idx_copy = topk_idx.to(torch.int32)
topk_idx_copy[topk_idx_copy == -1] = self.num_experts_per_partition
topk_idx_copy[topk_idx_copy == -1] = self.num_local_experts
return fused_moe(
hidden_states,
......
......@@ -77,6 +77,7 @@ class FusedMoE(torch.nn.Module):
routed_scaling_factor: Optional[float] = None,
enable_flashinfer_cutlass_moe: Optional[bool] = False,
enable_ep_moe: Optional[bool] = False,
skip_quant: Optional[bool] = False,
):
super().__init__()
......@@ -99,9 +100,6 @@ class FusedMoE(torch.nn.Module):
self.enable_flashinfer_cutlass_moe = enable_flashinfer_cutlass_moe
if enable_ep_moe:
assert (
self.enable_flashinfer_cutlass_moe
), "FusedMoE only supports EP with --enable-flashinfer-cutlass-moe"
self.ep_size = self.tp_size
self.ep_rank = self.tp_rank
self.tp_size = 1
......@@ -110,16 +108,16 @@ class FusedMoE(torch.nn.Module):
self.expert_map = torch.full((self.num_experts,), -1, dtype=torch.int32)
# Create a expert map for the local experts
assert num_experts % self.ep_size == 0
self.local_num_experts = num_experts // self.ep_size
self.num_local_experts = num_experts // self.ep_size
self.expert_map[
self.ep_rank
* self.local_num_experts : (self.ep_rank + 1)
* self.local_num_experts
] = torch.arange(0, self.local_num_experts, dtype=torch.int32, device="cpu")
* self.num_local_experts : (self.ep_rank + 1)
* self.num_local_experts
] = torch.arange(0, self.num_local_experts, dtype=torch.int32, device="cpu")
else:
self.ep_size = 1
self.ep_rank = 0
self.local_num_experts = num_experts
self.num_local_experts = num_experts
self.routed_scaling_factor = routed_scaling_factor
assert intermediate_size % self.tp_size == 0
self.intermediate_size_per_partition = intermediate_size // self.tp_size
......@@ -134,6 +132,9 @@ class FusedMoE(torch.nn.Module):
not _is_cpu and global_server_args_dict["enable_triton_kernel_moe"]
)
if skip_quant:
return
if quant_config is None:
self.quant_method: Optional[QuantizeMethodBase] = UnquantizedFusedMoEMethod(
self.use_triton_kernels
......@@ -149,7 +150,7 @@ class FusedMoE(torch.nn.Module):
self.quant_config = quant_config
self.quant_method.create_weights(
layer=self,
num_experts=self.local_num_experts,
num_experts=self.num_local_experts,
hidden_size=hidden_size,
# FIXME: figure out which intermediate_size to use
intermediate_size=self.intermediate_size_per_partition,
......@@ -378,6 +379,23 @@ class FusedMoE(torch.nn.Module):
if expert_id == -1:
return
self._weight_loader_impl(
param=param,
loaded_weight=loaded_weight,
weight_name=weight_name,
shard_id=shard_id,
expert_id=expert_id,
)
def _weight_loader_impl(
self,
param: torch.nn.Parameter,
loaded_weight: torch.Tensor,
weight_name: str,
shard_id: str,
expert_id: int,
) -> None:
# TP rank is set to 0 if EP is enabled
tp_rank = 0 if self.ep_size > 1 else get_tensor_model_parallel_rank()
......@@ -398,6 +416,10 @@ class FusedMoE(torch.nn.Module):
f"shard_id must be ['w1','w2','w3'] but " f"got {shard_id}."
)
# Flashinfer assumes w31 format for w13_weight. Same for the scales.
if getattr(self, "use_flashinfer_trtllm_moe", False):
shard_id = {"w1": "w3", "w3": "w1", "w2": "w2"}[shard_id]
WEIGHT_SCALE_SUPPORTED = [e.value for e in FusedMoeWeightScaleSupported]
# Fetch the dim to shard the parameter/loaded weight
# based on the shard id. This will be whatever
......@@ -605,37 +627,3 @@ class FusedMoE(torch.nn.Module):
("w3", ckpt_up_proj_name),
]
]
def _load_fp8_scale(
self,
param: torch.nn.Parameter,
loaded_weight: torch.Tensor,
weight_name: str,
shard_id: str,
expert_id: int,
) -> None:
param_data = param.data
# Input scales can be loaded directly and should be equal.
if "input_scale" in weight_name:
if (
param_data[expert_id] != 1
and (param_data[expert_id] - loaded_weight).abs() > 1e-5
):
raise ValueError(
"input_scales of w1 and w3 of a layer "
f"must be equal. But got {param_data[expert_id]} "
f"vs. {loaded_weight}"
)
param_data[expert_id] = loaded_weight
# Weight scales
elif "weight_scale" in weight_name:
# If we are in merged column case (gate_up_proj)
if shard_id in ("w1", "w3"):
# We have to keep the weight scales of w1 and w3 because
# we need to re-quantize w1/w3 weights after weight loading.
idx = 0 if shard_id == "w1" else 1
param_data[expert_id][idx] = loaded_weight
# If we are in the row parallel case (down_proj)
else:
param_data[expert_id] = loaded_weight
......@@ -172,6 +172,7 @@ class Fp8Config(QuantizationConfig):
self, layer: torch.nn.Module, prefix: str
) -> Optional[QuantizeMethodBase]:
from sglang.srt.layers.linear import LinearBase
from sglang.srt.layers.moe.ep_moe.layer import EPMoE
from sglang.srt.layers.moe.fused_moe_triton import FusedMoE
if isinstance(layer, LinearBase):
......@@ -180,6 +181,8 @@ class Fp8Config(QuantizationConfig):
return Fp8LinearMethod(self)
elif isinstance(layer, FusedMoE):
return Fp8MoEMethod(self)
elif isinstance(layer, EPMoE):
return Fp8EPMoEMethod(self)
return None
def get_scaled_act_names(self) -> List[str]:
......@@ -791,11 +794,13 @@ class Fp8MoEMethod(FusedMoEMethodBase):
# merged w13 weights and generate a single scaling factor.
layer.w13_weight_scale = torch.nn.Parameter(
torch.ones(
layer.num_experts, dtype=torch.float32, device=w13_weight.device
layer.num_local_experts,
dtype=torch.float32,
device=w13_weight.device,
),
requires_grad=False,
)
for expert in range(layer.num_experts):
for expert in range(layer.num_local_experts):
w13_weight[expert, :, :], layer.w13_weight_scale[expert] = (
scaled_fp8_quant(layer.w13_weight.data[expert, :, :])
)
......@@ -871,7 +876,7 @@ class Fp8MoEMethod(FusedMoEMethodBase):
assert layer.w13_weight_scale is not None
shard_size = layer.intermediate_size_per_partition
max_w13_scales = layer.w13_weight_scale.max(dim=1).values
for expert_id in range(layer.num_experts):
for expert_id in range(layer.num_local_experts):
start = 0
for shard_id in range(2):
dq_weight = per_tensor_dequantize(
......@@ -914,7 +919,7 @@ class Fp8MoEMethod(FusedMoEMethodBase):
assert layer.w13_weight_scale is not None
shard_size = layer.intermediate_size_per_partition
max_w13_scales = layer.w13_weight_scale.max(dim=1).values
for expert_id in range(layer.num_experts):
for expert_id in range(layer.num_local_experts):
start = 0
max_w13_scale_fp8 = max_w13_scales[expert_id]
for shard_id in range(2):
......@@ -931,7 +936,7 @@ class Fp8MoEMethod(FusedMoEMethodBase):
# special hack to asm_moe, which takes (weight_scale1 * weight_scale) as post GEMM scaling
# optimal design - shall apply per-column weight_scale1 before GEMM, and weight_scale post
for expert_id in range(layer.num_experts):
for expert_id in range(layer.num_local_experts):
layer.w13_weight_scale1[expert_id] *= max_w13_scales[expert_id]
layer.w2_weight_scale1[expert_id] *= layer.w2_weight_scale[expert_id]
......@@ -979,8 +984,23 @@ class Fp8MoEMethod(FusedMoEMethodBase):
no_combine: bool = False,
routed_scaling_factor: Optional[float] = None,
) -> torch.Tensor:
from sglang.srt.layers.moe.ep_moe.layer import EPMoE
from sglang.srt.layers.moe.fused_moe_triton.fused_moe import fused_experts
if isinstance(layer, EPMoE):
layer.w13_weight_scale = (
layer.w13_weight_scale_inv
if self.block_quant
else layer.w13_weight_scale
)
layer.w2_weight_scale = (
layer.w2_weight_scale_inv if self.block_quant else layer.w2_weight_scale
)
return layer.run_moe(
hidden_states=x,
topk_output=topk_output,
)
if use_intel_amx_backend(layer):
from sglang.srt.layers.moe.topk import apply_topk_weights_cpu
......@@ -1138,248 +1158,6 @@ class Fp8MoEMethod(FusedMoEMethodBase):
return None
class Fp8EPMoEMethod(Fp8MoEMethod):
"""MoE method for FP8.
Supports loading FP8 checkpoints with static weight scale and
dynamic/static activation scale.
Args:
quant_config: The quantization config.
"""
def __init__(self, quant_config: Fp8Config):
self.quant_config = quant_config
self.block_quant = self.quant_config.weight_block_size is not None
def create_weights(
self,
layer: Module,
num_experts_per_partition: int,
hidden_size: int,
intermediate_size: int,
params_dtype: torch.dtype,
**extra_weight_attrs,
):
from sglang.srt.layers.moe.fused_moe_triton import FusedMoeWeightScaleSupported
if self.quant_config.is_checkpoint_fp8_serialized:
params_dtype = torch.float8_e4m3fn
tp_size = get_tensor_model_parallel_world_size()
if self.block_quant:
block_n, block_k = (
self.quant_config.weight_block_size[0],
self.quant_config.weight_block_size[1],
)
# NOTE(HandH1998): To ensure proper alignment of the block-wise quantization scales, the output_size of the weights for both the gate and up layers must be divisible by block_n.
# Required by column parallel or enabling merged weights
if intermediate_size % block_n != 0:
raise ValueError(
f"The output_size of gate's and up's weight = "
f"{intermediate_size} is not divisible by "
f"weight quantization block_n = {block_n}."
)
if tp_size > 1:
# Required by row parallel
if intermediate_size % block_k != 0:
raise ValueError(
f"The input_size of down's weight = "
f"{intermediate_size} is not divisible by "
f"weight quantization block_k = {block_k}."
)
# WEIGHTS
w13_weight = torch.nn.Parameter(
torch.empty(
num_experts_per_partition,
2 * intermediate_size,
hidden_size,
dtype=params_dtype,
),
requires_grad=False,
)
layer.register_parameter("w13_weight", w13_weight)
set_weight_attrs(w13_weight, extra_weight_attrs)
w2_weight = torch.nn.Parameter(
torch.empty(
num_experts_per_partition,
hidden_size,
intermediate_size,
dtype=params_dtype,
),
requires_grad=False,
)
layer.register_parameter("w2_weight", w2_weight)
set_weight_attrs(w2_weight, extra_weight_attrs)
# WEIGHT_SCALES
if self.block_quant:
w13_weight_scale = torch.nn.Parameter(
torch.ones(
num_experts_per_partition,
2 * ((intermediate_size + block_n - 1) // block_n),
(hidden_size + block_k - 1) // block_k,
dtype=torch.float32,
),
requires_grad=False,
)
w2_weight_scale = torch.nn.Parameter(
torch.ones(
num_experts_per_partition,
(hidden_size + block_n - 1) // block_n,
(intermediate_size + block_k - 1) // block_k,
dtype=torch.float32,
),
requires_grad=False,
)
layer.register_parameter("w13_weight_scale_inv", w13_weight_scale)
layer.register_parameter("w2_weight_scale_inv", w2_weight_scale)
assert self.quant_config.activation_scheme == "dynamic"
else:
# WEIGHT_SCALES
# Allocate 2 scales for w1 and w3 respectively.
w13_weight_scale = torch.nn.Parameter(
torch.ones(num_experts_per_partition, 2, dtype=torch.float32),
requires_grad=False,
)
layer.register_parameter("w13_weight_scale", w13_weight_scale)
w2_weight_scale = torch.nn.Parameter(
torch.ones(num_experts_per_partition, dtype=torch.float32),
requires_grad=False,
)
layer.register_parameter("w2_weight_scale", w2_weight_scale)
# Add the quantization method used (per tensor/grouped/channel)
# to ensure the weight scales are loaded in properly
extra_weight_attrs.update(
{"quant_method": FusedMoeWeightScaleSupported.BLOCK.value}
if self.block_quant
else {"quant_method": FusedMoeWeightScaleSupported.TENSOR.value}
)
# If loading fp8 checkpoint, pass the weight loaders.
# If loading an fp16 checkpoint, do not (we will quantize in
# process_weights_after_loading()
if self.quant_config.is_checkpoint_fp8_serialized:
set_weight_attrs(w13_weight_scale, extra_weight_attrs)
set_weight_attrs(w2_weight_scale, extra_weight_attrs)
# INPUT_SCALES
if self.quant_config.activation_scheme == "static":
if not self.quant_config.is_checkpoint_fp8_serialized:
raise ValueError(
"Found static activation scheme for checkpoint that "
"was not serialized fp8."
)
w13_input_scale = torch.nn.Parameter(
torch.ones(num_experts_per_partition, dtype=torch.float32),
requires_grad=False,
)
layer.register_parameter("w13_input_scale", w13_input_scale)
set_weight_attrs(w13_input_scale, extra_weight_attrs)
w2_input_scale = torch.nn.Parameter(
torch.ones(num_experts_per_partition, dtype=torch.float32),
requires_grad=False,
)
layer.register_parameter("w2_input_scale", w2_input_scale)
set_weight_attrs(w2_input_scale, extra_weight_attrs)
else:
layer.w13_input_scale = None
layer.w2_input_scale = None
def process_weights_after_loading(self, layer: Module) -> None:
# If checkpoint is fp16, quantize in place.
if not self.quant_config.is_checkpoint_fp8_serialized:
# If rocm, use float8_e4m3fnuz as dtype
fp8_dtype = torch.float8_e4m3fnuz if _is_hip else torch.float8_e4m3fn
w13_weight = torch.empty_like(layer.w13_weight.data, dtype=fp8_dtype)
w2_weight = torch.empty_like(layer.w2_weight.data, dtype=fp8_dtype)
layer.w13_weight_scale = torch.nn.Parameter(
torch.ones(
layer.num_experts_per_partition,
dtype=torch.float32,
device=w13_weight.device,
),
requires_grad=False,
)
for expert in range(layer.num_experts_per_partition):
w13_weight[expert, :, :], layer.w13_weight_scale[expert] = (
scaled_fp8_quant(layer.w13_weight.data[expert, :, :])
)
w2_weight[expert, :, :], layer.w2_weight_scale[expert] = (
scaled_fp8_quant(layer.w2_weight.data[expert, :, :])
)
layer.w13_weight = torch.nn.Parameter(w13_weight, requires_grad=False)
layer.w2_weight = torch.nn.Parameter(w2_weight, requires_grad=False)
return
# If checkpoint is fp8, we need to handle that the
# MoE kernels require single activation scale and single weight
# scale for w13 per expert.
else:
if self.quant_config.activation_scheme == "static":
if layer.w13_input_scale is None or layer.w2_input_scale is None:
raise ValueError(
"QuantConfig has static quantization, but found "
"activation scales are None."
)
layer.w13_weight_scale = torch.nn.Parameter(
torch.max(layer.w13_weight_scale, dim=1).values,
requires_grad=False,
)
if self.block_quant:
# If ROCm, normalize the weights and scales to e4m3fnuz
if _is_fp8_fnuz:
# activation_scheme: dynamic
w13_weight, w13_weight_scale, _ = normalize_e4m3fn_to_e4m3fnuz(
weight=layer.w13_weight,
weight_scale=layer.w13_weight_scale_inv,
input_scale=None,
)
w2_weight, w2_weight_scale, _ = normalize_e4m3fn_to_e4m3fnuz(
weight=layer.w2_weight,
weight_scale=layer.w2_weight_scale_inv,
input_scale=None,
)
# Reset the parameter
layer.w13_weight = torch.nn.Parameter(
w13_weight, requires_grad=False
)
layer.w13_weight_scale_inv = torch.nn.Parameter(
w13_weight_scale, requires_grad=False
)
layer.w13_input_scale = None
layer.w2_weight = torch.nn.Parameter(w2_weight, requires_grad=False)
layer.w2_weight_scale_inv = torch.nn.Parameter(
w2_weight_scale, requires_grad=False
)
layer.w2_input_scale = None
if _use_aiter:
layer.w13_weight = torch.nn.Parameter(
shuffle_weight(layer.w13_weight.data, (16, 16)),
requires_grad=False,
)
layer.w2_weight = torch.nn.Parameter(
shuffle_weight(layer.w2_weight.data, (16, 16)),
requires_grad=False,
)
return
def apply(
self,
layer: torch.nn.Module,
hidden_states: torch.Tensor,
topk_output: TopKOutput,
) -> torch.Tensor:
raise NotImplementedError
class Fp8KVCacheMethod(BaseKVCacheMethod):
"""
Supports loading kv-cache scaling factors from FP8 checkpoints.
......
......@@ -24,6 +24,7 @@ from sglang.srt.utils import (
)
if TYPE_CHECKING:
from sglang.srt.layers.moe.ep_moe.layer import EPMoE
from sglang.srt.layers.moe.topk import TopKOutput
has_triton_kernels = importlib.util.find_spec("triton_kernels") is not None
......@@ -194,6 +195,15 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
no_combine: bool = False,
routed_scaling_factor: Optional[float] = None,
) -> torch.Tensor:
from sglang.srt.layers.moe.ep_moe.layer import EPMoE
if isinstance(layer, EPMoE):
return layer.run_moe(
hidden_states=x,
topk_output=topk_output,
)
return self.forward(
x=x,
layer=layer,
......@@ -354,69 +364,3 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
raise NotImplementedError("The TPU backend currently does not support MoE.")
forward_native = forward_cpu
class UnquantizedEPMoEMethod(FusedMoEMethodBase, CustomOp):
def create_weights(
self,
layer: torch.nn.Module,
num_experts_per_partition: int,
hidden_size: int,
intermediate_size: int,
params_dtype: torch.dtype,
**extra_weight_attrs,
):
# Fused gate_up_proj (column parallel)
w13_weight = torch.nn.Parameter(
torch.empty(
num_experts_per_partition,
2 * intermediate_size,
hidden_size,
dtype=params_dtype,
),
requires_grad=False,
)
layer.register_parameter("w13_weight", w13_weight)
set_weight_attrs(w13_weight, extra_weight_attrs)
# down_proj (row parallel)
w2_weight = torch.nn.Parameter(
torch.empty(
num_experts_per_partition,
hidden_size,
intermediate_size,
dtype=params_dtype,
),
requires_grad=False,
)
layer.register_parameter("w2_weight", w2_weight)
set_weight_attrs(w2_weight, extra_weight_attrs)
# scale
layer.register_parameter("w13_input_scale", None)
layer.register_parameter("w13_weight_scale", None)
ones_tensor = torch.ones(num_experts_per_partition, dtype=torch.float32)
w2_input_scale = torch.nn.Parameter(
ones_tensor,
requires_grad=False,
)
layer.register_parameter("w2_input_scale", w2_input_scale)
set_weight_attrs(w2_input_scale, extra_weight_attrs)
w2_weight_scale = torch.nn.Parameter(
ones_tensor,
requires_grad=False,
)
layer.register_parameter("w2_weight_scale", w2_weight_scale)
set_weight_attrs(w2_weight_scale, extra_weight_attrs)
def apply(
self,
layer: torch.nn.Module,
hidden_states: torch.Tensor,
topk_output: TopKOutput,
) -> torch.Tensor:
raise NotImplementedError
from __future__ import annotations
import logging
from typing import Any, Dict, List, Optional
from typing import TYPE_CHECKING, Any, Dict, List, Optional
import torch
from torch.nn import Module
......@@ -17,6 +17,9 @@ from sglang.srt.layers.quantization.unquant import UnquantizedLinearMethod
from sglang.srt.layers.quantization.utils import is_layer_skipped
from sglang.srt.utils import set_weight_attrs
if TYPE_CHECKING:
from sglang.srt.layers.moe.ep_moe.layer import EPMoE, TopKOutput
ACTIVATION_SCHEMES = ["static", "dynamic"]
logger = logging.getLogger(__name__)
......@@ -84,13 +87,14 @@ class W4AFp8Config(QuantizationConfig):
self, layer: torch.nn.Module, prefix: str
) -> Optional[QuantizeMethodBase]:
from sglang.srt.layers.linear import LinearBase
from sglang.srt.layers.moe.ep_moe.layer import EPMoE
from sglang.srt.layers.moe.fused_moe_triton import FusedMoE
if isinstance(layer, LinearBase):
if is_layer_skipped(prefix, self.ignored_layers):
return UnquantizedLinearMethod()
return Fp8LinearMethod(self)
elif isinstance(layer, FusedMoE):
elif isinstance(layer, EPMoE):
return W4AFp8MoEMethod(self)
return None
......@@ -105,8 +109,8 @@ class W4AFp8MoEMethod(FusedMoEMethodBase):
def create_weights(
self,
layer: Module,
num_experts_per_partition: int,
layer: EPMoE,
num_experts: int,
hidden_size: int,
intermediate_size: int,
params_dtype: torch.dtype,
......@@ -117,7 +121,7 @@ class W4AFp8MoEMethod(FusedMoEMethodBase):
# Fused gate_up_proj (column parallel)
w13_weight = torch.nn.Parameter(
torch.empty(
num_experts_per_partition,
num_experts,
intermediate_size * 2,
hidden_size // 2,
dtype=torch.int8,
......@@ -130,7 +134,7 @@ class W4AFp8MoEMethod(FusedMoEMethodBase):
# down_proj (row parallel)
w2_weight = torch.nn.Parameter(
torch.empty(
num_experts_per_partition,
num_experts,
hidden_size,
intermediate_size // 2,
dtype=torch.int8,
......@@ -142,7 +146,7 @@ class W4AFp8MoEMethod(FusedMoEMethodBase):
w13_weight_scale = torch.nn.Parameter(
torch.zeros(
num_experts_per_partition,
num_experts,
2 * intermediate_size,
hidden_size // self.quant_config.group_size,
dtype=torch.float32,
......@@ -154,7 +158,7 @@ class W4AFp8MoEMethod(FusedMoEMethodBase):
w2_weight_scale = torch.nn.Parameter(
torch.zeros(
num_experts_per_partition,
num_experts,
hidden_size,
intermediate_size // self.quant_config.group_size,
dtype=torch.float32,
......@@ -166,14 +170,14 @@ class W4AFp8MoEMethod(FusedMoEMethodBase):
# Input scales
w13_input_scale = torch.nn.Parameter(
torch.ones((num_experts_per_partition, 2), dtype=torch.bfloat16),
torch.ones((num_experts, 2), dtype=torch.bfloat16),
requires_grad=False,
)
layer.register_parameter("w13_input_scale", w13_input_scale)
set_weight_attrs(w13_input_scale, extra_weight_attrs)
w2_input_scale = torch.nn.Parameter(
torch.ones(num_experts_per_partition, dtype=torch.bfloat16),
torch.ones(num_experts, dtype=torch.bfloat16),
requires_grad=False,
)
layer.register_parameter("w2_input_scale", w2_input_scale)
......@@ -183,25 +187,25 @@ class W4AFp8MoEMethod(FusedMoEMethodBase):
device = layer.w13_weight.device
self.a_strides1 = torch.full(
(num_experts_per_partition, 3),
(num_experts, 3),
hidden_size,
device=device,
dtype=torch.int64,
)
self.c_strides1 = torch.full(
(num_experts_per_partition, 3),
(num_experts, 3),
2 * intermediate_size,
device=device,
dtype=torch.int64,
)
self.a_strides2 = torch.full(
(num_experts_per_partition, 3),
(num_experts, 3),
intermediate_size,
device=device,
dtype=torch.int64,
)
self.c_strides2 = torch.full(
(num_experts_per_partition, 3),
(num_experts, 3),
hidden_size,
device=device,
dtype=torch.int64,
......@@ -212,13 +216,13 @@ class W4AFp8MoEMethod(FusedMoEMethodBase):
self.s_strides2 = self.c_strides2
self.expert_offsets = torch.empty(
(num_experts_per_partition + 1), dtype=torch.int32, device=device
(num_experts + 1), dtype=torch.int32, device=device
)
self.problem_sizes1 = torch.empty(
(num_experts_per_partition, 3), dtype=torch.int32, device=device
(num_experts, 3), dtype=torch.int32, device=device
)
self.problem_sizes2 = torch.empty(
(num_experts_per_partition, 3), dtype=torch.int32, device=device
(num_experts, 3), dtype=torch.int32, device=device
)
return
......@@ -266,3 +270,50 @@ class W4AFp8MoEMethod(FusedMoEMethodBase):
[w2_input_scale_max], dtype=dtype, device=device
)
layer.w2_input_scale = Parameter(new_w2_input_scale, requires_grad=False)
def apply(
self,
layer: EPMoE,
hidden_states: torch.Tensor,
topk_output: TopKOutput,
) -> torch.Tensor:
# TODO(ch-wan): move it out of this class
from sglang.srt.layers.moe.cutlass_w4a8_moe import cutlass_w4a8_moe
topk_ids, topk_weights, _ = topk_output
local_topk_ids = topk_ids
if layer.expert_map is not None:
"Translate info from expert_map to topk_ids"
local_topk_ids = torch.where(
layer.expert_map[topk_ids] != layer.num_experts,
layer.expert_map[topk_ids],
layer.num_experts,
)
return cutlass_w4a8_moe(
layer.start_expert_id,
layer.end_expert_id,
layer.num_experts,
hidden_states,
layer.w13_weight,
layer.w2_weight,
layer.w13_weight_scale_inv,
layer.w2_weight_scale_inv,
topk_weights,
topk_ids,
local_topk_ids,
self.a_strides1,
self.b_strides1,
self.c_strides1,
self.a_strides2,
self.b_strides2,
self.c_strides2,
self.s_strides13,
self.s_strides2,
self.expert_offsets,
self.problem_sizes1,
self.problem_sizes2,
layer.w13_input_scale,
layer.w2_input_scale,
)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment