Unverified Commit 9b79cc02 authored by Benjamin Lefaudeux's avatar Benjamin Lefaudeux Committed by GitHub
Browse files

[chore] OSS - adding the profiler labels (#629)

parent 85dea5b2
...@@ -11,6 +11,7 @@ from math import inf ...@@ -11,6 +11,7 @@ from math import inf
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Type, Union from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Type, Union
import torch import torch
from torch.autograd import profiler
import torch.distributed as dist import torch.distributed as dist
from torch.nn import Parameter from torch.nn import Parameter
from torch.optim import SGD, Optimizer from torch.optim import SGD, Optimizer
...@@ -166,12 +167,14 @@ class OSS(Optimizer): ...@@ -166,12 +167,14 @@ class OSS(Optimizer):
OSS._sync_param_groups(self.param_groups, self.optim.param_groups) OSS._sync_param_groups(self.param_groups, self.optim.param_groups)
# Catch a possible change of devices in between OSS construction and step() # Catch a possible change of devices in between OSS construction and step()
with profiler.record_function("fairscale::oss::refresh_trainable"):
if self._default_device.type != self.param_groups[0]["params"][0].device.type: if self._default_device.type != self.param_groups[0]["params"][0].device.type:
logging.info("OSS detected that the parameter changed devices, re-allocating buffers") logging.info("OSS detected that the parameter changed devices, re-allocating buffers")
self._clear_cache() self._clear_cache()
self.refresh_trainable() self.refresh_trainable()
# Run the optimizer step on this shard only: # Run the optimizer step on this shard only:
with profiler.record_function("fairscale::oss::optim_step"):
if closure is not None: if closure is not None:
loss = self.optim.step(closure=closure, **kwargs) # type: ignore loss = self.optim.step(closure=closure, **kwargs) # type: ignore
else: else:
...@@ -214,6 +217,7 @@ class OSS(Optimizer): ...@@ -214,6 +217,7 @@ class OSS(Optimizer):
max_norm = float(max_norm) max_norm = float(max_norm)
norm_type = float(norm_type) norm_type = float(norm_type)
with profiler.record_function("fairscale::oss::clip_grad_norm"):
# Option to filter parameters from the grad_norm calculation. This is useful for model parallelism. # Option to filter parameters from the grad_norm calculation. This is useful for model parallelism.
# To avoid double counting, only consider parameters on rank zero + anything marked 'model_parallel' # To avoid double counting, only consider parameters on rank zero + anything marked 'model_parallel'
# 'model_parallel' flag is set in Megatron-LM: # 'model_parallel' flag is set in Megatron-LM:
...@@ -513,6 +517,7 @@ class OSS(Optimizer): ...@@ -513,6 +517,7 @@ class OSS(Optimizer):
def _broadcast_params(self) -> None: def _broadcast_params(self) -> None:
"""Helper function to broadcast all the parameters from a given device""" """Helper function to broadcast all the parameters from a given device"""
with profiler.record_function("fairscale::oss::refresh_trainable"):
# if NCCL broadcasts will be done in an independent stream # if NCCL broadcasts will be done in an independent stream
# make sure that prior compute work is complete # make sure that prior compute work is complete
if torch.device("cuda").type == self._default_device.type: if torch.device("cuda").type == self._default_device.type:
...@@ -535,7 +540,10 @@ class OSS(Optimizer): ...@@ -535,7 +540,10 @@ class OSS(Optimizer):
for dst_rank, bucket in self.buckets[device].items(): for dst_rank, bucket in self.buckets[device].items():
work_handles.append( work_handles.append(
dist.broadcast( dist.broadcast(
tensor=bucket.buffer, src=self._local_to_global_rank[dst_rank], group=self.group, async_op=True, tensor=bucket.buffer,
src=self._local_to_global_rank[dst_rank],
group=self.group,
async_op=True,
) )
) )
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment