Commit 19d2e2ac authored by Fan Yang's avatar Fan Yang Committed by A. Unique TensorFlower
Browse files

Internal change

PiperOrigin-RevId: 398292356
parent 350f4854
...@@ -228,13 +228,17 @@ class NASFPN(tf.keras.Model): ...@@ -228,13 +228,17 @@ class NASFPN(tf.keras.Model):
if input_level < target_level: if input_level < target_level:
stride = int(2 ** (target_level - input_level)) stride = int(2 ** (target_level - input_level))
x = tf.keras.layers.MaxPool2D( return tf.keras.layers.MaxPool2D(
pool_size=stride, strides=stride, padding='same')(x) pool_size=stride, strides=stride, padding='same')(x)
elif input_level > target_level: if input_level > target_level:
scale = int(2 ** (input_level - target_level)) scale = int(2 ** (input_level - target_level))
x = spatial_transform_ops.nearest_upsampling(x, scale=scale) return spatial_transform_ops.nearest_upsampling(x, scale=scale)
return x # Force output x to be the same dtype as mixed precision policy. This avoids
# dtype mismatch when one input (by default float32 dtype) does not meet all
# the above conditions and is output unchanged, while other inputs are
# processed to have different dtype, e.g., using bfloat16 on TPU.
return tf.cast(x, dtype=tf.keras.layers.Layer().dtype_policy.compute_dtype)
def _global_attention(self, feat0, feat1): def _global_attention(self, feat0, feat1):
m = tf.math.reduce_max(feat0, axis=[1, 2], keepdims=True) m = tf.math.reduce_max(feat0, axis=[1, 2], keepdims=True)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment