[JAX] Clamped Swiglu Integration (#2194)
Signed-off-by:
Varun Thumbe <vthumbe@nvidia.com>
*Jax integration for clamped swiglu. This is the continuation of PR which added Clamped Swiglu(used in GPT OSS) support in TE along with Pytorch integration. This PR hooks up the clamped swiglu and dswiglu's nvte APIs to TE Jax.
Showing
This diff is collapsed.
Please register or sign in to comment