Commit 4167eff9 authored by renzhc's avatar renzhc
Browse files

Merge branch 'v0.5.4_dev' of http://developer.sourcefind.cn/codes/OpenDAS/sglang into v0.5.4_rzc

parents 8da47f19 0dc51b09
......@@ -103,54 +103,112 @@ class DCUMLABackend(AttentionBackend):
skip_prefill=False,
)
def _build_decode_metadata(
self,
forward_batch: ForwardBatch,
seq_lens: torch.Tensor
) -> Tuple[Tuple[torch.Tensor, torch.Tensor], torch.Tensor, torch.Tensor]:
def init_forward_metadata(self, forward_batch: ForwardBatch):
bs = forward_batch.batch_size
max_seqlen_pad = triton.cdiv(seq_lens.max().item(), PAGE_SIZE)
if forward_batch.forward_mode.is_decode_or_idle():
max_seqlen_pad = triton.cdiv(
forward_batch.seq_lens_cpu.max().item(), PAGE_SIZE
)
# 参考vllm官方博客分页
block_kv_indices = torch.full(
(bs, max_seqlen_pad), -1, dtype=torch.int32, device=seq_lens.device
)
create_flashmla_kv_indices_triton[(bs,)](
self.req_to_token,
forward_batch.req_pool_indices,
seq_lens,
None,
block_kv_indices,
self.req_to_token.stride(0),
max_seqlen_pad,
)
mla_metadata, num_splits = get_mla_metadata(
seq_lens.to(torch.int32), self.num_q_heads, 1
)
return (mla_metadata, num_splits), num_splits, block_kv_indices
def init_forward_metadata(self, forward_batch: ForwardBatch):
block_kv_indices = torch.full(
(bs, max_seqlen_pad),
-1,
dtype=torch.int32,
device=forward_batch.seq_lens.device
)
create_flashmla_kv_indices_triton[(bs,)](
self.req_to_token,
forward_batch.req_pool_indices,
forward_batch.seq_lens,
None,
block_kv_indices,
self.req_to_token.stride(0),
max_seqlen_pad,
)
if forward_batch.forward_mode.is_decode_or_idle():
# decode用flashmla
(mla_metadata, num_splits), num_splits_t, block_kv_indices = (
self._build_decode_metadata(forward_batch, forward_batch.seq_lens)
mla_metadata, num_splits = get_mla_metadata(
forward_batch.seq_lens.to(torch.int32),
self.num_q_heads,
1
)
self.forward_metadata = VllmMLADecodeMetadata(
mla_metadata, num_splits_t, block_kv_indices
mla_metadata,
num_splits,
block_kv_indices
)
elif forward_batch.forward_mode.is_target_verify():
seq_lens_cpu = forward_batch.seq_lens_cpu + self.num_draft_tokens
seq_lens = forward_batch.seq_lens + self.num_draft_tokens
(mla_metadata, num_splits), num_splits_t, block_kv_indices = (
self._build_decode_metadata(forward_batch, seq_lens)
max_seqlen_pad = triton.cdiv(seq_lens_cpu.max().item(), PAGE_SIZE)
block_kv_indices = torch.full(
(bs, max_seqlen_pad),
-1,
dtype=torch.int32,
device=seq_lens.device,
)
create_flashmla_kv_indices_triton[(bs,)](
self.req_to_token,
forward_batch.req_pool_indices,
seq_lens,
None,
block_kv_indices,
self.req_to_token.stride(0),
max_seqlen_pad,
)
mla_metadata, num_splits = get_mla_metadata(
seq_lens.to(torch.int32),
self.num_draft_tokens * self.num_q_heads,
1,
)
self.forward_metadata = VllmMLADecodeMetadata(
mla_metadata, num_splits_t, block_kv_indices
mla_metadata,
num_splits,
block_kv_indices
)
else:
if not self.skip_prefill:
# === DRAFT_EXTEND_V2 MLA metadata === nhb
if forward_batch.forward_mode == ForwardMode.DRAFT_EXTEND_V2:
bs = forward_batch.batch_size
seq_lens_cpu = forward_batch.seq_lens_cpu
seq_lens = forward_batch.seq_lens
max_seqlen_pad = triton.cdiv(seq_lens_cpu.max().item(), PAGE_SIZE)
block_kv_indices = torch.full(
(bs, max_seqlen_pad),
-1,
dtype=torch.int32,
device=seq_lens.device,
)
# 调用 Triton kernel 生成 block_kv_indices
create_flashmla_kv_indices_triton[(bs,)](
self.req_to_token,
forward_batch.req_pool_indices,
seq_lens,
None,
block_kv_indices,
self.req_to_token.stride(0),
max_seqlen_pad,
)
# MLA
mla_metadata, num_splits = get_mla_metadata(
seq_lens.to(torch.int32),
self.num_q_heads,
1,
)
# save forward_metadata
self.forward_metadata = VllmMLADecodeMetadata(
mla_metadata,
num_splits,
block_kv_indices,
)
self.flashattn_backend.init_forward_metadata(forward_batch)
def init_cuda_graph_state(
......@@ -389,7 +447,7 @@ class DCUMLABackend(AttentionBackend):
reshape_q = q.view(bs, -1, layer.tp_q_head_num, layer.head_dim)
k_cache_reshaped = k_cache.view(-1, PAGE_SIZE, 1, self.kv_cache_dim)
num_draft_tokens = self.num_draft_tokens if self.num_draft_tokens is not None else 0
if self.data_type in (torch.float8_e4m3fn, torch.float8_e4m3fnuz,
torch.float8_e5m2, torch.float8_e5m2fnuz):
if self.data_type in (torch.float8_e4m3fnuz, torch.float8_e4m3fn):
......@@ -401,7 +459,7 @@ class DCUMLABackend(AttentionBackend):
reshape_q,
k_cache_reshaped,
self.forward_metadata.block_kv_indices[:bs],
forward_batch.seq_lens.to(torch.int32),
(forward_batch.seq_lens + num_draft_tokens).to(torch.int32),
layer.scaling,
k_scale,
kv_cache_dtype=kv_cache_dtype,
......@@ -411,7 +469,7 @@ class DCUMLABackend(AttentionBackend):
reshape_q,
k_cache_reshaped,
self.forward_metadata.block_kv_indices[:bs],
forward_batch.seq_lens.to(torch.int32),
(forward_batch.seq_lens + num_draft_tokens).to(torch.int32),
layer.scaling,
)
......@@ -431,12 +489,9 @@ class DCUMLABackend(AttentionBackend):
k_rope: Optional[torch.Tensor] = None,
sinks=None,
):
if save_kv_cache:
return self.forward_decode(q,k,v,layer,forward_batch, save_kv_cache)
if ((
if (
forward_batch.forward_mode == ForwardMode.EXTEND
or forward_batch.forward_mode == ForwardMode.DRAFT_EXTEND)
or forward_batch.forward_mode == ForwardMode.DRAFT_EXTEND
):
if not self.skip_prefill:
return self.flashattn_backend.forward_extend(
......@@ -449,14 +504,19 @@ class DCUMLABackend(AttentionBackend):
if k is not None:
assert v is not None
if save_kv_cache:
forward_batch.token_to_kv_pool.set_kv_buffer(layer, cache_loc, k, v)
forward_batch.token_to_kv_pool.set_kv_buffer(
layer,
cache_loc,
k,
v,
)
bs = forward_batch.batch_size
k_cache = forward_batch.token_to_kv_pool.get_key_buffer(layer.layer_id)
reshape_q = q.view(bs, -1, layer.tp_q_head_num, layer.head_dim)
k_cache_reshaped = k_cache.view(-1, PAGE_SIZE, 1, self.kv_cache_dim)
num_draft_tokens = self.num_draft_tokens if self.num_draft_tokens is not None else 0
if self.data_type in (torch.float8_e4m3fn, torch.float8_e4m3fnuz,
torch.float8_e5m2, torch.float8_e5m2fnuz):
if self.data_type in (torch.float8_e4m3fnuz, torch.float8_e4m3fn):
......@@ -468,7 +528,7 @@ class DCUMLABackend(AttentionBackend):
reshape_q,
k_cache_reshaped,
self.forward_metadata.block_kv_indices[:bs],
(forward_batch.seq_lens + self.num_draft_tokens).to(torch.int32),
(forward_batch.seq_lens + num_draft_tokens).to(torch.int32),
layer.scaling,
k_scale,
kv_cache_dtype=kv_cache_dtype,
......@@ -478,13 +538,12 @@ class DCUMLABackend(AttentionBackend):
reshape_q,
k_cache_reshaped,
self.forward_metadata.block_kv_indices[:bs],
(forward_batch.seq_lens + self.num_draft_tokens).to(torch.int32),
(forward_batch.seq_lens + num_draft_tokens).to(torch.int32),
layer.scaling,
)
return o.view(-1, layer.tp_q_head_num * layer.v_head_dim)
class DCUMLAMultiStepDraftBackend:
"""
Wrap multiple flashmla attention backends as one for multiple consecutive
......
......@@ -329,6 +329,8 @@ class FlashAttentionBackend(AttentionBackend):
self.use_mla = model_runner.model_config.attention_arch == AttentionArch.MLA
self.skip_prefill = skip_prefill
self.is_hybrid = model_runner.is_hybrid
self.k_scale = torch.tensor([1.0], dtype=torch.float32, device=self.device)
self.v_scale = torch.tensor([1.0], dtype=torch.float32, device=self.device)
if self.is_hybrid:
self.full_to_swa_index_mapping = (
model_runner.token_to_kv_pool.full_to_swa_index_mapping
......@@ -598,6 +600,7 @@ class FlashAttentionBackend(AttentionBackend):
if (
any(forward_batch.extend_prefix_lens_cpu)
or forward_batch.forward_mode == ForwardMode.DRAFT_EXTEND
or forward_batch.forward_mode == ForwardMode.DRAFT_EXTEND_V2 #nhb
):
extend_seq_lens = forward_batch.extend_seq_lens
metadata.max_seq_len_q = max(forward_batch.extend_seq_lens_cpu)
......@@ -608,10 +611,13 @@ class FlashAttentionBackend(AttentionBackend):
metadata.max_seq_len_q = metadata.max_seq_len_k
metadata.cu_seqlens_q = metadata.cu_seqlens_k
# Setup local attention if enabled
if forward_batch.forward_mode == ForwardMode.EXTEND:
# # Setup local attention if enabled
# if forward_batch.forward_mode == ForwardMode.EXTEND:
# self._init_local_attn_metadata(forward_batch, metadata, device)
if forward_batch.forward_mode in (ForwardMode.EXTEND, ForwardMode.DRAFT_EXTEND_V2):
self._init_local_attn_metadata(forward_batch, metadata, device)
# Encoder metadata for cross attention
if forward_batch.encoder_lens is not None:
assert (
......@@ -668,10 +674,16 @@ class FlashAttentionBackend(AttentionBackend):
if not layer.is_cross_attention
else forward_batch.encoder_out_cache_loc
)
if not self.use_mla:
forward_batch.token_to_kv_pool.set_kv_buffer(
layer, cache_loc, k, v, layer.k_scale, layer.v_scale
)
# if not self.use_mla:
if k_rope is None:
if not self.use_mla:
forward_batch.token_to_kv_pool.set_kv_buffer(
layer, cache_loc, k, v, layer.k_scale, layer.v_scale
)
else:
forward_batch.token_to_kv_pool.set_kv_buffer(
layer, cache_loc, k, v
)
else:
forward_batch.token_to_kv_pool.set_mla_kv_buffer(
layer,
......@@ -690,7 +702,8 @@ class FlashAttentionBackend(AttentionBackend):
layer.sliding_window_size is not None and layer.sliding_window_size > -1
)
window_size = (layer.sliding_window_size, 0) if is_swa else (-1, -1)
k_descale, v_descale = None, None
# k_descale, v_descale = None, None
k_descale, v_descale = self.k_scale, self.v_scale
# only use kv scaling if: 1) fp8 kv is explicitly enabled, 2) RadixAttention
# has corresponding quantization method so that layer.k_scale is not None,
# 3) layer.head_dim <= 256 since fa3 kernel require fp16 and bf16 data type in this case,
......@@ -774,55 +787,53 @@ class FlashAttentionBackend(AttentionBackend):
cu_seqlens_k = metadata.encoder_cu_seqlens_k
window_size = (-1, -1)
result = flash_attn_with_kvcache(
q=q.contiguous().view(-1, layer.tp_q_head_num, layer.head_dim),
k_cache=key_cache,
v_cache=value_cache,
page_table=page_table,
cache_seqlens=cache_seqlens,
cu_seqlens_q=cu_seqlens_q,
cu_seqlens_k_new=cu_seqlens_k if not use_local_attn else None,
max_seqlen_q=max_seqlen_q,
softmax_scale=layer.scaling,
causal=False if use_cascade_attn else causal,
window_size=window_size,
softcap=layer.logit_cap,
k_descale=k_descale,
v_descale=v_descale,
return_softmax_lse=use_cascade_attn,
num_splits=self.num_splits,
**kwargs,
)
if use_cascade_attn:
o, softmax_lse, *rest = result
o_expand, softmax_lse_expand, *rest_expand = flash_attn_with_kvcache(
q=q.contiguous().view(-1, layer.tp_q_head_num, layer.head_dim),
k_cache=key_cache,
v_cache=value_cache,
page_table=self.forward_metadata_spec_decode_expand.page_table,
cache_seqlens=self.forward_metadata_spec_decode_expand.cache_seqlens_int32,
cu_seqlens_q=self.forward_metadata_spec_decode_expand.cu_seqlens_q,
cu_seqlens_k_new=self.forward_metadata_spec_decode_expand.cu_seqlens_k,
max_seqlen_q=self.forward_metadata_spec_decode_expand.max_seq_len_q,
if forward_batch.attn_attend_prefix_cache:
assert not get_global_server_args().disable_chunked_prefix_cache
# MHA for chunked prefix kv cache when running model with MLA
assert forward_batch.prefix_chunk_idx is not None
assert forward_batch.prefix_chunk_cu_seq_lens is not None
assert forward_batch.prefix_chunk_max_seq_lens is not None
chunk_idx = forward_batch.prefix_chunk_idx
assert chunk_idx >= 0
assert forward_batch.mha_return_lse
output = flash_attn_varlen_func(
q=q.view(-1, layer.tp_q_head_num, layer.head_dim),
k=k.view(-1, layer.tp_k_head_num, layer.head_dim).view(q.dtype),
v=v.view(-1, layer.tp_k_head_num, layer.v_head_dim).view(q.dtype),
cu_seqlens_q=metadata.cu_seqlens_q,
cu_seqlens_k=forward_batch.prefix_chunk_cu_seq_lens[chunk_idx],
max_seqlen_q=metadata.max_seq_len_q,
max_seqlen_k=forward_batch.prefix_chunk_max_seq_lens[chunk_idx],
softmax_scale=layer.scaling,
causal=False,
window_size=window_size,
softcap=layer.logit_cap,
k_descale=k_descale,
v_descale=v_descale,
return_softmax_lse=True,
num_splits=self.num_splits,
**kwargs,
)
o, _ = merge_state_v2_wrapper(
o,
softmax_lse.T.contiguous(),
o_expand,
softmax_lse_expand.T.contiguous(),
)
else:
o = result
output = flash_attn_varlen_func(
q=q.view(-1, layer.tp_q_head_num, layer.head_dim),
k=k.view(-1, layer.tp_k_head_num, layer.head_dim).view(q.dtype),
v=v.view(-1, layer.tp_k_head_num, layer.v_head_dim).view(q.dtype),
cu_seqlens_q=metadata.cu_seqlens_q,
cu_seqlens_k=metadata.cu_seqlens_q,
max_seqlen_q=metadata.max_seq_len_q,
max_seqlen_k=metadata.max_seq_len_q,
softmax_scale=layer.scaling,
causal=True,
k_descale=k_descale,
v_descale=v_descale,
return_softmax_lse=forward_batch.mha_return_lse,
**kwargs,
)
if forward_batch.mha_return_lse:
output, lse, *rest = output
lse = torch.transpose(lse, 0, 1).contiguous()
return output, lse
return output.view(-1, layer.tp_q_head_num * layer.v_head_dim)
else:
if (
forward_batch.attn_attend_prefix_cache is not None
......@@ -851,6 +862,8 @@ class FlashAttentionBackend(AttentionBackend):
max_seqlen_k=forward_batch.prefix_chunk_max_seq_lens[chunk_idx],
softmax_scale=layer.scaling,
causal=False,
k_descale=k_descale,
v_descale=v_descale,
return_softmax_lse=True,
**kwargs,
)
......@@ -865,6 +878,8 @@ class FlashAttentionBackend(AttentionBackend):
max_seqlen_k=metadata.max_seq_len_q,
softmax_scale=layer.scaling,
causal=True,
k_descale=k_descale,
v_descale=v_descale,
return_softmax_lse=forward_batch.mha_return_lse,
**kwargs,
)
......@@ -974,10 +989,16 @@ class FlashAttentionBackend(AttentionBackend):
if not layer.is_cross_attention
else forward_batch.encoder_out_cache_loc
)
if not self.use_mla:
forward_batch.token_to_kv_pool.set_kv_buffer(
layer, cache_loc, k, v, layer.k_scale, layer.v_scale
)
# if not self.use_mla:
if k_rope is None:
if not self.use_mla:
forward_batch.token_to_kv_pool.set_kv_buffer(
layer, cache_loc, k, v, layer.k_scale, layer.v_scale
)
else:
forward_batch.token_to_kv_pool.set_kv_buffer(
layer, cache_loc, k, v
)
else:
forward_batch.token_to_kv_pool.set_mla_kv_buffer(
layer,
......@@ -1019,7 +1040,8 @@ class FlashAttentionBackend(AttentionBackend):
if sinks is not None:
kwargs["sinks"] = sinks
k_descale, v_descale = None, None
# k_descale, v_descale = None, None
k_descale, v_descale = self.k_scale, self.v_scale
# only use kv scaling if: 1) fp8 kv is explicitly enabled, 2) RadixAttention
# has corresponding quantization method so that layer.k_scale is not None,
# 3) layer.head_dim <= 256 since fa3 kernel require fp16 and bf16 data type in this case.
......@@ -1033,7 +1055,6 @@ class FlashAttentionBackend(AttentionBackend):
k_rope = k_rope.to(self.kv_cache_dtype) if k_rope is not None else None
if not self.use_mla:
# Do multi-head attention
key_cache, value_cache = forward_batch.token_to_kv_pool.get_kv_buffer(
layer.layer_id
)
......@@ -1085,65 +1106,62 @@ class FlashAttentionBackend(AttentionBackend):
**kwargs,
)
else:
cu_seqlens_q = metadata.cu_seqlens_q
max_seqlen_q = metadata.max_seq_len_q
page_table = metadata.page_table
cache_seqlens = metadata.cache_seqlens_int32
cu_seqlens_k = metadata.cu_seqlens_k
max_seqlen_q = metadata.max_seq_len_q
q_reshaped = q.contiguous().view(
-1, layer.tp_q_head_num, layer.head_dim
cache_seqlens = metadata.cache_seqlens_int32
key_cache = key_cache.view(
-1, self.page_size, layer.tp_k_head_num, layer.head_dim
)
# Default: single-token self-attention
result = flash_attn_with_kvcache(
q=q_reshaped,
k_cache=key_cache,
v_cache=value_cache,
page_table=page_table,
cache_seqlens=cache_seqlens,
cu_seqlens_q=metadata.cu_seqlens_q,
cu_seqlens_k_new=cu_seqlens_k,
max_seqlen_q=max_seqlen_q,
softmax_scale=layer.scaling,
causal=False if use_cascade_attn else causal,
window_size=window_size,
softcap=layer.logit_cap,
k_descale=k_descale,
v_descale=v_descale,
return_softmax_lse=use_cascade_attn,
num_splits=self.num_splits,
**kwargs,
value_cache = value_cache.view(
-1, self.page_size, layer.tp_v_head_num, layer.head_dim
)
if use_cascade_attn:
o, softmax_lse, *rest = result
o_expand, softmax_lse_expand, *rest_expand = (
flash_attn_with_kvcache(
q=q_reshaped,
k_cache=key_cache,
v_cache=value_cache,
page_table=self.forward_metadata_spec_decode_expand.page_table,
cache_seqlens=self.forward_metadata_spec_decode_expand.cache_seqlens_int32,
cu_seqlens_q=self.forward_metadata_spec_decode_expand.cu_seqlens_q,
cu_seqlens_k_new=self.forward_metadata_spec_decode_expand.cu_seqlens_k,
max_seqlen_q=self.forward_metadata_spec_decode_expand.max_seq_len_q,
softmax_scale=layer.scaling,
causal=False,
window_size=window_size,
softcap=layer.logit_cap,
k_descale=k_descale,
v_descale=v_descale,
return_softmax_lse=True,
num_splits=self.num_splits,
**kwargs,
)
)
o, _ = merge_state_v2(
o,
softmax_lse.T.contiguous(),
o_expand,
softmax_lse_expand.T.contiguous(),
if layer.is_cross_attention:
page_table = metadata.encoder_page_table
cache_seqlens = metadata.encoder_lens_int32
cu_seqlens_k = metadata.encoder_cu_seqlens_k
window_size = (-1, -1)
if max_seqlen_q > 1:
result = flash_attn_varlen_func(
q=q.contiguous().view(-1, layer.tp_q_head_num, layer.head_dim),
k=k.view(-1, layer.tp_k_head_num, layer.head_dim).view(q.dtype),
v=v.view(-1, layer.tp_k_head_num, layer.v_head_dim).view(q.dtype),
cu_seqlens_q=cu_seqlens_q,
cu_seqlens_k=cu_seqlens_k,
max_seqlen_q=max_seqlen_q,
max_seqlen_k=max_seqlen_q,
softmax_scale=layer.scaling,
causal=True,
window_size=window_size,
softcap=layer.logit_cap,
k_descale=k_descale,
v_descale=v_descale,
return_softmax_lse=use_cascade_attn,
num_splits=self.num_splits,
**kwargs,
)
else:
o = result
result = flash_attn_with_kvcache(
q=q.contiguous().view(-1, layer.tp_q_head_num, layer.head_dim),
k_cache=key_cache,
v_cache=value_cache,
page_table=page_table,
cache_seqlens=cache_seqlens,
cu_seqlens_q=cu_seqlens_q,
cu_seqlens_k_new=cu_seqlens_k if not use_local_attn else None,
max_seqlen_q=max_seqlen_q,
softmax_scale=layer.scaling,
causal=True,
window_size=window_size,
softcap=layer.logit_cap,
k_descale=k_descale,
v_descale=v_descale,
return_softmax_lse=use_cascade_attn,
num_splits=self.num_splits,
**kwargs,
)
o = result
else:
# Do absorbed multi-latent attention
kv_cache = forward_batch.token_to_kv_pool.get_key_buffer(layer.layer_id).to(
......
......@@ -41,18 +41,18 @@ def flash_attn_with_kvcache(
ver=3,
):
return flash_attn_with_kvcache_interface(
q=q.contiguous().view(-1, max_seqlen_q, q.shape[-2], q.shape[-1]),
k_cache=k_cache,
v_cache=v_cache,
block_table=page_table,
cache_seqlens=cache_seqlens,
softmax_scale=softmax_scale,
causal=causal,
window_size=window_size,
softcap=softcap,
return_softmax_lse=return_softmax_lse,
num_splits=num_splits,
)
q=q.contiguous().view(-1, max_seqlen_q, q.shape[-2], q.shape[-1]),
k_cache=k_cache.view(q.dtype),
v_cache=v_cache.view(q.dtype),
block_table=page_table,
cache_seqlens=cache_seqlens,
softmax_scale=softmax_scale,
causal=causal,
window_size=window_size,
softcap=softcap,
return_softmax_lse=return_softmax_lse,
num_splits=num_splits,
)
def flash_attn_varlen_func(
q,
......@@ -83,8 +83,8 @@ def flash_attn_varlen_func(
):
return flash_attn_varlen_func_interface(
q=q,
k=k,
v=v,
k=k.view(q.dtype),
v=v.view(q.dtype),
cu_seqlens_q=cu_seqlens_q,
cu_seqlens_k=cu_seqlens_k,
max_seqlen_q=max_seqlen_q,
......@@ -92,4 +92,5 @@ def flash_attn_varlen_func(
softmax_scale=softmax_scale,
causal=causal,
return_attn_probs=return_softmax_lse,
softcap=softcap,
)
\ No newline at end of file
......@@ -1193,6 +1193,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
self.seq_lens = seq_lens_tensor
self.seq_lens_cpu = seq_lens_cpu
self.extend_num_tokens = extend_num_tokens
self.loc_tensor = torch.tensor([-1], device=self.device)
# Allocate memory
out_cache_loc, req_pool_indices_tensor, req_pool_indices = alloc_for_extend(
......
......@@ -1940,7 +1940,7 @@ class Scheduler(
batch.spec_info = batch_result.next_draft_input
batch.spec_info.future_indices = future_indices
batch.sampling_info.is_all_greedy = True #nhb
# batch.spec_info = EagleDraftInput(
# future_indices=future_indices,
# verify_done=batch_result.next_draft_input.verify_done,
......
......@@ -356,7 +356,8 @@ def alloc_for_extend(
else:
# Paged allocation - build last_loc
last_loc = [
(t[-1:] if len(t) > 0 else torch.tensor([-1], device=batch.device))
# (t[-1:] if len(t) > 0 else torch.tensor([-1], device=batch.device))
(t[-1:] if len(t) > 0 else batch.loc_tensor)
for t in prefix_tensors
]
out_cache_loc = alloc_paged_token_slots_extend(
......
......@@ -123,12 +123,13 @@ class ForwardMode(IntEnum):
# For fixed shape logits output in v2 eagle worker
return self == ForwardMode.DRAFT_EXTEND_V2
def is_extend_or_draft_extend_or_mixed(self):
def is_extend_or_draft_extend_or_mixed(self): #nhb
return (
self == ForwardMode.EXTEND
or self == ForwardMode.DRAFT_EXTEND
or self == ForwardMode.MIXED
or self == ForwardMode.SPLIT_PREFILL
or self == ForwardMode.SPLIT_PREFILL
or self == ForwardMode.DRAFT_EXTEND_V2
)
def is_cuda_graph(self):
......
......@@ -2241,6 +2241,7 @@ class ModelRunner:
and self.graph_runner
and self.graph_runner.can_run(forward_batch)
)
if can_run_graph:
ret = self.graph_runner.replay(
forward_batch,
......
......@@ -37,7 +37,8 @@ from sglang.srt.speculative.spec_utils import (
get_src_tgt_cache_loc,
get_target_cache_loc,
)
from sglang.srt.utils import is_cuda, is_hip, next_power_of_2
from sglang.srt.utils import is_cuda, is_hip, next_power_of_2, get_bool_env_var
from sgl_kernel.kvcacheio import dcu_create_extend_after_decode_spec_info
if is_cuda():
from sgl_kernel import (
......@@ -615,6 +616,8 @@ class EagleDraftInput(SpecInput, EagleDraftInputV2Mixin):
new_seq_lens: Optional[torch.Tensor] = None
verify_done: Optional[torch.cuda.Event] = None
use_sglang_create_extend_after_decode_spec_info = get_bool_env_var("SGLANG_CREATE_EXTEND_AFTER_DECODE_SPEC_INFO")
def __post_init__(self):
super().__init__(SpecInputType.EAGLE_DRAFT)
......@@ -679,14 +682,24 @@ class EagleDraftInput(SpecInput, EagleDraftInputV2Mixin):
self.positions = torch.empty_like(batch.input_ids, dtype=torch.long)
self.verified_id = torch.empty_like(self.accept_length, dtype=torch.int32)
create_extend_after_decode_spec_info[(len(batch.seq_lens),)](
batch.input_ids,
batch.seq_lens,
self.accept_length,
self.positions,
self.verified_id,
next_power_of_2(max(speculative_num_steps + 1, len(batch.seq_lens))),
)
if self.use_sglang_create_extend_after_decode_spec_info:
dcu_create_extend_after_decode_spec_info(
verified_id = batch.input_ids,
seq_lens = batch.seq_lens,
accept_lens = self.accept_length,
positions = self.positions,
new_verified_id = self.verified_id,
bs = max(speculative_num_steps + 1, len(batch.seq_lens)),
)
else:
create_extend_after_decode_spec_info[(len(batch.seq_lens),)](
batch.input_ids,
batch.seq_lens,
self.accept_length,
self.positions,
self.verified_id,
next_power_of_2(max(speculative_num_steps + 1, len(batch.seq_lens))),
)
def generate_attn_arg_prefill(
self,
......
#include <ATen/cuda/CUDAContext.h>
#include <c10/cuda/CUDAGuard.h>
#ifdef __HIP_PLATFORM_AMD__
#include <hip/hip_bf16.h>
#endif
#include <algorithm>
#include <optional>
......
......@@ -131,6 +131,9 @@ TORCH_LIBRARY_EXPAND(sgl_kernel, m) {
/*
* From csrc/kvcacheio
*/
m.def("dcu_create_extend_after_decode_spec_info(Tensor verified_id, Tensor seq_lens, Tensor accept_lens, Tensor positions, Tensor new_verified_id, int bs) -> ()");
m.impl("dcu_create_extend_after_decode_spec_info", torch::kCUDA, &dcu_create_extend_after_decode_spec_info);
m.def("dcu_alloc_extend_kernel(Tensor pre_lens_ptr, Tensor seq_lens_ptr, Tensor last_loc_ptr, Tensor free_page_ptr, Tensor out_indices, int bs, int page_size) -> ()");
m.impl("dcu_alloc_extend_kernel", torch::kCUDA, &dcu_alloc_extend_kernel);
m.def("dcu_alloc_decode_kernel(Tensor seq_lens_ptr, Tensor last_loc_ptr, Tensor free_page_ptr, Tensor out_indices, int bs, int page_size) -> ()");
......
......@@ -693,6 +693,65 @@ __global__ void launch_alloc_extend_kernel(
out_indices[output_idx] = start_loc * page_size + offset;
}
}
__global__ void launch_create_extend_after_decode_spec_info_int32_kernel(
const int32_t* verified_id_ptr,
const int64_t* seq_lens_ptr,
const int32_t* accept_lens_ptr,
int64_t* positions_ptr,
int32_t* new_verified_id_ptr,
int64_t bs) {
int64_t pid = blockIdx.x * blockDim.x + threadIdx.x;
if (pid >= bs) return;
int64_t seq_length = seq_lens_ptr[pid];
int32_t accept_length = accept_lens_ptr[pid];
int32_t accept_len_cumsum = 0;
for (int32_t offset = 0; offset < pid; offset++) {
accept_len_cumsum += accept_lens_ptr[offset];
}
int64_t* positions_ptr1 = positions_ptr + accept_len_cumsum;
for (int32_t offset = 0; offset < accept_length && offset < bs; offset++)
{
positions_ptr1[offset] = seq_length - accept_length + offset;
}
int32_t verified_idx = accept_len_cumsum + accept_length - 1;
new_verified_id_ptr[pid] = verified_id_ptr[verified_idx];
}
__global__ void launch_create_extend_after_decode_spec_info_int64_kernel(
const int32_t* verified_id_ptr,
const int64_t* seq_lens_ptr,
const int64_t* accept_lens_ptr,
int64_t* positions_ptr,
int32_t* new_verified_id_ptr,
int64_t bs) {
int64_t pid = blockIdx.x * blockDim.x + threadIdx.x;
if (pid >= bs) return;
int64_t seq_length = seq_lens_ptr[pid];
int64_t accept_length = accept_lens_ptr[pid];
int64_t accept_len_cumsum = 0;
for (int64_t offset = 0; offset < pid; offset++) {
accept_len_cumsum += accept_lens_ptr[offset];
}
int64_t* positions_ptr1 = positions_ptr + accept_len_cumsum;
for (int64_t offset = 0; offset < accept_length && offset < bs; offset++)
{
positions_ptr1[offset] = seq_length - accept_length + offset;
}
int64_t verified_idx = accept_len_cumsum + accept_length - 1;
new_verified_id_ptr[pid] = verified_id_ptr[verified_idx];
}
void dcu_alloc_decode_kernel(
const at::Tensor seq_lens_ptr,
......@@ -714,6 +773,49 @@ void dcu_alloc_decode_kernel(
C10_CUDA_KERNEL_LAUNCH_CHECK();
}
void dcu_create_extend_after_decode_spec_info(
const at::Tensor verified_id,
const at::Tensor seq_lens,
const at::Tensor accept_lens,
at::Tensor positions,
at::Tensor new_verified_id,
int64_t bs) {
const int32_t* verified_id_ptr;
const int64_t* seq_lens_ptr;
const int32_t* accept_lens_ptr_int32;
const int64_t* accept_lens_ptr_int64;
int64_t* positions_ptr;
int32_t* new_verified_id_ptr;
int64_t block_size = 64;
int64_t grid_size = (bs + block_size - 1) / block_size;
cudaStream_t torch_current_stream = at::cuda::getCurrentCUDAStream();
if (accept_lens.dtype() == torch::kInt32)
{
verified_id_ptr = static_cast<const int32_t*>(verified_id.data_ptr());
seq_lens_ptr = static_cast<const int64_t*>(seq_lens.data_ptr());
accept_lens_ptr_int32 = static_cast<const int32_t*>(accept_lens.data_ptr());
positions_ptr = static_cast<int64_t*>(positions.data_ptr());
new_verified_id_ptr = static_cast<int32_t*>(new_verified_id.data_ptr());
launch_create_extend_after_decode_spec_info_int32_kernel<<<grid_size, block_size, 0, torch_current_stream>>>(verified_id_ptr, seq_lens_ptr, accept_lens_ptr_int32, positions_ptr, new_verified_id_ptr, bs);
C10_CUDA_KERNEL_LAUNCH_CHECK();
}
else
{
verified_id_ptr = static_cast<const int32_t*>(verified_id.data_ptr());
seq_lens_ptr = static_cast<const int64_t*>(seq_lens.data_ptr());
accept_lens_ptr_int64 = static_cast<const int64_t*>(accept_lens.data_ptr());
positions_ptr = static_cast<int64_t*>(positions.data_ptr());
new_verified_id_ptr = static_cast<int32_t*>(new_verified_id.data_ptr());
launch_create_extend_after_decode_spec_info_int64_kernel<<<grid_size, block_size, 0, torch_current_stream>>>(verified_id_ptr, seq_lens_ptr, accept_lens_ptr_int64, positions_ptr, new_verified_id_ptr, bs);
C10_CUDA_KERNEL_LAUNCH_CHECK();
}
};
void dcu_alloc_extend_kernel(
const at::Tensor pre_lens_ptr,
const at::Tensor seq_lens_ptr,
......
......@@ -538,6 +538,14 @@ void segment_packbits(
/*
* From csrc/kvcacheio
*/
void dcu_create_extend_after_decode_spec_info(
const at::Tensor verified_id,
const at::Tensor seq_lens,
const at::Tensor accept_lens,
at::Tensor positions,
at::Tensor new_verified_id,
int64_t bs);
void dcu_alloc_extend_kernel(
const at::Tensor pre_lens_ptr,
const at::Tensor seq_lens_ptr,
......
......@@ -9,6 +9,22 @@ def is_hip() -> bool:
_is_hip = is_hip()
def dcu_create_extend_after_decode_spec_info(
verified_id: torch.Tensor,
seq_lens: torch.Tensor,
accept_lens: torch.Tensor,
positions: torch.Tensor,
new_verified_id: torch.Tensor,
bs: int,
):
torch.ops.sgl_kernel.dcu_create_extend_after_decode_spec_info(
verified_id,
seq_lens,
accept_lens,
positions,
new_verified_id,
bs,
)
def dcu_alloc_extend_kernel(
pre_lens_ptr: torch.Tensor,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment