Commit c41f2bad authored by thomwolf's avatar thomwolf
Browse files

WIP XLM + refactoring

parent 288be7b7
...@@ -14,8 +14,8 @@ from torch.utils.data.distributed import DistributedSampler ...@@ -14,8 +14,8 @@ from torch.utils.data.distributed import DistributedSampler
from tqdm import tqdm from tqdm import tqdm
from pytorch_pretrained_bert import WEIGHTS_NAME, CONFIG_NAME from pytorch_pretrained_bert import WEIGHTS_NAME, CONFIG_NAME
from pytorch_pretrained_bert.modeling import BertForPreTraining from pytorch_pretrained_bert.modeling_bert import BertForPreTraining
from pytorch_pretrained_bert.tokenization import BertTokenizer from pytorch_pretrained_bert.tokenization_bert import BertTokenizer
from pytorch_pretrained_bert.optimization import BertAdam, WarmupLinearSchedule from pytorch_pretrained_bert.optimization import BertAdam, WarmupLinearSchedule
InputFeatures = namedtuple("InputFeatures", "input_ids input_mask segment_ids lm_label_ids is_next") InputFeatures = namedtuple("InputFeatures", "input_ids input_mask segment_ids lm_label_ids is_next")
......
...@@ -5,7 +5,7 @@ from tempfile import TemporaryDirectory ...@@ -5,7 +5,7 @@ from tempfile import TemporaryDirectory
import shelve import shelve
from random import random, randrange, randint, shuffle, choice from random import random, randrange, randint, shuffle, choice
from pytorch_pretrained_bert.tokenization import BertTokenizer from pytorch_pretrained_bert.tokenization_bert import BertTokenizer
import numpy as np import numpy as np
import json import json
import collections import collections
......
...@@ -30,8 +30,8 @@ from torch.utils.data.distributed import DistributedSampler ...@@ -30,8 +30,8 @@ from torch.utils.data.distributed import DistributedSampler
from tqdm import tqdm, trange from tqdm import tqdm, trange
from pytorch_pretrained_bert import WEIGHTS_NAME, CONFIG_NAME from pytorch_pretrained_bert import WEIGHTS_NAME, CONFIG_NAME
from pytorch_pretrained_bert.modeling import BertForPreTraining from pytorch_pretrained_bert.modeling_bert import BertForPreTraining
from pytorch_pretrained_bert.tokenization import BertTokenizer from pytorch_pretrained_bert.tokenization_bert import BertTokenizer
from pytorch_pretrained_bert.optimization import BertAdam, WarmupLinearSchedule from pytorch_pretrained_bert.optimization import BertAdam, WarmupLinearSchedule
logging.basicConfig(format='%(asctime)s - %(levelname)s - %(name)s - %(message)s', logging.basicConfig(format='%(asctime)s - %(levelname)s - %(name)s - %(message)s',
......
...@@ -35,8 +35,8 @@ from torch.nn import CrossEntropyLoss, MSELoss ...@@ -35,8 +35,8 @@ from torch.nn import CrossEntropyLoss, MSELoss
from tensorboardX import SummaryWriter from tensorboardX import SummaryWriter
from pytorch_pretrained_bert import WEIGHTS_NAME, CONFIG_NAME from pytorch_pretrained_bert import WEIGHTS_NAME, CONFIG_NAME
from pytorch_pretrained_bert.modeling import BertForSequenceClassification from pytorch_pretrained_bert.modeling_bert import BertForSequenceClassification
from pytorch_pretrained_bert.tokenization import BertTokenizer from pytorch_pretrained_bert.tokenization_bert import BertTokenizer
from pytorch_pretrained_bert.optimization import BertAdam, WarmupLinearSchedule from pytorch_pretrained_bert.optimization import BertAdam, WarmupLinearSchedule
from utils_glue import processors, output_modes, convert_examples_to_features, compute_metrics from utils_glue import processors, output_modes, convert_examples_to_features, compute_metrics
......
...@@ -28,8 +28,8 @@ import torch ...@@ -28,8 +28,8 @@ import torch
from torch.utils.data import TensorDataset, DataLoader, SequentialSampler from torch.utils.data import TensorDataset, DataLoader, SequentialSampler
from torch.utils.data.distributed import DistributedSampler from torch.utils.data.distributed import DistributedSampler
from pytorch_pretrained_bert.tokenization import BertTokenizer from pytorch_pretrained_bert.tokenization_bert import BertTokenizer
from pytorch_pretrained_bert.modeling import BertModel from pytorch_pretrained_bert.modeling_bert import BertModel
logging.basicConfig(format = '%(asctime)s - %(levelname)s - %(name)s - %(message)s', logging.basicConfig(format = '%(asctime)s - %(levelname)s - %(name)s - %(message)s',
datefmt = '%m/%d/%Y %H:%M:%S', datefmt = '%m/%d/%Y %H:%M:%S',
......
...@@ -34,9 +34,9 @@ from tqdm import tqdm, trange ...@@ -34,9 +34,9 @@ from tqdm import tqdm, trange
from tensorboardX import SummaryWriter from tensorboardX import SummaryWriter
from pytorch_pretrained_bert import WEIGHTS_NAME, CONFIG_NAME from pytorch_pretrained_bert import WEIGHTS_NAME, CONFIG_NAME
from pytorch_pretrained_bert.modeling import BertForQuestionAnswering from pytorch_pretrained_bert.modeling_bert import BertForQuestionAnswering
from pytorch_pretrained_bert.optimization import BertAdam, WarmupLinearSchedule from pytorch_pretrained_bert.optimization import BertAdam, WarmupLinearSchedule
from pytorch_pretrained_bert.tokenization import BertTokenizer from pytorch_pretrained_bert.tokenization_bert import BertTokenizer
from utils_squad import read_squad_examples, convert_examples_to_features, RawResult, write_predictions from utils_squad import read_squad_examples, convert_examples_to_features, RawResult, write_predictions
......
...@@ -33,9 +33,9 @@ from torch.utils.data.distributed import DistributedSampler ...@@ -33,9 +33,9 @@ from torch.utils.data.distributed import DistributedSampler
from tqdm import tqdm, trange from tqdm import tqdm, trange
from pytorch_pretrained_bert.file_utils import PYTORCH_PRETRAINED_BERT_CACHE, WEIGHTS_NAME, CONFIG_NAME from pytorch_pretrained_bert.file_utils import PYTORCH_PRETRAINED_BERT_CACHE, WEIGHTS_NAME, CONFIG_NAME
from pytorch_pretrained_bert.modeling import BertForMultipleChoice, BertConfig from pytorch_pretrained_bert.modeling_bert import BertForMultipleChoice, BertConfig
from pytorch_pretrained_bert.optimization import BertAdam, WarmupLinearSchedule from pytorch_pretrained_bert.optimization import BertAdam, WarmupLinearSchedule
from pytorch_pretrained_bert.tokenization import BertTokenizer from pytorch_pretrained_bert.tokenization_bert import BertTokenizer
logging.basicConfig(format = '%(asctime)s - %(levelname)s - %(name)s - %(message)s', logging.basicConfig(format = '%(asctime)s - %(levelname)s - %(name)s - %(message)s',
datefmt = '%m/%d/%Y %H:%M:%S', datefmt = '%m/%d/%Y %H:%M:%S',
......
...@@ -24,7 +24,7 @@ import math ...@@ -24,7 +24,7 @@ import math
import collections import collections
from io import open from io import open
from pytorch_pretrained_bert.tokenization import BasicTokenizer, whitespace_tokenize from pytorch_pretrained_bert.tokenization_bert import BasicTokenizer, whitespace_tokenize
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
......
from pytorch_pretrained_bert.tokenization import BertTokenizer from pytorch_pretrained_bert.tokenization_bert import BertTokenizer
from pytorch_pretrained_bert.modeling import ( from pytorch_pretrained_bert.modeling_bert import (
BertModel, BertModel,
BertForNextSentencePrediction, BertForNextSentencePrediction,
BertForMaskedLM, BertForMaskedLM,
......
...@@ -3997,9 +3997,9 @@ ...@@ -3997,9 +3997,9 @@
"name": "stderr", "name": "stderr",
"output_type": "stream", "output_type": "stream",
"text": [ "text": [
"11/16/2018 11:03:05 - INFO - pytorch_pretrained_bert.modeling - loading archive file https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-uncased.tar.gz from cache at /Users/thomaswolf/.pytorch_pretrained_bert/9c41111e2de84547a463fd39217199738d1e3deb72d4fec4399e6e241983c6f0.ae3cef932725ca7a30cdcb93fc6e09150a55e2a130ec7af63975a16c153ae2ba\n", "11/16/2018 11:03:05 - INFO - pytorch_pretrained_bert.modeling_bert - loading archive file https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-uncased.tar.gz from cache at /Users/thomaswolf/.pytorch_pretrained_bert/9c41111e2de84547a463fd39217199738d1e3deb72d4fec4399e6e241983c6f0.ae3cef932725ca7a30cdcb93fc6e09150a55e2a130ec7af63975a16c153ae2ba\n",
"11/16/2018 11:03:05 - INFO - pytorch_pretrained_bert.modeling - extracting archive file /Users/thomaswolf/.pytorch_pretrained_bert/9c41111e2de84547a463fd39217199738d1e3deb72d4fec4399e6e241983c6f0.ae3cef932725ca7a30cdcb93fc6e09150a55e2a130ec7af63975a16c153ae2ba to temp dir /var/folders/yx/cw8n_njx3js5jksyw_qlp8p00000gn/T/tmpaqgsm566\n", "11/16/2018 11:03:05 - INFO - pytorch_pretrained_bert.modeling_bert - extracting archive file /Users/thomaswolf/.pytorch_pretrained_bert/9c41111e2de84547a463fd39217199738d1e3deb72d4fec4399e6e241983c6f0.ae3cef932725ca7a30cdcb93fc6e09150a55e2a130ec7af63975a16c153ae2ba to temp dir /var/folders/yx/cw8n_njx3js5jksyw_qlp8p00000gn/T/tmpaqgsm566\n",
"11/16/2018 11:03:08 - INFO - pytorch_pretrained_bert.modeling - Model config {\n", "11/16/2018 11:03:08 - INFO - pytorch_pretrained_bert.modeling_bert - Model config {\n",
" \"attention_probs_dropout_prob\": 0.1,\n", " \"attention_probs_dropout_prob\": 0.1,\n",
" \"hidden_act\": \"gelu\",\n", " \"hidden_act\": \"gelu\",\n",
" \"hidden_dropout_prob\": 0.1,\n", " \"hidden_dropout_prob\": 0.1,\n",
......
...@@ -375,8 +375,8 @@ ...@@ -375,8 +375,8 @@
"name": "stderr", "name": "stderr",
"output_type": "stream", "output_type": "stream",
"text": [ "text": [
"11/15/2018 16:21:18 - INFO - pytorch_pretrained_bert.modeling - loading archive file ../../google_models/uncased_L-12_H-768_A-12/\n", "11/15/2018 16:21:18 - INFO - pytorch_pretrained_bert.modeling_bert - loading archive file ../../google_models/uncased_L-12_H-768_A-12/\n",
"11/15/2018 16:21:18 - INFO - pytorch_pretrained_bert.modeling - Model config {\n", "11/15/2018 16:21:18 - INFO - pytorch_pretrained_bert.modeling_bert - Model config {\n",
" \"attention_probs_dropout_prob\": 0.1,\n", " \"attention_probs_dropout_prob\": 0.1,\n",
" \"hidden_act\": \"gelu\",\n", " \"hidden_act\": \"gelu\",\n",
" \"hidden_dropout_prob\": 0.1,\n", " \"hidden_dropout_prob\": 0.1,\n",
......
__version__ = "0.6.2" __version__ = "0.6.2"
from .tokenization import BertTokenizer, BasicTokenizer, WordpieceTokenizer from .tokenization_bert import BertTokenizer, BasicTokenizer, WordpieceTokenizer
from .tokenization_openai import OpenAIGPTTokenizer from .tokenization_openai import OpenAIGPTTokenizer
from .tokenization_transfo_xl import (TransfoXLTokenizer, TransfoXLCorpus) from .tokenization_transfo_xl import (TransfoXLTokenizer, TransfoXLCorpus)
from .tokenization_gpt2 import GPT2Tokenizer from .tokenization_gpt2 import GPT2Tokenizer
from .tokenization_xlnet import XLNetTokenizer, SPIECE_UNDERLINE from .tokenization_xlnet import XLNetTokenizer, SPIECE_UNDERLINE
from .tokenization_xlm import XLMTokenizer
from .modeling import (BertConfig, BertModel, BertForPreTraining, from .modeling_bert import (BertConfig, BertModel, BertForPreTraining,
BertForMaskedLM, BertForNextSentencePrediction, BertForMaskedLM, BertForNextSentencePrediction,
BertForSequenceClassification, BertForMultipleChoice, BertForSequenceClassification, BertForMultipleChoice,
BertForTokenClassification, BertForQuestionAnswering, BertForTokenClassification, BertForQuestionAnswering,
...@@ -22,6 +23,9 @@ from .modeling_xlnet import (XLNetConfig, ...@@ -22,6 +23,9 @@ from .modeling_xlnet import (XLNetConfig,
XLNetPreTrainedModel, XLNetModel, XLNetLMHeadModel, XLNetPreTrainedModel, XLNetModel, XLNetLMHeadModel,
XLNetForSequenceClassification, XLNetForQuestionAnswering, XLNetForSequenceClassification, XLNetForQuestionAnswering,
load_tf_weights_in_xlnet) load_tf_weights_in_xlnet)
from .modeling_xlm import (XLMConfig, XLMModel,
XLMWithLMHeadModel, XLMForSequenceClassification,
XLMForQuestionAnswering)
from .optimization import BertAdam from .optimization import BertAdam
from .optimization_openai import OpenAIAdam from .optimization_openai import OpenAIAdam
......
...@@ -25,7 +25,7 @@ import tensorflow as tf ...@@ -25,7 +25,7 @@ import tensorflow as tf
import torch import torch
import numpy as np import numpy as np
from pytorch_pretrained_bert.modeling import BertConfig, BertForPreTraining, load_tf_weights_in_bert from pytorch_pretrained_bert.modeling_bert import BertConfig, BertForPreTraining, load_tf_weights_in_bert
def convert_tf_checkpoint_to_pytorch(tf_checkpoint_path, bert_config_file, pytorch_dump_path): def convert_tf_checkpoint_to_pytorch(tf_checkpoint_path, bert_config_file, pytorch_dump_path):
# Initialise PyTorch model # Initialise PyTorch model
......
...@@ -32,7 +32,7 @@ from torch.nn.parameter import Parameter ...@@ -32,7 +32,7 @@ from torch.nn.parameter import Parameter
from .file_utils import cached_path from .file_utils import cached_path
from .model_utils import Conv1D, CONFIG_NAME, WEIGHTS_NAME, PretrainedConfig, PreTrainedModel, prune_conv1d_layer from .model_utils import Conv1D, CONFIG_NAME, WEIGHTS_NAME, PretrainedConfig, PreTrainedModel, prune_conv1d_layer
from .modeling import BertLayerNorm as LayerNorm from .modeling_bert import BertLayerNorm as LayerNorm
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
......
...@@ -32,7 +32,7 @@ from torch.nn.parameter import Parameter ...@@ -32,7 +32,7 @@ from torch.nn.parameter import Parameter
from .file_utils import cached_path from .file_utils import cached_path
from .model_utils import Conv1D, CONFIG_NAME, WEIGHTS_NAME, PretrainedConfig, PreTrainedModel, prune_conv1d_layer from .model_utils import Conv1D, CONFIG_NAME, WEIGHTS_NAME, PretrainedConfig, PreTrainedModel, prune_conv1d_layer
from .modeling import BertLayerNorm as LayerNorm from .modeling_bert import BertLayerNorm as LayerNorm
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
......
...@@ -34,7 +34,7 @@ import torch.nn.functional as F ...@@ -34,7 +34,7 @@ import torch.nn.functional as F
from torch.nn import CrossEntropyLoss from torch.nn import CrossEntropyLoss
from torch.nn.parameter import Parameter from torch.nn.parameter import Parameter
from .modeling import BertLayerNorm as LayerNorm from .modeling_bert import BertLayerNorm as LayerNorm
from .modeling_transfo_xl_utilities import ProjectedAdaptiveLogSoftmax, sample_logits from .modeling_transfo_xl_utilities import ProjectedAdaptiveLogSoftmax, sample_logits
from .file_utils import cached_path from .file_utils import cached_path
from .model_utils import CONFIG_NAME, WEIGHTS_NAME, PretrainedConfig, PreTrainedModel from .model_utils import CONFIG_NAME, WEIGHTS_NAME, PretrainedConfig, PreTrainedModel
......
...@@ -71,7 +71,7 @@ class XLMConfig(PretrainedConfig): ...@@ -71,7 +71,7 @@ class XLMConfig(PretrainedConfig):
pretrained_config_archive_map = PRETRAINED_CONFIG_ARCHIVE_MAP pretrained_config_archive_map = PRETRAINED_CONFIG_ARCHIVE_MAP
def __init__(self, def __init__(self,
vocab_size_or_config_json_file, vocab_size_or_config_json_file=30145,
n_special=0, n_special=0,
emb_dim=2048, emb_dim=2048,
n_layers=12, n_layers=12,
...@@ -80,13 +80,20 @@ class XLMConfig(PretrainedConfig): ...@@ -80,13 +80,20 @@ class XLMConfig(PretrainedConfig):
attention_dropout=0.1, attention_dropout=0.1,
gelu_activation=True, gelu_activation=True,
sinusoidal_embeddings=False, sinusoidal_embeddings=False,
causal=False,
asm=False, asm=False,
id2lang={ 0: "en" },
lang2id={ "en": 0 },
n_langs=1, n_langs=1,
n_words=30145,
max_position_embeddings=512, max_position_embeddings=512,
initializer_range=0.02, embed_init_std=2048 ** -0.5,
init_std=0.02,
summary_type="last",
use_proj=True,
bos_index=0,
eos_index=1,
pad_index=2,
unk_index=3,
mask_index=5,
is_encoder=True,
**kwargs): **kwargs):
"""Constructs XLMConfig. """Constructs XLMConfig.
...@@ -148,12 +155,20 @@ class XLMConfig(PretrainedConfig): ...@@ -148,12 +155,20 @@ class XLMConfig(PretrainedConfig):
self.attention_dropout = attention_dropout self.attention_dropout = attention_dropout
self.gelu_activation = gelu_activation self.gelu_activation = gelu_activation
self.sinusoidal_embeddings = sinusoidal_embeddings self.sinusoidal_embeddings = sinusoidal_embeddings
self.causal = causal
self.asm = asm self.asm = asm
self.id2lang = id2lang
self.lang2id = lang2id
self.n_langs = n_langs self.n_langs = n_langs
self.summary_type = summary_type
self.use_proj = use_proj
self.bos_index = bos_index
self.eos_index = eos_index
self.pad_index = pad_index
self.unk_index = unk_index
self.mask_index = mask_index
self.is_encoder = is_encoder
self.max_position_embeddings = max_position_embeddings self.max_position_embeddings = max_position_embeddings
self.initializer_range = initializer_range self.embed_init_std = embed_init_std
self.init_std = init_std
else: else:
raise ValueError("First argument must be either a vocabulary size (int)" raise ValueError("First argument must be either a vocabulary size (int)"
"or the path to a pretrained model config file (str)") "or the path to a pretrained model config file (str)")
...@@ -175,37 +190,21 @@ class XLMConfig(PretrainedConfig): ...@@ -175,37 +190,21 @@ class XLMConfig(PretrainedConfig):
return self.n_layers return self.n_layers
try: def Embedding(num_embeddings, embedding_dim, padding_idx=None, config=None):
from apex.normalization.fused_layer_norm import FusedLayerNorm as XLMLayerNorm
except ImportError:
logger.info("Better speed can be achieved with apex installed from https://www.github.com/nvidia/apex .")
class XLMLayerNorm(nn.Module):
def __init__(self, d_model, eps=1e-12):
"""Construct a layernorm module in the TF style (epsilon inside the square root).
"""
super(XLMLayerNorm, self).__init__()
self.weight = nn.Parameter(torch.ones(d_model))
self.bias = nn.Parameter(torch.zeros(d_model))
self.variance_epsilon = eps
def forward(self, x):
u = x.mean(-1, keepdim=True)
s = (x - u).pow(2).mean(-1, keepdim=True)
x = (x - u) / torch.sqrt(s + self.variance_epsilon)
return self.weight * x + self.bias
def Embedding(num_embeddings, embedding_dim, padding_idx=None):
m = nn.Embedding(num_embeddings, embedding_dim, padding_idx=padding_idx) m = nn.Embedding(num_embeddings, embedding_dim, padding_idx=padding_idx)
nn.init.normal_(m.weight, mean=0, std=embedding_dim ** -0.5) if config is not None and config.embed_init_std is not None:
nn.init.normal_(m.weight, mean=0, std=config.embed_init_std)
if padding_idx is not None: if padding_idx is not None:
nn.init.constant_(m.weight[padding_idx], 0) nn.init.constant_(m.weight[padding_idx], 0)
return m return m
def Linear(in_features, out_features, bias=True): def Linear(in_features, out_features, bias=True, config=None):
m = nn.Linear(in_features, out_features, bias) m = nn.Linear(in_features, out_features, bias)
# nn.init.normal_(m.weight, mean=0, std=1) if config is not None and config.init_std is not None:
nn.init.normal_(m.weight, mean=0, std=config.init_std)
if bias:
nn.init.constant_(m.bias, 0.)
# nn.init.xavier_uniform_(m.weight) # nn.init.xavier_uniform_(m.weight)
# nn.init.constant_(m.bias, 0.) # nn.init.constant_(m.bias, 0.)
return m return m
...@@ -233,14 +232,17 @@ def gelu(x): ...@@ -233,14 +232,17 @@ def gelu(x):
return 0.5 * x * (1.0 + torch.erf(x / math.sqrt(2.0))) return 0.5 * x * (1.0 + torch.erf(x / math.sqrt(2.0)))
def get_masks(slen, lengths, causal): def get_masks(slen, lengths, causal, padding_mask=None):
""" """
Generate hidden states mask, and optionally an attention mask. Generate hidden states mask, and optionally an attention mask.
""" """
assert lengths.max().item() <= slen
bs = lengths.size(0) bs = lengths.size(0)
alen = torch.arange(slen, dtype=torch.long, device=lengths.device) if padding_mask is not None:
mask = alen < lengths[:, None] mask = padding_mask
else:
assert lengths.max().item() <= slen
alen = torch.arange(slen, dtype=torch.long, device=lengths.device)
mask = alen < lengths[:, None]
# attention mask is the same as mask, or triangular inferior attention (causal) # attention mask is the same as mask, or triangular inferior attention (causal)
if causal: if causal:
...@@ -259,21 +261,21 @@ class MultiHeadAttention(nn.Module): ...@@ -259,21 +261,21 @@ class MultiHeadAttention(nn.Module):
NEW_ID = itertools.count() NEW_ID = itertools.count()
def __init__(self, n_heads, dim, dropout, output_attentions=False): def __init__(self, n_heads, dim, config):
super().__init__() super().__init__()
self.layer_id = next(MultiHeadAttention.NEW_ID) self.layer_id = next(MultiHeadAttention.NEW_ID)
self.output_attentions = output_attentions self.output_attentions = config.output_attentions
self.dim = dim self.dim = dim
self.n_heads = n_heads self.n_heads = n_heads
self.dropout = dropout self.dropout = config.attention_dropout
assert self.dim % self.n_heads == 0 assert self.dim % self.n_heads == 0
self.q_lin = Linear(dim, dim) self.q_lin = Linear(dim, dim, config=config)
self.k_lin = Linear(dim, dim) self.k_lin = Linear(dim, dim, config=config)
self.v_lin = Linear(dim, dim) self.v_lin = Linear(dim, dim, config=config)
self.out_lin = Linear(dim, dim) self.out_lin = Linear(dim, dim, config=config)
def forward(self, input, mask, kv=None, cache=None): def forward(self, input, mask, kv=None, cache=None, head_mask=None):
""" """
Self-attention (if kv is None) or attention over source sentence (provided by kv). Self-attention (if kv is None) or attention over source sentence (provided by kv).
""" """
...@@ -323,6 +325,11 @@ class MultiHeadAttention(nn.Module): ...@@ -323,6 +325,11 @@ class MultiHeadAttention(nn.Module):
weights = F.softmax(scores.float(), dim=-1).type_as(scores) # (bs, n_heads, qlen, klen) weights = F.softmax(scores.float(), dim=-1).type_as(scores) # (bs, n_heads, qlen, klen)
weights = F.dropout(weights, p=self.dropout, training=self.training) # (bs, n_heads, qlen, klen) weights = F.dropout(weights, p=self.dropout, training=self.training) # (bs, n_heads, qlen, klen)
# Mask heads if we want to
if head_mask is not None:
weights = weights * head_mask
context = torch.matmul(weights, v) # (bs, n_heads, qlen, dim_per_head) context = torch.matmul(weights, v) # (bs, n_heads, qlen, dim_per_head)
context = unshape(context) # (bs, qlen, dim) context = unshape(context) # (bs, qlen, dim)
...@@ -334,12 +341,12 @@ class MultiHeadAttention(nn.Module): ...@@ -334,12 +341,12 @@ class MultiHeadAttention(nn.Module):
class TransformerFFN(nn.Module): class TransformerFFN(nn.Module):
def __init__(self, in_dim, dim_hidden, out_dim, dropout, gelu_activation): def __init__(self, in_dim, dim_hidden, out_dim, config):
super().__init__() super().__init__()
self.dropout = dropout self.dropout = config.dropout
self.lin1 = Linear(in_dim, dim_hidden) self.lin1 = Linear(in_dim, dim_hidden, config=config)
self.lin2 = Linear(dim_hidden, out_dim) self.lin2 = Linear(dim_hidden, out_dim, config=config)
self.act = gelu if gelu_activation else F.relu self.act = gelu if config.gelu_activation else F.relu
def forward(self, input): def forward(self, input):
x = self.lin1(input) x = self.lin1(input)
...@@ -365,12 +372,9 @@ class XLMPreTrainedModel(PreTrainedModel): ...@@ -365,12 +372,9 @@ class XLMPreTrainedModel(PreTrainedModel):
""" Initialize the weights. """ Initialize the weights.
""" """
if isinstance(module, (nn.Linear, nn.Embedding)): if isinstance(module, (nn.Linear, nn.Embedding)):
# Slightly different from the TF version which uses truncated_normal for initialization # Weights are initialized in module instantiation (see above)
# cf https://github.com/pytorch/pytorch/pull/5617 pass
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) if isinstance(module, nn.LayerNorm):
if isinstance(module, nn.Linear) and module.bias is not None:
module.bias.data.zero_()
elif isinstance(module, XLMLayerNorm):
module.bias.data.zero_() module.bias.data.zero_()
module.weight.data.fill_(1.0) module.weight.data.fill_(1.0)
...@@ -439,8 +443,10 @@ class XLMModel(XLMPreTrainedModel): ...@@ -439,8 +443,10 @@ class XLMModel(XLMPreTrainedModel):
self.output_hidden_states = config.output_hidden_states self.output_hidden_states = config.output_hidden_states
# encoder / decoder, output layer # encoder / decoder, output layer
# self.is_encoder = is_encoder self.is_encoder = config.is_encoder
# self.is_decoder = not is_encoder self.is_decoder = not config.is_encoder
if self.is_decoder:
raise NotImplementedError("Currently XLM can only be used as an encoder")
# self.with_output = with_output # self.with_output = with_output
self.causal = config.causal self.causal = config.causal
...@@ -450,10 +456,10 @@ class XLMModel(XLMPreTrainedModel): ...@@ -450,10 +456,10 @@ class XLMModel(XLMPreTrainedModel):
self.eos_index = config.eos_index self.eos_index = config.eos_index
self.pad_index = config.pad_index self.pad_index = config.pad_index
# self.dico = dico # self.dico = dico
self.id2lang = config.id2lang # self.id2lang = config.id2lang
self.lang2id = config.lang2id # self.lang2id = config.lang2id
# assert len(self.dico) == self.n_words # assert len(self.dico) == self.n_words
assert len(self.id2lang) == len(self.lang2id) == self.n_langs # assert len(self.id2lang) == len(self.lang2id) == self.n_langs
# model parameters # model parameters
self.dim = config.emb_dim # 512 by default self.dim = config.emb_dim # 512 by default
...@@ -465,12 +471,12 @@ class XLMModel(XLMPreTrainedModel): ...@@ -465,12 +471,12 @@ class XLMModel(XLMPreTrainedModel):
assert self.dim % self.n_heads == 0, 'transformer dim must be a multiple of n_heads' assert self.dim % self.n_heads == 0, 'transformer dim must be a multiple of n_heads'
# embeddings # embeddings
self.position_embeddings = Embedding(config.max_position_embeddings, self.dim) self.position_embeddings = Embedding(config.max_position_embeddings, self.dim, config=config)
if config.sinusoidal_embeddings: if config.sinusoidal_embeddings:
create_sinusoidal_embeddings(config.max_position_embeddings, self.dim, out=self.position_embeddings.weight) create_sinusoidal_embeddings(config.max_position_embeddings, self.dim, out=self.position_embeddings.weight)
if config.n_langs > 1: if config.n_langs > 1:
self.lang_embeddings = Embedding(self.n_langs, self.dim) self.lang_embeddings = Embedding(self.n_langs, self.dim, config=config)
self.embeddings = Embedding(self.n_words, self.dim, padding_idx=self.pad_index) self.embeddings = Embedding(self.n_words, self.dim, padding_idx=self.pad_index, config=config)
self.layer_norm_emb = nn.LayerNorm(self.dim, eps=1e-12) self.layer_norm_emb = nn.LayerNorm(self.dim, eps=1e-12)
# transformer layers # transformer layers
...@@ -478,29 +484,31 @@ class XLMModel(XLMPreTrainedModel): ...@@ -478,29 +484,31 @@ class XLMModel(XLMPreTrainedModel):
self.layer_norm1 = nn.ModuleList() self.layer_norm1 = nn.ModuleList()
self.ffns = nn.ModuleList() self.ffns = nn.ModuleList()
self.layer_norm2 = nn.ModuleList() self.layer_norm2 = nn.ModuleList()
if self.is_decoder: # if self.is_decoder:
self.layer_norm15 = nn.ModuleList() # self.layer_norm15 = nn.ModuleList()
self.encoder_attn = nn.ModuleList() # self.encoder_attn = nn.ModuleList()
for _ in range(self.n_layers): for _ in range(self.n_layers):
self.attentions.append(MultiHeadAttention(self.n_heads, self.dim, dropout=self.attention_dropout)) self.attentions.append(MultiHeadAttention(self.n_heads, self.dim, config=config))
self.layer_norm1.append(nn.LayerNorm(self.dim, eps=1e-12)) self.layer_norm1.append(nn.LayerNorm(self.dim, eps=1e-12))
if self.is_decoder: # if self.is_decoder:
self.layer_norm15.append(nn.LayerNorm(self.dim, eps=1e-12)) # self.layer_norm15.append(nn.LayerNorm(self.dim, eps=1e-12))
self.encoder_attn.append(MultiHeadAttention(self.n_heads, self.dim, dropout=self.attention_dropout)) # self.encoder_attn.append(MultiHeadAttention(self.n_heads, self.dim, dropout=self.attention_dropout))
self.ffns.append(TransformerFFN(self.dim, self.hidden_dim, self.dim, dropout=self.dropout, gelu_activation=config.gelu_activation)) self.ffns.append(TransformerFFN(self.dim, self.hidden_dim, self.dim, config=config))
self.layer_norm2.append(nn.LayerNorm(self.dim, eps=1e-12)) self.layer_norm2.append(nn.LayerNorm(self.dim, eps=1e-12))
def forward(self, input_ids, lengths, positions=None, langs=None, cache=None, head_mask=None): # src_enc=None, src_len=None, def forward(self, input_ids, lengths=None, positions=None, langs=None,
token_type_ids=None, attention_mask=None, cache=None, head_mask=None): # src_enc=None, src_len=None,
""" """
Inputs: Inputs:
`input_ids` LongTensor(bs, slen), containing word indices `input_ids` LongTensor(bs, slen), containing word indices
`lengths` LongTensor(bs), containing the length of each sentence `lengths` LongTensor(bs), containing the length of each sentence
`causal` Boolean, if True, the attention is only done over previous hidden states
`positions` LongTensor(bs, slen), containing word positions `positions` LongTensor(bs, slen), containing word positions
`langs` LongTensor(bs, slen), containing language IDs `langs` LongTensor(bs, slen), containing language IDs
`token_type_ids` LongTensor (bs, slen) same as `langs` used for compatibility
""" """
# lengths = (input_ids != self.pad_index).float().sum(dim=1) if lengths is None:
lengths = (input_ids != self.pad_index).float().sum(dim=1)
# mask = input_ids != self.pad_index # mask = input_ids != self.pad_index
# check inputs # check inputs
...@@ -514,7 +522,7 @@ class XLMModel(XLMPreTrainedModel): ...@@ -514,7 +522,7 @@ class XLMModel(XLMPreTrainedModel):
# assert src_enc.size(0) == bs # assert src_enc.size(0) == bs
# generate masks # generate masks
mask, attn_mask = get_masks(slen, lengths, self.causal) mask, attn_mask = get_masks(slen, lengths, self.causal, padding_mask=attention_mask)
# if self.is_decoder and src_enc is not None: # if self.is_decoder and src_enc is not None:
# src_mask = torch.arange(src_len.max(), dtype=torch.long, device=lengths.device) < src_len[:, None] # src_mask = torch.arange(src_len.max(), dtype=torch.long, device=lengths.device) < src_len[:, None]
...@@ -527,10 +535,28 @@ class XLMModel(XLMPreTrainedModel): ...@@ -527,10 +535,28 @@ class XLMModel(XLMPreTrainedModel):
# positions = positions.transpose(0, 1) # positions = positions.transpose(0, 1)
# langs # langs
assert langs is None or token_type_ids is None, "You can only use one among langs and token_type_ids"
if token_type_ids is not None:
langs = token_type_ids
if langs is not None: if langs is not None:
assert langs.size() == (bs, slen) # (slen, bs) assert langs.size() == (bs, slen) # (slen, bs)
# langs = langs.transpose(0, 1) # langs = langs.transpose(0, 1)
# Prepare head mask if needed
# 1.0 in head_mask indicate we keep the head
# attention_probs has shape bsz x n_heads x N x N
# input head_mask has shape [num_heads] or [num_hidden_layers x num_heads]
# and head_mask is converted to shape [num_hidden_layers x batch x num_heads x qlen x klen]
if head_mask is not None:
if head_mask.dim() == 1:
head_mask = head_mask.unsqueeze(0).unsqueeze(0).unsqueeze(-1).unsqueeze(-1)
head_mask = head_mask.expand(self.n_layers, -1, -1, -1, -1)
elif head_mask.dim() == 2:
head_mask = head_mask.unsqueeze(1).unsqueeze(-1).unsqueeze(-1) # We can specify head_mask for each layer
head_mask = head_mask.to(dtype=next(self.parameters()).dtype) # switch to fload if need + fp16 compatibility
else:
head_mask = [None] * self.n_layers
# do not recompute cached elements # do not recompute cached elements
if cache is not None: if cache is not None:
_slen = slen - cache['slen'] _slen = slen - cache['slen']
...@@ -696,9 +722,7 @@ class XLMWithLMHeadModel(XLMPreTrainedModel): ...@@ -696,9 +722,7 @@ class XLMWithLMHeadModel(XLMPreTrainedModel):
``` ```
""" """
def __init__(self, config): def __init__(self, config):
super(XLMLMHeadModel, self).__init__(config) super(XLMWithLMHeadModel, self).__init__(config)
self.attn_type = config.attn_type
self.same_length = config.same_length
self.transformer = XLMModel(config) self.transformer = XLMModel(config)
self.pred_layer = XLMPredLayer(config) self.pred_layer = XLMPredLayer(config)
...@@ -711,8 +735,8 @@ class XLMWithLMHeadModel(XLMPreTrainedModel): ...@@ -711,8 +735,8 @@ class XLMWithLMHeadModel(XLMPreTrainedModel):
""" """
self.pred_layer.proj.weight = self.transformer.embeddings.weight self.pred_layer.proj.weight = self.transformer.embeddings.weight
def forward(self, input_ids, lengths, positions=None, langs=None, cache=None, def forward(self, input_ids, lengths=None, positions=None, langs=None, token_type_ids=None,
labels=None, head_mask=None): attention_mask=None, cache=None, labels=None, head_mask=None):
""" """
Args: Args:
inp_k: int32 Tensor in shape [bsz, len], the input token IDs. inp_k: int32 Tensor in shape [bsz, len], the input token IDs.
...@@ -739,7 +763,8 @@ class XLMWithLMHeadModel(XLMPreTrainedModel): ...@@ -739,7 +763,8 @@ class XLMWithLMHeadModel(XLMPreTrainedModel):
summary_type: str, "last", "first", "mean", or "attn". The method summary_type: str, "last", "first", "mean", or "attn". The method
to pool the input to get a vector representation. to pool the input to get a vector representation.
""" """
transformer_outputs = self.transformer(input_ids, lengths, positions=positions, langs=langs, cache=cache, head_mask=head_mask) transformer_outputs = self.transformer(input_ids, lengths=lengths, positions=positions, token_type_ids=token_type_ids,
langs=langs, attention_mask=attention_mask, cache=cache, head_mask=head_mask)
output = transformer_outputs[0] output = transformer_outputs[0]
logits = self.pred_layer(output, labels) logits = self.pred_layer(output, labels)
...@@ -759,14 +784,14 @@ class XLMWithLMHeadModel(XLMPreTrainedModel): ...@@ -759,14 +784,14 @@ class XLMWithLMHeadModel(XLMPreTrainedModel):
class XLMSequenceSummary(nn.Module): class XLMSequenceSummary(nn.Module):
def __init__(self, config, summary_type="last", use_proj=True): def __init__(self, config):
super(XLMSequenceSummary, self).__init__() super(XLMSequenceSummary, self).__init__()
self.summary_type = summary_type self.summary_type = config.summary_type
if use_proj: if config.use_proj:
self.summary = nn.Linear(config.d_model, config.d_model) self.summary = nn.Linear(config.d_model, config.d_model)
else: else:
self.summary = None self.summary = None
if summary_type == 'attn': if config.summary_type == 'attn':
# We should use a standard multi-head attention module with absolute positional embedding for that. # We should use a standard multi-head attention module with absolute positional embedding for that.
# Cf. https://github.com/zihangdai/xlnet/blob/master/modeling.py#L253-L276 # Cf. https://github.com/zihangdai/xlnet/blob/master/modeling.py#L253-L276
# We can probably just use the multi-head attention module of PyTorch >=1.1.0 # We can probably just use the multi-head attention module of PyTorch >=1.1.0
...@@ -859,14 +884,13 @@ class XLMForSequenceClassification(XLMPreTrainedModel): ...@@ -859,14 +884,13 @@ class XLMForSequenceClassification(XLMPreTrainedModel):
super(XLMForSequenceClassification, self).__init__(config) super(XLMForSequenceClassification, self).__init__(config)
self.transformer = XLMModel(config) self.transformer = XLMModel(config)
self.sequence_summary = XLMSequenceSummary(config) self.sequence_summary = XLMSequenceSummary(config)
self.logits_proj = nn.Linear(config.d_model, num_labels) self.logits_proj = nn.Linear(config.d_model, config.num_labels)
self.apply(self.init_weights) self.apply(self.init_weights)
def forward(self, inp_k, token_type_ids=None, input_mask=None, attention_mask=None, def forward(self, input_ids, lengths=None, positions=None, langs=None, attention_mask=None,
mems=None, perm_mask=None, target_mapping=None, inp_q=None, cache=None, labels=None, head_mask=None):
labels=None, head_mask=None):
""" """
Args: Args:
inp_k: int32 Tensor in shape [bsz, len], the input token IDs. inp_k: int32 Tensor in shape [bsz, len], the input token IDs.
...@@ -894,8 +918,8 @@ class XLMForSequenceClassification(XLMPreTrainedModel): ...@@ -894,8 +918,8 @@ class XLMForSequenceClassification(XLMPreTrainedModel):
Only used during pretraining for two-stream attention. Only used during pretraining for two-stream attention.
Set to None during finetuning. Set to None during finetuning.
""" """
transformer_outputs = self.transformer(inp_k, token_type_ids, input_mask, attention_mask, transformer_outputs = self.transformer(input_ids, lengths=lengths, positions=positions, token_type_ids=token_type_ids,
mems, perm_mask, target_mapping, inp_q, head_mask) langs=langs, attention_mask=attention_mask, cache=cache, head_mask=head_mask)
output = transformer_outputs[0] output = transformer_outputs[0]
output = self.sequence_summary(output) output = self.sequence_summary(output)
...@@ -974,7 +998,7 @@ class XLMForQuestionAnswering(XLMPreTrainedModel): ...@@ -974,7 +998,7 @@ class XLMForQuestionAnswering(XLMPreTrainedModel):
start_logits, end_logits = model(input_ids, token_type_ids, input_mask) start_logits, end_logits = model(input_ids, token_type_ids, input_mask)
``` ```
""" """
def __init__(self, CONFIG_NAME): def __init__(self, config):
super(XLMForQuestionAnswering, self).__init__(config) super(XLMForQuestionAnswering, self).__init__(config)
self.transformer = XLMModel(config) self.transformer = XLMModel(config)
...@@ -982,12 +1006,11 @@ class XLMForQuestionAnswering(XLMPreTrainedModel): ...@@ -982,12 +1006,11 @@ class XLMForQuestionAnswering(XLMPreTrainedModel):
self.apply(self.init_weights) self.apply(self.init_weights)
def forward(self, inp_k, token_type_ids=None, input_mask=None, attention_mask=None, def forward(self, input_ids, lengths=None, positions=None, langs=None, attention_mask=None, cache=None,
mems=None, perm_mask=None, target_mapping=None, inp_q=None, labels=None, head_mask=None):
start_positions=None, end_positions=None, head_mask=None):
transformer_outputs = self.transformer(inp_k, token_type_ids, input_mask, attention_mask, transformer_outputs = self.transformer(input_ids, lengths=lengths, positions=positions, token_type_ids=token_type_ids,
mems, perm_mask, target_mapping, inp_q, head_mask) langs=langs, attention_mask=attention_mask, cache=cache, head_mask=head_mask)
output = transformer_outputs[0] output = transformer_outputs[0]
logits = self.qa_outputs(output) logits = self.qa_outputs(output)
......
...@@ -36,7 +36,9 @@ def _create_and_check_initialization(tester, model_classes, config, inputs_dict) ...@@ -36,7 +36,9 @@ def _create_and_check_initialization(tester, model_classes, config, inputs_dict)
for model_class in model_classes: for model_class in model_classes:
model = model_class(config=configs_no_init) model = model_class(config=configs_no_init)
for name, param in model.named_parameters(): for name, param in model.named_parameters():
tester.parent.assertIn(param.data.mean().item(), [0.0, 1.0], msg="Parameter {} of model {} seems not properly initialized".format(name, model_class)) if param.requires_grad:
tester.parent.assertIn(param.data.mean().item(), [0.0, 1.0],
msg="Parameter {} of model {} seems not properly initialized".format(name, model_class))
def _create_and_check_for_headmasking(tester, model_classes, config, inputs_dict): def _create_and_check_for_headmasking(tester, model_classes, config, inputs_dict):
configs_no_init = _config_zero_init(config) configs_no_init = _config_zero_init(config)
......
...@@ -26,7 +26,7 @@ import pytest ...@@ -26,7 +26,7 @@ import pytest
import torch import torch
from pytorch_pretrained_bert import PretrainedConfig, PreTrainedModel from pytorch_pretrained_bert import PretrainedConfig, PreTrainedModel
from pytorch_pretrained_bert.modeling import BertModel, BertConfig, PRETRAINED_MODEL_ARCHIVE_MAP, PRETRAINED_CONFIG_ARCHIVE_MAP from pytorch_pretrained_bert.modeling_bert import BertModel, BertConfig, PRETRAINED_MODEL_ARCHIVE_MAP, PRETRAINED_CONFIG_ARCHIVE_MAP
class ModelUtilsTest(unittest.TestCase): class ModelUtilsTest(unittest.TestCase):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment