import torch import pytest import json import argparse import os from common import distributed_test from simple_model import SimpleModel, random_dataloader, args_from_dict import deepspeed def run_unbalanced_gradients(model, data_loader): def drop_some_gradients(model, iter): odd_iteration = iter % 2 for i, p in enumerate(model.parameters()): p.requires_grad = (i % 2) == odd_iteration def enable_grads(model): for p in model.parameters(): p.requires_grad = True for i, batch in enumerate(data_loader): drop_some_gradients(model, i + 1) loss = model(batch[0], batch[1]) model.backward(loss) model.step() enable_grads(model) @pytest.mark.parametrize('zero_stage', [1, 2]) def test_zero_unbalanced_gradients(tmpdir, zero_stage): config_dict = { "train_micro_batch_size_per_gpu": 2, "gradient_accumulation_steps": 2, "steps_per_print": 1, "zero_optimization": { "stage": zero_stage }, "optimizer": { "type": "Adam", "params": { "lr": 1e-3 } }, "fp16": { "enabled": True, "initial_scale_power": 8 } } args = args_from_dict(tmpdir, config_dict) hidden_dim = 4 model = SimpleModel(hidden_dim=hidden_dim) @distributed_test(world_size=[1]) def _test_zero_unbalanced_gradients(args, model, hidden_dim): model, _, _, _ = deepspeed.initialize(args=args, model=model, model_parameters=model.parameters()) data_loader = random_dataloader(model=model, total_samples=16, hidden_dim=hidden_dim, device=model.device) run_unbalanced_gradients(model, data_loader) _test_zero_unbalanced_gradients(args=args, model=model, hidden_dim=hidden_dim)