Unverified Commit c386e937 authored by Benjamin Lefaudeux's avatar Benjamin Lefaudeux Committed by GitHub
Browse files

[OSS] Balance the trainable params only (#262)

* fix, one liner

* adjust so that frozen trunks get spread still, even if this should have little consequences

* removing dead code, hopeful unit test fix

* now with some linting..

* adding a proper unit test case
parent ca74ee22
......@@ -5,7 +5,6 @@
from collections import OrderedDict, deque
import copy
from enum import Enum, auto
import itertools
from itertools import chain
import logging
......@@ -27,11 +26,6 @@ else:
_params_t = Any
class BucketFlush(Enum):
Reduce = auto()
Broadcast = auto()
class OSS(Optimizer):
"""Wraps an arbitrary :class:`optim.Optimizer <torch.optim.Optimizer>`
optimizer and shards its state as described by ZeRO_.
......@@ -139,7 +133,16 @@ class OSS(Optimizer):
# Add this param to rank with smallest size.
rank = sizes.index(min(sizes))
param_lists[rank].append(param)
# We're partitioning the optimizer state,
# so trainable parameters are the ones which really count
if param.requires_grad:
sizes[rank] += param.numel()
else:
# Spread frozen params on a per-tensor basis
# Mostly useful for balance partitions for fine tuning for instance
# Not required strictly speaking
sizes[rank] += 1
for rank, params in enumerate(param_lists):
param_group_rank = copy.copy(param_group)
......@@ -585,30 +588,6 @@ class OSS(Optimizer):
if work_handle.callback is not None:
work_handle.callback()
def _handle_trailing_buckets(self, flush_type: BucketFlush) -> None:
"""
Go through the buckets, flush them if not already empty
.. warning: Could be that a bucket flush was already requested, needs to be handled carefully
"""
for bucket_list in self.buckets.values():
for bucket in bucket_list:
if bucket.current_offset > 0:
self.work_handles.append(
Workhandle(
handle=dist.broadcast(
tensor=bucket.buffer, src=bucket.global_ref_rank, group=self.group, async_op=True,
)
if flush_type == BucketFlush.Broadcast
else dist.reduce(
tensor=bucket.buffer, dst=bucket.global_ref_rank, group=self.group, async_op=True,
),
callback=bucket.unroll,
)
)
self._consume_work_handles()
def _setup_bucket_strategy(self) -> None:
""" Tag parameters to either bucket them or broadcast/reduce them directly. The parameters are ordered
(smallest first), the bucket will hold the smallest elements, the remaining ones will be directly sent
......
......@@ -178,17 +178,48 @@ class TestSingleRank(unittest.TestCase):
def run_test_add_param_group(rank, world_size, tempfile_name):
dist_init(rank, world_size, tempfile_name)
# Test with all parameters trainable to begin with
def all_trainable():
params = []
for size in [4, 5, 2, 6, 4]:
params.append(torch.rand(size, 1))
# Make sure that the params are trainable, enforces size-based partitioning
for p in params:
p.requires_grad = True
o = optim.OSS(params, lr=0.1)
assert len(o.param_groups) == 1
o.add_param_group({"params": [torch.rand(3, 1)]})
assert len(o.param_groups) == 2
# Verify that added group is added to the correct partition making all have 8 elements.
assert sum([x.numel() for g in o.optim.param_groups for x in g["params"]]) == 8
assert len(o.optim.param_groups) == 2
# Test a pathological config with a first big non-trainable param
def some_trainable():
params = []
for size in [100, 3, 5, 2, 6, 4]:
params.append(torch.rand(size, 1))
# Make sure that the params are trainable, enforces size-based partitioning
for p in params[1:]:
p.requires_grad = True
o = optim.OSS(params, lr=0.1)
assert len(o.param_groups) == 1
o.add_param_group({"params": [torch.rand(3, 1)]})
assert len(o.param_groups) == 2
assert len(o.optim.param_groups) == 2
all_trainable()
some_trainable()
dist.destroy_process_group()
......@@ -303,6 +334,11 @@ def run_test_sharding(rank, world_size, tempfile_name):
params = []
for size in [5, 4, 2, 6, 4, 3]:
params.append(torch.rand(size, 1))
# Make sure that the params are trainable, enforces size-based partitioning
for p in params:
p.requires_grad = True
o = optim.OSS(params, lr=0.1)
assert sum([x.numel() for x in o.optim.param_groups[0]["params"]]) == 8
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment