Commit 5498e94a authored by suily's avatar suily
Browse files

Initial commit

parent 14530156
Pipeline #1635 failed with stages
in 0 seconds
# Copyright 2024 Google LLC.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
r"""Fine-tunes a Vision Transformer / Hybrid from AugReg checkpoint.
Example for fine-tuning a R+Ti/16 on cifar100:
python -m vit_jax.main --workdir=/tmp/vit \
--config=$(pwd)/vit_jax/configs/augreg.py:R_Ti_16 \
--config.dataset=oxford_iiit_pet \
--config.pp.train='train[:90%]' \
--config.base_lr=0.01
Note that by default, the best i21k pre-trained checkpoint by upstream
validation accuracy is chosen. You can also manually select a model by
specifying the full name (without ".npz" extension):
python -m vit_jax.main --workdir=/tmp/vit \
--config=$(pwd)/vit_jax/configs/augreg.py:R_Ti_16-i21k-300ep-lr_0.001-aug_none-wd_0.03-do_0.0-sd_0.0 \
--config.dataset=oxford_iiit_pet \
--config.pp.train='train[:90%]' \
--config.base_lr=0.01
"""
import ml_collections
from vit_jax.configs import common
from vit_jax.configs import models
def get_config(model_or_filename):
"""Returns default parameters for finetuning ViT `model` on `dataset`."""
config = common.get_config()
config.pretrained_dir = 'gs://vit_models/augreg'
config.model_or_filename = model_or_filename
model = model_or_filename.split('-')[0]
if model not in models.AUGREG_CONFIGS:
raise ValueError(f'Unknown Augreg model "{model}"'
f'- not found in {set(models.AUGREG_CONFIGS.keys())}')
config.model = models.AUGREG_CONFIGS[model].copy_and_resolve_references()
config.model.transformer.dropout_rate = 0 # No AugReg during fine-tuning.
# These values are often overridden on the command line.
config.base_lr = 0.03
config.total_steps = 500
config.warmup_steps = 100
config.pp = ml_collections.ConfigDict()
config.pp.train = 'train'
config.pp.test = 'test'
config.pp.resize = 448
config.pp.crop = 384
# This value MUST be overridden on the command line.
config.dataset = ''
return config
# Copyright 2024 Google LLC.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any, Dict, Iterable, Tuple, Union
import ml_collections
def get_config():
"""Returns config values other than model parameters."""
config = ml_collections.ConfigDict()
# Where to search for pretrained ViT models.
# Can be downloaded from gs://vit_models/imagenet21k
config.pretrained_dir = '.'
# Which dataset to finetune on. This can be the name of a tfds dataset
# (see https://www.tensorflow.org/datasets/catalog/overview), or the path to
# a directory with the following structure ($filename can be arbitrary):
# "{train,test}/$class_name/$filename.jpg"
config.dataset = ''
# Path to manually downloaded dataset
config.tfds_manual_dir = None
# Path to tensorflow_datasets directory
config.tfds_data_dir = None
# Number of steps; determined by hyper module if not specified.
config.total_steps = None
# Resizes global gradients.
config.grad_norm_clip = 1.0
# Datatype to use for momentum state ("bfloat16" or "float32").
config.optim_dtype = 'bfloat16'
# Accumulate gradients over multiple steps to save on memory.
config.accum_steps = 8
# Batch size for training.
config.batch = 512
# Batch size for evaluation.
config.batch_eval = 512
# Shuffle buffer size.
config.shuffle_buffer = 50_000
# Run prediction on validation set every so many steps
config.eval_every = 100
# Log progress every so many steps.
config.progress_every = 10
# How often to write checkpoints. Specifying 0 disables checkpointing.
config.checkpoint_every = 1_000
# Number of batches to prefetch to device.
config.prefetch = 2
# Base learning-rate for fine-tuning.
config.base_lr = 0.03
# How to decay the learning rate ("cosine" or "linear").
config.decay_type = 'cosine'
# How to decay the learning rate.
config.warmup_steps = 500
# Alternatives : inference_time.
config.trainer = 'train'
# Will be set from ./models.py
config.model = None
# Only used in ./augreg.py configs
config.model_or_filename = None
# Must be set via `with_dataset()`
config.dataset = None
config.pp = None
return config.lock()
# We leave out a subset of training for validation purposes (if needed).
DATASET_PRESETS = {
'cifar10': ml_collections.ConfigDict(
{'total_steps': 10_000,
'pp': ml_collections.ConfigDict(
{'train': 'train[:98%]',
'test': 'test',
'crop': 384})
}),
'cifar100': ml_collections.ConfigDict(
{'total_steps': 10_000,
'pp': ml_collections.ConfigDict(
{'train': 'train[:98%]',
'test': 'test',
'crop': 384})
}),
'imagenet2012': ml_collections.ConfigDict(
{'total_steps': 20_000,
'pp': ml_collections.ConfigDict(
{'train': 'train[:99%]',
'test': 'validation',
'crop': 384})
}),
}
def with_dataset(config: ml_collections.ConfigDict,
dataset: str) -> ml_collections.ConfigDict:
config = ml_collections.ConfigDict(config.to_dict())
config.dataset = dataset
config.update(DATASET_PRESETS[dataset])
return config
def flatten(
config: Union[ml_collections.ConfigDict, Dict[str, Any]],
prefix: Tuple[str, ...] = ('config',)
) -> Iterable[Tuple[str, Any]]:
"""Returns a flat representation of `config`, e.g. for use in sweeps."""
for k, v in config.items():
if isinstance(v, (dict, ml_collections.ConfigDict)):
yield from flatten(v, prefix + (k,))
else:
yield ('.'.join(prefix + (k,)), v)
# Copyright 2024 Google LLC.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import ml_collections
def get_config():
"""Returns a configuration for inference_time.py."""
config = ml_collections.ConfigDict()
# Which model to use -- see ./models.py
config.model_name = 'ViT-B_32'
# Where to store training logs.
config.log_dir = '.'
# Number of steps to measure.
config.steps = 30
# Number of steps before measuring.
config.initial_steps = 10
# Batch size
config.batch = 0
# Number of output classes.
config.num_classes = 0
# Image size (width=height).
config.image_size = 0
config.train = 'inference_time'
return config
# Copyright 2024 Google LLC.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import ml_collections
from vit_jax.configs import common
from vit_jax.configs import models
def get_config():
"""Returns config for training Mixer-B/16 on cifar10."""
config = common.get_config()
config.model_type = 'Mixer'
config.model = models.get_mixer_b16_config()
config.dataset = 'cifar10'
config.total_steps = 10_000
config.pp = ml_collections.ConfigDict(
{'train': 'train[:98%]', 'test': 'test', 'crop': 224})
return config
# Copyright 2024 Google LLC.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import ml_collections
# The key of this dictionary refers to basename in the directory:
# https://console.cloud.google.com/storage/vit_models/
# Note that some names (e.g. "testing", but also some models only available in
# the AugReg paper) are not actually present in that directory.
MODEL_CONFIGS = {}
# The key of this dictionary refers to the first part (delimited by "-") of the
# filename of the checkpoint in:
# https://console.cloud.google.com/storage/vit_models/augreg/index.csv
AUGREG_CONFIGS = {}
def _register(get_config):
"""Adds reference to model config into MODEL_CONFIGS and AUGREG_CONFIGS."""
config = get_config().lock()
name = config.get('model_name')
MODEL_CONFIGS[name] = config
if 'Mixer' not in name and name not in ('testing', 'ViT-L_32', 'R50+ViT-B_16',
'ViT-H_14'):
# Note: we're using stricter filenames for AugReg checkpoints so they can be
# used both as filesystem filenames and URIs without escaping.
augreg_name = name.replace('ViT-', '').replace('+', '_')
AUGREG_CONFIGS[augreg_name] = config
return get_config
@_register
def get_testing_config():
"""Returns a simple config used for testing."""
config = ml_collections.ConfigDict()
# Only used for testing.
config.model_name = 'testing'
config.patches = ml_collections.ConfigDict({'size': (16, 16)})
config.hidden_size = 10
config.transformer = ml_collections.ConfigDict()
config.transformer.mlp_dim = 10
config.transformer.num_heads = 2
config.transformer.num_layers = 1
config.transformer.attention_dropout_rate = 0.0
config.transformer.dropout_rate = 0.1
config.classifier = 'token'
config.representation_size = None
return config
@_register
def get_testing_unpooled_config():
"""Returns a simple config used for testing unpooled version."""
config = get_testing_config()
# Only used for testing.
config.model_name = 'testing-unpooled'
config.classifier = 'unpooled'
return config
# ViT-X/16 & ViT-H/14
#####################
@_register
def get_ti16_config():
"""Returns the ViT-Ti/16 configuration."""
config = ml_collections.ConfigDict()
config.model_name = 'ViT-Ti_16'
config.patches = ml_collections.ConfigDict({'size': (16, 16)})
config.hidden_size = 192
config.transformer = ml_collections.ConfigDict()
config.transformer.mlp_dim = 768
config.transformer.num_heads = 3
config.transformer.num_layers = 12
config.transformer.attention_dropout_rate = 0.0
config.transformer.dropout_rate = 0.0
config.classifier = 'token'
config.representation_size = None
return config
@_register
def get_s16_config():
"""Returns the ViT-S/16 configuration."""
config = ml_collections.ConfigDict()
config.model_name = 'ViT-S_16'
config.patches = ml_collections.ConfigDict({'size': (16, 16)})
config.hidden_size = 384
config.transformer = ml_collections.ConfigDict()
config.transformer.mlp_dim = 1536
config.transformer.num_heads = 6
config.transformer.num_layers = 12
config.transformer.attention_dropout_rate = 0.0
config.transformer.dropout_rate = 0.0
config.classifier = 'token'
config.representation_size = None
return config
@_register
def get_b16_config():
"""Returns the ViT-B/16 configuration."""
config = ml_collections.ConfigDict()
config.model_name = 'ViT-B_16'
config.patches = ml_collections.ConfigDict({'size': (16, 16)})
config.hidden_size = 768
config.transformer = ml_collections.ConfigDict()
config.transformer.mlp_dim = 3072
config.transformer.num_heads = 12
config.transformer.num_layers = 12
config.transformer.attention_dropout_rate = 0.0
config.transformer.dropout_rate = 0.0
config.classifier = 'token'
config.representation_size = None
return config
@_register
def get_l16_config():
"""Returns the ViT-L/16 configuration."""
config = ml_collections.ConfigDict()
config.model_name = 'ViT-L_16'
config.patches = ml_collections.ConfigDict({'size': (16, 16)})
config.hidden_size = 1024
config.transformer = ml_collections.ConfigDict()
config.transformer.mlp_dim = 4096
config.transformer.num_heads = 16
config.transformer.num_layers = 24
config.transformer.attention_dropout_rate = 0.0
config.transformer.dropout_rate = 0.1
config.classifier = 'token'
config.representation_size = None
return config
@_register
def get_h14_config():
"""Returns the ViT-H/14 configuration."""
config = ml_collections.ConfigDict()
config.model_name = 'ViT-H_14'
config.patches = ml_collections.ConfigDict({'size': (14, 14)})
config.hidden_size = 1280
config.transformer = ml_collections.ConfigDict()
config.transformer.mlp_dim = 5120
config.transformer.num_heads = 16
config.transformer.num_layers = 32
config.transformer.attention_dropout_rate = 0.0
config.transformer.dropout_rate = 0.1
config.classifier = 'token'
config.representation_size = None
return config
@_register
def get_s16_gap_norep_config():
"""Returns ViT-S/16 with classifier=gap, representation=None."""
config = get_s16_config()
config.model_name = 'ViT-S_16-gap-norep'
config.classifier = 'gap'
config.representation_size = None
return config
@_register
def get_b16_gap_norep_config():
"""Returns ViT-B/16 with classifier=gap, representation=None."""
config = get_b16_config()
config.model_name = 'ViT-B_16-gap-norep'
config.classifier = 'gap'
config.representation_size = None
return config
# ViT-X/8
#########
@_register
def get_b8_config():
"""Returns the ViT-B/8 configuration."""
config = get_b16_config()
config.model_name = 'ViT-B_8'
config.patches.size = (8, 8)
return config
# ViT-X/32
##########
@_register
def get_s32_config():
"""Returns the ViT-S/32 configuration."""
config = get_s16_config()
config.model_name = 'ViT-S_32'
config.patches.size = (32, 32)
return config
@_register
def get_b32_config():
"""Returns the ViT-B/32 configuration."""
config = get_b16_config()
config.model_name = 'ViT-B_32'
config.patches.size = (32, 32)
return config
@_register
def get_l32_config():
"""Returns the ViT-L/32 configuration."""
config = get_l16_config()
config.transformer.dropout_rate = 0.0
config.model_name = 'ViT-L_32'
config.patches.size = (32, 32)
return config
@_register
def get_s32_gap_norep_config():
"""Returns ViT-S/32 with classifier=gap, representation=None."""
config = get_s32_config()
config.model_name = 'ViT-S_32-gap-norep'
config.classifier = 'gap'
config.representation_size = None
return config
@_register
def get_b32_gap_norep_config():
"""Returns ViT-B/32 with classifier=gap, representation=None."""
config = get_b32_config()
config.model_name = 'ViT-B_32-gap-norep'
config.classifier = 'gap'
config.representation_size = None
return config
# Hybrids R+ViT-X/16
####################
@_register
def get_r_ti16_config():
"""Returns the Resnet stem + ViT-Ti/16 configuration."""
config = get_ti16_config()
config.model_name = 'R+ViT-Ti_16'
config.patches.size = (8, 8)
config.resnet = ml_collections.ConfigDict()
# The resnet stem alone downscales 2x, making /16 with 8x8 patches.
config.resnet.num_layers = ()
config.resnet.width_factor = 1
return config
@_register
def get_r50_b16_config():
"""Returns the Resnet50 + ViT-B/16 configuration."""
config = get_b16_config()
config.transformer.dropout_rate = 0.1
config.model_name = 'R50+ViT-B_16'
config.patches.size = (1, 1)
config.resnet = ml_collections.ConfigDict()
# Note that the "real" Resnet50 has (3, 4, 6, 3) bottleneck blocks. Here
# we're using (3, 4, 9) configuration so we get a downscaling of 2^(1 + 3)=16
# which results in an effective patch size of /16.
config.resnet.num_layers = (3, 4, 9)
config.resnet.width_factor = 1
return config
# Hybrids R+ViT-X/32
####################
@_register
def get_r26_b32_config():
"""Returns the Resnet26 + ViT-B/32 configuration."""
config = get_b32_config()
config.model_name = 'R26+ViT-B_32'
config.patches.size = (1, 1)
config.resnet = ml_collections.ConfigDict()
# Using four bottleneck blocks results in a downscaling of 2^(1 + 4)=32 which
# results in an effective patch size of /32.
config.resnet.num_layers = (2, 2, 2, 2)
config.resnet.width_factor = 1
return config
@_register
def get_r26_s32_config():
"""Returns the Resnet26 + ViT-S/32 configuration."""
config = get_s16_config()
config.model_name = 'R26+ViT-S_32'
config.patches.size = (1, 1)
config.resnet = ml_collections.ConfigDict()
# Using four bottleneck blocks results in a downscaling of 2^(1 + 4)=32 which
# results in an effective patch size of /32.
config.resnet.num_layers = (2, 2, 2, 2)
config.resnet.width_factor = 1
return config
@_register
def get_r50_l32_config():
"""Returns the Resnet50 + ViT-L/32 configuration."""
config = get_l16_config()
config.model_name = 'R50+ViT-L_32'
config.patches.size = (1, 1)
config.resnet = ml_collections.ConfigDict()
# Using four bottleneck blocks results in a downscaling of 2^(1 + 4)=32 which
# results in an effective patch size of /32.
config.resnet.num_layers = (3, 4, 6, 3)
config.resnet.width_factor = 1
return config
# Mixers
########
@_register
def get_mixer_b16_config():
"""Returns Mixer-B/16 configuration."""
config = ml_collections.ConfigDict()
config.model_name = 'Mixer-B_16'
config.patches = ml_collections.ConfigDict({'size': (16, 16)})
config.hidden_dim = 768
config.num_blocks = 12
config.tokens_mlp_dim = 384
config.channels_mlp_dim = 3072
return config
@_register
def get_mixer_b32_config():
"""Returns Mixer-B/32 configuration."""
config = get_mixer_b16_config()
config.model_name = 'Mixer-B_32'
config.patches = ml_collections.ConfigDict({'size': (32, 32)})
return config
@_register
def get_mixer_l16_config():
"""Returns Mixer-L/16 configuration."""
config = ml_collections.ConfigDict()
config.model_name = 'Mixer-L_16'
config.patches = ml_collections.ConfigDict({'size': (16, 16)})
config.hidden_dim = 1024
config.num_blocks = 24
config.tokens_mlp_dim = 512
config.channels_mlp_dim = 4096
return config
# LiT
#####
@_register
def get_lit_b16b_config():
"""Returns a LiT model with ViT-Base and BERT-Base towers."""
config = ml_collections.ConfigDict()
config.model_name = 'LiT-B16B'
config.out_dim = (768, 768)
config.image = get_b16_config()
config.text_model = 'bert'
config.text = {}
config.text.config = 'base'
config.pp = {}
config.pp.tokenizer_name = 'bert'
config.pp.size = 224
config.pp.max_len = 16
return config
@_register
def get_lit_b16b_2_config():
"""Returns an improved LiT model with ViT-Base and BERT-Base towers."""
config = get_lit_b16b_config()
config.model_name = 'LiT-B16B_2'
config.out_dim = (None, 768)
return config
@_register
def get_lit_l16l_config():
"""Returns a LiT model with ViT-Large and BERT-Large towers."""
config = ml_collections.ConfigDict()
config.model_name = 'LiT-L16L'
config.out_dim = (None, 1024)
config.image = get_l16_config()
config.text_model = 'bert'
config.text = {}
config.text.config = 'large'
config.pp = {}
config.pp.tokenizer_name = 'bert'
config.pp.size = 224
config.pp.max_len = 16
return config
@_register
def get_lit_l16s_config():
"""Returns a LiT model with ViT-Large and small text towers."""
config = ml_collections.ConfigDict()
config.model_name = 'LiT-L16S'
config.out_dim = (None, 1024)
config.image = get_l16_config()
config.text_model = 'text_transformer'
config.text = {}
config.text.width = 384
config.text.num_layers = 12
config.text.mlp_dim = 1536
config.text.num_heads = 6
config.text.vocab_size = 16_000
config.pp = {}
config.pp.tokenizer_name = 'sentencepiece'
config.pp.size = 224
config.pp.max_len = 16
return config
@_register
def get_lit_l16ti_config():
"""Returns a LiT model with ViT-Large and tiny text towers."""
config = ml_collections.ConfigDict()
config.model_name = 'LiT-L16Ti'
config.out_dim = (None, 1024)
config.image = get_l16_config()
config.text_model = 'text_transformer'
config.text = {}
config.text.width = 192
config.text.num_layers = 12
config.text.mlp_dim = 768
config.text.num_heads = 3
config.text.vocab_size = 16_000
config.pp = {}
config.pp.tokenizer_name = 'sentencepiece'
config.pp.size = 224
config.pp.max_len = 16
return config
# Copyright 2024 Google LLC.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
r"""Fine-tunes a Vision Transformer.
Example for fine-tuning a ViT-B/16 on CIFAR10:
python -m vit_jax.main --workdir=/tmp/vit \
--config=$(pwd)/vit_jax/configs/vit.py:b16,cifar10 \
--config.pretrained_dir='gs://vit_models/imagenet21k'
"""
from vit_jax.configs import common
from vit_jax.configs import models
def get_config(model_dataset):
"""Returns default parameters for finetuning ViT `model` on `dataset`."""
print(model_dataset)
model, dataset = model_dataset.split(',')
config = common.with_dataset(common.get_config(), dataset)
get_model_config = getattr(models, f'get_{model}_config')
config.model = get_model_config()
if model == 'b16' and dataset == 'cifar10':
config.base_lr = 0.01
return config
torch<2.1.0
jax<0.4.23
jaxlib<0.4.23
tensorflow<2.13.1
\ No newline at end of file
# Copyright 2024 Google LLC.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import functools
import os
import time
from absl import logging
from clu import metric_writers
import flax
import flax.jax_utils as flax_utils
import jax
import jax.numpy as jnp
import ml_collections
import numpy as np
import tensorflow as tf
from vit_jax import checkpoint
from vit_jax import models
from vit_jax.configs import models as config_lib
def inference_time(config: ml_collections.ConfigDict, workdir: str):
"""Runs a number of steps and measures inference time."""
assert config.batch, f'Expected --config.batch={config.batch} > 0'
assert config.num_classes, (
f'Expected --config.num_classes={config.num_classes} > 0')
assert config.image_size, (
f'Expected --config.image_size={config.image_size} > 0')
# Build VisionTransformer architecture
model_config = config_lib.MODEL_CONFIGS[config.model_name]
model = models.VisionTransformer(
num_classes=config.num_classes, **model_config)
# Make sure initial model parameters (before replication) are on CPU only.
@functools.partial(jax.jit, backend='cpu')
def init(rng):
return model.init(
rng,
# Discard the "num_local_devices" dimension for initialization.
inputs=jnp.ones([1, config.image_size, config.image_size, 3],
jnp.float32),
train=False)
variables = init(jax.random.PRNGKey(0))
params_repl = flax_utils.replicate(variables['params'])
# pmap replicates the models over all TPUs/GPUs
vit_fn_repl = jax.pmap(functools.partial(model.apply, train=False))
images = jnp.ones([
jax.local_device_count(), config.batch // jax.local_device_count(),
config.image_size, config.image_size, 3
], jnp.float32)
writer = metric_writers.create_default_writer(workdir, asynchronous=False)
writer.write_hparams(config.to_dict())
logging.info('Starting training loop; initial compile can take a while...')
logits = vit_fn_repl(flax.core.FrozenDict(params=params_repl), images)
logits.block_until_ready()
logging.info('Done.')
logging.info('Going to run %d inferences WITHOUT measuring...',
config.initial_steps)
for _ in range(config.initial_steps):
logits = vit_fn_repl(flax.core.FrozenDict(params=params_repl), images)
logits.block_until_ready()
logging.info('Going to run %d s measuring...', config.steps)
times = []
for _ in range(config.initial_steps):
t0 = time.time()
logits = vit_fn_repl(flax.core.FrozenDict(params=params_repl), images)
logits.block_until_ready()
times.append(time.time() - t0)
logging.info('times=%s', times)
imgs_sec_core = config.batch / jax.local_device_count() / np.array(times)
logging.info('imgs_sec_core_min=%f', imgs_sec_core.min())
logging.info('imgs_sec_core_max=%f', imgs_sec_core.max())
logging.info('imgs_sec_core_mean=%f', imgs_sec_core.mean())
logging.info('imgs_sec_core_std=%f', imgs_sec_core.std())
writer.write_scalars(
0,
dict(
imgs_sec_core_min=imgs_sec_core.min(),
imgs_sec_core_max=imgs_sec_core.max(),
imgs_sec_core_mean=imgs_sec_core.mean(),
imgs_sec_core_std=imgs_sec_core.std(),
))
# Copyright 2024 Google LLC.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import glob
import tempfile
from absl.testing import absltest
from vit_jax import inference_time
from vit_jax import test_utils
from vit_jax.configs import inference_time as config_lib
from vit_jax.configs import models
class InferenceTimeTest(absltest.TestCase):
def test_main(self):
config = config_lib.get_config()
config.num_classes = 10
config.image_size = 224
config.batch = 8
config.model_name = 'testing'
model_config = models.get_testing_config()
workdir = tempfile.gettempdir()
config.pretrained_dir = workdir
test_utils.create_checkpoint(model_config, f'{workdir}/testing.npz')
inference_time.inference_time(config, workdir)
self.assertNotEmpty(glob.glob(f'{workdir}/events.out.tfevents.*'))
if __name__ == '__main__':
absltest.main()
# Copyright 2024 Google LLC.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import glob
import os
from absl import logging
import flax
import jax
import numpy as np
import tensorflow as tf
import tensorflow_datasets as tfds
# tf.config.set_visible_devices([], 'GPU')
import sys
if sys.platform != 'darwin':
# A workaround to avoid crash because tfds may open to many files.
import resource
low, high = resource.getrlimit(resource.RLIMIT_NOFILE)
resource.setrlimit(resource.RLIMIT_NOFILE, (high, high))
# Adjust depending on the available RAM.
MAX_IN_MEMORY = 200_000
def get_tfds_info(dataset, split):
"""Returns information about tfds dataset -- see `get_dataset_info()`."""
data_builder = tfds.builder(dataset)
return dict(
num_examples=data_builder.info.splits[split].num_examples,
num_classes=data_builder.info.features['label'].num_classes,
int2str=data_builder.info.features['label'].int2str,
examples_glob=None,
)
def get_directory_info(directory):
"""Returns information about directory dataset -- see `get_dataset_info()`."""
examples_glob = f'{directory}/*/*.jpg'
paths = glob.glob(examples_glob)
get_classname = lambda path: path.split('/')[-2]
class_names = sorted(set(map(get_classname, paths)))
return dict(
num_examples=len(paths),
num_classes=len(class_names),
int2str=lambda id_: class_names[id_],
examples_glob=examples_glob,
)
def get_dataset_info(dataset, split):
"""Returns information about a dataset.
Args:
dataset: Name of tfds dataset or directory -- see `./configs/common.py`
split: Which split to return data for (e.g. "test", or "train"; tfds also
supports splits like "test[:90%]").
Returns:
A dictionary with the following keys:
- num_examples: Number of examples in dataset/mode.
- num_classes: Number of classes in dataset.
- int2str: Function converting class id to class name.
- examples_glob: Glob to select all files, or None (for tfds dataset).
"""
directory = os.path.join(dataset, split)
if os.path.isdir(directory):
return get_directory_info(directory)
return get_tfds_info(dataset, split)
def get_datasets(config):
"""Returns `ds_train, ds_test` for specified `config`."""
if os.path.isdir(config.dataset):
train_dir = os.path.join(config.dataset, 'train')
test_dir = os.path.join(config.dataset, 'test')
if not os.path.isdir(train_dir):
raise ValueError('Expected to find directories"{}" and "{}"'.format(
train_dir,
test_dir,
))
logging.info('Reading dataset from directories "%s" and "%s"', train_dir,
test_dir)
ds_train = get_data_from_directory(
config=config, directory=train_dir, mode='train')
ds_test = get_data_from_directory(
config=config, directory=test_dir, mode='test')
else:
logging.info('Reading dataset from tfds "%s"', config.dataset)
ds_train = get_data_from_tfds(config=config, mode='train')
ds_test = get_data_from_tfds(config=config, mode='test')
return ds_train, ds_test
def get_data_from_directory(*, config, directory, mode):
"""Returns dataset as read from specified `directory`."""
dataset_info = get_directory_info(directory)
data = tf.data.Dataset.list_files(dataset_info['examples_glob'])
class_names = [
dataset_info['int2str'](id_) for id_ in range(dataset_info['num_classes'])
]
def _pp(path):
return dict(
image=path,
label=tf.where(
tf.strings.split(path, '/')[-2] == class_names
)[0][0],
)
image_decoder = lambda path: tf.image.decode_jpeg(tf.io.read_file(path), 3)
return get_data(
data=data,
mode=mode,
num_classes=dataset_info['num_classes'],
image_decoder=image_decoder,
repeats=None if mode == 'train' else 1,
batch_size=config.batch_eval if mode == 'test' else config.batch,
image_size=config.pp['crop'],
shuffle_buffer=min(dataset_info['num_examples'], config.shuffle_buffer),
preprocess=_pp)
def get_data_from_tfds(*, config, mode):
"""Returns dataset as read from tfds dataset `config.dataset`."""
data_builder = tfds.builder(config.dataset, data_dir=config.tfds_data_dir)
data_builder.download_and_prepare(
download_config=tfds.download.DownloadConfig(
manual_dir=config.tfds_manual_dir))
data = data_builder.as_dataset(
split=config.pp[mode],
# Reduces memory footprint in shuffle buffer.
decoders={'image': tfds.decode.SkipDecoding()},
shuffle_files=mode == 'train')
image_decoder = data_builder.info.features['image'].decode_example
dataset_info = get_tfds_info(config.dataset, config.pp[mode])
return get_data(
data=data,
mode=mode,
num_classes=dataset_info['num_classes'],
image_decoder=image_decoder,
repeats=None if mode == 'train' else 1,
batch_size=config.batch_eval if mode == 'test' else config.batch,
image_size=config.pp['crop'],
shuffle_buffer=min(dataset_info['num_examples'], config.shuffle_buffer))
def get_data(*,
data,
mode,
num_classes,
image_decoder,
repeats,
batch_size,
image_size,
shuffle_buffer,
preprocess=None):
"""Returns dataset for training/eval.
Args:
data: tf.data.Dataset to read data from.
mode: Must be "train" or "test".
num_classes: Number of classes (used for one-hot encoding).
image_decoder: Applied to `features['image']` after shuffling. Decoding the
image after shuffling allows for a larger shuffle buffer.
repeats: How many times the dataset should be repeated. For indefinite
repeats specify None.
batch_size: Global batch size. Note that the returned dataset will have
dimensions [local_devices, batch_size / local_devices, ...].
image_size: Image size after cropping (for training) / resizing (for
evaluation).
shuffle_buffer: Number of elements to preload the shuffle buffer with.
preprocess: Optional preprocess function. This function will be applied to
the dataset just after repeat/shuffling, and before the data augmentation
preprocess step is applied.
"""
def _pp(data):
im = image_decoder(data['image'])
if mode == 'train':
channels = im.shape[-1]
begin, size, _ = tf.image.sample_distorted_bounding_box(
tf.shape(im),
tf.zeros([0, 0, 4], tf.float32),
area_range=(0.05, 1.0),
min_object_covered=0, # Don't enforce a minimum area.
use_image_if_no_bounding_boxes=True)
im = tf.slice(im, begin, size)
# Unfortunately, the above operation loses the depth-dimension. So we
# need to restore it the manual way.
im.set_shape([None, None, channels])
im = tf.image.resize(im, [image_size, image_size])
if tf.random.uniform(shape=[]) > 0.5:
im = tf.image.flip_left_right(im)
else:
im = tf.image.resize(im, [image_size, image_size])
im = (im - 127.5) / 127.5
label = tf.one_hot(data['label'], num_classes) # pylint: disable=no-value-for-parameter
return {'image': im, 'label': label}
data = data.repeat(repeats)
if mode == 'train':
data = data.shuffle(shuffle_buffer)
if preprocess is not None:
data = data.map(preprocess, tf.data.experimental.AUTOTUNE)
data = data.map(_pp, tf.data.experimental.AUTOTUNE)
data = data.batch(batch_size, drop_remainder=True)
# Shard data such that it can be distributed accross devices
num_devices = jax.local_device_count()
def _shard(data):
data['image'] = tf.reshape(data['image'],
[num_devices, -1, image_size, image_size,
data['image'].shape[-1]])
data['label'] = tf.reshape(data['label'],
[num_devices, -1, num_classes])
return data
if num_devices is not None:
data = data.map(_shard, tf.data.experimental.AUTOTUNE)
return data.prefetch(1)
def prefetch(dataset, n_prefetch):
"""Prefetches data to device and converts to numpy array."""
ds_iter = iter(dataset)
ds_iter = map(lambda x: jax.tree_util.tree_map(lambda t: np.asarray(memoryview(t)), x),
ds_iter)
if n_prefetch:
ds_iter = flax.jax_utils.prefetch_to_device(ds_iter, n_prefetch)
return ds_iter
# Copyright 2024 Google LLC.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from absl import app
from absl import flags
from absl import logging
from clu import platform
import jax
from ml_collections import config_flags
import tensorflow as tf
from vit_jax import inference_time
from vit_jax import train
from vit_jax import utils
from jax.lib import xla_bridge
FLAGS = flags.FLAGS
_WORKDIR = flags.DEFINE_string('workdir', None,
'目录,用于存储日志和模型数据.')
config_flags.DEFINE_config_file(
'config',
None,
'训练超参数配置的文件路径。',
lock_config=True)
flags.mark_flags_as_required(['config', 'workdir'])
# Flags --jax_backend_target and --jax_xla_backend are available through JAX.
def main(argv):
if len(argv) > 1:
raise app.UsageError('Too many command-line arguments.')
utils.add_gfile_logger(_WORKDIR.value)
# 隐藏任何gpu形式TensorFlow。否则,TF可能会保留内存,使其对JAX不可用。
tf.config.experimental.set_visible_devices([], 'GPU')
jax.config.update('jax_log_compiles', True)
logging.info('JAX process: %d / %d', jax.process_index(), jax.process_count())
logging.info('JAX local devices: %r', jax.local_devices())
jax_xla_backend = ('None' if FLAGS.jax_xla_backend is None else
FLAGS.jax_xla_backend)
logging.info('Using JAX XLA backend %s', jax_xla_backend)
logging.info('Config: %s', FLAGS.config)
# Add a note so that we can tell which task is which JAX host.
# (Depending on the platform task 0 is not guaranteed to be host 0)
platform.work_unit().set_task_status(f'process_index: {jax.process_index()}, '
f'process_count: {jax.process_count()}')
platform.work_unit().create_artifact(platform.ArtifactType.DIRECTORY,
_WORKDIR.value, 'workdir')
if FLAGS.config.trainer == 'train':
train.train_and_evaluate(FLAGS.config, _WORKDIR.value)
elif FLAGS.config.trainer == 'inference_time':
inference_time.inference_time(FLAGS.config, _WORKDIR.value)
else:
raise app.UsageError(f'Unknown trainer: {FLAGS.config.trainer}')
if __name__ == '__main__':
# Provide access to --jax_backend_target and --jax_xla_backend flags.
jax_test=xla_bridge.get_backend().platform
print(jax_test)
if not (jax_test=='gpu'):
exit()
jax.config.config_with_absl()
app.run(main)
# Copyright 2024 Google LLC.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from vit_jax import models_lit
from vit_jax import models_mixer
from vit_jax import models_vit
from vit_jax.configs import models as model_configs
# Note that you probably want to import the individual modules separately
# instead (e.g. not depending on tensorflow_text required by models_lit if
# you're only interested in image models).
AddPositionEmbs = models_vit.AddPositionEmbs
MlpBlock = models_vit.MlpBlock
Encoder1DBlock = models_vit.Encoder1DBlock
Encoder = models_vit.Encoder
LitModel = models_lit.LitModel
MlpMixer = models_mixer.MlpMixer
VisionTransformer = models_vit.VisionTransformer
def get_model(name, **kw):
"""Returns a model as specified in `model_configs.MODEL_CONFIGS`."""
if name.startswith('Mixer-'):
return MlpMixer(**model_configs.MODEL_CONFIGS[name], **kw)
elif name.startswith('LiT-'):
return LitModel(**model_configs.MODEL_CONFIGS[name], **kw)
else:
return VisionTransformer(**model_configs.MODEL_CONFIGS[name], **kw)
# Copyright 2024 Google LLC.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Models from Locked-image text Tuning.
See paper https://arxiv.org/abs/2111.07991
"""
import dataclasses
import os
from typing import Optional, Tuple
import flax.linen as nn
import jax.numpy as jnp
import ml_collections
from vit_jax import checkpoint
from vit_jax import models_vit
from vit_jax import preprocess
from flaxformer.architectures.bert import bert
from flaxformer.architectures.bert import configs
BASE_PATH = 'gs://vit_models/lit'
class BertModel(nn.Module):
"""BERT encoder with linear projection on last layer CLS token."""
config: str
num_classes: Optional[int] = None
@nn.compact
def __call__(self, tokens):
out = {}
batch_size, max_len = tokens.shape
bert_model = bert.BertEncoder(**dataclasses.asdict({
'base': configs.BertBaseConfig(),
'large': configs.BertLargeConfig(),
}[self.config]))
x = out['transformed'] = bert_model(
token_ids=tokens,
position_ids=jnp.tile(
jnp.arange(0, max_len, dtype=jnp.int32), [batch_size, 1]),
segment_ids=jnp.zeros([batch_size, max_len], dtype=jnp.int32),
input_mask=tokens.astype(jnp.bool_).astype(jnp.int32),
enable_dropout=False,
)
x = out['pre_logits'] = x[:, 0] # CLS token
if self.num_classes:
x = out['logits'] = nn.Dense(self.num_classes, name='head')(x)
return x, out
class TextTransformer(nn.Module):
"""Simple text transformer."""
num_classes: int
width: int = 512
num_layers: int = 12
mlp_dim: int = 2048
num_heads: int = 8
dropout_rate: float = 0.0
vocab_size: int = 32_000
@nn.compact
def __call__(self, x):
out = {}
embedding = nn.Embed(num_embeddings=self.vocab_size, features=self.width)
x = out['embedded'] = embedding(x)
# Add posemb
n, l, d = x.shape # pylint: disable=unused-variable
x = x + self.param('pos_embedding',
nn.initializers.normal(stddev=1 / jnp.sqrt(d)),
(1, l, d), x.dtype)
x = models_vit.Encoder(
num_layers=self.num_layers,
mlp_dim=self.mlp_dim,
num_heads=self.num_heads,
dropout_rate=self.dropout_rate,
attention_dropout_rate=0,
add_position_embedding=False)(
x, train=False)
x = out['pre_logits'] = x[:, -1, :] # note that we take *last* token
x = out['logits'] = nn.Dense(self.num_classes, name='head')(x)
return x, out
class LitModel(nn.Module):
"""Locked-image text Tuning model.
See paper https://arxiv.org/abs/2111.07991
For examples, refer to Colab
https://colab.research.google.com/github/google-research/vision_transformer/blob/main/lit.ipynb
Attributes:
image: Configuration for ViT image tower.
text: Configuration for text tower.
pp: Preprocessing configuration.
out_dim: Size of optional image/text heads that are added to the towers.
model_name: Refers to the key in `model_configs.MODEL_CONFIGS`.
"""
image: ml_collections.ConfigDict
text_model: str
text: ml_collections.ConfigDict
pp: ml_collections.ConfigDict
out_dim: Tuple[Optional[int], Optional[int]]
model_name: str
def load_variables(self, path=None, cache=True):
"""Loads variables.
Args:
path: Path to load params from. If not specified, then the parms will be
loaded from the default public Cloud storage path, unless they exist in
the current working directory.
cache: If set to `True` and `path` is not specified (the default), then
the files will be copied from Cloud and stored in the current working
directory.
Returns:
The module variables, to be used with `model.apply()`
"""
if path is None:
local_path = f'{self.model_name}.npz'
if not os.path.exists(local_path):
path = f'{BASE_PATH}/{self.model_name}.npz'
print('Loading params from cloud:', path)
if cache:
checkpoint.copy(path, local_path)
if os.path.exists(local_path):
print('\n⚠️ Reusing local copy:', local_path)
path = local_path
return {'params': checkpoint.load(path)}
@property
def vocab_path(self):
ext = {
'bert': 'txt',
'sentencepiece': 'model',
}[self.pp.tokenizer_name]
return f'{BASE_PATH}/{self.model_name}.{ext}'
def get_pp(self, crop=False):
"""Returns a preprocessing function suitable for `tf.data.Dataset.map()`."""
return preprocess.get_pp(
tokenizer_name=self.pp.tokenizer_name,
vocab_path=self.vocab_path,
max_len=self.pp.max_len,
size=self.pp.size,
crop=crop)
def get_tokenizer(self):
"""Returns a tokenizer."""
return preprocess.get_tokenizer(self.pp.tokenizer_name)(
vocab_path=self.vocab_path,
max_len=self.pp.max_len)
def get_image_preprocessing(self, crop=False):
"""Returns a function to pre-process images (resize, value range)."""
return preprocess.PreprocessImages(size=self.pp.size, crop=crop)
@nn.compact
def __call__(self, *, images=None, tokens=None):
"""Embeds images and/or tokens.
Args:
images: Batch of images, prepared with the function returned by
`get_image_preprocessing()` or `get_pp()`.
tokens: Batch of tokens, prepared with the function returned by
`get_tokenizer()` or `get_pp()`.
Returns:
A tuple of `(zimg, ztxt, out)`, where `zimg` is a batch of embeddings for
the images (or `None`, if images were not specified), `ztxt` is a batch
of embeddings for the tokens (or `None`, if tokens were not specified),
and `out` is a dictionary of additional values, such as `out['t']` that
is the temperature multiplied with the vector dot products before the
softmax is applied.
"""
# Support calling without text or without images, for example for few-shot.
ztxt, zimg = None, None
out = {}
out_dims = self.out_dim
if isinstance(out_dims, int):
out_dims = (out_dims, out_dims)
if tokens is not None:
# Embed the text:
model_class = {
'bert': BertModel,
'text_transformer': TextTransformer,
}[self.text_model]
text_model = model_class(
**{
'num_classes': out_dims[1],
**(self.text or {})
}, name='txt')
ztxt, out_txt = text_model(tokens)
for k, v in out_txt.items():
out[f'txt/{k}'] = v
# Normalize the embeddings the models give us.
out['txt/norm'] = jnp.linalg.norm(ztxt, axis=1, keepdims=True)
out['txt/normalized'] = ztxt = ztxt / (out['txt/norm'] + 1e-8)
if images is not None:
image_model = models_vit.VisionTransformer(
**{
**self.image,
'num_classes': out_dims[0],
}, name='img') # pylint: disable=not-a-mapping
zimg = image_model(images, train=False)
# Normalize the embeddings the models give us.
out['img/norm'] = jnp.linalg.norm(zimg, axis=1, keepdims=True)
out['img/normalized'] = zimg = zimg / (out['img/norm'] + 1e-8)
t = self.param('t', nn.initializers.zeros, (1,), jnp.float32)
out['t'] = jnp.exp(t)
return zimg, ztxt, out
# Copyright 2024 Google LLC.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any, Optional
import einops
import flax.linen as nn
import jax.numpy as jnp
class MlpBlock(nn.Module):
mlp_dim: int
@nn.compact
def __call__(self, x):
y = nn.Dense(self.mlp_dim)(x)
y = nn.gelu(y)
return nn.Dense(x.shape[-1])(y)
class MixerBlock(nn.Module):
"""Mixer block layer."""
tokens_mlp_dim: int
channels_mlp_dim: int
@nn.compact
def __call__(self, x):
y = nn.LayerNorm()(x)
y = jnp.swapaxes(y, 1, 2)
y = MlpBlock(self.tokens_mlp_dim, name='token_mixing')(y)
y = jnp.swapaxes(y, 1, 2)
x = x + y
y = nn.LayerNorm()(x)
return x + MlpBlock(self.channels_mlp_dim, name='channel_mixing')(y)
class MlpMixer(nn.Module):
"""Mixer architecture."""
patches: Any
num_classes: int
num_blocks: int
hidden_dim: int
tokens_mlp_dim: int
channels_mlp_dim: int
model_name: Optional[str] = None
@nn.compact
def __call__(self, inputs, *, train):
del train
x = nn.Conv(self.hidden_dim, self.patches.size,
strides=self.patches.size, name='stem')(inputs)
x = einops.rearrange(x, 'n h w c -> n (h w) c')
for _ in range(self.num_blocks):
x = MixerBlock(self.tokens_mlp_dim, self.channels_mlp_dim)(x)
x = nn.LayerNorm(name='pre_head_layer_norm')(x)
x = jnp.mean(x, axis=1)
if self.num_classes:
x = nn.Dense(self.num_classes, kernel_init=nn.initializers.zeros,
name='head')(x)
return x
# Copyright 2024 Google LLC.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Callable, Sequence, TypeVar
from flax import linen as nn
import jax.numpy as jnp
T = TypeVar('T')
def weight_standardize(w, axis, eps):
"""Subtracts mean and divides by standard deviation."""
w = w - jnp.mean(w, axis=axis)
w = w / (jnp.std(w, axis=axis) + eps)
return w
class StdConv(nn.Conv):
"""Convolution with weight standardization."""
def param(self,
name: str,
init_fn: Callable[..., T],
*init_args) -> T:
param = super().param(name, init_fn, *init_args)
if name == 'kernel':
param = weight_standardize(param, axis=[0, 1, 2], eps=1e-5)
return param
class ResidualUnit(nn.Module):
"""Bottleneck ResNet block."""
features: int
strides: Sequence[int] = (1, 1)
@nn.compact
def __call__(self, x):
needs_projection = (
x.shape[-1] != self.features * 4 or self.strides != (1, 1))
residual = x
if needs_projection:
residual = StdConv(
features=self.features * 4,
kernel_size=(1, 1),
strides=self.strides,
use_bias=False,
name='conv_proj')(
residual)
residual = nn.GroupNorm(name='gn_proj')(residual)
y = StdConv(
features=self.features,
kernel_size=(1, 1),
use_bias=False,
name='conv1')(
x)
y = nn.GroupNorm(name='gn1')(y)
y = nn.relu(y)
y = StdConv(
features=self.features,
kernel_size=(3, 3),
strides=self.strides,
use_bias=False,
name='conv2')(
y)
y = nn.GroupNorm(name='gn2')(y)
y = nn.relu(y)
y = StdConv(
features=self.features * 4,
kernel_size=(1, 1),
use_bias=False,
name='conv3')(
y)
y = nn.GroupNorm(name='gn3', scale_init=nn.initializers.zeros)(y)
y = nn.relu(residual + y)
return y
class ResNetStage(nn.Module):
"""A ResNet stage."""
block_size: Sequence[int]
nout: int
first_stride: Sequence[int]
@nn.compact
def __call__(self, x):
x = ResidualUnit(self.nout, strides=self.first_stride, name='unit1')(x)
for i in range(1, self.block_size):
x = ResidualUnit(self.nout, strides=(1, 1), name=f'unit{i + 1}')(x)
return x
# Copyright 2024 Google LLC.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from absl.testing import absltest
from absl.testing import parameterized
import jax
import jax.numpy as jnp
from vit_jax import models
from vit_jax.configs import models as config_lib
MODEL_SIZES = {
'LiT-B16B': 195_871_489,
'LiT-B16B_2': 195_280_897,
'LiT-L16L': 638_443_521,
'LiT-L16S': 331_140_353,
'LiT-L16Ti': 311_913_089.,
'Mixer-B_16': 59_880_472,
'Mixer-B_32': 60_293_428,
'Mixer-L_16': 208_196_168,
'R+ViT-Ti_16': 6_337_704,
'R26+ViT-B_32': 101_383_976,
'R26+ViT-S_32': 36_431_912,
'R50+ViT-B_16': 98_659_112,
'R50+ViT-L_32': 328_994_856,
'ViT-B_8': 86_576_872,
'ViT-B_16': 86_567_656,
'ViT-B_16-gap-norep': 86_566_120,
'ViT-B_32': 88_224_232,
'ViT-B_32-gap-norep': 88_222_696,
'ViT-H_14': 632_045_800,
'ViT-L_16': 304_326_632,
'ViT-L_32': 306_535_400,
'ViT-S_16': 22_050_664,
'ViT-S_16-gap-norep': 22_049_896,
'ViT-S_32': 22_878_952,
'ViT-S_32-gap-norep': 22_878_184,
'ViT-Ti_16': 5_717_416,
'testing': 21_390,
'testing-unpooled': 21_370,
}
class ModelsTest(parameterized.TestCase):
def test_all_tested(self):
self.assertEmpty(set(config_lib.MODEL_CONFIGS).difference(MODEL_SIZES))
@parameterized.parameters(*list(MODEL_SIZES.items()))
def test_can_instantiate(self, name, size):
rng = jax.random.PRNGKey(0)
kw = {} if name.startswith('LiT-') else dict(num_classes=1_000)
model = models.get_model(name, **kw)
batch_size = 2
images = jnp.ones([batch_size, 224, 224, 3], jnp.float32)
if name.startswith('LiT-'):
tokens = jnp.ones([batch_size, model.pp.max_len], jnp.int32)
variables = model.init(rng, images=images, tokens=tokens)
zimg, ztxt, _ = model.apply(variables, images=images, tokens=tokens)
self.assertEqual(zimg.shape[0], batch_size)
self.assertEqual(zimg.shape, ztxt.shape)
else:
variables = model.init(rng, images, train=False)
outputs = model.apply(variables, images, train=False)
if 'unpooled' in name:
self.assertEqual((2, 196, 1000), outputs.shape)
else:
self.assertEqual((2, 1000), outputs.shape)
param_count = sum(p.size for p in jax.tree.flatten(variables)[0])
self.assertEqual(
size, param_count,
f'Expected {name} to have {size} params, found {param_count}.')
if __name__ == '__main__':
absltest.main()
# Copyright 2024 Google LLC.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any, Callable, Optional, Tuple, Type
import flax.linen as nn
import jax.numpy as jnp
from vit_jax import models_resnet
Array = Any
PRNGKey = Any
Shape = Tuple[int]
Dtype = Any
class IdentityLayer(nn.Module):
"""Identity layer, convenient for giving a name to an array."""
@nn.compact
def __call__(self, x):
return x
class AddPositionEmbs(nn.Module):
"""Adds learned positional embeddings to the inputs.
Attributes:
posemb_init: positional embedding initializer.
"""
posemb_init: Callable[[PRNGKey, Shape, Dtype], Array]
param_dtype: Dtype = jnp.float32
@nn.compact
def __call__(self, inputs):
"""Applies the AddPositionEmbs module.
Args:
inputs: Inputs to the layer.
Returns:
Output tensor with shape `(bs, timesteps, in_dim)`.
"""
# inputs.shape is (batch_size, seq_len, emb_dim).
assert inputs.ndim == 3, ('Number of dimensions should be 3,'
' but it is: %d' % inputs.ndim)
pos_emb_shape = (1, inputs.shape[1], inputs.shape[2])
pe = self.param(
'pos_embedding', self.posemb_init, pos_emb_shape, self.param_dtype)
return inputs + pe
class MlpBlock(nn.Module):
"""Transformer MLP / feed-forward block."""
mlp_dim: int
dtype: Dtype = jnp.float32
param_dtype: Dtype = jnp.float32
out_dim: Optional[int] = None
dropout_rate: float = 0.1
kernel_init: Callable[[PRNGKey, Shape, Dtype],
Array] = nn.initializers.xavier_uniform()
bias_init: Callable[[PRNGKey, Shape, Dtype],
Array] = nn.initializers.normal(stddev=1e-6)
@nn.compact
def __call__(self, inputs, *, deterministic):
"""Applies Transformer MlpBlock module."""
actual_out_dim = inputs.shape[-1] if self.out_dim is None else self.out_dim
x = nn.Dense(
features=self.mlp_dim,
dtype=self.dtype,
param_dtype=self.param_dtype,
kernel_init=self.kernel_init,
bias_init=self.bias_init)( # pytype: disable=wrong-arg-types
inputs)
x = nn.gelu(x)
x = nn.Dropout(rate=self.dropout_rate)(x, deterministic=deterministic)
output = nn.Dense(
features=actual_out_dim,
dtype=self.dtype,
param_dtype=self.param_dtype,
kernel_init=self.kernel_init,
bias_init=self.bias_init)( # pytype: disable=wrong-arg-types
x)
output = nn.Dropout(
rate=self.dropout_rate)(
output, deterministic=deterministic)
return output
class Encoder1DBlock(nn.Module):
"""Transformer encoder layer.
Attributes:
inputs: input data.
mlp_dim: dimension of the mlp on top of attention block.
dtype: the dtype of the computation (default: float32).
dropout_rate: dropout rate.
attention_dropout_rate: dropout for attention heads.
deterministic: bool, deterministic or not (to apply dropout).
num_heads: Number of heads in nn.MultiHeadDotProductAttention
"""
mlp_dim: int
num_heads: int
dtype: Dtype = jnp.float32
dropout_rate: float = 0.1
attention_dropout_rate: float = 0.1
@nn.compact
def __call__(self, inputs, *, deterministic):
"""Applies Encoder1DBlock module.
Args:
inputs: Inputs to the layer.
deterministic: Dropout will not be applied when set to true.
Returns:
output after transformer encoder block.
"""
# Attention block.
assert inputs.ndim == 3, f'Expected (batch, seq, hidden) got {inputs.shape}'
x = nn.LayerNorm(dtype=self.dtype)(inputs)
x = nn.MultiHeadDotProductAttention(
dtype=self.dtype,
kernel_init=nn.initializers.xavier_uniform(),
broadcast_dropout=False,
deterministic=deterministic,
dropout_rate=self.attention_dropout_rate,
num_heads=self.num_heads)(
x, x)
x = nn.Dropout(rate=self.dropout_rate)(x, deterministic=deterministic)
x = x + inputs
# MLP block.
y = nn.LayerNorm(dtype=self.dtype)(x)
y = MlpBlock(
mlp_dim=self.mlp_dim, dtype=self.dtype, dropout_rate=self.dropout_rate)(
y, deterministic=deterministic)
return x + y
class Encoder(nn.Module):
"""Transformer Model Encoder for sequence to sequence translation.
Attributes:
num_layers: number of layers
mlp_dim: dimension of the mlp on top of attention block
num_heads: Number of heads in nn.MultiHeadDotProductAttention
dropout_rate: dropout rate.
attention_dropout_rate: dropout rate in self attention.
"""
num_layers: int
mlp_dim: int
num_heads: int
dropout_rate: float = 0.1
attention_dropout_rate: float = 0.1
add_position_embedding: bool = True
@nn.compact
def __call__(self, x, *, train):
"""Applies Transformer model on the inputs.
Args:
x: Inputs to the layer.
train: Set to `True` when training.
Returns:
output of a transformer encoder.
"""
assert x.ndim == 3 # (batch, len, emb)
if self.add_position_embedding:
x = AddPositionEmbs(
posemb_init=nn.initializers.normal(stddev=0.02), # from BERT.
name='posembed_input')(
x)
x = nn.Dropout(rate=self.dropout_rate)(x, deterministic=not train)
# Input Encoder
for lyr in range(self.num_layers):
x = Encoder1DBlock(
mlp_dim=self.mlp_dim,
dropout_rate=self.dropout_rate,
attention_dropout_rate=self.attention_dropout_rate,
name=f'encoderblock_{lyr}',
num_heads=self.num_heads)(
x, deterministic=not train)
encoded = nn.LayerNorm(name='encoder_norm')(x)
return encoded
class VisionTransformer(nn.Module):
"""VisionTransformer."""
num_classes: int
patches: Any
transformer: Any
hidden_size: int
resnet: Optional[Any] = None
representation_size: Optional[int] = None
classifier: str = 'token'
head_bias_init: float = 0.
encoder: Type[nn.Module] = Encoder
model_name: Optional[str] = None
@nn.compact
def __call__(self, inputs, *, train):
x = inputs
# (Possibly partial) ResNet root.
if self.resnet is not None:
width = int(64 * self.resnet.width_factor)
# Root block.
x = models_resnet.StdConv(
features=width,
kernel_size=(7, 7),
strides=(2, 2),
use_bias=False,
name='conv_root')(
x)
x = nn.GroupNorm(name='gn_root')(x)
x = nn.relu(x)
x = nn.max_pool(x, window_shape=(3, 3), strides=(2, 2), padding='SAME')
# ResNet stages.
if self.resnet.num_layers:
x = models_resnet.ResNetStage(
block_size=self.resnet.num_layers[0],
nout=width,
first_stride=(1, 1),
name='block1')(
x)
for i, block_size in enumerate(self.resnet.num_layers[1:], 1):
x = models_resnet.ResNetStage(
block_size=block_size,
nout=width * 2**i,
first_stride=(2, 2),
name=f'block{i + 1}')(
x)
n, h, w, c = x.shape
# We can merge s2d+emb into a single conv; it's the same.
x = nn.Conv(
features=self.hidden_size,
kernel_size=self.patches.size,
strides=self.patches.size,
padding='VALID',
name='embedding')(
x)
# Here, x is a grid of embeddings.
# (Possibly partial) Transformer.
if self.transformer is not None:
n, h, w, c = x.shape
x = jnp.reshape(x, [n, h * w, c])
# If we want to add a class token, add it here.
if self.classifier in ['token', 'token_unpooled']:
cls = self.param('cls', nn.initializers.zeros, (1, 1, c))
cls = jnp.tile(cls, [n, 1, 1])
x = jnp.concatenate([cls, x], axis=1)
x = self.encoder(name='Transformer', **self.transformer)(x, train=train)
if self.classifier == 'token':
x = x[:, 0]
elif self.classifier == 'gap':
x = jnp.mean(x, axis=list(range(1, x.ndim - 1))) # (1,) or (1,2)
elif self.classifier in ['unpooled', 'token_unpooled']:
pass
else:
raise ValueError(f'Invalid classifier={self.classifier}')
if self.representation_size is not None:
x = nn.Dense(features=self.representation_size, name='pre_logits')(x)
x = nn.tanh(x)
else:
x = IdentityLayer(name='pre_logits')(x)
if self.num_classes:
x = nn.Dense(
features=self.num_classes,
name='head',
kernel_init=nn.initializers.zeros,
bias_init=nn.initializers.constant(self.head_bias_init))(x)
return x
# Copyright 2024 Google LLC.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Preprocessing utilities for text/image models."""
import dataclasses
import numpy as np
import tensorflow as tf
import tensorflow_text
def get_tokenizer(tokenizer_name):
"""Returns a tokenizer specified by name ("bert" or "sentencpiece")."""
return {
'bert': BertTokenizer,
'sentencepiece': SentencepieceTokenizer,
}[tokenizer_name]
@dataclasses.dataclass(frozen=True)
class BertTokenizer:
"""BERT tokenizer with prepended CLS token and fixed sequence length.
This class can be used to tokenize batches of text tokens to numpy arrays
(by calling `__call__()`), or as part of a TensorFlow preprocessing graph
(via the method `preprocess_tf()`).
Attributes:
vocab_path: Path pointing to the vocabulary file. Can be any path string
that is understood by `tf.io.gfile`.
max_len: Length of tokenized sequences. If the provided texts result in
fewer tokens, then the sequence is zero-padded. If the provided texts
result in more tokens, then the tokens are clipped.
cls_token: Will be set during class construction.
"""
vocab_path: str
max_len: int
cls_token: int = dataclasses.field(init=False)
_tokenizer: tensorflow_text.BertTokenizer = dataclasses.field(init=False)
def __post_init__(self):
tokenizer = tensorflow_text.BertTokenizer(
self.vocab_path, token_out_type=tf.int32, lower_case=True)
with tf.io.gfile.GFile(self.vocab_path) as f:
vocab = f.read().split('\n')
cls_token = vocab.index('[CLS]')
# Work-around for frozen dataclasses:
# https://stackoverflow.com/questions/53756788
object.__setattr__(self, 'cls_token', cls_token)
object.__setattr__(self, '_tokenizer', tokenizer)
def preprocess_tf(self, text):
"""Tokenizes a single text as part of a TensorFlow graph."""
return self._preprocess(text[None])[0]
def _preprocess(self, texts):
token_ids = self._tokenizer.tokenize(texts)
tokens, mask = tensorflow_text.pad_model_inputs(token_ids, self.max_len - 1)
del mask # Recovered from zero padding in model.
count = tf.shape(tokens)[0]
return tf.concat([tf.fill([count, 1], self.cls_token), tokens], axis=1)
def __call__(self, texts):
"""Tokenizes a batch of texts to a numpy array."""
return self._preprocess(tf.constant(texts)).numpy()
@dataclasses.dataclass(frozen=True)
class SentencepieceTokenizer:
"""SentencePiece tokenizer with sticky eos.
Models that use this tokanizer usually use the *last* token, which is
guaranteed to be the "</s>" token (even if tokens are capped to `max_len`).
The same token is used for padding (and exposed as `eos_token`).
This class can be used to tokenize batches of text tokens to numpy arrays
(by calling `__call__()`), or as part of a TensorFlow preprocessing graph
(via the method `preprocess_tf()`).
Attributes:
vocab_path: Path pointing to the vocabulary file. Can be any path string
that is understood by `tf.io.gfile`.
max_len: Length of tokenized sequences. If the provided texts result in
fewer tokens, then the sequence is zero-padded. If the provided texts
result in more tokens, then the tokens are clipped.
eos_token: Token used for padding. Last token is guaranteed to be padded.
"""
vocab_path: str
max_len: int
eos_token: int = dataclasses.field(init=False)
_tokenizer: tensorflow_text.BertTokenizer = dataclasses.field(init=False)
def __post_init__(self):
tokenizer = tensorflow_text.SentencepieceTokenizer(
model=tf.io.gfile.GFile(self.vocab_path, 'rb').read(), add_eos=True)
eos_token = tokenizer.string_to_id('</s>')
# Work-around for frozen dataclasses:
# https://stackoverflow.com/questions/53756788
object.__setattr__(self, 'eos_token', eos_token)
object.__setattr__(self, '_tokenizer', tokenizer)
def preprocess_tf(self, text):
"""Tokenizes a single text as part of a TensorFlow graph."""
tokens = self._tokenizer.tokenize(text)
tokens = tokens[:self.max_len - 1] # to guarantee eos at end
return tf.pad(
tokens, [(0, self.max_len - tf.shape(tokens)[0])],
constant_values=self.eos_token)
def __call__(self, texts):
"""Tokenizes a batch of texts to a numpy array."""
return tf.stack([self.preprocess_tf(text) for text in texts]).numpy()
@dataclasses.dataclass(frozen=True)
class PreprocessImages:
"""Resizes images and sets value range to [-1, 1].
This class can be used to tokenize batches of text tokens to numpy arrays
(by calling `__call__()`), or as part of a TensorFlow preprocessing graph
(via the method `preprocess_tf()`).
Attributes:
size: Target size of images.
crop: If set to true, then the image will first be resized maintaining the
original aspect ratio, and then a central crop of that resized image will
be returned.
"""
size: int
crop: bool = False
def _resize_small(self, image): # pylint: disable=missing-docstring
h, w = tf.shape(image)[0], tf.shape(image)[1]
# Figure out the necessary h/w.
ratio = (
tf.cast(self.size, tf.float32) /
tf.cast(tf.minimum(h, w), tf.float32))
h = tf.cast(tf.round(tf.cast(h, tf.float32) * ratio), tf.int32)
w = tf.cast(tf.round(tf.cast(w, tf.float32) * ratio), tf.int32)
return tf.image.resize(image, (h, w), method='bilinear')
def _crop(self, image):
h, w = self.size, self.size
dy = (tf.shape(image)[0] - h) // 2
dx = (tf.shape(image)[1] - w) // 2
return tf.image.crop_to_bounding_box(image, dy, dx, h, w)
def _resize(self, image):
return tf.image.resize(
image, size=[self.size, self.size], method='bilinear')
def _value_range(self, image):
image = tf.cast(image, tf.float32) / 255
return -1 + image * 2
def preprocess_tf(self, image):
"""Resizes a single image as part of a TensorFlowg graph."""
assert image.dtype == tf.uint8
if self.crop:
image = self._resize_small(image)
image = self._crop(image)
else:
image = self._resize(image)
image = tf.cast(image, tf.uint8)
return self._value_range(image)
def __call__(self, images):
"""Resizes a sequence of images, returns a numpy array."""
return np.stack([
self.preprocess_tf(tf.constant(image)) for image in images
])
def get_pp(*, tokenizer_name, vocab_path, max_len, size, crop=False):
"""Returns preprocessing function for "image" and "text" features.
The returned function can directly be used with `tf.data.Dataset.map()`.
If either the text feature (feature key "text") or the image feature (feature
key "image") are not found, then they will be left untouched.
Note that the "image" feature is overwritten with the resized image, but the
"text" feature is tokenized into a new feature "tokens".
Args:
tokenizer_name: Name of tokenizer (either "bert", or "sentencepiece").
vocab_path: Argument passed to tokenizer.
max_len: Argument passed to tokenizer.
size: Argument passed to `PreprocessImages`.
crop: Argument passed to `PreprocessImages`.
"""
tokenizer_class = get_tokenizer(tokenizer_name)
tokenizer = tokenizer_class(vocab_path=vocab_path, max_len=max_len)
preprocess_images = PreprocessImages(size=size, crop=crop)
def pp(features):
features = {**features}
if 'image' in features:
features['image'] = preprocess_images.preprocess_tf(features['image'])
if 'text' in features:
features['tokens'] = tokenizer.preprocess_tf(features['text'])
return features
return pp
# Copyright 2024 Google LLC.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import contextlib
import tempfile
from unittest import mock
from absl.testing import absltest
import numpy as np
import tensorflow as tf
from vit_jax import preprocess
VOCAB = """[PAD]
[CLS]
some
test
words"""
@contextlib.contextmanager
def _create_vocab():
with tempfile.NamedTemporaryFile('w') as f:
f.write(VOCAB)
f.flush()
yield f.name
class PreprocessTest(absltest.TestCase):
def test_bert_tokenizer(self):
with _create_vocab() as vocab_path:
tokenizer = preprocess.BertTokenizer(vocab_path=vocab_path, max_len=3)
tokens = tokenizer(['some', 'test', 'words', 'xxx'])
np.testing.assert_equal(tokens, [
[1, 2, 0],
[1, 3, 0],
[1, 4, 0],
[1, 5, 0],
])
@mock.patch('tensorflow_text.SentencepieceTokenizer')
@mock.patch('tensorflow.io.gfile.GFile')
def test_sentencepiece_tokenizer(self, gfile_patch, tokenizer_patch):
gfile_patch.return_value.read.return_value = 'test vocab'
eos_token = 7
tokenizer_patch.return_value.string_to_id.return_value = eos_token
tokenizer_patch.return_value.tokenize.side_effect = (
tf.constant([1, eos_token], tf.int32),
tf.constant([2, 3, eos_token], tf.int32),
tf.constant([4, 5, 6, eos_token], tf.int32),
)
tokenizer = preprocess.SentencepieceTokenizer(
vocab_path='test_path', max_len=3)
gfile_patch.assert_called_once_with('test_path', 'rb')
tokenizer_patch.assert_called_once_with(model='test vocab', add_eos=True)
tokenizer_patch.return_value.string_to_id.assert_called_once_with('</s>')
tokens = tokenizer(['some', 'test', 'words'])
tokenizer_patch.return_value.tokenize.assert_has_calls(
(mock.call('some'), mock.call('test'), mock.call('words')))
np.testing.assert_equal(tokens, [
[1, eos_token, eos_token],
[2, 3, eos_token],
[4, 5, eos_token],
])
def test_preprocess_images(self):
# white images with black border
img1 = 255 * np.concatenate([ # portrait image
np.zeros([2, 10, 3], np.uint8),
np.ones([12, 10, 3], np.uint8),
np.zeros([2, 10, 3], np.uint8),
], axis=0)
img2 = 255 * np.concatenate([ # landscape image
np.zeros([10, 2, 3], np.uint8),
np.ones([10, 12, 3], np.uint8),
np.zeros([10, 2, 3], np.uint8),
], axis=1)
preprocess_images = preprocess.PreprocessImages(size=4, crop=False)
imgs = preprocess_images([img1, img2])
self.assertEqual(imgs.shape, (2, 4, 4, 3))
self.assertLess(imgs.mean(), 1.0) # borders resized
preprocess_images = preprocess.PreprocessImages(size=4, crop=True)
imgs = preprocess_images([img1, img2])
self.assertEqual(imgs.shape, (2, 4, 4, 3))
self.assertEqual(imgs.mean(), 1.0) # borders cropped
def test_pp_bert(self):
with _create_vocab() as vocab_path:
pp = preprocess.get_pp(
tokenizer_name='bert', vocab_path=vocab_path, max_len=3, size=4)
ds = tf.data.Dataset.from_tensor_slices({
'text':
tf.constant(['test', 'test']),
'image': [
tf.ones([10, 10, 3], tf.uint8),
tf.ones([10, 10, 3], tf.uint8)
],
})
b = next(iter(ds.map(pp).batch(2).as_numpy_iterator()))
dtypes_shapes = {k: (v.dtype, v.shape) for k, v in b.items()}
np.testing.assert_equal(dtypes_shapes, {
'image': (np.float32, (2, 4, 4, 3)),
'text': (object, (2,)),
'tokens': (np.int32, (2, 3))
})
@mock.patch('tensorflow_text.SentencepieceTokenizer')
@mock.patch('tensorflow.io.gfile.GFile')
def test_pp_sentencepiece(self, gfile_patch, tokenizer_patch):
eos_token = 7
gfile_patch.return_value.read.return_value = 'test vocab'
tokenizer_patch.return_value.string_to_id.return_value = eos_token
tokenizer_patch.return_value.tokenize.side_effect = (
tf.constant([1, eos_token], tf.int32),
tf.constant([2, 3, eos_token], tf.int32),
)
pp = preprocess.get_pp(
tokenizer_name='sentencepiece',
vocab_path='test',
max_len=3,
size=4)
ds = tf.data.Dataset.from_tensor_slices({
'text':
tf.constant(['test', 'test']),
'image': [
tf.ones([10, 10, 3], tf.uint8),
tf.ones([10, 10, 3], tf.uint8)
],
})
b = next(iter(ds.map(pp).batch(2).as_numpy_iterator()))
dtypes_shapes = {k: (v.dtype, v.shape) for k, v in b.items()}
np.testing.assert_equal(dtypes_shapes, {
'image': (np.float32, (2, 4, 4, 3)),
'text': (object, (2,)),
'tokens': (np.int32, (2, 3))
})
if __name__ == '__main__':
absltest.main()
absl-py>=0.12.0
aqtp!=0.1.1 # https://github.com/google/aqt/issues/196
chex>=0.0.7
clu>=0.0.3
einops>=0.3.0
flax>=0.4.1
git+https://github.com/google/flaxformer
jax[tpu]>=0.2.16
--find-links https://storage.googleapis.com/jax-releases/libtpu_releases.html
ml-collections==0.1.0
numpy>=1.19.5
pandas>=1.1.0
tensorflow-datasets>=4.0.1
tensorflow-probability>=0.11.1
tensorflow-text>=2.9.0
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment