Unverified Commit 3104fd59 authored by Thor Johnsen's avatar Thor Johnsen Committed by GitHub
Browse files

Merge pull request #910 from szmigacz/smigacz/mha_xavier_init_gain_fix

Fixed weight init for fused weight matrices in fused MHA by adding correct gain factor
parents 4027bcba a0d99fdb
import math
import torch import torch
from torch import nn from torch import nn
from torch.nn import Parameter from torch.nn import Parameter
...@@ -76,7 +78,11 @@ class EncdecMultiheadAttn(nn.Module): ...@@ -76,7 +78,11 @@ class EncdecMultiheadAttn(nn.Module):
def reset_parameters(self): def reset_parameters(self):
nn.init.xavier_uniform_(self.in_proj_weight_q) nn.init.xavier_uniform_(self.in_proj_weight_q)
nn.init.xavier_uniform_(self.in_proj_weight_kv) # in_proj_weight_kv has shape [2 * hidden, hidden] but it should be
# initialized like a [hidden, hidden] matrix.
# sqrt(6 / (hidden + hidden)) / sqrt(6 / (2 * hidden + hidden)) = sqrt(1.5)
# therefore xavier_uniform gain should be set to sqrt(1.5).
nn.init.xavier_uniform_(self.in_proj_weight_kv, gain=math.sqrt(1.5))
nn.init.xavier_uniform_(self.out_proj_weight) nn.init.xavier_uniform_(self.out_proj_weight)
if self.bias: if self.bias:
nn.init.constant_(self.in_proj_bias_q, 0.) nn.init.constant_(self.in_proj_bias_q, 0.)
......
import math
import torch import torch
from torch import nn from torch import nn
from torch.nn import Parameter from torch.nn import Parameter
...@@ -98,7 +100,11 @@ class SelfMultiheadAttn(nn.Module): ...@@ -98,7 +100,11 @@ class SelfMultiheadAttn(nn.Module):
nn.init.xavier_uniform_(self.k_weight) nn.init.xavier_uniform_(self.k_weight)
nn.init.xavier_uniform_(self.v_weight) nn.init.xavier_uniform_(self.v_weight)
else: else:
nn.init.xavier_uniform_(self.in_proj_weight) # in_proj_weight has shape [3 * hidden, hidden] but it should be
# initialized like a [hidden, hidden] matrix.
# sqrt(6 / (hidden + hidden)) / sqrt(6 / (3 * hidden + hidden)) = sqrt(2)
# therefore xavier_uniform gain should be set to sqrt(2).
nn.init.xavier_uniform_(self.in_proj_weight, gain=math.sqrt(2))
nn.init.xavier_uniform_(self.out_proj_weight) nn.init.xavier_uniform_(self.out_proj_weight)
if self.bias: if self.bias:
if self.separate_qkv_params: if self.separate_qkv_params:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment