test.py 6.74 KB
Newer Older
suily's avatar
suily committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
from vit_jax import checkpoint
from vit_jax import input_pipeline
from vit_jax import utils
from vit_jax import models
from vit_jax import train
from vit_jax.configs import common as common_config
from vit_jax.configs import models as models_config
from absl import logging
import flax
import jax
from matplotlib import pyplot as plt
import numpy as np
import optax
import tqdm
import os
logging.set_verbosity(logging.INFO)
suily's avatar
suily committed
17
import PIL
suily's avatar
suily committed
18
19
20
21
import tensorflow_datasets as tfds
import time
# import tensorflow as tf

suily's avatar
suily committed
22
'''测试dcu/gpu'''
suily's avatar
suily committed
23
24
25
26
27
from jax.lib import xla_bridge
jax_test=xla_bridge.get_backend().platform
if not (jax_test=='gpu'):
	exit()

suily's avatar
suily committed
28
'''指定模型'''
suily's avatar
suily committed
29
model_name = 'ViT-B_16'  #@param ["ViT-B_32", "Mixer-B_16"]
suily's avatar
suily committed
30
31
pretrained_path=f'./test_result/{model_name}.npz'
model_path=f'./test_result/{model_name}_imagenet2012.npz'
suily's avatar
suily committed
32

suily's avatar
suily committed
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
'''加载数据集--微调用'''
# dataset = 'cifar100'   # imagenet2012  cifar10  cifar100
# batch_size = 512
# config = common_config.with_dataset(common_config.get_config(), dataset) 
# # config.shuffle_buffer=1000
# # config.accum_steps=64
# config.batch = batch_size
# config.pp.crop = 384
# # 建立数据集
# ds_train = input_pipeline.get_data_from_tfds(config=config, mode='train')
# ds_test = input_pipeline.get_data_from_tfds(config=config, mode='test')
# num_classes = input_pipeline.get_dataset_info(dataset, 'train')['num_classes']
# del config  # Only needed to instantiate datasets.
# # Fetch a batch of test images for illustration purposes.
# batch = next(iter(ds_test.as_numpy_iterator()))
# # Note the shape : [num_local_devices, local_batch_size, h, w, c]
# print("数据集shape:",batch['image'].shape)
suily's avatar
suily committed
50

suily's avatar
suily committed
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
'''加载预训练模型--微调用'''
# model_config = models_config.MODEL_CONFIGS[model_name]
# print("模型config:",model_config)
# # 加载模型定义并初始化随机参数。
# # 这也将模型编译为XLA(第一次需要几分钟)。
# if model_name.startswith('Mixer'):
#   model = models.MlpMixer(num_classes=num_classes, **model_config)
# else:
#   model = models.VisionTransformer(num_classes=num_classes, **model_config)
# variables = jax.jit(lambda: model.init(
#     jax.random.PRNGKey(0),
#     # 丢弃用于初始化的批处理的“num_local_devices”维度。
#     batch['image'][0, :1],
#     train=False,
# ), backend='cpu')()
# #加载和转换预训练检查点。
# # 这涉及到加载实际的预训练模型结果,但也要修改一点参数,例如改变最终层,并调整位置嵌入的大小。有关详细信息,请参阅代码和本文的方法。
# params = checkpoint.load_pretrained(
#     pretrained_path=pretrained_path, 
#     init_params=variables['params'],
#     model_config=model_config
# )
suily's avatar
suily committed
73
74

'''评估'''
suily's avatar
suily committed
75
76
77
78
79
80
81
82
# params_repl = flax.jax_utils.replicate(params)
# print('params.cls:', type(params['head']['bias']).__name__,
#       params['head']['bias'].shape)
# print('params_repl.cls:', type(params_repl['head']['bias']).__name__,
#       params_repl['head']['bias'].shape)
# # 然后将调用映射到我们模型的forward pass到所有可用的设备。
# vit_apply_repl = jax.pmap(lambda params, inputs: model.apply(
#     dict(params=params), inputs, train=False))
suily's avatar
suily committed
83

