Unverified Commit 0178e6cc authored by Olatunji Ruwase's avatar Olatunji Ruwase Committed by GitHub
Browse files

Fix unbalanced gradients bug in ZeRO-2 gradient accumulation (#545)

* Use zero-tensors for missing gradients to avoid size mismatch

* Unit test for unbalanced gradients in ZeRO

* Formatting fixes
parent 6b28bc5d
......@@ -1273,7 +1273,7 @@ class FP16_DeepSpeedZeroOptimizer(object):
current_size = 0
for i, tensor in enumerate(tensor_list):
if tensor.grad is None:
continue
tensor.grad = torch.zeros_like(tensor)
tensor = tensor.grad
num_elements = tensor.numel()
......
......@@ -32,8 +32,8 @@ class LinearStack(torch.nn.Module):
self.output_dim = output_dim
self.hidden_dim = hidden_dim
self.input_layer = VerboseLinear(in_features=self.input_dim,
out_features=self.hidden_dim)
self.input_layer = torch.nn.Linear(in_features=self.input_dim,
out_features=self.hidden_dim)
self.layers = torch.nn.ModuleList([
torch.nn.Linear(in_features=self.hidden_dim,
out_features=self.hidden_dim,
......
import torch
import pytest
import json
import argparse
import os
from common import distributed_test
from simple_model import SimpleModel, random_dataloader, args_from_dict
import deepspeed
def run_unbalanced_gradients(model, data_loader):
def drop_some_gradients(model, iter):
odd_iteration = iter % 2
for i, p in enumerate(model.parameters()):
p.requires_grad = (i % 2) == odd_iteration
def enable_grads(model):
for p in model.parameters():
p.requires_grad = True
for i, batch in enumerate(data_loader):
drop_some_gradients(model, i + 1)
loss = model(batch[0], batch[1])
model.backward(loss)
model.step()
enable_grads(model)
@pytest.mark.parametrize('zero_stage', [1, 2])
def test_zero_unbalanced_gradients(tmpdir, zero_stage):
config_dict = {
"train_micro_batch_size_per_gpu": 2,
"gradient_accumulation_steps": 2,
"steps_per_print": 1,
"zero_optimization": {
"stage": zero_stage
},
"optimizer": {
"type": "Adam",
"params": {
"lr": 1e-3
}
},
"fp16": {
"enabled": True,
"initial_scale_power": 8
}
}
args = args_from_dict(tmpdir, config_dict)
hidden_dim = 4
model = SimpleModel(hidden_dim=hidden_dim)
@distributed_test(world_size=[1])
def _test_zero_unbalanced_gradients(args, model, hidden_dim):
model, _, _, _ = deepspeed.initialize(args=args,
model=model,
model_parameters=model.parameters())
data_loader = random_dataloader(model=model,
total_samples=16,
hidden_dim=hidden_dim,
device=model.device)
run_unbalanced_gradients(model, data_loader)
_test_zero_unbalanced_gradients(args=args, model=model, hidden_dim=hidden_dim)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment