Commit 33597e5a authored by Myle Ott's avatar Myle Ott Committed by Facebook Github Bot
Browse files

Support different --max-positions and --tokens-per-sample

Summary: Pull Request resolved: https://github.com/pytorch/fairseq/pull/924

Differential Revision: D16548165

Pulled By: myleott

fbshipit-source-id: 49569ece3e54fad7b4f0dfb201ac99123bfdd4f2
parent 2fe45f09
...@@ -43,6 +43,10 @@ class RobertaHubInterface(nn.Module): ...@@ -43,6 +43,10 @@ class RobertaHubInterface(nn.Module):
def extract_features(self, tokens: torch.LongTensor, return_all_hiddens=False) -> torch.Tensor: def extract_features(self, tokens: torch.LongTensor, return_all_hiddens=False) -> torch.Tensor:
if tokens.dim() == 1: if tokens.dim() == 1:
tokens = tokens.unsqueeze(0) tokens = tokens.unsqueeze(0)
if tokens.size(-1) > self.model.max_positions():
raise ValueError('tokens exceeds maximum length: {} > {}'.format(
tokens.size(-1), self.model.max_positions()
))
features, extra = self.model( features, extra = self.model(
tokens.to(device=self.device), tokens.to(device=self.device),
features_only=True, features_only=True,
......
...@@ -75,6 +75,8 @@ class RobertaModel(FairseqLanguageModel): ...@@ -75,6 +75,8 @@ class RobertaModel(FairseqLanguageModel):
help='dropout probability after activation in FFN') help='dropout probability after activation in FFN')
parser.add_argument('--pooler-dropout', type=float, metavar='D', parser.add_argument('--pooler-dropout', type=float, metavar='D',
help='dropout probability in the masked_lm pooler layers') help='dropout probability in the masked_lm pooler layers')
parser.add_argument('--max-positions', type=int,
help='number of positional embeddings to learn')
@classmethod @classmethod
def build_model(cls, args, task): def build_model(cls, args, task):
......
...@@ -178,8 +178,6 @@ class MaskedLMTask(FairseqTask): ...@@ -178,8 +178,6 @@ class MaskedLMTask(FairseqTask):
) )
def build_dataset_for_inference(self, src_tokens, src_lengths, sort=True): def build_dataset_for_inference(self, src_tokens, src_lengths, sort=True):
if self.args.also_lowercase_words:
raise NotImplementedError
src_dataset = PadDataset( src_dataset = PadDataset(
TokenBlockDataset( TokenBlockDataset(
src_tokens, src_tokens,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment