"examples/research_projects/bertabs/modeling_bertabs.py" did not exist on "81422c4e6d213767dc075f20049e8fd201675029"
Unverified Commit 45e11091 authored by twaka's avatar twaka Committed by GitHub
Browse files

Make loading of pretrained gpt2 faster by avoiding initialization of Conv1D's weights (#21879)

apply normal_ after assigning weight as nn.Parameter to avoid unnecessary initialization computation
parent 1d3a1cc4
......@@ -105,10 +105,9 @@ class Conv1D(nn.Module):
def __init__(self, nf, nx):
super().__init__()
self.nf = nf
w = torch.empty(nx, nf)
nn.init.normal_(w, std=0.02)
self.weight = nn.Parameter(w)
self.weight = nn.Parameter(torch.empty(nx, nf))
self.bias = nn.Parameter(torch.zeros(nf))
nn.init.normal_(self.weight, std=0.02)
def forward(self, x):
size_out = x.size()[:-1] + (self.nf,)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment