Unverified Commit c49f90d3 authored by Kirthi Shankar Sivamani's avatar Kirthi Shankar Sivamani Committed by GitHub
Browse files

Fix in MHA cross attention path (#43)



fix unfused qkv param Xattn path
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>
parent 003f0549
......@@ -332,7 +332,7 @@ class MultiHeadAttention(torch.nn.Module):
**common_gemm_kwargs,
)
else:
self.query = Linear(
self.query_layer = Linear(
hidden_size,
hidden_size,
init_method=init_method,
......@@ -632,7 +632,7 @@ class MultiHeadAttention(torch.nn.Module):
else:
query_layer = layernorm_query_outputs
else:
query_layer = self.query(
query_layer = self.query_layer(
hidden_states,
weight=self.query,
bias=self.query_bias,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment