Commit 06a04113 authored by Baber's avatar Baber
Browse files

send messages

parent 52c96866
...@@ -69,11 +69,17 @@ class JudgeFilter(Filter): ...@@ -69,11 +69,17 @@ class JudgeFilter(Filter):
base_url=url, pretrained=model, num_concurrent=2, **kwargs base_url=url, pretrained=model, num_concurrent=2, **kwargs
) )
@staticmethod
def create_message(str) -> list[dict]:
return [{"role": "user", "content": str}]
def apply(self, resps: list[list[str]], docs: list[dict]) -> list[list[str]]: def apply(self, resps: list[list[str]], docs: list[dict]) -> list[list[str]]:
inputs = [ inputs = [
self.PROMPT self.create_message(
+ "\n\n" self.PROMPT
+ f"Question: {doc['question']}\nAnswer: {resp}\nGround Truth: {doc['answer']}" + "\n\n"
+ f"Question: {doc['question']}\nAnswer: {resp}\nGround Truth: {doc['answer']}"
)
for resp, doc in zip(resps, docs) for resp, doc in zip(resps, docs)
] ]
......
...@@ -200,6 +200,9 @@ class TemplateAPI(TemplateLM): ...@@ -200,6 +200,9 @@ class TemplateAPI(TemplateLM):
) )
# list[dict["role":..., "content":...],...] # list[dict["role":..., "content":...],...]
return json.loads(messages[0].prompt) return json.loads(messages[0].prompt)
elif isinstance(messages[0], dict):
# list[dict["role":..., "content":...],...]
return messages
if not self.tokenized_requests: if not self.tokenized_requests:
# if messages are tokenized: # if messages are tokenized:
...@@ -721,7 +724,9 @@ class TemplateAPI(TemplateLM): ...@@ -721,7 +724,9 @@ class TemplateAPI(TemplateLM):
return loglikelihoods return loglikelihoods
def simple_async_generate( def simple_async_generate(
self, requests: Union[List[List[str]], List[List[dict]]], gen_kwargs: dict self,
requests: Union[List[List[str], list[list[dict]]], List[List[dict]]],
gen_kwargs: dict,
): ):
results = itertools.chain.from_iterable( results = itertools.chain.from_iterable(
asyncio.run( asyncio.run(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment