Unverified Commit a77c56f0 authored by anj-s's avatar anj-s Committed by GitHub
Browse files

[fix] Revert change that removed the option to run OffloadModel with out...


[fix] Revert change that removed the option to run OffloadModel with out activation checkpointing. (#608)

* revert change made

* add tests and revert sync shard changes

* add tests

* remove file checked in by error

* inine var

* fix lint errors

* add checkpoint activation

* fix mypy

* use a bigger model

* modify tests for now

* resolve conflicts
Co-authored-by: default avatarAnjali Sridhar <anj@devfair0443.h2.fair>
parent 56506951
...@@ -170,7 +170,7 @@ run_offload_benchmark: &run_offload_benchmark ...@@ -170,7 +170,7 @@ run_offload_benchmark: &run_offload_benchmark
- run: - run:
name: Run Offload Benchmark name: Run Offload Benchmark
command: | command: |
python benchmarks/experimental/offload.py python benchmarks/experimental/offload.py --checkpoint_activation
run_pipe_benchmark: &run_pipe_benchmark run_pipe_benchmark: &run_pipe_benchmark
- run: - run:
......
...@@ -233,7 +233,7 @@ def train(model_config, model, benchmark_config, model_specs, args): ...@@ -233,7 +233,7 @@ def train(model_config, model, benchmark_config, model_specs, args):
def verify_peak_memory(golden_config, std_dev): def verify_peak_memory(golden_config, std_dev):
print("Peak allocated bytes on cuda:0: {:1d}".format(torch.cuda.memory_stats(0)["allocated_bytes.all.peak"]))
current_device_usage = torch.cuda.memory_stats(0)["allocated_bytes.all.peak"] current_device_usage = torch.cuda.memory_stats(0)["allocated_bytes.all.peak"]
golden_ref = golden_config["peak_mem_usage"] golden_ref = golden_config["peak_mem_usage"]
if not current_device_usage < golden_ref * std_dev: if not current_device_usage < golden_ref * std_dev:
...@@ -246,7 +246,6 @@ def verify_peak_memory(golden_config, std_dev): ...@@ -246,7 +246,6 @@ def verify_peak_memory(golden_config, std_dev):
def verify_lm_throughput(wps, golden_config, args): def verify_lm_throughput(wps, golden_config, args):
"""Verify that words per second for a given benchmark run matches the golden data.""" """Verify that words per second for a given benchmark run matches the golden data."""
print("Throughput(wps) is {:.2f}.".format(wps))
if not wps > (golden_config["avg_wps"] - (3 * golden_config["std_dev_wps"])): if not wps > (golden_config["avg_wps"] - (3 * golden_config["std_dev_wps"])):
raise RuntimeError( raise RuntimeError(
"Throughput(wps):{:.2f} is below the golden threshold of an " "Throughput(wps):{:.2f} is below the golden threshold of an "
...@@ -272,9 +271,12 @@ def benchmark_language_model(model_config, model, benchmark_config, model_specs, ...@@ -272,9 +271,12 @@ def benchmark_language_model(model_config, model, benchmark_config, model_specs,
raise RuntimeError( raise RuntimeError(
f"Golden data verification is only supported for the Transformer(lm) model and not {args.model_name}" f"Golden data verification is only supported for the Transformer(lm) model and not {args.model_name}"
) )
golden_config = get_golden_config(args.model_name, args) print("Throughput(wps) is {:.2f}.".format(wps))
verify_lm_throughput(wps, golden_config, args) print("Peak allocated bytes on cuda:0: {:1d}".format(torch.cuda.memory_stats(0)["allocated_bytes.all.peak"]))
verify_peak_memory(golden_config, 1.1) if not args.dry_run:
golden_config = get_golden_config(args.model_name, args)
verify_lm_throughput(wps, golden_config, args)
verify_peak_memory(golden_config, 1.1)
def get_synthetic_dataloaders(args, device, benchmark_config, model_specs): def get_synthetic_dataloaders(args, device, benchmark_config, model_specs):
...@@ -343,11 +345,11 @@ def create_model_config(args, benchmark_config=None, model_specs=None): ...@@ -343,11 +345,11 @@ def create_model_config(args, benchmark_config=None, model_specs=None):
raise RuntimeError(f"Unrecognized args.model_mame {args.model_name}") raise RuntimeError(f"Unrecognized args.model_mame {args.model_name}")
def create_benchmark_config(model_name): def create_benchmark_config(args):
"""Return a dict with configurations required for benchmarking `model_name` model.""" """Return a dict with configurations required for benchmarking `model_name` model."""
if args.model_name == "lm": if args.model_name == "lm":
return lm_wikitext2.get_benchmark_config() return lm_wikitext2.get_benchmark_config(checkpoint_activation=args.checkpoint_activation)
elif args.model_name == "seq": elif args.model_name == "seq":
return offload_seq.get_benchmark_config() return offload_seq.get_benchmark_config()
else: else:
...@@ -383,17 +385,15 @@ def run_benchmark(args): ...@@ -383,17 +385,15 @@ def run_benchmark(args):
init_random_seed(0) init_random_seed(0)
if args.model_name == "lm": if args.model_name == "lm":
benchmark_config = create_benchmark_config(args.model_name) benchmark_config = create_benchmark_config(args)
model_specs = get_model_specs(args.model_name) model_specs = get_model_specs(args.model_name)
model_config = create_model_config(args, benchmark_config=benchmark_config, model_specs=model_specs) model_config = create_model_config(args, benchmark_config=benchmark_config, model_specs=model_specs)
model = model_config["model"] model = model_config["model"]
if args.dry_run: benchmark_language_model(model_config, model, benchmark_config, model_specs, args)
train(model_config, model, benchmark_config, model_specs, args)
else:
benchmark_language_model(model_config, model, benchmark_config, model_specs, args)
elif args.model_name == "seq": elif args.model_name == "seq":
benchmark_config = create_benchmark_config(args.model_name) benchmark_config = create_benchmark_config(args)
model_specs = get_model_specs(args.model_name) model_specs = get_model_specs(args.model_name)
model_config = create_model_config(args, benchmark_config=benchmark_config, model_specs=model_specs) model_config = create_model_config(args, benchmark_config=benchmark_config, model_specs=model_specs)
model = model_config["model"] model = model_config["model"]
...@@ -419,7 +419,7 @@ parser.add_argument( ...@@ -419,7 +419,7 @@ parser.add_argument(
"--use_synthetic_data", default=True, action="store_true", help="Uses synthetic data for running benchmarks." "--use_synthetic_data", default=True, action="store_true", help="Uses synthetic data for running benchmarks."
) )
parser.add_argument("--use_fp16", action="store_true", default=False) parser.add_argument("--use_fp16", action="store_true", default=False)
parser.add_argument("--checkpoint_activation", action="store_true", default=True) parser.add_argument("--checkpoint_activation", action="store_true", default=False)
parser.add_argument("--use_profiler", action="store_true", default=False) parser.add_argument("--use_profiler", action="store_true", default=False)
......
...@@ -20,14 +20,14 @@ class Offload_Transformer: ...@@ -20,14 +20,14 @@ class Offload_Transformer:
"seq_len": 32, "seq_len": 32,
} }
def get_benchmark_config(): def get_benchmark_config(checkpoint_activation=True):
return { return {
"epochs": 1, "epochs": 1,
"lr": 0.001, # learning rate "lr": 0.001, # learning rate
"batch_size": 8, "batch_size": 8,
"criterion": nn.CrossEntropyLoss(), "criterion": nn.CrossEntropyLoss(),
"checkpoint_activation": True, "checkpoint_activation": checkpoint_activation,
"num_microbatches": 1, "num_microbatches": 1,
"slices": 3, "slices": 3,
} }
...@@ -59,7 +59,7 @@ class Offload_Sequential: ...@@ -59,7 +59,7 @@ class Offload_Sequential:
"criterion": nn.CrossEntropyLoss(), "criterion": nn.CrossEntropyLoss(),
"slices": 3, "slices": 3,
"checkpoint_activation": True, "checkpoint_activation": True,
"num_microbatches": 4, "num_microbatches": 1,
} }
......
...@@ -292,6 +292,75 @@ class OffloadFunction(torch.autograd.Function): ...@@ -292,6 +292,75 @@ class OffloadFunction(torch.autograd.Function):
return (None, None) + grads return (None, None) + grads
class ShardSyncLayer(torch.autograd.Function):
"""
The shard sync layer is a synchronization point between model shards.
- In the forward pass, it drops parameters in the previous shard and
loads parameters for the next shard.
- In the backward pass, it does the reverse.
It does not change or create any outputs at all, instead it just
forwards the input as the output.
NOTE: see https://pytorch.org/docs/stable/autograd.html#torch.autograd.Function
"""
@staticmethod
@_conditional_amp_fwd_decorator # type: ignore
def forward(ctx: Any, inputs: Any, index: int, model_slices: Any, model_instance: Any) -> Any:
drop_index = index
load_index = index + 1
max_slices = len(model_slices)
if drop_index >= 0:
# Move shard from device to offload device.
model_slices[drop_index].forward_drop()
if load_index < max_slices:
# Load shard from offload device to device.
model_slices[load_index].forward_load()
ctx.index = index
ctx.model_slices = model_slices
ctx.model_instance = model_instance
return inputs if isinstance(inputs, tuple) else (inputs,)
@staticmethod
@_conditional_amp_bwd_decorator
def backward(ctx, *grad_outputs): # type: ignore
load_index = ctx.index
drop_index = load_index + 1
model_slices = ctx.model_slices
model_instance = ctx.model_instance
# TODO(anj-s): Are these redundant in the backward pass?
if drop_index == len(model_slices):
# Drop the last activation since it is still on the CPU
# after the loss.backward() call.
model_instance._activations[-1] = tuple([a.cuda() for a in list(model_instance._activations[-1])])
if drop_index < len(model_slices):
# Move shard from device to offload device.
model_slices[drop_index].backward_drop()
model_instance._activations[drop_index] = tuple(
[a.cpu() for a in list(model_instance._activations[drop_index])]
)
if load_index >= 0:
# Load shard from offload device to device.
model_slices[load_index].backward_load()
model_instance._activations[load_index] = tuple(
[a.cuda() for a in list(model_instance._activations[load_index])]
)
# The returned variables need to mirror the forward inputs
# TODO(anj-s): Why do we need to do this?
if isinstance(grad_outputs, tuple):
return grad_outputs[0], None, None, None
return grad_outputs, None, None, None
class OffloadModel(nn.Module): class OffloadModel(nn.Module):
"""Wraps an arbitrary :class:`nn.Sequential <torch.nn.Sequential>` module """Wraps an arbitrary :class:`nn.Sequential <torch.nn.Sequential>` module
to train by offloading majority of the model parameters to the CPU. to train by offloading majority of the model parameters to the CPU.
...@@ -405,4 +474,23 @@ class OffloadModel(nn.Module): ...@@ -405,4 +474,23 @@ class OffloadModel(nn.Module):
# We need the second param to be a dummy input to enable the # We need the second param to be a dummy input to enable the
# backward pass to be triggered for integer inputs. # backward pass to be triggered for integer inputs.
return OffloadFunction.apply(*inputs, torch.tensor([], requires_grad=True), self) if self._checkpoint_activation:
return OffloadFunction.apply(*inputs, torch.tensor([], requires_grad=True), self)
self._activations = []
for index in range(-1, len(self.model_slices)):
if index >= 0:
# TODO(anj-s): This might be a redundant call since we have the previous
# activation on the device already.
self._activations[index] = tuple([a.cuda() for a in list(self._activations[index])])
inputs = self._activations[index]
inputs = self.model_slices[index](*inputs)
# Call the custom autograd hooks (discard/load slices FW and BW)
inputs = ShardSyncLayer.apply(inputs, index, self.model_slices, self)
self._activations.append(inputs)
if index >= 0:
self._activations[index] = tuple([a.cpu() for a in list(self._activations[index])])
result = self._activations[-1]
result = tuple([r.cuda() for r in result])
return result[0] if len(result) == 1 else result
...@@ -32,20 +32,38 @@ def test_single_run(): ...@@ -32,20 +32,38 @@ def test_single_run():
device, offload_device = _init() device, offload_device = _init()
model = _get_model() model = _get_model()
offload_model = OffloadModel(model=model, device=device, offload_device=offload_device, num_slices=2,) peak_mem = {}
offload_optimizer = torch.optim.SGD(offload_model.parameters(), lr=0.001) for checkpoint_activation in [True, False]:
offload_model = OffloadModel(
model=model,
device=device,
offload_device=offload_device,
num_slices=2,
checkpoint_activation=checkpoint_activation,
)
offload_optimizer = torch.optim.SGD(offload_model.parameters(), lr=0.001)
input = torch.ones(1000, 2).to(device)
labels = torch.ones(1000, 2).to(device)
offload_model.train()
pred = offload_model(input)
loss_fn = torch.nn.MSELoss(reduction="sum")
loss = loss_fn(pred, labels)
loss.backward()
offload_optimizer.step()
key = "ca_" + str(checkpoint_activation)
peak_mem[key] = torch.cuda.memory_stats(0)["allocated_bytes.all.peak"]
print(
"Peak allocated bytes on cuda:0 for checkpoint_activation "
+ str(checkpoint_activation)
+ ": {:2f}".format(peak_mem[key])
)
input = torch.ones(2, 2).to(device) # TODO(anj-s): We need a better requirement since this fails on CircleCI right now.
labels = torch.ones(2, 2).to(device) assert peak_mem["ca_True"] <= peak_mem["ca_False"]
offload_model.train()
pred = offload_model(input)
loss_fn = torch.nn.MSELoss(reduction="sum")
loss = loss_fn(pred, labels)
loss.backward()
offload_optimizer.step()
def _get_model(num_inputs=2, num_hidden=2, num_layers=1, num_outputs=2): def _get_model(num_inputs=2, num_hidden=20, num_layers=10, num_outputs=2):
model = torch.nn.Sequential( model = torch.nn.Sequential(
torch.nn.Linear(num_inputs, num_hidden), torch.nn.Linear(num_inputs, num_hidden),
*([torch.nn.Linear(num_hidden, num_hidden) for _ in range(num_layers)]), *([torch.nn.Linear(num_hidden, num_hidden) for _ in range(num_layers)]),
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment