Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
SOLOv2-pytorch
Commits
82356fd9
Commit
82356fd9
authored
Oct 08, 2018
by
Kai Chen
Browse files
support chunk when reducing grads
parent
3d2b79bd
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
35 additions
and
43 deletions
+35
-43
mmdet/core/utils/__init__.py
mmdet/core/utils/__init__.py
+3
-4
mmdet/core/utils/dist_utils.py
mmdet/core/utils/dist_utils.py
+32
-39
No files found.
mmdet/core/utils/__init__.py
View file @
82356fd9
from
.dist_utils
import
(
init_dist
,
reduce_grads
,
DistOptimizerHook
,
DistSamplerSeedHook
)
from
.dist_utils
import
init_dist
,
allreduce_grads
,
DistOptimizerHook
from
.misc
import
tensor2imgs
,
unmap
,
multi_apply
__all__
=
[
'init_dist'
,
'reduce_grads'
,
'DistOptimizerHook'
,
'
DistSamplerSeedHook
'
,
'tensor2imgs'
,
'unmap'
,
'multi_apply'
'init_dist'
,
'
all
reduce_grads'
,
'DistOptimizerHook'
,
'
tensor2imgs
'
,
'unmap'
,
'multi_apply'
]
mmdet/core/utils/dist_utils.py
View file @
82356fd9
...
...
@@ -4,9 +4,9 @@ from collections import OrderedDict
import
torch
import
torch.multiprocessing
as
mp
import
torch.distributed
as
dist
from
torch._utils
import
_flatten_dense_tensors
,
_unflatten_dense_tensors
from
torch.nn.utils
import
clip_grad
from
mmcv.runner
import
Hook
,
OptimizerHook
from
torch._utils
import
(
_flatten_dense_tensors
,
_unflatten_dense_tensors
,
_take_tensors
)
from
mmcv.runner
import
OptimizerHook
def
init_dist
(
launcher
,
backend
=
'nccl'
,
**
kwargs
):
...
...
@@ -38,59 +38,52 @@ def _init_dist_slurm(backend, **kwargs):
raise
NotImplementedError
# modified from
# https://github.com/NVIDIA/apex/blob/master/apex/parallel/distributed.py#L9
def
all_reduce_coalesced
(
tensors
):
def
_allreduce_coalesced
(
tensors
,
world_size
,
bucket_size_mb
=-
1
):
if
bucket_size_mb
>
0
:
bucket_size_bytes
=
bucket_size_mb
*
1024
*
1024
buckets
=
_take_tensors
(
tensors
,
bucket_size_bytes
)
else
:
buckets
=
OrderedDict
()
for
tensor
in
tensors
:
tp
=
tensor
.
type
()
if
tp
not
in
buckets
:
buckets
[
tp
]
=
[]
buckets
[
tp
].
append
(
tensor
)
buckets
=
buckets
.
values
()
world_size
=
dist
.
get_world_size
()
for
tp
in
buckets
:
bucket
=
buckets
[
tp
]
coalesced
=
_flatten_dense_tensors
(
bucket
)
dist
.
all_reduce
(
coalesced
)
coalesced
.
div_
(
world_size
)
for
buf
,
synced
in
zip
(
bucket
,
_unflatten_dense_tensors
(
coalesced
,
bucket
)):
buf
.
copy_
(
synced
)
for
bucket
in
buckets
:
flat_tensors
=
_flatten_dense_tensors
(
bucket
)
dist
.
all_reduce
(
flat_tensors
)
flat_tensors
.
div_
(
world_size
)
for
tensor
,
synced
in
zip
(
bucket
,
_unflatten_dense_tensors
(
flat_tensors
,
bucket
)):
tensor
.
copy_
(
synced
)
def
reduce_grads
(
model
,
coalesce
=
True
):
def
all
reduce_grads
(
model
,
coalesce
=
True
,
bucket_size_mb
=-
1
):
grads
=
[
param
.
grad
.
data
for
param
in
model
.
parameters
()
if
param
.
requires_grad
and
param
.
grad
is
not
None
]
world_size
=
dist
.
get_world_size
()
if
coalesce
:
all
_
reduce_coalesced
(
grads
)
_
allreduce_coalesced
(
grads
,
world_size
,
bucket_size_mb
)
else
:
world_size
=
dist
.
get_world_size
()
for
tensor
in
grads
:
dist
.
all_reduce
(
tensor
.
div_
(
world_size
))
class
DistOptimizerHook
(
OptimizerHook
):
def
__init__
(
self
,
grad_clip
=
None
,
coalesce
=
True
):
def
__init__
(
self
,
grad_clip
=
None
,
coalesce
=
True
,
bucket_size_mb
=-
1
):
self
.
grad_clip
=
grad_clip
self
.
coalesce
=
coalesce
self
.
bucket_size_mb
=
bucket_size_mb
def
after_train_iter
(
self
,
runner
):
runner
.
optimizer
.
zero_grad
()
runner
.
outputs
[
'loss'
].
backward
()
reduce_grads
(
runner
.
model
,
self
.
coalesce
)
all
reduce_grads
(
runner
.
model
,
self
.
coalesce
,
self
.
bucket_size_mb
)
if
self
.
grad_clip
is
not
None
:
clip_grad
.
clip_grad_norm_
(
filter
(
lambda
p
:
p
.
requires_grad
,
runner
.
model
.
parameters
()),
**
self
.
grad_clip
)
self
.
clip_grads
(
runner
.
model
.
parameters
())
runner
.
optimizer
.
step
()
class
DistSamplerSeedHook
(
Hook
):
def
before_epoch
(
self
,
runner
):
runner
.
data_loader
.
sampler
.
set_epoch
(
runner
.
epoch
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment