Unverified Commit 292186a3 authored by Patrick von Platen's avatar Patrick von Platen Committed by GitHub
Browse files

Adding LM Head to Transfo-XL and first step to fixing problem with Adaptive...

Adding LM Head to Transfo-XL and first step to fixing problem with Adaptive Embeddings in TransfoXL (#3286)

* first commit

* work in progress

* make language generation task pass

* update to working version for LM

* delete print

* remove dead code

* make style
parent efdb46b6
......@@ -357,6 +357,7 @@ if is_tf_available():
TFTransfoXLModel,
TFTransfoXLLMHeadModel,
TF_TRANSFO_XL_PRETRAINED_MODEL_ARCHIVE_MAP,
TFAdaptiveEmbedding,
)
from .modeling_tf_xlnet import (
......
......@@ -733,6 +733,25 @@ class TFTransfoXLModel(TFTransfoXLPreTrainedModel):
return outputs
class TFTransfoXLLMHead(tf.keras.layers.Layer):
def __init__(self, config, input_embeddings, **kwargs):
super().__init__(**kwargs)
self.vocab_size = config.vocab_size
# The output weights are the same as the input embeddings, but there is
# an output-only bias for each token.
self.input_embeddings = input_embeddings
def build(self, input_shape):
self.bias = self.add_weight(shape=(self.vocab_size,), initializer="zeros", trainable=True, name="bias")
super().build(input_shape)
def call(self, hidden_states):
hidden_states = self.input_embeddings(hidden_states, mode="linear")
hidden_states = hidden_states + self.bias
return hidden_states
@add_start_docstrings(
"""The Transformer-XL Model with a language modeling head on top
(adaptive softmax with weights tied to the adaptive input embeddings)""",
......@@ -743,15 +762,21 @@ class TFTransfoXLLMHeadModel(TFTransfoXLPreTrainedModel):
super().__init__(config)
self.transformer = TFTransfoXLMainLayer(config, name="transformer")
self.sample_softmax = config.sample_softmax
# use sampled softmax
if config.sample_softmax > 0:
raise NotImplementedError
# use adaptive softmax (including standard softmax)
else:
assert (
self.sample_softmax <= 0
), "Sampling from the softmax is not implemented yet. Please look at issue: #3310: https://github.com/huggingface/transformers/issues/3310"
self.crit = TFAdaptiveSoftmaxMask(
config.vocab_size, config.d_embed, config.d_model, config.cutoffs, div_val=config.div_val, name="crit"
)
def get_output_embeddings(self):
""" Double-check if you are using adaptive softmax.
"""
if len(self.crit.out_layers) > 0:
return self.crit.out_layers[-1]
return None
def reset_length(self, tgt_len, ext_len, mem_len):
self.transformer.reset_length(tgt_len, ext_len, mem_len)
......@@ -820,12 +845,8 @@ class TFTransfoXLLMHeadModel(TFTransfoXLPreTrainedModel):
last_hidden = transformer_outputs[0]
pred_hid = last_hidden[:, -tgt_len:]
outputs = transformer_outputs[1:]
if self.sample_softmax > 0 and training:
raise NotImplementedError
else:
# pred_hid = tf.reshape(pred_hid, (-1, shape_list(pred_hid)[-1]))
softmax_output = self.crit([pred_hid, labels], training=training)
# softmax_output = tf.reshape(softmax_output, (bsz, tgt_len, -1))
outputs = [softmax_output] + outputs
return outputs # logits, new_mems, (all hidden states), (all attentions)
......
......@@ -27,7 +27,7 @@ import torch.nn.functional as F
from .configuration_transfo_xl import TransfoXLConfig
from .file_utils import add_start_docstrings, add_start_docstrings_to_callable
from .modeling_transfo_xl_utilities import LogUniformSampler, ProjectedAdaptiveLogSoftmax, sample_logits
from .modeling_transfo_xl_utilities import ProjectedAdaptiveLogSoftmax
from .modeling_utils import PreTrainedModel
......@@ -809,27 +809,22 @@ class TransfoXLLMHeadModel(TransfoXLPreTrainedModel):
super().__init__(config)
self.transformer = TransfoXLModel(config)
self.sample_softmax = config.sample_softmax
# use sampled softmax
if config.sample_softmax > 0:
self.out_layer = nn.Linear(config.d_model, config.vocab_size)
self.sampler = LogUniformSampler(config.vocab_size, config.sample_softmax)
# use adaptive softmax (including standard softmax)
else:
assert (
self.sample_softmax <= 0
), "Sampling from the softmax is not implemented yet. Please look at issue: #3310: https://github.com/huggingface/transformers/issues/3310"
self.crit = ProjectedAdaptiveLogSoftmax(
config.vocab_size, config.d_embed, config.d_model, config.cutoffs, div_val=config.div_val
)
self.init_weights()
def tie_weights(self):
"""
Run this to be sure output and input (adaptive) softmax weights are tied
"""
# sampled softmax
if self.sample_softmax > 0:
if self.config.tie_weight:
self.out_layer.weight = self.transformer.word_emb.weight
# adaptive softmax (including standard softmax)
else:
if self.config.tie_weight:
for i in range(len(self.crit.out_layers)):
self._tie_or_clone_weights(self.crit.out_layers[i], self.transformer.word_emb.emb_layers[i])
......@@ -908,15 +903,7 @@ class TransfoXLLMHeadModel(TransfoXLPreTrainedModel):
last_hidden = transformer_outputs[0]
pred_hid = last_hidden[:, -tgt_len:]
outputs = transformer_outputs[1:]
if self.sample_softmax > 0 and self.training:
assert self.config.tie_weight
logit = sample_logits(self.transformer.word_emb, self.out_layer.bias, labels, pred_hid, self.sampler)
softmax_output = -F.log_softmax(logit, -1)[:, :, 0]
outputs = [softmax_output] + outputs
if labels is not None:
# TODO: This is not implemented
raise NotImplementedError
else:
softmax_output = self.crit(pred_hid.view(-1, pred_hid.size(-1)), labels)
if labels is None:
softmax_output = softmax_output.view(bsz, tgt_len, -1)
......
......@@ -241,77 +241,3 @@ class ProjectedAdaptiveLogSoftmax(nn.Module):
out[:, start_idx, stop_idx] = logprob_i
return out
class LogUniformSampler(object):
def __init__(self, range_max, n_sample):
"""
Reference : https://github.com/tensorflow/tensorflow/blob/r1.10/tensorflow/python/ops/candidate_sampling_ops.py
`P(class) = (log(class + 2) - log(class + 1)) / log(range_max + 1)`
expected count can be approximated by 1 - (1 - p)^n
and we use a numerically stable version -expm1(num_tries * log1p(-p))
Our implementation fixes num_tries at 2 * n_sample, and the actual #samples will vary from run to run
"""
with torch.no_grad():
self.range_max = range_max
log_indices = torch.arange(1.0, range_max + 2.0, 1.0).log_()
self.dist = (log_indices[1:] - log_indices[:-1]) / log_indices[-1]
self.log_q = (-(-self.dist.double().log1p_() * 2 * n_sample).expm1_()).log_().float()
self.n_sample = n_sample
def sample(self, labels):
"""
labels: [b1, b2]
Return
true_log_probs: [b1, b2]
samp_log_probs: [n_sample]
neg_samples: [n_sample]
"""
# neg_samples = torch.empty(0).long()
n_sample = self.n_sample
n_tries = 2 * n_sample
with torch.no_grad():
neg_samples = torch.multinomial(self.dist, n_tries, replacement=True).unique()
device = labels.device
neg_samples = neg_samples.to(device)
true_log_probs = self.log_q[labels].to(device)
samp_log_probs = self.log_q[neg_samples].to(device)
return true_log_probs, samp_log_probs, neg_samples
def sample_logits(embedding, bias, labels, inputs, sampler):
"""
embedding: an nn.Embedding layer
bias: [n_vocab]
labels: [b1, b2]
inputs: [b1, b2, n_emb]
sampler: you may use a LogUniformSampler
Return
logits: [b1, b2, 1 + n_sample]
"""
true_log_probs, samp_log_probs, neg_samples = sampler.sample(labels)
n_sample = neg_samples.size(0)
b1, b2 = labels.size(0), labels.size(1)
all_ids = torch.cat([labels.view(-1), neg_samples])
all_w = embedding(all_ids)
true_w = all_w[:-n_sample].view(b1, b2, -1)
sample_w = all_w[-n_sample:].view(n_sample, -1)
all_b = bias[all_ids]
true_b = all_b[:-n_sample].view(b1, b2)
sample_b = all_b[-n_sample:]
hit = (labels[:, :, None] == neg_samples).detach()
true_logits = torch.einsum("ijk,ijk->ij", [true_w, inputs]) + true_b - true_log_probs
sample_logits = torch.einsum("lk,ijk->ijl", [sample_w, inputs]) + sample_b - samp_log_probs
sample_logits.masked_fill_(hit, -1e30)
logits = torch.cat([true_logits[:, :, None], sample_logits], -1)
return logits
......@@ -30,7 +30,7 @@ if is_tf_available():
import tensorflow as tf
import numpy as np
from transformers import tf_top_k_top_p_filtering
from transformers import tf_top_k_top_p_filtering, TFAdaptiveEmbedding
if _tf_gpu_memory_limit is not None:
gpus = tf.config.list_physical_devices("GPU")
......@@ -348,7 +348,7 @@ class TFModelTesterMixin:
for model_class in self.all_model_classes:
model = model_class(config)
assert isinstance(model.get_input_embeddings(), tf.keras.layers.Layer)
assert isinstance(model.get_input_embeddings(), (tf.keras.layers.Layer, TFAdaptiveEmbedding))
x = model.get_output_embeddings()
assert x is None or isinstance(x, tf.keras.layers.Layer)
......
......@@ -26,7 +26,7 @@ from .utils import CACHE_DIR, require_tf, slow
if is_tf_available():
import tensorflow as tf
from transformers.modeling_tf_transfo_xl import (
from transformers import (
TFTransfoXLModel,
TFTransfoXLLMHeadModel,
TF_TRANSFO_XL_PRETRAINED_MODEL_ARCHIVE_MAP,
......@@ -364,7 +364,7 @@ class TFTransfoXLModelLanguageGenerationTest(unittest.TestCase):
0,
]
],
dtype=tf.int31,
dtype=tf.int32,
)
# In 1991 , the remains of Russian Tsar Nicholas II and his family
# ( except for Alexei and Maria ) are discovered .
......@@ -570,8 +570,5 @@ class TFTransfoXLModelLanguageGenerationTest(unittest.TestCase):
# Nicholas II and his family were discovered. The voice of <unk> young son,
# Tsarevich Alexei Nikolaevich, narrates the remainder of the story.<eos>
# TODO: add this test when trasnfo-xl-lmhead is implemented
with self.assertRaises(NotImplementedError):
model.generate(input_ids, max_length=200, do_sample=False)
print(expected_output_ids)
# self.assertListEqual(output_ids[0].numpy().tolist(), expected_output_ids) TODO: (PVP) to add when transfo-xl is implemented
output_ids = model.generate(input_ids, max_length=200, do_sample=False)
self.assertListEqual(output_ids[0].numpy().tolist(), expected_output_ids)
......@@ -129,10 +129,10 @@ class TransfoXLModelTest(ModelTesterMixin, unittest.TestCase):
def check_transfo_xl_model_output(self, result):
self.parent.assertListEqual(
list(result["hidden_states_1"].size()), [self.batch_size, self.seq_length, self.hidden_size]
list(result["hidden_states_1"].size()), [self.batch_size, self.seq_length, self.hidden_size],
)
self.parent.assertListEqual(
list(result["hidden_states_2"].size()), [self.batch_size, self.seq_length, self.hidden_size]
list(result["hidden_states_2"].size()), [self.batch_size, self.seq_length, self.hidden_size],
)
self.parent.assertListEqual(
list(list(mem.size()) for mem in result["mems_1"]),
......@@ -166,7 +166,7 @@ class TransfoXLModelTest(ModelTesterMixin, unittest.TestCase):
def check_transfo_xl_lm_head_output(self, result):
self.parent.assertListEqual(list(result["loss_1"].size()), [self.batch_size, self.seq_length])
self.parent.assertListEqual(
list(result["lm_logits_1"].size()), [self.batch_size, self.seq_length, self.vocab_size]
list(result["lm_logits_1"].size()), [self.batch_size, self.seq_length, self.vocab_size],
)
self.parent.assertListEqual(
list(list(mem.size()) for mem in result["mems_1"]),
......@@ -175,7 +175,7 @@ class TransfoXLModelTest(ModelTesterMixin, unittest.TestCase):
self.parent.assertListEqual(list(result["loss_2"].size()), [self.batch_size, self.seq_length])
self.parent.assertListEqual(
list(result["lm_logits_2"].size()), [self.batch_size, self.seq_length, self.vocab_size]
list(result["lm_logits_2"].size()), [self.batch_size, self.seq_length, self.vocab_size],
)
self.parent.assertListEqual(
list(list(mem.size()) for mem in result["mems_2"]),
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment