Commit 994bbcac authored by liucong's avatar liucong
Browse files

Merge branch 'v0.5.4_dev_liucong' of...

Merge branch 'v0.5.4_dev_liucong' of http://developer.sourcefind.cn/codes/OpenDAS/sglang into v0.5.4_dev_liucong
parents 33fbf3ca 9aea97cc
......@@ -163,6 +163,9 @@ class Envs:
SGLANG_USE_AITER = EnvBool(False)
SGLANG_ROCM_FUSED_DECODE_MLA = EnvBool(False)
SGLANG_ROCM_DISABLE_LINEARQUANT = EnvBool(False)
# DCU Lightop
SGLANG_USE_LIGHTOP = EnvBool(False)
# Quantization
SGLANG_INT4_WEIGHT = EnvBool(False)
......
......@@ -432,11 +432,18 @@ class DCUMLABackend(AttentionBackend):
layer: "RadixAttention",
forward_batch: ForwardBatch,
save_kv_cache: bool = True,
# For multi-head latent attention
q_rope: Optional[torch.Tensor] = None,
k_rope: Optional[torch.Tensor] = None,
sinks=None,
):
if (
if save_kv_cache:
return self.forward_decode(q,k,v,layer,forward_batch, save_kv_cache)
if ((
forward_batch.forward_mode == ForwardMode.EXTEND
or forward_batch.forward_mode == ForwardMode.DRAFT_EXTEND
or forward_batch.forward_mode == ForwardMode.DRAFT_EXTEND)
):
# flash_attn不支持fp8,fp8无法正常执行extend
if not self.skip_prefill:
......@@ -444,7 +451,7 @@ class DCUMLABackend(AttentionBackend):
# q, k, v, layer, forward_batch, save_kv_cache, sinks
# )
return self.flashattn_backend.forward_extend(
q, k, v, layer, forward_batch, save_kv_cache, sinks
q, k, v, layer, forward_batch, save_kv_cache, q_rope, k_rope, sinks
)
else:
raise RuntimeError("skip prefill but use forward_extend")
......
......@@ -86,9 +86,10 @@ def flash_attn_varlen_func(
k=k,
v=v,
cu_seqlens_q=cu_seqlens_q,
cu_seqlens_k=cu_seqlens_q,
cu_seqlens_k=cu_seqlens_k,
max_seqlen_q=max_seqlen_q,
max_seqlen_k=max_seqlen_q,
max_seqlen_k=max_seqlen_k,
softmax_scale=softmax_scale,
causal=causal,
return_attn_probs=return_softmax_lse,
)
\ No newline at end of file
......@@ -167,8 +167,6 @@ class RMSNorm(CustomOp):
if residual is not None:
try:
output = torch.empty_like(x)
residual_out = torch.empty_like(x)
fused_add_rms_norm(
x,
residual,
......@@ -177,6 +175,8 @@ class RMSNorm(CustomOp):
)
return x, residual
except TypeError:
output = torch.empty_like(x)
residual_out = torch.empty_like(x)
fused_add_rms_norm(
output,
x,
......
......@@ -28,6 +28,8 @@ from typing import (
runtime_checkable,
)
from numpy import dtype
import torch
import torch.nn.functional as F
......@@ -68,6 +70,7 @@ _is_cpu = is_cpu()
_is_cpu_amx_available = cpu_has_amx_support()
_is_npu = is_npu()
_use_aiter = get_bool_env_var("SGLANG_USE_AITER") and _is_hip
_use_lightop = get_bool_env_var("SGLANG_USE_LIGHTOP")
if _is_cuda:
from sgl_kernel import moe_fused_gate
......@@ -79,6 +82,8 @@ if _use_aiter:
from aiter import biased_grouped_topk as aiter_biased_grouped_topk
except ImportError:
raise ImportError("aiter is required when SGLANG_USE_AITER is set to True")
if _use_lightop:
from lightop import op as op
if _is_npu:
import torch_npu
......@@ -725,6 +730,18 @@ def biased_grouped_topk_gpu(
routed_scaling_factor,
)
return topk_weights, topk_ids
elif _use_lightop:
assert not apply_routed_scaling_factor_on_output, "Not implemented"
topk_weights, topk_ids = op.moe_fused_gate(
gating_output.to(dtype=torch.float32), # or bfloat16
correction_bias,
num_expert_group,
topk_group,
topk,
0, # 0 in vllm
routed_scaling_factor,
)
return topk_weights, topk_ids
else:
return biased_grouped_topk_impl(
hidden_states,
......
......@@ -22,6 +22,8 @@ from sglang.srt.utils import (
is_xpu,
)
from sglang.srt.utils import direct_register_custom_op
_is_cuda = is_cuda()
_is_hip = is_hip()
_use_aiter = get_bool_env_var("SGLANG_USE_AITER") and _is_hip
......@@ -29,6 +31,7 @@ _is_npu = is_npu()
_is_cpu_amx_available = cpu_has_amx_support()
_is_cpu = is_cpu()
_is_xpu = is_xpu()
_use_lightop = get_bool_env_var("SGLANG_USE_LIGHTOP")
if _is_cuda:
from sgl_kernel import FusedSetKVBufferArg, apply_rope_with_cos_sin_cache_inplace
......@@ -57,6 +60,34 @@ def _rotate_gptj(x: torch.Tensor) -> torch.Tensor:
x = torch.stack((-x2, x1), dim=-1)
return x.flatten(-2)
# for dcu
@triton.jit
def deepseek_scaling_rotary_emb_kernel_gptj(cos_sin, q, stride1: int,
stride2: int, stride_cs: int,
dim1: int, dim2: int, dim3: int,
BLOCK_SIZE: tl.constexpr):
pid0 = tl.program_id(0)
pid1 = tl.program_id(1)
pid2 = tl.program_id(2)
offsets_cs = tl.arange(0, BLOCK_SIZE) + pid2 * BLOCK_SIZE
offsets_q = tl.arange(0, BLOCK_SIZE * 2) + pid2 * BLOCK_SIZE * 2
offsets = pid0 * stride1 + pid1 * stride2 + offsets_q
mask = offsets_cs < dim3
mask2 = offsets_q < dim3 * 2
v_cos = tl.load(cos_sin + pid0 * stride_cs + offsets_cs, mask=mask)
v_cos2 = tl.interleave(v_cos, v_cos)
v_sin = tl.load(cos_sin + pid0 * stride_cs + dim3 + offsets_cs, mask=mask)
v_sin2 = tl.interleave(v_sin, v_sin)
x12 = tl.load(q + offsets, mask=mask2)
x1, x2 = tl.split(x12.reshape([BLOCK_SIZE, 2]))
# we are both reading and writing 'q'; make sure all warps are in sync
tl.debug_barrier()
x12_ = tl.ravel(tl.join(-x2, x1))
x12 = x12 * v_cos2 + x12_ * v_sin2
tl.store(q + offsets, x12, mask=mask2)
def _apply_rotary_emb(
x: torch.Tensor,
......@@ -736,7 +767,10 @@ class DeepseekScalingRotaryEmbedding(RotaryEmbedding):
# Re-dispatch
if _is_hip:
self._forward_method = self.forward_native
if _use_lightop:
self._forward_method = self.forward_dcu
else:
self._forward_method = self.forward_native
def _compute_inv_freq(self, scaling_factor: float) -> torch.Tensor:
pos_freqs = self.base ** (
......@@ -778,6 +812,24 @@ class DeepseekScalingRotaryEmbedding(RotaryEmbedding):
sin = freqs.sin() * self.mscale
cache = torch.cat((cos, sin), dim=-1)
return cache
def rotary_embedding_deepseek_fuse(positions: torch.Tensor, query: torch.Tensor, key: torch.Tensor,
head_size: int, cos_sin_cache: torch.Tensor,
is_neox_style: bool) -> None:
from lightop import op
op.rotary_embedding_deepseek_fuse(positions, query, key, head_size, cos_sin_cache, is_neox_style)
def rotary_embedding_deepseek_fuse_fake(positions: torch.Tensor, query: torch.Tensor, key: torch.Tensor,
head_size: int, cos_sin_cache: torch.Tensor,
is_neox_style: bool) -> None:
pass
direct_register_custom_op(
op_name="rotary_embedding_deepseek_fuse",
op_func=rotary_embedding_deepseek_fuse,
mutates_args=["query", "key"],
fake_impl=rotary_embedding_deepseek_fuse_fake,
)
def forward_native(
self,
......@@ -819,6 +871,77 @@ class DeepseekScalingRotaryEmbedding(RotaryEmbedding):
query = query_rot
key = key_rot
return query.to(dtype), key.to(dtype)
def forward_dcu(
self,
positions: torch.Tensor,
query: torch.Tensor,
key: torch.Tensor,
offsets: Optional[torch.Tensor] = None,
) -> Tuple[torch.Tensor, torch.Tensor]:
assert key is not None
if self.cos_sin_cache.device != positions.device:
self.cos_sin_cache: torch.Tensor = self.cos_sin_cache.to(
positions.device)
cos_sin = self.cos_sin_cache[torch.add(positions, offsets)
if offsets is not None else positions]
if query.device.type == 'cuda' and not self.is_neox_style: # not self.reference ?
assert len(query.shape) == 3
def call(q):
BLOCK_SIZE = 64
grid = (
q.shape[-3],
q.shape[-2],
triton.cdiv(self.rotary_dim // 2, BLOCK_SIZE),
)
deepseek_scaling_rotary_emb_kernel_gptj[grid](
cos_sin,
q,
stride1=q.stride()[-3],
stride2=q.stride()[-2],
stride_cs=cos_sin.stride()[-2],
dim1=q.shape[0],
dim2=q.shape[1],
dim3=self.rotary_dim // 2,
BLOCK_SIZE=BLOCK_SIZE,
num_warps=1)
if _use_lightop:
torch.ops.sglang.rotary_embedding_deepseek_fuse(positions, query, key, self.head_size, self.cos_sin_cache, self.is_neox_style)
else:
call(query)
call(key)
return query, key
else:
query_rot = query[..., :self.rotary_dim]
key_rot = key[..., :self.rotary_dim]
if self.rotary_dim < self.head_size:
query_pass = query[..., self.rotary_dim:]
key_pass = key[..., self.rotary_dim:]
cos, sin = cos_sin.chunk(2, dim=-1)
if self.is_neox_style:
# NOTE(woosuk): Here we assume that the positions tensor has the
# shape [batch_size, seq_len].
cos = cos.repeat(1, 1, 2).unsqueeze(-2)
sin = sin.repeat(1, 1, 2).unsqueeze(-2)
else:
cos = cos.repeat_interleave(2, dim=-1).unsqueeze(-2)
sin = sin.repeat_interleave(2, dim=-1).unsqueeze(-2)
rotate_fn = _rotate_neox if self.is_neox_style else _rotate_gptj
query_rot = query_rot * cos + rotate_fn(query_rot) * sin
key_rot = key_rot * cos + rotate_fn(key_rot) * sin
if self.rotary_dim < self.head_size:
query = torch.cat((query_rot, query_pass), dim=-1)
key = torch.cat((key_rot, key_pass), dim=-1)
else:
query = query_rot
key = key_rot
return query, key
def forward_npu(
self,
......
......@@ -1618,7 +1618,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
self.seq_lens_cpu = self.seq_lens_cpu[keep_indices]
self.orig_seq_lens = self.orig_seq_lens[keep_indices_device]
self.out_cache_loc = None
self.seq_lens_sum = self.seq_lens.sum()
self.seq_lens_sum = self.seq_lens.sum().item()
self.output_ids = self.output_ids[keep_indices_device]
self.return_logprob = any(req.return_logprob for req in self.reqs)
if self.return_logprob:
......
......@@ -174,6 +174,7 @@ MLA_ATTENTION_BACKENDS = [
CHUNKED_PREFIX_CACHE_SUPPORTED_ATTENTION_BACKENDS = [
"flashinfer",
"fa3",
"dcu_mla",
"fa4",
"flashmla",
"cutlass_mla",
......@@ -2238,7 +2239,6 @@ class ModelRunner:
and self.graph_runner
and self.graph_runner.can_run(forward_batch)
)
if can_run_graph:
ret = self.graph_runner.replay(
forward_batch,
......
......@@ -185,6 +185,7 @@ elif _is_hip:
from sglang.srt.layers.quantization.awq_triton import (
awq_dequantize_triton as awq_dequantize,
)
from sgl_kernel import merge_state_v2
elif _is_npu:
import custom_ops # noqa: F401
import sgl_kernel_npu # noqa: F401
......
......@@ -4,7 +4,7 @@
#include <algorithm>
#include <optional>
#include "pytorch_extension_utils.h"
#include "pytorch_extension_utils_rocm.h"
// Helper functions to convert between different data types
// (float, half, bfloat16) for the merge attention states kernel.
......@@ -27,6 +27,19 @@ inline __device__ void from_float(__nv_bfloat16& d, float s) {
d = __float2bfloat16(s);
}
inline void check_shape(const at::Tensor& a, const at::Tensor& b, const char* a_name, const char* b_name) {
TORCH_CHECK(a.dim() == b.dim(), a_name, ".dim() != ", b_name, ".dim(). ", a.dim(), " vs ", b.dim());
for (int i = 0; i < a.dim(); ++i) {
TORCH_CHECK(a.size(i) == b.size(i), a_name, ".size(", i, ") != ", b_name, ".size(", i, ")");
}
}
#define CHECK_SHAPE(a, b) check_shape(a, b, #a, #b)
#define CHECK_EQ(a, b) TORCH_CHECK((a) == (b), "CHECK_EQ(" #a ", " #b ") failed. ", a, " vs ", b)
#define CHECK_GE(a, b) TORCH_CHECK((a) >= (b), "CHECK_GE(" #a ", " #b ") failed. ", a, " vs ", b)
// Implements section 2.2 of https://www.arxiv.org/pdf/2501.01005
template <typename scalar_t, const uint NUM_THREADS>
__global__ void merge_attn_states_kernel(
......
......@@ -34,6 +34,12 @@ TORCH_LIBRARY_EXPAND(sgl_kernel, m) {
m.def("gelu_quick(Tensor! out, Tensor input) -> ()");
m.impl("gelu_quick", torch::kCUDA, &gelu_quick);
/*
* From csrc/attention
*/
m.def("merge_state_v2(Tensor v_a, Tensor s_a, Tensor v_b, Tensor s_b, Tensor! v_merged, Tensor! s_merged) -> ()");
m.impl("merge_state_v2", torch::kCUDA, &merge_state_v2);
/*
* From csrc/allreduce
*/
......
......@@ -50,6 +50,7 @@ sources = [
"csrc/moe/moe_topk_softmax_kernels.cu",
"csrc/speculative/eagle_utils.cu",
"csrc/kvcacheio/transfer.cu",
"csrc/attention/merge_attn_states.cu",
]
cxx_flags = ["-O3"]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment