Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
c08f3b2a
Unverified
Commit
c08f3b2a
authored
Apr 14, 2026
by
Lucas Kabela
Committed by
GitHub
Apr 14, 2026
Browse files
Measure encoder compile time seperate from llm backbone (#39240)
Signed-off-by:
Lucas Kabela
<
lucaskabela@meta.com
>
parent
f02b3269
Changes
9
Hide whitespace changes
Inline
Side-by-side
Showing
9 changed files
with
188 additions
and
149 deletions
+188
-149
vllm/benchmarks/startup.py
vllm/benchmarks/startup.py
+115
-130
vllm/compilation/backends.py
vllm/compilation/backends.py
+13
-3
vllm/compilation/piecewise_backend.py
vllm/compilation/piecewise_backend.py
+1
-0
vllm/config/compilation.py
vllm/config/compilation.py
+4
-0
vllm/v1/engine/core.py
vllm/v1/engine/core.py
+27
-5
vllm/v1/executor/abstract.py
vllm/v1/executor/abstract.py
+8
-3
vllm/v1/worker/cpu_worker.py
vllm/v1/worker/cpu_worker.py
+6
-2
vllm/v1/worker/gpu_worker.py
vllm/v1/worker/gpu_worker.py
+6
-3
vllm/v1/worker/worker_base.py
vllm/v1/worker/worker_base.py
+8
-3
No files found.
vllm/benchmarks/startup.py
View file @
c08f3b2a
...
...
@@ -16,7 +16,7 @@ import shutil
import
tempfile
import
time
from
contextlib
import
contextmanager
from
typing
import
Any
from
typing
import
Any
,
NamedTuple
import
numpy
as
np
from
tqdm
import
tqdm
...
...
@@ -27,6 +27,82 @@ from vllm.benchmarks.lib.utils import (
)
from
vllm.engine.arg_utils
import
EngineArgs
PERCENTAGES
=
[
10
,
25
,
50
,
75
,
90
,
99
]
class
MetricDesc
(
NamedTuple
):
"""Descriptor for a metric to collect from each iteration."""
iter_key
:
str
# key in the iteration result dict
suffix
:
str
# result key suffix, e.g. "startup", "compilation"
display_name
:
str
class
MetricStats
(
NamedTuple
):
"""Aggregated statistics for a single benchmark metric."""
key
:
str
# e.g. "cold_startup", "warm_encoder_compilation"
display_name
:
str
values
:
list
[
float
]
avg
:
float
percentiles
:
dict
[
int
,
float
]
_BASE_METRICS
=
[
MetricDesc
(
"total_startup_time"
,
"startup"
,
"Startup time"
),
MetricDesc
(
"compilation_time"
,
"compilation"
,
"Compilation time"
),
]
_ENCODER_METRIC
=
MetricDesc
(
"encoder_compilation_time"
,
"encoder_compilation"
,
"Encoder compilation time"
,
)
def
_compute_metric
(
phase
:
str
,
desc
:
MetricDesc
,
iterations
:
list
[
dict
[
str
,
float
]],
)
->
MetricStats
:
values
=
[
m
[
desc
.
iter_key
]
for
m
in
iterations
]
arr
=
np
.
array
(
values
)
return
MetricStats
(
key
=
f
"
{
phase
}
_
{
desc
.
suffix
}
"
,
display_name
=
desc
.
display_name
,
values
=
values
,
avg
=
float
(
np
.
mean
(
arr
)),
percentiles
=
dict
(
zip
(
PERCENTAGES
,
np
.
percentile
(
arr
,
PERCENTAGES
).
tolist
())),
)
def
_collect_phase_metrics
(
phase
:
str
,
iterations
:
list
[
dict
[
str
,
float
]],
has_encoder
:
bool
,
)
->
list
[
MetricStats
]:
metrics
=
[
_compute_metric
(
phase
,
desc
,
iterations
)
for
desc
in
_BASE_METRICS
]
if
has_encoder
:
metrics
.
append
(
_compute_metric
(
phase
,
_ENCODER_METRIC
,
iterations
))
return
metrics
def
_print_phase
(
phase_name
:
str
,
metrics
:
list
[
MetricStats
])
->
None
:
print
(
f
"
\n
{
phase_name
}
:"
)
for
m
in
metrics
:
print
(
f
"Avg
{
m
.
display_name
.
lower
()
}
:
{
m
.
avg
:.
2
f
}
seconds"
)
for
m
in
metrics
:
print
(
f
"
{
m
.
display_name
}
percentiles:"
)
for
pct
,
val
in
m
.
percentiles
.
items
():
print
(
f
"
{
pct
}
%:
{
val
:.
2
f
}
seconds"
)
def
_metric_to_json
(
m
:
MetricStats
)
->
dict
[
str
,
Any
]:
return
{
f
"avg_
{
m
.
key
}
_time"
:
m
.
avg
,
f
"
{
m
.
key
}
_times"
:
m
.
values
,
f
"
{
m
.
key
}
_percentiles"
:
m
.
percentiles
,
}
@
contextmanager
def
cold_startup
():
...
...
@@ -72,6 +148,7 @@ def run_startup_in_subprocess(engine_args, result_queue):
# Extract compilation time if available
compilation_time
=
0.0
encoder_compilation_time
=
0.0
if
hasattr
(
llm
.
llm_engine
,
"vllm_config"
):
vllm_config
=
llm
.
llm_engine
.
vllm_config
if
(
...
...
@@ -79,11 +156,15 @@ def run_startup_in_subprocess(engine_args, result_queue):
and
vllm_config
.
compilation_config
is
not
None
):
compilation_time
=
vllm_config
.
compilation_config
.
compilation_time
encoder_compilation_time
=
(
vllm_config
.
compilation_config
.
encoder_compilation_time
)
result_queue
.
put
(
{
"total_startup_time"
:
total_startup_time
,
"compilation_time"
:
compilation_time
,
"encoder_compilation_time"
:
encoder_compilation_time
,
}
)
...
...
@@ -93,65 +174,20 @@ def run_startup_in_subprocess(engine_args, result_queue):
def
save_to_pytorch_benchmark_format
(
args
:
argparse
.
Namespace
,
results
:
dict
[
str
,
Any
]
args
:
argparse
.
Namespace
,
metrics
:
list
[
MetricStats
]
)
->
None
:
base_name
=
os
.
path
.
splitext
(
args
.
output_json
)[
0
]
cold_startup_records
=
convert_to_pytorch_benchmark_format
(
args
=
args
,
metrics
=
{
"avg_cold_startup_time"
:
[
results
[
"avg_cold_startup_time"
]],
},
extra_info
=
{
"cold_startup_times"
:
results
[
"cold_startup_times"
],
"cold_startup_percentiles"
:
results
[
"cold_startup_percentiles"
],
},
)
if
cold_startup_records
:
write_to_json
(
f
"
{
base_name
}
.cold_startup.pytorch.json"
,
cold_startup_records
)
cold_compilation_records
=
convert_to_pytorch_benchmark_format
(
args
=
args
,
metrics
=
{
"avg_cold_compilation_time"
:
[
results
[
"avg_cold_compilation_time"
]],
},
extra_info
=
{
"cold_compilation_times"
:
results
[
"cold_compilation_times"
],
"cold_compilation_percentiles"
:
results
[
"cold_compilation_percentiles"
],
},
)
if
cold_compilation_records
:
write_to_json
(
f
"
{
base_name
}
.cold_compilation.pytorch.json"
,
cold_compilation_records
)
warm_startup_records
=
convert_to_pytorch_benchmark_format
(
args
=
args
,
metrics
=
{
"avg_warm_startup_time"
:
[
results
[
"avg_warm_startup_time"
]],
},
extra_info
=
{
"warm_startup_times"
:
results
[
"warm_startup_times"
],
"warm_startup_percentiles"
:
results
[
"warm_startup_percentiles"
],
},
)
if
warm_startup_records
:
write_to_json
(
f
"
{
base_name
}
.warm_startup.pytorch.json"
,
warm_startup_records
)
warm_compilation_records
=
convert_to_pytorch_benchmark_format
(
args
=
args
,
metrics
=
{
"avg_warm_compilation_time"
:
[
results
[
"avg_warm_compilation_time"
]],
},
extra_info
=
{
"warm_compilation_times"
:
results
[
"warm_compilation_times"
],
"warm_compilation_percentiles"
:
results
[
"warm_compilation_percentiles"
],
},
)
if
warm_compilation_records
:
write_to_json
(
f
"
{
base_name
}
.warm_compilation.pytorch.json"
,
warm_compilation_records
for
m
in
metrics
:
records
=
convert_to_pytorch_benchmark_format
(
args
=
args
,
metrics
=
{
f
"avg_
{
m
.
key
}
_time"
:
[
m
.
avg
]},
extra_info
=
{
f
"
{
m
.
key
}
_times"
:
m
.
values
,
f
"
{
m
.
key
}
_percentiles"
:
m
.
percentiles
,
},
)
if
records
:
write_to_json
(
f
"
{
base_name
}
.
{
m
.
key
}
.pytorch.json"
,
records
)
def
add_cli_args
(
parser
:
argparse
.
ArgumentParser
):
...
...
@@ -224,97 +260,46 @@ def main(args: argparse.Namespace):
os
.
environ
[
"VLLM_ENABLE_V1_MULTIPROCESSING"
]
=
"0"
print
(
"Setting VLLM_ENABLE_V1_MULTIPROCESSING=0 to collect startup metrics.
\n
"
)
# Collect cold startup iterations
print
(
"Measuring cold startup time...
\n
"
)
cold_startup_times
=
[]
cold_compilation_times
=
[]
cold_iterations
=
[]
for
i
in
tqdm
(
range
(
args
.
num_iters_cold
),
desc
=
"Cold startup iterations"
):
with
cold_startup
():
metrics
=
create_llm_and_measure_startup
()
cold_startup_times
.
append
(
metrics
[
"total_startup_time"
])
cold_compilation_times
.
append
(
metrics
[
"compilation_time"
])
cold_iterations
.
append
(
create_llm_and_measure_startup
())
# Warmup for warm startup
print
(
"
\n
Warming up for warm startup measurement...
\n
"
)
for
_
in
tqdm
(
range
(
args
.
num_iters_warmup
),
desc
=
"Warmup iterations"
):
create_llm_and_measure_startup
()
# Collect warm startup iterations
print
(
"
\n
Measuring warm startup time...
\n
"
)
warm_startup_times
=
[]
warm_compilation_times
=
[]
warm_iterations
=
[]
for
i
in
tqdm
(
range
(
args
.
num_iters_warm
),
desc
=
"Warm startup iterations"
):
metrics
=
create_llm_and_measure_startup
()
warm_startup_times
.
append
(
metrics
[
"total_startup_time"
])
warm_compilation_times
.
append
(
metrics
[
"compilation_time"
])
# Calculate statistics
cold_startup_array
=
np
.
array
(
cold_startup_times
)
cold_compilation_array
=
np
.
array
(
cold_compilation_times
)
warm_startup_array
=
np
.
array
(
warm_startup_times
)
warm_compilation_array
=
np
.
array
(
warm_compilation_times
)
avg_cold_startup
=
np
.
mean
(
cold_startup_array
)
avg_cold_compilation
=
np
.
mean
(
cold_compilation_array
)
avg_warm_startup
=
np
.
mean
(
warm_startup_array
)
avg_warm_compilation
=
np
.
mean
(
warm_compilation_array
)
percentages
=
[
10
,
25
,
50
,
75
,
90
,
99
]
cold_startup_percentiles
=
np
.
percentile
(
cold_startup_array
,
percentages
)
cold_compilation_percentiles
=
np
.
percentile
(
cold_compilation_array
,
percentages
)
warm_startup_percentiles
=
np
.
percentile
(
warm_startup_array
,
percentages
)
warm_compilation_percentiles
=
np
.
percentile
(
warm_compilation_array
,
percentages
)
warm_iterations
.
append
(
create_llm_and_measure_startup
())
# Determine if encoder compilation occurred in any iteration
has_encoder
=
any
(
m
[
"encoder_compilation_time"
]
>
0
for
m
in
cold_iterations
+
warm_iterations
)
cold_metrics
=
_collect_phase_metrics
(
"cold"
,
cold_iterations
,
has_encoder
)
warm_metrics
=
_collect_phase_metrics
(
"warm"
,
warm_iterations
,
has_encoder
)
all_metrics
=
cold_metrics
+
warm_metrics
# Print results
print
(
"
\n
"
+
"="
*
60
)
print
(
"STARTUP TIME BENCHMARK RESULTS"
)
print
(
"="
*
60
)
# Cold startup statistics
print
(
"
\n
COLD STARTUP:"
)
print
(
f
"Avg total startup time:
{
avg_cold_startup
:.
2
f
}
seconds"
)
print
(
f
"Avg compilation time:
{
avg_cold_compilation
:.
2
f
}
seconds"
)
print
(
"Startup time percentiles:"
)
for
percentage
,
percentile
in
zip
(
percentages
,
cold_startup_percentiles
):
print
(
f
"
{
percentage
}
%:
{
percentile
:.
2
f
}
seconds"
)
print
(
"Compilation time percentiles:"
)
for
percentage
,
percentile
in
zip
(
percentages
,
cold_compilation_percentiles
):
print
(
f
"
{
percentage
}
%:
{
percentile
:.
2
f
}
seconds"
)
# Warm startup statistics
print
(
"
\n
WARM STARTUP:"
)
print
(
f
"Avg total startup time:
{
avg_warm_startup
:.
2
f
}
seconds"
)
print
(
f
"Avg compilation time:
{
avg_warm_compilation
:.
2
f
}
seconds"
)
print
(
"Startup time percentiles:"
)
for
percentage
,
percentile
in
zip
(
percentages
,
warm_startup_percentiles
):
print
(
f
"
{
percentage
}
%:
{
percentile
:.
2
f
}
seconds"
)
print
(
"Compilation time percentiles:"
)
for
percentage
,
percentile
in
zip
(
percentages
,
warm_compilation_percentiles
):
print
(
f
"
{
percentage
}
%:
{
percentile
:.
2
f
}
seconds"
)
_print_phase
(
"COLD STARTUP"
,
cold_metrics
)
_print_phase
(
"WARM STARTUP"
,
warm_metrics
)
print
(
"="
*
60
)
# Output JSON results if specified
if
args
.
output_json
:
results
=
{
"avg_cold_startup_time"
:
float
(
avg_cold_startup
),
"avg_cold_compilation_time"
:
float
(
avg_cold_compilation
),
"cold_startup_times"
:
cold_startup_times
,
"cold_compilation_times"
:
cold_compilation_times
,
"cold_startup_percentiles"
:
dict
(
zip
(
percentages
,
cold_startup_percentiles
.
tolist
())
),
"cold_compilation_percentiles"
:
dict
(
zip
(
percentages
,
cold_compilation_percentiles
.
tolist
())
),
"avg_warm_startup_time"
:
float
(
avg_warm_startup
),
"avg_warm_compilation_time"
:
float
(
avg_warm_compilation
),
"warm_startup_times"
:
warm_startup_times
,
"warm_compilation_times"
:
warm_compilation_times
,
"warm_startup_percentiles"
:
dict
(
zip
(
percentages
,
warm_startup_percentiles
.
tolist
())
),
"warm_compilation_percentiles"
:
dict
(
zip
(
percentages
,
warm_compilation_percentiles
.
tolist
())
),
}
results
:
dict
[
str
,
Any
]
=
{}
for
m
in
all_metrics
:
results
.
update
(
_metric_to_json
(
m
))
with
open
(
args
.
output_json
,
"w"
)
as
f
:
json
.
dump
(
results
,
f
,
indent
=
4
)
save_to_pytorch_benchmark_format
(
args
,
result
s
)
save_to_pytorch_benchmark_format
(
args
,
all_metric
s
)
vllm/compilation/backends.py
View file @
c08f3b2a
...
...
@@ -265,6 +265,7 @@ class CompilerManager:
compile_range
:
Range
,
graph_index
:
int
=
0
,
num_graphs
:
int
=
1
,
is_encoder
:
bool
=
False
,
)
->
Any
:
if
graph_index
==
0
:
# before compiling the first graph, record the start time
...
...
@@ -282,7 +283,10 @@ class CompilerManager:
# after loading the last graph for this shape, record the time.
# there can be multiple graphs due to piecewise compilation.
elapsed
=
time
.
perf_counter
()
-
compilation_start_time
compilation_config
.
compilation_time
+=
elapsed
if
is_encoder
:
compilation_config
.
encoder_compilation_time
+=
elapsed
else
:
compilation_config
.
compilation_time
+=
elapsed
logger
.
info_once
(
"Directly load the compiled graph(s) for compile range %s "
"from the cache, took %.3f s"
,
...
...
@@ -387,7 +391,10 @@ class CompilerManager:
# after compiling the last graph, record the end time
if
graph_index
==
num_graphs
-
1
:
elapsed
=
time
.
perf_counter
()
-
compilation_start_time
compilation_config
.
compilation_time
+=
elapsed
if
is_encoder
:
compilation_config
.
encoder_compilation_time
+=
elapsed
else
:
compilation_config
.
compilation_time
+=
elapsed
logger
.
info_once
(
"Compiling a graph for compile range %s takes %.2f s"
,
str
(
compile_range
),
...
...
@@ -1130,7 +1137,10 @@ class VllmBackend:
logger
.
info_once
(
"Dynamo bytecode transform time: %.2f s"
,
dynamo_time
,
scope
=
"local"
)
self
.
compilation_config
.
compilation_time
+=
dynamo_time
if
self
.
is_encoder
:
self
.
compilation_config
.
encoder_compilation_time
+=
dynamo_time
else
:
self
.
compilation_config
.
compilation_time
+=
dynamo_time
# Record Dynamo time in tracing if available
start_time
=
int
(
torch_compile_start_time
*
1e9
)
...
...
vllm/compilation/piecewise_backend.py
View file @
c08f3b2a
...
...
@@ -270,6 +270,7 @@ class PiecewiseBackend:
compile_range
=
range_entry
.
compile_range
,
graph_index
=
self
.
piecewise_compile_index
,
num_graphs
=
self
.
total_piecewise_compiles
,
is_encoder
=
self
.
vllm_backend
.
is_encoder
,
)
range_entry
.
compiled
=
True
...
...
vllm/config/compilation.py
View file @
c08f3b2a
...
...
@@ -710,6 +710,8 @@ class CompilationConfig:
"""files that are traced for compilation"""
compilation_time
:
float
=
field
(
default
=
0.0
,
init
=
False
)
"""time taken for compilation"""
encoder_compilation_time
:
float
=
field
(
default
=
0.0
,
init
=
False
)
"""time taken for multimodal encoder compilation"""
static_forward_context
:
dict
[
str
,
Any
]
=
field
(
default_factory
=
dict
,
init
=
False
)
"""Per-model forward context
...
...
@@ -756,6 +758,7 @@ class CompilationConfig:
"local_cache_dir"
,
"traced_files"
,
"compilation_time"
,
"encoder_compilation_time"
,
"static_forward_context"
,
"pass_config"
,
# handled separately below
"dynamic_shapes_config"
,
# handled separately below
...
...
@@ -775,6 +778,7 @@ class CompilationConfig:
"enabled_custom_ops"
:
True
,
"disabled_custom_ops"
:
True
,
"compilation_time"
:
True
,
"encoder_compilation_time"
:
True
,
"traced_files"
:
True
,
"inductor_compile_config"
:
{
"post_grad_custom_post_pass"
:
True
,
...
...
vllm/v1/engine/core.py
View file @
c08f3b2a
...
...
@@ -282,11 +282,33 @@ class EngineCore:
self
.
model_executor
.
initialize_from_config
(
kv_cache_configs
)
elapsed
=
time
.
time
()
-
start
logger
.
info_once
(
"init engine (profile, create kv cache, warmup model) took %.2f seconds"
,
elapsed
,
scope
=
"local"
,
)
compile_time
=
vllm_config
.
compilation_config
.
compilation_time
encoder_compile_time
=
vllm_config
.
compilation_config
.
encoder_compilation_time
if
encoder_compile_time
>
0
:
logger
.
info_once
(
"init engine (profile, create kv cache, warmup model) took "
"%.2f s (compilation: %.2f s — language_model: %.2f s, "
"encoder: %.2f s)"
,
elapsed
,
compile_time
+
encoder_compile_time
,
compile_time
,
encoder_compile_time
,
scope
=
"local"
,
)
elif
compile_time
>
0
:
logger
.
info_once
(
"init engine (profile, create kv cache, warmup model) took "
"%.2f s (compilation: %.2f s)"
,
elapsed
,
compile_time
,
scope
=
"local"
,
)
else
:
logger
.
info_once
(
"init engine (profile, create kv cache, warmup model) took %.2f s"
,
elapsed
,
scope
=
"local"
,
)
return
scheduler_kv_cache_config
def
get_supported_tasks
(
self
)
->
tuple
[
SupportedTask
,
...]:
...
...
vllm/v1/executor/abstract.py
View file @
c08f3b2a
...
...
@@ -22,7 +22,7 @@ from vllm.v1.core.sched.output import GrammarOutput, SchedulerOutput
from
vllm.v1.engine
import
ReconfigureDistributedRequest
from
vllm.v1.kv_cache_interface
import
KVCacheConfig
,
KVCacheSpec
from
vllm.v1.outputs
import
DraftTokenIds
,
ModelRunnerOutput
from
vllm.v1.worker.worker_base
import
WorkerBase
from
vllm.v1.worker.worker_base
import
CompilationTimes
,
WorkerBase
if
TYPE_CHECKING
:
from
vllm.distributed.kv_transfer.kv_connector.base
import
KVConnectorBase
...
...
@@ -121,14 +121,19 @@ class Executor(ABC):
underlying workers.
"""
self
.
collective_rpc
(
"initialize_from_config"
,
args
=
(
kv_cache_configs
,))
compilation_times
:
list
[
float
]
=
self
.
collective_rpc
(
"compile_or_warm_up_model"
)
compilation_times
:
list
[
CompilationTimes
]
=
self
.
collective_rpc
(
"compile_or_warm_up_model"
)
# Propagate compilation time from workers back to the main process.
# With TP>1, compilation happens in worker processes, so the main
# process config is never updated. Use max across workers since they
# compile in parallel.
if
compilation_times
:
self
.
vllm_config
.
compilation_config
.
compilation_time
=
max
(
compilation_times
t
.
language_model
for
t
in
compilation_times
)
self
.
vllm_config
.
compilation_config
.
encoder_compilation_time
=
max
(
t
.
encoder
for
t
in
compilation_times
)
def
register_failure_callback
(
self
,
callback
:
FailureCallback
):
# noqa: B027
...
...
vllm/v1/worker/cpu_worker.py
View file @
c08f3b2a
...
...
@@ -13,6 +13,7 @@ from vllm.profiler.wrapper import TorchProfilerWrapper
from
vllm.utils.torch_utils
import
set_random_seed
from
vllm.v1.worker.cpu_model_runner
import
CPUModelRunner
from
vllm.v1.worker.gpu_worker
import
Worker
,
init_worker_distributed_environment
from
vllm.v1.worker.worker_base
import
CompilationTimes
logger
=
init_logger
(
__name__
)
...
...
@@ -104,12 +105,15 @@ class CPUWorker(Worker):
def
determine_available_memory
(
self
)
->
int
:
return
self
.
cache_config
.
cpu_kvcache_space_bytes
or
0
def
compile_or_warm_up_model
(
self
)
->
float
:
def
compile_or_warm_up_model
(
self
)
->
CompilationTimes
:
# Reset the seed to ensure that the random state is not affected by
# the model initialization and profiling.
set_random_seed
(
self
.
model_config
.
seed
)
self
.
model_runner
.
warming_up_model
()
return
self
.
compilation_config
.
compilation_time
return
CompilationTimes
(
language_model
=
self
.
compilation_config
.
compilation_time
,
encoder
=
self
.
compilation_config
.
encoder_compilation_time
,
)
def
profile
(
self
,
is_start
:
bool
=
True
,
profile_prefix
:
str
|
None
=
None
):
if
self
.
profiler
is
None
:
...
...
vllm/v1/worker/gpu_worker.py
View file @
c08f3b2a
...
...
@@ -56,7 +56,7 @@ from vllm.v1.outputs import (
)
from
vllm.v1.utils
import
compute_iteration_details
,
report_usage_stats
from
vllm.v1.worker.utils
import
is_residual_scattered_for_sp
from
vllm.v1.worker.worker_base
import
WorkerBase
from
vllm.v1.worker.worker_base
import
CompilationTimes
,
WorkerBase
from
vllm.v1.worker.workspace
import
init_workspace_manager
from
...model_executor.model_loader
import
TensorizerLoader
...
...
@@ -547,7 +547,7 @@ class Worker(WorkerBase):
self
.
model_runner
.
_init_kv_zero_meta
()
@
instrument
(
span_name
=
"Warmup (GPU)"
)
def
compile_or_warm_up_model
(
self
)
->
float
:
def
compile_or_warm_up_model
(
self
)
->
CompilationTimes
:
warmup_sizes
:
list
[
int
]
=
[]
if
self
.
vllm_config
.
compilation_config
.
mode
==
CompilationMode
.
VLLM_COMPILE
:
...
...
@@ -689,7 +689,10 @@ class Worker(WorkerBase):
# the model initialization and profiling.
set_random_seed
(
self
.
model_config
.
seed
)
return
self
.
compilation_config
.
compilation_time
return
CompilationTimes
(
language_model
=
self
.
compilation_config
.
compilation_time
,
encoder
=
self
.
compilation_config
.
encoder_compilation_time
,
)
def
reset_mm_cache
(
self
)
->
None
:
self
.
model_runner
.
reset_mm_cache
()
...
...
vllm/v1/worker/worker_base.py
View file @
c08f3b2a
...
...
@@ -2,7 +2,7 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from
collections.abc
import
Callable
from
typing
import
TYPE_CHECKING
,
Any
,
TypeVar
from
typing
import
TYPE_CHECKING
,
Any
,
NamedTuple
,
TypeVar
import
torch
import
torch.nn
as
nn
...
...
@@ -30,6 +30,11 @@ logger = init_logger(__name__)
_R
=
TypeVar
(
"_R"
)
class
CompilationTimes
(
NamedTuple
):
language_model
:
float
encoder
:
float
class
WorkerBase
:
"""Worker interface that allows vLLM to cleanly separate implementations for
different hardware. Also abstracts control plane communication, e.g., to
...
...
@@ -86,11 +91,11 @@ class WorkerBase:
"""Get specifications for KV cache implementation."""
raise
NotImplementedError
def
compile_or_warm_up_model
(
self
)
->
float
:
def
compile_or_warm_up_model
(
self
)
->
CompilationTimes
:
"""Prepare model for execution through compilation/warmup.
Returns:
The accumulated compilation time
in seconds.
Compilation times (language_model, encoder)
in seconds.
"""
raise
NotImplementedError
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment