# Copyright (c) DP Technology. # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. import logging import os import contextlib from typing import Optional import numpy as np from unicore.data import ( Dictionary, MaskTokensDataset, NestedDictionaryDataset, NumelDataset, NumSamplesDataset, LMDBDataset, PrependTokenDataset, RightPadDataset, SortDataset, BertTokenizeDataset, data_utils, ) from unicore.tasks import UnicoreTask, register_task logger = logging.getLogger(__name__) @register_task("bert") class BertTask(UnicoreTask): """Task for training masked language models (e.g., BERT).""" @staticmethod def add_args(parser): """Add task-specific arguments to the parser.""" parser.add_argument( "data", help="colon separated path to data directories list, \ will be iterated upon during epochs in round-robin manner", ) parser.add_argument( "--mask-prob", default=0.15, type=float, help="probability of replacing a token with mask", ) parser.add_argument( "--leave-unmasked-prob", default=0.1, type=float, help="probability that a masked token is unmasked", ) parser.add_argument( "--random-token-prob", default=0.1, type=float, help="probability of replacing a token with a random token", ) def __init__(self, args, dictionary): super().__init__(args) self.dictionary = dictionary self.seed = args.seed # add mask token self.mask_idx = dictionary.add_symbol("[MASK]", is_special=True) @classmethod def setup_task(cls, args, **kwargs): dictionary = Dictionary.load(os.path.join(args.data, "dict.txt")) logger.info("dictionary: {} types".format(len(dictionary))) return cls(args, dictionary) def load_dataset(self, split, combine=False, **kwargs): """Load a given dataset split. Args: split (str): name of the split (e.g., train, valid, test) """ split_path = os.path.join(self.args.data, split + '.lmdb') dict_path = os.path.join(self.args.data, "dict.txt") dataset = LMDBDataset(split_path) dataset = BertTokenizeDataset(dataset, dict_path, max_seq_len=self.args.max_seq_len) src_dataset, tgt_dataset = MaskTokensDataset.apply_mask( dataset, self.dictionary, pad_idx=self.dictionary.pad(), mask_idx=self.mask_idx, seed=self.args.seed, mask_prob=self.args.mask_prob, leave_unmasked_prob=self.args.leave_unmasked_prob, random_token_prob=self.args.random_token_prob, ) with data_utils.numpy_seed(self.args.seed): shuffle = np.random.permutation(len(src_dataset)) self.datasets[split] = SortDataset( NestedDictionaryDataset( { "net_input": { "src_tokens": RightPadDataset( src_dataset, pad_idx=self.dictionary.pad(), ) }, "target": RightPadDataset( tgt_dataset, pad_idx=self.dictionary.pad(), ), }, ), sort_order=[ shuffle ], ) def build_model(self, args): from unicore import models model = models.build_model(args, self) return model