Unverified Commit 925f3332 authored by Woosuk Kwon's avatar Woosuk Kwon Committed by GitHub
Browse files

[Core] Refactor Attention Take 2 (#3462)

parent b0dfa91d
...@@ -3,8 +3,7 @@ import pytest ...@@ -3,8 +3,7 @@ import pytest
import time import time
import torch import torch
from vllm.model_executor.layers.attention.ops.prefix_prefill import ( from vllm.attention.ops.prefix_prefill import context_attention_fwd
context_attention_fwd)
from xformers import ops as xops from xformers import ops as xops
from xformers.ops.fmha.attn_bias import BlockDiagonalCausalFromBottomRightMask from xformers.ops.fmha.attn_bias import BlockDiagonalCausalFromBottomRightMask
......
...@@ -2,7 +2,10 @@ ...@@ -2,7 +2,10 @@
Run `pytest tests/samplers/test_beam_search.py --forked`. Run `pytest tests/samplers/test_beam_search.py --forked`.
""" """
import gc
import pytest import pytest
import torch
# FIXME(zhuohan): The test can not pass if we: # FIXME(zhuohan): The test can not pass if we:
# 1. Increase max_tokens to 256. # 1. Increase max_tokens to 256.
...@@ -36,6 +39,10 @@ def test_beam_search_single_input( ...@@ -36,6 +39,10 @@ def test_beam_search_single_input(
vllm_outputs = vllm_model.generate_beam_search(example_prompts, beam_width, vllm_outputs = vllm_model.generate_beam_search(example_prompts, beam_width,
max_tokens) max_tokens)
del vllm_model del vllm_model
# NOTE(woosuk): For some reason, the following GC is required to avoid
# GPU OOM errors in the following tests using `vllm_runner`.
gc.collect()
torch.cuda.empty_cache()
for i in range(len(example_prompts)): for i in range(len(example_prompts)):
hf_output_ids, _ = hf_outputs[i] hf_output_ids, _ = hf_outputs[i]
......
...@@ -34,19 +34,19 @@ def test_prepare_prompt(batch_size): ...@@ -34,19 +34,19 @@ def test_prepare_prompt(batch_size):
expected_selected_token_indices.append(selected_token_start_idx + expected_selected_token_indices.append(selected_token_start_idx +
prompt_len - 1) prompt_len - 1)
selected_token_start_idx += prompt_len selected_token_start_idx += prompt_len
(input_tokens, input_positions, input_metadata, return_prompt_lens, _, _, (input_tokens, input_positions, attn_metadata, return_prompt_lens, _, _, _,
_, _) = (model_runner._prepare_prompt(seq_group_metadata_list)) _) = (model_runner._prepare_prompt(seq_group_metadata_list))
assert return_prompt_lens == prompt_lens assert return_prompt_lens == prompt_lens
# Verify input metadata is correct for prompts. # Verify input metadata is correct for prompts.
device = model_runner.device device = model_runner.device
assert input_metadata.is_prompt is True assert attn_metadata.is_prompt is True
assert torch.allclose(input_metadata.prompt_lens_tensor, assert torch.allclose(attn_metadata.prompt_lens_tensor,
torch.tensor(prompt_lens, device=device)) torch.tensor(prompt_lens, device=device))
assert input_metadata.prompt_lens == prompt_lens assert attn_metadata.prompt_lens == prompt_lens
assert input_metadata.num_prompt_tokens == sum(prompt_lens) assert attn_metadata.num_prompt_tokens == sum(prompt_lens)
assert input_metadata.num_generation_tokens == 0 assert attn_metadata.num_generation_tokens == 0
assert input_metadata.max_seq_len == max(prompt_lens) assert attn_metadata.max_prompt_len == max(prompt_lens)
# Test subquery start locs. # Test subquery start locs.
start_idx = 0 start_idx = 0
...@@ -55,7 +55,7 @@ def test_prepare_prompt(batch_size): ...@@ -55,7 +55,7 @@ def test_prepare_prompt(batch_size):
start_idx += prompt_len start_idx += prompt_len
start_loc.append(start_idx) start_loc.append(start_idx)
assert torch.allclose( assert torch.allclose(
input_metadata.subquery_start_loc, attn_metadata.subquery_start_loc,
torch.tensor(start_loc, dtype=torch.int32, device=device)) torch.tensor(start_loc, dtype=torch.int32, device=device))
# Test seq start locs. Note that for normal prefill it is # Test seq start locs. Note that for normal prefill it is
...@@ -67,22 +67,22 @@ def test_prepare_prompt(batch_size): ...@@ -67,22 +67,22 @@ def test_prepare_prompt(batch_size):
seq_start_loc.append(start_idx) seq_start_loc.append(start_idx)
assert torch.allclose( assert torch.allclose(
input_metadata.seq_start_loc, attn_metadata.seq_start_loc,
torch.tensor(start_loc, dtype=torch.int32, device=device)) torch.tensor(start_loc, dtype=torch.int32, device=device))
assert input_metadata.max_context_len is None assert attn_metadata.max_context_len is None
assert torch.allclose( assert torch.allclose(
input_metadata.context_lens, attn_metadata.context_lens,
torch.zeros(input_metadata.context_lens.shape[0], torch.zeros(attn_metadata.context_lens.shape[0],
dtype=torch.int, dtype=torch.int,
device=device)) device=device))
expected = torch.tensor([[] for _ in range(len(seq_group_metadata_list))], expected = torch.tensor([[] for _ in range(len(seq_group_metadata_list))],
dtype=torch.int32, dtype=torch.int32,
device=model_runner.device) device=model_runner.device)
assert torch.allclose(input_metadata.block_tables, expected) assert torch.allclose(attn_metadata.block_tables, expected)
# Cuda graph should not be used for prerill. # Cuda graph should not be used for prerill.
assert input_metadata.use_cuda_graph is False assert attn_metadata.use_cuda_graph is False
assert input_metadata.kv_cache_dtype == "auto" assert attn_metadata.kv_cache_dtype == "auto"
assert input_tokens.shape == (sum(prompt_lens), ) assert input_tokens.shape == (sum(prompt_lens), )
assert input_positions.shape == (sum(prompt_lens), ) assert input_positions.shape == (sum(prompt_lens), )
...@@ -140,34 +140,34 @@ def test_prepare_decode_cuda_graph(batch_size): ...@@ -140,34 +140,34 @@ def test_prepare_decode_cuda_graph(batch_size):
block_tables={0: [1]}, block_tables={0: [1]},
)) ))
input_tokens, input_positions, input_metadata, _, _, _ = ( input_tokens, input_positions, attn_metadata, _, _, _ = (
model_runner._prepare_decode(seq_group_metadata_list)) model_runner._prepare_decode(seq_group_metadata_list))
expected_bs = _get_graph_batch_size(len(seq_group_metadata_list)) expected_bs = _get_graph_batch_size(len(seq_group_metadata_list))
# Verify input metadata is correct for prompts. # Verify input metadata is correct for prompts.
device = model_runner.device device = model_runner.device
assert input_metadata.is_prompt is False assert attn_metadata.is_prompt is False
assert input_metadata.prompt_lens is None assert attn_metadata.prompt_lens is None
assert input_metadata.num_prompt_tokens == 0 assert attn_metadata.num_prompt_tokens == 0
assert input_metadata.num_generation_tokens == expected_bs assert attn_metadata.num_generation_tokens == expected_bs
assert input_metadata.max_seq_len is None assert attn_metadata.max_prompt_len is None
assert input_metadata.subquery_start_loc is None assert attn_metadata.subquery_start_loc is None
assert input_metadata.seq_start_loc is None assert attn_metadata.seq_start_loc is None
assert input_metadata.max_context_len == max(prompt_lens) assert attn_metadata.max_context_len == max(prompt_lens)
assert torch.allclose( assert torch.allclose(
input_metadata.context_lens[:len(prompt_lens)], attn_metadata.context_lens[:len(prompt_lens)],
torch.tensor(prompt_lens, dtype=torch.int, device=device)) torch.tensor(prompt_lens, dtype=torch.int, device=device))
# block table's first index corresponds to each batch, meaning in # block table's first index corresponds to each batch, meaning in
# decoding it is each token. # decoding it is each token.
assert input_metadata.block_tables.shape[0] == len(input_tokens) assert attn_metadata.block_tables.shape[0] == len(input_tokens)
# Block table's second dim correspondsd to each token's block number. # Block table's second dim correspondsd to each token's block number.
# It is padded up to # It is padded up to
assert input_metadata.block_tables.shape[1] == ( assert attn_metadata.block_tables.shape[1] == (
model_runner.get_max_block_per_batch()) model_runner.get_max_block_per_batch())
# Cuda graph should not be used for prerill. # Cuda graph should not be used for prerill.
assert input_metadata.use_cuda_graph is True assert attn_metadata.use_cuda_graph is True
assert input_metadata.kv_cache_dtype == "auto" assert attn_metadata.kv_cache_dtype == "auto"
assert input_tokens.shape == (expected_bs, ) assert input_tokens.shape == (expected_bs, )
assert input_positions.shape == (expected_bs, ) assert input_positions.shape == (expected_bs, )
......
from vllm.attention.backends.abstract import AttentionBackend, AttentionMetadata
from vllm.attention.layer import Attention
from vllm.attention.selector import get_attn_backend
__all__ = [
"AttentionBackend",
"AttentionMetadata",
"Attention",
"get_attn_backend",
]
from abc import ABC, abstractmethod
from dataclasses import dataclass, fields
from typing import Any, Dict, List, Optional, Tuple, Type
import torch
class AttentionBackend(ABC):
"""Abstract class for attention backends."""
@staticmethod
@abstractmethod
def get_impl_cls() -> Type["AttentionImpl"]:
raise NotImplementedError
@staticmethod
@abstractmethod
def make_metadata(*args, **kwargs) -> "AttentionMetadata":
raise NotImplementedError
@staticmethod
@abstractmethod
def get_kv_cache_shape(
num_blocks: int,
block_size: int,
num_kv_heads: int,
head_size: int,
) -> Tuple[int, ...]:
raise NotImplementedError
@staticmethod
@abstractmethod
def swap_blocks(
src_kv_cache: torch.Tensor,
dst_kv_cache: torch.Tensor,
src_to_dst: Dict[int, int],
) -> None:
raise NotImplementedError
@staticmethod
@abstractmethod
def copy_blocks(
kv_caches: List[torch.Tensor],
src_to_dists: Dict[int, List[int]],
) -> None:
raise NotImplementedError
@dataclass
class AttentionMetadata:
def asdict_zerocopy(self) -> Dict[str, Any]:
"""Similar to dataclasses.asdict, but avoids deepcopying."""
# Note that if we add dataclasses as fields, they will need
# similar handling.
return {
field.name: getattr(self, field.name)
for field in fields(self)
}
class AttentionImpl(ABC):
@abstractmethod
def __init__(
self,
num_heads: int,
head_size: int,
scale: float,
num_kv_heads: Optional[int] = None,
alibi_slopes: Optional[List[float]] = None,
sliding_window: Optional[int] = None,
) -> None:
raise NotImplementedError
@abstractmethod
def forward(
self,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
kv_cache: torch.Tensor,
attn_metadata: AttentionMetadata,
) -> torch.Tensor:
raise NotImplementedError
"""Attention layer with Flash and PagedAttention.""" """Attention layer with Flash and PagedAttention.
from typing import List, Optional
NOTE(woosuk): At the moment, this file includes a lot of duplicated code from
XFormers backend. The duplicated code will be removed once we use flash-attn or
flashinfer for all the attention operations.
"""
from dataclasses import dataclass
from typing import Dict, List, Optional, Tuple, Type
from flash_attn import flash_attn_varlen_func from flash_attn import flash_attn_varlen_func
import torch import torch
from vllm.model_executor.input_metadata import InputMetadata from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
from vllm.model_executor.layers.attention.ops.paged_attn import ( AttentionMetadata)
PagedAttentionImpl) from vllm.attention.ops.paged_attn import PagedAttention, PagedAttentionMetadata
class FlashAttentionBackend(AttentionBackend):
@staticmethod
def get_impl_cls() -> Type["FlashAttentionImpl"]:
return FlashAttentionImpl
@staticmethod
def make_metadata(*args, **kwargs) -> "FlashAttentionMetadata":
return FlashAttentionMetadata(*args, **kwargs)
@staticmethod
def get_kv_cache_shape(
num_blocks: int,
block_size: int,
num_kv_heads: int,
head_size: int,
) -> Tuple[int, ...]:
return PagedAttention.get_kv_cache_shape(num_blocks, block_size,
num_kv_heads, head_size)
@staticmethod
def swap_blocks(
src_kv_cache: torch.Tensor,
dst_kv_cache: torch.Tensor,
src_to_dst: Dict[int, int],
) -> None:
PagedAttention.swap_blocks(src_kv_cache, dst_kv_cache, src_to_dst)
@staticmethod
def copy_blocks(
kv_caches: List[torch.Tensor],
src_to_dists: Dict[int, List[int]],
) -> None:
PagedAttention.copy_blocks(kv_caches, src_to_dists)
@dataclass
class FlashAttentionMetadata(AttentionMetadata, PagedAttentionMetadata):
"""Metadata for FlashAttentionBackend.
class FlashAttentionBackend: NOTE: Any python object stored here is not updated when it is
cuda-graph replayed. If you have values that need to be changed
dynamically, it should be stored in tensor. The tensor has to be
updated from `CUDAGraphRunner.forward` API.
"""
# Currently, input sequences can only contain all prompts
# or all decoding. True if all sequences are prompts.
is_prompt: bool
# (batch_size,). The prompt length per sequence. None if it is a decoding.
prompt_lens: Optional[List[int]]
# prompt_lens stored as a tensor.
prompt_lens_tensor: Optional[torch.Tensor]
# The number of prompt tokens. Doesn't include padding.
num_prompt_tokens: int
# The number of generation tokens. Doesn't include padding.
num_generation_tokens: int
# NOTE(sang): Definition of context_len, subquery_len, and seqlen.
# |---------- N-1 iteration --------|
# |---------------- N iteration ---------------------|
# |- tokenA -|......................|-- newTokens ---|
# |---------- context_len ----------|
# |-------------------- seqlen ----------------------|
# |- subquery_len -|
# WARNING(sang): context_len has different definition depending on if it is
# prefill vs decoding. When it is prefill, it doesn't include new tokens.
# When it is for decoding, it includes a new token.
# Maximum subquery length in the batch.
max_subquery_len: Optional[int]
# Maximum prompt length in the batch.
max_prompt_len: Optional[int]
# (batch_size + 1,). The cumulative subquery lengths of the sequences in
# the batch, used to index into subquery. E.g., if the subquery length
# is [4, 6], it is [0, 4, 10].
subquery_start_loc: Optional[torch.Tensor]
# (batch_size + 1,). The cumulative sequence lengths of the sequences in
# the batch, used to index into sequence. E.g., if the sequence length is
# [4, 6], it is [0, 4, 10].
seq_start_loc: Optional[torch.Tensor]
# Whether or not if cuda graph is enabled.
# Cuda-graph is currently enabled for decoding only.
# TODO(woosuk): Move `use_cuda_graph` out since it's unrelated to attention.
use_cuda_graph: bool
class FlashAttentionImpl(AttentionImpl):
""" """
If the input tensors contain prompt tokens, the layout is as follows: If the input tensors contain prompt tokens, the layout is as follows:
|<--------------- num_prompt_tokens -------------->| |<--------------- num_prompt_tokens -------------->|
...@@ -39,30 +133,28 @@ class FlashAttentionBackend: ...@@ -39,30 +133,28 @@ class FlashAttentionBackend:
self.head_size = head_size self.head_size = head_size
self.scale = float(scale) self.scale = float(scale)
self.num_kv_heads = num_heads if num_kv_heads is None else num_kv_heads self.num_kv_heads = num_heads if num_kv_heads is None else num_kv_heads
self.sliding_window = sliding_window self.sliding_window = ((sliding_window, sliding_window)
if sliding_window is not None else (-1, -1))
if alibi_slopes is not None: if alibi_slopes is not None:
alibi_slopes = torch.tensor(alibi_slopes, dtype=torch.float32) alibi_slopes = torch.tensor(alibi_slopes, dtype=torch.float32)
self.alibi_slopes = alibi_slopes self.alibi_slopes = alibi_slopes
assert self.num_heads % self.num_kv_heads == 0 assert self.num_heads % self.num_kv_heads == 0
self.num_queries_per_kv = self.num_heads // self.num_kv_heads self.num_queries_per_kv = self.num_heads // self.num_kv_heads
suppored_head_sizes = PagedAttentionImpl.get_supported_head_sizes()
suppored_head_sizes = PagedAttention.get_supported_head_sizes()
if head_size not in suppored_head_sizes: if head_size not in suppored_head_sizes:
raise ValueError( raise ValueError(
f"Head size {head_size} is not supported by PagedAttention. " f"Head size {head_size} is not supported by PagedAttention. "
f"Supported head sizes are: {suppored_head_sizes}.") f"Supported head sizes are: {suppored_head_sizes}.")
self.sliding_window = ((self.sliding_window, self.sliding_window) if
self.sliding_window is not None else (-1, -1))
def forward( def forward(
self, self,
query: torch.Tensor, query: torch.Tensor,
key: torch.Tensor, key: torch.Tensor,
value: torch.Tensor, value: torch.Tensor,
key_cache: Optional[torch.Tensor], kv_cache: torch.Tensor,
value_cache: Optional[torch.Tensor], attn_metadata: FlashAttentionMetadata,
input_metadata: InputMetadata,
) -> torch.Tensor: ) -> torch.Tensor:
"""Forward pass with FlashAttention and PagedAttention. """Forward pass with FlashAttention and PagedAttention.
...@@ -70,11 +162,8 @@ class FlashAttentionBackend: ...@@ -70,11 +162,8 @@ class FlashAttentionBackend:
query: shape = [num_tokens, num_heads * head_size] query: shape = [num_tokens, num_heads * head_size]
key: shape = [num_tokens, num_kv_heads * head_size] key: shape = [num_tokens, num_kv_heads * head_size]
value: shape = [num_tokens, num_kv_heads * head_size] value: shape = [num_tokens, num_kv_heads * head_size]
key_cache: shape = [num_blocks, num_kv_heads, head_size/x, kv_cache = [2, num_blocks, block_size * num_kv_heads * head_size]
block_size, x] attn_metadata: Metadata for attention.
value_cache: shape = [num_blocks, num_kv_heads, head_size,
block_size]
input_metadata: metadata for the inputs.
Returns: Returns:
shape = [num_tokens, num_heads * head_size] shape = [num_tokens, num_heads * head_size]
""" """
...@@ -84,18 +173,21 @@ class FlashAttentionBackend: ...@@ -84,18 +173,21 @@ class FlashAttentionBackend:
key = key.view(-1, self.num_kv_heads, self.head_size) key = key.view(-1, self.num_kv_heads, self.head_size)
value = value.view(-1, self.num_kv_heads, self.head_size) value = value.view(-1, self.num_kv_heads, self.head_size)
# Reshape the keys and values and store them in the cache. if kv_cache is not None:
# If key_cache and value_cache are not provided, the new key and value key_cache, value_cache = PagedAttention.split_kv_cache(
# vectors will not be cached. This happens during the initial memory kv_cache, self.num_kv_heads, self.head_size)
# profiling run.
if key_cache is not None and value_cache is not None: # Reshape the input keys and values and store them in the cache.
PagedAttentionImpl.reshape_and_cache(key, value, key_cache, # If kv_cache is not provided, the new key and value tensors are
value_cache, input_metadata) # not cached. This happens during the initial memory profiling run.
PagedAttention.write_to_paged_cache(key, value, key_cache,
value_cache,
attn_metadata.slot_mapping,
attn_metadata.kv_cache_dtype)
if input_metadata.is_prompt: if attn_metadata.is_prompt:
# Prompt run. # Prompt run.
if (key_cache is None or value_cache is None if kv_cache is None or attn_metadata.block_tables.numel() == 0:
or input_metadata.block_tables.numel() == 0):
# normal attention # normal attention
# When block_tables are not filled, it means q and k are the # When block_tables are not filled, it means q and k are the
# prompt, and they have the same length. # prompt, and they have the same length.
...@@ -103,10 +195,10 @@ class FlashAttentionBackend: ...@@ -103,10 +195,10 @@ class FlashAttentionBackend:
q=query, q=query,
k=key, k=key,
v=value, v=value,
cu_seqlens_q=input_metadata.seq_start_loc, cu_seqlens_q=attn_metadata.seq_start_loc,
cu_seqlens_k=input_metadata.seq_start_loc, cu_seqlens_k=attn_metadata.seq_start_loc,
max_seqlen_q=input_metadata.max_seq_len, max_seqlen_q=attn_metadata.max_prompt_len,
max_seqlen_k=input_metadata.max_seq_len, max_seqlen_k=attn_metadata.max_prompt_len,
softmax_scale=self.scale, softmax_scale=self.scale,
causal=True, causal=True,
window_size=self.sliding_window, window_size=self.sliding_window,
...@@ -114,22 +206,29 @@ class FlashAttentionBackend: ...@@ -114,22 +206,29 @@ class FlashAttentionBackend:
) )
else: else:
# prefix-enabled attention # prefix-enabled attention
output = PagedAttentionImpl.forward_prefix( output = PagedAttention.forward_prefix(
query, query,
key, key,
value, value,
key_cache, key_cache,
value_cache, value_cache,
input_metadata, attn_metadata.block_tables,
attn_metadata.subquery_start_loc,
attn_metadata.prompt_lens_tensor,
attn_metadata.context_lens,
attn_metadata.max_subquery_len,
self.alibi_slopes, self.alibi_slopes,
) )
else: else:
# Decoding run. # Decoding run.
output = PagedAttentionImpl.forward_decode( output = PagedAttention.forward_decode(
query, query,
key_cache, key_cache,
value_cache, value_cache,
input_metadata, attn_metadata.block_tables,
attn_metadata.context_lens,
attn_metadata.max_context_len,
attn_metadata.kv_cache_dtype,
self.num_kv_heads, self.num_kv_heads,
self.scale, self.scale,
self.alibi_slopes, self.alibi_slopes,
......
"""Attention layer with xFormers and PagedAttention.""" """Attention layer with xFormers and PagedAttention."""
import importlib import importlib
from typing import List, Optional from dataclasses import dataclass
from typing import Dict, List, Optional, Tuple, Type
import torch import torch
from xformers import ops as xops from xformers import ops as xops
from xformers.ops.fmha.attn_bias import (BlockDiagonalCausalMask, from xformers.ops.fmha.attn_bias import (AttentionBias,
BlockDiagonalCausalMask,
LowerTriangularMaskWithTensorBias) LowerTriangularMaskWithTensorBias)
from vllm.model_executor.input_metadata import InputMetadata from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
from vllm.model_executor.layers.attention.ops.paged_attn import ( AttentionMetadata)
PagedAttentionImpl) from vllm.attention.ops.paged_attn import PagedAttention, PagedAttentionMetadata
from vllm.logger import init_logger
from vllm.utils import is_hip from vllm.utils import is_hip
logger = init_logger(__name__)
class XFormersBackend:
class XFormersBackend(AttentionBackend):
@staticmethod
def get_impl_cls() -> Type["XFormersImpl"]:
return XFormersImpl
@staticmethod
def make_metadata(*args, **kwargs) -> "XFormersMetadata":
return XFormersMetadata(*args, **kwargs)
@staticmethod
def get_kv_cache_shape(
num_blocks: int,
block_size: int,
num_kv_heads: int,
head_size: int,
) -> Tuple[int, ...]:
return PagedAttention.get_kv_cache_shape(num_blocks, block_size,
num_kv_heads, head_size)
@staticmethod
def swap_blocks(
src_kv_cache: torch.Tensor,
dst_kv_cache: torch.Tensor,
src_to_dst: Dict[int, int],
) -> None:
PagedAttention.swap_blocks(src_kv_cache, dst_kv_cache, src_to_dst)
@staticmethod
def copy_blocks(
kv_caches: List[torch.Tensor],
src_to_dists: Dict[int, List[int]],
) -> None:
PagedAttention.copy_blocks(kv_caches, src_to_dists)
@dataclass
class XFormersMetadata(AttentionMetadata, PagedAttentionMetadata):
"""Metadata for XFormersbackend.
NOTE: Any python object stored here is not updated when it is
cuda-graph replayed. If you have values that need to be changed
dynamically, it should be stored in tensor. The tensor has to be
updated from `CUDAGraphRunner.forward` API.
"""
# Currently, input sequences can only contain all prompts
# or all decoding. True if all sequences are prompts.
is_prompt: bool
# (num_tokens,). The indices of the token slots that input tokens will be
# stored into. E.g., if `slot_mapping` is [35, 2, 17] and the block size
# is 16, the three tokens are stored in the 3rd slot in block 2, 2nd slot
# in block 0, and 1st slot in block 1, respectively.
slot_mapping: torch.Tensor
# (batch_size,). The prompt length per sequence. None if it is a decoding.
prompt_lens: Optional[List[int]]
# prompt_lens stored as a tensor.
prompt_lens_tensor: Optional[torch.Tensor]
# The number of prompt tokens. Doesn't include padding.
num_prompt_tokens: int
# The number of generation tokens. Doesn't include padding.
num_generation_tokens: int
# NOTE(sang): Definition of context_len, subquery_len, and seqlen.
# |---------- N-1 iteration --------|
# |---------------- N iteration ---------------------|
# |- tokenA -|......................|-- newTokens ---|
# |---------- context_len ----------|
# |-------------------- seqlen ----------------------|
# |- subquery_len -|
# WARNING(sang): context_len has different definition depending on if it is
# prefill vs decoding. When it is prefill, it doesn't include new tokens.
# When it is for decoding, it includes a new token.
# Maximum subquery length in the batch.
max_subquery_len: Optional[int]
# FIXME: It is for flash attn.
# Maximum prompt length in the batch.
max_prompt_len: Optional[int]
# (batch_size + 1,). The cumulative subquery lengths of the sequences in
# the batch, used to index into subquery. E.g., if the subquery length
# is [4, 6], it is [0, 4, 10].
subquery_start_loc: Optional[torch.Tensor]
# FIXME: It is for flash attn.
# (batch_size + 1,). The cumulative sequence lengths of the sequences in
# the batch, used to index into sequence. E.g., if the sequence length is
# [4, 6], it is [0, 4, 10].
seq_start_loc: Optional[torch.Tensor]
# Whether or not if cuda graph is enabled.
# Cuda-graph is currently enabled for decoding only.
# TODO(woosuk): Move `use_cuda_graph` out since it's unrelated to attention.
use_cuda_graph: bool
def __post_init__(self):
# Set during the execution of the first attention op.
# It is a list because it is needed to set per prompt
# when alibi slopes is used. It is because of the limitation
# from xformer API.
# will not appear in the __repr__ and __init__
self.attn_bias: Optional[List[AttentionBias]] = None
class XFormersImpl(AttentionImpl):
""" """
If the input tensors contain prompt tokens, the layout is as follows: If the input tensors contain prompt tokens, the layout is as follows:
|<--------------- num_prompt_tokens --------------->| |<--------------- num_prompt_tokens --------------->|
...@@ -50,22 +158,25 @@ class XFormersBackend: ...@@ -50,22 +158,25 @@ class XFormersBackend:
assert self.num_heads % self.num_kv_heads == 0 assert self.num_heads % self.num_kv_heads == 0
self.num_queries_per_kv = self.num_heads // self.num_kv_heads self.num_queries_per_kv = self.num_heads // self.num_kv_heads
suppored_head_sizes = PagedAttentionImpl.get_supported_head_sizes()
suppored_head_sizes = PagedAttention.get_supported_head_sizes()
if head_size not in suppored_head_sizes: if head_size not in suppored_head_sizes:
raise ValueError( raise ValueError(
f"Head size {head_size} is not supported by PagedAttention. " f"Head size {head_size} is not supported by PagedAttention. "
f"Supported head sizes are: {suppored_head_sizes}.") f"Supported head sizes are: {suppored_head_sizes}.")
self.use_ref_attention = _check_use_ref_attention() # AMD Radeon 7900 series (gfx1100) currently does not support xFormers
# nor FlashAttention. As a temporary workaround, we use naive PyTorch
# implementation of attention.
self.use_naive_attention = _check_use_naive_attention()
def forward( def forward(
self, self,
query: torch.Tensor, query: torch.Tensor,
key: torch.Tensor, key: torch.Tensor,
value: torch.Tensor, value: torch.Tensor,
key_cache: Optional[torch.Tensor], kv_cache: Optional[torch.Tensor],
value_cache: Optional[torch.Tensor], attn_metadata: XFormersMetadata,
input_metadata: InputMetadata,
) -> torch.Tensor: ) -> torch.Tensor:
"""Forward pass with xFormers and PagedAttention. """Forward pass with xFormers and PagedAttention.
...@@ -73,11 +184,8 @@ class XFormersBackend: ...@@ -73,11 +184,8 @@ class XFormersBackend:
query: shape = [num_tokens, num_heads * head_size] query: shape = [num_tokens, num_heads * head_size]
key: shape = [num_tokens, num_kv_heads * head_size] key: shape = [num_tokens, num_kv_heads * head_size]
value: shape = [num_tokens, num_kv_heads * head_size] value: shape = [num_tokens, num_kv_heads * head_size]
key_cache: shape = [num_blocks, num_kv_heads, head_size/x, kv_cache = [2, num_blocks, block_size * num_kv_heads * head_size]
block_size, x] attn_metadata: Metadata for attention.
value_cache: shape = [num_blocks, num_kv_heads, head_size,
block_size]
input_metadata: metadata for the inputs.
Returns: Returns:
shape = [num_tokens, num_heads * head_size] shape = [num_tokens, num_heads * head_size]
""" """
...@@ -86,20 +194,24 @@ class XFormersBackend: ...@@ -86,20 +194,24 @@ class XFormersBackend:
key = key.view(-1, self.num_kv_heads, self.head_size) key = key.view(-1, self.num_kv_heads, self.head_size)
value = value.view(-1, self.num_kv_heads, self.head_size) value = value.view(-1, self.num_kv_heads, self.head_size)
# Reshape the keys and values and store them in the cache. if kv_cache is not None:
# If key_cache and value_cache are not provided, the new key and value key_cache, value_cache = PagedAttention.split_kv_cache(
# vectors will not be cached. This happens during the initial memory kv_cache, self.num_kv_heads, self.head_size)
# profiling run.
if key_cache is not None and value_cache is not None: # Reshape the input keys and values and store them in the cache.
PagedAttentionImpl.reshape_and_cache(key, value, key_cache, # If kv_cache is not provided, the new key and value tensors are
value_cache, input_metadata) # not cached. This happens during the initial memory profiling run.
PagedAttention.write_to_paged_cache(key, value, key_cache,
value_cache,
attn_metadata.slot_mapping,
attn_metadata.kv_cache_dtype)
if input_metadata.is_prompt: if attn_metadata.is_prompt:
# Prompt run. # Prompt run.
# key_cache and value_cache are None when it is a profiling run. if kv_cache is None or attn_metadata.block_tables.numel() == 0:
# block tables are empty if the prompt has never been computed. # normal attention.
if (key_cache is None or value_cache is None # block tables are empty if the prompt does not have a cached
or input_metadata.block_tables.numel() == 0): # prefix.
if self.num_kv_heads != self.num_heads: if self.num_kv_heads != self.num_heads:
# As of Nov 2023, xformers only supports MHA. For MQA/GQA, # As of Nov 2023, xformers only supports MHA. For MQA/GQA,
# project the key and value tensors to the desired number of # project the key and value tensors to the desired number of
...@@ -118,13 +230,12 @@ class XFormersBackend: ...@@ -118,13 +230,12 @@ class XFormersBackend:
self.num_queries_per_kv, self.num_queries_per_kv,
value.shape[-1]) value.shape[-1])
if self.use_ref_attention: if self.use_naive_attention:
print("ref attention used.")
output = torch.empty_like(query) output = torch.empty_like(query)
start = 0 start = 0
for _, prompt_len in enumerate(input_metadata.prompt_lens): for _, prompt_len in enumerate(attn_metadata.prompt_lens):
end = start + prompt_len end = start + prompt_len
out = _ref_masked_attention( out = _naive_masked_attention(
query[None, start:end], query[None, start:end],
key[None, start:end], key[None, start:end],
value[None, start:end], value[None, start:end],
...@@ -143,26 +254,33 @@ class XFormersBackend: ...@@ -143,26 +254,33 @@ class XFormersBackend:
# Use reshape instead. # Use reshape instead.
return output.reshape(num_tokens, hidden_size) return output.reshape(num_tokens, hidden_size)
output = self._run_memory_efficient_xformer_forward( output = self._run_memory_efficient_xformers_forward(
query, key, value, input_metadata) query, key, value, attn_metadata)
else: else:
# prefix-enabled attention # prefix-enabled attention
output = PagedAttentionImpl.forward_prefix( output = PagedAttention.forward_prefix(
query, query,
key, key,
value, value,
key_cache, key_cache,
value_cache, value_cache,
input_metadata, attn_metadata.block_tables,
attn_metadata.subquery_start_loc,
attn_metadata.prompt_lens_tensor,
attn_metadata.context_lens,
attn_metadata.max_subquery_len,
self.alibi_slopes, self.alibi_slopes,
) )
else: else:
# Decoding run. # Decoding run.
output = PagedAttentionImpl.forward_decode( output = PagedAttention.forward_decode(
query, query,
key_cache, key_cache,
value_cache, value_cache,
input_metadata, attn_metadata.block_tables,
attn_metadata.context_lens,
attn_metadata.max_context_len,
attn_metadata.kv_cache_dtype,
self.num_kv_heads, self.num_kv_heads,
self.scale, self.scale,
self.alibi_slopes, self.alibi_slopes,
...@@ -171,12 +289,12 @@ class XFormersBackend: ...@@ -171,12 +289,12 @@ class XFormersBackend:
# Reshape the output tensor. # Reshape the output tensor.
return output.view(-1, self.num_heads * self.head_size) return output.view(-1, self.num_heads * self.head_size)
def _run_memory_efficient_xformer_forward( def _run_memory_efficient_xformers_forward(
self, self,
query: torch.Tensor, query: torch.Tensor,
key: torch.Tensor, key: torch.Tensor,
value: torch.Tensor, value: torch.Tensor,
input_metadata: InputMetadata, attn_metadata: XFormersMetadata,
) -> torch.Tensor: ) -> torch.Tensor:
"""Attention for 1D query of multiple prompts. Multiple prompt """Attention for 1D query of multiple prompts. Multiple prompt
tokens are flattened in to `query` input. tokens are flattened in to `query` input.
...@@ -186,23 +304,23 @@ class XFormersBackend: ...@@ -186,23 +304,23 @@ class XFormersBackend:
query: shape = [num_prompt_tokens, num_heads, head_size] query: shape = [num_prompt_tokens, num_heads, head_size]
key: shape = [num_prompt_tokens, num_kv_heads, head_size] key: shape = [num_prompt_tokens, num_kv_heads, head_size]
value: shape = [num_prompt_tokens, num_kv_heads, head_size] value: shape = [num_prompt_tokens, num_kv_heads, head_size]
input_metadata: metadata for paged attention. attn_metadata: Metadata for attention.
""" """
# Set attention bias if not provided. This typically happens at # Set attention bias if not provided. This typically happens at
# the very attention layer of every iteration. # the very attention layer of every iteration.
# FIXME(woosuk): This is a hack. # FIXME(woosuk): This is a hack.
if input_metadata.attn_bias is None: if attn_metadata.attn_bias is None:
if self.alibi_slopes is None: if self.alibi_slopes is None:
attn_bias = BlockDiagonalCausalMask.from_seqlens( attn_bias = BlockDiagonalCausalMask.from_seqlens(
input_metadata.prompt_lens) attn_metadata.prompt_lens)
if self.sliding_window is not None: if self.sliding_window is not None:
attn_bias = attn_bias.make_local_attention( attn_bias = attn_bias.make_local_attention(
self.sliding_window) self.sliding_window)
input_metadata.attn_bias = [attn_bias] attn_metadata.attn_bias = [attn_bias]
else: else:
input_metadata.attn_bias = _make_alibi_bias( attn_metadata.attn_bias = _make_alibi_bias(
self.alibi_slopes, self.num_kv_heads, query.dtype, self.alibi_slopes, self.num_kv_heads, query.dtype,
input_metadata) attn_metadata.prompt_lens)
op = xops.fmha.MemoryEfficientAttentionFlashAttentionOp[0] if ( op = xops.fmha.MemoryEfficientAttentionFlashAttentionOp[0] if (
is_hip()) else None is_hip()) else None
...@@ -217,7 +335,7 @@ class XFormersBackend: ...@@ -217,7 +335,7 @@ class XFormersBackend:
query, query,
key, key,
value, value,
attn_bias=input_metadata.attn_bias[0], attn_bias=attn_metadata.attn_bias[0],
p=0.0, p=0.0,
scale=self.scale, scale=self.scale,
op=op) op=op)
...@@ -230,13 +348,13 @@ class XFormersBackend: ...@@ -230,13 +348,13 @@ class XFormersBackend:
# one. This is inefficient, especially when we have many short prompts. # one. This is inefficient, especially when we have many short prompts.
output = torch.empty_like(query) output = torch.empty_like(query)
start = 0 start = 0
for i, prompt_len in enumerate(input_metadata.prompt_lens): for i, prompt_len in enumerate(attn_metadata.prompt_lens):
end = start + prompt_len end = start + prompt_len
out = xops.memory_efficient_attention_forward( out = xops.memory_efficient_attention_forward(
query[None, start:end], query[None, start:end],
key[None, start:end], key[None, start:end],
value[None, start:end], value[None, start:end],
attn_bias=input_metadata.attn_bias[i], attn_bias=attn_metadata.attn_bias[i],
p=0.0, p=0.0,
scale=self.scale, scale=self.scale,
op=op) op=op)
...@@ -250,10 +368,10 @@ def _make_alibi_bias( ...@@ -250,10 +368,10 @@ def _make_alibi_bias(
alibi_slopes: torch.Tensor, alibi_slopes: torch.Tensor,
num_kv_heads: int, num_kv_heads: int,
dtype: torch.dtype, dtype: torch.dtype,
input_metadata: InputMetadata, prompt_lens: List[int],
) -> LowerTriangularMaskWithTensorBias: ) -> LowerTriangularMaskWithTensorBias:
attn_biases = [] attn_biases = []
for prompt_len in input_metadata.prompt_lens: for prompt_len in prompt_lens:
bias = torch.arange(prompt_len, dtype=dtype) bias = torch.arange(prompt_len, dtype=dtype)
# NOTE(zhuohan): HF uses # NOTE(zhuohan): HF uses
# `bias = bias[None, :].repeat(prompt_len, 1)` # `bias = bias[None, :].repeat(prompt_len, 1)`
...@@ -282,15 +400,19 @@ def _make_alibi_bias( ...@@ -282,15 +400,19 @@ def _make_alibi_bias(
return attn_biases return attn_biases
def _check_use_ref_attention() -> bool: def _check_use_naive_attention() -> bool:
if not is_hip(): if not is_hip():
return False return False
# For ROCm, check whether flash attention is installed or not. # For ROCm, check whether flash attention is installed or not.
# if not, use_ref_attention needs to be True has_flash_attn = importlib.util.find_spec("flash_attn") is None
return importlib.util.find_spec("flash_attn") is None if not has_flash_attn:
logger.warning("flash_attn is not installed. Using naive attention. "
"This will take significantly more GPU memory.")
return True
return False
def _ref_masked_attention( def _naive_masked_attention(
query: torch.Tensor, query: torch.Tensor,
key: torch.Tensor, key: torch.Tensor,
value: torch.Tensor, value: torch.Tensor,
......
"""Attention layer.""" """Attention layer."""
from functools import lru_cache
from typing import List, Optional from typing import List, Optional
import torch import torch
import torch.nn as nn import torch.nn as nn
from vllm.logger import init_logger from vllm.attention.backends.abstract import AttentionMetadata
from vllm.model_executor.input_metadata import InputMetadata from vllm.attention.selector import get_attn_backend
from vllm.utils import is_hip
logger = init_logger(__name__)
class Attention(nn.Module): class Attention(nn.Module):
...@@ -17,12 +13,11 @@ class Attention(nn.Module): ...@@ -17,12 +13,11 @@ class Attention(nn.Module):
This class takes query, key, and value tensors as input. The input tensors This class takes query, key, and value tensors as input. The input tensors
can either contain prompt tokens or generation tokens. can either contain prompt tokens or generation tokens.
The class does the following: The class does the following:
1. Store the input key and value tensors in the KV cache. 1. Store the input key and value tensors in the KV cache.
2. Perform (multi-head/multi-query/grouped-query) attention. 2. Perform (multi-head/multi-query/grouped-query) attention.
3. Output the output tensor. 3. Return the output tensor.
""" """
def __init__( def __init__(
...@@ -35,51 +30,17 @@ class Attention(nn.Module): ...@@ -35,51 +30,17 @@ class Attention(nn.Module):
sliding_window: Optional[int] = None, sliding_window: Optional[int] = None,
) -> None: ) -> None:
super().__init__() super().__init__()
if _use_flash_attn(): self.backend = get_attn_backend(torch.get_default_dtype())
from vllm.model_executor.layers.attention.backends.flash_attn import FlashAttentionBackend # noqa: E501 impl_cls = self.backend.get_impl_cls()
self.backend = FlashAttentionBackend(num_heads, head_size, scale, self.impl = impl_cls(num_heads, head_size, scale, num_kv_heads,
num_kv_heads, alibi_slopes, alibi_slopes, sliding_window)
sliding_window)
else:
from vllm.model_executor.layers.attention.backends.xformers import XFormersBackend # noqa: E501
self.backend = XFormersBackend(num_heads, head_size, scale,
num_kv_heads, alibi_slopes,
sliding_window)
def forward( def forward(
self, self,
query: torch.Tensor, query: torch.Tensor,
key: torch.Tensor, key: torch.Tensor,
value: torch.Tensor, value: torch.Tensor,
key_cache: Optional[torch.Tensor], kv_cache: Optional[torch.Tensor],
value_cache: Optional[torch.Tensor], attn_metadata: AttentionMetadata,
input_metadata: InputMetadata,
) -> torch.Tensor: ) -> torch.Tensor:
return self.backend.forward(query, key, value, key_cache, value_cache, return self.impl.forward(query, key, value, kv_cache, attn_metadata)
input_metadata)
@lru_cache(maxsize=1)
def _use_flash_attn() -> bool:
try:
import flash_attn # noqa: F401
except ImportError:
logger.info("flash_attn is not found. Using xformers backend.")
return False
if is_hip():
# AMD GPUs.
return False
if torch.cuda.get_device_capability()[0] < 8:
# Volta and Turing NVIDIA GPUs.
logger.info("flash_attn is not supported on Turing or older GPUs. "
"Using xformers backend.")
return False
if torch.get_default_dtype() not in (torch.float16, torch.bfloat16):
logger.info(
"flash_attn only supports torch.float16 or torch.bfloat16. "
"Using xformers backend.")
return False
logger.info("Using flash_attn backend.")
return True
from typing import List, Optional from dataclasses import dataclass
from typing import Dict, List, Optional, Tuple
import torch import torch
from vllm._C import cache_ops from vllm._C import cache_ops
from vllm._C import ops from vllm._C import ops
from vllm.model_executor.input_metadata import InputMetadata from vllm.attention.ops.prefix_prefill import context_attention_fwd
from vllm.model_executor.layers.attention.ops.prefix_prefill import (
context_attention_fwd)
# Should be the same as PARTITION_SIZE in `paged_attention_v2_launcher`. # Should be the same as PARTITION_SIZE in `paged_attention_v2_launcher`.
_PARTITION_SIZE = 512 _PARTITION_SIZE = 512
class PagedAttentionImpl: @dataclass
class PagedAttentionMetadata:
"""Metadata for PagedAttention."""
# (num_tokens,). The indices of the token slots that input tokens will be
# stored into. E.g., if `slot_mapping` is [35, 2, 17] and the block size
# is 16, the three tokens are stored in the 3rd slot in block 2, 2nd slot
# in block 0, and 1st slot in block 1, respectively.
slot_mapping: torch.Tensor
# (batch_size,). The length of context (tokens stored in KV cache) per
# sequence. WARNING: When it is a prefill request, it doesn't include new
# tokens. When it is for decoding, it includes a new token.
context_lens: Optional[torch.Tensor]
# Maximum context length in the batch.
max_context_len: Optional[int]
# (batch_size, max_blocks_per_seq).
# Block addresses per sequence. (Seq id -> list of physical block)
# E.g., [0, 1, 2] means tokens are stored in 0th, 1st, and 2nd blocks
# in the kv cache. Each block can contain up to block_size tokens.
# 2nd dimensions are padded up to max_blocks_per_seq if it is cuda-graph
# captured.
block_tables: Optional[torch.Tensor]
kv_cache_dtype: str
class PagedAttention:
@staticmethod @staticmethod
def get_supported_head_sizes() -> List[int]: def get_supported_head_sizes() -> List[int]:
return [64, 80, 96, 112, 128, 256] return [64, 80, 96, 112, 128, 256]
@staticmethod @staticmethod
def reshape_and_cache( def get_kv_cache_shape(
num_blocks: int,
block_size: int,
num_kv_heads: int,
head_size: int,
) -> Tuple[int, ...]:
return (2, num_blocks, block_size * num_kv_heads * head_size)
@staticmethod
def split_kv_cache(
kv_cache: torch.Tensor,
num_kv_heads: int,
head_size: int,
) -> Tuple[torch.Tensor, torch.Tensor]:
x = 16 // kv_cache.element_size()
num_blocks = kv_cache.shape[1]
key_cache = kv_cache[0]
key_cache = key_cache.view(num_blocks, num_kv_heads, head_size // x,
-1, x)
value_cache = kv_cache[1]
value_cache = value_cache.view(num_blocks, num_kv_heads, head_size, -1)
return key_cache, value_cache
@staticmethod
def write_to_paged_cache(
key: torch.Tensor, key: torch.Tensor,
value: torch.Tensor, value: torch.Tensor,
key_cache: torch.Tensor, key_cache: torch.Tensor,
value_cache: torch.Tensor, value_cache: torch.Tensor,
input_metadata: InputMetadata, slot_mapping: torch.Tensor,
kv_cache_dtype: str,
) -> None: ) -> None:
cache_ops.reshape_and_cache( cache_ops.reshape_and_cache(
key, key,
value, value,
key_cache, key_cache,
value_cache, value_cache,
input_metadata.slot_mapping.flatten(), slot_mapping.flatten(),
input_metadata.kv_cache_dtype, kv_cache_dtype,
) )
@staticmethod @staticmethod
...@@ -40,7 +89,10 @@ class PagedAttentionImpl: ...@@ -40,7 +89,10 @@ class PagedAttentionImpl:
query: torch.Tensor, query: torch.Tensor,
key_cache: torch.Tensor, key_cache: torch.Tensor,
value_cache: torch.Tensor, value_cache: torch.Tensor,
input_metadata: InputMetadata, block_tables: torch.Tensor,
context_lens: torch.Tensor,
max_context_len: int,
kv_cache_dtype: str,
num_kv_heads: int, num_kv_heads: int,
scale: float, scale: float,
alibi_slopes: Optional[torch.Tensor], alibi_slopes: Optional[torch.Tensor],
...@@ -49,8 +101,7 @@ class PagedAttentionImpl: ...@@ -49,8 +101,7 @@ class PagedAttentionImpl:
block_size = value_cache.shape[3] block_size = value_cache.shape[3]
num_seqs, num_heads, head_size = query.shape num_seqs, num_heads, head_size = query.shape
max_num_partitions = ( max_num_partitions = ((max_context_len + _PARTITION_SIZE - 1) //
(input_metadata.max_context_len + _PARTITION_SIZE - 1) //
_PARTITION_SIZE) _PARTITION_SIZE)
# NOTE(woosuk): We use a simple heuristic to decide whether to use # NOTE(woosuk): We use a simple heuristic to decide whether to use
# PagedAttention V1 or V2. If the number of partitions is 1, we use # PagedAttention V1 or V2. If the number of partitions is 1, we use
...@@ -59,8 +110,8 @@ class PagedAttentionImpl: ...@@ -59,8 +110,8 @@ class PagedAttentionImpl:
# to parallelize. # to parallelize.
# TODO(woosuk): Tune this heuristic. # TODO(woosuk): Tune this heuristic.
# For context len > 8192, use V2 kernel to avoid shared memory shortage. # For context len > 8192, use V2 kernel to avoid shared memory shortage.
use_v1 = input_metadata.max_context_len <= 8192 and ( use_v1 = (max_context_len <= 8192
max_num_partitions == 1 or num_seqs * num_heads > 512) and (max_num_partitions == 1 or num_seqs * num_heads > 512))
if use_v1: if use_v1:
# Run PagedAttention V1. # Run PagedAttention V1.
ops.paged_attention_v1( ops.paged_attention_v1(
...@@ -70,12 +121,12 @@ class PagedAttentionImpl: ...@@ -70,12 +121,12 @@ class PagedAttentionImpl:
value_cache, value_cache,
num_kv_heads, num_kv_heads,
scale, scale,
input_metadata.block_tables, block_tables,
input_metadata.context_lens, context_lens,
block_size, block_size,
input_metadata.max_context_len, max_context_len,
alibi_slopes, alibi_slopes,
input_metadata.kv_cache_dtype, kv_cache_dtype,
) )
else: else:
# Run PagedAttention V2. # Run PagedAttention V2.
...@@ -101,12 +152,12 @@ class PagedAttentionImpl: ...@@ -101,12 +152,12 @@ class PagedAttentionImpl:
value_cache, value_cache,
num_kv_heads, num_kv_heads,
scale, scale,
input_metadata.block_tables, block_tables,
input_metadata.context_lens, context_lens,
block_size, block_size,
input_metadata.max_context_len, max_context_len,
alibi_slopes, alibi_slopes,
input_metadata.kv_cache_dtype, kv_cache_dtype,
) )
return output return output
...@@ -117,7 +168,11 @@ class PagedAttentionImpl: ...@@ -117,7 +168,11 @@ class PagedAttentionImpl:
value: torch.Tensor, value: torch.Tensor,
key_cache: torch.Tensor, key_cache: torch.Tensor,
value_cache: torch.Tensor, value_cache: torch.Tensor,
input_metadata: InputMetadata, block_tables: torch.Tensor,
subquery_start_loc: torch.Tensor,
prompt_lens_tensor: torch.Tensor,
context_lens: torch.Tensor,
max_subquery_len: int,
alibi_slopes: Optional[torch.Tensor], alibi_slopes: Optional[torch.Tensor],
) -> torch.Tensor: ) -> torch.Tensor:
output = torch.empty_like(query) output = torch.empty_like(query)
...@@ -128,12 +183,35 @@ class PagedAttentionImpl: ...@@ -128,12 +183,35 @@ class PagedAttentionImpl:
output, output,
key_cache, key_cache,
value_cache, value_cache,
input_metadata.block_tables, block_tables,
# subquery_start_loc is (batch_size + 1,) # subquery_start_loc is (batch_size + 1,)
input_metadata.subquery_start_loc[:-1], subquery_start_loc[:-1],
input_metadata.prompt_lens_tensor, prompt_lens_tensor,
input_metadata.context_lens, context_lens,
input_metadata.max_subquery_len, max_subquery_len,
alibi_slopes, alibi_slopes,
) )
return output return output
@staticmethod
def swap_blocks(
src_kv_cache: torch.Tensor,
dst_kv_cache: torch.Tensor,
src_to_dst: Dict[int, int],
) -> None:
src_key_cache = src_kv_cache[0]
dst_key_cache = dst_kv_cache[0]
cache_ops.swap_blocks(src_key_cache, dst_key_cache, src_to_dst)
src_value_cache = src_kv_cache[1]
dst_value_cache = dst_kv_cache[1]
cache_ops.swap_blocks(src_value_cache, dst_value_cache, src_to_dst)
@staticmethod
def copy_blocks(
kv_caches: List[torch.Tensor],
src_to_dists: Dict[int, List[int]],
) -> None:
key_caches = [kv_cache[0] for kv_cache in kv_caches]
value_caches = [kv_cache[1] for kv_cache in kv_caches]
cache_ops.copy_blocks(key_caches, value_caches, src_to_dists)
from functools import lru_cache
import torch
from vllm.attention.backends.abstract import AttentionBackend
from vllm.logger import init_logger
from vllm.utils import is_hip
logger = init_logger(__name__)
@lru_cache(maxsize=None)
def get_attn_backend(dtype: torch.dtype) -> AttentionBackend:
if _can_use_flash_attn(dtype):
logger.info("Using FlashAttention backend.")
from vllm.attention.backends.flash_attn import FlashAttentionBackend # noqa: F401
return FlashAttentionBackend
else:
logger.info("Using XFormers backend.")
from vllm.attention.backends.xformers import XFormersBackend # noqa: F401
return XFormersBackend
def _can_use_flash_attn(dtype: torch.dtype) -> bool:
if is_hip():
# AMD GPUs.
logger.info("Cannot use FlashAttention backend for AMD GPUs.")
return False
if torch.cuda.get_device_capability()[0] < 8:
# Volta and Turing NVIDIA GPUs.
logger.info("Cannot use FlashAttention backend for Volta and Turing "
"GPUs.")
return False
if dtype not in (torch.float16, torch.bfloat16):
logger.info("Cannot use FlashAttention backend for dtype other than "
"torch.float16 or torch.bfloat16.")
return False
try:
import flash_attn # noqa: F401
except ImportError:
logger.info("flash_attn is not found.")
return False
return True
from vllm.model_executor.input_metadata import InputMetadata
from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.model_executor.utils import set_random_seed from vllm.model_executor.utils import set_random_seed
__all__ = [ __all__ = [
"InputMetadata",
"SamplingMetadata", "SamplingMetadata",
"set_random_seed", "set_random_seed",
] ]
from dataclasses import dataclass, fields
from typing import TYPE_CHECKING, Optional, List, Any, Dict
import torch
if TYPE_CHECKING:
from xformers.ops.fmha.attn_bias import AttentionBias
@dataclass
class InputMetadata:
"""Metadata for input sequences. Used in PagedAttention.
NOTE: Any python object stored here is not updated when it is
cuda-graph replayed. If you have values that need to be changed
dynamically, it should be stored in tensor. The tensor has to be
updated from `CUDAGraphRunner.forward` API.
"""
# Currently, input sequences can only contain all prompts
# or all decoding. True if all sequences are prompts.
is_prompt: bool
# (num_tokens,). The indices of the token slots that input tokens will be
# stored into. E.g., if `slot_mapping` is [35, 2, 17] and the block size
# is 16, the three tokens are stored in the 3rd slot in block 2, 2nd slot
# in block 0, and 1st slot in block 1, respectively.
slot_mapping: torch.Tensor
# (batch_size,). The prompt length per sequence. None if it is a decoding.
prompt_lens: Optional[List[int]]
# prompt_lens stored as a tensor.
prompt_lens_tensor: Optional[torch.Tensor]
# The number of prompt tokens. Doesn't include padding.
num_prompt_tokens: int
# The number of generation tokens. Doesn't include padding.
num_generation_tokens: int
"""
Definition of context_len, subquery_len, and seqlen.
|---------- N-1 iteration --------|
|---------------- N iteration ---------------------|
|- tokenA -|......................|-- newTokens ---|
|---------- context_len ----------|
|-------------------- seqlen ----------------------|
|- subquery_len -|
WARNING: context_len has different definition depending on if it is
prefill vs decoding. When it is prefill, it doesn't include new
tokens. When it is for decoding, it includes a new token.
"""
# Maximum subquery length in the batch.
max_subquery_len: Optional[int]
# Maximum context length in the batch.
max_context_len: Optional[int]
# FIXME: It is for flash attn.
# Maximum sequence length in the batch.
max_seq_len: Optional[int]
# (batch_size + 1,). The cumulative subquery lengths of the sequences in
# the batch, used to index into subquery. E.g., if the subquery length
# is [4, 6], it is [0, 4, 10].
subquery_start_loc: Optional[torch.Tensor]
# FIXME: It is for flash attn.
# (batch_size + 1,). The cumulative sequence lengths of the sequences in
# the batch, used to index into sequence. E.g., if the sequence length is
# [4, 6], it is [0, 4, 10].
seq_start_loc: Optional[torch.Tensor]
# (batch_size,). The length of context (tokens stored in KV cache) per
# sequence. WARNING: When it is a prefill request, it doesn't include new
# tokens. When it is for decoding, it includes a new token.
context_lens: Optional[torch.Tensor]
# (batch_size, max_blocks_per_seq).
# Block addresses per sequence. (Seq id -> list of physical block)
# E.g., [0, 1, 2] means tokens are stored in 0th, 1st, and 2nd blocks
# in the kv cache. Each block can contain up to block_size tokens.
# 2nd dimensions are padded up to max_blocks_per_seq if it is cuda-graph
# captured.
block_tables: Optional[torch.Tensor]
# Whether or not if cuda graph is enabled.
# Cuda-graph is currently enabled for decoding only.
use_cuda_graph: bool
kv_cache_dtype: str
def __post_init__(self):
# Set during the execution of the first attention op.
# It is a list because it is needed to set per prompt
# when alibi slopes is used. It is because of the limitation
# from xformer API.
# will not appear in the __repr__ and __init__
self.attn_bias: Optional[List["AttentionBias"]] = None
# Cuda graph is only used for decoding now.
if self.use_cuda_graph:
assert self.num_prompt_tokens == 0
def asdict_zerocopy(self) -> Dict[str, Any]:
"""Similar to dataclasses.asdict, but avoids deepcopying."""
# Note that if we add dataclasses as fields, they will need
# similar handling.
return {
field.name: getattr(self, field.name)
for field in fields(self)
}
from vllm.model_executor.layers.attention.attention import Attention
__all__ = [
"Attention",
]
...@@ -25,9 +25,8 @@ import torch ...@@ -25,9 +25,8 @@ import torch
from torch import nn from torch import nn
from transformers import PretrainedConfig from transformers import PretrainedConfig
from vllm.model_executor.input_metadata import InputMetadata from vllm.attention import Attention, AttentionMetadata
from vllm.model_executor.layers.activation import SiluAndMul from vllm.model_executor.layers.activation import SiluAndMul
from vllm.model_executor.layers.attention import Attention
from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.linear import (LinearMethodBase, from vllm.model_executor.layers.linear import (LinearMethodBase,
MergedColumnParallelLinear, MergedColumnParallelLinear,
...@@ -45,8 +44,6 @@ from vllm.model_executor.weight_utils import (default_weight_loader, ...@@ -45,8 +44,6 @@ from vllm.model_executor.weight_utils import (default_weight_loader,
hf_model_weights_iterator) hf_model_weights_iterator)
from vllm.sequence import SamplerOutput from vllm.sequence import SamplerOutput
KVCache = Tuple[torch.Tensor, torch.Tensor]
def _get_alibi_slopes(total_num_heads: int) -> torch.Tensor: def _get_alibi_slopes(total_num_heads: int) -> torch.Tensor:
closest_power_of_2 = 2**math.floor(math.log2(total_num_heads)) closest_power_of_2 = 2**math.floor(math.log2(total_num_heads))
...@@ -170,15 +167,14 @@ class BaiChuanAttention(nn.Module): ...@@ -170,15 +167,14 @@ class BaiChuanAttention(nn.Module):
self, self,
positions: torch.Tensor, positions: torch.Tensor,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
kv_cache: KVCache, kv_cache: torch.Tensor,
input_metadata: InputMetadata, attn_metadata: AttentionMetadata,
) -> torch.Tensor: ) -> torch.Tensor:
qkv, _ = self.W_pack(hidden_states) qkv, _ = self.W_pack(hidden_states)
q, k, v = qkv.chunk(chunks=3, dim=-1) q, k, v = qkv.chunk(chunks=3, dim=-1)
if self.postion_embedding != "ALIBI": if self.postion_embedding != "ALIBI":
q, k = self.rotary_emb(positions, q, k) q, k = self.rotary_emb(positions, q, k)
k_cache, v_cache = kv_cache attn_output = self.attn(q, k, v, kv_cache, attn_metadata)
attn_output = self.attn(q, k, v, k_cache, v_cache, input_metadata)
output, _ = self.o_proj(attn_output) output, _ = self.o_proj(attn_output)
return output return output
...@@ -217,8 +213,8 @@ class BaiChuanDecoderLayer(nn.Module): ...@@ -217,8 +213,8 @@ class BaiChuanDecoderLayer(nn.Module):
self, self,
positions: torch.Tensor, positions: torch.Tensor,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
kv_cache: KVCache, kv_cache: torch.Tensor,
input_metadata: InputMetadata, attn_metadata: AttentionMetadata,
residual: Optional[torch.Tensor], residual: Optional[torch.Tensor],
) -> Tuple[torch.Tensor, torch.Tensor]: ) -> Tuple[torch.Tensor, torch.Tensor]:
# Self Attention # Self Attention
...@@ -232,7 +228,7 @@ class BaiChuanDecoderLayer(nn.Module): ...@@ -232,7 +228,7 @@ class BaiChuanDecoderLayer(nn.Module):
positions=positions, positions=positions,
hidden_states=hidden_states, hidden_states=hidden_states,
kv_cache=kv_cache, kv_cache=kv_cache,
input_metadata=input_metadata, attn_metadata=attn_metadata,
) )
# Fully Connected # Fully Connected
...@@ -267,8 +263,8 @@ class BaiChuanModel(nn.Module): ...@@ -267,8 +263,8 @@ class BaiChuanModel(nn.Module):
self, self,
input_ids: torch.Tensor, input_ids: torch.Tensor,
positions: torch.Tensor, positions: torch.Tensor,
kv_caches: List[KVCache], kv_caches: List[torch.Tensor],
input_metadata: InputMetadata, attn_metadata: AttentionMetadata,
) -> torch.Tensor: ) -> torch.Tensor:
hidden_states = self.embed_tokens(input_ids) hidden_states = self.embed_tokens(input_ids)
residual = None residual = None
...@@ -278,7 +274,7 @@ class BaiChuanModel(nn.Module): ...@@ -278,7 +274,7 @@ class BaiChuanModel(nn.Module):
positions, positions,
hidden_states, hidden_states,
kv_caches[i], kv_caches[i],
input_metadata, attn_metadata,
residual, residual,
) )
hidden_states, _ = self.norm(hidden_states, residual) hidden_states, _ = self.norm(hidden_states, residual)
...@@ -303,11 +299,11 @@ class BaiChuanBaseForCausalLM(nn.Module): ...@@ -303,11 +299,11 @@ class BaiChuanBaseForCausalLM(nn.Module):
self, self,
input_ids: torch.Tensor, input_ids: torch.Tensor,
positions: torch.Tensor, positions: torch.Tensor,
kv_caches: List[KVCache], kv_caches: List[torch.Tensor],
input_metadata: InputMetadata, attn_metadata: AttentionMetadata,
) -> torch.Tensor: ) -> torch.Tensor:
hidden_states = self.model(input_ids, positions, kv_caches, hidden_states = self.model(input_ids, positions, kv_caches,
input_metadata) attn_metadata)
return hidden_states return hidden_states
def compute_logits(self, hidden_states: torch.Tensor, def compute_logits(self, hidden_states: torch.Tensor,
......
...@@ -17,15 +17,14 @@ ...@@ -17,15 +17,14 @@
# limitations under the License. # limitations under the License.
"""Inference-only BLOOM model compatible with HuggingFace weights.""" """Inference-only BLOOM model compatible with HuggingFace weights."""
import math import math
from typing import List, Optional, Tuple from typing import List, Optional
import torch import torch
from torch import nn from torch import nn
from transformers import BloomConfig from transformers import BloomConfig
from vllm.model_executor.input_metadata import InputMetadata from vllm.attention import Attention, AttentionMetadata
from vllm.model_executor.layers.activation import get_act_fn from vllm.model_executor.layers.activation import get_act_fn
from vllm.model_executor.layers.attention import Attention
from vllm.model_executor.layers.linear import (ColumnParallelLinear, from vllm.model_executor.layers.linear import (ColumnParallelLinear,
LinearMethodBase, LinearMethodBase,
QKVParallelLinear, QKVParallelLinear,
...@@ -41,8 +40,6 @@ from vllm.model_executor.weight_utils import (default_weight_loader, ...@@ -41,8 +40,6 @@ from vllm.model_executor.weight_utils import (default_weight_loader,
hf_model_weights_iterator) hf_model_weights_iterator)
from vllm.sequence import SamplerOutput from vllm.sequence import SamplerOutput
KVCache = Tuple[torch.Tensor, torch.Tensor]
def _get_alibi_slopes(total_num_heads: int) -> torch.Tensor: def _get_alibi_slopes(total_num_heads: int) -> torch.Tensor:
closest_power_of_2 = 2**math.floor(math.log2(total_num_heads)) closest_power_of_2 = 2**math.floor(math.log2(total_num_heads))
...@@ -117,14 +114,13 @@ class BloomAttention(nn.Module): ...@@ -117,14 +114,13 @@ class BloomAttention(nn.Module):
self, self,
position_ids: torch.Tensor, position_ids: torch.Tensor,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
kv_cache: KVCache, kv_cache: torch.Tensor,
input_metadata: InputMetadata, attn_metadata: AttentionMetadata,
) -> torch.Tensor: ) -> torch.Tensor:
del position_ids # Unused. del position_ids # Unused.
qkv, _ = self.query_key_value(hidden_states) qkv, _ = self.query_key_value(hidden_states)
q, k, v = qkv.chunk(chunks=3, dim=-1) q, k, v = qkv.chunk(chunks=3, dim=-1)
k_cache, v_cache = kv_cache attn_output = self.attn(q, k, v, kv_cache, attn_metadata)
attn_output = self.attn(q, k, v, k_cache, v_cache, input_metadata)
output, _ = self.dense(attn_output) output, _ = self.dense(attn_output)
return output return output
...@@ -181,8 +177,8 @@ class BloomBlock(nn.Module): ...@@ -181,8 +177,8 @@ class BloomBlock(nn.Module):
self, self,
position_ids: torch.Tensor, position_ids: torch.Tensor,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
kv_cache: KVCache, kv_cache: torch.Tensor,
input_metadata: InputMetadata, attn_metadata: AttentionMetadata,
) -> torch.Tensor: ) -> torch.Tensor:
# Layer norm at the beginning of the transformer layer. # Layer norm at the beginning of the transformer layer.
layernorm_output = self.input_layernorm(hidden_states) layernorm_output = self.input_layernorm(hidden_states)
...@@ -198,7 +194,7 @@ class BloomBlock(nn.Module): ...@@ -198,7 +194,7 @@ class BloomBlock(nn.Module):
position_ids=position_ids, position_ids=position_ids,
hidden_states=layernorm_output, hidden_states=layernorm_output,
kv_cache=kv_cache, kv_cache=kv_cache,
input_metadata=input_metadata, attn_metadata=attn_metadata,
) )
attention_output = attention_output + residual attention_output = attention_output + residual
layernorm_output = self.post_attention_layernorm(attention_output) layernorm_output = self.post_attention_layernorm(attention_output)
...@@ -245,8 +241,8 @@ class BloomModel(nn.Module): ...@@ -245,8 +241,8 @@ class BloomModel(nn.Module):
self, self,
input_ids: torch.Tensor, input_ids: torch.Tensor,
position_ids: torch.Tensor, position_ids: torch.Tensor,
kv_caches: List[KVCache], kv_caches: List[torch.Tensor],
input_metadata: InputMetadata, attn_metadata: AttentionMetadata,
) -> torch.Tensor: ) -> torch.Tensor:
hidden_states = self.word_embeddings(input_ids) hidden_states = self.word_embeddings(input_ids)
hidden_states = self.word_embeddings_layernorm(hidden_states) hidden_states = self.word_embeddings_layernorm(hidden_states)
...@@ -256,7 +252,7 @@ class BloomModel(nn.Module): ...@@ -256,7 +252,7 @@ class BloomModel(nn.Module):
position_ids, position_ids,
hidden_states, hidden_states,
kv_caches[i], kv_caches[i],
input_metadata, attn_metadata,
) )
hidden_states = self.ln_f(hidden_states) hidden_states = self.ln_f(hidden_states)
return hidden_states return hidden_states
...@@ -281,11 +277,11 @@ class BloomForCausalLM(nn.Module): ...@@ -281,11 +277,11 @@ class BloomForCausalLM(nn.Module):
self, self,
input_ids: torch.Tensor, input_ids: torch.Tensor,
positions: torch.Tensor, positions: torch.Tensor,
kv_caches: List[KVCache], kv_caches: List[torch.Tensor],
input_metadata: InputMetadata, attn_metadata: AttentionMetadata,
) -> torch.Tensor: ) -> torch.Tensor:
hidden_states = self.transformer(input_ids, positions, kv_caches, hidden_states = self.transformer(input_ids, positions, kv_caches,
input_metadata) attn_metadata)
return hidden_states return hidden_states
def compute_logits(self, hidden_states: torch.Tensor, def compute_logits(self, hidden_states: torch.Tensor,
......
...@@ -2,15 +2,14 @@ ...@@ -2,15 +2,14 @@
# Adapted from # Adapted from
# https://github.com/THUDM/ChatGLM2-6B # https://github.com/THUDM/ChatGLM2-6B
"""Inference-only ChatGLM model compatible with THUDM weights.""" """Inference-only ChatGLM model compatible with THUDM weights."""
from typing import List, Optional, Tuple from typing import List, Optional
import torch import torch
from torch import nn from torch import nn
from torch.nn import LayerNorm from torch.nn import LayerNorm
from vllm.model_executor.input_metadata import InputMetadata from vllm.attention import Attention, AttentionMetadata
from vllm.model_executor.layers.activation import SiluAndMul from vllm.model_executor.layers.activation import SiluAndMul
from vllm.model_executor.layers.attention import Attention
from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.linear import (LinearMethodBase, from vllm.model_executor.layers.linear import (LinearMethodBase,
MergedColumnParallelLinear, MergedColumnParallelLinear,
...@@ -29,8 +28,6 @@ from vllm.model_executor.weight_utils import (default_weight_loader, ...@@ -29,8 +28,6 @@ from vllm.model_executor.weight_utils import (default_weight_loader,
from vllm.sequence import SamplerOutput from vllm.sequence import SamplerOutput
from vllm.transformers_utils.configs import ChatGLMConfig from vllm.transformers_utils.configs import ChatGLMConfig
KVCache = Tuple[torch.Tensor, torch.Tensor]
class GLMAttention(nn.Module): class GLMAttention(nn.Module):
...@@ -99,20 +96,18 @@ class GLMAttention(nn.Module): ...@@ -99,20 +96,18 @@ class GLMAttention(nn.Module):
self, self,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
position_ids: torch.Tensor, position_ids: torch.Tensor,
kv_cache: KVCache, kv_cache: torch.Tensor,
input_metadata: InputMetadata, attn_metadata: AttentionMetadata,
) -> torch.Tensor: ) -> torch.Tensor:
qkv, _ = self.query_key_value(hidden_states) qkv, _ = self.query_key_value(hidden_states)
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
q, k = self.rotary_emb(position_ids, q, k) q, k = self.rotary_emb(position_ids, q, k)
key_cache, value_cache = kv_cache
context_layer = self.attn( context_layer = self.attn(
q, q,
k, k,
v, v,
key_cache, kv_cache,
value_cache, attn_metadata,
input_metadata,
) )
attn_output, _ = self.dense(context_layer) attn_output, _ = self.dense(context_layer)
return attn_output return attn_output
...@@ -200,8 +195,8 @@ class GLMBlock(nn.Module): ...@@ -200,8 +195,8 @@ class GLMBlock(nn.Module):
self, self,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
position_ids: torch.Tensor, position_ids: torch.Tensor,
kv_cache: KVCache, kv_cache: torch.Tensor,
input_metadata: InputMetadata, attn_metadata: AttentionMetadata,
) -> torch.Tensor: ) -> torch.Tensor:
# hidden_states: [num_tokens, h] # hidden_states: [num_tokens, h]
# Layer norm at the beginning of the transformer layer. # Layer norm at the beginning of the transformer layer.
...@@ -211,7 +206,7 @@ class GLMBlock(nn.Module): ...@@ -211,7 +206,7 @@ class GLMBlock(nn.Module):
hidden_states=layernorm_output, hidden_states=layernorm_output,
position_ids=position_ids, position_ids=position_ids,
kv_cache=kv_cache, kv_cache=kv_cache,
input_metadata=input_metadata, attn_metadata=attn_metadata,
) )
# Residual connection. # Residual connection.
...@@ -264,8 +259,8 @@ class GLMTransformer(nn.Module): ...@@ -264,8 +259,8 @@ class GLMTransformer(nn.Module):
self, self,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
position_ids: torch.Tensor, position_ids: torch.Tensor,
kv_caches: List[KVCache], kv_caches: List[torch.Tensor],
input_metadata: InputMetadata, attn_metadata: AttentionMetadata,
) -> torch.Tensor: ) -> torch.Tensor:
for i in range(self.num_layers): for i in range(self.num_layers):
layer = self.layers[i] layer = self.layers[i]
...@@ -273,7 +268,7 @@ class GLMTransformer(nn.Module): ...@@ -273,7 +268,7 @@ class GLMTransformer(nn.Module):
hidden_states=hidden_states, hidden_states=hidden_states,
position_ids=position_ids, position_ids=position_ids,
kv_cache=kv_caches[i], kv_cache=kv_caches[i],
input_metadata=input_metadata, attn_metadata=attn_metadata,
) )
# Final layer norm. # Final layer norm.
if self.post_layer_norm: if self.post_layer_norm:
...@@ -306,8 +301,8 @@ class ChatGLMModel(nn.Module): ...@@ -306,8 +301,8 @@ class ChatGLMModel(nn.Module):
self, self,
input_ids: torch.Tensor, input_ids: torch.Tensor,
position_ids: torch.Tensor, position_ids: torch.Tensor,
kv_caches: List[KVCache], kv_caches: List[torch.Tensor],
input_metadata: InputMetadata, attn_metadata: AttentionMetadata,
) -> torch.Tensor: ) -> torch.Tensor:
inputs_embeds = self.embedding(input_ids) inputs_embeds = self.embedding(input_ids)
...@@ -316,7 +311,7 @@ class ChatGLMModel(nn.Module): ...@@ -316,7 +311,7 @@ class ChatGLMModel(nn.Module):
hidden_states=inputs_embeds, hidden_states=inputs_embeds,
position_ids=position_ids, position_ids=position_ids,
kv_caches=kv_caches, kv_caches=kv_caches,
input_metadata=input_metadata, attn_metadata=attn_metadata,
) )
return hidden_states return hidden_states
...@@ -340,11 +335,11 @@ class ChatGLMForCausalLM(nn.Module): ...@@ -340,11 +335,11 @@ class ChatGLMForCausalLM(nn.Module):
self, self,
input_ids: torch.Tensor, input_ids: torch.Tensor,
positions: torch.Tensor, positions: torch.Tensor,
kv_caches: List[KVCache], kv_caches: List[torch.Tensor],
input_metadata: InputMetadata, attn_metadata: AttentionMetadata,
) -> torch.Tensor: ) -> torch.Tensor:
hidden_states = self.transformer(input_ids, positions, kv_caches, hidden_states = self.transformer(input_ids, positions, kv_caches,
input_metadata) attn_metadata)
return hidden_states return hidden_states
def compute_logits(self, hidden_states: torch.Tensor, def compute_logits(self, hidden_states: torch.Tensor,
......
...@@ -21,15 +21,14 @@ ...@@ -21,15 +21,14 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
"""Inference-only Deepseek model.""" """Inference-only Deepseek model."""
from typing import Any, Dict, List, Optional, Tuple from typing import Any, Dict, List, Optional
import torch import torch
from torch import nn from torch import nn
from transformers import PretrainedConfig from transformers import PretrainedConfig
from vllm.model_executor.input_metadata import InputMetadata from vllm.attention import Attention, AttentionMetadata
from vllm.model_executor.layers.activation import SiluAndMul from vllm.model_executor.layers.activation import SiluAndMul
from vllm.model_executor.layers.attention import Attention
from vllm.model_executor.layers.fused_moe import fused_moe from vllm.model_executor.layers.fused_moe import fused_moe
from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.linear import (LinearMethodBase, from vllm.model_executor.layers.linear import (LinearMethodBase,
...@@ -51,8 +50,6 @@ from vllm.model_executor.weight_utils import (default_weight_loader, ...@@ -51,8 +50,6 @@ from vllm.model_executor.weight_utils import (default_weight_loader,
hf_model_weights_iterator) hf_model_weights_iterator)
from vllm.sequence import SamplerOutput from vllm.sequence import SamplerOutput
KVCache = Tuple[torch.Tensor, torch.Tensor]
class DeepseekMLP(nn.Module): class DeepseekMLP(nn.Module):
...@@ -239,14 +236,13 @@ class DeepseekAttention(nn.Module): ...@@ -239,14 +236,13 @@ class DeepseekAttention(nn.Module):
self, self,
positions: torch.Tensor, positions: torch.Tensor,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
kv_cache: KVCache, kv_cache: torch.Tensor,
input_metadata: InputMetadata, attn_metadata: AttentionMetadata,
) -> torch.Tensor: ) -> torch.Tensor:
qkv, _ = self.qkv_proj(hidden_states) qkv, _ = self.qkv_proj(hidden_states)
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
q, k = self.rotary_emb(positions, q, k) q, k = self.rotary_emb(positions, q, k)
k_cache, v_cache = kv_cache attn_output = self.attn(q, k, v, kv_cache, attn_metadata)
attn_output = self.attn(q, k, v, k_cache, v_cache, input_metadata)
output, _ = self.o_proj(attn_output) output, _ = self.o_proj(attn_output)
return output return output
...@@ -294,8 +290,8 @@ class DeepseekDecoderLayer(nn.Module): ...@@ -294,8 +290,8 @@ class DeepseekDecoderLayer(nn.Module):
self, self,
positions: torch.Tensor, positions: torch.Tensor,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
kv_cache: KVCache, kv_cache: torch.Tensor,
input_metadata: InputMetadata, attn_metadata: AttentionMetadata,
residual: Optional[torch.Tensor], residual: Optional[torch.Tensor],
) -> torch.Tensor: ) -> torch.Tensor:
# Self Attention # Self Attention
...@@ -309,7 +305,7 @@ class DeepseekDecoderLayer(nn.Module): ...@@ -309,7 +305,7 @@ class DeepseekDecoderLayer(nn.Module):
positions=positions, positions=positions,
hidden_states=hidden_states, hidden_states=hidden_states,
kv_cache=kv_cache, kv_cache=kv_cache,
input_metadata=input_metadata, attn_metadata=attn_metadata,
) )
# Fully Connected # Fully Connected
...@@ -346,15 +342,15 @@ class DeepseekModel(nn.Module): ...@@ -346,15 +342,15 @@ class DeepseekModel(nn.Module):
self, self,
input_ids: torch.Tensor, input_ids: torch.Tensor,
positions: torch.Tensor, positions: torch.Tensor,
kv_caches: List[KVCache], kv_caches: List[torch.Tensor],
input_metadata: InputMetadata, attn_metadata: AttentionMetadata,
) -> torch.Tensor: ) -> torch.Tensor:
hidden_states = self.embed_tokens(input_ids) hidden_states = self.embed_tokens(input_ids)
residual = None residual = None
for i in range(len(self.layers)): for i in range(len(self.layers)):
layer = self.layers[i] layer = self.layers[i]
hidden_states, residual = layer(positions, hidden_states, hidden_states, residual = layer(positions, hidden_states,
kv_caches[i], input_metadata, kv_caches[i], attn_metadata,
residual) residual)
hidden_states, _ = self.norm(hidden_states, residual) hidden_states, _ = self.norm(hidden_states, residual)
return hidden_states return hidden_states
...@@ -379,11 +375,11 @@ class DeepseekForCausalLM(nn.Module): ...@@ -379,11 +375,11 @@ class DeepseekForCausalLM(nn.Module):
self, self,
input_ids: torch.Tensor, input_ids: torch.Tensor,
positions: torch.Tensor, positions: torch.Tensor,
kv_caches: List[KVCache], kv_caches: List[torch.Tensor],
input_metadata: InputMetadata, attn_metadata: AttentionMetadata,
) -> torch.Tensor: ) -> torch.Tensor:
hidden_states = self.model(input_ids, positions, kv_caches, hidden_states = self.model(input_ids, positions, kv_caches,
input_metadata) attn_metadata)
return hidden_states return hidden_states
def compute_logits(self, hidden_states: torch.Tensor, def compute_logits(self, hidden_states: torch.Tensor,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment