Unverified Commit 9b79cc02 authored by Benjamin Lefaudeux's avatar Benjamin Lefaudeux Committed by GitHub
Browse files

[chore] OSS - adding the profiler labels (#629)

parent 85dea5b2
......@@ -11,6 +11,7 @@ from math import inf
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Type, Union
import torch
from torch.autograd import profiler
import torch.distributed as dist
from torch.nn import Parameter
from torch.optim import SGD, Optimizer
......@@ -166,16 +167,18 @@ class OSS(Optimizer):
OSS._sync_param_groups(self.param_groups, self.optim.param_groups)
# Catch a possible change of devices in between OSS construction and step()
if self._default_device.type != self.param_groups[0]["params"][0].device.type:
logging.info("OSS detected that the parameter changed devices, re-allocating buffers")
self._clear_cache()
self.refresh_trainable()
with profiler.record_function("fairscale::oss::refresh_trainable"):
if self._default_device.type != self.param_groups[0]["params"][0].device.type:
logging.info("OSS detected that the parameter changed devices, re-allocating buffers")
self._clear_cache()
self.refresh_trainable()
# Run the optimizer step on this shard only:
if closure is not None:
loss = self.optim.step(closure=closure, **kwargs) # type: ignore
else:
loss = self.optim.step(**kwargs)
with profiler.record_function("fairscale::oss::optim_step"):
if closure is not None:
loss = self.optim.step(closure=closure, **kwargs) # type: ignore
else:
loss = self.optim.step(**kwargs)
# Sync all the updated shards in between the ranks
self._broadcast_params()
......@@ -214,33 +217,34 @@ class OSS(Optimizer):
max_norm = float(max_norm)
norm_type = float(norm_type)
# Option to filter parameters from the grad_norm calculation. This is useful for model parallelism.
# To avoid double counting, only consider parameters on rank zero + anything marked 'model_parallel'
# 'model_parallel' flag is set in Megatron-LM:
# https://github.com/NVIDIA/Megatron-LM/blob/19301985dd31c8b612095cbad15bd903e8ddd497/megatron/mpu/layers.py#L54
local_params = filter_params_fn(self._local_params) if filter_params_fn is not None else self._local_params
local_norm = calc_grad_norm(local_params, norm_type).to(self._default_device)
# Compute the norm on this grad set,
# then sync all the norms from all ranks
if norm_type == inf:
total_norm = local_norm
# all reduce over data parallel and model parallel workers
dist.all_reduce(total_norm, op=torch.distributed.ReduceOp.MAX, group=dist.group.WORLD)
else:
# local norm result can be accumulated with the remote ones if put to the right power
# n_i = sum_rank(a^p)^1/p
# -> n_total = all_reduce(n_i^p)^(1/p) = sum_i(n_i^p)^1/p = sum_i(sum_rank(a^p))^1/p
# all reduce over data parallel and model parallel workers
total_norm = local_norm ** norm_type
dist.all_reduce(total_norm)
total_norm = total_norm ** (1.0 / norm_type)
clip_coef = torch.tensor(max_norm, dtype=total_norm.dtype, device=total_norm.device) / (total_norm + 1e-6)
if clip_coef < 1:
for device, device_params in self._per_device_params.items():
for p in filter(lambda x: x.grad is not None, device_params[self.rank]):
p.grad.detach().mul_(clip_coef.to(device)) # type: ignore # mypy trips on the filter
with profiler.record_function("fairscale::oss::clip_grad_norm"):
# Option to filter parameters from the grad_norm calculation. This is useful for model parallelism.
# To avoid double counting, only consider parameters on rank zero + anything marked 'model_parallel'
# 'model_parallel' flag is set in Megatron-LM:
# https://github.com/NVIDIA/Megatron-LM/blob/19301985dd31c8b612095cbad15bd903e8ddd497/megatron/mpu/layers.py#L54
local_params = filter_params_fn(self._local_params) if filter_params_fn is not None else self._local_params
local_norm = calc_grad_norm(local_params, norm_type).to(self._default_device)
# Compute the norm on this grad set,
# then sync all the norms from all ranks
if norm_type == inf:
total_norm = local_norm
# all reduce over data parallel and model parallel workers
dist.all_reduce(total_norm, op=torch.distributed.ReduceOp.MAX, group=dist.group.WORLD)
else:
# local norm result can be accumulated with the remote ones if put to the right power
# n_i = sum_rank(a^p)^1/p
# -> n_total = all_reduce(n_i^p)^(1/p) = sum_i(n_i^p)^1/p = sum_i(sum_rank(a^p))^1/p
# all reduce over data parallel and model parallel workers
total_norm = local_norm ** norm_type
dist.all_reduce(total_norm)
total_norm = total_norm ** (1.0 / norm_type)
clip_coef = torch.tensor(max_norm, dtype=total_norm.dtype, device=total_norm.device) / (total_norm + 1e-6)
if clip_coef < 1:
for device, device_params in self._per_device_params.items():
for p in filter(lambda x: x.grad is not None, device_params[self.rank]):
p.grad.detach().mul_(clip_coef.to(device)) # type: ignore # mypy trips on the filter
return total_norm
......@@ -513,39 +517,43 @@ class OSS(Optimizer):
def _broadcast_params(self) -> None:
"""Helper function to broadcast all the parameters from a given device"""
# if NCCL broadcasts will be done in an independent stream
# make sure that prior compute work is complete
if torch.device("cuda").type == self._default_device.type:
for device in self._per_device_params.keys():
torch.cuda.synchronize(device=device)
with profiler.record_function("fairscale::oss::refresh_trainable"):
# if NCCL broadcasts will be done in an independent stream
# make sure that prior compute work is complete
if torch.device("cuda").type == self._default_device.type:
for device in self._per_device_params.keys():
torch.cuda.synchronize(device=device)
work_handles = [] # Work handles are consumed within this scope, no callback
work_handles = [] # Work handles are consumed within this scope, no callback
# Populate the fp16 shards
if self.broadcast_fp16:
for device in self.buckets.keys():
for dst_rank, bucket in self.buckets[device].items():
bucket.to(dtype=torch.float16, device=device, non_blocking=True, keep_param_alignment=False)
# Populate the fp16 shards
if self.broadcast_fp16:
for device in self.buckets.keys():
for dst_rank, bucket in self.buckets[device].items():
bucket.to(dtype=torch.float16, device=device, non_blocking=True, keep_param_alignment=False)
if torch.cuda.is_available():
torch.cuda.synchronize()
if torch.cuda.is_available():
torch.cuda.synchronize()
# Exchange all the shards with the other ranks
for device in self.buckets.keys():
for dst_rank, bucket in self.buckets[device].items():
work_handles.append(
dist.broadcast(
tensor=bucket.buffer, src=self._local_to_global_rank[dst_rank], group=self.group, async_op=True,
# Exchange all the shards with the other ranks
for device in self.buckets.keys():
for dst_rank, bucket in self.buckets[device].items():
work_handles.append(
dist.broadcast(
tensor=bucket.buffer,
src=self._local_to_global_rank[dst_rank],
group=self.group,
async_op=True,
)
)
)
_ = list(filter(lambda x: x.wait(), work_handles))
_ = list(filter(lambda x: x.wait(), work_handles))
# Populate back the fp32 shards
if self.broadcast_fp16:
for device in self.buckets.keys():
for dst_rank in self.buckets[device].keys():
bucket.to(dtype=torch.float32, device=device, non_blocking=True, keep_param_alignment=True)
# Populate back the fp32 shards
if self.broadcast_fp16:
for device in self.buckets.keys():
for dst_rank in self.buckets[device].keys():
bucket.to(dtype=torch.float32, device=device, non_blocking=True, keep_param_alignment=True)
def _setup_flat_buffers(self) -> None:
"""Make all params which are on the same device and tied to the same rank views of a single buffer.
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment