Unverified Commit 6d223777 authored by Benjamin Lefaudeux's avatar Benjamin Lefaudeux Committed by GitHub
Browse files

[fix] Fix iGPT buckets with ShardedDDP (#223)

* proper unit testing, but no other solution than disabling bucketing for now, couple of options tested do not work
parent ce5860ea
......@@ -112,6 +112,9 @@ class ShardedDataParallel(nn.Module):
# for the subsequent FW to be correct
self.sync_buffers(blocking=True)
# Reset all the grad reduce and bucket state flags
self._grad_to_be_reduced = [True] * len(self._grad_to_be_reduced)
# Normal FW on the base model
return self.module(*inputs, **kwargs)
......@@ -179,9 +182,6 @@ class ShardedDataParallel(nn.Module):
# and execute the delayed actions (release gradients, unroll the buckets)
Variable._execution_engine.queue_callback(optimizer._consume_work_handles)
# Reset all the grad reduce and bucket state flags
self._grad_to_be_reduced = [True] * len(self._grad_to_be_reduced)
def reduce_direct(*_: Any) -> None:
# Skip gradient reduction, do not alter status flags
if not self.should_accumulate_grads and self._grad_to_be_reduced[index]:
......@@ -229,22 +229,6 @@ class ShardedDataParallel(nn.Module):
)
if bucket.full():
def unwrap() -> None:
for flat in bucket.params:
if dst_rank != self.global_rank:
# this rank is not the owner, release the grad
flat.param.grad = None
else:
# this rank is the owner, unroll the results
assert flat.param.grad is not None
flat.param.grad.data.copy_(
bucket.buffer[flat.start : flat.stop].view_as(flat.param.data), non_blocking=True
)
bucket.reset()
bucket.buffer /= self.world_size
optimizer.work_handles.append(
......@@ -252,7 +236,7 @@ class ShardedDataParallel(nn.Module):
handle=dist.reduce(
tensor=bucket.buffer, dst=dst_rank, group=self.process_group, async_op=True,
),
callback=unwrap,
callback=bucket.unroll,
)
)
......
......@@ -5,6 +5,7 @@
from collections import OrderedDict
import copy
from enum import Enum, auto
import itertools
from itertools import chain
import logging
......@@ -26,6 +27,11 @@ else:
_params_t = Any
class BucketFlush(Enum):
Reduce = auto()
Broadcast = auto()
class OSS(Optimizer):
"""Wraps an arbitrary :class:`optim.Optimizer <torch.optim.Optimizer>`
optimizer and shards its state as described by ZeRO_.
......@@ -68,6 +74,9 @@ class OSS(Optimizer):
broadcast_buffer_size: int = 2 ** 17,
**default: Any,
):
logging.warning("Disabling bucketing for now, error prone for some models")
broadcast_buffer_size = 0
# Hold all the model params in the root .param_groups
self.in_super_constructor = True
super().__init__(params, default)
......@@ -495,21 +504,6 @@ class OSS(Optimizer):
def _broadcast_params(self) -> None:
"""Helper function to broadcast all the parameters from a given device"""
# The unroll callback is called when the broadcast is done.
# If this rank is a recipiendary and the call was bucketed, the results from the broadcast are unrolled
# onto the corresponding parameters.
def get_unroll_callback(src_rank: int, bucket: Bucket) -> Callable:
def unroll() -> None:
if src_rank != self.rank:
for flat in bucket.params:
flat.param.data.copy_(
bucket.buffer[flat.start : flat.stop].view_as(flat.param.data), non_blocking=True
)
bucket.reset()
return unroll
with torch.no_grad():
for (
device,
......@@ -537,7 +531,7 @@ class OSS(Optimizer):
handle=dist.broadcast(
tensor=bucket.buffer, src=global_src_rank, group=self.group, async_op=True
),
callback=get_unroll_callback(src_rank, bucket),
callback=bucket.unroll,
)
)
......@@ -566,6 +560,30 @@ class OSS(Optimizer):
self.work_handles.clear()
def _handle_trailing_buckets(self, flush_type: BucketFlush) -> None:
"""
Go through the buckets, flush them if not already empty
.. warning: Could be that a bucket flush was already requested, needs to be handled carefully
"""
for bucket_list in self.buckets.values():
for bucket in bucket_list:
if bucket.current_offset > 0:
self.work_handles.append(
Workhandle(
handle=dist.broadcast(
tensor=bucket.buffer, src=bucket.global_ref_rank, group=self.group, async_op=True,
)
if flush_type == BucketFlush.Broadcast
else dist.reduce(
tensor=bucket.buffer, dst=bucket.global_ref_rank, group=self.group, async_op=True,
),
callback=bucket.unroll,
)
)
self._consume_work_handles()
def _setup_bucket_strategy(self) -> None:
""" Tag parameters to either bucket them or broadcast/reduce them directly. The parameters are ordered
(smallest first), the bucket will hold the smallest elements, the remaining ones will be directly sent
......@@ -578,20 +596,21 @@ class OSS(Optimizer):
for device, per_rank_params in self.per_device_params.items():
for dst_rank, params in enumerate(per_rank_params):
offset = 0
bucket_size = self.buckets[device][dst_rank].max_size
for param in params:
if (offset + param.numel()) < bucket_size:
# This parameter is small enough to fit in the remaining size of the bucket
# Criteria to decide whether this parameter is to be bucketed or not:
# - enough room in the bucket
# - param not the first one in the DAG, because this may be kicked out of autograd (depending on inputs)
if (offset + param.numel()) < self.buckets[device][dst_rank].max_size and param.is_leaf:
self.should_bucket_param[param] = True
offset += param.numel()
else:
# The parameters are sorted by size, so all the following parameters
# will be too big and can be skipped
self.should_bucket_param[param] = False
# Register the max offset for this buffer
# Register the max offset for this buffer, and the reference rank
self.buckets[device][dst_rank].max_offset = offset
self.buckets[device][dst_rank].global_ref_rank = self.get_global_rank(self.group, dst_rank)
self.buckets[device][dst_rank].global_rank = self.global_rank
# Determine the max work handles in flight:
# - all the direct reduce/broadcast + 1 bucket
......
......@@ -37,9 +37,39 @@ class Bucket:
# Handles to the params and their position in this tensor, can be useful for a callback
self.params: List[FlatParam] = []
# Optional callback, possibly to unwrap the bucket
self.callback: Optional[Callable] = None
# Current status for this buffer
self.current_offset = 0
self.max_offset = 0
self.global_ref_rank = -1 # Either the destination or the src rank, if reducing or broadcasting for instance
self.global_rank = -1
self.gradients_based = False
def unroll(self) -> None:
"""
Dsitribute the contents of the flat buffer back to the attached parameters
"""
for flat in self.params:
if self.global_ref_rank != self.global_rank and self.gradients_based:
# this rank is not the owner, release the grad
flat.param.grad = None
else:
if self.gradients_based:
# this rank is the owner, unroll the results
assert flat.param.grad is not None
flat.param.grad.data.copy_(
self.buffer[flat.start : flat.stop].view_as(flat.param.data), non_blocking=True
)
else:
flat.param.data.copy_(
self.buffer[flat.start : flat.stop].view_as(flat.param.data), non_blocking=True
)
self.reset()
def reset(self) -> None:
""" empty the bucket """
......@@ -50,6 +80,7 @@ class Bucket:
""" add a tensor to the bucket """
end = self.current_offset + tensor.numel()
self.gradients_based = use_gradient
if end > self.max_size:
return False
......
......@@ -39,6 +39,7 @@ import torch
import torch.distributed as dist
from torch.distributed import rpc
import torch.multiprocessing as mp
import torch.nn as nn
from fairscale.nn.model_parallel import initialize_model_parallel
from fairscale.nn.model_parallel.random import model_parallel_cuda_manual_seed
......@@ -211,3 +212,81 @@ def torch_spawn(world_sizes: Optional[List[int]] = None) -> Callable:
return func
return prepare_test
class _Block(nn.Module):
def __init__(self, embed_dim: int, num_heads: int) -> None:
super().__init__()
self.ln_1 = nn.LayerNorm(embed_dim)
self.ln_2 = nn.LayerNorm(embed_dim)
self.attn = nn.MultiheadAttention(embed_dim, num_heads) # type: ignore
self.mlp = nn.Sequential(nn.Linear(embed_dim, embed_dim * 4), nn.GELU(), nn.Linear(embed_dim * 4, embed_dim),)
def forward(self, *inputs: Any, **kwargs: Any) -> Any:
x = inputs[0]
attn_mask = torch.full((len(x), len(x)), -float("Inf"), device=x.device, dtype=x.dtype)
attn_mask = torch.triu(attn_mask, diagonal=1)
x = self.ln_1(x)
a, _ = self.attn(x, x, x, attn_mask=attn_mask, need_weights=False)
x = x + a
m = self.mlp(self.ln_2(x))
x = x + m
return x
class GPT2(nn.Module):
def __init__(
self, embed_dim: int, num_heads: int, num_layers: int, num_positions: int, num_vocab: int, num_classes: int
) -> None:
super().__init__()
self.embed_dim = embed_dim
# start of sequence token
self.sos = torch.nn.Parameter(torch.zeros(embed_dim))
nn.init.normal_(self.sos)
self.token_embeddings = nn.Embedding(num_vocab, embed_dim)
self.position_embeddings = nn.Embedding(num_positions, embed_dim)
self.layers = nn.ModuleList()
for _ in range(num_layers):
self.layers.append(_Block(embed_dim, num_heads))
self.ln_f = nn.LayerNorm(embed_dim)
self.head = nn.Linear(embed_dim, num_vocab, bias=False)
self.clf_head = nn.Linear(embed_dim, num_classes)
def forward(self, x: torch.Tensor, classify=False) -> Any: # type: ignore
"""
Expect input as shape [sequence len, batch]
If classify, return classification logits
"""
length, batch = x.shape
h = self.token_embeddings(x)
# prepend sos token
sos = torch.ones(1, batch, self.embed_dim, device=x.device) * self.sos
h = torch.cat([sos, h[:-1, :, :]], dim=0)
# add positional embeddings
positions = torch.arange(length, device=x.device).unsqueeze(-1)
h = h + self.position_embeddings(positions).expand_as(h)
# transformer
for layer in self.layers:
h = layer(h)
h = self.ln_f(h)
logits = self.head(h)
if not classify:
# return logits
return logits
h = torch.mean(h, dim=0) # average pool over sequence
# return classification logits and generative logits
return self.clf_head(h), logits
......@@ -24,6 +24,8 @@ skip_if_no_cuda = pytest.mark.skipif(not torch.cuda.is_available(), reason="cuda
skip_if_single_gpu = pytest.mark.skipif(torch.cuda.device_count() < 2, reason="multiple GPUs required")
from contextlib import suppress
from fairscale.utils.testing import GPT2
def test_step_on_cpu():
run_test(backend=dist.Backend.GLOO, device=torch.device("cpu"), world_size=4)
......@@ -153,7 +155,6 @@ def run_test_two_inputs(rank, world_size, backend, device, temp_file_name):
loss.backward()
return loss
# The models should stay the same in between the ranks
for i in range(5):
_ = optimizer.step(closure=closure)
......@@ -214,15 +215,17 @@ def run_test_two_optimizers(rank, world_size, backend, device, temp_file_name):
# Optim loop
def closure():
optimizer.zero_grad()
input_tensor = torch.rand((64, 2)).to(device)
loss = ddp_model(input_tensor, input_tensor).abs().sum()
loss.backward()
return loss
# The models should stay the same in between the ranks
for i in range(5):
_ = optimizer.step(closure=closure)
optimizer_1.zero_grad()
optimizer_2.zero_grad()
_ = optimizer_1.step(closure=closure)
_ = optimizer_2.step(closure=closure)
dist.destroy_process_group()
......@@ -233,4 +236,49 @@ def test_two_optimizers():
backend = "gloo"
temp_file_name = tempfile.mkstemp()[1]
device = "cpu"
mp.spawn(run_test_two_inputs, args=(world_size, backend, device, temp_file_name), nprocs=world_size, join=True)
mp.spawn(run_test_two_optimizers, args=(world_size, backend, device, temp_file_name), nprocs=world_size, join=True)
def run_test_gpt2(rank, world_size, backend, device, temp_file_name):
INPUT_DIM = 32
BACH_SIZE = 10
STEPS = 10
url = "file://" + temp_file_name
dist.init_process_group(init_method=url, backend=backend, rank=rank, world_size=world_size)
if device == torch.device("cuda"):
torch.cuda.set_device(rank)
torch.manual_seed(rank)
np.random.seed(rank)
model = GPT2(
embed_dim=512, num_heads=2, num_layers=24, num_positions=INPUT_DIM * INPUT_DIM, num_vocab=512, num_classes=2
).to(device)
optimizer = OSS(params=model.parameters(), optim=torch.optim.SGD, lr=0.01, momentum=0.99)
ddp_model = ShardedDataParallel(model, optimizer)
# Optim loop
def closure():
optimizer.zero_grad()
# Force int inputs to prevent the first grad from firing
input_tensor = torch.randint(10, (BACH_SIZE, INPUT_DIM)).to(device)
loss = ddp_model(input_tensor).abs().sum()
loss.backward()
return loss
# Check for bucketing overflows
for i in range(STEPS):
_ = optimizer.step(closure=closure)
dist.destroy_process_group()
@skip_if_no_cuda
@skip_if_single_gpu
def test_gpt2():
# Check that the ShardedDDP wrapper accepts tuple(tensors) as inputs
world_size = 2
backend = "gloo"
temp_file_name = tempfile.mkstemp()[1]
device = "cuda"
mp.spawn(run_test_gpt2, args=(world_size, backend, device, temp_file_name), nprocs=world_size, join=True)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment