Unverified Commit e141a93e authored by Siddharth Goyal's avatar Siddharth Goyal Committed by GitHub
Browse files

[feat] experimental: Add xpipe support (#553)

parent 204392e5
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. # Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
#
# This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree.
import argparse import argparse
import logging import logging
...@@ -203,7 +206,8 @@ class SpectrainSGDMomentum(Optimizer): ...@@ -203,7 +206,8 @@ class SpectrainSGDMomentum(Optimizer):
def modify_current_params_using_reference_params(self): def modify_current_params_using_reference_params(self):
self.copy_params(self.reference_params, self.cur_params) self.copy_params(self.reference_params, self.cur_params)
def update_weight_using_future_predictions(self, model_index, num_gpus, forward): # chunk_index and chunks parameters are for unused for spectrain usecase
def update_weight_using_future_predictions(self, model_index, num_gpus, chunk_index, chunks, forward):
if forward: if forward:
# In forward pass: # In forward pass:
...@@ -260,6 +264,226 @@ class SpectrainSGDMomentum(Optimizer): ...@@ -260,6 +264,226 @@ class SpectrainSGDMomentum(Optimizer):
return loss return loss
class XpipeAdam(Optimizer):
r"""Implements Xpipe approach on top of Adam algorithm.
It has been proposed in `Adam: A Method for Stochastic Optimization`_.
The implementation of the L2 penalty follows changes proposed in
`Decoupled Weight Decay Regularization`_.
Xpipe details can be found here: https://arxiv.org/abs/1911.04610
Arguments:
params (iterable): iterable of parameters to optimize or dicts defining
parameter groups
lr (float, optional): learning rate (default: 1e-3)
betas (Tuple[float, float], optional): coefficients used for computing
running averages of gradient and its square (default: (0.9, 0.999))
eps (float, optional): term added to the denominator to improve
numerical stability (default: 1e-8)
weight_decay (float, optional): weight decay (L2 penalty) (default: 0)
amsgrad (boolean, optional): whether to use the AMSGrad variant of this
algorithm from the paper `On the Convergence of Adam and Beyond`_
(default: False)
.. _Adam\: A Method for Stochastic Optimization:
https://arxiv.org/abs/1412.6980
.. _Decoupled Weight Decay Regularization:
https://arxiv.org/abs/1711.05101
.. _On the Convergence of Adam and Beyond:
https://openreview.net/forum?id=ryQu7f-RZ
"""
def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, weight_decay=0, amsgrad=False):
if not 0.0 <= lr:
raise ValueError("Invalid learning rate: {}".format(lr))
if not 0.0 <= eps:
raise ValueError("Invalid epsilon value: {}".format(eps))
if not 0.0 <= betas[0] < 1.0:
raise ValueError("Invalid beta parameter at index 0: {}".format(betas[0]))
if not 0.0 <= betas[1] < 1.0:
raise ValueError("Invalid beta parameter at index 1: {}".format(betas[1]))
if not 0.0 <= weight_decay:
raise ValueError("Invalid weight_decay value: {}".format(weight_decay))
defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay, amsgrad=amsgrad)
params = list(params)
super(XpipeAdam, self).__init__(params, defaults)
self.cur_params, self.master_params = self.prep_param_copies(params)
_, self.forward_params = self.prep_param_copies(params)
_, self.backward_params = self.prep_param_copies(params)
for group in self.param_groups:
for p in group["params"]:
param_state = self.state[p]
param_state["step"] = 0
# Exponential moving average of gradient values
param_state["exp_avg"] = torch.zeros_like(p.data)
# Exponential moving average of squared gradient values
param_state["exp_avg_sq"] = torch.zeros_like(p.data)
def __setstate__(self, state):
super(Adam, self).__setstate__(state)
for group in self.param_groups:
group.setdefault("amsgrad", False)
def prep_param_copies(self, params):
model_params = [param for param in params if param.requires_grad]
reference_params = [param.clone().detach() for param in model_params]
for param in reference_params:
param.requires_grad = True
return model_params, reference_params
def copy_params(self, master_params, model_params):
for model, master in zip(model_params, master_params):
model.data.copy_(master.data)
def update_weight_using_future_predictions(
self, model_index, num_gpus, current_microbatch_index, microbatches_per_minibatch, forward
):
if forward:
# Forward pass overview:
# if bell-weather:
# 1. read from master copy
# 2. predict and modify
# 3. flush updates to forward copy
# else:
# 1. read from forward copy
if current_microbatch_index % microbatches_per_minibatch == 0:
# read from master copy
self.copy_params(self.master_params, self.cur_params)
microbatch_index = current_microbatch_index + 1
# predict and modify
for group in self.param_groups:
multiplier = group["lr"] * round(
(microbatch_index + num_gpus - model_index / 2 - 2) / microbatch_index
)
beta1, beta2 = group["betas"]
eps = group["eps"]
for p in group["params"]:
param_state = self.state[p]
temp1 = param_state["exp_avg"].data / (1 - beta1)
temp2 = ((param_state["exp_avg_sq"].data / (1 - beta2)) + eps).sqrt()
p.data.addcdiv_(temp1, temp2, value=-multiplier)
# flush updates to forward copy
self.copy_params(self.cur_params, self.forward_params)
else:
self.copy_params(self.forward_params, self.cur_params)
else:
# Backward pass overview:
# if bell-weather:
# 1. read from master copy
# 2. predict and modify
# 3. flush updates to backward copy
# else:
# 1. read from backward copy
if current_microbatch_index % microbatches_per_minibatch == 0:
# read from master copy
self.copy_params(self.master_params, self.cur_params)
microbatch_index = current_microbatch_index + 1
# predict and modify
for group in self.param_groups:
multiplier = group["lr"] * (microbatch_index + model_index // 2 - 1) // microbatch_index
beta1, beta2 = group["betas"]
eps = group["eps"]
for p in group["params"]:
param_state = self.state[p]
temp1 = param_state["exp_avg"].data / (1 - beta1)
temp2 = ((param_state["exp_avg_sq"].data / (1 - beta2)) + eps).sqrt()
p.data.addcdiv_(temp1, temp2, value=-multiplier)
# flush updates to forward copy
self.copy_params(self.cur_params, self.backward_params)
else:
self.copy_params(self.backward_params, self.cur_params)
@torch.no_grad()
def step(self, closure=None):
"""Performs a single optimization step.
Arguments:
closure (callable, optional): A closure that reevaluates the model
and returns the loss.
"""
loss = None
if closure is not None:
with torch.enable_grad():
loss = closure()
for group in self.param_groups:
for p in group["params"]:
if p.grad is None:
continue
grad = p.grad.data
amsgrad = group.get("amsgrad", False)
p_data = p.data
state = self.state[p]
# State initialization
if len(state) == 0:
state["step"] = 0
# Exponential moving average of gradient values
state["exp_avg"] = torch.zeros_like(p_data)
# Exponential moving average of squared gradient values
state["exp_avg_sq"] = torch.zeros_like(p_data)
if amsgrad:
# Maintains max of all exp. moving avg. of sq. grad. values
state["max_exp_avg_sq"] = torch.zeros_like(p_data)
else:
state["exp_avg"] = state["exp_avg"].to(p_data)
state["exp_avg_sq"] = state["exp_avg_sq"].to(p_data)
if amsgrad:
state["max_exp_avg_sq"] = state["max_exp_avg_sq"].to(p_data)
exp_avg, exp_avg_sq = state["exp_avg"], state["exp_avg_sq"]
if amsgrad:
max_exp_avg_sq = state["max_exp_avg_sq"]
beta1, beta2 = group["betas"]
state["step"] += 1
exp_avg_data = exp_avg.data
exp_avg_sq_data = exp_avg_sq.data
# Decay the first and second moment running average coefficient
exp_avg_data.mul_(beta1).add_(grad, alpha=1 - beta1)
exp_avg_sq_data.mul_(beta2).addcmul_(grad, grad, value=1 - beta2)
if amsgrad:
# Maintains the maximum of all 2nd moment running avg. till now
torch.max(max_exp_avg_sq, exp_avg_sq_data, out=max_exp_avg_sq_data)
# Use the max. for normalizing running avg. of gradient
denom = max_exp_avg_sq.sqrt().add_(group["eps"])
else:
denom = exp_avg_sq_data.sqrt().add_(group["eps"])
bias_correction1 = 1 - beta1 ** state["step"]
bias_correction2 = 1 - beta2 ** state["step"]
step_size = group["lr"] * math.sqrt(bias_correction2) / bias_correction1
if group["weight_decay"] != 0:
p_data.add_(p_data, alpha=-group["weight_decay"] * group["lr"])
p_data.addcdiv_(exp_avg_data, denom, value=-step_size)
return loss
def get_data(device): def get_data(device):
with warnings.catch_warnings(record=True) as fjldska: with warnings.catch_warnings(record=True) as fjldska:
TEXT = torchtext.data.Field( TEXT = torchtext.data.Field(
...@@ -321,7 +545,9 @@ def make_model(args, device, ntokens): ...@@ -321,7 +545,9 @@ def make_model(args, device, ntokens):
return Adam(model.parameters(), lr=lr) return Adam(model.parameters(), lr=lr)
def make_custom_optimizer(model, args): def make_custom_optimizer(model, args):
if args.spectrain: if args.xpipe:
return XpipeAdam(model.parameters(), lr=lr)
elif args.spectrain:
return SpectrainSGDMomentum(model.parameters(), lr=lr) return SpectrainSGDMomentum(model.parameters(), lr=lr)
else: else:
return MySGD(model.parameters(), lr=lr) return MySGD(model.parameters(), lr=lr)
...@@ -398,7 +624,9 @@ def train(lm_dataloader, model, criterion, optimizer, vocab_size, args): ...@@ -398,7 +624,9 @@ def train(lm_dataloader, model, criterion, optimizer, vocab_size, args):
optimizer = optimizer(model, args) optimizer = optimizer(model, args)
transform_and_log = AsyncDelegate(vocab_size) transform_and_log = AsyncDelegate(vocab_size)
model.interleave(lm_dataloader, criterion, optimizer, transform_and_log, args.min_update_interval, args.spectrain) model.interleave(
lm_dataloader, criterion, optimizer, transform_and_log, args.min_update_interval, args.spectrain or args.xpipe
)
if model.group.rank() == model.group.size() - 1: if model.group.rank() == model.group.size() - 1:
print("Done with an epoch") print("Done with an epoch")
...@@ -615,6 +843,7 @@ parser.add_argument("--max-batch", type=int, default=4, help="Max number of batc ...@@ -615,6 +843,7 @@ parser.add_argument("--max-batch", type=int, default=4, help="Max number of batc
parser.add_argument("--socket-name", type=str, default=None, help="socket ifname for gloo/tp") parser.add_argument("--socket-name", type=str, default=None, help="socket ifname for gloo/tp")
parser.add_argument("--num-decoder-layers", type=int, default=10, help="Number of decoder layers in the model") parser.add_argument("--num-decoder-layers", type=int, default=10, help="Number of decoder layers in the model")
parser.add_argument("--spectrain", action="store_true", default=False, help="Use spectrain based weight prediction") parser.add_argument("--spectrain", action="store_true", default=False, help="Use spectrain based weight prediction")
parser.add_argument("--xpipe", action="store_true", default=False, help="Use xpipe based weight prediction")
parser.add_argument( parser.add_argument(
"--lazy-construction", action="store_true", default=False, help="Number of decoder layers in the model" "--lazy-construction", action="store_true", default=False, help="Number of decoder layers in the model"
) )
...@@ -627,7 +856,9 @@ parser.add_argument("--min-update-interval", type=int, default=1, help="min upda ...@@ -627,7 +856,9 @@ parser.add_argument("--min-update-interval", type=int, default=1, help="min upda
To run the script, To run the script,
1. please build a suitable version of OpenMPI with a cuda-enabled UCX backend. 1. please build a suitable version of OpenMPI with a cuda-enabled UCX backend.
2. For running on 2 gpus: 2. For running on 2 gpus:
<open-mpi-installed-dir>/bin/mpirun --host localhost:8 -np 2 --map-by node --mca pml ucx -x UCX_TLS=rc,sm,cuda_ipc,cuda_copy -x PYTHONPATH=$PWD -x PATH=$PATH -x LD_LIBRARY_PATH=$LD_LIBRARY_PATH -x UCX_RNDV_SCHEME=put_zcopy -x UCX_MEMTYPE_CACHE=n python3 benchmarks/experimental_ampnet.py --num-decoder-layers=8 --host localhost --batch-size 4 <open-mpi-installed-dir>/bin/mpirun --host localhost:8 -np 2 --map-by node --mca pml ucx -x UCX_TLS=rc,sm,cuda_ipc,cuda_copy -x PYTHONPATH=$PWD -x PATH=$PATH -x LD_LIBRARY_PATH=$LD_LIBRARY_PATH -x UCX_RNDV_SCHEME=put_zcopy -x UCX_MEMTYPE_CACHE=n python3 benchmarks/experimental/experimental_async_approaches.py --num-decoder-layers=8 --host localhost --batch-size 4
3. For doing Spectrain based weight prediction, add `--spectrain` to the training command line argument.
4. For doing Xpipe based weight prediction, add `--xpipe` to the training command line argument.
""" """
if __name__ == "__main__": if __name__ == "__main__":
......
...@@ -73,6 +73,7 @@ class AsyncAMPnetEventLoop: ...@@ -73,6 +73,7 @@ class AsyncAMPnetEventLoop:
weight_prediction: bool, weight_prediction: bool,
checkpoint_stop: int, checkpoint_stop: int,
input_device: Union[None, int, str, torch.device], input_device: Union[None, int, str, torch.device],
chunks: int,
): ):
self.partitions = partitions self.partitions = partitions
self.group = group self.group = group
...@@ -81,9 +82,14 @@ class AsyncAMPnetEventLoop: ...@@ -81,9 +82,14 @@ class AsyncAMPnetEventLoop:
self.weight_prediction = weight_prediction self.weight_prediction = weight_prediction
self.checkpoint_stop = checkpoint_stop self.checkpoint_stop = checkpoint_stop
self.input_device = input_device self.input_device = input_device
self.chunks = chunks
def perform_optimizer_step(self, optimizer: Any, num_gradients: Any) -> Any: def perform_optimizer_step(self, optimizer: Any, num_gradients: Any) -> Any:
return (optimizer is not None) and ((num_gradients % self.min_update_interval == 0) or self.weight_prediction) return (
(optimizer is not None)
and (not self.weight_prediction and num_gradients % self.min_update_interval == 0)
or (self.weight_prediction and num_gradients % self.chunks == 0)
)
def async_send_inner(self, batch: Batch, index: int) -> Tuple[Batch, PipeMessage]: def async_send_inner(self, batch: Batch, index: int) -> Tuple[Batch, PipeMessage]:
task = create_task_without_skip_trackers( task = create_task_without_skip_trackers(
...@@ -160,7 +166,7 @@ class AsyncAMPnetEventLoop: ...@@ -160,7 +166,7 @@ class AsyncAMPnetEventLoop:
reqd_input = transform_logger_object.transform_input(cur_batch).to(self.input_device) reqd_input = transform_logger_object.transform_input(cur_batch).to(self.input_device)
batch = Batch(reqd_input, count) batch = Batch(reqd_input, count)
if self.weight_prediction: if self.weight_prediction:
optimizer.update_weight_using_future_predictions(cur_rank, N, forward=True) optimizer.update_weight_using_future_predictions(cur_rank, N, count, self.chunks, forward=True)
activations[count], message = self.async_send_inner(batch, count) activations[count], message = self.async_send_inner(batch, count)
self.transport.send_message(message, sync=True) self.transport.send_message(message, sync=True)
count += 1 count += 1
...@@ -177,7 +183,7 @@ class AsyncAMPnetEventLoop: ...@@ -177,7 +183,7 @@ class AsyncAMPnetEventLoop:
reqd_input = transform_logger_object.transform_input(cur_batch).to(self.input_device) reqd_input = transform_logger_object.transform_input(cur_batch).to(self.input_device)
batch = Batch(reqd_input, count) batch = Batch(reqd_input, count)
if self.weight_prediction: if self.weight_prediction:
optimizer.update_weight_using_future_predictions(cur_rank, N, forward=True) optimizer.update_weight_using_future_predictions(cur_rank, N, count, self.chunks, forward=True)
activations[count], forward_message = self.async_send_inner(batch, count) activations[count], forward_message = self.async_send_inner(batch, count)
count += 1 count += 1
...@@ -186,7 +192,9 @@ class AsyncAMPnetEventLoop: ...@@ -186,7 +192,9 @@ class AsyncAMPnetEventLoop:
args: AsyncMessageBody = message.args args: AsyncMessageBody = message.args
assert args.message_type is AsyncMessageType.Gradients assert args.message_type is AsyncMessageType.Gradients
if self.weight_prediction: if self.weight_prediction:
optimizer.update_weight_using_future_predictions(cur_rank, N, forward=False) optimizer.update_weight_using_future_predictions(
cur_rank, N, num_gradients, self.chunks, forward=False
)
self.async_grad_inner(message, activations) self.async_grad_inner(message, activations)
# Send after grad # Send after grad
...@@ -208,7 +216,7 @@ class AsyncAMPnetEventLoop: ...@@ -208,7 +216,7 @@ class AsyncAMPnetEventLoop:
args = message.args args = message.args
assert args.message_type is AsyncMessageType.Gradients assert args.message_type is AsyncMessageType.Gradients
if self.weight_prediction: if self.weight_prediction:
optimizer.update_weight_using_future_predictions(cur_rank, N, forward=False) optimizer.update_weight_using_future_predictions(cur_rank, N, num_gradients, self.chunks, forward=False)
self.async_grad_inner(message, activations) self.async_grad_inner(message, activations)
num_gradients += 1 num_gradients += 1
...@@ -248,7 +256,7 @@ class AsyncAMPnetEventLoop: ...@@ -248,7 +256,7 @@ class AsyncAMPnetEventLoop:
batch = self.get_batch_from_message(message, EVENT_LOOP_GRADIENTS_QUEUE) batch = self.get_batch_from_message(message, EVENT_LOOP_GRADIENTS_QUEUE)
if self.weight_prediction: if self.weight_prediction:
optimizer.update_weight_using_future_predictions(cur_rank, N, forward=True) optimizer.update_weight_using_future_predictions(cur_rank, N, count, self.chunks, forward=True)
task = create_task_without_skip_trackers( task = create_task_without_skip_trackers(
self.checkpoint_stop, args.microbatch_index, self.group.rank(), batch, self.partitions[0].module, self.checkpoint_stop, args.microbatch_index, self.group.rank(), batch, self.partitions[0].module,
) )
...@@ -257,7 +265,9 @@ class AsyncAMPnetEventLoop: ...@@ -257,7 +265,9 @@ class AsyncAMPnetEventLoop:
task.finalize(output) task.finalize(output)
# one backward # one backward
if self.weight_prediction: if self.weight_prediction:
optimizer.update_weight_using_future_predictions(cur_rank, N, forward=False) optimizer.update_weight_using_future_predictions(
cur_rank, N, num_gradients, self.chunks, forward=False
)
output_tensor = transform_logger_object.transform_output_before_loss(output.tensor) output_tensor = transform_logger_object.transform_output_before_loss(output.tensor)
loss = criterion(output_tensor, reqd_target) loss = criterion(output_tensor, reqd_target)
...@@ -307,7 +317,9 @@ class AsyncAMPnetEventLoop: ...@@ -307,7 +317,9 @@ class AsyncAMPnetEventLoop:
n_warmup = ranks[-1] - cur_rank n_warmup = ranks[-1] - cur_rank
for _ in range(n_warmup): for _ in range(n_warmup):
if self.weight_prediction: if self.weight_prediction:
optimizer.update_weight_using_future_predictions(cur_rank, N, forward=True) optimizer.update_weight_using_future_predictions(
cur_rank, N, num_activations, self.chunks, forward=True
)
message = self.event_loop_trunk_forward_helper(activations) message = self.event_loop_trunk_forward_helper(activations)
self.transport.send_message(message, sync=True) self.transport.send_message(message, sync=True)
num_activations += 1 num_activations += 1
...@@ -316,13 +328,15 @@ class AsyncAMPnetEventLoop: ...@@ -316,13 +328,15 @@ class AsyncAMPnetEventLoop:
while num_activations < num_microbatch: while num_activations < num_microbatch:
# 1 Forward # 1 Forward
if self.weight_prediction: if self.weight_prediction:
optimizer.update_weight_using_future_predictions(cur_rank, N, forward=True) optimizer.update_weight_using_future_predictions(
cur_rank, N, num_activations, self.chunks, forward=True
)
message = self.event_loop_trunk_forward_helper(activations) message = self.event_loop_trunk_forward_helper(activations)
num_activations += 1 num_activations += 1
# 1 Backward # 1 Backward
if self.weight_prediction: if self.weight_prediction:
optimizer.update_weight_using_future_predictions(cur_rank, N, forward=False) optimizer.update_weight_using_future_predictions(cur_rank, N, num_gradients, self.chunks, forward=False)
self.event_loop_trunk_backward_helper(activations) self.event_loop_trunk_backward_helper(activations)
num_gradients += 1 num_gradients += 1
if self.perform_optimizer_step(optimizer, num_gradients): if self.perform_optimizer_step(optimizer, num_gradients):
...@@ -336,7 +350,7 @@ class AsyncAMPnetEventLoop: ...@@ -336,7 +350,7 @@ class AsyncAMPnetEventLoop:
remaining = len(activations) remaining = len(activations)
for _ in range(remaining): for _ in range(remaining):
if self.weight_prediction: if self.weight_prediction:
optimizer.update_weight_using_future_predictions(cur_rank, N, forward=False) optimizer.update_weight_using_future_predictions(cur_rank, N, num_gradients, self.chunks, forward=False)
self.event_loop_trunk_backward_helper(activations) self.event_loop_trunk_backward_helper(activations)
num_gradients += 1 num_gradients += 1
if self.perform_optimizer_step(optimizer, num_gradients): if self.perform_optimizer_step(optimizer, num_gradients):
......
...@@ -56,6 +56,7 @@ class AMPnetPipe(AsyncPipe): ...@@ -56,6 +56,7 @@ class AMPnetPipe(AsyncPipe):
weight_prediction, weight_prediction,
checkpoint_stop, checkpoint_stop,
self.input_device, self.input_device,
self.chunks,
) )
if rank == 0: if rank == 0:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment