Commit 8fa3a1f0 authored by thomwolf's avatar thomwolf
Browse files

updating tests

parent c41f2bad
......@@ -35,7 +35,7 @@ from torch.nn import functional as F
from torch.nn import CrossEntropyLoss, MSELoss
from .file_utils import cached_path
from .model_utils import CONFIG_NAME, WEIGHTS_NAME, PretrainedConfig, PreTrainedModel
from .model_utils import CONFIG_NAME, WEIGHTS_NAME, PretrainedConfig, PreTrainedModel, prune_linear_layer
logger = logging.getLogger(__name__)
......@@ -46,24 +46,6 @@ PRETRAINED_CONFIG_ARCHIVE_MAP = {
'xlm-mlm-en-2048': "https://s3.amazonaws.com/models.huggingface.co/bert/xlm-mlm-en-2048-config.json",
}
DECODER_ONLY_PARAMS = [
'layer_norm15.%i.weight', 'layer_norm15.%i.bias',
'encoder_attn.%i.q_lin.weight', 'encoder_attn.%i.q_lin.bias',
'encoder_attn.%i.k_lin.weight', 'encoder_attn.%i.k_lin.bias',
'encoder_attn.%i.v_lin.weight', 'encoder_attn.%i.v_lin.bias',
'encoder_attn.%i.out_lin.weight', 'encoder_attn.%i.out_lin.bias'
]
TRANSFORMER_LAYER_PARAMS = [
'attentions.%i.q_lin.weight', 'attentions.%i.q_lin.bias',
'attentions.%i.k_lin.weight', 'attentions.%i.k_lin.bias',
'attentions.%i.v_lin.weight', 'attentions.%i.v_lin.bias',
'attentions.%i.out_lin.weight', 'attentions.%i.out_lin.bias',
'layer_norm1.%i.weight', 'layer_norm1.%i.bias',
'ffns.%i.lin1.weight', 'ffns.%i.lin1.bias',
'ffns.%i.lin2.weight', 'ffns.%i.lin2.bias',
'layer_norm2.%i.weight', 'layer_norm2.%i.bias'
]
class XLMConfig(PretrainedConfig):
"""Configuration class to store the configuration of a `XLMModel`.
......@@ -275,6 +257,24 @@ class MultiHeadAttention(nn.Module):
self.v_lin = Linear(dim, dim, config=config)
self.out_lin = Linear(dim, dim, config=config)
def prune_heads(self, heads):
attention_head_size = self.dim // self.n_heads
if len(heads) == 0:
return
mask = torch.ones(self.n_heads, attention_head_size)
for head in heads:
mask[head] = 0
mask = mask.view(-1).contiguous().eq(1)
index = torch.arange(len(mask))[mask].long()
# Prune linear layers
self.q_lin = prune_linear_layer(self.q_lin, index)
self.k_lin = prune_linear_layer(self.k_lin, index)
self.v_lin = prune_linear_layer(self.v_lin, index)
self.out_lin = prune_linear_layer(self.out_lin, index, dim=1)
# Update hyper params
self.n_heads = self.n_heads - len(heads)
self.dim = attention_head_size * self.n_heads
def forward(self, input, mask, kv=None, cache=None, head_mask=None):
"""
Self-attention (if kv is None) or attention over source sentence (provided by kv).
......@@ -286,9 +286,9 @@ class MultiHeadAttention(nn.Module):
klen = qlen if cache is None else cache['slen'] + qlen
else:
klen = kv.size(1)
assert dim == self.dim, 'Dimensions do not match: %s input vs %s configured' % (dim, self.dim)
# assert dim == self.dim, 'Dimensions do not match: %s input vs %s configured' % (dim, self.dim)
n_heads = self.n_heads
dim_per_head = dim // n_heads
dim_per_head = self.dim // n_heads
mask_reshape = (bs, 1, qlen, klen) if mask.dim() == 3 else (bs, 1, 1, klen)
def shape(x):
......@@ -335,7 +335,7 @@ class MultiHeadAttention(nn.Module):
outputs = (self.out_lin(context),)
if self.output_attentions:
outputs = outputs + (weights)
outputs = outputs + (weights,)
return outputs
......@@ -497,6 +497,14 @@ class XLMModel(XLMPreTrainedModel):
self.ffns.append(TransformerFFN(self.dim, self.hidden_dim, self.dim, config=config))
self.layer_norm2.append(nn.LayerNorm(self.dim, eps=1e-12))
def _prune_heads(self, heads_to_prune):
""" Prunes heads of the model.
heads_to_prune: dict of {layer_num: list of heads to prune in this layer}
See base class PreTrainedModel
"""
for layer, heads in heads_to_prune.items():
self.attentions[layer].prune_heads(heads)
def forward(self, input_ids, lengths=None, positions=None, langs=None,
token_type_ids=None, attention_mask=None, cache=None, head_mask=None): # src_enc=None, src_len=None,
"""
......@@ -508,7 +516,7 @@ class XLMModel(XLMPreTrainedModel):
`token_type_ids` LongTensor (bs, slen) same as `langs` used for compatibility
"""
if lengths is None:
lengths = (input_ids != self.pad_index).float().sum(dim=1)
lengths = (input_ids != self.pad_index).sum(dim=1).long()
# mask = input_ids != self.pad_index
# check inputs
......
......@@ -68,6 +68,8 @@ def _create_and_check_for_headmasking(tester, model_classes, config, inputs_dict
attentions = outputs[-1]
hidden_states = outputs[-2]
# Remove Nan
tester.parent.assertIsNotNone(multihead_outputs)
tester.parent.assertEqual(len(multihead_outputs), tester.num_hidden_layers)
tester.parent.assertAlmostEqual(
......@@ -298,7 +300,11 @@ class GPTModelTester(object):
mc_labels, lm_labels, mc_token_ids):
model = self.base_model_class(config)
model.eval()
outputs = model(input_ids, position_ids, token_type_ids)
outputs = model(input_ids, position_ids)
outputs = model(input_ids)
hidden_state = outputs[0]
self.parent.assertListEqual(
list(hidden_state.size()),
......
......@@ -126,6 +126,8 @@ class BertModelTest(unittest.TestCase):
model = BertModel(config=config)
model.eval()
sequence_output, pooled_output = model(input_ids, token_type_ids, input_mask)
sequence_output, pooled_output = model(input_ids, token_type_ids)
sequence_output, pooled_output = model(input_ids)
result = {
"sequence_output": sequence_output,
......
......@@ -96,7 +96,7 @@ class XLMModelTest(unittest.TestCase):
input_lengths = None
if self.use_input_lengths:
input_lengths = ids_tensor([self.batch_size], vocab_size=self.seq_length-1)
input_lengths = ids_tensor([self.batch_size], vocab_size=2) + self.seq_length - 2 # small variation of seq_length
token_type_ids = None
if self.use_token_type_ids:
......@@ -139,6 +139,8 @@ class XLMModelTest(unittest.TestCase):
model = XLMModel(config=config)
model.eval()
outputs = model(input_ids, lengths=input_lengths, langs=token_type_ids)
outputs = model(input_ids, langs=token_type_ids)
outputs = model(input_ids)
sequence_output = outputs[0]
result = {
"sequence_output": sequence_output,
......@@ -232,7 +234,7 @@ class XLMModelTest(unittest.TestCase):
def create_and_check_xlm_commons(self, config, input_ids, token_type_ids, input_lengths, sequence_labels, token_labels, choice_labels):
inputs_dict = {'input_ids': input_ids, 'token_type_ids': token_type_ids, 'attention_mask': input_lengths}
inputs_dict = {'input_ids': input_ids, 'token_type_ids': token_type_ids, 'lengths': input_lengths}
create_and_check_commons(self, config, inputs_dict)
def test_default(self):
......
......@@ -140,7 +140,26 @@ class XLNetModelTest(unittest.TestCase):
random.seed(self.seed)
torch.manual_seed(self.seed)
def create_transfo_xl_lm_head(self, config, input_ids_1, input_ids_2, input_ids_q, perm_mask, target_mapping, inp_q, segment_ids, lm_labels):
def create_and_check_xlnet_base_model(self, config, input_ids_1, input_ids_2, input_ids_q, perm_mask, target_mapping, inp_q, segment_ids, lm_labels):
model = XLNetModel(config)
model.eval()
_, _ = model(input_ids_1, token_type_ids=segment_ids)
outputs, mems_1 = model(input_ids_1)
result = {
"mems_1": mems_1,
"outputs": outputs,
}
self.parent.assertListEqual(
list(result["outputs"].size()),
[self.batch_size, self.seq_length, self.hidden_size])
self.parent.assertListEqual(
list(list(mem.size()) for mem in result["mems_1"]),
[[self.seq_length, self.batch_size, self.hidden_size]] * self.num_hidden_layers)
def create_and_check_xlnet_lm_head(self, config, input_ids_1, input_ids_2, input_ids_q, perm_mask, target_mapping, inp_q, segment_ids, lm_labels):
model = XLNetLMHeadModel(config)
model.eval()
......@@ -150,7 +169,7 @@ class XLNetModelTest(unittest.TestCase):
logits, _ = model(input_ids_q, perm_mask=perm_mask, target_mapping=target_mapping, inp_q=inp_q)
outputs = {
result = {
"loss_1": loss_1,
"mems_1": mems_1,
"all_logits_1": all_logits_1,
......@@ -158,9 +177,7 @@ class XLNetModelTest(unittest.TestCase):
"mems_2": mems_2,
"all_logits_2": all_logits_2,
}
return outputs
def check_transfo_xl_lm_head_output(self, result):
self.parent.assertListEqual(
list(result["loss_1"].size()),
[])
......@@ -203,8 +220,11 @@ class XLNetModelTest(unittest.TestCase):
def run_tester(self, tester):
tester.set_seed()
config_and_inputs = tester.prepare_config_and_inputs()
output_result = tester.create_transfo_xl_lm_head(*config_and_inputs)
tester.check_transfo_xl_lm_head_output(output_result)
tester.create_and_check_xlnet_base_model(*config_and_inputs)
tester.set_seed()
config_and_inputs = tester.prepare_config_and_inputs()
tester.create_and_check_xlnet_lm_head(*config_and_inputs)
tester.set_seed()
config_and_inputs = tester.prepare_config_and_inputs()
......
......@@ -304,7 +304,6 @@ class XLMTokenizer(object):
index = 0
with open(merge_file, "w", encoding="utf-8") as writer:
writer.write(u'#version: 0.2\n')
for bpe_tokens, token_index in sorted(self.bpe_ranks.items(), key=lambda kv: kv[1]):
if index != token_index:
logger.warning("Saving vocabulary to {}: BPE merge indices are not consecutive."
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment