Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
Megatron-LM
Commits
14c85e64
Commit
14c85e64
authored
Mar 04, 2021
by
Rewon Child
Browse files
Merge branch 'main' into rc-debug-underflow
parents
4e77e7c6
c1faa9fe
Changes
6
Hide whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
41 additions
and
11 deletions
+41
-11
megatron/arguments.py
megatron/arguments.py
+19
-3
megatron/checkpointing.py
megatron/checkpointing.py
+5
-2
megatron/model/fused_softmax.py
megatron/model/fused_softmax.py
+10
-5
megatron/training.py
megatron/training.py
+2
-0
tools/generate_samples_gpt.py
tools/generate_samples_gpt.py
+3
-1
tools/merge_mp_partitions.py
tools/merge_mp_partitions.py
+2
-0
No files found.
megatron/arguments.py
View file @
14c85e64
...
...
@@ -202,7 +202,23 @@ def parse_args(extra_args_provider=None, defaults={},
assert
args
.
checkpoint_activations
,
\
'for distribute-checkpointed-activations to work you '
\
'need to enable checkpoint-activations'
# custom kernel constraints check
seq_len
=
args
.
seq_length
attn_batch_size
=
\
(
args
.
num_attention_heads
/
args
.
tensor_model_parallel_size
)
*
\
args
.
micro_batch_size
# constraints on sequence length and attn_batch_size to enable warp based
# optimization and upper triangular optimization (for causal mask)
custom_kernel_constraint
=
seq_len
>
16
and
seq_len
<=
2048
and
\
seq_len
%
4
==
0
and
attn_batch_size
%
4
==
0
if
args
.
fp16
and
custom_kernel_constraint
and
args
.
masked_softmax_fusion
:
print
(
'WARNING: constraints for invoking optimized'
' fused softmax kernel are not met. We default back to unfused'
' kernel invocations.'
)
# Load scaled_masked_softmax_fusion_kernels
if
args
.
masked_softmax_fusion
:
fused_kernels
.
load_scaled_upper_triang_masked_softmax_fusion_kernel
()
...
...
@@ -480,9 +496,9 @@ def _add_checkpointing_args(parser):
help
=
'Output directory to save checkpoints to.'
)
group
.
add_argument
(
'--save-interval'
,
type
=
int
,
default
=
None
,
help
=
'Number of iterations between checkpoint saves.'
)
group
.
add_argument
(
'--no-save-optim'
,
action
=
'store_true'
,
group
.
add_argument
(
'--no-save-optim'
,
action
=
'store_true'
,
default
=
None
,
help
=
'Do not save current optimizer.'
)
group
.
add_argument
(
'--no-save-rng'
,
action
=
'store_true'
,
group
.
add_argument
(
'--no-save-rng'
,
action
=
'store_true'
,
default
=
None
,
help
=
'Do not save current rng state.'
)
group
.
add_argument
(
'--load'
,
type
=
str
,
default
=
None
,
help
=
'Directory containing a model checkpoint.'
)
...
...
megatron/checkpointing.py
View file @
14c85e64
...
...
@@ -343,12 +343,15 @@ def load_checkpoint(model, optimizer, lr_scheduler, load_arg='load', strict=True
np
.
random
.
set_state
(
state_dict
[
'np_rng_state'
])
torch
.
set_rng_state
(
state_dict
[
'torch_rng_state'
])
torch
.
cuda
.
set_rng_state
(
state_dict
[
'cuda_rng_state'
])
# Check for empty states array
if
not
state_dict
[
'rng_tracker_states'
]:
raise
KeyError
mpu
.
get_cuda_rng_tracker
().
set_states
(
state_dict
[
'rng_tracker_states'
])
except
KeyError
:
print_rank_0
(
'Unable to load
optimizer
from checkpoint {}. '
print_rank_0
(
'Unable to load
rng state
from checkpoint {}. '
'Specify --no-load-rng or --finetune to prevent '
'attempting to load the
optimizer
state, '
'attempting to load the
rng
state, '
'exiting ...'
.
format
(
checkpoint_name
))
sys
.
exit
()
...
...
megatron/model/fused_softmax.py
View file @
14c85e64
...
...
@@ -113,18 +113,23 @@ class FusedScaleMaskSoftmax(torch.nn.Module):
assert
(
self
.
scale
is
None
or
softmax_in_fp32
),
"softmax should be in fp32 when scaled"
def
forward
(
self
,
input
,
mask
):
# [b, np, sq, sk]
assert
input
.
dim
()
==
4
data_size
=
input
.
size
()
query_seq_len
=
data_size
[
-
2
]
key_seq_len
=
data_size
[
-
1
]
a
ssert
input
.
dim
()
==
4
a
ttn_batch_size
=
data_size
[
0
]
*
data_size
[
1
]
# invoke custom kernel
if
self
.
input_in_fp16
and
key_seq_len
<=
2048
and
mask
is
not
None
and
\
query_seq_len
%
4
==
0
and
self
.
scaled_masked_softmax_fusion
:
# constraints on various tensor dimensions to enable warp based
# optimization and upper triangular optimization (for causal mask)
custom_kernel_constraint
=
key_seq_len
>
16
and
key_seq_len
<=
2048
and
\
query_seq_len
%
4
==
0
and
attn_batch_size
%
4
==
0
# invoke custom kernel
if
self
.
input_in_fp16
and
mask
is
not
None
and
\
custom_kernel_constraint
and
self
.
scaled_masked_softmax_fusion
:
scale
=
self
.
scale
if
self
.
scale
is
not
None
else
1.0
if
self
.
attn_mask_type
==
AttnMaskType
.
causal
:
...
...
megatron/training.py
View file @
14c85e64
...
...
@@ -351,6 +351,8 @@ def communicate(tensor_send_next, tensor_send_prev, recv_forward, recv_backward)
reqs
=
torch
.
distributed
.
batch_isend_irecv
(
ops
)
for
req
in
reqs
:
req
.
wait
()
# Temporary workaround for batch_isend_irecv() race condition.
torch
.
cuda
.
synchronize
()
return
tensor_recv_prev
,
tensor_recv_next
...
...
tools/generate_samples_gpt.py
View file @
14c85e64
...
...
@@ -92,7 +92,9 @@ def main():
"""Main program."""
initialize_megatron
(
extra_args_provider
=
add_text_generate_args
,
args_defaults
=
{
'tokenizer_type'
:
'GPT2BPETokenizer'
})
args_defaults
=
{
'tokenizer_type'
:
'GPT2BPETokenizer'
,
'no_load_rng'
:
True
,
'no_load_optim'
:
True
})
# Set up model and load checkpoint.
model
=
get_model
(
model_provider
)
...
...
tools/merge_mp_partitions.py
View file @
14c85e64
...
...
@@ -200,6 +200,8 @@ def main():
'micro_batch_size'
:
1
,
'no_load_optim'
:
True
,
'no_load_rng'
:
True
,
'no_save_optim'
:
True
,
'no_save_rng'
:
True
,
'save_interval'
:
1
})
args
=
get_args
()
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment