"vscode:/vscode.git/clone" did not exist on "1e40c81bf493fdcf63423ee117fbecb4397ee4f6"
Commit bdda6c42 authored by niuhb's avatar niuhb
Browse files

Merge branch 'v0.5.4_dev' into 'v0.5.4_dev_shangxl'

# Conflicts:
#   python/sglang/srt/layers/attention/dcu_mla_backend.py
#   python/sglang/srt/layers/attention/flashattention_backend.py
#   python/sglang/srt/model_executor/forward_batch_info.py
#   sgl-kernel/csrc/common_extension_rocm.cc
#   sgl-kernel/include/sgl_kernel_ops.h
parents b08d561e 769353e6
...@@ -105,14 +105,14 @@ class DCUMLABackend(AttentionBackend): ...@@ -105,14 +105,14 @@ class DCUMLABackend(AttentionBackend):
skip_prefill=False, skip_prefill=False,
) )
def _build_decode_metadata( def init_forward_metadata(self, forward_batch: ForwardBatch):
self,
forward_batch: ForwardBatch,
seq_lens: torch.Tensor
) -> Tuple[Tuple[torch.Tensor, torch.Tensor], torch.Tensor, torch.Tensor]:
bs = forward_batch.batch_size bs = forward_batch.batch_size
max_seqlen_pad = triton.cdiv(seq_lens.max().item(), PAGE_SIZE) if forward_batch.forward_mode.is_decode_or_idle():
max_seqlen_pad = triton.cdiv(
forward_batch.seq_lens_cpu.max().item(), PAGE_SIZE
)
# 参考vllm官方博客分页 # 参考vllm官方博客分页
block_kv_indices = torch.full( block_kv_indices = torch.full(
...@@ -487,7 +487,7 @@ class DCUMLABackend(AttentionBackend): ...@@ -487,7 +487,7 @@ class DCUMLABackend(AttentionBackend):
) )
return o return o
@torch._dynamo.disable() @torch._dynamo.disable() # NOTE: FP8 cache decode不支持compile
def forward_decode( def forward_decode(
self, self,
q: torch.Tensor, q: torch.Tensor,
...@@ -514,7 +514,7 @@ class DCUMLABackend(AttentionBackend): ...@@ -514,7 +514,7 @@ class DCUMLABackend(AttentionBackend):
reshape_q = q.view(bs, -1, layer.tp_q_head_num, layer.head_dim) reshape_q = q.view(bs, -1, layer.tp_q_head_num, layer.head_dim)
k_cache_reshaped = k_cache.view(-1, PAGE_SIZE, 1, self.kv_cache_dim) k_cache_reshaped = k_cache.view(-1, PAGE_SIZE, 1, self.kv_cache_dim)
num_draft_tokens = self.num_draft_tokens if self.num_draft_tokens is not None else 0
if self.data_type in (torch.float8_e4m3fn, torch.float8_e4m3fnuz, if self.data_type in (torch.float8_e4m3fn, torch.float8_e4m3fnuz,
torch.float8_e5m2, torch.float8_e5m2fnuz): torch.float8_e5m2, torch.float8_e5m2fnuz):
if self.data_type in (torch.float8_e4m3fnuz, torch.float8_e4m3fn): if self.data_type in (torch.float8_e4m3fnuz, torch.float8_e4m3fn):
...@@ -526,7 +526,7 @@ class DCUMLABackend(AttentionBackend): ...@@ -526,7 +526,7 @@ class DCUMLABackend(AttentionBackend):
reshape_q, reshape_q,
k_cache_reshaped, k_cache_reshaped,
self.forward_metadata.block_kv_indices[:bs], self.forward_metadata.block_kv_indices[:bs],
forward_batch.seq_lens.to(torch.int32), (forward_batch.seq_lens + num_draft_tokens).to(torch.int32),
layer.scaling, layer.scaling,
k_scale, k_scale,
kv_cache_dtype=kv_cache_dtype, kv_cache_dtype=kv_cache_dtype,
...@@ -536,13 +536,13 @@ class DCUMLABackend(AttentionBackend): ...@@ -536,13 +536,13 @@ class DCUMLABackend(AttentionBackend):
reshape_q, reshape_q,
k_cache_reshaped, k_cache_reshaped,
self.forward_metadata.block_kv_indices[:bs], self.forward_metadata.block_kv_indices[:bs],
forward_batch.seq_lens.to(torch.int32), (forward_batch.seq_lens + num_draft_tokens).to(torch.int32),
layer.scaling, layer.scaling,
) )
return o.view(-1, layer.tp_q_head_num * layer.v_head_dim) return o.view(-1, layer.tp_q_head_num * layer.v_head_dim)
@torch._dynamo.disable() # NOTE: untested @torch._dynamo.disable()
def forward_extend( def forward_extend(
self, self,
q: torch.Tensor, q: torch.Tensor,
...@@ -556,12 +556,9 @@ class DCUMLABackend(AttentionBackend): ...@@ -556,12 +556,9 @@ class DCUMLABackend(AttentionBackend):
k_rope: Optional[torch.Tensor] = None, k_rope: Optional[torch.Tensor] = None,
sinks=None, sinks=None,
): ):
if save_kv_cache and self.num_draft_tokens == 0: #nhb if (
return self.forward_decode(q,k,v,layer,forward_batch, save_kv_cache)
if ((
forward_batch.forward_mode == ForwardMode.EXTEND forward_batch.forward_mode == ForwardMode.EXTEND
or forward_batch.forward_mode == ForwardMode.DRAFT_EXTEND) or forward_batch.forward_mode == ForwardMode.DRAFT_EXTEND
): ):
if not self.skip_prefill: if not self.skip_prefill:
return self.flashattn_backend.forward_extend( return self.flashattn_backend.forward_extend(
...@@ -574,14 +571,19 @@ class DCUMLABackend(AttentionBackend): ...@@ -574,14 +571,19 @@ class DCUMLABackend(AttentionBackend):
if k is not None: if k is not None:
assert v is not None assert v is not None
if save_kv_cache: if save_kv_cache:
forward_batch.token_to_kv_pool.set_kv_buffer(layer, cache_loc, k, v) forward_batch.token_to_kv_pool.set_kv_buffer(
layer,
cache_loc,
k,
v,
)
bs = forward_batch.batch_size bs = forward_batch.batch_size
k_cache = forward_batch.token_to_kv_pool.get_key_buffer(layer.layer_id) k_cache = forward_batch.token_to_kv_pool.get_key_buffer(layer.layer_id)
reshape_q = q.view(bs, -1, layer.tp_q_head_num, layer.head_dim) reshape_q = q.view(bs, -1, layer.tp_q_head_num, layer.head_dim)
k_cache_reshaped = k_cache.view(-1, PAGE_SIZE, 1, self.kv_cache_dim) k_cache_reshaped = k_cache.view(-1, PAGE_SIZE, 1, self.kv_cache_dim)
num_draft_tokens = self.num_draft_tokens if self.num_draft_tokens is not None else 0
if self.data_type in (torch.float8_e4m3fn, torch.float8_e4m3fnuz, if self.data_type in (torch.float8_e4m3fn, torch.float8_e4m3fnuz,
torch.float8_e5m2, torch.float8_e5m2fnuz): torch.float8_e5m2, torch.float8_e5m2fnuz):
if self.data_type in (torch.float8_e4m3fnuz, torch.float8_e4m3fn): if self.data_type in (torch.float8_e4m3fnuz, torch.float8_e4m3fn):
...@@ -593,7 +595,7 @@ class DCUMLABackend(AttentionBackend): ...@@ -593,7 +595,7 @@ class DCUMLABackend(AttentionBackend):
reshape_q, reshape_q,
k_cache_reshaped, k_cache_reshaped,
self.forward_metadata.block_kv_indices[:bs], self.forward_metadata.block_kv_indices[:bs],
(forward_batch.seq_lens + self.num_draft_tokens).to(torch.int32), (forward_batch.seq_lens + num_draft_tokens).to(torch.int32),
layer.scaling, layer.scaling,
k_scale, k_scale,
kv_cache_dtype=kv_cache_dtype, kv_cache_dtype=kv_cache_dtype,
...@@ -603,13 +605,12 @@ class DCUMLABackend(AttentionBackend): ...@@ -603,13 +605,12 @@ class DCUMLABackend(AttentionBackend):
reshape_q, reshape_q,
k_cache_reshaped, k_cache_reshaped,
self.forward_metadata.block_kv_indices[:bs], self.forward_metadata.block_kv_indices[:bs],
(forward_batch.seq_lens + self.num_draft_tokens).to(torch.int32), (forward_batch.seq_lens + num_draft_tokens).to(torch.int32),
layer.scaling, layer.scaling,
) )
return o.view(-1, layer.tp_q_head_num * layer.v_head_dim) return o.view(-1, layer.tp_q_head_num * layer.v_head_dim)
class DCUMLAMultiStepDraftBackend: class DCUMLAMultiStepDraftBackend:
""" """
Wrap multiple flashmla attention backends as one for multiple consecutive Wrap multiple flashmla attention backends as one for multiple consecutive
......
...@@ -329,6 +329,8 @@ class FlashAttentionBackend(AttentionBackend): ...@@ -329,6 +329,8 @@ class FlashAttentionBackend(AttentionBackend):
self.use_mla = model_runner.model_config.attention_arch == AttentionArch.MLA self.use_mla = model_runner.model_config.attention_arch == AttentionArch.MLA
self.skip_prefill = skip_prefill self.skip_prefill = skip_prefill
self.is_hybrid = model_runner.is_hybrid self.is_hybrid = model_runner.is_hybrid
self.k_scale = torch.tensor([1.0], dtype=torch.float32, device=self.device)
self.v_scale = torch.tensor([1.0], dtype=torch.float32, device=self.device)
if self.is_hybrid: if self.is_hybrid:
self.full_to_swa_index_mapping = ( self.full_to_swa_index_mapping = (
model_runner.token_to_kv_pool.full_to_swa_index_mapping model_runner.token_to_kv_pool.full_to_swa_index_mapping
...@@ -609,10 +611,13 @@ class FlashAttentionBackend(AttentionBackend): ...@@ -609,10 +611,13 @@ class FlashAttentionBackend(AttentionBackend):
metadata.max_seq_len_q = metadata.max_seq_len_k metadata.max_seq_len_q = metadata.max_seq_len_k
metadata.cu_seqlens_q = metadata.cu_seqlens_k metadata.cu_seqlens_q = metadata.cu_seqlens_k
# Setup local attention if enabled # # Setup local attention if enabled
if forward_batch.forward_mode == ForwardMode.EXTEND: # if forward_batch.forward_mode == ForwardMode.EXTEND:
# self._init_local_attn_metadata(forward_batch, metadata, device)
if forward_batch.forward_mode in (ForwardMode.EXTEND, ForwardMode.DRAFT_EXTEND_V2):
self._init_local_attn_metadata(forward_batch, metadata, device) self._init_local_attn_metadata(forward_batch, metadata, device)
# Encoder metadata for cross attention # Encoder metadata for cross attention
if forward_batch.encoder_lens is not None: if forward_batch.encoder_lens is not None:
assert ( assert (
...@@ -691,7 +696,8 @@ class FlashAttentionBackend(AttentionBackend): ...@@ -691,7 +696,8 @@ class FlashAttentionBackend(AttentionBackend):
layer.sliding_window_size is not None and layer.sliding_window_size > -1 layer.sliding_window_size is not None and layer.sliding_window_size > -1
) )
window_size = (layer.sliding_window_size, 0) if is_swa else (-1, -1) window_size = (layer.sliding_window_size, 0) if is_swa else (-1, -1)
k_descale, v_descale = None, None # k_descale, v_descale = None, None
k_descale, v_descale = self.k_scale, self.v_scale
# only use kv scaling if: 1) fp8 kv is explicitly enabled, 2) RadixAttention # only use kv scaling if: 1) fp8 kv is explicitly enabled, 2) RadixAttention
# has corresponding quantization method so that layer.k_scale is not None, # has corresponding quantization method so that layer.k_scale is not None,
# 3) layer.head_dim <= 256 since fa3 kernel require fp16 and bf16 data type in this case, # 3) layer.head_dim <= 256 since fa3 kernel require fp16 and bf16 data type in this case,
...@@ -775,55 +781,53 @@ class FlashAttentionBackend(AttentionBackend): ...@@ -775,55 +781,53 @@ class FlashAttentionBackend(AttentionBackend):
cu_seqlens_k = metadata.encoder_cu_seqlens_k cu_seqlens_k = metadata.encoder_cu_seqlens_k
window_size = (-1, -1) window_size = (-1, -1)
result = flash_attn_with_kvcache( if forward_batch.attn_attend_prefix_cache:
q=q.contiguous().view(-1, layer.tp_q_head_num, layer.head_dim), assert not get_global_server_args().disable_chunked_prefix_cache
k_cache=key_cache, # MHA for chunked prefix kv cache when running model with MLA
v_cache=value_cache, assert forward_batch.prefix_chunk_idx is not None
page_table=page_table, assert forward_batch.prefix_chunk_cu_seq_lens is not None
cache_seqlens=cache_seqlens, assert forward_batch.prefix_chunk_max_seq_lens is not None
cu_seqlens_q=cu_seqlens_q,
cu_seqlens_k_new=cu_seqlens_k if not use_local_attn else None, chunk_idx = forward_batch.prefix_chunk_idx
max_seqlen_q=max_seqlen_q, assert chunk_idx >= 0
assert forward_batch.mha_return_lse
output = flash_attn_varlen_func(
q=q.view(-1, layer.tp_q_head_num, layer.head_dim),
k=k.view(-1, layer.tp_k_head_num, layer.head_dim).view(q.dtype),
v=v.view(-1, layer.tp_k_head_num, layer.v_head_dim).view(q.dtype),
cu_seqlens_q=metadata.cu_seqlens_q,
cu_seqlens_k=forward_batch.prefix_chunk_cu_seq_lens[chunk_idx],
max_seqlen_q=metadata.max_seq_len_q,
max_seqlen_k=forward_batch.prefix_chunk_max_seq_lens[chunk_idx],
softmax_scale=layer.scaling, softmax_scale=layer.scaling,
causal=False if use_cascade_attn else causal, causal=False,
window_size=window_size,
softcap=layer.logit_cap,
k_descale=k_descale, k_descale=k_descale,
v_descale=v_descale, v_descale=v_descale,
return_softmax_lse=use_cascade_attn, return_softmax_lse=True,
num_splits=self.num_splits,
**kwargs, **kwargs,
) )
else:
if use_cascade_attn: output = flash_attn_varlen_func(
o, softmax_lse, *rest = result q=q.view(-1, layer.tp_q_head_num, layer.head_dim),
o_expand, softmax_lse_expand, *rest_expand = flash_attn_with_kvcache( k=k.view(-1, layer.tp_k_head_num, layer.head_dim).view(q.dtype),
q=q.contiguous().view(-1, layer.tp_q_head_num, layer.head_dim), v=v.view(-1, layer.tp_k_head_num, layer.v_head_dim).view(q.dtype),
k_cache=key_cache, cu_seqlens_q=metadata.cu_seqlens_q,
v_cache=value_cache, cu_seqlens_k=metadata.cu_seqlens_q,
page_table=self.forward_metadata_spec_decode_expand.page_table, max_seqlen_q=metadata.max_seq_len_q,
cache_seqlens=self.forward_metadata_spec_decode_expand.cache_seqlens_int32, max_seqlen_k=metadata.max_seq_len_q,
cu_seqlens_q=self.forward_metadata_spec_decode_expand.cu_seqlens_q,
cu_seqlens_k_new=self.forward_metadata_spec_decode_expand.cu_seqlens_k,
max_seqlen_q=self.forward_metadata_spec_decode_expand.max_seq_len_q,
softmax_scale=layer.scaling, softmax_scale=layer.scaling,
causal=False, causal=True,
window_size=window_size,
softcap=layer.logit_cap,
k_descale=k_descale, k_descale=k_descale,
v_descale=v_descale, v_descale=v_descale,
return_softmax_lse=True, return_softmax_lse=forward_batch.mha_return_lse,
num_splits=self.num_splits,
**kwargs, **kwargs,
) )
o, _ = merge_state_v2_wrapper( if forward_batch.mha_return_lse:
o, output, lse, *rest = output
softmax_lse.T.contiguous(), lse = torch.transpose(lse, 0, 1).contiguous()
o_expand, return output, lse
softmax_lse_expand.T.contiguous(), return output.view(-1, layer.tp_q_head_num * layer.v_head_dim)
)
else:
o = result
else: else:
if ( if (
forward_batch.attn_attend_prefix_cache is not None forward_batch.attn_attend_prefix_cache is not None
...@@ -852,6 +856,8 @@ class FlashAttentionBackend(AttentionBackend): ...@@ -852,6 +856,8 @@ class FlashAttentionBackend(AttentionBackend):
max_seqlen_k=forward_batch.prefix_chunk_max_seq_lens[chunk_idx], max_seqlen_k=forward_batch.prefix_chunk_max_seq_lens[chunk_idx],
softmax_scale=layer.scaling, softmax_scale=layer.scaling,
causal=False, causal=False,
k_descale=k_descale,
v_descale=v_descale,
return_softmax_lse=True, return_softmax_lse=True,
**kwargs, **kwargs,
) )
...@@ -866,6 +872,8 @@ class FlashAttentionBackend(AttentionBackend): ...@@ -866,6 +872,8 @@ class FlashAttentionBackend(AttentionBackend):
max_seqlen_k=metadata.max_seq_len_q, max_seqlen_k=metadata.max_seq_len_q,
softmax_scale=layer.scaling, softmax_scale=layer.scaling,
causal=True, causal=True,
k_descale=k_descale,
v_descale=v_descale,
return_softmax_lse=forward_batch.mha_return_lse, return_softmax_lse=forward_batch.mha_return_lse,
**kwargs, **kwargs,
) )
...@@ -975,10 +983,16 @@ class FlashAttentionBackend(AttentionBackend): ...@@ -975,10 +983,16 @@ class FlashAttentionBackend(AttentionBackend):
if not layer.is_cross_attention if not layer.is_cross_attention
else forward_batch.encoder_out_cache_loc else forward_batch.encoder_out_cache_loc
) )
# if not self.use_mla:
if k_rope is None:
if not self.use_mla: if not self.use_mla:
forward_batch.token_to_kv_pool.set_kv_buffer( forward_batch.token_to_kv_pool.set_kv_buffer(
layer, cache_loc, k, v, layer.k_scale, layer.v_scale layer, cache_loc, k, v, layer.k_scale, layer.v_scale
) )
else:
forward_batch.token_to_kv_pool.set_kv_buffer(
layer, cache_loc, k, v
)
else: else:
forward_batch.token_to_kv_pool.set_mla_kv_buffer( forward_batch.token_to_kv_pool.set_mla_kv_buffer(
layer, layer,
...@@ -1020,7 +1034,8 @@ class FlashAttentionBackend(AttentionBackend): ...@@ -1020,7 +1034,8 @@ class FlashAttentionBackend(AttentionBackend):
if sinks is not None: if sinks is not None:
kwargs["sinks"] = sinks kwargs["sinks"] = sinks
k_descale, v_descale = None, None # k_descale, v_descale = None, None
k_descale, v_descale = self.k_scale, self.v_scale
# only use kv scaling if: 1) fp8 kv is explicitly enabled, 2) RadixAttention # only use kv scaling if: 1) fp8 kv is explicitly enabled, 2) RadixAttention
# has corresponding quantization method so that layer.k_scale is not None, # has corresponding quantization method so that layer.k_scale is not None,
# 3) layer.head_dim <= 256 since fa3 kernel require fp16 and bf16 data type in this case. # 3) layer.head_dim <= 256 since fa3 kernel require fp16 and bf16 data type in this case.
...@@ -1034,7 +1049,6 @@ class FlashAttentionBackend(AttentionBackend): ...@@ -1034,7 +1049,6 @@ class FlashAttentionBackend(AttentionBackend):
k_rope = k_rope.to(self.kv_cache_dtype) if k_rope is not None else None k_rope = k_rope.to(self.kv_cache_dtype) if k_rope is not None else None
if not self.use_mla: if not self.use_mla:
# Do multi-head attention # Do multi-head attention
key_cache, value_cache = forward_batch.token_to_kv_pool.get_kv_buffer( key_cache, value_cache = forward_batch.token_to_kv_pool.get_kv_buffer(
layer.layer_id layer.layer_id
) )
...@@ -1086,26 +1100,33 @@ class FlashAttentionBackend(AttentionBackend): ...@@ -1086,26 +1100,33 @@ class FlashAttentionBackend(AttentionBackend):
**kwargs, **kwargs,
) )
else: else:
cu_seqlens_q = metadata.cu_seqlens_q
max_seqlen_q = metadata.max_seq_len_q
page_table = metadata.page_table page_table = metadata.page_table
cache_seqlens = metadata.cache_seqlens_int32
cu_seqlens_k = metadata.cu_seqlens_k cu_seqlens_k = metadata.cu_seqlens_k
max_seqlen_q = metadata.max_seq_len_q cache_seqlens = metadata.cache_seqlens_int32
q_reshaped = q.contiguous().view( key_cache = key_cache.view(
-1, layer.tp_q_head_num, layer.head_dim -1, self.page_size, layer.tp_k_head_num, layer.head_dim
) )
value_cache = value_cache.view(
# Default: single-token self-attention -1, self.page_size, layer.tp_v_head_num, layer.head_dim
result = flash_attn_with_kvcache( )
q=q_reshaped, if layer.is_cross_attention:
k_cache=key_cache, page_table = metadata.encoder_page_table
v_cache=value_cache, cache_seqlens = metadata.encoder_lens_int32
page_table=page_table, cu_seqlens_k = metadata.encoder_cu_seqlens_k
cache_seqlens=cache_seqlens, window_size = (-1, -1)
cu_seqlens_q=metadata.cu_seqlens_q, if max_seqlen_q > 1:
cu_seqlens_k_new=cu_seqlens_k, result = flash_attn_varlen_func(
q=q.contiguous().view(-1, layer.tp_q_head_num, layer.head_dim),
k=k.view(-1, layer.tp_k_head_num, layer.head_dim).view(q.dtype),
v=v.view(-1, layer.tp_k_head_num, layer.v_head_dim).view(q.dtype),
cu_seqlens_q=cu_seqlens_q,
cu_seqlens_k=cu_seqlens_k,
max_seqlen_q=max_seqlen_q, max_seqlen_q=max_seqlen_q,
max_seqlen_k=max_seqlen_q,
softmax_scale=layer.scaling, softmax_scale=layer.scaling,
causal=False if use_cascade_attn else causal, causal=True,
window_size=window_size, window_size=window_size,
softcap=layer.logit_cap, softcap=layer.logit_cap,
k_descale=k_descale, k_descale=k_descale,
...@@ -1114,36 +1135,26 @@ class FlashAttentionBackend(AttentionBackend): ...@@ -1114,36 +1135,26 @@ class FlashAttentionBackend(AttentionBackend):
num_splits=self.num_splits, num_splits=self.num_splits,
**kwargs, **kwargs,
) )
if use_cascade_attn: else:
o, softmax_lse, *rest = result result = flash_attn_with_kvcache(
o_expand, softmax_lse_expand, *rest_expand = ( q=q.contiguous().view(-1, layer.tp_q_head_num, layer.head_dim),
flash_attn_with_kvcache(
q=q_reshaped,
k_cache=key_cache, k_cache=key_cache,
v_cache=value_cache, v_cache=value_cache,
page_table=self.forward_metadata_spec_decode_expand.page_table, page_table=page_table,
cache_seqlens=self.forward_metadata_spec_decode_expand.cache_seqlens_int32, cache_seqlens=cache_seqlens,
cu_seqlens_q=self.forward_metadata_spec_decode_expand.cu_seqlens_q, cu_seqlens_q=cu_seqlens_q,
cu_seqlens_k_new=self.forward_metadata_spec_decode_expand.cu_seqlens_k, cu_seqlens_k_new=cu_seqlens_k if not use_local_attn else None,
max_seqlen_q=self.forward_metadata_spec_decode_expand.max_seq_len_q, max_seqlen_q=max_seqlen_q,
softmax_scale=layer.scaling, softmax_scale=layer.scaling,
causal=False, causal=True,
window_size=window_size, window_size=window_size,
softcap=layer.logit_cap, softcap=layer.logit_cap,
k_descale=k_descale, k_descale=k_descale,
v_descale=v_descale, v_descale=v_descale,
return_softmax_lse=True, return_softmax_lse=use_cascade_attn,
num_splits=self.num_splits, num_splits=self.num_splits,
**kwargs, **kwargs,
) )
)
o, _ = merge_state_v2(
o,
softmax_lse.T.contiguous(),
o_expand,
softmax_lse_expand.T.contiguous(),
)
else:
o = result o = result
else: else:
# Do absorbed multi-latent attention # Do absorbed multi-latent attention
......
...@@ -42,8 +42,8 @@ def flash_attn_with_kvcache( ...@@ -42,8 +42,8 @@ def flash_attn_with_kvcache(
): ):
return flash_attn_with_kvcache_interface( return flash_attn_with_kvcache_interface(
q=q.contiguous().view(-1, max_seqlen_q, q.shape[-2], q.shape[-1]), q=q.contiguous().view(-1, max_seqlen_q, q.shape[-2], q.shape[-1]),
k_cache=k_cache, k_cache=k_cache.view(q.dtype),
v_cache=v_cache, v_cache=v_cache.view(q.dtype),
block_table=page_table, block_table=page_table,
cache_seqlens=cache_seqlens, cache_seqlens=cache_seqlens,
softmax_scale=softmax_scale, softmax_scale=softmax_scale,
...@@ -83,8 +83,8 @@ def flash_attn_varlen_func( ...@@ -83,8 +83,8 @@ def flash_attn_varlen_func(
): ):
return flash_attn_varlen_func_interface( return flash_attn_varlen_func_interface(
q=q, q=q,
k=k, k=k.view(q.dtype),
v=v, v=v.view(q.dtype),
cu_seqlens_q=cu_seqlens_q, cu_seqlens_q=cu_seqlens_q,
cu_seqlens_k=cu_seqlens_k, cu_seqlens_k=cu_seqlens_k,
max_seqlen_q=max_seqlen_q, max_seqlen_q=max_seqlen_q,
...@@ -92,4 +92,5 @@ def flash_attn_varlen_func( ...@@ -92,4 +92,5 @@ def flash_attn_varlen_func(
softmax_scale=softmax_scale, softmax_scale=softmax_scale,
causal=causal, causal=causal,
return_attn_probs=return_softmax_lse, return_attn_probs=return_softmax_lse,
softcap=softcap,
) )
\ No newline at end of file
...@@ -4,7 +4,7 @@ import warnings ...@@ -4,7 +4,7 @@ import warnings
import torch import torch
from sglang.srt.utils import get_bool_env_var from sglang.srt.utils import get_bool_env_var, direct_register_custom_op
_USE_OPT_CAT = get_bool_env_var("SGLANG_USE_OPT_CAT") _USE_OPT_CAT = get_bool_env_var("SGLANG_USE_OPT_CAT")
...@@ -20,13 +20,48 @@ else: ...@@ -20,13 +20,48 @@ else:
ds_cat = None ds_cat = None
def concat_decode_opt(A:torch.Tensor, B:torch.Tensor, dim:int):
assert dim==2 , "tensor dim must be 3 and concat dim must be 2"
# TODO: 单独注册有些问题
def ds_cat_wrapper(A: torch.Tensor,
B: torch.Tensor,
dim: int,
mode: int) -> torch.Tensor:
output_shape = list(A.shape) output_shape = list(A.shape)
output_shape[dim] = A.shape[dim] + B.shape[dim] output_shape[dim] = A.shape[dim] + B.shape[dim]
C = torch.empty(output_shape, device=A.device, dtype=A.dtype) C = torch.empty(output_shape, device=A.device, dtype=A.dtype)
mode = 0 ds_cat(A, B, C, mode)
if dim!=0 :
ds_cat( A, B, C, mode)
return C return C
def ds_cat_fake(A: torch.Tensor,
B: torch.Tensor,
dim: int,
mode: int) -> torch.Tensor:
# 使用标准cat作为fake实现
return torch.cat([A, B], dim=dim)
direct_register_custom_op(
op_name="ds_cat",
op_func=ds_cat_wrapper,
mutates_args=[], # 没有修改参数,只有返回值
fake_impl=ds_cat_fake
)
def concat_decode_opt(A: torch.Tensor, B: torch.Tensor, dim: int):
assert dim == 2, "tensor dim must be 3 and concat dim must be 2"
mode = 0
if dim != 0:
return torch.ops.sglang.ds_cat(A, B, dim, mode)
assert False, "not support" assert False, "not support"
# def concat_decode_opt(A:torch.Tensor, B:torch.Tensor, dim:int):
# assert dim==2 , "tensor dim must be 3 and concat dim must be 2"
# output_shape = list(A.shape)
# output_shape[dim] = A.shape[dim] + B.shape[dim]
# C = torch.empty(output_shape, device=A.device, dtype=A.dtype)
# mode=0
# if dim!=0 :
# ds_cat(A, B, C, mode)
# return C
# assert False, "not support"
...@@ -31,7 +31,7 @@ from sglang.srt.layers.quantization.fp8_kernel import ( ...@@ -31,7 +31,7 @@ from sglang.srt.layers.quantization.fp8_kernel import (
) )
from sglang.srt.layers.quantization.w4afp8 import W4AFp8Config, W4AFp8MoEMethod from sglang.srt.layers.quantization.w4afp8 import W4AFp8Config, W4AFp8MoEMethod
from sglang.srt.single_batch_overlap import DownGemmOverlapArgs from sglang.srt.single_batch_overlap import DownGemmOverlapArgs
from sglang.srt.utils import ceil_div, dispose_tensor, get_bool_env_var, is_hip, is_npu from sglang.srt.utils import ceil_div, dispose_tensor, get_bool_env_var, is_hip, is_npu, direct_register_custom_op
from sglang.srt.utils.offloader import get_offloader from sglang.srt.utils.offloader import get_offloader
if TYPE_CHECKING: if TYPE_CHECKING:
...@@ -57,6 +57,105 @@ if _use_aiter: ...@@ -57,6 +57,105 @@ if _use_aiter:
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
#------ custom op for lightop
def m_grouped_w4a8_gemm_nt_masked_wrapper(
a0: torch.Tensor, a1: torch.Tensor,
b0: torch.Tensor, b1: torch.Tensor,
d: torch.Tensor,
masked_m: torch.Tensor,
expected_m_per_group: int
) -> torch.Tensor:
return m_grouped_w4a8_gemm_nt_masked(
(a0, a1),
(b0, b1),
d,
masked_m,
expected_m_per_group,
config={"MODE": 1000,}
)
def m_grouped_w4a8_gemm_nt_masked_fake(
a0: torch.Tensor, a1: torch.Tensor,
b0: torch.Tensor, b1: torch.Tensor,
d: torch.Tensor,
masked_m: torch.Tensor,
expected_m_per_group: int
) -> torch.Tensor:
return d
def m_grouped_w8a8_gemm_nt_masked_wrapper(
a0: torch.Tensor, a1: torch.Tensor,
b0: torch.Tensor, b1: torch.Tensor,
d: torch.Tensor,
masked_m: torch.Tensor,
expected_m_per_group: int
) -> torch.Tensor:
return m_grouped_w8a8_gemm_nt_masked(
(a0, a1),
(b0, b1),
d,
masked_m,
expected_m_per_group,
config={"MODE": 1000,}
)
def m_grouped_w8a8_gemm_nt_masked_fake(
a0: torch.Tensor, a1: torch.Tensor,
b0: torch.Tensor, b1: torch.Tensor,
d: torch.Tensor,
masked_m: torch.Tensor,
expected_m_per_group: int
) -> torch.Tensor:
return d
def fuse_silu_mul_quant_ep_wrapper(
input: torch.Tensor,
tokens_per_expert: Optional[torch.Tensor] = None,
num_local_tokens_tensor: Optional[torch.Tensor] = None,
topk:int=1,
expect_m:int=-1) -> tuple[torch.Tensor, torch.Tensor]:
return fuse_silu_mul_quant_ep(
input,
tokens_per_expert,
num_local_tokens_tensor,
topk,
expect_m
)
def fuse_silu_mul_quant_ep_fake(
input: torch.Tensor,
tokens_per_expert: Optional[torch.Tensor] = None,
num_local_tokens_tensor: Optional[torch.Tensor] = None,
topk:int=1,
expect_m:int=-1) -> tuple[torch.Tensor, torch.Tensor]:
E, T, H = input.shape
d = H // 2
output = torch.empty(E, T, d, dtype=torch.int8, device=input.device)
scales = torch.empty((E, T, 1),
device=input.device,
dtype=torch.float32)
return output, scales
direct_register_custom_op(
op_name="m_grouped_w4a8_gemm_nt_masked",
op_func=m_grouped_w4a8_gemm_nt_masked_wrapper,
mutates_args=[],
fake_impl=m_grouped_w4a8_gemm_nt_masked_fake
)
direct_register_custom_op(
op_name="m_grouped_w8a8_gemm_nt_masked",
op_func=m_grouped_w8a8_gemm_nt_masked_wrapper,
mutates_args=[],
fake_impl=m_grouped_w8a8_gemm_nt_masked_fake
)
direct_register_custom_op(
op_name="fuse_silu_mul_quant_ep",
op_func=fuse_silu_mul_quant_ep_wrapper,
mutates_args=[],
fake_impl=fuse_silu_mul_quant_ep_fake
)
#------
# TODO(kaixih@nvidia): ideally we should merge this logic into # TODO(kaixih@nvidia): ideally we should merge this logic into
# `fill_gateup_input_triton_kernel` to directly generate e8m0 scale. # `fill_gateup_input_triton_kernel` to directly generate e8m0 scale.
...@@ -815,23 +914,23 @@ class DeepEPMoE(EPMoE): ...@@ -815,23 +914,23 @@ class DeepEPMoE(EPMoE):
gateup_output = torch.empty((num_groups, m, n1), device=hidden_states.device, dtype=torch.bfloat16) gateup_output = torch.empty((num_groups, m, n1), device=hidden_states.device, dtype=torch.bfloat16)
# ---- first GEMM ---- # ---- first GEMM ----
m_grouped_w4a8_gemm_nt_masked( torch.ops.sglang.m_grouped_w4a8_gemm_nt_masked(
(q_a1_all, q_a1_scale), q_a1_all, q_a1_scale,
(w13_weight, w13_scales), w13_weight, w13_scales,
gateup_output, gateup_output,
masked_m, masked_m,
expected_m, expected_m,
) )
q_a2_all, q_a2_scale = fuse_silu_mul_quant_ep(gateup_output, masked_m) q_a2_all, q_a2_scale = torch.ops.sglang.fuse_silu_mul_quant_ep(gateup_output, masked_m)
# ---- second GEMM ---- # ---- second GEMM ----
n2 = w2_scales.size(1) n2 = w2_scales.size(1)
down_output = torch.empty((num_groups, m, n2), device=q_a2_all.device, dtype=torch.bfloat16) down_output = torch.empty((num_groups, m, n2), device=q_a2_all.device, dtype=torch.bfloat16)
m_grouped_w4a8_gemm_nt_masked( torch.ops.sglang.m_grouped_w4a8_gemm_nt_masked(
(q_a2_all, q_a2_scale), q_a2_all, q_a2_scale,
(w2_weight, w2_scales), w2_weight, w2_scales,
down_output, down_output,
masked_m, masked_m,
expected_m, expected_m,
...@@ -865,23 +964,23 @@ class DeepEPMoE(EPMoE): ...@@ -865,23 +964,23 @@ class DeepEPMoE(EPMoE):
gateup_output = torch.empty((num_groups, m, n1), device=hidden_states.device, dtype=torch.bfloat16) gateup_output = torch.empty((num_groups, m, n1), device=hidden_states.device, dtype=torch.bfloat16)
# ---- first GEMM ---- # ---- first GEMM ----
m_grouped_w8a8_gemm_nt_masked( torch.ops.sglang.m_grouped_w8a8_gemm_nt_masked(
(q_a1_all, q_a1_scale), q_a1_all, q_a1_scale,
(w13_weight, w13_scales), w13_weight, w13_scales,
gateup_output, gateup_output,
masked_m, masked_m,
expected_m, expected_m,
) )
q_a2_all, q_a2_scale = fuse_silu_mul_quant_ep(gateup_output, masked_m) q_a2_all, q_a2_scale = torch.ops.sglang.fuse_silu_mul_quant_ep(gateup_output, masked_m)
# ---- second GEMM ---- # ---- second GEMM ----
n2 = w2_scales.size(1) n2 = w2_scales.size(1)
down_output = torch.empty((num_groups, m, n2), device=q_a2_all.device, dtype=torch.bfloat16) down_output = torch.empty((num_groups, m, n2), device=q_a2_all.device, dtype=torch.bfloat16)
m_grouped_w8a8_gemm_nt_masked( torch.ops.sglang.m_grouped_w8a8_gemm_nt_masked(
(q_a2_all, q_a2_scale), q_a2_all, q_a2_scale,
(w2_weight, w2_scales), w2_weight, w2_scales,
down_output, down_output,
masked_m, masked_m,
expected_m, expected_m,
......
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from __future__ import annotations from __future__ import annotations
...@@ -15,6 +16,7 @@ from sglang.srt.layers.quantization.base_config import FusedMoEMethodBase ...@@ -15,6 +16,7 @@ from sglang.srt.layers.quantization.base_config import FusedMoEMethodBase
from sglang.srt.utils import set_weight_attrs from sglang.srt.utils import set_weight_attrs
from sglang.srt.layers.moe import MoeRunner, MoeRunnerBackend, MoeRunnerConfig from sglang.srt.layers.moe import MoeRunner, MoeRunnerBackend, MoeRunnerConfig
from sglang.srt.layers.moe.utils import get_moe_a2a_backend
try: try:
from lmslim.layers.fused_moe.fuse_moe_int8_marlin import fused_experts_impl_int8_marlin from lmslim.layers.fused_moe.fuse_moe_int8_marlin import fused_experts_impl_int8_marlin
except Exception: except Exception:
...@@ -77,7 +79,7 @@ class CompressedTensorsW8A8Int8MarlinMoEMethod(CompressedTensorsMarlinMoEMethod) ...@@ -77,7 +79,7 @@ class CompressedTensorsW8A8Int8MarlinMoEMethod(CompressedTensorsMarlinMoEMethod)
"weights") "weights")
self.input_quant = self.quant_config.target_scheme_map["Linear"].get( self.input_quant = self.quant_config.target_scheme_map["Linear"].get(
"input_activations") "input_activations")
self.use_deepep = True self.use_deepep = get_moe_a2a_backend().is_deepep()
per_channel = ( per_channel = (
self.weight_quant.strategy == QuantizationStrategy.CHANNEL self.weight_quant.strategy == QuantizationStrategy.CHANNEL
and self.input_quant.strategy == QuantizationStrategy.TOKEN) and self.input_quant.strategy == QuantizationStrategy.TOKEN)
......
...@@ -163,12 +163,14 @@ class CompressedTensorsW8A8Int8(CompressedTensorsScheme): ...@@ -163,12 +163,14 @@ class CompressedTensorsW8A8Int8(CompressedTensorsScheme):
) )
layer.register_parameter("input_zero_point", input_zero_point) layer.register_parameter("input_zero_point", input_zero_point)
@torch._dynamo.disable()
def apply_weights( def apply_weights(
self, layer: torch.nn.Module, x: torch.Tensor, bias: Optional[torch.Tensor] self, layer: torch.nn.Module, x: torch.Tensor, bias: Optional[torch.Tensor]
) -> torch.Tensor: ) -> torch.Tensor:
# TODO: add cutlass_scaled_mm_azp support # TODO: add cutlass_scaled_mm_azp support
x_q, x_scale = per_token_quant_int8(x) x_q, x_scale = per_token_quant_int8(x)
# TODO: fix with lmslim/lightop
return quant_ops.triton_scaled_mm( return quant_ops.triton_scaled_mm(
x_q, layer.weight, x_scale, layer.weight_scale, out_dtype=x.dtype, bias=bias x_q, layer.weight, x_scale, layer.weight_scale, out_dtype=x.dtype, bias=bias
) )
...@@ -157,7 +157,7 @@ class SlimQuantW4A8Int8LinearMethod(LinearMethodBase): ...@@ -157,7 +157,7 @@ class SlimQuantW4A8Int8LinearMethod(LinearMethodBase):
) )
layer.register_parameter("weight_scale", weight_scale) layer.register_parameter("weight_scale", weight_scale)
@torch._dynamo.disable() @torch._dynamo.disable() # TODO: 性能优化需要lmslim/lightop配合
def apply( def apply(
self, self,
layer: torch.nn.Module, layer: torch.nn.Module,
......
...@@ -214,7 +214,7 @@ class SlimQuantW4A8Int8MarlinMoEMethod: ...@@ -214,7 +214,7 @@ class SlimQuantW4A8Int8MarlinMoEMethod:
self.moe_runner_config = moe_runner_config self.moe_runner_config = moe_runner_config
self.runner = MoeRunner(MoeRunnerBackend.TRITON, moe_runner_config) self.runner = MoeRunner(MoeRunnerBackend.TRITON, moe_runner_config)
@torch._dynamo.disable() @torch._dynamo.disable() # TODO: 性能优化需lmslim/lightop配合
def apply( def apply(
self, self,
layer: torch.nn.Module, layer: torch.nn.Module,
...@@ -253,6 +253,7 @@ class SlimQuantW4A8Int8MarlinMoEMethod: ...@@ -253,6 +253,7 @@ class SlimQuantW4A8Int8MarlinMoEMethod:
) )
return StandardCombineInput(hidden_states=output) return StandardCombineInput(hidden_states=output)
@torch._dynamo.disable() # TODO: 性能优化需lmslim/lightop配合
def apply_with_shared_output( def apply_with_shared_output(
self, self,
layer: torch.nn.Module, layer: torch.nn.Module,
......
...@@ -1193,6 +1193,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin): ...@@ -1193,6 +1193,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
self.seq_lens = seq_lens_tensor self.seq_lens = seq_lens_tensor
self.seq_lens_cpu = seq_lens_cpu self.seq_lens_cpu = seq_lens_cpu
self.extend_num_tokens = extend_num_tokens self.extend_num_tokens = extend_num_tokens
self.loc_tensor = torch.tensor([-1], device=self.device)
# Allocate memory # Allocate memory
out_cache_loc, req_pool_indices_tensor, req_pool_indices = alloc_for_extend( out_cache_loc, req_pool_indices_tensor, req_pool_indices = alloc_for_extend(
......
...@@ -361,7 +361,8 @@ def alloc_for_extend( ...@@ -361,7 +361,8 @@ def alloc_for_extend(
else: else:
# Paged allocation - build last_loc # Paged allocation - build last_loc
last_loc = [ last_loc = [
(t[-1:] if len(t) > 0 else torch.tensor([-1], device=batch.device)) # (t[-1:] if len(t) > 0 else torch.tensor([-1], device=batch.device))
(t[-1:] if len(t) > 0 else batch.loc_tensor)
for t in prefix_tensors for t in prefix_tensors
] ]
out_cache_loc = alloc_paged_token_slots_extend( out_cache_loc = alloc_paged_token_slots_extend(
......
...@@ -127,7 +127,7 @@ class ForwardMode(IntEnum): ...@@ -127,7 +127,7 @@ class ForwardMode(IntEnum):
# For fixed shape logits output in v2 eagle worker # For fixed shape logits output in v2 eagle worker
return self == ForwardMode.DRAFT_EXTEND_V2 return self == ForwardMode.DRAFT_EXTEND_V2
def is_extend_or_draft_extend_or_mixed(self): def is_extend_or_draft_extend_or_mixed(self): #nhb
return ( return (
self == ForwardMode.EXTEND self == ForwardMode.EXTEND
or self == ForwardMode.DRAFT_EXTEND or self == ForwardMode.DRAFT_EXTEND
......
...@@ -2241,6 +2241,7 @@ class ModelRunner: ...@@ -2241,6 +2241,7 @@ class ModelRunner:
and self.graph_runner and self.graph_runner
and self.graph_runner.can_run(forward_batch) and self.graph_runner.can_run(forward_batch)
) )
if can_run_graph: if can_run_graph:
ret = self.graph_runner.replay( ret = self.graph_runner.replay(
forward_batch, forward_batch,
......
...@@ -37,7 +37,8 @@ from sglang.srt.speculative.spec_utils import ( ...@@ -37,7 +37,8 @@ from sglang.srt.speculative.spec_utils import (
get_src_tgt_cache_loc, get_src_tgt_cache_loc,
get_target_cache_loc, get_target_cache_loc,
) )
from sglang.srt.utils import is_cuda, is_hip, next_power_of_2 from sglang.srt.utils import is_cuda, is_hip, next_power_of_2, get_bool_env_var
from sgl_kernel.kvcacheio import dcu_create_extend_after_decode_spec_info
if is_cuda(): if is_cuda():
from sgl_kernel import ( from sgl_kernel import (
...@@ -615,6 +616,8 @@ class EagleDraftInput(SpecInput, EagleDraftInputV2Mixin): ...@@ -615,6 +616,8 @@ class EagleDraftInput(SpecInput, EagleDraftInputV2Mixin):
new_seq_lens: Optional[torch.Tensor] = None new_seq_lens: Optional[torch.Tensor] = None
verify_done: Optional[torch.cuda.Event] = None verify_done: Optional[torch.cuda.Event] = None
use_sglang_create_extend_after_decode_spec_info = get_bool_env_var("SGLANG_CREATE_EXTEND_AFTER_DECODE_SPEC_INFO")
def __post_init__(self): def __post_init__(self):
super().__init__(SpecInputType.EAGLE_DRAFT) super().__init__(SpecInputType.EAGLE_DRAFT)
...@@ -679,6 +682,16 @@ class EagleDraftInput(SpecInput, EagleDraftInputV2Mixin): ...@@ -679,6 +682,16 @@ class EagleDraftInput(SpecInput, EagleDraftInputV2Mixin):
self.positions = torch.empty_like(batch.input_ids, dtype=torch.long) self.positions = torch.empty_like(batch.input_ids, dtype=torch.long)
self.verified_id = torch.empty_like(self.accept_length, dtype=torch.int32) self.verified_id = torch.empty_like(self.accept_length, dtype=torch.int32)
if self.use_sglang_create_extend_after_decode_spec_info:
dcu_create_extend_after_decode_spec_info(
verified_id = batch.input_ids,
seq_lens = batch.seq_lens,
accept_lens = self.accept_length,
positions = self.positions,
new_verified_id = self.verified_id,
bs = max(speculative_num_steps + 1, len(batch.seq_lens)),
)
else:
create_extend_after_decode_spec_info[(len(batch.seq_lens),)]( create_extend_after_decode_spec_info[(len(batch.seq_lens),)](
batch.input_ids, batch.input_ids,
batch.seq_lens, batch.seq_lens,
......
...@@ -139,6 +139,8 @@ TORCH_LIBRARY_EXPAND(sgl_kernel, m) { ...@@ -139,6 +139,8 @@ TORCH_LIBRARY_EXPAND(sgl_kernel, m) {
/* /*
* From csrc/kvcacheio * From csrc/kvcacheio
*/ */
m.def("dcu_create_extend_after_decode_spec_info(Tensor verified_id, Tensor seq_lens, Tensor accept_lens, Tensor positions, Tensor new_verified_id, int bs) -> ()");
m.impl("dcu_create_extend_after_decode_spec_info", torch::kCUDA, &dcu_create_extend_after_decode_spec_info);
m.def("dcu_create_chunked_prefix_cache_kv_indices(Tensor req_to_token, Tensor req_pool_indices, Tensor chunk_starts, Tensor chunk_seq_lens, Tensor chunk_cu_seq_lens, Tensor chunk_kv_indices, int col_num, int bs) -> ()"); m.def("dcu_create_chunked_prefix_cache_kv_indices(Tensor req_to_token, Tensor req_pool_indices, Tensor chunk_starts, Tensor chunk_seq_lens, Tensor chunk_cu_seq_lens, Tensor chunk_kv_indices, int col_num, int bs) -> ()");
m.impl("dcu_create_chunked_prefix_cache_kv_indices", torch::kCUDA, &dcu_create_chunked_prefix_cache_kv_indices); m.impl("dcu_create_chunked_prefix_cache_kv_indices", torch::kCUDA, &dcu_create_chunked_prefix_cache_kv_indices);
m.def("dcu_assign_extend_cache_locs(Tensor req_pool_indices, Tensor req_to_token, Tensor start_offset, Tensor end_offset, Tensor out_cache_loc, int pool_len, int bs) -> ()"); m.def("dcu_assign_extend_cache_locs(Tensor req_pool_indices, Tensor req_to_token, Tensor start_offset, Tensor end_offset, Tensor out_cache_loc, int pool_len, int bs) -> ()");
......
...@@ -693,6 +693,65 @@ __global__ void launch_alloc_extend_kernel( ...@@ -693,6 +693,65 @@ __global__ void launch_alloc_extend_kernel(
out_indices[output_idx] = start_loc * page_size + offset; out_indices[output_idx] = start_loc * page_size + offset;
} }
} }
__global__ void launch_create_extend_after_decode_spec_info_int32_kernel(
const int32_t* verified_id_ptr,
const int64_t* seq_lens_ptr,
const int32_t* accept_lens_ptr,
int64_t* positions_ptr,
int32_t* new_verified_id_ptr,
int64_t bs) {
int64_t pid = blockIdx.x * blockDim.x + threadIdx.x;
if (pid >= bs) return;
int64_t seq_length = seq_lens_ptr[pid];
int32_t accept_length = accept_lens_ptr[pid];
int32_t accept_len_cumsum = 0;
for (int32_t offset = 0; offset < pid; offset++) {
accept_len_cumsum += accept_lens_ptr[offset];
}
int64_t* positions_ptr1 = positions_ptr + accept_len_cumsum;
for (int32_t offset = 0; offset < accept_length && offset < bs; offset++)
{
positions_ptr1[offset] = seq_length - accept_length + offset;
}
int32_t verified_idx = accept_len_cumsum + accept_length - 1;
new_verified_id_ptr[pid] = verified_id_ptr[verified_idx];
}
__global__ void launch_create_extend_after_decode_spec_info_int64_kernel(
const int32_t* verified_id_ptr,
const int64_t* seq_lens_ptr,
const int64_t* accept_lens_ptr,
int64_t* positions_ptr,
int32_t* new_verified_id_ptr,
int64_t bs) {
int64_t pid = blockIdx.x * blockDim.x + threadIdx.x;
if (pid >= bs) return;
int64_t seq_length = seq_lens_ptr[pid];
int64_t accept_length = accept_lens_ptr[pid];
int64_t accept_len_cumsum = 0;
for (int64_t offset = 0; offset < pid; offset++) {
accept_len_cumsum += accept_lens_ptr[offset];
}
int64_t* positions_ptr1 = positions_ptr + accept_len_cumsum;
for (int64_t offset = 0; offset < accept_length && offset < bs; offset++)
{
positions_ptr1[offset] = seq_length - accept_length + offset;
}
int64_t verified_idx = accept_len_cumsum + accept_length - 1;
new_verified_id_ptr[pid] = verified_id_ptr[verified_idx];
}
void dcu_alloc_decode_kernel( void dcu_alloc_decode_kernel(
const at::Tensor seq_lens_ptr, const at::Tensor seq_lens_ptr,
...@@ -714,6 +773,49 @@ void dcu_alloc_decode_kernel( ...@@ -714,6 +773,49 @@ void dcu_alloc_decode_kernel(
C10_CUDA_KERNEL_LAUNCH_CHECK(); C10_CUDA_KERNEL_LAUNCH_CHECK();
} }
void dcu_create_extend_after_decode_spec_info(
const at::Tensor verified_id,
const at::Tensor seq_lens,
const at::Tensor accept_lens,
at::Tensor positions,
at::Tensor new_verified_id,
int64_t bs) {
const int32_t* verified_id_ptr;
const int64_t* seq_lens_ptr;
const int32_t* accept_lens_ptr_int32;
const int64_t* accept_lens_ptr_int64;
int64_t* positions_ptr;
int32_t* new_verified_id_ptr;
int64_t block_size = 64;
int64_t grid_size = (bs + block_size - 1) / block_size;
cudaStream_t torch_current_stream = at::cuda::getCurrentCUDAStream();
if (accept_lens.dtype() == torch::kInt32)
{
verified_id_ptr = static_cast<const int32_t*>(verified_id.data_ptr());
seq_lens_ptr = static_cast<const int64_t*>(seq_lens.data_ptr());
accept_lens_ptr_int32 = static_cast<const int32_t*>(accept_lens.data_ptr());
positions_ptr = static_cast<int64_t*>(positions.data_ptr());
new_verified_id_ptr = static_cast<int32_t*>(new_verified_id.data_ptr());
launch_create_extend_after_decode_spec_info_int32_kernel<<<grid_size, block_size, 0, torch_current_stream>>>(verified_id_ptr, seq_lens_ptr, accept_lens_ptr_int32, positions_ptr, new_verified_id_ptr, bs);
C10_CUDA_KERNEL_LAUNCH_CHECK();
}
else
{
verified_id_ptr = static_cast<const int32_t*>(verified_id.data_ptr());
seq_lens_ptr = static_cast<const int64_t*>(seq_lens.data_ptr());
accept_lens_ptr_int64 = static_cast<const int64_t*>(accept_lens.data_ptr());
positions_ptr = static_cast<int64_t*>(positions.data_ptr());
new_verified_id_ptr = static_cast<int32_t*>(new_verified_id.data_ptr());
launch_create_extend_after_decode_spec_info_int64_kernel<<<grid_size, block_size, 0, torch_current_stream>>>(verified_id_ptr, seq_lens_ptr, accept_lens_ptr_int64, positions_ptr, new_verified_id_ptr, bs);
C10_CUDA_KERNEL_LAUNCH_CHECK();
}
};
void dcu_alloc_extend_kernel( void dcu_alloc_extend_kernel(
const at::Tensor pre_lens_ptr, const at::Tensor pre_lens_ptr,
const at::Tensor seq_lens_ptr, const at::Tensor seq_lens_ptr,
......
...@@ -538,6 +538,13 @@ void segment_packbits( ...@@ -538,6 +538,13 @@ void segment_packbits(
/* /*
* From csrc/kvcacheio * From csrc/kvcacheio
*/ */
void dcu_create_extend_after_decode_spec_info(
const at::Tensor verified_id,
const at::Tensor seq_lens,
const at::Tensor accept_lens,
at::Tensor positions,
at::Tensor new_verified_id,
int64_t bs);
void dcu_create_chunked_prefix_cache_kv_indices( void dcu_create_chunked_prefix_cache_kv_indices(
at::Tensor req_to_token, at::Tensor req_to_token,
const at::Tensor req_pool_indices, const at::Tensor req_pool_indices,
......
...@@ -9,6 +9,22 @@ def is_hip() -> bool: ...@@ -9,6 +9,22 @@ def is_hip() -> bool:
_is_hip = is_hip() _is_hip = is_hip()
def dcu_create_extend_after_decode_spec_info(
verified_id: torch.Tensor,
seq_lens: torch.Tensor,
accept_lens: torch.Tensor,
positions: torch.Tensor,
new_verified_id: torch.Tensor,
bs: int,
):
torch.ops.sgl_kernel.dcu_create_extend_after_decode_spec_info(
verified_id,
seq_lens,
accept_lens,
positions,
new_verified_id,
bs,
)
def dcu_alloc_extend_kernel( def dcu_alloc_extend_kernel(
pre_lens_ptr: torch.Tensor, pre_lens_ptr: torch.Tensor,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment