Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
fairscale
Commits
9e0df348
Unverified
Commit
9e0df348
authored
Feb 23, 2021
by
Myle Ott
Committed by
GitHub
Feb 23, 2021
Browse files
[fix]: Fix non-float buffers in FSDP (#427)
parent
b89365e6
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
11 additions
and
7 deletions
+11
-7
fairscale/nn/data_parallel/fully_sharded_data_parallel.py
fairscale/nn/data_parallel/fully_sharded_data_parallel.py
+6
-2
tests/nn/data_parallel/test_fsdp.py
tests/nn/data_parallel/test_fsdp.py
+5
-5
No files found.
fairscale/nn/data_parallel/fully_sharded_data_parallel.py
View file @
9e0df348
...
@@ -931,11 +931,15 @@ def cast_inputs_to_fp16(*args: Any, **kwargs: Any) -> Tuple[Any, Any]:
...
@@ -931,11 +931,15 @@ def cast_inputs_to_fp16(*args: Any, **kwargs: Any) -> Tuple[Any, Any]:
def
cast_buffers_
(
def
cast_buffers_
(
module
:
nn
.
Module
,
device
:
Optional
[
torch
.
device
]
=
None
,
dtype
:
Optional
[
torch
.
dtype
]
=
None
module
:
nn
.
Module
,
device
:
Optional
[
torch
.
device
]
=
None
,
dtype
:
Optional
[
torch
.
dtype
]
=
None
)
->
None
:
)
->
None
:
"""Cast all of module.named_buffers to device
,
dtype."""
"""Cast all of module.named_buffers to device
and floating point buffers to
dtype."""
# if buffers are already on the right device and/or dtype this is just python loop cost
# if buffers are already on the right device and/or dtype this is just python loop cost
assert
dtype
in
{
torch
.
float32
,
torch
.
float16
}
# assumes compute_dtype == float16
for
key
,
buf
in
module
.
named_buffers
(
recurse
=
False
):
for
key
,
buf
in
module
.
named_buffers
(
recurse
=
False
):
if
buf
is
not
None
:
if
buf
is
not
None
:
setattr
(
module
,
key
,
buf
.
to
(
dtype
=
dtype
,
device
=
device
))
buf
=
buf
.
to
(
device
=
device
)
if
torch
.
is_floating_point
(
buf
):
buf
=
buf
.
to
(
dtype
=
dtype
)
setattr
(
module
,
key
,
buf
)
def
free_storage_
(
data
:
torch
.
Tensor
)
->
None
:
def
free_storage_
(
data
:
torch
.
Tensor
)
->
None
:
...
...
tests/nn/data_parallel/test_fsdp.py
View file @
9e0df348
...
@@ -29,8 +29,6 @@ from fairscale.utils.testing import (
...
@@ -29,8 +29,6 @@ from fairscale.utils.testing import (
# How to use remote-pdb: https://gist.github.com/sshleifer/9d43351957179c13606e015b072927d4
# How to use remote-pdb: https://gist.github.com/sshleifer/9d43351957179c13606e015b072927d4
# All helper functions called by spawn must be either @classmethod, @staticmethod
# All helper functions called by spawn must be either @classmethod, @staticmethod
_BUFFER_NAME
=
"vocab_bias"
class
DistributedTest
(
unittest
.
TestCase
):
class
DistributedTest
(
unittest
.
TestCase
):
def
setUp
(
self
):
def
setUp
(
self
):
...
@@ -411,8 +409,9 @@ class TestLocalStateDict(DistributedTest):
...
@@ -411,8 +409,9 @@ class TestLocalStateDict(DistributedTest):
# Assert that parameters were updated since before training
# Assert that parameters were updated since before training
unchanged
=
[]
unchanged
=
[]
buffers
=
{
name
for
name
,
_
in
model
.
module
.
named_buffers
()}
for
k
in
state_1
:
for
k
in
state_1
:
if
(
state_before_training
[
k
]
==
state_after_training
[
k
]).
all
()
and
(
_BUFFER_NAME
not
in
k
):
if
(
state_before_training
[
k
]
==
state_after_training
[
k
]).
all
()
and
(
k
not
in
buffers
):
unchanged
.
append
(
k
)
unchanged
.
append
(
k
)
if
unchanged
:
if
unchanged
:
raise
AssertionError
(
f
"params
{
unchanged
}
not changed after training"
)
raise
AssertionError
(
f
"params
{
unchanged
}
not changed after training"
)
...
@@ -651,7 +650,8 @@ class TransformerWithSharedParams(nn.Module):
...
@@ -651,7 +650,8 @@ class TransformerWithSharedParams(nn.Module):
self
.
output_proj
=
nn
.
Linear
(
d_model
,
d_vocab
)
self
.
output_proj
=
nn
.
Linear
(
d_model
,
d_vocab
)
# share the embedding and output projection weights
# share the embedding and output projection weights
self
.
output_proj
.
weight
=
self
.
embed_tokens
.
weight
self
.
output_proj
.
weight
=
self
.
embed_tokens
.
weight
self
.
register_buffer
(
_BUFFER_NAME
,
self
.
embed_tokens
.
weight
.
new_ones
((
d_model
,)))
self
.
register_buffer
(
"vocab_bias"
,
self
.
embed_tokens
.
weight
.
new_ones
((
d_model
,)))
self
.
register_buffer
(
"long_buffer"
,
torch
.
zeros_like
(
self
.
vocab_bias
,
dtype
=
torch
.
long
))
def
get_input
(
self
,
device
):
def
get_input
(
self
,
device
):
torch
.
manual_seed
(
1
+
self
.
rank
)
# keep everything deterministic
torch
.
manual_seed
(
1
+
self
.
rank
)
# keep everything deterministic
...
@@ -661,7 +661,7 @@ class TransformerWithSharedParams(nn.Module):
...
@@ -661,7 +661,7 @@ class TransformerWithSharedParams(nn.Module):
def
forward
(
self
,
src_ids
,
tgt_ids
):
def
forward
(
self
,
src_ids
,
tgt_ids
):
src
=
self
.
embed_tokens
(
src_ids
)
src
=
self
.
embed_tokens
(
src_ids
)
src
=
src
+
self
.
vocab_bias
src
=
src
+
self
.
vocab_bias
+
self
.
long_buffer
.
type_as
(
src
)
tgt
=
self
.
embed_tokens
(
tgt_ids
)
tgt
=
self
.
embed_tokens
(
tgt_ids
)
x
=
self
.
transformer
(
src
,
tgt
)
x
=
self
.
transformer
(
src
,
tgt
)
return
self
.
output_proj
(
x
)
return
self
.
output_proj
(
x
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment