Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
dc5ce861
Unverified
Commit
dc5ce861
authored
Dec 02, 2024
by
youkaichao
Committed by
GitHub
Dec 03, 2024
Browse files
[torch.compile] remove compilation_context and simplify code (#10838)
Signed-off-by:
youkaichao
<
youkaichao@gmail.com
>
parent
21fe7b48
Changes
14
Hide whitespace changes
Inline
Side-by-side
Showing
14 changed files
with
128 additions
and
143 deletions
+128
-143
tests/compile/piecewise/test_simple.py
tests/compile/piecewise/test_simple.py
+4
-5
tests/compile/piecewise/test_toy_llama.py
tests/compile/piecewise/test_toy_llama.py
+17
-16
tests/models/decoder_only/language/test_jamba.py
tests/models/decoder_only/language/test_jamba.py
+3
-2
tests/models/decoder_only/language/test_mamba.py
tests/models/decoder_only/language/test_mamba.py
+3
-2
tests/worker/test_encoder_decoder_model_runner.py
tests/worker/test_encoder_decoder_model_runner.py
+2
-2
tests/worker/test_model_runner.py
tests/worker/test_model_runner.py
+3
-2
vllm/compilation/backends.py
vllm/compilation/backends.py
+0
-4
vllm/compilation/compile_context.py
vllm/compilation/compile_context.py
+0
-23
vllm/config.py
vllm/config.py
+75
-8
vllm/model_executor/models/jamba.py
vllm/model_executor/models/jamba.py
+2
-4
vllm/model_executor/models/mamba.py
vllm/model_executor/models/mamba.py
+2
-4
vllm/v1/worker/gpu_model_runner.py
vllm/v1/worker/gpu_model_runner.py
+8
-6
vllm/worker/enc_dec_model_runner.py
vllm/worker/enc_dec_model_runner.py
+3
-3
vllm/worker/model_runner.py
vllm/worker/model_runner.py
+6
-62
No files found.
tests/compile/piecewise/test_simple.py
View file @
dc5ce861
...
...
@@ -7,7 +7,6 @@ import torch
from
torch
import
nn
from
torch.library
import
Library
from
vllm.compilation.compile_context
import
set_compile_context
from
vllm.compilation.counter
import
compilation_counter
from
vllm.compilation.decorators
import
support_torch_compile
from
vllm.config
import
(
CompilationConfig
,
CompilationLevel
,
VllmConfig
,
...
...
@@ -81,6 +80,7 @@ def test_simple_piecewise_compile():
use_cudagraph
=
True
,
splitting_ops
=
[
"silly.attention"
],
cudagraph_copy_inputs
=
True
,
cudagraph_capture_sizes
=
[
1
,
2
],
))
with
set_current_vllm_config
(
vllm_config
):
model
=
SillyModel
(
vllm_config
=
vllm_config
,
prefix
=
''
)
...
...
@@ -96,11 +96,10 @@ def test_simple_piecewise_compile():
6
,
# num_cudagraph_sizes * num_piecewise_capturable_graphs_seen
):
with
set_compile_context
([
1
,
2
]):
model
(
inputs
)
model
(
inputs
)
model
(
torch
.
randn
(
2
).
cuda
())
model
(
torch
.
randn
(
1
).
cuda
())
model
(
torch
.
randn
(
2
).
cuda
())
model
(
torch
.
randn
(
1
).
cuda
())
input
=
torch
.
zeros
(
2
).
cuda
()
global
global_counter
...
...
tests/compile/piecewise/test_toy_llama.py
View file @
dc5ce861
...
...
@@ -13,7 +13,6 @@ import torch
from
torch
import
nn
from
torch.library
import
Library
from
vllm.compilation.compile_context
import
set_compile_context
from
vllm.compilation.counter
import
compilation_counter
from
vllm.compilation.decorators
import
support_torch_compile
from
vllm.config
import
(
CompilationConfig
,
CompilationLevel
,
VllmConfig
,
...
...
@@ -256,6 +255,7 @@ def run_model(llama_config,
compilation_config
=
CompilationConfig
(
level
=
CompilationLevel
.
PIECEWISE
,
use_cudagraph
=
True
,
cudagraph_capture_sizes
=
[
1
,
2
],
)
if
split_attn
:
compilation_config
.
splitting_ops
=
[
"silly.attention"
]
...
...
@@ -273,10 +273,9 @@ def run_model(llama_config,
input_ids
=
torch
.
randint
(
0
,
llama_config
.
vocab_size
,
(
B
,
)).
cuda
()
positions
=
torch
.
arange
(
B
).
cuda
()
with
set_compile_context
([
1
,
2
]):
model
(
input_ids
,
positions
)
model
(
input_ids
[:
2
],
positions
[:
2
])
model
(
input_ids
[:
1
],
positions
[:
1
])
model
(
input_ids
,
positions
)
model
(
input_ids
[:
2
],
positions
[:
2
])
model
(
input_ids
[:
1
],
positions
[:
1
])
input_ids
[:
2
].
zero_
()
output
=
model
(
input_ids
[:
2
],
positions
[:
2
])
...
...
@@ -379,10 +378,13 @@ def benchmark():
level
=
CompilationLevel
.
PIECEWISE
,
use_cudagraph
=
True
,
splitting_ops
=
[
"silly.attention"
],
cudagraph_capture_sizes
=
cudagraph_sizes
,
)
else
:
compilation_config
=
CompilationConfig
(
level
=
CompilationLevel
.
PIECEWISE
,
)
level
=
CompilationLevel
.
PIECEWISE
,
cudagraph_capture_sizes
=
cudagraph_sizes
,
)
vllm_config
=
VllmConfig
(
compilation_config
=
compilation_config
)
with
set_current_vllm_config
(
vllm_config
):
...
...
@@ -396,17 +398,16 @@ def benchmark():
graphs
=
{}
with
set_compile_context
(
cudagraph_sizes
):
model
(
input_ids
,
positions
)
for
b
in
cudagraph_sizes
[::
-
1
]:
if
not
piecewise
:
graph
=
torch
.
cuda
.
CUDAGraph
()
with
torch
.
cuda
.
graph
(
graph
,
pool
=
pool
):
output
=
model
(
input_ids
[:
b
],
positions
[:
b
])
graphs
[
b
]
=
(
graph
,
output
)
else
:
model
(
input_ids
,
positions
)
for
b
in
cudagraph_sizes
[::
-
1
]:
if
not
piecewise
:
graph
=
torch
.
cuda
.
CUDAGraph
()
with
torch
.
cuda
.
graph
(
graph
,
pool
=
pool
):
output
=
model
(
input_ids
[:
b
],
positions
[:
b
])
graphs
[
b
]
=
(
model
,
output
)
graphs
[
b
]
=
(
graph
,
output
)
else
:
output
=
model
(
input_ids
[:
b
],
positions
[:
b
])
graphs
[
b
]
=
(
model
,
output
)
for
b
in
cudagraph_sizes
:
if
piecewise
:
# noqa is for `Function definition does not bind loop variable`
...
...
tests/models/decoder_only/language/test_jamba.py
View file @
dc5ce861
import
pytest
from
tests.utils
import
multi_gpu_test
from
vllm.config
import
VllmConfig
from
vllm.sampling_params
import
SamplingParams
from
vllm.worker.model_runner
import
_get_graph_batch_size
from
...utils
import
check_outputs_equal
...
...
@@ -189,7 +189,8 @@ def test_mamba_cache_cg_padding(
# This test is for verifying that mamba cache is padded to CG captured
# batch size. If it's not, a torch RuntimeError will be raised because
# tensor dimensions aren't compatible
while
len
(
example_prompts
)
==
_get_graph_batch_size
(
len
(
example_prompts
)):
while
len
(
example_prompts
)
==
VllmConfig
.
get_graph_batch_size
(
len
(
example_prompts
)):
example_prompts
.
append
(
example_prompts
[
0
])
try
:
...
...
tests/models/decoder_only/language/test_mamba.py
View file @
dc5ce861
...
...
@@ -5,8 +5,8 @@ Run `pytest tests/models/test_mamba.py`.
import
pytest
from
transformers
import
AutoModelForCausalLM
,
AutoTokenizer
from
vllm.config
import
VllmConfig
from
vllm.sampling_params
import
SamplingParams
from
vllm.worker.model_runner
import
_get_graph_batch_size
from
...utils
import
check_outputs_equal
...
...
@@ -200,7 +200,8 @@ def test_mamba_cache_cg_padding(
# This test is for verifying that mamba cache is padded to CG captured
# batch size. If it's not, a torch RuntimeError will be raised because
# tensor dimensions aren't compatible
while
len
(
example_prompts
)
==
_get_graph_batch_size
(
len
(
example_prompts
)):
while
len
(
example_prompts
)
==
VllmConfig
.
get_graph_batch_size
(
len
(
example_prompts
)):
example_prompts
.
append
(
example_prompts
[
0
])
try
:
...
...
tests/worker/test_encoder_decoder_model_runner.py
View file @
dc5ce861
...
...
@@ -4,12 +4,12 @@ from typing import List
import
pytest
import
torch
from
vllm.config
import
VllmConfig
from
vllm.engine.arg_utils
import
EngineArgs
from
vllm.platforms
import
current_platform
from
vllm.sequence
import
SamplingParams
,
SequenceData
,
SequenceGroupMetadata
from
vllm.utils
import
make_tensor_with_pad
from
vllm.worker.enc_dec_model_runner
import
EncoderDecoderModelRunner
from
vllm.worker.model_runner
import
_get_graph_batch_size
BATCH_SIZES
=
[
1
,
4
,
16
,
64
,
256
]
...
...
@@ -548,7 +548,7 @@ def test_prepare_decode_cuda_graph(batch_size, multiple_seqs_per_seq_group):
# With CUDA Graph capture and replay enabled, the decoder and encoder
# input sequences will be padded. Create the expected padded tensors
# accordingly.
graph_batch_size
=
_
get_graph_batch_size
(
expanded_batch_size
)
graph_batch_size
=
VllmConfig
.
get_graph_batch_size
(
expanded_batch_size
)
cuda_graph_pad_size
=
graph_batch_size
-
expanded_batch_size
padded_seq_lens
=
seq_lens
+
list
(
itertools
.
repeat
(
1
,
cuda_graph_pad_size
))
padded_encoder_seq_lens
=
encoder_seq_lens
+
list
(
...
...
tests/worker/test_model_runner.py
View file @
dc5ce861
...
...
@@ -3,13 +3,14 @@ from typing import List
import
pytest
import
torch
from
vllm.config
import
VllmConfig
from
vllm.distributed.parallel_state
import
(
ensure_model_parallel_initialized
,
init_distributed_environment
)
from
vllm.engine.arg_utils
import
EngineArgs
from
vllm.model_executor.sampling_metadata
import
SamplingMetadata
from
vllm.sequence
import
SamplingParams
,
SequenceData
,
SequenceGroupMetadata
from
vllm.utils
import
get_open_port
from
vllm.worker.model_runner
import
ModelRunner
,
_get_graph_batch_size
from
vllm.worker.model_runner
import
ModelRunner
def
_create_model_runner
(
model
:
str
,
*
args
,
**
kwargs
)
->
ModelRunner
:
...
...
@@ -176,7 +177,7 @@ def test_prepare_decode_cuda_graph(batch_size):
model_input
.
attn_metadata
,
model_input
.
attn_metadata
.
slot_mapping
)
assert
len
(
slot_mapping
)
==
len
(
input_tokens
)
expected_bs
=
_
get_graph_batch_size
(
len
(
seq_group_metadata_list
))
expected_bs
=
VllmConfig
.
get_graph_batch_size
(
len
(
seq_group_metadata_list
))
# Verify input metadata is correct for prompts.
device
=
model_runner
.
device
assert
attn_metadata
.
num_prefills
==
0
...
...
vllm/compilation/backends.py
View file @
dc5ce861
...
...
@@ -242,10 +242,6 @@ class VllmBackend:
assert
not
self
.
_called
,
"VllmBackend can only be called once"
self
.
graph
=
graph
# config is updated now, because only here can
# we get the sizes to capture for cudagraph
# from compilation context
self
.
compilation_configs
.
init_during_runtime
()
self
.
configure_post_pass
()
self
.
split_gm
,
self
.
piecewise_graphs
=
split_graph
(
...
...
vllm/compilation/compile_context.py
deleted
100644 → 0
View file @
21fe7b48
from
contextlib
import
contextmanager
from
typing
import
Any
_compile_context
:
Any
=
None
def
get_compile_context
()
->
Any
:
"""Get the current compile context."""
return
_compile_context
@
contextmanager
def
set_compile_context
(
context
:
Any
):
"""A context manager that stores the current compile context,
usually it is a list of sizes to specialize.
"""
global
_compile_context
prev_context
=
_compile_context
_compile_context
=
context
try
:
yield
finally
:
_compile_context
=
prev_context
vllm/config.py
View file @
dc5ce861
...
...
@@ -2357,15 +2357,10 @@ class CompilationConfig(BaseModel):
from
vllm.compilation.backends
import
VllmBackend
return
VllmBackend
(
self
)
def
init_
during_runtime
(
self
):
def
init_
with_cudagraph_sizes
(
self
,
sizes_to_specialize
:
List
[
int
]
):
"""To complete the initialization of config,
we need to know the compile context, which is only available
during the first run of the model.
"""
from
vllm.compilation.compile_context
import
get_compile_context
context
=
get_compile_context
()
context
=
copy
.
deepcopy
(
context
)
if
context
is
not
None
else
[]
sizes_to_specialize
:
List
[
int
]
=
context
we need to know the cudagraph sizes."""
if
self
.
cudagraph_capture_sizes
is
None
:
self
.
capture_sizes
=
sizes_to_specialize
else
:
...
...
@@ -2386,6 +2381,21 @@ class CompilationConfig(BaseModel):
self
.
inductor_compile_sizes
=
[]
self
.
compile_sizes
=
self
.
inductor_compile_sizes
# sort to make sure cudagraph capture sizes are in descending order
self
.
capture_sizes
.
sort
(
reverse
=
True
)
_BATCH_SIZE_ALIGNMENT
=
8
# all the token sizes that **can** be captured by cudagraph.
# they can be arbitrarily large.
# currently it includes: 1, 2, 4, 8, 16, 24, 32, 40, ..., 8192.
# the actual sizes to capture will be determined by the model,
# depending on the model's max_num_seqs.
# NOTE: get_graph_batch_size needs to be updated if this list is changed.
_BATCH_SIZES_TO_CAPTURE
=
[
1
,
2
,
4
]
+
[
_BATCH_SIZE_ALIGNMENT
*
i
for
i
in
range
(
1
,
1025
)
]
@
dataclass
class
VllmConfig
:
...
...
@@ -2413,6 +2423,41 @@ class VllmConfig:
kv_transfer_config
:
KVTransferConfig
=
field
(
default
=
None
,
init
=
True
)
# type: ignore
@
staticmethod
def
get_graph_batch_size
(
batch_size
:
int
)
->
int
:
"""Returns the padded batch size given actual batch size.
Batch sizes are 1, 2, 4, _BATCH_SIZE_ALIGNMENT,
2*_BATCH_SIZE_ALIGNMENT, 3*_BATCH_SIZE_ALIGNMENT...
"""
if
batch_size
<=
2
:
return
batch_size
elif
batch_size
<=
4
:
return
4
else
:
return
((
batch_size
+
_BATCH_SIZE_ALIGNMENT
-
1
)
//
_BATCH_SIZE_ALIGNMENT
*
_BATCH_SIZE_ALIGNMENT
)
@
staticmethod
def
get_max_graph_batch_size
(
max_num_seqs
:
int
)
->
int
:
"""
max_num_seqs: Maximum number of sequences in a batch.
_BATCH_SIZES_TO_CAPTURE: all the sizes that we want to capture.
pad the max_num_seqs if necessary by calling get_graph_batch_size,
which will deal with some edge cases like 1, 2, 4.
if the padded size is in _BATCH_SIZES_TO_CAPTURE, return the padded
size. if not, it means the padded size is larger than the largest size
in _BATCH_SIZES_TO_CAPTURE, return the largest size in
_BATCH_SIZES_TO_CAPTURE.
"""
padded_size
=
VllmConfig
.
get_graph_batch_size
(
max_num_seqs
)
if
padded_size
in
_BATCH_SIZES_TO_CAPTURE
:
return
padded_size
assert
padded_size
>
_BATCH_SIZES_TO_CAPTURE
[
-
1
]
return
_BATCH_SIZES_TO_CAPTURE
[
-
1
]
@
staticmethod
def
_get_quantization_config
(
model_config
:
ModelConfig
,
...
...
@@ -2496,6 +2541,28 @@ class VllmConfig:
self
.
compilation_config
.
pass_config
.
enable_reshape
=
False
self
.
compilation_config
.
level
=
CompilationLevel
.
PIECEWISE
if
not
envs
.
VLLM_USE_V1
:
max_batchsize_to_capture
=
0
if
self
.
scheduler_config
is
not
None
and
\
self
.
model_config
is
not
None
and
\
not
self
.
model_config
.
enforce_eager
:
max_batchsize_to_capture
=
\
self
.
get_max_graph_batch_size
(
self
.
scheduler_config
.
max_num_seqs
)
batch_size_capture_list
=
[
size
for
size
in
_BATCH_SIZES_TO_CAPTURE
if
size
<=
max_batchsize_to_capture
]
else
:
batch_size_capture_list
=
[]
if
self
.
model_config
is
not
None
and
\
not
self
.
model_config
.
enforce_eager
:
batch_size_capture_list
=
[
1
,
2
,
4
]
+
[
i
for
i
in
range
(
8
,
513
,
8
)]
self
.
compilation_config
.
init_with_cudagraph_sizes
(
batch_size_capture_list
)
if
self
.
cache_config
is
not
None
and
\
self
.
cache_config
.
cpu_offload_gb
>
0
and
\
self
.
compilation_config
.
level
!=
CompilationLevel
.
NO_COMPILATION
:
...
...
vllm/model_executor/models/jamba.py
View file @
dc5ce861
...
...
@@ -7,7 +7,7 @@ from transformers import JambaConfig
from
vllm.attention.backends.abstract
import
AttentionMetadata
from
vllm.attention.layer
import
Attention
from
vllm.config
import
CacheConfig
,
VllmConfig
from
vllm.config
import
_BATCH_SIZES_TO_CAPTURE
,
CacheConfig
,
VllmConfig
from
vllm.distributed
import
get_tensor_model_parallel_world_size
from
vllm.model_executor.layers.fused_moe
import
FusedMoE
from
vllm.model_executor.layers.layernorm
import
RMSNorm
...
...
@@ -25,8 +25,6 @@ from vllm.model_executor.models.mamba_cache import (MambaCacheManager,
MambaCacheParams
)
from
vllm.model_executor.sampling_metadata
import
SamplingMetadata
from
vllm.sequence
import
IntermediateTensors
from
vllm.worker.model_runner
import
(
_BATCH_SIZES_TO_CAPTURE
,
_get_graph_batch_size
)
from
.interfaces
import
HasInnerState
,
SupportsLoRA
from
.utils
import
maybe_prefix
...
...
@@ -404,7 +402,7 @@ class JambaForCausalLM(nn.Module, HasInnerState, SupportsLoRA):
inputs_embeds
:
Optional
[
torch
.
Tensor
]
=
None
,
**
kwargs
):
if
self
.
mamba_cache
is
None
:
max_batch_size
=
(
_
get_graph_batch_size
(
max_batch_size
=
(
VllmConfig
.
get_graph_batch_size
(
self
.
scheduler_config
.
max_num_seqs
)
if
self
.
scheduler_config
else
max
(
_BATCH_SIZES_TO_CAPTURE
)
+
2
)
...
...
vllm/model_executor/models/mamba.py
View file @
dc5ce861
...
...
@@ -6,7 +6,7 @@ from torch import nn
from
transformers
import
MambaConfig
from
vllm.attention.backends.abstract
import
AttentionMetadata
from
vllm.config
import
CacheConfig
,
VllmConfig
from
vllm.config
import
_BATCH_SIZES_TO_CAPTURE
,
CacheConfig
,
VllmConfig
from
vllm.distributed
import
get_tensor_model_parallel_world_size
from
vllm.model_executor.layers.layernorm
import
RMSNorm
from
vllm.model_executor.layers.logits_processor
import
LogitsProcessor
...
...
@@ -23,8 +23,6 @@ from vllm.model_executor.models.mamba_cache import (MambaCacheManager,
MambaCacheParams
)
from
vllm.model_executor.sampling_metadata
import
SamplingMetadata
from
vllm.sequence
import
IntermediateTensors
from
vllm.worker.model_runner
import
(
_BATCH_SIZES_TO_CAPTURE
,
_get_graph_batch_size
)
from
.utils
import
maybe_prefix
...
...
@@ -187,7 +185,7 @@ class MambaForCausalLM(nn.Module, HasInnerState, IsAttentionFree):
inputs_embeds
:
Optional
[
torch
.
Tensor
]
=
None
,
**
kwargs
):
if
self
.
mamba_cache
is
None
:
max_batch_size
=
(
_
get_graph_batch_size
(
max_batch_size
=
(
VllmConfig
.
get_graph_batch_size
(
self
.
scheduler_config
.
max_num_seqs
)
if
self
.
scheduler_config
else
max
(
_BATCH_SIZES_TO_CAPTURE
)
+
2
)
self
.
mamba_cache
=
MambaCacheManager
(
...
...
vllm/v1/worker/gpu_model_runner.py
View file @
dc5ce861
...
...
@@ -8,7 +8,6 @@ import torch
import
torch.distributed
import
torch.nn
as
nn
from
vllm.compilation.compile_context
import
set_compile_context
from
vllm.config
import
CompilationLevel
,
VllmConfig
from
vllm.distributed.parallel_state
import
graph_capture
from
vllm.forward_context
import
set_forward_context
...
...
@@ -100,7 +99,11 @@ class GPUModelRunner:
==
CompilationLevel
.
PIECEWISE
and
not
self
.
model_config
.
enforce_eager
)
# TODO(woosuk): Provide an option to tune the max cudagraph batch size.
self
.
cudagraph_batch_sizes
=
[
1
,
2
,
4
]
+
[
i
for
i
in
range
(
8
,
513
,
8
)]
# The convention is different.
# self.cudagraph_batch_sizes sorts in ascending order.
# The batch sizes in the config are in descending order.
self
.
cudagraph_batch_sizes
=
list
(
reversed
(
self
.
vllm_config
.
compilation_config
.
capture_sizes
))
self
.
positions
=
torch
.
zeros
(
self
.
max_num_tokens
,
dtype
=
torch
.
int64
,
device
=
self
.
device
)
...
...
@@ -548,10 +551,9 @@ class GPUModelRunner:
torch
.
tensor
([],
dtype
=
torch
.
float32
,
device
=
self
.
device
)
for
_
in
range
(
self
.
num_attn_layers
)
]
with
set_compile_context
(
self
.
cudagraph_batch_sizes
):
# Trigger compilation for general shape.
hidden_states
=
self
.
_dummy_run
(
self
.
model
,
self
.
max_num_tokens
,
dummy_kv_caches
)
# Trigger compilation for general shape.
hidden_states
=
self
.
_dummy_run
(
self
.
model
,
self
.
max_num_tokens
,
dummy_kv_caches
)
logits
=
self
.
model
.
compute_logits
(
hidden_states
,
None
)
logits
=
logits
[:
self
.
max_num_tokens
]
# TODO(woosuk): Consider the memory usage of the sampler.
...
...
vllm/worker/enc_dec_model_runner.py
View file @
dc5ce861
...
...
@@ -25,8 +25,7 @@ from vllm.sequence import (IntermediateTensors, PoolerOutput,
from
vllm.utils
import
STR_NOT_IMPL_ENC_DEC_BACKEND
,
make_tensor_with_pad
from
vllm.worker.model_runner
import
(
GPUModelRunnerBase
,
ModelInputForGPUBuilder
,
ModelInputForGPUWithSamplingMetadata
,
_get_graph_batch_size
)
ModelInputForGPUWithSamplingMetadata
)
from
vllm.worker.model_runner_base
import
(
_add_attn_metadata_broadcastable_dict
,
_add_sampling_metadata_broadcastable_dict
)
...
...
@@ -465,7 +464,8 @@ class EncoderDecoderModelRunner(GPUModelRunnerBase[EncoderDecoderModelInput]):
# We will be using CUDA graph replay for this decode.
max_len_of_block_table
=
self
.
get_max_block_per_batch
()
batch_size
=
len
(
encoder_seq_lens
)
graph_batch_size
=
_get_graph_batch_size
(
batch_size
)
graph_batch_size
=
self
.
vllm_config
.
get_graph_batch_size
(
batch_size
)
assert
graph_batch_size
>=
batch_size
cuda_graph_pad_size
=
graph_batch_size
-
batch_size
# extend the cross_block_tables and encoder_seq_lens to match
...
...
vllm/worker/model_runner.py
View file @
dc5ce861
...
...
@@ -18,7 +18,6 @@ import vllm.envs as envs
from
vllm.attention
import
AttentionMetadata
,
get_attn_backend
from
vllm.attention.backends.abstract
import
AttentionState
from
vllm.attention.backends.utils
import
CommonAttentionState
from
vllm.compilation.compile_context
import
set_compile_context
from
vllm.config
import
CompilationLevel
,
VllmConfig
from
vllm.core.scheduler
import
SchedulerOutputs
from
vllm.distributed
import
get_kv_transfer_group
,
get_pp_group
...
...
@@ -63,16 +62,7 @@ if TYPE_CHECKING:
logger
=
init_logger
(
__name__
)
LORA_WARMUP_RANK
=
8
_BATCH_SIZE_ALIGNMENT
=
8
# all the token sizes that **can** be captured by cudagraph.
# they can be arbitrarily large.
# currently it includes: 1, 2, 4, 8, 16, 24, 32, 40, ..., 8192.
# the actual sizes to capture will be determined by the model,
# depending on the model's max_num_seqs.
# NOTE: _get_graph_batch_size needs to be updated if this list is changed.
_BATCH_SIZES_TO_CAPTURE
=
[
1
,
2
,
4
]
+
[
_BATCH_SIZE_ALIGNMENT
*
i
for
i
in
range
(
1
,
1025
)
]
_NUM_WARMUP_ITERS
=
2
TModelInputForGPU
=
TypeVar
(
'TModelInputForGPU'
,
bound
=
"ModelInputForGPU"
)
...
...
@@ -763,7 +753,6 @@ class ModelInputForGPUBuilder(ModelRunnerInputBuilderBase[ModelInputForGPU]):
max_decode_seq_len
:
int
,
max_encoder_seq_len
:
int
=
0
)
->
bool
:
return
(
decode_only
and
not
self
.
runner
.
model_config
.
enforce_eager
and
batch_size
<=
_BATCH_SIZES_TO_CAPTURE
[
-
1
]
and
max_decode_seq_len
<=
self
.
runner
.
max_seq_len_to_capture
and
max_encoder_seq_len
<=
self
.
runner
.
max_seq_len_to_capture
and
batch_size
<=
self
.
runner
.
max_batchsize_to_capture
)
...
...
@@ -811,7 +800,7 @@ class ModelInputForGPUBuilder(ModelRunnerInputBuilderBase[ModelInputForGPU]):
max_encoder_seq_len
):
return
-
1
graph_batch_size
=
_
get_graph_batch_size
(
batch_size
)
graph_batch_size
=
VllmConfig
.
get_graph_batch_size
(
batch_size
)
assert
graph_batch_size
>=
batch_size
return
graph_batch_size
-
batch_size
...
...
@@ -1023,7 +1012,7 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]):
self
.
sliding_window
=
model_config
.
get_sliding_window
()
self
.
block_size
=
cache_config
.
block_size
self
.
max_seq_len_to_capture
=
self
.
model_config
.
max_seq_len_to_capture
self
.
max_batchsize_to_capture
=
_
get_max_graph_batch_size
(
self
.
max_batchsize_to_capture
=
VllmConfig
.
get_max_graph_batch_size
(
self
.
scheduler_config
.
max_num_seqs
)
self
.
graph_runners
:
List
[
Dict
[
int
,
CUDAGraphRunner
]]
=
[
...
...
@@ -1333,14 +1322,7 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]):
dtype
=
self
.
model_config
.
dtype
,
device
=
self
.
device
)
graph_batch_size
=
self
.
max_batchsize_to_capture
batch_size_capture_list
=
[
bs
for
bs
in
_BATCH_SIZES_TO_CAPTURE
if
bs
<=
graph_batch_size
]
if
self
.
model_config
.
enforce_eager
:
batch_size_capture_list
=
[]
with
set_compile_context
(
batch_size_capture_list
):
self
.
execute_model
(
model_input
,
kv_caches
,
intermediate_tensors
)
self
.
execute_model
(
model_input
,
kv_caches
,
intermediate_tensors
)
torch
.
cuda
.
synchronize
()
return
...
...
@@ -1459,18 +1441,14 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]):
dtype
=
self
.
model_config
.
dtype
,
device
=
self
.
device
)
graph_batch_size
=
self
.
max_batchsize_to_capture
batch_size_capture_list
=
[
bs
for
bs
in
_BATCH_SIZES_TO_CAPTURE
if
bs
<=
graph_batch_size
]
with
self
.
attn_state
.
graph_capture
(
max_batch_size
),
graph_capture
()
as
graph_capture_context
:
# NOTE: Capturing the largest batch size first may help reduce the
# memory usage of CUDA graph.
for
virtual_engine
in
range
(
self
.
parallel_config
.
pipeline_parallel_size
):
for
batch_size
in
reversed
(
batch_size_capture_list
):
for
batch_size
in
\
self
.
vllm_config
.
compilation_config
.
capture_sizes
:
attn_metadata
=
(
self
.
attn_state
.
graph_capture_get_metadata_for_batch
(
batch_size
,
...
...
@@ -1993,37 +1971,3 @@ class CUDAGraphRunner(nn.Module):
return
self
.
output_buffers
[
"hidden_states"
]
return
self
.
output_buffers
def
_get_graph_batch_size
(
batch_size
:
int
)
->
int
:
"""Returns the padded batch size given actual batch size.
Batch sizes are 1, 2, 4, _BATCH_SIZE_ALIGNMENT,
2*_BATCH_SIZE_ALIGNMENT, 3*_BATCH_SIZE_ALIGNMENT...
"""
if
batch_size
<=
2
:
return
batch_size
elif
batch_size
<=
4
:
return
4
else
:
return
((
batch_size
+
_BATCH_SIZE_ALIGNMENT
-
1
)
//
_BATCH_SIZE_ALIGNMENT
*
_BATCH_SIZE_ALIGNMENT
)
def
_get_max_graph_batch_size
(
max_num_seqs
:
int
)
->
int
:
"""
max_num_seqs: Maximum number of sequences in a batch.
_BATCH_SIZES_TO_CAPTURE: all the sizes that we want to capture.
pad the max_num_seqs if necessary by calling _get_graph_batch_size,
which will deal with some edge cases like 1, 2, 4.
if the padded size is in _BATCH_SIZES_TO_CAPTURE, return the padded size.
if not, it means the padded size is larger than the largest size in
_BATCH_SIZES_TO_CAPTURE, return the largest size in _BATCH_SIZES_TO_CAPTURE.
"""
padded_size
=
_get_graph_batch_size
(
max_num_seqs
)
if
padded_size
in
_BATCH_SIZES_TO_CAPTURE
:
return
padded_size
assert
padded_size
>
_BATCH_SIZES_TO_CAPTURE
[
-
1
]
return
_BATCH_SIZES_TO_CAPTURE
[
-
1
]
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment