Commit af4f9088 authored by Guolin Ke's avatar Guolin Ke
Browse files

support softmax in non-inplace cases

parent 49c9895b
......@@ -99,7 +99,7 @@ class SelfMultiheadAttention(nn.Module):
else:
attn_weights += attn_bias
attn = softmax_dropout(
attn_weights, self.dropout, self.training,
attn_weights, self.dropout, self.training, inplace=False,
)
o = torch.bmm(attn, v)
......
......@@ -81,7 +81,7 @@ def _check_bias(bias, input):
prev_non_one = bias.shape[i] != 1
def softmax_dropout(input, dropout_prob, is_training=True, mask=None, bias=None):
def softmax_dropout(input, dropout_prob, is_training=True, mask=None, bias=None, inplace=True):
"""softmax dropout, and mask, bias are optional.
Args:
input (torch.Tensor): input tensor
......@@ -103,6 +103,9 @@ def softmax_dropout(input, dropout_prob, is_training=True, mask=None, bias=None)
_check_bias(bias, input)
bias = bias.contiguous().view(-1, input_size[-2], input_size[-1])
input = input.view(-1, input_size[-2], input_size[-1])
if not inplace:
# copy a input for non-inplace case
input = input.clone()
if dropout_prob <= 0.0 or input_size[-1] <= 1024:
return SoftmaxDropoutFast.apply(
is_training, input, mask, bias, dropout_prob
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment