Unverified Commit 61089465 authored by tomeras91's avatar tomeras91 Committed by GitHub
Browse files

[Model] Add MoE support for NemotronH (#25863)


Signed-off-by: default avatarTomer Asida <57313761+tomeras91@users.noreply.github.com>
parent 88afa110
...@@ -823,6 +823,8 @@ class FusedMoEConfig: ...@@ -823,6 +823,8 @@ class FusedMoEConfig:
has_bias: bool = False has_bias: bool = False
is_act_and_mul: bool = True
def __post_init__(self): def __post_init__(self):
if self.dp_size > 1: if self.dp_size > 1:
logger.debug_once( logger.debug_once(
......
...@@ -1647,6 +1647,7 @@ def fused_experts( ...@@ -1647,6 +1647,7 @@ def fused_experts(
SILU_NO_MUL: str = activation_without_mul("silu") SILU_NO_MUL: str = activation_without_mul("silu")
GELU_NO_MUL: str = activation_without_mul("gelu") GELU_NO_MUL: str = activation_without_mul("gelu")
RELU2_NO_MUL: str = activation_without_mul("relu2")
def _get_config_quant_dtype( def _get_config_quant_dtype(
...@@ -1914,7 +1915,8 @@ def fused_experts_impl( ...@@ -1914,7 +1915,8 @@ def fused_experts_impl(
intermediate_cache2 = F.silu(intermediate_cache1.view(-1, N)) intermediate_cache2 = F.silu(intermediate_cache1.view(-1, N))
elif activation == GELU_NO_MUL: elif activation == GELU_NO_MUL:
intermediate_cache2 = F.gelu(intermediate_cache1.view(-1, N)) intermediate_cache2 = F.gelu(intermediate_cache1.view(-1, N))
elif activation == RELU2_NO_MUL:
intermediate_cache2 = torch.square(F.relu(intermediate_cache1.view(-1, N)))
else: else:
raise ValueError(f"Unsupported FusedMoe activation: {activation}.") raise ValueError(f"Unsupported FusedMoe activation: {activation}.")
......
...@@ -411,11 +411,15 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp): ...@@ -411,11 +411,15 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
params_dtype: torch.dtype, params_dtype: torch.dtype,
**extra_weight_attrs, **extra_weight_attrs,
): ):
if self.moe.is_act_and_mul:
w13_up_dim = 2 * intermediate_size_per_partition
else:
w13_up_dim = intermediate_size_per_partition
# Fused gate_up_proj (column parallel) # Fused gate_up_proj (column parallel)
w13_weight = torch.nn.Parameter( w13_weight = torch.nn.Parameter(
torch.empty( torch.empty(
num_experts, num_experts,
2 * intermediate_size_per_partition, w13_up_dim,
hidden_size, hidden_size,
dtype=params_dtype, dtype=params_dtype,
), ),
...@@ -425,9 +429,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp): ...@@ -425,9 +429,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
set_weight_attrs(w13_weight, extra_weight_attrs) set_weight_attrs(w13_weight, extra_weight_attrs)
if self.moe.has_bias: if self.moe.has_bias:
w13_bias = torch.nn.Parameter( w13_bias = torch.nn.Parameter(
torch.zeros( torch.zeros(num_experts, w13_up_dim, dtype=params_dtype),
num_experts, 2 * intermediate_size_per_partition, dtype=params_dtype
),
requires_grad=False, requires_grad=False,
) )
layer.register_parameter("w13_bias", w13_bias) layer.register_parameter("w13_bias", w13_bias)
...@@ -1073,6 +1075,7 @@ class FusedMoE(CustomOp): ...@@ -1073,6 +1075,7 @@ class FusedMoE(CustomOp):
e_score_correction_bias: torch.Tensor | None = None, e_score_correction_bias: torch.Tensor | None = None,
apply_router_weight_on_input: bool = False, apply_router_weight_on_input: bool = False,
activation: str = "silu", activation: str = "silu",
is_act_and_mul: bool = True,
enable_eplb: bool = False, enable_eplb: bool = False,
num_redundant_experts: int = 0, num_redundant_experts: int = 0,
has_bias: bool = False, has_bias: bool = False,
...@@ -1263,6 +1266,7 @@ class FusedMoE(CustomOp): ...@@ -1263,6 +1266,7 @@ class FusedMoE(CustomOp):
in_dtype=moe_in_dtype, in_dtype=moe_in_dtype,
max_num_tokens=envs.VLLM_MOE_DP_CHUNK_SIZE, max_num_tokens=envs.VLLM_MOE_DP_CHUNK_SIZE,
has_bias=has_bias, has_bias=has_bias,
is_act_and_mul=is_act_and_mul,
) )
self.moe_config = moe self.moe_config = moe
self.moe_quant_config: FusedMoEQuantConfig | None = None self.moe_quant_config: FusedMoEQuantConfig | None = None
...@@ -1283,6 +1287,24 @@ class FusedMoE(CustomOp): ...@@ -1283,6 +1287,24 @@ class FusedMoE(CustomOp):
assert isinstance(quant_method, FusedMoEMethodBase) assert isinstance(quant_method, FusedMoEMethodBase)
self.quant_method = quant_method self.quant_method = quant_method
if not self.moe_config.is_act_and_mul:
# Avoid circular import
from vllm.model_executor.layers.quantization.modelopt import (
ModelOptFp8MoEMethod,
)
if not isinstance(
quant_method, (UnquantizedFusedMoEMethod, ModelOptFp8MoEMethod)
):
raise NotImplementedError(
"is_act_and_mul=False is supported only for unquantized "
"and ModelOpt FP8 moe for now"
)
if not current_platform.is_cuda():
raise NotImplementedError(
"is_act_and_mul=False is supported only for CUDA for now"
)
if self.enable_eplb: if self.enable_eplb:
from vllm.model_executor.layers.quantization.fp8 import Fp8MoEMethod from vllm.model_executor.layers.quantization.fp8 import Fp8MoEMethod
...@@ -1531,7 +1553,10 @@ class FusedMoE(CustomOp): ...@@ -1531,7 +1553,10 @@ class FusedMoE(CustomOp):
): ):
# Index the loaded weight for tp sharding. # Index the loaded weight for tp sharding.
# gate_up_proj: "MergedColumnParallel", so tp sharding on output_dim # gate_up_proj: "MergedColumnParallel", so tp sharding on output_dim
if self.moe_config.is_act_and_mul:
shard_size = expert_data.shape[shard_dim] // 2 shard_size = expert_data.shape[shard_dim] // 2
else:
shard_size = expert_data.shape[shard_dim]
if not load_full: if not load_full:
loaded_weight = loaded_weight.narrow( loaded_weight = loaded_weight.narrow(
shard_dim, shard_size * tp_rank, shard_size shard_dim, shard_size * tp_rank, shard_size
......
...@@ -354,7 +354,11 @@ class ModelOptFp8MoEMethod(FusedMoEMethodBase): ...@@ -354,7 +354,11 @@ class ModelOptFp8MoEMethod(FusedMoEMethodBase):
self.cutlass_fp8_supported = cutlass_fp8_supported() self.cutlass_fp8_supported = cutlass_fp8_supported()
self.flashinfer_moe_backend: FlashinferMoeBackend | None = None self.flashinfer_moe_backend: FlashinferMoeBackend | None = None
if envs.VLLM_USE_FLASHINFER_MOE_FP8 and has_flashinfer_moe(): if (
envs.VLLM_USE_FLASHINFER_MOE_FP8
and has_flashinfer_moe()
and self.moe.is_act_and_mul
):
self.flashinfer_moe_backend = get_flashinfer_moe_backend() self.flashinfer_moe_backend = get_flashinfer_moe_backend()
logger.info_once( logger.info_once(
f"Using FlashInfer {self.flashinfer_moe_backend.value} kernels" f"Using FlashInfer {self.flashinfer_moe_backend.value} kernels"
...@@ -405,10 +409,15 @@ class ModelOptFp8MoEMethod(FusedMoEMethodBase): ...@@ -405,10 +409,15 @@ class ModelOptFp8MoEMethod(FusedMoEMethodBase):
) )
weight_loader = extra_weight_attrs.get("weight_loader") weight_loader = extra_weight_attrs.get("weight_loader")
if self.moe.is_act_and_mul:
w13_up_dim = 2 * intermediate_size_per_partition
else:
w13_up_dim = intermediate_size_per_partition
w13_weight = ModelWeightParameter( w13_weight = ModelWeightParameter(
data=torch.empty( data=torch.empty(
num_experts, num_experts,
2 * intermediate_size_per_partition, w13_up_dim,
hidden_size, hidden_size,
dtype=weight_dtype, dtype=weight_dtype,
), ),
...@@ -433,11 +442,16 @@ class ModelOptFp8MoEMethod(FusedMoEMethodBase): ...@@ -433,11 +442,16 @@ class ModelOptFp8MoEMethod(FusedMoEMethodBase):
if self.quant_config.is_checkpoint_fp8_serialized: if self.quant_config.is_checkpoint_fp8_serialized:
# WEIGHT SCALES - Per-tensor scaling for ModelOpts # WEIGHT SCALES - Per-tensor scaling for ModelOpts
# Allocate 2 scales for w1 and w3 respectively. # For gated MoE, allocate 2 scales for w1 and w3 respectively.
# They will be combined to a single scale after weight loading. # They will be combined to a single scale after weight loading.
# For non-gated MoE, allocate 1 scale for w13.
if self.moe.is_act_and_mul:
w13_weight_scale_shape = (num_experts, 2)
else:
w13_weight_scale_shape = (num_experts, 1)
w13_weight_scale = PerTensorScaleParameter( w13_weight_scale = PerTensorScaleParameter(
data=torch.full( data=torch.full(
(num_experts, 2), w13_weight_scale_shape,
1.0, 1.0,
dtype=torch.float32, dtype=torch.float32,
), ),
...@@ -485,7 +499,14 @@ class ModelOptFp8MoEMethod(FusedMoEMethodBase): ...@@ -485,7 +499,14 @@ class ModelOptFp8MoEMethod(FusedMoEMethodBase):
# Fp8 moe kernel needs single weight scale for w13 per expert. # Fp8 moe kernel needs single weight scale for w13 per expert.
# We take the max of the w1 and w3 scales # We take the max of the w1 and w3 scales
# then dequant and requant each expert. # then dequant and requant each expert.
if layer.w13_weight_scale.dim() == 2: if (
layer.w13_weight_scale.dim() == 2
and layer.w13_weight_scale.shape[1] == 2
):
assert self.moe.is_act_and_mul, (
"w13_weight_scale should have 2 elements per expert "
"only for gated MoE"
)
# Get the maximum scale across w1 and w3 for each expert # Get the maximum scale across w1 and w3 for each expert
max_w13_scales = layer.w13_weight_scale.max(dim=1).values max_w13_scales = layer.w13_weight_scale.max(dim=1).values
......
...@@ -673,7 +673,9 @@ class MixtureOfExperts(Protocol): ...@@ -673,7 +673,9 @@ class MixtureOfExperts(Protocol):
def is_mixture_of_experts(model: object) -> TypeIs[MixtureOfExperts]: def is_mixture_of_experts(model: object) -> TypeIs[MixtureOfExperts]:
return isinstance(model, MixtureOfExperts) return (
isinstance(model, MixtureOfExperts) and getattr(model, "num_moe_layers", 0) > 0
)
@runtime_checkable @runtime_checkable
......
This diff is collapsed.
...@@ -185,6 +185,15 @@ class NemotronHConfig(PretrainedConfig): ...@@ -185,6 +185,15 @@ class NemotronHConfig(PretrainedConfig):
mamba_proj_bias=False, mamba_proj_bias=False,
mamba_chunk_size=256, mamba_chunk_size=256,
rescale_prenorm_residual=True, rescale_prenorm_residual=True,
n_routed_experts=8,
n_shared_experts=1,
moe_intermediate_size=7688,
moe_shared_expert_intermediate_size=7688,
num_experts_per_tok=2,
routed_scaling_factor=1.0,
n_group=1,
topk_group=1,
norm_topk_prob=True,
**kwargs, **kwargs,
): ):
self.vocab_size = vocab_size self.vocab_size = vocab_size
...@@ -241,6 +250,15 @@ class NemotronHConfig(PretrainedConfig): ...@@ -241,6 +250,15 @@ class NemotronHConfig(PretrainedConfig):
self.mamba_proj_bias = mamba_proj_bias self.mamba_proj_bias = mamba_proj_bias
self.chunk_size = mamba_chunk_size self.chunk_size = mamba_chunk_size
self.rescale_prenorm_residual = rescale_prenorm_residual self.rescale_prenorm_residual = rescale_prenorm_residual
self.n_routed_experts = n_routed_experts
self.n_shared_experts = n_shared_experts
self.moe_intermediate_size = moe_intermediate_size
self.moe_shared_expert_intermediate_size = moe_shared_expert_intermediate_size # noqa: E501
self.num_experts_per_tok = num_experts_per_tok
self.routed_scaling_factor = routed_scaling_factor
self.n_group = n_group
self.topk_group = topk_group
self.norm_topk_prob = norm_topk_prob
super().__init__( super().__init__(
pad_token_id=pad_token_id, pad_token_id=pad_token_id,
...@@ -258,5 +276,7 @@ class NemotronHConfig(PretrainedConfig): ...@@ -258,5 +276,7 @@ class NemotronHConfig(PretrainedConfig):
else "attention" else "attention"
if self.hybrid_override_pattern[i] == "*" if self.hybrid_override_pattern[i] == "*"
else "mlp" else "mlp"
if self.hybrid_override_pattern[i] == "-"
else "moe"
for i in range(self.num_hidden_layers) for i in range(self.num_hidden_layers)
] ]
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment