Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
Megatron-LM
Commits
76db9583
Commit
76db9583
authored
Jul 28, 2022
by
Vijay Korthikanti
Browse files
support for all mask in fused kernel + avoiding inplace operation in bwd pass
parent
8df49e72
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
108 additions
and
7 deletions
+108
-7
megatron/fused_kernels/scaled_masked_softmax.h
megatron/fused_kernels/scaled_masked_softmax.h
+8
-1
megatron/fused_kernels/scaled_masked_softmax_cuda.cu
megatron/fused_kernels/scaled_masked_softmax_cuda.cu
+9
-5
megatron/fused_kernels/tests/test_fused_kernels.py
megatron/fused_kernels/tests/test_fused_kernels.py
+90
-1
megatron/model/fused_softmax.py
megatron/model/fused_softmax.py
+1
-0
No files found.
megatron/fused_kernels/scaled_masked_softmax.h
View file @
76db9583
...
...
@@ -293,6 +293,13 @@ __global__ void scaled_masked_softmax_warp_forward(
}
warp_reduce
<
acc_t
,
WARP_BATCH
,
WARP_SIZE
,
Max
>
(
max_value
);
// compute scale value to account for full mask
acc_t
scale_value
[
WARP_BATCH
];
#pragma unroll
for
(
int
i
=
0
;
i
<
WARP_BATCH
;
++
i
)
{
scale_value
[
i
]
=
(
max_value
[
i
]
==
-
10000.0
)
?
0.0
:
1.0
;
}
acc_t
sum
[
WARP_BATCH
]
{
0.0
f
};
#pragma unroll
for
(
int
i
=
0
;
i
<
WARP_BATCH
;
++
i
)
{
...
...
@@ -316,7 +323,7 @@ __global__ void scaled_masked_softmax_warp_forward(
if
(
element_index
<
element_count
)
{
#pragma unroll
for
(
int
element
=
0
;
element
<
ELEMENTS_PER_LDG_STG
;
++
element
)
{
out
[
element
]
=
elements
[
i
][
it
+
element
]
/
sum
[
i
];
out
[
element
]
=
elements
[
i
][
it
+
element
]
*
scale_value
[
i
]
/
sum
[
i
];
}
copy_vector
<
output_t
,
ELEMENTS_PER_LDG_STG
>
(
dst
+
i
*
element_count
+
it
*
WARP_SIZE
,
out
);
}
else
{
...
...
megatron/fused_kernels/scaled_masked_softmax_cuda.cu
View file @
76db9583
...
...
@@ -65,7 +65,7 @@ torch::Tensor fwd_cuda(
input
.
scalar_type
(),
"dispatch_scaled_masked_softmax_forward"
,
dispatch_scaled_masked_softmax_forward
<
scalar_t
,
scalar_t
,
float
>
(
reinterpret_cast
<
scalar_t
*>
(
softmax_results_ptr
),
reinterpret_cast
<
scalar_t
*>
(
softmax_results_ptr
),
reinterpret_cast
<
const
scalar_t
*>
(
input_ptr
),
reinterpret_cast
<
const
uint8_t
*>
(
mask_ptr
),
scale_factor
,
...
...
@@ -92,14 +92,19 @@ torch::Tensor bwd_cuda(
const
int
query_seq_len
=
output_grads
.
size
(
2
);
const
int
key_seq_len
=
output_grads
.
size
(
3
);
auto
act_options
=
output_grads
.
options
().
requires_grad
(
false
);
torch
::
Tensor
input_grads
=
torch
::
empty
({
batches
,
attn_heads
,
query_seq_len
,
key_seq_len
},
act_options
);
void
*
output_grads_ptr
=
static_cast
<
void
*>
(
output_grads
.
data_ptr
());
void
*
input_grads_ptr
=
static_cast
<
void
*>
(
input_grads
.
data_ptr
());
//Softmax Grad
DISPATCH_HALF_AND_BFLOAT
(
output_grads_
.
scalar_type
(),
"dispatch_scaled_masked_softmax_backward"
,
dispatch_scaled_masked_softmax_backward
<
scalar_t
,
scalar_t
,
float
>
(
reinterpret_cast
<
scalar_t
*>
(
out
put_grads_ptr
),
reinterpret_cast
<
scalar_t
*>
(
in
put_grads_ptr
),
reinterpret_cast
<
scalar_t
*>
(
output_grads_ptr
),
reinterpret_cast
<
scalar_t
const
*>
(
softmax_results
.
data_ptr
()),
scale_factor
,
...
...
@@ -107,10 +112,9 @@ torch::Tensor bwd_cuda(
key_seq_len
,
batches
,
attn_heads
);
);
);
//backward pass is completely in-place
return
output_grads
;
return
input_grads
;
}
}
}
...
...
megatron/fused_kernels/tests/test_fused_kernels.py
View file @
76db9583
...
...
@@ -7,7 +7,7 @@ from megatron.model.enums import AttnMaskType
from
megatron.model.fused_layer_norm
import
MixedFusedLayerNorm
from
megatron.model.fused_softmax
import
FusedScaleMaskSoftmax
from
megatron.model.utils
import
attention_mask_func
from
megatron.fused_kernels
import
load
def
test_load_fused_kernels
():
try
:
...
...
@@ -279,6 +279,90 @@ def test_layer_norm():
)
def
attention_mask_func
(
attention_scores
,
attention_mask
):
attention_scores
.
masked_fill_
(
attention_mask
,
-
10000.0
)
return
attention_scores
def
forward_torch_softmax
(
input
,
mask
,
scale
):
input
=
input
*
scale
mask_output
=
attention_mask_func
(
input
,
mask
)
if
mask
is
not
None
else
input
probs
=
torch
.
nn
.
Softmax
(
dim
=-
1
)(
mask_output
)
return
probs
def
test_masked_softmax_forward
():
import
scaled_masked_softmax_cuda
batch
=
2
attn
=
16
scale_t
=
torch
.
tensor
([
1.0
])
for
qlen
in
[
128
,
256
,
1024
,
2048
,
4096
]:
for
klen
in
[
128
,
256
,
1024
,
2048
]:
inputs
=
torch
.
normal
(
0
,
2
,
(
batch
,
attn
,
qlen
,
klen
),
dtype
=
torch
.
float16
,
device
=
'cuda:0'
)
masks
=
torch
.
randint
(
0
,
2
,
(
batch
,
1
,
qlen
,
klen
),
dtype
=
torch
.
bool
,
device
=
'cuda:0'
)
softmax_results
=
scaled_masked_softmax_cuda
.
forward
(
inputs
,
masks
,
scale_t
[
0
].
item
())
softmax_results_torch
=
forward_torch_softmax
(
inputs
,
masks
,
scale_t
[
0
].
item
())
error
=
(
softmax_results_torch
-
softmax_results
).
abs
().
max
()
assert
error
<
1e-3
def
test_masked_softmax_backward
():
import
scaled_masked_softmax_cuda
batch
=
2
attn
=
16
scale_t
=
torch
.
tensor
([
1.0
])
for
qlen
in
[
128
,
256
,
1024
,
2048
,
4096
]:
for
klen
in
[
128
,
256
,
1024
,
2048
]:
inputs
=
torch
.
normal
(
0
,
2
,
(
batch
,
attn
,
qlen
,
klen
),
dtype
=
torch
.
float16
,
device
=
'cuda:0'
)
backward
=
torch
.
rand_like
(
inputs
,
dtype
=
torch
.
float16
,
device
=
'cuda:0'
)
masks
=
torch
.
randint
(
0
,
2
,
(
batch
,
1
,
qlen
,
klen
),
dtype
=
torch
.
bool
,
device
=
'cuda:0'
)
softmax_results
=
scaled_masked_softmax_cuda
.
forward
(
inputs
,
masks
,
scale_t
[
0
].
item
())
back_grad
=
scaled_masked_softmax_cuda
.
backward
(
backward
,
softmax_results
,
scale_t
[
0
].
item
())
inputs
.
requires_grad
=
True
softmax_results_torch
=
forward_torch_softmax
(
inputs
,
masks
,
scale_t
[
0
].
item
())
softmax_results_torch
.
backward
(
backward
)
error
=
(
back_grad
-
inputs
.
grad
).
abs
().
max
()
assert
error
<
1e-3
def
test_allmasked_softmax_forward
():
import
scaled_masked_softmax_cuda
batch
=
2
attn
=
16
scale_t
=
torch
.
tensor
([
1.0
])
for
qlen
in
[
128
,
256
,
1024
,
2048
,
4096
]:
for
klen
in
[
128
,
256
,
1024
,
2048
]:
inputs
=
torch
.
normal
(
0
,
2
,
(
batch
,
attn
,
qlen
,
klen
),
dtype
=
torch
.
float16
,
device
=
'cuda:0'
)
masks
=
torch
.
ones
((
batch
,
1
,
qlen
,
klen
),
dtype
=
torch
.
bool
,
device
=
'cuda:0'
)
softmax_results
=
scaled_masked_softmax_cuda
.
forward
(
inputs
,
masks
,
scale_t
[
0
].
item
())
softmax_results_torch
=
torch
.
zeros_like
(
inputs
)
error
=
(
softmax_results_torch
-
softmax_results
).
abs
().
max
()
assert
error
==
0.0
def
test_allmasked_softmax_backward
():
import
scaled_masked_softmax_cuda
batch
=
2
attn
=
16
scale_t
=
torch
.
tensor
([
1.0
])
for
qlen
in
[
128
,
256
,
1024
,
2048
,
4096
]:
for
klen
in
[
128
,
256
,
1024
,
2048
]:
inputs
=
torch
.
normal
(
0
,
2
,
(
batch
,
attn
,
qlen
,
klen
),
dtype
=
torch
.
float16
,
device
=
'cuda:0'
)
backward
=
torch
.
rand_like
(
inputs
,
dtype
=
torch
.
float16
,
device
=
'cuda:0'
)
masks
=
torch
.
ones
((
batch
,
1
,
qlen
,
klen
),
dtype
=
torch
.
bool
,
device
=
'cuda:0'
)
softmax_results
=
scaled_masked_softmax_cuda
.
forward
(
inputs
,
masks
,
scale_t
[
0
].
item
())
back_grad
=
scaled_masked_softmax_cuda
.
backward
(
backward
,
softmax_results
,
scale_t
[
0
].
item
())
inputs
.
requires_grad
=
True
softmax_results_torch
=
forward_torch_softmax
(
inputs
,
masks
,
scale_t
[
0
].
item
())
softmax_results_torch
.
backward
(
backward
)
error
=
(
back_grad
-
inputs
.
grad
).
abs
().
max
()
assert
error
<
1e-3
if
__name__
==
"__main__"
:
try
:
from
transformers
import
BertTokenizer
,
GPT2Tokenizer
...
...
@@ -294,6 +378,11 @@ if __name__ == "__main__":
print
(
"
\n
[Fail] Please install `transformers` package to test fused kernels
\n
"
)
exit
(
-
1
)
load
()
test_masked_softmax_forward
()
test_masked_softmax_backward
()
test_allmasked_softmax_forward
()
test_allmasked_softmax_backward
()
test_load_fused_kernels
()
test_fused_softmax
()
test_fused_upper_triangle_mask_softmax
()
...
...
megatron/model/fused_softmax.py
View file @
76db9583
...
...
@@ -170,6 +170,7 @@ class FusedScaleMaskSoftmax(nn.Module):
and
self
.
input_in_float16
# input must be fp16
and
16
<
sk
<=
4096
# sk must be 16 ~ 2048
and
sq
%
4
==
0
# sq must be divisor of 4
and
sk
%
4
==
0
# sk must be divisor of 4
and
attn_batches
%
4
==
0
# np * b must be divisor of 4
):
if
0
<=
sk
<=
4096
:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment