Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
3597b06a
Unverified
Commit
3597b06a
authored
Jun 13, 2025
by
Luka Govedič
Committed by
GitHub
Jun 13, 2025
Browse files
[CUDA] Enable full cudagraph for FlashMLA (#18581)
Signed-off-by:
luka
<
luka@neuralmagic.com
>
parent
1015296b
Changes
17
Hide whitespace changes
Inline
Side-by-side
Showing
17 changed files
with
453 additions
and
220 deletions
+453
-220
tests/compile/piecewise/test_full_cudagraph.py
tests/compile/piecewise/test_full_cudagraph.py
+107
-51
tests/compile/piecewise/test_simple.py
tests/compile/piecewise/test_simple.py
+5
-11
tests/compile/piecewise/test_toy_llama.py
tests/compile/piecewise/test_toy_llama.py
+21
-24
tests/utils.py
tests/utils.py
+21
-9
vllm/compilation/cuda_piecewise_backend.py
vllm/compilation/cuda_piecewise_backend.py
+5
-1
vllm/entrypoints/llm.py
vllm/entrypoints/llm.py
+2
-1
vllm/forward_context.py
vllm/forward_context.py
+12
-6
vllm/v1/attention/backends/cpu_attn.py
vllm/v1/attention/backends/cpu_attn.py
+8
-4
vllm/v1/attention/backends/flash_attn.py
vllm/v1/attention/backends/flash_attn.py
+18
-11
vllm/v1/attention/backends/flashinfer.py
vllm/v1/attention/backends/flashinfer.py
+7
-4
vllm/v1/attention/backends/flex_attention.py
vllm/v1/attention/backends/flex_attention.py
+9
-13
vllm/v1/attention/backends/mla/common.py
vllm/v1/attention/backends/mla/common.py
+36
-8
vllm/v1/attention/backends/mla/flashmla.py
vllm/v1/attention/backends/mla/flashmla.py
+31
-3
vllm/v1/attention/backends/mla/rocm_aiter_mla.py
vllm/v1/attention/backends/mla/rocm_aiter_mla.py
+1
-1
vllm/v1/attention/backends/utils.py
vllm/v1/attention/backends/utils.py
+71
-2
vllm/v1/spec_decode/eagle.py
vllm/v1/spec_decode/eagle.py
+6
-4
vllm/v1/worker/gpu_model_runner.py
vllm/v1/worker/gpu_model_runner.py
+93
-67
No files found.
tests/compile/piecewise/test_full_cudagraph.py
View file @
3597b06a
...
...
@@ -2,15 +2,16 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import
contextlib
import
os
import
weakref
from
contextlib
import
ExitStack
import
pytest
from
tests.utils
import
wait_for_gpu_memory_to_clear
from
vllm
import
LLM
,
SamplingParams
from
vllm.config
import
CompilationConfig
from
vllm.platforms
import
current_platform
MODEL
=
"Qwen/Qwen2-1.5B-Instruct"
@
contextlib
.
contextmanager
def
temporary_environ
(
env_vars
):
...
...
@@ -31,64 +32,119 @@ def temporary_environ(env_vars):
os
.
environ
[
k
]
=
v
@
pytest
.
fixture
(
scope
=
"module"
)
def
full_cudagraph_llm
():
with
temporary_environ
({
"VLLM_USE_V1"
:
"1"
,
"VLLM_FLASH_ATTN_VERSION"
:
"3"
}):
return
LLM
(
model
=
MODEL
,
gpu_memory_utilization
=
0.3
,
compilation_config
=
CompilationConfig
(
full_cuda_graph
=
True
))
@
pytest
.
fixture
(
scope
=
"class"
)
def
llm_pair
(
request
):
model
=
request
.
param
@
pytest
.
fixture
(
scope
=
"module"
)
def
piecewise_llm
():
with
temporary_environ
({
"VLLM_USE_V1"
:
"1"
,
"VLLM_FLASH_ATTN_VERSION"
:
"3"
}):
return
LLM
(
model
=
MODEL
,
gpu_memory_utilization
=
0.6
,
compilation_config
=
CompilationConfig
())
def
generate_text
(
llm
:
LLM
,
batch_size
:
int
,
max_tokens
:
int
):
prompts
=
[
"Hi my name is"
]
*
batch_size
sampling_params
=
SamplingParams
(
temperature
=
0.0
,
max_tokens
=
max_tokens
,
top_p
=
0.95
)
return
llm
.
generate
(
prompts
,
sampling_params
)
full
=
LLM
(
model
=
model
,
gpu_memory_utilization
=
0.45
,
trust_remote_code
=
True
,
max_model_len
=
1024
,
compilation_config
=
CompilationConfig
(
full_cuda_graph
=
True
),
)
piecewise
=
LLM
(
model
=
model
,
gpu_memory_utilization
=
0.45
,
trust_remote_code
=
True
,
max_model_len
=
1024
,
compilation_config
=
CompilationConfig
(),
)
# PyTest caches the fixture values so we use weakref.proxy to enable GC
yield
weakref
.
proxy
(
full
),
weakref
.
proxy
(
piecewise
)
del
full
del
piecewise
wait_for_gpu_memory_to_clear
(
devices
=
[
0
],
threshold_ratio
=
0.1
,
)
@
pytest
.
mark
.
parametrize
(
"llm_pair"
,
[
# Model names for the llm_pair fixture
"deepseek-ai/DeepSeek-V2-Lite"
,
"Qwen/Qwen2-1.5B-Instruct"
],
indirect
=
True
)
@
pytest
.
mark
.
skipif
(
current_platform
.
get_device_capability
()
!=
(
9
,
0
),
reason
=
"Only Hopper GPUs support FlashAttention 3"
)
@
pytest
.
mark
.
parametrize
((
"batch_size"
,
"max_tokens"
),
[(
1
,
10
),
(
7
,
10
),
(
16
,
10
),
(
25
,
10
),
(
32
,
10
),
(
45
,
10
),
(
64
,
10
),
(
8
,
5
),
(
8
,
20
),
(
8
,
200
)])
def
test_full_cudagraph
(
batch_size
,
max_tokens
,
full_cudagraph_llm
,
piecewise_llm
):
reason
=
"Only Hopper GPUs support FA3 and FlashMLA"
)
class
TestFullCUDAGraph
:
"""
Load full cudagraph model and piecewise model once, and at the same time to
reuse them across var
io
u
s
test cases
.
Use a class such that an llm pair is constructed once for all
batch_size/max_tokens combinat
io
n
s
and released immediately after
.
Test various batch sizes and max_tokens to ensure that the full cudagraph
compilation works for padded cases too
.
Module-scope fixtures would stick around the whole time,
meaning there would be multiple LLM instances hogging memory simultaneously
.
"""
piecewise_responses
=
generate_text
(
piecewise_llm
,
batch_size
=
batch_size
,
max_tokens
=
max_tokens
)
full_cudagraph_responses
=
generate_text
(
full_cudagraph_llm
,
batch_size
=
batch_size
,
max_tokens
=
max_tokens
)
# Check that all responses are the same
for
i
in
range
(
len
(
piecewise_responses
)):
assert
piecewise_responses
[
i
].
outputs
[
0
].
text
==
full_cudagraph_responses
[
i
].
outputs
[
0
].
text
@
pytest
.
mark
.
parametrize
((
"batch_size"
,
"max_tokens"
),
[
(
1
,
10
),
(
7
,
10
),
(
16
,
10
),
(
25
,
10
),
(
32
,
10
),
(
45
,
10
),
(
64
,
10
),
(
123
,
10
),
(
8
,
5
),
(
8
,
30
),
])
def
test_full_cudagraph
(
self
,
batch_size
,
max_tokens
,
llm_pair
:
tuple
[
LLM
,
LLM
]):
"""
Test various batch sizes and max_tokens to ensure that the
full cudagraph compilation works for padded cases too.
"""
piecewise_llm
,
full_cudagraph_llm
=
llm_pair
prompts
=
[
"Hello, my name is"
]
*
batch_size
sampling_params
=
SamplingParams
(
temperature
=
0.0
,
max_tokens
=
max_tokens
,
top_p
=
0.95
)
piecewise_responses
=
piecewise_llm
.
generate
(
prompts
,
sampling_params
)
full_responses
=
full_cudagraph_llm
.
generate
(
prompts
,
sampling_params
)
# Check that all responses are the same
for
piecewise_res
,
full_res
in
zip
(
piecewise_responses
,
full_responses
):
assert
piecewise_res
.
outputs
[
0
].
text
==
full_res
.
outputs
[
0
].
text
@
pytest
.
mark
.
parametrize
(
"model, supported"
,
[
(
"Qwen/Qwen2-1.5B-Instruct"
,
True
),
# MLA does not support capturing CUDA Graphs with size > max_num_seqs
(
"deepseek-ai/DeepSeek-V2-Lite"
,
False
),
])
@
pytest
.
mark
.
skipif
(
current_platform
.
get_device_capability
()
!=
(
9
,
0
),
reason
=
"Only Hopper GPUs support FA3 and FlashMLA"
)
def
test_lower_max_num_seqs
(
model
,
supported
):
with
temporary_environ
({
"VLLM_USE_V1"
:
"1"
,
"VLLM_FLASH_ATTN_VERSION"
:
"3"
}),
ExitStack
()
as
stack
:
if
not
supported
:
stack
.
enter_context
(
pytest
.
raises
(
RuntimeError
))
llm
=
LLM
(
model
=
model
,
max_num_seqs
=
256
,
trust_remote_code
=
True
,
max_model_len
=
1024
,
compilation_config
=
CompilationConfig
(
full_cuda_graph
=
True
,
cudagraph_capture_sizes
=
[
64
,
256
,
512
]))
llm
.
generate
([
"Hello, my name is"
]
*
10
)
def
test_full_cudagraph_with_invalid_backend
():
...
...
@@ -97,5 +153,5 @@ def test_full_cudagraph_with_invalid_backend():
"VLLM_FLASH_ATTN_VERSION"
:
"2"
#FA2 not supported with full_cuda_graph
}),
pytest
.
raises
(
RuntimeError
):
LLM
(
model
=
MODEL
,
LLM
(
model
=
"Qwen/Qwen2-1.5B-Instruct"
,
compilation_config
=
CompilationConfig
(
full_cuda_graph
=
True
))
tests/compile/piecewise/test_simple.py
View file @
3597b06a
...
...
@@ -4,7 +4,7 @@
Test the piecewise compilation with a simple model so that we
can exactly calculate the expected output and side effects.
"""
import
pytest
import
torch
from
torch
import
nn
from
torch.library
import
Library
...
...
@@ -14,6 +14,7 @@ from vllm.compilation.decorators import support_torch_compile
from
vllm.config
import
(
CompilationConfig
,
CompilationLevel
,
VllmConfig
,
set_current_vllm_config
)
from
vllm.envs
import
VLLM_USE_V1
from
vllm.forward_context
import
set_forward_context
from
vllm.utils
import
direct_register_custom_op
global_counter
=
0
...
...
@@ -76,7 +77,8 @@ class SillyModel(nn.Module):
return
x
def
_test_simple_piecewise_compile
(
*
,
use_inductor
):
@
pytest
.
mark
.
parametrize
(
"use_inductor"
,
[
True
,
False
])
def
test_simple_piecewise_compile
(
use_inductor
):
assert
VLLM_USE_V1
vllm_config
=
VllmConfig
(
compilation_config
=
CompilationConfig
(
...
...
@@ -99,7 +101,7 @@ def _test_simple_piecewise_compile(*, use_inductor):
num_backend_compilations
=
3
,
# num_piecewise_capturable_graphs_seen
num_cudagraph_captured
=
6
,
# num_cudagraph_sizes * num_piecewise_capturable_graphs_seen
):
),
set_forward_context
({},
vllm_config
=
vllm_config
):
model
(
inputs
)
...
...
@@ -112,11 +114,3 @@ def _test_simple_piecewise_compile(*, use_inductor):
output
=
model
(
input
)
assert
global_counter
==
2
assert
torch
.
allclose
(
output
.
cpu
(),
torch
.
tensor
([
3.
,
1.
]))
def
test_simple_piecewise_compile_inductor
():
_test_simple_piecewise_compile
(
use_inductor
=
True
)
def
test_simple_piecewise_compile_no_inductor
():
_test_simple_piecewise_compile
(
use_inductor
=
False
)
tests/compile/piecewise/test_toy_llama.py
View file @
3597b06a
...
...
@@ -11,6 +11,7 @@ initialized randomly with a fixed seed.
from
dataclasses
import
dataclass
from
typing
import
Any
,
Optional
import
pytest
import
torch
from
torch
import
nn
from
torch.library
import
Library
...
...
@@ -19,6 +20,7 @@ from vllm.compilation.counter import compilation_counter
from
vllm.compilation.decorators
import
support_torch_compile
from
vllm.config
import
(
CompilationConfig
,
CompilationLevel
,
VllmConfig
,
set_current_vllm_config
)
from
vllm.forward_context
import
set_forward_context
from
vllm.utils
import
direct_register_custom_op
# create a library to hold the custom op
...
...
@@ -285,29 +287,32 @@ def run_model(llama_config,
vllm_config
=
vllm_config
,
prefix
=
""
).
eval
().
cuda
()
B
=
16
# max batch size
input_ids
=
torch
.
randint
(
0
,
llama_config
.
vocab_size
,
(
B
,
)).
cuda
()
positions
=
torch
.
arange
(
B
).
cuda
()
with
set_forward_context
({},
vllm_config
=
vllm_config
):
B
=
16
# max batch size
input_ids
=
torch
.
randint
(
0
,
llama_config
.
vocab_size
,
(
B
,
)).
cuda
()
positions
=
torch
.
arange
(
B
).
cuda
()
model
(
input_ids
,
positions
)
model
(
input_ids
[:
2
],
positions
[:
2
])
model
(
input_ids
[:
1
],
positions
[:
1
])
model
(
input_ids
,
positions
)
model
(
input_ids
[:
2
],
positions
[:
2
])
model
(
input_ids
[:
1
],
positions
[:
1
])
input_ids
[:
2
].
zero_
()
output
=
model
(
input_ids
[:
2
],
positions
[:
2
])
input_ids
[:
2
].
zero_
()
output
=
model
(
input_ids
[:
2
],
positions
[:
2
])
output
=
output
.
cpu
()
output
=
output
.
cpu
()
if
llama_config
.
tractable_init
:
expected_output
=
tractable_computation
(
input_ids
[:
2
],
positions
[:
2
],
llama_config
).
cpu
()
if
llama_config
.
tractable_init
:
expected_output
=
tractable_computation
(
input_ids
[:
2
],
positions
[:
2
],
llama_config
).
cpu
()
assert
torch
.
allclose
(
output
,
expected_output
)
else
:
return
output
.
cpu
()
assert
torch
.
allclose
(
output
,
expected_output
)
else
:
return
output
.
cpu
()
def
_test_toy_llama
(
*
,
use_inductor
):
@
pytest
.
mark
.
parametrize
(
"use_inductor"
,
[
True
,
False
])
def
test_toy_llama
(
use_inductor
:
bool
):
# compare output with and without piecewise compilation
llama_config
=
LlamaConfig
(
hidden_size
=
128
,
...
...
@@ -379,14 +384,6 @@ def _test_toy_llama(*, use_inductor):
assert
torch
.
allclose
(
outputs
[
0
],
outputs
[
i
])
def
test_toy_llama_inductor
():
_test_toy_llama
(
use_inductor
=
True
)
def
test_toy_no_inductor
():
_test_toy_llama
(
use_inductor
=
False
)
@
torch
.
inference_mode
def
benchmark
():
from
triton.testing
import
do_bench
...
...
tests/utils.py
View file @
3597b06a
...
...
@@ -667,42 +667,54 @@ def get_physical_device_indices(devices):
@
_nvml
()
def
wait_for_gpu_memory_to_clear
(
devices
:
list
[
int
],
threshold_bytes
:
int
,
def
wait_for_gpu_memory_to_clear
(
*
,
devices
:
list
[
int
],
threshold_bytes
:
Optional
[
int
]
=
None
,
threshold_ratio
:
Optional
[
float
]
=
None
,
timeout_s
:
float
=
120
)
->
None
:
assert
threshold_bytes
is
not
None
or
threshold_ratio
is
not
None
# Use nvml instead of pytorch to reduce measurement error from torch cuda
# context.
devices
=
get_physical_device_indices
(
devices
)
start_time
=
time
.
time
()
while
True
:
output
:
dict
[
int
,
str
]
=
{}
output_raw
:
dict
[
int
,
float
]
=
{}
output_raw
:
dict
[
int
,
tuple
[
float
,
float
]
]
=
{}
for
device
in
devices
:
if
current_platform
.
is_rocm
():
dev_handle
=
amdsmi_get_processor_handles
()[
device
]
mem_info
=
amdsmi_get_gpu_vram_usage
(
dev_handle
)
gb_used
=
mem_info
[
"vram_used"
]
/
2
**
10
gb_total
=
mem_info
[
"vram_total"
]
/
2
**
10
else
:
dev_handle
=
nvmlDeviceGetHandleByIndex
(
device
)
mem_info
=
nvmlDeviceGetMemoryInfo
(
dev_handle
)
gb_used
=
mem_info
.
used
/
2
**
30
output_raw
[
device
]
=
gb_used
output
[
device
]
=
f
'
{
gb_used
:.
02
f
}
'
gb_total
=
mem_info
.
total
/
2
**
30
output_raw
[
device
]
=
(
gb_used
,
gb_total
)
output
[
device
]
=
f
'
{
gb_used
:.
02
f
}
/
{
gb_total
:.
02
f
}
'
print
(
'gpu memory used (GB): '
,
end
=
''
)
print
(
'gpu memory used
/total
(G
i
B): '
,
end
=
''
)
for
k
,
v
in
output
.
items
():
print
(
f
'
{
k
}
=
{
v
}
; '
,
end
=
''
)
print
(
''
)
if
threshold_bytes
is
not
None
:
is_free
=
lambda
used
,
total
:
used
<=
threshold_bytes
/
2
**
30
threshold
=
f
"
{
threshold_bytes
/
2
**
30
}
GiB"
else
:
is_free
=
lambda
used
,
total
:
used
/
total
<=
threshold_ratio
threshold
=
f
"
{
threshold_ratio
:.
2
f
}
"
dur_s
=
time
.
time
()
-
start_time
if
all
(
v
<=
(
threshold_bytes
/
2
**
30
)
for
v
in
output_raw
.
values
()):
if
all
(
is_free
(
used
,
total
)
for
used
,
total
in
output_raw
.
values
()):
print
(
f
'Done waiting for free GPU memory on devices
{
devices
=
}
'
f
'(
{
threshold
_bytes
/
2
**
30
=
}
)
{
dur_s
=
:.
02
f
}
'
)
f
'(
{
threshold
=
}
)
{
dur_s
=
:.
02
f
}
'
)
break
if
dur_s
>=
timeout_s
:
raise
ValueError
(
f
'Memory of devices
{
devices
=
}
not free after '
f
'
{
dur_s
=
:.
02
f
}
(
{
threshold
_bytes
/
2
**
30
=
}
)'
)
f
'
{
dur_s
=
:.
02
f
}
(
{
threshold
=
}
)'
)
time
.
sleep
(
5
)
...
...
vllm/compilation/cuda_piecewise_backend.py
View file @
3597b06a
...
...
@@ -14,6 +14,7 @@ from vllm.compilation.backends import VllmBackend
from
vllm.compilation.counter
import
compilation_counter
from
vllm.compilation.monitor
import
end_monitoring_torch_compile
from
vllm.config
import
VllmConfig
from
vllm.forward_context
import
get_forward_context
from
vllm.logger
import
init_logger
from
vllm.utils
import
weak_ref_tensors
...
...
@@ -137,7 +138,10 @@ class CUDAPiecewiseBackend:
if
self
.
is_last_graph
and
not
self
.
to_be_compiled_sizes
:
self
.
check_for_ending_compilation
()
if
not
entry
.
use_cudagraph
:
# Skip CUDA graphs if this entry doesn't use them OR
# if we're supposed to skip them globally
skip_cuda_graphs
=
get_forward_context
().
skip_cuda_graphs
if
not
entry
.
use_cudagraph
or
skip_cuda_graphs
:
return
entry
.
runnable
(
*
args
)
if
entry
.
cudagraph
is
None
:
...
...
vllm/entrypoints/llm.py
View file @
3597b06a
...
...
@@ -179,7 +179,8 @@ class LLM:
hf_overrides
:
Optional
[
HfOverrides
]
=
None
,
mm_processor_kwargs
:
Optional
[
dict
[
str
,
Any
]]
=
None
,
override_pooler_config
:
Optional
[
PoolerConfig
]
=
None
,
compilation_config
:
Optional
[
Union
[
int
,
dict
[
str
,
Any
]]]
=
None
,
compilation_config
:
Optional
[
Union
[
int
,
dict
[
str
,
Any
],
CompilationConfig
]]
=
None
,
**
kwargs
,
)
->
None
:
"""LLM constructor."""
...
...
vllm/forward_context.py
View file @
3597b06a
...
...
@@ -94,6 +94,7 @@ class ForwardContext:
virtual_engine
:
int
# set dynamically for each forward pass
# set dynamically for each forward pass
dp_metadata
:
Optional
[
DPMetadata
]
=
None
skip_cuda_graphs
:
bool
=
False
_forward_context
:
Optional
[
ForwardContext
]
=
None
...
...
@@ -108,11 +109,14 @@ def get_forward_context() -> ForwardContext:
@
contextmanager
def
set_forward_context
(
attn_metadata
:
Any
,
vllm_config
:
VllmConfig
,
virtual_engine
:
int
=
0
,
num_tokens
:
Optional
[
int
]
=
None
,
num_tokens_across_dp
:
Optional
[
torch
.
Tensor
]
=
None
):
def
set_forward_context
(
attn_metadata
:
Any
,
vllm_config
:
VllmConfig
,
virtual_engine
:
int
=
0
,
num_tokens
:
Optional
[
int
]
=
None
,
num_tokens_across_dp
:
Optional
[
torch
.
Tensor
]
=
None
,
skip_cuda_graphs
:
bool
=
False
,
):
"""A context manager that stores the current forward context,
can be attention metadata, etc.
Here we can inject common logic for every model forward pass.
...
...
@@ -135,7 +139,9 @@ def set_forward_context(attn_metadata: Any,
static_forward_context
,
virtual_engine
=
virtual_engine
,
attn_metadata
=
attn_metadata
,
dp_metadata
=
dp_metadata
)
dp_metadata
=
dp_metadata
,
skip_cuda_graphs
=
skip_cuda_graphs
,
)
try
:
yield
...
...
vllm/v1/attention/backends/cpu_attn.py
View file @
3597b06a
...
...
@@ -7,7 +7,8 @@ from vllm.attention.backends.torch_sdpa import (TorchSDPABackendImpl,
TorchSDPAMetadata
)
from
vllm.attention.backends.utils
import
CommonAttentionState
from
vllm.attention.ops.ipex_attn
import
PagedAttention
from
vllm.v1.attention.backends.utils
import
CommonAttentionMetadata
from
vllm.v1.attention.backends.utils
import
(
AttentionMetadataBuilder
,
CommonAttentionMetadata
)
from
vllm.v1.core.sched.output
import
SchedulerOutput
from
vllm.v1.kv_cache_interface
import
AttentionSpec
from
vllm.v1.worker.block_table
import
BlockTable
...
...
@@ -53,7 +54,7 @@ class TorchSDPABackend:
return
False
class
TorchSDPAMetadataBuilderV1
:
class
TorchSDPAMetadataBuilderV1
(
AttentionMetadataBuilder
[
TorchSDPAMetadata
])
:
def
__init__
(
self
,
runner
:
CPUModelRunner
,
kv_cache_spec
:
AttentionSpec
,
block_table
:
BlockTable
)
->
None
:
...
...
@@ -118,9 +119,12 @@ class TorchSDPAMetadataBuilderV1:
return
True
def
build
(
self
,
num_reqs
:
int
,
num_actual_tokens
:
int
,
max_query_len
:
int
,
common_prefix_len
:
int
,
def
build
(
self
,
common_prefix_len
:
int
,
common_attn_metadata
:
CommonAttentionMetadata
):
num_reqs
=
common_attn_metadata
.
num_reqs
num_actual_tokens
=
common_attn_metadata
.
num_actual_tokens
max_query_len
=
common_attn_metadata
.
max_query_len
runner
=
self
.
runner
block_table
=
self
.
block_table
seq_lens_np
=
runner
.
seq_lens_np
[:
num_reqs
]
...
...
vllm/v1/attention/backends/flash_attn.py
View file @
3597b06a
...
...
@@ -2,7 +2,7 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""Attention layer with FlashAttention."""
from
dataclasses
import
dataclass
from
typing
import
TYPE_CHECKING
,
Any
,
Optional
from
typing
import
TYPE_CHECKING
,
Any
,
ClassVar
,
Optional
import
numpy
as
np
import
torch
...
...
@@ -21,13 +21,12 @@ from vllm.distributed.kv_transfer.kv_connector.utils import (
from
vllm.logger
import
init_logger
from
vllm.platforms
import
current_platform
from
vllm.utils
import
cdiv
from
vllm.v1.attention.backends.utils
import
CommonAttentionMetadata
from
vllm.v1.attention.backends.utils
import
(
AttentionMetadataBuilder
,
CommonAttentionMetadata
)
from
vllm.v1.kv_cache_interface
import
AttentionSpec
from
vllm.v1.worker.block_table
import
BlockTable
if
TYPE_CHECKING
:
from
vllm.v1.core.sched.output
import
SchedulerOutput
from
vllm.v1.worker.gpu_input_batch
import
InputBatch
from
vllm.v1.worker.gpu_model_runner
import
GPUModelRunner
if
current_platform
.
is_cuda
():
...
...
@@ -306,7 +305,9 @@ def _get_sliding_window_configs(
return
sliding_window_configs
class
FlashAttentionMetadataBuilder
:
class
FlashAttentionMetadataBuilder
(
AttentionMetadataBuilder
[
FlashAttentionMetadata
]):
full_cudagraph_supported
:
ClassVar
[
bool
]
=
get_flash_attn_version
()
==
3
def
__init__
(
self
,
runner
:
"GPUModelRunner"
,
kv_cache_spec
:
AttentionSpec
,
block_table
:
BlockTable
):
...
...
@@ -336,13 +337,14 @@ class FlashAttentionMetadataBuilder:
# populated on first build() call.
self
.
aot_sliding_window
:
Optional
[
tuple
[
int
,
int
]]
=
None
def
reorder_batch
(
self
,
input_batch
:
"InputBatch"
,
scheduler_output
:
"SchedulerOutput"
)
->
bool
:
return
False
def
build
(
self
,
common_prefix_len
:
int
,
common_attn_metadata
:
CommonAttentionMetadata
)
->
FlashAttentionMetadata
:
num_reqs
=
common_attn_metadata
.
num_reqs
num_actual_tokens
=
common_attn_metadata
.
num_actual_tokens
max_query_len
=
common_attn_metadata
.
max_query_len
def
build
(
self
,
num_reqs
:
int
,
num_actual_tokens
:
int
,
max_query_len
:
int
,
common_prefix_len
:
int
,
common_attn_metadata
:
CommonAttentionMetadata
):
max_seq_len
=
int
(
self
.
runner
.
seq_lens_np
[:
num_reqs
].
max
())
query_start_loc
=
common_attn_metadata
.
query_start_loc
seq_lens
=
common_attn_metadata
.
seq_lens
...
...
@@ -496,6 +498,11 @@ class FlashAttentionMetadataBuilder:
)
return
attn_metadata
def
can_run_in_cudagraph
(
self
,
common_attn_metadata
:
CommonAttentionMetadata
)
->
bool
:
# Full CUDA Graph always supported (FA2 support checked separately)
return
True
def
use_cascade_attention
(
self
,
*
args
,
**
kwargs
)
->
bool
:
return
use_cascade_attention
(
*
args
,
**
kwargs
)
...
...
vllm/v1/attention/backends/flashinfer.py
View file @
3597b06a
...
...
@@ -18,7 +18,8 @@ from vllm.attention.layer import Attention
from
vllm.config
import
VllmConfig
,
get_layers_from_vllm_config
from
vllm.logger
import
init_logger
from
vllm.v1.attention.backends.flash_attn
import
use_cascade_attention
from
vllm.v1.attention.backends.utils
import
CommonAttentionMetadata
from
vllm.v1.attention.backends.utils
import
(
AttentionMetadataBuilder
,
CommonAttentionMetadata
)
from
vllm.v1.kv_cache_interface
import
AttentionSpec
from
vllm.v1.worker.block_table
import
BlockTable
...
...
@@ -202,7 +203,7 @@ class FlashInferMetadata:
f
" received
{
self
.
head_dim
}
."
)
class
FlashInferMetadataBuilder
:
class
FlashInferMetadataBuilder
(
AttentionMetadataBuilder
[
FlashInferMetadata
])
:
def
__init__
(
self
,
runner
:
GPUModelRunner
,
kv_cache_spec
:
AttentionSpec
,
block_table
:
BlockTable
):
...
...
@@ -399,9 +400,11 @@ class FlashInferMetadataBuilder:
kv_data_type
=
attn_metadata
.
data_type
,
)
def
build
(
self
,
num_reqs
:
int
,
num_actual_tokens
:
int
,
max_query_len
:
int
,
common_prefix_len
:
int
,
def
build
(
self
,
common_prefix_len
:
int
,
common_attn_metadata
:
CommonAttentionMetadata
):
num_reqs
=
common_attn_metadata
.
num_reqs
num_actual_tokens
=
common_attn_metadata
.
num_actual_tokens
assert
self
.
_num_decodes
+
self
.
_num_prefills
==
num_reqs
assert
(
self
.
_num_decode_tokens
+
self
.
_num_prefill_tokens
==
num_actual_tokens
)
...
...
vllm/v1/attention/backends/flex_attention.py
View file @
3597b06a
...
...
@@ -15,7 +15,8 @@ from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
is_quantized_kv_cache
)
from
vllm.logger
import
init_logger
from
vllm.platforms
import
current_platform
from
vllm.v1.attention.backends.utils
import
CommonAttentionMetadata
from
vllm.v1.attention.backends.utils
import
(
AttentionMetadataBuilder
,
CommonAttentionMetadata
)
from
vllm.v1.kv_cache_interface
import
AttentionSpec
from
vllm.v1.worker.block_table
import
BlockTable
...
...
@@ -25,8 +26,6 @@ if current_platform.is_cuda():
logger
=
init_logger
(
__name__
)
if
TYPE_CHECKING
:
from
vllm.v1.core.sched.output
import
SchedulerOutput
from
vllm.v1.worker.gpu_input_batch
import
InputBatch
from
vllm.v1.worker.gpu_model_runner
import
GPUModelRunner
create_block_mask_compiled
=
torch
.
compile
(
create_block_mask
,
...
...
@@ -256,7 +255,8 @@ class FlexAttentionMetadata:
self
.
block_mask
=
self
.
build_block_mask
()
class
FlexAttentionMetadataBuilder
:
class
FlexAttentionMetadataBuilder
(
AttentionMetadataBuilder
[
FlexAttentionMetadata
]):
def
__init__
(
self
,
runner
:
"GPUModelRunner"
,
kv_cache_spec
:
AttentionSpec
,
block_table
:
BlockTable
):
...
...
@@ -272,13 +272,12 @@ class FlexAttentionMetadataBuilder:
self
.
kv_cache_spec
=
kv_cache_spec
self
.
block_table
=
block_table
def
reorder_batch
(
self
,
input_batch
:
"InputBatch"
,
scheduler_output
:
"SchedulerOutput"
)
->
bool
:
return
False
def
build
(
self
,
num_reqs
:
int
,
num_actual_tokens
:
int
,
max_query_len
:
int
,
common_prefix_len
:
int
,
def
build
(
self
,
common_prefix_len
:
int
,
common_attn_metadata
:
CommonAttentionMetadata
):
num_reqs
=
common_attn_metadata
.
num_reqs
num_actual_tokens
=
common_attn_metadata
.
num_actual_tokens
max_query_len
=
common_attn_metadata
.
max_query_len
max_seq_len
=
self
.
runner
.
seq_lens_np
[:
num_reqs
].
max
()
query_start_loc
=
common_attn_metadata
.
query_start_loc
seq_lens
=
common_attn_metadata
.
seq_lens
...
...
@@ -332,9 +331,6 @@ class FlexAttentionMetadataBuilder:
)
return
out
def
use_cascade_attention
(
self
,
*
args
,
**
kwargs
)
->
bool
:
return
False
class
FlexAttentionImpl
(
AttentionImpl
):
sliding_window
:
Optional
[
tuple
[
int
,
int
]]
...
...
vllm/v1/attention/backends/mla/common.py
View file @
3597b06a
...
...
@@ -207,7 +207,8 @@ from vllm.model_executor.layers.linear import (ColumnParallelLinear,
UnquantizedLinearMethod
)
from
vllm.platforms
import
current_platform
from
vllm.utils
import
cdiv
,
round_down
from
vllm.v1.attention.backends.utils
import
CommonAttentionMetadata
from
vllm.v1.attention.backends.utils
import
(
AttentionMetadataBuilder
,
CommonAttentionMetadata
)
from
vllm.v1.kv_cache_interface
import
AttentionSpec
from
vllm.v1.worker.block_table
import
BlockTable
...
...
@@ -329,7 +330,7 @@ class MLACommonMetadata(Generic[D]):
M
=
TypeVar
(
"M"
,
bound
=
MLACommonMetadata
)
class
MLACommonMetadataBuilder
(
Generic
[
M
]):
class
MLACommonMetadataBuilder
(
AttentionMetadataBuilder
[
M
]):
"""
NOTE: Please read the comment at the top of the file before trying to
understand this class
...
...
@@ -450,9 +451,32 @@ class MLACommonMetadataBuilder(Generic[M]):
seq_lens
=
seq_lens
,
)
def
build
(
self
,
num_reqs
:
int
,
num_actual_tokens
:
int
,
max_query_len
:
int
,
common_prefix_len
:
int
,
def
build_for_cudagraph_capture
(
self
,
common_attn_metadata
:
CommonAttentionMetadata
)
->
M
:
"""
This method builds the metadata for full cudagraph capture.
Currently, only decode is supported for full cudagraphs with MLA.
"""
m
=
common_attn_metadata
assert
m
.
num_reqs
==
m
.
num_actual_tokens
,
\
"MLA only supports decode-only full CUDAGraph capture. "
\
"Make sure all cudagraph capture sizes <= max_num_seq."
m
.
max_query_len
=
1
# decode-only
# Update state usually set in reorder_batch.
self
.
_num_decodes
=
m
.
num_reqs
self
.
_num_decode_tokens
=
m
.
num_actual_tokens
self
.
_num_prefills
=
0
self
.
_num_prefill_tokens
=
0
return
self
.
build
(
0
,
m
)
def
build
(
self
,
common_prefix_len
:
int
,
common_attn_metadata
:
CommonAttentionMetadata
)
->
M
:
num_reqs
=
common_attn_metadata
.
num_reqs
num_actual_tokens
=
common_attn_metadata
.
num_actual_tokens
max_query_len
=
common_attn_metadata
.
max_query_len
assert
self
.
_num_decodes
+
self
.
_num_prefills
==
num_reqs
# Note(simon): be careful about the CPU <> GPU memory movement in this
...
...
@@ -461,8 +485,11 @@ class MLACommonMetadataBuilder(Generic[M]):
device
=
self
.
runner
.
device
block_table
=
self
.
block_table
block_table_tensor
=
block_table
.
get_device_tensor
()[:
num_reqs
]
slot_mapping
=
block_table
.
slot_mapping_cpu
[:
num_actual_tokens
].
to
(
device
,
non_blocking
=
True
).
long
()
block_table
.
slot_mapping
[:
num_actual_tokens
].
copy_
(
block_table
.
slot_mapping_cpu
[:
num_actual_tokens
],
non_blocking
=
True
)
block_table
.
slot_mapping
[
num_actual_tokens
:].
fill_
(
-
1
)
slot_mapping
=
block_table
.
slot_mapping
[:
num_actual_tokens
]
query_start_loc
=
common_attn_metadata
.
query_start_loc
seq_lens
=
common_attn_metadata
.
seq_lens
...
...
@@ -564,8 +591,9 @@ class MLACommonMetadataBuilder(Generic[M]):
decode
=
decode_metadata
,
)
def
use_cascade_attention
(
self
,
*
args
,
**
kwargs
)
->
bool
:
return
False
def
can_run_in_cudagraph
(
self
,
common_attn_metadata
:
CommonAttentionMetadata
)
->
bool
:
return
common_attn_metadata
.
max_query_len
==
1
class
MLACommonImpl
(
MLAAttentionImpl
[
M
],
Generic
[
M
]):
...
...
vllm/v1/attention/backends/mla/flashmla.py
View file @
3597b06a
...
...
@@ -2,7 +2,7 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from
dataclasses
import
dataclass
from
typing
import
Any
,
Optional
from
typing
import
Any
,
ClassVar
,
Optional
import
torch
...
...
@@ -44,7 +44,7 @@ class FlashMLABackend(MLACommonBackend):
@
dataclass
class
FlashMLADecodeMetadata
(
MLACommonDecodeMetadata
):
tile_scheduler_metadata
:
tuple
[
torch
.
Tensor
,
torch
.
Tensor
]
tile_scheduler_metadata
:
torch
.
Tensor
num_splits
:
torch
.
Tensor
...
...
@@ -54,14 +54,18 @@ class FlashMLAMetadata(MLACommonMetadata[FlashMLADecodeMetadata]):
class
FlashMLAMetadataBuilder
(
MLACommonMetadataBuilder
[
FlashMLAMetadata
]):
full_cudagraph_supported
:
ClassVar
[
bool
]
=
True
# Decode-only
def
__init__
(
self
,
runner
,
kv_cache_spec
:
AttentionSpec
,
block_table
:
BlockTable
):
super
().
__init__
(
runner
,
kv_cache_spec
,
block_table
)
super
().
__init__
(
runner
,
kv_cache_spec
,
block_table
,
FlashMLAMetadata
)
self
.
num_q_heads
=
self
.
runner
.
model_config
.
get_num_attention_heads
(
self
.
runner
.
parallel_config
)
self
.
cg_buf_tile_scheduler_metadata
=
None
self
.
cg_buf_num_splits
=
None
def
_build_decode
(
self
,
block_table_tensor
:
torch
.
Tensor
,
seq_lens
:
torch
.
Tensor
)
->
FlashMLADecodeMetadata
:
tile_scheduler_metadata
,
num_splits
=
\
...
...
@@ -71,6 +75,30 @@ class FlashMLAMetadataBuilder(MLACommonMetadataBuilder[FlashMLAMetadata]):
1
,
# MQA for the decode path
)
if
self
.
runner
.
full_cuda_graph
:
# First time around (CUDAGraph capture), allocate the static buffer
if
self
.
cg_buf_tile_scheduler_metadata
is
None
:
self
.
cg_buf_tile_scheduler_metadata
=
tile_scheduler_metadata
self
.
cg_buf_num_splits
=
num_splits
else
:
assert
self
.
cg_buf_num_splits
is
not
None
# Metadata per-SM, fixed size (#SMs, TileMetadataSize)
assert
(
self
.
cg_buf_tile_scheduler_metadata
.
size
()
==
tile_scheduler_metadata
.
size
())
self
.
cg_buf_tile_scheduler_metadata
.
\
copy_
(
tile_scheduler_metadata
)
tile_scheduler_metadata
=
self
.
cg_buf_tile_scheduler_metadata
# Num splits is per-batch, varying size (batch_size,)
n
=
num_splits
.
size
(
0
)
# make sure static buffer is large enough
assert
n
<=
self
.
cg_buf_num_splits
.
size
(
0
)
num_splits_view
=
self
.
cg_buf_num_splits
[:
n
]
num_splits_view
.
copy_
(
num_splits
)
self
.
cg_buf_num_splits
[
n
:].
fill_
(
0
)
# fill the rest with 0s
num_splits
=
num_splits_view
return
FlashMLADecodeMetadata
(
block_table
=
block_table_tensor
,
seq_lens
=
seq_lens
,
...
...
vllm/v1/attention/backends/mla/rocm_aiter_mla.py
View file @
3597b06a
...
...
@@ -66,7 +66,7 @@ class AiterMLAMetadataBuilder(MLACommonMetadataBuilder[AiterMLAMetadata]):
def
__init__
(
self
,
runner
,
kv_cache_spec
:
AttentionSpec
,
block_table
:
BlockTable
):
super
().
__init__
(
runner
,
kv_cache_spec
,
block_table
)
super
().
__init__
(
runner
,
kv_cache_spec
,
block_table
,
AiterMLAMetadata
)
assert
self
.
kv_cache_spec
.
block_size
==
1
,
"AITER MLA"
\
"only supports block size 1."
...
...
vllm/v1/attention/backends/utils.py
View file @
3597b06a
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import
abc
from
abc
import
abstractmethod
from
dataclasses
import
dataclass
from
typing
import
TYPE_CHECKING
,
ClassVar
,
Generic
,
TypeVar
import
numpy
as
np
import
torch
if
TYPE_CHECKING
:
from
vllm.v1.core.sched.output
import
SchedulerOutput
from
vllm.v1.worker.gpu_input_batch
import
InputBatch
@
dataclass
class
CommonAttentionMetadata
:
"""
A
ttention metadata
attributes that can be
shared
by
layers
in different KV
cache groups and thus having different block table
.
Per-batch a
ttention metadata
,
shared
across
layers
and backends.
AttentionMetadataBuilder instances use it to construct per-layer metadata
.
"""
query_start_loc
:
torch
.
Tensor
...
...
@@ -18,6 +26,67 @@ class CommonAttentionMetadata:
"""(batch_size,), the length of each request including both computed tokens
and newly scheduled tokens"""
num_reqs
:
int
"""Number of requests"""
num_actual_tokens
:
int
"""Total number of tokens in batch"""
max_query_len
:
int
"""Longest query in batch"""
M
=
TypeVar
(
"M"
)
class
AttentionMetadataBuilder
(
abc
.
ABC
,
Generic
[
M
]):
# Does this backend/builder support CUDA Graphs for attention.
full_cudagraph_supported
:
ClassVar
[
bool
]
=
False
@
abstractmethod
def
build
(
self
,
common_prefix_len
:
int
,
common_attn_metadata
:
CommonAttentionMetadata
)
->
M
:
"""
Central method that builds attention metadata.
Some builders (MLA) require reorder_batch to be called prior to build.
"""
raise
NotImplementedError
def
can_run_in_cudagraph
(
self
,
common_attn_metadata
:
CommonAttentionMetadata
)
->
bool
:
"""
Can this batch (with given metadata) use CUDA Graphs for attention.
"""
return
False
def
build_for_cudagraph_capture
(
self
,
common_attn_metadata
:
CommonAttentionMetadata
)
->
M
:
"""
Build attention metadata for CUDA graph capture. Uses build by default.
Subclasses that override this method should call self.build or
super().build_for_cudagraph_capture.
"""
return
self
.
build
(
common_prefix_len
=
0
,
common_attn_metadata
=
common_attn_metadata
)
def
use_cascade_attention
(
self
,
common_prefix_len
:
int
,
query_lens
:
np
.
ndarray
,
num_query_heads
:
int
,
num_kv_heads
:
int
,
use_alibi
:
bool
,
use_sliding_window
:
bool
,
num_sms
:
int
,
)
->
bool
:
return
False
def
reorder_batch
(
self
,
input_batch
:
"InputBatch"
,
scheduler_output
:
"SchedulerOutput"
)
->
bool
:
"""
This method can reorder the batch if desired by the backend.
:return: Has the batch been reordered (default False).
"""
return
False
def
validate_kv_sharing_target
(
current_layer_name
,
target_layer_name
,
static_forward_context
):
...
...
vllm/v1/spec_decode/eagle.py
View file @
3597b06a
...
...
@@ -138,15 +138,17 @@ class EagleProposer:
max_query_len
=
query_lens
.
max
().
item
()
common_attn_metadata
=
CommonAttentionMetadata
(
query_start_loc
=
cu_num_tokens
,
seq_lens
=
seq_lens
)
query_start_loc
=
cu_num_tokens
,
seq_lens
=
seq_lens
,
num_reqs
=
batch_size
,
num_actual_tokens
=
num_tokens
,
max_query_len
=
max_query_len
,
)
assert
self
.
runner
is
not
None
# FIXME: need to consider multiple kv_cache_groups
attn_metadata
=
self
.
runner
.
attn_metadata_builder
.
build
(
num_reqs
=
batch_size
,
num_actual_tokens
=
num_tokens
,
max_query_len
=
max_query_len
,
common_prefix_len
=
0
,
common_attn_metadata
=
common_attn_metadata
,
)
...
...
vllm/v1/worker/gpu_model_runner.py
View file @
3597b06a
...
...
@@ -16,10 +16,8 @@ from tqdm import tqdm
import
vllm.envs
as
envs
from
vllm.attention
import
AttentionType
,
get_attn_backend
from
vllm.attention.backends.abstract
import
(
AttentionBackend
,
AttentionMetadataBuilder
)
from
vllm.attention.backends.abstract
import
AttentionBackend
from
vllm.attention.layer
import
Attention
from
vllm.attention.utils.fa_utils
import
get_flash_attn_version
from
vllm.config
import
(
CompilationLevel
,
VllmConfig
,
get_layers_from_vllm_config
)
from
vllm.distributed.kv_transfer
import
(
get_kv_transfer_group
,
...
...
@@ -41,7 +39,8 @@ from vllm.sequence import IntermediateTensors
from
vllm.utils
import
(
STR_DTYPE_TO_TORCH_DTYPE
,
DeviceMemoryProfiler
,
GiB_bytes
,
LazyLoader
,
async_tensor_h2d
,
cdiv
,
check_use_alibi
,
is_pin_memory_available
)
from
vllm.v1.attention.backends.utils
import
CommonAttentionMetadata
from
vllm.v1.attention.backends.utils
import
(
AttentionMetadataBuilder
,
CommonAttentionMetadata
)
from
vllm.v1.core.encoder_cache_manager
import
compute_encoder_budget
from
vllm.v1.kv_cache_interface
import
(
AttentionSpec
,
FullAttentionSpec
,
KVCacheConfig
,
KVCacheSpec
,
...
...
@@ -89,6 +88,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
self
.
vllm_config
=
vllm_config
self
.
model_config
=
vllm_config
.
model_config
self
.
cache_config
=
vllm_config
.
cache_config
self
.
compilation_config
=
vllm_config
.
compilation_config
self
.
lora_config
=
vllm_config
.
lora_config
self
.
load_config
=
vllm_config
.
load_config
self
.
parallel_config
=
vllm_config
.
parallel_config
...
...
@@ -197,7 +197,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
block_sizes
=
[
self
.
cache_config
.
block_size
],
)
self
.
use_cuda_graph
=
(
self
.
vllm_config
.
compilation_config
.
level
self
.
use_cuda_graph
=
(
self
.
compilation_config
.
level
==
CompilationLevel
.
PIECEWISE
and
not
self
.
model_config
.
enforce_eager
)
# TODO(woosuk): Provide an option to tune the max cudagraph batch size.
...
...
@@ -205,8 +205,9 @@ class GPUModelRunner(LoRAModelRunnerMixin):
# self.cudagraph_batch_sizes sorts in ascending order.
# The batch sizes in the config are in descending order.
self
.
cudagraph_batch_sizes
=
list
(
reversed
(
self
.
vllm_config
.
compilation_config
.
cudagraph_capture_sizes
))
reversed
(
self
.
compilation_config
.
cudagraph_capture_sizes
))
self
.
full_cuda_graph
=
self
.
compilation_config
.
full_cuda_graph
# Cache the device properties.
self
.
_init_device_properties
()
...
...
@@ -555,7 +556,15 @@ class GPUModelRunner(LoRAModelRunnerMixin):
def
_prepare_inputs
(
self
,
scheduler_output
:
"SchedulerOutput"
,
)
->
tuple
[
dict
[
str
,
Any
],
torch
.
Tensor
,
Optional
[
SpecDecodeMetadata
]]:
)
->
tuple
[
dict
[
str
,
Any
],
bool
,
torch
.
Tensor
,
Optional
[
SpecDecodeMetadata
]]:
"""
:return: tuple[
attn_metadata: layer-to-attention_metadata mapping,
attention_cuda_graphs: whether attention can run in cudagraph
logits_indices, spec_decode_metadata
]
"""
total_num_scheduled_tokens
=
scheduler_output
.
total_num_scheduled_tokens
assert
total_num_scheduled_tokens
>
0
num_reqs
=
self
.
input_batch
.
num_reqs
...
...
@@ -669,7 +678,12 @@ class GPUModelRunner(LoRAModelRunnerMixin):
seq_lens
=
self
.
seq_lens
[:
num_reqs
]
common_attn_metadata
=
CommonAttentionMetadata
(
query_start_loc
=
query_start_loc
,
seq_lens
=
seq_lens
)
query_start_loc
=
query_start_loc
,
seq_lens
=
seq_lens
,
num_reqs
=
num_reqs
,
num_actual_tokens
=
total_num_scheduled_tokens
,
max_query_len
=
max_num_scheduled_tokens
,
)
attn_metadata
:
dict
[
str
,
Any
]
=
{}
# Prepare the attention metadata for each KV cache group and make layers
...
...
@@ -679,25 +693,28 @@ class GPUModelRunner(LoRAModelRunnerMixin):
# Prepare for cascade attention if enabled & beneficial.
common_prefix_len
=
0
builder
=
self
.
attn_metadata_builders
[
kv_cache_group_id
]
if
self
.
cascade_attn_enabled
:
common_prefix_len
=
self
.
_compute_cascade_attn_prefix_len
(
num_scheduled_tokens
,
scheduler_output
.
num_common_prefix_blocks
[
kv_cache_group_id
],
kv_cache_group_spec
.
kv_cache_spec
,
self
.
attn_metadata_builders
[
kv_cache_group_id
]
,
builder
,
)
attn_metadata_i
=
(
self
.
attn_metadata_builders
[
kv_cache_group_id
].
build
(
num_reqs
=
num_reqs
,
num_actual_tokens
=
total_num_scheduled_tokens
,
max_query_len
=
max_num_scheduled_tokens
,
common_prefix_len
=
common_prefix_len
,
common_attn_metadata
=
common_attn_metadata
))
attn_metadata_i
=
(
builder
.
build
(
common_prefix_len
=
common_prefix_len
,
common_attn_metadata
=
common_attn_metadata
,
))
for
layer_name
in
kv_cache_group_spec
.
layer_names
:
attn_metadata
[
layer_name
]
=
attn_metadata_i
attention_cuda_graphs
=
all
(
b
.
can_run_in_cudagraph
(
common_attn_metadata
)
for
b
in
self
.
attn_metadata_builders
)
use_spec_decode
=
len
(
scheduler_output
.
scheduled_spec_decode_tokens
)
>
0
if
not
use_spec_decode
:
...
...
@@ -726,7 +743,8 @@ class GPUModelRunner(LoRAModelRunnerMixin):
if
self
.
lora_config
:
self
.
set_active_loras
(
self
.
input_batch
,
num_scheduled_tokens
)
return
attn_metadata
,
logits_indices
,
spec_decode_metadata
return
(
attn_metadata
,
attention_cuda_graphs
,
logits_indices
,
spec_decode_metadata
)
def
_compute_cascade_attn_prefix_len
(
self
,
...
...
@@ -1121,7 +1139,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
assert
self
.
intermediate_tensors
is
not
None
tp
=
self
.
vllm_config
.
parallel_config
.
tensor_parallel_size
enabled_sp
=
self
.
vllm_config
.
compilation_config
.
pass_config
.
\
enabled_sp
=
self
.
compilation_config
.
pass_config
.
\
enable_sequence_parallelism
if
enabled_sp
:
# When sequence parallelism is enabled, we always pad num_tokens
...
...
@@ -1189,8 +1207,8 @@ class GPUModelRunner(LoRAModelRunnerMixin):
return
self
.
kv_connector_no_forward
(
scheduler_output
)
# Prepare the decoder inputs.
attn_metadata
,
logits_indices
,
spec_decode_metadata
=
(
self
.
_prepare_inputs
(
scheduler_output
))
(
attn_metadata
,
attention_cuda_graphs
,
logits_indices
,
spec_decode_metadata
)
=
(
self
.
_prepare_inputs
(
scheduler_output
))
num_scheduled_tokens
=
scheduler_output
.
total_num_scheduled_tokens
if
(
self
.
use_cuda_graph
and
num_scheduled_tokens
<=
self
.
cudagraph_batch_sizes
[
-
1
]):
...
...
@@ -1203,7 +1221,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
# Pad tokens to multiple of tensor_parallel_size when
# enabled collective fusion for SP
tp_size
=
self
.
vllm_config
.
parallel_config
.
tensor_parallel_size
if
self
.
vllm_config
.
compilation_config
.
pass_config
.
\
if
self
.
compilation_config
.
pass_config
.
\
enable_sequence_parallelism
and
tp_size
>
1
:
from
vllm.utils
import
round_up
num_input_tokens
=
round_up
(
num_scheduled_tokens
,
tp_size
)
...
...
@@ -1255,12 +1273,20 @@ class GPUModelRunner(LoRAModelRunnerMixin):
intermediate_tensors
=
self
.
sync_and_slice_intermediate_tensors
(
num_input_tokens
,
intermediate_tensors
,
True
)
# Some attention backends only support CUDA Graphs in pure decode.
# If attention doesn't support CUDA Graphs for this batch, but we
# compiled with full CUDA graphs, we have to skip them entirely.
skip_cuda_graphs
=
self
.
full_cuda_graph
and
not
attention_cuda_graphs
# Run the decoder.
# Use persistent buffers for CUDA graphs.
with
set_forward_context
(
attn_metadata
,
self
.
vllm_config
,
num_tokens
=
num_input_tokens
,
num_tokens_across_dp
=
num_tokens_across_dp
):
with
set_forward_context
(
attn_metadata
,
self
.
vllm_config
,
num_tokens
=
num_input_tokens
,
num_tokens_across_dp
=
num_tokens_across_dp
,
skip_cuda_graphs
=
skip_cuda_graphs
,
):
self
.
maybe_setup_kv_connector
(
scheduler_output
)
model_output
=
self
.
model
(
...
...
@@ -1769,7 +1795,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
def
_dummy_run
(
self
,
num_tokens
:
int
,
skip_attn
:
bool
=
Tru
e
,
capture_attn_cudagraph
:
bool
=
Fals
e
,
)
->
torch
.
Tensor
:
# Padding for DP
...
...
@@ -1790,9 +1816,10 @@ class GPUModelRunner(LoRAModelRunnerMixin):
num_scheduled_tokens
=
np
.
array
(
num_scheduled_tokens_list
,
dtype
=
np
.
int32
)
if
skip_attn
:
attn_metadata
:
Optional
[
dict
[
str
,
Any
]]
=
None
else
:
attn_metadata
:
Optional
[
dict
[
str
,
Any
]]
=
None
if
capture_attn_cudagraph
:
attn_metadata
=
{}
query_start_loc
=
self
.
query_start_loc
[:
num_reqs
+
1
]
# Make sure max_model_len is used at the graph capture time.
self
.
seq_lens_np
[:
num_reqs
]
=
self
.
max_model_len
...
...
@@ -1802,19 +1829,19 @@ class GPUModelRunner(LoRAModelRunnerMixin):
seq_lens
=
self
.
seq_lens
[:
num_reqs
]
common_attn_metadata
=
CommonAttentionMetadata
(
query_start_loc
=
query_start_loc
,
seq_lens
=
seq_lens
)
query_start_loc
=
query_start_loc
,
seq_lens
=
seq_lens
,
num_reqs
=
num_reqs
,
num_actual_tokens
=
num_tokens
,
max_query_len
=
num_tokens
,
)
attn_metadata
=
{}
for
kv_cache_group_id
,
kv_cache_group_spec
in
enumerate
(
self
.
kv_cache_config
.
kv_cache_groups
):
attn_metadata_i
=
(
self
.
attn_metadata_builders
[
kv_cache_group_id
].
build
(
num_reqs
=
num_reqs
,
num_actual_tokens
=
num_tokens
,
max_query_len
=
num_tokens
,
common_prefix_len
=
0
,
common_attn_metadata
=
common_attn_metadata
,
))
attn_metadata_i
=
self
.
attn_metadata_builders
[
kv_cache_group_id
].
build_for_cudagraph_capture
(
common_attn_metadata
)
for
layer_name
in
kv_cache_group_spec
.
layer_names
:
attn_metadata
[
layer_name
]
=
attn_metadata_i
...
...
@@ -2039,14 +2066,14 @@ class GPUModelRunner(LoRAModelRunnerMixin):
# Capture the large shapes first so that the smaller shapes
# can reuse the memory pool allocated for the large shapes.
with
graph_capture
(
device
=
self
.
device
):
skip_attn
=
not
self
.
vllm_config
.
compilation_config
.
full_cuda_graph
full_cg
=
self
.
full_cuda_graph
for
num_tokens
in
tqdm
(
reversed
(
self
.
cudagraph_batch_sizes
),
desc
=
"Capturing CUDA graphs"
,
total
=
len
(
self
.
cudagraph_batch_sizes
)):
for
_
in
range
(
self
.
vllm_config
.
compilation_config
.
cudagraph_num_of_warmups
):
self
.
_dummy_run
(
num_tokens
,
skip_attn
=
skip_attn
)
self
.
_dummy_run
(
num_tokens
,
skip_attn
=
skip_attn
)
for
_
in
range
(
self
.
compilation_config
.
cudagraph_num_of_warmups
):
self
.
_dummy_run
(
num_tokens
,
capture_attn_cudagraph
=
full_cg
)
self
.
_dummy_run
(
num_tokens
,
capture_attn_cudagraph
=
full_cg
)
end_time
=
time
.
perf_counter
()
end_free_gpu_memory
=
torch
.
cuda
.
mem_get_info
()[
0
]
...
...
@@ -2089,20 +2116,20 @@ class GPUModelRunner(LoRAModelRunnerMixin):
"Non-Attention backend is not supported by V1 "
"GPUModelRunner."
)
if
self
.
vllm_config
.
compilation_config
.
full_cuda_graph
:
attn_backend_name
=
attn_backend_i
.
__name__
flash_attn_version
=
get_flash_attn_version
()
if
attn_backend_name
!=
"FlashAttentionBackend"
or
\
flash_attn_version
!=
3
:
raise
ValueError
(
f
"full_cuda_graph is only supported with "
f
"FA3. Current attention backend is "
f
"
{
attn_backend_name
}
, FlashAttention version is "
f
"
{
flash_attn_version
}
."
)
block_table_i
=
self
.
input_batch
.
block_table
[
i
]
attn_metadata_builder_i
=
attn_backend_i
.
get_builder_cls
()(
weakref
.
proxy
(
self
),
kv_cache_spec
,
block_table_i
)
weakref
.
proxy
(
self
),
kv_cache_spec
,
block_table_i
,
)
if
(
self
.
full_cuda_graph
and
not
attn_metadata_builder_i
.
full_cudagraph_supported
):
raise
ValueError
(
f
"Full CUDAGraph not supported for "
f
"
{
attn_backend_i
.
__name__
}
. Turn off CompilationConfig."
f
"full_cuda_graph or use a different attention backend."
)
self
.
attn_backends
.
append
(
attn_backend_i
)
self
.
attn_metadata_builders
.
append
(
attn_metadata_builder_i
)
...
...
@@ -2142,9 +2169,9 @@ class GPUModelRunner(LoRAModelRunnerMixin):
to be reshaped to the desired shape before being used by the models.
Args:
kv_cache_config: The KV cache config
kv_cache_config: The KV cache config
Returns:
dict[str, torch.Tensor]: A map between layer names to their
dict[str, torch.Tensor]: A map between layer names to their
corresponding memory buffer for KV cache.
"""
kv_cache_raw_tensors
:
dict
[
str
,
torch
.
Tensor
]
=
{}
...
...
@@ -2171,11 +2198,11 @@ class GPUModelRunner(LoRAModelRunnerMixin):
Reshape the KV cache tensors to the desired shape and dtype.
Args:
kv_cache_config: The KV cache config
kv_cache_raw_tensors: The KV cache buffer of each layer, with
kv_cache_config: The KV cache config
kv_cache_raw_tensors: The KV cache buffer of each layer, with
correct size but uninitialized shape.
Returns:
Dict[str, torch.Tensor]: A map between layer names to their
Dict[str, torch.Tensor]: A map between layer names to their
corresponding memory buffer for KV cache.
"""
kv_caches
:
dict
[
str
,
torch
.
Tensor
]
=
{}
...
...
@@ -2227,7 +2254,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
Args:
kv_cache_config: The KV cache config
Returns:
Dict[str, torch.Tensor]: A map between layer names to their
Dict[str, torch.Tensor]: A map between layer names to their
corresponding memory buffer for KV cache.
"""
# Initialize the memory buffer for KV cache
...
...
@@ -2245,10 +2272,9 @@ class GPUModelRunner(LoRAModelRunnerMixin):
kv_caches
,
)
bind_kv_cache
(
kv_caches
,
self
.
vllm_config
.
compilation_config
.
static_forward_context
,
self
.
kv_caches
)
bind_kv_cache
(
kv_caches
,
self
.
compilation_config
.
static_forward_context
,
self
.
kv_caches
)
return
kv_caches
def
initialize_kv_cache
(
self
,
kv_cache_config
:
KVCacheConfig
)
->
None
:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment