Commit 9ddd0f97 authored by zhuwenwen's avatar zhuwenwen
Browse files

update use_nn_moe

parent 90ddfba8
...@@ -348,6 +348,34 @@ def fused_add_rms_norm( ...@@ -348,6 +348,34 @@ def fused_add_rms_norm(
torch.ops._C.fused_add_rms_norm(input, residual, weight, epsilon) torch.ops._C.fused_add_rms_norm(input, residual, weight, epsilon)
# layer norm ops (opt)
def rms_norm_opt(input: torch.Tensor, weight: torch.Tensor, out: torch.Tensor,
epsilon: float, training: Optional[bool]=False) -> None:
lightop.rmsnorm_forward(input, weight, out, epsilon, training)
def rms_norm_opt_fake(
input: torch.Tensor,
weight: torch.Tensor,
out: torch.Tensor,
epsilon: float,
training: Optional[bool] = False
) -> torch.Tensor:
return torch.empty_like(input)
def fused_add_rms_norm_opt(input: torch.Tensor, residual: torch.Tensor,
weight: torch.Tensor, epsilon: float, training: Optional[bool]=False, inplace: Optional[bool]=False) -> None:
lightop.rn_add_forward_autograd(input, residual, weight, epsilon, training, inplace)
def fused_add_rms_norm_opt_fake(
input: torch.Tensor,
residual: torch.Tensor,
weight: torch.Tensor,
epsilon: float,
training: Optional[bool] = False,
inplace: Optional[bool] = False
) -> torch.Tensor:
return torch.empty_like(input)
def fused_qk_norm_rope( def fused_qk_norm_rope(
qkv: torch.Tensor, qkv: torch.Tensor,
num_heads_q: int, num_heads_q: int,
...@@ -3513,3 +3541,17 @@ direct_register_custom_op( ...@@ -3513,3 +3541,17 @@ direct_register_custom_op(
mutates_args=[], mutates_args=[],
fake_impl=gptq_gemm_fake_, fake_impl=gptq_gemm_fake_,
) )
direct_register_custom_op(
op_name="rms_norm_opt",
op_func=rms_norm_opt,
mutates_args=[],
fake_impl=rms_norm_opt_fake,
)
direct_register_custom_op(
op_name="fused_add_rms_norm_opt",
op_func=fused_add_rms_norm_opt,
mutates_args=[],
fake_impl=fused_add_rms_norm_opt_fake,
)
\ No newline at end of file
...@@ -2882,6 +2882,7 @@ class TritonExperts(mk.FusedMoEPermuteExpertsUnpermute): ...@@ -2882,6 +2882,7 @@ class TritonExperts(mk.FusedMoEPermuteExpertsUnpermute):
workspace2: torch.Tensor, workspace2: torch.Tensor,
expert_tokens_meta: mk.ExpertTokensMetadata | None, expert_tokens_meta: mk.ExpertTokensMetadata | None,
apply_router_weight_on_input: bool, apply_router_weight_on_input: bool,
use_nn_moe: bool | None = False,
): ):
# Check constraints. # Check constraints.
if self.quant_config.use_int4_w4a16: if self.quant_config.use_int4_w4a16:
...@@ -2967,6 +2968,7 @@ class TritonExperts(mk.FusedMoEPermuteExpertsUnpermute): ...@@ -2967,6 +2968,7 @@ class TritonExperts(mk.FusedMoEPermuteExpertsUnpermute):
per_channel_quant=self.per_act_token_quant, per_channel_quant=self.per_act_token_quant,
block_shape=self.block_shape, block_shape=self.block_shape,
B_bias=self.w1_bias, B_bias=self.w1_bias,
use_nn_moe=use_nn_moe,
) )
self.activation( self.activation(
...@@ -3005,6 +3007,7 @@ class TritonExperts(mk.FusedMoEPermuteExpertsUnpermute): ...@@ -3005,6 +3007,7 @@ class TritonExperts(mk.FusedMoEPermuteExpertsUnpermute):
per_channel_quant=self.per_act_token_quant, per_channel_quant=self.per_act_token_quant,
block_shape=self.block_shape, block_shape=self.block_shape,
B_bias=self.w2_bias, B_bias=self.w2_bias,
use_nn_moe=use_nn_moe,
) )
# separate function is required for MoE + LoRA # separate function is required for MoE + LoRA
......
...@@ -306,8 +306,9 @@ class FusedMoERouterImpl(FusedMoERouter): ...@@ -306,8 +306,9 @@ class FusedMoERouterImpl(FusedMoERouter):
self, self,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
router_logits: torch.Tensor, router_logits: torch.Tensor,
use_fused_gate: bool | None = False,
) -> tuple[torch.Tensor, torch.Tensor]: ) -> tuple[torch.Tensor, torch.Tensor]:
return self.layer._select_experts(hidden_states, router_logits) return self.layer._select_experts(hidden_states, router_logits, use_fused_gate)
# --8<-- [start:fused_moe] # --8<-- [start:fused_moe]
......
...@@ -1020,6 +1020,7 @@ class FusedMoEModularKernel(torch.nn.Module): ...@@ -1020,6 +1020,7 @@ class FusedMoEModularKernel(torch.nn.Module):
expert_map: torch.Tensor | None, expert_map: torch.Tensor | None,
apply_router_weight_on_input: bool, apply_router_weight_on_input: bool,
expert_tokens_meta: ExpertTokensMetadata | None, expert_tokens_meta: ExpertTokensMetadata | None,
use_nn_moe: bool | None = False,
) -> torch.Tensor: ) -> torch.Tensor:
_, M_full, N, K, top_k = self.fused_experts.moe_problem_size( _, M_full, N, K, top_k = self.fused_experts.moe_problem_size(
a1q, w1, w2, topk_ids a1q, w1, w2, topk_ids
...@@ -1095,6 +1096,7 @@ class FusedMoEModularKernel(torch.nn.Module): ...@@ -1095,6 +1096,7 @@ class FusedMoEModularKernel(torch.nn.Module):
workspace2=workspace2, workspace2=workspace2,
expert_tokens_meta=c_expert_tokens_meta, expert_tokens_meta=c_expert_tokens_meta,
apply_router_weight_on_input=apply_router_weight_on_input, apply_router_weight_on_input=apply_router_weight_on_input,
use_nn_moe=use_nn_moe,
) )
return fused_out return fused_out
...@@ -1178,6 +1180,7 @@ class FusedMoEModularKernel(torch.nn.Module): ...@@ -1178,6 +1180,7 @@ class FusedMoEModularKernel(torch.nn.Module):
global_num_experts: int = -1, global_num_experts: int = -1,
expert_map: torch.Tensor | None = None, expert_map: torch.Tensor | None = None,
apply_router_weight_on_input: bool = False, apply_router_weight_on_input: bool = False,
use_nn_moe: bool | None = False,
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]: ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
""" """
This function computes a Mixture of Experts (MoE) layer using two sets This function computes a Mixture of Experts (MoE) layer using two sets
...@@ -1239,6 +1242,7 @@ class FusedMoEModularKernel(torch.nn.Module): ...@@ -1239,6 +1242,7 @@ class FusedMoEModularKernel(torch.nn.Module):
expert_map=expert_map, expert_map=expert_map,
apply_router_weight_on_input=apply_router_weight_on_input, apply_router_weight_on_input=apply_router_weight_on_input,
expert_tokens_meta=expert_tokens_meta, expert_tokens_meta=expert_tokens_meta,
use_nn_moe=use_nn_moe,
) )
return self._finalize( return self._finalize(
......
...@@ -137,6 +137,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp): ...@@ -137,6 +137,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
hidden_size: int, hidden_size: int,
intermediate_size_per_partition: int, intermediate_size_per_partition: int,
params_dtype: torch.dtype, params_dtype: torch.dtype,
use_nn_moe: bool,
**extra_weight_attrs, **extra_weight_attrs,
): ):
if self.moe.is_act_and_mul: if self.moe.is_act_and_mul:
...@@ -144,15 +145,26 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp): ...@@ -144,15 +145,26 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
else: else:
w13_up_dim = intermediate_size_per_partition w13_up_dim = intermediate_size_per_partition
# Fused gate_up_proj (column parallel) # Fused gate_up_proj (column parallel)
w13_weight = torch.nn.Parameter( if not use_nn_moe:
torch.empty( w13_weight = torch.nn.Parameter(
num_experts, torch.empty(
w13_up_dim, num_experts,
hidden_size, w13_up_dim,
dtype=params_dtype, hidden_size,
), dtype=params_dtype,
requires_grad=False, ),
) requires_grad=False,
)
else:
w13_weight = torch.nn.Parameter(
torch.empty(
num_experts,
hidden_size,
w13_up_dim,
dtype=params_dtype,
),
requires_grad=False,
)
layer.register_parameter("w13_weight", w13_weight) layer.register_parameter("w13_weight", w13_weight)
set_weight_attrs(w13_weight, extra_weight_attrs) set_weight_attrs(w13_weight, extra_weight_attrs)
if self.moe.has_bias: if self.moe.has_bias:
...@@ -163,15 +175,26 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp): ...@@ -163,15 +175,26 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
layer.register_parameter("w13_bias", w13_bias) layer.register_parameter("w13_bias", w13_bias)
set_weight_attrs(w13_bias, extra_weight_attrs) set_weight_attrs(w13_bias, extra_weight_attrs)
# down_proj (row parallel) # down_proj (row parallel)
w2_weight = torch.nn.Parameter( if not use_nn_moe:
torch.empty( w2_weight = torch.nn.Parameter(
num_experts, torch.empty(
hidden_size, num_experts,
intermediate_size_per_partition, hidden_size,
dtype=params_dtype, intermediate_size_per_partition,
), dtype=params_dtype,
requires_grad=False, ),
) requires_grad=False,
)
else:
w2_weight = torch.nn.Parameter(
torch.empty(
num_experts,
intermediate_size_per_partition,
hidden_size,
dtype=params_dtype,
),
requires_grad=False,
)
layer.register_parameter("w2_weight", w2_weight) layer.register_parameter("w2_weight", w2_weight)
set_weight_attrs(w2_weight, extra_weight_attrs) set_weight_attrs(w2_weight, extra_weight_attrs)
if self.moe.has_bias: if self.moe.has_bias:
......
...@@ -15,7 +15,6 @@ from vllm.model_executor.layers.batch_invariant import ( ...@@ -15,7 +15,6 @@ from vllm.model_executor.layers.batch_invariant import (
) )
from vllm.platforms import current_platform from vllm.platforms import current_platform
from vllm import envs from vllm import envs
import lightop as op
def rms_norm( def rms_norm(
...@@ -28,7 +27,7 @@ def rms_norm( ...@@ -28,7 +27,7 @@ def rms_norm(
out = torch.empty_like(x) out = torch.empty_like(x)
# if envs.VLLM_USE_OPT_OP: # if envs.VLLM_USE_OPT_OP:
if False: if False:
op.rmsnorm_forward( ops.rms_norm_opt(
x, x,
weight, weight,
out, out,
...@@ -58,7 +57,7 @@ def fused_add_rms_norm( ...@@ -58,7 +57,7 @@ def fused_add_rms_norm(
), x + residual ), x + residual
# if envs.VLLM_USE_OPT_OP: # if envs.VLLM_USE_OPT_OP:
if False: if False:
op.rn_add_forward_autograd( ops.fused_add_rms_norm_opt(
x, x,
residual, residual,
weight, weight,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment