Unverified Commit 9e0df348 authored by Myle Ott's avatar Myle Ott Committed by GitHub
Browse files

[fix]: Fix non-float buffers in FSDP (#427)

parent b89365e6
...@@ -931,11 +931,15 @@ def cast_inputs_to_fp16(*args: Any, **kwargs: Any) -> Tuple[Any, Any]: ...@@ -931,11 +931,15 @@ def cast_inputs_to_fp16(*args: Any, **kwargs: Any) -> Tuple[Any, Any]:
def cast_buffers_( def cast_buffers_(
module: nn.Module, device: Optional[torch.device] = None, dtype: Optional[torch.dtype] = None module: nn.Module, device: Optional[torch.device] = None, dtype: Optional[torch.dtype] = None
) -> None: ) -> None:
"""Cast all of module.named_buffers to device, dtype.""" """Cast all of module.named_buffers to device and floating point buffers to dtype."""
# if buffers are already on the right device and/or dtype this is just python loop cost # if buffers are already on the right device and/or dtype this is just python loop cost
assert dtype in {torch.float32, torch.float16} # assumes compute_dtype == float16
for key, buf in module.named_buffers(recurse=False): for key, buf in module.named_buffers(recurse=False):
if buf is not None: if buf is not None:
setattr(module, key, buf.to(dtype=dtype, device=device)) buf = buf.to(device=device)
if torch.is_floating_point(buf):
buf = buf.to(dtype=dtype)
setattr(module, key, buf)
def free_storage_(data: torch.Tensor) -> None: def free_storage_(data: torch.Tensor) -> None:
......
...@@ -29,8 +29,6 @@ from fairscale.utils.testing import ( ...@@ -29,8 +29,6 @@ from fairscale.utils.testing import (
# How to use remote-pdb: https://gist.github.com/sshleifer/9d43351957179c13606e015b072927d4 # How to use remote-pdb: https://gist.github.com/sshleifer/9d43351957179c13606e015b072927d4
# All helper functions called by spawn must be either @classmethod, @staticmethod # All helper functions called by spawn must be either @classmethod, @staticmethod
_BUFFER_NAME = "vocab_bias"
class DistributedTest(unittest.TestCase): class DistributedTest(unittest.TestCase):
def setUp(self): def setUp(self):
...@@ -411,8 +409,9 @@ class TestLocalStateDict(DistributedTest): ...@@ -411,8 +409,9 @@ class TestLocalStateDict(DistributedTest):
# Assert that parameters were updated since before training # Assert that parameters were updated since before training
unchanged = [] unchanged = []
buffers = {name for name, _ in model.module.named_buffers()}
for k in state_1: for k in state_1:
if (state_before_training[k] == state_after_training[k]).all() and (_BUFFER_NAME not in k): if (state_before_training[k] == state_after_training[k]).all() and (k not in buffers):
unchanged.append(k) unchanged.append(k)
if unchanged: if unchanged:
raise AssertionError(f"params {unchanged} not changed after training") raise AssertionError(f"params {unchanged} not changed after training")
...@@ -651,7 +650,8 @@ class TransformerWithSharedParams(nn.Module): ...@@ -651,7 +650,8 @@ class TransformerWithSharedParams(nn.Module):
self.output_proj = nn.Linear(d_model, d_vocab) self.output_proj = nn.Linear(d_model, d_vocab)
# share the embedding and output projection weights # share the embedding and output projection weights
self.output_proj.weight = self.embed_tokens.weight self.output_proj.weight = self.embed_tokens.weight
self.register_buffer(_BUFFER_NAME, self.embed_tokens.weight.new_ones((d_model,))) self.register_buffer("vocab_bias", self.embed_tokens.weight.new_ones((d_model,)))
self.register_buffer("long_buffer", torch.zeros_like(self.vocab_bias, dtype=torch.long))
def get_input(self, device): def get_input(self, device):
torch.manual_seed(1 + self.rank) # keep everything deterministic torch.manual_seed(1 + self.rank) # keep everything deterministic
...@@ -661,7 +661,7 @@ class TransformerWithSharedParams(nn.Module): ...@@ -661,7 +661,7 @@ class TransformerWithSharedParams(nn.Module):
def forward(self, src_ids, tgt_ids): def forward(self, src_ids, tgt_ids):
src = self.embed_tokens(src_ids) src = self.embed_tokens(src_ids)
src = src + self.vocab_bias src = src + self.vocab_bias + self.long_buffer.type_as(src)
tgt = self.embed_tokens(tgt_ids) tgt = self.embed_tokens(tgt_ids)
x = self.transformer(src, tgt) x = self.transformer(src, tgt)
return self.output_proj(x) return self.output_proj(x)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment