Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
f6d91d7e
Commit
f6d91d7e
authored
Nov 19, 2025
by
niuhb
Browse files
Merge remote-tracking branch 'origin/v0.5.4_dev_shangxl' into v0.5.4_dev
parents
769353e6
b08d561e
Changes
12
Hide whitespace changes
Inline
Side-by-side
Showing
12 changed files
with
700 additions
and
99 deletions
+700
-99
python/sglang/srt/layers/attention/dcu_mla_backend.py
python/sglang/srt/layers/attention/dcu_mla_backend.py
+70
-29
python/sglang/srt/layers/attention/flashattention_backend.py
python/sglang/srt/layers/attention/flashattention_backend.py
+4
-9
python/sglang/srt/layers/attention/flashmla_backend.py
python/sglang/srt/layers/attention/flashmla_backend.py
+48
-20
python/sglang/srt/mem_cache/common.py
python/sglang/srt/mem_cache/common.py
+12
-7
python/sglang/srt/model_executor/forward_batch_info.py
python/sglang/srt/model_executor/forward_batch_info.py
+31
-12
python/sglang/srt/speculative/draft_utils.py
python/sglang/srt/speculative/draft_utils.py
+10
-3
python/sglang/srt/speculative/eagle_info_v2.py
python/sglang/srt/speculative/eagle_info_v2.py
+52
-18
sgl-kernel/csrc/common_extension_rocm.cc
sgl-kernel/csrc/common_extension_rocm.cc
+17
-0
sgl-kernel/csrc/kvcacheio/transfer.cu
sgl-kernel/csrc/kvcacheio/transfer.cu
+319
-1
sgl-kernel/include/sgl_kernel_ops.h
sgl-kernel/include/sgl_kernel_ops.h
+44
-0
sgl-kernel/python/sgl_kernel/flash_mla.py
sgl-kernel/python/sgl_kernel/flash_mla.py
+20
-0
sgl-kernel/python/sgl_kernel/kvcacheio.py
sgl-kernel/python/sgl_kernel/kvcacheio.py
+73
-0
No files found.
python/sglang/srt/layers/attention/dcu_mla_backend.py
View file @
f6d91d7e
...
...
@@ -11,6 +11,8 @@ from sglang.srt.layers.attention.base_attn_backend import AttentionBackend
from
sglang.srt.layers.attention.utils
import
create_flashmla_kv_indices_triton
from
sglang.srt.layers.dp_attention
import
get_attention_tp_size
from
sglang.srt.model_executor.forward_batch_info
import
ForwardBatch
,
ForwardMode
from
sgl_kernel.flash_mla
import
dcu_create_flashmla_kv_indices
from
sglang.srt.utils
import
get_bool_env_var
try
:
from
flash_mla
import
(
...
...
@@ -104,6 +106,7 @@ class DCUMLABackend(AttentionBackend):
)
def
init_forward_metadata
(
self
,
forward_batch
:
ForwardBatch
):
use_sglang_create_flashmla_kv_indices_triton
=
get_bool_env_var
(
"SGLANG_CREATE_FLASHMLA_KV_INDICES_TRITON"
)
bs
=
forward_batch
.
batch_size
if
forward_batch
.
forward_mode
.
is_decode_or_idle
():
...
...
@@ -118,15 +121,27 @@ class DCUMLABackend(AttentionBackend):
dtype
=
torch
.
int32
,
device
=
forward_batch
.
seq_lens
.
device
)
create_flashmla_kv_indices_triton
[(
bs
,)](
self
.
req_to_token
,
forward_batch
.
req_pool_indices
,
forward_batch
.
seq_lens
,
None
,
block_kv_indices
,
self
.
req_to_token
.
stride
(
0
),
max_seqlen_pad
,
)
if
use_sglang_create_flashmla_kv_indices_triton
:
dcu_create_flashmla_kv_indices
(
req_to_token_ptr
=
self
.
req_to_token
,
req_pool_indices_ptr
=
forward_batch
.
req_pool_indices
,
page_kernel_lens_ptr
=
forward_batch
.
seq_lens
,
kv_start_idx
=
None
,
kv_indices_ptr
=
block_kv_indices
,
req_to_token_ptr_stride
=
self
.
req_to_token
.
stride
(
0
),
kv_indices_ptr_stride
=
max_seqlen_pad
,
)
else
:
create_flashmla_kv_indices_triton
[(
bs
,)](
self
.
req_to_token
,
forward_batch
.
req_pool_indices
,
forward_batch
.
seq_lens
,
None
,
block_kv_indices
,
self
.
req_to_token
.
stride
(
0
),
max_seqlen_pad
,
)
mla_metadata
,
num_splits
=
get_mla_metadata
(
forward_batch
.
seq_lens
.
to
(
torch
.
int32
),
...
...
@@ -149,15 +164,27 @@ class DCUMLABackend(AttentionBackend):
dtype
=
torch
.
int32
,
device
=
seq_lens
.
device
,
)
create_flashmla_kv_indices_triton
[(
bs
,)](
self
.
req_to_token
,
forward_batch
.
req_pool_indices
,
seq_lens
,
None
,
block_kv_indices
,
self
.
req_to_token
.
stride
(
0
),
max_seqlen_pad
,
)
if
use_sglang_create_flashmla_kv_indices_triton
:
dcu_create_flashmla_kv_indices
(
req_to_token_ptr
=
self
.
req_to_token
,
req_pool_indices_ptr
=
forward_batch
.
req_pool_indices
,
page_kernel_lens_ptr
=
forward_batch
.
seq_lens
,
kv_start_idx
=
None
,
kv_indices_ptr
=
block_kv_indices
,
req_to_token_ptr_stride
=
self
.
req_to_token
.
stride
(
0
),
kv_indices_ptr_stride
=
max_seqlen_pad
,
)
else
:
create_flashmla_kv_indices_triton
[(
bs
,)](
self
.
req_to_token
,
forward_batch
.
req_pool_indices
,
forward_batch
.
seq_lens
,
None
,
block_kv_indices
,
self
.
req_to_token
.
stride
(
0
),
max_seqlen_pad
,
)
mla_metadata
,
num_splits
=
get_mla_metadata
(
seq_lens
.
to
(
torch
.
int32
),
self
.
num_draft_tokens
*
self
.
num_q_heads
,
...
...
@@ -185,15 +212,27 @@ class DCUMLABackend(AttentionBackend):
)
# 调用 Triton kernel 生成 block_kv_indices
create_flashmla_kv_indices_triton
[(
bs
,)](
self
.
req_to_token
,
forward_batch
.
req_pool_indices
,
seq_lens
,
None
,
block_kv_indices
,
self
.
req_to_token
.
stride
(
0
),
max_seqlen_pad
,
)
if
use_sglang_create_flashmla_kv_indices_triton
:
dcu_create_flashmla_kv_indices
(
req_to_token_ptr
=
self
.
req_to_token
.
to
(
torch
.
int32
),
req_pool_indices_ptr
=
forward_batch
.
req_pool_indices
.
to
(
torch
.
int32
),
page_kernel_lens_ptr
=
forward_batch
.
seq_lens
.
to
(
torch
.
int32
),
kv_start_idx
=
None
,
kv_indices_ptr
=
block_kv_indices
.
to
(
torch
.
int32
),
req_to_token_ptr_stride
=
self
.
req_to_token
.
stride
(
0
),
kv_indices_ptr_stride
=
max_seqlen_pad
,
)
else
:
create_flashmla_kv_indices_triton
[(
bs
,)](
self
.
req_to_token
,
forward_batch
.
req_pool_indices
,
forward_batch
.
seq_lens
,
None
,
block_kv_indices
,
self
.
req_to_token
.
stride
(
0
),
max_seqlen_pad
,
)
# MLA
mla_metadata
,
num_splits
=
get_mla_metadata
(
...
...
@@ -211,6 +250,7 @@ class DCUMLABackend(AttentionBackend):
self
.
flashattn_backend
.
init_forward_metadata
(
forward_batch
)
def
init_cuda_graph_state
(
self
,
max_bs
:
int
,
...
...
@@ -489,9 +529,10 @@ class DCUMLABackend(AttentionBackend):
k_rope
:
Optional
[
torch
.
Tensor
]
=
None
,
sinks
=
None
,
):
if
(
if
((
forward_batch
.
forward_mode
==
ForwardMode
.
EXTEND
or
forward_batch
.
forward_mode
==
ForwardMode
.
DRAFT_EXTEND
or
forward_batch
.
forward_mode
==
ForwardMode
.
DRAFT_EXTEND
)
):
if
not
self
.
skip_prefill
:
return
self
.
flashattn_backend
.
forward_extend
(
...
...
python/sglang/srt/layers/attention/flashattention_backend.py
View file @
f6d91d7e
...
...
@@ -674,16 +674,11 @@ class FlashAttentionBackend(AttentionBackend):
if
not
layer
.
is_cross_attention
else
forward_batch
.
encoder_out_cache_loc
)
# if not self.use_mla:
if
k_rope
is
None
:
if
not
self
.
use_mla
:
forward_batch
.
token_to_kv_pool
.
set_kv_buffer
(
layer
,
cache_loc
,
k
,
v
,
layer
.
k_scale
,
layer
.
v_scale
)
else
:
forward_batch
.
token_to_kv_pool
.
set_kv_buffer
(
layer
,
cache_loc
,
k
,
v
)
forward_batch
.
token_to_kv_pool
.
set_kv_buffer
(
layer
,
cache_loc
,
k
,
v
,
#layer.k_scale, layer.v_scale
)
else
:
forward_batch
.
token_to_kv_pool
.
set_mla_kv_buffer
(
layer
,
...
...
python/sglang/srt/layers/attention/flashmla_backend.py
View file @
f6d91d7e
...
...
@@ -16,6 +16,10 @@ from sglang.srt.layers.attention.utils import create_flashmla_kv_indices_triton
from
sglang.srt.layers.dp_attention
import
get_attention_tp_size
from
sglang.srt.model_executor.forward_batch_info
import
ForwardBatch
,
ForwardMode
from
sglang.srt.utils
import
get_bool_env_var
from
sgl_kernel.flash_mla
import
dcu_create_flashmla_kv_indices
if
TYPE_CHECKING
:
from
sglang.srt.layers.radix_attention
import
RadixAttention
from
sglang.srt.model_executor.model_runner
import
ModelRunner
...
...
@@ -79,7 +83,7 @@ class FlashMLABackend(FlashInferMLAAttnBackend):
self
.
num_draft_tokens
=
model_runner
.
server_args
.
speculative_num_draft_tokens
def
init_forward_metadata
(
self
,
forward_batch
:
ForwardBatch
):
use_sglang_create_flashmla_kv_indices_triton
=
get_bool_env_var
(
"SGLANG_CREATE_EXTEND_AFTER_DECODE_SPEC_INFO"
)
bs
=
forward_batch
.
batch_size
if
forward_batch
.
forward_mode
.
is_decode_or_idle
():
max_seqlen_pad
=
triton
.
cdiv
(
...
...
@@ -91,15 +95,27 @@ class FlashMLABackend(FlashInferMLAAttnBackend):
dtype
=
torch
.
int32
,
device
=
forward_batch
.
seq_lens
.
device
,
)
create_flashmla_kv_indices_triton
[(
bs
,)](
self
.
req_to_token
,
forward_batch
.
req_pool_indices
,
forward_batch
.
seq_lens
,
None
,
block_kv_indices
,
self
.
req_to_token
.
stride
(
0
),
max_seqlen_pad
,
)
if
use_sglang_create_flashmla_kv_indices_triton
:
dcu_create_flashmla_kv_indices
(
req_to_token_ptr
=
self
.
req_to_token
,
req_pool_indices_ptr
=
forward_batch
.
req_pool_indices
,
page_kernel_lens_ptr
=
forward_batch
.
seq_lens
,
kv_start_idx
=
None
,
kv_indices_ptr
=
block_kv_indices
,
req_to_token_ptr_stride
=
self
.
req_to_token
.
stride
(
0
),
kv_indices_ptr_stride
=
max_seqlen_pad
,
)
else
:
create_flashmla_kv_indices_triton
[(
bs
,)](
self
.
req_to_token
,
forward_batch
.
req_pool_indices
,
forward_batch
.
seq_lens
,
None
,
block_kv_indices
,
self
.
req_to_token
.
stride
(
0
),
max_seqlen_pad
,
)
mla_metadata
,
num_splits
=
get_mla_metadata
(
forward_batch
.
seq_lens
.
to
(
torch
.
int32
),
self
.
num_q_heads
,
...
...
@@ -121,15 +137,27 @@ class FlashMLABackend(FlashInferMLAAttnBackend):
dtype
=
torch
.
int32
,
device
=
seq_lens
.
device
,
)
create_flashmla_kv_indices_triton
[(
bs
,)](
self
.
req_to_token
,
forward_batch
.
req_pool_indices
,
seq_lens
,
None
,
block_kv_indices
,
self
.
req_to_token
.
stride
(
0
),
max_seqlen_pad
,
)
if
use_sglang_create_flashmla_kv_indices_triton
:
dcu_create_flashmla_kv_indices
(
req_to_token_ptr
=
self
.
req_to_token
,
req_pool_indices_ptr
=
forward_batch
.
req_pool_indices
,
page_kernel_lens_ptr
=
forward_batch
.
seq_lens
,
kv_start_idx
=
None
,
kv_indices_ptr
=
block_kv_indices
,
req_to_token_ptr_stride
=
self
.
req_to_token
.
stride
(
0
),
kv_indices_ptr_stride
=
max_seqlen_pad
,
)
else
:
create_flashmla_kv_indices_triton
[(
bs
,)](
self
.
req_to_token
,
forward_batch
.
req_pool_indices
,
forward_batch
.
seq_lens
,
None
,
block_kv_indices
,
self
.
req_to_token
.
stride
(
0
),
max_seqlen_pad
,
)
mla_metadata
,
num_splits
=
get_mla_metadata
(
seq_lens
.
to
(
torch
.
int32
),
self
.
num_draft_tokens
*
self
.
num_q_heads
,
...
...
@@ -144,7 +172,7 @@ class FlashMLABackend(FlashInferMLAAttnBackend):
)
else
:
super
().
init_forward_metadata
(
forward_batch
)
def
init_cuda_graph_state
(
self
,
max_bs
:
int
,
...
...
python/sglang/srt/mem_cache/common.py
View file @
f6d91d7e
...
...
@@ -13,7 +13,8 @@ from sglang.srt.mem_cache.chunk_cache import ChunkCache, SWAChunkCache
from
sglang.srt.mem_cache.mamba_radix_cache
import
MambaRadixCache
from
sglang.srt.mem_cache.memory_pool
import
HybridReqToTokenPool
,
ReqToTokenPool
from
sglang.srt.server_args
import
get_global_server_args
from
sglang.srt.utils
import
support_triton
from
sglang.srt.utils
import
support_triton
,
get_bool_env_var
from
sgl_kernel.kvcacheio
import
dcu_get_last_loc
if
TYPE_CHECKING
:
from
sglang.srt.managers.schedule_batch
import
Req
,
ScheduleBatch
...
...
@@ -125,13 +126,17 @@ def get_last_loc(
req_pool_indices_tensor
:
torch
.
Tensor
,
prefix_lens_tensor
:
torch
.
Tensor
,
)
->
torch
.
Tensor
:
if
(
get_global_server_args
().
attention_backend
!=
"ascend"
and
get_global_server_args
().
attention_backend
!=
"torch_native"
):
impl
=
get_last_loc_triton
use_sglang_get_last_loc
=
get_bool_env_var
(
"SGLANG_GET_LAST_LOC"
)
if
use_sglang_get_last_loc
:
impl
=
dcu_get_last_loc
else
:
impl
=
get_last_loc_torch
if
(
get_global_server_args
().
attention_backend
!=
"ascend"
and
get_global_server_args
().
attention_backend
!=
"torch_native"
):
impl
=
get_last_loc_triton
else
:
impl
=
get_last_loc_torch
return
impl
(
req_to_token
,
req_pool_indices_tensor
,
prefix_lens_tensor
)
...
...
python/sglang/srt/model_executor/forward_batch_info.py
View file @
f6d91d7e
...
...
@@ -46,7 +46,11 @@ from sglang.srt.layers.dp_attention import (
set_dp_buffer_len
,
set_is_extend_in_batch
,
)
from
sglang.srt.utils
import
get_compiler_backend
,
is_npu
,
support_triton
from
sglang.srt.utils
import
get_compiler_backend
,
is_npu
,
support_triton
,
get_bool_env_var
from
sgl_kernel.kvcacheio
import
dcu_create_chunked_prefix_cache_kv_indices
import
logging
logger
=
logging
.
getLogger
(
__name__
)
if
TYPE_CHECKING
:
from
sglang.srt.layers.attention.base_attn_backend
import
AttentionBackend
...
...
@@ -128,8 +132,8 @@ class ForwardMode(IntEnum):
self
==
ForwardMode
.
EXTEND
or
self
==
ForwardMode
.
DRAFT_EXTEND
or
self
==
ForwardMode
.
MIXED
or
self
==
ForwardMode
.
SPLIT_PREFILL
or
self
==
ForwardMode
.
DRAFT_EXTEND_V2
or
self
==
ForwardMode
.
SPLIT_PREFILL
or
self
==
ForwardMode
.
DRAFT_EXTEND_V2
#nhb
)
def
is_cuda_graph
(
self
):
...
...
@@ -317,6 +321,8 @@ class ForwardBatch:
tbo_parent_token_range
:
Optional
[
Tuple
[
int
,
int
]]
=
None
tbo_children
:
Optional
[
List
[
ForwardBatch
]]
=
None
use_sglang_create_chunked_prefix_cache_kv_indices
=
get_bool_env_var
(
"SGLANG_CREATE_CHUNKED_PREFIX_CACHE_KV_INDICES"
)
@
classmethod
def
init_new
(
cls
,
...
...
@@ -635,15 +641,28 @@ class ForwardBatch:
num_chunk_tokens
,
dtype
=
torch
.
int32
,
device
=
device
)
create_chunked_prefix_cache_kv_indices
[(
self
.
batch_size
,)](
self
.
req_to_token_pool
.
req_to_token
,
self
.
req_pool_indices
,
chunk_starts
,
chunk_seq_lens
,
chunk_cu_seq_lens
,
chunk_kv_indices
,
self
.
req_to_token_pool
.
req_to_token
.
shape
[
1
],
)
if
self
.
use_sglang_create_chunked_prefix_cache_kv_indices
:
dcu_create_chunked_prefix_cache_kv_indices
(
req_to_token
=
self
.
req_to_token_pool
.
req_to_token
,
req_pool_indices
=
self
.
req_pool_indices
,
chunk_starts
=
chunk_starts
,
chunk_seq_lens
=
chunk_seq_lens
,
chunk_cu_seq_lens
=
chunk_cu_seq_lens
,
chunk_kv_indices
=
chunk_kv_indices
,
col_num
=
self
.
req_to_token_pool
.
req_to_token
.
shape
[
1
],
bs
=
self
.
batch_size
,
)
else
:
logger
.
info
(
"SGLANG_CREATE_CHUNKED_PREFIX_CACHE_KV_INDICES=0"
)
create_chunked_prefix_cache_kv_indices
[(
self
.
batch_size
,)](
self
.
req_to_token_pool
.
req_to_token
,
self
.
req_pool_indices
,
chunk_starts
,
chunk_seq_lens
,
chunk_cu_seq_lens
,
chunk_kv_indices
,
self
.
req_to_token_pool
.
req_to_token
.
shape
[
1
],
)
self
.
prefix_chunk_kv_indices
.
append
(
chunk_kv_indices
)
def
_pad_tensor_to_size
(
self
,
tensor
:
torch
.
Tensor
,
size
:
int
,
*
,
value
:
int
=
0
):
...
...
python/sglang/srt/speculative/draft_utils.py
View file @
f6d91d7e
...
...
@@ -237,7 +237,14 @@ class DraftBackendFactory:
return
None
def
_create_dcumla_prefill_backend
(
self
):
logger
.
warning
(
"flashmla prefill backend is not yet supported for draft extend."
# logger.warning(
# "flashmla prefill backend is not yet supported for draft extend."
# )
# return None
#nhb
from
sglang.srt.layers.attention.flashattention_backend
import
(
FlashAttentionBackend
,
)
return
None
return
FlashAttentionBackend
(
self
.
draft_model_runner
,
skip_prefill
=
False
)
python/sglang/srt/speculative/eagle_info_v2.py
View file @
f6d91d7e
...
...
@@ -29,6 +29,12 @@ from sglang.srt.speculative.spec_utils import (
)
from
sglang.srt.utils.common
import
fast_topk
,
is_cuda
,
is_hip
,
next_power_of_2
from
sglang.srt.utils
import
get_bool_env_var
from
sgl_kernel.kvcacheio
import
dcu_assign_req_to_token_pool
,
dcu_assign_extend_cache_locs
import
logging
logger
=
logging
.
getLogger
(
__name__
)
if
TYPE_CHECKING
:
from
sglang.srt.managers.tp_worker
import
TpModelWorker
from
sglang.srt.speculative.eagle_draft_cuda_graph_runner
import
(
...
...
@@ -77,6 +83,9 @@ def assign_draft_cache_locs_page_size_1(
@
dataclass
class
EagleDraftInputV2Mixin
:
use_sglang_assign_req_to_token_pool
=
get_bool_env_var
(
"SGLANG_ASSIGN_REQ_TO_TOKEN_POOL"
)
def
prepare_for_decode
(
self
:
EagleDraftInput
,
batch
:
ScheduleBatch
):
from
sglang.srt.speculative.spec_utils
import
assign_req_to_token_pool
...
...
@@ -112,15 +121,26 @@ class EagleDraftInputV2Mixin:
extend_num_tokens
,
)
assign_req_to_token_pool
[(
bs
,)](
batch
.
req_pool_indices
,
batch
.
req_to_token_pool
.
req_to_token
,
self
.
allocate_lens
,
new_allocate_lens
,
out_cache_loc
,
batch
.
req_to_token_pool
.
req_to_token
.
shape
[
1
],
next_power_of_2
(
bs
),
)
if
self
.
use_sglang_assign_req_to_token_pool
:
dcu_assign_req_to_token_pool
(
req_pool_indices
=
batch
.
req_pool_indices
,
req_to_token
=
batch
.
req_to_token_pool
.
req_to_token
,
allocate_lens
=
self
.
allocate_lens
,
new_allocate_lens
=
new_allocate_lens
,
out_cache_loc
=
out_cache_loc
,
shape
=
batch
.
req_to_token_pool
.
req_to_token
.
shape
[
1
],
bs
=
bs
,
)
else
:
assign_req_to_token_pool
[(
bs
,)](
batch
.
req_pool_indices
,
batch
.
req_to_token_pool
.
req_to_token
,
self
.
allocate_lens
,
new_allocate_lens
,
out_cache_loc
,
batch
.
req_to_token_pool
.
req_to_token
.
shape
[
1
],
next_power_of_2
(
bs
),
)
self
.
allocate_lens
=
new_allocate_lens
# FIXME(lsyin): make this sync optional
...
...
@@ -189,6 +209,9 @@ class EagleDraftInputV2Mixin:
@
dataclass
class
EagleVerifyInputV2Mixin
:
use_sglang_assign_extend_cache_locs
=
get_bool_env_var
(
"SGLANG_ASSIGN_EXTEND_CACHE_LOCS"
)
def
prepare_for_v2_verify
(
self
:
EagleVerifyInput
,
req_to_token_pool
:
ReqToTokenPool
,
...
...
@@ -205,15 +228,26 @@ class EagleVerifyInputV2Mixin:
device
=
device
,
)
assign_extend_cache_locs
[(
bs
,)](
batch
.
req_pool_indices
,
req_to_token_pool
.
req_to_token
,
batch
.
seq_lens
,
batch
.
seq_lens
+
self
.
draft_token_num
,
batch
.
out_cache_loc
,
req_to_token_pool
.
req_to_token
.
shape
[
1
],
next_power_of_2
(
bs
),
)
if
self
.
use_sglang_assign_extend_cache_locs
:
dcu_assign_extend_cache_locs
(
batch
.
req_pool_indices
,
req_to_token_pool
.
req_to_token
,
batch
.
seq_lens
,
batch
.
seq_lens
+
self
.
draft_token_num
,
batch
.
out_cache_loc
,
req_to_token_pool
.
req_to_token
.
shape
[
1
],
bs
,
)
else
:
assign_extend_cache_locs
[(
bs
,)](
batch
.
req_pool_indices
,
req_to_token_pool
.
req_to_token
,
batch
.
seq_lens
,
batch
.
seq_lens
+
self
.
draft_token_num
,
batch
.
out_cache_loc
,
req_to_token_pool
.
req_to_token
.
shape
[
1
],
next_power_of_2
(
bs
),
)
# Get a forward batch
batch
.
forward_mode
=
ForwardMode
.
TARGET_VERIFY
...
...
sgl-kernel/csrc/common_extension_rocm.cc
View file @
f6d91d7e
...
...
@@ -19,6 +19,14 @@ limitations under the License.
#include "sgl_kernel_ops.h"
TORCH_LIBRARY_EXPAND
(
sgl_kernel
,
m
)
{
/*
* From FlashMLA
*/
m
.
def
(
"dcu_create_flashmla_kv_indices(Tensor req_to_token, Tensor req_pool_indices,Tensor page_kernel_lens, Tensor? kv_start_idx, Tensor kv_indices, int req_to_token_stride, int kv_indices_stride, int PAGED_SIZE) -> ()"
);
m
.
impl
(
"dcu_create_flashmla_kv_indices"
,
torch
::
kCUDA
,
&
dcu_create_flashmla_kv_indices
);
/*
* From csrc/activation
*/
...
...
@@ -133,6 +141,15 @@ TORCH_LIBRARY_EXPAND(sgl_kernel, m) {
*/
m
.
def
(
"dcu_create_extend_after_decode_spec_info(Tensor verified_id, Tensor seq_lens, Tensor accept_lens, Tensor positions, Tensor new_verified_id, int bs) -> ()"
);
m
.
impl
(
"dcu_create_extend_after_decode_spec_info"
,
torch
::
kCUDA
,
&
dcu_create_extend_after_decode_spec_info
);
m
.
def
(
"dcu_create_chunked_prefix_cache_kv_indices(Tensor req_to_token, Tensor req_pool_indices, Tensor chunk_starts, Tensor chunk_seq_lens, Tensor chunk_cu_seq_lens, Tensor chunk_kv_indices, int col_num, int bs) -> ()"
);
m
.
impl
(
"dcu_create_chunked_prefix_cache_kv_indices"
,
torch
::
kCUDA
,
&
dcu_create_chunked_prefix_cache_kv_indices
);
m
.
def
(
"dcu_assign_extend_cache_locs(Tensor req_pool_indices, Tensor req_to_token, Tensor start_offset, Tensor end_offset, Tensor out_cache_loc, int pool_len, int bs) -> ()"
);
m
.
impl
(
"dcu_assign_extend_cache_locs"
,
torch
::
kCUDA
,
&
dcu_assign_extend_cache_locs
);
m
.
def
(
"dcu_get_last_loc(Tensor req_to_token, Tensor req_pool_indices, Tensor prefix_lens) -> Tensor"
);
m
.
impl
(
"dcu_get_last_loc"
,
torch
::
kCUDA
,
&
dcu_get_last_loc
);
m
.
def
(
"dcu_assign_req_to_token_pool(Tensor req_pool_indices_ptr,Tensor req_to_token_ptr,Tensor allocate_lens_ptr,Tensor new_allocate_lens,Tensor out_cache_loc_ptr,int shape,int bs) -> ()"
);
m
.
impl
(
"dcu_assign_req_to_token_pool"
,
torch
::
kCUDA
,
&
dcu_assign_req_to_token_pool
);
m
.
def
(
"dcu_alloc_extend_kernel(Tensor pre_lens_ptr, Tensor seq_lens_ptr, Tensor last_loc_ptr, Tensor free_page_ptr, Tensor out_indices, int bs, int page_size) -> ()"
);
m
.
impl
(
"dcu_alloc_extend_kernel"
,
torch
::
kCUDA
,
&
dcu_alloc_extend_kernel
);
...
...
sgl-kernel/csrc/kvcacheio/transfer.cu
View file @
f6d91d7e
...
...
@@ -836,4 +836,322 @@ void dcu_alloc_extend_kernel(
cudaStream_t
torch_current_stream
=
at
::
cuda
::
getCurrentCUDAStream
();
launch_alloc_extend_kernel
<<<
grid_size
,
block_size
,
0
,
torch_current_stream
>>>
(
pre_lens_ptr1
,
seq_lens_ptr1
,
last_loc_ptr1
,
free_page_ptr1
,
out_indices1
,
bs
,
page_size
);
C10_CUDA_KERNEL_LAUNCH_CHECK
();
}
\ No newline at end of file
}
__global__
void
launch_assign_req_to_token_pool
(
const
int64_t
*
req_pool_indices_ptr
,
int32_t
*
req_to_token_ptr
,
const
int64_t
*
allocate_lens_ptr
,
int64_t
*
new_allocate_lens
,
int64_t
*
out_cache_loc_ptr
,
int64_t
shape
,
int64_t
bs
)
{
int64_t
pid
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
if
(
pid
>=
bs
)
return
;
int64_t
kv_start
=
allocate_lens_ptr
[
pid
];
int64_t
kv_end
=
new_allocate_lens
[
pid
];
int64_t
pool_idx
=
req_pool_indices_ptr
[
pid
];
int32_t
*
token_pool
=
(
int32_t
*
)(
req_to_token_ptr
+
pool_idx
*
shape
);
int64_t
sum_out_offset
=
0
;
for
(
int
length_offset
=
0
;
length_offset
<
pid
;
length_offset
++
){
int64_t
start
=
allocate_lens_ptr
[
length_offset
];
int64_t
end
=
new_allocate_lens
[
length_offset
];
sum_out_offset
+=
(
end
-
start
);
}
int64_t
*
out_cache_ptr
=
out_cache_loc_ptr
+
sum_out_offset
;
int64_t
copy_length
=
kv_end
-
kv_start
;
#pragma unroll(32)
for
(
int
out_cache_index
=
0
;
out_cache_index
<
copy_length
;
out_cache_index
++
)
{
token_pool
[
kv_start
+
out_cache_index
]
=
out_cache_ptr
[
out_cache_index
];
}
}
void
dcu_assign_req_to_token_pool
(
const
at
::
Tensor
req_pool_indices_ptr
,
at
::
Tensor
req_to_token_ptr
,
const
at
::
Tensor
allocate_lens_ptr
,
at
::
Tensor
new_allocate_lens
,
at
::
Tensor
out_cache_loc_ptr
,
int64_t
shape
,
int64_t
bs
)
{
const
int64_t
*
req_pool_indices_ptr1
=
static_cast
<
const
int64_t
*>
(
req_pool_indices_ptr
.
data_ptr
());
int32_t
*
req_to_token_ptr1
=
static_cast
<
int32_t
*>
(
req_to_token_ptr
.
data_ptr
());
const
int64_t
*
allocate_lens_ptr1
=
static_cast
<
const
int64_t
*>
(
allocate_lens_ptr
.
data_ptr
());
int64_t
*
new_allocate_lens1
=
static_cast
<
int64_t
*>
(
new_allocate_lens
.
data_ptr
());
int64_t
*
out_cache_loc_ptr1
=
static_cast
<
int64_t
*>
(
out_cache_loc_ptr
.
data_ptr
());
int64_t
block_size
=
64
;
int64_t
grid_size
=
(
bs
+
block_size
-
1
)
/
block_size
;
cudaStream_t
torch_current_stream
=
at
::
cuda
::
getCurrentCUDAStream
();
launch_assign_req_to_token_pool
<<<
grid_size
,
block_size
,
0
,
torch_current_stream
>>>
(
req_pool_indices_ptr1
,
req_to_token_ptr1
,
allocate_lens_ptr1
,
new_allocate_lens1
,
out_cache_loc_ptr1
,
shape
,
bs
);
C10_CUDA_KERNEL_LAUNCH_CHECK
();
}
__global__
void
get_last_loc_kernel
(
const
int32_t
*
__restrict__
req_to_token
,
const
int64_t
*
__restrict__
req_pool_indices_tensor
,
const
int64_t
*
__restrict__
prefix_lens_tensor
,
int64_t
*
__restrict__
result
,
int64_t
num_tokens
,
int64_t
req_to_token_stride
){
int64_t
pid
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
if
(
pid
>=
num_tokens
)
return
;
int64_t
pre_len
=
prefix_lens_tensor
[
pid
];
if
(
pre_len
>
0
)
{
int64_t
req_idx
=
req_pool_indices_tensor
[
pid
];
int64_t
token_idx
=
req_idx
*
req_to_token_stride
+
(
pre_len
-
1
);
result
[
pid
]
=
static_cast
<
int64_t
>
(
req_to_token
[
token_idx
]);
}
else
{
result
[
pid
]
=
static_cast
<
int64_t
>
(
-
1
);
}
}
at
::
Tensor
dcu_get_last_loc
(
const
at
::
Tensor
req_to_token
,
const
at
::
Tensor
req_pool_indices
,
const
at
::
Tensor
prefix_lens
)
{
TORCH_CHECK
(
req_to_token
.
device
().
is_cuda
(),
"req_to_token must be CUDA tensor"
);
TORCH_CHECK
(
req_pool_indices
.
device
().
is_cuda
(),
"req_pool_indices must be CUDA tensor"
);
TORCH_CHECK
(
prefix_lens
.
device
().
is_cuda
(),
"prefix_lens must be CUDA tensor"
);
TORCH_CHECK
(
req_to_token
.
dim
()
==
2
,
"req_to_token must be 2D tensor [batch, seq_len]"
);
TORCH_CHECK
(
prefix_lens
.
dim
()
==
1
,
"prefix_lens must be 1D"
);
TORCH_CHECK
(
req_pool_indices
.
dim
()
==
1
,
"req_pool_indices must be 1D"
);
int64_t
num_tokens
=
prefix_lens
.
numel
();
TORCH_CHECK
(
req_pool_indices
.
numel
()
==
num_tokens
,
"req_pool_indices must have same length as prefix_lens"
);
int64_t
req_to_token_stride
=
req_to_token
.
stride
(
0
);
auto
req_to_token_c
=
req_to_token
.
contiguous
();
auto
req_pool_indices_c
=
req_pool_indices
.
contiguous
();
auto
prefix_lens_c
=
prefix_lens
.
contiguous
();
const
int32_t
*
req_to_token_ptr
=
req_to_token_c
.
data_ptr
<
int32_t
>
();
const
int64_t
*
req_pool_indices_ptr
=
req_pool_indices_c
.
data_ptr
<
int64_t
>
();
const
int64_t
*
prefix_lens_ptr
=
prefix_lens_c
.
data_ptr
<
int64_t
>
();
auto
result
=
at
::
empty_like
(
prefix_lens_c
);
int64_t
*
result_ptr
=
result
.
data_ptr
<
int64_t
>
();
const
int64_t
block_size
=
64
;
const
int64_t
grid_size
=
(
num_tokens
+
block_size
-
1
)
/
block_size
;
cudaStream_t
stream
=
at
::
cuda
::
getCurrentCUDAStream
();
get_last_loc_kernel
<<<
grid_size
,
block_size
,
0
,
stream
>>>
(
req_to_token_ptr
,
req_pool_indices_ptr
,
prefix_lens_ptr
,
result_ptr
,
num_tokens
,
req_to_token_stride
);
C10_CUDA_KERNEL_LAUNCH_CHECK
();
return
result
;
}
__global__
void
launch_assign_extend_cache_locs_kernel
(
const
int64_t
*
__restrict__
req_pool_indices
,
// [bs]
const
int32_t
*
__restrict__
req_to_token
,
// [max_num_req, pool_len]
const
int64_t
*
__restrict__
start_offset
,
// [bs]
const
int64_t
*
__restrict__
end_offset
,
// [bs]
int64_t
*
__restrict__
out_cache_loc
,
// [sum(draft_token_num)]
int64_t
pool_len
,
int64_t
bs
)
{
int
pid
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
if
(
pid
>=
bs
)
return
;
int64_t
kv_start
=
start_offset
[
pid
];
int64_t
kv_end
=
end_offset
[
pid
];
int64_t
req_id
=
req_pool_indices
[
pid
];
int64_t
out_offset
=
0
;
for
(
int
i
=
0
;
i
<
pid
;
++
i
)
{
out_offset
+=
end_offset
[
i
]
-
start_offset
[
i
];
}
const
int32_t
*
src
=
req_to_token
+
req_id
*
pool_len
+
kv_start
;
int64_t
*
dst
=
out_cache_loc
+
out_offset
;
for
(
int64_t
i
=
0
;
i
<
kv_end
-
kv_start
;
++
i
)
{
dst
[
i
]
=
src
[
i
];
}
}
void
dcu_assign_extend_cache_locs
(
const
at
::
Tensor
req_pool_indices
,
const
at
::
Tensor
req_to_token
,
const
at
::
Tensor
start_offset
,
const
at
::
Tensor
end_offset
,
at
::
Tensor
out_cache_loc
,
int64_t
pool_len
,
int64_t
bs
)
{
const
int64_t
*
req_pool_indices_ptr
=
req_pool_indices
.
data_ptr
<
int64_t
>
();
const
int32_t
*
req_to_token_ptr
=
req_to_token
.
data_ptr
<
int32_t
>
();
const
int64_t
*
start_offset_ptr
=
start_offset
.
data_ptr
<
int64_t
>
();
const
int64_t
*
end_offset_ptr
=
end_offset
.
data_ptr
<
int64_t
>
();
int64_t
*
out_cache_loc_ptr
=
out_cache_loc
.
data_ptr
<
int64_t
>
();
constexpr
int64_t
threads
=
128
;
int64_t
blocks
=
(
bs
+
threads
-
1
)
/
threads
;
cudaStream_t
stream
=
at
::
cuda
::
getCurrentCUDAStream
();
launch_assign_extend_cache_locs_kernel
<<<
blocks
,
threads
,
0
,
stream
>>>
(
req_pool_indices_ptr
,
req_to_token_ptr
,
start_offset_ptr
,
end_offset_ptr
,
out_cache_loc_ptr
,
pool_len
,
bs
);
C10_CUDA_KERNEL_LAUNCH_CHECK
();
}
template
<
int
PAGED_SIZE
>
__global__
void
dcu_create_flashmla_kv_indices_kernel
(
const
int32_t
*
__restrict__
req_to_token
,
const
int32_t
*
__restrict__
req_pool_indices
,
const
int32_t
*
__restrict__
page_kernel_lens
,
const
int32_t
*
__restrict__
kv_start_idx
,
int32_t
*
__restrict__
kv_indices
,
int
req_to_token_stride
,
int
kv_indices_stride
)
{
int
pid
=
blockIdx
.
x
;
// batch index
int
req_pool_index
=
req_pool_indices
[
pid
];
int
kv_start
=
0
;
int
kv_end
=
0
;
if
(
kv_start_idx
!=
nullptr
)
{
kv_start
=
kv_start_idx
[
pid
];
kv_end
=
kv_start
;
}
kv_end
+=
page_kernel_lens
[
pid
];
int
total_len
=
kv_end
-
kv_start
;
int
num_pages
=
(
total_len
+
PAGED_SIZE
-
1
)
/
PAGED_SIZE
;
for
(
int
pg
=
0
;
pg
<
num_pages
;
++
pg
)
{
int
offset
=
pg
*
PAGED_SIZE
;
// token id = req_to_token[req_pool_index][kv_start + offset]
int64_t
token
=
req_to_token
[
req_pool_index
*
req_to_token_stride
+
kv_start
+
offset
];
// 页索引
kv_indices
[
pid
*
kv_indices_stride
+
pg
]
=
token
/
PAGED_SIZE
;
}
}
void
dcu_create_flashmla_kv_indices
(
const
at
::
Tensor
&
req_to_token
,
const
at
::
Tensor
&
req_pool_indices
,
const
at
::
Tensor
&
page_kernel_lens
,
const
c10
::
optional
<
at
::
Tensor
>&
kv_start_idx
,
at
::
Tensor
&
kv_indices
,
int64_t
req_to_token_stride
,
int64_t
kv_indices_stride
,
int64_t
PAGED_SIZE
)
{
TORCH_CHECK
(
req_to_token
.
is_cuda
(),
"req_to_token must be CUDA tensor"
);
TORCH_CHECK
(
kv_indices
.
is_cuda
(),
"kv_indices must be CUDA tensor"
);
int
bs
=
req_pool_indices
.
size
(
0
);
auto
stream
=
at
::
cuda
::
getCurrentCUDAStream
();
dim3
grid
(
bs
);
dim3
block
(
1
);
const
int32_t
*
kv_start_idx_ptr
=
nullptr
;
if
(
kv_start_idx
.
has_value
())
{
kv_start_idx_ptr
=
kv_start_idx
.
value
().
data_ptr
<
int32_t
>
();
}
if
(
PAGED_SIZE
==
64
)
{
dcu_create_flashmla_kv_indices_kernel
<
64
><<<
grid
,
block
,
0
,
stream
>>>
(
req_to_token
.
data_ptr
<
int32_t
>
(),
req_pool_indices
.
data_ptr
<
int32_t
>
(),
page_kernel_lens
.
data_ptr
<
int32_t
>
(),
kv_start_idx_ptr
,
kv_indices
.
data_ptr
<
int32_t
>
(),
req_to_token_stride
,
kv_indices_stride
);
}
else
{
TORCH_CHECK
(
false
,
"Unsupported PAGED_SIZE"
);
}
}
__global__
void
launch_create_chunked_prefix_cache_kv_indices
(
int32_t
*
req_to_token_ptr
,
const
int64_t
*
req_pool_indices_ptr
,
const
int32_t
*
chunk_starts_ptr
,
const
int32_t
*
chunk_seq_lens_ptr
,
const
int32_t
*
chunk_cu_seq_lens_ptr
,
int32_t
*
chunk_kv_indices_ptr
,
int64_t
col_num
,
int64_t
bs
)
{
int64_t
pid
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
if
(
pid
>=
bs
)
return
;
int64_t
req_pool_index
=
req_pool_indices_ptr
[
pid
];
int64_t
chunk_kv_indices_offset
=
chunk_cu_seq_lens_ptr
[
pid
];
int32_t
chunk_start_pos
=
chunk_starts_ptr
[
pid
];
int32_t
chunk_seq_len
=
chunk_seq_lens_ptr
[
pid
];
#pragma unroll(32)
for
(
int32_t
offset
=
0
;
offset
<
chunk_seq_len
;
offset
++
){
chunk_kv_indices_ptr
[
chunk_kv_indices_offset
+
offset
]
=
req_to_token_ptr
[
req_pool_index
*
col_num
+
chunk_start_pos
+
offset
];
}
}
void
dcu_create_chunked_prefix_cache_kv_indices
(
at
::
Tensor
req_to_token_ptr
,
const
at
::
Tensor
req_pool_indices_ptr
,
const
at
::
Tensor
chunk_starts_ptr
,
const
at
::
Tensor
chunk_seq_lens_ptr
,
const
at
::
Tensor
chunk_cu_seq_lens_ptr
,
at
::
Tensor
chunk_kv_indices_ptr
,
int64_t
col_num
,
int64_t
bs
)
{
int32_t
*
req_to_token_ptr1
=
static_cast
<
int32_t
*>
(
req_to_token_ptr
.
data_ptr
());
const
int64_t
*
req_pool_indices_ptr1
=
static_cast
<
const
int64_t
*>
(
req_pool_indices_ptr
.
data_ptr
());
const
int32_t
*
chunk_starts_ptr1
=
static_cast
<
const
int32_t
*>
(
chunk_starts_ptr
.
data_ptr
());
const
int32_t
*
chunk_seq_lens_ptr1
=
static_cast
<
const
int32_t
*>
(
chunk_seq_lens_ptr
.
data_ptr
());
const
int32_t
*
chunk_cu_seq_lens_ptr1
=
static_cast
<
const
int32_t
*>
(
chunk_cu_seq_lens_ptr
.
data_ptr
());
int32_t
*
chunk_kv_indices_ptr1
=
static_cast
<
int32_t
*>
(
chunk_kv_indices_ptr
.
data_ptr
());
int64_t
block_size
=
64
;
int64_t
grid_size
=
(
bs
+
block_size
-
1
)
/
block_size
;
cudaStream_t
torch_current_stream
=
at
::
cuda
::
getCurrentCUDAStream
();
launch_create_chunked_prefix_cache_kv_indices
<<<
grid_size
,
block_size
,
0
,
torch_current_stream
>>>
(
req_to_token_ptr1
,
req_pool_indices_ptr1
,
chunk_starts_ptr1
,
chunk_seq_lens_ptr1
,
chunk_cu_seq_lens_ptr1
,
chunk_kv_indices_ptr1
,
col_num
,
bs
);
C10_CUDA_KERNEL_LAUNCH_CHECK
();
}
sgl-kernel/include/sgl_kernel_ops.h
View file @
f6d91d7e
...
...
@@ -538,6 +538,7 @@ void segment_packbits(
/*
* From csrc/kvcacheio
*/
void
dcu_create_extend_after_decode_spec_info
(
const
at
::
Tensor
verified_id
,
const
at
::
Tensor
seq_lens
,
...
...
@@ -545,6 +546,49 @@ void dcu_create_extend_after_decode_spec_info(
at
::
Tensor
positions
,
at
::
Tensor
new_verified_id
,
int64_t
bs
);
void
dcu_create_chunked_prefix_cache_kv_indices
(
at
::
Tensor
req_to_token
,
const
at
::
Tensor
req_pool_indices
,
const
at
::
Tensor
chunk_starts
,
const
at
::
Tensor
chunk_seq_lens
,
const
at
::
Tensor
chunk_cu_seq_lens
,
at
::
Tensor
chunk_kv_indices
,
int64_t
col_num
,
int64_t
bs
);
void
dcu_create_flashmla_kv_indices
(
const
at
::
Tensor
&
req_to_token
,
const
at
::
Tensor
&
req_pool_indices
,
const
at
::
Tensor
&
page_kernel_lens
,
const
c10
::
optional
<
at
::
Tensor
>&
kv_start_idx
,
at
::
Tensor
&
kv_indices
,
int64_t
req_to_token_stride
,
int64_t
kv_indices_stride
,
int64_t
PAGED_SIZE
);
void
dcu_assign_extend_cache_locs
(
const
at
::
Tensor
req_pool_indices
,
const
at
::
Tensor
req_to_token
,
const
at
::
Tensor
start_offset
,
const
at
::
Tensor
end_offset
,
at
::
Tensor
out_cache_loc
,
int64_t
pool_len
,
int64_t
bs
);
at
::
Tensor
dcu_get_last_loc
(
const
at
::
Tensor
req_to_token
,
const
at
::
Tensor
req_pool_indices
,
const
at
::
Tensor
prefix_lens
);
void
dcu_assign_req_to_token_pool
(
const
at
::
Tensor
req_pool_indices_ptr
,
at
::
Tensor
req_to_token_ptr
,
const
at
::
Tensor
allocate_lens_ptr
,
at
::
Tensor
new_allocate_lens
,
at
::
Tensor
out_cache_loc_ptr
,
int64_t
shape
,
int64_t
bs
);
void
dcu_alloc_extend_kernel
(
const
at
::
Tensor
pre_lens_ptr
,
...
...
sgl-kernel/python/sgl_kernel/flash_mla.py
View file @
f6d91d7e
...
...
@@ -13,6 +13,26 @@ _IMPORT_ERROR = ImportError(
"Failed to load sgl_kernel.flashmla_ops extension. Ensure CUDA Driver >= 12.4"
)
def
dcu_create_flashmla_kv_indices
(
req_to_token_ptr
,
req_pool_indices_ptr
,
page_kernel_lens_ptr
,
kv_start_idx
,
kv_indices_ptr
,
req_to_token_ptr_stride
,
kv_indices_ptr_stride
,
PAGED_SIZE
=
64
,
):
torch
.
ops
.
sgl_kernel
.
dcu_create_flashmla_kv_indices
(
req_to_token_ptr
,
req_pool_indices_ptr
,
page_kernel_lens_ptr
,
kv_start_idx
,
kv_indices_ptr
,
req_to_token_ptr_stride
,
kv_indices_ptr_stride
,
PAGED_SIZE
,
)
def
get_mla_metadata
(
cache_seqlens
:
torch
.
Tensor
,
...
...
sgl-kernel/python/sgl_kernel/kvcacheio.py
View file @
f6d91d7e
...
...
@@ -293,3 +293,76 @@ def transfer_kv_all_layer_mla_lf_pf(
block_quota
,
num_warps_per_block
,
)
def
dcu_assign_req_to_token_pool
(
req_pool_indices
:
torch
.
Tensor
,
req_to_token
:
torch
.
Tensor
,
allocate_lens
:
torch
.
Tensor
,
new_allocate_lens
:
torch
.
Tensor
,
out_cache_loc
:
torch
.
Tensor
,
shape
:
int
,
bs
:
int
,
):
torch
.
ops
.
sgl_kernel
.
dcu_assign_req_to_token_pool
(
req_pool_indices
,
req_to_token
,
allocate_lens
,
new_allocate_lens
,
out_cache_loc
,
shape
,
bs
,
)
def
dcu_get_last_loc
(
req_to_token
:
torch
.
Tensor
,
req_pool_indices
:
torch
.
Tensor
,
prefix_lens
:
torch
.
Tensor
,
):
result
=
torch
.
ops
.
sgl_kernel
.
dcu_get_last_loc
(
req_to_token
,
req_pool_indices
,
prefix_lens
,
)
return
result
def
dcu_assign_extend_cache_locs
(
req_pool_indices
:
torch
.
Tensor
,
req_to_token
:
torch
.
Tensor
,
start_offset
:
torch
.
Tensor
,
end_offset
:
torch
.
Tensor
,
out_cache_loc
:
torch
.
Tensor
,
pool_len
:
int
,
bs
:
int
,
):
torch
.
ops
.
sgl_kernel
.
dcu_assign_extend_cache_locs
(
req_pool_indices
,
req_to_token
,
start_offset
,
end_offset
,
out_cache_loc
,
pool_len
,
bs
,
)
def
dcu_create_chunked_prefix_cache_kv_indices
(
req_to_token
:
torch
.
Tensor
,
req_pool_indices
:
torch
.
Tensor
,
chunk_starts
:
torch
.
Tensor
,
chunk_seq_lens
:
torch
.
Tensor
,
chunk_cu_seq_lens
:
torch
.
Tensor
,
chunk_kv_indices
:
torch
.
Tensor
,
col_num
:
int
,
bs
:
int
,
):
torch
.
ops
.
sgl_kernel
.
dcu_create_chunked_prefix_cache_kv_indices
(
req_to_token
,
req_pool_indices
,
chunk_starts
,
chunk_seq_lens
,
chunk_cu_seq_lens
,
chunk_kv_indices
,
col_num
,
bs
,
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment