[JAX] Propagate sm_margin to the underly layernorm kernels (#1089)
* Propagate sm_margin to the underly layernorm kernels --------- Signed-off-by:Reese Wang <rewang@nvidia.com> Co-authored-by:
Phuong Nguyen <36155692+phu0ngng@users.noreply.github.com>
Showing
Please register or sign in to comment