Unverified Commit 925f3332 authored by Woosuk Kwon's avatar Woosuk Kwon Committed by GitHub
Browse files

[Core] Refactor Attention Take 2 (#3462)

parent b0dfa91d
......@@ -3,8 +3,7 @@ import pytest
import time
import torch
from vllm.model_executor.layers.attention.ops.prefix_prefill import (
context_attention_fwd)
from vllm.attention.ops.prefix_prefill import context_attention_fwd
from xformers import ops as xops
from xformers.ops.fmha.attn_bias import BlockDiagonalCausalFromBottomRightMask
......
......@@ -2,7 +2,10 @@
Run `pytest tests/samplers/test_beam_search.py --forked`.
"""
import gc
import pytest
import torch
# FIXME(zhuohan): The test can not pass if we:
# 1. Increase max_tokens to 256.
......@@ -36,6 +39,10 @@ def test_beam_search_single_input(
vllm_outputs = vllm_model.generate_beam_search(example_prompts, beam_width,
max_tokens)
del vllm_model
# NOTE(woosuk): For some reason, the following GC is required to avoid
# GPU OOM errors in the following tests using `vllm_runner`.
gc.collect()
torch.cuda.empty_cache()
for i in range(len(example_prompts)):
hf_output_ids, _ = hf_outputs[i]
......
......@@ -34,19 +34,19 @@ def test_prepare_prompt(batch_size):
expected_selected_token_indices.append(selected_token_start_idx +
prompt_len - 1)
selected_token_start_idx += prompt_len
(input_tokens, input_positions, input_metadata, return_prompt_lens, _, _,
_, _) = (model_runner._prepare_prompt(seq_group_metadata_list))
(input_tokens, input_positions, attn_metadata, return_prompt_lens, _, _, _,
_) = (model_runner._prepare_prompt(seq_group_metadata_list))
assert return_prompt_lens == prompt_lens
# Verify input metadata is correct for prompts.
device = model_runner.device
assert input_metadata.is_prompt is True
assert torch.allclose(input_metadata.prompt_lens_tensor,
assert attn_metadata.is_prompt is True
assert torch.allclose(attn_metadata.prompt_lens_tensor,
torch.tensor(prompt_lens, device=device))
assert input_metadata.prompt_lens == prompt_lens
assert input_metadata.num_prompt_tokens == sum(prompt_lens)
assert input_metadata.num_generation_tokens == 0
assert input_metadata.max_seq_len == max(prompt_lens)
assert attn_metadata.prompt_lens == prompt_lens
assert attn_metadata.num_prompt_tokens == sum(prompt_lens)
assert attn_metadata.num_generation_tokens == 0
assert attn_metadata.max_prompt_len == max(prompt_lens)
# Test subquery start locs.
start_idx = 0
......@@ -55,7 +55,7 @@ def test_prepare_prompt(batch_size):
start_idx += prompt_len
start_loc.append(start_idx)
assert torch.allclose(
input_metadata.subquery_start_loc,
attn_metadata.subquery_start_loc,
torch.tensor(start_loc, dtype=torch.int32, device=device))
# Test seq start locs. Note that for normal prefill it is
......@@ -67,22 +67,22 @@ def test_prepare_prompt(batch_size):
seq_start_loc.append(start_idx)
assert torch.allclose(
input_metadata.seq_start_loc,
attn_metadata.seq_start_loc,
torch.tensor(start_loc, dtype=torch.int32, device=device))
assert input_metadata.max_context_len is None
assert attn_metadata.max_context_len is None
assert torch.allclose(
input_metadata.context_lens,
torch.zeros(input_metadata.context_lens.shape[0],
attn_metadata.context_lens,
torch.zeros(attn_metadata.context_lens.shape[0],
dtype=torch.int,
device=device))
expected = torch.tensor([[] for _ in range(len(seq_group_metadata_list))],
dtype=torch.int32,
device=model_runner.device)
assert torch.allclose(input_metadata.block_tables, expected)
assert torch.allclose(attn_metadata.block_tables, expected)
# Cuda graph should not be used for prerill.
assert input_metadata.use_cuda_graph is False
assert input_metadata.kv_cache_dtype == "auto"
assert attn_metadata.use_cuda_graph is False
assert attn_metadata.kv_cache_dtype == "auto"
assert input_tokens.shape == (sum(prompt_lens), )
assert input_positions.shape == (sum(prompt_lens), )
......@@ -140,34 +140,34 @@ def test_prepare_decode_cuda_graph(batch_size):
block_tables={0: [1]},
))
input_tokens, input_positions, input_metadata, _, _, _ = (
input_tokens, input_positions, attn_metadata, _, _, _ = (
model_runner._prepare_decode(seq_group_metadata_list))
expected_bs = _get_graph_batch_size(len(seq_group_metadata_list))
# Verify input metadata is correct for prompts.
device = model_runner.device
assert input_metadata.is_prompt is False
assert input_metadata.prompt_lens is None
assert input_metadata.num_prompt_tokens == 0
assert input_metadata.num_generation_tokens == expected_bs
assert input_metadata.max_seq_len is None
assert input_metadata.subquery_start_loc is None
assert input_metadata.seq_start_loc is None
assert input_metadata.max_context_len == max(prompt_lens)
assert attn_metadata.is_prompt is False
assert attn_metadata.prompt_lens is None
assert attn_metadata.num_prompt_tokens == 0
assert attn_metadata.num_generation_tokens == expected_bs
assert attn_metadata.max_prompt_len is None
assert attn_metadata.subquery_start_loc is None
assert attn_metadata.seq_start_loc is None
assert attn_metadata.max_context_len == max(prompt_lens)
assert torch.allclose(
input_metadata.context_lens[:len(prompt_lens)],
attn_metadata.context_lens[:len(prompt_lens)],
torch.tensor(prompt_lens, dtype=torch.int, device=device))
# block table's first index corresponds to each batch, meaning in
# decoding it is each token.
assert input_metadata.block_tables.shape[0] == len(input_tokens)
assert attn_metadata.block_tables.shape[0] == len(input_tokens)
# Block table's second dim correspondsd to each token's block number.
# It is padded up to
assert input_metadata.block_tables.shape[1] == (
assert attn_metadata.block_tables.shape[1] == (
model_runner.get_max_block_per_batch())
# Cuda graph should not be used for prerill.
assert input_metadata.use_cuda_graph is True
assert input_metadata.kv_cache_dtype == "auto"
assert attn_metadata.use_cuda_graph is True
assert attn_metadata.kv_cache_dtype == "auto"
assert input_tokens.shape == (expected_bs, )
assert input_positions.shape == (expected_bs, )
......
from vllm.attention.backends.abstract import AttentionBackend, AttentionMetadata
from vllm.attention.layer import Attention
from vllm.attention.selector import get_attn_backend
__all__ = [
"AttentionBackend",
"AttentionMetadata",
"Attention",
"get_attn_backend",
]
from abc import ABC, abstractmethod
from dataclasses import dataclass, fields
from typing import Any, Dict, List, Optional, Tuple, Type
import torch
class AttentionBackend(ABC):
"""Abstract class for attention backends."""
@staticmethod
@abstractmethod
def get_impl_cls() -> Type["AttentionImpl"]:
raise NotImplementedError
@staticmethod
@abstractmethod
def make_metadata(*args, **kwargs) -> "AttentionMetadata":
raise NotImplementedError
@staticmethod
@abstractmethod
def get_kv_cache_shape(
num_blocks: int,
block_size: int,
num_kv_heads: int,
head_size: int,
) -> Tuple[int, ...]:
raise NotImplementedError
@staticmethod
@abstractmethod
def swap_blocks(
src_kv_cache: torch.Tensor,
dst_kv_cache: torch.Tensor,
src_to_dst: Dict[int, int],
) -> None:
raise NotImplementedError
@staticmethod
@abstractmethod
def copy_blocks(
kv_caches: List[torch.Tensor],
src_to_dists: Dict[int, List[int]],
) -> None:
raise NotImplementedError
@dataclass
class AttentionMetadata:
def asdict_zerocopy(self) -> Dict[str, Any]:
"""Similar to dataclasses.asdict, but avoids deepcopying."""
# Note that if we add dataclasses as fields, they will need
# similar handling.
return {
field.name: getattr(self, field.name)
for field in fields(self)
}
class AttentionImpl(ABC):
@abstractmethod
def __init__(
self,
num_heads: int,
head_size: int,
scale: float,
num_kv_heads: Optional[int] = None,
alibi_slopes: Optional[List[float]] = None,
sliding_window: Optional[int] = None,
) -> None:
raise NotImplementedError
@abstractmethod
def forward(
self,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
kv_cache: torch.Tensor,
attn_metadata: AttentionMetadata,
) -> torch.Tensor:
raise NotImplementedError
"""Attention layer with Flash and PagedAttention."""
from typing import List, Optional
"""Attention layer with Flash and PagedAttention.
NOTE(woosuk): At the moment, this file includes a lot of duplicated code from
XFormers backend. The duplicated code will be removed once we use flash-attn or
flashinfer for all the attention operations.
"""
from dataclasses import dataclass
from typing import Dict, List, Optional, Tuple, Type
from flash_attn import flash_attn_varlen_func
import torch
from vllm.model_executor.input_metadata import InputMetadata
from vllm.model_executor.layers.attention.ops.paged_attn import (
PagedAttentionImpl)
from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
AttentionMetadata)
from vllm.attention.ops.paged_attn import PagedAttention, PagedAttentionMetadata
class FlashAttentionBackend(AttentionBackend):
@staticmethod
def get_impl_cls() -> Type["FlashAttentionImpl"]:
return FlashAttentionImpl
@staticmethod
def make_metadata(*args, **kwargs) -> "FlashAttentionMetadata":
return FlashAttentionMetadata(*args, **kwargs)
@staticmethod
def get_kv_cache_shape(
num_blocks: int,
block_size: int,
num_kv_heads: int,
head_size: int,
) -> Tuple[int, ...]:
return PagedAttention.get_kv_cache_shape(num_blocks, block_size,
num_kv_heads, head_size)
@staticmethod
def swap_blocks(
src_kv_cache: torch.Tensor,
dst_kv_cache: torch.Tensor,
src_to_dst: Dict[int, int],
) -> None:
PagedAttention.swap_blocks(src_kv_cache, dst_kv_cache, src_to_dst)
@staticmethod
def copy_blocks(
kv_caches: List[torch.Tensor],
src_to_dists: Dict[int, List[int]],
) -> None:
PagedAttention.copy_blocks(kv_caches, src_to_dists)
@dataclass
class FlashAttentionMetadata(AttentionMetadata, PagedAttentionMetadata):
"""Metadata for FlashAttentionBackend.
class FlashAttentionBackend:
NOTE: Any python object stored here is not updated when it is
cuda-graph replayed. If you have values that need to be changed
dynamically, it should be stored in tensor. The tensor has to be
updated from `CUDAGraphRunner.forward` API.
"""
# Currently, input sequences can only contain all prompts
# or all decoding. True if all sequences are prompts.
is_prompt: bool
# (batch_size,). The prompt length per sequence. None if it is a decoding.
prompt_lens: Optional[List[int]]
# prompt_lens stored as a tensor.
prompt_lens_tensor: Optional[torch.Tensor]
# The number of prompt tokens. Doesn't include padding.
num_prompt_tokens: int
# The number of generation tokens. Doesn't include padding.
num_generation_tokens: int
# NOTE(sang): Definition of context_len, subquery_len, and seqlen.
# |---------- N-1 iteration --------|
# |---------------- N iteration ---------------------|
# |- tokenA -|......................|-- newTokens ---|
# |---------- context_len ----------|
# |-------------------- seqlen ----------------------|
# |- subquery_len -|
# WARNING(sang): context_len has different definition depending on if it is
# prefill vs decoding. When it is prefill, it doesn't include new tokens.
# When it is for decoding, it includes a new token.
# Maximum subquery length in the batch.
max_subquery_len: Optional[int]
# Maximum prompt length in the batch.
max_prompt_len: Optional[int]
# (batch_size + 1,). The cumulative subquery lengths of the sequences in
# the batch, used to index into subquery. E.g., if the subquery length
# is [4, 6], it is [0, 4, 10].
subquery_start_loc: Optional[torch.Tensor]
# (batch_size + 1,). The cumulative sequence lengths of the sequences in
# the batch, used to index into sequence. E.g., if the sequence length is
# [4, 6], it is [0, 4, 10].
seq_start_loc: Optional[torch.Tensor]
# Whether or not if cuda graph is enabled.
# Cuda-graph is currently enabled for decoding only.
# TODO(woosuk): Move `use_cuda_graph` out since it's unrelated to attention.
use_cuda_graph: bool
class FlashAttentionImpl(AttentionImpl):
"""
If the input tensors contain prompt tokens, the layout is as follows:
|<--------------- num_prompt_tokens -------------->|
......@@ -39,30 +133,28 @@ class FlashAttentionBackend:
self.head_size = head_size
self.scale = float(scale)
self.num_kv_heads = num_heads if num_kv_heads is None else num_kv_heads
self.sliding_window = sliding_window
self.sliding_window = ((sliding_window, sliding_window)
if sliding_window is not None else (-1, -1))
if alibi_slopes is not None:
alibi_slopes = torch.tensor(alibi_slopes, dtype=torch.float32)
self.alibi_slopes = alibi_slopes
assert self.num_heads % self.num_kv_heads == 0
self.num_queries_per_kv = self.num_heads // self.num_kv_heads
suppored_head_sizes = PagedAttentionImpl.get_supported_head_sizes()
suppored_head_sizes = PagedAttention.get_supported_head_sizes()
if head_size not in suppored_head_sizes:
raise ValueError(
f"Head size {head_size} is not supported by PagedAttention. "
f"Supported head sizes are: {suppored_head_sizes}.")
self.sliding_window = ((self.sliding_window, self.sliding_window) if
self.sliding_window is not None else (-1, -1))
def forward(
self,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
key_cache: Optional[torch.Tensor],
value_cache: Optional[torch.Tensor],
input_metadata: InputMetadata,
kv_cache: torch.Tensor,
attn_metadata: FlashAttentionMetadata,
) -> torch.Tensor:
"""Forward pass with FlashAttention and PagedAttention.
......@@ -70,11 +162,8 @@ class FlashAttentionBackend:
query: shape = [num_tokens, num_heads * head_size]
key: shape = [num_tokens, num_kv_heads * head_size]
value: shape = [num_tokens, num_kv_heads * head_size]
key_cache: shape = [num_blocks, num_kv_heads, head_size/x,
block_size, x]
value_cache: shape = [num_blocks, num_kv_heads, head_size,
block_size]
input_metadata: metadata for the inputs.
kv_cache = [2, num_blocks, block_size * num_kv_heads * head_size]
attn_metadata: Metadata for attention.
Returns:
shape = [num_tokens, num_heads * head_size]
"""
......@@ -84,18 +173,21 @@ class FlashAttentionBackend:
key = key.view(-1, self.num_kv_heads, self.head_size)
value = value.view(-1, self.num_kv_heads, self.head_size)
# Reshape the keys and values and store them in the cache.
# If key_cache and value_cache are not provided, the new key and value
# vectors will not be cached. This happens during the initial memory
# profiling run.
if key_cache is not None and value_cache is not None:
PagedAttentionImpl.reshape_and_cache(key, value, key_cache,
value_cache, input_metadata)
if kv_cache is not None:
key_cache, value_cache = PagedAttention.split_kv_cache(
kv_cache, self.num_kv_heads, self.head_size)
# Reshape the input keys and values and store them in the cache.
# If kv_cache is not provided, the new key and value tensors are
# not cached. This happens during the initial memory profiling run.
PagedAttention.write_to_paged_cache(key, value, key_cache,
value_cache,
attn_metadata.slot_mapping,
attn_metadata.kv_cache_dtype)
if input_metadata.is_prompt:
if attn_metadata.is_prompt:
# Prompt run.
if (key_cache is None or value_cache is None
or input_metadata.block_tables.numel() == 0):
if kv_cache is None or attn_metadata.block_tables.numel() == 0:
# normal attention
# When block_tables are not filled, it means q and k are the
# prompt, and they have the same length.
......@@ -103,10 +195,10 @@ class FlashAttentionBackend:
q=query,
k=key,
v=value,
cu_seqlens_q=input_metadata.seq_start_loc,
cu_seqlens_k=input_metadata.seq_start_loc,
max_seqlen_q=input_metadata.max_seq_len,
max_seqlen_k=input_metadata.max_seq_len,
cu_seqlens_q=attn_metadata.seq_start_loc,
cu_seqlens_k=attn_metadata.seq_start_loc,
max_seqlen_q=attn_metadata.max_prompt_len,
max_seqlen_k=attn_metadata.max_prompt_len,
softmax_scale=self.scale,
causal=True,
window_size=self.sliding_window,
......@@ -114,22 +206,29 @@ class FlashAttentionBackend:
)
else:
# prefix-enabled attention
output = PagedAttentionImpl.forward_prefix(
output = PagedAttention.forward_prefix(
query,
key,
value,
key_cache,
value_cache,
input_metadata,
attn_metadata.block_tables,
attn_metadata.subquery_start_loc,
attn_metadata.prompt_lens_tensor,
attn_metadata.context_lens,
attn_metadata.max_subquery_len,
self.alibi_slopes,
)
else:
# Decoding run.
output = PagedAttentionImpl.forward_decode(
output = PagedAttention.forward_decode(
query,
key_cache,
value_cache,
input_metadata,
attn_metadata.block_tables,
attn_metadata.context_lens,
attn_metadata.max_context_len,
attn_metadata.kv_cache_dtype,
self.num_kv_heads,
self.scale,
self.alibi_slopes,
......
"""Attention layer with xFormers and PagedAttention."""
import importlib
from typing import List, Optional
from dataclasses import dataclass
from typing import Dict, List, Optional, Tuple, Type
import torch
from xformers import ops as xops
from xformers.ops.fmha.attn_bias import (BlockDiagonalCausalMask,
from xformers.ops.fmha.attn_bias import (AttentionBias,
BlockDiagonalCausalMask,
LowerTriangularMaskWithTensorBias)
from vllm.model_executor.input_metadata import InputMetadata
from vllm.model_executor.layers.attention.ops.paged_attn import (
PagedAttentionImpl)
from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
AttentionMetadata)
from vllm.attention.ops.paged_attn import PagedAttention, PagedAttentionMetadata
from vllm.logger import init_logger
from vllm.utils import is_hip
logger = init_logger(__name__)
class XFormersBackend:
class XFormersBackend(AttentionBackend):
@staticmethod
def get_impl_cls() -> Type["XFormersImpl"]:
return XFormersImpl
@staticmethod
def make_metadata(*args, **kwargs) -> "XFormersMetadata":
return XFormersMetadata(*args, **kwargs)
@staticmethod
def get_kv_cache_shape(
num_blocks: int,
block_size: int,
num_kv_heads: int,
head_size: int,
) -> Tuple[int, ...]:
return PagedAttention.get_kv_cache_shape(num_blocks, block_size,
num_kv_heads, head_size)
@staticmethod
def swap_blocks(
src_kv_cache: torch.Tensor,
dst_kv_cache: torch.Tensor,
src_to_dst: Dict[int, int],
) -> None:
PagedAttention.swap_blocks(src_kv_cache, dst_kv_cache, src_to_dst)
@staticmethod
def copy_blocks(
kv_caches: List[torch.Tensor],
src_to_dists: Dict[int, List[int]],
) -> None:
PagedAttention.copy_blocks(kv_caches, src_to_dists)
@dataclass
class XFormersMetadata(AttentionMetadata, PagedAttentionMetadata):
"""Metadata for XFormersbackend.
NOTE: Any python object stored here is not updated when it is
cuda-graph replayed. If you have values that need to be changed
dynamically, it should be stored in tensor. The tensor has to be
updated from `CUDAGraphRunner.forward` API.
"""
# Currently, input sequences can only contain all prompts
# or all decoding. True if all sequences are prompts.
is_prompt: bool
# (num_tokens,). The indices of the token slots that input tokens will be
# stored into. E.g., if `slot_mapping` is [35, 2, 17] and the block size
# is 16, the three tokens are stored in the 3rd slot in block 2, 2nd slot
# in block 0, and 1st slot in block 1, respectively.
slot_mapping: torch.Tensor
# (batch_size,). The prompt length per sequence. None if it is a decoding.
prompt_lens: Optional[List[int]]
# prompt_lens stored as a tensor.
prompt_lens_tensor: Optional[torch.Tensor]
# The number of prompt tokens. Doesn't include padding.
num_prompt_tokens: int
# The number of generation tokens. Doesn't include padding.
num_generation_tokens: int
# NOTE(sang): Definition of context_len, subquery_len, and seqlen.
# |---------- N-1 iteration --------|
# |---------------- N iteration ---------------------|
# |- tokenA -|......................|-- newTokens ---|
# |---------- context_len ----------|
# |-------------------- seqlen ----------------------|
# |- subquery_len -|
# WARNING(sang): context_len has different definition depending on if it is
# prefill vs decoding. When it is prefill, it doesn't include new tokens.
# When it is for decoding, it includes a new token.
# Maximum subquery length in the batch.
max_subquery_len: Optional[int]
# FIXME: It is for flash attn.
# Maximum prompt length in the batch.
max_prompt_len: Optional[int]
# (batch_size + 1,). The cumulative subquery lengths of the sequences in
# the batch, used to index into subquery. E.g., if the subquery length
# is [4, 6], it is [0, 4, 10].
subquery_start_loc: Optional[torch.Tensor]
# FIXME: It is for flash attn.
# (batch_size + 1,). The cumulative sequence lengths of the sequences in
# the batch, used to index into sequence. E.g., if the sequence length is
# [4, 6], it is [0, 4, 10].
seq_start_loc: Optional[torch.Tensor]
# Whether or not if cuda graph is enabled.
# Cuda-graph is currently enabled for decoding only.
# TODO(woosuk): Move `use_cuda_graph` out since it's unrelated to attention.
use_cuda_graph: bool
def __post_init__(self):
# Set during the execution of the first attention op.
# It is a list because it is needed to set per prompt
# when alibi slopes is used. It is because of the limitation
# from xformer API.
# will not appear in the __repr__ and __init__
self.attn_bias: Optional[List[AttentionBias]] = None
class XFormersImpl(AttentionImpl):
"""
If the input tensors contain prompt tokens, the layout is as follows:
|<--------------- num_prompt_tokens --------------->|
......@@ -50,22 +158,25 @@ class XFormersBackend:
assert self.num_heads % self.num_kv_heads == 0
self.num_queries_per_kv = self.num_heads // self.num_kv_heads
suppored_head_sizes = PagedAttentionImpl.get_supported_head_sizes()
suppored_head_sizes = PagedAttention.get_supported_head_sizes()
if head_size not in suppored_head_sizes:
raise ValueError(
f"Head size {head_size} is not supported by PagedAttention. "
f"Supported head sizes are: {suppored_head_sizes}.")
self.use_ref_attention = _check_use_ref_attention()
# AMD Radeon 7900 series (gfx1100) currently does not support xFormers
# nor FlashAttention. As a temporary workaround, we use naive PyTorch
# implementation of attention.
self.use_naive_attention = _check_use_naive_attention()
def forward(
self,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
key_cache: Optional[torch.Tensor],
value_cache: Optional[torch.Tensor],
input_metadata: InputMetadata,
kv_cache: Optional[torch.Tensor],
attn_metadata: XFormersMetadata,
) -> torch.Tensor:
"""Forward pass with xFormers and PagedAttention.
......@@ -73,11 +184,8 @@ class XFormersBackend:
query: shape = [num_tokens, num_heads * head_size]
key: shape = [num_tokens, num_kv_heads * head_size]
value: shape = [num_tokens, num_kv_heads * head_size]
key_cache: shape = [num_blocks, num_kv_heads, head_size/x,
block_size, x]
value_cache: shape = [num_blocks, num_kv_heads, head_size,
block_size]
input_metadata: metadata for the inputs.
kv_cache = [2, num_blocks, block_size * num_kv_heads * head_size]
attn_metadata: Metadata for attention.
Returns:
shape = [num_tokens, num_heads * head_size]
"""
......@@ -86,20 +194,24 @@ class XFormersBackend:
key = key.view(-1, self.num_kv_heads, self.head_size)
value = value.view(-1, self.num_kv_heads, self.head_size)
# Reshape the keys and values and store them in the cache.
# If key_cache and value_cache are not provided, the new key and value
# vectors will not be cached. This happens during the initial memory
# profiling run.
if key_cache is not None and value_cache is not None:
PagedAttentionImpl.reshape_and_cache(key, value, key_cache,
value_cache, input_metadata)
if kv_cache is not None:
key_cache, value_cache = PagedAttention.split_kv_cache(
kv_cache, self.num_kv_heads, self.head_size)
# Reshape the input keys and values and store them in the cache.
# If kv_cache is not provided, the new key and value tensors are
# not cached. This happens during the initial memory profiling run.
PagedAttention.write_to_paged_cache(key, value, key_cache,
value_cache,
attn_metadata.slot_mapping,
attn_metadata.kv_cache_dtype)
if input_metadata.is_prompt:
if attn_metadata.is_prompt:
# Prompt run.
# key_cache and value_cache are None when it is a profiling run.
# block tables are empty if the prompt has never been computed.
if (key_cache is None or value_cache is None
or input_metadata.block_tables.numel() == 0):
if kv_cache is None or attn_metadata.block_tables.numel() == 0:
# normal attention.
# block tables are empty if the prompt does not have a cached
# prefix.
if self.num_kv_heads != self.num_heads:
# As of Nov 2023, xformers only supports MHA. For MQA/GQA,
# project the key and value tensors to the desired number of
......@@ -118,13 +230,12 @@ class XFormersBackend:
self.num_queries_per_kv,
value.shape[-1])
if self.use_ref_attention:
print("ref attention used.")
if self.use_naive_attention:
output = torch.empty_like(query)
start = 0
for _, prompt_len in enumerate(input_metadata.prompt_lens):
for _, prompt_len in enumerate(attn_metadata.prompt_lens):
end = start + prompt_len
out = _ref_masked_attention(
out = _naive_masked_attention(
query[None, start:end],
key[None, start:end],
value[None, start:end],
......@@ -143,26 +254,33 @@ class XFormersBackend:
# Use reshape instead.
return output.reshape(num_tokens, hidden_size)
output = self._run_memory_efficient_xformer_forward(
query, key, value, input_metadata)
output = self._run_memory_efficient_xformers_forward(
query, key, value, attn_metadata)
else:
# prefix-enabled attention
output = PagedAttentionImpl.forward_prefix(
output = PagedAttention.forward_prefix(
query,
key,
value,
key_cache,
value_cache,
input_metadata,
attn_metadata.block_tables,
attn_metadata.subquery_start_loc,
attn_metadata.prompt_lens_tensor,
attn_metadata.context_lens,
attn_metadata.max_subquery_len,
self.alibi_slopes,
)
else:
# Decoding run.
output = PagedAttentionImpl.forward_decode(
output = PagedAttention.forward_decode(
query,
key_cache,
value_cache,
input_metadata,
attn_metadata.block_tables,
attn_metadata.context_lens,
attn_metadata.max_context_len,
attn_metadata.kv_cache_dtype,
self.num_kv_heads,
self.scale,
self.alibi_slopes,
......@@ -171,12 +289,12 @@ class XFormersBackend:
# Reshape the output tensor.
return output.view(-1, self.num_heads * self.head_size)
def _run_memory_efficient_xformer_forward(
def _run_memory_efficient_xformers_forward(
self,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
input_metadata: InputMetadata,
attn_metadata: XFormersMetadata,
) -> torch.Tensor:
"""Attention for 1D query of multiple prompts. Multiple prompt
tokens are flattened in to `query` input.
......@@ -186,23 +304,23 @@ class XFormersBackend:
query: shape = [num_prompt_tokens, num_heads, head_size]
key: shape = [num_prompt_tokens, num_kv_heads, head_size]
value: shape = [num_prompt_tokens, num_kv_heads, head_size]
input_metadata: metadata for paged attention.
attn_metadata: Metadata for attention.
"""
# Set attention bias if not provided. This typically happens at
# the very attention layer of every iteration.
# FIXME(woosuk): This is a hack.
if input_metadata.attn_bias is None:
if attn_metadata.attn_bias is None:
if self.alibi_slopes is None:
attn_bias = BlockDiagonalCausalMask.from_seqlens(
input_metadata.prompt_lens)
attn_metadata.prompt_lens)
if self.sliding_window is not None:
attn_bias = attn_bias.make_local_attention(
self.sliding_window)
input_metadata.attn_bias = [attn_bias]
attn_metadata.attn_bias = [attn_bias]
else:
input_metadata.attn_bias = _make_alibi_bias(
attn_metadata.attn_bias = _make_alibi_bias(
self.alibi_slopes, self.num_kv_heads, query.dtype,
input_metadata)
attn_metadata.prompt_lens)
op = xops.fmha.MemoryEfficientAttentionFlashAttentionOp[0] if (
is_hip()) else None
......@@ -217,7 +335,7 @@ class XFormersBackend:
query,
key,
value,
attn_bias=input_metadata.attn_bias[0],
attn_bias=attn_metadata.attn_bias[0],
p=0.0,
scale=self.scale,
op=op)
......@@ -230,13 +348,13 @@ class XFormersBackend:
# one. This is inefficient, especially when we have many short prompts.
output = torch.empty_like(query)
start = 0
for i, prompt_len in enumerate(input_metadata.prompt_lens):
for i, prompt_len in enumerate(attn_metadata.prompt_lens):
end = start + prompt_len
out = xops.memory_efficient_attention_forward(
query[None, start:end],
key[None, start:end],
value[None, start:end],
attn_bias=input_metadata.attn_bias[i],
attn_bias=attn_metadata.attn_bias[i],
p=0.0,
scale=self.scale,
op=op)
......@@ -250,10 +368,10 @@ def _make_alibi_bias(
alibi_slopes: torch.Tensor,
num_kv_heads: int,
dtype: torch.dtype,
input_metadata: InputMetadata,
prompt_lens: List[int],
) -> LowerTriangularMaskWithTensorBias:
attn_biases = []
for prompt_len in input_metadata.prompt_lens:
for prompt_len in prompt_lens:
bias = torch.arange(prompt_len, dtype=dtype)
# NOTE(zhuohan): HF uses
# `bias = bias[None, :].repeat(prompt_len, 1)`
......@@ -282,15 +400,19 @@ def _make_alibi_bias(
return attn_biases
def _check_use_ref_attention() -> bool:
def _check_use_naive_attention() -> bool:
if not is_hip():
return False
# For ROCm, check whether flash attention is installed or not.
# if not, use_ref_attention needs to be True
return importlib.util.find_spec("flash_attn") is None
has_flash_attn = importlib.util.find_spec("flash_attn") is None
if not has_flash_attn:
logger.warning("flash_attn is not installed. Using naive attention. "
"This will take significantly more GPU memory.")
return True
return False
def _ref_masked_attention(
def _naive_masked_attention(
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
......
"""Attention layer."""
from functools import lru_cache
from typing import List, Optional
import torch
import torch.nn as nn
from vllm.logger import init_logger
from vllm.model_executor.input_metadata import InputMetadata
from vllm.utils import is_hip
logger = init_logger(__name__)
from vllm.attention.backends.abstract import AttentionMetadata
from vllm.attention.selector import get_attn_backend
class Attention(nn.Module):
......@@ -17,12 +13,11 @@ class Attention(nn.Module):
This class takes query, key, and value tensors as input. The input tensors
can either contain prompt tokens or generation tokens.
The class does the following:
1. Store the input key and value tensors in the KV cache.
2. Perform (multi-head/multi-query/grouped-query) attention.
3. Output the output tensor.
3. Return the output tensor.
"""
def __init__(
......@@ -35,51 +30,17 @@ class Attention(nn.Module):
sliding_window: Optional[int] = None,
) -> None:
super().__init__()
if _use_flash_attn():
from vllm.model_executor.layers.attention.backends.flash_attn import FlashAttentionBackend # noqa: E501
self.backend = FlashAttentionBackend(num_heads, head_size, scale,
num_kv_heads, alibi_slopes,
sliding_window)
else:
from vllm.model_executor.layers.attention.backends.xformers import XFormersBackend # noqa: E501
self.backend = XFormersBackend(num_heads, head_size, scale,
num_kv_heads, alibi_slopes,
sliding_window)
self.backend = get_attn_backend(torch.get_default_dtype())
impl_cls = self.backend.get_impl_cls()
self.impl = impl_cls(num_heads, head_size, scale, num_kv_heads,
alibi_slopes, sliding_window)
def forward(
self,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
key_cache: Optional[torch.Tensor],
value_cache: Optional[torch.Tensor],
input_metadata: InputMetadata,
kv_cache: Optional[torch.Tensor],
attn_metadata: AttentionMetadata,
) -> torch.Tensor:
return self.backend.forward(query, key, value, key_cache, value_cache,
input_metadata)
@lru_cache(maxsize=1)
def _use_flash_attn() -> bool:
try:
import flash_attn # noqa: F401
except ImportError:
logger.info("flash_attn is not found. Using xformers backend.")
return False
if is_hip():
# AMD GPUs.
return False
if torch.cuda.get_device_capability()[0] < 8:
# Volta and Turing NVIDIA GPUs.
logger.info("flash_attn is not supported on Turing or older GPUs. "
"Using xformers backend.")
return False
if torch.get_default_dtype() not in (torch.float16, torch.bfloat16):
logger.info(
"flash_attn only supports torch.float16 or torch.bfloat16. "
"Using xformers backend.")
return False
logger.info("Using flash_attn backend.")
return True
return self.impl.forward(query, key, value, kv_cache, attn_metadata)
from typing import List, Optional
from dataclasses import dataclass
from typing import Dict, List, Optional, Tuple
import torch
from vllm._C import cache_ops
from vllm._C import ops
from vllm.model_executor.input_metadata import InputMetadata
from vllm.model_executor.layers.attention.ops.prefix_prefill import (
context_attention_fwd)
from vllm.attention.ops.prefix_prefill import context_attention_fwd
# Should be the same as PARTITION_SIZE in `paged_attention_v2_launcher`.
_PARTITION_SIZE = 512
class PagedAttentionImpl:
@dataclass
class PagedAttentionMetadata:
"""Metadata for PagedAttention."""
# (num_tokens,). The indices of the token slots that input tokens will be
# stored into. E.g., if `slot_mapping` is [35, 2, 17] and the block size
# is 16, the three tokens are stored in the 3rd slot in block 2, 2nd slot
# in block 0, and 1st slot in block 1, respectively.
slot_mapping: torch.Tensor
# (batch_size,). The length of context (tokens stored in KV cache) per
# sequence. WARNING: When it is a prefill request, it doesn't include new
# tokens. When it is for decoding, it includes a new token.
context_lens: Optional[torch.Tensor]
# Maximum context length in the batch.
max_context_len: Optional[int]
# (batch_size, max_blocks_per_seq).
# Block addresses per sequence. (Seq id -> list of physical block)
# E.g., [0, 1, 2] means tokens are stored in 0th, 1st, and 2nd blocks
# in the kv cache. Each block can contain up to block_size tokens.
# 2nd dimensions are padded up to max_blocks_per_seq if it is cuda-graph
# captured.
block_tables: Optional[torch.Tensor]
kv_cache_dtype: str
class PagedAttention:
@staticmethod
def get_supported_head_sizes() -> List[int]:
return [64, 80, 96, 112, 128, 256]
@staticmethod
def reshape_and_cache(
def get_kv_cache_shape(
num_blocks: int,
block_size: int,
num_kv_heads: int,
head_size: int,
) -> Tuple[int, ...]:
return (2, num_blocks, block_size * num_kv_heads * head_size)
@staticmethod
def split_kv_cache(
kv_cache: torch.Tensor,
num_kv_heads: int,
head_size: int,
) -> Tuple[torch.Tensor, torch.Tensor]:
x = 16 // kv_cache.element_size()
num_blocks = kv_cache.shape[1]
key_cache = kv_cache[0]
key_cache = key_cache.view(num_blocks, num_kv_heads, head_size // x,
-1, x)
value_cache = kv_cache[1]
value_cache = value_cache.view(num_blocks, num_kv_heads, head_size, -1)
return key_cache, value_cache
@staticmethod
def write_to_paged_cache(
key: torch.Tensor,
value: torch.Tensor,
key_cache: torch.Tensor,
value_cache: torch.Tensor,
input_metadata: InputMetadata,
slot_mapping: torch.Tensor,
kv_cache_dtype: str,
) -> None:
cache_ops.reshape_and_cache(
key,
value,
key_cache,
value_cache,
input_metadata.slot_mapping.flatten(),
input_metadata.kv_cache_dtype,
slot_mapping.flatten(),
kv_cache_dtype,
)
@staticmethod
......@@ -40,7 +89,10 @@ class PagedAttentionImpl:
query: torch.Tensor,
key_cache: torch.Tensor,
value_cache: torch.Tensor,
input_metadata: InputMetadata,
block_tables: torch.Tensor,
context_lens: torch.Tensor,
max_context_len: int,
kv_cache_dtype: str,
num_kv_heads: int,
scale: float,
alibi_slopes: Optional[torch.Tensor],
......@@ -49,9 +101,8 @@ class PagedAttentionImpl:
block_size = value_cache.shape[3]
num_seqs, num_heads, head_size = query.shape
max_num_partitions = (
(input_metadata.max_context_len + _PARTITION_SIZE - 1) //
_PARTITION_SIZE)
max_num_partitions = ((max_context_len + _PARTITION_SIZE - 1) //
_PARTITION_SIZE)
# NOTE(woosuk): We use a simple heuristic to decide whether to use
# PagedAttention V1 or V2. If the number of partitions is 1, we use
# V1 to avoid the overhead of reduction. Also, if the number of
......@@ -59,8 +110,8 @@ class PagedAttentionImpl:
# to parallelize.
# TODO(woosuk): Tune this heuristic.
# For context len > 8192, use V2 kernel to avoid shared memory shortage.
use_v1 = input_metadata.max_context_len <= 8192 and (
max_num_partitions == 1 or num_seqs * num_heads > 512)
use_v1 = (max_context_len <= 8192
and (max_num_partitions == 1 or num_seqs * num_heads > 512))
if use_v1:
# Run PagedAttention V1.
ops.paged_attention_v1(
......@@ -70,12 +121,12 @@ class PagedAttentionImpl:
value_cache,
num_kv_heads,
scale,
input_metadata.block_tables,
input_metadata.context_lens,
block_tables,
context_lens,
block_size,
input_metadata.max_context_len,
max_context_len,
alibi_slopes,
input_metadata.kv_cache_dtype,
kv_cache_dtype,
)
else:
# Run PagedAttention V2.
......@@ -101,12 +152,12 @@ class PagedAttentionImpl:
value_cache,
num_kv_heads,
scale,
input_metadata.block_tables,
input_metadata.context_lens,
block_tables,
context_lens,
block_size,
input_metadata.max_context_len,
max_context_len,
alibi_slopes,
input_metadata.kv_cache_dtype,
kv_cache_dtype,
)
return output
......@@ -117,7 +168,11 @@ class PagedAttentionImpl:
value: torch.Tensor,
key_cache: torch.Tensor,
value_cache: torch.Tensor,
input_metadata: InputMetadata,
block_tables: torch.Tensor,
subquery_start_loc: torch.Tensor,
prompt_lens_tensor: torch.Tensor,
context_lens: torch.Tensor,
max_subquery_len: int,
alibi_slopes: Optional[torch.Tensor],
) -> torch.Tensor:
output = torch.empty_like(query)
......@@ -128,12 +183,35 @@ class PagedAttentionImpl:
output,
key_cache,
value_cache,
input_metadata.block_tables,
block_tables,
# subquery_start_loc is (batch_size + 1,)
input_metadata.subquery_start_loc[:-1],
input_metadata.prompt_lens_tensor,
input_metadata.context_lens,
input_metadata.max_subquery_len,
subquery_start_loc[:-1],
prompt_lens_tensor,
context_lens,
max_subquery_len,
alibi_slopes,
)
return output
@staticmethod
def swap_blocks(
src_kv_cache: torch.Tensor,
dst_kv_cache: torch.Tensor,
src_to_dst: Dict[int, int],
) -> None:
src_key_cache = src_kv_cache[0]
dst_key_cache = dst_kv_cache[0]
cache_ops.swap_blocks(src_key_cache, dst_key_cache, src_to_dst)
src_value_cache = src_kv_cache[1]
dst_value_cache = dst_kv_cache[1]
cache_ops.swap_blocks(src_value_cache, dst_value_cache, src_to_dst)
@staticmethod
def copy_blocks(
kv_caches: List[torch.Tensor],
src_to_dists: Dict[int, List[int]],
) -> None:
key_caches = [kv_cache[0] for kv_cache in kv_caches]
value_caches = [kv_cache[1] for kv_cache in kv_caches]
cache_ops.copy_blocks(key_caches, value_caches, src_to_dists)
from functools import lru_cache
import torch
from vllm.attention.backends.abstract import AttentionBackend
from vllm.logger import init_logger
from vllm.utils import is_hip
logger = init_logger(__name__)
@lru_cache(maxsize=None)
def get_attn_backend(dtype: torch.dtype) -> AttentionBackend:
if _can_use_flash_attn(dtype):
logger.info("Using FlashAttention backend.")
from vllm.attention.backends.flash_attn import FlashAttentionBackend # noqa: F401
return FlashAttentionBackend
else:
logger.info("Using XFormers backend.")
from vllm.attention.backends.xformers import XFormersBackend # noqa: F401
return XFormersBackend
def _can_use_flash_attn(dtype: torch.dtype) -> bool:
if is_hip():
# AMD GPUs.
logger.info("Cannot use FlashAttention backend for AMD GPUs.")
return False
if torch.cuda.get_device_capability()[0] < 8:
# Volta and Turing NVIDIA GPUs.
logger.info("Cannot use FlashAttention backend for Volta and Turing "
"GPUs.")
return False
if dtype not in (torch.float16, torch.bfloat16):
logger.info("Cannot use FlashAttention backend for dtype other than "
"torch.float16 or torch.bfloat16.")
return False
try:
import flash_attn # noqa: F401
except ImportError:
logger.info("flash_attn is not found.")
return False
return True
from vllm.model_executor.input_metadata import InputMetadata
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.model_executor.utils import set_random_seed
__all__ = [
"InputMetadata",
"SamplingMetadata",
"set_random_seed",
]
from dataclasses import dataclass, fields
from typing import TYPE_CHECKING, Optional, List, Any, Dict
import torch
if TYPE_CHECKING:
from xformers.ops.fmha.attn_bias import AttentionBias
@dataclass
class InputMetadata:
"""Metadata for input sequences. Used in PagedAttention.
NOTE: Any python object stored here is not updated when it is
cuda-graph replayed. If you have values that need to be changed
dynamically, it should be stored in tensor. The tensor has to be
updated from `CUDAGraphRunner.forward` API.
"""
# Currently, input sequences can only contain all prompts
# or all decoding. True if all sequences are prompts.
is_prompt: bool
# (num_tokens,). The indices of the token slots that input tokens will be
# stored into. E.g., if `slot_mapping` is [35, 2, 17] and the block size
# is 16, the three tokens are stored in the 3rd slot in block 2, 2nd slot
# in block 0, and 1st slot in block 1, respectively.
slot_mapping: torch.Tensor
# (batch_size,). The prompt length per sequence. None if it is a decoding.
prompt_lens: Optional[List[int]]
# prompt_lens stored as a tensor.
prompt_lens_tensor: Optional[torch.Tensor]
# The number of prompt tokens. Doesn't include padding.
num_prompt_tokens: int
# The number of generation tokens. Doesn't include padding.
num_generation_tokens: int
"""
Definition of context_len, subquery_len, and seqlen.
|---------- N-1 iteration --------|
|---------------- N iteration ---------------------|
|- tokenA -|......................|-- newTokens ---|
|---------- context_len ----------|
|-------------------- seqlen ----------------------|
|- subquery_len -|
WARNING: context_len has different definition depending on if it is
prefill vs decoding. When it is prefill, it doesn't include new
tokens. When it is for decoding, it includes a new token.
"""
# Maximum subquery length in the batch.
max_subquery_len: Optional[int]
# Maximum context length in the batch.
max_context_len: Optional[int]
# FIXME: It is for flash attn.
# Maximum sequence length in the batch.
max_seq_len: Optional[int]
# (batch_size + 1,). The cumulative subquery lengths of the sequences in
# the batch, used to index into subquery. E.g., if the subquery length
# is [4, 6], it is [0, 4, 10].
subquery_start_loc: Optional[torch.Tensor]
# FIXME: It is for flash attn.
# (batch_size + 1,). The cumulative sequence lengths of the sequences in
# the batch, used to index into sequence. E.g., if the sequence length is
# [4, 6], it is [0, 4, 10].
seq_start_loc: Optional[torch.Tensor]
# (batch_size,). The length of context (tokens stored in KV cache) per
# sequence. WARNING: When it is a prefill request, it doesn't include new
# tokens. When it is for decoding, it includes a new token.
context_lens: Optional[torch.Tensor]
# (batch_size, max_blocks_per_seq).
# Block addresses per sequence. (Seq id -> list of physical block)
# E.g., [0, 1, 2] means tokens are stored in 0th, 1st, and 2nd blocks
# in the kv cache. Each block can contain up to block_size tokens.
# 2nd dimensions are padded up to max_blocks_per_seq if it is cuda-graph
# captured.
block_tables: Optional[torch.Tensor]
# Whether or not if cuda graph is enabled.
# Cuda-graph is currently enabled for decoding only.
use_cuda_graph: bool
kv_cache_dtype: str
def __post_init__(self):
# Set during the execution of the first attention op.
# It is a list because it is needed to set per prompt
# when alibi slopes is used. It is because of the limitation
# from xformer API.
# will not appear in the __repr__ and __init__
self.attn_bias: Optional[List["AttentionBias"]] = None
# Cuda graph is only used for decoding now.
if self.use_cuda_graph:
assert self.num_prompt_tokens == 0
def asdict_zerocopy(self) -> Dict[str, Any]:
"""Similar to dataclasses.asdict, but avoids deepcopying."""
# Note that if we add dataclasses as fields, they will need
# similar handling.
return {
field.name: getattr(self, field.name)
for field in fields(self)
}
from vllm.model_executor.layers.attention.attention import Attention
__all__ = [
"Attention",
]
......@@ -25,9 +25,8 @@ import torch
from torch import nn
from transformers import PretrainedConfig
from vllm.model_executor.input_metadata import InputMetadata
from vllm.attention import Attention, AttentionMetadata
from vllm.model_executor.layers.activation import SiluAndMul
from vllm.model_executor.layers.attention import Attention
from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.linear import (LinearMethodBase,
MergedColumnParallelLinear,
......@@ -45,8 +44,6 @@ from vllm.model_executor.weight_utils import (default_weight_loader,
hf_model_weights_iterator)
from vllm.sequence import SamplerOutput
KVCache = Tuple[torch.Tensor, torch.Tensor]
def _get_alibi_slopes(total_num_heads: int) -> torch.Tensor:
closest_power_of_2 = 2**math.floor(math.log2(total_num_heads))
......@@ -170,15 +167,14 @@ class BaiChuanAttention(nn.Module):
self,
positions: torch.Tensor,
hidden_states: torch.Tensor,
kv_cache: KVCache,
input_metadata: InputMetadata,
kv_cache: torch.Tensor,
attn_metadata: AttentionMetadata,
) -> torch.Tensor:
qkv, _ = self.W_pack(hidden_states)
q, k, v = qkv.chunk(chunks=3, dim=-1)
if self.postion_embedding != "ALIBI":
q, k = self.rotary_emb(positions, q, k)
k_cache, v_cache = kv_cache
attn_output = self.attn(q, k, v, k_cache, v_cache, input_metadata)
attn_output = self.attn(q, k, v, kv_cache, attn_metadata)
output, _ = self.o_proj(attn_output)
return output
......@@ -217,8 +213,8 @@ class BaiChuanDecoderLayer(nn.Module):
self,
positions: torch.Tensor,
hidden_states: torch.Tensor,
kv_cache: KVCache,
input_metadata: InputMetadata,
kv_cache: torch.Tensor,
attn_metadata: AttentionMetadata,
residual: Optional[torch.Tensor],
) -> Tuple[torch.Tensor, torch.Tensor]:
# Self Attention
......@@ -232,7 +228,7 @@ class BaiChuanDecoderLayer(nn.Module):
positions=positions,
hidden_states=hidden_states,
kv_cache=kv_cache,
input_metadata=input_metadata,
attn_metadata=attn_metadata,
)
# Fully Connected
......@@ -267,8 +263,8 @@ class BaiChuanModel(nn.Module):
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
kv_caches: List[KVCache],
input_metadata: InputMetadata,
kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata,
) -> torch.Tensor:
hidden_states = self.embed_tokens(input_ids)
residual = None
......@@ -278,7 +274,7 @@ class BaiChuanModel(nn.Module):
positions,
hidden_states,
kv_caches[i],
input_metadata,
attn_metadata,
residual,
)
hidden_states, _ = self.norm(hidden_states, residual)
......@@ -303,11 +299,11 @@ class BaiChuanBaseForCausalLM(nn.Module):
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
kv_caches: List[KVCache],
input_metadata: InputMetadata,
kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata,
) -> torch.Tensor:
hidden_states = self.model(input_ids, positions, kv_caches,
input_metadata)
attn_metadata)
return hidden_states
def compute_logits(self, hidden_states: torch.Tensor,
......
......@@ -17,15 +17,14 @@
# limitations under the License.
"""Inference-only BLOOM model compatible with HuggingFace weights."""
import math
from typing import List, Optional, Tuple
from typing import List, Optional
import torch
from torch import nn
from transformers import BloomConfig
from vllm.model_executor.input_metadata import InputMetadata
from vllm.attention import Attention, AttentionMetadata
from vllm.model_executor.layers.activation import get_act_fn
from vllm.model_executor.layers.attention import Attention
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
LinearMethodBase,
QKVParallelLinear,
......@@ -41,8 +40,6 @@ from vllm.model_executor.weight_utils import (default_weight_loader,
hf_model_weights_iterator)
from vllm.sequence import SamplerOutput
KVCache = Tuple[torch.Tensor, torch.Tensor]
def _get_alibi_slopes(total_num_heads: int) -> torch.Tensor:
closest_power_of_2 = 2**math.floor(math.log2(total_num_heads))
......@@ -117,14 +114,13 @@ class BloomAttention(nn.Module):
self,
position_ids: torch.Tensor,
hidden_states: torch.Tensor,
kv_cache: KVCache,
input_metadata: InputMetadata,
kv_cache: torch.Tensor,
attn_metadata: AttentionMetadata,
) -> torch.Tensor:
del position_ids # Unused.
qkv, _ = self.query_key_value(hidden_states)
q, k, v = qkv.chunk(chunks=3, dim=-1)
k_cache, v_cache = kv_cache
attn_output = self.attn(q, k, v, k_cache, v_cache, input_metadata)
attn_output = self.attn(q, k, v, kv_cache, attn_metadata)
output, _ = self.dense(attn_output)
return output
......@@ -181,8 +177,8 @@ class BloomBlock(nn.Module):
self,
position_ids: torch.Tensor,
hidden_states: torch.Tensor,
kv_cache: KVCache,
input_metadata: InputMetadata,
kv_cache: torch.Tensor,
attn_metadata: AttentionMetadata,
) -> torch.Tensor:
# Layer norm at the beginning of the transformer layer.
layernorm_output = self.input_layernorm(hidden_states)
......@@ -198,7 +194,7 @@ class BloomBlock(nn.Module):
position_ids=position_ids,
hidden_states=layernorm_output,
kv_cache=kv_cache,
input_metadata=input_metadata,
attn_metadata=attn_metadata,
)
attention_output = attention_output + residual
layernorm_output = self.post_attention_layernorm(attention_output)
......@@ -245,8 +241,8 @@ class BloomModel(nn.Module):
self,
input_ids: torch.Tensor,
position_ids: torch.Tensor,
kv_caches: List[KVCache],
input_metadata: InputMetadata,
kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata,
) -> torch.Tensor:
hidden_states = self.word_embeddings(input_ids)
hidden_states = self.word_embeddings_layernorm(hidden_states)
......@@ -256,7 +252,7 @@ class BloomModel(nn.Module):
position_ids,
hidden_states,
kv_caches[i],
input_metadata,
attn_metadata,
)
hidden_states = self.ln_f(hidden_states)
return hidden_states
......@@ -281,11 +277,11 @@ class BloomForCausalLM(nn.Module):
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
kv_caches: List[KVCache],
input_metadata: InputMetadata,
kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata,
) -> torch.Tensor:
hidden_states = self.transformer(input_ids, positions, kv_caches,
input_metadata)
attn_metadata)
return hidden_states
def compute_logits(self, hidden_states: torch.Tensor,
......
......@@ -2,15 +2,14 @@
# Adapted from
# https://github.com/THUDM/ChatGLM2-6B
"""Inference-only ChatGLM model compatible with THUDM weights."""
from typing import List, Optional, Tuple
from typing import List, Optional
import torch
from torch import nn
from torch.nn import LayerNorm
from vllm.model_executor.input_metadata import InputMetadata
from vllm.attention import Attention, AttentionMetadata
from vllm.model_executor.layers.activation import SiluAndMul
from vllm.model_executor.layers.attention import Attention
from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.linear import (LinearMethodBase,
MergedColumnParallelLinear,
......@@ -29,8 +28,6 @@ from vllm.model_executor.weight_utils import (default_weight_loader,
from vllm.sequence import SamplerOutput
from vllm.transformers_utils.configs import ChatGLMConfig
KVCache = Tuple[torch.Tensor, torch.Tensor]
class GLMAttention(nn.Module):
......@@ -99,20 +96,18 @@ class GLMAttention(nn.Module):
self,
hidden_states: torch.Tensor,
position_ids: torch.Tensor,
kv_cache: KVCache,
input_metadata: InputMetadata,
kv_cache: torch.Tensor,
attn_metadata: AttentionMetadata,
) -> torch.Tensor:
qkv, _ = self.query_key_value(hidden_states)
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
q, k = self.rotary_emb(position_ids, q, k)
key_cache, value_cache = kv_cache
context_layer = self.attn(
q,
k,
v,
key_cache,
value_cache,
input_metadata,
kv_cache,
attn_metadata,
)
attn_output, _ = self.dense(context_layer)
return attn_output
......@@ -200,8 +195,8 @@ class GLMBlock(nn.Module):
self,
hidden_states: torch.Tensor,
position_ids: torch.Tensor,
kv_cache: KVCache,
input_metadata: InputMetadata,
kv_cache: torch.Tensor,
attn_metadata: AttentionMetadata,
) -> torch.Tensor:
# hidden_states: [num_tokens, h]
# Layer norm at the beginning of the transformer layer.
......@@ -211,7 +206,7 @@ class GLMBlock(nn.Module):
hidden_states=layernorm_output,
position_ids=position_ids,
kv_cache=kv_cache,
input_metadata=input_metadata,
attn_metadata=attn_metadata,
)
# Residual connection.
......@@ -264,8 +259,8 @@ class GLMTransformer(nn.Module):
self,
hidden_states: torch.Tensor,
position_ids: torch.Tensor,
kv_caches: List[KVCache],
input_metadata: InputMetadata,
kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata,
) -> torch.Tensor:
for i in range(self.num_layers):
layer = self.layers[i]
......@@ -273,7 +268,7 @@ class GLMTransformer(nn.Module):
hidden_states=hidden_states,
position_ids=position_ids,
kv_cache=kv_caches[i],
input_metadata=input_metadata,
attn_metadata=attn_metadata,
)
# Final layer norm.
if self.post_layer_norm:
......@@ -306,8 +301,8 @@ class ChatGLMModel(nn.Module):
self,
input_ids: torch.Tensor,
position_ids: torch.Tensor,
kv_caches: List[KVCache],
input_metadata: InputMetadata,
kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata,
) -> torch.Tensor:
inputs_embeds = self.embedding(input_ids)
......@@ -316,7 +311,7 @@ class ChatGLMModel(nn.Module):
hidden_states=inputs_embeds,
position_ids=position_ids,
kv_caches=kv_caches,
input_metadata=input_metadata,
attn_metadata=attn_metadata,
)
return hidden_states
......@@ -340,11 +335,11 @@ class ChatGLMForCausalLM(nn.Module):
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
kv_caches: List[KVCache],
input_metadata: InputMetadata,
kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata,
) -> torch.Tensor:
hidden_states = self.transformer(input_ids, positions, kv_caches,
input_metadata)
attn_metadata)
return hidden_states
def compute_logits(self, hidden_states: torch.Tensor,
......
......@@ -21,15 +21,14 @@
# See the License for the specific language governing permissions and
# limitations under the License.
"""Inference-only Deepseek model."""
from typing import Any, Dict, List, Optional, Tuple
from typing import Any, Dict, List, Optional
import torch
from torch import nn
from transformers import PretrainedConfig
from vllm.model_executor.input_metadata import InputMetadata
from vllm.attention import Attention, AttentionMetadata
from vllm.model_executor.layers.activation import SiluAndMul
from vllm.model_executor.layers.attention import Attention
from vllm.model_executor.layers.fused_moe import fused_moe
from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.linear import (LinearMethodBase,
......@@ -51,8 +50,6 @@ from vllm.model_executor.weight_utils import (default_weight_loader,
hf_model_weights_iterator)
from vllm.sequence import SamplerOutput
KVCache = Tuple[torch.Tensor, torch.Tensor]
class DeepseekMLP(nn.Module):
......@@ -239,14 +236,13 @@ class DeepseekAttention(nn.Module):
self,
positions: torch.Tensor,
hidden_states: torch.Tensor,
kv_cache: KVCache,
input_metadata: InputMetadata,
kv_cache: torch.Tensor,
attn_metadata: AttentionMetadata,
) -> torch.Tensor:
qkv, _ = self.qkv_proj(hidden_states)
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
q, k = self.rotary_emb(positions, q, k)
k_cache, v_cache = kv_cache
attn_output = self.attn(q, k, v, k_cache, v_cache, input_metadata)
attn_output = self.attn(q, k, v, kv_cache, attn_metadata)
output, _ = self.o_proj(attn_output)
return output
......@@ -294,8 +290,8 @@ class DeepseekDecoderLayer(nn.Module):
self,
positions: torch.Tensor,
hidden_states: torch.Tensor,
kv_cache: KVCache,
input_metadata: InputMetadata,
kv_cache: torch.Tensor,
attn_metadata: AttentionMetadata,
residual: Optional[torch.Tensor],
) -> torch.Tensor:
# Self Attention
......@@ -309,7 +305,7 @@ class DeepseekDecoderLayer(nn.Module):
positions=positions,
hidden_states=hidden_states,
kv_cache=kv_cache,
input_metadata=input_metadata,
attn_metadata=attn_metadata,
)
# Fully Connected
......@@ -346,15 +342,15 @@ class DeepseekModel(nn.Module):
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
kv_caches: List[KVCache],
input_metadata: InputMetadata,
kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata,
) -> torch.Tensor:
hidden_states = self.embed_tokens(input_ids)
residual = None
for i in range(len(self.layers)):
layer = self.layers[i]
hidden_states, residual = layer(positions, hidden_states,
kv_caches[i], input_metadata,
kv_caches[i], attn_metadata,
residual)
hidden_states, _ = self.norm(hidden_states, residual)
return hidden_states
......@@ -379,11 +375,11 @@ class DeepseekForCausalLM(nn.Module):
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
kv_caches: List[KVCache],
input_metadata: InputMetadata,
kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata,
) -> torch.Tensor:
hidden_states = self.model(input_ids, positions, kv_caches,
input_metadata)
attn_metadata)
return hidden_states
def compute_logits(self, hidden_states: torch.Tensor,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment