Unverified Commit 8778fa66 authored by Benjamin Lefaudeux's avatar Benjamin Lefaudeux Committed by GitHub
Browse files

[fix] repro+fix (#365)

fix a broken earlier commit, only worked for the first step
parent 4dc605c9
......@@ -8,7 +8,7 @@ import copy
from itertools import chain
import logging
from math import inf
from typing import TYPE_CHECKING, Any, Callable, Deque, Dict, Iterable, List, Optional, Type, Union
from typing import TYPE_CHECKING, Any, Callable, Deque, Dict, List, Optional, Type, Union
import torch
import torch.distributed as dist
......@@ -81,7 +81,7 @@ class OSS(Optimizer):
self._partition_parameters: List[List[dict]] = []
self._index_to_param: Dict[int, torch.Tensor] = {}
self._param_to_index: Dict[int, int] = {}
self._local_params: Optional[Iterable[Any]] = None
self._local_params: Optional[List[torch.Tensor]] = None
# Build the wrapped optimizer, responsible for a shard of the params
self.group = group if group is not None else dist.group.WORLD
......@@ -145,14 +145,20 @@ class OSS(Optimizer):
return self._partition_parameters
@property
def local_params(self) -> Iterable[torch.Tensor]:
def local_params(self) -> List[torch.Tensor]:
""" Iterable which goes through the parameters that this rank owns
"""
if self._local_params is None:
self._local_params = chain(
*[
list(filter(lambda x: x.grad is not None, device_params[self.rank]))
for device_params in self.per_device_params.values()
]
self._local_params = list(
chain(
*[
list(filter(lambda x: x.grad is not None, device_params[self.rank]))
for device_params in self.per_device_params.values()
]
)
)
# Make sure that the iterator is not consumed, only expose a copy
return self._local_params
@property
......
......@@ -632,6 +632,9 @@ def run_gradient_clipping(rank, world_size, tempfile_name):
print(f"Checking norm {norm}")
check(norm)
# Check twice, catch an hypothetic iterator dumb mistake
check(norm)
dist.destroy_process_group()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment