Unverified Commit 95dab34d authored by Yusuke Mori's avatar Yusuke Mori Committed by GitHub
Browse files

Add an error message that fires when Reformer is not in training mode, but one...

Add an error message that fires when Reformer is not in training mode, but one runs .backward() (#11117)
parent f1b938fd
......@@ -1512,6 +1512,10 @@ class ReformerLayer(nn.Module):
# Implementation of RevNet (see Fig. 6 in https://towardsdatascience.com/illustrating-the-reformer-393575ac6ba0)
# This code is heavily inspired by https://github.com/lucidrains/reformer-pytorch/blob/master/reformer_pytorch/reversible.py
assert (
self.training
), "If you want to train `ReformerModel` and its variations, make sure to use `model.train()` to put the model into training mode."
with torch.enable_grad():
next_attn_output.requires_grad = True
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment