Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
e61bac87
Unverified
Commit
e61bac87
authored
Aug 19, 2025
by
Woosuk Kwon
Committed by
GitHub
Aug 19, 2025
Browse files
[Misc] Minor refactoring for FlashInfer backend (#23147)
Signed-off-by:
Woosuk Kwon
<
woosuk.kwon@berkeley.edu
>
parent
80141bbf
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
65 additions
and
91 deletions
+65
-91
vllm/v1/attention/backends/flashinfer.py
vllm/v1/attention/backends/flashinfer.py
+65
-91
No files found.
vllm/v1/attention/backends/flashinfer.py
View file @
e61bac87
...
@@ -10,8 +10,7 @@ import torch
...
@@ -10,8 +10,7 @@ import torch
from
flashinfer
import
(
BatchDecodeWithPagedKVCacheWrapper
,
from
flashinfer
import
(
BatchDecodeWithPagedKVCacheWrapper
,
BatchPrefillWithPagedKVCacheWrapper
,
BatchPrefillWithPagedKVCacheWrapper
,
MultiLevelCascadeAttentionWrapper
)
MultiLevelCascadeAttentionWrapper
)
from
flashinfer.decode
import
(
_get_range_buf
,
get_seq_lens
,
from
flashinfer.decode
import
_get_range_buf
,
trtllm_batch_decode_with_kv_cache
trtllm_batch_decode_with_kv_cache
)
from
flashinfer.prefill
import
trtllm_batch_context_with_kv_cache
from
flashinfer.prefill
import
trtllm_batch_context_with_kv_cache
import
vllm.envs
as
envs
import
vllm.envs
as
envs
...
@@ -142,19 +141,10 @@ class FlashInferMetadata:
...
@@ -142,19 +141,10 @@ class FlashInferMetadata:
# The number of entries in the last page of each request in
# The number of entries in the last page of each request in
# the paged kv cache, shape: [batch_size] (CPU for plan)
# the paged kv cache, shape: [batch_size] (CPU for plan)
paged_kv_last_page_len_cpu
:
torch
.
Tensor
paged_kv_last_page_len_cpu
:
torch
.
Tensor
# The number of query/output heads
num_qo_heads
:
int
# The number of key/value heads
num_kv_heads
:
int
# The dimension of the attention heads
head_dim
:
int
# Block size of vllm
page_size
:
int
# The data type of the paged kv cache
kv_data_type
:
torch
.
dtype
# The data type of the query
# The data type of the query
q_data_type
:
torch
.
dtype
q_data_type
:
torch
.
dtype
seq_lens_cpu
:
torch
.
Tensor
slot_mapping
:
torch
.
Tensor
slot_mapping
:
torch
.
Tensor
# For flashinfer trtllm batch decode
# For flashinfer trtllm batch decode
...
@@ -185,10 +175,6 @@ class FlashInferMetadata:
...
@@ -185,10 +175,6 @@ class FlashInferMetadata:
qo_indptr_gpu
:
Optional
[
torch
.
Tensor
]
=
None
qo_indptr_gpu
:
Optional
[
torch
.
Tensor
]
=
None
paged_kv_indptr_gpu
:
Optional
[
torch
.
Tensor
]
=
None
paged_kv_indptr_gpu
:
Optional
[
torch
.
Tensor
]
=
None
def
__post_init__
(
self
):
if
self
.
head_dim
is
not
None
:
FlashInferBackend
.
validate_head_size
(
self
.
head_dim
)
class
FlashInferMetadataBuilder
(
AttentionMetadataBuilder
[
FlashInferMetadata
]):
class
FlashInferMetadataBuilder
(
AttentionMetadataBuilder
[
FlashInferMetadata
]):
cudagraph_support
:
ClassVar
[
AttentionCGSupport
]
=
\
cudagraph_support
:
ClassVar
[
AttentionCGSupport
]
=
\
...
@@ -201,13 +187,14 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
...
@@ -201,13 +187,14 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
self
.
device
=
device
self
.
device
=
device
self
.
vllm_config
=
vllm_config
self
.
vllm_config
=
vllm_config
self
.
cache_config
=
vllm_config
.
cache_config
self
.
cache_config
=
vllm_config
.
cache_config
self
.
model_config
=
vllm_config
.
model_config
self
.
kv_cache_spec
=
kv_cache_spec
self
.
kv_cache_spec
=
kv_cache_spec
self
.
_workspace_buffer
=
None
self
.
_workspace_buffer
=
None
self
.
_prefill_wrapper
=
None
# Wrapper for prefill/append
self
.
_prefill_wrapper
=
None
# Wrapper for prefill/append
self
.
_decode_wrapper
=
None
# Wrapper for decode (general shape)
self
.
_decode_wrapper
=
None
# Wrapper for decode (general shape)
self
.
compilation_config
=
vllm_config
.
compilation_config
self
.
compilation_config
=
vllm_config
.
compilation_config
max_num_pages_per_req
=
cdiv
(
vllm_config
.
model_config
.
max_model_len
,
max_num_pages_per_req
=
cdiv
(
self
.
model_config
.
max_model_len
,
self
.
kv_cache_spec
.
block_size
)
self
.
kv_cache_spec
.
block_size
)
max_num_reqs
=
vllm_config
.
scheduler_config
.
max_num_seqs
max_num_reqs
=
vllm_config
.
scheduler_config
.
max_num_seqs
max_num_pages
=
max_num_reqs
*
max_num_pages_per_req
max_num_pages
=
max_num_reqs
*
max_num_pages_per_req
...
@@ -221,6 +208,29 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
...
@@ -221,6 +208,29 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
self
.
_decode_cudagraph_max_bs
=
min
(
self
.
_decode_cudagraph_max_bs
=
min
(
max_num_reqs
,
self
.
compilation_config
.
max_capture_size
)
max_num_reqs
,
self
.
compilation_config
.
max_capture_size
)
self
.
num_qo_heads
=
self
.
model_config
.
get_num_attention_heads
(
self
.
vllm_config
.
parallel_config
)
self
.
num_kv_heads
=
self
.
kv_cache_spec
.
num_kv_heads
self
.
head_dim
=
self
.
kv_cache_spec
.
head_size
FlashInferBackend
.
validate_head_size
(
self
.
head_dim
)
self
.
page_size
=
self
.
kv_cache_spec
.
block_size
self
.
enable_fusion
=
(
self
.
compilation_config
.
pass_config
.
enable_attn_fusion
)
self
.
q_data_type
=
self
.
model_config
.
dtype
self
.
cache_dtype
=
self
.
cache_config
.
cache_dtype
if
self
.
cache_dtype
.
startswith
(
"fp8"
):
self
.
kv_cache_dtype
=
(
FlashInferBackend
.
get_fp8_dtype_for_flashinfer
(
self
.
cache_dtype
))
# Insert FP8 quant for query if FP8 kv cache and attn fusion enabled
if
self
.
enable_fusion
:
self
.
q_data_type
=
self
.
kv_cache_dtype
else
:
self
.
kv_cache_dtype
=
self
.
kv_cache_spec
.
dtype
self
.
use_tensor_cores
=
(
envs
.
VLLM_FLASHINFER_FORCE_TENSOR_CORES
or
(
self
.
num_qo_heads
//
self
.
num_kv_heads
>
4
))
self
.
_cascade_wrapper
=
None
# Wrapper for cascade attention
self
.
_cascade_wrapper
=
None
# Wrapper for cascade attention
# Global hyperparameters shared by all attention layers
# Global hyperparameters shared by all attention layers
...
@@ -282,14 +292,6 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
...
@@ -282,14 +292,6 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
decode_wrapper
=
self
.
_decode_wrapper
decode_wrapper
=
self
.
_decode_wrapper
if
decode_wrapper
is
None
:
if
decode_wrapper
is
None
:
num_qo_heads
=
(
self
.
vllm_config
.
model_config
.
get_num_attention_heads
(
self
.
vllm_config
.
parallel_config
))
num_kv_heads
=
self
.
vllm_config
.
model_config
.
get_num_kv_heads
(
self
.
vllm_config
.
parallel_config
)
use_tensor_cores
=
envs
.
VLLM_FLASHINFER_FORCE_TENSOR_CORES
or
(
num_qo_heads
//
num_kv_heads
>
4
)
if
use_cudagraph
:
if
use_cudagraph
:
paged_kv_indptr
=
self
.
paged_kv_indptr
[:
batch_size
+
1
]
paged_kv_indptr
=
self
.
paged_kv_indptr
[:
batch_size
+
1
]
paged_kv_indices
=
self
.
paged_kv_indices
paged_kv_indices
=
self
.
paged_kv_indices
...
@@ -306,7 +308,7 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
...
@@ -306,7 +308,7 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
paged_kv_indptr_buffer
=
paged_kv_indptr
,
paged_kv_indptr_buffer
=
paged_kv_indptr
,
paged_kv_indices_buffer
=
paged_kv_indices
,
paged_kv_indices_buffer
=
paged_kv_indices
,
paged_kv_last_page_len_buffer
=
paged_kv_last_page_len
,
paged_kv_last_page_len_buffer
=
paged_kv_last_page_len
,
use_tensor_cores
=
use_tensor_cores
)
use_tensor_cores
=
self
.
use_tensor_cores
)
# save the decode wrapper
# save the decode wrapper
if
use_cudagraph
:
if
use_cudagraph
:
...
@@ -342,16 +344,16 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
...
@@ -342,16 +344,16 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
attn_metadata
.
shared_kv_last_page_len_cpu
,
attn_metadata
.
shared_kv_last_page_len_cpu
,
attn_metadata
.
paged_kv_last_page_len_cpu
attn_metadata
.
paged_kv_last_page_len_cpu
],
],
attn_metadata
.
num_qo_heads
,
self
.
num_qo_heads
,
attn_metadata
.
num_kv_heads
,
self
.
num_kv_heads
,
attn_metadata
.
head_dim
,
self
.
head_dim
,
attn_metadata
.
page_size
,
self
.
page_size
,
causal
=
True
,
causal
=
True
,
sm_scale
=
self
.
global_hyperparameters
.
sm_scale
,
sm_scale
=
self
.
global_hyperparameters
.
sm_scale
,
window_left
=
self
.
global_hyperparameters
.
window_left
,
window_left
=
self
.
global_hyperparameters
.
window_left
,
logits_soft_cap
=
self
.
global_hyperparameters
.
logits_soft_cap
,
logits_soft_cap
=
self
.
global_hyperparameters
.
logits_soft_cap
,
q_data_type
=
attn_metadata
.
q_data_type
,
q_data_type
=
self
.
q_data_type
,
kv_data_type
=
attn_metadata
.
kv_data_
type
,
kv_data_type
=
self
.
kv_cache_d
type
,
)
)
else
:
else
:
# Regular attention (common case).
# Regular attention (common case).
...
@@ -383,17 +385,17 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
...
@@ -383,17 +385,17 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
attn_metadata
.
paged_kv_indices
,
attn_metadata
.
paged_kv_indices
,
attn_metadata
.
attn_metadata
.
paged_kv_last_page_len_cpu
[
prefill_start
:],
paged_kv_last_page_len_cpu
[
prefill_start
:],
attn_metadata
.
num_qo_heads
,
self
.
num_qo_heads
,
attn_metadata
.
num_kv_heads
,
self
.
num_kv_heads
,
attn_metadata
.
head_dim
,
self
.
head_dim
,
attn_metadata
.
page_size
,
self
.
page_size
,
causal
=
True
,
causal
=
True
,
sm_scale
=
self
.
global_hyperparameters
.
sm_scale
,
sm_scale
=
self
.
global_hyperparameters
.
sm_scale
,
window_left
=
self
.
global_hyperparameters
.
window_left
,
window_left
=
self
.
global_hyperparameters
.
window_left
,
logits_soft_cap
=
self
.
global_hyperparameters
.
logits_soft_cap
=
self
.
global_hyperparameters
.
logits_soft_cap
,
logits_soft_cap
,
q_data_type
=
attn_metadata
.
q_data_type
,
q_data_type
=
self
.
q_data_type
,
kv_data_type
=
attn_metadata
.
kv_data_
type
,
kv_data_type
=
self
.
kv_cache_d
type
,
)
)
else
:
else
:
attn_metadata
.
qo_indptr_gpu
=
qo_indptr_cpu
.
to
(
self
.
device
)
attn_metadata
.
qo_indptr_gpu
=
qo_indptr_cpu
.
to
(
self
.
device
)
...
@@ -435,18 +437,19 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
...
@@ -435,18 +437,19 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
self
.
paged_kv_indptr_cpu
[:
num_input_tokens
+
1
],
self
.
paged_kv_indptr_cpu
[:
num_input_tokens
+
1
],
attn_metadata
.
paged_kv_indices
,
attn_metadata
.
paged_kv_indices
,
self
.
paged_kv_last_page_len_cpu
[:
num_input_tokens
],
self
.
paged_kv_last_page_len_cpu
[:
num_input_tokens
],
attn_metadata
.
num_qo_heads
,
attn_metadata
.
seq_lens_cpu
[:
num_input_tokens
],
attn_metadata
.
num_kv_heads
,
self
.
num_qo_heads
,
attn_metadata
.
head_dim
,
self
.
num_kv_heads
,
attn_metadata
.
page_size
,
self
.
head_dim
,
self
.
page_size
,
# Disable flashinfer's pos encoding and use vllm's rope.
# Disable flashinfer's pos encoding and use vllm's rope.
pos_encoding_mode
=
"NONE"
,
pos_encoding_mode
=
"NONE"
,
sm_scale
=
self
.
global_hyperparameters
.
sm_scale
,
sm_scale
=
self
.
global_hyperparameters
.
sm_scale
,
window_left
=
self
.
global_hyperparameters
.
window_left
,
window_left
=
self
.
global_hyperparameters
.
window_left
,
logits_soft_cap
=
self
.
global_hyperparameters
.
logits_soft_cap
=
self
.
global_hyperparameters
.
logits_soft_cap
,
logits_soft_cap
,
q_data_type
=
attn_metadata
.
q_data_type
,
q_data_type
=
self
.
q_data_type
,
kv_data_type
=
attn_metadata
.
kv_data_
type
,
kv_data_type
=
self
.
kv_cache_d
type
,
)
)
def
build
(
self
,
def
build
(
self
,
...
@@ -458,9 +461,9 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
...
@@ -458,9 +461,9 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
num_decodes
,
num_prefills
,
num_decode_tokens
,
num_prefill_tokens
=
\
num_decodes
,
num_prefills
,
num_decode_tokens
,
num_prefill_tokens
=
\
split_decodes_and_prefills
(
common_attn_metadata
)
split_decodes_and_prefills
(
common_attn_metadata
)
page_size
=
self
.
kv_cache_spec
.
block
_size
page_size
=
self
.
page
_size
max_q_len
=
common_attn_metadata
.
max_query_len
max_q_len
=
common_attn_metadata
.
max_query_len
max_seq_len
=
common_attn_metadata
.
seq_lens_cpu
.
max
()
max_seq_len
=
common_attn_metadata
.
seq_lens_cpu
.
max
()
.
item
()
seq_lens
=
common_attn_metadata
.
seq_lens
seq_lens
=
common_attn_metadata
.
seq_lens
seq_lens_cpu
=
common_attn_metadata
.
seq_lens_cpu
seq_lens_cpu
=
common_attn_metadata
.
seq_lens_cpu
block_table_tensor
=
common_attn_metadata
.
block_table_tensor
block_table_tensor
=
common_attn_metadata
.
block_table_tensor
...
@@ -495,7 +498,7 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
...
@@ -495,7 +498,7 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
shared_kv_page_indices_cpu
=
None
shared_kv_page_indices_cpu
=
None
shared_kv_last_page_len_cpu
=
None
shared_kv_last_page_len_cpu
=
None
max_num_blocks
=
block_table_bounds_cpu
.
max
()
max_num_blocks
=
block_table_bounds_cpu
.
max
()
.
item
()
block_table_bounds
=
block_table_bounds_cpu
.
to
(
self
.
device
,
block_table_bounds
=
block_table_bounds_cpu
.
to
(
self
.
device
,
non_blocking
=
True
)
non_blocking
=
True
)
mask
=
(
self
.
block_table_arange
[:
max_num_blocks
].
unsqueeze
(
0
)
mask
=
(
self
.
block_table_arange
[:
max_num_blocks
].
unsqueeze
(
0
)
...
@@ -520,42 +523,23 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
...
@@ -520,42 +523,23 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
paged_kv_last_page_len_cpu
,
paged_kv_last_page_len_cpu
,
out
=
self
.
paged_kv_last_page_len_cpu
[:
num_reqs
])
out
=
self
.
paged_kv_last_page_len_cpu
[:
num_reqs
])
cache_dtype
=
self
.
cache_config
.
cache_dtype
if
cache_dtype
.
startswith
(
"fp8"
):
kv_cache_dtype
=
FlashInferBackend
.
get_fp8_dtype_for_flashinfer
(
cache_dtype
)
else
:
kv_cache_dtype
=
self
.
kv_cache_spec
.
dtype
config
=
self
.
vllm_config
num_qo_heads
=
config
.
model_config
.
get_num_attention_heads
(
config
.
parallel_config
)
num_kv_heads
=
self
.
kv_cache_spec
.
num_kv_heads
head_dim
=
self
.
kv_cache_spec
.
head_size
# Check if any layer uses sinks (requires TRTLLM attention)
# Check if any layer uses sinks (requires TRTLLM attention)
has_sinks
=
self
.
global_hyperparameters
.
has_sinks
has_sinks
=
self
.
global_hyperparameters
.
has_sinks
# Insert FP8 quant for query if FP8 kv cache and attn fusion enabled
prefill_use_trtllm
=
use_trtllm_attention
(
self
.
num_qo_heads
,
q_dtype
=
config
.
model_config
.
dtype
self
.
num_kv_heads
,
enable_fusion
=
config
.
compilation_config
.
pass_config
.
enable_attn_fusion
if
cache_dtype
.
startswith
(
"fp8"
)
and
enable_fusion
:
q_dtype
=
kv_cache_dtype
prefill_use_trtllm
=
use_trtllm_attention
(
num_qo_heads
,
num_kv_heads
,
num_prefill_tokens
,
num_prefill_tokens
,
max_seq_len
,
max_seq_len
,
cache_dtype
,
self
.
cache_dtype
,
q_d
type
,
self
.
q_data_
type
,
is_prefill
=
True
,
is_prefill
=
True
,
has_sinks
=
has_sinks
)
has_sinks
=
has_sinks
)
decode_use_trtllm
=
use_trtllm_attention
(
num_qo_heads
,
decode_use_trtllm
=
use_trtllm_attention
(
self
.
num_qo_heads
,
num_kv_heads
,
self
.
num_kv_heads
,
num_decode_tokens
,
num_decode_tokens
,
max_seq_len
,
max_seq_len
,
cache_dtype
,
self
.
cache_dtype
,
q_d
type
,
self
.
q_data_
type
,
is_prefill
=
False
,
is_prefill
=
False
,
has_sinks
=
has_sinks
)
has_sinks
=
has_sinks
)
...
@@ -566,12 +550,8 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
...
@@ -566,12 +550,8 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
paged_kv_indices
=
paged_kv_indices
,
paged_kv_indices
=
paged_kv_indices
,
paged_kv_last_page_len_cpu
=
self
.
paged_kv_last_page_len_cpu
=
self
.
paged_kv_last_page_len_cpu
[:
num_reqs
],
paged_kv_last_page_len_cpu
[:
num_reqs
],
num_qo_heads
=
num_qo_heads
,
q_data_type
=
self
.
q_data_type
,
num_kv_heads
=
num_kv_heads
,
seq_lens_cpu
=
seq_lens_cpu
,
head_dim
=
head_dim
,
page_size
=
page_size
,
kv_data_type
=
kv_cache_dtype
,
q_data_type
=
q_dtype
,
slot_mapping
=
common_attn_metadata
.
slot_mapping
,
slot_mapping
=
common_attn_metadata
.
slot_mapping
,
max_q_len
=
max_q_len
,
max_q_len
=
max_q_len
,
max_seq_len
=
max_seq_len
,
max_seq_len
=
max_seq_len
,
...
@@ -910,6 +890,7 @@ def fast_plan_decode(
...
@@ -910,6 +890,7 @@ def fast_plan_decode(
indptr_cpu
:
torch
.
Tensor
,
indptr_cpu
:
torch
.
Tensor
,
indices
:
torch
.
Tensor
,
indices
:
torch
.
Tensor
,
last_page_len_cpu
:
torch
.
Tensor
,
last_page_len_cpu
:
torch
.
Tensor
,
seq_lens_cpu
:
torch
.
Tensor
,
num_qo_heads
:
int
,
num_qo_heads
:
int
,
num_kv_heads
:
int
,
num_kv_heads
:
int
,
head_dim
:
int
,
head_dim
:
int
,
...
@@ -987,9 +968,6 @@ def fast_plan_decode(
...
@@ -987,9 +968,6 @@ def fast_plan_decode(
kv_data_type
=
getattr
(
torch
,
kv_data_type
)
if
isinstance
(
kv_data_type
=
getattr
(
torch
,
kv_data_type
)
if
isinstance
(
kv_data_type
,
str
)
else
kv_data_type
kv_data_type
,
str
)
else
kv_data_type
if
self
.
use_tensor_cores
:
qo_indptr_host
=
_get_range_buf
(
batch_size
+
1
,
"cpu"
)
if
batch_size
!=
self
.
_fixed_batch_size
:
if
batch_size
!=
self
.
_fixed_batch_size
:
raise
ValueError
(
raise
ValueError
(
"The batch size should be fixed in cudagraph mode, the runtime "
"The batch size should be fixed in cudagraph mode, the runtime "
...
@@ -1006,12 +984,8 @@ def fast_plan_decode(
...
@@ -1006,12 +984,8 @@ def fast_plan_decode(
self
.
_paged_kv_last_page_len_buf
.
copy_
(
last_page_len_cpu
,
self
.
_paged_kv_last_page_len_buf
.
copy_
(
last_page_len_cpu
,
non_blocking
=
True
)
non_blocking
=
True
)
indptr_host
=
indptr_cpu
last_page_len_host
=
last_page_len_cpu
if
self
.
use_tensor_cores
:
if
self
.
use_tensor_cores
:
kv_lens_arr_host
=
get_seq_lens
(
indptr_host
,
last_page_len_host
,
qo_indptr_host
=
_get_range_buf
(
batch_size
+
1
,
"cpu"
)
page_size
)
try
:
try
:
# Make sure we pass exactly 15 arguments for tensor core version
# Make sure we pass exactly 15 arguments for tensor core version
...
@@ -1020,8 +994,8 @@ def fast_plan_decode(
...
@@ -1020,8 +994,8 @@ def fast_plan_decode(
self
.
_int_workspace_buffer
,
self
.
_int_workspace_buffer
,
self
.
_pin_memory_int_workspace_buffer
,
self
.
_pin_memory_int_workspace_buffer
,
qo_indptr_host
,
qo_indptr_host
,
indptr_
host
,
indptr_
cpu
,
kv
_lens_
arr_host
,
seq
_lens_
cpu
,
batch_size
,
# total_num_rows
batch_size
,
# total_num_rows
batch_size
,
batch_size
,
num_qo_heads
,
num_qo_heads
,
...
@@ -1041,7 +1015,7 @@ def fast_plan_decode(
...
@@ -1041,7 +1015,7 @@ def fast_plan_decode(
self
.
_float_workspace_buffer
,
self
.
_float_workspace_buffer
,
self
.
_int_workspace_buffer
,
self
.
_int_workspace_buffer
,
self
.
_pin_memory_int_workspace_buffer
,
self
.
_pin_memory_int_workspace_buffer
,
indptr_
host
,
indptr_
cpu
,
batch_size
,
batch_size
,
num_qo_heads
,
num_qo_heads
,
num_kv_heads
,
num_kv_heads
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment