Commit 70887795 authored by thomwolf's avatar thomwolf
Browse files

updating tests and models, adding weights initialization test

parent 99ae5ab8
...@@ -191,6 +191,8 @@ def get_from_cache(url, cache_dir=None): ...@@ -191,6 +191,8 @@ def get_from_cache(url, cache_dir=None):
cache_dir = PYTORCH_PRETRAINED_BERT_CACHE cache_dir = PYTORCH_PRETRAINED_BERT_CACHE
if sys.version_info[0] == 3 and isinstance(cache_dir, Path): if sys.version_info[0] == 3 and isinstance(cache_dir, Path):
cache_dir = str(cache_dir) cache_dir = str(cache_dir)
if sys.version_info[0] == 2 and not isinstance(cache_dir, str):
cache_dir = str(cache_dir)
if not os.path.exists(cache_dir): if not os.path.exists(cache_dir):
os.makedirs(cache_dir) os.makedirs(cache_dir)
......
...@@ -60,8 +60,7 @@ class PretrainedConfig(object): ...@@ -60,8 +60,7 @@ class PretrainedConfig(object):
. `config.json` a configuration file for the model . `config.json` a configuration file for the model
cache_dir: an optional path to a folder in which the pre-trained model configuration will be cached. cache_dir: an optional path to a folder in which the pre-trained model configuration will be cached.
""" """
cache_dir = kwargs.get('cache_dir', None) cache_dir = kwargs.pop('cache_dir', None)
kwargs.pop('cache_dir', None)
if pretrained_model_name_or_path in cls.pretrained_config_archive_map: if pretrained_model_name_or_path in cls.pretrained_config_archive_map:
config_file = cls.pretrained_config_archive_map[pretrained_model_name_or_path] config_file = cls.pretrained_config_archive_map[pretrained_model_name_or_path]
......
...@@ -17,7 +17,6 @@ ...@@ -17,7 +17,6 @@
from __future__ import absolute_import, division, print_function, unicode_literals from __future__ import absolute_import, division, print_function, unicode_literals
import copy
import json import json
import logging import logging
import math import math
...@@ -422,8 +421,7 @@ class BertEncoder(nn.Module): ...@@ -422,8 +421,7 @@ class BertEncoder(nn.Module):
super(BertEncoder, self).__init__() super(BertEncoder, self).__init__()
self.output_attentions = config.output_attentions self.output_attentions = config.output_attentions
self.output_hidden_states = config.output_hidden_states self.output_hidden_states = config.output_hidden_states
layer = BertLayer(config) self.layer = nn.ModuleList([BertLayer(config) for _ in range(config.num_hidden_layers)])
self.layer = nn.ModuleList([copy.deepcopy(layer) for _ in range(config.num_hidden_layers)])
def forward(self, hidden_states, attention_mask, head_mask=None): def forward(self, hidden_states, attention_mask, head_mask=None):
all_hidden_states = [] all_hidden_states = []
...@@ -539,10 +537,12 @@ class BertPreTrainedModel(PreTrainedModel): ...@@ -539,10 +537,12 @@ class BertPreTrainedModel(PreTrainedModel):
""" """
config_class = BertConfig config_class = BertConfig
pretrained_model_archive_map = PRETRAINED_MODEL_ARCHIVE_MAP pretrained_model_archive_map = PRETRAINED_MODEL_ARCHIVE_MAP
pretrained_config_archive_map = PRETRAINED_CONFIG_ARCHIVE_MAP
load_tf_weights = load_tf_weights_in_bert load_tf_weights = load_tf_weights_in_bert
base_model_prefix = "bert" base_model_prefix = "bert"
def __init__(self, *inputs, **kwargs):
super(BertPreTrainedModel, self).__init__(*inputs, **kwargs)
def init_weights(self, module): def init_weights(self, module):
""" Initialize the weights. """ Initialize the weights.
""" """
......
...@@ -18,7 +18,6 @@ ...@@ -18,7 +18,6 @@
from __future__ import absolute_import, division, print_function, unicode_literals from __future__ import absolute_import, division, print_function, unicode_literals
import collections import collections
import copy
import json import json
import logging import logging
import math import math
...@@ -378,18 +377,21 @@ class GPT2PreTrainedModel(PreTrainedModel): ...@@ -378,18 +377,21 @@ class GPT2PreTrainedModel(PreTrainedModel):
load_tf_weights = load_tf_weights_in_gpt2 load_tf_weights = load_tf_weights_in_gpt2
base_model_prefix = "transformer" base_model_prefix = "transformer"
def __init__(self, *inputs, **kwargs):
super(GPT2PreTrainedModel, self).__init__(*inputs, **kwargs)
def init_weights(self, module): def init_weights(self, module):
""" Initialize the weights. """ Initialize the weights.
""" """
if isinstance(module, (nn.Linear, nn.Embedding)): if isinstance(module, (nn.Linear, nn.Embedding, Conv1D)):
# Slightly different from the TF version which uses truncated_normal for initialization # Slightly different from the TF version which uses truncated_normal for initialization
# cf https://github.com/pytorch/pytorch/pull/5617 # cf https://github.com/pytorch/pytorch/pull/5617
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
if isinstance(module, (nn.Linear, Conv1D)) and module.bias is not None:
module.bias.data.zero_()
elif isinstance(module, LayerNorm): elif isinstance(module, LayerNorm):
module.bias.data.zero_() module.bias.data.zero_()
module.weight.data.fill_(1.0) module.weight.data.fill_(1.0)
if isinstance(module, nn.Linear) and module.bias is not None:
module.bias.data.zero_()
@classmethod @classmethod
def from_pretrained(cls, pretrained_model_name_or_path, *inputs, **kwargs): def from_pretrained(cls, pretrained_model_name_or_path, *inputs, **kwargs):
...@@ -489,8 +491,7 @@ class GPT2Model(GPT2PreTrainedModel): ...@@ -489,8 +491,7 @@ class GPT2Model(GPT2PreTrainedModel):
self.wte = nn.Embedding(config.total_tokens_embeddings, config.n_embd) self.wte = nn.Embedding(config.total_tokens_embeddings, config.n_embd)
self.wpe = nn.Embedding(config.n_positions, config.n_embd) self.wpe = nn.Embedding(config.n_positions, config.n_embd)
self.drop = nn.Dropout(config.embd_pdrop) self.drop = nn.Dropout(config.embd_pdrop)
block = Block(config.n_ctx, config, scale=True) self.h = nn.ModuleList([Block(config.n_ctx, config, scale=True) for _ in range(config.n_layer)])
self.h = nn.ModuleList([copy.deepcopy(block) for _ in range(config.n_layer)])
self.ln_f = LayerNorm(config.n_embd, eps=config.layer_norm_epsilon) self.ln_f = LayerNorm(config.n_embd, eps=config.layer_norm_epsilon)
self.apply(self.init_weights) self.apply(self.init_weights)
......
...@@ -18,7 +18,6 @@ ...@@ -18,7 +18,6 @@
from __future__ import absolute_import, division, print_function, unicode_literals from __future__ import absolute_import, division, print_function, unicode_literals
import collections import collections
import copy
import json import json
import logging import logging
import math import math
...@@ -405,18 +404,21 @@ class OpenAIGPTPreTrainedModel(PreTrainedModel): ...@@ -405,18 +404,21 @@ class OpenAIGPTPreTrainedModel(PreTrainedModel):
load_tf_weights = load_tf_weights_in_openai_gpt load_tf_weights = load_tf_weights_in_openai_gpt
base_model_prefix = "transformer" base_model_prefix = "transformer"
def __init__(self, *inputs, **kwargs):
super(OpenAIGPTPreTrainedModel, self).__init__(*inputs, **kwargs)
def init_weights(self, module): def init_weights(self, module):
""" Initialize the weights. """ Initialize the weights.
""" """
if isinstance(module, (nn.Linear, nn.Embedding)): if isinstance(module, (nn.Linear, nn.Embedding, Conv1D)):
# Slightly different from the TF version which uses truncated_normal for initialization # Slightly different from the TF version which uses truncated_normal for initialization
# cf https://github.com/pytorch/pytorch/pull/5617 # cf https://github.com/pytorch/pytorch/pull/5617
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
if isinstance(module, (nn.Linear, Conv1D)) and module.bias is not None:
module.bias.data.zero_()
elif isinstance(module, LayerNorm): elif isinstance(module, LayerNorm):
module.bias.data.zero_() module.bias.data.zero_()
module.weight.data.fill_(1.0) module.weight.data.fill_(1.0)
if isinstance(module, nn.Linear) and module.bias is not None:
module.bias.data.zero_()
@classmethod @classmethod
def from_pretrained(cls, pretrained_model_name_or_path, *inputs, **kwargs): def from_pretrained(cls, pretrained_model_name_or_path, *inputs, **kwargs):
...@@ -513,8 +515,7 @@ class OpenAIGPTModel(OpenAIGPTPreTrainedModel): ...@@ -513,8 +515,7 @@ class OpenAIGPTModel(OpenAIGPTPreTrainedModel):
self.tokens_embed = nn.Embedding(config.total_tokens_embeddings, config.n_embd) self.tokens_embed = nn.Embedding(config.total_tokens_embeddings, config.n_embd)
self.positions_embed = nn.Embedding(config.n_positions, config.n_embd) self.positions_embed = nn.Embedding(config.n_positions, config.n_embd)
self.drop = nn.Dropout(config.embd_pdrop) self.drop = nn.Dropout(config.embd_pdrop)
block = Block(config.n_ctx, config, scale=True) self.h = nn.ModuleList([Block(config.n_ctx, config, scale=True) for _ in range(config.n_layer)])
self.h = nn.ModuleList([copy.deepcopy(block) for _ in range(config.n_layer)])
self.apply(self.init_weights) self.apply(self.init_weights)
......
...@@ -21,7 +21,6 @@ ...@@ -21,7 +21,6 @@
from __future__ import absolute_import, division, print_function, unicode_literals from __future__ import absolute_import, division, print_function, unicode_literals
import os import os
import copy
import json import json
import math import math
import logging import logging
...@@ -843,6 +842,9 @@ class TransfoXLPreTrainedModel(PreTrainedModel): ...@@ -843,6 +842,9 @@ class TransfoXLPreTrainedModel(PreTrainedModel):
load_tf_weights = load_tf_weights_in_transfo_xl load_tf_weights = load_tf_weights_in_transfo_xl
base_model_prefix = "transformer" base_model_prefix = "transformer"
def __init__(self, *inputs, **kwargs):
super(TransfoXLPreTrainedModel, self).__init__(*inputs, **kwargs)
def _init_weight(self, weight): def _init_weight(self, weight):
if self.config.init == 'uniform': if self.config.init == 'uniform':
nn.init.uniform_(weight, -self.config.init_range, self.config.init_range) nn.init.uniform_(weight, -self.config.init_range, self.config.init_range)
...@@ -883,7 +885,7 @@ class TransfoXLPreTrainedModel(PreTrainedModel): ...@@ -883,7 +885,7 @@ class TransfoXLPreTrainedModel(PreTrainedModel):
nn.init.normal_(m.weight, 1.0, self.config.init_std) nn.init.normal_(m.weight, 1.0, self.config.init_std)
if hasattr(m, 'bias') and m.bias is not None: if hasattr(m, 'bias') and m.bias is not None:
self._init_bias(m.bias) self._init_bias(m.bias)
elif classname.find('TransformerLM') != -1: else:
if hasattr(m, 'r_emb'): if hasattr(m, 'r_emb'):
self._init_weight(m.r_emb) self._init_weight(m.r_emb)
if hasattr(m, 'r_w_bias'): if hasattr(m, 'r_w_bias'):
......
...@@ -18,7 +18,6 @@ from __future__ import (absolute_import, division, print_function, ...@@ -18,7 +18,6 @@ from __future__ import (absolute_import, division, print_function,
unicode_literals) unicode_literals)
from __future__ import absolute_import, division, print_function, unicode_literals from __future__ import absolute_import, division, print_function, unicode_literals
import copy
import json import json
import logging import logging
import math import math
......
...@@ -19,7 +19,6 @@ from __future__ import (absolute_import, division, print_function, ...@@ -19,7 +19,6 @@ from __future__ import (absolute_import, division, print_function,
unicode_literals) unicode_literals)
from __future__ import absolute_import, division, print_function, unicode_literals from __future__ import absolute_import, division, print_function, unicode_literals
import copy
import json import json
import logging import logging
import math import math
...@@ -598,6 +597,8 @@ class XLNetPreTrainedModel(PreTrainedModel): ...@@ -598,6 +597,8 @@ class XLNetPreTrainedModel(PreTrainedModel):
# Slightly different from the TF version which uses truncated_normal for initialization # Slightly different from the TF version which uses truncated_normal for initialization
# cf https://github.com/pytorch/pytorch/pull/5617 # cf https://github.com/pytorch/pytorch/pull/5617
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
if isinstance(module, nn.Linear) and module.bias is not None:
module.bias.data.zero_()
elif isinstance(module, XLNetLayerNorm): elif isinstance(module, XLNetLayerNorm):
module.bias.data.zero_() module.bias.data.zero_()
module.weight.data.fill_(1.0) module.weight.data.fill_(1.0)
...@@ -606,8 +607,8 @@ class XLNetPreTrainedModel(PreTrainedModel): ...@@ -606,8 +607,8 @@ class XLNetPreTrainedModel(PreTrainedModel):
module.r_r_bias, module.r_s_bias, module.r_w_bias, module.r_r_bias, module.r_s_bias, module.r_w_bias,
module.seg_embed]: module.seg_embed]:
param.data.normal_(mean=0.0, std=self.config.initializer_range) param.data.normal_(mean=0.0, std=self.config.initializer_range)
if isinstance(module, nn.Linear) and module.bias is not None: elif isinstance(module, XLNetModel):
module.bias.data.zero_() module.mask_emb.data.normal_(mean=0.0, std=self.config.initializer_range)
class XLNetModel(XLNetPreTrainedModel): class XLNetModel(XLNetPreTrainedModel):
...@@ -627,10 +628,11 @@ class XLNetModel(XLNetPreTrainedModel): ...@@ -627,10 +628,11 @@ class XLNetModel(XLNetPreTrainedModel):
self.word_embedding = nn.Embedding(config.n_token, config.d_model) self.word_embedding = nn.Embedding(config.n_token, config.d_model)
self.mask_emb = nn.Parameter(torch.Tensor(1, 1, config.d_model)) self.mask_emb = nn.Parameter(torch.Tensor(1, 1, config.d_model))
layer = XLNetLayer(config) self.layer = nn.ModuleList([XLNetLayer(config) for _ in range(config.n_layer)])
self.layer = nn.ModuleList([copy.deepcopy(layer) for _ in range(config.n_layer)])
self.dropout = nn.Dropout(config.dropout) self.dropout = nn.Dropout(config.dropout)
self.apply(self.init_weights)
def _prune_heads(self, heads_to_prune): def _prune_heads(self, heads_to_prune):
logger.info("Head pruning is not implemented for XLNet") logger.info("Head pruning is not implemented for XLNet")
pass pass
......
...@@ -16,6 +16,7 @@ from __future__ import absolute_import ...@@ -16,6 +16,7 @@ from __future__ import absolute_import
from __future__ import division from __future__ import division
from __future__ import print_function from __future__ import print_function
import copy
import os import os
import shutil import shutil
import json import json
...@@ -23,59 +24,66 @@ import random ...@@ -23,59 +24,66 @@ import random
import torch import torch
def create_and_check_for_headmasking(tester, model_classes, config, inputs_dict): def _config_zero_init(config):
configs_no_init = copy.deepcopy(config)
for key in configs_no_init.__dict__.keys():
if '_range' in key or '_std' in key:
setattr(configs_no_init, key, 0.0)
return configs_no_init
def _create_and_check_initialization(tester, model_classes, config, inputs_dict):
configs_no_init = _config_zero_init(config)
for model_class in model_classes:
model = model_class(config=configs_no_init)
for name, param in model.named_parameters():
tester.parent.assertIn(param.data.mean().item(), [0.0, 1.0], msg="Parameter {} of model {} seems not properly initialized".format(name, model_class))
def _create_and_check_for_headmasking(tester, model_classes, config, inputs_dict):
configs_no_init = _config_zero_init(config)
for model_class in model_classes: for model_class in model_classes:
config.output_attentions = True
config.output_hidden_states = True config.output_hidden_states = True
model = model_class(config=config) model = model_class(config=configs_no_init)
model.eval() model.eval()
head_mask = torch.zeros(tester.num_hidden_layers, tester.num_attention_heads)
# Set that after having prepared the tensor to avoid error (leaf variable has been moved into the graph interior) # Prepare head_mask
# Set require_grad after having prepared the tensor to avoid error (leaf variable has been moved into the graph interior)
head_mask = torch.ones(tester.num_hidden_layers, tester.num_attention_heads)
head_mask[0, 0] = 0
head_mask[-1, :-1] = 0
head_mask.requires_grad_(requires_grad=True) head_mask.requires_grad_(requires_grad=True)
outputs = model(**inputs_dict, head_mask=head_mask) inputs = inputs_dict.copy()
inputs['head_mask'] = head_mask
# Compute some gradients outputs = model(**inputs)
# Test that we can get a gradient back for importance score computation
output = sum(t.sum() for t in outputs[0]) output = sum(t.sum() for t in outputs[0])
output = output.sum() output = output.sum()
output.backward() output.backward()
multihead_outputs = head_mask.grad multihead_outputs = head_mask.grad
attentions = outputs[-1]
hidden_states = outputs[-2]
tester.parent.assertIsNotNone(multihead_outputs)
tester.parent.assertEqual(len(multihead_outputs), tester.num_hidden_layers) tester.parent.assertEqual(len(multihead_outputs), tester.num_hidden_layers)
# self.parent.assertListEqual( tester.parent.assertAlmostEqual(
# list(multihead_outputs[0].size()), attentions[0][..., 0, :, :].flatten().sum().item(), 0.0)
# [self.batch_size, self.num_attention_heads, tester.parent.assertNotEqual(
# self.seq_length, self.hidden_size // self.num_attention_heads]) attentions[0][..., -1, :, :].flatten().sum().item(), 0.0)
# self.parent.assertEqual( tester.parent.assertNotEqual(
# len(multihead_outputs[0][:, 1:(self.num_attention_heads-1), :, :].nonzero()), attentions[1][..., 0, :, :].flatten().sum().item(), 0.0)
# 0) tester.parent.assertAlmostEqual(
# self.parent.assertEqual( attentions[-1][..., -2, :, :].flatten().sum().item(), 0.0)
# len(multihead_outputs[0][:, 0, :, :].nonzero()), tester.parent.assertNotEqual(
# self.batch_size * self.seq_length * self.hidden_size // self.num_attention_heads) attentions[-1][..., -1, :, :].flatten().sum().item(), 0.0)
# self.parent.assertEqual(
# len(multihead_outputs[0][:, self.num_attention_heads-1, :, :].nonzero()),
# self.batch_size * self.seq_length * self.hidden_size // self.num_attention_heads) def _create_and_check_for_head_pruning(tester, model_classes, config, inputs_dict):
# self.parent.assertListEqual(
# list(multihead_outputs[1].size()),
# [self.batch_size, self.num_attention_heads,
# self.seq_length, self.hidden_size // self.num_attention_heads])
# self.parent.assertEqual(
# len(multihead_outputs[1].nonzero()),
# multihead_outputs[1].numel())
# self.parent.assertListEqual(
# list(multihead_outputs[-1].size()),
# [self.batch_size, self.num_attention_heads,
# self.seq_length, self.hidden_size // self.num_attention_heads])
# self.parent.assertEqual(
# len(multihead_outputs[-1][:, 1:, :, :].nonzero()),
# 0)
# self.parent.assertEqual(
# len(multihead_outputs[-1][:, 0, :, :].nonzero()),
# self.batch_size * self.seq_length * self.hidden_size // self.num_attention_heads)
def create_and_check_for_head_pruning(tester, model_classes, config, inputs_dict):
for model_class in model_classes: for model_class in model_classes:
config.output_attentions = True
config.output_hidden_states = False
model = model_class(config=config) model = model_class(config=config)
model.eval() model.eval()
heads_to_prune = {0: list(range(1, tester.num_attention_heads)), heads_to_prune = {0: list(range(1, tester.num_attention_heads)),
...@@ -83,27 +91,17 @@ def create_and_check_for_head_pruning(tester, model_classes, config, inputs_dict ...@@ -83,27 +91,17 @@ def create_and_check_for_head_pruning(tester, model_classes, config, inputs_dict
model.prune_heads(heads_to_prune) model.prune_heads(heads_to_prune)
outputs = model(**inputs_dict) outputs = model(**inputs_dict)
# output = sum(t.sum() for t in outputs[0]) attentions = outputs[-1]
# output = output.sum()
# output.backward() tester.parent.assertEqual(
# multihead_outputs = bert_model.get_multihead_outputs() attentions[0].shape[-3], 1)
tester.parent.assertEqual(
# self.parent.assertEqual(len(multihead_outputs), self.num_hidden_layers) attentions[1].shape[-3], tester.num_attention_heads)
# self.parent.assertListEqual( tester.parent.assertEqual(
# list(multihead_outputs[0].size()), attentions[-1].shape[-3], tester.num_attention_heads - 1)
# [self.batch_size, 1,
# self.seq_length, self.hidden_size // self.num_attention_heads])
# self.parent.assertListEqual( def _create_and_check_for_attentions(tester, model_classes, config, inputs_dict):
# list(multihead_outputs[1].size()),
# [self.batch_size, self.num_attention_heads,
# self.seq_length, self.hidden_size // self.num_attention_heads])
# self.parent.assertListEqual(
# list(multihead_outputs[-1].size()),
# [self.batch_size, self.num_attention_heads-1,
# self.seq_length, self.hidden_size // self.num_attention_heads])
def create_and_check_for_attentions(tester, model_classes, config, inputs_dict):
for model_class in model_classes: for model_class in model_classes:
config.output_attentions = True config.output_attentions = True
config.output_hidden_states = False config.output_hidden_states = False
...@@ -139,7 +137,7 @@ def create_and_check_for_attentions(tester, model_classes, config, inputs_dict): ...@@ -139,7 +137,7 @@ def create_and_check_for_attentions(tester, model_classes, config, inputs_dict):
tester.seq_length, tester.seq_length,
tester.key_len if hasattr(tester, 'key_len') else tester.seq_length]) tester.key_len if hasattr(tester, 'key_len') else tester.seq_length])
def create_and_check_for_hidden_states(tester, model_classes, config, inputs_dict): def _create_and_check_for_hidden_states(tester, model_classes, config, inputs_dict):
for model_class in model_classes: for model_class in model_classes:
config.output_hidden_states = True config.output_hidden_states = True
config.output_attentions = False config.output_attentions = False
...@@ -155,11 +153,13 @@ def create_and_check_for_hidden_states(tester, model_classes, config, inputs_dic ...@@ -155,11 +153,13 @@ def create_and_check_for_hidden_states(tester, model_classes, config, inputs_dic
[tester.seq_length, tester.hidden_size]) [tester.seq_length, tester.hidden_size])
def create_and_check_commons(tester, config, inputs_dict): def create_and_check_commons(tester, config, inputs_dict, test_pruning=True):
create_and_check_for_attentions(tester, tester.all_model_classes, config, inputs_dict) _create_and_check_initialization(tester, tester.all_model_classes, config, inputs_dict)
create_and_check_for_headmasking(tester, tester.all_model_classes, config, inputs_dict) _create_and_check_for_attentions(tester, tester.all_model_classes, config, inputs_dict)
create_and_check_for_head_pruning(tester, tester.all_model_classes, config, inputs_dict) _create_and_check_for_headmasking(tester, tester.all_model_classes, config, inputs_dict)
create_and_check_for_hidden_states(tester, tester.all_model_classes, config, inputs_dict) _create_and_check_for_hidden_states(tester, tester.all_model_classes, config, inputs_dict)
if test_pruning:
_create_and_check_for_head_pruning(tester, tester.all_model_classes, config, inputs_dict)
def ids_tensor(shape, vocab_size, rng=None, name=None): def ids_tensor(shape, vocab_size, rng=None, name=None):
......
...@@ -28,9 +28,7 @@ import torch ...@@ -28,9 +28,7 @@ import torch
from pytorch_pretrained_bert import (GPT2Config, GPT2Model, from pytorch_pretrained_bert import (GPT2Config, GPT2Model,
GPT2LMHeadModel, GPT2DoubleHeadsModel) GPT2LMHeadModel, GPT2DoubleHeadsModel)
from .model_tests_commons import (create_and_check_for_attentions, create_and_check_for_head_pruning, from .model_tests_commons import (create_and_check_commons, ConfigTester, GPTModelTester)
create_and_check_for_headmasking, create_and_check_for_hidden_states,
ConfigTester, GPTModelTester)
class GPT2ModelTest(unittest.TestCase): class GPT2ModelTest(unittest.TestCase):
......
...@@ -28,9 +28,7 @@ import torch ...@@ -28,9 +28,7 @@ import torch
from pytorch_pretrained_bert import (OpenAIGPTConfig, OpenAIGPTModel, from pytorch_pretrained_bert import (OpenAIGPTConfig, OpenAIGPTModel,
OpenAIGPTLMHeadModel, OpenAIGPTDoubleHeadsModel) OpenAIGPTLMHeadModel, OpenAIGPTDoubleHeadsModel)
from .model_tests_commons import (create_and_check_for_attentions, create_and_check_for_head_pruning, from .model_tests_commons import (create_and_check_commons, ConfigTester, GPTModelTester)
create_and_check_for_headmasking, create_and_check_for_hidden_states,
ConfigTester, GPTModelTester)
class OpenAIModelTest(unittest.TestCase): class OpenAIModelTest(unittest.TestCase):
......
...@@ -173,7 +173,7 @@ class TransfoXLModelTest(unittest.TestCase): ...@@ -173,7 +173,7 @@ class TransfoXLModelTest(unittest.TestCase):
def create_and_check_transfo_xl_commons(self, config, input_ids_1, input_ids_2, lm_labels): def create_and_check_transfo_xl_commons(self, config, input_ids_1, input_ids_2, lm_labels):
inputs_dict = {'input_ids': input_ids_1} inputs_dict = {'input_ids': input_ids_1}
create_and_check_commons(self, config, inputs_dict) create_and_check_commons(self, config, inputs_dict, test_pruning=False)
def test_default(self): def test_default(self):
self.run_tester(TransfoXLModelTest.TransfoXLModelTester(self)) self.run_tester(TransfoXLModelTest.TransfoXLModelTester(self))
......
...@@ -52,6 +52,7 @@ class XLNetModelTest(unittest.TestCase): ...@@ -52,6 +52,7 @@ class XLNetModelTest(unittest.TestCase):
untie_r=True, untie_r=True,
bi_data=False, bi_data=False,
same_length=False, same_length=False,
initializer_range=0.05,
seed=1, seed=1,
type_vocab_size=2, type_vocab_size=2,
all_model_classes=(XLNetModel, XLNetLMHeadModel, all_model_classes=(XLNetModel, XLNetLMHeadModel,
...@@ -76,6 +77,7 @@ class XLNetModelTest(unittest.TestCase): ...@@ -76,6 +77,7 @@ class XLNetModelTest(unittest.TestCase):
self.bi_data = bi_data self.bi_data = bi_data
self.untie_r = untie_r self.untie_r = untie_r
self.same_length = same_length self.same_length = same_length
self.initializer_range = initializer_range
self.seed = seed self.seed = seed
self.type_vocab_size = type_vocab_size self.type_vocab_size = type_vocab_size
self.all_model_classes = all_model_classes self.all_model_classes = all_model_classes
...@@ -129,7 +131,8 @@ class XLNetModelTest(unittest.TestCase): ...@@ -129,7 +131,8 @@ class XLNetModelTest(unittest.TestCase):
clamp_len=self.clamp_len, clamp_len=self.clamp_len,
same_length=self.same_length, same_length=self.same_length,
reuse_len=self.reuse_len, reuse_len=self.reuse_len,
bi_data=self.bi_data) bi_data=self.bi_data,
initializer_range=self.initializer_range)
return (config, input_ids_1, input_ids_2, input_ids_q, perm_mask, target_mapping, inp_q, segment_ids, lm_labels) return (config, input_ids_1, input_ids_2, input_ids_q, perm_mask, target_mapping, inp_q, segment_ids, lm_labels)
...@@ -180,7 +183,7 @@ class XLNetModelTest(unittest.TestCase): ...@@ -180,7 +183,7 @@ class XLNetModelTest(unittest.TestCase):
def create_and_check_xlnet_commons(self, config, input_ids_1, input_ids_2, input_ids_q, perm_mask, target_mapping, inp_q, segment_ids, lm_labels): def create_and_check_xlnet_commons(self, config, input_ids_1, input_ids_2, input_ids_q, perm_mask, target_mapping, inp_q, segment_ids, lm_labels):
inputs_dict = {'input_ids': input_ids_1} inputs_dict = {'input_ids': input_ids_1}
create_and_check_commons(self, config, inputs_dict) create_and_check_commons(self, config, inputs_dict, test_pruning=False)
def test_default(self): def test_default(self):
self.run_tester(XLNetModelTest.XLNetModelTester(self)) self.run_tester(XLNetModelTest.XLNetModelTester(self))
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment