Commit 5a2cf36f authored by Kaushik Shivakumar's avatar Kaushik Shivakumar
Browse files

Merge remote-tracking branch 'upstream/master' into newavarecords

parents 258ddfc3 a829e648
...@@ -21,6 +21,8 @@ from __future__ import print_function ...@@ -21,6 +21,8 @@ from __future__ import print_function
import tensorflow as tf import tensorflow as tf
from tensorflow.python.util import deprecation
_CHR_IDX = ["a", "b", "c", "d", "e", "f", "g", "h", "i", "j", "k", "l", "m"] _CHR_IDX = ["a", "b", "c", "d", "e", "f", "g", "h", "i", "j", "k", "l", "m"]
...@@ -57,6 +59,9 @@ class DenseEinsum(tf.keras.layers.Layer): ...@@ -57,6 +59,9 @@ class DenseEinsum(tf.keras.layers.Layer):
`(batch_size, units)`. `(batch_size, units)`.
""" """
@deprecation.deprecated(
None, "DenseEinsum is deprecated. Please use "
"tf.keras.experimental.EinsumDense layer instead.")
def __init__(self, def __init__(self,
output_shape, output_shape,
num_summed_dimensions=1, num_summed_dimensions=1,
......
...@@ -48,7 +48,11 @@ class MultiChannelAttentionTest(tf.test.TestCase): ...@@ -48,7 +48,11 @@ class MultiChannelAttentionTest(tf.test.TestCase):
mask_data = np.random.randint(2, size=(3, num_docs, 4, 2)) mask_data = np.random.randint(2, size=(3, num_docs, 4, 2))
doc_probs = np.random.randint( doc_probs = np.random.randint(
2, size=(3, num_heads, 4, num_docs)).astype(float) 2, size=(3, num_heads, 4, num_docs)).astype(float)
outputs = attention_layer([from_data, to_data, doc_probs], mask_data) outputs = attention_layer(
query=from_data,
value=to_data,
context_attention_weights=doc_probs,
attention_mask=mask_data)
self.assertEqual(outputs.shape, (3, 4, 8)) self.assertEqual(outputs.shape, (3, 4, 8))
......
...@@ -160,7 +160,6 @@ class RelativePositionEmbedding(tf.keras.layers.Layer): ...@@ -160,7 +160,6 @@ class RelativePositionEmbedding(tf.keras.layers.Layer):
"hidden_size": self._hidden_size, "hidden_size": self._hidden_size,
"min_timescale": self._min_timescale, "min_timescale": self._min_timescale,
"max_timescale": self._max_timescale, "max_timescale": self._max_timescale,
"length": self._length,
} }
base_config = super(RelativePositionEmbedding, self).get_config() base_config = super(RelativePositionEmbedding, self).get_config()
return dict(list(base_config.items()) + list(config.items())) return dict(list(base_config.items()) + list(config.items()))
......
This diff is collapsed.
...@@ -262,9 +262,8 @@ class TransformerScaffold(tf.keras.layers.Layer): ...@@ -262,9 +262,8 @@ class TransformerScaffold(tf.keras.layers.Layer):
else: else:
input_tensor, attention_mask = (inputs, None) input_tensor, attention_mask = (inputs, None)
attention_inputs = [input_tensor, input_tensor] attention_output = self._attention_layer(
query=input_tensor, value=input_tensor, attention_mask=attention_mask)
attention_output = self._attention_layer(attention_inputs, attention_mask)
attention_output = self._attention_dropout(attention_output) attention_output = self._attention_dropout(attention_output)
attention_output = self._attention_layer_norm(input_tensor + attention_output = self._attention_layer_norm(input_tensor +
attention_output) attention_output)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment