Commit f2f26a8b authored by Chen Chen's avatar Chen Chen Committed by A. Unique TensorFlower
Browse files

Internal change

PiperOrigin-RevId: 313812017
parent f4baddb3
...@@ -176,16 +176,16 @@ class TransformerScaffold(tf.keras.layers.Layer): ...@@ -176,16 +176,16 @@ class TransformerScaffold(tf.keras.layers.Layer):
if self._feedforward_block is None: if self._feedforward_block is None:
self._intermediate_dense = tf.keras.layers.experimental.EinsumDense( self._intermediate_dense = tf.keras.layers.experimental.EinsumDense(
"...x,xy->...y", "abc,cd->abd",
output_shape=self._intermediate_size, output_shape=(None, self._intermediate_size),
bias_axes="y", bias_axes="d",
activation=self._intermediate_activation, activation=self._intermediate_activation,
name="intermediate", name="intermediate",
**common_kwargs) **common_kwargs)
self._output_dense = tf.keras.layers.experimental.EinsumDense( self._output_dense = tf.keras.layers.experimental.EinsumDense(
"...x,xy->...y", "abc,cd->abd",
output_shape=hidden_size, output_shape=(None, hidden_size),
bias_axes="y", bias_axes="d",
name="output", name="output",
**common_kwargs) **common_kwargs)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment