"googlemock/include/gmock/vscode:/vscode.git/clone" did not exist on "96f7ba83cb225dd9152dcffe3104af093e66510a"
Unverified Commit 1bb8b6eb authored by Przemyslaw Tredak's avatar Przemyslaw Tredak Committed by GitHub
Browse files

Fix UB names in MHA (#588)


Signed-off-by: default avatarPrzemek Tredak <ptredak@nvidia.com>
parent 9406633c
...@@ -2737,6 +2737,7 @@ class MultiheadAttention(torch.nn.Module): ...@@ -2737,6 +2737,7 @@ class MultiheadAttention(torch.nn.Module):
ub_split_ag=ub_split_ag, ub_split_ag=ub_split_ag,
normalization=normalization, normalization=normalization,
ub_atomic_gemm_ag=ub_atomic_gemm_ag, ub_atomic_gemm_ag=ub_atomic_gemm_ag,
ub_name="qkv",
**common_gemm_kwargs, **common_gemm_kwargs,
) )
else: else:
...@@ -2768,6 +2769,7 @@ class MultiheadAttention(torch.nn.Module): ...@@ -2768,6 +2769,7 @@ class MultiheadAttention(torch.nn.Module):
ub_split_ag=ub_split_ag, ub_split_ag=ub_split_ag,
normalization=normalization, normalization=normalization,
ub_atomic_gemm_ag=ub_atomic_gemm_ag, ub_atomic_gemm_ag=ub_atomic_gemm_ag,
ub_name="qkv",
**common_gemm_kwargs, **common_gemm_kwargs,
) )
else: else:
...@@ -2816,6 +2818,7 @@ class MultiheadAttention(torch.nn.Module): ...@@ -2816,6 +2818,7 @@ class MultiheadAttention(torch.nn.Module):
ub_split_ag=ub_split_ag, ub_split_ag=ub_split_ag,
ub_atomic_gemm_rs=ub_atomic_gemm_rs, ub_atomic_gemm_rs=ub_atomic_gemm_rs,
ub_atomic_gemm_ag=ub_atomic_gemm_ag, ub_atomic_gemm_ag=ub_atomic_gemm_ag,
ub_name="proj",
**common_gemm_kwargs, **common_gemm_kwargs,
) )
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment