Commit e4047852 authored by Myle Ott's avatar Myle Ott Committed by Facebook Github Bot
Browse files

Add internal tests for torch hub

Summary: Pull Request resolved: https://github.com/fairinternal/fairseq-py/pull/911

Differential Revision: D18511627

Pulled By: myleott

fbshipit-source-id: 37d7606ae629f9acf84715dbc9045fb683075db4
parent 0d03aa88
...@@ -227,7 +227,7 @@ class WSCTask(FairseqTask): ...@@ -227,7 +227,7 @@ class WSCTask(FairseqTask):
def get_masked_input(tokens, mask): def get_masked_input(tokens, mask):
masked_tokens = tokens.clone() masked_tokens = tokens.clone()
masked_tokens[mask] = self.mask masked_tokens[mask.bool()] = self.mask
return masked_tokens return masked_tokens
def get_lprobs(tokens, mask): def get_lprobs(tokens, mask):
...@@ -252,7 +252,7 @@ class WSCTask(FairseqTask): ...@@ -252,7 +252,7 @@ class WSCTask(FairseqTask):
best_idx = cand_lprobs.argmax().item() best_idx = cand_lprobs.argmax().item()
full_cand = sample['candidate_tokens'][0][best_idx] full_cand = sample['candidate_tokens'][0][best_idx]
mask = sample['candidate_masks'][0][best_idx] mask = sample['candidate_masks'][0][best_idx]
toks = full_cand[mask] toks = full_cand[mask.bool()]
return self.bpe.decode(self.source_dictionary.string(toks)).strip() return self.bpe.decode(self.source_dictionary.string(toks)).strip()
@property @property
......
...@@ -24,7 +24,7 @@ class SubwordNMTBPE(object): ...@@ -24,7 +24,7 @@ class SubwordNMTBPE(object):
raise ValueError('--bpe-codes is required for --bpe=subword_nmt') raise ValueError('--bpe-codes is required for --bpe=subword_nmt')
codes = file_utils.cached_path(args.bpe_codes) codes = file_utils.cached_path(args.bpe_codes)
try: try:
from subword_nmt.subword_nmt import apply_bpe from subword_nmt import apply_bpe
bpe_parser = apply_bpe.create_parser() bpe_parser = apply_bpe.create_parser()
bpe_args = bpe_parser.parse_args([ bpe_args = bpe_parser.parse_args([
'--codes', codes, '--codes', codes,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment