Commit f810bda3 authored by renzhc's avatar renzhc
Browse files

Merge branch 'v0.5.4_dev' into v0.5.4_rzc

parents 4167eff9 48542418
......@@ -293,3 +293,76 @@ def transfer_kv_all_layer_mla_lf_pf(
block_quota,
num_warps_per_block,
)
def dcu_assign_req_to_token_pool(
req_pool_indices:torch.Tensor,
req_to_token:torch.Tensor,
allocate_lens:torch.Tensor,
new_allocate_lens:torch.Tensor,
out_cache_loc:torch.Tensor,
shape:int,
bs:int,
):
torch.ops.sgl_kernel.dcu_assign_req_to_token_pool(
req_pool_indices,
req_to_token,
allocate_lens,
new_allocate_lens,
out_cache_loc,
shape,
bs,
)
def dcu_get_last_loc(
req_to_token: torch.Tensor,
req_pool_indices: torch.Tensor,
prefix_lens: torch.Tensor,
):
result = torch.ops.sgl_kernel.dcu_get_last_loc(
req_to_token,
req_pool_indices,
prefix_lens,
)
return result
def dcu_assign_extend_cache_locs(
req_pool_indices: torch.Tensor,
req_to_token: torch.Tensor,
start_offset: torch.Tensor,
end_offset: torch.Tensor,
out_cache_loc: torch.Tensor,
pool_len: int,
bs: int,
):
torch.ops.sgl_kernel.dcu_assign_extend_cache_locs(
req_pool_indices,
req_to_token,
start_offset,
end_offset,
out_cache_loc,
pool_len,
bs,
)
def dcu_create_chunked_prefix_cache_kv_indices(
req_to_token: torch.Tensor,
req_pool_indices: torch.Tensor,
chunk_starts: torch.Tensor,
chunk_seq_lens: torch.Tensor,
chunk_cu_seq_lens: torch.Tensor,
chunk_kv_indices: torch.Tensor,
col_num: int,
bs: int,
):
torch.ops.sgl_kernel.dcu_create_chunked_prefix_cache_kv_indices(
req_to_token,
req_pool_indices,
chunk_starts,
chunk_seq_lens,
chunk_cu_seq_lens,
chunk_kv_indices,
col_num,
bs,
)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment