Unverified Commit 9e0df348 authored by Myle Ott's avatar Myle Ott Committed by GitHub
Browse files

[fix]: Fix non-float buffers in FSDP (#427)

parent b89365e6
......@@ -931,11 +931,15 @@ def cast_inputs_to_fp16(*args: Any, **kwargs: Any) -> Tuple[Any, Any]:
def cast_buffers_(
module: nn.Module, device: Optional[torch.device] = None, dtype: Optional[torch.dtype] = None
) -> None:
"""Cast all of module.named_buffers to device, dtype."""
"""Cast all of module.named_buffers to device and floating point buffers to dtype."""
# if buffers are already on the right device and/or dtype this is just python loop cost
assert dtype in {torch.float32, torch.float16} # assumes compute_dtype == float16
for key, buf in module.named_buffers(recurse=False):
if buf is not None:
setattr(module, key, buf.to(dtype=dtype, device=device))
buf = buf.to(device=device)
if torch.is_floating_point(buf):
buf = buf.to(dtype=dtype)
setattr(module, key, buf)
def free_storage_(data: torch.Tensor) -> None:
......
......@@ -29,8 +29,6 @@ from fairscale.utils.testing import (
# How to use remote-pdb: https://gist.github.com/sshleifer/9d43351957179c13606e015b072927d4
# All helper functions called by spawn must be either @classmethod, @staticmethod
_BUFFER_NAME = "vocab_bias"
class DistributedTest(unittest.TestCase):
def setUp(self):
......@@ -411,8 +409,9 @@ class TestLocalStateDict(DistributedTest):
# Assert that parameters were updated since before training
unchanged = []
buffers = {name for name, _ in model.module.named_buffers()}
for k in state_1:
if (state_before_training[k] == state_after_training[k]).all() and (_BUFFER_NAME not in k):
if (state_before_training[k] == state_after_training[k]).all() and (k not in buffers):
unchanged.append(k)
if unchanged:
raise AssertionError(f"params {unchanged} not changed after training")
......@@ -651,7 +650,8 @@ class TransformerWithSharedParams(nn.Module):
self.output_proj = nn.Linear(d_model, d_vocab)
# share the embedding and output projection weights
self.output_proj.weight = self.embed_tokens.weight
self.register_buffer(_BUFFER_NAME, self.embed_tokens.weight.new_ones((d_model,)))
self.register_buffer("vocab_bias", self.embed_tokens.weight.new_ones((d_model,)))
self.register_buffer("long_buffer", torch.zeros_like(self.vocab_bias, dtype=torch.long))
def get_input(self, device):
torch.manual_seed(1 + self.rank) # keep everything deterministic
......@@ -661,7 +661,7 @@ class TransformerWithSharedParams(nn.Module):
def forward(self, src_ids, tgt_ids):
src = self.embed_tokens(src_ids)
src = src + self.vocab_bias
src = src + self.vocab_bias + self.long_buffer.type_as(src)
tgt = self.embed_tokens(tgt_ids)
x = self.transformer(src, tgt)
return self.output_proj(x)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment