Commit 262da9c6 authored by Michael Carilli's avatar Michael Carilli
Browse files

Adding more helpful exception messages to prep_param_lists

parent f22201e7
...@@ -68,9 +68,23 @@ def prep_param_lists(model, flat_master=False): ...@@ -68,9 +68,23 @@ def prep_param_lists(model, flat_master=False):
model_params = [param for param in model.parameters() if param.requires_grad] model_params = [param for param in model.parameters() if param.requires_grad]
if flat_master: if flat_master:
# flatten_dense_tensors returns a contiguous flat array. # Give the user some more useful error messages
# http://pytorch.org/docs/master/_modules/torch/_utils.html try:
master_params = _flatten_dense_tensors([param.data for param in model_params]).float() # flatten_dense_tensors returns a contiguous flat array.
# http://pytorch.org/docs/master/_modules/torch/_utils.html
master_params = _flatten_dense_tensors([param.data for param in model_params]).float()
except TypeError as instance:
# This is brittle, and depends on how cat chooses to word its error message.
if "cat received an invalid combination of arguments" not in instance.args[0]:
raise
else:
# If you append a message to the exception instance, via
# instance.args = instance.args + ("Error...",)
# this messes up the terminal-formatted printing of the instance's original message.
# Basic solution for now:
print("Error in prep_param_lists: model likely contains a mixture of parameters "
"of different types. Use flat_master=False, or use F16_Optimizer.")
raise
master_params = torch.nn.Parameter(master_params) master_params = torch.nn.Parameter(master_params)
master_params.requires_grad = True master_params.requires_grad = True
# master_params.register_hook(backwards_debug_hook) # master_params.register_hook(backwards_debug_hook)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment