Unverified Commit 292186a3 authored by Patrick von Platen's avatar Patrick von Platen Committed by GitHub
Browse files

Adding LM Head to Transfo-XL and first step to fixing problem with Adaptive...

Adding LM Head to Transfo-XL and first step to fixing problem with Adaptive Embeddings in TransfoXL (#3286)

* first commit

* work in progress

* make language generation task pass

* update to working version for LM

* delete print

* remove dead code

* make style
parent efdb46b6
...@@ -357,6 +357,7 @@ if is_tf_available(): ...@@ -357,6 +357,7 @@ if is_tf_available():
TFTransfoXLModel, TFTransfoXLModel,
TFTransfoXLLMHeadModel, TFTransfoXLLMHeadModel,
TF_TRANSFO_XL_PRETRAINED_MODEL_ARCHIVE_MAP, TF_TRANSFO_XL_PRETRAINED_MODEL_ARCHIVE_MAP,
TFAdaptiveEmbedding,
) )
from .modeling_tf_xlnet import ( from .modeling_tf_xlnet import (
......
...@@ -733,6 +733,25 @@ class TFTransfoXLModel(TFTransfoXLPreTrainedModel): ...@@ -733,6 +733,25 @@ class TFTransfoXLModel(TFTransfoXLPreTrainedModel):
return outputs return outputs
class TFTransfoXLLMHead(tf.keras.layers.Layer):
def __init__(self, config, input_embeddings, **kwargs):
super().__init__(**kwargs)
self.vocab_size = config.vocab_size
# The output weights are the same as the input embeddings, but there is
# an output-only bias for each token.
self.input_embeddings = input_embeddings
def build(self, input_shape):
self.bias = self.add_weight(shape=(self.vocab_size,), initializer="zeros", trainable=True, name="bias")
super().build(input_shape)
def call(self, hidden_states):
hidden_states = self.input_embeddings(hidden_states, mode="linear")
hidden_states = hidden_states + self.bias
return hidden_states
@add_start_docstrings( @add_start_docstrings(
"""The Transformer-XL Model with a language modeling head on top """The Transformer-XL Model with a language modeling head on top
(adaptive softmax with weights tied to the adaptive input embeddings)""", (adaptive softmax with weights tied to the adaptive input embeddings)""",
...@@ -743,14 +762,20 @@ class TFTransfoXLLMHeadModel(TFTransfoXLPreTrainedModel): ...@@ -743,14 +762,20 @@ class TFTransfoXLLMHeadModel(TFTransfoXLPreTrainedModel):
super().__init__(config) super().__init__(config)
self.transformer = TFTransfoXLMainLayer(config, name="transformer") self.transformer = TFTransfoXLMainLayer(config, name="transformer")
self.sample_softmax = config.sample_softmax self.sample_softmax = config.sample_softmax
# use sampled softmax assert (
if config.sample_softmax > 0: self.sample_softmax <= 0
raise NotImplementedError ), "Sampling from the softmax is not implemented yet. Please look at issue: #3310: https://github.com/huggingface/transformers/issues/3310"
# use adaptive softmax (including standard softmax)
else: self.crit = TFAdaptiveSoftmaxMask(
self.crit = TFAdaptiveSoftmaxMask( config.vocab_size, config.d_embed, config.d_model, config.cutoffs, div_val=config.div_val, name="crit"
config.vocab_size, config.d_embed, config.d_model, config.cutoffs, div_val=config.div_val, name="crit" )
)
def get_output_embeddings(self):
""" Double-check if you are using adaptive softmax.
"""
if len(self.crit.out_layers) > 0:
return self.crit.out_layers[-1]
return None
def reset_length(self, tgt_len, ext_len, mem_len): def reset_length(self, tgt_len, ext_len, mem_len):
self.transformer.reset_length(tgt_len, ext_len, mem_len) self.transformer.reset_length(tgt_len, ext_len, mem_len)
...@@ -820,13 +845,9 @@ class TFTransfoXLLMHeadModel(TFTransfoXLPreTrainedModel): ...@@ -820,13 +845,9 @@ class TFTransfoXLLMHeadModel(TFTransfoXLPreTrainedModel):
last_hidden = transformer_outputs[0] last_hidden = transformer_outputs[0]
pred_hid = last_hidden[:, -tgt_len:] pred_hid = last_hidden[:, -tgt_len:]
outputs = transformer_outputs[1:] outputs = transformer_outputs[1:]
if self.sample_softmax > 0 and training:
raise NotImplementedError softmax_output = self.crit([pred_hid, labels], training=training)
else: outputs = [softmax_output] + outputs
# pred_hid = tf.reshape(pred_hid, (-1, shape_list(pred_hid)[-1]))
softmax_output = self.crit([pred_hid, labels], training=training)
# softmax_output = tf.reshape(softmax_output, (bsz, tgt_len, -1))
outputs = [softmax_output] + outputs
return outputs # logits, new_mems, (all hidden states), (all attentions) return outputs # logits, new_mems, (all hidden states), (all attentions)
......
...@@ -27,7 +27,7 @@ import torch.nn.functional as F ...@@ -27,7 +27,7 @@ import torch.nn.functional as F
from .configuration_transfo_xl import TransfoXLConfig from .configuration_transfo_xl import TransfoXLConfig
from .file_utils import add_start_docstrings, add_start_docstrings_to_callable from .file_utils import add_start_docstrings, add_start_docstrings_to_callable
from .modeling_transfo_xl_utilities import LogUniformSampler, ProjectedAdaptiveLogSoftmax, sample_logits from .modeling_transfo_xl_utilities import ProjectedAdaptiveLogSoftmax
from .modeling_utils import PreTrainedModel from .modeling_utils import PreTrainedModel
...@@ -809,42 +809,37 @@ class TransfoXLLMHeadModel(TransfoXLPreTrainedModel): ...@@ -809,42 +809,37 @@ class TransfoXLLMHeadModel(TransfoXLPreTrainedModel):
super().__init__(config) super().__init__(config)
self.transformer = TransfoXLModel(config) self.transformer = TransfoXLModel(config)
self.sample_softmax = config.sample_softmax self.sample_softmax = config.sample_softmax
# use sampled softmax
if config.sample_softmax > 0: assert (
self.out_layer = nn.Linear(config.d_model, config.vocab_size) self.sample_softmax <= 0
self.sampler = LogUniformSampler(config.vocab_size, config.sample_softmax) ), "Sampling from the softmax is not implemented yet. Please look at issue: #3310: https://github.com/huggingface/transformers/issues/3310"
# use adaptive softmax (including standard softmax)
else: self.crit = ProjectedAdaptiveLogSoftmax(
self.crit = ProjectedAdaptiveLogSoftmax( config.vocab_size, config.d_embed, config.d_model, config.cutoffs, div_val=config.div_val
config.vocab_size, config.d_embed, config.d_model, config.cutoffs, div_val=config.div_val )
)
self.init_weights() self.init_weights()
def tie_weights(self): def tie_weights(self):
""" """
Run this to be sure output and input (adaptive) softmax weights are tied Run this to be sure output and input (adaptive) softmax weights are tied
""" """
# sampled softmax
if self.sample_softmax > 0: if self.config.tie_weight:
if self.config.tie_weight: for i in range(len(self.crit.out_layers)):
self.out_layer.weight = self.transformer.word_emb.weight self._tie_or_clone_weights(self.crit.out_layers[i], self.transformer.word_emb.emb_layers[i])
# adaptive softmax (including standard softmax) if self.config.tie_projs:
else: for i, tie_proj in enumerate(self.config.tie_projs):
if self.config.tie_weight: if tie_proj and self.config.div_val == 1 and self.config.d_model != self.config.d_embed:
for i in range(len(self.crit.out_layers)): if self.config.torchscript:
self._tie_or_clone_weights(self.crit.out_layers[i], self.transformer.word_emb.emb_layers[i]) self.crit.out_projs[i] = nn.Parameter(self.transformer.word_emb.emb_projs[0].clone())
if self.config.tie_projs: else:
for i, tie_proj in enumerate(self.config.tie_projs): self.crit.out_projs[i] = self.transformer.word_emb.emb_projs[0]
if tie_proj and self.config.div_val == 1 and self.config.d_model != self.config.d_embed: elif tie_proj and self.config.div_val != 1:
if self.config.torchscript: if self.config.torchscript:
self.crit.out_projs[i] = nn.Parameter(self.transformer.word_emb.emb_projs[0].clone()) self.crit.out_projs[i] = nn.Parameter(self.transformer.word_emb.emb_projs[i].clone())
else: else:
self.crit.out_projs[i] = self.transformer.word_emb.emb_projs[0] self.crit.out_projs[i] = self.transformer.word_emb.emb_projs[i]
elif tie_proj and self.config.div_val != 1:
if self.config.torchscript:
self.crit.out_projs[i] = nn.Parameter(self.transformer.word_emb.emb_projs[i].clone())
else:
self.crit.out_projs[i] = self.transformer.word_emb.emb_projs[i]
def reset_length(self, tgt_len, ext_len, mem_len): def reset_length(self, tgt_len, ext_len, mem_len):
self.transformer.reset_length(tgt_len, ext_len, mem_len) self.transformer.reset_length(tgt_len, ext_len, mem_len)
...@@ -908,22 +903,14 @@ class TransfoXLLMHeadModel(TransfoXLPreTrainedModel): ...@@ -908,22 +903,14 @@ class TransfoXLLMHeadModel(TransfoXLPreTrainedModel):
last_hidden = transformer_outputs[0] last_hidden = transformer_outputs[0]
pred_hid = last_hidden[:, -tgt_len:] pred_hid = last_hidden[:, -tgt_len:]
outputs = transformer_outputs[1:] outputs = transformer_outputs[1:]
if self.sample_softmax > 0 and self.training:
assert self.config.tie_weight softmax_output = self.crit(pred_hid.view(-1, pred_hid.size(-1)), labels)
logit = sample_logits(self.transformer.word_emb, self.out_layer.bias, labels, pred_hid, self.sampler) if labels is None:
softmax_output = -F.log_softmax(logit, -1)[:, :, 0] softmax_output = softmax_output.view(bsz, tgt_len, -1)
outputs = [softmax_output] + outputs outputs = [softmax_output] + outputs
if labels is not None:
# TODO: This is not implemented
raise NotImplementedError
else: else:
softmax_output = self.crit(pred_hid.view(-1, pred_hid.size(-1)), labels) softmax_output = softmax_output.view(bsz, tgt_len)
if labels is None: outputs = [softmax_output, None] + outputs
softmax_output = softmax_output.view(bsz, tgt_len, -1)
outputs = [softmax_output] + outputs
else:
softmax_output = softmax_output.view(bsz, tgt_len)
outputs = [softmax_output, None] + outputs
return outputs # (loss), logits or None if labels is not None (speed up adaptive softmax), new_mems, (all hidden states), (all attentions) return outputs # (loss), logits or None if labels is not None (speed up adaptive softmax), new_mems, (all hidden states), (all attentions)
......
...@@ -241,77 +241,3 @@ class ProjectedAdaptiveLogSoftmax(nn.Module): ...@@ -241,77 +241,3 @@ class ProjectedAdaptiveLogSoftmax(nn.Module):
out[:, start_idx, stop_idx] = logprob_i out[:, start_idx, stop_idx] = logprob_i
return out return out
class LogUniformSampler(object):
def __init__(self, range_max, n_sample):
"""
Reference : https://github.com/tensorflow/tensorflow/blob/r1.10/tensorflow/python/ops/candidate_sampling_ops.py
`P(class) = (log(class + 2) - log(class + 1)) / log(range_max + 1)`
expected count can be approximated by 1 - (1 - p)^n
and we use a numerically stable version -expm1(num_tries * log1p(-p))
Our implementation fixes num_tries at 2 * n_sample, and the actual #samples will vary from run to run
"""
with torch.no_grad():
self.range_max = range_max
log_indices = torch.arange(1.0, range_max + 2.0, 1.0).log_()
self.dist = (log_indices[1:] - log_indices[:-1]) / log_indices[-1]
self.log_q = (-(-self.dist.double().log1p_() * 2 * n_sample).expm1_()).log_().float()
self.n_sample = n_sample
def sample(self, labels):
"""
labels: [b1, b2]
Return
true_log_probs: [b1, b2]
samp_log_probs: [n_sample]
neg_samples: [n_sample]
"""
# neg_samples = torch.empty(0).long()
n_sample = self.n_sample
n_tries = 2 * n_sample
with torch.no_grad():
neg_samples = torch.multinomial(self.dist, n_tries, replacement=True).unique()
device = labels.device
neg_samples = neg_samples.to(device)
true_log_probs = self.log_q[labels].to(device)
samp_log_probs = self.log_q[neg_samples].to(device)
return true_log_probs, samp_log_probs, neg_samples
def sample_logits(embedding, bias, labels, inputs, sampler):
"""
embedding: an nn.Embedding layer
bias: [n_vocab]
labels: [b1, b2]
inputs: [b1, b2, n_emb]
sampler: you may use a LogUniformSampler
Return
logits: [b1, b2, 1 + n_sample]
"""
true_log_probs, samp_log_probs, neg_samples = sampler.sample(labels)
n_sample = neg_samples.size(0)
b1, b2 = labels.size(0), labels.size(1)
all_ids = torch.cat([labels.view(-1), neg_samples])
all_w = embedding(all_ids)
true_w = all_w[:-n_sample].view(b1, b2, -1)
sample_w = all_w[-n_sample:].view(n_sample, -1)
all_b = bias[all_ids]
true_b = all_b[:-n_sample].view(b1, b2)
sample_b = all_b[-n_sample:]
hit = (labels[:, :, None] == neg_samples).detach()
true_logits = torch.einsum("ijk,ijk->ij", [true_w, inputs]) + true_b - true_log_probs
sample_logits = torch.einsum("lk,ijk->ijl", [sample_w, inputs]) + sample_b - samp_log_probs
sample_logits.masked_fill_(hit, -1e30)
logits = torch.cat([true_logits[:, :, None], sample_logits], -1)
return logits
...@@ -30,7 +30,7 @@ if is_tf_available(): ...@@ -30,7 +30,7 @@ if is_tf_available():
import tensorflow as tf import tensorflow as tf
import numpy as np import numpy as np
from transformers import tf_top_k_top_p_filtering from transformers import tf_top_k_top_p_filtering, TFAdaptiveEmbedding
if _tf_gpu_memory_limit is not None: if _tf_gpu_memory_limit is not None:
gpus = tf.config.list_physical_devices("GPU") gpus = tf.config.list_physical_devices("GPU")
...@@ -348,7 +348,7 @@ class TFModelTesterMixin: ...@@ -348,7 +348,7 @@ class TFModelTesterMixin:
for model_class in self.all_model_classes: for model_class in self.all_model_classes:
model = model_class(config) model = model_class(config)
assert isinstance(model.get_input_embeddings(), tf.keras.layers.Layer) assert isinstance(model.get_input_embeddings(), (tf.keras.layers.Layer, TFAdaptiveEmbedding))
x = model.get_output_embeddings() x = model.get_output_embeddings()
assert x is None or isinstance(x, tf.keras.layers.Layer) assert x is None or isinstance(x, tf.keras.layers.Layer)
......
...@@ -26,7 +26,7 @@ from .utils import CACHE_DIR, require_tf, slow ...@@ -26,7 +26,7 @@ from .utils import CACHE_DIR, require_tf, slow
if is_tf_available(): if is_tf_available():
import tensorflow as tf import tensorflow as tf
from transformers.modeling_tf_transfo_xl import ( from transformers import (
TFTransfoXLModel, TFTransfoXLModel,
TFTransfoXLLMHeadModel, TFTransfoXLLMHeadModel,
TF_TRANSFO_XL_PRETRAINED_MODEL_ARCHIVE_MAP, TF_TRANSFO_XL_PRETRAINED_MODEL_ARCHIVE_MAP,
...@@ -364,7 +364,7 @@ class TFTransfoXLModelLanguageGenerationTest(unittest.TestCase): ...@@ -364,7 +364,7 @@ class TFTransfoXLModelLanguageGenerationTest(unittest.TestCase):
0, 0,
] ]
], ],
dtype=tf.int31, dtype=tf.int32,
) )
# In 1991 , the remains of Russian Tsar Nicholas II and his family # In 1991 , the remains of Russian Tsar Nicholas II and his family
# ( except for Alexei and Maria ) are discovered . # ( except for Alexei and Maria ) are discovered .
...@@ -570,8 +570,5 @@ class TFTransfoXLModelLanguageGenerationTest(unittest.TestCase): ...@@ -570,8 +570,5 @@ class TFTransfoXLModelLanguageGenerationTest(unittest.TestCase):
# Nicholas II and his family were discovered. The voice of <unk> young son, # Nicholas II and his family were discovered. The voice of <unk> young son,
# Tsarevich Alexei Nikolaevich, narrates the remainder of the story.<eos> # Tsarevich Alexei Nikolaevich, narrates the remainder of the story.<eos>
# TODO: add this test when trasnfo-xl-lmhead is implemented output_ids = model.generate(input_ids, max_length=200, do_sample=False)
with self.assertRaises(NotImplementedError): self.assertListEqual(output_ids[0].numpy().tolist(), expected_output_ids)
model.generate(input_ids, max_length=200, do_sample=False)
print(expected_output_ids)
# self.assertListEqual(output_ids[0].numpy().tolist(), expected_output_ids) TODO: (PVP) to add when transfo-xl is implemented
...@@ -129,10 +129,10 @@ class TransfoXLModelTest(ModelTesterMixin, unittest.TestCase): ...@@ -129,10 +129,10 @@ class TransfoXLModelTest(ModelTesterMixin, unittest.TestCase):
def check_transfo_xl_model_output(self, result): def check_transfo_xl_model_output(self, result):
self.parent.assertListEqual( self.parent.assertListEqual(
list(result["hidden_states_1"].size()), [self.batch_size, self.seq_length, self.hidden_size] list(result["hidden_states_1"].size()), [self.batch_size, self.seq_length, self.hidden_size],
) )
self.parent.assertListEqual( self.parent.assertListEqual(
list(result["hidden_states_2"].size()), [self.batch_size, self.seq_length, self.hidden_size] list(result["hidden_states_2"].size()), [self.batch_size, self.seq_length, self.hidden_size],
) )
self.parent.assertListEqual( self.parent.assertListEqual(
list(list(mem.size()) for mem in result["mems_1"]), list(list(mem.size()) for mem in result["mems_1"]),
...@@ -166,7 +166,7 @@ class TransfoXLModelTest(ModelTesterMixin, unittest.TestCase): ...@@ -166,7 +166,7 @@ class TransfoXLModelTest(ModelTesterMixin, unittest.TestCase):
def check_transfo_xl_lm_head_output(self, result): def check_transfo_xl_lm_head_output(self, result):
self.parent.assertListEqual(list(result["loss_1"].size()), [self.batch_size, self.seq_length]) self.parent.assertListEqual(list(result["loss_1"].size()), [self.batch_size, self.seq_length])
self.parent.assertListEqual( self.parent.assertListEqual(
list(result["lm_logits_1"].size()), [self.batch_size, self.seq_length, self.vocab_size] list(result["lm_logits_1"].size()), [self.batch_size, self.seq_length, self.vocab_size],
) )
self.parent.assertListEqual( self.parent.assertListEqual(
list(list(mem.size()) for mem in result["mems_1"]), list(list(mem.size()) for mem in result["mems_1"]),
...@@ -175,7 +175,7 @@ class TransfoXLModelTest(ModelTesterMixin, unittest.TestCase): ...@@ -175,7 +175,7 @@ class TransfoXLModelTest(ModelTesterMixin, unittest.TestCase):
self.parent.assertListEqual(list(result["loss_2"].size()), [self.batch_size, self.seq_length]) self.parent.assertListEqual(list(result["loss_2"].size()), [self.batch_size, self.seq_length])
self.parent.assertListEqual( self.parent.assertListEqual(
list(result["lm_logits_2"].size()), [self.batch_size, self.seq_length, self.vocab_size] list(result["lm_logits_2"].size()), [self.batch_size, self.seq_length, self.vocab_size],
) )
self.parent.assertListEqual( self.parent.assertListEqual(
list(list(mem.size()) for mem in result["mems_2"]), list(list(mem.size()) for mem in result["mems_2"]),
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment