Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
Megatron-LM
Commits
ce29d4d5
Commit
ce29d4d5
authored
Apr 02, 2020
by
Mohammad
Browse files
working on refactoring text generation
parent
a0bcee94
Changes
5
Expand all
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
189 additions
and
225 deletions
+189
-225
generate_samples.py
generate_samples.py
+174
-198
megatron/arguments.py
megatron/arguments.py
+1
-23
megatron/model/bert_model.py
megatron/model/bert_model.py
+1
-2
megatron/model/gpt2_model.py
megatron/model/gpt2_model.py
+6
-2
megatron/tokenizer/tokenizer.py
megatron/tokenizer/tokenizer.py
+7
-0
No files found.
generate_samples.py
View file @
ce29d4d5
This diff is collapsed.
Click to expand it.
megatron/arguments.py
View file @
ce29d4d5
...
@@ -357,29 +357,7 @@ def _add_gpt2_args(parser):
...
@@ -357,29 +357,7 @@ def _add_gpt2_args(parser):
def
add_text_generate_args
(
parser
):
"""Text generate arguments."""
group
=
parser
.
add_argument_group
(
'Text generation'
,
'configurations'
)
group
.
add_argument
(
"--temperature"
,
type
=
float
,
default
=
1.0
)
group
.
add_argument
(
"--greedy"
,
action
=
'store_true'
,
default
=
False
)
group
.
add_argument
(
"--top_p"
,
type
=
float
,
default
=
0.0
)
group
.
add_argument
(
"--top_k"
,
type
=
int
,
default
=
0
)
group
.
add_argument
(
"--out-seq-length"
,
type
=
int
,
default
=
1024
)
group
.
add_argument
(
"--sample-input-file"
,
type
=
str
,
default
=
""
,
help
=
'get input from file instead of interactive mode, '
'each line is an input'
)
group
.
add_argument
(
"--sample-output-file"
,
type
=
str
,
default
=
""
,
help
=
'output file got from --sample-input-file'
)
group
.
add_argument
(
"--num-samples"
,
type
=
int
,
default
=
0
,
help
=
'number of samples to generate unconditionally, '
'defaults to 0 and interactive conditional sampling'
)
group
.
add_argument
(
"--genfile"
,
type
=
str
,
help
=
'output file when generating unconditionally'
)
group
.
add_argument
(
"--recompute"
,
action
=
'store_true'
,
help
=
'during generation recompute all attention '
'instead of using previously computed keys/values.'
)
return
parser
def
add_data_args_
(
parser
):
def
add_data_args_
(
parser
):
...
...
megatron/model/bert_model.py
View file @
ce29d4d5
...
@@ -137,8 +137,7 @@ class BertModel(MegatronModule):
...
@@ -137,8 +137,7 @@ class BertModel(MegatronModule):
self
.
_binary_head_key
=
'binary_head'
self
.
_binary_head_key
=
'binary_head'
def
forward
(
self
,
input_ids
,
attention_mask
,
def
forward
(
self
,
input_ids
,
attention_mask
,
tokentype_ids
=
None
):
tokentype_ids
=
None
):
extended_attention_mask
=
bert_extended_attention_mask
(
extended_attention_mask
=
bert_extended_attention_mask
(
attention_mask
,
next
(
self
.
language_model
.
parameters
()).
dtype
)
attention_mask
,
next
(
self
.
language_model
.
parameters
()).
dtype
)
...
...
megatron/model/gpt2_model.py
View file @
ce29d4d5
...
@@ -51,7 +51,8 @@ class GPT2Model(MegatronModule):
...
@@ -51,7 +51,8 @@ class GPT2Model(MegatronModule):
def
forward
(
self
,
input_ids
,
position_ids
,
attention_mask
,
def
forward
(
self
,
input_ids
,
position_ids
,
attention_mask
,
tokentype_ids
=
None
,
layer_past
=
None
,
get_key_value
=
False
):
tokentype_ids
=
None
,
layer_past
=
None
,
get_key_value
=
False
,
forward_method_parallel_output
=
None
):
# Language model.
# Language model.
lm_output
=
self
.
language_model
(
input_ids
,
lm_output
=
self
.
language_model
(
input_ids
,
...
@@ -65,10 +66,13 @@ class GPT2Model(MegatronModule):
...
@@ -65,10 +66,13 @@ class GPT2Model(MegatronModule):
lm_output
,
presents
=
lm_output
lm_output
,
presents
=
lm_output
# Output.
# Output.
parallel_output
=
self
.
parallel_output
if
forward_method_parallel_output
is
not
None
:
parallel_output
=
forward_method_parallel_output
output
=
parallel_lm_logits
(
output
=
parallel_lm_logits
(
lm_output
,
lm_output
,
self
.
language_model
.
embedding
.
word_embeddings
.
weight
,
self
.
language_model
.
embedding
.
word_embeddings
.
weight
,
self
.
parallel_output
)
parallel_output
)
if
get_key_value
:
if
get_key_value
:
output
=
[
output
,
presents
]
output
=
[
output
,
presents
]
...
...
megatron/tokenizer/tokenizer.py
View file @
ce29d4d5
...
@@ -91,6 +91,10 @@ class AbstractTokenizer(ABC):
...
@@ -91,6 +91,10 @@ class AbstractTokenizer(ABC):
def
tokenize
(
self
,
text
):
def
tokenize
(
self
,
text
):
pass
pass
def
detokenize
(
self
,
token_ids
):
raise
NotImplementedError
(
'detokenizer is not implemented for {} '
'tokenizer'
.
format
(
self
.
name
))
@
property
@
property
def
cls
(
self
):
def
cls
(
self
):
raise
NotImplementedError
(
'CLS is not provided for {} '
raise
NotImplementedError
(
'CLS is not provided for {} '
...
@@ -190,6 +194,9 @@ class _GPT2BPETokenizer(AbstractTokenizer):
...
@@ -190,6 +194,9 @@ class _GPT2BPETokenizer(AbstractTokenizer):
def
tokenize
(
self
,
text
):
def
tokenize
(
self
,
text
):
return
self
.
tokenizer
.
encode
(
text
)
return
self
.
tokenizer
.
encode
(
text
)
def
detokenize
(
self
,
token_ids
):
return
self
.
tokenizer
.
decode
(
token_ids
)
@
property
@
property
def
eod
(
self
):
def
eod
(
self
):
return
self
.
eod_id
return
self
.
eod_id
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment