Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
apex
Commits
5d1993cf
"tests/distributed/test_distributed_sampling.py" did not exist on "ff5b5a4a1090f79e82163f43909a1831b7fce924"
Commit
5d1993cf
authored
Apr 29, 2020
by
Thor Johnsen
Browse files
Don't pad between consecutive parameters
parent
e1a4deba
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
6 additions
and
2 deletions
+6
-2
apex/contrib/optimizers/distributed_fused_adam.py
apex/contrib/optimizers/distributed_fused_adam.py
+6
-2
No files found.
apex/contrib/optimizers/distributed_fused_adam.py
View file @
5d1993cf
...
@@ -101,6 +101,7 @@ class DistributedFusedAdam(torch.optim.Optimizer):
...
@@ -101,6 +101,7 @@ class DistributedFusedAdam(torch.optim.Optimizer):
self
.
_grads_info
=
[]
self
.
_grads_info
=
[]
for
group
in
self
.
param_groups
:
for
group
in
self
.
param_groups
:
self
.
_param_group
=
group
self
.
_param_group
=
group
prev
=
None
for
p
in
group
[
'params'
]:
for
p
in
group
[
'params'
]:
torch
.
distributed
.
broadcast
(
p
,
0
)
torch
.
distributed
.
broadcast
(
p
,
0
)
if
not
p
.
requires_grad
:
if
not
p
.
requires_grad
:
...
@@ -119,7 +120,10 @@ class DistributedFusedAdam(torch.optim.Optimizer):
...
@@ -119,7 +120,10 @@ class DistributedFusedAdam(torch.optim.Optimizer):
self
.
_grads_info
.
append
({
"param_grads_size"
:
p_grads_size
,
"param_offset"
:
p_offset
})
self
.
_grads_info
.
append
({
"param_grads_size"
:
p_grads_size
,
"param_offset"
:
p_offset
})
wrapper
(
p
,
p_i
,
p_grads_size
,
p_offset
)
wrapper
(
p
,
p_i
,
p_grads_size
,
p_offset
)
p_offset
+=
p_grads_size
p_offset
+=
p_grads_size
# enforce 128b alignment (64 * fp16)
# Only enforce 128b alignment (64 * fp16) for non-consecutive parameters
# RNN is one example of consecutive parameters:
# (weight_ih, weight_hh, bias_ih, bias_hh)
if
prev
is
not
None
and
(
prev
.
data_ptr
()
+
prev
.
numel
()
*
prev
.
element_size
()
!=
p
.
data_ptr
()):
p_offset
=
((
p_offset
+
63
)
//
64
)
*
64
p_offset
=
((
p_offset
+
63
)
//
64
)
*
64
p_i
+=
1
p_i
+=
1
self
.
_grads_generated
=
[
False
]
*
len
(
self
.
_grads_info
)
self
.
_grads_generated
=
[
False
]
*
len
(
self
.
_grads_info
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment