Unverified Commit b495120e authored by Charlene Yang's avatar Charlene Yang Committed by GitHub
Browse files

[PyTorch] Fix GQA error message (#1328)



* fix GQA error message
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



---------
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>
Co-authored-by: default avatarpre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
parent 994f19d0
...@@ -7952,7 +7952,10 @@ class DotProductAttention(TransformerEngineBaseModule): ...@@ -7952,7 +7952,10 @@ class DotProductAttention(TransformerEngineBaseModule):
assert ( assert (
key_layer.shape[-2] == self.num_gqa_groups_per_partition key_layer.shape[-2] == self.num_gqa_groups_per_partition
and value_layer.shape[-2] == self.num_gqa_groups_per_partition and value_layer.shape[-2] == self.num_gqa_groups_per_partition
), f"Keys and values must have num_gqa_group = {self.num_gqa_groups} heads!" ), (
"Keys and values must have num_gqa_group ="
f" {self.num_gqa_groups_per_partition} heads!"
)
assert qkv_format in [ assert qkv_format in [
"sbhd", "sbhd",
"bshd", "bshd",
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment