import argparse from functools import partial import numpy as np import torch import torch.distributed as dist from dataset import * from loss import * from modules import * from modules.config import * from optims import * def run_epoch( epoch, data_iter, dev_rank, ndev, model, loss_compute, is_train=True ): universal = isinstance(model, UTransformer) with loss_compute: for i, g in enumerate(data_iter): with T.set_grad_enabled(is_train): if universal: output, loss_act = model(g) if is_train: loss_act.backward(retain_graph=True) else: output = model(g) tgt_y = g.tgt_y n_tokens = g.n_tokens loss = loss_compute(output, tgt_y, n_tokens) if universal: for step in range(1, model.MAX_DEPTH + 1): print( "nodes entering step {}: {:.2f}%".format( step, (1.0 * model.stat[step] / model.stat[0]) ) ) model.reset_stat() print( "Epoch {} {}: Dev {} average loss: {}, accuracy {}".format( epoch, "Training" if is_train else "Evaluating", dev_rank, loss_compute.avg_loss, loss_compute.accuracy, ) ) def run(dev_id, args): dist_init_method = "tcp://{master_ip}:{master_port}".format( master_ip=args.master_ip, master_port=args.master_port ) world_size = args.ngpu torch.distributed.init_process_group( backend="nccl", init_method=dist_init_method, world_size=world_size, rank=dev_id, ) gpu_rank = torch.distributed.get_rank() assert gpu_rank == dev_id main(dev_id, args) def main(dev_id, args): if dev_id == -1: device = torch.device("cpu") else: device = torch.device("cuda:{}".format(dev_id)) # Set current device th.cuda.set_device(device) # Prepare dataset dataset = get_dataset(args.dataset) V = dataset.vocab_size criterion = LabelSmoothing(V, padding_idx=dataset.pad_id, smoothing=0.1) dim_model = 512 # Build graph pool graph_pool = GraphPool() # Create model model = make_model( V, V, N=args.N, dim_model=dim_model, universal=args.universal ) # Sharing weights between Encoder & Decoder model.src_embed.lut.weight = model.tgt_embed.lut.weight model.generator.proj.weight = model.tgt_embed.lut.weight # Move model to corresponding device model, criterion = model.to(device), criterion.to(device) # Loss function if args.ngpu > 1: dev_rank = dev_id # current device id ndev = args.ngpu # number of devices (including cpu) loss_compute = partial( MultiGPULossCompute, criterion, args.ngpu, args.grad_accum, model ) else: # cpu or single gpu case dev_rank = 0 ndev = 1 loss_compute = partial(SimpleLossCompute, criterion, args.grad_accum) if ndev > 1: for param in model.parameters(): dist.all_reduce(param.data, op=dist.ReduceOp.SUM) param.data /= ndev # Optimizer model_opt = NoamOpt( dim_model, 0.1, 4000, T.optim.Adam(model.parameters(), lr=1e-3, betas=(0.9, 0.98), eps=1e-9), ) # Train & evaluate for epoch in range(100): start = time.time() train_iter = dataset( graph_pool, mode="train", batch_size=args.batch, device=device, dev_rank=dev_rank, ndev=ndev, ) model.train(True) run_epoch( epoch, train_iter, dev_rank, ndev, model, loss_compute(opt=model_opt), is_train=True, ) if dev_rank == 0: model.att_weight_map = None model.eval() valid_iter = dataset( graph_pool, mode="valid", batch_size=args.batch, device=device, dev_rank=dev_rank, ndev=1, ) run_epoch( epoch, valid_iter, dev_rank, 1, model, loss_compute(opt=None), is_train=False, ) end = time.time() print("epoch time: {}".format(end - start)) # Visualize attention if args.viz: src_seq = dataset.get_seq_by_id( VIZ_IDX, mode="valid", field="src" ) tgt_seq = dataset.get_seq_by_id( VIZ_IDX, mode="valid", field="tgt" )[:-1] draw_atts( model.att_weight_map, src_seq, tgt_seq, exp_setting, "epoch_{}".format(epoch), ) args_filter = [ "batch", "gpus", "viz", "master_ip", "master_port", "grad_accum", "ngpu", ] exp_setting = "-".join( "{}".format(v) for k, v in vars(args).items() if k not in args_filter ) with open( "checkpoints/{}-{}.pkl".format(exp_setting, epoch), "wb" ) as f: torch.save(model.state_dict(), f) if __name__ == "__main__": if not os.path.exists("checkpoints"): os.makedirs("checkpoints") np.random.seed(1111) argparser = argparse.ArgumentParser("training translation model") argparser.add_argument("--gpus", default="-1", type=str, help="gpu id") argparser.add_argument("--N", default=6, type=int, help="enc/dec layers") argparser.add_argument("--dataset", default="multi30k", help="dataset") argparser.add_argument("--batch", default=128, type=int, help="batch size") argparser.add_argument( "--viz", action="store_true", help="visualize attention" ) argparser.add_argument( "--universal", action="store_true", help="use universal transformer" ) argparser.add_argument( "--master-ip", type=str, default="127.0.0.1", help="master ip address" ) argparser.add_argument( "--master-port", type=str, default="12345", help="master port" ) argparser.add_argument( "--grad-accum", type=int, default=1, help="accumulate gradients for this many times " "then update weights", ) args = argparser.parse_args() print(args) devices = list(map(int, args.gpus.split(","))) if len(devices) == 1: args.ngpu = 0 if devices[0] < 0 else 1 main(devices[0], args) else: args.ngpu = len(devices) mp = torch.multiprocessing.get_context("spawn") procs = [] for dev_id in devices: procs.append( mp.Process(target=run, args=(dev_id, args), daemon=True) ) procs[-1].start() for p in procs: p.join()