Unverified Commit 4b8ffef4 authored by Phuong Nguyen's avatar Phuong Nguyen Committed by GitHub
Browse files

[JAX] Fix for Disable FusedAttn with FFI by default (#1304)



rm default value for NVTE_JAX_FUSED_ATTN_WITH_FFI
Signed-off-by: default avatarPhuong Nguyen <phuonguyen@nvidia.com>
parent 05c0fb02
...@@ -379,7 +379,7 @@ class FusedAttnFwdPrimitive(BasePrimitive): ...@@ -379,7 +379,7 @@ class FusedAttnFwdPrimitive(BasePrimitive):
wkspace_aval = ctx.avals_out[-1] wkspace_aval = ctx.avals_out[-1]
if is_ffi_enabled() and bool(os.getenv("NVTE_JAX_FUSED_ATTN_WITH_FFI", "0")): if is_ffi_enabled() and bool(os.getenv("NVTE_JAX_FUSED_ATTN_WITH_FFI")):
name = "te_fused_attn_forward_ffi" name = "te_fused_attn_forward_ffi"
out = ffi.ffi_lowering(name)( out = ffi.ffi_lowering(name)(
ctx, ctx,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment