Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
MMCV
Commits
c30e91db
Unverified
Commit
c30e91db
authored
Jun 14, 2020
by
Cao Yuhang
Committed by
GitHub
Jun 14, 2020
Browse files
share torch version (#343)
parent
b87e774f
Changes
7
Hide whitespace changes
Inline
Side-by-side
Showing
7 changed files
with
28 additions
and
17 deletions
+28
-17
mmcv/parallel/distributed.py
mmcv/parallel/distributed.py
+3
-2
mmcv/parallel/distributed_deprecated.py
mmcv/parallel/distributed_deprecated.py
+2
-1
mmcv/runner/dist_utils.py
mmcv/runner/dist_utils.py
+3
-1
mmcv/runner/hooks/logger/tensorboard.py
mmcv/runner/hooks/logger/tensorboard.py
+2
-3
mmcv/utils/__init__.py
mmcv/utils/__init__.py
+3
-1
mmcv/utils/env.py
mmcv/utils/env.py
+4
-0
mmcv/utils/parrots_wrapper.py
mmcv/utils/parrots_wrapper.py
+11
-9
No files found.
mmcv/parallel/distributed.py
View file @
c30e91db
...
...
@@ -3,6 +3,7 @@ import torch
from
torch.nn.parallel.distributed
import
(
DistributedDataParallel
,
_find_tensors
)
from
mmcv.utils
import
TORCH_VERSION
from
.scatter_gather
import
scatter_kwargs
...
...
@@ -47,7 +48,7 @@ class MMDistributedDataParallel(DistributedDataParallel):
else
:
self
.
reducer
.
prepare_for_backward
([])
else
:
if
torch
.
__version__
>
'1.2'
:
if
TORCH_VERSION
>
'1.2'
:
self
.
require_forward_param_sync
=
False
return
output
...
...
@@ -79,6 +80,6 @@ class MMDistributedDataParallel(DistributedDataParallel):
else
:
self
.
reducer
.
prepare_for_backward
([])
else
:
if
torch
.
__version__
>
'1.2'
:
if
TORCH_VERSION
>
'1.2'
:
self
.
require_forward_param_sync
=
False
return
output
mmcv/parallel/distributed_deprecated.py
View file @
c30e91db
...
...
@@ -5,6 +5,7 @@ import torch.nn as nn
from
torch._utils
import
(
_flatten_dense_tensors
,
_take_tensors
,
_unflatten_dense_tensors
)
from
mmcv.utils
import
TORCH_VERSION
from
.scatter_gather
import
scatter_kwargs
...
...
@@ -37,7 +38,7 @@ class MMDistributedDataParallel(nn.Module):
self
.
_dist_broadcast_coalesced
(
module_states
,
self
.
broadcast_bucket_size
)
if
self
.
broadcast_buffers
:
if
torch
.
__version__
<
'1.0'
:
if
TORCH_VERSION
<
'1.0'
:
buffers
=
[
b
.
data
for
b
in
self
.
module
.
_all_buffers
()]
else
:
buffers
=
[
b
.
data
for
b
in
self
.
module
.
buffers
()]
...
...
mmcv/runner/dist_utils.py
View file @
c30e91db
...
...
@@ -7,6 +7,8 @@ import torch
import
torch.distributed
as
dist
import
torch.multiprocessing
as
mp
from
mmcv.utils
import
TORCH_VERSION
def
init_dist
(
launcher
,
backend
=
'nccl'
,
**
kwargs
):
if
mp
.
get_start_method
(
allow_none
=
True
)
is
None
:
...
...
@@ -49,7 +51,7 @@ def _init_dist_slurm(backend, port=29500):
def
get_dist_info
():
if
torch
.
__version__
<
'1.0'
:
if
TORCH_VERSION
<
'1.0'
:
initialized
=
dist
.
_initialized
else
:
if
dist
.
is_available
():
...
...
mmcv/runner/hooks/logger/tensorboard.py
View file @
c30e91db
# Copyright (c) Open-MMLab. All rights reserved.
import
os.path
as
osp
import
torch
from
mmcv.utils
import
TORCH_VERSION
from
...dist_utils
import
master_only
from
..hook
import
HOOKS
from
.base
import
LoggerHook
...
...
@@ -22,7 +21,7 @@ class TensorboardLoggerHook(LoggerHook):
@
master_only
def
before_run
(
self
,
runner
):
if
torch
.
__version__
<
'1.1'
or
torch
.
__version__
==
'parrots'
:
if
TORCH_VERSION
<
'1.1'
or
TORCH_VERSION
==
'parrots'
:
try
:
from
tensorboardX
import
SummaryWriter
except
ImportError
:
...
...
mmcv/utils/__init__.py
View file @
c30e91db
# Copyright (c) Open-MMLab. All rights reserved.
from
.config
import
Config
,
ConfigDict
,
DictAction
from
.env
import
TORCH_VERSION
from
.logging
import
get_logger
,
print_log
from
.misc
import
(
check_prerequisites
,
concat_list
,
is_list_of
,
is_seq_of
,
is_str
,
is_tuple_of
,
iter_cast
,
list_cast
,
...
...
@@ -29,5 +30,6 @@ __all__ = [
'CUDA_HOME'
,
'SyncBatchNorm'
,
'_AdaptiveAvgPoolNd'
,
'_AdaptiveMaxPoolNd'
,
'_AvgPoolNd'
,
'_BatchNorm'
,
'_ConvNd'
,
'_ConvTransposeMixin'
,
'_InstanceNorm'
,
'_MaxPoolNd'
,
'get_build_config'
,
'BuildExtension'
,
'CppExtension'
,
'CUDAExtension'
,
'DataLoader'
,
'PoolDataLoader'
'CppExtension'
,
'CUDAExtension'
,
'DataLoader'
,
'PoolDataLoader'
,
'TORCH_VERSION'
]
mmcv/utils/env.py
0 → 100644
View file @
c30e91db
# This file holding some environment constant for sharing by other files
import
torch
TORCH_VERSION
=
torch
.
__version__
mmcv/utils/parrots_wrapper.py
View file @
c30e91db
...
...
@@ -2,9 +2,11 @@ from functools import partial
import
torch
from
.env
import
TORCH_VERSION
def
_get_cuda_home
():
if
torch
.
__version__
==
'parrots'
:
if
TORCH_VERSION
==
'parrots'
:
from
parrots.utils.build_extension
import
CUDA_HOME
else
:
from
torch.utils.cpp_extension
import
CUDA_HOME
...
...
@@ -12,7 +14,7 @@ def _get_cuda_home():
def
get_build_config
():
if
torch
.
__version__
==
'parrots'
:
if
TORCH_VERSION
==
'parrots'
:
from
parrots.config
import
get_build_info
return
get_build_info
()
else
:
...
...
@@ -20,7 +22,7 @@ def get_build_config():
def
_get_conv
():
if
torch
.
__version__
==
'parrots'
:
if
TORCH_VERSION
==
'parrots'
:
from
parrots.nn.modules.conv
import
_ConvNd
,
_ConvTransposeMixin
else
:
from
torch.nn.modules.conv
import
_ConvNd
,
_ConvTransposeMixin
...
...
@@ -28,7 +30,7 @@ def _get_conv():
def
_get_dataloader
():
if
torch
.
__version__
==
'parrots'
:
if
TORCH_VERSION
==
'parrots'
:
from
torch.utils.data
import
DataLoader
,
PoolDataLoader
else
:
from
torch.utils.data
import
DataLoader
...
...
@@ -37,7 +39,7 @@ def _get_dataloader():
def
_get_extension
():
if
torch
.
__version__
==
'parrots'
:
if
TORCH_VERSION
==
'parrots'
:
from
parrots.utils.build_extension
import
BuildExtension
,
Extension
CppExtension
=
partial
(
Extension
,
cuda
=
False
)
CUDAExtension
=
partial
(
Extension
,
cuda
=
True
)
...
...
@@ -48,7 +50,7 @@ def _get_extension():
def
_get_pool
():
if
torch
.
__version__
==
'parrots'
:
if
TORCH_VERSION
==
'parrots'
:
from
parrots.nn.modules.pool
import
(
_AdaptiveAvgPoolNd
,
_AdaptiveMaxPoolNd
,
_AvgPoolNd
,
_MaxPoolNd
)
...
...
@@ -60,7 +62,7 @@ def _get_pool():
def
_get_norm
():
if
torch
.
__version__
==
'parrots'
:
if
TORCH_VERSION
==
'parrots'
:
from
parrots.nn.modules.batchnorm
import
_BatchNorm
,
_InstanceNorm
SyncBatchNorm_
=
torch
.
nn
.
SyncBatchNorm2d
else
:
...
...
@@ -81,11 +83,11 @@ _AdaptiveAvgPoolNd, _AdaptiveMaxPoolNd, _AvgPoolNd, _MaxPoolNd = _get_pool()
class
SyncBatchNorm
(
SyncBatchNorm_
):
def
_specify_ddp_gpu_num
(
self
,
gpu_size
):
if
torch
.
__version__
!=
'parrots'
:
if
TORCH_VERSION
!=
'parrots'
:
super
().
_specify_ddp_gpu_num
(
gpu_size
)
def
_check_input_dim
(
self
,
input
):
if
torch
.
__version__
==
'parrots'
:
if
TORCH_VERSION
==
'parrots'
:
if
input
.
dim
()
<
2
:
raise
ValueError
(
f
'expected at least 2D input (got
{
input
.
dim
()
}
D input)'
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment