Commit a14e8360 authored by Gustaf Ahdritz's avatar Gustaf Ahdritz
Browse files

Never checkpoint with grad disabled

parent 67e34d92
...@@ -73,7 +73,7 @@ def checkpoint_blocks( ...@@ -73,7 +73,7 @@ def checkpoint_blocks(
# Avoids mishaps when the blocks take just one argument # Avoids mishaps when the blocks take just one argument
args = wrap(args) args = wrap(args)
if blocks_per_ckpt is None: if blocks_per_ckpt is None or not torch.is_grad_enabled():
return exec(blocks, args) return exec(blocks, args)
elif blocks_per_ckpt < 1 or blocks_per_ckpt > len(blocks): elif blocks_per_ckpt < 1 or blocks_per_ckpt > len(blocks):
raise ValueError("blocks_per_ckpt must be between 1 and len(blocks)") raise ValueError("blocks_per_ckpt must be between 1 and len(blocks)")
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment