Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
fairscale
Commits
b46dcfaf
Unverified
Commit
b46dcfaf
authored
Jul 27, 2021
by
Benjamin Lefaudeux
Committed by
GitHub
Jul 27, 2021
Browse files
[fix] OSS fp16 broadcast typo (#751)
parent
83b0b49e
Changes
3
Show whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
8 additions
and
3 deletions
+8
-3
CHANGELOG.md
CHANGELOG.md
+2
-1
fairscale/optim/oss.py
fairscale/optim/oss.py
+2
-2
tests/optim/test_oss.py
tests/optim/test_oss.py
+4
-0
No files found.
CHANGELOG.md
View file @
b46dcfaf
...
...
@@ -7,6 +7,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
## NEXT - TBD
### Fixed
-
FSDP: fixed metadata saving and shard consolidation for MoE cases [#746]
-
OSS: fixed the buckets which would stay in fp16 if
`broadcast fp16`
was required (#751)
### Added
-
FSDP: better performance; use
`_allgather_base`
and
`_reduce_scatter_base`
when available [#729]
...
...
fairscale/optim/oss.py
View file @
b46dcfaf
...
...
@@ -396,7 +396,7 @@ class OSS(Optimizer):
OSS
.
_sync_param_groups
(
self
.
param_groups
,
self
.
optim
.
param_groups
)
def
refresh_trainable
(
self
)
->
None
:
"""
Updates the partitioning and communication patterns if the trainability (`requires_grad`)
"""Updates the partitioning and communication patterns if the trainability (`requires_grad`)
of some parameters changed.
"""
...
...
@@ -551,7 +551,7 @@ class OSS(Optimizer):
# Populate back the fp32 shards
if
self
.
broadcast_fp16
:
for
device
in
self
.
buckets
.
keys
():
for
dst_rank
in
self
.
buckets
[
device
].
key
s
():
for
dst_rank
,
bucket
in
self
.
buckets
[
device
].
item
s
():
bucket
.
to
(
dtype
=
torch
.
float32
,
device
=
device
,
non_blocking
=
True
,
keep_param_alignment
=
True
)
def
_setup_flat_buffers
(
self
)
->
None
:
...
...
tests/optim/test_oss.py
View file @
b46dcfaf
...
...
@@ -536,6 +536,10 @@ def run_test_reproducibility(rank, world_size, tempfile_name, broadcast_fp16):
assert
torch
.
allclose
(
reference_loss
,
test_loss
),
f
"
{
reference_loss
}
vs
{
test_loss
}
. Reproducibility is broken"
# Check that no matter what the buffer is back to fp32
for
device
in
optimizer
.
buckets
.
keys
():
for
bucket
in
optimizer
.
buckets
[
device
].
values
():
assert
bucket
.
buffer
.
dtype
==
torch
.
float32
dist
.
destroy_process_group
()
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment