Commit a0d99fdb authored by Szymon Migacz's avatar Szymon Migacz
Browse files

Fixed weight init for fused weight matrices in fused MHA by adding correct gain factor.

parent 1ff54b8f
import math
import torch import torch
from torch import nn from torch import nn
from torch.nn import Parameter from torch.nn import Parameter
...@@ -76,7 +78,11 @@ class EncdecMultiheadAttn(nn.Module): ...@@ -76,7 +78,11 @@ class EncdecMultiheadAttn(nn.Module):
def reset_parameters(self): def reset_parameters(self):
nn.init.xavier_uniform_(self.in_proj_weight_q) nn.init.xavier_uniform_(self.in_proj_weight_q)
nn.init.xavier_uniform_(self.in_proj_weight_kv) # in_proj_weight_kv has shape [2 * hidden, hidden] but it should be
# initialized like a [hidden, hidden] matrix.
# sqrt(6 / (hidden + hidden)) / sqrt(6 / (2 * hidden + hidden)) = sqrt(1.5)
# therefore xavier_uniform gain should be set to sqrt(1.5).
nn.init.xavier_uniform_(self.in_proj_weight_kv, gain=math.sqrt(1.5))
nn.init.xavier_uniform_(self.out_proj_weight) nn.init.xavier_uniform_(self.out_proj_weight)
if self.bias: if self.bias:
nn.init.constant_(self.in_proj_bias_q, 0.) nn.init.constant_(self.in_proj_bias_q, 0.)
......
import math
import torch import torch
from torch import nn from torch import nn
from torch.nn import Parameter from torch.nn import Parameter
...@@ -98,7 +100,11 @@ class SelfMultiheadAttn(nn.Module): ...@@ -98,7 +100,11 @@ class SelfMultiheadAttn(nn.Module):
nn.init.xavier_uniform_(self.k_weight) nn.init.xavier_uniform_(self.k_weight)
nn.init.xavier_uniform_(self.v_weight) nn.init.xavier_uniform_(self.v_weight)
else: else:
nn.init.xavier_uniform_(self.in_proj_weight) # in_proj_weight has shape [3 * hidden, hidden] but it should be
# initialized like a [hidden, hidden] matrix.
# sqrt(6 / (hidden + hidden)) / sqrt(6 / (3 * hidden + hidden)) = sqrt(2)
# therefore xavier_uniform gain should be set to sqrt(2).
nn.init.xavier_uniform_(self.in_proj_weight, gain=math.sqrt(2))
nn.init.xavier_uniform_(self.out_proj_weight) nn.init.xavier_uniform_(self.out_proj_weight)
if self.bias: if self.bias:
if self.separate_qkv_params: if self.separate_qkv_params:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment