Unverified Commit 9b79cc02 authored by Benjamin Lefaudeux's avatar Benjamin Lefaudeux Committed by GitHub
Browse files

[chore] OSS - adding the profiler labels (#629)

parent 85dea5b2
......@@ -11,6 +11,7 @@ from math import inf
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Type, Union
import torch
from torch.autograd import profiler
import torch.distributed as dist
from torch.nn import Parameter
from torch.optim import SGD, Optimizer
......@@ -166,12 +167,14 @@ class OSS(Optimizer):
OSS._sync_param_groups(self.param_groups, self.optim.param_groups)
# Catch a possible change of devices in between OSS construction and step()
with profiler.record_function("fairscale::oss::refresh_trainable"):
if self._default_device.type != self.param_groups[0]["params"][0].device.type:
logging.info("OSS detected that the parameter changed devices, re-allocating buffers")
self._clear_cache()
self.refresh_trainable()
# Run the optimizer step on this shard only:
with profiler.record_function("fairscale::oss::optim_step"):
if closure is not None:
loss = self.optim.step(closure=closure, **kwargs) # type: ignore
else:
......@@ -214,6 +217,7 @@ class OSS(Optimizer):
max_norm = float(max_norm)
norm_type = float(norm_type)
with profiler.record_function("fairscale::oss::clip_grad_norm"):
# Option to filter parameters from the grad_norm calculation. This is useful for model parallelism.
# To avoid double counting, only consider parameters on rank zero + anything marked 'model_parallel'
# 'model_parallel' flag is set in Megatron-LM:
......@@ -513,6 +517,7 @@ class OSS(Optimizer):
def _broadcast_params(self) -> None:
"""Helper function to broadcast all the parameters from a given device"""
with profiler.record_function("fairscale::oss::refresh_trainable"):
# if NCCL broadcasts will be done in an independent stream
# make sure that prior compute work is complete
if torch.device("cuda").type == self._default_device.type:
......@@ -535,7 +540,10 @@ class OSS(Optimizer):
for dst_rank, bucket in self.buckets[device].items():
work_handles.append(
dist.broadcast(
tensor=bucket.buffer, src=self._local_to_global_rank[dst_rank], group=self.group, async_op=True,
tensor=bucket.buffer,
src=self._local_to_global_rank[dst_rank],
group=self.group,
async_op=True,
)
)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment