Unverified Commit 7fdd7ecf authored by Benjamin Lefaudeux's avatar Benjamin Lefaudeux Committed by GitHub
Browse files

[perf][OSS] Clip grad norm : minor obvious speedup (#363)

cache this iterator, easy speed up
parent 5c3ff9bd
...@@ -8,7 +8,7 @@ import copy ...@@ -8,7 +8,7 @@ import copy
from itertools import chain from itertools import chain
import logging import logging
from math import inf from math import inf
from typing import TYPE_CHECKING, Any, Callable, Deque, Dict, List, Optional, Type, Union from typing import TYPE_CHECKING, Any, Callable, Deque, Dict, Iterable, List, Optional, Type, Union
import torch import torch
import torch.distributed as dist import torch.distributed as dist
...@@ -81,6 +81,7 @@ class OSS(Optimizer): ...@@ -81,6 +81,7 @@ class OSS(Optimizer):
self._partition_parameters: List[List[dict]] = [] self._partition_parameters: List[List[dict]] = []
self._index_to_param: Dict[int, torch.Tensor] = {} self._index_to_param: Dict[int, torch.Tensor] = {}
self._param_to_index: Dict[int, int] = {} self._param_to_index: Dict[int, int] = {}
self._local_params: Optional[Iterable[Any]] = None
# Build the wrapped optimizer, responsible for a shard of the params # Build the wrapped optimizer, responsible for a shard of the params
self.group = group if group is not None else dist.group.WORLD self.group = group if group is not None else dist.group.WORLD
...@@ -143,6 +144,17 @@ class OSS(Optimizer): ...@@ -143,6 +144,17 @@ class OSS(Optimizer):
return self._partition_parameters return self._partition_parameters
@property
def local_params(self) -> Iterable[torch.Tensor]:
if self._local_params is None:
self._local_params = chain(
*[
list(filter(lambda x: x.grad is not None, device_params[self.rank]))
for device_params in self.per_device_params.values()
]
)
return self._local_params
@property @property
def index_to_param(self) -> Dict[int, torch.Tensor]: def index_to_param(self) -> Dict[int, torch.Tensor]:
""" Hash table in between parameter indices in the global optimizer scheme, and the actual params """ Hash table in between parameter indices in the global optimizer scheme, and the actual params
...@@ -255,25 +267,16 @@ class OSS(Optimizer): ...@@ -255,25 +267,16 @@ class OSS(Optimizer):
max_norm = float(max_norm) max_norm = float(max_norm)
norm_type = float(norm_type) norm_type = float(norm_type)
# Filter out the grad-less params, concatenate params from all devices
local_params = chain(
*[
list(filter(lambda x: x.grad is not None, device_params[self.rank]))
for device_params in self.per_device_params.values()
]
)
# Option to filter parameters from the grad_norm calculation. This is useful for model parallelism. # Option to filter parameters from the grad_norm calculation. This is useful for model parallelism.
# To avoid double counting, only consider parameters on rank zero + anything marked 'model_parallel' # To avoid double counting, only consider parameters on rank zero + anything marked 'model_parallel'
# 'model_parallel' flag is set in Megatron-LM: # 'model_parallel' flag is set in Megatron-LM:
# https://github.com/NVIDIA/Megatron-LM/blob/19301985dd31c8b612095cbad15bd903e8ddd497/megatron/mpu/layers.py#L54 # https://github.com/NVIDIA/Megatron-LM/blob/19301985dd31c8b612095cbad15bd903e8ddd497/megatron/mpu/layers.py#L54
if filter_params_fn is not None: local_params = filter_params_fn(self.local_params) if filter_params_fn is not None else self.local_params
local_params = filter_params_fn(local_params)
# Compute the norm on this grad set, # Compute the norm on this grad set,
# then sync all the norms from all ranks # then sync all the norms from all ranks
if norm_type == inf: if norm_type == inf:
total_norm = max(p.grad.detach().abs().max().to(self._device) for p in local_params) # type: ignore total_norm = max(p.grad.detach().abs().max().to(self._device) for p in local_params)
# all reduce over data parallel and model parallel workers # all reduce over data parallel and model parallel workers
dist.all_reduce(total_norm, op=torch.distributed.ReduceOp.MAX, group=dist.group.WORLD) dist.all_reduce(total_norm, op=torch.distributed.ReduceOp.MAX, group=dist.group.WORLD)
else: else:
...@@ -508,6 +511,7 @@ class OSS(Optimizer): ...@@ -508,6 +511,7 @@ class OSS(Optimizer):
self._param_rank.clear() self._param_rank.clear()
self._index_to_param.clear() self._index_to_param.clear()
self._param_to_index.clear() self._param_to_index.clear()
self._local_params = None
@staticmethod @staticmethod
def get_global_rank(group: Any, rank: int) -> int: def get_global_rank(group: Any, rank: int) -> int:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment