Commit c03b2816 authored by renzhc's avatar renzhc
Browse files

以尽量小的改动支持torch compile,补全部分接口

parent eb4b015f
......@@ -274,6 +274,64 @@ def triton_scaled_mm(a: torch.Tensor,
return quant_ops.triton_scaled_mm(a, b,scale_a,scale_b,out_dtype,bias,best_config)
def cutlass_scaled_mm(a: torch.Tensor,
b: torch.Tensor,
scale_a: torch.Tensor,
scale_b: torch.Tensor,
out_dtype: torch.dtype,
bias: Optional[torch.Tensor] = None) -> torch.Tensor:
"""
`cutlass_scaled_mm` implements a fused version of
`output = torch.mm((scale_a * a), (scale_b * b)).to(out_dtype)`
where scale_a * a and scale_b * b are implemented using numpy-style
broadcasting.
In order to support blockwise scaling like found in DeepSeek V3 we also
support extended "group" broadcast rules. We extend the numpy-style
broadcasting rules with the following rule:
"if the extent of a dimension in the source shape is between 1 and
corresponding extent in the target shape we repeat each element along
that dimension src_shape[dim] // target_shape[dim] times consecutively"
example if we have:
a = [[1, 2], and target_shape = (2, 4)
[3, 4]]
then we would expand a to:
a = [[1, 1, 2, 2],
[3, 3, 4, 4]]
currently we only support the case:
scale_a.shape * [1, 128] == a.shape
scale_b.shape * [128, 128] == b.shape
"""
assert (out_dtype is torch.bfloat16 or out_dtype is torch.float16)
assert bias is None or bias.shape[0] == b.shape[
1] and bias.dtype == out_dtype
# m = a.shape[0]
# n = b.shape[1]
# cutlass_compatible_b = (b.shape[0] % 16 == 0 and b.shape[1] % 16 == 0)
# if current_platform.is_rocm() or not cutlass_compatible_b:
# from vllm.model_executor.layers.quantization.compressed_tensors.triton_scaled_mm import ( # noqa
# triton_scaled_mm)
# return triton_scaled_mm(a, b, scale_a, scale_b, out_dtype, bias)
# out = torch.empty((m, n), dtype=out_dtype, device=a.device)
# torch.ops._C.cutlass_scaled_mm(out, a, b, scale_a, scale_b, bias)
# return out
#return quant_ops.cutlass_scaled_mm(a, b, scale_a, scale_b, out_dtype, bias)
return quant_ops.rocblas_scaled_mm_nn(a, b, scale_a, scale_b, out_dtype, bias)
def rocblas_scaled_mm(a: torch.Tensor,
b: torch.Tensor,
scale_a: torch.Tensor,
scale_b: torch.Tensor,
out_dtype: torch.dtype,
bias: Optional[torch.Tensor] = None) -> torch.Tensor:
return quant_ops.rocblas_scaled_mm_nn(a, b, scale_a, scale_b, out_dtype, bias)
def triton_int8_gemm_helper(m: int,
n: int,
k: int,
......
......@@ -362,6 +362,7 @@ class DCUMLABackend(AttentionBackend):
)
return o
@torch._dynamo.disable()
def forward_decode(
self,
q: torch.Tensor,
......@@ -416,6 +417,7 @@ class DCUMLABackend(AttentionBackend):
return o.view(-1, layer.tp_q_head_num * layer.v_head_dim)
@torch._dynamo.disable() # NOTE: untested
def forward_extend(
self,
q: torch.Tensor,
......
......@@ -45,6 +45,7 @@ from sglang.srt.layers.moe import (
should_use_flashinfer_trtllm_moe,
)
from sglang.srt.utils import (
direct_register_custom_op,
cpu_has_amx_support,
get_bool_env_var,
get_compiler_backend,
......@@ -87,6 +88,39 @@ if _use_lightop:
if _is_npu:
import torch_npu
# ------- custom op for moe_fused_gate
def moe_fused_gate_dcu(gating_output: torch.Tensor, correction_bias: torch.Tensor, num_expert_group: int,
topk_group: int, topk: int,
num_fused_shared_experts: int, routed_scaling_factor: float) -> tuple[torch.Tensor, torch.Tensor]:
topk_weights, topk_ids = op.moe_fused_gate(
gating_output.to(dtype=torch.float32), # or bfloat16
correction_bias,
num_expert_group,
topk_group,
topk,
num_fused_shared_experts, # 0 in vllm
routed_scaling_factor,
)
return topk_weights, topk_ids
def moe_fused_gate_fake(gating_output: torch.Tensor, correction_bias: torch.Tensor, num_expert_group: int,
topk_group: int, topk: int,
num_fused_shared_experts: int, routed_scaling_factor: float) -> tuple[torch.Tensor, torch.Tensor]:
return torch.empty((gating_output.size(0), topk),
dtype=gating_output.dtype,
device=gating_output.device), \
torch.empty((gating_output.size(0), topk),
dtype=gating_output.dtype,
device=gating_output.device)
direct_register_custom_op(
op_name="moe_fused_gate_dcu",
op_func=moe_fused_gate_dcu,
mutates_args=[],
fake_impl=moe_fused_gate_fake,
)
# -------
# -------------------------------- TopKConfig ---------------------------------------
......@@ -732,7 +766,7 @@ def biased_grouped_topk_gpu(
return topk_weights, topk_ids
elif _use_lightop:
assert not apply_routed_scaling_factor_on_output, "Not implemented"
topk_weights, topk_ids = op.moe_fused_gate(
topk_weights, topk_ids = torch.ops.sglang.moe_fused_gate_dcu(
gating_output.to(dtype=torch.float32), # or bfloat16
correction_bias,
num_expert_group,
......
......@@ -154,6 +154,7 @@ class SlimQuantW4A8Int8LinearMethod(LinearMethodBase):
)
layer.register_parameter("weight_scale", weight_scale)
@torch._dynamo.disable()
def apply(
self,
layer: torch.nn.Module,
......
......@@ -3,7 +3,7 @@ from typing import Any, Callable, Dict, List, Optional
import torch
from sglang.srt import _custom_ops as ops
from sglang.srt.utils import set_weight_attrs
from sglang.srt.utils import set_weight_attrs, get_bool_env_var
from sglang.srt.distributed import get_tensor_model_parallel_world_size
from torch.nn.parameter import Parameter
from sglang.srt.layers.linear import LinearBase
......@@ -213,8 +213,8 @@ class SlimQuantW4A8Int8MarlinMoEMethod:
):
self.moe_runner_config = moe_runner_config
self.runner = MoeRunner(MoeRunnerBackend.TRITON, moe_runner_config)
@torch._dynamo.disable()
def apply(
self,
layer: torch.nn.Module,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment