You need to sign in or sign up before continuing.
Commit 28097c99 authored by ptrblck's avatar ptrblck Committed by mcarilli
Browse files

initial commit, add CUDA warning to check_params_fp32 (#263)

parent cd2708cc
...@@ -75,18 +75,32 @@ def check_models(models): ...@@ -75,18 +75,32 @@ def check_models(models):
def check_params_fp32(models): def check_params_fp32(models):
for model in models: for model in models:
for name, param in model.named_parameters(): for name, param in model.named_parameters():
if param.is_floating_point() and param.type() != "torch.cuda.FloatTensor": if param.is_floating_point():
warn_or_err("Found param {} with type {}, expected torch.cuda.FloatTensor.\n" if 'Half' in param.type():
"When using amp.initialize, you do not need to call .half() on your model\n" warn_or_err("Found param {} with type {}, expected torch.cuda.FloatTensor.\n"
"before passing it, no matter what optimization level you choose.".format( "When using amp.initialize, you do not need to call .half() on your model\n"
name, param.type())) "before passing it, no matter what optimization level you choose.".format(
name, param.type()))
elif not param.is_cuda:
warn_or_err("Found param {} with type {}, expected torch.cuda.FloatTensor.\n"
"When using amp.initialize, you need to provide a model with parameters\n"
"located on a CUDA device before passing it no matter what optimization level\n"
"you chose. Use model.to('cuda') to use the default device.".format(
name, param.type()))
for name, buf in model.named_buffers(): for name, buf in model.named_buffers():
if buf.is_floating_point() and buf.type() != "torch.cuda.FloatTensor": if buf.is_floating_point():
warn_or_err("Found buffer {} with type {}, expected torch.cuda.FloatTensor.\n" if 'Half' in buf.type():
"When using amp.initialize, you do not need to call .half() on your model\n" warn_or_err("Found buffer {} with type {}, expected torch.cuda.FloatTensor.\n"
"before passing it, no matter what optimization level you choose.".format( "When using amp.initialize, you do not need to call .half() on your model\n"
name, buf.type())) "before passing it, no matter what optimization level you choose.".format(
name, buf.type()))
elif not buf.is_cuda:
warn_or_err("Found buffer {} with type {}, expected torch.cuda.FloatTensor.\n"
"When using amp.initialize, you need to provide a model with buffers\n"
"located on a CUDA device before passing it no matter what optimization level\n"
"you chose. Use model.to('cuda') to use the default device.".format(
name, buf.type()))
def check_optimizers(optimizers): def check_optimizers(optimizers):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment