Commit 33597e5a authored by Myle Ott's avatar Myle Ott Committed by Facebook Github Bot
Browse files

Support different --max-positions and --tokens-per-sample

Summary: Pull Request resolved: https://github.com/pytorch/fairseq/pull/924

Differential Revision: D16548165

Pulled By: myleott

fbshipit-source-id: 49569ece3e54fad7b4f0dfb201ac99123bfdd4f2
parent 2fe45f09
......@@ -43,6 +43,10 @@ class RobertaHubInterface(nn.Module):
def extract_features(self, tokens: torch.LongTensor, return_all_hiddens=False) -> torch.Tensor:
if tokens.dim() == 1:
tokens = tokens.unsqueeze(0)
if tokens.size(-1) > self.model.max_positions():
raise ValueError('tokens exceeds maximum length: {} > {}'.format(
tokens.size(-1), self.model.max_positions()
))
features, extra = self.model(
tokens.to(device=self.device),
features_only=True,
......
......@@ -75,6 +75,8 @@ class RobertaModel(FairseqLanguageModel):
help='dropout probability after activation in FFN')
parser.add_argument('--pooler-dropout', type=float, metavar='D',
help='dropout probability in the masked_lm pooler layers')
parser.add_argument('--max-positions', type=int,
help='number of positional embeddings to learn')
@classmethod
def build_model(cls, args, task):
......
......@@ -178,8 +178,6 @@ class MaskedLMTask(FairseqTask):
)
def build_dataset_for_inference(self, src_tokens, src_lengths, sort=True):
if self.args.also_lowercase_words:
raise NotImplementedError
src_dataset = PadDataset(
TokenBlockDataset(
src_tokens,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment