Unverified Commit d3552ddb authored by Kirthi Shankar Sivamani's avatar Kirthi Shankar Sivamani Committed by GitHub
Browse files

[PyTorch] Use __torch_function__ as a class method (#783)



Use torch function as a class method
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>
parent 63c7a1a3
......@@ -263,7 +263,7 @@ class _ViewFunc(torch.autograd.Function):
@staticmethod
def backward(ctx,
grad: torch.Tensor,
) -> Tuple[[torch.Tensor, None], ...]:
) -> Tuple[Union[torch.Tensor, None], ...]:
if isinstance(grad, Float8Tensor):
dgrad = Float8Tensor.make_like(
......@@ -853,5 +853,8 @@ class Float8Tensor(torch.Tensor):
_transpose_invalid = property(**_make_fp8_attr_property_funcs("transpose_invalid"))
_scale_inv = property(**_make_fp8_attr_property_funcs("scale_inv"))
# Do not force the Float8Tensor type on the returned tensor
__torch_function__ = torch._C._disabled_torch_function_impl
@classmethod
def __torch_function__(cls, func, types, args=(), kwargs=None):
if kwargs is None:
kwargs = {}
return torch._C._disabled_torch_function_impl(func, types, args, kwargs)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment