Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
norm
vllm
Commits
c1376e0f
"test/test_transforms_v2_functional.py" did not exist on "9851a69f6d294f5d672d973d8a1dbeebdd2aa04e"
Unverified
Commit
c1376e0f
authored
Oct 16, 2023
by
Woosuk Kwon
Committed by
GitHub
Oct 16, 2023
Browse files
Change scheduler & input tensor shape (#1381)
parent
651c614a
Changes
13
Hide whitespace changes
Inline
Side-by-side
Showing
13 changed files
with
181 additions
and
179 deletions
+181
-179
csrc/activation_kernels.cu
csrc/activation_kernels.cu
+14
-14
csrc/cache_kernels.cu
csrc/cache_kernels.cu
+7
-2
csrc/layernorm_kernels.cu
csrc/layernorm_kernels.cu
+6
-6
csrc/pos_encoding_kernels.cu
csrc/pos_encoding_kernels.cu
+11
-11
vllm/config.py
vllm/config.py
+3
-0
vllm/core/scheduler.py
vllm/core/scheduler.py
+11
-4
vllm/engine/arg_utils.py
vllm/engine/arg_utils.py
+7
-1
vllm/model_executor/input_metadata.py
vllm/model_executor/input_metadata.py
+5
-6
vllm/model_executor/layers/activation.py
vllm/model_executor/layers/activation.py
+8
-12
vllm/model_executor/layers/attention.py
vllm/model_executor/layers/attention.py
+71
-95
vllm/model_executor/layers/quantized_linear/awq.py
vllm/model_executor/layers/quantized_linear/awq.py
+2
-2
vllm/model_executor/layers/sampler.py
vllm/model_executor/layers/sampler.py
+2
-1
vllm/worker/worker.py
vllm/worker/worker.py
+34
-25
No files found.
csrc/activation_kernels.cu
View file @
c1376e0f
...
...
@@ -13,8 +13,8 @@ __device__ __forceinline__ T silu(const T& x) {
template
<
typename
scalar_t
>
__global__
void
silu_and_mul_kernel
(
scalar_t
*
__restrict__
out
,
// [
num_tokens
, d]
const
scalar_t
*
__restrict__
input
,
// [
num_tokens
, 2, d]
scalar_t
*
__restrict__
out
,
// [
...
, d]
const
scalar_t
*
__restrict__
input
,
// [
...
, 2, d]
const
int
d
)
{
const
int
token_idx
=
blockIdx
.
x
;
for
(
int
idx
=
threadIdx
.
x
;
idx
<
d
;
idx
+=
blockDim
.
x
)
{
...
...
@@ -27,11 +27,11 @@ __global__ void silu_and_mul_kernel(
}
// namespace vllm
void
silu_and_mul
(
torch
::
Tensor
&
out
,
// [
num_tokens
, d]
torch
::
Tensor
&
input
)
// [
num_tokens
, 2 * d]
torch
::
Tensor
&
out
,
// [
...
, d]
torch
::
Tensor
&
input
)
// [
...
, 2 * d]
{
int
num_tokens
=
input
.
size
(
0
);
int
d
=
input
.
size
(
1
)
/
2
;
int
num_tokens
=
input
.
numel
()
/
input
.
size
(
-
1
);
int
d
=
input
.
size
(
-
1
)
/
2
;
dim3
grid
(
num_tokens
);
dim3
block
(
std
::
min
(
d
,
1024
));
...
...
@@ -52,8 +52,8 @@ namespace vllm {
// Element-wise activation kernel template.
template
<
typename
scalar_t
,
scalar_t
(
*
ACT_FN
)(
const
scalar_t
&
)>
__global__
void
activation_kernel
(
scalar_t
*
__restrict__
out
,
// [
num_tokens
, d]
const
scalar_t
*
__restrict__
input
,
// [
num_tokens
, d]
scalar_t
*
__restrict__
out
,
// [
...
, d]
const
scalar_t
*
__restrict__
input
,
// [
...
, d]
const
int
d
)
{
const
int
token_idx
=
blockIdx
.
x
;
for
(
int
idx
=
threadIdx
.
x
;
idx
<
d
;
idx
+=
blockDim
.
x
)
{
...
...
@@ -66,8 +66,8 @@ __global__ void activation_kernel(
// Launch element-wise activation kernel.
#define LAUNCH_ACTIVATION_KERNEL(KERNEL) \
int
num_tokens
= input.size(
0);
\
int
d = input.size(1);
\
int
d
= input.size(
-1);
\
int
num_tokens = input.numel() / d;
\
dim3 grid(num_tokens); \
dim3 block(std::min(d, 1024)); \
const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); \
...
...
@@ -100,15 +100,15 @@ __device__ __forceinline__ T gelu_fast_kernel(const T& x) {
}
// namespace vllm
void
gelu_new
(
torch
::
Tensor
&
out
,
// [
num_tokens
, d]
torch
::
Tensor
&
input
)
// [
num_tokens
, d]
torch
::
Tensor
&
out
,
// [
...
, d]
torch
::
Tensor
&
input
)
// [
...
, d]
{
LAUNCH_ACTIVATION_KERNEL
(
vllm
::
gelu_new_kernel
);
}
void
gelu_fast
(
torch
::
Tensor
&
out
,
// [
num_tokens
, d]
torch
::
Tensor
&
input
)
// [
num_tokens
, d]
torch
::
Tensor
&
out
,
// [
...
, d]
torch
::
Tensor
&
input
)
// [
...
, d]
{
LAUNCH_ACTIVATION_KERNEL
(
vllm
::
gelu_fast_kernel
);
}
csrc/cache_kernels.cu
View file @
c1376e0f
...
...
@@ -154,6 +154,11 @@ __global__ void reshape_and_cache_kernel(
const
int
x
)
{
const
int
token_idx
=
blockIdx
.
x
;
const
int
slot_idx
=
slot_mapping
[
token_idx
];
if
(
slot_idx
<
0
)
{
// Padding token that should be ignored.
return
;
}
const
int
block_idx
=
slot_idx
/
block_size
;
const
int
block_offset
=
slot_idx
%
block_size
;
...
...
@@ -176,8 +181,8 @@ __global__ void reshape_and_cache_kernel(
+
head_idx
*
head_size
*
block_size
+
head_offset
*
block_size
+
block_offset
;
key_cache
[
tgt_key_idx
]
=
__ldg
(
&
key
[
src_key_idx
]
)
;
value_cache
[
tgt_value_idx
]
=
__ldg
(
&
value
[
src_value_idx
]
)
;
key_cache
[
tgt_key_idx
]
=
key
[
src_key_idx
];
value_cache
[
tgt_value_idx
]
=
value
[
src_value_idx
];
}
}
...
...
csrc/layernorm_kernels.cu
View file @
c1376e0f
...
...
@@ -9,8 +9,8 @@ namespace vllm {
// TODO(woosuk): Further optimize this kernel.
template
<
typename
scalar_t
>
__global__
void
rms_norm_kernel
(
scalar_t
*
__restrict__
out
,
// [
num_tokens
, hidden_size]
const
scalar_t
*
__restrict__
input
,
// [
num_tokens
, hidden_size]
scalar_t
*
__restrict__
out
,
// [
...
, hidden_size]
const
scalar_t
*
__restrict__
input
,
// [
...
, hidden_size]
const
scalar_t
*
__restrict__
weight
,
// [hidden_size]
const
float
epsilon
,
const
int
num_tokens
,
...
...
@@ -37,12 +37,12 @@ __global__ void rms_norm_kernel(
}
// namespace vllm
void
rms_norm
(
torch
::
Tensor
&
out
,
// [
num_tokens
, hidden_size]
torch
::
Tensor
&
input
,
// [
num_tokens
, hidden_size]
torch
::
Tensor
&
out
,
// [
...
, hidden_size]
torch
::
Tensor
&
input
,
// [
...
, hidden_size]
torch
::
Tensor
&
weight
,
// [hidden_size]
float
epsilon
)
{
int
num_tokens
=
input
.
size
(
0
);
int
hidden_size
=
input
.
size
(
1
)
;
int
hidden_size
=
input
.
size
(
-
1
);
int
num_tokens
=
input
.
numel
()
/
hidden_
size
;
dim3
grid
(
num_tokens
);
dim3
block
(
std
::
min
(
hidden_size
,
1024
));
...
...
csrc/pos_encoding_kernels.cu
View file @
c1376e0f
...
...
@@ -37,9 +37,9 @@ inline __device__ void apply_rotary_embedding(
template
<
typename
scalar_t
,
bool
IS_NEOX
>
__global__
void
rotary_embedding_kernel
(
const
int64_t
*
__restrict__
positions
,
// [num_tokens]
scalar_t
*
__restrict__
query
,
// [num_tokens, num_heads, head_size]
scalar_t
*
__restrict__
key
,
// [num_tokens, num_kv_heads, head_size]
const
int64_t
*
__restrict__
positions
,
//
[batch_size, seq_len] or
[num_tokens]
scalar_t
*
__restrict__
query
,
//
[batch_size, seq_len, num_heads, head_size] or
[num_tokens, num_heads, head_size]
scalar_t
*
__restrict__
key
,
//
[batch_size, seq_len, num_kv_heads, head_size] or
[num_tokens, num_kv_heads, head_size]
const
scalar_t
*
__restrict__
cos_sin_cache
,
// [max_position, 2, rot_dim // 2]
const
int
rot_dim
,
const
int
query_stride
,
...
...
@@ -78,18 +78,18 @@ __global__ void rotary_embedding_kernel(
}
// namespace vllm
void
rotary_embedding
(
torch
::
Tensor
&
positions
,
// [num_tokens]
torch
::
Tensor
&
query
,
// [num_tokens, num_heads * head_size]
torch
::
Tensor
&
key
,
// [num_tokens, num_kv_heads * head_size]
torch
::
Tensor
&
positions
,
//
[batch_size, seq_len] or
[num_tokens]
torch
::
Tensor
&
query
,
//
[batch_size, seq_len, num_heads * head_size] or
[num_tokens, num_heads * head_size]
torch
::
Tensor
&
key
,
//
[batch_size, seq_len, num_kv_heads * head_size] or
[num_tokens, num_kv_heads * head_size]
int
head_size
,
torch
::
Tensor
&
cos_sin_cache
,
// [max_position, rot_dim]
bool
is_neox
)
{
int
num_tokens
=
query
.
size
(
0
);
int
num_tokens
=
query
.
numel
()
/
query
.
size
(
-
1
);
int
rot_dim
=
cos_sin_cache
.
size
(
1
);
int
num_heads
=
query
.
size
(
1
)
/
head_size
;
int
num_kv_heads
=
key
.
size
(
1
)
/
head_size
;
int
query_stride
=
query
.
stride
(
0
);
int
key_stride
=
key
.
stride
(
0
);
int
num_heads
=
query
.
size
(
-
1
)
/
head_size
;
int
num_kv_heads
=
key
.
size
(
-
1
)
/
head_size
;
int
query_stride
=
query
.
stride
(
-
2
);
int
key_stride
=
key
.
stride
(
-
2
);
dim3
grid
(
num_tokens
);
dim3
block
(
std
::
min
(
num_heads
*
rot_dim
/
2
,
512
));
...
...
vllm/config.py
View file @
c1376e0f
...
...
@@ -268,6 +268,7 @@ class SchedulerConfig:
iteration.
max_model_len: Maximum length of a sequence (including prompt
and generated text).
max_paddings: Maximum number of paddings to be added to a batch.
"""
def
__init__
(
...
...
@@ -275,6 +276,7 @@ class SchedulerConfig:
max_num_batched_tokens
:
Optional
[
int
],
max_num_seqs
:
int
,
max_model_len
:
int
,
max_paddings
:
int
,
)
->
None
:
if
max_num_batched_tokens
is
not
None
:
self
.
max_num_batched_tokens
=
max_num_batched_tokens
...
...
@@ -284,6 +286,7 @@ class SchedulerConfig:
self
.
max_num_batched_tokens
=
max
(
max_model_len
,
2048
)
self
.
max_num_seqs
=
max_num_seqs
self
.
max_model_len
=
max_model_len
self
.
max_paddings
=
max_paddings
self
.
_verify_args
()
def
_verify_args
(
self
)
->
None
:
...
...
vllm/core/scheduler.py
View file @
c1376e0f
...
...
@@ -131,7 +131,8 @@ class Scheduler:
# requests in the generation phase.
num_curr_seqs
=
sum
(
seq_group
.
get_max_num_running_seqs
()
for
seq_group
in
self
.
running
)
num_batched_tokens
=
0
seq_lens
:
List
[
int
]
=
[]
# Optimization: We do not sort the waiting queue since the preempted
# sequence groups are added to the front and the new sequence groups
# are added to the back.
...
...
@@ -157,7 +158,9 @@ class Scheduler:
break
# If the number of batched tokens exceeds the limit, stop.
if
(
num_batched_tokens
+
num_prompt_tokens
>
new_seq_lens
=
seq_lens
+
[
num_prompt_tokens
]
num_batched_tokens
=
len
(
new_seq_lens
)
*
max
(
new_seq_lens
)
if
(
num_batched_tokens
>
self
.
scheduler_config
.
max_num_batched_tokens
):
break
...
...
@@ -168,10 +171,14 @@ class Scheduler:
self
.
scheduler_config
.
max_num_seqs
):
break
num_paddings
=
num_batched_tokens
-
sum
(
new_seq_lens
)
if
num_paddings
>
self
.
scheduler_config
.
max_paddings
:
break
seq_lens
=
new_seq_lens
seq_group
=
self
.
waiting
.
pop
(
0
)
self
.
_allocate
(
seq_group
)
self
.
running
.
append
(
seq_group
)
num_batched_tokens
+=
num_prompt_tokens
num_curr_seqs
+=
num_new_seqs
scheduled
.
append
(
seq_group
)
...
...
@@ -179,7 +186,7 @@ class Scheduler:
scheduler_outputs
=
SchedulerOutputs
(
scheduled_seq_groups
=
scheduled
,
prompt_run
=
True
,
num_batched_tokens
=
num_batched_tok
ens
,
num_batched_tokens
=
len
(
seq_lens
)
*
max
(
seq_l
ens
)
,
blocks_to_swap_in
=
blocks_to_swap_in
,
blocks_to_swap_out
=
blocks_to_swap_out
,
blocks_to_copy
=
blocks_to_copy
,
...
...
vllm/engine/arg_utils.py
View file @
c1376e0f
...
...
@@ -27,6 +27,7 @@ class EngineArgs:
gpu_memory_utilization
:
float
=
0.90
max_num_batched_tokens
:
Optional
[
int
]
=
None
max_num_seqs
:
int
=
256
max_paddings
:
int
=
256
disable_log_stats
:
bool
=
False
revision
:
Optional
[
str
]
=
None
tokenizer_revision
:
Optional
[
str
]
=
None
...
...
@@ -156,6 +157,10 @@ class EngineArgs:
type
=
int
,
default
=
EngineArgs
.
max_num_seqs
,
help
=
'maximum number of sequences per iteration'
)
parser
.
add_argument
(
'--max-paddings'
,
type
=
int
,
default
=
EngineArgs
.
max_paddings
,
help
=
'maximum number of paddings in a batch'
)
parser
.
add_argument
(
'--disable-log-stats'
,
action
=
'store_true'
,
help
=
'disable logging statistics'
)
...
...
@@ -193,7 +198,8 @@ class EngineArgs:
self
.
worker_use_ray
)
scheduler_config
=
SchedulerConfig
(
self
.
max_num_batched_tokens
,
self
.
max_num_seqs
,
model_config
.
max_model_len
)
model_config
.
max_model_len
,
self
.
max_paddings
)
return
model_config
,
cache_config
,
parallel_config
,
scheduler_config
...
...
vllm/model_executor/input_metadata.py
View file @
c1376e0f
...
...
@@ -39,11 +39,12 @@ class InputMetadata:
self
.
max_context_len
=
max_context_len
self
.
block_tables
=
block_tables
self
.
max_prompt_len
=
max
(
prompt_lens
)
if
prompt_lens
else
0
self
.
to_cache
=
None
if
sliding_window
is
not
None
:
# We need to keep the positions of sliding windows within
# the key / value tables, this is helpful to know which
# elements we need to cache
and where
# elements we need to cache
.
to_cache
,
start_idx
=
[],
0
for
prompt_len
in
self
.
prompt_lens
:
to_cache
.
extend
(
...
...
@@ -51,16 +52,15 @@ class InputMetadata:
start_idx
+
max
(
0
,
prompt_len
-
sliding_window
),
start_idx
+
prompt_len
,
))
start_idx
+=
prompt_len
start_idx
+=
self
.
max_
prompt_len
to_cache
.
extend
(
range
(
start_idx
,
slot_mapping
.
shape
[
0
]))
self
.
to_cache
=
torch
.
tensor
(
to_cache
,
dtype
=
torch
.
int32
,
device
=
self
.
slot_mapping
.
device
)
self
.
num_prompts
=
len
(
prompt_lens
)
self
.
num_prompt_tokens
=
s
um
(
prompt_len
s
)
self
.
num_prompt_tokens
=
s
elf
.
num_prompts
*
self
.
max_
prompt_len
self
.
num_generation_tokens
=
context_lens
.
shape
[
0
]
self
.
num_valid_tokens
=
slot_mapping
.
shape
[
0
]
if
block_tables
.
numel
()
>
0
:
self
.
max_num_blocks_per_seq
=
block_tables
.
shape
[
1
]
else
:
...
...
@@ -69,12 +69,11 @@ class InputMetadata:
assert
context_lens
.
shape
[
0
]
==
self
.
num_generation_tokens
# Set during the execution of the first attention op.
self
.
attn_bias
:
List
[
AttentionBias
]
=
[]
self
.
attn_bias
:
Optional
[
AttentionBias
]
=
None
def
__repr__
(
self
)
->
str
:
# Print only useful metadata.
return
(
f
'InputMetadata('
f
'num_valid_tokens=
{
self
.
num_valid_tokens
}
, '
f
'num_prompt_tokens=
{
self
.
num_prompt_tokens
}
, '
f
'num_prompts=
{
self
.
num_prompts
}
, '
f
'prompt_lens=
{
self
.
prompt_lens
}
, '
...
...
vllm/model_executor/layers/activation.py
View file @
c1376e0f
...
...
@@ -8,17 +8,17 @@ from vllm import activation_ops
class
SiluAndMul
(
nn
.
Module
):
"""An activation function for SwiGLU.
The function computes x -> silu(x[:d]) * x[d:] where d = x.shape[1] // 2.
The function computes x -> silu(x[:d]) * x[d:] where d = x.shape[
-
1] // 2.
Shapes:
x: (num_tokens, 2 * d)
return: (num_tokens, d)
x:
(batch_size, seq_len, 2 * d) or
(num_tokens, 2 * d)
return:
(batch_size, seq_len, d) or
(num_tokens, d)
"""
def
forward
(
self
,
x
:
torch
.
Tensor
)
->
torch
.
Tensor
:
num_tokens
=
x
.
shape
[
0
]
d
=
x
.
shape
[
1
]
//
2
out
=
torch
.
empty
(
num_tokens
,
d
,
dtype
=
x
.
dtype
,
device
=
x
.
device
)
d
=
x
.
shape
[
-
1
]
//
2
output_shape
=
(
x
.
shape
[
:
-
1
]
+
(
d
,
))
out
=
torch
.
empty
(
output_shape
,
dtype
=
x
.
dtype
,
device
=
x
.
device
)
activation_ops
.
silu_and_mul
(
out
,
x
)
return
out
...
...
@@ -26,9 +26,7 @@ class SiluAndMul(nn.Module):
class
NewGELU
(
nn
.
Module
):
def
forward
(
self
,
x
:
torch
.
Tensor
)
->
torch
.
Tensor
:
num_tokens
=
x
.
shape
[
0
]
d
=
x
.
shape
[
1
]
out
=
torch
.
empty
(
num_tokens
,
d
,
dtype
=
x
.
dtype
,
device
=
x
.
device
)
out
=
torch
.
empty_like
(
x
)
activation_ops
.
gelu_new
(
out
,
x
)
return
out
...
...
@@ -36,9 +34,7 @@ class NewGELU(nn.Module):
class
FastGELU
(
nn
.
Module
):
def
forward
(
self
,
x
:
torch
.
Tensor
)
->
torch
.
Tensor
:
num_tokens
=
x
.
shape
[
0
]
d
=
x
.
shape
[
1
]
out
=
torch
.
empty
(
num_tokens
,
d
,
dtype
=
x
.
dtype
,
device
=
x
.
device
)
out
=
torch
.
empty_like
(
x
)
activation_ops
.
gelu_fast
(
out
,
x
)
return
out
...
...
vllm/model_executor/layers/attention.py
View file @
c1376e0f
...
...
@@ -23,25 +23,9 @@ class PagedAttention(nn.Module):
# pylint: disable=line-too-long
"""GPT-style multi-head PagedAttention.
This class takes flattened 1D query, key, and value tensors as input. The
input 1D tensors can either contain prompt tokens or generation tokens, in
addition to paddings.
If the input tensors contain prompt tokens, the layout is as follows:
|<---------------------- num_valid_tokens ---------------------->|
|<--------------- num_prompt_tokens -------------->|
|<--prompt_0-->|<--prompt_1-->|...|<--prompt_N-1-->|<--padding-->|
Otherwise, the layout is as follows:
|<------------------ num_valid_tokens ------------------->|
|<------- num_generation_tokens (M) ------->|
|<--generation_0-->|...|<--generation_M-1-->|<--padding-->|
The prompts might have different lengths, while the generation tokens always
have length 1. The paddings are appended to make the input length a multiple
of 8, which is desirable for Tensor Cores.
This class takes query, key, and value tensors as input. The input tensors
can either contain prompt tokens or generation tokens, in addition to
paddings.
The class does the following:
1. Perform multi_query_kv_attention for the prompts. This operation does
...
...
@@ -53,7 +37,7 @@ class PagedAttention(nn.Module):
4. Perform single_query_cached_kv_attention for the generation tokens.
This operation reads the previous key and value tensors from the KV
cache.
5.
Output a flattened 1D
tensor.
5.
Return the output
tensor.
"""
def
__init__
(
self
,
...
...
@@ -85,14 +69,15 @@ class PagedAttention(nn.Module):
dtype
:
torch
.
dtype
,
)
->
None
:
del
dtype
# Unused.
if
input_metadata
.
attn_bias
:
if
input_metadata
.
attn_bias
is
not
None
:
# Already set by a previous layer.
return
prompt_lens
=
input_metadata
.
prompt_lens
prompt_lens
=
[
input_metadata
.
max_prompt_len
]
*
input_metadata
.
num_prompts
attn_bias
=
BlockDiagonalCausalMask
.
from_seqlens
(
prompt_lens
)
if
self
.
sliding_window
is
not
None
:
attn_bias
=
attn_bias
.
make_local_attention
(
self
.
sliding_window
)
input_metadata
.
attn_bias
.
append
(
attn_bias
)
input_metadata
.
attn_bias
=
attn_bias
def
multi_query_kv_attention
(
self
,
...
...
@@ -111,7 +96,6 @@ class PagedAttention(nn.Module):
value: shape = [num_prompt_tokens, num_kv_heads, head_size]
input_metadata: metadata for paged attention.
"""
if
self
.
num_kv_heads
!=
self
.
num_heads
:
# Project the key and value tensors to the desired number of heads.
key
=
torch
.
repeat_interleave
(
key
,
self
.
num_queries_per_kv
,
dim
=
1
)
...
...
@@ -124,7 +108,7 @@ class PagedAttention(nn.Module):
query
.
unsqueeze
(
0
),
key
.
unsqueeze
(
0
),
value
.
unsqueeze
(
0
),
attn_bias
=
input_metadata
.
attn_bias
[
0
]
,
attn_bias
=
input_metadata
.
attn_bias
,
p
=
0.0
,
scale
=
self
.
scale
,
)
...
...
@@ -232,12 +216,12 @@ class PagedAttention(nn.Module):
"""PagedAttention forward pass.
NOTE: The query, key, and value tensors must be sliced from a qkv
tensor of shape [
num_tok
en
s
, 3 * num_heads * head_size].
tensor of shape [
batch_size, seq_l
en, 3 * num_heads * head_size].
Args:
query: shape = [
num_tok
en
s
, num_heads * head_size]
key: shape = [
num_tok
en
s
, num_kv_heads * head_size]
value: shape = [
num_tokens
, num_kv_heads * head_size]
query: shape = [
batch_size, seq_l
en, num_heads * head_size]
key: shape = [
batch_size, seq_l
en, num_kv_heads * head_size]
value: shape = [
batch_size
, num_kv_heads * head_size]
key_cache: shape = [num_blocks, num_kv_heads, head_size/x,
block_size, x]
value_cache: shape = [num_blocks, num_kv_heads, head_size,
...
...
@@ -246,9 +230,9 @@ class PagedAttention(nn.Module):
cache_event: event to wait for the cache operations to finish.
Returns:
shape = [
num_tok
en
s
, num_heads * head_size]
shape = [
batch_size, seq_l
en, num_heads * head_size]
"""
batch_size
,
seq_len
,
_
=
query
.
shape
# Reshape the query, key, and value tensors.
query
=
query
.
view
(
-
1
,
self
.
num_heads
,
self
.
head_size
)
key
=
key
.
view
(
-
1
,
self
.
num_kv_heads
,
self
.
head_size
)
...
...
@@ -264,10 +248,10 @@ class PagedAttention(nn.Module):
assert
input_metadata
.
num_generation_tokens
==
0
self
.
set_attn_bias
(
input_metadata
,
dtype
=
query
.
dtype
)
self
.
multi_query_kv_attention
(
output
[:
num_prompt_tokens
]
,
query
[:
num_prompt_tokens
]
,
key
[:
num_prompt_tokens
]
,
value
[:
num_prompt_tokens
]
,
output
,
query
,
key
,
value
,
input_metadata
,
)
...
...
@@ -278,13 +262,10 @@ class PagedAttention(nn.Module):
# Reshape the keys and values and store them in the cache.
# When key_cache and value_cache are not provided, the new key
# and value vectors will not be cached.
num_valid_tokens
=
input_metadata
.
num_valid_tokens
if
(
num_valid_tokens
>
0
and
key_cache
is
not
None
and
value_cache
is
not
None
):
# The stride is 3 because the key and value are sliced from qkv.
key_to_cache
=
key
[:
num_valid_tokens
]
value_to_cache
=
value
[:
num_valid_tokens
]
slot_mapping
=
input_metadata
.
slot_mapping
if
key_cache
is
not
None
and
value_cache
is
not
None
:
key_to_cache
=
key
value_to_cache
=
value
slot_mapping
=
input_metadata
.
slot_mapping
.
view
(
-
1
)
if
input_metadata
.
to_cache
is
not
None
:
key_to_cache
=
key_to_cache
[
input_metadata
.
to_cache
]
value_to_cache
=
value_to_cache
[
input_metadata
.
to_cache
]
...
...
@@ -305,14 +286,14 @@ class PagedAttention(nn.Module):
"key_cache and value_cache must be provided when "
"generating tokens."
)
# Compute the attention op for generation tokens.
self
.
single_query_cached_kv_attention
(
output
[
num_prompt_tokens
:
num_valid_tokens
],
query
[
num_prompt_tokens
:
num_valid_tokens
],
key_cache
,
value_cache
,
input_metadata
,
self
.
get_alibi_slopes
())
self
.
single_query_cached_kv_attention
(
output
,
query
,
key_cache
,
value_cache
,
input_metadata
,
self
.
get_alibi_slopes
())
# Reshape the output tensor.
# NOTE(woosuk): The output tensor may include paddings.
return
output
.
view
(
-
1
,
self
.
num_heads
*
self
.
head_size
)
return
output
.
view
(
batch_size
,
seq_len
,
self
.
num_heads
*
self
.
head_size
)
class
PagedAttentionWithRoPE
(
PagedAttention
):
...
...
@@ -368,10 +349,10 @@ class PagedAttentionWithRoPE(PagedAttention):
""" PagedAttention forward pass with rotary embedding.
Args:
positions: shape = [
num_tok
en
s
]
query: shape = [
num_tok
en
s
, num_heads * head_size]
key: shape = [
num_tok
en
s
, num_kv_heads * head_size]
value: shape = [
num_tok
en
s
, num_kv_heads * head_size]
positions: shape = [
batch_size, seq_l
en]
query: shape = [
batch_size, seq_l
en, num_heads * head_size]
key: shape = [
batch_size, seq_l
en, num_kv_heads * head_size]
value: shape = [
batch_size, seq_l
en, num_kv_heads * head_size]
key_cache: shape = [num_blocks, num_kv_heads, head_size/x,
block_size, x]
value_cache: shape = [num_blocks, num_kv_heads, head_size,
...
...
@@ -380,7 +361,7 @@ class PagedAttentionWithRoPE(PagedAttention):
cache_event: event to wait for the cache operations to finish.
Returns:
shape = [
num_tok
en
s
, num_heads * head_size]
shape = [
batch_size, seq_l
en, num_heads * head_size]
"""
# Apply rotary embedding to the query and key before passing them
...
...
@@ -414,34 +395,34 @@ class PagedAttentionWithALiBi(PagedAttention):
def
set_attn_bias
(
self
,
input_metadata
:
InputMetadata
,
dtype
:
torch
.
dtype
)
->
None
:
if
input_metadata
.
attn_bias
:
if
input_metadata
.
attn_bias
is
not
None
:
# Already set by a previous layer.
return
# Generates ALiBi mask
for each prompt
.
for
prompt_len
in
input_metadata
.
prompt_len
s
:
bias
=
torch
.
arange
(
prompt_len
,
dtype
=
dtype
)
# NOTE(zhuohan): HF uses
# `bias = bias[None, :].repeat(prompt_len, 1)`
# here. We find that both biases give the same results, but
# the bias below more accurately follows the original ALiBi
# paper.
bias
=
bias
[
None
,
:]
-
bias
[:,
None
]
bias
=
bias
.
to
(
self
.
alibi_slopes
.
device
)
# When using custom attention bias, xformers requires the bias to
# be sliced from a tensor whose length is a multiple of 8.
padded_len
=
(
prompt_len
+
7
)
//
8
*
8
bias
=
torch
.
empty
(
1
,
# batch_size
self
.
num_heads
,
prompt_len
,
padded_len
,
device
=
self
.
alibi_slopes
.
device
,
dtype
=
dtype
,
)[:,
:,
:,
:
prompt_len
].
copy_
(
bias
)
bias
.
mul_
(
self
.
alibi_slopes
[:,
None
,
None
])
attn_bias
=
LowerTriangularMaskWithTensorBias
(
bias
)
input_metadata
.
attn_bias
.
append
(
attn_bias
)
# Generates ALiBi mask
based on the max prompt length
.
max_
prompt_len
=
input_metadata
.
max_
prompt_len
bias
=
torch
.
arange
(
max_
prompt_len
,
dtype
=
dtype
)
# NOTE(zhuohan): HF uses
# `bias = bias[None, :].repeat(prompt_len, 1)`
# here. We find that both biases give the same results, but
# the bias below more accurately follows the original ALiBi
# paper.
bias
=
bias
[
None
,
:]
-
bias
[:,
None
]
bias
=
bias
.
to
(
self
.
alibi_slopes
.
device
)
# When using custom attention bias, xformers requires the bias to
# be sliced from a tensor whose length is a multiple of 8.
padded_len
=
(
max_
prompt_len
+
7
)
//
8
*
8
bias
=
torch
.
empty
(
input_metadata
.
num_prompts
,
self
.
num_heads
,
max_
prompt_len
,
padded_len
,
device
=
self
.
alibi_slopes
.
device
,
dtype
=
dtype
,
)[:,
:,
:,
:
max_
prompt_len
].
copy_
(
bias
)
bias
.
mul_
(
self
.
alibi_slopes
[:,
None
,
None
])
attn_bias
=
LowerTriangularMaskWithTensorBias
(
bias
)
input_metadata
.
attn_bias
=
attn_bias
def
multi_query_kv_attention
(
self
,
...
...
@@ -466,24 +447,19 @@ class PagedAttentionWithALiBi(PagedAttention):
value
=
torch
.
repeat_interleave
(
value
,
self
.
num_queries_per_kv
,
dim
=
1
)
batch_size
=
input_metadata
.
num_prompts
seq_len
=
input_metadata
.
max_prompt_len
# FIXME(woosuk): Because xformers does not support dynamic sequence
# lengths with custom attention bias, we process each prompt one by
# one. This is inefficient, especially when we have many short prompts.
start
=
0
for
i
,
prompt_len
in
enumerate
(
input_metadata
.
prompt_lens
):
end
=
start
+
prompt_len
out
=
xops
.
memory_efficient_attention_forward
(
query
[
None
,
start
:
end
],
key
[
None
,
start
:
end
],
value
[
None
,
start
:
end
],
attn_bias
=
input_metadata
.
attn_bias
[
i
],
p
=
0.0
,
scale
=
self
.
scale
,
)
# TODO(woosuk): Unnecessary copy. Optimize.
output
[
start
:
end
].
copy_
(
out
.
squeeze
(
0
))
start
+=
prompt_len
out
=
xops
.
memory_efficient_attention_forward
(
query
.
view
(
batch_size
,
seq_len
,
self
.
num_heads
,
self
.
head_size
),
key
.
view
(
batch_size
,
seq_len
,
self
.
num_heads
,
self
.
head_size
),
value
.
view
(
batch_size
,
seq_len
,
self
.
num_heads
,
self
.
head_size
),
attn_bias
=
input_metadata
.
attn_bias
,
p
=
0.0
,
scale
=
self
.
scale
,
)
# TODO(woosuk): Unnecessary copy. Optimize.
output
.
copy_
(
out
.
view
(
-
1
,
self
.
num_heads
,
self
.
head_size
))
return
output
def
get_alibi_slopes
(
self
)
->
Optional
[
torch
.
Tensor
]:
...
...
vllm/model_executor/layers/quantized_linear/awq.py
View file @
c1376e0f
...
...
@@ -50,7 +50,7 @@ class AWQColumnParallelLinear(ColumnParallelLinear):
bias
:
Optional
[
torch
.
Tensor
],
)
->
torch
.
Tensor
:
pack_factor
=
self
.
quant_config
.
pack_factor
out_shape
=
(
x
.
shape
[
-
2
],
self
.
qweight
.
shape
[
-
1
]
*
pack_factor
)
out_shape
=
(
x
.
shape
[
:
-
1
]
+
(
self
.
qweight
.
shape
[
-
1
]
*
pack_factor
,
)
)
reshaped_x
=
x
.
reshape
(
-
1
,
x
.
shape
[
-
1
])
out
=
quantization_ops
.
awq_gemm
(
reshaped_x
,
self
.
qweight
,
self
.
scales
,
self
.
qzeros
,
pack_factor
)
...
...
@@ -95,7 +95,7 @@ class AWQRowParallelLinear(RowParallelLinear):
def
apply_weights
(
self
,
x
:
torch
.
Tensor
)
->
torch
.
Tensor
:
pack_factor
=
self
.
quant_config
.
pack_factor
out_shape
=
(
x
.
shape
[
-
2
],
self
.
qweight
.
shape
[
-
1
]
*
pack_factor
)
out_shape
=
(
x
.
shape
[
:
-
1
]
+
(
self
.
qweight
.
shape
[
-
1
]
*
pack_factor
,
)
)
reshaped_x
=
x
.
reshape
(
-
1
,
x
.
shape
[
-
1
])
out
=
quantization_ops
.
awq_gemm
(
reshaped_x
,
self
.
qweight
,
self
.
scales
,
self
.
qzeros
,
pack_factor
)
...
...
vllm/model_executor/layers/sampler.py
View file @
c1376e0f
...
...
@@ -119,7 +119,7 @@ def _prune_hidden_states(
selected_token_indices
.
extend
(
range
(
start_idx
,
start_idx
+
prompt_len
-
1
))
selected_token_indices
.
append
(
start_idx
+
prompt_len
-
1
)
start_idx
+=
prompt_len
start_idx
+=
input_metadata
.
max_
prompt_len
else
:
num_seqs
=
len
(
seq_ids
)
selected_token_indices
.
extend
(
...
...
@@ -129,6 +129,7 @@ def _prune_hidden_states(
selected_token_indices
=
torch
.
tensor
(
selected_token_indices
,
dtype
=
torch
.
long
,
device
=
hidden_states
.
device
)
hidden_states
=
hidden_states
.
view
(
-
1
,
hidden_states
.
shape
[
-
1
])
return
hidden_states
.
index_select
(
0
,
selected_token_indices
)
...
...
vllm/worker/worker.py
View file @
c1376e0f
...
...
@@ -158,9 +158,9 @@ class Worker:
seq_group_metadata_list
:
List
[
SequenceGroupMetadata
],
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
,
InputMetadata
]:
seq_groups
:
List
[
Tuple
[
List
[
int
],
SamplingParams
]]
=
[]
input_tokens
:
List
[
int
]
=
[]
input_positions
:
List
[
int
]
=
[]
slot_mapping
:
List
[
int
]
=
[]
input_tokens
:
List
[
List
[
int
]
]
=
[]
input_positions
:
List
[
List
[
int
]
]
=
[]
slot_mapping
:
List
[
List
[
int
]
]
=
[]
# Add prompt tokens.
prompt_lens
:
List
[
int
]
=
[]
...
...
@@ -180,24 +180,25 @@ class Worker:
prompt_len
=
len
(
prompt_tokens
)
prompt_lens
.
append
(
prompt_len
)
input_tokens
.
ext
end
(
prompt_tokens
)
input_tokens
.
app
end
(
prompt_tokens
)
# NOTE(woosuk): Here we assume that the first token in the prompt
# is always the first token in the sequence.
input_positions
.
ext
end
(
range
(
len
(
prompt_
tok
en
s
)))
input_positions
.
app
end
(
list
(
range
(
prompt_
l
en
)))
if
seq_group_metadata
.
block_tables
is
None
:
# During memory profiling, the block tables are not initialized
# yet. In this case, we just use a dummy slot mapping.
slot_mapping
.
ext
end
([
0
]
*
prompt_len
)
slot_mapping
.
app
end
([
0
]
*
prompt_len
)
continue
# Compute the slot mapping.
slot_mapping
.
append
([])
block_table
=
seq_group_metadata
.
block_tables
[
seq_id
]
for
i
in
range
(
prompt_len
):
block_number
=
block_table
[
i
//
self
.
block_size
]
block_offset
=
i
%
self
.
block_size
slot
=
block_number
*
self
.
block_size
+
block_offset
slot_mapping
.
append
(
slot
)
slot_mapping
[
-
1
]
.
append
(
slot
)
# Add generation tokens.
max_context_len
=
0
...
...
@@ -215,13 +216,13 @@ class Worker:
for
seq_id
in
seq_ids
:
seq_data
=
seq_group_metadata
.
seq_data
[
seq_id
]
generation_token
=
seq_data
.
get_last_token_id
()
input_tokens
.
append
(
generation_token
)
input_tokens
.
append
(
[
generation_token
]
)
context_len
=
seq_data
.
get_len
()
position
=
context_len
-
1
if
self
.
sliding_window
is
not
None
:
context_len
=
min
(
context_len
,
self
.
sliding_window
)
input_positions
.
append
(
position
)
input_positions
.
append
(
[
position
]
)
block_table
=
seq_group_metadata
.
block_tables
[
seq_id
]
...
...
@@ -233,7 +234,7 @@ class Worker:
block_number
=
block_table
[
position
//
self
.
block_size
]
block_offset
=
position
%
self
.
block_size
slot
=
block_number
*
self
.
block_size
+
block_offset
slot_mapping
.
append
(
slot
)
slot_mapping
.
append
(
[
slot
]
)
if
self
.
sliding_window
is
not
None
:
sliding_window_blocks
=
(
self
.
sliding_window
//
...
...
@@ -241,28 +242,36 @@ class Worker:
block_table
=
block_table
[
-
sliding_window_blocks
:]
generation_block_tables
.
append
(
block_table
)
# Optimization: Pad the input length to be a multiple of 8.
# This is required for utilizing the Tensor Cores in NVIDIA GPUs.
input_tokens
=
_pad_to_alignment
(
input_tokens
,
multiple_of
=
8
)
input_positions
=
_pad_to_alignment
(
input_positions
,
multiple_of
=
8
)
max_seq_len
=
max
(
prompt_lens
)
if
prompt_lens
else
1
padded_input_tokens
=
[
_pad_to_max
(
tokens
,
max_seq_len
,
pad
=
0
)
for
tokens
in
input_tokens
]
padded_input_positions
=
[
_pad_to_max
(
positions
,
max_seq_len
,
pad
=
0
)
for
positions
in
input_positions
]
padded_slot_mapping
=
[
_pad_to_max
(
mapping
,
max_seq_len
,
pad
=-
1
)
for
mapping
in
slot_mapping
]
padded_block_tables
=
[
_pad_to_max
(
block_table
,
max_num_blocks_per_seq
,
pad
=
0
)
for
block_table
in
generation_block_tables
]
# Convert to tensors.
tokens_tensor
=
torch
.
tensor
(
input_tokens
,
tokens_tensor
=
torch
.
tensor
(
padded_
input_tokens
,
dtype
=
torch
.
long
,
device
=
"cuda"
)
positions_tensor
=
torch
.
tensor
(
input_positions
,
positions_tensor
=
torch
.
tensor
(
padded_
input_positions
,
dtype
=
torch
.
long
,
device
=
"cuda"
)
slot_mapping_tensor
=
torch
.
tensor
(
slot_mapping
,
slot_mapping_tensor
=
torch
.
tensor
(
padded_
slot_mapping
,
dtype
=
torch
.
int
,
device
=
"cuda"
)
context_lens_tensor
=
torch
.
tensor
(
context_lens
,
dtype
=
torch
.
int
,
device
=
"cuda"
)
padded_block_tables
=
[
_pad_to_max
(
block_table
,
max_num_blocks_per_seq
)
for
block_table
in
generation_block_tables
]
block_tables_tensor
=
torch
.
tensor
(
padded_block_tables
,
dtype
=
torch
.
int
,
device
=
"cuda"
)
...
...
@@ -361,12 +370,12 @@ def _init_distributed_environment(
parallel_config
.
pipeline_parallel_size
)
def
_pad_to_alignment
(
x
:
List
[
int
],
multiple_of
:
int
)
->
List
[
int
]:
return
x
+
[
0
]
*
((
-
len
(
x
))
%
multiple_of
)
def
_pad_to_alignment
(
x
:
List
[
int
],
multiple_of
:
int
,
pad
:
int
)
->
List
[
int
]:
return
x
+
[
pad
]
*
((
-
len
(
x
))
%
multiple_of
)
def
_pad_to_max
(
x
:
List
[
int
],
max_len
:
int
)
->
List
[
int
]:
return
x
+
[
0
]
*
(
max_len
-
len
(
x
))
def
_pad_to_max
(
x
:
List
[
int
],
max_len
:
int
,
pad
:
int
)
->
List
[
int
]:
return
x
+
[
pad
]
*
(
max_len
-
len
(
x
))
def
_check_if_can_support_max_seq_len
(
max_seq_len
:
int
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment