Commit 288be7b7 authored by thomwolf's avatar thomwolf
Browse files

xlm

parent 70887795
# coding=utf-8
# Copyright 2018 The HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Convert OpenAI GPT checkpoint."""
from __future__ import absolute_import, division, print_function
import argparse
import json
from io import open
import torch
import numpy
from pytorch_pretrained_bert.modeling_xlm import (CONFIG_NAME, WEIGHTS_NAME, XLMConfig, XLMModel)
from pytorch_pretrained_bert.tokenization_xlm import MERGES_NAME, VOCAB_NAME
def convert_xlm_checkpoint_to_pytorch(xlm_checkpoint_path, pytorch_dump_folder_path):
# Load checkpoint
chkpt = torch.load(xlm_checkpoint_path, map_location='cpu')
model = chkpt['model']
config = chkpt['params']
config = dict((n, v) for n, v in config.items() if not isinstance(v, (torch.Tensor, numpy.ndarray)))
vocab = chkpt['dico_word2id']
vocab = dict((s + '</w>' if s.find('@@') == -1 and i > 13 else s.replace('@@', ''), i) for s, i in d.items())
# Save pytorch-model
pytorch_weights_dump_path = pytorch_dump_folder_path + '/' + WEIGHTS_NAME
pytorch_config_dump_path = pytorch_dump_folder_path + '/' + CONFIG_NAME
pytorch_vocab_dump_path = pytorch_dump_folder_path + '/' + VOCAB_NAME
print("Save PyTorch model to {}".format(pytorch_weights_dump_path))
torch.save(model, pytorch_weights_dump_path)
print("Save configuration file to {}".format(pytorch_config_dump_path))
with open(pytorch_config_dump_path, "w", encoding="utf-8") as f:
f.write(json.dumps(config, indent=2) + "\n")
print("Save vocab file to {}".format(pytorch_config_dump_path))
with open(pytorch_vocab_dump_path, "w", encoding="utf-8") as f:
f.write(json.dumps(vocab, indent=2) + "\n")
if __name__ == "__main__":
parser = argparse.ArgumentParser()
## Required parameters
parser.add_argument("--xlm_checkpoint_path",
default = None,
type = str,
required = True,
help = "Path the official PyTorch dump.")
parser.add_argument("--pytorch_dump_folder_path",
default = None,
type = str,
required = True,
help = "Path to the output PyTorch model.")
args = parser.parse_args()
convert_xlm_checkpoint_to_pytorch(args.xlm_checkpoint_path, args.pytorch_dump_folder_path)
...@@ -72,29 +72,22 @@ class XLMConfig(PretrainedConfig): ...@@ -72,29 +72,22 @@ class XLMConfig(PretrainedConfig):
def __init__(self, def __init__(self,
vocab_size_or_config_json_file, vocab_size_or_config_json_file,
causal=True, n_special=0,
d_model=1024, emb_dim=2048,
n_layer=24, n_layers=12,
n_head=16, n_heads=16,
d_inner=4096, dropout=0.1,
ff_activation="gelu", attention_dropout=0.1,
untie_r=True, gelu_activation=True,
attn_type="bi", sinusoidal_embeddings=False,
asm=False,
id2lang={ 0: "en" },
lang2id={ "en": 0 },
n_langs=1,
n_words=30145,
max_position_embeddings=512, max_position_embeddings=512,
initializer_range=0.02, initializer_range=0.02,
layer_norm_eps=1e-12, **kwargs):
dropout=0.1,
dropatt=0.1,
init="normal",
init_range=0.1,
init_std=0.02,
mem_len=None,
reuse_len=None,
bi_data=False,
clamp_len=-1,
same_length=False):
"""Constructs XLMConfig. """Constructs XLMConfig.
Args: Args:
...@@ -137,6 +130,8 @@ class XLMConfig(PretrainedConfig): ...@@ -137,6 +130,8 @@ class XLMConfig(PretrainedConfig):
-1 means no clamping. -1 means no clamping.
same_length: bool, whether to use the same attention length for each token. same_length: bool, whether to use the same attention length for each token.
""" """
super(XLMConfig, self).__init__(**kwargs)
if isinstance(vocab_size_or_config_json_file, str) or (sys.version_info[0] == 2 if isinstance(vocab_size_or_config_json_file, str) or (sys.version_info[0] == 2
and isinstance(vocab_size_or_config_json_file, unicode)): and isinstance(vocab_size_or_config_json_file, unicode)):
with open(vocab_size_or_config_json_file, "r", encoding='utf-8') as reader: with open(vocab_size_or_config_json_file, "r", encoding='utf-8') as reader:
...@@ -144,36 +139,41 @@ class XLMConfig(PretrainedConfig): ...@@ -144,36 +139,41 @@ class XLMConfig(PretrainedConfig):
for key, value in json_config.items(): for key, value in json_config.items():
self.__dict__[key] = value self.__dict__[key] = value
elif isinstance(vocab_size_or_config_json_file, int): elif isinstance(vocab_size_or_config_json_file, int):
self.n_token = vocab_size_or_config_json_file self.n_words = vocab_size_or_config_json_file
self.causal = causal self.n_special = n_special
self.d_model = d_model self.emb_dim = emb_dim
self.n_layer = n_layer self.n_layers = n_layers
self.n_head = n_head self.n_heads = n_heads
assert d_model % n_head == 0 self.dropout = dropout
self.d_head = d_model // n_head self.attention_dropout = attention_dropout
self.ff_activation = ff_activation self.gelu_activation = gelu_activation
self.d_inner = d_inner self.sinusoidal_embeddings = sinusoidal_embeddings
self.untie_r = untie_r self.asm = asm
self.attn_type = attn_type self.id2lang = id2lang
self.lang2id = lang2id
self.n_langs = n_langs
self.max_position_embeddings = max_position_embeddings self.max_position_embeddings = max_position_embeddings
self.initializer_range = initializer_range self.initializer_range = initializer_range
self.layer_norm_eps = layer_norm_eps
self.init = init
self.init_range = init_range
self.init_std = init_std
self.dropout = dropout
self.dropatt = dropatt
self.mem_len = mem_len
self.reuse_len = reuse_len
self.bi_data = bi_data
self.clamp_len = clamp_len
self.same_length = same_length
else: else:
raise ValueError("First argument must be either a vocabulary size (int)" raise ValueError("First argument must be either a vocabulary size (int)"
"or the path to a pretrained model config file (str)") "or the path to a pretrained model config file (str)")
@property
def total_tokens_embeddings(self):
return self.n_words + self.n_special
@property
def hidden_size(self):
return self.emb_dim
@property
def num_attention_heads(self):
return self.n_heads
@property
def num_hidden_layers(self):
return self.n_layers
try: try:
from apex.normalization.fused_layer_norm import FusedLayerNorm as XLMLayerNorm from apex.normalization.fused_layer_norm import FusedLayerNorm as XLMLayerNorm
...@@ -259,9 +259,10 @@ class MultiHeadAttention(nn.Module): ...@@ -259,9 +259,10 @@ class MultiHeadAttention(nn.Module):
NEW_ID = itertools.count() NEW_ID = itertools.count()
def __init__(self, n_heads, dim, dropout): def __init__(self, n_heads, dim, dropout, output_attentions=False):
super().__init__() super().__init__()
self.layer_id = next(MultiHeadAttention.NEW_ID) self.layer_id = next(MultiHeadAttention.NEW_ID)
self.output_attentions = output_attentions
self.dim = dim self.dim = dim
self.n_heads = n_heads self.n_heads = n_heads
self.dropout = dropout self.dropout = dropout
...@@ -325,7 +326,10 @@ class MultiHeadAttention(nn.Module): ...@@ -325,7 +326,10 @@ class MultiHeadAttention(nn.Module):
context = torch.matmul(weights, v) # (bs, n_heads, qlen, dim_per_head) context = torch.matmul(weights, v) # (bs, n_heads, qlen, dim_per_head)
context = unshape(context) # (bs, qlen, dim) context = unshape(context) # (bs, qlen, dim)
return self.out_lin(context) outputs = (self.out_lin(context),)
if self.output_attentions:
outputs = outputs + (weights)
return outputs
class TransformerFFN(nn.Module): class TransformerFFN(nn.Module):
...@@ -345,52 +349,6 @@ class TransformerFFN(nn.Module): ...@@ -345,52 +349,6 @@ class TransformerFFN(nn.Module):
return x return x
class BeamHypotheses(object):
def __init__(self, n_hyp, max_len, length_penalty, early_stopping):
"""
Initialize n-best list of hypotheses.
"""
self.max_len = max_len - 1 # ignoring <BOS>
self.length_penalty = length_penalty
self.early_stopping = early_stopping
self.n_hyp = n_hyp
self.hyp = []
self.worst_score = 1e9
def __len__(self):
"""
Number of hypotheses in the list.
"""
return len(self.hyp)
def add(self, hyp, sum_logprobs):
"""
Add a new hypothesis to the list.
"""
score = sum_logprobs / len(hyp) ** self.length_penalty
if len(self) < self.n_hyp or score > self.worst_score:
self.hyp.append((score, hyp))
if len(self) > self.n_hyp:
sorted_scores = sorted([(s, idx) for idx, (s, _) in enumerate(self.hyp)])
del self.hyp[sorted_scores[0][1]]
self.worst_score = sorted_scores[1][0]
else:
self.worst_score = min(score, self.worst_score)
def is_done(self, best_sum_logprobs):
"""
If there are enough hypotheses and that none of the hypotheses being generated
can become better than the worst one in the heap, then we are done with this sentence.
"""
if len(self) < self.n_hyp:
return False
elif self.early_stopping:
return True
else:
return self.worst_score >= best_sum_logprobs / self.max_len ** self.length_penalty
class XLMPreTrainedModel(PreTrainedModel): class XLMPreTrainedModel(PreTrainedModel):
""" An abstract class to handle weights initialization and """ An abstract class to handle weights initialization and
a simple interface for dowloading and loading pretrained models. a simple interface for dowloading and loading pretrained models.
...@@ -410,16 +368,11 @@ class XLMPreTrainedModel(PreTrainedModel): ...@@ -410,16 +368,11 @@ class XLMPreTrainedModel(PreTrainedModel):
# Slightly different from the TF version which uses truncated_normal for initialization # Slightly different from the TF version which uses truncated_normal for initialization
# cf https://github.com/pytorch/pytorch/pull/5617 # cf https://github.com/pytorch/pytorch/pull/5617
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
if isinstance(module, nn.Linear) and module.bias is not None:
module.bias.data.zero_()
elif isinstance(module, XLMLayerNorm): elif isinstance(module, XLMLayerNorm):
module.bias.data.zero_() module.bias.data.zero_()
module.weight.data.fill_(1.0) module.weight.data.fill_(1.0)
elif isinstance(module, XLMRelativeAttention):
for param in [module.q, module.k, module.v, module.o, module.r,
module.r_r_bias, module.r_s_bias, module.r_w_bias,
module.seg_embed]:
param.data.normal_(mean=0.0, std=self.config.initializer_range)
if isinstance(module, nn.Linear) and module.bias is not None:
module.bias.data.zero_()
class XLMModel(XLMPreTrainedModel): class XLMModel(XLMPreTrainedModel):
...@@ -429,7 +382,7 @@ class XLMModel(XLMPreTrainedModel): ...@@ -429,7 +382,7 @@ class XLMModel(XLMPreTrainedModel):
'hidden_dim', 'dropout', 'attention_dropout', 'asm', 'hidden_dim', 'dropout', 'attention_dropout', 'asm',
'asm_cutoffs', 'asm_div_value'] 'asm_cutoffs', 'asm_div_value']
def __init__(self, params, output_attentions=False, output_hidden_states=False): #, dico, is_encoder, with_output): def __init__(self, config): #, dico, is_encoder, with_output):
""" XLM model from: "Cross-lingual Language Model Pretraining" by Guillaume Lample, Alexis Conneau """ XLM model from: "Cross-lingual Language Model Pretraining" by Guillaume Lample, Alexis Conneau
Paper: https://arxiv.org/abs/1901.07291 Paper: https://arxiv.org/abs/1901.07291
Original code: https://github.com/facebookresearch/XLM Original code: https://github.com/facebookresearch/XLM
...@@ -481,41 +434,41 @@ class XLMModel(XLMPreTrainedModel): ...@@ -481,41 +434,41 @@ class XLMModel(XLMPreTrainedModel):
all_encoder_layers, pooled_output = model(input_ids, token_type_ids, input_mask) all_encoder_layers, pooled_output = model(input_ids, token_type_ids, input_mask)
``` ```
""" """
super(XLMModel, self).__init__(params) super(XLMModel, self).__init__(config)
self.output_attentions = output_attentions self.output_attentions = config.output_attentions
self.output_hidden_states = output_hidden_states self.output_hidden_states = config.output_hidden_states
# encoder / decoder, output layer # encoder / decoder, output layer
# self.is_encoder = is_encoder # self.is_encoder = is_encoder
# self.is_decoder = not is_encoder # self.is_decoder = not is_encoder
# self.with_output = with_output # self.with_output = with_output
self.causal = params.causal self.causal = config.causal
# dictionary / languages # dictionary / languages
self.n_langs = params.n_langs self.n_langs = config.n_langs
self.n_words = params.n_words self.n_words = config.n_words
self.eos_index = params.eos_index self.eos_index = config.eos_index
self.pad_index = params.pad_index self.pad_index = config.pad_index
# self.dico = dico # self.dico = dico
self.id2lang = params.id2lang self.id2lang = config.id2lang
self.lang2id = params.lang2id self.lang2id = config.lang2id
# assert len(self.dico) == self.n_words # assert len(self.dico) == self.n_words
assert len(self.id2lang) == len(self.lang2id) == self.n_langs assert len(self.id2lang) == len(self.lang2id) == self.n_langs
# model parameters # model parameters
self.dim = params.emb_dim # 512 by default self.dim = config.emb_dim # 512 by default
self.hidden_dim = self.dim * 4 # 2048 by default self.hidden_dim = self.dim * 4 # 2048 by default
self.n_heads = params.n_heads # 8 by default self.n_heads = config.n_heads # 8 by default
self.n_layers = params.n_layers self.n_layers = config.n_layers
self.dropout = params.dropout self.dropout = config.dropout
self.attention_dropout = params.attention_dropout self.attention_dropout = config.attention_dropout
assert self.dim % self.n_heads == 0, 'transformer dim must be a multiple of n_heads' assert self.dim % self.n_heads == 0, 'transformer dim must be a multiple of n_heads'
# embeddings # embeddings
self.position_embeddings = Embedding(params.max_position_embeddings, self.dim) self.position_embeddings = Embedding(config.max_position_embeddings, self.dim)
if params.sinusoidal_embeddings: if config.sinusoidal_embeddings:
create_sinusoidal_embeddings(params.max_position_embeddings, self.dim, out=self.position_embeddings.weight) create_sinusoidal_embeddings(config.max_position_embeddings, self.dim, out=self.position_embeddings.weight)
if params.n_langs > 1: if config.n_langs > 1:
self.lang_embeddings = Embedding(self.n_langs, self.dim) self.lang_embeddings = Embedding(self.n_langs, self.dim)
self.embeddings = Embedding(self.n_words, self.dim, padding_idx=self.pad_index) self.embeddings = Embedding(self.n_words, self.dim, padding_idx=self.pad_index)
self.layer_norm_emb = nn.LayerNorm(self.dim, eps=1e-12) self.layer_norm_emb = nn.LayerNorm(self.dim, eps=1e-12)
...@@ -535,26 +488,26 @@ class XLMModel(XLMPreTrainedModel): ...@@ -535,26 +488,26 @@ class XLMModel(XLMPreTrainedModel):
if self.is_decoder: if self.is_decoder:
self.layer_norm15.append(nn.LayerNorm(self.dim, eps=1e-12)) self.layer_norm15.append(nn.LayerNorm(self.dim, eps=1e-12))
self.encoder_attn.append(MultiHeadAttention(self.n_heads, self.dim, dropout=self.attention_dropout)) self.encoder_attn.append(MultiHeadAttention(self.n_heads, self.dim, dropout=self.attention_dropout))
self.ffns.append(TransformerFFN(self.dim, self.hidden_dim, self.dim, dropout=self.dropout, gelu_activation=params.gelu_activation)) self.ffns.append(TransformerFFN(self.dim, self.hidden_dim, self.dim, dropout=self.dropout, gelu_activation=config.gelu_activation))
self.layer_norm2.append(nn.LayerNorm(self.dim, eps=1e-12)) self.layer_norm2.append(nn.LayerNorm(self.dim, eps=1e-12))
def forward(self, x, lengths, positions=None, langs=None, cache=None, head_mask=None): # src_enc=None, src_len=None, def forward(self, input_ids, lengths, positions=None, langs=None, cache=None, head_mask=None): # src_enc=None, src_len=None,
""" """
Inputs: Inputs:
`x` LongTensor(bs, slen), containing word indices `input_ids` LongTensor(bs, slen), containing word indices
`lengths` LongTensor(bs), containing the length of each sentence `lengths` LongTensor(bs), containing the length of each sentence
`causal` Boolean, if True, the attention is only done over previous hidden states `causal` Boolean, if True, the attention is only done over previous hidden states
`positions` LongTensor(bs, slen), containing word positions `positions` LongTensor(bs, slen), containing word positions
`langs` LongTensor(bs, slen), containing language IDs `langs` LongTensor(bs, slen), containing language IDs
""" """
# lengths = (x != self.pad_index).float().sum(dim=1) # lengths = (input_ids != self.pad_index).float().sum(dim=1)
# mask = x != self.pad_index # mask = input_ids != self.pad_index
# check inputs # check inputs
bs, slen = x.size() bs, slen = input_ids.size()
assert lengths.size(0) == bs assert lengths.size(0) == bs
assert lengths.max().item() <= slen assert lengths.max().item() <= slen
# x = x.transpose(0, 1) # batch size as dimension 0 # input_ids = input_ids.transpose(0, 1) # batch size as dimension 0
# assert (src_enc is None) == (src_len is None) # assert (src_enc is None) == (src_len is None)
# if src_enc is not None: # if src_enc is not None:
# assert self.is_decoder # assert self.is_decoder
...@@ -567,7 +520,7 @@ class XLMModel(XLMPreTrainedModel): ...@@ -567,7 +520,7 @@ class XLMModel(XLMPreTrainedModel):
# positions # positions
if positions is None: if positions is None:
positions = x.new(slen).long() positions = input_ids.new(slen).long()
positions = torch.arange(slen, out=positions).unsqueeze(0) positions = torch.arange(slen, out=positions).unsqueeze(0)
else: else:
assert positions.size() == (bs, slen) # (slen, bs) assert positions.size() == (bs, slen) # (slen, bs)
...@@ -581,7 +534,7 @@ class XLMModel(XLMPreTrainedModel): ...@@ -581,7 +534,7 @@ class XLMModel(XLMPreTrainedModel):
# do not recompute cached elements # do not recompute cached elements
if cache is not None: if cache is not None:
_slen = slen - cache['slen'] _slen = slen - cache['slen']
x = x[:, -_slen:] input_ids = input_ids[:, -_slen:]
positions = positions[:, -_slen:] positions = positions[:, -_slen:]
if langs is not None: if langs is not None:
langs = langs[:, -_slen:] langs = langs[:, -_slen:]
...@@ -589,7 +542,7 @@ class XLMModel(XLMPreTrainedModel): ...@@ -589,7 +542,7 @@ class XLMModel(XLMPreTrainedModel):
attn_mask = attn_mask[:, -_slen:] attn_mask = attn_mask[:, -_slen:]
# embeddings # embeddings
tensor = self.embeddings(x) tensor = self.embeddings(input_ids)
tensor = tensor + self.position_embeddings(positions).expand_as(tensor) tensor = tensor + self.position_embeddings(positions).expand_as(tensor)
if langs is not None: if langs is not None:
tensor = tensor + self.lang_embeddings(langs) tensor = tensor + self.lang_embeddings(langs)
...@@ -648,21 +601,21 @@ class XLMPredLayer(nn.Module): ...@@ -648,21 +601,21 @@ class XLMPredLayer(nn.Module):
""" """
Prediction layer (cross_entropy or adaptive_softmax). Prediction layer (cross_entropy or adaptive_softmax).
""" """
def __init__(self, params): def __init__(self, config):
super().__init__() super().__init__()
self.asm = params.asm self.asm = config.asm
self.n_words = params.n_words self.n_words = config.n_words
self.pad_index = params.pad_index self.pad_index = config.pad_index
dim = params.emb_dim dim = config.emb_dim
if params.asm is False: if config.asm is False:
self.proj = Linear(dim, params.n_words, bias=True) self.proj = Linear(dim, config.n_words, bias=True)
else: else:
self.proj = nn.AdaptiveLogSoftmaxWithLoss( self.proj = nn.AdaptiveLogSoftmaxWithLoss(
in_features=dim, in_features=dim,
n_classes=params.n_words, n_classes=config.n_words,
cutoffs=params.asm_cutoffs, cutoffs=config.asm_cutoffs,
div_value=params.asm_div_value, div_value=config.asm_div_value,
head_bias=True, # default is False head_bias=True, # default is False
) )
...@@ -742,15 +695,12 @@ class XLMWithLMHeadModel(XLMPreTrainedModel): ...@@ -742,15 +695,12 @@ class XLMWithLMHeadModel(XLMPreTrainedModel):
all_encoder_layers, pooled_output = model(input_ids, token_type_ids, input_mask) all_encoder_layers, pooled_output = model(input_ids, token_type_ids, input_mask)
``` ```
""" """
def __init__(self, config, output_attentions=False, output_hidden_states=False): def __init__(self, config):
super(XLMLMHeadModel, self).__init__(config) super(XLMLMHeadModel, self).__init__(config)
self.output_attentions = output_attentions
self.output_hidden_states = output_hidden_states
self.attn_type = config.attn_type self.attn_type = config.attn_type
self.same_length = config.same_length self.same_length = config.same_length
self.transformer = XLMModel(config, output_attentions=output_attentions, output_hidden_states=output_hidden_states) self.transformer = XLMModel(config)
self.pred_layer = XLMPredLayer(config) self.pred_layer = XLMPredLayer(config)
self.apply(self.init_weights) self.apply(self.init_weights)
...@@ -761,7 +711,7 @@ class XLMWithLMHeadModel(XLMPreTrainedModel): ...@@ -761,7 +711,7 @@ class XLMWithLMHeadModel(XLMPreTrainedModel):
""" """
self.pred_layer.proj.weight = self.transformer.embeddings.weight self.pred_layer.proj.weight = self.transformer.embeddings.weight
def forward(self, x, lengths, positions=None, langs=None, cache=None, def forward(self, input_ids, lengths, positions=None, langs=None, cache=None,
labels=None, head_mask=None): labels=None, head_mask=None):
""" """
Args: Args:
...@@ -789,7 +739,7 @@ class XLMWithLMHeadModel(XLMPreTrainedModel): ...@@ -789,7 +739,7 @@ class XLMWithLMHeadModel(XLMPreTrainedModel):
summary_type: str, "last", "first", "mean", or "attn". The method summary_type: str, "last", "first", "mean", or "attn". The method
to pool the input to get a vector representation. to pool the input to get a vector representation.
""" """
transformer_outputs = self.transformer(x, lengths, positions=positions, langs=langs, cache=cache, head_mask=head_mask) transformer_outputs = self.transformer(input_ids, lengths, positions=positions, langs=langs, cache=cache, head_mask=head_mask)
output = transformer_outputs[0] output = transformer_outputs[0]
logits = self.pred_layer(output, labels) logits = self.pred_layer(output, labels)
...@@ -905,18 +855,12 @@ class XLMForSequenceClassification(XLMPreTrainedModel): ...@@ -905,18 +855,12 @@ class XLMForSequenceClassification(XLMPreTrainedModel):
all_encoder_layers, pooled_output = model(input_ids, token_type_ids, input_mask) all_encoder_layers, pooled_output = model(input_ids, token_type_ids, input_mask)
``` ```
""" """
def __init__(self, config, summary_type="last", use_proj=True, num_labels=2, def __init__(self, config):
output_attentions=False, output_hidden_states=False):
super(XLMForSequenceClassification, self).__init__(config) super(XLMForSequenceClassification, self).__init__(config)
self.output_attentions = output_attentions
self.output_hidden_states = output_hidden_states
self.summary_type = summary_type
self.num_labels = num_labels
self.transformer = XLMModel(config, output_attentions=output_attentions, output_hidden_states=output_hidden_states) self.transformer = XLMModel(config)
self.sequence_summary = XLMSequenceSummary(config, summary_type=summary_type, use_proj=use_proj) self.sequence_summary = XLMSequenceSummary(config)
self.logits_proj = nn.Linear(config.d_model, num_labels) self.logits_proj = nn.Linear(config.d_model, num_labels)
self.apply(self.init_weights) self.apply(self.init_weights)
...@@ -1030,13 +974,12 @@ class XLMForQuestionAnswering(XLMPreTrainedModel): ...@@ -1030,13 +974,12 @@ class XLMForQuestionAnswering(XLMPreTrainedModel):
start_logits, end_logits = model(input_ids, token_type_ids, input_mask) start_logits, end_logits = model(input_ids, token_type_ids, input_mask)
``` ```
""" """
def __init__(self, config, output_attentions=False, output_hidden_states=False): def __init__(self, CONFIG_NAME):
super(XLMForQuestionAnswering, self).__init__(config) super(XLMForQuestionAnswering, self).__init__(config)
self.output_attentions = output_attentions
self.output_hidden_states = output_hidden_states
self.transformer = XLMModel(config, output_attentions=output_attentions, output_hidden_states=output_hidden_states) self.transformer = XLMModel(config)
self.qa_outputs = nn.Linear(config.hidden_size, 2) self.qa_outputs = nn.Linear(config.hidden_size, config.num_labels)
self.apply(self.init_weights) self.apply(self.init_weights)
def forward(self, inp_k, token_type_ids=None, input_mask=None, attention_mask=None, def forward(self, inp_k, token_type_ids=None, input_mask=None, attention_mask=None,
......
# coding=utf-8
# Copyright 2019 The Open AI Team Authors and The HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tokenization classes for OpenAI GPT."""
from __future__ import (absolute_import, division, print_function,
unicode_literals)
import json
import logging
import os
import re
import sys
from io import open
from tqdm import tqdm
from .file_utils import cached_path
from .tokenization import BasicTokenizer
logger = logging.getLogger(__name__)
PRETRAINED_VOCAB_ARCHIVE_MAP = {
'xlm-mlm-en-2048': "https://s3.amazonaws.com/models.huggingface.co/bert/xlm-mlm-en-2048-vocab.json",
}
PRETRAINED_MERGES_ARCHIVE_MAP = {
'xlm-mlm-en-2048': "https://s3.amazonaws.com/models.huggingface.co/bert/xlm-mlm-en-2048-merges.txt",
}
PRETRAINED_VOCAB_POSITIONAL_EMBEDDINGS_SIZE_MAP = {
'xlm-mlm-en-2048': 512,
}
VOCAB_NAME = 'vocab.json'
MERGES_NAME = 'merges.txt'
SPECIAL_TOKENS_NAME = 'special_tokens.txt'
INDEX= {
"bos_index": 0,
"eos_index": 1,
"pad_index": 2,
"unk_index": 3,
"mask_index": 5
}
def get_pairs(word):
"""
Return set of symbol pairs in a word.
word is represented as tuple of symbols (symbols being variable-length strings)
"""
pairs = set()
prev_char = word[0]
for char in word[1:]:
pairs.add((prev_char, char))
prev_char = char
return pairs
def text_standardize(text):
"""
fixes some issues the spacy tokenizer had on books corpus
also does some whitespace standardization
"""
text = text.replace('—', '-')
text = text.replace('–', '-')
text = text.replace('―', '-')
text = text.replace('…', '...')
text = text.replace('´', "'")
text = re.sub(r'''(-+|~+|!+|"+|;+|\?+|\++|,+|\)+|\(+|\\+|\/+|\*+|\[+|\]+|}+|{+|\|+|_+)''', r' \1 ', text)
text = re.sub(r'\s*\n\s*', ' \n ', text)
text = re.sub(r'[^\S\n]+', ' ', text)
return text.strip()
class XLMTokenizer(object):
"""
BPE tokenizer for XLM, adapted from OpenAI BPE tokenizer. Peculiarities:
- lower case all inputs
- uses SpaCy tokenizer and ftfy for pre-BPE tokenization if they are installed, fallback to BERT's BasicTokenizer if not.
- argument special_tokens and function set_special_tokens:
can be used to add additional symbols (ex: "__classify__") to a vocabulary.
"""
@classmethod
def from_pretrained(cls, pretrained_model_name_or_path, cache_dir=None, *inputs, **kwargs):
"""
Instantiate a PreTrainedBertModel from a pre-trained model file.
Download and cache the pre-trained model file if needed.
"""
if pretrained_model_name_or_path in PRETRAINED_VOCAB_ARCHIVE_MAP:
vocab_file = PRETRAINED_VOCAB_ARCHIVE_MAP[pretrained_model_name_or_path]
merges_file = PRETRAINED_MERGES_ARCHIVE_MAP[pretrained_model_name_or_path]
special_tokens_file = None
else:
vocab_file = os.path.join(pretrained_model_name_or_path, VOCAB_NAME)
merges_file = os.path.join(pretrained_model_name_or_path, MERGES_NAME)
special_tokens_file = os.path.join(pretrained_model_name_or_path, SPECIAL_TOKENS_NAME)
if not os.path.exists(special_tokens_file):
special_tokens_file = None
else:
logger.info("loading special tokens file {}".format(special_tokens_file))
# redirect to the cache, if necessary
try:
resolved_vocab_file = cached_path(vocab_file, cache_dir=cache_dir)
resolved_merges_file = cached_path(merges_file, cache_dir=cache_dir)
except EnvironmentError:
if pretrained_model_name_or_path in PRETRAINED_VOCAB_ARCHIVE_MAP:
logger.error(
"Couldn't reach server at '{}' to download vocabulary.".format(
vocab_file))
else:
logger.error(
"Model name '{}' was not found in model name list ({}). "
"We assumed '{}' was a path or url but couldn't find files {} and {} "
"at this path or url.".format(
pretrained_model_name_or_path,
', '.join(PRETRAINED_VOCAB_ARCHIVE_MAP.keys()),
pretrained_model_name_or_path,
vocab_file, merges_file))
return None
if resolved_vocab_file == vocab_file and resolved_merges_file == merges_file:
logger.info("loading vocabulary file {}".format(vocab_file))
logger.info("loading merges file {}".format(merges_file))
else:
logger.info("loading vocabulary file {} from cache at {}".format(
vocab_file, resolved_vocab_file))
logger.info("loading merges file {} from cache at {}".format(
merges_file, resolved_merges_file))
if pretrained_model_name_or_path in PRETRAINED_VOCAB_POSITIONAL_EMBEDDINGS_SIZE_MAP:
# if we're using a pretrained model, ensure the tokenizer wont index sequences longer
# than the number of positional embeddings
max_len = PRETRAINED_VOCAB_POSITIONAL_EMBEDDINGS_SIZE_MAP[pretrained_model_name_or_path]
kwargs['max_len'] = min(kwargs.get('max_len', int(1e12)), max_len)
# Instantiate tokenizer.
if special_tokens_file and 'special_tokens' not in kwargs:
special_tokens = open(special_tokens_file, encoding='utf-8').read().split('\n')[:-1]
else:
special_tokens = kwargs.pop('special_tokens', [])
tokenizer = cls(resolved_vocab_file, resolved_merges_file, special_tokens=special_tokens, *inputs, **kwargs)
return tokenizer
def __init__(self, vocab_file, merges_file, special_tokens=None, max_len=None):
try:
import ftfy
import spacy
self.nlp = spacy.load('en', disable=['parser', 'tagger', 'ner', 'textcat'])
self.fix_text = ftfy.fix_text
except ImportError:
logger.warning("ftfy or spacy is not installed using BERT BasicTokenizer instead of SpaCy & ftfy.")
self.nlp = BasicTokenizer(do_lower_case=True,
never_split=special_tokens if special_tokens is not None else [])
self.fix_text = None
self.max_len = max_len if max_len is not None else int(1e12)
self.encoder = json.load(open(vocab_file, encoding="utf-8"))
self.decoder = {v:k for k,v in self.encoder.items()}
merges = open(merges_file, encoding='utf-8').read().split('\n')[:-1]
merges = [tuple(merge.split()[:2]) for merge in merges]
self.bpe_ranks = dict(zip(merges, range(len(merges))))
self.cache = {}
self.special_tokens = {}
self.special_tokens_decoder = {}
self.set_special_tokens(special_tokens)
def __len__(self):
return len(self.encoder) + len(self.special_tokens)
def set_special_tokens(self, special_tokens):
""" Add a list of additional tokens to the encoder.
The additional tokens are indexed starting from the last index of the
current vocabulary in the order of the `special_tokens` list.
"""
if not special_tokens:
self.special_tokens = {}
self.special_tokens_decoder = {}
return
self.special_tokens = dict((tok, len(self.encoder) + i) for i, tok in enumerate(special_tokens))
self.special_tokens_decoder = {v:k for k, v in self.special_tokens.items()}
if self.fix_text is None:
# Using BERT's BasicTokenizer: we can update the tokenizer
self.nlp.never_split = special_tokens
logger.info("Special tokens {}".format(self.special_tokens))
def bpe(self, token):
word = tuple(token[:-1]) + (token[-1] + '</w>',)
if token in self.cache:
return self.cache[token]
pairs = get_pairs(word)
if not pairs:
return token+'</w>'
while True:
bigram = min(pairs, key=lambda pair: self.bpe_ranks.get(pair, float('inf')))
if bigram not in self.bpe_ranks:
break
first, second = bigram
new_word = []
i = 0
while i < len(word):
try:
j = word.index(first, i)
new_word.extend(word[i:j])
i = j
except:
new_word.extend(word[i:])
break
if word[i] == first and i < len(word)-1 and word[i+1] == second:
new_word.append(first+second)
i += 2
else:
new_word.append(word[i])
i += 1
new_word = tuple(new_word)
word = new_word
if len(word) == 1:
break
else:
pairs = get_pairs(word)
word = ' '.join(word)
if word == '\n </w>':
word = '\n</w>'
self.cache[token] = word
return word
def tokenize(self, text):
""" Tokenize a string. """
split_tokens = []
if self.fix_text is None:
# Using BERT's BasicTokenizer
text = self.nlp.tokenize(text)
for token in text:
split_tokens.extend([t for t in self.bpe(token).split(' ')])
else:
# Using SpaCy & ftfy (original tokenization process of OpenAI GPT)
text = self.nlp(text_standardize(self.fix_text(text)))
for token in text:
split_tokens.extend([t for t in self.bpe(token.text.lower()).split(' ')])
return split_tokens
def convert_tokens_to_ids(self, tokens):
""" Converts a sequence of tokens into ids using the vocab. """
ids = []
if isinstance(tokens, str) or (sys.version_info[0] == 2 and isinstance(tokens, unicode)):
if tokens in self.special_tokens:
return self.special_tokens[tokens]
else:
return self.encoder.get(tokens, 0)
for token in tokens:
if token in self.special_tokens:
ids.append(self.special_tokens[token])
else:
ids.append(self.encoder.get(token, 0))
if len(ids) > self.max_len:
logger.warning(
"Token indices sequence length is longer than the specified maximum "
" sequence length for this OpenAI GPT model ({} > {}). Running this"
" sequence through the model will result in indexing errors".format(len(ids), self.max_len)
)
return ids
def convert_ids_to_tokens(self, ids, skip_special_tokens=False):
"""Converts a sequence of ids in BPE tokens using the vocab."""
tokens = []
for i in ids:
if i in self.special_tokens_decoder:
if not skip_special_tokens:
tokens.append(self.special_tokens_decoder[i])
else:
tokens.append(self.decoder[i])
return tokens
def encode(self, text):
return self.convert_tokens_to_ids(self.tokenize(text))
def decode(self, ids, skip_special_tokens=False, clean_up_tokenization_spaces=True):
"""Converts a sequence of ids in a string."""
tokens = self.convert_ids_to_tokens(ids, skip_special_tokens=skip_special_tokens)
out_string = ''.join(tokens).replace('</w>', ' ').strip()
if clean_up_tokenization_spaces:
out_string = out_string.replace('<unk>', '')
out_string = out_string.replace(' .', '.').replace(' ?', '?').replace(' !', '!').replace(' ,', ','
).replace(" ' ", "'").replace(" n't", "n't").replace(" 'm", "'m").replace(" do not", " don't"
).replace(" 's", "'s").replace(" 've", "'ve").replace(" 're", "'re")
return out_string
def save_vocabulary(self, vocab_path):
"""Save the tokenizer vocabulary and merge files to a directory."""
if not os.path.isdir(vocab_path):
logger.error("Vocabulary path ({}) should be a directory".format(vocab_path))
return
vocab_file = os.path.join(vocab_path, VOCAB_NAME)
merge_file = os.path.join(vocab_path, MERGES_NAME)
special_tokens_file = os.path.join(vocab_path, SPECIAL_TOKENS_NAME)
with open(vocab_file, 'w', encoding='utf-8') as f:
f.write(json.dumps(self.encoder, ensure_ascii=False))
index = 0
with open(merge_file, "w", encoding="utf-8") as writer:
writer.write(u'#version: 0.2\n')
for bpe_tokens, token_index in sorted(self.bpe_ranks.items(), key=lambda kv: kv[1]):
if index != token_index:
logger.warning("Saving vocabulary to {}: BPE merge indices are not consecutive."
" Please check that the tokenizer is not corrupted!".format(merge_file))
index = token_index
writer.write(' '.join(bpe_tokens) + u'\n')
index += 1
index = len(self.encoder)
with open(special_tokens_file, 'w', encoding='utf-8') as writer:
for token, token_index in sorted(self.special_tokens.items(), key=lambda kv: kv[1]):
if index != token_index:
logger.warning("Saving special tokens vocabulary to {}: BPE indices are not consecutive."
" Please check that the tokenizer is not corrupted!".format(special_tokens_file))
index = token_index
writer.write(token + u'\n')
index += 1
return vocab_file, merge_file, special_tokens_file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment