Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
e9630458
Unverified
Commit
e9630458
authored
Aug 06, 2024
by
Bongwon Jang
Committed by
GitHub
Aug 05, 2024
Browse files
[SpecDecode] Support FlashInfer in DraftModelRunner (#6926)
parent
82a1b1a8
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
47 additions
and
0 deletions
+47
-0
vllm/spec_decode/draft_model_runner.py
vllm/spec_decode/draft_model_runner.py
+47
-0
No files found.
vllm/spec_decode/draft_model_runner.py
View file @
e9630458
...
...
@@ -11,6 +11,17 @@ except ModuleNotFoundError:
from
vllm.attention.backends.rocm_flash_attn
import
(
ROCmFlashAttentionMetadata
as
FlashAttentionMetadata
)
try
:
from
flashinfer
import
BatchDecodeWithPagedKVCacheWrapper
from
flashinfer.decode
import
CUDAGraphBatchDecodeWithPagedKVCacheWrapper
from
flashinfer.prefill
import
BatchPrefillWithPagedKVCacheWrapper
FLASHINFER_WORKSPACE_BUFFER_SIZE
=
256
*
1024
*
1024
except
ImportError
:
BatchDecodeWithPagedKVCacheWrapper
=
None
CUDAGraphBatchDecodeWithPagedKVCacheWrapper
=
None
BatchPrefillWithPagedKVCacheWrapper
=
None
FLASHINFER_WORKSPACE_BUFFER_SIZE
=
0
from
vllm.config
import
(
CacheConfig
,
DeviceConfig
,
LoadConfig
,
LoRAConfig
,
ModelConfig
,
MultiModalConfig
,
ParallelConfig
,
PromptAdapterConfig
,
SchedulerConfig
)
...
...
@@ -79,6 +90,11 @@ class TP1DraftModelRunner(ModelRunner):
return_hidden_states
=
return_hidden_states
,
)
self
.
flashinfer_decode_workspace_buffer
=
None
self
.
flashinfer_decode_wrapper
=
None
self
.
flashinfer_prefill_workspace_buffer
=
None
self
.
flashinfer_prefill_wrapper
=
None
def
_update_flash_attn_metadata
(
self
,
attn_metadata
,
num_seqs
,
num_queries
):
assert
isinstance
(
attn_metadata
,
FlashAttentionMetadata
)
...
...
@@ -286,6 +302,37 @@ class TP1DraftModelRunner(ModelRunner):
model_input
.
prompt_adapter_requests
,
model_input
.
prompt_adapter_mapping
)
if
self
.
attn_backend
.
get_name
()
==
"flashinfer"
:
assert
model_input
.
attn_metadata
is
not
None
assert
model_input
.
input_tokens
is
not
None
if
self
.
flashinfer_decode_workspace_buffer
is
None
:
self
.
flashinfer_decode_workspace_buffer
=
torch
.
empty
(
FLASHINFER_WORKSPACE_BUFFER_SIZE
,
dtype
=
torch
.
uint8
,
device
=
self
.
device
)
self
.
flashinfer_decode_wrapper
=
\
BatchDecodeWithPagedKVCacheWrapper
(
self
.
flashinfer_decode_workspace_buffer
,
"NHD"
)
self
.
flashinfer_prefill_workspace_buffer
=
torch
.
empty
(
FLASHINFER_WORKSPACE_BUFFER_SIZE
,
dtype
=
torch
.
uint8
,
device
=
self
.
device
)
self
.
flashinfer_prefill_wrapper
=
\
BatchPrefillWithPagedKVCacheWrapper
(
self
.
flashinfer_prefill_workspace_buffer
,
"NHD"
)
model_input
.
attn_metadata
.
prefill_wrapper
=
\
self
.
flashinfer_prefill_wrapper
if
model_input
.
attn_metadata
.
use_cuda_graph
:
batch_size
=
model_input
.
input_tokens
.
shape
[
0
]
model_input
.
attn_metadata
.
decode_wrapper
=
\
self
.
graph_runners
[
model_input
.
virtual_engine
][
batch_size
].
flashinfer_decode_wrapper
else
:
model_input
.
attn_metadata
.
decode_wrapper
=
\
self
.
flashinfer_decode_wrapper
model_input
.
attn_metadata
.
begin_forward
()
# Detect exec mode
assert
model_input
.
attn_metadata
is
not
None
use_cuda_graph
=
False
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment