"docs/source/vscode:/vscode.git/clone" did not exist on "ef2dcdccaa9a115aca44d81f31c6dc4d32bebb3f"
Commit 6b3438df authored by thomwolf's avatar thomwolf
Browse files

fixing GPT2 double head model and updating the torch version tests

parent e3600372
......@@ -367,6 +367,13 @@ class GPT2Model(GPT2PreTrainedModel):
self.h[layer].attn.prune_heads(heads)
def forward(self, input_ids, past=None, attention_mask=None, token_type_ids=None, position_ids=None, head_mask=None):
input_shape = input_ids.size()
input_ids = input_ids.view(-1, input_shape[-1])
if token_type_ids is not None:
token_type_ids = token_type_ids.view(-1, input_shape[-1])
if position_ids is not None:
position_ids = position_ids.view(-1, input_shape[-1])
if past is None:
past_length = 0
past = [None] * len(self.h)
......@@ -378,6 +385,7 @@ class GPT2Model(GPT2PreTrainedModel):
# Attention mask.
if attention_mask is not None:
attention_mask = attention_mask.view(-1, input_shape[-1])
# We create a 3D attention mask from a 2D tensor mask.
# Sizes are [batch_size, 1, 1, to_seq_length]
# So we can broadcast to [batch_size, num_heads, from_seq_length, to_seq_length]
......@@ -407,14 +415,9 @@ class GPT2Model(GPT2PreTrainedModel):
else:
head_mask = [None] * self.config.n_layer
input_shape = input_ids.size()
input_ids = input_ids.view(-1, input_ids.size(-1))
position_ids = position_ids.view(-1, position_ids.size(-1))
inputs_embeds = self.wte(input_ids)
position_embeds = self.wpe(position_ids)
if token_type_ids is not None:
token_type_ids = token_type_ids.view(-1, token_type_ids.size(-1))
token_type_embeds = self.wte(token_type_ids)
else:
token_type_embeds = 0
......
......@@ -314,17 +314,16 @@ class TFGPT2Embeddings(tf.keras.layers.Layer):
def _linear(self, inputs):
"""Computes logits by running inputs through a linear layer.
Args:
inputs: A float32 tensor with shape [batch_size, length, hidden_size]
inputs: A float32 tensor with shape [..., hidden_size]
Returns:
float32 tensor with shape [batch_size, length, vocab_size].
float32 tensor with shape [..., vocab_size].
"""
batch_size = shape_list(inputs)[0]
length = shape_list(inputs)[1]
first_dims = shape_list(inputs)[:-1]
x = tf.reshape(inputs, [-1, self.hidden_size])
logits = tf.matmul(x, self.weight, transpose_b=True)
return tf.reshape(logits, [batch_size, length, self.vocab_size])
return tf.reshape(logits, first_dims + [self.vocab_size])
class TFGPT2MainLayer(tf.keras.layers.Layer):
def __init__(self, config, *inputs, **kwargs):
......@@ -679,10 +678,11 @@ class TFGPT2DoubleHeadsModel(TFGPT2PreTrainedModel):
@tf.function
def call(self, inputs, training=False):
if not isinstance(inputs, (dict, tuple, list)):
raise ValueError("Inputs should be a list or a dict with at least two elements: 'inputs_ids' and 'mc_token_ids'")
input_ids = inputs
mc_token_ids, past, attention_mask, token_type_ids, position_ids, head_mask = None, None, None, None, None
elif isinstance(inputs, (tuple, list)):
input_ids = inputs[0]
mc_token_ids = inputs[1]
mc_token_ids = inputs[1] if len(inputs) > 1 else None
past = inputs[2] if len(inputs) > 2 else None
attention_mask = inputs[3] if len(inputs) > 3 else None
token_type_ids = inputs[4] if len(inputs) > 4 else None
......@@ -691,7 +691,7 @@ class TFGPT2DoubleHeadsModel(TFGPT2PreTrainedModel):
assert len(inputs) <= 7, "Too many inputs."
else:
input_ids = inputs.get('input_ids')
mc_token_ids = inputs.get('mc_token_ids')
mc_token_ids = inputs.get('mc_token_ids', None)
past = inputs.get('past', None)
attention_mask = inputs.get('attention_mask', None)
token_type_ids = inputs.get('token_type_ids', None)
......@@ -699,9 +699,9 @@ class TFGPT2DoubleHeadsModel(TFGPT2PreTrainedModel):
head_mask = inputs.get('head_mask', None)
assert len(inputs) <= 5, "Too many inputs."
assert len(shape_list(input_ids)) == 3, "Inputs should have 3 dimensions: batch, choices, sequence length"
num_choices = shape_list(input_ids)[1]
seq_length = shape_list(input_ids)[2]
input_shapes = shape_list(input_ids)
seq_length = input_shapes[-1]
flat_input_ids = tf.reshape(input_ids, (-1, seq_length))
flat_attention_mask = tf.reshape(attention_mask, (-1, seq_length)) if attention_mask is not None else None
......@@ -710,13 +710,16 @@ class TFGPT2DoubleHeadsModel(TFGPT2PreTrainedModel):
flat_inputs = [flat_input_ids, past, flat_attention_mask, flat_token_type_ids, flat_position_ids, head_mask]
outputs = self.transformer(flat_inputs, training=training)
transformer_outputs = self.transformer(flat_inputs, training=training)
hidden_states = transformer_outputs[0]
hidden_states = tf.reshape(hidden_states, input_shapes + shape_list(hidden_states)[-1:])
lm_logits = self.transformer.wte(hidden_states, mode="linear")
mc_logits = self.multiple_choice_head([hidden_states, mc_token_ids], training=training)
mc_logits = tf.squeeze(mc_logits, axis=-1)
outputs = (lm_logits, mc_logits) + transformer_outputs[1:]
return outputs # (lm loss), (mc loss), lm logits, mc logits, presents, (all hidden_states), (attentions)
......@@ -359,13 +359,18 @@ class TFSequenceSummary(tf.keras.layers.Layer):
elif self.summary_type == 'mean':
output = tf.mean(hidden_states, axis=1)
elif self.summary_type == 'cls_index':
hidden_shape = shape_list(hidden_states) # e.g. [batch, num choices, seq length, hidden dims]
if cls_index is None:
cls_index = tf.fill(tf.shape(hidden_states[..., :1, :]), hidden_states.shape[-2]-1, dtype=tf.int32)
else:
cls_index = cls_index[..., tf.newaxis, tf.newaxis]
cls_index = cls_index.expand((-1,) * (cls_index.dim()-1) + (hidden_states.size(-1),))
cls_index = tf.fill(hidden_shape[:-2], hidden_shape[-2] - 1) # A tensor full of shape [batch] or [batch, num choices] full of sequence length
cls_shape = shape_list(cls_index)
if len(cls_shape) <= len(hidden_shape) - 2:
cls_index = cls_index[..., tf.newaxis]
# else:
# cls_index = cls_index[..., tf.newaxis]
# cls_index = cls_index.expand((-1,) * (cls_index.dim()-1) + (hidden_states.size(-1),))
# shape of cls_index: (bsz, XX, 1, hidden_size) where XX are optional leading dim of hidden_states
output = hidden_states.gather(-2, cls_index).squeeze(-2) # shape (bsz, XX, hidden_size)
output = tf.gather(hidden_states, cls_index, batch_dims=len(hidden_shape) - 2)
output = tf.squeeze(output, axis=len(hidden_shape) - 2) # shape of output: (batch, num choices, hidden_size)
elif self.summary_type == 'attn':
raise NotImplementedError
......
......@@ -679,7 +679,7 @@ class SequenceSummary(nn.Module):
self.last_dropout = nn.Dropout(config.summary_last_dropout)
def forward(self, hidden_states, cls_index=None):
""" hidden_states: float Tensor in shape [bsz, seq_len, hidden_size], the hidden-states of the last layer.
""" hidden_states: float Tensor in shape [bsz, ..., seq_len, hidden_size], the hidden-states of the last layer.
cls_index: [optional] position of the classification token if summary_type == 'cls_index',
shape (bsz,) or more generally (bsz, ...) where ... are optional leading dimensions of hidden_states.
if summary_type == 'cls_index' and cls_index is None:
......
......@@ -46,6 +46,7 @@ class GPT2ModelTest(CommonTestCases.CommonModelTester):
use_token_type_ids=True,
use_input_mask=True,
use_labels=True,
use_mc_token_ids=True,
vocab_size=99,
hidden_size=32,
num_hidden_layers=5,
......@@ -69,6 +70,7 @@ class GPT2ModelTest(CommonTestCases.CommonModelTester):
self.use_token_type_ids = use_token_type_ids
self.use_input_mask = use_input_mask
self.use_labels = use_labels
self.use_mc_token_ids = use_mc_token_ids
self.vocab_size = vocab_size
self.hidden_size = hidden_size
self.num_hidden_layers = num_hidden_layers
......@@ -96,6 +98,10 @@ class GPT2ModelTest(CommonTestCases.CommonModelTester):
if self.use_token_type_ids:
token_type_ids = ids_tensor([self.batch_size, self.seq_length], self.type_vocab_size)
mc_token_ids = None
if self.use_mc_token_ids:
mc_token_ids = ids_tensor([self.batch_size, self.num_choices], self.seq_length)
sequence_labels = None
token_labels = None
choice_labels = None
......@@ -121,7 +127,7 @@ class GPT2ModelTest(CommonTestCases.CommonModelTester):
head_mask = ids_tensor([self.num_hidden_layers, self.num_attention_heads], 2)
return config, input_ids, input_mask, head_mask, token_type_ids, sequence_labels, token_labels, choice_labels
return config, input_ids, input_mask, head_mask, token_type_ids, mc_token_ids, sequence_labels, token_labels, choice_labels
def check_loss_output(self, result):
self.parent.assertListEqual(
......@@ -163,15 +169,27 @@ class GPT2ModelTest(CommonTestCases.CommonModelTester):
list(result["lm_logits"].size()),
[self.batch_size, self.seq_length, self.vocab_size])
def create_and_check_double_lm_head_model(self, config, input_ids, input_mask, head_mask, token_type_ids, *args):
def create_and_check_double_lm_head_model(self, config, input_ids, input_mask, head_mask, token_type_ids, mc_token_ids, *args):
model = GPT2DoubleHeadsModel(config)
model.eval()
loss, lm_logits, mc_logits, _ = model(input_ids, token_type_ids=token_type_ids, lm_labels=input_ids)
multiple_choice_inputs_ids = input_ids.unsqueeze(1).expand(-1, self.num_choices, -1).contiguous()
multiple_choice_input_mask = input_mask.unsqueeze(1).expand(-1, self.num_choices, -1).contiguous()
multiple_choice_token_type_ids = token_type_ids.unsqueeze(1).expand(-1, self.num_choices, -1).contiguous()
inputs = {'input_ids': multiple_choice_inputs_ids,
'mc_token_ids': mc_token_ids,
'attention_mask': multiple_choice_input_mask,
'token_type_ids': multiple_choice_token_type_ids,
'lm_labels': multiple_choice_inputs_ids}
loss, lm_logits, mc_logits, _ = model(**inputs)
result = {
"loss": loss,
"lm_logits": lm_logits
"lm_logits": lm_logits,
"mc_logits": mc_logits
}
self.parent.assertListEqual(
......@@ -179,11 +197,17 @@ class GPT2ModelTest(CommonTestCases.CommonModelTester):
[])
self.parent.assertListEqual(
list(result["lm_logits"].size()),
[self.batch_size, self.seq_length, self.vocab_size])
[self.batch_size, self.num_choices, self.seq_length, self.vocab_size])
self.parent.assertListEqual(
list(result["mc_logits"].size()),
[self.batch_size, self.num_choices])
def prepare_config_and_inputs_for_common(self):
config_and_inputs = self.prepare_config_and_inputs()
(config, input_ids, input_mask, head_mask, token_type_ids, sequence_labels, token_labels, choice_labels) = config_and_inputs
(config, input_ids, input_mask, head_mask, token_type_ids,
mc_token_ids, sequence_labels, token_labels, choice_labels) = config_and_inputs
inputs_dict = {
'input_ids': input_ids,
'token_type_ids': token_type_ids,
......
......@@ -37,9 +37,9 @@ else:
class TFGPT2ModelTest(TFCommonTestCases.TFCommonModelTester):
# all_model_classes = (TFGPT2Model, TFGPT2LMHeadModel,
# TFGPT2DoubleHeadsModel) if is_tf_available() else ()
all_model_classes = (TFGPT2Model, TFGPT2LMHeadModel) if is_tf_available() else ()
all_model_classes = (TFGPT2Model, TFGPT2LMHeadModel,
TFGPT2DoubleHeadsModel) if is_tf_available() else ()
# all_model_classes = (TFGPT2Model, TFGPT2LMHeadModel) if is_tf_available() else ()
class TFGPT2ModelTester(object):
......@@ -51,6 +51,7 @@ class TFGPT2ModelTest(TFCommonTestCases.TFCommonModelTester):
use_token_type_ids=True,
use_input_mask=True,
use_labels=True,
use_mc_token_ids=True,
vocab_size=99,
hidden_size=32,
num_hidden_layers=5,
......@@ -74,6 +75,7 @@ class TFGPT2ModelTest(TFCommonTestCases.TFCommonModelTester):
self.use_token_type_ids = use_token_type_ids
self.use_input_mask = use_input_mask
self.use_labels = use_labels
self.use_mc_token_ids = use_mc_token_ids
self.vocab_size = vocab_size
self.hidden_size = hidden_size
self.num_hidden_layers = num_hidden_layers
......@@ -101,6 +103,10 @@ class TFGPT2ModelTest(TFCommonTestCases.TFCommonModelTester):
if self.use_token_type_ids:
token_type_ids = ids_tensor([self.batch_size, self.seq_length], self.type_vocab_size)
mc_token_ids = None
if self.use_mc_token_ids:
mc_token_ids = ids_tensor([self.batch_size, self.num_choices], self.seq_length)
sequence_labels = None
token_labels = None
choice_labels = None
......@@ -126,7 +132,7 @@ class TFGPT2ModelTest(TFCommonTestCases.TFCommonModelTester):
head_mask = ids_tensor([self.num_hidden_layers, self.num_attention_heads], 2)
return config, input_ids, input_mask, head_mask, token_type_ids, sequence_labels, token_labels, choice_labels
return config, input_ids, input_mask, head_mask, token_type_ids, mc_token_ids, sequence_labels, token_labels, choice_labels
def create_and_check_gpt2_model(self, config, input_ids, input_mask, head_mask, token_type_ids, *args):
model = TFGPT2Model(config=config)
......@@ -162,25 +168,34 @@ class TFGPT2ModelTest(TFCommonTestCases.TFCommonModelTester):
[self.batch_size, self.seq_length, self.vocab_size])
def create_and_check_gpt2_double_head(self, config, input_ids, input_mask, head_mask, token_type_ids, *args):
pass
# model = TFGPT2DoubleHeadsModel(config=config)
# inputs = {'input_ids': input_ids,
# 'attention_mask': input_mask,
# 'token_type_ids': token_type_ids}
# seq_relationship_score, = model(inputs)[0]
# result = {
# "seq_relationship_score": seq_relationship_score.numpy(),
# }
# self.parent.assertListEqual(
# list(result["seq_relationship_score"].shape),
# [self.batch_size, 2])
def create_and_check_gpt2_double_head(self, config, input_ids, input_mask, head_mask, token_type_ids, mc_token_ids, *args):
model = TFGPT2DoubleHeadsModel(config=config)
multiple_choice_inputs_ids = tf.tile(tf.expand_dims(input_ids, 1), (1, self.num_choices, 1))
multiple_choice_input_mask = tf.tile(tf.expand_dims(input_mask, 1), (1, self.num_choices, 1))
multiple_choice_token_type_ids = tf.tile(tf.expand_dims(token_type_ids, 1), (1, self.num_choices, 1))
inputs = {'input_ids': multiple_choice_inputs_ids,
'mc_token_ids': mc_token_ids,
'attention_mask': multiple_choice_input_mask,
'token_type_ids': multiple_choice_token_type_ids}
lm_logits, mc_logits = model(inputs)[:2]
result = {
"lm_logits": lm_logits.numpy(),
"mc_logits": mc_logits.numpy()
}
self.parent.assertListEqual(
list(result["lm_logits"].shape),
[self.batch_size, self.num_choices, self.seq_length, self.vocab_size])
self.parent.assertListEqual(
list(result["mc_logits"].shape),
[self.batch_size, self.num_choices])
def prepare_config_and_inputs_for_common(self):
config_and_inputs = self.prepare_config_and_inputs()
(config, input_ids, input_mask, head_mask, token_type_ids,
sequence_labels, token_labels, choice_labels) = config_and_inputs
mc_token_ids, sequence_labels, token_labels, choice_labels) = config_and_inputs
inputs_dict = {'input_ids': input_ids, 'token_type_ids': token_type_ids, 'attention_mask': input_mask}
return config, inputs_dict
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment