Unverified Commit a9e45742 authored by Woosuk Kwon's avatar Woosuk Kwon Committed by GitHub
Browse files

Refactor Attention (#1840)

parent 0229c386
"""Multi-head attention.""" """Multi-head attention."""
from typing import Any, Dict, List, Optional from typing import List, Optional
import torch import torch
import torch.nn as nn import torch.nn as nn
...@@ -10,7 +10,6 @@ from xformers.ops.fmha.attn_bias import (BlockDiagonalCausalMask, ...@@ -10,7 +10,6 @@ from xformers.ops.fmha.attn_bias import (BlockDiagonalCausalMask,
from vllm._C import ops from vllm._C import ops
from vllm._C import cache_ops from vllm._C import cache_ops
from vllm.model_executor.input_metadata import InputMetadata from vllm.model_executor.input_metadata import InputMetadata
from vllm.model_executor.layers.rotary_embedding import get_rope
_SUPPORTED_HEAD_SIZES = [64, 80, 96, 112, 128, 256] _SUPPORTED_HEAD_SIZES = [64, 80, 96, 112, 128, 256]
# Should be the same as PARTITION_SIZE in `paged_attention_v2_launcher`. # Should be the same as PARTITION_SIZE in `paged_attention_v2_launcher`.
...@@ -18,37 +17,39 @@ _PARTITION_SIZE = 512 ...@@ -18,37 +17,39 @@ _PARTITION_SIZE = 512
class PagedAttention(nn.Module): class PagedAttention(nn.Module):
"""GPT-style multi-head PagedAttention. """MHA/MQA/GQA layer with PagedAttention.
This class takes query, key, and value tensors as input. The input tensors This class takes query, key, and value tensors as input. The input tensors
can either contain prompt tokens or generation tokens, in addition to can either contain prompt tokens or generation tokens.
paddings.
The class does the following: The class does the following:
1. Perform multi_query_kv_attention for the prompts. This operation does
not use the KV cache. 1. Wait for the cache operations (e.g., swap, copy) to finish. The cache
2. Wait for the cache operations (e.g., swap, copy) to finish. The cache
operations are issued by the cache engine before executing the forward operations are issued by the cache engine before executing the forward
pass of the model, and they are executed asynchronously. pass of the model, and they are executed asynchronously.
3. Reshape and store the input key and value tensors in the KV cache. 2. Reshape and store the input key and value tensors in the KV cache.
4. Perform single_query_cached_kv_attention for the generation tokens. 3. Perform (multi-head/multi-query/grouped-query) attention using either
This operation reads the previous key and value tensors from the KV xformers or the PagedAttention custom op.
cache. 4. Return the output tensor.
5. Return the output tensor.
""" """
def __init__(self, def __init__(
self,
num_heads: int, num_heads: int,
head_size: int, head_size: int,
scale: float, scale: float,
num_kv_heads: Optional[int] = None, num_kv_heads: Optional[int] = None,
sliding_window: Optional[int] = None) -> None: alibi_slopes: Optional[List[float]] = None,
sliding_window: Optional[int] = None,
) -> None:
super().__init__() super().__init__()
self.num_heads = num_heads self.num_heads = num_heads
self.head_size = head_size self.head_size = head_size
self.scale = float(scale) self.scale = float(scale)
self.num_kv_heads = num_heads if num_kv_heads is None else num_kv_heads self.num_kv_heads = num_heads if num_kv_heads is None else num_kv_heads
self.sliding_window = sliding_window self.sliding_window = sliding_window
if alibi_slopes is not None:
alibi_slopes = torch.tensor(alibi_slopes, dtype=torch.float32)
self.register_buffer("alibi_slopes", alibi_slopes, persistent=False)
assert self.num_heads % self.num_kv_heads == 0 assert self.num_heads % self.num_kv_heads == 0
self.num_queries_per_kv = self.num_heads // self.num_kv_heads self.num_queries_per_kv = self.num_heads // self.num_kv_heads
...@@ -60,93 +61,173 @@ class PagedAttention(nn.Module): ...@@ -60,93 +61,173 @@ class PagedAttention(nn.Module):
raise ValueError(f"head_size ({self.head_size}) is not supported. " raise ValueError(f"head_size ({self.head_size}) is not supported. "
f"Supported head sizes: {_SUPPORTED_HEAD_SIZES}.") f"Supported head sizes: {_SUPPORTED_HEAD_SIZES}.")
def set_attn_bias( def forward(
self,
input_metadata: InputMetadata,
dtype: torch.dtype,
) -> None:
del dtype # Unused.
if input_metadata.attn_bias is not None:
# Already set by a previous layer.
return
prompt_lens = [input_metadata.max_prompt_len
] * input_metadata.num_prompts
attn_bias = BlockDiagonalCausalMask.from_seqlens(prompt_lens)
if self.sliding_window is not None:
attn_bias = attn_bias.make_local_attention(self.sliding_window)
input_metadata.attn_bias = attn_bias
def multi_query_kv_attention(
self, self,
output: torch.Tensor,
query: torch.Tensor, query: torch.Tensor,
key: torch.Tensor, key: torch.Tensor,
value: torch.Tensor, value: torch.Tensor,
key_cache: Optional[torch.Tensor],
value_cache: Optional[torch.Tensor],
input_metadata: InputMetadata, input_metadata: InputMetadata,
cache_event: Optional[torch.cuda.Event],
) -> torch.Tensor: ) -> torch.Tensor:
"""Normal attention for the prompt tokens. """PagedAttention forward pass.
Args: Args:
output: shape = [num_prompt_tokens, num_heads, head_size] query: shape = [batch_size, seq_len, num_heads * head_size]
query: shape = [num_prompt_tokens, num_heads, head_size] key: shape = [batch_size, seq_len, num_kv_heads * head_size]
key: shape = [num_prompt_tokens, num_kv_heads, head_size] value: shape = [batch_size, num_kv_heads * head_size]
value: shape = [num_prompt_tokens, num_kv_heads, head_size] key_cache: shape = [num_blocks, num_kv_heads, head_size/x,
input_metadata: metadata for paged attention. block_size, x]
value_cache: shape = [num_blocks, num_kv_heads, head_size,
block_size]
input_metadata: metadata for the inputs.
cache_event: event to wait for the cache operations to finish.
Returns:
shape = [batch_size, seq_len, num_heads * head_size]
""" """
batch_size, seq_len, hidden_size = query.shape
# Reshape the query, key, and value tensors.
query = query.view(-1, self.num_heads, self.head_size)
key = key.view(-1, self.num_kv_heads, self.head_size)
value = value.view(-1, self.num_kv_heads, self.head_size)
slot_mapping = input_metadata.slot_mapping.flatten()
if cache_event is not None:
cache_event.wait()
# Reshape the keys and values and store them in the cache.
# If key_cache and value_cache are not provided, the new key and value
# vectors will not be cached. This happens during the initial memory
# profiling run.
if key_cache is not None and value_cache is not None:
key_to_cache = key
value_to_cache = value
if input_metadata.to_cache is not None:
key_to_cache = key_to_cache[input_metadata.to_cache]
value_to_cache = value_to_cache[input_metadata.to_cache]
slot_mapping = slot_mapping[input_metadata.to_cache]
cache_ops.reshape_and_cache(
key_to_cache,
value_to_cache,
key_cache,
value_cache,
slot_mapping,
)
is_prompt = len(input_metadata.prompt_lens) > 0
if is_prompt:
# Prompt run.
if self.num_kv_heads != self.num_heads: if self.num_kv_heads != self.num_heads:
# Project the key and value tensors to the desired number of heads. # As of Nov 2023, xformers only supports MHA. For MQA/GQA,
# project the key and value tensors to the desired number of
# heads.
# TODO(woosuk): Use MQA/GQA kernels for higher performance.
query = query.view(query.shape[0], self.num_kv_heads, query = query.view(query.shape[0], self.num_kv_heads,
self.num_queries_per_kv, query.shape[-1]) self.num_queries_per_kv, query.shape[-1])
key = key[:, :, key = key[:, :,
None, :].expand(key.shape[0], self.num_kv_heads, None, :].expand(key.shape[0], self.num_kv_heads,
self.num_queries_per_kv, key.shape[-1]) self.num_queries_per_kv,
value = value[:, :, key.shape[-1])
None, :].expand(value.shape[0], self.num_kv_heads, value = value[:, :, None, :].expand(value.shape[0],
self.num_kv_heads,
self.num_queries_per_kv, self.num_queries_per_kv,
value.shape[-1]) value.shape[-1])
# TODO(woosuk): The unsqueeze op may incur some CPU overhead. Optimize. # Set attention bias if not provided. This typically happens at the
# very attention layer of every iteration.
# FIXME(woosuk): This is a hack.
if input_metadata.attn_bias is None:
if self.alibi_slopes is None:
attn_bias = BlockDiagonalCausalMask.from_seqlens(
[seq_len] * batch_size)
if self.sliding_window is not None:
attn_bias = attn_bias.make_local_attention(
self.sliding_window)
input_metadata.attn_bias = attn_bias
else:
input_metadata.attn_bias = _make_alibi_bias(
self.alibi_slopes, batch_size, seq_len, query.dtype)
# TODO(woosuk): Too many view operations. Let's try to reduce them
# in the future for code readability.
if self.alibi_slopes is None:
query = query.unsqueeze(0)
key = key.unsqueeze(0)
value = value.unsqueeze(0)
else:
query = query.unflatten(0, (batch_size, seq_len))
key = key.unflatten(0, (batch_size, seq_len))
value = value.unflatten(0, (batch_size, seq_len))
out = xops.memory_efficient_attention_forward( out = xops.memory_efficient_attention_forward(
query.unsqueeze(0), query,
key.unsqueeze(0), key,
value.unsqueeze(0), value,
attn_bias=input_metadata.attn_bias, attn_bias=input_metadata.attn_bias,
p=0.0, p=0.0,
scale=self.scale, scale=self.scale,
) )
# TODO(woosuk): Unnecessary copy. Optimize. output = out.view_as(query)
output.copy_(out.view_as(output)) else:
return output # Decoding run.
output = _paged_attention(
query,
key_cache,
value_cache,
input_metadata,
self.head_mapping,
self.scale,
self.alibi_slopes,
)
def get_alibi_slopes(self) -> Optional[torch.Tensor]: # Reshape the output tensor.
"""Returns the slopes for the alibi attention bias. return output.view(batch_size, seq_len, hidden_size)
Returns:
slopes: shape = [num_heads]
"""
return None
def single_query_cached_kv_attention( def _make_alibi_bias(
self, alibi_slopes: torch.Tensor,
output: torch.Tensor, batch_size: int,
seq_len: int,
dtype: torch.dtype,
) -> LowerTriangularMaskWithTensorBias:
bias = torch.arange(seq_len, dtype=dtype)
# NOTE(zhuohan): HF uses
# `bias = bias[None, :].repeat(prompt_len, 1)`
# here. We find that both biases give the same results, but
# the bias below more accurately follows the original ALiBi
# paper.
bias = bias[None, :] - bias[:, None]
bias = bias.to(alibi_slopes.device)
# When using custom attention bias, xformers requires the bias to
# be sliced from a tensor whose length is a multiple of 8.
padded_len = (seq_len + 7) // 8 * 8
bias = torch.empty(
batch_size,
alibi_slopes.shape[0],
seq_len,
padded_len,
device=alibi_slopes.device,
dtype=dtype,
)[:, :, :, :seq_len].copy_(bias)
bias.mul_(alibi_slopes[:, None, None])
attn_bias = LowerTriangularMaskWithTensorBias(bias)
return attn_bias
def _paged_attention(
query: torch.Tensor, query: torch.Tensor,
key_cache: torch.Tensor, key_cache: torch.Tensor,
value_cache: torch.Tensor, value_cache: torch.Tensor,
input_metadata: InputMetadata, input_metadata: InputMetadata,
head_mapping: torch.Tensor,
scale: float,
alibi_slopes: Optional[torch.Tensor], alibi_slopes: Optional[torch.Tensor],
) -> None: ) -> torch.Tensor:
"""PagedAttention for the generation tokens. output = torch.empty_like(query)
Args:
output: shape = [num_generation_tokens, num_heads, head_size]
query: shape = [num_generation_tokens, num_heads, head_size]
key_cache: shape = [num_blocks, num_kv_heads, head_size/x,
block_size, x]
value_cache: shape = [num_blocks, num_kv_heads, head_size,
block_size]
input_metadata: metadata for paged attention.
alibi_slopes: shape = [num_heads]
"""
block_size = value_cache.shape[3] block_size = value_cache.shape[3]
num_seqs, num_heads, head_size = query.shape num_seqs, num_heads, head_size = query.shape
max_num_partitions = ( max_num_partitions = (
...@@ -168,8 +249,8 @@ class PagedAttention(nn.Module): ...@@ -168,8 +249,8 @@ class PagedAttention(nn.Module):
query, query,
key_cache, key_cache,
value_cache, value_cache,
self.head_mapping, head_mapping,
self.scale, scale,
input_metadata.block_tables, input_metadata.block_tables,
input_metadata.context_lens, input_metadata.context_lens,
block_size, block_size,
...@@ -198,263 +279,12 @@ class PagedAttention(nn.Module): ...@@ -198,263 +279,12 @@ class PagedAttention(nn.Module):
query, query,
key_cache, key_cache,
value_cache, value_cache,
self.head_mapping, head_mapping,
self.scale, scale,
input_metadata.block_tables, input_metadata.block_tables,
input_metadata.context_lens, input_metadata.context_lens,
block_size, block_size,
input_metadata.max_context_len, input_metadata.max_context_len,
alibi_slopes, alibi_slopes,
) )
def forward(
self,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
key_cache: Optional[torch.Tensor],
value_cache: Optional[torch.Tensor],
input_metadata: InputMetadata,
cache_event: Optional[torch.cuda.Event],
) -> torch.Tensor:
"""PagedAttention forward pass.
NOTE: The query, key, and value tensors must be sliced from a qkv
tensor of shape [batch_size, seq_len, 3 * num_heads * head_size].
Args:
query: shape = [batch_size, seq_len, num_heads * head_size]
key: shape = [batch_size, seq_len, num_kv_heads * head_size]
value: shape = [batch_size, num_kv_heads * head_size]
key_cache: shape = [num_blocks, num_kv_heads, head_size/x,
block_size, x]
value_cache: shape = [num_blocks, num_kv_heads, head_size,
block_size]
input_metadata: metadata for paged attention.
cache_event: event to wait for the cache operations to finish.
Returns:
shape = [batch_size, seq_len, num_heads * head_size]
"""
batch_size, seq_len, _ = query.shape
# Reshape the query, key, and value tensors.
query = query.view(-1, self.num_heads, self.head_size)
key = key.view(-1, self.num_kv_heads, self.head_size)
value = value.view(-1, self.num_kv_heads, self.head_size)
# Pre-allocate the output tensor.
output = torch.empty_like(query)
# Compute the attention op for prompts.
num_prompt_tokens = input_metadata.num_prompt_tokens
if num_prompt_tokens > 0:
# Prompt run.
assert input_metadata.num_generation_tokens == 0
self.set_attn_bias(input_metadata, dtype=query.dtype)
self.multi_query_kv_attention(
output,
query,
key,
value,
input_metadata,
)
# Wait until the cache op is done.
if cache_event is not None:
cache_event.wait()
# Reshape the keys and values and store them in the cache.
# When key_cache and value_cache are not provided, the new key
# and value vectors will not be cached.
if key_cache is not None and value_cache is not None:
key_to_cache = key
value_to_cache = value
slot_mapping = input_metadata.slot_mapping.view(-1)
if input_metadata.to_cache is not None:
key_to_cache = key_to_cache[input_metadata.to_cache]
value_to_cache = value_to_cache[input_metadata.to_cache]
slot_mapping = slot_mapping[input_metadata.to_cache]
cache_ops.reshape_and_cache(
key_to_cache,
value_to_cache,
key_cache,
value_cache,
slot_mapping,
)
if input_metadata.num_generation_tokens > 0:
# Decoding run.
assert input_metadata.num_prompt_tokens == 0
assert key_cache is not None and value_cache is not None, (
"key_cache and value_cache must be provided when "
"generating tokens.")
# Compute the attention op for generation tokens.
self.single_query_cached_kv_attention(output, query, key_cache,
value_cache, input_metadata,
self.get_alibi_slopes())
# Reshape the output tensor.
# NOTE(woosuk): The output tensor may include paddings.
return output.view(batch_size, seq_len,
self.num_heads * self.head_size)
class PagedAttentionWithRoPE(PagedAttention):
"""PagedAttention with rotary positional embedding."""
def __init__(
self,
num_heads: int,
head_size: int,
scale: float,
rotary_dim: int,
max_position: int = 8192,
base: int = 10000,
num_kv_heads: Optional[int] = None,
is_neox_style: bool = True,
rope_scaling: Optional[Dict[str, Any]] = None,
sliding_window: Optional[int] = None,
) -> None:
super().__init__(num_heads,
head_size,
scale,
num_kv_heads,
sliding_window=sliding_window)
self.rotary_emb = get_rope(head_size, rotary_dim, max_position, base,
is_neox_style, rope_scaling)
def forward(
self,
positions: torch.Tensor,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
key_cache: torch.Tensor,
value_cache: torch.Tensor,
input_metadata: InputMetadata,
cache_event: Optional[torch.cuda.Event],
) -> torch.Tensor:
""" PagedAttention forward pass with rotary embedding.
Args:
positions: shape = [batch_size, seq_len]
query: shape = [batch_size, seq_len, num_heads * head_size]
key: shape = [batch_size, seq_len, num_kv_heads * head_size]
value: shape = [batch_size, seq_len, num_kv_heads * head_size]
key_cache: shape = [num_blocks, num_kv_heads, head_size/x,
block_size, x]
value_cache: shape = [num_blocks, num_kv_heads, head_size,
block_size]
input_metadata: metadata for paged attention.
cache_event: event to wait for the cache operations to finish.
Returns:
shape = [batch_size, seq_len, num_heads * head_size]
"""
# Apply rotary embedding to the query and key before passing them
# to the attention op.
query, key = self.rotary_emb(positions, query, key)
return super().forward(
query,
key,
value,
key_cache,
value_cache,
input_metadata,
cache_event,
)
class PagedAttentionWithALiBi(PagedAttention):
"""PagedAttention with ALiBi attention bias."""
def __init__(self,
num_heads: int,
head_size: int,
scale: float,
slopes: List[float],
num_kv_heads: Optional[int] = None) -> None:
super().__init__(num_heads, head_size, scale, num_kv_heads)
assert len(slopes) == num_heads
slopes = torch.tensor(slopes, dtype=torch.float32)
self.register_buffer("alibi_slopes", slopes, persistent=False)
def set_attn_bias(self, input_metadata: InputMetadata,
dtype: torch.dtype) -> None:
if input_metadata.attn_bias is not None:
# Already set by a previous layer.
return
# Generates ALiBi mask based on the max prompt length.
max_prompt_len = input_metadata.max_prompt_len
bias = torch.arange(max_prompt_len, dtype=dtype)
# NOTE(zhuohan): HF uses
# `bias = bias[None, :].repeat(prompt_len, 1)`
# here. We find that both biases give the same results, but
# the bias below more accurately follows the original ALiBi
# paper.
bias = bias[None, :] - bias[:, None]
bias = bias.to(self.alibi_slopes.device)
# When using custom attention bias, xformers requires the bias to
# be sliced from a tensor whose length is a multiple of 8.
padded_len = (max_prompt_len + 7) // 8 * 8
bias = torch.empty(
input_metadata.num_prompts,
self.num_heads,
max_prompt_len,
padded_len,
device=self.alibi_slopes.device,
dtype=dtype,
)[:, :, :, :max_prompt_len].copy_(bias)
bias.mul_(self.alibi_slopes[:, None, None])
attn_bias = LowerTriangularMaskWithTensorBias(bias)
input_metadata.attn_bias = attn_bias
def multi_query_kv_attention(
self,
output: torch.Tensor,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
input_metadata: InputMetadata,
) -> torch.Tensor:
"""Attention with ALiBi bias for the prompt tokens.
Args:
output: shape = [num_prompt_tokens, num_heads, head_size]
query: shape = [num_prompt_tokens, num_heads, head_size]
key: shape = [num_prompt_tokens, num_kv_heads, head_size]
value: shape = [num_prompt_tokens, num_kv_heads, head_size]
input_metadata: metadata for paged attention.
"""
if self.num_kv_heads != self.num_heads:
# Project the key and value tensors to the desired number of heads.
query = query.view(query.shape[0], self.num_kv_heads,
self.num_queries_per_kv, query.shape[-1])
key = key[:, :,
None, :].expand(key.shape[0], self.num_kv_heads,
self.num_queries_per_kv, key.shape[-1])
value = value[:, :,
None, :].expand(value.shape[0], self.num_kv_heads,
self.num_queries_per_kv,
value.shape[-1])
batch_size = input_metadata.num_prompts
seq_len = input_metadata.max_prompt_len
out = xops.memory_efficient_attention_forward(
query.view(batch_size, seq_len, self.num_heads, self.head_size),
key.view(batch_size, seq_len, self.num_heads, self.head_size),
value.view(batch_size, seq_len, self.num_heads, self.head_size),
attn_bias=input_metadata.attn_bias,
p=0.0,
scale=self.scale,
)
# TODO(woosuk): Unnecessary copy. Optimize.
output.copy_(out.view_as(output))
return output return output
def get_alibi_slopes(self) -> Optional[torch.Tensor]:
return self.alibi_slopes
...@@ -277,8 +277,8 @@ def get_rope( ...@@ -277,8 +277,8 @@ def get_rope(
rotary_dim: int, rotary_dim: int,
max_position: int, max_position: int,
base: int, base: int,
is_neox_style: bool, is_neox_style: bool = True,
rope_scaling: Optional[Dict[str, Any]], rope_scaling: Optional[Dict[str, Any]] = None,
) -> RotaryEmbedding: ) -> RotaryEmbedding:
if rope_scaling is None: if rope_scaling is None:
rotary_emb = RotaryEmbedding(head_size, rotary_dim, max_position, base, rotary_emb = RotaryEmbedding(head_size, rotary_dim, max_position, base,
......
...@@ -28,11 +28,12 @@ from torch import nn ...@@ -28,11 +28,12 @@ from torch import nn
from vllm.model_executor.input_metadata import InputMetadata from vllm.model_executor.input_metadata import InputMetadata
from vllm.model_executor.layers.activation import SiluAndMul from vllm.model_executor.layers.activation import SiluAndMul
from vllm.model_executor.layers.attention import PagedAttentionWithRoPE from vllm.model_executor.layers.attention import PagedAttention
from vllm.model_executor.layers.linear import (LinearMethodBase, from vllm.model_executor.layers.linear import (LinearMethodBase,
MergedColumnParallelLinear, MergedColumnParallelLinear,
QKVParallelLinear, QKVParallelLinear,
RowParallelLinear) RowParallelLinear)
from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.layers.sampler import Sampler from vllm.model_executor.layers.sampler import Sampler
from vllm.model_executor.layers.vocab_parallel_embedding import ( from vllm.model_executor.layers.vocab_parallel_embedding import (
VocabParallelEmbedding, ParallelLMHead) VocabParallelEmbedding, ParallelLMHead)
...@@ -138,15 +139,17 @@ class AquilaAttention(nn.Module): ...@@ -138,15 +139,17 @@ class AquilaAttention(nn.Module):
bias=False, bias=False,
linear_method=linear_method, linear_method=linear_method,
) )
self.attn = PagedAttentionWithRoPE( self.rotary_emb = get_rope(
self.num_heads,
self.head_dim, self.head_dim,
self.scaling,
base=self.rope_theta,
max_position=self.max_position_embeddings,
rotary_dim=self.head_dim, rotary_dim=self.head_dim,
num_kv_heads=self.num_kv_heads, max_position=self.max_position_embeddings,
rope_scaling=rope_scaling) base=self.rope_theta,
rope_scaling=rope_scaling,
)
self.attn = PagedAttention(self.num_heads,
self.head_dim,
self.scaling,
num_kv_heads=self.num_kv_heads)
def forward( def forward(
self, self,
...@@ -158,9 +161,10 @@ class AquilaAttention(nn.Module): ...@@ -158,9 +161,10 @@ class AquilaAttention(nn.Module):
) -> torch.Tensor: ) -> torch.Tensor:
qkv, _ = self.qkv_proj(hidden_states) qkv, _ = self.qkv_proj(hidden_states)
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
q, k = self.rotary_emb(positions, q, k)
k_cache, v_cache = kv_cache k_cache, v_cache = kv_cache
attn_output = self.attn(positions, q, k, v, k_cache, v_cache, attn_output = self.attn(q, k, v, k_cache, v_cache, input_metadata,
input_metadata, cache_event) cache_event)
output, _ = self.o_proj(attn_output) output, _ = self.o_proj(attn_output)
return output return output
......
...@@ -26,13 +26,13 @@ from torch import nn ...@@ -26,13 +26,13 @@ from torch import nn
from vllm.model_executor.input_metadata import InputMetadata from vllm.model_executor.input_metadata import InputMetadata
from vllm.model_executor.layers.activation import SiluAndMul from vllm.model_executor.layers.activation import SiluAndMul
from vllm.model_executor.layers.attention import (PagedAttentionWithRoPE, from vllm.model_executor.layers.attention import PagedAttention
PagedAttentionWithALiBi)
from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.linear import (LinearMethodBase, from vllm.model_executor.layers.linear import (LinearMethodBase,
MergedColumnParallelLinear, MergedColumnParallelLinear,
QKVParallelLinear, QKVParallelLinear,
RowParallelLinear) RowParallelLinear)
from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.layers.sampler import Sampler from vllm.model_executor.layers.sampler import Sampler
from vllm.model_executor.layers.vocab_parallel_embedding import ( from vllm.model_executor.layers.vocab_parallel_embedding import (
VocabParallelEmbedding, ParallelLMHead) VocabParallelEmbedding, ParallelLMHead)
...@@ -150,17 +150,20 @@ class BaiChuanAttention(nn.Module): ...@@ -150,17 +150,20 @@ class BaiChuanAttention(nn.Module):
alibi_slopes = alibi_slopes[head_start:head_end].tolist() alibi_slopes = alibi_slopes[head_start:head_end].tolist()
scaling = self.head_dim**-0.5 scaling = self.head_dim**-0.5
self.attn = PagedAttentionWithALiBi(self.num_heads, self.head_dim, self.attn = PagedAttention(self.num_heads,
scaling, alibi_slopes) self.head_dim,
scaling,
alibi_slopes=alibi_slopes)
else: else:
self.scaling = self.head_dim**-0.5 self.rotary_emb = get_rope(
self.attn = PagedAttentionWithRoPE(
self.num_heads,
self.head_dim, self.head_dim,
self.scaling,
rotary_dim=self.head_dim, rotary_dim=self.head_dim,
max_position=self.max_position_embeddings,
base=self.rope_theta, base=self.rope_theta,
max_position=self.max_position_embeddings) )
self.scaling = self.head_dim**-0.5
self.attn = PagedAttention(self.num_heads, self.head_dim,
self.scaling)
def forward( def forward(
self, self,
...@@ -172,14 +175,11 @@ class BaiChuanAttention(nn.Module): ...@@ -172,14 +175,11 @@ class BaiChuanAttention(nn.Module):
) -> torch.Tensor: ) -> torch.Tensor:
qkv, _ = self.W_pack(hidden_states) qkv, _ = self.W_pack(hidden_states)
q, k, v = qkv.chunk(chunks=3, dim=-1) q, k, v = qkv.chunk(chunks=3, dim=-1)
if self.postion_embedding != "ALIBI":
q, k = self.rotary_emb(positions, q, k)
k_cache, v_cache = kv_cache k_cache, v_cache = kv_cache
if self.postion_embedding == "ALIBI":
attn_output = self.attn(q, k, v, k_cache, v_cache, input_metadata, attn_output = self.attn(q, k, v, k_cache, v_cache, input_metadata,
cache_event) cache_event)
else:
attn_output = self.attn(positions, q, k, v, k_cache, v_cache,
input_metadata, cache_event)
output, _ = self.o_proj(attn_output) output, _ = self.o_proj(attn_output)
return output return output
......
...@@ -25,7 +25,7 @@ from transformers import BloomConfig ...@@ -25,7 +25,7 @@ from transformers import BloomConfig
from vllm.model_executor.input_metadata import InputMetadata from vllm.model_executor.input_metadata import InputMetadata
from vllm.model_executor.layers.activation import get_act_fn from vllm.model_executor.layers.activation import get_act_fn
from vllm.model_executor.layers.attention import PagedAttentionWithALiBi from vllm.model_executor.layers.attention import PagedAttention
from vllm.model_executor.layers.linear import (ColumnParallelLinear, from vllm.model_executor.layers.linear import (ColumnParallelLinear,
LinearMethodBase, LinearMethodBase,
QKVParallelLinear, QKVParallelLinear,
...@@ -106,8 +106,10 @@ class BloomAttention(nn.Module): ...@@ -106,8 +106,10 @@ class BloomAttention(nn.Module):
alibi_slopes = alibi_slopes[head_start:head_end].tolist() alibi_slopes = alibi_slopes[head_start:head_end].tolist()
scaling = self.head_dim**-0.5 scaling = self.head_dim**-0.5
self.attn = PagedAttentionWithALiBi(self.num_heads, self.head_dim, self.attn = PagedAttention(self.num_heads,
scaling, alibi_slopes) self.head_dim,
scaling,
alibi_slopes=alibi_slopes)
def forward( def forward(
self, self,
......
...@@ -10,12 +10,13 @@ from torch.nn import LayerNorm ...@@ -10,12 +10,13 @@ from torch.nn import LayerNorm
from vllm.model_executor.input_metadata import InputMetadata from vllm.model_executor.input_metadata import InputMetadata
from vllm.model_executor.layers.activation import SiluAndMul from vllm.model_executor.layers.activation import SiluAndMul
from vllm.model_executor.layers.attention import PagedAttentionWithRoPE from vllm.model_executor.layers.attention import PagedAttention
from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.linear import (LinearMethodBase, from vllm.model_executor.layers.linear import (LinearMethodBase,
MergedColumnParallelLinear, MergedColumnParallelLinear,
QKVParallelLinear, QKVParallelLinear,
RowParallelLinear) RowParallelLinear)
from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.layers.sampler import Sampler from vllm.model_executor.layers.sampler import Sampler
from vllm.model_executor.layers.vocab_parallel_embedding import ( from vllm.model_executor.layers.vocab_parallel_embedding import (
VocabParallelEmbedding, ParallelLMHead) VocabParallelEmbedding, ParallelLMHead)
...@@ -78,16 +79,19 @@ class GLMAttention(nn.Module): ...@@ -78,16 +79,19 @@ class GLMAttention(nn.Module):
# https://huggingface.co/THUDM/chatglm3-6b-32k/blob/e210410255278dd9d74463cf396ba559c0ef801c/modeling_chatglm.py#L141 # https://huggingface.co/THUDM/chatglm3-6b-32k/blob/e210410255278dd9d74463cf396ba559c0ef801c/modeling_chatglm.py#L141
rope_ratio = getattr(config, "rope_ratio", 1.0) rope_ratio = getattr(config, "rope_ratio", 1.0)
max_positions = getattr(config, "seq_length", 8192) max_positions = getattr(config, "seq_length", 8192)
self.attn = PagedAttentionWithRoPE( self.rotary_emb = get_rope(
self.num_heads,
self.head_dim, self.head_dim,
self.scaling,
rotary_dim=self.head_dim // 2, rotary_dim=self.head_dim // 2,
num_kv_heads=self.num_kv_heads,
max_position=max_positions, max_position=max_positions,
base=10000 * rope_ratio, base=10000 * rope_ratio,
is_neox_style=False, is_neox_style=False,
) )
self.attn = PagedAttention(
self.num_heads,
self.head_dim,
self.scaling,
num_kv_heads=self.num_kv_heads,
)
def forward( def forward(
self, self,
...@@ -99,10 +103,9 @@ class GLMAttention(nn.Module): ...@@ -99,10 +103,9 @@ class GLMAttention(nn.Module):
) -> torch.Tensor: ) -> torch.Tensor:
qkv, _ = self.query_key_value(hidden_states) qkv, _ = self.query_key_value(hidden_states)
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
q, k = self.rotary_emb(position_ids, q, k)
key_cache, value_cache = kv_cache key_cache, value_cache = kv_cache
context_layer = self.attn( context_layer = self.attn(
position_ids,
q, q,
k, k,
v, v,
...@@ -111,9 +114,7 @@ class GLMAttention(nn.Module): ...@@ -111,9 +114,7 @@ class GLMAttention(nn.Module):
input_metadata, input_metadata,
cache_event, cache_event,
) )
attn_output, _ = self.dense(context_layer) attn_output, _ = self.dense(context_layer)
return attn_output return attn_output
......
...@@ -28,13 +28,12 @@ from transformers import FalconConfig as HF_FalconConfig ...@@ -28,13 +28,12 @@ from transformers import FalconConfig as HF_FalconConfig
from vllm.model_executor.input_metadata import InputMetadata from vllm.model_executor.input_metadata import InputMetadata
from vllm.model_executor.layers.activation import get_act_fn from vllm.model_executor.layers.activation import get_act_fn
from vllm.model_executor.layers.attention import (PagedAttention, from vllm.model_executor.layers.attention import PagedAttention
PagedAttentionWithALiBi,
PagedAttentionWithRoPE)
from vllm.model_executor.layers.linear import (ColumnParallelLinear, from vllm.model_executor.layers.linear import (ColumnParallelLinear,
LinearMethodBase, LinearMethodBase,
QKVParallelLinear, QKVParallelLinear,
RowParallelLinear) RowParallelLinear)
from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.layers.sampler import Sampler from vllm.model_executor.layers.sampler import Sampler
from vllm.model_executor.layers.vocab_parallel_embedding import ( from vllm.model_executor.layers.vocab_parallel_embedding import (
VocabParallelEmbedding, ParallelLMHead) VocabParallelEmbedding, ParallelLMHead)
...@@ -144,13 +143,15 @@ class FalconAttention(nn.Module): ...@@ -144,13 +143,15 @@ class FalconAttention(nn.Module):
rope_theta = getattr(config, "rope_theta", 10000) rope_theta = getattr(config, "rope_theta", 10000)
max_position_embeddings = getattr(config, max_position_embeddings = getattr(config,
"max_position_embeddings", 8192) "max_position_embeddings", 8192)
self.attn = PagedAttentionWithRoPE( self.rotary_emb = get_rope(
self.num_heads,
self.head_dim, self.head_dim,
self.inv_norm_factor,
base=rope_theta,
max_position=max_position_embeddings,
rotary_dim=self.head_dim, rotary_dim=self.head_dim,
max_position=max_position_embeddings,
base=rope_theta,
)
self.attn = PagedAttention(self.num_heads,
self.head_dim,
self.inv_norm_factor,
num_kv_heads=self.num_kv_heads) num_kv_heads=self.num_kv_heads)
elif self.use_alibi: elif self.use_alibi:
tp_rank = get_tensor_model_parallel_rank() tp_rank = get_tensor_model_parallel_rank()
...@@ -159,11 +160,11 @@ class FalconAttention(nn.Module): ...@@ -159,11 +160,11 @@ class FalconAttention(nn.Module):
alibi_slopes = (_get_alibi_slopes(self.total_num_heads) * alibi_slopes = (_get_alibi_slopes(self.total_num_heads) *
self.inv_norm_factor) self.inv_norm_factor)
alibi_slopes = alibi_slopes[head_start:head_end].tolist() alibi_slopes = alibi_slopes[head_start:head_end].tolist()
self.attn = PagedAttentionWithALiBi(self.num_heads, self.attn = PagedAttention(self.num_heads,
self.head_dim, self.head_dim,
self.inv_norm_factor, self.inv_norm_factor,
alibi_slopes, num_kv_heads=self.num_kv_heads,
num_kv_heads=self.num_kv_heads) alibi_slopes=alibi_slopes)
else: else:
self.attn = PagedAttention(self.num_heads, self.attn = PagedAttention(self.num_heads,
self.head_dim, self.head_dim,
...@@ -182,11 +183,9 @@ class FalconAttention(nn.Module): ...@@ -182,11 +183,9 @@ class FalconAttention(nn.Module):
if bias is not None: if bias is not None:
qkv += bias qkv += bias
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
k_cache, v_cache = kv_cache
if self.use_rotary: if self.use_rotary:
attn_output = self.attn(positions, q, k, v, k_cache, v_cache, q, k = self.rotary_emb(positions, q, k)
input_metadata, cache_event) k_cache, v_cache = kv_cache
else:
attn_output = self.attn(q, k, v, k_cache, v_cache, input_metadata, attn_output = self.attn(q, k, v, k_cache, v_cache, input_metadata,
cache_event) cache_event)
attn_output, bias = self.dense(attn_output) attn_output, bias = self.dense(attn_output)
......
...@@ -24,11 +24,12 @@ from transformers import GPTJConfig ...@@ -24,11 +24,12 @@ from transformers import GPTJConfig
from vllm.model_executor.input_metadata import InputMetadata from vllm.model_executor.input_metadata import InputMetadata
from vllm.model_executor.layers.activation import get_act_fn from vllm.model_executor.layers.activation import get_act_fn
from vllm.model_executor.layers.attention import PagedAttentionWithRoPE from vllm.model_executor.layers.attention import PagedAttention
from vllm.model_executor.layers.linear import (ColumnParallelLinear, from vllm.model_executor.layers.linear import (ColumnParallelLinear,
LinearMethodBase, LinearMethodBase,
QKVParallelLinear, QKVParallelLinear,
RowParallelLinear) RowParallelLinear)
from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.layers.sampler import Sampler from vllm.model_executor.layers.sampler import Sampler
from vllm.model_executor.layers.vocab_parallel_embedding import ( from vllm.model_executor.layers.vocab_parallel_embedding import (
VocabParallelEmbedding, ParallelLMHead) VocabParallelEmbedding, ParallelLMHead)
...@@ -77,15 +78,14 @@ class GPTJAttention(nn.Module): ...@@ -77,15 +78,14 @@ class GPTJAttention(nn.Module):
rope_theta = getattr(config, "rope_theta", 10000) rope_theta = getattr(config, "rope_theta", 10000)
max_position_embeddings = getattr(config, "max_position_embeddings", max_position_embeddings = getattr(config, "max_position_embeddings",
8192) 8192)
self.attn = PagedAttentionWithRoPE( self.rotary_emb = get_rope(
self.num_heads,
self.head_size, self.head_size,
scaling, rotary_dim=config.rotary_dim,
config.rotary_dim,
base=rope_theta,
max_position=max_position_embeddings, max_position=max_position_embeddings,
is_neox_style=False) base=rope_theta,
self.warmup = False is_neox_style=False,
)
self.attn = PagedAttention(self.num_heads, self.head_size, scaling)
def forward( def forward(
self, self,
...@@ -97,9 +97,10 @@ class GPTJAttention(nn.Module): ...@@ -97,9 +97,10 @@ class GPTJAttention(nn.Module):
) -> torch.Tensor: ) -> torch.Tensor:
qkv, _ = self.qkv_proj(hidden_states) qkv, _ = self.qkv_proj(hidden_states)
q, k, v = qkv.chunk(chunks=3, dim=-1) q, k, v = qkv.chunk(chunks=3, dim=-1)
q, k = self.rotary_emb(position_ids, q, k)
k_cache, v_cache = kv_cache k_cache, v_cache = kv_cache
attn_output = self.attn(position_ids, q, k, v, k_cache, v_cache, attn_output = self.attn(q, k, v, k_cache, v_cache, input_metadata,
input_metadata, cache_event) cache_event)
attn_output, _ = self.out_proj(attn_output) attn_output, _ = self.out_proj(attn_output)
return attn_output return attn_output
......
...@@ -24,11 +24,12 @@ from transformers import GPTNeoXConfig ...@@ -24,11 +24,12 @@ from transformers import GPTNeoXConfig
from vllm.model_executor.input_metadata import InputMetadata from vllm.model_executor.input_metadata import InputMetadata
from vllm.model_executor.layers.activation import get_act_fn from vllm.model_executor.layers.activation import get_act_fn
from vllm.model_executor.layers.attention import PagedAttentionWithRoPE from vllm.model_executor.layers.attention import PagedAttention
from vllm.model_executor.layers.linear import (ColumnParallelLinear, from vllm.model_executor.layers.linear import (ColumnParallelLinear,
LinearMethodBase, LinearMethodBase,
QKVParallelLinear, QKVParallelLinear,
RowParallelLinear) RowParallelLinear)
from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.layers.sampler import Sampler from vllm.model_executor.layers.sampler import Sampler
from vllm.model_executor.layers.vocab_parallel_embedding import ( from vllm.model_executor.layers.vocab_parallel_embedding import (
VocabParallelEmbedding, ParallelLMHead) VocabParallelEmbedding, ParallelLMHead)
...@@ -77,13 +78,13 @@ class GPTNeoXAttention(nn.Module): ...@@ -77,13 +78,13 @@ class GPTNeoXAttention(nn.Module):
rope_theta = getattr(config, "rope_theta", 10000) rope_theta = getattr(config, "rope_theta", 10000)
max_position_embeddings = getattr(config, "max_position_embeddings", max_position_embeddings = getattr(config, "max_position_embeddings",
8192) 8192)
self.attn = PagedAttentionWithRoPE( self.rotary_emb = get_rope(
self.num_heads,
self.head_size, self.head_size,
scaling, rotary_dim=rotary_dim,
rotary_dim, max_position=max_position_embeddings,
base=rope_theta, base=rope_theta,
max_position=max_position_embeddings) )
self.attn = PagedAttention(self.num_heads, self.head_size, scaling)
def forward( def forward(
self, self,
...@@ -95,9 +96,10 @@ class GPTNeoXAttention(nn.Module): ...@@ -95,9 +96,10 @@ class GPTNeoXAttention(nn.Module):
) -> torch.Tensor: ) -> torch.Tensor:
qkv, _ = self.query_key_value(hidden_states) qkv, _ = self.query_key_value(hidden_states)
q, k, v = qkv.chunk(chunks=3, dim=-1) q, k, v = qkv.chunk(chunks=3, dim=-1)
q, k = self.rotary_emb(position_ids, q, k)
k_cache, v_cache = kv_cache k_cache, v_cache = kv_cache
attn_output = self.attn(position_ids, q, k, v, k_cache, v_cache, attn_output = self.attn(q, k, v, k_cache, v_cache, input_metadata,
input_metadata, cache_event) cache_event)
output, _ = self.dense(attn_output) output, _ = self.dense(attn_output)
return output return output
......
...@@ -7,12 +7,13 @@ from transformers import LlamaConfig ...@@ -7,12 +7,13 @@ from transformers import LlamaConfig
from vllm.model_executor.input_metadata import InputMetadata from vllm.model_executor.input_metadata import InputMetadata
from vllm.model_executor.layers.activation import SiluAndMul from vllm.model_executor.layers.activation import SiluAndMul
from vllm.model_executor.layers.attention import PagedAttentionWithRoPE from vllm.model_executor.layers.attention import PagedAttention
from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.linear import (LinearMethodBase, from vllm.model_executor.layers.linear import (LinearMethodBase,
MergedColumnParallelLinear, MergedColumnParallelLinear,
QKVParallelLinear, QKVParallelLinear,
RowParallelLinear) RowParallelLinear)
from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.layers.sampler import Sampler from vllm.model_executor.layers.sampler import Sampler
from vllm.model_executor.layers.vocab_parallel_embedding import ( from vllm.model_executor.layers.vocab_parallel_embedding import (
VocabParallelEmbedding, ParallelLMHead) VocabParallelEmbedding, ParallelLMHead)
...@@ -92,13 +93,13 @@ class InternLMAttention(nn.Module): ...@@ -92,13 +93,13 @@ class InternLMAttention(nn.Module):
bias=bias, bias=bias,
linear_method=linear_method, linear_method=linear_method,
) )
self.attn = PagedAttentionWithRoPE( self.rotary_emb = get_rope(
self.num_heads,
self.head_dim, self.head_dim,
self.scaling, rotary_dim=self.head_dim,
base=self.rope_theta,
max_position=self.max_position_embeddings, max_position=self.max_position_embeddings,
rotary_dim=self.head_dim) base=self.rope_theta,
)
self.attn = PagedAttention(self.num_heads, self.head_dim, self.scaling)
def forward( def forward(
self, self,
...@@ -110,9 +111,10 @@ class InternLMAttention(nn.Module): ...@@ -110,9 +111,10 @@ class InternLMAttention(nn.Module):
) -> torch.Tensor: ) -> torch.Tensor:
qkv, _ = self.qkv_proj(hidden_states) qkv, _ = self.qkv_proj(hidden_states)
q, k, v = qkv.chunk(chunks=3, dim=-1) q, k, v = qkv.chunk(chunks=3, dim=-1)
q, k = self.rotary_emb(positions, q, k)
k_cache, v_cache = kv_cache k_cache, v_cache = kv_cache
attn_output = self.attn(positions, q, k, v, k_cache, v_cache, attn_output = self.attn(q, k, v, k_cache, v_cache, input_metadata,
input_metadata, cache_event) cache_event)
output, _ = self.o_proj(attn_output) output, _ = self.o_proj(attn_output)
return output return output
......
...@@ -29,12 +29,13 @@ from transformers import LlamaConfig ...@@ -29,12 +29,13 @@ from transformers import LlamaConfig
from vllm.model_executor.input_metadata import InputMetadata from vllm.model_executor.input_metadata import InputMetadata
from vllm.model_executor.layers.activation import SiluAndMul from vllm.model_executor.layers.activation import SiluAndMul
from vllm.model_executor.layers.attention import PagedAttentionWithRoPE from vllm.model_executor.layers.attention import PagedAttention
from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.linear import (LinearMethodBase, from vllm.model_executor.layers.linear import (LinearMethodBase,
MergedColumnParallelLinear, MergedColumnParallelLinear,
QKVParallelLinear, QKVParallelLinear,
RowParallelLinear) RowParallelLinear)
from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.layers.sampler import Sampler from vllm.model_executor.layers.sampler import Sampler
from vllm.model_executor.layers.vocab_parallel_embedding import ( from vllm.model_executor.layers.vocab_parallel_embedding import (
VocabParallelEmbedding, ParallelLMHead) VocabParallelEmbedding, ParallelLMHead)
...@@ -126,15 +127,18 @@ class LlamaAttention(nn.Module): ...@@ -126,15 +127,18 @@ class LlamaAttention(nn.Module):
bias=False, bias=False,
linear_method=linear_method, linear_method=linear_method,
) )
self.attn = PagedAttentionWithRoPE(
self.num_heads, self.rotary_emb = get_rope(
self.head_dim, self.head_dim,
self.scaling,
base=self.rope_theta,
max_position=self.max_position_embeddings,
rotary_dim=self.head_dim, rotary_dim=self.head_dim,
num_kv_heads=self.num_kv_heads, max_position=max_position_embeddings,
rope_scaling=rope_scaling) base=rope_theta,
rope_scaling=rope_scaling,
)
self.attn = PagedAttention(self.num_heads,
self.head_dim,
self.scaling,
num_kv_heads=self.num_kv_heads)
def forward( def forward(
self, self,
...@@ -146,9 +150,10 @@ class LlamaAttention(nn.Module): ...@@ -146,9 +150,10 @@ class LlamaAttention(nn.Module):
) -> torch.Tensor: ) -> torch.Tensor:
qkv, _ = self.qkv_proj(hidden_states) qkv, _ = self.qkv_proj(hidden_states)
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
q, k = self.rotary_emb(positions, q, k)
k_cache, v_cache = kv_cache k_cache, v_cache = kv_cache
attn_output = self.attn(positions, q, k, v, k_cache, v_cache, attn_output = self.attn(q, k, v, k_cache, v_cache, input_metadata,
input_metadata, cache_event) cache_event)
output, _ = self.o_proj(attn_output) output, _ = self.o_proj(attn_output)
return output return output
......
...@@ -29,12 +29,13 @@ from transformers import MistralConfig ...@@ -29,12 +29,13 @@ from transformers import MistralConfig
from vllm.model_executor.input_metadata import InputMetadata from vllm.model_executor.input_metadata import InputMetadata
from vllm.model_executor.layers.activation import SiluAndMul from vllm.model_executor.layers.activation import SiluAndMul
from vllm.model_executor.layers.attention import PagedAttentionWithRoPE from vllm.model_executor.layers.attention import PagedAttention
from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.linear import (LinearMethodBase, from vllm.model_executor.layers.linear import (LinearMethodBase,
MergedColumnParallelLinear, MergedColumnParallelLinear,
QKVParallelLinear, QKVParallelLinear,
RowParallelLinear) RowParallelLinear)
from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.layers.sampler import Sampler from vllm.model_executor.layers.sampler import Sampler
from vllm.model_executor.layers.vocab_parallel_embedding import ( from vllm.model_executor.layers.vocab_parallel_embedding import (
VocabParallelEmbedding, ParallelLMHead) VocabParallelEmbedding, ParallelLMHead)
...@@ -124,12 +125,16 @@ class MistralAttention(nn.Module): ...@@ -124,12 +125,16 @@ class MistralAttention(nn.Module):
bias=False, bias=False,
linear_method=linear_method, linear_method=linear_method,
) )
self.attn = PagedAttentionWithRoPE(self.num_heads,
self.rotary_emb = get_rope(
self.head_dim, self.head_dim,
self.scaling,
base=self.rope_theta,
max_position=max_position,
rotary_dim=self.head_dim, rotary_dim=self.head_dim,
max_position=max_position,
base=self.rope_theta,
)
self.attn = PagedAttention(self.num_heads,
self.head_dim,
self.scaling,
num_kv_heads=self.num_kv_heads, num_kv_heads=self.num_kv_heads,
sliding_window=self.sliding_window) sliding_window=self.sliding_window)
...@@ -143,9 +148,10 @@ class MistralAttention(nn.Module): ...@@ -143,9 +148,10 @@ class MistralAttention(nn.Module):
) -> torch.Tensor: ) -> torch.Tensor:
qkv, _ = self.qkv_proj(hidden_states) qkv, _ = self.qkv_proj(hidden_states)
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
q, k = self.rotary_emb(positions, q, k)
k_cache, v_cache = kv_cache k_cache, v_cache = kv_cache
attn_output = self.attn(positions, q, k, v, k_cache, v_cache, attn_output = self.attn(q, k, v, k_cache, v_cache, input_metadata,
input_metadata, cache_event) cache_event)
output, _ = self.o_proj(attn_output) output, _ = self.o_proj(attn_output)
return output return output
......
...@@ -8,7 +8,7 @@ import torch.nn as nn ...@@ -8,7 +8,7 @@ import torch.nn as nn
from vllm.model_executor.input_metadata import InputMetadata from vllm.model_executor.input_metadata import InputMetadata
from vllm.model_executor.layers.activation import get_act_fn from vllm.model_executor.layers.activation import get_act_fn
from vllm.model_executor.layers.attention import PagedAttentionWithALiBi from vllm.model_executor.layers.attention import PagedAttention
from vllm.model_executor.layers.linear import (ColumnParallelLinear, from vllm.model_executor.layers.linear import (ColumnParallelLinear,
LinearMethodBase, LinearMethodBase,
QKVParallelLinear, QKVParallelLinear,
...@@ -87,8 +87,10 @@ class MPTAttention(nn.Module): ...@@ -87,8 +87,10 @@ class MPTAttention(nn.Module):
self.head_dim = self.d_model // self.total_num_heads self.head_dim = self.d_model // self.total_num_heads
scaling = self.head_dim**-0.5 scaling = self.head_dim**-0.5
self.attn = PagedAttentionWithALiBi(self.num_heads, self.head_dim, self.attn = PagedAttention(self.num_heads,
scaling, alibi_slopes) self.head_dim,
scaling,
alibi_slopes=alibi_slopes)
def forward( def forward(
self, self,
......
...@@ -43,11 +43,12 @@ from transformers import PretrainedConfig ...@@ -43,11 +43,12 @@ from transformers import PretrainedConfig
from vllm.model_executor.input_metadata import InputMetadata from vllm.model_executor.input_metadata import InputMetadata
from vllm.model_executor.layers.activation import get_act_fn from vllm.model_executor.layers.activation import get_act_fn
from vllm.model_executor.layers.attention import PagedAttentionWithRoPE from vllm.model_executor.layers.attention import PagedAttention
from vllm.model_executor.layers.linear import (ColumnParallelLinear, from vllm.model_executor.layers.linear import (ColumnParallelLinear,
LinearMethodBase, LinearMethodBase,
QKVParallelLinear, QKVParallelLinear,
RowParallelLinear) RowParallelLinear)
from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.layers.sampler import Sampler from vllm.model_executor.layers.sampler import Sampler
from vllm.model_executor.layers.vocab_parallel_embedding import ( from vllm.model_executor.layers.vocab_parallel_embedding import (
VocabParallelEmbedding, ParallelLMHead) VocabParallelEmbedding, ParallelLMHead)
...@@ -119,13 +120,13 @@ class PhiAttention(nn.Module): ...@@ -119,13 +120,13 @@ class PhiAttention(nn.Module):
# https://huggingface.co/microsoft/phi-1_5/blob/d212a789620c380ff32ca1d1ee9943a777360987/modeling_phi.py#L518 # https://huggingface.co/microsoft/phi-1_5/blob/d212a789620c380ff32ca1d1ee9943a777360987/modeling_phi.py#L518
rope_theta = 10000 rope_theta = 10000
max_position_embeddings = getattr(config, "n_positions", 2048) max_position_embeddings = getattr(config, "n_positions", 2048)
self.attn = PagedAttentionWithRoPE( self.rotary_emb = get_rope(
self.num_heads,
self.head_size, self.head_size,
scaling, rotary_dim=rotary_dim,
rotary_dim, max_position=max_position_embeddings,
base=rope_theta, base=rope_theta,
max_position=max_position_embeddings) )
self.attn = PagedAttention(self.num_heads, self.head_size, scaling)
def forward( def forward(
self, self,
...@@ -137,9 +138,10 @@ class PhiAttention(nn.Module): ...@@ -137,9 +138,10 @@ class PhiAttention(nn.Module):
) -> torch.Tensor: ) -> torch.Tensor:
qkv, _ = self.Wqkv(hidden_states) qkv, _ = self.Wqkv(hidden_states)
q, k, v = qkv.chunk(chunks=3, dim=-1) q, k, v = qkv.chunk(chunks=3, dim=-1)
q, k = self.rotary_emb(position_ids, q, k)
k_cache, v_cache = kv_cache k_cache, v_cache = kv_cache
attn_output = self.attn(position_ids, q, k, v, k_cache, v_cache, attn_output = self.attn(q, k, v, k_cache, v_cache, input_metadata,
input_metadata, cache_event) cache_event)
output, _ = self.out_proj(attn_output) output, _ = self.out_proj(attn_output)
return output return output
......
...@@ -11,12 +11,13 @@ from torch import nn ...@@ -11,12 +11,13 @@ from torch import nn
from vllm.model_executor.input_metadata import InputMetadata from vllm.model_executor.input_metadata import InputMetadata
from vllm.model_executor.layers.activation import SiluAndMul from vllm.model_executor.layers.activation import SiluAndMul
from vllm.model_executor.layers.attention import PagedAttentionWithRoPE from vllm.model_executor.layers.attention import PagedAttention
from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.linear import (LinearMethodBase, from vllm.model_executor.layers.linear import (LinearMethodBase,
MergedColumnParallelLinear, MergedColumnParallelLinear,
QKVParallelLinear, QKVParallelLinear,
RowParallelLinear) RowParallelLinear)
from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.layers.sampler import Sampler from vllm.model_executor.layers.sampler import Sampler
from vllm.model_executor.layers.vocab_parallel_embedding import ( from vllm.model_executor.layers.vocab_parallel_embedding import (
VocabParallelEmbedding, ParallelLMHead) VocabParallelEmbedding, ParallelLMHead)
...@@ -95,14 +96,15 @@ class QWenAttention(nn.Module): ...@@ -95,14 +96,15 @@ class QWenAttention(nn.Module):
linear_method=linear_method, linear_method=linear_method,
) )
self.scaling = self.head_dim**-0.5 self.scaling = self.head_dim**-0.5
self.attn = PagedAttentionWithRoPE(
self.num_heads, self.rotary_emb = get_rope(
self.head_dim, self.head_dim,
self.scaling,
rotary_dim=self.head_dim, rotary_dim=self.head_dim,
base=rope_theta,
max_position=max_position_embeddings, max_position=max_position_embeddings,
rope_scaling=rope_scaling) base=rope_theta,
rope_scaling=rope_scaling,
)
self.attn = PagedAttention(self.num_heads, self.head_dim, self.scaling)
def forward( def forward(
self, self,
...@@ -114,10 +116,10 @@ class QWenAttention(nn.Module): ...@@ -114,10 +116,10 @@ class QWenAttention(nn.Module):
) -> torch.Tensor: ) -> torch.Tensor:
qkv, _ = self.c_attn(hidden_states) qkv, _ = self.c_attn(hidden_states)
q, k, v = qkv.chunk(chunks=3, dim=-1) q, k, v = qkv.chunk(chunks=3, dim=-1)
q, k = self.rotary_emb(positions, q, k)
k_cache, v_cache = kv_cache k_cache, v_cache = kv_cache
attn_output = self.attn(positions, q, k, v, k_cache, v_cache, attn_output = self.attn(q, k, v, k_cache, v_cache, input_metadata,
input_metadata, cache_event) cache_event)
output, _ = self.c_proj(attn_output) output, _ = self.c_proj(attn_output)
return output return output
......
...@@ -29,12 +29,13 @@ from vllm.transformers_utils.configs.yi import YiConfig ...@@ -29,12 +29,13 @@ from vllm.transformers_utils.configs.yi import YiConfig
from vllm.model_executor.input_metadata import InputMetadata from vllm.model_executor.input_metadata import InputMetadata
from vllm.model_executor.layers.activation import SiluAndMul from vllm.model_executor.layers.activation import SiluAndMul
from vllm.model_executor.layers.attention import PagedAttentionWithRoPE from vllm.model_executor.layers.attention import PagedAttention
from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.linear import (LinearMethodBase, from vllm.model_executor.layers.linear import (LinearMethodBase,
MergedColumnParallelLinear, MergedColumnParallelLinear,
QKVParallelLinear, QKVParallelLinear,
RowParallelLinear) RowParallelLinear)
from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.layers.sampler import Sampler from vllm.model_executor.layers.sampler import Sampler
from vllm.model_executor.layers.vocab_parallel_embedding import ( from vllm.model_executor.layers.vocab_parallel_embedding import (
VocabParallelEmbedding, ParallelLMHead) VocabParallelEmbedding, ParallelLMHead)
...@@ -126,15 +127,17 @@ class YiAttention(nn.Module): ...@@ -126,15 +127,17 @@ class YiAttention(nn.Module):
bias=False, bias=False,
linear_method=linear_method, linear_method=linear_method,
) )
self.attn = PagedAttentionWithRoPE( self.rotary_emb = get_rope(
self.num_heads,
self.head_dim, self.head_dim,
self.scaling,
base=self.rope_theta,
max_position=self.max_position_embeddings,
rotary_dim=self.head_dim, rotary_dim=self.head_dim,
num_kv_heads=self.num_kv_heads, max_position=max_position_embeddings,
rope_scaling=rope_scaling) base=self.rope_theta,
rope_scaling=rope_scaling,
)
self.attn = PagedAttention(self.num_heads,
self.head_dim,
self.scaling,
num_kv_heads=self.num_kv_heads)
def forward( def forward(
self, self,
...@@ -146,9 +149,10 @@ class YiAttention(nn.Module): ...@@ -146,9 +149,10 @@ class YiAttention(nn.Module):
) -> torch.Tensor: ) -> torch.Tensor:
qkv, _ = self.qkv_proj(hidden_states) qkv, _ = self.qkv_proj(hidden_states)
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
q, k = self.rotary_emb(positions, q, k)
k_cache, v_cache = kv_cache k_cache, v_cache = kv_cache
attn_output = self.attn(positions, q, k, v, k_cache, v_cache, attn_output = self.attn(q, k, v, k_cache, v_cache, input_metadata,
input_metadata, cache_event) cache_event)
output, _ = self.o_proj(attn_output) output, _ = self.o_proj(attn_output)
return output return output
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment