Commit 8fa3a1f0 authored by thomwolf's avatar thomwolf
Browse files

updating tests

parent c41f2bad
...@@ -35,7 +35,7 @@ from torch.nn import functional as F ...@@ -35,7 +35,7 @@ from torch.nn import functional as F
from torch.nn import CrossEntropyLoss, MSELoss from torch.nn import CrossEntropyLoss, MSELoss
from .file_utils import cached_path from .file_utils import cached_path
from .model_utils import CONFIG_NAME, WEIGHTS_NAME, PretrainedConfig, PreTrainedModel from .model_utils import CONFIG_NAME, WEIGHTS_NAME, PretrainedConfig, PreTrainedModel, prune_linear_layer
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
...@@ -46,24 +46,6 @@ PRETRAINED_CONFIG_ARCHIVE_MAP = { ...@@ -46,24 +46,6 @@ PRETRAINED_CONFIG_ARCHIVE_MAP = {
'xlm-mlm-en-2048': "https://s3.amazonaws.com/models.huggingface.co/bert/xlm-mlm-en-2048-config.json", 'xlm-mlm-en-2048': "https://s3.amazonaws.com/models.huggingface.co/bert/xlm-mlm-en-2048-config.json",
} }
DECODER_ONLY_PARAMS = [
'layer_norm15.%i.weight', 'layer_norm15.%i.bias',
'encoder_attn.%i.q_lin.weight', 'encoder_attn.%i.q_lin.bias',
'encoder_attn.%i.k_lin.weight', 'encoder_attn.%i.k_lin.bias',
'encoder_attn.%i.v_lin.weight', 'encoder_attn.%i.v_lin.bias',
'encoder_attn.%i.out_lin.weight', 'encoder_attn.%i.out_lin.bias'
]
TRANSFORMER_LAYER_PARAMS = [
'attentions.%i.q_lin.weight', 'attentions.%i.q_lin.bias',
'attentions.%i.k_lin.weight', 'attentions.%i.k_lin.bias',
'attentions.%i.v_lin.weight', 'attentions.%i.v_lin.bias',
'attentions.%i.out_lin.weight', 'attentions.%i.out_lin.bias',
'layer_norm1.%i.weight', 'layer_norm1.%i.bias',
'ffns.%i.lin1.weight', 'ffns.%i.lin1.bias',
'ffns.%i.lin2.weight', 'ffns.%i.lin2.bias',
'layer_norm2.%i.weight', 'layer_norm2.%i.bias'
]
class XLMConfig(PretrainedConfig): class XLMConfig(PretrainedConfig):
"""Configuration class to store the configuration of a `XLMModel`. """Configuration class to store the configuration of a `XLMModel`.
...@@ -275,6 +257,24 @@ class MultiHeadAttention(nn.Module): ...@@ -275,6 +257,24 @@ class MultiHeadAttention(nn.Module):
self.v_lin = Linear(dim, dim, config=config) self.v_lin = Linear(dim, dim, config=config)
self.out_lin = Linear(dim, dim, config=config) self.out_lin = Linear(dim, dim, config=config)
def prune_heads(self, heads):
attention_head_size = self.dim // self.n_heads
if len(heads) == 0:
return
mask = torch.ones(self.n_heads, attention_head_size)
for head in heads:
mask[head] = 0
mask = mask.view(-1).contiguous().eq(1)
index = torch.arange(len(mask))[mask].long()
# Prune linear layers
self.q_lin = prune_linear_layer(self.q_lin, index)
self.k_lin = prune_linear_layer(self.k_lin, index)
self.v_lin = prune_linear_layer(self.v_lin, index)
self.out_lin = prune_linear_layer(self.out_lin, index, dim=1)
# Update hyper params
self.n_heads = self.n_heads - len(heads)
self.dim = attention_head_size * self.n_heads
def forward(self, input, mask, kv=None, cache=None, head_mask=None): def forward(self, input, mask, kv=None, cache=None, head_mask=None):
""" """
Self-attention (if kv is None) or attention over source sentence (provided by kv). Self-attention (if kv is None) or attention over source sentence (provided by kv).
...@@ -286,9 +286,9 @@ class MultiHeadAttention(nn.Module): ...@@ -286,9 +286,9 @@ class MultiHeadAttention(nn.Module):
klen = qlen if cache is None else cache['slen'] + qlen klen = qlen if cache is None else cache['slen'] + qlen
else: else:
klen = kv.size(1) klen = kv.size(1)
assert dim == self.dim, 'Dimensions do not match: %s input vs %s configured' % (dim, self.dim) # assert dim == self.dim, 'Dimensions do not match: %s input vs %s configured' % (dim, self.dim)
n_heads = self.n_heads n_heads = self.n_heads
dim_per_head = dim // n_heads dim_per_head = self.dim // n_heads
mask_reshape = (bs, 1, qlen, klen) if mask.dim() == 3 else (bs, 1, 1, klen) mask_reshape = (bs, 1, qlen, klen) if mask.dim() == 3 else (bs, 1, 1, klen)
def shape(x): def shape(x):
...@@ -335,7 +335,7 @@ class MultiHeadAttention(nn.Module): ...@@ -335,7 +335,7 @@ class MultiHeadAttention(nn.Module):
outputs = (self.out_lin(context),) outputs = (self.out_lin(context),)
if self.output_attentions: if self.output_attentions:
outputs = outputs + (weights) outputs = outputs + (weights,)
return outputs return outputs
...@@ -497,6 +497,14 @@ class XLMModel(XLMPreTrainedModel): ...@@ -497,6 +497,14 @@ class XLMModel(XLMPreTrainedModel):
self.ffns.append(TransformerFFN(self.dim, self.hidden_dim, self.dim, config=config)) self.ffns.append(TransformerFFN(self.dim, self.hidden_dim, self.dim, config=config))
self.layer_norm2.append(nn.LayerNorm(self.dim, eps=1e-12)) self.layer_norm2.append(nn.LayerNorm(self.dim, eps=1e-12))
def _prune_heads(self, heads_to_prune):
""" Prunes heads of the model.
heads_to_prune: dict of {layer_num: list of heads to prune in this layer}
See base class PreTrainedModel
"""
for layer, heads in heads_to_prune.items():
self.attentions[layer].prune_heads(heads)
def forward(self, input_ids, lengths=None, positions=None, langs=None, def forward(self, input_ids, lengths=None, positions=None, langs=None,
token_type_ids=None, attention_mask=None, cache=None, head_mask=None): # src_enc=None, src_len=None, token_type_ids=None, attention_mask=None, cache=None, head_mask=None): # src_enc=None, src_len=None,
""" """
...@@ -508,7 +516,7 @@ class XLMModel(XLMPreTrainedModel): ...@@ -508,7 +516,7 @@ class XLMModel(XLMPreTrainedModel):
`token_type_ids` LongTensor (bs, slen) same as `langs` used for compatibility `token_type_ids` LongTensor (bs, slen) same as `langs` used for compatibility
""" """
if lengths is None: if lengths is None:
lengths = (input_ids != self.pad_index).float().sum(dim=1) lengths = (input_ids != self.pad_index).sum(dim=1).long()
# mask = input_ids != self.pad_index # mask = input_ids != self.pad_index
# check inputs # check inputs
......
...@@ -68,6 +68,8 @@ def _create_and_check_for_headmasking(tester, model_classes, config, inputs_dict ...@@ -68,6 +68,8 @@ def _create_and_check_for_headmasking(tester, model_classes, config, inputs_dict
attentions = outputs[-1] attentions = outputs[-1]
hidden_states = outputs[-2] hidden_states = outputs[-2]
# Remove Nan
tester.parent.assertIsNotNone(multihead_outputs) tester.parent.assertIsNotNone(multihead_outputs)
tester.parent.assertEqual(len(multihead_outputs), tester.num_hidden_layers) tester.parent.assertEqual(len(multihead_outputs), tester.num_hidden_layers)
tester.parent.assertAlmostEqual( tester.parent.assertAlmostEqual(
...@@ -298,7 +300,11 @@ class GPTModelTester(object): ...@@ -298,7 +300,11 @@ class GPTModelTester(object):
mc_labels, lm_labels, mc_token_ids): mc_labels, lm_labels, mc_token_ids):
model = self.base_model_class(config) model = self.base_model_class(config)
model.eval() model.eval()
outputs = model(input_ids, position_ids, token_type_ids) outputs = model(input_ids, position_ids, token_type_ids)
outputs = model(input_ids, position_ids)
outputs = model(input_ids)
hidden_state = outputs[0] hidden_state = outputs[0]
self.parent.assertListEqual( self.parent.assertListEqual(
list(hidden_state.size()), list(hidden_state.size()),
......
...@@ -126,6 +126,8 @@ class BertModelTest(unittest.TestCase): ...@@ -126,6 +126,8 @@ class BertModelTest(unittest.TestCase):
model = BertModel(config=config) model = BertModel(config=config)
model.eval() model.eval()
sequence_output, pooled_output = model(input_ids, token_type_ids, input_mask) sequence_output, pooled_output = model(input_ids, token_type_ids, input_mask)
sequence_output, pooled_output = model(input_ids, token_type_ids)
sequence_output, pooled_output = model(input_ids)
result = { result = {
"sequence_output": sequence_output, "sequence_output": sequence_output,
......
...@@ -96,7 +96,7 @@ class XLMModelTest(unittest.TestCase): ...@@ -96,7 +96,7 @@ class XLMModelTest(unittest.TestCase):
input_lengths = None input_lengths = None
if self.use_input_lengths: if self.use_input_lengths:
input_lengths = ids_tensor([self.batch_size], vocab_size=self.seq_length-1) input_lengths = ids_tensor([self.batch_size], vocab_size=2) + self.seq_length - 2 # small variation of seq_length
token_type_ids = None token_type_ids = None
if self.use_token_type_ids: if self.use_token_type_ids:
...@@ -139,6 +139,8 @@ class XLMModelTest(unittest.TestCase): ...@@ -139,6 +139,8 @@ class XLMModelTest(unittest.TestCase):
model = XLMModel(config=config) model = XLMModel(config=config)
model.eval() model.eval()
outputs = model(input_ids, lengths=input_lengths, langs=token_type_ids) outputs = model(input_ids, lengths=input_lengths, langs=token_type_ids)
outputs = model(input_ids, langs=token_type_ids)
outputs = model(input_ids)
sequence_output = outputs[0] sequence_output = outputs[0]
result = { result = {
"sequence_output": sequence_output, "sequence_output": sequence_output,
...@@ -232,7 +234,7 @@ class XLMModelTest(unittest.TestCase): ...@@ -232,7 +234,7 @@ class XLMModelTest(unittest.TestCase):
def create_and_check_xlm_commons(self, config, input_ids, token_type_ids, input_lengths, sequence_labels, token_labels, choice_labels): def create_and_check_xlm_commons(self, config, input_ids, token_type_ids, input_lengths, sequence_labels, token_labels, choice_labels):
inputs_dict = {'input_ids': input_ids, 'token_type_ids': token_type_ids, 'attention_mask': input_lengths} inputs_dict = {'input_ids': input_ids, 'token_type_ids': token_type_ids, 'lengths': input_lengths}
create_and_check_commons(self, config, inputs_dict) create_and_check_commons(self, config, inputs_dict)
def test_default(self): def test_default(self):
......
...@@ -140,7 +140,26 @@ class XLNetModelTest(unittest.TestCase): ...@@ -140,7 +140,26 @@ class XLNetModelTest(unittest.TestCase):
random.seed(self.seed) random.seed(self.seed)
torch.manual_seed(self.seed) torch.manual_seed(self.seed)
def create_transfo_xl_lm_head(self, config, input_ids_1, input_ids_2, input_ids_q, perm_mask, target_mapping, inp_q, segment_ids, lm_labels): def create_and_check_xlnet_base_model(self, config, input_ids_1, input_ids_2, input_ids_q, perm_mask, target_mapping, inp_q, segment_ids, lm_labels):
model = XLNetModel(config)
model.eval()
_, _ = model(input_ids_1, token_type_ids=segment_ids)
outputs, mems_1 = model(input_ids_1)
result = {
"mems_1": mems_1,
"outputs": outputs,
}
self.parent.assertListEqual(
list(result["outputs"].size()),
[self.batch_size, self.seq_length, self.hidden_size])
self.parent.assertListEqual(
list(list(mem.size()) for mem in result["mems_1"]),
[[self.seq_length, self.batch_size, self.hidden_size]] * self.num_hidden_layers)
def create_and_check_xlnet_lm_head(self, config, input_ids_1, input_ids_2, input_ids_q, perm_mask, target_mapping, inp_q, segment_ids, lm_labels):
model = XLNetLMHeadModel(config) model = XLNetLMHeadModel(config)
model.eval() model.eval()
...@@ -150,7 +169,7 @@ class XLNetModelTest(unittest.TestCase): ...@@ -150,7 +169,7 @@ class XLNetModelTest(unittest.TestCase):
logits, _ = model(input_ids_q, perm_mask=perm_mask, target_mapping=target_mapping, inp_q=inp_q) logits, _ = model(input_ids_q, perm_mask=perm_mask, target_mapping=target_mapping, inp_q=inp_q)
outputs = { result = {
"loss_1": loss_1, "loss_1": loss_1,
"mems_1": mems_1, "mems_1": mems_1,
"all_logits_1": all_logits_1, "all_logits_1": all_logits_1,
...@@ -158,9 +177,7 @@ class XLNetModelTest(unittest.TestCase): ...@@ -158,9 +177,7 @@ class XLNetModelTest(unittest.TestCase):
"mems_2": mems_2, "mems_2": mems_2,
"all_logits_2": all_logits_2, "all_logits_2": all_logits_2,
} }
return outputs
def check_transfo_xl_lm_head_output(self, result):
self.parent.assertListEqual( self.parent.assertListEqual(
list(result["loss_1"].size()), list(result["loss_1"].size()),
[]) [])
...@@ -203,8 +220,11 @@ class XLNetModelTest(unittest.TestCase): ...@@ -203,8 +220,11 @@ class XLNetModelTest(unittest.TestCase):
def run_tester(self, tester): def run_tester(self, tester):
tester.set_seed() tester.set_seed()
config_and_inputs = tester.prepare_config_and_inputs() config_and_inputs = tester.prepare_config_and_inputs()
output_result = tester.create_transfo_xl_lm_head(*config_and_inputs) tester.create_and_check_xlnet_base_model(*config_and_inputs)
tester.check_transfo_xl_lm_head_output(output_result)
tester.set_seed()
config_and_inputs = tester.prepare_config_and_inputs()
tester.create_and_check_xlnet_lm_head(*config_and_inputs)
tester.set_seed() tester.set_seed()
config_and_inputs = tester.prepare_config_and_inputs() config_and_inputs = tester.prepare_config_and_inputs()
......
...@@ -304,7 +304,6 @@ class XLMTokenizer(object): ...@@ -304,7 +304,6 @@ class XLMTokenizer(object):
index = 0 index = 0
with open(merge_file, "w", encoding="utf-8") as writer: with open(merge_file, "w", encoding="utf-8") as writer:
writer.write(u'#version: 0.2\n')
for bpe_tokens, token_index in sorted(self.bpe_ranks.items(), key=lambda kv: kv[1]): for bpe_tokens, token_index in sorted(self.bpe_ranks.items(), key=lambda kv: kv[1]):
if index != token_index: if index != token_index:
logger.warning("Saving vocabulary to {}: BPE merge indices are not consecutive." logger.warning("Saving vocabulary to {}: BPE merge indices are not consecutive."
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment