Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
Megatron-LM
Commits
2096d356
Commit
2096d356
authored
Feb 05, 2021
by
Jared Casper
Browse files
Merge branch 'fused_kernel_cond' into 'main'
conditioning fused kernels See merge request ADLR/megatron-lm!228
parents
872e38ea
0cb36de2
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
27 additions
and
6 deletions
+27
-6
megatron/arguments.py
megatron/arguments.py
+17
-1
megatron/model/fused_softmax.py
megatron/model/fused_softmax.py
+10
-5
No files found.
megatron/arguments.py
View file @
2096d356
...
@@ -202,7 +202,23 @@ def parse_args(extra_args_provider=None, defaults={},
...
@@ -202,7 +202,23 @@ def parse_args(extra_args_provider=None, defaults={},
assert
args
.
checkpoint_activations
,
\
assert
args
.
checkpoint_activations
,
\
'for distribute-checkpointed-activations to work you '
\
'for distribute-checkpointed-activations to work you '
\
'need to enable checkpoint-activations'
'need to enable checkpoint-activations'
# custom kernel constraints check
seq_len
=
args
.
seq_length
attn_batch_size
=
\
(
args
.
num_attention_heads
/
args
.
tensor_model_parallel_size
)
*
\
args
.
micro_batch_size
# constraints on sequence length and attn_batch_size to enable warp based
# optimization and upper triangular optimization (for causal mask)
custom_kernel_constraint
=
seq_len
>
16
and
seq_len
<=
2048
and
\
seq_len
%
4
==
0
and
attn_batch_size
%
4
==
0
if
args
.
fp16
and
custom_kernel_constraint
and
args
.
masked_softmax_fusion
:
print
(
'WARNING: constraints for invoking optimized'
' fused softmax kernel are not met. We default back to unfused'
' kernel invocations.'
)
# Load scaled_masked_softmax_fusion_kernels
# Load scaled_masked_softmax_fusion_kernels
if
args
.
masked_softmax_fusion
:
if
args
.
masked_softmax_fusion
:
fused_kernels
.
load_scaled_upper_triang_masked_softmax_fusion_kernel
()
fused_kernels
.
load_scaled_upper_triang_masked_softmax_fusion_kernel
()
...
...
megatron/model/fused_softmax.py
View file @
2096d356
...
@@ -113,18 +113,23 @@ class FusedScaleMaskSoftmax(torch.nn.Module):
...
@@ -113,18 +113,23 @@ class FusedScaleMaskSoftmax(torch.nn.Module):
assert
(
assert
(
self
.
scale
is
None
or
softmax_in_fp32
self
.
scale
is
None
or
softmax_in_fp32
),
"softmax should be in fp32 when scaled"
),
"softmax should be in fp32 when scaled"
def
forward
(
self
,
input
,
mask
):
def
forward
(
self
,
input
,
mask
):
# [b, np, sq, sk]
# [b, np, sq, sk]
assert
input
.
dim
()
==
4
data_size
=
input
.
size
()
data_size
=
input
.
size
()
query_seq_len
=
data_size
[
-
2
]
query_seq_len
=
data_size
[
-
2
]
key_seq_len
=
data_size
[
-
1
]
key_seq_len
=
data_size
[
-
1
]
a
ssert
input
.
dim
()
==
4
a
ttn_batch_size
=
data_size
[
0
]
*
data_size
[
1
]
# invoke custom kernel
# constraints on various tensor dimensions to enable warp based
if
self
.
input_in_fp16
and
key_seq_len
<=
2048
and
mask
is
not
None
and
\
# optimization and upper triangular optimization (for causal mask)
query_seq_len
%
4
==
0
and
self
.
scaled_masked_softmax_fusion
:
custom_kernel_constraint
=
key_seq_len
>
16
and
key_seq_len
<=
2048
and
\
query_seq_len
%
4
==
0
and
attn_batch_size
%
4
==
0
# invoke custom kernel
if
self
.
input_in_fp16
and
mask
is
not
None
and
\
custom_kernel_constraint
and
self
.
scaled_masked_softmax_fusion
:
scale
=
self
.
scale
if
self
.
scale
is
not
None
else
1.0
scale
=
self
.
scale
if
self
.
scale
is
not
None
else
1.0
if
self
.
attn_mask_type
==
AttnMaskType
.
causal
:
if
self
.
attn_mask_type
==
AttnMaskType
.
causal
:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment