Unverified Commit f7813d6d authored by anj-s's avatar anj-s Committed by GitHub
Browse files

[feature] Add support for OffloadModel to enable training large models on 1 GPU. (#432)



* clean start

* removing per layer split strategy, probably not that useful indeed

* initial transformer benchmark

* hack, enable testing ViT + offload, python3 benchmarks/oss.py  --epochs 2 --optim_type oss_offload_ddp --batch_size=32 --model vit_large_patch16_224

* proper cuda streams and device, something off in terms of mems consumption

* minor, stashing

* unit test fix

* removing all the distributed parts

* simpler test, needs debugging

* working OOP, running a model which does not fit on the gpu memory

* spring cleaning

* removing the ill-advised optimizer bits, better keep that orthogonal

* [offload] Add support for activation offloading + other changes (#367)

* initial fwd/bwd commit

* checkpoint work

* modify shard loop

* activation offloading and test to start with

* fix lint errors

* update comments

* fix lint

* remove unused var

* remove commented out lines

* modify name

* remove break

* remove profiler comments

* avoid saving inputs

* fix lint errors
Co-authored-by: default avatarAnjali Sridhar <anj@devfair0443.h2.fair>

* [offload] Add support for fp16 training (#374)

* initial fwd/bwd commit

* checkpoint work

* modify shard loop

* activation offloading and test to start with

* fix lint errors

* update comments

* fix lint

* remove unused var

* remove commented out lines

* modify name

* remove break

* remove profiler comments

* add support for fp16

* add unit tests

* fix lint errors

* fix test failure
Co-authored-by: default avatarAnjali Sridhar <anj@devfair0443.h2.fair>

* [offload] Add support for activation checkpointing for all layers. (#381)

* initial fwd/bwd commit

* checkpoint work

* modify shard loop

* activation offloading and test to start with

* fix lint errors

* update comments

* fix lint

* remove unused var

* remove commented out lines

* modify name

* remove break

* remove profiler comments

* add support for fp16

* add unit tests

* fix lint errors

* fix test failure

* cp work, incorrect output dimensions still need to be fixed

* fixed activation outputs

* intermediate cp of work

* add tests

* fix lint errors
Co-authored-by: default avatarAnjali Sridhar <anj@devfair0443.h2.fair>

* add support for microbatches

* revert benchmark config changes

* add parametrization

* fix lint errors and tests

* skip test for 1.5

* fix lint errors

* skip test if there are no GPUs

* fix lint errors

* fix lint errors

* move experimental to the fairscale repo

* lint error fixes

* modify test imports

* lint error fixes

* move offload files to the experimental directory

* move tests and benchmarks to their forlder

* fix mypy errors

* cp intermediate working benchmarks

* more changes

* split benchmark configs

* remove print statements

* fix lint errors

* remove unused print

* stress testing

* remove unused file

* change param nae

* lint fixes

* move file to the right folder

* offload_experimental

* add doc string

* add error message
Co-authored-by: default avatarBenjamin Lefaudeux <benjamin.lefaudeux@gmail.com>
Co-authored-by: default avatarBenjamin Lefaudeux <benjamin.lefaudeux@protonmail.com>
Co-authored-by: default avatarAnjali Sridhar <anj@devfair0443.h2.fair>
parent 7ee228bf
......@@ -44,7 +44,7 @@ def get_real_dataloaders(args, benchmark_config, model_specs):
test_dataset = data_process(iter(io.open(test_filepath, encoding="utf8")))
def batchify(data):
batch_size = args.batch_size
batch_size = benchmark_config["batch_size"]
return _batchify(data, batch_size)
total_batch_size = _get_total_batch_size(benchmark_config, model_specs)
......
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
#
# This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree.
import argparse
import contextlib
from functools import reduce
import logging
import math
import operator
import time
import numpy as np
import torch
from torch.optim import Adam
from torch.utils.data.dataloader import DataLoader
from torchvision.datasets import FakeData
from torchvision.transforms import ToTensor
from benchmarks.datasets.wikitext2_data import get_real_dataloaders as get_real_wikitext2_dataloaders
from benchmarks.datasets.wikitext2_data import get_synthetic_dataloaders as get_synthetic_wikitext2_dataloaders
from benchmarks.golden_configs.lm_wikitext2 import Offload_Sequential as offload_seq
from benchmarks.golden_configs.lm_wikitext2 import Offload_Transformer as lm_wikitext2
from benchmarks.models import transformer_lm
from fairscale.experimental.nn.offload import OffloadModel
def init_random_seed(seed: int):
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
np.random.seed(seed)
def get_model_and_optimizer(args, device, benchmark_config, model_specs):
"""Return instantiated model and optimizer function."""
if args.model_name == "lm":
model = get_lm_model(args, device, model_specs)
lr = benchmark_config["lr"]
def make_adam(params):
return Adam(params, lr=lr)
optimizer = make_adam
elif args.model_name == "seq":
model = get_seq_model(args, device, model_specs)
optimizer = torch.optim.SGD
model = OffloadModel(
model_cpu=model,
device=torch.device("cuda"),
offload_device=torch.device("cpu"),
num_slices=benchmark_config["slices"],
checkpoint_activation=benchmark_config["checkpoint_activation"],
num_microbatches=benchmark_config["num_microbatches"],
)
return model, optimizer
def get_seq_model(args, device, model_specs):
model = torch.nn.Sequential(
torch.nn.Linear(model_specs["inputs"] * model_specs["inputs"], model_specs["hidden"]),
*([torch.nn.Linear(model_specs["hidden"], model_specs["hidden"]) for _ in range(model_specs["layers"])]),
torch.nn.Linear(model_specs["hidden"], model_specs["outputs"]),
)
return model.cpu()
def get_lm_model(args, device, config):
"""Get language model(based on GPT-2) used for sequence prediction."""
ninp = config["ninp"]
nhead = config["nhead"]
initrange = config["initrange"]
dropout = config["dropout"]
vocab_size = config["vocab_size"]
nhid = config["nhid"]
ndecoder = config["num_decoder_layers"]
return transformer_lm.TransformerLM(vocab_size, ninp, nhead, nhid, dropout, initrange, ndecoder).to(device)
def log_number_of_parameters(model):
num_params = reduce(operator.add, (reduce(operator.mul, x.size()) for x in model.parameters()))
logging.info(f"training model, #params = {num_params}")
def _get_fp16_context(use_fp16=False):
if use_fp16:
return torch.cuda.amp.autocast()
else:
return contextlib.nullcontext()
def _get_profiler_context(use_profiler=False):
if use_profiler:
return torch.autograd.profiler.profile(use_cuda=True, profile_memory=True)
else:
return contextlib.nullcontext()
def _get_profiler_record_context(record_name, use_profiler=False):
if use_profiler:
return torch.autograd.profiler.record_function(record_name)
else:
return contextlib.nullcontext()
def train_seq(model_config, benchmark_config, model_specs, args):
device = torch.device("cuda")
torch.cuda.set_device(0)
torch.manual_seed(5)
model = model_config["model"]
criterion = benchmark_config["criterion"]
optimizer = model_config["optimizer"](model.parameters(), lr=benchmark_config["lr"])
dataloader, _, _ = model_config["data"]
def train_epoch(args):
model.train()
for batch_inputs, batch_outputs in dataloader:
batch_inputs, batch_outputs = batch_inputs.to("cuda"), batch_outputs.to("cuda")
start = time.time_ns()
with _get_profiler_context() as prof:
optimizer.zero_grad()
inputs = batch_inputs.reshape(-1, model_specs["inputs"] * model_specs["inputs"])
with _get_profiler_record_context("model_training"):
with _get_fp16_context(use_fp16=args.use_fp16):
output = model(inputs)
loss = criterion(output, target=batch_outputs)
loss.backward()
optimizer.step()
logging.info(
"Memory stats are {:.2f}GB".format(torch.cuda.memory_stats(0)["allocated_bytes.all.peak"] / 2 ** 30)
)
logging.info(
"Loss {:.2f} - throughput {:.2f}fps".format(
loss.item(), benchmark_config["batch_size"] / (time.time_ns() - start) * 10 ** 9
)
)
if args.use_profiler:
prof.export_chrome_trace("/tmp/offload_prof")
train_epoch(args)
def train(model_config, model, benchmark_config, model_specs, args):
lm_dataloader, _, _ = model_config["data"]
criterion = benchmark_config["criterion"]
vocab_size = model_specs["vocab_size"]
optimizer = model_config["optimizer"]
model.train()
log_number_of_parameters(model)
total_loss = 0.0
word_counter = 0
optimizer = optimizer(model.parameters())
total_tokens = 0
total_tokens_per_log_interval = 0
bptt = 2
start_time = time.time()
epoch_start_time = 0.0
def get_batch(source):
seq_len = len(source) - 1
data = source[0:seq_len]
target = source[1 : 1 + seq_len]
return data, target
for i, batch in enumerate(lm_dataloader):
if i == 1:
epoch_start_time = time.time()
source, target = get_batch(batch)
if i > 0:
total_tokens += source.numel()
optimizer.zero_grad()
output = model(source)
target = target.to("cuda")
output = output.to(target.device)
loss = criterion(output.view(-1, vocab_size), target.view(-1))
loss.backward()
torch.nn.utils.clip_grad_value_(model.parameters(), model_specs["clip_value"])
optimizer.step()
total_loss += loss.item()
log_interval = 1
total_tokens_per_log_interval += source.numel()
if i % log_interval == 0 and i > 0:
cur_loss = total_loss / log_interval
elapsed = time.time() - start_time
print(
"| batch {:5d} | wps {:5.2f} | loss {:5.2f} | ppl {:8.2f}".format(
i, total_tokens_per_log_interval / elapsed, cur_loss, math.exp(cur_loss)
)
)
total_tokens_per_log_interval = 0
total_loss = 0
start_time = time.time()
if epoch_start_time != 0:
wps = total_tokens / (time.time() - epoch_start_time)
else:
raise RuntimeError(
"Unable to benchmark on a single batch. Increase the size " " of the dataset and rerun the benchmark."
)
return wps, loss.item()
def verify_peak_memory(rank, golden_config, std_dev):
print("Peak allocated bytes on cuda:0: {:1d}".format(torch.cuda.memory_stats(rank)["allocated_bytes.all.peak"]))
current_device_usage = torch.cuda.memory_stats(rank)["allocated_bytes.all.peak"]
golden_ref = golden_config["peak_mem_usage"][rank]
if not current_device_usage < golden_ref * std_dev:
raise RuntimeError(
"Peak memory usage for cuda device {:d} is {:d} which"
"is less than golden reference value of {:d}".format(rank, current_device_usage, golden_ref)
)
def verify_lm_run(wps, golden_config, args):
"""Verify that words per second for a given benchmark run matches the golden data."""
# Verify wps only on the last rank in multiprocess pipe
if not args.multiprocess or dist.get_rank() == dist.get_world_size() - 1:
# Assert that words per second is within 3 standard deviations of the average
# of five golden runs
print("Throughput(wps) is {:.2f}.".format(wps))
if not wps > (golden_config["avg_wps"] - (3 * golden_config["std_dev_wps"])):
raise RuntimeError(
"Throughput(wps):{:.2f} is below the golden threshold of an "
"average value of {:.2f} and standard dev of {:.2f}.".format(
wps, golden_config["avg_wps"], golden_config["std_dev_wps"]
)
)
if args.multiprocess:
verify_peak_memory(dist.get_rank(), golden_config, 1.5)
else:
for i in range(4):
verify_peak_memory(i, golden_config, 1.1)
def benchmark_language_model(model_config, model, benchmark_config, model_specs, args):
epoch = benchmark_config["epochs"]
start_time = time.time()
print("-" * 110)
print("| start of epoch {:1d}".format(epoch))
print("-" * 110)
wps, loss = train(model_config, model, benchmark_config, model_specs, args)
elapsed_time = time.time() - start_time
print("-" * 110)
print("| end of epoch {:1d} | time: {:5.2f}s | train loss {:5.2f} ".format(epoch, elapsed_time, loss))
print("-" * 110)
print("Throughput(wps) is {:.2f}.".format(wps))
print("Peak allocated bytes on cuda:0: {:1d}".format(torch.cuda.memory_stats(0)["allocated_bytes.all.peak"]))
# TODO(anj-s): Enable golden config data verification.
def get_synthetic_dataloaders(args, device, benchmark_config, model_specs):
"""Returns dataloader for synthetic data."""
if args.model_name == "lm":
return get_synthetic_wikitext2_dataloaders(args, benchmark_config, model_specs)
elif args.model_name == "seq":
transform = ToTensor()
dataloader = DataLoader(
FakeData(
image_size=(1, model_specs["inputs"], model_specs["inputs"]),
num_classes=model_specs["outputs"],
transform=transform,
),
batch_size=benchmark_config["batch_size"],
)
return dataloader, dataloader, dataloader
else:
raise RuntimeError(f"Unrecognized args.model_name {args.model_name}")
def get_real_dataloaders(args, device, benchmark_config, model_specs):
"""Returns dataloaders for real data."""
if args.model_name == "lm":
data = get_real_wikitext2_dataloaders(args, benchmark_config, model_specs)
ntokens, train_dataloader, valid_dataloader, test_dataloader = data
model_specs["vocab_size"] = ntokens
return train_dataloader, valid_dataloader, test_dataloader
else:
raise RuntimeError(f"Unrecognized args.model_mame {args.model_name}")
def create_model_config(args, benchmark_config=None, model_specs=None):
"""Return a dict with the given model, dataset and optimizer."""
# device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
device = torch.device("cpu")
if args.model_name == "lm":
if args.use_synthetic_data:
dataloader_fn = get_synthetic_dataloaders
else:
dataloader_fn = get_real_dataloaders
data = dataloader_fn(args, device, benchmark_config, model_specs)
model, optimizer = get_model_and_optimizer(args, device, benchmark_config, model_specs)
return {
"model": model,
"optimizer": optimizer,
"data": data,
}
elif args.model_name == "seq":
data = get_synthetic_dataloaders(
args, device, offload_seq.get_benchmark_config(), offload_seq.get_model_config()
)
model, optimizer = get_model_and_optimizer(args, device, benchmark_config, model_specs)
return {
"model": model,
"optimizer": optimizer,
"data": data,
}
else:
raise RuntimeError(f"Unrecognized args.model_mame {args.model_name}")
def create_benchmark_config(model_name):
"""Return a dict with configurations required for benchmarking `model_name` model."""
if args.model_name == "lm":
return lm_wikitext2.get_benchmark_config()
elif args.model_name == "seq":
return offload_seq.get_benchmark_config()
else:
raise RuntimeError(f"Unrecognized args.model_name {args.model_name}")
def get_golden_config(model_name, args):
"""Return a dict with the golden data for throughput and memory usage."""
if model_name == "lm":
return lm_wikitext2.get_golden_real_stats(False)
else:
raise RuntimeError(f"Unrecognized args.model_mame {args.model_name}")
def get_model_specs(model_name):
"""Return a dict with configurations required for configuring `model_name` model."""
if model_name == "lm":
return lm_wikitext2.get_model_config()
elif model_name == "seq":
return offload_seq.get_model_config()
else:
raise RuntimeError("Unrecognized args.model_mame " % args.model_name)
def run_benchmark(args):
"""Benchmark a given model using a single process and single devices."""
# We need at least 1 GPU to benchmark the offload model API.
num_devices = torch.cuda.device_count() if torch.cuda.is_available() else 0
assert num_devices > 0
init_random_seed(0)
if args.model_name == "lm":
benchmark_config = create_benchmark_config(args.model_name)
model_specs = get_model_specs(args.model_name)
model_config = create_model_config(args, benchmark_config=benchmark_config, model_specs=model_specs)
model = model_config["model"]
if args.dry_run:
train(model_config, model, benchmark_config, args)
else:
benchmark_language_model(model_config, model, benchmark_config, model_specs, args)
elif args.model_name == "seq":
benchmark_config = create_benchmark_config(args.model_name)
model_specs = get_model_specs(args.model_name)
model_config = create_model_config(args, benchmark_config=benchmark_config, model_specs=model_specs)
model = model_config["model"]
train_seq(model_config, benchmark_config, model_specs, args)
else:
raise RuntimeError(f"Unable to recognize model name {args.model_name}")
parser = argparse.ArgumentParser(description="benchmark")
parser.add_argument("--dry_run", action="store_true", help="Run a sample training run without regression testing.")
parser.add_argument(
"--debug", action="store_true", help="Print debugging statements which is more verbose than the default."
)
parser.add_argument(
"--model_name", default="lm", type=str, help="Language Model(LM) used to benchmark nn.pipe.",
)
parser.add_argument("--use_synthetic_data", action="store_true", help="Uses synthetic data for running benchmarks.")
parser.add_argument("--use_fp16", action="store_true", default=False)
parser.add_argument("--checkpoint_activation", action="store_true", default=False)
parser.add_argument("--use_profiler", action="store_true", default=False)
if __name__ == "__main__":
args = parser.parse_args()
logging.basicConfig(level=logging.INFO if not args.debug else logging.DEBUG)
logging.info("Benchmark arguments: %s" % args)
run_benchmark(args)
......@@ -5,46 +5,95 @@ import torch.nn as nn
from fairscale.optim import GradScaler
def get_model_config():
return {
"vocab_size": 10000,
"ninp": 2048, # embedding dimension
"nhid": 2048, # the dimension of the feedforward network model in nn.TransformerEncoder
"nhead": 32, # the number of heads in the multiheadattention models
"dropout": 0,
"initrange": 0.1,
"scaler": GradScaler(),
"clip_value": 0.05,
"num_decoder_layers": 10,
"seq_len": 32,
}
def get_benchmark_config():
return {
"epochs": 1,
"lr": 0.001, # learning rate
"batch_size": 8,
"criterion": nn.CrossEntropyLoss(),
}
def get_golden_real_stats(multiprocess=False):
if not multiprocess:
class Offload_Transformer:
def get_model_config():
return {
"avg_wps": 703.778,
"std_dev_wps": 5.732,
"peak_mem_usage": [2320996352, 1396742144, 1396742144, 2340010496],
"vocab_size": 10000,
"ninp": 2048, # embedding dimension
"nhid": 2048, # the dimension of the feedforward network model in nn.TransformerEncoder
"nhead": 32, # the number of heads in the multiheadattention models
"dropout": 0,
"initrange": 0.1,
"scaler": GradScaler(),
"clip_value": 0.05,
"num_decoder_layers": 10,
"seq_len": 32,
}
else:
def get_benchmark_config():
return {
"epochs": 1,
"lr": 0.001, # learning rate
"batch_size": 8,
"criterion": nn.CrossEntropyLoss(),
"checkpoint_activation": True,
"num_microbatches": 4,
"slices": 3,
}
class Offload_Sequential:
def get_model_config():
return {
"inputs": 100,
"outputs": 5,
"hidden": 1000,
"layers": 100,
"clip_value": 0.05,
}
def get_benchmark_config():
return {
"epochs": 1,
"lr": 0.001, # learning rate
"batch_size": 8,
"criterion": nn.CrossEntropyLoss(),
"slices": 3,
"checkpoint_activation": True,
"num_microbatches": 4,
}
class Pipe:
def get_model_config():
return {
"vocab_size": 10000,
"ninp": 2048, # embedding dimension
"nhid": 2048, # the dimension of the feedforward network model in nn.TransformerEncoder
"nhead": 32, # the number of heads in the multiheadattention models
"dropout": 0,
"initrange": 0.1,
"scaler": GradScaler(),
"clip_value": 0.05,
"num_decoder_layers": 10,
"seq_len": 32,
}
def get_benchmark_config():
return {
"avg_wps": 647.404,
"std_dev_wps": 14.51,
"peak_mem_usage": [3305007616, 2578692608, 3304524288, 2578692608],
"epochs": 1,
"lr": 0.001, # learning rate
"batch_size": 8,
"criterion": nn.CrossEntropyLoss(),
}
def get_golden_real_stats(multiprocess=False):
if not multiprocess:
return {
"avg_wps": 703.778,
"std_dev_wps": 5.732,
"peak_mem_usage": [2320996352, 1396742144, 1396742144, 2340010496],
}
else:
return {
"avg_wps": 647.404,
"std_dev_wps": 14.51,
"peak_mem_usage": [3305007616, 2578692608, 3304524288, 2578692608],
}
def get_golden_synthetic_stats():
# TODO(anj-s): Add support for synthetic regression benchmarks
raise NotImplementedError("Synthetic data benchmarks are not supported.")
def get_golden_synthetic_stats():
# TODO(anj-s): Add support for synthetic regression benchmarks
raise NotImplementedError("Synthetic data benchmarks are not supported.")
......@@ -12,7 +12,6 @@ import time
from datasets.wikitext2_data import get_real_dataloaders as get_real_wikitext2_dataloaders
from datasets.wikitext2_data import get_synthetic_dataloaders as get_synthetic_wikitext2_dataloaders
from golden_configs import lm_wikitext2
from models import transformer_lm
import numpy as np
import torch
......@@ -22,6 +21,7 @@ import torch.multiprocessing as mp
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.optim import Adam
from benchmarks.golden_configs.lm_wikitext2 import Pipe as lm_wikitext2
from fairscale.nn import Pipe
from fairscale.nn.model_parallel import initialize_model_parallel
from fairscale.nn.model_parallel.initialize import get_data_parallel_group, get_pipeline_parallel_group
......
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree.
from builtins import isinstance
import functools
import logging
from typing import Any, List, Tuple
import torch
from torch import nn
def conditional_amp_fwd_decorator(orig_func): # type: ignore
if hasattr(torch.cuda.amp, "custom_fwd"):
return torch.cuda.amp.custom_fwd(orig_func) # type: ignore
@functools.wraps(orig_func)
def inner_decorator(*args: Any, **kwargs: Any) -> Any:
return orig_func(*args, **kwargs)
return inner_decorator
def conditional_amp_bwd_decorator(orig_func): # type: ignore
if hasattr(torch.cuda.amp, "custom_bwd"):
return torch.cuda.amp.custom_bwd(orig_func) # type: ignore
@functools.wraps(orig_func)
def inner_decorator(*args: Any, **kwargs: Any) -> Any:
return orig_func(*args, **kwargs)
return inner_decorator
def _split(modules: nn.Sequential, number_splits: int) -> List[List[nn.Module]]:
number_splits = min(len(modules), number_splits)
splits: List[List[nn.Module]] = [[] for _ in range(number_splits)]
# Count the number of parameters per exposed layer, use that as a proxy for memory footprint
total_number_params = sum([sum(p.numel() for p in m.parameters()) for m in modules])
number_parameters_per_shard = total_number_params // number_splits
current_shard = 0
logging.info(
f"This model has {total_number_params/1e6:.2f}M parameters, aiming for {number_parameters_per_shard/1e6:.2f}M parameters per shard"
)
for m in modules:
# Number of parameters in the current shard
current_shard_params = sum(p.numel() for sm in splits[current_shard] for p in sm.parameters())
# This shard is big enough, point to the next one
if (
current_shard_params > 0
and current_shard_params + sum(p.numel() for p in m.parameters()) > number_parameters_per_shard
and current_shard < number_splits - 1
):
current_shard += 1
splits[current_shard].append(m)
for i, split in enumerate(splits):
current_shard_params = sum(p.numel() for sm in split for p in sm.parameters())
logging.info(f"Shard {i} holds {current_shard_params/1e6:.2f}M parameters")
return splits
class ModelShard(nn.Module):
"""
Wrap one shard of the model, make it possible to load parameters on the
fly for the FW and BW pass on the given device.
"""
def __init__(
self, cpu_model_shard: nn.Module, device: torch.device, offload_device: torch.device, index: int,
):
super().__init__()
self.model_shard = cpu_model_shard
self.index = index
# Save all the parameter sizes to be able to restore them
self.device = device
torch.cuda.device(self.device)
self.offload_device = offload_device
self.model_shard.to(offload_device)
self.cuda_stream = torch.cuda.Stream(
device=self.device
) # needed to make sure load/offload really run in parallel with compute
def forward(self, *inputs): # type: ignore
return self.model_shard(*inputs) if isinstance(inputs, tuple) else self.model_shard(inputs)
def to(self, device: torch.device) -> "ModelShard": # type: ignore
# Make sure that the lookahead and lookback shards are not captured by this call
self.model_shard.to(device)
return self
def train(self, mode: bool = True) -> "ModelShard":
# Make sure that the lookahead and lookback shards are not captured by this call
self.model_shard.train(mode)
return self
def to_device(self) -> None:
self.model_shard.to(device=self.device, non_blocking=True)
def forward_load(self, non_blocking: bool = True) -> None:
with torch.cuda.stream(self.cuda_stream):
# Restore all the parameter buffers
self.model_shard.to(device=self.device, non_blocking=non_blocking)
def backward_load(self, non_blocking: bool = True) -> None:
with torch.cuda.stream(self.cuda_stream):
self.model_shard.to(self.device, non_blocking=non_blocking)
def forward_drop(self, non_blocking: bool = True) -> None:
with torch.cuda.stream(self.cuda_stream):
self.model_shard.to(self.offload_device, non_blocking=non_blocking)
def backward_drop(self, non_blocking: bool = True) -> None:
with torch.cuda.stream(self.cuda_stream):
self.model_shard.to(self.offload_device, non_blocking=non_blocking)
class ActivationCheckpointing(torch.autograd.Function):
"""
This Function enables checkpointing of intermediate activations at
shard boundaries by overriding the forward and backward pass of the nn.Module.
- In the FW pass, it drops parameters in the previous shard and
loads parameters for the next shard. No graph is constructed in the FW pass.
This enables us to offload intermediate activations present at the shard
boundaries.
- In the BW pass, it does the reverse. We run the forward pass using the
saved intermediate activations and calculate gradients as needed.
The trade-off is latency vs memory when using activation checkpointing.
- Follows heavily from https://pytorch.org/docs/stable/_modules/torch/utils/checkpoint.html#checkpoint.
NOTE: see https://pytorch.org/docs/stable/autograd.html#torch.autograd.Function
"""
@staticmethod
@conditional_amp_fwd_decorator # type: ignore
def forward(ctx: Any, inputs: Any, model_instance: Any) -> Any:
inputs = inputs if isinstance(inputs, tuple) else (inputs,)
ctx.inputs = inputs
ctx.model_instance = model_instance
# TODO(anj-s): We might need to store this for each boundary activation.
# Currently we assume all boundary activation inputs require
ctx.grad_requirements = tuple(x.requires_grad for x in inputs)
ctx.fwd_rng_state = torch.get_rng_state()
# List of input activations starting with the given input.
model_instance._activations = [inputs]
# Enumerate through layer shards and apply activations from the previous shard.
for index, layer_shard in enumerate(model_instance.model_slices):
# Bring in the current activations onto the device.
model_instance._activations[index] = tuple([a.cuda() for a in list(model_instance._activations[index])])
# Bring in the current layer shard onto the device.
layer_shard.forward_load()
# Apply the FP and store the activations on the CPU.
inputs = model_instance._activations[index]
with torch.no_grad():
output_list: List[Any] = []
for given_input in inputs:
given_input_list = torch.chunk(given_input, model_instance._num_microbatches)
given_output_list = []
for inputs in given_input_list:
output = layer_shard(inputs)
given_output_list.append(output)
given_output = torch.cat(given_output_list).squeeze(-1)
output_list.append(given_output)
output = tuple(output_list)
output = output if isinstance(output, tuple) else (output,)
# The last instance will lose the gradient function if we move it to the CPU.
# This is because all grad function are present on the device that ran the FW pass.
if index == len(model_instance.model_slices) - 1:
model_instance._activations.append(output)
else:
model_instance._activations.append(tuple([a.cpu() for a in list(output)]))
# Move the layer shard back to the CPU.
layer_shard.forward_drop()
# TODO(anj-s): Check device of the result to make sure the outputs and targets match device.
result = model_instance._activations[-1]
for r in result:
r.requires_grad = True
return result[0] if len(result) == 1 else result
@staticmethod
@conditional_amp_bwd_decorator
def backward(ctx, *grad_outputs): # type: ignore
if not torch.autograd._is_checkpoint_valid():
raise RuntimeError("Checkpointing is not compatible with .grad(), please use .backward() if possible")
inputs = ctx.inputs
model_instance = ctx.model_instance
for i, need_grad in enumerate(ctx.grad_requirements):
inputs[i].requires_grad = need_grad
all_grads = [grad_outputs]
final_index = len(model_instance._activations) - 1
for model_shard, activation in zip(
reversed(model_instance.model_slices), reversed(model_instance._activations[:-1])
):
# Move the activation to the device.
activation = tuple([a.cuda() for a in list(activation)])
# One of the inputs to the FW pass must require grad.
for a in activation:
a.requires_grad = True
# Move the model shard to the device.
model_shard.backward_load()
# Store the BW pass state.
bwd_rng_state = torch.get_rng_state()
# TODO(anj-s): Why detach inputs?
activation = torch.utils.checkpoint.detach_variable(activation)
# Get the last gradient calculation.
final_grads = all_grads[-1]
if isinstance(activation, torch.Tensor):
activation = (activation,)
if isinstance(final_grads, torch.Tensor):
final_grads = (final_grads,)
# Iterate through all the inputs/outputs of a shard (there could be multiple).
chunked_grad_list: List[Any] = []
# Chunk the activation and grad based on the number of microbatches that are set.
for chunked_activation, chunked_grad in zip(
torch.chunk(*activation, model_instance._num_microbatches), # type: ignore
torch.chunk(*final_grads, model_instance._num_microbatches), # type: ignore
):
# Set the states to what it used to be before the forward pass.
torch.set_rng_state(ctx.fwd_rng_state)
if isinstance(chunked_activation, torch.Tensor):
chunked_activation = (chunked_activation,) # type: ignore
if isinstance(chunked_grad, torch.Tensor):
chunked_grad = (chunked_grad,) # type: ignore
# Since we need a grad value of a non leaf element we need to set these properties.
for a in chunked_activation:
a.requires_grad = True
a.retain_grad()
with torch.enable_grad():
# calculate the output of the last shard wrt to the stored activation at the slice boundary.
outputs = model_shard(*chunked_activation)
# Set the states back to what it was at the start of this function.
torch.set_rng_state(bwd_rng_state)
torch.autograd.backward(outputs, chunked_grad)
chunked_grad_list += [a.grad for a in chunked_activation]
# Append the list of grads to the all_grads list and this should be on the CPU.
all_grads.append(torch.cat(chunked_grad_list).squeeze(-1)) # type: ignore
# Move activation back to the CPU.
# TODO(anj-s): Why does moving activations to CPU cause the .grad property to be None?
activation = tuple([a.cpu() for a in list(activation)])
# Move the shard back to the CPU.
model_shard.backward_drop()
detached_inputs = model_instance._activations[0]
grads = tuple(inp.grad if isinstance(inp, torch.Tensor) else inp for inp in detached_inputs)
return (None, None) + grads
class ShardSyncLayer(torch.autograd.Function):
"""
The shard sync layer is a synchronization point between model shards.
- In the forward pass, it drops parameters in the previous shard and
loads parameters for the next shard.
- In the backward pass, it does the reverse.
It does not change or create any outputs at all, instead it just
forwards the input as the output.
NOTE: see https://pytorch.org/docs/stable/autograd.html#torch.autograd.Function
"""
@staticmethod
@conditional_amp_fwd_decorator # type: ignore
def forward(ctx: Any, inputs: Any, index: int, model_slices: Any, model_instance: Any) -> Any:
drop_index = index
load_index = index + 1
max_slices = len(model_slices)
if drop_index >= 0:
# Move shard from device to offload device.
logging.info(f"Dropping shard {drop_index}")
model_slices[drop_index].forward_drop()
if load_index < max_slices:
# Load shard from offload device to device.
logging.info(f"Loading shard{load_index}")
model_slices[load_index].forward_load()
ctx.index = index
ctx.model_slices = model_slices
ctx.model_instance = model_instance
return inputs if isinstance(inputs, tuple) else (inputs,)
@staticmethod
@conditional_amp_bwd_decorator
def backward(ctx, *grad_outputs): # type: ignore
load_index = ctx.index
drop_index = load_index + 1
model_slices = ctx.model_slices
model_instance = ctx.model_instance
# TODO(anj-s): Are these redundant in the backward pass?
if drop_index == len(model_slices):
# Drop the last activation since it is still on the CPU
# after the loss.backward() call.
model_instance._activations[-1] = tuple([a.cuda() for a in list(model_instance._activations[-1])])
if drop_index < len(model_slices):
# Move shard from device to offload device.
logging.info(f"Backward Dropping shard {drop_index}")
model_slices[drop_index].backward_drop()
model_instance._activations[drop_index] = tuple(
[a.cpu() for a in list(model_instance._activations[drop_index])]
)
if load_index >= 0:
# Load shard from offload device to device.
logging.info(f"Backward Loading shard{load_index}")
model_slices[load_index].backward_load()
model_instance._activations[load_index] = tuple(
[a.cuda() for a in list(model_instance._activations[load_index])]
)
# The returned variables need to mirror the forward inputs
# TODO(anj-s): Why do we need to do this?
if isinstance(grad_outputs, tuple):
return grad_outputs[0], None, None, None
return grad_outputs, None, None, None
class OffloadModel(nn.Module):
"""Wrapper used offload parts of a model to the CPU.
The model is sharded into chunks and at each iteration, a
single chunk is copied from CPU->GPU, FW pass is computed and
the chunk is copied back to CPU. This process is repeated for
all the chunks. In the BW pass, the same process happens in
reverse.
Note: OffloadModel currently only supports nn.Sequential models.
Args:
module (~torch.nn.Sequential): Module to be offloaded.
device (torch.device):
Device where the active model should reside.
offload_device (torch.device):
Device where the inactive model should reside.
num_slices (int):
Number of slices into which the model should be chunked.
checkpoint_activation (bool):
Boolean to indicate if we want to checkpoint intermediate
activation states on the CPU. Default value is False.
num_microbatches (int):
Number of microbatches which should be run per model
shard on device.
"""
def __init__(
self,
model_cpu: nn.Sequential,
device: torch.device,
offload_device: torch.device = torch.device("cpu"),
num_slices: int = 5,
checkpoint_activation: bool = False,
num_microbatches: int = 1,
):
super().__init__()
# TODO(anj-s): Add error checks for cuda and sequential model.
self.device = device
self.offload_device = offload_device
# Slice the model into roughly equivalent sequential shards.
splits = _split(model_cpu, num_slices)
# List of model shards that will be placed on/off the device.
self.model_slices: List[nn.Module] = []
for i, split in enumerate(splits):
# Add one model handling this slice
self.model_slices.append(
ModelShard(
cpu_model_shard=nn.Sequential(*split), device=device, offload_device=offload_device, index=i,
)
)
# Expose a unified view of the slices
self.model = torch.nn.Sequential(*self.model_slices)
# intermediate activations at the slice boundaries.
self._activations: List[Tuple] = []
# Currently we only support microbatches with activation checkpointing.
if not checkpoint_activation and num_microbatches > 1:
raise RuntimeError("We currently only support microbatches with activation checkpointing.")
# Bool indicating if we want to checkpoint activation on the host.
self._checkpoint_activation = checkpoint_activation
# Number of microbatches to run per batch on the device
self._num_microbatches = num_microbatches
def forward(self, *inputs: Any, **_: Any) -> Any:
# At least one of the inputs needs to have `requires_grad` set.
# TODO(anj-s): Should we require users to set this or should we set it here?
set_at_least_once = False
for inp in inputs:
if inp.dtype == torch.long:
continue
inp.requires_grad = True
set_at_least_once = True
if not set_at_least_once:
raise RuntimeError("We need at least one of the inputs to require grads.")
if self._checkpoint_activation:
return ActivationCheckpointing.apply(*inputs, self)
self._activations = []
for index in range(-1, len(self.model_slices)):
if index >= 0:
# TODO(anj-s): This might be a redundant call since we have the previous
# activation on the device already.
self._activations[index] = tuple([a.cuda() for a in list(self._activations[index])])
inputs = self._activations[index]
inputs = self.model_slices[index](*inputs)
# Call the custom autograd hooks (discard/load slices FW and BW)
inputs = ShardSyncLayer.apply(inputs, index, self.model_slices, self)
self._activations.append(inputs)
if index >= 0:
self._activations[index] = tuple([a.cpu() for a in list(self._activations[index])])
# We don't move the last activation/output since the target is present
# on the device.
# TODO(anj-s): It is now a requirement that the target tensors be placed on the
# device.
result = self._activations[-1]
return result[0] if len(result) == 1 else result
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree.
"""
Testing Offload Module
"""
import contextlib
import copy
import numpy as np
import pytest
import torch
from fairscale.experimental.nn.offload import OffloadModel
from fairscale.utils.testing import skip_if_no_cuda
def _init():
torch.cuda.set_device(0)
torch.manual_seed(0)
np.random.seed(0)
device = torch.device("cuda")
offload_device = torch.device("cpu")
return device, offload_device
@skip_if_no_cuda
def test_single_run():
device, offload_device = _init()
model = _get_model()
offload_model = OffloadModel(model_cpu=model, device=device, offload_device=offload_device, num_slices=2,)
offload_optimizer = torch.optim.SGD(offload_model.parameters(), lr=0.001)
input = torch.ones(2, 2).to(device)
labels = torch.ones(2, 2).to(device)
offload_model.train()
pred = offload_model(input)
loss_fn = torch.nn.MSELoss(reduction="sum")
loss = loss_fn(pred, labels)
loss.backward()
offload_optimizer.step()
def _get_model(num_inputs=2, num_hidden=2, num_layers=1, num_outputs=2):
model = torch.nn.Sequential(
torch.nn.Linear(num_inputs, num_hidden),
*([torch.nn.Linear(num_hidden, num_hidden) for _ in range(num_layers)]),
torch.nn.Linear(num_hidden, num_outputs),
)
return model
def _check_parity(rmodel, omodel, ropt, oopt, rloss, oloss):
for oparams, rparams in zip(omodel.parameters(), rmodel.parameters()):
assert torch.allclose(oparams, rparams, atol=1e-2), f"Model params are different {oparams} {rparams}"
for o_pg, reg_pg in zip(oopt.param_groups, ropt.param_groups):
for o_pg, reg_pg in zip(o_pg["params"], reg_pg["params"]):
assert torch.allclose(
o_pg, reg_pg, atol=1e-2
), f"Model parameters differ in between Offlad and Vanilla {[o_pg]} {reg_pg}"
for o_buf, reg_buf in zip(omodel.buffers(), rmodel.buffers()):
assert torch.allclose(o_buf, reg_buf, atol=1e-2), "Model buffers differ in between Offload and Vanilla."
def _get_fp16_context(use_fp16=False):
if use_fp16:
return torch.cuda.amp.autocast()
else:
return contextlib.nullcontext()
def _train(model, optimizer, use_fp16, device):
inputs = torch.ones(32, 2).to(device)
labels = torch.ones(32, 2).to(device)
loss_fn = torch.nn.MSELoss(reduction="sum")
model.train()
with _get_fp16_context(use_fp16):
pred = model(inputs)
loss = loss_fn(pred, labels)
loss.backward()
optimizer.step()
return model, optimizer, loss
def _train_reg_model(model, device, offload_device, use_fp16=False):
reg_model = copy.deepcopy(model)
reg_model = reg_model.cuda()
reg_optimizer = torch.optim.SGD(reg_model.parameters(), lr=0.001)
return _train(reg_model, reg_optimizer, use_fp16, device)
def _train_offload_model(
model, device, offload_device, use_fp16=False, checkpoint_activation=False, num_microbatches=1
):
omodel = copy.deepcopy(model)
offload_model = OffloadModel(
model_cpu=omodel,
device=device,
offload_device=offload_device,
num_slices=2,
checkpoint_activation=checkpoint_activation,
num_microbatches=num_microbatches,
)
offload_optimizer = torch.optim.SGD(offload_model.parameters(), lr=0.001)
return _train(offload_model, offload_optimizer, use_fp16, device)
@skip_if_no_cuda
@pytest.mark.parametrize("use_fp16", [True, False])
@pytest.mark.parametrize("checkpoint_activation", [True, False])
@pytest.mark.parametrize("num_microbatches", [1, 5])
def test_correctness(use_fp16, checkpoint_activation, num_microbatches):
if (use_fp16 or checkpoint_activation) and not hasattr(torch.cuda.amp, "custom_fwd"):
pytest.skip(f"AMP APIs are not supported in torch version {torch.__version__}")
if not checkpoint_activation and num_microbatches > 1:
pytest.skip("We only support microbatches with activation offloading.")
device, offload_device = _init()
model = _get_model()
rmodel, ropt, rloss = _train_reg_model(model, device, offload_device)
omodel, oopt, oloss = _train_offload_model(
model,
device,
offload_device,
use_fp16=use_fp16,
checkpoint_activation=checkpoint_activation,
num_microbatches=num_microbatches,
)
_check_parity(rmodel.cpu(), omodel.cpu(), ropt, oopt, rloss, oloss)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment