Unverified Commit fb8d9137 authored by Benjamin Lefaudeux's avatar Benjamin Lefaudeux Committed by GitHub
Browse files

[fix] Dead code removal for OSS (#276)

* removing a dead call since ShardedDDP, small speedup
* unrelated, but filling in the changelog
* another nit
parent 7abaa2be
...@@ -13,6 +13,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ...@@ -13,6 +13,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
### Fixed ### Fixed
- AdaScale: smoothing factor value fixed when using gradient accumulation (#235) - AdaScale: smoothing factor value fixed when using gradient accumulation (#235)
- Pipe: documentation on balancing functions (#243) - Pipe: documentation on balancing functions (#243)
- ShardedDDP: handle typical NLP models
- ShardedDDP: better partitioning when finetuning
## [0.1.1] - 2020-12-01 ## [0.1.1] - 2020-12-01
### Fixed ### Fixed
......
...@@ -171,8 +171,7 @@ def train( ...@@ -171,8 +171,7 @@ def train(
else: else:
final_loss = optimizer.step(closure) final_loss = optimizer.step(closure)
prof.export_chrome_trace(f"{optim_type}_trace_rank_{rank}.json") prof.export_chrome_trace(f"{optim_type}_trace_rank_{rank}.json")
need_profiling = False # only profile once need_profiling = False # only profile once
else: else:
......
...@@ -208,9 +208,6 @@ class OSS(Optimizer): ...@@ -208,9 +208,6 @@ class OSS(Optimizer):
else: else:
loss = self.optim.step(**kwargs) loss = self.optim.step(**kwargs)
# Depending on the DDP engine used, gradients specific to other ranks may still be loaded
self._free_other_grads()
# Sync all the updated shards in between the ranks # Sync all the updated shards in between the ranks
self._broadcast_params() self._broadcast_params()
...@@ -507,17 +504,6 @@ class OSS(Optimizer): ...@@ -507,17 +504,6 @@ class OSS(Optimizer):
# Discard this tensor/rank, broadcast necessary for syncing # Discard this tensor/rank, broadcast necessary for syncing
broadcast_object(empty_buffer, src_rank=global_rank, group=self.group, dist_device=self._device) broadcast_object(empty_buffer, src_rank=global_rank, group=self.group, dist_device=self._device)
def _free_other_grads(self) -> None:
"""Free all the gradients only useful for the other ranks
"""
for rank, partition in enumerate(self.partition_parameters()):
if rank == self.rank:
continue
for p in partition:
for t in p["params"]:
t.grad = None
def _broadcast_params(self) -> None: def _broadcast_params(self) -> None:
"""Helper function to broadcast all the parameters from a given device""" """Helper function to broadcast all the parameters from a given device"""
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment