Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
b41fb9d3
Unverified
Commit
b41fb9d3
authored
Nov 12, 2024
by
sroy745
Committed by
GitHub
Nov 12, 2024
Browse files
[Encoder Decoder] Update Mllama to run with both FlashAttention and XFormers (#9982)
Signed-off-by:
Sourashis Roy
<
sroy@roblox.com
>
parent
7c655279
Changes
5
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
117 additions
and
80 deletions
+117
-80
tests/encoder_decoder/test_e2e_correctness.py
tests/encoder_decoder/test_e2e_correctness.py
+8
-1
tests/models/encoder_decoder/vision_language/test_mllama.py
tests/models/encoder_decoder/vision_language/test_mllama.py
+63
-37
tests/test_config.py
tests/test_config.py
+2
-0
vllm/model_executor/models/mllama.py
vllm/model_executor/models/mllama.py
+38
-14
vllm/worker/enc_dec_model_runner.py
vllm/worker/enc_dec_model_runner.py
+6
-28
No files found.
tests/encoder_decoder/test_e2e_correctness.py
View file @
b41fb9d3
...
...
@@ -7,7 +7,7 @@ from typing import List, Optional, Tuple
import
pytest
from
transformers
import
AutoModelForSeq2SeqLM
from
vllm.attention.selector
import
(
_Backend
,
from
vllm.attention.selector
import
(
_Backend
,
_cached_get_attn_backend
,
global_force_attn_backend_context_manager
)
from
vllm.platforms
import
current_platform
from
vllm.sequence
import
SampleLogprobs
...
...
@@ -34,6 +34,13 @@ def vllm_to_hf_output(
return
output_ids
,
hf_output_str
,
out_logprobs
@
pytest
.
fixture
(
autouse
=
True
)
def
clear_cache
():
"""Fixture to clear backend cache before each test."""
_cached_get_attn_backend
.
cache_clear
()
# Clear the cache
yield
# This allows the test to run
@
pytest
.
mark
.
parametrize
(
"model"
,
[
"facebook/bart-large-cnn"
])
@
pytest
.
mark
.
parametrize
(
"dtype"
,
[
"float"
])
@
pytest
.
mark
.
parametrize
(
"attn_backend"
,
LIST_ENC_DEC_SUPPORTED_BACKENDS
)
...
...
tests/models/encoder_decoder/vision_language/test_mllama.py
View file @
b41fb9d3
...
...
@@ -4,6 +4,8 @@ import pytest
from
transformers
import
(
AutoConfig
,
AutoModelForVision2Seq
,
AutoTokenizer
,
BatchEncoding
)
from
vllm.attention.selector
import
(
_Backend
,
_cached_get_attn_backend
,
global_force_attn_backend_context_manager
)
from
vllm.multimodal.utils
import
rescale_image_size
from
vllm.sequence
import
SampleLogprobs
...
...
@@ -14,6 +16,8 @@ from ...utils import check_logprobs_close
_LIMIT_IMAGE_PER_PROMPT
=
3
LIST_ENC_DEC_SUPPORTED_BACKENDS
=
[
_Backend
.
XFORMERS
,
_Backend
.
FLASH_ATTN
]
HF_IMAGE_PROMPTS
=
IMAGE_ASSETS
.
prompts
({
"stop_sign"
:
"<|image|><|begin_of_text|>The meaning of the image is"
,
...
...
@@ -221,6 +225,13 @@ def _run_test(
)
@
pytest
.
fixture
(
autouse
=
True
)
def
clear_cache
():
"""Fixture to clear backend cache before each test."""
_cached_get_attn_backend
.
cache_clear
()
# Clear the cache
yield
# This allows the test to run
@
large_gpu_test
(
min_gb
=
48
)
@
pytest
.
mark
.
parametrize
(
"model"
,
models
)
@
pytest
.
mark
.
parametrize
(
...
...
@@ -244,20 +255,26 @@ def _run_test(
@
pytest
.
mark
.
parametrize
(
"dtype"
,
[
"bfloat16"
])
@
pytest
.
mark
.
parametrize
(
"max_tokens"
,
[
128
])
@
pytest
.
mark
.
parametrize
(
"num_logprobs"
,
[
5
])
@
pytest
.
mark
.
parametrize
(
"attn_backend"
,
LIST_ENC_DEC_SUPPORTED_BACKENDS
)
def
test_models_single_leading_image
(
hf_runner
,
vllm_runner
,
image_assets
,
model
,
sizes
,
dtype
,
max_tokens
,
num_logprobs
)
->
None
:
run_test
(
hf_runner
,
vllm_runner
,
image_assets
,
model
,
sizes
=
sizes
,
dtype
=
dtype
,
max_tokens
=
max_tokens
,
num_logprobs
=
num_logprobs
,
tensor_parallel_size
=
1
,
)
num_logprobs
,
attn_backend
:
_Backend
)
->
None
:
with
global_force_attn_backend_context_manager
(
attn_backend
):
if
attn_backend
==
_Backend
.
FLASH_ATTN
:
# Flash Attention works only with bfloat16 data-type
dtype
=
'bfloat16'
run_test
(
hf_runner
,
vllm_runner
,
image_assets
,
model
,
sizes
=
sizes
,
dtype
=
dtype
,
max_tokens
=
max_tokens
,
num_logprobs
=
num_logprobs
,
tensor_parallel_size
=
1
,
)
@
large_gpu_test
(
min_gb
=
48
)
...
...
@@ -265,9 +282,10 @@ def test_models_single_leading_image(hf_runner, vllm_runner, image_assets,
@
pytest
.
mark
.
parametrize
(
"dtype"
,
[
"bfloat16"
])
@
pytest
.
mark
.
parametrize
(
"max_tokens"
,
[
128
])
@
pytest
.
mark
.
parametrize
(
"num_logprobs"
,
[
5
])
@
pytest
.
mark
.
parametrize
(
"attn_backend"
,
LIST_ENC_DEC_SUPPORTED_BACKENDS
)
def
test_models_multi_leading_images
(
hf_runner
,
vllm_runner
,
image_assets
,
model
,
dtype
,
max_tokens
,
num_logprobs
)
->
None
:
model
,
dtype
,
max_tokens
,
num_logprobs
,
attn_backend
:
_Backend
)
->
None
:
stop_sign
=
image_assets
[
0
].
pil_image
cherry_blossom
=
image_assets
[
1
].
pil_image
...
...
@@ -291,17 +309,20 @@ def test_models_multi_leading_images(hf_runner, vllm_runner, image_assets,
cherry_blossom
.
resize
((
512
,
1024
)),
],
])]
_run_test
(
hf_runner
,
vllm_runner
,
inputs
,
model
,
dtype
=
dtype
,
max_tokens
=
max_tokens
,
num_logprobs
=
num_logprobs
,
tensor_parallel_size
=
1
,
)
with
global_force_attn_backend_context_manager
(
attn_backend
):
if
attn_backend
==
_Backend
.
FLASH_ATTN
:
# Flash Attention works only with bfloat16 data-type
dtype
=
'bfloat16'
_run_test
(
hf_runner
,
vllm_runner
,
inputs
,
model
,
dtype
=
dtype
,
max_tokens
=
max_tokens
,
num_logprobs
=
num_logprobs
,
tensor_parallel_size
=
1
,
)
@
large_gpu_test
(
min_gb
=
48
)
...
...
@@ -309,8 +330,10 @@ def test_models_multi_leading_images(hf_runner, vllm_runner, image_assets,
@
pytest
.
mark
.
parametrize
(
"dtype"
,
[
"bfloat16"
])
@
pytest
.
mark
.
parametrize
(
"max_tokens"
,
[
128
])
@
pytest
.
mark
.
parametrize
(
"num_logprobs"
,
[
5
])
@
pytest
.
mark
.
parametrize
(
"attn_backend"
,
LIST_ENC_DEC_SUPPORTED_BACKENDS
)
def
test_models_interleaved_images
(
hf_runner
,
vllm_runner
,
image_assets
,
model
,
dtype
,
max_tokens
,
num_logprobs
)
->
None
:
dtype
,
max_tokens
,
num_logprobs
,
attn_backend
:
_Backend
)
->
None
:
stop_sign
=
image_assets
[
0
].
pil_image
cherry_blossom
=
image_assets
[
1
].
pil_image
...
...
@@ -325,14 +348,17 @@ def test_models_interleaved_images(hf_runner, vllm_runner, image_assets, model,
[
stop_sign
],
[
stop_sign
,
cherry_blossom
],
])]
_run_test
(
hf_runner
,
vllm_runner
,
inputs
,
model
,
dtype
=
dtype
,
max_tokens
=
max_tokens
,
num_logprobs
=
num_logprobs
,
tensor_parallel_size
=
1
,
)
with
global_force_attn_backend_context_manager
(
attn_backend
):
if
attn_backend
==
_Backend
.
FLASH_ATTN
:
# Flash Attention works only with bfloat16 data-type
dtype
=
'bfloat16'
_run_test
(
hf_runner
,
vllm_runner
,
inputs
,
model
,
dtype
=
dtype
,
max_tokens
=
max_tokens
,
num_logprobs
=
num_logprobs
,
tensor_parallel_size
=
1
,
)
tests/test_config.py
View file @
b41fb9d3
...
...
@@ -243,6 +243,8 @@ def test_rope_customization():
assert
longchat_model_config
.
max_model_len
==
4096
@
pytest
.
mark
.
skipif
(
current_platform
.
is_rocm
(),
reason
=
"Encoder Decoder models not supported on ROCm."
)
@
pytest
.
mark
.
parametrize
((
"model_id"
,
"is_encoder_decoder"
),
[
(
"facebook/opt-125m"
,
False
),
(
"facebook/bart-base"
,
True
),
...
...
vllm/model_executor/models/mllama.py
View file @
b41fb9d3
...
...
@@ -32,6 +32,8 @@ from transformers.models.mllama.processing_mllama import (
import
vllm.distributed.parallel_state
as
ps
from
vllm.attention
import
Attention
,
AttentionMetadata
,
AttentionType
from
vllm.attention.backends.flash_attn
import
FlashAttentionMetadata
from
vllm.attention.backends.xformers
import
XFormersMetadata
from
vllm.attention.ops.paged_attn
import
PagedAttention
from
vllm.config
import
VllmConfig
from
vllm.distributed
import
get_tensor_model_parallel_world_size
...
...
@@ -799,12 +801,13 @@ class MllamaTextCrossAttention(nn.Module):
q
=
self
.
q_norm
(
q
)
if
attention_mask
is
not
None
:
output
=
self
.
attention_with_mask
(
q
,
k
,
v
,
kv_cache
,
attention_mask
,
kv_range_for_decode
,
attn_metadata
)
output
=
self
.
_
attention_with_mask
(
q
,
k
,
v
,
kv_cache
,
attention_mask
,
kv_range_for_decode
,
attn_metadata
)
else
:
output
=
self
.
attn
(
q
,
output
=
self
.
attn
(
q
.
view
(
-
1
,
self
.
num_local_heads
*
self
.
head_dim
),
k
,
v
,
kv_cache
,
...
...
@@ -813,7 +816,7 @@ class MllamaTextCrossAttention(nn.Module):
out
,
_
=
self
.
o_proj
(
output
)
return
out
def
attention_with_mask
(
def
_
attention_with_mask
(
self
,
q
:
torch
.
Tensor
,
k
:
torch
.
Tensor
,
...
...
@@ -824,14 +827,35 @@ class MllamaTextCrossAttention(nn.Module):
attn_metadata
:
AttentionMetadata
,
)
->
torch
.
Tensor
:
# Skip writing kv-cache for the initial profiling run.
if
len
(
kv_cache
.
shape
)
==
3
:
key_cache
,
value_cache
=
PagedAttention
.
split_kv_cache
(
kv_cache
,
self
.
num_local_key_value_heads
,
self
.
head_dim
)
cached_k
=
torch
.
cat
([
k
[
s
:
e
]
for
s
,
e
in
kv_range_for_decode
])
cached_v
=
torch
.
cat
([
v
[
s
:
e
]
for
s
,
e
in
kv_range_for_decode
])
PagedAttention
.
write_to_paged_cache
(
cached_k
,
cached_v
,
key_cache
,
value_cache
,
attn_metadata
.
cross_slot_mapping
,
"auto"
,
1.0
,
1.0
)
if
len
(
kv_cache
.
shape
)
>
1
:
if
isinstance
(
attn_metadata
,
FlashAttentionMetadata
):
cached_k
=
torch
.
cat
([
k
[
s
:
e
]
for
s
,
e
in
kv_range_for_decode
])
cached_v
=
torch
.
cat
([
v
[
s
:
e
]
for
s
,
e
in
kv_range_for_decode
])
torch
.
ops
.
_C_cache_ops
.
reshape_and_cache_flash
(
cached_k
,
cached_v
,
kv_cache
[
0
],
kv_cache
[
1
],
attn_metadata
.
cross_slot_mapping
,
# type: ignore[union-attr]
"auto"
,
1.0
,
1.0
,
)
elif
isinstance
(
attn_metadata
,
XFormersMetadata
):
key_cache
,
value_cache
=
PagedAttention
.
split_kv_cache
(
kv_cache
,
self
.
num_local_key_value_heads
,
self
.
head_dim
)
cached_k
=
torch
.
cat
([
k
[
s
:
e
]
for
s
,
e
in
kv_range_for_decode
])
cached_v
=
torch
.
cat
([
v
[
s
:
e
]
for
s
,
e
in
kv_range_for_decode
])
PagedAttention
.
write_to_paged_cache
(
cached_k
,
cached_v
,
key_cache
,
value_cache
,
attn_metadata
.
cross_slot_mapping
,
"auto"
,
1.0
,
1.0
)
else
:
raise
ValueError
(
f
"Unsupported AttentionMetadata
{
type
(
attn_metadata
)
}
"
f
"class found. Expected the AttentionMetadata to "
f
"be either XFormersMetadata or FlashAttentionMetadata."
)
# We have to call torch.sdpa for prefill when using a
# custom cross-attention mask. Because the mask is not a
# standard causal mask, neither a block diagonal mask which
...
...
vllm/worker/enc_dec_model_runner.py
View file @
b41fb9d3
...
...
@@ -9,15 +9,13 @@ from vllm.attention.backends.abstract import (AttentionBackend,
AttentionMetadata
)
from
vllm.attention.backends.utils
import
PAD_SLOT_ID
from
vllm.attention.selector
import
(
_Backend
,
get_env_variable_attn_backend
,
get_global_forced_attn_backend
,
global_force_attn_backend
)
from
vllm.config
import
ModelConfig
,
VllmConfig
get_global_forced_attn_backend
)
from
vllm.config
import
VllmConfig
from
vllm.forward_context
import
set_forward_context
from
vllm.inputs
import
INPUT_REGISTRY
,
InputRegistry
from
vllm.logger
import
init_logger
from
vllm.model_executor
import
SamplingMetadata
from
vllm.model_executor.layers.sampler
import
SamplerOutput
from
vllm.model_executor.model_loader.utils
import
get_architecture_class_name
from
vllm.multimodal
import
(
MULTIMODAL_REGISTRY
,
MultiModalKwargs
,
MultiModalRegistry
)
from
vllm.sampling_params
import
SamplingParams
...
...
@@ -35,11 +33,6 @@ from vllm.worker.utils import assert_enc_dec_mr_supported_scenario
logger
=
init_logger
(
__name__
)
# The Mllama model has PagedAttention specific logic because of which it
# can only be run with the XFORMERS backend
# TODO Make Mllama model work with Flash Attention backend.
_XFORMERS_ONLY_ENCODER_DECODER_ARCHS
=
[
"MllamaForConditionalGeneration"
]
@
dataclasses
.
dataclass
(
frozen
=
True
)
class
EncoderDecoderModelInput
(
ModelInputForGPUWithSamplingMetadata
):
...
...
@@ -97,7 +90,7 @@ class EncoderDecoderModelRunner(GPUModelRunnerBase[EncoderDecoderModelInput]):
models) but these arguments are present here for compatibility with
the base-class constructor.
'''
self
.
_maybe_force_supported_attention_backend
(
vllm_config
.
model_config
)
self
.
_maybe_force_supported_attention_backend
()
super
().
__init__
(
vllm_config
=
vllm_config
,
...
...
@@ -108,12 +101,7 @@ class EncoderDecoderModelRunner(GPUModelRunnerBase[EncoderDecoderModelInput]):
# Crash for unsupported encoder/scenarios
assert_enc_dec_mr_supported_scenario
(
self
)
def
_is_xformers_only_encoder_decoder_model
(
self
,
model
:
ModelConfig
)
->
bool
:
return
get_architecture_class_name
(
model
)
in
_XFORMERS_ONLY_ENCODER_DECODER_ARCHS
def
_maybe_force_supported_attention_backend
(
self
,
model
:
ModelConfig
):
def
_maybe_force_supported_attention_backend
(
self
):
'''
Force vLLM to use the XFormers attention backend,
which is currently the only supported option.
...
...
@@ -128,23 +116,13 @@ class EncoderDecoderModelRunner(GPUModelRunnerBase[EncoderDecoderModelInput]):
maybe_global_forced_backend
=
get_global_forced_attn_backend
()
is_forced_by_global
=
maybe_global_forced_backend
is
not
None
is_forced_by_env_var
=
maybe_env_var_forced_backend
is
not
None
if
not
(
is_forced_by_global
or
is_forced_by_env_var
)
\
and
self
.
_is_xformers_only_encoder_decoder_model
(
model
):
# The user has not already specified an attention backend
# override
logger
.
info
(
"Encoder-Decoder Model Architecture %s requires XFormers "
"backend; overriding backend auto-selection and "
"forcing XFormers."
,
get_architecture_class_name
(
model
))
global_force_attn_backend
(
_Backend
.
XFORMERS
)
elif
is_forced_by_global
:
if
is_forced_by_global
:
# noqa: SIM102
# Backend override enforced by global variable takes
# precedence over vLLM backend environment variable.
if
maybe_global_forced_backend
not
in
\
[
_Backend
.
XFORMERS
,
_Backend
.
FLASH_ATTN
]:
raise_backend_err
()
elif
is_forced_by_env_var
:
elif
is_forced_by_env_var
:
# noqa: SIM102
# Backend override enforced by vLLM backend
# environment variable
if
maybe_env_var_forced_backend
not
in
\
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment