Commit 14a27ede authored by yangzhong's avatar yangzhong
Browse files

添加amp参数开关控制

parent 5f82b770
......@@ -38,6 +38,16 @@ torch.distributed.init_process_group(backend='nccl')
# 固定seed
seed_everything(42)
# 添加amp参数开关
parser = argparse.ArgumentParser(description='bert4torch training')
#parser.add_argument('--use-amp', type=bool, default=True, help='Use automatic mixed precision (AMP)')
parser.add_argument(
"--use-amp",
action="store_true",
help="Run model AMP (automatic mixed precision) mode.",
)
args = parser.parse_args()
# 加载数据集
class MyDataset(ListDataset):
@staticmethod
......@@ -120,7 +130,12 @@ class Loss(nn.Module):
def forward(self, outputs, labels):
return model.module.crf(*outputs, labels)
model.compile(loss=Loss(), optimizer=optim.Adam(model.parameters(), lr=2e-5)) # fp32
if args.use_amp:
model.compile(loss=Loss(), optimizer=optim.Adam(model.parameters(), lr=2e-5), use_amp=True) # 使用 AMP 进行训练fp16
else:
model.compile(loss=Loss(), optimizer=optim.Adam(model.parameters(), lr=2e-5), use_amp=False) # 不使用 AMP 进行训练 fp32
# model.compile(loss=Loss(), optimizer=optim.Adam(model.parameters(), lr=2e-5)) # fp32
# 定义使用的loss和optimizer,这里支持自定义
# model.compile(loss=Loss(), optimizer=optim.Adam(model.parameters(), lr=2e-5), use_amp=True) # fp16
#compile(self, loss, optimizer, scheduler=None, max_grad_norm=None, use_amp=False, metrics=None, adversarial_train={'name': ''}):
......
......@@ -13,4 +13,5 @@ export HIP_VISIBLE_DEVICES=$(seq -s, ${START} ${LAST})
export HSA_FORCE_FINE_GRAIN_PCIE=1
logfile=bert_base_${NUM}dcu_`date +%Y%m%d%H%M%S`.log
python3 -m torch.distributed.run --nproc_per_node=${NUM} crf_ddp.py 2>&1 | tee $logfile
python3 -m torch.distributed.run --nproc_per_node=${NUM} crf_ddp.py 2>&1 | tee $logfile # fp32
#python3 -m torch.distributed.run --nproc_per_node=${NUM} crf_ddp.py --use-amp 2>&1 | tee $logfile # fp16
logfile=bert_base_`date +%Y%m%d%H%M%S`.log
python3 crf.py 2>&1 | tee $logfile
python3 crf.py 2>&1 | tee $logfile # fp32
#python3 crf.py --use-amp 2>&1 | tee $logfile # fp16
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment