Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
Megatron-LM
Commits
40becfc9
Commit
40becfc9
authored
Aug 13, 2021
by
hyunwoongko
Committed by
mshoeybi
Aug 22, 2021
Browse files
Improve and fix bugs about fused softmax layer
parent
23266c57
Changes
7
Hide whitespace changes
Inline
Side-by-side
Showing
7 changed files
with
415 additions
and
54 deletions
+415
-54
megatron/fused_kernels/scaled_masked_softmax.cpp
megatron/fused_kernels/scaled_masked_softmax.cpp
+21
-1
megatron/fused_kernels/scaled_masked_softmax.h
megatron/fused_kernels/scaled_masked_softmax.h
+16
-4
megatron/fused_kernels/scaled_masked_softmax_cuda.cu
megatron/fused_kernels/scaled_masked_softmax_cuda.cu
+5
-0
megatron/fused_kernels/scaled_upper_triang_masked_softmax.h
megatron/fused_kernels/scaled_upper_triang_masked_softmax.h
+0
-4
megatron/fused_kernels/tests/__init__.py
megatron/fused_kernels/tests/__init__.py
+0
-0
megatron/fused_kernels/tests/test_fused_kernels.py
megatron/fused_kernels/tests/test_fused_kernels.py
+300
-0
megatron/model/fused_softmax.py
megatron/model/fused_softmax.py
+73
-45
No files found.
megatron/fused_kernels/scaled_masked_softmax.cpp
View file @
40becfc9
...
@@ -32,6 +32,12 @@ torch::Tensor bwd_cuda(
...
@@ -32,6 +32,12 @@ torch::Tensor bwd_cuda(
torch
::
Tensor
const
&
softmax_results
,
torch
::
Tensor
const
&
softmax_results
,
float
scale_factor
);
float
scale_factor
);
int
get_batch_per_block_cuda
(
int
query_seq_len
,
int
key_seq_len
,
int
batches
,
int
attn_heads
);
torch
::
Tensor
fwd
(
torch
::
Tensor
fwd
(
torch
::
Tensor
const
&
input
,
torch
::
Tensor
const
&
input
,
torch
::
Tensor
const
&
mask
,
torch
::
Tensor
const
&
mask
,
...
@@ -63,6 +69,14 @@ torch::Tensor bwd(
...
@@ -63,6 +69,14 @@ torch::Tensor bwd(
return
bwd_cuda
(
output_grads
,
softmax_results
,
scale_factor
);
return
bwd_cuda
(
output_grads
,
softmax_results
,
scale_factor
);
}
}
int
get_batch_per_block
(
int
query_seq_len
,
int
key_seq_len
,
int
batches
,
int
attn_heads
)
{
return
get_batch_per_block_cuda
(
query_seq_len
,
key_seq_len
,
batches
,
attn_heads
);
}
}
// end namespace scaled_masked_softmax
}
// end namespace scaled_masked_softmax
}
// end namespace fused_softmax
}
// end namespace fused_softmax
}
// end namespace multihead_attn
}
// end namespace multihead_attn
...
@@ -71,7 +85,13 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
...
@@ -71,7 +85,13 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m
.
def
(
"forward"
,
m
.
def
(
"forward"
,
&
multihead_attn
::
fused_softmax
::
scaled_masked_softmax
::
fwd
,
&
multihead_attn
::
fused_softmax
::
scaled_masked_softmax
::
fwd
,
"Self Multihead Attention scaled, time masked softmax -- Forward."
);
"Self Multihead Attention scaled, time masked softmax -- Forward."
);
m
.
def
(
"backward"
,
m
.
def
(
"backward"
,
&
multihead_attn
::
fused_softmax
::
scaled_masked_softmax
::
bwd
,
&
multihead_attn
::
fused_softmax
::
scaled_masked_softmax
::
bwd
,
"Self Multihead Attention scaled, time masked softmax -- Backward."
);
"Self Multihead Attention scaled, time masked softmax -- Backward."
);
m
.
def
(
"get_batch_per_block"
,
&
multihead_attn
::
fused_softmax
::
scaled_masked_softmax
::
get_batch_per_block
,
"Return Batch per block size."
);
}
}
megatron/fused_kernels/scaled_masked_softmax.h
View file @
40becfc9
...
@@ -16,6 +16,7 @@
...
@@ -16,6 +16,7 @@
#pragma once
#pragma once
#include <stdio.h>
#include <assert.h>
#include <assert.h>
#include <cuda_fp16.h>
#include <cuda_fp16.h>
#include <cfloat>
#include <cfloat>
...
@@ -310,9 +311,23 @@ __global__ void scaled_masked_softmax_warp_backward(
...
@@ -310,9 +311,23 @@ __global__ void scaled_masked_softmax_warp_backward(
}
}
}
}
}
}
}
// end of anonymous namespace
}
// end of anonymous namespace
int
get_batch_per_block
(
int
query_seq_len
,
int
key_seq_len
,
int
batches
,
int
attn_heads
){
int
log2_elements
=
log2_ceil
(
key_seq_len
);
const
int
next_power_of_two
=
1
<<
log2_elements
;
int
batch_count
=
batches
*
attn_heads
*
query_seq_len
;
int
warp_size
=
(
next_power_of_two
<
C10_WARP_SIZE
)
?
next_power_of_two
:
C10_WARP_SIZE
;
int
batches_per_warp
=
(
next_power_of_two
<=
128
)
?
2
:
1
;
constexpr
int
threads_per_block
=
128
;
int
warps_per_block
=
(
threads_per_block
/
warp_size
);
int
batches_per_block
=
warps_per_block
*
batches_per_warp
;
return
batches_per_block
;
}
template
<
typename
input_t
,
typename
output_t
,
typename
acc_t
>
template
<
typename
input_t
,
typename
output_t
,
typename
acc_t
>
void
dispatch_scaled_masked_softmax_forward
(
void
dispatch_scaled_masked_softmax_forward
(
output_t
*
dst
,
output_t
*
dst
,
...
@@ -325,7 +340,6 @@ void dispatch_scaled_masked_softmax_forward(
...
@@ -325,7 +340,6 @@ void dispatch_scaled_masked_softmax_forward(
int
attn_heads
,
int
attn_heads
,
int
pad_batches
)
int
pad_batches
)
{
{
TORCH_INTERNAL_ASSERT
(
key_seq_len
>=
0
&&
key_seq_len
<=
2048
);
if
(
key_seq_len
==
0
)
{
if
(
key_seq_len
==
0
)
{
return
;
return
;
}
else
{
}
else
{
...
@@ -344,7 +358,6 @@ void dispatch_scaled_masked_softmax_forward(
...
@@ -344,7 +358,6 @@ void dispatch_scaled_masked_softmax_forward(
int
warps_per_block
=
(
threads_per_block
/
warp_size
);
int
warps_per_block
=
(
threads_per_block
/
warp_size
);
int
batches_per_block
=
warps_per_block
*
batches_per_warp
;
int
batches_per_block
=
warps_per_block
*
batches_per_warp
;
TORCH_INTERNAL_ASSERT
(
query_seq_len
%
batches_per_block
==
0
);
dim3
blocks
(
query_seq_len
/
batches_per_block
,
attn_heads
,
batches
);
dim3
blocks
(
query_seq_len
/
batches_per_block
,
attn_heads
,
batches
);
dim3
threads
(
warp_size
,
warps_per_block
,
1
);
dim3
threads
(
warp_size
,
warps_per_block
,
1
);
// Launch code would be more elegant if C++ supported FOR CONSTEXPR
// Launch code would be more elegant if C++ supported FOR CONSTEXPR
...
@@ -414,7 +427,6 @@ void dispatch_scaled_masked_softmax_backward(
...
@@ -414,7 +427,6 @@ void dispatch_scaled_masked_softmax_backward(
int
batches
,
int
batches
,
int
attn_heads
)
int
attn_heads
)
{
{
TORCH_INTERNAL_ASSERT
(
key_seq_len
>=
0
&&
key_seq_len
<=
2048
);
if
(
key_seq_len
==
0
)
{
if
(
key_seq_len
==
0
)
{
return
;
return
;
}
else
{
}
else
{
...
...
megatron/fused_kernels/scaled_masked_softmax_cuda.cu
View file @
40becfc9
...
@@ -28,6 +28,11 @@ namespace multihead_attn {
...
@@ -28,6 +28,11 @@ namespace multihead_attn {
namespace
fused_softmax
{
namespace
fused_softmax
{
namespace
scaled_masked_softmax
{
namespace
scaled_masked_softmax
{
int
get_batch_per_block_cuda
(
int
query_seq_len
,
int
key_seq_len
,
int
batches
,
int
attn_heads
){
return
get_batch_per_block
(
query_seq_len
,
key_seq_len
,
batches
,
attn_heads
);
}
torch
::
Tensor
fwd_cuda
(
torch
::
Tensor
fwd_cuda
(
torch
::
Tensor
const
&
input
,
torch
::
Tensor
const
&
input
,
torch
::
Tensor
const
&
mask
,
torch
::
Tensor
const
&
mask
,
...
...
megatron/fused_kernels/scaled_upper_triang_masked_softmax.h
View file @
40becfc9
...
@@ -340,7 +340,6 @@ void dispatch_scaled_upper_triang_masked_softmax_forward(
...
@@ -340,7 +340,6 @@ void dispatch_scaled_upper_triang_masked_softmax_forward(
int
softmax_elements_stride
,
int
softmax_elements_stride
,
int
attn_batches
)
int
attn_batches
)
{
{
TORCH_INTERNAL_ASSERT
(
softmax_elements
>=
0
&&
softmax_elements
<=
2048
);
if
(
softmax_elements
==
0
)
{
if
(
softmax_elements
==
0
)
{
return
;
return
;
}
else
{
}
else
{
...
@@ -360,7 +359,6 @@ void dispatch_scaled_upper_triang_masked_softmax_forward(
...
@@ -360,7 +359,6 @@ void dispatch_scaled_upper_triang_masked_softmax_forward(
int
warps_per_block
=
(
threads_per_block
/
warp_size
);
int
warps_per_block
=
(
threads_per_block
/
warp_size
);
int
batches_per_block
=
warps_per_block
*
batches_per_warp
;
int
batches_per_block
=
warps_per_block
*
batches_per_warp
;
TORCH_INTERNAL_ASSERT
(
attn_batches
%
batches_per_block
==
0
);
int
blocks_per_seq
=
attn_batches
/
batches_per_block
;
int
blocks_per_seq
=
attn_batches
/
batches_per_block
;
dim3
blocks
(
seq_len
,
blocks_per_seq
,
1
);
dim3
blocks
(
seq_len
,
blocks_per_seq
,
1
);
dim3
threads
(
warp_size
,
warps_per_block
,
1
);
dim3
threads
(
warp_size
,
warps_per_block
,
1
);
...
@@ -430,7 +428,6 @@ void dispatch_scaled_upper_triang_masked_softmax_backward(
...
@@ -430,7 +428,6 @@ void dispatch_scaled_upper_triang_masked_softmax_backward(
int
softmax_elements_stride
,
int
softmax_elements_stride
,
int
attn_batches
)
int
attn_batches
)
{
{
TORCH_INTERNAL_ASSERT
(
softmax_elements
>=
0
&&
softmax_elements
<=
2048
);
if
(
softmax_elements
==
0
)
{
if
(
softmax_elements
==
0
)
{
return
;
return
;
}
else
{
}
else
{
...
@@ -450,7 +447,6 @@ void dispatch_scaled_upper_triang_masked_softmax_backward(
...
@@ -450,7 +447,6 @@ void dispatch_scaled_upper_triang_masked_softmax_backward(
int
warps_per_block
=
(
threads_per_block
/
warp_size
);
int
warps_per_block
=
(
threads_per_block
/
warp_size
);
int
batches_per_block
=
warps_per_block
*
batches_per_warp
;
int
batches_per_block
=
warps_per_block
*
batches_per_warp
;
TORCH_INTERNAL_ASSERT
(
attn_batches
%
batches_per_block
==
0
);
int
blocks_per_seq
=
attn_batches
/
batches_per_block
;
int
blocks_per_seq
=
attn_batches
/
batches_per_block
;
dim3
blocks
(
seq_len
,
blocks_per_seq
,
1
);
dim3
blocks
(
seq_len
,
blocks_per_seq
,
1
);
dim3
threads
(
warp_size
,
warps_per_block
,
1
);
dim3
threads
(
warp_size
,
warps_per_block
,
1
);
...
...
megatron/fused_kernels/tests/__init__.py
0 → 100644
View file @
40becfc9
megatron/fused_kernels/tests/test_fused_kernels.py
0 → 100644
View file @
40becfc9
import
math
import
torch
from
torch.nn
import
LayerNorm
from
megatron.model.enums
import
AttnMaskType
from
megatron.model.fused_layer_norm
import
MixedFusedLayerNorm
from
megatron.model.fused_softmax
import
FusedScaleMaskSoftmax
from
megatron.model.utils
import
attention_mask_func
def
test_load_fused_kernels
():
try
:
import
fused_mix_prec_layer_norm_cuda
import
scaled_masked_softmax_cuda
import
scaled_upper_triang_masked_softmax_cuda
import
torch
print
(
"[Success] load_fused_kernels"
)
except
ImportError
as
e
:
print
(
"[Fail] load_fused_kernels"
)
raise
e
def
test_fused_softmax
():
bert
=
BertModel
.
from_pretrained
(
"bert-base-cased"
).
cuda
().
half
()
tokenizer
=
BertTokenizer
.
from_pretrained
(
"bert-base-cased"
)
test_text
=
(
"Hello. How are you? I am fine thank you and you? yes Good. "
"hi hi hi hi hi hi hi hi hi hi hi hi hi"
# 32
)
tokens
=
tokenizer
(
[
test_text
]
*
4
,
return_tensors
=
"pt"
,
)
embedding_output
=
bert
.
embeddings
(
input_ids
=
tokens
[
"input_ids"
].
cuda
(),
position_ids
=
None
,
token_type_ids
=
tokens
[
"token_type_ids"
].
cuda
(),
inputs_embeds
=
None
,
past_key_values_length
=
0
,
)
# (bsz, 1, 1, seq_len)
mask
=
bert
.
get_extended_attention_mask
(
attention_mask
=
tokens
[
"attention_mask"
].
cuda
(),
input_shape
=
tokens
[
"input_ids"
].
shape
,
device
=
bert
.
device
,
)
# (bsz, 1, seq_len, seq_len)
mask
=
mask
.
repeat
(
1
,
1
,
mask
.
size
()[
-
1
],
1
)
attention
=
bert
.
encoder
.
layer
[
0
].
attention
.
self
key_layer
=
attention
.
transpose_for_scores
(
attention
.
key
(
embedding_output
))
query_layer
=
attention
.
transpose_for_scores
(
attention
.
query
(
embedding_output
))
attention_scores
=
torch
.
matmul
(
query_layer
,
key_layer
.
transpose
(
-
1
,
-
2
))
attention_scores
/=
math
.
sqrt
(
key_layer
.
size
()[
-
1
])
fused_softmax
=
(
FusedScaleMaskSoftmax
(
input_in_fp16
=
True
,
input_in_bf16
=
False
,
mask_func
=
attention_mask_func
,
scale
=
None
,
softmax_in_fp32
=
False
,
attn_mask_type
=
AttnMaskType
.
padding
,
scaled_masked_softmax_fusion
=
True
,
)
.
cuda
()
.
half
()
)
fused_softmax_output
=
fused_softmax
(
attention_scores
,
(
mask
!=
0
),
)
torch_softmax
=
(
FusedScaleMaskSoftmax
(
input_in_fp16
=
True
,
input_in_bf16
=
False
,
mask_func
=
attention_mask_func
,
scale
=
None
,
softmax_in_fp32
=
False
,
attn_mask_type
=
AttnMaskType
.
padding
,
scaled_masked_softmax_fusion
=
False
,
)
.
cuda
()
.
half
()
)
torch_softmax_output
=
torch_softmax
(
attention_scores
,
(
mask
!=
0
),
)
test_result
=
(
fused_softmax_output
-
torch_softmax_output
).
abs
()
while
test_result
.
dim
()
!=
1
:
test_result
=
test_result
.
mean
(
dim
=-
1
)
diff
=
test_result
.
mean
(
dim
=-
1
)
if
diff
<=
1e-3
:
print
(
f
"
\n
[Success] test_fused_softmax"
f
"
\n
> mean_difference=
{
diff
}
"
f
"
\n
> fused_values=
{
fused_softmax_output
[
-
1
][
-
1
][
-
1
][:
5
].
tolist
()
}
"
f
"
\n
> torch_values=
{
torch_softmax_output
[
-
1
][
-
1
][
-
1
][:
5
].
tolist
()
}
"
)
else
:
print
(
f
"
\n
[Fail] test_fused_softmax"
f
"
\n
> mean_difference=
{
diff
}
, "
f
"
\n
> fused_values=
{
fused_softmax_output
[
-
1
][
-
1
][
-
1
][:
5
].
tolist
()
}
, "
f
"
\n
> torch_values=
{
torch_softmax_output
[
-
1
][
-
1
][
-
1
][:
5
].
tolist
()
}
"
)
def
test_fused_upper_triangle_mask_softmax
():
gpt
=
GPT2Model
.
from_pretrained
(
"gpt2"
).
cuda
().
half
()
tokenizer
=
GPT2Tokenizer
.
from_pretrained
(
"gpt2"
)
test_text
=
(
"Hello. How are you? I am fine thank you and you? yes Good. "
"hi hi hi hi hi hi hi"
# 24
)
tokens
=
tokenizer
(
[
test_text
]
*
4
,
return_tensors
=
"pt"
,
)
attention_mask
=
tokens
[
"attention_mask"
].
cuda
()
attention_mask
=
attention_mask
.
view
(
attention_mask
.
size
(
0
),
-
1
)
attention_mask
=
attention_mask
[:,
None
,
None
,
:]
attention_mask
=
(
1.0
-
attention_mask
)
*
-
10000.0
attention_mask
=
attention_mask
.
repeat
(
1
,
1
,
attention_mask
.
size
()[
-
1
],
1
)
attn
=
gpt
.
h
[
0
]
hidden_states
=
gpt
.
wte
(
tokens
[
"input_ids"
].
cuda
())
q
,
k
,
v
=
attn
.
attn
.
c_attn
(
hidden_states
).
split
(
768
,
dim
=-
1
)
q
=
attn
.
attn
.
_split_heads
(
q
,
attn
.
attn
.
num_heads
,
attn
.
attn
.
head_dim
)
k
=
attn
.
attn
.
_split_heads
(
k
,
attn
.
attn
.
num_heads
,
attn
.
attn
.
head_dim
)
attn_weights
=
torch
.
matmul
(
q
,
k
.
transpose
(
-
1
,
-
2
))
sq
,
sk
=
q
.
size
(
-
2
),
k
.
size
(
-
2
)
causal_mask
=
attn
.
attn
.
bias
[:,
:,
sk
-
sq
:
sk
,
:
sk
].
bool
()
total_mask
=
~
(
causal_mask
&
(
attention_mask
==
0
))
"""
tensor([[[[False, True, True, ..., True, True, True],
[False, False, True, ..., True, True, True],
[False, False, False, ..., True, True, True],
...,
[False, False, False, ..., False, True, True],
[False, False, False, ..., False, False, True],
[False, False, False, ..., False, False, False]]]
"""
fused_softmax
=
(
FusedScaleMaskSoftmax
(
input_in_fp16
=
True
,
input_in_bf16
=
False
,
mask_func
=
attention_mask_func
,
scale
=
None
,
softmax_in_fp32
=
False
,
attn_mask_type
=
AttnMaskType
.
causal
,
scaled_masked_softmax_fusion
=
True
,
)
.
cuda
()
.
half
()
)
fused_softmax_output
=
fused_softmax
(
attn_weights
,
total_mask
,
)
torch_softmax
=
(
FusedScaleMaskSoftmax
(
input_in_fp16
=
True
,
input_in_bf16
=
False
,
mask_func
=
attention_mask_func
,
scale
=
None
,
softmax_in_fp32
=
False
,
attn_mask_type
=
AttnMaskType
.
causal
,
scaled_masked_softmax_fusion
=
False
,
)
.
cuda
()
.
half
()
)
torch_softmax_output
=
torch_softmax
(
attn_weights
,
total_mask
,
)
test_result
=
(
fused_softmax_output
-
torch_softmax_output
).
abs
()
while
test_result
.
dim
()
!=
1
:
test_result
=
test_result
.
mean
(
dim
=-
1
)
diff
=
test_result
.
mean
(
dim
=-
1
)
if
diff
<=
1e-3
:
print
(
f
"
\n
[Success] test_fused_upper_triangle_mask_softmax"
f
"
\n
> mean_difference=
{
diff
}
"
f
"
\n
> fused_values=
{
fused_softmax_output
[
-
1
][
-
1
][
-
1
][:
5
].
tolist
()
}
"
f
"
\n
> torch_values=
{
torch_softmax_output
[
-
1
][
-
1
][
-
1
][:
5
].
tolist
()
}
"
)
else
:
print
(
f
"
\n
[Fail] test_fused_upper_triangle_mask_softmax"
f
"
\n
> mean_difference=
{
diff
}
, "
f
"
\n
> fused_values=
{
fused_softmax_output
[
-
1
][
-
1
][
-
1
][:
5
].
tolist
()
}
, "
f
"
\n
> torch_values=
{
torch_softmax_output
[
-
1
][
-
1
][
-
1
][:
5
].
tolist
()
}
"
)
def
test_layer_norm
():
bert
=
BertModel
.
from_pretrained
(
"bert-base-cased"
).
cuda
().
half
()
tokenizer
=
BertTokenizer
.
from_pretrained
(
"bert-base-cased"
)
test_text
=
(
"Hello. How are you? I am fine thank you and you? yes Good. "
"hi hi hi hi hi hi hi hi hi hi hi hi hi"
# 32
)
tokens
=
tokenizer
(
[
test_text
]
*
4
,
return_tensors
=
"pt"
,
)
# [bsz, seq_len, d_model]
embedding_output
=
(
bert
.
embeddings
(
input_ids
=
tokens
[
"input_ids"
].
cuda
(),
position_ids
=
None
,
token_type_ids
=
tokens
[
"token_type_ids"
].
cuda
(),
inputs_embeds
=
None
,
past_key_values_length
=
0
,
)
.
cuda
()
.
half
()
)
fused_layernorm_layer
=
(
MixedFusedLayerNorm
(
normalized_shape
=
embedding_output
.
size
(
-
1
)).
cuda
().
half
()
)
torch_layernorm_layer
=
(
LayerNorm
(
normalized_shape
=
embedding_output
.
size
(
-
1
)).
cuda
().
half
()
)
fused_output
=
fused_layernorm_layer
(
embedding_output
)
torch_output
=
torch_layernorm_layer
(
embedding_output
)
test_result
=
(
fused_output
-
torch_output
).
abs
()
while
test_result
.
dim
()
!=
1
:
test_result
=
test_result
.
mean
(
dim
=-
1
)
diff
=
test_result
.
mean
(
dim
=-
1
)
if
diff
<=
1e-3
:
print
(
f
"
\n
[Success] test_layer_norm"
f
"
\n
> mean_difference=
{
diff
}
"
f
"
\n
> fused_values=
{
fused_output
[
-
1
][
-
1
][:
5
].
tolist
()
}
"
f
"
\n
> torch_values=
{
torch_output
[
-
1
][
-
1
][:
5
].
tolist
()
}
"
)
else
:
print
(
f
"
\n
[Fail] test_layer_norm"
f
"
\n
> mean_difference=
{
diff
}
, "
f
"
\n
> fused_values=
{
fused_output
[
-
1
][
-
1
][:
5
].
tolist
()
}
, "
f
"
\n
> torch_values=
{
torch_output
[
-
1
][
-
1
][:
5
].
tolist
()
}
"
)
if
__name__
==
"__main__"
:
try
:
from
transformers
import
BertTokenizer
,
GPT2Tokenizer
from
transformers.models.bert.modeling_bert
import
BertModel
from
transformers.models.gpt2.modeling_gpt2
import
GPT2Model
import
transformers
transformers
.
logging
.
set_verbosity
(
transformers
.
logging
.
FATAL
,
)
except
:
print
(
"
\n
[Fail] Please install `transformers` package to test fused kernels
\n
"
)
exit
(
-
1
)
test_load_fused_kernels
()
test_fused_softmax
()
test_fused_upper_triangle_mask_softmax
()
test_layer_norm
()
megatron/model/fused_softmax.py
View file @
40becfc9
...
@@ -13,7 +13,9 @@
...
@@ -13,7 +13,9 @@
# See the License for the specific language governing permissions and
# See the License for the specific language governing permissions and
# limitations under the License.
# limitations under the License.
import
torch
import
torch
import
torch.nn
as
nn
from
megatron.model.enums
import
AttnMaskType
from
megatron.model.enums
import
AttnMaskType
...
@@ -30,10 +32,10 @@ class ScaledUpperTriangMaskedSoftmax(torch.autograd.Function):
...
@@ -30,10 +32,10 @@ class ScaledUpperTriangMaskedSoftmax(torch.autograd.Function):
import
scaled_upper_triang_masked_softmax_cuda
import
scaled_upper_triang_masked_softmax_cuda
scale_t
=
torch
.
tensor
([
scale
])
scale_t
=
torch
.
tensor
([
scale
])
softmax_results
=
scaled_upper_triang_masked_softmax_cuda
.
forward
(
softmax_results
=
scaled_upper_triang_masked_softmax_cuda
.
forward
(
inputs
,
scale_t
[
0
]
inputs
,
scale_t
[
0
]
)
)
ctx
.
save_for_backward
(
softmax_results
,
scale_t
)
ctx
.
save_for_backward
(
softmax_results
,
scale_t
)
return
softmax_results
return
softmax_results
...
@@ -42,10 +44,10 @@ class ScaledUpperTriangMaskedSoftmax(torch.autograd.Function):
...
@@ -42,10 +44,10 @@ class ScaledUpperTriangMaskedSoftmax(torch.autograd.Function):
import
scaled_upper_triang_masked_softmax_cuda
import
scaled_upper_triang_masked_softmax_cuda
softmax_results
,
scale_t
=
ctx
.
saved_tensors
softmax_results
,
scale_t
=
ctx
.
saved_tensors
input_grads
=
scaled_upper_triang_masked_softmax_cuda
.
backward
(
input_grads
=
scaled_upper_triang_masked_softmax_cuda
.
backward
(
output_grads
,
softmax_results
,
scale_t
[
0
]
output_grads
,
softmax_results
,
scale_t
[
0
]
)
)
return
input_grads
,
None
return
input_grads
,
None
...
@@ -63,9 +65,7 @@ class ScaledMaskedSoftmax(torch.autograd.Function):
...
@@ -63,9 +65,7 @@ class ScaledMaskedSoftmax(torch.autograd.Function):
scale_t
=
torch
.
tensor
([
scale
])
scale_t
=
torch
.
tensor
([
scale
])
softmax_results
=
scaled_masked_softmax_cuda
.
forward
(
softmax_results
=
scaled_masked_softmax_cuda
.
forward
(
inputs
,
mask
,
scale_t
[
0
])
inputs
,
mask
,
scale_t
[
0
]
)
ctx
.
save_for_backward
(
softmax_results
,
scale_t
)
ctx
.
save_for_backward
(
softmax_results
,
scale_t
)
return
softmax_results
return
softmax_results
...
@@ -81,16 +81,18 @@ class ScaledMaskedSoftmax(torch.autograd.Function):
...
@@ -81,16 +81,18 @@ class ScaledMaskedSoftmax(torch.autograd.Function):
return
input_grads
,
None
,
None
return
input_grads
,
None
,
None
class
FusedScaleMaskSoftmax
(
torch
.
nn
.
Module
):
class
FusedScaleMaskSoftmax
(
nn
.
Module
):
"""
"""
fused operation: scaling + mask + softmax
fused operation: scaling + mask + softmax
Arguments:
Arguments:
input_in_fp16: flag to indicate if input in fp16 data format.
input_in_fp16: flag to indicate if input in fp16 data format.
input_in_bf16: flag to indicate if input in bf16 data format.
attn_mask_type: attention mask type (pad or causal)
attn_mask_type: attention mask type (pad or causal)
scaled_masked_softmax_fusion: flag to indicate user want to use softmax fusion
mask_func: mask function to be applied.
mask_func: mask function to be applied.
softmax_in_fp32: if true, softmax in performed at fp32 precision.
softmax_in_fp32: if true, softmax in performed at fp32 precision.
scale: scaling factor used in input tensor scaling.
scale: scaling factor used in input tensor scaling.
"""
"""
def
__init__
(
def
__init__
(
...
@@ -106,8 +108,9 @@ class FusedScaleMaskSoftmax(torch.nn.Module):
...
@@ -106,8 +108,9 @@ class FusedScaleMaskSoftmax(torch.nn.Module):
super
(
FusedScaleMaskSoftmax
,
self
).
__init__
()
super
(
FusedScaleMaskSoftmax
,
self
).
__init__
()
self
.
input_in_fp16
=
input_in_fp16
self
.
input_in_fp16
=
input_in_fp16
self
.
input_in_bf16
=
input_in_bf16
self
.
input_in_bf16
=
input_in_bf16
assert
not
(
self
.
input_in_fp16
and
self
.
input_in_bf16
),
\
assert
not
(
'both fp16 and bf16 flags cannot be active at the same time.'
self
.
input_in_fp16
and
self
.
input_in_bf16
),
"both fp16 and bf16 flags cannot be active at the same time."
self
.
input_in_float16
=
self
.
input_in_fp16
or
self
.
input_in_bf16
self
.
input_in_float16
=
self
.
input_in_fp16
or
self
.
input_in_bf16
self
.
attn_mask_type
=
attn_mask_type
self
.
attn_mask_type
=
attn_mask_type
self
.
scaled_masked_softmax_fusion
=
scaled_masked_softmax_fusion
self
.
scaled_masked_softmax_fusion
=
scaled_masked_softmax_fusion
...
@@ -118,47 +121,72 @@ class FusedScaleMaskSoftmax(torch.nn.Module):
...
@@ -118,47 +121,72 @@ class FusedScaleMaskSoftmax(torch.nn.Module):
assert
(
assert
(
self
.
scale
is
None
or
softmax_in_fp32
self
.
scale
is
None
or
softmax_in_fp32
),
"softmax should be in fp32 when scaled"
),
"softmax should be in fp32 when scaled"
def
forward
(
self
,
input
,
mask
):
def
forward
(
self
,
input
,
mask
):
# [b, np, sq, sk]
# [b, np, sq, sk]
assert
input
.
dim
()
==
4
assert
input
.
dim
()
==
4
data_size
=
input
.
size
()
query_seq_len
=
data_size
[
-
2
]
if
self
.
is_kernel_available
(
mask
,
*
input
.
size
()):
key_seq_len
=
data_size
[
-
1
]
return
self
.
forward_fused_softmax
(
input
,
mask
)
attn_batch_size
=
data_size
[
0
]
*
data_size
[
1
]
# constraints on various tensor dimensions to enable warp based
# optimization and upper triangular optimization (for causal mask)
custom_kernel_constraint
=
key_seq_len
>
16
and
key_seq_len
<=
2048
and
\
query_seq_len
%
4
==
0
and
attn_batch_size
%
4
==
0
# invoke custom kernel
if
self
.
input_in_float16
and
mask
is
not
None
and
\
custom_kernel_constraint
and
self
.
scaled_masked_softmax_fusion
:
scale
=
self
.
scale
if
self
.
scale
is
not
None
else
1.0
if
self
.
attn_mask_type
==
AttnMaskType
.
causal
:
assert
query_seq_len
==
key_seq_len
,
\
"causal mask is only for self attention"
input
=
input
.
view
(
-
1
,
query_seq_len
,
key_seq_len
)
probs
=
ScaledUpperTriangMaskedSoftmax
.
apply
(
input
,
scale
)
probs
=
probs
.
view
(
*
data_size
)
else
:
assert
self
.
attn_mask_type
==
AttnMaskType
.
padding
probs
=
ScaledMaskedSoftmax
.
apply
(
input
,
mask
,
scale
)
else
:
else
:
if
self
.
input_in_float16
and
self
.
softmax_in_fp32
:
return
self
.
forward_torch_softmax
(
input
,
mask
)
input
=
input
.
float
()
def
is_kernel_available
(
self
,
mask
,
b
,
np
,
sq
,
sk
):
attn_batches
=
b
*
np
if
(
self
.
scaled_masked_softmax_fusion
# user want to fuse
and
self
.
input_in_float16
# input must be fp16
and
mask
is
not
None
# mask tensor must not be None
and
16
<
sq
<=
2048
# sq must be 16 ~ 2048
and
sk
%
4
==
0
# sk must be divisor of 4
and
attn_batches
%
4
==
0
# np * b must be divisor of 4
):
if
0
<=
sk
<=
2048
:
batch_per_block
=
self
.
get_batch_per_block
(
sq
,
sk
,
b
,
np
)
if
self
.
attn_mask_type
==
AttnMaskType
.
causal
:
if
attn_batches
%
batch_per_block
==
0
:
return
True
else
:
if
sq
%
batch_per_block
==
0
:
return
True
return
False
if
self
.
scale
is
not
None
:
def
forward_fused_softmax
(
self
,
input
,
mask
):
input
=
input
*
self
.
scale
b
,
np
,
sq
,
sk
=
input
.
size
()
mask_output
=
self
.
mask_func
(
input
,
mask
)
if
mask
is
not
None
else
input
scale
=
self
.
scale
if
self
.
scale
is
not
None
else
1.0
probs
=
torch
.
nn
.
Softmax
(
dim
=-
1
)(
mask_output
)
if
self
.
input_in_float16
and
self
.
softmax_in_fp32
:
if
self
.
attn_mask_type
==
AttnMaskType
.
causal
:
if
self
.
input_in_fp16
:
assert
sq
==
sk
,
"causal mask is only for self attention"
probs
=
probs
.
half
()
else
:
# input is 3D tensor (attn_batches, sq, sk)
probs
=
probs
.
bfloat16
()
input
=
input
.
view
(
-
1
,
sq
,
sk
)
probs
=
ScaledUpperTriangMaskedSoftmax
.
apply
(
input
,
scale
)
return
probs
.
view
(
b
,
np
,
sq
,
sk
)
else
:
# input is 4D tensor (b, np, sq, sk)
return
ScaledMaskedSoftmax
.
apply
(
input
,
mask
,
scale
)
def
forward_torch_softmax
(
self
,
input
,
mask
):
if
self
.
input_in_float16
and
self
.
softmax_in_fp32
:
input
=
input
.
float
()
if
self
.
scale
is
not
None
:
input
=
input
*
self
.
scale
mask_output
=
self
.
mask_func
(
input
,
mask
)
if
mask
is
not
None
else
input
probs
=
torch
.
nn
.
Softmax
(
dim
=-
1
)(
mask_output
)
if
self
.
input_in_float16
and
self
.
softmax_in_fp32
:
if
self
.
input_in_fp16
:
probs
=
probs
.
half
()
else
:
probs
=
probs
.
bfloat16
()
return
probs
return
probs
@
staticmethod
def
get_batch_per_block
(
b
,
np
,
sq
,
sk
):
import
scaled_masked_softmax_cuda
return
scaled_masked_softmax_cuda
.
get_batch_per_block
(
sq
,
sk
,
b
,
np
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment