Commit 069ad593 authored by Scott Zhu's avatar Scott Zhu Committed by A. Unique TensorFlower
Browse files

Prepare for upcoming keras initializer change.

PiperOrigin-RevId: 448506063
parent b8c1bd07
...@@ -18,6 +18,8 @@ import inspect ...@@ -18,6 +18,8 @@ import inspect
from typing import Any, MutableMapping, Optional, Union, Tuple from typing import Any, MutableMapping, Optional, Union, Tuple
import tensorflow as tf import tensorflow as tf
from official.modeling import tf_utils
class GroupConv2D(tf.keras.layers.Conv2D): class GroupConv2D(tf.keras.layers.Conv2D):
"""2D group convolution as a Keras Layer.""" """2D group convolution as a Keras Layer."""
...@@ -168,7 +170,7 @@ class GroupConv2D(tf.keras.layers.Conv2D): ...@@ -168,7 +170,7 @@ class GroupConv2D(tf.keras.layers.Conv2D):
self.add_weight( self.add_weight(
name='kernel_{}'.format(g), name='kernel_{}'.format(g),
shape=self.group_kernel_shape, shape=self.group_kernel_shape,
initializer=self.kernel_initializer, initializer=tf_utils.clone_initializer(self.kernel_initializer),
regularizer=self.kernel_regularizer, regularizer=self.kernel_regularizer,
constraint=self.kernel_constraint, constraint=self.kernel_constraint,
trainable=True, trainable=True,
...@@ -178,7 +180,7 @@ class GroupConv2D(tf.keras.layers.Conv2D): ...@@ -178,7 +180,7 @@ class GroupConv2D(tf.keras.layers.Conv2D):
self.add_weight( self.add_weight(
name='bias_{}'.format(g), name='bias_{}'.format(g),
shape=(self.group_output_channel,), shape=(self.group_output_channel,),
initializer=self.bias_initializer, initializer=tf_utils.clone_initializer(self.bias_initializer),
regularizer=self.bias_regularizer, regularizer=self.bias_regularizer,
constraint=self.bias_constraint, constraint=self.bias_constraint,
trainable=True, trainable=True,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment