Unverified Commit 74ffc9ea authored by Patrick von Platen's avatar Patrick von Platen Committed by GitHub
Browse files

[Reformer] Fix example and error message (#4191)

* fix example reformer

* fix error message and example docstring

* improved error message
parent 96c78396
...@@ -124,8 +124,8 @@ class AxialPositionEmbeddings(nn.Module): ...@@ -124,8 +124,8 @@ class AxialPositionEmbeddings(nn.Module):
if self.training is True: if self.training is True:
assert ( assert (
reduce(mul, self.axial_pos_shape) == sequence_length reduce(mul, self.axial_pos_shape) == sequence_length
), "Make sure that config.axial_pos_shape factors: {} multiply to sequence length: {}".format( ), "If training, make sure that config.axial_pos_shape factors: {} multiply to sequence length. Got prod({}) != sequence_length: {}. You might want to consider padding your sequence length to {} or changing config.axial_pos_shape.".format(
self.axial_pos_shape, sequence_length self.axial_pos_shape, self.axial_pos_shape, sequence_length, reduce(mul, self.axial_pos_shape)
) )
if self.dropout > 0: if self.dropout > 0:
weights = torch.cat(broadcasted_weights, dim=-1) weights = torch.cat(broadcasted_weights, dim=-1)
...@@ -1515,11 +1515,11 @@ class ReformerModel(ReformerPreTrainedModel): ...@@ -1515,11 +1515,11 @@ class ReformerModel(ReformerPreTrainedModel):
Examples:: Examples::
from transformers import ReformerModel, ReformerTokenizer from transformers import ReformerModelWithLMHead, ReformerTokenizer
import torch import torch
tokenizer = ReformerTokenizer.from_pretrained('bert-base-uncased') tokenizer = ReformerTokenizer.from_pretrained('google/reformer-crime-and-punishment')
model = ReformerModel.from_pretrained('bert-base-uncased') model = ReformerModelWithLMHead.from_pretrained('google/reformer-crime-and-punishment')
input_ids = torch.tensor(tokenizer.encode("Hello, my dog is cute", add_special_tokens=True)).unsqueeze(0) # Batch size 1 input_ids = torch.tensor(tokenizer.encode("Hello, my dog is cute", add_special_tokens=True)).unsqueeze(0) # Batch size 1
outputs = model(input_ids) outputs = model(input_ids)
...@@ -1562,7 +1562,7 @@ class ReformerModel(ReformerPreTrainedModel): ...@@ -1562,7 +1562,7 @@ class ReformerModel(ReformerPreTrainedModel):
if self.training is True: if self.training is True:
raise ValueError( raise ValueError(
"If training, sequence Length {} has to be a multiple of least common multiple chunk_length {}. Please consider padding the input to a length of {}.".format( "If training, sequence Length {} has to be a multiple of least common multiple chunk_length {}. Please consider padding the input to a length of {}.".format(
input_shape[-2], least_common_mult_chunk_length, input_shape[-2] + padding_length input_shape[-1], least_common_mult_chunk_length, input_shape[-1] + padding_length
) )
) )
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment