Unverified Commit bea70f2e authored by Marks101's avatar Marks101 Committed by GitHub
Browse files

[PyTorch] forward attention_type in MultiHeadAttention (#621)



[PyTorch] fix forward attention_type in MultiheadAttention
Signed-off-by: default avatarMarkus Schnoes <markus.schnoes@gmx.de>
Co-authored-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>
parent 4dc36f0e
......@@ -3090,6 +3090,7 @@ class MultiheadAttention(torch.nn.Module):
sequence_parallel=sequence_parallel,
tp_group=tp_group,
layer_number=self.layer_number,
attention_type=self.attention_type,
)
# Linear
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment