Commit 82fe21d3 authored by Jiezhong Qiu's avatar Jiezhong Qiu
Browse files

fix

parent 37d01e9c
...@@ -90,7 +90,7 @@ class MultiHeadPositionwiseFF(nn.Module): ...@@ -90,7 +90,7 @@ class MultiHeadPositionwiseFF(nn.Module):
assert d_model % n_head == 0 assert d_model % n_head == 0
self.n_head = n_head self.n_head = n_head
d_head = d_model / n_head d_head = d_model // n_head
self.d_head = d_head self.d_head = d_head
self.d_model = d_model self.d_model = d_model
self.d_inner = d_inner self.d_inner = d_inner
...@@ -138,7 +138,7 @@ class MultiHeadPositionwiseFF(nn.Module): ...@@ -138,7 +138,7 @@ class MultiHeadPositionwiseFF(nn.Module):
attn_vec = torch.einsum('ibnh,ndh->ibnd', (attn_score, self.v_weight)) + self.v_bias attn_vec = torch.einsum('ibnh,ndh->ibnd', (attn_score, self.v_weight)) + self.v_bias
attn_vec = attn_vec.view(inp.size(0), inp.size(1), self.d_model) attn_vec = attn_vec.contiguous().view(inp.size(0), inp.size(1), self.d_model)
core_out = self.o_net(attn_vec) core_out = self.o_net(attn_vec)
core_out = self.dropout(core_out) core_out = self.dropout(core_out)
...@@ -987,7 +987,7 @@ class MemTransformerLM(nn.Module): ...@@ -987,7 +987,7 @@ class MemTransformerLM(nn.Module):
self.out_layer.bias, target, pred_hid, self.sampler) self.out_layer.bias, target, pred_hid, self.sampler)
loss = -F.log_softmax(logit, -1)[:, :, 0] loss = -F.log_softmax(logit, -1)[:, :, 0]
else: else:
loss = self.crit(pred_hid.view(-1, pred_hid.size(-1)), target.view(-1)) loss = self.crit(pred_hid.view(-1, pred_hid.size(-1)), target.contiguous().view(-1))
loss = loss.view(tgt_len, -1) loss = loss.view(tgt_len, -1)
if new_mems is None: if new_mems is None:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment