Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
6e650f56
Unverified
Commit
6e650f56
authored
Jan 24, 2025
by
youkaichao
Committed by
GitHub
Jan 24, 2025
Browse files
[torch.compile] decouple compile sizes and cudagraph sizes (#12243)
Signed-off-by:
youkaichao
<
youkaichao@gmail.com
>
parent
3f50c148
Changes
7
Hide whitespace changes
Inline
Side-by-side
Showing
7 changed files
with
95 additions
and
58 deletions
+95
-58
vllm/compilation/backends.py
vllm/compilation/backends.py
+5
-5
vllm/config.py
vllm/config.py
+37
-34
vllm/engine/metrics.py
vllm/engine/metrics.py
+2
-1
vllm/v1/worker/gpu_model_runner.py
vllm/v1/worker/gpu_model_runner.py
+10
-8
vllm/v1/worker/gpu_worker.py
vllm/v1/worker/gpu_worker.py
+12
-0
vllm/worker/model_runner.py
vllm/worker/model_runner.py
+17
-10
vllm/worker/worker.py
vllm/worker/worker.py
+12
-0
No files found.
vllm/compilation/backends.py
View file @
6e650f56
...
@@ -680,7 +680,7 @@ class VllmBackend:
...
@@ -680,7 +680,7 @@ class VllmBackend:
class
ConcreteSizeEntry
:
class
ConcreteSizeEntry
:
runtime_shape
:
int
runtime_shape
:
int
need_to_compile
:
bool
# the size is in compile_sizes
need_to_compile
:
bool
# the size is in compile_sizes
use_cudagraph
:
bool
# the size is in capture_sizes
use_cudagraph
:
bool
# the size is in
cudagraph_
capture_sizes
compiled
:
bool
=
False
compiled
:
bool
=
False
runnable
:
Callable
=
None
# type: ignore
runnable
:
Callable
=
None
# type: ignore
...
@@ -727,8 +727,8 @@ class PiecewiseBackend:
...
@@ -727,8 +727,8 @@ class PiecewiseBackend:
self
.
compile_sizes
:
Set
[
int
]
=
set
(
self
.
compile_sizes
:
Set
[
int
]
=
set
(
self
.
compilation_config
.
compile_sizes
)
self
.
compilation_config
.
compile_sizes
)
self
.
capture_sizes
:
Set
[
int
]
=
set
(
self
.
cudagraph_
capture_sizes
:
Set
[
int
]
=
set
(
self
.
compilation_config
.
capture_sizes
self
.
compilation_config
.
cudagraph_
capture_sizes
)
if
self
.
compilation_config
.
use_cudagraph
else
set
()
)
if
self
.
compilation_config
.
use_cudagraph
else
set
()
self
.
first_run_finished
=
False
self
.
first_run_finished
=
False
...
@@ -746,11 +746,11 @@ class PiecewiseBackend:
...
@@ -746,11 +746,11 @@ class PiecewiseBackend:
# to_be_compiled_sizes tracks the remaining sizes to compile,
# to_be_compiled_sizes tracks the remaining sizes to compile,
# and updates during the compilation process, so we need to copy it
# and updates during the compilation process, so we need to copy it
self
.
to_be_compiled_sizes
:
Set
[
int
]
=
self
.
compile_sizes
.
copy
()
self
.
to_be_compiled_sizes
:
Set
[
int
]
=
self
.
compile_sizes
.
copy
()
for
shape
in
self
.
compile_sizes
.
union
(
self
.
capture_sizes
):
for
shape
in
self
.
compile_sizes
.
union
(
self
.
cudagraph_
capture_sizes
):
self
.
concrete_size_entries
[
shape
]
=
ConcreteSizeEntry
(
self
.
concrete_size_entries
[
shape
]
=
ConcreteSizeEntry
(
runtime_shape
=
shape
,
runtime_shape
=
shape
,
need_to_compile
=
shape
in
self
.
compile_sizes
,
need_to_compile
=
shape
in
self
.
compile_sizes
,
use_cudagraph
=
shape
in
self
.
capture_sizes
,
use_cudagraph
=
shape
in
self
.
cudagraph_
capture_sizes
,
)
)
def
check_for_ending_compilation
(
self
):
def
check_for_ending_compilation
(
self
):
...
...
vllm/config.py
View file @
6e650f56
...
@@ -2711,10 +2711,11 @@ class CompilationConfig(BaseModel):
...
@@ -2711,10 +2711,11 @@ class CompilationConfig(BaseModel):
- use_inductor: whether to use inductor compilation.
- use_inductor: whether to use inductor compilation.
- False: inductor compilation is not used. graph runs in eager.
- False: inductor compilation is not used. graph runs in eager.
- True: inductor compilation is used. one graph for symbolic shape
- True: inductor compilation is used. one graph for symbolic shape
is compiled. In addition, compile for cudagraph sizes that are
is compiled. In addition, compile for compile_sizes,
in candidate_compile_sizes, using configurations
using configurations in inductor_compile_config.
in inductor_compile_config.
- compile_sizes: sizes to compile for inductor. In addition
- candidate_compile_sizes: sizes to compile for inductor.
to integers, it also supports "cudagraph_capture_sizes" to
specify the sizes for cudagraph capture.
- inductor_compile_config: additional configurations for inductor.
- inductor_compile_config: additional configurations for inductor.
- None: use default configurations.
- None: use default configurations.
- inductor_passes: additional passes for inductor. It is a dictionary
- inductor_passes: additional passes for inductor. It is a dictionary
...
@@ -2742,7 +2743,7 @@ class CompilationConfig(BaseModel):
...
@@ -2742,7 +2743,7 @@ class CompilationConfig(BaseModel):
splitting_ops
:
List
[
str
]
=
Field
(
default
=
None
)
# type: ignore
splitting_ops
:
List
[
str
]
=
Field
(
default
=
None
)
# type: ignore
use_inductor
:
bool
=
True
use_inductor
:
bool
=
True
candidate_
compile_sizes
:
Optional
[
List
[
int
]]
=
Field
(
default
=
None
)
compile_sizes
:
Optional
[
List
[
Union
[
int
,
str
]
]]
=
Field
(
default
=
None
)
inductor_compile_config
:
Dict
=
Field
(
default_factory
=
dict
)
inductor_compile_config
:
Dict
=
Field
(
default_factory
=
dict
)
inductor_passes
:
Dict
[
str
,
str
]
=
Field
(
default_factory
=
dict
)
inductor_passes
:
Dict
[
str
,
str
]
=
Field
(
default_factory
=
dict
)
...
@@ -2790,8 +2791,6 @@ class CompilationConfig(BaseModel):
...
@@ -2790,8 +2791,6 @@ class CompilationConfig(BaseModel):
pass_config
:
PassConfig
=
Field
(
default_factory
=
PassConfig
)
pass_config
:
PassConfig
=
Field
(
default_factory
=
PassConfig
)
# not configurable, computed after init
# not configurable, computed after init
compile_sizes
:
List
[
int
]
=
PrivateAttr
capture_sizes
:
List
[
int
]
=
PrivateAttr
max_capture_size
:
int
=
PrivateAttr
max_capture_size
:
int
=
PrivateAttr
local_cache_dir
:
str
=
PrivateAttr
# local cache dir for each rank
local_cache_dir
:
str
=
PrivateAttr
# local cache dir for each rank
# optimization:
# optimization:
...
@@ -2918,43 +2917,47 @@ class CompilationConfig(BaseModel):
...
@@ -2918,43 +2917,47 @@ class CompilationConfig(BaseModel):
from
vllm.compilation.backends
import
VllmBackend
from
vllm.compilation.backends
import
VllmBackend
return
VllmBackend
(
vllm_config
)
return
VllmBackend
(
vllm_config
)
def
init_with_cudagraph_sizes
(
self
,
sizes_to_specialize
:
List
[
int
]):
def
init_with_cudagraph_sizes
(
self
,
cudagraph_capture_sizes
:
List
[
int
])
->
None
:
"""To complete the initialization of config,
"""To complete the initialization of config,
we need to know the cudagraph sizes."""
we need to know the cudagraph sizes."""
if
self
.
cudagraph_capture_sizes
is
None
:
if
self
.
cudagraph_capture_sizes
is
None
:
self
.
capture_sizes
=
sizes_to_special
ize
self
.
cudagraph_
capture_sizes
=
cudagraph_capture_s
ize
s
else
:
else
:
self
.
capture_sizes
=
self
.
cudagraph_capture_sizes
# de-duplicate the sizes provided by the config
self
.
cudagraph_capture_sizes
=
list
(
set
(
self
.
cudagraph_capture_sizes
))
logger
.
info
((
"cudagraph sizes specified by model runner"
logger
.
info
((
"cudagraph sizes specified by model runner"
" %s is overridden by config %s"
),
" %s is overridden by config %s"
),
sizes_to_specialize
,
self
.
cudagraph_capture_sizes
)
cudagraph_capture_sizes
,
self
.
cudagraph_capture_sizes
)
if
self
.
candidate_compile_sizes
is
None
:
computed_compile_sizes
=
[]
self
.
candidate_compile_sizes
=
[]
if
self
.
compile_sizes
is
not
None
:
self
.
compile_sizes
=
[
# de-duplicate the sizes provided by the config
x
for
x
in
self
.
candidate_compile_sizes
if
x
in
self
.
capture_sizes
self
.
compile_sizes
=
list
(
set
(
self
.
compile_sizes
))
]
for
x
in
self
.
compile_sizes
:
ignored_sizes
=
[
if
isinstance
(
x
,
str
):
x
for
x
in
self
.
candidate_compile_sizes
assert
x
==
"cudagraph_capture_sizes"
,
\
if
x
not
in
self
.
capture_sizes
"Unrecognized size type in compile_sizes, "
\
]
f
"expect 'cudagraph_capture_sizes', got
{
x
}
"
if
ignored_sizes
:
computed_compile_sizes
.
extend
(
self
.
cudagraph_capture_sizes
)
logger
.
warning
((
"candidate_compile_sizes %s are ignored "
else
:
"because they are not cudagraph capture sizes."
),
assert
isinstance
(
x
,
int
)
ignored_sizes
)
computed_compile_sizes
.
append
(
x
)
self
.
compile_sizes
=
computed_compile_sizes
# type: ignore
# sort to make sure cudagraph capture sizes are in descending order
# sort to make sure cudagraph capture sizes are in descending order
self
.
capture_sizes
.
sort
(
reverse
=
True
)
self
.
cudagraph_
capture_sizes
.
sort
(
reverse
=
True
)
self
.
max_capture_size
=
self
.
capture_sizes
[
self
.
max_capture_size
=
self
.
cudagraph_
capture_sizes
[
0
]
if
self
.
capture_sizes
else
0
0
]
if
self
.
cudagraph_
capture_sizes
else
0
# pre-compute the mapping from batch size to padded graph size
# pre-compute the mapping from batch size to padded graph size
self
.
bs_to_padded_graph_size
=
[
self
.
bs_to_padded_graph_size
=
[
0
for
i
in
range
(
self
.
max_capture_size
+
1
)
0
for
i
in
range
(
self
.
max_capture_size
+
1
)
]
]
for
end
,
start
in
zip
(
self
.
capture_sizes
,
for
end
,
start
in
zip
(
self
.
cudagraph_
capture_sizes
,
self
.
capture_sizes
[
1
:]
+
[
0
]):
self
.
cudagraph_
capture_sizes
[
1
:]
+
[
0
]):
for
bs
in
range
(
start
,
end
):
for
bs
in
range
(
start
,
end
):
if
bs
==
start
:
if
bs
==
start
:
self
.
bs_to_padded_graph_size
[
bs
]
=
start
self
.
bs_to_padded_graph_size
[
bs
]
=
start
...
@@ -3225,14 +3228,14 @@ class VllmConfig:
...
@@ -3225,14 +3228,14 @@ class VllmConfig:
However, if users specify the cudagraph capture sizes through
However, if users specify the cudagraph capture sizes through
compilation config, we will use the specified sizes instead.
compilation config, we will use the specified sizes instead.
In the end, `vllm_config.compilation_config.capture_sizes`
will be the
In the end, `vllm_config.compilation_config.
cudagraph_
capture_sizes`
final sizes to capture cudagraph (in descending order).
will be the
final sizes to capture cudagraph (in descending order).
During runtime, if batchsize is larger than
During runtime, if batchsize is larger than
`vllm_config.compilation_config.capture_sizes`,
`vllm_config.compilation_config.
cudagraph_
capture_sizes`,
no cudagraph will be used.
no cudagraph will be used.
If the batch size is no larger than
If the batch size is no larger than
`vllm_config.compilation_config.capture_sizes`,
`vllm_config.compilation_config.
cudagraph_
capture_sizes`,
we can quickly find the padded graph size for a given batch size by
we can quickly find the padded graph size for a given batch size by
looking up `vllm_config.compilation_config.bs_to_padded_graph_size`.
looking up `vllm_config.compilation_config.bs_to_padded_graph_size`.
"""
"""
...
...
vllm/engine/metrics.py
View file @
6e650f56
...
@@ -120,7 +120,8 @@ class Metrics:
...
@@ -120,7 +120,8 @@ class Metrics:
labelnames
=
labelnames
)
labelnames
=
labelnames
)
buckets
=
[
1
,
8
,
16
,
32
,
64
,
128
,
256
,
512
,
1024
,
2048
,
4096
,
8096
]
buckets
=
[
1
,
8
,
16
,
32
,
64
,
128
,
256
,
512
,
1024
,
2048
,
4096
,
8096
]
if
not
vllm_config
.
model_config
.
enforce_eager
:
if
not
vllm_config
.
model_config
.
enforce_eager
:
buckets
=
vllm_config
.
compilation_config
.
capture_sizes
.
copy
()
buckets
=
vllm_config
.
compilation_config
.
\
cudagraph_capture_sizes
.
copy
()
buckets
.
sort
()
buckets
.
sort
()
self
.
histogram_iteration_tokens
=
self
.
_histogram_cls
(
self
.
histogram_iteration_tokens
=
self
.
_histogram_cls
(
name
=
"vllm:iteration_tokens_total"
,
name
=
"vllm:iteration_tokens_total"
,
...
...
vllm/v1/worker/gpu_model_runner.py
View file @
6e650f56
import
gc
import
gc
import
time
import
time
from
typing
import
TYPE_CHECKING
,
Dict
,
List
,
Tuple
,
cast
from
typing
import
TYPE_CHECKING
,
Dict
,
List
,
Optional
,
Tuple
,
cast
import
numpy
as
np
import
numpy
as
np
import
torch
import
torch
...
@@ -128,7 +128,8 @@ class GPUModelRunner:
...
@@ -128,7 +128,8 @@ class GPUModelRunner:
# self.cudagraph_batch_sizes sorts in ascending order.
# self.cudagraph_batch_sizes sorts in ascending order.
# The batch sizes in the config are in descending order.
# The batch sizes in the config are in descending order.
self
.
cudagraph_batch_sizes
=
list
(
self
.
cudagraph_batch_sizes
=
list
(
reversed
(
self
.
vllm_config
.
compilation_config
.
capture_sizes
))
reversed
(
self
.
vllm_config
.
compilation_config
.
cudagraph_capture_sizes
))
# Cache the device properties.
# Cache the device properties.
self
.
device_properties
=
torch
.
cuda
.
get_device_properties
(
self
.
device
)
self
.
device_properties
=
torch
.
cuda
.
get_device_properties
(
self
.
device
)
...
@@ -834,10 +835,12 @@ class GPUModelRunner:
...
@@ -834,10 +835,12 @@ class GPUModelRunner:
@
torch
.
inference_mode
()
@
torch
.
inference_mode
()
def
_dummy_run
(
def
_dummy_run
(
self
,
self
,
model
:
nn
.
Module
,
num_tokens
:
int
,
num_tokens
:
int
,
kv_caches
:
List
[
torch
.
Tensor
],
kv_caches
:
Optional
[
List
[
torch
.
Tensor
]
]
=
None
,
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
model
=
self
.
model
if
kv_caches
is
None
:
kv_caches
=
self
.
kv_caches
if
self
.
is_multimodal_model
:
if
self
.
is_multimodal_model
:
input_ids
=
None
input_ids
=
None
inputs_embeds
=
self
.
inputs_embeds
[:
num_tokens
]
inputs_embeds
=
self
.
inputs_embeds
[:
num_tokens
]
...
@@ -963,8 +966,7 @@ class GPUModelRunner:
...
@@ -963,8 +966,7 @@ class GPUModelRunner:
self
.
encoder_cache
[
"tmp"
]
=
dict
(
enumerate
(
dummy_encoder_outputs
))
self
.
encoder_cache
[
"tmp"
]
=
dict
(
enumerate
(
dummy_encoder_outputs
))
# Trigger compilation for general shape.
# Trigger compilation for general shape.
hidden_states
=
self
.
_dummy_run
(
self
.
model
,
self
.
max_num_tokens
,
hidden_states
=
self
.
_dummy_run
(
self
.
max_num_tokens
,
dummy_kv_caches
)
dummy_kv_caches
)
logits
=
self
.
model
.
compute_logits
(
hidden_states
,
None
)
logits
=
self
.
model
.
compute_logits
(
hidden_states
,
None
)
logits
=
logits
[:
self
.
max_num_tokens
]
logits
=
logits
[:
self
.
max_num_tokens
]
# TODO(woosuk): Consider the memory usage of the sampler.
# TODO(woosuk): Consider the memory usage of the sampler.
...
@@ -990,8 +992,8 @@ class GPUModelRunner:
...
@@ -990,8 +992,8 @@ class GPUModelRunner:
for
num_tokens
in
reversed
(
self
.
cudagraph_batch_sizes
):
for
num_tokens
in
reversed
(
self
.
cudagraph_batch_sizes
):
for
_
in
range
(
self
.
vllm_config
.
compilation_config
.
for
_
in
range
(
self
.
vllm_config
.
compilation_config
.
cudagraph_num_of_warmups
):
cudagraph_num_of_warmups
):
self
.
_dummy_run
(
self
.
model
,
num_tokens
,
self
.
kv_cache
s
)
self
.
_dummy_run
(
num_token
s
)
self
.
_dummy_run
(
self
.
model
,
num_tokens
,
self
.
kv_cache
s
)
self
.
_dummy_run
(
num_token
s
)
end_time
=
time
.
perf_counter
()
end_time
=
time
.
perf_counter
()
end_free_gpu_memory
=
torch
.
cuda
.
mem_get_info
()[
0
]
end_free_gpu_memory
=
torch
.
cuda
.
mem_get_info
()[
0
]
...
...
vllm/v1/worker/gpu_worker.py
View file @
6e650f56
...
@@ -206,6 +206,18 @@ class Worker:
...
@@ -206,6 +206,18 @@ class Worker:
self
.
model_runner
.
initialize_kv_cache
(
kv_cache_config
)
self
.
model_runner
.
initialize_kv_cache
(
kv_cache_config
)
def
compile_or_warm_up_model
(
self
)
->
None
:
def
compile_or_warm_up_model
(
self
)
->
None
:
# warm up sizes that are not in cudagraph capture sizes,
# but users still want to compile for better performance,
# e.g. for the max-num-batched token size in chunked prefill.
warmup_sizes
=
self
.
vllm_config
.
compilation_config
.
compile_sizes
.
copy
()
if
not
self
.
model_config
.
enforce_eager
:
warmup_sizes
=
[
x
for
x
in
warmup_sizes
if
x
not
in
self
.
vllm_config
.
compilation_config
.
cudagraph_capture_sizes
]
for
size
in
sorted
(
warmup_sizes
,
reverse
=
True
):
logger
.
info
(
"Compile and warming up model for size %d"
,
size
)
self
.
model_runner
.
_dummy_run
(
size
)
if
not
self
.
model_config
.
enforce_eager
:
if
not
self
.
model_config
.
enforce_eager
:
self
.
model_runner
.
capture_model
()
self
.
model_runner
.
capture_model
()
# Reset the seed to ensure that the random state is not affected by
# Reset the seed to ensure that the random state is not affected by
...
...
vllm/worker/model_runner.py
View file @
6e650f56
...
@@ -1256,13 +1256,19 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]):
...
@@ -1256,13 +1256,19 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]):
@
torch
.
inference_mode
()
@
torch
.
inference_mode
()
def
profile_run
(
self
)
->
None
:
def
profile_run
(
self
)
->
None
:
max_num_batched_tokens
=
\
self
.
scheduler_config
.
max_num_batched_tokens
max_num_seqs
=
self
.
scheduler_config
.
max_num_seqs
self
.
_dummy_run
(
max_num_batched_tokens
,
max_num_seqs
)
def
_dummy_run
(
self
,
max_num_batched_tokens
:
int
,
max_num_seqs
:
int
=
1
)
->
None
:
with
self
.
set_in_profile_run
():
with
self
.
set_in_profile_run
():
# Enable top-k sampling to reflect the accurate memory usage.
# Enable top-k sampling to reflect the accurate memory usage.
sampling_params
=
\
sampling_params
=
\
SamplingParams
(
top_p
=
0.99
,
top_k
=
self
.
vocab_size
-
1
)
SamplingParams
(
top_p
=
0.99
,
top_k
=
self
.
vocab_size
-
1
)
max_num_batched_tokens
=
\
self
.
scheduler_config
.
max_num_batched_tokens
max_num_seqs
=
self
.
scheduler_config
.
max_num_seqs
# This represents the maximum number of different requests
# This represents the maximum number of different requests
# that will have unique loras, an therefore the max amount of memory
# that will have unique loras, an therefore the max amount of memory
# consumption create dummy lora request copies from the lora request
# consumption create dummy lora request copies from the lora request
...
@@ -1491,13 +1497,14 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]):
...
@@ -1491,13 +1497,14 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]):
for
virtual_engine
in
range
(
for
virtual_engine
in
range
(
self
.
parallel_config
.
pipeline_parallel_size
):
self
.
parallel_config
.
pipeline_parallel_size
):
# Only rank 0 should print progress bar during capture
# Only rank 0 should print progress bar during capture
capture_sizes
=
(
cudagraph_capture_sizes
=
(
tqdm
(
tqdm
(
self
.
vllm_config
.
compilation_config
.
self
.
vllm_config
.
compilation_config
.
capture_sizes
,
cudagraph_capture_sizes
,
desc
=
"Capturing CUDA graph shapes"
,
desc
=
"Capturing CUDA graph shapes"
,
)
if
get_tensor_model_parallel_rank
()
==
0
else
)
if
get_tensor_model_parallel_rank
()
==
0
else
self
.
vllm_config
.
compilation_config
.
capture_sizes
)
self
.
vllm_config
.
compilation_config
.
for
batch_size
in
capture_sizes
:
cudagraph_capture_sizes
)
for
batch_size
in
cudagraph_capture_sizes
:
attn_metadata
=
(
attn_metadata
=
(
self
.
attn_state
.
graph_capture_get_metadata_for_batch
(
self
.
attn_state
.
graph_capture_get_metadata_for_batch
(
batch_size
,
batch_size
,
...
...
vllm/worker/worker.py
View file @
6e650f56
...
@@ -323,6 +323,18 @@ class Worker(LocalOrDistributedWorkerBase):
...
@@ -323,6 +323,18 @@ class Worker(LocalOrDistributedWorkerBase):
self
.
gpu_cache
)
self
.
gpu_cache
)
def
_warm_up_model
(
self
)
->
None
:
def
_warm_up_model
(
self
)
->
None
:
# warm up sizes that are not in cudagraph capture sizes,
# but users still want to compile for better performance,
# e.g. for the max-num-batched token size in chunked prefill.
warmup_sizes
=
self
.
vllm_config
.
compilation_config
.
compile_sizes
.
copy
()
if
not
self
.
model_config
.
enforce_eager
:
warmup_sizes
=
[
x
for
x
in
warmup_sizes
if
x
not
in
self
.
vllm_config
.
compilation_config
.
cudagraph_capture_sizes
]
for
size
in
sorted
(
warmup_sizes
,
reverse
=
True
):
logger
.
info
(
"Compile and warming up model for size %d"
,
size
)
self
.
model_runner
.
_dummy_run
(
size
)
if
not
self
.
model_config
.
enforce_eager
:
if
not
self
.
model_config
.
enforce_eager
:
self
.
model_runner
.
capture_model
(
self
.
gpu_cache
)
self
.
model_runner
.
capture_model
(
self
.
gpu_cache
)
# Reset the seed to ensure that the random state is not affected by
# Reset the seed to ensure that the random state is not affected by
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment