Commit 7df61696 authored by Sugon_ldc's avatar Sugon_ldc
Browse files

add fairseq0.10.2

parents
Pipeline #471 failed with stages
in 0 seconds
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
from fairseq.model_parallel.modules import ModelParallelMultiheadAttention
from fairseq.modules import TransformerDecoderLayer, TransformerEncoderLayer
try:
from fairseq.model_parallel.megatron.mpu import (
ColumnParallelLinear,
RowParallelLinear,
)
has_megatron_submodule = True
except (ImportError, ModuleNotFoundError):
has_megatron_submodule = False
class ModelParallelTransformerEncoderLayer(TransformerEncoderLayer):
"""Encoder layer block over multiple gpus.
See "Megatron-LM: https://arxiv.org/pdf/1909.08053.pdf" for more details.
"""
def build_fc1(self, input_dim, output_dim, q_noise, qn_block_size):
if q_noise > 0:
raise NotImplementedError
return ColumnParallelLinear(input_dim, output_dim, gather_output=False)
def build_fc2(self, input_dim, output_dim, q_noise, qn_block_size):
if q_noise > 0:
raise NotImplementedError
return RowParallelLinear(input_dim, output_dim, input_is_parallel=True)
def build_self_attention(self, embed_dim, args, **unused_kwargs):
return ModelParallelMultiheadAttention(
embed_dim,
args.encoder_attention_heads,
dropout=args.attention_dropout,
self_attention=True,
)
class ModelParallelTransformerDecoderLayer(TransformerDecoderLayer):
"""Decoder layer block.
See "Megatron-LM: https://arxiv.org/pdf/1909.08053.pdf" for more details.
"""
def build_fc1(self, input_dim, output_dim, q_noise, qn_block_size):
if q_noise > 0:
raise NotImplementedError
return ColumnParallelLinear(input_dim, output_dim, gather_output=False)
def build_fc2(self, input_dim, output_dim, q_noise, qn_block_size):
if q_noise > 0:
raise NotImplementedError
return RowParallelLinear(input_dim, output_dim, input_is_parallel=True)
def build_self_attention(self, embed_dim, args, **unused_kwargs):
return ModelParallelMultiheadAttention(
embed_dim=embed_dim,
num_heads=args.decoder_attention_heads,
dropout=args.attention_dropout,
self_attention=not getattr(args, "cross_self_attention", False),
)
def build_encoder_attention(self, embed_dim, args, **unused_kwargs):
return ModelParallelMultiheadAttention(
embed_dim=embed_dim,
num_heads=args.decoder_attention_heads,
kdim=getattr(args, "encoder_embed_dim", None),
vdim=getattr(args, "encoder_embed_dim", None),
dropout=args.attention_dropout,
encoder_decoder_attention=True,
)
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import random
from typing import Optional, Tuple
import torch
import torch.nn as nn
import torch.nn.functional as F
from fairseq.model_parallel.modules import ModelParallelTransformerSentenceEncoderLayer
from fairseq.modules import (
LayerNorm,
MultiheadAttention,
PositionalEmbedding,
TransformerSentenceEncoder,
)
try:
from fairseq.model_parallel.megatron.mpu import VocabParallelEmbedding
has_megatron_submodule = True
except (ImportError, ModuleNotFoundError):
has_megatron_submodule = False
class ModelParallelTransformerSentenceEncoder(TransformerSentenceEncoder):
"""
Implementation for a Model Parallel Bi-directional Transformer based
Sentence Encoder used in BERT/XLM style pre-trained models.
"""
def build_embedding(self, vocab_size, embedding_dim, padding_idx):
return VocabParallelEmbedding(vocab_size, embedding_dim, padding_idx)
def build_transformer_sentence_encoder_layer(
self,
embedding_dim,
ffn_embedding_dim,
num_attention_heads,
dropout,
attention_dropout,
activation_dropout,
activation_fn,
export,
**unused,
):
return ModelParallelTransformerSentenceEncoderLayer(
embedding_dim=embedding_dim,
ffn_embedding_dim=ffn_embedding_dim,
num_attention_heads=num_attention_heads,
dropout=dropout,
attention_dropout=attention_dropout,
activation_dropout=activation_dropout,
activation_fn=activation_fn,
export=export,
)
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import torch
import torch.nn.functional as F
from fairseq import utils
from fairseq.model_parallel.modules import ModelParallelMultiheadAttention
from fairseq.modules import TransformerSentenceEncoderLayer
try:
from fairseq.model_parallel.megatron.mpu import (
ColumnParallelLinear,
RowParallelLinear,
)
has_megatron_submodule = True
except (ImportError, ModuleNotFoundError):
has_megatron_submodule = False
class ModelParallelTransformerSentenceEncoderLayer(TransformerSentenceEncoderLayer):
"""
Implements a Model Parallel Transformer Encoder Layer used in
BERT/XLM style pre-trained models.
"""
def build_fc1(self, input_dim, output_dim, **unused):
return ColumnParallelLinear(input_dim, output_dim, gather_output=False)
def build_fc2(self, input_dim, output_dim, **unused):
return RowParallelLinear(input_dim, output_dim, input_is_parallel=True)
def build_self_attention(
self,
embed_dim,
num_attention_heads,
dropout,
**kwargs,
):
return ModelParallelMultiheadAttention(
embed_dim, num_attention_heads, dropout=dropout, self_attention=True
)
def forward(
self,
x: torch.Tensor,
self_attn_mask: torch.Tensor = None,
self_attn_padding_mask: torch.Tensor = None,
):
"""
LayerNorm is applied either before or after the self-attention/ffn
modules similar to the original Transformer imlementation.
"""
residual = x
x = self.self_attn_layer_norm(x)
x, attn = self.self_attn(
query=x,
key=x,
value=x,
key_padding_mask=self_attn_padding_mask,
need_weights=False,
attn_mask=self_attn_mask,
)
x = self.dropout_module(x)
x = residual + x
residual = x
x = self.final_layer_norm(x)
x = self.activation_fn(self.fc1(x))
x = self.activation_dropout_module(x)
x = self.fc2(x)
x = self.dropout_module(x)
x = residual + x
return x, None
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
"""isort:skip_file"""
import argparse
import importlib
import os
from argparse import Namespace
from typing import Union
import fairseq
from fairseq.dataclass import FairseqDataclass
from omegaconf import DictConfig, OmegaConf
from .composite_encoder import CompositeEncoder
from .distributed_fairseq_model import DistributedFairseqModel
from .fairseq_decoder import FairseqDecoder
from .fairseq_encoder import FairseqEncoder
from .fairseq_incremental_decoder import FairseqIncrementalDecoder
from .fairseq_model import (
BaseFairseqModel,
FairseqEncoderDecoderModel,
FairseqEncoderModel,
FairseqLanguageModel,
FairseqModel,
FairseqMultiModel,
)
MODEL_REGISTRY = {}
MODEL_DATACLASS_REGISTRY = {}
ARCH_MODEL_REGISTRY = {}
ARCH_MODEL_NAME_REGISTRY = {}
ARCH_MODEL_INV_REGISTRY = {}
ARCH_CONFIG_REGISTRY = {}
__all__ = [
"BaseFairseqModel",
"CompositeEncoder",
"DistributedFairseqModel",
"FairseqDecoder",
"FairseqEncoder",
"FairseqEncoderDecoderModel",
"FairseqEncoderModel",
"FairseqIncrementalDecoder",
"FairseqLanguageModel",
"FairseqModel",
"FairseqMultiModel",
]
def build_model(model_cfg: Union[DictConfig, Namespace], task):
if isinstance(model_cfg, DictConfig):
return ARCH_MODEL_REGISTRY[model_cfg._name].build_model(model_cfg, task)
return ARCH_MODEL_REGISTRY[model_cfg.arch].build_model(model_cfg, task)
def register_model(name, dataclass=None):
"""
New model types can be added to fairseq with the :func:`register_model`
function decorator.
For example::
@register_model('lstm')
class LSTM(FairseqEncoderDecoderModel):
(...)
.. note:: All models must implement the :class:`BaseFairseqModel` interface.
Typically you will extend :class:`FairseqEncoderDecoderModel` for
sequence-to-sequence tasks or :class:`FairseqLanguageModel` for
language modeling tasks.
Args:
name (str): the name of the model
"""
def register_model_cls(cls):
if name in MODEL_REGISTRY:
raise ValueError("Cannot register duplicate model ({})".format(name))
if not issubclass(cls, BaseFairseqModel):
raise ValueError(
"Model ({}: {}) must extend BaseFairseqModel".format(name, cls.__name__)
)
MODEL_REGISTRY[name] = cls
if dataclass is not None and not issubclass(dataclass, FairseqDataclass):
raise ValueError(
"Dataclass {} must extend FairseqDataclass".format(dataclass)
)
cls.__dataclass = dataclass
MODEL_DATACLASS_REGISTRY[name] = dataclass
return cls
return register_model_cls
def register_model_architecture(model_name, arch_name):
"""
New model architectures can be added to fairseq with the
:func:`register_model_architecture` function decorator. After registration,
model architectures can be selected with the ``--arch`` command-line
argument.
For example::
@register_model_architecture('lstm', 'lstm_luong_wmt_en_de')
def lstm_luong_wmt_en_de(args):
args.encoder_embed_dim = getattr(args, 'encoder_embed_dim', 1000)
(...)
The decorated function should take a single argument *args*, which is a
:class:`argparse.Namespace` of arguments parsed from the command-line. The
decorated function should modify these arguments in-place to match the
desired architecture.
Args:
model_name (str): the name of the Model (Model must already be
registered)
arch_name (str): the name of the model architecture (``--arch``)
"""
def arch_override_from_yaml(args, arch):
root_dir = os.path.dirname(os.path.dirname(fairseq.__file__))
yaml_path = os.path.join(root_dir, "config/model/{}.yaml".format(arch))
if not os.path.exists(yaml_path):
raise RuntimeError(f"yaml file {yaml_path} does not exist!")
arch_cfg = OmegaConf.load(yaml_path)
for k, v in arch_cfg.items():
setattr(args, k, getattr(args, k, v))
def register_model_arch_fn(fn):
if model_name not in MODEL_REGISTRY:
raise ValueError(
"Cannot register model architecture for unknown model type ({})".format(
model_name
)
)
if arch_name in ARCH_MODEL_REGISTRY:
raise ValueError(
"Cannot register duplicate model architecture ({})".format(arch_name)
)
if not callable(fn):
raise ValueError(
"Model architecture must be callable ({})".format(arch_name)
)
ARCH_MODEL_REGISTRY[arch_name] = MODEL_REGISTRY[model_name]
ARCH_MODEL_NAME_REGISTRY[arch_name] = model_name
ARCH_MODEL_INV_REGISTRY.setdefault(model_name, []).append(arch_name)
if type(fn) is type and issubclass(fn, BaseFairseqModel):
# for model classes migrated with hydra
# in this case, we are using this decorator directly on model class since
# we do not need arch overriding functions.
ARCH_CONFIG_REGISTRY[arch_name] = lambda args: arch_override_from_yaml(
args, arch=arch_name
)
else:
ARCH_CONFIG_REGISTRY[arch_name] = fn
return fn
return register_model_arch_fn
# automatically import any Python files in the models/ directory
models_dir = os.path.dirname(__file__)
for file in os.listdir(models_dir):
path = os.path.join(models_dir, file)
if (
not file.startswith("_")
and not file.startswith(".")
and (file.endswith(".py") or os.path.isdir(path))
):
model_name = file[: file.find(".py")] if file.endswith(".py") else file
module = importlib.import_module("fairseq.models." + model_name)
# extra `model_parser` for sphinx
if model_name in MODEL_REGISTRY:
parser = argparse.ArgumentParser(add_help=False)
group_archs = parser.add_argument_group("Named architectures")
group_archs.add_argument(
"--arch", choices=ARCH_MODEL_INV_REGISTRY[model_name]
)
group_args = parser.add_argument_group("Additional command-line arguments")
MODEL_REGISTRY[model_name].add_args(group_args)
globals()[model_name + "_parser"] = parser
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment