Commit 7e3070ae authored by thomwolf's avatar thomwolf
Browse files

add from_pretrained method to all configuration classes

parent 93e9971c
......@@ -18,7 +18,7 @@ from .modeling_transfo_xl import (TransfoXLConfig, TransfoXLModel, TransfoXLLMHe
from .modeling_gpt2 import (GPT2Config, GPT2Model,
GPT2LMHeadModel, GPT2DoubleHeadsModel, GPT2MultipleChoiceHead,
load_tf_weights_in_gpt2)
from .modeling_xlnet import (XLNetBaseConfig, XLNetConfig, XLNetRunConfig,
from .modeling_xlnet import (XLNetConfig,
XLNetPreTrainedModel, XLNetModel, XLNetLMHeadModel,
XLNetForSequenceClassification, XLNetForQuestionAnswering,
load_tf_weights_in_xlnet)
......@@ -26,5 +26,6 @@ from .modeling_xlnet import (XLNetBaseConfig, XLNetConfig, XLNetRunConfig,
from .optimization import BertAdam
from .optimization_openai import OpenAIAdam
from .file_utils import (PYTORCH_PRETRAINED_BERT_CACHE, cached_path,
WEIGHTS_NAME, CONFIG_NAME)
from .file_utils import (PYTORCH_PRETRAINED_BERT_CACHE, cached_path)
from .model_utils import (WEIGHTS_NAME, CONFIG_NAME, PretrainedConfig)
......@@ -23,7 +23,7 @@ import argparse
import torch
from pytorch_pretrained_bert.modeling_xlnet import (CONFIG_NAME, WEIGHTS_NAME,
XLNetConfig, XLNetRunConfig,
XLNetConfig,
XLNetLMHeadModel, XLNetForQuestionAnswering,
XLNetForSequenceClassification,
load_tf_weights_in_xlnet)
......
......@@ -44,9 +44,6 @@ except (AttributeError, ImportError):
PYTORCH_PRETRAINED_BERT_CACHE = os.getenv('PYTORCH_PRETRAINED_BERT_CACHE',
default_cache_path)
CONFIG_NAME = "config.json"
WEIGHTS_NAME = "pytorch_model.bin"
logger = logging.getLogger(__name__) # pylint: disable=invalid-name
......
# coding=utf-8
# Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team.
# Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""PyTorch BERT model."""
from __future__ import absolute_import, division, print_function, unicode_literals
import logging
import os
import json
import copy
import torch
from torch import nn
from torch.nn import CrossEntropyLoss, MSELoss
from .file_utils import cached_path
logger = logging.getLogger(__name__)
CONFIG_NAME = "config.json"
WEIGHTS_NAME = "pytorch_model.bin"
class PretrainedConfig(object):
""" An abstract class to handle dowloading a model pretrained config.
"""
pretrained_config_archive_map = {}
@classmethod
def from_pretrained(cls, pretrained_model_name_or_path, **kwargs):
"""
Instantiate a PretrainedConfig from a pre-trained model configuration.
Params:
pretrained_model_name_or_path: either:
- a str with the name of a pre-trained model to load selected in the list of:
. `xlnet-large-cased`
- a path or url to a pretrained model archive containing:
. `config.json` a configuration file for the model
cache_dir: an optional path to a folder in which the pre-trained model configuration will be cached.
"""
cache_dir = kwargs.get('cache_dir', None)
kwargs.pop('cache_dir', None)
if pretrained_model_name_or_path in cls.pretrained_config_archive_map:
config_file = cls.pretrained_config_archive_map[pretrained_model_name_or_path]
else:
config_file = os.path.join(pretrained_model_name_or_path, CONFIG_NAME)
# redirect to the cache, if necessary
try:
resolved_config_file = cached_path(config_file, cache_dir=cache_dir)
except EnvironmentError:
if pretrained_model_name_or_path in cls.pretrained_config_archive_map:
logger.error(
"Couldn't reach server at '{}' to download pretrained model configuration file.".format(
config_file))
else:
logger.error(
"Model name '{}' was not found in model name list ({}). "
"We assumed '{}' was a path or url but couldn't find any file "
"associated to this path or url.".format(
pretrained_model_name_or_path,
', '.join(cls.pretrained_config_archive_map.keys()),
config_file))
return None
if resolved_config_file == config_file:
logger.info("loading configuration file {}".format(config_file))
else:
logger.info("loading configuration file {} from cache at {}".format(
config_file, resolved_config_file))
# Load config
config = cls.from_json_file(resolved_config_file)
# Update config with kwargs if needed
to_remove = []
for key, value in kwargs.items():
if hasattr(config, key):
setattr(config, key, value)
to_remove.append(key)
for key in to_remove:
kwargs.pop(key, None)
logger.info("Model config {}".format(config))
return config
@classmethod
def from_dict(cls, json_object):
"""Constructs a `Config` from a Python dictionary of parameters."""
config = cls(vocab_size_or_config_json_file=-1)
for key, value in json_object.items():
config.__dict__[key] = value
return config
@classmethod
def from_json_file(cls, json_file):
"""Constructs a `BertConfig` from a json file of parameters."""
with open(json_file, "r", encoding='utf-8') as reader:
text = reader.read()
return cls.from_dict(json.loads(text))
def __repr__(self):
return str(self.to_json_string())
def to_dict(self):
"""Serializes this instance to a Python dictionary."""
output = copy.deepcopy(self.__dict__)
return output
def to_json_string(self):
"""Serializes this instance to a JSON string."""
return json.dumps(self.to_dict(), indent=2, sort_keys=True) + "\n"
def to_json_file(self, json_file_path):
""" Save this instance to a json file."""
with open(json_file_path, "w", encoding='utf-8') as writer:
writer.write(self.to_json_string())
def prune_linear_layer(layer, index, dim=0):
""" Prune a linear layer (a model parameters) to keep only entries in index.
Return the pruned layer as a new layer with requires_grad=True.
Used to remove heads.
"""
index = index.to(layer.weight.device)
W = layer.weight.index_select(dim, index).clone().detach()
if layer.bias is not None:
if dim == 1:
b = layer.bias.clone().detach()
else:
b = layer.bias[index].clone().detach()
new_size = list(layer.weight.size())
new_size[dim] = len(index)
new_layer = nn.Linear(new_size[1], new_size[0], bias=layer.bias is not None).to(layer.weight.device)
new_layer.weight.requires_grad = False
new_layer.weight.copy_(W.contiguous())
new_layer.weight.requires_grad = True
if layer.bias is not None:
new_layer.bias.requires_grad = False
new_layer.bias.copy_(b.contiguous())
new_layer.bias.requires_grad = True
return new_layer
class Conv1D(nn.Module):
""" Conv1D layer as defined by Alec Radford for GPT (and also used in GPT-2)
Basically works like a Linear layer but the weights are transposed
"""
def __init__(self, nf, nx):
super(Conv1D, self).__init__()
self.nf = nf
w = torch.empty(nx, nf)
nn.init.normal_(w, std=0.02)
self.weight = nn.Parameter(w)
self.bias = nn.Parameter(torch.zeros(nf))
def forward(self, x):
size_out = x.size()[:-1] + (self.nf,)
x = torch.addmm(self.bias, x.view(-1, x.size(-1)), self.weight)
x = x.view(*size_out)
return x
def prune_conv1d_layer(layer, index, dim=1):
""" Prune a Conv1D layer (a model parameters) to keep only entries in index.
A Conv1D work as a Linear layer (see e.g. BERT) but the weights are transposed.
Return the pruned layer as a new layer with requires_grad=True.
Used to remove heads.
"""
index = index.to(layer.weight.device)
W = layer.weight.index_select(dim, index).clone().detach()
if dim == 0:
b = layer.bias.clone().detach()
else:
b = layer.bias[index].clone().detach()
new_size = list(layer.weight.size())
new_size[dim] = len(index)
new_layer = Conv1D(new_size[1], new_size[0]).to(layer.weight.device)
new_layer.weight.requires_grad = False
new_layer.weight.copy_(W.contiguous())
new_layer.weight.requires_grad = True
new_layer.bias.requires_grad = False
new_layer.bias.copy_(b.contiguous())
new_layer.bias.requires_grad = True
return new_layer
......@@ -29,7 +29,8 @@ import torch
from torch import nn
from torch.nn import CrossEntropyLoss, MSELoss
from .file_utils import cached_path, WEIGHTS_NAME, CONFIG_NAME
from .file_utils import cached_path
from .model_utils import WEIGHTS_NAME, CONFIG_NAME, PretrainedConfig, prune_linear_layer
logger = logging.getLogger(__name__)
......@@ -66,30 +67,6 @@ PRETRAINED_CONFIG_ARCHIVE_MAP = {
BERT_CONFIG_NAME = 'bert_config.json'
TF_WEIGHTS_NAME = 'model.ckpt'
def prune_linear_layer(layer, index, dim=0):
""" Prune a linear layer (a model parameters) to keep only entries in index.
Return the pruned layer as a new layer with requires_grad=True.
Used to remove heads.
"""
index = index.to(layer.weight.device)
W = layer.weight.index_select(dim, index).clone().detach()
if layer.bias is not None:
if dim == 1:
b = layer.bias.clone().detach()
else:
b = layer.bias[index].clone().detach()
new_size = list(layer.weight.size())
new_size[dim] = len(index)
new_layer = nn.Linear(new_size[1], new_size[0], bias=layer.bias is not None).to(layer.weight.device)
new_layer.weight.requires_grad = False
new_layer.weight.copy_(W.contiguous())
new_layer.weight.requires_grad = True
if layer.bias is not None:
new_layer.bias.requires_grad = False
new_layer.bias.copy_(b.contiguous())
new_layer.bias.requires_grad = True
return new_layer
def load_tf_weights_in_bert(model, tf_checkpoint_path):
""" Load tf checkpoints in a pytorch model
......@@ -174,9 +151,11 @@ def swish(x):
ACT2FN = {"gelu": gelu, "relu": torch.nn.functional.relu, "swish": swish}
class BertConfig(object):
class BertConfig(PretrainedConfig):
"""Configuration class to store the configuration of a `BertModel`.
"""
pretrained_config_archive_map = PRETRAINED_CONFIG_ARCHIVE_MAP
def __init__(self,
vocab_size_or_config_json_file,
hidden_size=768,
......@@ -238,37 +217,6 @@ class BertConfig(object):
raise ValueError("First argument must be either a vocabulary size (int)"
"or the path to a pretrained model config file (str)")
@classmethod
def from_dict(cls, json_object):
"""Constructs a `BertConfig` from a Python dictionary of parameters."""
config = BertConfig(vocab_size_or_config_json_file=-1)
for key, value in json_object.items():
config.__dict__[key] = value
return config
@classmethod
def from_json_file(cls, json_file):
"""Constructs a `BertConfig` from a json file of parameters."""
with open(json_file, "r", encoding='utf-8') as reader:
text = reader.read()
return cls.from_dict(json.loads(text))
def __repr__(self):
return str(self.to_json_string())
def to_dict(self):
"""Serializes this instance to a Python dictionary."""
output = copy.deepcopy(self.__dict__)
return output
def to_json_string(self):
"""Serializes this instance to a JSON string."""
return json.dumps(self.to_dict(), indent=2, sort_keys=True) + "\n"
def to_json_file(self, json_file_path):
""" Save this instance to a json file."""
with open(json_file_path, "w", encoding='utf-8') as writer:
writer.write(self.to_json_string())
try:
from apex.normalization.fused_layer_norm import FusedLayerNorm as BertLayerNorm
......
......@@ -31,7 +31,8 @@ import torch.nn as nn
from torch.nn import CrossEntropyLoss
from torch.nn.parameter import Parameter
from .file_utils import cached_path, CONFIG_NAME, WEIGHTS_NAME
from .file_utils import cached_path
from .model_utils import Conv1D, CONFIG_NAME, WEIGHTS_NAME, PretrainedConfig, prune_conv1d_layer
from .modeling import BertLayerNorm as LayerNorm
logger = logging.getLogger(__name__)
......@@ -41,30 +42,6 @@ PRETRAINED_MODEL_ARCHIVE_MAP = {"gpt2": "https://s3.amazonaws.com/models.hugging
PRETRAINED_CONFIG_ARCHIVE_MAP = {"gpt2": "https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-config.json",
"gpt2-medium": "https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-medium-config.json"}
def prune_conv1d_layer(layer, index, dim=1):
""" Prune a Conv1D layer (a model parameters) to keep only entries in index.
A Conv1D work as a Linear layer (see e.g. BERT) but the weights are transposed.
Return the pruned layer as a new layer with requires_grad=True.
Used to remove heads.
"""
index = index.to(layer.weight.device)
W = layer.weight.index_select(dim, index).clone().detach()
if dim == 0:
b = layer.bias.clone().detach()
else:
b = layer.bias[index].clone().detach()
new_size = list(layer.weight.size())
new_size[dim] = len(index)
new_layer = Conv1D(new_size[1], new_size[0]).to(layer.weight.device)
new_layer.weight.requires_grad = False
new_layer.weight.copy_(W.contiguous())
new_layer.weight.requires_grad = True
new_layer.bias.requires_grad = False
new_layer.bias.copy_(b.contiguous())
new_layer.bias.requires_grad = True
return new_layer
def load_tf_weights_in_gpt2(model, gpt2_checkpoint_path):
""" Load tf checkpoints in a pytorch model
"""
......@@ -123,9 +100,10 @@ def gelu(x):
return 0.5 * x * (1 + torch.tanh(math.sqrt(2 / math.pi) * (x + 0.044715 * torch.pow(x, 3))))
class GPT2Config(object):
class GPT2Config(PretrainedConfig):
"""Configuration class to store the configuration of a `GPT2Model`.
"""
pretrained_config_archive_map = PRETRAINED_CONFIG_ARCHIVE_MAP
def __init__(
self,
......@@ -194,54 +172,6 @@ class GPT2Config(object):
def total_tokens_embeddings(self):
return self.vocab_size + self.n_special
@classmethod
def from_dict(cls, json_object):
"""Constructs a `GPT2Config` from a Python dictionary of parameters."""
config = GPT2Config(vocab_size_or_config_json_file=-1)
for key, value in json_object.items():
config.__dict__[key] = value
return config
@classmethod
def from_json_file(cls, json_file):
"""Constructs a `GPT2Config` from a json file of parameters."""
with open(json_file, "r", encoding="utf-8") as reader:
text = reader.read()
return cls.from_dict(json.loads(text))
def __repr__(self):
return str(self.to_json_string())
def to_dict(self):
"""Serializes this instance to a Python dictionary."""
output = copy.deepcopy(self.__dict__)
return output
def to_json_string(self):
"""Serializes this instance to a JSON string."""
return json.dumps(self.to_dict(), indent=2, sort_keys=True) + "\n"
def to_json_file(self, json_file_path):
""" Save this instance to a json file."""
with open(json_file_path, "w", encoding='utf-8') as writer:
writer.write(self.to_json_string())
class Conv1D(nn.Module):
def __init__(self, nf, nx):
super(Conv1D, self).__init__()
self.nf = nf
w = torch.empty(nx, nf)
nn.init.normal_(w, std=0.02)
self.weight = Parameter(w)
self.bias = Parameter(torch.zeros(nf))
def forward(self, x):
size_out = x.size()[:-1] + (self.nf,)
x = torch.addmm(self.bias, x.view(-1, x.size(-1)), self.weight)
x = x.view(*size_out)
return x
class Attention(nn.Module):
def __init__(self, nx, n_ctx, config, scale=False, output_attentions=False, keep_multihead_output=False):
......
......@@ -31,9 +31,9 @@ import torch.nn as nn
from torch.nn import CrossEntropyLoss
from torch.nn.parameter import Parameter
from .file_utils import cached_path, CONFIG_NAME, WEIGHTS_NAME
from .file_utils import cached_path
from .model_utils import Conv1D, CONFIG_NAME, WEIGHTS_NAME, PretrainedConfig, prune_conv1d_layer
from .modeling import BertLayerNorm as LayerNorm
from .modeling_gpt2 import prune_conv1d_layer
logger = logging.getLogger(__name__)
......@@ -122,9 +122,10 @@ def swish(x):
ACT_FNS = {"relu": nn.ReLU, "swish": swish, "gelu": gelu}
class OpenAIGPTConfig(object):
class OpenAIGPTConfig(PretrainedConfig):
"""Configuration class to store the configuration of a `OpenAIGPTModel`.
"""
pretrained_config_archive_map = PRETRAINED_CONFIG_ARCHIVE_MAP
def __init__(
self,
......@@ -197,61 +198,6 @@ class OpenAIGPTConfig(object):
def total_tokens_embeddings(self):
return self.vocab_size + self.n_special
@classmethod
def from_dict(cls, json_object):
"""Constructs a `OpenAIGPTConfig` from a Python dictionary of parameters."""
config = OpenAIGPTConfig(vocab_size_or_config_json_file=-1)
for key, value in json_object.items():
config.__dict__[key] = value
return config
@classmethod
def from_json_file(cls, json_file):
"""Constructs a `OpenAIGPTConfig` from a json file of parameters."""
with open(json_file, "r", encoding="utf-8") as reader:
text = reader.read()
return cls.from_dict(json.loads(text))
def __repr__(self):
return str(self.to_json_string())
def to_dict(self):
"""Serializes this instance to a Python dictionary."""
output = copy.deepcopy(self.__dict__)
return output
def to_json_string(self):
"""Serializes this instance to a JSON string."""
return json.dumps(self.to_dict(), indent=2, sort_keys=True) + "\n"
def to_json_file(self, json_file_path):
""" Save this instance to a json file."""
with open(json_file_path, "w", encoding='utf-8') as writer:
writer.write(self.to_json_string())
class Conv1D(nn.Module):
def __init__(self, nf, rf, nx):
super(Conv1D, self).__init__()
self.rf = rf
self.nf = nf
if rf == 1: # faster 1x1 conv
w = torch.empty(nx, nf)
nn.init.normal_(w, std=0.02)
self.weight = Parameter(w)
self.bias = Parameter(torch.zeros(nf))
else: # was used to train LM
raise NotImplementedError
def forward(self, x):
if self.rf == 1:
size_out = x.size()[:-1] + (self.nf,)
x = torch.addmm(self.bias, x.view(-1, x.size(-1)), self.weight)
x = x.view(*size_out)
else:
raise NotImplementedError
return x
class Attention(nn.Module):
def __init__(self, nx, n_ctx, config, scale=False, output_attentions=False, keep_multihead_output=False):
......@@ -268,8 +214,8 @@ class Attention(nn.Module):
self.keep_multihead_output = keep_multihead_output
self.multihead_output = None
self.c_attn = Conv1D(n_state * 3, 1, nx)
self.c_proj = Conv1D(n_state, 1, nx)
self.c_attn = Conv1D(n_state * 3, nx)
self.c_proj = Conv1D(n_state, nx)
self.attn_dropout = nn.Dropout(config.attn_pdrop)
self.resid_dropout = nn.Dropout(config.resid_pdrop)
......@@ -348,8 +294,8 @@ class MLP(nn.Module):
def __init__(self, n_state, config): # in MLP: n_state=3072 (4 * n_embd)
super(MLP, self).__init__()
nx = config.n_embd
self.c_fc = Conv1D(n_state, 1, nx)
self.c_proj = Conv1D(nx, 1, n_state)
self.c_fc = Conv1D(n_state, nx)
self.c_proj = Conv1D(nx, n_state)
self.act = ACT_FNS[config.afn]
self.dropout = nn.Dropout(config.resid_pdrop)
......
......@@ -37,7 +37,8 @@ from torch.nn.parameter import Parameter
from .modeling import BertLayerNorm as LayerNorm
from .modeling_transfo_xl_utilities import ProjectedAdaptiveLogSoftmax, sample_logits
from .file_utils import cached_path, CONFIG_NAME, WEIGHTS_NAME
from .file_utils import cached_path
from .model_utils import CONFIG_NAME, WEIGHTS_NAME, PretrainedConfig
logger = logging.getLogger(__name__)
......@@ -178,9 +179,11 @@ def load_tf_weights_in_transfo_xl(model, config, tf_path):
return model
class TransfoXLConfig(object):
class TransfoXLConfig(PretrainedConfig):
"""Configuration class to store the configuration of a `TransfoXLModel`.
"""
pretrained_config_archive_map = PRETRAINED_CONFIG_ARCHIVE_MAP
def __init__(self,
vocab_size_or_config_json_file=267735,
cutoffs=[20000, 40000, 200000],
......@@ -285,38 +288,6 @@ class TransfoXLConfig(object):
raise ValueError("First argument must be either a vocabulary size (int)"
"or the path to a pretrained model config file (str)")
@classmethod
def from_dict(cls, json_object):
"""Constructs a `TransfoXLConfig` from a Python dictionary of parameters."""
config = TransfoXLConfig(vocab_size_or_config_json_file=-1)
for key, value in json_object.items():
config.__dict__[key] = value
return config
@classmethod
def from_json_file(cls, json_file):
"""Constructs a `TransfoXLConfig` from a json file of parameters."""
with open(json_file, "r", encoding='utf-8') as reader:
text = reader.read()
return cls.from_dict(json.loads(text))
def __repr__(self):
return str(self.to_json_string())
def to_dict(self):
"""Serializes this instance to a Python dictionary."""
output = copy.deepcopy(self.__dict__)
return output
def to_json_string(self):
"""Serializes this instance to a JSON string."""
return json.dumps(self.to_dict(), indent=2, sort_keys=True) + "\n"
def to_json_file(self, json_file_path):
""" Save this instance to a json file."""
with open(json_file_path, "w", encoding='utf-8') as writer:
writer.write(self.to_json_string())
class PositionalEmbedding(nn.Module):
def __init__(self, demb):
......
This diff is collapsed.
......@@ -32,7 +32,9 @@ from torch import nn
from torch.nn import functional as F
from torch.nn import CrossEntropyLoss, MSELoss
from .file_utils import cached_path, WEIGHTS_NAME, CONFIG_NAME
from .file_utils import cached_path
from .model_utils import CONFIG_NAME, WEIGHTS_NAME, PretrainedConfig
logger = logging.getLogger(__name__)
......@@ -192,48 +194,12 @@ def swish(x):
ACT2FN = {"gelu": gelu, "relu": torch.nn.functional.relu, "swish": swish}
class XLNetBaseConfig(object):
@classmethod
def from_dict(cls, json_object):
"""Constructs a `XLNetBaseConfig` from a Python dictionary of parameters."""
config = cls(vocab_size_or_config_json_file=-1)
for key, value in json_object.items():
config.__dict__[key] = value
return config
@classmethod
def from_json_file(cls, json_file):
"""Constructs a `XLNetBaseConfig` from a json file of parameters."""
with open(json_file, "r", encoding='utf-8') as reader:
text = reader.read()
return cls.from_dict(json.loads(text))
def update(self, other):
dict_b = other.to_dict()
for key, value in dict_b.items():
self.__dict__[key] = value
def __repr__(self):
return str(self.to_json_string())
def to_dict(self):
"""Serializes this instance to a Python dictionary."""
output = copy.deepcopy(self.__dict__)
return output
def to_json_string(self):
"""Serializes this instance to a JSON string."""
return json.dumps(self.to_dict(), indent=2, sort_keys=True) + "\n"
def to_json_file(self, json_file_path):
""" Save this instance to a json file."""
with open(json_file_path, "w", encoding='utf-8') as writer:
writer.write(self.to_json_string())
class XLNetConfig(XLNetBaseConfig):
class XLNetConfig(PretrainedConfig):
"""Configuration class to store the configuration of a `XLNetModel`.
"""
pretrained_config_archive_map = PRETRAINED_CONFIG_ARCHIVE_MAP
def __init__(self,
vocab_size_or_config_json_file,
d_model=1024,
......@@ -337,53 +303,6 @@ class XLNetConfig(XLNetBaseConfig):
"or the path to a pretrained model config file (str)")
class XLNetRunConfig(XLNetBaseConfig):
"""XLNetRunConfig contains hyperparameters that could be different
between pretraining and finetuning.
These hyperparameters can also be changed from run to run.
We store them separately from XLNetConfig for flexibility.
"""
def __init__(self,
dropout=0.1,
dropatt=0.1,
init="normal",
init_range=0.1,
init_std=0.02,
mem_len=None,
reuse_len=None,
bi_data=False,
clamp_len=-1,
same_length=False):
"""
Args:
dropout: float, dropout rate.
dropatt: float, dropout rate on attention probabilities.
init: str, the initialization scheme, either "normal" or "uniform".
init_range: float, initialize the parameters with a uniform distribution
in [-init_range, init_range]. Only effective when init="uniform".
init_std: float, initialize the parameters with a normal distribution
with mean 0 and stddev init_std. Only effective when init="normal".
mem_len: int, the number of tokens to cache.
reuse_len: int, the number of tokens in the currect batch to be cached
and reused in the future.
bi_data: bool, whether to use bidirectional input pipeline.
Usually set to True during pretraining and False during finetuning.
clamp_len: int, clamp all relative distances larger than clamp_len.
-1 means no clamping.
same_length: bool, whether to use the same attention length for each token.
"""
self.init = init
self.init_range = init_range
self.init_std = init_std
self.dropout = dropout
self.dropatt = dropatt
self.mem_len = mem_len
self.reuse_len = reuse_len
self.bi_data = bi_data
self.clamp_len = clamp_len
self.same_length = same_length
try:
from apex.normalization.fused_layer_norm import FusedLayerNorm as XLNetLayerNorm
except ImportError:
......@@ -637,9 +556,9 @@ class XLNetPreTrainedModel(nn.Module):
"""
def __init__(self, config, *inputs, **kwargs):
super(XLNetPreTrainedModel, self).__init__()
if not isinstance(config, XLNetBaseConfig):
if not isinstance(config, XLNetConfig):
raise ValueError(
"Parameter config in `{}(config)` should be an instance of class `XLNetBaseConfig`. "
"Parameter config in `{}(config)` should be an instance of class `XLNetConfig`. "
"To create a model from a Google pretrained model use "
"`model = {}.from_pretrained(PRETRAINED_MODEL_NAME)`".format(
self.__class__.__name__, self.__class__.__name__
......
......@@ -25,7 +25,7 @@ import pytest
import torch
from pytorch_pretrained_bert import (XLNetConfig, XLNetRunConfig, XLNetModel, XLNetLMHeadModel)
from pytorch_pretrained_bert import (XLNetConfig, XLNetModel, XLNetLMHeadModel)
from pytorch_pretrained_bert.modeling_xlnet import PRETRAINED_MODEL_ARCHIVE_MAP
class XLNetModelTest(unittest.TestCase):
......@@ -117,17 +117,13 @@ class XLNetModelTest(unittest.TestCase):
d_inner=self.d_inner,
n_layer=self.n_layer,
untie_r=self.untie_r,
max_position_embeddings=self.max_position_embeddings)
run_config = XLNetRunConfig(
max_position_embeddings=self.max_position_embeddings,
mem_len=self.mem_len,
clamp_len=self.clamp_len,
same_length=self.same_length,
reuse_len=self.reuse_len,
bi_data=self.bi_data)
config.update(run_config)
return (config, input_ids_1, input_ids_2, input_ids_q, perm_mask, target_mapping, inp_q, segment_ids, lm_labels)
def set_seed(self):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment