Commit 189e72a7 authored by Jared Casper's avatar Jared Casper
Browse files

Merge branch 'fused_softmax_kernel_fixes' into 'main'

support for all mask in fused kernel + avoiding inplace operation in bwd pass

See merge request ADLR/megatron-lm!435
parents 8df49e72 76db9583
......@@ -293,6 +293,13 @@ __global__ void scaled_masked_softmax_warp_forward(
}
warp_reduce<acc_t, WARP_BATCH, WARP_SIZE, Max>(max_value);
// compute scale value to account for full mask
acc_t scale_value[WARP_BATCH];
#pragma unroll
for (int i = 0; i < WARP_BATCH; ++i) {
scale_value[i] = (max_value[i] == -10000.0) ? 0.0 : 1.0;
}
acc_t sum[WARP_BATCH] { 0.0f };
#pragma unroll
for (int i = 0; i < WARP_BATCH; ++i) {
......@@ -316,7 +323,7 @@ __global__ void scaled_masked_softmax_warp_forward(
if (element_index < element_count) {
#pragma unroll
for (int element = 0; element < ELEMENTS_PER_LDG_STG; ++element) {
out[element] = elements[i][it + element] / sum[i];
out[element] = elements[i][it + element] * scale_value[i] / sum[i];
}
copy_vector<output_t, ELEMENTS_PER_LDG_STG>(dst + i * element_count + it * WARP_SIZE, out);
} else {
......
......@@ -65,7 +65,7 @@ torch::Tensor fwd_cuda(
input.scalar_type(),
"dispatch_scaled_masked_softmax_forward",
dispatch_scaled_masked_softmax_forward<scalar_t, scalar_t, float>(
reinterpret_cast<scalar_t*>(softmax_results_ptr),
reinterpret_cast<scalar_t*>(softmax_results_ptr),
reinterpret_cast<const scalar_t*>(input_ptr),
reinterpret_cast<const uint8_t*>(mask_ptr),
scale_factor,
......@@ -92,14 +92,19 @@ torch::Tensor bwd_cuda(
const int query_seq_len = output_grads.size(2);
const int key_seq_len = output_grads.size(3);
auto act_options = output_grads.options().requires_grad(false);
torch::Tensor input_grads =
torch::empty({batches, attn_heads, query_seq_len, key_seq_len}, act_options);
void* output_grads_ptr = static_cast<void*>(output_grads.data_ptr());
void* input_grads_ptr = static_cast<void*>(input_grads.data_ptr());
//Softmax Grad
DISPATCH_HALF_AND_BFLOAT(
output_grads_.scalar_type(),
"dispatch_scaled_masked_softmax_backward",
dispatch_scaled_masked_softmax_backward<scalar_t, scalar_t, float>(
reinterpret_cast<scalar_t*>(output_grads_ptr),
reinterpret_cast<scalar_t*>(input_grads_ptr),
reinterpret_cast<scalar_t*>(output_grads_ptr),
reinterpret_cast<scalar_t const*>(softmax_results.data_ptr()),
scale_factor,
......@@ -107,10 +112,9 @@ torch::Tensor bwd_cuda(
key_seq_len,
batches,
attn_heads);
);
);
//backward pass is completely in-place
return output_grads;
return input_grads;
}
}
}
......
......@@ -7,7 +7,7 @@ from megatron.model.enums import AttnMaskType
from megatron.model.fused_layer_norm import MixedFusedLayerNorm
from megatron.model.fused_softmax import FusedScaleMaskSoftmax
from megatron.model.utils import attention_mask_func
from megatron.fused_kernels import load
def test_load_fused_kernels():
try:
......@@ -279,6 +279,90 @@ def test_layer_norm():
)
def attention_mask_func(attention_scores, attention_mask):
attention_scores.masked_fill_(attention_mask, -10000.0)
return attention_scores
def forward_torch_softmax(input, mask, scale):
input = input * scale
mask_output = attention_mask_func(input, mask) if mask is not None else input
probs = torch.nn.Softmax(dim=-1)(mask_output)
return probs
def test_masked_softmax_forward():
import scaled_masked_softmax_cuda
batch = 2
attn = 16
scale_t = torch.tensor([1.0])
for qlen in [128, 256, 1024, 2048, 4096]:
for klen in [128, 256, 1024, 2048]:
inputs = torch.normal(0, 2, (batch, attn, qlen, klen), dtype=torch.float16, device='cuda:0')
masks = torch.randint(0, 2, (batch, 1, qlen, klen), dtype=torch.bool, device='cuda:0')
softmax_results = scaled_masked_softmax_cuda.forward(inputs, masks, scale_t[0].item())
softmax_results_torch = forward_torch_softmax(inputs, masks, scale_t[0].item())
error = (softmax_results_torch - softmax_results).abs().max()
assert error < 1e-3
def test_masked_softmax_backward():
import scaled_masked_softmax_cuda
batch = 2
attn = 16
scale_t = torch.tensor([1.0])
for qlen in [128, 256, 1024, 2048, 4096]:
for klen in [128, 256, 1024, 2048]:
inputs = torch.normal(0, 2, (batch, attn, qlen, klen), dtype=torch.float16, device='cuda:0')
backward = torch.rand_like(inputs, dtype=torch.float16, device='cuda:0')
masks = torch.randint(0, 2, (batch, 1, qlen, klen), dtype=torch.bool, device='cuda:0')
softmax_results = scaled_masked_softmax_cuda.forward(inputs, masks, scale_t[0].item())
back_grad = scaled_masked_softmax_cuda.backward(backward, softmax_results, scale_t[0].item())
inputs.requires_grad = True
softmax_results_torch = forward_torch_softmax(inputs, masks, scale_t[0].item())
softmax_results_torch.backward(backward)
error = (back_grad - inputs.grad).abs().max()
assert error < 1e-3
def test_allmasked_softmax_forward():
import scaled_masked_softmax_cuda
batch = 2
attn = 16
scale_t = torch.tensor([1.0])
for qlen in [128, 256, 1024, 2048, 4096]:
for klen in [128, 256, 1024, 2048]:
inputs = torch.normal(0, 2, (batch, attn, qlen, klen), dtype=torch.float16, device='cuda:0')
masks = torch.ones((batch, 1, qlen, klen), dtype=torch.bool, device='cuda:0')
softmax_results = scaled_masked_softmax_cuda.forward(inputs, masks, scale_t[0].item())
softmax_results_torch = torch.zeros_like(inputs)
error = (softmax_results_torch - softmax_results).abs().max()
assert error == 0.0
def test_allmasked_softmax_backward():
import scaled_masked_softmax_cuda
batch = 2
attn = 16
scale_t = torch.tensor([1.0])
for qlen in [128, 256, 1024, 2048, 4096]:
for klen in [128, 256, 1024, 2048]:
inputs = torch.normal(0, 2, (batch, attn, qlen, klen), dtype=torch.float16, device='cuda:0')
backward = torch.rand_like(inputs, dtype=torch.float16, device='cuda:0')
masks = torch.ones((batch, 1, qlen, klen), dtype=torch.bool, device='cuda:0')
softmax_results = scaled_masked_softmax_cuda.forward(inputs, masks, scale_t[0].item())
back_grad = scaled_masked_softmax_cuda.backward(backward, softmax_results, scale_t[0].item())
inputs.requires_grad = True
softmax_results_torch = forward_torch_softmax(inputs, masks, scale_t[0].item())
softmax_results_torch.backward(backward)
error = (back_grad - inputs.grad).abs().max()
assert error < 1e-3
if __name__ == "__main__":
try:
from transformers import BertTokenizer, GPT2Tokenizer
......@@ -294,6 +378,11 @@ if __name__ == "__main__":
print("\n[Fail] Please install `transformers` package to test fused kernels\n")
exit(-1)
load()
test_masked_softmax_forward()
test_masked_softmax_backward()
test_allmasked_softmax_forward()
test_allmasked_softmax_backward()
test_load_fused_kernels()
test_fused_softmax()
test_fused_upper_triangle_mask_softmax()
......
......@@ -170,6 +170,7 @@ class FusedScaleMaskSoftmax(nn.Module):
and self.input_in_float16 # input must be fp16
and 16 < sk <= 4096 # sk must be 16 ~ 2048
and sq % 4 == 0 # sq must be divisor of 4
and sk % 4 == 0 # sk must be divisor of 4
and attn_batches % 4 == 0 # np * b must be divisor of 4
):
if 0 <= sk <= 4096:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment