Commit 97392643 authored by thomwolf's avatar thomwolf
Browse files

fix differencies with tensorflow version (mem cells and adaptive sofmax clusters)

parent ba9e4eb3
...@@ -1088,7 +1088,7 @@ class TransfoXLModel(TransfoXLPreTrainedModel): ...@@ -1088,7 +1088,7 @@ class TransfoXLModel(TransfoXLPreTrainedModel):
if self.mem_len > 0: if self.mem_len > 0:
mems = [] mems = []
param = next(self.parameters()) param = next(self.parameters())
for i in range(self.n_layer+1): for i in range(self.n_layer):
empty = torch.zeros(self.mem_len, data.size(1), self.config.d_model, empty = torch.zeros(self.mem_len, data.size(1), self.config.d_model,
dtype=param.dtype, device=param.device) dtype=param.dtype, device=param.device)
mems.append(empty) mems.append(empty)
...@@ -1151,15 +1151,14 @@ class TransfoXLModel(TransfoXLPreTrainedModel): ...@@ -1151,15 +1151,14 @@ class TransfoXLModel(TransfoXLPreTrainedModel):
core_out = self.drop(word_emb) core_out = self.drop(word_emb)
pos_emb = self.drop(pos_emb) pos_emb = self.drop(pos_emb)
hids.append(core_out)
for i, layer in enumerate(self.layers): for i, layer in enumerate(self.layers):
hids.append(core_out)
mems_i = None if mems is None else mems[i] mems_i = None if mems is None else mems[i]
core_out = layer(core_out, pos_emb, dec_attn_mask=dec_attn_mask, mems=mems_i) core_out = layer(core_out, pos_emb, dec_attn_mask=dec_attn_mask, mems=mems_i)
hids.append(core_out)
elif self.attn_type == 1: # learnable elif self.attn_type == 1: # learnable
core_out = self.drop(word_emb) core_out = self.drop(word_emb)
hids.append(core_out)
for i, layer in enumerate(self.layers): for i, layer in enumerate(self.layers):
hids.append(core_out)
if self.clamp_len > 0: if self.clamp_len > 0:
r_emb = self.r_emb[i][-self.clamp_len :] r_emb = self.r_emb[i][-self.clamp_len :]
r_bias = self.r_bias[i][-self.clamp_len :] r_bias = self.r_bias[i][-self.clamp_len :]
...@@ -1169,7 +1168,6 @@ class TransfoXLModel(TransfoXLPreTrainedModel): ...@@ -1169,7 +1168,6 @@ class TransfoXLModel(TransfoXLPreTrainedModel):
mems_i = None if mems is None else mems[i] mems_i = None if mems is None else mems[i]
core_out = layer(core_out, r_emb, self.r_w_bias[i], core_out = layer(core_out, r_emb, self.r_w_bias[i],
r_bias, dec_attn_mask=dec_attn_mask, mems=mems_i) r_bias, dec_attn_mask=dec_attn_mask, mems=mems_i)
hids.append(core_out)
elif self.attn_type == 2: # absolute elif self.attn_type == 2: # absolute
pos_seq = torch.arange(klen - 1, -1, -1.0, device=word_emb.device, pos_seq = torch.arange(klen - 1, -1, -1.0, device=word_emb.device,
dtype=word_emb.dtype) dtype=word_emb.dtype)
...@@ -1179,19 +1177,18 @@ class TransfoXLModel(TransfoXLPreTrainedModel): ...@@ -1179,19 +1177,18 @@ class TransfoXLModel(TransfoXLPreTrainedModel):
core_out = self.drop(word_emb + pos_emb[-qlen:]) core_out = self.drop(word_emb + pos_emb[-qlen:])
hids.append(core_out)
for i, layer in enumerate(self.layers): for i, layer in enumerate(self.layers):
hids.append(core_out)
mems_i = None if mems is None else mems[i] mems_i = None if mems is None else mems[i]
if mems_i is not None and i == 0: if mems_i is not None and i == 0:
mems_i += pos_emb[:mlen] mems_i += pos_emb[:mlen]
core_out = layer(core_out, dec_attn_mask=dec_attn_mask, core_out = layer(core_out, dec_attn_mask=dec_attn_mask,
mems=mems_i) mems=mems_i)
hids.append(core_out)
elif self.attn_type == 3: elif self.attn_type == 3:
core_out = self.drop(word_emb) core_out = self.drop(word_emb)
hids.append(core_out)
for i, layer in enumerate(self.layers): for i, layer in enumerate(self.layers):
hids.append(core_out)
mems_i = None if mems is None else mems[i] mems_i = None if mems is None else mems[i]
if mems_i is not None and mlen > 0: if mems_i is not None and mlen > 0:
cur_emb = self.r_emb[i][:-qlen] cur_emb = self.r_emb[i][:-qlen]
...@@ -1206,7 +1203,6 @@ class TransfoXLModel(TransfoXLPreTrainedModel): ...@@ -1206,7 +1203,6 @@ class TransfoXLModel(TransfoXLPreTrainedModel):
core_out = layer(core_out, dec_attn_mask=dec_attn_mask, core_out = layer(core_out, dec_attn_mask=dec_attn_mask,
mems=mems_i) mems=mems_i)
hids.append(core_out)
core_out = self.drop(core_out) core_out = self.drop(core_out)
...@@ -1241,5 +1237,4 @@ class TransfoXLModel(TransfoXLPreTrainedModel): ...@@ -1241,5 +1237,4 @@ class TransfoXLModel(TransfoXLPreTrainedModel):
if new_mems is None: if new_mems is None:
return [loss] return [loss]
else: else:
return [loss] + new_mems return (loss, new_mems)
...@@ -93,6 +93,9 @@ class ProjectedAdaptiveLogSoftmax(nn.Module): ...@@ -93,6 +93,9 @@ class ProjectedAdaptiveLogSoftmax(nn.Module):
''' '''
hidden :: [len*bsz x d_proj] hidden :: [len*bsz x d_proj]
target :: [len*bsz] target :: [len*bsz]
We could replace this implementation by the native PyTorch one
if their was an option to set bias on all clusters in the native one.
line https://github.com/pytorch/pytorch/blob/dbe6a7a9ff1a364a8706bf5df58a1ca96d2fd9da/torch/nn/modules/adaptive.py#L138
''' '''
if hidden.size(0) != target.size(0): if hidden.size(0) != target.size(0):
...@@ -156,9 +159,9 @@ class ProjectedAdaptiveLogSoftmax(nn.Module): ...@@ -156,9 +159,9 @@ class ProjectedAdaptiveLogSoftmax(nn.Module):
tail_logit_i = self._compute_logit(hidden_i, weight_i, bias_i, proj_i) tail_logit_i = self._compute_logit(hidden_i, weight_i, bias_i, proj_i)
tail_logprob_i = F.log_softmax(tail_logit_i, dim=1) tail_logprob_i = F.log_softmax(tail_logit_i, dim=1)
cluster_prob_idx = self.cutoffs[0] + i - 1 # No probability for the head cluster
logprob_i = head_logprob_i[:, -i] \ logprob_i = head_logprob_i[:, cluster_prob_idx] \
+ tail_logprob_i.gather(1, target_i[:,None]).squeeze(1) + tail_logprob_i.gather(1, target_i[:, None]).squeeze(1)
if (hasattr(self, 'keep_order') and self.keep_order) or keep_order: if (hasattr(self, 'keep_order') and self.keep_order) or keep_order:
nll.index_copy_(0, indices_i, -logprob_i) nll.index_copy_(0, indices_i, -logprob_i)
...@@ -169,6 +172,69 @@ class ProjectedAdaptiveLogSoftmax(nn.Module): ...@@ -169,6 +172,69 @@ class ProjectedAdaptiveLogSoftmax(nn.Module):
return nll return nll
def log_prob(self, hidden):
r""" Computes log probabilities for all :math:`n\_classes`
From: https://github.com/pytorch/pytorch/blob/master/torch/nn/modules/adaptive.py
Args:
hidden (Tensor): a minibatch of examples
Returns:
log-probabilities of for each class :math:`c`
in range :math:`0 <= c <= n\_classes`, where :math:`n\_classes` is a
parameter passed to ``AdaptiveLogSoftmaxWithLoss`` constructor.
Shape:
- Input: :math:`(N, in\_features)`
- Output: :math:`(N, n\_classes)`
"""
if self.n_clusters == 0:
logit = self._compute_logit(hidden, self.out_layers[0].weight,
self.out_layers[0].bias, self.out_projs[0])
return F.log_softmax(logit, dim=-1)
else:
# construct weights and biases
weights, biases = [], []
for i in range(len(self.cutoffs)):
if self.div_val == 1:
l_idx, r_idx = self.cutoff_ends[i], self.cutoff_ends[i + 1]
weight_i = self.out_layers[0].weight[l_idx:r_idx]
bias_i = self.out_layers[0].bias[l_idx:r_idx]
else:
weight_i = self.out_layers[i].weight
bias_i = self.out_layers[i].bias
if i == 0:
weight_i = torch.cat(
[weight_i, self.cluster_weight], dim=0)
bias_i = torch.cat(
[bias_i, self.cluster_bias], dim=0)
weights.append(weight_i)
biases.append(bias_i)
head_weight, head_bias, head_proj = weights[0], biases[0], self.out_projs[0]
head_logit = self._compute_logit(hidden, head_weight, head_bias, head_proj)
out = hidden.new_empty((head_logit.size(0), self.n_token))
head_logprob = F.log_softmax(head_logit, dim=1)
cutoff_values = [0] + self.cutoffs
for i in range(len(cutoff_values) - 1):
start_idx, stop_idx = cutoff_values[i], cutoff_values[i + 1]
if i == 0:
out[:, :self.cutoffs[0]] = head_logprob[:, :self.cutoffs[0]]
else:
weight_i, bias_i, proj_i = weights[i], biases[i], self.out_projs[i]
tail_logit_i = self._compute_logit(hidden, weight_i, bias_i, proj_i)
tail_logprob_i = F.log_softmax(tail_logit_i, dim=1)
logprob_i = head_logprob[:, -i] + tail_logprob_i
out[:, start_idx, stop_idx] = logprob_i
return out
class LogUniformSampler(object): class LogUniformSampler(object):
def __init__(self, range_max, n_sample): def __init__(self, range_max, n_sample):
""" """
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment