Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
vision_transformer_jax
Commits
67a1d360
Commit
67a1d360
authored
Aug 30, 2024
by
suily
Browse files
添加推理结果
parent
188f0cfa
Changes
5
Expand all
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
1182 additions
and
148 deletions
+1182
-148
README.md
README.md
+49
-11
dataset/ilsvrc2012_wordnet_lemmas.txt
dataset/ilsvrc2012_wordnet_lemmas.txt
+1000
-0
dataset/picsum.jpg
dataset/picsum.jpg
+0
-0
doc/picsum.jpg
doc/picsum.jpg
+0
-0
test.py
test.py
+133
-137
No files found.
README.md
View file @
67a1d360
...
@@ -76,6 +76,7 @@ pip install -r requirements.txt
...
@@ -76,6 +76,7 @@ pip install -r requirements.txt
pip install tensorflow-cpu==2.13.1
pip install tensorflow-cpu==2.13.1
```
```
## 数据集
## 数据集
### 训练数据集
`cifar10 cifar100`
`cifar10 cifar100`
数据集由tensorflow_datasets自动下载和处理,相关代码见vision_transformer/vit_jax/input_pipeline.py
数据集由tensorflow_datasets自动下载和处理,相关代码见vision_transformer/vit_jax/input_pipeline.py
...
@@ -102,11 +103,25 @@ vim /usr/local/lib/python3.10/site-packages/tensorflow_datasets/core/utils/gcs_u
...
@@ -102,11 +103,25 @@ vim /usr/local/lib/python3.10/site-packages/tensorflow_datasets/core/utils/gcs_u
│ ├── features.json
│ ├── features.json
│ └── label.labels.txt
│ └── label.labels.txt
```
```
### 推理数据集
推理所用图片和文件可根据以下代码进行下载:
```
# ./dataset是存储地址,可自订
wget https://storage.googleapis.com/bit_models/ilsvrc2012_wordnet_lemmas.txt -P ./dataset
wget https://picsum.photos/384 -O ./dataset/picsum.jpg # 将图片调整为384分辨率
```
数据集目录结构如下:
```
── dataset
│ ├── ilsvrc2012_wordnet_lemmas.txt
│ └── picsum.jpg
```
## 训练
## 训练
检查点可通过以下方式进行下载:
检查点可通过以下方式进行下载:
```
```
cd /your_code_path/vision_transformer/test_result # test_result为检查点下载地址,可自订
cd /your_code_path/vision_transformer/test_result # test_result为检查点下载地址,可自订
wget https://storage.googleapis.com/vit_models/imagenet21k/ViT-B_16.npz
wget https://storage.googleapis.com/vit_models/imagenet21k/ViT-B_16.npz
wget https://storage.googleapis.com/vit_models/imagenet21k/ViT-L_16.npz
```
```
### 单机单卡
### 单机单卡
```
```
...
@@ -124,24 +139,47 @@ sh test.sh
...
@@ -124,24 +139,47 @@ sh test.sh
# config.optim_dtype='bfloat16' # 精度
# config.optim_dtype='bfloat16' # 精度
```
```
## 推理
## 推理
检查点可通过以下方式进行下载:
```
cd /your_code_path/vision_transformer/test_result # test_result为检查点下载地址,可自订
wget https://storage.googleapis.com/vit_models/imagenet21k+imagenet2012/ViT-B_16.npz -O ViT-B_16_imagenet2012.npz
wget https://storage.googleapis.com/vit_models/imagenet21k+imagenet2012/ViT-L_16.npz -O ViT-L_16_imagenet2012.npz
```
```
```
cd /your_code_path/vision_transformer
python test.py
python test.py
```
```
## result
## result
此处填算法效果测试图(包括输入、输出)
测试图为:
<div
align=
center
>
<div
align=
center
>
<img
src=
"./doc/
xxx.pn
g"
/>
<img
src=
"./doc/
picsum.jp
g"
/>
</div>
</div>
```
dcu推理结果:
0.73861 : alp
0.24576 : valley, vale
0.00416 : lakeside, lakeshore
0.00404 : cliff, drop, drop-off
0.00094 : promontory, headland, head, foreland
0.00060 : mountain_tent
0.00055 : dam, dike, dyke
0.00033 : volcano
0.00031 : ski
0.00012 : solar_dish, solar_collector, solar_furnace
gpu推理结果:
0.73976 : alp
0.24465 : valley, vale
0.00414 : lakeside, lakeshore
0.00404 : cliff, drop, drop-off
0.00094 : promontory, headland, head, foreland
0.00060 : mountain_tent
0.00054 : dam, dike, dyke
0.00033 : volcano
0.00031 : ski
0.00012 : solar_dish, solar_collector, solar_furnace
```
### 精度
### 精度
测试数据:
[
test data
](
链接
)
,使用的加速卡:xxx。
无
根据测试结果情况填写表格:
| xxx | xxx | xxx | xxx | xxx |
| :------: | :------: | :------: | :------: |:------: |
| xxx | xxx | xxx | xxx | xxx |
| xxx | xx | xxx | xxx | xxx |
## 应用场景
## 应用场景
### 算法类别
### 算法类别
`图像识别`
`图像识别`
...
...
dataset/ilsvrc2012_wordnet_lemmas.txt
0 → 100644
View file @
67a1d360
This diff is collapsed.
Click to expand it.
dataset/picsum.jpg
0 → 100644
View file @
67a1d360
21.9 KB
doc/picsum.jpg
0 → 100644
View file @
67a1d360
21.9 KB
test.py
View file @
67a1d360
...
@@ -14,7 +14,7 @@ import optax
...
@@ -14,7 +14,7 @@ import optax
import
tqdm
import
tqdm
import
os
import
os
logging
.
set_verbosity
(
logging
.
INFO
)
logging
.
set_verbosity
(
logging
.
INFO
)
#
import PIL
import
PIL
import
tensorflow_datasets
as
tfds
import
tensorflow_datasets
as
tfds
import
time
import
time
# import tensorflow as tf
# import tensorflow as tf
...
@@ -22,153 +22,149 @@ import time
...
@@ -22,153 +22,149 @@ import time
'''显示可用设备gpu的数量。'''
'''显示可用设备gpu的数量。'''
from
jax.lib
import
xla_bridge
from
jax.lib
import
xla_bridge
jax_test
=
xla_bridge
.
get_backend
().
platform
jax_test
=
xla_bridge
.
get_backend
().
platform
print
(
jax_test
,
jax
.
local_devices
())
if
not
(
jax_test
==
'gpu'
):
if
not
(
jax_test
==
'gpu'
):
exit
()
exit
()
model_name
=
'ViT-B_16'
#@param ["ViT-B_32", "Mixer-B_16"]
'''指定模型'''
# assert os.path.exists(f'./test_result/{model_name}.npz')
model_name
=
'ViT-L_16'
#@param ["ViT-B_32", "Mixer-B_16"]
pretrained_path
=
f
'./test_result/
{
model_name
}
.npz'
model_path
=
f
'./test_result/
{
model_name
}
_imagenet2012.npz'
'''加载数据集'''
'''加载数据集--微调用'''
dataset
=
'cifar100'
# imagenet2012 cifar10 cifar100
# dataset = 'cifar100' # imagenet2012 cifar10 cifar100
batch_size
=
512
# batch_size = 512
config
=
common_config
.
with_dataset
(
common_config
.
get_config
(),
dataset
)
# config = common_config.with_dataset(common_config.get_config(), dataset)
# config.shuffle_buffer=1000
# # config.shuffle_buffer=1000
# config.accum_steps=64
# # config.accum_steps=64
config
.
batch
=
batch_size
# config.batch = batch_size
config
.
pp
.
crop
=
384
# config.pp.crop = 384
# 建立数据集
# # 建立数据集
ds_train
=
input_pipeline
.
get_data_from_tfds
(
config
=
config
,
mode
=
'train'
)
# ds_train = input_pipeline.get_data_from_tfds(config=config, mode='train')
ds_test
=
input_pipeline
.
get_data_from_tfds
(
config
=
config
,
mode
=
'test'
)
# ds_test = input_pipeline.get_data_from_tfds(config=config, mode='test')
num_classes
=
input_pipeline
.
get_dataset_info
(
dataset
,
'train'
)[
'num_classes'
]
# num_classes = input_pipeline.get_dataset_info(dataset, 'train')['num_classes']
del
config
# Only needed to instantiate datasets.
# del config # Only needed to instantiate datasets.
# Fetch a batch of test images for illustration purposes.
# # Fetch a batch of test images for illustration purposes.
batch
=
next
(
iter
(
ds_test
.
as_numpy_iterator
()))
# batch = next(iter(ds_test.as_numpy_iterator()))
# Note the shape : [num_local_devices, local_batch_size, h, w, c]
# # Note the shape : [num_local_devices, local_batch_size, h, w, c]
print
(
batch
[
'image'
].
shape
)
# print("数据集shape:",batch['image'].shape)
exit
()
# tf.config.set_visible_devices([], 'GPU')
'''加载预训练模型--微调用'''
# print(tf.config.get_visible_devices('GPU'))
# model_config = models_config.MODEL_CONFIGS[model_name]
'''加载预训练模型'''
# print("模型config:",model_config)
model_config
=
models_config
.
MODEL_CONFIGS
[
model_name
]
# # 加载模型定义并初始化随机参数。
print
(
model_config
)
# # 这也将模型编译为XLA(第一次需要几分钟)。
# 加载模型定义并初始化随机参数。
# if model_name.startswith('Mixer'):
# 这也将模型编译为XLA(第一次需要几分钟)。
# model = models.MlpMixer(num_classes=num_classes, **model_config)
if
model_name
.
startswith
(
'Mixer'
):
# else:
model
=
models
.
MlpMixer
(
num_classes
=
num_classes
,
**
model_config
)
# model = models.VisionTransformer(num_classes=num_classes, **model_config)
else
:
# variables = jax.jit(lambda: model.init(
model
=
models
.
VisionTransformer
(
num_classes
=
num_classes
,
**
model_config
)
# jax.random.PRNGKey(0),
variables
=
jax
.
jit
(
lambda
:
model
.
init
(
# # 丢弃用于初始化的批处理的“num_local_devices”维度。
jax
.
random
.
PRNGKey
(
0
),
# batch['image'][0, :1],
# 丢弃用于初始化的批处理的“num_local_devices”维度。
# train=False,
batch
[
'image'
][
0
,
:
1
],
# ), backend='cpu')()
train
=
False
,
# #加载和转换预训练检查点。
),
backend
=
'cpu'
)()
# # 这涉及到加载实际的预训练模型结果,但也要修改一点参数,例如改变最终层,并调整位置嵌入的大小。有关详细信息,请参阅代码和本文的方法。
#加载和转换预训练检查点。
# params = checkpoint.load_pretrained(
# 这涉及到加载实际的预训练模型结果,但也要修改一点参数,例如改变最终层,并调整位置嵌入的大小。有关详细信息,请参阅代码和本文的方法。
# pretrained_path=pretrained_path,
params
=
checkpoint
.
load_pretrained
(
# init_params=variables['params'],
pretrained_path
=
f
'./test_result/
{
model_name
}
.npz'
,
# model_config=model_config
init_params
=
variables
[
'params'
],
# )
model_config
=
model_config
)
'''评估'''
'''评估'''
params_repl
=
flax
.
jax_utils
.
replicate
(
params
)
#
params_repl = flax.jax_utils.replicate(params)
print
(
'params.cls:'
,
type
(
params
[
'head'
][
'bias'
]).
__name__
,
#
print('params.cls:', type(params['head']['bias']).__name__,
params
[
'head'
][
'bias'
].
shape
)
#
params['head']['bias'].shape)
print
(
'params_repl.cls:'
,
type
(
params_repl
[
'head'
][
'bias'
]).
__name__
,
#
print('params_repl.cls:', type(params_repl['head']['bias']).__name__,
params_repl
[
'head'
][
'bias'
].
shape
)
#
params_repl['head']['bias'].shape)
# 然后将调用映射到我们模型的forward pass到所有可用的设备。
#
# 然后将调用映射到我们模型的forward pass到所有可用的设备。
vit_apply_repl
=
jax
.
pmap
(
lambda
params
,
inputs
:
model
.
apply
(
#
vit_apply_repl = jax.pmap(lambda params, inputs: model.apply(
dict
(
params
=
params
),
inputs
,
train
=
False
))
#
dict(params=params), inputs, train=False))
def
get_accuracy
(
params_repl
):
# def get_accuracy(params_repl):
"""返回对测试集求值的精度"""
# """返回对测试集求值的精度"""
good
=
total
=
0
# good = total = 0
steps
=
input_pipeline
.
get_dataset_info
(
dataset
,
'test'
)[
'num_examples'
]
//
batch_size
# steps = input_pipeline.get_dataset_info(dataset, 'test')['num_examples'] // batch_size
for
_
,
batch
in
zip
(
tqdm
.
trange
(
steps
),
ds_test
.
as_numpy_iterator
()):
# for _, batch in zip(tqdm.trange(steps), ds_test.as_numpy_iterator()):
predicted
=
vit_apply_repl
(
params_repl
,
batch
[
'image'
])
# predicted = vit_apply_repl(params_repl, batch['image'])
is_same
=
predicted
.
argmax
(
axis
=-
1
)
==
batch
[
'label'
].
argmax
(
axis
=-
1
)
# is_same = predicted.argmax(axis=-1) == batch['label'].argmax(axis=-1)
good
+=
is_same
.
sum
()
# good += is_same.sum()
total
+=
len
(
is_same
.
flatten
())
# total += len(is_same.flatten())
return
good
/
total
# return good / total
# 没有微调的随机性能。
# # 模型的随机性能
print
(
get_accuracy
(
params_repl
))
# print(get_accuracy(params_repl))
exit
()
'''微调'''
'''微调'''
# 100 Steps take approximately 15 minutes in the TPU runtime.
#
# 100 Steps take approximately 15 minutes in the TPU runtime.
total_steps
=
50
#
total_steps = 50
warmup_steps
=
5
#
warmup_steps = 5
decay_type
=
'cosine'
#
decay_type = 'cosine'
grad_norm_clip
=
1
#
grad_norm_clip = 1
# 这控制了批处理被分割的转发次数。8适用于具有8个设备的TPU运行时。64应该可以在GPU上工作。当然,您也可以调整上面的batch_size,但这需要相应地调整学习率。
#
# 这控制了批处理被分割的转发次数。8适用于具有8个设备的TPU运行时。64应该可以在GPU上工作。当然,您也可以调整上面的batch_size,但这需要相应地调整学习率。
accum_steps
=
64
# TODO:可能要改
#
accum_steps = 64 # TODO:可能要改
base_lr
=
0.03
#
base_lr = 0.03
# 检查 train.make_update_fn
#
# 检查 train.make_update_fn
lr_fn
=
utils
.
create_learning_rate_schedule
(
total_steps
,
base_lr
,
decay_type
,
warmup_steps
)
#
lr_fn = utils.create_learning_rate_schedule(total_steps, base_lr, decay_type, warmup_steps)
# 我们使用一个动量优化器,使用一半精度的状态来节省内存。它还实现了梯度裁剪
#
# 我们使用一个动量优化器,使用一半精度的状态来节省内存。它还实现了梯度裁剪
tx
=
optax
.
chain
(
#
tx = optax.chain(
optax
.
clip_by_global_norm
(
grad_norm_clip
),
#
optax.clip_by_global_norm(grad_norm_clip),
optax
.
sgd
(
#
optax.sgd(
learning_rate
=
lr_fn
,
#
learning_rate=lr_fn,
momentum
=
0.9
,
#
momentum=0.9,
accumulator_dtype
=
'bfloat16'
,
#
accumulator_dtype='bfloat16',
),
#
),
)
#
)
update_fn_repl
=
train
.
make_update_fn
(
#
update_fn_repl = train.make_update_fn(
apply_fn
=
model
.
apply
,
accum_steps
=
accum_steps
,
tx
=
tx
)
#
apply_fn=model.apply, accum_steps=accum_steps, tx=tx)
opt_state
=
tx
.
init
(
params
)
#
opt_state = tx.init(params)
opt_state_repl
=
flax
.
jax_utils
.
replicate
(
opt_state
)
#
opt_state_repl = flax.jax_utils.replicate(opt_state)
# Initialize PRNGs for dropout.
#
# Initialize PRNGs for dropout.
update_rng_repl
=
flax
.
jax_utils
.
replicate
(
jax
.
random
.
PRNGKey
(
0
))
#
update_rng_repl = flax.jax_utils.replicate(jax.random.PRNGKey(0))
# 训练更新
#
# 训练更新
losses
=
[]
#
losses = []
lrs
=
[]
#
lrs = []
# Completes in ~20 min on the TPU runtime.
#
# Completes in ~20 min on the TPU runtime.
start
=
time
.
time
()
#
start = time.time()
for
step
,
batch
in
zip
(
#
for step, batch in zip(
tqdm
.
trange
(
1
,
total_steps
+
1
),
#
tqdm.trange(1, total_steps + 1),
ds_train
.
as_numpy_iterator
(),
#
ds_train.as_numpy_iterator(),
):
#
):
params_repl
,
opt_state_repl
,
loss_repl
,
update_rng_repl
=
update_fn_repl
(
#
params_repl, opt_state_repl, loss_repl, update_rng_repl = update_fn_repl(
params_repl
,
opt_state_repl
,
batch
,
update_rng_repl
)
#
params_repl, opt_state_repl, batch, update_rng_repl)
losses
.
append
(
loss_repl
[
0
])
#
losses.append(loss_repl[0])
lrs
.
append
(
lr_fn
(
step
))
#
lrs.append(lr_fn(step))
end
=
time
.
time
()
#
end = time.time()
print
(
f
"
{
model_name
}
_
{
dataset
}
_
{
total_steps
}
_
{
warmup_steps
}
微调时间为:"
,
end
-
start
)
#
print(f"{model_name}_{dataset}_{total_steps}_{warmup_steps}微调时间为:",end-start)
print
(
get_accuracy
(
params_repl
))
#
print(get_accuracy(params_repl))
# 绘制学习率变化曲线并保存
# 绘制学习率变化曲线并保存
plt
.
plot
(
losses
)
#
plt.plot(losses)
plt
.
savefig
(
f
'./test_result/
{
model_name
}
_
{
dataset
}
/losses_plot.png'
)
#
plt.savefig(f'./test_result/{model_name}_{dataset}/losses_plot.png')
plt
.
close
()
#
plt.close()
plt
.
plot
(
lrs
)
#
plt.plot(lrs)
plt
.
savefig
(
f
'./test_result/
{
model_name
}
_
{
dataset
}
/lrs_plot.png'
)
#
plt.savefig(f'./test_result/{model_name}_{dataset}/lrs_plot.png')
plt
.
close
()
#
plt.close()
# 在CIFAR10上,Mixer-B/16应该是~96.7%,vitb /32应该是97.7%(都是@224)
# 在CIFAR10上,Mixer-B/16应该是~96.7%,vitb /32应该是97.7%(都是@224)
exit
()
# exit()
'''推理'''
'''推理'''
# #下载一个预训练的模型
model_config
=
models_config
.
MODEL_CONFIGS
[
model_name
]
# model_name = 'ViT-L_16'
print
(
"模型config:"
,
model_config
)
# model_config = models_config.MODEL_CONFIGS[model_name]
model
=
models
.
VisionTransformer
(
num_classes
=
1000
,
**
model_config
)
# print(model_config)
assert
os
.
path
.
exists
(
model_path
)
# model = models.VisionTransformer(num_classes=1000, **model_config)
# 加载和转换预训练的检查点
# assert os.path.exists(f'./test_result/{model_name}_imagenet2012.npz')
params
=
checkpoint
.
load
(
model_path
)
# # 加载和转换预训练的检查点
params
[
'pre_logits'
]
=
{}
# Need to restore empty leaf for Flax.
# params = checkpoint.load(f'./test_result/{model_name}_imagenet2012.npz')
# 获取图像标签.
# params['pre_logits'] = {} # Need to restore empty leaf for Flax.
# get_ipython().system('wget https://storage.googleapis.com/bit_models/ilsvrc2012_wordnet_lemmas.txt')
# # 获取图像标签.
imagenet_labels
=
dict
(
enumerate
(
open
(
'./dataset/ilsvrc2012_wordnet_lemmas.txt'
)))
# # get_ipython().system('wget https://storage.googleapis.com/bit_models/ilsvrc2012_wordnet_lemmas.txt')
# 得到一张具有正确尺寸的随机图片
# imagenet_labels = dict(enumerate(open('./test_result/ilsvrc2012_wordnet_lemmas.txt')))
# resolution = 224 if model_name.startswith('Mixer') else 384
# # 得到一张具有正确尺寸的随机图片
# get_ipython().system('wget https://picsum.photos/$resolution -O picsum.jpg')
# # resolution = 224 if model_name.startswith('Mixer') else 384
img
=
PIL
.
Image
.
open
(
'./dataset/picsum.jpg'
)
# # get_ipython().system('wget https://picsum.photos/$resolution -O picsum.jpg')
# 预测
# img = PIL.Image.open('./test_result/picsum.jpg')
start_time
=
time
.
time
()
# # 预测单个项目的批处理(注意非常高效的TPU使用…)
logits
,
=
model
.
apply
(
dict
(
params
=
params
),
(
np
.
array
(
img
)
/
128
-
1
)[
None
,
...],
train
=
False
)
# logits, = model.apply(dict(params=params), (np.array(img) / 128 - 1)[None, ...], train=False)
end_time
=
time
.
time
()
# preds = np.array(jax.nn.softmax(logits))
preds
=
np
.
array
(
jax
.
nn
.
softmax
(
logits
))
# for idx in preds.argsort()[:-11:-1]:
print
(
"推理结果:time="
,
end_time
-
start_time
)
# print(f'{preds[idx]:.5f} : {imagenet_labels[idx]}', end='')
for
idx
in
preds
.
argsort
()[:
-
11
:
-
1
]:
\ No newline at end of file
print
(
f
'
{
preds
[
idx
]:.
5
f
}
:
{
imagenet_labels
[
idx
]
}
'
,
end
=
''
)
\ No newline at end of file
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment