Unverified Commit 57f25f4b authored by Mitch Naylor's avatar Mitch Naylor Committed by GitHub
Browse files

Add Mega: Moving Average Equipped Gated Attention (#21766)



* add mega file structure and plain pytorch version of mega source code

* added config class with old naming conventions

* filled in mega documentation

* added config class and embeddings with optional token types

* updated notes

* starting the conversion process, deleted intermediate and added use_cache back to config

* renamed config attributes in modeling_mega.py

* checkpointing before refactoring incremental decoding functions

* removed stateful incremental key/values for EMA and self-attention

* refactored MovingAverageGatedAttention to remove stateful k/v history and use unified attention mask

* MovingAverageGatedAttention works with incremental decoding + past values, added sequence length enforcement

* more comments in MovingAverageGatedAttention + checkpointing before GatedCrossAttention

* bug fix in attention mask handling in MovingAverageGatedAttention

* removed incremental state from GatedCrossAttention and removed IncrementalState class

* finished gated cross attention and got MegaLayer working

* fixed causal masking in mega decoder

* fixed how padding and causal masks are passed through MegaLayer with and without k/v caching

* finished MegaModel; tested with encoder, decoder-only, and cross-attention type inputs; started work on downstream classes; removed mentions of position_ids

* added optional dense hidden layer for masked and causal LM classes

* docstring updates in MultiHeadEMA and GatedCrossAttention, removed unnecessary inputs in cross-attention

* removed before_attn_fn in Mega class and updated docstrings and comments up to there

* bug fix in MovingAverageGatedAttention masking

* working conversion of MLM checkpoint in scratchpad script -- perfect matches

* moved arg for hidden dense layer in LM head to config; discovered issue where from_pretrained is renaming gamma and beta parameters

* renamed gamma and beta parameters to avoid HF renaming when loading from checkpoint

* finished checkpoint conversion script

* cleanup old class in mega config script

* removed 'copied from' statements and passing integration tests

* added num_attention_heads=1 to config for integration compatibility, decoder tests working, generation tests failing

* fixed tuple output of megamodel

* all common tests passing after fixing issues in decoder, gradient retention, and initialization

* added mega-specific tests, ready for more documentation and style checks

* updated docstrings; checkpoint before style fixes

* style and quality checks, fixed initialization problem in float_tensor, ready for PR

* added mega to toctree

* removed unnecessary arg in megaconfig

* removed unused arg and fixed code samples with leftover roberta models

* Apply suggestions from code review

Applied all suggestions except the one renaming a class, as I'll need to update that througout
Co-authored-by: default avatarArthur <48595927+ArthurZucker@users.noreply.github.com>

* fixed issue where .view breaks batch dimension, conversion script fixed with absolute imports, updated readme with Mega->MEGA

* removed asserts in Mega code, renamed sequencenorm, gatedcrossattention, and NFFN, replaced get_activation_fn with ACTFN, and added sequencenorm to layer norms

* reformatted .forward() docstrings to match style and removed unused mask input in cross-attention

* removed all reset_parameters() methods and rolled into MegaPreTrainedModel._init_weights()

* renamed all single-letter variables and improved readability in tensor size comments, Mega->MEGA in 2 documentation files

* variable names in NFFN

* manual Mega->MEGA changes in docs

* Mega->MEGA in config auto

* style and quality fixes

* Apply suggestions from code review
Co-authored-by: default avatarArthur <48595927+ArthurZucker@users.noreply.github.com>

* renamed parameters and variables with confusing names, added copied from statements, moved fft conv to its own method, other cleanup from PR comments

* commit before dealing with merge conflicts

* made new attention activation functions available in ACT2FN and added generation test from OPT

* style and quality in activations and tests

* documentation fixes, renaming variables in dropout and rotary positions, used built-in causal masking, encoders->layers in MegaModel, moved comments into docstrings

* style and quality fixes after latest updates, before rotary position ids

* causal mask in MegaBlock docstring + added missing device passing

* Apply suggestions from code review
Co-authored-by: default avatarArthur <48595927+ArthurZucker@users.noreply.github.com>

* Update README.md
Co-authored-by: default avatarSylvain Gugger <35901082+sgugger@users.noreply.github.com>

* added Mega prefixes where missing, reverted MegaSequenceNorm to if-else, other module renaming requested in PR

* style and quality fixes + readme updates pointing to main

---------
Co-authored-by: default avatarArthur <48595927+ArthurZucker@users.noreply.github.com>
Co-authored-by: default avatarSylvain Gugger <35901082+sgugger@users.noreply.github.com>
parent 0fa46524
......@@ -122,6 +122,7 @@ CONFIG_MAPPING_NAMES = OrderedDict(
("maskformer-swin", "MaskFormerSwinConfig"),
("mbart", "MBartConfig"),
("mctct", "MCTCTConfig"),
("mega", "MegaConfig"),
("megatron-bert", "MegatronBertConfig"),
("mgp-str", "MgpstrConfig"),
("mobilebert", "MobileBertConfig"),
......@@ -299,6 +300,7 @@ CONFIG_ARCHIVE_MAP_MAPPING_NAMES = OrderedDict(
("maskformer", "MASKFORMER_PRETRAINED_CONFIG_ARCHIVE_MAP"),
("mbart", "MBART_PRETRAINED_CONFIG_ARCHIVE_MAP"),
("mctct", "MCTCT_PRETRAINED_CONFIG_ARCHIVE_MAP"),
("mega", "MEGA_PRETRAINED_CONFIG_ARCHIVE_MAP"),
("megatron-bert", "MEGATRON_BERT_PRETRAINED_CONFIG_ARCHIVE_MAP"),
("mgp-str", "MGP_STR_PRETRAINED_CONFIG_ARCHIVE_MAP"),
("mobilenet_v1", "MOBILENET_V1_PRETRAINED_CONFIG_ARCHIVE_MAP"),
......@@ -484,6 +486,7 @@ MODEL_NAMES_MAPPING = OrderedDict(
("mbart", "mBART"),
("mbart50", "mBART-50"),
("mctct", "M-CTC-T"),
("mega", "MEGA"),
("megatron-bert", "Megatron-BERT"),
("megatron_gpt2", "Megatron-GPT2"),
("mgp-str", "MGP-STR"),
......
......@@ -120,6 +120,7 @@ MODEL_MAPPING_NAMES = OrderedDict(
("maskformer-swin", "MaskFormerSwinModel"),
("mbart", "MBartModel"),
("mctct", "MCTCTModel"),
("mega", "MegaModel"),
("megatron-bert", "MegatronBertModel"),
("mgp-str", "MgpstrForSceneTextRecognition"),
("mobilebert", "MobileBertModel"),
......@@ -228,6 +229,7 @@ MODEL_FOR_PRETRAINING_MAPPING_NAMES = OrderedDict(
("longformer", "LongformerForMaskedLM"),
("luke", "LukeForMaskedLM"),
("lxmert", "LxmertForPreTraining"),
("mega", "MegaForMaskedLM"),
("megatron-bert", "MegatronBertForPreTraining"),
("mobilebert", "MobileBertForPreTraining"),
("mpnet", "MPNetForMaskedLM"),
......@@ -302,6 +304,7 @@ MODEL_WITH_LM_HEAD_MAPPING_NAMES = OrderedDict(
("luke", "LukeForMaskedLM"),
("m2m_100", "M2M100ForConditionalGeneration"),
("marian", "MarianMTModel"),
("mega", "MegaForMaskedLM"),
("megatron-bert", "MegatronBertForCausalLM"),
("mobilebert", "MobileBertForMaskedLM"),
("mpnet", "MPNetForMaskedLM"),
......@@ -363,6 +366,7 @@ MODEL_FOR_CAUSAL_LM_MAPPING_NAMES = OrderedDict(
("llama", "LlamaForCausalLM"),
("marian", "MarianForCausalLM"),
("mbart", "MBartForCausalLM"),
("mega", "MegaForCausalLM"),
("megatron-bert", "MegatronBertForCausalLM"),
("mvp", "MvpForCausalLM"),
("openai-gpt", "OpenAIGPTLMHeadModel"),
......@@ -531,6 +535,7 @@ MODEL_FOR_MASKED_LM_MAPPING_NAMES = OrderedDict(
("longformer", "LongformerForMaskedLM"),
("luke", "LukeForMaskedLM"),
("mbart", "MBartForConditionalGeneration"),
("mega", "MegaForMaskedLM"),
("megatron-bert", "MegatronBertForMaskedLM"),
("mobilebert", "MobileBertForMaskedLM"),
("mpnet", "MPNetForMaskedLM"),
......@@ -657,6 +662,7 @@ MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING_NAMES = OrderedDict(
("luke", "LukeForSequenceClassification"),
("markuplm", "MarkupLMForSequenceClassification"),
("mbart", "MBartForSequenceClassification"),
("mega", "MegaForSequenceClassification"),
("megatron-bert", "MegatronBertForSequenceClassification"),
("mobilebert", "MobileBertForSequenceClassification"),
("mpnet", "MPNetForSequenceClassification"),
......@@ -719,6 +725,7 @@ MODEL_FOR_QUESTION_ANSWERING_MAPPING_NAMES = OrderedDict(
("lxmert", "LxmertForQuestionAnswering"),
("markuplm", "MarkupLMForQuestionAnswering"),
("mbart", "MBartForQuestionAnswering"),
("mega", "MegaForQuestionAnswering"),
("megatron-bert", "MegatronBertForQuestionAnswering"),
("mobilebert", "MobileBertForQuestionAnswering"),
("mpnet", "MPNetForQuestionAnswering"),
......@@ -796,6 +803,7 @@ MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING_NAMES = OrderedDict(
("longformer", "LongformerForTokenClassification"),
("luke", "LukeForTokenClassification"),
("markuplm", "MarkupLMForTokenClassification"),
("mega", "MegaForTokenClassification"),
("megatron-bert", "MegatronBertForTokenClassification"),
("mobilebert", "MobileBertForTokenClassification"),
("mpnet", "MPNetForTokenClassification"),
......@@ -838,6 +846,7 @@ MODEL_FOR_MULTIPLE_CHOICE_MAPPING_NAMES = OrderedDict(
("ibert", "IBertForMultipleChoice"),
("longformer", "LongformerForMultipleChoice"),
("luke", "LukeForMultipleChoice"),
("mega", "MegaForMultipleChoice"),
("megatron-bert", "MegatronBertForMultipleChoice"),
("mobilebert", "MobileBertForMultipleChoice"),
("mpnet", "MPNetForMultipleChoice"),
......
......@@ -194,6 +194,7 @@ else:
"MBart50TokenizerFast" if is_tokenizers_available() else None,
),
),
("mega", ("RobertaTokenizer", "RobertaTokenizerFast" if is_tokenizers_available() else None)),
("megatron-bert", ("BertTokenizer", "BertTokenizerFast" if is_tokenizers_available() else None)),
("mgp-str", ("MgpstrTokenizer", None)),
("mluke", ("MLukeTokenizer" if is_sentencepiece_available() else None, None)),
......
# Copyright 2023 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import TYPE_CHECKING
from ...utils import (
OptionalDependencyNotAvailable,
_LazyModule,
is_torch_available,
)
_import_structure = {
"configuration_mega": ["MEGA_PRETRAINED_CONFIG_ARCHIVE_MAP", "MegaConfig", "MegaOnnxConfig"],
}
try:
if not is_torch_available():
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
pass
else:
_import_structure["modeling_mega"] = [
"MEGA_PRETRAINED_MODEL_ARCHIVE_LIST",
"MegaForCausalLM",
"MegaForMaskedLM",
"MegaForMultipleChoice",
"MegaForQuestionAnswering",
"MegaForSequenceClassification",
"MegaForTokenClassification",
"MegaModel",
"MegaPreTrainedModel",
]
if TYPE_CHECKING:
from .configuration_mega import MEGA_PRETRAINED_CONFIG_ARCHIVE_MAP, MegaConfig, MegaOnnxConfig
try:
if not is_torch_available():
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
pass
else:
from .modeling_mega import (
MEGA_PRETRAINED_MODEL_ARCHIVE_LIST,
MegaForCausalLM,
MegaForMaskedLM,
MegaForMultipleChoice,
MegaForQuestionAnswering,
MegaForSequenceClassification,
MegaForTokenClassification,
MegaModel,
MegaPreTrainedModel,
)
else:
import sys
sys.modules[__name__] = _LazyModule(__name__, globals()["__file__"], _import_structure, module_spec=__spec__)
# coding=utf-8
# Copyright 2023 The Mega Authors and The HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
""" MEGA configuration"""
from collections import OrderedDict
from typing import Mapping
from ...configuration_utils import PretrainedConfig
from ...onnx import OnnxConfig
from ...utils import logging
logger = logging.get_logger(__name__)
MEGA_PRETRAINED_CONFIG_ARCHIVE_MAP = {
"mnaylor/mega-base-wikitext": "https://huggingface.co/mnaylor/mega-base-wikitext/resolve/main/config.json",
}
class MegaConfig(PretrainedConfig):
r"""
This is the configuration class to store the configuration of a [`MegaModel`]. It is used to instantiate a Mega
model according to the specified arguments, defining the model architecture. Instantiating a configuration with the
defaults will yield a similar configuration to that of the Mega
[mnaylor/mega-base-wikitext](https://huggingface.co/mnaylor/mega-base-wikitext) architecture.
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
documentation from [`PretrainedConfig`] for more information.
Args:
vocab_size (`int`, *optional*, defaults to 30522):
Vocabulary size of the Mega model. Defines the number of different tokens that can be represented by the
`inputs_ids` passed when calling [`MegaModel`].
hidden_size (`int`, *optional*, defaults to 128):
Dimensionality of the encoder layers and the pooler layer.
num_hidden_layers (`int`, *optional*, defaults to 4):
Number of hidden layers in the Mega encoder.
intermediate_size (`int`, *optional*, defaults to 256):
Dimensionality of the hidden size (self-attention value projection) within the Mega encoder
ema_projection_size (`int`, *optional*, defaults to 16):
Dimensionality of the MegaMultiDimensionDampedEma
bidirectional (`bool`, *optional*, defaults to `True`):
Whether the MegaMultiDimensionDampedEma used in Mega's self-attention should work bidirectionally (`True`)
or unidirectionally (`False`). Bidirectional EMA is incompatible with causal decoding, so this should be
False if you intend to use the model as a decoder.
shared_representation_size (`int`, *optional*, defaults to 64):
Dimensionality of the linear projection for shared representation of self-attention queries and keys
use_chunking (`bool`, *optional*, defaults to `False`):
Whether to chunk inputs for linear self-attention complexity (described as Mega-chunk in the paper)
chunk_size (`int`, *optional*, defaults to -1):
If `use_chunking` is set to `True`, determines the size of the chunks to apply to the input sequence. If
chunking is used, input sequences must be padded to a multiple of `chunk_size`
truncation (`int`, *optional*):
If specified, the sequence length for which to truncate MegaMultiDimensionDampedEma
normalize_before_mega (`bool`, *optional*, defaults to `True`):
Whether to normalize before (`True`) or after (`False`) passing through Mega encoder blocks
normalization_type (`str`, *optional*, defaults to `"scalenorm"`):
Type of normalization to use in Mega encoder blocks. Choose one of `"scalenorm"`, `"layernorm"`,
`"rmsnorm"`, `"batchnorm"`, or `"syncbatchnorm"` (GPU required for syncbatchnorm)
norm_affine (`bool`, *optional*, defaults to `True`):
If `True`, applies a parameterized affine transformation to inputs during normalization
activation (`str`, *optional*, defaults to `"silu"`):
Activation function to apply within Mega encoder blocks. Choose one of `"silu"`, `"relu"`, `"linear"`,
`"gelu"`, or `"gelu_accurate"`
attention_activation (`str`, *optional*, defaults to `"softmax"`):
Activation function to apply for single-headed self-attention (a la Transformer). Choose one of
`"softmax"`, `"laplace"`, or `"relu2"`
dropout_prob (`float`, *optional*, defaults to 0.1):
The dropout probability for EMA self-attention
hidden_dropout_prob (`float`, *optional*, defaults to 0.1):
The dropout probability for all fully connected layers in the embeddings, encoder, and pooler.
attention_probs_dropout_prob (`float`, *optional*, defaults to 0.1):
The dropout ratio for the attention probabilities.
use_feature_dropout (`bool`, *optional*, defaults to `False`):
Whether to use feature-based (`True`) or standard dropout (`False`)
use_normalized_ffn (`bool`, *optional*, defaults to `True`):
Whether to use the normalized feed-forward sub-layer in Mega blocks (`True`) or pass Mega encoder output
as-is (`False`)
nffn_hidden_size (`int`, *optional*, defaults to 256):
If using the normalized feed-forward network (NFFN) layer within Mega (`use_normalized_ffn = True`), this
is the hidden size of the NFFN
normalize_before_ffn (`bool`, *optional*, defaults to `True`):
Whether to normalize before (`True`) or after (`False`) the feed-forward portion of NFFN
nffn_activation_dropout_prob (`float`, *optional*, defaults to 0.1):
The dropout ratio for the NFFN component.
max_positions (`int`, *optional*, defaults to 2048):
The maximum sequence length to use for positional representations. For `"simple"` relative positional bias,
this is a hard limit on input length; `"rotary"` relative positional bias will extrapolate to longer
sequences
add_token_type_embeddings (`bool`, *optional*, defaults to `True`):
Whether to account for token types in embeddings. Left as optional to maintain compatibility with original
implementation while adding support for token types.
type_vocab_size (`int`, *optional*, defaults to 2):
The vocabulary size of the `token_type_ids` passed when calling [`MegaModel`]. Only used if
`add_token_type_embeddings = True`
initializer_range (`float`, *optional*, defaults to 0.02):
The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
ema_delta_alpha_range (`float`, *optional*, defaults to 0.2):
The standard deviation for initializing the delta (damping factor) and alpha (decay factor) parameters in
MegaMultiDimensionDampedEma.
ema_beta_range (`float`, *optional*, defaults to 0.02):
The standard deviation for initializing the beta parameter (expansion matrix) in
MegaMultiDimensionDampedEma.
ema_gamma_omega_range (`float`, *optional*, defaults to 1.0):
The standard deviation for initializing the gamma (projection matrix) and omega (residual weight)
parameters in MultiDimensionEMA.
relative_positional_bias (`str`, *optional*, defaults to `"rotary"`):
Type of relative positional encoding. Choose one of `"rotary"` or `"simple"`. If `"simple"` is selected,
`max_positions` is used as a limit on input size, while `"rotary"` extrapolates beyond `max_positions`.
is_decoder (`bool`, *optional*, defaults to `False`):
Whether the model is used as a decoder or not. If `False`, the model is used as an encoder.
use_cache (`bool`, *optional*, defaults to `True`):
Whether or not the model should return the last key/values attentions (not used by all models). Only
relevant if `config.is_decoder=True`.
classifier_dropout (`float`, *optional*):
The dropout ratio for the classification head.
add_lm_hidden_dense_layer (`bool`, *optional*, defaults to `True`):
Whether to include a hidden layer for projection between encoder outputs and LM heads (`True`) or pass
hidden states directly to LM head (`False`). Remains optional for compatibility with original
implementation
Examples:
```python
>>> from transformers import MegaConfig, MegaModel
>>> # Initializing a Mega configuration
>>> configuration = MegaConfig()
>>> # Initializing a model (with random weights) from the configuration
>>> model = MegaModel(configuration)
>>> # Accessing the model configuration
>>> configuration = model.config
```"""
model_type = "mega"
def __init__(
self,
vocab_size=30522,
hidden_size=128,
num_hidden_layers=4,
intermediate_size=256,
ema_projection_size=16,
bidirectional=True,
shared_representation_size=64,
use_chunking=False,
chunk_size=-1,
truncation=None,
normalize_before_mega=True,
normalization_type="scalenorm",
norm_affine=True,
activation="silu",
attention_activation="softmax",
dropout_prob=0.1,
hidden_dropout_prob=0.1,
attention_probs_dropout_prob=0.1,
use_feature_dropout=False,
use_normalized_ffn=True,
nffn_hidden_size=256,
normalize_before_ffn=True,
nffn_activation_dropout_prob=0.1,
max_positions=2048,
add_token_type_embeddings=False,
type_vocab_size=2,
initializer_range=0.02,
ema_delta_alpha_range=0.2,
ema_beta_range=0.02,
ema_gamma_omega_range=1.0,
pad_token_id=1,
bos_token_id=0,
eos_token_id=2,
relative_positional_bias="rotary",
classifier_dropout=None,
use_cache=True,
add_lm_hidden_dense_layer=True,
**kwargs,
):
super().__init__(pad_token_id=pad_token_id, bos_token_id=bos_token_id, eos_token_id=eos_token_id, **kwargs)
self.vocab_size = vocab_size
self.hidden_size = hidden_size
self.num_hidden_layers = num_hidden_layers
self.activation = activation
self.attention_activation = attention_activation
self.intermediate_size = intermediate_size
self.ema_projection_size = ema_projection_size
self.bidirectional = bidirectional
self.shared_representation_size = shared_representation_size
self.use_chunking = use_chunking
self.chunk_size = chunk_size
self.truncation = truncation
self.normalize_before_mega = normalize_before_mega
self.normalization_type = normalization_type
self.norm_affine = norm_affine
self.dropout_prob = dropout_prob
self.hidden_dropout_prob = hidden_dropout_prob
self.attention_probs_dropout_prob = attention_probs_dropout_prob
self.use_feature_dropout = use_feature_dropout
self.use_normalized_ffn = use_normalized_ffn
self.nffn_hidden_size = nffn_hidden_size
self.normalize_before_ffn = normalize_before_ffn
self.nffn_activation_dropout_prob = nffn_activation_dropout_prob
self.max_positions = max_positions
self.add_token_type_embeddings = add_token_type_embeddings
self.type_vocab_size = type_vocab_size
self.initializer_range = initializer_range
self.ema_delta_alpha_range = ema_delta_alpha_range
self.ema_beta_range = ema_beta_range
self.ema_gamma_omega_range = ema_gamma_omega_range
self.relative_positional_bias = relative_positional_bias
self.use_cache = use_cache
self.classifier_dropout = classifier_dropout
self.add_lm_hidden_dense_layer = add_lm_hidden_dense_layer
self.num_attention_heads = 1 # not used but required by Hugging Face
class MegaOnnxConfig(OnnxConfig):
@property
def inputs(self) -> Mapping[str, Mapping[int, str]]:
if self.task == "multiple-choice":
dynamic_axis = {0: "batch", 1: "choice", 2: "sequence"}
else:
dynamic_axis = {0: "batch", 1: "sequence"}
return OrderedDict(
[
("input_ids", dynamic_axis),
("attention_mask", dynamic_axis),
]
)
# coding=utf-8
# Copyright 2023 The HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Convert Mega pretrained checkpoint. Built to convert the Masked LM checkpoint located at
https://huggingface.co/mnaylor/mega-wikitext-103
Requirements:
- clone the Mega repo and install fairseq from there
1. git clone https://github.com/facebookresearch/mega.git
2. cd mega && pip install -e
- clone the pretrained weights for the original implementation from the hugging face repo
* use this location as the path for pretrained weights
"""
import argparse
# utilities to import the model weights and config file
import os
import pickle as pkl
# PyTorch + new model classes
import torch
from torch import nn
from transformers import AutoTokenizer, MegaConfig, MegaForMaskedLM
# import the EncoderLayer class used to pretrain
# !! NOTE !! this requires the version of fairseq that is built when you install the Mega source
try:
from fairseq.modules.mega_layer import MegaEncoderLayer
except ImportError:
raise ImportError("You need to install the version of fairseq from the Mega repo!")
# define the wrapper classes used to train the MLM (see colab notebook below)
# https://colab.research.google.com/drive/1qfUO6o5HRdxBblWlw058HVyvaEPhPpH8?usp=sharing
# MegaLM outputs hidden states
class MegaLM(nn.Module):
"The base class for our Mega encoder - given input IDs, embed text and return encoder output"
def __init__(self, mega_args, depth, vocab_size):
super().__init__()
self.mega_args = mega_args
self.embedding_layer = nn.Embedding(vocab_size, self.mega_args.encoder_embed_dim)
self.encoders = nn.ModuleList([MegaEncoderLayer(self.mega_args) for _ in range(depth)])
self.depth = depth
def forward(self, input_ids, attention_mask, batch_first=True, ignore_mask_value=0):
"""
Code for a forward pass - expects input_ids and attention_mask to come from a Hugging Face tokenizer as PyTorch
tensors, and returns a tensor of size (batch, n_classes) containing classification logits
Other options:
- batch_first: boolean indicating whether the batch dimension is first in input_ids (default: True, which
aligns with the HF tokenizer behavior)
- ignore_mask_value: the value in attention_mask that identifies tokens that should be ignored (default: 0,
which aligns with HF tokenizer)
"""
# Mega expects embeddings to be (time, batch, embedding size), but
# Hugging Face returns tokens as (batch, time)
if batch_first:
input_ids = input_ids.T
# to make things more confusing, Mega expects the attention mask to
# be (batch, time), but with values of 0 (normal token) and 1 (ignore token)
# which is the opposite of what HF returns
if ignore_mask_value == 0:
attention_mask = 1 - attention_mask
# get token embeddings from IDs
embeds = self.embedding_layer(input_ids)
# pass through the Mega layers
# input is (time, batch, encoder dim) and output is the same
for encoder in self.encoders:
embeds = encoder(embeds, attention_mask)
# return according to the shape specified
if batch_first:
# (T, B, H) --> (B, T, H)
return torch.transpose(embeds, 0, 1)
else:
return embeds
# renamed from MegaForMaskedLM to avoid confusion with new module
class OriginalMegaForMaskedLM(nn.Module):
"A wrapper class for doing masked language modeling with Mega"
def __init__(self, mega_args, depth, vocab_size):
super().__init__()
self.mega = MegaLM(mega_args, depth, vocab_size)
self.mlm_head = nn.Linear(mega_args.encoder_embed_dim, vocab_size)
self.dropout = nn.Dropout(p=0.1)
def forward(self, input_ids, attention_mask, batch_first=True, ignore_mask_value=0):
"""
Perform a forward pass through the Mega encoder and the masked LM head. Returns logits for each vocabulary
entry.
If `batch_first` (default to align with Hugging Face tokenizer behavior), output will have the shape (Batch
size, Sequence length, Vocab size); otherwise (S, B, V)
"""
encoder_output = self.mega(input_ids, attention_mask, batch_first, ignore_mask_value)
return self.mlm_head(self.dropout(encoder_output))
# code to convert the checkpoint located in the user-specified location
def convert_checkpoint_to_huggingface(pretrained_checkpoint_path, output_path, includes_tokenizer):
with open(os.path.join(pretrained_checkpoint_path, "model_args.pkl"), "rb") as f:
mega_original_args = pkl.load(f)
# load the original encoder
original_mlm = OriginalMegaForMaskedLM(**mega_original_args).eval()
# load its weights
print(
"Original Mega encoder:",
original_mlm.mega.load_state_dict(
torch.load(os.path.join(pretrained_checkpoint_path, "encoder_weights.pt"), map_location="cpu")
),
)
print(
"Original Mega MLM layer:",
original_mlm.mlm_head.load_state_dict(
torch.load(os.path.join(pretrained_checkpoint_path, "mlm_head_weights.pt"), map_location="cpu")
),
)
# create a new config from the old one
hf_config = MegaConfig(
num_hidden_layers=mega_original_args["depth"],
vocab_size=mega_original_args["vocab_size"],
hidden_size=mega_original_args["mega_args"].encoder_embed_dim,
shared_representation_size=mega_original_args["mega_args"].encoder_z_dim,
intermediate_size=mega_original_args["mega_args"].encoder_hidden_dim,
ema_projection_size=mega_original_args["mega_args"].encoder_n_dim,
dropout_prob=mega_original_args["mega_args"].dropout,
attention_probs_dropout_prob=mega_original_args["mega_args"].attention_dropout,
hidden_dropout_prob=mega_original_args["mega_args"].hidden_dropout,
activation=mega_original_args["mega_args"].activation_fn,
attention_activation=mega_original_args["mega_args"].attention_activation_fn,
bidirectional=mega_original_args["mega_args"].bidirectional,
use_chunking=mega_original_args["mega_args"].encoder_chunk_size > 0,
chunk_size=mega_original_args["mega_args"].encoder_chunk_size,
truncation=mega_original_args["mega_args"].truncation_length,
normalization_type=mega_original_args["mega_args"].normalization_type,
normalize_before_mega=True,
norm_affine=True,
use_feature_dropout=mega_original_args["mega_args"].feature_dropout,
relative_positional_bias=mega_original_args["mega_args"].rel_pos_bias,
max_positions=mega_original_args["mega_args"].max_source_positions,
nffn_hidden_size=mega_original_args["mega_args"].encoder_ffn_embed_dim,
normalize_before_ffn=mega_original_args["mega_args"].normalize_before,
# new arguments added for HF implementation
nffn_activation_dropout_prob=0.0,
add_token_type_embeddings=False,
add_lm_hidden_dense_layer=False,
)
hf_mlm = MegaForMaskedLM(hf_config).eval()
# the originl checkpoint just uses nn.Embedding for the word embeddings
# we use a wrapper module for embeddings to add support for positional embeddings
hf_mlm.mega.embedding_layer.word_embeddings.weight = original_mlm.mega.embedding_layer.weight
# modify the state dictionary of the original checkpoint to account for naming issues in the Hugging Face
# ecosystem -- any names containing "beta" or "gamma" aren't safe to use and are renamed upon _load_pretrained,
# also renaming previously confusing parameter names
original_state_dict = original_mlm.mega.encoders.state_dict()
updated_keys = {}
for module_name in original_state_dict.keys():
new_module_name = None
# have to handle gamma, beta, and alpha differently due to their use
# in multiple modules within the original repository;
# beta is used in EMA, MovingAverageGatedAttention, and RotaryRelativePositionalBias, and must be renamed due to flax/tf weights
# the EMA sublayer was renamed from "move" to "ema_gate" for readability, so that is also done here
if "beta" in module_name:
# EMA sub-layers were always called "move" in the original repo
if "move.beta" in module_name:
new_module_name = module_name.replace("move.beta", "ema_gate.ema_expansion_matrix")
elif "mega_layer.beta" in module_name:
new_module_name = module_name.replace("beta", "qk_bias")
else:
new_module_name = module_name.replace("beta", "b_param")
# beta is used in EMA and MovingAverageGatedAttention, and must be renamed due to flax/tf weights
elif "gamma" in module_name:
if "move.gamma" in module_name:
new_module_name = module_name.replace("move.gamma", "ema_gate.kernel_projection_matrix")
elif "mega_layer.gamma" in module_name:
new_module_name = module_name.replace("gamma", "qk_weight")
else:
new_module_name = module_name.replace("gamma", "g_param")
# alpha is used in EMA and positional bias; renaming to improve readability
elif "move.alpha" in module_name:
new_module_name = module_name.replace("move.alpha", "ema_gate.decay_factor")
# delta is only used in EMA; renaming to improve readability
elif "move.delta" in module_name:
new_module_name = module_name.replace("move.delta", "ema_gate.damping_factor")
# omega is only used in EMA; renaming to improve readability
elif "omega" in module_name:
new_module_name = module_name.replace("move.omega", "ema_gate.residual_weight")
if new_module_name:
updated_keys[module_name] = new_module_name
if len(updated_keys) != 0:
print(f"Renaming these keys: {updated_keys.keys()}")
else:
print("No need to rename state dict entries")
for old, new in updated_keys.items():
original_state_dict[new] = original_state_dict.pop(old)
# now attempt to load the state dictionary with updated names
# note that we now call it `mega.layers` instead of `mega.encoders` due to hugging face style
print("HF Mega encoder:", hf_mlm.mega.layers.load_state_dict(original_state_dict))
# load the MLM head weights directly
print(
"HF Mega MLM layer:",
hf_mlm.mlm_head.load_state_dict(
torch.load(os.path.join(pretrained_checkpoint_path, "mlm_head_weights.pt"), map_location="cpu")
),
)
# test on a randomly generated input sequence
input_ids = torch.randint(0, hf_config.vocab_size, size=(4, 256))
input_mask = torch.ones_like(input_ids)
# mask a few tokens to make sure masking is applied appropriately :)
input_mask[:, -10:] = 0
# run forward passes
original_output = original_mlm(input_ids, input_mask, batch_first=True, ignore_mask_value=0)
hf_output = hf_mlm(input_ids, input_mask)[0]
# print shapes and diff
print(f"original output {original_output.shape}")
print(f"hf output {hf_output.shape}")
print(f"max diff: {(original_output - hf_output).max()}") # 0.0
success = torch.allclose(original_output, hf_output, atol=1e-3)
if success:
print("Yay!")
hf_mlm.save_pretrained(output_path)
else:
raise RuntimeError(f"Something's broken :(\nOriginal:\n{original_output}\n\nHF\n{hf_output}\n{hf_mlm}")
if includes_tokenizer:
print("Transferring tokenizer")
tokenizer = AutoTokenizer.from_pretrained(pretrained_checkpoint_path)
tokenizer.save_pretrained(output_path)
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument(
"--pretrained_checkpoint_path",
default=None,
type=str,
required=True,
help="Point to the directory containing your model weights using the official Mega repo",
)
parser.add_argument(
"--output_path", default=None, type=str, required=True, help="Location to save the Hugging Face version"
)
parser.add_argument(
"--includes_tokenizer",
action="store_true",
help="Use this flag if there is a Hugging Face tokenizer in the original checkpoint repo",
)
args = parser.parse_args()
convert_checkpoint_to_huggingface(args.pretrained_checkpoint_path, args.output_path, args.includes_tokenizer)
This diff is collapsed.
......@@ -4205,6 +4205,65 @@ class MCTCTPreTrainedModel(metaclass=DummyObject):
requires_backends(self, ["torch"])
MEGA_PRETRAINED_MODEL_ARCHIVE_LIST = None
class MegaForCausalLM(metaclass=DummyObject):
_backends = ["torch"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])
class MegaForMaskedLM(metaclass=DummyObject):
_backends = ["torch"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])
class MegaForMultipleChoice(metaclass=DummyObject):
_backends = ["torch"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])
class MegaForQuestionAnswering(metaclass=DummyObject):
_backends = ["torch"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])
class MegaForSequenceClassification(metaclass=DummyObject):
_backends = ["torch"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])
class MegaForTokenClassification(metaclass=DummyObject):
_backends = ["torch"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])
class MegaModel(metaclass=DummyObject):
_backends = ["torch"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])
class MegaPreTrainedModel(metaclass=DummyObject):
_backends = ["torch"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])
MEGATRON_BERT_PRETRAINED_MODEL_ARCHIVE_LIST = None
......
This diff is collapsed.
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment