from vit_jax import checkpoint from vit_jax import input_pipeline from vit_jax import utils from vit_jax import models from vit_jax import train from vit_jax.configs import common as common_config from vit_jax.configs import models as models_config from absl import logging import flax import jax from matplotlib import pyplot as plt import numpy as np import optax import tqdm import os logging.set_verbosity(logging.INFO) import PIL import tensorflow_datasets as tfds import time # import tensorflow as tf '''测试dcu/gpu''' from jax.lib import xla_bridge jax_test=xla_bridge.get_backend().platform if not (jax_test=='gpu'): exit() '''指定模型''' model_name = 'ViT-B_16' #@param ["ViT-B_32", "Mixer-B_16"] pretrained_path=f'./test_result/{model_name}.npz' model_path=f'./test_result/{model_name}_imagenet2012.npz' '''加载数据集--微调用''' # dataset = 'cifar100' # imagenet2012 cifar10 cifar100 # batch_size = 512 # config = common_config.with_dataset(common_config.get_config(), dataset) # # config.shuffle_buffer=1000 # # config.accum_steps=64 # config.batch = batch_size # config.pp.crop = 384 # # 建立数据集 # ds_train = input_pipeline.get_data_from_tfds(config=config, mode='train') # ds_test = input_pipeline.get_data_from_tfds(config=config, mode='test') # num_classes = input_pipeline.get_dataset_info(dataset, 'train')['num_classes'] # del config # Only needed to instantiate datasets. # # Fetch a batch of test images for illustration purposes. # batch = next(iter(ds_test.as_numpy_iterator())) # # Note the shape : [num_local_devices, local_batch_size, h, w, c] # print("数据集shape:",batch['image'].shape) '''加载预训练模型--微调用''' # model_config = models_config.MODEL_CONFIGS[model_name] # print("模型config:",model_config) # # 加载模型定义并初始化随机参数。 # # 这也将模型编译为XLA(第一次需要几分钟)。 # if model_name.startswith('Mixer'): # model = models.MlpMixer(num_classes=num_classes, **model_config) # else: # model = models.VisionTransformer(num_classes=num_classes, **model_config) # variables = jax.jit(lambda: model.init( # jax.random.PRNGKey(0), # # 丢弃用于初始化的批处理的“num_local_devices”维度。 # batch['image'][0, :1], # train=False, # ), backend='cpu')() # #加载和转换预训练检查点。 # # 这涉及到加载实际的预训练模型结果,但也要修改一点参数,例如改变最终层,并调整位置嵌入的大小。有关详细信息,请参阅代码和本文的方法。 # params = checkpoint.load_pretrained( # pretrained_path=pretrained_path, # init_params=variables['params'], # model_config=model_config # ) '''评估''' # params_repl = flax.jax_utils.replicate(params) # print('params.cls:', type(params['head']['bias']).__name__, # params['head']['bias'].shape) # print('params_repl.cls:', type(params_repl['head']['bias']).__name__, # params_repl['head']['bias'].shape) # # 然后将调用映射到我们模型的forward pass到所有可用的设备。 # vit_apply_repl = jax.pmap(lambda params, inputs: model.apply( # dict(params=params), inputs, train=False)) # def get_accuracy(params_repl): # """返回对测试集求值的精度""" # good = total = 0 # steps = input_pipeline.get_dataset_info(dataset, 'test')['num_examples'] // batch_size # for _, batch in zip(tqdm.trange(steps), ds_test.as_numpy_iterator()): # predicted = vit_apply_repl(params_repl, batch['image']) # is_same = predicted.argmax(axis=-1) == batch['label'].argmax(axis=-1) # good += is_same.sum() # total += len(is_same.flatten()) # return good / total # # 模型的随机性能 # print(get_accuracy(params_repl)) '''微调''' # # 100 Steps take approximately 15 minutes in the TPU runtime. # total_steps = 50 # warmup_steps = 5 # decay_type = 'cosine' # grad_norm_clip = 1 # # 这控制了批处理被分割的转发次数。8适用于具有8个设备的TPU运行时。64应该可以在GPU上工作。当然,您也可以调整上面的batch_size,但这需要相应地调整学习率。 # accum_steps = 64 # TODO:可能要改 # base_lr = 0.03 # # 检查 train.make_update_fn # lr_fn = utils.create_learning_rate_schedule(total_steps, base_lr, decay_type, warmup_steps) # # 我们使用一个动量优化器,使用一半精度的状态来节省内存。它还实现了梯度裁剪 # tx = optax.chain( # optax.clip_by_global_norm(grad_norm_clip), # optax.sgd( # learning_rate=lr_fn, # momentum=0.9, # accumulator_dtype='bfloat16', # ), # ) # update_fn_repl = train.make_update_fn( # apply_fn=model.apply, accum_steps=accum_steps, tx=tx) # opt_state = tx.init(params) # opt_state_repl = flax.jax_utils.replicate(opt_state) # # Initialize PRNGs for dropout. # update_rng_repl = flax.jax_utils.replicate(jax.random.PRNGKey(0)) # # 训练更新 # losses = [] # lrs = [] # # Completes in ~20 min on the TPU runtime. # start = time.time() # for step, batch in zip( # tqdm.trange(1, total_steps + 1), # ds_train.as_numpy_iterator(), # ): # params_repl, opt_state_repl, loss_repl, update_rng_repl = update_fn_repl( # params_repl, opt_state_repl, batch, update_rng_repl) # losses.append(loss_repl[0]) # lrs.append(lr_fn(step)) # end = time.time() # print(f"{model_name}_{dataset}_{total_steps}_{warmup_steps}微调时间为:",end-start) # print(get_accuracy(params_repl)) # 绘制学习率变化曲线并保存 # plt.plot(losses) # plt.savefig(f'./test_result/{model_name}_{dataset}/losses_plot.png') # plt.close() # plt.plot(lrs) # plt.savefig(f'./test_result/{model_name}_{dataset}/lrs_plot.png') # plt.close() # 在CIFAR10上,Mixer-B/16应该是~96.7%,vitb /32应该是97.7%(都是@224) '''推理''' model_config = models_config.MODEL_CONFIGS[model_name] print("模型config:",model_config) model = models.VisionTransformer(num_classes=1000, **model_config) assert os.path.exists(model_path) # 加载和转换预训练的检查点 params = checkpoint.load(model_path) params['pre_logits'] = {} # Need to restore empty leaf for Flax. # 获取图像标签. # get_ipython().system('wget https://storage.googleapis.com/bit_models/ilsvrc2012_wordnet_lemmas.txt') imagenet_labels = dict(enumerate(open('./dataset/ilsvrc2012_wordnet_lemmas.txt'))) # 得到一张具有正确尺寸的随机图片 # resolution = 224 if model_name.startswith('Mixer') else 384 # get_ipython().system('wget https://picsum.photos/$resolution -O picsum.jpg') img = PIL.Image.open('./dataset/picsum.jpg') # 预测 start_time=time.time() logits, = model.apply(dict(params=params), (np.array(img) / 128 - 1)[None, ...], train=False) end_time=time.time() preds = np.array(jax.nn.softmax(logits)) print("推理结果:time=",end_time-start_time) for idx in preds.argsort()[:-11:-1]: print(f'{preds[idx]:.5f} : {imagenet_labels[idx]}', end='')