Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
67b4221a
Unverified
Commit
67b4221a
authored
Apr 11, 2024
by
SangBin Cho
Committed by
GitHub
Apr 10, 2024
Browse files
[Core][5/N] Fully working chunked prefill e2e (#3884)
parent
63e7176f
Changes
26
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
666 additions
and
225 deletions
+666
-225
.buildkite/test-pipeline.yaml
.buildkite/test-pipeline.yaml
+2
-0
benchmarks/benchmark_latency.py
benchmarks/benchmark_latency.py
+1
-2
benchmarks/benchmark_throughput.py
benchmarks/benchmark_throughput.py
+38
-24
tests/basic_correctness/test_chunked_prefill.py
tests/basic_correctness/test_chunked_prefill.py
+70
-0
tests/core/test_chunked_prefill_scheduler.py
tests/core/test_chunked_prefill_scheduler.py
+8
-8
tests/distributed/test_basic_distributed_correctness.py
tests/distributed/test_basic_distributed_correctness.py
+6
-1
tests/distributed/test_chunked_prefill_distributed.py
tests/distributed/test_chunked_prefill_distributed.py
+66
-0
tests/entrypoints/test_openai_server.py
tests/entrypoints/test_openai_server.py
+1
-1
tests/models/test_models.py
tests/models/test_models.py
+1
-1
tests/worker/test_model_runner.py
tests/worker/test_model_runner.py
+170
-19
vllm/attention/__init__.py
vllm/attention/__init__.py
+3
-1
vllm/attention/backends/abstract.py
vllm/attention/backends/abstract.py
+39
-3
vllm/attention/backends/flash_attn.py
vllm/attention/backends/flash_attn.py
+55
-30
vllm/attention/backends/rocm_flash_attn.py
vllm/attention/backends/rocm_flash_attn.py
+63
-34
vllm/attention/backends/torch_sdpa.py
vllm/attention/backends/torch_sdpa.py
+42
-25
vllm/attention/backends/xformers.py
vllm/attention/backends/xformers.py
+79
-59
vllm/attention/layer.py
vllm/attention/layer.py
+3
-2
vllm/attention/ops/paged_attn.py
vllm/attention/ops/paged_attn.py
+0
-6
vllm/config.py
vllm/config.py
+10
-3
vllm/core/scheduler.py
vllm/core/scheduler.py
+9
-6
No files found.
.buildkite/test-pipeline.yaml
View file @
67b4221a
...
@@ -29,6 +29,8 @@ steps:
...
@@ -29,6 +29,8 @@ steps:
-
pytest -v -s test_pynccl.py
-
pytest -v -s test_pynccl.py
-
TEST_DIST_MODEL=facebook/opt-125m pytest -v -s test_basic_distributed_correctness.py
-
TEST_DIST_MODEL=facebook/opt-125m pytest -v -s test_basic_distributed_correctness.py
-
TEST_DIST_MODEL=meta-llama/Llama-2-7b-hf pytest -v -s test_basic_distributed_correctness.py
-
TEST_DIST_MODEL=meta-llama/Llama-2-7b-hf pytest -v -s test_basic_distributed_correctness.py
-
TEST_DIST_MODEL=facebook/opt-125m pytest -v -s test_chunked_prefill_distributed.py
-
TEST_DIST_MODEL=meta-llama/Llama-2-7b-hf pytest -v -s test_chunked_prefill_distributed.py
-
label
:
Engine Test
-
label
:
Engine Test
command
:
pytest -v -s engine tokenization test_sequence.py test_config.py
command
:
pytest -v -s engine tokenization test_sequence.py test_config.py
...
...
benchmarks/benchmark_latency.py
View file @
67b4221a
...
@@ -177,8 +177,7 @@ if __name__ == '__main__':
...
@@ -177,8 +177,7 @@ if __name__ == '__main__':
help
=
'block size of key/value cache'
)
help
=
'block size of key/value cache'
)
parser
.
add_argument
(
parser
.
add_argument
(
'--enable-chunked-prefill'
,
'--enable-chunked-prefill'
,
type
=
bool
,
action
=
'store_true'
,
default
=
False
,
help
=
'If True, the prefill requests can be chunked based on the '
help
=
'If True, the prefill requests can be chunked based on the '
'max_num_batched_tokens'
)
'max_num_batched_tokens'
)
parser
.
add_argument
(
parser
.
add_argument
(
...
...
benchmarks/benchmark_throughput.py
View file @
67b4221a
...
@@ -74,25 +74,31 @@ def run_vllm(
...
@@ -74,25 +74,31 @@ def run_vllm(
quantization_param_path
:
Optional
[
str
],
quantization_param_path
:
Optional
[
str
],
device
:
str
,
device
:
str
,
enable_prefix_caching
:
bool
,
enable_prefix_caching
:
bool
,
enable_chunked_prefill
:
bool
,
max_num_batched_tokens
:
int
,
gpu_memory_utilization
:
float
=
0.9
,
gpu_memory_utilization
:
float
=
0.9
,
download_dir
:
Optional
[
str
]
=
None
,
download_dir
:
Optional
[
str
]
=
None
,
)
->
float
:
)
->
float
:
from
vllm
import
LLM
,
SamplingParams
from
vllm
import
LLM
,
SamplingParams
llm
=
LLM
(
model
=
model
,
llm
=
LLM
(
tokenizer
=
tokenizer
,
model
=
model
,
quantization
=
quantization
,
tokenizer
=
tokenizer
,
tensor_parallel_size
=
tensor_parallel_size
,
quantization
=
quantization
,
seed
=
seed
,
tensor_parallel_size
=
tensor_parallel_size
,
trust_remote_code
=
trust_remote_code
,
seed
=
seed
,
dtype
=
dtype
,
trust_remote_code
=
trust_remote_code
,
max_model_len
=
max_model_len
,
dtype
=
dtype
,
gpu_memory_utilization
=
gpu_memory_utilization
,
max_model_len
=
max_model_len
,
enforce_eager
=
enforce_eager
,
gpu_memory_utilization
=
gpu_memory_utilization
,
kv_cache_dtype
=
kv_cache_dtype
,
enforce_eager
=
enforce_eager
,
quantization_param_path
=
quantization_param_path
,
kv_cache_dtype
=
kv_cache_dtype
,
device
=
device
,
quantization_param_path
=
quantization_param_path
,
enable_prefix_caching
=
enable_prefix_caching
,
device
=
device
,
download_dir
=
download_dir
)
enable_prefix_caching
=
enable_prefix_caching
,
download_dir
=
download_dir
,
enable_chunked_prefill
=
enable_chunked_prefill
,
max_num_batched_tokens
=
max_num_batched_tokens
,
)
# Add the requests to the engine.
# Add the requests to the engine.
for
prompt
,
_
,
output_len
in
requests
:
for
prompt
,
_
,
output_len
in
requests
:
...
@@ -213,15 +219,15 @@ def main(args: argparse.Namespace):
...
@@ -213,15 +219,15 @@ def main(args: argparse.Namespace):
args
.
output_len
)
args
.
output_len
)
if
args
.
backend
==
"vllm"
:
if
args
.
backend
==
"vllm"
:
elapsed_time
=
run_vllm
(
requests
,
args
.
model
,
args
.
tokenizer
,
elapsed_time
=
run_vllm
(
args
.
quantization
,
args
.
tensor_parallel_size
,
requests
,
args
.
model
,
args
.
tokenizer
,
args
.
quantization
,
args
.
seed
,
args
.
n
,
args
.
use_beam_search
,
args
.
tensor_parallel_size
,
args
.
seed
,
args
.
n
,
args
.
use_beam_search
,
args
.
trust_remote_code
,
args
.
dtype
,
args
.
trust_remote_code
,
args
.
dtype
,
args
.
max_model_len
,
args
.
max_model_len
,
args
.
enforce_eager
,
args
.
enforce_eager
,
args
.
kv_cache_dtype
,
args
.
kv_cache_dtyp
e
,
args
.
quantization_param_path
,
args
.
devic
e
,
args
.
quantization_param_path
,
args
.
device
,
args
.
enable_prefix_caching
,
args
.
enable_chunked_prefill
,
args
.
enable_prefix_caching
,
args
.
max_num_batched_tokens
,
args
.
gpu_memory_utilization
,
args
.
gpu_memory_utilization
,
args
.
download_dir
)
args
.
download_dir
)
elif
args
.
backend
==
"hf"
:
elif
args
.
backend
==
"hf"
:
assert
args
.
tensor_parallel_size
==
1
assert
args
.
tensor_parallel_size
==
1
elapsed_time
=
run_hf
(
requests
,
args
.
model
,
tokenizer
,
args
.
n
,
elapsed_time
=
run_hf
(
requests
,
args
.
model
,
tokenizer
,
args
.
n
,
...
@@ -335,6 +341,14 @@ if __name__ == "__main__":
...
@@ -335,6 +341,14 @@ if __name__ == "__main__":
"--enable-prefix-caching"
,
"--enable-prefix-caching"
,
action
=
'store_true'
,
action
=
'store_true'
,
help
=
"enable automatic prefix caching for vLLM backend."
)
help
=
"enable automatic prefix caching for vLLM backend."
)
parser
.
add_argument
(
"--enable-chunked-prefill"
,
action
=
'store_true'
,
help
=
"enable chunked prefill for vLLM backend."
)
parser
.
add_argument
(
'--max-num-batched-tokens'
,
type
=
int
,
default
=
None
,
help
=
'maximum number of batched tokens per '
'iteration'
)
parser
.
add_argument
(
'--download-dir'
,
parser
.
add_argument
(
'--download-dir'
,
type
=
str
,
type
=
str
,
default
=
None
,
default
=
None
,
...
...
tests/basic_correctness/test_chunked_prefill.py
0 → 100644
View file @
67b4221a
"""Compare the outputs of HF and vLLM when using greedy sampling.
It tests chunked prefill. Chunked prefill can be enabled by
enable_chunked_prefill=True. If prefill size exceeds max_num_batched_tokens,
prefill requests are chunked.
Run `pytest tests/models/test_chunked_prefill.py`.
"""
import
pytest
MODELS
=
[
"facebook/opt-125m"
,
"meta-llama/Llama-2-7b-hf"
,
]
@
pytest
.
mark
.
parametrize
(
"model"
,
MODELS
)
@
pytest
.
mark
.
parametrize
(
"dtype"
,
[
"half"
])
@
pytest
.
mark
.
parametrize
(
"max_tokens"
,
[
32
])
@
pytest
.
mark
.
parametrize
(
"chunked_prefill_token_size"
,
[
1
,
4
,
16
])
@
pytest
.
mark
.
parametrize
(
"enforce_eager"
,
[
False
,
True
])
# NOTE: Increasing this in this suite will fail CI because we currently cannot
# reset distributed env properly. Use a value > 1 just when you test.
@
pytest
.
mark
.
parametrize
(
"tensor_parallel_size"
,
[
1
])
def
test_models
(
hf_runner
,
vllm_runner
,
example_prompts
,
model
:
str
,
dtype
:
str
,
max_tokens
:
int
,
chunked_prefill_token_size
:
int
,
enforce_eager
:
bool
,
tensor_parallel_size
:
int
,
)
->
None
:
if
(
tensor_parallel_size
==
2
and
chunked_prefill_token_size
!=
16
and
not
enforce_eager
):
pytest
.
skip
(
f
"Skip
{
chunked_prefill_token_size
=
}
and
{
enforce_eager
=
}
"
"for high TP to save testing time."
)
max_num_seqs
=
min
(
chunked_prefill_token_size
,
256
)
enable_chunked_prefill
=
False
max_num_batched_tokens
=
None
if
chunked_prefill_token_size
!=
-
1
:
enable_chunked_prefill
=
True
max_num_batched_tokens
=
chunked_prefill_token_size
hf_model
=
hf_runner
(
model
,
dtype
=
dtype
)
hf_outputs
=
hf_model
.
generate_greedy
(
example_prompts
,
max_tokens
)
del
hf_model
vllm_model
=
vllm_runner
(
model
,
dtype
=
dtype
,
max_num_batched_tokens
=
max_num_batched_tokens
,
enable_chunked_prefill
=
enable_chunked_prefill
,
tensor_parallel_size
=
tensor_parallel_size
,
enforce_eager
=
enforce_eager
,
max_num_seqs
=
max_num_seqs
,
)
vllm_outputs
=
vllm_model
.
generate_greedy
(
example_prompts
,
max_tokens
)
del
vllm_model
print
(
vllm_outputs
[
0
])
for
i
in
range
(
len
(
example_prompts
)):
hf_output_ids
,
hf_output_str
=
hf_outputs
[
i
]
vllm_output_ids
,
vllm_output_str
=
vllm_outputs
[
i
]
assert
hf_output_str
==
vllm_output_str
,
(
f
"Test
{
i
}
:
\n
HF:
{
hf_output_str
!
r
}
\n
vLLM:
{
vllm_output_str
!
r
}
"
)
assert
hf_output_ids
==
vllm_output_ids
,
(
f
"Test
{
i
}
:
\n
HF:
{
hf_output_ids
}
\n
vLLM:
{
vllm_output_ids
}
"
)
tests/core/test_chunked_prefill_scheduler.py
View file @
67b4221a
...
@@ -104,10 +104,10 @@ def test_chunk():
...
@@ -104,10 +104,10 @@ def test_chunk():
# One chunked prefill, and one decoding.
# One chunked prefill, and one decoding.
seq_group_meta
,
out
=
schedule_and_update_computed_tokens
(
scheduler
)
seq_group_meta
,
out
=
schedule_and_update_computed_tokens
(
scheduler
)
assert
set
(
get_sequence_groups
(
out
))
==
set
(
running
)
assert
set
(
get_sequence_groups
(
out
))
==
set
(
running
)
# The first one is
decod
ing.
# The first one is
prefill. Scheduler guarantees order
ing.
assert
seq_group_meta
[
0
].
token_chunk_size
==
1
assert
seq_group_meta
[
0
].
token_chunk_size
==
56
# The second one is a chunked prefill.
# The second one is a chunked prefill.
assert
seq_group_meta
[
1
].
token_chunk_size
==
56
assert
seq_group_meta
[
1
].
token_chunk_size
==
1
assert
out
.
num_prefill_groups
==
1
assert
out
.
num_prefill_groups
==
1
assert
out
.
num_batched_tokens
==
57
assert
out
.
num_batched_tokens
==
57
...
@@ -157,12 +157,12 @@ def test_complex():
...
@@ -157,12 +157,12 @@ def test_complex():
# Decoding & chunked prefill & first chunk of 3rd request is scheduled.
# Decoding & chunked prefill & first chunk of 3rd request is scheduled.
seq_group_meta
,
out
=
schedule_and_update_computed_tokens
(
scheduler
)
seq_group_meta
,
out
=
schedule_and_update_computed_tokens
(
scheduler
)
assert
len
(
get_sequence_groups
(
out
))
==
3
assert
len
(
get_sequence_groups
(
out
))
==
3
# The first one is
decoding
.
# The first one is
the first chunked prefill
.
assert
seq_group_meta
[
0
].
token_chunk_size
==
1
assert
seq_group_meta
[
0
].
token_chunk_size
==
7
# The second one is
a
chunked prefill.
# The second one is
the second new
chunked prefill.
assert
seq_group_meta
[
1
].
token_chunk_size
==
56
assert
seq_group_meta
[
1
].
token_chunk_size
==
56
# The
third
one is
also chunked
.
# The
last
one is
decode
.
assert
seq_group_meta
[
2
].
token_chunk_size
==
7
assert
seq_group_meta
[
2
].
token_chunk_size
==
1
# Two of them are in chunked prefill.
# Two of them are in chunked prefill.
assert
out
.
num_prefill_groups
==
2
assert
out
.
num_prefill_groups
==
2
assert
out
.
num_batched_tokens
==
64
assert
out
.
num_batched_tokens
==
64
...
...
tests/distributed/test_basic_distributed_correctness.py
View file @
67b4221a
...
@@ -33,11 +33,16 @@ def test_models(
...
@@ -33,11 +33,16 @@ def test_models(
dtype
:
str
,
dtype
:
str
,
max_tokens
:
int
,
max_tokens
:
int
,
)
->
None
:
)
->
None
:
hf_model
=
hf_runner
(
model
,
dtype
=
dtype
)
hf_model
=
hf_runner
(
model
,
dtype
=
dtype
)
hf_outputs
=
hf_model
.
generate_greedy
(
example_prompts
,
max_tokens
)
hf_outputs
=
hf_model
.
generate_greedy
(
example_prompts
,
max_tokens
)
del
hf_model
del
hf_model
vllm_model
=
vllm_runner
(
model
,
dtype
=
dtype
,
tensor_parallel_size
=
2
)
vllm_model
=
vllm_runner
(
model
,
dtype
=
dtype
,
tensor_parallel_size
=
2
,
)
vllm_outputs
=
vllm_model
.
generate_greedy
(
example_prompts
,
max_tokens
)
vllm_outputs
=
vllm_model
.
generate_greedy
(
example_prompts
,
max_tokens
)
del
vllm_model
del
vllm_model
...
...
tests/distributed/test_chunked_prefill_distributed.py
0 → 100644
View file @
67b4221a
"""Compare the outputs of HF and distributed vLLM when using greedy sampling.
vLLM will allocate all the available memory, so we need to run the tests one
by one. The solution is to pass arguments (model name) by environment
variables.
Run:
```sh
TEST_DIST_MODEL=facebook/opt-125m pytest
\
test_chunked_prefill_distributed.py
TEST_DIST_MODEL=meta-llama/Llama-2-7b-hf
\
test_chunked_prefill_distributed.py
```
"""
import
os
import
pytest
import
torch
MODELS
=
[
os
.
environ
[
"TEST_DIST_MODEL"
],
]
@
pytest
.
mark
.
skipif
(
torch
.
cuda
.
device_count
()
<
2
,
reason
=
"Need at least 2 GPUs to run the test."
)
@
pytest
.
mark
.
parametrize
(
"model"
,
MODELS
)
@
pytest
.
mark
.
parametrize
(
"dtype"
,
[
"half"
])
@
pytest
.
mark
.
parametrize
(
"max_tokens"
,
[
5
])
@
pytest
.
mark
.
parametrize
(
"chunked_prefill_token_size"
,
[
16
])
def
test_models
(
hf_runner
,
vllm_runner
,
example_prompts
,
model
:
str
,
dtype
:
str
,
max_tokens
:
int
,
chunked_prefill_token_size
:
int
,
)
->
None
:
# Add a chunked prefill config.
max_num_seqs
=
min
(
chunked_prefill_token_size
,
256
)
assert
chunked_prefill_token_size
!=
-
1
enable_chunked_prefill
=
True
max_num_batched_tokens
=
chunked_prefill_token_size
hf_model
=
hf_runner
(
model
,
dtype
=
dtype
)
hf_outputs
=
hf_model
.
generate_greedy
(
example_prompts
,
max_tokens
)
del
hf_model
vllm_model
=
vllm_runner
(
model
,
dtype
=
dtype
,
tensor_parallel_size
=
2
,
max_num_seqs
=
max_num_seqs
,
enable_chunked_prefill
=
enable_chunked_prefill
,
max_num_batched_tokens
=
max_num_batched_tokens
,
)
vllm_outputs
=
vllm_model
.
generate_greedy
(
example_prompts
,
max_tokens
)
del
vllm_model
for
i
in
range
(
len
(
example_prompts
)):
hf_output_ids
,
hf_output_str
=
hf_outputs
[
i
]
vllm_output_ids
,
vllm_output_str
=
vllm_outputs
[
i
]
assert
hf_output_str
==
vllm_output_str
,
(
f
"Test
{
i
}
:
\n
HF:
{
hf_output_str
!
r
}
\n
vLLM:
{
vllm_output_str
!
r
}
"
)
assert
hf_output_ids
==
vllm_output_ids
,
(
f
"Test
{
i
}
:
\n
HF:
{
hf_output_ids
}
\n
vLLM:
{
vllm_output_ids
}
"
)
tests/entrypoints/test_openai_server.py
View file @
67b4221a
...
@@ -141,7 +141,7 @@ def server(zephyr_lora_files):
...
@@ -141,7 +141,7 @@ def server(zephyr_lora_files):
"--max-cpu-loras"
,
"--max-cpu-loras"
,
"2"
,
"2"
,
"--max-num-seqs"
,
"--max-num-seqs"
,
"128"
"128"
,
])
])
ray
.
get
(
server_runner
.
ready
.
remote
())
ray
.
get
(
server_runner
.
ready
.
remote
())
yield
server_runner
yield
server_runner
...
...
tests/models/test_models.py
View file @
67b4221a
...
@@ -12,7 +12,7 @@ MODELS = [
...
@@ -12,7 +12,7 @@ MODELS = [
"gpt2"
,
"gpt2"
,
"bigcode/tiny_starcoder_py"
,
"bigcode/tiny_starcoder_py"
,
"EleutherAI/pythia-70m"
,
"EleutherAI/pythia-70m"
,
"bigscience/bloom-560m"
,
"bigscience/bloom-560m"
,
# Testing alibi slopes.
"microsoft/phi-2"
,
"microsoft/phi-2"
,
"stabilityai/stablelm-3b-4e1t"
,
"stabilityai/stablelm-3b-4e1t"
,
# "allenai/OLMo-1B", # Broken
# "allenai/OLMo-1B", # Broken
...
...
tests/worker/test_model_runner.py
View file @
67b4221a
import
pytest
import
pytest
import
torch
import
torch
from
vllm.config
import
ModelConfig
from
vllm.config
import
ModelConfig
,
SchedulerConfig
from
vllm.sequence
import
SamplingParams
,
SequenceData
,
SequenceGroupMetadata
from
vllm.sequence
import
SamplingParams
,
SequenceData
,
SequenceGroupMetadata
from
vllm.worker.model_runner
import
ModelRunner
,
_get_graph_batch_size
from
vllm.worker.model_runner
import
ModelRunner
,
_get_graph_batch_size
@
pytest
.
mark
.
parametrize
(
"batch_size"
,
list
(
range
(
1
,
257
)))
@
pytest
.
mark
.
parametrize
(
"batch_size"
,
list
(
range
(
1
,
257
)))
def
test_prepare_prompt
(
batch_size
):
def
test_prepare_prompt
(
batch_size
):
model_runner
=
ModelRunner
(
None
,
None
,
None
,
None
,
None
)
scheduler_config
=
SchedulerConfig
(
100000
,
100000
,
100000
,
enable_chunked_prefill
=
False
)
model_runner
=
ModelRunner
(
None
,
None
,
scheduler_config
,
None
,
None
)
model_runner
.
set_block_size
(
16
)
model_runner
.
set_block_size
(
16
)
prompt_lens
=
[]
prompt_lens
=
[]
...
@@ -36,8 +40,10 @@ def test_prepare_prompt(batch_size):
...
@@ -36,8 +40,10 @@ def test_prepare_prompt(batch_size):
prompt_len
-
1
)
prompt_len
-
1
)
selected_token_start_idx
+=
prompt_len
selected_token_start_idx
+=
prompt_len
(
input_tokens
,
input_positions
,
attn_metadata
,
return_prompt_lens
,
_
,
_
,
_
,
(
input_tokens
,
input_positions
,
attn_metadata
,
return_prompt_lens
,
_
,
_
,
_
,
_
,
_
)
=
(
model_runner
.
_prepare_prompt
(
seq_group_metadata_list
))
_
,
_
,
slot_mapping
)
=
(
model_runner
.
_prepare_prompt
(
seq_group_metadata_list
))
assert
return_prompt_lens
==
prompt_lens
assert
return_prompt_lens
==
prompt_lens
assert
len
(
slot_mapping
)
==
len
(
input_tokens
)
# Verify input metadata is correct for prompts.
# Verify input metadata is correct for prompts.
device
=
model_runner
.
device
device
=
model_runner
.
device
...
@@ -45,8 +51,6 @@ def test_prepare_prompt(batch_size):
...
@@ -45,8 +51,6 @@ def test_prepare_prompt(batch_size):
assert
torch
.
allclose
(
attn_metadata
.
prompt_lens_tensor
,
assert
torch
.
allclose
(
attn_metadata
.
prompt_lens_tensor
,
torch
.
tensor
(
prompt_lens
,
device
=
device
))
torch
.
tensor
(
prompt_lens
,
device
=
device
))
assert
attn_metadata
.
prompt_lens
==
prompt_lens
assert
attn_metadata
.
prompt_lens
==
prompt_lens
assert
attn_metadata
.
num_prompt_tokens
==
sum
(
prompt_lens
)
assert
attn_metadata
.
num_generation_tokens
==
0
assert
attn_metadata
.
max_prompt_len
==
max
(
prompt_lens
)
assert
attn_metadata
.
max_prompt_len
==
max
(
prompt_lens
)
# Test subquery start locs.
# Test subquery start locs.
...
@@ -83,23 +87,22 @@ def test_prepare_prompt(batch_size):
...
@@ -83,23 +87,22 @@ def test_prepare_prompt(batch_size):
assert
torch
.
allclose
(
attn_metadata
.
block_tables
,
expected
)
assert
torch
.
allclose
(
attn_metadata
.
block_tables
,
expected
)
# Cuda graph should not be used for prerill.
# Cuda graph should not be used for prerill.
assert
attn_metadata
.
use_cuda_graph
is
False
assert
attn_metadata
.
use_cuda_graph
is
False
assert
attn_metadata
.
kv_cache_dtype
==
"auto"
assert
input_tokens
.
shape
==
(
sum
(
prompt_lens
)
,
)
assert
len
(
input_tokens
)
==
sum
(
prompt_lens
)
assert
input_positions
.
shape
==
(
sum
(
prompt_lens
)
,
)
assert
len
(
input_positions
)
==
sum
(
prompt_lens
)
torch
.
testing
.
assert_close
(
input_tokens
,
input_positions
)
torch
.
testing
.
assert_close
(
input_tokens
,
input_positions
)
sampling_metadata
=
model_runner
.
_prepare_sample
(
seq_group_metadata_list
,
sampling_metadata
=
model_runner
.
_prepare_sample
(
seq_group_metadata_list
,
prompt_lens
,
prompt_lens
,
subquery_lens
=
prompt_lens
)
subquery_lens
=
prompt_lens
)
assert
input_tokens
.
shape
==
(
sum
(
prompt_lens
)
,
)
assert
len
(
input_tokens
)
==
sum
(
prompt_lens
)
assert
input_positions
.
shape
==
(
sum
(
prompt_lens
)
,
)
assert
len
(
input_positions
)
==
sum
(
prompt_lens
)
actual
=
sampling_metadata
.
selected_token_indices
actual
=
sampling_metadata
.
selected_token_indices
expected
=
torch
.
tensor
(
expected_selected_token_indices
,
expected
=
torch
.
tensor
(
expected_selected_token_indices
,
device
=
actual
.
device
,
device
=
actual
.
device
,
dtype
=
actual
.
dtype
)
dtype
=
actual
.
dtype
)
torch
.
testing
.
assert_close
(
actual
,
expected
)
torch
.
testing
.
assert_close
(
actual
,
expected
)
torch
.
testing
.
assert_close
(
input_tokens
,
input_positions
)
assert
input_tokens
==
input_positions
actual
=
sampling_metadata
.
selected_token_indices
actual
=
sampling_metadata
.
selected_token_indices
expected
=
torch
.
tensor
(
expected_selected_token_indices
,
expected
=
torch
.
tensor
(
expected_selected_token_indices
,
...
@@ -122,7 +125,12 @@ def test_prepare_decode_cuda_graph(batch_size):
...
@@ -122,7 +125,12 @@ def test_prepare_decode_cuda_graph(batch_size):
revision
=
None
,
revision
=
None
,
enforce_eager
=
False
,
enforce_eager
=
False
,
)
)
model_runner
=
ModelRunner
(
model_config
,
None
,
None
,
None
,
None
)
scheduler_config
=
SchedulerConfig
(
100000
,
100000
,
100000
,
enable_chunked_prefill
=
False
)
model_runner
=
ModelRunner
(
model_config
,
None
,
scheduler_config
,
None
,
None
)
model_runner
.
set_block_size
(
16
)
model_runner
.
set_block_size
(
16
)
prompt_lens
=
[]
prompt_lens
=
[]
...
@@ -143,16 +151,15 @@ def test_prepare_decode_cuda_graph(batch_size):
...
@@ -143,16 +151,15 @@ def test_prepare_decode_cuda_graph(batch_size):
assert
seq_group_metadata
.
token_chunk_size
==
1
assert
seq_group_metadata
.
token_chunk_size
==
1
seq_group_metadata_list
.
append
(
seq_group_metadata
)
seq_group_metadata_list
.
append
(
seq_group_metadata
)
input_tokens
,
input_positions
,
attn_metadata
,
_
,
_
,
_
=
(
input_tokens
,
input_positions
,
attn_metadata
,
_
,
_
,
_
,
slot_mapping
=
(
model_runner
.
_prepare_decode
(
seq_group_metadata_list
))
model_runner
.
_prepare_decode
(
seq_group_metadata_list
))
assert
len
(
slot_mapping
)
==
len
(
input_tokens
)
expected_bs
=
_get_graph_batch_size
(
len
(
seq_group_metadata_list
))
expected_bs
=
_get_graph_batch_size
(
len
(
seq_group_metadata_list
))
# Verify input metadata is correct for prompts.
# Verify input metadata is correct for prompts.
device
=
model_runner
.
device
device
=
model_runner
.
device
assert
attn_metadata
.
is_prompt
is
False
assert
attn_metadata
.
is_prompt
is
False
assert
attn_metadata
.
prompt_lens
is
None
assert
attn_metadata
.
prompt_lens
is
None
assert
attn_metadata
.
num_prompt_tokens
==
0
assert
attn_metadata
.
num_generation_tokens
==
expected_bs
assert
attn_metadata
.
max_prompt_len
is
None
assert
attn_metadata
.
max_prompt_len
is
None
assert
attn_metadata
.
subquery_start_loc
is
None
assert
attn_metadata
.
subquery_start_loc
is
None
assert
attn_metadata
.
seq_start_loc
is
None
assert
attn_metadata
.
seq_start_loc
is
None
...
@@ -170,11 +177,10 @@ def test_prepare_decode_cuda_graph(batch_size):
...
@@ -170,11 +177,10 @@ def test_prepare_decode_cuda_graph(batch_size):
model_runner
.
get_max_block_per_batch
())
model_runner
.
get_max_block_per_batch
())
# Cuda graph should not be used for prerill.
# Cuda graph should not be used for prerill.
assert
attn_metadata
.
use_cuda_graph
is
True
assert
attn_metadata
.
use_cuda_graph
is
True
assert
attn_metadata
.
kv_cache_dtype
==
"auto"
assert
input_tokens
.
shape
==
(
expected_bs
,
)
assert
len
(
input_tokens
)
==
expected_bs
assert
input_positions
.
shape
==
(
expected_bs
,
)
assert
len
(
input_positions
)
==
expected_bs
torch
.
testing
.
assert_close
(
input_tokens
,
input_positions
)
assert
input_tokens
==
input_positions
# Verify Sampling
# Verify Sampling
expected_selected_token_indices
=
[]
expected_selected_token_indices
=
[]
...
@@ -190,3 +196,148 @@ def test_prepare_decode_cuda_graph(batch_size):
...
@@ -190,3 +196,148 @@ def test_prepare_decode_cuda_graph(batch_size):
device
=
actual
.
device
,
device
=
actual
.
device
,
dtype
=
actual
.
dtype
)
dtype
=
actual
.
dtype
)
torch
.
testing
.
assert_close
(
actual
,
expected
)
torch
.
testing
.
assert_close
(
actual
,
expected
)
def
test_empty_seq_group
():
"""Verify prepare prompt and decode returns empty output."""
model_config
=
ModelConfig
(
"facebook/opt-125m"
,
"facebook/opt-125m"
,
tokenizer_mode
=
"auto"
,
trust_remote_code
=
False
,
download_dir
=
None
,
load_format
=
"dummy"
,
seed
=
0
,
dtype
=
"float16"
,
revision
=
None
,
enforce_eager
=
False
,
)
model_runner
=
ModelRunner
(
model_config
,
None
,
None
,
None
,
None
)
model_runner
.
set_block_size
(
16
)
seq_group_metadata_list
=
[]
input_tokens
,
input_positions
,
attn_metadata
,
_
,
_
,
_
,
slot_mapping
=
(
model_runner
.
_prepare_decode
(
seq_group_metadata_list
))
assert
len
(
input_tokens
)
==
0
assert
len
(
input_positions
)
==
0
assert
attn_metadata
is
None
assert
len
(
slot_mapping
)
==
0
(
input_tokens
,
input_positions
,
attn_metadata
,
return_prompt_lens
,
_
,
_
,
_
,
_
,
_
,
slot_mapping
)
=
(
model_runner
.
_prepare_prompt
(
seq_group_metadata_list
))
assert
len
(
input_tokens
)
==
0
assert
len
(
input_positions
)
==
0
assert
attn_metadata
is
None
assert
len
(
slot_mapping
)
==
0
assert
len
(
return_prompt_lens
)
==
0
@
pytest
.
mark
.
parametrize
(
"batch_size"
,
list
(
range
(
2
,
128
)))
@
pytest
.
mark
.
parametrize
(
"enforce_eager"
,
[
True
,
False
])
def
test_hybrid_batches
(
batch_size
,
enforce_eager
,
monkeypatch
):
def
get_world_size
(
group
=
None
):
return
1
def
mock_get_process_group_ranks
(
group
=
None
):
return
[
0
]
monkeypatch
.
setattr
(
torch
.
distributed
,
"get_world_size"
,
get_world_size
)
monkeypatch
.
setattr
(
torch
.
distributed
,
"get_process_group_ranks"
,
mock_get_process_group_ranks
)
model_config
=
ModelConfig
(
"facebook/opt-125m"
,
"facebook/opt-125m"
,
tokenizer_mode
=
"auto"
,
trust_remote_code
=
False
,
download_dir
=
None
,
load_format
=
"dummy"
,
seed
=
0
,
dtype
=
"float16"
,
revision
=
None
,
enforce_eager
=
enforce_eager
,
)
scheduler_config
=
SchedulerConfig
(
100000
,
100000
,
100000
,
enable_chunked_prefill
=
True
)
model_runner
=
ModelRunner
(
model_config
,
None
,
scheduler_config
,
None
,
None
,
is_driver_worker
=
True
)
model_runner
.
set_block_size
(
16
)
# Add prefill requests.
prompt_lens
=
[]
seq_group_metadata_list
=
[]
prefill_metadata_list
=
[]
decode_metadata_list
=
[]
block_tables
=
{
0
:
[
1
]}
prefill_batch_size
=
batch_size
//
2
decode_batch_size
=
batch_size
-
prefill_batch_size
for
i
in
range
(
prefill_batch_size
):
# make sure all tokens fit into one block
prompt_len
=
i
%
(
model_runner
.
block_size
-
1
)
+
1
prompt_lens
.
append
(
prompt_len
)
seq_data
=
SequenceData
(
list
(
range
(
prompt_len
)))
seq_group_metadata
=
SequenceGroupMetadata
(
request_id
=
f
"test_
{
i
}
"
,
is_prompt
=
True
,
seq_data
=
{
0
:
seq_data
},
sampling_params
=
SamplingParams
(
temperature
=
0
),
block_tables
=
block_tables
,
)
assert
seq_group_metadata
.
token_chunk_size
==
seq_data
.
get_len
()
seq_group_metadata_list
.
append
(
seq_group_metadata
)
prefill_metadata_list
.
append
(
seq_group_metadata
)
# Add decode requests
for
i
in
range
(
prefill_batch_size
,
batch_size
):
# make sure all tokens fit into one block
prompt_len
=
i
%
(
model_runner
.
block_size
-
1
)
+
1
prompt_toks
=
list
(
range
(
prompt_len
))
seq_data
=
SequenceData
(
prompt_toks
)
seq_group_metadata
=
SequenceGroupMetadata
(
request_id
=
f
"test_
{
i
}
"
,
is_prompt
=
False
,
seq_data
=
{
0
:
seq_data
},
sampling_params
=
SamplingParams
(
temperature
=
0
),
block_tables
=
{
0
:
[
1
]},
)
assert
seq_group_metadata
.
token_chunk_size
==
1
seq_group_metadata_list
.
append
(
seq_group_metadata
)
decode_metadata_list
.
append
(
seq_group_metadata
)
(
input_tokens
,
input_positions
,
attn_metadata
,
_
,
_
,
_
,
_
)
=
model_runner
.
prepare_input_tensors
(
seq_group_metadata_list
)
prefill_meta_actual
=
attn_metadata
.
prefill_metadata
decode_meta_actual
=
attn_metadata
.
decode_metadata
assert
len
(
attn_metadata
.
slot_mapping
)
==
len
(
input_tokens
)
assert
len
(
input_positions
)
==
len
(
input_tokens
)
assert
attn_metadata
.
kv_cache_dtype
==
"auto"
assert
attn_metadata
.
num_prefills
==
prefill_batch_size
if
enforce_eager
:
assert
attn_metadata
.
num_decode_tokens
==
decode_batch_size
else
:
assert
attn_metadata
.
num_decode_tokens
==
_get_graph_batch_size
(
decode_batch_size
)
assert
attn_metadata
.
num_prefill_tokens
==
sum
(
prompt_lens
)
# Verify attn metadata is consistent. We don't need to test individual
# values here because they are tested above.
prefill_meta
=
model_runner
.
_prepare_prompt
(
prefill_metadata_list
).
attn_metadata
decode_meta
=
model_runner
.
_prepare_decode
(
decode_metadata_list
).
attn_metadata
for
attr_expected
,
attr_actual
in
zip
(
vars
(
prefill_meta
),
vars
(
prefill_meta_actual
)):
assert
attr_expected
[
1
]
==
attr_actual
[
1
]
for
attr_expected
,
attr_actual
in
zip
(
vars
(
decode_meta
),
vars
(
decode_meta_actual
)):
assert
attr_expected
[
1
]
==
attr_actual
[
1
]
vllm/attention/__init__.py
View file @
67b4221a
from
vllm.attention.backends.abstract
import
(
AttentionBackend
,
from
vllm.attention.backends.abstract
import
(
AttentionBackend
,
AttentionMetadata
)
AttentionMetadata
,
AttentionMetadataPerStage
)
from
vllm.attention.layer
import
Attention
from
vllm.attention.layer
import
Attention
from
vllm.attention.selector
import
get_attn_backend
from
vllm.attention.selector
import
get_attn_backend
...
@@ -8,4 +9,5 @@ __all__ = [
...
@@ -8,4 +9,5 @@ __all__ = [
"AttentionMetadata"
,
"AttentionMetadata"
,
"Attention"
,
"Attention"
,
"get_attn_backend"
,
"get_attn_backend"
,
"AttentionMetadataPerStage"
,
]
]
vllm/attention/backends/abstract.py
View file @
67b4221a
from
abc
import
ABC
,
abstractmethod
from
abc
import
ABC
,
abstractmethod
from
dataclasses
import
dataclass
,
fields
from
dataclasses
import
dataclass
,
fields
from
typing
import
Any
,
Dict
,
List
,
Optional
,
Tuple
,
Type
from
typing
import
Any
,
Dict
,
Generic
,
List
,
Optional
,
Tuple
,
Type
,
TypeVar
import
torch
import
torch
...
@@ -47,7 +47,8 @@ class AttentionBackend(ABC):
...
@@ -47,7 +47,8 @@ class AttentionBackend(ABC):
@
dataclass
@
dataclass
class
AttentionMetadata
:
class
AttentionMetadataPerStage
:
"""Attention metadata for a specific stage. I.e., prefill or decode."""
def
asdict_zerocopy
(
self
)
->
Dict
[
str
,
Any
]:
def
asdict_zerocopy
(
self
)
->
Dict
[
str
,
Any
]:
"""Similar to dataclasses.asdict, but avoids deepcopying."""
"""Similar to dataclasses.asdict, but avoids deepcopying."""
...
@@ -59,6 +60,41 @@ class AttentionMetadata:
...
@@ -59,6 +60,41 @@ class AttentionMetadata:
}
}
T
=
TypeVar
(
"T"
,
bound
=
AttentionMetadataPerStage
)
@
dataclass
class
AttentionMetadata
(
Generic
[
T
]):
"""Attention metadata for prefill and decode batched together."""
# Total number of prefill requests.
num_prefills
:
int
# Number of prefill tokens.
num_prefill_tokens
:
int
# Number of decode tokens. Note that it is equivalent to the number of
# decode requests.
num_decode_tokens
:
int
# The attention metadata for prefill requests in a batch.
# None if there's no prefill requests in a batch.
prefill_metadata
:
Optional
[
T
]
# The attention metadata for decode requests in a batch.
# None if there's no decode requests in a batch.
decode_metadata
:
Optional
[
T
]
# (num_tokens,). The indices of the token slots that input tokens will be
# stored into. E.g., if `slot_mapping` is [35, 2, 17] and the block size
# is 16, the three tokens are stored in the 3rd slot in block 2, 2nd slot
# in block 0, and 1st slot in block 1, respectively.
slot_mapping
:
torch
.
Tensor
# The kv cache's data type.
kv_cache_dtype
:
str
def
__post_init__
(
self
):
if
self
.
num_prefill_tokens
>
0
:
assert
self
.
num_prefills
>
0
assert
self
.
prefill_metadata
is
not
None
if
self
.
num_decode_tokens
>
0
:
assert
self
.
decode_metadata
is
not
None
class
AttentionImpl
(
ABC
):
class
AttentionImpl
(
ABC
):
@
abstractmethod
@
abstractmethod
...
@@ -80,7 +116,7 @@ class AttentionImpl(ABC):
...
@@ -80,7 +116,7 @@ class AttentionImpl(ABC):
key
:
torch
.
Tensor
,
key
:
torch
.
Tensor
,
value
:
torch
.
Tensor
,
value
:
torch
.
Tensor
,
kv_cache
:
torch
.
Tensor
,
kv_cache
:
torch
.
Tensor
,
attn_metadata
:
AttentionMetadata
,
attn_metadata
:
AttentionMetadata
[
AttentionMetadataPerStage
]
,
kv_scale
:
float
,
kv_scale
:
float
,
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
raise
NotImplementedError
raise
NotImplementedError
vllm/attention/backends/flash_attn.py
View file @
67b4221a
...
@@ -11,7 +11,8 @@ import torch
...
@@ -11,7 +11,8 @@ import torch
from
flash_attn
import
flash_attn_varlen_func
from
flash_attn
import
flash_attn_varlen_func
from
vllm.attention.backends.abstract
import
(
AttentionBackend
,
AttentionImpl
,
from
vllm.attention.backends.abstract
import
(
AttentionBackend
,
AttentionImpl
,
AttentionMetadata
)
AttentionMetadata
,
AttentionMetadataPerStage
)
from
vllm.attention.ops.paged_attn
import
(
PagedAttention
,
from
vllm.attention.ops.paged_attn
import
(
PagedAttention
,
PagedAttentionMetadata
)
PagedAttentionMetadata
)
...
@@ -53,7 +54,8 @@ class FlashAttentionBackend(AttentionBackend):
...
@@ -53,7 +54,8 @@ class FlashAttentionBackend(AttentionBackend):
@
dataclass
@
dataclass
class
FlashAttentionMetadata
(
AttentionMetadata
,
PagedAttentionMetadata
):
class
FlashAttentionMetadata
(
AttentionMetadataPerStage
,
PagedAttentionMetadata
):
"""Metadata for FlashAttentionBackend.
"""Metadata for FlashAttentionBackend.
NOTE: Any python object stored here is not updated when it is
NOTE: Any python object stored here is not updated when it is
...
@@ -68,10 +70,6 @@ class FlashAttentionMetadata(AttentionMetadata, PagedAttentionMetadata):
...
@@ -68,10 +70,6 @@ class FlashAttentionMetadata(AttentionMetadata, PagedAttentionMetadata):
prompt_lens
:
Optional
[
List
[
int
]]
prompt_lens
:
Optional
[
List
[
int
]]
# prompt_lens stored as a tensor.
# prompt_lens stored as a tensor.
prompt_lens_tensor
:
Optional
[
torch
.
Tensor
]
prompt_lens_tensor
:
Optional
[
torch
.
Tensor
]
# The number of prompt tokens. Doesn't include padding.
num_prompt_tokens
:
int
# The number of generation tokens. Doesn't include padding.
num_generation_tokens
:
int
# NOTE(sang): Definition of context_len, subquery_len, and seqlen.
# NOTE(sang): Definition of context_len, subquery_len, and seqlen.
# |---------- N-1 iteration --------|
# |---------- N-1 iteration --------|
...
@@ -107,18 +105,27 @@ class FlashAttentionMetadata(AttentionMetadata, PagedAttentionMetadata):
...
@@ -107,18 +105,27 @@ class FlashAttentionMetadata(AttentionMetadata, PagedAttentionMetadata):
class
FlashAttentionImpl
(
AttentionImpl
):
class
FlashAttentionImpl
(
AttentionImpl
):
"""
"""
If the input tensors contain prompt tokens, the layout is as follows:
If the input tensors contain prompt tokens, the layout is as follows:
|<--------------- num_pr
ompt
_tokens -------------->|
|<--------------- num_pr
efill
_tokens
---
-------------->|
|<--pr
ompt
_0-->|<--pr
ompt
_1-->|...|<--pr
ompt
_N-1-->|
|<--pr
efill
_0-->|<--pr
efill
_1-->|...|<--pr
efill
_N-1--
-
>|
Otherwise, the layout is as follows:
Otherwise, the layout is as follows:
|<-----------------
-
num_
generation
_tokens
(M)
----------------->|
|<----------------- num_
decode
_tokens
-
----------------->|
|<--
generation
_0-->|..........|<--
generation
_M-1-->|<--padding-->|
|<--
decode
_0-->|..........|<--
decode
_M-1-->|<--padding-->|
Generation tokens can contain padding when cuda-graph is used.
Generation tokens can contain padding when cuda-graph is used.
Currently, prompt tokens don't contain any padding.
Currently, prompt tokens don't contain any padding.
The prompts might have different lengths, while the generation tokens
The prompts might have different lengths, while the generation tokens
always have length 1.
always have length 1.
If chunked prefill is enabled, prefill tokens and decode tokens can be
batched together in a flattened 1D query.
|<----- num_prefill_tokens ---->|<------- num_decode_tokens --------->|
|<-prefill_0->|...|<-prefill_N-1->|<--decode_0-->|...|<--decode_M-1-->|
Currently, cuda graph is disabled for chunked prefill, meaning there's no
padding between prefill and decode tokens.
"""
"""
def
__init__
(
def
__init__
(
...
@@ -155,7 +162,7 @@ class FlashAttentionImpl(AttentionImpl):
...
@@ -155,7 +162,7 @@ class FlashAttentionImpl(AttentionImpl):
key
:
torch
.
Tensor
,
key
:
torch
.
Tensor
,
value
:
torch
.
Tensor
,
value
:
torch
.
Tensor
,
kv_cache
:
torch
.
Tensor
,
kv_cache
:
torch
.
Tensor
,
attn_metadata
:
FlashAttentionMetadata
,
attn_metadata
:
AttentionMetadata
[
FlashAttentionMetadata
]
,
kv_scale
:
float
,
kv_scale
:
float
,
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
"""Forward pass with FlashAttention and PagedAttention.
"""Forward pass with FlashAttention and PagedAttention.
...
@@ -188,52 +195,70 @@ class FlashAttentionImpl(AttentionImpl):
...
@@ -188,52 +195,70 @@ class FlashAttentionImpl(AttentionImpl):
attn_metadata
.
kv_cache_dtype
,
attn_metadata
.
kv_cache_dtype
,
kv_scale
)
kv_scale
)
if
attn_metadata
.
is_prompt
:
num_prefill_tokens
=
attn_metadata
.
num_prefill_tokens
num_decode_tokens
=
attn_metadata
.
num_decode_tokens
assert
key
.
shape
[
0
]
==
num_prefill_tokens
+
num_decode_tokens
assert
value
.
shape
[
0
]
==
num_prefill_tokens
+
num_decode_tokens
output
=
torch
.
empty_like
(
query
)
# Query for decode. KV is not needed because it is already cached.
decode_query
=
query
[
num_prefill_tokens
:]
# QKV for prefill.
query
=
query
[:
num_prefill_tokens
]
key
=
key
[:
num_prefill_tokens
]
value
=
value
[:
num_prefill_tokens
]
assert
query
.
shape
[
0
]
==
num_prefill_tokens
assert
decode_query
.
shape
[
0
]
==
num_decode_tokens
if
prefill_meta
:
=
attn_metadata
.
prefill_metadata
:
# Prompt run.
# Prompt run.
if
kv_cache
is
None
or
attn_metada
ta
.
block_tables
.
numel
()
==
0
:
if
kv_cache
is
None
or
prefill_me
ta
.
block_tables
.
numel
()
==
0
:
# normal attention
# normal attention
# When block_tables are not filled, it means q and k are the
# When block_tables are not filled, it means q and k are the
# prompt, and they have the same length.
# prompt, and they have the same length.
out
put
=
flash_attn_varlen_func
(
out
=
flash_attn_varlen_func
(
q
=
query
,
q
=
query
,
k
=
key
,
k
=
key
,
v
=
value
,
v
=
value
,
cu_seqlens_q
=
attn_metada
ta
.
seq_start_loc
,
cu_seqlens_q
=
prefill_me
ta
.
seq_start_loc
,
cu_seqlens_k
=
attn_metada
ta
.
seq_start_loc
,
cu_seqlens_k
=
prefill_me
ta
.
seq_start_loc
,
max_seqlen_q
=
attn_metada
ta
.
max_prompt_len
,
max_seqlen_q
=
prefill_me
ta
.
max_prompt_len
,
max_seqlen_k
=
attn_metada
ta
.
max_prompt_len
,
max_seqlen_k
=
prefill_me
ta
.
max_prompt_len
,
softmax_scale
=
self
.
scale
,
softmax_scale
=
self
.
scale
,
causal
=
True
,
causal
=
True
,
window_size
=
self
.
sliding_window
,
window_size
=
self
.
sliding_window
,
alibi_slopes
=
self
.
alibi_slopes
,
alibi_slopes
=
self
.
alibi_slopes
,
)
)
assert
output
[:
num_prefill_tokens
].
shape
==
out
.
shape
output
[:
num_prefill_tokens
]
=
out
else
:
else
:
# prefix-enabled attention
# prefix-enabled attention
# TODO(Hai) this triton kernel has regression issue (broke) to
# TODO(Hai) this triton kernel has regression issue (broke) to
# deal with different data types between KV and FP8 KV cache,
# deal with different data types between KV and FP8 KV cache,
# to be addressed separately.
# to be addressed separately.
output
=
PagedAttention
.
forward_prefix
(
output
[:
num_prefill_tokens
]
=
PagedAttention
.
forward_prefix
(
query
,
query
,
key
,
key
,
value
,
value
,
key_cache
,
key_cache
,
value_cache
,
value_cache
,
attn_metada
ta
.
block_tables
,
prefill_me
ta
.
block_tables
,
attn_metada
ta
.
subquery_start_loc
,
prefill_me
ta
.
subquery_start_loc
,
attn_metada
ta
.
prompt_lens_tensor
,
prefill_me
ta
.
prompt_lens_tensor
,
attn_metada
ta
.
context_lens
,
prefill_me
ta
.
context_lens
,
attn_metada
ta
.
max_subquery_len
,
prefill_me
ta
.
max_subquery_len
,
self
.
alibi_slopes
,
self
.
alibi_slopes
,
)
)
else
:
if
decode_meta
:
=
attn_metadata
.
decode_metadata
:
# Decoding run.
# Decoding run.
output
=
PagedAttention
.
forward_decode
(
output
[
num_prefill_tokens
:]
=
PagedAttention
.
forward_decode
(
query
,
decode_
query
,
key_cache
,
key_cache
,
value_cache
,
value_cache
,
attn_metada
ta
.
block_tables
,
decode_me
ta
.
block_tables
,
attn_metada
ta
.
context_lens
,
decode_me
ta
.
context_lens
,
attn_metada
ta
.
max_context_len
,
decode_me
ta
.
max_context_len
,
attn_metadata
.
kv_cache_dtype
,
attn_metadata
.
kv_cache_dtype
,
self
.
num_kv_heads
,
self
.
num_kv_heads
,
self
.
scale
,
self
.
scale
,
...
...
vllm/attention/backends/rocm_flash_attn.py
View file @
67b4221a
...
@@ -6,7 +6,8 @@ from typing import Dict, List, Optional, Tuple, Type
...
@@ -6,7 +6,8 @@ from typing import Dict, List, Optional, Tuple, Type
import
torch
import
torch
from
vllm.attention.backends.abstract
import
(
AttentionBackend
,
AttentionImpl
,
from
vllm.attention.backends.abstract
import
(
AttentionBackend
,
AttentionImpl
,
AttentionMetadata
)
AttentionMetadata
,
AttentionMetadataPerStage
)
from
vllm.attention.ops.paged_attn
import
(
PagedAttention
,
from
vllm.attention.ops.paged_attn
import
(
PagedAttention
,
PagedAttentionMetadata
)
PagedAttentionMetadata
)
from
vllm.logger
import
init_logger
from
vllm.logger
import
init_logger
...
@@ -51,7 +52,8 @@ class ROCmFlashAttentionBackend(AttentionBackend):
...
@@ -51,7 +52,8 @@ class ROCmFlashAttentionBackend(AttentionBackend):
@
dataclass
@
dataclass
class
ROCmFlashAttentionMetadata
(
AttentionMetadata
,
PagedAttentionMetadata
):
class
ROCmFlashAttentionMetadata
(
AttentionMetadataPerStage
,
PagedAttentionMetadata
):
"""Metadata for FlashAttentionBackend.
"""Metadata for FlashAttentionBackend.
NOTE: Any python object stored here is not updated when it is
NOTE: Any python object stored here is not updated when it is
...
@@ -66,10 +68,6 @@ class ROCmFlashAttentionMetadata(AttentionMetadata, PagedAttentionMetadata):
...
@@ -66,10 +68,6 @@ class ROCmFlashAttentionMetadata(AttentionMetadata, PagedAttentionMetadata):
prompt_lens
:
Optional
[
List
[
int
]]
prompt_lens
:
Optional
[
List
[
int
]]
# prompt_lens stored as a tensor.
# prompt_lens stored as a tensor.
prompt_lens_tensor
:
Optional
[
torch
.
Tensor
]
prompt_lens_tensor
:
Optional
[
torch
.
Tensor
]
# The number of prompt tokens. Doesn't include padding.
num_prompt_tokens
:
int
# The number of generation tokens. Doesn't include padding.
num_generation_tokens
:
int
# NOTE(sang): Definition of context_len, subquery_len, and seqlen.
# NOTE(sang): Definition of context_len, subquery_len, and seqlen.
# |---------- N-1 iteration --------|
# |---------- N-1 iteration --------|
...
@@ -117,6 +115,15 @@ class ROCmFlashAttentionImpl(AttentionImpl):
...
@@ -117,6 +115,15 @@ class ROCmFlashAttentionImpl(AttentionImpl):
The prompts might have different lengths, while the generation tokens
The prompts might have different lengths, while the generation tokens
always have length 1.
always have length 1.
If chunked prefill is enabled, prefill tokens and decode tokens can be
batched together in a flattened 1D query.
|<----- num_prefill_tokens ---->|<------- num_decode_tokens ----------->|
|<-prompt_0->|...|<-prompt_N-1->|<-generation_0->|...|<-generation_M-1->|
Currently, cuda graph is disabled for chunked prefill, meaning there's no
padding between prefill and decode tokens.
"""
"""
def
__init__
(
def
__init__
(
...
@@ -181,7 +188,7 @@ class ROCmFlashAttentionImpl(AttentionImpl):
...
@@ -181,7 +188,7 @@ class ROCmFlashAttentionImpl(AttentionImpl):
key
:
torch
.
Tensor
,
key
:
torch
.
Tensor
,
value
:
torch
.
Tensor
,
value
:
torch
.
Tensor
,
kv_cache
:
torch
.
Tensor
,
kv_cache
:
torch
.
Tensor
,
attn_metadata
:
ROCmFlashAttentionMetadata
,
attn_metadata
:
AttentionMetadata
[
ROCmFlashAttentionMetadata
]
,
kv_scale
:
float
=
1.0
,
kv_scale
:
float
=
1.0
,
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
"""Forward pass with FlashAttention and PagedAttention.
"""Forward pass with FlashAttention and PagedAttention.
...
@@ -218,9 +225,25 @@ class ROCmFlashAttentionImpl(AttentionImpl):
...
@@ -218,9 +225,25 @@ class ROCmFlashAttentionImpl(AttentionImpl):
kv_scale
,
kv_scale
,
)
)
if
attn_metadata
.
is_prompt
:
num_prefill_tokens
=
attn_metadata
.
num_prefill_tokens
num_decode_tokens
=
attn_metadata
.
num_decode_tokens
assert
key
.
shape
[
0
]
==
num_prefill_tokens
+
num_decode_tokens
assert
value
.
shape
[
0
]
==
num_prefill_tokens
+
num_decode_tokens
output
=
torch
.
empty_like
(
query
)
# Query for decode. KV is not needed because it is already cached.
decode_query
=
query
[
num_prefill_tokens
:]
# QKV for prefill.
query
=
query
[:
num_prefill_tokens
]
key
=
key
[:
num_prefill_tokens
]
value
=
value
[:
num_prefill_tokens
]
assert
query
.
shape
[
0
]
==
num_prefill_tokens
assert
decode_query
.
shape
[
0
]
==
num_decode_tokens
if
prefill_meta
:
=
attn_metadata
.
prefill_metadata
:
# Prompt run.
# Prompt run.
if
kv_cache
is
None
or
attn_metada
ta
.
block_tables
.
numel
()
==
0
:
if
kv_cache
is
None
or
prefill_me
ta
.
block_tables
.
numel
()
==
0
:
# triton attention
# triton attention
# When block_tables are not filled, it means q and k are the
# When block_tables are not filled, it means q and k are the
# prompt, and they have the same length.
# prompt, and they have the same length.
...
@@ -230,63 +253,69 @@ class ROCmFlashAttentionImpl(AttentionImpl):
...
@@ -230,63 +253,69 @@ class ROCmFlashAttentionImpl(AttentionImpl):
key
=
self
.
repeat_kv
(
key
,
self
.
num_queries_per_kv
)
key
=
self
.
repeat_kv
(
key
,
self
.
num_queries_per_kv
)
value
=
self
.
repeat_kv
(
value
,
self
.
num_queries_per_kv
)
value
=
self
.
repeat_kv
(
value
,
self
.
num_queries_per_kv
)
if
self
.
use_naive_attn
:
if
self
.
use_naive_attn
:
out
put
=
self
.
attn_fuc
(
out
=
self
.
attn_fuc
(
query
,
query
,
key
,
key
,
value
,
value
,
attn_metada
ta
.
prompt_lens
,
prefill_me
ta
.
prompt_lens
,
self
.
scale
,
self
.
scale
,
)
)
assert
output
[:
num_prefill_tokens
].
shape
==
out
.
shape
output
[:
num_prefill_tokens
]
=
out
else
:
else
:
out
put
,
_
=
self
.
attn_func
(
out
,
_
=
self
.
attn_func
(
query
,
query
,
key
,
key
,
value
,
value
,
None
,
None
,
attn_metada
ta
.
seq_start_loc
,
prefill_me
ta
.
seq_start_loc
,
attn_metada
ta
.
seq_start_loc
,
prefill_me
ta
.
seq_start_loc
,
attn_metada
ta
.
max_prompt_len
,
prefill_me
ta
.
max_prompt_len
,
attn_metada
ta
.
max_prompt_len
,
prefill_me
ta
.
max_prompt_len
,
True
,
True
,
self
.
scale
,
self
.
scale
,
)
)
assert
output
[:
num_prefill_tokens
].
shape
==
out
.
shape
output
[:
num_prefill_tokens
]
=
out
else
:
else
:
out
put
=
self
.
attn_func
(
out
=
self
.
attn_func
(
q
=
query
,
q
=
query
,
k
=
key
,
k
=
key
,
v
=
value
,
v
=
value
,
cu_seqlens_q
=
attn_metada
ta
.
seq_start_loc
,
cu_seqlens_q
=
prefill_me
ta
.
seq_start_loc
,
cu_seqlens_k
=
attn_metada
ta
.
seq_start_loc
,
cu_seqlens_k
=
prefill_me
ta
.
seq_start_loc
,
max_seqlen_q
=
attn_metada
ta
.
max_prompt_len
,
max_seqlen_q
=
prefill_me
ta
.
max_prompt_len
,
max_seqlen_k
=
attn_metada
ta
.
max_prompt_len
,
max_seqlen_k
=
prefill_me
ta
.
max_prompt_len
,
softmax_scale
=
self
.
scale
,
softmax_scale
=
self
.
scale
,
causal
=
True
,
causal
=
True
,
)
)
assert
output
[:
num_prefill_tokens
].
shape
==
out
.
shape
output
[:
num_prefill_tokens
]
=
out
else
:
else
:
# prefix-enabled attention
# prefix-enabled attention
output
=
PagedAttention
.
forward_prefix
(
output
[:
num_prefill_tokens
]
=
PagedAttention
.
forward_prefix
(
query
,
query
,
key
,
key
,
value
,
value
,
key_cache
,
key_cache
,
value_cache
,
value_cache
,
attn_metada
ta
.
block_tables
,
prefill_me
ta
.
block_tables
,
attn_metada
ta
.
subquery_start_loc
,
prefill_me
ta
.
subquery_start_loc
,
attn_metada
ta
.
prompt_lens_tensor
,
prefill_me
ta
.
prompt_lens_tensor
,
attn_metada
ta
.
context_lens
,
prefill_me
ta
.
context_lens
,
attn_metada
ta
.
max_subquery_len
,
prefill_me
ta
.
max_subquery_len
,
self
.
alibi_slopes
,
self
.
alibi_slopes
,
)
)
else
:
if
decode_meta
:
=
attn_metadata
.
decode_metadata
:
# Decoding run.
# Decoding run.
output
=
PagedAttention
.
forward_decode
(
output
[
num_prefill_tokens
:]
=
PagedAttention
.
forward_decode
(
query
,
decode_
query
,
key_cache
,
key_cache
,
value_cache
,
value_cache
,
attn_metada
ta
.
block_tables
,
decode_me
ta
.
block_tables
,
attn_metada
ta
.
context_lens
,
decode_me
ta
.
context_lens
,
attn_metada
ta
.
max_context_len
,
decode_me
ta
.
max_context_len
,
attn_metadata
.
kv_cache_dtype
,
attn_metadata
.
kv_cache_dtype
,
self
.
num_kv_heads
,
self
.
num_kv_heads
,
self
.
scale
,
self
.
scale
,
...
...
vllm/attention/backends/torch_sdpa.py
View file @
67b4221a
...
@@ -7,7 +7,8 @@ import torch
...
@@ -7,7 +7,8 @@ import torch
from
torch.nn.functional
import
scaled_dot_product_attention
from
torch.nn.functional
import
scaled_dot_product_attention
from
vllm.attention.backends.abstract
import
(
AttentionBackend
,
AttentionImpl
,
from
vllm.attention.backends.abstract
import
(
AttentionBackend
,
AttentionImpl
,
AttentionMetadata
)
AttentionMetadata
,
AttentionMetadataPerStage
)
from
vllm.attention.ops.paged_attn
import
(
PagedAttention
,
from
vllm.attention.ops.paged_attn
import
(
PagedAttention
,
PagedAttentionMetadata
)
PagedAttentionMetadata
)
...
@@ -49,17 +50,14 @@ class TorchSDPABackend(AttentionBackend):
...
@@ -49,17 +50,14 @@ class TorchSDPABackend(AttentionBackend):
@
dataclass
@
dataclass
class
TorchSDPAMetadata
(
AttentionMetadata
,
PagedAttentionMetadata
):
class
TorchSDPAMetadata
(
AttentionMetadata
PerStage
,
PagedAttentionMetadata
):
"""Metadata for TorchSDPABackend.
"""Metadata for TorchSDPABackend.
"""
"""
# Currently, input sequences can only contain all prompts
# Currently, input sequences can only contain all prompts
# or all decoding. True if all sequences are prompts.
# or all decoding. True if all sequences are prompts.
is_prompt
:
bool
is_prompt
:
bool
slot_mapping
:
torch
.
Tensor
prompt_lens
:
Optional
[
List
[
int
]]
prompt_lens
:
Optional
[
List
[
int
]]
prompt_lens_tensor
:
Optional
[
torch
.
Tensor
]
prompt_lens_tensor
:
Optional
[
torch
.
Tensor
]
num_prompt_tokens
:
int
num_generation_tokens
:
int
max_subquery_len
:
Optional
[
int
]
=
None
max_subquery_len
:
Optional
[
int
]
=
None
max_prompt_len
:
Optional
[
int
]
=
None
max_prompt_len
:
Optional
[
int
]
=
None
...
@@ -113,7 +111,7 @@ class TorchSDPABackendImpl(AttentionImpl):
...
@@ -113,7 +111,7 @@ class TorchSDPABackendImpl(AttentionImpl):
key
:
torch
.
Tensor
,
key
:
torch
.
Tensor
,
value
:
torch
.
Tensor
,
value
:
torch
.
Tensor
,
kv_cache
:
Optional
[
torch
.
Tensor
],
kv_cache
:
Optional
[
torch
.
Tensor
],
attn_metadata
:
TorchSDPAMetadata
,
attn_metadata
:
AttentionMetadata
[
TorchSDPAMetadata
]
,
kv_scale
:
float
,
kv_scale
:
float
,
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
"""Forward pass with torch SDPA and PagedAttention.
"""Forward pass with torch SDPA and PagedAttention.
...
@@ -142,36 +140,51 @@ class TorchSDPABackendImpl(AttentionImpl):
...
@@ -142,36 +140,51 @@ class TorchSDPABackendImpl(AttentionImpl):
attn_metadata
.
kv_cache_dtype
,
attn_metadata
.
kv_cache_dtype
,
kv_scale
)
kv_scale
)
if
attn_metadata
.
is_prompt
:
num_prefill_tokens
=
attn_metadata
.
num_prefill_tokens
if
(
kv_cache
is
None
or
attn_metadata
.
block_tables
.
numel
()
==
0
):
num_decode_tokens
=
attn_metadata
.
num_decode_tokens
assert
key
.
shape
[
0
]
==
num_prefill_tokens
+
num_decode_tokens
assert
value
.
shape
[
0
]
==
num_prefill_tokens
+
num_decode_tokens
output
=
torch
.
empty_like
(
query
)
# Query for decode. KV is not needed because it is already cached.
decode_query
=
query
[
num_prefill_tokens
:]
# QKV for prefill.
query
=
query
[:
num_prefill_tokens
]
key
=
key
[:
num_prefill_tokens
]
value
=
value
[:
num_prefill_tokens
]
assert
query
.
shape
[
0
]
==
num_prefill_tokens
assert
decode_query
.
shape
[
0
]
==
num_decode_tokens
if
prefill_meta
:
=
attn_metadata
.
prefill_metadata
:
if
(
kv_cache
is
None
or
prefill_meta
.
block_tables
.
numel
()
==
0
):
if
self
.
num_kv_heads
!=
self
.
num_heads
:
if
self
.
num_kv_heads
!=
self
.
num_heads
:
key
=
key
.
repeat_interleave
(
self
.
num_queries_per_kv
,
dim
=
1
)
key
=
key
.
repeat_interleave
(
self
.
num_queries_per_kv
,
dim
=
1
)
value
=
value
.
repeat_interleave
(
self
.
num_queries_per_kv
,
value
=
value
.
repeat_interleave
(
self
.
num_queries_per_kv
,
dim
=
1
)
dim
=
1
)
if
attn_metada
ta
.
attn_bias
is
None
:
if
prefill_me
ta
.
attn_bias
is
None
:
if
self
.
alibi_slopes
is
not
None
:
if
self
.
alibi_slopes
is
not
None
:
att_masks
=
_make_alibi_bias
(
att_masks
=
_make_alibi_bias
(
self
.
alibi_slopes
,
query
.
dtype
,
self
.
alibi_slopes
,
query
.
dtype
,
attn_metada
ta
.
prompt_lens
)
# type: ignore
prefill_me
ta
.
prompt_lens
)
# type: ignore
elif
self
.
sliding_window
is
not
None
:
elif
self
.
sliding_window
is
not
None
:
att_masks
=
_make_sliding_window_bias
(
att_masks
=
_make_sliding_window_bias
(
attn_metada
ta
.
prompt_lens
,
self
.
sliding_window
,
prefill_me
ta
.
prompt_lens
,
self
.
sliding_window
,
query
.
dtype
)
# type: ignore
query
.
dtype
)
# type: ignore
else
:
else
:
att_masks
=
[
None
]
*
len
(
attn_metada
ta
.
prompt_lens
)
att_masks
=
[
None
]
*
len
(
prefill_me
ta
.
prompt_lens
)
attn_metada
ta
.
attn_bias
=
att_masks
prefill_me
ta
.
attn_bias
=
att_masks
query
=
query
.
movedim
(
0
,
query
.
dim
()
-
2
)
query
=
query
.
movedim
(
0
,
query
.
dim
()
-
2
)
key
=
key
.
movedim
(
0
,
key
.
dim
()
-
2
)
key
=
key
.
movedim
(
0
,
key
.
dim
()
-
2
)
value
=
value
.
movedim
(
0
,
value
.
dim
()
-
2
)
value
=
value
.
movedim
(
0
,
value
.
dim
()
-
2
)
start
=
0
start
=
0
output
=
torch
.
empty
(
out
=
torch
.
empty
((
num_tokens
,
self
.
num_heads
,
self
.
head_size
),
(
num_tokens
,
self
.
num_heads
,
self
.
head_size
),
dtype
=
query
.
dtype
)
dtype
=
query
.
dtype
)
for
prompt_len
,
mask
in
zip
(
prefill_meta
.
prompt_lens
,
for
prompt_len
,
mask
in
zip
(
attn_metadata
.
prompt_lens
,
prefill_meta
.
attn_bias
):
attn_metadata
.
attn_bias
):
end
=
start
+
prompt_len
end
=
start
+
prompt_len
sub_out
=
scaled_dot_product_attention
(
sub_out
=
scaled_dot_product_attention
(
query
[:,
start
:
end
,
:],
query
[:,
start
:
end
,
:],
...
@@ -181,28 +194,32 @@ class TorchSDPABackendImpl(AttentionImpl):
...
@@ -181,28 +194,32 @@ class TorchSDPABackendImpl(AttentionImpl):
dropout_p
=
0.0
,
dropout_p
=
0.0
,
is_causal
=
not
self
.
need_mask
,
is_causal
=
not
self
.
need_mask
,
scale
=
self
.
scale
).
movedim
(
query
.
dim
()
-
2
,
0
)
scale
=
self
.
scale
).
movedim
(
query
.
dim
()
-
2
,
0
)
out
put
[
start
:
end
,
:,
:]
=
sub_out
out
[
start
:
end
,
:,
:]
=
sub_out
start
=
end
start
=
end
assert
out
.
shape
==
output
[:
num_prefill_tokens
].
shape
output
[:
num_prefill_tokens
]
=
out
else
:
else
:
# prefix-enabled attention
# prefix-enabled attention
raise
RuntimeError
(
raise
RuntimeError
(
"Torch SDPA backend doesn't support prefix decoding."
)
"Torch SDPA backend doesn't support prefix decoding."
)
else
:
if
decode_meta
:
=
attn_metadata
.
decode_metadata
:
# Decoding run.
# Decoding run.
out
put
=
PagedAttention
.
forward_decode
(
out
=
PagedAttention
.
forward_decode
(
query
,
decode_
query
,
key_cache
,
key_cache
,
value_cache
,
value_cache
,
attn_metada
ta
.
block_tables
,
decode_me
ta
.
block_tables
,
attn_metada
ta
.
context_lens
,
decode_me
ta
.
context_lens
,
attn_metada
ta
.
max_context_len
,
decode_me
ta
.
max_context_len
,
attn_metadata
.
kv_cache_dtype
,
attn_metadata
.
kv_cache_dtype
,
self
.
num_kv_heads
,
self
.
num_kv_heads
,
self
.
scale
,
self
.
scale
,
self
.
alibi_slopes
,
self
.
alibi_slopes
,
kv_scale
,
kv_scale
,
)
)
assert
out
.
shape
==
output
[
num_prefill_tokens
:].
shape
output
[
num_prefill_tokens
:]
# Reshape the output tensor.
# Reshape the output tensor.
return
output
.
view
(
-
1
,
self
.
num_heads
*
self
.
head_size
)
return
output
.
view
(
-
1
,
self
.
num_heads
*
self
.
head_size
)
...
...
vllm/attention/backends/xformers.py
View file @
67b4221a
...
@@ -9,7 +9,8 @@ from xformers.ops.fmha.attn_bias import (AttentionBias,
...
@@ -9,7 +9,8 @@ from xformers.ops.fmha.attn_bias import (AttentionBias,
LowerTriangularMaskWithTensorBias
)
LowerTriangularMaskWithTensorBias
)
from
vllm.attention.backends.abstract
import
(
AttentionBackend
,
AttentionImpl
,
from
vllm.attention.backends.abstract
import
(
AttentionBackend
,
AttentionImpl
,
AttentionMetadata
)
AttentionMetadata
,
AttentionMetadataPerStage
)
from
vllm.attention.ops.paged_attn
import
(
PagedAttention
,
from
vllm.attention.ops.paged_attn
import
(
PagedAttention
,
PagedAttentionMetadata
)
PagedAttentionMetadata
)
from
vllm.logger
import
init_logger
from
vllm.logger
import
init_logger
...
@@ -54,7 +55,7 @@ class XFormersBackend(AttentionBackend):
...
@@ -54,7 +55,7 @@ class XFormersBackend(AttentionBackend):
@
dataclass
@
dataclass
class
XFormersMetadata
(
AttentionMetadata
,
PagedAttentionMetadata
):
class
XFormersMetadata
(
AttentionMetadata
PerStage
,
PagedAttentionMetadata
):
"""Metadata for XFormersbackend.
"""Metadata for XFormersbackend.
NOTE: Any python object stored here is not updated when it is
NOTE: Any python object stored here is not updated when it is
...
@@ -65,19 +66,10 @@ class XFormersMetadata(AttentionMetadata, PagedAttentionMetadata):
...
@@ -65,19 +66,10 @@ class XFormersMetadata(AttentionMetadata, PagedAttentionMetadata):
# Currently, input sequences can only contain all prompts
# Currently, input sequences can only contain all prompts
# or all decoding. True if all sequences are prompts.
# or all decoding. True if all sequences are prompts.
is_prompt
:
bool
is_prompt
:
bool
# (num_tokens,). The indices of the token slots that input tokens will be
# stored into. E.g., if `slot_mapping` is [35, 2, 17] and the block size
# is 16, the three tokens are stored in the 3rd slot in block 2, 2nd slot
# in block 0, and 1st slot in block 1, respectively.
slot_mapping
:
torch
.
Tensor
# (batch_size,). The prompt length per sequence. None if it is a decoding.
# (batch_size,). The prompt length per sequence. None if it is a decoding.
prompt_lens
:
Optional
[
List
[
int
]]
prompt_lens
:
Optional
[
List
[
int
]]
# prompt_lens stored as a tensor.
# prompt_lens stored as a tensor.
prompt_lens_tensor
:
Optional
[
torch
.
Tensor
]
prompt_lens_tensor
:
Optional
[
torch
.
Tensor
]
# The number of prompt tokens. Doesn't include padding.
num_prompt_tokens
:
int
# The number of generation tokens. Doesn't include padding.
num_generation_tokens
:
int
# NOTE(sang): Definition of context_len, subquery_len, and seqlen.
# NOTE(sang): Definition of context_len, subquery_len, and seqlen.
# |---------- N-1 iteration --------|
# |---------- N-1 iteration --------|
...
@@ -123,18 +115,27 @@ class XFormersMetadata(AttentionMetadata, PagedAttentionMetadata):
...
@@ -123,18 +115,27 @@ class XFormersMetadata(AttentionMetadata, PagedAttentionMetadata):
class
XFormersImpl
(
AttentionImpl
):
class
XFormersImpl
(
AttentionImpl
):
"""
"""
If the input tensors contain prompt tokens, the layout is as follows:
If the input tensors contain prompt tokens, the layout is as follows:
|<--------------- num_pr
ompt
_tokens --------------->|
|<--------------- num_pr
efill
_tokens
--
--------------->|
|<--pr
ompt
_0-->|<--pr
ompt
_1-->|...|<--pr
ompt
_N-1--->|
|<--pr
efill
_0-->|<--pr
efill
_1-->|...|<--pr
efill
_N-1--->|
Otherwise, the layout is as follows:
Otherwise, the layout is as follows:
|<-----------------
-
num_
generation
_tokens
(M)
----------------->|
|<----------------- num_
decode
_tokens
-
----------------->|
|<--
generation
_0-->|..........|<--
generation
_M-1-->|<--padding-->|
|<--
decode
_0-->|..........|<--
decode
_M-1-->|<--padding-->|
Generation tokens can contain padding when cuda-graph is used.
Generation tokens can contain padding when cuda-graph is used.
Currently, prompt tokens don't contain any padding.
Currently, prompt tokens don't contain any padding.
The prompts might have different lengths, while the generation tokens
The prompts might have different lengths, while the generation tokens
always have length 1.
always have length 1.
If chunked prefill is enabled, prefill tokens and decode tokens can be
batched together in a flattened 1D query.
|<----- num_prefill_tokens ---->|<------- num_decode_tokens --------->|
|<-prefill_0->|...|<-prefill_N-1->|<--decode_0-->|...|<--decode_M-1-->|
Currently, cuda graph is disabled for chunked prefill, meaning there's no
padding between prefill and decode tokens.
"""
"""
def
__init__
(
def
__init__
(
...
@@ -170,7 +171,7 @@ class XFormersImpl(AttentionImpl):
...
@@ -170,7 +171,7 @@ class XFormersImpl(AttentionImpl):
key
:
torch
.
Tensor
,
key
:
torch
.
Tensor
,
value
:
torch
.
Tensor
,
value
:
torch
.
Tensor
,
kv_cache
:
Optional
[
torch
.
Tensor
],
kv_cache
:
Optional
[
torch
.
Tensor
],
attn_metadata
:
XFormersMetadata
,
attn_metadata
:
AttentionMetadata
[
XFormersMetadata
]
,
kv_scale
:
float
,
kv_scale
:
float
,
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
"""Forward pass with xFormers and PagedAttention.
"""Forward pass with xFormers and PagedAttention.
...
@@ -202,59 +203,61 @@ class XFormersImpl(AttentionImpl):
...
@@ -202,59 +203,61 @@ class XFormersImpl(AttentionImpl):
attn_metadata
.
kv_cache_dtype
,
attn_metadata
.
kv_cache_dtype
,
kv_scale
)
kv_scale
)
if
attn_metadata
.
is_prompt
:
num_prefill_tokens
=
attn_metadata
.
num_prefill_tokens
num_decode_tokens
=
attn_metadata
.
num_decode_tokens
assert
key
.
shape
[
0
]
==
num_prefill_tokens
+
num_decode_tokens
assert
value
.
shape
[
0
]
==
num_prefill_tokens
+
num_decode_tokens
output
=
torch
.
empty_like
(
query
)
# Query for decode. KV is not needed because it is already cached.
decode_query
=
query
[
num_prefill_tokens
:]
# QKV for prefill.
query
=
query
[:
num_prefill_tokens
]
key
=
key
[:
num_prefill_tokens
]
value
=
value
[:
num_prefill_tokens
]
assert
query
.
shape
[
0
]
==
num_prefill_tokens
assert
decode_query
.
shape
[
0
]
==
num_decode_tokens
if
prefill_meta
:
=
attn_metadata
.
prefill_metadata
:
# Prompt run.
# Prompt run.
if
kv_cache
is
None
or
attn_metada
ta
.
block_tables
.
numel
()
==
0
:
if
kv_cache
is
None
or
prefill_me
ta
.
block_tables
.
numel
()
==
0
:
# normal attention.
# normal attention.
# block tables are empty if the prompt does not have a cached
# block tables are empty if the prompt does not have a cached
# prefix.
# prefix.
if
self
.
num_kv_heads
!=
self
.
num_heads
:
out
=
self
.
_run_memory_efficient_xformers_forward
(
# As of Nov 2023, xformers only supports MHA. For MQA/GQA,
query
,
key
,
value
,
prefill_meta
)
# project the key and value tensors to the desired number of
assert
out
.
shape
==
output
[:
num_prefill_tokens
].
shape
# heads.
output
[:
num_prefill_tokens
]
=
out
# TODO(woosuk): Use MQA/GQA kernels for higher performance.
query
=
query
.
view
(
query
.
shape
[
0
],
self
.
num_kv_heads
,
self
.
num_queries_per_kv
,
query
.
shape
[
-
1
])
key
=
key
[:,
:,
None
,
:].
expand
(
key
.
shape
[
0
],
self
.
num_kv_heads
,
self
.
num_queries_per_kv
,
key
.
shape
[
-
1
])
value
=
value
[:,
:,
None
,
:].
expand
(
value
.
shape
[
0
],
self
.
num_kv_heads
,
self
.
num_queries_per_kv
,
value
.
shape
[
-
1
])
output
=
self
.
_run_memory_efficient_xformers_forward
(
query
,
key
,
value
,
attn_metadata
)
else
:
else
:
# prefix-enabled attention
# prefix-enabled attention
# TODO(Hai) this triton kernel has regression issue (broke) to
# TODO(Hai) this triton kernel has regression issue (broke) to
# deal with different data types between KV and FP8 KV cache,
# deal with different data types between KV and FP8 KV cache,
# to be addressed separately.
# to be addressed separately.
out
put
=
PagedAttention
.
forward_prefix
(
out
=
PagedAttention
.
forward_prefix
(
query
,
query
,
key
,
key
,
value
,
value
,
key_cache
,
key_cache
,
value_cache
,
value_cache
,
attn_metada
ta
.
block_tables
,
prefill_me
ta
.
block_tables
,
attn_metada
ta
.
subquery_start_loc
,
prefill_me
ta
.
subquery_start_loc
,
attn_metada
ta
.
prompt_lens_tensor
,
prefill_me
ta
.
prompt_lens_tensor
,
attn_metada
ta
.
context_lens
,
prefill_me
ta
.
context_lens
,
attn_metada
ta
.
max_subquery_len
,
prefill_me
ta
.
max_subquery_len
,
self
.
alibi_slopes
,
self
.
alibi_slopes
,
)
)
else
:
assert
output
[:
num_prefill_tokens
].
shape
==
out
.
shape
# Decoding run.
output
[:
num_prefill_tokens
]
=
out
output
=
PagedAttention
.
forward_decode
(
query
,
if
decode_meta
:
=
attn_metadata
.
decode_metadata
:
output
[
num_prefill_tokens
:]
=
PagedAttention
.
forward_decode
(
decode_query
,
key_cache
,
key_cache
,
value_cache
,
value_cache
,
attn_metada
ta
.
block_tables
,
decode_me
ta
.
block_tables
,
attn_metada
ta
.
context_lens
,
decode_me
ta
.
context_lens
,
attn_metada
ta
.
max_context_len
,
decode_me
ta
.
max_context_len
,
attn_metadata
.
kv_cache_dtype
,
attn_metadata
.
kv_cache_dtype
,
self
.
num_kv_heads
,
self
.
num_kv_heads
,
self
.
scale
,
self
.
scale
,
...
@@ -275,13 +278,30 @@ class XFormersImpl(AttentionImpl):
...
@@ -275,13 +278,30 @@ class XFormersImpl(AttentionImpl):
"""Attention for 1D query of multiple prompts. Multiple prompt
"""Attention for 1D query of multiple prompts. Multiple prompt
tokens are flattened in to `query` input.
tokens are flattened in to `query` input.
See https://facebookresearch.github.io/xformers/components/ops.html
for API spec.
Args:
Args:
output: shape = [num_pr
ompt
_tokens, num_heads, head_size]
output: shape = [num_pr
efill
_tokens, num_heads, head_size]
query: shape = [num_pr
ompt
_tokens, num_heads, head_size]
query: shape = [num_pr
efill
_tokens, num_heads, head_size]
key: shape = [num_pr
ompt
_tokens, num_kv_heads, head_size]
key: shape = [num_pr
efill
_tokens, num_kv_heads, head_size]
value: shape = [num_pr
ompt
_tokens, num_kv_heads, head_size]
value: shape = [num_pr
efill
_tokens, num_kv_heads, head_size]
attn_metadata: Metadata for attention.
attn_metadata: Metadata for attention.
"""
"""
original_query
=
query
if
self
.
num_kv_heads
!=
self
.
num_heads
:
# GQA/MQA requires the shape [B, M, G, H, K].
# Note that the output also has the same shape (which is different
# from a spec from the doc).
query
=
query
.
view
(
query
.
shape
[
0
],
self
.
num_kv_heads
,
self
.
num_queries_per_kv
,
query
.
shape
[
-
1
])
key
=
key
[:,
:,
None
,
:].
expand
(
key
.
shape
[
0
],
self
.
num_kv_heads
,
self
.
num_queries_per_kv
,
key
.
shape
[
-
1
])
value
=
value
[:,
:,
None
,
:].
expand
(
value
.
shape
[
0
],
self
.
num_kv_heads
,
self
.
num_queries_per_kv
,
value
.
shape
[
-
1
])
# Set attention bias if not provided. This typically happens at
# Set attention bias if not provided. This typically happens at
# the very attention layer of every iteration.
# the very attention layer of every iteration.
# FIXME(woosuk): This is a hack.
# FIXME(woosuk): This is a hack.
...
@@ -302,6 +322,7 @@ class XFormersImpl(AttentionImpl):
...
@@ -302,6 +322,7 @@ class XFormersImpl(AttentionImpl):
# TODO(woosuk): Too many view operations. Let's try to reduce
# TODO(woosuk): Too many view operations. Let's try to reduce
# them in the future for code readability.
# them in the future for code readability.
if
self
.
alibi_slopes
is
None
:
if
self
.
alibi_slopes
is
None
:
# Add the batch dimension.
query
=
query
.
unsqueeze
(
0
)
query
=
query
.
unsqueeze
(
0
)
key
=
key
.
unsqueeze
(
0
)
key
=
key
.
unsqueeze
(
0
)
value
=
value
.
unsqueeze
(
0
)
value
=
value
.
unsqueeze
(
0
)
...
@@ -312,14 +333,13 @@ class XFormersImpl(AttentionImpl):
...
@@ -312,14 +333,13 @@ class XFormersImpl(AttentionImpl):
attn_bias
=
attn_metadata
.
attn_bias
[
0
],
attn_bias
=
attn_metadata
.
attn_bias
[
0
],
p
=
0.0
,
p
=
0.0
,
scale
=
self
.
scale
)
scale
=
self
.
scale
)
return
out
.
view_as
(
original_query
)
return
out
.
view_as
(
query
)
# Attention with alibi slopes.
# Attention with alibi slopes.
# FIXME(woosuk): Because xformers does not support dynamic sequence
# FIXME(woosuk): Because xformers does not support dynamic sequence
# lengths with custom attention bias, we process each prompt one by
# lengths with custom attention bias, we process each prompt one by
# one. This is inefficient, especially when we have many short prompts.
# one. This is inefficient, especially when we have many short prompts.
output
=
torch
.
empty_like
(
query
)
output
=
torch
.
empty_like
(
original_
query
)
start
=
0
start
=
0
for
i
,
prompt_len
in
enumerate
(
attn_metadata
.
prompt_lens
):
for
i
,
prompt_len
in
enumerate
(
attn_metadata
.
prompt_lens
):
end
=
start
+
prompt_len
end
=
start
+
prompt_len
...
@@ -331,7 +351,7 @@ class XFormersImpl(AttentionImpl):
...
@@ -331,7 +351,7 @@ class XFormersImpl(AttentionImpl):
p
=
0.0
,
p
=
0.0
,
scale
=
self
.
scale
)
scale
=
self
.
scale
)
# TODO(woosuk): Unnecessary copy. Optimize.
# TODO(woosuk): Unnecessary copy. Optimize.
output
[
start
:
end
].
copy_
(
out
.
squeeze
(
0
))
output
[
start
:
end
].
copy_
(
out
.
view_as
(
original_query
[
start
:
end
]
))
start
+=
prompt_len
start
+=
prompt_len
return
output
return
output
...
...
vllm/attention/layer.py
View file @
67b4221a
...
@@ -4,7 +4,8 @@ from typing import List, Optional
...
@@ -4,7 +4,8 @@ from typing import List, Optional
import
torch
import
torch
import
torch.nn
as
nn
import
torch.nn
as
nn
from
vllm.attention.backends.abstract
import
AttentionMetadata
from
vllm.attention.backends.abstract
import
(
AttentionMetadata
,
AttentionMetadataPerStage
)
from
vllm.attention.selector
import
get_attn_backend
from
vllm.attention.selector
import
get_attn_backend
...
@@ -41,7 +42,7 @@ class Attention(nn.Module):
...
@@ -41,7 +42,7 @@ class Attention(nn.Module):
key
:
torch
.
Tensor
,
key
:
torch
.
Tensor
,
value
:
torch
.
Tensor
,
value
:
torch
.
Tensor
,
kv_cache
:
Optional
[
torch
.
Tensor
],
kv_cache
:
Optional
[
torch
.
Tensor
],
attn_metadata
:
AttentionMetadata
,
attn_metadata
:
AttentionMetadata
[
AttentionMetadataPerStage
]
,
kv_scale
:
float
=
1.0
,
kv_scale
:
float
=
1.0
,
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
return
self
.
impl
.
forward
(
query
,
key
,
value
,
kv_cache
,
attn_metadata
,
return
self
.
impl
.
forward
(
query
,
key
,
value
,
kv_cache
,
attn_metadata
,
...
...
vllm/attention/ops/paged_attn.py
View file @
67b4221a
...
@@ -13,11 +13,6 @@ _PARTITION_SIZE = 512
...
@@ -13,11 +13,6 @@ _PARTITION_SIZE = 512
@
dataclass
@
dataclass
class
PagedAttentionMetadata
:
class
PagedAttentionMetadata
:
"""Metadata for PagedAttention."""
"""Metadata for PagedAttention."""
# (num_tokens,). The indices of the token slots that input tokens will be
# stored into. E.g., if `slot_mapping` is [35, 2, 17] and the block size
# is 16, the three tokens are stored in the 3rd slot in block 2, 2nd slot
# in block 0, and 1st slot in block 1, respectively.
slot_mapping
:
torch
.
Tensor
# (batch_size,). The length of context (tokens stored in KV cache) per
# (batch_size,). The length of context (tokens stored in KV cache) per
# sequence. WARNING: When it is a prefill request, it doesn't include new
# sequence. WARNING: When it is a prefill request, it doesn't include new
# tokens. When it is for decoding, it includes a new token.
# tokens. When it is for decoding, it includes a new token.
...
@@ -31,7 +26,6 @@ class PagedAttentionMetadata:
...
@@ -31,7 +26,6 @@ class PagedAttentionMetadata:
# 2nd dimensions are padded up to max_blocks_per_seq if it is cuda-graph
# 2nd dimensions are padded up to max_blocks_per_seq if it is cuda-graph
# captured.
# captured.
block_tables
:
Optional
[
torch
.
Tensor
]
block_tables
:
Optional
[
torch
.
Tensor
]
kv_cache_dtype
:
str
class
PagedAttention
:
class
PagedAttention
:
...
...
vllm/config.py
View file @
67b4221a
...
@@ -565,9 +565,16 @@ class SchedulerConfig:
...
@@ -565,9 +565,16 @@ class SchedulerConfig:
if
max_num_batched_tokens
is
not
None
:
if
max_num_batched_tokens
is
not
None
:
self
.
max_num_batched_tokens
=
max_num_batched_tokens
self
.
max_num_batched_tokens
=
max_num_batched_tokens
else
:
else
:
# If max_model_len is too short, use 2048 as the default value for
if
enable_chunked_prefill
:
# higher throughput.
# For chunked prefill, choose the well-tuned batch size.
self
.
max_num_batched_tokens
=
max
(
max_model_len
,
2048
)
self
.
max_num_batched_tokens
=
768
else
:
# If max_model_len is too short, use 2048 as the default value
# for higher throughput.
self
.
max_num_batched_tokens
=
max
(
max_model_len
,
2048
)
if
enable_chunked_prefill
:
logger
.
info
(
"Chunked prefill is enabled (EXPERIMENTAL)."
)
self
.
max_num_seqs
=
max_num_seqs
self
.
max_num_seqs
=
max_num_seqs
self
.
max_model_len
=
max_model_len
self
.
max_model_len
=
max_model_len
self
.
use_v2_block_manager
=
use_v2_block_manager
self
.
use_v2_block_manager
=
use_v2_block_manager
...
...
vllm/core/scheduler.py
View file @
67b4221a
...
@@ -140,7 +140,11 @@ class SchedulerOutputs:
...
@@ -140,7 +140,11 @@ class SchedulerOutputs:
@
property
@
property
def
lora_requests
(
self
)
->
Set
[
LoRARequest
]:
def
lora_requests
(
self
)
->
Set
[
LoRARequest
]:
return
{
g
.
seq_group
.
lora_request
for
g
in
self
.
scheduled_seq_groups
}
return
{
g
.
seq_group
.
lora_request
for
g
in
self
.
scheduled_seq_groups
if
g
.
seq_group
.
lora_request
is
not
None
}
@
dataclass
@
dataclass
...
@@ -826,13 +830,12 @@ class Scheduler:
...
@@ -826,13 +830,12 @@ class Scheduler:
# Update swapped requests.
# Update swapped requests.
self
.
swapped
=
remaining_swapped
self
.
swapped
=
remaining_swapped
self
.
swapped
.
extend
(
running_scheduled
.
swapped_out
)
self
.
swapped
.
extend
(
running_scheduled
.
swapped_out
)
return
SchedulerOutputs
(
return
SchedulerOutputs
(
scheduled_seq_groups
=
(
prefills
.
seq_groups
+
scheduled_seq_groups
=
(
prefills
.
seq_groups
+
running_scheduled
.
decode_seq_groups
+
running_scheduled
.
prefill_seq_groups
+
running_scheduled
.
prefill_seq_groups
+
swapped_in
.
decode_seq_groups
+
swapped_in
.
prefill_seq_groups
+
swapped_in
.
prefill_seq_groups
),
running_scheduled
.
decode_seq_groups
+
swapped_in
.
decode_seq_groups
),
num_prefill_groups
=
(
len
(
prefills
.
seq_groups
)
+
num_prefill_groups
=
(
len
(
prefills
.
seq_groups
)
+
len
(
swapped_in
.
prefill_seq_groups
)
+
len
(
swapped_in
.
prefill_seq_groups
)
+
len
(
running_scheduled
.
prefill_seq_groups
)),
len
(
running_scheduled
.
prefill_seq_groups
)),
...
@@ -907,7 +910,7 @@ class Scheduler:
...
@@ -907,7 +910,7 @@ class Scheduler:
# It assumes the scheduled_seq_groups is ordered by
# It assumes the scheduled_seq_groups is ordered by
# prefill < decoding.
# prefill < decoding.
is_prompt
=
i
<
scheduler_outputs
.
num
_prefill
_groups
is_prompt
=
seq_group
.
is
_prefill
()
seq_group_metadata
=
SequenceGroupMetadata
(
seq_group_metadata
=
SequenceGroupMetadata
(
request_id
=
seq_group
.
request_id
,
request_id
=
seq_group
.
request_id
,
is_prompt
=
is_prompt
,
is_prompt
=
is_prompt
,
...
...
Prev
1
2
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment