Commit 7b585440 authored by mshoeybi's avatar mshoeybi
Browse files

onlly support pp=1

parent e923ec52
...@@ -236,6 +236,7 @@ def parse_args(extra_args_provider=None, defaults={}, ...@@ -236,6 +236,7 @@ def parse_args(extra_args_provider=None, defaults={},
assert args.checkpoint_activations, \ assert args.checkpoint_activations, \
'for distribute-checkpointed-activations to work you '\ 'for distribute-checkpointed-activations to work you '\
'need to enable checkpoint-activations' 'need to enable checkpoint-activations'
assert args.pipeline_model_parallel_size == 1
_print_args(args) _print_args(args)
return args return args
......
...@@ -289,7 +289,8 @@ class CheckpointFunction(torch.autograd.Function): ...@@ -289,7 +289,8 @@ class CheckpointFunction(torch.autograd.Function):
return (None, None) + grads return (None, None) + grads
def checkpoint(function, *args): def checkpoint(function, distribute_checkpointed_activations, *args):
"""Checkpoint a model or part of the model. """Checkpoint a model or part of the model.
This has been directly copied from torch.utils.checkpoint.""" This has been directly copied from torch.utils.checkpoint."""
return CheckpointFunction.apply(function, *args) return CheckpointFunction.apply(function,
distribute_checkpointed_activations, *args)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment