[JAX] Fix unfused GQA performance (#643)
* Fix unfused GQA perf Signed-off-by:Reese Wang <rewang@nvidia.com> * Remove WAR for Check failed: reduction_kind.has_value() Signed-off-by:
Reese Wang <rewang@nvidia.com> --------- Signed-off-by:
Reese Wang <rewang@nvidia.com>
Showing
Please register or sign in to comment