Commit 42e83ee0 authored by rprenger's avatar rprenger
Browse files

Changing the interface to the lm eval harness and fixing bugs caused by...

Changing the interface to the lm eval harness and fixing bugs caused by misunderstanding out_seq_length
parent 8fe6f9bf
......@@ -186,8 +186,7 @@ def generate(model, sentences=None, max_len=0, all_probs=False):
output = synced_generate(model, context_tokens_tensor, context_length_tensor, max_len, all_probs)
if output is not None:
decode_tokens, output_logits, full_logits = output
if torch.distributed.get_rank() == 0:
args = get_args()
tokenizer = get_tokenizer()
resp_sentences = []
......@@ -206,7 +205,7 @@ def generate(model, sentences=None, max_len=0, all_probs=False):
output_logits = output_logits.cpu().numpy().tolist()
if all_probs:
full_logits = full_logits.cpu().numpy().tolist()
return resp_sentences, resp_sentences_seg, output_logits, full_logits, decode_tokens
def generate_samples_eval(model, context, max_gen_length, eos_token_id):
......@@ -214,12 +213,15 @@ def generate_samples_eval(model, context, max_gen_length, eos_token_id):
This function is here to provide an a matching API for a legacy task
This implementation hasn't been tested yet to make sure it matches
"""
assert False, "Implementation untested"
#assert False, "Implementation untested"
args = get_args()
args.eos_id = eos_token_id
raw_text_len = len(context)
resp_sentences = generate(model, [context], max_gen_length)
return resp_sentences[0][raw_text_len:]
if resp_sentences:
return resp_sentences[0][raw_text_len:]
else:
return [None] # This is horrible
def switch(val1, val2, boolean):
boolean = boolean.type_as(val1)
......@@ -262,7 +264,7 @@ def forward_step(model, tokens, position_ids, attention_mask, tokentype_ids,
def sample_sequence_batch(model, context_tokens, context_lengths,
attention_mask, position_ids,
maxlen=None, all_probs=False, type_ids=None):
maxlen, all_probs=False, type_ids=None):
args = get_args()
tokenizer = get_tokenizer()
......@@ -285,14 +287,13 @@ def sample_sequence_batch(model, context_tokens, context_lengths,
is_done = torch.zeros([batch_size]).byte().cuda()
tokens = context_tokens
output_logits = None
if maxlen is None:
maxlen = args.seq_length - 1
# TODO(rprenger) maxlen should be named a different parameter
maxlen = maxlen + org_context_length
if maxlen > (org_context_length + args.out_seq_length):
maxlen = org_context_length + args.out_seq_length
# TODO(rprenger) Need a better understanding of what args.seq_length vs args.out_seq_length (shouldn't be "args")
if maxlen > args.seq_length:
maxlen = args.seq_length
lengths = torch.ones([batch_size]).long().cuda() * maxlen
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment