Commit 1eaad6d1 authored by lizhigong's avatar lizhigong
Browse files

Merge branch 'v0.5.4_rzc' into 'v0.5.4_dev'

添加环境变量SGLANG_USE_LIGHTOP 控制 lightop的融合rotaty_emb和moe_gated算子,默认禁用;修复RMSNorm:forward_hip中的错误逻辑

See merge request OpenDAS/sglang!13
parents 63c8d8d0 f453578b
......@@ -163,6 +163,9 @@ class Envs:
SGLANG_USE_AITER = EnvBool(False)
SGLANG_ROCM_FUSED_DECODE_MLA = EnvBool(False)
SGLANG_ROCM_DISABLE_LINEARQUANT = EnvBool(False)
# DCU Lightop
SGLANG_USE_LIGHTOP = EnvBool(False)
# Quantization
SGLANG_INT4_WEIGHT = EnvBool(False)
......
......@@ -167,8 +167,6 @@ class RMSNorm(CustomOp):
if residual is not None:
try:
output = torch.empty_like(x)
residual_out = torch.empty_like(x)
fused_add_rms_norm(
x,
residual,
......@@ -177,6 +175,8 @@ class RMSNorm(CustomOp):
)
return x, residual
except TypeError:
output = torch.empty_like(x)
residual_out = torch.empty_like(x)
fused_add_rms_norm(
output,
x,
......
......@@ -28,6 +28,8 @@ from typing import (
runtime_checkable,
)
from numpy import dtype
import torch
import torch.nn.functional as F
......@@ -68,6 +70,7 @@ _is_cpu = is_cpu()
_is_cpu_amx_available = cpu_has_amx_support()
_is_npu = is_npu()
_use_aiter = get_bool_env_var("SGLANG_USE_AITER") and _is_hip
_use_lightop = get_bool_env_var("SGLANG_USE_LIGHTOP")
if _is_cuda:
from sgl_kernel import moe_fused_gate
......@@ -79,6 +82,8 @@ if _use_aiter:
from aiter import biased_grouped_topk as aiter_biased_grouped_topk
except ImportError:
raise ImportError("aiter is required when SGLANG_USE_AITER is set to True")
if _use_lightop:
from lightop import op as op
if _is_npu:
import torch_npu
......@@ -725,6 +730,18 @@ def biased_grouped_topk_gpu(
routed_scaling_factor,
)
return topk_weights, topk_ids
elif _use_lightop:
assert not apply_routed_scaling_factor_on_output, "Not implemented"
topk_weights, topk_ids = op.moe_fused_gate(
gating_output.to(dtype=torch.float32), # or bfloat16
correction_bias,
num_expert_group,
topk_group,
topk,
0, # 0 in vllm
routed_scaling_factor,
)
return topk_weights, topk_ids
else:
return biased_grouped_topk_impl(
hidden_states,
......
......@@ -22,6 +22,8 @@ from sglang.srt.utils import (
is_xpu,
)
from sglang.srt.utils import direct_register_custom_op
_is_cuda = is_cuda()
_is_hip = is_hip()
_use_aiter = get_bool_env_var("SGLANG_USE_AITER") and _is_hip
......@@ -29,6 +31,7 @@ _is_npu = is_npu()
_is_cpu_amx_available = cpu_has_amx_support()
_is_cpu = is_cpu()
_is_xpu = is_xpu()
_use_lightop = get_bool_env_var("SGLANG_USE_LIGHTOP")
if _is_cuda:
from sgl_kernel import FusedSetKVBufferArg, apply_rope_with_cos_sin_cache_inplace
......@@ -57,6 +60,34 @@ def _rotate_gptj(x: torch.Tensor) -> torch.Tensor:
x = torch.stack((-x2, x1), dim=-1)
return x.flatten(-2)
# for dcu
@triton.jit
def deepseek_scaling_rotary_emb_kernel_gptj(cos_sin, q, stride1: int,
stride2: int, stride_cs: int,
dim1: int, dim2: int, dim3: int,
BLOCK_SIZE: tl.constexpr):
pid0 = tl.program_id(0)
pid1 = tl.program_id(1)
pid2 = tl.program_id(2)
offsets_cs = tl.arange(0, BLOCK_SIZE) + pid2 * BLOCK_SIZE
offsets_q = tl.arange(0, BLOCK_SIZE * 2) + pid2 * BLOCK_SIZE * 2
offsets = pid0 * stride1 + pid1 * stride2 + offsets_q
mask = offsets_cs < dim3
mask2 = offsets_q < dim3 * 2
v_cos = tl.load(cos_sin + pid0 * stride_cs + offsets_cs, mask=mask)
v_cos2 = tl.interleave(v_cos, v_cos)
v_sin = tl.load(cos_sin + pid0 * stride_cs + dim3 + offsets_cs, mask=mask)
v_sin2 = tl.interleave(v_sin, v_sin)
x12 = tl.load(q + offsets, mask=mask2)
x1, x2 = tl.split(x12.reshape([BLOCK_SIZE, 2]))
# we are both reading and writing 'q'; make sure all warps are in sync
tl.debug_barrier()
x12_ = tl.ravel(tl.join(-x2, x1))
x12 = x12 * v_cos2 + x12_ * v_sin2
tl.store(q + offsets, x12, mask=mask2)
def _apply_rotary_emb(
x: torch.Tensor,
......@@ -736,7 +767,10 @@ class DeepseekScalingRotaryEmbedding(RotaryEmbedding):
# Re-dispatch
if _is_hip:
self._forward_method = self.forward_native
if _use_lightop:
self._forward_method = self.forward_dcu
else:
self._forward_method = self.forward_native
def _compute_inv_freq(self, scaling_factor: float) -> torch.Tensor:
pos_freqs = self.base ** (
......@@ -778,6 +812,24 @@ class DeepseekScalingRotaryEmbedding(RotaryEmbedding):
sin = freqs.sin() * self.mscale
cache = torch.cat((cos, sin), dim=-1)
return cache
def rotary_embedding_deepseek_fuse(positions: torch.Tensor, query: torch.Tensor, key: torch.Tensor,
head_size: int, cos_sin_cache: torch.Tensor,
is_neox_style: bool) -> None:
from lightop import op
op.rotary_embedding_deepseek_fuse(positions, query, key, head_size, cos_sin_cache, is_neox_style)
def rotary_embedding_deepseek_fuse_fake(positions: torch.Tensor, query: torch.Tensor, key: torch.Tensor,
head_size: int, cos_sin_cache: torch.Tensor,
is_neox_style: bool) -> None:
pass
direct_register_custom_op(
op_name="rotary_embedding_deepseek_fuse",
op_func=rotary_embedding_deepseek_fuse,
mutates_args=["query", "key"],
fake_impl=rotary_embedding_deepseek_fuse_fake,
)
def forward_native(
self,
......@@ -819,6 +871,77 @@ class DeepseekScalingRotaryEmbedding(RotaryEmbedding):
query = query_rot
key = key_rot
return query.to(dtype), key.to(dtype)
def forward_dcu(
self,
positions: torch.Tensor,
query: torch.Tensor,
key: torch.Tensor,
offsets: Optional[torch.Tensor] = None,
) -> Tuple[torch.Tensor, torch.Tensor]:
assert key is not None
if self.cos_sin_cache.device != positions.device:
self.cos_sin_cache: torch.Tensor = self.cos_sin_cache.to(
positions.device)
cos_sin = self.cos_sin_cache[torch.add(positions, offsets)
if offsets is not None else positions]
if query.device.type == 'cuda' and not self.is_neox_style: # not self.reference ?
assert len(query.shape) == 3
def call(q):
BLOCK_SIZE = 64
grid = (
q.shape[-3],
q.shape[-2],
triton.cdiv(self.rotary_dim // 2, BLOCK_SIZE),
)
deepseek_scaling_rotary_emb_kernel_gptj[grid](
cos_sin,
q,
stride1=q.stride()[-3],
stride2=q.stride()[-2],
stride_cs=cos_sin.stride()[-2],
dim1=q.shape[0],
dim2=q.shape[1],
dim3=self.rotary_dim // 2,
BLOCK_SIZE=BLOCK_SIZE,
num_warps=1)
if _use_lightop:
torch.ops.sglang.rotary_embedding_deepseek_fuse(positions, query, key, self.head_size, self.cos_sin_cache, self.is_neox_style)
else:
call(query)
call(key)
return query, key
else:
query_rot = query[..., :self.rotary_dim]
key_rot = key[..., :self.rotary_dim]
if self.rotary_dim < self.head_size:
query_pass = query[..., self.rotary_dim:]
key_pass = key[..., self.rotary_dim:]
cos, sin = cos_sin.chunk(2, dim=-1)
if self.is_neox_style:
# NOTE(woosuk): Here we assume that the positions tensor has the
# shape [batch_size, seq_len].
cos = cos.repeat(1, 1, 2).unsqueeze(-2)
sin = sin.repeat(1, 1, 2).unsqueeze(-2)
else:
cos = cos.repeat_interleave(2, dim=-1).unsqueeze(-2)
sin = sin.repeat_interleave(2, dim=-1).unsqueeze(-2)
rotate_fn = _rotate_neox if self.is_neox_style else _rotate_gptj
query_rot = query_rot * cos + rotate_fn(query_rot) * sin
key_rot = key_rot * cos + rotate_fn(key_rot) * sin
if self.rotary_dim < self.head_size:
query = torch.cat((query_rot, query_pass), dim=-1)
key = torch.cat((key_rot, key_pass), dim=-1)
else:
query = query_rot
key = key_rot
return query, key
def forward_npu(
self,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment