Unverified Commit b41fb9d3 authored by sroy745's avatar sroy745 Committed by GitHub
Browse files

[Encoder Decoder] Update Mllama to run with both FlashAttention and XFormers (#9982)


Signed-off-by: default avatarSourashis Roy <sroy@roblox.com>
parent 7c655279
...@@ -7,7 +7,7 @@ from typing import List, Optional, Tuple ...@@ -7,7 +7,7 @@ from typing import List, Optional, Tuple
import pytest import pytest
from transformers import AutoModelForSeq2SeqLM from transformers import AutoModelForSeq2SeqLM
from vllm.attention.selector import (_Backend, from vllm.attention.selector import (_Backend, _cached_get_attn_backend,
global_force_attn_backend_context_manager) global_force_attn_backend_context_manager)
from vllm.platforms import current_platform from vllm.platforms import current_platform
from vllm.sequence import SampleLogprobs from vllm.sequence import SampleLogprobs
...@@ -34,6 +34,13 @@ def vllm_to_hf_output( ...@@ -34,6 +34,13 @@ def vllm_to_hf_output(
return output_ids, hf_output_str, out_logprobs return output_ids, hf_output_str, out_logprobs
@pytest.fixture(autouse=True)
def clear_cache():
"""Fixture to clear backend cache before each test."""
_cached_get_attn_backend.cache_clear() # Clear the cache
yield # This allows the test to run
@pytest.mark.parametrize("model", ["facebook/bart-large-cnn"]) @pytest.mark.parametrize("model", ["facebook/bart-large-cnn"])
@pytest.mark.parametrize("dtype", ["float"]) @pytest.mark.parametrize("dtype", ["float"])
@pytest.mark.parametrize("attn_backend", LIST_ENC_DEC_SUPPORTED_BACKENDS) @pytest.mark.parametrize("attn_backend", LIST_ENC_DEC_SUPPORTED_BACKENDS)
......
...@@ -4,6 +4,8 @@ import pytest ...@@ -4,6 +4,8 @@ import pytest
from transformers import (AutoConfig, AutoModelForVision2Seq, AutoTokenizer, from transformers import (AutoConfig, AutoModelForVision2Seq, AutoTokenizer,
BatchEncoding) BatchEncoding)
from vllm.attention.selector import (_Backend, _cached_get_attn_backend,
global_force_attn_backend_context_manager)
from vllm.multimodal.utils import rescale_image_size from vllm.multimodal.utils import rescale_image_size
from vllm.sequence import SampleLogprobs from vllm.sequence import SampleLogprobs
...@@ -14,6 +16,8 @@ from ...utils import check_logprobs_close ...@@ -14,6 +16,8 @@ from ...utils import check_logprobs_close
_LIMIT_IMAGE_PER_PROMPT = 3 _LIMIT_IMAGE_PER_PROMPT = 3
LIST_ENC_DEC_SUPPORTED_BACKENDS = [_Backend.XFORMERS, _Backend.FLASH_ATTN]
HF_IMAGE_PROMPTS = IMAGE_ASSETS.prompts({ HF_IMAGE_PROMPTS = IMAGE_ASSETS.prompts({
"stop_sign": "stop_sign":
"<|image|><|begin_of_text|>The meaning of the image is", "<|image|><|begin_of_text|>The meaning of the image is",
...@@ -221,6 +225,13 @@ def _run_test( ...@@ -221,6 +225,13 @@ def _run_test(
) )
@pytest.fixture(autouse=True)
def clear_cache():
"""Fixture to clear backend cache before each test."""
_cached_get_attn_backend.cache_clear() # Clear the cache
yield # This allows the test to run
@large_gpu_test(min_gb=48) @large_gpu_test(min_gb=48)
@pytest.mark.parametrize("model", models) @pytest.mark.parametrize("model", models)
@pytest.mark.parametrize( @pytest.mark.parametrize(
...@@ -244,9 +255,15 @@ def _run_test( ...@@ -244,9 +255,15 @@ def _run_test(
@pytest.mark.parametrize("dtype", ["bfloat16"]) @pytest.mark.parametrize("dtype", ["bfloat16"])
@pytest.mark.parametrize("max_tokens", [128]) @pytest.mark.parametrize("max_tokens", [128])
@pytest.mark.parametrize("num_logprobs", [5]) @pytest.mark.parametrize("num_logprobs", [5])
@pytest.mark.parametrize("attn_backend", LIST_ENC_DEC_SUPPORTED_BACKENDS)
def test_models_single_leading_image(hf_runner, vllm_runner, image_assets, def test_models_single_leading_image(hf_runner, vllm_runner, image_assets,
model, sizes, dtype, max_tokens, model, sizes, dtype, max_tokens,
num_logprobs) -> None: num_logprobs,
attn_backend: _Backend) -> None:
with global_force_attn_backend_context_manager(attn_backend):
if attn_backend == _Backend.FLASH_ATTN:
# Flash Attention works only with bfloat16 data-type
dtype = 'bfloat16'
run_test( run_test(
hf_runner, hf_runner,
vllm_runner, vllm_runner,
...@@ -265,9 +282,10 @@ def test_models_single_leading_image(hf_runner, vllm_runner, image_assets, ...@@ -265,9 +282,10 @@ def test_models_single_leading_image(hf_runner, vllm_runner, image_assets,
@pytest.mark.parametrize("dtype", ["bfloat16"]) @pytest.mark.parametrize("dtype", ["bfloat16"])
@pytest.mark.parametrize("max_tokens", [128]) @pytest.mark.parametrize("max_tokens", [128])
@pytest.mark.parametrize("num_logprobs", [5]) @pytest.mark.parametrize("num_logprobs", [5])
@pytest.mark.parametrize("attn_backend", LIST_ENC_DEC_SUPPORTED_BACKENDS)
def test_models_multi_leading_images(hf_runner, vllm_runner, image_assets, def test_models_multi_leading_images(hf_runner, vllm_runner, image_assets,
model, dtype, max_tokens, model, dtype, max_tokens, num_logprobs,
num_logprobs) -> None: attn_backend: _Backend) -> None:
stop_sign = image_assets[0].pil_image stop_sign = image_assets[0].pil_image
cherry_blossom = image_assets[1].pil_image cherry_blossom = image_assets[1].pil_image
...@@ -291,7 +309,10 @@ def test_models_multi_leading_images(hf_runner, vllm_runner, image_assets, ...@@ -291,7 +309,10 @@ def test_models_multi_leading_images(hf_runner, vllm_runner, image_assets,
cherry_blossom.resize((512, 1024)), cherry_blossom.resize((512, 1024)),
], ],
])] ])]
with global_force_attn_backend_context_manager(attn_backend):
if attn_backend == _Backend.FLASH_ATTN:
# Flash Attention works only with bfloat16 data-type
dtype = 'bfloat16'
_run_test( _run_test(
hf_runner, hf_runner,
vllm_runner, vllm_runner,
...@@ -309,8 +330,10 @@ def test_models_multi_leading_images(hf_runner, vllm_runner, image_assets, ...@@ -309,8 +330,10 @@ def test_models_multi_leading_images(hf_runner, vllm_runner, image_assets,
@pytest.mark.parametrize("dtype", ["bfloat16"]) @pytest.mark.parametrize("dtype", ["bfloat16"])
@pytest.mark.parametrize("max_tokens", [128]) @pytest.mark.parametrize("max_tokens", [128])
@pytest.mark.parametrize("num_logprobs", [5]) @pytest.mark.parametrize("num_logprobs", [5])
@pytest.mark.parametrize("attn_backend", LIST_ENC_DEC_SUPPORTED_BACKENDS)
def test_models_interleaved_images(hf_runner, vllm_runner, image_assets, model, def test_models_interleaved_images(hf_runner, vllm_runner, image_assets, model,
dtype, max_tokens, num_logprobs) -> None: dtype, max_tokens, num_logprobs,
attn_backend: _Backend) -> None:
stop_sign = image_assets[0].pil_image stop_sign = image_assets[0].pil_image
cherry_blossom = image_assets[1].pil_image cherry_blossom = image_assets[1].pil_image
...@@ -325,7 +348,10 @@ def test_models_interleaved_images(hf_runner, vllm_runner, image_assets, model, ...@@ -325,7 +348,10 @@ def test_models_interleaved_images(hf_runner, vllm_runner, image_assets, model,
[stop_sign], [stop_sign],
[stop_sign, cherry_blossom], [stop_sign, cherry_blossom],
])] ])]
with global_force_attn_backend_context_manager(attn_backend):
if attn_backend == _Backend.FLASH_ATTN:
# Flash Attention works only with bfloat16 data-type
dtype = 'bfloat16'
_run_test( _run_test(
hf_runner, hf_runner,
vllm_runner, vllm_runner,
......
...@@ -243,6 +243,8 @@ def test_rope_customization(): ...@@ -243,6 +243,8 @@ def test_rope_customization():
assert longchat_model_config.max_model_len == 4096 assert longchat_model_config.max_model_len == 4096
@pytest.mark.skipif(current_platform.is_rocm(),
reason="Encoder Decoder models not supported on ROCm.")
@pytest.mark.parametrize(("model_id", "is_encoder_decoder"), [ @pytest.mark.parametrize(("model_id", "is_encoder_decoder"), [
("facebook/opt-125m", False), ("facebook/opt-125m", False),
("facebook/bart-base", True), ("facebook/bart-base", True),
......
...@@ -32,6 +32,8 @@ from transformers.models.mllama.processing_mllama import ( ...@@ -32,6 +32,8 @@ from transformers.models.mllama.processing_mllama import (
import vllm.distributed.parallel_state as ps import vllm.distributed.parallel_state as ps
from vllm.attention import Attention, AttentionMetadata, AttentionType from vllm.attention import Attention, AttentionMetadata, AttentionType
from vllm.attention.backends.flash_attn import FlashAttentionMetadata
from vllm.attention.backends.xformers import XFormersMetadata
from vllm.attention.ops.paged_attn import PagedAttention from vllm.attention.ops.paged_attn import PagedAttention
from vllm.config import VllmConfig from vllm.config import VllmConfig
from vllm.distributed import get_tensor_model_parallel_world_size from vllm.distributed import get_tensor_model_parallel_world_size
...@@ -799,12 +801,13 @@ class MllamaTextCrossAttention(nn.Module): ...@@ -799,12 +801,13 @@ class MllamaTextCrossAttention(nn.Module):
q = self.q_norm(q) q = self.q_norm(q)
if attention_mask is not None: if attention_mask is not None:
output = self.attention_with_mask(q, k, v, kv_cache, output = self._attention_with_mask(q, k, v, kv_cache,
attention_mask, attention_mask,
kv_range_for_decode, kv_range_for_decode,
attn_metadata) attn_metadata)
else: else:
output = self.attn(q, output = self.attn(q.view(-1,
self.num_local_heads * self.head_dim),
k, k,
v, v,
kv_cache, kv_cache,
...@@ -813,7 +816,7 @@ class MllamaTextCrossAttention(nn.Module): ...@@ -813,7 +816,7 @@ class MllamaTextCrossAttention(nn.Module):
out, _ = self.o_proj(output) out, _ = self.o_proj(output)
return out return out
def attention_with_mask( def _attention_with_mask(
self, self,
q: torch.Tensor, q: torch.Tensor,
k: torch.Tensor, k: torch.Tensor,
...@@ -824,7 +827,22 @@ class MllamaTextCrossAttention(nn.Module): ...@@ -824,7 +827,22 @@ class MllamaTextCrossAttention(nn.Module):
attn_metadata: AttentionMetadata, attn_metadata: AttentionMetadata,
) -> torch.Tensor: ) -> torch.Tensor:
# Skip writing kv-cache for the initial profiling run. # Skip writing kv-cache for the initial profiling run.
if len(kv_cache.shape) == 3: if len(kv_cache.shape) > 1:
if isinstance(attn_metadata, FlashAttentionMetadata):
cached_k = torch.cat([k[s:e] for s, e in kv_range_for_decode])
cached_v = torch.cat([v[s:e] for s, e in kv_range_for_decode])
torch.ops._C_cache_ops.reshape_and_cache_flash(
cached_k,
cached_v,
kv_cache[0],
kv_cache[1],
attn_metadata.
cross_slot_mapping, # type: ignore[union-attr]
"auto",
1.0,
1.0,
)
elif isinstance(attn_metadata, XFormersMetadata):
key_cache, value_cache = PagedAttention.split_kv_cache( key_cache, value_cache = PagedAttention.split_kv_cache(
kv_cache, self.num_local_key_value_heads, self.head_dim) kv_cache, self.num_local_key_value_heads, self.head_dim)
cached_k = torch.cat([k[s:e] for s, e in kv_range_for_decode]) cached_k = torch.cat([k[s:e] for s, e in kv_range_for_decode])
...@@ -832,6 +850,12 @@ class MllamaTextCrossAttention(nn.Module): ...@@ -832,6 +850,12 @@ class MllamaTextCrossAttention(nn.Module):
PagedAttention.write_to_paged_cache( PagedAttention.write_to_paged_cache(
cached_k, cached_v, key_cache, value_cache, cached_k, cached_v, key_cache, value_cache,
attn_metadata.cross_slot_mapping, "auto", 1.0, 1.0) attn_metadata.cross_slot_mapping, "auto", 1.0, 1.0)
else:
raise ValueError(
f"Unsupported AttentionMetadata {type(attn_metadata)} "
f"class found. Expected the AttentionMetadata to "
f"be either XFormersMetadata or FlashAttentionMetadata.")
# We have to call torch.sdpa for prefill when using a # We have to call torch.sdpa for prefill when using a
# custom cross-attention mask. Because the mask is not a # custom cross-attention mask. Because the mask is not a
# standard causal mask, neither a block diagonal mask which # standard causal mask, neither a block diagonal mask which
......
...@@ -9,15 +9,13 @@ from vllm.attention.backends.abstract import (AttentionBackend, ...@@ -9,15 +9,13 @@ from vllm.attention.backends.abstract import (AttentionBackend,
AttentionMetadata) AttentionMetadata)
from vllm.attention.backends.utils import PAD_SLOT_ID from vllm.attention.backends.utils import PAD_SLOT_ID
from vllm.attention.selector import (_Backend, get_env_variable_attn_backend, from vllm.attention.selector import (_Backend, get_env_variable_attn_backend,
get_global_forced_attn_backend, get_global_forced_attn_backend)
global_force_attn_backend) from vllm.config import VllmConfig
from vllm.config import ModelConfig, VllmConfig
from vllm.forward_context import set_forward_context from vllm.forward_context import set_forward_context
from vllm.inputs import INPUT_REGISTRY, InputRegistry from vllm.inputs import INPUT_REGISTRY, InputRegistry
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.model_executor import SamplingMetadata from vllm.model_executor import SamplingMetadata
from vllm.model_executor.layers.sampler import SamplerOutput from vllm.model_executor.layers.sampler import SamplerOutput
from vllm.model_executor.model_loader.utils import get_architecture_class_name
from vllm.multimodal import (MULTIMODAL_REGISTRY, MultiModalKwargs, from vllm.multimodal import (MULTIMODAL_REGISTRY, MultiModalKwargs,
MultiModalRegistry) MultiModalRegistry)
from vllm.sampling_params import SamplingParams from vllm.sampling_params import SamplingParams
...@@ -35,11 +33,6 @@ from vllm.worker.utils import assert_enc_dec_mr_supported_scenario ...@@ -35,11 +33,6 @@ from vllm.worker.utils import assert_enc_dec_mr_supported_scenario
logger = init_logger(__name__) logger = init_logger(__name__)
# The Mllama model has PagedAttention specific logic because of which it
# can only be run with the XFORMERS backend
# TODO Make Mllama model work with Flash Attention backend.
_XFORMERS_ONLY_ENCODER_DECODER_ARCHS = ["MllamaForConditionalGeneration"]
@dataclasses.dataclass(frozen=True) @dataclasses.dataclass(frozen=True)
class EncoderDecoderModelInput(ModelInputForGPUWithSamplingMetadata): class EncoderDecoderModelInput(ModelInputForGPUWithSamplingMetadata):
...@@ -97,7 +90,7 @@ class EncoderDecoderModelRunner(GPUModelRunnerBase[EncoderDecoderModelInput]): ...@@ -97,7 +90,7 @@ class EncoderDecoderModelRunner(GPUModelRunnerBase[EncoderDecoderModelInput]):
models) but these arguments are present here for compatibility with models) but these arguments are present here for compatibility with
the base-class constructor. the base-class constructor.
''' '''
self._maybe_force_supported_attention_backend(vllm_config.model_config) self._maybe_force_supported_attention_backend()
super().__init__( super().__init__(
vllm_config=vllm_config, vllm_config=vllm_config,
...@@ -108,12 +101,7 @@ class EncoderDecoderModelRunner(GPUModelRunnerBase[EncoderDecoderModelInput]): ...@@ -108,12 +101,7 @@ class EncoderDecoderModelRunner(GPUModelRunnerBase[EncoderDecoderModelInput]):
# Crash for unsupported encoder/scenarios # Crash for unsupported encoder/scenarios
assert_enc_dec_mr_supported_scenario(self) assert_enc_dec_mr_supported_scenario(self)
def _is_xformers_only_encoder_decoder_model(self, def _maybe_force_supported_attention_backend(self):
model: ModelConfig) -> bool:
return get_architecture_class_name(
model) in _XFORMERS_ONLY_ENCODER_DECODER_ARCHS
def _maybe_force_supported_attention_backend(self, model: ModelConfig):
''' '''
Force vLLM to use the XFormers attention backend, Force vLLM to use the XFormers attention backend,
which is currently the only supported option. which is currently the only supported option.
...@@ -128,23 +116,13 @@ class EncoderDecoderModelRunner(GPUModelRunnerBase[EncoderDecoderModelInput]): ...@@ -128,23 +116,13 @@ class EncoderDecoderModelRunner(GPUModelRunnerBase[EncoderDecoderModelInput]):
maybe_global_forced_backend = get_global_forced_attn_backend() maybe_global_forced_backend = get_global_forced_attn_backend()
is_forced_by_global = maybe_global_forced_backend is not None is_forced_by_global = maybe_global_forced_backend is not None
is_forced_by_env_var = maybe_env_var_forced_backend is not None is_forced_by_env_var = maybe_env_var_forced_backend is not None
if is_forced_by_global: # noqa: SIM102
if not (is_forced_by_global or is_forced_by_env_var) \
and self._is_xformers_only_encoder_decoder_model(model):
# The user has not already specified an attention backend
# override
logger.info(
"Encoder-Decoder Model Architecture %s requires XFormers "
"backend; overriding backend auto-selection and "
"forcing XFormers.", get_architecture_class_name(model))
global_force_attn_backend(_Backend.XFORMERS)
elif is_forced_by_global:
# Backend override enforced by global variable takes # Backend override enforced by global variable takes
# precedence over vLLM backend environment variable. # precedence over vLLM backend environment variable.
if maybe_global_forced_backend not in\ if maybe_global_forced_backend not in\
[_Backend.XFORMERS, _Backend.FLASH_ATTN]: [_Backend.XFORMERS, _Backend.FLASH_ATTN]:
raise_backend_err() raise_backend_err()
elif is_forced_by_env_var: elif is_forced_by_env_var: # noqa: SIM102
# Backend override enforced by vLLM backend # Backend override enforced by vLLM backend
# environment variable # environment variable
if maybe_env_var_forced_backend not in\ if maybe_env_var_forced_backend not in\
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment