Unverified Commit 8778fa66 authored by Benjamin Lefaudeux's avatar Benjamin Lefaudeux Committed by GitHub
Browse files

[fix] repro+fix (#365)

fix a broken earlier commit, only worked for the first step
parent 4dc605c9
...@@ -8,7 +8,7 @@ import copy ...@@ -8,7 +8,7 @@ import copy
from itertools import chain from itertools import chain
import logging import logging
from math import inf from math import inf
from typing import TYPE_CHECKING, Any, Callable, Deque, Dict, Iterable, List, Optional, Type, Union from typing import TYPE_CHECKING, Any, Callable, Deque, Dict, List, Optional, Type, Union
import torch import torch
import torch.distributed as dist import torch.distributed as dist
...@@ -81,7 +81,7 @@ class OSS(Optimizer): ...@@ -81,7 +81,7 @@ class OSS(Optimizer):
self._partition_parameters: List[List[dict]] = [] self._partition_parameters: List[List[dict]] = []
self._index_to_param: Dict[int, torch.Tensor] = {} self._index_to_param: Dict[int, torch.Tensor] = {}
self._param_to_index: Dict[int, int] = {} self._param_to_index: Dict[int, int] = {}
self._local_params: Optional[Iterable[Any]] = None self._local_params: Optional[List[torch.Tensor]] = None
# Build the wrapped optimizer, responsible for a shard of the params # Build the wrapped optimizer, responsible for a shard of the params
self.group = group if group is not None else dist.group.WORLD self.group = group if group is not None else dist.group.WORLD
...@@ -145,14 +145,20 @@ class OSS(Optimizer): ...@@ -145,14 +145,20 @@ class OSS(Optimizer):
return self._partition_parameters return self._partition_parameters
@property @property
def local_params(self) -> Iterable[torch.Tensor]: def local_params(self) -> List[torch.Tensor]:
""" Iterable which goes through the parameters that this rank owns
"""
if self._local_params is None: if self._local_params is None:
self._local_params = chain( self._local_params = list(
*[ chain(
list(filter(lambda x: x.grad is not None, device_params[self.rank])) *[
for device_params in self.per_device_params.values() list(filter(lambda x: x.grad is not None, device_params[self.rank]))
] for device_params in self.per_device_params.values()
]
)
) )
# Make sure that the iterator is not consumed, only expose a copy
return self._local_params return self._local_params
@property @property
......
...@@ -632,6 +632,9 @@ def run_gradient_clipping(rank, world_size, tempfile_name): ...@@ -632,6 +632,9 @@ def run_gradient_clipping(rank, world_size, tempfile_name):
print(f"Checking norm {norm}") print(f"Checking norm {norm}")
check(norm) check(norm)
# Check twice, catch an hypothetic iterator dumb mistake
check(norm)
dist.destroy_process_group() dist.destroy_process_group()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment