Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
dc5ce861
Unverified
Commit
dc5ce861
authored
Dec 02, 2024
by
youkaichao
Committed by
GitHub
Dec 03, 2024
Browse files
[torch.compile] remove compilation_context and simplify code (#10838)
Signed-off-by:
youkaichao
<
youkaichao@gmail.com
>
parent
21fe7b48
Changes
14
Show whitespace changes
Inline
Side-by-side
Showing
14 changed files
with
128 additions
and
143 deletions
+128
-143
tests/compile/piecewise/test_simple.py
tests/compile/piecewise/test_simple.py
+4
-5
tests/compile/piecewise/test_toy_llama.py
tests/compile/piecewise/test_toy_llama.py
+17
-16
tests/models/decoder_only/language/test_jamba.py
tests/models/decoder_only/language/test_jamba.py
+3
-2
tests/models/decoder_only/language/test_mamba.py
tests/models/decoder_only/language/test_mamba.py
+3
-2
tests/worker/test_encoder_decoder_model_runner.py
tests/worker/test_encoder_decoder_model_runner.py
+2
-2
tests/worker/test_model_runner.py
tests/worker/test_model_runner.py
+3
-2
vllm/compilation/backends.py
vllm/compilation/backends.py
+0
-4
vllm/compilation/compile_context.py
vllm/compilation/compile_context.py
+0
-23
vllm/config.py
vllm/config.py
+75
-8
vllm/model_executor/models/jamba.py
vllm/model_executor/models/jamba.py
+2
-4
vllm/model_executor/models/mamba.py
vllm/model_executor/models/mamba.py
+2
-4
vllm/v1/worker/gpu_model_runner.py
vllm/v1/worker/gpu_model_runner.py
+8
-6
vllm/worker/enc_dec_model_runner.py
vllm/worker/enc_dec_model_runner.py
+3
-3
vllm/worker/model_runner.py
vllm/worker/model_runner.py
+6
-62
No files found.
tests/compile/piecewise/test_simple.py
View file @
dc5ce861
...
@@ -7,7 +7,6 @@ import torch
...
@@ -7,7 +7,6 @@ import torch
from
torch
import
nn
from
torch
import
nn
from
torch.library
import
Library
from
torch.library
import
Library
from
vllm.compilation.compile_context
import
set_compile_context
from
vllm.compilation.counter
import
compilation_counter
from
vllm.compilation.counter
import
compilation_counter
from
vllm.compilation.decorators
import
support_torch_compile
from
vllm.compilation.decorators
import
support_torch_compile
from
vllm.config
import
(
CompilationConfig
,
CompilationLevel
,
VllmConfig
,
from
vllm.config
import
(
CompilationConfig
,
CompilationLevel
,
VllmConfig
,
...
@@ -81,6 +80,7 @@ def test_simple_piecewise_compile():
...
@@ -81,6 +80,7 @@ def test_simple_piecewise_compile():
use_cudagraph
=
True
,
use_cudagraph
=
True
,
splitting_ops
=
[
"silly.attention"
],
splitting_ops
=
[
"silly.attention"
],
cudagraph_copy_inputs
=
True
,
cudagraph_copy_inputs
=
True
,
cudagraph_capture_sizes
=
[
1
,
2
],
))
))
with
set_current_vllm_config
(
vllm_config
):
with
set_current_vllm_config
(
vllm_config
):
model
=
SillyModel
(
vllm_config
=
vllm_config
,
prefix
=
''
)
model
=
SillyModel
(
vllm_config
=
vllm_config
,
prefix
=
''
)
...
@@ -96,7 +96,6 @@ def test_simple_piecewise_compile():
...
@@ -96,7 +96,6 @@ def test_simple_piecewise_compile():
6
,
# num_cudagraph_sizes * num_piecewise_capturable_graphs_seen
6
,
# num_cudagraph_sizes * num_piecewise_capturable_graphs_seen
):
):
with
set_compile_context
([
1
,
2
]):
model
(
inputs
)
model
(
inputs
)
model
(
torch
.
randn
(
2
).
cuda
())
model
(
torch
.
randn
(
2
).
cuda
())
...
...
tests/compile/piecewise/test_toy_llama.py
View file @
dc5ce861
...
@@ -13,7 +13,6 @@ import torch
...
@@ -13,7 +13,6 @@ import torch
from
torch
import
nn
from
torch
import
nn
from
torch.library
import
Library
from
torch.library
import
Library
from
vllm.compilation.compile_context
import
set_compile_context
from
vllm.compilation.counter
import
compilation_counter
from
vllm.compilation.counter
import
compilation_counter
from
vllm.compilation.decorators
import
support_torch_compile
from
vllm.compilation.decorators
import
support_torch_compile
from
vllm.config
import
(
CompilationConfig
,
CompilationLevel
,
VllmConfig
,
from
vllm.config
import
(
CompilationConfig
,
CompilationLevel
,
VllmConfig
,
...
@@ -256,6 +255,7 @@ def run_model(llama_config,
...
@@ -256,6 +255,7 @@ def run_model(llama_config,
compilation_config
=
CompilationConfig
(
compilation_config
=
CompilationConfig
(
level
=
CompilationLevel
.
PIECEWISE
,
level
=
CompilationLevel
.
PIECEWISE
,
use_cudagraph
=
True
,
use_cudagraph
=
True
,
cudagraph_capture_sizes
=
[
1
,
2
],
)
)
if
split_attn
:
if
split_attn
:
compilation_config
.
splitting_ops
=
[
"silly.attention"
]
compilation_config
.
splitting_ops
=
[
"silly.attention"
]
...
@@ -273,7 +273,6 @@ def run_model(llama_config,
...
@@ -273,7 +273,6 @@ def run_model(llama_config,
input_ids
=
torch
.
randint
(
0
,
llama_config
.
vocab_size
,
(
B
,
)).
cuda
()
input_ids
=
torch
.
randint
(
0
,
llama_config
.
vocab_size
,
(
B
,
)).
cuda
()
positions
=
torch
.
arange
(
B
).
cuda
()
positions
=
torch
.
arange
(
B
).
cuda
()
with
set_compile_context
([
1
,
2
]):
model
(
input_ids
,
positions
)
model
(
input_ids
,
positions
)
model
(
input_ids
[:
2
],
positions
[:
2
])
model
(
input_ids
[:
2
],
positions
[:
2
])
model
(
input_ids
[:
1
],
positions
[:
1
])
model
(
input_ids
[:
1
],
positions
[:
1
])
...
@@ -379,10 +378,13 @@ def benchmark():
...
@@ -379,10 +378,13 @@ def benchmark():
level
=
CompilationLevel
.
PIECEWISE
,
level
=
CompilationLevel
.
PIECEWISE
,
use_cudagraph
=
True
,
use_cudagraph
=
True
,
splitting_ops
=
[
"silly.attention"
],
splitting_ops
=
[
"silly.attention"
],
cudagraph_capture_sizes
=
cudagraph_sizes
,
)
)
else
:
else
:
compilation_config
=
CompilationConfig
(
compilation_config
=
CompilationConfig
(
level
=
CompilationLevel
.
PIECEWISE
,
)
level
=
CompilationLevel
.
PIECEWISE
,
cudagraph_capture_sizes
=
cudagraph_sizes
,
)
vllm_config
=
VllmConfig
(
compilation_config
=
compilation_config
)
vllm_config
=
VllmConfig
(
compilation_config
=
compilation_config
)
with
set_current_vllm_config
(
vllm_config
):
with
set_current_vllm_config
(
vllm_config
):
...
@@ -396,7 +398,6 @@ def benchmark():
...
@@ -396,7 +398,6 @@ def benchmark():
graphs
=
{}
graphs
=
{}
with
set_compile_context
(
cudagraph_sizes
):
model
(
input_ids
,
positions
)
model
(
input_ids
,
positions
)
for
b
in
cudagraph_sizes
[::
-
1
]:
for
b
in
cudagraph_sizes
[::
-
1
]:
if
not
piecewise
:
if
not
piecewise
:
...
...
tests/models/decoder_only/language/test_jamba.py
View file @
dc5ce861
import
pytest
import
pytest
from
tests.utils
import
multi_gpu_test
from
tests.utils
import
multi_gpu_test
from
vllm.config
import
VllmConfig
from
vllm.sampling_params
import
SamplingParams
from
vllm.sampling_params
import
SamplingParams
from
vllm.worker.model_runner
import
_get_graph_batch_size
from
...utils
import
check_outputs_equal
from
...utils
import
check_outputs_equal
...
@@ -189,7 +189,8 @@ def test_mamba_cache_cg_padding(
...
@@ -189,7 +189,8 @@ def test_mamba_cache_cg_padding(
# This test is for verifying that mamba cache is padded to CG captured
# This test is for verifying that mamba cache is padded to CG captured
# batch size. If it's not, a torch RuntimeError will be raised because
# batch size. If it's not, a torch RuntimeError will be raised because
# tensor dimensions aren't compatible
# tensor dimensions aren't compatible
while
len
(
example_prompts
)
==
_get_graph_batch_size
(
len
(
example_prompts
)):
while
len
(
example_prompts
)
==
VllmConfig
.
get_graph_batch_size
(
len
(
example_prompts
)):
example_prompts
.
append
(
example_prompts
[
0
])
example_prompts
.
append
(
example_prompts
[
0
])
try
:
try
:
...
...
tests/models/decoder_only/language/test_mamba.py
View file @
dc5ce861
...
@@ -5,8 +5,8 @@ Run `pytest tests/models/test_mamba.py`.
...
@@ -5,8 +5,8 @@ Run `pytest tests/models/test_mamba.py`.
import
pytest
import
pytest
from
transformers
import
AutoModelForCausalLM
,
AutoTokenizer
from
transformers
import
AutoModelForCausalLM
,
AutoTokenizer
from
vllm.config
import
VllmConfig
from
vllm.sampling_params
import
SamplingParams
from
vllm.sampling_params
import
SamplingParams
from
vllm.worker.model_runner
import
_get_graph_batch_size
from
...utils
import
check_outputs_equal
from
...utils
import
check_outputs_equal
...
@@ -200,7 +200,8 @@ def test_mamba_cache_cg_padding(
...
@@ -200,7 +200,8 @@ def test_mamba_cache_cg_padding(
# This test is for verifying that mamba cache is padded to CG captured
# This test is for verifying that mamba cache is padded to CG captured
# batch size. If it's not, a torch RuntimeError will be raised because
# batch size. If it's not, a torch RuntimeError will be raised because
# tensor dimensions aren't compatible
# tensor dimensions aren't compatible
while
len
(
example_prompts
)
==
_get_graph_batch_size
(
len
(
example_prompts
)):
while
len
(
example_prompts
)
==
VllmConfig
.
get_graph_batch_size
(
len
(
example_prompts
)):
example_prompts
.
append
(
example_prompts
[
0
])
example_prompts
.
append
(
example_prompts
[
0
])
try
:
try
:
...
...
tests/worker/test_encoder_decoder_model_runner.py
View file @
dc5ce861
...
@@ -4,12 +4,12 @@ from typing import List
...
@@ -4,12 +4,12 @@ from typing import List
import
pytest
import
pytest
import
torch
import
torch
from
vllm.config
import
VllmConfig
from
vllm.engine.arg_utils
import
EngineArgs
from
vllm.engine.arg_utils
import
EngineArgs
from
vllm.platforms
import
current_platform
from
vllm.platforms
import
current_platform
from
vllm.sequence
import
SamplingParams
,
SequenceData
,
SequenceGroupMetadata
from
vllm.sequence
import
SamplingParams
,
SequenceData
,
SequenceGroupMetadata
from
vllm.utils
import
make_tensor_with_pad
from
vllm.utils
import
make_tensor_with_pad
from
vllm.worker.enc_dec_model_runner
import
EncoderDecoderModelRunner
from
vllm.worker.enc_dec_model_runner
import
EncoderDecoderModelRunner
from
vllm.worker.model_runner
import
_get_graph_batch_size
BATCH_SIZES
=
[
1
,
4
,
16
,
64
,
256
]
BATCH_SIZES
=
[
1
,
4
,
16
,
64
,
256
]
...
@@ -548,7 +548,7 @@ def test_prepare_decode_cuda_graph(batch_size, multiple_seqs_per_seq_group):
...
@@ -548,7 +548,7 @@ def test_prepare_decode_cuda_graph(batch_size, multiple_seqs_per_seq_group):
# With CUDA Graph capture and replay enabled, the decoder and encoder
# With CUDA Graph capture and replay enabled, the decoder and encoder
# input sequences will be padded. Create the expected padded tensors
# input sequences will be padded. Create the expected padded tensors
# accordingly.
# accordingly.
graph_batch_size
=
_
get_graph_batch_size
(
expanded_batch_size
)
graph_batch_size
=
VllmConfig
.
get_graph_batch_size
(
expanded_batch_size
)
cuda_graph_pad_size
=
graph_batch_size
-
expanded_batch_size
cuda_graph_pad_size
=
graph_batch_size
-
expanded_batch_size
padded_seq_lens
=
seq_lens
+
list
(
itertools
.
repeat
(
1
,
cuda_graph_pad_size
))
padded_seq_lens
=
seq_lens
+
list
(
itertools
.
repeat
(
1
,
cuda_graph_pad_size
))
padded_encoder_seq_lens
=
encoder_seq_lens
+
list
(
padded_encoder_seq_lens
=
encoder_seq_lens
+
list
(
...
...
tests/worker/test_model_runner.py
View file @
dc5ce861
...
@@ -3,13 +3,14 @@ from typing import List
...
@@ -3,13 +3,14 @@ from typing import List
import
pytest
import
pytest
import
torch
import
torch
from
vllm.config
import
VllmConfig
from
vllm.distributed.parallel_state
import
(
ensure_model_parallel_initialized
,
from
vllm.distributed.parallel_state
import
(
ensure_model_parallel_initialized
,
init_distributed_environment
)
init_distributed_environment
)
from
vllm.engine.arg_utils
import
EngineArgs
from
vllm.engine.arg_utils
import
EngineArgs
from
vllm.model_executor.sampling_metadata
import
SamplingMetadata
from
vllm.model_executor.sampling_metadata
import
SamplingMetadata
from
vllm.sequence
import
SamplingParams
,
SequenceData
,
SequenceGroupMetadata
from
vllm.sequence
import
SamplingParams
,
SequenceData
,
SequenceGroupMetadata
from
vllm.utils
import
get_open_port
from
vllm.utils
import
get_open_port
from
vllm.worker.model_runner
import
ModelRunner
,
_get_graph_batch_size
from
vllm.worker.model_runner
import
ModelRunner
def
_create_model_runner
(
model
:
str
,
*
args
,
**
kwargs
)
->
ModelRunner
:
def
_create_model_runner
(
model
:
str
,
*
args
,
**
kwargs
)
->
ModelRunner
:
...
@@ -176,7 +177,7 @@ def test_prepare_decode_cuda_graph(batch_size):
...
@@ -176,7 +177,7 @@ def test_prepare_decode_cuda_graph(batch_size):
model_input
.
attn_metadata
,
model_input
.
attn_metadata
.
slot_mapping
)
model_input
.
attn_metadata
,
model_input
.
attn_metadata
.
slot_mapping
)
assert
len
(
slot_mapping
)
==
len
(
input_tokens
)
assert
len
(
slot_mapping
)
==
len
(
input_tokens
)
expected_bs
=
_
get_graph_batch_size
(
len
(
seq_group_metadata_list
))
expected_bs
=
VllmConfig
.
get_graph_batch_size
(
len
(
seq_group_metadata_list
))
# Verify input metadata is correct for prompts.
# Verify input metadata is correct for prompts.
device
=
model_runner
.
device
device
=
model_runner
.
device
assert
attn_metadata
.
num_prefills
==
0
assert
attn_metadata
.
num_prefills
==
0
...
...
vllm/compilation/backends.py
View file @
dc5ce861
...
@@ -242,10 +242,6 @@ class VllmBackend:
...
@@ -242,10 +242,6 @@ class VllmBackend:
assert
not
self
.
_called
,
"VllmBackend can only be called once"
assert
not
self
.
_called
,
"VllmBackend can only be called once"
self
.
graph
=
graph
self
.
graph
=
graph
# config is updated now, because only here can
# we get the sizes to capture for cudagraph
# from compilation context
self
.
compilation_configs
.
init_during_runtime
()
self
.
configure_post_pass
()
self
.
configure_post_pass
()
self
.
split_gm
,
self
.
piecewise_graphs
=
split_graph
(
self
.
split_gm
,
self
.
piecewise_graphs
=
split_graph
(
...
...
vllm/compilation/compile_context.py
deleted
100644 → 0
View file @
21fe7b48
from
contextlib
import
contextmanager
from
typing
import
Any
_compile_context
:
Any
=
None
def
get_compile_context
()
->
Any
:
"""Get the current compile context."""
return
_compile_context
@
contextmanager
def
set_compile_context
(
context
:
Any
):
"""A context manager that stores the current compile context,
usually it is a list of sizes to specialize.
"""
global
_compile_context
prev_context
=
_compile_context
_compile_context
=
context
try
:
yield
finally
:
_compile_context
=
prev_context
vllm/config.py
View file @
dc5ce861
...
@@ -2357,15 +2357,10 @@ class CompilationConfig(BaseModel):
...
@@ -2357,15 +2357,10 @@ class CompilationConfig(BaseModel):
from
vllm.compilation.backends
import
VllmBackend
from
vllm.compilation.backends
import
VllmBackend
return
VllmBackend
(
self
)
return
VllmBackend
(
self
)
def
init_
during_runtime
(
self
):
def
init_
with_cudagraph_sizes
(
self
,
sizes_to_specialize
:
List
[
int
]
):
"""To complete the initialization of config,
"""To complete the initialization of config,
we need to know the compile context, which is only available
we need to know the cudagraph sizes."""
during the first run of the model.
"""
from
vllm.compilation.compile_context
import
get_compile_context
context
=
get_compile_context
()
context
=
copy
.
deepcopy
(
context
)
if
context
is
not
None
else
[]
sizes_to_specialize
:
List
[
int
]
=
context
if
self
.
cudagraph_capture_sizes
is
None
:
if
self
.
cudagraph_capture_sizes
is
None
:
self
.
capture_sizes
=
sizes_to_specialize
self
.
capture_sizes
=
sizes_to_specialize
else
:
else
:
...
@@ -2386,6 +2381,21 @@ class CompilationConfig(BaseModel):
...
@@ -2386,6 +2381,21 @@ class CompilationConfig(BaseModel):
self
.
inductor_compile_sizes
=
[]
self
.
inductor_compile_sizes
=
[]
self
.
compile_sizes
=
self
.
inductor_compile_sizes
self
.
compile_sizes
=
self
.
inductor_compile_sizes
# sort to make sure cudagraph capture sizes are in descending order
self
.
capture_sizes
.
sort
(
reverse
=
True
)
_BATCH_SIZE_ALIGNMENT
=
8
# all the token sizes that **can** be captured by cudagraph.
# they can be arbitrarily large.
# currently it includes: 1, 2, 4, 8, 16, 24, 32, 40, ..., 8192.
# the actual sizes to capture will be determined by the model,
# depending on the model's max_num_seqs.
# NOTE: get_graph_batch_size needs to be updated if this list is changed.
_BATCH_SIZES_TO_CAPTURE
=
[
1
,
2
,
4
]
+
[
_BATCH_SIZE_ALIGNMENT
*
i
for
i
in
range
(
1
,
1025
)
]
@
dataclass
@
dataclass
class
VllmConfig
:
class
VllmConfig
:
...
@@ -2413,6 +2423,41 @@ class VllmConfig:
...
@@ -2413,6 +2423,41 @@ class VllmConfig:
kv_transfer_config
:
KVTransferConfig
=
field
(
default
=
None
,
kv_transfer_config
:
KVTransferConfig
=
field
(
default
=
None
,
init
=
True
)
# type: ignore
init
=
True
)
# type: ignore
@
staticmethod
def
get_graph_batch_size
(
batch_size
:
int
)
->
int
:
"""Returns the padded batch size given actual batch size.
Batch sizes are 1, 2, 4, _BATCH_SIZE_ALIGNMENT,
2*_BATCH_SIZE_ALIGNMENT, 3*_BATCH_SIZE_ALIGNMENT...
"""
if
batch_size
<=
2
:
return
batch_size
elif
batch_size
<=
4
:
return
4
else
:
return
((
batch_size
+
_BATCH_SIZE_ALIGNMENT
-
1
)
//
_BATCH_SIZE_ALIGNMENT
*
_BATCH_SIZE_ALIGNMENT
)
@
staticmethod
def
get_max_graph_batch_size
(
max_num_seqs
:
int
)
->
int
:
"""
max_num_seqs: Maximum number of sequences in a batch.
_BATCH_SIZES_TO_CAPTURE: all the sizes that we want to capture.
pad the max_num_seqs if necessary by calling get_graph_batch_size,
which will deal with some edge cases like 1, 2, 4.
if the padded size is in _BATCH_SIZES_TO_CAPTURE, return the padded
size. if not, it means the padded size is larger than the largest size
in _BATCH_SIZES_TO_CAPTURE, return the largest size in
_BATCH_SIZES_TO_CAPTURE.
"""
padded_size
=
VllmConfig
.
get_graph_batch_size
(
max_num_seqs
)
if
padded_size
in
_BATCH_SIZES_TO_CAPTURE
:
return
padded_size
assert
padded_size
>
_BATCH_SIZES_TO_CAPTURE
[
-
1
]
return
_BATCH_SIZES_TO_CAPTURE
[
-
1
]
@
staticmethod
@
staticmethod
def
_get_quantization_config
(
def
_get_quantization_config
(
model_config
:
ModelConfig
,
model_config
:
ModelConfig
,
...
@@ -2496,6 +2541,28 @@ class VllmConfig:
...
@@ -2496,6 +2541,28 @@ class VllmConfig:
self
.
compilation_config
.
pass_config
.
enable_reshape
=
False
self
.
compilation_config
.
pass_config
.
enable_reshape
=
False
self
.
compilation_config
.
level
=
CompilationLevel
.
PIECEWISE
self
.
compilation_config
.
level
=
CompilationLevel
.
PIECEWISE
if
not
envs
.
VLLM_USE_V1
:
max_batchsize_to_capture
=
0
if
self
.
scheduler_config
is
not
None
and
\
self
.
model_config
is
not
None
and
\
not
self
.
model_config
.
enforce_eager
:
max_batchsize_to_capture
=
\
self
.
get_max_graph_batch_size
(
self
.
scheduler_config
.
max_num_seqs
)
batch_size_capture_list
=
[
size
for
size
in
_BATCH_SIZES_TO_CAPTURE
if
size
<=
max_batchsize_to_capture
]
else
:
batch_size_capture_list
=
[]
if
self
.
model_config
is
not
None
and
\
not
self
.
model_config
.
enforce_eager
:
batch_size_capture_list
=
[
1
,
2
,
4
]
+
[
i
for
i
in
range
(
8
,
513
,
8
)]
self
.
compilation_config
.
init_with_cudagraph_sizes
(
batch_size_capture_list
)
if
self
.
cache_config
is
not
None
and
\
if
self
.
cache_config
is
not
None
and
\
self
.
cache_config
.
cpu_offload_gb
>
0
and
\
self
.
cache_config
.
cpu_offload_gb
>
0
and
\
self
.
compilation_config
.
level
!=
CompilationLevel
.
NO_COMPILATION
:
self
.
compilation_config
.
level
!=
CompilationLevel
.
NO_COMPILATION
:
...
...
vllm/model_executor/models/jamba.py
View file @
dc5ce861
...
@@ -7,7 +7,7 @@ from transformers import JambaConfig
...
@@ -7,7 +7,7 @@ from transformers import JambaConfig
from
vllm.attention.backends.abstract
import
AttentionMetadata
from
vllm.attention.backends.abstract
import
AttentionMetadata
from
vllm.attention.layer
import
Attention
from
vllm.attention.layer
import
Attention
from
vllm.config
import
CacheConfig
,
VllmConfig
from
vllm.config
import
_BATCH_SIZES_TO_CAPTURE
,
CacheConfig
,
VllmConfig
from
vllm.distributed
import
get_tensor_model_parallel_world_size
from
vllm.distributed
import
get_tensor_model_parallel_world_size
from
vllm.model_executor.layers.fused_moe
import
FusedMoE
from
vllm.model_executor.layers.fused_moe
import
FusedMoE
from
vllm.model_executor.layers.layernorm
import
RMSNorm
from
vllm.model_executor.layers.layernorm
import
RMSNorm
...
@@ -25,8 +25,6 @@ from vllm.model_executor.models.mamba_cache import (MambaCacheManager,
...
@@ -25,8 +25,6 @@ from vllm.model_executor.models.mamba_cache import (MambaCacheManager,
MambaCacheParams
)
MambaCacheParams
)
from
vllm.model_executor.sampling_metadata
import
SamplingMetadata
from
vllm.model_executor.sampling_metadata
import
SamplingMetadata
from
vllm.sequence
import
IntermediateTensors
from
vllm.sequence
import
IntermediateTensors
from
vllm.worker.model_runner
import
(
_BATCH_SIZES_TO_CAPTURE
,
_get_graph_batch_size
)
from
.interfaces
import
HasInnerState
,
SupportsLoRA
from
.interfaces
import
HasInnerState
,
SupportsLoRA
from
.utils
import
maybe_prefix
from
.utils
import
maybe_prefix
...
@@ -404,7 +402,7 @@ class JambaForCausalLM(nn.Module, HasInnerState, SupportsLoRA):
...
@@ -404,7 +402,7 @@ class JambaForCausalLM(nn.Module, HasInnerState, SupportsLoRA):
inputs_embeds
:
Optional
[
torch
.
Tensor
]
=
None
,
inputs_embeds
:
Optional
[
torch
.
Tensor
]
=
None
,
**
kwargs
):
**
kwargs
):
if
self
.
mamba_cache
is
None
:
if
self
.
mamba_cache
is
None
:
max_batch_size
=
(
_
get_graph_batch_size
(
max_batch_size
=
(
VllmConfig
.
get_graph_batch_size
(
self
.
scheduler_config
.
max_num_seqs
)
if
self
.
scheduler_config
self
.
scheduler_config
.
max_num_seqs
)
if
self
.
scheduler_config
else
max
(
_BATCH_SIZES_TO_CAPTURE
)
+
2
)
else
max
(
_BATCH_SIZES_TO_CAPTURE
)
+
2
)
...
...
vllm/model_executor/models/mamba.py
View file @
dc5ce861
...
@@ -6,7 +6,7 @@ from torch import nn
...
@@ -6,7 +6,7 @@ from torch import nn
from
transformers
import
MambaConfig
from
transformers
import
MambaConfig
from
vllm.attention.backends.abstract
import
AttentionMetadata
from
vllm.attention.backends.abstract
import
AttentionMetadata
from
vllm.config
import
CacheConfig
,
VllmConfig
from
vllm.config
import
_BATCH_SIZES_TO_CAPTURE
,
CacheConfig
,
VllmConfig
from
vllm.distributed
import
get_tensor_model_parallel_world_size
from
vllm.distributed
import
get_tensor_model_parallel_world_size
from
vllm.model_executor.layers.layernorm
import
RMSNorm
from
vllm.model_executor.layers.layernorm
import
RMSNorm
from
vllm.model_executor.layers.logits_processor
import
LogitsProcessor
from
vllm.model_executor.layers.logits_processor
import
LogitsProcessor
...
@@ -23,8 +23,6 @@ from vllm.model_executor.models.mamba_cache import (MambaCacheManager,
...
@@ -23,8 +23,6 @@ from vllm.model_executor.models.mamba_cache import (MambaCacheManager,
MambaCacheParams
)
MambaCacheParams
)
from
vllm.model_executor.sampling_metadata
import
SamplingMetadata
from
vllm.model_executor.sampling_metadata
import
SamplingMetadata
from
vllm.sequence
import
IntermediateTensors
from
vllm.sequence
import
IntermediateTensors
from
vllm.worker.model_runner
import
(
_BATCH_SIZES_TO_CAPTURE
,
_get_graph_batch_size
)
from
.utils
import
maybe_prefix
from
.utils
import
maybe_prefix
...
@@ -187,7 +185,7 @@ class MambaForCausalLM(nn.Module, HasInnerState, IsAttentionFree):
...
@@ -187,7 +185,7 @@ class MambaForCausalLM(nn.Module, HasInnerState, IsAttentionFree):
inputs_embeds
:
Optional
[
torch
.
Tensor
]
=
None
,
inputs_embeds
:
Optional
[
torch
.
Tensor
]
=
None
,
**
kwargs
):
**
kwargs
):
if
self
.
mamba_cache
is
None
:
if
self
.
mamba_cache
is
None
:
max_batch_size
=
(
_
get_graph_batch_size
(
max_batch_size
=
(
VllmConfig
.
get_graph_batch_size
(
self
.
scheduler_config
.
max_num_seqs
)
if
self
.
scheduler_config
self
.
scheduler_config
.
max_num_seqs
)
if
self
.
scheduler_config
else
max
(
_BATCH_SIZES_TO_CAPTURE
)
+
2
)
else
max
(
_BATCH_SIZES_TO_CAPTURE
)
+
2
)
self
.
mamba_cache
=
MambaCacheManager
(
self
.
mamba_cache
=
MambaCacheManager
(
...
...
vllm/v1/worker/gpu_model_runner.py
View file @
dc5ce861
...
@@ -8,7 +8,6 @@ import torch
...
@@ -8,7 +8,6 @@ import torch
import
torch.distributed
import
torch.distributed
import
torch.nn
as
nn
import
torch.nn
as
nn
from
vllm.compilation.compile_context
import
set_compile_context
from
vllm.config
import
CompilationLevel
,
VllmConfig
from
vllm.config
import
CompilationLevel
,
VllmConfig
from
vllm.distributed.parallel_state
import
graph_capture
from
vllm.distributed.parallel_state
import
graph_capture
from
vllm.forward_context
import
set_forward_context
from
vllm.forward_context
import
set_forward_context
...
@@ -100,7 +99,11 @@ class GPUModelRunner:
...
@@ -100,7 +99,11 @@ class GPUModelRunner:
==
CompilationLevel
.
PIECEWISE
==
CompilationLevel
.
PIECEWISE
and
not
self
.
model_config
.
enforce_eager
)
and
not
self
.
model_config
.
enforce_eager
)
# TODO(woosuk): Provide an option to tune the max cudagraph batch size.
# TODO(woosuk): Provide an option to tune the max cudagraph batch size.
self
.
cudagraph_batch_sizes
=
[
1
,
2
,
4
]
+
[
i
for
i
in
range
(
8
,
513
,
8
)]
# The convention is different.
# self.cudagraph_batch_sizes sorts in ascending order.
# The batch sizes in the config are in descending order.
self
.
cudagraph_batch_sizes
=
list
(
reversed
(
self
.
vllm_config
.
compilation_config
.
capture_sizes
))
self
.
positions
=
torch
.
zeros
(
self
.
max_num_tokens
,
self
.
positions
=
torch
.
zeros
(
self
.
max_num_tokens
,
dtype
=
torch
.
int64
,
dtype
=
torch
.
int64
,
device
=
self
.
device
)
device
=
self
.
device
)
...
@@ -548,7 +551,6 @@ class GPUModelRunner:
...
@@ -548,7 +551,6 @@ class GPUModelRunner:
torch
.
tensor
([],
dtype
=
torch
.
float32
,
device
=
self
.
device
)
torch
.
tensor
([],
dtype
=
torch
.
float32
,
device
=
self
.
device
)
for
_
in
range
(
self
.
num_attn_layers
)
for
_
in
range
(
self
.
num_attn_layers
)
]
]
with
set_compile_context
(
self
.
cudagraph_batch_sizes
):
# Trigger compilation for general shape.
# Trigger compilation for general shape.
hidden_states
=
self
.
_dummy_run
(
self
.
model
,
self
.
max_num_tokens
,
hidden_states
=
self
.
_dummy_run
(
self
.
model
,
self
.
max_num_tokens
,
dummy_kv_caches
)
dummy_kv_caches
)
...
...
vllm/worker/enc_dec_model_runner.py
View file @
dc5ce861
...
@@ -25,8 +25,7 @@ from vllm.sequence import (IntermediateTensors, PoolerOutput,
...
@@ -25,8 +25,7 @@ from vllm.sequence import (IntermediateTensors, PoolerOutput,
from
vllm.utils
import
STR_NOT_IMPL_ENC_DEC_BACKEND
,
make_tensor_with_pad
from
vllm.utils
import
STR_NOT_IMPL_ENC_DEC_BACKEND
,
make_tensor_with_pad
from
vllm.worker.model_runner
import
(
GPUModelRunnerBase
,
from
vllm.worker.model_runner
import
(
GPUModelRunnerBase
,
ModelInputForGPUBuilder
,
ModelInputForGPUBuilder
,
ModelInputForGPUWithSamplingMetadata
,
ModelInputForGPUWithSamplingMetadata
)
_get_graph_batch_size
)
from
vllm.worker.model_runner_base
import
(
from
vllm.worker.model_runner_base
import
(
_add_attn_metadata_broadcastable_dict
,
_add_attn_metadata_broadcastable_dict
,
_add_sampling_metadata_broadcastable_dict
)
_add_sampling_metadata_broadcastable_dict
)
...
@@ -465,7 +464,8 @@ class EncoderDecoderModelRunner(GPUModelRunnerBase[EncoderDecoderModelInput]):
...
@@ -465,7 +464,8 @@ class EncoderDecoderModelRunner(GPUModelRunnerBase[EncoderDecoderModelInput]):
# We will be using CUDA graph replay for this decode.
# We will be using CUDA graph replay for this decode.
max_len_of_block_table
=
self
.
get_max_block_per_batch
()
max_len_of_block_table
=
self
.
get_max_block_per_batch
()
batch_size
=
len
(
encoder_seq_lens
)
batch_size
=
len
(
encoder_seq_lens
)
graph_batch_size
=
_get_graph_batch_size
(
batch_size
)
graph_batch_size
=
self
.
vllm_config
.
get_graph_batch_size
(
batch_size
)
assert
graph_batch_size
>=
batch_size
assert
graph_batch_size
>=
batch_size
cuda_graph_pad_size
=
graph_batch_size
-
batch_size
cuda_graph_pad_size
=
graph_batch_size
-
batch_size
# extend the cross_block_tables and encoder_seq_lens to match
# extend the cross_block_tables and encoder_seq_lens to match
...
...
vllm/worker/model_runner.py
View file @
dc5ce861
...
@@ -18,7 +18,6 @@ import vllm.envs as envs
...
@@ -18,7 +18,6 @@ import vllm.envs as envs
from
vllm.attention
import
AttentionMetadata
,
get_attn_backend
from
vllm.attention
import
AttentionMetadata
,
get_attn_backend
from
vllm.attention.backends.abstract
import
AttentionState
from
vllm.attention.backends.abstract
import
AttentionState
from
vllm.attention.backends.utils
import
CommonAttentionState
from
vllm.attention.backends.utils
import
CommonAttentionState
from
vllm.compilation.compile_context
import
set_compile_context
from
vllm.config
import
CompilationLevel
,
VllmConfig
from
vllm.config
import
CompilationLevel
,
VllmConfig
from
vllm.core.scheduler
import
SchedulerOutputs
from
vllm.core.scheduler
import
SchedulerOutputs
from
vllm.distributed
import
get_kv_transfer_group
,
get_pp_group
from
vllm.distributed
import
get_kv_transfer_group
,
get_pp_group
...
@@ -63,16 +62,7 @@ if TYPE_CHECKING:
...
@@ -63,16 +62,7 @@ if TYPE_CHECKING:
logger
=
init_logger
(
__name__
)
logger
=
init_logger
(
__name__
)
LORA_WARMUP_RANK
=
8
LORA_WARMUP_RANK
=
8
_BATCH_SIZE_ALIGNMENT
=
8
# all the token sizes that **can** be captured by cudagraph.
# they can be arbitrarily large.
# currently it includes: 1, 2, 4, 8, 16, 24, 32, 40, ..., 8192.
# the actual sizes to capture will be determined by the model,
# depending on the model's max_num_seqs.
# NOTE: _get_graph_batch_size needs to be updated if this list is changed.
_BATCH_SIZES_TO_CAPTURE
=
[
1
,
2
,
4
]
+
[
_BATCH_SIZE_ALIGNMENT
*
i
for
i
in
range
(
1
,
1025
)
]
_NUM_WARMUP_ITERS
=
2
_NUM_WARMUP_ITERS
=
2
TModelInputForGPU
=
TypeVar
(
'TModelInputForGPU'
,
bound
=
"ModelInputForGPU"
)
TModelInputForGPU
=
TypeVar
(
'TModelInputForGPU'
,
bound
=
"ModelInputForGPU"
)
...
@@ -763,7 +753,6 @@ class ModelInputForGPUBuilder(ModelRunnerInputBuilderBase[ModelInputForGPU]):
...
@@ -763,7 +753,6 @@ class ModelInputForGPUBuilder(ModelRunnerInputBuilderBase[ModelInputForGPU]):
max_decode_seq_len
:
int
,
max_decode_seq_len
:
int
,
max_encoder_seq_len
:
int
=
0
)
->
bool
:
max_encoder_seq_len
:
int
=
0
)
->
bool
:
return
(
decode_only
and
not
self
.
runner
.
model_config
.
enforce_eager
return
(
decode_only
and
not
self
.
runner
.
model_config
.
enforce_eager
and
batch_size
<=
_BATCH_SIZES_TO_CAPTURE
[
-
1
]
and
max_decode_seq_len
<=
self
.
runner
.
max_seq_len_to_capture
and
max_decode_seq_len
<=
self
.
runner
.
max_seq_len_to_capture
and
max_encoder_seq_len
<=
self
.
runner
.
max_seq_len_to_capture
and
max_encoder_seq_len
<=
self
.
runner
.
max_seq_len_to_capture
and
batch_size
<=
self
.
runner
.
max_batchsize_to_capture
)
and
batch_size
<=
self
.
runner
.
max_batchsize_to_capture
)
...
@@ -811,7 +800,7 @@ class ModelInputForGPUBuilder(ModelRunnerInputBuilderBase[ModelInputForGPU]):
...
@@ -811,7 +800,7 @@ class ModelInputForGPUBuilder(ModelRunnerInputBuilderBase[ModelInputForGPU]):
max_encoder_seq_len
):
max_encoder_seq_len
):
return
-
1
return
-
1
graph_batch_size
=
_
get_graph_batch_size
(
batch_size
)
graph_batch_size
=
VllmConfig
.
get_graph_batch_size
(
batch_size
)
assert
graph_batch_size
>=
batch_size
assert
graph_batch_size
>=
batch_size
return
graph_batch_size
-
batch_size
return
graph_batch_size
-
batch_size
...
@@ -1023,7 +1012,7 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]):
...
@@ -1023,7 +1012,7 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]):
self
.
sliding_window
=
model_config
.
get_sliding_window
()
self
.
sliding_window
=
model_config
.
get_sliding_window
()
self
.
block_size
=
cache_config
.
block_size
self
.
block_size
=
cache_config
.
block_size
self
.
max_seq_len_to_capture
=
self
.
model_config
.
max_seq_len_to_capture
self
.
max_seq_len_to_capture
=
self
.
model_config
.
max_seq_len_to_capture
self
.
max_batchsize_to_capture
=
_
get_max_graph_batch_size
(
self
.
max_batchsize_to_capture
=
VllmConfig
.
get_max_graph_batch_size
(
self
.
scheduler_config
.
max_num_seqs
)
self
.
scheduler_config
.
max_num_seqs
)
self
.
graph_runners
:
List
[
Dict
[
int
,
CUDAGraphRunner
]]
=
[
self
.
graph_runners
:
List
[
Dict
[
int
,
CUDAGraphRunner
]]
=
[
...
@@ -1333,13 +1322,6 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]):
...
@@ -1333,13 +1322,6 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]):
dtype
=
self
.
model_config
.
dtype
,
dtype
=
self
.
model_config
.
dtype
,
device
=
self
.
device
)
device
=
self
.
device
)
graph_batch_size
=
self
.
max_batchsize_to_capture
batch_size_capture_list
=
[
bs
for
bs
in
_BATCH_SIZES_TO_CAPTURE
if
bs
<=
graph_batch_size
]
if
self
.
model_config
.
enforce_eager
:
batch_size_capture_list
=
[]
with
set_compile_context
(
batch_size_capture_list
):
self
.
execute_model
(
model_input
,
kv_caches
,
intermediate_tensors
)
self
.
execute_model
(
model_input
,
kv_caches
,
intermediate_tensors
)
torch
.
cuda
.
synchronize
()
torch
.
cuda
.
synchronize
()
return
return
...
@@ -1459,18 +1441,14 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]):
...
@@ -1459,18 +1441,14 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]):
dtype
=
self
.
model_config
.
dtype
,
dtype
=
self
.
model_config
.
dtype
,
device
=
self
.
device
)
device
=
self
.
device
)
graph_batch_size
=
self
.
max_batchsize_to_capture
batch_size_capture_list
=
[
bs
for
bs
in
_BATCH_SIZES_TO_CAPTURE
if
bs
<=
graph_batch_size
]
with
self
.
attn_state
.
graph_capture
(
with
self
.
attn_state
.
graph_capture
(
max_batch_size
),
graph_capture
()
as
graph_capture_context
:
max_batch_size
),
graph_capture
()
as
graph_capture_context
:
# NOTE: Capturing the largest batch size first may help reduce the
# NOTE: Capturing the largest batch size first may help reduce the
# memory usage of CUDA graph.
# memory usage of CUDA graph.
for
virtual_engine
in
range
(
for
virtual_engine
in
range
(
self
.
parallel_config
.
pipeline_parallel_size
):
self
.
parallel_config
.
pipeline_parallel_size
):
for
batch_size
in
reversed
(
batch_size_capture_list
):
for
batch_size
in
\
self
.
vllm_config
.
compilation_config
.
capture_sizes
:
attn_metadata
=
(
attn_metadata
=
(
self
.
attn_state
.
graph_capture_get_metadata_for_batch
(
self
.
attn_state
.
graph_capture_get_metadata_for_batch
(
batch_size
,
batch_size
,
...
@@ -1993,37 +1971,3 @@ class CUDAGraphRunner(nn.Module):
...
@@ -1993,37 +1971,3 @@ class CUDAGraphRunner(nn.Module):
return
self
.
output_buffers
[
"hidden_states"
]
return
self
.
output_buffers
[
"hidden_states"
]
return
self
.
output_buffers
return
self
.
output_buffers
def
_get_graph_batch_size
(
batch_size
:
int
)
->
int
:
"""Returns the padded batch size given actual batch size.
Batch sizes are 1, 2, 4, _BATCH_SIZE_ALIGNMENT,
2*_BATCH_SIZE_ALIGNMENT, 3*_BATCH_SIZE_ALIGNMENT...
"""
if
batch_size
<=
2
:
return
batch_size
elif
batch_size
<=
4
:
return
4
else
:
return
((
batch_size
+
_BATCH_SIZE_ALIGNMENT
-
1
)
//
_BATCH_SIZE_ALIGNMENT
*
_BATCH_SIZE_ALIGNMENT
)
def
_get_max_graph_batch_size
(
max_num_seqs
:
int
)
->
int
:
"""
max_num_seqs: Maximum number of sequences in a batch.
_BATCH_SIZES_TO_CAPTURE: all the sizes that we want to capture.
pad the max_num_seqs if necessary by calling _get_graph_batch_size,
which will deal with some edge cases like 1, 2, 4.
if the padded size is in _BATCH_SIZES_TO_CAPTURE, return the padded size.
if not, it means the padded size is larger than the largest size in
_BATCH_SIZES_TO_CAPTURE, return the largest size in _BATCH_SIZES_TO_CAPTURE.
"""
padded_size
=
_get_graph_batch_size
(
max_num_seqs
)
if
padded_size
in
_BATCH_SIZES_TO_CAPTURE
:
return
padded_size
assert
padded_size
>
_BATCH_SIZES_TO_CAPTURE
[
-
1
]
return
_BATCH_SIZES_TO_CAPTURE
[
-
1
]
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment