export HIP_VISIBLE_DEVICES=4 # 配置GPU/dcu训练,指定"编码"(多个即多GPU/dcu) export USE_MIOPEN_BATCHNORM=1 # 启用MIOPEN库的批归一化优化,用于加快训练速度? # 指定存储模型的位置 # 用于微调的模型/数据集 # 预训练模型的位置 # 累加梯度(tpu=8,cpu=64) # 微调轮次 # 学习率衰减轮次 # 训练批次 # 图像块的分辨率 # 精度 for model_datasets in 'b16,cifar10' 'b16,cifar100' 'l16,cifar10' 'l16,cifar100' do python -m vit_jax.main --workdir=$(pwd)/test_result/dcu/vit-$(date +%s) \ --config=$(pwd)/vit_jax/configs/vit.py:$model_datasets \ --config.pretrained_dir=$(pwd)/test_result \ --config.accum_steps=64 \ --config.total_steps=500 \ --config.warmup_steps=50 \ --config.batch=512 \ --config.pp.crop=384 \ --config.optim_dtype='bfloat16' done