import os import json import argparse import torch class SimpleModel(torch.nn.Module): def __init__(self, hidden_dim, empty_grad=False, rank=0): super(SimpleModel, self).__init__() self.linear = torch.nn.Linear(hidden_dim, hidden_dim) if empty_grad: self.linear2 = torch.nn.Linear(hidden_dim, hidden_dim) self.cross_entropy_loss = torch.nn.CrossEntropyLoss() self.rank = rank self.empty_grad = empty_grad def forward(self, x, y): hidden_dim = x if self.rank == 0 and self.empty_grad: hidden_dim = self.linear(hidden_dim) + self.linear2(hidden_dim) else: hidden_dim = self.linear(hidden_dim) return self.cross_entropy_loss(hidden_dim, y) class SimpleOptimizer(torch.optim.Optimizer): def __init__(self, params, lr=0.11072018): defaults = dict(lr=lr) super(SimpleOptimizer, self).__init__(params, defaults) def __setstate__(self, state): super(SimpleOptimizer, self).__setstate__(state) def step(self, closure=None): loss = None if closure is not None: loss = closure() for group in self.param_groups: for p in group['params']: if p.grad is None: continue d_p = p.grad.data p.data.add_(-group['lr'], d_p) return loss def random_dataloader(model, total_samples, hidden_dim, device, dtype=torch.half): batch_size = model.train_micro_batch_size_per_gpu() train_data = torch.randn(total_samples, hidden_dim, device=device, dtype=dtype) train_label = torch.empty(total_samples, dtype=torch.long, device=device).random_(hidden_dim) train_dataset = torch.utils.data.TensorDataset(train_data, train_label) train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size) return train_loader def create_config_from_dict(tmpdir, config_dict): config_path = os.path.join(tmpdir, 'temp_config.json') with open(config_path, 'w') as fd: json.dump(config_dict, fd) return config_path def args_from_dict(tmpdir, config_dict): config_path = create_config_from_dict(tmpdir, config_dict) parser = argparse.ArgumentParser() args = parser.parse_args(args='') args.deepspeed = True args.deepspeed_config = config_path args.local_rank = 0 return args