Unverified Commit f825c6bd authored by Maximilien de Bayser's avatar Maximilien de Bayser Committed by GitHub
Browse files

Support encoder_only attention for FlexAttention (#22273)


Signed-off-by: default avatarMax de Bayser <mbayser@br.ibm.com>
parent 41b67f42
...@@ -9,7 +9,9 @@ import pytest ...@@ -9,7 +9,9 @@ import pytest
import torch import torch
from packaging import version from packaging import version
from vllm import LLM, SamplingParams from vllm import SamplingParams
from ..models.utils import check_embeddings_close
TORCH_VERSION = version.parse(torch.__version__) TORCH_VERSION = version.parse(torch.__version__)
MINIMUM_TORCH_VERSION = version.parse("2.7.0") MINIMUM_TORCH_VERSION = version.parse("2.7.0")
...@@ -28,7 +30,7 @@ def set_seed(seed): ...@@ -28,7 +30,7 @@ def set_seed(seed):
not torch.cuda.is_available() or TORCH_VERSION < MINIMUM_TORCH_VERSION, not torch.cuda.is_available() or TORCH_VERSION < MINIMUM_TORCH_VERSION,
reason="CUDA not available or PyTorch version < 2.7", reason="CUDA not available or PyTorch version < 2.7",
) )
def test_flex_attention_vs_default_backend(monkeypatch): def test_flex_attention_vs_default_backend(vllm_runner, monkeypatch):
"""Test that FlexAttention produces the same outputs as the default backend. """Test that FlexAttention produces the same outputs as the default backend.
This test compares the outputs from the FlexAttention backend with This test compares the outputs from the FlexAttention backend with
...@@ -36,7 +38,7 @@ def test_flex_attention_vs_default_backend(monkeypatch): ...@@ -36,7 +38,7 @@ def test_flex_attention_vs_default_backend(monkeypatch):
""" """
model_name = "Qwen/Qwen2.5-1.5B-Instruct" model_name = "Qwen/Qwen2.5-1.5B-Instruct"
seed = 42 seed = 42
max_tokens = 32 max_tokens = 24
prompts = [ prompts = [
"Hello, my name is", "Hello, my name is",
"The president of the United States is", "The president of the United States is",
...@@ -54,33 +56,30 @@ def test_flex_attention_vs_default_backend(monkeypatch): ...@@ -54,33 +56,30 @@ def test_flex_attention_vs_default_backend(monkeypatch):
m.setenv("VLLM_ATTENTION_BACKEND", "FLEX_ATTENTION") m.setenv("VLLM_ATTENTION_BACKEND", "FLEX_ATTENTION")
set_seed(seed) set_seed(seed)
with vllm_runner(model_name,
llm_flex = LLM( runner="generate",
model_name, tensor_parallel_size=1,
tensor_parallel_size=1, num_gpu_blocks_override=128,
num_gpu_blocks_override=128, enforce_eager=True) as llm_flex:
enforce_eager=True, output_flex = llm_flex.generate(prompts, sampling_params)
)
output_flex = llm_flex.generate(prompts, sampling_params)
# Run with default backend # Run with default backend
with monkeypatch.context() as m: with monkeypatch.context() as m:
m.setenv("VLLM_USE_V1", "1") m.setenv("VLLM_USE_V1", "1")
set_seed(seed) set_seed(seed)
llm_default = LLM( with vllm_runner(model_name,
model_name, runner="generate",
tensor_parallel_size=1, tensor_parallel_size=1,
num_gpu_blocks_override=128, num_gpu_blocks_override=128,
enforce_eager=True, enforce_eager=True) as llm_default:
) output_default = llm_default.generate(prompts, sampling_params)
output_default = llm_default.generate(prompts, sampling_params)
# Compare outputs from both backends # Compare outputs from both backends
for i, (flex_result, for i, (flex_result,
default_result) in enumerate(zip(output_flex, output_default)): default_result) in enumerate(zip(output_flex, output_default)):
prompt = prompts[i] prompt = prompts[i]
flex_text = flex_result.outputs[0].text flex_text = flex_result[1][0]
default_text = default_result.outputs[0].text default_text = default_result[1][0]
assert flex_text == default_text, ( assert flex_text == default_text, (
f"FlexAttention output doesn't match default for: {prompt!r}\n" f"FlexAttention output doesn't match default for: {prompt!r}\n"
...@@ -88,5 +87,54 @@ def test_flex_attention_vs_default_backend(monkeypatch): ...@@ -88,5 +87,54 @@ def test_flex_attention_vs_default_backend(monkeypatch):
f"Default: {default_text!r}") f"Default: {default_text!r}")
@pytest.mark.skipif(
not torch.cuda.is_available() or TORCH_VERSION < MINIMUM_TORCH_VERSION,
reason="CUDA not available or PyTorch version < 2.7",
)
def test_encoder_flex_attention_vs_default_backend(vllm_runner, monkeypatch):
"""Test that FlexAttention produces the same outputs as the default backend.
This test compares the outputs from the FlexAttention backend with
the default backend for encoder models.
"""
model_name = "BAAI/bge-base-en-v1.5"
prompts = [
"Hello, my name is",
"The president of the United States is",
"The capital of France is",
]
# Run with flex attention
with monkeypatch.context() as m:
m.setenv("VLLM_USE_V1", "1")
m.setenv("VLLM_ATTENTION_BACKEND", "FLEX_ATTENTION")
with vllm_runner(model_name,
runner="pooling",
dtype=torch.bfloat16,
tensor_parallel_size=1,
max_model_len=100,
enforce_eager=True) as llm_flex:
flex_outputs = llm_flex.embed(prompts)
# Run with default backend
with monkeypatch.context() as m:
m.setenv("VLLM_USE_V1", "1")
with vllm_runner(model_name,
runner="pooling",
dtype=torch.bfloat16,
tensor_parallel_size=1,
max_model_len=100,
enforce_eager=True) as llm_default:
default_outputs = llm_default.embed(prompts)
check_embeddings_close(
embeddings_0_lst=flex_outputs,
embeddings_1_lst=default_outputs,
name_0="flex",
name_1="default",
tol=1e-2,
)
if __name__ == "__main__": if __name__ == "__main__":
pytest.main([__file__]) pytest.main([__file__])
...@@ -148,6 +148,7 @@ def causal_mask_mod(b: torch.Tensor, h: torch.Tensor, q_idx: torch.Tensor, ...@@ -148,6 +148,7 @@ def causal_mask_mod(b: torch.Tensor, h: torch.Tensor, q_idx: torch.Tensor,
@dataclass @dataclass
class FlexAttentionMetadata: class FlexAttentionMetadata:
causal: bool
num_actual_tokens: int # Number of tokens excluding padding. num_actual_tokens: int # Number of tokens excluding padding.
max_query_len: int max_query_len: int
query_start_loc: torch.Tensor query_start_loc: torch.Tensor
...@@ -177,10 +178,9 @@ class FlexAttentionMetadata: ...@@ -177,10 +178,9 @@ class FlexAttentionMetadata:
num_blocks = 0 num_blocks = 0
block_mask: Optional[BlockMask] = None block_mask: Optional[BlockMask] = None
score_mod: Optional[_score_mod_signature] = None score_mod: Optional[_score_mod_signature] = None
mask_mod: Optional[_mask_mod_signature] = None
logical_mask_mod: _mask_mod_signature = causal_mask_mod logical_mask_mod: _mask_mod_signature = causal_mask_mod
def get_mask_mod(self) -> _mask_mod_signature: def get_causal_mask_mod(self) -> _mask_mod_signature:
"""Creates the mask_mod function for FlexAttention. """Creates the mask_mod function for FlexAttention.
This function creates the combined mask mod function that handles: This function creates the combined mask mod function that handles:
...@@ -233,14 +233,39 @@ class FlexAttentionMetadata: ...@@ -233,14 +233,39 @@ class FlexAttentionMetadata:
return final_mask_mod return final_mask_mod
def get_bidirectional_mask_mod(self) -> _mask_mod_signature:
"""Creates the encoder mask_mod function for FlexAttention.
Since the encoder bidirectional attention doesn't run with
KV cache, this function creates a mask based on the
packed query sequences.
"""
# Create a lookup mapping from query indices -> request number
request_lookup = _offsets_to_doc_ids_tensor(self.query_start_loc)
def final_mask_mod(
b: torch.Tensor,
h: torch.Tensor,
q_idx: torch.Tensor,
kv_idx: torch.Tensor,
) -> torch.Tensor:
return request_lookup[q_idx] == request_lookup[kv_idx]
return final_mask_mod
def build_block_mask(self) -> BlockMask: def build_block_mask(self) -> BlockMask:
assert self.mask_mod is not None if self.causal:
mask_mod = self.get_causal_mask_mod()
kv_len = self.total_cache_tokens
else:
mask_mod = self.get_bidirectional_mask_mod()
kv_len = self.num_actual_tokens
return create_block_mask_compiled( return create_block_mask_compiled(
self.mask_mod, mask_mod,
None, None,
None, None,
self.num_actual_tokens, self.num_actual_tokens,
self.total_cache_tokens, kv_len,
device=self.block_table.device, device=self.block_table.device,
) )
...@@ -251,7 +276,6 @@ class FlexAttentionMetadata: ...@@ -251,7 +276,6 @@ class FlexAttentionMetadata:
assert self.prefix_kv_lens is None, "Not implemented yet." assert self.prefix_kv_lens is None, "Not implemented yet."
assert self.suffix_kv_lens is None, "Not implemented yet." assert self.suffix_kv_lens is None, "Not implemented yet."
self.num_blocks = self.total_cache_tokens // self.block_size self.num_blocks = self.total_cache_tokens // self.block_size
self.mask_mod = self.get_mask_mod()
self.block_mask = self.build_block_mask() self.block_mask = self.build_block_mask()
...@@ -306,6 +330,7 @@ class FlexAttentionMetadataBuilder( ...@@ -306,6 +330,7 @@ class FlexAttentionMetadataBuilder(
self.device, non_blocking=True) self.device, non_blocking=True)
out = FlexAttentionMetadata( out = FlexAttentionMetadata(
causal=common_attn_metadata.causal,
num_actual_tokens=num_actual_tokens, num_actual_tokens=num_actual_tokens,
max_query_len=max_query_len, max_query_len=max_query_len,
query_start_loc=query_start_loc, query_start_loc=query_start_loc,
...@@ -350,6 +375,12 @@ class FlexAttentionImpl(AttentionImpl): ...@@ -350,6 +375,12 @@ class FlexAttentionImpl(AttentionImpl):
self.head_size = head_size self.head_size = head_size
self.scale = float(scale) self.scale = float(scale)
self.num_kv_heads = num_kv_heads self.num_kv_heads = num_kv_heads
self.attn_type = attn_type
if attn_type not in (AttentionType.ENCODER_ONLY,
AttentionType.DECODER):
raise NotImplementedError(
f"FlexAttention does not support {attn_type} attention")
if alibi_slopes is not None: if alibi_slopes is not None:
raise NotImplementedError( raise NotImplementedError(
...@@ -425,26 +456,38 @@ class FlexAttentionImpl(AttentionImpl): ...@@ -425,26 +456,38 @@ class FlexAttentionImpl(AttentionImpl):
num_actual_tokens = attn_metadata.num_actual_tokens num_actual_tokens = attn_metadata.num_actual_tokens
key_cache, value_cache = kv_cache.unbind(0) if not attn_metadata.causal:
assert self.attn_type == AttentionType.ENCODER_ONLY
torch.ops._C_cache_ops.reshape_and_cache_flash(
key, query, key_tensor, value_tensor = map(
value, lambda x: self.view_as_4d(x).permute(0, 2, 1, 3),
key_cache, (query, key, value),
value_cache, )
attn_metadata.slot_mapping,
self.kv_cache_dtype, else:
layer._k_scale, assert self.attn_type == AttentionType.DECODER
layer._v_scale, key_cache, value_cache = kv_cache.unbind(0)
)
torch.ops._C_cache_ops.reshape_and_cache_flash(
key,
value,
key_cache,
value_cache,
attn_metadata.slot_mapping,
self.kv_cache_dtype,
layer._k_scale,
layer._v_scale,
)
# View out the block_size dim
key_cache = key_cache.view(-1, self.num_kv_heads, self.head_size)
value_cache = value_cache.view(-1, self.num_kv_heads,
self.head_size)
query, key_tensor, value_tensor = map(
lambda x: self.view_as_4d(x).permute(0, 2, 1, 3),
(query, key_cache, value_cache),
)
# View out the block_size dim
key_cache = key_cache.view(-1, self.num_kv_heads, self.head_size)
value_cache = value_cache.view(-1, self.num_kv_heads, self.head_size)
query, key_cache, value_cache = map(
lambda x: self.view_as_4d(x).permute(0, 2, 1, 3),
(query, key_cache, value_cache),
)
query = query[:, :, :num_actual_tokens, :] query = query[:, :, :num_actual_tokens, :]
# Doesn't work for now -> constraint violation # Doesn't work for now -> constraint violation
# torch._dynamo.try_mark_dynamic(query, 2) # torch._dynamo.try_mark_dynamic(query, 2)
...@@ -465,8 +508,8 @@ class FlexAttentionImpl(AttentionImpl): ...@@ -465,8 +508,8 @@ class FlexAttentionImpl(AttentionImpl):
out = flex_attention_compiled( out = flex_attention_compiled(
query, query,
key_cache, key_tensor,
value_cache, value_tensor,
attn_metadata.score_mod, attn_metadata.score_mod,
attn_metadata.block_mask, attn_metadata.block_mask,
self.scale, self.scale,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment