Commit 34b32f77 authored by daniel-furman's avatar daniel-furman
Browse files

first stab at wrap_chat_template, various

parent 6c68fd16
...@@ -671,27 +671,19 @@ class HFLM(LM): ...@@ -671,27 +671,19 @@ class HFLM(LM):
for req in requests: for req in requests:
context, continuation = req.args[0].strip(), req.args[1].strip() context, continuation = req.args[0].strip(), req.args[1].strip()
chat = [ chat = [
{"role": "user", "content": context}, #{"role": "system", "content": "You are a helpful, respectful and honest assistant."},
{"role": "assistant", "content": continuation}, {"role": "user", "content": context},
] ]
single_tokenized_conversation = self.tokenizer.apply_chat_template( context = self.tokenizer.apply_chat_template(
chat, chat,
tokenize=False, tokenize=False,
add_generation_prompt=True, add_generation_prompt=True,
) )
rfind_continuation = single_tokenized_conversation.rfind(continuation)
context = single_tokenized_conversation[:rfind_continuation]
continuation = single_tokenized_conversation[rfind_continuation:]
# remove special chars from continuation
continuation = self.tokenizer.decode(
self.tokenizer.encode(continuation), skip_special_tokens=True
)
req.args = (context, continuation) req.args = (context, continuation)
new_reqs.append(req) new_reqs.append(req)
return new_reqs return new_reqs
def _model_call(self, inps, attn_mask=None, labels=None): def _model_call(self, inps, attn_mask=None, labels=None):
""" """
...@@ -773,9 +765,12 @@ class HFLM(LM): ...@@ -773,9 +765,12 @@ class HFLM(LM):
return context_enc, continuation_enc return context_enc, continuation_enc
def loglikelihood(self, requests: List[Instance]) -> List[Tuple[float, bool]]: def loglikelihood(self, requests: List[Instance]) -> List[Tuple[float, bool]]:
print("Loglikelihood invoked")
print(f"First element before prompt formatting...\n{requests[0].args}") print(f"First element before prompt formatting...\n{requests[0].args}")
requests = self.tok_wrap_chat_template(requests) requests = self.tok_wrap_chat_template(requests)
print(f"First element after prompt formatting...\n{requests[0].args}") print(f"First element after prompt formatting...\n{requests[0].args}")
new_reqs = [] new_reqs = []
for context, continuation in [req.args for req in requests]: for context, continuation in [req.args for req in requests]:
if context == "": if context == "":
...@@ -1049,6 +1044,12 @@ class HFLM(LM): ...@@ -1049,6 +1044,12 @@ class HFLM(LM):
return re_ord.get_original(res) return re_ord.get_original(res)
def generate_until(self, requests: List[Instance]) -> List[str]: def generate_until(self, requests: List[Instance]) -> List[str]:
print("Generate_until invoked")
print(f"First element before prompt formatting...\n{requests[0].args}")
requests = self.tok_wrap_chat_template(requests)
print(f"First element after prompt formatting...\n{requests[0].args}")
res = [] res = []
def _collate(x): def _collate(x):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment