Commit 5498e94a authored by suily's avatar suily
Browse files

Initial commit

parent 14530156
Pipeline #1635 failed with stages
in 0 seconds
absl-py>=0.12.0
aqtp!=0.1.1 # https://github.com/google/aqt/issues/196
chex>=0.0.7
clu>=0.0.3 # 0.0.9
einops>=0.3.0
flax>=0.6.4 #0.8.4、optax-0.2.2、orbax-checkpoint==0.4.1
# git+https://github.com/google/flaxformer
jax[cuda11_cudnn86]==0.4.23 --find-links https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
ml-collections>=0.1.0
numpy>=1.19.5
pandas>=1.1.0
tensorflow-cpu>=2.4.0 #tensorflow-cpu==2.14.0 # Using tensorflow-cpu to have all GPU memory for JAX.
tensorflow-datasets>=4.0.1
tensorflow-probability>=0.11.1
tensorflow-text>=2.9.0
\ No newline at end of file
# Copyright 2024 Google LLC.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Utility code to create fake pre-trained checkpoints."""
import os
import dataclasses
import flax
import jax
import jax.numpy as jnp
import numpy as np
from vit_jax import models
def _traverse_with_names(tree):
"""Traverses nested dicts/dataclasses and emits (leaf_name, leaf_val)."""
if dataclasses.is_dataclass(tree):
tree = flax.serialization.to_state_dict(tree)
if isinstance(tree, dict) or isinstance(tree, flax.core.FrozenDict):
keys = sorted(tree.keys())
for key in keys:
for path, v in _traverse_with_names(tree[key]):
yield (key + '/' + path).rstrip('/'), v
else:
yield '', tree
def _tree_flatten_with_names(tree):
"""Populates tree_flatten with leaf names.
This function populates output of tree_flatten with leaf names, using a
custom traversal that produces names is provided. The custom traversal does
NOT have to traverse tree in the same order as jax, as we take care of
automatically aligning jax' and custom traversals.
Args:
tree: python tree.
Returns:
A list of values with names: [(name, value), ...]
"""
vals, tree_def = jax.tree.flatten(tree)
# "Fake" token tree that is use to track jax internal tree traversal and
# adjust our custom tree traversal to be compatible with it.
tokens = range(len(vals))
token_tree = tree_def.unflatten(tokens)
val_names, perm = zip(*_traverse_with_names(token_tree))
inv_perm = np.argsort(perm)
# Custom traversal should visit the same number of leaves.
assert len(val_names) == len(vals)
return [(val_names[i], v) for i, v in zip(inv_perm, vals)], tree_def
def _save(data, path):
"""Util for checkpointing: saves jax pytree objects to the disk."""
names_and_vals, _ = _tree_flatten_with_names(data)
os.makedirs(os.path.dirname(path), exist_ok=True)
with open(path, 'wb') as f:
np.savez(f, **{k: v for k, v in names_and_vals})
def create_checkpoint(model_config, path):
"""Initializes model and stores weights in specified path."""
model = models.VisionTransformer(num_classes=1, **model_config)
variables = model.init(
jax.random.PRNGKey(0),
jnp.ones([1, 16, 16, 3], jnp.float32),
train=False,
)
_save(variables['params'], path)
# Copyright 2024 Google LLC.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import functools
import os
import time
from absl import logging
from clu import metric_writers
from clu import periodic_actions
import flax
from flax.training import checkpoints as flax_checkpoints
import jax
import jax.numpy as jnp
import ml_collections
import numpy as np
import optax
import tensorflow as tf
from vit_jax import checkpoint
from vit_jax import input_pipeline
from vit_jax import models
from vit_jax import utils
def make_update_fn(*, apply_fn, accum_steps, tx):
"""Returns update step for data parallel training."""
def update_fn(params, opt_state, batch, rng):
_, new_rng = jax.random.split(rng)
# Bind the rng key to the device id (which is unique across hosts)
# Note: This is only used for multi-host training (i.e. multiple computers
# each with multiple accelerators).
dropout_rng = jax.random.fold_in(rng, jax.lax.axis_index('batch'))
def cross_entropy_loss(*, logits, labels):
logp = jax.nn.log_softmax(logits)
return -jnp.mean(jnp.sum(logp * labels, axis=1))
def loss_fn(params, images, labels):
logits = apply_fn(
dict(params=params),
rngs=dict(dropout=dropout_rng),
inputs=images,
train=True)
return cross_entropy_loss(logits=logits, labels=labels)
l, g = utils.accumulate_gradient(
jax.value_and_grad(loss_fn), params, batch['image'], batch['label'],
accum_steps)
g = jax.tree_util.tree_map(lambda x: jax.lax.pmean(x, axis_name='batch'), g)
updates, opt_state = tx.update(g, opt_state)
params = optax.apply_updates(params, updates)
l = jax.lax.pmean(l, axis_name='batch')
return params, opt_state, l, new_rng
return jax.pmap(update_fn, axis_name='batch', donate_argnums=(0,))
def train_and_evaluate(config: ml_collections.ConfigDict, workdir: str):
"""Runs training interleaved with evaluation."""
# Setup input pipeline
dataset_info = input_pipeline.get_dataset_info(config.dataset, 'train')
ds_train, ds_test = input_pipeline.get_datasets(config)
batch = next(iter(ds_train))
logging.info(ds_train)
logging.info(ds_test)
# Build VisionTransformer architecture
model_cls = {'ViT': models.VisionTransformer,
'Mixer': models.MlpMixer}[config.get('model_type', 'ViT')]
model = model_cls(num_classes=dataset_info['num_classes'], **config.model)
def init_model():
return model.init(
jax.random.PRNGKey(0),
# Discard the "num_local_devices" dimension for initialization.
jnp.ones(batch['image'].shape[1:], batch['image'].dtype.name),
train=False)
# Use JIT to make sure params reside in CPU memory.
variables = jax.jit(init_model, backend='cpu')()
model_or_filename = config.get('model_or_filename')
if model_or_filename:
# Loading model from repo published with "How to train your ViT? Data,
# Augmentation, and Regularization in Vision Transformers" paper.
# https://arxiv.org/abs/2106.10270
if '-' in model_or_filename:
filename = model_or_filename
else:
# Select best checkpoint from i21k pretraining by final upstream
# validation accuracy.
df = checkpoint.get_augreg_df(directory=config.pretrained_dir)
sel = df.filename.apply(
lambda filename: filename.split('-')[0] == model_or_filename)
best = df.loc[sel].query('ds=="i21k"').sort_values('final_val').iloc[-1]
filename = best.filename
logging.info('Selected fillename="%s" for "%s" with final_val=%.3f',
filename, model_or_filename, best.final_val)
pretrained_path = os.path.join(config.pretrained_dir,
f'{config.model.model_name}.npz')
else:
# ViT / Mixer papers
filename = config.model.model_name
pretrained_path = os.path.join(config.pretrained_dir, f'{filename}.npz')
if not tf.io.gfile.exists(pretrained_path):
raise ValueError(
f'Could not find "{pretrained_path}" - you can download models from '
'"gs://vit_models/imagenet21k" or directly set '
'--config.pretrained_dir="gs://vit_models/imagenet21k".')
params = checkpoint.load_pretrained(
pretrained_path=pretrained_path,
init_params=variables['params'],
model_config=config.model)
total_steps = config.total_steps
lr_fn = utils.create_learning_rate_schedule(total_steps, config.base_lr,
config.decay_type,
config.warmup_steps)
tx = optax.chain(
optax.clip_by_global_norm(config.grad_norm_clip),
optax.sgd(
learning_rate=lr_fn,
momentum=0.9,
accumulator_dtype='bfloat16',
),
)
update_fn_repl = make_update_fn(
apply_fn=model.apply, accum_steps=config.accum_steps, tx=tx)
infer_fn_repl = jax.pmap(functools.partial(model.apply, train=False))
initial_step = 1
opt_state = tx.init(params)
params, opt_state, initial_step = flax_checkpoints.restore_checkpoint(
workdir, (params, opt_state, initial_step))
logging.info('Will start/continue training at initial_step=%d', initial_step)
params_repl, opt_state_repl = flax.jax_utils.replicate((params, opt_state))
# Delete references to the objects that are not needed anymore
del opt_state
del params
# Prepare the learning-rate and pre-fetch it to device to avoid delays.
update_rng_repl = flax.jax_utils.replicate(jax.random.PRNGKey(0))
# Setup metric writer & hooks.
writer = metric_writers.create_default_writer(workdir, asynchronous=False)
writer.write_hparams(config.to_dict())
hooks = [
periodic_actions.Profile(logdir=workdir),
periodic_actions.ReportProgress(
num_train_steps=total_steps, writer=writer),
]
# Run training loop
logging.info('Starting training loop; initial compile can take a while...')
t0 = lt0 = time.time()
lstep = initial_step
for step, batch in zip(
range(initial_step, total_steps + 1),
input_pipeline.prefetch(ds_train, config.prefetch)):
with jax.profiler.StepTraceAnnotation('train', step_num=step):
params_repl, opt_state_repl, loss_repl, update_rng_repl = update_fn_repl(
params_repl, opt_state_repl, batch, update_rng_repl)
for hook in hooks:
hook(step)
if step == initial_step:
logging.info('First step took %.1f seconds.', time.time() - t0)
t0 = time.time()
lt0, lstep = time.time(), step
# Report training metrics
if config.progress_every and step % config.progress_every == 0:
img_sec_core_train = (config.batch * (step - lstep) /
(time.time() - lt0)) / jax.device_count()
lt0, lstep = time.time(), step
writer.write_scalars(
step,
dict(
train_loss=float(flax.jax_utils.unreplicate(loss_repl)),
img_sec_core_train=img_sec_core_train))
done = step / total_steps
logging.info(f'Step: {step}/{total_steps} {100*done:.1f}%, ' # pylint: disable=logging-fstring-interpolation
f'img/sec/core: {img_sec_core_train:.1f}, '
f'ETA: {(time.time()-t0)/done*(1-done)/3600:.2f}h')
# Run evaluation
if ((config.eval_every and step % config.eval_every == 0) or
(step == total_steps)):
accuracies = []
lt0 = time.time()
for test_batch in input_pipeline.prefetch(ds_test, config.prefetch):
logits = infer_fn_repl(
dict(params=params_repl), test_batch['image'])
accuracies.append(
(np.argmax(logits,
axis=-1) == np.argmax(test_batch['label'],
axis=-1)).mean())
accuracy_test = np.mean(accuracies)
img_sec_core_test = (
config.batch_eval * ds_test.cardinality().numpy() /
(time.time() - lt0) / jax.device_count())
lt0 = time.time()
lr = float(lr_fn(step))
logging.info(f'Step: {step} ' # pylint: disable=logging-fstring-interpolation
f'Learning rate: {lr:.7f}, '
f'Test accuracy: {accuracy_test:0.5f}, '
f'img/sec/core: {img_sec_core_test:.1f}')
writer.write_scalars(
step,
dict(
accuracy_test=accuracy_test,
lr=lr,
img_sec_core_test=img_sec_core_test))
# Store checkpoint.
if ((config.checkpoint_every and step % config.eval_every == 0) or
step == total_steps):
checkpoint_path = flax_checkpoints.save_checkpoint(
workdir, (flax.jax_utils.unreplicate(params_repl),
flax.jax_utils.unreplicate(opt_state_repl), step), step)
logging.info('Stored checkpoint at step %d to "%s"', step,
checkpoint_path)
return flax.jax_utils.unreplicate(params_repl)
# Copyright 2024 Google LLC.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import tempfile
from absl.testing import absltest
from absl.testing import parameterized
import ml_collections
import tensorflow_datasets as tfds
from vit_jax import test_utils
from vit_jax import train
from vit_jax.configs import common
from vit_jax.configs import models
# from PIL import Image
# import numpy as np
# Image.fromarray(np.array([[[0, 0, 0]]], np.uint8)).save('black1px.jpg')
# print(repr(file('black1px.jpg', 'rb').read()))
JPG_BLACK_1PX = (b'\xff\xd8\xff\xe0\x00\x10JFIF\x00\x01\x01\x00\x00\x01\x00\x01\x00\x00\xff\xdb\x00C\x00\x08\x06\x06\x07\x06\x05\x08\x07\x07\x07\t\t\x08\n\x0c\x14\r\x0c\x0b\x0b\x0c\x19\x12\x13\x0f\x14\x1d\x1a\x1f\x1e\x1d\x1a\x1c\x1c'
b' $.\' '
b'",#\x1c\x1c(7),01444\x1f\'9=82<.342\xff\xdb\x00C\x01\t\t\t\x0c\x0b\x0c\x18\r\r\x182!\x1c!22222222222222222222222222222222222222222222222222\xff\xc0\x00\x11\x08\x00\x01\x00\x01\x03\x01"\x00\x02\x11\x01\x03\x11\x01\xff\xc4\x00\x1f\x00\x00\x01\x05\x01\x01\x01\x01\x01\x01\x00\x00\x00\x00\x00\x00\x00\x00\x01\x02\x03\x04\x05\x06\x07\x08\t\n\x0b\xff\xc4\x00\xb5\x10\x00\x02\x01\x03\x03\x02\x04\x03\x05\x05\x04\x04\x00\x00\x01}\x01\x02\x03\x00\x04\x11\x05\x12!1A\x06\x13Qa\x07"q\x142\x81\x91\xa1\x08#B\xb1\xc1\x15R\xd1\xf0$3br\x82\t\n\x16\x17\x18\x19\x1a%&\'()*456789:CDEFGHIJSTUVWXYZcdefghijstuvwxyz\x83\x84\x85\x86\x87\x88\x89\x8a\x92\x93\x94\x95\x96\x97\x98\x99\x9a\xa2\xa3\xa4\xa5\xa6\xa7\xa8\xa9\xaa\xb2\xb3\xb4\xb5\xb6\xb7\xb8\xb9\xba\xc2\xc3\xc4\xc5\xc6\xc7\xc8\xc9\xca\xd2\xd3\xd4\xd5\xd6\xd7\xd8\xd9\xda\xe1\xe2\xe3\xe4\xe5\xe6\xe7\xe8\xe9\xea\xf1\xf2\xf3\xf4\xf5\xf6\xf7\xf8\xf9\xfa\xff\xc4\x00\x1f\x01\x00\x03\x01\x01\x01\x01\x01\x01\x01\x01\x01\x00\x00\x00\x00\x00\x00\x01\x02\x03\x04\x05\x06\x07\x08\t\n\x0b\xff\xc4\x00\xb5\x11\x00\x02\x01\x02\x04\x04\x03\x04\x07\x05\x04\x04\x00\x01\x02w\x00\x01\x02\x03\x11\x04\x05!1\x06\x12AQ\x07aq\x13"2\x81\x08\x14B\x91\xa1\xb1\xc1\t#3R\xf0\x15br\xd1\n\x16$4\xe1%\xf1\x17\x18\x19\x1a&\'()*56789:CDEFGHIJSTUVWXYZcdefghijstuvwxyz\x82\x83\x84\x85\x86\x87\x88\x89\x8a\x92\x93\x94\x95\x96\x97\x98\x99\x9a\xa2\xa3\xa4\xa5\xa6\xa7\xa8\xa9\xaa\xb2\xb3\xb4\xb5\xb6\xb7\xb8\xb9\xba\xc2\xc3\xc4\xc5\xc6\xc7\xc8\xc9\xca\xd2\xd3\xd4\xd5\xd6\xd7\xd8\xd9\xda\xe2\xe3\xe4\xe5\xe6\xe7\xe8\xe9\xea\xf2\xf3\xf4\xf5\xf6\xf7\xf8\xf9\xfa\xff\xda\x00\x0c\x03\x01\x00\x02\x11\x03\x11\x00?\x00\xf9\xfe\x8a(\xa0\x0f\xff\xd9') # pylint: disable=line-too-long
class TrainTest(parameterized.TestCase):
@parameterized.named_parameters(
('tfds', 'tfds'),
('directory', 'directory'),
)
def test_train_and_evaluate(self, dataset_source):
config = common.get_config()
config.model = models.get_testing_config()
config.batch = 64
config.accum_steps = 2
config.batch_eval = 8
config.total_steps = 1
with tempfile.TemporaryDirectory() as workdir:
if dataset_source == 'tfds':
config.dataset = 'cifar10'
config.pp = ml_collections.ConfigDict({
'train': 'train[:98%]',
'test': 'test',
'crop': 224
})
elif dataset_source == 'directory':
config.dataset = os.path.join(workdir, 'dataset')
config.pp = ml_collections.ConfigDict({'crop': 224})
for mode in ('train', 'test'):
for class_name in ('test1', 'test2'):
for i in range(8):
path = os.path.join(config.dataset, mode, class_name, f'{i}.jpg')
os.makedirs(os.path.dirname(path), exist_ok=True)
with open(path, 'wb') as f:
f.write(JPG_BLACK_1PX)
else:
raise ValueError(f'Unknown dataset_source: "{dataset_source}"')
config.pretrained_dir = workdir
test_utils.create_checkpoint(config.model, f'{workdir}/testing.npz')
_ = train.train_and_evaluate(config, workdir)
self.assertTrue(os.path.exists(f'{workdir}/checkpoint_1'))
if __name__ == '__main__':
absltest.main()
# Copyright 2024 Google LLC.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging as python_logging
import os
import threading
from absl import logging
import jax
import jax.numpy as jnp
import tensorflow as tf
class GFileHandler(python_logging.StreamHandler):
"""Writes log messages to file using tf.io.gfile."""
def __init__(self, filename, mode, flush_secs=1.0):
super().__init__()
tf.io.gfile.makedirs(os.path.dirname(filename))
if mode == 'a' and not tf.io.gfile.exists(filename):
mode = 'w'
self.filehandle = tf.io.gfile.GFile(filename, mode)
self.flush_secs = flush_secs
self.flush_timer = None
def flush(self):
self.filehandle.flush()
def emit(self, record):
msg = self.format(record)
self.filehandle.write(f'{msg}\n')
if self.flush_timer is not None:
self.flush_timer.cancel()
self.flush_timer = threading.Timer(self.flush_secs, self.flush)
self.flush_timer.start()
def add_gfile_logger(workdir, *, basename='train', level=python_logging.INFO):
"""Adds GFile file logger to Python logging handlers."""
fh = GFileHandler(f'{workdir}/{basename}.log', 'a')
fh.setLevel(level)
fh.setFormatter(logging.PythonFormatter())
python_logging.getLogger('').addHandler(fh)
def create_learning_rate_schedule(total_steps,
base,
decay_type,
warmup_steps,
linear_end=1e-5):
"""Creates learning rate schedule.
Currently only warmup + {linear,cosine} but will be a proper mini-language
like preprocessing one in the future.
Args:
total_steps: The total number of steps to run.
base: The starting learning-rate (without warmup).
decay_type: 'linear' or 'cosine'.
warmup_steps: how many steps to warm up for.
linear_end: Minimum learning rate.
Returns:
A function learning_rate(step): float -> {"learning_rate": float}.
"""
def step_fn(step):
"""Step to learning rate function."""
lr = base
progress = (step - warmup_steps) / float(total_steps - warmup_steps)
progress = jnp.clip(progress, 0.0, 1.0)
if decay_type == 'linear':
lr = linear_end + (lr - linear_end) * (1.0 - progress)
elif decay_type == 'cosine':
lr = lr * 0.5 * (1. + jnp.cos(jnp.pi * progress))
else:
raise ValueError(f'Unknown lr type {decay_type}')
if warmup_steps:
lr = lr * jnp.minimum(1., step / warmup_steps)
return jnp.asarray(lr, dtype=jnp.float32)
return step_fn
def accumulate_gradient(loss_and_grad_fn, params, images, labels, accum_steps):
"""Accumulate gradient over multiple steps to save on memory."""
if accum_steps and accum_steps > 1:
assert images.shape[0] % accum_steps == 0, (
f'Bad accum_steps {accum_steps} for batch size {images.shape[0]}')
step_size = images.shape[0] // accum_steps
l, g = loss_and_grad_fn(params, images[:step_size], labels[:step_size])
def acc_grad_and_loss(i, l_and_g):
imgs = jax.lax.dynamic_slice(images, (i * step_size, 0, 0, 0),
(step_size,) + images.shape[1:])
lbls = jax.lax.dynamic_slice(labels, (i * step_size, 0),
(step_size, labels.shape[1]))
li, gi = loss_and_grad_fn(params, imgs, lbls)
l, g = l_and_g
return (l + li, jax.tree_util.tree_map(lambda x, y: x + y, g, gi))
l, g = jax.lax.fori_loop(1, accum_steps, acc_grad_and_loss, (l, g))
return jax.tree_util.tree_map(lambda x: x / accum_steps, (l, g))
else:
return loss_and_grad_fn(params, images, labels)
This diff is collapsed.
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment