Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
Megatron-LM
Commits
9acc8956
Commit
9acc8956
authored
Jan 06, 2021
by
Jared Casper
Browse files
Change some arguments to default to on.
parent
9a297541
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
31 additions
and
23 deletions
+31
-23
megatron/arguments.py
megatron/arguments.py
+29
-22
pretrain_gpt2.py
pretrain_gpt2.py
+2
-1
No files found.
megatron/arguments.py
View file @
9acc8956
...
@@ -183,13 +183,15 @@ def parse_args(extra_args_provider=None, defaults={},
...
@@ -183,13 +183,15 @@ def parse_args(extra_args_provider=None, defaults={},
'for distribute-checkpointed-activations to work you '
\
'for distribute-checkpointed-activations to work you '
\
'need to enable checkpoint-activations'
'need to enable checkpoint-activations'
# load scaled_upper_triang_masked_softmax_fusion kernel
if
args
.
scaled_upper_triang_masked_softmax_fusion
:
fused_kernels
.
load_scaled_upper_triang_masked_softmax_fusion_kernel
()
# load scaled_masked_softmax_fusion kernel
if
args
.
scaled_masked_softmax_fusion
:
if
args
.
scaled_masked_softmax_fusion
:
fused_kernels
.
load_scaled_masked_softmax_fusion_kernel
()
if
args
.
scaled_upper_triang_masked_softmax_fusion
:
fused_kernels
.
load_scaled_upper_triang_masked_softmax_fusion_kernel
()
else
:
fused_kernels
.
load_scaled_masked_softmax_fusion_kernel
()
else
:
# This argument will eventually go away, for now make sure it is off
# if scaled_masked_softmax_fusion is off.
args
.
scaled_upper_triang_masked_softmax_fusion
=
False
# Load mixed precision fused layer norm.
# Load mixed precision fused layer norm.
if
args
.
fp32_residual_connection
:
if
args
.
fp32_residual_connection
:
...
@@ -328,18 +330,22 @@ def _add_training_args(parser):
...
@@ -328,18 +330,22 @@ def _add_training_args(parser):
help
=
'Exit the program after this many minutes.'
)
help
=
'Exit the program after this many minutes.'
)
group
.
add_argument
(
'--tensorboard-dir'
,
type
=
str
,
default
=
None
,
group
.
add_argument
(
'--tensorboard-dir'
,
type
=
str
,
default
=
None
,
help
=
'Write TensorBoard logs to this directory.'
)
help
=
'Write TensorBoard logs to this directory.'
)
group
.
add_argument
(
'--no-scaled-masked-softmax-fusion'
,
action
=
'store_false'
,
help
=
'Disable fusion of query_key_value scaling, '
'masking, and softmax.'
,
dest
=
'scaled_masked_softmax_fusion'
)
group
.
add_argument
(
'--scaled-upper-triang-masked-softmax-fusion'
,
group
.
add_argument
(
'--scaled-upper-triang-masked-softmax-fusion'
,
action
=
'store_true'
,
type
=
bool
,
help
=
'Enable fusion of query_key_value_scaling '
help
=
'Use upper triangular version of fused '
'time (upper diagonal) masking and softmax.'
)
'scale, mask, softmax fusion kernel (default for GPT). '
group
.
add_argument
(
'--scaled-masked-softmax-fusion'
,
'- DEPRECATED'
)
action
=
'store_true'
,
group
.
add_argument
(
'--no-bias-gelu-fusion'
,
action
=
'store_false'
,
help
=
'Enable fusion of query_key_value_scaling '
help
=
'Disable bias and gelu fusion.'
,
'general masking and softmax.'
)
dest
=
'bias_gelu_fusion'
)
group
.
add_argument
(
'--bias-gelu-fusion'
,
action
=
'store_true'
,
group
.
add_argument
(
'--no-bias-dropout-fusion'
,
action
=
'store_false'
,
help
=
'Enable bias and gelu fusion.'
)
help
=
'Disable bias and dropout fusion.'
,
group
.
add_argument
(
'--bias-dropout-fusion'
,
action
=
'store_true'
,
dest
=
'bias_dropout_fusion'
)
help
=
'Enable bias and dropout fusion.'
)
return
parser
return
parser
...
@@ -447,12 +453,13 @@ def _add_mixed_precision_args(parser):
...
@@ -447,12 +453,13 @@ def _add_mixed_precision_args(parser):
help
=
'hysteresis for dynamic loss scaling'
)
help
=
'hysteresis for dynamic loss scaling'
)
group
.
add_argument
(
'--fp32-residual-connection'
,
action
=
'store_true'
,
group
.
add_argument
(
'--fp32-residual-connection'
,
action
=
'store_true'
,
help
=
'Move residual connections to fp32.'
)
help
=
'Move residual connections to fp32.'
)
group
.
add_argument
(
'--apply-query-key-layer-scaling'
,
action
=
'store_true'
,
group
.
add_argument
(
'--no-query-key-layer-scaling'
,
action
=
'store_false'
,
help
=
'Scale Q * K^T by 1 / layer-number. If this flag '
help
=
'Do not scale Q * K^T by 1 / layer-number.'
,
'is set, then it will automatically set '
dest
=
'apply_query_key_layer_scaling'
)
'attention-softmax-in-fp32 to true'
)
group
.
add_argument
(
'--attention-softmax-in-fp32'
,
action
=
'store_true'
,
group
.
add_argument
(
'--attention-softmax-in-fp32'
,
action
=
'store_true'
,
help
=
'Run attention masking and softmax in fp32.'
)
help
=
'Run attention masking and softmax in fp32. '
'This flag is ignored unless '
'--no-query-key-layer-scaling is specified.'
)
group
.
add_argument
(
'--fp32-allreduce'
,
action
=
'store_true'
,
group
.
add_argument
(
'--fp32-allreduce'
,
action
=
'store_true'
,
help
=
'All-reduce in fp32'
)
help
=
'All-reduce in fp32'
)
group
.
add_argument
(
'--fp16-lm-cross-entropy'
,
action
=
'store_true'
,
group
.
add_argument
(
'--fp16-lm-cross-entropy'
,
action
=
'store_true'
,
...
...
pretrain_gpt2.py
View file @
9acc8956
...
@@ -141,4 +141,5 @@ def train_valid_test_datasets_provider(train_val_test_num_samples):
...
@@ -141,4 +141,5 @@ def train_valid_test_datasets_provider(train_val_test_num_samples):
if
__name__
==
"__main__"
:
if
__name__
==
"__main__"
:
pretrain
(
train_valid_test_datasets_provider
,
model_provider
,
forward_step
,
pretrain
(
train_valid_test_datasets_provider
,
model_provider
,
forward_step
,
args_defaults
=
{
'tokenizer_type'
:
'GPT2BPETokenizer'
})
args_defaults
=
{
'tokenizer_type'
:
'GPT2BPETokenizer'
,
'scaled_upper_triang_masked_softmax_fusion'
:
True
})
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment