test.py 6.73 KB
Newer Older
suily's avatar
suily committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
from vit_jax import checkpoint
from vit_jax import input_pipeline
from vit_jax import utils
from vit_jax import models
from vit_jax import train
from vit_jax.configs import common as common_config
from vit_jax.configs import models as models_config
from absl import logging
import flax
import jax
from matplotlib import pyplot as plt
import numpy as np
import optax
import tqdm
import os
logging.set_verbosity(logging.INFO)
# import PIL
import tensorflow_datasets as tfds
import time
# import tensorflow as tf

'''显示可用设备gpu的数量。'''
from jax.lib import xla_bridge
jax_test=xla_bridge.get_backend().platform
print(jax_test,jax.local_devices())
if not (jax_test=='gpu'):
	exit()

model_name = 'ViT-B_16'  #@param ["ViT-B_32", "Mixer-B_16"]
# assert os.path.exists(f'./test_result/{model_name}.npz')

'''加载数据集'''
dataset = 'cifar100'   # imagenet2012  cifar10  cifar100
batch_size = 512
config = common_config.with_dataset(common_config.get_config(), dataset) 
# config.shuffle_buffer=1000
# config.accum_steps=64
config.batch = batch_size
config.pp.crop = 384
# 建立数据集
ds_train = input_pipeline.get_data_from_tfds(config=config, mode='train')
ds_test = input_pipeline.get_data_from_tfds(config=config, mode='test')
num_classes = input_pipeline.get_dataset_info(dataset, 'train')['num_classes']
del config  # Only needed to instantiate datasets.
# Fetch a batch of test images for illustration purposes.
batch = next(iter(ds_test.as_numpy_iterator()))
# Note the shape : [num_local_devices, local_batch_size, h, w, c]
print(batch['image'].shape)
exit()

# tf.config.set_visible_devices([], 'GPU')
# print(tf.config.get_visible_devices('GPU'))
'''加载预训练模型'''
model_config = models_config.MODEL_CONFIGS[model_name]
print(model_config)
# 加载模型定义并初始化随机参数。
# 这也将模型编译为XLA(第一次需要几分钟)。
if model_name.startswith('Mixer'):
  model = models.MlpMixer(num_classes=num_classes, **model_config)
else:
  model = models.VisionTransformer(num_classes=num_classes, **model_config)
variables = jax.jit(lambda: model.init(
    jax.random.PRNGKey(0),
    # 丢弃用于初始化的批处理的“num_local_devices”维度。
    batch['image'][0, :1],
    train=False,
), backend='cpu')()
#加载和转换预训练检查点。
# 这涉及到加载实际的预训练模型结果,但也要修改一点参数,例如改变最终层,并调整位置嵌入的大小。有关详细信息,请参阅代码和本文的方法。
params = checkpoint.load_pretrained(
    pretrained_path=f'./test_result/{model_name}.npz',
    init_params=variables['params'],
    model_config=model_config
)

'''评估'''
params_repl = flax.jax_utils.replicate(params)
print('params.cls:', type(params['head']['bias']).__name__,
      params['head']['bias'].shape)
print('params_repl.cls:', type(params_repl['head']['bias']).__name__,
      params_repl['head']['bias'].shape)
# 然后将调用映射到我们模型的forward pass到所有可用的设备。
vit_apply_repl = jax.pmap(lambda params, inputs: model.apply(
    dict(params=params), inputs, train=False))

def get_accuracy(params_repl):
  """返回对测试集求值的精度"""
  good = total = 0
  steps = input_pipeline.get_dataset_info(dataset, 'test')['num_examples'] // batch_size
  for _, batch in zip(tqdm.trange(steps), ds_test.as_numpy_iterator()):
    predicted = vit_apply_repl(params_repl, batch['image'])
    is_same = predicted.argmax(axis=-1) == batch['label'].argmax(axis=-1)
    good += is_same.sum()
    total += len(is_same.flatten())
  return good / total
# 没有微调的随机性能。
print(get_accuracy(params_repl))
exit()

'''微调'''
# 100 Steps take approximately 15 minutes in the TPU runtime.
total_steps = 50
warmup_steps = 5
decay_type = 'cosine'
grad_norm_clip = 1
# 这控制了批处理被分割的转发次数。8适用于具有8个设备的TPU运行时。64应该可以在GPU上工作。当然,您也可以调整上面的batch_size,但这需要相应地调整学习率。
accum_steps = 64  # TODO:可能要改
base_lr = 0.03
# 检查 train.make_update_fn
lr_fn = utils.create_learning_rate_schedule(total_steps, base_lr, decay_type, warmup_steps)
# 我们使用一个动量优化器,使用一半精度的状态来节省内存。它还实现了梯度裁剪
tx = optax.chain(
    optax.clip_by_global_norm(grad_norm_clip),
    optax.sgd(
        learning_rate=lr_fn,
        momentum=0.9,
        accumulator_dtype='bfloat16',
    ),
)
update_fn_repl = train.make_update_fn(
    apply_fn=model.apply, accum_steps=accum_steps, tx=tx)
opt_state = tx.init(params)
opt_state_repl = flax.jax_utils.replicate(opt_state)
# Initialize PRNGs for dropout.
update_rng_repl = flax.jax_utils.replicate(jax.random.PRNGKey(0))
# 训练更新
losses = []
lrs = []
# Completes in ~20 min on the TPU runtime.
start = time.time() 
for step, batch in zip(
    tqdm.trange(1, total_steps + 1),
    ds_train.as_numpy_iterator(),
):
  params_repl, opt_state_repl, loss_repl, update_rng_repl = update_fn_repl(
      params_repl, opt_state_repl, batch, update_rng_repl)
  losses.append(loss_repl[0])
  lrs.append(lr_fn(step))
end = time.time()
print(f"{model_name}_{dataset}_{total_steps}_{warmup_steps}微调时间为:",end-start) 
print(get_accuracy(params_repl))
# 绘制学习率变化曲线并保存
plt.plot(losses)
plt.savefig(f'./test_result/{model_name}_{dataset}/losses_plot.png')
plt.close()
plt.plot(lrs)
plt.savefig(f'./test_result/{model_name}_{dataset}/lrs_plot.png')
plt.close()
# 在CIFAR10上,Mixer-B/16应该是~96.7%,vitb /32应该是97.7%(都是@224)
exit()

# exit()
'''推理'''
# #下载一个预训练的模型
# model_name = 'ViT-L_16' 
# model_config = models_config.MODEL_CONFIGS[model_name]
# print(model_config)
# model = models.VisionTransformer(num_classes=1000, **model_config)
# assert os.path.exists(f'./test_result/{model_name}_imagenet2012.npz')
# # 加载和转换预训练的检查点
# params = checkpoint.load(f'./test_result/{model_name}_imagenet2012.npz')
# params['pre_logits'] = {}  # Need to restore empty leaf for Flax.
# # 获取图像标签.
# # get_ipython().system('wget https://storage.googleapis.com/bit_models/ilsvrc2012_wordnet_lemmas.txt')
# imagenet_labels = dict(enumerate(open('./test_result/ilsvrc2012_wordnet_lemmas.txt')))
# # 得到一张具有正确尺寸的随机图片
# # resolution = 224 if model_name.startswith('Mixer') else 384
# # get_ipython().system('wget https://picsum.photos/$resolution -O picsum.jpg')
# img = PIL.Image.open('./test_result/picsum.jpg')
# # 预测单个项目的批处理(注意非常高效的TPU使用…)
# logits, = model.apply(dict(params=params), (np.array(img) / 128 - 1)[None, ...], train=False)
# preds = np.array(jax.nn.softmax(logits))
# for idx in preds.argsort()[:-11:-1]:
#   print(f'{preds[idx]:.5f} : {imagenet_labels[idx]}', end='')