'''Copyright The Microsoft DeepSpeed Team''' import copy import torch import torch.nn as nn import deepspeed.comm as dist import pytest import deepspeed from deepspeed.pipe import PipelineModule from deepspeed.utils import RepeatingLoader from deepspeed.accelerator import get_accelerator from unit.common import DistributedTest HIDDEN_DIM = 32 LAYERS = 8 @pytest.fixture def sequential_model(): model = torch.nn.Sequential( *[nn.Linear(HIDDEN_DIM, HIDDEN_DIM) for _ in range(LAYERS)], nn.Linear(HIDDEN_DIM, 1), ) return model @pytest.fixture def simple_config(): config_dict = { "train_batch_size": 1, "train_micro_batch_size_per_gpu": 1, "steps_per_print": 1, "optimizer": { "type": "Adam", "params": { "lr": 0.001, "betas": [0.9, 0.999], "eps": 1e-8, "weight_decay": 3e-7 } }, "pipeline": { "activation_checkpoint_interval": 1 } } return config_dict @pytest.fixture def batch_input(): return torch.randn(1, HIDDEN_DIM) class TestPipeModuleSequential(DistributedTest): world_size = 2 def test(self, sequential_model, simple_config, batch_input): base_model = copy.deepcopy(sequential_model) base_input = batch_input.clone().detach() base_output = base_model(base_input) base_output = base_output base_params = sum(p.numel() for p in base_model.parameters()) pipe_model = copy.deepcopy(sequential_model) pipe_model = PipelineModule(layers=pipe_model, num_stages=2) # Ensure all parameters are accounted for. my_params = sum(p.numel() for p in pipe_model.parameters()) total_pipe_params = torch.LongTensor([my_params ]).to(get_accelerator().device_name()) dist.all_reduce(total_pipe_params) total_pipe_params = total_pipe_params.item() assert total_pipe_params == base_params pipe_model, _, _, _ = deepspeed.initialize( config=simple_config, model=pipe_model, model_parameters=[p for p in pipe_model.parameters()]) if pipe_model.is_first_stage or pipe_model.is_last_stage: pipe_input = base_input.clone().detach().to(get_accelerator().device_name()) # label 0 is meaningless dataset = [(pipe_input, 0)] loader = RepeatingLoader(dataset) data_iter = iter(loader) else: data_iter = None pipe_output = pipe_model.eval_batch(data_iter=data_iter) base_output = base_output.to('cpu') pipe_output = pipe_output.to('cpu') assert torch.allclose(base_output, pipe_output, atol=1e-4)