Unverified Commit 3c727ec5 authored by anj-s's avatar anj-s Committed by GitHub
Browse files

[refactor] Remove unused variables, add configuration objects and basic...


[refactor] Remove unused variables, add configuration objects and basic cleanup for pipe benchmarks. (#252)

* [refactor]Remove unused variables and refactor common configurations

* move helper function to call site

* fixed lint errors

* fix lint errors

* fix lint errors

* fix lint errors

* fix import order

* format files

* remove unused imports

* fix lint errors

* address PR comments

* sorted imports

* add space

* modify comment

* added doc strings and addressed PR comments.

* addressed PR comments

* added another comment to clarify.

* fixing lint errors

* rename variable
Co-authored-by: default avatarAnjali Sridhar <anj@devfair0443.h2.fair>
parent 8321f682
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
import warnings
import torchtext
from torchtext.data.utils import get_tokenizer
def get_wikitext2_data(device):
"""Return batched data from wikitext2 dataset for training, validation and testing."""
with warnings.catch_warnings(record=True) as _:
text_field = torchtext.data.Field(
tokenize=get_tokenizer("basic_english"), init_token="<sos>", eos_token="<eos>", lower=True
)
train_txt, val_txt, test_txt = torchtext.datasets.WikiText2.splits(text_field)
text_field.build_vocab(train_txt)
ntokens = len(text_field.vocab.stoi)
batch_size = 20
eval_batch_size = 10
train_data = batchify(train_txt, batch_size, text_field, device)
val_data = batchify(val_txt, eval_batch_size, text_field, device)
test_data = batchify(test_txt, eval_batch_size, text_field, device)
return ntokens, train_data, val_data, test_data
def batchify(data, bsz, text_field, device):
"""Return batched data that is placed on the specified device."""
data = text_field.numericalize([data.examples[0].text])
nbatch = data.size(0) // bsz
data = data.narrow(0, 0, nbatch * bsz)
data = data.view(bsz, -1).t().contiguous()
return data.to(device)
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
import math
import torch
import torch.nn as nn
# TODO(anj-s): Identify if we need this initialization logic for the below wrapped layers.
class EmbeddingLayer(nn.Embedding):
"""Wrapped nn.Embedding layer to allow for weight initialization."""
def __init__(self, ntoken, ninp, initrange):
super().__init__(ntoken, ninp)
self.ninp_sqrt = math.sqrt(ninp)
self.weight.data.uniform_(-initrange, initrange)
def forward(self, src):
return super().forward(src) * self.ninp_sqrt
class PositionalEncodingLayer(nn.Module):
"""PositionalEncoding layer for a given Transformer model."""
def __init__(self, d_model, dropout=0.1, max_len=5000):
super(PositionalEncodingLayer, self).__init__()
self.dropout = nn.Dropout(p=dropout)
pe = torch.zeros(max_len, d_model)
position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
pe[:, 0::2] = torch.sin(position * div_term)
pe[:, 1::2] = torch.cos(position * div_term)
pe = pe.unsqueeze(0).transpose(0, 1)
self.register_buffer("pe", pe)
def forward(self, x):
x = x + self.pe[: x.size(0), :]
return self.dropout(x)
class TransformerDecoderLayer(nn.TransformerEncoderLayer):
"""TransformerDecoder layer which inherits from nn.TransformerEncoderLayer."""
def __init__(self, ninp, nhead, nhid, dropout):
super().__init__(ninp, nhead, nhid, dropout)
self.src_mask = None
def _generate_square_subsequent_mask(self, sz):
mask = (torch.triu(torch.ones(sz, sz)) == 1).transpose(0, 1)
mask = mask.float().masked_fill(mask == 0, float("-inf")).masked_fill(mask == 1, float(0.0))
return mask
def forward(self, src):
# TODO(anj-s): Fix the data format so that we have [seq_len, batch_size, embedding dim].
# Currently real data has seq_len as the second dimension and batch_size as the first dimension.
# We need to mask the sequence length dimension and not the batch size.
if self.src_mask is None or self.src_mask.size(0) != len(src):
device = src.device
mask = self._generate_square_subsequent_mask(len(src)).to(device)
self.src_mask = mask
return super().forward(src, self.src_mask)
class LinearLayer(nn.Linear):
"""Wrapped nn.Linear layer to allow for weight initialization."""
def __init__(self, ninp, ntoken, initrange):
super().__init__(ninp, ntoken)
self.bias.data.zero_()
self.weight.data.uniform_(-initrange, initrange)
class TransformerLMSequntial(nn.Sequential):
"""A GPT-2 based nn.Sequeitnal language model."""
def __init__(self, ntokens, ninp, nhead, nhid, dropout, initrange, ndecoder):
layers = [
EmbeddingLayer(ntokens, ninp, initrange),
PositionalEncodingLayer(ninp, dropout),
]
for _ in range(ndecoder):
layers.append(TransformerDecoderLayer(ninp, nhead, nhid, dropout))
layers.append(LinearLayer(ninp, ntokens, initrange))
super(TransformerLMSequntial, self).__init__(*layers)
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. # Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
import argparse import argparse
from collections import defaultdict
from functools import reduce
import gc
import logging import logging
import math import math
import operator
import os import os
import pprint
import time import time
import warnings
from benchmark_dataset import BenchmarkLMDataset, collate_sentences_lm from benchmark_dataset import BenchmarkLMDataset, collate_sentences_lm
import datasets
import models
import numpy
import torch import torch
from torch.distributed import rpc from torch.distributed import rpc
import torch.multiprocessing as mp import torch.multiprocessing as mp
import torch.nn as nn import torch.nn as nn
from torch.nn.parallel import DistributedDataParallel as DDP from torch.nn.parallel import DistributedDataParallel as DDP
from torch.utils.data import DataLoader from torch.utils.data import DataLoader
import torchtext
from torchtext.data.utils import get_tokenizer
from fairscale.nn import Pipe from fairscale.nn import Pipe
from fairscale.nn.model_parallel import initialize_model_parallel from fairscale.nn.model_parallel import initialize_model_parallel
...@@ -36,167 +41,46 @@ except ImportError: ...@@ -36,167 +41,46 @@ except ImportError:
def init_random_seed(seed: int): def init_random_seed(seed: int):
import numpy
torch.manual_seed(seed) torch.manual_seed(seed)
torch.cuda.manual_seed(seed) torch.cuda.manual_seed(seed)
numpy.random.seed(seed) numpy.random.seed(seed)
PIPE_CHUNKS = 2 def make_model(args, device, config):
iteration_count = 0 ninp = config["ninp"]
nhead = config["nhead"]
initrange = config["initrange"]
class EmbeddingLayer(nn.Embedding): dropout = config["dropout"]
def __init__(self, ntoken, ninp, initrange): vocab_size = config["vocab_size"]
super().__init__(ntoken, ninp) nhid = config["nhid"]
self.ninp = ninp lr = config["lr"]
self.weight.data.uniform_(-initrange, initrange)
def forward(self, src):
return super().forward(src) * math.sqrt(self.ninp)
class PositionalEncodingLayer(nn.Module):
def __init__(self, d_model, dropout=0.1, max_len=5000):
super(PositionalEncodingLayer, self).__init__()
self.dropout = nn.Dropout(p=dropout)
pe = torch.zeros(max_len, d_model)
position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
pe[:, 0::2] = torch.sin(position * div_term)
pe[:, 1::2] = torch.cos(position * div_term)
pe = pe.unsqueeze(0).transpose(0, 1)
self.register_buffer("pe", pe)
def forward(self, x):
x = x + self.pe[: x.size(0), :]
return self.dropout(x)
class TransformerDecoderLayer(nn.TransformerEncoderLayer):
"""Though this class inherits from torch.nn.TransformerEncoderLayer,
it functions as a decoder in this model"""
def __init__(self, ninp, nhead, nhid, droupout):
super().__init__(ninp, nhead, nhid, droupout)
self.src_mask = None
def _generate_square_subsequent_mask(self, sz):
mask = (torch.triu(torch.ones(sz, sz)) == 1).transpose(0, 1)
mask = mask.float().masked_fill(mask == 0, float("-inf")).masked_fill(mask == 1, float(0.0))
return mask
def forward(self, src):
global iteration_count
iteration_count += 1
# if iteration_count == 196:
# dump_cuda_tensors()
if self.src_mask is None or self.src_mask.size(0) != len(src):
device = src.device
mask = self._generate_square_subsequent_mask(len(src)).to(device)
self.src_mask = mask
return super().forward(src, self.src_mask)
class LinearLayer(nn.Linear):
def __init__(self, ninp, ntoken, initrange):
super().__init__(ninp, ntoken)
self.bias.data.zero_()
self.weight.data.uniform_(-initrange, initrange)
class TransformerLMSequential(nn.Sequential):
"""A small language model based on the design of GPT-2 using nn.Sequential
for compatability with Pipe"""
def __init__(self, ntokens, ninp, nhead, nhid, dropout, initrange, ndecoder):
layers = [
EmbeddingLayer(ntokens, ninp, initrange),
PositionalEncodingLayer(ninp, dropout),
]
for _ in range(ndecoder):
layers.append(TransformerDecoderLayer(ninp, nhead, nhid, dropout))
layers.append(LinearLayer(ninp, ntokens, initrange))
super(TransformerLMSequential, self).__init__(*layers)
def get_data(device):
with warnings.catch_warnings(record=True) as fjldska:
TEXT = torchtext.data.Field(
tokenize=get_tokenizer("basic_english"), init_token="<sos>", eos_token="<eos>", lower=True
)
train_txt, val_txt, test_txt = torchtext.datasets.WikiText2.splits(TEXT)
TEXT.build_vocab(train_txt)
ntokens = len(TEXT.vocab.stoi)
batch_size = 20
eval_batch_size = 10
train_data = batchify(train_txt, batch_size, TEXT, device)
val_data = batchify(val_txt, eval_batch_size, TEXT, device)
test_data = batchify(test_txt, eval_batch_size, TEXT, device)
return ntokens, train_data, val_data, test_data
def batchify(data, bsz, TEXT, device):
data = TEXT.numericalize([data.examples[0].text])
nbatch = data.size(0) // bsz
data = data.narrow(0, 0, nbatch * bsz)
data = data.view(bsz, -1).t().contiguous()
return data.to(device)
def get_batch(source, i, bptt):
seq_len = min(bptt, len(source) - 1 - i)
data = source[i : i + seq_len]
target = source[i + 1 : i + 1 + seq_len].view(-1)
return data, target
def make_model(args, device, ntokens):
ninp = 2048 # embedding dimension
nhid = 2048 # the dimension of the feedforward network model in nn.TransformerEncoder
nhead = 32 # the number of heads in the multiheadattention models
dropout = 0
initrange = 0.1
ndecoder = args.num_decoder_layers ndecoder = args.num_decoder_layers
if args.lazy_construction: if args.lazy_construction:
layers = [ layers = [
LazyModule(lambda: EmbeddingLayer(ntokens, ninp, initrange)), LazyModule(lambda: models.EmbeddingLayer(vocab_size, ninp, initrange)),
LazyModule(lambda: PositionalEncodingLayer(ninp, dropout)), LazyModule(lambda: models.PositionalEncodingLayer(ninp, dropout)),
] ]
for _ in range(ndecoder): for _ in range(ndecoder):
layers.append(LazyModule(lambda: TransformerDecoderLayer(ninp, nhead, nhid, dropout))) layers.append(LazyModule(lambda: models.TransformerDecoderLayer(ninp, nhead, nhid, dropout)))
layers.append(LazyModule(lambda: LinearLayer(ninp, ntokens, initrange))) layers.append(LazyModule(lambda: models.LinearLayer(ninp, vocab_size, initrange)))
model = layers model = layers
else: else:
model = TransformerLMSequential(ntokens, ninp, nhead, nhid, dropout, initrange, ndecoder).to(device) model = models.TransformerLMSequntial(vocab_size, ninp, nhead, nhid, dropout, initrange, ndecoder).to(device)
criterion = nn.CrossEntropyLoss()
lr = 0.01 # learning rate
def make_adam(model): def make_adam(params):
if args.ddp_zero: if args.ddp_zero:
return OSS(params=model.parameters(), optim=Adam, group=get_data_parallel_group(), lr=lr) return OSS(params=params, optim=Adam, group=get_data_parallel_group(), lr=lr)
else: else:
return Adam(model.parameters(), lr=lr) return Adam(params, lr=lr)
optimizer = make_adam optimizer = make_adam
scaler = GradScaler() return model, optimizer
return model, criterion, optimizer, scaler
def get_tensors_by_size_bucket(): def get_tensors_by_size_bucket():
from collections import defaultdict
import gc
size_buckets = defaultdict(int) size_buckets = defaultdict(int)
for obj in gc.get_objects(): for obj in gc.get_objects():
...@@ -209,8 +93,6 @@ def get_tensors_by_size_bucket(): ...@@ -209,8 +93,6 @@ def get_tensors_by_size_bucket():
def dump_size_buckets(size_buckets, prefix=""): def dump_size_buckets(size_buckets, prefix=""):
from functools import reduce
import operator
total = 0 total = 0
for key, value in size_buckets.items(): for key, value in size_buckets.items():
...@@ -253,9 +135,6 @@ def check_size_buckets(): ...@@ -253,9 +135,6 @@ def check_size_buckets():
def dump_cuda_tensors(): def dump_cuda_tensors():
print(f"dumping cuda tensors...") print(f"dumping cuda tensors...")
from functools import reduce
import gc
import operator
for obj in gc.get_objects(): for obj in gc.get_objects():
if not isinstance(obj, torch.Tensor): if not isinstance(obj, torch.Tensor):
...@@ -270,16 +149,10 @@ def dump_cuda_tensors(): ...@@ -270,16 +149,10 @@ def dump_cuda_tensors():
total += this total += this
print(f"{key} : {value}, {this}") print(f"{key} : {value}, {this}")
print(f"total size = {total}") print(f"total size = {total}")
import pprint
pprint.pprint(torch.cuda.memory_stats()) pprint.pprint(torch.cuda.memory_stats())
def train(lm_dataloader, model, criterion, optimizer, vocab_size, args): def log_number_of_parameters(model):
model.train()
from functools import reduce
import operator
num_params = reduce(operator.add, (reduce(operator.mul, x.size()) for x in model.parameters())) num_params = reduce(operator.add, (reduce(operator.mul, x.size()) for x in model.parameters()))
if model.group: if model.group:
...@@ -288,42 +161,55 @@ def train(lm_dataloader, model, criterion, optimizer, vocab_size, args): ...@@ -288,42 +161,55 @@ def train(lm_dataloader, model, criterion, optimizer, vocab_size, args):
total = total.cuda() total = total.cuda()
torch.distributed.all_reduce(total, group=model.group) torch.distributed.all_reduce(total, group=model.group)
logging.info( logging.info(
f"training model, #prams = {num_params}, group: {model.group.rank()}, grank:" f"training model, #params = {num_params}, group: {model.group.rank()}, grank:"
f" {torch.distributed.get_rank()}, sizes {model.group.size()}" f" {torch.distributed.get_rank()}, sizes {model.group.size()}"
) )
torch.distributed.barrier() torch.distributed.barrier()
if model.group.rank() == 0: if model.group.rank() == 0:
logging.info(f"total #prams = {total.item()}") logging.info(f"total #prams = {total.item()}")
else: else:
logging.info(f"training model, #prams = {num_params}") logging.info(f"training model, #params = {num_params}")
vocab_size = 10000 # FIXME
total_loss = 0.0
start_time = time.time()
word_counter = 0
optimizer = optimizer(model)
def get_first_device(model): def get_device(model, index):
if isinstance(model, DDP): if isinstance(model, DDP):
model = model.module model = model.module
if not torch.cuda.is_available(): if not torch.cuda.is_available():
return torch.device("cpu") return torch.device("cpu")
if model.devices: if model.devices:
return model.devices[0] return model.devices[index]
else: else:
return torch.cuda.current_device() return torch.cuda.current_device()
def get_last_device(model):
if isinstance(model, DDP):
model = model.module
if not torch.cuda.is_available(): def get_fake_dataloader(lm_dataloader_len):
return torch.device("cpu") fake_input = {"input": torch.zeros(args.batch_size)}
if model.devices:
return model.devices[-1] class FakeDataset:
else: def __getitem__(self, index):
return torch.cuda.current_device() return fake_input
def __len__(self):
return lm_dataloader_len
return FakeDataset()
def train(data_config, model, benchmark_config, args):
lm_dataloader = data_config["data"]
criterion = benchmark_config["criterion"]
vocab_size = benchmark_config["vocab_size"]
optimizer = data_config["optimizer"]
model.train()
log_number_of_parameters(model)
total_loss = 0.0
start_time = time.time()
word_counter = 0
optimizer = optimizer(model.parameters())
pipe_group = model.group pipe_group = model.group
...@@ -335,26 +221,17 @@ def train(lm_dataloader, model, criterion, optimizer, vocab_size, args): ...@@ -335,26 +221,17 @@ def train(lm_dataloader, model, criterion, optimizer, vocab_size, args):
find_unused_parameters=False, find_unused_parameters=False,
) )
# TODO(anj-s): Avoid sending fake data to all replicas except the first and last one.
if pipe_group and pipe_group.rank() != 0 and pipe_group.rank() != (pipe_group.size() - 1): if pipe_group and pipe_group.rank() != 0 and pipe_group.rank() != (pipe_group.size() - 1):
thing = {"input": torch.zeros(args.batch_size)} lm_dataloader = get_fake_dataloader(len(lm_dataloader))
class FakeDataset:
def __getitem__(self, index):
return thing
def __len__(self):
return len(lm_dataloader)
lm_dataloader = FakeDataset()
for i, batch in enumerate(lm_dataloader): for i, batch in enumerate(lm_dataloader):
bi = batch["input"]
if args.max_batch and i > args.max_batch: if args.max_batch and i > args.max_batch:
break break
optimizer.zero_grad() optimizer.zero_grad()
try: try:
if (pipe_group is None or pipe_group.rank() == 0) and not args.ddp_zero: if (pipe_group is None or pipe_group.rank() == 0) and not args.ddp_zero:
tmp = batch["input"].to(get_first_device(model)) tmp = batch["input"].to(get_device(model, 0))
output = model(tmp) output = model(tmp)
else: else:
output = model(batch["input"]) output = model(batch["input"])
...@@ -362,7 +239,7 @@ def train(lm_dataloader, model, criterion, optimizer, vocab_size, args): ...@@ -362,7 +239,7 @@ def train(lm_dataloader, model, criterion, optimizer, vocab_size, args):
raise RuntimeError(f"training failed on {torch.distributed.get_rank()}") from e raise RuntimeError(f"training failed on {torch.distributed.get_rank()}") from e
if pipe_group is None or pipe_group.rank() == pipe_group.size() - 1: if pipe_group is None or pipe_group.rank() == pipe_group.size() - 1:
target = batch["target"].to(get_last_device(model)) target = batch["target"].to(get_device(model, -1))
output = output.to(target.device) output = output.to(target.device)
loss = criterion(output.view(-1, vocab_size), target.view(-1)) loss = criterion(output.view(-1, vocab_size), target.view(-1))
...@@ -380,7 +257,7 @@ def train(lm_dataloader, model, criterion, optimizer, vocab_size, args): ...@@ -380,7 +257,7 @@ def train(lm_dataloader, model, criterion, optimizer, vocab_size, args):
del output del output
torch.nn.utils.clip_grad_value_(model.parameters(), 0.05) torch.nn.utils.clip_grad_value_(model.parameters(), benchmark_config["clip_value"])
optimizer.step() optimizer.step()
if pipe_group is None or pipe_group.rank() == pipe_group.size() - 1: if pipe_group is None or pipe_group.rank() == pipe_group.size() - 1:
...@@ -398,15 +275,18 @@ def train(lm_dataloader, model, criterion, optimizer, vocab_size, args): ...@@ -398,15 +275,18 @@ def train(lm_dataloader, model, criterion, optimizer, vocab_size, args):
word_counter = 0 word_counter = 0
total_loss = 0 total_loss = 0
start_time = time.time() start_time = time.time()
# if i >= 10:
# break
# torch.cuda.empty_cache()
# check_size_buckets()
def evaluate(eval_model, data_source, criterion, bptt, ntokens): def evaluate(eval_model, data_source, criterion, bptt, ntokens):
eval_model.eval() eval_model.eval()
total_loss = 0.0 total_loss = 0.0
def get_batch(source, i, bptt):
seq_len = min(bptt, len(source) - 1 - i)
data = source[i : i + seq_len]
target = source[i + 1 : i + 1 + seq_len].view(-1)
return data, target
with torch.no_grad(): with torch.no_grad():
for i in range(0, data_source.size(0) - 1, bptt): for i in range(0, data_source.size(0) - 1, bptt):
data, targets = get_batch(data_source, i, bptt) data, targets = get_batch(data_source, i, bptt)
...@@ -421,7 +301,10 @@ def get_number_of_words(data): ...@@ -421,7 +301,10 @@ def get_number_of_words(data):
return data.size()[0] * data.size()[1] return data.size()[0] * data.size()[1]
def benchmark_language_model(train_data, val_data, test_data, model, criterion, optimizer, ntokens, args): def benchmark_language_model(model_config, model, benchmark_config, args):
ntokens, train_data, val_data, test_data = model_config["data"]
optimizer = model_config["optimizer"]
criterion = benchmark_config["criterion"]
epoch = 1 epoch = 1
bptt = 35 bptt = 35
start_time = time.time() start_time = time.time()
...@@ -497,64 +380,76 @@ def generate_balance(num_devices, num_layers): ...@@ -497,64 +380,76 @@ def generate_balance(num_devices, num_layers):
return balance return balance
def make_model_and_data(args, device, new_data: bool = True): def make_model_and_data(args, config=None):
"""Return a dict with the given model, dataset and optimizer."""
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
if new_data: if args.use_synthetic_data:
vocab_size = 10000 model, optimizer = make_model(args, device, config)
model, criterion, optimizer, scaler = make_model(args, device, vocab_size)
lm_dataset = BenchmarkLMDataset() lm_dataset = BenchmarkLMDataset()
lm_dataloader = DataLoader( lm_dataloader = DataLoader(
lm_dataset, batch_size=args.batch_size, shuffle=True, num_workers=0, collate_fn=collate_sentences_lm lm_dataset, batch_size=args.batch_size, shuffle=True, num_workers=0, collate_fn=collate_sentences_lm
) )
return { return {"model": model, "optimizer": optimizer, "data": lm_dataloader}
"model": model,
"criterion": criterion,
"optimizer": optimizer,
"data": lm_dataloader,
"vocab_size": vocab_size,
}
else: else:
data = get_data(device) data = datasets.get_wikitext2_data(device)
ntokens, train_data, val_data, test_data = data ntokens, _, _, _ = data
model, criterion, optimizer, scaler = make_model(args, device, ntokens) config["vocab_size"] = ntokens
model, optimizer = make_model(args, device, ntokens)
return { return {
"model": model, "model": model,
"criterion": criterion,
"optimizer": optimizer, "optimizer": optimizer,
"data": data, "data": data,
} }
def bench_single_process(args): def create_benchmark_config(model_name):
"""Return a dict with configurations required for benchmarking `model_name` model."""
if model_name == "seq_pred":
return {
"vocab_size": 10000,
"ninp": 2048, # embedding dimension
"nhid": 2048, # the dimension of the feedforward network model in nn.TransformerEncoder
"nhead": 32, # the number of heads in the multiheadattention models
"dropout": 0,
"initrange": 0.1,
"criterion": nn.CrossEntropyLoss(),
"lr": 0.01, # learning rate
"scaler": GradScaler(),
"clip_value": 0.05,
}
def benchmark_single_process(args):
"""Benchmark a given model using a single process and multiple devices."""
num_devices = torch.cuda.device_count() if torch.cuda.is_available() else 1 num_devices = torch.cuda.device_count() if torch.cuda.is_available() else 1
assert num_devices > 0 assert num_devices > 0
init_random_seed(0) init_random_seed(0)
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
new_data = True
blob = make_model_and_data(args, None, new_data=new_data) benchmark_config = create_benchmark_config(args.model_name)
model = blob["model"] model_config = make_model_and_data(args, config=benchmark_config)
model = model_config["model"]
balance = generate_balance(min(num_devices, 4), len(model)) balance = generate_balance(min(num_devices, 4), len(model))
p = pipe.Pipe( p = pipe.Pipe(
model, balance, chunks=args.chunks, pipelined_backward=args.pipelined_backward, checkpoint=args.checkpoint model, balance, chunks=args.chunks, pipelined_backward=args.pipelined_backward, checkpoint=args.checkpoint
) )
del model del model
del blob["model"] del model_config["model"]
if new_data: if args.use_synthetic_data:
train(blob["data"], p, blob["criterion"], blob["optimizer"], blob["vocab_size"], args) train(model_config, p, benchmark_config, args)
else: else:
ntokens, train_data, val_data, test_data = blob["data"] benchmark_language_model(model_config, p, benchmark_config, args)
benchmark_language_model(train_data, val_data, test_data, p, criterion, optimizer, ntokens, args)
def run_mp_worker(args, available_workers): def run_mp_worker(args, available_workers):
new_data = True
blob = make_model_and_data(args, None, new_data=new_data) benchmark_config = create_benchmark_config(args.model_name)
model = blob["model"] model_config = make_model_and_data(args, config=benchmark_config)
model = model_config["model"]
balance = generate_balance_weighted(get_pipeline_parallel_group().size(), len(model), 0.8) balance = generate_balance_weighted(get_pipeline_parallel_group().size(), len(model), 0.8)
p = pipe.Pipe( p = pipe.Pipe(
...@@ -566,7 +461,7 @@ def run_mp_worker(args, available_workers): ...@@ -566,7 +461,7 @@ def run_mp_worker(args, available_workers):
input_device=torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu"), input_device=torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu"),
pipelined_backward=args.pipelined_backward, pipelined_backward=args.pipelined_backward,
checkpoint=args.checkpoint, checkpoint=args.checkpoint,
# loss_fn=blob["criterion"], # TODO(anj-s): Do we need to comment this out? loss_fn=benchmark_config["criterion"],
) )
if torch.cuda.is_available(): if torch.cuda.is_available():
p = p.cuda() p = p.cuda()
...@@ -574,11 +469,10 @@ def run_mp_worker(args, available_workers): ...@@ -574,11 +469,10 @@ def run_mp_worker(args, available_workers):
print(f"running all at once") print(f"running all at once")
p.pipeline.all_at_once = True p.pipeline.all_at_once = True
if new_data: if args.use_synthetic_data:
train(blob["data"], p, blob["criterion"], blob["optimizer"], blob["vocab_size"], args) train(model_config, p, benchmark_config, args)
else: else:
ntokens, train_data, val_data, test_data = blob["data"] benchmark_language_model(model_config, p, benchmark_config, args)
benchmark_language_model(train_data, val_data, test_data, p, criterion, optimizer, ntokens, args)
def run_worker(rank, world_size, args): def run_worker(rank, world_size, args):
...@@ -681,14 +575,18 @@ parser.add_argument( ...@@ -681,14 +575,18 @@ parser.add_argument(
parser.add_argument( parser.add_argument(
"--no-pipelined-backward", dest="pipelined_backward", action="store_false", help="Pipelined backward pass" "--no-pipelined-backward", dest="pipelined_backward", action="store_false", help="Pipelined backward pass"
) )
parser.add_argument("--use_synthetic_data", default=True, help="Uses synthetic data for a sample training run.")
parser.add_argument(
"--model_name", default="seq_pred", choices=["seq_pred", "transformer"], help="Model used to benchmark pipe."
)
parser.set_defaults(pipelined_backward=True) parser.set_defaults(pipelined_backward=True)
if __name__ == "__main__": if __name__ == "__main__":
args = parser.parse_args() args = parser.parse_args()
# bench_multi_process(args, all_at_once=True) # TODO(anj-s): Add support for multiprocess benchmarking.
if args.no_mpi or "OMPI_COMM_WORLD_RANK" not in os.environ: if args.no_mpi or "OMPI_COMM_WORLD_RANK" not in os.environ:
print(f"Running benchmark with args: {args}") print(f"Running benchmark with args: {args}")
bench_single_process(args) benchmark_single_process(args)
else: else:
if os.environ["OMPI_COMM_WORLD_RANK"] == "0": if os.environ["OMPI_COMM_WORLD_RANK"] == "0":
print(f"Running benchmark with args: {args}") print(f"Running benchmark with args: {args}")
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment