Commit b80ae5e9 authored by maxiao1's avatar maxiao1
Browse files

adaptation w4a8 tp

parent b091a7a5
......@@ -170,9 +170,7 @@ class RMSNorm(CustomOp):
output = torch.empty_like(x)
residual_out = torch.empty_like(x)
fused_add_rms_norm(
output,
x,
residual_out,
residual,
self.weight.data,
self.variance_epsilon,
......@@ -180,7 +178,9 @@ class RMSNorm(CustomOp):
return output, residual_out
except TypeError:
fused_add_rms_norm(
output,
x,
residual_out,
residual,
self.weight.data,
self.variance_epsilon,
......
from typing import Any, Callable, Dict, List, Optional
from sglang.srt.layers.moe.token_dispatcher.base import CombineInput
from sglang.srt.layers.moe.token_dispatcher.standard import StandardCombineInput, StandardDispatchOutput
# from sglang.srt.layers.moe.token_dispatcher.base import CombineInput
import torch
from sglang.srt import _custom_ops as ops
from sglang.srt.utils import set_weight_attrs
......@@ -218,8 +218,9 @@ class SlimQuantW4A8Int8MarlinMoEMethod:
def apply(
self,
layer: torch.nn.Module,
dispatch_output: StandardDispatchOutput,
) -> CombineInput:
dispatch_output,
) :
from sglang.srt.layers.moe.token_dispatcher.standard import StandardCombineInput
x = dispatch_output.hidden_states
topk_output = dispatch_output.topk_output
from sglang.srt.layers.moe.topk import apply_topk_weights_cpu
......@@ -241,7 +242,7 @@ class SlimQuantW4A8Int8MarlinMoEMethod:
use_int4_w4a8=True,
per_channel_quant=True,
activation=layer.moe_runner_config.activation,
expert_map=layer.expert_map_gpu,
# expert_map=layer.expert_map_gpu,
apply_router_weight_on_input=self.moe_runner_config.apply_router_weight_on_input,
global_num_experts=layer.moe_runner_config.num_experts,
w1_scale=(layer.w13_weight_scale),
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment