Unverified Commit c386e937 authored by Benjamin Lefaudeux's avatar Benjamin Lefaudeux Committed by GitHub
Browse files

[OSS] Balance the trainable params only (#262)

* fix, one liner

* adjust so that frozen trunks get spread still, even if this should have little consequences

* removing dead code, hopeful unit test fix

* now with some linting..

* adding a proper unit test case
parent ca74ee22
...@@ -5,7 +5,6 @@ ...@@ -5,7 +5,6 @@
from collections import OrderedDict, deque from collections import OrderedDict, deque
import copy import copy
from enum import Enum, auto
import itertools import itertools
from itertools import chain from itertools import chain
import logging import logging
...@@ -27,11 +26,6 @@ else: ...@@ -27,11 +26,6 @@ else:
_params_t = Any _params_t = Any
class BucketFlush(Enum):
Reduce = auto()
Broadcast = auto()
class OSS(Optimizer): class OSS(Optimizer):
"""Wraps an arbitrary :class:`optim.Optimizer <torch.optim.Optimizer>` """Wraps an arbitrary :class:`optim.Optimizer <torch.optim.Optimizer>`
optimizer and shards its state as described by ZeRO_. optimizer and shards its state as described by ZeRO_.
...@@ -139,7 +133,16 @@ class OSS(Optimizer): ...@@ -139,7 +133,16 @@ class OSS(Optimizer):
# Add this param to rank with smallest size. # Add this param to rank with smallest size.
rank = sizes.index(min(sizes)) rank = sizes.index(min(sizes))
param_lists[rank].append(param) param_lists[rank].append(param)
sizes[rank] += param.numel()
# We're partitioning the optimizer state,
# so trainable parameters are the ones which really count
if param.requires_grad:
sizes[rank] += param.numel()
else:
# Spread frozen params on a per-tensor basis
# Mostly useful for balance partitions for fine tuning for instance
# Not required strictly speaking
sizes[rank] += 1
for rank, params in enumerate(param_lists): for rank, params in enumerate(param_lists):
param_group_rank = copy.copy(param_group) param_group_rank = copy.copy(param_group)
...@@ -585,30 +588,6 @@ class OSS(Optimizer): ...@@ -585,30 +588,6 @@ class OSS(Optimizer):
if work_handle.callback is not None: if work_handle.callback is not None:
work_handle.callback() work_handle.callback()
def _handle_trailing_buckets(self, flush_type: BucketFlush) -> None:
"""
Go through the buckets, flush them if not already empty
.. warning: Could be that a bucket flush was already requested, needs to be handled carefully
"""
for bucket_list in self.buckets.values():
for bucket in bucket_list:
if bucket.current_offset > 0:
self.work_handles.append(
Workhandle(
handle=dist.broadcast(
tensor=bucket.buffer, src=bucket.global_ref_rank, group=self.group, async_op=True,
)
if flush_type == BucketFlush.Broadcast
else dist.reduce(
tensor=bucket.buffer, dst=bucket.global_ref_rank, group=self.group, async_op=True,
),
callback=bucket.unroll,
)
)
self._consume_work_handles()
def _setup_bucket_strategy(self) -> None: def _setup_bucket_strategy(self) -> None:
""" Tag parameters to either bucket them or broadcast/reduce them directly. The parameters are ordered """ Tag parameters to either bucket them or broadcast/reduce them directly. The parameters are ordered
(smallest first), the bucket will hold the smallest elements, the remaining ones will be directly sent (smallest first), the bucket will hold the smallest elements, the remaining ones will be directly sent
......
...@@ -178,16 +178,47 @@ class TestSingleRank(unittest.TestCase): ...@@ -178,16 +178,47 @@ class TestSingleRank(unittest.TestCase):
def run_test_add_param_group(rank, world_size, tempfile_name): def run_test_add_param_group(rank, world_size, tempfile_name):
dist_init(rank, world_size, tempfile_name) dist_init(rank, world_size, tempfile_name)
params = []
for size in [4, 5, 2, 6, 4]: # Test with all parameters trainable to begin with
params.append(torch.rand(size, 1)) def all_trainable():
o = optim.OSS(params, lr=0.1) params = []
assert len(o.param_groups) == 1 for size in [4, 5, 2, 6, 4]:
o.add_param_group({"params": [torch.rand(3, 1)]}) params.append(torch.rand(size, 1))
assert len(o.param_groups) == 2
# Verify that added group is added to the correct partition making all have 8 elements. # Make sure that the params are trainable, enforces size-based partitioning
assert sum([x.numel() for g in o.optim.param_groups for x in g["params"]]) == 8 for p in params:
assert len(o.optim.param_groups) == 2 p.requires_grad = True
o = optim.OSS(params, lr=0.1)
assert len(o.param_groups) == 1
o.add_param_group({"params": [torch.rand(3, 1)]})
assert len(o.param_groups) == 2
# Verify that added group is added to the correct partition making all have 8 elements.
assert sum([x.numel() for g in o.optim.param_groups for x in g["params"]]) == 8
assert len(o.optim.param_groups) == 2
# Test a pathological config with a first big non-trainable param
def some_trainable():
params = []
for size in [100, 3, 5, 2, 6, 4]:
params.append(torch.rand(size, 1))
# Make sure that the params are trainable, enforces size-based partitioning
for p in params[1:]:
p.requires_grad = True
o = optim.OSS(params, lr=0.1)
assert len(o.param_groups) == 1
o.add_param_group({"params": [torch.rand(3, 1)]})
assert len(o.param_groups) == 2
assert len(o.optim.param_groups) == 2
all_trainable()
some_trainable()
dist.destroy_process_group() dist.destroy_process_group()
...@@ -303,6 +334,11 @@ def run_test_sharding(rank, world_size, tempfile_name): ...@@ -303,6 +334,11 @@ def run_test_sharding(rank, world_size, tempfile_name):
params = [] params = []
for size in [5, 4, 2, 6, 4, 3]: for size in [5, 4, 2, 6, 4, 3]:
params.append(torch.rand(size, 1)) params.append(torch.rand(size, 1))
# Make sure that the params are trainable, enforces size-based partitioning
for p in params:
p.requires_grad = True
o = optim.OSS(params, lr=0.1) o = optim.OSS(params, lr=0.1)
assert sum([x.numel() for x in o.optim.param_groups[0]["params"]]) == 8 assert sum([x.numel() for x in o.optim.param_groups[0]["params"]]) == 8
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment