Commit b80ae5e9 authored by maxiao1's avatar maxiao1
Browse files

adaptation w4a8 tp

parent b091a7a5
...@@ -170,9 +170,7 @@ class RMSNorm(CustomOp): ...@@ -170,9 +170,7 @@ class RMSNorm(CustomOp):
output = torch.empty_like(x) output = torch.empty_like(x)
residual_out = torch.empty_like(x) residual_out = torch.empty_like(x)
fused_add_rms_norm( fused_add_rms_norm(
output,
x, x,
residual_out,
residual, residual,
self.weight.data, self.weight.data,
self.variance_epsilon, self.variance_epsilon,
...@@ -180,7 +178,9 @@ class RMSNorm(CustomOp): ...@@ -180,7 +178,9 @@ class RMSNorm(CustomOp):
return output, residual_out return output, residual_out
except TypeError: except TypeError:
fused_add_rms_norm( fused_add_rms_norm(
output,
x, x,
residual_out,
residual, residual,
self.weight.data, self.weight.data,
self.variance_epsilon, self.variance_epsilon,
......
from typing import Any, Callable, Dict, List, Optional from typing import Any, Callable, Dict, List, Optional
from sglang.srt.layers.moe.token_dispatcher.base import CombineInput # from sglang.srt.layers.moe.token_dispatcher.base import CombineInput
from sglang.srt.layers.moe.token_dispatcher.standard import StandardCombineInput, StandardDispatchOutput
import torch import torch
from sglang.srt import _custom_ops as ops from sglang.srt import _custom_ops as ops
from sglang.srt.utils import set_weight_attrs from sglang.srt.utils import set_weight_attrs
...@@ -218,8 +218,9 @@ class SlimQuantW4A8Int8MarlinMoEMethod: ...@@ -218,8 +218,9 @@ class SlimQuantW4A8Int8MarlinMoEMethod:
def apply( def apply(
self, self,
layer: torch.nn.Module, layer: torch.nn.Module,
dispatch_output: StandardDispatchOutput, dispatch_output,
) -> CombineInput: ) :
from sglang.srt.layers.moe.token_dispatcher.standard import StandardCombineInput
x = dispatch_output.hidden_states x = dispatch_output.hidden_states
topk_output = dispatch_output.topk_output topk_output = dispatch_output.topk_output
from sglang.srt.layers.moe.topk import apply_topk_weights_cpu from sglang.srt.layers.moe.topk import apply_topk_weights_cpu
...@@ -241,7 +242,7 @@ class SlimQuantW4A8Int8MarlinMoEMethod: ...@@ -241,7 +242,7 @@ class SlimQuantW4A8Int8MarlinMoEMethod:
use_int4_w4a8=True, use_int4_w4a8=True,
per_channel_quant=True, per_channel_quant=True,
activation=layer.moe_runner_config.activation, activation=layer.moe_runner_config.activation,
expert_map=layer.expert_map_gpu, # expert_map=layer.expert_map_gpu,
apply_router_weight_on_input=self.moe_runner_config.apply_router_weight_on_input, apply_router_weight_on_input=self.moe_runner_config.apply_router_weight_on_input,
global_num_experts=layer.moe_runner_config.num_experts, global_num_experts=layer.moe_runner_config.num_experts,
w1_scale=(layer.w13_weight_scale), w1_scale=(layer.w13_weight_scale),
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment