Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
fairscale
Commits
180ab8c8
Unverified
Commit
180ab8c8
authored
Sep 13, 2021
by
Benjamin Lefaudeux
Committed by
GitHub
Sep 13, 2021
Browse files
[OSS] Fixing the fp16 broadcast and catching this case in the unit test (#795)
parent
31e36453
Changes
3
Show whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
10 additions
and
1 deletion
+10
-1
CHANGELOG.md
CHANGELOG.md
+4
-0
fairscale/nn/misc/param_bucket.py
fairscale/nn/misc/param_bucket.py
+3
-1
tests/nn/misc/test_param_bucket.py
tests/nn/misc/test_param_bucket.py
+3
-0
No files found.
CHANGELOG.md
View file @
180ab8c8
...
@@ -16,6 +16,10 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
...
@@ -16,6 +16,10 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
-
activation checkpoint: Ensure outputs of checkpointed modules only require grad if either
-
activation checkpoint: Ensure outputs of checkpointed modules only require grad if either
the input requires grad or if the parameters require grad. [#787]
the input requires grad or if the parameters require grad. [#787]
-
OSS: fix the broadcast_fp16 option, broken after a refactor, this flag was doing nothing (bugfix).[#795]
-
OSS: update default device when refreshing the params, meaning that moving the model to GPU after
the OSS wrap will not trigger warnings and slow the jobs (ease of use). [#786]
### Added
### Added
-
FSDP: Added support for returning the original names of parameters when
`named_parameters`
is called on
-
FSDP: Added support for returning the original names of parameters when
`named_parameters`
is called on
the module. To retrieve the orginal names of the parameters along with the params, you need to
the module. To retrieve the orginal names of the parameters along with the params, you need to
...
...
fairscale/nn/misc/param_bucket.py
View file @
180ab8c8
...
@@ -32,7 +32,7 @@ class Bucket:
...
@@ -32,7 +32,7 @@ class Bucket:
Move the underlying buffer
Move the underlying buffer
"""
"""
assert
self
.
buffer
is
not
None
,
"Cannot move a collapsed bucket, please rebuild it"
assert
self
.
buffer
is
not
None
,
"Cannot move a collapsed bucket, please rebuild it"
self
.
buffer
.
to
(
device
,
dtype
,
non_blocking
)
self
.
buffer
=
self
.
buffer
.
to
(
device
,
dtype
,
non_blocking
)
class
ParamBucket
(
Bucket
):
class
ParamBucket
(
Bucket
):
...
@@ -98,6 +98,8 @@ class ParamBucket(Bucket):
...
@@ -98,6 +98,8 @@ class ParamBucket(Bucket):
self
.
_fill
=
0
self
.
_fill
=
0
for
p
in
self
.
_params
:
for
p
in
self
.
_params
:
if
p
.
dtype
!=
self
.
buffer
.
dtype
:
p
.
data
=
p
.
data
.
to
(
self
.
buffer
.
dtype
)
self
.
_add_param_as_view
(
p
,
keep_existing_value
=
False
)
self
.
_add_param_as_view
(
p
,
keep_existing_value
=
False
)
...
...
tests/nn/misc/test_param_bucket.py
View file @
180ab8c8
...
@@ -48,7 +48,10 @@ def test_type_change():
...
@@ -48,7 +48,10 @@ def test_type_change():
# Move the bucket to fp16 and back
# Move the bucket to fp16 and back
bucket
.
to
(
dtype
=
torch
.
float16
,
device
=
param
.
device
)
bucket
.
to
(
dtype
=
torch
.
float16
,
device
=
param
.
device
)
assert
bucket
.
buffer
.
dtype
==
torch
.
float16
bucket
.
to
(
dtype
=
torch
.
float32
,
device
=
param
.
device
,
keep_param_alignment
=
True
)
bucket
.
to
(
dtype
=
torch
.
float32
,
device
=
param
.
device
,
keep_param_alignment
=
True
)
assert
bucket
.
buffer
.
dtype
==
torch
.
float32
# Same with the reference tensor
# Same with the reference tensor
param_
.
to
(
dtype
=
torch
.
float16
)
param_
.
to
(
dtype
=
torch
.
float16
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment