Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
Megatron-LM
Commits
752eeae3
Commit
752eeae3
authored
Apr 02, 2020
by
Mohammad
Browse files
code runs
parent
a6ba254f
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
8 additions
and
10 deletions
+8
-10
generate_samples.py
generate_samples.py
+2
-2
megatron/arguments.py
megatron/arguments.py
+6
-8
No files found.
generate_samples.py
View file @
752eeae3
...
@@ -319,7 +319,7 @@ def get_token_stream(model, context_tokens):
...
@@ -319,7 +319,7 @@ def get_token_stream(model, context_tokens):
group
=
mpu
.
get_model_parallel_group
())
group
=
mpu
.
get_model_parallel_group
())
context_length
=
context_length_tensor
.
min
().
item
()
context_length
=
context_length_tensor
.
min
().
item
()
tokens
,
attention_mask
,
position_ids
=
get_batch
(
context_tokens_tensor
,
args
)
tokens
,
attention_mask
,
position_ids
=
get_batch
(
context_tokens_tensor
)
batch_token_iterator
=
sample_sequence_batch
(
model
,
context_tokens_tensor
,
batch_token_iterator
=
sample_sequence_batch
(
model
,
context_tokens_tensor
,
context_length_tensor
,
context_length_tensor
,
...
@@ -469,7 +469,7 @@ def main():
...
@@ -469,7 +469,7 @@ def main():
# Generate samples.
# Generate samples.
if
args
.
num_samples
==
0
:
if
args
.
num_samples
==
0
:
assert
args
.
batch_size
=
=
1
args
.
batch_size
=
1
if
args
.
sample_input_file
!=
""
:
if
args
.
sample_input_file
!=
""
:
generate_samples_input_from_file
(
model
)
generate_samples_input_from_file
(
model
)
else
:
else
:
...
...
megatron/arguments.py
View file @
752eeae3
...
@@ -69,8 +69,10 @@ def parse_args(extra_args_provider=None, defaults={}):
...
@@ -69,8 +69,10 @@ def parse_args(extra_args_provider=None, defaults={}):
# Checks.
# Checks.
assert
args
.
hidden_size
%
args
.
num_attention_heads
==
0
assert
args
.
hidden_size
%
args
.
num_attention_heads
==
0
assert
args
.
max_position_embeddings
>=
args
.
seq_length
if
args
.
seq_length
is
not
None
:
assert
args
.
min_lr
<=
args
.
lr
assert
args
.
max_position_embeddings
>=
args
.
seq_length
if
args
.
lr
is
not
None
:
assert
args
.
min_lr
<=
args
.
lr
if
args
.
save
is
not
None
:
if
args
.
save
is
not
None
:
assert
args
.
save_interval
is
not
None
assert
args
.
save_interval
is
not
None
...
@@ -134,7 +136,7 @@ def _add_regularization_args(parser):
...
@@ -134,7 +136,7 @@ def _add_regularization_args(parser):
def
_add_training_args
(
parser
):
def
_add_training_args
(
parser
):
group
=
parser
.
add_argument_group
(
title
=
'training'
)
group
=
parser
.
add_argument_group
(
title
=
'training'
)
group
.
add_argument
(
'--batch-size'
,
type
=
int
,
required
=
Tru
e
,
group
.
add_argument
(
'--batch-size'
,
type
=
int
,
default
=
Non
e
,
help
=
'Batch size per model instance (local batch size). '
help
=
'Batch size per model instance (local batch size). '
'Global batch size is local batch size times data '
'Global batch size is local batch size times data '
'parallel size.'
)
'parallel size.'
)
...
@@ -301,7 +303,7 @@ def _add_data_args(parser):
...
@@ -301,7 +303,7 @@ def _add_data_args(parser):
help
=
'Path to the vocab file.'
)
help
=
'Path to the vocab file.'
)
group
.
add_argument
(
'--merge-file'
,
type
=
str
,
default
=
None
,
group
.
add_argument
(
'--merge-file'
,
type
=
str
,
default
=
None
,
help
=
'Path to the BPE merge file.'
)
help
=
'Path to the BPE merge file.'
)
group
.
add_argument
(
'--seq-length'
,
type
=
int
,
required
=
Tru
e
,
group
.
add_argument
(
'--seq-length'
,
type
=
int
,
default
=
Non
e
,
help
=
"Maximum sequence length to process."
)
help
=
"Maximum sequence length to process."
)
group
.
add_argument
(
'--mask-prob'
,
type
=
float
,
default
=
0.15
,
group
.
add_argument
(
'--mask-prob'
,
type
=
float
,
default
=
0.15
,
help
=
'Probability of replacing a token with mask.'
)
help
=
'Probability of replacing a token with mask.'
)
...
@@ -356,10 +358,6 @@ def _add_gpt2_args(parser):
...
@@ -356,10 +358,6 @@ def _add_gpt2_args(parser):
def
add_data_args_
(
parser
):
def
add_data_args_
(
parser
):
"""Train/valid/test data arguments."""
"""Train/valid/test data arguments."""
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment