Commit 2f21497d authored by thomwolf's avatar thomwolf
Browse files

fixing param.grad is None in fp16 examples

parent da73925f
......@@ -555,7 +555,8 @@ def main():
if args.fp16 and args.loss_scale != 1.0:
# scale down gradients for fp16 training
for param in model.parameters():
param.grad.data = param.grad.data / args.loss_scale
if param.grad is not None:
param.grad.data = param.grad.data / args.loss_scale
is_nan = set_optimizer_params_grad(param_optimizer, model.named_parameters(), test_nan=True)
if is_nan:
logger.info("FP16 TRAINING: Nan in gradients, reducing loss scaling")
......
......@@ -898,7 +898,8 @@ def main():
if args.fp16 and args.loss_scale != 1.0:
# scale down gradients for fp16 training
for param in model.parameters():
param.grad.data = param.grad.data / args.loss_scale
if param.grad is not None:
param.grad.data = param.grad.data / args.loss_scale
is_nan = set_optimizer_params_grad(param_optimizer, model.named_parameters(), test_nan=True)
if is_nan:
logger.info("FP16 TRAINING: Nan in gradients, reducing loss scaling")
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment