Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
6e435de7
Unverified
Commit
6e435de7
authored
Mar 21, 2024
by
SangBin Cho
Committed by
GitHub
Mar 20, 2024
Browse files
[1/n][Chunked Prefill] Refactor input query shapes (#3236)
parent
426ec4ec
Changes
18
Expand all
Hide whitespace changes
Inline
Side-by-side
Showing
18 changed files
with
575 additions
and
259 deletions
+575
-259
.buildkite/test-pipeline.yaml
.buildkite/test-pipeline.yaml
+2
-2
tests/basic_correctness/test_basic_correctness.py
tests/basic_correctness/test_basic_correctness.py
+3
-1
tests/core/test_scheduler.py
tests/core/test_scheduler.py
+9
-9
tests/lora/test_worker.py
tests/lora/test_worker.py
+1
-1
tests/spec_decode/test_multi_step_worker.py
tests/spec_decode/test_multi_step_worker.py
+2
-2
tests/worker/test_model_runner.py
tests/worker/test_model_runner.py
+153
-8
vllm/config.py
vllm/config.py
+0
-3
vllm/core/scheduler.py
vllm/core/scheduler.py
+3
-10
vllm/engine/arg_utils.py
vllm/engine/arg_utils.py
+1
-7
vllm/engine/llm_engine.py
vllm/engine/llm_engine.py
+0
-1
vllm/model_executor/input_metadata.py
vllm/model_executor/input_metadata.py
+69
-13
vllm/model_executor/layers/activation.py
vllm/model_executor/layers/activation.py
+2
-2
vllm/model_executor/layers/attention/attention.py
vllm/model_executor/layers/attention/attention.py
+2
-1
vllm/model_executor/layers/attention/backends/flash_attn.py
vllm/model_executor/layers/attention/backends/flash_attn.py
+32
-14
vllm/model_executor/layers/attention/backends/xformers.py
vllm/model_executor/layers/attention/backends/xformers.py
+147
-85
vllm/model_executor/layers/attention/ops/paged_attn.py
vllm/model_executor/layers/attention/ops/paged_attn.py
+5
-4
vllm/model_executor/layers/sampler.py
vllm/model_executor/layers/sampler.py
+0
-1
vllm/worker/model_runner.py
vllm/worker/model_runner.py
+144
-95
No files found.
.buildkite/test-pipeline.yaml
View file @
6e435de7
...
@@ -47,7 +47,7 @@ steps:
...
@@ -47,7 +47,7 @@ steps:
-
pytest -v -s prefix_caching
-
pytest -v -s prefix_caching
-
label
:
Samplers Test
-
label
:
Samplers Test
command
:
pytest -v -s samplers
--forked
command
:
pytest -v -s samplers
-
label
:
Worker Test
-
label
:
Worker Test
command
:
pytest -v -s worker
command
:
pytest -v -s worker
...
@@ -56,7 +56,7 @@ steps:
...
@@ -56,7 +56,7 @@ steps:
command
:
pytest -v -s spec_decode
command
:
pytest -v -s spec_decode
-
label
:
LoRA Test %N
-
label
:
LoRA Test %N
command
:
pytest -v -s lora
--forked
--shard-id=$$BUILDKITE_PARALLEL_JOB --num-shards=$$BUILDKITE_PARALLEL_JOB_COUNT
command
:
pytest -v -s lora --shard-id=$$BUILDKITE_PARALLEL_JOB --num-shards=$$BUILDKITE_PARALLEL_JOB_COUNT
parallelism
:
4
parallelism
:
4
-
label
:
Metrics Test
-
label
:
Metrics Test
...
...
tests/basic_correctness/test_basic_correctness.py
View file @
6e435de7
...
@@ -13,6 +13,7 @@ MODELS = [
...
@@ -13,6 +13,7 @@ MODELS = [
@
pytest
.
mark
.
parametrize
(
"model"
,
MODELS
)
@
pytest
.
mark
.
parametrize
(
"model"
,
MODELS
)
@
pytest
.
mark
.
parametrize
(
"dtype"
,
[
"half"
])
@
pytest
.
mark
.
parametrize
(
"dtype"
,
[
"half"
])
@
pytest
.
mark
.
parametrize
(
"max_tokens"
,
[
5
])
@
pytest
.
mark
.
parametrize
(
"max_tokens"
,
[
5
])
@
pytest
.
mark
.
parametrize
(
"enforce_eager"
,
[
False
,
True
])
def
test_models
(
def
test_models
(
hf_runner
,
hf_runner
,
vllm_runner
,
vllm_runner
,
...
@@ -20,12 +21,13 @@ def test_models(
...
@@ -20,12 +21,13 @@ def test_models(
model
:
str
,
model
:
str
,
dtype
:
str
,
dtype
:
str
,
max_tokens
:
int
,
max_tokens
:
int
,
enforce_eager
:
bool
,
)
->
None
:
)
->
None
:
hf_model
=
hf_runner
(
model
,
dtype
=
dtype
)
hf_model
=
hf_runner
(
model
,
dtype
=
dtype
)
hf_outputs
=
hf_model
.
generate_greedy
(
example_prompts
,
max_tokens
)
hf_outputs
=
hf_model
.
generate_greedy
(
example_prompts
,
max_tokens
)
del
hf_model
del
hf_model
vllm_model
=
vllm_runner
(
model
,
dtype
=
dtype
)
vllm_model
=
vllm_runner
(
model
,
dtype
=
dtype
,
enforce_eager
=
enforce_eager
)
vllm_outputs
=
vllm_model
.
generate_greedy
(
example_prompts
,
max_tokens
)
vllm_outputs
=
vllm_model
.
generate_greedy
(
example_prompts
,
max_tokens
)
del
vllm_model
del
vllm_model
...
...
tests/core/test_scheduler.py
View file @
6e435de7
...
@@ -10,7 +10,7 @@ from .utils import create_dummy_prompt
...
@@ -10,7 +10,7 @@ from .utils import create_dummy_prompt
def
test_scheduler_add_seq_group
():
def
test_scheduler_add_seq_group
():
block_size
=
4
block_size
=
4
scheduler_config
=
SchedulerConfig
(
100
,
64
,
1
,
256
)
scheduler_config
=
SchedulerConfig
(
100
,
64
,
1
)
cache_config
=
CacheConfig
(
block_size
,
1.0
,
1
,
"auto"
)
cache_config
=
CacheConfig
(
block_size
,
1.0
,
1
,
"auto"
)
cache_config
.
num_cpu_blocks
=
4
cache_config
.
num_cpu_blocks
=
4
cache_config
.
num_gpu_blocks
=
4
cache_config
.
num_gpu_blocks
=
4
...
@@ -26,7 +26,7 @@ def test_scheduler_add_seq_group():
...
@@ -26,7 +26,7 @@ def test_scheduler_add_seq_group():
def
test_scheduler_abort_seq_group
():
def
test_scheduler_abort_seq_group
():
block_size
=
4
block_size
=
4
scheduler_config
=
SchedulerConfig
(
100
,
64
,
1
,
256
)
scheduler_config
=
SchedulerConfig
(
100
,
64
,
1
)
cache_config
=
CacheConfig
(
block_size
,
1.0
,
1
,
"auto"
)
cache_config
=
CacheConfig
(
block_size
,
1.0
,
1
,
"auto"
)
cache_config
.
num_cpu_blocks
=
4
cache_config
.
num_cpu_blocks
=
4
cache_config
.
num_gpu_blocks
=
4
cache_config
.
num_gpu_blocks
=
4
...
@@ -50,7 +50,7 @@ def test_scheduler_schedule_simple():
...
@@ -50,7 +50,7 @@ def test_scheduler_schedule_simple():
block_size
=
4
block_size
=
4
num_seq_group
=
4
num_seq_group
=
4
max_model_len
=
16
max_model_len
=
16
scheduler_config
=
SchedulerConfig
(
64
,
num_seq_group
,
max_model_len
,
256
)
scheduler_config
=
SchedulerConfig
(
64
,
num_seq_group
,
max_model_len
)
cache_config
=
CacheConfig
(
block_size
,
1.0
,
1
,
"auto"
)
cache_config
=
CacheConfig
(
block_size
,
1.0
,
1
,
"auto"
)
cache_config
.
num_cpu_blocks
=
8
cache_config
.
num_cpu_blocks
=
8
cache_config
.
num_gpu_blocks
=
8
cache_config
.
num_gpu_blocks
=
8
...
@@ -64,10 +64,10 @@ def test_scheduler_schedule_simple():
...
@@ -64,10 +64,10 @@ def test_scheduler_schedule_simple():
running
.
append
(
seq_group
)
running
.
append
(
seq_group
)
# Schedule seq groups prompts.
# Schedule seq groups prompts.
num_tokens
=
block_size
*
num_seq_group
seq_group_meta
,
out
=
scheduler
.
schedule
()
seq_group_meta
,
out
=
scheduler
.
schedule
()
assert
set
(
out
.
scheduled_seq_groups
)
==
set
(
running
)
assert
set
(
out
.
scheduled_seq_groups
)
==
set
(
running
)
assert
out
.
num_batched_tokens
==
num_seq_group
*
seq_group
.
get_seqs
(
assert
out
.
num_batched_tokens
==
num_tokens
)[
0
].
get_len
()
assert
(
not
out
.
blocks_to_copy
and
not
out
.
blocks_to_swap_in
assert
(
not
out
.
blocks_to_copy
and
not
out
.
blocks_to_swap_in
and
not
out
.
blocks_to_swap_out
)
and
not
out
.
blocks_to_swap_out
)
assert
len
(
seq_group_meta
)
==
num_seq_group
assert
len
(
seq_group_meta
)
==
num_seq_group
...
@@ -84,7 +84,7 @@ def test_scheduler_schedule_simple():
...
@@ -84,7 +84,7 @@ def test_scheduler_schedule_simple():
def
test_scheduler_schedule_preempt_abort
():
def
test_scheduler_schedule_preempt_abort
():
block_size
=
4
block_size
=
4
max_model_len
=
16
max_model_len
=
16
scheduler_config
=
SchedulerConfig
(
64
,
2
,
max_model_len
,
256
)
scheduler_config
=
SchedulerConfig
(
64
,
2
,
max_model_len
)
cache_config
=
CacheConfig
(
block_size
,
1.0
,
1
,
"auto"
)
cache_config
=
CacheConfig
(
block_size
,
1.0
,
1
,
"auto"
)
cache_config
.
num_cpu_blocks
=
2
cache_config
.
num_cpu_blocks
=
2
cache_config
.
num_gpu_blocks
=
2
cache_config
.
num_gpu_blocks
=
2
...
@@ -99,7 +99,7 @@ def test_scheduler_schedule_preempt_abort():
...
@@ -99,7 +99,7 @@ def test_scheduler_schedule_preempt_abort():
# Schedule seq groups prompts.
# Schedule seq groups prompts.
seq_group_meta
,
out
=
scheduler
.
schedule
()
seq_group_meta
,
out
=
scheduler
.
schedule
()
assert
out
.
scheduled_seq_groups
==
[
seq_group_a
,
seq_group_b
]
assert
out
.
scheduled_seq_groups
==
[
seq_group_a
,
seq_group_b
]
assert
out
.
num_batched_tokens
==
seq_group_a
.
get_seqs
()[
0
].
get_len
()
*
2
assert
out
.
num_batched_tokens
==
block_size
*
2
# seq_a and seq_b
assert
(
not
out
.
blocks_to_copy
and
not
out
.
blocks_to_swap_in
assert
(
not
out
.
blocks_to_copy
and
not
out
.
blocks_to_swap_in
and
not
out
.
blocks_to_swap_out
)
and
not
out
.
blocks_to_swap_out
)
assert
len
(
seq_group_meta
)
==
2
assert
len
(
seq_group_meta
)
==
2
...
@@ -124,7 +124,7 @@ def test_scheduler_schedule_preempt_abort():
...
@@ -124,7 +124,7 @@ def test_scheduler_schedule_preempt_abort():
scheduler
.
abort_seq_group
(
"1"
)
scheduler
.
abort_seq_group
(
"1"
)
seq_group_meta
,
out
=
scheduler
.
schedule
()
seq_group_meta
,
out
=
scheduler
.
schedule
()
assert
out
.
scheduled_seq_groups
==
[
seq_group_b
]
assert
out
.
scheduled_seq_groups
==
[
seq_group_b
]
assert
out
.
num_batched_tokens
==
seq_group_b
.
get_seqs
()[
0
].
get_len
()
assert
out
.
num_batched_tokens
==
5
# 4 prompt + 1 generation.
assert
(
not
out
.
blocks_to_copy
and
not
out
.
blocks_to_swap_in
assert
(
not
out
.
blocks_to_copy
and
not
out
.
blocks_to_swap_in
and
not
out
.
blocks_to_swap_out
)
and
not
out
.
blocks_to_swap_out
)
assert
len
(
seq_group_meta
)
==
1
assert
len
(
seq_group_meta
)
==
1
...
@@ -136,7 +136,7 @@ def test_scheduler_max_seqs():
...
@@ -136,7 +136,7 @@ def test_scheduler_max_seqs():
num_seq_group
=
4
num_seq_group
=
4
max_seq_group
=
2
max_seq_group
=
2
max_model_len
=
16
max_model_len
=
16
scheduler_config
=
SchedulerConfig
(
64
,
max_seq_group
,
max_model_len
,
256
)
scheduler_config
=
SchedulerConfig
(
64
,
max_seq_group
,
max_model_len
)
cache_config
=
CacheConfig
(
block_size
,
1.0
,
1
,
"auto"
)
cache_config
=
CacheConfig
(
block_size
,
1.0
,
1
,
"auto"
)
cache_config
.
num_cpu_blocks
=
8
cache_config
.
num_cpu_blocks
=
8
cache_config
.
num_gpu_blocks
=
8
cache_config
.
num_gpu_blocks
=
8
...
...
tests/lora/test_worker.py
View file @
6e435de7
...
@@ -25,7 +25,7 @@ def test_worker_apply_lora(sql_lora_files):
...
@@ -25,7 +25,7 @@ def test_worker_apply_lora(sql_lora_files):
revision
=
None
,
revision
=
None
,
),
),
parallel_config
=
ParallelConfig
(
1
,
1
,
False
),
parallel_config
=
ParallelConfig
(
1
,
1
,
False
),
scheduler_config
=
SchedulerConfig
(
32
,
32
,
32
,
256
),
scheduler_config
=
SchedulerConfig
(
32
,
32
,
32
),
device_config
=
DeviceConfig
(
"cuda"
),
device_config
=
DeviceConfig
(
"cuda"
),
local_rank
=
0
,
local_rank
=
0
,
rank
=
0
,
rank
=
0
,
...
...
tests/spec_decode/test_multi_step_worker.py
View file @
6e435de7
...
@@ -92,8 +92,8 @@ def test_same_output_for_single_step():
...
@@ -92,8 +92,8 @@ def test_same_output_for_single_step():
num_gpu_blocks
,
num_gpu_blocks
,
seed
,
seed
,
)
)
multi_step_worker
.
model_runner
=
worker
.
model_runner
#
multi_step_worker.model_runner = worker.model_runner
multi_step_worker
.
cache_engine
=
worker
.
cache_engine
#
multi_step_worker.cache_engine = worker.cache_engine
num_steps
=
1
num_steps
=
1
...
...
tests/worker/test_model_runner.py
View file @
6e435de7
import
random
import
random
import
torch
import
torch
from
vllm.config
import
ModelConfig
from
vllm.sequence
import
SamplingParams
,
SequenceData
,
SequenceGroupMetadata
from
vllm.sequence
import
SamplingParams
,
SequenceData
,
SequenceGroupMetadata
from
vllm.worker.model_runner
import
ModelRunner
from
vllm.worker.model_runner
import
ModelRunner
,
_BATCH_SIZE_ALIGNMENT
def
get_aligned_size
(
batch_size
:
int
,
alignment
:
int
):
return
((
batch_size
+
alignment
-
1
)
//
alignment
*
alignment
)
def
test_prepare_prompt
():
def
test_prepare_prompt
():
...
@@ -12,6 +17,7 @@ def test_prepare_prompt():
...
@@ -12,6 +17,7 @@ def test_prepare_prompt():
batch_size
=
random
.
randint
(
1
,
256
)
batch_size
=
random
.
randint
(
1
,
256
)
prompt_lens
=
[]
prompt_lens
=
[]
seq_group_metadata_list
=
[]
seq_group_metadata_list
=
[]
block_tables
=
{
0
:
[
1
]}
for
i
in
range
(
batch_size
):
for
i
in
range
(
batch_size
):
# make sure all tokens fit into one block
# make sure all tokens fit into one block
prompt_len
=
i
%
(
model_runner
.
block_size
-
1
)
+
1
prompt_len
=
i
%
(
model_runner
.
block_size
-
1
)
+
1
...
@@ -23,26 +29,165 @@ def test_prepare_prompt():
...
@@ -23,26 +29,165 @@ def test_prepare_prompt():
is_prompt
=
True
,
is_prompt
=
True
,
seq_data
=
{
0
:
SequenceData
(
seq_data
)},
seq_data
=
{
0
:
SequenceData
(
seq_data
)},
sampling_params
=
SamplingParams
(
temperature
=
0
),
sampling_params
=
SamplingParams
(
temperature
=
0
),
block_tables
=
{
0
:
[
1
]}
,
block_tables
=
block_tables
,
))
))
expected_selected_token_indices
=
[]
expected_selected_token_indices
=
[]
selected_token_start_idx
=
0
selected_token_start_idx
=
0
max_seq_len
=
max
(
prompt_lens
)
for
prompt_len
in
prompt_lens
:
for
prompt_len
in
prompt_lens
:
expected_selected_token_indices
.
append
(
selected_token_start_idx
+
expected_selected_token_indices
.
append
(
selected_token_start_idx
+
prompt_len
-
1
)
prompt_len
-
1
)
selected_token_start_idx
+=
max_seq
_len
selected_token_start_idx
+=
prompt
_len
input_tokens
,
input_positions
,
_
,
return_prompt_lens
,
_
,
_
,
_
,
_
=
(
(
input_tokens
,
input_positions
,
input_metadata
,
return_prompt_lens
,
_
,
_
,
model_runner
.
_prepare_prompt
(
seq_group_metadata_list
))
_
,
_
)
=
(
model_runner
.
_prepare_prompt
(
seq_group_metadata_list
))
assert
return_prompt_lens
==
prompt_lens
assert
return_prompt_lens
==
prompt_lens
# Verify input metadata is correct for prompts.
device
=
model_runner
.
device
assert
input_metadata
.
is_prompt
is
True
assert
torch
.
allclose
(
input_metadata
.
prompt_lens_tensor
,
torch
.
tensor
(
prompt_lens
,
device
=
device
))
assert
input_metadata
.
prompt_lens
==
prompt_lens
assert
input_metadata
.
num_prompt_tokens
==
sum
(
prompt_lens
)
assert
input_metadata
.
num_generation_tokens
==
0
assert
input_metadata
.
max_seq_len
==
max
(
prompt_lens
)
# Test subquery start locs.
start_idx
=
0
start_loc
=
[
start_idx
]
for
prompt_len
in
prompt_lens
:
start_idx
+=
prompt_len
start_loc
.
append
(
start_idx
)
assert
torch
.
allclose
(
input_metadata
.
subquery_start_loc
,
torch
.
tensor
(
start_loc
,
dtype
=
torch
.
int32
,
device
=
device
))
# Test seq start locs. Note that for normal prefill it is
# equivalent to subquery_start_loc.
start_idx
=
0
seq_start_loc
=
[
start_idx
]
for
prompt_len
in
prompt_lens
:
start_idx
+=
prompt_len
seq_start_loc
.
append
(
start_idx
)
assert
torch
.
allclose
(
input_metadata
.
seq_start_loc
,
torch
.
tensor
(
start_loc
,
dtype
=
torch
.
int32
,
device
=
device
))
assert
input_metadata
.
max_context_len
is
None
assert
torch
.
allclose
(
input_metadata
.
context_lens
,
torch
.
zeros
(
input_metadata
.
context_lens
.
shape
[
0
],
dtype
=
torch
.
int
,
device
=
device
))
expected
=
torch
.
tensor
([[]
for
_
in
range
(
len
(
seq_group_metadata_list
))],
dtype
=
torch
.
int32
,
device
=
model_runner
.
device
)
assert
torch
.
allclose
(
input_metadata
.
block_tables
,
expected
)
# Cuda graph should not be used for prerill.
assert
input_metadata
.
use_cuda_graph
is
False
assert
input_metadata
.
kv_cache_dtype
==
"auto"
assert
input_tokens
.
shape
==
(
sum
(
prompt_lens
),
)
assert
input_positions
.
shape
==
(
sum
(
prompt_lens
),
)
torch
.
testing
.
assert_close
(
input_tokens
,
input_positions
)
sampling_metadata
=
model_runner
.
_prepare_sample
(
seq_group_metadata_list
,
sampling_metadata
=
model_runner
.
_prepare_sample
(
seq_group_metadata_list
,
prompt_lens
,
prompt_lens
,
subquery_lens
=
prompt_lens
)
subquery_lens
=
prompt_lens
)
assert
input_tokens
.
shape
==
(
batch_size
,
max_seq_len
)
assert
input_tokens
.
shape
==
(
sum
(
prompt_lens
),
)
assert
input_positions
.
shape
==
(
batch_size
,
max_seq_len
)
assert
input_positions
.
shape
==
(
sum
(
prompt_lens
),
)
actual
=
sampling_metadata
.
selected_token_indices
expected
=
torch
.
tensor
(
expected_selected_token_indices
,
device
=
actual
.
device
,
dtype
=
actual
.
dtype
)
torch
.
testing
.
assert_close
(
actual
,
expected
)
torch
.
testing
.
assert_close
(
input_tokens
,
input_positions
)
actual
=
sampling_metadata
.
selected_token_indices
expected
=
torch
.
tensor
(
expected_selected_token_indices
,
device
=
actual
.
device
,
dtype
=
actual
.
dtype
)
torch
.
testing
.
assert_close
(
actual
,
expected
)
def
test_prepare_decode_cuda_graph
():
model_config
=
ModelConfig
(
"facebook/opt-125m"
,
"facebook/opt-125m"
,
tokenizer_mode
=
"auto"
,
trust_remote_code
=
False
,
download_dir
=
None
,
load_format
=
"dummy"
,
seed
=
0
,
dtype
=
"float16"
,
revision
=
None
,
enforce_eager
=
False
,
)
model_runner
=
ModelRunner
(
model_config
,
None
,
None
,
None
,
None
)
model_runner
.
set_block_size
(
16
)
batch_size
=
random
.
randint
(
1
,
256
)
prompt_lens
=
[]
seq_group_metadata_list
=
[]
for
i
in
range
(
batch_size
):
# make sure all tokens fit into one block
prompt_len
=
i
%
(
model_runner
.
block_size
-
1
)
+
1
prompt_lens
.
append
(
prompt_len
)
seq_data
=
list
(
range
(
prompt_len
))
seq_group_metadata_list
.
append
(
SequenceGroupMetadata
(
request_id
=
f
"test_
{
i
}
"
,
is_prompt
=
False
,
seq_data
=
{
0
:
SequenceData
(
seq_data
)},
sampling_params
=
SamplingParams
(
temperature
=
0
),
block_tables
=
{
0
:
[
1
]},
))
input_tokens
,
input_positions
,
input_metadata
,
_
,
_
,
_
=
(
model_runner
.
_prepare_decode
(
seq_group_metadata_list
))
# Verify input metadata is correct for prompts.
device
=
model_runner
.
device
assert
input_metadata
.
is_prompt
is
False
assert
input_metadata
.
prompt_lens
is
None
assert
input_metadata
.
num_prompt_tokens
==
0
assert
input_metadata
.
num_generation_tokens
==
(
get_aligned_size
(
len
(
seq_group_metadata_list
),
_BATCH_SIZE_ALIGNMENT
))
assert
input_metadata
.
max_seq_len
is
None
assert
input_metadata
.
subquery_start_loc
is
None
assert
input_metadata
.
seq_start_loc
is
None
assert
input_metadata
.
max_context_len
==
max
(
prompt_lens
)
assert
torch
.
allclose
(
input_metadata
.
context_lens
[:
len
(
prompt_lens
)],
torch
.
tensor
(
prompt_lens
,
dtype
=
torch
.
int
,
device
=
device
))
# block table's first index corresponds to each batch, meaning in
# decoding it is each token.
assert
input_metadata
.
block_tables
.
shape
[
0
]
==
len
(
input_tokens
)
# Block table's second dim correspondsd to each token's block number.
# It is padded up to
assert
input_metadata
.
block_tables
.
shape
[
1
]
==
(
model_runner
.
get_max_block_per_batch
())
# Cuda graph should not be used for prerill.
assert
input_metadata
.
use_cuda_graph
is
True
assert
input_metadata
.
kv_cache_dtype
==
"auto"
assert
input_tokens
.
shape
==
(
get_aligned_size
(
len
(
seq_group_metadata_list
),
_BATCH_SIZE_ALIGNMENT
),
)
assert
input_positions
.
shape
==
(
get_aligned_size
(
len
(
seq_group_metadata_list
),
_BATCH_SIZE_ALIGNMENT
),
)
torch
.
testing
.
assert_close
(
input_tokens
,
input_positions
)
torch
.
testing
.
assert_close
(
input_tokens
,
input_positions
)
# Verify Sampling
expected_selected_token_indices
=
[]
selected_token_start_idx
=
0
for
prompt_len
in
prompt_lens
:
expected_selected_token_indices
.
append
(
selected_token_start_idx
)
selected_token_start_idx
+=
1
sampling_metadata
=
model_runner
.
_prepare_sample
(
seq_group_metadata_list
,
prompt_lens
,
subquery_lens
=
prompt_lens
)
actual
=
sampling_metadata
.
selected_token_indices
actual
=
sampling_metadata
.
selected_token_indices
expected
=
torch
.
tensor
(
expected_selected_token_indices
,
expected
=
torch
.
tensor
(
expected_selected_token_indices
,
device
=
actual
.
device
,
device
=
actual
.
device
,
...
...
vllm/config.py
View file @
6e435de7
...
@@ -535,7 +535,6 @@ class SchedulerConfig:
...
@@ -535,7 +535,6 @@ class SchedulerConfig:
iteration.
iteration.
max_model_len: Maximum length of a sequence (including prompt
max_model_len: Maximum length of a sequence (including prompt
and generated text).
and generated text).
max_paddings: Maximum number of paddings to be added to a batch.
"""
"""
def
__init__
(
def
__init__
(
...
@@ -543,7 +542,6 @@ class SchedulerConfig:
...
@@ -543,7 +542,6 @@ class SchedulerConfig:
max_num_batched_tokens
:
Optional
[
int
],
max_num_batched_tokens
:
Optional
[
int
],
max_num_seqs
:
int
,
max_num_seqs
:
int
,
max_model_len
:
int
,
max_model_len
:
int
,
max_paddings
:
int
,
)
->
None
:
)
->
None
:
if
max_num_batched_tokens
is
not
None
:
if
max_num_batched_tokens
is
not
None
:
self
.
max_num_batched_tokens
=
max_num_batched_tokens
self
.
max_num_batched_tokens
=
max_num_batched_tokens
...
@@ -553,7 +551,6 @@ class SchedulerConfig:
...
@@ -553,7 +551,6 @@ class SchedulerConfig:
self
.
max_num_batched_tokens
=
max
(
max_model_len
,
2048
)
self
.
max_num_batched_tokens
=
max
(
max_model_len
,
2048
)
self
.
max_num_seqs
=
max_num_seqs
self
.
max_num_seqs
=
max_num_seqs
self
.
max_model_len
=
max_model_len
self
.
max_model_len
=
max_model_len
self
.
max_paddings
=
max_paddings
self
.
_verify_args
()
self
.
_verify_args
()
def
_verify_args
(
self
)
->
None
:
def
_verify_args
(
self
)
->
None
:
...
...
vllm/core/scheduler.py
View file @
6e435de7
...
@@ -173,12 +173,12 @@ class Scheduler:
...
@@ -173,12 +173,12 @@ class Scheduler:
curr_loras
=
set
(
curr_loras
=
set
(
seq_group
.
lora_int_id
seq_group
.
lora_int_id
for
seq_group
in
self
.
running
)
if
self
.
lora_enabled
else
None
for
seq_group
in
self
.
running
)
if
self
.
lora_enabled
else
None
seq_lens
:
List
[
int
]
=
[]
# Optimization: We do not sort the waiting queue since the preempted
# Optimization: We do not sort the waiting queue since the preempted
# sequence groups are added to the front and the new sequence groups
# sequence groups are added to the front and the new sequence groups
# are added to the back.
# are added to the back.
leftover_waiting_sequences
=
deque
()
leftover_waiting_sequences
=
deque
()
num_batched_tokens
=
0
while
self
.
waiting
:
while
self
.
waiting
:
seq_group
=
self
.
waiting
[
0
]
seq_group
=
self
.
waiting
[
0
]
waiting_seqs
=
seq_group
.
get_seqs
(
waiting_seqs
=
seq_group
.
get_seqs
(
...
@@ -223,8 +223,7 @@ class Scheduler:
...
@@ -223,8 +223,7 @@ class Scheduler:
continue
continue
# If the number of batched tokens exceeds the limit, stop.
# If the number of batched tokens exceeds the limit, stop.
new_seq_lens
=
seq_lens
+
[
num_prompt_tokens
]
num_batched_tokens
+=
num_prompt_tokens
num_batched_tokens
=
len
(
new_seq_lens
)
*
max
(
new_seq_lens
)
if
(
num_batched_tokens
>
if
(
num_batched_tokens
>
self
.
scheduler_config
.
max_num_batched_tokens
):
self
.
scheduler_config
.
max_num_batched_tokens
):
break
break
...
@@ -236,11 +235,6 @@ class Scheduler:
...
@@ -236,11 +235,6 @@ class Scheduler:
self
.
scheduler_config
.
max_num_seqs
):
self
.
scheduler_config
.
max_num_seqs
):
break
break
num_paddings
=
num_batched_tokens
-
sum
(
new_seq_lens
)
if
num_paddings
>
self
.
scheduler_config
.
max_paddings
:
break
seq_lens
=
new_seq_lens
if
lora_int_id
>
0
:
if
lora_int_id
>
0
:
curr_loras
.
add
(
lora_int_id
)
curr_loras
.
add
(
lora_int_id
)
self
.
waiting
.
popleft
()
self
.
waiting
.
popleft
()
...
@@ -255,8 +249,7 @@ class Scheduler:
...
@@ -255,8 +249,7 @@ class Scheduler:
scheduler_outputs
=
SchedulerOutputs
(
scheduler_outputs
=
SchedulerOutputs
(
scheduled_seq_groups
=
scheduled
,
scheduled_seq_groups
=
scheduled
,
prompt_run
=
True
,
prompt_run
=
True
,
num_batched_tokens
=
len
(
seq_lens
)
*
num_batched_tokens
=
num_batched_tokens
,
max
(
seq_lens
)
if
seq_lens
else
0
,
blocks_to_swap_in
=
blocks_to_swap_in
,
blocks_to_swap_in
=
blocks_to_swap_in
,
blocks_to_swap_out
=
blocks_to_swap_out
,
blocks_to_swap_out
=
blocks_to_swap_out
,
blocks_to_copy
=
blocks_to_copy
,
blocks_to_copy
=
blocks_to_copy
,
...
...
vllm/engine/arg_utils.py
View file @
6e435de7
...
@@ -31,7 +31,6 @@ class EngineArgs:
...
@@ -31,7 +31,6 @@ class EngineArgs:
gpu_memory_utilization
:
float
=
0.90
gpu_memory_utilization
:
float
=
0.90
max_num_batched_tokens
:
Optional
[
int
]
=
None
max_num_batched_tokens
:
Optional
[
int
]
=
None
max_num_seqs
:
int
=
256
max_num_seqs
:
int
=
256
max_paddings
:
int
=
256
max_logprobs
:
int
=
5
# OpenAI default value
max_logprobs
:
int
=
5
# OpenAI default value
disable_log_stats
:
bool
=
False
disable_log_stats
:
bool
=
False
revision
:
Optional
[
str
]
=
None
revision
:
Optional
[
str
]
=
None
...
@@ -213,10 +212,6 @@ class EngineArgs:
...
@@ -213,10 +212,6 @@ class EngineArgs:
type
=
int
,
type
=
int
,
default
=
EngineArgs
.
max_num_seqs
,
default
=
EngineArgs
.
max_num_seqs
,
help
=
'maximum number of sequences per iteration'
)
help
=
'maximum number of sequences per iteration'
)
parser
.
add_argument
(
'--max-paddings'
,
type
=
int
,
default
=
EngineArgs
.
max_paddings
,
help
=
'maximum number of paddings in a batch'
)
parser
.
add_argument
(
parser
.
add_argument
(
'--max-logprobs'
,
'--max-logprobs'
,
type
=
int
,
type
=
int
,
...
@@ -347,8 +342,7 @@ class EngineArgs:
...
@@ -347,8 +342,7 @@ class EngineArgs:
),
self
.
ray_workers_use_nsight
)
),
self
.
ray_workers_use_nsight
)
scheduler_config
=
SchedulerConfig
(
self
.
max_num_batched_tokens
,
scheduler_config
=
SchedulerConfig
(
self
.
max_num_batched_tokens
,
self
.
max_num_seqs
,
self
.
max_num_seqs
,
model_config
.
max_model_len
,
model_config
.
max_model_len
)
self
.
max_paddings
)
lora_config
=
LoRAConfig
(
lora_config
=
LoRAConfig
(
max_lora_rank
=
self
.
max_lora_rank
,
max_lora_rank
=
self
.
max_lora_rank
,
max_loras
=
self
.
max_loras
,
max_loras
=
self
.
max_loras
,
...
...
vllm/engine/llm_engine.py
View file @
6e435de7
...
@@ -561,7 +561,6 @@ class LLMEngine:
...
@@ -561,7 +561,6 @@ class LLMEngine:
# Log stats.
# Log stats.
if
self
.
log_stats
:
if
self
.
log_stats
:
self
.
stat_logger
.
log
(
self
.
_get_stats
(
scheduler_outputs
))
self
.
stat_logger
.
log
(
self
.
_get_stats
(
scheduler_outputs
))
return
request_outputs
return
request_outputs
def
step
(
self
)
->
List
[
RequestOutput
]:
def
step
(
self
)
->
List
[
RequestOutput
]:
...
...
vllm/model_executor/input_metadata.py
View file @
6e435de7
from
dataclasses
import
dataclass
,
fields
from
dataclasses
import
dataclass
,
fields
from
typing
import
Optional
,
Any
,
Dict
from
typing
import
Optional
,
List
,
Any
,
Dict
import
torch
import
torch
from
xformers.ops.fmha.attn_bias
import
AttentionBias
@
dataclass
@
dataclass
class
InputMetadata
:
class
InputMetadata
:
"""Metadata for input sequences. Used in PagedAttention.
"""Metadata for input sequences. Used in PagedAttention.
Args:
NOTE: Any python object stored here is not updated when it is
prompt_lens: Lengths of prompts.
cuda-graph replayed. If you have values that need to be changed
slot_mapping: The address to write the new KV to of each token.
dynamically, it should be stored in tensor. The tensor has to be
max_context_len: The maximum context length.
updated from `CUDAGraphRunner.forward` API.
context_lens: the length of attention context for each sequence.
block_tables: The block tables. (Seq id -> list of physical block)
kv_cache_dtype: Data type to store kv cache.
"""
"""
# Currently, input sequences can only contain all prompts
# or all decoding. True if all sequences are prompts.
is_prompt
:
bool
is_prompt
:
bool
# (num_tokens,). The indices of the token slots that input tokens will be
# stored into. E.g., if `slot_mapping` is [35, 2, 17] and the block size
# is 16, the three tokens are stored in the 3rd slot in block 2, 2nd slot
# in block 0, and 1st slot in block 1, respectively.
slot_mapping
:
torch
.
Tensor
slot_mapping
:
torch
.
Tensor
prompt_lens
:
Optional
[
torch
.
Tensor
]
# (batch_size,). The prompt length per sequence. None if it is a decoding.
max_seq_len
:
Optional
[
int
]
prompt_lens
:
Optional
[
List
[
int
]]
start_loc
:
Optional
[
torch
.
Tensor
]
# prompt_lens stored as a tensor.
prompt_lens_tensor
:
Optional
[
torch
.
Tensor
]
# The number of prompt tokens. Doesn't include padding.
num_prompt_tokens
:
int
# The number of generation tokens. Doesn't include padding.
num_generation_tokens
:
int
"""
Definition of context_len, subquery_len, and seqlen.
|---------- N-1 iteration --------|
|---------------- N iteration ---------------------|
|- tokenA -|......................|-- newTokens ---|
|---------- context_len ----------|
|-------------------- seqlen ----------------------|
|- subquery_len -|
WARNING: context_len has different definition depending on if it is
prefill vs decoding. When it is prefill, it doesn't include new
tokens. When it is for decoding, it includes a new token.
"""
# Maximum subquery length in the batch.
max_subquery_len
:
Optional
[
int
]
# Maximum context length in the batch.
max_context_len
:
Optional
[
int
]
max_context_len
:
Optional
[
int
]
# FIXME: It is for flash attn.
# Maximum sequence length in the batch.
max_seq_len
:
Optional
[
int
]
# (batch_size + 1,). The cumulative subquery lengths of the sequences in
# the batch, used to index into subquery. E.g., if the subquery length
# is [4, 6], it is [0, 4, 10].
subquery_start_loc
:
Optional
[
torch
.
Tensor
]
# FIXME: It is for flash attn.
# (batch_size + 1,). The cumulative sequence lengths of the sequences in
# the batch, used to index into sequence. E.g., if the sequence length is
# [4, 6], it is [0, 4, 10].
seq_start_loc
:
Optional
[
torch
.
Tensor
]
# (batch_size,). The length of context (tokens stored in KV cache) per
# sequence. WARNING: When it is a prefill request, it doesn't include new
# tokens. When it is for decoding, it includes a new token.
context_lens
:
Optional
[
torch
.
Tensor
]
context_lens
:
Optional
[
torch
.
Tensor
]
# (batch_size, max_blocks_per_seq).
# Block addresses per sequence. (Seq id -> list of physical block)
# E.g., [0, 1, 2] means tokens are stored in 0th, 1st, and 2nd blocks
# in the kv cache. Each block can contain up to block_size tokens.
# 2nd dimensions are padded up to max_blocks_per_seq if it is cuda-graph
# captured.
block_tables
:
Optional
[
torch
.
Tensor
]
block_tables
:
Optional
[
torch
.
Tensor
]
# Whether or not if cuda graph is enabled.
# Cuda-graph is currently enabled for decoding only.
use_cuda_graph
:
bool
use_cuda_graph
:
bool
kv_cache_dtype
:
str
kv_cache_dtype
:
str
def
__post_init__
(
self
):
def
__post_init__
(
self
):
# Set during the execution of the first attention op.
# It is a list because it is needed to set per prompt
# when alibi slopes is used. It is because of the limitation
# from xformer API.
# will not appear in the __repr__ and __init__
# will not appear in the __repr__ and __init__
self
.
attn_bias
=
None
self
.
attn_bias
:
Optional
[
List
[
AttentionBias
]]
=
None
# Cuda graph is only used for decoding now.
if
self
.
use_cuda_graph
:
assert
self
.
num_prompt_tokens
==
0
def
asdict_zerocopy
(
self
)
->
Dict
[
str
,
Any
]:
def
asdict_zerocopy
(
self
)
->
Dict
[
str
,
Any
]:
"""Similar to dataclasses.asdict, but avoids deepcopying."""
"""Similar to dataclasses.asdict, but avoids deepcopying."""
...
...
vllm/model_executor/layers/activation.py
View file @
6e435de7
...
@@ -20,8 +20,8 @@ class SiluAndMul(nn.Module):
...
@@ -20,8 +20,8 @@ class SiluAndMul(nn.Module):
The function computes x -> silu(x[:d]) * x[d:] where d = x.shape[-1] // 2.
The function computes x -> silu(x[:d]) * x[d:] where d = x.shape[-1] // 2.
Shapes:
Shapes:
x: (
batch_size, seq_len, 2 * d) or (num_tok
en
s
, 2 * d)
x: (
num_tokens, 2 * d) or (batch_size, seq_l
en, 2 * d)
return: (batch_size, seq_len, d)
or (num_tokens, d)
return:
(num_tokens, d) or
(batch_size, seq_len, d)
"""
"""
def
_forward
(
self
,
x
:
torch
.
Tensor
)
->
torch
.
Tensor
:
def
_forward
(
self
,
x
:
torch
.
Tensor
)
->
torch
.
Tensor
:
...
...
vllm/model_executor/layers/attention/attention.py
View file @
6e435de7
...
@@ -17,11 +17,12 @@ class Attention(nn.Module):
...
@@ -17,11 +17,12 @@ class Attention(nn.Module):
This class takes query, key, and value tensors as input. The input tensors
This class takes query, key, and value tensors as input. The input tensors
can either contain prompt tokens or generation tokens.
can either contain prompt tokens or generation tokens.
The class does the following:
The class does the following:
1. Store the input key and value tensors in the KV cache.
1. Store the input key and value tensors in the KV cache.
2. Perform (multi-head/multi-query/grouped-query) attention.
2. Perform (multi-head/multi-query/grouped-query) attention.
3.
Return
the output tensor.
3.
Output
the output tensor.
"""
"""
def
__init__
(
def
__init__
(
...
...
vllm/model_executor/layers/attention/backends/flash_attn.py
View file @
6e435de7
"""Attention layer with Flash and PagedAttention."""
"""Attention layer with Flash and PagedAttention."""
from
typing
import
List
,
Optional
from
typing
import
List
,
Optional
from
flash_attn
import
flash_attn_func
from
flash_attn
import
flash_attn_
varlen_
func
import
torch
import
torch
from
vllm.model_executor.input_metadata
import
InputMetadata
from
vllm.model_executor.input_metadata
import
InputMetadata
...
@@ -10,6 +10,21 @@ from vllm.model_executor.layers.attention.ops.paged_attn import (
...
@@ -10,6 +10,21 @@ from vllm.model_executor.layers.attention.ops.paged_attn import (
class
FlashAttentionBackend
:
class
FlashAttentionBackend
:
"""
If the input tensors contain prompt tokens, the layout is as follows:
|<--------------- num_prompt_tokens -------------->|
|<--prompt_0-->|<--prompt_1-->|...|<--prompt_N-1-->|
Otherwise, the layout is as follows:
|<------------------ num_generation_tokens (M) ----------------->|
|<--generation_0-->|..........|<--generation_M-1-->|<--padding-->|
Generation tokens can contain padding when cuda-graph is used.
Currently, prompt tokens don't contain any padding.
The prompts might have different lengths, while the generation tokens
always have length 1.
"""
def
__init__
(
def
__init__
(
self
,
self
,
...
@@ -52,18 +67,18 @@ class FlashAttentionBackend:
...
@@ -52,18 +67,18 @@ class FlashAttentionBackend:
"""Forward pass with FlashAttention and PagedAttention.
"""Forward pass with FlashAttention and PagedAttention.
Args:
Args:
query: shape = [
batch_size, seq_len
, num_heads * head_size]
query: shape = [
num_tokens
, num_heads * head_size]
key: shape = [
batch_size, seq_len
, num_kv_heads * head_size]
key: shape = [
num_tokens
, num_kv_heads * head_size]
value: shape = [
batch_size, seq_len
, num_kv_heads * head_size]
value: shape = [
num_tokens
, num_kv_heads * head_size]
key_cache: shape = [num_blocks, num_kv_heads, head_size/x,
key_cache: shape = [num_blocks, num_kv_heads, head_size/x,
block_size, x]
block_size, x]
value_cache: shape = [num_blocks, num_kv_heads, head_size,
value_cache: shape = [num_blocks, num_kv_heads, head_size,
block_size]
block_size]
input_metadata: metadata for the inputs.
input_metadata: metadata for the inputs.
Returns:
Returns:
shape = [
batch_size, seq_len
, num_heads * head_size]
shape = [
num_tokens
, num_heads * head_size]
"""
"""
batch_size
,
seq_len
,
hidden_size
=
query
.
shape
num_tokens
,
hidden_size
=
query
.
shape
# Reshape the query, key, and value tensors.
# Reshape the query, key, and value tensors.
query
=
query
.
view
(
-
1
,
self
.
num_heads
,
self
.
head_size
)
query
=
query
.
view
(
-
1
,
self
.
num_heads
,
self
.
head_size
)
key
=
key
.
view
(
-
1
,
self
.
num_kv_heads
,
self
.
head_size
)
key
=
key
.
view
(
-
1
,
self
.
num_kv_heads
,
self
.
head_size
)
...
@@ -82,13 +97,16 @@ class FlashAttentionBackend:
...
@@ -82,13 +97,16 @@ class FlashAttentionBackend:
if
(
key_cache
is
None
or
value_cache
is
None
if
(
key_cache
is
None
or
value_cache
is
None
or
input_metadata
.
block_tables
.
numel
()
==
0
):
or
input_metadata
.
block_tables
.
numel
()
==
0
):
# normal attention
# normal attention
query
=
query
.
unflatten
(
0
,
(
batch_size
,
seq_len
))
# When block_tables are not filled, it means q and k are the
key
=
key
.
unflatten
(
0
,
(
batch_size
,
seq_len
))
# prompt, and they have the same length.
value
=
value
.
unflatten
(
0
,
(
batch_size
,
seq_len
))
output
=
flash_attn_varlen_func
(
output
=
flash_attn_func
(
q
=
query
,
query
,
k
=
key
,
key
,
v
=
value
,
value
,
cu_seqlens_q
=
input_metadata
.
seq_start_loc
,
cu_seqlens_k
=
input_metadata
.
seq_start_loc
,
max_seqlen_q
=
input_metadata
.
max_seq_len
,
max_seqlen_k
=
input_metadata
.
max_seq_len
,
softmax_scale
=
self
.
scale
,
softmax_scale
=
self
.
scale
,
causal
=
True
,
causal
=
True
,
window_size
=
self
.
sliding_window
,
window_size
=
self
.
sliding_window
,
...
@@ -118,4 +136,4 @@ class FlashAttentionBackend:
...
@@ -118,4 +136,4 @@ class FlashAttentionBackend:
)
)
# Reshape the output tensor.
# Reshape the output tensor.
return
output
.
view
(
batch_size
,
seq_len
,
hidden_size
)
return
output
.
view
(
num_tokens
,
hidden_size
)
vllm/model_executor/layers/attention/backends/xformers.py
View file @
6e435de7
...
@@ -14,6 +14,21 @@ from vllm.utils import is_hip
...
@@ -14,6 +14,21 @@ from vllm.utils import is_hip
class
XFormersBackend
:
class
XFormersBackend
:
"""
If the input tensors contain prompt tokens, the layout is as follows:
|<--------------- num_prompt_tokens --------------->|
|<--prompt_0-->|<--prompt_1-->|...|<--prompt_N-1--->|
Otherwise, the layout is as follows:
|<------------------ num_generation_tokens (M) ----------------->|
|<--generation_0-->|..........|<--generation_M-1-->|<--padding-->|
Generation tokens can contain padding when cuda-graph is used.
Currently, prompt tokens don't contain any padding.
The prompts might have different lengths, while the generation tokens
always have length 1.
"""
def
__init__
(
def
__init__
(
self
,
self
,
...
@@ -55,19 +70,18 @@ class XFormersBackend:
...
@@ -55,19 +70,18 @@ class XFormersBackend:
"""Forward pass with xFormers and PagedAttention.
"""Forward pass with xFormers and PagedAttention.
Args:
Args:
query: shape = [
batch_size, seq_len
, num_heads * head_size]
query: shape = [
num_tokens
, num_heads * head_size]
key: shape = [
batch_size, seq_len
, num_kv_heads * head_size]
key: shape = [
num_tokens
, num_kv_heads * head_size]
value: shape = [
batch_size, seq_len
, num_kv_heads * head_size]
value: shape = [
num_tokens
, num_kv_heads * head_size]
key_cache: shape = [num_blocks, num_kv_heads, head_size/x,
key_cache: shape = [num_blocks, num_kv_heads, head_size/x,
block_size, x]
block_size, x]
value_cache: shape = [num_blocks, num_kv_heads, head_size,
value_cache: shape = [num_blocks, num_kv_heads, head_size,
block_size]
block_size]
input_metadata: metadata for the inputs.
input_metadata: metadata for the inputs.
Returns:
Returns:
shape = [
batch_size, seq_len
, num_heads * head_size]
shape = [
num_tokens
, num_heads * head_size]
"""
"""
batch_size
,
seq_len
,
hidden_size
=
query
.
shape
num_tokens
,
hidden_size
=
query
.
shape
# Reshape the query, key, and value tensors.
query
=
query
.
view
(
-
1
,
self
.
num_heads
,
self
.
head_size
)
query
=
query
.
view
(
-
1
,
self
.
num_heads
,
self
.
head_size
)
key
=
key
.
view
(
-
1
,
self
.
num_kv_heads
,
self
.
head_size
)
key
=
key
.
view
(
-
1
,
self
.
num_kv_heads
,
self
.
head_size
)
value
=
value
.
view
(
-
1
,
self
.
num_kv_heads
,
self
.
head_size
)
value
=
value
.
view
(
-
1
,
self
.
num_kv_heads
,
self
.
head_size
)
...
@@ -82,9 +96,10 @@ class XFormersBackend:
...
@@ -82,9 +96,10 @@ class XFormersBackend:
if
input_metadata
.
is_prompt
:
if
input_metadata
.
is_prompt
:
# Prompt run.
# Prompt run.
# key_cache and value_cache are None when it is a profiling run.
# block tables are empty if the prompt has never been computed.
if
(
key_cache
is
None
or
value_cache
is
None
if
(
key_cache
is
None
or
value_cache
is
None
or
input_metadata
.
block_tables
.
numel
()
==
0
):
or
input_metadata
.
block_tables
.
numel
()
==
0
):
# normal attention
if
self
.
num_kv_heads
!=
self
.
num_heads
:
if
self
.
num_kv_heads
!=
self
.
num_heads
:
# As of Nov 2023, xformers only supports MHA. For MQA/GQA,
# As of Nov 2023, xformers only supports MHA. For MQA/GQA,
# project the key and value tensors to the desired number of
# project the key and value tensors to the desired number of
...
@@ -103,61 +118,33 @@ class XFormersBackend:
...
@@ -103,61 +118,33 @@ class XFormersBackend:
self
.
num_queries_per_kv
,
self
.
num_queries_per_kv
,
value
.
shape
[
-
1
])
value
.
shape
[
-
1
])
# Set attention bias if not provided. This typically happens at
# the very attention layer of every iteration.
# FIXME(woosuk): This is a hack.
if
input_metadata
.
attn_bias
is
None
:
if
self
.
alibi_slopes
is
None
:
attn_bias
=
BlockDiagonalCausalMask
.
from_seqlens
(
[
seq_len
]
*
batch_size
)
if
self
.
sliding_window
is
not
None
:
attn_bias
=
attn_bias
.
make_local_attention
(
self
.
sliding_window
)
input_metadata
.
attn_bias
=
attn_bias
else
:
input_metadata
.
attn_bias
=
_make_alibi_bias
(
self
.
alibi_slopes
,
self
.
num_kv_heads
,
batch_size
,
seq_len
,
query
.
dtype
)
if
self
.
use_ref_attention
:
if
self
.
use_ref_attention
:
output
=
_ref_masked_attention
(
print
(
"ref attention used."
)
query
,
output
=
torch
.
empty_like
(
query
)
key
,
start
=
0
value
,
for
_
,
prompt_len
in
enumerate
(
input_metadata
.
prompt_lens
):
self
.
num_heads
,
end
=
start
+
prompt_len
self
.
num_kv_heads
,
out
=
_ref_masked_attention
(
self
.
head_size
,
query
[
None
,
start
:
end
],
self
.
scale
,
key
[
None
,
start
:
end
],
)
value
[
None
,
start
:
end
],
self
.
num_heads
,
self
.
num_kv_heads
,
self
.
head_size
,
self
.
scale
,
)
# TODO(woosuk): Unnecessary copy. Optimize.
output
[
start
:
end
].
copy_
(
out
)
start
+=
prompt_len
# Using view got RuntimeError: view size is not compatible
# Using view got RuntimeError: view size is not compatible
# with input tensor's size and stride (at least one
# with input tensor's size and stride (at least one
# dimension spans across two contiguous subspaces).
# dimension spans across two contiguous subspaces).
# Use reshape instead.
# Use reshape instead.
return
output
.
reshape
(
batch_size
,
seq_len
,
hidden_size
)
return
output
.
reshape
(
num_tokens
,
hidden_size
)
# TODO(woosuk): Too many view operations. Let's try to reduce
# them in the future for code readability.
if
self
.
alibi_slopes
is
None
:
query
=
query
.
unsqueeze
(
0
)
key
=
key
.
unsqueeze
(
0
)
value
=
value
.
unsqueeze
(
0
)
else
:
query
=
query
.
unflatten
(
0
,
(
batch_size
,
seq_len
))
key
=
key
.
unflatten
(
0
,
(
batch_size
,
seq_len
))
value
=
value
.
unflatten
(
0
,
(
batch_size
,
seq_len
))
out
=
xops
.
memory_efficient_attention_forward
(
query
,
key
,
value
,
attn_bias
=
input_metadata
.
attn_bias
,
p
=
0.0
,
scale
=
self
.
scale
,
op
=
xops
.
fmha
.
MemoryEfficientAttentionFlashAttentionOp
[
0
]
if
(
is_hip
())
else
None
,
)
output
=
out
.
view_as
(
query
)
output
=
self
.
_run_memory_efficient_xformer_forward
(
query
,
key
,
value
,
input_metadata
)
else
:
else
:
# prefix-enabled attention
# prefix-enabled attention
output
=
PagedAttentionImpl
.
forward_prefix
(
output
=
PagedAttentionImpl
.
forward_prefix
(
...
@@ -182,41 +169,117 @@ class XFormersBackend:
...
@@ -182,41 +169,117 @@ class XFormersBackend:
)
)
# Reshape the output tensor.
# Reshape the output tensor.
return
output
.
view
(
batch_size
,
seq_len
,
hidden_size
)
return
output
.
view
(
-
1
,
self
.
num_heads
*
self
.
head_size
)
def
_run_memory_efficient_xformer_forward
(
self
,
query
:
torch
.
Tensor
,
key
:
torch
.
Tensor
,
value
:
torch
.
Tensor
,
input_metadata
:
InputMetadata
,
)
->
torch
.
Tensor
:
"""Attention for 1D query of multiple prompts. Multiple prompt
tokens are flattened in to `query` input.
Args:
output: shape = [num_prompt_tokens, num_heads, head_size]
query: shape = [num_prompt_tokens, num_heads, head_size]
key: shape = [num_prompt_tokens, num_kv_heads, head_size]
value: shape = [num_prompt_tokens, num_kv_heads, head_size]
input_metadata: metadata for paged attention.
"""
# Set attention bias if not provided. This typically happens at
# the very attention layer of every iteration.
# FIXME(woosuk): This is a hack.
if
input_metadata
.
attn_bias
is
None
:
if
self
.
alibi_slopes
is
None
:
attn_bias
=
BlockDiagonalCausalMask
.
from_seqlens
(
input_metadata
.
prompt_lens
)
if
self
.
sliding_window
is
not
None
:
attn_bias
=
attn_bias
.
make_local_attention
(
self
.
sliding_window
)
input_metadata
.
attn_bias
=
[
attn_bias
]
else
:
input_metadata
.
attn_bias
=
_make_alibi_bias
(
self
.
alibi_slopes
,
self
.
num_kv_heads
,
query
.
dtype
,
input_metadata
)
op
=
xops
.
fmha
.
MemoryEfficientAttentionFlashAttentionOp
[
0
]
if
(
is_hip
())
else
None
# No alibi slopes.
# TODO(woosuk): Too many view operations. Let's try to reduce
# them in the future for code readability.
if
self
.
alibi_slopes
is
None
:
query
=
query
.
unsqueeze
(
0
)
key
=
key
.
unsqueeze
(
0
)
value
=
value
.
unsqueeze
(
0
)
out
=
xops
.
memory_efficient_attention_forward
(
query
,
key
,
value
,
attn_bias
=
input_metadata
.
attn_bias
[
0
],
p
=
0.0
,
scale
=
self
.
scale
,
op
=
op
)
return
out
.
view_as
(
query
)
# Attention with alibi slopes.
# FIXME(woosuk): Because xformers does not support dynamic sequence
# lengths with custom attention bias, we process each prompt one by
# one. This is inefficient, especially when we have many short prompts.
output
=
torch
.
empty_like
(
query
)
start
=
0
for
i
,
prompt_len
in
enumerate
(
input_metadata
.
prompt_lens
):
end
=
start
+
prompt_len
out
=
xops
.
memory_efficient_attention_forward
(
query
[
None
,
start
:
end
],
key
[
None
,
start
:
end
],
value
[
None
,
start
:
end
],
attn_bias
=
input_metadata
.
attn_bias
[
i
],
p
=
0.0
,
scale
=
self
.
scale
,
op
=
op
)
# TODO(woosuk): Unnecessary copy. Optimize.
output
[
start
:
end
].
copy_
(
out
.
squeeze
(
0
))
start
+=
prompt_len
return
output
def
_make_alibi_bias
(
def
_make_alibi_bias
(
alibi_slopes
:
torch
.
Tensor
,
alibi_slopes
:
torch
.
Tensor
,
num_kv_heads
:
int
,
num_kv_heads
:
int
,
batch_size
:
int
,
seq_len
:
int
,
dtype
:
torch
.
dtype
,
dtype
:
torch
.
dtype
,
input_metadata
:
InputMetadata
,
)
->
LowerTriangularMaskWithTensorBias
:
)
->
LowerTriangularMaskWithTensorBias
:
bias
=
torch
.
arange
(
seq_len
,
dtype
=
dtype
)
attn_biases
=
[]
# NOTE(zhuohan): HF uses
for
prompt_len
in
input_metadata
.
prompt_lens
:
# `bias = bias[None, :].repeat(prompt_len, 1)`
bias
=
torch
.
arange
(
prompt_len
,
dtype
=
dtype
)
# here. We find that both biases give the same results, but
# NOTE(zhuohan): HF uses
# the bias below more accurately follows the original ALiBi
# `bias = bias[None, :].repeat(prompt_len, 1)`
# paper.
# here. We find that both biases give the same results, but
bias
=
bias
[
None
,
:]
-
bias
[:,
None
]
# the bias below more accurately follows the original ALiBi
# paper.
# When using custom attention bias, xformers requires the bias to
# Calculate a matrix where each element represents ith element- jth
# be sliced from a tensor whose length is a multiple of 8.
# element.
padded_len
=
(
seq_len
+
7
)
//
8
*
8
bias
=
bias
[
None
,
:]
-
bias
[:,
None
]
num_heads
=
alibi_slopes
.
shape
[
0
]
bias
=
torch
.
empty
(
padded_len
=
(
prompt_len
+
7
)
//
8
*
8
batch_size
,
num_heads
=
alibi_slopes
.
shape
[
0
]
num_heads
,
bias
=
torch
.
empty
(
seq_len
,
1
,
# batch size
padded_len
,
num_heads
,
device
=
alibi_slopes
.
device
,
prompt_len
,
dtype
=
dtype
,
padded_len
,
)[:,
:,
:,
:
seq_len
].
copy_
(
bias
)
device
=
alibi_slopes
.
device
,
bias
.
mul_
(
alibi_slopes
[:,
None
,
None
])
dtype
=
dtype
,
if
num_heads
!=
num_kv_heads
:
)[:,
:,
:,
:
prompt_len
].
copy_
(
bias
)
bias
=
bias
.
unflatten
(
1
,
(
num_kv_heads
,
num_heads
//
num_kv_heads
))
bias
.
mul_
(
alibi_slopes
[:,
None
,
None
])
attn_bias
=
LowerTriangularMaskWithTensorBias
(
bias
)
if
num_heads
!=
num_kv_heads
:
return
attn_bias
bias
=
bias
.
unflatten
(
1
,
(
num_kv_heads
,
num_heads
//
num_kv_heads
))
attn_biases
.
append
(
LowerTriangularMaskWithTensorBias
(
bias
))
return
attn_biases
def
_check_use_ref_attention
()
->
bool
:
def
_check_use_ref_attention
()
->
bool
:
...
@@ -239,7 +302,6 @@ def _ref_masked_attention(
...
@@ -239,7 +302,6 @@ def _ref_masked_attention(
query
=
query
.
view
(
-
1
,
num_heads
,
head_size
)
query
=
query
.
view
(
-
1
,
num_heads
,
head_size
)
key
=
key
.
view
(
-
1
,
num_kv_heads
,
head_size
)
key
=
key
.
view
(
-
1
,
num_kv_heads
,
head_size
)
value
=
value
.
view
(
-
1
,
num_kv_heads
,
head_size
)
value
=
value
.
view
(
-
1
,
num_kv_heads
,
head_size
)
seq_len
,
_
,
_
=
query
.
shape
seq_len
,
_
,
_
=
query
.
shape
attn_mask
=
torch
.
triu
(
torch
.
ones
(
seq_len
,
attn_mask
=
torch
.
triu
(
torch
.
ones
(
seq_len
,
seq_len
,
seq_len
,
...
...
vllm/model_executor/layers/attention/ops/paged_attn.py
View file @
6e435de7
...
@@ -128,11 +128,12 @@ class PagedAttentionImpl:
...
@@ -128,11 +128,12 @@ class PagedAttentionImpl:
output
,
output
,
key_cache
,
key_cache
,
value_cache
,
value_cache
,
input_metadata
.
block_tables
,
# [BS, max_block_per_request]
input_metadata
.
block_tables
,
input_metadata
.
start_loc
,
# subquery_start_loc is (batch_size + 1,)
input_metadata
.
prompt_lens
,
input_metadata
.
subquery_start_loc
[:
-
1
],
input_metadata
.
prompt_lens_tensor
,
input_metadata
.
context_lens
,
input_metadata
.
context_lens
,
input_metadata
.
max_s
eq
_len
,
input_metadata
.
max_s
ubquery
_len
,
alibi_slopes
,
alibi_slopes
,
)
)
return
output
return
output
vllm/model_executor/layers/sampler.py
View file @
6e435de7
...
@@ -128,7 +128,6 @@ def _prune_hidden_states(
...
@@ -128,7 +128,6 @@ def _prune_hidden_states(
hidden_states
:
torch
.
Tensor
,
hidden_states
:
torch
.
Tensor
,
sampling_metadata
:
SamplingMetadata
,
sampling_metadata
:
SamplingMetadata
,
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
hidden_states
=
hidden_states
.
view
(
-
1
,
hidden_states
.
shape
[
-
1
])
return
hidden_states
.
index_select
(
0
,
return
hidden_states
.
index_select
(
0
,
sampling_metadata
.
selected_token_indices
)
sampling_metadata
.
selected_token_indices
)
...
...
vllm/worker/model_runner.py
View file @
6e435de7
This diff is collapsed.
Click to expand it.
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment