Unverified Commit b46dcfaf authored by Benjamin Lefaudeux's avatar Benjamin Lefaudeux Committed by GitHub
Browse files

[fix] OSS fp16 broadcast typo (#751)

parent 83b0b49e
...@@ -7,6 +7,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ...@@ -7,6 +7,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
## NEXT - TBD ## NEXT - TBD
### Fixed ### Fixed
- FSDP: fixed metadata saving and shard consolidation for MoE cases [#746] - FSDP: fixed metadata saving and shard consolidation for MoE cases [#746]
- OSS: fixed the buckets which would stay in fp16 if `broadcast fp16` was required (#751)
### Added ### Added
- FSDP: better performance; use `_allgather_base` and `_reduce_scatter_base` when available [#729] - FSDP: better performance; use `_allgather_base` and `_reduce_scatter_base` when available [#729]
...@@ -21,7 +22,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ...@@ -21,7 +22,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- FSDP: Ensure requires_grad of FlatParameter is consistent with requires_grad of the original parameters. [#721] - FSDP: Ensure requires_grad of FlatParameter is consistent with requires_grad of the original parameters. [#721]
- doc: Thoroughly improved the doc for FSDP. [#711] - doc: Thoroughly improved the doc for FSDP. [#711]
- cleanup: Remove examples/ doc from the repo. [#712] - cleanup: Remove examples/ doc from the repo. [#712]
- cleanup: Future proof storage size test. [#735] - cleanup: Future proof storage size test. [#735]
- cleanup: Migrate away from legacy torchtext iterators. [#713] - cleanup: Migrate away from legacy torchtext iterators. [#713]
- chore: Updated torch 1.9 to release version. [#717] - chore: Updated torch 1.9 to release version. [#717]
......
...@@ -396,7 +396,7 @@ class OSS(Optimizer): ...@@ -396,7 +396,7 @@ class OSS(Optimizer):
OSS._sync_param_groups(self.param_groups, self.optim.param_groups) OSS._sync_param_groups(self.param_groups, self.optim.param_groups)
def refresh_trainable(self) -> None: def refresh_trainable(self) -> None:
""" Updates the partitioning and communication patterns if the trainability (`requires_grad`) """Updates the partitioning and communication patterns if the trainability (`requires_grad`)
of some parameters changed. of some parameters changed.
""" """
...@@ -551,7 +551,7 @@ class OSS(Optimizer): ...@@ -551,7 +551,7 @@ class OSS(Optimizer):
# Populate back the fp32 shards # Populate back the fp32 shards
if self.broadcast_fp16: if self.broadcast_fp16:
for device in self.buckets.keys(): for device in self.buckets.keys():
for dst_rank in self.buckets[device].keys(): for dst_rank, bucket in self.buckets[device].items():
bucket.to(dtype=torch.float32, device=device, non_blocking=True, keep_param_alignment=True) bucket.to(dtype=torch.float32, device=device, non_blocking=True, keep_param_alignment=True)
def _setup_flat_buffers(self) -> None: def _setup_flat_buffers(self) -> None:
......
...@@ -536,6 +536,10 @@ def run_test_reproducibility(rank, world_size, tempfile_name, broadcast_fp16): ...@@ -536,6 +536,10 @@ def run_test_reproducibility(rank, world_size, tempfile_name, broadcast_fp16):
assert torch.allclose(reference_loss, test_loss), f"{reference_loss} vs {test_loss}. Reproducibility is broken" assert torch.allclose(reference_loss, test_loss), f"{reference_loss} vs {test_loss}. Reproducibility is broken"
# Check that no matter what the buffer is back to fp32
for device in optimizer.buckets.keys():
for bucket in optimizer.buckets[device].values():
assert bucket.buffer.dtype == torch.float32
dist.destroy_process_group() dist.destroy_process_group()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment