Unverified Commit 0b303dad authored by Charlene Yang's avatar Charlene Yang Committed by GitHub
Browse files

[PyTorch] Fix tp_size for MQA/GQA (#1044)



fix tp_size for GQA
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>
parent 4cc220c9
...@@ -5125,7 +5125,7 @@ class DotProductAttention(TransformerEngineBaseModule): ...@@ -5125,7 +5125,7 @@ class DotProductAttention(TransformerEngineBaseModule):
self.hidden_size_per_attention_head = kv_channels self.hidden_size_per_attention_head = kv_channels
self.num_gqa_groups = num_attention_heads if num_gqa_groups is None else num_gqa_groups self.num_gqa_groups = num_attention_heads if num_gqa_groups is None else num_gqa_groups
self.num_gqa_groups_per_partition = int(self.num_gqa_groups // tp_size) self.num_gqa_groups_per_partition = int(self.num_gqa_groups // self.tp_size)
assert ( assert (
num_attention_heads % self.num_gqa_groups == 0 num_attention_heads % self.num_gqa_groups == 0
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment