Commit d7706f0d authored by Jiezhong Qiu's avatar Jiezhong Qiu
Browse files

fixes pytorch/pytorch/issues/36035

follow suggestions from https://github.com/pytorch/pytorch/issues/36035#issuecomment-770960405
parent e86dea53
......@@ -10,7 +10,7 @@ import torch.nn.functional as F
# import torch_sparse
sys.path.append('utils')
from proj_adaptive_softmax import ProjectedAdaptiveLogSoftmax
from proj_adaptive_softmax import ProjectedAdaptiveLogSoftmax, Projection
from log_uniform_sampler import LogUniformSampler, sample_logits
class PositionalEmbedding(nn.Module):
......@@ -822,7 +822,7 @@ class CustomizedMoEPositionwiseFF(FMoETransformerMLP):
def __init__(self, d_model, d_inner, dropout, pre_lnorm=False):
def activation(x):
return self.dropout(F.relu(x))
super().__init__(num_expert=8, d_model=d_model, d_hidden=d_inner,
super().__init__(num_expert=64, d_model=d_model, d_hidden=d_inner, topk=2,
pre_lnorm=pre_lnorm, activation=activation)
self.dropout = nn.Dropout(dropout)
self.bias = nn.Parameter(
......@@ -896,6 +896,7 @@ class RelPartialLearnableDecoderLayer(nn.Module):
# return output, relu_out
class AdaptiveEmbedding(nn.Module):
def __init__(self, n_token, d_embed, d_proj, cutoffs, div_val=1,
sample_softmax=False):
......@@ -913,25 +914,26 @@ class AdaptiveEmbedding(nn.Module):
self.cutoff_ends = [0] + self.cutoffs
self.emb_layers = nn.ModuleList()
self.emb_projs = nn.ParameterList()
self.emb_projs = nn.ModuleList()
if div_val == 1:
self.emb_layers.append(
nn.Embedding(n_token, d_embed, sparse=sample_softmax>0)
)
if d_proj != d_embed:
self.emb_projs.append(nn.Parameter(torch.Tensor(d_proj, d_embed)))
self.emb_projs.append(Projection(d_proj, d_embed))
else:
for i in range(len(self.cutoffs)):
l_idx, r_idx = self.cutoff_ends[i], self.cutoff_ends[i+1]
d_emb_i = d_embed // (div_val ** i)
self.emb_layers.append(nn.Embedding(r_idx-l_idx, d_emb_i))
self.emb_projs.append(nn.Parameter(torch.Tensor(d_proj, d_emb_i)))
self.emb_projs.append(Projectio(d_proj, d_emb_i))
def forward(self, inp):
if self.div_val == 1:
embed = self.emb_layers[0](inp)
if self.d_proj != self.d_embed:
embed = F.linear(embed, self.emb_projs[0])
embed = F.linear(embed, self.emb_projs[0].weight)
else:
param = next(self.parameters())
inp_flat = inp.view(-1)
......@@ -948,7 +950,7 @@ class AdaptiveEmbedding(nn.Module):
inp_i = inp_flat.index_select(0, indices_i) - l_idx
emb_i = self.emb_layers[i](inp_i)
emb_i = F.linear(emb_i, self.emb_projs[i])
emb_i = F.linear(emb_i, self.emb_projs[i].weight)
emb_flat.index_copy_(0, indices_i, emb_i)
......@@ -1035,9 +1037,9 @@ class MemTransformerLM(nn.Module):
if tie_projs:
for i, tie_proj in enumerate(tie_projs):
if tie_proj and div_val == 1 and d_model != d_embed:
self.crit.out_projs[i] = self.word_emb.emb_projs[0]
self.crit.out_projs[i].weight = self.word_emb.emb_projs[0].weight
elif tie_proj and div_val != 1:
self.crit.out_projs[i] = self.word_emb.emb_projs[i]
self.crit.out_projs[i].weight = self.word_emb.emb_projs[i].weight
self.same_length = same_length
self.clamp_len = clamp_len
......@@ -1070,12 +1072,11 @@ class MemTransformerLM(nn.Module):
self.mem_len = mem_len
self.ext_len = ext_len
def init_mems(self):
def init_mems(self, x):
if self.mem_len > 0:
mems = []
param = next(self.parameters())
for i in range(self.n_layer+1):
empty = torch.empty(0, dtype=param.dtype, device=param.device)
empty = torch.empty(0, dtype=x.dtype, device=x.device)
mems.append(empty)
return mems
......@@ -1215,7 +1216,7 @@ class MemTransformerLM(nn.Module):
# So, have to initialize size(0) mems inside the model forward.
# Moreover, have to return new_mems to allow nn.DataParallel to piece
# them together.
if not mems: mems = self.init_mems()
if not mems: mems = self.init_mems(data)
tgt_len = target.size(0)
hidden, new_mems = self._forward(data, mems=mems)
......
......@@ -9,6 +9,10 @@ import torch.nn.functional as F
CUDA_MAJOR = int(torch.version.cuda.split('.')[0])
CUDA_MINOR = int(torch.version.cuda.split('.')[1])
class Projection(nn.Module):
def __init__(self, out_feat, in_feat):
self.weight = nn.Parameter(torch.Tensor(out_feat, in_feat))
class ProjectedAdaptiveLogSoftmax(nn.Module):
def __init__(self, n_token, d_embed, d_proj, cutoffs, div_val=1,
keep_order=False):
......@@ -31,13 +35,13 @@ class ProjectedAdaptiveLogSoftmax(nn.Module):
self.cluster_bias = nn.Parameter(torch.zeros(self.n_clusters))
self.out_layers = nn.ModuleList()
self.out_projs = nn.ParameterList()
self.out_projs = nn.ModuleList()
if div_val == 1:
for i in range(len(self.cutoffs)):
if d_proj != d_embed:
self.out_projs.append(
nn.Parameter(torch.Tensor(d_proj, d_embed))
Projection(d_proj, d_embed)
)
else:
self.out_projs.append(None)
......@@ -49,7 +53,7 @@ class ProjectedAdaptiveLogSoftmax(nn.Module):
d_emb_i = d_embed // (div_val ** i)
self.out_projs.append(
nn.Parameter(torch.Tensor(d_proj, d_emb_i))
Projection(d_proj, d_emb_i)
)
self.out_layers.append(nn.Linear(d_emb_i, r_idx-l_idx))
......@@ -82,7 +86,7 @@ class ProjectedAdaptiveLogSoftmax(nn.Module):
if self.n_clusters == 0:
logit = self._compute_logit(hidden, self.out_layers[0].weight,
self.out_layers[0].bias, self.out_projs[0])
self.out_layers[0].bias, self.out_projs[0].weight if self.out_projs[0] is not None else None)
nll = -F.log_softmax(logit, dim=-1) \
.gather(1, target.unsqueeze(1)).squeeze(1)
else:
......@@ -106,7 +110,7 @@ class ProjectedAdaptiveLogSoftmax(nn.Module):
weights.append(weight_i)
biases.append(bias_i)
head_weight, head_bias, head_proj = weights[0], biases[0], self.out_projs[0]
head_weight, head_bias, head_proj = weights[0], biases[0], self.out_projs[0].weight
head_logit = self._compute_logit(hidden, head_weight, head_bias, head_proj)
head_logprob = F.log_softmax(head_logit, dim=1)
......@@ -131,7 +135,7 @@ class ProjectedAdaptiveLogSoftmax(nn.Module):
if i == 0:
logprob_i = head_logprob_i.gather(1, target_i[:,None]).squeeze(1)
else:
weight_i, bias_i, proj_i = weights[i], biases[i], self.out_projs[i]
weight_i, bias_i, proj_i = weights[i], biases[i], self.out_projs[i].weight
hidden_i = hidden.index_select(0, indices_i)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment