Commit 705f6a35 authored by zhuwenwen's avatar zhuwenwen
Browse files

Merge tag 'v0.5.2' into v0.5.2-dtk24.04.1

parents af837396 4cf256ae
import dataclasses
from typing import List, Tuple, Type
import torch
from vllm.attention import AttentionMetadata
from vllm.attention.backends.abstract import AttentionBackend
from vllm.model_executor import SamplingMetadata
from vllm.model_executor.pooling_metadata import PoolingMetadata
from vllm.worker.embedding_model_runner import (
ModelInputForGPUWithPoolingMetadata)
from vllm.worker.model_runner import ModelInputForGPUWithSamplingMetadata
class MockAttentionBackend(AttentionBackend):
@staticmethod
def get_name() -> str:
raise NotImplementedError
@staticmethod
def get_impl_cls():
raise NotImplementedError
@staticmethod
def get_metadata_cls() -> Type["AttentionMetadata"]:
return AttentionMetadata
@staticmethod
def get_kv_cache_shape(
num_blocks: int,
block_size: int,
num_kv_heads: int,
head_size: int,
) -> Tuple[int, ...]:
raise NotImplementedError
@staticmethod
def swap_blocks(
src_kv_cache: torch.Tensor,
dst_kv_cache: torch.Tensor,
src_to_dst: torch.Tensor,
) -> None:
pass
@staticmethod
def copy_blocks(
kv_caches: List[torch.Tensor],
src_to_dists: torch.Tensor,
) -> None:
pass
def test_model_runner_input():
sampling_metadata = SamplingMetadata(
["seq_group"],
"selected_token_indices",
"categorized_sample_indices",
"num_prompts",
)
attn_metadata = AttentionMetadata(
num_prefills=1,
num_prefill_tokens=2,
num_decode_tokens=3,
slot_mapping=torch.zeros(1),
)
model_input = ModelInputForGPUWithSamplingMetadata(
input_tokens=torch.ones(10),
input_positions=torch.ones(10),
sampling_metadata=sampling_metadata,
attn_metadata=attn_metadata)
assert isinstance(model_input, ModelInputForGPUWithSamplingMetadata)
# Test round trip serialization.
tensor_dict = model_input.as_broadcastable_tensor_dict()
attn_backend = MockAttentionBackend()
received_model_input = (
ModelInputForGPUWithSamplingMetadata.from_broadcasted_tensor_dict(
tensor_dict, attn_backend=attn_backend))
# Check that received copy has correct values.
assert isinstance(received_model_input,
ModelInputForGPUWithSamplingMetadata)
assert received_model_input.input_tokens is not None
assert (
received_model_input.input_tokens == model_input.input_tokens).all()
assert received_model_input.input_positions is not None
assert (received_model_input.input_positions == model_input.input_positions
).all()
assert received_model_input.multi_modal_kwargs is None
assert (received_model_input.multi_modal_kwargs ==
model_input.multi_modal_kwargs)
assert received_model_input.lora_requests is None
assert received_model_input.lora_requests == model_input.lora_requests
assert received_model_input.lora_mapping is None
assert received_model_input.lora_mapping == model_input.lora_mapping
for field in dataclasses.fields(AttentionMetadata):
assert getattr(received_model_input.attn_metadata, field.name,
None) == getattr(attn_metadata, field.name, None)
# For sampling metadata, only selected_token_indices is copied.
assert (received_model_input.sampling_metadata.selected_token_indices ==
sampling_metadata.selected_token_indices)
assert received_model_input.sampling_metadata.seq_groups is None
def test_embedding_model_runner_input():
pooling_metadata = PoolingMetadata(
seq_groups=[[0]],
seq_data={},
prompt_lens=[1],
)
attn_metadata = AttentionMetadata(
num_prefills=1,
num_prefill_tokens=2,
num_decode_tokens=3,
slot_mapping=torch.zeros(1),
)
model_input = ModelInputForGPUWithPoolingMetadata(
input_tokens=torch.ones(10),
input_positions=torch.ones(10),
pooling_metadata=pooling_metadata,
attn_metadata=attn_metadata)
assert isinstance(model_input, ModelInputForGPUWithPoolingMetadata)
# Test round trip serialization.
tensor_dict = model_input.as_broadcastable_tensor_dict()
attn_backend = MockAttentionBackend()
received_model_input = (
ModelInputForGPUWithPoolingMetadata.from_broadcasted_tensor_dict(
tensor_dict, attn_backend=attn_backend))
# Check that received copy has correct values.
assert isinstance(received_model_input,
ModelInputForGPUWithPoolingMetadata)
assert received_model_input.input_tokens is not None
assert (
received_model_input.input_tokens == model_input.input_tokens).all()
assert received_model_input.input_positions is not None
assert (received_model_input.input_positions == model_input.input_positions
).all()
assert received_model_input.multi_modal_kwargs is None
assert (received_model_input.multi_modal_kwargs ==
model_input.multi_modal_kwargs)
assert received_model_input.lora_requests is None
assert received_model_input.lora_requests == model_input.lora_requests
assert received_model_input.lora_mapping is None
assert received_model_input.lora_mapping == model_input.lora_mapping
for field in dataclasses.fields(AttentionMetadata):
assert getattr(received_model_input.attn_metadata, field.name,
None) == getattr(attn_metadata, field.name, None)
# Pooling metadata is not broadcast.
assert received_model_input.pooling_metadata is None
from typing import List
import pytest
import torch
from vllm.distributed.parallel_state import init_distributed_environment
from vllm.distributed.parallel_state import (ensure_model_parallel_initialized,
init_distributed_environment)
from vllm.engine.arg_utils import EngineArgs
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.sequence import SamplingParams, SequenceData, SequenceGroupMetadata
......@@ -20,6 +23,7 @@ def _create_model_runner(model: str, *args, **kwargs) -> ModelRunner:
cache_config=engine_config.cache_config,
load_config=engine_config.load_config,
lora_config=engine_config.lora_config,
prompt_adapter_config=engine_config.prompt_adapter_config,
is_driver_worker=True,
)
return model_runner
......@@ -34,8 +38,8 @@ def test_prepare_prompt(batch_size):
enable_chunked_prefill=False,
)
seq_lens = []
seq_group_metadata_list = []
seq_lens: List[int] = []
seq_group_metadata_list: List[SequenceGroupMetadata] = []
block_tables = {0: [1]}
for i in range(batch_size):
# make sure all tokens fit into one block
......@@ -58,12 +62,13 @@ def test_prepare_prompt(batch_size):
expected_selected_token_indices.append(selected_token_start_idx +
seq_len - 1)
selected_token_start_idx += seq_len
model_input = model_runner._prepare_model_input(seq_group_metadata_list)
model_input = model_runner._prepare_model_input_tensors(
seq_group_metadata_list)
input_tokens = model_input.input_tokens
input_positions = model_input.input_positions
attn_metadata = model_input.attn_metadata
return_seq_lens = model_input.seq_lens
slot_mapping = model_input.slot_mapping
slot_mapping = attn_metadata.slot_mapping
assert return_seq_lens == seq_lens
assert len(slot_mapping) == len(input_tokens)
......@@ -150,15 +155,14 @@ def test_prepare_decode_cuda_graph(batch_size):
enable_chunked_prefill=False,
)
context_lens = []
seq_group_metadata_list = []
context_lens: List[int] = []
seq_group_metadata_list: List[SequenceGroupMetadata] = []
# Assume each seq group finishes prefill.
for i in range(batch_size):
# make sure all tokens fit into one block
context_len = i % (model_runner.block_size - 1) + 1
context_lens.append(context_len)
seq_data = list(range(context_len))
seq_data = SequenceData(seq_data)
seq_data = SequenceData(list(range(context_len)))
seq_data.update_num_computed_tokens(context_len)
# Append one token ID since prefill is finished.
seq_data.append_token_id(1, 0)
......@@ -172,10 +176,11 @@ def test_prepare_decode_cuda_graph(batch_size):
assert seq_group_metadata.token_chunk_size == 1
seq_group_metadata_list.append(seq_group_metadata)
model_input = model_runner._prepare_model_input(seq_group_metadata_list)
model_input = model_runner._prepare_model_input_tensors(
seq_group_metadata_list)
input_tokens, input_positions, attn_metadata, slot_mapping = (
model_input.input_tokens, model_input.input_positions,
model_input.attn_metadata, model_input.slot_mapping)
model_input.attn_metadata, model_input.attn_metadata.slot_mapping)
assert len(slot_mapping) == len(input_tokens)
expected_bs = _get_graph_batch_size(len(seq_group_metadata_list))
......@@ -256,33 +261,30 @@ def test_empty_seq_group():
dtype="float16",
enforce_eager=False,
)
seq_group_metadata_list = []
model_input = model_runner._prepare_model_input(seq_group_metadata_list)
input_tokens, input_positions, attn_metadata, slot_mapping = (
seq_group_metadata_list: List[SequenceGroupMetadata] = []
model_input = model_runner._prepare_model_input_tensors(
seq_group_metadata_list)
input_tokens, input_positions, attn_metadata = (
model_input.input_tokens,
model_input.input_positions,
model_input.attn_metadata,
model_input.slot_mapping,
)
assert len(input_tokens) == 0
assert len(input_positions) == 0
assert input_tokens is None
assert input_positions is None
assert attn_metadata is None
assert len(slot_mapping) == 0
model_input = model_runner._prepare_model_input(seq_group_metadata_list)
(input_tokens, input_positions, attn_metadata, slot_mapping,
return_seq_lens) = (
model_input.input_tokens,
model_input.input_positions,
model_input.attn_metadata,
model_input.slot_mapping,
model_input.seq_lens,
)
assert len(input_tokens) == 0
assert len(input_positions) == 0
model_input = model_runner._prepare_model_input_tensors(
seq_group_metadata_list)
(input_tokens, input_positions, attn_metadata, return_seq_lens) = (
model_input.input_tokens,
model_input.input_positions,
model_input.attn_metadata,
model_input.seq_lens,
)
assert input_tokens is None
assert input_positions is None
assert attn_metadata is None
assert len(slot_mapping) == 0
assert len(return_seq_lens) == 0
assert return_seq_lens is None
@pytest.fixture
......@@ -292,6 +294,7 @@ def distributed_init():
rank=0,
distributed_init_method=f"tcp://127.0.0.1:{get_open_port()}",
local_rank=0)
ensure_model_parallel_initialized(1, 1)
@pytest.mark.parametrize("batch_size", list(range(2, 128)))
......@@ -308,10 +311,10 @@ def test_hybrid_batches(batch_size, enforce_eager, distributed_init):
)
# Add prefill requests.
seq_lens = []
seq_group_metadata_list = []
prefill_metadata_list = []
decode_metadata_list = []
seq_lens: List[int] = []
seq_group_metadata_list: List[SequenceGroupMetadata] = []
prefill_metadata_list: List[SequenceGroupMetadata] = []
decode_metadata_list: List[SequenceGroupMetadata] = []
block_tables = {0: [1]}
prefill_batch_size = batch_size // 2
decode_batch_size = batch_size - prefill_batch_size
......@@ -350,8 +353,12 @@ def test_hybrid_batches(batch_size, enforce_eager, distributed_init):
seq_group_metadata_list.append(seq_group_metadata)
decode_metadata_list.append(seq_group_metadata)
(input_tokens, input_positions, attn_metadata, _, _, _,
_) = model_runner.prepare_input_tensors(seq_group_metadata_list)
model_input = model_runner.prepare_model_input(seq_group_metadata_list)
(input_tokens, input_positions, attn_metadata) = (
model_input.input_tokens,
model_input.input_positions,
model_input.attn_metadata,
)
prefill_meta_actual = attn_metadata.prefill_metadata
decode_meta_actual = attn_metadata.decode_metadata
......@@ -364,7 +371,7 @@ def test_hybrid_batches(batch_size, enforce_eager, distributed_init):
# Verify attn metadata is consistent. We don't need to test individual
# values here because they are tested above.
attn_metadata = model_runner._prepare_model_input(
attn_metadata = model_runner._prepare_model_input_tensors(
seq_group_metadata_list).attn_metadata
for attr_expected, attr_actual in zip(vars(attn_metadata.prefill_metadata),
......
......@@ -39,8 +39,8 @@ def test_swap() -> None:
num_cpu_blocks=engine_config.cache_config.num_cpu_blocks)
# Randomly initialize the cache.
gpu_cache = worker.cache_engine.gpu_cache
cpu_cache = worker.cache_engine.cpu_cache
gpu_cache = worker.cache_engine[0].gpu_cache
cpu_cache = worker.cache_engine[0].cpu_cache
num_layers = len(gpu_cache)
for i in range(num_layers):
gpu_key_cache, gpu_value_cache = gpu_cache[i]
......
......@@ -13,9 +13,11 @@ from vllm.pooling_params import PoolingParams
from vllm.sampling_params import SamplingParams
from vllm.version import __dcu_version__
__version__ = "0.5.0"
from .version import __commit__, __version__
__all__ = [
"__commit__",
"__version__",
"LLM",
"ModelRegistry",
"PromptStrictInputs",
......
import contextlib
import functools
from typing import List, Optional, Tuple, Type
import torch
from vllm.logger import init_logger
logger = init_logger(__name__)
try:
import vllm._C
except ImportError as e:
from vllm.logger import init_logger
logger = init_logger(__name__)
logger.warning("Failed to import from vllm._C with %r", e)
with contextlib.suppress(ImportError):
......@@ -23,6 +26,25 @@ def is_custom_op_supported(op_name: str) -> bool:
return op is not None
def hint_on_error(fn):
@functools.wraps(fn)
def wrapper(*args, **kwargs):
try:
return fn(*args, **kwargs)
except AttributeError as e:
msg = (
"Error in calling custom op %s: %s\n"
"Possibly you have built or installed an obsolete version of vllm.\n"
"Please try a clean build and install of vllm,"
"or remove old built files such as vllm/*cpython*.so and build/ ."
)
logger.error(msg, fn.__name__, e)
raise e
return wrapper
# activation ops
def silu_and_mul(out: torch.Tensor, x: torch.Tensor) -> None:
torch.ops._C.silu_and_mul(out, x)
......@@ -44,6 +66,10 @@ def gelu_new(out: torch.Tensor, x: torch.Tensor) -> None:
torch.ops._C.gelu_new(out, x)
def gelu_quick(out: torch.Tensor, x: torch.Tensor) -> None:
torch.ops._C.gelu_quick(out, x)
# page attention ops
def paged_attention_v1(
out: torch.Tensor,
......@@ -190,9 +216,16 @@ def gptq_marlin_24_gemm(a: torch.Tensor, b_q_weight: torch.Tensor,
# cutlass
def cutlass_scaled_mm_dq(a: torch.Tensor, b: torch.Tensor,
scale_a: torch.Tensor, scale_b: torch.Tensor,
out_dtype: Type[torch.dtype]) -> torch.Tensor:
def cutlass_scaled_mm_supports_fp8(cuda_device_capability: int) -> bool:
return torch.ops._C.cutlass_scaled_mm_supports_fp8(cuda_device_capability)
def cutlass_scaled_mm(a: torch.Tensor,
b: torch.Tensor,
scale_a: torch.Tensor,
scale_b: torch.Tensor,
out_dtype: Type[torch.dtype],
bias: Optional[torch.Tensor] = None) -> torch.Tensor:
assert (b.shape[0] % 16 == 0 and b.shape[1] % 16 == 0)
assert (out_dtype is torch.bfloat16 or out_dtype is torch.float16)
......@@ -200,7 +233,7 @@ def cutlass_scaled_mm_dq(a: torch.Tensor, b: torch.Tensor,
n = b.shape[1]
out = torch.empty((m, n), dtype=out_dtype, device=a.device)
torch.ops._C.cutlass_scaled_mm_dq(out, a, b, scale_a, scale_b)
torch.ops._C.cutlass_scaled_mm(out, a, b, scale_a, scale_b, bias)
return out
......@@ -238,6 +271,15 @@ def gptq_marlin_gemm(a: torch.Tensor, b_q_weight: torch.Tensor,
size_k, is_k_full)
# fp8 marlin
def fp8_marlin_gemm(a: torch.Tensor, b_q_weight: torch.Tensor,
b_scales: torch.Tensor, workspace: torch.Tensor,
num_bits: int, size_m: int, size_n: int,
size_k: int) -> torch.Tensor:
return torch.ops._C.fp8_marlin_gemm(a, b_q_weight, b_scales, workspace,
num_bits, size_m, size_n, size_k)
# fp8
# def scaled_fp8_quant(
# input: torch.Tensor,
......@@ -352,7 +394,8 @@ def reshape_and_cache_flash(
kv_cache_dtype)
def copy_blocks(key_caches: torch.Tensor, value_caches: torch.Tensor,
def copy_blocks(key_caches: List[torch.Tensor],
value_caches: List[torch.Tensor],
block_mapping: torch.Tensor) -> None:
torch.ops._C_cache_ops.copy_blocks(key_caches, value_caches, block_mapping)
......@@ -459,3 +502,25 @@ def dispatch_bgmv_low_level(
h_out,
y_offset,
)
# temporary fix for https://github.com/vllm-project/vllm/issues/5456
# TODO: remove this in v0.6.0
names_and_values = globals()
names_and_values_to_update = {}
# prepare variables to avoid dict size change during iteration
k, v, arg = None, None, None
fn_type = type(lambda x: x)
for k, v in names_and_values.items():
# find functions that are defined in this file and have torch.Tensor
# in their annotations. `arg == "torch.Tensor"` is used to handle
# the case when users use `import __annotations__` to turn type
# hints into strings.
if isinstance(v, fn_type) \
and v.__code__.co_filename == __file__ \
and any(arg is torch.Tensor or arg == "torch.Tensor"
for arg in v.__annotations__.values()):
names_and_values_to_update[k] = hint_on_error(v)
names_and_values.update(names_and_values_to_update)
del names_and_values_to_update, names_and_values, v, k, fn_type
from typing import List, Optional, Tuple
import torch
from vllm.logger import init_logger
logger = init_logger(__name__)
try:
import intel_extension_for_pytorch as ipex
except ImportError as e:
logger.warning("Import error msg: %s", e.msg)
class ipex_ops:
@staticmethod
def _reshape_activation_tensor(
x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
num = x.size(0)
d = x.size(1) // 2
x = x.reshape(num, 2, d)
x1, x2 = torch.chunk(x, chunks=2, dim=1)
x1 = x1.reshape(num, d)
x2 = x2.reshape(num, d)
return x1, x2
def silu_and_mul(out: torch.Tensor, x: torch.Tensor) -> None:
x1, x2 = ipex_ops._reshape_activation_tensor(x)
ipex.llm.functional.silu_mul(x1, x2, out)
def gelu_and_mul(out: torch.Tensor, x: torch.Tensor) -> None:
x1, x2 = ipex_ops._reshape_activation_tensor(x)
ipex.llm.functional.gelu_mul(x1, x2, out, "none")
def gelu_tanh_and_mul(out: torch.Tensor, x: torch.Tensor) -> None:
x1, x2 = ipex_ops._reshape_activation_tensor(x)
ipex.llm.functional.gelu_mul(x1, x2, out, "tanh")
def gelu_fast(out: torch.Tensor, x: torch.Tensor) -> None:
out.copy_(torch.nn.functional.gelu(x))
def gelu_new(out: torch.Tensor, x: torch.Tensor) -> None:
out.copy_(torch.nn.functional.gelu(x))
# TODO add implementation of gelu_quick here
# def gelu_quick(out: torch.Tensor, x: torch.Tensor) -> None:
def paged_attention_v1(
out: torch.Tensor,
query: torch.Tensor,
key_cache: torch.Tensor,
value_cache: torch.Tensor,
num_kv_heads: int,
scale: float,
block_tables: torch.Tensor,
context_lens: torch.Tensor,
block_size: int,
max_context_len: int,
alibi_slopes: Optional[torch.Tensor],
kv_cache_dtype: str,
kv_scale: float,
tp_rank: int = 0,
blocksparse_local_blocks: int = 0,
blocksparse_vert_stride: int = 0,
blocksparse_block_size: int = 64,
blocksparse_head_sliding_step: int = 0,
) -> None:
assert kv_cache_dtype == "auto"
num_heads = out.size(1)
num_queries_per_tokens = num_heads // num_kv_heads
head_mapping = torch.arange(
0,
num_kv_heads,
device=query.device,
dtype=torch.int32,
).view(num_kv_heads,
1).repeat_interleave(num_queries_per_tokens).flatten()
# todo: ipex will refactor namespace
torch.xpu.paged_attention_v1(out, query.contiguous(),
key_cache.view_as(value_cache),
value_cache, head_mapping, scale,
block_tables, context_lens, block_size,
max_context_len, alibi_slopes)
def paged_attention_v2(
out: torch.Tensor,
exp_sum: torch.Tensor,
max_logits: torch.Tensor,
tmp_out: torch.Tensor,
query: torch.Tensor,
key_cache: torch.Tensor,
value_cache: torch.Tensor,
num_kv_heads: int,
scale: float,
block_tables: torch.Tensor,
context_lens: torch.Tensor,
block_size: int,
max_context_len: int,
alibi_slopes: Optional[torch.Tensor],
kv_cache_dtype: str,
kv_scale: float,
tp_rank: int = 0,
blocksparse_local_blocks: int = 0,
blocksparse_vert_stride: int = 0,
blocksparse_block_size: int = 64,
blocksparse_head_sliding_step: int = 0,
) -> None:
assert kv_cache_dtype == "auto"
num_heads = out.size(1)
num_queries_per_tokens = num_heads // num_kv_heads
head_mapping = torch.arange(
0,
num_kv_heads,
dtype=torch.int32,
device=query.device,
).view(num_kv_heads,
1).repeat_interleave(num_queries_per_tokens).flatten()
# todo: ipex will refactor namespace
torch.xpu.paged_attention_v2(out, exp_sum, max_logits, tmp_out,
query.contiguous(),
key_cache.view_as(value_cache),
value_cache, head_mapping, block_tables,
context_lens, scale, block_size,
max_context_len, alibi_slopes)
def rotary_embedding(
positions: torch.Tensor, # [batch_size, seq_len]
query: torch.Tensor, # [batch_size, seq_len, num_heads*head_size]
key: torch.Tensor, # [batch_size, seq_len, num_kv_heads*head_size]
head_size: int,
cos_sin_cache: torch.Tensor, # [cos_sin_dim, rot_dim]
is_neox: bool,
) -> None:
if positions.dim() == 1:
positions = positions.unsqueeze(0)
query = query.unsqueeze(0)
key = key.unsqueeze(0)
rotary_dim = cos_sin_cache.size(1)
query = query.view(*query.shape[:-1], -1, head_size)
key = key.view(*key.shape[:-1], -1, head_size)
query_rot = query[..., :rotary_dim]
key_rot = key[..., :rotary_dim]
cos_sin = cos_sin_cache[positions.long()]
cos, sin = cos_sin.chunk(2, dim=-1)
if is_neox:
cos = cos.repeat(1, 1, 2).unsqueeze(-2)
sin = sin.repeat(1, 1, 2).unsqueeze(-2)
else:
cos = cos.repeat_interleave(2, dim=-1).unsqueeze(-2)
sin = sin.repeat_interleave(2, dim=-1).unsqueeze(-2)
ipex.llm.functional.rotary_embedding(query_rot, key_rot, sin, cos,
rotary_dim, is_neox, positions)
def batched_rotary_embedding(positions: torch.Tensor, query: torch.Tensor,
key: torch.Tensor, head_size: int,
cos_sin_cache: torch.Tensor, is_neox: bool,
rot_dim: int,
cos_sin_cache_offsets: torch.Tensor) -> None:
if positions.dim() == 1:
positions = positions.unsqueeze(0)
query = query.unsqueeze(0)
key = key.unsqueeze(0)
cos_sin_cache_offsets = cos_sin_cache_offsets.view_as(positions)
rotary_dim = cos_sin_cache.size(1)
query = query.view(*query.shape[:-1], -1, head_size)
key = key.view(*key.shape[:-1], -1, head_size)
query_rot = query[..., :rotary_dim]
key_rot = key[..., :rotary_dim]
cos_sin = cos_sin_cache[torch.add(positions,
cos_sin_cache_offsets).long()]
cos, sin = cos_sin.chunk(2, dim=-1)
if is_neox:
cos = cos.repeat(1, 1, 2).unsqueeze(-2)
sin = sin.repeat(1, 1, 2).unsqueeze(-2)
else:
cos = cos.repeat_interleave(2, dim=-1).unsqueeze(-2)
sin = sin.repeat_interleave(2, dim=-1).unsqueeze(-2)
ipex.llm.functional.rotary_embedding(query_rot, key_rot, sin, cos,
rotary_dim, is_neox, positions)
def rms_norm(out: torch.Tensor, input: torch.Tensor, weight: torch.Tensor,
epsilon: float) -> None:
tmp = ipex.llm.functional.rms_norm(input, weight, epsilon)
out.copy_(tmp)
def fused_add_rms_norm(input: torch.Tensor, residual: torch.Tensor,
weight: torch.Tensor, epsilon: float) -> None:
tmp = ipex.llm.functional.add_rms_norm(residual, input, weight, None,
epsilon, True)
input.copy_(tmp)
def varlen_attention(
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
out: torch.Tensor,
seqlen_q: torch.Tensor,
seqlen_k: torch.Tensor,
max_seqlen_q: int,
max_seqlen_k: int,
pdropout: float,
softmax_scale: float,
zero_tensors: bool,
is_causal: bool,
return_softmax: bool,
gen_: torch.Generator,
) -> None:
ipex.llm.functional.varlen_attention(query, key, value, out, seqlen_q,
seqlen_k, max_seqlen_q,
max_seqlen_k, pdropout,
softmax_scale, zero_tensors,
is_causal, return_softmax, gen_)
def reshape_and_cache(
key: torch.Tensor,
value: torch.Tensor,
key_cache: torch.Tensor,
value_cache: torch.Tensor,
slot_mapping: torch.Tensor,
kv_cache_dtype: str,
kv_scale: float,
) -> None:
assert kv_cache_dtype == "auto"
ipex.llm.modules.PagedAttention.reshape_and_cache(
key, value, key_cache, value_cache, slot_mapping)
@staticmethod
def copy_blocks(key_caches: List[torch.Tensor],
value_caches: List[torch.Tensor],
block_mapping: torch.Tensor) -> None:
torch.xpu.copy_blocks(key_caches, value_caches, block_mapping)
def swap_blocks(src: torch.Tensor, dst: torch.Tensor,
block_mapping: torch.Tensor) -> None:
torch.xpu.swap_blocks(src, dst, block_mapping)
from dataclasses import dataclass
from typing import Tuple
@dataclass
class AdapterMapping:
# Per every token in input_ids:
index_mapping: Tuple[int, ...]
# Per sampled token:
prompt_mapping: Tuple[int, ...]
def __post_init__(self):
self.index_mapping = tuple(self.index_mapping)
self.prompt_mapping = tuple(self.prompt_mapping)
\ No newline at end of file
from abc import ABC, abstractmethod
from typing import Any, Callable, Dict, Hashable, Optional, TypeVar
from torch import nn
from vllm.logger import init_logger
from vllm.utils import LRUCache
logger = init_logger(__name__)
class AdapterModel(ABC):
def __init__(self, model_id=None):
self.id = model_id
@abstractmethod
def from_local_checkpoint(cls, model_dir, model_id=None, **kwargs):
# Common initialization code
# Load weights or embeddings from local checkpoint
raise NotImplementedError("Subclasses must implement this method.")
T = TypeVar('T')
class AdapterLRUCache(LRUCache[T]):
def __init__(self, capacity: int, deactivate_fn: Callable[[Hashable],
None]):
super().__init__(capacity)
self.deactivate_fn = deactivate_fn
def _on_remove(self, key: Hashable, value: T):
logger.debug("Removing adapter int id: %d", key)
self.deactivate_fn(key)
return super()._on_remove(key, value)
class AdapterModelManager(ABC):
def __init__(
self,
model: nn.Module,
):
"""Create a AdapterModelManager and adapter for a given model.
Args:
model: the model to be adapted.
"""
self.model: nn.Module = model
self._registered_adapters: Dict[int, Any] = {}
# Dict instead of a Set for compatibility with LRUCache.
self._active_adapters: Dict[int, None] = {}
self.adapter_type = 'Adapter'
self._last_mapping = None
def __len__(self) -> int:
return len(self._registered_adapters)
@property
@abstractmethod
def adapter_slots(self):
...
@property
@abstractmethod
def capacity(self):
...
@abstractmethod
def activate_adapter(self, adapter_id: int) -> bool:
...
@abstractmethod
def deactivate_adapter(self, adapter_id: int) -> bool:
...
@abstractmethod
def add_adapter(self, adapter: Any) -> bool:
...
@abstractmethod
def set_adapter_mapping(self, mapping: Any) -> None:
...
@abstractmethod
def remove_adapter(self, adapter_id: int) -> bool:
...
@abstractmethod
def remove_all_adapters(self):
...
@abstractmethod
def get_adapter(self, adapter_id: int) -> Optional[Any]:
...
@abstractmethod
def list_adapters(self) -> Dict[int, Any]:
...
@abstractmethod
def pin_adapter(self, adapter_id: int) -> bool:
...
from abc import abstractmethod
from dataclasses import dataclass
@dataclass
class AdapterRequest:
"""
Base class for adapter requests.
"""
@property
@abstractmethod
def adapter_id(self):
...
def __post_init__(self):
if self.adapter_id < 1:
raise ValueError(f"id must be > 0, got {self.adapter_id}")
def __eq__(self, value: object) -> bool:
return isinstance(
value, self.__class__) and self.adapter_id == value.adapter_id
def __hash__(self) -> int:
return hash(self.adapter_id)
from typing import Any, Callable, Dict, Optional, Set
## model functions
def deactivate_adapter(adapter_id: int, active_adapters: Dict[int, None],
deactivate_func: Callable) -> bool:
if adapter_id in active_adapters:
deactivate_func(adapter_id)
active_adapters.pop(adapter_id)
return True
return False
def add_adapter(adapter: Any, registered_adapters: Dict[int, Any],
capacity: int, add_func: Callable) -> bool:
if adapter.id not in registered_adapters:
if len(registered_adapters) >= capacity:
raise RuntimeError('No free adapter slots.')
add_func(adapter)
registered_adapters[adapter.id] = adapter
return True
return False
def set_adapter_mapping(mapping: Any, last_mapping: Any,
set_mapping_func: Callable) -> Any:
if last_mapping != mapping:
set_mapping_func(mapping)
return mapping
return last_mapping
def remove_adapter(adapter_id: int, registered_adapters: Dict[int, Any],
deactivate_func: Callable) -> bool:
deactivate_func(adapter_id)
return bool(registered_adapters.pop(adapter_id, None))
def list_adapters(registered_adapters: Dict[int, Any]) -> Dict[int, Any]:
return dict(registered_adapters)
def get_adapter(adapter_id: int,
registered_adapters: Dict[int, Any]) -> Optional[Any]:
return registered_adapters.get(adapter_id, None)
## worker functions
def set_active_adapters_worker(requests: Set[Any], mapping: Optional[Any],
apply_adapters_func,
set_adapter_mapping_func) -> None:
apply_adapters_func(requests)
set_adapter_mapping_func(mapping)
def add_adapter_worker(adapter_request: Any, list_adapters_func,
load_adapter_func, add_adapter_func,
activate_adapter_func) -> bool:
if adapter_request.adapter_id in list_adapters_func():
return False
loaded_adapter = load_adapter_func(adapter_request)
loaded = add_adapter_func(loaded_adapter)
activate_adapter_func(loaded_adapter.id)
return loaded
def apply_adapters_worker(adapter_requests: Set[Any], list_adapters_func,
adapter_slots: int, remove_adapter_func,
add_adapter_func) -> None:
models_that_exist = list_adapters_func()
models_map = {
adapter_request.adapter_id: adapter_request
for adapter_request in adapter_requests if adapter_request
}
if len(models_map) > adapter_slots:
raise RuntimeError(
f"Number of requested models ({len(models_map)}) is greater "
f"than the number of GPU model slots "
f"({adapter_slots}).")
new_models = set(models_map)
models_to_add = new_models - models_that_exist
models_to_remove = models_that_exist - new_models
for adapter_id in models_to_remove:
remove_adapter_func(adapter_id)
for adapter_id in models_to_add:
add_adapter_func(models_map[adapter_id])
def list_adapters_worker(adapter_manager_list_adapters_func) -> Set[int]:
return set(adapter_manager_list_adapters_func())
from abc import ABC, abstractmethod
from typing import Any, Optional, Set
import torch
class AbstractWorkerManager(ABC):
def __init__(self, device: torch.device):
self.device = device
@property
@abstractmethod
def is_enabled(self) -> bool:
...
@abstractmethod
def set_active_adapters(self, requests: Set[Any],
mapping: Optional[Any]) -> None:
...
@abstractmethod
def add_adapter(self, adapter_request: Any) -> bool:
...
@abstractmethod
def remove_adapter(self, adapter_id: int) -> bool:
...
@abstractmethod
def remove_all_adapters(self):
...
@abstractmethod
def list_adapters(self) -> Set[int]:
...
from abc import ABC, abstractmethod
from dataclasses import dataclass, fields
from enum import Enum, auto
from typing import (Any, Dict, Generic, List, Optional, Set, Tuple, Type,
TypeVar)
import torch
class AttentionType(Enum):
DECODER = auto() # Decoder attention between previous layer Q/K/V
ENCODER = auto() # Encoder attention between previous layer Q/K/V
ENCODER_DECODER = auto() # Attention between dec. Q and enc. K/V
class AttentionBackend(ABC):
"""Abstract class for attention backends."""
......@@ -21,9 +28,13 @@ class AttentionBackend(ABC):
@staticmethod
@abstractmethod
def make_metadata(*args, **kwargs) -> "AttentionMetadata":
def get_metadata_cls() -> Type["AttentionMetadata"]:
raise NotImplementedError
@classmethod
def make_metadata(cls, *args, **kwargs) -> "AttentionMetadata":
return cls.get_metadata_cls()(*args, **kwargs)
@staticmethod
@abstractmethod
def get_kv_cache_shape(
......@@ -124,5 +135,6 @@ class AttentionImpl(ABC, Generic[T]):
kv_cache: torch.Tensor,
attn_metadata: T,
kv_scale: float = 1.0,
attn_type: AttentionType = AttentionType.DECODER,
) -> torch.Tensor:
raise NotImplementedError
......@@ -4,7 +4,7 @@ from typing import Any, Dict, List, Optional, Tuple, Type
import torch
from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
AttentionMetadata)
AttentionMetadata, AttentionType)
from vllm.attention.ops.blocksparse_attention.interface import (
LocalStridedBlockSparseAttn, get_head_sliding_step)
from vllm.attention.ops.paged_attn import PagedAttention
......@@ -90,8 +90,8 @@ class BlocksparseFlashAttentionBackend(AttentionBackend):
return BlocksparseFlashAttentionImpl
@staticmethod
def make_metadata(*args, **kwargs) -> "BlocksparseFlashAttentionMetadata":
return BlocksparseFlashAttentionMetadata(*args, **kwargs)
def get_metadata_cls() -> Type["AttentionMetadata"]:
return BlocksparseFlashAttentionMetadata
@staticmethod
def get_kv_cache_shape(
......@@ -328,6 +328,7 @@ class BlocksparseFlashAttentionImpl(AttentionImpl):
kv_cache: torch.Tensor,
attn_metadata: BlocksparseFlashAttentionMetadata,
kv_scale: float = 1.0,
attn_type: AttentionType = AttentionType.DECODER,
) -> torch.Tensor:
"""Forward pass with FlashAttention and PagedAttention.
......@@ -340,6 +341,12 @@ class BlocksparseFlashAttentionImpl(AttentionImpl):
Returns:
shape = [num_tokens, num_heads * head_size]
"""
if attn_type != AttentionType.DECODER:
raise NotImplementedError("Encoder self-attention and "
"encoder/decoder cross-attention "
"are not implemented for "
"BlocksparseFlashAttentionImpl")
num_tokens, hidden_size = query.shape
# Reshape the query, key, and value tensors.
query = query.view(-1, self.num_heads, self.head_size)
......
......@@ -7,7 +7,7 @@ from vllm_flash_attn import flash_attn_varlen_func, flash_attn_with_kvcache
from vllm import _custom_ops as ops
from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
AttentionMetadata)
AttentionMetadata, AttentionType)
class FlashAttentionBackend(AttentionBackend):
......@@ -25,8 +25,8 @@ class FlashAttentionBackend(AttentionBackend):
return FlashAttentionImpl
@staticmethod
def make_metadata(*args, **kwargs) -> "FlashAttentionMetadata":
return FlashAttentionMetadata(*args, **kwargs)
def get_metadata_cls() -> Type["AttentionMetadata"]:
return FlashAttentionMetadata
@staticmethod
def get_kv_cache_shape(
......@@ -83,7 +83,7 @@ class FlashAttentionMetadata(AttentionMetadata):
# |---------------- N iteration ---------------------|
# |- tokenA -|......................|-- newTokens ---|
# |---------- context_len ----------|
# |-------------------- seq_len ----------------------|
# |-------------------- seq_len ---------------------|
# |-- query_len ---|
# Maximum query length in the batch. None for decoding.
......@@ -257,6 +257,7 @@ class FlashAttentionImpl(AttentionImpl):
kv_cache: torch.Tensor,
attn_metadata: FlashAttentionMetadata,
kv_scale: float = 1.0,
attn_type: AttentionType = AttentionType.DECODER,
) -> torch.Tensor:
"""Forward pass with FlashAttention.
......@@ -269,6 +270,12 @@ class FlashAttentionImpl(AttentionImpl):
Returns:
shape = [num_tokens, num_heads * head_size]
"""
if attn_type != AttentionType.DECODER:
raise NotImplementedError("Encoder self-attention and "
"encoder/decoder cross-attention "
"are not implemented for "
"FlashAttentionImpl")
# NOTE(woosuk): FlashAttention does not support FP8 KV cache.
assert kv_scale == 1.0, "kv_scale is not supported in FlashAttention."
......@@ -317,7 +324,7 @@ class FlashAttentionImpl(AttentionImpl):
# normal attention
# When block_tables are not filled, it means q and k are the
# prompt, and they have the same length.
flash_attn_varlen_func(
out = flash_attn_varlen_func(
q=query,
k=key,
v=value,
......@@ -329,13 +336,14 @@ class FlashAttentionImpl(AttentionImpl):
causal=True,
window_size=self.sliding_window,
alibi_slopes=self.alibi_slopes,
out=output[:num_prefill_tokens],
)
assert output[:num_prefill_tokens].shape == out.shape
output[:num_prefill_tokens] = out
else:
# prefix-enabled attention
assert prefill_meta.seq_lens is not None
max_seq_len = max(prefill_meta.seq_lens)
flash_attn_varlen_func(
output[:num_prefill_tokens] = flash_attn_varlen_func(
q=query,
k=key_cache,
v=value_cache,
......@@ -347,12 +355,11 @@ class FlashAttentionImpl(AttentionImpl):
causal=True,
alibi_slopes=self.alibi_slopes,
block_table=prefill_meta.block_tables,
out=output[:num_prefill_tokens],
)
if decode_meta := attn_metadata.decode_metadata:
# Decoding run.
flash_attn_with_kvcache(
output[num_prefill_tokens:] = flash_attn_with_kvcache(
decode_query.unsqueeze(1),
key_cache,
value_cache,
......@@ -361,8 +368,7 @@ class FlashAttentionImpl(AttentionImpl):
softmax_scale=self.scale,
causal=True,
alibi_slopes=self.alibi_slopes,
out=output[num_prefill_tokens:].unsqueeze(1),
)
).squeeze(1)
# Reshape the output tensor.
return output.view(num_tokens, hidden_size)
from dataclasses import dataclass
from typing import Any, Dict, List, Optional, Set, Tuple, Type
import flashinfer
try:
from flashinfer import BatchDecodeWithPagedKVCacheWrapper
from flashinfer.prefill import BatchPrefillWithPagedKVCacheWrapper
from vllm_flash_attn import flash_attn_varlen_func
except ImportError:
flash_attn_varlen_func = None
BatchDecodeWithPagedKVCacheWrapper = None
BatchPrefillWithPagedKVCacheWrapper = None
import torch
from flashinfer import BatchDecodeWithPagedKVCacheWrapper
from vllm_flash_attn import flash_attn_varlen_func
from vllm import _custom_ops as ops
from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
AttentionMetadata)
AttentionMetadata, AttentionType)
class FlashInferBackend(AttentionBackend):
......@@ -22,8 +28,8 @@ class FlashInferBackend(AttentionBackend):
return FlashInferImpl
@staticmethod
def make_metadata(*args, **kwargs) -> "FlashInferMetadata":
return FlashInferMetadata(*args, **kwargs)
def get_metadata_cls() -> Type["AttentionMetadata"]:
return FlashInferMetadata
@staticmethod
def get_kv_cache_shape(
......@@ -60,19 +66,16 @@ class FlashInferMetadata(AttentionMetadata):
# requests only.
max_prefill_seq_len: int
use_cuda_graph: bool = False
use_cuda_graph: bool = True
prefill_wrapper: Optional[BatchPrefillWithPagedKVCacheWrapper] = None
decode_wrapper: Optional[BatchDecodeWithPagedKVCacheWrapper] = None
# Metadata for the prefill stage since we still
# use flash attention for prefill.
# Metadata for the prefill stage
seq_start_loc: Optional[torch.Tensor] = None
query_start_loc: Optional[torch.Tensor] = None
block_tables: Optional[torch.Tensor] = None
# Metadata for the decode stage
# Workspace buffer required by the kernel, the buffer should not
# be allocated/deacollated by the FalshInfermetadata object.
workspace_buffer: Optional[torch.Tensor] = None
# An example for paged_kv_indices, paged_kv_indptr:
# request 1, page indices [0, 5, 8]
# request 2, page indices [1, 6, 7]
......@@ -98,6 +101,9 @@ class FlashInferMetadata(AttentionMetadata):
page_size: Optional[int] = None
# The data type of the paged kv cache
data_type: torch.dtype = None
device: torch.device = torch.device("cuda")
# Only used by gemma2 model
logits_soft_cap: Optional[float] = None
def __post_init__(self):
# Refer to
......@@ -109,13 +115,37 @@ class FlashInferMetadata(AttentionMetadata):
f"Only {supported_head_sizes} are supported for head_dim,",
f"received {self.head_dim}.")
# When using flashinfer, we are also creating the FlashInferMetadata,
# which will also call post_init by default, here we want to skip the
# post_init if it's the prefill phase.
if self.num_prefills == 0:
assert self.num_decode_tokens > 0
self.decode_wrapper = flashinfer.BatchDecodeWithPagedKVCacheWrapper(
self.workspace_buffer, "NHD")
def begin_forward(self):
if self.num_prefill_tokens > 0:
if self.paged_kv_indices is None:
return
assert self.prefill_wrapper is not None
assert self.paged_kv_indices is not None
assert self.paged_kv_indptr is not None
assert self.paged_kv_last_page_len is not None
self.paged_kv_indices = self.paged_kv_indices.to(self.device)
self.paged_kv_indptr = self.paged_kv_indptr.to(self.device)
self.paged_kv_last_page_len = self.paged_kv_last_page_len.to(
self.device)
self.prefill_wrapper.end_forward()
self.prefill_wrapper.begin_forward(
self.query_start_loc, self.paged_kv_indptr,
self.paged_kv_indices, self.paged_kv_last_page_len,
self.num_qo_heads, self.num_kv_heads, self.head_dim,
self.page_size)
else:
if not self.use_cuda_graph:
assert self.paged_kv_indices is not None
assert self.paged_kv_indptr is not None
assert self.paged_kv_last_page_len is not None
self.paged_kv_indices = self.paged_kv_indices.to(self.device)
self.paged_kv_indptr = self.paged_kv_indptr.to(self.device)
self.paged_kv_last_page_len = self.paged_kv_last_page_len.to(
self.device)
assert self.decode_wrapper is not None
self.decode_wrapper.end_forward()
self.decode_wrapper.begin_forward(
self.paged_kv_indptr,
self.paged_kv_indices,
......@@ -133,8 +163,9 @@ class FlashInferMetadata(AttentionMetadata):
) -> Dict[str, Any]:
if skip_fields is None:
skip_fields = set()
# We need to skip the decode_wrapper field since it cannot be
# We need to skip the prefill/decode_wrapper field since it cannot be
# broadcasted with nccl when TP is enabled.
skip_fields.add('prefill_wrapper')
skip_fields.add('decode_wrapper')
return super().asdict_zerocopy(skip_fields)
......@@ -168,6 +199,7 @@ class FlashInferImpl(AttentionImpl):
alibi_slopes: Optional[List[float]],
sliding_window: Optional[int],
kv_cache_dtype: str,
blocksparse_params: Optional[Dict[str, Any]] = None,
) -> None:
self.num_heads = num_heads
self.head_size = head_size
......@@ -192,8 +224,14 @@ class FlashInferImpl(AttentionImpl):
kv_cache: Optional[torch.Tensor],
attn_metadata: FlashInferMetadata,
kv_scale: float = 1.0,
attn_type: AttentionType = AttentionType.DECODER,
) -> torch.Tensor:
assert kv_scale == 1.0
if attn_type != AttentionType.DECODER:
raise NotImplementedError("Encoder self-attention and "
"encoder/decoder cross-attention "
"are not implemented for "
"FlashInferImpl")
num_tokens, hidden_size = query.shape
query = query.view(-1, self.num_heads, self.head_size)
key = key.view(-1, self.num_kv_heads, self.head_size)
......@@ -217,10 +255,14 @@ class FlashInferImpl(AttentionImpl):
self.kv_cache_dtype,
)
query = query.contiguous(
) # Flashinfer requires query to be contiguous
if prefill_meta := attn_metadata.prefill_metadata:
# Prompt run.
assert prefill_meta.block_tables is not None
if kv_cache is None or prefill_meta.block_tables.numel() == 0:
# We will use flash attention for prefill
# when kv_cache is not provided.
# This happens when vllm runs the profiling to
# determine the number of blocks.
if kv_cache is None:
output = flash_attn_varlen_func(
q=query,
k=key,
......@@ -235,16 +277,19 @@ class FlashInferImpl(AttentionImpl):
alibi_slopes=self.alibi_slopes,
)
else:
raise NotImplementedError(
"Prefix caching is not supported with flashinfer yet.")
assert prefill_meta is not None
assert prefill_meta.prefill_wrapper is not None
output = prefill_meta.prefill_wrapper.forward(
query,
kv_cache,
logits_soft_cap=attn_metadata.logits_soft_cap,
causal=True)
else:
assert attn_metadata.decode_metadata is not None
assert attn_metadata.decode_metadata.decode_wrapper is not None
query = query.contiguous(
) # Flashinfer requires query to be contiguous
output = attn_metadata.decode_metadata.decode_wrapper.forward(
query,
kv_cache,
sm_scale=self.scale,
)
logits_soft_cap=attn_metadata.logits_soft_cap)
return output.view(num_tokens, hidden_size)
""" Attention layer with torch scaled_dot_product_attention
and PagedAttention."""
from dataclasses import dataclass
from typing import Any, Dict, List, Optional, Tuple, Type
import torch
from vllm._ipex_ops import ipex_ops
from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
AttentionMetadata, AttentionType)
from vllm.attention.ops.paged_attn import (PagedAttention,
PagedAttentionMetadata)
_PARTITION_SIZE = 512
class IpexAttnBackend(AttentionBackend):
@staticmethod
def get_name() -> str:
return "ipex-attn"
@staticmethod
def get_impl_cls() -> Type["IpexAttnBackendImpl"]:
return IpexAttnBackendImpl
@staticmethod
def get_metadata_cls() -> Type["IpexAttnMetadata"]:
return IpexAttnMetadata
@staticmethod
def get_kv_cache_shape(
num_blocks: int,
block_size: int,
num_kv_heads: int,
head_size: int,
) -> Tuple[int, ...]:
return PagedAttention.get_kv_cache_shape(num_blocks, block_size,
num_kv_heads, head_size)
@staticmethod
def swap_blocks(
src_kv_cache: torch.Tensor,
dst_kv_cache: torch.Tensor,
src_to_dst: torch.Tensor,
) -> None:
PagedAttention.swap_blocks(src_kv_cache, dst_kv_cache, src_to_dst)
@staticmethod
def copy_blocks(
kv_caches: List[torch.Tensor],
src_to_dists: torch.Tensor,
) -> None:
PagedAttention.copy_blocks(kv_caches, src_to_dists)
@dataclass
class IpexAttnMetadata(AttentionMetadata, PagedAttentionMetadata):
"""Metadata for IpexAttnBackend.
"""
# Currently, input sequences can only contain all prompts
# or all decoding. True if all sequences are prompts.
is_prompt: bool
slot_mapping: torch.Tensor
seq_lens: Optional[List[int]]
seqlen_q: Optional[torch.Tensor]
max_seqlen: Optional[int]
def __post_init__(self):
# Set during the execution of the first attention op.
# It is a list because it is needed to set per prompt
# when alibi slopes is used. It is because of the limitation
# from xformer API.
# will not appear in the __repr__ and __init__
self.attn_bias: Optional[List[torch.Tensor]] = None
@property
def prefill_metadata(self) -> Optional["IpexAttnMetadata"]:
# Currently chunked prefill is not supported
if self.num_decode_tokens == 0:
assert self.num_prefills > 0
return self
return None
@property
def decode_metadata(self) -> Optional["IpexAttnMetadata"]:
# Currently chunked prefill is not supported
if self.num_prefills > 0:
assert self.num_decode_tokens == 0
return None
return self
class IpexAttnBackendImpl(AttentionImpl[IpexAttnMetadata]):
def __init__(
self,
num_heads: int,
head_size: int,
scale: float,
num_kv_heads: int,
alibi_slopes: Optional[List[float]],
sliding_window: Optional[int],
kv_cache_dtype: str,
blocksparse_params: Optional[Dict[str, Any]] = None,
) -> None:
assert blocksparse_params is None, ValueError(
"Torch SPDA does not support block-sparse attention.")
self.num_heads = num_heads
self.head_size = head_size
self.scale = float(scale)
self.num_kv_heads = num_kv_heads
if alibi_slopes is not None:
alibi_slopes = torch.tensor(alibi_slopes, dtype=torch.float32)
self.alibi_slopes = alibi_slopes
self.sliding_window = sliding_window
self.kv_cache_dtype = kv_cache_dtype
assert self.num_heads % self.num_kv_heads == 0
self.num_queries_per_kv = self.num_heads // self.num_kv_heads
self.need_mask = (self.alibi_slopes is not None
or self.sliding_window is not None)
supported_head_sizes = PagedAttention.get_supported_head_sizes()
if head_size not in supported_head_sizes:
raise ValueError(
f"Head size {head_size} is not supported by PagedAttention. "
f"Supported head sizes are: {supported_head_sizes}.")
if kv_cache_dtype != "auto":
raise NotImplementedError(
"IPEX backend does not support FP8 KV cache. "
"Please use xFormers backend instead.")
def split_kv_cache(
self,
kv_cache: torch.Tensor,
num_kv_heads: int,
head_size: int,
) -> Tuple[torch.Tensor, torch.Tensor]:
x = 1
num_blocks = kv_cache.shape[1]
key_cache = kv_cache[0]
key_cache = key_cache.view(num_blocks, num_kv_heads, head_size // x,
-1, x)
value_cache = kv_cache[1]
value_cache = value_cache.view(num_blocks, num_kv_heads, head_size, -1)
return key_cache, value_cache
def forward(
self,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
kv_cache: Optional[torch.Tensor],
attn_metadata: IpexAttnMetadata, # type: ignore
kv_scale: float = 1.0,
attn_type: AttentionType = AttentionType.DECODER,
) -> torch.Tensor:
"""Forward pass with IPEX varlen_attention and PagedAttention.
Args:
query: shape = [num_tokens, num_heads * head_size]
key: shape = [num_tokens, num_kv_heads * head_size]
value: shape = [num_tokens, num_kv_heads * head_size]
kv_cache = [2, num_blocks, block_size * num_kv_heads * head_size]
attn_metadata: Metadata for attention.
Returns:
shape = [num_tokens, num_heads * head_size]
"""
assert kv_scale == 1.0
if attn_type != AttentionType.DECODER:
raise NotImplementedError("Encoder self-attention and "
"encoder/decoder cross-attention "
"are not implemented for "
"IpexAttnBackendImpl")
num_tokens, hidden_size = query.shape
# Reshape the query, key, and value tensors.
query = query.view(-1, self.num_heads, self.head_size)
key = key.view(-1, self.num_kv_heads, self.head_size)
value = value.view(-1, self.num_kv_heads, self.head_size)
if kv_cache is not None:
key_cache, value_cache = self.split_kv_cache(
kv_cache, self.num_kv_heads, self.head_size)
ipex_ops.reshape_and_cache(
key,
value,
key_cache,
value_cache,
attn_metadata.slot_mapping.flatten(),
self.kv_cache_dtype,
kv_scale,
)
if attn_metadata.is_prompt:
assert attn_metadata.seq_lens is not None
if (kv_cache is None or attn_metadata.block_tables.numel() == 0):
if self.num_kv_heads != self.num_heads:
key = key.repeat_interleave(self.num_queries_per_kv, dim=1)
value = value.repeat_interleave(self.num_queries_per_kv,
dim=1)
if attn_metadata.attn_bias is None:
if self.alibi_slopes is not None:
att_masks = _make_alibi_bias(
self.alibi_slopes, query.dtype,
attn_metadata.seq_lens) # type: ignore
elif self.sliding_window is not None:
att_masks = _make_sliding_window_bias(
attn_metadata.seq_lens, self.sliding_window,
query.dtype) # type: ignore
else:
att_masks = _make_sliding_window_bias(
attn_metadata.seq_lens, None, dtype=query.dtype)
attn_metadata.attn_bias = att_masks
output = torch.empty(
(num_tokens, self.num_heads, self.head_size),
dtype=query.dtype,
device=query.device)
ipex_ops.varlen_attention(query,
key,
value,
output,
attn_metadata.seqlen_q,
attn_metadata.seqlen_q,
attn_metadata.max_seqlen,
attn_metadata.max_seqlen,
pdropout=0.0,
softmax_scale=self.scale,
zero_tensors=False,
is_causal=True,
return_softmax=False,
gen_=None)
else:
# prefix-enabled attention
raise RuntimeError(
"IPEX backend doesn't support prefix decoding.")
else:
# Decoding run.
max_seq_len = attn_metadata.max_decode_seq_len
output = torch.empty_like(query)
block_size = value_cache.shape[3]
num_seqs, num_heads, head_size = query.shape
max_num_partitions = ((max_seq_len + _PARTITION_SIZE - 1) //
_PARTITION_SIZE)
# NOTE(woosuk): We use a simple heuristic to decide whether to use
# PagedAttention V1 or V2. If the number of partitions is 1, we use
# V1 to avoid the overhead of reduction. Also, if the number of
# sequences or heads is large, we use V1 since there is enough work
# to parallelize.
# TODO(woosuk): Tune this heuristic.
# For context len > 8192, use V2 kernel to avoid shared memory
# shortage.
use_v1 = (max_seq_len <= 8192 and
(max_num_partitions == 1 or num_seqs * num_heads > 512))
if use_v1:
# Run PagedAttention V1.
ipex_ops.paged_attention_v1(
output,
query,
key_cache,
value_cache,
self.num_kv_heads,
self.scale,
attn_metadata.block_tables,
attn_metadata.seq_lens_tensor,
block_size,
max_seq_len,
self.alibi_slopes,
self.kv_cache_dtype,
kv_scale,
)
else:
# Run PagedAttention V2.
assert _PARTITION_SIZE % block_size == 0
tmp_output = torch.empty(
size=(num_seqs, num_heads, max_num_partitions, head_size),
dtype=output.dtype,
device=output.device,
)
exp_sums = torch.empty(
size=(num_seqs, num_heads, max_num_partitions),
dtype=torch.float32,
device=output.device,
)
max_logits = torch.empty_like(exp_sums)
ipex_ops.paged_attention_v2(
output,
exp_sums,
max_logits,
tmp_output,
query,
key_cache,
value_cache,
self.num_kv_heads,
self.scale,
attn_metadata.block_tables,
attn_metadata.seq_lens_tensor,
block_size,
max_seq_len,
self.alibi_slopes,
self.kv_cache_dtype,
kv_scale,
)
# Reshape the output tensor.
return output.view(-1, self.num_heads * self.head_size)
def _make_alibi_bias(
alibi_slopes: torch.Tensor,
dtype: torch.dtype,
seq_lens: List[int],
) -> List[torch.Tensor]:
attn_biases = []
for seq_len in seq_lens:
bias = torch.arange(seq_len, dtype=dtype, device=alibi_slopes.device)
# NOTE(zhuohan): HF uses
# `bias = bias[None, :].repeat(seq_len, 1)`
# here. We find that both biases give the same results, but
# the bias below more accurately follows the original ALiBi
# paper.
bias = bias[None, :] - bias[:, None]
num_heads = alibi_slopes.shape[0]
bias = bias[None, :].repeat((num_heads, 1, 1))
bias.mul_(alibi_slopes[:, None, None])
inf_mask = torch.empty(
(1, seq_len, seq_len),
dtype=bias.dtype,
device=alibi_slopes.device).fill_(-torch.inf).triu_(diagonal=1)
attn_biases.append((bias + inf_mask).to(dtype))
return attn_biases
def _make_sliding_window_bias(
seq_lens: List[int],
window_size: Optional[int],
dtype: torch.dtype,
) -> List[torch.Tensor]:
attn_biases = []
for seq_len in seq_lens:
tensor = torch.full(
(1, seq_len, seq_len),
dtype=dtype,
fill_value=1,
)
shift = 0
mask = torch.tril(tensor, diagonal=shift).to(dtype) # type: ignore
if window_size is not None:
mask = torch.triu(mask, diagonal=shift - window_size + 1)
mask = torch.log(mask)
attn_biases.append(mask.to(dtype))
return attn_biases
from dataclasses import dataclass
from typing import List, Tuple
import openvino as ov
import torch
from vllm.attention.backends.abstract import (AttentionBackend,
AttentionMetadata)
class OpenVINOAttentionBackend(AttentionBackend):
@staticmethod
def get_name() -> str:
return "openvino"
@staticmethod
def get_impl_cls():
# OpenVINO implements PagedAttention as part of the Optimum
# exported model
raise NotImplementedError
@staticmethod
def make_metadata(*args, **kwargs) -> "AttentionMetadata":
raise NotImplementedError
@staticmethod
def make_openvino_metadata(*args, **kwargs) -> "OpenVINOAttentionMetadata":
return OpenVINOAttentionMetadata(*args, **kwargs)
@staticmethod
def get_kv_cache_shape(
num_blocks: int,
block_size: int,
num_kv_heads: int,
head_size: int,
) -> Tuple[int, ...]:
return (2, num_blocks, num_kv_heads, block_size, head_size)
@staticmethod
def swap_blocks(
src_kv_cache: ov.Tensor,
dst_kv_cache: ov.Tensor,
src_to_dst: torch.Tensor,
) -> None:
# OpenVINO currently supports only CPU, which does not require
# swap of KV cache blocks
raise NotImplementedError
@staticmethod
def copy_blocks(
kv_caches: List[Tuple[ov.Tensor, ov.Tensor]],
src_to_dists: List[Tuple[int, int]],
) -> None:
for src, dst in src_to_dists:
for key_cache, value_cache in kv_caches:
key_cache.data[dst, :] = key_cache.data[src, :]
value_cache.data[dst, :] = value_cache.data[src, :]
@dataclass
class OpenVINOAttentionMetadata:
"""Metadata for OpenVINOAttentionBackend.
Basic terms used below:
- batch_size_in_sequences - total number of sequences to execute​
- prompt_lens – per sequence size number of scheduled tokens​
- batch_size_in_tokens = sum(prompt_lens)​
- max_context_len = max(context_lens)​
- max_num_blocks = div_up(max_context_len / BLOCK_SIZE)​
- num_blocks – total number of blocks in block_indices​
"""
# Describes past KV cache size for each sequence within a batch
# Shape: [batch_size_in_sequences]
# Type: i32​
past_lens: torch.Tensor
# Describes start indices of input / speculative tokens from
# current sequences within a batch sequence​
# Shape: [batch_size_in_sequences + 1]​
# Type: i32
subsequence_begins: torch.Tensor
# Describes block tables for each sequence within a batch​ -
# indices along 0th dimension in key_cache and value_cache inputs​
# Shape: [num_blocks]
# Type: i32​
block_indices: torch.Tensor
# Describes block tables for each sequence within a batch​ -
# for i-th element, it is an index in block_indices with the
# first block belonging to i-th sequence​
# Shape: [batch_size_in_sequences + 1]
# Type: i32​
block_indices_begins: torch.Tensor
# Describes max context length
# Shape: scalar
# Type: i32
max_context_len: torch.Tensor
from dataclasses import dataclass
from typing import Any, Dict, List, Optional, Tuple, Type
import torch
import torch_xla.experimental.custom_kernel # Required to register custom ops.
import torch_xla.experimental.dynamo_set_buffer_donor
from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
AttentionMetadata, AttentionType)
class PallasAttentionBackend(AttentionBackend):
@staticmethod
def get_impl_cls() -> Type["PallasAttentionBackendImpl"]:
return PallasAttentionBackendImpl
@staticmethod
def get_metadata_cls() -> Type["PallasMetadata"]:
return PallasMetadata
@staticmethod
def get_kv_cache_shape(
num_blocks: int,
block_size: int,
num_kv_heads: int,
head_size: int,
) -> Tuple[int, ...]:
return (num_kv_heads, num_blocks, block_size, head_size)
@staticmethod
def swap_blocks(
src_kv_cache: torch.Tensor,
dst_kv_cache: torch.Tensor,
src_to_dst: torch.Tensor,
) -> None:
raise RuntimeError("swap_blocks is not used for the TPU backend.")
@torch.compile(backend="openxla")
@staticmethod
def copy_blocks(
kv_caches: List[Tuple[torch.Tensor, torch.Tensor]],
src_to_dists: Tuple[torch.Tensor, torch.Tensor],
) -> None:
src_indices, dst_indices = src_to_dists
for k_cache, v_cache in kv_caches:
torch.ops.xla.dynamo_set_buffer_donor_(k_cache, True)
k_cache[:, dst_indices] = k_cache[:, src_indices]
torch.ops.xla.dynamo_set_buffer_donor_(v_cache, True)
v_cache[:, dst_indices] = v_cache[:, src_indices]
@dataclass
class PallasMetadata(AttentionMetadata):
# Currently, input sequences can only contain all prefills
# or all decoding.
block_tables: Optional[torch.Tensor]
context_lens: Optional[torch.Tensor]
@property
def prefill_metadata(self) -> Optional["PallasMetadata"]:
if self.num_prefills == 0:
return None
assert self.num_decode_tokens == 0
assert self.block_tables is None
assert self.context_lens is None
return self
@property
def decode_metadata(self) -> Optional["PallasMetadata"]:
if self.num_decode_tokens == 0:
return None
assert self.num_prefills == 0
assert self.num_prefill_tokens == 0
assert self.block_tables is not None
assert self.context_lens is not None
return self
class PallasAttentionBackendImpl(AttentionImpl):
def __init__(
self,
num_heads: int,
head_size: int,
scale: float,
num_kv_heads: int,
alibi_slopes: Optional[List[float]],
sliding_window: Optional[int],
kv_cache_dtype: str,
blocksparse_params: Optional[Dict[str, Any]] = None,
) -> None:
self.num_heads = num_heads
self.head_size = head_size
self.scale = float(scale)
self.num_kv_heads = num_heads if num_kv_heads is None else num_kv_heads
assert self.num_heads % self.num_kv_heads == 0
self.num_queries_per_kv = self.num_heads // self.num_kv_heads
if head_size % 128 != 0:
raise NotImplementedError("Head size must be a multiple of 128.")
if alibi_slopes is not None:
raise NotImplementedError("Alibi slopes is not supported.")
if sliding_window is not None:
raise NotImplementedError("Sliding window is not supported.")
if kv_cache_dtype != "auto":
raise NotImplementedError("FP8 KV cache dtype is not supported.")
if blocksparse_params is not None:
raise NotImplementedError("Blocksparse is not supported.")
if torch_xla.tpu.version() < 4:
raise NotImplementedError("TPU version must be 4 or higher.")
self.megacore_mode = None
tpu_type = torch_xla.tpu.get_tpu_env()["TYPE"].lower()
if "lite" not in tpu_type:
if self.num_kv_heads % 2 == 0:
self.megacore_mode = "kv_head"
else:
# NOTE(woosuk): If the batch size is not a multiple of 2, the
# megacore mode will be None.
self.megacore_mode = "batch"
def forward(
self,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
kv_cache: Tuple[Optional[torch.Tensor], Optional[torch.Tensor]],
attn_metadata: PallasMetadata,
kv_scale: float = 1.0,
attn_type: AttentionType = AttentionType.DECODER,
) -> torch.Tensor:
"""Forward pass with Pallas attention.
Args:
query: shape = [batch_size, seq_len, num_heads * head_size]
key: shape = [batch_size, seq_len, num_kv_heads * head_size]
value: shape = [batch_size, seq_len, num_kv_heads * head_size]
key_cache = [num_kv_heads, num_blocks, block_size, head_size]
value_cache = [num_kv_heads, num_blocks, block_size, head_size]
attn_metadata: Metadata for attention.
Returns:
shape = [batch_size, seq_len, num_heads * head_size]
"""
assert kv_scale == 1.0
if attn_type != AttentionType.DECODER:
raise NotImplementedError("Encoder self-attention and "
"encoder/decoder cross-attention "
"are not implemented for "
"PallasAttentionBackendImpl")
batch_size, seq_len, hidden_size = query.shape
query = query.view(batch_size, seq_len, self.num_heads, self.head_size)
key = key.view(batch_size, seq_len, self.num_kv_heads, self.head_size)
value = value.view(batch_size, seq_len, self.num_kv_heads,
self.head_size)
if kv_cache[0] is not None:
slot_mapping = attn_metadata.slot_mapping
key_cache, value_cache = kv_cache
write_to_kv_cache(key, value, key_cache, value_cache, slot_mapping)
query = query * self.scale
if attn_metadata.num_prefills > 0:
assert seq_len % 16 == 0, (
"Pallas FlashAttention kernel requires seq_len to be a "
f"multiple of 16 but got {seq_len}")
# Handle GQA/MQA.
if self.num_kv_heads != self.num_heads:
key = key.repeat_interleave(self.num_queries_per_kv, dim=-2)
key = key.view(batch_size, seq_len, self.num_heads,
self.head_size)
value = value.repeat_interleave(self.num_queries_per_kv,
dim=-2)
value = value.view(batch_size, seq_len, self.num_heads,
self.head_size)
# FlashAttention requires [batch_size, num_heads, seq_len, d_model]
# while the input is [batch_size, seq_len, num_heads, d_model].
# Permute the input to match the required format.
output = torch.ops.xla.flash_attention(
query.permute(0, 2, 1, 3),
key.permute(0, 2, 1, 3),
value.permute(0, 2, 1, 3),
True,
)
output = output.permute(0, 2, 1, 3)
else:
# Decoding run.
assert kv_cache is not None
pages_per_compute_block = 16 # TODO(woosuk): Tune this value.
if self.megacore_mode == "batch" and batch_size % 2 != 0:
megacore_mode = None
else:
megacore_mode = self.megacore_mode
# NOTE(woosuk): A temporary workaround to avoid the error:
# "xla::paged_attention() Expected a value of type 'str' for
# argument 'megacore_mode' but instead found type 'NoneType'."
if megacore_mode is not None:
output = torch.ops.xla.paged_attention(
query.squeeze(dim=1),
key_cache,
value_cache,
attn_metadata.context_lens,
attn_metadata.block_tables,
pages_per_compute_block,
megacore_mode=megacore_mode,
)
else:
output = torch.ops.xla.paged_attention(
query.squeeze(dim=1),
key_cache,
value_cache,
attn_metadata.context_lens,
attn_metadata.block_tables,
pages_per_compute_block,
)
# Reshape the output tensor.
return output.reshape(batch_size, seq_len, hidden_size)
def write_to_kv_cache(
key: torch.Tensor,
value: torch.Tensor,
key_cache: torch.Tensor,
value_cache: torch.Tensor,
slot_mapping: torch.Tensor,
) -> None:
torch.ops.xla.dynamo_set_buffer_donor_(key_cache, True)
torch.ops.xla.dynamo_set_buffer_donor_(value_cache, True)
key = key.flatten(0, 2)
value = value.flatten(0, 2)
key_cache = key_cache.flatten(0, 2)
value_cache = value_cache.flatten(0, 2)
key_cache.index_copy_(0, slot_mapping, key)
value_cache.index_copy_(0, slot_mapping, value)
......@@ -6,7 +6,7 @@ import torch
import vllm.envs as envs
from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
AttentionMetadata)
AttentionMetadata, AttentionType)
from vllm.attention.ops.paged_attn import (PagedAttention,
PagedAttentionMetadata)
from vllm.logger import init_logger
......@@ -25,8 +25,8 @@ class ROCmFlashAttentionBackend(AttentionBackend):
return ROCmFlashAttentionImpl
@staticmethod
def make_metadata(*args, **kwargs) -> "ROCmFlashAttentionMetadata":
return ROCmFlashAttentionMetadata(*args, **kwargs)
def get_metadata_cls() -> Type["AttentionMetadata"]:
return ROCmFlashAttentionMetadata
@staticmethod
def get_kv_cache_shape(
......@@ -166,6 +166,37 @@ class ROCmFlashAttentionMetadata(AttentionMetadata, PagedAttentionMetadata):
return self._cached_decode_metadata
def _make_alibi_bias(alibi_slopes: torch.Tensor,
dtype: torch.dtype,
seq_lens: Optional[List[int]],
make_attn_mask: bool = True) -> List[torch.Tensor]:
attn_biases = []
if seq_lens:
for seq_len in seq_lens:
bias = torch.arange(seq_len, dtype=dtype)
# NOTE(zhuohan): HF uses
# `bias = bias[None, :].repeat(seq_len, 1)`
# here. We find that both biases give the same results, but
# the bias below more accurately follows the original ALiBi
# paper.
bias = bias[None, :] - bias[:, None]
num_heads = alibi_slopes.shape[0]
bias = bias[None, :].repeat(
(num_heads, 1, 1)).to(alibi_slopes.device)
bias.mul_(alibi_slopes[:, None, None])
if make_attn_mask:
inf_mask = torch.empty(
(1, seq_len, seq_len),
dtype=bias.dtype).fill_(-torch.inf).triu_(diagonal=1).to(
alibi_slopes.device)
attn_biases.append((bias + inf_mask).to(dtype))
else:
attn_biases.append(bias.to(dtype))
return attn_biases
class ROCmFlashAttentionImpl(AttentionImpl):
"""
If the input tensors contain prompt tokens, the layout is as follows:
......@@ -229,11 +260,11 @@ class ROCmFlashAttentionImpl(AttentionImpl):
# NOTE: Allow for switching between Triton and CK. Defaulting to triton.
self.use_triton_flash_attn = envs.VLLM_USE_TRITON_FLASH_ATTN
if self.use_triton_flash_attn:
# from vllm.attention.ops.triton_flash_attention import ( # noqa: F401
# triton_attention)
from vllm.attention.ops.flash_attn_triton_mqa_gqa import (
flash_attn_varlen_func)
self.attn_func = flash_attn_varlen_func # triton_attention
from vllm.attention.ops.triton_flash_attention import ( # noqa: F401
triton_attention)
# from vllm.attention.ops.flash_attn_triton_mqa_gqa import (
# flash_attn_varlen_func)
self.attn_func = triton_attention # flash_attn_varlen_func
logger.debug("Using Triton FA in ROCmBackend")
else:
# if not using triton, navi3x/navi21/navi10 do not use flash-attn
......@@ -268,6 +299,7 @@ class ROCmFlashAttentionImpl(AttentionImpl):
kv_cache: torch.Tensor,
attn_metadata: ROCmFlashAttentionMetadata,
kv_scale: float = 1.0,
attn_type: AttentionType = AttentionType.DECODER,
) -> torch.Tensor:
"""Forward pass with FlashAttention and PagedAttention.
......@@ -280,6 +312,12 @@ class ROCmFlashAttentionImpl(AttentionImpl):
Returns:
shape = [num_tokens, num_heads * head_size]
"""
if attn_type != AttentionType.DECODER:
raise NotImplementedError("Encoder self-attention and "
"encoder/decoder cross-attention "
"are not implemented for "
"ROCmFlashAttentionImpl")
num_tokens, hidden_size = query.shape
# Reshape the query, key, and value tensors.
query = query.view(-1, self.num_heads, self.head_size)
......@@ -326,34 +364,36 @@ class ROCmFlashAttentionImpl(AttentionImpl):
# triton attention
# When block_tables are not filled, it means q and k are the
# prompt, and they have the same length.
attn_masks = None
if self.use_triton_flash_attn:
# out, _ = self.attn_func(
# query,
# key,
# value,
# None,
# prefill_meta.seq_start_loc,
# prefill_meta.seq_start_loc,
# prefill_meta.max_prefill_seq_len,
# prefill_meta.max_prefill_seq_len,
# True,
# self.scale,
if self.alibi_slopes is not None:
attn_masks = _make_alibi_bias(
self.alibi_slopes,
query.dtype,
attn_metadata.seq_lens,
make_attn_mask=False) # type: ignore
out = self.attn_func(
q=query,
k=key,
v=value,
cu_seqlens_q=prefill_meta.seq_start_loc,
cu_seqlens_k=prefill_meta.seq_start_loc,
max_seqlens_q=prefill_meta.max_prefill_seq_len,
max_seqlens_k=prefill_meta.max_prefill_seq_len,
softmax_scale=self.scale,
causal=True,
query,
key,
value,
prefill_meta.seq_lens,
num_tokens,
self.num_heads,
self.head_size,
self.scale,
attn_masks,
)
elif self.use_naive_attn:
if self.num_kv_heads != self.num_heads:
# Interleave for MQA workaround.
key = self.repeat_kv(key, self.num_queries_per_kv)
value = self.repeat_kv(value, self.num_queries_per_kv)
if self.alibi_slopes is not None:
attn_masks = _make_alibi_bias(
self.alibi_slopes,
query.dtype,
attn_metadata.seq_lens,
make_attn_mask=True) # type: ignore
query = query.movedim(0, query.dim() - 2)
key = key.movedim(0, key.dim() - 2)
value = value.movedim(0, value.dim() - 2)
......@@ -367,6 +407,7 @@ class ROCmFlashAttentionImpl(AttentionImpl):
self.num_heads,
self.head_size,
self.scale,
attn_masks,
)
else:
out = self.attn_func(
......@@ -430,13 +471,14 @@ def _sdpa_attention(
num_heads: int,
head_size: int,
scale: float,
attn_masks: Optional[List[torch.Tensor]] = None,
) -> torch.Tensor:
start = 0
output = torch.empty((num_tokens, num_heads, head_size),
dtype=query.dtype,
device=query.device)
for seq_len in seq_lens:
for i, seq_len in enumerate(seq_lens):
end = start + seq_len
with torch.backends.cuda.sdp_kernel(enable_math=True,
enable_flash=False,
......@@ -446,7 +488,8 @@ def _sdpa_attention(
key[:, start:end, :],
value[:, start:end, :],
dropout_p=0.0,
is_causal=True,
is_causal=attn_masks is None,
attn_mask=attn_masks[i] if attn_masks else None,
scale=scale).movedim(query.dim() - 2, 0)
output[start:end, :, :] = sub_out
start = end
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment