Commit 5f82b770 authored by yangzhong's avatar yangzhong
Browse files

添加amp参数开关控制

parent 4edfa95d
...@@ -30,6 +30,16 @@ device = 'cuda' if torch.cuda.is_available() else 'cpu' ...@@ -30,6 +30,16 @@ device = 'cuda' if torch.cuda.is_available() else 'cpu'
# 固定seed # 固定seed
seed_everything(42) seed_everything(42)
# 添加amp参数开关
parser = argparse.ArgumentParser(description='bert4torch training')
#parser.add_argument('--use-amp', type=bool, default=True, help='Use automatic mixed precision (AMP)')
parser.add_argument(
"--use-amp",
action="store_true",
help="Run model AMP (automatic mixed precision) mode.",
)
args = parser.parse_args()
# 加载数据集 # 加载数据集
class MyDataset(ListDataset): class MyDataset(ListDataset):
@staticmethod @staticmethod
...@@ -107,7 +117,12 @@ class Loss(nn.Module): ...@@ -107,7 +117,12 @@ class Loss(nn.Module):
def forward(self, outputs, labels): def forward(self, outputs, labels):
return model.crf(*outputs, labels) return model.crf(*outputs, labels)
model.compile(loss=Loss(), optimizer=optim.Adam(model.parameters(), lr=2e-5)) # fp32 if args.use_amp:
model.compile(loss=Loss(), optimizer=optim.Adam(model.parameters(), lr=2e-5), use_amp=True) # 使用 AMP 进行训练fp16
else:
model.compile(loss=Loss(), optimizer=optim.Adam(model.parameters(), lr=2e-5), use_amp=False) # 不使用 AMP 进行训练 fp32
# model.compile(loss=Loss(), optimizer=optim.Adam(model.parameters(), lr=2e-5)) # fp32
# model.compile(loss=Loss(), optimizer=optim.Adam(model.parameters(), lr=2e-5), use_amp=True) # fp16 # model.compile(loss=Loss(), optimizer=optim.Adam(model.parameters(), lr=2e-5), use_amp=True) # fp16
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment