Commit 82fe21d3 authored by Jiezhong Qiu's avatar Jiezhong Qiu
Browse files

fix

parent 37d01e9c
......@@ -90,7 +90,7 @@ class MultiHeadPositionwiseFF(nn.Module):
assert d_model % n_head == 0
self.n_head = n_head
d_head = d_model / n_head
d_head = d_model // n_head
self.d_head = d_head
self.d_model = d_model
self.d_inner = d_inner
......@@ -138,7 +138,7 @@ class MultiHeadPositionwiseFF(nn.Module):
attn_vec = torch.einsum('ibnh,ndh->ibnd', (attn_score, self.v_weight)) + self.v_bias
attn_vec = attn_vec.view(inp.size(0), inp.size(1), self.d_model)
attn_vec = attn_vec.contiguous().view(inp.size(0), inp.size(1), self.d_model)
core_out = self.o_net(attn_vec)
core_out = self.dropout(core_out)
......@@ -987,7 +987,7 @@ class MemTransformerLM(nn.Module):
self.out_layer.bias, target, pred_hid, self.sampler)
loss = -F.log_softmax(logit, -1)[:, :, 0]
else:
loss = self.crit(pred_hid.view(-1, pred_hid.size(-1)), target.view(-1))
loss = self.crit(pred_hid.view(-1, pred_hid.size(-1)), target.contiguous().view(-1))
loss = loss.view(tgt_len, -1)
if new_mems is None:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment