Commit 47fd9852 authored by Myle Ott's avatar Myle Ott Committed by Facebook Github Bot
Browse files

Move Masked LM components to legacy/ -- new ones are coming

Summary: Pull Request resolved: https://github.com/fairinternal/fairseq-py/pull/740

Differential Revision: D16377797

Pulled By: myleott

fbshipit-source-id: f7d6c8b00a77e279ea94376b1f0fcd15087eaf5f
parent 9c89e882
......@@ -63,7 +63,7 @@ fairseq-train \
--optimizer adam --lr-scheduler reduce_lr_on_plateau \
--lr-shrink 0.5 --lr 0.0001 --min-lr 1e-09 \
--dropout 0.1 \
--criterion masked_lm_loss \
--criterion legacy_masked_lm_loss \
--max-tokens 2048 --tokens-per-sample 256 --attention-dropout 0.1 \
--dataset-impl lazy --seed 0 \
--masked-lm-only \
......
......@@ -32,8 +32,8 @@ def compute_cross_entropy_loss(logits, targets, ignore_index=-100):
return loss
@register_criterion('masked_lm_loss')
class MaskedLmLoss(FairseqCriterion):
@register_criterion('legacy_masked_lm_loss')
class LegacyMaskedLmLoss(FairseqCriterion):
"""
Implementation for the loss used in masked language model (MLM) training.
This optionally also computes the next sentence prediction (NSP) loss and
......
......@@ -6,18 +6,15 @@
# can be found in the PATENTS file in the same directory.
from .dictionary import Dictionary, TruncatedDictionary
from .masked_lm_dictionary import BertDictionary, MaskedLMDictionary
from .fairseq_dataset import FairseqDataset
from .audio.raw_audio_dataset import RawAudioDataset
from .backtranslation_dataset import BacktranslationDataset
from .block_pair_dataset import BlockPairDataset
from .concat_dataset import ConcatDataset
from .indexed_dataset import IndexedCachedDataset, IndexedDataset, IndexedRawTextDataset, MMapIndexedDataset
from .language_pair_dataset import LanguagePairDataset
from .lm_context_window_dataset import LMContextWindowDataset
from .masked_lm_dataset import MaskedLMDataset
from .monolingual_dataset import MonolingualDataset
from .noising import NoisingDataset
from .round_robin_zip_datasets import RoundRobinZipDatasets
......@@ -34,8 +31,6 @@ from .iterators import (
__all__ = [
'BacktranslationDataset',
'BertDictionary',
'BlockPairDataset',
'ConcatDataset',
'CountingIterator',
'Dictionary',
......@@ -47,8 +42,6 @@ __all__ = [
'IndexedRawTextDataset',
'LanguagePairDataset',
'LMContextWindowDataset',
'MaskedLMDataset',
'MaskedLMDictionary',
'MMapIndexedDataset',
'MonolingualDataset',
'NoisingDataset',
......
# Copyright (c) 2017-present, Facebook, Inc.
# All rights reserved.
#
# This source code is licensed under the license found in the LICENSE file in
# the root directory of this source tree. An additional grant of patent rights
# can be found in the PATENTS file in the same directory.
from .masked_lm_dictionary import BertDictionary, MaskedLMDictionary
from .block_pair_dataset import BlockPairDataset
from .masked_lm_dataset import MaskedLMDataset
__all__ = [
'BertDictionary',
'BlockPairDataset',
'MaskedLMDataset',
'MaskedLMDictionary',
]
......@@ -10,7 +10,7 @@ import math
import numpy as np
import torch
from . import FairseqDataset
from fairseq.data import FairseqDataset
class BlockPairDataset(FairseqDataset):
......
......@@ -12,10 +12,10 @@ import torch
from typing import Dict, List, Tuple
from . import FairseqDataset, data_utils
from fairseq.data import FairseqDataset, data_utils
from fairseq.data import Dictionary
from fairseq.data.block_pair_dataset import BlockPairDataset
from fairseq.data.legacy.block_pair_dataset import BlockPairDataset
from fairseq.data.token_block_dataset import TokenBlockDataset
from fairseq.data.concat_dataset import ConcatDataset
......
......@@ -9,7 +9,7 @@ import os
from typing import Any, Dict
from fairseq import checkpoint_utils
from fairseq.data.masked_lm_dictionary import MaskedLMDictionary
from fairseq.data.legacy.masked_lm_dictionary import MaskedLMDictionary
from fairseq.models import register_model, register_model_architecture
from fairseq.models.transformer import (
TransformerDecoder,
......
......@@ -13,7 +13,7 @@ from collections import OrderedDict
import numpy as np
from fairseq import tokenizer
from fairseq.data.masked_lm_dictionary import MaskedLMDictionary
from fairseq.data.legacy.masked_lm_dictionary import MaskedLMDictionary
from fairseq.data import (
ConcatDataset,
......@@ -23,7 +23,7 @@ from fairseq.data import (
)
from fairseq.data import Dictionary
from fairseq.data.masked_lm_dataset import MaskedLMDataset
from fairseq.data.legacy.masked_lm_dataset import MaskedLMDataset
from fairseq.data.multi_corpus_sampled_dataset import MultiCorpusSampledDataset
from . import FairseqTask, register_task
......
......@@ -17,15 +17,15 @@ from fairseq.data import (
)
from fairseq.data import Dictionary
from fairseq.data.block_pair_dataset import BlockPairDataset
from fairseq.data.masked_lm_dataset import MaskedLMDataset
from fairseq.data.masked_lm_dictionary import BertDictionary
from fairseq.data.legacy.block_pair_dataset import BlockPairDataset
from fairseq.data.legacy.masked_lm_dataset import MaskedLMDataset
from fairseq.data.legacy.masked_lm_dictionary import BertDictionary
from . import FairseqTask, register_task
@register_task('masked_lm')
class MaskedLMTask(FairseqTask):
@register_task('legacy_masked_lm')
class LegacyMaskedLMTask(FairseqTask):
"""
Task for training Masked LM (BERT) model.
Args:
......
......@@ -5,7 +5,7 @@
# the root directory of this source tree. An additional grant of patent rights
# can be found in the PATENTS file in the same directory.
from fairseq.data.masked_lm_dictionary import MaskedLMDictionary
from fairseq.data.legacy.masked_lm_dictionary import MaskedLMDictionary
from fairseq.tasks.translation import TranslationTask
from . import register_task
......
......@@ -263,19 +263,20 @@ class TestLanguageModeling(unittest.TestCase):
class TestMaskedLanguageModel(unittest.TestCase):
def test_masked_lm(self):
def test_legacy_masked_lm(self):
with contextlib.redirect_stdout(StringIO()):
with tempfile.TemporaryDirectory("test_mlm") as data_dir:
with tempfile.TemporaryDirectory("test_legacy_mlm") as data_dir:
create_dummy_data(data_dir)
preprocess_lm_data(data_dir)
train_masked_language_model(data_dir, "masked_lm")
train_legacy_masked_language_model(data_dir, "masked_lm")
def _test_pretrained_masked_lm_for_translation(self, learned_pos_emb, encoder_only):
with contextlib.redirect_stdout(StringIO()):
with tempfile.TemporaryDirectory("test_mlm") as data_dir:
create_dummy_data(data_dir)
preprocess_lm_data(data_dir)
train_masked_language_model(
train_legacy_masked_language_model(
data_dir,
arch="masked_lm",
extra_args=('--encoder-learned-pos',) if learned_pos_emb else ()
......@@ -332,7 +333,8 @@ class TestMaskedLanguageModel(unittest.TestCase):
def test_pretrained_masked_lm_for_translation_encoder_only(self):
self._test_pretrained_masked_lm_for_translation(True, True)
def train_masked_language_model(data_dir, arch, extra_args=()):
def train_legacy_masked_language_model(data_dir, arch, extra_args=()):
train_parser = options.get_training_parser()
# TODO: langs should be in and out right?
train_args = options.parse_args_and_arch(
......@@ -361,7 +363,7 @@ def train_masked_language_model(data_dir, arch, extra_args=()):
"0.1",
# MLM args
"--criterion",
"masked_lm_loss",
"legacy_masked_lm_loss",
"--masked-lm-only",
"--monolingual-langs",
"in,out",
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment