Commit 63262827 authored by Raul Puri's avatar Raul Puri
Browse files

Merge branch 'staging_text_generation' into 'staging'

Refactoring text generation

See merge request ADLR/megatron-lm!39
parents 3977b721 fffa0497
# coding=utf-8
# Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Sample Generate GPT2"""
from megatron import get_args
from megatron import get_tokenizer
from megatron import print_rank_0
from megatron.checkpointing import load_checkpoint
from megatron.initialize import initialize_megatron
from megatron.model import GPT2Model
from megatron.training import get_model
from megatron.text_generation_utils import generate_and_write_samples_unconditional
from megatron.text_generation_utils import generate_samples_input_from_file
from megatron.text_generation_utils import generate_samples_interactive
def model_provider():
"""Build the model."""
print_rank_0('building GPT2 model ...')
model = GPT2Model(num_tokentypes=0, parallel_output=False)
return model
def add_text_generate_args(parser):
"""Text generation arguments."""
group = parser.add_argument_group(title='text generation')
group.add_argument("--temperature", type=float, default=1.0,
help='Sampling temperature.')
group.add_argument("--greedy", action='store_true', default=False,
help='Use greedy sampling.')
group.add_argument("--top_p", type=float, default=0.0,
help='Top p sampling.')
group.add_argument("--top_k", type=int, default=0,
help='Top k sampling.')
group.add_argument("--out-seq-length", type=int, default=1024,
help='Size of the output generated text.')
group.add_argument("--sample-input-file", type=str, default=None,
help='Get input from file instead of interactive mode, '
'each line is an input.')
group.add_argument("--sample-output-file", type=str, default=None,
help='Output file got from --sample-input-file')
group.add_argument("--num-samples", type=int, default=0,
help='Number of samples to generate unconditionally, '
'defaults to 0 and interactive conditional sampling')
group.add_argument("--genfile", type=str,
help='Output file when generating unconditionally')
group.add_argument("--recompute", action='store_true',
help='During generation recompute all attention '
'instead of using previously computed keys/values.')
return parser
def main():
"""Main program."""
initialize_megatron(extra_args_provider=add_text_generate_args,
args_defaults={'tokenizer_type': 'GPT2BPETokenizer'})
# Set up model and load checkpoint.
model = get_model(model_provider)
args = get_args()
if args.load is not None:
_ = load_checkpoint(model, None, None)
# Generate samples.
if args.num_samples == 0:
args.batch_size = 1
if args.sample_input_file != "":
generate_samples_input_from_file(model)
else:
generate_samples_interactive(model)
else:
generate_and_write_samples_unconditional(model)
if __name__ == "__main__":
main()
...@@ -69,7 +69,9 @@ def parse_args(extra_args_provider=None, defaults={}): ...@@ -69,7 +69,9 @@ def parse_args(extra_args_provider=None, defaults={}):
# Checks. # Checks.
assert args.hidden_size % args.num_attention_heads == 0 assert args.hidden_size % args.num_attention_heads == 0
if args.seq_length is not None:
assert args.max_position_embeddings >= args.seq_length assert args.max_position_embeddings >= args.seq_length
if args.lr is not None:
assert args.min_lr <= args.lr assert args.min_lr <= args.lr
if args.save is not None: if args.save is not None:
assert args.save_interval is not None assert args.save_interval is not None
...@@ -134,7 +136,7 @@ def _add_regularization_args(parser): ...@@ -134,7 +136,7 @@ def _add_regularization_args(parser):
def _add_training_args(parser): def _add_training_args(parser):
group = parser.add_argument_group(title='training') group = parser.add_argument_group(title='training')
group.add_argument('--batch-size', type=int, required=True, group.add_argument('--batch-size', type=int, default=None,
help='Batch size per model instance (local batch size). ' help='Batch size per model instance (local batch size). '
'Global batch size is local batch size times data ' 'Global batch size is local batch size times data '
'parallel size.') 'parallel size.')
...@@ -301,7 +303,7 @@ def _add_data_args(parser): ...@@ -301,7 +303,7 @@ def _add_data_args(parser):
help='Path to the vocab file.') help='Path to the vocab file.')
group.add_argument('--merge-file', type=str, default=None, group.add_argument('--merge-file', type=str, default=None,
help='Path to the BPE merge file.') help='Path to the BPE merge file.')
group.add_argument('--seq-length', type=int, required=True, group.add_argument('--seq-length', type=int, default=None,
help="Maximum sequence length to process.") help="Maximum sequence length to process.")
group.add_argument('--mask-prob', type=float, default=0.15, group.add_argument('--mask-prob', type=float, default=0.15,
help='Probability of replacing a token with mask.') help='Probability of replacing a token with mask.')
...@@ -356,32 +358,6 @@ def _add_gpt2_args(parser): ...@@ -356,32 +358,6 @@ def _add_gpt2_args(parser):
def add_text_generate_args(parser):
"""Text generate arguments."""
group = parser.add_argument_group('Text generation', 'configurations')
group.add_argument("--temperature", type=float, default=1.0)
group.add_argument("--greedy", action='store_true', default=False)
group.add_argument("--top_p", type=float, default=0.0)
group.add_argument("--top_k", type=int, default=0)
group.add_argument("--out-seq-length", type=int, default=1024)
group.add_argument("--sample-input-file", type=str, default="",
help='get input from file instead of interactive mode, '
'each line is an input' )
group.add_argument("--sample-output-file", type=str, default="",
help='output file got from --sample-input-file')
group.add_argument("--num-samples", type=int, default=0,
help='number of samples to generate unconditionally, '
'defaults to 0 and interactive conditional sampling')
group.add_argument("--genfile", type=str,
help='output file when generating unconditionally')
group.add_argument("--recompute", action='store_true',
help='during generation recompute all attention '
'instead of using previously computed keys/values.')
return parser
def add_data_args_(parser): def add_data_args_(parser):
"""Train/valid/test data arguments.""" """Train/valid/test data arguments."""
......
...@@ -137,8 +137,7 @@ class BertModel(MegatronModule): ...@@ -137,8 +137,7 @@ class BertModel(MegatronModule):
self._binary_head_key = 'binary_head' self._binary_head_key = 'binary_head'
def forward(self, input_ids, attention_mask, def forward(self, input_ids, attention_mask, tokentype_ids=None):
tokentype_ids=None):
extended_attention_mask = bert_extended_attention_mask( extended_attention_mask = bert_extended_attention_mask(
attention_mask, next(self.language_model.parameters()).dtype) attention_mask, next(self.language_model.parameters()).dtype)
......
...@@ -51,7 +51,8 @@ class GPT2Model(MegatronModule): ...@@ -51,7 +51,8 @@ class GPT2Model(MegatronModule):
def forward(self, input_ids, position_ids, attention_mask, def forward(self, input_ids, position_ids, attention_mask,
tokentype_ids=None, layer_past=None, get_key_value=False): tokentype_ids=None, layer_past=None, get_key_value=False,
forward_method_parallel_output=None):
# Language model. # Language model.
lm_output = self.language_model(input_ids, lm_output = self.language_model(input_ids,
...@@ -65,10 +66,13 @@ class GPT2Model(MegatronModule): ...@@ -65,10 +66,13 @@ class GPT2Model(MegatronModule):
lm_output, presents = lm_output lm_output, presents = lm_output
# Output. # Output.
parallel_output = self.parallel_output
if forward_method_parallel_output is not None:
parallel_output = forward_method_parallel_output
output = parallel_lm_logits( output = parallel_lm_logits(
lm_output, lm_output,
self.language_model.embedding.word_embeddings.weight, self.language_model.embedding.word_embeddings.weight,
self.parallel_output) parallel_output)
if get_key_value: if get_key_value:
output = [output, presents] output = [output, presents]
......
This diff is collapsed.
...@@ -91,6 +91,10 @@ class AbstractTokenizer(ABC): ...@@ -91,6 +91,10 @@ class AbstractTokenizer(ABC):
def tokenize(self, text): def tokenize(self, text):
pass pass
def detokenize(self, token_ids):
raise NotImplementedError('detokenizer is not implemented for {} '
'tokenizer'.format(self.name))
@property @property
def cls(self): def cls(self):
raise NotImplementedError('CLS is not provided for {} ' raise NotImplementedError('CLS is not provided for {} '
...@@ -190,6 +194,9 @@ class _GPT2BPETokenizer(AbstractTokenizer): ...@@ -190,6 +194,9 @@ class _GPT2BPETokenizer(AbstractTokenizer):
def tokenize(self, text): def tokenize(self, text):
return self.tokenizer.encode(text) return self.tokenizer.encode(text)
def detokenize(self, token_ids):
return self.tokenizer.decode(token_ids)
@property @property
def eod(self): def eod(self):
return self.eod_id return self.eod_id
File mode changed from 100755 to 100644
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment