Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
19836092
Unverified
Commit
19836092
authored
Sep 23, 2025
by
Benjamin Chislett
Committed by
GitHub
Sep 24, 2025
Browse files
[Bugfix] Use a separate FlashInfer workspace buffer for trtllm-gen (#25520)
parent
d06b5a95
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
12 additions
and
2 deletions
+12
-2
vllm/v1/attention/backends/flashinfer.py
vllm/v1/attention/backends/flashinfer.py
+12
-2
No files found.
vllm/v1/attention/backends/flashinfer.py
View file @
19836092
...
@@ -48,6 +48,16 @@ FP4_DTYPE = torch.uint8
...
@@ -48,6 +48,16 @@ FP4_DTYPE = torch.uint8
logger
=
init_logger
(
__name__
)
logger
=
init_logger
(
__name__
)
trtllm_gen_workspace_buffer
=
None
def
_get_trtllm_gen_workspace_buffer
():
global
trtllm_gen_workspace_buffer
if
trtllm_gen_workspace_buffer
is
None
:
trtllm_gen_workspace_buffer
=
torch
.
zeros
(
FLASHINFER_WORKSPACE_BUFFER_SIZE
,
dtype
=
torch
.
uint8
,
device
=
'cuda'
)
return
trtllm_gen_workspace_buffer
@
triton
.
jit
@
triton
.
jit
def
_trtllm_prefill_attn_kvfp8_dequant
(
def
_trtllm_prefill_attn_kvfp8_dequant
(
...
@@ -862,7 +872,7 @@ class FlashInferImpl(AttentionImpl):
...
@@ -862,7 +872,7 @@ class FlashInferImpl(AttentionImpl):
else
:
else
:
# prefill_query may be non-contiguous
# prefill_query may be non-contiguous
prefill_query
=
prefill_query
.
contiguous
()
prefill_query
=
prefill_query
.
contiguous
()
workspace_buffer
=
prefill_wrapper
.
_float
_workspace_buffer
workspace_buffer
=
_get_trtllm_gen
_workspace_buffer
()
block_tables_prefill
=
attn_metadata
.
block_table_tensor
[
block_tables_prefill
=
attn_metadata
.
block_table_tensor
[
num_decode_tokens
:]
num_decode_tokens
:]
seq_lens_prefill
=
attn_metadata
.
seq_lens
[
num_decode_tokens
:]
seq_lens_prefill
=
attn_metadata
.
seq_lens
[
num_decode_tokens
:]
...
@@ -943,7 +953,7 @@ class FlashInferImpl(AttentionImpl):
...
@@ -943,7 +953,7 @@ class FlashInferImpl(AttentionImpl):
else
:
else
:
# decode_query may be non-contiguous
# decode_query may be non-contiguous
decode_query
=
decode_query
.
contiguous
()
decode_query
=
decode_query
.
contiguous
()
workspace_buffer
=
decode_wrapper
.
_float
_workspace_buffer
workspace_buffer
=
_get_trtllm_gen
_workspace_buffer
()
block_tables_decode
=
attn_metadata
.
\
block_tables_decode
=
attn_metadata
.
\
block_table_tensor
[:
num_decode_tokens
]
block_table_tensor
[:
num_decode_tokens
]
seq_lens_decode
=
attn_metadata
.
seq_lens
[:
num_decode_tokens
]
seq_lens_decode
=
attn_metadata
.
seq_lens
[:
num_decode_tokens
]
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment