Commit 67a1d360 authored by suily's avatar suily
Browse files

添加推理结果

parent 188f0cfa
...@@ -76,6 +76,7 @@ pip install -r requirements.txt ...@@ -76,6 +76,7 @@ pip install -r requirements.txt
pip install tensorflow-cpu==2.13.1 pip install tensorflow-cpu==2.13.1
``` ```
## 数据集 ## 数据集
### 训练数据集
`cifar10 cifar100` `cifar10 cifar100`
数据集由tensorflow_datasets自动下载和处理,相关代码见vision_transformer/vit_jax/input_pipeline.py 数据集由tensorflow_datasets自动下载和处理,相关代码见vision_transformer/vit_jax/input_pipeline.py
...@@ -102,11 +103,25 @@ vim /usr/local/lib/python3.10/site-packages/tensorflow_datasets/core/utils/gcs_u ...@@ -102,11 +103,25 @@ vim /usr/local/lib/python3.10/site-packages/tensorflow_datasets/core/utils/gcs_u
│    ├── features.json │    ├── features.json
│ └── label.labels.txt │ └── label.labels.txt
``` ```
### 推理数据集
推理所用图片和文件可根据以下代码进行下载:
```
# ./dataset是存储地址,可自订
wget https://storage.googleapis.com/bit_models/ilsvrc2012_wordnet_lemmas.txt -P ./dataset
wget https://picsum.photos/384 -O ./dataset/picsum.jpg # 将图片调整为384分辨率
```
数据集目录结构如下:
```
── dataset
│   ├── ilsvrc2012_wordnet_lemmas.txt
│ └── picsum.jpg
```
## 训练 ## 训练
检查点可通过以下方式进行下载: 检查点可通过以下方式进行下载:
``` ```
cd /your_code_path/vision_transformer/test_result # test_result为检查点下载地址,可自订 cd /your_code_path/vision_transformer/test_result # test_result为检查点下载地址,可自订
wget https://storage.googleapis.com/vit_models/imagenet21k/ViT-B_16.npz wget https://storage.googleapis.com/vit_models/imagenet21k/ViT-B_16.npz
wget https://storage.googleapis.com/vit_models/imagenet21k/ViT-L_16.npz
``` ```
### 单机单卡 ### 单机单卡
``` ```
...@@ -124,24 +139,47 @@ sh test.sh ...@@ -124,24 +139,47 @@ sh test.sh
# config.optim_dtype='bfloat16' # 精度 # config.optim_dtype='bfloat16' # 精度
``` ```
## 推理 ## 推理
检查点可通过以下方式进行下载:
```
cd /your_code_path/vision_transformer/test_result # test_result为检查点下载地址,可自订
wget https://storage.googleapis.com/vit_models/imagenet21k+imagenet2012/ViT-B_16.npz -O ViT-B_16_imagenet2012.npz
wget https://storage.googleapis.com/vit_models/imagenet21k+imagenet2012/ViT-L_16.npz -O ViT-L_16_imagenet2012.npz
``` ```
```
cd /your_code_path/vision_transformer
python test.py python test.py
``` ```
## result ## result
此处填算法效果测试图(包括输入、输出) 测试图为:
<div align=center> <div align=center>
<img src="./doc/xxx.png"/> <img src="./doc/picsum.jpg"/>
</div> </div>
```
dcu推理结果:
0.73861 : alp
0.24576 : valley, vale
0.00416 : lakeside, lakeshore
0.00404 : cliff, drop, drop-off
0.00094 : promontory, headland, head, foreland
0.00060 : mountain_tent
0.00055 : dam, dike, dyke
0.00033 : volcano
0.00031 : ski
0.00012 : solar_dish, solar_collector, solar_furnace
gpu推理结果:
0.73976 : alp
0.24465 : valley, vale
0.00414 : lakeside, lakeshore
0.00404 : cliff, drop, drop-off
0.00094 : promontory, headland, head, foreland
0.00060 : mountain_tent
0.00054 : dam, dike, dyke
0.00033 : volcano
0.00031 : ski
0.00012 : solar_dish, solar_collector, solar_furnace
```
### 精度 ### 精度
测试数据:[test data](链接),使用的加速卡:xxx。
根据测试结果情况填写表格:
| xxx | xxx | xxx | xxx | xxx |
| :------: | :------: | :------: | :------: |:------: |
| xxx | xxx | xxx | xxx | xxx |
| xxx | xx | xxx | xxx | xxx |
## 应用场景 ## 应用场景
### 算法类别 ### 算法类别
`图像识别` `图像识别`
......
This diff is collapsed.
...@@ -14,7 +14,7 @@ import optax ...@@ -14,7 +14,7 @@ import optax
import tqdm import tqdm
import os import os
logging.set_verbosity(logging.INFO) logging.set_verbosity(logging.INFO)
# import PIL import PIL
import tensorflow_datasets as tfds import tensorflow_datasets as tfds
import time import time
# import tensorflow as tf # import tensorflow as tf
...@@ -22,153 +22,149 @@ import time ...@@ -22,153 +22,149 @@ import time
'''显示可用设备gpu的数量。''' '''显示可用设备gpu的数量。'''
from jax.lib import xla_bridge from jax.lib import xla_bridge
jax_test=xla_bridge.get_backend().platform jax_test=xla_bridge.get_backend().platform
print(jax_test,jax.local_devices())
if not (jax_test=='gpu'): if not (jax_test=='gpu'):
exit() exit()
model_name = 'ViT-B_16' #@param ["ViT-B_32", "Mixer-B_16"] '''指定模型'''
# assert os.path.exists(f'./test_result/{model_name}.npz') model_name = 'ViT-L_16' #@param ["ViT-B_32", "Mixer-B_16"]
pretrained_path=f'./test_result/{model_name}.npz'
model_path=f'./test_result/{model_name}_imagenet2012.npz'
'''加载数据集''' '''加载数据集--微调用'''
dataset = 'cifar100' # imagenet2012 cifar10 cifar100 # dataset = 'cifar100' # imagenet2012 cifar10 cifar100
batch_size = 512 # batch_size = 512
config = common_config.with_dataset(common_config.get_config(), dataset) # config = common_config.with_dataset(common_config.get_config(), dataset)
# config.shuffle_buffer=1000 # # config.shuffle_buffer=1000
# config.accum_steps=64 # # config.accum_steps=64
config.batch = batch_size # config.batch = batch_size
config.pp.crop = 384 # config.pp.crop = 384
# 建立数据集 # # 建立数据集
ds_train = input_pipeline.get_data_from_tfds(config=config, mode='train') # ds_train = input_pipeline.get_data_from_tfds(config=config, mode='train')
ds_test = input_pipeline.get_data_from_tfds(config=config, mode='test') # ds_test = input_pipeline.get_data_from_tfds(config=config, mode='test')
num_classes = input_pipeline.get_dataset_info(dataset, 'train')['num_classes'] # num_classes = input_pipeline.get_dataset_info(dataset, 'train')['num_classes']
del config # Only needed to instantiate datasets. # del config # Only needed to instantiate datasets.
# Fetch a batch of test images for illustration purposes. # # Fetch a batch of test images for illustration purposes.
batch = next(iter(ds_test.as_numpy_iterator())) # batch = next(iter(ds_test.as_numpy_iterator()))
# Note the shape : [num_local_devices, local_batch_size, h, w, c] # # Note the shape : [num_local_devices, local_batch_size, h, w, c]
print(batch['image'].shape) # print("数据集shape:",batch['image'].shape)
exit()
# tf.config.set_visible_devices([], 'GPU') '''加载预训练模型--微调用'''
# print(tf.config.get_visible_devices('GPU')) # model_config = models_config.MODEL_CONFIGS[model_name]
'''加载预训练模型''' # print("模型config:",model_config)
model_config = models_config.MODEL_CONFIGS[model_name] # # 加载模型定义并初始化随机参数。
print(model_config) # # 这也将模型编译为XLA(第一次需要几分钟)。
# 加载模型定义并初始化随机参数。 # if model_name.startswith('Mixer'):
# 这也将模型编译为XLA(第一次需要几分钟)。 # model = models.MlpMixer(num_classes=num_classes, **model_config)
if model_name.startswith('Mixer'): # else:
model = models.MlpMixer(num_classes=num_classes, **model_config) # model = models.VisionTransformer(num_classes=num_classes, **model_config)
else: # variables = jax.jit(lambda: model.init(
model = models.VisionTransformer(num_classes=num_classes, **model_config) # jax.random.PRNGKey(0),
variables = jax.jit(lambda: model.init( # # 丢弃用于初始化的批处理的“num_local_devices”维度。
jax.random.PRNGKey(0), # batch['image'][0, :1],
# 丢弃用于初始化的批处理的“num_local_devices”维度。 # train=False,
batch['image'][0, :1], # ), backend='cpu')()
train=False, # #加载和转换预训练检查点。
), backend='cpu')() # # 这涉及到加载实际的预训练模型结果,但也要修改一点参数,例如改变最终层,并调整位置嵌入的大小。有关详细信息,请参阅代码和本文的方法。
#加载和转换预训练检查点。 # params = checkpoint.load_pretrained(
# 这涉及到加载实际的预训练模型结果,但也要修改一点参数,例如改变最终层,并调整位置嵌入的大小。有关详细信息,请参阅代码和本文的方法。 # pretrained_path=pretrained_path,
params = checkpoint.load_pretrained( # init_params=variables['params'],
pretrained_path=f'./test_result/{model_name}.npz', # model_config=model_config
init_params=variables['params'], # )
model_config=model_config
)
'''评估''' '''评估'''
params_repl = flax.jax_utils.replicate(params) # params_repl = flax.jax_utils.replicate(params)
print('params.cls:', type(params['head']['bias']).__name__, # print('params.cls:', type(params['head']['bias']).__name__,
params['head']['bias'].shape) # params['head']['bias'].shape)
print('params_repl.cls:', type(params_repl['head']['bias']).__name__, # print('params_repl.cls:', type(params_repl['head']['bias']).__name__,
params_repl['head']['bias'].shape) # params_repl['head']['bias'].shape)
# 然后将调用映射到我们模型的forward pass到所有可用的设备。 # # 然后将调用映射到我们模型的forward pass到所有可用的设备。
vit_apply_repl = jax.pmap(lambda params, inputs: model.apply( # vit_apply_repl = jax.pmap(lambda params, inputs: model.apply(
dict(params=params), inputs, train=False)) # dict(params=params), inputs, train=False))
def get_accuracy(params_repl): # def get_accuracy(params_repl):
"""返回对测试集求值的精度""" # """返回对测试集求值的精度"""
good = total = 0 # good = total = 0
steps = input_pipeline.get_dataset_info(dataset, 'test')['num_examples'] // batch_size # steps = input_pipeline.get_dataset_info(dataset, 'test')['num_examples'] // batch_size
for _, batch in zip(tqdm.trange(steps), ds_test.as_numpy_iterator()): # for _, batch in zip(tqdm.trange(steps), ds_test.as_numpy_iterator()):
predicted = vit_apply_repl(params_repl, batch['image']) # predicted = vit_apply_repl(params_repl, batch['image'])
is_same = predicted.argmax(axis=-1) == batch['label'].argmax(axis=-1) # is_same = predicted.argmax(axis=-1) == batch['label'].argmax(axis=-1)
good += is_same.sum() # good += is_same.sum()
total += len(is_same.flatten()) # total += len(is_same.flatten())
return good / total # return good / total
# 没有微调的随机性能。 # # 模型的随机性能
print(get_accuracy(params_repl)) # print(get_accuracy(params_repl))
exit()
'''微调''' '''微调'''
# 100 Steps take approximately 15 minutes in the TPU runtime. # # 100 Steps take approximately 15 minutes in the TPU runtime.
total_steps = 50 # total_steps = 50
warmup_steps = 5 # warmup_steps = 5
decay_type = 'cosine' # decay_type = 'cosine'
grad_norm_clip = 1 # grad_norm_clip = 1
# 这控制了批处理被分割的转发次数。8适用于具有8个设备的TPU运行时。64应该可以在GPU上工作。当然,您也可以调整上面的batch_size,但这需要相应地调整学习率。 # # 这控制了批处理被分割的转发次数。8适用于具有8个设备的TPU运行时。64应该可以在GPU上工作。当然,您也可以调整上面的batch_size,但这需要相应地调整学习率。
accum_steps = 64 # TODO:可能要改 # accum_steps = 64 # TODO:可能要改
base_lr = 0.03 # base_lr = 0.03
# 检查 train.make_update_fn # # 检查 train.make_update_fn
lr_fn = utils.create_learning_rate_schedule(total_steps, base_lr, decay_type, warmup_steps) # lr_fn = utils.create_learning_rate_schedule(total_steps, base_lr, decay_type, warmup_steps)
# 我们使用一个动量优化器,使用一半精度的状态来节省内存。它还实现了梯度裁剪 # # 我们使用一个动量优化器,使用一半精度的状态来节省内存。它还实现了梯度裁剪
tx = optax.chain( # tx = optax.chain(
optax.clip_by_global_norm(grad_norm_clip), # optax.clip_by_global_norm(grad_norm_clip),
optax.sgd( # optax.sgd(
learning_rate=lr_fn, # learning_rate=lr_fn,
momentum=0.9, # momentum=0.9,
accumulator_dtype='bfloat16', # accumulator_dtype='bfloat16',
), # ),
) # )
update_fn_repl = train.make_update_fn( # update_fn_repl = train.make_update_fn(
apply_fn=model.apply, accum_steps=accum_steps, tx=tx) # apply_fn=model.apply, accum_steps=accum_steps, tx=tx)
opt_state = tx.init(params) # opt_state = tx.init(params)
opt_state_repl = flax.jax_utils.replicate(opt_state) # opt_state_repl = flax.jax_utils.replicate(opt_state)
# Initialize PRNGs for dropout. # # Initialize PRNGs for dropout.
update_rng_repl = flax.jax_utils.replicate(jax.random.PRNGKey(0)) # update_rng_repl = flax.jax_utils.replicate(jax.random.PRNGKey(0))
# 训练更新 # # 训练更新
losses = [] # losses = []
lrs = [] # lrs = []
# Completes in ~20 min on the TPU runtime. # # Completes in ~20 min on the TPU runtime.
start = time.time() # start = time.time()
for step, batch in zip( # for step, batch in zip(
tqdm.trange(1, total_steps + 1), # tqdm.trange(1, total_steps + 1),
ds_train.as_numpy_iterator(), # ds_train.as_numpy_iterator(),
): # ):
params_repl, opt_state_repl, loss_repl, update_rng_repl = update_fn_repl( # params_repl, opt_state_repl, loss_repl, update_rng_repl = update_fn_repl(
params_repl, opt_state_repl, batch, update_rng_repl) # params_repl, opt_state_repl, batch, update_rng_repl)
losses.append(loss_repl[0]) # losses.append(loss_repl[0])
lrs.append(lr_fn(step)) # lrs.append(lr_fn(step))
end = time.time() # end = time.time()
print(f"{model_name}_{dataset}_{total_steps}_{warmup_steps}微调时间为:",end-start) # print(f"{model_name}_{dataset}_{total_steps}_{warmup_steps}微调时间为:",end-start)
print(get_accuracy(params_repl)) # print(get_accuracy(params_repl))
# 绘制学习率变化曲线并保存 # 绘制学习率变化曲线并保存
plt.plot(losses) # plt.plot(losses)
plt.savefig(f'./test_result/{model_name}_{dataset}/losses_plot.png') # plt.savefig(f'./test_result/{model_name}_{dataset}/losses_plot.png')
plt.close() # plt.close()
plt.plot(lrs) # plt.plot(lrs)
plt.savefig(f'./test_result/{model_name}_{dataset}/lrs_plot.png') # plt.savefig(f'./test_result/{model_name}_{dataset}/lrs_plot.png')
plt.close() # plt.close()
# 在CIFAR10上,Mixer-B/16应该是~96.7%,vitb /32应该是97.7%(都是@224) # 在CIFAR10上,Mixer-B/16应该是~96.7%,vitb /32应该是97.7%(都是@224)
exit()
# exit()
'''推理''' '''推理'''
# #下载一个预训练的模型 model_config = models_config.MODEL_CONFIGS[model_name]
# model_name = 'ViT-L_16' print("模型config:",model_config)
# model_config = models_config.MODEL_CONFIGS[model_name] model = models.VisionTransformer(num_classes=1000, **model_config)
# print(model_config) assert os.path.exists(model_path)
# model = models.VisionTransformer(num_classes=1000, **model_config) # 加载和转换预训练的检查点
# assert os.path.exists(f'./test_result/{model_name}_imagenet2012.npz') params = checkpoint.load(model_path)
# # 加载和转换预训练的检查点 params['pre_logits'] = {} # Need to restore empty leaf for Flax.
# params = checkpoint.load(f'./test_result/{model_name}_imagenet2012.npz') # 获取图像标签.
# params['pre_logits'] = {} # Need to restore empty leaf for Flax. # get_ipython().system('wget https://storage.googleapis.com/bit_models/ilsvrc2012_wordnet_lemmas.txt')
# # 获取图像标签. imagenet_labels = dict(enumerate(open('./dataset/ilsvrc2012_wordnet_lemmas.txt')))
# # get_ipython().system('wget https://storage.googleapis.com/bit_models/ilsvrc2012_wordnet_lemmas.txt') # 得到一张具有正确尺寸的随机图片
# imagenet_labels = dict(enumerate(open('./test_result/ilsvrc2012_wordnet_lemmas.txt'))) # resolution = 224 if model_name.startswith('Mixer') else 384
# # 得到一张具有正确尺寸的随机图片 # get_ipython().system('wget https://picsum.photos/$resolution -O picsum.jpg')
# # resolution = 224 if model_name.startswith('Mixer') else 384 img = PIL.Image.open('./dataset/picsum.jpg')
# # get_ipython().system('wget https://picsum.photos/$resolution -O picsum.jpg') # 预测
# img = PIL.Image.open('./test_result/picsum.jpg') start_time=time.time()
# # 预测单个项目的批处理(注意非常高效的TPU使用…) logits, = model.apply(dict(params=params), (np.array(img) / 128 - 1)[None, ...], train=False)
# logits, = model.apply(dict(params=params), (np.array(img) / 128 - 1)[None, ...], train=False) end_time=time.time()
# preds = np.array(jax.nn.softmax(logits)) preds = np.array(jax.nn.softmax(logits))
# for idx in preds.argsort()[:-11:-1]: print("推理结果:time=",end_time-start_time)
# print(f'{preds[idx]:.5f} : {imagenet_labels[idx]}', end='') for idx in preds.argsort()[:-11:-1]:
\ No newline at end of file print(f'{preds[idx]:.5f} : {imagenet_labels[idx]}', end='')
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment