Commit 30123e2c authored by Myle Ott's avatar Myle Ott Committed by Facebook Github Bot
Browse files

Fix read_binarized.py script

Summary: Pull Request resolved: https://github.com/fairinternal/fairseq-py/pull/762

Differential Revision: D16427266

Pulled By: myleott

fbshipit-source-id: 9bd9b8c6b4994ae98a62a37b34d03265bd365453
parent a03fe6fa
...@@ -45,7 +45,7 @@ def collate_tokens(values, pad_idx, eos_idx=None, left_pad=False, move_eos_to_be ...@@ -45,7 +45,7 @@ def collate_tokens(values, pad_idx, eos_idx=None, left_pad=False, move_eos_to_be
return res return res
def load_indexed_dataset(path, dictionary, dataset_impl=None, combine=False): def load_indexed_dataset(path, dictionary, dataset_impl=None, combine=False, default='cached'):
"""A helper function for loading indexed datasets. """A helper function for loading indexed datasets.
Args: Args:
...@@ -72,7 +72,7 @@ def load_indexed_dataset(path, dictionary, dataset_impl=None, combine=False): ...@@ -72,7 +72,7 @@ def load_indexed_dataset(path, dictionary, dataset_impl=None, combine=False):
dataset = indexed_dataset.make_dataset( dataset = indexed_dataset.make_dataset(
path_k, path_k,
impl=dataset_impl_k or 'cached', impl=dataset_impl_k or default,
fix_lua_indexing=True, fix_lua_indexing=True,
dictionary=dictionary, dictionary=dictionary,
) )
......
...@@ -8,8 +8,7 @@ ...@@ -8,8 +8,7 @@
import argparse import argparse
from fairseq.data import Dictionary from fairseq.data import data_utils, Dictionary, indexed_dataset
from fairseq.data import indexed_dataset
def get_parser(): def get_parser():
...@@ -30,8 +29,12 @@ def main(): ...@@ -30,8 +29,12 @@ def main():
args = parser.parse_args() args = parser.parse_args()
dictionary = Dictionary.load(args.dict) if args.dict is not None else None dictionary = Dictionary.load(args.dict) if args.dict is not None else None
dataset = indexed_dataset.make_dataset(args.input, impl=args.dataset_impl, dataset = data_utils.load_indexed_dataset(
fix_lua_indexing=True, dictionary=dictionary) args.input,
dictionary,
dataset_impl=args.dataset_impl,
default='lazy',
)
for tensor_line in dataset: for tensor_line in dataset:
if dictionary is None: if dictionary is None:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment