Unverified Commit 1bb8b6eb authored by Przemyslaw Tredak's avatar Przemyslaw Tredak Committed by GitHub
Browse files

Fix UB names in MHA (#588)


Signed-off-by: default avatarPrzemek Tredak <ptredak@nvidia.com>
parent 9406633c
......@@ -2737,6 +2737,7 @@ class MultiheadAttention(torch.nn.Module):
ub_split_ag=ub_split_ag,
normalization=normalization,
ub_atomic_gemm_ag=ub_atomic_gemm_ag,
ub_name="qkv",
**common_gemm_kwargs,
)
else:
......@@ -2768,6 +2769,7 @@ class MultiheadAttention(torch.nn.Module):
ub_split_ag=ub_split_ag,
normalization=normalization,
ub_atomic_gemm_ag=ub_atomic_gemm_ag,
ub_name="qkv",
**common_gemm_kwargs,
)
else:
......@@ -2816,6 +2818,7 @@ class MultiheadAttention(torch.nn.Module):
ub_split_ag=ub_split_ag,
ub_atomic_gemm_rs=ub_atomic_gemm_rs,
ub_atomic_gemm_ag=ub_atomic_gemm_ag,
ub_name="proj",
**common_gemm_kwargs,
)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment