Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
apex
Commits
e777bddb
"docs/vscode:/vscode.git/clone" did not exist on "9185eee858cdb26da4ccb030124d2ffb76e1ebbd"
Commit
e777bddb
authored
Sep 02, 2021
by
Thor Johnsen
Browse files
Optional NCCL communicator argument to init method
parent
9e295728
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
9 additions
and
6 deletions
+9
-6
apex/contrib/bottleneck/bottleneck.py
apex/contrib/bottleneck/bottleneck.py
+9
-6
No files found.
apex/contrib/bottleneck/bottleneck.py
View file @
e777bddb
...
@@ -393,7 +393,7 @@ class SpatialBottleneck(torch.nn.Module):
...
@@ -393,7 +393,7 @@ class SpatialBottleneck(torch.nn.Module):
def
__init__
(
self
,
in_channels
,
bottleneck_channels
,
out_channels
,
stride
=
1
,
groups
=
1
,
def
__init__
(
self
,
in_channels
,
bottleneck_channels
,
out_channels
,
stride
=
1
,
groups
=
1
,
dilation
=
1
,
norm_func
=
None
,
use_cudnn
=
False
,
explicit_nhwc
=
False
,
dilation
=
1
,
norm_func
=
None
,
use_cudnn
=
False
,
explicit_nhwc
=
False
,
spatial_group_size
=
1
):
spatial_group_size
=
1
,
communicator
=
None
):
super
(
SpatialBottleneck
,
self
).
__init__
()
super
(
SpatialBottleneck
,
self
).
__init__
()
if
groups
!=
1
:
if
groups
!=
1
:
raise
RuntimeError
(
'Only support groups == 1'
)
raise
RuntimeError
(
'Only support groups == 1'
)
...
@@ -454,11 +454,14 @@ class SpatialBottleneck(torch.nn.Module):
...
@@ -454,11 +454,14 @@ class SpatialBottleneck(torch.nn.Module):
assert
(
num_groups
*
spatial_group_size
==
world_size
),
"torch.distributed.get_world_size() must be multiple of group_size"
assert
(
num_groups
*
spatial_group_size
==
world_size
),
"torch.distributed.get_world_size() must be multiple of group_size"
rank
=
dist
.
get_rank
()
rank
=
dist
.
get_rank
()
self
.
local_rank
=
rank
%
spatial_group_size
self
.
local_rank
=
rank
%
spatial_group_size
if
communicator
is
None
:
for
group
in
range
(
num_groups
):
for
group
in
range
(
num_groups
):
ranks
=
list
(
range
(
group
*
spatial_group_size
,(
group
+
1
)
*
spatial_group_size
))
ranks
=
list
(
range
(
group
*
spatial_group_size
,(
group
+
1
)
*
spatial_group_size
))
comm
=
torch
.
distributed
.
new_group
(
ranks
=
ranks
)
comm
=
torch
.
distributed
.
new_group
(
ranks
=
ranks
)
if
rank
in
ranks
:
if
rank
in
ranks
:
self
.
communicator
=
comm
self
.
communicator
=
comm
else
:
self
.
communicator
=
communicator
self
.
stream1
=
torch
.
cuda
.
Stream
()
self
.
stream1
=
torch
.
cuda
.
Stream
()
self
.
spatial_args
=
self
.
spatial_group_size
,
self
.
local_rank
,
self
.
communicator
,
self
.
stream1
self
.
spatial_args
=
self
.
spatial_group_size
,
self
.
local_rank
,
self
.
communicator
,
self
.
stream1
else
:
else
:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment