Unverified Commit 23caab3f authored by Phuong Nguyen's avatar Phuong Nguyen Committed by GitHub
Browse files

[TE/JAX] Disable FusedAttn with FFI by default (#1298)



* disable fused attn with ffi

---------
Signed-off-by: default avatarPhuong Nguyen <phuonguyen@nvidia.com>
parent 9dddb36d
...@@ -379,7 +379,7 @@ class FusedAttnFwdPrimitive(BasePrimitive): ...@@ -379,7 +379,7 @@ class FusedAttnFwdPrimitive(BasePrimitive):
wkspace_aval = ctx.avals_out[-1] wkspace_aval = ctx.avals_out[-1]
if is_ffi_enabled(): if is_ffi_enabled() and bool(os.getenv("NVTE_JAX_FUSED_ATTN_WITH_FFI", "0")):
name = "te_fused_attn_forward_ffi" name = "te_fused_attn_forward_ffi"
out = ffi.ffi_lowering(name)( out = ffi.ffi_lowering(name)(
ctx, ctx,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment