Commit adb5c79f authored by thomwolf's avatar thomwolf Committed by Lysandre Debut
Browse files

update all tf.shape and tensor.shape to shape_list

parent 1ab8dc44
...@@ -32,7 +32,7 @@ import numpy as np ...@@ -32,7 +32,7 @@ import numpy as np
import tensorflow as tf import tensorflow as tf
from .configuration_xxx import XxxConfig from .configuration_xxx import XxxConfig
from .modeling_tf_utils import TFPreTrainedModel, get_initializer from .modeling_tf_utils import TFPreTrainedModel, get_initializer, shape_list
from .file_utils import add_start_docstrings from .file_utils import add_start_docstrings
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
...@@ -121,9 +121,9 @@ class TFXxxMainLayer(tf.keras.layers.Layer): ...@@ -121,9 +121,9 @@ class TFXxxMainLayer(tf.keras.layers.Layer):
input_ids = inputs input_ids = inputs
if attention_mask is None: if attention_mask is None:
attention_mask = tf.fill(tf.shape(input_ids), 1) attention_mask = tf.fill(shape_list(input_ids), 1)
if token_type_ids is None: if token_type_ids is None:
token_type_ids = tf.fill(tf.shape(input_ids), 0) token_type_ids = tf.fill(shape_list(input_ids), 0)
# We create a 3D attention mask from a 2D tensor mask. # We create a 3D attention mask from a 2D tensor mask.
# Sizes are [batch_size, 1, 1, to_seq_length] # Sizes are [batch_size, 1, 1, to_seq_length]
......
...@@ -118,7 +118,7 @@ if is_torch_available(): ...@@ -118,7 +118,7 @@ if is_torch_available():
# TensorFlow # TensorFlow
if is_tf_available(): if is_tf_available():
from .modeling_tf_utils import TFPreTrainedModel, TFSharedEmbeddings, TFSequenceSummary from .modeling_tf_utils import TFPreTrainedModel, TFSharedEmbeddings, TFSequenceSummary, shape_list
from .modeling_tf_auto import (TFAutoModel, TFAutoModelForSequenceClassification, TFAutoModelForQuestionAnswering, from .modeling_tf_auto import (TFAutoModel, TFAutoModelForSequenceClassification, TFAutoModelForQuestionAnswering,
TFAutoModelWithLMHead) TFAutoModelWithLMHead)
......
...@@ -16,18 +16,13 @@ ...@@ -16,18 +16,13 @@
""" TF 2.0 ALBERT model. """ """ TF 2.0 ALBERT model. """
from __future__ import absolute_import, division, print_function, unicode_literals from __future__ import absolute_import, division, print_function, unicode_literals
import json
import logging import logging
import math
import os
import sys import sys
from io import open
import numpy as np
import tensorflow as tf import tensorflow as tf
from .configuration_albert import AlbertConfig from .configuration_albert import AlbertConfig
from .modeling_tf_utils import TFPreTrainedModel, get_initializer from .modeling_tf_utils import TFPreTrainedModel, get_initializer, shape_list
from .modeling_tf_bert import ACT2FN, TFBertSelfAttention from .modeling_tf_bert import ACT2FN, TFBertSelfAttention
from .file_utils import add_start_docstrings from .file_utils import add_start_docstrings
...@@ -110,9 +105,9 @@ class TFAlbertEmbeddings(tf.keras.layers.Layer): ...@@ -110,9 +105,9 @@ class TFAlbertEmbeddings(tf.keras.layers.Layer):
input_ids, position_ids, token_type_ids, inputs_embeds = inputs input_ids, position_ids, token_type_ids, inputs_embeds = inputs
if input_ids is not None: if input_ids is not None:
input_shape = tf.shape(input_ids) input_shape = shape_list(input_ids)
else: else:
input_shape = tf.shape(inputs_embeds)[:-1] input_shape = shape_list(inputs_embeds)[:-1]
seq_length = input_shape[1] seq_length = input_shape[1]
if position_ids is None: if position_ids is None:
...@@ -137,8 +132,8 @@ class TFAlbertEmbeddings(tf.keras.layers.Layer): ...@@ -137,8 +132,8 @@ class TFAlbertEmbeddings(tf.keras.layers.Layer):
Returns: Returns:
float32 tensor with shape [batch_size, length, vocab_size]. float32 tensor with shape [batch_size, length, vocab_size].
""" """
batch_size = tf.shape(inputs)[0] batch_size = shape_list(inputs)[0]
length = tf.shape(inputs)[1] length = shape_list(inputs)[1]
x = tf.reshape(inputs, [-1, self.config.embedding_size]) x = tf.reshape(inputs, [-1, self.config.embedding_size])
logits = tf.matmul(x, self.word_embeddings, transpose_b=True) logits = tf.matmul(x, self.word_embeddings, transpose_b=True)
return tf.reshape(logits, [batch_size, length, self.config.vocab_size]) return tf.reshape(logits, [batch_size, length, self.config.vocab_size])
...@@ -183,7 +178,7 @@ class TFAlbertSelfAttention(tf.keras.layers.Layer): ...@@ -183,7 +178,7 @@ class TFAlbertSelfAttention(tf.keras.layers.Layer):
def call(self, inputs, training=False): def call(self, inputs, training=False):
hidden_states, attention_mask, head_mask = inputs hidden_states, attention_mask, head_mask = inputs
batch_size = tf.shape(hidden_states)[0] batch_size = shape_list(hidden_states)[0]
mixed_query_layer = self.query(hidden_states) mixed_query_layer = self.query(hidden_states)
mixed_key_layer = self.key(hidden_states) mixed_key_layer = self.key(hidden_states)
mixed_value_layer = self.value(hidden_states) mixed_value_layer = self.value(hidden_states)
...@@ -196,7 +191,7 @@ class TFAlbertSelfAttention(tf.keras.layers.Layer): ...@@ -196,7 +191,7 @@ class TFAlbertSelfAttention(tf.keras.layers.Layer):
# (batch size, num_heads, seq_len_q, seq_len_k) # (batch size, num_heads, seq_len_q, seq_len_k)
attention_scores = tf.matmul(query_layer, key_layer, transpose_b=True) attention_scores = tf.matmul(query_layer, key_layer, transpose_b=True)
# scale attention_scores # scale attention_scores
dk = tf.cast(tf.shape(key_layer)[-1], tf.float32) dk = tf.cast(shape_list(key_layer)[-1], tf.float32)
attention_scores = attention_scores / tf.math.sqrt(dk) attention_scores = attention_scores / tf.math.sqrt(dk)
if attention_mask is not None: if attention_mask is not None:
...@@ -264,7 +259,7 @@ class TFAlbertAttention(TFBertSelfAttention): ...@@ -264,7 +259,7 @@ class TFAlbertAttention(TFBertSelfAttention):
def call(self, inputs, training=False): def call(self, inputs, training=False):
input_tensor, attention_mask, head_mask = inputs input_tensor, attention_mask, head_mask = inputs
batch_size = tf.shape(input_tensor)[0] batch_size = shape_list(input_tensor)[0]
mixed_query_layer = self.query(input_tensor) mixed_query_layer = self.query(input_tensor)
mixed_key_layer = self.key(input_tensor) mixed_key_layer = self.key(input_tensor)
mixed_value_layer = self.value(input_tensor) mixed_value_layer = self.value(input_tensor)
...@@ -277,7 +272,7 @@ class TFAlbertAttention(TFBertSelfAttention): ...@@ -277,7 +272,7 @@ class TFAlbertAttention(TFBertSelfAttention):
# (batch size, num_heads, seq_len_q, seq_len_k) # (batch size, num_heads, seq_len_q, seq_len_k)
attention_scores = tf.matmul(query_layer, key_layer, transpose_b=True) attention_scores = tf.matmul(query_layer, key_layer, transpose_b=True)
# scale attention_scores # scale attention_scores
dk = tf.cast(tf.shape(key_layer)[-1], tf.float32) dk = tf.cast(shape_list(key_layer)[-1], tf.float32)
attention_scores = attention_scores / tf.math.sqrt(dk) attention_scores = attention_scores / tf.math.sqrt(dk)
if attention_mask is not None: if attention_mask is not None:
...@@ -645,9 +640,9 @@ class TFAlbertModel(TFAlbertPreTrainedModel): ...@@ -645,9 +640,9 @@ class TFAlbertModel(TFAlbertPreTrainedModel):
if input_ids is not None and inputs_embeds is not None: if input_ids is not None and inputs_embeds is not None:
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
elif input_ids is not None: elif input_ids is not None:
input_shape = tf.shape(input_ids) input_shape = shape_list(input_ids)
elif inputs_embeds is not None: elif inputs_embeds is not None:
input_shape = inputs_embeds.shape[:-1] input_shape = shape_list(inputs_embeds)[:-1]
else: else:
raise ValueError("You have to specify either input_ids or inputs_embeds") raise ValueError("You have to specify either input_ids or inputs_embeds")
......
...@@ -28,7 +28,7 @@ import numpy as np ...@@ -28,7 +28,7 @@ import numpy as np
import tensorflow as tf import tensorflow as tf
from .configuration_bert import BertConfig from .configuration_bert import BertConfig
from .modeling_tf_utils import TFPreTrainedModel, get_initializer from .modeling_tf_utils import TFPreTrainedModel, get_initializer, shape_list
from .file_utils import add_start_docstrings from .file_utils import add_start_docstrings
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
...@@ -145,9 +145,9 @@ class TFBertEmbeddings(tf.keras.layers.Layer): ...@@ -145,9 +145,9 @@ class TFBertEmbeddings(tf.keras.layers.Layer):
input_ids, position_ids, token_type_ids, inputs_embeds = inputs input_ids, position_ids, token_type_ids, inputs_embeds = inputs
if input_ids is not None: if input_ids is not None:
input_shape = tf.shape(input_ids) input_shape = shape_list(input_ids)
else: else:
input_shape = tf.shape(inputs_embeds)[:-1] input_shape = shape_list(inputs_embeds)[:-1]
seq_length = input_shape[1] seq_length = input_shape[1]
if position_ids is None: if position_ids is None:
...@@ -172,8 +172,8 @@ class TFBertEmbeddings(tf.keras.layers.Layer): ...@@ -172,8 +172,8 @@ class TFBertEmbeddings(tf.keras.layers.Layer):
Returns: Returns:
float32 tensor with shape [batch_size, length, vocab_size]. float32 tensor with shape [batch_size, length, vocab_size].
""" """
batch_size = tf.shape(inputs)[0] batch_size = shape_list(inputs)[0]
length = tf.shape(inputs)[1] length = shape_list(inputs)[1]
x = tf.reshape(inputs, [-1, self.hidden_size]) x = tf.reshape(inputs, [-1, self.hidden_size])
logits = tf.matmul(x, self.word_embeddings, transpose_b=True) logits = tf.matmul(x, self.word_embeddings, transpose_b=True)
...@@ -214,7 +214,7 @@ class TFBertSelfAttention(tf.keras.layers.Layer): ...@@ -214,7 +214,7 @@ class TFBertSelfAttention(tf.keras.layers.Layer):
def call(self, inputs, training=False): def call(self, inputs, training=False):
hidden_states, attention_mask, head_mask = inputs hidden_states, attention_mask, head_mask = inputs
batch_size = tf.shape(hidden_states)[0] batch_size = shape_list(hidden_states)[0]
mixed_query_layer = self.query(hidden_states) mixed_query_layer = self.query(hidden_states)
mixed_key_layer = self.key(hidden_states) mixed_key_layer = self.key(hidden_states)
mixed_value_layer = self.value(hidden_states) mixed_value_layer = self.value(hidden_states)
...@@ -225,7 +225,7 @@ class TFBertSelfAttention(tf.keras.layers.Layer): ...@@ -225,7 +225,7 @@ class TFBertSelfAttention(tf.keras.layers.Layer):
# Take the dot product between "query" and "key" to get the raw attention scores. # Take the dot product between "query" and "key" to get the raw attention scores.
attention_scores = tf.matmul(query_layer, key_layer, transpose_b=True) # (batch size, num_heads, seq_len_q, seq_len_k) attention_scores = tf.matmul(query_layer, key_layer, transpose_b=True) # (batch size, num_heads, seq_len_q, seq_len_k)
dk = tf.cast(tf.shape(key_layer)[-1], tf.float32) # scale attention_scores dk = tf.cast(shape_list(key_layer)[-1], tf.float32) # scale attention_scores
attention_scores = attention_scores / tf.math.sqrt(dk) attention_scores = attention_scores / tf.math.sqrt(dk)
if attention_mask is not None: if attention_mask is not None:
...@@ -502,9 +502,9 @@ class TFBertMainLayer(tf.keras.layers.Layer): ...@@ -502,9 +502,9 @@ class TFBertMainLayer(tf.keras.layers.Layer):
if input_ids is not None and inputs_embeds is not None: if input_ids is not None and inputs_embeds is not None:
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
elif input_ids is not None: elif input_ids is not None:
input_shape = input_ids.shape input_shape = shape_list(input_ids)
elif inputs_embeds is not None: elif inputs_embeds is not None:
input_shape = inputs_embeds.shape[:-1] input_shape = shape_list(inputs_embeds)[:-1]
else: else:
raise ValueError("You have to specify either input_ids or inputs_embeds") raise ValueError("You have to specify either input_ids or inputs_embeds")
...@@ -939,11 +939,11 @@ class TFBertForMultipleChoice(TFBertPreTrainedModel): ...@@ -939,11 +939,11 @@ class TFBertForMultipleChoice(TFBertPreTrainedModel):
input_ids = inputs input_ids = inputs
if input_ids is not None: if input_ids is not None:
num_choices = tf.shape(input_ids)[1] num_choices = shape_list(input_ids)[1]
seq_length = tf.shape(input_ids)[2] seq_length = shape_list(input_ids)[2]
else: else:
num_choices = tf.shape(inputs_embeds)[1] num_choices = shape_list(inputs_embeds)[1]
seq_length = tf.shape(inputs_embeds)[2] seq_length = shape_list(inputs_embeds)[2]
flat_input_ids = tf.reshape(input_ids, (-1, seq_length)) if input_ids is not None else None flat_input_ids = tf.reshape(input_ids, (-1, seq_length)) if input_ids is not None else None
flat_attention_mask = tf.reshape(attention_mask, (-1, seq_length)) if attention_mask is not None else None flat_attention_mask = tf.reshape(attention_mask, (-1, seq_length)) if attention_mask is not None else None
......
...@@ -95,7 +95,7 @@ class TFMultiHeadAttention(tf.keras.layers.Layer): ...@@ -95,7 +95,7 @@ class TFMultiHeadAttention(tf.keras.layers.Layer):
def call(self, inputs, training=False): def call(self, inputs, training=False):
v, k, q, mask, layer_past, attention_mask, head_mask = inputs v, k, q, mask, layer_past, attention_mask, head_mask = inputs
batch_size = q.shape[0] batch_size = shape_list(q)[0]
q = self.Wq(q) q = self.Wq(q)
k = self.Wk(k) k = self.Wk(k)
......
...@@ -137,9 +137,9 @@ class TFEmbeddings(tf.keras.layers.Layer): ...@@ -137,9 +137,9 @@ class TFEmbeddings(tf.keras.layers.Layer):
input_ids, position_ids = inputs input_ids, position_ids = inputs
if input_ids is not None: if input_ids is not None:
seq_length = tf.shape(input_ids)[1] seq_length = shape_list(input_ids)[1]
else: else:
seq_length = tf.shape(inputs_embeds)[1] seq_length = shape_list(inputs_embeds)[1]
if position_ids is None: if position_ids is None:
position_ids = tf.range(seq_length, dtype=tf.int32)[tf.newaxis, :] position_ids = tf.range(seq_length, dtype=tf.int32)[tf.newaxis, :]
...@@ -160,8 +160,8 @@ class TFEmbeddings(tf.keras.layers.Layer): ...@@ -160,8 +160,8 @@ class TFEmbeddings(tf.keras.layers.Layer):
Returns: Returns:
float32 tensor with shape [batch_size, length, vocab_size]. float32 tensor with shape [batch_size, length, vocab_size].
""" """
batch_size = tf.shape(inputs)[0] batch_size = shape_list(inputs)[0]
length = tf.shape(inputs)[1] length = shape_list(inputs)[1]
x = tf.reshape(inputs, [-1, self.dim]) x = tf.reshape(inputs, [-1, self.dim])
logits = tf.matmul(x, self.word_embeddings, transpose_b=True) logits = tf.matmul(x, self.word_embeddings, transpose_b=True)
......
...@@ -92,7 +92,7 @@ class TFAttention(tf.keras.layers.Layer): ...@@ -92,7 +92,7 @@ class TFAttention(tf.keras.layers.Layer):
# q, k, v have shape [batch, heads, sequence, features] # q, k, v have shape [batch, heads, sequence, features]
w = tf.matmul(q, k, transpose_b=True) w = tf.matmul(q, k, transpose_b=True)
if self.scale: if self.scale:
dk = tf.cast(tf.shape(k)[-1], tf.float32) # scale attention_scores dk = tf.cast(shape_list(k)[-1], tf.float32) # scale attention_scores
w = w / tf.math.sqrt(dk) w = w / tf.math.sqrt(dk)
# w has shape [batch, heads, dst_sequence, src_sequence], where information flows from src to dst. # w has shape [batch, heads, dst_sequence, src_sequence], where information flows from src to dst.
......
...@@ -98,7 +98,7 @@ class TFAttention(tf.keras.layers.Layer): ...@@ -98,7 +98,7 @@ class TFAttention(tf.keras.layers.Layer):
# q, k, v have shape [batch, heads, sequence, features] # q, k, v have shape [batch, heads, sequence, features]
w = tf.matmul(q, k, transpose_b=True) w = tf.matmul(q, k, transpose_b=True)
if self.scale: if self.scale:
dk = tf.cast(tf.shape(k)[-1], tf.float32) # scale attention_scores dk = tf.cast(shape_list(k)[-1], tf.float32) # scale attention_scores
w = w / tf.math.sqrt(dk) w = w / tf.math.sqrt(dk)
# w has shape [batch, heads, dst_sequence, src_sequence], where information flows from src to dst. # w has shape [batch, heads, dst_sequence, src_sequence], where information flows from src to dst.
......
...@@ -24,7 +24,7 @@ import numpy as np ...@@ -24,7 +24,7 @@ import numpy as np
import tensorflow as tf import tensorflow as tf
from .configuration_roberta import RobertaConfig from .configuration_roberta import RobertaConfig
from .modeling_tf_utils import TFPreTrainedModel, get_initializer from .modeling_tf_utils import TFPreTrainedModel, get_initializer, shape_list
from .file_utils import add_start_docstrings from .file_utils import add_start_docstrings
from .modeling_tf_bert import TFBertEmbeddings, TFBertMainLayer, gelu, gelu_new from .modeling_tf_bert import TFBertEmbeddings, TFBertMainLayer, gelu, gelu_new
...@@ -51,9 +51,9 @@ class TFRobertaEmbeddings(TFBertEmbeddings): ...@@ -51,9 +51,9 @@ class TFRobertaEmbeddings(TFBertEmbeddings):
input_ids, position_ids, token_type_ids, inputs_embeds = inputs input_ids, position_ids, token_type_ids, inputs_embeds = inputs
if input_ids is not None: if input_ids is not None:
seq_length = tf.shape(input_ids)[1] seq_length = shape_list(input_ids)[1]
else: else:
seq_length = tf.shape(inputs_embeds)[1] seq_length = shape_list(inputs_embeds)[1]
if position_ids is None: if position_ids is None:
position_ids = tf.range(self.padding_idx+1, seq_length+self.padding_idx+1, dtype=tf.int32)[tf.newaxis, :] position_ids = tf.range(self.padding_idx+1, seq_length+self.padding_idx+1, dtype=tf.int32)[tf.newaxis, :]
......
...@@ -337,7 +337,7 @@ class TFAdaptiveEmbedding(tf.keras.layers.Layer): ...@@ -337,7 +337,7 @@ class TFAdaptiveEmbedding(tf.keras.layers.Layer):
emb_i = tf.einsum('id,de->ie', emb_i, self.emb_projs[i]) emb_i = tf.einsum('id,de->ie', emb_i, self.emb_projs[i])
mask_idx = tf.cast(tf.where(mask_i), dtype=tf.int64) mask_idx = tf.cast(tf.where(mask_i), dtype=tf.int64)
emb_flat += tf.scatter_nd(mask_idx, emb_i, tf.cast(tf.shape(emb_flat), dtype=tf.int64)) emb_flat += tf.scatter_nd(mask_idx, emb_i, tf.cast(shape_list(emb_flat), dtype=tf.int64))
embed_shape = shape_list(inp) + [self.d_proj] embed_shape = shape_list(inp) + [self.d_proj]
embed = tf.reshape(emb_flat, embed_shape) embed = tf.reshape(emb_flat, embed_shape)
......
...@@ -105,7 +105,7 @@ class TFAdaptiveSoftmaxMask(tf.keras.layers.Layer): ...@@ -105,7 +105,7 @@ class TFAdaptiveSoftmaxMask(tf.keras.layers.Layer):
@staticmethod @staticmethod
def _gather_logprob(logprob, target): def _gather_logprob(logprob, target):
lp_size = tf.shape(logprob) lp_size = shape_list(logprob)
r = tf.range(lp_size[0]) r = tf.range(lp_size[0])
idx = tf.stack([r, target], 1) idx = tf.stack([r, target], 1)
return tf.gather_nd(logprob, idx) return tf.gather_nd(logprob, idx)
...@@ -159,7 +159,7 @@ class TFAdaptiveSoftmaxMask(tf.keras.layers.Layer): ...@@ -159,7 +159,7 @@ class TFAdaptiveSoftmaxMask(tf.keras.layers.Layer):
cur_logprob = self._gather_logprob(cur_tail_logprob, cur_target) cur_logprob = self._gather_logprob(cur_tail_logprob, cur_target)
cur_logprob += cur_head_logprob[:, self.cutoff_ends[1] + i - 1] cur_logprob += cur_head_logprob[:, self.cutoff_ends[1] + i - 1]
if target is not None: if target is not None:
loss += tf.scatter_nd(mask_idx, -cur_logprob, tf.cast(tf.shape(loss), dtype=tf.int64)) loss += tf.scatter_nd(mask_idx, -cur_logprob, tf.cast(shape_list(loss), dtype=tf.int64))
out = tf.concat(out, axis=-1) out = tf.concat(out, axis=-1)
if target is not None: if target is not None:
......
...@@ -494,7 +494,7 @@ class TFSequenceSummary(tf.keras.layers.Layer): ...@@ -494,7 +494,7 @@ class TFSequenceSummary(tf.keras.layers.Layer):
def shape_list(x): def shape_list(x):
"""Deal with dynamic shape in tensorflow cleanly.""" """Deal with dynamic shape in tensorflow cleanly."""
static = x.shape.as_list() static = x.shape.as_list()
dynamic = tf.shape(x) dynamic = shape_list(x)
return [dynamic[i] if s is None else s for i, s in enumerate(static)] return [dynamic[i] if s is None else s for i, s in enumerate(static)]
def get_initializer(initializer_range=0.02): def get_initializer(initializer_range=0.02):
......
...@@ -112,8 +112,7 @@ class TFXLNetRelativeAttention(tf.keras.layers.Layer): ...@@ -112,8 +112,7 @@ class TFXLNetRelativeAttention(tf.keras.layers.Layer):
def prune_heads(self, heads): def prune_heads(self, heads):
raise NotImplementedError raise NotImplementedError
@staticmethod def rel_shift(self, x, klen=-1):
def rel_shift(x, klen=-1):
"""perform relative shift to form the relative attention score.""" """perform relative shift to form the relative attention score."""
x_size = shape_list(x) x_size = shape_list(x)
...@@ -135,7 +134,7 @@ class TFXLNetRelativeAttention(tf.keras.layers.Layer): ...@@ -135,7 +134,7 @@ class TFXLNetRelativeAttention(tf.keras.layers.Layer):
# position based attention score # position based attention score
bd = tf.einsum('ibnd,jbnd->ijbn', q_head + self.r_r_bias, k_head_r) bd = tf.einsum('ibnd,jbnd->ijbn', q_head + self.r_r_bias, k_head_r)
bd = self.rel_shift(bd, klen=ac.shape[1]) bd = self.rel_shift(bd, klen=shape_list(ac)[1])
# segment based attention score # segment based attention score
if seg_mat is None: if seg_mat is None:
...@@ -192,7 +191,7 @@ class TFXLNetRelativeAttention(tf.keras.layers.Layer): ...@@ -192,7 +191,7 @@ class TFXLNetRelativeAttention(tf.keras.layers.Layer):
if g is not None: if g is not None:
###### Two-stream attention with relative positional encoding. ###### Two-stream attention with relative positional encoding.
# content based attention score # content based attention score
if mems is not None and mems.shape.ndims > 1: if mems is not None and len(shape_list(mems)) > 1:
cat = tf.concat([mems, h], axis=0) cat = tf.concat([mems, h], axis=0)
else: else:
cat = h cat = h
...@@ -252,7 +251,7 @@ class TFXLNetRelativeAttention(tf.keras.layers.Layer): ...@@ -252,7 +251,7 @@ class TFXLNetRelativeAttention(tf.keras.layers.Layer):
else: else:
###### Multi-head attention with relative positional encoding ###### Multi-head attention with relative positional encoding
if mems is not None and mems.shape.ndims > 1: if mems is not None and len(shape_list(mems)) > 1:
cat = tf.concat([mems, h], axis=0) cat = tf.concat([mems, h], axis=0)
else: else:
cat = h cat = h
...@@ -565,7 +564,7 @@ class TFXLNetMainLayer(tf.keras.layers.Layer): ...@@ -565,7 +564,7 @@ class TFXLNetMainLayer(tf.keras.layers.Layer):
if data_mask is not None: if data_mask is not None:
# all mems can be attended to # all mems can be attended to
mems_mask = tf.zeros([tf.shape(data_mask)[0], mlen, bsz], mems_mask = tf.zeros([shape_list(data_mask)[0], mlen, bsz],
dtype=dtype_float) dtype=dtype_float)
data_mask = tf.concat([mems_mask, data_mask], axis=1) data_mask = tf.concat([mems_mask, data_mask], axis=1)
if attn_mask is None: if attn_mask is None:
...@@ -590,7 +589,7 @@ class TFXLNetMainLayer(tf.keras.layers.Layer): ...@@ -590,7 +589,7 @@ class TFXLNetMainLayer(tf.keras.layers.Layer):
word_emb_k = self.word_embedding(input_ids) word_emb_k = self.word_embedding(input_ids)
output_h = self.dropout(word_emb_k, training=training) output_h = self.dropout(word_emb_k, training=training)
if target_mapping is not None: if target_mapping is not None:
word_emb_q = tf.tile(self.mask_emb, [tf.shape(target_mapping)[0], bsz, 1]) word_emb_q = tf.tile(self.mask_emb, [shape_list(target_mapping)[0], bsz, 1])
# else: # We removed the inp_q input which was same as target mapping # else: # We removed the inp_q input which was same as target mapping
# inp_q_ext = inp_q[:, :, None] # inp_q_ext = inp_q[:, :, None]
# word_emb_q = inp_q_ext * self.mask_emb + (1 - inp_q_ext) * word_emb_k # word_emb_q = inp_q_ext * self.mask_emb + (1 - inp_q_ext) * word_emb_k
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment