Unverified Commit d3552ddb authored by Kirthi Shankar Sivamani's avatar Kirthi Shankar Sivamani Committed by GitHub
Browse files

[PyTorch] Use __torch_function__ as a class method (#783)



Use torch function as a class method
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>
parent 63c7a1a3
...@@ -263,7 +263,7 @@ class _ViewFunc(torch.autograd.Function): ...@@ -263,7 +263,7 @@ class _ViewFunc(torch.autograd.Function):
@staticmethod @staticmethod
def backward(ctx, def backward(ctx,
grad: torch.Tensor, grad: torch.Tensor,
) -> Tuple[[torch.Tensor, None], ...]: ) -> Tuple[Union[torch.Tensor, None], ...]:
if isinstance(grad, Float8Tensor): if isinstance(grad, Float8Tensor):
dgrad = Float8Tensor.make_like( dgrad = Float8Tensor.make_like(
...@@ -853,5 +853,8 @@ class Float8Tensor(torch.Tensor): ...@@ -853,5 +853,8 @@ class Float8Tensor(torch.Tensor):
_transpose_invalid = property(**_make_fp8_attr_property_funcs("transpose_invalid")) _transpose_invalid = property(**_make_fp8_attr_property_funcs("transpose_invalid"))
_scale_inv = property(**_make_fp8_attr_property_funcs("scale_inv")) _scale_inv = property(**_make_fp8_attr_property_funcs("scale_inv"))
# Do not force the Float8Tensor type on the returned tensor @classmethod
__torch_function__ = torch._C._disabled_torch_function_impl def __torch_function__(cls, func, types, args=(), kwargs=None):
if kwargs is None:
kwargs = {}
return torch._C._disabled_torch_function_impl(func, types, args, kwargs)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment