Commit 0a79672a authored by Guolin Ke's avatar Guolin Ke
Browse files

rollback to torch when check bias/mask failed in softmax

parent f4ce5889
...@@ -51,42 +51,50 @@ class SoftmaxDropoutFast(torch.autograd.Function): ...@@ -51,42 +51,50 @@ class SoftmaxDropoutFast(torch.autograd.Function):
def _check_mask(mask, input): def _check_mask(mask, input):
assert mask.dtype == input.dtype, "mask and input must have the same dtype" try:
assert len(mask.shape) == len(input.shape), "wrong length of mask.shape" assert mask.dtype == input.dtype, "mask and input must have the same dtype"
assert ( assert len(mask.shape) == len(input.shape), "wrong length of mask.shape"
mask.shape[-3] == 1 or mask.shape[-3] == input.shape[-3]
), "mask.shape[-3] must be 1 or input.shape[-3]"
if mask.shape[-3] == 1:
assert mask.shape[-2] == 1, "when mask.shape[-3] == 1, mask.shape[-2] must be 1"
else:
assert ( assert (
mask.shape[-2] == 1 or mask.shape[-2] == input.shape[-2] mask.shape[-3] == 1 or mask.shape[-3] == input.shape[-3]
), "mask.shape[-2] must be 1 or input.shape[-2]" ), "mask.shape[-3] must be 1 or input.shape[-3]"
if mask.shape[-3] == 1:
assert mask.shape[-2] == 1, "when mask.shape[-3] == 1, mask.shape[-2] must be 1"
else:
assert (
mask.shape[-2] == 1 or mask.shape[-2] == input.shape[-2]
), "mask.shape[-2] must be 1 or input.shape[-2]"
return True
except:
return False
def _check_bias(bias, input): def _check_bias(bias, input):
assert bias.dtype == input.dtype, "bias and input must have the same dtype" try:
assert len(bias.shape) == len(input.shape), "wrong length of bias.shape" assert bias.dtype == input.dtype, "bias and input must have the same dtype"
assert bias.shape[-1] == input.shape[-1], "bias.shape[-1] must be input.shape[-1]" assert len(bias.shape) == len(input.shape), "wrong length of bias.shape"
assert bias.shape[-2] == input.shape[-2], "bias.shape[-2] must be input.shape[-2]" assert bias.shape[-1] == input.shape[-1], "bias.shape[-1] must be input.shape[-1]"
len_shape = len(input.shape) assert bias.shape[-2] == input.shape[-2], "bias.shape[-2] must be input.shape[-2]"
if len_shape > 3: len_shape = len(input.shape)
# head dim should be the same if len_shape > 3:
assert ( # head dim should be the same
bias.shape[-3] == input.shape[-3]
), "bias.shape[-3] must be input.shape[-3]"
offset = 3
else:
offset = 2
prev_non_one = True
for i in range(len_shape - offset - 1, -1, -1):
if prev_non_one:
assert ( assert (
bias.shape[i] == input.shape[i] or bias.shape[i] == 1 bias.shape[-3] == input.shape[-3]
), "bias.shape[{}] must be input.shape[{}] or 1".format(i, i) ), "bias.shape[-3] must be input.shape[-3]"
offset = 3
else: else:
assert bias.shape[i] == 1, "bias.shape[{}] must be 1".format(i) offset = 2
prev_non_one = bias.shape[i] != 1 prev_non_one = True
for i in range(len_shape - offset - 1, -1, -1):
if prev_non_one:
assert (
bias.shape[i] == input.shape[i] or bias.shape[i] == 1
), "bias.shape[{}] must be input.shape[{}] or 1".format(i, i)
else:
assert bias.shape[i] == 1, "bias.shape[{}] must be 1".format(i)
prev_non_one = bias.shape[i] != 1
return True
except:
return False
def softmax_dropout(input, dropout_prob, is_training=True, mask=None, bias=None, inplace=True): def softmax_dropout(input, dropout_prob, is_training=True, mask=None, bias=None, inplace=True):
...@@ -102,18 +110,24 @@ def softmax_dropout(input, dropout_prob, is_training=True, mask=None, bias=None, ...@@ -102,18 +110,24 @@ def softmax_dropout(input, dropout_prob, is_training=True, mask=None, bias=None,
torch.Tensor: the result after softmax torch.Tensor: the result after softmax
""" """
input = input.contiguous() input = input.contiguous()
if not inplace:
# copy a input for non-inplace case
input = input.clone()
if input.is_cuda and HAS_SOFTMAX: if input.is_cuda and HAS_SOFTMAX:
input_size = input.size() input_size = input.size()
if mask is not None: if mask is not None:
_check_mask(mask, input) if _check_mask(mask, input):
mask = mask.contiguous().view(-1, mask.shape[-2], mask.shape[-1]) mask = mask.contiguous().view(-1, mask.shape[-2], mask.shape[-1])
else:
input += mask
mask = None
if bias is not None: if bias is not None:
_check_bias(bias, input) if _check_bias(bias, input):
bias = bias.contiguous().view(-1, input_size[-2], input_size[-1]) bias = bias.contiguous().view(-1, input_size[-2], input_size[-1])
else:
input += bias
bias = None
input = input.view(-1, input_size[-2], input_size[-1]) input = input.view(-1, input_size[-2], input_size[-1])
if not inplace:
# copy a input for non-inplace case
input = input.clone()
if dropout_prob <= 0.0 or input_size[-1] <= 1024: if dropout_prob <= 0.0 or input_size[-1] <= 1024:
return SoftmaxDropoutFast.apply( return SoftmaxDropoutFast.apply(
is_training, input, mask, bias, dropout_prob is_training, input, mask, bias, dropout_prob
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment