Commit f2f26a8b authored by Chen Chen's avatar Chen Chen Committed by A. Unique TensorFlower
Browse files

Internal change

PiperOrigin-RevId: 313812017
parent f4baddb3
......@@ -176,16 +176,16 @@ class TransformerScaffold(tf.keras.layers.Layer):
if self._feedforward_block is None:
self._intermediate_dense = tf.keras.layers.experimental.EinsumDense(
"...x,xy->...y",
output_shape=self._intermediate_size,
bias_axes="y",
"abc,cd->abd",
output_shape=(None, self._intermediate_size),
bias_axes="d",
activation=self._intermediate_activation,
name="intermediate",
**common_kwargs)
self._output_dense = tf.keras.layers.experimental.EinsumDense(
"...x,xy->...y",
output_shape=hidden_size,
bias_axes="y",
"abc,cd->abd",
output_shape=(None, hidden_size),
bias_axes="d",
name="output",
**common_kwargs)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment