import torch import deepspeed import argparse import pytest import json import os from common import distributed_test from simple_model import SimpleModel, random_dataloader, args_from_dict def test_lamb_fp16_basic(tmpdir): config_dict = { "train_batch_size": 2, "steps_per_print": 1, "optimizer": { "type": "Lamb", "params": { "lr": 0.00015, "max_grad_norm": 1.0 } }, "fp16": { "enabled": True } } args = args_from_dict(tmpdir, config_dict) hidden_dim = 10 model = SimpleModel(hidden_dim, empty_grad=False) @distributed_test(world_size=[1, 2]) def _test_lamb_fp16_basic(args, model, hidden_dim): model, _, _,_ = deepspeed.initialize(args=args, model=model, model_parameters=model.parameters(), dist_init_required=False) data_loader = random_dataloader(model=model, total_samples=50, hidden_dim=hidden_dim, device=model.device) for n, batch in enumerate(data_loader): loss = model(batch[0], batch[1]) model.backward(loss) model.step() _test_lamb_fp16_basic(args=args, model=model, hidden_dim=hidden_dim) def test_lamb_fp16_empty_grad(tmpdir): config_dict = { "train_batch_size": 1, "steps_per_print": 1, "optimizer": { "type": "Lamb", "params": { "lr": 0.00015, "max_grad_norm": 1.0 } }, "fp16": { "enabled": True } } args = args_from_dict(tmpdir, config_dict) hidden_dim = 10 model = SimpleModel(hidden_dim, empty_grad=True) @distributed_test(world_size=[1]) def _test_lamb_fp16_empty_grad(args, model, hidden_dim): model, _, _,_ = deepspeed.initialize(args=args, model=model, model_parameters=model.parameters(), dist_init_required=False) data_loader = random_dataloader(model=model, total_samples=50, hidden_dim=hidden_dim, device=model.device) for n, batch in enumerate(data_loader): loss = model(batch[0], batch[1]) model.backward(loss) model.step() _test_lamb_fp16_empty_grad(args=args, model=model, hidden_dim=hidden_dim) def test_adamw_fp16_basic(tmpdir): config_dict = { "train_batch_size": 1, "steps_per_print": 1, "fp16": { "enabled": True } } args = args_from_dict(tmpdir, config_dict) hidden_dim = 10 model = SimpleModel(hidden_dim, empty_grad=False) @distributed_test(world_size=[1]) def _test_adamw_fp16_basic(args, model, hidden_dim): optimizer = torch.optim.AdamW(params=model.parameters()) model, _, _,_ = deepspeed.initialize(args=args, model=model, optimizer=optimizer, dist_init_required=False) data_loader = random_dataloader(model=model, total_samples=50, hidden_dim=hidden_dim, device=model.device) for n, batch in enumerate(data_loader): loss = model(batch[0], batch[1]) model.backward(loss) model.step() _test_adamw_fp16_basic(args=args, model=model, hidden_dim=hidden_dim) def test_adamw_fp16_empty_grad(tmpdir): config_dict = { "train_batch_size": 1, "steps_per_print": 1, "fp16": { "enabled": True } } args = args_from_dict(tmpdir, config_dict) hidden_dim = 10 model = SimpleModel(hidden_dim, empty_grad=True) @distributed_test(world_size=[1]) def _test_adamw_fp16_empty_grad(args, model, hidden_dim): optimizer = torch.optim.AdamW(params=model.parameters()) model, _, _,_ = deepspeed.initialize(args=args, model=model, optimizer=optimizer, dist_init_required=False) data_loader = random_dataloader(model=model, total_samples=50, hidden_dim=hidden_dim, device=model.device) for n, batch in enumerate(data_loader): loss = model(batch[0], batch[1]) model.backward(loss) model.step() _test_adamw_fp16_empty_grad(args=args, model=model, hidden_dim=hidden_dim)