Commit 438db43d authored by thomwolf's avatar thomwolf
Browse files

update adaptive softmax head

parent c306869e
...@@ -89,24 +89,35 @@ class ProjectedAdaptiveLogSoftmax(nn.Module): ...@@ -89,24 +89,35 @@ class ProjectedAdaptiveLogSoftmax(nn.Module):
return logit return logit
def forward(self, hidden, target, keep_order=False): def forward(self, hidden, target=None, keep_order=False):
''' '''
hidden :: [len*bsz x d_proj] Params:
target :: [len*bsz] hidden :: [len*bsz x d_proj]
target :: [len*bsz]
Return:
if target is None:
out :: [len*bsz] Negative log likelihood
else:
out :: [len*bsz x n_tokens] log probabilities of tokens over the vocabulary
We could replace this implementation by the native PyTorch one We could replace this implementation by the native PyTorch one
if their was an option to set bias on all clusters in the native one. if their's had an option to set bias on all clusters in the native one.
line https://github.com/pytorch/pytorch/blob/dbe6a7a9ff1a364a8706bf5df58a1ca96d2fd9da/torch/nn/modules/adaptive.py#L138 here: https://github.com/pytorch/pytorch/blob/dbe6a7a9ff1a364a8706bf5df58a1ca96d2fd9da/torch/nn/modules/adaptive.py#L138
''' '''
if hidden.size(0) != target.size(0): if target is not None:
raise RuntimeError('Input and target should have the same size ' target = target.view(-1)
'in the batch dimension.') if hidden.size(0) != target.size(0):
raise RuntimeError('Input and target should have the same size '
'in the batch dimension.')
if self.n_clusters == 0: if self.n_clusters == 0:
logit = self._compute_logit(hidden, self.out_layers[0].weight, logit = self._compute_logit(hidden, self.out_layers[0].weight,
self.out_layers[0].bias, self.out_projs[0]) self.out_layers[0].bias, self.out_projs[0])
nll = -F.log_softmax(logit, dim=-1) \ if target is not None:
.gather(1, target.unsqueeze(1)).squeeze(1) output = -F.log_softmax(logit, dim=-1) \
.gather(1, target.unsqueeze(1)).squeeze(1)
else:
output = F.log_softmax(logit, dim=-1)
else: else:
# construct weights and biases # construct weights and biases
weights, biases = [], [] weights, biases = [], []
...@@ -133,44 +144,55 @@ class ProjectedAdaptiveLogSoftmax(nn.Module): ...@@ -133,44 +144,55 @@ class ProjectedAdaptiveLogSoftmax(nn.Module):
head_logit = self._compute_logit(hidden, head_weight, head_bias, head_proj) head_logit = self._compute_logit(hidden, head_weight, head_bias, head_proj)
head_logprob = F.log_softmax(head_logit, dim=1) head_logprob = F.log_softmax(head_logit, dim=1)
nll = torch.zeros_like(target, if target is None:
dtype=hidden.dtype, device=hidden.device) out = hidden.new_empty((head_logit.size(0), self.n_token))
else:
out = torch.zeros_like(target, dtype=hidden.dtype, device=hidden.device)
offset = 0 offset = 0
cutoff_values = [0] + self.cutoffs cutoff_values = [0] + self.cutoffs
for i in range(len(cutoff_values) - 1): for i in range(len(cutoff_values) - 1):
l_idx, r_idx = cutoff_values[i], cutoff_values[i + 1] l_idx, r_idx = cutoff_values[i], cutoff_values[i + 1]
mask_i = (target >= l_idx) & (target < r_idx) if target is not None:
indices_i = mask_i.nonzero().squeeze() mask_i = (target >= l_idx) & (target < r_idx)
indices_i = mask_i.nonzero().squeeze()
if indices_i.numel() == 0: if indices_i.numel() == 0:
continue continue
target_i = target.index_select(0, indices_i) - l_idx target_i = target.index_select(0, indices_i) - l_idx
head_logprob_i = head_logprob.index_select(0, indices_i) head_logprob_i = head_logprob.index_select(0, indices_i)
hidden_i = hidden.index_select(0, indices_i)
else:
hidden_i = hidden
if i == 0: if i == 0:
logprob_i = head_logprob_i.gather(1, target_i[:,None]).squeeze(1) if target is not None:
logprob_i = head_logprob_i.gather(1, target_i[:,None]).squeeze(1)
else:
out[:, :self.cutoffs[0]] = head_logprob[:, :self.cutoffs[0]]
else: else:
weight_i, bias_i, proj_i = weights[i], biases[i], self.out_projs[i] weight_i, bias_i, proj_i = weights[i], biases[i], self.out_projs[i]
hidden_i = hidden.index_select(0, indices_i)
tail_logit_i = self._compute_logit(hidden_i, weight_i, bias_i, proj_i) tail_logit_i = self._compute_logit(hidden_i, weight_i, bias_i, proj_i)
tail_logprob_i = F.log_softmax(tail_logit_i, dim=1) tail_logprob_i = F.log_softmax(tail_logit_i, dim=1)
cluster_prob_idx = self.cutoffs[0] + i - 1 # No probability for the head cluster cluster_prob_idx = self.cutoffs[0] + i - 1 # No probability for the head cluster
logprob_i = head_logprob_i[:, cluster_prob_idx] \ if target is not None:
+ tail_logprob_i.gather(1, target_i[:, None]).squeeze(1) logprob_i = head_logprob_i[:, cluster_prob_idx] \
+ tail_logprob_i.gather(1, target_i[:, None]).squeeze(1)
if (hasattr(self, 'keep_order') and self.keep_order) or keep_order: else:
nll.index_copy_(0, indices_i, -logprob_i) logprob_i = head_logprob[:, cluster_prob_idx, None] + tail_logprob_i
else: out[:, l_idx:r_idx] = logprob_i
nll[offset:offset+logprob_i.size(0)].copy_(-logprob_i)
if target is not None:
offset += logprob_i.size(0) if (hasattr(self, 'keep_order') and self.keep_order) or keep_order:
out.index_copy_(0, indices_i, -logprob_i)
return nll else:
out[offset:offset+logprob_i.size(0)].copy_(-logprob_i)
offset += logprob_i.size(0)
return out
def log_prob(self, hidden): def log_prob(self, hidden):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment