Commit f810bda3 authored by renzhc's avatar renzhc
Browse files

Merge branch 'v0.5.4_dev' into v0.5.4_rzc

parents 4167eff9 48542418
......@@ -332,6 +332,18 @@ def rocblas_scaled_mm(a: torch.Tensor,
return quant_ops.rocblas_scaled_mm_nn(a, b, scale_a, scale_b, out_dtype, bias)
def blaslt_scaled_mm(a: torch.Tensor,
b: torch.Tensor,
scale_a: torch.Tensor,
scale_b: torch.Tensor,
out_dtype: torch.dtype,
bias: Optional[torch.Tensor] = None) -> torch.Tensor:
m = a.shape[0]
n = b.shape[0]
k = a.shape[1]
_, out = quant_ops.hipblaslt_w8a8_gemm(a, b, scale_a, scale_b, m, n, k, 'NT', out_dtype)
return out
def triton_int8_gemm_helper(m: int,
n: int,
k: int,
......
......@@ -11,6 +11,8 @@ from sglang.srt.layers.attention.base_attn_backend import AttentionBackend
from sglang.srt.layers.attention.utils import create_flashmla_kv_indices_triton
from sglang.srt.layers.dp_attention import get_attention_tp_size
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode
from sgl_kernel.flash_mla import dcu_create_flashmla_kv_indices
from sglang.srt.utils import get_bool_env_var
try:
from flash_mla import (
......@@ -104,6 +106,7 @@ class DCUMLABackend(AttentionBackend):
)
def init_forward_metadata(self, forward_batch: ForwardBatch):
use_sglang_create_flashmla_kv_indices_triton = get_bool_env_var("SGLANG_CREATE_FLASHMLA_KV_INDICES_TRITON")
bs = forward_batch.batch_size
if forward_batch.forward_mode.is_decode_or_idle():
......@@ -118,15 +121,27 @@ class DCUMLABackend(AttentionBackend):
dtype=torch.int32,
device=forward_batch.seq_lens.device
)
create_flashmla_kv_indices_triton[(bs,)](
self.req_to_token,
forward_batch.req_pool_indices,
forward_batch.seq_lens,
None,
block_kv_indices,
self.req_to_token.stride(0),
max_seqlen_pad,
)
if use_sglang_create_flashmla_kv_indices_triton:
dcu_create_flashmla_kv_indices(
req_to_token_ptr = self.req_to_token,
req_pool_indices_ptr = forward_batch.req_pool_indices,
page_kernel_lens_ptr = forward_batch.seq_lens,
kv_start_idx = None,
kv_indices_ptr = block_kv_indices,
req_to_token_ptr_stride = self.req_to_token.stride(0),
kv_indices_ptr_stride = max_seqlen_pad,
)
else:
create_flashmla_kv_indices_triton[(bs,)](
self.req_to_token,
forward_batch.req_pool_indices,
forward_batch.seq_lens,
None,
block_kv_indices,
self.req_to_token.stride(0),
max_seqlen_pad,
)
mla_metadata, num_splits = get_mla_metadata(
forward_batch.seq_lens.to(torch.int32),
......@@ -149,15 +164,27 @@ class DCUMLABackend(AttentionBackend):
dtype=torch.int32,
device=seq_lens.device,
)
create_flashmla_kv_indices_triton[(bs,)](
self.req_to_token,
forward_batch.req_pool_indices,
seq_lens,
None,
block_kv_indices,
self.req_to_token.stride(0),
max_seqlen_pad,
)
if use_sglang_create_flashmla_kv_indices_triton:
dcu_create_flashmla_kv_indices(
req_to_token_ptr = self.req_to_token,
req_pool_indices_ptr = forward_batch.req_pool_indices,
page_kernel_lens_ptr = forward_batch.seq_lens,
kv_start_idx = None,
kv_indices_ptr = block_kv_indices,
req_to_token_ptr_stride = self.req_to_token.stride(0),
kv_indices_ptr_stride = max_seqlen_pad,
)
else:
create_flashmla_kv_indices_triton[(bs,)](
self.req_to_token,
forward_batch.req_pool_indices,
forward_batch.seq_lens,
None,
block_kv_indices,
self.req_to_token.stride(0),
max_seqlen_pad,
)
mla_metadata, num_splits = get_mla_metadata(
seq_lens.to(torch.int32),
self.num_draft_tokens * self.num_q_heads,
......@@ -185,15 +212,27 @@ class DCUMLABackend(AttentionBackend):
)
# 调用 Triton kernel 生成 block_kv_indices
create_flashmla_kv_indices_triton[(bs,)](
self.req_to_token,
forward_batch.req_pool_indices,
seq_lens,
None,
block_kv_indices,
self.req_to_token.stride(0),
max_seqlen_pad,
)
if use_sglang_create_flashmla_kv_indices_triton:
dcu_create_flashmla_kv_indices(
req_to_token_ptr = self.req_to_token.to(torch.int32),
req_pool_indices_ptr = forward_batch.req_pool_indices.to(torch.int32),
page_kernel_lens_ptr = forward_batch.seq_lens.to(torch.int32),
kv_start_idx = None,
kv_indices_ptr = block_kv_indices.to(torch.int32),
req_to_token_ptr_stride = self.req_to_token.stride(0),
kv_indices_ptr_stride = max_seqlen_pad,
)
else:
create_flashmla_kv_indices_triton[(bs,)](
self.req_to_token,
forward_batch.req_pool_indices,
forward_batch.seq_lens,
None,
block_kv_indices,
self.req_to_token.stride(0),
max_seqlen_pad,
)
# MLA
mla_metadata, num_splits = get_mla_metadata(
......@@ -211,6 +250,7 @@ class DCUMLABackend(AttentionBackend):
self.flashattn_backend.init_forward_metadata(forward_batch)
def init_cuda_graph_state(
self,
max_bs: int,
......@@ -489,9 +529,10 @@ class DCUMLABackend(AttentionBackend):
k_rope: Optional[torch.Tensor] = None,
sinks=None,
):
if (
if ((
forward_batch.forward_mode == ForwardMode.EXTEND
or forward_batch.forward_mode == ForwardMode.DRAFT_EXTEND
or forward_batch.forward_mode == ForwardMode.DRAFT_EXTEND)
):
if not self.skip_prefill:
return self.flashattn_backend.forward_extend(
......
......@@ -674,16 +674,11 @@ class FlashAttentionBackend(AttentionBackend):
if not layer.is_cross_attention
else forward_batch.encoder_out_cache_loc
)
# if not self.use_mla:
if k_rope is None:
if not self.use_mla:
forward_batch.token_to_kv_pool.set_kv_buffer(
layer, cache_loc, k, v, layer.k_scale, layer.v_scale
)
else:
forward_batch.token_to_kv_pool.set_kv_buffer(
layer, cache_loc, k, v
)
forward_batch.token_to_kv_pool.set_kv_buffer(
layer, cache_loc, k, v, #layer.k_scale, layer.v_scale
)
else:
forward_batch.token_to_kv_pool.set_mla_kv_buffer(
layer,
......
......@@ -16,6 +16,10 @@ from sglang.srt.layers.attention.utils import create_flashmla_kv_indices_triton
from sglang.srt.layers.dp_attention import get_attention_tp_size
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode
from sglang.srt.utils import get_bool_env_var
from sgl_kernel.flash_mla import dcu_create_flashmla_kv_indices
if TYPE_CHECKING:
from sglang.srt.layers.radix_attention import RadixAttention
from sglang.srt.model_executor.model_runner import ModelRunner
......@@ -79,7 +83,7 @@ class FlashMLABackend(FlashInferMLAAttnBackend):
self.num_draft_tokens = model_runner.server_args.speculative_num_draft_tokens
def init_forward_metadata(self, forward_batch: ForwardBatch):
use_sglang_create_flashmla_kv_indices_triton = get_bool_env_var("SGLANG_CREATE_EXTEND_AFTER_DECODE_SPEC_INFO")
bs = forward_batch.batch_size
if forward_batch.forward_mode.is_decode_or_idle():
max_seqlen_pad = triton.cdiv(
......@@ -91,15 +95,27 @@ class FlashMLABackend(FlashInferMLAAttnBackend):
dtype=torch.int32,
device=forward_batch.seq_lens.device,
)
create_flashmla_kv_indices_triton[(bs,)](
self.req_to_token,
forward_batch.req_pool_indices,
forward_batch.seq_lens,
None,
block_kv_indices,
self.req_to_token.stride(0),
max_seqlen_pad,
)
if use_sglang_create_flashmla_kv_indices_triton:
dcu_create_flashmla_kv_indices(
req_to_token_ptr = self.req_to_token,
req_pool_indices_ptr = forward_batch.req_pool_indices,
page_kernel_lens_ptr = forward_batch.seq_lens,
kv_start_idx = None,
kv_indices_ptr = block_kv_indices,
req_to_token_ptr_stride = self.req_to_token.stride(0),
kv_indices_ptr_stride = max_seqlen_pad,
)
else:
create_flashmla_kv_indices_triton[(bs,)](
self.req_to_token,
forward_batch.req_pool_indices,
forward_batch.seq_lens,
None,
block_kv_indices,
self.req_to_token.stride(0),
max_seqlen_pad,
)
mla_metadata, num_splits = get_mla_metadata(
forward_batch.seq_lens.to(torch.int32),
self.num_q_heads,
......@@ -121,15 +137,27 @@ class FlashMLABackend(FlashInferMLAAttnBackend):
dtype=torch.int32,
device=seq_lens.device,
)
create_flashmla_kv_indices_triton[(bs,)](
self.req_to_token,
forward_batch.req_pool_indices,
seq_lens,
None,
block_kv_indices,
self.req_to_token.stride(0),
max_seqlen_pad,
)
if use_sglang_create_flashmla_kv_indices_triton:
dcu_create_flashmla_kv_indices(
req_to_token_ptr = self.req_to_token,
req_pool_indices_ptr = forward_batch.req_pool_indices,
page_kernel_lens_ptr = forward_batch.seq_lens,
kv_start_idx = None,
kv_indices_ptr = block_kv_indices,
req_to_token_ptr_stride = self.req_to_token.stride(0),
kv_indices_ptr_stride = max_seqlen_pad,
)
else:
create_flashmla_kv_indices_triton[(bs,)](
self.req_to_token,
forward_batch.req_pool_indices,
forward_batch.seq_lens,
None,
block_kv_indices,
self.req_to_token.stride(0),
max_seqlen_pad,
)
mla_metadata, num_splits = get_mla_metadata(
seq_lens.to(torch.int32),
self.num_draft_tokens * self.num_q_heads,
......@@ -144,7 +172,7 @@ class FlashMLABackend(FlashInferMLAAttnBackend):
)
else:
super().init_forward_metadata(forward_batch)
def init_cuda_graph_state(
self,
max_bs: int,
......
......@@ -1013,3 +1013,134 @@ def zero_experts_compute_triton(
)
return output
from triton.language.extra import libdevice
from typing import Optional
@triton.jit
def _per_token_quant_int8_one_kernel_opt(
x_ptr,
xq_ptr,
scale_ptr,
stride_x,
stride_xq,
N,
T_dim,
tokens_per_expert_ptr,
BLOCK: tl.constexpr
):
row_id = tl.program_id(0)
if tokens_per_expert_ptr is not None:
e = row_id // T_dim
t = row_id % T_dim
num_valid_tokens_for_e = tl.load(tokens_per_expert_ptr + e)
if t >= num_valid_tokens_for_e:
return
cols = tl.arange(0, BLOCK)
mask = cols < N
x = tl.load(x_ptr + row_id * stride_x + cols, mask=mask,
other=0.0).to(tl.float32)
absmax = tl.maximum(tl.max(tl.abs(x)), 1e-10)
scale_x = absmax / 127
x_q = x * (127 / absmax)
x_q = libdevice.nearbyint(x_q).to(tl.int8)
tl.store(xq_ptr + row_id * stride_xq + cols, x_q, mask=mask)
tl.store(scale_ptr + row_id, scale_x)
@triton.jit
def _per_token_quant_int8_kernel_opt(
x_ptr,
xq_ptr,
scale_ptr,
stride_x,
stride_xq,
N,
E_dim,
T_dim,
tokens_per_expert_ptr,
BLOCK: tl.constexpr
):
token_idx_start = tl.program_id(0)
grid_size = tl.num_programs(0)
num_total_tokens = E_dim * T_dim
for token_idx in range(token_idx_start, num_total_tokens, grid_size):
is_valid_token = True
if tokens_per_expert_ptr is not None:
e = token_idx // T_dim
t = token_idx % T_dim
num_valid_tokens_for_e = tl.load(tokens_per_expert_ptr + e)
if t >= num_valid_tokens_for_e:
is_valid_token = False
if is_valid_token:
cols = tl.arange(0, BLOCK)
mask = cols < N
x = tl.load(x_ptr + token_idx * stride_x + cols, mask=mask,
other=0.0).to(tl.float32)
absmax = tl.maximum(tl.max(tl.abs(x)), 1e-10)
scale_x = absmax / 127
x_q = x * (127 / absmax)
x_q = libdevice.nearbyint(x_q).to(tl.int8)
tl.store(xq_ptr + token_idx * stride_xq + cols, x_q, mask=mask)
tl.store(scale_ptr + token_idx, scale_x)
def per_token_quant_int8_triton_opt(x: torch.Tensor,
tokens_per_expert: Optional[torch.Tensor] = None):
if x.dim() != 3:
raise ValueError(f"Input must be 3D [E, T, H], but got {x.shape}")
E, T, H = x.shape
N = H
x_q = torch.empty_like(x, device=x.device, dtype=torch.int8)
scales = torch.empty(x.shape[:-1] + (1, ), device=x.device, dtype=torch.float32)
BLOCK = triton.next_power_of_2(N)
num_warps = min(max(BLOCK // 256, 1), 8)
if (E == 8 and T >= 1024) or (E == 16 and T >= 512):
num_warps = 1
num_tokens = E * T
grid_opt = num_tokens
if (E == 8 and T >= 1024) or (E == 16 and T >= 512):
grid_opt = max(1, num_tokens // (T // 256))
_per_token_quant_int8_kernel_opt[(grid_opt, )](
x,
x_q,
scales,
stride_x=x.stride(-2),
stride_xq=x_q.stride(-2),
N=N,
E_dim=E,
T_dim=T,
tokens_per_expert_ptr=tokens_per_expert,
BLOCK=BLOCK,
num_warps=num_warps,
num_stages=1,
)
else:
_per_token_quant_int8_one_kernel_opt[(grid_opt, )](
x,
x_q,
scales,
stride_x=x.stride(-2),
stride_xq=x_q.stride(-2),
N=N,
T_dim=T,
tokens_per_expert_ptr=tokens_per_expert,
BLOCK=BLOCK,
num_warps=num_warps,
num_stages=1,
)
return x_q, scales
\ No newline at end of file
......@@ -2,7 +2,7 @@ from __future__ import annotations
import logging
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union
from collections import defaultdict
from sglang.srt.layers.quantization.compressed_tensors.compressed_tensors_marlin import SlimQuantCompressedTensorsMarlinConfig
from sglang.srt.layers.quantization.slimquant_w4a8_marlin import SlimQuantW4A8Int8MarlinConfig
import torch
......@@ -20,6 +20,7 @@ from sglang.srt.layers.moe.ep_moe.kernels import (
ep_scatter,
silu_and_mul_masked_post_quant_fwd,
tma_align_input_scale,
per_token_quant_int8_triton_opt,
)
from sglang.srt.layers.moe.fused_moe_triton.layer import FlashInferFusedMoE, FusedMoE
from sglang.srt.layers.moe.topk import TopKOutput
......@@ -40,7 +41,7 @@ if TYPE_CHECKING:
DeepEPNormalOutput,
DispatchOutput,
)
from lightop import m_grouped_w4a8_gemm_nt_masked, fuse_silu_mul_quant_ep, m_grouped_w8a8_gemm_nt_masked
from lightop import m_grouped_w4a8_gemm_nt_masked, fuse_silu_mul_quant_ep, m_grouped_w8a8_gemm_nt_masked, m_grouped_w8a8_gemm_nt_contig_asm, fuse_silu_mul_quant
from lmslim.layers.gemm.int8_utils import per_token_quant_int8
_is_hip = is_hip()
......@@ -605,6 +606,8 @@ class DeepEPMoE(EPMoE):
return self.forward_deepgemm_contiguous(dispatch_output)
elif self.use_w4a8_marlin:
return self.forward_deepgemm_w4a8_marlin_contiguous(dispatch_output)
elif self.use_w8a8_marlin:
return self.forward_groupgemm_w8a8_marlin_contiguous(dispatch_output)
else:
raise ValueError(
f"Dispatch output is not supported"
......@@ -709,6 +712,111 @@ class DeepEPMoE(EPMoE):
)
return expert_output
def forward_groupgemm_w8a8_marlin_contiguous(
self,
dispatch_output: DeepEPNormalOutput,
):
hidden_states, hidden_states_scale, topk_idx, topk_weights, num_recv_tokens_per_expert = dispatch_output
assert self.quant_method is not None
assert self.moe_runner_config.activation == "silu"
all_tokens = sum(num_recv_tokens_per_expert)
if all_tokens <= 0:
return hidden_states.bfloat16()
device = hidden_states.device
M = hidden_states.shape[0]
K = hidden_states.shape[1]
topk = topk_idx.shape[1]
active_experts = set()
token_expert_pos = [None] * M
for t in range(M):
lst = []
for pos in range(topk):
e = int(topk_idx[t, pos].item())
if e >= 0:
lst.append((e, pos))
active_experts.add(e)
token_expert_pos[t] = lst
active_experts = sorted(list(active_experts))
num_active = len(active_experts)
if num_active == 0:
return hidden_states.bfloat16()
counts = defaultdict(int)
for t in range(M):
for (e, pos) in token_expert_pos[t]:
counts[e] += 1
per_expert_block = {}
for e in active_experts:
cnt = counts.get(e, 0)
if cnt <= 0:
per_expert_block[e] = 0
else:
needed = ((cnt + 256 - 1) // 256) * 256 # next multiple of 256
per_expert_block[e] = max(256, needed)
expert_slot_offset = {}
offset = 0
for e in active_experts:
expert_slot_offset[e] = offset
offset += per_expert_block[e]
pad_M = offset
hidden_states_packed = torch.zeros((pad_M, K), device=device, dtype=hidden_states.dtype)
m_indices = torch.full((pad_M,), -1, device=device, dtype=torch.int32)
slot_counters = {e: 0 for e in active_experts}
token_row_weight_list = {t: [] for t in range(M)}
for t in range(M):
for (e, pos) in token_expert_pos[t]:
start = expert_slot_offset[e]
slot = slot_counters[e]
if slot >= per_expert_block[e]:
raise RuntimeError(f"Internal error: expert {e} slot {slot} >= block {per_expert_block[e]}")
row = start + slot
hidden_states_packed[row] = hidden_states[t]
m_indices[row] = int(e)
slot_counters[e] += 1
w = topk_weights[t, pos].to(device=device)
w_f = w.float() if w.dtype != torch.float32 else w
token_row_weight_list[t].append((row, w_f))
q_a1_all, q_a1_scale = per_token_quant_int8(hidden_states_packed)
N = self.w13_weight.size(1)
gateup_output = torch.empty((pad_M, N * 16), device=device, dtype=torch.bfloat16)
m_grouped_w8a8_gemm_nt_contig_asm(
(q_a1_all, q_a1_scale),
(self.w13_weight, self.w13_weight_scale),
gateup_output,
m_indices,
)
q_a2_all, q_a2_scale = fuse_silu_mul_quant(gateup_output)
down_output = torch.empty((pad_M, K), device=device, dtype=torch.bfloat16)
down_output = m_grouped_w8a8_gemm_nt_contig_asm(
(q_a2_all, q_a2_scale),
(self.w2_weight, self.w2_weight_scale),
down_output,
m_indices,
)
result = torch.zeros((M, K), device=device, dtype=down_output.dtype)
for t in range(M):
pairs = token_row_weight_list[t]
if not pairs:
continue
acc = None
for (row, w) in pairs:
vec = down_output[row].float()
weighted = vec * w
acc = weighted if acc is None else (acc + weighted)
result[t] = acc.to(result.dtype)
return result
def forward_deepgemm_contiguous(
self,
......@@ -899,10 +1007,10 @@ class DeepEPMoE(EPMoE):
# base shapes
num_groups, m, k = hidden_states.size()
expected_m = min(m, expected_m)
expected_m = min(m, expected_m)
# ---- first quant: ensure float input for quantizer ----
q_a1_all, q_a1_scale = per_token_quant_int8(hidden_states)
q_a1_all, q_a1_scale = per_token_quant_int8_triton_opt(hidden_states, masked_m)
# ---- weights & scales ----
w13_weight = self.w13_weight
......@@ -943,16 +1051,15 @@ class DeepEPMoE(EPMoE):
dispatch_output: DeepEPLLOutput,
):
hidden_states, _, _, _, masked_m, expected_m = dispatch_output
hidden_states, _, topk_ids, _, masked_m, expected_m = dispatch_output
assert self.quant_method is not None
assert self.moe_runner_config.activation == "silu"
# base shapes
num_groups, m, k = hidden_states.size()
expected_m = min(m, expected_m)
expected_m = min(m, expected_m)
# ---- first quant: ensure float input for quantizer ----
q_a1_all, q_a1_scale = per_token_quant_int8(hidden_states)
q_a1_all, q_a1_scale = per_token_quant_int8_triton_opt(hidden_states, masked_m)
# ---- weights & scales ----
w13_weight = self.w13_weight
......
......@@ -308,7 +308,7 @@ class _DeepEPDispatcherImplBase:
self.params_bytes = 2
self.num_max_dispatch_tokens_per_rank = get_int_env_var(
"SGLANG_DEEPEP_NUM_MAX_DISPATCH_TOKENS_PER_RANK", 128
"SGLANG_DEEPEP_NUM_MAX_DISPATCH_TOKENS_PER_RANK", 64
)
# DeepEP internode_ll dispatch uses FINISHED_SUM_TAG=1024
# and the logic requires num-tokens-sent-from-one-rank-to-another-rank less than it
......@@ -441,18 +441,19 @@ class _DeepEPDispatcherImplNormal(_DeepEPDispatcherImplBase):
expert_alignment=1,
config=DeepEPConfig.get_instance().normal_dispatch_config,
)
# get_global_expert_distribution_recorder().on_deepep_dispatch_normal(
# num_recv_tokens_per_expert,
# num_tokens_per_rank=num_tokens_per_rank,
# num_tokens_per_rdma_rank=num_tokens_per_rdma_rank,
# num_tokens_per_expert=num_tokens_per_expert,
# )
self.rank_expert_offset= get_moe_expert_parallel_rank() * ( self.num_experts // get_moe_expert_parallel_world_size())
recv_topk_ids = torch.where(
recv_topk_ids == -1,
self.num_experts - 1 if self.rank_expert_offset == 0 else 0,
recv_topk_ids + self.rank_expert_offset)
if self.quant_config.get("quant_method") == "slimquant_w4a8_marlin":
self.rank_expert_offset= get_moe_expert_parallel_rank() * ( self.num_experts // get_moe_expert_parallel_world_size())
recv_topk_ids = torch.where(
recv_topk_ids == -1,
self.num_experts - 1 if self.rank_expert_offset == 0 else 0,
recv_topk_ids + self.rank_expert_offset)
else:
get_global_expert_distribution_recorder().on_deepep_dispatch_normal(
num_recv_tokens_per_expert,
num_tokens_per_rank=num_tokens_per_rank,
num_tokens_per_rdma_rank=num_tokens_per_rdma_rank,
num_tokens_per_expert=num_tokens_per_expert,
)
return (
recv_x,
recv_topk_ids,
......@@ -541,7 +542,7 @@ class _DeepEPDispatcherImplLowLatency(_DeepEPDispatcherImplBase):
num_max_dispatch_tokens_per_rank: the actual batch size in the decoding engine should be less than 256
https://github.com/deepseek-ai/DeepEP?tab=readme-ov-file#example-use-in-inference-decoding
"""
self.return_recv_hook = False
self.return_recv_hook = return_recv_hook
self.device_module = torch.get_device_module()
self.quant_config = {}
......@@ -724,6 +725,7 @@ class _DeepEPDispatcherImplLowLatency(_DeepEPDispatcherImplBase):
self.packed_recv_count = self.handle = None
return combined_hidden_states, event, hook
@torch._dynamo.disable()
def _get_buffer(self):
DeepEPBuffer.set_dispatch_mode_as_low_latency()
return DeepEPBuffer.get_deepep_buffer(
......
# Adapted from https://github.com/vllm-project/vllm/tree/v0.8.2/vllm/model_executor/layers/quantization/compressed_tensors
# SPDX-License-Identifier: Apache-2.0
from __future__ import annotations
import os
import logging
from contextlib import suppress
from typing import Any, Dict, List, Literal, NamedTuple, Optional, Tuple, cast
......@@ -46,6 +46,9 @@ from sglang.srt.layers.quantization.fp8 import Fp8LinearMethod
from sglang.srt.layers.quantization.unquant import UnquantizedLinearMethod
from sglang.srt.layers.quantization.kv_cache import BaseKVCacheMethod
from sglang.srt import _custom_ops as ops
from sglang.srt.utils import W8a8GetCacheJSON
logger = logging.getLogger(__name__)
__all__ = ["CompressedTensorsLinearMethod"]
......@@ -590,8 +593,33 @@ class CompressedTensorsLinearMethod(LinearMethodBase):
def __init__(self, quantization_config: CompressedTensorsConfig):
self.quantization_config = quantization_config
self.tritonsingleton= W8a8GetCacheJSON()
self.w8a8_strategy=int(os.getenv('W8A8_SUPPORT_METHODS', '1'))
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
n=layer.weight.shape[0]
k=layer.weight.shape[1]
if self.w8a8_strategy==1:
if [n,k] not in self.tritonsingleton.weight_shapes:
self.tritonsingleton.weight_shapes.append([n,k])
json_file=self.tritonsingleton.get_w8a8json_name(n,k)
configs_dict=self.tritonsingleton.get_triton_cache(json_file,n,k)
if configs_dict:
self.tritonsingleton.triton_json_dict.update(configs_dict)
for key, value in configs_dict.items():
m=int(key.split('_')[0])
ops.triton_int8_gemm_helper(m=m,n=n,k=k,per_token_act_quant=True,per_out_channel_weight_quant=True,use_bias=False,device=layer.weight.device,best_config=value)
elif self.w8a8_strategy==3:
layer.weight.data = layer.weight.data.T
else:
weight_data=layer.weight.data
_weight=weight_data.T.contiguous().reshape(n,-1)
layer.weight.data=_weight
self.tritonsingleton.gen_model_json()
layer.scheme.process_weights_after_loading(layer)
def create_weights(
......
# Adapted from https://github.com/vllm-project/vllm/tree/main/vllm/model_executor/layers/quantization/compressed_tensors
# SPDX-License-Identifier: Apache-2.0
import os
from typing import Callable, Optional
import torch
......@@ -19,11 +20,13 @@ from sglang.srt.layers.quantization.compressed_tensors.schemes import (
from lmslim.layers.gemm.int8_utils import per_token_quant_int8
from sglang.srt.layers.quantization.utils import requantize_with_max_scale
from sglang.srt.utils import is_cuda
from lmslim import quant_ops
from sglang.srt import _custom_ops as ops
_is_cuda = is_cuda()
if _is_cuda:
from sgl_kernel import int8_scaled_mm
from sglang.srt.utils import W8a8GetCacheJSON
W8A8_TRITONJSON=W8a8GetCacheJSON()
class CompressedTensorsW8A8Int8(CompressedTensorsScheme):
......@@ -33,6 +36,7 @@ class CompressedTensorsW8A8Int8(CompressedTensorsScheme):
self.strategy = strategy
self.is_static_input_scheme = is_static_input_scheme
self.input_symmetric = input_symmetric
self.w8a8_strategy=int(os.getenv('W8A8_SUPPORT_METHODS', '1')) # TODO
@classmethod
def get_min_capability(cls) -> int:
......@@ -163,14 +167,70 @@ class CompressedTensorsW8A8Int8(CompressedTensorsScheme):
)
layer.register_parameter("input_zero_point", input_zero_point)
@torch._dynamo.disable()
def apply_weights(
self, layer: torch.nn.Module, x: torch.Tensor, bias: Optional[torch.Tensor]
) -> torch.Tensor:
# TODO: add cutlass_scaled_mm_azp support
x_q, x_scale = per_token_quant_int8(x)
# TODO: fix with lmslim/lightop
return quant_ops.triton_scaled_mm(
x_q, layer.weight, x_scale, layer.weight_scale, out_dtype=x.dtype, bias=bias
)
# return quant_ops.custom_scaled_mm(x_q, layer.weight, x_scale, layer.weight_scale, out_dtype=x.dtype, bias=bias)
if self.w8a8_strategy==1:
m=x_q.shape[0]
k=x_q.shape[1]
n=layer.weight.shape[1]
if len(W8A8_TRITONJSON.triton_json_dict)==0:
best_config=None
elif f"1_{n}_{k}" in W8A8_TRITONJSON.triton_json_dict:
if m<=16:
m_=m
elif m<=64:
m_= (m + 3) & -4 #取值到最近的4的倍数
elif m<=160:
m_=(m + 7) & -8
elif m<200: #256
m_=160
elif m<480: #512
m_=256
elif m<960: #1024
m_=512
elif m<2048:
m_=1024
elif m<4096:
m_=2048
elif m<6000:
m_=4096
else:
m_=8192
best_config=W8A8_TRITONJSON.triton_json_dict[f"{m_}_{n}_{k}"]
else:
best_config=None
return ops.triton_scaled_mm(
x_q, layer.weight, x_scale, layer.weight_scale, out_dtype=x.dtype, bias=bias, best_config=best_config
)
elif self.w8a8_strategy==2:
return ops.cutlass_scaled_mm(x_q,
layer.weight,
scale_a=x_scale,
scale_b=layer.weight_scale,
out_dtype=x.dtype,
bias=bias)
elif self.w8a8_strategy==3:
return ops.blaslt_scaled_mm(x_q,
layer.weight,
scale_a=x_scale,
scale_b=layer.weight_scale,
out_dtype=x.dtype,
bias=None)
else:
return ops.rocblas_scaled_mm(x_q,
layer.weight,
scale_a=x_scale,
scale_b=layer.weight_scale,
out_dtype=x.dtype,
bias=bias)
......@@ -15,7 +15,7 @@ from lmslim.layers.gemm.int8_utils import (
per_token_group_quant_int8,
per_token_quant_int8)
from sglang.srt import _custom_ops as ops
from vllm.utils import W8a8GetCacheJSON
from sglang.srt.utils import W8a8GetCacheJSON
from sglang.srt.layers.moe import MoeRunner, MoeRunnerBackend, MoeRunnerConfig
import os
......@@ -157,7 +157,6 @@ class SlimQuantW4A8Int8LinearMethod(LinearMethodBase):
)
layer.register_parameter("weight_scale", weight_scale)
@torch._dynamo.disable() # TODO: 性能优化需要lmslim/lightop配合
def apply(
self,
layer: torch.nn.Module,
......@@ -227,6 +226,13 @@ class SlimQuantW4A8Int8LinearMethod(LinearMethodBase):
scale_b=layer.weight_scale,
out_dtype=x.dtype,
bias=bias)
elif self.w8a8_strategy==3:
return ops.blaslt_scaled_mm(x_q,
layer.weight,
scale_a=x_scale,
scale_b=layer.weight_scale,
out_dtype=x.dtype,
bias=None)
else:
return ops.rocblas_scaled_mm(x_q,
layer.weight,
......
......@@ -13,7 +13,8 @@ from sglang.srt.mem_cache.chunk_cache import ChunkCache, SWAChunkCache
from sglang.srt.mem_cache.mamba_radix_cache import MambaRadixCache
from sglang.srt.mem_cache.memory_pool import HybridReqToTokenPool, ReqToTokenPool
from sglang.srt.server_args import get_global_server_args
from sglang.srt.utils import support_triton
from sglang.srt.utils import support_triton,get_bool_env_var
from sgl_kernel.kvcacheio import dcu_get_last_loc
if TYPE_CHECKING:
from sglang.srt.managers.schedule_batch import Req, ScheduleBatch
......@@ -125,13 +126,17 @@ def get_last_loc(
req_pool_indices_tensor: torch.Tensor,
prefix_lens_tensor: torch.Tensor,
) -> torch.Tensor:
if (
get_global_server_args().attention_backend != "ascend"
and get_global_server_args().attention_backend != "torch_native"
):
impl = get_last_loc_triton
use_sglang_get_last_loc = get_bool_env_var("SGLANG_GET_LAST_LOC")
if use_sglang_get_last_loc:
impl = dcu_get_last_loc
else:
impl = get_last_loc_torch
if (
get_global_server_args().attention_backend != "ascend"
and get_global_server_args().attention_backend != "torch_native"
):
impl = get_last_loc_triton
else:
impl = get_last_loc_torch
return impl(req_to_token, req_pool_indices_tensor, prefix_lens_tensor)
......
......@@ -46,7 +46,11 @@ from sglang.srt.layers.dp_attention import (
set_dp_buffer_len,
set_is_extend_in_batch,
)
from sglang.srt.utils import get_compiler_backend, is_npu, support_triton
from sglang.srt.utils import get_compiler_backend, is_npu, support_triton,get_bool_env_var
from sgl_kernel.kvcacheio import dcu_create_chunked_prefix_cache_kv_indices
import logging
logger = logging.getLogger(__name__)
if TYPE_CHECKING:
from sglang.srt.layers.attention.base_attn_backend import AttentionBackend
......@@ -123,13 +127,13 @@ class ForwardMode(IntEnum):
# For fixed shape logits output in v2 eagle worker
return self == ForwardMode.DRAFT_EXTEND_V2
def is_extend_or_draft_extend_or_mixed(self): #nhb
def is_extend_or_draft_extend_or_mixed(self):
return (
self == ForwardMode.EXTEND
or self == ForwardMode.DRAFT_EXTEND
or self == ForwardMode.MIXED
or self == ForwardMode.SPLIT_PREFILL
or self == ForwardMode.DRAFT_EXTEND_V2
or self == ForwardMode.SPLIT_PREFILL
or self == ForwardMode.DRAFT_EXTEND_V2 #nhb
)
def is_cuda_graph(self):
......@@ -317,6 +321,8 @@ class ForwardBatch:
tbo_parent_token_range: Optional[Tuple[int, int]] = None
tbo_children: Optional[List[ForwardBatch]] = None
use_sglang_create_chunked_prefix_cache_kv_indices = get_bool_env_var("SGLANG_CREATE_CHUNKED_PREFIX_CACHE_KV_INDICES")
@classmethod
def init_new(
cls,
......@@ -363,13 +369,13 @@ class ForwardBatch:
if batch.extend_input_logprob_token_ids is not None:
ret.extend_input_logprob_token_ids_gpu = (
batch.extend_input_logprob_token_ids.to(device, non_blocking=True)
batch.extend_input_logprob_token_ids.pin_memory().to(device, non_blocking=True)
)
if enable_num_token_non_padded(model_runner.server_args):
ret.num_token_non_padded = torch.tensor(
len(batch.input_ids), dtype=torch.int32
).to(device, non_blocking=True)
).pin_memory().to(device, non_blocking=True)
ret.num_token_non_padded_cpu = len(batch.input_ids)
# For MLP sync
......@@ -389,12 +395,12 @@ class ForwardBatch:
ret.global_num_tokens_cpu = global_num_tokens
ret.global_num_tokens_gpu = torch.tensor(
global_num_tokens, dtype=torch.int64
).to(device, non_blocking=True)
).pin_memory().to(device, non_blocking=True)
ret.global_num_tokens_for_logprob_cpu = global_num_tokens_for_logprob
ret.global_num_tokens_for_logprob_gpu = torch.tensor(
global_num_tokens_for_logprob, dtype=torch.int64
).to(device, non_blocking=True)
).pin_memory().to(device, non_blocking=True)
if ret.forward_mode.is_idle():
ret.positions = torch.empty((0,), dtype=torch.int64, device=device)
......@@ -419,10 +425,10 @@ class ForwardBatch:
assert isinstance(batch.extend_prefix_lens, list)
ret.extend_seq_lens = torch.tensor(
batch.extend_seq_lens, dtype=torch.int32
).to(device, non_blocking=True)
).pin_memory().to(device, non_blocking=True)
ret.extend_prefix_lens = torch.tensor(
batch.extend_prefix_lens, dtype=torch.int32
).to(device, non_blocking=True)
).pin_memory().to(device, non_blocking=True)
ret.extend_num_tokens = batch.extend_num_tokens
positions, ret.extend_start_loc = compute_position(
model_runner.server_args.attention_backend,
......@@ -635,15 +641,28 @@ class ForwardBatch:
num_chunk_tokens, dtype=torch.int32, device=device
)
create_chunked_prefix_cache_kv_indices[(self.batch_size,)](
self.req_to_token_pool.req_to_token,
self.req_pool_indices,
chunk_starts,
chunk_seq_lens,
chunk_cu_seq_lens,
chunk_kv_indices,
self.req_to_token_pool.req_to_token.shape[1],
)
if self.use_sglang_create_chunked_prefix_cache_kv_indices:
dcu_create_chunked_prefix_cache_kv_indices(
req_to_token = self.req_to_token_pool.req_to_token,
req_pool_indices = self.req_pool_indices,
chunk_starts = chunk_starts,
chunk_seq_lens = chunk_seq_lens,
chunk_cu_seq_lens = chunk_cu_seq_lens,
chunk_kv_indices = chunk_kv_indices,
col_num = self.req_to_token_pool.req_to_token.shape[1],
bs = self.batch_size,
)
else:
# logger.info("SGLANG_CREATE_CHUNKED_PREFIX_CACHE_KV_INDICES=0")
create_chunked_prefix_cache_kv_indices[(self.batch_size,)](
self.req_to_token_pool.req_to_token,
self.req_pool_indices,
chunk_starts,
chunk_seq_lens,
chunk_cu_seq_lens,
chunk_kv_indices,
self.req_to_token_pool.req_to_token.shape[1],
)
self.prefix_chunk_kv_indices.append(chunk_kv_indices)
def _pad_tensor_to_size(self, tensor: torch.Tensor, size: int, *, value: int = 0):
......
......@@ -237,7 +237,14 @@ class DraftBackendFactory:
return None
def _create_dcumla_prefill_backend(self):
logger.warning(
"flashmla prefill backend is not yet supported for draft extend."
# logger.warning(
# "flashmla prefill backend is not yet supported for draft extend."
# )
# return None
#nhb
from sglang.srt.layers.attention.flashattention_backend import (
FlashAttentionBackend,
)
return None
return FlashAttentionBackend(self.draft_model_runner, skip_prefill=False)
......@@ -29,6 +29,12 @@ from sglang.srt.speculative.spec_utils import (
)
from sglang.srt.utils.common import fast_topk, is_cuda, is_hip, next_power_of_2
from sglang.srt.utils import get_bool_env_var
from sgl_kernel.kvcacheio import dcu_assign_req_to_token_pool,dcu_assign_extend_cache_locs
import logging
logger = logging.getLogger(__name__)
if TYPE_CHECKING:
from sglang.srt.managers.tp_worker import TpModelWorker
from sglang.srt.speculative.eagle_draft_cuda_graph_runner import (
......@@ -77,6 +83,9 @@ def assign_draft_cache_locs_page_size_1(
@dataclass
class EagleDraftInputV2Mixin:
use_sglang_assign_req_to_token_pool = get_bool_env_var("SGLANG_ASSIGN_REQ_TO_TOKEN_POOL")
def prepare_for_decode(self: EagleDraftInput, batch: ScheduleBatch):
from sglang.srt.speculative.spec_utils import assign_req_to_token_pool
......@@ -112,15 +121,26 @@ class EagleDraftInputV2Mixin:
extend_num_tokens,
)
assign_req_to_token_pool[(bs,)](
batch.req_pool_indices,
batch.req_to_token_pool.req_to_token,
self.allocate_lens,
new_allocate_lens,
out_cache_loc,
batch.req_to_token_pool.req_to_token.shape[1],
next_power_of_2(bs),
)
if self.use_sglang_assign_req_to_token_pool:
dcu_assign_req_to_token_pool(
req_pool_indices = batch.req_pool_indices,
req_to_token = batch.req_to_token_pool.req_to_token,
allocate_lens = self.allocate_lens,
new_allocate_lens = new_allocate_lens,
out_cache_loc = out_cache_loc,
shape = batch.req_to_token_pool.req_to_token.shape[1],
bs = bs,
)
else:
assign_req_to_token_pool[(bs,)](
batch.req_pool_indices,
batch.req_to_token_pool.req_to_token,
self.allocate_lens,
new_allocate_lens,
out_cache_loc,
batch.req_to_token_pool.req_to_token.shape[1],
next_power_of_2(bs),
)
self.allocate_lens = new_allocate_lens
# FIXME(lsyin): make this sync optional
......@@ -189,6 +209,9 @@ class EagleDraftInputV2Mixin:
@dataclass
class EagleVerifyInputV2Mixin:
use_sglang_assign_extend_cache_locs = get_bool_env_var("SGLANG_ASSIGN_EXTEND_CACHE_LOCS")
def prepare_for_v2_verify(
self: EagleVerifyInput,
req_to_token_pool: ReqToTokenPool,
......@@ -205,15 +228,26 @@ class EagleVerifyInputV2Mixin:
device=device,
)
assign_extend_cache_locs[(bs,)](
batch.req_pool_indices,
req_to_token_pool.req_to_token,
batch.seq_lens,
batch.seq_lens + self.draft_token_num,
batch.out_cache_loc,
req_to_token_pool.req_to_token.shape[1],
next_power_of_2(bs),
)
if self.use_sglang_assign_extend_cache_locs:
dcu_assign_extend_cache_locs(
batch.req_pool_indices,
req_to_token_pool.req_to_token,
batch.seq_lens,
batch.seq_lens + self.draft_token_num,
batch.out_cache_loc,
req_to_token_pool.req_to_token.shape[1],
bs,
)
else:
assign_extend_cache_locs[(bs,)](
batch.req_pool_indices,
req_to_token_pool.req_to_token,
batch.seq_lens,
batch.seq_lens + self.draft_token_num,
batch.out_cache_loc,
req_to_token_pool.req_to_token.shape[1],
next_power_of_2(bs),
)
# Get a forward batch
batch.forward_mode = ForwardMode.TARGET_VERIFY
......
......@@ -758,7 +758,7 @@ class TboForwardBatchPreparer:
# TODO we may make padding on both sub-batches to make it slightly more balanced
value_a = min(tbo_split_token_index, num_token_non_padded)
value_b = max(0, num_token_non_padded - tbo_split_token_index)
return torch.tensor([value_a, value_b], dtype=torch.int32).to(
return torch.tensor([value_a, value_b], dtype=torch.int32).pin_memory().to(
device=get_global_server_args().device, non_blocking=True
)
......
......@@ -3528,3 +3528,158 @@ def cached_triton_kernel(key_fn=None):
return CachedKernel(fn, key_fn)
return decorator
# from vllm
class W8a8GetCacheJSON:
_instance = None
def __new__(cls, *args, **kwargs):
if cls._instance is None:
cls._instance = super(W8a8GetCacheJSON, cls).__new__(cls, *args, **kwargs)
cls._instance._initialize()
return cls._instance
def _initialize(self):
current_folder_path = os.path.dirname(os.path.abspath(__file__))
json_folder_path=current_folder_path+'/../../lmslim/configs/w8a8'
self.triton_json_dir=(os.getenv('TRITON_JSON_DIR', json_folder_path))
self.triton_json_dict={}
self.triton_moejson_dict={}
self.triton_json_list=[]
self.weight_shapes=[]
self.moe_weight_shapes=[]
arch_name = torch.cuda.get_device_properties("cuda").gcnArchName.split(':')[0]
arch_cu = torch.cuda.get_device_properties(torch.cuda.current_device()).multi_processor_count
device_name =arch_name+'_'+str(arch_cu)+'cu'
self.device_name=device_name
self.topk=1
self.quant_method=None
#析构函数,最后会生成model.json的配置文件
def gen_model_json(self,E:Optional[int]=0,block_size:Optional[list]=None):
json_dir = os.getenv('LMSLIM_TUNING_JSON', "None")
if json_dir != "None" and os.path.exists(json_dir):
#生成模型配置文件
# logger.info("model_tuning.json is at LMSLIM_TUNING_JSON:%s", json_dir)
config = {
"layers": {
"linear": {
"shapes": [],
"m_range":"None",
},
"moe": {
"shapes": [],
"m_range": "None",
"topk": self.topk
}
},
"quantization_config": {
"quant_method": self.quant_method,
"weight_block_size": "None"
}
}
# 处理 MoE shapes
for shape in self.moe_weight_shapes:
if len(shape) == 4: # 假设 MoE shape 是 [N1, N2,K] 格式
moe_config = {
"E": shape[0],
"N1": shape[1],
"N2": shape[2],
"K": shape[3], # 默认值
}
config["layers"]["moe"]["shapes"].append(moe_config)
for shape in self.weight_shapes:
config["layers"]["linear"]["shapes"].append(shape)
if block_size is not None:
config["quantization_config"]["weight_block_size"]=block_size
with open(json_dir+"/model.json", 'w') as f:
json.dump(config, f, indent=4)
# else:
# logger.info("LMSLIM_TUNING_JSON is not set")
def getspec_config(self,configs_dict,M,N,K):
if f"{M}_{N}_{K}" in configs_dict:
return configs_dict[f"{M}_{N}_{K}"]
else:
return None
def get_triton_cache(self,file_path,n,k):
#在非tuning的时候使用,当文件不存在则直接返回none
cache_json_file=file_path
if os.path.exists(file_path):
#try:
with open(cache_json_file, 'r') as file:
cachedata = json.load(file)
else:
return None
#把所有的cache解析成key:config的形式:[M_N_K]:[config]
configs_dict={}
for key, value in cachedata.items():
for sub_key, sub_value in value.items():
configs_key= f"{sub_key}_{key}"
configs_dict[configs_key]=sub_value
return configs_dict
def get_w8a8json_name(self,n,k):
return self.triton_json_dir+f"/W8A8_{n}_{k}_{self.device_name}.json"
def get_blockint8_triton_cache(self,file_path,n,k,block_n,block_k):
cache_json_file=file_path
if os.path.exists(file_path):
#try:
with open(cache_json_file, 'r') as file:
cachedata = json.load(file)
else:
return None
#把所有的cache解析成key:config的形式:[M_N_K]:[config]
configs_dict={}
for key, value in cachedata.items():
for sub_key, sub_value in value.items():
configs_key= f"{sub_key}_{key}"
configs_dict[configs_key]=sub_value
return configs_dict
def get_blockint8json_name(self,n,k,block_n,block_k):
return self.triton_json_dir+f"/linear_{n}_{k}_block[{block_n},{block_k}]_{self.device_name}.json"
def get_moeint8json_name(self,E,N1,N2,K,TOPK,
block_size:Optional[list]=None,use_int4_w4a8:Optional[bool]=False):
if use_int4_w4a8:
if block_size is not None:
return self.triton_json_dir+f"/MOE_W4A8INT8[{block_size[0]},{block_size[1]}]_E={E}_N1={N1}_N2={N2}_K={K}_TOPK{TOPK}_{self.device_name}.json"
else:
return self.triton_json_dir+f"/MOE_W4A8INT8_E={E}_N1={N1}_N2={N2}_K={K}_TOPK{TOPK}_{self.device_name}.json"
else:
if block_size is not None:
return self.triton_json_dir+f"/MOE_BLOCKINT8[{block_size[0]},{block_size[1]}]_E={E}_N1={N1}_N2={N2}_K={K}_TOPK{TOPK}_{self.device_name}.json"
else:
return self.triton_json_dir+f"/MOE_W8A8INT8_E={E}_N1={N1}_N2={N2}_K={K}_TOPK{TOPK}_{self.device_name}.json"
def get_moeint8_triton_cache(self,file_path,E,N1,N2,K,TOPK):
cache_json_file=file_path
if os.path.exists(file_path):
#try:
with open(cache_json_file, 'r') as file:
cachedata = json.load(file)
else:
return None
#把所有的cache解析成key:config的形式:[M_N_K]:[config1,config2]
configs_dict={}
for key, value in cachedata.items():
for sub_key, sub_value in value.items():
configs_key= f"{sub_key}_{key}"
configs_dict[configs_key]=sub_value
return configs_dict
......@@ -19,6 +19,14 @@ limitations under the License.
#include "sgl_kernel_ops.h"
TORCH_LIBRARY_EXPAND(sgl_kernel, m) {
/*
* From FlashMLA
*/
m.def("dcu_create_flashmla_kv_indices(Tensor req_to_token, Tensor req_pool_indices,Tensor page_kernel_lens, Tensor? kv_start_idx, Tensor kv_indices, int req_to_token_stride, int kv_indices_stride, int PAGED_SIZE) -> ()");
m.impl("dcu_create_flashmla_kv_indices", torch::kCUDA, &dcu_create_flashmla_kv_indices);
/*
* From csrc/activation
*/
......@@ -133,6 +141,15 @@ TORCH_LIBRARY_EXPAND(sgl_kernel, m) {
*/
m.def("dcu_create_extend_after_decode_spec_info(Tensor verified_id, Tensor seq_lens, Tensor accept_lens, Tensor positions, Tensor new_verified_id, int bs) -> ()");
m.impl("dcu_create_extend_after_decode_spec_info", torch::kCUDA, &dcu_create_extend_after_decode_spec_info);
m.def("dcu_create_chunked_prefix_cache_kv_indices(Tensor req_to_token, Tensor req_pool_indices, Tensor chunk_starts, Tensor chunk_seq_lens, Tensor chunk_cu_seq_lens, Tensor chunk_kv_indices, int col_num, int bs) -> ()");
m.impl("dcu_create_chunked_prefix_cache_kv_indices", torch::kCUDA, &dcu_create_chunked_prefix_cache_kv_indices);
m.def("dcu_assign_extend_cache_locs(Tensor req_pool_indices, Tensor req_to_token, Tensor start_offset, Tensor end_offset, Tensor out_cache_loc, int pool_len, int bs) -> ()");
m.impl("dcu_assign_extend_cache_locs", torch::kCUDA, &dcu_assign_extend_cache_locs);
m.def("dcu_get_last_loc(Tensor req_to_token, Tensor req_pool_indices, Tensor prefix_lens) -> Tensor");
m.impl("dcu_get_last_loc", torch::kCUDA, &dcu_get_last_loc);
m.def("dcu_assign_req_to_token_pool(Tensor req_pool_indices_ptr,Tensor req_to_token_ptr,Tensor allocate_lens_ptr,Tensor new_allocate_lens,Tensor out_cache_loc_ptr,int shape,int bs) -> ()");
m.impl("dcu_assign_req_to_token_pool",torch::kCUDA,&dcu_assign_req_to_token_pool);
m.def("dcu_alloc_extend_kernel(Tensor pre_lens_ptr, Tensor seq_lens_ptr, Tensor last_loc_ptr, Tensor free_page_ptr, Tensor out_indices, int bs, int page_size) -> ()");
m.impl("dcu_alloc_extend_kernel", torch::kCUDA, &dcu_alloc_extend_kernel);
......
......@@ -836,4 +836,322 @@ void dcu_alloc_extend_kernel(
cudaStream_t torch_current_stream = at::cuda::getCurrentCUDAStream();
launch_alloc_extend_kernel<<<grid_size, block_size, 0, torch_current_stream>>>(pre_lens_ptr1, seq_lens_ptr1, last_loc_ptr1, free_page_ptr1, out_indices1, bs, page_size);
C10_CUDA_KERNEL_LAUNCH_CHECK();
}
\ No newline at end of file
}
__global__ void launch_assign_req_to_token_pool(
const int64_t* req_pool_indices_ptr,
int32_t* req_to_token_ptr,
const int64_t* allocate_lens_ptr,
int64_t* new_allocate_lens,
int64_t* out_cache_loc_ptr,
int64_t shape,
int64_t bs)
{
int64_t pid = blockIdx.x * blockDim.x + threadIdx.x;
if (pid >= bs) return;
int64_t kv_start = allocate_lens_ptr[pid];
int64_t kv_end = new_allocate_lens[pid];
int64_t pool_idx = req_pool_indices_ptr[pid];
int32_t* token_pool = (int32_t*)(req_to_token_ptr + pool_idx * shape);
int64_t sum_out_offset = 0;
for(int length_offset = 0; length_offset < pid;length_offset++){
int64_t start = allocate_lens_ptr[length_offset];
int64_t end = new_allocate_lens[length_offset];
sum_out_offset += (end- start);
}
int64_t* out_cache_ptr = out_cache_loc_ptr + sum_out_offset;
int64_t copy_length = kv_end - kv_start;
#pragma unroll(32)
for (int out_cache_index = 0; out_cache_index < copy_length; out_cache_index++) {
token_pool[kv_start + out_cache_index] = out_cache_ptr[out_cache_index];
}
}
void dcu_assign_req_to_token_pool(
const at::Tensor req_pool_indices_ptr,
at::Tensor req_to_token_ptr,
const at::Tensor allocate_lens_ptr,
at::Tensor new_allocate_lens,
at::Tensor out_cache_loc_ptr,
int64_t shape,
int64_t bs) {
const int64_t* req_pool_indices_ptr1 = static_cast<const int64_t*>(req_pool_indices_ptr.data_ptr());
int32_t* req_to_token_ptr1 = static_cast<int32_t*>(req_to_token_ptr.data_ptr());
const int64_t* allocate_lens_ptr1 = static_cast<const int64_t*>(allocate_lens_ptr.data_ptr());
int64_t* new_allocate_lens1 = static_cast<int64_t*>(new_allocate_lens.data_ptr());
int64_t* out_cache_loc_ptr1 = static_cast<int64_t*>(out_cache_loc_ptr.data_ptr());
int64_t block_size = 64;
int64_t grid_size = (bs + block_size - 1) / block_size;
cudaStream_t torch_current_stream = at::cuda::getCurrentCUDAStream();
launch_assign_req_to_token_pool<<<grid_size, block_size, 0, torch_current_stream>>>(req_pool_indices_ptr1, req_to_token_ptr1, allocate_lens_ptr1, new_allocate_lens1, out_cache_loc_ptr1, shape, bs);
C10_CUDA_KERNEL_LAUNCH_CHECK();
}
__global__ void get_last_loc_kernel(
const int32_t* __restrict__ req_to_token,
const int64_t* __restrict__ req_pool_indices_tensor,
const int64_t* __restrict__ prefix_lens_tensor,
int64_t* __restrict__ result,
int64_t num_tokens,
int64_t req_to_token_stride){
int64_t pid = blockIdx.x * blockDim.x + threadIdx.x;
if (pid >= num_tokens) return;
int64_t pre_len = prefix_lens_tensor[pid];
if (pre_len > 0) {
int64_t req_idx = req_pool_indices_tensor[pid];
int64_t token_idx = req_idx * req_to_token_stride + (pre_len - 1);
result[pid] = static_cast<int64_t>(req_to_token[token_idx]);
} else {
result[pid] = static_cast<int64_t>(-1);
}
}
at::Tensor dcu_get_last_loc(
const at::Tensor req_to_token,
const at::Tensor req_pool_indices,
const at::Tensor prefix_lens) {
TORCH_CHECK(req_to_token.device().is_cuda(), "req_to_token must be CUDA tensor");
TORCH_CHECK(req_pool_indices.device().is_cuda(), "req_pool_indices must be CUDA tensor");
TORCH_CHECK(prefix_lens.device().is_cuda(), "prefix_lens must be CUDA tensor");
TORCH_CHECK(req_to_token.dim() == 2, "req_to_token must be 2D tensor [batch, seq_len]");
TORCH_CHECK(prefix_lens.dim() == 1, "prefix_lens must be 1D");
TORCH_CHECK(req_pool_indices.dim() == 1, "req_pool_indices must be 1D");
int64_t num_tokens = prefix_lens.numel();
TORCH_CHECK(req_pool_indices.numel() == num_tokens, "req_pool_indices must have same length as prefix_lens");
int64_t req_to_token_stride = req_to_token.stride(0);
auto req_to_token_c = req_to_token.contiguous();
auto req_pool_indices_c = req_pool_indices.contiguous();
auto prefix_lens_c = prefix_lens.contiguous();
const int32_t* req_to_token_ptr = req_to_token_c.data_ptr<int32_t>();
const int64_t* req_pool_indices_ptr = req_pool_indices_c.data_ptr<int64_t>();
const int64_t* prefix_lens_ptr = prefix_lens_c.data_ptr<int64_t>();
auto result = at::empty_like(prefix_lens_c);
int64_t* result_ptr = result.data_ptr<int64_t>();
const int64_t block_size = 64;
const int64_t grid_size = (num_tokens + block_size - 1) / block_size;
cudaStream_t stream = at::cuda::getCurrentCUDAStream();
get_last_loc_kernel<<<grid_size, block_size, 0, stream>>>(
req_to_token_ptr,
req_pool_indices_ptr,
prefix_lens_ptr,
result_ptr,
num_tokens,
req_to_token_stride
);
C10_CUDA_KERNEL_LAUNCH_CHECK();
return result;
}
__global__ void launch_assign_extend_cache_locs_kernel(
const int64_t* __restrict__ req_pool_indices, // [bs]
const int32_t* __restrict__ req_to_token, // [max_num_req, pool_len]
const int64_t* __restrict__ start_offset, // [bs]
const int64_t* __restrict__ end_offset, // [bs]
int64_t* __restrict__ out_cache_loc, // [sum(draft_token_num)]
int64_t pool_len,
int64_t bs)
{
int pid = blockIdx.x * blockDim.x + threadIdx.x;
if (pid >= bs) return;
int64_t kv_start = start_offset[pid];
int64_t kv_end = end_offset[pid];
int64_t req_id = req_pool_indices[pid];
int64_t out_offset = 0;
for (int i = 0; i < pid; ++i) {
out_offset += end_offset[i] - start_offset[i];
}
const int32_t* src = req_to_token + req_id * pool_len + kv_start;
int64_t* dst = out_cache_loc + out_offset;
for (int64_t i = 0; i < kv_end - kv_start; ++i) {
dst[i] = src[i];
}
}
void dcu_assign_extend_cache_locs(
const at::Tensor req_pool_indices,
const at::Tensor req_to_token,
const at::Tensor start_offset,
const at::Tensor end_offset,
at::Tensor out_cache_loc,
int64_t pool_len,
int64_t bs)
{
const int64_t* req_pool_indices_ptr = req_pool_indices.data_ptr<int64_t>();
const int32_t* req_to_token_ptr = req_to_token.data_ptr<int32_t>();
const int64_t* start_offset_ptr = start_offset.data_ptr<int64_t>();
const int64_t* end_offset_ptr = end_offset.data_ptr<int64_t>();
int64_t* out_cache_loc_ptr = out_cache_loc.data_ptr<int64_t>();
constexpr int64_t threads = 128;
int64_t blocks = (bs + threads - 1) / threads;
cudaStream_t stream = at::cuda::getCurrentCUDAStream();
launch_assign_extend_cache_locs_kernel<<<blocks, threads, 0, stream>>>(
req_pool_indices_ptr,
req_to_token_ptr,
start_offset_ptr,
end_offset_ptr,
out_cache_loc_ptr,
pool_len,
bs);
C10_CUDA_KERNEL_LAUNCH_CHECK();
}
template<int PAGED_SIZE>
__global__ void dcu_create_flashmla_kv_indices_kernel(
const int32_t* __restrict__ req_to_token,
const int32_t* __restrict__ req_pool_indices,
const int32_t* __restrict__ page_kernel_lens,
const int32_t* __restrict__ kv_start_idx,
int32_t* __restrict__ kv_indices,
int req_to_token_stride,
int kv_indices_stride)
{
int pid = blockIdx.x; // batch index
int req_pool_index = req_pool_indices[pid];
int kv_start = 0;
int kv_end = 0;
if (kv_start_idx != nullptr) {
kv_start = kv_start_idx[pid];
kv_end = kv_start;
}
kv_end += page_kernel_lens[pid];
int total_len = kv_end - kv_start;
int num_pages = (total_len + PAGED_SIZE - 1) / PAGED_SIZE;
for (int pg = 0; pg < num_pages; ++pg) {
int offset = pg * PAGED_SIZE;
// token id = req_to_token[req_pool_index][kv_start + offset]
int64_t token =
req_to_token[req_pool_index * req_to_token_stride + kv_start + offset];
// 页索引
kv_indices[pid * kv_indices_stride + pg] = token / PAGED_SIZE;
}
}
void dcu_create_flashmla_kv_indices(
const at::Tensor& req_to_token,
const at::Tensor& req_pool_indices,
const at::Tensor& page_kernel_lens,
const c10::optional<at::Tensor>& kv_start_idx,
at::Tensor& kv_indices,
int64_t req_to_token_stride,
int64_t kv_indices_stride,
int64_t PAGED_SIZE)
{
TORCH_CHECK(req_to_token.is_cuda(), "req_to_token must be CUDA tensor");
TORCH_CHECK(kv_indices.is_cuda(), "kv_indices must be CUDA tensor");
int bs = req_pool_indices.size(0);
auto stream = at::cuda::getCurrentCUDAStream();
dim3 grid(bs);
dim3 block(1);
const int32_t* kv_start_idx_ptr = nullptr;
if (kv_start_idx.has_value()) {
kv_start_idx_ptr = kv_start_idx.value().data_ptr<int32_t>();
}
if (PAGED_SIZE == 64) {
dcu_create_flashmla_kv_indices_kernel<64><<<grid, block, 0, stream>>>(
req_to_token.data_ptr<int32_t>(),
req_pool_indices.data_ptr<int32_t>(),
page_kernel_lens.data_ptr<int32_t>(),
kv_start_idx_ptr,
kv_indices.data_ptr<int32_t>(),
req_to_token_stride,
kv_indices_stride
);
} else {
TORCH_CHECK(false, "Unsupported PAGED_SIZE");
}
}
__global__ void launch_create_chunked_prefix_cache_kv_indices(
int32_t* req_to_token_ptr,
const int64_t* req_pool_indices_ptr,
const int32_t* chunk_starts_ptr,
const int32_t* chunk_seq_lens_ptr,
const int32_t* chunk_cu_seq_lens_ptr,
int32_t* chunk_kv_indices_ptr,
int64_t col_num,
int64_t bs)
{
int64_t pid = blockIdx.x * blockDim.x + threadIdx.x;
if (pid >= bs) return;
int64_t req_pool_index = req_pool_indices_ptr[pid];
int64_t chunk_kv_indices_offset = chunk_cu_seq_lens_ptr[pid];
int32_t chunk_start_pos = chunk_starts_ptr[pid];
int32_t chunk_seq_len = chunk_seq_lens_ptr[pid];
#pragma unroll(32)
for(int32_t offset = 0;offset < chunk_seq_len;offset++){
chunk_kv_indices_ptr[chunk_kv_indices_offset+offset] = req_to_token_ptr[req_pool_index * col_num + chunk_start_pos + offset];
}
}
void dcu_create_chunked_prefix_cache_kv_indices(
at::Tensor req_to_token_ptr,
const at::Tensor req_pool_indices_ptr,
const at::Tensor chunk_starts_ptr,
const at::Tensor chunk_seq_lens_ptr,
const at::Tensor chunk_cu_seq_lens_ptr,
at::Tensor chunk_kv_indices_ptr,
int64_t col_num,
int64_t bs) {
int32_t* req_to_token_ptr1 = static_cast<int32_t*>(req_to_token_ptr.data_ptr());
const int64_t* req_pool_indices_ptr1 = static_cast<const int64_t*>(req_pool_indices_ptr.data_ptr());
const int32_t* chunk_starts_ptr1 = static_cast<const int32_t*>(chunk_starts_ptr.data_ptr());
const int32_t* chunk_seq_lens_ptr1 = static_cast<const int32_t*>(chunk_seq_lens_ptr.data_ptr());
const int32_t* chunk_cu_seq_lens_ptr1 = static_cast<const int32_t*>(chunk_cu_seq_lens_ptr.data_ptr());
int32_t* chunk_kv_indices_ptr1 = static_cast<int32_t*>(chunk_kv_indices_ptr.data_ptr());
int64_t block_size = 64;
int64_t grid_size = (bs + block_size - 1) / block_size;
cudaStream_t torch_current_stream = at::cuda::getCurrentCUDAStream();
launch_create_chunked_prefix_cache_kv_indices<<<grid_size, block_size, 0, torch_current_stream>>>(req_to_token_ptr1, req_pool_indices_ptr1, chunk_starts_ptr1, chunk_seq_lens_ptr1, chunk_cu_seq_lens_ptr1,chunk_kv_indices_ptr1, col_num, bs);
C10_CUDA_KERNEL_LAUNCH_CHECK();
}
......@@ -538,6 +538,7 @@ void segment_packbits(
/*
* From csrc/kvcacheio
*/
void dcu_create_extend_after_decode_spec_info(
const at::Tensor verified_id,
const at::Tensor seq_lens,
......@@ -545,6 +546,49 @@ void dcu_create_extend_after_decode_spec_info(
at::Tensor positions,
at::Tensor new_verified_id,
int64_t bs);
void dcu_create_chunked_prefix_cache_kv_indices(
at::Tensor req_to_token,
const at::Tensor req_pool_indices,
const at::Tensor chunk_starts,
const at::Tensor chunk_seq_lens,
const at::Tensor chunk_cu_seq_lens,
at::Tensor chunk_kv_indices,
int64_t col_num,
int64_t bs);
void dcu_create_flashmla_kv_indices(
const at::Tensor& req_to_token,
const at::Tensor& req_pool_indices,
const at::Tensor& page_kernel_lens,
const c10::optional<at::Tensor>& kv_start_idx,
at::Tensor& kv_indices,
int64_t req_to_token_stride,
int64_t kv_indices_stride,
int64_t PAGED_SIZE);
void dcu_assign_extend_cache_locs(
const at::Tensor req_pool_indices,
const at::Tensor req_to_token,
const at::Tensor start_offset,
const at::Tensor end_offset,
at::Tensor out_cache_loc,
int64_t pool_len,
int64_t bs);
at::Tensor dcu_get_last_loc(
const at::Tensor req_to_token,
const at::Tensor req_pool_indices,
const at::Tensor prefix_lens);
void dcu_assign_req_to_token_pool(
const at::Tensor req_pool_indices_ptr,
at::Tensor req_to_token_ptr,
const at::Tensor allocate_lens_ptr,
at::Tensor new_allocate_lens,
at::Tensor out_cache_loc_ptr,
int64_t shape,
int64_t bs);
void dcu_alloc_extend_kernel(
const at::Tensor pre_lens_ptr,
......
......@@ -13,6 +13,26 @@ _IMPORT_ERROR = ImportError(
"Failed to load sgl_kernel.flashmla_ops extension. Ensure CUDA Driver >= 12.4"
)
def dcu_create_flashmla_kv_indices(
req_to_token_ptr,
req_pool_indices_ptr,
page_kernel_lens_ptr,
kv_start_idx,
kv_indices_ptr,
req_to_token_ptr_stride,
kv_indices_ptr_stride,
PAGED_SIZE = 64,
):
torch.ops.sgl_kernel.dcu_create_flashmla_kv_indices(req_to_token_ptr,
req_pool_indices_ptr,
page_kernel_lens_ptr,
kv_start_idx,
kv_indices_ptr,
req_to_token_ptr_stride,
kv_indices_ptr_stride,
PAGED_SIZE,
)
def get_mla_metadata(
cache_seqlens: torch.Tensor,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment