Unverified Commit 05977f44 authored by Phuong Nguyen's avatar Phuong Nguyen Committed by GitHub
Browse files

[TE/JAX] Remove tuple wrapper of singleton in HLO lowering return (#1000)



* removed singleton wrapper
Signed-off-by: default avatarPhuong Nguyen <phuonguyen@nvidia.com>

---------
Signed-off-by: default avatarPhuong Nguyen <phuonguyen@nvidia.com>
parent 33dbf62b
...@@ -98,7 +98,7 @@ class ActLuPrimitive(BasePrimitive): ...@@ -98,7 +98,7 @@ class ActLuPrimitive(BasePrimitive):
out = custom_caller(ActLuPrimitive.name, args, opaque, False) out = custom_caller(ActLuPrimitive.name, args, opaque, False)
return [out] return out
@staticmethod @staticmethod
def impl(x, act_enum): def impl(x, act_enum):
...@@ -226,7 +226,7 @@ class DActLuPrimitive(BasePrimitive): ...@@ -226,7 +226,7 @@ class DActLuPrimitive(BasePrimitive):
out = custom_caller(DActLuPrimitive.name, args, opaque, False) out = custom_caller(DActLuPrimitive.name, args, opaque, False)
return [out] return out
@staticmethod @staticmethod
def impl(dz, x, act_enum): def impl(dz, x, act_enum):
......
...@@ -135,7 +135,7 @@ class SoftmaxPrimitive(BasePrimitive): ...@@ -135,7 +135,7 @@ class SoftmaxPrimitive(BasePrimitive):
out = custom_caller(name, args, opaque, False) out = custom_caller(name, args, opaque, False)
return [out] return out
@staticmethod @staticmethod
def forward_impl(primitive, logits, scale_factor): def forward_impl(primitive, logits, scale_factor):
...@@ -247,7 +247,7 @@ class SoftmaxPrimitive(BasePrimitive): ...@@ -247,7 +247,7 @@ class SoftmaxPrimitive(BasePrimitive):
out = custom_caller(name, args, opaque, False) out = custom_caller(name, args, opaque, False)
return [out] return out
@staticmethod @staticmethod
def backward_impl(primitive, dz, softmax_out, scale_factor): def backward_impl(primitive, dz, softmax_out, scale_factor):
...@@ -577,7 +577,7 @@ class ScaledMaskedSoftmaxFwdPrimitive(SoftmaxPrimitive): ...@@ -577,7 +577,7 @@ class ScaledMaskedSoftmaxFwdPrimitive(SoftmaxPrimitive):
out = custom_caller(ScaledMaskedSoftmaxFwdPrimitive.name, args, opaque, False) out = custom_caller(ScaledMaskedSoftmaxFwdPrimitive.name, args, opaque, False)
return [out] return out
@staticmethod @staticmethod
def impl(logits, mask, scale_factor): def impl(logits, mask, scale_factor):
......
...@@ -103,7 +103,7 @@ class TransposePrimitive(BasePrimitive): ...@@ -103,7 +103,7 @@ class TransposePrimitive(BasePrimitive):
out = custom_caller(TransposePrimitive.name, args, opaque, False) out = custom_caller(TransposePrimitive.name, args, opaque, False)
return [out] return out
@staticmethod @staticmethod
def impl(x, static_axis_boundary, transpose_axis_boundary): def impl(x, static_axis_boundary, transpose_axis_boundary):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment