import torch import torch.nn as nn import torch.nn.functional as F import torch.distributed as dist import deepspeed import argparse import pytest import copy import json import os import numpy as np import time from deepspeed.runtime.pipe.topology import PipeDataParallelTopology, PipeModelDataParallelTopology from deepspeed.ops.op_builder import OpBuilder PipeTopo = PipeDataParallelTopology from deepspeed.runtime.pipe.module import PipelineModule, LayerSpec from .common import distributed_test from .simple_model import SimpleModel, SimpleOptimizer, random_dataloader, args_from_dict, create_deepspeed_args from .test_pipe import AlexNetPipe, train_cifar TORCH_MAJOR = int(torch.__version__.split('.')[0]) TORCH_MINOR = int(torch.__version__.split('.')[1]) if TORCH_MAJOR < 1 or TORCH_MINOR < 8: pytest.skip("NCCL-based 1-bit compression requires torch 1.8 or higher", allow_module_level=True) rocm_version = OpBuilder.installed_rocm_version() if rocm_version[0] > 4: pytest.skip( "NCCL-based 1-bit compression is not yet supported w. ROCm 5 until cupy supports ROCm 5", allow_module_level=True) def test_onebitadam_fp16_basic(tmpdir): config_dict = { "train_batch_size": 2, "steps_per_print": 1, "optimizer": { "type": "OneBitAdam", "params": { "lr": 0.00015, "weight_decay": 0.01, "freeze_step": 2, "cuda_aware": False, "comm_backend_name": "nccl" } }, "gradient_clipping": 1.0, "fp16": { "enabled": True, "loss_scale": 0, "initial_scale_power": 16 } } args = args_from_dict(tmpdir, config_dict) hidden_dim = 10 model = SimpleModel(hidden_dim) @distributed_test(world_size=[1, 2]) def _test_onebitadam_fp16_basic(args, model, hidden_dim): model, _, _, _ = deepspeed.initialize(args=args, model=model, model_parameters=model.parameters()) data_loader = random_dataloader(model=model, total_samples=50, hidden_dim=hidden_dim, device=model.device) for n, batch in enumerate(data_loader): loss = model(batch[0], batch[1]) model.backward(loss) model.step() _test_onebitadam_fp16_basic(args=args, model=model, hidden_dim=hidden_dim) def test_onebitadam_fp32_basic(tmpdir): config_dict = { "train_batch_size": 2, "steps_per_print": 1, "optimizer": { "type": "OneBitAdam", "params": { "lr": 0.00015, "weight_decay": 0.01, "freeze_step": 2, "cuda_aware": False, "comm_backend_name": "nccl" } }, "gradient_clipping": 1.0, } args = args_from_dict(tmpdir, config_dict) hidden_dim = 10 model = SimpleModel(hidden_dim) @distributed_test(world_size=[1, 2]) def _test_onebitadam_fp32_basic(args, model, hidden_dim): model, _, _, _ = deepspeed.initialize(args=args, model=model, model_parameters=model.parameters()) data_loader = random_dataloader(model=model, total_samples=50, hidden_dim=hidden_dim, device=model.device, dtype=torch.float) for n, batch in enumerate(data_loader): loss = model(batch[0], batch[1]) model.backward(loss) model.step() _test_onebitadam_fp32_basic(args=args, model=model, hidden_dim=hidden_dim) def test_onebitadam_exp_avg_mask(tmpdir): config_dict = { "train_batch_size": 2, "steps_per_print": 1, "optimizer": { "type": "OneBitAdam", "params": { "lr": 0.00015, "weight_decay": 0.01, "freeze_step": 2, "cuda_aware": False, "comm_backend_name": "nccl" } }, "gradient_clipping": 1.0, "fp16": { "enabled": True, "loss_scale": 0, "initial_scale_power": 16 } } args = args_from_dict(tmpdir, config_dict) hidden_dim = 10 model = SimpleModel(hidden_dim) param_optimizer = list(model.named_parameters()) mask1 = torch.zeros_like(param_optimizer[0][1].data) for col in range(mask1.size()[1]): mask1[0][col] += 1 mask1 = torch.flatten(mask1) optimizer_grouped_parameters = [{ 'params': [param_optimizer[0][1]], 'weight_decay': 0.01, 'exp_avg_mask': mask1 }, { 'params': [param_optimizer[1][1]], 'weight_decay': 0.01 }] @distributed_test(world_size=[2]) def _test_onebitadam_exp_avg_mask(args, model, hidden_dim): model, optimizer, _, _ = deepspeed.initialize(args=args, model=model, model_parameters=optimizer_grouped_parameters) data_loader = random_dataloader(model=model, total_samples=50, hidden_dim=hidden_dim, device=model.device) for n, batch in enumerate(data_loader): loss = model(batch[0], batch[1]) model.backward(loss) model.step() # Test whether the momentum mask works for v in optimizer.state.values(): if v['exp_avg'].size() == mask1.size(): assert torch.allclose(v['exp_avg'], v['exp_avg'].mul_(mask1.to(device=v['exp_avg'].device)), atol=1e-07), f"Momentum mask is not working properly" _test_onebitadam_exp_avg_mask(args=args, model=model, hidden_dim=hidden_dim) def test_onebitadam_checkpointing(tmpdir): config_dict = { "train_batch_size": 2, "steps_per_print": 1, "optimizer": { "type": "OneBitAdam", "params": { "lr": 0.00015, "weight_decay": 0.01, "freeze_step": 2, "cuda_aware": False, "comm_backend_name": "nccl" } }, "gradient_clipping": 1.0, "fp16": { "enabled": True, "loss_scale": 0, "initial_scale_power": 16 } } args = args_from_dict(tmpdir, config_dict) hidden_dim = 10 model = SimpleModel(hidden_dim) param_optimizer = list(model.named_parameters()) mask1 = torch.zeros_like(param_optimizer[0][1].data) mask2 = torch.zeros_like(param_optimizer[0][1].data) for col in range(mask1.size()[1]): mask1[0][col] += 1 mask2[1][col] += 1 mask1 = torch.flatten(mask1) mask2 = torch.flatten(mask2) optimizer_grouped_parameters_1 = [{ 'params': [param_optimizer[0][1]], 'weight_decay': 0.01, 'exp_avg_mask': mask1 }, { 'params': [param_optimizer[1][1]], 'weight_decay': 0.01 }] optimizer_grouped_parameters_2 = [{ 'params': [param_optimizer[0][1]], 'weight_decay': 0.01, 'exp_avg_mask': mask2 }, { 'params': [param_optimizer[1][1]], 'weight_decay': 0.01 }] optimizer_grouped_parameters_3 = [{ 'params': [param_optimizer[0][1]], 'weight_decay': 0.01 }, { 'params': [param_optimizer[1][1]], 'weight_decay': 0.01 }] @distributed_test(world_size=[2]) def _test_onebitadam_checkpointing(mask1, mask2, args, model, hidden_dim): model_1, optimizer_1, _, _ = deepspeed.initialize(args=args, model=model, model_parameters=optimizer_grouped_parameters_1) data_loader = random_dataloader(model=model_1, total_samples=10, hidden_dim=hidden_dim, device=model_1.device) for n, batch in enumerate(data_loader): loss = model_1(batch[0], batch[1]) model_1.backward(loss) model_1.step() # Test whether momentum mask still exist after saving checkpoint assert optimizer_1.optimizer.adam_freeze_key is True mask1 = mask1.to(device=optimizer_1.param_groups[0]['exp_avg_mask'].device) assert torch.allclose(optimizer_1.param_groups[0]['exp_avg_mask'], mask1, atol=1e-07), f"Incorrect momentum mask" save_folder = os.path.join(tmpdir, 'saved_checkpoint') model_1.save_checkpoint(save_folder, tag=None) assert torch.allclose(optimizer_1.param_groups[0]['exp_avg_mask'], mask1, atol=1e-07), f"Momentum mask should not change after saving checkpoint" model_2, optimizer_2, _, _ = deepspeed.initialize(args=args, model=model, model_parameters=optimizer_grouped_parameters_2) # Test whether momentum mask stays the same after loading checkpoint mask2 = mask2.to(device=optimizer_2.param_groups[0]['exp_avg_mask'].device) assert torch.allclose(optimizer_2.param_groups[0]['exp_avg_mask'], mask2, atol=1e-07), f"Incorrect momentum mask" model_2.load_checkpoint(save_folder, tag=None, load_optimizer_states=True, load_lr_scheduler_states=True) assert torch.allclose(optimizer_2.param_groups[0]['exp_avg_mask'], mask2, atol=1e-07), f"Momentum mask should not change after loading checkpoint" # Test whether worker&server error is reset for v in optimizer_2.state.values(): assert 'worker_error' not in v, f"Incorrect worker error" assert 'server_error' not in v, f"Incorrect server error" assert optimizer_2.optimizer.adam_freeze_key is True model_3, optimizer_3, _, _ = deepspeed.initialize(args=args, model=model, model_parameters=optimizer_grouped_parameters_3) optimizer_3.optimizer.freeze_step = 20 data_loader = random_dataloader(model=model_3, total_samples=50, hidden_dim=hidden_dim, device=model_3.device) for n, batch in enumerate(data_loader): loss = model_3(batch[0], batch[1]) model_3.backward(loss) model_3.step() assert optimizer_3.optimizer.adam_freeze_key is True # Test whether momentum mask stays the same after loading checkpoint assert 'exp_avg_mask' not in optimizer_3.param_groups[0], f"Incorrect momentum mask" model_3.load_checkpoint(save_folder, tag=None, load_optimizer_states=True, load_lr_scheduler_states=True) assert 'exp_avg_mask' not in optimizer_3.param_groups[0], f"Momentum mask should not change after loading checkpoint" # Test whether worker&server error is reset for v in optimizer_3.state.values(): assert 'worker_error' not in v, f"Incorrect worker error" assert 'server_error' not in v, f"Incorrect server error" assert optimizer_3.optimizer.adam_freeze_key is False _test_onebitadam_checkpointing(mask1, mask2, args=args, model=model, hidden_dim=hidden_dim) def test_onebitadam_checkpointing_overflow(tmpdir): config_dict = { "train_batch_size": 2, "steps_per_print": 1, "optimizer": { "type": "OneBitAdam", "params": { "lr": 0.00015, "weight_decay": 0.01, "freeze_step": 2, "cuda_aware": False, "comm_backend_name": "nccl" } }, "gradient_clipping": 1.0, "fp16": { "enabled": True, "loss_scale": 0, "initial_scale_power": 16 } } args = args_from_dict(tmpdir, config_dict) hidden_dim = 10 model = SimpleModel(hidden_dim) @distributed_test(world_size=[2]) def _test_onebitadam_checkpointing_overflow(args, model, hidden_dim): model, _, _, _ = deepspeed.initialize(args=args, model=model, model_parameters=model.parameters()) data_loader = random_dataloader(model=model, total_samples=100, hidden_dim=hidden_dim, device=model.device) save_folder = os.path.join(tmpdir, 'saved_checkpoint') for n, batch in enumerate(data_loader): loss = model(batch[0], batch[1]) if dist.get_rank() == 0 and n >= 10: loss = loss * 1000000.0 model.backward(loss) dist.barrier() model.step() dist.barrier() model.save_checkpoint(save_folder, tag=None) _test_onebitadam_checkpointing_overflow(args=args, model=model, hidden_dim=hidden_dim) @pytest.mark.parametrize('topo', [ PipeTopo(num_pp=1, num_dp=4), PipeTopo(num_pp=2, num_dp=2), PipeTopo(num_pp=4, num_dp=1), ]) def test_onebitadam_fp16_pipeline(topo, tmpdir): config_dict = { "train_batch_size": 16, "train_micro_batch_size_per_gpu": 4, "steps_per_print": 20, "optimizer": { "type": "OneBitAdam", "params": { "lr": 0.00001, "betas": [0.9, 0.999], "eps": 1e-8, "weight_decay": 3e-7, "freeze_step": 200, "cuda_aware": False, "comm_backend_name": "nccl" } }, "gradient_clipping": 1.0, "zero_optimization": { "stage": 0 }, "fp16": { "enabled": True, "loss_scale": 0, "initial_scale_power": 16 }, "pipeline": { "seed_layers": True, "activation_checkpoint_interval": 1 } } args = args_from_dict(tmpdir, config_dict) # Allocate model for consistent initial weights. init_net = AlexNetPipe() @distributed_test(world_size=4) def _helper(topo, tmpdir, steps=500): assert steps >= 100 test_net = copy.deepcopy(init_net) test_model = PipelineModule(layers=test_net.to_layers(), topology=topo, loss_fn=nn.CrossEntropyLoss()) test_losses = train_cifar(test_model, args, num_steps=steps, fp16=config_dict['fp16']['enabled']) _helper(topo, tmpdir) def test_zerooneadam_fp16_basic(tmpdir): config_dict = { "train_batch_size": 2, "steps_per_print": 1, "optimizer": { "type": "ZeroOneAdam", "params": { "lr": 0.00015, "weight_decay": 0.01, "var_freeze_step": 4, "var_update_scaler": 1, "local_step_scaler": 1, "local_step_clipper": 2, "cuda_aware": False, "comm_backend_name": "nccl" } }, "gradient_clipping": 1.0, "fp16": { "enabled": True, "loss_scale": 0, "initial_scale_power": 16 } } args = args_from_dict(tmpdir, config_dict) hidden_dim = 10 model = SimpleModel(hidden_dim) @distributed_test(world_size=[1, 2]) def _test_zerooneadam_fp16_basic(args, model, hidden_dim): model, _, _, _ = deepspeed.initialize(args=args, model=model, model_parameters=model.parameters()) data_loader = random_dataloader(model=model, total_samples=50, hidden_dim=hidden_dim, device=model.device) for n, batch in enumerate(data_loader): loss = model(batch[0], batch[1]) model.backward(loss) model.step() _test_zerooneadam_fp16_basic(args=args, model=model, hidden_dim=hidden_dim) def test_zerooneadam_fp32_basic(tmpdir): config_dict = { "train_batch_size": 2, "steps_per_print": 1, "optimizer": { "type": "ZeroOneAdam", "params": { "lr": 0.00015, "weight_decay": 0.01, "var_freeze_step": 4, "var_update_scaler": 1, "local_step_scaler": 1, "local_step_clipper": 2, "cuda_aware": False, "comm_backend_name": "nccl" } }, "gradient_clipping": 1.0, } args = args_from_dict(tmpdir, config_dict) hidden_dim = 10 model = SimpleModel(hidden_dim) @distributed_test(world_size=[1, 2]) def _test_zerooneadam_fp32_basic(args, model, hidden_dim): model, _, _, _ = deepspeed.initialize(args=args, model=model, model_parameters=model.parameters()) data_loader = random_dataloader(model=model, total_samples=50, hidden_dim=hidden_dim, device=model.device, dtype=torch.float) for n, batch in enumerate(data_loader): loss = model(batch[0], batch[1]) model.backward(loss) model.step() _test_zerooneadam_fp32_basic(args=args, model=model, hidden_dim=hidden_dim) def test_zerooneadam_exp_avg_mask(tmpdir): config_dict = { "train_batch_size": 2, "steps_per_print": 1, "optimizer": { "type": "ZeroOneAdam", "params": { "lr": 0.00015, "weight_decay": 0.01, "var_freeze_step": 4, "var_update_scaler": 1, "local_step_scaler": 1, "local_step_clipper": 2, "cuda_aware": False, "comm_backend_name": "nccl" } }, "gradient_clipping": 1.0, "fp16": { "enabled": True, "loss_scale": 0, "initial_scale_power": 16 } } args = args_from_dict(tmpdir, config_dict) hidden_dim = 10 model = SimpleModel(hidden_dim) param_optimizer = list(model.named_parameters()) mask1 = torch.zeros_like(param_optimizer[0][1].data) for col in range(mask1.size()[1]): mask1[0][col] += 1 mask1 = torch.flatten(mask1) optimizer_grouped_parameters = [{ 'params': [param_optimizer[0][1]], 'weight_decay': 0.01, 'exp_avg_mask': mask1 }, { 'params': [param_optimizer[1][1]], 'weight_decay': 0.01 }] @distributed_test(world_size=[2]) def _test_zerooneadam_exp_avg_mask(args, model, hidden_dim): model, optimizer, _, _ = deepspeed.initialize(args=args, model=model, model_parameters=optimizer_grouped_parameters) data_loader = random_dataloader(model=model, total_samples=50, hidden_dim=hidden_dim, device=model.device) for n, batch in enumerate(data_loader): loss = model(batch[0], batch[1]) model.backward(loss) model.step() # Test whether the momentum mask works for v in optimizer.state.values(): if v['exp_avg'].size() == mask1.size(): assert torch.allclose(v['exp_avg'], v['exp_avg'].mul_(mask1.to(device=v['exp_avg'].device)), atol=1e-07), f"Momentum mask is not working properly" _test_zerooneadam_exp_avg_mask(args=args, model=model, hidden_dim=hidden_dim) def test_zerooneadam_checkpointing(tmpdir): config_dict = { "train_batch_size": 2, "steps_per_print": 1, "optimizer": { "type": "ZeroOneAdam", "params": { "lr": 0.00015, "weight_decay": 0.01, "var_freeze_step": 4, "var_update_scaler": 1, "local_step_scaler": 1, "local_step_clipper": 2, "cuda_aware": False, "comm_backend_name": "nccl" } }, "gradient_clipping": 1.0, "fp16": { "enabled": True, "loss_scale": 0, "initial_scale_power": 16 } } args = args_from_dict(tmpdir, config_dict) hidden_dim = 10 model = SimpleModel(hidden_dim) param_optimizer = list(model.named_parameters()) mask1 = torch.zeros_like(param_optimizer[0][1].data) mask2 = torch.zeros_like(param_optimizer[0][1].data) for col in range(mask1.size()[1]): mask1[0][col] += 1 mask2[1][col] += 1 mask1 = torch.flatten(mask1) mask2 = torch.flatten(mask2) optimizer_grouped_parameters_1 = [{ 'params': [param_optimizer[0][1]], 'weight_decay': 0.01, 'exp_avg_mask': mask1 }, { 'params': [param_optimizer[1][1]], 'weight_decay': 0.01 }] optimizer_grouped_parameters_2 = [{ 'params': [param_optimizer[0][1]], 'weight_decay': 0.01, 'exp_avg_mask': mask2 }, { 'params': [param_optimizer[1][1]], 'weight_decay': 0.01 }] optimizer_grouped_parameters_3 = [{ 'params': [param_optimizer[0][1]], 'weight_decay': 0.01 }, { 'params': [param_optimizer[1][1]], 'weight_decay': 0.01 }] @distributed_test(world_size=[2]) def _test_zerooneadam_checkpointing(mask1, mask2, args, model, hidden_dim): model_1, optimizer_1, _, _ = deepspeed.initialize(args=args, model=model, model_parameters=optimizer_grouped_parameters_1) data_loader = random_dataloader(model=model_1, total_samples=10, hidden_dim=hidden_dim, device=model_1.device) for n, batch in enumerate(data_loader): loss = model_1(batch[0], batch[1]) model_1.backward(loss) model_1.step() # Test whether momentum mask still exist after saving checkpoint mask1 = mask1.to(device=optimizer_1.param_groups[0]['exp_avg_mask'].device) assert torch.allclose(optimizer_1.param_groups[0]['exp_avg_mask'], mask1, atol=1e-07), f"Incorrect momentum mask" save_folder = os.path.join(tmpdir, 'saved_checkpoint') model_1.save_checkpoint(save_folder, tag=None) assert torch.allclose(optimizer_1.param_groups[0]['exp_avg_mask'], mask1, atol=1e-07), f"Momentum mask should not change after saving checkpoint" model_2, optimizer_2, _, _ = deepspeed.initialize(args=args, model=model, model_parameters=optimizer_grouped_parameters_2) # Test whether momentum mask stays the same after loading checkpoint mask2 = mask2.to(device=optimizer_2.param_groups[0]['exp_avg_mask'].device) assert torch.allclose(optimizer_2.param_groups[0]['exp_avg_mask'], mask2, atol=1e-07), f"Incorrect momentum mask" model_2.load_checkpoint(save_folder, tag=None, load_optimizer_states=True, load_lr_scheduler_states=True) assert torch.allclose(optimizer_2.param_groups[0]['exp_avg_mask'], mask2, atol=1e-07), f"Momentum mask should not change after loading checkpoint" # Test whether worker&server error is reset for v in optimizer_2.state.values(): assert 'worker_error' not in v, f"Incorrect worker error" assert 'server_error' not in v, f"Incorrect server error" model_3, optimizer_3, _, _ = deepspeed.initialize(args=args, model=model, model_parameters=optimizer_grouped_parameters_3) optimizer_3.optimizer.freeze_step = 20 data_loader = random_dataloader(model=model_3, total_samples=50, hidden_dim=hidden_dim, device=model_3.device) for n, batch in enumerate(data_loader): loss = model_3(batch[0], batch[1]) model_3.backward(loss) model_3.step() # Test whether momentum mask stays the same after loading checkpoint assert 'exp_avg_mask' not in optimizer_3.param_groups[0], f"Incorrect momentum mask" model_3.load_checkpoint(save_folder, tag=None, load_optimizer_states=True, load_lr_scheduler_states=True) assert 'exp_avg_mask' not in optimizer_3.param_groups[0], f"Momentum mask should not change after loading checkpoint" # Test whether worker&server error is reset for v in optimizer_3.state.values(): assert 'worker_error' not in v, f"Incorrect worker error" assert 'server_error' not in v, f"Incorrect server error" _test_zerooneadam_checkpointing(mask1, mask2, args=args, model=model, hidden_dim=hidden_dim) def test_zerooneadam_checkpointing_overflow(tmpdir): config_dict = { "train_batch_size": 2, "steps_per_print": 1, "optimizer": { "type": "ZeroOneAdam", "params": { "lr": 0.00015, "weight_decay": 0.01, "var_freeze_step": 4, "var_update_scaler": 1, "local_step_scaler": 1, "local_step_clipper": 2, "cuda_aware": False, "comm_backend_name": "nccl" } }, "gradient_clipping": 1.0, "fp16": { "enabled": True, "loss_scale": 0, "initial_scale_power": 16 } } args = args_from_dict(tmpdir, config_dict) hidden_dim = 10 model = SimpleModel(hidden_dim) @distributed_test(world_size=[2]) def _test_zerooneadam_checkpointing_overflow(args, model, hidden_dim): model, _, _, _ = deepspeed.initialize(args=args, model=model, model_parameters=model.parameters()) data_loader = random_dataloader(model=model, total_samples=100, hidden_dim=hidden_dim, device=model.device) save_folder = os.path.join(tmpdir, 'saved_checkpoint') for n, batch in enumerate(data_loader): loss = model(batch[0], batch[1]) if dist.get_rank() == 0 and n >= 10: loss = loss * 1000000.0 model.backward(loss) dist.barrier() model.step() dist.barrier() model.save_checkpoint(save_folder, tag=None) _test_zerooneadam_checkpointing_overflow(args=args, model=model, hidden_dim=hidden_dim) @pytest.mark.parametrize('topo', [ PipeTopo(num_pp=1, num_dp=4), PipeTopo(num_pp=2, num_dp=2), PipeTopo(num_pp=4, num_dp=1), ]) def test_zerooneadam_fp16_pipeline(topo, tmpdir): config_dict = { "train_batch_size": 16, "train_micro_batch_size_per_gpu": 4, "steps_per_print": 20, "optimizer": { "type": "ZeroOneAdam", "params": { "lr": 0.00001, "betas": [0.9, 0.999], "eps": 1e-8, "weight_decay": 3e-7, "var_freeze_step": 4, "var_update_scaler": 1, "local_step_scaler": 1, "local_step_clipper": 2, "cuda_aware": False, "comm_backend_name": "nccl" } }, "gradient_clipping": 1.0, "zero_optimization": { "stage": 0 }, "fp16": { "enabled": True, "loss_scale": 0, "initial_scale_power": 16 }, "pipeline": { "seed_layers": True, "activation_checkpoint_interval": 1 } } args = args_from_dict(tmpdir, config_dict) # Allocate model for consistent initial weights. init_net = AlexNetPipe() @distributed_test(world_size=4) def _helper(topo, tmpdir, steps=500): assert steps >= 100 test_net = copy.deepcopy(init_net) test_model = PipelineModule(layers=test_net.to_layers(), topology=topo, loss_fn=nn.CrossEntropyLoss()) test_losses = train_cifar(test_model, args, num_steps=steps, fp16=config_dict['fp16']['enabled']) _helper(topo, tmpdir) def test_onebitlamb_fp16_basic(tmpdir): config_dict = { "train_batch_size": 2, "steps_per_print": 1, "optimizer": { "type": "OneBitLamb", "params": { "lr": 0.00015, "weight_decay": 0.01, "max_coeff": 0.3, "min_coeff": 0.01, "freeze_step": 2, "cuda_aware": False, "comm_backend_name": "nccl", "coeff_beta": 0.9, "factor_max": 1.0, "factor_min": 0.5, "factor_threshold": 0.1 } }, "gradient_clipping": 1.0, "fp16": { "enabled": True, "loss_scale": 0, "initial_scale_power": 16 } } args = args_from_dict(tmpdir, config_dict) hidden_dim = 10 model = SimpleModel(hidden_dim) @distributed_test(world_size=[1, 2]) def _test_onebitlamb_fp16_basic(args, model, hidden_dim): model, _, _, _ = deepspeed.initialize(args=args, model=model, model_parameters=model.parameters()) data_loader = random_dataloader(model=model, total_samples=50, hidden_dim=hidden_dim, device=model.device) for n, batch in enumerate(data_loader): loss = model(batch[0], batch[1]) model.backward(loss) model.step() _test_onebitlamb_fp16_basic(args=args, model=model, hidden_dim=hidden_dim) def test_onebitlamb_fp32_basic(tmpdir): config_dict = { "train_batch_size": 2, "steps_per_print": 1, "optimizer": { "type": "OneBitLamb", "params": { "lr": 0.00015, "weight_decay": 0.01, "max_coeff": 0.3, "min_coeff": 0.01, "freeze_step": 2, "cuda_aware": False, "comm_backend_name": "nccl", "coeff_beta": 0.9, "factor_max": 1.0, "factor_min": 0.5, "factor_threshold": 0.1 } }, "gradient_clipping": 1.0, } args = args_from_dict(tmpdir, config_dict) hidden_dim = 10 model = SimpleModel(hidden_dim) @distributed_test(world_size=[1, 2]) def _test_onebitlamb_fp32_basic(args, model, hidden_dim): model, _, _, _ = deepspeed.initialize(args=args, model=model, model_parameters=model.parameters()) data_loader = random_dataloader(model=model, total_samples=50, hidden_dim=hidden_dim, device=model.device, dtype=torch.float) for n, batch in enumerate(data_loader): loss = model(batch[0], batch[1]) model.backward(loss) model.step() _test_onebitlamb_fp32_basic(args=args, model=model, hidden_dim=hidden_dim) def test_onebitlamb_exp_avg_mask(tmpdir): config_dict = { "train_batch_size": 2, "steps_per_print": 1, "optimizer": { "type": "OneBitLamb", "params": { "lr": 0.00015, "weight_decay": 0.01, "max_coeff": 0.3, "min_coeff": 0.01, "freeze_step": 2, "cuda_aware": False, "comm_backend_name": "nccl", "coeff_beta": 0.9, "factor_max": 1.0, "factor_min": 0.5, "factor_threshold": 0.1 } }, "gradient_clipping": 1.0, "fp16": { "enabled": True, "loss_scale": 0, "initial_scale_power": 16 } } args = args_from_dict(tmpdir, config_dict) hidden_dim = 10 model = SimpleModel(hidden_dim) param_optimizer = list(model.named_parameters()) mask1 = torch.zeros_like(param_optimizer[0][1].data) for col in range(mask1.size()[1]): mask1[0][col] += 1 optimizer_grouped_parameters = [{ 'params': [param_optimizer[0][1]], 'weight_decay': 0.01, 'exp_avg_mask': mask1 }, { 'params': [param_optimizer[1][1]], 'weight_decay': 0.01 }] @distributed_test(world_size=[2]) def _test_onebitlamb_exp_avg_mask(args, model, hidden_dim): model, optimizer, _, _ = deepspeed.initialize(args=args, model=model, model_parameters=optimizer_grouped_parameters) data_loader = random_dataloader(model=model, total_samples=50, hidden_dim=hidden_dim, device=model.device) for n, batch in enumerate(data_loader): loss = model(batch[0], batch[1]) model.backward(loss) model.step() # Test whether the momentum mask works for v in optimizer.state.values(): if v['exp_avg'].size() == mask1.size(): assert torch.allclose(v['exp_avg'], v['exp_avg'].mul_(mask1.to(device=v['exp_avg'].device)), atol=1e-07), f"Momentum mask is not working properly" _test_onebitlamb_exp_avg_mask(args=args, model=model, hidden_dim=hidden_dim) def test_onebitlamb_checkpointing(tmpdir): config_dict = { "train_batch_size": 2, "steps_per_print": 1, "optimizer": { "type": "OneBitLamb", "params": { "lr": 0.00015, "weight_decay": 0.01, "max_coeff": 0.3, "min_coeff": 0.01, "freeze_step": 2, "cuda_aware": False, "comm_backend_name": "nccl", "coeff_beta": 0.9, "factor_max": 1.0, "factor_min": 0.5, "factor_threshold": 0.1 } }, "gradient_clipping": 1.0, "fp16": { "enabled": True, "loss_scale": 0, "initial_scale_power": 16 } } args = args_from_dict(tmpdir, config_dict) hidden_dim = 10 model = SimpleModel(hidden_dim) param_optimizer = list(model.named_parameters()) mask1 = torch.zeros_like(param_optimizer[0][1].data) mask2 = torch.zeros_like(param_optimizer[0][1].data) for col in range(mask1.size()[1]): mask1[0][col] += 1 mask2[1][col] += 1 optimizer_grouped_parameters_1 = [{ 'params': [param_optimizer[0][1]], 'weight_decay': 0.01, 'exp_avg_mask': mask1 }, { 'params': [param_optimizer[1][1]], 'weight_decay': 0.01 }] optimizer_grouped_parameters_2 = [{ 'params': [param_optimizer[0][1]], 'weight_decay': 0.01, 'exp_avg_mask': mask2 }, { 'params': [param_optimizer[1][1]], 'weight_decay': 0.01 }] optimizer_grouped_parameters_3 = [{ 'params': [param_optimizer[0][1]], 'weight_decay': 0.01 }, { 'params': [param_optimizer[1][1]], 'weight_decay': 0.01 }] @distributed_test(world_size=[2]) def _test_onebitlamb_checkpointing(mask1, mask2, args, model, hidden_dim): model_1, optimizer_1, _, _ = deepspeed.initialize(args=args, model=model, model_parameters=optimizer_grouped_parameters_1) data_loader = random_dataloader(model=model_1, total_samples=10, hidden_dim=hidden_dim, device=model_1.device) for n, batch in enumerate(data_loader): loss = model_1(batch[0], batch[1]) model_1.backward(loss) model_1.step() # Test whether momentum mask still exist after saving checkpoint assert optimizer_1.optimizer.lamb_freeze_key is True mask1 = mask1.to(device=optimizer_1.param_groups[0]['exp_avg_mask'].device) assert torch.allclose(optimizer_1.param_groups[0]['exp_avg_mask'], mask1, atol=1e-07), f"Incorrect momentum mask" scaling_coeff_1 = [] for v in optimizer_1.state.values(): assert 'scaling_coeff' in v, f"Incorrect scaling_coeff" scaling_coeff_1.append(v['scaling_coeff']) save_folder = os.path.join(tmpdir, 'saved_checkpoint') model_1.save_checkpoint(save_folder, tag=None) assert torch.allclose(optimizer_1.param_groups[0]['exp_avg_mask'], mask1, atol=1e-07), f"Momentum mask should not change after saving checkpoint" model_2, optimizer_2, _, _ = deepspeed.initialize(args=args, model=model, model_parameters=optimizer_grouped_parameters_2) # Test whether momentum mask stays the same after loading checkpoint mask2 = mask2.to(device=optimizer_2.param_groups[0]['exp_avg_mask'].device) assert torch.allclose(optimizer_2.param_groups[0]['exp_avg_mask'], mask2, atol=1e-07), f"Incorrect momentum mask" model_2.load_checkpoint(save_folder, tag=None, load_optimizer_states=True, load_lr_scheduler_states=True) assert torch.allclose(optimizer_2.param_groups[0]['exp_avg_mask'], mask2, atol=1e-07), f"Momentum mask should not change after loading checkpoint" # Test whether worker&server error is reset assert len(optimizer_2.optimizer.worker_errors) == 0, f"Incorrect worker error" assert len(optimizer_2.optimizer.server_errors) == 0, f"Incorrect server error" # Test whether scaling_coeffs is loaded correctly scaling_coeff_2 = [] for v in optimizer_2.state.values(): assert 'scaling_coeff' in v, f"Incorrect scaling_coeff" scaling_coeff_2.append(v['scaling_coeff']) assert list(sorted(scaling_coeff_2)) == list(sorted(scaling_coeff_1)), f"Incorrect scaling_coeffs" assert optimizer_2.optimizer.lamb_freeze_key is True model_3, optimizer_3, _, _ = deepspeed.initialize(args=args, model=model, model_parameters=optimizer_grouped_parameters_3) optimizer_3.optimizer.freeze_step = 20 data_loader = random_dataloader(model=model_3, total_samples=50, hidden_dim=hidden_dim, device=model_3.device) for n, batch in enumerate(data_loader): loss = model_3(batch[0], batch[1]) model_3.backward(loss) model_3.step() assert optimizer_3.optimizer.lamb_freeze_key is True # Test whether momentum mask stays the same after loading checkpoint assert 'exp_avg_mask' not in optimizer_3.param_groups[0], f"Incorrect momentum mask" model_3.load_checkpoint(save_folder, tag=None, load_optimizer_states=True, load_lr_scheduler_states=True) assert 'exp_avg_mask' not in optimizer_3.param_groups[0], f"Momentum mask should not change after loading checkpoint" # Test whether worker&server error is reset assert len(optimizer_3.optimizer.worker_errors) == 0, f"Incorrect worker error" assert len(optimizer_3.optimizer.server_errors) == 0, f"Incorrect server error" # Test whether scaling_coeffs, lamb_coeff_freeze, last_factor are reset for v in optimizer_3.state.values(): assert v['lamb_coeff_freeze'] == 0.0, f"Incorrect lamb_coeff_freeze" assert v['last_factor'] == 1.0, f"Incorrect last_factor" assert 'scaling_coeff' not in v, f"Incorrect scaling_coeff" assert optimizer_3.optimizer.lamb_freeze_key is False _test_onebitlamb_checkpointing(mask1, mask2, args=args, model=model, hidden_dim=hidden_dim) def test_onebitlamb_checkpointing_overflow(tmpdir): config_dict = { "train_batch_size": 2, "steps_per_print": 1, "optimizer": { "type": "OneBitLamb", "params": { "lr": 0.00015, "weight_decay": 0.01, "max_coeff": 0.3, "min_coeff": 0.01, "freeze_step": 2, "cuda_aware": False, "comm_backend_name": "nccl", "coeff_beta": 0.9, "factor_max": 1.0, "factor_min": 0.5, "factor_threshold": 0.1 } }, "gradient_clipping": 1.0, "fp16": { "enabled": True, "loss_scale": 0, "initial_scale_power": 16 } } args = args_from_dict(tmpdir, config_dict) hidden_dim = 10 model = SimpleModel(hidden_dim) @distributed_test(world_size=[2]) def _test_onebitlamb_checkpointing_overflow(args, model, hidden_dim): model, _, _, _ = deepspeed.initialize(args=args, model=model, model_parameters=model.parameters()) data_loader = random_dataloader(model=model, total_samples=100, hidden_dim=hidden_dim, device=model.device) save_folder = os.path.join(tmpdir, 'saved_checkpoint') for n, batch in enumerate(data_loader): loss = model(batch[0], batch[1]) if dist.get_rank() == 0 and n >= 10: loss = loss * 1000000.0 model.backward(loss) dist.barrier() model.step() dist.barrier() model.save_checkpoint(save_folder, tag=None) _test_onebitlamb_checkpointing_overflow(args=args, model=model, hidden_dim=hidden_dim) @pytest.mark.parametrize('topo', [ PipeTopo(num_pp=1, num_dp=4), PipeTopo(num_pp=2, num_dp=2), PipeTopo(num_pp=4, num_dp=1), ]) def test_onebitlamb_fp16_pipeline(topo, tmpdir): config_dict = { "train_batch_size": 16, "train_micro_batch_size_per_gpu": 4, "steps_per_print": 20, "optimizer": { "type": "OneBitLamb", "params": { "lr": 0.00001, "betas": [0.9, 0.999], "eps": 1e-8, "weight_decay": 3e-7, "freeze_step": 200, "cuda_aware": False, "comm_backend_name": "nccl" } }, "gradient_clipping": 1.0, "zero_optimization": { "stage": 0 }, "fp16": { "enabled": True, "loss_scale": 0, "initial_scale_power": 16 }, "pipeline": { "seed_layers": True, "activation_checkpoint_interval": 1 } } args = args_from_dict(tmpdir, config_dict) # Allocate model for consistent initial weights. init_net = AlexNetPipe() @distributed_test(world_size=4) def _helper(topo, tmpdir, steps=500): assert steps >= 100 test_net = copy.deepcopy(init_net) test_model = PipelineModule(layers=test_net.to_layers(), topology=topo, loss_fn=nn.CrossEntropyLoss()) test_losses = train_cifar(test_model, args, num_steps=steps, fp16=config_dict['fp16']['enabled']) _helper(topo, tmpdir) @pytest.mark.sequential def test_compressed_allreduce_basic(tmpdir): @distributed_test(world_size=[1, 2]) def _test_compressed_allreduce_basic(): from deepspeed.runtime.comm.nccl import NcclBackend size = dist.get_world_size() rank = dist.get_rank() backend = NcclBackend() local_rank = dist.get_rank() device = torch.device("cuda", dist.get_rank()) # A simulated compression function using torch.distributed def torch_sim(a): a_sign = a.sign().add_(1).bool().float().add_(-0.5).mul_(2.0) scale = a.norm() / np.sqrt(a.numel()) a_compressed = scale * a_sign a_sign = None worker_error = a - a_compressed dist.all_reduce(a_compressed) a_compressed.mul_(1 / dist.get_world_size()) a_server_sign = a_compressed.sign().add_(1).bool().float().add_(-0.5).mul_( 2.0) a_list = torch.chunk(a_compressed, chunks=dist.get_world_size()) server_scale = [ chunk_a.norm() / np.sqrt(chunk_a.numel()) for chunk_a in a_list ] a_sign_list = torch.chunk(a_server_sign, dist.get_world_size()) a_server_compressed = torch.cat( [server_scale[i] * a_sign_list[i] for i in range(dist.get_world_size())]) rank = dist.get_rank() server_error = a_list[rank] - server_scale[rank] * a_sign_list[rank] torch.cuda.synchronize() torch.distributed.barrier() return a_server_compressed, worker_error, server_error tensor_size = 300 * 2**20 server_size = int(tensor_size / size) if tensor_size % (8 * size) != 0: right_tensor_size = tensor_size + (8 * size - (tensor_size % (8 * size))) else: right_tensor_size = tensor_size right_server_size = right_tensor_size // size # Adding bias to the initialization of the gradient we are communicating # In order to get rid of the case where some elements in the gradient are too small a = (torch.rand(tensor_size, device=device) - 0.5) + 0.01 * rank worker_error = torch.zeros(right_tensor_size, device=device) server_error = torch.zeros(right_server_size, device=device) a_torch, worker_error_torch, server_error_torch = torch_sim(a) torch.cuda.empty_cache() a_after = backend.compressed_allreduce(a, worker_error, server_error, local_rank) threshold = 1e-6 magnitude_threshold = 1e-6 diff_mask = (a_after - a_torch) > threshold diff_server_mask = torch.chunk(diff_mask, size)[rank] mpi_server = torch.chunk(a_after, size)[rank] + server_error torch_server = torch.chunk(a_torch, size)[rank] + server_error_torch # If the number in the compensated_server_m is too small (e.g 1e-8), then calling sign() might be problematic # The test would skip those numbers that are too small in compensated_server_m check_mag_mask = mpi_server[diff_server_mask] > magnitude_threshold if torch.sum(check_mag_mask) != 0: print('Fails at {} of positions'.format(torch.sum(check_mag_mask))) assert torch.sum(diff_server_mask) == 0 or torch.sum(check_mag_mask) == 0 _test_compressed_allreduce_basic()