suily's avatar
suily committed
84
85
86
87
88
89
90
91
92
93
94
95
# def get_accuracy(params_repl):
#   """返回对测试集求值的精度"""
#   good = total = 0
#   steps = input_pipeline.get_dataset_info(dataset, 'test')['num_examples'] // batch_size
#   for _, batch in zip(tqdm.trange(steps), ds_test.as_numpy_iterator()):
#     predicted = vit_apply_repl(params_repl, batch['image'])
#     is_same = predicted.argmax(axis=-1) == batch['label'].argmax(axis=-1)
#     good += is_same.sum()
#     total += len(is_same.flatten())
#   return good / total
# # 模型的随机性能
# print(get_accuracy(params_repl))
suily's avatar
suily committed
96
97

'''微调'''
suily's avatar
suily committed
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
# # 100 Steps take approximately 15 minutes in the TPU runtime.
# total_steps = 50
# warmup_steps = 5
# decay_type = 'cosine'
# grad_norm_clip = 1
# # 这控制了批处理被分割的转发次数。8适用于具有8个设备的TPU运行时。64应该可以在GPU上工作。当然,您也可以调整上面的batch_size,但这需要相应地调整学习率。
# accum_steps = 64  # TODO:可能要改
# base_lr = 0.03
# # 检查 train.make_update_fn
# lr_fn = utils.create_learning_rate_schedule(total_steps, base_lr, decay_type, warmup_steps)
# # 我们使用一个动量优化器,使用一半精度的状态来节省内存。它还实现了梯度裁剪
# tx = optax.chain(
#     optax.clip_by_global_norm(grad_norm_clip),
#     optax.sgd(
#         learning_rate=lr_fn,
#         momentum=0.9,
#         accumulator_dtype='bfloat16',
#     ),
# )
# update_fn_repl = train.make_update_fn(
#     apply_fn=model.apply, accum_steps=accum_steps, tx=tx)
# opt_state = tx.init(params)
# opt_state_repl = flax.jax_utils.replicate(opt_state)
# # Initialize PRNGs for dropout.
# update_rng_repl = flax.jax_utils.replicate(jax.random.PRNGKey(0))
# # 训练更新
# losses = []
# lrs = []
# # Completes in ~20 min on the TPU runtime.
# start = time.time() 
# for step, batch in zip(
#     tqdm.trange(1, total_steps + 1),
#     ds_train.as_numpy_iterator(),
# ):
#   params_repl, opt_state_repl, loss_repl, update_rng_repl = update_fn_repl(
#       params_repl, opt_state_repl, batch, update_rng_repl)
#   losses.append(loss_repl[0])
#   lrs.append(lr_fn(step))
# end = time.time()
# print(f"{model_name}_{dataset}_{total_steps}_{warmup_steps}微调时间为:",end-start) 
# print(get_accuracy(params_repl))
suily's avatar
suily committed
139
# 绘制学习率变化曲线并保存
suily's avatar
suily committed
140
141
142
143
144
145
# plt.plot(losses)
# plt.savefig(f'./test_result/{model_name}_{dataset}/losses_plot.png')
# plt.close()
# plt.plot(lrs)
# plt.savefig(f'./test_result/{model_name}_{dataset}/lrs_plot.png')
# plt.close()
suily's avatar
suily committed
146
147
148
# 在CIFAR10上,Mixer-B/16应该是~96.7%,vitb /32应该是97.7%(都是@224)

'''推理'''
suily's avatar
suily committed
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
model_config = models_config.MODEL_CONFIGS[model_name]
print("模型config:",model_config)
model = models.VisionTransformer(num_classes=1000, **model_config)
assert os.path.exists(model_path)
# 加载和转换预训练的检查点
params = checkpoint.load(model_path)
params['pre_logits'] = {}  # Need to restore empty leaf for Flax.
# 获取图像标签.
# get_ipython().system('wget https://storage.googleapis.com/bit_models/ilsvrc2012_wordnet_lemmas.txt')
imagenet_labels = dict(enumerate(open('./dataset/ilsvrc2012_wordnet_lemmas.txt')))
# 得到一张具有正确尺寸的随机图片
# resolution = 224 if model_name.startswith('Mixer') else 384
# get_ipython().system('wget https://picsum.photos/$resolution -O picsum.jpg')
img = PIL.Image.open('./dataset/picsum.jpg')
# 预测
start_time=time.time()
logits, = model.apply(dict(params=params), (np.array(img) / 128 - 1)[None, ...], train=False)
end_time=time.time()
preds = np.array(jax.nn.softmax(logits))
print("推理结果:time=",end_time-start_time)
for idx in preds.argsort()[:-11:-1]:
  print(f'{preds[idx]:.5f} : {imagenet_labels[idx]}', end='')