Unverified Commit 2d9243bf authored by Benjamin Lefaudeux's avatar Benjamin Lefaudeux Committed by GitHub
Browse files

[fix] OSS - typo + small perf fix (#256)

* typo, sorry about that

* small perf fix
parent 351f35e1
......@@ -105,9 +105,9 @@ class OSS(Optimizer):
self._device = list(self.per_device_params.keys())[0]
self.buckets: Dict[torch.device, List[Bucket]] = {}
# if torch.cuda.is_available() and self.world_size <= torch.cuda.device_count():
# broadcast_buffer_size = 0
# logging.warning("Assuming single node job, bucketing is disabled")
if torch.cuda.is_available() and self.world_size <= torch.cuda.device_count():
broadcast_buffer_size = 0
logging.warning("Assuming single node job, bucketing is disabled")
self.bucket_size = broadcast_buffer_size
for device, per_device in self.per_device_params.items():
......@@ -393,6 +393,8 @@ class OSS(Optimizer):
from a call to :meth:`state_dict`
"""
print("loading state dict")
# Check whether we got a local or global dict
if state_dict["local_state_dict"]:
self.load_local_state_dict(state_dict)
......@@ -521,7 +523,7 @@ class OSS(Optimizer):
for param in params:
# Bucket broadcast
if self.should_bucket_param[param]:
if self.bucket_size > 0 and self.should_bucket_param[param]:
assert bucket.append(param), "Bucket overflow: max %s - current %s - adding %s" % (
bucket.max_size,
bucket.current_offset,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment