Commit b6b7ba4d authored by rprenger's avatar rprenger
Browse files

Added generate_samples_eval function

parent e718810e
...@@ -151,6 +151,7 @@ def synced_generate(model, context_tokens_tensor, context_length_tensor, max_len ...@@ -151,6 +151,7 @@ def synced_generate(model, context_tokens_tensor, context_length_tensor, max_len
return tokens[:, :context_length] return tokens[:, :context_length]
def generate(model, sentences=None, max_len=0): def generate(model, sentences=None, max_len=0):
model.eval()
if torch.distributed.get_rank() == 0: if torch.distributed.get_rank() == 0:
context_tokens_tensor, context_length_tensor = tokenize_batch(sentences) context_tokens_tensor, context_length_tensor = tokenize_batch(sentences)
send_generate_info(context_tokens_tensor, context_length_tensor, max_len) send_generate_info(context_tokens_tensor, context_length_tensor, max_len)
...@@ -168,6 +169,18 @@ def generate(model, sentences=None, max_len=0): ...@@ -168,6 +169,18 @@ def generate(model, sentences=None, max_len=0):
resp_sentences.append(tokenizer.detokenize(decode_token)) resp_sentences.append(tokenizer.detokenize(decode_token))
return resp_sentences return resp_sentences
def generate_samples_eval(model, context, max_gen_length, eos_token_id):
"""
This function is here to provide an a matching API for a legacy task
This implementation hasn't been tested yet to make sure it matches
"""
assert False, "Implementation untested"
args = get_args()
args.eos_id = eos_token_id
raw_text_len = len(context)
resp_sentences = generate(model, [context], max_gen_length)
return resp_sentences[0][raw_text_len:]
def switch(val1, val2, boolean): def switch(val1, val2, boolean):
boolean = boolean.type_as(val1) boolean = boolean.type_as(val1)
return (1 - boolean) * val1 + boolean * val2 return (1 - boolean) * val1 + boolean * val2
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment