Commit e468192e authored by thomwolf's avatar thomwolf
Browse files

Merge branch 'pytorch-transformers' into xlnet

parents 9dd2c860 4ce237c8
......@@ -23,14 +23,13 @@ from io import open
import torch
import pytorch_pretrained_bert.tokenization_transfo_xl as data_utils
from pytorch_pretrained_bert.modeling_transfo_xl import (CONFIG_NAME,
import pytorch_transformers.tokenization_transfo_xl as data_utils
from pytorch_transformers.modeling_transfo_xl import (CONFIG_NAME,
WEIGHTS_NAME,
TransfoXLConfig,
TransfoXLLMHeadModel,
load_tf_weights_in_transfo_xl)
from pytorch_pretrained_bert.tokenization_transfo_xl import (CORPUS_NAME,
VOCAB_NAME)
from pytorch_transformers.tokenization_transfo_xl import (CORPUS_NAME, VOCAB_FILES_NAMES)
if sys.version_info[0] == 2:
import cPickle as pickle
......@@ -53,7 +52,7 @@ def convert_transfo_xl_checkpoint_to_pytorch(tf_checkpoint_path,
with open(transfo_xl_dataset_file, "rb") as fp:
corpus = pickle.load(fp, encoding="latin1")
# Save vocabulary and dataset cache as Dictionaries (should be better than pickles for the long-term)
pytorch_vocab_dump_path = pytorch_dump_folder_path + '/' + VOCAB_NAME
pytorch_vocab_dump_path = pytorch_dump_folder_path + '/' + VOCAB_FILES_NAMES['pretrained_vocab_file']
print("Save vocabulary to {}".format(pytorch_vocab_dump_path))
corpus_vocab_dict = corpus.vocab.__dict__
torch.save(corpus_vocab_dict, pytorch_vocab_dump_path)
......
......@@ -23,8 +23,8 @@ from io import open
import torch
import numpy
from pytorch_pretrained_bert.modeling_xlm import (CONFIG_NAME, WEIGHTS_NAME, XLMConfig, XLMModel)
from pytorch_pretrained_bert.tokenization_xlm import MERGES_NAME, VOCAB_NAME
from pytorch_transformers.modeling_xlm import (CONFIG_NAME, WEIGHTS_NAME, XLMConfig, XLMModel)
from pytorch_transformers.tokenization_xlm import VOCAB_FILES_NAMES
def convert_xlm_checkpoint_to_pytorch(xlm_checkpoint_path, pytorch_dump_folder_path):
......@@ -42,7 +42,7 @@ def convert_xlm_checkpoint_to_pytorch(xlm_checkpoint_path, pytorch_dump_folder_p
# Save pytorch-model
pytorch_weights_dump_path = pytorch_dump_folder_path + '/' + WEIGHTS_NAME
pytorch_config_dump_path = pytorch_dump_folder_path + '/' + CONFIG_NAME
pytorch_vocab_dump_path = pytorch_dump_folder_path + '/' + VOCAB_NAME
pytorch_vocab_dump_path = pytorch_dump_folder_path + '/' + VOCAB_FILES_NAMES['vocab_file']
print("Save PyTorch model to {}".format(pytorch_weights_dump_path))
torch.save(model, pytorch_weights_dump_path)
......
......@@ -22,7 +22,7 @@ import os
import argparse
import torch
from pytorch_pretrained_bert.modeling_xlnet import (CONFIG_NAME, WEIGHTS_NAME,
from pytorch_transformers.modeling_xlnet import (CONFIG_NAME, WEIGHTS_NAME,
XLNetConfig,
XLNetLMHeadModel, XLNetForQuestionAnswering,
XLNetForSequenceClassification,
......
......@@ -29,7 +29,7 @@ except ImportError:
torch_cache_home = os.path.expanduser(
os.getenv('TORCH_HOME', os.path.join(
os.getenv('XDG_CACHE_HOME', '~/.cache'), 'torch')))
default_cache_path = os.path.join(torch_cache_home, 'pytorch_pretrained_bert')
default_cache_path = os.path.join(torch_cache_home, 'pytorch_transformers')
try:
from urllib.parse import urlparse
......
......@@ -28,12 +28,11 @@ import torch
from torch import nn
from torch.nn import CrossEntropyLoss, MSELoss
from .file_utils import cached_path
from .model_utils import WEIGHTS_NAME, CONFIG_NAME, PretrainedConfig, PreTrainedModel, prune_linear_layer
from .modeling_utils import WEIGHTS_NAME, CONFIG_NAME, PretrainedConfig, PreTrainedModel, prune_linear_layer
logger = logging.getLogger(__name__)
PRETRAINED_MODEL_ARCHIVE_MAP = {
BERT_PRETRAINED_MODEL_ARCHIVE_MAP = {
'bert-base-uncased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-uncased-pytorch_model.bin",
'bert-large-uncased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-uncased-pytorch_model.bin",
'bert-base-cased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-cased-pytorch_model.bin",
......@@ -49,7 +48,7 @@ PRETRAINED_MODEL_ARCHIVE_MAP = {
'bert-base-cased-finetuned-mrpc': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-cased-finetuned-mrpc-pytorch_model.bin",
}
PRETRAINED_CONFIG_ARCHIVE_MAP = {
BERT_PRETRAINED_CONFIG_ARCHIVE_MAP = {
'bert-base-uncased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-uncased-config.json",
'bert-large-uncased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-uncased-config.json",
'bert-base-cased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-cased-config.json",
......@@ -545,7 +544,7 @@ class BertPreTrainedModel(PreTrainedModel):
a simple interface for dowloading and loading pretrained models.
"""
config_class = BertConfig
pretrained_model_archive_map = PRETRAINED_MODEL_ARCHIVE_MAP
pretrained_model_archive_map = BERT_PRETRAINED_MODEL_ARCHIVE_MAP
load_tf_weights = load_tf_weights_in_bert
base_model_prefix = "bert"
......
......@@ -30,16 +30,15 @@ import torch.nn as nn
from torch.nn import CrossEntropyLoss
from torch.nn.parameter import Parameter
from .file_utils import cached_path
from .model_utils import (Conv1D, CONFIG_NAME, WEIGHTS_NAME, PretrainedConfig,
from .modeling_utils import (Conv1D, CONFIG_NAME, WEIGHTS_NAME, PretrainedConfig,
PreTrainedModel, prune_conv1d_layer, SequenceSummary)
from .modeling_bert import BertLayerNorm as LayerNorm
logger = logging.getLogger(__name__)
PRETRAINED_MODEL_ARCHIVE_MAP = {"gpt2": "https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-pytorch_model.bin",
GPT2_PRETRAINED_MODEL_ARCHIVE_MAP = {"gpt2": "https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-pytorch_model.bin",
"gpt2-medium": "https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-medium-pytorch_model.bin"}
PRETRAINED_CONFIG_ARCHIVE_MAP = {"gpt2": "https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-config.json",
GPT2_PRETRAINED_CONFIG_ARCHIVE_MAP = {"gpt2": "https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-config.json",
"gpt2-medium": "https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-medium-config.json"}
def load_tf_weights_in_gpt2(model, config, gpt2_checkpoint_path):
......@@ -103,7 +102,7 @@ def gelu(x):
class GPT2Config(PretrainedConfig):
"""Configuration class to store the configuration of a `GPT2Model`.
"""
pretrained_config_archive_map = PRETRAINED_CONFIG_ARCHIVE_MAP
pretrained_config_archive_map = GPT2_PRETRAINED_CONFIG_ARCHIVE_MAP
def __init__(
self,
......@@ -120,11 +119,13 @@ class GPT2Config(PretrainedConfig):
layer_norm_epsilon=1e-5,
initializer_range=0.02,
predict_special_tokens=True,
num_labels=1,
summary_type='token_ids',
summary_use_proj=True,
summary_num_classes=1,
summary_activation=None,
summary_dropout=0.1,
summary_proj_to_labels=True,
summary_first_dropout=0.1,
**kwargs
):
"""Constructs GPT2Config.
......@@ -170,11 +171,13 @@ class GPT2Config(PretrainedConfig):
self.layer_norm_epsilon = layer_norm_epsilon
self.initializer_range = initializer_range
self.predict_special_tokens = predict_special_tokens
self.num_labels = num_labels
self.summary_type = summary_type
self.summary_use_proj = summary_use_proj
self.summary_num_classes = summary_num_classes
self.summary_activation = summary_activation
self.summary_dropout = summary_dropout
self.summary_first_dropout = summary_first_dropout
self.summary_proj_to_labels = summary_proj_to_labels
else:
raise ValueError(
"First argument must be either a vocabulary size (int)"
......@@ -358,7 +361,7 @@ class GPT2PreTrainedModel(PreTrainedModel):
a simple interface for dowloading and loading pretrained models.
"""
config_class = GPT2Config
pretrained_model_archive_map = PRETRAINED_MODEL_ARCHIVE_MAP
pretrained_model_archive_map = GPT2_PRETRAINED_MODEL_ARCHIVE_MAP
load_tf_weights = load_tf_weights_in_gpt2
base_model_prefix = "transformer"
......
......@@ -30,15 +30,14 @@ import torch.nn as nn
from torch.nn import CrossEntropyLoss
from torch.nn.parameter import Parameter
from .file_utils import cached_path
from .model_utils import (Conv1D, CONFIG_NAME, WEIGHTS_NAME, PretrainedConfig,
PreTrainedModel, prune_conv1d_layer, SequenceSummary)
from .modeling_utils import (Conv1D, CONFIG_NAME, WEIGHTS_NAME, PretrainedConfig,
PreTrainedModel, prune_conv1d_layer, SequenceSummary)
from .modeling_bert import BertLayerNorm as LayerNorm
logger = logging.getLogger(__name__)
PRETRAINED_MODEL_ARCHIVE_MAP = {"openai-gpt": "https://s3.amazonaws.com/models.huggingface.co/bert/openai-gpt-pytorch_model.bin"}
PRETRAINED_CONFIG_ARCHIVE_MAP = {"openai-gpt": "https://s3.amazonaws.com/models.huggingface.co/bert/openai-gpt-config.json"}
OPENAI_GPT_PRETRAINED_MODEL_ARCHIVE_MAP = {"openai-gpt": "https://s3.amazonaws.com/models.huggingface.co/bert/openai-gpt-pytorch_model.bin"}
OPENAI_GPT_PRETRAINED_CONFIG_ARCHIVE_MAP = {"openai-gpt": "https://s3.amazonaws.com/models.huggingface.co/bert/openai-gpt-config.json"}
def load_tf_weights_in_openai_gpt(model, config, openai_checkpoint_folder_path):
......@@ -130,7 +129,7 @@ ACT_FNS = {"relu": nn.ReLU, "swish": swish, "gelu": gelu}
class OpenAIGPTConfig(PretrainedConfig):
"""Configuration class to store the configuration of a `OpenAIGPTModel`.
"""
pretrained_config_archive_map = PRETRAINED_CONFIG_ARCHIVE_MAP
pretrained_config_archive_map = OPENAI_GPT_PRETRAINED_CONFIG_ARCHIVE_MAP
def __init__(
self,
......@@ -148,11 +147,13 @@ class OpenAIGPTConfig(PretrainedConfig):
layer_norm_epsilon=1e-5,
initializer_range=0.02,
predict_special_tokens=True,
num_labels=1,
summary_type='token_ids',
summary_use_proj=True,
summary_num_classes=1,
summary_activation=None,
summary_dropout=0.1,
summary_proj_to_labels=True,
summary_first_dropout=0.1,
**kwargs
):
"""Constructs OpenAIGPTConfig.
......@@ -201,11 +202,13 @@ class OpenAIGPTConfig(PretrainedConfig):
self.layer_norm_epsilon = layer_norm_epsilon
self.initializer_range = initializer_range
self.predict_special_tokens = predict_special_tokens
self.num_labels = num_labels
self.summary_type = summary_type
self.summary_use_proj = summary_use_proj
self.summary_num_classes = summary_num_classes
self.summary_activation = summary_activation
self.summary_dropout = summary_dropout
self.summary_first_dropout = summary_first_dropout
self.summary_proj_to_labels = summary_proj_to_labels
else:
raise ValueError(
"First argument must be either a vocabulary size (int)"
......@@ -384,7 +387,7 @@ class OpenAIGPTPreTrainedModel(PreTrainedModel):
a simple interface for dowloading and loading pretrained models.
"""
config_class = OpenAIGPTConfig
pretrained_model_archive_map = PRETRAINED_MODEL_ARCHIVE_MAP
pretrained_model_archive_map = OPENAI_GPT_PRETRAINED_MODEL_ARCHIVE_MAP
load_tf_weights = load_tf_weights_in_openai_gpt
base_model_prefix = "transformer"
......
......@@ -36,15 +36,14 @@ from torch.nn.parameter import Parameter
from .modeling_bert import BertLayerNorm as LayerNorm
from .modeling_transfo_xl_utilities import ProjectedAdaptiveLogSoftmax, sample_logits
from .file_utils import cached_path
from .model_utils import CONFIG_NAME, WEIGHTS_NAME, PretrainedConfig, PreTrainedModel
from .modeling_utils import CONFIG_NAME, WEIGHTS_NAME, PretrainedConfig, PreTrainedModel
logger = logging.getLogger(__name__)
PRETRAINED_MODEL_ARCHIVE_MAP = {
TRANSFO_XL_PRETRAINED_MODEL_ARCHIVE_MAP = {
'transfo-xl-wt103': "https://s3.amazonaws.com/models.huggingface.co/bert/transfo-xl-wt103-pytorch_model.bin",
}
PRETRAINED_CONFIG_ARCHIVE_MAP = {
TRANSFO_XL_PRETRAINED_CONFIG_ARCHIVE_MAP = {
'transfo-xl-wt103': "https://s3.amazonaws.com/models.huggingface.co/bert/transfo-xl-wt103-config.json",
}
......@@ -179,7 +178,7 @@ def load_tf_weights_in_transfo_xl(model, config, tf_path):
class TransfoXLConfig(PretrainedConfig):
"""Configuration class to store the configuration of a `TransfoXLModel`.
"""
pretrained_config_archive_map = PRETRAINED_CONFIG_ARCHIVE_MAP
pretrained_config_archive_map = TRANSFO_XL_PRETRAINED_CONFIG_ARCHIVE_MAP
def __init__(self,
vocab_size_or_config_json_file=267735,
......@@ -838,7 +837,7 @@ class TransfoXLPreTrainedModel(PreTrainedModel):
a simple interface for dowloading and loading pretrained models.
"""
config_class = TransfoXLConfig
pretrained_model_archive_map = PRETRAINED_MODEL_ARCHIVE_MAP
pretrained_model_archive_map = TRANSFO_XL_PRETRAINED_MODEL_ARCHIVE_MAP
load_tf_weights = load_tf_weights_in_transfo_xl
base_model_prefix = "transformer"
......
......@@ -25,7 +25,7 @@ from io import open
import torch
from torch import nn
from torch.nn import CrossEntropyLoss, MSELoss, functional as F
from torch.nn import CrossEntropyLoss, functional as F
from .file_utils import cached_path
......@@ -169,6 +169,22 @@ class PreTrainedModel(nn.Module):
model_to_prune = getattr(self, self.base_model_prefix, self) # get the base model if needed
model_to_prune._prune_heads(heads_to_prune)
def save_pretrained(self, save_directory):
""" Save a model with its configuration file to a directory, so that it
can be re-loaded using the `from_pretrained(save_directory)` class method.
"""
assert os.path.isdir(save_directory), "Saving path should be a directory where the model and configuration can be saved"
# Only save the model it-self if we are using distributed training
model_to_save = self.module if hasattr(self, 'module') else self
# If we save using the predefined names, we can load using `from_pretrained`
output_model_file = os.path.join(save_directory, WEIGHTS_NAME)
output_config_file = os.path.join(save_directory, CONFIG_NAME)
torch.save(model_to_save.state_dict(), output_model_file)
model_to_save.config.to_json_file(output_config_file)
@classmethod
def from_pretrained(cls, pretrained_model_name_or_path, *inputs, **kwargs):
"""
......@@ -193,7 +209,8 @@ class PreTrainedModel(nn.Module):
"""
state_dict = kwargs.pop('state_dict', None)
cache_dir = kwargs.pop('cache_dir', None)
from_tf = kwargs.pop('from_tf', None)
from_tf = kwargs.pop('from_tf', False)
output_loading_info = kwargs.pop('output_loading_info', False)
# Load config
config = cls.config_class.from_pretrained(pretrained_model_name_or_path, *inputs, **kwargs)
......@@ -239,6 +256,21 @@ class PreTrainedModel(nn.Module):
# Directly load from a TensorFlow checkpoint
return cls.load_tf_weights(model, config, resolved_archive_file[:-6]) # Remove the '.index'
# Convert old format to new format if needed from a PyTorch state_dict
old_keys = []
new_keys = []
for key in state_dict.keys():
new_key = None
if 'gamma' in key:
new_key = key.replace('gamma', 'weight')
if 'beta' in key:
new_key = key.replace('beta', 'bias')
if new_key:
old_keys.append(key)
new_keys.append(new_key)
for old_key, new_key in zip(old_keys, new_keys):
state_dict[new_key] = state_dict.pop(old_key)
# Load from a PyTorch state_dict
missing_keys = []
unexpected_keys = []
......@@ -279,6 +311,10 @@ class PreTrainedModel(nn.Module):
if hasattr(model, 'tie_weights'):
model.tie_weights() # make sure word embedding weights are still tied
if output_loading_info:
loading_info = {"missing_keys": missing_keys, "unexpected_keys": unexpected_keys, "error_msgs": error_msgs}
return model, loading_info
return model
......@@ -478,10 +514,10 @@ class SequenceSummary(nn.Module):
- 'token_ids' => supply a Tensor of classification token indices (GPT/GPT-2)
- 'attn' => Not implemented now, use multi-head attention
summary_use_proj: Add a projection after the vector extraction
summary_num_classes: If > 0: the projection outputs to n classes (otherwise to hidden_size)
summary_activation:
'tanh' => add a tanh activation to the output
None => no activation
summary_proj_to_labels: If True, the projection outputs to config.num_labels classes (otherwise to hidden_size). Default: False.
summary_activation: 'tanh' => add a tanh activation to the output, Other => no activation. Default
summary_first_dropout: Add a dropout before the projection and activation
summary_last_dropout: Add a dropout after the projection and activation
"""
def __init__(self, config):
super(SequenceSummary, self).__init__()
......@@ -495,8 +531,8 @@ class SequenceSummary(nn.Module):
self.summary = nn.Identity()
if hasattr(config, 'summary_use_proj') and config.summary_use_proj:
if hasattr(config, 'summary_num_classes') and config.summary_num_classes > 0:
num_classes = config.summary_num_classes
if hasattr(config, 'summary_proj_to_labels') and config.summary_proj_to_labels and config.num_labels > 0:
num_classes = config.num_labels
else:
num_classes = config.hidden_size
self.summary = nn.Linear(config.hidden_size, num_classes)
......@@ -505,7 +541,13 @@ class SequenceSummary(nn.Module):
if hasattr(config, 'summary_activation') and config.summary_activation == 'tanh':
self.activation = nn.Tanh()
self.dropout = nn.Dropout(config.summary_dropout)
self.first_dropout = nn.Identity()
if hasattr(config, 'summary_first_dropout') and config.summary_first_dropout > 0:
self.first_dropout = nn.Dropout(config.summary_first_dropout)
self.last_dropout = nn.Identity()
if hasattr(config, 'summary_last_dropout') and config.summary_last_dropout > 0:
self.last_dropout = nn.Dropout(config.summary_last_dropout)
def forward(self, hidden_states, token_ids=None):
""" hidden_states: float Tensor in shape [bsz, seq_len, hidden_size], the hidden-states of the last layer.
......@@ -531,9 +573,10 @@ class SequenceSummary(nn.Module):
elif self.summary_type == 'attn':
raise NotImplementedError
output = self.first_dropout(output)
output = self.summary(output)
output = self.activation(output)
output = self.dropout(output)
output = self.last_dropout(output)
return output
......@@ -598,9 +641,3 @@ def prune_layer(layer, index, dim=None):
return prune_conv1d_layer(layer, index, dim=1 if dim is None else dim)
else:
raise ValueError("Can't prune layer of class {}".format(layer.__class__))
def clean_up_tokenization(out_string):
out_string.replace(' .', '.').replace(' ?', '?').replace(' !', '!').replace(' ,', ','
).replace(" ' ", "'").replace(" n't", "n't").replace(" 'm", "'m").replace(" do not", " don't"
).replace(" 's", "'s").replace(" 've", "'ve").replace(" 're", "'re")
return out_string
......@@ -14,18 +14,14 @@
# limitations under the License.
""" PyTorch XLM model.
"""
from __future__ import (absolute_import, division, print_function,
unicode_literals)
from __future__ import absolute_import, division, print_function, unicode_literals
import json
import logging
import math
import os
import sys
from io import open
import math
import itertools
import numpy as np
......@@ -34,16 +30,15 @@ from torch import nn
from torch.nn import functional as F
from torch.nn import CrossEntropyLoss, MSELoss
from .file_utils import cached_path
from .model_utils import (CONFIG_NAME, WEIGHTS_NAME, PretrainedConfig, PreTrainedModel,
prune_linear_layer, SequenceSummary, SQuADHead)
from .modeling_utils import (PretrainedConfig, PreTrainedModel,
prune_linear_layer, SequenceSummary, SQuADHead)
logger = logging.getLogger(__name__)
PRETRAINED_MODEL_ARCHIVE_MAP = {
XLM_PRETRAINED_MODEL_ARCHIVE_MAP = {
'xlm-mlm-en-2048': "https://s3.amazonaws.com/models.huggingface.co/bert/xlm-mlm-en-2048-pytorch_model.bin",
}
PRETRAINED_CONFIG_ARCHIVE_MAP = {
XLM_PRETRAINED_CONFIG_ARCHIVE_MAP = {
'xlm-mlm-en-2048': "https://s3.amazonaws.com/models.huggingface.co/bert/xlm-mlm-en-2048-config.json",
}
......@@ -51,7 +46,7 @@ PRETRAINED_CONFIG_ARCHIVE_MAP = {
class XLMConfig(PretrainedConfig):
"""Configuration class to store the configuration of a `XLMModel`.
"""
pretrained_config_archive_map = PRETRAINED_CONFIG_ARCHIVE_MAP
pretrained_config_archive_map = XLM_PRETRAINED_CONFIG_ARCHIVE_MAP
def __init__(self,
vocab_size_or_config_json_file=30145,
......@@ -79,10 +74,11 @@ class XLMConfig(PretrainedConfig):
finetuning_task=None,
num_labels=2,
summary_type='last',
summary_type='first',
summary_use_proj=True,
summary_activation='tanh',
summary_dropout=0.1,
summary_activation=None,
summary_proj_to_labels=True,
summary_first_dropout=0.1,
start_n_top=5,
end_n_top=5,
**kwargs):
......@@ -164,7 +160,8 @@ class XLMConfig(PretrainedConfig):
self.summary_type = summary_type
self.summary_use_proj = summary_use_proj
self.summary_activation = summary_activation
self.summary_dropout = summary_dropout
self.summary_proj_to_labels = summary_proj_to_labels
self.summary_first_dropout = summary_first_dropout
self.start_n_top = start_n_top
self.end_n_top = end_n_top
else:
......@@ -204,7 +201,7 @@ def gelu(x):
GELU activation
https://arxiv.org/abs/1606.08415
https://github.com/huggingface/pytorch-openai-transformer-lm/blob/master/model_pytorch.py#L14
https://github.com/huggingface/pytorch-pretrained-BERT/blob/master/modeling.py
https://github.com/huggingface/pytorch-transformers/blob/master/modeling.py
"""
# return 0.5 * x * (1 + torch.tanh(math.sqrt(2 / math.pi) * (x + 0.044715 * torch.pow(x, 3))))
return 0.5 * x * (1.0 + torch.erf(x / math.sqrt(2.0)))
......@@ -357,7 +354,7 @@ class XLMPreTrainedModel(PreTrainedModel):
a simple interface for dowloading and loading pretrained models.
"""
config_class = XLMConfig
pretrained_model_archive_map = PRETRAINED_MODEL_ARCHIVE_MAP
pretrained_model_archive_map = XLM_PRETRAINED_MODEL_ARCHIVE_MAP
load_tf_weights = None
base_model_prefix = "transformer"
......
......@@ -31,17 +31,16 @@ from torch import nn
from torch.nn import functional as F
from torch.nn import CrossEntropyLoss, MSELoss
from .file_utils import cached_path
from .model_utils import (CONFIG_NAME, WEIGHTS_NAME, PretrainedConfig, PreTrainedModel,
SequenceSummary, PoolerAnswerClass, PoolerEndLogits, PoolerStartLogits)
from .modeling_utils import (CONFIG_NAME, WEIGHTS_NAME, PretrainedConfig, PreTrainedModel,
SequenceSummary, PoolerAnswerClass, PoolerEndLogits, PoolerStartLogits)
logger = logging.getLogger(__name__)
PRETRAINED_MODEL_ARCHIVE_MAP = {
XLNET_PRETRAINED_MODEL_ARCHIVE_MAP = {
'xlnet-large-cased': "https://s3.amazonaws.com/models.huggingface.co/bert/xlnet-large-cased-pytorch_model.bin",
}
PRETRAINED_CONFIG_ARCHIVE_MAP = {
XLNET_PRETRAINED_CONFIG_ARCHIVE_MAP = {
'xlnet-large-cased': "https://s3.amazonaws.com/models.huggingface.co/bert/xlnet-large-cased-config.json",
}
......@@ -195,7 +194,7 @@ ACT2FN = {"gelu": gelu, "relu": torch.nn.functional.relu, "swish": swish}
class XLNetConfig(PretrainedConfig):
"""Configuration class to store the configuration of a `XLNetModel`.
"""
pretrained_config_archive_map = PRETRAINED_CONFIG_ARCHIVE_MAP
pretrained_config_archive_map = XLNET_PRETRAINED_CONFIG_ARCHIVE_MAP
def __init__(self,
vocab_size_or_config_json_file=32000,
......@@ -227,7 +226,7 @@ class XLNetConfig(PretrainedConfig):
summary_type='last',
summary_use_proj=True,
summary_activation='tanh',
summary_dropout=0.1,
summary_last_dropout=0.1,
start_n_top=5,
end_n_top=5,
**kwargs):
......@@ -314,7 +313,7 @@ class XLNetConfig(PretrainedConfig):
self.summary_type = summary_type
self.summary_use_proj = summary_use_proj
self.summary_activation = summary_activation
self.summary_dropout = summary_dropout
self.summary_last_dropout = summary_last_dropout
self.start_n_top = start_n_top
self.end_n_top = end_n_top
else:
......@@ -593,7 +592,7 @@ class XLNetPreTrainedModel(PreTrainedModel):
a simple interface for dowloading and loading pretrained models.
"""
config_class = XLNetConfig
pretrained_model_archive_map = PRETRAINED_MODEL_ARCHIVE_MAP
pretrained_model_archive_map = XLNET_PRETRAINED_MODEL_ARCHIVE_MAP
load_tf_weights = load_tf_weights_in_xlnet
base_model_prefix = "transformer"
......
......@@ -20,13 +20,13 @@ import unittest
import shutil
import pytest
from pytorch_pretrained_bert import (BertConfig, BertModel, BertForMaskedLM,
from pytorch_transformers import (BertConfig, BertModel, BertForMaskedLM,
BertForNextSentencePrediction, BertForPreTraining,
BertForQuestionAnswering, BertForSequenceClassification,
BertForTokenClassification, BertForMultipleChoice)
from pytorch_pretrained_bert.modeling_bert import PRETRAINED_MODEL_ARCHIVE_MAP
from pytorch_transformers.modeling_bert import BERT_PRETRAINED_MODEL_ARCHIVE_MAP
from .model_tests_commons import (create_and_check_commons, ConfigTester, ids_tensor)
from .modeling_tests_commons import (create_and_check_commons, ConfigTester, ids_tensor)
class BertModelTest(unittest.TestCase):
......@@ -266,8 +266,8 @@ class BertModelTest(unittest.TestCase):
@pytest.mark.slow
def test_model_from_pretrained(self):
cache_dir = "/tmp/pytorch_pretrained_bert_test/"
for model_name in list(PRETRAINED_MODEL_ARCHIVE_MAP.keys())[:1]:
cache_dir = "/tmp/pytorch_transformers_test/"
for model_name in list(BERT_PRETRAINED_MODEL_ARCHIVE_MAP.keys())[:1]:
model = BertModel.from_pretrained(model_name, cache_dir=cache_dir)
shutil.rmtree(cache_dir)
self.assertIsNotNone(model)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment