Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
MMCV
Commits
fdeee889
Commit
fdeee889
authored
May 25, 2025
by
limm
Browse files
release v1.6.1 of mmcv
parent
df465820
Changes
490
Show whitespace changes
Inline
Side-by-side
Showing
10 changed files
with
193 additions
and
54 deletions
+193
-54
mmcv/runner/log_buffer.py
mmcv/runner/log_buffer.py
+4
-4
mmcv/runner/optimizer/builder.py
mmcv/runner/optimizer/builder.py
+4
-3
mmcv/runner/optimizer/default_constructor.py
mmcv/runner/optimizer/default_constructor.py
+24
-15
mmcv/runner/priority.py
mmcv/runner/priority.py
+2
-1
mmcv/runner/utils.py
mmcv/runner/utils.py
+10
-4
mmcv/tensorrt/__init__.py
mmcv/tensorrt/__init__.py
+2
-2
mmcv/tensorrt/init_plugins.py
mmcv/tensorrt/init_plugins.py
+42
-3
mmcv/tensorrt/preprocess.py
mmcv/tensorrt/preprocess.py
+17
-1
mmcv/tensorrt/tensorrt_utils.py
mmcv/tensorrt/tensorrt_utils.py
+70
-14
mmcv/utils/__init__.py
mmcv/utils/__init__.py
+18
-7
No files found.
Too many changes to show.
To preserve performance only
490 of 490+
files are displayed.
Plain diff
Email patch
mmcv/runner/log_buffer.py
View file @
fdeee889
...
...
@@ -12,16 +12,16 @@ class LogBuffer:
self
.
output
=
OrderedDict
()
self
.
ready
=
False
def
clear
(
self
):
def
clear
(
self
)
->
None
:
self
.
val_history
.
clear
()
self
.
n_history
.
clear
()
self
.
clear_output
()
def
clear_output
(
self
):
def
clear_output
(
self
)
->
None
:
self
.
output
.
clear
()
self
.
ready
=
False
def
update
(
self
,
vars
,
count
=
1
)
:
def
update
(
self
,
vars
:
dict
,
count
:
int
=
1
)
->
None
:
assert
isinstance
(
vars
,
dict
)
for
key
,
var
in
vars
.
items
():
if
key
not
in
self
.
val_history
:
...
...
@@ -30,7 +30,7 @@ class LogBuffer:
self
.
val_history
[
key
].
append
(
var
)
self
.
n_history
[
key
].
append
(
count
)
def
average
(
self
,
n
=
0
)
:
def
average
(
self
,
n
:
int
=
0
)
->
None
:
"""Average latest n values or all values."""
assert
n
>=
0
for
key
in
self
.
val_history
:
...
...
mmcv/runner/optimizer/builder.py
View file @
fdeee889
# Copyright (c) OpenMMLab. All rights reserved.
import
copy
import
inspect
from
typing
import
Dict
,
List
import
torch
...
...
@@ -10,7 +11,7 @@ OPTIMIZERS = Registry('optimizer')
OPTIMIZER_BUILDERS
=
Registry
(
'optimizer builder'
)
def
register_torch_optimizers
():
def
register_torch_optimizers
()
->
List
:
torch_optimizers
=
[]
for
module_name
in
dir
(
torch
.
optim
):
if
module_name
.
startswith
(
'__'
):
...
...
@@ -26,11 +27,11 @@ def register_torch_optimizers():
TORCH_OPTIMIZERS
=
register_torch_optimizers
()
def
build_optimizer_constructor
(
cfg
):
def
build_optimizer_constructor
(
cfg
:
Dict
):
return
build_from_cfg
(
cfg
,
OPTIMIZER_BUILDERS
)
def
build_optimizer
(
model
,
cfg
):
def
build_optimizer
(
model
,
cfg
:
Dict
):
optimizer_cfg
=
copy
.
deepcopy
(
cfg
)
constructor_type
=
optimizer_cfg
.
pop
(
'constructor'
,
'DefaultOptimizerConstructor'
)
...
...
mmcv/runner/optimizer/default_constructor.py
View file @
fdeee889
# Copyright (c) OpenMMLab. All rights reserved.
import
warnings
from
typing
import
Dict
,
List
,
Optional
,
Union
import
torch
import
torch.nn
as
nn
from
torch.nn
import
GroupNorm
,
LayerNorm
from
mmcv.utils
import
_BatchNorm
,
_InstanceNorm
,
build_from_cfg
,
is_list_of
...
...
@@ -46,16 +48,17 @@ class DefaultOptimizerConstructor:
would not be added into optimizer. Default: False.
Note:
1. If the option ``dcn_offset_lr_mult`` is used, the constructor will
override the effect of ``bias_lr_mult`` in the bias of offset
layer. So be careful when using both ``bias_lr_mult`` and
``dcn_offset_lr_mult``. If you wish to apply both of them to the
offset layer in deformable convs, set ``dcn_offset_lr_mult``
to the original ``dcn_offset_lr_mult`` * ``bias_lr_mult``.
override the effect of ``bias_lr_mult`` in the bias of offset layer.
So be careful when using both ``bias_lr_mult`` and
``dcn_offset_lr_mult``. If you wish to apply both of them to the offset
layer in deformable convs, set ``dcn_offset_lr_mult`` to the original
``dcn_offset_lr_mult`` * ``bias_lr_mult``.
2. If the option ``dcn_offset_lr_mult`` is used, the constructor will
apply it to all the DCN layers in the model. So be careful when
the model contains multiple DCN layers in places other than
backbone.
apply it to all the DCN layers in the model. So be careful when the
model contains multiple DCN layers in places other than backbone.
Args:
model (:obj:`nn.Module`): The model with parameters to be optimized.
...
...
@@ -83,7 +86,7 @@ class DefaultOptimizerConstructor:
>>> # assume model have attribute model.backbone and model.cls_head
>>> optimizer_cfg = dict(type='SGD', lr=0.01, weight_decay=0.95)
>>> paramwise_cfg = dict(custom_keys={
'
.
backbone': dict(lr_mult=0.1, decay_mult=0.9)})
'backbone': dict(lr_mult=0.1, decay_mult=0.9)})
>>> optim_builder = DefaultOptimizerConstructor(
>>> optimizer_cfg, paramwise_cfg)
>>> optimizer = optim_builder(model)
...
...
@@ -92,7 +95,9 @@ class DefaultOptimizerConstructor:
>>> # model.cls_head is (0.01, 0.95).
"""
def
__init__
(
self
,
optimizer_cfg
,
paramwise_cfg
=
None
):
def
__init__
(
self
,
optimizer_cfg
:
Dict
,
paramwise_cfg
:
Optional
[
Dict
]
=
None
):
if
not
isinstance
(
optimizer_cfg
,
dict
):
raise
TypeError
(
'optimizer_cfg should be a dict'
,
f
'but got
{
type
(
optimizer_cfg
)
}
'
)
...
...
@@ -102,7 +107,7 @@ class DefaultOptimizerConstructor:
self
.
base_wd
=
optimizer_cfg
.
get
(
'weight_decay'
,
None
)
self
.
_validate_cfg
()
def
_validate_cfg
(
self
):
def
_validate_cfg
(
self
)
->
None
:
if
not
isinstance
(
self
.
paramwise_cfg
,
dict
):
raise
TypeError
(
'paramwise_cfg should be None or a dict, '
f
'but got
{
type
(
self
.
paramwise_cfg
)
}
'
)
...
...
@@ -125,7 +130,7 @@ class DefaultOptimizerConstructor:
if
self
.
base_wd
is
None
:
raise
ValueError
(
'base_wd should not be None'
)
def
_is_in
(
self
,
param_group
,
param_group_list
)
:
def
_is_in
(
self
,
param_group
:
Dict
,
param_group_list
:
List
)
->
bool
:
assert
is_list_of
(
param_group_list
,
dict
)
param
=
set
(
param_group
[
'params'
])
param_set
=
set
()
...
...
@@ -134,7 +139,11 @@ class DefaultOptimizerConstructor:
return
not
param
.
isdisjoint
(
param_set
)
def
add_params
(
self
,
params
,
module
,
prefix
=
''
,
is_dcn_module
=
None
):
def
add_params
(
self
,
params
:
List
[
Dict
],
module
:
nn
.
Module
,
prefix
:
str
=
''
,
is_dcn_module
:
Union
[
int
,
float
,
None
]
=
None
)
->
None
:
"""Add all parameters of module to the params list.
The parameters of the given module will be added to the list of param
...
...
@@ -231,7 +240,7 @@ class DefaultOptimizerConstructor:
prefix
=
child_prefix
,
is_dcn_module
=
is_dcn_module
)
def
__call__
(
self
,
model
):
def
__call__
(
self
,
model
:
nn
.
Module
):
if
hasattr
(
model
,
'module'
):
model
=
model
.
module
...
...
@@ -242,7 +251,7 @@ class DefaultOptimizerConstructor:
return
build_from_cfg
(
optimizer_cfg
,
OPTIMIZERS
)
# set param-wise lr and weight decay recursively
params
=
[]
params
:
List
[
Dict
]
=
[]
self
.
add_params
(
params
,
model
)
optimizer_cfg
[
'params'
]
=
params
...
...
mmcv/runner/priority.py
View file @
fdeee889
# Copyright (c) OpenMMLab. All rights reserved.
from
enum
import
Enum
from
typing
import
Union
class
Priority
(
Enum
):
...
...
@@ -39,7 +40,7 @@ class Priority(Enum):
LOWEST
=
100
def
get_priority
(
priority
)
:
def
get_priority
(
priority
:
Union
[
int
,
str
,
Priority
])
->
int
:
"""Get priority value.
Args:
...
...
mmcv/runner/utils.py
View file @
fdeee889
...
...
@@ -6,6 +6,8 @@ import time
import
warnings
from
getpass
import
getuser
from
socket
import
gethostname
from
types
import
ModuleType
from
typing
import
Optional
import
numpy
as
np
import
torch
...
...
@@ -13,7 +15,7 @@ import torch
import
mmcv
def
get_host_info
():
def
get_host_info
()
->
str
:
"""Get hostname and username.
Return empty string if exception raised, e.g. ``getpass.getuser()`` will
...
...
@@ -28,11 +30,13 @@ def get_host_info():
return
host
def
get_time_str
():
def
get_time_str
()
->
str
:
return
time
.
strftime
(
'%Y%m%d_%H%M%S'
,
time
.
localtime
())
def
obj_from_dict
(
info
,
parent
=
None
,
default_args
=
None
):
def
obj_from_dict
(
info
:
dict
,
parent
:
Optional
[
ModuleType
]
=
None
,
default_args
:
Optional
[
dict
]
=
None
):
"""Initialize an object from dict.
The dict must contain the key "type", which indicates the object type, it
...
...
@@ -67,7 +71,9 @@ def obj_from_dict(info, parent=None, default_args=None):
return
obj_type
(
**
args
)
def
set_random_seed
(
seed
,
deterministic
=
False
,
use_rank_shift
=
False
):
def
set_random_seed
(
seed
:
int
,
deterministic
:
bool
=
False
,
use_rank_shift
:
bool
=
False
)
->
None
:
"""Set random seed.
Args:
...
...
mmcv/tensorrt/__init__.py
View file @
fdeee889
...
...
@@ -22,9 +22,9 @@ if is_tensorrt_available():
# load tensorrt plugin lib
load_tensorrt_plugin
()
__all__
.
app
end
([
__all__
.
ext
end
([
'onnx2trt'
,
'save_trt_engine'
,
'load_trt_engine'
,
'TRTWraper'
,
'TRTWrapper'
])
__all__
.
app
end
([
'is_tensorrt_plugin_loaded'
,
'preprocess_onnx'
])
__all__
.
ext
end
([
'is_tensorrt_plugin_loaded'
,
'preprocess_onnx'
])
mmcv/tensorrt/init_plugins.py
View file @
fdeee889
...
...
@@ -2,10 +2,23 @@
import
ctypes
import
glob
import
os
import
warnings
def
get_tensorrt_op_path
():
def
get_tensorrt_op_path
()
->
str
:
"""Get TensorRT plugins library path."""
# Following strings of text style are from colorama package
bright_style
,
reset_style
=
'
\x1b
[1m'
,
'
\x1b
[0m'
red_text
,
blue_text
=
'
\x1b
[31m'
,
'
\x1b
[34m'
white_background
=
'
\x1b
[107m'
msg
=
white_background
+
bright_style
+
red_text
msg
+=
'DeprecationWarning: This function will be deprecated in future. '
msg
+=
blue_text
+
'Welcome to use the unified model deployment toolbox '
msg
+=
'MMDeploy: https://github.com/open-mmlab/mmdeploy'
msg
+=
reset_style
warnings
.
warn
(
msg
)
wildcard
=
os
.
path
.
join
(
os
.
path
.
abspath
(
os
.
path
.
dirname
(
os
.
path
.
dirname
(
__file__
))),
'_ext_trt.*.so'
)
...
...
@@ -18,18 +31,44 @@ def get_tensorrt_op_path():
plugin_is_loaded
=
False
def
is_tensorrt_plugin_loaded
():
def
is_tensorrt_plugin_loaded
()
->
bool
:
"""Check if TensorRT plugins library is loaded or not.
Returns:
bool: plugin_is_loaded flag
"""
# Following strings of text style are from colorama package
bright_style
,
reset_style
=
'
\x1b
[1m'
,
'
\x1b
[0m'
red_text
,
blue_text
=
'
\x1b
[31m'
,
'
\x1b
[34m'
white_background
=
'
\x1b
[107m'
msg
=
white_background
+
bright_style
+
red_text
msg
+=
'DeprecationWarning: This function will be deprecated in future. '
msg
+=
blue_text
+
'Welcome to use the unified model deployment toolbox '
msg
+=
'MMDeploy: https://github.com/open-mmlab/mmdeploy'
msg
+=
reset_style
warnings
.
warn
(
msg
)
global
plugin_is_loaded
return
plugin_is_loaded
def
load_tensorrt_plugin
():
def
load_tensorrt_plugin
()
->
None
:
"""load TensorRT plugins library."""
# Following strings of text style are from colorama package
bright_style
,
reset_style
=
'
\x1b
[1m'
,
'
\x1b
[0m'
red_text
,
blue_text
=
'
\x1b
[31m'
,
'
\x1b
[34m'
white_background
=
'
\x1b
[107m'
msg
=
white_background
+
bright_style
+
red_text
msg
+=
'DeprecationWarning: This function will be deprecated in future. '
msg
+=
blue_text
+
'Welcome to use the unified model deployment toolbox '
msg
+=
'MMDeploy: https://github.com/open-mmlab/mmdeploy'
msg
+=
reset_style
warnings
.
warn
(
msg
)
global
plugin_is_loaded
lib_path
=
get_tensorrt_op_path
()
if
(
not
plugin_is_loaded
)
and
os
.
path
.
exists
(
lib_path
):
...
...
mmcv/tensorrt/preprocess.py
View file @
fdeee889
# Copyright (c) OpenMMLab. All rights reserved.
import
warnings
import
numpy
as
np
import
onnx
def
preprocess_onnx
(
onnx_model
)
:
def
preprocess_onnx
(
onnx_model
:
onnx
.
ModelProto
)
->
onnx
.
ModelProto
:
"""Modify onnx model to match with TensorRT plugins in mmcv.
There are some conflict between onnx node definition and TensorRT limit.
...
...
@@ -18,6 +21,19 @@ def preprocess_onnx(onnx_model):
Returns:
onnx.ModelProto: Modified onnx model.
"""
# Following strings of text style are from colorama package
bright_style
,
reset_style
=
'
\x1b
[1m'
,
'
\x1b
[0m'
red_text
,
blue_text
=
'
\x1b
[31m'
,
'
\x1b
[34m'
white_background
=
'
\x1b
[107m'
msg
=
white_background
+
bright_style
+
red_text
msg
+=
'DeprecationWarning: This function will be deprecated in future. '
msg
+=
blue_text
+
'Welcome to use the unified model deployment toolbox '
msg
+=
'MMDeploy: https://github.com/open-mmlab/mmdeploy'
msg
+=
reset_style
warnings
.
warn
(
msg
)
graph
=
onnx_model
.
graph
nodes
=
graph
.
node
initializers
=
graph
.
initializer
...
...
mmcv/tensorrt/tensorrt_utils.py
View file @
fdeee889
# Copyright (c) OpenMMLab. All rights reserved.
import
warnings
from
typing
import
Union
import
onnx
import
tensorrt
as
trt
...
...
@@ -8,12 +9,12 @@ import torch
from
.preprocess
import
preprocess_onnx
def
onnx2trt
(
onnx_model
,
opt_shape_dict
,
log_level
=
trt
.
Logger
.
ERROR
,
fp16_mode
=
False
,
max_workspace_size
=
0
,
device_id
=
0
)
:
def
onnx2trt
(
onnx_model
:
Union
[
str
,
onnx
.
ModelProto
]
,
opt_shape_
dict
:
dict
,
log_level
:
trt
.
ILogger
.
Severity
=
trt
.
Logger
.
ERROR
,
fp16_mode
:
bool
=
False
,
max_workspace_size
:
int
=
0
,
device_id
:
int
=
0
)
->
trt
.
ICudaEngine
:
"""Convert onnx model to tensorrt engine.
Arguments:
...
...
@@ -40,7 +41,20 @@ def onnx2trt(onnx_model,
>>> device_id=0)
>>> })
"""
device
=
torch
.
device
(
'cuda:{}'
.
format
(
device_id
))
# Following strings of text style are from colorama package
bright_style
,
reset_style
=
'
\x1b
[1m'
,
'
\x1b
[0m'
red_text
,
blue_text
=
'
\x1b
[31m'
,
'
\x1b
[34m'
white_background
=
'
\x1b
[107m'
msg
=
white_background
+
bright_style
+
red_text
msg
+=
'DeprecationWarning: This function will be deprecated in future. '
msg
+=
blue_text
+
'Welcome to use the unified model deployment toolbox '
msg
+=
'MMDeploy: https://github.com/open-mmlab/mmdeploy'
msg
+=
reset_style
warnings
.
warn
(
msg
)
device
=
torch
.
device
(
f
'cuda:
{
device_id
}
'
)
# create builder and network
logger
=
trt
.
Logger
(
log_level
)
builder
=
trt
.
Builder
(
logger
)
...
...
@@ -87,18 +101,31 @@ def onnx2trt(onnx_model,
return
engine
def
save_trt_engine
(
engine
,
path
)
:
def
save_trt_engine
(
engine
:
trt
.
ICudaEngine
,
path
:
str
)
->
None
:
"""Serialize TensorRT engine to disk.
Arguments:
engine (tensorrt.ICudaEngine): TensorRT engine to serialize
path (str): disk path to write the engine
"""
# Following strings of text style are from colorama package
bright_style
,
reset_style
=
'
\x1b
[1m'
,
'
\x1b
[0m'
red_text
,
blue_text
=
'
\x1b
[31m'
,
'
\x1b
[34m'
white_background
=
'
\x1b
[107m'
msg
=
white_background
+
bright_style
+
red_text
msg
+=
'DeprecationWarning: This function will be deprecated in future. '
msg
+=
blue_text
+
'Welcome to use the unified model deployment toolbox '
msg
+=
'MMDeploy: https://github.com/open-mmlab/mmdeploy'
msg
+=
reset_style
warnings
.
warn
(
msg
)
with
open
(
path
,
mode
=
'wb'
)
as
f
:
f
.
write
(
bytearray
(
engine
.
serialize
()))
def
load_trt_engine
(
path
)
:
def
load_trt_engine
(
path
:
str
)
->
trt
.
ICudaEngine
:
"""Deserialize TensorRT engine from disk.
Arguments:
...
...
@@ -107,6 +134,19 @@ def load_trt_engine(path):
Returns:
tensorrt.ICudaEngine: the TensorRT engine loaded from disk
"""
# Following strings of text style are from colorama package
bright_style
,
reset_style
=
'
\x1b
[1m'
,
'
\x1b
[0m'
red_text
,
blue_text
=
'
\x1b
[31m'
,
'
\x1b
[34m'
white_background
=
'
\x1b
[107m'
msg
=
white_background
+
bright_style
+
red_text
msg
+=
'DeprecationWarning: This function will be deprecated in future. '
msg
+=
blue_text
+
'Welcome to use the unified model deployment toolbox '
msg
+=
'MMDeploy: https://github.com/open-mmlab/mmdeploy'
msg
+=
reset_style
warnings
.
warn
(
msg
)
with
trt
.
Logger
()
as
logger
,
trt
.
Runtime
(
logger
)
as
runtime
:
with
open
(
path
,
mode
=
'rb'
)
as
f
:
engine_bytes
=
f
.
read
()
...
...
@@ -114,7 +154,7 @@ def load_trt_engine(path):
return
engine
def
torch_dtype_from_trt
(
dtype
)
:
def
torch_dtype_from_trt
(
dtype
:
trt
.
DataType
)
->
Union
[
torch
.
dtype
,
TypeError
]
:
"""Convert pytorch dtype to TensorRT dtype."""
if
dtype
==
trt
.
bool
:
return
torch
.
bool
...
...
@@ -130,7 +170,8 @@ def torch_dtype_from_trt(dtype):
raise
TypeError
(
'%s is not supported by torch'
%
dtype
)
def
torch_device_from_trt
(
device
):
def
torch_device_from_trt
(
device
:
trt
.
TensorLocation
)
->
Union
[
torch
.
device
,
TypeError
]:
"""Convert pytorch device to TensorRT device."""
if
device
==
trt
.
TensorLocation
.
DEVICE
:
return
torch
.
device
(
'cuda'
)
...
...
@@ -154,7 +195,21 @@ class TRTWrapper(torch.nn.Module):
"""
def
__init__
(
self
,
engine
,
input_names
=
None
,
output_names
=
None
):
super
(
TRTWrapper
,
self
).
__init__
()
# Following strings of text style are from colorama package
bright_style
,
reset_style
=
'
\x1b
[1m'
,
'
\x1b
[0m'
red_text
,
blue_text
=
'
\x1b
[31m'
,
'
\x1b
[34m'
white_background
=
'
\x1b
[107m'
msg
=
white_background
+
bright_style
+
red_text
msg
+=
'DeprecationWarning: This tool will be deprecated in future. '
msg
+=
blue_text
+
\
'Welcome to use the unified model deployment toolbox '
msg
+=
'MMDeploy: https://github.com/open-mmlab/mmdeploy'
msg
+=
reset_style
warnings
.
warn
(
msg
)
super
().
__init__
()
self
.
engine
=
engine
if
isinstance
(
self
.
engine
,
str
):
self
.
engine
=
load_trt_engine
(
engine
)
...
...
@@ -231,5 +286,6 @@ class TRTWraper(TRTWrapper):
def
__init__
(
self
,
*
args
,
**
kwargs
):
super
().
__init__
(
*
args
,
**
kwargs
)
warnings
.
warn
(
'TRTWraper will be deprecated in'
' future. Please use TRTWrapper instead'
)
warnings
.
warn
(
'TRTWraper will be deprecated in'
' future. Please use TRTWrapper instead'
,
DeprecationWarning
)
mmcv/utils/__init__.py
View file @
fdeee889
...
...
@@ -36,17 +36,26 @@ except ImportError:
'is_method_overridden'
,
'has_method'
]
else
:
from
.device_type
import
(
IS_IPU_AVAILABLE
,
IS_MLU_AVAILABLE
,
IS_MPS_AVAILABLE
)
from
.env
import
collect_env
from
.hub
import
load_url
from
.logging
import
get_logger
,
print_log
from
.parrots_jit
import
jit
,
skip_no_elena
from
.parrots_wrapper
import
(
TORCH_VERSION
,
BuildExtension
,
CppExtension
,
CUDAExtension
,
DataLoader
,
PoolDataLoader
,
SyncBatchNorm
,
_AdaptiveAvgPoolNd
,
_AdaptiveMaxPoolNd
,
_AvgPoolNd
,
_BatchNorm
,
_ConvNd
,
_ConvTransposeMixin
,
_InstanceNorm
,
_MaxPoolNd
,
get_build_config
,
is_rocm_pytorch
,
_get_cuda_home
)
# yapf: disable
from
.parrots_wrapper
import
(
IS_CUDA_AVAILABLE
,
TORCH_VERSION
,
BuildExtension
,
CppExtension
,
CUDAExtension
,
DataLoader
,
PoolDataLoader
,
SyncBatchNorm
,
_AdaptiveAvgPoolNd
,
_AdaptiveMaxPoolNd
,
_AvgPoolNd
,
_BatchNorm
,
_ConvNd
,
_ConvTransposeMixin
,
_get_cuda_home
,
_InstanceNorm
,
_MaxPoolNd
,
get_build_config
,
is_rocm_pytorch
)
# yapf: enable
from
.registry
import
Registry
,
build_from_cfg
from
.seed
import
worker_init_fn
from
.torch_ops
import
torch_meshgrid
from
.trace
import
is_jit_tracing
from
.hub
import
load_url
__all__
=
[
'Config'
,
'ConfigDict'
,
'DictAction'
,
'collect_env'
,
'get_logger'
,
'print_log'
,
'is_str'
,
'iter_cast'
,
'list_cast'
,
'tuple_cast'
,
...
...
@@ -66,5 +75,7 @@ else:
'assert_dict_has_keys'
,
'assert_keys_equal'
,
'assert_is_norm_layer'
,
'assert_params_all_zeros'
,
'check_python_script'
,
'is_method_overridden'
,
'is_jit_tracing'
,
'is_rocm_pytorch'
,
'_get_cuda_home'
,
'load_url'
,
'has_method'
'_get_cuda_home'
,
'load_url'
,
'has_method'
,
'IS_CUDA_AVAILABLE'
,
'worker_init_fn'
,
'IS_MLU_AVAILABLE'
,
'IS_IPU_AVAILABLE'
,
'IS_MPS_AVAILABLE'
,
'torch_meshgrid'
]
Prev
1
…
21
22
23
24
25
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment