Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
Megatron-LM
Commits
368ad0d3
Commit
368ad0d3
authored
Apr 14, 2020
by
Mohammad
Browse files
made size arguments optional so they can be set from input function
parent
7d75b3b5
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
14 additions
and
31 deletions
+14
-31
megatron/arguments.py
megatron/arguments.py
+14
-31
No files found.
megatron/arguments.py
View file @
368ad0d3
...
@@ -35,8 +35,6 @@ def parse_args(extra_args_provider=None, defaults={}):
...
@@ -35,8 +35,6 @@ def parse_args(extra_args_provider=None, defaults={}):
parser
=
_add_validation_args
(
parser
)
parser
=
_add_validation_args
(
parser
)
parser
=
_add_data_args
(
parser
)
parser
=
_add_data_args
(
parser
)
parser
=
_add_autoresume_args
(
parser
)
parser
=
_add_autoresume_args
(
parser
)
# TODO: Refactor
parser
=
_add_gpt2_args
(
parser
)
# Custom arguments.
# Custom arguments.
if
extra_args_provider
is
not
None
:
if
extra_args_provider
is
not
None
:
...
@@ -54,6 +52,12 @@ def parse_args(extra_args_provider=None, defaults={}):
...
@@ -54,6 +52,12 @@ def parse_args(extra_args_provider=None, defaults={}):
'defaults can only be overwritten for args with None values.'
'defaults can only be overwritten for args with None values.'
setattr
(
args
,
key
,
defaults
[
key
])
setattr
(
args
,
key
,
defaults
[
key
])
# Check required arguments.
_check_arg_is_not_none
(
args
,
'num_layers'
)
_check_arg_is_not_none
(
args
,
'hidden_size'
)
_check_arg_is_not_none
(
args
,
'num_attention_heads'
)
_check_arg_is_not_none
(
args
,
'max_position_embeddings'
)
# Distributed args.
# Distributed args.
args
.
rank
=
int
(
os
.
getenv
(
'RANK'
,
'0'
))
args
.
rank
=
int
(
os
.
getenv
(
'RANK'
,
'0'
))
args
.
world_size
=
int
(
os
.
getenv
(
"WORLD_SIZE"
,
'1'
))
args
.
world_size
=
int
(
os
.
getenv
(
"WORLD_SIZE"
,
'1'
))
...
@@ -93,16 +97,20 @@ def _print_args(args):
...
@@ -93,16 +97,20 @@ def _print_args(args):
print
(
'---------------- end of arguments ----------------'
,
flush
=
True
)
print
(
'---------------- end of arguments ----------------'
,
flush
=
True
)
def
_check_arg_is_not_none
(
args
,
arg
):
assert
getattr
(
args
,
arg
)
is
not
None
,
'{} argument is None'
.
format
(
arg
)
def
_add_network_size_args
(
parser
):
def
_add_network_size_args
(
parser
):
group
=
parser
.
add_argument_group
(
title
=
'network size'
)
group
=
parser
.
add_argument_group
(
title
=
'network size'
)
group
.
add_argument
(
'--num-layers'
,
type
=
int
,
required
=
Tru
e
,
group
.
add_argument
(
'--num-layers'
,
type
=
int
,
default
=
Non
e
,
help
=
'Number of transformer layers.'
)
help
=
'Number of transformer layers.'
)
group
.
add_argument
(
'--hidden-size'
,
type
=
int
,
required
=
Tru
e
,
group
.
add_argument
(
'--hidden-size'
,
type
=
int
,
default
=
Non
e
,
help
=
'Tansformer hidden size.'
)
help
=
'Tansformer hidden size.'
)
group
.
add_argument
(
'--num-attention-heads'
,
type
=
int
,
required
=
Tru
e
,
group
.
add_argument
(
'--num-attention-heads'
,
type
=
int
,
default
=
Non
e
,
help
=
'Number of transformer attention heads.'
)
help
=
'Number of transformer attention heads.'
)
group
.
add_argument
(
'--max-position-embeddings'
,
type
=
int
,
required
=
Tru
e
,
group
.
add_argument
(
'--max-position-embeddings'
,
type
=
int
,
default
=
Non
e
,
help
=
'Maximum number of position embeddings to use. '
help
=
'Maximum number of position embeddings to use. '
'This is the size of position embedding.'
)
'This is the size of position embedding.'
)
group
.
add_argument
(
'--make-vocab-size-divisible-by'
,
type
=
int
,
default
=
128
,
group
.
add_argument
(
'--make-vocab-size-divisible-by'
,
type
=
int
,
default
=
128
,
...
@@ -342,28 +350,3 @@ def _add_autoresume_args(parser):
...
@@ -342,28 +350,3 @@ def _add_autoresume_args(parser):
'termination signal'
)
'termination signal'
)
return
parser
return
parser
########################################################################
def
_add_gpt2_args
(
parser
):
group
=
parser
.
add_argument_group
(
title
=
'gpt2'
)
group
.
add_argument
(
'--input-data-sizes-file'
,
type
=
str
,
default
=
'sizes.txt'
,
help
=
'The filename containing all the shards '
'sizes for numpy data loader'
)
return
parser
def
add_data_args_
(
parser
):
"""Train/valid/test data arguments."""
group
=
parser
.
add_argument_group
(
'data'
,
'data configurations'
)
group
.
add_argument
(
'--data-loader'
,
type
=
str
,
default
=
None
,
choices
=
[
'raw'
,
'lazy'
,
'tfrecords'
,
'numpy'
,
'binary'
],
help
=
'Which data loader to use. Default varies by model.'
)
return
parser
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment