"tests/vscode:/vscode.git/clone" did not exist on "a96d63c21d18ad6610adfcabd3aae02c6357334e"
Unverified Commit c49f90d3 authored by Kirthi Shankar Sivamani's avatar Kirthi Shankar Sivamani Committed by GitHub
Browse files

Fix in MHA cross attention path (#43)



fix unfused qkv param Xattn path
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>
parent 003f0549
...@@ -332,7 +332,7 @@ class MultiHeadAttention(torch.nn.Module): ...@@ -332,7 +332,7 @@ class MultiHeadAttention(torch.nn.Module):
**common_gemm_kwargs, **common_gemm_kwargs,
) )
else: else:
self.query = Linear( self.query_layer = Linear(
hidden_size, hidden_size,
hidden_size, hidden_size,
init_method=init_method, init_method=init_method,
...@@ -632,7 +632,7 @@ class MultiHeadAttention(torch.nn.Module): ...@@ -632,7 +632,7 @@ class MultiHeadAttention(torch.nn.Module):
else: else:
query_layer = layernorm_query_outputs query_layer = layernorm_query_outputs
else: else:
query_layer = self.query( query_layer = self.query_layer(
hidden_states, hidden_states,
weight=self.query, weight=self.query,
bias=self.query_bias, bias=self.query_bias,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment