Unverified Commit 7a7fdf71 authored by Sam Shleifer's avatar Sam Shleifer Committed by GitHub
Browse files

Multilingual BART - (#3602)

- support mbart-en-ro weights
- add MBartTokenizer
parent f98d0ef2
......@@ -283,4 +283,7 @@ For a list that includes community-uploaded models, refer to `https://huggingfac
| +------------------------------------------------------------+---------------------------------------------------------------------------------------------------------------------------------------+
| | ``bart-large-cnn`` | | 12-layer, 1024-hidden, 16-heads, 406M parameters (same as base) |
| | | | bart-large base architecture finetuned on cnn summarization task |
| +------------------------------------------------------------+---------------------------------------------------------------------------------------------------------------------------------------+
| | ``mbart-large-en-ro`` | | 12-layer, 1024-hidden, 16-heads, 880M parameters |
| | | | bart-large architecture pretrained on cc25 multilingual data , finetuned on WMT english romanian translation. |
+-------------------+------------------------------------------------------------+---------------------------------------------------------------------------------------------------------------------------------------+
......@@ -122,7 +122,7 @@ from .pipelines import (
)
from .tokenization_albert import AlbertTokenizer
from .tokenization_auto import TOKENIZER_MAPPING, AutoTokenizer
from .tokenization_bart import BartTokenizer
from .tokenization_bart import BartTokenizer, MBartTokenizer
from .tokenization_bert import BasicTokenizer, BertTokenizer, BertTokenizerFast, WordpieceTokenizer
from .tokenization_bert_japanese import BertJapaneseTokenizer, CharacterTokenizer, MecabTokenizer
from .tokenization_camembert import CamembertTokenizer
......
......@@ -27,6 +27,7 @@ BART_PRETRAINED_CONFIG_ARCHIVE_MAP = {
"bart-large-mnli": "https://s3.amazonaws.com/models.huggingface.co/bert/facebook/bart-large-mnli/config.json",
"bart-large-cnn": "https://s3.amazonaws.com/models.huggingface.co/bert/facebook/bart-large-cnn/config.json",
"bart-large-xsum": "https://s3.amazonaws.com/models.huggingface.co/bert/facebook/bart-large-xsum/config.json",
"mbart-large-en-ro": "https://s3.amazonaws.com/models.huggingface.co/bert/facebook/mbart-large-en-ro/config.json",
}
......@@ -61,6 +62,9 @@ class BartConfig(PretrainedConfig):
pad_token_id=1,
bos_token_id=0,
eos_token_id=2,
normalize_before=False,
add_final_layer_norm=False,
scale_embedding=False,
**common_kwargs
):
r"""
......@@ -90,6 +94,11 @@ class BartConfig(PretrainedConfig):
self.max_position_embeddings = max_position_embeddings
self.init_std = init_std # Normal(0, this parameter)
self.activation_function = activation_function
self.scale_embedding = scale_embedding # scale factor will be sqrt(d_model) if True
# True for mbart, False otherwise
self.normalize_before = normalize_before # combo of fairseq's encoder_ and decoder_normalize_before
self.add_final_layer_norm = add_final_layer_norm
# 3 Types of Dropout
self.attention_dropout = attention_dropout
......@@ -100,9 +109,17 @@ class BartConfig(PretrainedConfig):
self.classif_dropout = classifier_dropout
@property
def num_attention_heads(self):
def num_attention_heads(self) -> int:
return self.encoder_attention_heads
@property
def hidden_size(self):
def hidden_size(self) -> int:
return self.d_model
def is_valid_mbart(self) -> bool:
"""Is the configuration aligned with the MBART paper."""
if self.normalize_before and self.add_final_layer_norm and self.scale_embedding:
return True
if self.normalize_before or self.add_final_layer_norm or self.scale_embedding:
logger.info("This configuration is a mixture of MBART and BART settings")
return False
......@@ -45,13 +45,24 @@ logger = logging.getLogger(__name__)
SAMPLE_TEXT = " Hello world! cécé herlolip"
rename_keys = [
mnli_rename_keys = [
("model.classification_heads.mnli.dense.weight", "classification_head.dense.weight"),
("model.classification_heads.mnli.dense.bias", "classification_head.dense.bias"),
("model.classification_heads.mnli.out_proj.weight", "classification_head.out_proj.weight"),
("model.classification_heads.mnli.out_proj.bias", "classification_head.out_proj.bias"),
]
IGNORE_KEYS = ["encoder.version", "decoder.version", "model.encoder.version", "model.decoder.version", "_float_tensor"]
def remove_ignore_keys_(state_dict):
ignore_keys = [
"encoder.version",
"decoder.version",
"model.encoder.version",
"model.decoder.version",
"_float_tensor",
]
for k in ignore_keys:
state_dict.pop(k, None)
def rename_key(dct, old, new):
......@@ -67,6 +78,19 @@ def load_xsum_checkpoint(checkpoint_path):
return hub_interface
def convert_checkpoint_from_disk(checkpoint_path, **config_kwargs):
state_dict = torch.load(checkpoint_path, map_location="cpu")["model"]
remove_ignore_keys_(state_dict)
vocab_size = state_dict["encoder.embed_tokens.weight"].shape[0]
state_dict["shared.weight"] = state_dict["decoder.embed_tokens.weight"]
mbart_config = BartConfig(vocab_size=vocab_size, **config_kwargs)
model = BartForConditionalGeneration(mbart_config)
model.model.load_state_dict(state_dict)
if hasattr(model, "lm_head"):
model.lm_head = _make_linear_from_emb(model.model.shared)
return model
@torch.no_grad()
def convert_bart_checkpoint(checkpoint_path, pytorch_dump_folder_path, hf_checkpoint_name=None):
"""
......@@ -89,7 +113,7 @@ def convert_bart_checkpoint(checkpoint_path, pytorch_dump_folder_path, hf_checkp
state_dict = bart.state_dict()
remove_ignore_keys_(state_dict)
state_dict["model.shared.weight"] = state_dict["model.decoder.embed_tokens.weight"]
for src, dest in rename_keys:
for src, dest in mnli_rename_keys:
rename_key(state_dict, src, dest)
model = BartForSequenceClassification(config).eval()
model.load_state_dict(state_dict)
......@@ -118,11 +142,6 @@ def convert_bart_checkpoint(checkpoint_path, pytorch_dump_folder_path, hf_checkp
model.save_pretrained(pytorch_dump_folder_path)
def remove_ignore_keys_(state_dict):
for k in IGNORE_KEYS:
state_dict.pop(k, None)
if __name__ == "__main__":
parser = argparse.ArgumentParser()
# Required parameters
......
......@@ -14,6 +14,7 @@
# limitations under the License.
"""PyTorch BART model, ported from the fairseq repo."""
import logging
import math
import random
from typing import Dict, List, Optional, Tuple
......@@ -35,6 +36,7 @@ BART_PRETRAINED_MODEL_ARCHIVE_MAP = {
"bart-large-mnli": "https://s3.amazonaws.com/models.huggingface.co/bert/facebook/bart-large-mnli/pytorch_model.bin",
"bart-large-cnn": "https://s3.amazonaws.com/models.huggingface.co/bert/facebook/bart-large-cnn/pytorch_model.bin",
"bart-large-xsum": "https://s3.amazonaws.com/models.huggingface.co/bert/facebook/bart-large-xsum/pytorch_model.bin",
"mbart-large-en-ro": "https://s3.amazonaws.com/models.huggingface.co/bert/facebook/mbart-large-en-ro/pytorch_model.bin",
}
BART_START_DOCSTRING = r"""
......@@ -180,6 +182,7 @@ class EncoderLayer(nn.Module):
self.self_attn = SelfAttention(
self.embed_dim, config.encoder_attention_heads, dropout=config.attention_dropout,
)
self.normalize_before = config.normalize_before
self.self_attn_layer_norm = LayerNorm(self.embed_dim)
self.dropout = config.dropout
self.activation_fn = ACT2FN[config.activation_function]
......@@ -201,19 +204,25 @@ class EncoderLayer(nn.Module):
encoded output of shape `(seq_len, batch, embed_dim)`
"""
residual = x
if self.normalize_before:
x = self.self_attn_layer_norm(x)
x, attn_weights = self.self_attn(
query=x, key=x, key_padding_mask=encoder_padding_mask, need_weights=self.output_attentions
)
x = F.dropout(x, p=self.dropout, training=self.training)
x = residual + x
if not self.normalize_before:
x = self.self_attn_layer_norm(x)
residual = x
if self.normalize_before:
x = self.final_layer_norm(x)
x = self.activation_fn(self.fc1(x))
x = F.dropout(x, p=self.activation_dropout, training=self.training)
x = self.fc2(x)
x = F.dropout(x, p=self.dropout, training=self.training)
x = residual + x
if not self.normalize_before:
x = self.final_layer_norm(x)
return x, attn_weights
......@@ -236,6 +245,7 @@ class BartEncoder(nn.Module):
self.output_hidden_states = config.output_hidden_states
embed_dim = embed_tokens.embedding_dim
self.embed_scale = math.sqrt(embed_dim) if config.scale_embedding else 1.0
self.padding_idx = embed_tokens.padding_idx
self.max_source_positions = config.max_position_embeddings
......@@ -244,6 +254,8 @@ class BartEncoder(nn.Module):
self.embed_positions = LearnedPositionalEmbedding(config.max_position_embeddings, embed_dim, self.padding_idx,)
self.layers = nn.ModuleList([EncoderLayer(config) for _ in range(config.encoder_layers)])
self.layernorm_embedding = LayerNorm(embed_dim)
# mbart has one extra layer_norm
self.layer_norm = LayerNorm(config.d_model) if config.normalize_before else None
def forward(
self, input_ids, attention_mask=None,
......@@ -267,7 +279,7 @@ class BartEncoder(nn.Module):
if attention_mask is not None:
attention_mask = invert_mask(attention_mask)
inputs_embeds = self.embed_tokens(input_ids)
inputs_embeds = self.embed_tokens(input_ids) * self.embed_scale
embed_pos = self.embed_positions(input_ids)
x = inputs_embeds + embed_pos
x = self.layernorm_embedding(x)
......@@ -290,6 +302,8 @@ class BartEncoder(nn.Module):
if self.output_attentions:
all_attentions.append(attn)
if self.layer_norm:
x = self.layer_norm(x)
if self.output_hidden_states:
encoder_states.append(x)
......@@ -311,6 +325,7 @@ class DecoderLayer(nn.Module):
self.dropout = config.dropout
self.activation_fn = ACT2FN[config.activation_function]
self.activation_dropout = config.activation_dropout
self.normalize_before = config.normalize_before
self.self_attn_layer_norm = LayerNorm(self.embed_dim)
self.encoder_attn = SelfAttention(
......@@ -337,21 +352,28 @@ class DecoderLayer(nn.Module):
if layer_state is None:
layer_state = {}
# next line mutates layer state
if self.normalize_before:
x = self.self_attn_layer_norm(x)
# Self Attention
x, self_attn_weights = self.self_attn(
query=x,
key=x,
layer_state=layer_state,
layer_state=layer_state, # adds keys to layer state
key_padding_mask=decoder_padding_mask,
attn_mask=causal_mask,
need_weights=self.output_attentions,
)
x = F.dropout(x, p=self.dropout, training=self.training)
x = residual + x
if not self.normalize_before:
x = self.self_attn_layer_norm(x)
# Cross attention
residual = x
assert self.encoder_attn.cache_key != self.self_attn.cache_key
if self.normalize_before:
x = self.encoder_attn_layer_norm(x)
x, _ = self.encoder_attn(
query=x,
key=encoder_hidden_states,
......@@ -360,15 +382,19 @@ class DecoderLayer(nn.Module):
)
x = F.dropout(x, p=self.dropout, training=self.training)
x = residual + x
if not self.normalize_before:
x = self.encoder_attn_layer_norm(x)
# Fully Connected
residual = x
if self.normalize_before:
x = self.final_layer_norm(x)
x = self.activation_fn(self.fc1(x))
x = F.dropout(x, p=self.activation_dropout, training=self.training)
x = self.fc2(x)
x = F.dropout(x, p=self.dropout, training=self.training)
x = residual + x
if not self.normalize_before:
x = self.final_layer_norm(x)
return (
x,
......@@ -394,6 +420,7 @@ class BartDecoder(nn.Module):
self.layerdrop = config.decoder_layerdrop
self.padding_idx = embed_tokens.padding_idx
self.max_target_positions = config.max_position_embeddings
self.embed_scale = math.sqrt(config.d_model) if config.scale_embedding else 1.0
self.embed_tokens = embed_tokens
self.embed_positions = LearnedPositionalEmbedding(
config.max_position_embeddings, config.d_model, self.padding_idx,
......@@ -402,6 +429,7 @@ class BartDecoder(nn.Module):
[DecoderLayer(config) for _ in range(config.decoder_layers)]
) # type: List[DecoderLayer]
self.layernorm_embedding = LayerNorm(config.d_model)
self.layer_norm = LayerNorm(config.d_model) if config.add_final_layer_norm else None
def forward(
self,
......@@ -444,9 +472,8 @@ class BartDecoder(nn.Module):
positions = positions[:, -1:] # happens after we embed them
assert input_ids.ne(self.padding_idx).any()
x = self.embed_tokens(input_ids)
x = self.embed_tokens(input_ids) * self.embed_scale
x += positions
x = self.layernorm_embedding(x)
x = F.dropout(x, p=self.dropout, training=self.training)
......@@ -458,14 +485,16 @@ class BartDecoder(nn.Module):
all_hidden_states = ()
all_self_attns = ()
next_decoder_cache = []
for i, decoder_layer in enumerate(self.layers):
decoder_layer # type: DecoderLayer
for idx, decoder_layer in enumerate(self.layers):
# add LayerDrop (see https://arxiv.org/abs/1909.11556 for description)
if self.output_hidden_states:
all_hidden_states += (x,)
dropout_probability = random.uniform(0, 1)
if self.training and (dropout_probability < self.layerdrop):
continue
layer_state = decoder_cached_states[i] if decoder_cached_states is not None else None
layer_state = decoder_cached_states[idx] if decoder_cached_states is not None else None
x, layer_self_attn, layer_past = decoder_layer(
x,
encoder_hidden_states,
......@@ -477,12 +506,13 @@ class BartDecoder(nn.Module):
if use_cache:
next_decoder_cache.append(layer_past.copy())
if self.output_hidden_states:
all_hidden_states += (x,)
if self.layer_norm and (idx == len(self.layers) - 1): # last layer of mbart
x = self.layer_norm(x)
if self.output_attentions:
all_self_attns += (layer_self_attn,)
# Convert to standart output format: (seq_len, BS, model_dim) -> (BS, seq_len, model_dim)
# Convert to standard output format: (seq_len, BS, model_dim) -> (BS, seq_len, model_dim)
all_hidden_states = [hidden_state.transpose(0, 1) for hidden_state in all_hidden_states]
x = x.transpose(0, 1)
encoder_hidden_states = encoder_hidden_states.transpose(0, 1)
......
......@@ -13,7 +13,13 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
from .tokenization_roberta import RobertaTokenizer
from .tokenization_xlm_roberta import XLMRobertaTokenizer
logger = logging.getLogger(__name__)
# vocab and merges same as roberta
......@@ -21,6 +27,8 @@ vocab_url = "https://s3.amazonaws.com/models.huggingface.co/bert/roberta-large-v
merges_url = "https://s3.amazonaws.com/models.huggingface.co/bert/roberta-large-merges.txt"
_all_bart_models = ["bart-large", "bart-large-mnli", "bart-large-cnn", "bart-large-xsum"]
VOCAB_FILES_NAMES = {"vocab_file": "sentence.bpe.model"}
class BartTokenizer(RobertaTokenizer):
# merges and vocab same as Roberta
......@@ -29,3 +37,13 @@ class BartTokenizer(RobertaTokenizer):
"vocab_file": {m: vocab_url for m in _all_bart_models},
"merges_file": {m: merges_url for m in _all_bart_models},
}
_all_mbart_models = ["mbart-large-en-ro"]
SPM_URL = "https://s3.amazonaws.com/models.huggingface.co/bert/facebook/mbart-large-en-ro/sentence.bpe.model"
class MBartTokenizer(XLMRobertaTokenizer):
vocab_files_names = VOCAB_FILES_NAMES
max_model_input_sizes = {m: 1024 for m in _all_mbart_models}
pretrained_vocab_files_map = {"vocab_file": {m: SPM_URL for m in _all_mbart_models}}
......@@ -34,6 +34,8 @@ if is_torch_available():
BartForConditionalGeneration,
BartForSequenceClassification,
BartConfig,
BartTokenizer,
MBartTokenizer,
)
from transformers.modeling_bart import (
BART_PRETRAINED_MODEL_ARCHIVE_MAP,
......@@ -41,7 +43,6 @@ if is_torch_available():
invert_mask,
_prepare_bart_decoder_inputs,
)
from transformers.tokenization_bart import BartTokenizer
@require_torch
......@@ -55,10 +56,10 @@ class ModelTester:
self.is_training = True
self.use_labels = False
self.vocab_size = 99
self.hidden_size = 32
self.num_hidden_layers = 5
self.hidden_size = 16
self.num_hidden_layers = 2
self.num_attention_heads = 4
self.intermediate_size = 37
self.intermediate_size = 4
self.hidden_act = "gelu"
self.hidden_dropout_prob = 0.1
self.attention_probs_dropout_prob = 0.1
......@@ -105,7 +106,6 @@ def prepare_bart_inputs_dict(
@require_torch
class BARTModelTest(ModelTesterMixin, unittest.TestCase):
all_model_classes = (
(BartModel, BartForConditionalGeneration, BartForSequenceClassification) if is_torch_available() else ()
)
......@@ -196,8 +196,114 @@ class BARTModelTest(ModelTesterMixin, unittest.TestCase):
@require_torch
class BartHeadTests(unittest.TestCase):
class BartTranslationTests(unittest.TestCase):
_model = None
@classmethod
def setUpClass(cls):
checkpoint_name = "mbart-large-en-ro"
cls.tokenizer = MBartTokenizer.from_pretrained(checkpoint_name)
cls.pad_token_id = 1
net_input = {
"input_ids": _long_tensor(
[
[3493, 3060, 621, 104064, 1810, 100, 142, 566, 13158, 6889, 5, 2, 250004],
[64511, 7, 765, 2837, 45188, 297, 4049, 237, 10, 122122, 5, 2, 250004],
]
),
"decoder_input_ids": _long_tensor(
[
[250020, 31952, 144, 9019, 242307, 21980, 55749, 11, 5, 2, 1, 1],
[250020, 884, 9019, 96, 9, 916, 86792, 36, 18743, 15596, 5, 2],
]
),
"generation_mode": False,
}
net_input["attention_mask"] = net_input["input_ids"].ne(cls.pad_token_id)
cls.net_input = net_input
return cls
@property
def model(self):
"""Only load the model if needed."""
if self._model is None:
model = BartForConditionalGeneration.from_pretrained("mbart-large-en-ro")
self._model = model
return self._model
@slow
def test_enro_forward(self):
model = self.model
with torch.no_grad():
logits, *other_stuff = model(**self.net_input)
expected_slice = torch.tensor([9.0078, 10.1113, 14.4787])
result_slice = logits[0][0][:3]
self.assertTrue(torch.allclose(expected_slice, result_slice, atol=TOLERANCE))
@slow
def test_enro_generate(self):
model = self.model
# example_english_phrase = " UN Chief Says There Is No Military Solution in Syria"
# inputs: dict = tokenizer.batch_encode_plus([example_english_phrase], return_tensors="pt",)
expected_translation_romanian = "Şeful ONU declară că nu există o soluţie militară în Siria"
inputs = {
"input_ids": torch.LongTensor(
[[8274, 127873, 25916, 7, 8622, 2071, 438, 67485, 53, 187895, 23, 51712, 2]] # 250004
)
}
translated_tokens = model.generate(input_ids=inputs["input_ids"].to(torch_device), num_beams=5,)
decoded = [
self.tokenizer.decode(g, skip_special_tokens=True, clean_up_tokenization_spaces=False)
for g in translated_tokens
]
self.assertEqual(expected_translation_romanian, decoded[0])
def test_mbart_enro_config(self):
mbart_models = ["mbart-large-en-ro"]
expected = {"scale_embedding": True, "output_past": True}
for name in mbart_models:
config = BartConfig.from_pretrained(name)
self.assertTrue(config.is_valid_mbart())
for k, v in expected.items():
try:
self.assertEqual(v, getattr(config, k))
except AssertionError as e:
e.args += (name, k)
raise
def test_enro_tokenizer(self):
raw = "UN Chief Says There Is No Military Solution in Syria"
ids = self.tokenizer.batch_encode_plus([raw])["input_ids"][0]
expected_result = [0, 8274, 127873, 25916, 7, 8622, 2071, 438, 67485, 53, 187895, 23, 51712, 2]
# TODO(SS): should be [8274, ..., 2, 250020]
self.assertListEqual(expected_result, ids)
def test_mbart_fast_forward(self):
config = BartConfig(
vocab_size=99,
d_model=24,
encoder_layers=2,
decoder_layers=2,
encoder_attention_heads=2,
decoder_attention_heads=2,
encoder_ffn_dim=32,
decoder_ffn_dim=32,
max_position_embeddings=48,
add_final_layer_norm=True,
)
lm_model = BartForConditionalGeneration(config).to(torch_device)
context = torch.Tensor([[71, 82, 18, 33, 46, 91, 2], [68, 34, 26, 58, 30, 2, 1]]).long().to(torch_device)
summary = torch.Tensor([[82, 71, 82, 18, 2], [58, 68, 2, 1, 1]]).long().to(torch_device)
loss, logits, enc_features = lm_model(input_ids=context, decoder_input_ids=summary, lm_labels=summary)
expected_shape = (*summary.shape, config.vocab_size)
self.assertEqual(logits.shape, expected_shape)
@require_torch
class BartHeadTests(unittest.TestCase):
vocab_size = 99
def _get_config_and_data(self):
......@@ -263,13 +369,13 @@ class BartHeadTests(unittest.TestCase):
def test_lm_uneven_forward(self):
config = BartConfig(
vocab_size=self.vocab_size,
d_model=24,
d_model=14,
encoder_layers=2,
decoder_layers=2,
encoder_attention_heads=2,
decoder_attention_heads=2,
encoder_ffn_dim=32,
decoder_ffn_dim=32,
encoder_ffn_dim=8,
decoder_ffn_dim=8,
max_position_embeddings=48,
)
lm_model = BartForConditionalGeneration(config).to(torch_device)
......@@ -462,6 +568,7 @@ class BartModelIntegrationTests(unittest.TestCase):
@slow
def test_xsum_summarization_same_as_fairseq(self):
model = BartForConditionalGeneration.from_pretrained("bart-large-xsum").to(torch_device)
self.assertFalse(model.config.is_valid_mbart())
tok = BartTokenizer.from_pretrained("bart-large")
PGE_ARTICLE = """ PG&E stated it scheduled the blackouts in response to forecasts for high winds amid dry conditions. The aim is to reduce the risk of wildfires. Nearly 800 thousand customers were scheduled to be affected by the shutoffs which were expected to last through at least midday tomorrow."""
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment