Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
Megatron-LM
Commits
40becfc9
"...models/git@developer.sourcefind.cn:OpenDAS/lmdeploy.git" did not exist on "5ea40abf613e47bb56a0c06f695644d55671f585"
Commit
40becfc9
authored
Aug 13, 2021
by
hyunwoongko
Committed by
mshoeybi
Aug 22, 2021
Browse files
Improve and fix bugs about fused softmax layer
parent
23266c57
Changes
7
Hide whitespace changes
Inline
Side-by-side
Showing
7 changed files
with
415 additions
and
54 deletions
+415
-54
megatron/fused_kernels/scaled_masked_softmax.cpp
megatron/fused_kernels/scaled_masked_softmax.cpp
+21
-1
megatron/fused_kernels/scaled_masked_softmax.h
megatron/fused_kernels/scaled_masked_softmax.h
+16
-4
megatron/fused_kernels/scaled_masked_softmax_cuda.cu
megatron/fused_kernels/scaled_masked_softmax_cuda.cu
+5
-0
megatron/fused_kernels/scaled_upper_triang_masked_softmax.h
megatron/fused_kernels/scaled_upper_triang_masked_softmax.h
+0
-4
megatron/fused_kernels/tests/__init__.py
megatron/fused_kernels/tests/__init__.py
+0
-0
megatron/fused_kernels/tests/test_fused_kernels.py
megatron/fused_kernels/tests/test_fused_kernels.py
+300
-0
megatron/model/fused_softmax.py
megatron/model/fused_softmax.py
+73
-45
No files found.
megatron/fused_kernels/scaled_masked_softmax.cpp
View file @
40becfc9
...
...
@@ -32,6 +32,12 @@ torch::Tensor bwd_cuda(
torch
::
Tensor
const
&
softmax_results
,
float
scale_factor
);
int
get_batch_per_block_cuda
(
int
query_seq_len
,
int
key_seq_len
,
int
batches
,
int
attn_heads
);
torch
::
Tensor
fwd
(
torch
::
Tensor
const
&
input
,
torch
::
Tensor
const
&
mask
,
...
...
@@ -63,6 +69,14 @@ torch::Tensor bwd(
return
bwd_cuda
(
output_grads
,
softmax_results
,
scale_factor
);
}
int
get_batch_per_block
(
int
query_seq_len
,
int
key_seq_len
,
int
batches
,
int
attn_heads
)
{
return
get_batch_per_block_cuda
(
query_seq_len
,
key_seq_len
,
batches
,
attn_heads
);
}
}
// end namespace scaled_masked_softmax
}
// end namespace fused_softmax
}
// end namespace multihead_attn
...
...
@@ -71,7 +85,13 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m
.
def
(
"forward"
,
&
multihead_attn
::
fused_softmax
::
scaled_masked_softmax
::
fwd
,
"Self Multihead Attention scaled, time masked softmax -- Forward."
);
m
.
def
(
"backward"
,
m
.
def
(
"backward"
,
&
multihead_attn
::
fused_softmax
::
scaled_masked_softmax
::
bwd
,
"Self Multihead Attention scaled, time masked softmax -- Backward."
);
m
.
def
(
"get_batch_per_block"
,
&
multihead_attn
::
fused_softmax
::
scaled_masked_softmax
::
get_batch_per_block
,
"Return Batch per block size."
);
}
megatron/fused_kernels/scaled_masked_softmax.h
View file @
40becfc9
...
...
@@ -16,6 +16,7 @@
#pragma once
#include <stdio.h>
#include <assert.h>
#include <cuda_fp16.h>
#include <cfloat>
...
...
@@ -310,9 +311,23 @@ __global__ void scaled_masked_softmax_warp_backward(
}
}
}
}
// end of anonymous namespace
int
get_batch_per_block
(
int
query_seq_len
,
int
key_seq_len
,
int
batches
,
int
attn_heads
){
int
log2_elements
=
log2_ceil
(
key_seq_len
);
const
int
next_power_of_two
=
1
<<
log2_elements
;
int
batch_count
=
batches
*
attn_heads
*
query_seq_len
;
int
warp_size
=
(
next_power_of_two
<
C10_WARP_SIZE
)
?
next_power_of_two
:
C10_WARP_SIZE
;
int
batches_per_warp
=
(
next_power_of_two
<=
128
)
?
2
:
1
;
constexpr
int
threads_per_block
=
128
;
int
warps_per_block
=
(
threads_per_block
/
warp_size
);
int
batches_per_block
=
warps_per_block
*
batches_per_warp
;
return
batches_per_block
;
}
template
<
typename
input_t
,
typename
output_t
,
typename
acc_t
>
void
dispatch_scaled_masked_softmax_forward
(
output_t
*
dst
,
...
...
@@ -325,7 +340,6 @@ void dispatch_scaled_masked_softmax_forward(
int
attn_heads
,
int
pad_batches
)
{
TORCH_INTERNAL_ASSERT
(
key_seq_len
>=
0
&&
key_seq_len
<=
2048
);
if
(
key_seq_len
==
0
)
{
return
;
}
else
{
...
...
@@ -344,7 +358,6 @@ void dispatch_scaled_masked_softmax_forward(
int
warps_per_block
=
(
threads_per_block
/
warp_size
);
int
batches_per_block
=
warps_per_block
*
batches_per_warp
;
TORCH_INTERNAL_ASSERT
(
query_seq_len
%
batches_per_block
==
0
);
dim3
blocks
(
query_seq_len
/
batches_per_block
,
attn_heads
,
batches
);
dim3
threads
(
warp_size
,
warps_per_block
,
1
);
// Launch code would be more elegant if C++ supported FOR CONSTEXPR
...
...
@@ -414,7 +427,6 @@ void dispatch_scaled_masked_softmax_backward(
int
batches
,
int
attn_heads
)
{
TORCH_INTERNAL_ASSERT
(
key_seq_len
>=
0
&&
key_seq_len
<=
2048
);
if
(
key_seq_len
==
0
)
{
return
;
}
else
{
...
...
megatron/fused_kernels/scaled_masked_softmax_cuda.cu
View file @
40becfc9
...
...
@@ -28,6 +28,11 @@ namespace multihead_attn {
namespace
fused_softmax
{
namespace
scaled_masked_softmax
{
int
get_batch_per_block_cuda
(
int
query_seq_len
,
int
key_seq_len
,
int
batches
,
int
attn_heads
){
return
get_batch_per_block
(
query_seq_len
,
key_seq_len
,
batches
,
attn_heads
);
}
torch
::
Tensor
fwd_cuda
(
torch
::
Tensor
const
&
input
,
torch
::
Tensor
const
&
mask
,
...
...
megatron/fused_kernels/scaled_upper_triang_masked_softmax.h
View file @
40becfc9
...
...
@@ -340,7 +340,6 @@ void dispatch_scaled_upper_triang_masked_softmax_forward(
int
softmax_elements_stride
,
int
attn_batches
)
{
TORCH_INTERNAL_ASSERT
(
softmax_elements
>=
0
&&
softmax_elements
<=
2048
);
if
(
softmax_elements
==
0
)
{
return
;
}
else
{
...
...
@@ -360,7 +359,6 @@ void dispatch_scaled_upper_triang_masked_softmax_forward(
int
warps_per_block
=
(
threads_per_block
/
warp_size
);
int
batches_per_block
=
warps_per_block
*
batches_per_warp
;
TORCH_INTERNAL_ASSERT
(
attn_batches
%
batches_per_block
==
0
);
int
blocks_per_seq
=
attn_batches
/
batches_per_block
;
dim3
blocks
(
seq_len
,
blocks_per_seq
,
1
);
dim3
threads
(
warp_size
,
warps_per_block
,
1
);
...
...
@@ -430,7 +428,6 @@ void dispatch_scaled_upper_triang_masked_softmax_backward(
int
softmax_elements_stride
,
int
attn_batches
)
{
TORCH_INTERNAL_ASSERT
(
softmax_elements
>=
0
&&
softmax_elements
<=
2048
);
if
(
softmax_elements
==
0
)
{
return
;
}
else
{
...
...
@@ -450,7 +447,6 @@ void dispatch_scaled_upper_triang_masked_softmax_backward(
int
warps_per_block
=
(
threads_per_block
/
warp_size
);
int
batches_per_block
=
warps_per_block
*
batches_per_warp
;
TORCH_INTERNAL_ASSERT
(
attn_batches
%
batches_per_block
==
0
);
int
blocks_per_seq
=
attn_batches
/
batches_per_block
;
dim3
blocks
(
seq_len
,
blocks_per_seq
,
1
);
dim3
threads
(
warp_size
,
warps_per_block
,
1
);
...
...
megatron/fused_kernels/tests/__init__.py
0 → 100644
View file @
40becfc9
megatron/fused_kernels/tests/test_fused_kernels.py
0 → 100644
View file @
40becfc9
import
math
import
torch
from
torch.nn
import
LayerNorm
from
megatron.model.enums
import
AttnMaskType
from
megatron.model.fused_layer_norm
import
MixedFusedLayerNorm
from
megatron.model.fused_softmax
import
FusedScaleMaskSoftmax
from
megatron.model.utils
import
attention_mask_func
def
test_load_fused_kernels
():
try
:
import
fused_mix_prec_layer_norm_cuda
import
scaled_masked_softmax_cuda
import
scaled_upper_triang_masked_softmax_cuda
import
torch
print
(
"[Success] load_fused_kernels"
)
except
ImportError
as
e
:
print
(
"[Fail] load_fused_kernels"
)
raise
e
def
test_fused_softmax
():
bert
=
BertModel
.
from_pretrained
(
"bert-base-cased"
).
cuda
().
half
()
tokenizer
=
BertTokenizer
.
from_pretrained
(
"bert-base-cased"
)
test_text
=
(
"Hello. How are you? I am fine thank you and you? yes Good. "
"hi hi hi hi hi hi hi hi hi hi hi hi hi"
# 32
)
tokens
=
tokenizer
(
[
test_text
]
*
4
,
return_tensors
=
"pt"
,
)
embedding_output
=
bert
.
embeddings
(
input_ids
=
tokens
[
"input_ids"
].
cuda
(),
position_ids
=
None
,
token_type_ids
=
tokens
[
"token_type_ids"
].
cuda
(),
inputs_embeds
=
None
,
past_key_values_length
=
0
,
)
# (bsz, 1, 1, seq_len)
mask
=
bert
.
get_extended_attention_mask
(
attention_mask
=
tokens
[
"attention_mask"
].
cuda
(),
input_shape
=
tokens
[
"input_ids"
].
shape
,
device
=
bert
.
device
,
)
# (bsz, 1, seq_len, seq_len)
mask
=
mask
.
repeat
(
1
,
1
,
mask
.
size
()[
-
1
],
1
)
attention
=
bert
.
encoder
.
layer
[
0
].
attention
.
self
key_layer
=
attention
.
transpose_for_scores
(
attention
.
key
(
embedding_output
))
query_layer
=
attention
.
transpose_for_scores
(
attention
.
query
(
embedding_output
))
attention_scores
=
torch
.
matmul
(
query_layer
,
key_layer
.
transpose
(
-
1
,
-
2
))
attention_scores
/=
math
.
sqrt
(
key_layer
.
size
()[
-
1
])
fused_softmax
=
(
FusedScaleMaskSoftmax
(
input_in_fp16
=
True
,
input_in_bf16
=
False
,
mask_func
=
attention_mask_func
,
scale
=
None
,
softmax_in_fp32
=
False
,
attn_mask_type
=
AttnMaskType
.
padding
,
scaled_masked_softmax_fusion
=
True
,
)
.
cuda
()
.
half
()
)
fused_softmax_output
=
fused_softmax
(
attention_scores
,
(
mask
!=
0
),
)
torch_softmax
=
(
FusedScaleMaskSoftmax
(
input_in_fp16
=
True
,
input_in_bf16
=
False
,
mask_func
=
attention_mask_func
,
scale
=
None
,
softmax_in_fp32
=
False
,
attn_mask_type
=
AttnMaskType
.
padding
,
scaled_masked_softmax_fusion
=
False
,
)
.
cuda
()
.
half
()
)
torch_softmax_output
=
torch_softmax
(
attention_scores
,
(
mask
!=
0
),
)
test_result
=
(
fused_softmax_output
-
torch_softmax_output
).
abs
()
while
test_result
.
dim
()
!=
1
:
test_result
=
test_result
.
mean
(
dim
=-
1
)
diff
=
test_result
.
mean
(
dim
=-
1
)
if
diff
<=
1e-3
:
print
(
f
"
\n
[Success] test_fused_softmax"
f
"
\n
> mean_difference=
{
diff
}
"
f
"
\n
> fused_values=
{
fused_softmax_output
[
-
1
][
-
1
][
-
1
][:
5
].
tolist
()
}
"
f
"
\n
> torch_values=
{
torch_softmax_output
[
-
1
][
-
1
][
-
1
][:
5
].
tolist
()
}
"
)
else
:
print
(
f
"
\n
[Fail] test_fused_softmax"
f
"
\n
> mean_difference=
{
diff
}
, "
f
"
\n
> fused_values=
{
fused_softmax_output
[
-
1
][
-
1
][
-
1
][:
5
].
tolist
()
}
, "
f
"
\n
> torch_values=
{
torch_softmax_output
[
-
1
][
-
1
][
-
1
][:
5
].
tolist
()
}
"
)
def
test_fused_upper_triangle_mask_softmax
():
gpt
=
GPT2Model
.
from_pretrained
(
"gpt2"
).
cuda
().
half
()
tokenizer
=
GPT2Tokenizer
.
from_pretrained
(
"gpt2"
)
test_text
=
(
"Hello. How are you? I am fine thank you and you? yes Good. "
"hi hi hi hi hi hi hi"
# 24
)
tokens
=
tokenizer
(
[
test_text
]
*
4
,
return_tensors
=
"pt"
,
)
attention_mask
=
tokens
[
"attention_mask"
].
cuda
()
attention_mask
=
attention_mask
.
view
(
attention_mask
.
size
(
0
),
-
1
)
attention_mask
=
attention_mask
[:,
None
,
None
,
:]
attention_mask
=
(
1.0
-
attention_mask
)
*
-
10000.0
attention_mask
=
attention_mask
.
repeat
(
1
,
1
,
attention_mask
.
size
()[
-
1
],
1
)
attn
=
gpt
.
h
[
0
]
hidden_states
=
gpt
.
wte
(
tokens
[
"input_ids"
].
cuda
())
q
,
k
,
v
=
attn
.
attn
.
c_attn
(
hidden_states
).
split
(
768
,
dim
=-
1
)
q
=
attn
.
attn
.
_split_heads
(
q
,
attn
.
attn
.
num_heads
,
attn
.
attn
.
head_dim
)
k
=
attn
.
attn
.
_split_heads
(
k
,
attn
.
attn
.
num_heads
,
attn
.
attn
.
head_dim
)
attn_weights
=
torch
.
matmul
(
q
,
k
.
transpose
(
-
1
,
-
2
))
sq
,
sk
=
q
.
size
(
-
2
),
k
.
size
(
-
2
)
causal_mask
=
attn
.
attn
.
bias
[:,
:,
sk
-
sq
:
sk
,
:
sk
].
bool
()
total_mask
=
~
(
causal_mask
&
(
attention_mask
==
0
))
"""
tensor([[[[False, True, True, ..., True, True, True],
[False, False, True, ..., True, True, True],
[False, False, False, ..., True, True, True],
...,
[False, False, False, ..., False, True, True],
[False, False, False, ..., False, False, True],
[False, False, False, ..., False, False, False]]]
"""
fused_softmax
=
(
FusedScaleMaskSoftmax
(
input_in_fp16
=
True
,
input_in_bf16
=
False
,
mask_func
=
attention_mask_func
,
scale
=
None
,
softmax_in_fp32
=
False
,
attn_mask_type
=
AttnMaskType
.
causal
,
scaled_masked_softmax_fusion
=
True
,
)
.
cuda
()
.
half
()
)
fused_softmax_output
=
fused_softmax
(
attn_weights
,
total_mask
,
)
torch_softmax
=
(
FusedScaleMaskSoftmax
(
input_in_fp16
=
True
,
input_in_bf16
=
False
,
mask_func
=
attention_mask_func
,
scale
=
None
,
softmax_in_fp32
=
False
,
attn_mask_type
=
AttnMaskType
.
causal
,
scaled_masked_softmax_fusion
=
False
,
)
.
cuda
()
.
half
()
)
torch_softmax_output
=
torch_softmax
(
attn_weights
,
total_mask
,
)
test_result
=
(
fused_softmax_output
-
torch_softmax_output
).
abs
()
while
test_result
.
dim
()
!=
1
:
test_result
=
test_result
.
mean
(
dim
=-
1
)
diff
=
test_result
.
mean
(
dim
=-
1
)
if
diff
<=
1e-3
:
print
(
f
"
\n
[Success] test_fused_upper_triangle_mask_softmax"
f
"
\n
> mean_difference=
{
diff
}
"
f
"
\n
> fused_values=
{
fused_softmax_output
[
-
1
][
-
1
][
-
1
][:
5
].
tolist
()
}
"
f
"
\n
> torch_values=
{
torch_softmax_output
[
-
1
][
-
1
][
-
1
][:
5
].
tolist
()
}
"
)
else
:
print
(
f
"
\n
[Fail] test_fused_upper_triangle_mask_softmax"
f
"
\n
> mean_difference=
{
diff
}
, "
f
"
\n
> fused_values=
{
fused_softmax_output
[
-
1
][
-
1
][
-
1
][:
5
].
tolist
()
}
, "
f
"
\n
> torch_values=
{
torch_softmax_output
[
-
1
][
-
1
][
-
1
][:
5
].
tolist
()
}
"
)
def
test_layer_norm
():
bert
=
BertModel
.
from_pretrained
(
"bert-base-cased"
).
cuda
().
half
()
tokenizer
=
BertTokenizer
.
from_pretrained
(
"bert-base-cased"
)
test_text
=
(
"Hello. How are you? I am fine thank you and you? yes Good. "
"hi hi hi hi hi hi hi hi hi hi hi hi hi"
# 32
)
tokens
=
tokenizer
(
[
test_text
]
*
4
,
return_tensors
=
"pt"
,
)
# [bsz, seq_len, d_model]
embedding_output
=
(
bert
.
embeddings
(
input_ids
=
tokens
[
"input_ids"
].
cuda
(),
position_ids
=
None
,
token_type_ids
=
tokens
[
"token_type_ids"
].
cuda
(),
inputs_embeds
=
None
,
past_key_values_length
=
0
,
)
.
cuda
()
.
half
()
)
fused_layernorm_layer
=
(
MixedFusedLayerNorm
(
normalized_shape
=
embedding_output
.
size
(
-
1
)).
cuda
().
half
()
)
torch_layernorm_layer
=
(
LayerNorm
(
normalized_shape
=
embedding_output
.
size
(
-
1
)).
cuda
().
half
()
)
fused_output
=
fused_layernorm_layer
(
embedding_output
)
torch_output
=
torch_layernorm_layer
(
embedding_output
)
test_result
=
(
fused_output
-
torch_output
).
abs
()
while
test_result
.
dim
()
!=
1
:
test_result
=
test_result
.
mean
(
dim
=-
1
)
diff
=
test_result
.
mean
(
dim
=-
1
)
if
diff
<=
1e-3
:
print
(
f
"
\n
[Success] test_layer_norm"
f
"
\n
> mean_difference=
{
diff
}
"
f
"
\n
> fused_values=
{
fused_output
[
-
1
][
-
1
][:
5
].
tolist
()
}
"
f
"
\n
> torch_values=
{
torch_output
[
-
1
][
-
1
][:
5
].
tolist
()
}
"
)
else
:
print
(
f
"
\n
[Fail] test_layer_norm"
f
"
\n
> mean_difference=
{
diff
}
, "
f
"
\n
> fused_values=
{
fused_output
[
-
1
][
-
1
][:
5
].
tolist
()
}
, "
f
"
\n
> torch_values=
{
torch_output
[
-
1
][
-
1
][:
5
].
tolist
()
}
"
)
if
__name__
==
"__main__"
:
try
:
from
transformers
import
BertTokenizer
,
GPT2Tokenizer
from
transformers.models.bert.modeling_bert
import
BertModel
from
transformers.models.gpt2.modeling_gpt2
import
GPT2Model
import
transformers
transformers
.
logging
.
set_verbosity
(
transformers
.
logging
.
FATAL
,
)
except
:
print
(
"
\n
[Fail] Please install `transformers` package to test fused kernels
\n
"
)
exit
(
-
1
)
test_load_fused_kernels
()
test_fused_softmax
()
test_fused_upper_triangle_mask_softmax
()
test_layer_norm
()
megatron/model/fused_softmax.py
View file @
40becfc9
...
...
@@ -13,7 +13,9 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import
torch
import
torch.nn
as
nn
from
megatron.model.enums
import
AttnMaskType
...
...
@@ -30,10 +32,10 @@ class ScaledUpperTriangMaskedSoftmax(torch.autograd.Function):
import
scaled_upper_triang_masked_softmax_cuda
scale_t
=
torch
.
tensor
([
scale
])
softmax_results
=
scaled_upper_triang_masked_softmax_cuda
.
forward
(
inputs
,
scale_t
[
0
]
)
ctx
.
save_for_backward
(
softmax_results
,
scale_t
)
return
softmax_results
...
...
@@ -42,10 +44,10 @@ class ScaledUpperTriangMaskedSoftmax(torch.autograd.Function):
import
scaled_upper_triang_masked_softmax_cuda
softmax_results
,
scale_t
=
ctx
.
saved_tensors
input_grads
=
scaled_upper_triang_masked_softmax_cuda
.
backward
(
output_grads
,
softmax_results
,
scale_t
[
0
]
)
return
input_grads
,
None
...
...
@@ -63,9 +65,7 @@ class ScaledMaskedSoftmax(torch.autograd.Function):
scale_t
=
torch
.
tensor
([
scale
])
softmax_results
=
scaled_masked_softmax_cuda
.
forward
(
inputs
,
mask
,
scale_t
[
0
]
)
softmax_results
=
scaled_masked_softmax_cuda
.
forward
(
inputs
,
mask
,
scale_t
[
0
])
ctx
.
save_for_backward
(
softmax_results
,
scale_t
)
return
softmax_results
...
...
@@ -81,16 +81,18 @@ class ScaledMaskedSoftmax(torch.autograd.Function):
return
input_grads
,
None
,
None
class
FusedScaleMaskSoftmax
(
torch
.
nn
.
Module
):
class
FusedScaleMaskSoftmax
(
nn
.
Module
):
"""
fused operation: scaling + mask + softmax
Arguments:
input_in_fp16: flag to indicate if input in fp16 data format.
input_in_bf16: flag to indicate if input in bf16 data format.
attn_mask_type: attention mask type (pad or causal)
scaled_masked_softmax_fusion: flag to indicate user want to use softmax fusion
mask_func: mask function to be applied.
softmax_in_fp32: if true, softmax in performed at fp32 precision.
scale: scaling factor used in input tensor scaling.
"""
def
__init__
(
...
...
@@ -106,8 +108,9 @@ class FusedScaleMaskSoftmax(torch.nn.Module):
super
(
FusedScaleMaskSoftmax
,
self
).
__init__
()
self
.
input_in_fp16
=
input_in_fp16
self
.
input_in_bf16
=
input_in_bf16
assert
not
(
self
.
input_in_fp16
and
self
.
input_in_bf16
),
\
'both fp16 and bf16 flags cannot be active at the same time.'
assert
not
(
self
.
input_in_fp16
and
self
.
input_in_bf16
),
"both fp16 and bf16 flags cannot be active at the same time."
self
.
input_in_float16
=
self
.
input_in_fp16
or
self
.
input_in_bf16
self
.
attn_mask_type
=
attn_mask_type
self
.
scaled_masked_softmax_fusion
=
scaled_masked_softmax_fusion
...
...
@@ -118,47 +121,72 @@ class FusedScaleMaskSoftmax(torch.nn.Module):
assert
(
self
.
scale
is
None
or
softmax_in_fp32
),
"softmax should be in fp32 when scaled"
def
forward
(
self
,
input
,
mask
):
# [b, np, sq, sk]
assert
input
.
dim
()
==
4
data_size
=
input
.
size
()
query_seq_len
=
data_size
[
-
2
]
key_seq_len
=
data_size
[
-
1
]
attn_batch_size
=
data_size
[
0
]
*
data_size
[
1
]
# constraints on various tensor dimensions to enable warp based
# optimization and upper triangular optimization (for causal mask)
custom_kernel_constraint
=
key_seq_len
>
16
and
key_seq_len
<=
2048
and
\
query_seq_len
%
4
==
0
and
attn_batch_size
%
4
==
0
# invoke custom kernel
if
self
.
input_in_float16
and
mask
is
not
None
and
\
custom_kernel_constraint
and
self
.
scaled_masked_softmax_fusion
:
scale
=
self
.
scale
if
self
.
scale
is
not
None
else
1.0
if
self
.
attn_mask_type
==
AttnMaskType
.
causal
:
assert
query_seq_len
==
key_seq_len
,
\
"causal mask is only for self attention"
input
=
input
.
view
(
-
1
,
query_seq_len
,
key_seq_len
)
probs
=
ScaledUpperTriangMaskedSoftmax
.
apply
(
input
,
scale
)
probs
=
probs
.
view
(
*
data_size
)
else
:
assert
self
.
attn_mask_type
==
AttnMaskType
.
padding
probs
=
ScaledMaskedSoftmax
.
apply
(
input
,
mask
,
scale
)
if
self
.
is_kernel_available
(
mask
,
*
input
.
size
()):
return
self
.
forward_fused_softmax
(
input
,
mask
)
else
:
if
self
.
input_in_float16
and
self
.
softmax_in_fp32
:
input
=
input
.
float
()
return
self
.
forward_torch_softmax
(
input
,
mask
)
def
is_kernel_available
(
self
,
mask
,
b
,
np
,
sq
,
sk
):
attn_batches
=
b
*
np
if
(
self
.
scaled_masked_softmax_fusion
# user want to fuse
and
self
.
input_in_float16
# input must be fp16
and
mask
is
not
None
# mask tensor must not be None
and
16
<
sq
<=
2048
# sq must be 16 ~ 2048
and
sk
%
4
==
0
# sk must be divisor of 4
and
attn_batches
%
4
==
0
# np * b must be divisor of 4
):
if
0
<=
sk
<=
2048
:
batch_per_block
=
self
.
get_batch_per_block
(
sq
,
sk
,
b
,
np
)
if
self
.
attn_mask_type
==
AttnMaskType
.
causal
:
if
attn_batches
%
batch_per_block
==
0
:
return
True
else
:
if
sq
%
batch_per_block
==
0
:
return
True
return
False
if
self
.
scale
is
not
None
:
input
=
input
*
self
.
scale
mask_output
=
self
.
mask_func
(
input
,
mask
)
if
mask
is
not
None
else
input
probs
=
torch
.
nn
.
Softmax
(
dim
=-
1
)(
mask_output
)
def
forward_fused_softmax
(
self
,
input
,
mask
):
b
,
np
,
sq
,
sk
=
input
.
size
()
scale
=
self
.
scale
if
self
.
scale
is
not
None
else
1.0
if
self
.
input_in_float16
and
self
.
softmax_in_fp32
:
if
self
.
input_in_fp16
:
probs
=
probs
.
half
()
else
:
probs
=
probs
.
bfloat16
()
if
self
.
attn_mask_type
==
AttnMaskType
.
causal
:
assert
sq
==
sk
,
"causal mask is only for self attention"
# input is 3D tensor (attn_batches, sq, sk)
input
=
input
.
view
(
-
1
,
sq
,
sk
)
probs
=
ScaledUpperTriangMaskedSoftmax
.
apply
(
input
,
scale
)
return
probs
.
view
(
b
,
np
,
sq
,
sk
)
else
:
# input is 4D tensor (b, np, sq, sk)
return
ScaledMaskedSoftmax
.
apply
(
input
,
mask
,
scale
)
def
forward_torch_softmax
(
self
,
input
,
mask
):
if
self
.
input_in_float16
and
self
.
softmax_in_fp32
:
input
=
input
.
float
()
if
self
.
scale
is
not
None
:
input
=
input
*
self
.
scale
mask_output
=
self
.
mask_func
(
input
,
mask
)
if
mask
is
not
None
else
input
probs
=
torch
.
nn
.
Softmax
(
dim
=-
1
)(
mask_output
)
if
self
.
input_in_float16
and
self
.
softmax_in_fp32
:
if
self
.
input_in_fp16
:
probs
=
probs
.
half
()
else
:
probs
=
probs
.
bfloat16
()
return
probs
@
staticmethod
def
get_batch_per_block
(
b
,
np
,
sq
,
sk
):
import
scaled_masked_softmax_cuda
return
scaled_masked_softmax_cuda
.
get_batch_per_block
(
sq
,
sk
,
b
,
np
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment