Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
Megatron-LM
Commits
fffa0497
Commit
fffa0497
authored
Apr 02, 2020
by
Mohammad
Browse files
sample generation runs
parent
752eeae3
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
97 additions
and
74 deletions
+97
-74
generate_samples_gpt2.py
generate_samples_gpt2.py
+95
-0
megatron/text_generation_utils.py
megatron/text_generation_utils.py
+2
-74
No files found.
generate_samples_gpt2.py
0 → 100644
View file @
fffa0497
# coding=utf-8
# Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Sample Generate GPT2"""
from
megatron
import
get_args
from
megatron
import
get_tokenizer
from
megatron
import
print_rank_0
from
megatron.checkpointing
import
load_checkpoint
from
megatron.initialize
import
initialize_megatron
from
megatron.model
import
GPT2Model
from
megatron.training
import
get_model
from
megatron.text_generation_utils
import
generate_and_write_samples_unconditional
from
megatron.text_generation_utils
import
generate_samples_input_from_file
from
megatron.text_generation_utils
import
generate_samples_interactive
def
model_provider
():
"""Build the model."""
print_rank_0
(
'building GPT2 model ...'
)
model
=
GPT2Model
(
num_tokentypes
=
0
,
parallel_output
=
False
)
return
model
def
add_text_generate_args
(
parser
):
"""Text generation arguments."""
group
=
parser
.
add_argument_group
(
title
=
'text generation'
)
group
.
add_argument
(
"--temperature"
,
type
=
float
,
default
=
1.0
,
help
=
'Sampling temperature.'
)
group
.
add_argument
(
"--greedy"
,
action
=
'store_true'
,
default
=
False
,
help
=
'Use greedy sampling.'
)
group
.
add_argument
(
"--top_p"
,
type
=
float
,
default
=
0.0
,
help
=
'Top p sampling.'
)
group
.
add_argument
(
"--top_k"
,
type
=
int
,
default
=
0
,
help
=
'Top k sampling.'
)
group
.
add_argument
(
"--out-seq-length"
,
type
=
int
,
default
=
1024
,
help
=
'Size of the output generated text.'
)
group
.
add_argument
(
"--sample-input-file"
,
type
=
str
,
default
=
None
,
help
=
'Get input from file instead of interactive mode, '
'each line is an input.'
)
group
.
add_argument
(
"--sample-output-file"
,
type
=
str
,
default
=
None
,
help
=
'Output file got from --sample-input-file'
)
group
.
add_argument
(
"--num-samples"
,
type
=
int
,
default
=
0
,
help
=
'Number of samples to generate unconditionally, '
'defaults to 0 and interactive conditional sampling'
)
group
.
add_argument
(
"--genfile"
,
type
=
str
,
help
=
'Output file when generating unconditionally'
)
group
.
add_argument
(
"--recompute"
,
action
=
'store_true'
,
help
=
'During generation recompute all attention '
'instead of using previously computed keys/values.'
)
return
parser
def
main
():
"""Main program."""
initialize_megatron
(
extra_args_provider
=
add_text_generate_args
,
args_defaults
=
{
'tokenizer_type'
:
'GPT2BPETokenizer'
})
# Set up model and load checkpoint.
model
=
get_model
(
model_provider
)
args
=
get_args
()
if
args
.
load
is
not
None
:
_
=
load_checkpoint
(
model
,
None
,
None
)
# Generate samples.
if
args
.
num_samples
==
0
:
args
.
batch_size
=
1
if
args
.
sample_input_file
!=
""
:
generate_samples_input_from_file
(
model
)
else
:
generate_samples_interactive
(
model
)
else
:
generate_and_write_samples_unconditional
(
model
)
if
__name__
==
"__main__"
:
main
()
generate_sample
s.py
→
megatron/text_generation_util
s.py
View file @
fffa0497
...
@@ -13,7 +13,7 @@
...
@@ -13,7 +13,7 @@
# See the License for the specific language governing permissions and
# See the License for the specific language governing permissions and
# limitations under the License.
# limitations under the License.
"""
Sample Generate GPT2
"""
"""
Utilities for generating text.
"""
import
copy
import
copy
import
json
import
json
...
@@ -26,23 +26,9 @@ import torch.nn.functional as F
...
@@ -26,23 +26,9 @@ import torch.nn.functional as F
from
megatron
import
get_args
from
megatron
import
get_args
from
megatron
import
get_tokenizer
from
megatron
import
get_tokenizer
from
megatron
import
mpu
from
megatron
import
mpu
from
megatron
import
print_rank_0
from
megatron.checkpointing
import
load_checkpoint
from
megatron.initialize
import
initialize_megatron
from
megatron.model
import
GPT2Model
from
megatron.training
import
get_model
from
megatron.utils
import
get_ltor_masks_and_position_ids
from
megatron.utils
import
get_ltor_masks_and_position_ids
def
model_provider
():
"""Build the model."""
print_rank_0
(
'building GPT2 model ...'
)
model
=
GPT2Model
(
num_tokentypes
=
0
,
parallel_output
=
False
)
return
model
def
get_batch
(
context_tokens
):
def
get_batch
(
context_tokens
):
"""Generate batch from context tokens."""
"""Generate batch from context tokens."""
args
=
get_args
()
args
=
get_args
()
...
@@ -280,7 +266,7 @@ def generate_samples_unconditional(model):
...
@@ -280,7 +266,7 @@ def generate_samples_unconditional(model):
break
break
def
wri
te_and_
genera
te_samples_unconditional
(
model
):
def
genera
te_and_
wri
te_samples_unconditional
(
model
):
args
=
get_args
()
args
=
get_args
()
assert
args
.
genfile
is
not
None
assert
args
.
genfile
is
not
None
...
@@ -423,61 +409,3 @@ def sample_sequence_batch(model, context_tokens, context_lengths,
...
@@ -423,61 +409,3 @@ def sample_sequence_batch(model, context_tokens, context_lengths,
yield
tokens
,
lengths
yield
tokens
,
lengths
if
done
:
if
done
:
break
break
def
add_text_generate_args
(
parser
):
"""Text generation arguments."""
group
=
parser
.
add_argument_group
(
title
=
'text generation'
)
group
.
add_argument
(
"--temperature"
,
type
=
float
,
default
=
1.0
,
help
=
'Sampling temperature.'
)
group
.
add_argument
(
"--greedy"
,
action
=
'store_true'
,
default
=
False
,
help
=
'Use greedy sampling.'
)
group
.
add_argument
(
"--top_p"
,
type
=
float
,
default
=
0.0
,
help
=
'Top p sampling.'
)
group
.
add_argument
(
"--top_k"
,
type
=
int
,
default
=
0
,
help
=
'Top k sampling.'
)
group
.
add_argument
(
"--out-seq-length"
,
type
=
int
,
default
=
1024
,
help
=
'Size of the output generated text.'
)
group
.
add_argument
(
"--sample-input-file"
,
type
=
str
,
default
=
None
,
help
=
'Get input from file instead of interactive mode, '
'each line is an input.'
)
group
.
add_argument
(
"--sample-output-file"
,
type
=
str
,
default
=
None
,
help
=
'Output file got from --sample-input-file'
)
group
.
add_argument
(
"--num-samples"
,
type
=
int
,
default
=
0
,
help
=
'Number of samples to generate unconditionally, '
'defaults to 0 and interactive conditional sampling'
)
group
.
add_argument
(
"--genfile"
,
type
=
str
,
help
=
'Output file when generating unconditionally'
)
group
.
add_argument
(
"--recompute"
,
action
=
'store_true'
,
help
=
'During generation recompute all attention '
'instead of using previously computed keys/values.'
)
return
parser
def
main
():
"""Main program."""
initialize_megatron
(
extra_args_provider
=
add_text_generate_args
,
args_defaults
=
{
'tokenizer_type'
:
'GPT2BPETokenizer'
})
# Set up model and load checkpoint.
model
=
get_model
(
model_provider
)
args
=
get_args
()
if
args
.
load
is
not
None
:
_
=
load_checkpoint
(
model
,
None
,
None
)
# Generate samples.
if
args
.
num_samples
==
0
:
args
.
batch_size
=
1
if
args
.
sample_input_file
!=
""
:
generate_samples_input_from_file
(
model
)
else
:
generate_samples_interactive
(
model
)
else
:
write_and_generate_samples_unconditional
(
model
)
if
__name__
==
"__main__"
:
main
()
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment