[JAX] Adapt latest JAX/PAX image (#744)
* value_and_grad requires same shape for input and gradients Signed-off-by:Reese Wang <rewang@nvidia.com> * Use high precision layernorm Signed-off-by:
Reese Wang <rewang@nvidia.com> * Remove local_device_ids as it caused unexpected behaviors Signed-off-by:
Reese Wang <rewang@nvidia.com> * Revert "Remove local_device_ids as it caused unexpected behaviors" This reverts commit c54349b2ce1e96ae696cf0d74f5210e55002cf72. Signed-off-by:
Reese Wang <rewang@nvidia.com> --------- Signed-off-by:
Reese Wang <rewang@nvidia.com>
Showing
Please register or sign in to comment