Commit ce29d4d5 authored by Mohammad's avatar Mohammad
Browse files

working on refactoring text generation

parent a0bcee94
This diff is collapsed.
......@@ -357,29 +357,7 @@ def _add_gpt2_args(parser):
def add_text_generate_args(parser):
"""Text generate arguments."""
group = parser.add_argument_group('Text generation', 'configurations')
group.add_argument("--temperature", type=float, default=1.0)
group.add_argument("--greedy", action='store_true', default=False)
group.add_argument("--top_p", type=float, default=0.0)
group.add_argument("--top_k", type=int, default=0)
group.add_argument("--out-seq-length", type=int, default=1024)
group.add_argument("--sample-input-file", type=str, default="",
help='get input from file instead of interactive mode, '
'each line is an input' )
group.add_argument("--sample-output-file", type=str, default="",
help='output file got from --sample-input-file')
group.add_argument("--num-samples", type=int, default=0,
help='number of samples to generate unconditionally, '
'defaults to 0 and interactive conditional sampling')
group.add_argument("--genfile", type=str,
help='output file when generating unconditionally')
group.add_argument("--recompute", action='store_true',
help='during generation recompute all attention '
'instead of using previously computed keys/values.')
return parser
def add_data_args_(parser):
......
......@@ -137,8 +137,7 @@ class BertModel(MegatronModule):
self._binary_head_key = 'binary_head'
def forward(self, input_ids, attention_mask,
tokentype_ids=None):
def forward(self, input_ids, attention_mask, tokentype_ids=None):
extended_attention_mask = bert_extended_attention_mask(
attention_mask, next(self.language_model.parameters()).dtype)
......
......@@ -51,7 +51,8 @@ class GPT2Model(MegatronModule):
def forward(self, input_ids, position_ids, attention_mask,
tokentype_ids=None, layer_past=None, get_key_value=False):
tokentype_ids=None, layer_past=None, get_key_value=False,
forward_method_parallel_output=None):
# Language model.
lm_output = self.language_model(input_ids,
......@@ -65,10 +66,13 @@ class GPT2Model(MegatronModule):
lm_output, presents = lm_output
# Output.
parallel_output = self.parallel_output
if forward_method_parallel_output is not None:
parallel_output = forward_method_parallel_output
output = parallel_lm_logits(
lm_output,
self.language_model.embedding.word_embeddings.weight,
self.parallel_output)
parallel_output)
if get_key_value:
output = [output, presents]
......
......@@ -91,6 +91,10 @@ class AbstractTokenizer(ABC):
def tokenize(self, text):
pass
def detokenize(self, token_ids):
raise NotImplementedError('detokenizer is not implemented for {} '
'tokenizer'.format(self.name))
@property
def cls(self):
raise NotImplementedError('CLS is not provided for {} '
......@@ -190,6 +194,9 @@ class _GPT2BPETokenizer(AbstractTokenizer):
def tokenize(self, text):
return self.tokenizer.encode(text)
def detokenize(self, token_ids):
return self.tokenizer.decode(token_ids)
@property
def eod(self):
return self.eod_id
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment