Commit 189e72a7 authored by Jared Casper's avatar Jared Casper
Browse files

Merge branch 'fused_softmax_kernel_fixes' into 'main'

support for all mask in fused kernel + avoiding inplace operation in bwd pass

See merge request ADLR/megatron-lm!435
parents 8df49e72 76db9583
...@@ -293,6 +293,13 @@ __global__ void scaled_masked_softmax_warp_forward( ...@@ -293,6 +293,13 @@ __global__ void scaled_masked_softmax_warp_forward(
} }
warp_reduce<acc_t, WARP_BATCH, WARP_SIZE, Max>(max_value); warp_reduce<acc_t, WARP_BATCH, WARP_SIZE, Max>(max_value);
// compute scale value to account for full mask
acc_t scale_value[WARP_BATCH];
#pragma unroll
for (int i = 0; i < WARP_BATCH; ++i) {
scale_value[i] = (max_value[i] == -10000.0) ? 0.0 : 1.0;
}
acc_t sum[WARP_BATCH] { 0.0f }; acc_t sum[WARP_BATCH] { 0.0f };
#pragma unroll #pragma unroll
for (int i = 0; i < WARP_BATCH; ++i) { for (int i = 0; i < WARP_BATCH; ++i) {
...@@ -316,7 +323,7 @@ __global__ void scaled_masked_softmax_warp_forward( ...@@ -316,7 +323,7 @@ __global__ void scaled_masked_softmax_warp_forward(
if (element_index < element_count) { if (element_index < element_count) {
#pragma unroll #pragma unroll
for (int element = 0; element < ELEMENTS_PER_LDG_STG; ++element) { for (int element = 0; element < ELEMENTS_PER_LDG_STG; ++element) {
out[element] = elements[i][it + element] / sum[i]; out[element] = elements[i][it + element] * scale_value[i] / sum[i];
} }
copy_vector<output_t, ELEMENTS_PER_LDG_STG>(dst + i * element_count + it * WARP_SIZE, out); copy_vector<output_t, ELEMENTS_PER_LDG_STG>(dst + i * element_count + it * WARP_SIZE, out);
} else { } else {
......
...@@ -65,7 +65,7 @@ torch::Tensor fwd_cuda( ...@@ -65,7 +65,7 @@ torch::Tensor fwd_cuda(
input.scalar_type(), input.scalar_type(),
"dispatch_scaled_masked_softmax_forward", "dispatch_scaled_masked_softmax_forward",
dispatch_scaled_masked_softmax_forward<scalar_t, scalar_t, float>( dispatch_scaled_masked_softmax_forward<scalar_t, scalar_t, float>(
reinterpret_cast<scalar_t*>(softmax_results_ptr), reinterpret_cast<scalar_t*>(softmax_results_ptr),
reinterpret_cast<const scalar_t*>(input_ptr), reinterpret_cast<const scalar_t*>(input_ptr),
reinterpret_cast<const uint8_t*>(mask_ptr), reinterpret_cast<const uint8_t*>(mask_ptr),
scale_factor, scale_factor,
...@@ -92,14 +92,19 @@ torch::Tensor bwd_cuda( ...@@ -92,14 +92,19 @@ torch::Tensor bwd_cuda(
const int query_seq_len = output_grads.size(2); const int query_seq_len = output_grads.size(2);
const int key_seq_len = output_grads.size(3); const int key_seq_len = output_grads.size(3);
auto act_options = output_grads.options().requires_grad(false);
torch::Tensor input_grads =
torch::empty({batches, attn_heads, query_seq_len, key_seq_len}, act_options);
void* output_grads_ptr = static_cast<void*>(output_grads.data_ptr()); void* output_grads_ptr = static_cast<void*>(output_grads.data_ptr());
void* input_grads_ptr = static_cast<void*>(input_grads.data_ptr());
//Softmax Grad //Softmax Grad
DISPATCH_HALF_AND_BFLOAT( DISPATCH_HALF_AND_BFLOAT(
output_grads_.scalar_type(), output_grads_.scalar_type(),
"dispatch_scaled_masked_softmax_backward", "dispatch_scaled_masked_softmax_backward",
dispatch_scaled_masked_softmax_backward<scalar_t, scalar_t, float>( dispatch_scaled_masked_softmax_backward<scalar_t, scalar_t, float>(
reinterpret_cast<scalar_t*>(output_grads_ptr), reinterpret_cast<scalar_t*>(input_grads_ptr),
reinterpret_cast<scalar_t*>(output_grads_ptr), reinterpret_cast<scalar_t*>(output_grads_ptr),
reinterpret_cast<scalar_t const*>(softmax_results.data_ptr()), reinterpret_cast<scalar_t const*>(softmax_results.data_ptr()),
scale_factor, scale_factor,
...@@ -107,10 +112,9 @@ torch::Tensor bwd_cuda( ...@@ -107,10 +112,9 @@ torch::Tensor bwd_cuda(
key_seq_len, key_seq_len,
batches, batches,
attn_heads); attn_heads);
); );
//backward pass is completely in-place return input_grads;
return output_grads;
} }
} }
} }
......
...@@ -7,7 +7,7 @@ from megatron.model.enums import AttnMaskType ...@@ -7,7 +7,7 @@ from megatron.model.enums import AttnMaskType
from megatron.model.fused_layer_norm import MixedFusedLayerNorm from megatron.model.fused_layer_norm import MixedFusedLayerNorm
from megatron.model.fused_softmax import FusedScaleMaskSoftmax from megatron.model.fused_softmax import FusedScaleMaskSoftmax
from megatron.model.utils import attention_mask_func from megatron.model.utils import attention_mask_func
from megatron.fused_kernels import load
def test_load_fused_kernels(): def test_load_fused_kernels():
try: try:
...@@ -279,6 +279,90 @@ def test_layer_norm(): ...@@ -279,6 +279,90 @@ def test_layer_norm():
) )
def attention_mask_func(attention_scores, attention_mask):
attention_scores.masked_fill_(attention_mask, -10000.0)
return attention_scores
def forward_torch_softmax(input, mask, scale):
input = input * scale
mask_output = attention_mask_func(input, mask) if mask is not None else input
probs = torch.nn.Softmax(dim=-1)(mask_output)
return probs
def test_masked_softmax_forward():
import scaled_masked_softmax_cuda
batch = 2
attn = 16
scale_t = torch.tensor([1.0])
for qlen in [128, 256, 1024, 2048, 4096]:
for klen in [128, 256, 1024, 2048]:
inputs = torch.normal(0, 2, (batch, attn, qlen, klen), dtype=torch.float16, device='cuda:0')
masks = torch.randint(0, 2, (batch, 1, qlen, klen), dtype=torch.bool, device='cuda:0')
softmax_results = scaled_masked_softmax_cuda.forward(inputs, masks, scale_t[0].item())
softmax_results_torch = forward_torch_softmax(inputs, masks, scale_t[0].item())
error = (softmax_results_torch - softmax_results).abs().max()
assert error < 1e-3
def test_masked_softmax_backward():
import scaled_masked_softmax_cuda
batch = 2
attn = 16
scale_t = torch.tensor([1.0])
for qlen in [128, 256, 1024, 2048, 4096]:
for klen in [128, 256, 1024, 2048]:
inputs = torch.normal(0, 2, (batch, attn, qlen, klen), dtype=torch.float16, device='cuda:0')
backward = torch.rand_like(inputs, dtype=torch.float16, device='cuda:0')
masks = torch.randint(0, 2, (batch, 1, qlen, klen), dtype=torch.bool, device='cuda:0')
softmax_results = scaled_masked_softmax_cuda.forward(inputs, masks, scale_t[0].item())
back_grad = scaled_masked_softmax_cuda.backward(backward, softmax_results, scale_t[0].item())
inputs.requires_grad = True
softmax_results_torch = forward_torch_softmax(inputs, masks, scale_t[0].item())
softmax_results_torch.backward(backward)
error = (back_grad - inputs.grad).abs().max()
assert error < 1e-3
def test_allmasked_softmax_forward():
import scaled_masked_softmax_cuda
batch = 2
attn = 16
scale_t = torch.tensor([1.0])
for qlen in [128, 256, 1024, 2048, 4096]:
for klen in [128, 256, 1024, 2048]:
inputs = torch.normal(0, 2, (batch, attn, qlen, klen), dtype=torch.float16, device='cuda:0')
masks = torch.ones((batch, 1, qlen, klen), dtype=torch.bool, device='cuda:0')
softmax_results = scaled_masked_softmax_cuda.forward(inputs, masks, scale_t[0].item())
softmax_results_torch = torch.zeros_like(inputs)
error = (softmax_results_torch - softmax_results).abs().max()
assert error == 0.0
def test_allmasked_softmax_backward():
import scaled_masked_softmax_cuda
batch = 2
attn = 16
scale_t = torch.tensor([1.0])
for qlen in [128, 256, 1024, 2048, 4096]:
for klen in [128, 256, 1024, 2048]:
inputs = torch.normal(0, 2, (batch, attn, qlen, klen), dtype=torch.float16, device='cuda:0')
backward = torch.rand_like(inputs, dtype=torch.float16, device='cuda:0')
masks = torch.ones((batch, 1, qlen, klen), dtype=torch.bool, device='cuda:0')
softmax_results = scaled_masked_softmax_cuda.forward(inputs, masks, scale_t[0].item())
back_grad = scaled_masked_softmax_cuda.backward(backward, softmax_results, scale_t[0].item())
inputs.requires_grad = True
softmax_results_torch = forward_torch_softmax(inputs, masks, scale_t[0].item())
softmax_results_torch.backward(backward)
error = (back_grad - inputs.grad).abs().max()
assert error < 1e-3
if __name__ == "__main__": if __name__ == "__main__":
try: try:
from transformers import BertTokenizer, GPT2Tokenizer from transformers import BertTokenizer, GPT2Tokenizer
...@@ -294,6 +378,11 @@ if __name__ == "__main__": ...@@ -294,6 +378,11 @@ if __name__ == "__main__":
print("\n[Fail] Please install `transformers` package to test fused kernels\n") print("\n[Fail] Please install `transformers` package to test fused kernels\n")
exit(-1) exit(-1)
load()
test_masked_softmax_forward()
test_masked_softmax_backward()
test_allmasked_softmax_forward()
test_allmasked_softmax_backward()
test_load_fused_kernels() test_load_fused_kernels()
test_fused_softmax() test_fused_softmax()
test_fused_upper_triangle_mask_softmax() test_fused_upper_triangle_mask_softmax()
......
...@@ -170,6 +170,7 @@ class FusedScaleMaskSoftmax(nn.Module): ...@@ -170,6 +170,7 @@ class FusedScaleMaskSoftmax(nn.Module):
and self.input_in_float16 # input must be fp16 and self.input_in_float16 # input must be fp16
and 16 < sk <= 4096 # sk must be 16 ~ 2048 and 16 < sk <= 4096 # sk must be 16 ~ 2048
and sq % 4 == 0 # sq must be divisor of 4 and sq % 4 == 0 # sq must be divisor of 4
and sk % 4 == 0 # sk must be divisor of 4
and attn_batches % 4 == 0 # np * b must be divisor of 4 and attn_batches % 4 == 0 # np * b must be divisor of 4
): ):
if 0 <= sk <= 4096: if 0 <= sk <= 4096:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment