Unverified Commit a9e45742 authored by Woosuk Kwon's avatar Woosuk Kwon Committed by GitHub
Browse files

Refactor Attention (#1840)

parent 0229c386
This diff is collapsed.
...@@ -277,8 +277,8 @@ def get_rope( ...@@ -277,8 +277,8 @@ def get_rope(
rotary_dim: int, rotary_dim: int,
max_position: int, max_position: int,
base: int, base: int,
is_neox_style: bool, is_neox_style: bool = True,
rope_scaling: Optional[Dict[str, Any]], rope_scaling: Optional[Dict[str, Any]] = None,
) -> RotaryEmbedding: ) -> RotaryEmbedding:
if rope_scaling is None: if rope_scaling is None:
rotary_emb = RotaryEmbedding(head_size, rotary_dim, max_position, base, rotary_emb = RotaryEmbedding(head_size, rotary_dim, max_position, base,
......
...@@ -28,11 +28,12 @@ from torch import nn ...@@ -28,11 +28,12 @@ from torch import nn
from vllm.model_executor.input_metadata import InputMetadata from vllm.model_executor.input_metadata import InputMetadata
from vllm.model_executor.layers.activation import SiluAndMul from vllm.model_executor.layers.activation import SiluAndMul
from vllm.model_executor.layers.attention import PagedAttentionWithRoPE from vllm.model_executor.layers.attention import PagedAttention
from vllm.model_executor.layers.linear import (LinearMethodBase, from vllm.model_executor.layers.linear import (LinearMethodBase,
MergedColumnParallelLinear, MergedColumnParallelLinear,
QKVParallelLinear, QKVParallelLinear,
RowParallelLinear) RowParallelLinear)
from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.layers.sampler import Sampler from vllm.model_executor.layers.sampler import Sampler
from vllm.model_executor.layers.vocab_parallel_embedding import ( from vllm.model_executor.layers.vocab_parallel_embedding import (
VocabParallelEmbedding, ParallelLMHead) VocabParallelEmbedding, ParallelLMHead)
...@@ -138,15 +139,17 @@ class AquilaAttention(nn.Module): ...@@ -138,15 +139,17 @@ class AquilaAttention(nn.Module):
bias=False, bias=False,
linear_method=linear_method, linear_method=linear_method,
) )
self.attn = PagedAttentionWithRoPE( self.rotary_emb = get_rope(
self.num_heads,
self.head_dim, self.head_dim,
self.scaling,
base=self.rope_theta,
max_position=self.max_position_embeddings,
rotary_dim=self.head_dim, rotary_dim=self.head_dim,
num_kv_heads=self.num_kv_heads, max_position=self.max_position_embeddings,
rope_scaling=rope_scaling) base=self.rope_theta,
rope_scaling=rope_scaling,
)
self.attn = PagedAttention(self.num_heads,
self.head_dim,
self.scaling,
num_kv_heads=self.num_kv_heads)
def forward( def forward(
self, self,
...@@ -158,9 +161,10 @@ class AquilaAttention(nn.Module): ...@@ -158,9 +161,10 @@ class AquilaAttention(nn.Module):
) -> torch.Tensor: ) -> torch.Tensor:
qkv, _ = self.qkv_proj(hidden_states) qkv, _ = self.qkv_proj(hidden_states)
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
q, k = self.rotary_emb(positions, q, k)
k_cache, v_cache = kv_cache k_cache, v_cache = kv_cache
attn_output = self.attn(positions, q, k, v, k_cache, v_cache, attn_output = self.attn(q, k, v, k_cache, v_cache, input_metadata,
input_metadata, cache_event) cache_event)
output, _ = self.o_proj(attn_output) output, _ = self.o_proj(attn_output)
return output return output
......
...@@ -26,13 +26,13 @@ from torch import nn ...@@ -26,13 +26,13 @@ from torch import nn
from vllm.model_executor.input_metadata import InputMetadata from vllm.model_executor.input_metadata import InputMetadata
from vllm.model_executor.layers.activation import SiluAndMul from vllm.model_executor.layers.activation import SiluAndMul
from vllm.model_executor.layers.attention import (PagedAttentionWithRoPE, from vllm.model_executor.layers.attention import PagedAttention
PagedAttentionWithALiBi)
from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.linear import (LinearMethodBase, from vllm.model_executor.layers.linear import (LinearMethodBase,
MergedColumnParallelLinear, MergedColumnParallelLinear,
QKVParallelLinear, QKVParallelLinear,
RowParallelLinear) RowParallelLinear)
from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.layers.sampler import Sampler from vllm.model_executor.layers.sampler import Sampler
from vllm.model_executor.layers.vocab_parallel_embedding import ( from vllm.model_executor.layers.vocab_parallel_embedding import (
VocabParallelEmbedding, ParallelLMHead) VocabParallelEmbedding, ParallelLMHead)
...@@ -150,17 +150,20 @@ class BaiChuanAttention(nn.Module): ...@@ -150,17 +150,20 @@ class BaiChuanAttention(nn.Module):
alibi_slopes = alibi_slopes[head_start:head_end].tolist() alibi_slopes = alibi_slopes[head_start:head_end].tolist()
scaling = self.head_dim**-0.5 scaling = self.head_dim**-0.5
self.attn = PagedAttentionWithALiBi(self.num_heads, self.head_dim, self.attn = PagedAttention(self.num_heads,
scaling, alibi_slopes) self.head_dim,
scaling,
alibi_slopes=alibi_slopes)
else: else:
self.scaling = self.head_dim**-0.5 self.rotary_emb = get_rope(
self.attn = PagedAttentionWithRoPE(
self.num_heads,
self.head_dim, self.head_dim,
self.scaling,
rotary_dim=self.head_dim, rotary_dim=self.head_dim,
max_position=self.max_position_embeddings,
base=self.rope_theta, base=self.rope_theta,
max_position=self.max_position_embeddings) )
self.scaling = self.head_dim**-0.5
self.attn = PagedAttention(self.num_heads, self.head_dim,
self.scaling)
def forward( def forward(
self, self,
...@@ -172,14 +175,11 @@ class BaiChuanAttention(nn.Module): ...@@ -172,14 +175,11 @@ class BaiChuanAttention(nn.Module):
) -> torch.Tensor: ) -> torch.Tensor:
qkv, _ = self.W_pack(hidden_states) qkv, _ = self.W_pack(hidden_states)
q, k, v = qkv.chunk(chunks=3, dim=-1) q, k, v = qkv.chunk(chunks=3, dim=-1)
if self.postion_embedding != "ALIBI":
q, k = self.rotary_emb(positions, q, k)
k_cache, v_cache = kv_cache k_cache, v_cache = kv_cache
if self.postion_embedding == "ALIBI":
attn_output = self.attn(q, k, v, k_cache, v_cache, input_metadata, attn_output = self.attn(q, k, v, k_cache, v_cache, input_metadata,
cache_event) cache_event)
else:
attn_output = self.attn(positions, q, k, v, k_cache, v_cache,
input_metadata, cache_event)
output, _ = self.o_proj(attn_output) output, _ = self.o_proj(attn_output)
return output return output
......
...@@ -25,7 +25,7 @@ from transformers import BloomConfig ...@@ -25,7 +25,7 @@ from transformers import BloomConfig
from vllm.model_executor.input_metadata import InputMetadata from vllm.model_executor.input_metadata import InputMetadata
from vllm.model_executor.layers.activation import get_act_fn from vllm.model_executor.layers.activation import get_act_fn
from vllm.model_executor.layers.attention import PagedAttentionWithALiBi from vllm.model_executor.layers.attention import PagedAttention
from vllm.model_executor.layers.linear import (ColumnParallelLinear, from vllm.model_executor.layers.linear import (ColumnParallelLinear,
LinearMethodBase, LinearMethodBase,
QKVParallelLinear, QKVParallelLinear,
...@@ -106,8 +106,10 @@ class BloomAttention(nn.Module): ...@@ -106,8 +106,10 @@ class BloomAttention(nn.Module):
alibi_slopes = alibi_slopes[head_start:head_end].tolist() alibi_slopes = alibi_slopes[head_start:head_end].tolist()
scaling = self.head_dim**-0.5 scaling = self.head_dim**-0.5
self.attn = PagedAttentionWithALiBi(self.num_heads, self.head_dim, self.attn = PagedAttention(self.num_heads,
scaling, alibi_slopes) self.head_dim,
scaling,
alibi_slopes=alibi_slopes)
def forward( def forward(
self, self,
......
...@@ -10,12 +10,13 @@ from torch.nn import LayerNorm ...@@ -10,12 +10,13 @@ from torch.nn import LayerNorm
from vllm.model_executor.input_metadata import InputMetadata from vllm.model_executor.input_metadata import InputMetadata
from vllm.model_executor.layers.activation import SiluAndMul from vllm.model_executor.layers.activation import SiluAndMul
from vllm.model_executor.layers.attention import PagedAttentionWithRoPE from vllm.model_executor.layers.attention import PagedAttention
from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.linear import (LinearMethodBase, from vllm.model_executor.layers.linear import (LinearMethodBase,
MergedColumnParallelLinear, MergedColumnParallelLinear,
QKVParallelLinear, QKVParallelLinear,
RowParallelLinear) RowParallelLinear)
from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.layers.sampler import Sampler from vllm.model_executor.layers.sampler import Sampler
from vllm.model_executor.layers.vocab_parallel_embedding import ( from vllm.model_executor.layers.vocab_parallel_embedding import (
VocabParallelEmbedding, ParallelLMHead) VocabParallelEmbedding, ParallelLMHead)
...@@ -78,16 +79,19 @@ class GLMAttention(nn.Module): ...@@ -78,16 +79,19 @@ class GLMAttention(nn.Module):
# https://huggingface.co/THUDM/chatglm3-6b-32k/blob/e210410255278dd9d74463cf396ba559c0ef801c/modeling_chatglm.py#L141 # https://huggingface.co/THUDM/chatglm3-6b-32k/blob/e210410255278dd9d74463cf396ba559c0ef801c/modeling_chatglm.py#L141
rope_ratio = getattr(config, "rope_ratio", 1.0) rope_ratio = getattr(config, "rope_ratio", 1.0)
max_positions = getattr(config, "seq_length", 8192) max_positions = getattr(config, "seq_length", 8192)
self.attn = PagedAttentionWithRoPE( self.rotary_emb = get_rope(
self.num_heads,
self.head_dim, self.head_dim,
self.scaling,
rotary_dim=self.head_dim // 2, rotary_dim=self.head_dim // 2,
num_kv_heads=self.num_kv_heads,
max_position=max_positions, max_position=max_positions,
base=10000 * rope_ratio, base=10000 * rope_ratio,
is_neox_style=False, is_neox_style=False,
) )
self.attn = PagedAttention(
self.num_heads,
self.head_dim,
self.scaling,
num_kv_heads=self.num_kv_heads,
)
def forward( def forward(
self, self,
...@@ -99,10 +103,9 @@ class GLMAttention(nn.Module): ...@@ -99,10 +103,9 @@ class GLMAttention(nn.Module):
) -> torch.Tensor: ) -> torch.Tensor:
qkv, _ = self.query_key_value(hidden_states) qkv, _ = self.query_key_value(hidden_states)
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
q, k = self.rotary_emb(position_ids, q, k)
key_cache, value_cache = kv_cache key_cache, value_cache = kv_cache
context_layer = self.attn( context_layer = self.attn(
position_ids,
q, q,
k, k,
v, v,
...@@ -111,9 +114,7 @@ class GLMAttention(nn.Module): ...@@ -111,9 +114,7 @@ class GLMAttention(nn.Module):
input_metadata, input_metadata,
cache_event, cache_event,
) )
attn_output, _ = self.dense(context_layer) attn_output, _ = self.dense(context_layer)
return attn_output return attn_output
......
...@@ -28,13 +28,12 @@ from transformers import FalconConfig as HF_FalconConfig ...@@ -28,13 +28,12 @@ from transformers import FalconConfig as HF_FalconConfig
from vllm.model_executor.input_metadata import InputMetadata from vllm.model_executor.input_metadata import InputMetadata
from vllm.model_executor.layers.activation import get_act_fn from vllm.model_executor.layers.activation import get_act_fn
from vllm.model_executor.layers.attention import (PagedAttention, from vllm.model_executor.layers.attention import PagedAttention
PagedAttentionWithALiBi,
PagedAttentionWithRoPE)
from vllm.model_executor.layers.linear import (ColumnParallelLinear, from vllm.model_executor.layers.linear import (ColumnParallelLinear,
LinearMethodBase, LinearMethodBase,
QKVParallelLinear, QKVParallelLinear,
RowParallelLinear) RowParallelLinear)
from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.layers.sampler import Sampler from vllm.model_executor.layers.sampler import Sampler
from vllm.model_executor.layers.vocab_parallel_embedding import ( from vllm.model_executor.layers.vocab_parallel_embedding import (
VocabParallelEmbedding, ParallelLMHead) VocabParallelEmbedding, ParallelLMHead)
...@@ -144,13 +143,15 @@ class FalconAttention(nn.Module): ...@@ -144,13 +143,15 @@ class FalconAttention(nn.Module):
rope_theta = getattr(config, "rope_theta", 10000) rope_theta = getattr(config, "rope_theta", 10000)
max_position_embeddings = getattr(config, max_position_embeddings = getattr(config,
"max_position_embeddings", 8192) "max_position_embeddings", 8192)
self.attn = PagedAttentionWithRoPE( self.rotary_emb = get_rope(
self.num_heads,
self.head_dim, self.head_dim,
self.inv_norm_factor,
base=rope_theta,
max_position=max_position_embeddings,
rotary_dim=self.head_dim, rotary_dim=self.head_dim,
max_position=max_position_embeddings,
base=rope_theta,
)
self.attn = PagedAttention(self.num_heads,
self.head_dim,
self.inv_norm_factor,
num_kv_heads=self.num_kv_heads) num_kv_heads=self.num_kv_heads)
elif self.use_alibi: elif self.use_alibi:
tp_rank = get_tensor_model_parallel_rank() tp_rank = get_tensor_model_parallel_rank()
...@@ -159,11 +160,11 @@ class FalconAttention(nn.Module): ...@@ -159,11 +160,11 @@ class FalconAttention(nn.Module):
alibi_slopes = (_get_alibi_slopes(self.total_num_heads) * alibi_slopes = (_get_alibi_slopes(self.total_num_heads) *
self.inv_norm_factor) self.inv_norm_factor)
alibi_slopes = alibi_slopes[head_start:head_end].tolist() alibi_slopes = alibi_slopes[head_start:head_end].tolist()
self.attn = PagedAttentionWithALiBi(self.num_heads, self.attn = PagedAttention(self.num_heads,
self.head_dim, self.head_dim,
self.inv_norm_factor, self.inv_norm_factor,
alibi_slopes, num_kv_heads=self.num_kv_heads,
num_kv_heads=self.num_kv_heads) alibi_slopes=alibi_slopes)
else: else:
self.attn = PagedAttention(self.num_heads, self.attn = PagedAttention(self.num_heads,
self.head_dim, self.head_dim,
...@@ -182,11 +183,9 @@ class FalconAttention(nn.Module): ...@@ -182,11 +183,9 @@ class FalconAttention(nn.Module):
if bias is not None: if bias is not None:
qkv += bias qkv += bias
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
k_cache, v_cache = kv_cache
if self.use_rotary: if self.use_rotary:
attn_output = self.attn(positions, q, k, v, k_cache, v_cache, q, k = self.rotary_emb(positions, q, k)
input_metadata, cache_event) k_cache, v_cache = kv_cache
else:
attn_output = self.attn(q, k, v, k_cache, v_cache, input_metadata, attn_output = self.attn(q, k, v, k_cache, v_cache, input_metadata,
cache_event) cache_event)
attn_output, bias = self.dense(attn_output) attn_output, bias = self.dense(attn_output)
......
...@@ -24,11 +24,12 @@ from transformers import GPTJConfig ...@@ -24,11 +24,12 @@ from transformers import GPTJConfig
from vllm.model_executor.input_metadata import InputMetadata from vllm.model_executor.input_metadata import InputMetadata
from vllm.model_executor.layers.activation import get_act_fn from vllm.model_executor.layers.activation import get_act_fn
from vllm.model_executor.layers.attention import PagedAttentionWithRoPE from vllm.model_executor.layers.attention import PagedAttention
from vllm.model_executor.layers.linear import (ColumnParallelLinear, from vllm.model_executor.layers.linear import (ColumnParallelLinear,
LinearMethodBase, LinearMethodBase,
QKVParallelLinear, QKVParallelLinear,
RowParallelLinear) RowParallelLinear)
from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.layers.sampler import Sampler from vllm.model_executor.layers.sampler import Sampler
from vllm.model_executor.layers.vocab_parallel_embedding import ( from vllm.model_executor.layers.vocab_parallel_embedding import (
VocabParallelEmbedding, ParallelLMHead) VocabParallelEmbedding, ParallelLMHead)
...@@ -77,15 +78,14 @@ class GPTJAttention(nn.Module): ...@@ -77,15 +78,14 @@ class GPTJAttention(nn.Module):
rope_theta = getattr(config, "rope_theta", 10000) rope_theta = getattr(config, "rope_theta", 10000)
max_position_embeddings = getattr(config, "max_position_embeddings", max_position_embeddings = getattr(config, "max_position_embeddings",
8192) 8192)
self.attn = PagedAttentionWithRoPE( self.rotary_emb = get_rope(
self.num_heads,
self.head_size, self.head_size,
scaling, rotary_dim=config.rotary_dim,
config.rotary_dim,
base=rope_theta,
max_position=max_position_embeddings, max_position=max_position_embeddings,
is_neox_style=False) base=rope_theta,
self.warmup = False is_neox_style=False,
)
self.attn = PagedAttention(self.num_heads, self.head_size, scaling)
def forward( def forward(
self, self,
...@@ -97,9 +97,10 @@ class GPTJAttention(nn.Module): ...@@ -97,9 +97,10 @@ class GPTJAttention(nn.Module):
) -> torch.Tensor: ) -> torch.Tensor:
qkv, _ = self.qkv_proj(hidden_states) qkv, _ = self.qkv_proj(hidden_states)
q, k, v = qkv.chunk(chunks=3, dim=-1) q, k, v = qkv.chunk(chunks=3, dim=-1)
q, k = self.rotary_emb(position_ids, q, k)
k_cache, v_cache = kv_cache k_cache, v_cache = kv_cache
attn_output = self.attn(position_ids, q, k, v, k_cache, v_cache, attn_output = self.attn(q, k, v, k_cache, v_cache, input_metadata,
input_metadata, cache_event) cache_event)
attn_output, _ = self.out_proj(attn_output) attn_output, _ = self.out_proj(attn_output)
return attn_output return attn_output
......
...@@ -24,11 +24,12 @@ from transformers import GPTNeoXConfig ...@@ -24,11 +24,12 @@ from transformers import GPTNeoXConfig
from vllm.model_executor.input_metadata import InputMetadata from vllm.model_executor.input_metadata import InputMetadata
from vllm.model_executor.layers.activation import get_act_fn from vllm.model_executor.layers.activation import get_act_fn
from vllm.model_executor.layers.attention import PagedAttentionWithRoPE from vllm.model_executor.layers.attention import PagedAttention
from vllm.model_executor.layers.linear import (ColumnParallelLinear, from vllm.model_executor.layers.linear import (ColumnParallelLinear,
LinearMethodBase, LinearMethodBase,
QKVParallelLinear, QKVParallelLinear,
RowParallelLinear) RowParallelLinear)
from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.layers.sampler import Sampler from vllm.model_executor.layers.sampler import Sampler
from vllm.model_executor.layers.vocab_parallel_embedding import ( from vllm.model_executor.layers.vocab_parallel_embedding import (
VocabParallelEmbedding, ParallelLMHead) VocabParallelEmbedding, ParallelLMHead)
...@@ -77,13 +78,13 @@ class GPTNeoXAttention(nn.Module): ...@@ -77,13 +78,13 @@ class GPTNeoXAttention(nn.Module):
rope_theta = getattr(config, "rope_theta", 10000) rope_theta = getattr(config, "rope_theta", 10000)
max_position_embeddings = getattr(config, "max_position_embeddings", max_position_embeddings = getattr(config, "max_position_embeddings",
8192) 8192)
self.attn = PagedAttentionWithRoPE( self.rotary_emb = get_rope(
self.num_heads,
self.head_size, self.head_size,
scaling, rotary_dim=rotary_dim,
rotary_dim, max_position=max_position_embeddings,
base=rope_theta, base=rope_theta,
max_position=max_position_embeddings) )
self.attn = PagedAttention(self.num_heads, self.head_size, scaling)
def forward( def forward(
self, self,
...@@ -95,9 +96,10 @@ class GPTNeoXAttention(nn.Module): ...@@ -95,9 +96,10 @@ class GPTNeoXAttention(nn.Module):
) -> torch.Tensor: ) -> torch.Tensor:
qkv, _ = self.query_key_value(hidden_states) qkv, _ = self.query_key_value(hidden_states)
q, k, v = qkv.chunk(chunks=3, dim=-1) q, k, v = qkv.chunk(chunks=3, dim=-1)
q, k = self.rotary_emb(position_ids, q, k)
k_cache, v_cache = kv_cache k_cache, v_cache = kv_cache
attn_output = self.attn(position_ids, q, k, v, k_cache, v_cache, attn_output = self.attn(q, k, v, k_cache, v_cache, input_metadata,
input_metadata, cache_event) cache_event)
output, _ = self.dense(attn_output) output, _ = self.dense(attn_output)
return output return output
......
...@@ -7,12 +7,13 @@ from transformers import LlamaConfig ...@@ -7,12 +7,13 @@ from transformers import LlamaConfig
from vllm.model_executor.input_metadata import InputMetadata from vllm.model_executor.input_metadata import InputMetadata
from vllm.model_executor.layers.activation import SiluAndMul from vllm.model_executor.layers.activation import SiluAndMul
from vllm.model_executor.layers.attention import PagedAttentionWithRoPE from vllm.model_executor.layers.attention import PagedAttention
from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.linear import (LinearMethodBase, from vllm.model_executor.layers.linear import (LinearMethodBase,
MergedColumnParallelLinear, MergedColumnParallelLinear,
QKVParallelLinear, QKVParallelLinear,
RowParallelLinear) RowParallelLinear)
from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.layers.sampler import Sampler from vllm.model_executor.layers.sampler import Sampler
from vllm.model_executor.layers.vocab_parallel_embedding import ( from vllm.model_executor.layers.vocab_parallel_embedding import (
VocabParallelEmbedding, ParallelLMHead) VocabParallelEmbedding, ParallelLMHead)
...@@ -92,13 +93,13 @@ class InternLMAttention(nn.Module): ...@@ -92,13 +93,13 @@ class InternLMAttention(nn.Module):
bias=bias, bias=bias,
linear_method=linear_method, linear_method=linear_method,
) )
self.attn = PagedAttentionWithRoPE( self.rotary_emb = get_rope(
self.num_heads,
self.head_dim, self.head_dim,
self.scaling, rotary_dim=self.head_dim,
base=self.rope_theta,
max_position=self.max_position_embeddings, max_position=self.max_position_embeddings,
rotary_dim=self.head_dim) base=self.rope_theta,
)
self.attn = PagedAttention(self.num_heads, self.head_dim, self.scaling)
def forward( def forward(
self, self,
...@@ -110,9 +111,10 @@ class InternLMAttention(nn.Module): ...@@ -110,9 +111,10 @@ class InternLMAttention(nn.Module):
) -> torch.Tensor: ) -> torch.Tensor:
qkv, _ = self.qkv_proj(hidden_states) qkv, _ = self.qkv_proj(hidden_states)
q, k, v = qkv.chunk(chunks=3, dim=-1) q, k, v = qkv.chunk(chunks=3, dim=-1)
q, k = self.rotary_emb(positions, q, k)
k_cache, v_cache = kv_cache k_cache, v_cache = kv_cache
attn_output = self.attn(positions, q, k, v, k_cache, v_cache, attn_output = self.attn(q, k, v, k_cache, v_cache, input_metadata,
input_metadata, cache_event) cache_event)
output, _ = self.o_proj(attn_output) output, _ = self.o_proj(attn_output)
return output return output
......
...@@ -29,12 +29,13 @@ from transformers import LlamaConfig ...@@ -29,12 +29,13 @@ from transformers import LlamaConfig
from vllm.model_executor.input_metadata import InputMetadata from vllm.model_executor.input_metadata import InputMetadata
from vllm.model_executor.layers.activation import SiluAndMul from vllm.model_executor.layers.activation import SiluAndMul
from vllm.model_executor.layers.attention import PagedAttentionWithRoPE from vllm.model_executor.layers.attention import PagedAttention
from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.linear import (LinearMethodBase, from vllm.model_executor.layers.linear import (LinearMethodBase,
MergedColumnParallelLinear, MergedColumnParallelLinear,
QKVParallelLinear, QKVParallelLinear,
RowParallelLinear) RowParallelLinear)
from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.layers.sampler import Sampler from vllm.model_executor.layers.sampler import Sampler
from vllm.model_executor.layers.vocab_parallel_embedding import ( from vllm.model_executor.layers.vocab_parallel_embedding import (
VocabParallelEmbedding, ParallelLMHead) VocabParallelEmbedding, ParallelLMHead)
...@@ -126,15 +127,18 @@ class LlamaAttention(nn.Module): ...@@ -126,15 +127,18 @@ class LlamaAttention(nn.Module):
bias=False, bias=False,
linear_method=linear_method, linear_method=linear_method,
) )
self.attn = PagedAttentionWithRoPE(
self.num_heads, self.rotary_emb = get_rope(
self.head_dim, self.head_dim,
self.scaling,
base=self.rope_theta,
max_position=self.max_position_embeddings,
rotary_dim=self.head_dim, rotary_dim=self.head_dim,
num_kv_heads=self.num_kv_heads, max_position=max_position_embeddings,
rope_scaling=rope_scaling) base=rope_theta,
rope_scaling=rope_scaling,
)
self.attn = PagedAttention(self.num_heads,
self.head_dim,
self.scaling,
num_kv_heads=self.num_kv_heads)
def forward( def forward(
self, self,
...@@ -146,9 +150,10 @@ class LlamaAttention(nn.Module): ...@@ -146,9 +150,10 @@ class LlamaAttention(nn.Module):
) -> torch.Tensor: ) -> torch.Tensor:
qkv, _ = self.qkv_proj(hidden_states) qkv, _ = self.qkv_proj(hidden_states)
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
q, k = self.rotary_emb(positions, q, k)
k_cache, v_cache = kv_cache k_cache, v_cache = kv_cache
attn_output = self.attn(positions, q, k, v, k_cache, v_cache, attn_output = self.attn(q, k, v, k_cache, v_cache, input_metadata,
input_metadata, cache_event) cache_event)
output, _ = self.o_proj(attn_output) output, _ = self.o_proj(attn_output)
return output return output
......
...@@ -29,12 +29,13 @@ from transformers import MistralConfig ...@@ -29,12 +29,13 @@ from transformers import MistralConfig
from vllm.model_executor.input_metadata import InputMetadata from vllm.model_executor.input_metadata import InputMetadata
from vllm.model_executor.layers.activation import SiluAndMul from vllm.model_executor.layers.activation import SiluAndMul
from vllm.model_executor.layers.attention import PagedAttentionWithRoPE from vllm.model_executor.layers.attention import PagedAttention
from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.linear import (LinearMethodBase, from vllm.model_executor.layers.linear import (LinearMethodBase,
MergedColumnParallelLinear, MergedColumnParallelLinear,
QKVParallelLinear, QKVParallelLinear,
RowParallelLinear) RowParallelLinear)
from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.layers.sampler import Sampler from vllm.model_executor.layers.sampler import Sampler
from vllm.model_executor.layers.vocab_parallel_embedding import ( from vllm.model_executor.layers.vocab_parallel_embedding import (
VocabParallelEmbedding, ParallelLMHead) VocabParallelEmbedding, ParallelLMHead)
...@@ -124,12 +125,16 @@ class MistralAttention(nn.Module): ...@@ -124,12 +125,16 @@ class MistralAttention(nn.Module):
bias=False, bias=False,
linear_method=linear_method, linear_method=linear_method,
) )
self.attn = PagedAttentionWithRoPE(self.num_heads,
self.rotary_emb = get_rope(
self.head_dim, self.head_dim,
self.scaling,
base=self.rope_theta,
max_position=max_position,
rotary_dim=self.head_dim, rotary_dim=self.head_dim,
max_position=max_position,
base=self.rope_theta,
)
self.attn = PagedAttention(self.num_heads,
self.head_dim,
self.scaling,
num_kv_heads=self.num_kv_heads, num_kv_heads=self.num_kv_heads,
sliding_window=self.sliding_window) sliding_window=self.sliding_window)
...@@ -143,9 +148,10 @@ class MistralAttention(nn.Module): ...@@ -143,9 +148,10 @@ class MistralAttention(nn.Module):
) -> torch.Tensor: ) -> torch.Tensor:
qkv, _ = self.qkv_proj(hidden_states) qkv, _ = self.qkv_proj(hidden_states)
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
q, k = self.rotary_emb(positions, q, k)
k_cache, v_cache = kv_cache k_cache, v_cache = kv_cache
attn_output = self.attn(positions, q, k, v, k_cache, v_cache, attn_output = self.attn(q, k, v, k_cache, v_cache, input_metadata,
input_metadata, cache_event) cache_event)
output, _ = self.o_proj(attn_output) output, _ = self.o_proj(attn_output)
return output return output
......
...@@ -8,7 +8,7 @@ import torch.nn as nn ...@@ -8,7 +8,7 @@ import torch.nn as nn
from vllm.model_executor.input_metadata import InputMetadata from vllm.model_executor.input_metadata import InputMetadata
from vllm.model_executor.layers.activation import get_act_fn from vllm.model_executor.layers.activation import get_act_fn
from vllm.model_executor.layers.attention import PagedAttentionWithALiBi from vllm.model_executor.layers.attention import PagedAttention
from vllm.model_executor.layers.linear import (ColumnParallelLinear, from vllm.model_executor.layers.linear import (ColumnParallelLinear,
LinearMethodBase, LinearMethodBase,
QKVParallelLinear, QKVParallelLinear,
...@@ -87,8 +87,10 @@ class MPTAttention(nn.Module): ...@@ -87,8 +87,10 @@ class MPTAttention(nn.Module):
self.head_dim = self.d_model // self.total_num_heads self.head_dim = self.d_model // self.total_num_heads
scaling = self.head_dim**-0.5 scaling = self.head_dim**-0.5
self.attn = PagedAttentionWithALiBi(self.num_heads, self.head_dim, self.attn = PagedAttention(self.num_heads,
scaling, alibi_slopes) self.head_dim,
scaling,
alibi_slopes=alibi_slopes)
def forward( def forward(
self, self,
......
...@@ -43,11 +43,12 @@ from transformers import PretrainedConfig ...@@ -43,11 +43,12 @@ from transformers import PretrainedConfig
from vllm.model_executor.input_metadata import InputMetadata from vllm.model_executor.input_metadata import InputMetadata
from vllm.model_executor.layers.activation import get_act_fn from vllm.model_executor.layers.activation import get_act_fn
from vllm.model_executor.layers.attention import PagedAttentionWithRoPE from vllm.model_executor.layers.attention import PagedAttention
from vllm.model_executor.layers.linear import (ColumnParallelLinear, from vllm.model_executor.layers.linear import (ColumnParallelLinear,
LinearMethodBase, LinearMethodBase,
QKVParallelLinear, QKVParallelLinear,
RowParallelLinear) RowParallelLinear)
from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.layers.sampler import Sampler from vllm.model_executor.layers.sampler import Sampler
from vllm.model_executor.layers.vocab_parallel_embedding import ( from vllm.model_executor.layers.vocab_parallel_embedding import (
VocabParallelEmbedding, ParallelLMHead) VocabParallelEmbedding, ParallelLMHead)
...@@ -119,13 +120,13 @@ class PhiAttention(nn.Module): ...@@ -119,13 +120,13 @@ class PhiAttention(nn.Module):
# https://huggingface.co/microsoft/phi-1_5/blob/d212a789620c380ff32ca1d1ee9943a777360987/modeling_phi.py#L518 # https://huggingface.co/microsoft/phi-1_5/blob/d212a789620c380ff32ca1d1ee9943a777360987/modeling_phi.py#L518
rope_theta = 10000 rope_theta = 10000
max_position_embeddings = getattr(config, "n_positions", 2048) max_position_embeddings = getattr(config, "n_positions", 2048)
self.attn = PagedAttentionWithRoPE( self.rotary_emb = get_rope(
self.num_heads,
self.head_size, self.head_size,
scaling, rotary_dim=rotary_dim,
rotary_dim, max_position=max_position_embeddings,
base=rope_theta, base=rope_theta,
max_position=max_position_embeddings) )
self.attn = PagedAttention(self.num_heads, self.head_size, scaling)
def forward( def forward(
self, self,
...@@ -137,9 +138,10 @@ class PhiAttention(nn.Module): ...@@ -137,9 +138,10 @@ class PhiAttention(nn.Module):
) -> torch.Tensor: ) -> torch.Tensor:
qkv, _ = self.Wqkv(hidden_states) qkv, _ = self.Wqkv(hidden_states)
q, k, v = qkv.chunk(chunks=3, dim=-1) q, k, v = qkv.chunk(chunks=3, dim=-1)
q, k = self.rotary_emb(position_ids, q, k)
k_cache, v_cache = kv_cache k_cache, v_cache = kv_cache
attn_output = self.attn(position_ids, q, k, v, k_cache, v_cache, attn_output = self.attn(q, k, v, k_cache, v_cache, input_metadata,
input_metadata, cache_event) cache_event)
output, _ = self.out_proj(attn_output) output, _ = self.out_proj(attn_output)
return output return output
......
...@@ -11,12 +11,13 @@ from torch import nn ...@@ -11,12 +11,13 @@ from torch import nn
from vllm.model_executor.input_metadata import InputMetadata from vllm.model_executor.input_metadata import InputMetadata
from vllm.model_executor.layers.activation import SiluAndMul from vllm.model_executor.layers.activation import SiluAndMul
from vllm.model_executor.layers.attention import PagedAttentionWithRoPE from vllm.model_executor.layers.attention import PagedAttention
from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.linear import (LinearMethodBase, from vllm.model_executor.layers.linear import (LinearMethodBase,
MergedColumnParallelLinear, MergedColumnParallelLinear,
QKVParallelLinear, QKVParallelLinear,
RowParallelLinear) RowParallelLinear)
from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.layers.sampler import Sampler from vllm.model_executor.layers.sampler import Sampler
from vllm.model_executor.layers.vocab_parallel_embedding import ( from vllm.model_executor.layers.vocab_parallel_embedding import (
VocabParallelEmbedding, ParallelLMHead) VocabParallelEmbedding, ParallelLMHead)
...@@ -95,14 +96,15 @@ class QWenAttention(nn.Module): ...@@ -95,14 +96,15 @@ class QWenAttention(nn.Module):
linear_method=linear_method, linear_method=linear_method,
) )
self.scaling = self.head_dim**-0.5 self.scaling = self.head_dim**-0.5
self.attn = PagedAttentionWithRoPE(
self.num_heads, self.rotary_emb = get_rope(
self.head_dim, self.head_dim,
self.scaling,
rotary_dim=self.head_dim, rotary_dim=self.head_dim,
base=rope_theta,
max_position=max_position_embeddings, max_position=max_position_embeddings,
rope_scaling=rope_scaling) base=rope_theta,
rope_scaling=rope_scaling,
)
self.attn = PagedAttention(self.num_heads, self.head_dim, self.scaling)
def forward( def forward(
self, self,
...@@ -114,10 +116,10 @@ class QWenAttention(nn.Module): ...@@ -114,10 +116,10 @@ class QWenAttention(nn.Module):
) -> torch.Tensor: ) -> torch.Tensor:
qkv, _ = self.c_attn(hidden_states) qkv, _ = self.c_attn(hidden_states)
q, k, v = qkv.chunk(chunks=3, dim=-1) q, k, v = qkv.chunk(chunks=3, dim=-1)
q, k = self.rotary_emb(positions, q, k)
k_cache, v_cache = kv_cache k_cache, v_cache = kv_cache
attn_output = self.attn(positions, q, k, v, k_cache, v_cache, attn_output = self.attn(q, k, v, k_cache, v_cache, input_metadata,
input_metadata, cache_event) cache_event)
output, _ = self.c_proj(attn_output) output, _ = self.c_proj(attn_output)
return output return output
......
...@@ -29,12 +29,13 @@ from vllm.transformers_utils.configs.yi import YiConfig ...@@ -29,12 +29,13 @@ from vllm.transformers_utils.configs.yi import YiConfig
from vllm.model_executor.input_metadata import InputMetadata from vllm.model_executor.input_metadata import InputMetadata
from vllm.model_executor.layers.activation import SiluAndMul from vllm.model_executor.layers.activation import SiluAndMul
from vllm.model_executor.layers.attention import PagedAttentionWithRoPE from vllm.model_executor.layers.attention import PagedAttention
from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.linear import (LinearMethodBase, from vllm.model_executor.layers.linear import (LinearMethodBase,
MergedColumnParallelLinear, MergedColumnParallelLinear,
QKVParallelLinear, QKVParallelLinear,
RowParallelLinear) RowParallelLinear)
from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.layers.sampler import Sampler from vllm.model_executor.layers.sampler import Sampler
from vllm.model_executor.layers.vocab_parallel_embedding import ( from vllm.model_executor.layers.vocab_parallel_embedding import (
VocabParallelEmbedding, ParallelLMHead) VocabParallelEmbedding, ParallelLMHead)
...@@ -126,15 +127,17 @@ class YiAttention(nn.Module): ...@@ -126,15 +127,17 @@ class YiAttention(nn.Module):
bias=False, bias=False,
linear_method=linear_method, linear_method=linear_method,
) )
self.attn = PagedAttentionWithRoPE( self.rotary_emb = get_rope(
self.num_heads,
self.head_dim, self.head_dim,
self.scaling,
base=self.rope_theta,
max_position=self.max_position_embeddings,
rotary_dim=self.head_dim, rotary_dim=self.head_dim,
num_kv_heads=self.num_kv_heads, max_position=max_position_embeddings,
rope_scaling=rope_scaling) base=self.rope_theta,
rope_scaling=rope_scaling,
)
self.attn = PagedAttention(self.num_heads,
self.head_dim,
self.scaling,
num_kv_heads=self.num_kv_heads)
def forward( def forward(
self, self,
...@@ -146,9 +149,10 @@ class YiAttention(nn.Module): ...@@ -146,9 +149,10 @@ class YiAttention(nn.Module):
) -> torch.Tensor: ) -> torch.Tensor:
qkv, _ = self.qkv_proj(hidden_states) qkv, _ = self.qkv_proj(hidden_states)
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
q, k = self.rotary_emb(positions, q, k)
k_cache, v_cache = kv_cache k_cache, v_cache = kv_cache
attn_output = self.attn(positions, q, k, v, k_cache, v_cache, attn_output = self.attn(q, k, v, k_cache, v_cache, input_metadata,
input_metadata, cache_event) cache_event)
output, _ = self.o_proj(attn_output) output, _ = self.o_proj(attn_output)
return output return output
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment