Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
norm
vllm
Commits
a9e45742
Unverified
Commit
a9e45742
authored
Nov 29, 2023
by
Woosuk Kwon
Committed by
GitHub
Nov 29, 2023
Browse files
Refactor Attention (#1840)
parent
0229c386
Changes
16
Hide whitespace changes
Inline
Side-by-side
Showing
16 changed files
with
360 additions
and
498 deletions
+360
-498
vllm/model_executor/layers/attention.py
vllm/model_executor/layers/attention.py
+191
-361
vllm/model_executor/layers/rotary_embedding.py
vllm/model_executor/layers/rotary_embedding.py
+2
-2
vllm/model_executor/models/aquila.py
vllm/model_executor/models/aquila.py
+14
-10
vllm/model_executor/models/baichuan.py
vllm/model_executor/models/baichuan.py
+16
-16
vllm/model_executor/models/bloom.py
vllm/model_executor/models/bloom.py
+5
-3
vllm/model_executor/models/chatglm.py
vllm/model_executor/models/chatglm.py
+10
-9
vllm/model_executor/models/falcon.py
vllm/model_executor/models/falcon.py
+19
-20
vllm/model_executor/models/gpt_j.py
vllm/model_executor/models/gpt_j.py
+11
-10
vllm/model_executor/models/gpt_neox.py
vllm/model_executor/models/gpt_neox.py
+10
-8
vllm/model_executor/models/internlm.py
vllm/model_executor/models/internlm.py
+10
-8
vllm/model_executor/models/llama.py
vllm/model_executor/models/llama.py
+15
-10
vllm/model_executor/models/mistral.py
vllm/model_executor/models/mistral.py
+17
-11
vllm/model_executor/models/mpt.py
vllm/model_executor/models/mpt.py
+5
-3
vllm/model_executor/models/phi_1_5.py
vllm/model_executor/models/phi_1_5.py
+10
-8
vllm/model_executor/models/qwen.py
vllm/model_executor/models/qwen.py
+11
-9
vllm/model_executor/models/yi.py
vllm/model_executor/models/yi.py
+14
-10
No files found.
vllm/model_executor/layers/attention.py
View file @
a9e45742
"""Multi-head attention."""
"""Multi-head attention."""
from
typing
import
Any
,
Dict
,
List
,
Optional
from
typing
import
List
,
Optional
import
torch
import
torch
import
torch.nn
as
nn
import
torch.nn
as
nn
...
@@ -10,7 +10,6 @@ from xformers.ops.fmha.attn_bias import (BlockDiagonalCausalMask,
...
@@ -10,7 +10,6 @@ from xformers.ops.fmha.attn_bias import (BlockDiagonalCausalMask,
from
vllm._C
import
ops
from
vllm._C
import
ops
from
vllm._C
import
cache_ops
from
vllm._C
import
cache_ops
from
vllm.model_executor.input_metadata
import
InputMetadata
from
vllm.model_executor.input_metadata
import
InputMetadata
from
vllm.model_executor.layers.rotary_embedding
import
get_rope
_SUPPORTED_HEAD_SIZES
=
[
64
,
80
,
96
,
112
,
128
,
256
]
_SUPPORTED_HEAD_SIZES
=
[
64
,
80
,
96
,
112
,
128
,
256
]
# Should be the same as PARTITION_SIZE in `paged_attention_v2_launcher`.
# Should be the same as PARTITION_SIZE in `paged_attention_v2_launcher`.
...
@@ -18,37 +17,39 @@ _PARTITION_SIZE = 512
...
@@ -18,37 +17,39 @@ _PARTITION_SIZE = 512
class
PagedAttention
(
nn
.
Module
):
class
PagedAttention
(
nn
.
Module
):
"""
GPT-style multi-head
PagedAttention.
"""
MHA/MQA/GQA layer with
PagedAttention.
This class takes query, key, and value tensors as input. The input tensors
This class takes query, key, and value tensors as input. The input tensors
can either contain prompt tokens or generation tokens, in addition to
can either contain prompt tokens or generation tokens.
paddings.
The class does the following:
The class does the following:
1. Perform multi_query_kv_attention for the prompts. This operation does
not use the KV cache.
1. Wait for the cache operations (e.g., swap, copy) to finish. The cache
2. Wait for the cache operations (e.g., swap, copy) to finish. The cache
operations are issued by the cache engine before executing the forward
operations are issued by the cache engine before executing the forward
pass of the model, and they are executed asynchronously.
pass of the model, and they are executed asynchronously.
3. Reshape and store the input key and value tensors in the KV cache.
2. Reshape and store the input key and value tensors in the KV cache.
4. Perform single_query_cached_kv_attention for the generation tokens.
3. Perform (multi-head/multi-query/grouped-query) attention using either
This operation reads the previous key and value tensors from the KV
xformers or the PagedAttention custom op.
cache.
4. Return the output tensor.
5. Return the output tensor.
"""
"""
def
__init__
(
self
,
def
__init__
(
num_heads
:
int
,
self
,
head_size
:
int
,
num_heads
:
int
,
scale
:
float
,
head_size
:
int
,
num_kv_heads
:
Optional
[
int
]
=
None
,
scale
:
float
,
sliding_window
:
Optional
[
int
]
=
None
)
->
None
:
num_kv_heads
:
Optional
[
int
]
=
None
,
alibi_slopes
:
Optional
[
List
[
float
]]
=
None
,
sliding_window
:
Optional
[
int
]
=
None
,
)
->
None
:
super
().
__init__
()
super
().
__init__
()
self
.
num_heads
=
num_heads
self
.
num_heads
=
num_heads
self
.
head_size
=
head_size
self
.
head_size
=
head_size
self
.
scale
=
float
(
scale
)
self
.
scale
=
float
(
scale
)
self
.
num_kv_heads
=
num_heads
if
num_kv_heads
is
None
else
num_kv_heads
self
.
num_kv_heads
=
num_heads
if
num_kv_heads
is
None
else
num_kv_heads
self
.
sliding_window
=
sliding_window
self
.
sliding_window
=
sliding_window
if
alibi_slopes
is
not
None
:
alibi_slopes
=
torch
.
tensor
(
alibi_slopes
,
dtype
=
torch
.
float32
)
self
.
register_buffer
(
"alibi_slopes"
,
alibi_slopes
,
persistent
=
False
)
assert
self
.
num_heads
%
self
.
num_kv_heads
==
0
assert
self
.
num_heads
%
self
.
num_kv_heads
==
0
self
.
num_queries_per_kv
=
self
.
num_heads
//
self
.
num_kv_heads
self
.
num_queries_per_kv
=
self
.
num_heads
//
self
.
num_kv_heads
...
@@ -60,153 +61,6 @@ class PagedAttention(nn.Module):
...
@@ -60,153 +61,6 @@ class PagedAttention(nn.Module):
raise
ValueError
(
f
"head_size (
{
self
.
head_size
}
) is not supported. "
raise
ValueError
(
f
"head_size (
{
self
.
head_size
}
) is not supported. "
f
"Supported head sizes:
{
_SUPPORTED_HEAD_SIZES
}
."
)
f
"Supported head sizes:
{
_SUPPORTED_HEAD_SIZES
}
."
)
def
set_attn_bias
(
self
,
input_metadata
:
InputMetadata
,
dtype
:
torch
.
dtype
,
)
->
None
:
del
dtype
# Unused.
if
input_metadata
.
attn_bias
is
not
None
:
# Already set by a previous layer.
return
prompt_lens
=
[
input_metadata
.
max_prompt_len
]
*
input_metadata
.
num_prompts
attn_bias
=
BlockDiagonalCausalMask
.
from_seqlens
(
prompt_lens
)
if
self
.
sliding_window
is
not
None
:
attn_bias
=
attn_bias
.
make_local_attention
(
self
.
sliding_window
)
input_metadata
.
attn_bias
=
attn_bias
def
multi_query_kv_attention
(
self
,
output
:
torch
.
Tensor
,
query
:
torch
.
Tensor
,
key
:
torch
.
Tensor
,
value
:
torch
.
Tensor
,
input_metadata
:
InputMetadata
,
)
->
torch
.
Tensor
:
"""Normal attention for the prompt tokens.
Args:
output: shape = [num_prompt_tokens, num_heads, head_size]
query: shape = [num_prompt_tokens, num_heads, head_size]
key: shape = [num_prompt_tokens, num_kv_heads, head_size]
value: shape = [num_prompt_tokens, num_kv_heads, head_size]
input_metadata: metadata for paged attention.
"""
if
self
.
num_kv_heads
!=
self
.
num_heads
:
# Project the key and value tensors to the desired number of heads.
query
=
query
.
view
(
query
.
shape
[
0
],
self
.
num_kv_heads
,
self
.
num_queries_per_kv
,
query
.
shape
[
-
1
])
key
=
key
[:,
:,
None
,
:].
expand
(
key
.
shape
[
0
],
self
.
num_kv_heads
,
self
.
num_queries_per_kv
,
key
.
shape
[
-
1
])
value
=
value
[:,
:,
None
,
:].
expand
(
value
.
shape
[
0
],
self
.
num_kv_heads
,
self
.
num_queries_per_kv
,
value
.
shape
[
-
1
])
# TODO(woosuk): The unsqueeze op may incur some CPU overhead. Optimize.
out
=
xops
.
memory_efficient_attention_forward
(
query
.
unsqueeze
(
0
),
key
.
unsqueeze
(
0
),
value
.
unsqueeze
(
0
),
attn_bias
=
input_metadata
.
attn_bias
,
p
=
0.0
,
scale
=
self
.
scale
,
)
# TODO(woosuk): Unnecessary copy. Optimize.
output
.
copy_
(
out
.
view_as
(
output
))
return
output
def
get_alibi_slopes
(
self
)
->
Optional
[
torch
.
Tensor
]:
"""Returns the slopes for the alibi attention bias.
Returns:
slopes: shape = [num_heads]
"""
return
None
def
single_query_cached_kv_attention
(
self
,
output
:
torch
.
Tensor
,
query
:
torch
.
Tensor
,
key_cache
:
torch
.
Tensor
,
value_cache
:
torch
.
Tensor
,
input_metadata
:
InputMetadata
,
alibi_slopes
:
Optional
[
torch
.
Tensor
],
)
->
None
:
"""PagedAttention for the generation tokens.
Args:
output: shape = [num_generation_tokens, num_heads, head_size]
query: shape = [num_generation_tokens, num_heads, head_size]
key_cache: shape = [num_blocks, num_kv_heads, head_size/x,
block_size, x]
value_cache: shape = [num_blocks, num_kv_heads, head_size,
block_size]
input_metadata: metadata for paged attention.
alibi_slopes: shape = [num_heads]
"""
block_size
=
value_cache
.
shape
[
3
]
num_seqs
,
num_heads
,
head_size
=
query
.
shape
max_num_partitions
=
(
(
input_metadata
.
max_context_len
+
_PARTITION_SIZE
-
1
)
//
_PARTITION_SIZE
)
# NOTE(woosuk): We use a simple heuristic to decide whether to use
# PagedAttention V1 or V2. If the number of partitions is 1, we use
# V1 to avoid the overhead of reduction. Also, if the number of
# sequences or heads is large, we use V1 since there is enough work
# to parallelize.
# TODO(woosuk): Tune this heuristic.
# For context len > 8192, use V2 kernel to avoid shared memory shortage.
use_v1
=
input_metadata
.
max_context_len
<=
8192
and
(
max_num_partitions
==
1
or
num_seqs
*
num_heads
>
512
)
if
use_v1
:
# Run PagedAttention V1.
ops
.
paged_attention_v1
(
output
,
query
,
key_cache
,
value_cache
,
self
.
head_mapping
,
self
.
scale
,
input_metadata
.
block_tables
,
input_metadata
.
context_lens
,
block_size
,
input_metadata
.
max_context_len
,
alibi_slopes
,
)
else
:
# Run PagedAttention V2.
assert
_PARTITION_SIZE
%
block_size
==
0
tmp_output
=
torch
.
empty
(
size
=
(
num_seqs
,
num_heads
,
max_num_partitions
,
head_size
),
dtype
=
output
.
dtype
,
device
=
output
.
device
,
)
exp_sums
=
torch
.
empty
(
size
=
(
num_seqs
,
num_heads
,
max_num_partitions
),
dtype
=
torch
.
float32
,
device
=
output
.
device
,
)
max_logits
=
torch
.
empty_like
(
exp_sums
)
ops
.
paged_attention_v2
(
output
,
exp_sums
,
max_logits
,
tmp_output
,
query
,
key_cache
,
value_cache
,
self
.
head_mapping
,
self
.
scale
,
input_metadata
.
block_tables
,
input_metadata
.
context_lens
,
block_size
,
input_metadata
.
max_context_len
,
alibi_slopes
,
)
def
forward
(
def
forward
(
self
,
self
,
query
:
torch
.
Tensor
,
query
:
torch
.
Tensor
,
...
@@ -219,9 +73,6 @@ class PagedAttention(nn.Module):
...
@@ -219,9 +73,6 @@ class PagedAttention(nn.Module):
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
"""PagedAttention forward pass.
"""PagedAttention forward pass.
NOTE: The query, key, and value tensors must be sliced from a qkv
tensor of shape [batch_size, seq_len, 3 * num_heads * head_size].
Args:
Args:
query: shape = [batch_size, seq_len, num_heads * head_size]
query: shape = [batch_size, seq_len, num_heads * head_size]
key: shape = [batch_size, seq_len, num_kv_heads * head_size]
key: shape = [batch_size, seq_len, num_kv_heads * head_size]
...
@@ -230,46 +81,28 @@ class PagedAttention(nn.Module):
...
@@ -230,46 +81,28 @@ class PagedAttention(nn.Module):
block_size, x]
block_size, x]
value_cache: shape = [num_blocks, num_kv_heads, head_size,
value_cache: shape = [num_blocks, num_kv_heads, head_size,
block_size]
block_size]
input_metadata: metadata for
paged attention
.
input_metadata: metadata for
the inputs
.
cache_event: event to wait for the cache operations to finish.
cache_event: event to wait for the cache operations to finish.
Returns:
Returns:
shape = [batch_size, seq_len, num_heads * head_size]
shape = [batch_size, seq_len, num_heads * head_size]
"""
"""
batch_size
,
seq_len
,
_
=
query
.
shape
batch_size
,
seq_len
,
hidden_size
=
query
.
shape
# Reshape the query, key, and value tensors.
# Reshape the query, key, and value tensors.
query
=
query
.
view
(
-
1
,
self
.
num_heads
,
self
.
head_size
)
query
=
query
.
view
(
-
1
,
self
.
num_heads
,
self
.
head_size
)
key
=
key
.
view
(
-
1
,
self
.
num_kv_heads
,
self
.
head_size
)
key
=
key
.
view
(
-
1
,
self
.
num_kv_heads
,
self
.
head_size
)
value
=
value
.
view
(
-
1
,
self
.
num_kv_heads
,
self
.
head_size
)
value
=
value
.
view
(
-
1
,
self
.
num_kv_heads
,
self
.
head_size
)
slot_mapping
=
input_metadata
.
slot_mapping
.
flatten
()
# Pre-allocate the output tensor.
output
=
torch
.
empty_like
(
query
)
# Compute the attention op for prompts.
num_prompt_tokens
=
input_metadata
.
num_prompt_tokens
if
num_prompt_tokens
>
0
:
# Prompt run.
assert
input_metadata
.
num_generation_tokens
==
0
self
.
set_attn_bias
(
input_metadata
,
dtype
=
query
.
dtype
)
self
.
multi_query_kv_attention
(
output
,
query
,
key
,
value
,
input_metadata
,
)
# Wait until the cache op is done.
if
cache_event
is
not
None
:
if
cache_event
is
not
None
:
cache_event
.
wait
()
cache_event
.
wait
()
# Reshape the keys and values and store them in the cache.
# Reshape the keys and values and store them in the cache.
# When key_cache and value_cache are not provided, the new key
# If key_cache and value_cache are not provided, the new key and value
# and value vectors will not be cached.
# vectors will not be cached. This happens during the initial memory
# profiling run.
if
key_cache
is
not
None
and
value_cache
is
not
None
:
if
key_cache
is
not
None
and
value_cache
is
not
None
:
key_to_cache
=
key
key_to_cache
=
key
value_to_cache
=
value
value_to_cache
=
value
slot_mapping
=
input_metadata
.
slot_mapping
.
view
(
-
1
)
if
input_metadata
.
to_cache
is
not
None
:
if
input_metadata
.
to_cache
is
not
None
:
key_to_cache
=
key_to_cache
[
input_metadata
.
to_cache
]
key_to_cache
=
key_to_cache
[
input_metadata
.
to_cache
]
value_to_cache
=
value_to_cache
[
input_metadata
.
to_cache
]
value_to_cache
=
value_to_cache
[
input_metadata
.
to_cache
]
...
@@ -283,178 +116,175 @@ class PagedAttention(nn.Module):
...
@@ -283,178 +116,175 @@ class PagedAttention(nn.Module):
slot_mapping
,
slot_mapping
,
)
)
if
input_metadata
.
num_generation_tokens
>
0
:
is_prompt
=
len
(
input_metadata
.
prompt_lens
)
>
0
if
is_prompt
:
# Prompt run.
if
self
.
num_kv_heads
!=
self
.
num_heads
:
# As of Nov 2023, xformers only supports MHA. For MQA/GQA,
# project the key and value tensors to the desired number of
# heads.
# TODO(woosuk): Use MQA/GQA kernels for higher performance.
query
=
query
.
view
(
query
.
shape
[
0
],
self
.
num_kv_heads
,
self
.
num_queries_per_kv
,
query
.
shape
[
-
1
])
key
=
key
[:,
:,
None
,
:].
expand
(
key
.
shape
[
0
],
self
.
num_kv_heads
,
self
.
num_queries_per_kv
,
key
.
shape
[
-
1
])
value
=
value
[:,
:,
None
,
:].
expand
(
value
.
shape
[
0
],
self
.
num_kv_heads
,
self
.
num_queries_per_kv
,
value
.
shape
[
-
1
])
# Set attention bias if not provided. This typically happens at the
# very attention layer of every iteration.
# FIXME(woosuk): This is a hack.
if
input_metadata
.
attn_bias
is
None
:
if
self
.
alibi_slopes
is
None
:
attn_bias
=
BlockDiagonalCausalMask
.
from_seqlens
(
[
seq_len
]
*
batch_size
)
if
self
.
sliding_window
is
not
None
:
attn_bias
=
attn_bias
.
make_local_attention
(
self
.
sliding_window
)
input_metadata
.
attn_bias
=
attn_bias
else
:
input_metadata
.
attn_bias
=
_make_alibi_bias
(
self
.
alibi_slopes
,
batch_size
,
seq_len
,
query
.
dtype
)
# TODO(woosuk): Too many view operations. Let's try to reduce them
# in the future for code readability.
if
self
.
alibi_slopes
is
None
:
query
=
query
.
unsqueeze
(
0
)
key
=
key
.
unsqueeze
(
0
)
value
=
value
.
unsqueeze
(
0
)
else
:
query
=
query
.
unflatten
(
0
,
(
batch_size
,
seq_len
))
key
=
key
.
unflatten
(
0
,
(
batch_size
,
seq_len
))
value
=
value
.
unflatten
(
0
,
(
batch_size
,
seq_len
))
out
=
xops
.
memory_efficient_attention_forward
(
query
,
key
,
value
,
attn_bias
=
input_metadata
.
attn_bias
,
p
=
0.0
,
scale
=
self
.
scale
,
)
output
=
out
.
view_as
(
query
)
else
:
# Decoding run.
# Decoding run.
assert
input_metadata
.
num_prompt_tokens
==
0
output
=
_paged_attention
(
assert
key_cache
is
not
None
and
value_cache
is
not
None
,
(
query
,
"key_cache and value_cache must be provided when "
key_cache
,
"generating tokens."
)
value_cache
,
# Compute the attention op for generation tokens.
input_metadata
,
self
.
single_query_cached_kv_attention
(
output
,
query
,
key_cache
,
self
.
head_mapping
,
value_cache
,
input_metadata
,
self
.
scale
,
self
.
get_alibi_slopes
())
self
.
alibi_slopes
,
)
# Reshape the output tensor.
# Reshape the output tensor.
# NOTE(woosuk): The output tensor may include paddings.
return
output
.
view
(
batch_size
,
seq_len
,
hidden_size
)
return
output
.
view
(
batch_size
,
seq_len
,
self
.
num_heads
*
self
.
head_size
)
def
_make_alibi_bias
(
alibi_slopes
:
torch
.
Tensor
,
class
PagedAttentionWithRoPE
(
PagedAttention
):
batch_size
:
int
,
"""PagedAttention with rotary positional embedding."""
seq_len
:
int
,
dtype
:
torch
.
dtype
,
def
__init__
(
)
->
LowerTriangularMaskWithTensorBias
:
self
,
bias
=
torch
.
arange
(
seq_len
,
dtype
=
dtype
)
num_heads
:
int
,
# NOTE(zhuohan): HF uses
head_size
:
int
,
# `bias = bias[None, :].repeat(prompt_len, 1)`
scale
:
float
,
# here. We find that both biases give the same results, but
rotary_dim
:
int
,
# the bias below more accurately follows the original ALiBi
max_position
:
int
=
8192
,
# paper.
base
:
int
=
10000
,
bias
=
bias
[
None
,
:]
-
bias
[:,
None
]
num_kv_heads
:
Optional
[
int
]
=
None
,
bias
=
bias
.
to
(
alibi_slopes
.
device
)
is_neox_style
:
bool
=
True
,
rope_scaling
:
Optional
[
Dict
[
str
,
Any
]]
=
None
,
# When using custom attention bias, xformers requires the bias to
sliding_window
:
Optional
[
int
]
=
None
,
# be sliced from a tensor whose length is a multiple of 8.
)
->
None
:
padded_len
=
(
seq_len
+
7
)
//
8
*
8
super
().
__init__
(
num_heads
,
bias
=
torch
.
empty
(
head_size
,
batch_size
,
scale
,
alibi_slopes
.
shape
[
0
],
num_kv_heads
,
seq_len
,
sliding_window
=
sliding_window
)
padded_len
,
self
.
rotary_emb
=
get_rope
(
head_size
,
rotary_dim
,
max_position
,
base
,
device
=
alibi_slopes
.
device
,
is_neox_style
,
rope_scaling
)
dtype
=
dtype
,
)[:,
:,
:,
:
seq_len
].
copy_
(
bias
)
def
forward
(
bias
.
mul_
(
alibi_slopes
[:,
None
,
None
])
self
,
attn_bias
=
LowerTriangularMaskWithTensorBias
(
bias
)
positions
:
torch
.
Tensor
,
return
attn_bias
query
:
torch
.
Tensor
,
key
:
torch
.
Tensor
,
value
:
torch
.
Tensor
,
def
_paged_attention
(
key_cache
:
torch
.
Tensor
,
query
:
torch
.
Tensor
,
value_cache
:
torch
.
Tensor
,
key_cache
:
torch
.
Tensor
,
input_metadata
:
InputMetadata
,
value_cache
:
torch
.
Tensor
,
cache_event
:
Optional
[
torch
.
cuda
.
Event
],
input_metadata
:
InputMetadata
,
)
->
torch
.
Tensor
:
head_mapping
:
torch
.
Tensor
,
""" PagedAttention forward pass with rotary embedding.
scale
:
float
,
alibi_slopes
:
Optional
[
torch
.
Tensor
],
Args:
)
->
torch
.
Tensor
:
positions: shape = [batch_size, seq_len]
output
=
torch
.
empty_like
(
query
)
query: shape = [batch_size, seq_len, num_heads * head_size]
key: shape = [batch_size, seq_len, num_kv_heads * head_size]
block_size
=
value_cache
.
shape
[
3
]
value: shape = [batch_size, seq_len, num_kv_heads * head_size]
num_seqs
,
num_heads
,
head_size
=
query
.
shape
key_cache: shape = [num_blocks, num_kv_heads, head_size/x,
max_num_partitions
=
(
block_size, x]
(
input_metadata
.
max_context_len
+
_PARTITION_SIZE
-
1
)
//
value_cache: shape = [num_blocks, num_kv_heads, head_size,
_PARTITION_SIZE
)
block_size]
# NOTE(woosuk): We use a simple heuristic to decide whether to use
input_metadata: metadata for paged attention.
# PagedAttention V1 or V2. If the number of partitions is 1, we use
cache_event: event to wait for the cache operations to finish.
# V1 to avoid the overhead of reduction. Also, if the number of
# sequences or heads is large, we use V1 since there is enough work
Returns:
# to parallelize.
shape = [batch_size, seq_len, num_heads * head_size]
# TODO(woosuk): Tune this heuristic.
"""
# For context len > 8192, use V2 kernel to avoid shared memory shortage.
use_v1
=
input_metadata
.
max_context_len
<=
8192
and
(
# Apply rotary embedding to the query and key before passing them
max_num_partitions
==
1
or
num_seqs
*
num_heads
>
512
)
# to the attention op.
if
use_v1
:
query
,
key
=
self
.
rotary_emb
(
positions
,
query
,
key
)
# Run PagedAttention V1.
return
super
().
forward
(
ops
.
paged_attention_v1
(
output
,
query
,
query
,
key
,
value
,
key_cache
,
key_cache
,
value_cache
,
value_cache
,
input_metadata
,
head_mapping
,
cache_event
,
scale
,
input_metadata
.
block_tables
,
input_metadata
.
context_lens
,
block_size
,
input_metadata
.
max_context_len
,
alibi_slopes
,
)
)
else
:
# Run PagedAttention V2.
class
PagedAttentionWithALiBi
(
PagedAttention
):
assert
_PARTITION_SIZE
%
block_size
==
0
"""PagedAttention with ALiBi attention bias."""
tmp_output
=
torch
.
empty
(
size
=
(
num_seqs
,
num_heads
,
max_num_partitions
,
head_size
),
def
__init__
(
self
,
dtype
=
output
.
dtype
,
num_heads
:
int
,
device
=
output
.
device
,
head_size
:
int
,
scale
:
float
,
slopes
:
List
[
float
],
num_kv_heads
:
Optional
[
int
]
=
None
)
->
None
:
super
().
__init__
(
num_heads
,
head_size
,
scale
,
num_kv_heads
)
assert
len
(
slopes
)
==
num_heads
slopes
=
torch
.
tensor
(
slopes
,
dtype
=
torch
.
float32
)
self
.
register_buffer
(
"alibi_slopes"
,
slopes
,
persistent
=
False
)
def
set_attn_bias
(
self
,
input_metadata
:
InputMetadata
,
dtype
:
torch
.
dtype
)
->
None
:
if
input_metadata
.
attn_bias
is
not
None
:
# Already set by a previous layer.
return
# Generates ALiBi mask based on the max prompt length.
max_prompt_len
=
input_metadata
.
max_prompt_len
bias
=
torch
.
arange
(
max_prompt_len
,
dtype
=
dtype
)
# NOTE(zhuohan): HF uses
# `bias = bias[None, :].repeat(prompt_len, 1)`
# here. We find that both biases give the same results, but
# the bias below more accurately follows the original ALiBi
# paper.
bias
=
bias
[
None
,
:]
-
bias
[:,
None
]
bias
=
bias
.
to
(
self
.
alibi_slopes
.
device
)
# When using custom attention bias, xformers requires the bias to
# be sliced from a tensor whose length is a multiple of 8.
padded_len
=
(
max_prompt_len
+
7
)
//
8
*
8
bias
=
torch
.
empty
(
input_metadata
.
num_prompts
,
self
.
num_heads
,
max_prompt_len
,
padded_len
,
device
=
self
.
alibi_slopes
.
device
,
dtype
=
dtype
,
)[:,
:,
:,
:
max_prompt_len
].
copy_
(
bias
)
bias
.
mul_
(
self
.
alibi_slopes
[:,
None
,
None
])
attn_bias
=
LowerTriangularMaskWithTensorBias
(
bias
)
input_metadata
.
attn_bias
=
attn_bias
def
multi_query_kv_attention
(
self
,
output
:
torch
.
Tensor
,
query
:
torch
.
Tensor
,
key
:
torch
.
Tensor
,
value
:
torch
.
Tensor
,
input_metadata
:
InputMetadata
,
)
->
torch
.
Tensor
:
"""Attention with ALiBi bias for the prompt tokens.
Args:
output: shape = [num_prompt_tokens, num_heads, head_size]
query: shape = [num_prompt_tokens, num_heads, head_size]
key: shape = [num_prompt_tokens, num_kv_heads, head_size]
value: shape = [num_prompt_tokens, num_kv_heads, head_size]
input_metadata: metadata for paged attention.
"""
if
self
.
num_kv_heads
!=
self
.
num_heads
:
# Project the key and value tensors to the desired number of heads.
query
=
query
.
view
(
query
.
shape
[
0
],
self
.
num_kv_heads
,
self
.
num_queries_per_kv
,
query
.
shape
[
-
1
])
key
=
key
[:,
:,
None
,
:].
expand
(
key
.
shape
[
0
],
self
.
num_kv_heads
,
self
.
num_queries_per_kv
,
key
.
shape
[
-
1
])
value
=
value
[:,
:,
None
,
:].
expand
(
value
.
shape
[
0
],
self
.
num_kv_heads
,
self
.
num_queries_per_kv
,
value
.
shape
[
-
1
])
batch_size
=
input_metadata
.
num_prompts
seq_len
=
input_metadata
.
max_prompt_len
out
=
xops
.
memory_efficient_attention_forward
(
query
.
view
(
batch_size
,
seq_len
,
self
.
num_heads
,
self
.
head_size
),
key
.
view
(
batch_size
,
seq_len
,
self
.
num_heads
,
self
.
head_size
),
value
.
view
(
batch_size
,
seq_len
,
self
.
num_heads
,
self
.
head_size
),
attn_bias
=
input_metadata
.
attn_bias
,
p
=
0.0
,
scale
=
self
.
scale
,
)
)
# TODO(woosuk): Unnecessary copy. Optimize.
exp_sums
=
torch
.
empty
(
output
.
copy_
(
out
.
view_as
(
output
))
size
=
(
num_seqs
,
num_heads
,
max_num_partitions
),
return
output
dtype
=
torch
.
float32
,
device
=
output
.
device
,
def
get_alibi_slopes
(
self
)
->
Optional
[
torch
.
Tensor
]:
)
return
self
.
alibi_slopes
max_logits
=
torch
.
empty_like
(
exp_sums
)
ops
.
paged_attention_v2
(
output
,
exp_sums
,
max_logits
,
tmp_output
,
query
,
key_cache
,
value_cache
,
head_mapping
,
scale
,
input_metadata
.
block_tables
,
input_metadata
.
context_lens
,
block_size
,
input_metadata
.
max_context_len
,
alibi_slopes
,
)
return
output
vllm/model_executor/layers/rotary_embedding.py
View file @
a9e45742
...
@@ -277,8 +277,8 @@ def get_rope(
...
@@ -277,8 +277,8 @@ def get_rope(
rotary_dim
:
int
,
rotary_dim
:
int
,
max_position
:
int
,
max_position
:
int
,
base
:
int
,
base
:
int
,
is_neox_style
:
bool
,
is_neox_style
:
bool
=
True
,
rope_scaling
:
Optional
[
Dict
[
str
,
Any
]],
rope_scaling
:
Optional
[
Dict
[
str
,
Any
]]
=
None
,
)
->
RotaryEmbedding
:
)
->
RotaryEmbedding
:
if
rope_scaling
is
None
:
if
rope_scaling
is
None
:
rotary_emb
=
RotaryEmbedding
(
head_size
,
rotary_dim
,
max_position
,
base
,
rotary_emb
=
RotaryEmbedding
(
head_size
,
rotary_dim
,
max_position
,
base
,
...
...
vllm/model_executor/models/aquila.py
View file @
a9e45742
...
@@ -28,11 +28,12 @@ from torch import nn
...
@@ -28,11 +28,12 @@ from torch import nn
from
vllm.model_executor.input_metadata
import
InputMetadata
from
vllm.model_executor.input_metadata
import
InputMetadata
from
vllm.model_executor.layers.activation
import
SiluAndMul
from
vllm.model_executor.layers.activation
import
SiluAndMul
from
vllm.model_executor.layers.attention
import
PagedAttention
WithRoPE
from
vllm.model_executor.layers.attention
import
PagedAttention
from
vllm.model_executor.layers.linear
import
(
LinearMethodBase
,
from
vllm.model_executor.layers.linear
import
(
LinearMethodBase
,
MergedColumnParallelLinear
,
MergedColumnParallelLinear
,
QKVParallelLinear
,
QKVParallelLinear
,
RowParallelLinear
)
RowParallelLinear
)
from
vllm.model_executor.layers.rotary_embedding
import
get_rope
from
vllm.model_executor.layers.sampler
import
Sampler
from
vllm.model_executor.layers.sampler
import
Sampler
from
vllm.model_executor.layers.vocab_parallel_embedding
import
(
from
vllm.model_executor.layers.vocab_parallel_embedding
import
(
VocabParallelEmbedding
,
ParallelLMHead
)
VocabParallelEmbedding
,
ParallelLMHead
)
...
@@ -138,15 +139,17 @@ class AquilaAttention(nn.Module):
...
@@ -138,15 +139,17 @@ class AquilaAttention(nn.Module):
bias
=
False
,
bias
=
False
,
linear_method
=
linear_method
,
linear_method
=
linear_method
,
)
)
self
.
attn
=
PagedAttentionWithRoPE
(
self
.
rotary_emb
=
get_rope
(
self
.
num_heads
,
self
.
head_dim
,
self
.
head_dim
,
self
.
scaling
,
base
=
self
.
rope_theta
,
max_position
=
self
.
max_position_embeddings
,
rotary_dim
=
self
.
head_dim
,
rotary_dim
=
self
.
head_dim
,
num_kv_heads
=
self
.
num_kv_heads
,
max_position
=
self
.
max_position_embeddings
,
rope_scaling
=
rope_scaling
)
base
=
self
.
rope_theta
,
rope_scaling
=
rope_scaling
,
)
self
.
attn
=
PagedAttention
(
self
.
num_heads
,
self
.
head_dim
,
self
.
scaling
,
num_kv_heads
=
self
.
num_kv_heads
)
def
forward
(
def
forward
(
self
,
self
,
...
@@ -158,9 +161,10 @@ class AquilaAttention(nn.Module):
...
@@ -158,9 +161,10 @@ class AquilaAttention(nn.Module):
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
qkv
,
_
=
self
.
qkv_proj
(
hidden_states
)
qkv
,
_
=
self
.
qkv_proj
(
hidden_states
)
q
,
k
,
v
=
qkv
.
split
([
self
.
q_size
,
self
.
kv_size
,
self
.
kv_size
],
dim
=-
1
)
q
,
k
,
v
=
qkv
.
split
([
self
.
q_size
,
self
.
kv_size
,
self
.
kv_size
],
dim
=-
1
)
q
,
k
=
self
.
rotary_emb
(
positions
,
q
,
k
)
k_cache
,
v_cache
=
kv_cache
k_cache
,
v_cache
=
kv_cache
attn_output
=
self
.
attn
(
positions
,
q
,
k
,
v
,
k_cache
,
v_cache
,
attn_output
=
self
.
attn
(
q
,
k
,
v
,
k_cache
,
v_cache
,
input_metadata
,
input_metadata
,
cache_event
)
cache_event
)
output
,
_
=
self
.
o_proj
(
attn_output
)
output
,
_
=
self
.
o_proj
(
attn_output
)
return
output
return
output
...
...
vllm/model_executor/models/baichuan.py
View file @
a9e45742
...
@@ -26,13 +26,13 @@ from torch import nn
...
@@ -26,13 +26,13 @@ from torch import nn
from
vllm.model_executor.input_metadata
import
InputMetadata
from
vllm.model_executor.input_metadata
import
InputMetadata
from
vllm.model_executor.layers.activation
import
SiluAndMul
from
vllm.model_executor.layers.activation
import
SiluAndMul
from
vllm.model_executor.layers.attention
import
(
PagedAttentionWithRoPE
,
from
vllm.model_executor.layers.attention
import
PagedAttention
PagedAttentionWithALiBi
)
from
vllm.model_executor.layers.layernorm
import
RMSNorm
from
vllm.model_executor.layers.layernorm
import
RMSNorm
from
vllm.model_executor.layers.linear
import
(
LinearMethodBase
,
from
vllm.model_executor.layers.linear
import
(
LinearMethodBase
,
MergedColumnParallelLinear
,
MergedColumnParallelLinear
,
QKVParallelLinear
,
QKVParallelLinear
,
RowParallelLinear
)
RowParallelLinear
)
from
vllm.model_executor.layers.rotary_embedding
import
get_rope
from
vllm.model_executor.layers.sampler
import
Sampler
from
vllm.model_executor.layers.sampler
import
Sampler
from
vllm.model_executor.layers.vocab_parallel_embedding
import
(
from
vllm.model_executor.layers.vocab_parallel_embedding
import
(
VocabParallelEmbedding
,
ParallelLMHead
)
VocabParallelEmbedding
,
ParallelLMHead
)
...
@@ -150,17 +150,20 @@ class BaiChuanAttention(nn.Module):
...
@@ -150,17 +150,20 @@ class BaiChuanAttention(nn.Module):
alibi_slopes
=
alibi_slopes
[
head_start
:
head_end
].
tolist
()
alibi_slopes
=
alibi_slopes
[
head_start
:
head_end
].
tolist
()
scaling
=
self
.
head_dim
**-
0.5
scaling
=
self
.
head_dim
**-
0.5
self
.
attn
=
PagedAttentionWithALiBi
(
self
.
num_heads
,
self
.
head_dim
,
self
.
attn
=
PagedAttention
(
self
.
num_heads
,
scaling
,
alibi_slopes
)
self
.
head_dim
,
scaling
,
alibi_slopes
=
alibi_slopes
)
else
:
else
:
self
.
scaling
=
self
.
head_dim
**-
0.5
self
.
rotary_emb
=
get_rope
(
self
.
attn
=
PagedAttentionWithRoPE
(
self
.
num_heads
,
self
.
head_dim
,
self
.
head_dim
,
self
.
scaling
,
rotary_dim
=
self
.
head_dim
,
rotary_dim
=
self
.
head_dim
,
max_position
=
self
.
max_position_embeddings
,
base
=
self
.
rope_theta
,
base
=
self
.
rope_theta
,
max_position
=
self
.
max_position_embeddings
)
)
self
.
scaling
=
self
.
head_dim
**-
0.5
self
.
attn
=
PagedAttention
(
self
.
num_heads
,
self
.
head_dim
,
self
.
scaling
)
def
forward
(
def
forward
(
self
,
self
,
...
@@ -172,14 +175,11 @@ class BaiChuanAttention(nn.Module):
...
@@ -172,14 +175,11 @@ class BaiChuanAttention(nn.Module):
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
qkv
,
_
=
self
.
W_pack
(
hidden_states
)
qkv
,
_
=
self
.
W_pack
(
hidden_states
)
q
,
k
,
v
=
qkv
.
chunk
(
chunks
=
3
,
dim
=-
1
)
q
,
k
,
v
=
qkv
.
chunk
(
chunks
=
3
,
dim
=-
1
)
if
self
.
postion_embedding
!=
"ALIBI"
:
q
,
k
=
self
.
rotary_emb
(
positions
,
q
,
k
)
k_cache
,
v_cache
=
kv_cache
k_cache
,
v_cache
=
kv_cache
if
self
.
postion_embedding
==
"ALIBI"
:
attn_output
=
self
.
attn
(
q
,
k
,
v
,
k_cache
,
v_cache
,
input_metadata
,
attn_output
=
self
.
attn
(
q
,
k
,
v
,
k_cache
,
v_cache
,
input_metadata
,
cache_event
)
cache_event
)
else
:
attn_output
=
self
.
attn
(
positions
,
q
,
k
,
v
,
k_cache
,
v_cache
,
input_metadata
,
cache_event
)
output
,
_
=
self
.
o_proj
(
attn_output
)
output
,
_
=
self
.
o_proj
(
attn_output
)
return
output
return
output
...
...
vllm/model_executor/models/bloom.py
View file @
a9e45742
...
@@ -25,7 +25,7 @@ from transformers import BloomConfig
...
@@ -25,7 +25,7 @@ from transformers import BloomConfig
from
vllm.model_executor.input_metadata
import
InputMetadata
from
vllm.model_executor.input_metadata
import
InputMetadata
from
vllm.model_executor.layers.activation
import
get_act_fn
from
vllm.model_executor.layers.activation
import
get_act_fn
from
vllm.model_executor.layers.attention
import
PagedAttention
WithALiBi
from
vllm.model_executor.layers.attention
import
PagedAttention
from
vllm.model_executor.layers.linear
import
(
ColumnParallelLinear
,
from
vllm.model_executor.layers.linear
import
(
ColumnParallelLinear
,
LinearMethodBase
,
LinearMethodBase
,
QKVParallelLinear
,
QKVParallelLinear
,
...
@@ -106,8 +106,10 @@ class BloomAttention(nn.Module):
...
@@ -106,8 +106,10 @@ class BloomAttention(nn.Module):
alibi_slopes
=
alibi_slopes
[
head_start
:
head_end
].
tolist
()
alibi_slopes
=
alibi_slopes
[
head_start
:
head_end
].
tolist
()
scaling
=
self
.
head_dim
**-
0.5
scaling
=
self
.
head_dim
**-
0.5
self
.
attn
=
PagedAttentionWithALiBi
(
self
.
num_heads
,
self
.
head_dim
,
self
.
attn
=
PagedAttention
(
self
.
num_heads
,
scaling
,
alibi_slopes
)
self
.
head_dim
,
scaling
,
alibi_slopes
=
alibi_slopes
)
def
forward
(
def
forward
(
self
,
self
,
...
...
vllm/model_executor/models/chatglm.py
View file @
a9e45742
...
@@ -10,12 +10,13 @@ from torch.nn import LayerNorm
...
@@ -10,12 +10,13 @@ from torch.nn import LayerNorm
from
vllm.model_executor.input_metadata
import
InputMetadata
from
vllm.model_executor.input_metadata
import
InputMetadata
from
vllm.model_executor.layers.activation
import
SiluAndMul
from
vllm.model_executor.layers.activation
import
SiluAndMul
from
vllm.model_executor.layers.attention
import
PagedAttention
WithRoPE
from
vllm.model_executor.layers.attention
import
PagedAttention
from
vllm.model_executor.layers.layernorm
import
RMSNorm
from
vllm.model_executor.layers.layernorm
import
RMSNorm
from
vllm.model_executor.layers.linear
import
(
LinearMethodBase
,
from
vllm.model_executor.layers.linear
import
(
LinearMethodBase
,
MergedColumnParallelLinear
,
MergedColumnParallelLinear
,
QKVParallelLinear
,
QKVParallelLinear
,
RowParallelLinear
)
RowParallelLinear
)
from
vllm.model_executor.layers.rotary_embedding
import
get_rope
from
vllm.model_executor.layers.sampler
import
Sampler
from
vllm.model_executor.layers.sampler
import
Sampler
from
vllm.model_executor.layers.vocab_parallel_embedding
import
(
from
vllm.model_executor.layers.vocab_parallel_embedding
import
(
VocabParallelEmbedding
,
ParallelLMHead
)
VocabParallelEmbedding
,
ParallelLMHead
)
...
@@ -78,16 +79,19 @@ class GLMAttention(nn.Module):
...
@@ -78,16 +79,19 @@ class GLMAttention(nn.Module):
# https://huggingface.co/THUDM/chatglm3-6b-32k/blob/e210410255278dd9d74463cf396ba559c0ef801c/modeling_chatglm.py#L141
# https://huggingface.co/THUDM/chatglm3-6b-32k/blob/e210410255278dd9d74463cf396ba559c0ef801c/modeling_chatglm.py#L141
rope_ratio
=
getattr
(
config
,
"rope_ratio"
,
1.0
)
rope_ratio
=
getattr
(
config
,
"rope_ratio"
,
1.0
)
max_positions
=
getattr
(
config
,
"seq_length"
,
8192
)
max_positions
=
getattr
(
config
,
"seq_length"
,
8192
)
self
.
attn
=
PagedAttentionWithRoPE
(
self
.
rotary_emb
=
get_rope
(
self
.
num_heads
,
self
.
head_dim
,
self
.
head_dim
,
self
.
scaling
,
rotary_dim
=
self
.
head_dim
//
2
,
rotary_dim
=
self
.
head_dim
//
2
,
num_kv_heads
=
self
.
num_kv_heads
,
max_position
=
max_positions
,
max_position
=
max_positions
,
base
=
10000
*
rope_ratio
,
base
=
10000
*
rope_ratio
,
is_neox_style
=
False
,
is_neox_style
=
False
,
)
)
self
.
attn
=
PagedAttention
(
self
.
num_heads
,
self
.
head_dim
,
self
.
scaling
,
num_kv_heads
=
self
.
num_kv_heads
,
)
def
forward
(
def
forward
(
self
,
self
,
...
@@ -99,10 +103,9 @@ class GLMAttention(nn.Module):
...
@@ -99,10 +103,9 @@ class GLMAttention(nn.Module):
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
qkv
,
_
=
self
.
query_key_value
(
hidden_states
)
qkv
,
_
=
self
.
query_key_value
(
hidden_states
)
q
,
k
,
v
=
qkv
.
split
([
self
.
q_size
,
self
.
kv_size
,
self
.
kv_size
],
dim
=-
1
)
q
,
k
,
v
=
qkv
.
split
([
self
.
q_size
,
self
.
kv_size
,
self
.
kv_size
],
dim
=-
1
)
q
,
k
=
self
.
rotary_emb
(
position_ids
,
q
,
k
)
key_cache
,
value_cache
=
kv_cache
key_cache
,
value_cache
=
kv_cache
context_layer
=
self
.
attn
(
context_layer
=
self
.
attn
(
position_ids
,
q
,
q
,
k
,
k
,
v
,
v
,
...
@@ -111,9 +114,7 @@ class GLMAttention(nn.Module):
...
@@ -111,9 +114,7 @@ class GLMAttention(nn.Module):
input_metadata
,
input_metadata
,
cache_event
,
cache_event
,
)
)
attn_output
,
_
=
self
.
dense
(
context_layer
)
attn_output
,
_
=
self
.
dense
(
context_layer
)
return
attn_output
return
attn_output
...
...
vllm/model_executor/models/falcon.py
View file @
a9e45742
...
@@ -28,13 +28,12 @@ from transformers import FalconConfig as HF_FalconConfig
...
@@ -28,13 +28,12 @@ from transformers import FalconConfig as HF_FalconConfig
from
vllm.model_executor.input_metadata
import
InputMetadata
from
vllm.model_executor.input_metadata
import
InputMetadata
from
vllm.model_executor.layers.activation
import
get_act_fn
from
vllm.model_executor.layers.activation
import
get_act_fn
from
vllm.model_executor.layers.attention
import
(
PagedAttention
,
from
vllm.model_executor.layers.attention
import
PagedAttention
PagedAttentionWithALiBi
,
PagedAttentionWithRoPE
)
from
vllm.model_executor.layers.linear
import
(
ColumnParallelLinear
,
from
vllm.model_executor.layers.linear
import
(
ColumnParallelLinear
,
LinearMethodBase
,
LinearMethodBase
,
QKVParallelLinear
,
QKVParallelLinear
,
RowParallelLinear
)
RowParallelLinear
)
from
vllm.model_executor.layers.rotary_embedding
import
get_rope
from
vllm.model_executor.layers.sampler
import
Sampler
from
vllm.model_executor.layers.sampler
import
Sampler
from
vllm.model_executor.layers.vocab_parallel_embedding
import
(
from
vllm.model_executor.layers.vocab_parallel_embedding
import
(
VocabParallelEmbedding
,
ParallelLMHead
)
VocabParallelEmbedding
,
ParallelLMHead
)
...
@@ -144,14 +143,16 @@ class FalconAttention(nn.Module):
...
@@ -144,14 +143,16 @@ class FalconAttention(nn.Module):
rope_theta
=
getattr
(
config
,
"rope_theta"
,
10000
)
rope_theta
=
getattr
(
config
,
"rope_theta"
,
10000
)
max_position_embeddings
=
getattr
(
config
,
max_position_embeddings
=
getattr
(
config
,
"max_position_embeddings"
,
8192
)
"max_position_embeddings"
,
8192
)
self
.
attn
=
PagedAttentionWithRoPE
(
self
.
rotary_emb
=
get_rope
(
self
.
num_heads
,
self
.
head_dim
,
self
.
head_dim
,
self
.
inv_norm_factor
,
base
=
rope_theta
,
max_position
=
max_position_embeddings
,
rotary_dim
=
self
.
head_dim
,
rotary_dim
=
self
.
head_dim
,
num_kv_heads
=
self
.
num_kv_heads
)
max_position
=
max_position_embeddings
,
base
=
rope_theta
,
)
self
.
attn
=
PagedAttention
(
self
.
num_heads
,
self
.
head_dim
,
self
.
inv_norm_factor
,
num_kv_heads
=
self
.
num_kv_heads
)
elif
self
.
use_alibi
:
elif
self
.
use_alibi
:
tp_rank
=
get_tensor_model_parallel_rank
()
tp_rank
=
get_tensor_model_parallel_rank
()
head_start
=
tp_rank
*
self
.
num_heads
head_start
=
tp_rank
*
self
.
num_heads
...
@@ -159,11 +160,11 @@ class FalconAttention(nn.Module):
...
@@ -159,11 +160,11 @@ class FalconAttention(nn.Module):
alibi_slopes
=
(
_get_alibi_slopes
(
self
.
total_num_heads
)
*
alibi_slopes
=
(
_get_alibi_slopes
(
self
.
total_num_heads
)
*
self
.
inv_norm_factor
)
self
.
inv_norm_factor
)
alibi_slopes
=
alibi_slopes
[
head_start
:
head_end
].
tolist
()
alibi_slopes
=
alibi_slopes
[
head_start
:
head_end
].
tolist
()
self
.
attn
=
PagedAttention
WithALiBi
(
self
.
num_heads
,
self
.
attn
=
PagedAttention
(
self
.
num_heads
,
self
.
head_dim
,
self
.
head_dim
,
self
.
inv_norm_factor
,
self
.
inv_norm_factor
,
alibi_slope
s
,
num_kv_heads
=
self
.
num_kv_head
s
,
num_kv_heads
=
self
.
num_kv_head
s
)
alibi_slopes
=
alibi_slope
s
)
else
:
else
:
self
.
attn
=
PagedAttention
(
self
.
num_heads
,
self
.
attn
=
PagedAttention
(
self
.
num_heads
,
self
.
head_dim
,
self
.
head_dim
,
...
@@ -182,13 +183,11 @@ class FalconAttention(nn.Module):
...
@@ -182,13 +183,11 @@ class FalconAttention(nn.Module):
if
bias
is
not
None
:
if
bias
is
not
None
:
qkv
+=
bias
qkv
+=
bias
q
,
k
,
v
=
qkv
.
split
([
self
.
q_size
,
self
.
kv_size
,
self
.
kv_size
],
dim
=-
1
)
q
,
k
,
v
=
qkv
.
split
([
self
.
q_size
,
self
.
kv_size
,
self
.
kv_size
],
dim
=-
1
)
k_cache
,
v_cache
=
kv_cache
if
self
.
use_rotary
:
if
self
.
use_rotary
:
attn_output
=
self
.
attn
(
positions
,
q
,
k
,
v
,
k_cache
,
v_cache
,
q
,
k
=
self
.
rotary_emb
(
positions
,
q
,
k
)
input_metadata
,
cache_event
)
k_cache
,
v_cache
=
kv_cache
else
:
attn_output
=
self
.
attn
(
q
,
k
,
v
,
k_cache
,
v_cache
,
input_metadata
,
attn_output
=
self
.
attn
(
q
,
k
,
v
,
k_cache
,
v_cache
,
input_metadata
,
cache_event
)
cache_event
)
attn_output
,
bias
=
self
.
dense
(
attn_output
)
attn_output
,
bias
=
self
.
dense
(
attn_output
)
return
attn_output
,
bias
return
attn_output
,
bias
...
...
vllm/model_executor/models/gpt_j.py
View file @
a9e45742
...
@@ -24,11 +24,12 @@ from transformers import GPTJConfig
...
@@ -24,11 +24,12 @@ from transformers import GPTJConfig
from
vllm.model_executor.input_metadata
import
InputMetadata
from
vllm.model_executor.input_metadata
import
InputMetadata
from
vllm.model_executor.layers.activation
import
get_act_fn
from
vllm.model_executor.layers.activation
import
get_act_fn
from
vllm.model_executor.layers.attention
import
PagedAttention
WithRoPE
from
vllm.model_executor.layers.attention
import
PagedAttention
from
vllm.model_executor.layers.linear
import
(
ColumnParallelLinear
,
from
vllm.model_executor.layers.linear
import
(
ColumnParallelLinear
,
LinearMethodBase
,
LinearMethodBase
,
QKVParallelLinear
,
QKVParallelLinear
,
RowParallelLinear
)
RowParallelLinear
)
from
vllm.model_executor.layers.rotary_embedding
import
get_rope
from
vllm.model_executor.layers.sampler
import
Sampler
from
vllm.model_executor.layers.sampler
import
Sampler
from
vllm.model_executor.layers.vocab_parallel_embedding
import
(
from
vllm.model_executor.layers.vocab_parallel_embedding
import
(
VocabParallelEmbedding
,
ParallelLMHead
)
VocabParallelEmbedding
,
ParallelLMHead
)
...
@@ -77,15 +78,14 @@ class GPTJAttention(nn.Module):
...
@@ -77,15 +78,14 @@ class GPTJAttention(nn.Module):
rope_theta
=
getattr
(
config
,
"rope_theta"
,
10000
)
rope_theta
=
getattr
(
config
,
"rope_theta"
,
10000
)
max_position_embeddings
=
getattr
(
config
,
"max_position_embeddings"
,
max_position_embeddings
=
getattr
(
config
,
"max_position_embeddings"
,
8192
)
8192
)
self
.
attn
=
PagedAttentionWithRoPE
(
self
.
rotary_emb
=
get_rope
(
self
.
num_heads
,
self
.
head_size
,
self
.
head_size
,
scaling
,
rotary_dim
=
config
.
rotary_dim
,
config
.
rotary_dim
,
base
=
rope_theta
,
max_position
=
max_position_embeddings
,
max_position
=
max_position_embeddings
,
is_neox_style
=
False
)
base
=
rope_theta
,
self
.
warmup
=
False
is_neox_style
=
False
,
)
self
.
attn
=
PagedAttention
(
self
.
num_heads
,
self
.
head_size
,
scaling
)
def
forward
(
def
forward
(
self
,
self
,
...
@@ -97,9 +97,10 @@ class GPTJAttention(nn.Module):
...
@@ -97,9 +97,10 @@ class GPTJAttention(nn.Module):
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
qkv
,
_
=
self
.
qkv_proj
(
hidden_states
)
qkv
,
_
=
self
.
qkv_proj
(
hidden_states
)
q
,
k
,
v
=
qkv
.
chunk
(
chunks
=
3
,
dim
=-
1
)
q
,
k
,
v
=
qkv
.
chunk
(
chunks
=
3
,
dim
=-
1
)
q
,
k
=
self
.
rotary_emb
(
position_ids
,
q
,
k
)
k_cache
,
v_cache
=
kv_cache
k_cache
,
v_cache
=
kv_cache
attn_output
=
self
.
attn
(
position_ids
,
q
,
k
,
v
,
k_cache
,
v_cache
,
attn_output
=
self
.
attn
(
q
,
k
,
v
,
k_cache
,
v_cache
,
input_metadata
,
input_metadata
,
cache_event
)
cache_event
)
attn_output
,
_
=
self
.
out_proj
(
attn_output
)
attn_output
,
_
=
self
.
out_proj
(
attn_output
)
return
attn_output
return
attn_output
...
...
vllm/model_executor/models/gpt_neox.py
View file @
a9e45742
...
@@ -24,11 +24,12 @@ from transformers import GPTNeoXConfig
...
@@ -24,11 +24,12 @@ from transformers import GPTNeoXConfig
from
vllm.model_executor.input_metadata
import
InputMetadata
from
vllm.model_executor.input_metadata
import
InputMetadata
from
vllm.model_executor.layers.activation
import
get_act_fn
from
vllm.model_executor.layers.activation
import
get_act_fn
from
vllm.model_executor.layers.attention
import
PagedAttention
WithRoPE
from
vllm.model_executor.layers.attention
import
PagedAttention
from
vllm.model_executor.layers.linear
import
(
ColumnParallelLinear
,
from
vllm.model_executor.layers.linear
import
(
ColumnParallelLinear
,
LinearMethodBase
,
LinearMethodBase
,
QKVParallelLinear
,
QKVParallelLinear
,
RowParallelLinear
)
RowParallelLinear
)
from
vllm.model_executor.layers.rotary_embedding
import
get_rope
from
vllm.model_executor.layers.sampler
import
Sampler
from
vllm.model_executor.layers.sampler
import
Sampler
from
vllm.model_executor.layers.vocab_parallel_embedding
import
(
from
vllm.model_executor.layers.vocab_parallel_embedding
import
(
VocabParallelEmbedding
,
ParallelLMHead
)
VocabParallelEmbedding
,
ParallelLMHead
)
...
@@ -77,13 +78,13 @@ class GPTNeoXAttention(nn.Module):
...
@@ -77,13 +78,13 @@ class GPTNeoXAttention(nn.Module):
rope_theta
=
getattr
(
config
,
"rope_theta"
,
10000
)
rope_theta
=
getattr
(
config
,
"rope_theta"
,
10000
)
max_position_embeddings
=
getattr
(
config
,
"max_position_embeddings"
,
max_position_embeddings
=
getattr
(
config
,
"max_position_embeddings"
,
8192
)
8192
)
self
.
attn
=
PagedAttentionWithRoPE
(
self
.
rotary_emb
=
get_rope
(
self
.
num_heads
,
self
.
head_size
,
self
.
head_size
,
scaling
,
rotary_dim
=
rotary_dim
,
rotary_dim
,
max_position
=
max_position_embeddings
,
base
=
rope_theta
,
base
=
rope_theta
,
max_position
=
max_position_embeddings
)
)
self
.
attn
=
PagedAttention
(
self
.
num_heads
,
self
.
head_size
,
scaling
)
def
forward
(
def
forward
(
self
,
self
,
...
@@ -95,9 +96,10 @@ class GPTNeoXAttention(nn.Module):
...
@@ -95,9 +96,10 @@ class GPTNeoXAttention(nn.Module):
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
qkv
,
_
=
self
.
query_key_value
(
hidden_states
)
qkv
,
_
=
self
.
query_key_value
(
hidden_states
)
q
,
k
,
v
=
qkv
.
chunk
(
chunks
=
3
,
dim
=-
1
)
q
,
k
,
v
=
qkv
.
chunk
(
chunks
=
3
,
dim
=-
1
)
q
,
k
=
self
.
rotary_emb
(
position_ids
,
q
,
k
)
k_cache
,
v_cache
=
kv_cache
k_cache
,
v_cache
=
kv_cache
attn_output
=
self
.
attn
(
position_ids
,
q
,
k
,
v
,
k_cache
,
v_cache
,
attn_output
=
self
.
attn
(
q
,
k
,
v
,
k_cache
,
v_cache
,
input_metadata
,
input_metadata
,
cache_event
)
cache_event
)
output
,
_
=
self
.
dense
(
attn_output
)
output
,
_
=
self
.
dense
(
attn_output
)
return
output
return
output
...
...
vllm/model_executor/models/internlm.py
View file @
a9e45742
...
@@ -7,12 +7,13 @@ from transformers import LlamaConfig
...
@@ -7,12 +7,13 @@ from transformers import LlamaConfig
from
vllm.model_executor.input_metadata
import
InputMetadata
from
vllm.model_executor.input_metadata
import
InputMetadata
from
vllm.model_executor.layers.activation
import
SiluAndMul
from
vllm.model_executor.layers.activation
import
SiluAndMul
from
vllm.model_executor.layers.attention
import
PagedAttention
WithRoPE
from
vllm.model_executor.layers.attention
import
PagedAttention
from
vllm.model_executor.layers.layernorm
import
RMSNorm
from
vllm.model_executor.layers.layernorm
import
RMSNorm
from
vllm.model_executor.layers.linear
import
(
LinearMethodBase
,
from
vllm.model_executor.layers.linear
import
(
LinearMethodBase
,
MergedColumnParallelLinear
,
MergedColumnParallelLinear
,
QKVParallelLinear
,
QKVParallelLinear
,
RowParallelLinear
)
RowParallelLinear
)
from
vllm.model_executor.layers.rotary_embedding
import
get_rope
from
vllm.model_executor.layers.sampler
import
Sampler
from
vllm.model_executor.layers.sampler
import
Sampler
from
vllm.model_executor.layers.vocab_parallel_embedding
import
(
from
vllm.model_executor.layers.vocab_parallel_embedding
import
(
VocabParallelEmbedding
,
ParallelLMHead
)
VocabParallelEmbedding
,
ParallelLMHead
)
...
@@ -92,13 +93,13 @@ class InternLMAttention(nn.Module):
...
@@ -92,13 +93,13 @@ class InternLMAttention(nn.Module):
bias
=
bias
,
bias
=
bias
,
linear_method
=
linear_method
,
linear_method
=
linear_method
,
)
)
self
.
attn
=
PagedAttentionWithRoPE
(
self
.
rotary_emb
=
get_rope
(
self
.
num_heads
,
self
.
head_dim
,
self
.
head_dim
,
self
.
scaling
,
rotary_dim
=
self
.
head_dim
,
base
=
self
.
rope_theta
,
max_position
=
self
.
max_position_embeddings
,
max_position
=
self
.
max_position_embeddings
,
rotary_dim
=
self
.
head_dim
)
base
=
self
.
rope_theta
,
)
self
.
attn
=
PagedAttention
(
self
.
num_heads
,
self
.
head_dim
,
self
.
scaling
)
def
forward
(
def
forward
(
self
,
self
,
...
@@ -110,9 +111,10 @@ class InternLMAttention(nn.Module):
...
@@ -110,9 +111,10 @@ class InternLMAttention(nn.Module):
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
qkv
,
_
=
self
.
qkv_proj
(
hidden_states
)
qkv
,
_
=
self
.
qkv_proj
(
hidden_states
)
q
,
k
,
v
=
qkv
.
chunk
(
chunks
=
3
,
dim
=-
1
)
q
,
k
,
v
=
qkv
.
chunk
(
chunks
=
3
,
dim
=-
1
)
q
,
k
=
self
.
rotary_emb
(
positions
,
q
,
k
)
k_cache
,
v_cache
=
kv_cache
k_cache
,
v_cache
=
kv_cache
attn_output
=
self
.
attn
(
positions
,
q
,
k
,
v
,
k_cache
,
v_cache
,
attn_output
=
self
.
attn
(
q
,
k
,
v
,
k_cache
,
v_cache
,
input_metadata
,
input_metadata
,
cache_event
)
cache_event
)
output
,
_
=
self
.
o_proj
(
attn_output
)
output
,
_
=
self
.
o_proj
(
attn_output
)
return
output
return
output
...
...
vllm/model_executor/models/llama.py
View file @
a9e45742
...
@@ -29,12 +29,13 @@ from transformers import LlamaConfig
...
@@ -29,12 +29,13 @@ from transformers import LlamaConfig
from
vllm.model_executor.input_metadata
import
InputMetadata
from
vllm.model_executor.input_metadata
import
InputMetadata
from
vllm.model_executor.layers.activation
import
SiluAndMul
from
vllm.model_executor.layers.activation
import
SiluAndMul
from
vllm.model_executor.layers.attention
import
PagedAttention
WithRoPE
from
vllm.model_executor.layers.attention
import
PagedAttention
from
vllm.model_executor.layers.layernorm
import
RMSNorm
from
vllm.model_executor.layers.layernorm
import
RMSNorm
from
vllm.model_executor.layers.linear
import
(
LinearMethodBase
,
from
vllm.model_executor.layers.linear
import
(
LinearMethodBase
,
MergedColumnParallelLinear
,
MergedColumnParallelLinear
,
QKVParallelLinear
,
QKVParallelLinear
,
RowParallelLinear
)
RowParallelLinear
)
from
vllm.model_executor.layers.rotary_embedding
import
get_rope
from
vllm.model_executor.layers.sampler
import
Sampler
from
vllm.model_executor.layers.sampler
import
Sampler
from
vllm.model_executor.layers.vocab_parallel_embedding
import
(
from
vllm.model_executor.layers.vocab_parallel_embedding
import
(
VocabParallelEmbedding
,
ParallelLMHead
)
VocabParallelEmbedding
,
ParallelLMHead
)
...
@@ -126,15 +127,18 @@ class LlamaAttention(nn.Module):
...
@@ -126,15 +127,18 @@ class LlamaAttention(nn.Module):
bias
=
False
,
bias
=
False
,
linear_method
=
linear_method
,
linear_method
=
linear_method
,
)
)
self
.
attn
=
PagedAttentionWithRoPE
(
self
.
num_heads
,
self
.
rotary_emb
=
get_rope
(
self
.
head_dim
,
self
.
head_dim
,
self
.
scaling
,
base
=
self
.
rope_theta
,
max_position
=
self
.
max_position_embeddings
,
rotary_dim
=
self
.
head_dim
,
rotary_dim
=
self
.
head_dim
,
num_kv_heads
=
self
.
num_kv_heads
,
max_position
=
max_position_embeddings
,
rope_scaling
=
rope_scaling
)
base
=
rope_theta
,
rope_scaling
=
rope_scaling
,
)
self
.
attn
=
PagedAttention
(
self
.
num_heads
,
self
.
head_dim
,
self
.
scaling
,
num_kv_heads
=
self
.
num_kv_heads
)
def
forward
(
def
forward
(
self
,
self
,
...
@@ -146,9 +150,10 @@ class LlamaAttention(nn.Module):
...
@@ -146,9 +150,10 @@ class LlamaAttention(nn.Module):
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
qkv
,
_
=
self
.
qkv_proj
(
hidden_states
)
qkv
,
_
=
self
.
qkv_proj
(
hidden_states
)
q
,
k
,
v
=
qkv
.
split
([
self
.
q_size
,
self
.
kv_size
,
self
.
kv_size
],
dim
=-
1
)
q
,
k
,
v
=
qkv
.
split
([
self
.
q_size
,
self
.
kv_size
,
self
.
kv_size
],
dim
=-
1
)
q
,
k
=
self
.
rotary_emb
(
positions
,
q
,
k
)
k_cache
,
v_cache
=
kv_cache
k_cache
,
v_cache
=
kv_cache
attn_output
=
self
.
attn
(
positions
,
q
,
k
,
v
,
k_cache
,
v_cache
,
attn_output
=
self
.
attn
(
q
,
k
,
v
,
k_cache
,
v_cache
,
input_metadata
,
input_metadata
,
cache_event
)
cache_event
)
output
,
_
=
self
.
o_proj
(
attn_output
)
output
,
_
=
self
.
o_proj
(
attn_output
)
return
output
return
output
...
...
vllm/model_executor/models/mistral.py
View file @
a9e45742
...
@@ -29,12 +29,13 @@ from transformers import MistralConfig
...
@@ -29,12 +29,13 @@ from transformers import MistralConfig
from
vllm.model_executor.input_metadata
import
InputMetadata
from
vllm.model_executor.input_metadata
import
InputMetadata
from
vllm.model_executor.layers.activation
import
SiluAndMul
from
vllm.model_executor.layers.activation
import
SiluAndMul
from
vllm.model_executor.layers.attention
import
PagedAttention
WithRoPE
from
vllm.model_executor.layers.attention
import
PagedAttention
from
vllm.model_executor.layers.layernorm
import
RMSNorm
from
vllm.model_executor.layers.layernorm
import
RMSNorm
from
vllm.model_executor.layers.linear
import
(
LinearMethodBase
,
from
vllm.model_executor.layers.linear
import
(
LinearMethodBase
,
MergedColumnParallelLinear
,
MergedColumnParallelLinear
,
QKVParallelLinear
,
QKVParallelLinear
,
RowParallelLinear
)
RowParallelLinear
)
from
vllm.model_executor.layers.rotary_embedding
import
get_rope
from
vllm.model_executor.layers.sampler
import
Sampler
from
vllm.model_executor.layers.sampler
import
Sampler
from
vllm.model_executor.layers.vocab_parallel_embedding
import
(
from
vllm.model_executor.layers.vocab_parallel_embedding
import
(
VocabParallelEmbedding
,
ParallelLMHead
)
VocabParallelEmbedding
,
ParallelLMHead
)
...
@@ -124,14 +125,18 @@ class MistralAttention(nn.Module):
...
@@ -124,14 +125,18 @@ class MistralAttention(nn.Module):
bias
=
False
,
bias
=
False
,
linear_method
=
linear_method
,
linear_method
=
linear_method
,
)
)
self
.
attn
=
PagedAttentionWithRoPE
(
self
.
num_heads
,
self
.
head_dim
,
self
.
rotary_emb
=
get_rope
(
self
.
scaling
,
self
.
head_dim
,
base
=
self
.
rope_theta
,
rotary_dim
=
self
.
head_dim
,
max_position
=
max_position
,
max_position
=
max_position
,
rotary_dim
=
self
.
head_dim
,
base
=
self
.
rope_theta
,
num_kv_heads
=
self
.
num_kv_heads
,
)
sliding_window
=
self
.
sliding_window
)
self
.
attn
=
PagedAttention
(
self
.
num_heads
,
self
.
head_dim
,
self
.
scaling
,
num_kv_heads
=
self
.
num_kv_heads
,
sliding_window
=
self
.
sliding_window
)
def
forward
(
def
forward
(
self
,
self
,
...
@@ -143,9 +148,10 @@ class MistralAttention(nn.Module):
...
@@ -143,9 +148,10 @@ class MistralAttention(nn.Module):
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
qkv
,
_
=
self
.
qkv_proj
(
hidden_states
)
qkv
,
_
=
self
.
qkv_proj
(
hidden_states
)
q
,
k
,
v
=
qkv
.
split
([
self
.
q_size
,
self
.
kv_size
,
self
.
kv_size
],
dim
=-
1
)
q
,
k
,
v
=
qkv
.
split
([
self
.
q_size
,
self
.
kv_size
,
self
.
kv_size
],
dim
=-
1
)
q
,
k
=
self
.
rotary_emb
(
positions
,
q
,
k
)
k_cache
,
v_cache
=
kv_cache
k_cache
,
v_cache
=
kv_cache
attn_output
=
self
.
attn
(
positions
,
q
,
k
,
v
,
k_cache
,
v_cache
,
attn_output
=
self
.
attn
(
q
,
k
,
v
,
k_cache
,
v_cache
,
input_metadata
,
input_metadata
,
cache_event
)
cache_event
)
output
,
_
=
self
.
o_proj
(
attn_output
)
output
,
_
=
self
.
o_proj
(
attn_output
)
return
output
return
output
...
...
vllm/model_executor/models/mpt.py
View file @
a9e45742
...
@@ -8,7 +8,7 @@ import torch.nn as nn
...
@@ -8,7 +8,7 @@ import torch.nn as nn
from
vllm.model_executor.input_metadata
import
InputMetadata
from
vllm.model_executor.input_metadata
import
InputMetadata
from
vllm.model_executor.layers.activation
import
get_act_fn
from
vllm.model_executor.layers.activation
import
get_act_fn
from
vllm.model_executor.layers.attention
import
PagedAttention
WithALiBi
from
vllm.model_executor.layers.attention
import
PagedAttention
from
vllm.model_executor.layers.linear
import
(
ColumnParallelLinear
,
from
vllm.model_executor.layers.linear
import
(
ColumnParallelLinear
,
LinearMethodBase
,
LinearMethodBase
,
QKVParallelLinear
,
QKVParallelLinear
,
...
@@ -87,8 +87,10 @@ class MPTAttention(nn.Module):
...
@@ -87,8 +87,10 @@ class MPTAttention(nn.Module):
self
.
head_dim
=
self
.
d_model
//
self
.
total_num_heads
self
.
head_dim
=
self
.
d_model
//
self
.
total_num_heads
scaling
=
self
.
head_dim
**-
0.5
scaling
=
self
.
head_dim
**-
0.5
self
.
attn
=
PagedAttentionWithALiBi
(
self
.
num_heads
,
self
.
head_dim
,
self
.
attn
=
PagedAttention
(
self
.
num_heads
,
scaling
,
alibi_slopes
)
self
.
head_dim
,
scaling
,
alibi_slopes
=
alibi_slopes
)
def
forward
(
def
forward
(
self
,
self
,
...
...
vllm/model_executor/models/phi_1_5.py
View file @
a9e45742
...
@@ -43,11 +43,12 @@ from transformers import PretrainedConfig
...
@@ -43,11 +43,12 @@ from transformers import PretrainedConfig
from
vllm.model_executor.input_metadata
import
InputMetadata
from
vllm.model_executor.input_metadata
import
InputMetadata
from
vllm.model_executor.layers.activation
import
get_act_fn
from
vllm.model_executor.layers.activation
import
get_act_fn
from
vllm.model_executor.layers.attention
import
PagedAttention
WithRoPE
from
vllm.model_executor.layers.attention
import
PagedAttention
from
vllm.model_executor.layers.linear
import
(
ColumnParallelLinear
,
from
vllm.model_executor.layers.linear
import
(
ColumnParallelLinear
,
LinearMethodBase
,
LinearMethodBase
,
QKVParallelLinear
,
QKVParallelLinear
,
RowParallelLinear
)
RowParallelLinear
)
from
vllm.model_executor.layers.rotary_embedding
import
get_rope
from
vllm.model_executor.layers.sampler
import
Sampler
from
vllm.model_executor.layers.sampler
import
Sampler
from
vllm.model_executor.layers.vocab_parallel_embedding
import
(
from
vllm.model_executor.layers.vocab_parallel_embedding
import
(
VocabParallelEmbedding
,
ParallelLMHead
)
VocabParallelEmbedding
,
ParallelLMHead
)
...
@@ -119,13 +120,13 @@ class PhiAttention(nn.Module):
...
@@ -119,13 +120,13 @@ class PhiAttention(nn.Module):
# https://huggingface.co/microsoft/phi-1_5/blob/d212a789620c380ff32ca1d1ee9943a777360987/modeling_phi.py#L518
# https://huggingface.co/microsoft/phi-1_5/blob/d212a789620c380ff32ca1d1ee9943a777360987/modeling_phi.py#L518
rope_theta
=
10000
rope_theta
=
10000
max_position_embeddings
=
getattr
(
config
,
"n_positions"
,
2048
)
max_position_embeddings
=
getattr
(
config
,
"n_positions"
,
2048
)
self
.
attn
=
PagedAttentionWithRoPE
(
self
.
rotary_emb
=
get_rope
(
self
.
num_heads
,
self
.
head_size
,
self
.
head_size
,
scaling
,
rotary_dim
=
rotary_dim
,
rotary_dim
,
max_position
=
max_position_embeddings
,
base
=
rope_theta
,
base
=
rope_theta
,
max_position
=
max_position_embeddings
)
)
self
.
attn
=
PagedAttention
(
self
.
num_heads
,
self
.
head_size
,
scaling
)
def
forward
(
def
forward
(
self
,
self
,
...
@@ -137,9 +138,10 @@ class PhiAttention(nn.Module):
...
@@ -137,9 +138,10 @@ class PhiAttention(nn.Module):
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
qkv
,
_
=
self
.
Wqkv
(
hidden_states
)
qkv
,
_
=
self
.
Wqkv
(
hidden_states
)
q
,
k
,
v
=
qkv
.
chunk
(
chunks
=
3
,
dim
=-
1
)
q
,
k
,
v
=
qkv
.
chunk
(
chunks
=
3
,
dim
=-
1
)
q
,
k
=
self
.
rotary_emb
(
position_ids
,
q
,
k
)
k_cache
,
v_cache
=
kv_cache
k_cache
,
v_cache
=
kv_cache
attn_output
=
self
.
attn
(
position_ids
,
q
,
k
,
v
,
k_cache
,
v_cache
,
attn_output
=
self
.
attn
(
q
,
k
,
v
,
k_cache
,
v_cache
,
input_metadata
,
input_metadata
,
cache_event
)
cache_event
)
output
,
_
=
self
.
out_proj
(
attn_output
)
output
,
_
=
self
.
out_proj
(
attn_output
)
return
output
return
output
...
...
vllm/model_executor/models/qwen.py
View file @
a9e45742
...
@@ -11,12 +11,13 @@ from torch import nn
...
@@ -11,12 +11,13 @@ from torch import nn
from
vllm.model_executor.input_metadata
import
InputMetadata
from
vllm.model_executor.input_metadata
import
InputMetadata
from
vllm.model_executor.layers.activation
import
SiluAndMul
from
vllm.model_executor.layers.activation
import
SiluAndMul
from
vllm.model_executor.layers.attention
import
PagedAttention
WithRoPE
from
vllm.model_executor.layers.attention
import
PagedAttention
from
vllm.model_executor.layers.layernorm
import
RMSNorm
from
vllm.model_executor.layers.layernorm
import
RMSNorm
from
vllm.model_executor.layers.linear
import
(
LinearMethodBase
,
from
vllm.model_executor.layers.linear
import
(
LinearMethodBase
,
MergedColumnParallelLinear
,
MergedColumnParallelLinear
,
QKVParallelLinear
,
QKVParallelLinear
,
RowParallelLinear
)
RowParallelLinear
)
from
vllm.model_executor.layers.rotary_embedding
import
get_rope
from
vllm.model_executor.layers.sampler
import
Sampler
from
vllm.model_executor.layers.sampler
import
Sampler
from
vllm.model_executor.layers.vocab_parallel_embedding
import
(
from
vllm.model_executor.layers.vocab_parallel_embedding
import
(
VocabParallelEmbedding
,
ParallelLMHead
)
VocabParallelEmbedding
,
ParallelLMHead
)
...
@@ -95,14 +96,15 @@ class QWenAttention(nn.Module):
...
@@ -95,14 +96,15 @@ class QWenAttention(nn.Module):
linear_method
=
linear_method
,
linear_method
=
linear_method
,
)
)
self
.
scaling
=
self
.
head_dim
**-
0.5
self
.
scaling
=
self
.
head_dim
**-
0.5
self
.
attn
=
PagedAttentionWithRoPE
(
self
.
num_heads
,
self
.
rotary_emb
=
get_rope
(
self
.
head_dim
,
self
.
head_dim
,
self
.
scaling
,
rotary_dim
=
self
.
head_dim
,
rotary_dim
=
self
.
head_dim
,
base
=
rope_theta
,
max_position
=
max_position_embeddings
,
max_position
=
max_position_embeddings
,
rope_scaling
=
rope_scaling
)
base
=
rope_theta
,
rope_scaling
=
rope_scaling
,
)
self
.
attn
=
PagedAttention
(
self
.
num_heads
,
self
.
head_dim
,
self
.
scaling
)
def
forward
(
def
forward
(
self
,
self
,
...
@@ -114,10 +116,10 @@ class QWenAttention(nn.Module):
...
@@ -114,10 +116,10 @@ class QWenAttention(nn.Module):
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
qkv
,
_
=
self
.
c_attn
(
hidden_states
)
qkv
,
_
=
self
.
c_attn
(
hidden_states
)
q
,
k
,
v
=
qkv
.
chunk
(
chunks
=
3
,
dim
=-
1
)
q
,
k
,
v
=
qkv
.
chunk
(
chunks
=
3
,
dim
=-
1
)
q
,
k
=
self
.
rotary_emb
(
positions
,
q
,
k
)
k_cache
,
v_cache
=
kv_cache
k_cache
,
v_cache
=
kv_cache
attn_output
=
self
.
attn
(
positions
,
q
,
k
,
v
,
k_cache
,
v_cache
,
attn_output
=
self
.
attn
(
q
,
k
,
v
,
k_cache
,
v_cache
,
input_metadata
,
input_metadata
,
cache_event
)
cache_event
)
output
,
_
=
self
.
c_proj
(
attn_output
)
output
,
_
=
self
.
c_proj
(
attn_output
)
return
output
return
output
...
...
vllm/model_executor/models/yi.py
View file @
a9e45742
...
@@ -29,12 +29,13 @@ from vllm.transformers_utils.configs.yi import YiConfig
...
@@ -29,12 +29,13 @@ from vllm.transformers_utils.configs.yi import YiConfig
from
vllm.model_executor.input_metadata
import
InputMetadata
from
vllm.model_executor.input_metadata
import
InputMetadata
from
vllm.model_executor.layers.activation
import
SiluAndMul
from
vllm.model_executor.layers.activation
import
SiluAndMul
from
vllm.model_executor.layers.attention
import
PagedAttention
WithRoPE
from
vllm.model_executor.layers.attention
import
PagedAttention
from
vllm.model_executor.layers.layernorm
import
RMSNorm
from
vllm.model_executor.layers.layernorm
import
RMSNorm
from
vllm.model_executor.layers.linear
import
(
LinearMethodBase
,
from
vllm.model_executor.layers.linear
import
(
LinearMethodBase
,
MergedColumnParallelLinear
,
MergedColumnParallelLinear
,
QKVParallelLinear
,
QKVParallelLinear
,
RowParallelLinear
)
RowParallelLinear
)
from
vllm.model_executor.layers.rotary_embedding
import
get_rope
from
vllm.model_executor.layers.sampler
import
Sampler
from
vllm.model_executor.layers.sampler
import
Sampler
from
vllm.model_executor.layers.vocab_parallel_embedding
import
(
from
vllm.model_executor.layers.vocab_parallel_embedding
import
(
VocabParallelEmbedding
,
ParallelLMHead
)
VocabParallelEmbedding
,
ParallelLMHead
)
...
@@ -126,15 +127,17 @@ class YiAttention(nn.Module):
...
@@ -126,15 +127,17 @@ class YiAttention(nn.Module):
bias
=
False
,
bias
=
False
,
linear_method
=
linear_method
,
linear_method
=
linear_method
,
)
)
self
.
attn
=
PagedAttentionWithRoPE
(
self
.
rotary_emb
=
get_rope
(
self
.
num_heads
,
self
.
head_dim
,
self
.
head_dim
,
self
.
scaling
,
base
=
self
.
rope_theta
,
max_position
=
self
.
max_position_embeddings
,
rotary_dim
=
self
.
head_dim
,
rotary_dim
=
self
.
head_dim
,
num_kv_heads
=
self
.
num_kv_heads
,
max_position
=
max_position_embeddings
,
rope_scaling
=
rope_scaling
)
base
=
self
.
rope_theta
,
rope_scaling
=
rope_scaling
,
)
self
.
attn
=
PagedAttention
(
self
.
num_heads
,
self
.
head_dim
,
self
.
scaling
,
num_kv_heads
=
self
.
num_kv_heads
)
def
forward
(
def
forward
(
self
,
self
,
...
@@ -146,9 +149,10 @@ class YiAttention(nn.Module):
...
@@ -146,9 +149,10 @@ class YiAttention(nn.Module):
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
qkv
,
_
=
self
.
qkv_proj
(
hidden_states
)
qkv
,
_
=
self
.
qkv_proj
(
hidden_states
)
q
,
k
,
v
=
qkv
.
split
([
self
.
q_size
,
self
.
kv_size
,
self
.
kv_size
],
dim
=-
1
)
q
,
k
,
v
=
qkv
.
split
([
self
.
q_size
,
self
.
kv_size
,
self
.
kv_size
],
dim
=-
1
)
q
,
k
=
self
.
rotary_emb
(
positions
,
q
,
k
)
k_cache
,
v_cache
=
kv_cache
k_cache
,
v_cache
=
kv_cache
attn_output
=
self
.
attn
(
positions
,
q
,
k
,
v
,
k_cache
,
v_cache
,
attn_output
=
self
.
attn
(
q
,
k
,
v
,
k_cache
,
v_cache
,
input_metadata
,
input_metadata
,
cache_event
)
cache_event
)
output
,
_
=
self
.
o_proj
(
attn_output
)
output
,
_
=
self
.
o_proj
(
attn_output
)
return
output
return
output
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment