"vscode:/vscode.git/clone" did not exist on "568dc42d4fc3a70466897242cd371d4c3034f48c"
Unverified Commit 7fdd7ecf authored by Benjamin Lefaudeux's avatar Benjamin Lefaudeux Committed by GitHub
Browse files

[perf][OSS] Clip grad norm : minor obvious speedup (#363)

cache this iterator, easy speed up
parent 5c3ff9bd
......@@ -8,7 +8,7 @@ import copy
from itertools import chain
import logging
from math import inf
from typing import TYPE_CHECKING, Any, Callable, Deque, Dict, List, Optional, Type, Union
from typing import TYPE_CHECKING, Any, Callable, Deque, Dict, Iterable, List, Optional, Type, Union
import torch
import torch.distributed as dist
......@@ -81,6 +81,7 @@ class OSS(Optimizer):
self._partition_parameters: List[List[dict]] = []
self._index_to_param: Dict[int, torch.Tensor] = {}
self._param_to_index: Dict[int, int] = {}
self._local_params: Optional[Iterable[Any]] = None
# Build the wrapped optimizer, responsible for a shard of the params
self.group = group if group is not None else dist.group.WORLD
......@@ -143,6 +144,17 @@ class OSS(Optimizer):
return self._partition_parameters
@property
def local_params(self) -> Iterable[torch.Tensor]:
if self._local_params is None:
self._local_params = chain(
*[
list(filter(lambda x: x.grad is not None, device_params[self.rank]))
for device_params in self.per_device_params.values()
]
)
return self._local_params
@property
def index_to_param(self) -> Dict[int, torch.Tensor]:
""" Hash table in between parameter indices in the global optimizer scheme, and the actual params
......@@ -255,25 +267,16 @@ class OSS(Optimizer):
max_norm = float(max_norm)
norm_type = float(norm_type)
# Filter out the grad-less params, concatenate params from all devices
local_params = chain(
*[
list(filter(lambda x: x.grad is not None, device_params[self.rank]))
for device_params in self.per_device_params.values()
]
)
# Option to filter parameters from the grad_norm calculation. This is useful for model parallelism.
# To avoid double counting, only consider parameters on rank zero + anything marked 'model_parallel'
# 'model_parallel' flag is set in Megatron-LM:
# https://github.com/NVIDIA/Megatron-LM/blob/19301985dd31c8b612095cbad15bd903e8ddd497/megatron/mpu/layers.py#L54
if filter_params_fn is not None:
local_params = filter_params_fn(local_params)
local_params = filter_params_fn(self.local_params) if filter_params_fn is not None else self.local_params
# Compute the norm on this grad set,
# then sync all the norms from all ranks
if norm_type == inf:
total_norm = max(p.grad.detach().abs().max().to(self._device) for p in local_params) # type: ignore
total_norm = max(p.grad.detach().abs().max().to(self._device) for p in local_params)
# all reduce over data parallel and model parallel workers
dist.all_reduce(total_norm, op=torch.distributed.ReduceOp.MAX, group=dist.group.WORLD)
else:
......@@ -508,6 +511,7 @@ class OSS(Optimizer):
self._param_rank.clear()
self._index_to_param.clear()
self._param_to_index.clear()
self._local_params = None
@staticmethod
def get_global_rank(group: Any, rank: int) -> int:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment