Unverified Commit bea70f2e authored by Marks101's avatar Marks101 Committed by GitHub
Browse files

[PyTorch] forward attention_type in MultiHeadAttention (#621)



[PyTorch] fix forward attention_type in MultiheadAttention
Signed-off-by: default avatarMarkus Schnoes <markus.schnoes@gmx.de>
Co-authored-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>
parent 4dc36f0e
...@@ -3090,6 +3090,7 @@ class MultiheadAttention(torch.nn.Module): ...@@ -3090,6 +3090,7 @@ class MultiheadAttention(torch.nn.Module):
sequence_parallel=sequence_parallel, sequence_parallel=sequence_parallel,
tp_group=tp_group, tp_group=tp_group,
layer_number=self.layer_number, layer_number=self.layer_number,
attention_type=self.attention_type,
) )
# Linear # Linear
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment