Unverified Commit 180ab8c8 authored by Benjamin Lefaudeux's avatar Benjamin Lefaudeux Committed by GitHub
Browse files

[OSS] Fixing the fp16 broadcast and catching this case in the unit test (#795)

parent 31e36453
......@@ -16,6 +16,10 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- activation checkpoint: Ensure outputs of checkpointed modules only require grad if either
the input requires grad or if the parameters require grad. [#787]
- OSS: fix the broadcast_fp16 option, broken after a refactor, this flag was doing nothing (bugfix).[#795]
- OSS: update default device when refreshing the params, meaning that moving the model to GPU after
the OSS wrap will not trigger warnings and slow the jobs (ease of use). [#786]
### Added
- FSDP: Added support for returning the original names of parameters when `named_parameters` is called on
the module. To retrieve the orginal names of the parameters along with the params, you need to
......
......@@ -32,7 +32,7 @@ class Bucket:
Move the underlying buffer
"""
assert self.buffer is not None, "Cannot move a collapsed bucket, please rebuild it"
self.buffer.to(device, dtype, non_blocking)
self.buffer = self.buffer.to(device, dtype, non_blocking)
class ParamBucket(Bucket):
......@@ -98,6 +98,8 @@ class ParamBucket(Bucket):
self._fill = 0
for p in self._params:
if p.dtype != self.buffer.dtype:
p.data = p.data.to(self.buffer.dtype)
self._add_param_as_view(p, keep_existing_value=False)
......
......@@ -48,7 +48,10 @@ def test_type_change():
# Move the bucket to fp16 and back
bucket.to(dtype=torch.float16, device=param.device)
assert bucket.buffer.dtype == torch.float16
bucket.to(dtype=torch.float32, device=param.device, keep_param_alignment=True)
assert bucket.buffer.dtype == torch.float32
# Same with the reference tensor
param_.to(dtype=torch.float16)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment