Unverified Commit 5e8a6422 authored by Siddharth Goyal's avatar Siddharth Goyal Committed by GitHub
Browse files

[feat] experimental: Add spectrain support (#372)

* experimental: Add spectrain support

* Address review comments

* Address review comments
parent ccda8bd0
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
import torch
from torch.utils.data import Dataset
# TODO(sidgoyal): Refactor benchmarks to remove this file eventually.
def collate_sentences_lm(samples):
if len(samples) == 0:
return {}
id = torch.LongTensor([s["id"] for s in samples])
src_tokens = torch.stack([s["source"] for s in samples], 0)
tgt_tokens = torch.stack([s["target"] for s in samples], 0)
ntokens = len(samples) * len(samples[0]["target"])
src_lengths = torch.LongTensor([len(samples[0]["source"])] * len(samples))
batch = {
"id": id,
"nsentences": len(samples),
"ntokens": ntokens,
"input": src_tokens,
"target": tgt_tokens,
}
return batch
class BenchmarkLMDataset(Dataset):
"""
Dataset to benchmark a translation like seq2seq task.
Args:
vocab_size (int, optional): size of the vocabulary (default 10000).
max_source_positions (int, optional): max number of tokens in the
source sentence (default: 1024).
total_samples (int, optional): the total number of rows in the
dataset (default: 10000).
"""
def __init__(
self, vocab_size=10000, max_source_positions=1024, total_samples=10000,
):
self.vocab_size = vocab_size
self.max_source_positions = max_source_positions
self.total_samples = total_samples
self.sizes = [self.max_source_positions] * self.total_samples
def __getitem__(self, index):
length = self.sizes[index]
source = torch.randint(1, self.vocab_size, (length,))
target = source.clone()
return {
"id": index,
"source": source,
"target": target,
}
def __len__(self):
return self.total_samples
......@@ -131,7 +131,7 @@ class MySGD(Optimizer):
lr (float): learning rate (required)
"""
def __init__(self, params, lr=0.01):
def __init__(self, params, lr):
defaults = dict(lr=lr)
super(MySGD, self).__init__(params, defaults)
......@@ -140,7 +140,7 @@ class MySGD(Optimizer):
def step(self, closure=None):
""" Performs a single optimization step.
Arguments:
Args:
closure (callable, optional): A closure that reevaluates the model
and returns the loss.
"""
......@@ -157,6 +157,109 @@ class MySGD(Optimizer):
return loss
class SpectrainSGDMomentum(Optimizer):
r"""
Implements a SGD with momentum optimizer with Spectrain based weight
prediction. Please refer to the spectrain paper: https://arxiv.org/pdf/1809.02839.pdf
for more details.
Args:
params (iterable): iterable of parameters to optimize or dicts defining
parameter groups
lr (float): learning rate (required)
momentum (float): momentum (default=0.9)
"""
def __init__(self, params, lr, momentum=0.9):
defaults = dict(lr=lr, momentum=momentum)
params = list(params)
super(SpectrainSGDMomentum, self).__init__(params, defaults)
self.old_weights = None
self.cur_params, self.reference_params = self.prep_param_copies(params)
for group in self.param_groups:
for p in group["params"]:
if momentum != 0:
param_state = self.state[p]
param_state["momentum_buffer"] = torch.zeros_like(p.data)
def __setstate__(self, state):
super(SpectrainSGDMomentum, self).__setstate__(state)
def prep_param_copies(self, params):
model_params = [param for param in params if param.requires_grad]
reference_params = [param.clone().detach() for param in model_params]
for param in reference_params:
param.requires_grad = True
return model_params, reference_params
def copy_params(self, master_params, model_params):
for model, master in zip(model_params, master_params):
model.data.copy_(master.data)
def modify_reference_params_using_current_params(self):
self.copy_params(self.cur_params, self.reference_params)
def modify_current_params_using_reference_params(self):
self.copy_params(self.reference_params, self.cur_params)
def update_weight_using_future_predictions(self, model_index, num_gpus, forward):
if forward:
# In forward pass:
# 1. clone weights to self.old_weights
# 2. predict new weights and modify
self.modify_reference_params_using_current_params()
for group in self.param_groups:
multiplier = group["lr"] * (model_index // 2 + num_gpus - model_index - 1)
for p in group["params"]:
param_state = self.state[p]
p.data.sub_(param_state["momentum_buffer"].data, alpha=multiplier)
else:
# In backward pass:
# 1. load old weights
# 2. predict new weights and modify
self.modify_current_params_using_reference_params()
for group in self.param_groups:
multiplier = group["lr"] * (model_index // 2)
for p in group["params"]:
param_state = self.state[p]
p.data.sub_(param_state["momentum_buffer"].data, alpha=multiplier)
def step(self, weight_prediction=True, closure=None):
""" Performs a single optimization step.
Args:
weight_prediction (bool, optional): Enable weight prediction based updates
closure (callable, optional): A closure that reevaluates the model
and returns the loss.
"""
loss = None
if closure is not None:
loss = closure()
if weight_prediction:
self.modify_current_params_using_reference_params()
for group in self.param_groups:
momentum = group["momentum"]
for p in group["params"]:
if p.grad is None:
continue
d_p = p.grad.data
if momentum != 0:
param_state = self.state[p]
buf = param_state["momentum_buffer"]
buf.data.mul_(momentum).add_(d_p, alpha=1 - momentum)
d_p = buf
p.data.add_(d_p, alpha=-group["lr"])
return loss
def get_data(device):
with warnings.catch_warnings(record=True) as fjldska:
TEXT = torchtext.data.Field(
......@@ -215,28 +318,21 @@ def make_model(args, device, ntokens):
lr = 0.01 # learning rate
def make_adam(model):
# if args.ddp_zero:
# return OSS(params=model.parameters(), optim=Adam, group=get_data_parallel_group(), lr=lr)
# else:
return Adam(model.parameters(), lr=lr)
def make_custom_sgd(model):
def make_custom_optimizer(model, args):
if args.spectrain:
return SpectrainSGDMomentum(model.parameters(), lr=lr)
else:
return MySGD(model.parameters(), lr=lr)
optimizer = make_custom_sgd
optimizer = make_custom_optimizer
scaler = GradScaler()
return model, criterion, optimizer, scaler
def safe_rank():
try:
return torch.distributed.get_rank()
except AssertionError:
return 0
class AMPnetDelegate(object):
class AsyncDelegate(object):
def __init__(self, vocab_size, iteration_per_batch=1000):
self.cur_epoch = 0
self.cur_iteration = 0
......@@ -300,9 +396,9 @@ def train(lm_dataloader, model, criterion, optimizer, vocab_size, args):
start_time = time.time()
word_counter = 0
optimizer = optimizer(model)
transform_and_log = AMPnetDelegate(vocab_size)
model.interleave(lm_dataloader, criterion, optimizer, transform_and_log, args.min_update_interval)
optimizer = optimizer(model, args)
transform_and_log = AsyncDelegate(vocab_size)
model.interleave(lm_dataloader, criterion, optimizer, transform_and_log, args.min_update_interval, args.spectrain)
if model.group.rank() == model.group.size() - 1:
print("Done with an epoch")
......@@ -518,6 +614,7 @@ parser.add_argument("--batch-size", type=int, default=8, help="size of a batch")
parser.add_argument("--max-batch", type=int, default=4, help="Max number of batches")
parser.add_argument("--socket-name", type=str, default=None, help="socket ifname for gloo/tp")
parser.add_argument("--num-decoder-layers", type=int, default=10, help="Number of decoder layers in the model")
parser.add_argument("--spectrain", action="store_true", default=False, help="Use spectrain based weight prediction")
parser.add_argument(
"--lazy-construction", action="store_true", default=False, help="Number of decoder layers in the model"
)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment