Commit 9ddd0f97 authored by zhuwenwen's avatar zhuwenwen
Browse files

update use_nn_moe

parent 90ddfba8
......@@ -348,6 +348,34 @@ def fused_add_rms_norm(
torch.ops._C.fused_add_rms_norm(input, residual, weight, epsilon)
# layer norm ops (opt)
def rms_norm_opt(input: torch.Tensor, weight: torch.Tensor, out: torch.Tensor,
epsilon: float, training: Optional[bool]=False) -> None:
lightop.rmsnorm_forward(input, weight, out, epsilon, training)
def rms_norm_opt_fake(
input: torch.Tensor,
weight: torch.Tensor,
out: torch.Tensor,
epsilon: float,
training: Optional[bool] = False
) -> torch.Tensor:
return torch.empty_like(input)
def fused_add_rms_norm_opt(input: torch.Tensor, residual: torch.Tensor,
weight: torch.Tensor, epsilon: float, training: Optional[bool]=False, inplace: Optional[bool]=False) -> None:
lightop.rn_add_forward_autograd(input, residual, weight, epsilon, training, inplace)
def fused_add_rms_norm_opt_fake(
input: torch.Tensor,
residual: torch.Tensor,
weight: torch.Tensor,
epsilon: float,
training: Optional[bool] = False,
inplace: Optional[bool] = False
) -> torch.Tensor:
return torch.empty_like(input)
def fused_qk_norm_rope(
qkv: torch.Tensor,
num_heads_q: int,
......@@ -3513,3 +3541,17 @@ direct_register_custom_op(
mutates_args=[],
fake_impl=gptq_gemm_fake_,
)
direct_register_custom_op(
op_name="rms_norm_opt",
op_func=rms_norm_opt,
mutates_args=[],
fake_impl=rms_norm_opt_fake,
)
direct_register_custom_op(
op_name="fused_add_rms_norm_opt",
op_func=fused_add_rms_norm_opt,
mutates_args=[],
fake_impl=fused_add_rms_norm_opt_fake,
)
\ No newline at end of file
......@@ -2882,6 +2882,7 @@ class TritonExperts(mk.FusedMoEPermuteExpertsUnpermute):
workspace2: torch.Tensor,
expert_tokens_meta: mk.ExpertTokensMetadata | None,
apply_router_weight_on_input: bool,
use_nn_moe: bool | None = False,
):
# Check constraints.
if self.quant_config.use_int4_w4a16:
......@@ -2967,6 +2968,7 @@ class TritonExperts(mk.FusedMoEPermuteExpertsUnpermute):
per_channel_quant=self.per_act_token_quant,
block_shape=self.block_shape,
B_bias=self.w1_bias,
use_nn_moe=use_nn_moe,
)
self.activation(
......@@ -3005,6 +3007,7 @@ class TritonExperts(mk.FusedMoEPermuteExpertsUnpermute):
per_channel_quant=self.per_act_token_quant,
block_shape=self.block_shape,
B_bias=self.w2_bias,
use_nn_moe=use_nn_moe,
)
# separate function is required for MoE + LoRA
......
......@@ -306,8 +306,9 @@ class FusedMoERouterImpl(FusedMoERouter):
self,
hidden_states: torch.Tensor,
router_logits: torch.Tensor,
use_fused_gate: bool | None = False,
) -> tuple[torch.Tensor, torch.Tensor]:
return self.layer._select_experts(hidden_states, router_logits)
return self.layer._select_experts(hidden_states, router_logits, use_fused_gate)
# --8<-- [start:fused_moe]
......
......@@ -1020,6 +1020,7 @@ class FusedMoEModularKernel(torch.nn.Module):
expert_map: torch.Tensor | None,
apply_router_weight_on_input: bool,
expert_tokens_meta: ExpertTokensMetadata | None,
use_nn_moe: bool | None = False,
) -> torch.Tensor:
_, M_full, N, K, top_k = self.fused_experts.moe_problem_size(
a1q, w1, w2, topk_ids
......@@ -1095,6 +1096,7 @@ class FusedMoEModularKernel(torch.nn.Module):
workspace2=workspace2,
expert_tokens_meta=c_expert_tokens_meta,
apply_router_weight_on_input=apply_router_weight_on_input,
use_nn_moe=use_nn_moe,
)
return fused_out
......@@ -1178,6 +1180,7 @@ class FusedMoEModularKernel(torch.nn.Module):
global_num_experts: int = -1,
expert_map: torch.Tensor | None = None,
apply_router_weight_on_input: bool = False,
use_nn_moe: bool | None = False,
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
"""
This function computes a Mixture of Experts (MoE) layer using two sets
......@@ -1239,6 +1242,7 @@ class FusedMoEModularKernel(torch.nn.Module):
expert_map=expert_map,
apply_router_weight_on_input=apply_router_weight_on_input,
expert_tokens_meta=expert_tokens_meta,
use_nn_moe=use_nn_moe,
)
return self._finalize(
......
......@@ -137,6 +137,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
hidden_size: int,
intermediate_size_per_partition: int,
params_dtype: torch.dtype,
use_nn_moe: bool,
**extra_weight_attrs,
):
if self.moe.is_act_and_mul:
......@@ -144,6 +145,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
else:
w13_up_dim = intermediate_size_per_partition
# Fused gate_up_proj (column parallel)
if not use_nn_moe:
w13_weight = torch.nn.Parameter(
torch.empty(
num_experts,
......@@ -153,6 +155,16 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
),
requires_grad=False,
)
else:
w13_weight = torch.nn.Parameter(
torch.empty(
num_experts,
hidden_size,
w13_up_dim,
dtype=params_dtype,
),
requires_grad=False,
)
layer.register_parameter("w13_weight", w13_weight)
set_weight_attrs(w13_weight, extra_weight_attrs)
if self.moe.has_bias:
......@@ -163,6 +175,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
layer.register_parameter("w13_bias", w13_bias)
set_weight_attrs(w13_bias, extra_weight_attrs)
# down_proj (row parallel)
if not use_nn_moe:
w2_weight = torch.nn.Parameter(
torch.empty(
num_experts,
......@@ -172,6 +185,16 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
),
requires_grad=False,
)
else:
w2_weight = torch.nn.Parameter(
torch.empty(
num_experts,
intermediate_size_per_partition,
hidden_size,
dtype=params_dtype,
),
requires_grad=False,
)
layer.register_parameter("w2_weight", w2_weight)
set_weight_attrs(w2_weight, extra_weight_attrs)
if self.moe.has_bias:
......
......@@ -15,7 +15,6 @@ from vllm.model_executor.layers.batch_invariant import (
)
from vllm.platforms import current_platform
from vllm import envs
import lightop as op
def rms_norm(
......@@ -28,7 +27,7 @@ def rms_norm(
out = torch.empty_like(x)
# if envs.VLLM_USE_OPT_OP:
if False:
op.rmsnorm_forward(
ops.rms_norm_opt(
x,
weight,
out,
......@@ -58,7 +57,7 @@ def fused_add_rms_norm(
), x + residual
# if envs.VLLM_USE_OPT_OP:
if False:
op.rn_add_forward_autograd(
ops.fused_add_rms_norm_opt(
x,
residual,
weight,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment