Commit 0be40526 authored by Jared Casper's avatar Jared Casper
Browse files

Merge branch 'github_fused_softmax' into 'main'

Fused softmax checks and additions from Github (#133)

See merge request ADLR/megatron-lm!312
parents 23266c57 bc7b3539
...@@ -32,6 +32,12 @@ torch::Tensor bwd_cuda( ...@@ -32,6 +32,12 @@ torch::Tensor bwd_cuda(
torch::Tensor const& softmax_results, torch::Tensor const& softmax_results,
float scale_factor); float scale_factor);
int get_batch_per_block_cuda(
int query_seq_len,
int key_seq_len,
int batches,
int attn_heads);
torch::Tensor fwd( torch::Tensor fwd(
torch::Tensor const& input, torch::Tensor const& input,
torch::Tensor const& mask, torch::Tensor const& mask,
...@@ -63,6 +69,14 @@ torch::Tensor bwd( ...@@ -63,6 +69,14 @@ torch::Tensor bwd(
return bwd_cuda(output_grads, softmax_results, scale_factor); return bwd_cuda(output_grads, softmax_results, scale_factor);
} }
int get_batch_per_block(
int query_seq_len,
int key_seq_len,
int batches,
int attn_heads) {
return get_batch_per_block_cuda(query_seq_len, key_seq_len, batches, attn_heads);
}
} // end namespace scaled_masked_softmax } // end namespace scaled_masked_softmax
} // end namespace fused_softmax } // end namespace fused_softmax
} // end namespace multihead_attn } // end namespace multihead_attn
...@@ -71,7 +85,13 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { ...@@ -71,7 +85,13 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("forward", m.def("forward",
&multihead_attn::fused_softmax::scaled_masked_softmax::fwd, &multihead_attn::fused_softmax::scaled_masked_softmax::fwd,
"Self Multihead Attention scaled, time masked softmax -- Forward."); "Self Multihead Attention scaled, time masked softmax -- Forward.");
m.def("backward",
m.def("backward",
&multihead_attn::fused_softmax::scaled_masked_softmax::bwd, &multihead_attn::fused_softmax::scaled_masked_softmax::bwd,
"Self Multihead Attention scaled, time masked softmax -- Backward."); "Self Multihead Attention scaled, time masked softmax -- Backward.");
m.def("get_batch_per_block",
&multihead_attn::fused_softmax::scaled_masked_softmax::get_batch_per_block,
"Return Batch per block size."
);
} }
...@@ -310,9 +310,22 @@ __global__ void scaled_masked_softmax_warp_backward( ...@@ -310,9 +310,22 @@ __global__ void scaled_masked_softmax_warp_backward(
} }
} }
} }
} // end of anonymous namespace } // end of anonymous namespace
int get_batch_per_block(int query_seq_len, int key_seq_len, int batches, int attn_heads){
int log2_elements = log2_ceil(key_seq_len);
const int next_power_of_two = 1 << log2_elements;
int warp_size = (next_power_of_two < C10_WARP_SIZE) ? next_power_of_two : C10_WARP_SIZE;
int batches_per_warp = (next_power_of_two <= 128) ? 2 : 1;
constexpr int threads_per_block = 128;
int warps_per_block = (threads_per_block / warp_size);
int batches_per_block = warps_per_block * batches_per_warp;
return batches_per_block;
}
template<typename input_t, typename output_t, typename acc_t> template<typename input_t, typename output_t, typename acc_t>
void dispatch_scaled_masked_softmax_forward( void dispatch_scaled_masked_softmax_forward(
output_t *dst, output_t *dst,
......
...@@ -28,6 +28,11 @@ namespace multihead_attn { ...@@ -28,6 +28,11 @@ namespace multihead_attn {
namespace fused_softmax { namespace fused_softmax {
namespace scaled_masked_softmax { namespace scaled_masked_softmax {
int get_batch_per_block_cuda(int query_seq_len, int key_seq_len, int batches, int attn_heads){
return get_batch_per_block(query_seq_len, key_seq_len, batches, attn_heads);
}
torch::Tensor fwd_cuda( torch::Tensor fwd_cuda(
torch::Tensor const& input, torch::Tensor const& input,
torch::Tensor const& mask, torch::Tensor const& mask,
......
...@@ -361,6 +361,7 @@ void dispatch_scaled_upper_triang_masked_softmax_forward( ...@@ -361,6 +361,7 @@ void dispatch_scaled_upper_triang_masked_softmax_forward(
int warps_per_block = (threads_per_block / warp_size); int warps_per_block = (threads_per_block / warp_size);
int batches_per_block = warps_per_block * batches_per_warp; int batches_per_block = warps_per_block * batches_per_warp;
TORCH_INTERNAL_ASSERT(attn_batches % batches_per_block == 0); TORCH_INTERNAL_ASSERT(attn_batches % batches_per_block == 0);
int blocks_per_seq = attn_batches / batches_per_block; int blocks_per_seq = attn_batches / batches_per_block;
dim3 blocks(seq_len, blocks_per_seq, 1); dim3 blocks(seq_len, blocks_per_seq, 1);
dim3 threads(warp_size, warps_per_block, 1); dim3 threads(warp_size, warps_per_block, 1);
...@@ -451,6 +452,7 @@ void dispatch_scaled_upper_triang_masked_softmax_backward( ...@@ -451,6 +452,7 @@ void dispatch_scaled_upper_triang_masked_softmax_backward(
int warps_per_block = (threads_per_block / warp_size); int warps_per_block = (threads_per_block / warp_size);
int batches_per_block = warps_per_block * batches_per_warp; int batches_per_block = warps_per_block * batches_per_warp;
TORCH_INTERNAL_ASSERT(attn_batches % batches_per_block == 0); TORCH_INTERNAL_ASSERT(attn_batches % batches_per_block == 0);
int blocks_per_seq = attn_batches / batches_per_block; int blocks_per_seq = attn_batches / batches_per_block;
dim3 blocks(seq_len, blocks_per_seq, 1); dim3 blocks(seq_len, blocks_per_seq, 1);
dim3 threads(warp_size, warps_per_block, 1); dim3 threads(warp_size, warps_per_block, 1);
......
import math
import torch
from torch.nn import LayerNorm
from megatron.model.enums import AttnMaskType
from megatron.model.fused_layer_norm import MixedFusedLayerNorm
from megatron.model.fused_softmax import FusedScaleMaskSoftmax
from megatron.model.utils import attention_mask_func
def test_load_fused_kernels():
try:
import fused_mix_prec_layer_norm_cuda
import scaled_masked_softmax_cuda
import scaled_upper_triang_masked_softmax_cuda
import torch
print("[Success] load_fused_kernels")
except ImportError as e:
print("[Fail] load_fused_kernels")
raise e
def test_fused_softmax():
bert = BertModel.from_pretrained("bert-base-cased").cuda().half()
tokenizer = BertTokenizer.from_pretrained("bert-base-cased")
test_text = (
"Hello. How are you? I am fine thank you and you? yes Good. "
"hi hi hi hi hi hi hi hi hi hi hi hi hi" # 32
)
tokens = tokenizer(
[test_text] * 4,
return_tensors="pt",
)
embedding_output = bert.embeddings(
input_ids=tokens["input_ids"].cuda(),
position_ids=None,
token_type_ids=tokens["token_type_ids"].cuda(),
inputs_embeds=None,
past_key_values_length=0,
)
# (bsz, 1, 1, seq_len)
mask = bert.get_extended_attention_mask(
attention_mask=tokens["attention_mask"].cuda(),
input_shape=tokens["input_ids"].shape,
device=bert.device,
)
# (bsz, 1, seq_len, seq_len)
mask = mask.repeat(1, 1, mask.size()[-1], 1)
attention = bert.encoder.layer[0].attention.self
key_layer = attention.transpose_for_scores(attention.key(embedding_output))
query_layer = attention.transpose_for_scores(attention.query(embedding_output))
attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))
attention_scores /= math.sqrt(key_layer.size()[-1])
fused_softmax = (
FusedScaleMaskSoftmax(
input_in_fp16=True,
input_in_bf16=False,
mask_func=attention_mask_func,
scale=None,
softmax_in_fp32=False,
attn_mask_type=AttnMaskType.padding,
scaled_masked_softmax_fusion=True,
)
.cuda()
.half()
)
fused_softmax_output = fused_softmax(
attention_scores,
(mask != 0),
)
torch_softmax = (
FusedScaleMaskSoftmax(
input_in_fp16=True,
input_in_bf16=False,
mask_func=attention_mask_func,
scale=None,
softmax_in_fp32=False,
attn_mask_type=AttnMaskType.padding,
scaled_masked_softmax_fusion=False,
)
.cuda()
.half()
)
torch_softmax_output = torch_softmax(
attention_scores,
(mask != 0),
)
test_result = (fused_softmax_output - torch_softmax_output).abs()
while test_result.dim() != 1:
test_result = test_result.mean(dim=-1)
diff = test_result.mean(dim=-1)
if diff <= 1e-3:
print(
f"\n[Success] test_fused_softmax"
f"\n > mean_difference={diff}"
f"\n > fused_values={fused_softmax_output[-1][-1][-1][:5].tolist()}"
f"\n > torch_values={torch_softmax_output[-1][-1][-1][:5].tolist()}"
)
else:
print(
f"\n[Fail] test_fused_softmax"
f"\n > mean_difference={diff}, "
f"\n > fused_values={fused_softmax_output[-1][-1][-1][:5].tolist()}, "
f"\n > torch_values={torch_softmax_output[-1][-1][-1][:5].tolist()}"
)
def test_fused_upper_triangle_mask_softmax():
gpt = GPT2Model.from_pretrained("gpt2").cuda().half()
tokenizer = GPT2Tokenizer.from_pretrained("gpt2")
test_text = (
"Hello. How are you? I am fine thank you and you? yes Good. "
"hi hi hi hi hi hi hi" # 24
)
tokens = tokenizer(
[test_text] * 4,
return_tensors="pt",
)
attention_mask = tokens["attention_mask"].cuda()
attention_mask = attention_mask.view(attention_mask.size(0), -1)
attention_mask = attention_mask[:, None, None, :]
attention_mask = (1.0 - attention_mask) * -10000.0
attention_mask = attention_mask.repeat(1, 1, attention_mask.size()[-1], 1)
attn = gpt.h[0]
hidden_states = gpt.wte(tokens["input_ids"].cuda())
q, k, v = attn.attn.c_attn(hidden_states).split(768, dim=-1)
q = attn.attn._split_heads(q, attn.attn.num_heads, attn.attn.head_dim)
k = attn.attn._split_heads(k, attn.attn.num_heads, attn.attn.head_dim)
attn_weights = torch.matmul(q, k.transpose(-1, -2))
sq, sk = q.size(-2), k.size(-2)
causal_mask = attn.attn.bias[:, :, sk - sq : sk, :sk].bool()
total_mask = ~(causal_mask & (attention_mask == 0))
"""
tensor([[[[False, True, True, ..., True, True, True],
[False, False, True, ..., True, True, True],
[False, False, False, ..., True, True, True],
...,
[False, False, False, ..., False, True, True],
[False, False, False, ..., False, False, True],
[False, False, False, ..., False, False, False]]]
"""
fused_softmax = (
FusedScaleMaskSoftmax(
input_in_fp16=True,
input_in_bf16=False,
mask_func=attention_mask_func,
scale=None,
softmax_in_fp32=False,
attn_mask_type=AttnMaskType.causal,
scaled_masked_softmax_fusion=True,
)
.cuda()
.half()
)
fused_softmax_output = fused_softmax(
attn_weights,
total_mask,
)
torch_softmax = (
FusedScaleMaskSoftmax(
input_in_fp16=True,
input_in_bf16=False,
mask_func=attention_mask_func,
scale=None,
softmax_in_fp32=False,
attn_mask_type=AttnMaskType.causal,
scaled_masked_softmax_fusion=False,
)
.cuda()
.half()
)
torch_softmax_output = torch_softmax(
attn_weights,
total_mask,
)
test_result = (fused_softmax_output - torch_softmax_output).abs()
while test_result.dim() != 1:
test_result = test_result.mean(dim=-1)
diff = test_result.mean(dim=-1)
if diff <= 1e-3:
print(
f"\n[Success] test_fused_upper_triangle_mask_softmax"
f"\n > mean_difference={diff}"
f"\n > fused_values={fused_softmax_output[-1][-1][-1][:5].tolist()}"
f"\n > torch_values={torch_softmax_output[-1][-1][-1][:5].tolist()}"
)
else:
print(
f"\n[Fail] test_fused_upper_triangle_mask_softmax"
f"\n > mean_difference={diff}, "
f"\n > fused_values={fused_softmax_output[-1][-1][-1][:5].tolist()}, "
f"\n > torch_values={torch_softmax_output[-1][-1][-1][:5].tolist()}"
)
def test_layer_norm():
bert = BertModel.from_pretrained("bert-base-cased").cuda().half()
tokenizer = BertTokenizer.from_pretrained("bert-base-cased")
test_text = (
"Hello. How are you? I am fine thank you and you? yes Good. "
"hi hi hi hi hi hi hi hi hi hi hi hi hi" # 32
)
tokens = tokenizer(
[test_text] * 4,
return_tensors="pt",
)
# [bsz, seq_len, d_model]
embedding_output = (
bert.embeddings(
input_ids=tokens["input_ids"].cuda(),
position_ids=None,
token_type_ids=tokens["token_type_ids"].cuda(),
inputs_embeds=None,
past_key_values_length=0,
)
.cuda()
.half()
)
fused_layernorm_layer = (
MixedFusedLayerNorm(normalized_shape=embedding_output.size(-1)).cuda().half()
)
torch_layernorm_layer = (
LayerNorm(normalized_shape=embedding_output.size(-1)).cuda().half()
)
fused_output = fused_layernorm_layer(embedding_output)
torch_output = torch_layernorm_layer(embedding_output)
test_result = (fused_output - torch_output).abs()
while test_result.dim() != 1:
test_result = test_result.mean(dim=-1)
diff = test_result.mean(dim=-1)
if diff <= 1e-3:
print(
f"\n[Success] test_layer_norm"
f"\n > mean_difference={diff}"
f"\n > fused_values={fused_output[-1][-1][:5].tolist()}"
f"\n > torch_values={torch_output[-1][-1][:5].tolist()}"
)
else:
print(
f"\n[Fail] test_layer_norm"
f"\n > mean_difference={diff}, "
f"\n > fused_values={fused_output[-1][-1][:5].tolist()}, "
f"\n > torch_values={torch_output[-1][-1][:5].tolist()}"
)
if __name__ == "__main__":
try:
from transformers import BertTokenizer, GPT2Tokenizer
from transformers.models.bert.modeling_bert import BertModel
from transformers.models.gpt2.modeling_gpt2 import GPT2Model
import transformers
transformers.logging.set_verbosity(
transformers.logging.FATAL,
)
except:
print("\n[Fail] Please install `transformers` package to test fused kernels\n")
exit(-1)
test_load_fused_kernels()
test_fused_softmax()
test_fused_upper_triangle_mask_softmax()
test_layer_norm()
...@@ -13,7 +13,9 @@ ...@@ -13,7 +13,9 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import torch import torch
import torch.nn as nn
from megatron.model.enums import AttnMaskType from megatron.model.enums import AttnMaskType
...@@ -30,10 +32,10 @@ class ScaledUpperTriangMaskedSoftmax(torch.autograd.Function): ...@@ -30,10 +32,10 @@ class ScaledUpperTriangMaskedSoftmax(torch.autograd.Function):
import scaled_upper_triang_masked_softmax_cuda import scaled_upper_triang_masked_softmax_cuda
scale_t = torch.tensor([scale]) scale_t = torch.tensor([scale])
softmax_results = scaled_upper_triang_masked_softmax_cuda.forward( softmax_results = scaled_upper_triang_masked_softmax_cuda.forward(
inputs, scale_t[0] inputs, scale_t[0]
) )
ctx.save_for_backward(softmax_results, scale_t) ctx.save_for_backward(softmax_results, scale_t)
return softmax_results return softmax_results
...@@ -42,10 +44,10 @@ class ScaledUpperTriangMaskedSoftmax(torch.autograd.Function): ...@@ -42,10 +44,10 @@ class ScaledUpperTriangMaskedSoftmax(torch.autograd.Function):
import scaled_upper_triang_masked_softmax_cuda import scaled_upper_triang_masked_softmax_cuda
softmax_results, scale_t = ctx.saved_tensors softmax_results, scale_t = ctx.saved_tensors
input_grads = scaled_upper_triang_masked_softmax_cuda.backward( input_grads = scaled_upper_triang_masked_softmax_cuda.backward(
output_grads, softmax_results, scale_t[0] output_grads, softmax_results, scale_t[0]
) )
return input_grads, None return input_grads, None
...@@ -63,9 +65,7 @@ class ScaledMaskedSoftmax(torch.autograd.Function): ...@@ -63,9 +65,7 @@ class ScaledMaskedSoftmax(torch.autograd.Function):
scale_t = torch.tensor([scale]) scale_t = torch.tensor([scale])
softmax_results = scaled_masked_softmax_cuda.forward( softmax_results = scaled_masked_softmax_cuda.forward(inputs, mask, scale_t[0])
inputs, mask, scale_t[0]
)
ctx.save_for_backward(softmax_results, scale_t) ctx.save_for_backward(softmax_results, scale_t)
return softmax_results return softmax_results
...@@ -81,16 +81,18 @@ class ScaledMaskedSoftmax(torch.autograd.Function): ...@@ -81,16 +81,18 @@ class ScaledMaskedSoftmax(torch.autograd.Function):
return input_grads, None, None return input_grads, None, None
class FusedScaleMaskSoftmax(torch.nn.Module): class FusedScaleMaskSoftmax(nn.Module):
""" """
fused operation: scaling + mask + softmax fused operation: scaling + mask + softmax
Arguments: Arguments:
input_in_fp16: flag to indicate if input in fp16 data format. input_in_fp16: flag to indicate if input in fp16 data format.
input_in_bf16: flag to indicate if input in bf16 data format.
attn_mask_type: attention mask type (pad or causal) attn_mask_type: attention mask type (pad or causal)
scaled_masked_softmax_fusion: flag to indicate user want to use softmax fusion
mask_func: mask function to be applied. mask_func: mask function to be applied.
softmax_in_fp32: if true, softmax in performed at fp32 precision. softmax_in_fp32: if true, softmax in performed at fp32 precision.
scale: scaling factor used in input tensor scaling. scale: scaling factor used in input tensor scaling.
""" """
def __init__( def __init__(
...@@ -106,8 +108,9 @@ class FusedScaleMaskSoftmax(torch.nn.Module): ...@@ -106,8 +108,9 @@ class FusedScaleMaskSoftmax(torch.nn.Module):
super(FusedScaleMaskSoftmax, self).__init__() super(FusedScaleMaskSoftmax, self).__init__()
self.input_in_fp16 = input_in_fp16 self.input_in_fp16 = input_in_fp16
self.input_in_bf16 = input_in_bf16 self.input_in_bf16 = input_in_bf16
assert not (self.input_in_fp16 and self.input_in_bf16),\ assert not (
'both fp16 and bf16 flags cannot be active at the same time.' self.input_in_fp16 and self.input_in_bf16
), "both fp16 and bf16 flags cannot be active at the same time."
self.input_in_float16 = self.input_in_fp16 or self.input_in_bf16 self.input_in_float16 = self.input_in_fp16 or self.input_in_bf16
self.attn_mask_type = attn_mask_type self.attn_mask_type = attn_mask_type
self.scaled_masked_softmax_fusion = scaled_masked_softmax_fusion self.scaled_masked_softmax_fusion = scaled_masked_softmax_fusion
...@@ -118,47 +121,72 @@ class FusedScaleMaskSoftmax(torch.nn.Module): ...@@ -118,47 +121,72 @@ class FusedScaleMaskSoftmax(torch.nn.Module):
assert ( assert (
self.scale is None or softmax_in_fp32 self.scale is None or softmax_in_fp32
), "softmax should be in fp32 when scaled" ), "softmax should be in fp32 when scaled"
def forward(self, input, mask): def forward(self, input, mask):
# [b, np, sq, sk] # [b, np, sq, sk]
assert input.dim() == 4 assert input.dim() == 4
data_size = input.size()
query_seq_len = data_size[-2] if self.is_kernel_available(mask, *input.size()):
key_seq_len = data_size[-1] return self.forward_fused_softmax(input, mask)
attn_batch_size = data_size[0] * data_size[1]
# constraints on various tensor dimensions to enable warp based
# optimization and upper triangular optimization (for causal mask)
custom_kernel_constraint = key_seq_len > 16 and key_seq_len <= 2048 and \
query_seq_len % 4 == 0 and attn_batch_size % 4 == 0
# invoke custom kernel
if self.input_in_float16 and mask is not None and \
custom_kernel_constraint and self.scaled_masked_softmax_fusion:
scale = self.scale if self.scale is not None else 1.0
if self.attn_mask_type == AttnMaskType.causal:
assert query_seq_len == key_seq_len, \
"causal mask is only for self attention"
input = input.view(-1, query_seq_len, key_seq_len)
probs = ScaledUpperTriangMaskedSoftmax.apply(input, scale)
probs = probs.view(*data_size)
else:
assert self.attn_mask_type == AttnMaskType.padding
probs = ScaledMaskedSoftmax.apply(input, mask, scale)
else: else:
if self.input_in_float16 and self.softmax_in_fp32: return self.forward_torch_softmax(input, mask)
input = input.float()
def is_kernel_available(self, mask, b, np, sq, sk):
attn_batches = b * np
if (
self.scaled_masked_softmax_fusion # user want to fuse
and self.input_in_float16 # input must be fp16
and mask is not None # mask tensor must not be None
and 16 < sk <= 2048 # sk must be 16 ~ 2048
and sq % 4 == 0 # sq must be divisor of 4
and attn_batches % 4 == 0 # np * b must be divisor of 4
):
if 0 <= sk <= 2048:
batch_per_block = self.get_batch_per_block(sq, sk, b, np)
if self.attn_mask_type == AttnMaskType.causal:
if attn_batches % batch_per_block == 0:
return True
else:
if sq % batch_per_block == 0:
return True
return False
if self.scale is not None: def forward_fused_softmax(self, input, mask):
input = input * self.scale b, np, sq, sk = input.size()
mask_output = self.mask_func(input, mask) if mask is not None else input scale = self.scale if self.scale is not None else 1.0
probs = torch.nn.Softmax(dim=-1)(mask_output)
if self.input_in_float16 and self.softmax_in_fp32: if self.attn_mask_type == AttnMaskType.causal:
if self.input_in_fp16: assert sq == sk, "causal mask is only for self attention"
probs = probs.half()
else: # input is 3D tensor (attn_batches, sq, sk)
probs = probs.bfloat16() input = input.view(-1, sq, sk)
probs = ScaledUpperTriangMaskedSoftmax.apply(input, scale)
return probs.view(b, np, sq, sk)
else:
# input is 4D tensor (b, np, sq, sk)
return ScaledMaskedSoftmax.apply(input, mask, scale)
def forward_torch_softmax(self, input, mask):
if self.input_in_float16 and self.softmax_in_fp32:
input = input.float()
if self.scale is not None:
input = input * self.scale
mask_output = self.mask_func(input, mask) if mask is not None else input
probs = torch.nn.Softmax(dim=-1)(mask_output)
if self.input_in_float16 and self.softmax_in_fp32:
if self.input_in_fp16:
probs = probs.half()
else:
probs = probs.bfloat16()
return probs return probs
@staticmethod
def get_batch_per_block(sq, sk, b, np):
import scaled_masked_softmax_cuda
return scaled_masked_softmax_cuda.get_batch_per_block(sq, sk, b, np)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment