Unverified Commit 59f9ad4b authored by Wang, Yi's avatar Wang, Yi Committed by GitHub
Browse files

fix extra_match low if batch_size > 1 (#2595)



* fix extra_match low if batch_size > 1
Signed-off-by: default avatarWang, Yi A <yi.a.wang@intel.com>

* add sorting to logprobs

* nit

---------
Signed-off-by: default avatarWang, Yi A <yi.a.wang@intel.com>
Co-authored-by: default avatarBaber <baber@hey.com>
parent 932e8f9e
import os
from functools import cached_property
from operator import itemgetter
from typing import Any, Dict, List, Optional, Tuple, Union
from lm_eval.api.registry import register_model
......@@ -68,7 +69,9 @@ class LocalCompletionsAPI(TemplateAPI):
if not isinstance(outputs, list):
outputs = [outputs]
for out in outputs:
for choice, ctxlen in zip(out["choices"], ctxlens):
for choice, ctxlen in zip(
sorted(out["choices"], key=itemgetter("index")), ctxlens
):
assert ctxlen > 0, "Context length must be greater than 0"
logprobs = sum(choice["logprobs"]["token_logprobs"][ctxlen:-1])
tokens_logprobs = choice["logprobs"]["token_logprobs"][ctxlen:-1]
......@@ -87,8 +90,10 @@ class LocalCompletionsAPI(TemplateAPI):
if not isinstance(outputs, list):
outputs = [outputs]
for out in outputs:
tmp = [None] * len(out["choices"])
for choices in out["choices"]:
res.append(choices["text"])
tmp[choices["index"]] = choices["text"]
res = res + tmp
return res
@property
......@@ -157,8 +162,10 @@ class LocalChatCompletion(LocalCompletionsAPI):
if not isinstance(outputs, list):
outputs = [outputs]
for out in outputs:
tmp = [None] * len(out["choices"])
for choices in out["choices"]:
res.append(choices["message"]["content"])
tmp[choices["index"]] = choices["message"]["content"]
res = res + tmp
return res
def tok_encode(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment