Unverified Commit 29dd22e6 authored by Lingfan Yu's avatar Lingfan Yu Committed by GitHub
Browse files

[Model] Support Multi-GPU for Transformer model (#356)

* multi-process version of transformer

* lots of fix

* fix bugs and accum gradients for multiple batches

* many fixes

* minor

* upd

* set torch device

* fix bugs

* fix and minor

* comments and clean up

* uncomment viz code
parent 1c63dd4f
...@@ -2,14 +2,14 @@ from .graph import * ...@@ -2,14 +2,14 @@ from .graph import *
from .fields import * from .fields import *
from .utils import prepare_dataset from .utils import prepare_dataset
import os import os
import numpy as np import random
class ClassificationDataset: class ClassificationDataset(object):
"Dataset class for classification task." "Dataset class for classification task."
def __init__(self): def __init__(self):
raise NotImplementedError raise NotImplementedError
class TranslationDataset: class TranslationDataset(object):
''' '''
Dataset class for translation task. Dataset class for translation task.
By default, the source language shares the same vocabulary with the target language. By default, the source language shares the same vocabulary with the target language.
...@@ -22,17 +22,17 @@ class TranslationDataset: ...@@ -22,17 +22,17 @@ class TranslationDataset:
vocab_path = os.path.join(path, vocab) vocab_path = os.path.join(path, vocab)
self.src = {} self.src = {}
self.tgt = {} self.tgt = {}
with open(os.path.join(path, train + '.' + exts[0]), 'r') as f: with open(os.path.join(path, train + '.' + exts[0]), 'r', encoding='utf-8') as f:
self.src['train'] = f.readlines() self.src['train'] = f.readlines()
with open(os.path.join(path, train + '.' + exts[1]), 'r') as f: with open(os.path.join(path, train + '.' + exts[1]), 'r', encoding='utf-8') as f:
self.tgt['train'] = f.readlines() self.tgt['train'] = f.readlines()
with open(os.path.join(path, valid + '.' + exts[0]), 'r') as f: with open(os.path.join(path, valid + '.' + exts[0]), 'r', encoding='utf-8') as f:
self.src['valid'] = f.readlines() self.src['valid'] = f.readlines()
with open(os.path.join(path, valid + '.' + exts[1]), 'r') as f: with open(os.path.join(path, valid + '.' + exts[1]), 'r', encoding='utf-8') as f:
self.tgt['valid'] = f.readlines() self.tgt['valid'] = f.readlines()
with open(os.path.join(path, test + '.' + exts[0]), 'r') as f: with open(os.path.join(path, test + '.' + exts[0]), 'r', encoding='utf-8') as f:
self.src['test'] = f.readlines() self.src['test'] = f.readlines()
with open(os.path.join(path, test + '.' + exts[1]), 'r') as f: with open(os.path.join(path, test + '.' + exts[1]), 'r', encoding='utf-8') as f:
self.tgt['test'] = f.readlines() self.tgt['test'] = f.readlines()
if not os.path.exists(vocab_path): if not os.path.exists(vocab_path):
...@@ -90,20 +90,30 @@ class TranslationDataset: ...@@ -90,20 +90,30 @@ class TranslationDataset:
def eos_id(self): def eos_id(self):
return self.vocab[self.EOS_TOKEN] return self.vocab[self.EOS_TOKEN]
def __call__(self, graph_pool, mode='train', batch_size=32, k=1, devices=['cpu']): def __call__(self, graph_pool, mode='train', batch_size=32, k=1,
device='cpu', dev_rank=0, ndev=1):
''' '''
Create a batched graph correspond to the mini-batch of the dataset. Create a batched graph correspond to the mini-batch of the dataset.
args: args:
graph_pool: a GraphPool object for accelerating. graph_pool: a GraphPool object for accelerating.
mode: train/valid/test mode: train/valid/test
batch_size: batch size batch_size: batch size
devices: ['cpu'] or a list of gpu ids. k: beam size(only required for test)
k: beam size(only required for test) device: str or torch.device
dev_rank: rank (id) of current device
ndev: number of devices
''' '''
dev_id, gs = 0, []
src_data, tgt_data = self.src[mode], self.tgt[mode] src_data, tgt_data = self.src[mode], self.tgt[mode]
n = len(src_data) n = len(src_data)
order = np.random.permutation(n) if mode == 'train' else range(n) # make sure all devices have the same number of batch
n = n // ndev * ndev
# XXX: partition then shuffle may not be equivalent to shuffle then
# partition
order = list(range(dev_rank, n, ndev))
if mode == 'train':
random.shuffle(order)
src_buf, tgt_buf = [], [] src_buf, tgt_buf = [], []
for idx in order: for idx in order:
...@@ -115,22 +125,16 @@ class TranslationDataset: ...@@ -115,22 +125,16 @@ class TranslationDataset:
tgt_buf.append(tgt_sample) tgt_buf.append(tgt_sample)
if len(src_buf) == batch_size: if len(src_buf) == batch_size:
if mode == 'test': if mode == 'test':
assert len(devices) == 1 # we only allow single gpu for inference yield graph_pool.beam(src_buf, self.sos_id, self.MAX_LENGTH, k, device=device)
yield graph_pool.beam(src_buf, self.sos_id, self.MAX_LENGTH, k, device=devices[0])
else: else:
gs.append(graph_pool(src_buf, tgt_buf, device=devices[dev_id])) yield graph_pool(src_buf, tgt_buf, device=device)
dev_id += 1
if dev_id == len(devices):
yield gs if len(devices) > 1 else gs[0]
dev_id, gs = 0, []
src_buf, tgt_buf = [], [] src_buf, tgt_buf = [], []
if len(src_buf) != 0: if len(src_buf) != 0:
if mode == 'test': if mode == 'test':
yield graph_pool.beam(src_buf, self.sos_id, self.MAX_LENGTH, k, device=devices[0]) yield graph_pool.beam(src_buf, self.sos_id, self.MAX_LENGTH, k, device=device)
else: else:
gs.append(graph_pool(src_buf, tgt_buf, device=devices[dev_id])) yield graph_pool(src_buf, tgt_buf, device=device)
yield gs if len(devices) > 1 else gs[0]
def get_sequence(self, batch): def get_sequence(self, batch):
"return a list of sequence from a list of index arrays" "return a list of sequence from a list of index arrays"
...@@ -151,8 +155,8 @@ def get_dataset(dataset): ...@@ -151,8 +155,8 @@ def get_dataset(dataset):
raise NotImplementedError raise NotImplementedError
elif dataset == 'copy' or dataset == 'sort': elif dataset == 'copy' or dataset == 'sort':
return TranslationDataset( return TranslationDataset(
'data/{}'.format(dataset), 'data/{}'.format(dataset),
('in', 'out'), ('in', 'out'),
train='train', train='train',
valid='valid', valid='valid',
test='test', test='test',
......
...@@ -16,7 +16,7 @@ class Vocab: ...@@ -16,7 +16,7 @@ class Vocab:
self.vocab_lst.append(self.pad_token) self.vocab_lst.append(self.pad_token)
if self.unk_token is not None: if self.unk_token is not None:
self.vocab_lst.append(self.unk_token) self.vocab_lst.append(self.unk_token)
with open(path, 'r') as f: with open(path, 'r', encoding='utf-8') as f:
for token in f.readlines(): for token in f.readlines():
token = token.strip() token = token.strip()
self.vocab_lst.append(token) self.vocab_lst.append(token)
......
import torch as T import torch as T
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
import torch.distributed as dist
class LabelSmoothing(nn.Module): class LabelSmoothing(nn.Module):
""" """
...@@ -37,14 +38,37 @@ class LabelSmoothing(nn.Module): ...@@ -37,14 +38,37 @@ class LabelSmoothing(nn.Module):
class SimpleLossCompute(nn.Module): class SimpleLossCompute(nn.Module):
eps=1e-8 eps=1e-8
def __init__(self, criterion, opt=None): def __init__(self, criterion, grad_accum, opt=None):
""" """Loss function and optimizer for single device
opt is required during training
Parameters
----------
criterion: torch.nn.Module
criterion to compute loss
grad_accum: int
number of batches to accumulate gradients
opt: Optimizer
Model optimizer to use. If None, then no backward and update will be
performed
""" """
super(SimpleLossCompute, self).__init__() super(SimpleLossCompute, self).__init__()
self.criterion = criterion self.criterion = criterion
self.opt = opt self.opt = opt
self.reset() self.acc_loss = 0
self.n_correct = 0
self.norm_term = 0
self.loss = 0
self.batch_count = 0
self.grad_accum = grad_accum
def __enter__(self):
self.batch_count = 0
def __exit__(self, type, value, traceback):
# if not enough batches accumulated and there are gradients not applied,
# do one more step
if self.batch_count > 0:
self.step()
@property @property
def avg_loss(self): def avg_loss(self):
...@@ -54,32 +78,56 @@ class SimpleLossCompute(nn.Module): ...@@ -54,32 +78,56 @@ class SimpleLossCompute(nn.Module):
def accuracy(self): def accuracy(self):
return (self.n_correct + self.eps) / (self.norm_term + self.eps) return (self.n_correct + self.eps) / (self.norm_term + self.eps)
def reset(self): def step(self):
self.acc_loss = 0 self.opt.step()
self.n_correct = 0 self.opt.optimizer.zero_grad()
self.norm_term = 0
def backward_and_step(self):
self.loss.backward()
self.batch_count += 1
# accumulate self.grad_accum times then synchronize and update
if self.batch_count == self.grad_accum:
self.step()
self.batch_count = 0
def __call__(self, y_pred, y, norm): def __call__(self, y_pred, y, norm):
y_pred = y_pred.contiguous().view(-1, y_pred.shape[-1]) y_pred = y_pred.contiguous().view(-1, y_pred.shape[-1])
y = y.contiguous().view(-1) y = y.contiguous().view(-1)
loss = self.criterion( self.loss = self.criterion(y_pred, y) / norm
y_pred, y
) / norm
if self.opt is not None: if self.opt is not None:
loss.backward() self.backward_and_step()
self.opt.step()
self.opt.optimizer.zero_grad()
self.n_correct += ((y_pred.max(dim=-1)[1] == y) & (y != self.criterion.padding_idx)).sum().item() self.n_correct += ((y_pred.max(dim=-1)[1] == y) & (y != self.criterion.padding_idx)).sum().item()
self.acc_loss += loss.item() * norm self.acc_loss += self.loss.item() * norm
self.norm_term += norm self.norm_term += norm
return loss.item() * norm return self.loss.item() * norm
class MultiGPULossCompute(SimpleLossCompute): class MultiGPULossCompute(SimpleLossCompute):
def __init__(self, criterion, devices, opt=None, chunk_size=5): def __init__(self, criterion, ndev, grad_accum, model, opt=None):
self.criterion = criterion """Loss function and optimizer for multiple devices
self.opt = opt
self.devices = devices Parameters
self.chunk_size = chunk_size ----------
criterion: torch.nn.Module
criterion to compute loss
ndev: int
number of devices used
grad_accum: int
number of batches to accumulate gradients
model: torch.nn.Module
model to optimizer (needed to iterate and synchronize all parameters)
opt: Optimizer
Model optimizer to use. If None, then no backward and update will be
performed
"""
super(MultiGPULossCompute, self).__init__(criterion, grad_accum, opt=opt)
self.ndev = ndev
self.model = model
def __call__(self, y_preds, ys, norms): def step(self):
pass # multi-gpu synchronize gradients
for param in self.model.parameters():
if param.requires_grad and param.grad is not None:
dist.all_reduce(param.grad.data, op=dist.ReduceOp.SUM)
param.grad.data /= self.ndev
self.opt.step()
self.opt.optimizer.zero_grad()
...@@ -12,7 +12,7 @@ class MultiHeadAttention(nn.Module): ...@@ -12,7 +12,7 @@ class MultiHeadAttention(nn.Module):
self.h = h self.h = h
# W_q, W_k, W_v, W_o # W_q, W_k, W_v, W_o
self.linears = clones( self.linears = clones(
nn.Linear(dim_model, dim_model), 4 nn.Linear(dim_model, dim_model, bias=False), 4
) )
def get(self, x, fields='qkv'): def get(self, x, fields='qkv'):
......
...@@ -46,11 +46,11 @@ class Decoder(nn.Module): ...@@ -46,11 +46,11 @@ class Decoder(nn.Module):
layer = self.layers[i] layer = self.layers[i]
def func(nodes): def func(nodes):
x = nodes.data['x'] x = nodes.data['x']
if fields == 'kv': norm_x = layer.sublayer[l].norm(x) if fields.startswith('q') else x
norm_x = x # In enc-dec attention, x has already been normalized. if fields != 'qkv':
return layer.src_attn.get(norm_x, fields)
else: else:
norm_x = layer.sublayer[l].norm(x) return layer.self_attn.get(norm_x, fields)
return layer.self_attn.get(norm_x, fields)
return func return func
def post_func(self, i, l=0): def post_func(self, i, l=0):
...@@ -64,8 +64,6 @@ class Decoder(nn.Module): ...@@ -64,8 +64,6 @@ class Decoder(nn.Module):
return {'x': x if i < self.N - 1 else self.norm(x)} return {'x': x if i < self.N - 1 else self.norm(x)}
return func return func
lock = threading.Lock()
class Transformer(nn.Module): class Transformer(nn.Module):
def __init__(self, encoder, decoder, src_embed, tgt_embed, pos_enc, generator, h, d_k): def __init__(self, encoder, decoder, src_embed, tgt_embed, pos_enc, generator, h, d_k):
super(Transformer, self).__init__() super(Transformer, self).__init__()
...@@ -124,9 +122,10 @@ class Transformer(nn.Module): ...@@ -124,9 +122,10 @@ class Transformer(nn.Module):
self.update_graph(g, edges, [(pre_q, nodes), (pre_kv, nodes_e)], [(post_func, nodes)]) self.update_graph(g, edges, [(pre_q, nodes), (pre_kv, nodes_e)], [(post_func, nodes)])
# visualize attention # visualize attention
with lock: """
if self.att_weight_map is None: if self.att_weight_map is None:
self._register_att_map(g, graph.nid_arr['enc'][VIZ_IDX], graph.nid_arr['dec'][VIZ_IDX]) self._register_att_map(g, graph.nid_arr['enc'][VIZ_IDX], graph.nid_arr['dec'][VIZ_IDX])
"""
return self.generator(g.ndata['x'][nids['dec']]) return self.generator(g.ndata['x'][nids['dec']])
......
# Mostly then same with PyTorch
import threading
import torch
def get_a_var(obj):
if isinstance(obj, torch.Tensor):
return obj
if isinstance(obj, list) or isinstance(obj, tuple):
for result in map(get_a_var, obj):
if isinstance(result, torch.Tensor):
return result
if isinstance(obj, dict):
for result in map(get_a_var, obj.items()):
if isinstance(result, torch.Tensor):
return result
return None
def parallel_apply(modules, inputs):
assert len(modules) == len(inputs)
lock = threading.Lock()
results = {}
grad_enabled = torch.is_grad_enabled()
def _worker(i, module, input):
torch.set_grad_enabled(grad_enabled)
try:
#with torch.cuda.device(device):
output = module(input)
with lock:
results[i] = output
except Exception as e:
with lock:
results[i] = e
if len(modules) > 1:
threads = [threading.Thread(target=_worker,
args=(i, module, input))
for i, (module, input) in
enumerate(zip(modules, inputs))]
for thread in threads:
thread.start()
for thread in threads:
thread.join()
else:
_worker(0, modules[0], inputs[0])
outputs = []
for i in range(len(inputs)):
output = results[i]
if isinstance(output, Exception):
raise output
outputs.append(output)
return outputs
"""
In current version we use multi30k as the default training and validation set.
Multi-GPU support is required to train the model on WMT14.
"""
from modules import * from modules import *
from parallel import * from loss import *
from loss import *
from optims import * from optims import *
from dataset import * from dataset import *
from modules.config import * from modules.config import *
from modules.viz import * #from modules.viz import *
from tqdm import tqdm
import numpy as np import numpy as np
import argparse import argparse
import torch
from functools import partial
import torch.distributed as dist
def run_epoch(data_iter, model, loss_compute, is_train=True): def run_epoch(epoch, data_iter, dev_rank, ndev, model, loss_compute, is_train=True):
universal = isinstance(model, UTransformer) universal = isinstance(model, UTransformer)
for i, g in tqdm(enumerate(data_iter)): with loss_compute:
with T.set_grad_enabled(is_train): for i, g in enumerate(data_iter):
if isinstance(model, list): with T.set_grad_enabled(is_train):
model = model[:len(gs)]
output = parallel_apply(model, g)
tgt_y = [g.tgt_y for g in gs]
n_tokens = [g.n_tokens for g in gs]
else:
if universal: if universal:
output, loss_act = model(g) output, loss_act = model(g)
if is_train: loss_act.backward(retain_graph=True) if is_train: loss_act.backward(retain_graph=True)
...@@ -30,70 +22,134 @@ def run_epoch(data_iter, model, loss_compute, is_train=True): ...@@ -30,70 +22,134 @@ def run_epoch(data_iter, model, loss_compute, is_train=True):
output = model(g) output = model(g)
tgt_y = g.tgt_y tgt_y = g.tgt_y
n_tokens = g.n_tokens n_tokens = g.n_tokens
loss = loss_compute(output, tgt_y, n_tokens) loss = loss_compute(output, tgt_y, n_tokens)
if universal: if universal:
for step in range(1, model.MAX_DEPTH + 1): for step in range(1, model.MAX_DEPTH + 1):
print("nodes entering step {}: {:.2f}%".format(step, (1.0 * model.stat[step] / model.stat[0]))) print("nodes entering step {}: {:.2f}%".format(step, (1.0 * model.stat[step] / model.stat[0])))
model.reset_stat() model.reset_stat()
print('average loss: {}'.format(loss_compute.avg_loss)) print('Epoch {} {}: Dev {} average loss: {}, accuracy {}'.format(
print('accuracy: {}'.format(loss_compute.accuracy)) epoch, "Training" if is_train else "Evaluating",
dev_rank, loss_compute.avg_loss, loss_compute.accuracy))
if __name__ == '__main__': def run(dev_id, args):
if not os.path.exists('checkpoints'): dist_init_method = 'tcp://{master_ip}:{master_port}'.format(
os.makedirs('checkpoints') master_ip=args.master_ip, master_port=args.master_port)
np.random.seed(1111) world_size = args.ngpu
argparser = argparse.ArgumentParser('training translation model') torch.distributed.init_process_group(backend="nccl",
argparser.add_argument('--gpus', default='-1', type=str, help='gpu id') init_method=dist_init_method,
argparser.add_argument('--N', default=6, type=int, help='enc/dec layers') world_size=world_size,
argparser.add_argument('--dataset', default='multi30k', help='dataset') rank=dev_id)
argparser.add_argument('--batch', default=128, type=int, help='batch size') gpu_rank = torch.distributed.get_rank()
argparser.add_argument('--viz', action='store_true', help='visualize attention') assert gpu_rank == dev_id
argparser.add_argument('--universal', action='store_true', help='use universal transformer') main(dev_id, args)
args = argparser.parse_args()
args_filter = ['batch', 'gpus', 'viz']
exp_setting = '-'.join('{}'.format(v) for k, v in vars(args).items() if k not in args_filter)
devices = ['cpu'] if args.gpus == '-1' else [int(gpu_id) for gpu_id in args.gpus.split(',')]
def main(dev_id, args):
if dev_id == -1:
device = torch.device('cpu')
else:
device = torch.device('cuda:{}'.format(dev_id))
# Set current device
th.cuda.set_device(device)
# Prepare dataset
dataset = get_dataset(args.dataset) dataset = get_dataset(args.dataset)
V = dataset.vocab_size V = dataset.vocab_size
criterion = LabelSmoothing(V, padding_idx=dataset.pad_id, smoothing=0.1) criterion = LabelSmoothing(V, padding_idx=dataset.pad_id, smoothing=0.1)
dim_model = 512 dim_model = 512
# Build graph pool
graph_pool = GraphPool() graph_pool = GraphPool()
model = make_model(V, V, N=args.N, dim_model=dim_model, universal=args.universal) # Create model
model = make_model(V, V, N=args.N, dim_model=dim_model,
universal=args.universal)
# Sharing weights between Encoder & Decoder # Sharing weights between Encoder & Decoder
model.src_embed.lut.weight = model.tgt_embed.lut.weight model.src_embed.lut.weight = model.tgt_embed.lut.weight
model.generator.proj.weight = model.tgt_embed.lut.weight model.generator.proj.weight = model.tgt_embed.lut.weight
# Move model to corresponding device
model, criterion = model.to(device), criterion.to(device)
# Loss function
if args.ngpu > 1:
dev_rank = dev_id # current device id
ndev = args.ngpu # number of devices (including cpu)
loss_compute = partial(MultiGPULossCompute, criterion, args.ngpu,
args.grad_accum, model)
else: # cpu or single gpu case
dev_rank = 0
ndev = 1
loss_compute = partial(SimpleLossCompute, criterion, args.grad_accum)
model, criterion = model.to(devices[0]), criterion.to(devices[0]) if ndev > 1:
model_opt = NoamOpt(dim_model, 1, 400, for param in model.parameters():
T.optim.Adam(model.parameters(), lr=1e-3, betas=(0.9, 0.98), eps=1e-9)) dist.all_reduce(param.data, op=dist.ReduceOp.SUM)
if len(devices) > 1: param.data /= ndev
model, criterion = map(nn.parallel.replicate, [model, criterion], [devices, devices])
loss_compute = SimpleLossCompute if len(devices) == 1 else MultiGPULossCompute
# Optimizer
model_opt = NoamOpt(dim_model, 1, 4000,
T.optim.Adam(model.parameters(), lr=1e-3,
betas=(0.9, 0.98), eps=1e-9))
# Train & evaluate
for epoch in range(100): for epoch in range(100):
train_iter = dataset(graph_pool, mode='train', batch_size=args.batch, devices=devices) start = time.time()
valid_iter = dataset(graph_pool, mode='valid', batch_size=args.batch, devices=devices) train_iter = dataset(graph_pool, mode='train', batch_size=args.batch,
print('Epoch: {} Training...'.format(epoch)) device=device, dev_rank=dev_rank, ndev=ndev)
model.train(True) model.train(True)
run_epoch(train_iter, model, run_epoch(epoch, train_iter, dev_rank, ndev, model,
loss_compute(criterion, model_opt), is_train=True) loss_compute(opt=model_opt), is_train=True)
print('Epoch: {} Evaluating...'.format(epoch)) if dev_rank == 0:
model.att_weight_map = None model.att_weight_map = None
model.eval() model.eval()
run_epoch(valid_iter, model, valid_iter = dataset(graph_pool, mode='valid', batch_size=args.batch,
loss_compute(criterion, None), is_train=False) device=device, dev_rank=dev_rank, ndev=1)
# Visualize attention run_epoch(epoch, valid_iter, dev_rank, 1, model,
if args.viz: loss_compute(opt=None), is_train=False)
src_seq = dataset.get_seq_by_id(VIZ_IDX, mode='valid', field='src') end = time.time()
tgt_seq = dataset.get_seq_by_id(VIZ_IDX, mode='valid', field='tgt')[:-1] print("epoch time: {}".format(end - start))
draw_atts(model.att_weight_map, src_seq, tgt_seq, exp_setting, 'epoch_{}'.format(epoch))
# Visualize attention
if args.viz:
src_seq = dataset.get_seq_by_id(VIZ_IDX, mode='valid', field='src')
tgt_seq = dataset.get_seq_by_id(VIZ_IDX, mode='valid', field='tgt')[:-1]
draw_atts(model.att_weight_map, src_seq, tgt_seq, exp_setting, 'epoch_{}'.format(epoch))
args_filter = ['batch', 'gpus', 'viz', 'master_ip', 'master_port', 'grad_accum', 'ngpu']
exp_setting = '-'.join('{}'.format(v) for k, v in vars(args).items() if k not in args_filter)
with open('checkpoints/{}-{}.pkl'.format(exp_setting, epoch), 'wb') as f:
torch.save(model.state_dict(), f)
if __name__ == '__main__':
if not os.path.exists('checkpoints'):
os.makedirs('checkpoints')
np.random.seed(1111)
argparser = argparse.ArgumentParser('training translation model')
argparser.add_argument('--gpus', default='-1', type=str, help='gpu id')
argparser.add_argument('--N', default=6, type=int, help='enc/dec layers')
argparser.add_argument('--dataset', default='multi30k', help='dataset')
argparser.add_argument('--batch', default=128, type=int, help='batch size')
argparser.add_argument('--viz', action='store_true',
help='visualize attention')
argparser.add_argument('--universal', action='store_true',
help='use universal transformer')
argparser.add_argument('--master-ip', type=str, default='127.0.0.1',
help='master ip address')
argparser.add_argument('--master-port', type=str, default='12345',
help='master port')
argparser.add_argument('--grad-accum', type=int, default=1,
help='accumulate gradients for this many times '
'then update weights')
args = argparser.parse_args()
print(args)
print('----------------------------------') devices = list(map(int, args.gpus.split(',')))
with open('checkpoints/{}-{}.pkl'.format(exp_setting, epoch), 'wb') as f: if len(devices) == 1:
th.save(model.state_dict(), f) args.ngpu = 0 if devices[0] < 0 else 1
main(devices[0], args)
else:
args.ngpu = len(devices)
mp = torch.multiprocessing.get_context('spawn')
procs = []
for dev_id in devices:
procs.append(mp.Process(target=run, args=(dev_id, args),
daemon=True))
procs[-1].start()
for p in procs:
p.join()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment