Unverified Commit 6d223777 authored by Benjamin Lefaudeux's avatar Benjamin Lefaudeux Committed by GitHub
Browse files

[fix] Fix iGPT buckets with ShardedDDP (#223)

* proper unit testing, but no other solution than disabling bucketing for now, couple of options tested do not work
parent ce5860ea
...@@ -112,6 +112,9 @@ class ShardedDataParallel(nn.Module): ...@@ -112,6 +112,9 @@ class ShardedDataParallel(nn.Module):
# for the subsequent FW to be correct # for the subsequent FW to be correct
self.sync_buffers(blocking=True) self.sync_buffers(blocking=True)
# Reset all the grad reduce and bucket state flags
self._grad_to_be_reduced = [True] * len(self._grad_to_be_reduced)
# Normal FW on the base model # Normal FW on the base model
return self.module(*inputs, **kwargs) return self.module(*inputs, **kwargs)
...@@ -179,9 +182,6 @@ class ShardedDataParallel(nn.Module): ...@@ -179,9 +182,6 @@ class ShardedDataParallel(nn.Module):
# and execute the delayed actions (release gradients, unroll the buckets) # and execute the delayed actions (release gradients, unroll the buckets)
Variable._execution_engine.queue_callback(optimizer._consume_work_handles) Variable._execution_engine.queue_callback(optimizer._consume_work_handles)
# Reset all the grad reduce and bucket state flags
self._grad_to_be_reduced = [True] * len(self._grad_to_be_reduced)
def reduce_direct(*_: Any) -> None: def reduce_direct(*_: Any) -> None:
# Skip gradient reduction, do not alter status flags # Skip gradient reduction, do not alter status flags
if not self.should_accumulate_grads and self._grad_to_be_reduced[index]: if not self.should_accumulate_grads and self._grad_to_be_reduced[index]:
...@@ -229,22 +229,6 @@ class ShardedDataParallel(nn.Module): ...@@ -229,22 +229,6 @@ class ShardedDataParallel(nn.Module):
) )
if bucket.full(): if bucket.full():
def unwrap() -> None:
for flat in bucket.params:
if dst_rank != self.global_rank:
# this rank is not the owner, release the grad
flat.param.grad = None
else:
# this rank is the owner, unroll the results
assert flat.param.grad is not None
flat.param.grad.data.copy_(
bucket.buffer[flat.start : flat.stop].view_as(flat.param.data), non_blocking=True
)
bucket.reset()
bucket.buffer /= self.world_size bucket.buffer /= self.world_size
optimizer.work_handles.append( optimizer.work_handles.append(
...@@ -252,7 +236,7 @@ class ShardedDataParallel(nn.Module): ...@@ -252,7 +236,7 @@ class ShardedDataParallel(nn.Module):
handle=dist.reduce( handle=dist.reduce(
tensor=bucket.buffer, dst=dst_rank, group=self.process_group, async_op=True, tensor=bucket.buffer, dst=dst_rank, group=self.process_group, async_op=True,
), ),
callback=unwrap, callback=bucket.unroll,
) )
) )
......
...@@ -5,6 +5,7 @@ ...@@ -5,6 +5,7 @@
from collections import OrderedDict from collections import OrderedDict
import copy import copy
from enum import Enum, auto
import itertools import itertools
from itertools import chain from itertools import chain
import logging import logging
...@@ -26,6 +27,11 @@ else: ...@@ -26,6 +27,11 @@ else:
_params_t = Any _params_t = Any
class BucketFlush(Enum):
Reduce = auto()
Broadcast = auto()
class OSS(Optimizer): class OSS(Optimizer):
"""Wraps an arbitrary :class:`optim.Optimizer <torch.optim.Optimizer>` """Wraps an arbitrary :class:`optim.Optimizer <torch.optim.Optimizer>`
optimizer and shards its state as described by ZeRO_. optimizer and shards its state as described by ZeRO_.
...@@ -68,6 +74,9 @@ class OSS(Optimizer): ...@@ -68,6 +74,9 @@ class OSS(Optimizer):
broadcast_buffer_size: int = 2 ** 17, broadcast_buffer_size: int = 2 ** 17,
**default: Any, **default: Any,
): ):
logging.warning("Disabling bucketing for now, error prone for some models")
broadcast_buffer_size = 0
# Hold all the model params in the root .param_groups # Hold all the model params in the root .param_groups
self.in_super_constructor = True self.in_super_constructor = True
super().__init__(params, default) super().__init__(params, default)
...@@ -495,21 +504,6 @@ class OSS(Optimizer): ...@@ -495,21 +504,6 @@ class OSS(Optimizer):
def _broadcast_params(self) -> None: def _broadcast_params(self) -> None:
"""Helper function to broadcast all the parameters from a given device""" """Helper function to broadcast all the parameters from a given device"""
# The unroll callback is called when the broadcast is done.
# If this rank is a recipiendary and the call was bucketed, the results from the broadcast are unrolled
# onto the corresponding parameters.
def get_unroll_callback(src_rank: int, bucket: Bucket) -> Callable:
def unroll() -> None:
if src_rank != self.rank:
for flat in bucket.params:
flat.param.data.copy_(
bucket.buffer[flat.start : flat.stop].view_as(flat.param.data), non_blocking=True
)
bucket.reset()
return unroll
with torch.no_grad(): with torch.no_grad():
for ( for (
device, device,
...@@ -537,7 +531,7 @@ class OSS(Optimizer): ...@@ -537,7 +531,7 @@ class OSS(Optimizer):
handle=dist.broadcast( handle=dist.broadcast(
tensor=bucket.buffer, src=global_src_rank, group=self.group, async_op=True tensor=bucket.buffer, src=global_src_rank, group=self.group, async_op=True
), ),
callback=get_unroll_callback(src_rank, bucket), callback=bucket.unroll,
) )
) )
...@@ -566,6 +560,30 @@ class OSS(Optimizer): ...@@ -566,6 +560,30 @@ class OSS(Optimizer):
self.work_handles.clear() self.work_handles.clear()
def _handle_trailing_buckets(self, flush_type: BucketFlush) -> None:
"""
Go through the buckets, flush them if not already empty
.. warning: Could be that a bucket flush was already requested, needs to be handled carefully
"""
for bucket_list in self.buckets.values():
for bucket in bucket_list:
if bucket.current_offset > 0:
self.work_handles.append(
Workhandle(
handle=dist.broadcast(
tensor=bucket.buffer, src=bucket.global_ref_rank, group=self.group, async_op=True,
)
if flush_type == BucketFlush.Broadcast
else dist.reduce(
tensor=bucket.buffer, dst=bucket.global_ref_rank, group=self.group, async_op=True,
),
callback=bucket.unroll,
)
)
self._consume_work_handles()
def _setup_bucket_strategy(self) -> None: def _setup_bucket_strategy(self) -> None:
""" Tag parameters to either bucket them or broadcast/reduce them directly. The parameters are ordered """ Tag parameters to either bucket them or broadcast/reduce them directly. The parameters are ordered
(smallest first), the bucket will hold the smallest elements, the remaining ones will be directly sent (smallest first), the bucket will hold the smallest elements, the remaining ones will be directly sent
...@@ -578,20 +596,21 @@ class OSS(Optimizer): ...@@ -578,20 +596,21 @@ class OSS(Optimizer):
for device, per_rank_params in self.per_device_params.items(): for device, per_rank_params in self.per_device_params.items():
for dst_rank, params in enumerate(per_rank_params): for dst_rank, params in enumerate(per_rank_params):
offset = 0 offset = 0
bucket_size = self.buckets[device][dst_rank].max_size
for param in params: for param in params:
if (offset + param.numel()) < bucket_size: # Criteria to decide whether this parameter is to be bucketed or not:
# This parameter is small enough to fit in the remaining size of the bucket # - enough room in the bucket
# - param not the first one in the DAG, because this may be kicked out of autograd (depending on inputs)
if (offset + param.numel()) < self.buckets[device][dst_rank].max_size and param.is_leaf:
self.should_bucket_param[param] = True self.should_bucket_param[param] = True
offset += param.numel() offset += param.numel()
else: else:
# The parameters are sorted by size, so all the following parameters
# will be too big and can be skipped
self.should_bucket_param[param] = False self.should_bucket_param[param] = False
# Register the max offset for this buffer # Register the max offset for this buffer, and the reference rank
self.buckets[device][dst_rank].max_offset = offset self.buckets[device][dst_rank].max_offset = offset
self.buckets[device][dst_rank].global_ref_rank = self.get_global_rank(self.group, dst_rank)
self.buckets[device][dst_rank].global_rank = self.global_rank
# Determine the max work handles in flight: # Determine the max work handles in flight:
# - all the direct reduce/broadcast + 1 bucket # - all the direct reduce/broadcast + 1 bucket
......
...@@ -37,9 +37,39 @@ class Bucket: ...@@ -37,9 +37,39 @@ class Bucket:
# Handles to the params and their position in this tensor, can be useful for a callback # Handles to the params and their position in this tensor, can be useful for a callback
self.params: List[FlatParam] = [] self.params: List[FlatParam] = []
# Optional callback, possibly to unwrap the bucket
self.callback: Optional[Callable] = None
# Current status for this buffer # Current status for this buffer
self.current_offset = 0 self.current_offset = 0
self.max_offset = 0 self.max_offset = 0
self.global_ref_rank = -1 # Either the destination or the src rank, if reducing or broadcasting for instance
self.global_rank = -1
self.gradients_based = False
def unroll(self) -> None:
"""
Dsitribute the contents of the flat buffer back to the attached parameters
"""
for flat in self.params:
if self.global_ref_rank != self.global_rank and self.gradients_based:
# this rank is not the owner, release the grad
flat.param.grad = None
else:
if self.gradients_based:
# this rank is the owner, unroll the results
assert flat.param.grad is not None
flat.param.grad.data.copy_(
self.buffer[flat.start : flat.stop].view_as(flat.param.data), non_blocking=True
)
else:
flat.param.data.copy_(
self.buffer[flat.start : flat.stop].view_as(flat.param.data), non_blocking=True
)
self.reset()
def reset(self) -> None: def reset(self) -> None:
""" empty the bucket """ """ empty the bucket """
...@@ -50,6 +80,7 @@ class Bucket: ...@@ -50,6 +80,7 @@ class Bucket:
""" add a tensor to the bucket """ """ add a tensor to the bucket """
end = self.current_offset + tensor.numel() end = self.current_offset + tensor.numel()
self.gradients_based = use_gradient
if end > self.max_size: if end > self.max_size:
return False return False
......
...@@ -39,6 +39,7 @@ import torch ...@@ -39,6 +39,7 @@ import torch
import torch.distributed as dist import torch.distributed as dist
from torch.distributed import rpc from torch.distributed import rpc
import torch.multiprocessing as mp import torch.multiprocessing as mp
import torch.nn as nn
from fairscale.nn.model_parallel import initialize_model_parallel from fairscale.nn.model_parallel import initialize_model_parallel
from fairscale.nn.model_parallel.random import model_parallel_cuda_manual_seed from fairscale.nn.model_parallel.random import model_parallel_cuda_manual_seed
...@@ -211,3 +212,81 @@ def torch_spawn(world_sizes: Optional[List[int]] = None) -> Callable: ...@@ -211,3 +212,81 @@ def torch_spawn(world_sizes: Optional[List[int]] = None) -> Callable:
return func return func
return prepare_test return prepare_test
class _Block(nn.Module):
def __init__(self, embed_dim: int, num_heads: int) -> None:
super().__init__()
self.ln_1 = nn.LayerNorm(embed_dim)
self.ln_2 = nn.LayerNorm(embed_dim)
self.attn = nn.MultiheadAttention(embed_dim, num_heads) # type: ignore
self.mlp = nn.Sequential(nn.Linear(embed_dim, embed_dim * 4), nn.GELU(), nn.Linear(embed_dim * 4, embed_dim),)
def forward(self, *inputs: Any, **kwargs: Any) -> Any:
x = inputs[0]
attn_mask = torch.full((len(x), len(x)), -float("Inf"), device=x.device, dtype=x.dtype)
attn_mask = torch.triu(attn_mask, diagonal=1)
x = self.ln_1(x)
a, _ = self.attn(x, x, x, attn_mask=attn_mask, need_weights=False)
x = x + a
m = self.mlp(self.ln_2(x))
x = x + m
return x
class GPT2(nn.Module):
def __init__(
self, embed_dim: int, num_heads: int, num_layers: int, num_positions: int, num_vocab: int, num_classes: int
) -> None:
super().__init__()
self.embed_dim = embed_dim
# start of sequence token
self.sos = torch.nn.Parameter(torch.zeros(embed_dim))
nn.init.normal_(self.sos)
self.token_embeddings = nn.Embedding(num_vocab, embed_dim)
self.position_embeddings = nn.Embedding(num_positions, embed_dim)
self.layers = nn.ModuleList()
for _ in range(num_layers):
self.layers.append(_Block(embed_dim, num_heads))
self.ln_f = nn.LayerNorm(embed_dim)
self.head = nn.Linear(embed_dim, num_vocab, bias=False)
self.clf_head = nn.Linear(embed_dim, num_classes)
def forward(self, x: torch.Tensor, classify=False) -> Any: # type: ignore
"""
Expect input as shape [sequence len, batch]
If classify, return classification logits
"""
length, batch = x.shape
h = self.token_embeddings(x)
# prepend sos token
sos = torch.ones(1, batch, self.embed_dim, device=x.device) * self.sos
h = torch.cat([sos, h[:-1, :, :]], dim=0)
# add positional embeddings
positions = torch.arange(length, device=x.device).unsqueeze(-1)
h = h + self.position_embeddings(positions).expand_as(h)
# transformer
for layer in self.layers:
h = layer(h)
h = self.ln_f(h)
logits = self.head(h)
if not classify:
# return logits
return logits
h = torch.mean(h, dim=0) # average pool over sequence
# return classification logits and generative logits
return self.clf_head(h), logits
...@@ -24,6 +24,8 @@ skip_if_no_cuda = pytest.mark.skipif(not torch.cuda.is_available(), reason="cuda ...@@ -24,6 +24,8 @@ skip_if_no_cuda = pytest.mark.skipif(not torch.cuda.is_available(), reason="cuda
skip_if_single_gpu = pytest.mark.skipif(torch.cuda.device_count() < 2, reason="multiple GPUs required") skip_if_single_gpu = pytest.mark.skipif(torch.cuda.device_count() < 2, reason="multiple GPUs required")
from contextlib import suppress from contextlib import suppress
from fairscale.utils.testing import GPT2
def test_step_on_cpu(): def test_step_on_cpu():
run_test(backend=dist.Backend.GLOO, device=torch.device("cpu"), world_size=4) run_test(backend=dist.Backend.GLOO, device=torch.device("cpu"), world_size=4)
...@@ -153,7 +155,6 @@ def run_test_two_inputs(rank, world_size, backend, device, temp_file_name): ...@@ -153,7 +155,6 @@ def run_test_two_inputs(rank, world_size, backend, device, temp_file_name):
loss.backward() loss.backward()
return loss return loss
# The models should stay the same in between the ranks
for i in range(5): for i in range(5):
_ = optimizer.step(closure=closure) _ = optimizer.step(closure=closure)
...@@ -214,15 +215,17 @@ def run_test_two_optimizers(rank, world_size, backend, device, temp_file_name): ...@@ -214,15 +215,17 @@ def run_test_two_optimizers(rank, world_size, backend, device, temp_file_name):
# Optim loop # Optim loop
def closure(): def closure():
optimizer.zero_grad()
input_tensor = torch.rand((64, 2)).to(device) input_tensor = torch.rand((64, 2)).to(device)
loss = ddp_model(input_tensor, input_tensor).abs().sum() loss = ddp_model(input_tensor, input_tensor).abs().sum()
loss.backward() loss.backward()
return loss return loss
# The models should stay the same in between the ranks
for i in range(5): for i in range(5):
_ = optimizer.step(closure=closure) optimizer_1.zero_grad()
optimizer_2.zero_grad()
_ = optimizer_1.step(closure=closure)
_ = optimizer_2.step(closure=closure)
dist.destroy_process_group() dist.destroy_process_group()
...@@ -233,4 +236,49 @@ def test_two_optimizers(): ...@@ -233,4 +236,49 @@ def test_two_optimizers():
backend = "gloo" backend = "gloo"
temp_file_name = tempfile.mkstemp()[1] temp_file_name = tempfile.mkstemp()[1]
device = "cpu" device = "cpu"
mp.spawn(run_test_two_inputs, args=(world_size, backend, device, temp_file_name), nprocs=world_size, join=True) mp.spawn(run_test_two_optimizers, args=(world_size, backend, device, temp_file_name), nprocs=world_size, join=True)
def run_test_gpt2(rank, world_size, backend, device, temp_file_name):
INPUT_DIM = 32
BACH_SIZE = 10
STEPS = 10
url = "file://" + temp_file_name
dist.init_process_group(init_method=url, backend=backend, rank=rank, world_size=world_size)
if device == torch.device("cuda"):
torch.cuda.set_device(rank)
torch.manual_seed(rank)
np.random.seed(rank)
model = GPT2(
embed_dim=512, num_heads=2, num_layers=24, num_positions=INPUT_DIM * INPUT_DIM, num_vocab=512, num_classes=2
).to(device)
optimizer = OSS(params=model.parameters(), optim=torch.optim.SGD, lr=0.01, momentum=0.99)
ddp_model = ShardedDataParallel(model, optimizer)
# Optim loop
def closure():
optimizer.zero_grad()
# Force int inputs to prevent the first grad from firing
input_tensor = torch.randint(10, (BACH_SIZE, INPUT_DIM)).to(device)
loss = ddp_model(input_tensor).abs().sum()
loss.backward()
return loss
# Check for bucketing overflows
for i in range(STEPS):
_ = optimizer.step(closure=closure)
dist.destroy_process_group()
@skip_if_no_cuda
@skip_if_single_gpu
def test_gpt2():
# Check that the ShardedDDP wrapper accepts tuple(tensors) as inputs
world_size = 2
backend = "gloo"
temp_file_name = tempfile.mkstemp()[1]
device = "cuda"
mp.spawn(run_test_gpt2, args=(world_size, backend, device, temp_file_name), nprocs=world_size, join=True)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment