Commit 7b585440 authored by mshoeybi's avatar mshoeybi
Browse files

onlly support pp=1

parent e923ec52
......@@ -236,6 +236,7 @@ def parse_args(extra_args_provider=None, defaults={},
assert args.checkpoint_activations, \
'for distribute-checkpointed-activations to work you '\
'need to enable checkpoint-activations'
assert args.pipeline_model_parallel_size == 1
_print_args(args)
return args
......
......@@ -289,7 +289,8 @@ class CheckpointFunction(torch.autograd.Function):
return (None, None) + grads
def checkpoint(function, *args):
def checkpoint(function, distribute_checkpointed_activations, *args):
"""Checkpoint a model or part of the model.
This has been directly copied from torch.utils.checkpoint."""
return CheckpointFunction.apply(function, *args)
return CheckpointFunction.apply(function,
distribute_checkpointed_activations, *args)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment