Unverified Commit 43a6f9fe authored by mcarilli's avatar mcarilli Committed by GitHub
Browse files

Don't patch tensor ops that aren't present (#899)



* Only attempt to patch Tensor methods if defined

* syntax
Co-authored-by: default avatarMichael Carilli <mcarilli@nvidia.com>
parent 44532b30
......@@ -40,3 +40,7 @@ def scalar_python_val(x):
return x.data[0]
else:
return x[0]
# Accounts for the possibility that some ops may be removed from a namespace.
def filter_attrs(module, attrs):
return list(attrname for attrname in attrs if hasattr(module, attrname))
......@@ -11,20 +11,20 @@ MODULE = torch.Tensor
# MODULE = torch.autograd.Variable
FP16_FUNCS = [
FP16_FUNCS = compat.filter_attrs(MODULE, [
'__matmul__',
]
])
FP32_FUNCS = [
FP32_FUNCS = compat.filter_attrs(MODULE, [
'__ipow__',
'__pow__',
'__rpow__',
# Cast to fp32 before transfer to CPU
'cpu',
]
])
CASTS = [
CASTS = compat.filter_attrs(MODULE, [
'__add__',
'__div__',
'__eq__',
......@@ -46,7 +46,7 @@ CASTS = [
'__rtruediv__',
'__sub__',
'__truediv__',
]
])
# None of these, but here to make code cleaner.
SEQUENCE_CASTS = []
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment