Unverified Commit 23caab3f authored by Phuong Nguyen's avatar Phuong Nguyen Committed by GitHub
Browse files

[TE/JAX] Disable FusedAttn with FFI by default (#1298)



* disable fused attn with ffi

---------
Signed-off-by: default avatarPhuong Nguyen <phuonguyen@nvidia.com>
parent 9dddb36d
......@@ -379,7 +379,7 @@ class FusedAttnFwdPrimitive(BasePrimitive):
wkspace_aval = ctx.avals_out[-1]
if is_ffi_enabled():
if is_ffi_enabled() and bool(os.getenv("NVTE_JAX_FUSED_ATTN_WITH_FFI", "0")):
name = "te_fused_attn_forward_ffi"
out = ffi.ffi_lowering(name)(
ctx,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment