Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
norm
vllm
Commits
a9e45742
Unverified
Commit
a9e45742
authored
Nov 29, 2023
by
Woosuk Kwon
Committed by
GitHub
Nov 29, 2023
Browse files
Refactor Attention (#1840)
parent
0229c386
Changes
16
Expand all
Hide whitespace changes
Inline
Side-by-side
Showing
16 changed files
with
360 additions
and
498 deletions
+360
-498
vllm/model_executor/layers/attention.py
vllm/model_executor/layers/attention.py
+191
-361
vllm/model_executor/layers/rotary_embedding.py
vllm/model_executor/layers/rotary_embedding.py
+2
-2
vllm/model_executor/models/aquila.py
vllm/model_executor/models/aquila.py
+14
-10
vllm/model_executor/models/baichuan.py
vllm/model_executor/models/baichuan.py
+16
-16
vllm/model_executor/models/bloom.py
vllm/model_executor/models/bloom.py
+5
-3
vllm/model_executor/models/chatglm.py
vllm/model_executor/models/chatglm.py
+10
-9
vllm/model_executor/models/falcon.py
vllm/model_executor/models/falcon.py
+19
-20
vllm/model_executor/models/gpt_j.py
vllm/model_executor/models/gpt_j.py
+11
-10
vllm/model_executor/models/gpt_neox.py
vllm/model_executor/models/gpt_neox.py
+10
-8
vllm/model_executor/models/internlm.py
vllm/model_executor/models/internlm.py
+10
-8
vllm/model_executor/models/llama.py
vllm/model_executor/models/llama.py
+15
-10
vllm/model_executor/models/mistral.py
vllm/model_executor/models/mistral.py
+17
-11
vllm/model_executor/models/mpt.py
vllm/model_executor/models/mpt.py
+5
-3
vllm/model_executor/models/phi_1_5.py
vllm/model_executor/models/phi_1_5.py
+10
-8
vllm/model_executor/models/qwen.py
vllm/model_executor/models/qwen.py
+11
-9
vllm/model_executor/models/yi.py
vllm/model_executor/models/yi.py
+14
-10
No files found.
vllm/model_executor/layers/attention.py
View file @
a9e45742
This diff is collapsed.
Click to expand it.
vllm/model_executor/layers/rotary_embedding.py
View file @
a9e45742
...
...
@@ -277,8 +277,8 @@ def get_rope(
rotary_dim
:
int
,
max_position
:
int
,
base
:
int
,
is_neox_style
:
bool
,
rope_scaling
:
Optional
[
Dict
[
str
,
Any
]],
is_neox_style
:
bool
=
True
,
rope_scaling
:
Optional
[
Dict
[
str
,
Any
]]
=
None
,
)
->
RotaryEmbedding
:
if
rope_scaling
is
None
:
rotary_emb
=
RotaryEmbedding
(
head_size
,
rotary_dim
,
max_position
,
base
,
...
...
vllm/model_executor/models/aquila.py
View file @
a9e45742
...
...
@@ -28,11 +28,12 @@ from torch import nn
from
vllm.model_executor.input_metadata
import
InputMetadata
from
vllm.model_executor.layers.activation
import
SiluAndMul
from
vllm.model_executor.layers.attention
import
PagedAttention
WithRoPE
from
vllm.model_executor.layers.attention
import
PagedAttention
from
vllm.model_executor.layers.linear
import
(
LinearMethodBase
,
MergedColumnParallelLinear
,
QKVParallelLinear
,
RowParallelLinear
)
from
vllm.model_executor.layers.rotary_embedding
import
get_rope
from
vllm.model_executor.layers.sampler
import
Sampler
from
vllm.model_executor.layers.vocab_parallel_embedding
import
(
VocabParallelEmbedding
,
ParallelLMHead
)
...
...
@@ -138,15 +139,17 @@ class AquilaAttention(nn.Module):
bias
=
False
,
linear_method
=
linear_method
,
)
self
.
attn
=
PagedAttentionWithRoPE
(
self
.
num_heads
,
self
.
rotary_emb
=
get_rope
(
self
.
head_dim
,
self
.
scaling
,
base
=
self
.
rope_theta
,
max_position
=
self
.
max_position_embeddings
,
rotary_dim
=
self
.
head_dim
,
num_kv_heads
=
self
.
num_kv_heads
,
rope_scaling
=
rope_scaling
)
max_position
=
self
.
max_position_embeddings
,
base
=
self
.
rope_theta
,
rope_scaling
=
rope_scaling
,
)
self
.
attn
=
PagedAttention
(
self
.
num_heads
,
self
.
head_dim
,
self
.
scaling
,
num_kv_heads
=
self
.
num_kv_heads
)
def
forward
(
self
,
...
...
@@ -158,9 +161,10 @@ class AquilaAttention(nn.Module):
)
->
torch
.
Tensor
:
qkv
,
_
=
self
.
qkv_proj
(
hidden_states
)
q
,
k
,
v
=
qkv
.
split
([
self
.
q_size
,
self
.
kv_size
,
self
.
kv_size
],
dim
=-
1
)
q
,
k
=
self
.
rotary_emb
(
positions
,
q
,
k
)
k_cache
,
v_cache
=
kv_cache
attn_output
=
self
.
attn
(
positions
,
q
,
k
,
v
,
k_cache
,
v_cache
,
input_metadata
,
cache_event
)
attn_output
=
self
.
attn
(
q
,
k
,
v
,
k_cache
,
v_cache
,
input_metadata
,
cache_event
)
output
,
_
=
self
.
o_proj
(
attn_output
)
return
output
...
...
vllm/model_executor/models/baichuan.py
View file @
a9e45742
...
...
@@ -26,13 +26,13 @@ from torch import nn
from
vllm.model_executor.input_metadata
import
InputMetadata
from
vllm.model_executor.layers.activation
import
SiluAndMul
from
vllm.model_executor.layers.attention
import
(
PagedAttentionWithRoPE
,
PagedAttentionWithALiBi
)
from
vllm.model_executor.layers.attention
import
PagedAttention
from
vllm.model_executor.layers.layernorm
import
RMSNorm
from
vllm.model_executor.layers.linear
import
(
LinearMethodBase
,
MergedColumnParallelLinear
,
QKVParallelLinear
,
RowParallelLinear
)
from
vllm.model_executor.layers.rotary_embedding
import
get_rope
from
vllm.model_executor.layers.sampler
import
Sampler
from
vllm.model_executor.layers.vocab_parallel_embedding
import
(
VocabParallelEmbedding
,
ParallelLMHead
)
...
...
@@ -150,17 +150,20 @@ class BaiChuanAttention(nn.Module):
alibi_slopes
=
alibi_slopes
[
head_start
:
head_end
].
tolist
()
scaling
=
self
.
head_dim
**-
0.5
self
.
attn
=
PagedAttentionWithALiBi
(
self
.
num_heads
,
self
.
head_dim
,
scaling
,
alibi_slopes
)
self
.
attn
=
PagedAttention
(
self
.
num_heads
,
self
.
head_dim
,
scaling
,
alibi_slopes
=
alibi_slopes
)
else
:
self
.
scaling
=
self
.
head_dim
**-
0.5
self
.
attn
=
PagedAttentionWithRoPE
(
self
.
num_heads
,
self
.
rotary_emb
=
get_rope
(
self
.
head_dim
,
self
.
scaling
,
rotary_dim
=
self
.
head_dim
,
max_position
=
self
.
max_position_embeddings
,
base
=
self
.
rope_theta
,
max_position
=
self
.
max_position_embeddings
)
)
self
.
scaling
=
self
.
head_dim
**-
0.5
self
.
attn
=
PagedAttention
(
self
.
num_heads
,
self
.
head_dim
,
self
.
scaling
)
def
forward
(
self
,
...
...
@@ -172,14 +175,11 @@ class BaiChuanAttention(nn.Module):
)
->
torch
.
Tensor
:
qkv
,
_
=
self
.
W_pack
(
hidden_states
)
q
,
k
,
v
=
qkv
.
chunk
(
chunks
=
3
,
dim
=-
1
)
if
self
.
postion_embedding
!=
"ALIBI"
:
q
,
k
=
self
.
rotary_emb
(
positions
,
q
,
k
)
k_cache
,
v_cache
=
kv_cache
if
self
.
postion_embedding
==
"ALIBI"
:
attn_output
=
self
.
attn
(
q
,
k
,
v
,
k_cache
,
v_cache
,
input_metadata
,
cache_event
)
else
:
attn_output
=
self
.
attn
(
positions
,
q
,
k
,
v
,
k_cache
,
v_cache
,
input_metadata
,
cache_event
)
attn_output
=
self
.
attn
(
q
,
k
,
v
,
k_cache
,
v_cache
,
input_metadata
,
cache_event
)
output
,
_
=
self
.
o_proj
(
attn_output
)
return
output
...
...
vllm/model_executor/models/bloom.py
View file @
a9e45742
...
...
@@ -25,7 +25,7 @@ from transformers import BloomConfig
from
vllm.model_executor.input_metadata
import
InputMetadata
from
vllm.model_executor.layers.activation
import
get_act_fn
from
vllm.model_executor.layers.attention
import
PagedAttention
WithALiBi
from
vllm.model_executor.layers.attention
import
PagedAttention
from
vllm.model_executor.layers.linear
import
(
ColumnParallelLinear
,
LinearMethodBase
,
QKVParallelLinear
,
...
...
@@ -106,8 +106,10 @@ class BloomAttention(nn.Module):
alibi_slopes
=
alibi_slopes
[
head_start
:
head_end
].
tolist
()
scaling
=
self
.
head_dim
**-
0.5
self
.
attn
=
PagedAttentionWithALiBi
(
self
.
num_heads
,
self
.
head_dim
,
scaling
,
alibi_slopes
)
self
.
attn
=
PagedAttention
(
self
.
num_heads
,
self
.
head_dim
,
scaling
,
alibi_slopes
=
alibi_slopes
)
def
forward
(
self
,
...
...
vllm/model_executor/models/chatglm.py
View file @
a9e45742
...
...
@@ -10,12 +10,13 @@ from torch.nn import LayerNorm
from
vllm.model_executor.input_metadata
import
InputMetadata
from
vllm.model_executor.layers.activation
import
SiluAndMul
from
vllm.model_executor.layers.attention
import
PagedAttention
WithRoPE
from
vllm.model_executor.layers.attention
import
PagedAttention
from
vllm.model_executor.layers.layernorm
import
RMSNorm
from
vllm.model_executor.layers.linear
import
(
LinearMethodBase
,
MergedColumnParallelLinear
,
QKVParallelLinear
,
RowParallelLinear
)
from
vllm.model_executor.layers.rotary_embedding
import
get_rope
from
vllm.model_executor.layers.sampler
import
Sampler
from
vllm.model_executor.layers.vocab_parallel_embedding
import
(
VocabParallelEmbedding
,
ParallelLMHead
)
...
...
@@ -78,16 +79,19 @@ class GLMAttention(nn.Module):
# https://huggingface.co/THUDM/chatglm3-6b-32k/blob/e210410255278dd9d74463cf396ba559c0ef801c/modeling_chatglm.py#L141
rope_ratio
=
getattr
(
config
,
"rope_ratio"
,
1.0
)
max_positions
=
getattr
(
config
,
"seq_length"
,
8192
)
self
.
attn
=
PagedAttentionWithRoPE
(
self
.
num_heads
,
self
.
rotary_emb
=
get_rope
(
self
.
head_dim
,
self
.
scaling
,
rotary_dim
=
self
.
head_dim
//
2
,
num_kv_heads
=
self
.
num_kv_heads
,
max_position
=
max_positions
,
base
=
10000
*
rope_ratio
,
is_neox_style
=
False
,
)
self
.
attn
=
PagedAttention
(
self
.
num_heads
,
self
.
head_dim
,
self
.
scaling
,
num_kv_heads
=
self
.
num_kv_heads
,
)
def
forward
(
self
,
...
...
@@ -99,10 +103,9 @@ class GLMAttention(nn.Module):
)
->
torch
.
Tensor
:
qkv
,
_
=
self
.
query_key_value
(
hidden_states
)
q
,
k
,
v
=
qkv
.
split
([
self
.
q_size
,
self
.
kv_size
,
self
.
kv_size
],
dim
=-
1
)
q
,
k
=
self
.
rotary_emb
(
position_ids
,
q
,
k
)
key_cache
,
value_cache
=
kv_cache
context_layer
=
self
.
attn
(
position_ids
,
q
,
k
,
v
,
...
...
@@ -111,9 +114,7 @@ class GLMAttention(nn.Module):
input_metadata
,
cache_event
,
)
attn_output
,
_
=
self
.
dense
(
context_layer
)
return
attn_output
...
...
vllm/model_executor/models/falcon.py
View file @
a9e45742
...
...
@@ -28,13 +28,12 @@ from transformers import FalconConfig as HF_FalconConfig
from
vllm.model_executor.input_metadata
import
InputMetadata
from
vllm.model_executor.layers.activation
import
get_act_fn
from
vllm.model_executor.layers.attention
import
(
PagedAttention
,
PagedAttentionWithALiBi
,
PagedAttentionWithRoPE
)
from
vllm.model_executor.layers.attention
import
PagedAttention
from
vllm.model_executor.layers.linear
import
(
ColumnParallelLinear
,
LinearMethodBase
,
QKVParallelLinear
,
RowParallelLinear
)
from
vllm.model_executor.layers.rotary_embedding
import
get_rope
from
vllm.model_executor.layers.sampler
import
Sampler
from
vllm.model_executor.layers.vocab_parallel_embedding
import
(
VocabParallelEmbedding
,
ParallelLMHead
)
...
...
@@ -144,14 +143,16 @@ class FalconAttention(nn.Module):
rope_theta
=
getattr
(
config
,
"rope_theta"
,
10000
)
max_position_embeddings
=
getattr
(
config
,
"max_position_embeddings"
,
8192
)
self
.
attn
=
PagedAttentionWithRoPE
(
self
.
num_heads
,
self
.
rotary_emb
=
get_rope
(
self
.
head_dim
,
self
.
inv_norm_factor
,
base
=
rope_theta
,
max_position
=
max_position_embeddings
,
rotary_dim
=
self
.
head_dim
,
num_kv_heads
=
self
.
num_kv_heads
)
max_position
=
max_position_embeddings
,
base
=
rope_theta
,
)
self
.
attn
=
PagedAttention
(
self
.
num_heads
,
self
.
head_dim
,
self
.
inv_norm_factor
,
num_kv_heads
=
self
.
num_kv_heads
)
elif
self
.
use_alibi
:
tp_rank
=
get_tensor_model_parallel_rank
()
head_start
=
tp_rank
*
self
.
num_heads
...
...
@@ -159,11 +160,11 @@ class FalconAttention(nn.Module):
alibi_slopes
=
(
_get_alibi_slopes
(
self
.
total_num_heads
)
*
self
.
inv_norm_factor
)
alibi_slopes
=
alibi_slopes
[
head_start
:
head_end
].
tolist
()
self
.
attn
=
PagedAttention
WithALiBi
(
self
.
num_heads
,
self
.
head_dim
,
self
.
inv_norm_factor
,
alibi_slope
s
,
num_kv_heads
=
self
.
num_kv_head
s
)
self
.
attn
=
PagedAttention
(
self
.
num_heads
,
self
.
head_dim
,
self
.
inv_norm_factor
,
num_kv_heads
=
self
.
num_kv_head
s
,
alibi_slopes
=
alibi_slope
s
)
else
:
self
.
attn
=
PagedAttention
(
self
.
num_heads
,
self
.
head_dim
,
...
...
@@ -182,13 +183,11 @@ class FalconAttention(nn.Module):
if
bias
is
not
None
:
qkv
+=
bias
q
,
k
,
v
=
qkv
.
split
([
self
.
q_size
,
self
.
kv_size
,
self
.
kv_size
],
dim
=-
1
)
k_cache
,
v_cache
=
kv_cache
if
self
.
use_rotary
:
attn_output
=
self
.
attn
(
positions
,
q
,
k
,
v
,
k_cache
,
v_cache
,
input_metadata
,
cache_event
)
else
:
attn_output
=
self
.
attn
(
q
,
k
,
v
,
k_cache
,
v_cache
,
input_metadata
,
cache_event
)
q
,
k
=
self
.
rotary_emb
(
positions
,
q
,
k
)
k_cache
,
v_cache
=
kv_cache
attn_output
=
self
.
attn
(
q
,
k
,
v
,
k_cache
,
v_cache
,
input_metadata
,
cache_event
)
attn_output
,
bias
=
self
.
dense
(
attn_output
)
return
attn_output
,
bias
...
...
vllm/model_executor/models/gpt_j.py
View file @
a9e45742
...
...
@@ -24,11 +24,12 @@ from transformers import GPTJConfig
from
vllm.model_executor.input_metadata
import
InputMetadata
from
vllm.model_executor.layers.activation
import
get_act_fn
from
vllm.model_executor.layers.attention
import
PagedAttention
WithRoPE
from
vllm.model_executor.layers.attention
import
PagedAttention
from
vllm.model_executor.layers.linear
import
(
ColumnParallelLinear
,
LinearMethodBase
,
QKVParallelLinear
,
RowParallelLinear
)
from
vllm.model_executor.layers.rotary_embedding
import
get_rope
from
vllm.model_executor.layers.sampler
import
Sampler
from
vllm.model_executor.layers.vocab_parallel_embedding
import
(
VocabParallelEmbedding
,
ParallelLMHead
)
...
...
@@ -77,15 +78,14 @@ class GPTJAttention(nn.Module):
rope_theta
=
getattr
(
config
,
"rope_theta"
,
10000
)
max_position_embeddings
=
getattr
(
config
,
"max_position_embeddings"
,
8192
)
self
.
attn
=
PagedAttentionWithRoPE
(
self
.
num_heads
,
self
.
rotary_emb
=
get_rope
(
self
.
head_size
,
scaling
,
config
.
rotary_dim
,
base
=
rope_theta
,
rotary_dim
=
config
.
rotary_dim
,
max_position
=
max_position_embeddings
,
is_neox_style
=
False
)
self
.
warmup
=
False
base
=
rope_theta
,
is_neox_style
=
False
,
)
self
.
attn
=
PagedAttention
(
self
.
num_heads
,
self
.
head_size
,
scaling
)
def
forward
(
self
,
...
...
@@ -97,9 +97,10 @@ class GPTJAttention(nn.Module):
)
->
torch
.
Tensor
:
qkv
,
_
=
self
.
qkv_proj
(
hidden_states
)
q
,
k
,
v
=
qkv
.
chunk
(
chunks
=
3
,
dim
=-
1
)
q
,
k
=
self
.
rotary_emb
(
position_ids
,
q
,
k
)
k_cache
,
v_cache
=
kv_cache
attn_output
=
self
.
attn
(
position_ids
,
q
,
k
,
v
,
k_cache
,
v_cache
,
input_metadata
,
cache_event
)
attn_output
=
self
.
attn
(
q
,
k
,
v
,
k_cache
,
v_cache
,
input_metadata
,
cache_event
)
attn_output
,
_
=
self
.
out_proj
(
attn_output
)
return
attn_output
...
...
vllm/model_executor/models/gpt_neox.py
View file @
a9e45742
...
...
@@ -24,11 +24,12 @@ from transformers import GPTNeoXConfig
from
vllm.model_executor.input_metadata
import
InputMetadata
from
vllm.model_executor.layers.activation
import
get_act_fn
from
vllm.model_executor.layers.attention
import
PagedAttention
WithRoPE
from
vllm.model_executor.layers.attention
import
PagedAttention
from
vllm.model_executor.layers.linear
import
(
ColumnParallelLinear
,
LinearMethodBase
,
QKVParallelLinear
,
RowParallelLinear
)
from
vllm.model_executor.layers.rotary_embedding
import
get_rope
from
vllm.model_executor.layers.sampler
import
Sampler
from
vllm.model_executor.layers.vocab_parallel_embedding
import
(
VocabParallelEmbedding
,
ParallelLMHead
)
...
...
@@ -77,13 +78,13 @@ class GPTNeoXAttention(nn.Module):
rope_theta
=
getattr
(
config
,
"rope_theta"
,
10000
)
max_position_embeddings
=
getattr
(
config
,
"max_position_embeddings"
,
8192
)
self
.
attn
=
PagedAttentionWithRoPE
(
self
.
num_heads
,
self
.
rotary_emb
=
get_rope
(
self
.
head_size
,
scaling
,
rotary_dim
,
rotary_dim
=
rotary_dim
,
max_position
=
max_position_embeddings
,
base
=
rope_theta
,
max_position
=
max_position_embeddings
)
)
self
.
attn
=
PagedAttention
(
self
.
num_heads
,
self
.
head_size
,
scaling
)
def
forward
(
self
,
...
...
@@ -95,9 +96,10 @@ class GPTNeoXAttention(nn.Module):
)
->
torch
.
Tensor
:
qkv
,
_
=
self
.
query_key_value
(
hidden_states
)
q
,
k
,
v
=
qkv
.
chunk
(
chunks
=
3
,
dim
=-
1
)
q
,
k
=
self
.
rotary_emb
(
position_ids
,
q
,
k
)
k_cache
,
v_cache
=
kv_cache
attn_output
=
self
.
attn
(
position_ids
,
q
,
k
,
v
,
k_cache
,
v_cache
,
input_metadata
,
cache_event
)
attn_output
=
self
.
attn
(
q
,
k
,
v
,
k_cache
,
v_cache
,
input_metadata
,
cache_event
)
output
,
_
=
self
.
dense
(
attn_output
)
return
output
...
...
vllm/model_executor/models/internlm.py
View file @
a9e45742
...
...
@@ -7,12 +7,13 @@ from transformers import LlamaConfig
from
vllm.model_executor.input_metadata
import
InputMetadata
from
vllm.model_executor.layers.activation
import
SiluAndMul
from
vllm.model_executor.layers.attention
import
PagedAttention
WithRoPE
from
vllm.model_executor.layers.attention
import
PagedAttention
from
vllm.model_executor.layers.layernorm
import
RMSNorm
from
vllm.model_executor.layers.linear
import
(
LinearMethodBase
,
MergedColumnParallelLinear
,
QKVParallelLinear
,
RowParallelLinear
)
from
vllm.model_executor.layers.rotary_embedding
import
get_rope
from
vllm.model_executor.layers.sampler
import
Sampler
from
vllm.model_executor.layers.vocab_parallel_embedding
import
(
VocabParallelEmbedding
,
ParallelLMHead
)
...
...
@@ -92,13 +93,13 @@ class InternLMAttention(nn.Module):
bias
=
bias
,
linear_method
=
linear_method
,
)
self
.
attn
=
PagedAttentionWithRoPE
(
self
.
num_heads
,
self
.
rotary_emb
=
get_rope
(
self
.
head_dim
,
self
.
scaling
,
base
=
self
.
rope_theta
,
rotary_dim
=
self
.
head_dim
,
max_position
=
self
.
max_position_embeddings
,
rotary_dim
=
self
.
head_dim
)
base
=
self
.
rope_theta
,
)
self
.
attn
=
PagedAttention
(
self
.
num_heads
,
self
.
head_dim
,
self
.
scaling
)
def
forward
(
self
,
...
...
@@ -110,9 +111,10 @@ class InternLMAttention(nn.Module):
)
->
torch
.
Tensor
:
qkv
,
_
=
self
.
qkv_proj
(
hidden_states
)
q
,
k
,
v
=
qkv
.
chunk
(
chunks
=
3
,
dim
=-
1
)
q
,
k
=
self
.
rotary_emb
(
positions
,
q
,
k
)
k_cache
,
v_cache
=
kv_cache
attn_output
=
self
.
attn
(
positions
,
q
,
k
,
v
,
k_cache
,
v_cache
,
input_metadata
,
cache_event
)
attn_output
=
self
.
attn
(
q
,
k
,
v
,
k_cache
,
v_cache
,
input_metadata
,
cache_event
)
output
,
_
=
self
.
o_proj
(
attn_output
)
return
output
...
...
vllm/model_executor/models/llama.py
View file @
a9e45742
...
...
@@ -29,12 +29,13 @@ from transformers import LlamaConfig
from
vllm.model_executor.input_metadata
import
InputMetadata
from
vllm.model_executor.layers.activation
import
SiluAndMul
from
vllm.model_executor.layers.attention
import
PagedAttention
WithRoPE
from
vllm.model_executor.layers.attention
import
PagedAttention
from
vllm.model_executor.layers.layernorm
import
RMSNorm
from
vllm.model_executor.layers.linear
import
(
LinearMethodBase
,
MergedColumnParallelLinear
,
QKVParallelLinear
,
RowParallelLinear
)
from
vllm.model_executor.layers.rotary_embedding
import
get_rope
from
vllm.model_executor.layers.sampler
import
Sampler
from
vllm.model_executor.layers.vocab_parallel_embedding
import
(
VocabParallelEmbedding
,
ParallelLMHead
)
...
...
@@ -126,15 +127,18 @@ class LlamaAttention(nn.Module):
bias
=
False
,
linear_method
=
linear_method
,
)
self
.
attn
=
PagedAttentionWithRoPE
(
self
.
num_heads
,
self
.
rotary_emb
=
get_rope
(
self
.
head_dim
,
self
.
scaling
,
base
=
self
.
rope_theta
,
max_position
=
self
.
max_position_embeddings
,
rotary_dim
=
self
.
head_dim
,
num_kv_heads
=
self
.
num_kv_heads
,
rope_scaling
=
rope_scaling
)
max_position
=
max_position_embeddings
,
base
=
rope_theta
,
rope_scaling
=
rope_scaling
,
)
self
.
attn
=
PagedAttention
(
self
.
num_heads
,
self
.
head_dim
,
self
.
scaling
,
num_kv_heads
=
self
.
num_kv_heads
)
def
forward
(
self
,
...
...
@@ -146,9 +150,10 @@ class LlamaAttention(nn.Module):
)
->
torch
.
Tensor
:
qkv
,
_
=
self
.
qkv_proj
(
hidden_states
)
q
,
k
,
v
=
qkv
.
split
([
self
.
q_size
,
self
.
kv_size
,
self
.
kv_size
],
dim
=-
1
)
q
,
k
=
self
.
rotary_emb
(
positions
,
q
,
k
)
k_cache
,
v_cache
=
kv_cache
attn_output
=
self
.
attn
(
positions
,
q
,
k
,
v
,
k_cache
,
v_cache
,
input_metadata
,
cache_event
)
attn_output
=
self
.
attn
(
q
,
k
,
v
,
k_cache
,
v_cache
,
input_metadata
,
cache_event
)
output
,
_
=
self
.
o_proj
(
attn_output
)
return
output
...
...
vllm/model_executor/models/mistral.py
View file @
a9e45742
...
...
@@ -29,12 +29,13 @@ from transformers import MistralConfig
from
vllm.model_executor.input_metadata
import
InputMetadata
from
vllm.model_executor.layers.activation
import
SiluAndMul
from
vllm.model_executor.layers.attention
import
PagedAttention
WithRoPE
from
vllm.model_executor.layers.attention
import
PagedAttention
from
vllm.model_executor.layers.layernorm
import
RMSNorm
from
vllm.model_executor.layers.linear
import
(
LinearMethodBase
,
MergedColumnParallelLinear
,
QKVParallelLinear
,
RowParallelLinear
)
from
vllm.model_executor.layers.rotary_embedding
import
get_rope
from
vllm.model_executor.layers.sampler
import
Sampler
from
vllm.model_executor.layers.vocab_parallel_embedding
import
(
VocabParallelEmbedding
,
ParallelLMHead
)
...
...
@@ -124,14 +125,18 @@ class MistralAttention(nn.Module):
bias
=
False
,
linear_method
=
linear_method
,
)
self
.
attn
=
PagedAttentionWithRoPE
(
self
.
num_heads
,
self
.
head_dim
,
self
.
scaling
,
base
=
self
.
rope_theta
,
max_position
=
max_position
,
rotary_dim
=
self
.
head_dim
,
num_kv_heads
=
self
.
num_kv_heads
,
sliding_window
=
self
.
sliding_window
)
self
.
rotary_emb
=
get_rope
(
self
.
head_dim
,
rotary_dim
=
self
.
head_dim
,
max_position
=
max_position
,
base
=
self
.
rope_theta
,
)
self
.
attn
=
PagedAttention
(
self
.
num_heads
,
self
.
head_dim
,
self
.
scaling
,
num_kv_heads
=
self
.
num_kv_heads
,
sliding_window
=
self
.
sliding_window
)
def
forward
(
self
,
...
...
@@ -143,9 +148,10 @@ class MistralAttention(nn.Module):
)
->
torch
.
Tensor
:
qkv
,
_
=
self
.
qkv_proj
(
hidden_states
)
q
,
k
,
v
=
qkv
.
split
([
self
.
q_size
,
self
.
kv_size
,
self
.
kv_size
],
dim
=-
1
)
q
,
k
=
self
.
rotary_emb
(
positions
,
q
,
k
)
k_cache
,
v_cache
=
kv_cache
attn_output
=
self
.
attn
(
positions
,
q
,
k
,
v
,
k_cache
,
v_cache
,
input_metadata
,
cache_event
)
attn_output
=
self
.
attn
(
q
,
k
,
v
,
k_cache
,
v_cache
,
input_metadata
,
cache_event
)
output
,
_
=
self
.
o_proj
(
attn_output
)
return
output
...
...
vllm/model_executor/models/mpt.py
View file @
a9e45742
...
...
@@ -8,7 +8,7 @@ import torch.nn as nn
from
vllm.model_executor.input_metadata
import
InputMetadata
from
vllm.model_executor.layers.activation
import
get_act_fn
from
vllm.model_executor.layers.attention
import
PagedAttention
WithALiBi
from
vllm.model_executor.layers.attention
import
PagedAttention
from
vllm.model_executor.layers.linear
import
(
ColumnParallelLinear
,
LinearMethodBase
,
QKVParallelLinear
,
...
...
@@ -87,8 +87,10 @@ class MPTAttention(nn.Module):
self
.
head_dim
=
self
.
d_model
//
self
.
total_num_heads
scaling
=
self
.
head_dim
**-
0.5
self
.
attn
=
PagedAttentionWithALiBi
(
self
.
num_heads
,
self
.
head_dim
,
scaling
,
alibi_slopes
)
self
.
attn
=
PagedAttention
(
self
.
num_heads
,
self
.
head_dim
,
scaling
,
alibi_slopes
=
alibi_slopes
)
def
forward
(
self
,
...
...
vllm/model_executor/models/phi_1_5.py
View file @
a9e45742
...
...
@@ -43,11 +43,12 @@ from transformers import PretrainedConfig
from
vllm.model_executor.input_metadata
import
InputMetadata
from
vllm.model_executor.layers.activation
import
get_act_fn
from
vllm.model_executor.layers.attention
import
PagedAttention
WithRoPE
from
vllm.model_executor.layers.attention
import
PagedAttention
from
vllm.model_executor.layers.linear
import
(
ColumnParallelLinear
,
LinearMethodBase
,
QKVParallelLinear
,
RowParallelLinear
)
from
vllm.model_executor.layers.rotary_embedding
import
get_rope
from
vllm.model_executor.layers.sampler
import
Sampler
from
vllm.model_executor.layers.vocab_parallel_embedding
import
(
VocabParallelEmbedding
,
ParallelLMHead
)
...
...
@@ -119,13 +120,13 @@ class PhiAttention(nn.Module):
# https://huggingface.co/microsoft/phi-1_5/blob/d212a789620c380ff32ca1d1ee9943a777360987/modeling_phi.py#L518
rope_theta
=
10000
max_position_embeddings
=
getattr
(
config
,
"n_positions"
,
2048
)
self
.
attn
=
PagedAttentionWithRoPE
(
self
.
num_heads
,
self
.
rotary_emb
=
get_rope
(
self
.
head_size
,
scaling
,
rotary_dim
,
rotary_dim
=
rotary_dim
,
max_position
=
max_position_embeddings
,
base
=
rope_theta
,
max_position
=
max_position_embeddings
)
)
self
.
attn
=
PagedAttention
(
self
.
num_heads
,
self
.
head_size
,
scaling
)
def
forward
(
self
,
...
...
@@ -137,9 +138,10 @@ class PhiAttention(nn.Module):
)
->
torch
.
Tensor
:
qkv
,
_
=
self
.
Wqkv
(
hidden_states
)
q
,
k
,
v
=
qkv
.
chunk
(
chunks
=
3
,
dim
=-
1
)
q
,
k
=
self
.
rotary_emb
(
position_ids
,
q
,
k
)
k_cache
,
v_cache
=
kv_cache
attn_output
=
self
.
attn
(
position_ids
,
q
,
k
,
v
,
k_cache
,
v_cache
,
input_metadata
,
cache_event
)
attn_output
=
self
.
attn
(
q
,
k
,
v
,
k_cache
,
v_cache
,
input_metadata
,
cache_event
)
output
,
_
=
self
.
out_proj
(
attn_output
)
return
output
...
...
vllm/model_executor/models/qwen.py
View file @
a9e45742
...
...
@@ -11,12 +11,13 @@ from torch import nn
from
vllm.model_executor.input_metadata
import
InputMetadata
from
vllm.model_executor.layers.activation
import
SiluAndMul
from
vllm.model_executor.layers.attention
import
PagedAttention
WithRoPE
from
vllm.model_executor.layers.attention
import
PagedAttention
from
vllm.model_executor.layers.layernorm
import
RMSNorm
from
vllm.model_executor.layers.linear
import
(
LinearMethodBase
,
MergedColumnParallelLinear
,
QKVParallelLinear
,
RowParallelLinear
)
from
vllm.model_executor.layers.rotary_embedding
import
get_rope
from
vllm.model_executor.layers.sampler
import
Sampler
from
vllm.model_executor.layers.vocab_parallel_embedding
import
(
VocabParallelEmbedding
,
ParallelLMHead
)
...
...
@@ -95,14 +96,15 @@ class QWenAttention(nn.Module):
linear_method
=
linear_method
,
)
self
.
scaling
=
self
.
head_dim
**-
0.5
self
.
attn
=
PagedAttentionWithRoPE
(
self
.
num_heads
,
self
.
rotary_emb
=
get_rope
(
self
.
head_dim
,
self
.
scaling
,
rotary_dim
=
self
.
head_dim
,
base
=
rope_theta
,
max_position
=
max_position_embeddings
,
rope_scaling
=
rope_scaling
)
base
=
rope_theta
,
rope_scaling
=
rope_scaling
,
)
self
.
attn
=
PagedAttention
(
self
.
num_heads
,
self
.
head_dim
,
self
.
scaling
)
def
forward
(
self
,
...
...
@@ -114,10 +116,10 @@ class QWenAttention(nn.Module):
)
->
torch
.
Tensor
:
qkv
,
_
=
self
.
c_attn
(
hidden_states
)
q
,
k
,
v
=
qkv
.
chunk
(
chunks
=
3
,
dim
=-
1
)
q
,
k
=
self
.
rotary_emb
(
positions
,
q
,
k
)
k_cache
,
v_cache
=
kv_cache
attn_output
=
self
.
attn
(
positions
,
q
,
k
,
v
,
k_cache
,
v_cache
,
input_metadata
,
cache_event
)
attn_output
=
self
.
attn
(
q
,
k
,
v
,
k_cache
,
v_cache
,
input_metadata
,
cache_event
)
output
,
_
=
self
.
c_proj
(
attn_output
)
return
output
...
...
vllm/model_executor/models/yi.py
View file @
a9e45742
...
...
@@ -29,12 +29,13 @@ from vllm.transformers_utils.configs.yi import YiConfig
from
vllm.model_executor.input_metadata
import
InputMetadata
from
vllm.model_executor.layers.activation
import
SiluAndMul
from
vllm.model_executor.layers.attention
import
PagedAttention
WithRoPE
from
vllm.model_executor.layers.attention
import
PagedAttention
from
vllm.model_executor.layers.layernorm
import
RMSNorm
from
vllm.model_executor.layers.linear
import
(
LinearMethodBase
,
MergedColumnParallelLinear
,
QKVParallelLinear
,
RowParallelLinear
)
from
vllm.model_executor.layers.rotary_embedding
import
get_rope
from
vllm.model_executor.layers.sampler
import
Sampler
from
vllm.model_executor.layers.vocab_parallel_embedding
import
(
VocabParallelEmbedding
,
ParallelLMHead
)
...
...
@@ -126,15 +127,17 @@ class YiAttention(nn.Module):
bias
=
False
,
linear_method
=
linear_method
,
)
self
.
attn
=
PagedAttentionWithRoPE
(
self
.
num_heads
,
self
.
rotary_emb
=
get_rope
(
self
.
head_dim
,
self
.
scaling
,
base
=
self
.
rope_theta
,
max_position
=
self
.
max_position_embeddings
,
rotary_dim
=
self
.
head_dim
,
num_kv_heads
=
self
.
num_kv_heads
,
rope_scaling
=
rope_scaling
)
max_position
=
max_position_embeddings
,
base
=
self
.
rope_theta
,
rope_scaling
=
rope_scaling
,
)
self
.
attn
=
PagedAttention
(
self
.
num_heads
,
self
.
head_dim
,
self
.
scaling
,
num_kv_heads
=
self
.
num_kv_heads
)
def
forward
(
self
,
...
...
@@ -146,9 +149,10 @@ class YiAttention(nn.Module):
)
->
torch
.
Tensor
:
qkv
,
_
=
self
.
qkv_proj
(
hidden_states
)
q
,
k
,
v
=
qkv
.
split
([
self
.
q_size
,
self
.
kv_size
,
self
.
kv_size
],
dim
=-
1
)
q
,
k
=
self
.
rotary_emb
(
positions
,
q
,
k
)
k_cache
,
v_cache
=
kv_cache
attn_output
=
self
.
attn
(
positions
,
q
,
k
,
v
,
k_cache
,
v_cache
,
input_metadata
,
cache_event
)
attn_output
=
self
.
attn
(
q
,
k
,
v
,
k_cache
,
v_cache
,
input_metadata
,
cache_event
)
output
,
_
=
self
.
o_proj
(
attn_output
)
return
output
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment