Unverified Commit 14491030 authored by Siddharth Goyal's avatar Siddharth Goyal Committed by GitHub
Browse files

[feat] Add AMPnet implementation in experimental dir (#304)

* Add AMPnet implementation (clean version)

* Move ampnet to experimental

* Move stuff around pipeline

* Address review comments and fix pre-commit errors

* Refactor and modify delegate functionality

* Modify header in pipe.py
parent 8a49a748
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
import argparse
import logging
import math
import os
import sys
import time
import warnings
from benchmark_dataset import BenchmarkLMDataset, collate_sentences_lm
import torch
from torch.distributed import rpc
import torch.multiprocessing as mp
import torch.nn as nn
from torch.optim.optimizer import Optimizer
from torch.utils.data import DataLoader
import torchtext
from torchtext.data.utils import get_tokenizer
from experimental.nn.ampnet_pipe import pipe
from fairscale.nn import Pipe
from fairscale.nn.model_parallel import initialize_model_parallel
from fairscale.nn.model_parallel.initialize import get_pipeline_parallel_group
from fairscale.nn.pipe import LazyModule
from fairscale.optim import GradScaler
from fairscale.utils.testing import dist_init, get_worker_map
try:
from fairscale.optim import Adam # type: ignore
can_benchmark = True
except ImportError:
from torch.optim import Adam # type: ignore
can_benchmark = False
def init_random_seed(seed: int):
import numpy
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
numpy.random.seed(seed)
PIPE_CHUNKS = 2
iteration_count = 0
class EmbeddingLayer(nn.Embedding):
def __init__(self, ntoken, ninp, initrange):
super().__init__(ntoken, ninp)
self.ninp = ninp
self.weight.data.uniform_(-initrange, initrange)
def forward(self, src):
return super().forward(src) * math.sqrt(self.ninp)
class PositionalEncodingLayer(nn.Module):
def __init__(self, d_model, dropout=0.1, max_len=5000):
super(PositionalEncodingLayer, self).__init__()
self.dropout = nn.Dropout(p=dropout)
pe = torch.zeros(max_len, d_model)
position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
pe[:, 0::2] = torch.sin(position * div_term)
pe[:, 1::2] = torch.cos(position * div_term)
pe = pe.unsqueeze(0).transpose(0, 1)
self.register_buffer("pe", pe)
def forward(self, x):
x = x + self.pe[: x.size(0), :]
return self.dropout(x)
class TransformerDecoderLayer(nn.TransformerEncoderLayer):
"""Though this class inherits from torch.nn.TransformerEncoderLayer,
it functions as a decoder in this model"""
def __init__(self, ninp, nhead, nhid, droupout):
super().__init__(ninp, nhead, nhid, droupout)
self.src_mask = None
def _generate_square_subsequent_mask(self, sz):
mask = (torch.triu(torch.ones(sz, sz)) == 1).transpose(0, 1)
mask = mask.float().masked_fill(mask == 0, float("-inf")).masked_fill(mask == 1, float(0.0))
return mask
def forward(self, src):
global iteration_count
iteration_count += 1
if self.src_mask is None or self.src_mask.size(0) != len(src):
device = src.device
mask = self._generate_square_subsequent_mask(len(src)).to(device)
self.src_mask = mask
return super().forward(src, self.src_mask)
class LinearLayer(nn.Linear):
def __init__(self, ninp, ntoken, initrange):
super().__init__(ninp, ntoken)
self.bias.data.zero_()
self.weight.data.uniform_(-initrange, initrange)
class TransformerLMSequntial(nn.Sequential):
"""A small language model based on the design of GPT-2 using nn.Sequeitnal
for compatability with Pipe"""
def __init__(self, ntokens, ninp, nhead, nhid, dropout, initrange, ndecoder):
layers = [
EmbeddingLayer(ntokens, ninp, initrange),
PositionalEncodingLayer(ninp, dropout),
]
for _ in range(ndecoder):
layers.append(TransformerDecoderLayer(ninp, nhead, nhid, dropout))
layers.append(LinearLayer(ninp, ntokens, initrange))
super(TransformerLMSequntial, self).__init__(*layers)
class MySGD(Optimizer):
r"""
Args:
params (iterable): iterable of parameters to optimize or dicts defining
parameter groups
lr (float): learning rate (required)
"""
def __init__(self, params, lr=0.01):
defaults = dict(lr=lr)
super(MySGD, self).__init__(params, defaults)
def __setstate__(self, state):
super(MySGD, self).__setstate__(state)
def step(self, closure=None):
""" Performs a single optimization step.
Arguments:
closure (callable, optional): A closure that reevaluates the model
and returns the loss.
"""
loss = None
if closure is not None:
loss = closure()
for group in self.param_groups:
for p in group["params"]:
if p.grad is None:
continue
d_p = p.grad.data
p.data.add_(d_p, alpha=-group["lr"])
return loss
def get_data(device):
with warnings.catch_warnings(record=True) as fjldska:
TEXT = torchtext.data.Field(
tokenize=get_tokenizer("basic_english"), init_token="<sos>", eos_token="<eos>", lower=True
)
train_txt, val_txt, test_txt = torchtext.datasets.WikiText2.splits(TEXT)
TEXT.build_vocab(train_txt)
ntokens = len(TEXT.vocab.stoi)
batch_size = 20
eval_batch_size = 10
train_data = batchify(train_txt, batch_size, TEXT, device)
val_data = batchify(val_txt, eval_batch_size, TEXT, device)
test_data = batchify(test_txt, eval_batch_size, TEXT, device)
return ntokens, train_data, val_data, test_data
def batchify(data, bsz, TEXT, device):
data = TEXT.numericalize([data.examples[0].text])
nbatch = data.size(0) // bsz
data = data.narrow(0, 0, nbatch * bsz)
data = data.view(bsz, -1).t().contiguous()
return data.to(device)
def get_batch(source, i, bptt):
seq_len = min(bptt, len(source) - 1 - i)
data = source[i : i + seq_len]
target = source[i + 1 : i + 1 + seq_len].view(-1)
return data, target
def make_model(args, device, ntokens):
ninp = 2048 # embedding dimension
nhid = 2048 # the dimension of the feedforward network model in nn.TransformerEncoder
nhead = 32 # the number of heads in the multiheadattention models
dropout = 0
initrange = 0.1
ndecoder = args.num_decoder_layers
if args.lazy_construction:
layers = [
LazyModule(lambda: EmbeddingLayer(ntokens, ninp, initrange)),
LazyModule(lambda: PositionalEncodingLayer(ninp, dropout)),
]
for _ in range(ndecoder):
layers.append(LazyModule(lambda: TransformerDecoderLayer(ninp, nhead, nhid, dropout)))
layers.append(LazyModule(lambda: LinearLayer(ninp, ntokens, initrange)))
model = layers
else:
model = TransformerLMSequntial(ntokens, ninp, nhead, nhid, dropout, initrange, ndecoder).to(device)
criterion = nn.CrossEntropyLoss()
lr = 0.01 # learning rate
def make_adam(model):
# if args.ddp_zero:
# return OSS(params=model.parameters(), optim=Adam, group=get_data_parallel_group(), lr=lr)
# else:
return Adam(model.parameters(), lr=lr)
def make_custom_sgd(model):
return MySGD(model.parameters(), lr=lr)
optimizer = make_custom_sgd
scaler = GradScaler()
return model, criterion, optimizer, scaler
def safe_rank():
try:
return torch.distributed.get_rank()
except AssertionError:
return 0
class AMPnetDelegate(object):
def __init__(self, vocab_size, iteration_per_batch=1000):
self.cur_epoch = 0
self.cur_iteration = 0
self.iteration_per_batch = iteration_per_batch
self.vocab_size = vocab_size
self.word_counter = 0
self.start_time = time.time()
self.log_interval = 1
self.total_loss = 0
def transform_input(self, cur_batch):
return cur_batch["input"]
def transform_target(self, cur_batch):
return cur_batch["target"].view(-1)
def log_loss(self, cur_batch, loss, count):
self.word_counter += cur_batch["ntokens"]
if count % self.log_interval == 0 and count > 0:
self.total_loss += loss.item()
cur_loss = self.total_loss / self.log_interval
elapsed = time.time() - self.start_time
print(
"| batch {:5d} | wps {:5.2f} | loss {:5.2f} | ppl {:8.2f}".format(
count, self.word_counter / elapsed, cur_loss, math.exp(cur_loss)
)
)
self.word_counter = 0
self.total_loss = 0
self.start_time = time.time()
def transform_output_before_loss(self, output_tensor):
return output_tensor.view(-1, self.vocab_size)
def check_and_save_weights(self, num_gradients):
pass
def train(lm_dataloader, model, criterion, optimizer, vocab_size, args):
model.train()
from functools import reduce
import operator
num_params = reduce(operator.add, (reduce(operator.mul, x.size()) for x in model.parameters()))
if model.group:
total = torch.Tensor([num_params])
if torch.cuda.is_available():
total = total.cuda()
torch.distributed.all_reduce(total, group=model.group)
logging.info(
f"training model, #prams = {num_params}, group: {model.group.rank()}, grank:"
f" {torch.distributed.get_rank()}, sizes {model.group.size()}"
)
torch.distributed.barrier()
if model.group.rank() == 0:
logging.info(f"total #prams = {total.item()}")
else:
logging.info(f"training model, #prams = {num_params}")
vocab_size = 10000 # FIXME
total_loss = 0.0
start_time = time.time()
word_counter = 0
optimizer = optimizer(model)
transform_and_log = AMPnetDelegate(vocab_size)
model.interleave(lm_dataloader, criterion, optimizer, transform_and_log, args.min_update_interval)
if model.group.rank() == model.group.size() - 1:
print("Done with an epoch")
def evaluate(eval_model, data_source, criterion, bptt, ntokens):
eval_model.eval()
total_loss = 0.0
with torch.no_grad():
for i in range(0, data_source.size(0) - 1, bptt):
data, targets = get_batch(data_source, i, bptt)
output = eval_model(data)
output = output.to(targets.device)
output_flat = output.view(-1, ntokens)
total_loss += len(data) * criterion(output_flat, targets).item()
return total_loss / (len(data_source) - 1)
def get_number_of_words(data):
return data.size()[0] * data.size()[1]
def benchmark_language_model(train_data, val_data, test_data, model, criterion, optimizer, ntokens, args):
epoch = 1
bptt = 35
start_time = time.time()
print("-" * 110)
print("| start of epoch {:1d}".format(epoch))
print("-" * 110)
epoch_start_time = time.time()
train(train_data, model, criterion, optimizer, bptt, ntokens, args)
val_loss = 1 # evaluate(model, val_data, criterion, bptt, ntokens)
print("-" * 89)
print(
"| end of epoch {:1d} | time: {:5.2f}s | valid loss {:5.2f} ".format(
epoch, (time.time() - epoch_start_time), val_loss
)
)
print("-" * 110)
elapsed_time = time.time() - start_time
nwords = get_number_of_words(train_data) + get_number_of_words(val_data)
wps = nwords / elapsed_time
test_loss = 1 # evaluate(model, test_data, criterion, bptt, ntokens)
print("=" * 89)
print(
"| end of training | test loss {:5.2f} \n| time: {:5.2f}s | words: {:3d} | wps: {:5.2f}".format(
test_loss, elapsed_time, nwords, wps
)
)
print("=" * 110)
def generate_balance_weighted(num_devices, num_layers, fraction=0.5):
balance = []
layers_assigned = 0
average_count = num_layers / num_devices
last_layers = int(average_count * fraction)
balance = generate_balance(num_devices - 1, num_layers - last_layers)
balance.append(last_layers)
return balance
def generate_balance(num_devices, num_layers):
balance = []
layers_assigned = 0
for i in range(num_devices):
x = (num_layers - layers_assigned) / (num_devices - i)
if x.is_integer():
balance.append(int(x))
layers_assigned += x
else:
balance.append(math.ceil(x))
layers_assigned += math.ceil(x)
return balance
def make_model_and_data(args, device, new_data: bool = True):
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
if new_data:
vocab_size = 10000
model, criterion, optimizer, scaler = make_model(args, device, vocab_size)
lm_dataset = BenchmarkLMDataset()
lm_dataloader = DataLoader(
lm_dataset, batch_size=args.batch_size, shuffle=True, num_workers=0, collate_fn=collate_sentences_lm
)
return {
"model": model,
"criterion": criterion,
"optimizer": optimizer,
"data": lm_dataloader,
"vocab_size": vocab_size,
}
else:
data = get_data(device)
ntokens, train_data, val_data, test_data = data
model, criterion, optimizer, scaler = make_model(args, device, ntokens)
return {
"model": model,
"criterion": criterion,
"optimizer": optimizer,
"data": data,
}
def run_mp_worker(args, available_workers):
new_data = True
blob = make_model_and_data(args, None, new_data=new_data)
model = blob["model"]
balance = generate_balance(get_pipeline_parallel_group().size(), len(model))
p = pipe.AMPnetPipe(
module=model,
balance=balance,
style=Pipe.AsyncSchedule,
chunks=args.chunks,
worker_map=get_worker_map(),
input_device=torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu"),
pipelined_backward=False,
checkpoint=args.checkpoint,
)
if torch.cuda.is_available():
p = p.cuda()
if new_data:
train(blob["data"], p, blob["criterion"], blob["optimizer"], blob["vocab_size"], args)
else:
ntokens, train_data, val_data, test_data = blob["data"]
benchmark_language_model(train_data, val_data, test_data, p, criterion, optimizer, ntokens, args)
def run_worker(rank, world_size, args):
if args.world_size != 0:
world_size = args.world_size
dist_init(rank + args.rank_base, world_size, hostname=args.host)
initialize_model_parallel(1, world_size)
init_random_seed(0)
run_mp_worker(args, world_size)
rpc.shutdown()
torch.distributed.destroy_process_group()
def bench_multi_process(args, all_at_once=False):
if args.local_world_size != 0:
world_size = args.local_world_size
else:
world_size = min(torch.cuda.device_count(), 2)
mp.spawn(run_worker, args=(world_size, args), nprocs=world_size, join=True)
best_device_map = {
0: "mlx5_0:1",
1: "mlx5_0:1",
2: "mlx5_1:1",
3: "mlx5_1:1",
4: "mlx5_2:1",
5: "mlx5_2:1",
6: "mlx5_3:1",
7: "mlx5_3:1",
}
def bench_mpi(args):
guess_rank = int(os.environ["OMPI_COMM_WORLD_RANK"])
world_size = int(os.environ["OMPI_COMM_WORLD_SIZE"])
local_rank = int(os.environ["OMPI_COMM_WORLD_LOCAL_RANK"])
os.environ["UCX_NET_DEVICES"] = best_device_map[local_rank]
os.environ["MASTER_ADDR"] = args.host
os.environ["MASTER_PORT"] = "10638"
if args.socket_name:
os.environ["GLOO_SOCKET_IFNAME"] = args.socket_name
os.environ["TP_SOCKET_IFNAME"] = args.socket_name
torch.distributed.init_process_group(backend="gloo", rank=guess_rank, world_size=world_size)
os.environ["MASTER_ADDR"] = args.host
os.environ["MASTER_PORT"] = "10639"
init_method = f"tcp://{os.environ['MASTER_ADDR']}:{os.environ['MASTER_PORT']}"
rank = torch.distributed.get_rank()
world_size = torch.distributed.get_world_size()
torch.cuda.set_device(local_rank % torch.cuda.device_count())
rpc.init_rpc(
f"Test{rank}",
rank=rank,
world_size=world_size,
backend=rpc.BackendType.PROCESS_GROUP,
rpc_backend_options=rpc.ProcessGroupRpcBackendOptions(rpc_timeout=20, init_method=init_method),
)
backends = {"model_parallel_backend": "nccl", "pipeline_backend": "mpi", "ddp_backend": "nccl"}
initialize_model_parallel(1, world_size, **backends)
init_random_seed(0)
run_mp_worker(args, world_size)
rpc.shutdown()
torch.distributed.destroy_process_group()
parser = argparse.ArgumentParser(description="benchmark")
parser.add_argument("--local-world-size", "-l", type=int, default=0, help="local world size")
parser.add_argument("--world-size", "-w", type=int, default=0, help="world size")
parser.add_argument("--rank-base", "-r", type=int, help="rank base", default=0)
parser.add_argument("--host", "-o", type=str, default="localhost", help="hostname")
parser.add_argument("--no-mpi", action="store_true", default=False, help="disable mpi")
parser.add_argument("--chunks", type=int, default=1, help="number of microbatches per batch")
parser.add_argument("--batch-size", type=int, default=8, help="size of a batch")
parser.add_argument("--max-batch", type=int, default=4, help="Max number of batches")
parser.add_argument("--socket-name", type=str, default=None, help="socket ifname for gloo/tp")
parser.add_argument("--num-decoder-layers", type=int, default=10, help="Number of decoder layers in the model")
parser.add_argument(
"--lazy-construction", action="store_true", default=False, help="Number of decoder layers in the model"
)
parser.add_argument(
"--checkpoint", default="never", choices=["always", "except_last", "never"], help="Checkpointing strategy for pipe"
)
parser.add_argument("--min-update-interval", type=int, default=1, help="min update interval for ampnet")
"""
To run the script,
1. please build a suitable version of OpenMPI with a cuda-enabled UCX backend.
2. For running on 2 gpus:
<open-mpi-installed-dir>/bin/mpirun --host localhost:8 -np 2 --map-by node --mca pml ucx -x UCX_TLS=rc,sm,cuda_ipc,cuda_copy -x PYTHONPATH=$PWD -x PATH=$PATH -x LD_LIBRARY_PATH=$LD_LIBRARY_PATH -x UCX_RNDV_SCHEME=put_zcopy -x UCX_MEMTYPE_CACHE=n python3 benchmarks/experimental_ampnet.py --num-decoder-layers=8 --host localhost --batch-size 4
"""
if __name__ == "__main__":
args = parser.parse_args()
# bench_multi_process(args, all_at_once=True)
if args.no_mpi or "OMPI_COMM_WORLD_RANK" not in os.environ:
print(f"Can't run benchmark")
sys.exit(1)
else:
if os.environ["OMPI_COMM_WORLD_RANK"] == "0":
print(f"Running benchmark with args: {args}")
bench_mpi(args)
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
#
# This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree.
################################################################################
# Import most common subpackages
################################################################################
from . import nn
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
#
# This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree.
import time
from typing import Any, Dict, List, Tuple, Union
import torch
from torch import nn
from torch.autograd.profiler import record_function
from torch.distributed import ProcessGroup
from torch.optim.optimizer import Optimizer
from torch.utils.data import DataLoader
from fairscale.nn.model_parallel import get_pipeline_parallel_ranks
from fairscale.nn.pipe.async_schedule import (
AsyncMessageBody,
AsyncMessageType,
AsyncRecvOperator,
Location,
ModuleWrapper,
)
from fairscale.nn.pipe.checkpoint import Checkpointing
from fairscale.nn.pipe.messages import Transport
from fairscale.nn.pipe.microbatch import Batch
from fairscale.nn.pipe.types import (
EVENT_LOOP_ACTIVATIONS_QUEUE,
EVENT_LOOP_GRADIENTS_QUEUE,
PipeMessage,
TensorOrTensors,
)
from fairscale.nn.pipe.worker import Task
def create_task_without_skip_trackers(
checkpoint_stop: int, i: int, j: int, batch: Batch, partition: nn.Sequential,
) -> Task:
# Determine whether checkpointing or not.
# style is guaranteed to be PipelineStyle.AsyncSchedule
if i < checkpoint_stop:
def function(
input: TensorOrTensors, partition: nn.Sequential = partition, chunk_id: int = i, part_id: int = j,
) -> TensorOrTensors:
with record_function("chunk%d-part%d" % (chunk_id, part_id)):
return partition(input)
chk = Checkpointing(function, batch)
task = Task(None, compute=chk.checkpoint, finalize=chk.recompute)
del function, chk
else:
def compute(
batch: Batch = batch, partition: nn.Sequential = partition, chunk_id: int = i, part_id: int = j,
) -> Batch:
with record_function("chunk%d-part%d" % (chunk_id, part_id)):
return batch.call(partition)
task = Task(None, compute=compute, finalize=None)
del compute
return task
class AsyncAMPnetEventLoop:
def __init__(
self,
partitions: List[ModuleWrapper],
group: ProcessGroup,
transport: Transport,
min_update_interval: int,
weight_prediction: bool,
checkpoint_stop: int,
input_device: Union[None, int, str, torch.device],
):
self.partitions = partitions
self.group = group
self.transport = transport
self.min_update_interval = min_update_interval
self.weight_prediction = weight_prediction
self.checkpoint_stop = checkpoint_stop
self.input_device = input_device
def perform_optimizer_step(self, optimizer, num_gradients):
return (optimizer is not None) and ((num_gradients % self.min_update_interval == 0) or self.weight_prediction)
def async_send_inner(self, batch: Batch, index: int) -> Tuple[Batch, PipeMessage]:
task = create_task_without_skip_trackers(
self.checkpoint_stop, index, self.group.rank(), batch, self.partitions[0].module,
)
result = task.compute()
task.finalize(result)
ranks = get_pipeline_parallel_ranks()
this_rank = torch.distributed.get_rank()
body = AsyncMessageBody(
AsyncMessageType.Activations,
index,
Location(this_rank, 0),
Location(ranks[ranks.index(this_rank) + 1], 0),
0,
)
message = PipeMessage(
this_rank,
ranks[ranks.index(this_rank) + 1],
queue_name=EVENT_LOOP_ACTIVATIONS_QUEUE,
args=body,
tensors=tuple([*result]),
)
return result, message
def async_grad_inner(self, message: PipeMessage, activations: Dict[int, Batch]) -> None:
args: AsyncMessageBody = message.args
recvd_grads = self.transport.recv_message_tensors(message)
batch = activations[args.microbatch_index]
if len(recvd_grads.tensors) != len(batch):
raise RuntimeError("different number of tensors and gradients")
grads = []
final_tensors = []
for i, tensor in enumerate(batch):
if tensor.requires_grad or getattr(tensor, "grad_fn", None) is not None:
grads.append(recvd_grads.tensors[i])
final_tensors.append(tensor)
torch.autograd.backward(final_tensors, grad_tensors=grads, retain_graph=True)
del activations[args.microbatch_index]
def get_batch_from_message(self, message: PipeMessage, queue_name: int) -> Batch:
"""Get the tensor(s) wrapped in a `Batch` from a `PipeMessage`, applying
AsyncRecvOperator so we can intercept the backward pass"""
microbatch_index = message.args.microbatch_index
phony = torch.empty(0, device=self.transport.input_device, requires_grad=True)
result = AsyncRecvOperator.apply(phony, self.transport, message, queue_name)
if len(result) == 1:
batch = Batch(result[0], microbatch_index)
else:
batch = Batch(result, microbatch_index)
return batch
def event_loop_head_across_minibatches(
self, lm_dataloader: DataLoader, criterion: nn.Module, optimizer: Optimizer, transform_logger_object: Any
) -> None:
# handles one epoch
cur_rank = self.group.rank()
N = len(get_pipeline_parallel_ranks()) # for warmup phase
activations = dict()
count = 0
num_gradients = 0
lm_iter = iter(lm_dataloader)
# filling the pipeline: warmup -> all N - 1 forward passes
while True:
try:
cur_batch = next(lm_iter)
reqd_input = transform_logger_object.transform_input(cur_batch).to(self.input_device)
batch = Batch(reqd_input, count)
if self.weight_prediction:
optimizer.update_weight_using_future_predictions(cur_rank, N, forward=True) # type: ignore
activations[count], message = self.async_send_inner(batch, count)
self.transport.send_message(message, sync=True)
count += 1
if count == N - 1:
break
except StopIteration:
break
# steady state
while True:
try:
# 1 forward pass
cur_batch = next(lm_iter)
reqd_input = transform_logger_object.transform_input(cur_batch).to(self.input_device)
batch = Batch(reqd_input, count)
if self.weight_prediction:
optimizer.update_weight_using_future_predictions(cur_rank, N, forward=True) # type: ignore
activations[count], forward_message = self.async_send_inner(batch, count)
count += 1
# 1 backward pass
message = self.transport.recv_message_header(EVENT_LOOP_GRADIENTS_QUEUE)
args: AsyncMessageBody = message.args
assert args.message_type is AsyncMessageType.Gradients
if self.weight_prediction:
optimizer.update_weight_using_future_predictions(cur_rank, N, forward=False) # type: ignore
self.async_grad_inner(message, activations)
# Send after grad
self.transport.send_message(forward_message, sync=True)
num_gradients += 1
if self.perform_optimizer_step(optimizer, num_gradients):
optimizer.step()
optimizer.zero_grad()
transform_logger_object.check_and_save_weights(num_gradients)
except StopIteration:
break
# remaining items for backward
remaining_items = len(activations)
for _ in range(remaining_items):
message = self.transport.recv_message_header(EVENT_LOOP_GRADIENTS_QUEUE)
args = message.args
assert args.message_type is AsyncMessageType.Gradients
if self.weight_prediction:
optimizer.update_weight_using_future_predictions(cur_rank, N, forward=False) # type: ignore
self.async_grad_inner(message, activations)
num_gradients += 1
if self.perform_optimizer_step(optimizer, num_gradients):
optimizer.step()
optimizer.zero_grad()
transform_logger_object.check_and_save_weights(num_gradients)
def event_loop_tail_across_minibatches(
self, lm_dataloader: DataLoader, criterion: nn.Module, optimizer: Optimizer, transform_logger_object: Any
) -> None:
# handles one epoch
cur_rank = self.group.rank()
N = len(get_pipeline_parallel_ranks())
num_batches = len(lm_dataloader)
lm_iter = enumerate(lm_dataloader)
# last partition -> one forward / one backward -> no warmup
count = 0
num_gradients = 0
activations = dict()
log_interval = 1
word_counter = 0
total_loss = 0
while True:
try:
start_time = time.time()
microbatch_index, cur_batch = next(lm_iter)
reqd_target = transform_logger_object.transform_target(cur_batch).to(self.input_device)
# one forward
message = self.transport.recv_message_header(EVENT_LOOP_ACTIVATIONS_QUEUE)
args: AsyncMessageBody = message.args
assert args.microbatch_index == count
batch = self.get_batch_from_message(message, EVENT_LOOP_GRADIENTS_QUEUE)
if self.weight_prediction:
optimizer.update_weight_using_future_predictions(cur_rank, N, forward=True) # type: ignore
task = create_task_without_skip_trackers(
self.checkpoint_stop, args.microbatch_index, self.group.rank(), batch, self.partitions[0].module,
)
output = task.compute()
activations[args.microbatch_index] = output
task.finalize(output)
# one backward
if self.weight_prediction:
optimizer.update_weight_using_future_predictions(cur_rank, N, forward=False) # type: ignore
output_tensor = transform_logger_object.transform_output_before_loss(output.tensor)
loss = criterion(output_tensor, reqd_target)
loss.backward()
count += 1
num_gradients += 1
if self.perform_optimizer_step(optimizer, num_gradients):
optimizer.step()
optimizer.zero_grad()
transform_logger_object.check_and_save_weights(num_gradients)
transform_logger_object.log_loss(cur_batch, loss, count)
del loss
del activations[args.microbatch_index]
except StopIteration:
break
def event_loop_trunk_forward_helper(self, activations: Dict[int, Batch]) -> PipeMessage:
message = self.transport.recv_message_header(EVENT_LOOP_ACTIVATIONS_QUEUE)
args: AsyncMessageBody = message.args
assert args.message_type is AsyncMessageType.Activations
batch = self.get_batch_from_message(message, EVENT_LOOP_GRADIENTS_QUEUE)
activations[args.microbatch_index], message = self.async_send_inner(batch, args.microbatch_index)
return message
def event_loop_trunk_backward_helper(self, activations: Dict[int, Batch]) -> None:
message = self.transport.recv_message_header(EVENT_LOOP_GRADIENTS_QUEUE)
args: AsyncMessageBody = message.args
assert args.message_type is AsyncMessageType.Gradients
self.async_grad_inner(message, activations)
def event_loop_across_minibatches(
self, lm_dataloader: DataLoader, criterion: nn.Module, optimizer: Optimizer, transform_logger_object: Any
) -> None:
activations: Dict[int, Batch] = dict()
num_microbatch = len(lm_dataloader)
num_activations = 0
num_gradients = 0
ranks = get_pipeline_parallel_ranks() # for warmup phase
N = len(ranks)
cur_rank = torch.distributed.get_rank()
# warmup phase (forward passes)
# cur_rank worker will do (max_rank - cur_rank) forward passes
n_warmup = ranks[-1] - cur_rank
for _ in range(n_warmup):
if self.weight_prediction:
optimizer.update_weight_using_future_predictions(cur_rank, N, forward=True) # type: ignore
message = self.event_loop_trunk_forward_helper(activations)
self.transport.send_message(message, sync=True)
num_activations += 1
# common loop for remanining items in the warmup phase and steady phase
while num_activations < num_microbatch:
# 1 Forward
if self.weight_prediction:
optimizer.update_weight_using_future_predictions(cur_rank, N, forward=True) # type: ignore
message = self.event_loop_trunk_forward_helper(activations)
num_activations += 1
# 1 Backward
if self.weight_prediction:
optimizer.update_weight_using_future_predictions(cur_rank, N, forward=False) # type: ignore
self.event_loop_trunk_backward_helper(activations)
num_gradients += 1
if self.perform_optimizer_step(optimizer, num_gradients):
optimizer.step()
optimizer.zero_grad()
transform_logger_object.check_and_save_weights(num_gradients)
self.transport.send_message(message, sync=True)
# remaining backwards
remaining = len(activations)
for _ in range(remaining):
if self.weight_prediction:
optimizer.update_weight_using_future_predictions(cur_rank, N, forward=False) # type: ignore
self.event_loop_trunk_backward_helper(activations)
num_gradients += 1
if self.perform_optimizer_step(optimizer, num_gradients):
optimizer.step()
optimizer.zero_grad()
transform_logger_object.check_and_save_weights(num_gradients)
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
#
# This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree.
"""The AMPnetPipe interface."""
from typing import Any
from torch import nn
from torch.optim.optimizer import Optimizer
from torch.utils.data import DataLoader
from fairscale.nn.pipe import Pipe
from fairscale.nn.pipe.types import PipelineStyle
from .ampnet import AsyncAMPnetEventLoop
__all__ = ["AMPnetPipe"]
class AMPnetPipe(Pipe):
"""
AMPnetPipe is the asynchronous version of the Pipe implementation
which avoids the bubble issue, by using stale weights and gradients.
The implementation closely follows the paper: https://arxiv.org/abs/1705.09786
"""
def __init__(self, **kwargs) -> None:
super().__init__(**kwargs)
def interleave(
self,
lm_dataloader: DataLoader,
criterion: nn.Module,
optimizer: Optimizer,
transform_logger_object: Any,
min_update_interval: int = 1,
weight_prediction: bool = False,
) -> None:
partitions = self.mp_partitions
n = len(partitions)
# AMPnet implementation doesn't handle skip_trackers!
assert self.pipeline.style is PipelineStyle.AsyncSchedule # type: ignore
assert self.group
rank = self.group.rank()
transport = self.pipeline.transport # type: ignore
checkpoint_stop = self.pipeline.checkpoint_stop # type: ignore
ampnet_event_loop = AsyncAMPnetEventLoop(
partitions,
self.group,
transport,
min_update_interval,
weight_prediction,
checkpoint_stop,
self.input_device,
)
if rank == 0:
ampnet_event_loop.event_loop_head_across_minibatches(
lm_dataloader, criterion, optimizer, transform_logger_object
)
elif self.final_stage:
ampnet_event_loop.event_loop_tail_across_minibatches(
lm_dataloader, criterion, optimizer, transform_logger_object
)
else:
ampnet_event_loop.event_loop_across_minibatches(
lm_dataloader, criterion, optimizer, transform_logger_object
)
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
#
# This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree.
# Copyright 2019 Kakao Brain
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import torch
from torch import nn
from torch.optim.optimizer import Optimizer
from torch.utils.data import DataLoader, Dataset
from experimental.nn.ampnet_pipe.pipe import AMPnetPipe
from fairscale.nn.pipe import Pipe
from fairscale.utils.testing import get_worker_map, torch_spawn
class MySGD(Optimizer):
r"""
Args:
params (iterable): iterable of parameters to optimize or dicts defining
parameter groups
lr (float): learning rate (required)
"""
def __init__(self, params, lr=0.01):
defaults = dict(lr=lr)
super(MySGD, self).__init__(params, defaults)
def __setstate__(self, state):
super(MySGD, self).__setstate__(state)
def step(self, closure=None):
""" Performs a single optimization step.
Arguments:
closure (callable, optional): A closure that reevaluates the model
and returns the loss.
"""
loss = None
if closure is not None:
loss = closure()
for group in self.param_groups:
for p in group["params"]:
if p.grad is None:
continue
d_p = p.grad.data
p.data.add_(d_p, alpha=-group["lr"])
return loss
class FakeDataset(Dataset):
def __init__(
self, input_dim=10, output_dim=10, total_samples=100,
):
self.input_dim = input_dim
self.output_dim = output_dim
self.total_samples = total_samples
self.input_samples = torch.rand(self.total_samples, self.input_dim, self.output_dim)
self.target_samples = torch.rand(self.total_samples, self.input_dim, self.output_dim)
def __getitem__(self, index):
return {
"input": self.input_samples[index, :, :],
"target": self.target_samples[index, :, :],
}
def __len__(self):
return self.total_samples
@torch_spawn([2])
def async_event_loop_interleave_simple():
model = nn.Sequential(nn.Linear(10, 10), nn.ReLU(inplace=False), nn.Linear(10, 10), nn.ReLU(inplace=False))
pipe = AMPnetPipe(
module=model,
balance=[2, 2],
style=Pipe.AsyncSchedule,
worker_map=get_worker_map(),
chunks=10,
checkpoint="never",
)
fake_dataset = FakeDataset()
fake_dataloader = DataLoader(fake_dataset, batch_size=4, shuffle=True, num_workers=0)
loss = nn.MSELoss()
opt = MySGD(model.parameters(), lr=0.01)
pipe.interleave(fake_dataloader, loss, opt, 0)
@torch_spawn([4])
def async_event_loop_interleave_hard():
model = nn.Sequential(nn.Linear(10, 10), nn.Linear(10, 10), nn.Linear(10, 10), nn.Linear(10, 10))
pipe = AMPnetPipe(
module=model,
balance=[1, 1, 1, 1],
style=Pipe.AsyncSchedule,
worker_map=get_worker_map(),
chunks=10,
checkpoint="never",
)
fake_dataset = FakeDataset()
fake_dataloader = DataLoader(fake_dataset, batch_size=4, shuffle=True, num_workers=0)
loss = nn.MSELoss()
opt = MySGD(model.parameters(), lr=0.01)
pipe.interleave(fake_dataloader, loss, opt, 0)
...@@ -102,10 +102,10 @@ class AsyncRecvOperator(torch.autograd.Function): ...@@ -102,10 +102,10 @@ class AsyncRecvOperator(torch.autograd.Function):
@staticmethod @staticmethod
# type: ignore # type: ignore
def forward(ctx, phony: Tensor, transport: Transport, message: PipeMessage) -> Tensors: def forward(ctx, phony: Tensor, transport: Transport, message: PipeMessage, queue_name: int) -> Tensors:
ctx.transport = transport ctx.transport = transport
ctx.index = message.args.microbatch_index ctx.index = message.args.microbatch_index
ctx.queue_name = queue_name
result = transport.recv_message_tensors(message) result = transport.recv_message_tensors(message)
ctx.args = result.args ctx.args = result.args
...@@ -127,7 +127,7 @@ class AsyncRecvOperator(torch.autograd.Function): ...@@ -127,7 +127,7 @@ class AsyncRecvOperator(torch.autograd.Function):
) )
ctx.transport.send_message( ctx.transport.send_message(
PipeMessage( PipeMessage(
this_rank, ranks[ctx.args.source.stage], queue_name=EVENT_LOOP_QUEUE, args=body, tensors=tuple(grad), this_rank, ranks[ctx.args.source.stage], queue_name=ctx.queue_name, args=body, tensors=tuple(grad),
), ),
sync=True, sync=True,
) )
...@@ -136,7 +136,7 @@ class AsyncRecvOperator(torch.autograd.Function): ...@@ -136,7 +136,7 @@ class AsyncRecvOperator(torch.autograd.Function):
if tail_ctx: if tail_ctx:
expected_gradients = tail_ctx.expected_gradients expected_gradients = tail_ctx.expected_gradients
while expected_gradients > 0: while expected_gradients > 0:
message = ctx.transport.recv_message_header(EVENT_LOOP_QUEUE) message = ctx.transport.recv_message_header(ctx.queue_name)
args: AsyncMessageBody = message.args args: AsyncMessageBody = message.args
assert args.message_type is AsyncMessageType.Gradients assert args.message_type is AsyncMessageType.Gradients
...@@ -304,7 +304,7 @@ class AsyncEventLoop: ...@@ -304,7 +304,7 @@ class AsyncEventLoop:
microbatch_index = message.args.microbatch_index microbatch_index = message.args.microbatch_index
phony = torch.empty(0, device=self.transport.input_device, requires_grad=True) phony = torch.empty(0, device=self.transport.input_device, requires_grad=True)
result = AsyncRecvOperator.apply(phony, self.transport, message) result = AsyncRecvOperator.apply(phony, self.transport, message, EVENT_LOOP_QUEUE)
if len(result) == 1: if len(result) == 1:
batch = Batch(result[0], microbatch_index) batch = Batch(result[0], microbatch_index)
else: else:
......
...@@ -14,7 +14,9 @@ ACTIVATIONS_GRADS_QUEUE = 0 ...@@ -14,7 +14,9 @@ ACTIVATIONS_GRADS_QUEUE = 0
SKIP_TENSOR_QUEUE = 1 SKIP_TENSOR_QUEUE = 1
PORTAL_QUEUE = 2 PORTAL_QUEUE = 2
EVENT_LOOP_QUEUE = 3 EVENT_LOOP_QUEUE = 3
MESSAGE_GENERATION_START = 4 EVENT_LOOP_ACTIVATIONS_QUEUE = 4
EVENT_LOOP_GRADIENTS_QUEUE = 5
MESSAGE_GENERATION_START = 6
MessageGeneration = MESSAGE_GENERATION_START MessageGeneration = MESSAGE_GENERATION_START
......
...@@ -946,7 +946,6 @@ def test_instantiate_partition(): ...@@ -946,7 +946,6 @@ def test_instantiate_partition():
def check_partitions(model, balance, expected_order, expected_ranks): def check_partitions(model, balance, expected_order, expected_ranks):
"""Check the instantiated model matches expectation of order and rank """Check the instantiated model matches expectation of order and rank
model: a list of modules or an nn.Sequential model: a list of modules or an nn.Sequential
balance: the balance argument to Pipe balance: the balance argument to Pipe
expected_order: the index of modules in `model` in the order they will expected_order: the index of modules in `model` in the order they will
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment