"vscode:/vscode.git/clone" did not exist on "b64375c88f19dbea3fdc61472051337bb811731e"
Commit 0e0a94a6 authored by Hongkun Yu's avatar Hongkun Yu Committed by A. Unique TensorFlower
Browse files

Remove compute_output_shape.

Keras: "manual" shape inference is only required if the layer is dynamic (otherwise we use TF's static shape inference capabilities)

PiperOrigin-RevId: 290821518
parent ac97f01b
......@@ -118,14 +118,6 @@ class Attention(tf.keras.layers.Layer):
self._dropout = tf.keras.layers.Dropout(rate=self._dropout_rate)
def compute_output_shape(self, input_shape):
# TODO(momernick): validate tensor dimensions.
from_tensor_shape = tf.TensorShape(input_shape[0])
batch = from_tensor_shape[0]
from_tensor_length = from_tensor_shape[1]
return tf.TensorShape(
(batch, from_tensor_length, self._num_heads, self._head_size))
def get_config(self):
config = {
"num_heads":
......
......@@ -143,18 +143,6 @@ class DenseEinsum(tf.keras.layers.Layer):
self._bias = None
super(DenseEinsum, self).build(input_shape)
def compute_output_shape(self, input_shape):
input_shape = tf.TensorShape(input_shape)
input_shape = input_shape.with_rank_at_least(self._num_summed_dimensions +
1)
for i in range(self._num_summed_dimensions):
if tf.dimension_value(input_shape[-1 * i]) is None:
raise ValueError(
"The %s dimension of input_shape must be defined, but saw: %s" %
(-1 * i, input_shape))
return input_shape[:-1 * self._num_summed_dimensions].concatenate(
self._units)
def get_config(self):
config = {
"output_shape":
......
......@@ -158,13 +158,6 @@ class Transformer(tf.keras.layers.Layer):
super(Transformer, self).build(input_shape)
def compute_output_shape(self, input_shape):
data_tensor_shape = tf.TensorShape(input_shape[0])
batch = data_tensor_shape[0]
sequence_length = data_tensor_shape[1]
return tf.TensorShape((batch, sequence_length, self._output_einsum_shape))
def get_config(self):
config = {
"num_attention_heads":
......
......@@ -175,13 +175,6 @@ class TransformerScaffold(tf.keras.layers.Layer):
super(TransformerScaffold, self).build(input_shape)
def compute_output_shape(self, input_shape):
data_tensor_shape = tf.TensorShape(input_shape[0])
batch = data_tensor_shape[0]
sequence_length = data_tensor_shape[1]
return tf.TensorShape((batch, sequence_length, self._output_einsum_shape))
def get_config(self):
config = {
"attention_cls":
......
......@@ -168,9 +168,6 @@ class Bias(tf.keras.layers.Layer):
super(Bias, self).build(input_shape)
def compute_output_shape(self, input_shape):
return input_shape
def get_config(self):
config = {
'activation': tf.keras.activations.serialize(self._activation),
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment