Unverified Commit 39a12a8b authored by anj-s's avatar anj-s Committed by GitHub
Browse files

[offload] Add support for multiple streams and fix issue with integer inputs. (#515)



* debugging statements

* fix index inputs and streams

* fix lint errors

* remove print

* lint errors

* address comments

* lint error
Co-authored-by: default avatarAnjali Sridhar <anj@devfair0443.h2.fair>
parent 66dfe606
...@@ -58,7 +58,7 @@ def get_synthetic_dataloaders(args, benchmark_config, model_specs): ...@@ -58,7 +58,7 @@ def get_synthetic_dataloaders(args, benchmark_config, model_specs):
"""Return synthetic dataloaders for training, testing and validation.""" """Return synthetic dataloaders for training, testing and validation."""
def batchify(data): def batchify(data):
batch_size = args.batch_size batch_size = benchmark_config["batch_size"]
return _batchify(data, batch_size) return _batchify(data, batch_size)
total_batch_size = total_batch_size = _get_total_batch_size(benchmark_config, model_specs) total_batch_size = total_batch_size = _get_total_batch_size(benchmark_config, model_specs)
......
...@@ -125,12 +125,13 @@ def train_seq(model_config, benchmark_config, model_specs, args): ...@@ -125,12 +125,13 @@ def train_seq(model_config, benchmark_config, model_specs, args):
for batch_inputs, batch_outputs in dataloader: for batch_inputs, batch_outputs in dataloader:
batch_inputs, batch_outputs = batch_inputs.to("cuda"), batch_outputs.to("cuda") batch_inputs, batch_outputs = batch_inputs.to("cuda"), batch_outputs.to("cuda")
start = time.time_ns() start = time.time_ns()
with _get_profiler_context() as prof: with _get_profiler_context(args.use_profiler) as prof:
optimizer.zero_grad() optimizer.zero_grad()
inputs = batch_inputs.reshape(-1, model_specs["inputs"] * model_specs["inputs"]) inputs = batch_inputs.reshape(-1, model_specs["inputs"] * model_specs["inputs"])
with _get_profiler_record_context("model_training"): with _get_profiler_record_context("model_training", args.use_profiler):
with _get_fp16_context(use_fp16=args.use_fp16): with _get_fp16_context(use_fp16=args.use_fp16):
output = model(inputs) output = model(inputs)
print(f"output grad_fn {output.grad_fn}")
loss = criterion(output, target=batch_outputs) loss = criterion(output, target=batch_outputs)
loss.backward() loss.backward()
optimizer.step() optimizer.step()
...@@ -149,6 +150,9 @@ def train_seq(model_config, benchmark_config, model_specs, args): ...@@ -149,6 +150,9 @@ def train_seq(model_config, benchmark_config, model_specs, args):
def train(model_config, model, benchmark_config, model_specs, args): def train(model_config, model, benchmark_config, model_specs, args):
device = torch.device("cuda")
torch.cuda.set_device(0)
lm_dataloader, _, _ = model_config["data"] lm_dataloader, _, _ = model_config["data"]
criterion = benchmark_config["criterion"] criterion = benchmark_config["criterion"]
vocab_size = model_specs["vocab_size"] vocab_size = model_specs["vocab_size"]
...@@ -179,19 +183,21 @@ def train(model_config, model, benchmark_config, model_specs, args): ...@@ -179,19 +183,21 @@ def train(model_config, model, benchmark_config, model_specs, args):
epoch_start_time = time.time() epoch_start_time = time.time()
source, target = get_batch(batch) source, target = get_batch(batch)
source, target = source.cuda(), target.cuda()
if i > 0: if i > 0:
total_tokens += source.numel() total_tokens += source.numel()
with _get_profiler_context(args.use_profiler) as prof:
optimizer.zero_grad() optimizer.zero_grad()
with _get_profiler_record_context("FW pass", args.use_profiler):
output = model(source) output = model(source)
with _get_profiler_record_context("Loss", args.use_profiler):
target = target.to("cuda")
output = output.to(target.device)
loss = criterion(output.view(-1, vocab_size), target.view(-1)) loss = criterion(output.view(-1, vocab_size), target.view(-1))
with _get_profiler_record_context("BW pass", args.use_profiler):
loss.backward() loss.backward()
torch.nn.utils.clip_grad_value_(model.parameters(), model_specs["clip_value"]) torch.nn.utils.clip_grad_value_(model.parameters(), model_specs["clip_value"])
with _get_profiler_record_context("Opt step", args.use_profiler):
optimizer.step() optimizer.step()
total_loss += loss.item() total_loss += loss.item()
...@@ -208,6 +214,8 @@ def train(model_config, model, benchmark_config, model_specs, args): ...@@ -208,6 +214,8 @@ def train(model_config, model, benchmark_config, model_specs, args):
total_tokens_per_log_interval = 0 total_tokens_per_log_interval = 0
total_loss = 0 total_loss = 0
start_time = time.time() start_time = time.time()
prof.export_chrome_trace("/tmp/offload_prof")
if epoch_start_time != 0: if epoch_start_time != 0:
wps = total_tokens / (time.time() - epoch_start_time) wps = total_tokens / (time.time() - epoch_start_time)
else: else:
...@@ -379,7 +387,7 @@ def run_benchmark(args): ...@@ -379,7 +387,7 @@ def run_benchmark(args):
model = model_config["model"] model = model_config["model"]
if args.dry_run: if args.dry_run:
train(model_config, model, benchmark_config, args) train(model_config, model, benchmark_config, model_specs, args)
else: else:
benchmark_language_model(model_config, model, benchmark_config, model_specs, args) benchmark_language_model(model_config, model, benchmark_config, model_specs, args)
elif args.model_name == "seq": elif args.model_name == "seq":
...@@ -393,16 +401,23 @@ def run_benchmark(args): ...@@ -393,16 +401,23 @@ def run_benchmark(args):
parser = argparse.ArgumentParser(description="benchmark") parser = argparse.ArgumentParser(description="benchmark")
parser.add_argument("--dry_run", action="store_true", help="Run a sample training run without regression testing.")
parser.add_argument( parser.add_argument(
"--debug", action="store_true", help="Print debugging statements which is more verbose than the default." "--dry_run", default=True, action="store_true", help="Run a sample training run without regression testing."
)
parser.add_argument(
"--debug",
action="store_true",
default=True,
help="Print debugging statements which is more verbose than the default.",
) )
parser.add_argument( parser.add_argument(
"--model_name", default="lm", type=str, help="Language Model(LM) used to benchmark nn.pipe.", "--model_name", default="lm", type=str, help="Language Model(LM) used to benchmark nn.pipe.",
) )
parser.add_argument("--use_synthetic_data", action="store_true", help="Uses synthetic data for running benchmarks.") parser.add_argument(
"--use_synthetic_data", default=True, action="store_true", help="Uses synthetic data for running benchmarks."
)
parser.add_argument("--use_fp16", action="store_true", default=False) parser.add_argument("--use_fp16", action="store_true", default=False)
parser.add_argument("--checkpoint_activation", action="store_true", default=False) parser.add_argument("--checkpoint_activation", action="store_true", default=True)
parser.add_argument("--use_profiler", action="store_true", default=False) parser.add_argument("--use_profiler", action="store_true", default=False)
......
...@@ -28,7 +28,7 @@ class Offload_Transformer: ...@@ -28,7 +28,7 @@ class Offload_Transformer:
"batch_size": 8, "batch_size": 8,
"criterion": nn.CrossEntropyLoss(), "criterion": nn.CrossEntropyLoss(),
"checkpoint_activation": True, "checkpoint_activation": True,
"num_microbatches": 4, "num_microbatches": 1,
"slices": 3, "slices": 3,
} }
......
...@@ -193,7 +193,6 @@ def train(model_config, model, benchmark_config, model_specs, args): ...@@ -193,7 +193,6 @@ def train(model_config, model, benchmark_config, model_specs, args):
if pipe_group is None or pipe_group.rank() == pipe_group.size() - 1: if pipe_group is None or pipe_group.rank() == pipe_group.size() - 1:
target = target.to(get_device(model, -1)) target = target.to(get_device(model, -1))
output = output.to(target.device) output = output.to(target.device)
loss = criterion(output.view(-1, vocab_size), target.view(-1)) loss = criterion(output.view(-1, vocab_size), target.view(-1))
loss.backward() loss.backward()
del target del target
......
...@@ -51,6 +51,8 @@ def _split(modules: nn.Sequential, number_splits: int) -> List[List[nn.Module]]: ...@@ -51,6 +51,8 @@ def _split(modules: nn.Sequential, number_splits: int) -> List[List[nn.Module]]:
) )
for m in modules: for m in modules:
for p in m.parameters():
p.data = p.data.pin_memory()
# Number of parameters in the current shard # Number of parameters in the current shard
current_shard_params = sum(p.numel() for sm in splits[current_shard] for p in sm.parameters()) current_shard_params = sum(p.numel() for sm in splits[current_shard] for p in sm.parameters())
...@@ -91,9 +93,8 @@ class ModelShard(nn.Module): ...@@ -91,9 +93,8 @@ class ModelShard(nn.Module):
self.offload_device = offload_device self.offload_device = offload_device
self.model_shard.to(offload_device) self.model_shard.to(offload_device)
self.cuda_stream = torch.cuda.Stream( self._cpu_to_gpu_stream = torch.cuda.Stream(device=self.device)
device=self.device self._gpu_to_cpu_stream = torch.cuda.Stream(device=self.device)
) # needed to make sure load/offload really run in parallel with compute
def forward(self, *inputs): # type: ignore def forward(self, *inputs): # type: ignore
return self.model_shard(*inputs) if isinstance(inputs, tuple) else self.model_shard(inputs) return self.model_shard(*inputs) if isinstance(inputs, tuple) else self.model_shard(inputs)
...@@ -112,20 +113,20 @@ class ModelShard(nn.Module): ...@@ -112,20 +113,20 @@ class ModelShard(nn.Module):
self.model_shard.to(device=self.device, non_blocking=True) self.model_shard.to(device=self.device, non_blocking=True)
def forward_load(self, non_blocking: bool = True) -> None: def forward_load(self, non_blocking: bool = True) -> None:
with torch.cuda.stream(self.cuda_stream): with torch.cuda.stream(self._cpu_to_gpu_stream):
# Restore all the parameter buffers # Restore all the parameter buffers
self.model_shard.to(device=self.device, non_blocking=non_blocking) self.model_shard.to(device=self.device, non_blocking=non_blocking)
def backward_load(self, non_blocking: bool = True) -> None: def backward_load(self, non_blocking: bool = True) -> None:
with torch.cuda.stream(self.cuda_stream): with torch.cuda.stream(self._cpu_to_gpu_stream):
self.model_shard.to(self.device, non_blocking=non_blocking) self.model_shard.to(self.device, non_blocking=non_blocking)
def forward_drop(self, non_blocking: bool = True) -> None: def forward_drop(self, non_blocking: bool = True) -> None:
with torch.cuda.stream(self.cuda_stream): with torch.cuda.stream(self._gpu_to_cpu_stream):
self.model_shard.to(self.offload_device, non_blocking=non_blocking) self.model_shard.to(self.offload_device, non_blocking=non_blocking)
def backward_drop(self, non_blocking: bool = True) -> None: def backward_drop(self, non_blocking: bool = True) -> None:
with torch.cuda.stream(self.cuda_stream): with torch.cuda.stream(self._gpu_to_cpu_stream):
self.model_shard.to(self.offload_device, non_blocking=non_blocking) self.model_shard.to(self.offload_device, non_blocking=non_blocking)
...@@ -150,7 +151,7 @@ class ActivationCheckpointing(torch.autograd.Function): ...@@ -150,7 +151,7 @@ class ActivationCheckpointing(torch.autograd.Function):
@staticmethod @staticmethod
@conditional_amp_fwd_decorator # type: ignore @conditional_amp_fwd_decorator # type: ignore
def forward(ctx: Any, inputs: Any, model_instance: Any) -> Any: def forward(ctx: Any, inputs: Any, dummy_input: Any, model_instance: Any) -> Any:
inputs = inputs if isinstance(inputs, tuple) else (inputs,) inputs = inputs if isinstance(inputs, tuple) else (inputs,)
ctx.inputs = inputs ctx.inputs = inputs
...@@ -168,6 +169,7 @@ class ActivationCheckpointing(torch.autograd.Function): ...@@ -168,6 +169,7 @@ class ActivationCheckpointing(torch.autograd.Function):
model_instance._activations[index] = tuple([a.cuda() for a in list(model_instance._activations[index])]) model_instance._activations[index] = tuple([a.cuda() for a in list(model_instance._activations[index])])
# Bring in the current layer shard onto the device. # Bring in the current layer shard onto the device.
layer_shard.forward_load() layer_shard.forward_load()
# Apply the FP and store the activations on the CPU. # Apply the FP and store the activations on the CPU.
inputs = model_instance._activations[index] inputs = model_instance._activations[index]
...@@ -212,17 +214,9 @@ class ActivationCheckpointing(torch.autograd.Function): ...@@ -212,17 +214,9 @@ class ActivationCheckpointing(torch.autograd.Function):
all_grads = [grad_outputs] all_grads = [grad_outputs]
final_index = len(model_instance._activations) - 1
for model_shard, activation in zip( for model_shard, activation in zip(
reversed(model_instance.model_slices), reversed(model_instance._activations[:-1]) reversed(model_instance.model_slices), reversed(model_instance._activations[:-1])
): ):
# Move the activation to the device.
activation = tuple([a.cuda() for a in list(activation)])
# One of the inputs to the FW pass must require grad.
for a in activation:
a.requires_grad = True
# Move the model shard to the device. # Move the model shard to the device.
model_shard.backward_load() model_shard.backward_load()
# Store the BW pass state. # Store the BW pass state.
...@@ -232,6 +226,7 @@ class ActivationCheckpointing(torch.autograd.Function): ...@@ -232,6 +226,7 @@ class ActivationCheckpointing(torch.autograd.Function):
activation = torch.utils.checkpoint.detach_variable(activation) activation = torch.utils.checkpoint.detach_variable(activation)
# Get the last gradient calculation. # Get the last gradient calculation.
final_grads = all_grads[-1] final_grads = all_grads[-1]
if isinstance(activation, torch.Tensor): if isinstance(activation, torch.Tensor):
activation = (activation,) activation = (activation,)
if isinstance(final_grads, torch.Tensor): if isinstance(final_grads, torch.Tensor):
...@@ -253,6 +248,8 @@ class ActivationCheckpointing(torch.autograd.Function): ...@@ -253,6 +248,8 @@ class ActivationCheckpointing(torch.autograd.Function):
# Since we need a grad value of a non leaf element we need to set these properties. # Since we need a grad value of a non leaf element we need to set these properties.
for a in chunked_activation: for a in chunked_activation:
if a.dtype == torch.long:
continue
a.requires_grad = True a.requires_grad = True
a.retain_grad() a.retain_grad()
...@@ -263,13 +260,16 @@ class ActivationCheckpointing(torch.autograd.Function): ...@@ -263,13 +260,16 @@ class ActivationCheckpointing(torch.autograd.Function):
# Set the states back to what it was at the start of this function. # Set the states back to what it was at the start of this function.
torch.set_rng_state(bwd_rng_state) torch.set_rng_state(bwd_rng_state)
torch.autograd.backward(outputs, chunked_grad) torch.autograd.backward(outputs, chunked_grad)
chunked_grad_list += [a.grad for a in chunked_activation] intermediate_grads = []
for a in chunked_activation:
if a.grad is not None:
intermediate_grads.append(a.grad)
if None not in intermediate_grads:
chunked_grad_list += intermediate_grads
if chunked_grad_list:
# Append the list of grads to the all_grads list and this should be on the CPU. # Append the list of grads to the all_grads list and this should be on the CPU.
all_grads.append(torch.cat(chunked_grad_list).squeeze(-1)) # type: ignore all_grads.append(torch.cat(chunked_grad_list).squeeze(-1)) # type: ignore
# Move activation back to the CPU.
# TODO(anj-s): Why does moving activations to CPU cause the .grad property to be None? # TODO(anj-s): Why does moving activations to CPU cause the .grad property to be None?
activation = tuple([a.cpu() for a in list(activation)])
# Move the shard back to the CPU. # Move the shard back to the CPU.
model_shard.backward_drop() model_shard.backward_drop()
detached_inputs = model_instance._activations[0] detached_inputs = model_instance._activations[0]
...@@ -432,20 +432,9 @@ class OffloadModel(nn.Module): ...@@ -432,20 +432,9 @@ class OffloadModel(nn.Module):
self._num_microbatches = num_microbatches self._num_microbatches = num_microbatches
def forward(self, *inputs: Any, **_: Any) -> Any: def forward(self, *inputs: Any, **_: Any) -> Any:
# At least one of the inputs needs to have `requires_grad` set. dummy_input = torch.tensor([], requires_grad=True)
# TODO(anj-s): Should we require users to set this or should we set it here?
set_at_least_once = False
for inp in inputs:
if inp.dtype == torch.long:
continue
inp.requires_grad = True
set_at_least_once = True
if not set_at_least_once:
raise RuntimeError("We need at least one of the inputs to require grads.")
if self._checkpoint_activation: if self._checkpoint_activation:
return ActivationCheckpointing.apply(*inputs, self) return ActivationCheckpointing.apply(*inputs, dummy_input, self)
self._activations = [] self._activations = []
for index in range(-1, len(self.model_slices)): for index in range(-1, len(self.model_slices)):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment