Unverified Commit 9b79cc02 authored by Benjamin Lefaudeux's avatar Benjamin Lefaudeux Committed by GitHub
Browse files

[chore] OSS - adding the profiler labels (#629)

parent 85dea5b2
...@@ -11,6 +11,7 @@ from math import inf ...@@ -11,6 +11,7 @@ from math import inf
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Type, Union from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Type, Union
import torch import torch
from torch.autograd import profiler
import torch.distributed as dist import torch.distributed as dist
from torch.nn import Parameter from torch.nn import Parameter
from torch.optim import SGD, Optimizer from torch.optim import SGD, Optimizer
...@@ -166,16 +167,18 @@ class OSS(Optimizer): ...@@ -166,16 +167,18 @@ class OSS(Optimizer):
OSS._sync_param_groups(self.param_groups, self.optim.param_groups) OSS._sync_param_groups(self.param_groups, self.optim.param_groups)
# Catch a possible change of devices in between OSS construction and step() # Catch a possible change of devices in between OSS construction and step()
if self._default_device.type != self.param_groups[0]["params"][0].device.type: with profiler.record_function("fairscale::oss::refresh_trainable"):
logging.info("OSS detected that the parameter changed devices, re-allocating buffers") if self._default_device.type != self.param_groups[0]["params"][0].device.type:
self._clear_cache() logging.info("OSS detected that the parameter changed devices, re-allocating buffers")
self.refresh_trainable() self._clear_cache()
self.refresh_trainable()
# Run the optimizer step on this shard only: # Run the optimizer step on this shard only:
if closure is not None: with profiler.record_function("fairscale::oss::optim_step"):
loss = self.optim.step(closure=closure, **kwargs) # type: ignore if closure is not None:
else: loss = self.optim.step(closure=closure, **kwargs) # type: ignore
loss = self.optim.step(**kwargs) else:
loss = self.optim.step(**kwargs)
# Sync all the updated shards in between the ranks # Sync all the updated shards in between the ranks
self._broadcast_params() self._broadcast_params()
...@@ -214,33 +217,34 @@ class OSS(Optimizer): ...@@ -214,33 +217,34 @@ class OSS(Optimizer):
max_norm = float(max_norm) max_norm = float(max_norm)
norm_type = float(norm_type) norm_type = float(norm_type)
# Option to filter parameters from the grad_norm calculation. This is useful for model parallelism. with profiler.record_function("fairscale::oss::clip_grad_norm"):
# To avoid double counting, only consider parameters on rank zero + anything marked 'model_parallel' # Option to filter parameters from the grad_norm calculation. This is useful for model parallelism.
# 'model_parallel' flag is set in Megatron-LM: # To avoid double counting, only consider parameters on rank zero + anything marked 'model_parallel'
# https://github.com/NVIDIA/Megatron-LM/blob/19301985dd31c8b612095cbad15bd903e8ddd497/megatron/mpu/layers.py#L54 # 'model_parallel' flag is set in Megatron-LM:
local_params = filter_params_fn(self._local_params) if filter_params_fn is not None else self._local_params # https://github.com/NVIDIA/Megatron-LM/blob/19301985dd31c8b612095cbad15bd903e8ddd497/megatron/mpu/layers.py#L54
local_params = filter_params_fn(self._local_params) if filter_params_fn is not None else self._local_params
local_norm = calc_grad_norm(local_params, norm_type).to(self._default_device)
# Compute the norm on this grad set, local_norm = calc_grad_norm(local_params, norm_type).to(self._default_device)
# then sync all the norms from all ranks # Compute the norm on this grad set,
if norm_type == inf: # then sync all the norms from all ranks
total_norm = local_norm if norm_type == inf:
# all reduce over data parallel and model parallel workers total_norm = local_norm
dist.all_reduce(total_norm, op=torch.distributed.ReduceOp.MAX, group=dist.group.WORLD) # all reduce over data parallel and model parallel workers
else: dist.all_reduce(total_norm, op=torch.distributed.ReduceOp.MAX, group=dist.group.WORLD)
# local norm result can be accumulated with the remote ones if put to the right power else:
# n_i = sum_rank(a^p)^1/p # local norm result can be accumulated with the remote ones if put to the right power
# -> n_total = all_reduce(n_i^p)^(1/p) = sum_i(n_i^p)^1/p = sum_i(sum_rank(a^p))^1/p # n_i = sum_rank(a^p)^1/p
# all reduce over data parallel and model parallel workers # -> n_total = all_reduce(n_i^p)^(1/p) = sum_i(n_i^p)^1/p = sum_i(sum_rank(a^p))^1/p
total_norm = local_norm ** norm_type # all reduce over data parallel and model parallel workers
dist.all_reduce(total_norm) total_norm = local_norm ** norm_type
total_norm = total_norm ** (1.0 / norm_type) dist.all_reduce(total_norm)
total_norm = total_norm ** (1.0 / norm_type)
clip_coef = torch.tensor(max_norm, dtype=total_norm.dtype, device=total_norm.device) / (total_norm + 1e-6)
if clip_coef < 1: clip_coef = torch.tensor(max_norm, dtype=total_norm.dtype, device=total_norm.device) / (total_norm + 1e-6)
for device, device_params in self._per_device_params.items(): if clip_coef < 1:
for p in filter(lambda x: x.grad is not None, device_params[self.rank]): for device, device_params in self._per_device_params.items():
p.grad.detach().mul_(clip_coef.to(device)) # type: ignore # mypy trips on the filter for p in filter(lambda x: x.grad is not None, device_params[self.rank]):
p.grad.detach().mul_(clip_coef.to(device)) # type: ignore # mypy trips on the filter
return total_norm return total_norm
...@@ -513,39 +517,43 @@ class OSS(Optimizer): ...@@ -513,39 +517,43 @@ class OSS(Optimizer):
def _broadcast_params(self) -> None: def _broadcast_params(self) -> None:
"""Helper function to broadcast all the parameters from a given device""" """Helper function to broadcast all the parameters from a given device"""
# if NCCL broadcasts will be done in an independent stream with profiler.record_function("fairscale::oss::refresh_trainable"):
# make sure that prior compute work is complete # if NCCL broadcasts will be done in an independent stream
if torch.device("cuda").type == self._default_device.type: # make sure that prior compute work is complete
for device in self._per_device_params.keys(): if torch.device("cuda").type == self._default_device.type:
torch.cuda.synchronize(device=device) for device in self._per_device_params.keys():
torch.cuda.synchronize(device=device)
work_handles = [] # Work handles are consumed within this scope, no callback work_handles = [] # Work handles are consumed within this scope, no callback
# Populate the fp16 shards # Populate the fp16 shards
if self.broadcast_fp16: if self.broadcast_fp16:
for device in self.buckets.keys(): for device in self.buckets.keys():
for dst_rank, bucket in self.buckets[device].items(): for dst_rank, bucket in self.buckets[device].items():
bucket.to(dtype=torch.float16, device=device, non_blocking=True, keep_param_alignment=False) bucket.to(dtype=torch.float16, device=device, non_blocking=True, keep_param_alignment=False)
if torch.cuda.is_available(): if torch.cuda.is_available():
torch.cuda.synchronize() torch.cuda.synchronize()
# Exchange all the shards with the other ranks # Exchange all the shards with the other ranks
for device in self.buckets.keys(): for device in self.buckets.keys():
for dst_rank, bucket in self.buckets[device].items(): for dst_rank, bucket in self.buckets[device].items():
work_handles.append( work_handles.append(
dist.broadcast( dist.broadcast(
tensor=bucket.buffer, src=self._local_to_global_rank[dst_rank], group=self.group, async_op=True, tensor=bucket.buffer,
src=self._local_to_global_rank[dst_rank],
group=self.group,
async_op=True,
)
) )
)
_ = list(filter(lambda x: x.wait(), work_handles)) _ = list(filter(lambda x: x.wait(), work_handles))
# Populate back the fp32 shards # Populate back the fp32 shards
if self.broadcast_fp16: if self.broadcast_fp16:
for device in self.buckets.keys(): for device in self.buckets.keys():
for dst_rank in self.buckets[device].keys(): for dst_rank in self.buckets[device].keys():
bucket.to(dtype=torch.float32, device=device, non_blocking=True, keep_param_alignment=True) bucket.to(dtype=torch.float32, device=device, non_blocking=True, keep_param_alignment=True)
def _setup_flat_buffers(self) -> None: def _setup_flat_buffers(self) -> None:
"""Make all params which are on the same device and tied to the same rank views of a single buffer. """Make all params which are on the same device and tied to the same rank views of a single buffer.
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment