Unverified Commit 2daf23ab authored by Woosuk Kwon's avatar Woosuk Kwon Committed by GitHub
Browse files

Separate attention backends (#3005)

parent cbf4c05b
...@@ -24,7 +24,7 @@ from transformers import GPTJConfig ...@@ -24,7 +24,7 @@ from transformers import GPTJConfig
from vllm.model_executor.input_metadata import InputMetadata from vllm.model_executor.input_metadata import InputMetadata
from vllm.model_executor.layers.activation import get_act_fn from vllm.model_executor.layers.activation import get_act_fn
from vllm.model_executor.layers.attention import PagedAttention from vllm.model_executor.layers.attention import Attention
from vllm.model_executor.layers.linear import (ColumnParallelLinear, from vllm.model_executor.layers.linear import (ColumnParallelLinear,
LinearMethodBase, LinearMethodBase,
QKVParallelLinear, QKVParallelLinear,
...@@ -86,7 +86,7 @@ class GPTJAttention(nn.Module): ...@@ -86,7 +86,7 @@ class GPTJAttention(nn.Module):
base=rope_theta, base=rope_theta,
is_neox_style=False, is_neox_style=False,
) )
self.attn = PagedAttention(self.num_heads, self.head_size, scaling) self.attn = Attention(self.num_heads, self.head_size, scaling)
def forward( def forward(
self, self,
......
...@@ -24,7 +24,7 @@ from transformers import GPTNeoXConfig ...@@ -24,7 +24,7 @@ from transformers import GPTNeoXConfig
from vllm.model_executor.input_metadata import InputMetadata from vllm.model_executor.input_metadata import InputMetadata
from vllm.model_executor.layers.activation import get_act_fn from vllm.model_executor.layers.activation import get_act_fn
from vllm.model_executor.layers.attention import PagedAttention from vllm.model_executor.layers.attention import Attention
from vllm.model_executor.layers.linear import (ColumnParallelLinear, from vllm.model_executor.layers.linear import (ColumnParallelLinear,
LinearMethodBase, LinearMethodBase,
QKVParallelLinear, QKVParallelLinear,
...@@ -87,7 +87,7 @@ class GPTNeoXAttention(nn.Module): ...@@ -87,7 +87,7 @@ class GPTNeoXAttention(nn.Module):
max_position=max_position_embeddings, max_position=max_position_embeddings,
base=rope_theta, base=rope_theta,
) )
self.attn = PagedAttention(self.num_heads, self.head_size, scaling) self.attn = Attention(self.num_heads, self.head_size, scaling)
def forward( def forward(
self, self,
......
...@@ -7,7 +7,7 @@ from transformers import PretrainedConfig ...@@ -7,7 +7,7 @@ from transformers import PretrainedConfig
from vllm.model_executor.input_metadata import InputMetadata from vllm.model_executor.input_metadata import InputMetadata
from vllm.model_executor.layers.activation import SiluAndMul from vllm.model_executor.layers.activation import SiluAndMul
from vllm.model_executor.layers.attention import PagedAttention from vllm.model_executor.layers.attention import Attention
from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.linear import (LinearMethodBase, from vllm.model_executor.layers.linear import (LinearMethodBase,
MergedColumnParallelLinear, MergedColumnParallelLinear,
...@@ -114,7 +114,7 @@ class InternLM2Attention(nn.Module): ...@@ -114,7 +114,7 @@ class InternLM2Attention(nn.Module):
base=rope_theta, base=rope_theta,
rope_scaling=rope_scaling, rope_scaling=rope_scaling,
) )
self.attn = PagedAttention(self.num_heads, self.attn = Attention(self.num_heads,
self.head_dim, self.head_dim,
self.scaling, self.scaling,
num_kv_heads=self.num_kv_heads) num_kv_heads=self.num_kv_heads)
......
...@@ -30,7 +30,7 @@ from transformers import LlamaConfig ...@@ -30,7 +30,7 @@ from transformers import LlamaConfig
from vllm.config import LoRAConfig from vllm.config import LoRAConfig
from vllm.model_executor.input_metadata import InputMetadata from vllm.model_executor.input_metadata import InputMetadata
from vllm.model_executor.layers.activation import SiluAndMul from vllm.model_executor.layers.activation import SiluAndMul
from vllm.model_executor.layers.attention import PagedAttention from vllm.model_executor.layers.attention import Attention
from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.linear import (LinearMethodBase, from vllm.model_executor.layers.linear import (LinearMethodBase,
MergedColumnParallelLinear, MergedColumnParallelLinear,
...@@ -139,7 +139,7 @@ class LlamaAttention(nn.Module): ...@@ -139,7 +139,7 @@ class LlamaAttention(nn.Module):
base=rope_theta, base=rope_theta,
rope_scaling=rope_scaling, rope_scaling=rope_scaling,
) )
self.attn = PagedAttention(self.num_heads, self.attn = Attention(self.num_heads,
self.head_dim, self.head_dim,
self.scaling, self.scaling,
num_kv_heads=self.num_kv_heads, num_kv_heads=self.num_kv_heads,
......
...@@ -29,7 +29,7 @@ from transformers import MixtralConfig ...@@ -29,7 +29,7 @@ from transformers import MixtralConfig
from vllm.config import LoRAConfig from vllm.config import LoRAConfig
from vllm.model_executor.input_metadata import InputMetadata from vllm.model_executor.input_metadata import InputMetadata
from vllm.model_executor.layers.attention import PagedAttention from vllm.model_executor.layers.attention import Attention
from vllm.model_executor.layers.fused_moe import fused_moe from vllm.model_executor.layers.fused_moe import fused_moe
from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.linear import (LinearMethodBase, from vllm.model_executor.layers.linear import (LinearMethodBase,
...@@ -197,7 +197,7 @@ class MixtralAttention(nn.Module): ...@@ -197,7 +197,7 @@ class MixtralAttention(nn.Module):
base=int(self.rope_theta), base=int(self.rope_theta),
is_neox_style=True, is_neox_style=True,
) )
self.attn = PagedAttention( self.attn = Attention(
self.num_heads, self.num_heads,
self.head_dim, self.head_dim,
self.scaling, self.scaling,
......
...@@ -32,7 +32,7 @@ from torch import nn ...@@ -32,7 +32,7 @@ from torch import nn
from transformers import MixtralConfig from transformers import MixtralConfig
from vllm.model_executor.input_metadata import InputMetadata from vllm.model_executor.input_metadata import InputMetadata
from vllm.model_executor.layers.attention import PagedAttention from vllm.model_executor.layers.attention import Attention
from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.linear import (LinearMethodBase, from vllm.model_executor.layers.linear import (LinearMethodBase,
ReplicatedLinear, ReplicatedLinear,
...@@ -214,7 +214,7 @@ class MixtralAttention(nn.Module): ...@@ -214,7 +214,7 @@ class MixtralAttention(nn.Module):
base=int(self.rope_theta), base=int(self.rope_theta),
is_neox_style=True, is_neox_style=True,
) )
self.attn = PagedAttention( self.attn = Attention(
self.num_heads, self.num_heads,
self.head_dim, self.head_dim,
self.scaling, self.scaling,
......
...@@ -8,7 +8,7 @@ import torch.nn as nn ...@@ -8,7 +8,7 @@ import torch.nn as nn
from vllm.model_executor.input_metadata import InputMetadata from vllm.model_executor.input_metadata import InputMetadata
from vllm.model_executor.layers.activation import get_act_fn from vllm.model_executor.layers.activation import get_act_fn
from vllm.model_executor.layers.attention import PagedAttention from vllm.model_executor.layers.attention import Attention
from vllm.model_executor.layers.linear import (ColumnParallelLinear, from vllm.model_executor.layers.linear import (ColumnParallelLinear,
LinearMethodBase, LinearMethodBase,
QKVParallelLinear, QKVParallelLinear,
...@@ -105,7 +105,7 @@ class MPTAttention(nn.Module): ...@@ -105,7 +105,7 @@ class MPTAttention(nn.Module):
self.head_dim = self.d_model // self.total_num_heads self.head_dim = self.d_model // self.total_num_heads
scaling = self.head_dim**-0.5 scaling = self.head_dim**-0.5
self.attn = PagedAttention(self.num_heads, self.attn = Attention(self.num_heads,
self.head_dim, self.head_dim,
scaling, scaling,
alibi_slopes=alibi_slopes, alibi_slopes=alibi_slopes,
......
...@@ -43,7 +43,7 @@ import torch.nn.functional as F ...@@ -43,7 +43,7 @@ import torch.nn.functional as F
from torch import nn from torch import nn
from vllm.model_executor.input_metadata import InputMetadata from vllm.model_executor.input_metadata import InputMetadata
from vllm.model_executor.layers.attention import PagedAttention from vllm.model_executor.layers.attention import Attention
from vllm.model_executor.layers.linear import ( from vllm.model_executor.layers.linear import (
ColumnParallelLinear, ColumnParallelLinear,
LinearMethodBase, LinearMethodBase,
...@@ -126,7 +126,7 @@ class OlmoAttention(nn.Module): ...@@ -126,7 +126,7 @@ class OlmoAttention(nn.Module):
base=rope_theta, base=rope_theta,
) )
self.scaling = self.head_dim**-0.5 self.scaling = self.head_dim**-0.5
self.attn = PagedAttention(self.num_heads, self.attn = Attention(self.num_heads,
self.head_dim, self.head_dim,
scale=self.scaling) scale=self.scaling)
......
...@@ -25,7 +25,7 @@ from transformers import OPTConfig ...@@ -25,7 +25,7 @@ from transformers import OPTConfig
from vllm.model_executor.input_metadata import InputMetadata from vllm.model_executor.input_metadata import InputMetadata
from vllm.model_executor.layers.activation import get_act_fn from vllm.model_executor.layers.activation import get_act_fn
from vllm.model_executor.layers.attention import PagedAttention from vllm.model_executor.layers.attention import Attention
from vllm.model_executor.layers.linear import (ColumnParallelLinear, from vllm.model_executor.layers.linear import (ColumnParallelLinear,
LinearMethodBase, LinearMethodBase,
QKVParallelLinear, QKVParallelLinear,
...@@ -89,7 +89,7 @@ class OPTAttention(nn.Module): ...@@ -89,7 +89,7 @@ class OPTAttention(nn.Module):
bias=bias, bias=bias,
linear_method=linear_method, linear_method=linear_method,
) )
self.attn = PagedAttention(self.num_heads, self.attn = Attention(self.num_heads,
self.head_dim, self.head_dim,
scale=self.scaling) scale=self.scaling)
......
...@@ -12,7 +12,7 @@ from transformers import PretrainedConfig ...@@ -12,7 +12,7 @@ from transformers import PretrainedConfig
from vllm.model_executor.input_metadata import InputMetadata from vllm.model_executor.input_metadata import InputMetadata
from vllm.model_executor.layers.activation import SiluAndMul from vllm.model_executor.layers.activation import SiluAndMul
from vllm.model_executor.layers.attention import PagedAttention from vllm.model_executor.layers.attention import Attention
from vllm.model_executor.layers.linear import (LinearMethodBase, from vllm.model_executor.layers.linear import (LinearMethodBase,
MergedColumnParallelLinear, MergedColumnParallelLinear,
QKVParallelLinear, QKVParallelLinear,
...@@ -118,7 +118,7 @@ class OrionAttention(nn.Module): ...@@ -118,7 +118,7 @@ class OrionAttention(nn.Module):
base=rope_theta, base=rope_theta,
rope_scaling=rope_scaling, rope_scaling=rope_scaling,
) )
self.attn = PagedAttention(self.num_heads, self.attn = Attention(self.num_heads,
self.head_dim, self.head_dim,
self.scaling, self.scaling,
num_kv_heads=self.num_kv_heads) num_kv_heads=self.num_kv_heads)
......
...@@ -43,7 +43,7 @@ from transformers import PretrainedConfig ...@@ -43,7 +43,7 @@ from transformers import PretrainedConfig
from vllm.model_executor.input_metadata import InputMetadata from vllm.model_executor.input_metadata import InputMetadata
from vllm.model_executor.layers.activation import get_act_fn from vllm.model_executor.layers.activation import get_act_fn
from vllm.model_executor.layers.attention import PagedAttention from vllm.model_executor.layers.attention import Attention
from vllm.model_executor.layers.linear import (ColumnParallelLinear, from vllm.model_executor.layers.linear import (ColumnParallelLinear,
LinearMethodBase, LinearMethodBase,
QKVParallelLinear, QKVParallelLinear,
...@@ -108,7 +108,7 @@ class PhiAttention(nn.Module): ...@@ -108,7 +108,7 @@ class PhiAttention(nn.Module):
max_position=max_position_embeddings, max_position=max_position_embeddings,
base=rope_theta, base=rope_theta,
) )
self.attn = PagedAttention(self.num_heads, self.head_size, scaling) self.attn = Attention(self.num_heads, self.head_size, scaling)
def forward( def forward(
self, self,
......
...@@ -12,7 +12,7 @@ from transformers import PretrainedConfig ...@@ -12,7 +12,7 @@ from transformers import PretrainedConfig
from vllm.model_executor.input_metadata import InputMetadata from vllm.model_executor.input_metadata import InputMetadata
from vllm.model_executor.layers.activation import SiluAndMul from vllm.model_executor.layers.activation import SiluAndMul
from vllm.model_executor.layers.attention import PagedAttention from vllm.model_executor.layers.attention import Attention
from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.linear import (LinearMethodBase, from vllm.model_executor.layers.linear import (LinearMethodBase,
MergedColumnParallelLinear, MergedColumnParallelLinear,
...@@ -104,7 +104,7 @@ class QWenAttention(nn.Module): ...@@ -104,7 +104,7 @@ class QWenAttention(nn.Module):
base=rope_theta, base=rope_theta,
rope_scaling=rope_scaling, rope_scaling=rope_scaling,
) )
self.attn = PagedAttention(self.num_heads, self.head_dim, self.scaling) self.attn = Attention(self.num_heads, self.head_dim, self.scaling)
def forward( def forward(
self, self,
......
...@@ -30,7 +30,7 @@ from transformers import Qwen2Config ...@@ -30,7 +30,7 @@ from transformers import Qwen2Config
from vllm.model_executor.input_metadata import InputMetadata from vllm.model_executor.input_metadata import InputMetadata
from vllm.model_executor.layers.activation import SiluAndMul from vllm.model_executor.layers.activation import SiluAndMul
from vllm.model_executor.layers.attention import PagedAttention from vllm.model_executor.layers.attention import Attention
from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.linear import (LinearMethodBase, from vllm.model_executor.layers.linear import (LinearMethodBase,
MergedColumnParallelLinear, MergedColumnParallelLinear,
...@@ -135,7 +135,7 @@ class Qwen2Attention(nn.Module): ...@@ -135,7 +135,7 @@ class Qwen2Attention(nn.Module):
max_position=max_position, max_position=max_position,
base=self.rope_theta, base=self.rope_theta,
) )
self.attn = PagedAttention(self.num_heads, self.attn = Attention(self.num_heads,
self.head_dim, self.head_dim,
self.scaling, self.scaling,
num_kv_heads=self.num_kv_heads, num_kv_heads=self.num_kv_heads,
......
...@@ -25,7 +25,7 @@ from transformers import PretrainedConfig ...@@ -25,7 +25,7 @@ from transformers import PretrainedConfig
from vllm.model_executor.input_metadata import InputMetadata from vllm.model_executor.input_metadata import InputMetadata
from vllm.model_executor.layers.activation import SiluAndMul from vllm.model_executor.layers.activation import SiluAndMul
from vllm.model_executor.layers.attention import PagedAttention from vllm.model_executor.layers.attention import Attention
from vllm.model_executor.layers.linear import (LinearMethodBase, from vllm.model_executor.layers.linear import (LinearMethodBase,
MergedColumnParallelLinear, MergedColumnParallelLinear,
QKVParallelLinear, QKVParallelLinear,
...@@ -122,7 +122,7 @@ class StablelmAttention(nn.Module): ...@@ -122,7 +122,7 @@ class StablelmAttention(nn.Module):
max_position=self.config.max_position_embeddings, max_position=self.config.max_position_embeddings,
base=self.config.rope_theta, base=self.config.rope_theta,
) )
self.attn = PagedAttention(self.num_heads, self.attn = Attention(self.num_heads,
self.head_dim, self.head_dim,
self.scaling, self.scaling,
num_kv_heads=self.num_key_value_heads) num_kv_heads=self.num_key_value_heads)
......
...@@ -25,7 +25,7 @@ from torch import nn ...@@ -25,7 +25,7 @@ from torch import nn
from vllm.model_executor.input_metadata import InputMetadata from vllm.model_executor.input_metadata import InputMetadata
from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.model_executor.layers.attention import PagedAttention from vllm.model_executor.layers.attention import Attention
from vllm.model_executor.layers.activation import get_act_fn from vllm.model_executor.layers.activation import get_act_fn
from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.layers.linear import (ColumnParallelLinear, from vllm.model_executor.layers.linear import (ColumnParallelLinear,
...@@ -103,7 +103,7 @@ class Starcoder2Attention(nn.Module): ...@@ -103,7 +103,7 @@ class Starcoder2Attention(nn.Module):
base=int(self.rope_theta), base=int(self.rope_theta),
is_neox_style=True, is_neox_style=True,
) )
self.attn = PagedAttention( self.attn = Attention(
self.num_heads, self.num_heads,
self.head_dim, self.head_dim,
self.scaling, self.scaling,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment