Commit 47fd9852 authored by Myle Ott's avatar Myle Ott Committed by Facebook Github Bot
Browse files

Move Masked LM components to legacy/ -- new ones are coming

Summary: Pull Request resolved: https://github.com/fairinternal/fairseq-py/pull/740

Differential Revision: D16377797

Pulled By: myleott

fbshipit-source-id: f7d6c8b00a77e279ea94376b1f0fcd15087eaf5f
parent 9c89e882
...@@ -63,7 +63,7 @@ fairseq-train \ ...@@ -63,7 +63,7 @@ fairseq-train \
--optimizer adam --lr-scheduler reduce_lr_on_plateau \ --optimizer adam --lr-scheduler reduce_lr_on_plateau \
--lr-shrink 0.5 --lr 0.0001 --min-lr 1e-09 \ --lr-shrink 0.5 --lr 0.0001 --min-lr 1e-09 \
--dropout 0.1 \ --dropout 0.1 \
--criterion masked_lm_loss \ --criterion legacy_masked_lm_loss \
--max-tokens 2048 --tokens-per-sample 256 --attention-dropout 0.1 \ --max-tokens 2048 --tokens-per-sample 256 --attention-dropout 0.1 \
--dataset-impl lazy --seed 0 \ --dataset-impl lazy --seed 0 \
--masked-lm-only \ --masked-lm-only \
......
...@@ -32,8 +32,8 @@ def compute_cross_entropy_loss(logits, targets, ignore_index=-100): ...@@ -32,8 +32,8 @@ def compute_cross_entropy_loss(logits, targets, ignore_index=-100):
return loss return loss
@register_criterion('masked_lm_loss') @register_criterion('legacy_masked_lm_loss')
class MaskedLmLoss(FairseqCriterion): class LegacyMaskedLmLoss(FairseqCriterion):
""" """
Implementation for the loss used in masked language model (MLM) training. Implementation for the loss used in masked language model (MLM) training.
This optionally also computes the next sentence prediction (NSP) loss and This optionally also computes the next sentence prediction (NSP) loss and
......
...@@ -6,18 +6,15 @@ ...@@ -6,18 +6,15 @@
# can be found in the PATENTS file in the same directory. # can be found in the PATENTS file in the same directory.
from .dictionary import Dictionary, TruncatedDictionary from .dictionary import Dictionary, TruncatedDictionary
from .masked_lm_dictionary import BertDictionary, MaskedLMDictionary
from .fairseq_dataset import FairseqDataset from .fairseq_dataset import FairseqDataset
from .audio.raw_audio_dataset import RawAudioDataset from .audio.raw_audio_dataset import RawAudioDataset
from .backtranslation_dataset import BacktranslationDataset from .backtranslation_dataset import BacktranslationDataset
from .block_pair_dataset import BlockPairDataset
from .concat_dataset import ConcatDataset from .concat_dataset import ConcatDataset
from .indexed_dataset import IndexedCachedDataset, IndexedDataset, IndexedRawTextDataset, MMapIndexedDataset from .indexed_dataset import IndexedCachedDataset, IndexedDataset, IndexedRawTextDataset, MMapIndexedDataset
from .language_pair_dataset import LanguagePairDataset from .language_pair_dataset import LanguagePairDataset
from .lm_context_window_dataset import LMContextWindowDataset from .lm_context_window_dataset import LMContextWindowDataset
from .masked_lm_dataset import MaskedLMDataset
from .monolingual_dataset import MonolingualDataset from .monolingual_dataset import MonolingualDataset
from .noising import NoisingDataset from .noising import NoisingDataset
from .round_robin_zip_datasets import RoundRobinZipDatasets from .round_robin_zip_datasets import RoundRobinZipDatasets
...@@ -34,8 +31,6 @@ from .iterators import ( ...@@ -34,8 +31,6 @@ from .iterators import (
__all__ = [ __all__ = [
'BacktranslationDataset', 'BacktranslationDataset',
'BertDictionary',
'BlockPairDataset',
'ConcatDataset', 'ConcatDataset',
'CountingIterator', 'CountingIterator',
'Dictionary', 'Dictionary',
...@@ -47,8 +42,6 @@ __all__ = [ ...@@ -47,8 +42,6 @@ __all__ = [
'IndexedRawTextDataset', 'IndexedRawTextDataset',
'LanguagePairDataset', 'LanguagePairDataset',
'LMContextWindowDataset', 'LMContextWindowDataset',
'MaskedLMDataset',
'MaskedLMDictionary',
'MMapIndexedDataset', 'MMapIndexedDataset',
'MonolingualDataset', 'MonolingualDataset',
'NoisingDataset', 'NoisingDataset',
......
# Copyright (c) 2017-present, Facebook, Inc.
# All rights reserved.
#
# This source code is licensed under the license found in the LICENSE file in
# the root directory of this source tree. An additional grant of patent rights
# can be found in the PATENTS file in the same directory.
from .masked_lm_dictionary import BertDictionary, MaskedLMDictionary
from .block_pair_dataset import BlockPairDataset
from .masked_lm_dataset import MaskedLMDataset
__all__ = [
'BertDictionary',
'BlockPairDataset',
'MaskedLMDataset',
'MaskedLMDictionary',
]
...@@ -10,7 +10,7 @@ import math ...@@ -10,7 +10,7 @@ import math
import numpy as np import numpy as np
import torch import torch
from . import FairseqDataset from fairseq.data import FairseqDataset
class BlockPairDataset(FairseqDataset): class BlockPairDataset(FairseqDataset):
......
...@@ -12,10 +12,10 @@ import torch ...@@ -12,10 +12,10 @@ import torch
from typing import Dict, List, Tuple from typing import Dict, List, Tuple
from . import FairseqDataset, data_utils from fairseq.data import FairseqDataset, data_utils
from fairseq.data import Dictionary from fairseq.data import Dictionary
from fairseq.data.block_pair_dataset import BlockPairDataset from fairseq.data.legacy.block_pair_dataset import BlockPairDataset
from fairseq.data.token_block_dataset import TokenBlockDataset from fairseq.data.token_block_dataset import TokenBlockDataset
from fairseq.data.concat_dataset import ConcatDataset from fairseq.data.concat_dataset import ConcatDataset
......
...@@ -9,7 +9,7 @@ import os ...@@ -9,7 +9,7 @@ import os
from typing import Any, Dict from typing import Any, Dict
from fairseq import checkpoint_utils from fairseq import checkpoint_utils
from fairseq.data.masked_lm_dictionary import MaskedLMDictionary from fairseq.data.legacy.masked_lm_dictionary import MaskedLMDictionary
from fairseq.models import register_model, register_model_architecture from fairseq.models import register_model, register_model_architecture
from fairseq.models.transformer import ( from fairseq.models.transformer import (
TransformerDecoder, TransformerDecoder,
......
...@@ -13,7 +13,7 @@ from collections import OrderedDict ...@@ -13,7 +13,7 @@ from collections import OrderedDict
import numpy as np import numpy as np
from fairseq import tokenizer from fairseq import tokenizer
from fairseq.data.masked_lm_dictionary import MaskedLMDictionary from fairseq.data.legacy.masked_lm_dictionary import MaskedLMDictionary
from fairseq.data import ( from fairseq.data import (
ConcatDataset, ConcatDataset,
...@@ -23,7 +23,7 @@ from fairseq.data import ( ...@@ -23,7 +23,7 @@ from fairseq.data import (
) )
from fairseq.data import Dictionary from fairseq.data import Dictionary
from fairseq.data.masked_lm_dataset import MaskedLMDataset from fairseq.data.legacy.masked_lm_dataset import MaskedLMDataset
from fairseq.data.multi_corpus_sampled_dataset import MultiCorpusSampledDataset from fairseq.data.multi_corpus_sampled_dataset import MultiCorpusSampledDataset
from . import FairseqTask, register_task from . import FairseqTask, register_task
......
...@@ -17,15 +17,15 @@ from fairseq.data import ( ...@@ -17,15 +17,15 @@ from fairseq.data import (
) )
from fairseq.data import Dictionary from fairseq.data import Dictionary
from fairseq.data.block_pair_dataset import BlockPairDataset from fairseq.data.legacy.block_pair_dataset import BlockPairDataset
from fairseq.data.masked_lm_dataset import MaskedLMDataset from fairseq.data.legacy.masked_lm_dataset import MaskedLMDataset
from fairseq.data.masked_lm_dictionary import BertDictionary from fairseq.data.legacy.masked_lm_dictionary import BertDictionary
from . import FairseqTask, register_task from . import FairseqTask, register_task
@register_task('masked_lm') @register_task('legacy_masked_lm')
class MaskedLMTask(FairseqTask): class LegacyMaskedLMTask(FairseqTask):
""" """
Task for training Masked LM (BERT) model. Task for training Masked LM (BERT) model.
Args: Args:
......
...@@ -5,7 +5,7 @@ ...@@ -5,7 +5,7 @@
# the root directory of this source tree. An additional grant of patent rights # the root directory of this source tree. An additional grant of patent rights
# can be found in the PATENTS file in the same directory. # can be found in the PATENTS file in the same directory.
from fairseq.data.masked_lm_dictionary import MaskedLMDictionary from fairseq.data.legacy.masked_lm_dictionary import MaskedLMDictionary
from fairseq.tasks.translation import TranslationTask from fairseq.tasks.translation import TranslationTask
from . import register_task from . import register_task
......
...@@ -263,19 +263,20 @@ class TestLanguageModeling(unittest.TestCase): ...@@ -263,19 +263,20 @@ class TestLanguageModeling(unittest.TestCase):
class TestMaskedLanguageModel(unittest.TestCase): class TestMaskedLanguageModel(unittest.TestCase):
def test_masked_lm(self):
def test_legacy_masked_lm(self):
with contextlib.redirect_stdout(StringIO()): with contextlib.redirect_stdout(StringIO()):
with tempfile.TemporaryDirectory("test_mlm") as data_dir: with tempfile.TemporaryDirectory("test_legacy_mlm") as data_dir:
create_dummy_data(data_dir) create_dummy_data(data_dir)
preprocess_lm_data(data_dir) preprocess_lm_data(data_dir)
train_masked_language_model(data_dir, "masked_lm") train_legacy_masked_language_model(data_dir, "masked_lm")
def _test_pretrained_masked_lm_for_translation(self, learned_pos_emb, encoder_only): def _test_pretrained_masked_lm_for_translation(self, learned_pos_emb, encoder_only):
with contextlib.redirect_stdout(StringIO()): with contextlib.redirect_stdout(StringIO()):
with tempfile.TemporaryDirectory("test_mlm") as data_dir: with tempfile.TemporaryDirectory("test_mlm") as data_dir:
create_dummy_data(data_dir) create_dummy_data(data_dir)
preprocess_lm_data(data_dir) preprocess_lm_data(data_dir)
train_masked_language_model( train_legacy_masked_language_model(
data_dir, data_dir,
arch="masked_lm", arch="masked_lm",
extra_args=('--encoder-learned-pos',) if learned_pos_emb else () extra_args=('--encoder-learned-pos',) if learned_pos_emb else ()
...@@ -332,7 +333,8 @@ class TestMaskedLanguageModel(unittest.TestCase): ...@@ -332,7 +333,8 @@ class TestMaskedLanguageModel(unittest.TestCase):
def test_pretrained_masked_lm_for_translation_encoder_only(self): def test_pretrained_masked_lm_for_translation_encoder_only(self):
self._test_pretrained_masked_lm_for_translation(True, True) self._test_pretrained_masked_lm_for_translation(True, True)
def train_masked_language_model(data_dir, arch, extra_args=()):
def train_legacy_masked_language_model(data_dir, arch, extra_args=()):
train_parser = options.get_training_parser() train_parser = options.get_training_parser()
# TODO: langs should be in and out right? # TODO: langs should be in and out right?
train_args = options.parse_args_and_arch( train_args = options.parse_args_and_arch(
...@@ -361,7 +363,7 @@ def train_masked_language_model(data_dir, arch, extra_args=()): ...@@ -361,7 +363,7 @@ def train_masked_language_model(data_dir, arch, extra_args=()):
"0.1", "0.1",
# MLM args # MLM args
"--criterion", "--criterion",
"masked_lm_loss", "legacy_masked_lm_loss",
"--masked-lm-only", "--masked-lm-only",
"--monolingual-langs", "--monolingual-langs",
"in,out", "in,out",
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment