train.sh 609 Bytes
Newer Older
Rayyyyy's avatar
Rayyyyy committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
#!/bin/bash
echo "Export params ..."

export HIP_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 # 自行修改为训练的卡号和数量
export HSA_FORCE_FINE_GRAIN_PCIE=1
export USE_MIOPEN_BATCHNORM=1

echo "Training start ..."
# HAT_SRx4

#  options/train/train_HAT_SRx4_finetune_from_ImageNet_pretrain.yml文件中部分参数确认:
#  9行 datasets: 请确认数据地址正确
# 76行 pretrain_network_g: 请确认预训练模型地址正确

python -m torch.distributed.launch --nproc_per_node=8 --master_port=4321 hat/train.py -opt options/train/train_HAT_SRx4_finetune_from_ImageNet_pretrain.yml --launcher pytorch