Commit 1e8b6e33 authored by Guolin Ke's avatar Guolin Ke
Browse files

fix softmax cpu

parent 21cb6b39
...@@ -94,6 +94,7 @@ def softmax_dropout(input, dropout_prob, is_training=True, mask=None, bias=None) ...@@ -94,6 +94,7 @@ def softmax_dropout(input, dropout_prob, is_training=True, mask=None, bias=None)
torch.Tensor: the result after softmax torch.Tensor: the result after softmax
""" """
input = input.contiguous() input = input.contiguous()
if input.is_cuda and input.shape[-1] <= 2048:
input_size = input.size() input_size = input.size()
if mask is not None: if mask is not None:
_check_mask(mask, input) _check_mask(mask, input)
...@@ -102,15 +103,12 @@ def softmax_dropout(input, dropout_prob, is_training=True, mask=None, bias=None) ...@@ -102,15 +103,12 @@ def softmax_dropout(input, dropout_prob, is_training=True, mask=None, bias=None)
_check_bias(bias, input) _check_bias(bias, input)
bias = bias.contiguous().view(-1, input_size[-2], input_size[-1]) bias = bias.contiguous().view(-1, input_size[-2], input_size[-1])
input = input.view(-1, input_size[-2], input_size[-1]) input = input.view(-1, input_size[-2], input_size[-1])
if input.is_cuda and input.shape[-1] <= 2048:
return SoftmaxDropoutFast.apply( return SoftmaxDropoutFast.apply(
is_training, input, mask, bias, dropout_prob is_training, input, mask, bias, dropout_prob
).view(*input_size) ).view(*input_size)
else: else:
if mask is None: if mask is not None:
input += mask input += mask
if bias is not None: if bias is not None:
input += bias input += bias
return F.dropout( return F.dropout(F.softmax(input, dim=-1), p=dropout_prob, training=is_training)
F.softmax(input, dim=-1), p=dropout_prob, training=is_training
).view(*input_size)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment