Unverified Commit 180ab8c8 authored by Benjamin Lefaudeux's avatar Benjamin Lefaudeux Committed by GitHub
Browse files

[OSS] Fixing the fp16 broadcast and catching this case in the unit test (#795)

parent 31e36453
...@@ -16,6 +16,10 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ...@@ -16,6 +16,10 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- activation checkpoint: Ensure outputs of checkpointed modules only require grad if either - activation checkpoint: Ensure outputs of checkpointed modules only require grad if either
the input requires grad or if the parameters require grad. [#787] the input requires grad or if the parameters require grad. [#787]
- OSS: fix the broadcast_fp16 option, broken after a refactor, this flag was doing nothing (bugfix).[#795]
- OSS: update default device when refreshing the params, meaning that moving the model to GPU after
the OSS wrap will not trigger warnings and slow the jobs (ease of use). [#786]
### Added ### Added
- FSDP: Added support for returning the original names of parameters when `named_parameters` is called on - FSDP: Added support for returning the original names of parameters when `named_parameters` is called on
the module. To retrieve the orginal names of the parameters along with the params, you need to the module. To retrieve the orginal names of the parameters along with the params, you need to
......
...@@ -32,7 +32,7 @@ class Bucket: ...@@ -32,7 +32,7 @@ class Bucket:
Move the underlying buffer Move the underlying buffer
""" """
assert self.buffer is not None, "Cannot move a collapsed bucket, please rebuild it" assert self.buffer is not None, "Cannot move a collapsed bucket, please rebuild it"
self.buffer.to(device, dtype, non_blocking) self.buffer = self.buffer.to(device, dtype, non_blocking)
class ParamBucket(Bucket): class ParamBucket(Bucket):
...@@ -98,6 +98,8 @@ class ParamBucket(Bucket): ...@@ -98,6 +98,8 @@ class ParamBucket(Bucket):
self._fill = 0 self._fill = 0
for p in self._params: for p in self._params:
if p.dtype != self.buffer.dtype:
p.data = p.data.to(self.buffer.dtype)
self._add_param_as_view(p, keep_existing_value=False) self._add_param_as_view(p, keep_existing_value=False)
......
...@@ -48,7 +48,10 @@ def test_type_change(): ...@@ -48,7 +48,10 @@ def test_type_change():
# Move the bucket to fp16 and back # Move the bucket to fp16 and back
bucket.to(dtype=torch.float16, device=param.device) bucket.to(dtype=torch.float16, device=param.device)
assert bucket.buffer.dtype == torch.float16
bucket.to(dtype=torch.float32, device=param.device, keep_param_alignment=True) bucket.to(dtype=torch.float32, device=param.device, keep_param_alignment=True)
assert bucket.buffer.dtype == torch.float32
# Same with the reference tensor # Same with the reference tensor
param_.to(dtype=torch.float16) param_.to(dtype=torch.float16)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment