Unverified Commit 95dab34d authored by Yusuke Mori's avatar Yusuke Mori Committed by GitHub
Browse files

Add an error message that fires when Reformer is not in training mode, but one...

Add an error message that fires when Reformer is not in training mode, but one runs .backward() (#11117)
parent f1b938fd
...@@ -1512,6 +1512,10 @@ class ReformerLayer(nn.Module): ...@@ -1512,6 +1512,10 @@ class ReformerLayer(nn.Module):
# Implementation of RevNet (see Fig. 6 in https://towardsdatascience.com/illustrating-the-reformer-393575ac6ba0) # Implementation of RevNet (see Fig. 6 in https://towardsdatascience.com/illustrating-the-reformer-393575ac6ba0)
# This code is heavily inspired by https://github.com/lucidrains/reformer-pytorch/blob/master/reformer_pytorch/reversible.py # This code is heavily inspired by https://github.com/lucidrains/reformer-pytorch/blob/master/reformer_pytorch/reversible.py
assert (
self.training
), "If you want to train `ReformerModel` and its variations, make sure to use `model.train()` to put the model into training mode."
with torch.enable_grad(): with torch.enable_grad():
next_attn_output.requires_grad = True next_attn_output.requires_grad = True
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment