Unverified Commit 2daf23ab authored by Woosuk Kwon's avatar Woosuk Kwon Committed by GitHub
Browse files

Separate attention backends (#3005)

parent cbf4c05b
...@@ -184,3 +184,6 @@ _build/ ...@@ -184,3 +184,6 @@ _build/
# Benchmark dataset # Benchmark dataset
*.json *.json
# Third-party Python packages.
vllm/thirdparty_files/
...@@ -3,6 +3,7 @@ import io ...@@ -3,6 +3,7 @@ import io
import os import os
import re import re
import subprocess import subprocess
import sys
import warnings import warnings
from pathlib import Path from pathlib import Path
from typing import List, Set from typing import List, Set
...@@ -14,6 +15,8 @@ import torch.utils.cpp_extension as torch_cpp_ext ...@@ -14,6 +15,8 @@ import torch.utils.cpp_extension as torch_cpp_ext
from torch.utils.cpp_extension import BuildExtension, CUDAExtension, CUDA_HOME, ROCM_HOME from torch.utils.cpp_extension import BuildExtension, CUDAExtension, CUDA_HOME, ROCM_HOME
ROOT_DIR = os.path.dirname(__file__) ROOT_DIR = os.path.dirname(__file__)
# This is a temporary directory to store third-party packages.
THIRDPARTY_SUBDIR = "vllm/thirdparty_files"
# If you are developing the C++ backend of vLLM, consider building vLLM with # If you are developing the C++ backend of vLLM, consider building vLLM with
# `python setup.py develop` since it will give you incremental builds. # `python setup.py develop` since it will give you incremental builds.
...@@ -324,7 +327,45 @@ if _is_cuda(): ...@@ -324,7 +327,45 @@ if _is_cuda():
"nvcc": NVCC_FLAGS_PUNICA, "nvcc": NVCC_FLAGS_PUNICA,
}, },
)) ))
elif _is_neuron():
# Download the FlashAttention package.
# Adapted from https://github.com/ray-project/ray/blob/f92928c9cfcbbf80c3a8534ca4911de1b44069c0/python/setup.py#L518-L530
flash_attn_version = "2.5.6"
install_dir = os.path.join(ROOT_DIR, THIRDPARTY_SUBDIR)
subprocess.check_call(
[
sys.executable,
"-m",
"pip",
"install",
"-q",
f"--target={install_dir}",
"einops", # Dependency of flash-attn.
f"flash-attn=={flash_attn_version}",
"--no-dependencies", # Required to avoid re-installing torch.
],
env=dict(os.environ, CC="gcc"),
)
# Copy the FlashAttention package into the vLLM package after build.
class build_ext(BuildExtension):
def run(self):
super().run()
target_dir = os.path.join(self.build_lib, THIRDPARTY_SUBDIR)
if not os.path.exists(target_dir):
os.makedirs(target_dir)
self.copy_tree(install_dir, target_dir)
class BinaryDistribution(setuptools.Distribution):
def has_ext_modules(self):
return True
else:
build_ext = BuildExtension
BinaryDistribution = setuptools.Distribution
if _is_neuron():
neuronxcc_version = get_neuronxcc_version() neuronxcc_version = get_neuronxcc_version()
vllm_extension_sources = [ vllm_extension_sources = [
...@@ -468,6 +509,7 @@ setuptools.setup( ...@@ -468,6 +509,7 @@ setuptools.setup(
python_requires=">=3.8", python_requires=">=3.8",
install_requires=get_requirements(), install_requires=get_requirements(),
ext_modules=ext_modules, ext_modules=ext_modules,
cmdclass={"build_ext": BuildExtension} if not _is_neuron() else {}, cmdclass={"build_ext": build_ext} if not _is_neuron() else {},
distclass=BinaryDistribution,
package_data=package_data, package_data=package_data,
) )
...@@ -3,7 +3,7 @@ import pytest ...@@ -3,7 +3,7 @@ import pytest
import time import time
import torch import torch
from vllm.model_executor.layers.triton_kernel.prefix_prefill import ( from vllm.model_executor.layers.attention.ops.prefix_prefill import (
context_attention_fwd) context_attention_fwd)
from xformers import ops as xops from xformers import ops as xops
from xformers.ops.fmha.attn_bias import BlockDiagonalCausalFromBottomRightMask from xformers.ops.fmha.attn_bias import BlockDiagonalCausalFromBottomRightMask
......
"""vLLM: a high-throughput and memory-efficient inference engine for LLMs""" """vLLM: a high-throughput and memory-efficient inference engine for LLMs"""
from vllm.engine.arg_utils import AsyncEngineArgs, EngineArgs
from vllm.engine.async_llm_engine import AsyncLLMEngine # Adapted from https://github.com/ray-project/ray/blob/f92928c9cfcbbf80c3a8534ca4911de1b44069c0/python/ray/__init__.py#L11
from vllm.engine.llm_engine import LLMEngine def _configure_system():
from vllm.engine.ray_utils import initialize_cluster import os
from vllm.entrypoints.llm import LLM import sys
from vllm.outputs import CompletionOutput, RequestOutput
from vllm.sampling_params import SamplingParams # Importing flash-attn.
thirdparty_files = os.path.join(os.path.abspath(os.path.dirname(__file__)),
"thirdparty_files")
sys.path.insert(0, thirdparty_files)
_configure_system()
# Delete configuration function.
del _configure_system
from vllm.engine.arg_utils import AsyncEngineArgs, EngineArgs # noqa: E402
from vllm.engine.async_llm_engine import AsyncLLMEngine # noqa: E402
from vllm.engine.llm_engine import LLMEngine # noqa: E402
from vllm.engine.ray_utils import initialize_cluster # noqa: E402
from vllm.entrypoints.llm import LLM # noqa: E402
from vllm.outputs import CompletionOutput, RequestOutput # noqa: E402
from vllm.sampling_params import SamplingParams # noqa: E402
__version__ = "0.3.3" __version__ = "0.3.3"
......
from vllm.model_executor.layers.attention.attention import Attention
__all__ = [
"Attention",
]
"""Attention layer."""
from typing import List, Optional
import torch
import torch.nn as nn
from vllm.model_executor.input_metadata import InputMetadata
from vllm.utils import is_hip
class Attention(nn.Module):
"""Attention layer.
This class takes query, key, and value tensors as input. The input tensors
can either contain prompt tokens or generation tokens.
The class does the following:
1. Store the input key and value tensors in the KV cache.
2. Perform (multi-head/multi-query/grouped-query) attention.
3. Return the output tensor.
"""
def __init__(
self,
num_heads: int,
head_size: int,
scale: float,
num_kv_heads: Optional[int] = None,
alibi_slopes: Optional[List[float]] = None,
sliding_window: Optional[int] = None,
) -> None:
super().__init__()
if (not is_hip() and torch.cuda.get_device_capability()[0] >= 8 and
torch.get_default_dtype() in (torch.float16, torch.bfloat16)):
# Ampere or later NVIDIA GPUs.
# NOTE(woosuk): FlashAttention does not support FP32.
from vllm.model_executor.layers.attention.backends.flash_attn import FlashAttentionBackend
self.backend = FlashAttentionBackend(num_heads, head_size, scale,
num_kv_heads, alibi_slopes,
sliding_window)
else:
# Turing and Volta NVIDIA GPUs or AMD GPUs.
# Or FP32 on any GPU.
from vllm.model_executor.layers.attention.backends.xformers import XFormersBackend
self.backend = XFormersBackend(num_heads, head_size, scale,
num_kv_heads, alibi_slopes,
sliding_window)
def forward(
self,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
key_cache: Optional[torch.Tensor],
value_cache: Optional[torch.Tensor],
input_metadata: InputMetadata,
) -> torch.Tensor:
return self.backend.forward(query, key, value, key_cache, value_cache,
input_metadata)
"""Attention layer with Flash and PagedAttention."""
from typing import List, Optional
# NOTE(woosuk): This imports flash_attn under vllm/thirdparty_files/.
from flash_attn import flash_attn_func
import torch
from vllm.model_executor.input_metadata import InputMetadata
from vllm.model_executor.layers.attention.ops.paged_attn import (
PagedAttentionImpl)
class FlashAttentionBackend:
def __init__(
self,
num_heads: int,
head_size: int,
scale: float,
num_kv_heads: Optional[int] = None,
alibi_slopes: Optional[List[float]] = None,
sliding_window: Optional[int] = None,
) -> None:
self.num_heads = num_heads
self.head_size = head_size
self.scale = float(scale)
self.num_kv_heads = num_heads if num_kv_heads is None else num_kv_heads
self.sliding_window = sliding_window
if alibi_slopes is not None:
alibi_slopes = torch.tensor(alibi_slopes, dtype=torch.float32)
self.alibi_slopes = alibi_slopes
assert self.num_heads % self.num_kv_heads == 0
self.num_queries_per_kv = self.num_heads // self.num_kv_heads
suppored_head_sizes = PagedAttentionImpl.get_supported_head_sizes()
if head_size not in suppored_head_sizes:
raise ValueError(
f"Head size {head_size} is not supported by PagedAttention. "
f"Supported head sizes are: {suppored_head_sizes}.")
self.sliding_window = ((self.sliding_window, self.sliding_window) if
self.sliding_window is not None else (-1, -1))
def forward(
self,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
key_cache: Optional[torch.Tensor],
value_cache: Optional[torch.Tensor],
input_metadata: InputMetadata,
) -> torch.Tensor:
"""Forward pass with FlashAttention and PagedAttention.
Args:
query: shape = [batch_size, seq_len, num_heads * head_size]
key: shape = [batch_size, seq_len, num_kv_heads * head_size]
value: shape = [batch_size, seq_len, num_kv_heads * head_size]
key_cache: shape = [num_blocks, num_kv_heads, head_size/x,
block_size, x]
value_cache: shape = [num_blocks, num_kv_heads, head_size,
block_size]
input_metadata: metadata for the inputs.
Returns:
shape = [batch_size, seq_len, num_heads * head_size]
"""
batch_size, seq_len, hidden_size = query.shape
# Reshape the query, key, and value tensors.
query = query.view(-1, self.num_heads, self.head_size)
key = key.view(-1, self.num_kv_heads, self.head_size)
value = value.view(-1, self.num_kv_heads, self.head_size)
# Reshape the keys and values and store them in the cache.
# If key_cache and value_cache are not provided, the new key and value
# vectors will not be cached. This happens during the initial memory
# profiling run.
if key_cache is not None and value_cache is not None:
PagedAttentionImpl.reshape_and_cache(key, value, key_cache,
value_cache, input_metadata)
if input_metadata.is_prompt:
# Prompt run.
if (key_cache is None or value_cache is None
or input_metadata.block_tables.numel() == 0):
# normal attention
query = query.unflatten(0, (batch_size, seq_len))
key = key.unflatten(0, (batch_size, seq_len))
value = value.unflatten(0, (batch_size, seq_len))
output = flash_attn_func(
query,
key,
value,
softmax_scale=self.scale,
causal=True,
window_size=self.sliding_window,
alibi_slopes=self.alibi_slopes,
)
else:
# prefix-enabled attention
output = PagedAttentionImpl.forward_prefix(
query,
key,
value,
key_cache,
value_cache,
input_metadata,
self.num_heads,
self.num_kv_heads,
self.alibi_slopes,
)
else:
# Decoding run.
output = PagedAttentionImpl.forward_decode(
query,
key_cache,
value_cache,
input_metadata,
self.num_kv_heads,
self.scale,
self.alibi_slopes,
)
# Reshape the output tensor.
return output.view(batch_size, seq_len, hidden_size)
"""Multi-head attention.""" """Attention layer with xFormers and PagedAttention."""
import importlib
from typing import List, Optional from typing import List, Optional
import importlib
import torch import torch
import torch.nn as nn
from xformers import ops as xops from xformers import ops as xops
from xformers.ops.fmha.attn_bias import (BlockDiagonalCausalMask, from xformers.ops.fmha.attn_bias import (BlockDiagonalCausalMask,
LowerTriangularMaskWithTensorBias) LowerTriangularMaskWithTensorBias)
from vllm._C import ops
from vllm._C import cache_ops
from vllm.model_executor.input_metadata import InputMetadata from vllm.model_executor.input_metadata import InputMetadata
from vllm.model_executor.layers.triton_kernel.prefix_prefill import ( from vllm.model_executor.layers.attention.ops.paged_attn import (
context_attention_fwd) PagedAttentionImpl)
from vllm.utils import is_hip from vllm.utils import is_hip
_SUPPORTED_HEAD_SIZES = [64, 80, 96, 112, 128, 256]
# Should be the same as PARTITION_SIZE in `paged_attention_v2_launcher`.
_PARTITION_SIZE = 512
class PagedAttention(nn.Module): class XFormersBackend:
"""MHA/MQA/GQA layer with PagedAttention.
This class takes query, key, and value tensors as input. The input tensors
can either contain prompt tokens or generation tokens.
The class does the following:
1. Reshape and store the input key and value tensors in the KV cache.
2. Perform (multi-head/multi-query/grouped-query) attention using either
xformers or the PagedAttention custom op.
3. Return the output tensor.
"""
def __init__( def __init__(
self, self,
...@@ -42,7 +24,6 @@ class PagedAttention(nn.Module): ...@@ -42,7 +24,6 @@ class PagedAttention(nn.Module):
alibi_slopes: Optional[List[float]] = None, alibi_slopes: Optional[List[float]] = None,
sliding_window: Optional[int] = None, sliding_window: Optional[int] = None,
) -> None: ) -> None:
super().__init__()
self.num_heads = num_heads self.num_heads = num_heads
self.head_size = head_size self.head_size = head_size
self.scale = float(scale) self.scale = float(scale)
...@@ -50,48 +31,17 @@ class PagedAttention(nn.Module): ...@@ -50,48 +31,17 @@ class PagedAttention(nn.Module):
self.sliding_window = sliding_window self.sliding_window = sliding_window
if alibi_slopes is not None: if alibi_slopes is not None:
alibi_slopes = torch.tensor(alibi_slopes, dtype=torch.float32) alibi_slopes = torch.tensor(alibi_slopes, dtype=torch.float32)
self.register_buffer("alibi_slopes", alibi_slopes, persistent=False) self.alibi_slopes = alibi_slopes
assert self.num_heads % self.num_kv_heads == 0 assert self.num_heads % self.num_kv_heads == 0
self.num_queries_per_kv = self.num_heads // self.num_kv_heads self.num_queries_per_kv = self.num_heads // self.num_kv_heads
suppored_head_sizes = PagedAttentionImpl.get_supported_head_sizes()
if head_size not in suppored_head_sizes:
raise ValueError(
f"Head size {head_size} is not supported by PagedAttention. "
f"Supported head sizes are: {suppored_head_sizes}.")
if self.head_size not in _SUPPORTED_HEAD_SIZES: self.use_ref_attention = _check_use_ref_attention()
raise ValueError(f"head_size ({self.head_size}) is not supported. "
f"Supported head sizes: {_SUPPORTED_HEAD_SIZES}.")
self.use_ref_attention = self.check_use_ref_attention()
def check_use_ref_attention(self) -> bool:
if not is_hip():
return False
# For ROCm, check whether flash attention is installed or not.
# if not, use_ref_attention needs to be True
return importlib.util.find_spec("flash_attn") is None
def ref_masked_attention(
self,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
) -> torch.Tensor:
query = query.view(-1, self.num_heads, self.head_size)
key = key.view(-1, self.num_kv_heads, self.head_size)
value = value.view(-1, self.num_kv_heads, self.head_size)
seq_len, _, _ = query.shape
attn_mask = torch.triu(torch.ones(seq_len,
seq_len,
dtype=query.dtype,
device=query.device),
diagonal=1)
attn_mask = attn_mask * torch.finfo(query.dtype).min
attn_weights = self.scale * torch.einsum("qhd,khd->hqk", query,
key).float()
attn_weights = attn_weights + attn_mask.float()
attn_weights = torch.softmax(attn_weights, dim=-1).to(value.dtype)
out = torch.einsum("hqk,khd->qhd", attn_weights, value)
return out
def forward( def forward(
self, self,
...@@ -102,7 +52,7 @@ class PagedAttention(nn.Module): ...@@ -102,7 +52,7 @@ class PagedAttention(nn.Module):
value_cache: Optional[torch.Tensor], value_cache: Optional[torch.Tensor],
input_metadata: InputMetadata, input_metadata: InputMetadata,
) -> torch.Tensor: ) -> torch.Tensor:
"""PagedAttention forward pass. """Forward pass with xFormers and PagedAttention.
Args: Args:
query: shape = [batch_size, seq_len, num_heads * head_size] query: shape = [batch_size, seq_len, num_heads * head_size]
...@@ -127,19 +77,14 @@ class PagedAttention(nn.Module): ...@@ -127,19 +77,14 @@ class PagedAttention(nn.Module):
# vectors will not be cached. This happens during the initial memory # vectors will not be cached. This happens during the initial memory
# profiling run. # profiling run.
if key_cache is not None and value_cache is not None: if key_cache is not None and value_cache is not None:
cache_ops.reshape_and_cache( PagedAttentionImpl.reshape_and_cache(key, value, key_cache,
key, value_cache, input_metadata)
value,
key_cache,
value_cache,
input_metadata.slot_mapping.flatten(),
input_metadata.kv_cache_dtype,
)
if input_metadata.is_prompt: if input_metadata.is_prompt:
# normal attention # Prompt run.
if (key_cache is None or value_cache is None if (key_cache is None or value_cache is None
or input_metadata.block_tables.numel() == 0): or input_metadata.block_tables.numel() == 0):
# normal attention
if self.num_kv_heads != self.num_heads: if self.num_kv_heads != self.num_heads:
# As of Nov 2023, xformers only supports MHA. For MQA/GQA, # As of Nov 2023, xformers only supports MHA. For MQA/GQA,
# project the key and value tensors to the desired number of # project the key and value tensors to the desired number of
...@@ -175,13 +120,19 @@ class PagedAttention(nn.Module): ...@@ -175,13 +120,19 @@ class PagedAttention(nn.Module):
seq_len, query.dtype) seq_len, query.dtype)
if self.use_ref_attention: if self.use_ref_attention:
output = self.ref_masked_attention( output = _ref_masked_attention(
query, query,
key, key,
value, value,
self.num_heads,
self.num_kv_heads,
self.head_size,
self.scale,
) )
# Using view got RuntimeError: view size is not compatible with input tensor's size and stride # Using view got RuntimeError: view size is not compatible
# (at least one dimension spans across two contiguous subspaces). Use reshape instead # with input tensor's size and stride (at least one
# dimension spans across two contiguous subspaces).
# Use reshape instead.
return output.reshape(batch_size, seq_len, hidden_size) return output.reshape(batch_size, seq_len, hidden_size)
# TODO(woosuk): Too many view operations. Let's try to reduce # TODO(woosuk): Too many view operations. Let's try to reduce
...@@ -206,27 +157,21 @@ class PagedAttention(nn.Module): ...@@ -206,27 +157,21 @@ class PagedAttention(nn.Module):
(is_hip()) else None, (is_hip()) else None,
) )
output = out.view_as(query) output = out.view_as(query)
else: else:
# prefix-enabled attention # prefix-enabled attention
output = torch.empty_like(query) output = PagedAttentionImpl.forward_prefix(
context_attention_fwd(
query, query,
key, key,
value, value,
output,
key_cache, key_cache,
value_cache, value_cache,
input_metadata.block_tables, # [BS, max_block_per_request] input_metadata,
input_metadata.start_loc, self.alibi_slopes,
input_metadata.prompt_lens,
input_metadata.context_lens,
input_metadata.max_seq_len,
getattr(self, "alibi_slopes", None),
) )
else: else:
# Decoding run. # Decoding run.
output = _paged_attention( output = PagedAttentionImpl.forward_decode(
query, query,
key_cache, key_cache,
value_cache, value_cache,
...@@ -274,76 +219,37 @@ def _make_alibi_bias( ...@@ -274,76 +219,37 @@ def _make_alibi_bias(
return attn_bias return attn_bias
def _paged_attention( def _check_use_ref_attention() -> bool:
if not is_hip():
return False
# For ROCm, check whether flash attention is installed or not.
# if not, use_ref_attention needs to be True
return importlib.util.find_spec("flash_attn") is None
def _ref_masked_attention(
query: torch.Tensor, query: torch.Tensor,
key_cache: torch.Tensor, key: torch.Tensor,
value_cache: torch.Tensor, value: torch.Tensor,
input_metadata: InputMetadata, num_heads: int,
num_kv_heads: int, num_kv_heads: int,
head_size: int,
scale: float, scale: float,
alibi_slopes: Optional[torch.Tensor],
) -> torch.Tensor: ) -> torch.Tensor:
output = torch.empty_like(query) query = query.view(-1, num_heads, head_size)
key = key.view(-1, num_kv_heads, head_size)
value = value.view(-1, num_kv_heads, head_size)
block_size = value_cache.shape[3] seq_len, _, _ = query.shape
num_seqs, num_heads, head_size = query.shape attn_mask = torch.triu(torch.ones(seq_len,
max_num_partitions = ( seq_len,
(input_metadata.max_context_len + _PARTITION_SIZE - 1) // dtype=query.dtype,
_PARTITION_SIZE) device=query.device),
# NOTE(woosuk): We use a simple heuristic to decide whether to use diagonal=1)
# PagedAttention V1 or V2. If the number of partitions is 1, we use attn_mask = attn_mask * torch.finfo(query.dtype).min
# V1 to avoid the overhead of reduction. Also, if the number of
# sequences or heads is large, we use V1 since there is enough work attn_weights = scale * torch.einsum("qhd,khd->hqk", query, key).float()
# to parallelize. attn_weights = attn_weights + attn_mask.float()
# TODO(woosuk): Tune this heuristic. attn_weights = torch.softmax(attn_weights, dim=-1).to(value.dtype)
# For context len > 8192, use V2 kernel to avoid shared memory shortage. out = torch.einsum("hqk,khd->qhd", attn_weights, value)
use_v1 = input_metadata.max_context_len <= 8192 and ( return out
max_num_partitions == 1 or num_seqs * num_heads > 512)
if use_v1:
# Run PagedAttention V1.
ops.paged_attention_v1(
output,
query,
key_cache,
value_cache,
num_kv_heads,
scale,
input_metadata.block_tables,
input_metadata.context_lens,
block_size,
input_metadata.max_context_len,
alibi_slopes,
input_metadata.kv_cache_dtype,
)
else:
# Run PagedAttention V2.
assert _PARTITION_SIZE % block_size == 0
tmp_output = torch.empty(
size=(num_seqs, num_heads, max_num_partitions, head_size),
dtype=output.dtype,
device=output.device,
)
exp_sums = torch.empty(
size=(num_seqs, num_heads, max_num_partitions),
dtype=torch.float32,
device=output.device,
)
max_logits = torch.empty_like(exp_sums)
ops.paged_attention_v2(
output,
exp_sums,
max_logits,
tmp_output,
query,
key_cache,
value_cache,
num_kv_heads,
scale,
input_metadata.block_tables,
input_metadata.context_lens,
block_size,
input_metadata.max_context_len,
alibi_slopes,
input_metadata.kv_cache_dtype,
)
return output
from typing import List, Optional
import torch
from vllm._C import cache_ops
from vllm._C import ops
from vllm.model_executor.input_metadata import InputMetadata
from vllm.model_executor.layers.attention.ops.prefix_prefill import (
context_attention_fwd)
# Should be the same as PARTITION_SIZE in `paged_attention_v2_launcher`.
_PARTITION_SIZE = 512
class PagedAttentionImpl:
@staticmethod
def get_supported_head_sizes() -> List[int]:
return [64, 80, 96, 112, 128, 256]
@staticmethod
def reshape_and_cache(
key: torch.Tensor,
value: torch.Tensor,
key_cache: torch.Tensor,
value_cache: torch.Tensor,
input_metadata: InputMetadata,
) -> None:
cache_ops.reshape_and_cache(
key,
value,
key_cache,
value_cache,
input_metadata.slot_mapping.flatten(),
input_metadata.kv_cache_dtype,
)
@staticmethod
def forward_decode(
query: torch.Tensor,
key_cache: torch.Tensor,
value_cache: torch.Tensor,
input_metadata: InputMetadata,
num_kv_heads: int,
scale: float,
alibi_slopes: Optional[torch.Tensor],
) -> torch.Tensor:
output = torch.empty_like(query)
block_size = value_cache.shape[3]
num_seqs, num_heads, head_size = query.shape
max_num_partitions = (
(input_metadata.max_context_len + _PARTITION_SIZE - 1) //
_PARTITION_SIZE)
# NOTE(woosuk): We use a simple heuristic to decide whether to use
# PagedAttention V1 or V2. If the number of partitions is 1, we use
# V1 to avoid the overhead of reduction. Also, if the number of
# sequences or heads is large, we use V1 since there is enough work
# to parallelize.
# TODO(woosuk): Tune this heuristic.
# For context len > 8192, use V2 kernel to avoid shared memory shortage.
use_v1 = input_metadata.max_context_len <= 8192 and (
max_num_partitions == 1 or num_seqs * num_heads > 512)
if use_v1:
# Run PagedAttention V1.
ops.paged_attention_v1(
output,
query,
key_cache,
value_cache,
num_kv_heads,
scale,
input_metadata.block_tables,
input_metadata.context_lens,
block_size,
input_metadata.max_context_len,
alibi_slopes,
input_metadata.kv_cache_dtype,
)
else:
# Run PagedAttention V2.
assert _PARTITION_SIZE % block_size == 0
tmp_output = torch.empty(
size=(num_seqs, num_heads, max_num_partitions, head_size),
dtype=output.dtype,
device=output.device,
)
exp_sums = torch.empty(
size=(num_seqs, num_heads, max_num_partitions),
dtype=torch.float32,
device=output.device,
)
max_logits = torch.empty_like(exp_sums)
ops.paged_attention_v2(
output,
exp_sums,
max_logits,
tmp_output,
query,
key_cache,
value_cache,
num_kv_heads,
scale,
input_metadata.block_tables,
input_metadata.context_lens,
block_size,
input_metadata.max_context_len,
alibi_slopes,
input_metadata.kv_cache_dtype,
)
return output
@staticmethod
def forward_prefix(
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
key_cache: torch.Tensor,
value_cache: torch.Tensor,
input_metadata: InputMetadata,
alibi_slopes: Optional[torch.Tensor],
) -> torch.Tensor:
output = torch.empty_like(query)
context_attention_fwd(
query,
key,
value,
output,
key_cache,
value_cache,
input_metadata.block_tables, # [BS, max_block_per_request]
input_metadata.start_loc,
input_metadata.prompt_lens,
input_metadata.context_lens,
input_metadata.max_seq_len,
alibi_slopes,
)
return output
...@@ -27,7 +27,7 @@ from transformers import PretrainedConfig ...@@ -27,7 +27,7 @@ from transformers import PretrainedConfig
from vllm.model_executor.input_metadata import InputMetadata from vllm.model_executor.input_metadata import InputMetadata
from vllm.model_executor.layers.activation import SiluAndMul from vllm.model_executor.layers.activation import SiluAndMul
from vllm.model_executor.layers.attention import PagedAttention from vllm.model_executor.layers.attention import Attention
from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.linear import (LinearMethodBase, from vllm.model_executor.layers.linear import (LinearMethodBase,
MergedColumnParallelLinear, MergedColumnParallelLinear,
...@@ -151,7 +151,7 @@ class BaiChuanAttention(nn.Module): ...@@ -151,7 +151,7 @@ class BaiChuanAttention(nn.Module):
alibi_slopes = alibi_slopes[head_start:head_end].tolist() alibi_slopes = alibi_slopes[head_start:head_end].tolist()
scaling = self.head_dim**-0.5 scaling = self.head_dim**-0.5
self.attn = PagedAttention(self.num_heads, self.attn = Attention(self.num_heads,
self.head_dim, self.head_dim,
scaling, scaling,
alibi_slopes=alibi_slopes) alibi_slopes=alibi_slopes)
...@@ -163,8 +163,7 @@ class BaiChuanAttention(nn.Module): ...@@ -163,8 +163,7 @@ class BaiChuanAttention(nn.Module):
base=self.rope_theta, base=self.rope_theta,
) )
self.scaling = self.head_dim**-0.5 self.scaling = self.head_dim**-0.5
self.attn = PagedAttention(self.num_heads, self.head_dim, self.attn = Attention(self.num_heads, self.head_dim, self.scaling)
self.scaling)
def forward( def forward(
self, self,
......
...@@ -25,7 +25,7 @@ from transformers import BloomConfig ...@@ -25,7 +25,7 @@ from transformers import BloomConfig
from vllm.model_executor.input_metadata import InputMetadata from vllm.model_executor.input_metadata import InputMetadata
from vllm.model_executor.layers.activation import get_act_fn from vllm.model_executor.layers.activation import get_act_fn
from vllm.model_executor.layers.attention import PagedAttention from vllm.model_executor.layers.attention import Attention
from vllm.model_executor.layers.linear import (ColumnParallelLinear, from vllm.model_executor.layers.linear import (ColumnParallelLinear,
LinearMethodBase, LinearMethodBase,
QKVParallelLinear, QKVParallelLinear,
...@@ -107,7 +107,7 @@ class BloomAttention(nn.Module): ...@@ -107,7 +107,7 @@ class BloomAttention(nn.Module):
alibi_slopes = alibi_slopes[head_start:head_end].tolist() alibi_slopes = alibi_slopes[head_start:head_end].tolist()
scaling = self.head_dim**-0.5 scaling = self.head_dim**-0.5
self.attn = PagedAttention(self.num_heads, self.attn = Attention(self.num_heads,
self.head_dim, self.head_dim,
scaling, scaling,
alibi_slopes=alibi_slopes) alibi_slopes=alibi_slopes)
......
...@@ -10,7 +10,7 @@ from torch.nn import LayerNorm ...@@ -10,7 +10,7 @@ from torch.nn import LayerNorm
from vllm.model_executor.input_metadata import InputMetadata from vllm.model_executor.input_metadata import InputMetadata
from vllm.model_executor.layers.activation import SiluAndMul from vllm.model_executor.layers.activation import SiluAndMul
from vllm.model_executor.layers.attention import PagedAttention from vllm.model_executor.layers.attention import Attention
from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.linear import (LinearMethodBase, from vllm.model_executor.layers.linear import (LinearMethodBase,
MergedColumnParallelLinear, MergedColumnParallelLinear,
...@@ -87,7 +87,7 @@ class GLMAttention(nn.Module): ...@@ -87,7 +87,7 @@ class GLMAttention(nn.Module):
base=10000 * rope_ratio, base=10000 * rope_ratio,
is_neox_style=False, is_neox_style=False,
) )
self.attn = PagedAttention( self.attn = Attention(
self.num_heads, self.num_heads,
self.head_dim, self.head_dim,
self.scaling, self.scaling,
......
...@@ -29,7 +29,7 @@ from transformers import PretrainedConfig ...@@ -29,7 +29,7 @@ from transformers import PretrainedConfig
from vllm.model_executor.input_metadata import InputMetadata from vllm.model_executor.input_metadata import InputMetadata
from vllm.model_executor.layers.activation import SiluAndMul from vllm.model_executor.layers.activation import SiluAndMul
from vllm.model_executor.layers.attention import PagedAttention from vllm.model_executor.layers.attention import Attention
from vllm.model_executor.layers.fused_moe import fused_moe from vllm.model_executor.layers.fused_moe import fused_moe
from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.linear import (LinearMethodBase, from vllm.model_executor.layers.linear import (LinearMethodBase,
...@@ -229,7 +229,7 @@ class DeepseekAttention(nn.Module): ...@@ -229,7 +229,7 @@ class DeepseekAttention(nn.Module):
base=rope_theta, base=rope_theta,
rope_scaling=rope_scaling, rope_scaling=rope_scaling,
) )
self.attn = PagedAttention(self.num_heads, self.attn = Attention(self.num_heads,
self.head_dim, self.head_dim,
self.scaling, self.scaling,
num_kv_heads=self.num_kv_heads) num_kv_heads=self.num_kv_heads)
......
...@@ -28,7 +28,7 @@ from transformers import FalconConfig as HF_FalconConfig ...@@ -28,7 +28,7 @@ from transformers import FalconConfig as HF_FalconConfig
from vllm.model_executor.input_metadata import InputMetadata from vllm.model_executor.input_metadata import InputMetadata
from vllm.model_executor.layers.activation import get_act_fn from vllm.model_executor.layers.activation import get_act_fn
from vllm.model_executor.layers.attention import PagedAttention from vllm.model_executor.layers.attention import Attention
from vllm.model_executor.layers.linear import (ColumnParallelLinear, from vllm.model_executor.layers.linear import (ColumnParallelLinear,
LinearMethodBase, LinearMethodBase,
QKVParallelLinear, QKVParallelLinear,
...@@ -150,7 +150,7 @@ class FalconAttention(nn.Module): ...@@ -150,7 +150,7 @@ class FalconAttention(nn.Module):
max_position=max_position_embeddings, max_position=max_position_embeddings,
base=rope_theta, base=rope_theta,
) )
self.attn = PagedAttention(self.num_heads, self.attn = Attention(self.num_heads,
self.head_dim, self.head_dim,
self.inv_norm_factor, self.inv_norm_factor,
num_kv_heads=self.num_kv_heads) num_kv_heads=self.num_kv_heads)
...@@ -161,13 +161,13 @@ class FalconAttention(nn.Module): ...@@ -161,13 +161,13 @@ class FalconAttention(nn.Module):
alibi_slopes = (_get_alibi_slopes(self.total_num_heads) * alibi_slopes = (_get_alibi_slopes(self.total_num_heads) *
self.inv_norm_factor) self.inv_norm_factor)
alibi_slopes = alibi_slopes[head_start:head_end].tolist() alibi_slopes = alibi_slopes[head_start:head_end].tolist()
self.attn = PagedAttention(self.num_heads, self.attn = Attention(self.num_heads,
self.head_dim, self.head_dim,
self.inv_norm_factor, self.inv_norm_factor,
num_kv_heads=self.num_kv_heads, num_kv_heads=self.num_kv_heads,
alibi_slopes=alibi_slopes) alibi_slopes=alibi_slopes)
else: else:
self.attn = PagedAttention(self.num_heads, self.attn = Attention(self.num_heads,
self.head_dim, self.head_dim,
scale=self.inv_norm_factor, scale=self.inv_norm_factor,
num_kv_heads=self.num_kv_heads) num_kv_heads=self.num_kv_heads)
......
...@@ -23,7 +23,7 @@ from transformers import GemmaConfig ...@@ -23,7 +23,7 @@ from transformers import GemmaConfig
from vllm.config import LoRAConfig from vllm.config import LoRAConfig
from vllm.model_executor.input_metadata import InputMetadata from vllm.model_executor.input_metadata import InputMetadata
from vllm.model_executor.layers.activation import GeluAndMul from vllm.model_executor.layers.activation import GeluAndMul
from vllm.model_executor.layers.attention import PagedAttention from vllm.model_executor.layers.attention import Attention
from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.linear import (LinearMethodBase, from vllm.model_executor.layers.linear import (LinearMethodBase,
MergedColumnParallelLinear, MergedColumnParallelLinear,
...@@ -123,7 +123,7 @@ class GemmaAttention(nn.Module): ...@@ -123,7 +123,7 @@ class GemmaAttention(nn.Module):
base=self.rope_theta, base=self.rope_theta,
is_neox_style=True, is_neox_style=True,
) )
self.attn = PagedAttention(self.num_heads, self.attn = Attention(self.num_heads,
self.head_dim, self.head_dim,
self.scaling, self.scaling,
num_kv_heads=self.num_kv_heads) num_kv_heads=self.num_kv_heads)
......
...@@ -25,7 +25,7 @@ from transformers import GPT2Config ...@@ -25,7 +25,7 @@ from transformers import GPT2Config
from vllm.model_executor.input_metadata import InputMetadata from vllm.model_executor.input_metadata import InputMetadata
from vllm.model_executor.layers.activation import get_act_fn from vllm.model_executor.layers.activation import get_act_fn
from vllm.model_executor.layers.attention import PagedAttention from vllm.model_executor.layers.attention import Attention
from vllm.model_executor.layers.linear import (ColumnParallelLinear, from vllm.model_executor.layers.linear import (ColumnParallelLinear,
LinearMethodBase, LinearMethodBase,
QKVParallelLinear, QKVParallelLinear,
...@@ -73,9 +73,7 @@ class GPT2Attention(nn.Module): ...@@ -73,9 +73,7 @@ class GPT2Attention(nn.Module):
bias=True, bias=True,
linear_method=linear_method, linear_method=linear_method,
) )
self.attn = PagedAttention(self.num_heads, self.attn = Attention(self.num_heads, self.head_dim, scale=self.scale)
self.head_dim,
scale=self.scale)
def forward( def forward(
self, self,
......
...@@ -26,7 +26,7 @@ from transformers import GPTBigCodeConfig ...@@ -26,7 +26,7 @@ from transformers import GPTBigCodeConfig
from vllm.model_executor.input_metadata import InputMetadata from vllm.model_executor.input_metadata import InputMetadata
from vllm.model_executor.layers.activation import get_act_fn from vllm.model_executor.layers.activation import get_act_fn
from vllm.model_executor.layers.attention import PagedAttention from vllm.model_executor.layers.attention import Attention
from vllm.model_executor.layers.linear import (ColumnParallelLinear, from vllm.model_executor.layers.linear import (ColumnParallelLinear,
LinearMethodBase, LinearMethodBase,
QKVParallelLinear, QKVParallelLinear,
...@@ -85,7 +85,7 @@ class GPTBigCodeAttention(nn.Module): ...@@ -85,7 +85,7 @@ class GPTBigCodeAttention(nn.Module):
bias=True, bias=True,
linear_method=linear_method, linear_method=linear_method,
) )
self.attn = PagedAttention(self.num_heads, self.attn = Attention(self.num_heads,
self.head_dim, self.head_dim,
scale=self.scale, scale=self.scale,
num_kv_heads=self.num_kv_heads) num_kv_heads=self.num_kv_heads)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment