Unverified Commit 90c267f2 authored by Phuong Nguyen's avatar Phuong Nguyen Committed by GitHub
Browse files

[JAX] [B] Fixed Batcher in DBiasCastTranspose Primitive (#843)



fixed batcher in dbias_cast_transpose primitive
Signed-off-by: default avatarPhuong Nguyen <phuonguyen@nvidia.com>
parent 476f659e
...@@ -4298,7 +4298,7 @@ class DBiasCastTransposePrimitive(BasePrimitive): ...@@ -4298,7 +4298,7 @@ class DBiasCastTransposePrimitive(BasePrimitive):
_check_valid_batch_dims(batch_dims) _check_valid_batch_dims(batch_dims)
assert DBiasCastTransposePrimitive.outer_primitive is not None assert DBiasCastTransposePrimitive.outer_primitive is not None
dz, amax, scale, scale_inv = batched_args dz, amax, scale, scale_inv = batched_args
dz_bdim, _, amax_bdim, _, _ = batch_dims dz_bdim, amax_bdim, _, _ = batch_dims
# Minus batch dim. # Minus batch dim.
transpose_axis_boundary = _normalize_axis_boundary(transpose_axis_boundary, dz.ndim - 1) transpose_axis_boundary = _normalize_axis_boundary(transpose_axis_boundary, dz.ndim - 1)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment