Unverified Commit b14f417a authored by Jan Bernlöhr's avatar Jan Bernlöhr Committed by GitHub
Browse files

[PyTorch] Fix assertion error message formatting in DotProductAttention (#2103)


Signed-off-by: default avatarjanbernloehr <jan@bernloehrs.de>
Co-authored-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>
parent 632c4c3e
...@@ -1033,14 +1033,14 @@ class DotProductAttention(TransformerEngineBaseModule): ...@@ -1033,14 +1033,14 @@ class DotProductAttention(TransformerEngineBaseModule):
query_layer.shape[-1] == key_layer.shape[-1] query_layer.shape[-1] == key_layer.shape[-1]
), "Queries and keys must have the same head dimension!" ), "Queries and keys must have the same head dimension!"
head_dim_qk, head_dim_v = query_layer.shape[-1], value_layer.shape[-1] head_dim_qk, head_dim_v = query_layer.shape[-1], value_layer.shape[-1]
assert ( assert head_dim_qk == self.hidden_size_per_attention_head_k, (
head_dim_qk == self.hidden_size_per_attention_head_k f"Keys have head_dim = {head_dim_qk}, but expected head_dim ="
), f"Keys have head_dim = {head_dim_qk}, " f" {self.hidden_size_per_attention_head_k}!"
"but expected head_dim = {self.hidden_size_per_attention_head_k}!" )
assert ( assert head_dim_v == self.hidden_size_per_attention_head_v, (
head_dim_v == self.hidden_size_per_attention_head_v f"Values have head_dim = {head_dim_v}, but expected head_dim ="
), f"Values have head_dim = {head_dim_v}, " f" {self.hidden_size_per_attention_head_v}!"
"but expected head_dim = {self.hidden_size_per_attention_head_v}!" )
assert num_gqa_groups == self.num_gqa_groups_per_partition, ( assert num_gqa_groups == self.num_gqa_groups_per_partition, (
"Keys and values must have num_gqa_group =" "Keys and values must have num_gqa_group ="
f" {self.num_gqa_groups_per_partition} heads! Found {num_gqa_groups}." f" {self.num_gqa_groups_per_partition} heads! Found {num_gqa_groups}."
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment