Unverified Commit 2daf23ab authored by Woosuk Kwon's avatar Woosuk Kwon Committed by GitHub
Browse files

Separate attention backends (#3005)

parent cbf4c05b
......@@ -24,7 +24,7 @@ from transformers import GPTJConfig
from vllm.model_executor.input_metadata import InputMetadata
from vllm.model_executor.layers.activation import get_act_fn
from vllm.model_executor.layers.attention import PagedAttention
from vllm.model_executor.layers.attention import Attention
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
LinearMethodBase,
QKVParallelLinear,
......@@ -86,7 +86,7 @@ class GPTJAttention(nn.Module):
base=rope_theta,
is_neox_style=False,
)
self.attn = PagedAttention(self.num_heads, self.head_size, scaling)
self.attn = Attention(self.num_heads, self.head_size, scaling)
def forward(
self,
......
......@@ -24,7 +24,7 @@ from transformers import GPTNeoXConfig
from vllm.model_executor.input_metadata import InputMetadata
from vllm.model_executor.layers.activation import get_act_fn
from vllm.model_executor.layers.attention import PagedAttention
from vllm.model_executor.layers.attention import Attention
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
LinearMethodBase,
QKVParallelLinear,
......@@ -87,7 +87,7 @@ class GPTNeoXAttention(nn.Module):
max_position=max_position_embeddings,
base=rope_theta,
)
self.attn = PagedAttention(self.num_heads, self.head_size, scaling)
self.attn = Attention(self.num_heads, self.head_size, scaling)
def forward(
self,
......
......@@ -7,7 +7,7 @@ from transformers import PretrainedConfig
from vllm.model_executor.input_metadata import InputMetadata
from vllm.model_executor.layers.activation import SiluAndMul
from vllm.model_executor.layers.attention import PagedAttention
from vllm.model_executor.layers.attention import Attention
from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.linear import (LinearMethodBase,
MergedColumnParallelLinear,
......@@ -114,10 +114,10 @@ class InternLM2Attention(nn.Module):
base=rope_theta,
rope_scaling=rope_scaling,
)
self.attn = PagedAttention(self.num_heads,
self.head_dim,
self.scaling,
num_kv_heads=self.num_kv_heads)
self.attn = Attention(self.num_heads,
self.head_dim,
self.scaling,
num_kv_heads=self.num_kv_heads)
def forward(
self,
......
......@@ -30,7 +30,7 @@ from transformers import LlamaConfig
from vllm.config import LoRAConfig
from vllm.model_executor.input_metadata import InputMetadata
from vllm.model_executor.layers.activation import SiluAndMul
from vllm.model_executor.layers.attention import PagedAttention
from vllm.model_executor.layers.attention import Attention
from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.linear import (LinearMethodBase,
MergedColumnParallelLinear,
......@@ -139,11 +139,11 @@ class LlamaAttention(nn.Module):
base=rope_theta,
rope_scaling=rope_scaling,
)
self.attn = PagedAttention(self.num_heads,
self.head_dim,
self.scaling,
num_kv_heads=self.num_kv_heads,
sliding_window=sliding_window)
self.attn = Attention(self.num_heads,
self.head_dim,
self.scaling,
num_kv_heads=self.num_kv_heads,
sliding_window=sliding_window)
def forward(
self,
......
......@@ -29,7 +29,7 @@ from transformers import MixtralConfig
from vllm.config import LoRAConfig
from vllm.model_executor.input_metadata import InputMetadata
from vllm.model_executor.layers.attention import PagedAttention
from vllm.model_executor.layers.attention import Attention
from vllm.model_executor.layers.fused_moe import fused_moe
from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.linear import (LinearMethodBase,
......@@ -197,7 +197,7 @@ class MixtralAttention(nn.Module):
base=int(self.rope_theta),
is_neox_style=True,
)
self.attn = PagedAttention(
self.attn = Attention(
self.num_heads,
self.head_dim,
self.scaling,
......
......@@ -32,7 +32,7 @@ from torch import nn
from transformers import MixtralConfig
from vllm.model_executor.input_metadata import InputMetadata
from vllm.model_executor.layers.attention import PagedAttention
from vllm.model_executor.layers.attention import Attention
from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.linear import (LinearMethodBase,
ReplicatedLinear,
......@@ -214,7 +214,7 @@ class MixtralAttention(nn.Module):
base=int(self.rope_theta),
is_neox_style=True,
)
self.attn = PagedAttention(
self.attn = Attention(
self.num_heads,
self.head_dim,
self.scaling,
......
......@@ -8,7 +8,7 @@ import torch.nn as nn
from vllm.model_executor.input_metadata import InputMetadata
from vllm.model_executor.layers.activation import get_act_fn
from vllm.model_executor.layers.attention import PagedAttention
from vllm.model_executor.layers.attention import Attention
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
LinearMethodBase,
QKVParallelLinear,
......@@ -105,11 +105,11 @@ class MPTAttention(nn.Module):
self.head_dim = self.d_model // self.total_num_heads
scaling = self.head_dim**-0.5
self.attn = PagedAttention(self.num_heads,
self.head_dim,
scaling,
alibi_slopes=alibi_slopes,
num_kv_heads=self.num_kv_heads)
self.attn = Attention(self.num_heads,
self.head_dim,
scaling,
alibi_slopes=alibi_slopes,
num_kv_heads=self.num_kv_heads)
def forward(
self,
......
......@@ -43,7 +43,7 @@ import torch.nn.functional as F
from torch import nn
from vllm.model_executor.input_metadata import InputMetadata
from vllm.model_executor.layers.attention import PagedAttention
from vllm.model_executor.layers.attention import Attention
from vllm.model_executor.layers.linear import (
ColumnParallelLinear,
LinearMethodBase,
......@@ -126,9 +126,9 @@ class OlmoAttention(nn.Module):
base=rope_theta,
)
self.scaling = self.head_dim**-0.5
self.attn = PagedAttention(self.num_heads,
self.head_dim,
scale=self.scaling)
self.attn = Attention(self.num_heads,
self.head_dim,
scale=self.scaling)
# Attention output projection.
self.attn_out = RowParallelLinear(
......
......@@ -25,7 +25,7 @@ from transformers import OPTConfig
from vllm.model_executor.input_metadata import InputMetadata
from vllm.model_executor.layers.activation import get_act_fn
from vllm.model_executor.layers.attention import PagedAttention
from vllm.model_executor.layers.attention import Attention
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
LinearMethodBase,
QKVParallelLinear,
......@@ -89,9 +89,9 @@ class OPTAttention(nn.Module):
bias=bias,
linear_method=linear_method,
)
self.attn = PagedAttention(self.num_heads,
self.head_dim,
scale=self.scaling)
self.attn = Attention(self.num_heads,
self.head_dim,
scale=self.scaling)
def forward(
self,
......
......@@ -12,7 +12,7 @@ from transformers import PretrainedConfig
from vllm.model_executor.input_metadata import InputMetadata
from vllm.model_executor.layers.activation import SiluAndMul
from vllm.model_executor.layers.attention import PagedAttention
from vllm.model_executor.layers.attention import Attention
from vllm.model_executor.layers.linear import (LinearMethodBase,
MergedColumnParallelLinear,
QKVParallelLinear,
......@@ -118,10 +118,10 @@ class OrionAttention(nn.Module):
base=rope_theta,
rope_scaling=rope_scaling,
)
self.attn = PagedAttention(self.num_heads,
self.head_dim,
self.scaling,
num_kv_heads=self.num_kv_heads)
self.attn = Attention(self.num_heads,
self.head_dim,
self.scaling,
num_kv_heads=self.num_kv_heads)
def forward(
self,
......
......@@ -43,7 +43,7 @@ from transformers import PretrainedConfig
from vllm.model_executor.input_metadata import InputMetadata
from vllm.model_executor.layers.activation import get_act_fn
from vllm.model_executor.layers.attention import PagedAttention
from vllm.model_executor.layers.attention import Attention
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
LinearMethodBase,
QKVParallelLinear,
......@@ -108,7 +108,7 @@ class PhiAttention(nn.Module):
max_position=max_position_embeddings,
base=rope_theta,
)
self.attn = PagedAttention(self.num_heads, self.head_size, scaling)
self.attn = Attention(self.num_heads, self.head_size, scaling)
def forward(
self,
......
......@@ -12,7 +12,7 @@ from transformers import PretrainedConfig
from vllm.model_executor.input_metadata import InputMetadata
from vllm.model_executor.layers.activation import SiluAndMul
from vllm.model_executor.layers.attention import PagedAttention
from vllm.model_executor.layers.attention import Attention
from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.linear import (LinearMethodBase,
MergedColumnParallelLinear,
......@@ -104,7 +104,7 @@ class QWenAttention(nn.Module):
base=rope_theta,
rope_scaling=rope_scaling,
)
self.attn = PagedAttention(self.num_heads, self.head_dim, self.scaling)
self.attn = Attention(self.num_heads, self.head_dim, self.scaling)
def forward(
self,
......
......@@ -30,7 +30,7 @@ from transformers import Qwen2Config
from vllm.model_executor.input_metadata import InputMetadata
from vllm.model_executor.layers.activation import SiluAndMul
from vllm.model_executor.layers.attention import PagedAttention
from vllm.model_executor.layers.attention import Attention
from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.linear import (LinearMethodBase,
MergedColumnParallelLinear,
......@@ -135,11 +135,11 @@ class Qwen2Attention(nn.Module):
max_position=max_position,
base=self.rope_theta,
)
self.attn = PagedAttention(self.num_heads,
self.head_dim,
self.scaling,
num_kv_heads=self.num_kv_heads,
sliding_window=self.sliding_window)
self.attn = Attention(self.num_heads,
self.head_dim,
self.scaling,
num_kv_heads=self.num_kv_heads,
sliding_window=self.sliding_window)
def forward(
self,
......
......@@ -25,7 +25,7 @@ from transformers import PretrainedConfig
from vllm.model_executor.input_metadata import InputMetadata
from vllm.model_executor.layers.activation import SiluAndMul
from vllm.model_executor.layers.attention import PagedAttention
from vllm.model_executor.layers.attention import Attention
from vllm.model_executor.layers.linear import (LinearMethodBase,
MergedColumnParallelLinear,
QKVParallelLinear,
......@@ -122,10 +122,10 @@ class StablelmAttention(nn.Module):
max_position=self.config.max_position_embeddings,
base=self.config.rope_theta,
)
self.attn = PagedAttention(self.num_heads,
self.head_dim,
self.scaling,
num_kv_heads=self.num_key_value_heads)
self.attn = Attention(self.num_heads,
self.head_dim,
self.scaling,
num_kv_heads=self.num_key_value_heads)
def forward(
self,
......
......@@ -25,7 +25,7 @@ from torch import nn
from vllm.model_executor.input_metadata import InputMetadata
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.model_executor.layers.attention import PagedAttention
from vllm.model_executor.layers.attention import Attention
from vllm.model_executor.layers.activation import get_act_fn
from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
......@@ -103,7 +103,7 @@ class Starcoder2Attention(nn.Module):
base=int(self.rope_theta),
is_neox_style=True,
)
self.attn = PagedAttention(
self.attn = Attention(
self.num_heads,
self.head_dim,
self.scaling,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment