Commit 06a04113 authored by Baber's avatar Baber
Browse files

send messages

parent 52c96866
......@@ -69,11 +69,17 @@ class JudgeFilter(Filter):
base_url=url, pretrained=model, num_concurrent=2, **kwargs
)
@staticmethod
def create_message(str) -> list[dict]:
return [{"role": "user", "content": str}]
def apply(self, resps: list[list[str]], docs: list[dict]) -> list[list[str]]:
inputs = [
self.PROMPT
+ "\n\n"
+ f"Question: {doc['question']}\nAnswer: {resp}\nGround Truth: {doc['answer']}"
self.create_message(
self.PROMPT
+ "\n\n"
+ f"Question: {doc['question']}\nAnswer: {resp}\nGround Truth: {doc['answer']}"
)
for resp, doc in zip(resps, docs)
]
......
......@@ -200,6 +200,9 @@ class TemplateAPI(TemplateLM):
)
# list[dict["role":..., "content":...],...]
return json.loads(messages[0].prompt)
elif isinstance(messages[0], dict):
# list[dict["role":..., "content":...],...]
return messages
if not self.tokenized_requests:
# if messages are tokenized:
......@@ -721,7 +724,9 @@ class TemplateAPI(TemplateLM):
return loglikelihoods
def simple_async_generate(
self, requests: Union[List[List[str]], List[List[dict]]], gen_kwargs: dict
self,
requests: Union[List[List[str], list[list[dict]]], List[List[dict]]],
gen_kwargs: dict,
):
results = itertools.chain.from_iterable(
asyncio.run(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment