export HIP_VISIBLE_DEVICES=4 # 配置GPU/dcu训练，指定"编码"（多个即多GPU/dcu）
export USE_MIOPEN_BATCHNORM=1  # 启用MIOPEN库的批归一化优化，用于加快训练速度？

# 指定存储模型的位置
# 用于微调的模型/数据集
# 预训练模型的位置
# 累加梯度（tpu=8,cpu=64）
# 微调轮次
# 学习率衰减轮次
# 训练批次
# 图像块的分辨率
# 精度
for model_datasets in 'b16,cifar10' 'b16,cifar100' 'l16,cifar10' 'l16,cifar100'
do
    python -m vit_jax.main --workdir=$(pwd)/test_result/dcu/vit-$(date +%s) \
        --config=$(pwd)/vit_jax/configs/vit.py:$model_datasets \
        --config.pretrained_dir=$(pwd)/test_result \
        --config.accum_steps=64 \
        --config.total_steps=500 \
        --config.warmup_steps=50 \
        --config.batch=512 \
        --config.pp.crop=384 \
        --config.optim_dtype='bfloat16'
done