Commit 28097c99 authored by ptrblck's avatar ptrblck Committed by mcarilli
Browse files

initial commit, add CUDA warning to check_params_fp32 (#263)

parent cd2708cc
......@@ -75,18 +75,32 @@ def check_models(models):
def check_params_fp32(models):
for model in models:
for name, param in model.named_parameters():
if param.is_floating_point() and param.type() != "torch.cuda.FloatTensor":
warn_or_err("Found param {} with type {}, expected torch.cuda.FloatTensor.\n"
"When using amp.initialize, you do not need to call .half() on your model\n"
"before passing it, no matter what optimization level you choose.".format(
name, param.type()))
if param.is_floating_point():
if 'Half' in param.type():
warn_or_err("Found param {} with type {}, expected torch.cuda.FloatTensor.\n"
"When using amp.initialize, you do not need to call .half() on your model\n"
"before passing it, no matter what optimization level you choose.".format(
name, param.type()))
elif not param.is_cuda:
warn_or_err("Found param {} with type {}, expected torch.cuda.FloatTensor.\n"
"When using amp.initialize, you need to provide a model with parameters\n"
"located on a CUDA device before passing it no matter what optimization level\n"
"you chose. Use model.to('cuda') to use the default device.".format(
name, param.type()))
for name, buf in model.named_buffers():
if buf.is_floating_point() and buf.type() != "torch.cuda.FloatTensor":
warn_or_err("Found buffer {} with type {}, expected torch.cuda.FloatTensor.\n"
"When using amp.initialize, you do not need to call .half() on your model\n"
"before passing it, no matter what optimization level you choose.".format(
name, buf.type()))
if buf.is_floating_point():
if 'Half' in buf.type():
warn_or_err("Found buffer {} with type {}, expected torch.cuda.FloatTensor.\n"
"When using amp.initialize, you do not need to call .half() on your model\n"
"before passing it, no matter what optimization level you choose.".format(
name, buf.type()))
elif not buf.is_cuda:
warn_or_err("Found buffer {} with type {}, expected torch.cuda.FloatTensor.\n"
"When using amp.initialize, you need to provide a model with buffers\n"
"located on a CUDA device before passing it no matter what optimization level\n"
"you chose. Use model.to('cuda') to use the default device.".format(
name, buf.type()))
def check_optimizers(optimizers):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment