Commit 3bdbb28d authored by Jiezhong Qiu's avatar Jiezhong Qiu
Browse files

fix

parent 1b7cbeb5
...@@ -120,6 +120,7 @@ class SparsePositionwiseFF(nn.Module): ...@@ -120,6 +120,7 @@ class SparsePositionwiseFF(nn.Module):
relu_out = self.CoreNet_1(inp).view(-1, self.d_inner) relu_out = self.CoreNet_1(inp).view(-1, self.d_inner)
sparse_relu_out = torch_sparse.SparseTensor.from_dense(relu_out) sparse_relu_out = torch_sparse.SparseTensor.from_dense(relu_out)
core_out = torch_sparse.matmul(sparse_relu_out, self.W2) + self.b2 core_out = torch_sparse.matmul(sparse_relu_out, self.W2) + self.b2
core_out = core_out.view(inp.size(0), inp.size(1), self.d_model)
core_out = self.dropout_final(core_out) core_out = self.dropout_final(core_out)
output = core_out + residual output = core_out + residual
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment