Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
6e435de7
Unverified
Commit
6e435de7
authored
Mar 21, 2024
by
SangBin Cho
Committed by
GitHub
Mar 20, 2024
Browse files
[1/n][Chunked Prefill] Refactor input query shapes (#3236)
parent
426ec4ec
Changes
18
Hide whitespace changes
Inline
Side-by-side
Showing
18 changed files
with
575 additions
and
259 deletions
+575
-259
.buildkite/test-pipeline.yaml
.buildkite/test-pipeline.yaml
+2
-2
tests/basic_correctness/test_basic_correctness.py
tests/basic_correctness/test_basic_correctness.py
+3
-1
tests/core/test_scheduler.py
tests/core/test_scheduler.py
+9
-9
tests/lora/test_worker.py
tests/lora/test_worker.py
+1
-1
tests/spec_decode/test_multi_step_worker.py
tests/spec_decode/test_multi_step_worker.py
+2
-2
tests/worker/test_model_runner.py
tests/worker/test_model_runner.py
+153
-8
vllm/config.py
vllm/config.py
+0
-3
vllm/core/scheduler.py
vllm/core/scheduler.py
+3
-10
vllm/engine/arg_utils.py
vllm/engine/arg_utils.py
+1
-7
vllm/engine/llm_engine.py
vllm/engine/llm_engine.py
+0
-1
vllm/model_executor/input_metadata.py
vllm/model_executor/input_metadata.py
+69
-13
vllm/model_executor/layers/activation.py
vllm/model_executor/layers/activation.py
+2
-2
vllm/model_executor/layers/attention/attention.py
vllm/model_executor/layers/attention/attention.py
+2
-1
vllm/model_executor/layers/attention/backends/flash_attn.py
vllm/model_executor/layers/attention/backends/flash_attn.py
+32
-14
vllm/model_executor/layers/attention/backends/xformers.py
vllm/model_executor/layers/attention/backends/xformers.py
+147
-85
vllm/model_executor/layers/attention/ops/paged_attn.py
vllm/model_executor/layers/attention/ops/paged_attn.py
+5
-4
vllm/model_executor/layers/sampler.py
vllm/model_executor/layers/sampler.py
+0
-1
vllm/worker/model_runner.py
vllm/worker/model_runner.py
+144
-95
No files found.
.buildkite/test-pipeline.yaml
View file @
6e435de7
...
...
@@ -47,7 +47,7 @@ steps:
-
pytest -v -s prefix_caching
-
label
:
Samplers Test
command
:
pytest -v -s samplers
--forked
command
:
pytest -v -s samplers
-
label
:
Worker Test
command
:
pytest -v -s worker
...
...
@@ -56,7 +56,7 @@ steps:
command
:
pytest -v -s spec_decode
-
label
:
LoRA Test %N
command
:
pytest -v -s lora
--forked
--shard-id=$$BUILDKITE_PARALLEL_JOB --num-shards=$$BUILDKITE_PARALLEL_JOB_COUNT
command
:
pytest -v -s lora --shard-id=$$BUILDKITE_PARALLEL_JOB --num-shards=$$BUILDKITE_PARALLEL_JOB_COUNT
parallelism
:
4
-
label
:
Metrics Test
...
...
tests/basic_correctness/test_basic_correctness.py
View file @
6e435de7
...
...
@@ -13,6 +13,7 @@ MODELS = [
@
pytest
.
mark
.
parametrize
(
"model"
,
MODELS
)
@
pytest
.
mark
.
parametrize
(
"dtype"
,
[
"half"
])
@
pytest
.
mark
.
parametrize
(
"max_tokens"
,
[
5
])
@
pytest
.
mark
.
parametrize
(
"enforce_eager"
,
[
False
,
True
])
def
test_models
(
hf_runner
,
vllm_runner
,
...
...
@@ -20,12 +21,13 @@ def test_models(
model
:
str
,
dtype
:
str
,
max_tokens
:
int
,
enforce_eager
:
bool
,
)
->
None
:
hf_model
=
hf_runner
(
model
,
dtype
=
dtype
)
hf_outputs
=
hf_model
.
generate_greedy
(
example_prompts
,
max_tokens
)
del
hf_model
vllm_model
=
vllm_runner
(
model
,
dtype
=
dtype
)
vllm_model
=
vllm_runner
(
model
,
dtype
=
dtype
,
enforce_eager
=
enforce_eager
)
vllm_outputs
=
vllm_model
.
generate_greedy
(
example_prompts
,
max_tokens
)
del
vllm_model
...
...
tests/core/test_scheduler.py
View file @
6e435de7
...
...
@@ -10,7 +10,7 @@ from .utils import create_dummy_prompt
def
test_scheduler_add_seq_group
():
block_size
=
4
scheduler_config
=
SchedulerConfig
(
100
,
64
,
1
,
256
)
scheduler_config
=
SchedulerConfig
(
100
,
64
,
1
)
cache_config
=
CacheConfig
(
block_size
,
1.0
,
1
,
"auto"
)
cache_config
.
num_cpu_blocks
=
4
cache_config
.
num_gpu_blocks
=
4
...
...
@@ -26,7 +26,7 @@ def test_scheduler_add_seq_group():
def
test_scheduler_abort_seq_group
():
block_size
=
4
scheduler_config
=
SchedulerConfig
(
100
,
64
,
1
,
256
)
scheduler_config
=
SchedulerConfig
(
100
,
64
,
1
)
cache_config
=
CacheConfig
(
block_size
,
1.0
,
1
,
"auto"
)
cache_config
.
num_cpu_blocks
=
4
cache_config
.
num_gpu_blocks
=
4
...
...
@@ -50,7 +50,7 @@ def test_scheduler_schedule_simple():
block_size
=
4
num_seq_group
=
4
max_model_len
=
16
scheduler_config
=
SchedulerConfig
(
64
,
num_seq_group
,
max_model_len
,
256
)
scheduler_config
=
SchedulerConfig
(
64
,
num_seq_group
,
max_model_len
)
cache_config
=
CacheConfig
(
block_size
,
1.0
,
1
,
"auto"
)
cache_config
.
num_cpu_blocks
=
8
cache_config
.
num_gpu_blocks
=
8
...
...
@@ -64,10 +64,10 @@ def test_scheduler_schedule_simple():
running
.
append
(
seq_group
)
# Schedule seq groups prompts.
num_tokens
=
block_size
*
num_seq_group
seq_group_meta
,
out
=
scheduler
.
schedule
()
assert
set
(
out
.
scheduled_seq_groups
)
==
set
(
running
)
assert
out
.
num_batched_tokens
==
num_seq_group
*
seq_group
.
get_seqs
(
)[
0
].
get_len
()
assert
out
.
num_batched_tokens
==
num_tokens
assert
(
not
out
.
blocks_to_copy
and
not
out
.
blocks_to_swap_in
and
not
out
.
blocks_to_swap_out
)
assert
len
(
seq_group_meta
)
==
num_seq_group
...
...
@@ -84,7 +84,7 @@ def test_scheduler_schedule_simple():
def
test_scheduler_schedule_preempt_abort
():
block_size
=
4
max_model_len
=
16
scheduler_config
=
SchedulerConfig
(
64
,
2
,
max_model_len
,
256
)
scheduler_config
=
SchedulerConfig
(
64
,
2
,
max_model_len
)
cache_config
=
CacheConfig
(
block_size
,
1.0
,
1
,
"auto"
)
cache_config
.
num_cpu_blocks
=
2
cache_config
.
num_gpu_blocks
=
2
...
...
@@ -99,7 +99,7 @@ def test_scheduler_schedule_preempt_abort():
# Schedule seq groups prompts.
seq_group_meta
,
out
=
scheduler
.
schedule
()
assert
out
.
scheduled_seq_groups
==
[
seq_group_a
,
seq_group_b
]
assert
out
.
num_batched_tokens
==
seq_group_a
.
get_seqs
()[
0
].
get_len
()
*
2
assert
out
.
num_batched_tokens
==
block_size
*
2
# seq_a and seq_b
assert
(
not
out
.
blocks_to_copy
and
not
out
.
blocks_to_swap_in
and
not
out
.
blocks_to_swap_out
)
assert
len
(
seq_group_meta
)
==
2
...
...
@@ -124,7 +124,7 @@ def test_scheduler_schedule_preempt_abort():
scheduler
.
abort_seq_group
(
"1"
)
seq_group_meta
,
out
=
scheduler
.
schedule
()
assert
out
.
scheduled_seq_groups
==
[
seq_group_b
]
assert
out
.
num_batched_tokens
==
seq_group_b
.
get_seqs
()[
0
].
get_len
()
assert
out
.
num_batched_tokens
==
5
# 4 prompt + 1 generation.
assert
(
not
out
.
blocks_to_copy
and
not
out
.
blocks_to_swap_in
and
not
out
.
blocks_to_swap_out
)
assert
len
(
seq_group_meta
)
==
1
...
...
@@ -136,7 +136,7 @@ def test_scheduler_max_seqs():
num_seq_group
=
4
max_seq_group
=
2
max_model_len
=
16
scheduler_config
=
SchedulerConfig
(
64
,
max_seq_group
,
max_model_len
,
256
)
scheduler_config
=
SchedulerConfig
(
64
,
max_seq_group
,
max_model_len
)
cache_config
=
CacheConfig
(
block_size
,
1.0
,
1
,
"auto"
)
cache_config
.
num_cpu_blocks
=
8
cache_config
.
num_gpu_blocks
=
8
...
...
tests/lora/test_worker.py
View file @
6e435de7
...
...
@@ -25,7 +25,7 @@ def test_worker_apply_lora(sql_lora_files):
revision
=
None
,
),
parallel_config
=
ParallelConfig
(
1
,
1
,
False
),
scheduler_config
=
SchedulerConfig
(
32
,
32
,
32
,
256
),
scheduler_config
=
SchedulerConfig
(
32
,
32
,
32
),
device_config
=
DeviceConfig
(
"cuda"
),
local_rank
=
0
,
rank
=
0
,
...
...
tests/spec_decode/test_multi_step_worker.py
View file @
6e435de7
...
...
@@ -92,8 +92,8 @@ def test_same_output_for_single_step():
num_gpu_blocks
,
seed
,
)
multi_step_worker
.
model_runner
=
worker
.
model_runner
multi_step_worker
.
cache_engine
=
worker
.
cache_engine
#
multi_step_worker.model_runner = worker.model_runner
#
multi_step_worker.cache_engine = worker.cache_engine
num_steps
=
1
...
...
tests/worker/test_model_runner.py
View file @
6e435de7
import
random
import
torch
from
vllm.config
import
ModelConfig
from
vllm.sequence
import
SamplingParams
,
SequenceData
,
SequenceGroupMetadata
from
vllm.worker.model_runner
import
ModelRunner
from
vllm.worker.model_runner
import
ModelRunner
,
_BATCH_SIZE_ALIGNMENT
def
get_aligned_size
(
batch_size
:
int
,
alignment
:
int
):
return
((
batch_size
+
alignment
-
1
)
//
alignment
*
alignment
)
def
test_prepare_prompt
():
...
...
@@ -12,6 +17,7 @@ def test_prepare_prompt():
batch_size
=
random
.
randint
(
1
,
256
)
prompt_lens
=
[]
seq_group_metadata_list
=
[]
block_tables
=
{
0
:
[
1
]}
for
i
in
range
(
batch_size
):
# make sure all tokens fit into one block
prompt_len
=
i
%
(
model_runner
.
block_size
-
1
)
+
1
...
...
@@ -23,26 +29,165 @@ def test_prepare_prompt():
is_prompt
=
True
,
seq_data
=
{
0
:
SequenceData
(
seq_data
)},
sampling_params
=
SamplingParams
(
temperature
=
0
),
block_tables
=
{
0
:
[
1
]}
,
block_tables
=
block_tables
,
))
expected_selected_token_indices
=
[]
selected_token_start_idx
=
0
max_seq_len
=
max
(
prompt_lens
)
for
prompt_len
in
prompt_lens
:
expected_selected_token_indices
.
append
(
selected_token_start_idx
+
prompt_len
-
1
)
selected_token_start_idx
+=
max_seq
_len
input_tokens
,
input_positions
,
_
,
return_prompt_lens
,
_
,
_
,
_
,
_
=
(
model_runner
.
_prepare_prompt
(
seq_group_metadata_list
))
selected_token_start_idx
+=
prompt
_len
(
input_tokens
,
input_positions
,
input_metadata
,
return_prompt_lens
,
_
,
_
,
_
,
_
)
=
(
model_runner
.
_prepare_prompt
(
seq_group_metadata_list
))
assert
return_prompt_lens
==
prompt_lens
# Verify input metadata is correct for prompts.
device
=
model_runner
.
device
assert
input_metadata
.
is_prompt
is
True
assert
torch
.
allclose
(
input_metadata
.
prompt_lens_tensor
,
torch
.
tensor
(
prompt_lens
,
device
=
device
))
assert
input_metadata
.
prompt_lens
==
prompt_lens
assert
input_metadata
.
num_prompt_tokens
==
sum
(
prompt_lens
)
assert
input_metadata
.
num_generation_tokens
==
0
assert
input_metadata
.
max_seq_len
==
max
(
prompt_lens
)
# Test subquery start locs.
start_idx
=
0
start_loc
=
[
start_idx
]
for
prompt_len
in
prompt_lens
:
start_idx
+=
prompt_len
start_loc
.
append
(
start_idx
)
assert
torch
.
allclose
(
input_metadata
.
subquery_start_loc
,
torch
.
tensor
(
start_loc
,
dtype
=
torch
.
int32
,
device
=
device
))
# Test seq start locs. Note that for normal prefill it is
# equivalent to subquery_start_loc.
start_idx
=
0
seq_start_loc
=
[
start_idx
]
for
prompt_len
in
prompt_lens
:
start_idx
+=
prompt_len
seq_start_loc
.
append
(
start_idx
)
assert
torch
.
allclose
(
input_metadata
.
seq_start_loc
,
torch
.
tensor
(
start_loc
,
dtype
=
torch
.
int32
,
device
=
device
))
assert
input_metadata
.
max_context_len
is
None
assert
torch
.
allclose
(
input_metadata
.
context_lens
,
torch
.
zeros
(
input_metadata
.
context_lens
.
shape
[
0
],
dtype
=
torch
.
int
,
device
=
device
))
expected
=
torch
.
tensor
([[]
for
_
in
range
(
len
(
seq_group_metadata_list
))],
dtype
=
torch
.
int32
,
device
=
model_runner
.
device
)
assert
torch
.
allclose
(
input_metadata
.
block_tables
,
expected
)
# Cuda graph should not be used for prerill.
assert
input_metadata
.
use_cuda_graph
is
False
assert
input_metadata
.
kv_cache_dtype
==
"auto"
assert
input_tokens
.
shape
==
(
sum
(
prompt_lens
),
)
assert
input_positions
.
shape
==
(
sum
(
prompt_lens
),
)
torch
.
testing
.
assert_close
(
input_tokens
,
input_positions
)
sampling_metadata
=
model_runner
.
_prepare_sample
(
seq_group_metadata_list
,
prompt_lens
,
subquery_lens
=
prompt_lens
)
assert
input_tokens
.
shape
==
(
batch_size
,
max_seq_len
)
assert
input_positions
.
shape
==
(
batch_size
,
max_seq_len
)
assert
input_tokens
.
shape
==
(
sum
(
prompt_lens
),
)
assert
input_positions
.
shape
==
(
sum
(
prompt_lens
),
)
actual
=
sampling_metadata
.
selected_token_indices
expected
=
torch
.
tensor
(
expected_selected_token_indices
,
device
=
actual
.
device
,
dtype
=
actual
.
dtype
)
torch
.
testing
.
assert_close
(
actual
,
expected
)
torch
.
testing
.
assert_close
(
input_tokens
,
input_positions
)
actual
=
sampling_metadata
.
selected_token_indices
expected
=
torch
.
tensor
(
expected_selected_token_indices
,
device
=
actual
.
device
,
dtype
=
actual
.
dtype
)
torch
.
testing
.
assert_close
(
actual
,
expected
)
def
test_prepare_decode_cuda_graph
():
model_config
=
ModelConfig
(
"facebook/opt-125m"
,
"facebook/opt-125m"
,
tokenizer_mode
=
"auto"
,
trust_remote_code
=
False
,
download_dir
=
None
,
load_format
=
"dummy"
,
seed
=
0
,
dtype
=
"float16"
,
revision
=
None
,
enforce_eager
=
False
,
)
model_runner
=
ModelRunner
(
model_config
,
None
,
None
,
None
,
None
)
model_runner
.
set_block_size
(
16
)
batch_size
=
random
.
randint
(
1
,
256
)
prompt_lens
=
[]
seq_group_metadata_list
=
[]
for
i
in
range
(
batch_size
):
# make sure all tokens fit into one block
prompt_len
=
i
%
(
model_runner
.
block_size
-
1
)
+
1
prompt_lens
.
append
(
prompt_len
)
seq_data
=
list
(
range
(
prompt_len
))
seq_group_metadata_list
.
append
(
SequenceGroupMetadata
(
request_id
=
f
"test_
{
i
}
"
,
is_prompt
=
False
,
seq_data
=
{
0
:
SequenceData
(
seq_data
)},
sampling_params
=
SamplingParams
(
temperature
=
0
),
block_tables
=
{
0
:
[
1
]},
))
input_tokens
,
input_positions
,
input_metadata
,
_
,
_
,
_
=
(
model_runner
.
_prepare_decode
(
seq_group_metadata_list
))
# Verify input metadata is correct for prompts.
device
=
model_runner
.
device
assert
input_metadata
.
is_prompt
is
False
assert
input_metadata
.
prompt_lens
is
None
assert
input_metadata
.
num_prompt_tokens
==
0
assert
input_metadata
.
num_generation_tokens
==
(
get_aligned_size
(
len
(
seq_group_metadata_list
),
_BATCH_SIZE_ALIGNMENT
))
assert
input_metadata
.
max_seq_len
is
None
assert
input_metadata
.
subquery_start_loc
is
None
assert
input_metadata
.
seq_start_loc
is
None
assert
input_metadata
.
max_context_len
==
max
(
prompt_lens
)
assert
torch
.
allclose
(
input_metadata
.
context_lens
[:
len
(
prompt_lens
)],
torch
.
tensor
(
prompt_lens
,
dtype
=
torch
.
int
,
device
=
device
))
# block table's first index corresponds to each batch, meaning in
# decoding it is each token.
assert
input_metadata
.
block_tables
.
shape
[
0
]
==
len
(
input_tokens
)
# Block table's second dim correspondsd to each token's block number.
# It is padded up to
assert
input_metadata
.
block_tables
.
shape
[
1
]
==
(
model_runner
.
get_max_block_per_batch
())
# Cuda graph should not be used for prerill.
assert
input_metadata
.
use_cuda_graph
is
True
assert
input_metadata
.
kv_cache_dtype
==
"auto"
assert
input_tokens
.
shape
==
(
get_aligned_size
(
len
(
seq_group_metadata_list
),
_BATCH_SIZE_ALIGNMENT
),
)
assert
input_positions
.
shape
==
(
get_aligned_size
(
len
(
seq_group_metadata_list
),
_BATCH_SIZE_ALIGNMENT
),
)
torch
.
testing
.
assert_close
(
input_tokens
,
input_positions
)
# Verify Sampling
expected_selected_token_indices
=
[]
selected_token_start_idx
=
0
for
prompt_len
in
prompt_lens
:
expected_selected_token_indices
.
append
(
selected_token_start_idx
)
selected_token_start_idx
+=
1
sampling_metadata
=
model_runner
.
_prepare_sample
(
seq_group_metadata_list
,
prompt_lens
,
subquery_lens
=
prompt_lens
)
actual
=
sampling_metadata
.
selected_token_indices
expected
=
torch
.
tensor
(
expected_selected_token_indices
,
device
=
actual
.
device
,
...
...
vllm/config.py
View file @
6e435de7
...
...
@@ -535,7 +535,6 @@ class SchedulerConfig:
iteration.
max_model_len: Maximum length of a sequence (including prompt
and generated text).
max_paddings: Maximum number of paddings to be added to a batch.
"""
def
__init__
(
...
...
@@ -543,7 +542,6 @@ class SchedulerConfig:
max_num_batched_tokens
:
Optional
[
int
],
max_num_seqs
:
int
,
max_model_len
:
int
,
max_paddings
:
int
,
)
->
None
:
if
max_num_batched_tokens
is
not
None
:
self
.
max_num_batched_tokens
=
max_num_batched_tokens
...
...
@@ -553,7 +551,6 @@ class SchedulerConfig:
self
.
max_num_batched_tokens
=
max
(
max_model_len
,
2048
)
self
.
max_num_seqs
=
max_num_seqs
self
.
max_model_len
=
max_model_len
self
.
max_paddings
=
max_paddings
self
.
_verify_args
()
def
_verify_args
(
self
)
->
None
:
...
...
vllm/core/scheduler.py
View file @
6e435de7
...
...
@@ -173,12 +173,12 @@ class Scheduler:
curr_loras
=
set
(
seq_group
.
lora_int_id
for
seq_group
in
self
.
running
)
if
self
.
lora_enabled
else
None
seq_lens
:
List
[
int
]
=
[]
# Optimization: We do not sort the waiting queue since the preempted
# sequence groups are added to the front and the new sequence groups
# are added to the back.
leftover_waiting_sequences
=
deque
()
num_batched_tokens
=
0
while
self
.
waiting
:
seq_group
=
self
.
waiting
[
0
]
waiting_seqs
=
seq_group
.
get_seqs
(
...
...
@@ -223,8 +223,7 @@ class Scheduler:
continue
# If the number of batched tokens exceeds the limit, stop.
new_seq_lens
=
seq_lens
+
[
num_prompt_tokens
]
num_batched_tokens
=
len
(
new_seq_lens
)
*
max
(
new_seq_lens
)
num_batched_tokens
+=
num_prompt_tokens
if
(
num_batched_tokens
>
self
.
scheduler_config
.
max_num_batched_tokens
):
break
...
...
@@ -236,11 +235,6 @@ class Scheduler:
self
.
scheduler_config
.
max_num_seqs
):
break
num_paddings
=
num_batched_tokens
-
sum
(
new_seq_lens
)
if
num_paddings
>
self
.
scheduler_config
.
max_paddings
:
break
seq_lens
=
new_seq_lens
if
lora_int_id
>
0
:
curr_loras
.
add
(
lora_int_id
)
self
.
waiting
.
popleft
()
...
...
@@ -255,8 +249,7 @@ class Scheduler:
scheduler_outputs
=
SchedulerOutputs
(
scheduled_seq_groups
=
scheduled
,
prompt_run
=
True
,
num_batched_tokens
=
len
(
seq_lens
)
*
max
(
seq_lens
)
if
seq_lens
else
0
,
num_batched_tokens
=
num_batched_tokens
,
blocks_to_swap_in
=
blocks_to_swap_in
,
blocks_to_swap_out
=
blocks_to_swap_out
,
blocks_to_copy
=
blocks_to_copy
,
...
...
vllm/engine/arg_utils.py
View file @
6e435de7
...
...
@@ -31,7 +31,6 @@ class EngineArgs:
gpu_memory_utilization
:
float
=
0.90
max_num_batched_tokens
:
Optional
[
int
]
=
None
max_num_seqs
:
int
=
256
max_paddings
:
int
=
256
max_logprobs
:
int
=
5
# OpenAI default value
disable_log_stats
:
bool
=
False
revision
:
Optional
[
str
]
=
None
...
...
@@ -213,10 +212,6 @@ class EngineArgs:
type
=
int
,
default
=
EngineArgs
.
max_num_seqs
,
help
=
'maximum number of sequences per iteration'
)
parser
.
add_argument
(
'--max-paddings'
,
type
=
int
,
default
=
EngineArgs
.
max_paddings
,
help
=
'maximum number of paddings in a batch'
)
parser
.
add_argument
(
'--max-logprobs'
,
type
=
int
,
...
...
@@ -347,8 +342,7 @@ class EngineArgs:
),
self
.
ray_workers_use_nsight
)
scheduler_config
=
SchedulerConfig
(
self
.
max_num_batched_tokens
,
self
.
max_num_seqs
,
model_config
.
max_model_len
,
self
.
max_paddings
)
model_config
.
max_model_len
)
lora_config
=
LoRAConfig
(
max_lora_rank
=
self
.
max_lora_rank
,
max_loras
=
self
.
max_loras
,
...
...
vllm/engine/llm_engine.py
View file @
6e435de7
...
...
@@ -561,7 +561,6 @@ class LLMEngine:
# Log stats.
if
self
.
log_stats
:
self
.
stat_logger
.
log
(
self
.
_get_stats
(
scheduler_outputs
))
return
request_outputs
def
step
(
self
)
->
List
[
RequestOutput
]:
...
...
vllm/model_executor/input_metadata.py
View file @
6e435de7
from
dataclasses
import
dataclass
,
fields
from
typing
import
Optional
,
Any
,
Dict
from
typing
import
Optional
,
List
,
Any
,
Dict
import
torch
from
xformers.ops.fmha.attn_bias
import
AttentionBias
@
dataclass
class
InputMetadata
:
"""Metadata for input sequences. Used in PagedAttention.
Args:
prompt_lens: Lengths of prompts.
slot_mapping: The address to write the new KV to of each token.
max_context_len: The maximum context length.
context_lens: the length of attention context for each sequence.
block_tables: The block tables. (Seq id -> list of physical block)
kv_cache_dtype: Data type to store kv cache.
NOTE: Any python object stored here is not updated when it is
cuda-graph replayed. If you have values that need to be changed
dynamically, it should be stored in tensor. The tensor has to be
updated from `CUDAGraphRunner.forward` API.
"""
# Currently, input sequences can only contain all prompts
# or all decoding. True if all sequences are prompts.
is_prompt
:
bool
# (num_tokens,). The indices of the token slots that input tokens will be
# stored into. E.g., if `slot_mapping` is [35, 2, 17] and the block size
# is 16, the three tokens are stored in the 3rd slot in block 2, 2nd slot
# in block 0, and 1st slot in block 1, respectively.
slot_mapping
:
torch
.
Tensor
prompt_lens
:
Optional
[
torch
.
Tensor
]
max_seq_len
:
Optional
[
int
]
start_loc
:
Optional
[
torch
.
Tensor
]
# (batch_size,). The prompt length per sequence. None if it is a decoding.
prompt_lens
:
Optional
[
List
[
int
]]
# prompt_lens stored as a tensor.
prompt_lens_tensor
:
Optional
[
torch
.
Tensor
]
# The number of prompt tokens. Doesn't include padding.
num_prompt_tokens
:
int
# The number of generation tokens. Doesn't include padding.
num_generation_tokens
:
int
"""
Definition of context_len, subquery_len, and seqlen.
|---------- N-1 iteration --------|
|---------------- N iteration ---------------------|
|- tokenA -|......................|-- newTokens ---|
|---------- context_len ----------|
|-------------------- seqlen ----------------------|
|- subquery_len -|
WARNING: context_len has different definition depending on if it is
prefill vs decoding. When it is prefill, it doesn't include new
tokens. When it is for decoding, it includes a new token.
"""
# Maximum subquery length in the batch.
max_subquery_len
:
Optional
[
int
]
# Maximum context length in the batch.
max_context_len
:
Optional
[
int
]
# FIXME: It is for flash attn.
# Maximum sequence length in the batch.
max_seq_len
:
Optional
[
int
]
# (batch_size + 1,). The cumulative subquery lengths of the sequences in
# the batch, used to index into subquery. E.g., if the subquery length
# is [4, 6], it is [0, 4, 10].
subquery_start_loc
:
Optional
[
torch
.
Tensor
]
# FIXME: It is for flash attn.
# (batch_size + 1,). The cumulative sequence lengths of the sequences in
# the batch, used to index into sequence. E.g., if the sequence length is
# [4, 6], it is [0, 4, 10].
seq_start_loc
:
Optional
[
torch
.
Tensor
]
# (batch_size,). The length of context (tokens stored in KV cache) per
# sequence. WARNING: When it is a prefill request, it doesn't include new
# tokens. When it is for decoding, it includes a new token.
context_lens
:
Optional
[
torch
.
Tensor
]
# (batch_size, max_blocks_per_seq).
# Block addresses per sequence. (Seq id -> list of physical block)
# E.g., [0, 1, 2] means tokens are stored in 0th, 1st, and 2nd blocks
# in the kv cache. Each block can contain up to block_size tokens.
# 2nd dimensions are padded up to max_blocks_per_seq if it is cuda-graph
# captured.
block_tables
:
Optional
[
torch
.
Tensor
]
# Whether or not if cuda graph is enabled.
# Cuda-graph is currently enabled for decoding only.
use_cuda_graph
:
bool
kv_cache_dtype
:
str
def
__post_init__
(
self
):
# Set during the execution of the first attention op.
# It is a list because it is needed to set per prompt
# when alibi slopes is used. It is because of the limitation
# from xformer API.
# will not appear in the __repr__ and __init__
self
.
attn_bias
=
None
self
.
attn_bias
:
Optional
[
List
[
AttentionBias
]]
=
None
# Cuda graph is only used for decoding now.
if
self
.
use_cuda_graph
:
assert
self
.
num_prompt_tokens
==
0
def
asdict_zerocopy
(
self
)
->
Dict
[
str
,
Any
]:
"""Similar to dataclasses.asdict, but avoids deepcopying."""
...
...
vllm/model_executor/layers/activation.py
View file @
6e435de7
...
...
@@ -20,8 +20,8 @@ class SiluAndMul(nn.Module):
The function computes x -> silu(x[:d]) * x[d:] where d = x.shape[-1] // 2.
Shapes:
x: (
batch_size, seq_len, 2 * d) or (num_tok
en
s
, 2 * d)
return: (batch_size, seq_len, d)
or (num_tokens, d)
x: (
num_tokens, 2 * d) or (batch_size, seq_l
en, 2 * d)
return:
(num_tokens, d) or
(batch_size, seq_len, d)
"""
def
_forward
(
self
,
x
:
torch
.
Tensor
)
->
torch
.
Tensor
:
...
...
vllm/model_executor/layers/attention/attention.py
View file @
6e435de7
...
...
@@ -17,11 +17,12 @@ class Attention(nn.Module):
This class takes query, key, and value tensors as input. The input tensors
can either contain prompt tokens or generation tokens.
The class does the following:
1. Store the input key and value tensors in the KV cache.
2. Perform (multi-head/multi-query/grouped-query) attention.
3.
Return
the output tensor.
3.
Output
the output tensor.
"""
def
__init__
(
...
...
vllm/model_executor/layers/attention/backends/flash_attn.py
View file @
6e435de7
"""Attention layer with Flash and PagedAttention."""
from
typing
import
List
,
Optional
from
flash_attn
import
flash_attn_func
from
flash_attn
import
flash_attn_
varlen_
func
import
torch
from
vllm.model_executor.input_metadata
import
InputMetadata
...
...
@@ -10,6 +10,21 @@ from vllm.model_executor.layers.attention.ops.paged_attn import (
class
FlashAttentionBackend
:
"""
If the input tensors contain prompt tokens, the layout is as follows:
|<--------------- num_prompt_tokens -------------->|
|<--prompt_0-->|<--prompt_1-->|...|<--prompt_N-1-->|
Otherwise, the layout is as follows:
|<------------------ num_generation_tokens (M) ----------------->|
|<--generation_0-->|..........|<--generation_M-1-->|<--padding-->|
Generation tokens can contain padding when cuda-graph is used.
Currently, prompt tokens don't contain any padding.
The prompts might have different lengths, while the generation tokens
always have length 1.
"""
def
__init__
(
self
,
...
...
@@ -52,18 +67,18 @@ class FlashAttentionBackend:
"""Forward pass with FlashAttention and PagedAttention.
Args:
query: shape = [
batch_size, seq_len
, num_heads * head_size]
key: shape = [
batch_size, seq_len
, num_kv_heads * head_size]
value: shape = [
batch_size, seq_len
, num_kv_heads * head_size]
query: shape = [
num_tokens
, num_heads * head_size]
key: shape = [
num_tokens
, num_kv_heads * head_size]
value: shape = [
num_tokens
, num_kv_heads * head_size]
key_cache: shape = [num_blocks, num_kv_heads, head_size/x,
block_size, x]
value_cache: shape = [num_blocks, num_kv_heads, head_size,
block_size]
input_metadata: metadata for the inputs.
Returns:
shape = [
batch_size, seq_len
, num_heads * head_size]
shape = [
num_tokens
, num_heads * head_size]
"""
batch_size
,
seq_len
,
hidden_size
=
query
.
shape
num_tokens
,
hidden_size
=
query
.
shape
# Reshape the query, key, and value tensors.
query
=
query
.
view
(
-
1
,
self
.
num_heads
,
self
.
head_size
)
key
=
key
.
view
(
-
1
,
self
.
num_kv_heads
,
self
.
head_size
)
...
...
@@ -82,13 +97,16 @@ class FlashAttentionBackend:
if
(
key_cache
is
None
or
value_cache
is
None
or
input_metadata
.
block_tables
.
numel
()
==
0
):
# normal attention
query
=
query
.
unflatten
(
0
,
(
batch_size
,
seq_len
))
key
=
key
.
unflatten
(
0
,
(
batch_size
,
seq_len
))
value
=
value
.
unflatten
(
0
,
(
batch_size
,
seq_len
))
output
=
flash_attn_func
(
query
,
key
,
value
,
# When block_tables are not filled, it means q and k are the
# prompt, and they have the same length.
output
=
flash_attn_varlen_func
(
q
=
query
,
k
=
key
,
v
=
value
,
cu_seqlens_q
=
input_metadata
.
seq_start_loc
,
cu_seqlens_k
=
input_metadata
.
seq_start_loc
,
max_seqlen_q
=
input_metadata
.
max_seq_len
,
max_seqlen_k
=
input_metadata
.
max_seq_len
,
softmax_scale
=
self
.
scale
,
causal
=
True
,
window_size
=
self
.
sliding_window
,
...
...
@@ -118,4 +136,4 @@ class FlashAttentionBackend:
)
# Reshape the output tensor.
return
output
.
view
(
batch_size
,
seq_len
,
hidden_size
)
return
output
.
view
(
num_tokens
,
hidden_size
)
vllm/model_executor/layers/attention/backends/xformers.py
View file @
6e435de7
...
...
@@ -14,6 +14,21 @@ from vllm.utils import is_hip
class
XFormersBackend
:
"""
If the input tensors contain prompt tokens, the layout is as follows:
|<--------------- num_prompt_tokens --------------->|
|<--prompt_0-->|<--prompt_1-->|...|<--prompt_N-1--->|
Otherwise, the layout is as follows:
|<------------------ num_generation_tokens (M) ----------------->|
|<--generation_0-->|..........|<--generation_M-1-->|<--padding-->|
Generation tokens can contain padding when cuda-graph is used.
Currently, prompt tokens don't contain any padding.
The prompts might have different lengths, while the generation tokens
always have length 1.
"""
def
__init__
(
self
,
...
...
@@ -55,19 +70,18 @@ class XFormersBackend:
"""Forward pass with xFormers and PagedAttention.
Args:
query: shape = [
batch_size, seq_len
, num_heads * head_size]
key: shape = [
batch_size, seq_len
, num_kv_heads * head_size]
value: shape = [
batch_size, seq_len
, num_kv_heads * head_size]
query: shape = [
num_tokens
, num_heads * head_size]
key: shape = [
num_tokens
, num_kv_heads * head_size]
value: shape = [
num_tokens
, num_kv_heads * head_size]
key_cache: shape = [num_blocks, num_kv_heads, head_size/x,
block_size, x]
value_cache: shape = [num_blocks, num_kv_heads, head_size,
block_size]
input_metadata: metadata for the inputs.
Returns:
shape = [
batch_size, seq_len
, num_heads * head_size]
shape = [
num_tokens
, num_heads * head_size]
"""
batch_size
,
seq_len
,
hidden_size
=
query
.
shape
# Reshape the query, key, and value tensors.
num_tokens
,
hidden_size
=
query
.
shape
query
=
query
.
view
(
-
1
,
self
.
num_heads
,
self
.
head_size
)
key
=
key
.
view
(
-
1
,
self
.
num_kv_heads
,
self
.
head_size
)
value
=
value
.
view
(
-
1
,
self
.
num_kv_heads
,
self
.
head_size
)
...
...
@@ -82,9 +96,10 @@ class XFormersBackend:
if
input_metadata
.
is_prompt
:
# Prompt run.
# key_cache and value_cache are None when it is a profiling run.
# block tables are empty if the prompt has never been computed.
if
(
key_cache
is
None
or
value_cache
is
None
or
input_metadata
.
block_tables
.
numel
()
==
0
):
# normal attention
if
self
.
num_kv_heads
!=
self
.
num_heads
:
# As of Nov 2023, xformers only supports MHA. For MQA/GQA,
# project the key and value tensors to the desired number of
...
...
@@ -103,61 +118,33 @@ class XFormersBackend:
self
.
num_queries_per_kv
,
value
.
shape
[
-
1
])
# Set attention bias if not provided. This typically happens at
# the very attention layer of every iteration.
# FIXME(woosuk): This is a hack.
if
input_metadata
.
attn_bias
is
None
:
if
self
.
alibi_slopes
is
None
:
attn_bias
=
BlockDiagonalCausalMask
.
from_seqlens
(
[
seq_len
]
*
batch_size
)
if
self
.
sliding_window
is
not
None
:
attn_bias
=
attn_bias
.
make_local_attention
(
self
.
sliding_window
)
input_metadata
.
attn_bias
=
attn_bias
else
:
input_metadata
.
attn_bias
=
_make_alibi_bias
(
self
.
alibi_slopes
,
self
.
num_kv_heads
,
batch_size
,
seq_len
,
query
.
dtype
)
if
self
.
use_ref_attention
:
output
=
_ref_masked_attention
(
query
,
key
,
value
,
self
.
num_heads
,
self
.
num_kv_heads
,
self
.
head_size
,
self
.
scale
,
)
print
(
"ref attention used."
)
output
=
torch
.
empty_like
(
query
)
start
=
0
for
_
,
prompt_len
in
enumerate
(
input_metadata
.
prompt_lens
):
end
=
start
+
prompt_len
out
=
_ref_masked_attention
(
query
[
None
,
start
:
end
],
key
[
None
,
start
:
end
],
value
[
None
,
start
:
end
],
self
.
num_heads
,
self
.
num_kv_heads
,
self
.
head_size
,
self
.
scale
,
)
# TODO(woosuk): Unnecessary copy. Optimize.
output
[
start
:
end
].
copy_
(
out
)
start
+=
prompt_len
# Using view got RuntimeError: view size is not compatible
# with input tensor's size and stride (at least one
# dimension spans across two contiguous subspaces).
# Use reshape instead.
return
output
.
reshape
(
batch_size
,
seq_len
,
hidden_size
)
# TODO(woosuk): Too many view operations. Let's try to reduce
# them in the future for code readability.
if
self
.
alibi_slopes
is
None
:
query
=
query
.
unsqueeze
(
0
)
key
=
key
.
unsqueeze
(
0
)
value
=
value
.
unsqueeze
(
0
)
else
:
query
=
query
.
unflatten
(
0
,
(
batch_size
,
seq_len
))
key
=
key
.
unflatten
(
0
,
(
batch_size
,
seq_len
))
value
=
value
.
unflatten
(
0
,
(
batch_size
,
seq_len
))
out
=
xops
.
memory_efficient_attention_forward
(
query
,
key
,
value
,
attn_bias
=
input_metadata
.
attn_bias
,
p
=
0.0
,
scale
=
self
.
scale
,
op
=
xops
.
fmha
.
MemoryEfficientAttentionFlashAttentionOp
[
0
]
if
(
is_hip
())
else
None
,
)
output
=
out
.
view_as
(
query
)
return
output
.
reshape
(
num_tokens
,
hidden_size
)
output
=
self
.
_run_memory_efficient_xformer_forward
(
query
,
key
,
value
,
input_metadata
)
else
:
# prefix-enabled attention
output
=
PagedAttentionImpl
.
forward_prefix
(
...
...
@@ -182,41 +169,117 @@ class XFormersBackend:
)
# Reshape the output tensor.
return
output
.
view
(
batch_size
,
seq_len
,
hidden_size
)
return
output
.
view
(
-
1
,
self
.
num_heads
*
self
.
head_size
)
def
_run_memory_efficient_xformer_forward
(
self
,
query
:
torch
.
Tensor
,
key
:
torch
.
Tensor
,
value
:
torch
.
Tensor
,
input_metadata
:
InputMetadata
,
)
->
torch
.
Tensor
:
"""Attention for 1D query of multiple prompts. Multiple prompt
tokens are flattened in to `query` input.
Args:
output: shape = [num_prompt_tokens, num_heads, head_size]
query: shape = [num_prompt_tokens, num_heads, head_size]
key: shape = [num_prompt_tokens, num_kv_heads, head_size]
value: shape = [num_prompt_tokens, num_kv_heads, head_size]
input_metadata: metadata for paged attention.
"""
# Set attention bias if not provided. This typically happens at
# the very attention layer of every iteration.
# FIXME(woosuk): This is a hack.
if
input_metadata
.
attn_bias
is
None
:
if
self
.
alibi_slopes
is
None
:
attn_bias
=
BlockDiagonalCausalMask
.
from_seqlens
(
input_metadata
.
prompt_lens
)
if
self
.
sliding_window
is
not
None
:
attn_bias
=
attn_bias
.
make_local_attention
(
self
.
sliding_window
)
input_metadata
.
attn_bias
=
[
attn_bias
]
else
:
input_metadata
.
attn_bias
=
_make_alibi_bias
(
self
.
alibi_slopes
,
self
.
num_kv_heads
,
query
.
dtype
,
input_metadata
)
op
=
xops
.
fmha
.
MemoryEfficientAttentionFlashAttentionOp
[
0
]
if
(
is_hip
())
else
None
# No alibi slopes.
# TODO(woosuk): Too many view operations. Let's try to reduce
# them in the future for code readability.
if
self
.
alibi_slopes
is
None
:
query
=
query
.
unsqueeze
(
0
)
key
=
key
.
unsqueeze
(
0
)
value
=
value
.
unsqueeze
(
0
)
out
=
xops
.
memory_efficient_attention_forward
(
query
,
key
,
value
,
attn_bias
=
input_metadata
.
attn_bias
[
0
],
p
=
0.0
,
scale
=
self
.
scale
,
op
=
op
)
return
out
.
view_as
(
query
)
# Attention with alibi slopes.
# FIXME(woosuk): Because xformers does not support dynamic sequence
# lengths with custom attention bias, we process each prompt one by
# one. This is inefficient, especially when we have many short prompts.
output
=
torch
.
empty_like
(
query
)
start
=
0
for
i
,
prompt_len
in
enumerate
(
input_metadata
.
prompt_lens
):
end
=
start
+
prompt_len
out
=
xops
.
memory_efficient_attention_forward
(
query
[
None
,
start
:
end
],
key
[
None
,
start
:
end
],
value
[
None
,
start
:
end
],
attn_bias
=
input_metadata
.
attn_bias
[
i
],
p
=
0.0
,
scale
=
self
.
scale
,
op
=
op
)
# TODO(woosuk): Unnecessary copy. Optimize.
output
[
start
:
end
].
copy_
(
out
.
squeeze
(
0
))
start
+=
prompt_len
return
output
def
_make_alibi_bias
(
alibi_slopes
:
torch
.
Tensor
,
num_kv_heads
:
int
,
batch_size
:
int
,
seq_len
:
int
,
dtype
:
torch
.
dtype
,
input_metadata
:
InputMetadata
,
)
->
LowerTriangularMaskWithTensorBias
:
bias
=
torch
.
arange
(
seq_len
,
dtype
=
dtype
)
# NOTE(zhuohan): HF uses
# `bias = bias[None, :].repeat(prompt_len, 1)`
# here. We find that both biases give the same results, but
# the bias below more accurately follows the original ALiBi
# paper.
bias
=
bias
[
None
,
:]
-
bias
[:,
None
]
# When using custom attention bias, xformers requires the bias to
# be sliced from a tensor whose length is a multiple of 8.
padded_len
=
(
seq_len
+
7
)
//
8
*
8
num_heads
=
alibi_slopes
.
shape
[
0
]
bias
=
torch
.
empty
(
batch_size
,
num_heads
,
seq_len
,
padded_len
,
device
=
alibi_slopes
.
device
,
dtype
=
dtype
,
)[:,
:,
:,
:
seq_len
].
copy_
(
bias
)
bias
.
mul_
(
alibi_slopes
[:,
None
,
None
])
if
num_heads
!=
num_kv_heads
:
bias
=
bias
.
unflatten
(
1
,
(
num_kv_heads
,
num_heads
//
num_kv_heads
))
attn_bias
=
LowerTriangularMaskWithTensorBias
(
bias
)
return
attn_bias
attn_biases
=
[]
for
prompt_len
in
input_metadata
.
prompt_lens
:
bias
=
torch
.
arange
(
prompt_len
,
dtype
=
dtype
)
# NOTE(zhuohan): HF uses
# `bias = bias[None, :].repeat(prompt_len, 1)`
# here. We find that both biases give the same results, but
# the bias below more accurately follows the original ALiBi
# paper.
# Calculate a matrix where each element represents ith element- jth
# element.
bias
=
bias
[
None
,
:]
-
bias
[:,
None
]
padded_len
=
(
prompt_len
+
7
)
//
8
*
8
num_heads
=
alibi_slopes
.
shape
[
0
]
bias
=
torch
.
empty
(
1
,
# batch size
num_heads
,
prompt_len
,
padded_len
,
device
=
alibi_slopes
.
device
,
dtype
=
dtype
,
)[:,
:,
:,
:
prompt_len
].
copy_
(
bias
)
bias
.
mul_
(
alibi_slopes
[:,
None
,
None
])
if
num_heads
!=
num_kv_heads
:
bias
=
bias
.
unflatten
(
1
,
(
num_kv_heads
,
num_heads
//
num_kv_heads
))
attn_biases
.
append
(
LowerTriangularMaskWithTensorBias
(
bias
))
return
attn_biases
def
_check_use_ref_attention
()
->
bool
:
...
...
@@ -239,7 +302,6 @@ def _ref_masked_attention(
query
=
query
.
view
(
-
1
,
num_heads
,
head_size
)
key
=
key
.
view
(
-
1
,
num_kv_heads
,
head_size
)
value
=
value
.
view
(
-
1
,
num_kv_heads
,
head_size
)
seq_len
,
_
,
_
=
query
.
shape
attn_mask
=
torch
.
triu
(
torch
.
ones
(
seq_len
,
seq_len
,
...
...
vllm/model_executor/layers/attention/ops/paged_attn.py
View file @
6e435de7
...
...
@@ -128,11 +128,12 @@ class PagedAttentionImpl:
output
,
key_cache
,
value_cache
,
input_metadata
.
block_tables
,
# [BS, max_block_per_request]
input_metadata
.
start_loc
,
input_metadata
.
prompt_lens
,
input_metadata
.
block_tables
,
# subquery_start_loc is (batch_size + 1,)
input_metadata
.
subquery_start_loc
[:
-
1
],
input_metadata
.
prompt_lens_tensor
,
input_metadata
.
context_lens
,
input_metadata
.
max_s
eq
_len
,
input_metadata
.
max_s
ubquery
_len
,
alibi_slopes
,
)
return
output
vllm/model_executor/layers/sampler.py
View file @
6e435de7
...
...
@@ -128,7 +128,6 @@ def _prune_hidden_states(
hidden_states
:
torch
.
Tensor
,
sampling_metadata
:
SamplingMetadata
,
)
->
torch
.
Tensor
:
hidden_states
=
hidden_states
.
view
(
-
1
,
hidden_states
.
shape
[
-
1
])
return
hidden_states
.
index_select
(
0
,
sampling_metadata
.
selected_token_indices
)
...
...
vllm/worker/model_runner.py
View file @
6e435de7
...
...
@@ -28,9 +28,12 @@ logger = init_logger(__name__)
KVCache
=
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]
_PAD_SLOT_ID
=
-
1
LORA_WARMUP_RANK
=
8
# Capture graphs for batch size 1, 2, 4, 8, 16, 24, 32, 40, ..., 256.
_BATCH_SIZE_ALIGNMENT
=
8
# Capture graphs for token size 1, 2, 4, 8, 16, 24, 32, 40, ..., 256.
# NOTE: _get_graph_batch_size needs to be updated if this list is changed.
_BATCH_SIZES_TO_CAPTURE
=
[
1
,
2
,
4
]
+
[
8
*
i
for
i
in
range
(
1
,
33
)]
_BATCH_SIZES_TO_CAPTURE
=
[
1
,
2
,
4
]
+
[
_BATCH_SIZE_ALIGNMENT
*
i
for
i
in
range
(
1
,
33
)
]
class
ModelRunner
:
...
...
@@ -107,8 +110,7 @@ class ModelRunner:
),
"Model does not have embedding_padding_modules"
self
.
lora_manager
=
LRUCacheWorkerLoRAManager
(
self
.
scheduler_config
.
max_num_seqs
,
self
.
scheduler_config
.
max_num_batched_tokens
+
self
.
scheduler_config
.
max_paddings
,
self
.
vocab_size
,
self
.
scheduler_config
.
max_num_batched_tokens
,
self
.
vocab_size
,
self
.
lora_config
,
self
.
device
,
self
.
model
.
embedding_modules
,
self
.
model
.
embedding_padding_modules
)
self
.
model
=
self
.
lora_manager
.
create_lora_manager
(
self
.
model
)
...
...
@@ -116,10 +118,13 @@ class ModelRunner:
def
set_block_size
(
self
,
block_size
:
int
)
->
None
:
self
.
block_size
=
block_size
max_num_blocks
=
(
self
.
max_context_len_to_capture
+
block_size
-
1
)
//
block_size
self
.
graph_block_tables
=
np
.
zeros
(
(
max
(
_BATCH_SIZES_TO_CAPTURE
),
max_num_blocks
),
dtype
=
np
.
int32
)
(
max
(
_BATCH_SIZES_TO_CAPTURE
),
self
.
get_max_block_per_batch
()),
dtype
=
np
.
int32
)
def
get_max_block_per_batch
(
self
)
->
int
:
block_size
=
self
.
block_size
return
(
self
.
max_context_len_to_capture
+
block_size
-
1
)
//
block_size
def
_prepare_prompt
(
self
,
...
...
@@ -127,9 +132,9 @@ class ModelRunner:
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
,
InputMetadata
,
List
[
int
],
List
[
int
],
List
[
int
],
List
[
int
],
Set
[
LoRARequest
]]:
assert
len
(
seq_group_metadata_list
)
>
0
input_tokens
:
List
[
List
[
int
]
]
=
[]
input_positions
:
List
[
List
[
int
]
]
=
[]
slot_mapping
:
List
[
List
[
int
]
]
=
[]
input_tokens
:
List
[
int
]
=
[]
input_positions
:
List
[
int
]
=
[]
slot_mapping
:
List
[
int
]
=
[]
lora_index_mapping
:
List
[
int
]
=
[]
lora_prompt_mapping
:
List
[
int
]
=
[]
lora_requests
:
Set
[
LoRARequest
]
=
set
()
...
...
@@ -158,16 +163,18 @@ class ModelRunner:
computed_len
=
len
(
computed_block_nums
)
*
self
.
block_size
prompt_tokens
=
prompt_tokens
[
computed_len
:]
prefix_block_tables
.
append
(
computed_block_nums
)
context_len
=
computed_len
else
:
prefix_block_tables
.
append
([])
context_len
=
0
# actual prompt lens
context_lens
.
append
(
co
mputed
_len
)
context_lens
.
append
(
co
ntext
_len
)
subquery_lens
.
append
(
prompt_len
-
computed_len
)
input_tokens
.
app
end
(
prompt_tokens
)
input_tokens
.
ext
end
(
prompt_tokens
)
# NOTE(woosuk): Here we assume that the first token in the prompt
# is always the first token in the sequence.
input_positions
.
app
end
(
input_positions
.
ext
end
(
list
(
range
(
computed_len
,
computed_len
+
len
(
prompt_tokens
))))
lora_id
=
seq_group_metadata
.
lora_int_id
...
...
@@ -175,7 +182,7 @@ class ModelRunner:
if
lora_id
>
0
:
lora_requests
.
add
(
seq_group_metadata
.
lora_request
)
lora_index_mapping
.
append
(
[
lora_id
]
*
(
prompt_len
-
computed_len
)
)
lora_index_mapping
+=
[
lora_id
]
*
(
prompt_len
-
computed_len
)
lora_prompt_mapping
.
extend
(
[
lora_id
]
*
(
prompt_len
-
computed_len
...
...
@@ -184,11 +191,10 @@ class ModelRunner:
if
seq_group_metadata
.
block_tables
is
None
:
# During memory profiling, the block tables are not initialized
# yet. In this case, we just use a dummy slot mapping.
slot_mapping
.
app
end
([
_PAD_SLOT_ID
]
*
prompt_len
)
slot_mapping
.
ext
end
([
_PAD_SLOT_ID
]
*
prompt_len
)
continue
# Compute the slot mapping.
slot_mapping
.
append
([])
block_table
=
seq_group_metadata
.
block_tables
[
seq_id
]
# Mask the [0, start_idx) tokens of the prompt with _PAD_SLOT_ID,
# where start_idx is max(0, prompt_len - sliding_window).
...
...
@@ -203,35 +209,30 @@ class ModelRunner:
start_idx
=
max
(
0
,
prompt_len
-
self
.
sliding_window
)
for
i
in
range
(
computed_len
,
prompt_len
):
if
i
<
start_idx
:
slot_mapping
[
-
1
]
.
append
(
_PAD_SLOT_ID
)
slot_mapping
.
append
(
_PAD_SLOT_ID
)
continue
block_number
=
block_table
[
i
//
self
.
block_size
]
block_offset
=
i
%
self
.
block_size
slot
=
block_number
*
self
.
block_size
+
block_offset
slot_mapping
[
-
1
].
append
(
slot
)
max_prompt_len
=
max
(
subquery_lens
)
assert
max_prompt_len
>
0
input_tokens
=
_make_tensor_with_pad
(
input_tokens
,
max_prompt_len
,
pad
=
0
,
dtype
=
torch
.
long
,
device
=
self
.
device
)
input_positions
=
_make_tensor_with_pad
(
input_positions
,
max_prompt_len
,
pad
=
0
,
dtype
=
torch
.
long
,
device
=
self
.
device
)
slot_mapping
=
_make_tensor_with_pad
(
slot_mapping
,
max_prompt_len
,
pad
=
_PAD_SLOT_ID
,
dtype
=
torch
.
long
,
device
=
self
.
device
)
lora_index_mapping
=
[
_pad_to_max
(
mapping
,
max_prompt_len
,
pad
=
0
)
for
mapping
in
lora_index_mapping
]
slot_mapping
.
append
(
slot
)
max_subquery_len
=
max
(
subquery_lens
)
max_seq_len
=
max
(
prompt_lens
)
num_prompt_tokens
=
len
(
input_tokens
)
assert
max_subquery_len
>
0
input_tokens
=
torch
.
tensor
(
input_tokens
,
dtype
=
torch
.
long
,
device
=
self
.
device
)
input_positions
=
torch
.
tensor
(
input_positions
,
dtype
=
torch
.
long
,
device
=
self
.
device
)
slot_mapping
=
torch
.
tensor
(
slot_mapping
,
dtype
=
torch
.
long
,
device
=
self
.
device
)
lora_index_mapping
=
lora_index_mapping
context_lens_tensor
=
torch
.
tensor
(
context_lens
,
dtype
=
torch
.
int
,
device
=
self
.
device
)
...
...
@@ -244,22 +245,45 @@ class ModelRunner:
dtype
=
torch
.
int
,
device
=
self
.
device
,
)
start_loc_tensor
=
torch
.
arange
(
0
,
len
(
prompt_lens
)
*
max_prompt_len
,
max_prompt_len
,
dtype
=
torch
.
long
,
device
=
self
.
device
)
# Query length can be shorter than key (i.e., prompt) when prefill
# is chunked or prefix cached.
subquery_lens_tensor
=
torch
.
tensor
(
subquery_lens
,
dtype
=
torch
.
long
,
device
=
self
.
device
)
subquery_start_loc
=
torch
.
zeros
(
subquery_lens_tensor
.
shape
[
0
]
+
1
,
dtype
=
torch
.
int32
,
device
=
self
.
device
)
prompt_lens_tensor
=
torch
.
tensor
(
prompt_lens
,
dtype
=
torch
.
long
,
device
=
self
.
device
)
seq_start_loc
=
torch
.
zeros
(
prompt_lens_tensor
.
shape
[
0
]
+
1
,
dtype
=
torch
.
int32
,
device
=
self
.
device
)
torch
.
cumsum
(
subquery_lens_tensor
,
dim
=
0
,
dtype
=
subquery_start_loc
.
dtype
,
out
=
subquery_start_loc
[
1
:])
torch
.
cumsum
(
prompt_lens_tensor
,
dim
=
0
,
dtype
=
seq_start_loc
.
dtype
,
out
=
seq_start_loc
[
1
:])
input_metadata
=
InputMetadata
(
is_prompt
=
True
,
slot_mapping
=
slot_mapping
,
prompt_lens
=
prompt_lens_tensor
,
max_seq_len
=
max_prompt_len
,
start_loc
=
start_loc_tensor
,
prompt_lens
=
prompt_lens
,
prompt_lens_tensor
=
prompt_lens_tensor
,
num_prompt_tokens
=
num_prompt_tokens
,
num_generation_tokens
=
0
,
max_subquery_len
=
max_subquery_len
,
max_context_len
=
None
,
max_seq_len
=
max_seq_len
,
subquery_start_loc
=
subquery_start_loc
,
seq_start_loc
=
seq_start_loc
,
context_lens
=
context_lens_tensor
,
block_tables
=
block_tables
,
use_cuda_graph
=
False
,
...
...
@@ -275,9 +299,9 @@ class ModelRunner:
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
,
InputMetadata
,
List
[
int
],
List
[
int
],
Set
[
LoRARequest
]]:
assert
len
(
seq_group_metadata_list
)
>
0
input_tokens
:
List
[
List
[
int
]
]
=
[]
input_positions
:
List
[
List
[
int
]
]
=
[]
slot_mapping
:
List
[
List
[
int
]
]
=
[]
input_tokens
:
List
[
int
]
=
[]
input_positions
:
List
[
int
]
=
[]
slot_mapping
:
List
[
int
]
=
[]
context_lens
:
List
[
int
]
=
[]
block_tables
:
List
[
List
[
int
]]
=
[]
lora_index_mapping
:
List
[
int
]
=
[]
...
...
@@ -296,11 +320,11 @@ class ModelRunner:
for
seq_id
in
seq_ids
:
seq_data
=
seq_group_metadata
.
seq_data
[
seq_id
]
generation_token
=
seq_data
.
get_last_token_id
()
input_tokens
.
append
(
[
generation_token
]
)
input_tokens
.
append
(
generation_token
)
seq_len
=
seq_data
.
get_len
()
position
=
seq_len
-
1
input_positions
.
append
(
[
position
]
)
input_positions
.
append
(
position
)
context_len
=
seq_len
if
self
.
sliding_window
is
None
else
min
(
seq_len
,
self
.
sliding_window
)
...
...
@@ -310,8 +334,8 @@ class ModelRunner:
block_number
=
block_table
[
position
//
self
.
block_size
]
block_offset
=
position
%
self
.
block_size
slot
=
block_number
*
self
.
block_size
+
block_offset
slot_mapping
.
append
(
[
slot
]
)
lora_index_mapping
.
append
(
[
lora_id
]
)
slot_mapping
.
append
(
slot
)
lora_index_mapping
.
append
(
lora_id
)
lora_prompt_mapping
.
append
(
lora_id
)
if
self
.
sliding_window
is
not
None
:
...
...
@@ -320,6 +344,9 @@ class ModelRunner:
block_table
=
block_table
[
-
sliding_window_blocks
:]
block_tables
.
append
(
block_table
)
# vLLM uses cuda graph only for decoding requests.
# See `capture_model` API for more details.
# For decoding requests, batch_size == input_tokens.
batch_size
=
len
(
input_tokens
)
max_context_len
=
max
(
context_lens
)
use_captured_graph
=
(
...
...
@@ -327,38 +354,37 @@ class ModelRunner:
and
batch_size
<=
_BATCH_SIZES_TO_CAPTURE
[
-
1
]
and
max_context_len
<=
self
.
max_context_len_to_capture
)
if
use_captured_graph
:
# Pad the input tokens, positions, and slot mapping to match the
# batch size of the captured graph.
graph_batch_size
=
_get_graph_batch_size
(
batch_size
)
assert
graph_batch_size
>=
batch_size
for
_
in
range
(
graph_batch_size
-
batch_size
):
input_tokens
.
append
(
[]
)
input_positions
.
append
(
[]
)
slot_mapping
.
append
(
[]
)
input_tokens
.
append
(
0
)
input_positions
.
append
(
0
)
slot_mapping
.
append
(
_PAD_SLOT_ID
)
context_lens
.
append
(
1
)
block_tables
.
append
([])
lora_index_mapping
.
append
(
0
)
batch_size
=
graph_batch_size
input_tokens
=
_make_tensor_with_pad
(
input_tokens
,
max_len
=
1
,
pad
=
0
,
dtype
=
torch
.
long
,
device
=
self
.
device
)
input_positions
=
_make_tensor_with_pad
(
input_positions
,
max_len
=
1
,
pad
=
0
,
dtype
=
torch
.
long
,
device
=
self
.
device
)
slot_mapping
=
_make_tensor_with_pad
(
slot_mapping
,
max_len
=
1
,
pad
=
_PAD_SLOT_ID
,
dtype
=
torch
.
long
,
device
=
self
.
device
)
input_tokens
=
torch
.
tensor
(
input_tokens
,
dtype
=
torch
.
long
,
device
=
self
.
device
)
input_positions
=
torch
.
tensor
(
input_positions
,
dtype
=
torch
.
long
,
device
=
self
.
device
)
slot_mapping
=
torch
.
tensor
(
slot_mapping
,
dtype
=
torch
.
long
,
device
=
self
.
device
)
context_lens
=
torch
.
tensor
(
context_lens
,
dtype
=
torch
.
int
,
device
=
self
.
device
)
if
use_captured_graph
:
# When using cuda-graph all these tensors should be
# padded.
assert
context_lens
.
shape
[
0
]
==
input_tokens
.
shape
[
0
]
assert
context_lens
.
shape
[
0
]
==
input_positions
.
shape
[
0
]
assert
context_lens
.
shape
[
0
]
==
slot_mapping
.
shape
[
0
]
# The shape of graph_block_tables is
# [max batch size, max context len // block size].
input_block_tables
=
self
.
graph_block_tables
[:
batch_size
]
...
...
@@ -377,17 +403,18 @@ class ModelRunner:
device
=
self
.
device
,
)
lora_index_mapping
=
[
_pad_to_max
(
mapping
,
1
,
pad
=
0
)
for
mapping
in
lora_index_mapping
]
input_metadata
=
InputMetadata
(
is_prompt
=
False
,
slot_mapping
=
slot_mapping
,
prompt_lens
=
None
,
max_seq_len
=
None
,
start_loc
=
None
,
prompt_lens_tensor
=
None
,
num_prompt_tokens
=
0
,
num_generation_tokens
=
len
(
input_tokens
),
max_subquery_len
=
None
,
max_context_len
=
max_context_len
,
max_seq_len
=
None
,
subquery_start_loc
=
None
,
seq_start_loc
=
None
,
context_lens
=
context_lens
,
block_tables
=
block_tables
,
use_cuda_graph
=
use_captured_graph
,
...
...
@@ -411,7 +438,6 @@ class ModelRunner:
categorized_sampled_token_indices_start_idx
=
0
pin_memory
=
not
self
.
in_wsl
and
not
self
.
device_config
.
is_neuron
max_subquery_len
=
max
(
subquery_lens
)
if
subquery_lens
else
1
for
i
,
seq_group_metadata
in
enumerate
(
seq_group_metadata_list
):
seq_ids
=
list
(
seq_group_metadata
.
seq_data
.
keys
())
sampling_params
=
seq_group_metadata
.
sampling_params
...
...
@@ -439,7 +465,7 @@ class ModelRunner:
selected_token_start_idx
+
subquery_len
-
1
))
selected_token_indices
.
append
(
selected_token_start_idx
+
subquery_len
-
1
)
selected_token_start_idx
+=
max_
subquery_len
selected_token_start_idx
+=
subquery_len
if
sampling_params
.
seed
is
not
None
:
seq_group_metadata
.
state
.
generator
=
torch
.
Generator
(
...
...
@@ -521,11 +547,8 @@ class ModelRunner:
subquery_lens
)
if
self
.
lora_config
:
flat_lora_index_mapping
=
[
item
for
sublist
in
lora_index_mapping
for
item
in
sublist
]
lora_mapping
=
LoRAMapping
(
flat_
lora_index_mapping
,
lora_index_mapping
,
lora_prompt_mapping
,
)
else
:
...
...
@@ -679,6 +702,18 @@ class ModelRunner:
@
torch
.
inference_mode
()
def
capture_model
(
self
,
kv_caches
:
List
[
KVCache
])
->
None
:
"""Cuda graph capture a model.
Note that CUDA graph's performance gain is negligible if number
of batched tokens are larger than 200. And since CUDA graph
requires fixed sized tensors, supporting large/variable batch
size requires high GPU memory overhead. Thus, vLLM only captures
decoding requests. Mixed batch (chunked prefill + decoding) or
prefill requests are not captured.
Since it is used for decoding-only, it assumes there's only 1 token
per sequence in the batch.
"""
# NOTE(woosuk): This is a hack to ensure that the NCCL backend is never
# deleted before the CUDA graphs.
self
.
cupy_nccl_backend
=
cupy_utils
.
get_nccl_backend
()
...
...
@@ -697,10 +732,9 @@ class ModelRunner:
# Prepare dummy inputs. These will be reused for all batch sizes.
max_batch_size
=
max
(
_BATCH_SIZES_TO_CAPTURE
)
input_tokens
=
torch
.
zeros
(
max_batch_size
,
1
,
dtype
=
torch
.
long
).
cuda
()
input_positions
=
torch
.
zeros
(
max_batch_size
,
1
,
dtype
=
torch
.
long
).
cuda
()
slot_mapping
=
torch
.
empty
(
max_batch_size
,
1
,
dtype
=
torch
.
long
).
cuda
()
input_tokens
=
torch
.
zeros
(
max_batch_size
,
dtype
=
torch
.
long
).
cuda
()
input_positions
=
torch
.
zeros
(
max_batch_size
,
dtype
=
torch
.
long
).
cuda
()
slot_mapping
=
torch
.
empty
(
max_batch_size
,
dtype
=
torch
.
long
).
cuda
()
slot_mapping
.
fill_
(
_PAD_SLOT_ID
)
context_lens
=
torch
.
ones
(
max_batch_size
,
dtype
=
torch
.
int32
).
cuda
()
block_tables
=
torch
.
from_numpy
(
self
.
graph_block_tables
).
cuda
()
...
...
@@ -726,9 +760,14 @@ class ModelRunner:
is_prompt
=
False
,
slot_mapping
=
slot_mapping
[:
batch_size
],
prompt_lens
=
None
,
max_seq_len
=
None
,
start_loc
=
None
,
prompt_lens_tensor
=
None
,
num_prompt_tokens
=
0
,
num_generation_tokens
=
batch_size
,
max_subquery_len
=
None
,
max_context_len
=
self
.
max_context_len_to_capture
,
max_seq_len
=
None
,
subquery_start_loc
=
None
,
seq_start_loc
=
None
,
context_lens
=
context_lens
[:
batch_size
],
block_tables
=
block_tables
[:
batch_size
],
use_cuda_graph
=
True
,
...
...
@@ -845,7 +884,6 @@ class CUDAGraphRunner:
non_blocking
=
True
)
self
.
input_buffers
[
"block_tables"
].
copy_
(
input_metadata
.
block_tables
,
non_blocking
=
True
)
# Run the graph.
self
.
graph
.
replay
()
...
...
@@ -877,17 +915,28 @@ def _make_tensor_with_pad(
dtype
:
torch
.
dtype
,
device
:
Optional
[
Union
[
str
,
torch
.
device
]],
)
->
torch
.
Tensor
:
"""Make a padded tensor of a 2D inputs.
The padding is applied to the end of each inner list until it reaches
`max_len`.
"""
padded_x
=
[
_pad_to_max
(
x_i
,
max_len
,
pad
)
for
x_i
in
x
]
return
torch
.
tensor
(
padded_x
,
dtype
=
dtype
,
device
=
device
)
def
_get_graph_batch_size
(
batch_size
:
int
)
->
int
:
"""Returns the padded batch size given actual batch size.
Batch sizes are 1, 2, 4, _BATCH_SIZE_ALIGNMENT,
2*_BATCH_SIZE_ALIGNMENT, 3*_BATCH_SIZE_ALIGNMENT...
"""
if
batch_size
<=
2
:
return
batch_size
elif
batch_size
<=
4
:
return
4
else
:
return
(
batch_size
+
7
)
//
8
*
8
return
((
batch_size
+
_BATCH_SIZE_ALIGNMENT
-
1
)
//
_BATCH_SIZE_ALIGNMENT
*
_BATCH_SIZE_ALIGNMENT
)
def
_async_h2d
(
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment