Unverified Commit 05977f44 authored by Phuong Nguyen's avatar Phuong Nguyen Committed by GitHub
Browse files

[TE/JAX] Remove tuple wrapper of singleton in HLO lowering return (#1000)



* removed singleton wrapper
Signed-off-by: default avatarPhuong Nguyen <phuonguyen@nvidia.com>

---------
Signed-off-by: default avatarPhuong Nguyen <phuonguyen@nvidia.com>
parent 33dbf62b
......@@ -98,7 +98,7 @@ class ActLuPrimitive(BasePrimitive):
out = custom_caller(ActLuPrimitive.name, args, opaque, False)
return [out]
return out
@staticmethod
def impl(x, act_enum):
......@@ -226,7 +226,7 @@ class DActLuPrimitive(BasePrimitive):
out = custom_caller(DActLuPrimitive.name, args, opaque, False)
return [out]
return out
@staticmethod
def impl(dz, x, act_enum):
......
......@@ -135,7 +135,7 @@ class SoftmaxPrimitive(BasePrimitive):
out = custom_caller(name, args, opaque, False)
return [out]
return out
@staticmethod
def forward_impl(primitive, logits, scale_factor):
......@@ -247,7 +247,7 @@ class SoftmaxPrimitive(BasePrimitive):
out = custom_caller(name, args, opaque, False)
return [out]
return out
@staticmethod
def backward_impl(primitive, dz, softmax_out, scale_factor):
......@@ -577,7 +577,7 @@ class ScaledMaskedSoftmaxFwdPrimitive(SoftmaxPrimitive):
out = custom_caller(ScaledMaskedSoftmaxFwdPrimitive.name, args, opaque, False)
return [out]
return out
@staticmethod
def impl(logits, mask, scale_factor):
......
......@@ -103,7 +103,7 @@ class TransposePrimitive(BasePrimitive):
out = custom_caller(TransposePrimitive.name, args, opaque, False)
return [out]
return out
@staticmethod
def impl(x, static_axis_boundary, transpose_axis_boundary):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment