Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
norm
vllm
Commits
c1376e0f
Unverified
Commit
c1376e0f
authored
Oct 16, 2023
by
Woosuk Kwon
Committed by
GitHub
Oct 16, 2023
Browse files
Change scheduler & input tensor shape (#1381)
parent
651c614a
Changes
13
Hide whitespace changes
Inline
Side-by-side
Showing
13 changed files
with
181 additions
and
179 deletions
+181
-179
csrc/activation_kernels.cu
csrc/activation_kernels.cu
+14
-14
csrc/cache_kernels.cu
csrc/cache_kernels.cu
+7
-2
csrc/layernorm_kernels.cu
csrc/layernorm_kernels.cu
+6
-6
csrc/pos_encoding_kernels.cu
csrc/pos_encoding_kernels.cu
+11
-11
vllm/config.py
vllm/config.py
+3
-0
vllm/core/scheduler.py
vllm/core/scheduler.py
+11
-4
vllm/engine/arg_utils.py
vllm/engine/arg_utils.py
+7
-1
vllm/model_executor/input_metadata.py
vllm/model_executor/input_metadata.py
+5
-6
vllm/model_executor/layers/activation.py
vllm/model_executor/layers/activation.py
+8
-12
vllm/model_executor/layers/attention.py
vllm/model_executor/layers/attention.py
+71
-95
vllm/model_executor/layers/quantized_linear/awq.py
vllm/model_executor/layers/quantized_linear/awq.py
+2
-2
vllm/model_executor/layers/sampler.py
vllm/model_executor/layers/sampler.py
+2
-1
vllm/worker/worker.py
vllm/worker/worker.py
+34
-25
No files found.
csrc/activation_kernels.cu
View file @
c1376e0f
...
@@ -13,8 +13,8 @@ __device__ __forceinline__ T silu(const T& x) {
...
@@ -13,8 +13,8 @@ __device__ __forceinline__ T silu(const T& x) {
template
<
typename
scalar_t
>
template
<
typename
scalar_t
>
__global__
void
silu_and_mul_kernel
(
__global__
void
silu_and_mul_kernel
(
scalar_t
*
__restrict__
out
,
// [
num_tokens
, d]
scalar_t
*
__restrict__
out
,
// [
...
, d]
const
scalar_t
*
__restrict__
input
,
// [
num_tokens
, 2, d]
const
scalar_t
*
__restrict__
input
,
// [
...
, 2, d]
const
int
d
)
{
const
int
d
)
{
const
int
token_idx
=
blockIdx
.
x
;
const
int
token_idx
=
blockIdx
.
x
;
for
(
int
idx
=
threadIdx
.
x
;
idx
<
d
;
idx
+=
blockDim
.
x
)
{
for
(
int
idx
=
threadIdx
.
x
;
idx
<
d
;
idx
+=
blockDim
.
x
)
{
...
@@ -27,11 +27,11 @@ __global__ void silu_and_mul_kernel(
...
@@ -27,11 +27,11 @@ __global__ void silu_and_mul_kernel(
}
// namespace vllm
}
// namespace vllm
void
silu_and_mul
(
void
silu_and_mul
(
torch
::
Tensor
&
out
,
// [
num_tokens
, d]
torch
::
Tensor
&
out
,
// [
...
, d]
torch
::
Tensor
&
input
)
// [
num_tokens
, 2 * d]
torch
::
Tensor
&
input
)
// [
...
, 2 * d]
{
{
int
num_tokens
=
input
.
size
(
0
);
int
num_tokens
=
input
.
numel
()
/
input
.
size
(
-
1
);
int
d
=
input
.
size
(
1
)
/
2
;
int
d
=
input
.
size
(
-
1
)
/
2
;
dim3
grid
(
num_tokens
);
dim3
grid
(
num_tokens
);
dim3
block
(
std
::
min
(
d
,
1024
));
dim3
block
(
std
::
min
(
d
,
1024
));
...
@@ -52,8 +52,8 @@ namespace vllm {
...
@@ -52,8 +52,8 @@ namespace vllm {
// Element-wise activation kernel template.
// Element-wise activation kernel template.
template
<
typename
scalar_t
,
scalar_t
(
*
ACT_FN
)(
const
scalar_t
&
)>
template
<
typename
scalar_t
,
scalar_t
(
*
ACT_FN
)(
const
scalar_t
&
)>
__global__
void
activation_kernel
(
__global__
void
activation_kernel
(
scalar_t
*
__restrict__
out
,
// [
num_tokens
, d]
scalar_t
*
__restrict__
out
,
// [
...
, d]
const
scalar_t
*
__restrict__
input
,
// [
num_tokens
, d]
const
scalar_t
*
__restrict__
input
,
// [
...
, d]
const
int
d
)
{
const
int
d
)
{
const
int
token_idx
=
blockIdx
.
x
;
const
int
token_idx
=
blockIdx
.
x
;
for
(
int
idx
=
threadIdx
.
x
;
idx
<
d
;
idx
+=
blockDim
.
x
)
{
for
(
int
idx
=
threadIdx
.
x
;
idx
<
d
;
idx
+=
blockDim
.
x
)
{
...
@@ -66,8 +66,8 @@ __global__ void activation_kernel(
...
@@ -66,8 +66,8 @@ __global__ void activation_kernel(
// Launch element-wise activation kernel.
// Launch element-wise activation kernel.
#define LAUNCH_ACTIVATION_KERNEL(KERNEL) \
#define LAUNCH_ACTIVATION_KERNEL(KERNEL) \
int
num_tokens
= input.size(
0);
\
int
d
= input.size(
-1);
\
int
d = input.size(1);
\
int
num_tokens = input.numel() / d;
\
dim3 grid(num_tokens); \
dim3 grid(num_tokens); \
dim3 block(std::min(d, 1024)); \
dim3 block(std::min(d, 1024)); \
const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); \
const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); \
...
@@ -100,15 +100,15 @@ __device__ __forceinline__ T gelu_fast_kernel(const T& x) {
...
@@ -100,15 +100,15 @@ __device__ __forceinline__ T gelu_fast_kernel(const T& x) {
}
// namespace vllm
}
// namespace vllm
void
gelu_new
(
void
gelu_new
(
torch
::
Tensor
&
out
,
// [
num_tokens
, d]
torch
::
Tensor
&
out
,
// [
...
, d]
torch
::
Tensor
&
input
)
// [
num_tokens
, d]
torch
::
Tensor
&
input
)
// [
...
, d]
{
{
LAUNCH_ACTIVATION_KERNEL
(
vllm
::
gelu_new_kernel
);
LAUNCH_ACTIVATION_KERNEL
(
vllm
::
gelu_new_kernel
);
}
}
void
gelu_fast
(
void
gelu_fast
(
torch
::
Tensor
&
out
,
// [
num_tokens
, d]
torch
::
Tensor
&
out
,
// [
...
, d]
torch
::
Tensor
&
input
)
// [
num_tokens
, d]
torch
::
Tensor
&
input
)
// [
...
, d]
{
{
LAUNCH_ACTIVATION_KERNEL
(
vllm
::
gelu_fast_kernel
);
LAUNCH_ACTIVATION_KERNEL
(
vllm
::
gelu_fast_kernel
);
}
}
csrc/cache_kernels.cu
View file @
c1376e0f
...
@@ -154,6 +154,11 @@ __global__ void reshape_and_cache_kernel(
...
@@ -154,6 +154,11 @@ __global__ void reshape_and_cache_kernel(
const
int
x
)
{
const
int
x
)
{
const
int
token_idx
=
blockIdx
.
x
;
const
int
token_idx
=
blockIdx
.
x
;
const
int
slot_idx
=
slot_mapping
[
token_idx
];
const
int
slot_idx
=
slot_mapping
[
token_idx
];
if
(
slot_idx
<
0
)
{
// Padding token that should be ignored.
return
;
}
const
int
block_idx
=
slot_idx
/
block_size
;
const
int
block_idx
=
slot_idx
/
block_size
;
const
int
block_offset
=
slot_idx
%
block_size
;
const
int
block_offset
=
slot_idx
%
block_size
;
...
@@ -176,8 +181,8 @@ __global__ void reshape_and_cache_kernel(
...
@@ -176,8 +181,8 @@ __global__ void reshape_and_cache_kernel(
+
head_idx
*
head_size
*
block_size
+
head_idx
*
head_size
*
block_size
+
head_offset
*
block_size
+
head_offset
*
block_size
+
block_offset
;
+
block_offset
;
key_cache
[
tgt_key_idx
]
=
__ldg
(
&
key
[
src_key_idx
]
)
;
key_cache
[
tgt_key_idx
]
=
key
[
src_key_idx
];
value_cache
[
tgt_value_idx
]
=
__ldg
(
&
value
[
src_value_idx
]
)
;
value_cache
[
tgt_value_idx
]
=
value
[
src_value_idx
];
}
}
}
}
...
...
csrc/layernorm_kernels.cu
View file @
c1376e0f
...
@@ -9,8 +9,8 @@ namespace vllm {
...
@@ -9,8 +9,8 @@ namespace vllm {
// TODO(woosuk): Further optimize this kernel.
// TODO(woosuk): Further optimize this kernel.
template
<
typename
scalar_t
>
template
<
typename
scalar_t
>
__global__
void
rms_norm_kernel
(
__global__
void
rms_norm_kernel
(
scalar_t
*
__restrict__
out
,
// [
num_tokens
, hidden_size]
scalar_t
*
__restrict__
out
,
// [
...
, hidden_size]
const
scalar_t
*
__restrict__
input
,
// [
num_tokens
, hidden_size]
const
scalar_t
*
__restrict__
input
,
// [
...
, hidden_size]
const
scalar_t
*
__restrict__
weight
,
// [hidden_size]
const
scalar_t
*
__restrict__
weight
,
// [hidden_size]
const
float
epsilon
,
const
float
epsilon
,
const
int
num_tokens
,
const
int
num_tokens
,
...
@@ -37,12 +37,12 @@ __global__ void rms_norm_kernel(
...
@@ -37,12 +37,12 @@ __global__ void rms_norm_kernel(
}
// namespace vllm
}
// namespace vllm
void
rms_norm
(
void
rms_norm
(
torch
::
Tensor
&
out
,
// [
num_tokens
, hidden_size]
torch
::
Tensor
&
out
,
// [
...
, hidden_size]
torch
::
Tensor
&
input
,
// [
num_tokens
, hidden_size]
torch
::
Tensor
&
input
,
// [
...
, hidden_size]
torch
::
Tensor
&
weight
,
// [hidden_size]
torch
::
Tensor
&
weight
,
// [hidden_size]
float
epsilon
)
{
float
epsilon
)
{
int
num_tokens
=
input
.
size
(
0
);
int
hidden_size
=
input
.
size
(
-
1
);
int
hidden_size
=
input
.
size
(
1
)
;
int
num_tokens
=
input
.
numel
()
/
hidden_
size
;
dim3
grid
(
num_tokens
);
dim3
grid
(
num_tokens
);
dim3
block
(
std
::
min
(
hidden_size
,
1024
));
dim3
block
(
std
::
min
(
hidden_size
,
1024
));
...
...
csrc/pos_encoding_kernels.cu
View file @
c1376e0f
...
@@ -37,9 +37,9 @@ inline __device__ void apply_rotary_embedding(
...
@@ -37,9 +37,9 @@ inline __device__ void apply_rotary_embedding(
template
<
typename
scalar_t
,
bool
IS_NEOX
>
template
<
typename
scalar_t
,
bool
IS_NEOX
>
__global__
void
rotary_embedding_kernel
(
__global__
void
rotary_embedding_kernel
(
const
int64_t
*
__restrict__
positions
,
// [num_tokens]
const
int64_t
*
__restrict__
positions
,
//
[batch_size, seq_len] or
[num_tokens]
scalar_t
*
__restrict__
query
,
// [num_tokens, num_heads, head_size]
scalar_t
*
__restrict__
query
,
//
[batch_size, seq_len, num_heads, head_size] or
[num_tokens, num_heads, head_size]
scalar_t
*
__restrict__
key
,
// [num_tokens, num_kv_heads, head_size]
scalar_t
*
__restrict__
key
,
//
[batch_size, seq_len, num_kv_heads, head_size] or
[num_tokens, num_kv_heads, head_size]
const
scalar_t
*
__restrict__
cos_sin_cache
,
// [max_position, 2, rot_dim // 2]
const
scalar_t
*
__restrict__
cos_sin_cache
,
// [max_position, 2, rot_dim // 2]
const
int
rot_dim
,
const
int
rot_dim
,
const
int
query_stride
,
const
int
query_stride
,
...
@@ -78,18 +78,18 @@ __global__ void rotary_embedding_kernel(
...
@@ -78,18 +78,18 @@ __global__ void rotary_embedding_kernel(
}
// namespace vllm
}
// namespace vllm
void
rotary_embedding
(
void
rotary_embedding
(
torch
::
Tensor
&
positions
,
// [num_tokens]
torch
::
Tensor
&
positions
,
//
[batch_size, seq_len] or
[num_tokens]
torch
::
Tensor
&
query
,
// [num_tokens, num_heads * head_size]
torch
::
Tensor
&
query
,
//
[batch_size, seq_len, num_heads * head_size] or
[num_tokens, num_heads * head_size]
torch
::
Tensor
&
key
,
// [num_tokens, num_kv_heads * head_size]
torch
::
Tensor
&
key
,
//
[batch_size, seq_len, num_kv_heads * head_size] or
[num_tokens, num_kv_heads * head_size]
int
head_size
,
int
head_size
,
torch
::
Tensor
&
cos_sin_cache
,
// [max_position, rot_dim]
torch
::
Tensor
&
cos_sin_cache
,
// [max_position, rot_dim]
bool
is_neox
)
{
bool
is_neox
)
{
int
num_tokens
=
query
.
size
(
0
);
int
num_tokens
=
query
.
numel
()
/
query
.
size
(
-
1
);
int
rot_dim
=
cos_sin_cache
.
size
(
1
);
int
rot_dim
=
cos_sin_cache
.
size
(
1
);
int
num_heads
=
query
.
size
(
1
)
/
head_size
;
int
num_heads
=
query
.
size
(
-
1
)
/
head_size
;
int
num_kv_heads
=
key
.
size
(
1
)
/
head_size
;
int
num_kv_heads
=
key
.
size
(
-
1
)
/
head_size
;
int
query_stride
=
query
.
stride
(
0
);
int
query_stride
=
query
.
stride
(
-
2
);
int
key_stride
=
key
.
stride
(
0
);
int
key_stride
=
key
.
stride
(
-
2
);
dim3
grid
(
num_tokens
);
dim3
grid
(
num_tokens
);
dim3
block
(
std
::
min
(
num_heads
*
rot_dim
/
2
,
512
));
dim3
block
(
std
::
min
(
num_heads
*
rot_dim
/
2
,
512
));
...
...
vllm/config.py
View file @
c1376e0f
...
@@ -268,6 +268,7 @@ class SchedulerConfig:
...
@@ -268,6 +268,7 @@ class SchedulerConfig:
iteration.
iteration.
max_model_len: Maximum length of a sequence (including prompt
max_model_len: Maximum length of a sequence (including prompt
and generated text).
and generated text).
max_paddings: Maximum number of paddings to be added to a batch.
"""
"""
def
__init__
(
def
__init__
(
...
@@ -275,6 +276,7 @@ class SchedulerConfig:
...
@@ -275,6 +276,7 @@ class SchedulerConfig:
max_num_batched_tokens
:
Optional
[
int
],
max_num_batched_tokens
:
Optional
[
int
],
max_num_seqs
:
int
,
max_num_seqs
:
int
,
max_model_len
:
int
,
max_model_len
:
int
,
max_paddings
:
int
,
)
->
None
:
)
->
None
:
if
max_num_batched_tokens
is
not
None
:
if
max_num_batched_tokens
is
not
None
:
self
.
max_num_batched_tokens
=
max_num_batched_tokens
self
.
max_num_batched_tokens
=
max_num_batched_tokens
...
@@ -284,6 +286,7 @@ class SchedulerConfig:
...
@@ -284,6 +286,7 @@ class SchedulerConfig:
self
.
max_num_batched_tokens
=
max
(
max_model_len
,
2048
)
self
.
max_num_batched_tokens
=
max
(
max_model_len
,
2048
)
self
.
max_num_seqs
=
max_num_seqs
self
.
max_num_seqs
=
max_num_seqs
self
.
max_model_len
=
max_model_len
self
.
max_model_len
=
max_model_len
self
.
max_paddings
=
max_paddings
self
.
_verify_args
()
self
.
_verify_args
()
def
_verify_args
(
self
)
->
None
:
def
_verify_args
(
self
)
->
None
:
...
...
vllm/core/scheduler.py
View file @
c1376e0f
...
@@ -131,7 +131,8 @@ class Scheduler:
...
@@ -131,7 +131,8 @@ class Scheduler:
# requests in the generation phase.
# requests in the generation phase.
num_curr_seqs
=
sum
(
seq_group
.
get_max_num_running_seqs
()
num_curr_seqs
=
sum
(
seq_group
.
get_max_num_running_seqs
()
for
seq_group
in
self
.
running
)
for
seq_group
in
self
.
running
)
num_batched_tokens
=
0
seq_lens
:
List
[
int
]
=
[]
# Optimization: We do not sort the waiting queue since the preempted
# Optimization: We do not sort the waiting queue since the preempted
# sequence groups are added to the front and the new sequence groups
# sequence groups are added to the front and the new sequence groups
# are added to the back.
# are added to the back.
...
@@ -157,7 +158,9 @@ class Scheduler:
...
@@ -157,7 +158,9 @@ class Scheduler:
break
break
# If the number of batched tokens exceeds the limit, stop.
# If the number of batched tokens exceeds the limit, stop.
if
(
num_batched_tokens
+
num_prompt_tokens
>
new_seq_lens
=
seq_lens
+
[
num_prompt_tokens
]
num_batched_tokens
=
len
(
new_seq_lens
)
*
max
(
new_seq_lens
)
if
(
num_batched_tokens
>
self
.
scheduler_config
.
max_num_batched_tokens
):
self
.
scheduler_config
.
max_num_batched_tokens
):
break
break
...
@@ -168,10 +171,14 @@ class Scheduler:
...
@@ -168,10 +171,14 @@ class Scheduler:
self
.
scheduler_config
.
max_num_seqs
):
self
.
scheduler_config
.
max_num_seqs
):
break
break
num_paddings
=
num_batched_tokens
-
sum
(
new_seq_lens
)
if
num_paddings
>
self
.
scheduler_config
.
max_paddings
:
break
seq_lens
=
new_seq_lens
seq_group
=
self
.
waiting
.
pop
(
0
)
seq_group
=
self
.
waiting
.
pop
(
0
)
self
.
_allocate
(
seq_group
)
self
.
_allocate
(
seq_group
)
self
.
running
.
append
(
seq_group
)
self
.
running
.
append
(
seq_group
)
num_batched_tokens
+=
num_prompt_tokens
num_curr_seqs
+=
num_new_seqs
num_curr_seqs
+=
num_new_seqs
scheduled
.
append
(
seq_group
)
scheduled
.
append
(
seq_group
)
...
@@ -179,7 +186,7 @@ class Scheduler:
...
@@ -179,7 +186,7 @@ class Scheduler:
scheduler_outputs
=
SchedulerOutputs
(
scheduler_outputs
=
SchedulerOutputs
(
scheduled_seq_groups
=
scheduled
,
scheduled_seq_groups
=
scheduled
,
prompt_run
=
True
,
prompt_run
=
True
,
num_batched_tokens
=
num_batched_tok
ens
,
num_batched_tokens
=
len
(
seq_lens
)
*
max
(
seq_l
ens
)
,
blocks_to_swap_in
=
blocks_to_swap_in
,
blocks_to_swap_in
=
blocks_to_swap_in
,
blocks_to_swap_out
=
blocks_to_swap_out
,
blocks_to_swap_out
=
blocks_to_swap_out
,
blocks_to_copy
=
blocks_to_copy
,
blocks_to_copy
=
blocks_to_copy
,
...
...
vllm/engine/arg_utils.py
View file @
c1376e0f
...
@@ -27,6 +27,7 @@ class EngineArgs:
...
@@ -27,6 +27,7 @@ class EngineArgs:
gpu_memory_utilization
:
float
=
0.90
gpu_memory_utilization
:
float
=
0.90
max_num_batched_tokens
:
Optional
[
int
]
=
None
max_num_batched_tokens
:
Optional
[
int
]
=
None
max_num_seqs
:
int
=
256
max_num_seqs
:
int
=
256
max_paddings
:
int
=
256
disable_log_stats
:
bool
=
False
disable_log_stats
:
bool
=
False
revision
:
Optional
[
str
]
=
None
revision
:
Optional
[
str
]
=
None
tokenizer_revision
:
Optional
[
str
]
=
None
tokenizer_revision
:
Optional
[
str
]
=
None
...
@@ -156,6 +157,10 @@ class EngineArgs:
...
@@ -156,6 +157,10 @@ class EngineArgs:
type
=
int
,
type
=
int
,
default
=
EngineArgs
.
max_num_seqs
,
default
=
EngineArgs
.
max_num_seqs
,
help
=
'maximum number of sequences per iteration'
)
help
=
'maximum number of sequences per iteration'
)
parser
.
add_argument
(
'--max-paddings'
,
type
=
int
,
default
=
EngineArgs
.
max_paddings
,
help
=
'maximum number of paddings in a batch'
)
parser
.
add_argument
(
'--disable-log-stats'
,
parser
.
add_argument
(
'--disable-log-stats'
,
action
=
'store_true'
,
action
=
'store_true'
,
help
=
'disable logging statistics'
)
help
=
'disable logging statistics'
)
...
@@ -193,7 +198,8 @@ class EngineArgs:
...
@@ -193,7 +198,8 @@ class EngineArgs:
self
.
worker_use_ray
)
self
.
worker_use_ray
)
scheduler_config
=
SchedulerConfig
(
self
.
max_num_batched_tokens
,
scheduler_config
=
SchedulerConfig
(
self
.
max_num_batched_tokens
,
self
.
max_num_seqs
,
self
.
max_num_seqs
,
model_config
.
max_model_len
)
model_config
.
max_model_len
,
self
.
max_paddings
)
return
model_config
,
cache_config
,
parallel_config
,
scheduler_config
return
model_config
,
cache_config
,
parallel_config
,
scheduler_config
...
...
vllm/model_executor/input_metadata.py
View file @
c1376e0f
...
@@ -39,11 +39,12 @@ class InputMetadata:
...
@@ -39,11 +39,12 @@ class InputMetadata:
self
.
max_context_len
=
max_context_len
self
.
max_context_len
=
max_context_len
self
.
block_tables
=
block_tables
self
.
block_tables
=
block_tables
self
.
max_prompt_len
=
max
(
prompt_lens
)
if
prompt_lens
else
0
self
.
to_cache
=
None
self
.
to_cache
=
None
if
sliding_window
is
not
None
:
if
sliding_window
is
not
None
:
# We need to keep the positions of sliding windows within
# We need to keep the positions of sliding windows within
# the key / value tables, this is helpful to know which
# the key / value tables, this is helpful to know which
# elements we need to cache
and where
# elements we need to cache
.
to_cache
,
start_idx
=
[],
0
to_cache
,
start_idx
=
[],
0
for
prompt_len
in
self
.
prompt_lens
:
for
prompt_len
in
self
.
prompt_lens
:
to_cache
.
extend
(
to_cache
.
extend
(
...
@@ -51,16 +52,15 @@ class InputMetadata:
...
@@ -51,16 +52,15 @@ class InputMetadata:
start_idx
+
max
(
0
,
prompt_len
-
sliding_window
),
start_idx
+
max
(
0
,
prompt_len
-
sliding_window
),
start_idx
+
prompt_len
,
start_idx
+
prompt_len
,
))
))
start_idx
+=
prompt_len
start_idx
+=
self
.
max_
prompt_len
to_cache
.
extend
(
range
(
start_idx
,
slot_mapping
.
shape
[
0
]))
to_cache
.
extend
(
range
(
start_idx
,
slot_mapping
.
shape
[
0
]))
self
.
to_cache
=
torch
.
tensor
(
to_cache
,
self
.
to_cache
=
torch
.
tensor
(
to_cache
,
dtype
=
torch
.
int32
,
dtype
=
torch
.
int32
,
device
=
self
.
slot_mapping
.
device
)
device
=
self
.
slot_mapping
.
device
)
self
.
num_prompts
=
len
(
prompt_lens
)
self
.
num_prompts
=
len
(
prompt_lens
)
self
.
num_prompt_tokens
=
s
um
(
prompt_len
s
)
self
.
num_prompt_tokens
=
s
elf
.
num_prompts
*
self
.
max_
prompt_len
self
.
num_generation_tokens
=
context_lens
.
shape
[
0
]
self
.
num_generation_tokens
=
context_lens
.
shape
[
0
]
self
.
num_valid_tokens
=
slot_mapping
.
shape
[
0
]
if
block_tables
.
numel
()
>
0
:
if
block_tables
.
numel
()
>
0
:
self
.
max_num_blocks_per_seq
=
block_tables
.
shape
[
1
]
self
.
max_num_blocks_per_seq
=
block_tables
.
shape
[
1
]
else
:
else
:
...
@@ -69,12 +69,11 @@ class InputMetadata:
...
@@ -69,12 +69,11 @@ class InputMetadata:
assert
context_lens
.
shape
[
0
]
==
self
.
num_generation_tokens
assert
context_lens
.
shape
[
0
]
==
self
.
num_generation_tokens
# Set during the execution of the first attention op.
# Set during the execution of the first attention op.
self
.
attn_bias
:
List
[
AttentionBias
]
=
[]
self
.
attn_bias
:
Optional
[
AttentionBias
]
=
None
def
__repr__
(
self
)
->
str
:
def
__repr__
(
self
)
->
str
:
# Print only useful metadata.
# Print only useful metadata.
return
(
f
'InputMetadata('
return
(
f
'InputMetadata('
f
'num_valid_tokens=
{
self
.
num_valid_tokens
}
, '
f
'num_prompt_tokens=
{
self
.
num_prompt_tokens
}
, '
f
'num_prompt_tokens=
{
self
.
num_prompt_tokens
}
, '
f
'num_prompts=
{
self
.
num_prompts
}
, '
f
'num_prompts=
{
self
.
num_prompts
}
, '
f
'prompt_lens=
{
self
.
prompt_lens
}
, '
f
'prompt_lens=
{
self
.
prompt_lens
}
, '
...
...
vllm/model_executor/layers/activation.py
View file @
c1376e0f
...
@@ -8,17 +8,17 @@ from vllm import activation_ops
...
@@ -8,17 +8,17 @@ from vllm import activation_ops
class
SiluAndMul
(
nn
.
Module
):
class
SiluAndMul
(
nn
.
Module
):
"""An activation function for SwiGLU.
"""An activation function for SwiGLU.
The function computes x -> silu(x[:d]) * x[d:] where d = x.shape[1] // 2.
The function computes x -> silu(x[:d]) * x[d:] where d = x.shape[
-
1] // 2.
Shapes:
Shapes:
x: (num_tokens, 2 * d)
x:
(batch_size, seq_len, 2 * d) or
(num_tokens, 2 * d)
return: (num_tokens, d)
return:
(batch_size, seq_len, d) or
(num_tokens, d)
"""
"""
def
forward
(
self
,
x
:
torch
.
Tensor
)
->
torch
.
Tensor
:
def
forward
(
self
,
x
:
torch
.
Tensor
)
->
torch
.
Tensor
:
num_tokens
=
x
.
shape
[
0
]
d
=
x
.
shape
[
-
1
]
//
2
d
=
x
.
shape
[
1
]
//
2
output_shape
=
(
x
.
shape
[
:
-
1
]
+
(
d
,
))
out
=
torch
.
empty
(
num_tokens
,
d
,
dtype
=
x
.
dtype
,
device
=
x
.
device
)
out
=
torch
.
empty
(
output_shape
,
dtype
=
x
.
dtype
,
device
=
x
.
device
)
activation_ops
.
silu_and_mul
(
out
,
x
)
activation_ops
.
silu_and_mul
(
out
,
x
)
return
out
return
out
...
@@ -26,9 +26,7 @@ class SiluAndMul(nn.Module):
...
@@ -26,9 +26,7 @@ class SiluAndMul(nn.Module):
class
NewGELU
(
nn
.
Module
):
class
NewGELU
(
nn
.
Module
):
def
forward
(
self
,
x
:
torch
.
Tensor
)
->
torch
.
Tensor
:
def
forward
(
self
,
x
:
torch
.
Tensor
)
->
torch
.
Tensor
:
num_tokens
=
x
.
shape
[
0
]
out
=
torch
.
empty_like
(
x
)
d
=
x
.
shape
[
1
]
out
=
torch
.
empty
(
num_tokens
,
d
,
dtype
=
x
.
dtype
,
device
=
x
.
device
)
activation_ops
.
gelu_new
(
out
,
x
)
activation_ops
.
gelu_new
(
out
,
x
)
return
out
return
out
...
@@ -36,9 +34,7 @@ class NewGELU(nn.Module):
...
@@ -36,9 +34,7 @@ class NewGELU(nn.Module):
class
FastGELU
(
nn
.
Module
):
class
FastGELU
(
nn
.
Module
):
def
forward
(
self
,
x
:
torch
.
Tensor
)
->
torch
.
Tensor
:
def
forward
(
self
,
x
:
torch
.
Tensor
)
->
torch
.
Tensor
:
num_tokens
=
x
.
shape
[
0
]
out
=
torch
.
empty_like
(
x
)
d
=
x
.
shape
[
1
]
out
=
torch
.
empty
(
num_tokens
,
d
,
dtype
=
x
.
dtype
,
device
=
x
.
device
)
activation_ops
.
gelu_fast
(
out
,
x
)
activation_ops
.
gelu_fast
(
out
,
x
)
return
out
return
out
...
...
vllm/model_executor/layers/attention.py
View file @
c1376e0f
...
@@ -23,25 +23,9 @@ class PagedAttention(nn.Module):
...
@@ -23,25 +23,9 @@ class PagedAttention(nn.Module):
# pylint: disable=line-too-long
# pylint: disable=line-too-long
"""GPT-style multi-head PagedAttention.
"""GPT-style multi-head PagedAttention.
This class takes flattened 1D query, key, and value tensors as input. The
This class takes query, key, and value tensors as input. The input tensors
input 1D tensors can either contain prompt tokens or generation tokens, in
can either contain prompt tokens or generation tokens, in addition to
addition to paddings.
paddings.
If the input tensors contain prompt tokens, the layout is as follows:
|<---------------------- num_valid_tokens ---------------------->|
|<--------------- num_prompt_tokens -------------->|
|<--prompt_0-->|<--prompt_1-->|...|<--prompt_N-1-->|<--padding-->|
Otherwise, the layout is as follows:
|<------------------ num_valid_tokens ------------------->|
|<------- num_generation_tokens (M) ------->|
|<--generation_0-->|...|<--generation_M-1-->|<--padding-->|
The prompts might have different lengths, while the generation tokens always
have length 1. The paddings are appended to make the input length a multiple
of 8, which is desirable for Tensor Cores.
The class does the following:
The class does the following:
1. Perform multi_query_kv_attention for the prompts. This operation does
1. Perform multi_query_kv_attention for the prompts. This operation does
...
@@ -53,7 +37,7 @@ class PagedAttention(nn.Module):
...
@@ -53,7 +37,7 @@ class PagedAttention(nn.Module):
4. Perform single_query_cached_kv_attention for the generation tokens.
4. Perform single_query_cached_kv_attention for the generation tokens.
This operation reads the previous key and value tensors from the KV
This operation reads the previous key and value tensors from the KV
cache.
cache.
5.
Output a flattened 1D
tensor.
5.
Return the output
tensor.
"""
"""
def
__init__
(
self
,
def
__init__
(
self
,
...
@@ -85,14 +69,15 @@ class PagedAttention(nn.Module):
...
@@ -85,14 +69,15 @@ class PagedAttention(nn.Module):
dtype
:
torch
.
dtype
,
dtype
:
torch
.
dtype
,
)
->
None
:
)
->
None
:
del
dtype
# Unused.
del
dtype
# Unused.
if
input_metadata
.
attn_bias
:
if
input_metadata
.
attn_bias
is
not
None
:
# Already set by a previous layer.
# Already set by a previous layer.
return
return
prompt_lens
=
input_metadata
.
prompt_lens
prompt_lens
=
[
input_metadata
.
max_prompt_len
]
*
input_metadata
.
num_prompts
attn_bias
=
BlockDiagonalCausalMask
.
from_seqlens
(
prompt_lens
)
attn_bias
=
BlockDiagonalCausalMask
.
from_seqlens
(
prompt_lens
)
if
self
.
sliding_window
is
not
None
:
if
self
.
sliding_window
is
not
None
:
attn_bias
=
attn_bias
.
make_local_attention
(
self
.
sliding_window
)
attn_bias
=
attn_bias
.
make_local_attention
(
self
.
sliding_window
)
input_metadata
.
attn_bias
.
append
(
attn_bias
)
input_metadata
.
attn_bias
=
attn_bias
def
multi_query_kv_attention
(
def
multi_query_kv_attention
(
self
,
self
,
...
@@ -111,7 +96,6 @@ class PagedAttention(nn.Module):
...
@@ -111,7 +96,6 @@ class PagedAttention(nn.Module):
value: shape = [num_prompt_tokens, num_kv_heads, head_size]
value: shape = [num_prompt_tokens, num_kv_heads, head_size]
input_metadata: metadata for paged attention.
input_metadata: metadata for paged attention.
"""
"""
if
self
.
num_kv_heads
!=
self
.
num_heads
:
if
self
.
num_kv_heads
!=
self
.
num_heads
:
# Project the key and value tensors to the desired number of heads.
# Project the key and value tensors to the desired number of heads.
key
=
torch
.
repeat_interleave
(
key
,
self
.
num_queries_per_kv
,
dim
=
1
)
key
=
torch
.
repeat_interleave
(
key
,
self
.
num_queries_per_kv
,
dim
=
1
)
...
@@ -124,7 +108,7 @@ class PagedAttention(nn.Module):
...
@@ -124,7 +108,7 @@ class PagedAttention(nn.Module):
query
.
unsqueeze
(
0
),
query
.
unsqueeze
(
0
),
key
.
unsqueeze
(
0
),
key
.
unsqueeze
(
0
),
value
.
unsqueeze
(
0
),
value
.
unsqueeze
(
0
),
attn_bias
=
input_metadata
.
attn_bias
[
0
]
,
attn_bias
=
input_metadata
.
attn_bias
,
p
=
0.0
,
p
=
0.0
,
scale
=
self
.
scale
,
scale
=
self
.
scale
,
)
)
...
@@ -232,12 +216,12 @@ class PagedAttention(nn.Module):
...
@@ -232,12 +216,12 @@ class PagedAttention(nn.Module):
"""PagedAttention forward pass.
"""PagedAttention forward pass.
NOTE: The query, key, and value tensors must be sliced from a qkv
NOTE: The query, key, and value tensors must be sliced from a qkv
tensor of shape [
num_tok
en
s
, 3 * num_heads * head_size].
tensor of shape [
batch_size, seq_l
en, 3 * num_heads * head_size].
Args:
Args:
query: shape = [
num_tok
en
s
, num_heads * head_size]
query: shape = [
batch_size, seq_l
en, num_heads * head_size]
key: shape = [
num_tok
en
s
, num_kv_heads * head_size]
key: shape = [
batch_size, seq_l
en, num_kv_heads * head_size]
value: shape = [
num_tokens
, num_kv_heads * head_size]
value: shape = [
batch_size
, num_kv_heads * head_size]
key_cache: shape = [num_blocks, num_kv_heads, head_size/x,
key_cache: shape = [num_blocks, num_kv_heads, head_size/x,
block_size, x]
block_size, x]
value_cache: shape = [num_blocks, num_kv_heads, head_size,
value_cache: shape = [num_blocks, num_kv_heads, head_size,
...
@@ -246,9 +230,9 @@ class PagedAttention(nn.Module):
...
@@ -246,9 +230,9 @@ class PagedAttention(nn.Module):
cache_event: event to wait for the cache operations to finish.
cache_event: event to wait for the cache operations to finish.
Returns:
Returns:
shape = [
num_tok
en
s
, num_heads * head_size]
shape = [
batch_size, seq_l
en, num_heads * head_size]
"""
"""
batch_size
,
seq_len
,
_
=
query
.
shape
# Reshape the query, key, and value tensors.
# Reshape the query, key, and value tensors.
query
=
query
.
view
(
-
1
,
self
.
num_heads
,
self
.
head_size
)
query
=
query
.
view
(
-
1
,
self
.
num_heads
,
self
.
head_size
)
key
=
key
.
view
(
-
1
,
self
.
num_kv_heads
,
self
.
head_size
)
key
=
key
.
view
(
-
1
,
self
.
num_kv_heads
,
self
.
head_size
)
...
@@ -264,10 +248,10 @@ class PagedAttention(nn.Module):
...
@@ -264,10 +248,10 @@ class PagedAttention(nn.Module):
assert
input_metadata
.
num_generation_tokens
==
0
assert
input_metadata
.
num_generation_tokens
==
0
self
.
set_attn_bias
(
input_metadata
,
dtype
=
query
.
dtype
)
self
.
set_attn_bias
(
input_metadata
,
dtype
=
query
.
dtype
)
self
.
multi_query_kv_attention
(
self
.
multi_query_kv_attention
(
output
[:
num_prompt_tokens
]
,
output
,
query
[:
num_prompt_tokens
]
,
query
,
key
[:
num_prompt_tokens
]
,
key
,
value
[:
num_prompt_tokens
]
,
value
,
input_metadata
,
input_metadata
,
)
)
...
@@ -278,13 +262,10 @@ class PagedAttention(nn.Module):
...
@@ -278,13 +262,10 @@ class PagedAttention(nn.Module):
# Reshape the keys and values and store them in the cache.
# Reshape the keys and values and store them in the cache.
# When key_cache and value_cache are not provided, the new key
# When key_cache and value_cache are not provided, the new key
# and value vectors will not be cached.
# and value vectors will not be cached.
num_valid_tokens
=
input_metadata
.
num_valid_tokens
if
key_cache
is
not
None
and
value_cache
is
not
None
:
if
(
num_valid_tokens
>
0
and
key_cache
is
not
None
key_to_cache
=
key
and
value_cache
is
not
None
):
value_to_cache
=
value
# The stride is 3 because the key and value are sliced from qkv.
slot_mapping
=
input_metadata
.
slot_mapping
.
view
(
-
1
)
key_to_cache
=
key
[:
num_valid_tokens
]
value_to_cache
=
value
[:
num_valid_tokens
]
slot_mapping
=
input_metadata
.
slot_mapping
if
input_metadata
.
to_cache
is
not
None
:
if
input_metadata
.
to_cache
is
not
None
:
key_to_cache
=
key_to_cache
[
input_metadata
.
to_cache
]
key_to_cache
=
key_to_cache
[
input_metadata
.
to_cache
]
value_to_cache
=
value_to_cache
[
input_metadata
.
to_cache
]
value_to_cache
=
value_to_cache
[
input_metadata
.
to_cache
]
...
@@ -305,14 +286,14 @@ class PagedAttention(nn.Module):
...
@@ -305,14 +286,14 @@ class PagedAttention(nn.Module):
"key_cache and value_cache must be provided when "
"key_cache and value_cache must be provided when "
"generating tokens."
)
"generating tokens."
)
# Compute the attention op for generation tokens.
# Compute the attention op for generation tokens.
self
.
single_query_cached_kv_attention
(
self
.
single_query_cached_kv_attention
(
output
,
query
,
key_cache
,
output
[
num_prompt_tokens
:
num_valid_tokens
],
value_cache
,
input_metadata
,
query
[
num_prompt_tokens
:
num_valid_tokens
],
key_cache
,
self
.
get_alibi_slopes
())
value_cache
,
input_metadata
,
self
.
get_alibi_slopes
())
# Reshape the output tensor.
# Reshape the output tensor.
# NOTE(woosuk): The output tensor may include paddings.
# NOTE(woosuk): The output tensor may include paddings.
return
output
.
view
(
-
1
,
self
.
num_heads
*
self
.
head_size
)
return
output
.
view
(
batch_size
,
seq_len
,
self
.
num_heads
*
self
.
head_size
)
class
PagedAttentionWithRoPE
(
PagedAttention
):
class
PagedAttentionWithRoPE
(
PagedAttention
):
...
@@ -368,10 +349,10 @@ class PagedAttentionWithRoPE(PagedAttention):
...
@@ -368,10 +349,10 @@ class PagedAttentionWithRoPE(PagedAttention):
""" PagedAttention forward pass with rotary embedding.
""" PagedAttention forward pass with rotary embedding.
Args:
Args:
positions: shape = [
num_tok
en
s
]
positions: shape = [
batch_size, seq_l
en]
query: shape = [
num_tok
en
s
, num_heads * head_size]
query: shape = [
batch_size, seq_l
en, num_heads * head_size]
key: shape = [
num_tok
en
s
, num_kv_heads * head_size]
key: shape = [
batch_size, seq_l
en, num_kv_heads * head_size]
value: shape = [
num_tok
en
s
, num_kv_heads * head_size]
value: shape = [
batch_size, seq_l
en, num_kv_heads * head_size]
key_cache: shape = [num_blocks, num_kv_heads, head_size/x,
key_cache: shape = [num_blocks, num_kv_heads, head_size/x,
block_size, x]
block_size, x]
value_cache: shape = [num_blocks, num_kv_heads, head_size,
value_cache: shape = [num_blocks, num_kv_heads, head_size,
...
@@ -380,7 +361,7 @@ class PagedAttentionWithRoPE(PagedAttention):
...
@@ -380,7 +361,7 @@ class PagedAttentionWithRoPE(PagedAttention):
cache_event: event to wait for the cache operations to finish.
cache_event: event to wait for the cache operations to finish.
Returns:
Returns:
shape = [
num_tok
en
s
, num_heads * head_size]
shape = [
batch_size, seq_l
en, num_heads * head_size]
"""
"""
# Apply rotary embedding to the query and key before passing them
# Apply rotary embedding to the query and key before passing them
...
@@ -414,34 +395,34 @@ class PagedAttentionWithALiBi(PagedAttention):
...
@@ -414,34 +395,34 @@ class PagedAttentionWithALiBi(PagedAttention):
def
set_attn_bias
(
self
,
input_metadata
:
InputMetadata
,
def
set_attn_bias
(
self
,
input_metadata
:
InputMetadata
,
dtype
:
torch
.
dtype
)
->
None
:
dtype
:
torch
.
dtype
)
->
None
:
if
input_metadata
.
attn_bias
:
if
input_metadata
.
attn_bias
is
not
None
:
# Already set by a previous layer.
# Already set by a previous layer.
return
return
# Generates ALiBi mask
for each prompt
.
# Generates ALiBi mask
based on the max prompt length
.
for
prompt_len
in
input_metadata
.
prompt_len
s
:
max_
prompt_len
=
input_metadata
.
max_
prompt_len
bias
=
torch
.
arange
(
prompt_len
,
dtype
=
dtype
)
bias
=
torch
.
arange
(
max_
prompt_len
,
dtype
=
dtype
)
# NOTE(zhuohan): HF uses
# NOTE(zhuohan): HF uses
# `bias = bias[None, :].repeat(prompt_len, 1)`
# `bias = bias[None, :].repeat(prompt_len, 1)`
# here. We find that both biases give the same results, but
# here. We find that both biases give the same results, but
# the bias below more accurately follows the original ALiBi
# the bias below more accurately follows the original ALiBi
# paper.
# paper.
bias
=
bias
[
None
,
:]
-
bias
[:,
None
]
bias
=
bias
[
None
,
:]
-
bias
[:,
None
]
bias
=
bias
.
to
(
self
.
alibi_slopes
.
device
)
bias
=
bias
.
to
(
self
.
alibi_slopes
.
device
)
# When using custom attention bias, xformers requires the bias to
# When using custom attention bias, xformers requires the bias to
# be sliced from a tensor whose length is a multiple of 8.
# be sliced from a tensor whose length is a multiple of 8.
padded_len
=
(
prompt_len
+
7
)
//
8
*
8
padded_len
=
(
max_
prompt_len
+
7
)
//
8
*
8
bias
=
torch
.
empty
(
bias
=
torch
.
empty
(
1
,
# batch_size
input_metadata
.
num_prompts
,
self
.
num_heads
,
self
.
num_heads
,
prompt_len
,
max_
prompt_len
,
padded_len
,
padded_len
,
device
=
self
.
alibi_slopes
.
device
,
device
=
self
.
alibi_slopes
.
device
,
dtype
=
dtype
,
dtype
=
dtype
,
)[:,
:,
:,
:
prompt_len
].
copy_
(
bias
)
)[:,
:,
:,
:
max_
prompt_len
].
copy_
(
bias
)
bias
.
mul_
(
self
.
alibi_slopes
[:,
None
,
None
])
bias
.
mul_
(
self
.
alibi_slopes
[:,
None
,
None
])
attn_bias
=
LowerTriangularMaskWithTensorBias
(
bias
)
attn_bias
=
LowerTriangularMaskWithTensorBias
(
bias
)
input_metadata
.
attn_bias
.
append
(
attn_bias
)
input_metadata
.
attn_bias
=
attn_bias
def
multi_query_kv_attention
(
def
multi_query_kv_attention
(
self
,
self
,
...
@@ -466,24 +447,19 @@ class PagedAttentionWithALiBi(PagedAttention):
...
@@ -466,24 +447,19 @@ class PagedAttentionWithALiBi(PagedAttention):
value
=
torch
.
repeat_interleave
(
value
,
value
=
torch
.
repeat_interleave
(
value
,
self
.
num_queries_per_kv
,
self
.
num_queries_per_kv
,
dim
=
1
)
dim
=
1
)
batch_size
=
input_metadata
.
num_prompts
seq_len
=
input_metadata
.
max_prompt_len
# FIXME(woosuk): Because xformers does not support dynamic sequence
out
=
xops
.
memory_efficient_attention_forward
(
# lengths with custom attention bias, we process each prompt one by
query
.
view
(
batch_size
,
seq_len
,
self
.
num_heads
,
self
.
head_size
),
# one. This is inefficient, especially when we have many short prompts.
key
.
view
(
batch_size
,
seq_len
,
self
.
num_heads
,
self
.
head_size
),
start
=
0
value
.
view
(
batch_size
,
seq_len
,
self
.
num_heads
,
self
.
head_size
),
for
i
,
prompt_len
in
enumerate
(
input_metadata
.
prompt_lens
):
attn_bias
=
input_metadata
.
attn_bias
,
end
=
start
+
prompt_len
p
=
0.0
,
out
=
xops
.
memory_efficient_attention_forward
(
scale
=
self
.
scale
,
query
[
None
,
start
:
end
],
)
key
[
None
,
start
:
end
],
# TODO(woosuk): Unnecessary copy. Optimize.
value
[
None
,
start
:
end
],
output
.
copy_
(
out
.
view
(
-
1
,
self
.
num_heads
,
self
.
head_size
))
attn_bias
=
input_metadata
.
attn_bias
[
i
],
p
=
0.0
,
scale
=
self
.
scale
,
)
# TODO(woosuk): Unnecessary copy. Optimize.
output
[
start
:
end
].
copy_
(
out
.
squeeze
(
0
))
start
+=
prompt_len
return
output
return
output
def
get_alibi_slopes
(
self
)
->
Optional
[
torch
.
Tensor
]:
def
get_alibi_slopes
(
self
)
->
Optional
[
torch
.
Tensor
]:
...
...
vllm/model_executor/layers/quantized_linear/awq.py
View file @
c1376e0f
...
@@ -50,7 +50,7 @@ class AWQColumnParallelLinear(ColumnParallelLinear):
...
@@ -50,7 +50,7 @@ class AWQColumnParallelLinear(ColumnParallelLinear):
bias
:
Optional
[
torch
.
Tensor
],
bias
:
Optional
[
torch
.
Tensor
],
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
pack_factor
=
self
.
quant_config
.
pack_factor
pack_factor
=
self
.
quant_config
.
pack_factor
out_shape
=
(
x
.
shape
[
-
2
],
self
.
qweight
.
shape
[
-
1
]
*
pack_factor
)
out_shape
=
(
x
.
shape
[
:
-
1
]
+
(
self
.
qweight
.
shape
[
-
1
]
*
pack_factor
,
)
)
reshaped_x
=
x
.
reshape
(
-
1
,
x
.
shape
[
-
1
])
reshaped_x
=
x
.
reshape
(
-
1
,
x
.
shape
[
-
1
])
out
=
quantization_ops
.
awq_gemm
(
reshaped_x
,
self
.
qweight
,
self
.
scales
,
out
=
quantization_ops
.
awq_gemm
(
reshaped_x
,
self
.
qweight
,
self
.
scales
,
self
.
qzeros
,
pack_factor
)
self
.
qzeros
,
pack_factor
)
...
@@ -95,7 +95,7 @@ class AWQRowParallelLinear(RowParallelLinear):
...
@@ -95,7 +95,7 @@ class AWQRowParallelLinear(RowParallelLinear):
def
apply_weights
(
self
,
x
:
torch
.
Tensor
)
->
torch
.
Tensor
:
def
apply_weights
(
self
,
x
:
torch
.
Tensor
)
->
torch
.
Tensor
:
pack_factor
=
self
.
quant_config
.
pack_factor
pack_factor
=
self
.
quant_config
.
pack_factor
out_shape
=
(
x
.
shape
[
-
2
],
self
.
qweight
.
shape
[
-
1
]
*
pack_factor
)
out_shape
=
(
x
.
shape
[
:
-
1
]
+
(
self
.
qweight
.
shape
[
-
1
]
*
pack_factor
,
)
)
reshaped_x
=
x
.
reshape
(
-
1
,
x
.
shape
[
-
1
])
reshaped_x
=
x
.
reshape
(
-
1
,
x
.
shape
[
-
1
])
out
=
quantization_ops
.
awq_gemm
(
reshaped_x
,
self
.
qweight
,
self
.
scales
,
out
=
quantization_ops
.
awq_gemm
(
reshaped_x
,
self
.
qweight
,
self
.
scales
,
self
.
qzeros
,
pack_factor
)
self
.
qzeros
,
pack_factor
)
...
...
vllm/model_executor/layers/sampler.py
View file @
c1376e0f
...
@@ -119,7 +119,7 @@ def _prune_hidden_states(
...
@@ -119,7 +119,7 @@ def _prune_hidden_states(
selected_token_indices
.
extend
(
selected_token_indices
.
extend
(
range
(
start_idx
,
start_idx
+
prompt_len
-
1
))
range
(
start_idx
,
start_idx
+
prompt_len
-
1
))
selected_token_indices
.
append
(
start_idx
+
prompt_len
-
1
)
selected_token_indices
.
append
(
start_idx
+
prompt_len
-
1
)
start_idx
+=
prompt_len
start_idx
+=
input_metadata
.
max_
prompt_len
else
:
else
:
num_seqs
=
len
(
seq_ids
)
num_seqs
=
len
(
seq_ids
)
selected_token_indices
.
extend
(
selected_token_indices
.
extend
(
...
@@ -129,6 +129,7 @@ def _prune_hidden_states(
...
@@ -129,6 +129,7 @@ def _prune_hidden_states(
selected_token_indices
=
torch
.
tensor
(
selected_token_indices
,
selected_token_indices
=
torch
.
tensor
(
selected_token_indices
,
dtype
=
torch
.
long
,
dtype
=
torch
.
long
,
device
=
hidden_states
.
device
)
device
=
hidden_states
.
device
)
hidden_states
=
hidden_states
.
view
(
-
1
,
hidden_states
.
shape
[
-
1
])
return
hidden_states
.
index_select
(
0
,
selected_token_indices
)
return
hidden_states
.
index_select
(
0
,
selected_token_indices
)
...
...
vllm/worker/worker.py
View file @
c1376e0f
...
@@ -158,9 +158,9 @@ class Worker:
...
@@ -158,9 +158,9 @@ class Worker:
seq_group_metadata_list
:
List
[
SequenceGroupMetadata
],
seq_group_metadata_list
:
List
[
SequenceGroupMetadata
],
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
,
InputMetadata
]:
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
,
InputMetadata
]:
seq_groups
:
List
[
Tuple
[
List
[
int
],
SamplingParams
]]
=
[]
seq_groups
:
List
[
Tuple
[
List
[
int
],
SamplingParams
]]
=
[]
input_tokens
:
List
[
int
]
=
[]
input_tokens
:
List
[
List
[
int
]
]
=
[]
input_positions
:
List
[
int
]
=
[]
input_positions
:
List
[
List
[
int
]
]
=
[]
slot_mapping
:
List
[
int
]
=
[]
slot_mapping
:
List
[
List
[
int
]
]
=
[]
# Add prompt tokens.
# Add prompt tokens.
prompt_lens
:
List
[
int
]
=
[]
prompt_lens
:
List
[
int
]
=
[]
...
@@ -180,24 +180,25 @@ class Worker:
...
@@ -180,24 +180,25 @@ class Worker:
prompt_len
=
len
(
prompt_tokens
)
prompt_len
=
len
(
prompt_tokens
)
prompt_lens
.
append
(
prompt_len
)
prompt_lens
.
append
(
prompt_len
)
input_tokens
.
ext
end
(
prompt_tokens
)
input_tokens
.
app
end
(
prompt_tokens
)
# NOTE(woosuk): Here we assume that the first token in the prompt
# NOTE(woosuk): Here we assume that the first token in the prompt
# is always the first token in the sequence.
# is always the first token in the sequence.
input_positions
.
ext
end
(
range
(
len
(
prompt_
tok
en
s
)))
input_positions
.
app
end
(
list
(
range
(
prompt_
l
en
)))
if
seq_group_metadata
.
block_tables
is
None
:
if
seq_group_metadata
.
block_tables
is
None
:
# During memory profiling, the block tables are not initialized
# During memory profiling, the block tables are not initialized
# yet. In this case, we just use a dummy slot mapping.
# yet. In this case, we just use a dummy slot mapping.
slot_mapping
.
ext
end
([
0
]
*
prompt_len
)
slot_mapping
.
app
end
([
0
]
*
prompt_len
)
continue
continue
# Compute the slot mapping.
# Compute the slot mapping.
slot_mapping
.
append
([])
block_table
=
seq_group_metadata
.
block_tables
[
seq_id
]
block_table
=
seq_group_metadata
.
block_tables
[
seq_id
]
for
i
in
range
(
prompt_len
):
for
i
in
range
(
prompt_len
):
block_number
=
block_table
[
i
//
self
.
block_size
]
block_number
=
block_table
[
i
//
self
.
block_size
]
block_offset
=
i
%
self
.
block_size
block_offset
=
i
%
self
.
block_size
slot
=
block_number
*
self
.
block_size
+
block_offset
slot
=
block_number
*
self
.
block_size
+
block_offset
slot_mapping
.
append
(
slot
)
slot_mapping
[
-
1
]
.
append
(
slot
)
# Add generation tokens.
# Add generation tokens.
max_context_len
=
0
max_context_len
=
0
...
@@ -215,13 +216,13 @@ class Worker:
...
@@ -215,13 +216,13 @@ class Worker:
for
seq_id
in
seq_ids
:
for
seq_id
in
seq_ids
:
seq_data
=
seq_group_metadata
.
seq_data
[
seq_id
]
seq_data
=
seq_group_metadata
.
seq_data
[
seq_id
]
generation_token
=
seq_data
.
get_last_token_id
()
generation_token
=
seq_data
.
get_last_token_id
()
input_tokens
.
append
(
generation_token
)
input_tokens
.
append
(
[
generation_token
]
)
context_len
=
seq_data
.
get_len
()
context_len
=
seq_data
.
get_len
()
position
=
context_len
-
1
position
=
context_len
-
1
if
self
.
sliding_window
is
not
None
:
if
self
.
sliding_window
is
not
None
:
context_len
=
min
(
context_len
,
self
.
sliding_window
)
context_len
=
min
(
context_len
,
self
.
sliding_window
)
input_positions
.
append
(
position
)
input_positions
.
append
(
[
position
]
)
block_table
=
seq_group_metadata
.
block_tables
[
seq_id
]
block_table
=
seq_group_metadata
.
block_tables
[
seq_id
]
...
@@ -233,7 +234,7 @@ class Worker:
...
@@ -233,7 +234,7 @@ class Worker:
block_number
=
block_table
[
position
//
self
.
block_size
]
block_number
=
block_table
[
position
//
self
.
block_size
]
block_offset
=
position
%
self
.
block_size
block_offset
=
position
%
self
.
block_size
slot
=
block_number
*
self
.
block_size
+
block_offset
slot
=
block_number
*
self
.
block_size
+
block_offset
slot_mapping
.
append
(
slot
)
slot_mapping
.
append
(
[
slot
]
)
if
self
.
sliding_window
is
not
None
:
if
self
.
sliding_window
is
not
None
:
sliding_window_blocks
=
(
self
.
sliding_window
//
sliding_window_blocks
=
(
self
.
sliding_window
//
...
@@ -241,28 +242,36 @@ class Worker:
...
@@ -241,28 +242,36 @@ class Worker:
block_table
=
block_table
[
-
sliding_window_blocks
:]
block_table
=
block_table
[
-
sliding_window_blocks
:]
generation_block_tables
.
append
(
block_table
)
generation_block_tables
.
append
(
block_table
)
# Optimization: Pad the input length to be a multiple of 8.
max_seq_len
=
max
(
prompt_lens
)
if
prompt_lens
else
1
# This is required for utilizing the Tensor Cores in NVIDIA GPUs.
padded_input_tokens
=
[
input_tokens
=
_pad_to_alignment
(
input_tokens
,
multiple_of
=
8
)
_pad_to_max
(
tokens
,
max_seq_len
,
pad
=
0
)
for
tokens
in
input_tokens
input_positions
=
_pad_to_alignment
(
input_positions
,
multiple_of
=
8
)
]
padded_input_positions
=
[
_pad_to_max
(
positions
,
max_seq_len
,
pad
=
0
)
for
positions
in
input_positions
]
padded_slot_mapping
=
[
_pad_to_max
(
mapping
,
max_seq_len
,
pad
=-
1
)
for
mapping
in
slot_mapping
]
padded_block_tables
=
[
_pad_to_max
(
block_table
,
max_num_blocks_per_seq
,
pad
=
0
)
for
block_table
in
generation_block_tables
]
# Convert to tensors.
# Convert to tensors.
tokens_tensor
=
torch
.
tensor
(
input_tokens
,
tokens_tensor
=
torch
.
tensor
(
padded_
input_tokens
,
dtype
=
torch
.
long
,
dtype
=
torch
.
long
,
device
=
"cuda"
)
device
=
"cuda"
)
positions_tensor
=
torch
.
tensor
(
input_positions
,
positions_tensor
=
torch
.
tensor
(
padded_
input_positions
,
dtype
=
torch
.
long
,
dtype
=
torch
.
long
,
device
=
"cuda"
)
device
=
"cuda"
)
slot_mapping_tensor
=
torch
.
tensor
(
slot_mapping
,
slot_mapping_tensor
=
torch
.
tensor
(
padded_
slot_mapping
,
dtype
=
torch
.
int
,
dtype
=
torch
.
int
,
device
=
"cuda"
)
device
=
"cuda"
)
context_lens_tensor
=
torch
.
tensor
(
context_lens
,
context_lens_tensor
=
torch
.
tensor
(
context_lens
,
dtype
=
torch
.
int
,
dtype
=
torch
.
int
,
device
=
"cuda"
)
device
=
"cuda"
)
padded_block_tables
=
[
_pad_to_max
(
block_table
,
max_num_blocks_per_seq
)
for
block_table
in
generation_block_tables
]
block_tables_tensor
=
torch
.
tensor
(
padded_block_tables
,
block_tables_tensor
=
torch
.
tensor
(
padded_block_tables
,
dtype
=
torch
.
int
,
dtype
=
torch
.
int
,
device
=
"cuda"
)
device
=
"cuda"
)
...
@@ -361,12 +370,12 @@ def _init_distributed_environment(
...
@@ -361,12 +370,12 @@ def _init_distributed_environment(
parallel_config
.
pipeline_parallel_size
)
parallel_config
.
pipeline_parallel_size
)
def
_pad_to_alignment
(
x
:
List
[
int
],
multiple_of
:
int
)
->
List
[
int
]:
def
_pad_to_alignment
(
x
:
List
[
int
],
multiple_of
:
int
,
pad
:
int
)
->
List
[
int
]:
return
x
+
[
0
]
*
((
-
len
(
x
))
%
multiple_of
)
return
x
+
[
pad
]
*
((
-
len
(
x
))
%
multiple_of
)
def
_pad_to_max
(
x
:
List
[
int
],
max_len
:
int
)
->
List
[
int
]:
def
_pad_to_max
(
x
:
List
[
int
],
max_len
:
int
,
pad
:
int
)
->
List
[
int
]:
return
x
+
[
0
]
*
(
max_len
-
len
(
x
))
return
x
+
[
pad
]
*
(
max_len
-
len
(
x
))
def
_check_if_can_support_max_seq_len
(
max_seq_len
:
int
,
def
_check_if_can_support_max_seq_len
(
max_seq_len
:
int
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment