Unverified Commit 8da5eaaf authored by Guolin Ke's avatar Guolin Ke Committed by GitHub
Browse files

support softmax with bias and mask (#2)

* try

* fix bugs

* code clean

* support mask in softmax

* code clean

* check for shapes
parent 0da9683c
......@@ -6,46 +6,64 @@
std::vector<c10::optional<torch::Tensor>> fwd_cuda(
bool is_training,
torch::Tensor &input,
const c10::optional<torch::Tensor> &attn_mask,
const c10::optional<torch::Tensor> &bias,
float dropout_prob,
c10::optional<at::Generator> gen_
);
c10::optional<at::Generator> gen_);
torch::Tensor bwd_cuda(
torch::Tensor &output_grads,
const torch::Tensor &softmax_results,
const c10::optional<torch::Tensor> &dropout_mask,
float dropout_prob
);
float dropout_prob);
// C++ interface
#define CHECK_CUDA(x) AT_ASSERTM(x.is_cuda(), #x " must be a CUDA tensor")
#define CHECK_CONTIGUOUS(x) AT_ASSERTM(x.is_contiguous(), #x " must be contiguous")
#define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x)
#define CHECK_INPUT(x) \
CHECK_CUDA(x); \
CHECK_CONTIGUOUS(x)
std::vector<c10::optional<torch::Tensor>> fwd(
bool is_training,
torch::Tensor &input,
const c10::optional<torch::Tensor> &attn_mask,
const c10::optional<torch::Tensor> &bias,
float dropout_prob,
c10::optional<at::Generator> gen_
) {
c10::optional<at::Generator> gen_)
{
CHECK_INPUT(input);
if (attn_mask)
{
CHECK_INPUT(attn_mask.value());
AT_ASSERTM(attn_mask->dim() == 3, "expected 3D tensor");
}
if (bias)
{
CHECK_INPUT(bias.value());
AT_ASSERTM(bias->dim() == 3, "expected 3D tensor");
AT_ASSERTM(input.size(0) % bias->size(0) == 0, "wrong first dim of bias.");
AT_ASSERTM(bias->size(1) == input.size(1) && bias->size(2) == input.size(2), "the last two dims of bias and input should be the same.");
}
AT_ASSERTM(input.dim() == 3, "expected 3D tensor");
AT_ASSERTM(input.scalar_type() == at::ScalarType::Half ||
input.scalar_type() == at::ScalarType::BFloat16 ||
input.scalar_type() == at::ScalarType::Float, "Only HALF/BFloat16/Float is supported");
return fwd_cuda(is_training, input, dropout_prob, gen_);
input.scalar_type() == at::ScalarType::Float,
"Only HALF/BFloat16/Float is supported");
return fwd_cuda(is_training, input, attn_mask, bias, dropout_prob, gen_);
}
torch::Tensor bwd(
torch::Tensor &output_grads,
const torch::Tensor &softmax_results,
const c10::optional<torch::Tensor> &dropout_mask,
float dropout_prob
) {
float dropout_prob)
{
CHECK_INPUT(output_grads);
CHECK_INPUT(softmax_results);
if (dropout_mask) {
if (dropout_mask)
{
CHECK_INPUT(dropout_mask.value());
}
AT_ASSERTM(output_grads.dim() == 3, "expected 3D tensor");
......@@ -54,15 +72,18 @@ torch::Tensor bwd(
AT_ASSERTM(output_grads.scalar_type() == at::ScalarType::Half ||
output_grads.scalar_type() == at::ScalarType::BFloat16 ||
output_grads.scalar_type() == at::ScalarType::Float, "Only HALF/BFloat16/Float is supported");
output_grads.scalar_type() == at::ScalarType::Float,
"Only HALF/BFloat16/Float is supported");
AT_ASSERTM(softmax_results.scalar_type() == at::ScalarType::Half ||
softmax_results.scalar_type() == at::ScalarType::BFloat16 ||
softmax_results.scalar_type() == at::ScalarType::Float, "Only HALF/BFloat16/Float is supported");
softmax_results.scalar_type() == at::ScalarType::Float,
"Only HALF/BFloat16/Float is supported");
AT_ASSERTM(output_grads.scalar_type() == softmax_results.scalar_type(), "the types mismatch");
return bwd_cuda(output_grads, softmax_results, dropout_mask, dropout_prob);
}
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m)
{
m.def("forward", &fwd, "softmax dropout -- Forward.");
m.def("backward", &bwd, "softmax dropout -- Backward.");
}
This diff is collapsed.
import torch
import torch.nn.functional as F
from unicore.modules import softmax_dropout
def gen_attn_mask(mask, neg_inf):
assert neg_inf < -1e4
attn_mask = torch.zeros_like(mask)
attn_mask[mask == 0] = neg_inf
return attn_mask
def normal_softmax(a, mask, bias):
return F.softmax(a + mask + bias, dim=-1)
def fused_softmax(a, mask, bias):
return softmax_dropout(a, 0, True, mask=mask, bias=bias)
def wrap_forward_backward(func, a1, mask, bias1):
a = a1.clone()
bias = bias1.clone()
a.requires_grad = True
bias.requires_grad = True
output = func(a, mask, bias)
o = output.float().sum()
o.backward()
return output, a.grad, bias.grad
def check_diff(a, b, name, eps=1e-3):
assert (a - b).abs().max() < eps, "name {}, diff {}".format(
name, (a - b).abs().max()
)
def test_softmax():
n_batch = 4
n_heads = 8
n_query = 128
test_dims = [64, 128, 256, 512, 1024]
test_dtype = [torch.float32, torch.float16, torch.bfloat16]
test_device = torch.device("cuda")
for last_dim in test_dims:
for dtype in test_dtype:
x = torch.rand(
n_batch,
n_heads,
n_query,
last_dim,
dtype=dtype,
device=test_device,
)
mask = gen_attn_mask(
(
torch.rand(
n_batch,
1,
1,
last_dim,
dtype=dtype,
device=test_device,
)
> 0.2
).type(x.dtype),
-3e4,
)
bias = torch.rand(
n_batch, n_heads, n_query, last_dim, dtype=dtype, device=test_device
)
out_a1, out_b1, out_c1 = wrap_forward_backward(
normal_softmax, x, mask, bias
)
out_a2, out_b2, out_c2 = wrap_forward_backward(fused_softmax, x, mask, bias)
check_diff(out_a1, out_a2, "output")
check_diff(out_b1, out_b2, "grad_input")
check_diff(out_c1, out_c2, "grad_bias")
def test_tri_softmax1():
n_batch = 2
n_groups = 32
n_heads = 8
n_query = 128
test_dims = [64, 128, 256, 512, 1024]
test_dtype = [torch.float32, torch.float16, torch.bfloat16]
test_device = torch.device("cuda")
for last_dim in test_dims:
for dtype in test_dtype:
x = torch.rand(
n_batch,
n_groups,
n_heads,
n_query,
last_dim,
dtype=dtype,
device=test_device,
)
mask = gen_attn_mask(
(
torch.rand(
n_batch,
n_groups,
1,
1,
last_dim,
dtype=dtype,
device=test_device,
)
> 0.2
).type(x.dtype),
-3e4,
)
bias = torch.rand(
1, 1, n_heads, n_query, last_dim, dtype=dtype, device=test_device
)
out_a1, out_b1, out_c1 = wrap_forward_backward(
normal_softmax, x, mask, bias
)
out_a2, out_b2, out_c2 = wrap_forward_backward(fused_softmax, x, mask, bias)
check_diff(out_a1, out_a2, "output")
check_diff(out_b1, out_b2, "grad_input")
check_diff(out_c1, out_c2, "grad_bias")
def test_tri_softmax2():
n_batch = 2
n_groups = 32
n_heads = 8
n_query = 128
test_dims = [64, 128, 256, 512, 1024]
test_dtype = [torch.float32, torch.float16, torch.bfloat16]
test_device = torch.device("cuda")
for last_dim in test_dims:
for dtype in test_dtype:
x = torch.rand(
n_batch,
n_groups,
n_heads,
n_query,
last_dim,
dtype=dtype,
device=test_device,
)
mask = gen_attn_mask(
(
torch.rand(
n_batch,
n_groups,
n_heads,
1,
last_dim,
dtype=dtype,
device=test_device,
)
> 0.2
).type(x.dtype),
-3e4,
)
bias = torch.rand(
1, n_groups, n_heads, n_query, last_dim, dtype=dtype, device=test_device
)
out_a1, out_b1, out_c1 = wrap_forward_backward(
normal_softmax, x, mask, bias
)
out_a2, out_b2, out_c2 = wrap_forward_backward(fused_softmax, x, mask, bias)
check_diff(out_a1, out_a2, "output")
check_diff(out_b1, out_b2, "grad_input")
check_diff(out_c1, out_c2, "grad_bias")
......@@ -8,6 +8,7 @@ import torch
from torch import Tensor, nn
from .softmax_dropout import softmax_dropout
class SelfMultiheadAttention(nn.Module):
def __init__(
self,
......@@ -37,7 +38,7 @@ class SelfMultiheadAttention(nn.Module):
query,
key_padding_mask: Optional[Tensor] = None,
attn_bias: Optional[Tensor] = None,
return_attn: bool=False,
return_attn: bool = False,
) -> Tensor:
bsz, tgt_len, embed_dim = query.size()
......@@ -46,18 +47,25 @@ class SelfMultiheadAttention(nn.Module):
q, k, v = self.in_proj(query).chunk(3, dim=-1)
q = (
q.view(bsz, tgt_len, self.num_heads, self.head_dim).transpose(1, 2)
.contiguous().view(bsz * self.num_heads, -1, self.head_dim) * self.scaling
q.view(bsz, tgt_len, self.num_heads, self.head_dim)
.transpose(1, 2)
.contiguous()
.view(bsz * self.num_heads, -1, self.head_dim)
* self.scaling
)
if k is not None:
k = (
k.view(bsz, -1, self.num_heads, self.head_dim).transpose(1, 2)
.contiguous().view(bsz * self.num_heads, -1, self.head_dim)
k.view(bsz, -1, self.num_heads, self.head_dim)
.transpose(1, 2)
.contiguous()
.view(bsz * self.num_heads, -1, self.head_dim)
)
if v is not None:
v = (
v.view(bsz, -1, self.num_heads, self.head_dim).transpose(1, 2)
.contiguous().view(bsz * self.num_heads, -1, self.head_dim)
v.view(bsz, -1, self.num_heads, self.head_dim)
.transpose(1, 2)
.contiguous()
.view(bsz * self.num_heads, -1, self.head_dim)
)
assert k is not None
......@@ -72,37 +80,38 @@ class SelfMultiheadAttention(nn.Module):
assert key_padding_mask.size(0) == bsz
assert key_padding_mask.size(1) == src_len
attn_weights = torch.bmm(q, k.transpose(1, 2))
assert list(attn_weights.size()) == [bsz * self.num_heads, tgt_len, src_len]
if key_padding_mask is not None:
# don't attend to padding symbols
attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len)
attn_weights.masked_fill_(
key_padding_mask.unsqueeze(1).unsqueeze(2).to(torch.bool),
float("-inf")
key_padding_mask.unsqueeze(1).unsqueeze(2).to(torch.bool), float("-inf")
)
attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
if attn_bias is not None:
attn_weights += attn_bias
attn_probs = softmax_dropout(attn_weights, self.dropout, self.training)
attn_probs = softmax_dropout(
attn_weights, self.dropout, self.training, bias=attn_bias
)
attn = torch.bmm(attn_probs, v)
assert list(attn.size()) == [bsz * self.num_heads, tgt_len, self.head_dim]
attn = attn.view(bsz, self.num_heads, tgt_len, self.head_dim).transpose(1, 2).contiguous().view(bsz, tgt_len, embed_dim)
attn = (
attn.view(bsz, self.num_heads, tgt_len, self.head_dim)
.transpose(1, 2)
.contiguous()
.view(bsz, tgt_len, embed_dim)
)
attn = self.out_proj(attn)
if not return_attn:
return attn
else:
return attn, attn_weights, attn_probs
class CrossMultiheadAttention(nn.Module):
def __init__(
self,
......@@ -147,18 +156,25 @@ class CrossMultiheadAttention(nn.Module):
v = self.v_proj(value)
q = (
q.view(bsz, tgt_len, self.num_heads, self.head_dim).transpose(1, 2)
.contiguous().view(bsz * self.num_heads, -1, self.head_dim) * self.scaling
q.view(bsz, tgt_len, self.num_heads, self.head_dim)
.transpose(1, 2)
.contiguous()
.view(bsz * self.num_heads, -1, self.head_dim)
* self.scaling
)
if k is not None:
k = (
k.view(bsz, -1, self.num_heads, self.head_dim).transpose(1, 2)
.contiguous().view(bsz * self.num_heads, -1, self.head_dim)
k.view(bsz, -1, self.num_heads, self.head_dim)
.transpose(1, 2)
.contiguous()
.view(bsz * self.num_heads, -1, self.head_dim)
)
if v is not None:
v = (
v.view(bsz, -1, self.num_heads, self.head_dim).transpose(1, 2)
.contiguous().view(bsz * self.num_heads, -1, self.head_dim)
v.view(bsz, -1, self.num_heads, self.head_dim)
.transpose(1, 2)
.contiguous()
.view(bsz * self.num_heads, -1, self.head_dim)
)
assert k is not None
......@@ -173,30 +189,28 @@ class CrossMultiheadAttention(nn.Module):
assert key_padding_mask.size(0) == bsz
assert key_padding_mask.size(1) == src_len
attn_weights = torch.bmm(q, k.transpose(1, 2))
assert list(attn_weights.size()) == [bsz * self.num_heads, tgt_len, src_len]
if key_padding_mask is not None:
# don't attend to padding symbols
attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len)
attn_weights.masked_fill_(
key_padding_mask.unsqueeze(1).unsqueeze(2).to(torch.bool),
float("-inf")
key_padding_mask.unsqueeze(1).unsqueeze(2).to(torch.bool), float("-inf")
)
attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
if attn_bias is not None:
attn_weights += attn_bias
attn_probs = softmax_dropout(attn_weights, self.dropout, self.training)
attn_probs = softmax_dropout(attn_weights, self.dropout, self.training, bias=attn_bias)
attn = torch.bmm(attn_probs, v)
assert list(attn.size()) == [bsz * self.num_heads, tgt_len, self.head_dim]
attn = attn.view(bsz, self.num_heads, tgt_len, self.head_dim).transpose(1, 2).contiguous().view(bsz, tgt_len, embed_dim)
attn = (
attn.view(bsz, self.num_heads, tgt_len, self.head_dim)
.transpose(1, 2)
.contiguous()
.view(bsz, tgt_len, embed_dim)
)
attn = self.out_proj(attn)
return attn
......@@ -6,16 +6,23 @@ import torch
import unicore_fused_softmax_dropout
import torch.nn.functional as F
class SoftmaxDropoutFast(torch.autograd.Function):
@staticmethod
def forward(ctx, is_training, inputs, dropout_prob):
# don't use ctx.save_for_backward to save dropout_prob
# allocating space for a tensor is time-consuming
dropout_results, dropout_mask, softmax_results = unicore_fused_softmax_dropout.forward(is_training,
inputs, dropout_prob, None)
def forward(ctx, is_training, inputs, mask, bias, dropout_prob):
(
dropout_results,
dropout_mask,
softmax_results,
) = unicore_fused_softmax_dropout.forward(
is_training, inputs, mask, bias, dropout_prob, None
)
if is_training:
ctx.dropout_prob = dropout_prob
ctx.save_for_backward(softmax_results, dropout_mask)
ctx.has_bias = bias is not None and bias.requires_grad
if ctx.has_bias:
ctx.bias_batch_dim = bias.shape[0]
return dropout_results
@staticmethod
......@@ -23,15 +30,87 @@ class SoftmaxDropoutFast(torch.autograd.Function):
softmax_results, dropout_mask = ctx.saved_tensors
dropout_prob = ctx.dropout_prob
grad_output = grad_output.contiguous()
grad_input = unicore_fused_softmax_dropout.backward(grad_output, softmax_results,
dropout_mask, dropout_prob)
return None, grad_input, None
grad_input = unicore_fused_softmax_dropout.backward(
grad_output, softmax_results, dropout_mask, dropout_prob
)
if ctx.has_bias:
grad_bias = grad_input.view(
-1, ctx.bias_batch_dim, grad_input.shape[-2], grad_input.shape[-1]
).sum(dim=0)
else:
grad_bias = None
return None, grad_input, None, grad_bias, None
def _check_mask(mask, input):
assert mask.dtype == input.dtype, "mask and input must have the same dtype"
assert len(mask.shape) == len(input.shape), "wrong length of mask.shape"
assert (
mask.shape[-3] == 1 or mask.shape[-3] == input.shape[-3]
), "mask.shape[-3] must be 1 or input.shape[-3]"
if mask.shape[-3] == 1:
assert mask.shape[-2] == 1, "when mask.shape[-3] == 1, mask.shape[-2] must be 1"
else:
assert (
mask.shape[-2] == 1 or mask.shape[-2] == input.shape[-2]
), "mask.shape[-2] must be 1 or input.shape[-2]"
def _check_bias(bias, input):
assert bias.dtype == input.dtype, "bias and input must have the same dtype"
assert len(bias.shape) == len(input.shape), "wrong length of bias.shape"
assert bias.shape[-1] == input.shape[-1], "bias.shape[-1] must be input.shape[-1]"
assert bias.shape[-2] == input.shape[-2], "bias.shape[-2] must be input.shape[-2]"
len_shape = len(input.shape)
if len_shape > 3:
# head dim should be the same
assert (
bias.shape[-3] == input.shape[-3]
), "bias.shape[-3] must be input.shape[-3]"
offset = 3
else:
offset = 2
prev_non_one = True
for i in range(len_shape - offset - 1, -1, -1):
if prev_non_one:
assert (
bias.shape[i] == input.shape[i] or bias.shape[i] == 1
), "bias.shape[{}] must be input.shape[{}] or 1".format(i, i)
else:
assert bias.shape[i] == 1, "bias.shape[{}] must be 1".format(i)
prev_non_one = bias.shape[i] != 1
def softmax_dropout(input, dropout_prob, is_training=True, mask=None, bias=None):
"""softmax dropout, and mask, bias are optional.
Args:
input (torch.Tensor): input tensor
dropout_prob (float): dropout probability
is_training (bool, optional): is in training or not. Defaults to True.
mask (torch.Tensor, optional): the mask tensor, use as input + mask . Defaults to None.
bias (torch.Tensor, optional): the bias tensor, use as input + bias . Defaults to None.
def softmax_dropout(input, dropout_prob, is_training=True):
Returns:
torch.Tensor: the result after softmax
"""
input = input.contiguous()
input_size = input.size()
if mask is not None:
_check_mask(mask, input)
mask = mask.contiguous().view(-1, mask.shape[-2], mask.shape[-1])
if bias is not None:
_check_bias(bias, input)
bias = bias.contiguous().view(-1, input_size[-2], input_size[-1])
input = input.view(-1, input_size[-2], input_size[-1])
if input.is_cuda and input.shape[-1] <= 2048:
return SoftmaxDropoutFast.apply(is_training, input, dropout_prob).view(*input_size)
return SoftmaxDropoutFast.apply(
is_training, input, mask, bias, dropout_prob
).view(*input_size)
else:
return F.dropout(F.softmax(input, dim=-1), p=dropout_prob, training=is_training).view(*input_size)
if mask is None:
input += mask
if bias is not None:
input += bias
return F.dropout(
F.softmax(input, dim=-1), p=dropout_prob, training=is_training
).view(*input_size)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment