Unverified Commit 43b9e1ee authored by Md Fahim Faysal Khan's avatar Md Fahim Faysal Khan Committed by GitHub
Browse files

fix assertion bug for SWA API in TE-JAX (#1242)



fixed assertion bug for SWA
Signed-off-by: default avatarMd Fahim Faysal Khan <mdfahimfaysa@nvidia.com>
Co-authored-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>
Co-authored-by: default avatarPhuong Nguyen <36155692+phu0ngng@users.noreply.github.com>
parent f6b766bd
......@@ -1061,7 +1061,7 @@ class FusedAttnCPWithAllGatherFwdPrimitive(FusedAttnFwdPrimitive):
# Call base implementation for non-context parallel mesh to avoid unecessary work.
is_context_parallel = get_mesh_axis_size(config.cp_axis, mesh) > 1
assert (
is_context_parallel and config.window_size[0] == -1
not is_context_parallel or config.window_size[0] == -1
), "Sliding window attention is not supported when context parallelism is enabled"
if not is_context_parallel:
return FusedAttnFwdPrimitive.partition(config, mesh, arg_infos, result_infos)
......@@ -1158,7 +1158,7 @@ class FusedAttnCPWithAllGatherBwdPrimitive(FusedAttnBwdPrimitive):
# Call base implementation for non-context parallel mesh to avoid unecessary work.
is_context_parallel = get_mesh_axis_size(config.cp_axis, mesh) > 1
assert (
is_context_parallel and config.window_size[0] == -1
not is_context_parallel or config.window_size[0] == -1
), "Sliding window attention is not supported when context parallelism is enabled"
if not is_context_parallel:
return FusedAttnBwdPrimitive.partition(config, mesh, arg_infos, result_infos)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment