"examples/vscode:/vscode.git/clone" did not exist on "5ab21b072fa2a122da930386381d23f95de06e28"
Unverified Commit 74ffc9ea authored by Patrick von Platen's avatar Patrick von Platen Committed by GitHub
Browse files

[Reformer] Fix example and error message (#4191)

* fix example reformer

* fix error message and example docstring

* improved error message
parent 96c78396
......@@ -124,8 +124,8 @@ class AxialPositionEmbeddings(nn.Module):
if self.training is True:
assert (
reduce(mul, self.axial_pos_shape) == sequence_length
), "Make sure that config.axial_pos_shape factors: {} multiply to sequence length: {}".format(
self.axial_pos_shape, sequence_length
), "If training, make sure that config.axial_pos_shape factors: {} multiply to sequence length. Got prod({}) != sequence_length: {}. You might want to consider padding your sequence length to {} or changing config.axial_pos_shape.".format(
self.axial_pos_shape, self.axial_pos_shape, sequence_length, reduce(mul, self.axial_pos_shape)
)
if self.dropout > 0:
weights = torch.cat(broadcasted_weights, dim=-1)
......@@ -1515,11 +1515,11 @@ class ReformerModel(ReformerPreTrainedModel):
Examples::
from transformers import ReformerModel, ReformerTokenizer
from transformers import ReformerModelWithLMHead, ReformerTokenizer
import torch
tokenizer = ReformerTokenizer.from_pretrained('bert-base-uncased')
model = ReformerModel.from_pretrained('bert-base-uncased')
tokenizer = ReformerTokenizer.from_pretrained('google/reformer-crime-and-punishment')
model = ReformerModelWithLMHead.from_pretrained('google/reformer-crime-and-punishment')
input_ids = torch.tensor(tokenizer.encode("Hello, my dog is cute", add_special_tokens=True)).unsqueeze(0) # Batch size 1
outputs = model(input_ids)
......@@ -1562,7 +1562,7 @@ class ReformerModel(ReformerPreTrainedModel):
if self.training is True:
raise ValueError(
"If training, sequence Length {} has to be a multiple of least common multiple chunk_length {}. Please consider padding the input to a length of {}.".format(
input_shape[-2], least_common_mult_chunk_length, input_shape[-2] + padding_length
input_shape[-1], least_common_mult_chunk_length, input_shape[-1] + padding_length
)
)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment