Unverified Commit 43a6f9fe authored by mcarilli's avatar mcarilli Committed by GitHub
Browse files

Don't patch tensor ops that aren't present (#899)



* Only attempt to patch Tensor methods if defined

* syntax
Co-authored-by: default avatarMichael Carilli <mcarilli@nvidia.com>
parent 44532b30
...@@ -40,3 +40,7 @@ def scalar_python_val(x): ...@@ -40,3 +40,7 @@ def scalar_python_val(x):
return x.data[0] return x.data[0]
else: else:
return x[0] return x[0]
# Accounts for the possibility that some ops may be removed from a namespace.
def filter_attrs(module, attrs):
return list(attrname for attrname in attrs if hasattr(module, attrname))
...@@ -11,20 +11,20 @@ MODULE = torch.Tensor ...@@ -11,20 +11,20 @@ MODULE = torch.Tensor
# MODULE = torch.autograd.Variable # MODULE = torch.autograd.Variable
FP16_FUNCS = [ FP16_FUNCS = compat.filter_attrs(MODULE, [
'__matmul__', '__matmul__',
] ])
FP32_FUNCS = [ FP32_FUNCS = compat.filter_attrs(MODULE, [
'__ipow__', '__ipow__',
'__pow__', '__pow__',
'__rpow__', '__rpow__',
# Cast to fp32 before transfer to CPU # Cast to fp32 before transfer to CPU
'cpu', 'cpu',
] ])
CASTS = [ CASTS = compat.filter_attrs(MODULE, [
'__add__', '__add__',
'__div__', '__div__',
'__eq__', '__eq__',
...@@ -46,7 +46,7 @@ CASTS = [ ...@@ -46,7 +46,7 @@ CASTS = [
'__rtruediv__', '__rtruediv__',
'__sub__', '__sub__',
'__truediv__', '__truediv__',
] ])
# None of these, but here to make code cleaner. # None of these, but here to make code cleaner.
SEQUENCE_CASTS = [] SEQUENCE_CASTS = []
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment