"transformer_engine/pytorch/attention.py" did not exist on "0c9c0ba1fe0f5f43b4bb68a690b9d8832496216b"
[JAX] Allow multi-dims for dgamma and dbeta in LN descriptor. (#780)
* Allow multi-dims for dgamma and dbeta in LN descriptor. Signed-off-by:Ming Huang <mingh@nvidia.com> * Fix the jit error in examples/jax Signed-off-by:
Ming Huang <mingh@nvidia.com> --------- Signed-off-by:
Ming Huang <mingh@nvidia.com>
Showing
Please register or sign in to comment