Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
1009e93c
Unverified
Commit
1009e93c
authored
Sep 17, 2024
by
sroy745
Committed by
GitHub
Sep 17, 2024
Browse files
[Encoder decoder] Add cuda graph support during decoding for encoder-decoder models (#7631)
parent
1b6de835
Changes
15
Hide whitespace changes
Inline
Side-by-side
Showing
15 changed files
with
526 additions
and
112 deletions
+526
-112
.buildkite/test-pipeline.yaml
.buildkite/test-pipeline.yaml
+7
-0
tests/encoder_decoder/__init__.py
tests/encoder_decoder/__init__.py
+0
-0
tests/encoder_decoder/test_e2e_correctness.py
tests/encoder_decoder/test_e2e_correctness.py
+98
-0
tests/worker/test_encoder_decoder_model_runner.py
tests/worker/test_encoder_decoder_model_runner.py
+160
-22
vllm/attention/backends/abstract.py
vllm/attention/backends/abstract.py
+13
-4
vllm/attention/backends/flashinfer.py
vllm/attention/backends/flashinfer.py
+9
-3
vllm/attention/backends/utils.py
vllm/attention/backends/utils.py
+107
-6
vllm/config.py
vllm/config.py
+8
-33
vllm/engine/arg_utils.py
vllm/engine/arg_utils.py
+4
-1
vllm/entrypoints/llm.py
vllm/entrypoints/llm.py
+4
-4
vllm/model_executor/models/bart.py
vllm/model_executor/models/bart.py
+4
-2
vllm/utils.py
vllm/utils.py
+0
-5
vllm/worker/enc_dec_model_runner.py
vllm/worker/enc_dec_model_runner.py
+36
-7
vllm/worker/model_runner.py
vllm/worker/model_runner.py
+76
-21
vllm/worker/utils.py
vllm/worker/utils.py
+0
-4
No files found.
.buildkite/test-pipeline.yaml
View file @
1009e93c
...
...
@@ -252,6 +252,13 @@ steps:
-
export VLLM_WORKER_MULTIPROC_METHOD=spawn
-
bash ./run-tests.sh -c configs/models-small.txt -t
1
-
label
:
Encoder Decoder tests
# 5min
source_file_dependencies
:
-
vllm/
-
tests/encoder_decoder
commands
:
-
pytest -v -s encoder_decoder
-
label
:
OpenAI-Compatible Tool Use
# 20 min
fast_check
:
false
mirror_hardwares
:
[
amd
]
...
...
tests/encoder_decoder/__init__.py
0 → 100644
View file @
1009e93c
tests/encoder_decoder/test_e2e_correctness.py
0 → 100644
View file @
1009e93c
"""E2E tests to verify the correctness of the encoder-decoder framework
Run `pytest tests/encoder_decoder/test_e2e_correctness.py`.
"""
from
typing
import
List
,
Optional
,
Tuple
import
pytest
from
transformers
import
AutoModelForSeq2SeqLM
from
vllm.sequence
import
SampleLogprobs
from
vllm.utils
import
is_cpu
from
..conftest
import
DecoderPromptType
from
..models.utils
import
check_logprobs_close
def
vllm_to_hf_output
(
vllm_output
:
Tuple
[
List
[
int
],
str
,
Optional
[
SampleLogprobs
]],
decoder_prompt_type
:
DecoderPromptType
,
):
"""Sanitize vllm output to be comparable with hf output."""
output_ids
,
output_str
,
out_logprobs
=
vllm_output
hf_output_str
=
output_str
+
"</s>"
if
decoder_prompt_type
==
DecoderPromptType
.
NONE
:
hf_output_str
=
"<s>"
+
hf_output_str
return
output_ids
,
hf_output_str
,
out_logprobs
@
pytest
.
mark
.
parametrize
(
"model"
,
[
"facebook/bart-large-cnn"
])
@
pytest
.
mark
.
parametrize
(
"dtype"
,
[
"bfloat16"
])
@
pytest
.
mark
.
parametrize
(
"max_tokens"
,
[
128
])
@
pytest
.
mark
.
parametrize
(
"num_logprobs"
,
[
5
])
@
pytest
.
mark
.
parametrize
(
"decoder_prompt_type"
,
list
(
DecoderPromptType
))
@
pytest
.
mark
.
parametrize
(
"enforce_eager"
,
[
True
,
False
])
@
pytest
.
mark
.
skipif
(
is_cpu
(),
reason
=
"CPU backend is not currently supported with encoder/decoder models"
)
def
test_encoder_decoder_e2e
(
hf_runner
,
vllm_runner
,
example_encoder_decoder_prompts
,
model
:
str
,
dtype
:
str
,
max_tokens
:
int
,
num_logprobs
:
int
,
decoder_prompt_type
:
DecoderPromptType
,
enforce_eager
:
bool
,
)
->
None
:
'''
End-to-End (E2E) test for the encoder-decoder framework.
This test evaluates the encoder-decoder functionality using the BART
model. We compare the outputs of the Hugging Face and vLLM
implementations to ensure that both implementations produce consistent
and correct results.
'''
test_case_prompts
=
example_encoder_decoder_prompts
[
decoder_prompt_type
]
# Configuration settings for HF baseline
hf_kwargs
=
{
"top_k"
:
None
,
"num_beams"
:
1
,
"repetition_penalty"
:
1.0
,
"top_p"
:
1.0
,
"length_penalty"
:
1.0
,
"early_stopping"
:
False
,
"no_repeat_ngram_size"
:
None
,
"min_length"
:
0
}
with
hf_runner
(
model
,
dtype
=
dtype
,
auto_cls
=
AutoModelForSeq2SeqLM
)
as
hf_model
:
hf_outputs
=
(
hf_model
.
generate_encoder_decoder_greedy_logprobs_limit
(
test_case_prompts
,
max_tokens
,
num_logprobs
,
**
hf_kwargs
,
))
with
vllm_runner
(
model
,
dtype
=
dtype
,
enforce_eager
=
enforce_eager
)
as
vllm_model
:
vllm_outputs
=
vllm_model
.
generate_encoder_decoder_greedy_logprobs
(
test_case_prompts
,
max_tokens
,
num_logprobs
)
hf_skip_tokens
=
(
1
if
decoder_prompt_type
==
DecoderPromptType
.
NONE
else
0
)
check_logprobs_close
(
outputs_0_lst
=
hf_outputs
,
outputs_1_lst
=
[
vllm_to_hf_output
(
vllm_output
,
decoder_prompt_type
)
for
vllm_output
in
vllm_outputs
],
name_0
=
"hf"
,
name_1
=
"vllm"
,
num_outputs_0_skip_tokens
=
hf_skip_tokens
,
)
tests/worker/test_encoder_decoder_model_runner.py
View file @
1009e93c
import
itertools
from
array
import
array
from
typing
import
List
...
...
@@ -7,13 +8,9 @@ import torch
from
vllm.engine.arg_utils
import
EngineArgs
from
vllm.sequence
import
(
VLLM_TOKEN_ID_ARRAY_TYPE
,
SamplingParams
,
SequenceData
,
SequenceGroupMetadata
)
from
vllm.utils
import
is_cpu
from
vllm.utils
import
is_cpu
,
make_tensor_with_pad
from
vllm.worker.enc_dec_model_runner
import
EncoderDecoderModelRunner
# CUDA graph scenarios to test
#
# Currently CUDA graph is not supported
ENFORCE_EAGER
=
[
True
]
from
vllm.worker.model_runner
import
_get_graph_batch_size
BATCH_SIZES
=
[
1
,
4
,
16
,
64
,
256
]
...
...
@@ -40,8 +37,7 @@ def _create_model_runner(model: str, *args,
reason
=
"CPU backend is currently "
"unsupported for encoder/ "
"decoder models"
)
@
pytest
.
mark
.
parametrize
(
"enforce_eager"
,
ENFORCE_EAGER
)
def
test_empty_seq_group
(
enforce_eager
,
):
def
test_empty_seq_group
():
"""Verify prepare prompt and decode returns empty output
for empty seq group list"""
...
...
@@ -52,7 +48,7 @@ def test_empty_seq_group(enforce_eager, ):
max_num_batched_tokens
=
100000
,
max_num_seqs
=
100000
,
enable_chunked_prefill
=
False
,
enforce_eager
=
enforce_eager
,
enforce_eager
=
True
,
)
seq_group_metadata_list
:
List
[
SequenceGroupMetadata
]
=
[]
model_input
=
model_runner
.
_prepare_model_input_tensors
(
...
...
@@ -85,11 +81,7 @@ def test_empty_seq_group(enforce_eager, ):
"unsupported for encoder/ "
"decoder models"
)
@
pytest
.
mark
.
parametrize
(
"batch_size"
,
BATCH_SIZES
)
@
pytest
.
mark
.
parametrize
(
"enforce_eager"
,
ENFORCE_EAGER
)
def
test_prepare_prompt
(
batch_size
,
enforce_eager
,
):
def
test_prepare_prompt
(
batch_size
):
'''
Test the ability of the encoder/decoder model runner subclass to
produce prefill-phase model inputs & attention metadata.
...
...
@@ -115,7 +107,7 @@ def test_prepare_prompt(
max_num_batched_tokens
=
100000
,
max_num_seqs
=
100000
,
enable_chunked_prefill
=
False
,
enforce_eager
=
enforce_eager
,
enforce_eager
=
True
,
)
seq_lens
:
List
[
int
]
=
[]
...
...
@@ -281,11 +273,7 @@ def test_prepare_prompt(
"unsupported for encoder/ "
"decoder models"
)
@
pytest
.
mark
.
parametrize
(
"batch_size"
,
BATCH_SIZES
)
@
pytest
.
mark
.
parametrize
(
"enforce_eager"
,
ENFORCE_EAGER
)
def
test_prepare_decode
(
batch_size
,
enforce_eager
,
):
def
test_prepare_decode
(
batch_size
):
'''
Test the ability of the encoder/decoder model runner subclass to
produce decode-phase model inputs & attention metadata.
...
...
@@ -311,7 +299,7 @@ def test_prepare_decode(
max_num_batched_tokens
=
100000
,
max_num_seqs
=
100000
,
enable_chunked_prefill
=
False
,
enforce_eager
=
enforce_eager
,
enforce_eager
=
True
,
)
seq_lens
:
List
[
int
]
=
[]
...
...
@@ -428,7 +416,8 @@ def test_prepare_decode(
expected
,
)
# Cuda graph should is currently not supported for encoder/decoer.
# Model runner's CUDAGraph setting should be propagated to attention
# metadata.
assert
attn_metadata
.
use_cuda_graph
is
False
# Verify the lengths of input tokens & positions
...
...
@@ -484,3 +473,152 @@ def test_prepare_decode(
dtype
=
actual
.
dtype
,
)
assert
torch
.
equal
(
actual
,
expected
)
@
pytest
.
mark
.
parametrize
(
"batch_size"
,
list
(
range
(
1
,
257
)))
def
test_prepare_decode_cuda_graph
(
batch_size
):
"""
Tests that for encoder-decoder models with CUDA Graph capture and replay
enabled, the tensors used during the decode phase are correctly padded
for varying input batch sizes.
"""
model_runner
=
_create_model_runner
(
"facebook/bart-base"
,
seed
=
0
,
dtype
=
"float16"
,
max_num_batched_tokens
=
100000
,
max_num_seqs
=
100000
,
enable_chunked_prefill
=
False
,
enforce_eager
=
False
,
)
seq_lens
:
List
[
int
]
=
[]
encoder_seq_lens
:
List
[
int
]
=
[]
seq_group_metadata_list
:
List
[
SequenceGroupMetadata
]
=
[]
block_tables
=
{
0
:
[
1
]}
cross_block_table
=
[
2
]
for
i
in
range
(
batch_size
):
# make sure all tokens fit into one block
seq_len
=
i
%
(
model_runner
.
block_size
-
1
)
+
1
seq_lens
.
append
(
seq_len
)
seq_data
=
SequenceData
(
array
(
VLLM_TOKEN_ID_ARRAY_TYPE
,
(
range
(
seq_len
))))
encoder_seq_len
=
(
i
+
1
)
%
(
model_runner
.
block_size
-
1
)
+
1
encoder_seq_lens
.
append
(
encoder_seq_len
)
encoder_seq_data
=
SequenceData
(
array
(
VLLM_TOKEN_ID_ARRAY_TYPE
,
(
range
(
encoder_seq_len
))))
seq_group_metadata
=
SequenceGroupMetadata
(
request_id
=
f
"test_
{
i
}
"
,
is_prompt
=
False
,
seq_data
=
{
0
:
seq_data
},
sampling_params
=
SamplingParams
(
temperature
=
0
),
block_tables
=
block_tables
,
encoder_seq_data
=
encoder_seq_data
,
cross_block_table
=
cross_block_table
,
)
assert
seq_group_metadata
.
token_chunk_size
==
1
seq_group_metadata_list
.
append
(
seq_group_metadata
)
model_input
=
model_runner
.
prepare_model_input
(
seq_group_metadata_list
)
input_tokens
=
model_input
.
input_tokens
input_positions
=
model_input
.
input_positions
attn_metadata
=
model_input
.
attn_metadata
return_seq_lens
=
model_input
.
seq_lens
slot_mapping
=
attn_metadata
.
slot_mapping
encoder_input_tokens
=
model_input
.
encoder_input_tokens
encoder_input_positions
=
model_input
.
encoder_input_positions
cross_slot_mapping
=
attn_metadata
.
cross_slot_mapping
# With CUDA Graph capture and replay enabled, the decoder and encoder
# input sequences will be padded. Create the expected padded tensors
# accordingly.
graph_batch_size
=
_get_graph_batch_size
(
batch_size
)
cuda_graph_pad_size
=
graph_batch_size
-
batch_size
padded_seq_lens
=
seq_lens
+
list
(
itertools
.
repeat
(
1
,
cuda_graph_pad_size
))
padded_encoder_seq_lens
=
encoder_seq_lens
+
list
(
itertools
.
repeat
(
1
,
cuda_graph_pad_size
))
assert
return_seq_lens
==
padded_seq_lens
assert
len
(
slot_mapping
)
==
len
(
input_tokens
)
assert
len
(
cross_slot_mapping
)
==
len
(
encoder_input_tokens
)
# Verify attention metadata
device
=
model_runner
.
device
assert
attn_metadata
.
num_prefills
==
0
assert
attn_metadata
.
num_decode_tokens
>
0
assert
torch
.
equal
(
attn_metadata
.
seq_lens_tensor
,
torch
.
tensor
(
padded_seq_lens
,
device
=
device
,
dtype
=
torch
.
int
))
assert
attn_metadata
.
seq_lens
==
padded_seq_lens
assert
attn_metadata
.
max_prefill_seq_len
==
0
assert
attn_metadata
.
max_decode_seq_len
==
max
(
seq_lens
)
# - Encoder attention metadata
assert
attn_metadata
.
encoder_seq_lens
==
padded_encoder_seq_lens
assert
torch
.
equal
(
attn_metadata
.
encoder_seq_lens_tensor
,
torch
.
tensor
(
padded_encoder_seq_lens
,
device
=
device
,
dtype
=
torch
.
int
))
assert
attn_metadata
.
max_encoder_seq_len
==
max
(
padded_encoder_seq_lens
)
assert
attn_metadata
.
num_encoder_tokens
==
sum
(
padded_encoder_seq_lens
)
# Verify block tables are correct for prompts
# - Decoder self-attention. Pad the block tables as expected.
expected
=
[
block_tables
[
0
]
for
_
in
range
(
batch_size
)]
expected
.
extend
([[]
for
_
in
range
(
cuda_graph_pad_size
)])
expected
=
make_tensor_with_pad
(
expected
,
max_len
=
64
,
pad
=
0
,
dtype
=
torch
.
int32
,
device
=
model_runner
.
device
,
)
assert
torch
.
equal
(
attn_metadata
.
block_tables
,
expected
,
)
# - Encoder/decoder cross-attention. Pad the cross-attention block tables
# as expected.
expected
=
[
cross_block_table
for
_
in
range
(
len
(
seq_group_metadata_list
))]
expected
.
extend
([[]
for
_
in
range
(
cuda_graph_pad_size
)])
expected
=
make_tensor_with_pad
(
expected
,
max_len
=
64
,
pad
=
0
,
dtype
=
torch
.
int32
,
device
=
model_runner
.
device
,
)
assert
torch
.
equal
(
attn_metadata
.
cross_block_tables
,
expected
,
)
# Model runner's CUDAGraph setting should be propagated to attention
# metadata.
assert
attn_metadata
.
use_cuda_graph
is
True
# Verify the lengths of input tokens & positions
# - Decoder
assert
len
(
input_tokens
)
==
len
(
padded_seq_lens
)
assert
len
(
input_positions
)
==
len
(
padded_seq_lens
)
# -- An indirect check that model_input.input_tokens
# and model_input.input_positions are correct -
# by design of the test, the input tokens are
# equal to the input position values, so if
# the model_input data structure has the correct
# values then these two should be equal
assert
torch
.
equal
(
input_tokens
,
input_positions
,
)
# - Encoder
assert
len
(
encoder_input_tokens
)
==
0
assert
len
(
encoder_input_tokens
)
==
0
# -- An indirect check that model_input.encoder_input_tokens
# and model_input.encoder_input_positions are correct -
# by design of the test, the input tokens are
# equal to the input position values, so if
# the model_input data structure has the correct
# values then these two should be equal
assert
torch
.
equal
(
encoder_input_tokens
,
encoder_input_positions
,
)
vllm/attention/backends/abstract.py
View file @
1009e93c
...
...
@@ -156,18 +156,27 @@ class AttentionState(ABC, Generic[T]):
...
@
abstractmethod
def
graph_capture_get_metadata_for_batch
(
self
,
batch_size
:
int
)
->
T
:
def
graph_capture_get_metadata_for_batch
(
self
,
batch_size
:
int
,
is_encoder_decoder_model
:
bool
=
False
)
->
T
:
"""Get attention metadata for CUDA graph capture of batch_size."""
...
@
abstractmethod
def
get_graph_input_buffers
(
self
,
attn_metadata
:
T
)
->
Dict
[
str
,
Any
]:
def
get_graph_input_buffers
(
self
,
attn_metadata
:
T
,
is_encoder_decoder_model
:
bool
=
False
)
->
Dict
[
str
,
Any
]:
"""Get attention-specific input buffers for CUDA graph capture."""
...
@
abstractmethod
def
prepare_graph_input_buffers
(
self
,
input_buffers
:
Dict
[
str
,
Any
],
attn_metadata
:
T
)
->
None
:
def
prepare_graph_input_buffers
(
self
,
input_buffers
:
Dict
[
str
,
Any
],
attn_metadata
:
T
,
is_encoder_decoder_model
:
bool
=
False
)
->
None
:
"""In-place modify input buffers dict for CUDA graph replay."""
...
...
...
vllm/attention/backends/flashinfer.py
View file @
1009e93c
...
...
@@ -172,7 +172,8 @@ class FlashInferState(AttentionState):
state
.
_prefill_wrapper
=
self
.
_get_prefill_wrapper
()
return
state
def
graph_capture_get_metadata_for_batch
(
self
,
batch_size
:
int
):
def
graph_capture_get_metadata_for_batch
(
self
,
batch_size
:
int
,
is_encoder_decoder_model
:
bool
=
False
):
assert
self
.
_is_graph_capturing
_indptr_buffer
=
self
.
_graph_indptr_buffer
[:
batch_size
+
1
]
_last_page_len_buffer
=
self
.
_graph_last_page_len_buffer
[:
batch_size
]
...
...
@@ -232,12 +233,17 @@ class FlashInferState(AttentionState):
attn_metadata
.
begin_forward
()
return
attn_metadata
def
get_graph_input_buffers
(
self
,
attn_metadata
):
def
get_graph_input_buffers
(
self
,
attn_metadata
,
is_encoder_decoder_model
:
bool
=
False
):
return
{
"slot_mapping"
:
attn_metadata
.
slot_mapping
,
}
def
prepare_graph_input_buffers
(
self
,
input_buffers
,
attn_metadata
):
def
prepare_graph_input_buffers
(
self
,
input_buffers
,
attn_metadata
,
is_encoder_decoder_model
:
bool
=
False
):
return
def
begin_forward
(
self
,
model_input
):
...
...
vllm/attention/backends/utils.py
View file @
1009e93c
...
...
@@ -304,7 +304,8 @@ class CommonAttentionState(AttentionState):
assert
self
.
_is_graph_capturing
return
self
.
__class__
(
self
.
runner
)
def
graph_capture_get_metadata_for_batch
(
self
,
batch_size
:
int
):
def
graph_capture_get_metadata_for_batch
(
self
,
batch_size
:
int
,
is_encoder_decoder_model
:
bool
=
False
):
assert
self
.
_is_graph_capturing
attn_metadata
=
self
.
runner
.
attn_backend
.
make_metadata
(
num_prefills
=
0
,
...
...
@@ -322,21 +323,121 @@ class CommonAttentionState(AttentionState):
block_tables
=
self
.
_graph_block_tables
[:
batch_size
],
use_cuda_graph
=
True
,
)
if
is_encoder_decoder_model
:
# The encoder decoder model works only with XFormers backend.
# Assert the same.
assert
self
.
runner
.
attn_backend
.
get_name
()
==
"xformers"
,
\
f
"Expected attn_backend name to be 'xformers', but "
\
f
" got '
{
self
.
runner
.
attn_backend
.
get_name
()
}
'"
self
.
_update_captured_metadata_for_enc_dec_model
(
batch_size
=
batch_size
,
attn_metadata
=
attn_metadata
)
return
attn_metadata
def
get_graph_input_buffers
(
self
,
attn_metadata
)
->
Dict
[
str
,
Any
]:
return
{
def
get_graph_input_buffers
(
self
,
attn_metadata
,
is_encoder_decoder_model
:
bool
=
False
)
->
Dict
[
str
,
Any
]:
input_buffers
=
{
"slot_mapping"
:
attn_metadata
.
slot_mapping
,
"seq_lens_tensor"
:
attn_metadata
.
decode_metadata
.
seq_lens_tensor
,
"block_tables"
:
attn_metadata
.
decode_metadata
.
block_tables
,
}
def
prepare_graph_input_buffers
(
self
,
input_buffers
,
attn_metadata
)
->
None
:
if
is_encoder_decoder_model
:
# The encoder decoder model works only with XFormers backend.
# Assert the same.
assert
self
.
runner
.
attn_backend
.
get_name
()
==
"xformers"
,
\
f
"Expected attn_backend name to be 'xformers', but "
\
f
" got '
{
self
.
runner
.
attn_backend
.
get_name
()
}
'"
self
.
_add_additonal_input_buffers_for_enc_dec_model
(
attn_metadata
=
attn_metadata
,
input_buffers
=
input_buffers
)
return
input_buffers
def
prepare_graph_input_buffers
(
self
,
input_buffers
,
attn_metadata
,
is_encoder_decoder_model
:
bool
=
False
)
->
None
:
input_buffers
[
"seq_lens_tensor"
].
copy_
(
attn_metadata
.
decode_metadata
.
seq_lens_tensor
,
non_blocking
=
True
)
input_buffers
[
"block_tables"
].
copy_
(
attn_metadata
.
decode_metadata
.
block_tables
,
non_blocking
=
True
)
if
is_encoder_decoder_model
:
# The encoder decoder model works only with XFormers backend.
# Assert the same.
assert
self
.
runner
.
attn_backend
.
get_name
()
==
"xformers"
,
\
f
"Expected attn_backend name to be 'xformers', but "
\
f
" got '
{
self
.
runner
.
attn_backend
.
get_name
()
}
'"
self
.
_prepare_input_buffers_for_enc_dec_model
(
attn_metadata
,
input_buffers
)
def
begin_forward
(
self
,
model_input
)
->
None
:
return
def
_update_captured_metadata_for_enc_dec_model
(
self
,
batch_size
:
int
,
attn_metadata
):
"""
Updates the attention metadata parameters for CUDA graph capture in an
encoder-decoder model.
This method modifies attention-related tensors and metadata required
for CUDA graph capture in encoder-decoder models. Specifically, it
updates the cross-attention and encoder sequence tensors in the
AttentionMetadata object.
"""
# During decode phase the cross_slot_mapping will be empty. Hence set
# an empty tensor for CUDA Graph capture.
attn_metadata
.
cross_slot_mapping
=
torch
.
tensor
(
[],
dtype
=
torch
.
int
).
cuda
()
attn_metadata
.
cross_block_tables
=
torch
.
full
(
(
batch_size
,
self
.
runner
.
get_max_block_per_batch
()),
1
,
dtype
=
torch
.
int
).
cuda
()
attn_metadata
.
encoder_seq_lens
=
torch
.
full
((
batch_size
,
),
1
,
dtype
=
torch
.
int
).
cuda
()
attn_metadata
.
encoder_seq_lens_tensor
=
torch
.
full
(
(
batch_size
,
),
1
,
dtype
=
torch
.
int
).
cuda
()
attn_metadata
.
max_encoder_seq_len
=
self
.
runner
.
max_seq_len_to_capture
def
_add_additonal_input_buffers_for_enc_dec_model
(
self
,
attn_metadata
,
input_buffers
:
Dict
[
str
,
Any
]):
"""
Saves additional input buffers specific to the encoder-decoder model
from the attention metadata.
This method extracts and stores encoder-decoder related input buffers
from the `attn_metadata` into the `input_buffers` dictionary. The
buffers include encoder sequence lengths, cross-slot mappings, and
cross-block tables, which are essential for the encoder-decoder model
during CUDA graph replay.
"""
input_buffers
[
"encoder_seq_lens_tensor"
]
=
(
attn_metadata
.
decode_metadata
.
encoder_seq_lens_tensor
)
input_buffers
[
"cross_slot_mapping"
]
=
(
attn_metadata
.
decode_metadata
.
cross_slot_mapping
)
input_buffers
[
"cross_block_tables"
]
=
(
attn_metadata
.
decode_metadata
.
cross_block_tables
)
def
_prepare_input_buffers_for_enc_dec_model
(
self
,
attn_metadata
,
input_buffers
:
Dict
[
str
,
Any
]):
"""
Populates input buffers with data from the encoder-decoder model's
attention metadata.
This method fills the input buffers with encoder-decoder specific
tensors. It copies data from the `attn_metadata` and keyword arguments
(`kwargs`) into corresponding buffers in the `input_buffers` dictionary.
The copied data includes attention-related metadata as well as input
IDs and positional information for the encoder.
"""
input_buffers
[
"encoder_seq_lens_tensor"
].
copy_
(
attn_metadata
.
decode_metadata
.
encoder_seq_lens_tensor
,
non_blocking
=
True
)
input_buffers
[
"cross_slot_mapping"
].
copy_
(
attn_metadata
.
decode_metadata
.
cross_slot_mapping
,
non_blocking
=
True
)
input_buffers
[
"cross_block_tables"
].
copy_
(
attn_metadata
.
decode_metadata
.
cross_block_tables
,
non_blocking
=
True
)
vllm/config.py
View file @
1009e93c
...
...
@@ -16,9 +16,8 @@ from vllm.tracing import is_otel_available, otel_import_error_traceback
from
vllm.transformers_utils.config
import
(
ConfigFormat
,
get_config
,
get_hf_image_processor_config
,
get_hf_text_config
)
from
vllm.utils
import
(
STR_NOT_IMPL_ENC_DEC_CUDAGRAPH
,
GiB_bytes
,
cuda_device_count_stateless
,
get_cpu_memory
,
is_cpu
,
is_hip
,
is_neuron
,
is_openvino
,
is_xpu
,
from
vllm.utils
import
(
GiB_bytes
,
cuda_device_count_stateless
,
get_cpu_memory
,
is_cpu
,
is_hip
,
is_neuron
,
is_openvino
,
is_xpu
,
print_warning_once
)
if
TYPE_CHECKING
:
...
...
@@ -96,15 +95,15 @@ class ModelConfig:
enforce_eager: Whether to enforce eager execution. If True, we will
disable CUDA graph and always execute the model in eager mode.
If False, we will use CUDA graph and eager execution in hybrid.
If None, the user did not specify, so default to False -
except for encoder/decoder models, which currently require
eager mode.
If None, the user did not specify, so default to False.
max_context_len_to_capture: Maximum context len covered by CUDA graphs.
When a sequence has context length larger than this, we fall back
to eager mode (DEPRECATED. Use max_seq_len_to_capture instead).
max_seq_len_to_capture: Maximum sequence len covered by CUDA graphs.
When a sequence has context length larger than this, we fall back
to eager mode
to eager mode. Additionally for encoder-decoder models, if the
sequence length of the encoder input is larger than this, we fall
back to the eager mode.
disable_sliding_window: Whether to disable sliding window. If True,
we will disable the sliding window functionality of the model.
If the model does not support sliding window, this argument is
...
...
@@ -186,32 +185,8 @@ class ModelConfig:
self
.
dtype
=
_get_and_verify_dtype
(
self
.
hf_text_config
,
dtype
)
self
.
use_async_output_proc
=
use_async_output_proc
# Choose a default enforce_eager value if the user did not specify
# a value (enforce_eager is None)
if
getattr
(
self
.
hf_config
,
'is_encoder_decoder'
,
False
):
if
self
.
enforce_eager
is
None
:
# *Only for encoder/decoder models* and
# *only if enforce_eager is unset*, override
# to enforce_eager=True
#
# Add a logger message since it is *somewhat* non-intuitive that
# enforce_eager is True when the user has not specified its
# value.
logger
.
info
(
"Forcing enforce_eager == True because "
"enforce_eager setting was unspecified and "
"CUDAGraph is not supported with encoder/ "
"decoder models."
)
self
.
enforce_eager
=
True
if
not
self
.
enforce_eager
:
# Eager mode explicitly disabled by user for an encoder/
# decoder model; however CUDAGRAPH + encoder/decoder is
# not currently supported
raise
ValueError
(
STR_NOT_IMPL_ENC_DEC_CUDAGRAPH
)
elif
self
.
enforce_eager
is
None
:
# *Only for decoder-only models*, enforce_eager
# defaults to False if unset. This is intuitive
# so no logging message needed.
# Set enforce_eager to False if the value is unset.
if
self
.
enforce_eager
is
None
:
self
.
enforce_eager
=
False
if
(
not
self
.
disable_sliding_window
...
...
vllm/engine/arg_utils.py
View file @
1009e93c
...
...
@@ -472,7 +472,10 @@ class EngineArgs:
default
=
EngineArgs
.
max_seq_len_to_capture
,
help
=
'Maximum sequence length covered by CUDA '
'graphs. When a sequence has context length '
'larger than this, we fall back to eager mode.'
)
'larger than this, we fall back to eager mode. '
'Additionally for encoder-decoder models, if the '
'sequence length of the encoder input is larger '
'than this, we fall back to the eager mode.'
)
parser
.
add_argument
(
'--disable-custom-all-reduce'
,
action
=
'store_true'
,
default
=
EngineArgs
.
disable_custom_all_reduce
,
...
...
vllm/entrypoints/llm.py
View file @
1009e93c
...
...
@@ -88,7 +88,9 @@ class LLM:
to eager mode (DEPRECATED. Use `max_seq_len_to_capture` instead).
max_seq_len_to_capture: Maximum sequence len covered by CUDA graphs.
When a sequence has context length larger than this, we fall back
to eager mode.
to eager mode. Additionally for encoder-decoder models, if the
sequence length of the encoder input is larger than this, we fall
back to the eager mode.
disable_custom_all_reduce: See ParallelConfig
**kwargs: Arguments for :class:`~vllm.EngineArgs`. (See
:ref:`engine_args`)
...
...
@@ -137,9 +139,7 @@ class LLM:
LLM constructor.
Note: if enforce_eager is unset (enforce_eager is None)
it defaults to False for decoder-only models and True
for encoder/decoder models, since encoder/decoder models
do not currently support CUDAGraph.
it defaults to False.
'''
if
"disable_log_stats"
not
in
kwargs
:
...
...
vllm/model_executor/models/bart.py
View file @
1009e93c
...
...
@@ -848,11 +848,13 @@ class BartForConditionalGeneration(nn.Module):
self
,
input_ids
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
encoder_input_ids
:
torch
.
Tensor
,
encoder_positions
:
torch
.
Tensor
,
kv_caches
:
List
[
torch
.
Tensor
],
attn_metadata
:
AttentionMetadata
,
intermediate_tensors
:
Optional
[
IntermediateTensors
]
=
None
,
*
,
encoder_input_ids
:
torch
.
Tensor
,
encoder_positions
:
torch
.
Tensor
,
**
kwargs
,
)
->
torch
.
Tensor
:
r
"""
Args:
...
...
vllm/utils.py
View file @
1009e93c
...
...
@@ -71,10 +71,6 @@ STR_NOT_IMPL_ENC_DEC_SPEC_DEC = ("Speculative decoding is not "
"currently supported with encoder/"
"decoder models."
)
STR_NOT_IMPL_ENC_DEC_CUDAGRAPH
=
(
"CUDAGraph is not "
"currently supported with encoder/"
"decoder models."
)
STR_NOT_IMPL_ENC_DEC_BACKEND
=
(
"XFormers is the only backend "
"currently supported with encoder/"
"decoder models."
)
...
...
@@ -98,7 +94,6 @@ STR_NOT_IMPL_ENC_DEC_ERR_STRS = {
"STR_NOT_IMPL_ENC_DEC_PP"
:
STR_NOT_IMPL_ENC_DEC_PP
,
"STR_NOT_IMPL_ENC_DEC_MM"
:
STR_NOT_IMPL_ENC_DEC_MM
,
"STR_NOT_IMPL_ENC_DEC_SPEC_DEC"
:
STR_NOT_IMPL_ENC_DEC_SPEC_DEC
,
"STR_NOT_IMPL_ENC_DEC_CUDA_GRAPH"
:
STR_NOT_IMPL_ENC_DEC_CUDAGRAPH
,
"STR_NOT_IMPL_ENC_DEC_BACKEND"
:
STR_NOT_IMPL_ENC_DEC_BACKEND
,
"STR_NOT_IMPL_ENC_DEC_PROMPT_ADAPTER"
:
STR_NOT_IMPL_ENC_DEC_PROMPT_ADAPTER
,
"STR_NOT_IMPL_ENC_DEC_CPU"
:
STR_NOT_IMPL_ENC_DEC_CPU
...
...
vllm/worker/enc_dec_model_runner.py
View file @
1009e93c
import
dataclasses
import
itertools
from
typing
import
Any
,
Dict
,
List
,
Optional
,
Tuple
,
Type
,
cast
import
torch
...
...
@@ -24,7 +25,8 @@ from vllm.sequence import (IntermediateTensors, PoolerOutput,
from
vllm.utils
import
STR_NOT_IMPL_ENC_DEC_BACKEND
,
make_tensor_with_pad
from
vllm.worker.model_runner
import
(
GPUModelRunnerBase
,
ModelInputForGPUBuilder
,
ModelInputForGPUWithSamplingMetadata
)
ModelInputForGPUWithSamplingMetadata
,
_get_graph_batch_size
)
from
vllm.worker.model_runner_base
import
(
_add_attn_metadata_broadcastable_dict
,
_add_sampling_metadata_broadcastable_dict
)
...
...
@@ -178,7 +180,15 @@ class EncoderDecoderModelRunner(GPUModelRunnerBase[EncoderDecoderModelInput]):
raise
ValueError
(
"num_steps > 1 is not supported in "
"EncoderDecoderModelRunner"
)
model_executable
=
self
.
model
if
(
model_input
.
attn_metadata
is
not
None
and
model_input
.
attn_metadata
.
prefill_metadata
is
None
and
model_input
.
attn_metadata
.
decode_metadata
.
use_cuda_graph
):
assert
model_input
.
input_tokens
is
not
None
graph_batch_size
=
model_input
.
input_tokens
.
shape
[
0
]
model_executable
=
self
.
graph_runners
[
model_input
.
virtual_engine
][
graph_batch_size
]
else
:
model_executable
=
self
.
model
seqlen_agnostic_kwargs
=
{
"finished_requests_ids"
:
model_input
.
finished_requests_ids
,
...
...
@@ -200,6 +210,9 @@ class EncoderDecoderModelRunner(GPUModelRunnerBase[EncoderDecoderModelInput]):
if
not
self
.
is_driver_worker
:
return
[]
if
model_input
.
async_callback
is
not
None
:
model_input
.
async_callback
()
# Sample the next token.
output
:
SamplerOutput
=
self
.
model
.
sample
(
logits
=
logits
,
...
...
@@ -231,14 +244,12 @@ class EncoderDecoderModelRunner(GPUModelRunnerBase[EncoderDecoderModelInput]):
"""
model_input
=
self
.
_prepare_model_input_tensors
(
seq_group_metadata_list
,
finished_requests_ids
)
(
attn_metadata
,
encoder_input_tokens_tensor
,
encoder_input_positions_tensor
,
)
=
(
self
.
_prepare_encoder_model_input_tensors
(
seq_group_metadata_list
,
model_input
))
# Inject attn_metadata encoder/cross-attention fields &
# encoder input tokens/positions into model_input.
# Frozen dataclass fields cannot be modified, so use
...
...
@@ -437,11 +448,29 @@ class EncoderDecoderModelRunner(GPUModelRunnerBase[EncoderDecoderModelInput]):
cross_block_tables
.
append
([]
if
(
cross_block_table
is
None
)
else
cross_block_table
)
# Convert cross-attention block tables to encoder input tensor
if
(
model_input
.
attn_metadata
is
not
None
and
model_input
.
attn_metadata
.
use_cuda_graph
):
# We will be using CUDA graph replay for this decode.
max_len_of_block_table
=
self
.
get_max_block_per_batch
()
batch_size
=
len
(
encoder_seq_lens
)
graph_batch_size
=
_get_graph_batch_size
(
batch_size
)
assert
graph_batch_size
>=
batch_size
cuda_graph_pad_size
=
graph_batch_size
-
batch_size
# extend the cross_block_tables and encoder_seq_lens to match
# the graph_batch_size.
cross_block_tables
.
extend
([[]
for
_
in
range
(
cuda_graph_pad_size
)
])
encoder_seq_lens
.
extend
(
itertools
.
repeat
(
1
,
cuda_graph_pad_size
))
else
:
max_len_of_block_table
=
max
(
len
(
block_table
)
for
block_table
in
cross_block_tables
)
cross_block_tables
=
make_tensor_with_pad
(
cross_block_tables
,
max_len
=
max
(
len
(
block_table
)
for
block_table
in
cross_block_tables
),
max_len
=
max_len_of_block_table
,
pad
=
0
,
dtype
=
torch
.
int32
,
device
=
self
.
device
,
...
...
vllm/worker/model_runner.py
View file @
1009e93c
...
...
@@ -243,6 +243,7 @@ class ModelInputForGPUBuilder(ModelRunnerInputBuilderBase[ModelInputForGPU]):
prefix_cache_hit
:
bool
=
False
,
reinit
:
bool
=
False
,
reinit_use_defaults
:
bool
=
False
,
encoder_seq_len
:
int
=
0
,
):
if
reinit
:
assert
len
(
self
.
seq_ids
)
==
len
(
seq_ids
)
# type: ignore
...
...
@@ -256,6 +257,7 @@ class ModelInputForGPUBuilder(ModelRunnerInputBuilderBase[ModelInputForGPU]):
self
.
block_tables
=
block_tables
self
.
computed_block_nums
=
computed_block_nums
self
.
n_seqs
=
n_seqs
self
.
encoder_seq_len
=
encoder_seq_len
if
reinit
:
if
len
(
self
.
seq_ids
)
==
1
and
reinit_use_defaults
:
...
...
@@ -702,6 +704,11 @@ class ModelInputForGPUBuilder(ModelRunnerInputBuilderBase[ModelInputForGPU]):
assert
n_seqs
==
1
self
.
decode_only
=
False
encoder_seq_len
=
0
if
self
.
runner
.
model_config
.
is_encoder_decoder_model
:
encoder_seq_len
=
seq_group_metadata
.
encoder_seq_data
.
get_len
()
inter_data
=
self
.
init_cached_inter_data
(
request_id
=
seq_group_metadata
.
request_id
,
seq_ids
=
seq_ids
,
...
...
@@ -709,7 +716,8 @@ class ModelInputForGPUBuilder(ModelRunnerInputBuilderBase[ModelInputForGPU]):
block_tables
=
seq_group_metadata
.
block_tables
,
computed_block_nums
=
seq_group_metadata
.
computed_block_nums
,
reinit
=
True
,
reinit_use_defaults
=
True
)
reinit_use_defaults
=
True
,
encoder_seq_len
=
encoder_seq_len
)
self
.
inter_data_list
.
append
(
inter_data
)
...
...
@@ -719,11 +727,15 @@ class ModelInputForGPUBuilder(ModelRunnerInputBuilderBase[ModelInputForGPU]):
for
per_seq_group_fn
in
self
.
per_seq_group_compute_fns
:
per_seq_group_fn
(
inter_data
,
seq_group_metadata
)
def
_use_captured_graph
(
self
,
batch_size
:
int
,
max_decode_seq_len
:
int
)
->
bool
:
def
_use_captured_graph
(
self
,
batch_size
:
int
,
max_decode_seq_len
:
int
,
max_encoder_seq_len
:
int
=
0
)
->
bool
:
return
(
self
.
decode_only
and
not
self
.
runner
.
model_config
.
enforce_eager
and
batch_size
<=
self
.
runner
.
max_batchsize_to_capture
and
max_decode_seq_len
<=
self
.
runner
.
max_seq_len_to_capture
)
and
batch_size
<=
_BATCH_SIZES_TO_CAPTURE
[
-
1
]
and
max_decode_seq_len
<=
self
.
runner
.
max_seq_len_to_capture
and
max_encoder_seq_len
<=
self
.
runner
.
max_seq_len_to_capture
and
batch_size
<=
self
.
runner
.
max_batchsize_to_capture
)
def
build
(
self
)
->
ModelInputForGPU
:
"""Finalize the builder intermediate data and
...
...
@@ -763,15 +775,18 @@ class ModelInputForGPUBuilder(ModelRunnerInputBuilderBase[ModelInputForGPU]):
input_positions
.
extend
(
cur_input_positions
)
seq_lens
=
[]
query_lens
=
[]
max_decode_seq_len
=
0
max_encoder_seq_len
=
0
for
inter_data
in
self
.
inter_data_list
:
seq_lens
.
extend
(
inter_data
.
seq_lens
)
query_lens
.
extend
(
inter_data
.
query_lens
)
if
not
inter_data
.
is_prompt
:
max_decode_seq_len
=
max
(
max_decode_seq_len
,
max
(
inter_data
.
seq_lens
))
query_lens
=
[]
for
inter_data
in
self
.
inter_data_list
:
query_lens
.
extend
(
inter_data
.
query
_len
s
)
if
self
.
runner
.
model_config
.
is_encoder_decoder_model
:
max_encoder_seq_len
=
max
(
max_encoder_seq_len
,
inter_data
.
encoder_seq
_len
)
# Mapping from request IDs to sequence IDs. Used for Jamba models
# that manages the cache by itself.
...
...
@@ -781,8 +796,10 @@ class ModelInputForGPUBuilder(ModelRunnerInputBuilderBase[ModelInputForGPU]):
}
batch_size
=
len
(
input_tokens
)
use_captured_graph
=
self
.
_use_captured_graph
(
batch_size
,
max_decode_seq_len
)
use_captured_graph
=
self
.
_use_captured_graph
(
batch_size
,
max_decode_seq_len
,
max_encoder_seq_len
=
max_encoder_seq_len
)
# If cuda graph can be used, pad tensors accordingly.
# See `capture_model` API for more details.
...
...
@@ -1364,7 +1381,9 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]):
for
batch_size
in
reversed
(
batch_size_capture_list
):
attn_metadata
=
(
self
.
attn_state
.
graph_capture_get_metadata_for_batch
(
batch_size
))
batch_size
,
is_encoder_decoder_model
=
self
.
model_config
.
is_encoder_decoder_model
))
if
self
.
lora_config
:
lora_mapping
=
LoRAMapping
(
...
...
@@ -1380,10 +1399,10 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]):
)
self
.
set_active_prompt_adapters
(
set
(),
prompt_adapter_mapping
)
graph_runner
=
CUDAGraphRunner
(
self
.
model
,
self
.
attn_backend
.
get_name
(),
self
.
attn_state
.
graph_clone
(
batch_size
))
self
.
attn_state
.
graph_clone
(
batch_size
),
self
.
model_config
.
is_encoder_decoder_model
)
capture_inputs
=
{
"input_ids"
:
...
...
@@ -1420,6 +1439,12 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]):
self
.
model
.
get_seqlen_agnostic_capture_inputs
(
batch_size
)
})
if
self
.
model_config
.
is_encoder_decoder_model
:
# add the additional inputs to capture for
# encoder-decoder models.
self
.
_update_inputs_to_capture_for_enc_dec_model
(
capture_inputs
)
graph_runner
.
capture
(
**
capture_inputs
)
self
.
graph_memory_pool
=
graph_runner
.
graph
.
pool
()
self
.
graph_runners
[
virtual_engine
][
batch_size
]
=
(
...
...
@@ -1430,6 +1455,24 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]):
# This usually takes < 10 seconds.
logger
.
info
(
"Graph capturing finished in %.0f secs."
,
elapsed_time
)
def
_update_inputs_to_capture_for_enc_dec_model
(
self
,
capture_inputs
:
Dict
[
str
,
Any
]):
"""
Updates the set of input tensors needed for CUDA graph capture in an
encoder-decoder model.
This method modifies the provided `capture_inputs` dictionary by
adding tensors specific to encoder-decoder specific models that
need to be captured for CUDA Graph replay.
"""
# During the decode phase encoder_input_ids and encoder_positions are
# unset. Do the same thing for graph capture.
capture_inputs
[
"encoder_input_ids"
]
=
torch
.
tensor
(
[],
dtype
=
torch
.
long
).
cuda
()
capture_inputs
[
"encoder_positions"
]
=
torch
.
tensor
(
[],
dtype
=
torch
.
long
).
cuda
()
@
property
def
vocab_size
(
self
)
->
int
:
return
self
.
model_config
.
get_vocab_size
()
...
...
@@ -1629,7 +1672,7 @@ class ModelRunner(GPUModelRunnerBase[ModelInputForGPUWithSamplingMetadata]):
class
CUDAGraphRunner
:
def
__init__
(
self
,
model
:
nn
.
Module
,
backend_name
:
str
,
attn_state
:
AttentionState
):
attn_state
:
AttentionState
,
is_encoder_decoder_model
:
bool
):
self
.
model
=
model
self
.
backend_name
=
backend_name
self
.
attn_state
=
attn_state
...
...
@@ -1638,6 +1681,7 @@ class CUDAGraphRunner:
self
.
output_buffers
:
Dict
[
str
,
torch
.
Tensor
]
=
{}
self
.
_graph
:
Optional
[
torch
.
cuda
.
CUDAGraph
]
=
None
self
.
_is_encoder_decoder_model
=
is_encoder_decoder_model
@
property
def
graph
(
self
):
...
...
@@ -1671,8 +1715,9 @@ class CUDAGraphRunner:
intermediate_tensors
=
intermediate_inputs
,
**
kwargs
,
)
# Wait for the warm up operations to finish before proceeding with
# Graph Capture.
torch
.
cuda
.
synchronize
()
# Capture the graph.
self
.
_graph
=
torch
.
cuda
.
CUDAGraph
()
with
torch
.
cuda
.
graph
(
self
.
_graph
,
pool
=
memory_pool
,
stream
=
stream
):
...
...
@@ -1704,10 +1749,14 @@ class CUDAGraphRunner:
# Save the input and output buffers.
self
.
input_buffers
=
{
"input_ids"
:
input_ids
,
"positions"
:
positions
,
"kv_caches"
:
kv_caches
,
**
self
.
attn_state
.
get_graph_input_buffers
(
attn_metadata
),
"input_ids"
:
input_ids
,
"positions"
:
positions
,
"kv_caches"
:
kv_caches
,
**
self
.
attn_state
.
get_graph_input_buffers
(
attn_metadata
,
self
.
_is_encoder_decoder_model
),
**
kwargs
,
}
if
intermediate_inputs
is
not
None
:
...
...
@@ -1737,8 +1786,8 @@ class CUDAGraphRunner:
self
.
input_buffers
[
"positions"
].
copy_
(
positions
,
non_blocking
=
True
)
self
.
input_buffers
[
"slot_mapping"
].
copy_
(
attn_metadata
.
slot_mapping
,
non_blocking
=
True
)
self
.
attn_state
.
prepare_graph_input_buffers
(
self
.
input_buffers
,
attn_metadata
)
self
.
attn_state
.
prepare_graph_input_buffers
(
self
.
input_buffers
,
attn_metadata
,
self
.
_is_encoder_decoder_model
)
if
"seqlen_agnostic_capture_inputs"
in
self
.
input_buffers
:
self
.
model
.
copy_inputs_before_cuda_graphs
(
self
.
input_buffers
,
**
kwargs
)
...
...
@@ -1752,6 +1801,12 @@ class CUDAGraphRunner:
if
key
!=
"model_execute_time"
and
key
!=
"model_forward_time"
:
self
.
input_buffers
[
key
].
copy_
(
intermediate_tensors
[
key
],
non_blocking
=
True
)
if
self
.
_is_encoder_decoder_model
:
self
.
input_buffers
[
"encoder_input_ids"
].
copy_
(
kwargs
[
'encoder_input_ids'
],
non_blocking
=
True
)
self
.
input_buffers
[
"encoder_positions"
].
copy_
(
kwargs
[
'encoder_positions'
],
non_blocking
=
True
)
# Run the graph.
self
.
graph
.
replay
()
# Return the output tensor.
...
...
vllm/worker/utils.py
View file @
1009e93c
...
...
@@ -47,10 +47,6 @@ def assert_enc_dec_mr_supported_scenario(
raise
NotImplementedError
(
STR_NOT_IMPL_ENC_DEC_ERR_STRS
[
'STR_NOT_IMPL_ENC_DEC_SPEC_DEC'
])
if
not
enc_dec_mr
.
model_config
.
enforce_eager
:
raise
NotImplementedError
(
STR_NOT_IMPL_ENC_DEC_ERR_STRS
[
'STR_NOT_IMPL_ENC_DEC_CUDA_GRAPH'
])
if
enc_dec_mr
.
prompt_adapter_config
is
not
None
:
raise
NotImplementedError
(
STR_NOT_IMPL_ENC_DEC_ERR_STRS
[
'STR_NOT_IMPL_ENC_DEC_PROMPT_ADAPTER'
])
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment