Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
c08f3b2a
Unverified
Commit
c08f3b2a
authored
Apr 14, 2026
by
Lucas Kabela
Committed by
GitHub
Apr 14, 2026
Browse files
Measure encoder compile time seperate from llm backbone (#39240)
Signed-off-by:
Lucas Kabela
<
lucaskabela@meta.com
>
parent
f02b3269
Changes
9
Show whitespace changes
Inline
Side-by-side
Showing
9 changed files
with
188 additions
and
149 deletions
+188
-149
vllm/benchmarks/startup.py
vllm/benchmarks/startup.py
+115
-130
vllm/compilation/backends.py
vllm/compilation/backends.py
+13
-3
vllm/compilation/piecewise_backend.py
vllm/compilation/piecewise_backend.py
+1
-0
vllm/config/compilation.py
vllm/config/compilation.py
+4
-0
vllm/v1/engine/core.py
vllm/v1/engine/core.py
+27
-5
vllm/v1/executor/abstract.py
vllm/v1/executor/abstract.py
+8
-3
vllm/v1/worker/cpu_worker.py
vllm/v1/worker/cpu_worker.py
+6
-2
vllm/v1/worker/gpu_worker.py
vllm/v1/worker/gpu_worker.py
+6
-3
vllm/v1/worker/worker_base.py
vllm/v1/worker/worker_base.py
+8
-3
No files found.
vllm/benchmarks/startup.py
View file @
c08f3b2a
...
@@ -16,7 +16,7 @@ import shutil
...
@@ -16,7 +16,7 @@ import shutil
import
tempfile
import
tempfile
import
time
import
time
from
contextlib
import
contextmanager
from
contextlib
import
contextmanager
from
typing
import
Any
from
typing
import
Any
,
NamedTuple
import
numpy
as
np
import
numpy
as
np
from
tqdm
import
tqdm
from
tqdm
import
tqdm
...
@@ -27,6 +27,82 @@ from vllm.benchmarks.lib.utils import (
...
@@ -27,6 +27,82 @@ from vllm.benchmarks.lib.utils import (
)
)
from
vllm.engine.arg_utils
import
EngineArgs
from
vllm.engine.arg_utils
import
EngineArgs
PERCENTAGES
=
[
10
,
25
,
50
,
75
,
90
,
99
]
class
MetricDesc
(
NamedTuple
):
"""Descriptor for a metric to collect from each iteration."""
iter_key
:
str
# key in the iteration result dict
suffix
:
str
# result key suffix, e.g. "startup", "compilation"
display_name
:
str
class
MetricStats
(
NamedTuple
):
"""Aggregated statistics for a single benchmark metric."""
key
:
str
# e.g. "cold_startup", "warm_encoder_compilation"
display_name
:
str
values
:
list
[
float
]
avg
:
float
percentiles
:
dict
[
int
,
float
]
_BASE_METRICS
=
[
MetricDesc
(
"total_startup_time"
,
"startup"
,
"Startup time"
),
MetricDesc
(
"compilation_time"
,
"compilation"
,
"Compilation time"
),
]
_ENCODER_METRIC
=
MetricDesc
(
"encoder_compilation_time"
,
"encoder_compilation"
,
"Encoder compilation time"
,
)
def
_compute_metric
(
phase
:
str
,
desc
:
MetricDesc
,
iterations
:
list
[
dict
[
str
,
float
]],
)
->
MetricStats
:
values
=
[
m
[
desc
.
iter_key
]
for
m
in
iterations
]
arr
=
np
.
array
(
values
)
return
MetricStats
(
key
=
f
"
{
phase
}
_
{
desc
.
suffix
}
"
,
display_name
=
desc
.
display_name
,
values
=
values
,
avg
=
float
(
np
.
mean
(
arr
)),
percentiles
=
dict
(
zip
(
PERCENTAGES
,
np
.
percentile
(
arr
,
PERCENTAGES
).
tolist
())),
)
def
_collect_phase_metrics
(
phase
:
str
,
iterations
:
list
[
dict
[
str
,
float
]],
has_encoder
:
bool
,
)
->
list
[
MetricStats
]:
metrics
=
[
_compute_metric
(
phase
,
desc
,
iterations
)
for
desc
in
_BASE_METRICS
]
if
has_encoder
:
metrics
.
append
(
_compute_metric
(
phase
,
_ENCODER_METRIC
,
iterations
))
return
metrics
def
_print_phase
(
phase_name
:
str
,
metrics
:
list
[
MetricStats
])
->
None
:
print
(
f
"
\n
{
phase_name
}
:"
)
for
m
in
metrics
:
print
(
f
"Avg
{
m
.
display_name
.
lower
()
}
:
{
m
.
avg
:.
2
f
}
seconds"
)
for
m
in
metrics
:
print
(
f
"
{
m
.
display_name
}
percentiles:"
)
for
pct
,
val
in
m
.
percentiles
.
items
():
print
(
f
"
{
pct
}
%:
{
val
:.
2
f
}
seconds"
)
def
_metric_to_json
(
m
:
MetricStats
)
->
dict
[
str
,
Any
]:
return
{
f
"avg_
{
m
.
key
}
_time"
:
m
.
avg
,
f
"
{
m
.
key
}
_times"
:
m
.
values
,
f
"
{
m
.
key
}
_percentiles"
:
m
.
percentiles
,
}
@
contextmanager
@
contextmanager
def
cold_startup
():
def
cold_startup
():
...
@@ -72,6 +148,7 @@ def run_startup_in_subprocess(engine_args, result_queue):
...
@@ -72,6 +148,7 @@ def run_startup_in_subprocess(engine_args, result_queue):
# Extract compilation time if available
# Extract compilation time if available
compilation_time
=
0.0
compilation_time
=
0.0
encoder_compilation_time
=
0.0
if
hasattr
(
llm
.
llm_engine
,
"vllm_config"
):
if
hasattr
(
llm
.
llm_engine
,
"vllm_config"
):
vllm_config
=
llm
.
llm_engine
.
vllm_config
vllm_config
=
llm
.
llm_engine
.
vllm_config
if
(
if
(
...
@@ -79,11 +156,15 @@ def run_startup_in_subprocess(engine_args, result_queue):
...
@@ -79,11 +156,15 @@ def run_startup_in_subprocess(engine_args, result_queue):
and
vllm_config
.
compilation_config
is
not
None
and
vllm_config
.
compilation_config
is
not
None
):
):
compilation_time
=
vllm_config
.
compilation_config
.
compilation_time
compilation_time
=
vllm_config
.
compilation_config
.
compilation_time
encoder_compilation_time
=
(
vllm_config
.
compilation_config
.
encoder_compilation_time
)
result_queue
.
put
(
result_queue
.
put
(
{
{
"total_startup_time"
:
total_startup_time
,
"total_startup_time"
:
total_startup_time
,
"compilation_time"
:
compilation_time
,
"compilation_time"
:
compilation_time
,
"encoder_compilation_time"
:
encoder_compilation_time
,
}
}
)
)
...
@@ -93,65 +174,20 @@ def run_startup_in_subprocess(engine_args, result_queue):
...
@@ -93,65 +174,20 @@ def run_startup_in_subprocess(engine_args, result_queue):
def
save_to_pytorch_benchmark_format
(
def
save_to_pytorch_benchmark_format
(
args
:
argparse
.
Namespace
,
results
:
dict
[
str
,
Any
]
args
:
argparse
.
Namespace
,
metrics
:
list
[
MetricStats
]
)
->
None
:
)
->
None
:
base_name
=
os
.
path
.
splitext
(
args
.
output_json
)[
0
]
base_name
=
os
.
path
.
splitext
(
args
.
output_json
)[
0
]
for
m
in
metrics
:
cold_startup_
records
=
convert_to_pytorch_benchmark_format
(
records
=
convert_to_pytorch_benchmark_format
(
args
=
args
,
args
=
args
,
metrics
=
{
metrics
=
{
f
"avg_
{
m
.
key
}
_time"
:
[
m
.
avg
]},
"avg_cold_startup_time"
:
[
results
[
"avg_cold_startup_time"
]],
},
extra_info
=
{
extra_info
=
{
"cold_startup_times"
:
results
[
"cold_startup_times"
]
,
f
"
{
m
.
key
}
_times"
:
m
.
values
,
"cold_startup
_percentiles"
:
results
[
"cold_startup_
percentiles
"
]
,
f
"
{
m
.
key
}
_percentiles"
:
m
.
percentiles
,
},
},
)
)
if
cold_startup_records
:
if
records
:
write_to_json
(
f
"
{
base_name
}
.cold_startup.pytorch.json"
,
cold_startup_records
)
write_to_json
(
f
"
{
base_name
}
.
{
m
.
key
}
.pytorch.json"
,
records
)
cold_compilation_records
=
convert_to_pytorch_benchmark_format
(
args
=
args
,
metrics
=
{
"avg_cold_compilation_time"
:
[
results
[
"avg_cold_compilation_time"
]],
},
extra_info
=
{
"cold_compilation_times"
:
results
[
"cold_compilation_times"
],
"cold_compilation_percentiles"
:
results
[
"cold_compilation_percentiles"
],
},
)
if
cold_compilation_records
:
write_to_json
(
f
"
{
base_name
}
.cold_compilation.pytorch.json"
,
cold_compilation_records
)
warm_startup_records
=
convert_to_pytorch_benchmark_format
(
args
=
args
,
metrics
=
{
"avg_warm_startup_time"
:
[
results
[
"avg_warm_startup_time"
]],
},
extra_info
=
{
"warm_startup_times"
:
results
[
"warm_startup_times"
],
"warm_startup_percentiles"
:
results
[
"warm_startup_percentiles"
],
},
)
if
warm_startup_records
:
write_to_json
(
f
"
{
base_name
}
.warm_startup.pytorch.json"
,
warm_startup_records
)
warm_compilation_records
=
convert_to_pytorch_benchmark_format
(
args
=
args
,
metrics
=
{
"avg_warm_compilation_time"
:
[
results
[
"avg_warm_compilation_time"
]],
},
extra_info
=
{
"warm_compilation_times"
:
results
[
"warm_compilation_times"
],
"warm_compilation_percentiles"
:
results
[
"warm_compilation_percentiles"
],
},
)
if
warm_compilation_records
:
write_to_json
(
f
"
{
base_name
}
.warm_compilation.pytorch.json"
,
warm_compilation_records
)
def
add_cli_args
(
parser
:
argparse
.
ArgumentParser
):
def
add_cli_args
(
parser
:
argparse
.
ArgumentParser
):
...
@@ -224,97 +260,46 @@ def main(args: argparse.Namespace):
...
@@ -224,97 +260,46 @@ def main(args: argparse.Namespace):
os
.
environ
[
"VLLM_ENABLE_V1_MULTIPROCESSING"
]
=
"0"
os
.
environ
[
"VLLM_ENABLE_V1_MULTIPROCESSING"
]
=
"0"
print
(
"Setting VLLM_ENABLE_V1_MULTIPROCESSING=0 to collect startup metrics.
\n
"
)
print
(
"Setting VLLM_ENABLE_V1_MULTIPROCESSING=0 to collect startup metrics.
\n
"
)
# Collect cold startup iterations
print
(
"Measuring cold startup time...
\n
"
)
print
(
"Measuring cold startup time...
\n
"
)
cold_startup_times
=
[]
cold_iterations
=
[]
cold_compilation_times
=
[]
for
i
in
tqdm
(
range
(
args
.
num_iters_cold
),
desc
=
"Cold startup iterations"
):
for
i
in
tqdm
(
range
(
args
.
num_iters_cold
),
desc
=
"Cold startup iterations"
):
with
cold_startup
():
with
cold_startup
():
metrics
=
create_llm_and_measure_startup
()
cold_iterations
.
append
(
create_llm_and_measure_startup
())
cold_startup_times
.
append
(
metrics
[
"total_startup_time"
])
cold_compilation_times
.
append
(
metrics
[
"compilation_time"
])
# Warmup for warm startup
# Warmup for warm startup
print
(
"
\n
Warming up for warm startup measurement...
\n
"
)
print
(
"
\n
Warming up for warm startup measurement...
\n
"
)
for
_
in
tqdm
(
range
(
args
.
num_iters_warmup
),
desc
=
"Warmup iterations"
):
for
_
in
tqdm
(
range
(
args
.
num_iters_warmup
),
desc
=
"Warmup iterations"
):
create_llm_and_measure_startup
()
create_llm_and_measure_startup
()
# Collect warm startup iterations
print
(
"
\n
Measuring warm startup time...
\n
"
)
print
(
"
\n
Measuring warm startup time...
\n
"
)
warm_startup_times
=
[]
warm_iterations
=
[]
warm_compilation_times
=
[]
for
i
in
tqdm
(
range
(
args
.
num_iters_warm
),
desc
=
"Warm startup iterations"
):
for
i
in
tqdm
(
range
(
args
.
num_iters_warm
),
desc
=
"Warm startup iterations"
):
metrics
=
create_llm_and_measure_startup
()
warm_iterations
.
append
(
create_llm_and_measure_startup
())
warm_startup_times
.
append
(
metrics
[
"total_startup_time"
])
warm_compilation_times
.
append
(
metrics
[
"compilation_time"
])
# Determine if encoder compilation occurred in any iteration
has_encoder
=
any
(
# Calculate statistics
m
[
"encoder_compilation_time"
]
>
0
for
m
in
cold_iterations
+
warm_iterations
cold_startup_array
=
np
.
array
(
cold_startup_times
)
)
cold_compilation_array
=
np
.
array
(
cold_compilation_times
)
warm_startup_array
=
np
.
array
(
warm_startup_times
)
cold_metrics
=
_collect_phase_metrics
(
"cold"
,
cold_iterations
,
has_encoder
)
warm_compilation_array
=
np
.
array
(
warm_compilation_times
)
warm_metrics
=
_collect_phase_metrics
(
"warm"
,
warm_iterations
,
has_encoder
)
all_metrics
=
cold_metrics
+
warm_metrics
avg_cold_startup
=
np
.
mean
(
cold_startup_array
)
avg_cold_compilation
=
np
.
mean
(
cold_compilation_array
)
avg_warm_startup
=
np
.
mean
(
warm_startup_array
)
avg_warm_compilation
=
np
.
mean
(
warm_compilation_array
)
percentages
=
[
10
,
25
,
50
,
75
,
90
,
99
]
cold_startup_percentiles
=
np
.
percentile
(
cold_startup_array
,
percentages
)
cold_compilation_percentiles
=
np
.
percentile
(
cold_compilation_array
,
percentages
)
warm_startup_percentiles
=
np
.
percentile
(
warm_startup_array
,
percentages
)
warm_compilation_percentiles
=
np
.
percentile
(
warm_compilation_array
,
percentages
)
# Print results
print
(
"
\n
"
+
"="
*
60
)
print
(
"
\n
"
+
"="
*
60
)
print
(
"STARTUP TIME BENCHMARK RESULTS"
)
print
(
"STARTUP TIME BENCHMARK RESULTS"
)
print
(
"="
*
60
)
print
(
"="
*
60
)
_print_phase
(
"COLD STARTUP"
,
cold_metrics
)
# Cold startup statistics
_print_phase
(
"WARM STARTUP"
,
warm_metrics
)
print
(
"
\n
COLD STARTUP:"
)
print
(
f
"Avg total startup time:
{
avg_cold_startup
:.
2
f
}
seconds"
)
print
(
f
"Avg compilation time:
{
avg_cold_compilation
:.
2
f
}
seconds"
)
print
(
"Startup time percentiles:"
)
for
percentage
,
percentile
in
zip
(
percentages
,
cold_startup_percentiles
):
print
(
f
"
{
percentage
}
%:
{
percentile
:.
2
f
}
seconds"
)
print
(
"Compilation time percentiles:"
)
for
percentage
,
percentile
in
zip
(
percentages
,
cold_compilation_percentiles
):
print
(
f
"
{
percentage
}
%:
{
percentile
:.
2
f
}
seconds"
)
# Warm startup statistics
print
(
"
\n
WARM STARTUP:"
)
print
(
f
"Avg total startup time:
{
avg_warm_startup
:.
2
f
}
seconds"
)
print
(
f
"Avg compilation time:
{
avg_warm_compilation
:.
2
f
}
seconds"
)
print
(
"Startup time percentiles:"
)
for
percentage
,
percentile
in
zip
(
percentages
,
warm_startup_percentiles
):
print
(
f
"
{
percentage
}
%:
{
percentile
:.
2
f
}
seconds"
)
print
(
"Compilation time percentiles:"
)
for
percentage
,
percentile
in
zip
(
percentages
,
warm_compilation_percentiles
):
print
(
f
"
{
percentage
}
%:
{
percentile
:.
2
f
}
seconds"
)
print
(
"="
*
60
)
print
(
"="
*
60
)
# Output JSON results if specified
# Output JSON results if specified
if
args
.
output_json
:
if
args
.
output_json
:
results
=
{
results
:
dict
[
str
,
Any
]
=
{}
"avg_cold_startup_time"
:
float
(
avg_cold_startup
),
for
m
in
all_metrics
:
"avg_cold_compilation_time"
:
float
(
avg_cold_compilation
),
results
.
update
(
_metric_to_json
(
m
))
"cold_startup_times"
:
cold_startup_times
,
"cold_compilation_times"
:
cold_compilation_times
,
"cold_startup_percentiles"
:
dict
(
zip
(
percentages
,
cold_startup_percentiles
.
tolist
())
),
"cold_compilation_percentiles"
:
dict
(
zip
(
percentages
,
cold_compilation_percentiles
.
tolist
())
),
"avg_warm_startup_time"
:
float
(
avg_warm_startup
),
"avg_warm_compilation_time"
:
float
(
avg_warm_compilation
),
"warm_startup_times"
:
warm_startup_times
,
"warm_compilation_times"
:
warm_compilation_times
,
"warm_startup_percentiles"
:
dict
(
zip
(
percentages
,
warm_startup_percentiles
.
tolist
())
),
"warm_compilation_percentiles"
:
dict
(
zip
(
percentages
,
warm_compilation_percentiles
.
tolist
())
),
}
with
open
(
args
.
output_json
,
"w"
)
as
f
:
with
open
(
args
.
output_json
,
"w"
)
as
f
:
json
.
dump
(
results
,
f
,
indent
=
4
)
json
.
dump
(
results
,
f
,
indent
=
4
)
save_to_pytorch_benchmark_format
(
args
,
result
s
)
save_to_pytorch_benchmark_format
(
args
,
all_metric
s
)
vllm/compilation/backends.py
View file @
c08f3b2a
...
@@ -265,6 +265,7 @@ class CompilerManager:
...
@@ -265,6 +265,7 @@ class CompilerManager:
compile_range
:
Range
,
compile_range
:
Range
,
graph_index
:
int
=
0
,
graph_index
:
int
=
0
,
num_graphs
:
int
=
1
,
num_graphs
:
int
=
1
,
is_encoder
:
bool
=
False
,
)
->
Any
:
)
->
Any
:
if
graph_index
==
0
:
if
graph_index
==
0
:
# before compiling the first graph, record the start time
# before compiling the first graph, record the start time
...
@@ -282,6 +283,9 @@ class CompilerManager:
...
@@ -282,6 +283,9 @@ class CompilerManager:
# after loading the last graph for this shape, record the time.
# after loading the last graph for this shape, record the time.
# there can be multiple graphs due to piecewise compilation.
# there can be multiple graphs due to piecewise compilation.
elapsed
=
time
.
perf_counter
()
-
compilation_start_time
elapsed
=
time
.
perf_counter
()
-
compilation_start_time
if
is_encoder
:
compilation_config
.
encoder_compilation_time
+=
elapsed
else
:
compilation_config
.
compilation_time
+=
elapsed
compilation_config
.
compilation_time
+=
elapsed
logger
.
info_once
(
logger
.
info_once
(
"Directly load the compiled graph(s) for compile range %s "
"Directly load the compiled graph(s) for compile range %s "
...
@@ -387,6 +391,9 @@ class CompilerManager:
...
@@ -387,6 +391,9 @@ class CompilerManager:
# after compiling the last graph, record the end time
# after compiling the last graph, record the end time
if
graph_index
==
num_graphs
-
1
:
if
graph_index
==
num_graphs
-
1
:
elapsed
=
time
.
perf_counter
()
-
compilation_start_time
elapsed
=
time
.
perf_counter
()
-
compilation_start_time
if
is_encoder
:
compilation_config
.
encoder_compilation_time
+=
elapsed
else
:
compilation_config
.
compilation_time
+=
elapsed
compilation_config
.
compilation_time
+=
elapsed
logger
.
info_once
(
logger
.
info_once
(
"Compiling a graph for compile range %s takes %.2f s"
,
"Compiling a graph for compile range %s takes %.2f s"
,
...
@@ -1130,6 +1137,9 @@ class VllmBackend:
...
@@ -1130,6 +1137,9 @@ class VllmBackend:
logger
.
info_once
(
logger
.
info_once
(
"Dynamo bytecode transform time: %.2f s"
,
dynamo_time
,
scope
=
"local"
"Dynamo bytecode transform time: %.2f s"
,
dynamo_time
,
scope
=
"local"
)
)
if
self
.
is_encoder
:
self
.
compilation_config
.
encoder_compilation_time
+=
dynamo_time
else
:
self
.
compilation_config
.
compilation_time
+=
dynamo_time
self
.
compilation_config
.
compilation_time
+=
dynamo_time
# Record Dynamo time in tracing if available
# Record Dynamo time in tracing if available
...
...
vllm/compilation/piecewise_backend.py
View file @
c08f3b2a
...
@@ -270,6 +270,7 @@ class PiecewiseBackend:
...
@@ -270,6 +270,7 @@ class PiecewiseBackend:
compile_range
=
range_entry
.
compile_range
,
compile_range
=
range_entry
.
compile_range
,
graph_index
=
self
.
piecewise_compile_index
,
graph_index
=
self
.
piecewise_compile_index
,
num_graphs
=
self
.
total_piecewise_compiles
,
num_graphs
=
self
.
total_piecewise_compiles
,
is_encoder
=
self
.
vllm_backend
.
is_encoder
,
)
)
range_entry
.
compiled
=
True
range_entry
.
compiled
=
True
...
...
vllm/config/compilation.py
View file @
c08f3b2a
...
@@ -710,6 +710,8 @@ class CompilationConfig:
...
@@ -710,6 +710,8 @@ class CompilationConfig:
"""files that are traced for compilation"""
"""files that are traced for compilation"""
compilation_time
:
float
=
field
(
default
=
0.0
,
init
=
False
)
compilation_time
:
float
=
field
(
default
=
0.0
,
init
=
False
)
"""time taken for compilation"""
"""time taken for compilation"""
encoder_compilation_time
:
float
=
field
(
default
=
0.0
,
init
=
False
)
"""time taken for multimodal encoder compilation"""
static_forward_context
:
dict
[
str
,
Any
]
=
field
(
default_factory
=
dict
,
init
=
False
)
static_forward_context
:
dict
[
str
,
Any
]
=
field
(
default_factory
=
dict
,
init
=
False
)
"""Per-model forward context
"""Per-model forward context
...
@@ -756,6 +758,7 @@ class CompilationConfig:
...
@@ -756,6 +758,7 @@ class CompilationConfig:
"local_cache_dir"
,
"local_cache_dir"
,
"traced_files"
,
"traced_files"
,
"compilation_time"
,
"compilation_time"
,
"encoder_compilation_time"
,
"static_forward_context"
,
"static_forward_context"
,
"pass_config"
,
# handled separately below
"pass_config"
,
# handled separately below
"dynamic_shapes_config"
,
# handled separately below
"dynamic_shapes_config"
,
# handled separately below
...
@@ -775,6 +778,7 @@ class CompilationConfig:
...
@@ -775,6 +778,7 @@ class CompilationConfig:
"enabled_custom_ops"
:
True
,
"enabled_custom_ops"
:
True
,
"disabled_custom_ops"
:
True
,
"disabled_custom_ops"
:
True
,
"compilation_time"
:
True
,
"compilation_time"
:
True
,
"encoder_compilation_time"
:
True
,
"traced_files"
:
True
,
"traced_files"
:
True
,
"inductor_compile_config"
:
{
"inductor_compile_config"
:
{
"post_grad_custom_post_pass"
:
True
,
"post_grad_custom_post_pass"
:
True
,
...
...
vllm/v1/engine/core.py
View file @
c08f3b2a
...
@@ -282,8 +282,30 @@ class EngineCore:
...
@@ -282,8 +282,30 @@ class EngineCore:
self
.
model_executor
.
initialize_from_config
(
kv_cache_configs
)
self
.
model_executor
.
initialize_from_config
(
kv_cache_configs
)
elapsed
=
time
.
time
()
-
start
elapsed
=
time
.
time
()
-
start
compile_time
=
vllm_config
.
compilation_config
.
compilation_time
encoder_compile_time
=
vllm_config
.
compilation_config
.
encoder_compilation_time
if
encoder_compile_time
>
0
:
logger
.
info_once
(
logger
.
info_once
(
"init engine (profile, create kv cache, warmup model) took %.2f seconds"
,
"init engine (profile, create kv cache, warmup model) took "
"%.2f s (compilation: %.2f s — language_model: %.2f s, "
"encoder: %.2f s)"
,
elapsed
,
compile_time
+
encoder_compile_time
,
compile_time
,
encoder_compile_time
,
scope
=
"local"
,
)
elif
compile_time
>
0
:
logger
.
info_once
(
"init engine (profile, create kv cache, warmup model) took "
"%.2f s (compilation: %.2f s)"
,
elapsed
,
compile_time
,
scope
=
"local"
,
)
else
:
logger
.
info_once
(
"init engine (profile, create kv cache, warmup model) took %.2f s"
,
elapsed
,
elapsed
,
scope
=
"local"
,
scope
=
"local"
,
)
)
...
...
vllm/v1/executor/abstract.py
View file @
c08f3b2a
...
@@ -22,7 +22,7 @@ from vllm.v1.core.sched.output import GrammarOutput, SchedulerOutput
...
@@ -22,7 +22,7 @@ from vllm.v1.core.sched.output import GrammarOutput, SchedulerOutput
from
vllm.v1.engine
import
ReconfigureDistributedRequest
from
vllm.v1.engine
import
ReconfigureDistributedRequest
from
vllm.v1.kv_cache_interface
import
KVCacheConfig
,
KVCacheSpec
from
vllm.v1.kv_cache_interface
import
KVCacheConfig
,
KVCacheSpec
from
vllm.v1.outputs
import
DraftTokenIds
,
ModelRunnerOutput
from
vllm.v1.outputs
import
DraftTokenIds
,
ModelRunnerOutput
from
vllm.v1.worker.worker_base
import
WorkerBase
from
vllm.v1.worker.worker_base
import
CompilationTimes
,
WorkerBase
if
TYPE_CHECKING
:
if
TYPE_CHECKING
:
from
vllm.distributed.kv_transfer.kv_connector.base
import
KVConnectorBase
from
vllm.distributed.kv_transfer.kv_connector.base
import
KVConnectorBase
...
@@ -121,14 +121,19 @@ class Executor(ABC):
...
@@ -121,14 +121,19 @@ class Executor(ABC):
underlying workers.
underlying workers.
"""
"""
self
.
collective_rpc
(
"initialize_from_config"
,
args
=
(
kv_cache_configs
,))
self
.
collective_rpc
(
"initialize_from_config"
,
args
=
(
kv_cache_configs
,))
compilation_times
:
list
[
float
]
=
self
.
collective_rpc
(
"compile_or_warm_up_model"
)
compilation_times
:
list
[
CompilationTimes
]
=
self
.
collective_rpc
(
"compile_or_warm_up_model"
)
# Propagate compilation time from workers back to the main process.
# Propagate compilation time from workers back to the main process.
# With TP>1, compilation happens in worker processes, so the main
# With TP>1, compilation happens in worker processes, so the main
# process config is never updated. Use max across workers since they
# process config is never updated. Use max across workers since they
# compile in parallel.
# compile in parallel.
if
compilation_times
:
if
compilation_times
:
self
.
vllm_config
.
compilation_config
.
compilation_time
=
max
(
self
.
vllm_config
.
compilation_config
.
compilation_time
=
max
(
compilation_times
t
.
language_model
for
t
in
compilation_times
)
self
.
vllm_config
.
compilation_config
.
encoder_compilation_time
=
max
(
t
.
encoder
for
t
in
compilation_times
)
)
def
register_failure_callback
(
self
,
callback
:
FailureCallback
):
# noqa: B027
def
register_failure_callback
(
self
,
callback
:
FailureCallback
):
# noqa: B027
...
...
vllm/v1/worker/cpu_worker.py
View file @
c08f3b2a
...
@@ -13,6 +13,7 @@ from vllm.profiler.wrapper import TorchProfilerWrapper
...
@@ -13,6 +13,7 @@ from vllm.profiler.wrapper import TorchProfilerWrapper
from
vllm.utils.torch_utils
import
set_random_seed
from
vllm.utils.torch_utils
import
set_random_seed
from
vllm.v1.worker.cpu_model_runner
import
CPUModelRunner
from
vllm.v1.worker.cpu_model_runner
import
CPUModelRunner
from
vllm.v1.worker.gpu_worker
import
Worker
,
init_worker_distributed_environment
from
vllm.v1.worker.gpu_worker
import
Worker
,
init_worker_distributed_environment
from
vllm.v1.worker.worker_base
import
CompilationTimes
logger
=
init_logger
(
__name__
)
logger
=
init_logger
(
__name__
)
...
@@ -104,12 +105,15 @@ class CPUWorker(Worker):
...
@@ -104,12 +105,15 @@ class CPUWorker(Worker):
def
determine_available_memory
(
self
)
->
int
:
def
determine_available_memory
(
self
)
->
int
:
return
self
.
cache_config
.
cpu_kvcache_space_bytes
or
0
return
self
.
cache_config
.
cpu_kvcache_space_bytes
or
0
def
compile_or_warm_up_model
(
self
)
->
float
:
def
compile_or_warm_up_model
(
self
)
->
CompilationTimes
:
# Reset the seed to ensure that the random state is not affected by
# Reset the seed to ensure that the random state is not affected by
# the model initialization and profiling.
# the model initialization and profiling.
set_random_seed
(
self
.
model_config
.
seed
)
set_random_seed
(
self
.
model_config
.
seed
)
self
.
model_runner
.
warming_up_model
()
self
.
model_runner
.
warming_up_model
()
return
self
.
compilation_config
.
compilation_time
return
CompilationTimes
(
language_model
=
self
.
compilation_config
.
compilation_time
,
encoder
=
self
.
compilation_config
.
encoder_compilation_time
,
)
def
profile
(
self
,
is_start
:
bool
=
True
,
profile_prefix
:
str
|
None
=
None
):
def
profile
(
self
,
is_start
:
bool
=
True
,
profile_prefix
:
str
|
None
=
None
):
if
self
.
profiler
is
None
:
if
self
.
profiler
is
None
:
...
...
vllm/v1/worker/gpu_worker.py
View file @
c08f3b2a
...
@@ -56,7 +56,7 @@ from vllm.v1.outputs import (
...
@@ -56,7 +56,7 @@ from vllm.v1.outputs import (
)
)
from
vllm.v1.utils
import
compute_iteration_details
,
report_usage_stats
from
vllm.v1.utils
import
compute_iteration_details
,
report_usage_stats
from
vllm.v1.worker.utils
import
is_residual_scattered_for_sp
from
vllm.v1.worker.utils
import
is_residual_scattered_for_sp
from
vllm.v1.worker.worker_base
import
WorkerBase
from
vllm.v1.worker.worker_base
import
CompilationTimes
,
WorkerBase
from
vllm.v1.worker.workspace
import
init_workspace_manager
from
vllm.v1.worker.workspace
import
init_workspace_manager
from
...model_executor.model_loader
import
TensorizerLoader
from
...model_executor.model_loader
import
TensorizerLoader
...
@@ -547,7 +547,7 @@ class Worker(WorkerBase):
...
@@ -547,7 +547,7 @@ class Worker(WorkerBase):
self
.
model_runner
.
_init_kv_zero_meta
()
self
.
model_runner
.
_init_kv_zero_meta
()
@
instrument
(
span_name
=
"Warmup (GPU)"
)
@
instrument
(
span_name
=
"Warmup (GPU)"
)
def
compile_or_warm_up_model
(
self
)
->
float
:
def
compile_or_warm_up_model
(
self
)
->
CompilationTimes
:
warmup_sizes
:
list
[
int
]
=
[]
warmup_sizes
:
list
[
int
]
=
[]
if
self
.
vllm_config
.
compilation_config
.
mode
==
CompilationMode
.
VLLM_COMPILE
:
if
self
.
vllm_config
.
compilation_config
.
mode
==
CompilationMode
.
VLLM_COMPILE
:
...
@@ -689,7 +689,10 @@ class Worker(WorkerBase):
...
@@ -689,7 +689,10 @@ class Worker(WorkerBase):
# the model initialization and profiling.
# the model initialization and profiling.
set_random_seed
(
self
.
model_config
.
seed
)
set_random_seed
(
self
.
model_config
.
seed
)
return
self
.
compilation_config
.
compilation_time
return
CompilationTimes
(
language_model
=
self
.
compilation_config
.
compilation_time
,
encoder
=
self
.
compilation_config
.
encoder_compilation_time
,
)
def
reset_mm_cache
(
self
)
->
None
:
def
reset_mm_cache
(
self
)
->
None
:
self
.
model_runner
.
reset_mm_cache
()
self
.
model_runner
.
reset_mm_cache
()
...
...
vllm/v1/worker/worker_base.py
View file @
c08f3b2a
...
@@ -2,7 +2,7 @@
...
@@ -2,7 +2,7 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from
collections.abc
import
Callable
from
collections.abc
import
Callable
from
typing
import
TYPE_CHECKING
,
Any
,
TypeVar
from
typing
import
TYPE_CHECKING
,
Any
,
NamedTuple
,
TypeVar
import
torch
import
torch
import
torch.nn
as
nn
import
torch.nn
as
nn
...
@@ -30,6 +30,11 @@ logger = init_logger(__name__)
...
@@ -30,6 +30,11 @@ logger = init_logger(__name__)
_R
=
TypeVar
(
"_R"
)
_R
=
TypeVar
(
"_R"
)
class
CompilationTimes
(
NamedTuple
):
language_model
:
float
encoder
:
float
class
WorkerBase
:
class
WorkerBase
:
"""Worker interface that allows vLLM to cleanly separate implementations for
"""Worker interface that allows vLLM to cleanly separate implementations for
different hardware. Also abstracts control plane communication, e.g., to
different hardware. Also abstracts control plane communication, e.g., to
...
@@ -86,11 +91,11 @@ class WorkerBase:
...
@@ -86,11 +91,11 @@ class WorkerBase:
"""Get specifications for KV cache implementation."""
"""Get specifications for KV cache implementation."""
raise
NotImplementedError
raise
NotImplementedError
def
compile_or_warm_up_model
(
self
)
->
float
:
def
compile_or_warm_up_model
(
self
)
->
CompilationTimes
:
"""Prepare model for execution through compilation/warmup.
"""Prepare model for execution through compilation/warmup.
Returns:
Returns:
The accumulated compilation time
in seconds.
Compilation times (language_model, encoder)
in seconds.
"""
"""
raise
NotImplementedError
raise
NotImplementedError
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment