"vscode:/vscode.git/clone" did not exist on "97a4e42e0887e7ff1e6be2b69d977be1111a63db"
Commit 0a79672a authored by Guolin Ke's avatar Guolin Ke
Browse files

rollback to torch when check bias/mask failed in softmax

parent f4ce5889
......@@ -51,6 +51,7 @@ class SoftmaxDropoutFast(torch.autograd.Function):
def _check_mask(mask, input):
try:
assert mask.dtype == input.dtype, "mask and input must have the same dtype"
assert len(mask.shape) == len(input.shape), "wrong length of mask.shape"
assert (
......@@ -62,9 +63,13 @@ def _check_mask(mask, input):
assert (
mask.shape[-2] == 1 or mask.shape[-2] == input.shape[-2]
), "mask.shape[-2] must be 1 or input.shape[-2]"
return True
except:
return False
def _check_bias(bias, input):
try:
assert bias.dtype == input.dtype, "bias and input must have the same dtype"
assert len(bias.shape) == len(input.shape), "wrong length of bias.shape"
assert bias.shape[-1] == input.shape[-1], "bias.shape[-1] must be input.shape[-1]"
......@@ -87,6 +92,9 @@ def _check_bias(bias, input):
else:
assert bias.shape[i] == 1, "bias.shape[{}] must be 1".format(i)
prev_non_one = bias.shape[i] != 1
return True
except:
return False
def softmax_dropout(input, dropout_prob, is_training=True, mask=None, bias=None, inplace=True):
......@@ -102,18 +110,24 @@ def softmax_dropout(input, dropout_prob, is_training=True, mask=None, bias=None,
torch.Tensor: the result after softmax
"""
input = input.contiguous()
if not inplace:
# copy a input for non-inplace case
input = input.clone()
if input.is_cuda and HAS_SOFTMAX:
input_size = input.size()
if mask is not None:
_check_mask(mask, input)
if _check_mask(mask, input):
mask = mask.contiguous().view(-1, mask.shape[-2], mask.shape[-1])
else:
input += mask
mask = None
if bias is not None:
_check_bias(bias, input)
if _check_bias(bias, input):
bias = bias.contiguous().view(-1, input_size[-2], input_size[-1])
else:
input += bias
bias = None
input = input.view(-1, input_size[-2], input_size[-1])
if not inplace:
# copy a input for non-inplace case
input = input.clone()
if dropout_prob <= 0.0 or input_size[-1] <= 1024:
return SoftmaxDropoutFast.apply(
is_training, input, mask, bias, dropout_prob
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment