Commit 1e8b6e33 authored by Guolin Ke's avatar Guolin Ke
Browse files

fix softmax cpu

parent 21cb6b39
......@@ -94,23 +94,21 @@ def softmax_dropout(input, dropout_prob, is_training=True, mask=None, bias=None)
torch.Tensor: the result after softmax
"""
input = input.contiguous()
input_size = input.size()
if mask is not None:
_check_mask(mask, input)
mask = mask.contiguous().view(-1, mask.shape[-2], mask.shape[-1])
if bias is not None:
_check_bias(bias, input)
bias = bias.contiguous().view(-1, input_size[-2], input_size[-1])
input = input.view(-1, input_size[-2], input_size[-1])
if input.is_cuda and input.shape[-1] <= 2048:
input_size = input.size()
if mask is not None:
_check_mask(mask, input)
mask = mask.contiguous().view(-1, mask.shape[-2], mask.shape[-1])
if bias is not None:
_check_bias(bias, input)
bias = bias.contiguous().view(-1, input_size[-2], input_size[-1])
input = input.view(-1, input_size[-2], input_size[-1])
return SoftmaxDropoutFast.apply(
is_training, input, mask, bias, dropout_prob
).view(*input_size)
else:
if mask is None:
if mask is not None:
input += mask
if bias is not None:
input += bias
return F.dropout(
F.softmax(input, dim=-1), p=dropout_prob, training=is_training
).view(*input_size)
return F.dropout(F.softmax(input, dim=-1), p=dropout_prob, training=is_training)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment