"vscode:/vscode.git/clone" did not exist on "ba0de4b129ffe6469d95291f36c2c50b6fd0fb92"
Commit 87dff8b0 authored by daniel-furman's avatar daniel-furman
Browse files

first stab at wrap_chat_template, print statements in loglikelihood for testing

parent 3e27f9dd
...@@ -664,10 +664,13 @@ class HFLM(LM): ...@@ -664,10 +664,13 @@ class HFLM(LM):
return self.tokenizer.decode(tokens, skip_special_tokens=True) return self.tokenizer.decode(tokens, skip_special_tokens=True)
def tok_wrap_chat_template(self, requests: List[Instance], system: bool = False) -> List[Instance]: def tok_wrap_chat_template(self, requests: List[Instance], system: bool = False) -> List[Instance]:
"""
Utility for adding chat templates via the apply_chat_template() method
"""
new_reqs = [] new_reqs = []
for req in requests: for req in requests:
context, continuation = req.args[0].strip(), req.args[1].strip() context, continuation = req.args[0].strip(), req.args[1].strip()
if system: if system:
chat = [ chat = [
{"role": "system", "content": system}, {"role": "system", "content": system},
...@@ -690,7 +693,7 @@ class HFLM(LM): ...@@ -690,7 +693,7 @@ class HFLM(LM):
continuation = single_tokenized_conversation[rfind_continuation:] continuation = single_tokenized_conversation[rfind_continuation:]
req.args = (context, continuation) req.args = (context, continuation)
new_reqs.append(req) new_reqs.append(req)
return new_reqs return new_reqs
...@@ -774,12 +777,9 @@ class HFLM(LM): ...@@ -774,12 +777,9 @@ class HFLM(LM):
return context_enc, continuation_enc return context_enc, continuation_enc
def loglikelihood(self, requests: List[Instance]) -> List[Tuple[float, bool]]: def loglikelihood(self, requests: List[Instance]) -> List[Tuple[float, bool]]:
#print(requests) print(f"First element before prompt formatting...\n{requests[0].args}")
print(requests[0])
print(requests[0].args)
requests = self.tok_wrap_chat_template(requests) requests = self.tok_wrap_chat_template(requests)
print(requests[0]) print(f"First element after prompt formatting...\n{requests[0].args}")
print(requests[0].args)
new_reqs = [] new_reqs = []
for context, continuation in [req.args for req in requests]: for context, continuation in [req.args for req in requests]:
if context == "": if context == "":
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment