Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
3597b06a
Unverified
Commit
3597b06a
authored
Jun 13, 2025
by
Luka Govedič
Committed by
GitHub
Jun 13, 2025
Browse files
[CUDA] Enable full cudagraph for FlashMLA (#18581)
Signed-off-by:
luka
<
luka@neuralmagic.com
>
parent
1015296b
Changes
17
Hide whitespace changes
Inline
Side-by-side
Showing
17 changed files
with
453 additions
and
220 deletions
+453
-220
tests/compile/piecewise/test_full_cudagraph.py
tests/compile/piecewise/test_full_cudagraph.py
+107
-51
tests/compile/piecewise/test_simple.py
tests/compile/piecewise/test_simple.py
+5
-11
tests/compile/piecewise/test_toy_llama.py
tests/compile/piecewise/test_toy_llama.py
+21
-24
tests/utils.py
tests/utils.py
+21
-9
vllm/compilation/cuda_piecewise_backend.py
vllm/compilation/cuda_piecewise_backend.py
+5
-1
vllm/entrypoints/llm.py
vllm/entrypoints/llm.py
+2
-1
vllm/forward_context.py
vllm/forward_context.py
+12
-6
vllm/v1/attention/backends/cpu_attn.py
vllm/v1/attention/backends/cpu_attn.py
+8
-4
vllm/v1/attention/backends/flash_attn.py
vllm/v1/attention/backends/flash_attn.py
+18
-11
vllm/v1/attention/backends/flashinfer.py
vllm/v1/attention/backends/flashinfer.py
+7
-4
vllm/v1/attention/backends/flex_attention.py
vllm/v1/attention/backends/flex_attention.py
+9
-13
vllm/v1/attention/backends/mla/common.py
vllm/v1/attention/backends/mla/common.py
+36
-8
vllm/v1/attention/backends/mla/flashmla.py
vllm/v1/attention/backends/mla/flashmla.py
+31
-3
vllm/v1/attention/backends/mla/rocm_aiter_mla.py
vllm/v1/attention/backends/mla/rocm_aiter_mla.py
+1
-1
vllm/v1/attention/backends/utils.py
vllm/v1/attention/backends/utils.py
+71
-2
vllm/v1/spec_decode/eagle.py
vllm/v1/spec_decode/eagle.py
+6
-4
vllm/v1/worker/gpu_model_runner.py
vllm/v1/worker/gpu_model_runner.py
+93
-67
No files found.
tests/compile/piecewise/test_full_cudagraph.py
View file @
3597b06a
...
@@ -2,15 +2,16 @@
...
@@ -2,15 +2,16 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import
contextlib
import
contextlib
import
os
import
os
import
weakref
from
contextlib
import
ExitStack
import
pytest
import
pytest
from
tests.utils
import
wait_for_gpu_memory_to_clear
from
vllm
import
LLM
,
SamplingParams
from
vllm
import
LLM
,
SamplingParams
from
vllm.config
import
CompilationConfig
from
vllm.config
import
CompilationConfig
from
vllm.platforms
import
current_platform
from
vllm.platforms
import
current_platform
MODEL
=
"Qwen/Qwen2-1.5B-Instruct"
@
contextlib
.
contextmanager
@
contextlib
.
contextmanager
def
temporary_environ
(
env_vars
):
def
temporary_environ
(
env_vars
):
...
@@ -31,64 +32,119 @@ def temporary_environ(env_vars):
...
@@ -31,64 +32,119 @@ def temporary_environ(env_vars):
os
.
environ
[
k
]
=
v
os
.
environ
[
k
]
=
v
@
pytest
.
fixture
(
scope
=
"module"
)
@
pytest
.
fixture
(
scope
=
"class"
)
def
full_cudagraph_llm
():
def
llm_pair
(
request
):
with
temporary_environ
({
model
=
request
.
param
"VLLM_USE_V1"
:
"1"
,
"VLLM_FLASH_ATTN_VERSION"
:
"3"
}):
return
LLM
(
model
=
MODEL
,
gpu_memory_utilization
=
0.3
,
compilation_config
=
CompilationConfig
(
full_cuda_graph
=
True
))
@
pytest
.
fixture
(
scope
=
"module"
)
def
piecewise_llm
():
with
temporary_environ
({
with
temporary_environ
({
"VLLM_USE_V1"
:
"1"
,
"VLLM_USE_V1"
:
"1"
,
"VLLM_FLASH_ATTN_VERSION"
:
"3"
"VLLM_FLASH_ATTN_VERSION"
:
"3"
}):
}):
return
LLM
(
model
=
MODEL
,
full
=
LLM
(
gpu_memory_utilization
=
0.6
,
model
=
model
,
compilation_config
=
CompilationConfig
())
gpu_memory_utilization
=
0.45
,
trust_remote_code
=
True
,
max_model_len
=
1024
,
def
generate_text
(
llm
:
LLM
,
batch_size
:
int
,
max_tokens
:
int
):
compilation_config
=
CompilationConfig
(
full_cuda_graph
=
True
),
prompts
=
[
"Hi my name is"
]
*
batch_size
)
sampling_params
=
SamplingParams
(
temperature
=
0.0
,
piecewise
=
LLM
(
max_tokens
=
max_tokens
,
model
=
model
,
top_p
=
0.95
)
gpu_memory_utilization
=
0.45
,
trust_remote_code
=
True
,
return
llm
.
generate
(
prompts
,
sampling_params
)
max_model_len
=
1024
,
compilation_config
=
CompilationConfig
(),
)
# PyTest caches the fixture values so we use weakref.proxy to enable GC
yield
weakref
.
proxy
(
full
),
weakref
.
proxy
(
piecewise
)
del
full
del
piecewise
wait_for_gpu_memory_to_clear
(
devices
=
[
0
],
threshold_ratio
=
0.1
,
)
@
pytest
.
mark
.
parametrize
(
"llm_pair"
,
[
# Model names for the llm_pair fixture
"deepseek-ai/DeepSeek-V2-Lite"
,
"Qwen/Qwen2-1.5B-Instruct"
],
indirect
=
True
)
@
pytest
.
mark
.
skipif
(
current_platform
.
get_device_capability
()
!=
(
9
,
0
),
@
pytest
.
mark
.
skipif
(
current_platform
.
get_device_capability
()
!=
(
9
,
0
),
reason
=
"Only Hopper GPUs support FlashAttention 3"
)
reason
=
"Only Hopper GPUs support FA3 and FlashMLA"
)
@
pytest
.
mark
.
parametrize
((
"batch_size"
,
"max_tokens"
),
[(
1
,
10
),
(
7
,
10
),
class
TestFullCUDAGraph
:
(
16
,
10
),
(
25
,
10
),
(
32
,
10
),
(
45
,
10
),
(
64
,
10
),
(
8
,
5
),
(
8
,
20
),
(
8
,
200
)])
def
test_full_cudagraph
(
batch_size
,
max_tokens
,
full_cudagraph_llm
,
piecewise_llm
):
"""
"""
Load full cudagraph model and piecewise model once, and at the same time to
Use a class such that an llm pair is constructed once for all
reuse them across var
io
u
s
test cases
.
batch_size/max_tokens combinat
io
n
s
and released immediately after
.
Test various batch sizes and max_tokens to ensure that the full cudagraph
Module-scope fixtures would stick around the whole time,
compilation works for padded cases too
.
meaning there would be multiple LLM instances hogging memory simultaneously
.
"""
"""
piecewise_responses
=
generate_text
(
piecewise_llm
,
batch_size
=
batch_size
,
max_tokens
=
max_tokens
)
full_cudagraph_responses
=
generate_text
(
full_cudagraph_llm
,
batch_size
=
batch_size
,
max_tokens
=
max_tokens
)
# Check that all responses are the same
@
pytest
.
mark
.
parametrize
((
"batch_size"
,
"max_tokens"
),
[
for
i
in
range
(
len
(
piecewise_responses
)):
(
1
,
10
),
assert
piecewise_responses
[
i
].
outputs
[
(
7
,
10
),
0
].
text
==
full_cudagraph_responses
[
i
].
outputs
[
0
].
text
(
16
,
10
),
(
25
,
10
),
(
32
,
10
),
(
45
,
10
),
(
64
,
10
),
(
123
,
10
),
(
8
,
5
),
(
8
,
30
),
])
def
test_full_cudagraph
(
self
,
batch_size
,
max_tokens
,
llm_pair
:
tuple
[
LLM
,
LLM
]):
"""
Test various batch sizes and max_tokens to ensure that the
full cudagraph compilation works for padded cases too.
"""
piecewise_llm
,
full_cudagraph_llm
=
llm_pair
prompts
=
[
"Hello, my name is"
]
*
batch_size
sampling_params
=
SamplingParams
(
temperature
=
0.0
,
max_tokens
=
max_tokens
,
top_p
=
0.95
)
piecewise_responses
=
piecewise_llm
.
generate
(
prompts
,
sampling_params
)
full_responses
=
full_cudagraph_llm
.
generate
(
prompts
,
sampling_params
)
# Check that all responses are the same
for
piecewise_res
,
full_res
in
zip
(
piecewise_responses
,
full_responses
):
assert
piecewise_res
.
outputs
[
0
].
text
==
full_res
.
outputs
[
0
].
text
@
pytest
.
mark
.
parametrize
(
"model, supported"
,
[
(
"Qwen/Qwen2-1.5B-Instruct"
,
True
),
# MLA does not support capturing CUDA Graphs with size > max_num_seqs
(
"deepseek-ai/DeepSeek-V2-Lite"
,
False
),
])
@
pytest
.
mark
.
skipif
(
current_platform
.
get_device_capability
()
!=
(
9
,
0
),
reason
=
"Only Hopper GPUs support FA3 and FlashMLA"
)
def
test_lower_max_num_seqs
(
model
,
supported
):
with
temporary_environ
({
"VLLM_USE_V1"
:
"1"
,
"VLLM_FLASH_ATTN_VERSION"
:
"3"
}),
ExitStack
()
as
stack
:
if
not
supported
:
stack
.
enter_context
(
pytest
.
raises
(
RuntimeError
))
llm
=
LLM
(
model
=
model
,
max_num_seqs
=
256
,
trust_remote_code
=
True
,
max_model_len
=
1024
,
compilation_config
=
CompilationConfig
(
full_cuda_graph
=
True
,
cudagraph_capture_sizes
=
[
64
,
256
,
512
]))
llm
.
generate
([
"Hello, my name is"
]
*
10
)
def
test_full_cudagraph_with_invalid_backend
():
def
test_full_cudagraph_with_invalid_backend
():
...
@@ -97,5 +153,5 @@ def test_full_cudagraph_with_invalid_backend():
...
@@ -97,5 +153,5 @@ def test_full_cudagraph_with_invalid_backend():
"VLLM_FLASH_ATTN_VERSION"
:
"VLLM_FLASH_ATTN_VERSION"
:
"2"
#FA2 not supported with full_cuda_graph
"2"
#FA2 not supported with full_cuda_graph
}),
pytest
.
raises
(
RuntimeError
):
}),
pytest
.
raises
(
RuntimeError
):
LLM
(
model
=
MODEL
,
LLM
(
model
=
"Qwen/Qwen2-1.5B-Instruct"
,
compilation_config
=
CompilationConfig
(
full_cuda_graph
=
True
))
compilation_config
=
CompilationConfig
(
full_cuda_graph
=
True
))
tests/compile/piecewise/test_simple.py
View file @
3597b06a
...
@@ -4,7 +4,7 @@
...
@@ -4,7 +4,7 @@
Test the piecewise compilation with a simple model so that we
Test the piecewise compilation with a simple model so that we
can exactly calculate the expected output and side effects.
can exactly calculate the expected output and side effects.
"""
"""
import
pytest
import
torch
import
torch
from
torch
import
nn
from
torch
import
nn
from
torch.library
import
Library
from
torch.library
import
Library
...
@@ -14,6 +14,7 @@ from vllm.compilation.decorators import support_torch_compile
...
@@ -14,6 +14,7 @@ from vllm.compilation.decorators import support_torch_compile
from
vllm.config
import
(
CompilationConfig
,
CompilationLevel
,
VllmConfig
,
from
vllm.config
import
(
CompilationConfig
,
CompilationLevel
,
VllmConfig
,
set_current_vllm_config
)
set_current_vllm_config
)
from
vllm.envs
import
VLLM_USE_V1
from
vllm.envs
import
VLLM_USE_V1
from
vllm.forward_context
import
set_forward_context
from
vllm.utils
import
direct_register_custom_op
from
vllm.utils
import
direct_register_custom_op
global_counter
=
0
global_counter
=
0
...
@@ -76,7 +77,8 @@ class SillyModel(nn.Module):
...
@@ -76,7 +77,8 @@ class SillyModel(nn.Module):
return
x
return
x
def
_test_simple_piecewise_compile
(
*
,
use_inductor
):
@
pytest
.
mark
.
parametrize
(
"use_inductor"
,
[
True
,
False
])
def
test_simple_piecewise_compile
(
use_inductor
):
assert
VLLM_USE_V1
assert
VLLM_USE_V1
vllm_config
=
VllmConfig
(
compilation_config
=
CompilationConfig
(
vllm_config
=
VllmConfig
(
compilation_config
=
CompilationConfig
(
...
@@ -99,7 +101,7 @@ def _test_simple_piecewise_compile(*, use_inductor):
...
@@ -99,7 +101,7 @@ def _test_simple_piecewise_compile(*, use_inductor):
num_backend_compilations
=
3
,
# num_piecewise_capturable_graphs_seen
num_backend_compilations
=
3
,
# num_piecewise_capturable_graphs_seen
num_cudagraph_captured
=
num_cudagraph_captured
=
6
,
# num_cudagraph_sizes * num_piecewise_capturable_graphs_seen
6
,
# num_cudagraph_sizes * num_piecewise_capturable_graphs_seen
):
),
set_forward_context
({},
vllm_config
=
vllm_config
):
model
(
inputs
)
model
(
inputs
)
...
@@ -112,11 +114,3 @@ def _test_simple_piecewise_compile(*, use_inductor):
...
@@ -112,11 +114,3 @@ def _test_simple_piecewise_compile(*, use_inductor):
output
=
model
(
input
)
output
=
model
(
input
)
assert
global_counter
==
2
assert
global_counter
==
2
assert
torch
.
allclose
(
output
.
cpu
(),
torch
.
tensor
([
3.
,
1.
]))
assert
torch
.
allclose
(
output
.
cpu
(),
torch
.
tensor
([
3.
,
1.
]))
def
test_simple_piecewise_compile_inductor
():
_test_simple_piecewise_compile
(
use_inductor
=
True
)
def
test_simple_piecewise_compile_no_inductor
():
_test_simple_piecewise_compile
(
use_inductor
=
False
)
tests/compile/piecewise/test_toy_llama.py
View file @
3597b06a
...
@@ -11,6 +11,7 @@ initialized randomly with a fixed seed.
...
@@ -11,6 +11,7 @@ initialized randomly with a fixed seed.
from
dataclasses
import
dataclass
from
dataclasses
import
dataclass
from
typing
import
Any
,
Optional
from
typing
import
Any
,
Optional
import
pytest
import
torch
import
torch
from
torch
import
nn
from
torch
import
nn
from
torch.library
import
Library
from
torch.library
import
Library
...
@@ -19,6 +20,7 @@ from vllm.compilation.counter import compilation_counter
...
@@ -19,6 +20,7 @@ from vllm.compilation.counter import compilation_counter
from
vllm.compilation.decorators
import
support_torch_compile
from
vllm.compilation.decorators
import
support_torch_compile
from
vllm.config
import
(
CompilationConfig
,
CompilationLevel
,
VllmConfig
,
from
vllm.config
import
(
CompilationConfig
,
CompilationLevel
,
VllmConfig
,
set_current_vllm_config
)
set_current_vllm_config
)
from
vllm.forward_context
import
set_forward_context
from
vllm.utils
import
direct_register_custom_op
from
vllm.utils
import
direct_register_custom_op
# create a library to hold the custom op
# create a library to hold the custom op
...
@@ -285,29 +287,32 @@ def run_model(llama_config,
...
@@ -285,29 +287,32 @@ def run_model(llama_config,
vllm_config
=
vllm_config
,
vllm_config
=
vllm_config
,
prefix
=
""
).
eval
().
cuda
()
prefix
=
""
).
eval
().
cuda
()
B
=
16
# max batch size
with
set_forward_context
({},
vllm_config
=
vllm_config
):
input_ids
=
torch
.
randint
(
0
,
llama_config
.
vocab_size
,
(
B
,
)).
cuda
()
B
=
16
# max batch size
positions
=
torch
.
arange
(
B
).
cuda
()
input_ids
=
torch
.
randint
(
0
,
llama_config
.
vocab_size
,
(
B
,
)).
cuda
()
positions
=
torch
.
arange
(
B
).
cuda
()
model
(
input_ids
,
positions
)
model
(
input_ids
,
positions
)
model
(
input_ids
[:
2
],
positions
[:
2
])
model
(
input_ids
[:
2
],
positions
[:
2
])
model
(
input_ids
[:
1
],
positions
[:
1
])
model
(
input_ids
[:
1
],
positions
[:
1
])
input_ids
[:
2
].
zero_
()
input_ids
[:
2
].
zero_
()
output
=
model
(
input_ids
[:
2
],
positions
[:
2
])
output
=
model
(
input_ids
[:
2
],
positions
[:
2
])
output
=
output
.
cpu
()
output
=
output
.
cpu
()
if
llama_config
.
tractable_init
:
if
llama_config
.
tractable_init
:
expected_output
=
tractable_computation
(
input_ids
[:
2
],
positions
[:
2
],
expected_output
=
tractable_computation
(
input_ids
[:
2
],
llama_config
).
cpu
()
positions
[:
2
],
llama_config
).
cpu
()
assert
torch
.
allclose
(
output
,
expected_output
)
assert
torch
.
allclose
(
output
,
expected_output
)
else
:
else
:
return
output
.
cpu
()
return
output
.
cpu
()
def
_test_toy_llama
(
*
,
use_inductor
):
@
pytest
.
mark
.
parametrize
(
"use_inductor"
,
[
True
,
False
])
def
test_toy_llama
(
use_inductor
:
bool
):
# compare output with and without piecewise compilation
# compare output with and without piecewise compilation
llama_config
=
LlamaConfig
(
hidden_size
=
128
,
llama_config
=
LlamaConfig
(
hidden_size
=
128
,
...
@@ -379,14 +384,6 @@ def _test_toy_llama(*, use_inductor):
...
@@ -379,14 +384,6 @@ def _test_toy_llama(*, use_inductor):
assert
torch
.
allclose
(
outputs
[
0
],
outputs
[
i
])
assert
torch
.
allclose
(
outputs
[
0
],
outputs
[
i
])
def
test_toy_llama_inductor
():
_test_toy_llama
(
use_inductor
=
True
)
def
test_toy_no_inductor
():
_test_toy_llama
(
use_inductor
=
False
)
@
torch
.
inference_mode
@
torch
.
inference_mode
def
benchmark
():
def
benchmark
():
from
triton.testing
import
do_bench
from
triton.testing
import
do_bench
...
...
tests/utils.py
View file @
3597b06a
...
@@ -667,42 +667,54 @@ def get_physical_device_indices(devices):
...
@@ -667,42 +667,54 @@ def get_physical_device_indices(devices):
@
_nvml
()
@
_nvml
()
def
wait_for_gpu_memory_to_clear
(
devices
:
list
[
int
],
def
wait_for_gpu_memory_to_clear
(
*
,
threshold_bytes
:
int
,
devices
:
list
[
int
],
threshold_bytes
:
Optional
[
int
]
=
None
,
threshold_ratio
:
Optional
[
float
]
=
None
,
timeout_s
:
float
=
120
)
->
None
:
timeout_s
:
float
=
120
)
->
None
:
assert
threshold_bytes
is
not
None
or
threshold_ratio
is
not
None
# Use nvml instead of pytorch to reduce measurement error from torch cuda
# Use nvml instead of pytorch to reduce measurement error from torch cuda
# context.
# context.
devices
=
get_physical_device_indices
(
devices
)
devices
=
get_physical_device_indices
(
devices
)
start_time
=
time
.
time
()
start_time
=
time
.
time
()
while
True
:
while
True
:
output
:
dict
[
int
,
str
]
=
{}
output
:
dict
[
int
,
str
]
=
{}
output_raw
:
dict
[
int
,
float
]
=
{}
output_raw
:
dict
[
int
,
tuple
[
float
,
float
]
]
=
{}
for
device
in
devices
:
for
device
in
devices
:
if
current_platform
.
is_rocm
():
if
current_platform
.
is_rocm
():
dev_handle
=
amdsmi_get_processor_handles
()[
device
]
dev_handle
=
amdsmi_get_processor_handles
()[
device
]
mem_info
=
amdsmi_get_gpu_vram_usage
(
dev_handle
)
mem_info
=
amdsmi_get_gpu_vram_usage
(
dev_handle
)
gb_used
=
mem_info
[
"vram_used"
]
/
2
**
10
gb_used
=
mem_info
[
"vram_used"
]
/
2
**
10
gb_total
=
mem_info
[
"vram_total"
]
/
2
**
10
else
:
else
:
dev_handle
=
nvmlDeviceGetHandleByIndex
(
device
)
dev_handle
=
nvmlDeviceGetHandleByIndex
(
device
)
mem_info
=
nvmlDeviceGetMemoryInfo
(
dev_handle
)
mem_info
=
nvmlDeviceGetMemoryInfo
(
dev_handle
)
gb_used
=
mem_info
.
used
/
2
**
30
gb_used
=
mem_info
.
used
/
2
**
30
output_raw
[
device
]
=
gb_used
gb_total
=
mem_info
.
total
/
2
**
30
output
[
device
]
=
f
'
{
gb_used
:.
02
f
}
'
output_raw
[
device
]
=
(
gb_used
,
gb_total
)
output
[
device
]
=
f
'
{
gb_used
:.
02
f
}
/
{
gb_total
:.
02
f
}
'
print
(
'gpu memory used (GB): '
,
end
=
''
)
print
(
'gpu memory used
/total
(G
i
B): '
,
end
=
''
)
for
k
,
v
in
output
.
items
():
for
k
,
v
in
output
.
items
():
print
(
f
'
{
k
}
=
{
v
}
; '
,
end
=
''
)
print
(
f
'
{
k
}
=
{
v
}
; '
,
end
=
''
)
print
(
''
)
print
(
''
)
if
threshold_bytes
is
not
None
:
is_free
=
lambda
used
,
total
:
used
<=
threshold_bytes
/
2
**
30
threshold
=
f
"
{
threshold_bytes
/
2
**
30
}
GiB"
else
:
is_free
=
lambda
used
,
total
:
used
/
total
<=
threshold_ratio
threshold
=
f
"
{
threshold_ratio
:.
2
f
}
"
dur_s
=
time
.
time
()
-
start_time
dur_s
=
time
.
time
()
-
start_time
if
all
(
v
<=
(
threshold_bytes
/
2
**
30
)
for
v
in
output_raw
.
values
()):
if
all
(
is_free
(
used
,
total
)
for
used
,
total
in
output_raw
.
values
()):
print
(
f
'Done waiting for free GPU memory on devices
{
devices
=
}
'
print
(
f
'Done waiting for free GPU memory on devices
{
devices
=
}
'
f
'(
{
threshold
_bytes
/
2
**
30
=
}
)
{
dur_s
=
:.
02
f
}
'
)
f
'(
{
threshold
=
}
)
{
dur_s
=
:.
02
f
}
'
)
break
break
if
dur_s
>=
timeout_s
:
if
dur_s
>=
timeout_s
:
raise
ValueError
(
f
'Memory of devices
{
devices
=
}
not free after '
raise
ValueError
(
f
'Memory of devices
{
devices
=
}
not free after '
f
'
{
dur_s
=
:.
02
f
}
(
{
threshold
_bytes
/
2
**
30
=
}
)'
)
f
'
{
dur_s
=
:.
02
f
}
(
{
threshold
=
}
)'
)
time
.
sleep
(
5
)
time
.
sleep
(
5
)
...
...
vllm/compilation/cuda_piecewise_backend.py
View file @
3597b06a
...
@@ -14,6 +14,7 @@ from vllm.compilation.backends import VllmBackend
...
@@ -14,6 +14,7 @@ from vllm.compilation.backends import VllmBackend
from
vllm.compilation.counter
import
compilation_counter
from
vllm.compilation.counter
import
compilation_counter
from
vllm.compilation.monitor
import
end_monitoring_torch_compile
from
vllm.compilation.monitor
import
end_monitoring_torch_compile
from
vllm.config
import
VllmConfig
from
vllm.config
import
VllmConfig
from
vllm.forward_context
import
get_forward_context
from
vllm.logger
import
init_logger
from
vllm.logger
import
init_logger
from
vllm.utils
import
weak_ref_tensors
from
vllm.utils
import
weak_ref_tensors
...
@@ -137,7 +138,10 @@ class CUDAPiecewiseBackend:
...
@@ -137,7 +138,10 @@ class CUDAPiecewiseBackend:
if
self
.
is_last_graph
and
not
self
.
to_be_compiled_sizes
:
if
self
.
is_last_graph
and
not
self
.
to_be_compiled_sizes
:
self
.
check_for_ending_compilation
()
self
.
check_for_ending_compilation
()
if
not
entry
.
use_cudagraph
:
# Skip CUDA graphs if this entry doesn't use them OR
# if we're supposed to skip them globally
skip_cuda_graphs
=
get_forward_context
().
skip_cuda_graphs
if
not
entry
.
use_cudagraph
or
skip_cuda_graphs
:
return
entry
.
runnable
(
*
args
)
return
entry
.
runnable
(
*
args
)
if
entry
.
cudagraph
is
None
:
if
entry
.
cudagraph
is
None
:
...
...
vllm/entrypoints/llm.py
View file @
3597b06a
...
@@ -179,7 +179,8 @@ class LLM:
...
@@ -179,7 +179,8 @@ class LLM:
hf_overrides
:
Optional
[
HfOverrides
]
=
None
,
hf_overrides
:
Optional
[
HfOverrides
]
=
None
,
mm_processor_kwargs
:
Optional
[
dict
[
str
,
Any
]]
=
None
,
mm_processor_kwargs
:
Optional
[
dict
[
str
,
Any
]]
=
None
,
override_pooler_config
:
Optional
[
PoolerConfig
]
=
None
,
override_pooler_config
:
Optional
[
PoolerConfig
]
=
None
,
compilation_config
:
Optional
[
Union
[
int
,
dict
[
str
,
Any
]]]
=
None
,
compilation_config
:
Optional
[
Union
[
int
,
dict
[
str
,
Any
],
CompilationConfig
]]
=
None
,
**
kwargs
,
**
kwargs
,
)
->
None
:
)
->
None
:
"""LLM constructor."""
"""LLM constructor."""
...
...
vllm/forward_context.py
View file @
3597b06a
...
@@ -94,6 +94,7 @@ class ForwardContext:
...
@@ -94,6 +94,7 @@ class ForwardContext:
virtual_engine
:
int
# set dynamically for each forward pass
virtual_engine
:
int
# set dynamically for each forward pass
# set dynamically for each forward pass
# set dynamically for each forward pass
dp_metadata
:
Optional
[
DPMetadata
]
=
None
dp_metadata
:
Optional
[
DPMetadata
]
=
None
skip_cuda_graphs
:
bool
=
False
_forward_context
:
Optional
[
ForwardContext
]
=
None
_forward_context
:
Optional
[
ForwardContext
]
=
None
...
@@ -108,11 +109,14 @@ def get_forward_context() -> ForwardContext:
...
@@ -108,11 +109,14 @@ def get_forward_context() -> ForwardContext:
@
contextmanager
@
contextmanager
def
set_forward_context
(
attn_metadata
:
Any
,
def
set_forward_context
(
vllm_config
:
VllmConfig
,
attn_metadata
:
Any
,
virtual_engine
:
int
=
0
,
vllm_config
:
VllmConfig
,
num_tokens
:
Optional
[
int
]
=
None
,
virtual_engine
:
int
=
0
,
num_tokens_across_dp
:
Optional
[
torch
.
Tensor
]
=
None
):
num_tokens
:
Optional
[
int
]
=
None
,
num_tokens_across_dp
:
Optional
[
torch
.
Tensor
]
=
None
,
skip_cuda_graphs
:
bool
=
False
,
):
"""A context manager that stores the current forward context,
"""A context manager that stores the current forward context,
can be attention metadata, etc.
can be attention metadata, etc.
Here we can inject common logic for every model forward pass.
Here we can inject common logic for every model forward pass.
...
@@ -135,7 +139,9 @@ def set_forward_context(attn_metadata: Any,
...
@@ -135,7 +139,9 @@ def set_forward_context(attn_metadata: Any,
static_forward_context
,
static_forward_context
,
virtual_engine
=
virtual_engine
,
virtual_engine
=
virtual_engine
,
attn_metadata
=
attn_metadata
,
attn_metadata
=
attn_metadata
,
dp_metadata
=
dp_metadata
)
dp_metadata
=
dp_metadata
,
skip_cuda_graphs
=
skip_cuda_graphs
,
)
try
:
try
:
yield
yield
...
...
vllm/v1/attention/backends/cpu_attn.py
View file @
3597b06a
...
@@ -7,7 +7,8 @@ from vllm.attention.backends.torch_sdpa import (TorchSDPABackendImpl,
...
@@ -7,7 +7,8 @@ from vllm.attention.backends.torch_sdpa import (TorchSDPABackendImpl,
TorchSDPAMetadata
)
TorchSDPAMetadata
)
from
vllm.attention.backends.utils
import
CommonAttentionState
from
vllm.attention.backends.utils
import
CommonAttentionState
from
vllm.attention.ops.ipex_attn
import
PagedAttention
from
vllm.attention.ops.ipex_attn
import
PagedAttention
from
vllm.v1.attention.backends.utils
import
CommonAttentionMetadata
from
vllm.v1.attention.backends.utils
import
(
AttentionMetadataBuilder
,
CommonAttentionMetadata
)
from
vllm.v1.core.sched.output
import
SchedulerOutput
from
vllm.v1.core.sched.output
import
SchedulerOutput
from
vllm.v1.kv_cache_interface
import
AttentionSpec
from
vllm.v1.kv_cache_interface
import
AttentionSpec
from
vllm.v1.worker.block_table
import
BlockTable
from
vllm.v1.worker.block_table
import
BlockTable
...
@@ -53,7 +54,7 @@ class TorchSDPABackend:
...
@@ -53,7 +54,7 @@ class TorchSDPABackend:
return
False
return
False
class
TorchSDPAMetadataBuilderV1
:
class
TorchSDPAMetadataBuilderV1
(
AttentionMetadataBuilder
[
TorchSDPAMetadata
])
:
def
__init__
(
self
,
runner
:
CPUModelRunner
,
kv_cache_spec
:
AttentionSpec
,
def
__init__
(
self
,
runner
:
CPUModelRunner
,
kv_cache_spec
:
AttentionSpec
,
block_table
:
BlockTable
)
->
None
:
block_table
:
BlockTable
)
->
None
:
...
@@ -118,9 +119,12 @@ class TorchSDPAMetadataBuilderV1:
...
@@ -118,9 +119,12 @@ class TorchSDPAMetadataBuilderV1:
return
True
return
True
def
build
(
self
,
num_reqs
:
int
,
num_actual_tokens
:
int
,
max_query_len
:
int
,
def
build
(
self
,
common_prefix_len
:
int
,
common_prefix_len
:
int
,
common_attn_metadata
:
CommonAttentionMetadata
):
common_attn_metadata
:
CommonAttentionMetadata
):
num_reqs
=
common_attn_metadata
.
num_reqs
num_actual_tokens
=
common_attn_metadata
.
num_actual_tokens
max_query_len
=
common_attn_metadata
.
max_query_len
runner
=
self
.
runner
runner
=
self
.
runner
block_table
=
self
.
block_table
block_table
=
self
.
block_table
seq_lens_np
=
runner
.
seq_lens_np
[:
num_reqs
]
seq_lens_np
=
runner
.
seq_lens_np
[:
num_reqs
]
...
...
vllm/v1/attention/backends/flash_attn.py
View file @
3597b06a
...
@@ -2,7 +2,7 @@
...
@@ -2,7 +2,7 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""Attention layer with FlashAttention."""
"""Attention layer with FlashAttention."""
from
dataclasses
import
dataclass
from
dataclasses
import
dataclass
from
typing
import
TYPE_CHECKING
,
Any
,
Optional
from
typing
import
TYPE_CHECKING
,
Any
,
ClassVar
,
Optional
import
numpy
as
np
import
numpy
as
np
import
torch
import
torch
...
@@ -21,13 +21,12 @@ from vllm.distributed.kv_transfer.kv_connector.utils import (
...
@@ -21,13 +21,12 @@ from vllm.distributed.kv_transfer.kv_connector.utils import (
from
vllm.logger
import
init_logger
from
vllm.logger
import
init_logger
from
vllm.platforms
import
current_platform
from
vllm.platforms
import
current_platform
from
vllm.utils
import
cdiv
from
vllm.utils
import
cdiv
from
vllm.v1.attention.backends.utils
import
CommonAttentionMetadata
from
vllm.v1.attention.backends.utils
import
(
AttentionMetadataBuilder
,
CommonAttentionMetadata
)
from
vllm.v1.kv_cache_interface
import
AttentionSpec
from
vllm.v1.kv_cache_interface
import
AttentionSpec
from
vllm.v1.worker.block_table
import
BlockTable
from
vllm.v1.worker.block_table
import
BlockTable
if
TYPE_CHECKING
:
if
TYPE_CHECKING
:
from
vllm.v1.core.sched.output
import
SchedulerOutput
from
vllm.v1.worker.gpu_input_batch
import
InputBatch
from
vllm.v1.worker.gpu_model_runner
import
GPUModelRunner
from
vllm.v1.worker.gpu_model_runner
import
GPUModelRunner
if
current_platform
.
is_cuda
():
if
current_platform
.
is_cuda
():
...
@@ -306,7 +305,9 @@ def _get_sliding_window_configs(
...
@@ -306,7 +305,9 @@ def _get_sliding_window_configs(
return
sliding_window_configs
return
sliding_window_configs
class
FlashAttentionMetadataBuilder
:
class
FlashAttentionMetadataBuilder
(
AttentionMetadataBuilder
[
FlashAttentionMetadata
]):
full_cudagraph_supported
:
ClassVar
[
bool
]
=
get_flash_attn_version
()
==
3
def
__init__
(
self
,
runner
:
"GPUModelRunner"
,
kv_cache_spec
:
AttentionSpec
,
def
__init__
(
self
,
runner
:
"GPUModelRunner"
,
kv_cache_spec
:
AttentionSpec
,
block_table
:
BlockTable
):
block_table
:
BlockTable
):
...
@@ -336,13 +337,14 @@ class FlashAttentionMetadataBuilder:
...
@@ -336,13 +337,14 @@ class FlashAttentionMetadataBuilder:
# populated on first build() call.
# populated on first build() call.
self
.
aot_sliding_window
:
Optional
[
tuple
[
int
,
int
]]
=
None
self
.
aot_sliding_window
:
Optional
[
tuple
[
int
,
int
]]
=
None
def
reorder_batch
(
self
,
input_batch
:
"InputBatch"
,
def
build
(
scheduler_output
:
"SchedulerOutput"
)
->
bool
:
self
,
common_prefix_len
:
int
,
return
False
common_attn_metadata
:
CommonAttentionMetadata
)
->
FlashAttentionMetadata
:
num_reqs
=
common_attn_metadata
.
num_reqs
num_actual_tokens
=
common_attn_metadata
.
num_actual_tokens
max_query_len
=
common_attn_metadata
.
max_query_len
def
build
(
self
,
num_reqs
:
int
,
num_actual_tokens
:
int
,
max_query_len
:
int
,
common_prefix_len
:
int
,
common_attn_metadata
:
CommonAttentionMetadata
):
max_seq_len
=
int
(
self
.
runner
.
seq_lens_np
[:
num_reqs
].
max
())
max_seq_len
=
int
(
self
.
runner
.
seq_lens_np
[:
num_reqs
].
max
())
query_start_loc
=
common_attn_metadata
.
query_start_loc
query_start_loc
=
common_attn_metadata
.
query_start_loc
seq_lens
=
common_attn_metadata
.
seq_lens
seq_lens
=
common_attn_metadata
.
seq_lens
...
@@ -496,6 +498,11 @@ class FlashAttentionMetadataBuilder:
...
@@ -496,6 +498,11 @@ class FlashAttentionMetadataBuilder:
)
)
return
attn_metadata
return
attn_metadata
def
can_run_in_cudagraph
(
self
,
common_attn_metadata
:
CommonAttentionMetadata
)
->
bool
:
# Full CUDA Graph always supported (FA2 support checked separately)
return
True
def
use_cascade_attention
(
self
,
*
args
,
**
kwargs
)
->
bool
:
def
use_cascade_attention
(
self
,
*
args
,
**
kwargs
)
->
bool
:
return
use_cascade_attention
(
*
args
,
**
kwargs
)
return
use_cascade_attention
(
*
args
,
**
kwargs
)
...
...
vllm/v1/attention/backends/flashinfer.py
View file @
3597b06a
...
@@ -18,7 +18,8 @@ from vllm.attention.layer import Attention
...
@@ -18,7 +18,8 @@ from vllm.attention.layer import Attention
from
vllm.config
import
VllmConfig
,
get_layers_from_vllm_config
from
vllm.config
import
VllmConfig
,
get_layers_from_vllm_config
from
vllm.logger
import
init_logger
from
vllm.logger
import
init_logger
from
vllm.v1.attention.backends.flash_attn
import
use_cascade_attention
from
vllm.v1.attention.backends.flash_attn
import
use_cascade_attention
from
vllm.v1.attention.backends.utils
import
CommonAttentionMetadata
from
vllm.v1.attention.backends.utils
import
(
AttentionMetadataBuilder
,
CommonAttentionMetadata
)
from
vllm.v1.kv_cache_interface
import
AttentionSpec
from
vllm.v1.kv_cache_interface
import
AttentionSpec
from
vllm.v1.worker.block_table
import
BlockTable
from
vllm.v1.worker.block_table
import
BlockTable
...
@@ -202,7 +203,7 @@ class FlashInferMetadata:
...
@@ -202,7 +203,7 @@ class FlashInferMetadata:
f
" received
{
self
.
head_dim
}
."
)
f
" received
{
self
.
head_dim
}
."
)
class
FlashInferMetadataBuilder
:
class
FlashInferMetadataBuilder
(
AttentionMetadataBuilder
[
FlashInferMetadata
])
:
def
__init__
(
self
,
runner
:
GPUModelRunner
,
kv_cache_spec
:
AttentionSpec
,
def
__init__
(
self
,
runner
:
GPUModelRunner
,
kv_cache_spec
:
AttentionSpec
,
block_table
:
BlockTable
):
block_table
:
BlockTable
):
...
@@ -399,9 +400,11 @@ class FlashInferMetadataBuilder:
...
@@ -399,9 +400,11 @@ class FlashInferMetadataBuilder:
kv_data_type
=
attn_metadata
.
data_type
,
kv_data_type
=
attn_metadata
.
data_type
,
)
)
def
build
(
self
,
num_reqs
:
int
,
num_actual_tokens
:
int
,
max_query_len
:
int
,
def
build
(
self
,
common_prefix_len
:
int
,
common_prefix_len
:
int
,
common_attn_metadata
:
CommonAttentionMetadata
):
common_attn_metadata
:
CommonAttentionMetadata
):
num_reqs
=
common_attn_metadata
.
num_reqs
num_actual_tokens
=
common_attn_metadata
.
num_actual_tokens
assert
self
.
_num_decodes
+
self
.
_num_prefills
==
num_reqs
assert
self
.
_num_decodes
+
self
.
_num_prefills
==
num_reqs
assert
(
self
.
_num_decode_tokens
+
assert
(
self
.
_num_decode_tokens
+
self
.
_num_prefill_tokens
==
num_actual_tokens
)
self
.
_num_prefill_tokens
==
num_actual_tokens
)
...
...
vllm/v1/attention/backends/flex_attention.py
View file @
3597b06a
...
@@ -15,7 +15,8 @@ from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
...
@@ -15,7 +15,8 @@ from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
is_quantized_kv_cache
)
is_quantized_kv_cache
)
from
vllm.logger
import
init_logger
from
vllm.logger
import
init_logger
from
vllm.platforms
import
current_platform
from
vllm.platforms
import
current_platform
from
vllm.v1.attention.backends.utils
import
CommonAttentionMetadata
from
vllm.v1.attention.backends.utils
import
(
AttentionMetadataBuilder
,
CommonAttentionMetadata
)
from
vllm.v1.kv_cache_interface
import
AttentionSpec
from
vllm.v1.kv_cache_interface
import
AttentionSpec
from
vllm.v1.worker.block_table
import
BlockTable
from
vllm.v1.worker.block_table
import
BlockTable
...
@@ -25,8 +26,6 @@ if current_platform.is_cuda():
...
@@ -25,8 +26,6 @@ if current_platform.is_cuda():
logger
=
init_logger
(
__name__
)
logger
=
init_logger
(
__name__
)
if
TYPE_CHECKING
:
if
TYPE_CHECKING
:
from
vllm.v1.core.sched.output
import
SchedulerOutput
from
vllm.v1.worker.gpu_input_batch
import
InputBatch
from
vllm.v1.worker.gpu_model_runner
import
GPUModelRunner
from
vllm.v1.worker.gpu_model_runner
import
GPUModelRunner
create_block_mask_compiled
=
torch
.
compile
(
create_block_mask
,
create_block_mask_compiled
=
torch
.
compile
(
create_block_mask
,
...
@@ -256,7 +255,8 @@ class FlexAttentionMetadata:
...
@@ -256,7 +255,8 @@ class FlexAttentionMetadata:
self
.
block_mask
=
self
.
build_block_mask
()
self
.
block_mask
=
self
.
build_block_mask
()
class
FlexAttentionMetadataBuilder
:
class
FlexAttentionMetadataBuilder
(
AttentionMetadataBuilder
[
FlexAttentionMetadata
]):
def
__init__
(
self
,
runner
:
"GPUModelRunner"
,
kv_cache_spec
:
AttentionSpec
,
def
__init__
(
self
,
runner
:
"GPUModelRunner"
,
kv_cache_spec
:
AttentionSpec
,
block_table
:
BlockTable
):
block_table
:
BlockTable
):
...
@@ -272,13 +272,12 @@ class FlexAttentionMetadataBuilder:
...
@@ -272,13 +272,12 @@ class FlexAttentionMetadataBuilder:
self
.
kv_cache_spec
=
kv_cache_spec
self
.
kv_cache_spec
=
kv_cache_spec
self
.
block_table
=
block_table
self
.
block_table
=
block_table
def
reorder_batch
(
self
,
input_batch
:
"InputBatch"
,
def
build
(
self
,
common_prefix_len
:
int
,
scheduler_output
:
"SchedulerOutput"
)
->
bool
:
return
False
def
build
(
self
,
num_reqs
:
int
,
num_actual_tokens
:
int
,
max_query_len
:
int
,
common_prefix_len
:
int
,
common_attn_metadata
:
CommonAttentionMetadata
):
common_attn_metadata
:
CommonAttentionMetadata
):
num_reqs
=
common_attn_metadata
.
num_reqs
num_actual_tokens
=
common_attn_metadata
.
num_actual_tokens
max_query_len
=
common_attn_metadata
.
max_query_len
max_seq_len
=
self
.
runner
.
seq_lens_np
[:
num_reqs
].
max
()
max_seq_len
=
self
.
runner
.
seq_lens_np
[:
num_reqs
].
max
()
query_start_loc
=
common_attn_metadata
.
query_start_loc
query_start_loc
=
common_attn_metadata
.
query_start_loc
seq_lens
=
common_attn_metadata
.
seq_lens
seq_lens
=
common_attn_metadata
.
seq_lens
...
@@ -332,9 +331,6 @@ class FlexAttentionMetadataBuilder:
...
@@ -332,9 +331,6 @@ class FlexAttentionMetadataBuilder:
)
)
return
out
return
out
def
use_cascade_attention
(
self
,
*
args
,
**
kwargs
)
->
bool
:
return
False
class
FlexAttentionImpl
(
AttentionImpl
):
class
FlexAttentionImpl
(
AttentionImpl
):
sliding_window
:
Optional
[
tuple
[
int
,
int
]]
sliding_window
:
Optional
[
tuple
[
int
,
int
]]
...
...
vllm/v1/attention/backends/mla/common.py
View file @
3597b06a
...
@@ -207,7 +207,8 @@ from vllm.model_executor.layers.linear import (ColumnParallelLinear,
...
@@ -207,7 +207,8 @@ from vllm.model_executor.layers.linear import (ColumnParallelLinear,
UnquantizedLinearMethod
)
UnquantizedLinearMethod
)
from
vllm.platforms
import
current_platform
from
vllm.platforms
import
current_platform
from
vllm.utils
import
cdiv
,
round_down
from
vllm.utils
import
cdiv
,
round_down
from
vllm.v1.attention.backends.utils
import
CommonAttentionMetadata
from
vllm.v1.attention.backends.utils
import
(
AttentionMetadataBuilder
,
CommonAttentionMetadata
)
from
vllm.v1.kv_cache_interface
import
AttentionSpec
from
vllm.v1.kv_cache_interface
import
AttentionSpec
from
vllm.v1.worker.block_table
import
BlockTable
from
vllm.v1.worker.block_table
import
BlockTable
...
@@ -329,7 +330,7 @@ class MLACommonMetadata(Generic[D]):
...
@@ -329,7 +330,7 @@ class MLACommonMetadata(Generic[D]):
M
=
TypeVar
(
"M"
,
bound
=
MLACommonMetadata
)
M
=
TypeVar
(
"M"
,
bound
=
MLACommonMetadata
)
class
MLACommonMetadataBuilder
(
Generic
[
M
]):
class
MLACommonMetadataBuilder
(
AttentionMetadataBuilder
[
M
]):
"""
"""
NOTE: Please read the comment at the top of the file before trying to
NOTE: Please read the comment at the top of the file before trying to
understand this class
understand this class
...
@@ -450,9 +451,32 @@ class MLACommonMetadataBuilder(Generic[M]):
...
@@ -450,9 +451,32 @@ class MLACommonMetadataBuilder(Generic[M]):
seq_lens
=
seq_lens
,
seq_lens
=
seq_lens
,
)
)
def
build
(
self
,
num_reqs
:
int
,
num_actual_tokens
:
int
,
max_query_len
:
int
,
def
build_for_cudagraph_capture
(
common_prefix_len
:
int
,
self
,
common_attn_metadata
:
CommonAttentionMetadata
)
->
M
:
"""
This method builds the metadata for full cudagraph capture.
Currently, only decode is supported for full cudagraphs with MLA.
"""
m
=
common_attn_metadata
assert
m
.
num_reqs
==
m
.
num_actual_tokens
,
\
"MLA only supports decode-only full CUDAGraph capture. "
\
"Make sure all cudagraph capture sizes <= max_num_seq."
m
.
max_query_len
=
1
# decode-only
# Update state usually set in reorder_batch.
self
.
_num_decodes
=
m
.
num_reqs
self
.
_num_decode_tokens
=
m
.
num_actual_tokens
self
.
_num_prefills
=
0
self
.
_num_prefill_tokens
=
0
return
self
.
build
(
0
,
m
)
def
build
(
self
,
common_prefix_len
:
int
,
common_attn_metadata
:
CommonAttentionMetadata
)
->
M
:
common_attn_metadata
:
CommonAttentionMetadata
)
->
M
:
num_reqs
=
common_attn_metadata
.
num_reqs
num_actual_tokens
=
common_attn_metadata
.
num_actual_tokens
max_query_len
=
common_attn_metadata
.
max_query_len
assert
self
.
_num_decodes
+
self
.
_num_prefills
==
num_reqs
assert
self
.
_num_decodes
+
self
.
_num_prefills
==
num_reqs
# Note(simon): be careful about the CPU <> GPU memory movement in this
# Note(simon): be careful about the CPU <> GPU memory movement in this
...
@@ -461,8 +485,11 @@ class MLACommonMetadataBuilder(Generic[M]):
...
@@ -461,8 +485,11 @@ class MLACommonMetadataBuilder(Generic[M]):
device
=
self
.
runner
.
device
device
=
self
.
runner
.
device
block_table
=
self
.
block_table
block_table
=
self
.
block_table
block_table_tensor
=
block_table
.
get_device_tensor
()[:
num_reqs
]
block_table_tensor
=
block_table
.
get_device_tensor
()[:
num_reqs
]
slot_mapping
=
block_table
.
slot_mapping_cpu
[:
num_actual_tokens
].
to
(
block_table
.
slot_mapping
[:
num_actual_tokens
].
copy_
(
device
,
non_blocking
=
True
).
long
()
block_table
.
slot_mapping_cpu
[:
num_actual_tokens
],
non_blocking
=
True
)
block_table
.
slot_mapping
[
num_actual_tokens
:].
fill_
(
-
1
)
slot_mapping
=
block_table
.
slot_mapping
[:
num_actual_tokens
]
query_start_loc
=
common_attn_metadata
.
query_start_loc
query_start_loc
=
common_attn_metadata
.
query_start_loc
seq_lens
=
common_attn_metadata
.
seq_lens
seq_lens
=
common_attn_metadata
.
seq_lens
...
@@ -564,8 +591,9 @@ class MLACommonMetadataBuilder(Generic[M]):
...
@@ -564,8 +591,9 @@ class MLACommonMetadataBuilder(Generic[M]):
decode
=
decode_metadata
,
decode
=
decode_metadata
,
)
)
def
use_cascade_attention
(
self
,
*
args
,
**
kwargs
)
->
bool
:
def
can_run_in_cudagraph
(
return
False
self
,
common_attn_metadata
:
CommonAttentionMetadata
)
->
bool
:
return
common_attn_metadata
.
max_query_len
==
1
class
MLACommonImpl
(
MLAAttentionImpl
[
M
],
Generic
[
M
]):
class
MLACommonImpl
(
MLAAttentionImpl
[
M
],
Generic
[
M
]):
...
...
vllm/v1/attention/backends/mla/flashmla.py
View file @
3597b06a
...
@@ -2,7 +2,7 @@
...
@@ -2,7 +2,7 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from
dataclasses
import
dataclass
from
dataclasses
import
dataclass
from
typing
import
Any
,
Optional
from
typing
import
Any
,
ClassVar
,
Optional
import
torch
import
torch
...
@@ -44,7 +44,7 @@ class FlashMLABackend(MLACommonBackend):
...
@@ -44,7 +44,7 @@ class FlashMLABackend(MLACommonBackend):
@
dataclass
@
dataclass
class
FlashMLADecodeMetadata
(
MLACommonDecodeMetadata
):
class
FlashMLADecodeMetadata
(
MLACommonDecodeMetadata
):
tile_scheduler_metadata
:
tuple
[
torch
.
Tensor
,
torch
.
Tensor
]
tile_scheduler_metadata
:
torch
.
Tensor
num_splits
:
torch
.
Tensor
num_splits
:
torch
.
Tensor
...
@@ -54,14 +54,18 @@ class FlashMLAMetadata(MLACommonMetadata[FlashMLADecodeMetadata]):
...
@@ -54,14 +54,18 @@ class FlashMLAMetadata(MLACommonMetadata[FlashMLADecodeMetadata]):
class
FlashMLAMetadataBuilder
(
MLACommonMetadataBuilder
[
FlashMLAMetadata
]):
class
FlashMLAMetadataBuilder
(
MLACommonMetadataBuilder
[
FlashMLAMetadata
]):
full_cudagraph_supported
:
ClassVar
[
bool
]
=
True
# Decode-only
def
__init__
(
self
,
runner
,
kv_cache_spec
:
AttentionSpec
,
def
__init__
(
self
,
runner
,
kv_cache_spec
:
AttentionSpec
,
block_table
:
BlockTable
):
block_table
:
BlockTable
):
super
().
__init__
(
runner
,
kv_cache_spec
,
block_table
)
super
().
__init__
(
runner
,
kv_cache_spec
,
block_table
,
FlashMLAMetadata
)
self
.
num_q_heads
=
self
.
runner
.
model_config
.
get_num_attention_heads
(
self
.
num_q_heads
=
self
.
runner
.
model_config
.
get_num_attention_heads
(
self
.
runner
.
parallel_config
)
self
.
runner
.
parallel_config
)
self
.
cg_buf_tile_scheduler_metadata
=
None
self
.
cg_buf_num_splits
=
None
def
_build_decode
(
self
,
block_table_tensor
:
torch
.
Tensor
,
def
_build_decode
(
self
,
block_table_tensor
:
torch
.
Tensor
,
seq_lens
:
torch
.
Tensor
)
->
FlashMLADecodeMetadata
:
seq_lens
:
torch
.
Tensor
)
->
FlashMLADecodeMetadata
:
tile_scheduler_metadata
,
num_splits
=
\
tile_scheduler_metadata
,
num_splits
=
\
...
@@ -71,6 +75,30 @@ class FlashMLAMetadataBuilder(MLACommonMetadataBuilder[FlashMLAMetadata]):
...
@@ -71,6 +75,30 @@ class FlashMLAMetadataBuilder(MLACommonMetadataBuilder[FlashMLAMetadata]):
1
,
# MQA for the decode path
1
,
# MQA for the decode path
)
)
if
self
.
runner
.
full_cuda_graph
:
# First time around (CUDAGraph capture), allocate the static buffer
if
self
.
cg_buf_tile_scheduler_metadata
is
None
:
self
.
cg_buf_tile_scheduler_metadata
=
tile_scheduler_metadata
self
.
cg_buf_num_splits
=
num_splits
else
:
assert
self
.
cg_buf_num_splits
is
not
None
# Metadata per-SM, fixed size (#SMs, TileMetadataSize)
assert
(
self
.
cg_buf_tile_scheduler_metadata
.
size
()
==
tile_scheduler_metadata
.
size
())
self
.
cg_buf_tile_scheduler_metadata
.
\
copy_
(
tile_scheduler_metadata
)
tile_scheduler_metadata
=
self
.
cg_buf_tile_scheduler_metadata
# Num splits is per-batch, varying size (batch_size,)
n
=
num_splits
.
size
(
0
)
# make sure static buffer is large enough
assert
n
<=
self
.
cg_buf_num_splits
.
size
(
0
)
num_splits_view
=
self
.
cg_buf_num_splits
[:
n
]
num_splits_view
.
copy_
(
num_splits
)
self
.
cg_buf_num_splits
[
n
:].
fill_
(
0
)
# fill the rest with 0s
num_splits
=
num_splits_view
return
FlashMLADecodeMetadata
(
return
FlashMLADecodeMetadata
(
block_table
=
block_table_tensor
,
block_table
=
block_table_tensor
,
seq_lens
=
seq_lens
,
seq_lens
=
seq_lens
,
...
...
vllm/v1/attention/backends/mla/rocm_aiter_mla.py
View file @
3597b06a
...
@@ -66,7 +66,7 @@ class AiterMLAMetadataBuilder(MLACommonMetadataBuilder[AiterMLAMetadata]):
...
@@ -66,7 +66,7 @@ class AiterMLAMetadataBuilder(MLACommonMetadataBuilder[AiterMLAMetadata]):
def
__init__
(
self
,
runner
,
kv_cache_spec
:
AttentionSpec
,
def
__init__
(
self
,
runner
,
kv_cache_spec
:
AttentionSpec
,
block_table
:
BlockTable
):
block_table
:
BlockTable
):
super
().
__init__
(
runner
,
kv_cache_spec
,
block_table
)
super
().
__init__
(
runner
,
kv_cache_spec
,
block_table
,
AiterMLAMetadata
)
assert
self
.
kv_cache_spec
.
block_size
==
1
,
"AITER MLA"
\
assert
self
.
kv_cache_spec
.
block_size
==
1
,
"AITER MLA"
\
"only supports block size 1."
"only supports block size 1."
...
...
vllm/v1/attention/backends/utils.py
View file @
3597b06a
# SPDX-License-Identifier: Apache-2.0
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import
abc
from
abc
import
abstractmethod
from
dataclasses
import
dataclass
from
dataclasses
import
dataclass
from
typing
import
TYPE_CHECKING
,
ClassVar
,
Generic
,
TypeVar
import
numpy
as
np
import
torch
import
torch
if
TYPE_CHECKING
:
from
vllm.v1.core.sched.output
import
SchedulerOutput
from
vllm.v1.worker.gpu_input_batch
import
InputBatch
@
dataclass
@
dataclass
class
CommonAttentionMetadata
:
class
CommonAttentionMetadata
:
"""
"""
A
ttention metadata
attributes that can be
shared
by
layers
in different KV
Per-batch a
ttention metadata
,
shared
across
layers
and backends.
cache groups and thus having different block table
.
AttentionMetadataBuilder instances use it to construct per-layer metadata
.
"""
"""
query_start_loc
:
torch
.
Tensor
query_start_loc
:
torch
.
Tensor
...
@@ -18,6 +26,67 @@ class CommonAttentionMetadata:
...
@@ -18,6 +26,67 @@ class CommonAttentionMetadata:
"""(batch_size,), the length of each request including both computed tokens
"""(batch_size,), the length of each request including both computed tokens
and newly scheduled tokens"""
and newly scheduled tokens"""
num_reqs
:
int
"""Number of requests"""
num_actual_tokens
:
int
"""Total number of tokens in batch"""
max_query_len
:
int
"""Longest query in batch"""
M
=
TypeVar
(
"M"
)
class
AttentionMetadataBuilder
(
abc
.
ABC
,
Generic
[
M
]):
# Does this backend/builder support CUDA Graphs for attention.
full_cudagraph_supported
:
ClassVar
[
bool
]
=
False
@
abstractmethod
def
build
(
self
,
common_prefix_len
:
int
,
common_attn_metadata
:
CommonAttentionMetadata
)
->
M
:
"""
Central method that builds attention metadata.
Some builders (MLA) require reorder_batch to be called prior to build.
"""
raise
NotImplementedError
def
can_run_in_cudagraph
(
self
,
common_attn_metadata
:
CommonAttentionMetadata
)
->
bool
:
"""
Can this batch (with given metadata) use CUDA Graphs for attention.
"""
return
False
def
build_for_cudagraph_capture
(
self
,
common_attn_metadata
:
CommonAttentionMetadata
)
->
M
:
"""
Build attention metadata for CUDA graph capture. Uses build by default.
Subclasses that override this method should call self.build or
super().build_for_cudagraph_capture.
"""
return
self
.
build
(
common_prefix_len
=
0
,
common_attn_metadata
=
common_attn_metadata
)
def
use_cascade_attention
(
self
,
common_prefix_len
:
int
,
query_lens
:
np
.
ndarray
,
num_query_heads
:
int
,
num_kv_heads
:
int
,
use_alibi
:
bool
,
use_sliding_window
:
bool
,
num_sms
:
int
,
)
->
bool
:
return
False
def
reorder_batch
(
self
,
input_batch
:
"InputBatch"
,
scheduler_output
:
"SchedulerOutput"
)
->
bool
:
"""
This method can reorder the batch if desired by the backend.
:return: Has the batch been reordered (default False).
"""
return
False
def
validate_kv_sharing_target
(
current_layer_name
,
target_layer_name
,
def
validate_kv_sharing_target
(
current_layer_name
,
target_layer_name
,
static_forward_context
):
static_forward_context
):
...
...
vllm/v1/spec_decode/eagle.py
View file @
3597b06a
...
@@ -138,15 +138,17 @@ class EagleProposer:
...
@@ -138,15 +138,17 @@ class EagleProposer:
max_query_len
=
query_lens
.
max
().
item
()
max_query_len
=
query_lens
.
max
().
item
()
common_attn_metadata
=
CommonAttentionMetadata
(
common_attn_metadata
=
CommonAttentionMetadata
(
query_start_loc
=
cu_num_tokens
,
seq_lens
=
seq_lens
)
query_start_loc
=
cu_num_tokens
,
seq_lens
=
seq_lens
,
num_reqs
=
batch_size
,
num_actual_tokens
=
num_tokens
,
max_query_len
=
max_query_len
,
)
assert
self
.
runner
is
not
None
assert
self
.
runner
is
not
None
# FIXME: need to consider multiple kv_cache_groups
# FIXME: need to consider multiple kv_cache_groups
attn_metadata
=
self
.
runner
.
attn_metadata_builder
.
build
(
attn_metadata
=
self
.
runner
.
attn_metadata_builder
.
build
(
num_reqs
=
batch_size
,
num_actual_tokens
=
num_tokens
,
max_query_len
=
max_query_len
,
common_prefix_len
=
0
,
common_prefix_len
=
0
,
common_attn_metadata
=
common_attn_metadata
,
common_attn_metadata
=
common_attn_metadata
,
)
)
...
...
vllm/v1/worker/gpu_model_runner.py
View file @
3597b06a
...
@@ -16,10 +16,8 @@ from tqdm import tqdm
...
@@ -16,10 +16,8 @@ from tqdm import tqdm
import
vllm.envs
as
envs
import
vllm.envs
as
envs
from
vllm.attention
import
AttentionType
,
get_attn_backend
from
vllm.attention
import
AttentionType
,
get_attn_backend
from
vllm.attention.backends.abstract
import
(
AttentionBackend
,
from
vllm.attention.backends.abstract
import
AttentionBackend
AttentionMetadataBuilder
)
from
vllm.attention.layer
import
Attention
from
vllm.attention.layer
import
Attention
from
vllm.attention.utils.fa_utils
import
get_flash_attn_version
from
vllm.config
import
(
CompilationLevel
,
VllmConfig
,
from
vllm.config
import
(
CompilationLevel
,
VllmConfig
,
get_layers_from_vllm_config
)
get_layers_from_vllm_config
)
from
vllm.distributed.kv_transfer
import
(
get_kv_transfer_group
,
from
vllm.distributed.kv_transfer
import
(
get_kv_transfer_group
,
...
@@ -41,7 +39,8 @@ from vllm.sequence import IntermediateTensors
...
@@ -41,7 +39,8 @@ from vllm.sequence import IntermediateTensors
from
vllm.utils
import
(
STR_DTYPE_TO_TORCH_DTYPE
,
DeviceMemoryProfiler
,
from
vllm.utils
import
(
STR_DTYPE_TO_TORCH_DTYPE
,
DeviceMemoryProfiler
,
GiB_bytes
,
LazyLoader
,
async_tensor_h2d
,
cdiv
,
GiB_bytes
,
LazyLoader
,
async_tensor_h2d
,
cdiv
,
check_use_alibi
,
is_pin_memory_available
)
check_use_alibi
,
is_pin_memory_available
)
from
vllm.v1.attention.backends.utils
import
CommonAttentionMetadata
from
vllm.v1.attention.backends.utils
import
(
AttentionMetadataBuilder
,
CommonAttentionMetadata
)
from
vllm.v1.core.encoder_cache_manager
import
compute_encoder_budget
from
vllm.v1.core.encoder_cache_manager
import
compute_encoder_budget
from
vllm.v1.kv_cache_interface
import
(
AttentionSpec
,
FullAttentionSpec
,
from
vllm.v1.kv_cache_interface
import
(
AttentionSpec
,
FullAttentionSpec
,
KVCacheConfig
,
KVCacheSpec
,
KVCacheConfig
,
KVCacheSpec
,
...
@@ -89,6 +88,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
...
@@ -89,6 +88,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
self
.
vllm_config
=
vllm_config
self
.
vllm_config
=
vllm_config
self
.
model_config
=
vllm_config
.
model_config
self
.
model_config
=
vllm_config
.
model_config
self
.
cache_config
=
vllm_config
.
cache_config
self
.
cache_config
=
vllm_config
.
cache_config
self
.
compilation_config
=
vllm_config
.
compilation_config
self
.
lora_config
=
vllm_config
.
lora_config
self
.
lora_config
=
vllm_config
.
lora_config
self
.
load_config
=
vllm_config
.
load_config
self
.
load_config
=
vllm_config
.
load_config
self
.
parallel_config
=
vllm_config
.
parallel_config
self
.
parallel_config
=
vllm_config
.
parallel_config
...
@@ -197,7 +197,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
...
@@ -197,7 +197,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
block_sizes
=
[
self
.
cache_config
.
block_size
],
block_sizes
=
[
self
.
cache_config
.
block_size
],
)
)
self
.
use_cuda_graph
=
(
self
.
vllm_config
.
compilation_config
.
level
self
.
use_cuda_graph
=
(
self
.
compilation_config
.
level
==
CompilationLevel
.
PIECEWISE
==
CompilationLevel
.
PIECEWISE
and
not
self
.
model_config
.
enforce_eager
)
and
not
self
.
model_config
.
enforce_eager
)
# TODO(woosuk): Provide an option to tune the max cudagraph batch size.
# TODO(woosuk): Provide an option to tune the max cudagraph batch size.
...
@@ -205,8 +205,9 @@ class GPUModelRunner(LoRAModelRunnerMixin):
...
@@ -205,8 +205,9 @@ class GPUModelRunner(LoRAModelRunnerMixin):
# self.cudagraph_batch_sizes sorts in ascending order.
# self.cudagraph_batch_sizes sorts in ascending order.
# The batch sizes in the config are in descending order.
# The batch sizes in the config are in descending order.
self
.
cudagraph_batch_sizes
=
list
(
self
.
cudagraph_batch_sizes
=
list
(
reversed
(
reversed
(
self
.
compilation_config
.
cudagraph_capture_sizes
))
self
.
vllm_config
.
compilation_config
.
cudagraph_capture_sizes
))
self
.
full_cuda_graph
=
self
.
compilation_config
.
full_cuda_graph
# Cache the device properties.
# Cache the device properties.
self
.
_init_device_properties
()
self
.
_init_device_properties
()
...
@@ -555,7 +556,15 @@ class GPUModelRunner(LoRAModelRunnerMixin):
...
@@ -555,7 +556,15 @@ class GPUModelRunner(LoRAModelRunnerMixin):
def
_prepare_inputs
(
def
_prepare_inputs
(
self
,
self
,
scheduler_output
:
"SchedulerOutput"
,
scheduler_output
:
"SchedulerOutput"
,
)
->
tuple
[
dict
[
str
,
Any
],
torch
.
Tensor
,
Optional
[
SpecDecodeMetadata
]]:
)
->
tuple
[
dict
[
str
,
Any
],
bool
,
torch
.
Tensor
,
Optional
[
SpecDecodeMetadata
]]:
"""
:return: tuple[
attn_metadata: layer-to-attention_metadata mapping,
attention_cuda_graphs: whether attention can run in cudagraph
logits_indices, spec_decode_metadata
]
"""
total_num_scheduled_tokens
=
scheduler_output
.
total_num_scheduled_tokens
total_num_scheduled_tokens
=
scheduler_output
.
total_num_scheduled_tokens
assert
total_num_scheduled_tokens
>
0
assert
total_num_scheduled_tokens
>
0
num_reqs
=
self
.
input_batch
.
num_reqs
num_reqs
=
self
.
input_batch
.
num_reqs
...
@@ -669,7 +678,12 @@ class GPUModelRunner(LoRAModelRunnerMixin):
...
@@ -669,7 +678,12 @@ class GPUModelRunner(LoRAModelRunnerMixin):
seq_lens
=
self
.
seq_lens
[:
num_reqs
]
seq_lens
=
self
.
seq_lens
[:
num_reqs
]
common_attn_metadata
=
CommonAttentionMetadata
(
common_attn_metadata
=
CommonAttentionMetadata
(
query_start_loc
=
query_start_loc
,
seq_lens
=
seq_lens
)
query_start_loc
=
query_start_loc
,
seq_lens
=
seq_lens
,
num_reqs
=
num_reqs
,
num_actual_tokens
=
total_num_scheduled_tokens
,
max_query_len
=
max_num_scheduled_tokens
,
)
attn_metadata
:
dict
[
str
,
Any
]
=
{}
attn_metadata
:
dict
[
str
,
Any
]
=
{}
# Prepare the attention metadata for each KV cache group and make layers
# Prepare the attention metadata for each KV cache group and make layers
...
@@ -679,25 +693,28 @@ class GPUModelRunner(LoRAModelRunnerMixin):
...
@@ -679,25 +693,28 @@ class GPUModelRunner(LoRAModelRunnerMixin):
# Prepare for cascade attention if enabled & beneficial.
# Prepare for cascade attention if enabled & beneficial.
common_prefix_len
=
0
common_prefix_len
=
0
builder
=
self
.
attn_metadata_builders
[
kv_cache_group_id
]
if
self
.
cascade_attn_enabled
:
if
self
.
cascade_attn_enabled
:
common_prefix_len
=
self
.
_compute_cascade_attn_prefix_len
(
common_prefix_len
=
self
.
_compute_cascade_attn_prefix_len
(
num_scheduled_tokens
,
num_scheduled_tokens
,
scheduler_output
.
scheduler_output
.
num_common_prefix_blocks
[
kv_cache_group_id
],
num_common_prefix_blocks
[
kv_cache_group_id
],
kv_cache_group_spec
.
kv_cache_spec
,
kv_cache_group_spec
.
kv_cache_spec
,
self
.
attn_metadata_builders
[
kv_cache_group_id
]
,
builder
,
)
)
attn_metadata_i
=
(
attn_metadata_i
=
(
builder
.
build
(
self
.
attn_metadata_builders
[
kv_cache_group_id
].
build
(
common_prefix_len
=
common_prefix_len
,
num_reqs
=
num_reqs
,
common_attn_metadata
=
common_attn_metadata
,
num_actual_tokens
=
total_num_scheduled_tokens
,
))
max_query_len
=
max_num_scheduled_tokens
,
common_prefix_len
=
common_prefix_len
,
common_attn_metadata
=
common_attn_metadata
))
for
layer_name
in
kv_cache_group_spec
.
layer_names
:
for
layer_name
in
kv_cache_group_spec
.
layer_names
:
attn_metadata
[
layer_name
]
=
attn_metadata_i
attn_metadata
[
layer_name
]
=
attn_metadata_i
attention_cuda_graphs
=
all
(
b
.
can_run_in_cudagraph
(
common_attn_metadata
)
for
b
in
self
.
attn_metadata_builders
)
use_spec_decode
=
len
(
use_spec_decode
=
len
(
scheduler_output
.
scheduled_spec_decode_tokens
)
>
0
scheduler_output
.
scheduled_spec_decode_tokens
)
>
0
if
not
use_spec_decode
:
if
not
use_spec_decode
:
...
@@ -726,7 +743,8 @@ class GPUModelRunner(LoRAModelRunnerMixin):
...
@@ -726,7 +743,8 @@ class GPUModelRunner(LoRAModelRunnerMixin):
if
self
.
lora_config
:
if
self
.
lora_config
:
self
.
set_active_loras
(
self
.
input_batch
,
num_scheduled_tokens
)
self
.
set_active_loras
(
self
.
input_batch
,
num_scheduled_tokens
)
return
attn_metadata
,
logits_indices
,
spec_decode_metadata
return
(
attn_metadata
,
attention_cuda_graphs
,
logits_indices
,
spec_decode_metadata
)
def
_compute_cascade_attn_prefix_len
(
def
_compute_cascade_attn_prefix_len
(
self
,
self
,
...
@@ -1121,7 +1139,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
...
@@ -1121,7 +1139,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
assert
self
.
intermediate_tensors
is
not
None
assert
self
.
intermediate_tensors
is
not
None
tp
=
self
.
vllm_config
.
parallel_config
.
tensor_parallel_size
tp
=
self
.
vllm_config
.
parallel_config
.
tensor_parallel_size
enabled_sp
=
self
.
vllm_config
.
compilation_config
.
pass_config
.
\
enabled_sp
=
self
.
compilation_config
.
pass_config
.
\
enable_sequence_parallelism
enable_sequence_parallelism
if
enabled_sp
:
if
enabled_sp
:
# When sequence parallelism is enabled, we always pad num_tokens
# When sequence parallelism is enabled, we always pad num_tokens
...
@@ -1189,8 +1207,8 @@ class GPUModelRunner(LoRAModelRunnerMixin):
...
@@ -1189,8 +1207,8 @@ class GPUModelRunner(LoRAModelRunnerMixin):
return
self
.
kv_connector_no_forward
(
scheduler_output
)
return
self
.
kv_connector_no_forward
(
scheduler_output
)
# Prepare the decoder inputs.
# Prepare the decoder inputs.
attn_metadata
,
logits_indices
,
spec_decode_metadata
=
(
(
attn_metadata
,
attention_cuda_graphs
,
logits_indices
,
self
.
_prepare_inputs
(
scheduler_output
))
spec_decode_metadata
)
=
(
self
.
_prepare_inputs
(
scheduler_output
))
num_scheduled_tokens
=
scheduler_output
.
total_num_scheduled_tokens
num_scheduled_tokens
=
scheduler_output
.
total_num_scheduled_tokens
if
(
self
.
use_cuda_graph
if
(
self
.
use_cuda_graph
and
num_scheduled_tokens
<=
self
.
cudagraph_batch_sizes
[
-
1
]):
and
num_scheduled_tokens
<=
self
.
cudagraph_batch_sizes
[
-
1
]):
...
@@ -1203,7 +1221,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
...
@@ -1203,7 +1221,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
# Pad tokens to multiple of tensor_parallel_size when
# Pad tokens to multiple of tensor_parallel_size when
# enabled collective fusion for SP
# enabled collective fusion for SP
tp_size
=
self
.
vllm_config
.
parallel_config
.
tensor_parallel_size
tp_size
=
self
.
vllm_config
.
parallel_config
.
tensor_parallel_size
if
self
.
vllm_config
.
compilation_config
.
pass_config
.
\
if
self
.
compilation_config
.
pass_config
.
\
enable_sequence_parallelism
and
tp_size
>
1
:
enable_sequence_parallelism
and
tp_size
>
1
:
from
vllm.utils
import
round_up
from
vllm.utils
import
round_up
num_input_tokens
=
round_up
(
num_scheduled_tokens
,
tp_size
)
num_input_tokens
=
round_up
(
num_scheduled_tokens
,
tp_size
)
...
@@ -1255,12 +1273,20 @@ class GPUModelRunner(LoRAModelRunnerMixin):
...
@@ -1255,12 +1273,20 @@ class GPUModelRunner(LoRAModelRunnerMixin):
intermediate_tensors
=
self
.
sync_and_slice_intermediate_tensors
(
intermediate_tensors
=
self
.
sync_and_slice_intermediate_tensors
(
num_input_tokens
,
intermediate_tensors
,
True
)
num_input_tokens
,
intermediate_tensors
,
True
)
# Some attention backends only support CUDA Graphs in pure decode.
# If attention doesn't support CUDA Graphs for this batch, but we
# compiled with full CUDA graphs, we have to skip them entirely.
skip_cuda_graphs
=
self
.
full_cuda_graph
and
not
attention_cuda_graphs
# Run the decoder.
# Run the decoder.
# Use persistent buffers for CUDA graphs.
# Use persistent buffers for CUDA graphs.
with
set_forward_context
(
attn_metadata
,
with
set_forward_context
(
self
.
vllm_config
,
attn_metadata
,
num_tokens
=
num_input_tokens
,
self
.
vllm_config
,
num_tokens_across_dp
=
num_tokens_across_dp
):
num_tokens
=
num_input_tokens
,
num_tokens_across_dp
=
num_tokens_across_dp
,
skip_cuda_graphs
=
skip_cuda_graphs
,
):
self
.
maybe_setup_kv_connector
(
scheduler_output
)
self
.
maybe_setup_kv_connector
(
scheduler_output
)
model_output
=
self
.
model
(
model_output
=
self
.
model
(
...
@@ -1769,7 +1795,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
...
@@ -1769,7 +1795,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
def
_dummy_run
(
def
_dummy_run
(
self
,
self
,
num_tokens
:
int
,
num_tokens
:
int
,
skip_attn
:
bool
=
Tru
e
,
capture_attn_cudagraph
:
bool
=
Fals
e
,
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
# Padding for DP
# Padding for DP
...
@@ -1790,9 +1816,10 @@ class GPUModelRunner(LoRAModelRunnerMixin):
...
@@ -1790,9 +1816,10 @@ class GPUModelRunner(LoRAModelRunnerMixin):
num_scheduled_tokens
=
np
.
array
(
num_scheduled_tokens_list
,
num_scheduled_tokens
=
np
.
array
(
num_scheduled_tokens_list
,
dtype
=
np
.
int32
)
dtype
=
np
.
int32
)
if
skip_attn
:
attn_metadata
:
Optional
[
dict
[
str
,
Any
]]
=
None
attn_metadata
:
Optional
[
dict
[
str
,
Any
]]
=
None
if
capture_attn_cudagraph
:
else
:
attn_metadata
=
{}
query_start_loc
=
self
.
query_start_loc
[:
num_reqs
+
1
]
query_start_loc
=
self
.
query_start_loc
[:
num_reqs
+
1
]
# Make sure max_model_len is used at the graph capture time.
# Make sure max_model_len is used at the graph capture time.
self
.
seq_lens_np
[:
num_reqs
]
=
self
.
max_model_len
self
.
seq_lens_np
[:
num_reqs
]
=
self
.
max_model_len
...
@@ -1802,19 +1829,19 @@ class GPUModelRunner(LoRAModelRunnerMixin):
...
@@ -1802,19 +1829,19 @@ class GPUModelRunner(LoRAModelRunnerMixin):
seq_lens
=
self
.
seq_lens
[:
num_reqs
]
seq_lens
=
self
.
seq_lens
[:
num_reqs
]
common_attn_metadata
=
CommonAttentionMetadata
(
common_attn_metadata
=
CommonAttentionMetadata
(
query_start_loc
=
query_start_loc
,
seq_lens
=
seq_lens
)
query_start_loc
=
query_start_loc
,
seq_lens
=
seq_lens
,
num_reqs
=
num_reqs
,
num_actual_tokens
=
num_tokens
,
max_query_len
=
num_tokens
,
)
attn_metadata
=
{}
for
kv_cache_group_id
,
kv_cache_group_spec
in
enumerate
(
for
kv_cache_group_id
,
kv_cache_group_spec
in
enumerate
(
self
.
kv_cache_config
.
kv_cache_groups
):
self
.
kv_cache_config
.
kv_cache_groups
):
attn_metadata_i
=
(
self
.
attn_metadata_builders
[
kv_cache_group_id
].
build
(
attn_metadata_i
=
self
.
attn_metadata_builders
[
num_reqs
=
num_reqs
,
kv_cache_group_id
].
build_for_cudagraph_capture
(
num_actual_tokens
=
num_tokens
,
common_attn_metadata
)
max_query_len
=
num_tokens
,
common_prefix_len
=
0
,
common_attn_metadata
=
common_attn_metadata
,
))
for
layer_name
in
kv_cache_group_spec
.
layer_names
:
for
layer_name
in
kv_cache_group_spec
.
layer_names
:
attn_metadata
[
layer_name
]
=
attn_metadata_i
attn_metadata
[
layer_name
]
=
attn_metadata_i
...
@@ -2039,14 +2066,14 @@ class GPUModelRunner(LoRAModelRunnerMixin):
...
@@ -2039,14 +2066,14 @@ class GPUModelRunner(LoRAModelRunnerMixin):
# Capture the large shapes first so that the smaller shapes
# Capture the large shapes first so that the smaller shapes
# can reuse the memory pool allocated for the large shapes.
# can reuse the memory pool allocated for the large shapes.
with
graph_capture
(
device
=
self
.
device
):
with
graph_capture
(
device
=
self
.
device
):
skip_attn
=
not
self
.
vllm_config
.
compilation_config
.
full_cuda_graph
full_cg
=
self
.
full_cuda_graph
for
num_tokens
in
tqdm
(
reversed
(
self
.
cudagraph_batch_sizes
),
for
num_tokens
in
tqdm
(
reversed
(
self
.
cudagraph_batch_sizes
),
desc
=
"Capturing CUDA graphs"
,
desc
=
"Capturing CUDA graphs"
,
total
=
len
(
self
.
cudagraph_batch_sizes
)):
total
=
len
(
self
.
cudagraph_batch_sizes
)):
for
_
in
range
(
self
.
vllm_config
.
compilation_config
.
for
_
in
range
(
cudagraph_num_of_warmups
):
self
.
compilation_config
.
cudagraph_num_of_warmups
):
self
.
_dummy_run
(
num_tokens
,
skip_attn
=
skip_attn
)
self
.
_dummy_run
(
num_tokens
,
capture_attn_cudagraph
=
full_cg
)
self
.
_dummy_run
(
num_tokens
,
skip_attn
=
skip_attn
)
self
.
_dummy_run
(
num_tokens
,
capture_attn_cudagraph
=
full_cg
)
end_time
=
time
.
perf_counter
()
end_time
=
time
.
perf_counter
()
end_free_gpu_memory
=
torch
.
cuda
.
mem_get_info
()[
0
]
end_free_gpu_memory
=
torch
.
cuda
.
mem_get_info
()[
0
]
...
@@ -2089,20 +2116,20 @@ class GPUModelRunner(LoRAModelRunnerMixin):
...
@@ -2089,20 +2116,20 @@ class GPUModelRunner(LoRAModelRunnerMixin):
"Non-Attention backend is not supported by V1 "
"Non-Attention backend is not supported by V1 "
"GPUModelRunner."
)
"GPUModelRunner."
)
if
self
.
vllm_config
.
compilation_config
.
full_cuda_graph
:
attn_backend_name
=
attn_backend_i
.
__name__
flash_attn_version
=
get_flash_attn_version
()
if
attn_backend_name
!=
"FlashAttentionBackend"
or
\
flash_attn_version
!=
3
:
raise
ValueError
(
f
"full_cuda_graph is only supported with "
f
"FA3. Current attention backend is "
f
"
{
attn_backend_name
}
, FlashAttention version is "
f
"
{
flash_attn_version
}
."
)
block_table_i
=
self
.
input_batch
.
block_table
[
i
]
block_table_i
=
self
.
input_batch
.
block_table
[
i
]
attn_metadata_builder_i
=
attn_backend_i
.
get_builder_cls
()(
attn_metadata_builder_i
=
attn_backend_i
.
get_builder_cls
()(
weakref
.
proxy
(
self
),
kv_cache_spec
,
block_table_i
)
weakref
.
proxy
(
self
),
kv_cache_spec
,
block_table_i
,
)
if
(
self
.
full_cuda_graph
and
not
attn_metadata_builder_i
.
full_cudagraph_supported
):
raise
ValueError
(
f
"Full CUDAGraph not supported for "
f
"
{
attn_backend_i
.
__name__
}
. Turn off CompilationConfig."
f
"full_cuda_graph or use a different attention backend."
)
self
.
attn_backends
.
append
(
attn_backend_i
)
self
.
attn_backends
.
append
(
attn_backend_i
)
self
.
attn_metadata_builders
.
append
(
attn_metadata_builder_i
)
self
.
attn_metadata_builders
.
append
(
attn_metadata_builder_i
)
...
@@ -2142,9 +2169,9 @@ class GPUModelRunner(LoRAModelRunnerMixin):
...
@@ -2142,9 +2169,9 @@ class GPUModelRunner(LoRAModelRunnerMixin):
to be reshaped to the desired shape before being used by the models.
to be reshaped to the desired shape before being used by the models.
Args:
Args:
kv_cache_config: The KV cache config
kv_cache_config: The KV cache config
Returns:
Returns:
dict[str, torch.Tensor]: A map between layer names to their
dict[str, torch.Tensor]: A map between layer names to their
corresponding memory buffer for KV cache.
corresponding memory buffer for KV cache.
"""
"""
kv_cache_raw_tensors
:
dict
[
str
,
torch
.
Tensor
]
=
{}
kv_cache_raw_tensors
:
dict
[
str
,
torch
.
Tensor
]
=
{}
...
@@ -2171,11 +2198,11 @@ class GPUModelRunner(LoRAModelRunnerMixin):
...
@@ -2171,11 +2198,11 @@ class GPUModelRunner(LoRAModelRunnerMixin):
Reshape the KV cache tensors to the desired shape and dtype.
Reshape the KV cache tensors to the desired shape and dtype.
Args:
Args:
kv_cache_config: The KV cache config
kv_cache_config: The KV cache config
kv_cache_raw_tensors: The KV cache buffer of each layer, with
kv_cache_raw_tensors: The KV cache buffer of each layer, with
correct size but uninitialized shape.
correct size but uninitialized shape.
Returns:
Returns:
Dict[str, torch.Tensor]: A map between layer names to their
Dict[str, torch.Tensor]: A map between layer names to their
corresponding memory buffer for KV cache.
corresponding memory buffer for KV cache.
"""
"""
kv_caches
:
dict
[
str
,
torch
.
Tensor
]
=
{}
kv_caches
:
dict
[
str
,
torch
.
Tensor
]
=
{}
...
@@ -2227,7 +2254,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
...
@@ -2227,7 +2254,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
Args:
Args:
kv_cache_config: The KV cache config
kv_cache_config: The KV cache config
Returns:
Returns:
Dict[str, torch.Tensor]: A map between layer names to their
Dict[str, torch.Tensor]: A map between layer names to their
corresponding memory buffer for KV cache.
corresponding memory buffer for KV cache.
"""
"""
# Initialize the memory buffer for KV cache
# Initialize the memory buffer for KV cache
...
@@ -2245,10 +2272,9 @@ class GPUModelRunner(LoRAModelRunnerMixin):
...
@@ -2245,10 +2272,9 @@ class GPUModelRunner(LoRAModelRunnerMixin):
kv_caches
,
kv_caches
,
)
)
bind_kv_cache
(
bind_kv_cache
(
kv_caches
,
kv_caches
,
self
.
compilation_config
.
static_forward_context
,
self
.
vllm_config
.
compilation_config
.
static_forward_context
,
self
.
kv_caches
)
self
.
kv_caches
)
return
kv_caches
return
kv_caches
def
initialize_kv_cache
(
self
,
kv_cache_config
:
KVCacheConfig
)
->
None
:
def
initialize_kv_cache
(
self
,
kv_cache_config
:
KVCacheConfig
)
->
None
:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment