Commit 1a3c83d6 authored by zhanggzh's avatar zhanggzh
Browse files

增加keras-cv模型及训练代码

parent 9846958a
# Copyright 2022 The KerasCV Authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import tensorflow as tf
from tensorflow import keras
from keras_cv.models.generative.stable_diffusion.__internal__.layers.group_normalization import (
GroupNormalization,
)
from keras_cv.models.generative.stable_diffusion.__internal__.layers.padded_conv2d import (
PaddedConv2D,
)
class DiffusionModel(keras.Model):
def __init__(self, img_height, img_width, max_text_length, name=None):
context = keras.layers.Input((max_text_length, 768))
t_embed_input = keras.layers.Input((320,))
latent = keras.layers.Input((img_height // 8, img_width // 8, 4))
t_emb = keras.layers.Dense(1280)(t_embed_input)
t_emb = keras.layers.Activation("swish")(t_emb)
t_emb = keras.layers.Dense(1280)(t_emb)
# Downsampling flow
outputs = []
x = PaddedConv2D(320, kernel_size=3, padding=1)(latent)
outputs.append(x)
for _ in range(2):
x = ResBlock(320)([x, t_emb])
x = SpatialTransformer(8, 40)([x, context])
outputs.append(x)
x = PaddedConv2D(320, 3, strides=2, padding=1)(x) # Downsample 2x
outputs.append(x)
for _ in range(2):
x = ResBlock(640)([x, t_emb])
x = SpatialTransformer(8, 80)([x, context])
outputs.append(x)
x = PaddedConv2D(640, 3, strides=2, padding=1)(x) # Downsample 2x
outputs.append(x)
for _ in range(2):
x = ResBlock(1280)([x, t_emb])
x = SpatialTransformer(8, 160)([x, context])
outputs.append(x)
x = PaddedConv2D(1280, 3, strides=2, padding=1)(x) # Downsample 2x
outputs.append(x)
for _ in range(2):
x = ResBlock(1280)([x, t_emb])
outputs.append(x)
# Middle flow
x = ResBlock(1280)([x, t_emb])
x = SpatialTransformer(8, 160)([x, context])
x = ResBlock(1280)([x, t_emb])
# Upsampling flow
for _ in range(3):
x = keras.layers.Concatenate()([x, outputs.pop()])
x = ResBlock(1280)([x, t_emb])
x = Upsample(1280)(x)
for _ in range(3):
x = keras.layers.Concatenate()([x, outputs.pop()])
x = ResBlock(1280)([x, t_emb])
x = SpatialTransformer(8, 160)([x, context])
x = Upsample(1280)(x)
for _ in range(3):
x = keras.layers.Concatenate()([x, outputs.pop()])
x = ResBlock(640)([x, t_emb])
x = SpatialTransformer(8, 80)([x, context])
x = Upsample(640)(x)
for _ in range(3):
x = keras.layers.Concatenate()([x, outputs.pop()])
x = ResBlock(320)([x, t_emb])
x = SpatialTransformer(8, 40)([x, context])
# Exit flow
x = GroupNormalization(epsilon=1e-5)(x)
x = keras.layers.Activation("swish")(x)
output = PaddedConv2D(4, kernel_size=3, padding=1)(x)
super().__init__([latent, t_embed_input, context], output, name=name)
class ResBlock(keras.layers.Layer):
def __init__(self, output_dim, **kwargs):
super().__init__(**kwargs)
self.output_dim = output_dim
self.entry_flow = [
GroupNormalization(epsilon=1e-5),
keras.layers.Activation("swish"),
PaddedConv2D(output_dim, 3, padding=1),
]
self.embedding_flow = [
keras.layers.Activation("swish"),
keras.layers.Dense(output_dim),
]
self.exit_flow = [
GroupNormalization(epsilon=1e-5),
keras.layers.Activation("swish"),
PaddedConv2D(output_dim, 3, padding=1),
]
def build(self, input_shape):
if input_shape[0][-1] != self.output_dim:
self.residual_projection = PaddedConv2D(self.output_dim, 1)
else:
self.residual_projection = lambda x: x
def call(self, inputs):
inputs, embeddings = inputs
x = inputs
for layer in self.entry_flow:
x = layer(x)
for layer in self.embedding_flow:
embeddings = layer(embeddings)
x = x + embeddings[:, None, None]
for layer in self.exit_flow:
x = layer(x)
return x + self.residual_projection(inputs)
class SpatialTransformer(keras.layers.Layer):
def __init__(self, num_heads, head_size, **kwargs):
super().__init__(**kwargs)
self.norm = GroupNormalization(epsilon=1e-5)
channels = num_heads * head_size
self.conv1 = PaddedConv2D(num_heads * head_size, 1)
self.transformer_block = BasicTransformerBlock(channels, num_heads, head_size)
self.conv2 = PaddedConv2D(channels, 1)
def call(self, inputs):
inputs, context = inputs
_, h, w, c = inputs.shape
x = self.norm(inputs)
x = self.conv1(x)
x = tf.reshape(x, (-1, h * w, c))
x = self.transformer_block([x, context])
x = tf.reshape(x, (-1, h, w, c))
return self.conv2(x) + inputs
class BasicTransformerBlock(keras.layers.Layer):
def __init__(self, dim, num_heads, head_size, **kwargs):
super().__init__(**kwargs)
self.norm1 = keras.layers.LayerNormalization(epsilon=1e-5)
self.attn1 = CrossAttention(num_heads, head_size)
self.norm2 = keras.layers.LayerNormalization(epsilon=1e-5)
self.attn2 = CrossAttention(num_heads, head_size)
self.norm3 = keras.layers.LayerNormalization(epsilon=1e-5)
self.geglu = GEGLU(dim * 4)
self.dense = keras.layers.Dense(dim)
def call(self, inputs):
inputs, context = inputs
x = self.attn1([self.norm1(inputs), None]) + inputs
x = self.attn2([self.norm2(x), context]) + x
return self.dense(self.geglu(self.norm3(x))) + x
class CrossAttention(keras.layers.Layer):
def __init__(self, num_heads, head_size, **kwargs):
super().__init__(**kwargs)
self.to_q = keras.layers.Dense(num_heads * head_size, use_bias=False)
self.to_k = keras.layers.Dense(num_heads * head_size, use_bias=False)
self.to_v = keras.layers.Dense(num_heads * head_size, use_bias=False)
self.scale = head_size**-0.5
self.num_heads = num_heads
self.head_size = head_size
self.out_proj = keras.layers.Dense(num_heads * head_size)
def call(self, inputs):
inputs, context = inputs
context = inputs if context is None else context
q, k, v = self.to_q(inputs), self.to_k(context), self.to_v(context)
q = tf.reshape(q, (-1, inputs.shape[1], self.num_heads, self.head_size))
k = tf.reshape(k, (-1, context.shape[1], self.num_heads, self.head_size))
v = tf.reshape(v, (-1, context.shape[1], self.num_heads, self.head_size))
q = tf.transpose(q, (0, 2, 1, 3)) # (bs, num_heads, time, head_size)
k = tf.transpose(k, (0, 2, 3, 1)) # (bs, num_heads, head_size, time)
v = tf.transpose(v, (0, 2, 1, 3)) # (bs, num_heads, time, head_size)
score = td_dot(q, k) * self.scale
weights = keras.activations.softmax(score) # (bs, num_heads, time, time)
attn = td_dot(weights, v)
attn = tf.transpose(attn, (0, 2, 1, 3)) # (bs, time, num_heads, head_size)
out = tf.reshape(attn, (-1, inputs.shape[1], self.num_heads * self.head_size))
return self.out_proj(out)
class Upsample(keras.layers.Layer):
def __init__(self, channels, **kwargs):
super().__init__(**kwargs)
self.ups = keras.layers.UpSampling2D(2)
self.conv = PaddedConv2D(channels, 3, padding=1)
def call(self, inputs):
return self.conv(self.ups(inputs))
class GEGLU(keras.layers.Layer):
def __init__(self, output_dim, **kwargs):
super().__init__(**kwargs)
self.output_dim = output_dim
self.dense = keras.layers.Dense(output_dim * 2)
def call(self, inputs):
x = self.dense(inputs)
x, gate = x[..., : self.output_dim], x[..., self.output_dim :]
tanh_res = keras.activations.tanh(
gate * 0.7978845608 * (1 + 0.044715 * (gate**2))
)
return x * 0.5 * gate * (1 + tanh_res)
def td_dot(a, b):
aa = tf.reshape(a, (-1, a.shape[2], a.shape[3]))
bb = tf.reshape(b, (-1, b.shape[2], b.shape[3]))
cc = keras.backend.batch_dot(aa, bb)
return tf.reshape(cc, (-1, a.shape[1], cc.shape[1], cc.shape[2]))
# Copyright 2022 The KerasCV Authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from tensorflow import keras
from keras_cv.models.generative.stable_diffusion.__internal__.layers.attention_block import (
AttentionBlock,
)
from keras_cv.models.generative.stable_diffusion.__internal__.layers.group_normalization import (
GroupNormalization,
)
from keras_cv.models.generative.stable_diffusion.__internal__.layers.padded_conv2d import (
PaddedConv2D,
)
from keras_cv.models.generative.stable_diffusion.__internal__.layers.resnet_block import (
ResnetBlock,
)
class ImageEncoder(keras.Sequential):
"""ImageEncoder is the VAE Encoder for StableDiffusion."""
def __init__(self, img_height=512, img_width=512, download_weights=True):
super().__init__(
[
keras.layers.Input((img_height, img_width, 3)),
PaddedConv2D(128, 3, padding=1),
ResnetBlock(128),
ResnetBlock(128),
PaddedConv2D(128, 3, padding=1, strides=2),
ResnetBlock(256),
ResnetBlock(256),
PaddedConv2D(256, 3, padding=1, strides=2),
ResnetBlock(512),
ResnetBlock(512),
PaddedConv2D(512, 3, padding=1, strides=2),
ResnetBlock(512),
ResnetBlock(512),
ResnetBlock(512),
AttentionBlock(512),
ResnetBlock(512),
GroupNormalization(epsilon=1e-5),
keras.layers.Activation("swish"),
PaddedConv2D(8, 3, padding=1),
PaddedConv2D(8, 1),
# TODO(lukewood): can this be refactored to be a Rescaling layer?
# Perhaps some sort of rescale and gather?
# Either way, we may need a lambda to gather the first 4 dimensions.
keras.layers.Lambda(lambda x: x[..., :4] * 0.18215),
]
)
if download_weights:
image_encoder_weights_fpath = keras.utils.get_file(
origin="https://huggingface.co/fchollet/stable-diffusion/blob/main/vae_encoder.h5",
file_hash="f142c8c94c6853cd19d8bfb9c10aa762c057566f54456398beea6a70a639bf48",
)
self.load_weights(image_encoder_weights_fpath)
# Copyright 2022 The KerasCV Authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Keras implementation of StableDiffusion.
Credits:
- Original implementation: https://github.com/CompVis/stable-diffusion
- Initial TF/Keras port: https://github.com/divamgupta/stable-diffusion-tensorflow
The current implementation is a rewrite of the initial TF/Keras port by Divam Gupta.
"""
import math
import numpy as np
import tensorflow as tf
from tensorflow import keras
from keras_cv.models.generative.stable_diffusion.clip_tokenizer import SimpleTokenizer
from keras_cv.models.generative.stable_diffusion.constants import _ALPHAS_CUMPROD
from keras_cv.models.generative.stable_diffusion.constants import _UNCONDITIONAL_TOKENS
from keras_cv.models.generative.stable_diffusion.decoder import Decoder
from keras_cv.models.generative.stable_diffusion.diffusion_model import DiffusionModel
from keras_cv.models.generative.stable_diffusion.image_encoder import ImageEncoder
from keras_cv.models.generative.stable_diffusion.text_encoder import TextEncoder
MAX_PROMPT_LENGTH = 77
class StableDiffusion:
"""Keras implementation of Stable Diffusion.
Stable Diffusion is a powerful image generation model that can be used,
among other things, to generate pictures according to a short text description
(called a "prompt").
Arguments:
img_height: Height of the images to generate, in pixel. Note that only
multiples of 128 are supported; the value provided will be rounded
to the nearest valid value. Default: 512.
img_width: Width of the images to generate, in pixel. Note that only
multiples of 128 are supported; the value provided will be rounded
to the nearest valid value. Default: 512.
jit_compile: Whether to compile the underlying models to XLA.
This can lead to a significant speedup on some systems. Default: False.
Example:
```python
from keras_cv.models import StableDiffusion
from PIL import Image
model = StableDiffusion(img_height=512, img_width=512, jit_compile=True)
img = model.text_to_image(
prompt="A beautiful horse running through a field",
batch_size=1, # How many images to generate at once
num_steps=25, # Number of iterations (controls image quality)
seed=123, # Set this to always get the same image from the same prompt
)
Image.fromarray(img[0]).save("horse.png")
print("saved at horse.png")
```
References:
- [About Stable Diffusion](https://stability.ai/blog/stable-diffusion-announcement)
- [Original implementation](https://github.com/CompVis/stable-diffusion)
"""
def __init__(
self,
img_height=512,
img_width=512,
jit_compile=False,
):
# UNet requires multiples of 2**7 = 128
img_height = round(img_height / 128) * 128
img_width = round(img_width / 128) * 128
self.img_height = img_height
self.img_width = img_width
self.tokenizer = SimpleTokenizer()
# Create models
self.text_encoder = TextEncoder(MAX_PROMPT_LENGTH)
self.diffusion_model = DiffusionModel(img_height, img_width, MAX_PROMPT_LENGTH)
self.decoder = Decoder(img_height, img_width)
# lazy initialize image encoder
self._image_encoder = None
self.jit_compile = jit_compile
if jit_compile:
self.text_encoder.compile(jit_compile=True)
self.diffusion_model.compile(jit_compile=True)
self.decoder.compile(jit_compile=True)
print(
"By using this model checkpoint, you acknowledge that its usage is "
"subject to the terms of the CreativeML Open RAIL-M license at "
"https://raw.githubusercontent.com/CompVis/stable-diffusion/main/LICENSE"
)
# Load weights
text_encoder_weights_fpath = keras.utils.get_file(
origin="https://huggingface.co/fchollet/stable-diffusion/resolve/main/kcv_encoder.h5",
file_hash="4789e63e07c0e54d6a34a29b45ce81ece27060c499a709d556c7755b42bb0dc4",
)
diffusion_model_weights_fpath = keras.utils.get_file(
origin="https://huggingface.co/fchollet/stable-diffusion/resolve/main/kcv_diffusion_model.h5",
file_hash="8799ff9763de13d7f30a683d653018e114ed24a6a819667da4f5ee10f9e805fe",
)
decoder_weights_fpath = keras.utils.get_file(
origin="https://huggingface.co/fchollet/stable-diffusion/resolve/main/kcv_decoder.h5",
file_hash="ad350a65cc8bc4a80c8103367e039a3329b4231c2469a1093869a345f55b1962",
)
self.text_encoder.load_weights(text_encoder_weights_fpath)
self.diffusion_model.load_weights(diffusion_model_weights_fpath)
self.decoder.load_weights(decoder_weights_fpath)
def text_to_image(
self,
prompt,
batch_size=1,
num_steps=25,
unconditional_guidance_scale=7.5,
seed=None,
):
encoded_text = self.encode_text(prompt)
return self.generate_image(
encoded_text,
batch_size=batch_size,
num_steps=num_steps,
unconditional_guidance_scale=unconditional_guidance_scale,
seed=seed,
)
def encode_text(self, prompt):
"""Encodes a prompt into a latent text encoding.
The encoding produced by this method should be used as the
`encoded_text` parameter of `StableDiffusion.generate_image`. Encoding
text separately from generating an image can be used to arbitrarily
modify the text encoding priot to image generation, e.g. for walking
between two prompts.
Args:
prompt: a string to encode, must be 77 tokens or shorter.
Example:
```python
from keras_cv.models import StableDiffusion
model = StableDiffusion(img_height=512, img_width=512, jit_compile=True)
encoded_text = model.encode_text("Tacos at dawn")
img = model.generate_image(encoded_text)
```
"""
# Tokenize prompt (i.e. starting context)
inputs = self.tokenizer.encode(prompt)
if len(inputs) > MAX_PROMPT_LENGTH:
raise ValueError(
f"Prompt is too long (should be <= {MAX_PROMPT_LENGTH} tokens)"
)
phrase = inputs + [49407] * (MAX_PROMPT_LENGTH - len(inputs))
phrase = tf.convert_to_tensor([phrase], dtype=tf.int32)
context = self.text_encoder.predict_on_batch([phrase, self._get_pos_ids()])
return context
def generate_image(
self,
encoded_text,
batch_size=1,
num_steps=25,
unconditional_guidance_scale=7.5,
diffusion_noise=None,
seed=None,
):
"""Generates an image based on encoded text.
The encoding passed to this method should be derived from
`StableDiffusion.encode_text`.
Args:
encoded_text: Tensor of shape (`batch_size`, 77, 768), or a Tensor
of shape (77, 768). When the batch axis is omitted, the same encoded
text will be used to produce every generated image.
batch_size: number of images to generate. Default: 1.
num_steps: number of diffusion steps (controls image quality).
Default: 25.
unconditional_guidance_scale: float controling how closely the image
should adhere to the prompt. Larger values result in more
closely adhering to the prompt, but will make the image noisier.
Default: 7.5.
diffusion_noise: Tensor of shape (`batch_size`, img_height // 8,
img_width // 8, 4), or a Tensor of shape (img_height // 8,
img_width // 8, 4). Optional custom noise to seed the diffusion
process. When the batch axis is omitted, the same noise will be
used to seed diffusion for every generated image.
seed: integer which is used to seed the random generation of
diffusion noise, only to be specified if `diffusion_noise` is
None.
Example:
```python
from keras_cv.models import StableDiffusion
batch_size = 8
model = StableDiffusion(img_height=512, img_width=512, jit_compile=True)
e_tacos = model.encode_text("Tacos at dawn")
e_watermelons = model.encode_text("Watermelons at dusk")
e_interpolated = tf.linspace(e_tacos, e_watermelons, batch_size)
images = model.generate_image(e_interpolated, batch_size=batch_size)
```
"""
if diffusion_noise is not None and seed is not None:
raise ValueError(
"`diffusion_noise` and `seed` should not both be passed to "
"`generate_image`. `seed` is only used to generate diffusion "
"noise when it's not already user-specified."
)
encoded_text = tf.squeeze(encoded_text)
if encoded_text.shape.rank == 2:
encoded_text = tf.repeat(
tf.expand_dims(encoded_text, axis=0), batch_size, axis=0
)
context = encoded_text
unconditional_context = tf.repeat(
self._get_unconditional_context(), batch_size, axis=0
)
if diffusion_noise is not None:
diffusion_noise = tf.squeeze(diffusion_noise)
if diffusion_noise.shape.rank == 3:
diffusion_noise = tf.repeat(
tf.expand_dims(diffusion_noise, axis=0), batch_size, axis=0
)
latent = diffusion_noise
else:
latent = self._get_initial_diffusion_noise(batch_size, seed)
# Iterative reverse diffusion stage
timesteps = tf.range(1, 1000, 1000 // num_steps)
alphas, alphas_prev = self._get_initial_alphas(timesteps)
progbar = keras.utils.Progbar(len(timesteps))
iteration = 0
for index, timestep in list(enumerate(timesteps))[::-1]:
latent_prev = latent # Set aside the previous latent vector
t_emb = self._get_timestep_embedding(timestep, batch_size)
unconditional_latent = self.diffusion_model.predict_on_batch(
[latent, t_emb, unconditional_context]
)
latent = self.diffusion_model.predict_on_batch([latent, t_emb, context])
latent = unconditional_latent + unconditional_guidance_scale * (
latent - unconditional_latent
)
a_t, a_prev = alphas[index], alphas_prev[index]
pred_x0 = (latent_prev - math.sqrt(1 - a_t) * latent) / math.sqrt(a_t)
latent = latent * math.sqrt(1.0 - a_prev) + math.sqrt(a_prev) * pred_x0
iteration += 1
progbar.update(iteration)
# Decoding stage
decoded = self.decoder.predict_on_batch(latent)
decoded = ((decoded + 1) / 2) * 255
return np.clip(decoded, 0, 255).astype("uint8")
def _get_unconditional_context(self):
unconditional_tokens = tf.convert_to_tensor(
[_UNCONDITIONAL_TOKENS], dtype=tf.int32
)
unconditional_context = self.text_encoder.predict_on_batch(
[unconditional_tokens, self._get_pos_ids()]
)
return unconditional_context
@property
def image_encoder(self):
"""image_encoder returns the VAE Encoder with pretrained weights.
Usage:
```python
sd = keras_cv.models.StableDiffusion()
my_image = np.ones((512, 512, 3))
latent_representation = sd.image_encoder.predict(my_image)
```
"""
if self._image_encoder is None:
self._image_encoder = ImageEncoder(self.img_height, self.img_width)
if self.jit_compile:
self._image_encoder.compile(jit_compile=True)
return self._image_encoder
def _get_timestep_embedding(self, timestep, batch_size, dim=320, max_period=10000):
half = dim // 2
freqs = tf.math.exp(
-math.log(max_period) * tf.range(0, half, dtype=tf.float32) / half
)
args = tf.convert_to_tensor([timestep], dtype=tf.float32) * freqs
embedding = tf.concat([tf.math.cos(args), tf.math.sin(args)], 0)
embedding = tf.reshape(embedding, [1, -1])
return tf.repeat(embedding, batch_size, axis=0)
def _get_initial_alphas(self, timesteps):
alphas = [_ALPHAS_CUMPROD[t] for t in timesteps]
alphas_prev = [1.0] + alphas[:-1]
return alphas, alphas_prev
def _get_initial_diffusion_noise(self, batch_size, seed):
return tf.random.normal(
(batch_size, self.img_height // 8, self.img_width // 8, 4), seed=seed
)
@staticmethod
def _get_pos_ids():
return tf.convert_to_tensor([list(range(MAX_PROMPT_LENGTH))], dtype=tf.int32)
# Copyright 2022 The KerasCV Authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import tensorflow as tf
from tensorflow.keras import mixed_precision
from keras_cv.models import StableDiffusion
class StableDiffusionTest(tf.test.TestCase):
def DISABLED_test_end_to_end_golden_value(self):
prompt = "a caterpillar smoking a hookah while sitting on a mushroom"
stablediff = StableDiffusion(128, 128)
# Using TF global random seed to guarantee that subsequent text-to-image
# runs are seeded identically.
tf.random.set_seed(8675309)
img = stablediff.text_to_image(prompt)
self.assertAllClose(img[0][64:65, 64:65, :][0][0], [124, 188, 114], atol=1e-4)
# Verify that the step-by-step creation flow creates an identical output
tf.random.set_seed(8675309)
text_encoding = stablediff.encode_text(prompt)
self.assertAllClose(img, stablediff.generate_image(text_encoding), atol=1e-4)
def DISABLED_test_mixed_precision(self):
mixed_precision.set_global_policy("mixed_float16")
stablediff = StableDiffusion(128, 128)
_ = stablediff.text_to_image("Testing123 haha!")
def DISABLED_test_generate_image_rejects_noise_and_seed(self):
stablediff = StableDiffusion(128, 128)
with self.assertRaisesRegex(
ValueError, r"`diffusion_noise` and `seed` should not both be passed"
):
_ = stablediff.generate_image(
stablediff.encode_text("thou shall not render"),
diffusion_noise=tf.random.normal((1, 16, 16, 4)),
seed=1337,
)
if __name__ == "__main__":
tf.test.main()
# Copyright 2022 The KerasCV Authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import tensorflow as tf
from tensorflow import keras
from tensorflow.experimental import numpy as tfnp
class TextEncoder(keras.Model):
def __init__(self, max_length, name=None):
tokens = keras.layers.Input(shape=(max_length,), dtype="int32", name="tokens")
positions = keras.layers.Input(
shape=(max_length,), dtype="int32", name="positions"
)
x = CLIPEmbedding(49408, 768, max_length)([tokens, positions])
for _ in range(12):
x = CLIPEncoderLayer()(x)
embedded = keras.layers.LayerNormalization(epsilon=1e-5)(x)
super().__init__([tokens, positions], embedded, name=name)
class CLIPEmbedding(keras.layers.Layer):
def __init__(self, input_dim=49408, output_dim=768, max_length=77, **kwargs):
super().__init__(**kwargs)
self.token_embedding = keras.layers.Embedding(input_dim, output_dim)
self.position_embedding = keras.layers.Embedding(max_length, output_dim)
def call(self, inputs):
tokens, positions = inputs
tokens = self.token_embedding(tokens)
positions = self.position_embedding(positions)
return tokens + positions
class CLIPEncoderLayer(keras.layers.Layer):
def __init__(self, **kwargs):
super().__init__(**kwargs)
self.layer_norm1 = keras.layers.LayerNormalization(epsilon=1e-5)
self.clip_attn = CLIPAttention(causal=True)
self.layer_norm2 = keras.layers.LayerNormalization(epsilon=1e-5)
self.fc1 = keras.layers.Dense(3072)
self.fc2 = keras.layers.Dense(768)
def call(self, inputs):
residual = inputs
x = self.layer_norm1(inputs)
x = self.clip_attn(x)
x = residual + x
residual = x
x = self.layer_norm2(x)
x = self.fc1(x)
x = x * tf.sigmoid(x * 1.702) # Quick gelu
x = self.fc2(x)
return x + residual
class CLIPAttention(keras.layers.Layer):
def __init__(self, embed_dim=768, num_heads=12, causal=True, **kwargs):
super().__init__(**kwargs)
self.embed_dim = embed_dim
self.num_heads = num_heads
self.causal = causal
self.head_dim = self.embed_dim // self.num_heads
self.scale = self.head_dim**-0.5
self.q_proj = keras.layers.Dense(self.embed_dim)
self.k_proj = keras.layers.Dense(self.embed_dim)
self.v_proj = keras.layers.Dense(self.embed_dim)
self.out_proj = keras.layers.Dense(self.embed_dim)
def reshape_states(self, x, sequence_length, batch_size):
x = tf.reshape(x, (batch_size, sequence_length, self.num_heads, self.head_dim))
return tf.transpose(x, (0, 2, 1, 3)) # bs, heads, sequence_length, head_dim
def call(self, inputs, attention_mask=None):
if attention_mask is None and self.causal:
length = tf.shape(inputs)[1]
attention_mask = tfnp.triu(
tf.ones((1, 1, length, length), dtype=self.compute_dtype) * -tfnp.inf,
k=1,
)
_, tgt_len, embed_dim = inputs.shape
query_states = self.q_proj(inputs) * self.scale
key_states = self.reshape_states(self.k_proj(inputs), tgt_len, -1)
value_states = self.reshape_states(self.v_proj(inputs), tgt_len, -1)
proj_shape = (-1, tgt_len, self.head_dim)
query_states = self.reshape_states(query_states, tgt_len, -1)
query_states = tf.reshape(query_states, proj_shape)
key_states = tf.reshape(key_states, proj_shape)
src_len = tgt_len
value_states = tf.reshape(value_states, proj_shape)
attn_weights = query_states @ tf.transpose(key_states, (0, 2, 1))
attn_weights = tf.reshape(attn_weights, (-1, self.num_heads, tgt_len, src_len))
attn_weights = attn_weights + attention_mask
attn_weights = tf.reshape(attn_weights, (-1, tgt_len, src_len))
attn_weights = tf.nn.softmax(attn_weights)
attn_output = attn_weights @ value_states
attn_output = tf.reshape(
attn_output, (-1, self.num_heads, tgt_len, self.head_dim)
)
attn_output = tf.transpose(attn_output, (0, 2, 1, 3))
attn_output = tf.reshape(attn_output, (-1, tgt_len, embed_dim))
return self.out_proj(attn_output)
# Copyright 2022 The KerasCV Authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""MLP Mixer models for KerasCV.
Reference:
- [MLP-Mixer: An all-MLP Architecture for Vision](https://arxiv.org/abs/2105.01601)
"""
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import backend
from tensorflow.keras import layers
from keras_cv.models import utils
def MLPBlock(mlp_dim, name=None):
"""An MLP block consisting of two linear layers with GELU activation in
between.
Args:
mlp_dim: integer, the number of units to be present in the first layer.
name: string, block label.
Returns:
a function that takes an input Tensor representing an MLP block.
"""
if name is None:
name = f"mlp_block_{backend.get_uid('mlp_block')}"
def apply(x):
y = layers.Dense(mlp_dim, name=f"{name}_dense_1")(x)
y = layers.Activation("gelu", name=f"{name}_gelu")(y)
return layers.Dense(x.shape[-1], name=f"{name}_dense_2")(y)
return apply
def MixerBlock(tokens_mlp_dim, channels_mlp_dim, name=None):
"""A mixer block.
Args:
tokens_mlp_dim: integer, number of units to be present in the MLP block
dealing with tokens.
channels_mlp_dim: integer, number of units to be present in the MLP block
dealing with channels.
name: string, block label.
Returns:
a function that takes an input Tensor representing an MLP block.
"""
if name is None:
name = f"mixer_block_{backend.get_uid('mlp_block')}"
def apply(x):
y = layers.LayerNormalization()(x)
y = layers.Permute((2, 1))(y)
y = MLPBlock(tokens_mlp_dim, name=f"{name}_token_mixing")(y)
y = layers.Permute((2, 1))(y)
x = layers.Add()([x, y])
y = layers.LayerNormalization()(x)
y = MLPBlock(channels_mlp_dim, name=f"{name}_channel_mixing")(y)
return layers.Add()([x, y])
return apply
def MLPMixer(
input_shape,
patch_size,
num_blocks,
hidden_dim,
tokens_mlp_dim,
channels_mlp_dim,
include_rescaling,
include_top,
classes=None,
input_tensor=None,
weights=None,
pooling=None,
classifier_activation="softmax",
name=None,
**kwargs,
):
"""Instantiates the MLP Mixer architecture.
Reference:
- [MLP-Mixer: An all-MLP Architecture for Vision (NeurIPS 2021)](https://arxiv.org/abs/2105.01601)
This function returns a Keras MLP Mixer model.
For transfer learning use cases, make sure to read the
[guide to transfer learning & fine-tuning](https://keras.io/guides/transfer_learning/).
Note that the `input_shape` should be fully divisible by the `patch_size`.
Args:
input_shape: tuple denoting the input shape, (224, 224, 3) for example.
patch_size: tuple denoting the size of the patches to be extracted
from the inputs ((16, 16) for example).
num_blocks: number of mixer blocks.
hidden_dim: dimension to which the patches will be linearly projected.
tokens_mlp_dim: dimension of the MLP block responsible for tokens.
channels_mlp_dim: dimension of the MLP block responsible for channels.
include_rescaling: whether or not to Rescale the inputs.
If set to True, inputs will be passed through a
`Rescaling(1/255.0)` layer.
include_top: whether to include the fully-connected
layer at the top of the network. If provided, classes must be provided.
classes: optional number of classes to classify images
into, only to be specified if `include_top` is True.
weights: one of `None` (random initialization), or a pretrained
weight file path.
input_tensor: optional Keras tensor (i.e. output of `layers.Input()`)
to use as image input for the model.
pooling: optional pooling mode for feature extraction
when `include_top` is `False`.
- `None` means that the output of the model will be
the 4D tensor output of the
last convolutional block.
- `avg` means that global average pooling
will be applied to the output of the
last convolutional block, and thus
the output of the model will be a 2D tensor.
- `max` means that global max pooling will
be applied.
classifier_activation: A `str` or callable. The activation function to use
on the "top" layer. Ignored unless `include_top=True`. Set
`classifier_activation=None` to return the logits of the "top" layer.
When loading pretrained weights, `classifier_activation` can only
be `None` or `"softmax"`.
name: (Optional) name to pass to the model. Defaults to "DenseNet".
Returns:
A `keras.Model` instance.
"""
if weights and not tf.io.gfile.exists(weights):
raise ValueError(
"The `weights` argument should be either "
"`None` or the path to the weights file to be loaded. "
f"Weights file not found at location: {weights}"
)
if include_top and not classes:
raise ValueError(
"If `include_top` is True, "
"you should specify `classes`. "
f"Received: classes={classes}"
)
if (not isinstance(input_shape, tuple)) and (not isinstance(patch_size, tuple)):
raise ValueError("`input_shape` and `patch_size` both need to be tuple.")
if len(input_shape) != 3:
raise ValueError(
"`input_shape` needs to contain dimensions for three"
" axes: height, width, and channel ((224, 224, 3) for example)."
)
if len(patch_size) != 2:
raise ValueError(
"`patch_size` needs to contain dimensions for two"
" spatial axes: height, and width ((16, 16) for example)."
)
if input_shape[0] != input_shape[1]:
raise ValueError("Non-uniform resolutions are not supported.")
if patch_size[0] != patch_size[1]:
raise ValueError("Non-uniform patch sizes are not supported.")
if input_shape[0] % patch_size[0] != 0:
raise ValueError("Input resolution should be divisible by the patch size.")
inputs = utils.parse_model_inputs(input_shape, input_tensor)
x = inputs
if include_rescaling:
x = layers.Rescaling(1 / 255.0)(x)
x = layers.Conv2D(
filters=hidden_dim,
kernel_size=patch_size,
strides=patch_size,
padding="valid",
name="patchify_and_projection",
)(x)
x = layers.Reshape((x.shape[1] * x.shape[2], x.shape[3]))(x)
for i in range(num_blocks):
x = MixerBlock(tokens_mlp_dim, channels_mlp_dim, name=f"mixer_block_{i}")(x)
x = layers.LayerNormalization()(x)
if include_top:
x = layers.GlobalAveragePooling1D(name="avg_pool")(x)
x = layers.Dense(classes, activation=classifier_activation, name="predictions")(
x
)
elif pooling == "avg":
x = layers.GlobalAveragePooling1D(name="avg_pool")(x)
elif pooling == "max":
x = layers.GlobalMaxPooling1D(name="max_pool")(x)
model = keras.Model(inputs, x, name=name, **kwargs)
if weights is not None:
model.load_weights(weights)
return model
def MLPMixerB16(
input_shape,
patch_size,
include_rescaling,
include_top,
classes=None,
input_tensor=None,
weights=None,
pooling=None,
name="mlp_mixer_b16",
**kwargs,
):
return MLPMixer(
input_shape=input_shape,
patch_size=patch_size,
num_blocks=12,
hidden_dim=768,
tokens_mlp_dim=384,
channels_mlp_dim=3072,
include_rescaling=include_rescaling,
include_top=include_top,
classes=classes,
input_tensor=input_tensor,
weights=weights,
pooling=pooling,
name=name,
**kwargs,
)
def MLPMixerB32(
input_shape,
patch_size,
include_rescaling,
include_top,
classes=None,
input_tensor=None,
weights=None,
pooling=None,
name="mlp_mixer_b32",
**kwargs,
):
return MLPMixer(
input_shape=input_shape,
patch_size=patch_size,
num_blocks=12,
hidden_dim=768,
tokens_mlp_dim=384,
channels_mlp_dim=3072,
include_rescaling=include_rescaling,
include_top=include_top,
classes=classes,
input_tensor=input_tensor,
weights=weights,
pooling=pooling,
name=name,
**kwargs,
)
def MLPMixerL16(
input_shape,
patch_size,
include_rescaling,
include_top,
classes=None,
input_tensor=None,
weights=None,
pooling=None,
name="mlp_mixer_l16",
**kwargs,
):
return MLPMixer(
input_shape=input_shape,
patch_size=patch_size,
num_blocks=24,
hidden_dim=1024,
tokens_mlp_dim=512,
channels_mlp_dim=4096,
include_rescaling=include_rescaling,
include_top=include_top,
classes=classes,
input_tensor=input_tensor,
weights=weights,
pooling=pooling,
name=name,
**kwargs,
)
# Copyright 2022 The KerasCV Authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import tensorflow as tf
from absl.testing import parameterized
from keras_cv.models import mlp_mixer
from .models_test import ModelsTest
MODEL_LIST = [
(
mlp_mixer.MLPMixerB16,
768,
{"patch_size": (16, 16), "input_shape": (224, 224, 3)},
),
(
mlp_mixer.MLPMixerB32,
768,
{"patch_size": (32, 32), "input_shape": (224, 224, 3)},
),
(
mlp_mixer.MLPMixerL16,
1024,
{"patch_size": (16, 16), "input_shape": (224, 224, 3)},
),
]
class MLPMixerTest(ModelsTest, tf.test.TestCase, parameterized.TestCase):
@parameterized.parameters(*MODEL_LIST)
def test_application_base(self, app, _, args):
super()._test_application_base(app, _, args)
@parameterized.parameters(*MODEL_LIST)
def test_application_with_rescaling(self, app, last_dim, args):
super()._test_application_with_rescaling(app, last_dim, args)
@parameterized.parameters(*MODEL_LIST)
def test_application_pooling(self, app, last_dim, args):
super()._test_application_pooling(app, last_dim, args)
@parameterized.parameters(*MODEL_LIST)
def test_application_variable_input_channels(self, app, last_dim, args):
super()._test_application_variable_input_channels(app, last_dim, args)
@parameterized.parameters(*MODEL_LIST)
def test_model_can_be_used_as_backbone(self, app, last_dim, args):
super()._test_model_can_be_used_as_backbone(app, last_dim, args)
if __name__ == "__main__":
tf.test.main()
# Copyright 2022 The KerasCV Authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""MobileNet v3 models for KerasCV.
References:
- [Searching for MobileNetV3](https://arxiv.org/pdf/1905.02244.pdf) (ICCV 2019)
- [Based on the original keras.applications MobileNetv3](https://github.com/keras-team/keras/blob/master/keras/applications/mobilenet_v3.py)
"""
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import backend
from tensorflow.keras import layers
from tensorflow.keras.utils import custom_object_scope
from keras_cv import layers as cv_layers
from keras_cv.models import utils
channel_axis = -1
BASE_DOCSTRING = """Instantiates the {name} architecture.
References:
- [Searching for MobileNetV3](https://arxiv.org/abs/1905.02244)
- [Based on the Original keras.applications MobileNetv3](https://github.com/keras-team/keras/blob/master/keras/applications/mobilenet_v3.py)
This function returns a Keras {name} model.
For transfer learning use cases, make sure to read the [guide to transfer
learning & fine-tuning](https://keras.io/guides/transfer_learning/).
Args:
include_rescaling: whether or not to Rescale the inputs.If set to True,
inputs will be passed through a `Rescaling(scale=1 / 255)`
layer, defaults to True.
include_top: whether to include the fully-connected layer at the top of the
network. If provided, `classes` must be provided.
classes: optional number of classes to classify images into, only to be
specified if `include_top` is True, and if no `weights` argument is
specified.
weights: one of `None` (random initialization), or a pretrained weight file
path.
input_shape: optional shape tuple, defaults to (None, None, 3).
input_tensor: optional Keras tensor (i.e. output of `layers.Input()`)
to use as image input for the model.
pooling: optional pooling mode for feature extraction
when `include_top` is `False`.
- `None` means that the output of the model will be the 4D tensor output
of the last convolutional block.
- `avg` means that global average pooling will be applied to the output
of the last convolutional block, and thus the output of the model will
be a 2D tensor.
- `max` means that global max pooling will be applied.
alpha: controls the width of the network. This is known as the
depth multiplier in the MobileNetV3 paper, but the name is kept for
consistency with MobileNetV1 in Keras.
- If `alpha` < 1.0, proportionally decreases the number
of filters in each layer.
- If `alpha` > 1.0, proportionally increases the number
of filters in each layer.
- If `alpha` = 1, default number of filters from the paper
are used at each layer.
minimalistic: in addition to large and small models this module also
contains so-called minimalistic models, these models have the same
per-layer dimensions characteristic as MobilenetV3 however, they don't
utilize any of the advanced blocks (squeeze-and-excite units, hard-swish,
and 5x5 convolutions). While these models are less efficient on CPU, they
are much more performant on GPU/DSP.
dropout_rate: a float between 0 and 1 denoting the fraction of input units to
drop, defaults to 0.2.
classifier_activation: the activation function to use, defaults to softmax.
name: (Optional) name to pass to the model. Defaults to "{name}".
Returns:
A `keras.Model` instance.
"""
def depth(x, divisor=8, min_value=None):
"""Ensure that all layers have a channel number that is divisble by the `divisor`.
Args:
x: input value.
divisor: integer, the value by which a channel number should be divisble,
defaults to 8.
min_value: float, minimum value for the new tensor.
Returns:
the updated value of the input.
"""
if min_value is None:
min_value = divisor
new_x = max(min_value, int(x + divisor / 2) // divisor * divisor)
# make sure that round down does not go down by more than 10%.
if new_x < 0.9 * x:
new_x += divisor
return new_x
def HardSigmoid(name=None):
"""The Hard Sigmoid function.
Args:
name: string, layer label.
Returns:
a function that takes an input Tensor representing a HardSigmoid layer.
"""
if name is None:
name = f"hard_sigmoid_{backend.get_uid('hard_sigmoid')}"
activation = layers.ReLU(6.0)
def apply(x):
return activation(x + 3.0) * (1.0 / 6.0)
return apply
def HardSwish(name=None):
"""The Hard Swish function.
Args:
name: string, layer label.
Returns:
a function that takes an input Tensor representing a HardSwish layer.
"""
if name is None:
name = f"hard_swish_{backend.get_uid('hard_swish')}"
hard_sigmoid = HardSigmoid()
multiply_layer = layers.Multiply()
def apply(x):
return multiply_layer([x, hard_sigmoid(x)])
return apply
def InvertedResBlock(
expansion, filters, kernel_size, stride, se_ratio, activation, block_id, name=None
):
"""An Inverted Residual Block.
Args:
expansion: integer, the expansion ratio, multiplied with infilters to get the
minimum value passed to depth.
filters: integer, number of filters for convolution layer.
kernel_size: integer, the kernel size for DpethWise Convolutions.
stride: integer, the stride length for DpethWise Convolutions.
se_ratio: float, ratio for bottleneck filters. Number of bottleneck
filters = filters * se_ratio.
activation: the activation layer to use.
block_id: integer, a unique identification if you want to use expanded
convolutions.
name: string, layer label.
Returns:
a function that takes an input Tensor representing a InvertedResBlock.
"""
if name is None:
name = f"inverted_res_block_{backend.get_uid('inverted_res_block')}"
def apply(x):
shortcut = x
prefix = "expanded_conv/"
infilters = backend.int_shape(x)[channel_axis]
if block_id:
prefix = f"expanded_conv_{block_id}"
x = layers.Conv2D(
depth(infilters * expansion),
kernel_size=1,
padding="same",
use_bias=False,
name=prefix + "expand",
)(x)
x = layers.BatchNormalization(
axis=channel_axis,
epsilon=1e-3,
momentum=0.999,
name=prefix + "expand/BatchNorm",
)(x)
x = activation(x)
x = layers.DepthwiseConv2D(
kernel_size,
strides=stride,
padding="same" if stride == 1 else "valid",
use_bias=False,
name=prefix + "depthwise",
)(x)
x = layers.BatchNormalization(
axis=channel_axis,
epsilon=1e-3,
momentum=0.999,
name=prefix + "depthwise/BatchNorm",
)(x)
x = activation(x)
if se_ratio:
with custom_object_scope({"hard_sigmoid": HardSigmoid()}):
x = cv_layers.SqueezeAndExcite2D(
filters=depth(infilters * expansion),
ratio=se_ratio,
squeeze_activation="relu",
excite_activation="hard_sigmoid",
)(x)
x = layers.Conv2D(
filters,
kernel_size=1,
padding="same",
use_bias=False,
name=prefix + "project",
)(x)
x = layers.BatchNormalization(
axis=channel_axis,
epsilon=1e-3,
momentum=0.999,
name=prefix + "project/BatchNorm",
)(x)
if stride == 1 and infilters == filters:
x = layers.Add(name=prefix + "Add")([shortcut, x])
return x
return apply
def MobileNetV3(
stack_fn,
last_point_ch,
include_rescaling,
include_top,
classes=None,
weights=None,
input_shape=(None, None, 3),
input_tensor=None,
pooling=None,
alpha=1.0,
minimalistic=True,
dropout_rate=0.2,
classifier_activation="softmax",
name="MobileNetV3",
**kwargs,
):
"""Instantiates the MobileNetV3 architecture.
References:
- [Searching for MobileNetV3](https://arxiv.org/pdf/1905.02244.pdf) (ICCV 2019)
- [Based on the Original keras.applications MobileNetv3](https://github.com/keras-team/keras/blob/master/keras/applications/mobilenet_v3.py)
This function returns a Keras MobileNetV3 model.
For transfer learning use cases, make sure to read the [guide to transfer
learning & fine-tuning](https://keras.io/guides/transfer_learning/).
Args:
stack_fn: a function that returns tensors passed through Inverted
Residual Blocks.
last_point_ch: the number of filters for the convolution layer.
include_rescaling: whether or not to Rescale the inputs.If set to True,
inputs will be passed through a `Rescaling(scale=1 / 255)`
layer, defaults to True.
include_top: whether to include the fully-connected layer at the top of the
network. If provided, `classes` must be provided.
classes: optional number of classes to classify images into, only to be
specified if `include_top` is True, and if no `weights` argument is
specified.
weights: one of `None` (random initialization), or a pretrained weight file
path.
input_shape: optional shape tuple, defaults to (None, None, 3).
input_tensor: optional Keras tensor (i.e. output of `layers.Input()`)
to use as image input for the model.
pooling: optional pooling mode for feature extraction
when `include_top` is `False`.
- `None` means that the output of the model will be the 4D tensor output
of the last convolutional block.
- `avg` means that global average pooling will be applied to the output
of the last convolutional block, and thus the output of the model will
be a 2D tensor.
- `max` means that global max pooling will be applied.
alpha: controls the width of the network. This is known as the
depth multiplier in the MobileNetV3 paper, but the name is kept for
consistency with MobileNetV1 in Keras.
- If `alpha` < 1.0, proportionally decreases the number
of filters in each layer.
- If `alpha` > 1.0, proportionally increases the number
of filters in each layer.
- If `alpha` = 1, default number of filters from the paper
are used at each layer.
minimalistic: in addition to large and small models this module also
contains so-called minimalistic models, these models have the same
per-layer dimensions characteristic as MobilenetV3 however, they don't
utilize any of the advanced blocks (squeeze-and-excite units, hard-swish,
and 5x5 convolutions). While these models are less efficient on CPU, they
are much more performant on GPU/DSP.
dropout_rate: a float between 0 and 1 denoting the fraction of input units to
drop, defaults to 0.2.
classifier_activation: the activation function to use, defaults to softmax.
name: (Optional) name to pass to the model. Defaults to "MobileNetV3".
Returns:
A `keras.Model` instance.
Raises:
ValueError: if `weights` represents an invalid path to weights file and is not
None.
ValueError: if `include_top` is True and `classes` is not specified.
"""
if weights and not tf.io.gfile.exists(weights):
raise ValueError(
"The `weights` argument should be either "
"`None` or the path to the weights file to be loaded. "
f"Weights file not found at location: {weights}"
)
if include_top and not classes:
raise ValueError(
"If `include_top` is True, "
"you should specify `classes`. "
f"Received: classes={classes}"
)
if minimalistic:
kernel = 3
activation = layers.ReLU()
se_ratio = None
else:
kernel = 5
activation = HardSwish()
se_ratio = 0.25
inputs = utils.parse_model_inputs(input_shape, input_tensor)
x = inputs
if include_rescaling:
x = layers.Rescaling(scale=1 / 255)(x)
x = layers.Conv2D(
16,
kernel_size=3,
strides=(2, 2),
padding="same",
use_bias=False,
name="Conv",
)(x)
x = layers.BatchNormalization(
axis=channel_axis, epsilon=1e-3, momentum=0.999, name="Conv/BatchNorm"
)(x)
x = activation(x)
x = stack_fn(x, kernel, activation, se_ratio)
last_conv_ch = depth(backend.int_shape(x)[channel_axis] * 6)
# if the width multiplier is greater than 1 we
# increase the number of output channels
if alpha > 1.0:
last_point_ch = depth(last_point_ch * alpha)
x = layers.Conv2D(
last_conv_ch,
kernel_size=1,
padding="same",
use_bias=False,
name="Conv_1",
)(x)
x = layers.BatchNormalization(
axis=channel_axis, epsilon=1e-3, momentum=0.999, name="Conv_1/BatchNorm"
)(x)
x = activation(x)
if include_top:
x = layers.GlobalAveragePooling2D(keepdims=True)(x)
x = layers.Conv2D(
last_point_ch,
kernel_size=1,
padding="same",
use_bias=True,
name="Conv_2",
)(x)
x = activation(x)
if dropout_rate > 0:
x = layers.Dropout(dropout_rate)(x)
x = layers.Conv2D(classes, kernel_size=1, padding="same", name="Logits")(x)
x = layers.Flatten()(x)
x = layers.Activation(activation=classifier_activation, name="Predictions")(x)
elif pooling == "avg":
x = layers.GlobalAveragePooling2D(name="avg_pool")(x)
elif pooling == "max":
x = layers.GlobalMaxPooling2D(name="max_pool")(x)
model = keras.Model(inputs, x, name=name, **kwargs)
if weights is not None:
model.load_weights(weights)
return model
def MobileNetV3Small(
include_rescaling,
include_top,
classes=None,
weights=None,
input_shape=(None, None, 3),
input_tensor=None,
pooling=None,
alpha=1.0,
minimalistic=False,
dropout_rate=0.2,
classifier_activation="softmax",
name="MobileNetV3Small",
**kwargs,
):
def stack_fn(x, kernel, activation, se_ratio):
x = InvertedResBlock(1, depth(16 * alpha), 3, 2, se_ratio, layers.ReLU(), 0)(x)
x = InvertedResBlock(
72.0 / 16, depth(24 * alpha), 3, 2, None, layers.ReLU(), 1
)(x)
x = InvertedResBlock(
88.0 / 24, depth(24 * alpha), 3, 1, None, layers.ReLU(), 2
)(x)
x = InvertedResBlock(4, depth(40 * alpha), kernel, 2, se_ratio, activation, 3)(
x
)
x = InvertedResBlock(6, depth(40 * alpha), kernel, 1, se_ratio, activation, 4)(
x
)
x = InvertedResBlock(6, depth(40 * alpha), kernel, 1, se_ratio, activation, 5)(
x
)
x = InvertedResBlock(3, depth(48 * alpha), kernel, 1, se_ratio, activation, 6)(
x
)
x = InvertedResBlock(3, depth(48 * alpha), kernel, 1, se_ratio, activation, 7)(
x
)
x = InvertedResBlock(6, depth(96 * alpha), kernel, 2, se_ratio, activation, 8)(
x
)
x = InvertedResBlock(6, depth(96 * alpha), kernel, 1, se_ratio, activation, 9)(
x
)
x = InvertedResBlock(6, depth(96 * alpha), kernel, 1, se_ratio, activation, 10)(
x
)
return x
return MobileNetV3(
stack_fn=stack_fn,
last_point_ch=1024,
include_rescaling=include_rescaling,
include_top=include_top,
classes=classes,
weights=weights,
input_shape=input_shape,
input_tensor=input_tensor,
pooling=pooling,
alpha=alpha,
minimalistic=minimalistic,
dropout_rate=dropout_rate,
classifier_activation=classifier_activation,
name=name,
**kwargs,
)
def MobileNetV3Large(
include_rescaling,
include_top,
classes=None,
weights=None,
input_shape=(None, None, 3),
input_tensor=None,
pooling=None,
alpha=1.0,
minimalistic=False,
dropout_rate=0.2,
classifier_activation="softmax",
name="MobileNetV3Large",
**kwargs,
):
def stack_fn(x, kernel, activation, se_ratio):
x = InvertedResBlock(1, depth(16 * alpha), 3, 1, None, layers.ReLU(), 0)(x)
x = InvertedResBlock(4, depth(24 * alpha), 3, 2, None, layers.ReLU(), 1)(x)
x = InvertedResBlock(3, depth(24 * alpha), 3, 1, None, layers.ReLU(), 2)(x)
x = InvertedResBlock(
3, depth(40 * alpha), kernel, 2, se_ratio, layers.ReLU(), 3
)(x)
x = InvertedResBlock(
3, depth(40 * alpha), kernel, 1, se_ratio, layers.ReLU(), 4
)(x)
x = InvertedResBlock(
3, depth(40 * alpha), kernel, 1, se_ratio, layers.ReLU(), 5
)(x)
x = InvertedResBlock(6, depth(80 * alpha), 3, 2, None, activation, 6)(x)
x = InvertedResBlock(2.5, depth(80 * alpha), 3, 1, None, activation, 7)(x)
x = InvertedResBlock(2.3, depth(80 * alpha), 3, 1, None, activation, 8)(x)
x = InvertedResBlock(2.3, depth(80 * alpha), 3, 1, None, activation, 9)(x)
x = InvertedResBlock(6, depth(112 * alpha), 3, 1, se_ratio, activation, 10)(x)
x = InvertedResBlock(6, depth(112 * alpha), 3, 1, se_ratio, activation, 11)(x)
x = InvertedResBlock(
6, depth(160 * alpha), kernel, 2, se_ratio, activation, 12
)(x)
x = InvertedResBlock(
6, depth(160 * alpha), kernel, 1, se_ratio, activation, 13
)(x)
x = InvertedResBlock(
6, depth(160 * alpha), kernel, 1, se_ratio, activation, 14
)(x)
return x
return MobileNetV3(
stack_fn=stack_fn,
last_point_ch=1280,
include_rescaling=include_rescaling,
include_top=include_top,
classes=classes,
weights=weights,
input_shape=input_shape,
input_tensor=input_tensor,
pooling=pooling,
alpha=alpha,
minimalistic=minimalistic,
dropout_rate=dropout_rate,
classifier_activation=classifier_activation,
name=name,
**kwargs,
)
setattr(MobileNetV3Large, "__doc__", BASE_DOCSTRING.format(name="MobileNetV3Large"))
setattr(MobileNetV3Small, "__doc__", BASE_DOCSTRING.format(name="MobileNetV3Small"))
# Copyright 2022 The KerasCV Authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import tensorflow as tf
from absl.testing import parameterized
from keras_cv.models import mobilenet_v3
from .models_test import ModelsTest
MODEL_LIST = [
(mobilenet_v3.MobileNetV3Small, 576, {}),
(mobilenet_v3.MobileNetV3Large, 960, {}),
]
class MobileNetV3Test(ModelsTest, tf.test.TestCase, parameterized.TestCase):
@parameterized.parameters(*MODEL_LIST)
def test_application_base(self, app, _, args):
super()._test_application_base(app, _, args)
@parameterized.parameters(*MODEL_LIST)
def test_application_with_rescaling(self, app, last_dim, args):
super()._test_application_with_rescaling(app, last_dim, args)
@parameterized.parameters(*MODEL_LIST)
def test_application_pooling(self, app, last_dim, args):
super()._test_application_pooling(app, last_dim, args)
@parameterized.parameters(*MODEL_LIST)
def test_application_variable_input_channels(self, app, last_dim, args):
super()._test_application_variable_input_channels(app, last_dim, args)
@parameterized.parameters(*MODEL_LIST)
def test_model_can_be_used_as_backbone(self, app, last_dim, args):
super()._test_model_can_be_used_as_backbone(app, last_dim, args)
if __name__ == "__main__":
tf.test.main()
# Copyright 2022 The KerasCV Authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Integration tests for KerasCV models."""
import pytest
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import backend
class ModelsTest:
def assertShapeEqual(self, shape1, shape2):
self.assertEqual(tf.TensorShape(shape1), tf.TensorShape(shape2))
@pytest.fixture(autouse=True)
def cleanup_global_session(self):
# Code before yield runs before the test
yield
tf.keras.backend.clear_session()
def _test_application_base(self, app, _, args):
# Can be instantiated with default arguments
model = app(include_top=True, classes=1000, include_rescaling=False, **args)
# Can be serialized and deserialized
config = model.get_config()
reconstructed_model = model.__class__.from_config(config)
self.assertEqual(len(model.weights), len(reconstructed_model.weights))
# There is no rescaling layer bcause include_rescaling=False
with self.assertRaises(ValueError):
model.get_layer(name="rescaling")
def _test_application_with_rescaling(self, app, last_dim, args):
model = app(include_rescaling=True, include_top=False, **args)
self.assertIsNotNone(model.get_layer(name="rescaling"))
def _test_application_pooling(self, app, last_dim, args):
model = app(include_rescaling=False, include_top=False, pooling="avg", **args)
self.assertShapeEqual(model.output_shape, (None, last_dim))
def _test_application_variable_input_channels(self, app, last_dim, args):
# Make a local copy of args because we modify them in the test
args = dict(args)
input_shape = (None, None, 3)
# Avoid passing this parameter twice to the app function
if "input_shape" in args:
input_shape = args["input_shape"]
del args["input_shape"]
single_channel_input_shape = (input_shape[0], input_shape[1], 1)
model = app(
include_rescaling=False,
include_top=False,
input_shape=single_channel_input_shape,
**args
)
output_shape = model.output_shape
if "Mixer" not in app.__name__:
self.assertShapeEqual(output_shape, (None, None, None, last_dim))
elif "MixerB16" in app.__name__ or "MixerL16" in app.__name__:
num_patches = 196
self.assertShapeEqual(output_shape, (None, num_patches, last_dim))
elif "MixerB32" in app.__name__:
num_patches = 49
self.assertShapeEqual(output_shape, (None, num_patches, last_dim))
backend.clear_session()
four_channel_input_shape = (input_shape[0], input_shape[1], 4)
model = app(
include_rescaling=False,
include_top=False,
input_shape=four_channel_input_shape,
**args
)
output_shape = model.output_shape
if "Mixer" not in app.__name__:
self.assertShapeEqual(output_shape, (None, None, None, last_dim))
elif "MixerB16" in app.__name__ or "MixerL16" in app.__name__:
num_patches = 196
self.assertShapeEqual(output_shape, (None, num_patches, last_dim))
elif "MixerB32" in app.__name__:
num_patches = 49
self.assertShapeEqual(output_shape, (None, num_patches, last_dim))
def _test_model_can_be_used_as_backbone(self, app, last_dim, args):
inputs = keras.layers.Input(shape=(224, 224, 3))
backbone = app(
include_rescaling=False,
include_top=False,
input_tensor=inputs,
pooling="avg",
**args
)
x = inputs
x = backbone(x)
backbone_output = backbone.get_layer(index=-1).output
model = keras.Model(inputs=inputs, outputs=[backbone_output])
model.compile()
if __name__ == "__main__":
tf.test.main()
# Copyright 2022 The KerasCV Authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Copyright 2022 The KerasCV Authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import functools
import math
import numpy as np
import tensorflow as tf
try:
import pandas as pd
except ImportError:
pd = None
def _get_tensor_types():
if pd is None:
return (tf.Tensor, np.ndarray)
else:
return (tf.Tensor, np.ndarray, pd.Series, pd.DataFrame)
def convert_inputs_to_tf_dataset(x=None, y=None, sample_weight=None, batch_size=None):
if sample_weight is not None:
raise ValueError("RetinaNet does not yet support `sample_weight`.")
if isinstance(x, tf.data.Dataset):
if y is not None or batch_size is not None:
raise ValueError(
"When `x` is a `tf.data.Dataset`, please do not provide a value for "
f"`y` or `batch_size`. Got `y={y}`, `batch_size={batch_size}`."
)
return x
# batch_size defaults to 32, as it does in fit().
batch_size = batch_size or 32
# Parse inputs
inputs = x
if y is not None:
inputs = (x, y)
# Construct tf.data.Dataset
dataset = tf.data.Dataset.from_tensor_slices(inputs)
if batch_size == "full":
dataset = dataset.batch(x.shape[0])
elif batch_size is not None:
dataset = dataset.batch(batch_size)
return dataset
# TODO(lukewood): remove once exported from Keras core.
def train_validation_split(arrays, validation_split):
"""Split arrays into train and validation subsets in deterministic order.
The last part of data will become validation data.
Args:
arrays: Tensors to split. Allowed inputs are arbitrarily nested structures
of Tensors and NumPy arrays.
validation_split: Float between 0 and 1. The proportion of the dataset to
include in the validation split. The rest of the dataset will be
included in the training split.
Returns:
`(train_arrays, validation_arrays)`
"""
def _can_split(t):
tensor_types = _get_tensor_types()
return isinstance(t, tensor_types) or t is None
flat_arrays = tf.nest.flatten(arrays)
unsplitable = [type(t) for t in flat_arrays if not _can_split(t)]
if unsplitable:
raise ValueError(
"`validation_split` is only supported for Tensors or NumPy "
"arrays, found following types in the input: {}".format(unsplitable)
)
if all(t is None for t in flat_arrays):
return arrays, arrays
first_non_none = None
for t in flat_arrays:
if t is not None:
first_non_none = t
break
# Assumes all arrays have the same batch shape or are `None`.
batch_dim = int(first_non_none.shape[0])
split_at = int(math.floor(batch_dim * (1.0 - validation_split)))
if split_at == 0 or split_at == batch_dim:
raise ValueError(
"Training data contains {batch_dim} samples, which is not "
"sufficient to split it into a validation and training set as "
"specified by `validation_split={validation_split}`. Either "
"provide more data, or a different value for the "
"`validation_split` argument.".format(
batch_dim=batch_dim, validation_split=validation_split
)
)
def _split(t, start, end):
if t is None:
return t
return t[start:end]
train_arrays = tf.nest.map_structure(
functools.partial(_split, start=0, end=split_at), arrays
)
val_arrays = tf.nest.map_structure(
functools.partial(_split, start=split_at, end=batch_dim), arrays
)
return train_arrays, val_arrays
# Copyright 2022 The KerasCV Authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import tensorflow as tf
from absl import logging
from keras_cv.bounding_box.converters import _decode_deltas_to_boxes
from keras_cv.bounding_box.utils import _clip_boxes
from keras_cv.layers.object_detection.anchor_generator import AnchorGenerator
from keras_cv.layers.object_detection.roi_align import _ROIAligner
from keras_cv.layers.object_detection.roi_generator import ROIGenerator
from keras_cv.layers.object_detection.roi_sampler import _ROISampler
from keras_cv.layers.object_detection.rpn_label_encoder import _RpnLabelEncoder
from keras_cv.ops.box_matcher import ArgmaxBoxMatcher
def _resnet50_backbone(include_rescaling=False):
inputs = tf.keras.layers.Input(shape=(None, None, 3))
x = inputs
if include_rescaling:
x = tf.keras.applications.resnet.preprocess_input(x)
backbone = tf.keras.applications.ResNet50(include_top=False, input_tensor=x)
c2_output, c3_output, c4_output, c5_output = [
backbone.get_layer(layer_name).output
for layer_name in [
"conv2_block3_out",
"conv3_block4_out",
"conv4_block6_out",
"conv5_block3_out",
]
]
return tf.keras.Model(
inputs=inputs, outputs=[c2_output, c3_output, c4_output, c5_output]
)
class FeaturePyramid(tf.keras.layers.Layer):
"""Builds the Feature Pyramid with the feature maps from the backbone."""
def __init__(self, **kwargs):
super().__init__(**kwargs)
self.conv_c2_1x1 = tf.keras.layers.Conv2D(256, 1, 1, "same")
self.conv_c3_1x1 = tf.keras.layers.Conv2D(256, 1, 1, "same")
self.conv_c4_1x1 = tf.keras.layers.Conv2D(256, 1, 1, "same")
self.conv_c5_1x1 = tf.keras.layers.Conv2D(256, 1, 1, "same")
self.conv_c2_3x3 = tf.keras.layers.Conv2D(256, 3, 1, "same")
self.conv_c3_3x3 = tf.keras.layers.Conv2D(256, 3, 1, "same")
self.conv_c4_3x3 = tf.keras.layers.Conv2D(256, 3, 1, "same")
self.conv_c5_3x3 = tf.keras.layers.Conv2D(256, 3, 1, "same")
self.conv_c6_3x3 = tf.keras.layers.Conv2D(256, 3, 1, "same")
self.conv_c6_pool = tf.keras.layers.MaxPool2D()
self.upsample_2x = tf.keras.layers.UpSampling2D(2)
def call(self, inputs, training=None):
c2_output, c3_output, c4_output, c5_output = inputs
c6_output = self.conv_c6_pool(c5_output)
p6_output = c6_output
p5_output = self.conv_c5_1x1(c5_output)
p4_output = self.conv_c4_1x1(c4_output)
p3_output = self.conv_c3_1x1(c3_output)
p2_output = self.conv_c2_1x1(c2_output)
p4_output = p4_output + self.upsample_2x(p5_output)
p3_output = p3_output + self.upsample_2x(p4_output)
p2_output = p2_output + self.upsample_2x(p3_output)
p6_output = self.conv_c6_3x3(p6_output)
p5_output = self.conv_c5_3x3(p5_output)
p4_output = self.conv_c4_3x3(p4_output)
p3_output = self.conv_c3_3x3(p3_output)
p2_output = self.conv_c2_3x3(p2_output)
return {2: p2_output, 3: p3_output, 4: p4_output, 5: p5_output, 6: p6_output}
def get_config(self):
config = {}
base_config = super().get_config()
return dict(list(base_config.items()) + list(config.items()))
class RPNHead(tf.keras.layers.Layer):
def __init__(
self,
num_anchors_per_location=3,
**kwargs,
):
super().__init__(**kwargs)
self.num_anchors = num_anchors_per_location
def build(self, input_shape):
if isinstance(input_shape, (dict, list, tuple)):
input_shape = tf.nest.flatten(input_shape)
input_shape = input_shape[0]
filters = input_shape[-1]
self.conv = tf.keras.layers.Conv2D(
filters=filters,
kernel_size=3,
strides=1,
padding="same",
activation="relu",
kernel_initializer="truncated_normal",
)
self.objectness_logits = tf.keras.layers.Conv2D(
filters=self.num_anchors * 1,
kernel_size=1,
strides=1,
padding="same",
kernel_initializer="truncated_normal",
)
self.anchor_deltas = tf.keras.layers.Conv2D(
filters=self.num_anchors * 4,
kernel_size=1,
strides=1,
padding="same",
kernel_initializer="truncated_normal",
)
def call(self, feature_map):
def call_single_level(f_map):
batch_size = f_map.get_shape().as_list()[0]
if batch_size is None:
raise ValueError("Cannot handle static shape")
# [BS, H, W, C]
t = self.conv(f_map)
# [BS, H, W, K]
rpn_scores = self.objectness_logits(t)
# [BS, H, W, K * 4]
rpn_boxes = self.anchor_deltas(t)
# [BS, H*W*K, 4]
rpn_boxes = tf.reshape(rpn_boxes, [batch_size, -1, 4])
# [BS, H*W*K, 1]
rpn_scores = tf.reshape(rpn_scores, [batch_size, -1, 1])
return rpn_boxes, rpn_scores
if not isinstance(feature_map, (dict, list, tuple)):
return call_single_level(feature_map)
elif isinstance(feature_map, (list, tuple)):
rpn_boxes = []
rpn_scores = []
for f_map in feature_map:
rpn_box, rpn_score = call_single_level(f_map)
rpn_boxes.append(rpn_box)
rpn_scores.append(rpn_score)
return rpn_boxes, rpn_scores
else:
rpn_boxes = {}
rpn_scores = {}
for lvl, f_map in feature_map.items():
rpn_box, rpn_score = call_single_level(f_map)
rpn_boxes[lvl] = rpn_box
rpn_scores[lvl] = rpn_score
return rpn_boxes, rpn_scores
def get_config(self):
config = {
"num_anchors_per_location": self.num_anchors,
}
base_config = super().get_config()
return dict(list(base_config.items()) + list(config.items()))
# class agnostic regression
class RCNNHead(tf.keras.layers.Layer):
def __init__(
self,
classes,
conv_dims=[],
fc_dims=[1024, 1024],
**kwargs,
):
super().__init__(**kwargs)
self.num_classes = classes
self.conv_dims = conv_dims
self.fc_dims = fc_dims
self.convs = []
for conv_dim in conv_dims:
layer = tf.keras.layers.Conv2D(
filters=conv_dim,
kernel_size=3,
strides=1,
padding="same",
activation="relu",
)
self.convs.append(layer)
self.fcs = []
for fc_dim in fc_dims:
layer = tf.keras.layers.Dense(units=fc_dim, activation="relu")
self.fcs.append(layer)
self.box_pred = tf.keras.layers.Dense(units=4)
self.cls_score = tf.keras.layers.Dense(units=classes + 1, activation="softmax")
def call(self, feature_map):
x = feature_map
for conv in self.convs:
x = conv(x)
for fc in self.fcs:
x = fc(x)
rcnn_boxes = self.box_pred(x)
rcnn_scores = self.cls_score(x)
return rcnn_boxes, rcnn_scores
def get_config(self):
config = {
"classes": self.num_classes,
"conv_dims": self.conv_dims,
"fc_dims": self.fc_dims,
}
base_config = super().get_config()
return dict(list(base_config.items()) + list(config.items()))
# TODO(tanzheny): add more configurations
class FasterRCNN(tf.keras.Model):
"""A Keras model implementing the FasterRCNN architecture.
Implements the FasterRCNN architecture for object detection. The constructor
requires `classes`, `bounding_box_format` and a `backbone`.
References:
- [FasterRCNN](https://arxiv.org/pdf/1506.01497.pdf)
Usage:
```python
retina_net = keras_cv.models.FasterRCNN(
classes=20,
bounding_box_format="xywh",
backbone="resnet50",
include_rescaling=False,
)
```
Args:
classes: the number of classes in your dataset excluding the background
class. Classes should be represented by integers in the range
[0, classes).
bounding_box_format: The format of bounding boxes of model output. Refer
[to the keras.io docs](https://keras.io/api/keras_cv/bounding_box/formats/)
for more details on supported bounding box formats.
backbone: Either `"resnet50"` or a custom backbone model. For now, only a backbone
with per level dict output is supported. Default to ResNet50 with FPN, which
uses the last conv block from stage 2 to stage 6 and add a max pooling at
stage 7.
include_rescaling: Required if provided backbone is a pre-configured model.
If set to `True`, inputs will be passed through a `Rescaling(1/255.0)`
layer. Default to False.
anchor_generator: (Optional) a `keras_cv.layers.AnchorGeneratot`. It is used
in the model to match ground truth boxes and labels with anchors, or with
region proposals. By default it uses the sizes and ratios from the paper,
that is optimized for image size between [640, 800]. The users should pass
their own anchor generator if the input image size differs from paper.
For now, only anchor generator with per level dict output is supported,
rpn_head: (Optional) a `keras.layers.Layer` that takes input feature map and
returns a box delta prediction (in reference to anchors) and binary prediction
(foreground vs background) with per level dict output is supported. By default
it uses the rpn head from paper, which is 3x3 conv followed by 1 box regressor
and 1 binary classifier.
rcnn_head: (Optional) a `keras.layers.Layer` that takes input feature map and
returns a box delta prediction (in reference to rois) and multi-class prediction
(all foreground classes + one background class). By default it uses the rcnn head
from paper, which is 2 FC layer with 1024 dimension, 1 box regressor and 1
softmax classifier.
"""
def __init__(
self,
classes,
bounding_box_format,
backbone=None,
include_rescaling=False,
anchor_generator=None,
rpn_head=None,
rcnn_head=None,
**kwargs,
):
self.bounding_box_format = bounding_box_format
super().__init__(**kwargs)
scales = [2**x for x in [0]]
aspect_ratios = [0.5, 1.0, 2.0]
self.anchor_generator = anchor_generator or AnchorGenerator(
bounding_box_format="yxyx",
sizes={2: 32.0, 3: 64.0, 4: 128.0, 5: 256.0, 6: 512.0},
scales=scales,
aspect_ratios=aspect_ratios,
strides={i: 2**i for i in range(2, 7)},
clip_boxes=True,
)
self.rpn_head = rpn_head or RPNHead(
num_anchors_per_location=len(scales) * len(aspect_ratios)
)
self.roi_generator = ROIGenerator(
bounding_box_format="yxyx",
nms_score_threshold_train=float("-inf"),
nms_score_threshold_test=float("-inf"),
)
self.box_matcher = ArgmaxBoxMatcher(
thresholds=[0.0, 0.5], match_values=[-2, -1, 1]
)
self.roi_sampler = _ROISampler(
bounding_box_format="yxyx",
roi_matcher=self.box_matcher,
background_class=classes,
num_sampled_rois=512,
)
self.roi_pooler = _ROIAligner(bounding_box_format="yxyx")
self.rcnn_head = rcnn_head or RCNNHead(classes)
self.backbone = backbone or _resnet50_backbone(include_rescaling)
self.feature_pyramid = FeaturePyramid()
self.rpn_labeler = _RpnLabelEncoder(
anchor_format="yxyx",
ground_truth_box_format="yxyx",
positive_threshold=0.7,
negative_threshold=0.3,
samples_per_image=256,
positive_fraction=0.5,
)
def _call_rpn(self, images, anchors, training=None):
image_shape = tf.shape(images[0])
feature_map = self.backbone(images, training=training)
feature_map = self.feature_pyramid(feature_map, training=training)
# [BS, num_anchors, 4], [BS, num_anchors, 1]
rpn_boxes, rpn_scores = self.rpn_head(feature_map)
# the decoded format is center_xywh, convert to yxyx
decoded_rpn_boxes = _decode_deltas_to_boxes(
anchors=anchors,
boxes_delta=rpn_boxes,
anchor_format="yxyx",
box_format="yxyx",
variance=[0.1, 0.1, 0.2, 0.2],
)
rois, _ = self.roi_generator(decoded_rpn_boxes, rpn_scores, training=training)
rois = _clip_boxes(rois, "yxyx", image_shape)
rpn_boxes = tf.concat(tf.nest.flatten(rpn_boxes), axis=1)
rpn_scores = tf.concat(tf.nest.flatten(rpn_scores), axis=1)
return rois, feature_map, rpn_boxes, rpn_scores
def _call_rcnn(self, rois, feature_map):
feature_map = self.roi_pooler(feature_map, rois)
# [BS, H*W*K, pool_shape*C]
feature_map = tf.reshape(
feature_map, tf.concat([tf.shape(rois)[:2], [-1]], axis=0)
)
# [BS, H*W*K, 4], [BS, H*W*K, num_classes + 1]
rcnn_box_pred, rcnn_cls_pred = self.rcnn_head(feature_map)
return rcnn_box_pred, rcnn_cls_pred
def call(self, images, training=None):
image_shape = tf.shape(images[0])
anchors = self.anchor_generator(image_shape=image_shape)
rois, feature_map, _, _ = self._call_rpn(images, anchors, training=training)
box_pred, cls_pred = self._call_rcnn(rois, feature_map)
if not training:
# box_pred is on "center_yxhw" format, convert to target format.
box_pred = _decode_deltas_to_boxes(
anchors=rois,
boxes_delta=box_pred,
anchor_format="yxyx",
box_format=self.bounding_box_format,
variance=[0.1, 0.1, 0.2, 0.2],
)
return box_pred, cls_pred
# TODO(tanzhenyu): Support compile with metrics.
def compile(
self,
box_loss=None,
classification_loss=None,
rpn_box_loss=None,
rpn_classification_loss=None,
weight_decay=0.0001,
loss=None,
**kwargs,
):
# TODO(tanzhenyu): Add metrics support once COCOMap issue is addressed.
# https://github.com/keras-team/keras-cv/issues/915
if "metrics" in kwargs.keys():
raise ValueError("currently metrics support is not supported intentionally")
if loss is not None:
raise ValueError(
"`FasterRCNN` does not accept a `loss` to `compile()`. "
"Instead, please pass `box_loss` and `classification_loss`. "
"`loss` will be ignored during training."
)
box_loss = _validate_and_get_loss(box_loss, "box_loss")
classification_loss = _validate_and_get_loss(
classification_loss, "classification_loss"
)
rpn_box_loss = _validate_and_get_loss(rpn_box_loss, "rpn_box_loss")
if rpn_classification_loss == "BinaryCrossentropy":
rpn_classification_loss = tf.keras.losses.BinaryCrossentropy(
from_logits=True, reduction=tf.keras.losses.Reduction.SUM
)
rpn_classification_loss = _validate_and_get_loss(
rpn_classification_loss, "rpn_cls_loss"
)
if not rpn_classification_loss.from_logits:
raise ValueError(
"`rpn_classification_loss` must come with `from_logits`=True"
)
self.rpn_box_loss = rpn_box_loss
self.rpn_cls_loss = rpn_classification_loss
self.box_loss = box_loss
self.cls_loss = classification_loss
self.weight_decay = weight_decay
losses = {
"box": self.box_loss,
"cls": self.cls_loss,
"rpn_box": self.rpn_box_loss,
"rpn_cls": self.rpn_cls_loss,
}
super().compile(loss=losses, **kwargs)
def compute_loss(self, images, gt_boxes, gt_classes, training):
image_shape = tf.shape(images[0])
local_batch = images.get_shape().as_list()[0]
if tf.distribute.has_strategy():
num_sync = tf.distribute.get_strategy().num_replicas_in_sync
else:
num_sync = 1
global_batch = local_batch * num_sync
anchors = self.anchor_generator(image_shape=image_shape)
(
rpn_box_targets,
rpn_box_weights,
rpn_cls_targets,
rpn_cls_weights,
) = self.rpn_labeler(
tf.concat(tf.nest.flatten(anchors), axis=0), gt_boxes, gt_classes
)
rpn_box_weights /= self.rpn_labeler.samples_per_image * global_batch * 0.25
rpn_cls_weights /= self.rpn_labeler.samples_per_image * global_batch
rois, feature_map, rpn_box_pred, rpn_cls_pred = self._call_rpn(
images, anchors, training=training
)
rois = tf.stop_gradient(rois)
rois, box_targets, box_weights, cls_targets, cls_weights = self.roi_sampler(
rois, gt_boxes, gt_classes
)
box_weights /= self.roi_sampler.num_sampled_rois * global_batch * 0.25
cls_weights /= self.roi_sampler.num_sampled_rois * global_batch
box_pred, cls_pred = self._call_rcnn(rois, feature_map)
y_true = {
"rpn_box": rpn_box_targets,
"rpn_cls": rpn_cls_targets,
"box": box_targets,
"cls": cls_targets,
}
y_pred = {
"rpn_box": rpn_box_pred,
"rpn_cls": rpn_cls_pred,
"box": box_pred,
"cls": cls_pred,
}
weights = {
"rpn_box": rpn_box_weights,
"rpn_cls": rpn_cls_weights,
"box": box_weights,
"cls": cls_weights,
}
return super().compute_loss(
x=images, y=y_true, y_pred=y_pred, sample_weight=weights
)
def train_step(self, data):
images, y, sample_weight = tf.keras.utils.unpack_x_y_sample_weight(data)
if sample_weight is not None:
raise ValueError("`sample_weight` is currently not supported.")
gt_boxes = y["gt_boxes"]
gt_classes = y["gt_classes"]
with tf.GradientTape() as tape:
total_loss = self.compute_loss(images, gt_boxes, gt_classes, training=True)
reg_losses = []
if self.weight_decay:
for var in self.trainable_variables:
if "bn" not in var.name:
reg_losses.append(self.weight_decay * tf.nn.l2_loss(var))
l2_loss = tf.math.add_n(reg_losses)
total_loss += l2_loss
self.optimizer.minimize(total_loss, self.trainable_variables, tape=tape)
return self.compute_metrics(images, {}, {}, sample_weight={})
def test_step(self, data):
images, y, sample_weight = tf.keras.utils.unpack_x_y_sample_weight(data)
if sample_weight is not None:
raise ValueError("`sample_weight` is currently not supported.")
gt_boxes = y["gt_boxes"]
gt_classes = y["gt_classes"]
self.compute_loss(images, gt_boxes, gt_classes, training=False)
return self.compute_metrics(images, {}, {}, sample_weight={})
def _validate_and_get_loss(loss, loss_name):
if isinstance(loss, str):
loss = tf.keras.losses.get(loss)
if loss is None or not isinstance(loss, tf.keras.losses.Loss):
raise ValueError(
f"FasterRCNN only accepts `tf.keras.losses.Loss` for {loss_name}, got {loss}"
)
if loss.reduction != tf.keras.losses.Reduction.SUM:
logging.info(
f"FasterRCNN only accepts `SUM` reduction, got {loss.reduction}, automatically converted."
)
loss.reduction = tf.keras.losses.Reduction.SUM
return loss
# Copyright 2022 The KerasCV Authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import tensorflow as tf
from keras_cv.models.object_detection.faster_rcnn import FasterRCNN
class FasterRCNNTest(tf.test.TestCase):
def test_faster_rcnn_infer(self):
model = FasterRCNN(classes=80, bounding_box_format="xyxy")
images = tf.random.normal([2, 512, 512, 3])
outputs = model(images, training=False)
# 1000 proposals in inference
self.assertAllEqual([2, 1000, 81], outputs[1].shape)
self.assertAllEqual([2, 1000, 4], outputs[0].shape)
def test_faster_rcnn_train(self):
model = FasterRCNN(classes=80, bounding_box_format="xyxy")
images = tf.random.normal([2, 512, 512, 3])
outputs = model(images, training=True)
self.assertAllEqual([2, 1000, 81], outputs[1].shape)
self.assertAllEqual([2, 1000, 4], outputs[0].shape)
def test_invalid_compile(self):
model = FasterRCNN(classes=80, bounding_box_format="yxyx")
with self.assertRaisesRegex(ValueError, "only accepts"):
model.compile(rpn_box_loss="binary_crossentropy")
with self.assertRaisesRegex(ValueError, "only accepts"):
model.compile(
rpn_classification_loss=tf.keras.losses.BinaryCrossentropy(
from_logits=False
)
)
# Copyright 2022 The KerasCV Authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import tensorflow as tf
from tensorflow import keras
from keras_cv import bounding_box
from keras_cv.models.object_detection.__internal__ import convert_inputs_to_tf_dataset
from keras_cv.models.object_detection.__internal__ import train_validation_split
class ObjectDetectionBaseModel(keras.Model):
"""ObjectDetectionBaseModel performs asynchonous label encoding.
ObjectDetectionBaseModel invokes the provided `label_encoder` in the `tf.data`
pipeline to ensure optimal training performance. This is done by overriding the
methods `train_on_batch()`, `fit()`, `test_on_batch()`, and `evaluate()`.
"""
def __init__(self, bounding_box_format, label_encoder, **kwargs):
super().__init__(**kwargs)
self.bounding_box_format = bounding_box_format
self.label_encoder = label_encoder
def fit(
self,
x=None,
y=None,
validation_data=None,
validation_split=None,
sample_weight=None,
batch_size=None,
**kwargs,
):
if validation_split and validation_data is None:
(x, y, sample_weight,), validation_data = train_validation_split(
(x, y, sample_weight), validation_split=validation_split
)
dataset = convert_inputs_to_tf_dataset(
x=x, y=y, sample_weight=sample_weight, batch_size=batch_size
)
if validation_data is not None:
val_x, val_y, val_sample = keras.utils.unpack_x_y_sample_weight(
validation_data
)
validation_data = convert_inputs_to_tf_dataset(
x=val_x, y=val_y, sample_weight=val_sample, batch_size=batch_size
)
dataset = dataset.map(self.encode_data, num_parallel_calls=tf.data.AUTOTUNE)
dataset = dataset.prefetch(tf.data.AUTOTUNE)
return super().fit(x=dataset, validation_data=validation_data, **kwargs)
def train_on_batch(self, x, y=None, **kwargs):
x, y = self.encode_data(x, y)
return super().train_on_batch(x=x, y=y, **kwargs)
def test_on_batch(self, x, y=None, **kwargs):
x, y = self.encode_data(x, y)
return super().test_on_batch(x=x, y=y, **kwargs)
def evaluate(
self,
x=None,
y=None,
sample_weight=None,
batch_size=None,
_use_cached_eval_dataset=None,
**kwargs,
):
dataset = convert_inputs_to_tf_dataset(
x=x, y=y, sample_weight=sample_weight, batch_size=batch_size
)
dataset = dataset.map(self.encode_data, num_parallel_calls=tf.data.AUTOTUNE)
dataset = dataset.prefetch(tf.data.AUTOTUNE)
# force _use_cached_eval_dataset=False
# this is required to override evaluate().
# We can remove _use_cached_eval_dataset=False when
# https://github.com/keras-team/keras/issues/16958
# is fixed
return super().evaluate(x=dataset, _use_cached_eval_dataset=False, **kwargs)
def encode_data(self, x, y):
y_for_metrics = y
y = bounding_box.convert_format(
y,
source=self.bounding_box_format,
target=self.label_encoder.bounding_box_format,
images=x,
)
y_training_target = self.label_encoder(x, y)
y_training_target = bounding_box.convert_format(
y_training_target,
source=self.label_encoder.bounding_box_format,
target=self.bounding_box_format,
images=x,
)
return x, (y_for_metrics, y_training_target)
# Copyright 2022 The KerasCV Authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import numpy as np
import tensorflow as tf
from keras_cv import layers as cv_layers
from keras_cv.models.object_detection.object_detection_base_model import (
ObjectDetectionBaseModel,
)
class ObjectDetectionBaseModelTest(tf.test.TestCase):
def test_raises_error_when_y_provided_with_dataset(self):
x = tf.data.Dataset.from_tensor_slices(
(tf.ones((8, 512, 512, 3)), tf.ones((8, 4, 5)))
)
y = tf.constant(8, 4, 5)
model = ObjectDetectionBaseModel(
bounding_box_format="xywh", label_encoder=_default_encoder()
)
with self.assertRaisesRegex(ValueError, "When `x` is a `tf.data.Dataset`,"):
model.fit(x=x, y=y)
def test_numpy_array(self):
model = DummySubclass()
model.compile()
x = np.ones((8, 512, 512, 3))
y = np.ones((8, 4, 5))
model.fit(x, y, validation_data=(x, y))
model.evaluate(np.ones((8, 512, 512, 3)), np.ones((8, 4, 5)))
model.train_on_batch(x, y)
model.test_on_batch(x, y)
def test_validation_split(self):
model = DummySubclass()
model.compile()
x = np.ones((8, 512, 512, 3))
y = np.ones((8, 4, 5))
model.fit(x, y, validation_split=0.2)
model.evaluate(np.ones((8, 512, 512, 3)), np.ones((8, 4, 5)))
def test_tf_dataset(self):
model = DummySubclass()
model.compile()
my_ds = tf.data.Dataset.from_tensor_slices(
(np.ones((8, 512, 512, 3)), np.ones((8, 4, 5)))
)
my_ds = my_ds.batch(8)
model.fit(my_ds, validation_data=my_ds)
model.evaluate(np.ones((8, 512, 512, 3)), np.ones((8, 4, 5)))
def test_with_sample_weight(self):
pass
class DummySubclass(ObjectDetectionBaseModel):
def __init__(self, **kwargs):
super().__init__(
bounding_box_format="xywh", label_encoder=_default_encoder(), **kwargs
)
def train_step(self, data):
x, y = data
y_for_metrics, y_training_target = data
return {"loss": 0}
def test_step(self, data):
x, y = data
y_for_metrics, y_training_target = data
return {"loss": 0}
def _default_encoder():
strides = [2**i for i in range(3, 8)]
scales = [2**x for x in [0, 1 / 3, 2 / 3]]
sizes = [32.0, 64.0, 128.0, 256.0, 512.0]
aspect_ratios = [0.5, 1.0, 2.0]
anchor_generator = cv_layers.AnchorGenerator(
bounding_box_format="xywh",
sizes=sizes,
aspect_ratios=aspect_ratios,
scales=scales,
strides=strides,
clip_boxes=True,
)
return cv_layers.RetinaNetLabelEncoder(
bounding_box_format="xywh", anchor_generator=anchor_generator
)
# Copyright 2022 The KerasCV Authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Copyright 2022 The KerasCV Authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Copyright 2022 The KerasCV Authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from keras_cv.models.object_detection.retina_net.__internal__.layers.feature_pyramid import (
FeaturePyramid,
)
from keras_cv.models.object_detection.retina_net.__internal__.layers.prediction_head import (
PredictionHead,
)
# Copyright 2022 The KerasCV Authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import tensorflow as tf
from tensorflow import keras
@tf.keras.utils.register_keras_serializable(package="keras_cv")
class FeaturePyramid(keras.layers.Layer):
"""Builds the Feature Pyramid with the feature maps from the backbone."""
def __init__(self, **kwargs):
super().__init__(**kwargs)
self.conv_c3_1x1 = keras.layers.Conv2D(256, 1, 1, "same")
self.conv_c4_1x1 = keras.layers.Conv2D(256, 1, 1, "same")
self.conv_c5_1x1 = keras.layers.Conv2D(256, 1, 1, "same")
self.conv_c3_3x3 = keras.layers.Conv2D(256, 3, 1, "same")
self.conv_c4_3x3 = keras.layers.Conv2D(256, 3, 1, "same")
self.conv_c5_3x3 = keras.layers.Conv2D(256, 3, 1, "same")
self.conv_c6_3x3 = keras.layers.Conv2D(256, 3, 2, "same")
self.conv_c7_3x3 = keras.layers.Conv2D(256, 3, 2, "same")
self.upsample_2x = keras.layers.UpSampling2D(2)
def call(self, inputs, training=False):
c3_output, c4_output, c5_output = inputs
p3_output = self.conv_c3_1x1(c3_output, training=training)
p4_output = self.conv_c4_1x1(c4_output, training=training)
p5_output = self.conv_c5_1x1(c5_output, training=training)
p4_output = p4_output + self.upsample_2x(p5_output, training=training)
p3_output = p3_output + self.upsample_2x(p4_output, training=training)
p3_output = self.conv_c3_3x3(p3_output, training=training)
p4_output = self.conv_c4_3x3(p4_output, training=training)
p5_output = self.conv_c5_3x3(p5_output, training=training)
p6_output = self.conv_c6_3x3(c5_output, training=training)
p7_output = self.conv_c7_3x3(tf.nn.relu(p6_output), training=training)
return p3_output, p4_output, p5_output, p6_output, p7_output
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment