Unverified Commit 2d9243bf authored by Benjamin Lefaudeux's avatar Benjamin Lefaudeux Committed by GitHub
Browse files

[fix] OSS - typo + small perf fix (#256)

* typo, sorry about that

* small perf fix
parent 351f35e1
...@@ -105,9 +105,9 @@ class OSS(Optimizer): ...@@ -105,9 +105,9 @@ class OSS(Optimizer):
self._device = list(self.per_device_params.keys())[0] self._device = list(self.per_device_params.keys())[0]
self.buckets: Dict[torch.device, List[Bucket]] = {} self.buckets: Dict[torch.device, List[Bucket]] = {}
# if torch.cuda.is_available() and self.world_size <= torch.cuda.device_count(): if torch.cuda.is_available() and self.world_size <= torch.cuda.device_count():
# broadcast_buffer_size = 0 broadcast_buffer_size = 0
# logging.warning("Assuming single node job, bucketing is disabled") logging.warning("Assuming single node job, bucketing is disabled")
self.bucket_size = broadcast_buffer_size self.bucket_size = broadcast_buffer_size
for device, per_device in self.per_device_params.items(): for device, per_device in self.per_device_params.items():
...@@ -393,6 +393,8 @@ class OSS(Optimizer): ...@@ -393,6 +393,8 @@ class OSS(Optimizer):
from a call to :meth:`state_dict` from a call to :meth:`state_dict`
""" """
print("loading state dict")
# Check whether we got a local or global dict # Check whether we got a local or global dict
if state_dict["local_state_dict"]: if state_dict["local_state_dict"]:
self.load_local_state_dict(state_dict) self.load_local_state_dict(state_dict)
...@@ -521,7 +523,7 @@ class OSS(Optimizer): ...@@ -521,7 +523,7 @@ class OSS(Optimizer):
for param in params: for param in params:
# Bucket broadcast # Bucket broadcast
if self.should_bucket_param[param]: if self.bucket_size > 0 and self.should_bucket_param[param]:
assert bucket.append(param), "Bucket overflow: max %s - current %s - adding %s" % ( assert bucket.append(param), "Bucket overflow: max %s - current %s - adding %s" % (
bucket.max_size, bucket.max_size,
bucket.current_offset, bucket.current_offset,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment