Unverified Commit 39a12a8b authored by anj-s's avatar anj-s Committed by GitHub
Browse files

[offload] Add support for multiple streams and fix issue with integer inputs. (#515)



* debugging statements

* fix index inputs and streams

* fix lint errors

* remove print

* lint errors

* address comments

* lint error
Co-authored-by: default avatarAnjali Sridhar <anj@devfair0443.h2.fair>
parent 66dfe606
......@@ -58,7 +58,7 @@ def get_synthetic_dataloaders(args, benchmark_config, model_specs):
"""Return synthetic dataloaders for training, testing and validation."""
def batchify(data):
batch_size = args.batch_size
batch_size = benchmark_config["batch_size"]
return _batchify(data, batch_size)
total_batch_size = total_batch_size = _get_total_batch_size(benchmark_config, model_specs)
......
......@@ -125,12 +125,13 @@ def train_seq(model_config, benchmark_config, model_specs, args):
for batch_inputs, batch_outputs in dataloader:
batch_inputs, batch_outputs = batch_inputs.to("cuda"), batch_outputs.to("cuda")
start = time.time_ns()
with _get_profiler_context() as prof:
with _get_profiler_context(args.use_profiler) as prof:
optimizer.zero_grad()
inputs = batch_inputs.reshape(-1, model_specs["inputs"] * model_specs["inputs"])
with _get_profiler_record_context("model_training"):
with _get_profiler_record_context("model_training", args.use_profiler):
with _get_fp16_context(use_fp16=args.use_fp16):
output = model(inputs)
print(f"output grad_fn {output.grad_fn}")
loss = criterion(output, target=batch_outputs)
loss.backward()
optimizer.step()
......@@ -149,6 +150,9 @@ def train_seq(model_config, benchmark_config, model_specs, args):
def train(model_config, model, benchmark_config, model_specs, args):
device = torch.device("cuda")
torch.cuda.set_device(0)
lm_dataloader, _, _ = model_config["data"]
criterion = benchmark_config["criterion"]
vocab_size = model_specs["vocab_size"]
......@@ -179,19 +183,21 @@ def train(model_config, model, benchmark_config, model_specs, args):
epoch_start_time = time.time()
source, target = get_batch(batch)
source, target = source.cuda(), target.cuda()
if i > 0:
total_tokens += source.numel()
with _get_profiler_context(args.use_profiler) as prof:
optimizer.zero_grad()
with _get_profiler_record_context("FW pass", args.use_profiler):
output = model(source)
target = target.to("cuda")
output = output.to(target.device)
with _get_profiler_record_context("Loss", args.use_profiler):
loss = criterion(output.view(-1, vocab_size), target.view(-1))
with _get_profiler_record_context("BW pass", args.use_profiler):
loss.backward()
torch.nn.utils.clip_grad_value_(model.parameters(), model_specs["clip_value"])
with _get_profiler_record_context("Opt step", args.use_profiler):
optimizer.step()
total_loss += loss.item()
......@@ -208,6 +214,8 @@ def train(model_config, model, benchmark_config, model_specs, args):
total_tokens_per_log_interval = 0
total_loss = 0
start_time = time.time()
prof.export_chrome_trace("/tmp/offload_prof")
if epoch_start_time != 0:
wps = total_tokens / (time.time() - epoch_start_time)
else:
......@@ -379,7 +387,7 @@ def run_benchmark(args):
model = model_config["model"]
if args.dry_run:
train(model_config, model, benchmark_config, args)
train(model_config, model, benchmark_config, model_specs, args)
else:
benchmark_language_model(model_config, model, benchmark_config, model_specs, args)
elif args.model_name == "seq":
......@@ -393,16 +401,23 @@ def run_benchmark(args):
parser = argparse.ArgumentParser(description="benchmark")
parser.add_argument("--dry_run", action="store_true", help="Run a sample training run without regression testing.")
parser.add_argument(
"--debug", action="store_true", help="Print debugging statements which is more verbose than the default."
"--dry_run", default=True, action="store_true", help="Run a sample training run without regression testing."
)
parser.add_argument(
"--debug",
action="store_true",
default=True,
help="Print debugging statements which is more verbose than the default.",
)
parser.add_argument(
"--model_name", default="lm", type=str, help="Language Model(LM) used to benchmark nn.pipe.",
)
parser.add_argument("--use_synthetic_data", action="store_true", help="Uses synthetic data for running benchmarks.")
parser.add_argument(
"--use_synthetic_data", default=True, action="store_true", help="Uses synthetic data for running benchmarks."
)
parser.add_argument("--use_fp16", action="store_true", default=False)
parser.add_argument("--checkpoint_activation", action="store_true", default=False)
parser.add_argument("--checkpoint_activation", action="store_true", default=True)
parser.add_argument("--use_profiler", action="store_true", default=False)
......
......@@ -28,7 +28,7 @@ class Offload_Transformer:
"batch_size": 8,
"criterion": nn.CrossEntropyLoss(),
"checkpoint_activation": True,
"num_microbatches": 4,
"num_microbatches": 1,
"slices": 3,
}
......
......@@ -193,7 +193,6 @@ def train(model_config, model, benchmark_config, model_specs, args):
if pipe_group is None or pipe_group.rank() == pipe_group.size() - 1:
target = target.to(get_device(model, -1))
output = output.to(target.device)
loss = criterion(output.view(-1, vocab_size), target.view(-1))
loss.backward()
del target
......
......@@ -51,6 +51,8 @@ def _split(modules: nn.Sequential, number_splits: int) -> List[List[nn.Module]]:
)
for m in modules:
for p in m.parameters():
p.data = p.data.pin_memory()
# Number of parameters in the current shard
current_shard_params = sum(p.numel() for sm in splits[current_shard] for p in sm.parameters())
......@@ -91,9 +93,8 @@ class ModelShard(nn.Module):
self.offload_device = offload_device
self.model_shard.to(offload_device)
self.cuda_stream = torch.cuda.Stream(
device=self.device
) # needed to make sure load/offload really run in parallel with compute
self._cpu_to_gpu_stream = torch.cuda.Stream(device=self.device)
self._gpu_to_cpu_stream = torch.cuda.Stream(device=self.device)
def forward(self, *inputs): # type: ignore
return self.model_shard(*inputs) if isinstance(inputs, tuple) else self.model_shard(inputs)
......@@ -112,20 +113,20 @@ class ModelShard(nn.Module):
self.model_shard.to(device=self.device, non_blocking=True)
def forward_load(self, non_blocking: bool = True) -> None:
with torch.cuda.stream(self.cuda_stream):
with torch.cuda.stream(self._cpu_to_gpu_stream):
# Restore all the parameter buffers
self.model_shard.to(device=self.device, non_blocking=non_blocking)
def backward_load(self, non_blocking: bool = True) -> None:
with torch.cuda.stream(self.cuda_stream):
with torch.cuda.stream(self._cpu_to_gpu_stream):
self.model_shard.to(self.device, non_blocking=non_blocking)
def forward_drop(self, non_blocking: bool = True) -> None:
with torch.cuda.stream(self.cuda_stream):
with torch.cuda.stream(self._gpu_to_cpu_stream):
self.model_shard.to(self.offload_device, non_blocking=non_blocking)
def backward_drop(self, non_blocking: bool = True) -> None:
with torch.cuda.stream(self.cuda_stream):
with torch.cuda.stream(self._gpu_to_cpu_stream):
self.model_shard.to(self.offload_device, non_blocking=non_blocking)
......@@ -150,7 +151,7 @@ class ActivationCheckpointing(torch.autograd.Function):
@staticmethod
@conditional_amp_fwd_decorator # type: ignore
def forward(ctx: Any, inputs: Any, model_instance: Any) -> Any:
def forward(ctx: Any, inputs: Any, dummy_input: Any, model_instance: Any) -> Any:
inputs = inputs if isinstance(inputs, tuple) else (inputs,)
ctx.inputs = inputs
......@@ -168,6 +169,7 @@ class ActivationCheckpointing(torch.autograd.Function):
model_instance._activations[index] = tuple([a.cuda() for a in list(model_instance._activations[index])])
# Bring in the current layer shard onto the device.
layer_shard.forward_load()
# Apply the FP and store the activations on the CPU.
inputs = model_instance._activations[index]
......@@ -212,17 +214,9 @@ class ActivationCheckpointing(torch.autograd.Function):
all_grads = [grad_outputs]
final_index = len(model_instance._activations) - 1
for model_shard, activation in zip(
reversed(model_instance.model_slices), reversed(model_instance._activations[:-1])
):
# Move the activation to the device.
activation = tuple([a.cuda() for a in list(activation)])
# One of the inputs to the FW pass must require grad.
for a in activation:
a.requires_grad = True
# Move the model shard to the device.
model_shard.backward_load()
# Store the BW pass state.
......@@ -232,6 +226,7 @@ class ActivationCheckpointing(torch.autograd.Function):
activation = torch.utils.checkpoint.detach_variable(activation)
# Get the last gradient calculation.
final_grads = all_grads[-1]
if isinstance(activation, torch.Tensor):
activation = (activation,)
if isinstance(final_grads, torch.Tensor):
......@@ -253,6 +248,8 @@ class ActivationCheckpointing(torch.autograd.Function):
# Since we need a grad value of a non leaf element we need to set these properties.
for a in chunked_activation:
if a.dtype == torch.long:
continue
a.requires_grad = True
a.retain_grad()
......@@ -263,13 +260,16 @@ class ActivationCheckpointing(torch.autograd.Function):
# Set the states back to what it was at the start of this function.
torch.set_rng_state(bwd_rng_state)
torch.autograd.backward(outputs, chunked_grad)
chunked_grad_list += [a.grad for a in chunked_activation]
intermediate_grads = []
for a in chunked_activation:
if a.grad is not None:
intermediate_grads.append(a.grad)
if None not in intermediate_grads:
chunked_grad_list += intermediate_grads
if chunked_grad_list:
# Append the list of grads to the all_grads list and this should be on the CPU.
all_grads.append(torch.cat(chunked_grad_list).squeeze(-1)) # type: ignore
# Move activation back to the CPU.
# TODO(anj-s): Why does moving activations to CPU cause the .grad property to be None?
activation = tuple([a.cpu() for a in list(activation)])
# Move the shard back to the CPU.
model_shard.backward_drop()
detached_inputs = model_instance._activations[0]
......@@ -432,20 +432,9 @@ class OffloadModel(nn.Module):
self._num_microbatches = num_microbatches
def forward(self, *inputs: Any, **_: Any) -> Any:
# At least one of the inputs needs to have `requires_grad` set.
# TODO(anj-s): Should we require users to set this or should we set it here?
set_at_least_once = False
for inp in inputs:
if inp.dtype == torch.long:
continue
inp.requires_grad = True
set_at_least_once = True
if not set_at_least_once:
raise RuntimeError("We need at least one of the inputs to require grads.")
dummy_input = torch.tensor([], requires_grad=True)
if self._checkpoint_activation:
return ActivationCheckpointing.apply(*inputs, self)
return ActivationCheckpointing.apply(*inputs, dummy_input, self)
self._activations = []
for index in range(-1, len(self.model_slices)):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment