from vit_jax import checkpoint from vit_jax import input_pipeline from vit_jax import utils from vit_jax import models from vit_jax import train from vit_jax.configs import common as common_config from vit_jax.configs import models as models_config from absl import logging import flax import jax from matplotlib import pyplot as plt import numpy as np import optax import tqdm import os logging.set_verbosity(logging.INFO) # import PIL import tensorflow_datasets as tfds import time # import tensorflow as tf '''显示可用设备gpu的数量。''' from jax.lib import xla_bridge jax_test=xla_bridge.get_backend().platform print(jax_test,jax.local_devices()) if not (jax_test=='gpu'): exit() model_name = 'ViT-B_16' #@param ["ViT-B_32", "Mixer-B_16"] # assert os.path.exists(f'./test_result/{model_name}.npz') '''加载数据集''' dataset = 'cifar100' # imagenet2012 cifar10 cifar100 batch_size = 512 config = common_config.with_dataset(common_config.get_config(), dataset) # config.shuffle_buffer=1000 # config.accum_steps=64 config.batch = batch_size config.pp.crop = 384 # 建立数据集 ds_train = input_pipeline.get_data_from_tfds(config=config, mode='train') ds_test = input_pipeline.get_data_from_tfds(config=config, mode='test') num_classes = input_pipeline.get_dataset_info(dataset, 'train')['num_classes'] del config # Only needed to instantiate datasets. # Fetch a batch of test images for illustration purposes. batch = next(iter(ds_test.as_numpy_iterator())) # Note the shape : [num_local_devices, local_batch_size, h, w, c] print(batch['image'].shape) exit() # tf.config.set_visible_devices([], 'GPU') # print(tf.config.get_visible_devices('GPU')) '''加载预训练模型''' model_config = models_config.MODEL_CONFIGS[model_name] print(model_config) # 加载模型定义并初始化随机参数。 # 这也将模型编译为XLA(第一次需要几分钟)。 if model_name.startswith('Mixer'): model = models.MlpMixer(num_classes=num_classes, **model_config) else: model = models.VisionTransformer(num_classes=num_classes, **model_config) variables = jax.jit(lambda: model.init( jax.random.PRNGKey(0), # 丢弃用于初始化的批处理的“num_local_devices”维度。 batch['image'][0, :1], train=False, ), backend='cpu')() #加载和转换预训练检查点。 # 这涉及到加载实际的预训练模型结果,但也要修改一点参数,例如改变最终层,并调整位置嵌入的大小。有关详细信息,请参阅代码和本文的方法。 params = checkpoint.load_pretrained( pretrained_path=f'./test_result/{model_name}.npz', init_params=variables['params'], model_config=model_config ) '''评估''' params_repl = flax.jax_utils.replicate(params) print('params.cls:', type(params['head']['bias']).__name__, params['head']['bias'].shape) print('params_repl.cls:', type(params_repl['head']['bias']).__name__, params_repl['head']['bias'].shape) # 然后将调用映射到我们模型的forward pass到所有可用的设备。 vit_apply_repl = jax.pmap(lambda params, inputs: model.apply( dict(params=params), inputs, train=False)) def get_accuracy(params_repl): """返回对测试集求值的精度""" good = total = 0 steps = input_pipeline.get_dataset_info(dataset, 'test')['num_examples'] // batch_size for _, batch in zip(tqdm.trange(steps), ds_test.as_numpy_iterator()): predicted = vit_apply_repl(params_repl, batch['image']) is_same = predicted.argmax(axis=-1) == batch['label'].argmax(axis=-1) good += is_same.sum() total += len(is_same.flatten()) return good / total # 没有微调的随机性能。 print(get_accuracy(params_repl)) exit() '''微调''' # 100 Steps take approximately 15 minutes in the TPU runtime. total_steps = 50 warmup_steps = 5 decay_type = 'cosine' grad_norm_clip = 1 # 这控制了批处理被分割的转发次数。8适用于具有8个设备的TPU运行时。64应该可以在GPU上工作。当然,您也可以调整上面的batch_size,但这需要相应地调整学习率。 accum_steps = 64 # TODO:可能要改 base_lr = 0.03 # 检查 train.make_update_fn lr_fn = utils.create_learning_rate_schedule(total_steps, base_lr, decay_type, warmup_steps) # 我们使用一个动量优化器,使用一半精度的状态来节省内存。它还实现了梯度裁剪 tx = optax.chain( optax.clip_by_global_norm(grad_norm_clip), optax.sgd( learning_rate=lr_fn, momentum=0.9, accumulator_dtype='bfloat16', ), ) update_fn_repl = train.make_update_fn( apply_fn=model.apply, accum_steps=accum_steps, tx=tx) opt_state = tx.init(params) opt_state_repl = flax.jax_utils.replicate(opt_state) # Initialize PRNGs for dropout. update_rng_repl = flax.jax_utils.replicate(jax.random.PRNGKey(0)) # 训练更新 losses = [] lrs = [] # Completes in ~20 min on the TPU runtime. start = time.time() for step, batch in zip( tqdm.trange(1, total_steps + 1), ds_train.as_numpy_iterator(), ): params_repl, opt_state_repl, loss_repl, update_rng_repl = update_fn_repl( params_repl, opt_state_repl, batch, update_rng_repl) losses.append(loss_repl[0]) lrs.append(lr_fn(step)) end = time.time() print(f"{model_name}_{dataset}_{total_steps}_{warmup_steps}微调时间为:",end-start) print(get_accuracy(params_repl)) # 绘制学习率变化曲线并保存 plt.plot(losses) plt.savefig(f'./test_result/{model_name}_{dataset}/losses_plot.png') plt.close() plt.plot(lrs) plt.savefig(f'./test_result/{model_name}_{dataset}/lrs_plot.png') plt.close() # 在CIFAR10上,Mixer-B/16应该是~96.7%,vitb /32应该是97.7%(都是@224) exit() # exit() '''推理''' # #下载一个预训练的模型 # model_name = 'ViT-L_16' # model_config = models_config.MODEL_CONFIGS[model_name] # print(model_config) # model = models.VisionTransformer(num_classes=1000, **model_config) # assert os.path.exists(f'./test_result/{model_name}_imagenet2012.npz') # # 加载和转换预训练的检查点 # params = checkpoint.load(f'./test_result/{model_name}_imagenet2012.npz') # params['pre_logits'] = {} # Need to restore empty leaf for Flax. # # 获取图像标签. # # get_ipython().system('wget https://storage.googleapis.com/bit_models/ilsvrc2012_wordnet_lemmas.txt') # imagenet_labels = dict(enumerate(open('./test_result/ilsvrc2012_wordnet_lemmas.txt'))) # # 得到一张具有正确尺寸的随机图片 # # resolution = 224 if model_name.startswith('Mixer') else 384 # # get_ipython().system('wget https://picsum.photos/$resolution -O picsum.jpg') # img = PIL.Image.open('./test_result/picsum.jpg') # # 预测单个项目的批处理(注意非常高效的TPU使用…) # logits, = model.apply(dict(params=params), (np.array(img) / 128 - 1)[None, ...], train=False) # preds = np.array(jax.nn.softmax(logits)) # for idx in preds.argsort()[:-11:-1]: # print(f'{preds[idx]:.5f} : {imagenet_labels[idx]}', end='')