Commit f810bda3 authored by renzhc's avatar renzhc
Browse files

Merge branch 'v0.5.4_dev' into v0.5.4_rzc

parents 4167eff9 48542418
...@@ -332,6 +332,18 @@ def rocblas_scaled_mm(a: torch.Tensor, ...@@ -332,6 +332,18 @@ def rocblas_scaled_mm(a: torch.Tensor,
return quant_ops.rocblas_scaled_mm_nn(a, b, scale_a, scale_b, out_dtype, bias) return quant_ops.rocblas_scaled_mm_nn(a, b, scale_a, scale_b, out_dtype, bias)
def blaslt_scaled_mm(a: torch.Tensor,
b: torch.Tensor,
scale_a: torch.Tensor,
scale_b: torch.Tensor,
out_dtype: torch.dtype,
bias: Optional[torch.Tensor] = None) -> torch.Tensor:
m = a.shape[0]
n = b.shape[0]
k = a.shape[1]
_, out = quant_ops.hipblaslt_w8a8_gemm(a, b, scale_a, scale_b, m, n, k, 'NT', out_dtype)
return out
def triton_int8_gemm_helper(m: int, def triton_int8_gemm_helper(m: int,
n: int, n: int,
k: int, k: int,
......
...@@ -11,6 +11,8 @@ from sglang.srt.layers.attention.base_attn_backend import AttentionBackend ...@@ -11,6 +11,8 @@ from sglang.srt.layers.attention.base_attn_backend import AttentionBackend
from sglang.srt.layers.attention.utils import create_flashmla_kv_indices_triton from sglang.srt.layers.attention.utils import create_flashmla_kv_indices_triton
from sglang.srt.layers.dp_attention import get_attention_tp_size from sglang.srt.layers.dp_attention import get_attention_tp_size
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode
from sgl_kernel.flash_mla import dcu_create_flashmla_kv_indices
from sglang.srt.utils import get_bool_env_var
try: try:
from flash_mla import ( from flash_mla import (
...@@ -104,6 +106,7 @@ class DCUMLABackend(AttentionBackend): ...@@ -104,6 +106,7 @@ class DCUMLABackend(AttentionBackend):
) )
def init_forward_metadata(self, forward_batch: ForwardBatch): def init_forward_metadata(self, forward_batch: ForwardBatch):
use_sglang_create_flashmla_kv_indices_triton = get_bool_env_var("SGLANG_CREATE_FLASHMLA_KV_INDICES_TRITON")
bs = forward_batch.batch_size bs = forward_batch.batch_size
if forward_batch.forward_mode.is_decode_or_idle(): if forward_batch.forward_mode.is_decode_or_idle():
...@@ -118,15 +121,27 @@ class DCUMLABackend(AttentionBackend): ...@@ -118,15 +121,27 @@ class DCUMLABackend(AttentionBackend):
dtype=torch.int32, dtype=torch.int32,
device=forward_batch.seq_lens.device device=forward_batch.seq_lens.device
) )
create_flashmla_kv_indices_triton[(bs,)]( if use_sglang_create_flashmla_kv_indices_triton:
self.req_to_token, dcu_create_flashmla_kv_indices(
forward_batch.req_pool_indices, req_to_token_ptr = self.req_to_token,
forward_batch.seq_lens, req_pool_indices_ptr = forward_batch.req_pool_indices,
None, page_kernel_lens_ptr = forward_batch.seq_lens,
block_kv_indices, kv_start_idx = None,
self.req_to_token.stride(0), kv_indices_ptr = block_kv_indices,
max_seqlen_pad, req_to_token_ptr_stride = self.req_to_token.stride(0),
) kv_indices_ptr_stride = max_seqlen_pad,
)
else:
create_flashmla_kv_indices_triton[(bs,)](
self.req_to_token,
forward_batch.req_pool_indices,
forward_batch.seq_lens,
None,
block_kv_indices,
self.req_to_token.stride(0),
max_seqlen_pad,
)
mla_metadata, num_splits = get_mla_metadata( mla_metadata, num_splits = get_mla_metadata(
forward_batch.seq_lens.to(torch.int32), forward_batch.seq_lens.to(torch.int32),
...@@ -149,15 +164,27 @@ class DCUMLABackend(AttentionBackend): ...@@ -149,15 +164,27 @@ class DCUMLABackend(AttentionBackend):
dtype=torch.int32, dtype=torch.int32,
device=seq_lens.device, device=seq_lens.device,
) )
create_flashmla_kv_indices_triton[(bs,)]( if use_sglang_create_flashmla_kv_indices_triton:
self.req_to_token, dcu_create_flashmla_kv_indices(
forward_batch.req_pool_indices, req_to_token_ptr = self.req_to_token,
seq_lens, req_pool_indices_ptr = forward_batch.req_pool_indices,
None, page_kernel_lens_ptr = forward_batch.seq_lens,
block_kv_indices, kv_start_idx = None,
self.req_to_token.stride(0), kv_indices_ptr = block_kv_indices,
max_seqlen_pad, req_to_token_ptr_stride = self.req_to_token.stride(0),
) kv_indices_ptr_stride = max_seqlen_pad,
)
else:
create_flashmla_kv_indices_triton[(bs,)](
self.req_to_token,
forward_batch.req_pool_indices,
forward_batch.seq_lens,
None,
block_kv_indices,
self.req_to_token.stride(0),
max_seqlen_pad,
)
mla_metadata, num_splits = get_mla_metadata( mla_metadata, num_splits = get_mla_metadata(
seq_lens.to(torch.int32), seq_lens.to(torch.int32),
self.num_draft_tokens * self.num_q_heads, self.num_draft_tokens * self.num_q_heads,
...@@ -185,15 +212,27 @@ class DCUMLABackend(AttentionBackend): ...@@ -185,15 +212,27 @@ class DCUMLABackend(AttentionBackend):
) )
# 调用 Triton kernel 生成 block_kv_indices # 调用 Triton kernel 生成 block_kv_indices
create_flashmla_kv_indices_triton[(bs,)]( if use_sglang_create_flashmla_kv_indices_triton:
self.req_to_token, dcu_create_flashmla_kv_indices(
forward_batch.req_pool_indices, req_to_token_ptr = self.req_to_token.to(torch.int32),
seq_lens, req_pool_indices_ptr = forward_batch.req_pool_indices.to(torch.int32),
None, page_kernel_lens_ptr = forward_batch.seq_lens.to(torch.int32),
block_kv_indices, kv_start_idx = None,
self.req_to_token.stride(0), kv_indices_ptr = block_kv_indices.to(torch.int32),
max_seqlen_pad, req_to_token_ptr_stride = self.req_to_token.stride(0),
) kv_indices_ptr_stride = max_seqlen_pad,
)
else:
create_flashmla_kv_indices_triton[(bs,)](
self.req_to_token,
forward_batch.req_pool_indices,
forward_batch.seq_lens,
None,
block_kv_indices,
self.req_to_token.stride(0),
max_seqlen_pad,
)
# MLA # MLA
mla_metadata, num_splits = get_mla_metadata( mla_metadata, num_splits = get_mla_metadata(
...@@ -211,6 +250,7 @@ class DCUMLABackend(AttentionBackend): ...@@ -211,6 +250,7 @@ class DCUMLABackend(AttentionBackend):
self.flashattn_backend.init_forward_metadata(forward_batch) self.flashattn_backend.init_forward_metadata(forward_batch)
def init_cuda_graph_state( def init_cuda_graph_state(
self, self,
max_bs: int, max_bs: int,
...@@ -489,9 +529,10 @@ class DCUMLABackend(AttentionBackend): ...@@ -489,9 +529,10 @@ class DCUMLABackend(AttentionBackend):
k_rope: Optional[torch.Tensor] = None, k_rope: Optional[torch.Tensor] = None,
sinks=None, sinks=None,
): ):
if (
if ((
forward_batch.forward_mode == ForwardMode.EXTEND forward_batch.forward_mode == ForwardMode.EXTEND
or forward_batch.forward_mode == ForwardMode.DRAFT_EXTEND or forward_batch.forward_mode == ForwardMode.DRAFT_EXTEND)
): ):
if not self.skip_prefill: if not self.skip_prefill:
return self.flashattn_backend.forward_extend( return self.flashattn_backend.forward_extend(
......
...@@ -674,16 +674,11 @@ class FlashAttentionBackend(AttentionBackend): ...@@ -674,16 +674,11 @@ class FlashAttentionBackend(AttentionBackend):
if not layer.is_cross_attention if not layer.is_cross_attention
else forward_batch.encoder_out_cache_loc else forward_batch.encoder_out_cache_loc
) )
# if not self.use_mla:
if k_rope is None: if k_rope is None:
if not self.use_mla: forward_batch.token_to_kv_pool.set_kv_buffer(
forward_batch.token_to_kv_pool.set_kv_buffer( layer, cache_loc, k, v, #layer.k_scale, layer.v_scale
layer, cache_loc, k, v, layer.k_scale, layer.v_scale )
)
else:
forward_batch.token_to_kv_pool.set_kv_buffer(
layer, cache_loc, k, v
)
else: else:
forward_batch.token_to_kv_pool.set_mla_kv_buffer( forward_batch.token_to_kv_pool.set_mla_kv_buffer(
layer, layer,
......
...@@ -16,6 +16,10 @@ from sglang.srt.layers.attention.utils import create_flashmla_kv_indices_triton ...@@ -16,6 +16,10 @@ from sglang.srt.layers.attention.utils import create_flashmla_kv_indices_triton
from sglang.srt.layers.dp_attention import get_attention_tp_size from sglang.srt.layers.dp_attention import get_attention_tp_size
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode
from sglang.srt.utils import get_bool_env_var
from sgl_kernel.flash_mla import dcu_create_flashmla_kv_indices
if TYPE_CHECKING: if TYPE_CHECKING:
from sglang.srt.layers.radix_attention import RadixAttention from sglang.srt.layers.radix_attention import RadixAttention
from sglang.srt.model_executor.model_runner import ModelRunner from sglang.srt.model_executor.model_runner import ModelRunner
...@@ -79,7 +83,7 @@ class FlashMLABackend(FlashInferMLAAttnBackend): ...@@ -79,7 +83,7 @@ class FlashMLABackend(FlashInferMLAAttnBackend):
self.num_draft_tokens = model_runner.server_args.speculative_num_draft_tokens self.num_draft_tokens = model_runner.server_args.speculative_num_draft_tokens
def init_forward_metadata(self, forward_batch: ForwardBatch): def init_forward_metadata(self, forward_batch: ForwardBatch):
use_sglang_create_flashmla_kv_indices_triton = get_bool_env_var("SGLANG_CREATE_EXTEND_AFTER_DECODE_SPEC_INFO")
bs = forward_batch.batch_size bs = forward_batch.batch_size
if forward_batch.forward_mode.is_decode_or_idle(): if forward_batch.forward_mode.is_decode_or_idle():
max_seqlen_pad = triton.cdiv( max_seqlen_pad = triton.cdiv(
...@@ -91,15 +95,27 @@ class FlashMLABackend(FlashInferMLAAttnBackend): ...@@ -91,15 +95,27 @@ class FlashMLABackend(FlashInferMLAAttnBackend):
dtype=torch.int32, dtype=torch.int32,
device=forward_batch.seq_lens.device, device=forward_batch.seq_lens.device,
) )
create_flashmla_kv_indices_triton[(bs,)]( if use_sglang_create_flashmla_kv_indices_triton:
self.req_to_token, dcu_create_flashmla_kv_indices(
forward_batch.req_pool_indices, req_to_token_ptr = self.req_to_token,
forward_batch.seq_lens, req_pool_indices_ptr = forward_batch.req_pool_indices,
None, page_kernel_lens_ptr = forward_batch.seq_lens,
block_kv_indices, kv_start_idx = None,
self.req_to_token.stride(0), kv_indices_ptr = block_kv_indices,
max_seqlen_pad, req_to_token_ptr_stride = self.req_to_token.stride(0),
) kv_indices_ptr_stride = max_seqlen_pad,
)
else:
create_flashmla_kv_indices_triton[(bs,)](
self.req_to_token,
forward_batch.req_pool_indices,
forward_batch.seq_lens,
None,
block_kv_indices,
self.req_to_token.stride(0),
max_seqlen_pad,
)
mla_metadata, num_splits = get_mla_metadata( mla_metadata, num_splits = get_mla_metadata(
forward_batch.seq_lens.to(torch.int32), forward_batch.seq_lens.to(torch.int32),
self.num_q_heads, self.num_q_heads,
...@@ -121,15 +137,27 @@ class FlashMLABackend(FlashInferMLAAttnBackend): ...@@ -121,15 +137,27 @@ class FlashMLABackend(FlashInferMLAAttnBackend):
dtype=torch.int32, dtype=torch.int32,
device=seq_lens.device, device=seq_lens.device,
) )
create_flashmla_kv_indices_triton[(bs,)]( if use_sglang_create_flashmla_kv_indices_triton:
self.req_to_token, dcu_create_flashmla_kv_indices(
forward_batch.req_pool_indices, req_to_token_ptr = self.req_to_token,
seq_lens, req_pool_indices_ptr = forward_batch.req_pool_indices,
None, page_kernel_lens_ptr = forward_batch.seq_lens,
block_kv_indices, kv_start_idx = None,
self.req_to_token.stride(0), kv_indices_ptr = block_kv_indices,
max_seqlen_pad, req_to_token_ptr_stride = self.req_to_token.stride(0),
) kv_indices_ptr_stride = max_seqlen_pad,
)
else:
create_flashmla_kv_indices_triton[(bs,)](
self.req_to_token,
forward_batch.req_pool_indices,
forward_batch.seq_lens,
None,
block_kv_indices,
self.req_to_token.stride(0),
max_seqlen_pad,
)
mla_metadata, num_splits = get_mla_metadata( mla_metadata, num_splits = get_mla_metadata(
seq_lens.to(torch.int32), seq_lens.to(torch.int32),
self.num_draft_tokens * self.num_q_heads, self.num_draft_tokens * self.num_q_heads,
...@@ -144,7 +172,7 @@ class FlashMLABackend(FlashInferMLAAttnBackend): ...@@ -144,7 +172,7 @@ class FlashMLABackend(FlashInferMLAAttnBackend):
) )
else: else:
super().init_forward_metadata(forward_batch) super().init_forward_metadata(forward_batch)
def init_cuda_graph_state( def init_cuda_graph_state(
self, self,
max_bs: int, max_bs: int,
......
...@@ -1013,3 +1013,134 @@ def zero_experts_compute_triton( ...@@ -1013,3 +1013,134 @@ def zero_experts_compute_triton(
) )
return output return output
from triton.language.extra import libdevice
from typing import Optional
@triton.jit
def _per_token_quant_int8_one_kernel_opt(
x_ptr,
xq_ptr,
scale_ptr,
stride_x,
stride_xq,
N,
T_dim,
tokens_per_expert_ptr,
BLOCK: tl.constexpr
):
row_id = tl.program_id(0)
if tokens_per_expert_ptr is not None:
e = row_id // T_dim
t = row_id % T_dim
num_valid_tokens_for_e = tl.load(tokens_per_expert_ptr + e)
if t >= num_valid_tokens_for_e:
return
cols = tl.arange(0, BLOCK)
mask = cols < N
x = tl.load(x_ptr + row_id * stride_x + cols, mask=mask,
other=0.0).to(tl.float32)
absmax = tl.maximum(tl.max(tl.abs(x)), 1e-10)
scale_x = absmax / 127
x_q = x * (127 / absmax)
x_q = libdevice.nearbyint(x_q).to(tl.int8)
tl.store(xq_ptr + row_id * stride_xq + cols, x_q, mask=mask)
tl.store(scale_ptr + row_id, scale_x)
@triton.jit
def _per_token_quant_int8_kernel_opt(
x_ptr,
xq_ptr,
scale_ptr,
stride_x,
stride_xq,
N,
E_dim,
T_dim,
tokens_per_expert_ptr,
BLOCK: tl.constexpr
):
token_idx_start = tl.program_id(0)
grid_size = tl.num_programs(0)
num_total_tokens = E_dim * T_dim
for token_idx in range(token_idx_start, num_total_tokens, grid_size):
is_valid_token = True
if tokens_per_expert_ptr is not None:
e = token_idx // T_dim
t = token_idx % T_dim
num_valid_tokens_for_e = tl.load(tokens_per_expert_ptr + e)
if t >= num_valid_tokens_for_e:
is_valid_token = False
if is_valid_token:
cols = tl.arange(0, BLOCK)
mask = cols < N
x = tl.load(x_ptr + token_idx * stride_x + cols, mask=mask,
other=0.0).to(tl.float32)
absmax = tl.maximum(tl.max(tl.abs(x)), 1e-10)
scale_x = absmax / 127
x_q = x * (127 / absmax)
x_q = libdevice.nearbyint(x_q).to(tl.int8)
tl.store(xq_ptr + token_idx * stride_xq + cols, x_q, mask=mask)
tl.store(scale_ptr + token_idx, scale_x)
def per_token_quant_int8_triton_opt(x: torch.Tensor,
tokens_per_expert: Optional[torch.Tensor] = None):
if x.dim() != 3:
raise ValueError(f"Input must be 3D [E, T, H], but got {x.shape}")
E, T, H = x.shape
N = H
x_q = torch.empty_like(x, device=x.device, dtype=torch.int8)
scales = torch.empty(x.shape[:-1] + (1, ), device=x.device, dtype=torch.float32)
BLOCK = triton.next_power_of_2(N)
num_warps = min(max(BLOCK // 256, 1), 8)
if (E == 8 and T >= 1024) or (E == 16 and T >= 512):
num_warps = 1
num_tokens = E * T
grid_opt = num_tokens
if (E == 8 and T >= 1024) or (E == 16 and T >= 512):
grid_opt = max(1, num_tokens // (T // 256))
_per_token_quant_int8_kernel_opt[(grid_opt, )](
x,
x_q,
scales,
stride_x=x.stride(-2),
stride_xq=x_q.stride(-2),
N=N,
E_dim=E,
T_dim=T,
tokens_per_expert_ptr=tokens_per_expert,
BLOCK=BLOCK,
num_warps=num_warps,
num_stages=1,
)
else:
_per_token_quant_int8_one_kernel_opt[(grid_opt, )](
x,
x_q,
scales,
stride_x=x.stride(-2),
stride_xq=x_q.stride(-2),
N=N,
T_dim=T,
tokens_per_expert_ptr=tokens_per_expert,
BLOCK=BLOCK,
num_warps=num_warps,
num_stages=1,
)
return x_q, scales
\ No newline at end of file
...@@ -2,7 +2,7 @@ from __future__ import annotations ...@@ -2,7 +2,7 @@ from __future__ import annotations
import logging import logging
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union
from collections import defaultdict
from sglang.srt.layers.quantization.compressed_tensors.compressed_tensors_marlin import SlimQuantCompressedTensorsMarlinConfig from sglang.srt.layers.quantization.compressed_tensors.compressed_tensors_marlin import SlimQuantCompressedTensorsMarlinConfig
from sglang.srt.layers.quantization.slimquant_w4a8_marlin import SlimQuantW4A8Int8MarlinConfig from sglang.srt.layers.quantization.slimquant_w4a8_marlin import SlimQuantW4A8Int8MarlinConfig
import torch import torch
...@@ -20,6 +20,7 @@ from sglang.srt.layers.moe.ep_moe.kernels import ( ...@@ -20,6 +20,7 @@ from sglang.srt.layers.moe.ep_moe.kernels import (
ep_scatter, ep_scatter,
silu_and_mul_masked_post_quant_fwd, silu_and_mul_masked_post_quant_fwd,
tma_align_input_scale, tma_align_input_scale,
per_token_quant_int8_triton_opt,
) )
from sglang.srt.layers.moe.fused_moe_triton.layer import FlashInferFusedMoE, FusedMoE from sglang.srt.layers.moe.fused_moe_triton.layer import FlashInferFusedMoE, FusedMoE
from sglang.srt.layers.moe.topk import TopKOutput from sglang.srt.layers.moe.topk import TopKOutput
...@@ -40,7 +41,7 @@ if TYPE_CHECKING: ...@@ -40,7 +41,7 @@ if TYPE_CHECKING:
DeepEPNormalOutput, DeepEPNormalOutput,
DispatchOutput, DispatchOutput,
) )
from lightop import m_grouped_w4a8_gemm_nt_masked, fuse_silu_mul_quant_ep, m_grouped_w8a8_gemm_nt_masked from lightop import m_grouped_w4a8_gemm_nt_masked, fuse_silu_mul_quant_ep, m_grouped_w8a8_gemm_nt_masked, m_grouped_w8a8_gemm_nt_contig_asm, fuse_silu_mul_quant
from lmslim.layers.gemm.int8_utils import per_token_quant_int8 from lmslim.layers.gemm.int8_utils import per_token_quant_int8
_is_hip = is_hip() _is_hip = is_hip()
...@@ -605,6 +606,8 @@ class DeepEPMoE(EPMoE): ...@@ -605,6 +606,8 @@ class DeepEPMoE(EPMoE):
return self.forward_deepgemm_contiguous(dispatch_output) return self.forward_deepgemm_contiguous(dispatch_output)
elif self.use_w4a8_marlin: elif self.use_w4a8_marlin:
return self.forward_deepgemm_w4a8_marlin_contiguous(dispatch_output) return self.forward_deepgemm_w4a8_marlin_contiguous(dispatch_output)
elif self.use_w8a8_marlin:
return self.forward_groupgemm_w8a8_marlin_contiguous(dispatch_output)
else: else:
raise ValueError( raise ValueError(
f"Dispatch output is not supported" f"Dispatch output is not supported"
...@@ -709,6 +712,111 @@ class DeepEPMoE(EPMoE): ...@@ -709,6 +712,111 @@ class DeepEPMoE(EPMoE):
) )
return expert_output return expert_output
def forward_groupgemm_w8a8_marlin_contiguous(
self,
dispatch_output: DeepEPNormalOutput,
):
hidden_states, hidden_states_scale, topk_idx, topk_weights, num_recv_tokens_per_expert = dispatch_output
assert self.quant_method is not None
assert self.moe_runner_config.activation == "silu"
all_tokens = sum(num_recv_tokens_per_expert)
if all_tokens <= 0:
return hidden_states.bfloat16()
device = hidden_states.device
M = hidden_states.shape[0]
K = hidden_states.shape[1]
topk = topk_idx.shape[1]
active_experts = set()
token_expert_pos = [None] * M
for t in range(M):
lst = []
for pos in range(topk):
e = int(topk_idx[t, pos].item())
if e >= 0:
lst.append((e, pos))
active_experts.add(e)
token_expert_pos[t] = lst
active_experts = sorted(list(active_experts))
num_active = len(active_experts)
if num_active == 0:
return hidden_states.bfloat16()
counts = defaultdict(int)
for t in range(M):
for (e, pos) in token_expert_pos[t]:
counts[e] += 1
per_expert_block = {}
for e in active_experts:
cnt = counts.get(e, 0)
if cnt <= 0:
per_expert_block[e] = 0
else:
needed = ((cnt + 256 - 1) // 256) * 256 # next multiple of 256
per_expert_block[e] = max(256, needed)
expert_slot_offset = {}
offset = 0
for e in active_experts:
expert_slot_offset[e] = offset
offset += per_expert_block[e]
pad_M = offset
hidden_states_packed = torch.zeros((pad_M, K), device=device, dtype=hidden_states.dtype)
m_indices = torch.full((pad_M,), -1, device=device, dtype=torch.int32)
slot_counters = {e: 0 for e in active_experts}
token_row_weight_list = {t: [] for t in range(M)}
for t in range(M):
for (e, pos) in token_expert_pos[t]:
start = expert_slot_offset[e]
slot = slot_counters[e]
if slot >= per_expert_block[e]:
raise RuntimeError(f"Internal error: expert {e} slot {slot} >= block {per_expert_block[e]}")
row = start + slot
hidden_states_packed[row] = hidden_states[t]
m_indices[row] = int(e)
slot_counters[e] += 1
w = topk_weights[t, pos].to(device=device)
w_f = w.float() if w.dtype != torch.float32 else w
token_row_weight_list[t].append((row, w_f))
q_a1_all, q_a1_scale = per_token_quant_int8(hidden_states_packed)
N = self.w13_weight.size(1)
gateup_output = torch.empty((pad_M, N * 16), device=device, dtype=torch.bfloat16)
m_grouped_w8a8_gemm_nt_contig_asm(
(q_a1_all, q_a1_scale),
(self.w13_weight, self.w13_weight_scale),
gateup_output,
m_indices,
)
q_a2_all, q_a2_scale = fuse_silu_mul_quant(gateup_output)
down_output = torch.empty((pad_M, K), device=device, dtype=torch.bfloat16)
down_output = m_grouped_w8a8_gemm_nt_contig_asm(
(q_a2_all, q_a2_scale),
(self.w2_weight, self.w2_weight_scale),
down_output,
m_indices,
)
result = torch.zeros((M, K), device=device, dtype=down_output.dtype)
for t in range(M):
pairs = token_row_weight_list[t]
if not pairs:
continue
acc = None
for (row, w) in pairs:
vec = down_output[row].float()
weighted = vec * w
acc = weighted if acc is None else (acc + weighted)
result[t] = acc.to(result.dtype)
return result
def forward_deepgemm_contiguous( def forward_deepgemm_contiguous(
self, self,
...@@ -899,10 +1007,10 @@ class DeepEPMoE(EPMoE): ...@@ -899,10 +1007,10 @@ class DeepEPMoE(EPMoE):
# base shapes # base shapes
num_groups, m, k = hidden_states.size() num_groups, m, k = hidden_states.size()
expected_m = min(m, expected_m) expected_m = min(m, expected_m)
# ---- first quant: ensure float input for quantizer ---- # ---- first quant: ensure float input for quantizer ----
q_a1_all, q_a1_scale = per_token_quant_int8(hidden_states) q_a1_all, q_a1_scale = per_token_quant_int8_triton_opt(hidden_states, masked_m)
# ---- weights & scales ---- # ---- weights & scales ----
w13_weight = self.w13_weight w13_weight = self.w13_weight
...@@ -943,16 +1051,15 @@ class DeepEPMoE(EPMoE): ...@@ -943,16 +1051,15 @@ class DeepEPMoE(EPMoE):
dispatch_output: DeepEPLLOutput, dispatch_output: DeepEPLLOutput,
): ):
hidden_states, _, _, _, masked_m, expected_m = dispatch_output hidden_states, _, topk_ids, _, masked_m, expected_m = dispatch_output
assert self.quant_method is not None assert self.quant_method is not None
assert self.moe_runner_config.activation == "silu" assert self.moe_runner_config.activation == "silu"
# base shapes # base shapes
num_groups, m, k = hidden_states.size() num_groups, m, k = hidden_states.size()
expected_m = min(m, expected_m) expected_m = min(m, expected_m)
# ---- first quant: ensure float input for quantizer ---- # ---- first quant: ensure float input for quantizer ----
q_a1_all, q_a1_scale = per_token_quant_int8(hidden_states) q_a1_all, q_a1_scale = per_token_quant_int8_triton_opt(hidden_states, masked_m)
# ---- weights & scales ---- # ---- weights & scales ----
w13_weight = self.w13_weight w13_weight = self.w13_weight
......
...@@ -308,7 +308,7 @@ class _DeepEPDispatcherImplBase: ...@@ -308,7 +308,7 @@ class _DeepEPDispatcherImplBase:
self.params_bytes = 2 self.params_bytes = 2
self.num_max_dispatch_tokens_per_rank = get_int_env_var( self.num_max_dispatch_tokens_per_rank = get_int_env_var(
"SGLANG_DEEPEP_NUM_MAX_DISPATCH_TOKENS_PER_RANK", 128 "SGLANG_DEEPEP_NUM_MAX_DISPATCH_TOKENS_PER_RANK", 64
) )
# DeepEP internode_ll dispatch uses FINISHED_SUM_TAG=1024 # DeepEP internode_ll dispatch uses FINISHED_SUM_TAG=1024
# and the logic requires num-tokens-sent-from-one-rank-to-another-rank less than it # and the logic requires num-tokens-sent-from-one-rank-to-another-rank less than it
...@@ -441,18 +441,19 @@ class _DeepEPDispatcherImplNormal(_DeepEPDispatcherImplBase): ...@@ -441,18 +441,19 @@ class _DeepEPDispatcherImplNormal(_DeepEPDispatcherImplBase):
expert_alignment=1, expert_alignment=1,
config=DeepEPConfig.get_instance().normal_dispatch_config, config=DeepEPConfig.get_instance().normal_dispatch_config,
) )
# get_global_expert_distribution_recorder().on_deepep_dispatch_normal( if self.quant_config.get("quant_method") == "slimquant_w4a8_marlin":
# num_recv_tokens_per_expert, self.rank_expert_offset= get_moe_expert_parallel_rank() * ( self.num_experts // get_moe_expert_parallel_world_size())
# num_tokens_per_rank=num_tokens_per_rank, recv_topk_ids = torch.where(
# num_tokens_per_rdma_rank=num_tokens_per_rdma_rank, recv_topk_ids == -1,
# num_tokens_per_expert=num_tokens_per_expert, self.num_experts - 1 if self.rank_expert_offset == 0 else 0,
# ) recv_topk_ids + self.rank_expert_offset)
self.rank_expert_offset= get_moe_expert_parallel_rank() * ( self.num_experts // get_moe_expert_parallel_world_size()) else:
recv_topk_ids = torch.where( get_global_expert_distribution_recorder().on_deepep_dispatch_normal(
recv_topk_ids == -1, num_recv_tokens_per_expert,
self.num_experts - 1 if self.rank_expert_offset == 0 else 0, num_tokens_per_rank=num_tokens_per_rank,
recv_topk_ids + self.rank_expert_offset) num_tokens_per_rdma_rank=num_tokens_per_rdma_rank,
num_tokens_per_expert=num_tokens_per_expert,
)
return ( return (
recv_x, recv_x,
recv_topk_ids, recv_topk_ids,
...@@ -541,7 +542,7 @@ class _DeepEPDispatcherImplLowLatency(_DeepEPDispatcherImplBase): ...@@ -541,7 +542,7 @@ class _DeepEPDispatcherImplLowLatency(_DeepEPDispatcherImplBase):
num_max_dispatch_tokens_per_rank: the actual batch size in the decoding engine should be less than 256 num_max_dispatch_tokens_per_rank: the actual batch size in the decoding engine should be less than 256
https://github.com/deepseek-ai/DeepEP?tab=readme-ov-file#example-use-in-inference-decoding https://github.com/deepseek-ai/DeepEP?tab=readme-ov-file#example-use-in-inference-decoding
""" """
self.return_recv_hook = False self.return_recv_hook = return_recv_hook
self.device_module = torch.get_device_module() self.device_module = torch.get_device_module()
self.quant_config = {} self.quant_config = {}
...@@ -724,6 +725,7 @@ class _DeepEPDispatcherImplLowLatency(_DeepEPDispatcherImplBase): ...@@ -724,6 +725,7 @@ class _DeepEPDispatcherImplLowLatency(_DeepEPDispatcherImplBase):
self.packed_recv_count = self.handle = None self.packed_recv_count = self.handle = None
return combined_hidden_states, event, hook return combined_hidden_states, event, hook
@torch._dynamo.disable()
def _get_buffer(self): def _get_buffer(self):
DeepEPBuffer.set_dispatch_mode_as_low_latency() DeepEPBuffer.set_dispatch_mode_as_low_latency()
return DeepEPBuffer.get_deepep_buffer( return DeepEPBuffer.get_deepep_buffer(
......
# Adapted from https://github.com/vllm-project/vllm/tree/v0.8.2/vllm/model_executor/layers/quantization/compressed_tensors # Adapted from https://github.com/vllm-project/vllm/tree/v0.8.2/vllm/model_executor/layers/quantization/compressed_tensors
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
from __future__ import annotations from __future__ import annotations
import os
import logging import logging
from contextlib import suppress from contextlib import suppress
from typing import Any, Dict, List, Literal, NamedTuple, Optional, Tuple, cast from typing import Any, Dict, List, Literal, NamedTuple, Optional, Tuple, cast
...@@ -46,6 +46,9 @@ from sglang.srt.layers.quantization.fp8 import Fp8LinearMethod ...@@ -46,6 +46,9 @@ from sglang.srt.layers.quantization.fp8 import Fp8LinearMethod
from sglang.srt.layers.quantization.unquant import UnquantizedLinearMethod from sglang.srt.layers.quantization.unquant import UnquantizedLinearMethod
from sglang.srt.layers.quantization.kv_cache import BaseKVCacheMethod from sglang.srt.layers.quantization.kv_cache import BaseKVCacheMethod
from sglang.srt import _custom_ops as ops
from sglang.srt.utils import W8a8GetCacheJSON
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
__all__ = ["CompressedTensorsLinearMethod"] __all__ = ["CompressedTensorsLinearMethod"]
...@@ -590,8 +593,33 @@ class CompressedTensorsLinearMethod(LinearMethodBase): ...@@ -590,8 +593,33 @@ class CompressedTensorsLinearMethod(LinearMethodBase):
def __init__(self, quantization_config: CompressedTensorsConfig): def __init__(self, quantization_config: CompressedTensorsConfig):
self.quantization_config = quantization_config self.quantization_config = quantization_config
self.tritonsingleton= W8a8GetCacheJSON()
self.w8a8_strategy=int(os.getenv('W8A8_SUPPORT_METHODS', '1'))
def process_weights_after_loading(self, layer: torch.nn.Module) -> None: def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
n=layer.weight.shape[0]
k=layer.weight.shape[1]
if self.w8a8_strategy==1:
if [n,k] not in self.tritonsingleton.weight_shapes:
self.tritonsingleton.weight_shapes.append([n,k])
json_file=self.tritonsingleton.get_w8a8json_name(n,k)
configs_dict=self.tritonsingleton.get_triton_cache(json_file,n,k)
if configs_dict:
self.tritonsingleton.triton_json_dict.update(configs_dict)
for key, value in configs_dict.items():
m=int(key.split('_')[0])
ops.triton_int8_gemm_helper(m=m,n=n,k=k,per_token_act_quant=True,per_out_channel_weight_quant=True,use_bias=False,device=layer.weight.device,best_config=value)
elif self.w8a8_strategy==3:
layer.weight.data = layer.weight.data.T
else:
weight_data=layer.weight.data
_weight=weight_data.T.contiguous().reshape(n,-1)
layer.weight.data=_weight
self.tritonsingleton.gen_model_json()
layer.scheme.process_weights_after_loading(layer) layer.scheme.process_weights_after_loading(layer)
def create_weights( def create_weights(
......
# Adapted from https://github.com/vllm-project/vllm/tree/main/vllm/model_executor/layers/quantization/compressed_tensors # Adapted from https://github.com/vllm-project/vllm/tree/main/vllm/model_executor/layers/quantization/compressed_tensors
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
import os
from typing import Callable, Optional from typing import Callable, Optional
import torch import torch
...@@ -19,11 +20,13 @@ from sglang.srt.layers.quantization.compressed_tensors.schemes import ( ...@@ -19,11 +20,13 @@ from sglang.srt.layers.quantization.compressed_tensors.schemes import (
from lmslim.layers.gemm.int8_utils import per_token_quant_int8 from lmslim.layers.gemm.int8_utils import per_token_quant_int8
from sglang.srt.layers.quantization.utils import requantize_with_max_scale from sglang.srt.layers.quantization.utils import requantize_with_max_scale
from sglang.srt.utils import is_cuda from sglang.srt.utils import is_cuda
from lmslim import quant_ops
from sglang.srt import _custom_ops as ops
_is_cuda = is_cuda() _is_cuda = is_cuda()
if _is_cuda: if _is_cuda:
from sgl_kernel import int8_scaled_mm from sgl_kernel import int8_scaled_mm
from sglang.srt.utils import W8a8GetCacheJSON
W8A8_TRITONJSON=W8a8GetCacheJSON()
class CompressedTensorsW8A8Int8(CompressedTensorsScheme): class CompressedTensorsW8A8Int8(CompressedTensorsScheme):
...@@ -33,6 +36,7 @@ class CompressedTensorsW8A8Int8(CompressedTensorsScheme): ...@@ -33,6 +36,7 @@ class CompressedTensorsW8A8Int8(CompressedTensorsScheme):
self.strategy = strategy self.strategy = strategy
self.is_static_input_scheme = is_static_input_scheme self.is_static_input_scheme = is_static_input_scheme
self.input_symmetric = input_symmetric self.input_symmetric = input_symmetric
self.w8a8_strategy=int(os.getenv('W8A8_SUPPORT_METHODS', '1')) # TODO
@classmethod @classmethod
def get_min_capability(cls) -> int: def get_min_capability(cls) -> int:
...@@ -163,14 +167,70 @@ class CompressedTensorsW8A8Int8(CompressedTensorsScheme): ...@@ -163,14 +167,70 @@ class CompressedTensorsW8A8Int8(CompressedTensorsScheme):
) )
layer.register_parameter("input_zero_point", input_zero_point) layer.register_parameter("input_zero_point", input_zero_point)
@torch._dynamo.disable()
def apply_weights( def apply_weights(
self, layer: torch.nn.Module, x: torch.Tensor, bias: Optional[torch.Tensor] self, layer: torch.nn.Module, x: torch.Tensor, bias: Optional[torch.Tensor]
) -> torch.Tensor: ) -> torch.Tensor:
# TODO: add cutlass_scaled_mm_azp support # TODO: add cutlass_scaled_mm_azp support
x_q, x_scale = per_token_quant_int8(x) x_q, x_scale = per_token_quant_int8(x)
# return quant_ops.custom_scaled_mm(x_q, layer.weight, x_scale, layer.weight_scale, out_dtype=x.dtype, bias=bias)
# TODO: fix with lmslim/lightop
return quant_ops.triton_scaled_mm( if self.w8a8_strategy==1:
x_q, layer.weight, x_scale, layer.weight_scale, out_dtype=x.dtype, bias=bias m=x_q.shape[0]
) k=x_q.shape[1]
n=layer.weight.shape[1]
if len(W8A8_TRITONJSON.triton_json_dict)==0:
best_config=None
elif f"1_{n}_{k}" in W8A8_TRITONJSON.triton_json_dict:
if m<=16:
m_=m
elif m<=64:
m_= (m + 3) & -4 #取值到最近的4的倍数
elif m<=160:
m_=(m + 7) & -8
elif m<200: #256
m_=160
elif m<480: #512
m_=256
elif m<960: #1024
m_=512
elif m<2048:
m_=1024
elif m<4096:
m_=2048
elif m<6000:
m_=4096
else:
m_=8192
best_config=W8A8_TRITONJSON.triton_json_dict[f"{m_}_{n}_{k}"]
else:
best_config=None
return ops.triton_scaled_mm(
x_q, layer.weight, x_scale, layer.weight_scale, out_dtype=x.dtype, bias=bias, best_config=best_config
)
elif self.w8a8_strategy==2:
return ops.cutlass_scaled_mm(x_q,
layer.weight,
scale_a=x_scale,
scale_b=layer.weight_scale,
out_dtype=x.dtype,
bias=bias)
elif self.w8a8_strategy==3:
return ops.blaslt_scaled_mm(x_q,
layer.weight,
scale_a=x_scale,
scale_b=layer.weight_scale,
out_dtype=x.dtype,
bias=None)
else:
return ops.rocblas_scaled_mm(x_q,
layer.weight,
scale_a=x_scale,
scale_b=layer.weight_scale,
out_dtype=x.dtype,
bias=bias)
...@@ -15,7 +15,7 @@ from lmslim.layers.gemm.int8_utils import ( ...@@ -15,7 +15,7 @@ from lmslim.layers.gemm.int8_utils import (
per_token_group_quant_int8, per_token_group_quant_int8,
per_token_quant_int8) per_token_quant_int8)
from sglang.srt import _custom_ops as ops from sglang.srt import _custom_ops as ops
from vllm.utils import W8a8GetCacheJSON from sglang.srt.utils import W8a8GetCacheJSON
from sglang.srt.layers.moe import MoeRunner, MoeRunnerBackend, MoeRunnerConfig from sglang.srt.layers.moe import MoeRunner, MoeRunnerBackend, MoeRunnerConfig
import os import os
...@@ -157,7 +157,6 @@ class SlimQuantW4A8Int8LinearMethod(LinearMethodBase): ...@@ -157,7 +157,6 @@ class SlimQuantW4A8Int8LinearMethod(LinearMethodBase):
) )
layer.register_parameter("weight_scale", weight_scale) layer.register_parameter("weight_scale", weight_scale)
@torch._dynamo.disable() # TODO: 性能优化需要lmslim/lightop配合
def apply( def apply(
self, self,
layer: torch.nn.Module, layer: torch.nn.Module,
...@@ -227,6 +226,13 @@ class SlimQuantW4A8Int8LinearMethod(LinearMethodBase): ...@@ -227,6 +226,13 @@ class SlimQuantW4A8Int8LinearMethod(LinearMethodBase):
scale_b=layer.weight_scale, scale_b=layer.weight_scale,
out_dtype=x.dtype, out_dtype=x.dtype,
bias=bias) bias=bias)
elif self.w8a8_strategy==3:
return ops.blaslt_scaled_mm(x_q,
layer.weight,
scale_a=x_scale,
scale_b=layer.weight_scale,
out_dtype=x.dtype,
bias=None)
else: else:
return ops.rocblas_scaled_mm(x_q, return ops.rocblas_scaled_mm(x_q,
layer.weight, layer.weight,
......
...@@ -13,7 +13,8 @@ from sglang.srt.mem_cache.chunk_cache import ChunkCache, SWAChunkCache ...@@ -13,7 +13,8 @@ from sglang.srt.mem_cache.chunk_cache import ChunkCache, SWAChunkCache
from sglang.srt.mem_cache.mamba_radix_cache import MambaRadixCache from sglang.srt.mem_cache.mamba_radix_cache import MambaRadixCache
from sglang.srt.mem_cache.memory_pool import HybridReqToTokenPool, ReqToTokenPool from sglang.srt.mem_cache.memory_pool import HybridReqToTokenPool, ReqToTokenPool
from sglang.srt.server_args import get_global_server_args from sglang.srt.server_args import get_global_server_args
from sglang.srt.utils import support_triton from sglang.srt.utils import support_triton,get_bool_env_var
from sgl_kernel.kvcacheio import dcu_get_last_loc
if TYPE_CHECKING: if TYPE_CHECKING:
from sglang.srt.managers.schedule_batch import Req, ScheduleBatch from sglang.srt.managers.schedule_batch import Req, ScheduleBatch
...@@ -125,13 +126,17 @@ def get_last_loc( ...@@ -125,13 +126,17 @@ def get_last_loc(
req_pool_indices_tensor: torch.Tensor, req_pool_indices_tensor: torch.Tensor,
prefix_lens_tensor: torch.Tensor, prefix_lens_tensor: torch.Tensor,
) -> torch.Tensor: ) -> torch.Tensor:
if ( use_sglang_get_last_loc = get_bool_env_var("SGLANG_GET_LAST_LOC")
get_global_server_args().attention_backend != "ascend" if use_sglang_get_last_loc:
and get_global_server_args().attention_backend != "torch_native" impl = dcu_get_last_loc
):
impl = get_last_loc_triton
else: else:
impl = get_last_loc_torch if (
get_global_server_args().attention_backend != "ascend"
and get_global_server_args().attention_backend != "torch_native"
):
impl = get_last_loc_triton
else:
impl = get_last_loc_torch
return impl(req_to_token, req_pool_indices_tensor, prefix_lens_tensor) return impl(req_to_token, req_pool_indices_tensor, prefix_lens_tensor)
......
...@@ -46,7 +46,11 @@ from sglang.srt.layers.dp_attention import ( ...@@ -46,7 +46,11 @@ from sglang.srt.layers.dp_attention import (
set_dp_buffer_len, set_dp_buffer_len,
set_is_extend_in_batch, set_is_extend_in_batch,
) )
from sglang.srt.utils import get_compiler_backend, is_npu, support_triton from sglang.srt.utils import get_compiler_backend, is_npu, support_triton,get_bool_env_var
from sgl_kernel.kvcacheio import dcu_create_chunked_prefix_cache_kv_indices
import logging
logger = logging.getLogger(__name__)
if TYPE_CHECKING: if TYPE_CHECKING:
from sglang.srt.layers.attention.base_attn_backend import AttentionBackend from sglang.srt.layers.attention.base_attn_backend import AttentionBackend
...@@ -123,13 +127,13 @@ class ForwardMode(IntEnum): ...@@ -123,13 +127,13 @@ class ForwardMode(IntEnum):
# For fixed shape logits output in v2 eagle worker # For fixed shape logits output in v2 eagle worker
return self == ForwardMode.DRAFT_EXTEND_V2 return self == ForwardMode.DRAFT_EXTEND_V2
def is_extend_or_draft_extend_or_mixed(self): #nhb def is_extend_or_draft_extend_or_mixed(self):
return ( return (
self == ForwardMode.EXTEND self == ForwardMode.EXTEND
or self == ForwardMode.DRAFT_EXTEND or self == ForwardMode.DRAFT_EXTEND
or self == ForwardMode.MIXED or self == ForwardMode.MIXED
or self == ForwardMode.SPLIT_PREFILL or self == ForwardMode.SPLIT_PREFILL
or self == ForwardMode.DRAFT_EXTEND_V2 or self == ForwardMode.DRAFT_EXTEND_V2 #nhb
) )
def is_cuda_graph(self): def is_cuda_graph(self):
...@@ -317,6 +321,8 @@ class ForwardBatch: ...@@ -317,6 +321,8 @@ class ForwardBatch:
tbo_parent_token_range: Optional[Tuple[int, int]] = None tbo_parent_token_range: Optional[Tuple[int, int]] = None
tbo_children: Optional[List[ForwardBatch]] = None tbo_children: Optional[List[ForwardBatch]] = None
use_sglang_create_chunked_prefix_cache_kv_indices = get_bool_env_var("SGLANG_CREATE_CHUNKED_PREFIX_CACHE_KV_INDICES")
@classmethod @classmethod
def init_new( def init_new(
cls, cls,
...@@ -363,13 +369,13 @@ class ForwardBatch: ...@@ -363,13 +369,13 @@ class ForwardBatch:
if batch.extend_input_logprob_token_ids is not None: if batch.extend_input_logprob_token_ids is not None:
ret.extend_input_logprob_token_ids_gpu = ( ret.extend_input_logprob_token_ids_gpu = (
batch.extend_input_logprob_token_ids.to(device, non_blocking=True) batch.extend_input_logprob_token_ids.pin_memory().to(device, non_blocking=True)
) )
if enable_num_token_non_padded(model_runner.server_args): if enable_num_token_non_padded(model_runner.server_args):
ret.num_token_non_padded = torch.tensor( ret.num_token_non_padded = torch.tensor(
len(batch.input_ids), dtype=torch.int32 len(batch.input_ids), dtype=torch.int32
).to(device, non_blocking=True) ).pin_memory().to(device, non_blocking=True)
ret.num_token_non_padded_cpu = len(batch.input_ids) ret.num_token_non_padded_cpu = len(batch.input_ids)
# For MLP sync # For MLP sync
...@@ -389,12 +395,12 @@ class ForwardBatch: ...@@ -389,12 +395,12 @@ class ForwardBatch:
ret.global_num_tokens_cpu = global_num_tokens ret.global_num_tokens_cpu = global_num_tokens
ret.global_num_tokens_gpu = torch.tensor( ret.global_num_tokens_gpu = torch.tensor(
global_num_tokens, dtype=torch.int64 global_num_tokens, dtype=torch.int64
).to(device, non_blocking=True) ).pin_memory().to(device, non_blocking=True)
ret.global_num_tokens_for_logprob_cpu = global_num_tokens_for_logprob ret.global_num_tokens_for_logprob_cpu = global_num_tokens_for_logprob
ret.global_num_tokens_for_logprob_gpu = torch.tensor( ret.global_num_tokens_for_logprob_gpu = torch.tensor(
global_num_tokens_for_logprob, dtype=torch.int64 global_num_tokens_for_logprob, dtype=torch.int64
).to(device, non_blocking=True) ).pin_memory().to(device, non_blocking=True)
if ret.forward_mode.is_idle(): if ret.forward_mode.is_idle():
ret.positions = torch.empty((0,), dtype=torch.int64, device=device) ret.positions = torch.empty((0,), dtype=torch.int64, device=device)
...@@ -419,10 +425,10 @@ class ForwardBatch: ...@@ -419,10 +425,10 @@ class ForwardBatch:
assert isinstance(batch.extend_prefix_lens, list) assert isinstance(batch.extend_prefix_lens, list)
ret.extend_seq_lens = torch.tensor( ret.extend_seq_lens = torch.tensor(
batch.extend_seq_lens, dtype=torch.int32 batch.extend_seq_lens, dtype=torch.int32
).to(device, non_blocking=True) ).pin_memory().to(device, non_blocking=True)
ret.extend_prefix_lens = torch.tensor( ret.extend_prefix_lens = torch.tensor(
batch.extend_prefix_lens, dtype=torch.int32 batch.extend_prefix_lens, dtype=torch.int32
).to(device, non_blocking=True) ).pin_memory().to(device, non_blocking=True)
ret.extend_num_tokens = batch.extend_num_tokens ret.extend_num_tokens = batch.extend_num_tokens
positions, ret.extend_start_loc = compute_position( positions, ret.extend_start_loc = compute_position(
model_runner.server_args.attention_backend, model_runner.server_args.attention_backend,
...@@ -635,15 +641,28 @@ class ForwardBatch: ...@@ -635,15 +641,28 @@ class ForwardBatch:
num_chunk_tokens, dtype=torch.int32, device=device num_chunk_tokens, dtype=torch.int32, device=device
) )
create_chunked_prefix_cache_kv_indices[(self.batch_size,)]( if self.use_sglang_create_chunked_prefix_cache_kv_indices:
self.req_to_token_pool.req_to_token, dcu_create_chunked_prefix_cache_kv_indices(
self.req_pool_indices, req_to_token = self.req_to_token_pool.req_to_token,
chunk_starts, req_pool_indices = self.req_pool_indices,
chunk_seq_lens, chunk_starts = chunk_starts,
chunk_cu_seq_lens, chunk_seq_lens = chunk_seq_lens,
chunk_kv_indices, chunk_cu_seq_lens = chunk_cu_seq_lens,
self.req_to_token_pool.req_to_token.shape[1], chunk_kv_indices = chunk_kv_indices,
) col_num = self.req_to_token_pool.req_to_token.shape[1],
bs = self.batch_size,
)
else:
# logger.info("SGLANG_CREATE_CHUNKED_PREFIX_CACHE_KV_INDICES=0")
create_chunked_prefix_cache_kv_indices[(self.batch_size,)](
self.req_to_token_pool.req_to_token,
self.req_pool_indices,
chunk_starts,
chunk_seq_lens,
chunk_cu_seq_lens,
chunk_kv_indices,
self.req_to_token_pool.req_to_token.shape[1],
)
self.prefix_chunk_kv_indices.append(chunk_kv_indices) self.prefix_chunk_kv_indices.append(chunk_kv_indices)
def _pad_tensor_to_size(self, tensor: torch.Tensor, size: int, *, value: int = 0): def _pad_tensor_to_size(self, tensor: torch.Tensor, size: int, *, value: int = 0):
......
...@@ -237,7 +237,14 @@ class DraftBackendFactory: ...@@ -237,7 +237,14 @@ class DraftBackendFactory:
return None return None
def _create_dcumla_prefill_backend(self): def _create_dcumla_prefill_backend(self):
logger.warning( # logger.warning(
"flashmla prefill backend is not yet supported for draft extend." # "flashmla prefill backend is not yet supported for draft extend."
# )
# return None
#nhb
from sglang.srt.layers.attention.flashattention_backend import (
FlashAttentionBackend,
) )
return None
return FlashAttentionBackend(self.draft_model_runner, skip_prefill=False)
...@@ -29,6 +29,12 @@ from sglang.srt.speculative.spec_utils import ( ...@@ -29,6 +29,12 @@ from sglang.srt.speculative.spec_utils import (
) )
from sglang.srt.utils.common import fast_topk, is_cuda, is_hip, next_power_of_2 from sglang.srt.utils.common import fast_topk, is_cuda, is_hip, next_power_of_2
from sglang.srt.utils import get_bool_env_var
from sgl_kernel.kvcacheio import dcu_assign_req_to_token_pool,dcu_assign_extend_cache_locs
import logging
logger = logging.getLogger(__name__)
if TYPE_CHECKING: if TYPE_CHECKING:
from sglang.srt.managers.tp_worker import TpModelWorker from sglang.srt.managers.tp_worker import TpModelWorker
from sglang.srt.speculative.eagle_draft_cuda_graph_runner import ( from sglang.srt.speculative.eagle_draft_cuda_graph_runner import (
...@@ -77,6 +83,9 @@ def assign_draft_cache_locs_page_size_1( ...@@ -77,6 +83,9 @@ def assign_draft_cache_locs_page_size_1(
@dataclass @dataclass
class EagleDraftInputV2Mixin: class EagleDraftInputV2Mixin:
use_sglang_assign_req_to_token_pool = get_bool_env_var("SGLANG_ASSIGN_REQ_TO_TOKEN_POOL")
def prepare_for_decode(self: EagleDraftInput, batch: ScheduleBatch): def prepare_for_decode(self: EagleDraftInput, batch: ScheduleBatch):
from sglang.srt.speculative.spec_utils import assign_req_to_token_pool from sglang.srt.speculative.spec_utils import assign_req_to_token_pool
...@@ -112,15 +121,26 @@ class EagleDraftInputV2Mixin: ...@@ -112,15 +121,26 @@ class EagleDraftInputV2Mixin:
extend_num_tokens, extend_num_tokens,
) )
assign_req_to_token_pool[(bs,)]( if self.use_sglang_assign_req_to_token_pool:
batch.req_pool_indices, dcu_assign_req_to_token_pool(
batch.req_to_token_pool.req_to_token, req_pool_indices = batch.req_pool_indices,
self.allocate_lens, req_to_token = batch.req_to_token_pool.req_to_token,
new_allocate_lens, allocate_lens = self.allocate_lens,
out_cache_loc, new_allocate_lens = new_allocate_lens,
batch.req_to_token_pool.req_to_token.shape[1], out_cache_loc = out_cache_loc,
next_power_of_2(bs), shape = batch.req_to_token_pool.req_to_token.shape[1],
) bs = bs,
)
else:
assign_req_to_token_pool[(bs,)](
batch.req_pool_indices,
batch.req_to_token_pool.req_to_token,
self.allocate_lens,
new_allocate_lens,
out_cache_loc,
batch.req_to_token_pool.req_to_token.shape[1],
next_power_of_2(bs),
)
self.allocate_lens = new_allocate_lens self.allocate_lens = new_allocate_lens
# FIXME(lsyin): make this sync optional # FIXME(lsyin): make this sync optional
...@@ -189,6 +209,9 @@ class EagleDraftInputV2Mixin: ...@@ -189,6 +209,9 @@ class EagleDraftInputV2Mixin:
@dataclass @dataclass
class EagleVerifyInputV2Mixin: class EagleVerifyInputV2Mixin:
use_sglang_assign_extend_cache_locs = get_bool_env_var("SGLANG_ASSIGN_EXTEND_CACHE_LOCS")
def prepare_for_v2_verify( def prepare_for_v2_verify(
self: EagleVerifyInput, self: EagleVerifyInput,
req_to_token_pool: ReqToTokenPool, req_to_token_pool: ReqToTokenPool,
...@@ -205,15 +228,26 @@ class EagleVerifyInputV2Mixin: ...@@ -205,15 +228,26 @@ class EagleVerifyInputV2Mixin:
device=device, device=device,
) )
assign_extend_cache_locs[(bs,)]( if self.use_sglang_assign_extend_cache_locs:
batch.req_pool_indices, dcu_assign_extend_cache_locs(
req_to_token_pool.req_to_token, batch.req_pool_indices,
batch.seq_lens, req_to_token_pool.req_to_token,
batch.seq_lens + self.draft_token_num, batch.seq_lens,
batch.out_cache_loc, batch.seq_lens + self.draft_token_num,
req_to_token_pool.req_to_token.shape[1], batch.out_cache_loc,
next_power_of_2(bs), req_to_token_pool.req_to_token.shape[1],
) bs,
)
else:
assign_extend_cache_locs[(bs,)](
batch.req_pool_indices,
req_to_token_pool.req_to_token,
batch.seq_lens,
batch.seq_lens + self.draft_token_num,
batch.out_cache_loc,
req_to_token_pool.req_to_token.shape[1],
next_power_of_2(bs),
)
# Get a forward batch # Get a forward batch
batch.forward_mode = ForwardMode.TARGET_VERIFY batch.forward_mode = ForwardMode.TARGET_VERIFY
......
...@@ -758,7 +758,7 @@ class TboForwardBatchPreparer: ...@@ -758,7 +758,7 @@ class TboForwardBatchPreparer:
# TODO we may make padding on both sub-batches to make it slightly more balanced # TODO we may make padding on both sub-batches to make it slightly more balanced
value_a = min(tbo_split_token_index, num_token_non_padded) value_a = min(tbo_split_token_index, num_token_non_padded)
value_b = max(0, num_token_non_padded - tbo_split_token_index) value_b = max(0, num_token_non_padded - tbo_split_token_index)
return torch.tensor([value_a, value_b], dtype=torch.int32).to( return torch.tensor([value_a, value_b], dtype=torch.int32).pin_memory().to(
device=get_global_server_args().device, non_blocking=True device=get_global_server_args().device, non_blocking=True
) )
......
...@@ -3528,3 +3528,158 @@ def cached_triton_kernel(key_fn=None): ...@@ -3528,3 +3528,158 @@ def cached_triton_kernel(key_fn=None):
return CachedKernel(fn, key_fn) return CachedKernel(fn, key_fn)
return decorator return decorator
# from vllm
class W8a8GetCacheJSON:
_instance = None
def __new__(cls, *args, **kwargs):
if cls._instance is None:
cls._instance = super(W8a8GetCacheJSON, cls).__new__(cls, *args, **kwargs)
cls._instance._initialize()
return cls._instance
def _initialize(self):
current_folder_path = os.path.dirname(os.path.abspath(__file__))
json_folder_path=current_folder_path+'/../../lmslim/configs/w8a8'
self.triton_json_dir=(os.getenv('TRITON_JSON_DIR', json_folder_path))
self.triton_json_dict={}
self.triton_moejson_dict={}
self.triton_json_list=[]
self.weight_shapes=[]
self.moe_weight_shapes=[]
arch_name = torch.cuda.get_device_properties("cuda").gcnArchName.split(':')[0]
arch_cu = torch.cuda.get_device_properties(torch.cuda.current_device()).multi_processor_count
device_name =arch_name+'_'+str(arch_cu)+'cu'
self.device_name=device_name
self.topk=1
self.quant_method=None
#析构函数,最后会生成model.json的配置文件
def gen_model_json(self,E:Optional[int]=0,block_size:Optional[list]=None):
json_dir = os.getenv('LMSLIM_TUNING_JSON', "None")
if json_dir != "None" and os.path.exists(json_dir):
#生成模型配置文件
# logger.info("model_tuning.json is at LMSLIM_TUNING_JSON:%s", json_dir)
config = {
"layers": {
"linear": {
"shapes": [],
"m_range":"None",
},
"moe": {
"shapes": [],
"m_range": "None",
"topk": self.topk
}
},
"quantization_config": {
"quant_method": self.quant_method,
"weight_block_size": "None"
}
}
# 处理 MoE shapes
for shape in self.moe_weight_shapes:
if len(shape) == 4: # 假设 MoE shape 是 [N1, N2,K] 格式
moe_config = {
"E": shape[0],
"N1": shape[1],
"N2": shape[2],
"K": shape[3], # 默认值
}
config["layers"]["moe"]["shapes"].append(moe_config)
for shape in self.weight_shapes:
config["layers"]["linear"]["shapes"].append(shape)
if block_size is not None:
config["quantization_config"]["weight_block_size"]=block_size
with open(json_dir+"/model.json", 'w') as f:
json.dump(config, f, indent=4)
# else:
# logger.info("LMSLIM_TUNING_JSON is not set")
def getspec_config(self,configs_dict,M,N,K):
if f"{M}_{N}_{K}" in configs_dict:
return configs_dict[f"{M}_{N}_{K}"]
else:
return None
def get_triton_cache(self,file_path,n,k):
#在非tuning的时候使用,当文件不存在则直接返回none
cache_json_file=file_path
if os.path.exists(file_path):
#try:
with open(cache_json_file, 'r') as file:
cachedata = json.load(file)
else:
return None
#把所有的cache解析成key:config的形式:[M_N_K]:[config]
configs_dict={}
for key, value in cachedata.items():
for sub_key, sub_value in value.items():
configs_key= f"{sub_key}_{key}"
configs_dict[configs_key]=sub_value
return configs_dict
def get_w8a8json_name(self,n,k):
return self.triton_json_dir+f"/W8A8_{n}_{k}_{self.device_name}.json"
def get_blockint8_triton_cache(self,file_path,n,k,block_n,block_k):
cache_json_file=file_path
if os.path.exists(file_path):
#try:
with open(cache_json_file, 'r') as file:
cachedata = json.load(file)
else:
return None
#把所有的cache解析成key:config的形式:[M_N_K]:[config]
configs_dict={}
for key, value in cachedata.items():
for sub_key, sub_value in value.items():
configs_key= f"{sub_key}_{key}"
configs_dict[configs_key]=sub_value
return configs_dict
def get_blockint8json_name(self,n,k,block_n,block_k):
return self.triton_json_dir+f"/linear_{n}_{k}_block[{block_n},{block_k}]_{self.device_name}.json"
def get_moeint8json_name(self,E,N1,N2,K,TOPK,
block_size:Optional[list]=None,use_int4_w4a8:Optional[bool]=False):
if use_int4_w4a8:
if block_size is not None:
return self.triton_json_dir+f"/MOE_W4A8INT8[{block_size[0]},{block_size[1]}]_E={E}_N1={N1}_N2={N2}_K={K}_TOPK{TOPK}_{self.device_name}.json"
else:
return self.triton_json_dir+f"/MOE_W4A8INT8_E={E}_N1={N1}_N2={N2}_K={K}_TOPK{TOPK}_{self.device_name}.json"
else:
if block_size is not None:
return self.triton_json_dir+f"/MOE_BLOCKINT8[{block_size[0]},{block_size[1]}]_E={E}_N1={N1}_N2={N2}_K={K}_TOPK{TOPK}_{self.device_name}.json"
else:
return self.triton_json_dir+f"/MOE_W8A8INT8_E={E}_N1={N1}_N2={N2}_K={K}_TOPK{TOPK}_{self.device_name}.json"
def get_moeint8_triton_cache(self,file_path,E,N1,N2,K,TOPK):
cache_json_file=file_path
if os.path.exists(file_path):
#try:
with open(cache_json_file, 'r') as file:
cachedata = json.load(file)
else:
return None
#把所有的cache解析成key:config的形式:[M_N_K]:[config1,config2]
configs_dict={}
for key, value in cachedata.items():
for sub_key, sub_value in value.items():
configs_key= f"{sub_key}_{key}"
configs_dict[configs_key]=sub_value
return configs_dict
...@@ -19,6 +19,14 @@ limitations under the License. ...@@ -19,6 +19,14 @@ limitations under the License.
#include "sgl_kernel_ops.h" #include "sgl_kernel_ops.h"
TORCH_LIBRARY_EXPAND(sgl_kernel, m) { TORCH_LIBRARY_EXPAND(sgl_kernel, m) {
/*
* From FlashMLA
*/
m.def("dcu_create_flashmla_kv_indices(Tensor req_to_token, Tensor req_pool_indices,Tensor page_kernel_lens, Tensor? kv_start_idx, Tensor kv_indices, int req_to_token_stride, int kv_indices_stride, int PAGED_SIZE) -> ()");
m.impl("dcu_create_flashmla_kv_indices", torch::kCUDA, &dcu_create_flashmla_kv_indices);
/* /*
* From csrc/activation * From csrc/activation
*/ */
...@@ -133,6 +141,15 @@ TORCH_LIBRARY_EXPAND(sgl_kernel, m) { ...@@ -133,6 +141,15 @@ TORCH_LIBRARY_EXPAND(sgl_kernel, m) {
*/ */
m.def("dcu_create_extend_after_decode_spec_info(Tensor verified_id, Tensor seq_lens, Tensor accept_lens, Tensor positions, Tensor new_verified_id, int bs) -> ()"); m.def("dcu_create_extend_after_decode_spec_info(Tensor verified_id, Tensor seq_lens, Tensor accept_lens, Tensor positions, Tensor new_verified_id, int bs) -> ()");
m.impl("dcu_create_extend_after_decode_spec_info", torch::kCUDA, &dcu_create_extend_after_decode_spec_info); m.impl("dcu_create_extend_after_decode_spec_info", torch::kCUDA, &dcu_create_extend_after_decode_spec_info);
m.def("dcu_create_chunked_prefix_cache_kv_indices(Tensor req_to_token, Tensor req_pool_indices, Tensor chunk_starts, Tensor chunk_seq_lens, Tensor chunk_cu_seq_lens, Tensor chunk_kv_indices, int col_num, int bs) -> ()");
m.impl("dcu_create_chunked_prefix_cache_kv_indices", torch::kCUDA, &dcu_create_chunked_prefix_cache_kv_indices);
m.def("dcu_assign_extend_cache_locs(Tensor req_pool_indices, Tensor req_to_token, Tensor start_offset, Tensor end_offset, Tensor out_cache_loc, int pool_len, int bs) -> ()");
m.impl("dcu_assign_extend_cache_locs", torch::kCUDA, &dcu_assign_extend_cache_locs);
m.def("dcu_get_last_loc(Tensor req_to_token, Tensor req_pool_indices, Tensor prefix_lens) -> Tensor");
m.impl("dcu_get_last_loc", torch::kCUDA, &dcu_get_last_loc);
m.def("dcu_assign_req_to_token_pool(Tensor req_pool_indices_ptr,Tensor req_to_token_ptr,Tensor allocate_lens_ptr,Tensor new_allocate_lens,Tensor out_cache_loc_ptr,int shape,int bs) -> ()");
m.impl("dcu_assign_req_to_token_pool",torch::kCUDA,&dcu_assign_req_to_token_pool);
m.def("dcu_alloc_extend_kernel(Tensor pre_lens_ptr, Tensor seq_lens_ptr, Tensor last_loc_ptr, Tensor free_page_ptr, Tensor out_indices, int bs, int page_size) -> ()"); m.def("dcu_alloc_extend_kernel(Tensor pre_lens_ptr, Tensor seq_lens_ptr, Tensor last_loc_ptr, Tensor free_page_ptr, Tensor out_indices, int bs, int page_size) -> ()");
m.impl("dcu_alloc_extend_kernel", torch::kCUDA, &dcu_alloc_extend_kernel); m.impl("dcu_alloc_extend_kernel", torch::kCUDA, &dcu_alloc_extend_kernel);
......
...@@ -836,4 +836,322 @@ void dcu_alloc_extend_kernel( ...@@ -836,4 +836,322 @@ void dcu_alloc_extend_kernel(
cudaStream_t torch_current_stream = at::cuda::getCurrentCUDAStream(); cudaStream_t torch_current_stream = at::cuda::getCurrentCUDAStream();
launch_alloc_extend_kernel<<<grid_size, block_size, 0, torch_current_stream>>>(pre_lens_ptr1, seq_lens_ptr1, last_loc_ptr1, free_page_ptr1, out_indices1, bs, page_size); launch_alloc_extend_kernel<<<grid_size, block_size, 0, torch_current_stream>>>(pre_lens_ptr1, seq_lens_ptr1, last_loc_ptr1, free_page_ptr1, out_indices1, bs, page_size);
C10_CUDA_KERNEL_LAUNCH_CHECK(); C10_CUDA_KERNEL_LAUNCH_CHECK();
} }
\ No newline at end of file
__global__ void launch_assign_req_to_token_pool(
const int64_t* req_pool_indices_ptr,
int32_t* req_to_token_ptr,
const int64_t* allocate_lens_ptr,
int64_t* new_allocate_lens,
int64_t* out_cache_loc_ptr,
int64_t shape,
int64_t bs)
{
int64_t pid = blockIdx.x * blockDim.x + threadIdx.x;
if (pid >= bs) return;
int64_t kv_start = allocate_lens_ptr[pid];
int64_t kv_end = new_allocate_lens[pid];
int64_t pool_idx = req_pool_indices_ptr[pid];
int32_t* token_pool = (int32_t*)(req_to_token_ptr + pool_idx * shape);
int64_t sum_out_offset = 0;
for(int length_offset = 0; length_offset < pid;length_offset++){
int64_t start = allocate_lens_ptr[length_offset];
int64_t end = new_allocate_lens[length_offset];
sum_out_offset += (end- start);
}
int64_t* out_cache_ptr = out_cache_loc_ptr + sum_out_offset;
int64_t copy_length = kv_end - kv_start;
#pragma unroll(32)
for (int out_cache_index = 0; out_cache_index < copy_length; out_cache_index++) {
token_pool[kv_start + out_cache_index] = out_cache_ptr[out_cache_index];
}
}
void dcu_assign_req_to_token_pool(
const at::Tensor req_pool_indices_ptr,
at::Tensor req_to_token_ptr,
const at::Tensor allocate_lens_ptr,
at::Tensor new_allocate_lens,
at::Tensor out_cache_loc_ptr,
int64_t shape,
int64_t bs) {
const int64_t* req_pool_indices_ptr1 = static_cast<const int64_t*>(req_pool_indices_ptr.data_ptr());
int32_t* req_to_token_ptr1 = static_cast<int32_t*>(req_to_token_ptr.data_ptr());
const int64_t* allocate_lens_ptr1 = static_cast<const int64_t*>(allocate_lens_ptr.data_ptr());
int64_t* new_allocate_lens1 = static_cast<int64_t*>(new_allocate_lens.data_ptr());
int64_t* out_cache_loc_ptr1 = static_cast<int64_t*>(out_cache_loc_ptr.data_ptr());
int64_t block_size = 64;
int64_t grid_size = (bs + block_size - 1) / block_size;
cudaStream_t torch_current_stream = at::cuda::getCurrentCUDAStream();
launch_assign_req_to_token_pool<<<grid_size, block_size, 0, torch_current_stream>>>(req_pool_indices_ptr1, req_to_token_ptr1, allocate_lens_ptr1, new_allocate_lens1, out_cache_loc_ptr1, shape, bs);
C10_CUDA_KERNEL_LAUNCH_CHECK();
}
__global__ void get_last_loc_kernel(
const int32_t* __restrict__ req_to_token,
const int64_t* __restrict__ req_pool_indices_tensor,
const int64_t* __restrict__ prefix_lens_tensor,
int64_t* __restrict__ result,
int64_t num_tokens,
int64_t req_to_token_stride){
int64_t pid = blockIdx.x * blockDim.x + threadIdx.x;
if (pid >= num_tokens) return;
int64_t pre_len = prefix_lens_tensor[pid];
if (pre_len > 0) {
int64_t req_idx = req_pool_indices_tensor[pid];
int64_t token_idx = req_idx * req_to_token_stride + (pre_len - 1);
result[pid] = static_cast<int64_t>(req_to_token[token_idx]);
} else {
result[pid] = static_cast<int64_t>(-1);
}
}
at::Tensor dcu_get_last_loc(
const at::Tensor req_to_token,
const at::Tensor req_pool_indices,
const at::Tensor prefix_lens) {
TORCH_CHECK(req_to_token.device().is_cuda(), "req_to_token must be CUDA tensor");
TORCH_CHECK(req_pool_indices.device().is_cuda(), "req_pool_indices must be CUDA tensor");
TORCH_CHECK(prefix_lens.device().is_cuda(), "prefix_lens must be CUDA tensor");
TORCH_CHECK(req_to_token.dim() == 2, "req_to_token must be 2D tensor [batch, seq_len]");
TORCH_CHECK(prefix_lens.dim() == 1, "prefix_lens must be 1D");
TORCH_CHECK(req_pool_indices.dim() == 1, "req_pool_indices must be 1D");
int64_t num_tokens = prefix_lens.numel();
TORCH_CHECK(req_pool_indices.numel() == num_tokens, "req_pool_indices must have same length as prefix_lens");
int64_t req_to_token_stride = req_to_token.stride(0);
auto req_to_token_c = req_to_token.contiguous();
auto req_pool_indices_c = req_pool_indices.contiguous();
auto prefix_lens_c = prefix_lens.contiguous();
const int32_t* req_to_token_ptr = req_to_token_c.data_ptr<int32_t>();
const int64_t* req_pool_indices_ptr = req_pool_indices_c.data_ptr<int64_t>();
const int64_t* prefix_lens_ptr = prefix_lens_c.data_ptr<int64_t>();
auto result = at::empty_like(prefix_lens_c);
int64_t* result_ptr = result.data_ptr<int64_t>();
const int64_t block_size = 64;
const int64_t grid_size = (num_tokens + block_size - 1) / block_size;
cudaStream_t stream = at::cuda::getCurrentCUDAStream();
get_last_loc_kernel<<<grid_size, block_size, 0, stream>>>(
req_to_token_ptr,
req_pool_indices_ptr,
prefix_lens_ptr,
result_ptr,
num_tokens,
req_to_token_stride
);
C10_CUDA_KERNEL_LAUNCH_CHECK();
return result;
}
__global__ void launch_assign_extend_cache_locs_kernel(
const int64_t* __restrict__ req_pool_indices, // [bs]
const int32_t* __restrict__ req_to_token, // [max_num_req, pool_len]
const int64_t* __restrict__ start_offset, // [bs]
const int64_t* __restrict__ end_offset, // [bs]
int64_t* __restrict__ out_cache_loc, // [sum(draft_token_num)]
int64_t pool_len,
int64_t bs)
{
int pid = blockIdx.x * blockDim.x + threadIdx.x;
if (pid >= bs) return;
int64_t kv_start = start_offset[pid];
int64_t kv_end = end_offset[pid];
int64_t req_id = req_pool_indices[pid];
int64_t out_offset = 0;
for (int i = 0; i < pid; ++i) {
out_offset += end_offset[i] - start_offset[i];
}
const int32_t* src = req_to_token + req_id * pool_len + kv_start;
int64_t* dst = out_cache_loc + out_offset;
for (int64_t i = 0; i < kv_end - kv_start; ++i) {
dst[i] = src[i];
}
}
void dcu_assign_extend_cache_locs(
const at::Tensor req_pool_indices,
const at::Tensor req_to_token,
const at::Tensor start_offset,
const at::Tensor end_offset,
at::Tensor out_cache_loc,
int64_t pool_len,
int64_t bs)
{
const int64_t* req_pool_indices_ptr = req_pool_indices.data_ptr<int64_t>();
const int32_t* req_to_token_ptr = req_to_token.data_ptr<int32_t>();
const int64_t* start_offset_ptr = start_offset.data_ptr<int64_t>();
const int64_t* end_offset_ptr = end_offset.data_ptr<int64_t>();
int64_t* out_cache_loc_ptr = out_cache_loc.data_ptr<int64_t>();
constexpr int64_t threads = 128;
int64_t blocks = (bs + threads - 1) / threads;
cudaStream_t stream = at::cuda::getCurrentCUDAStream();
launch_assign_extend_cache_locs_kernel<<<blocks, threads, 0, stream>>>(
req_pool_indices_ptr,
req_to_token_ptr,
start_offset_ptr,
end_offset_ptr,
out_cache_loc_ptr,
pool_len,
bs);
C10_CUDA_KERNEL_LAUNCH_CHECK();
}
template<int PAGED_SIZE>
__global__ void dcu_create_flashmla_kv_indices_kernel(
const int32_t* __restrict__ req_to_token,
const int32_t* __restrict__ req_pool_indices,
const int32_t* __restrict__ page_kernel_lens,
const int32_t* __restrict__ kv_start_idx,
int32_t* __restrict__ kv_indices,
int req_to_token_stride,
int kv_indices_stride)
{
int pid = blockIdx.x; // batch index
int req_pool_index = req_pool_indices[pid];
int kv_start = 0;
int kv_end = 0;
if (kv_start_idx != nullptr) {
kv_start = kv_start_idx[pid];
kv_end = kv_start;
}
kv_end += page_kernel_lens[pid];
int total_len = kv_end - kv_start;
int num_pages = (total_len + PAGED_SIZE - 1) / PAGED_SIZE;
for (int pg = 0; pg < num_pages; ++pg) {
int offset = pg * PAGED_SIZE;
// token id = req_to_token[req_pool_index][kv_start + offset]
int64_t token =
req_to_token[req_pool_index * req_to_token_stride + kv_start + offset];
// 页索引
kv_indices[pid * kv_indices_stride + pg] = token / PAGED_SIZE;
}
}
void dcu_create_flashmla_kv_indices(
const at::Tensor& req_to_token,
const at::Tensor& req_pool_indices,
const at::Tensor& page_kernel_lens,
const c10::optional<at::Tensor>& kv_start_idx,
at::Tensor& kv_indices,
int64_t req_to_token_stride,
int64_t kv_indices_stride,
int64_t PAGED_SIZE)
{
TORCH_CHECK(req_to_token.is_cuda(), "req_to_token must be CUDA tensor");
TORCH_CHECK(kv_indices.is_cuda(), "kv_indices must be CUDA tensor");
int bs = req_pool_indices.size(0);
auto stream = at::cuda::getCurrentCUDAStream();
dim3 grid(bs);
dim3 block(1);
const int32_t* kv_start_idx_ptr = nullptr;
if (kv_start_idx.has_value()) {
kv_start_idx_ptr = kv_start_idx.value().data_ptr<int32_t>();
}
if (PAGED_SIZE == 64) {
dcu_create_flashmla_kv_indices_kernel<64><<<grid, block, 0, stream>>>(
req_to_token.data_ptr<int32_t>(),
req_pool_indices.data_ptr<int32_t>(),
page_kernel_lens.data_ptr<int32_t>(),
kv_start_idx_ptr,
kv_indices.data_ptr<int32_t>(),
req_to_token_stride,
kv_indices_stride
);
} else {
TORCH_CHECK(false, "Unsupported PAGED_SIZE");
}
}
__global__ void launch_create_chunked_prefix_cache_kv_indices(
int32_t* req_to_token_ptr,
const int64_t* req_pool_indices_ptr,
const int32_t* chunk_starts_ptr,
const int32_t* chunk_seq_lens_ptr,
const int32_t* chunk_cu_seq_lens_ptr,
int32_t* chunk_kv_indices_ptr,
int64_t col_num,
int64_t bs)
{
int64_t pid = blockIdx.x * blockDim.x + threadIdx.x;
if (pid >= bs) return;
int64_t req_pool_index = req_pool_indices_ptr[pid];
int64_t chunk_kv_indices_offset = chunk_cu_seq_lens_ptr[pid];
int32_t chunk_start_pos = chunk_starts_ptr[pid];
int32_t chunk_seq_len = chunk_seq_lens_ptr[pid];
#pragma unroll(32)
for(int32_t offset = 0;offset < chunk_seq_len;offset++){
chunk_kv_indices_ptr[chunk_kv_indices_offset+offset] = req_to_token_ptr[req_pool_index * col_num + chunk_start_pos + offset];
}
}
void dcu_create_chunked_prefix_cache_kv_indices(
at::Tensor req_to_token_ptr,
const at::Tensor req_pool_indices_ptr,
const at::Tensor chunk_starts_ptr,
const at::Tensor chunk_seq_lens_ptr,
const at::Tensor chunk_cu_seq_lens_ptr,
at::Tensor chunk_kv_indices_ptr,
int64_t col_num,
int64_t bs) {
int32_t* req_to_token_ptr1 = static_cast<int32_t*>(req_to_token_ptr.data_ptr());
const int64_t* req_pool_indices_ptr1 = static_cast<const int64_t*>(req_pool_indices_ptr.data_ptr());
const int32_t* chunk_starts_ptr1 = static_cast<const int32_t*>(chunk_starts_ptr.data_ptr());
const int32_t* chunk_seq_lens_ptr1 = static_cast<const int32_t*>(chunk_seq_lens_ptr.data_ptr());
const int32_t* chunk_cu_seq_lens_ptr1 = static_cast<const int32_t*>(chunk_cu_seq_lens_ptr.data_ptr());
int32_t* chunk_kv_indices_ptr1 = static_cast<int32_t*>(chunk_kv_indices_ptr.data_ptr());
int64_t block_size = 64;
int64_t grid_size = (bs + block_size - 1) / block_size;
cudaStream_t torch_current_stream = at::cuda::getCurrentCUDAStream();
launch_create_chunked_prefix_cache_kv_indices<<<grid_size, block_size, 0, torch_current_stream>>>(req_to_token_ptr1, req_pool_indices_ptr1, chunk_starts_ptr1, chunk_seq_lens_ptr1, chunk_cu_seq_lens_ptr1,chunk_kv_indices_ptr1, col_num, bs);
C10_CUDA_KERNEL_LAUNCH_CHECK();
}
...@@ -538,6 +538,7 @@ void segment_packbits( ...@@ -538,6 +538,7 @@ void segment_packbits(
/* /*
* From csrc/kvcacheio * From csrc/kvcacheio
*/ */
void dcu_create_extend_after_decode_spec_info( void dcu_create_extend_after_decode_spec_info(
const at::Tensor verified_id, const at::Tensor verified_id,
const at::Tensor seq_lens, const at::Tensor seq_lens,
...@@ -545,6 +546,49 @@ void dcu_create_extend_after_decode_spec_info( ...@@ -545,6 +546,49 @@ void dcu_create_extend_after_decode_spec_info(
at::Tensor positions, at::Tensor positions,
at::Tensor new_verified_id, at::Tensor new_verified_id,
int64_t bs); int64_t bs);
void dcu_create_chunked_prefix_cache_kv_indices(
at::Tensor req_to_token,
const at::Tensor req_pool_indices,
const at::Tensor chunk_starts,
const at::Tensor chunk_seq_lens,
const at::Tensor chunk_cu_seq_lens,
at::Tensor chunk_kv_indices,
int64_t col_num,
int64_t bs);
void dcu_create_flashmla_kv_indices(
const at::Tensor& req_to_token,
const at::Tensor& req_pool_indices,
const at::Tensor& page_kernel_lens,
const c10::optional<at::Tensor>& kv_start_idx,
at::Tensor& kv_indices,
int64_t req_to_token_stride,
int64_t kv_indices_stride,
int64_t PAGED_SIZE);
void dcu_assign_extend_cache_locs(
const at::Tensor req_pool_indices,
const at::Tensor req_to_token,
const at::Tensor start_offset,
const at::Tensor end_offset,
at::Tensor out_cache_loc,
int64_t pool_len,
int64_t bs);
at::Tensor dcu_get_last_loc(
const at::Tensor req_to_token,
const at::Tensor req_pool_indices,
const at::Tensor prefix_lens);
void dcu_assign_req_to_token_pool(
const at::Tensor req_pool_indices_ptr,
at::Tensor req_to_token_ptr,
const at::Tensor allocate_lens_ptr,
at::Tensor new_allocate_lens,
at::Tensor out_cache_loc_ptr,
int64_t shape,
int64_t bs);
void dcu_alloc_extend_kernel( void dcu_alloc_extend_kernel(
const at::Tensor pre_lens_ptr, const at::Tensor pre_lens_ptr,
......
...@@ -13,6 +13,26 @@ _IMPORT_ERROR = ImportError( ...@@ -13,6 +13,26 @@ _IMPORT_ERROR = ImportError(
"Failed to load sgl_kernel.flashmla_ops extension. Ensure CUDA Driver >= 12.4" "Failed to load sgl_kernel.flashmla_ops extension. Ensure CUDA Driver >= 12.4"
) )
def dcu_create_flashmla_kv_indices(
req_to_token_ptr,
req_pool_indices_ptr,
page_kernel_lens_ptr,
kv_start_idx,
kv_indices_ptr,
req_to_token_ptr_stride,
kv_indices_ptr_stride,
PAGED_SIZE = 64,
):
torch.ops.sgl_kernel.dcu_create_flashmla_kv_indices(req_to_token_ptr,
req_pool_indices_ptr,
page_kernel_lens_ptr,
kv_start_idx,
kv_indices_ptr,
req_to_token_ptr_stride,
kv_indices_ptr_stride,
PAGED_SIZE,
)
def get_mla_metadata( def get_mla_metadata(
cache_seqlens: torch.Tensor, cache_seqlens: torch.Tensor,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment