Commit 14e34f7f authored by ptrblck's avatar ptrblck Committed by mcarilli
Browse files

add backwards compatibility for PyTorch 0.4 for named_buffers (#331)

parent e6eec3ba
......@@ -89,7 +89,16 @@ def check_params_fp32(models):
"you chose. Use model.to('cuda') to use the default device.".format(
name, param.type()))
for name, buf in model.named_buffers():
# Backward compatibility for PyTorch 0.4
if hasattr(model, 'named_buffers'):
buf_iter = model.named_buffers()
else:
buf_iter = model._buffers
for obj in buf_iter:
if type(obj)==tuple:
name, buf = obj
else:
name, buf = obj, buf_iter[obj]
if buf.is_floating_point():
if 'Half' in buf.type():
warn_or_err("Found buffer {} with type {}, expected torch.cuda.FloatTensor.\n"
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment