Commit 38248282 authored by daniel-furman's avatar daniel-furman
Browse files

first stab at wrap_chat_template

parent 28ec7fa9
...@@ -31,3 +31,13 @@ class Instance: ...@@ -31,3 +31,13 @@ class Instance:
return ( return (
self.arguments if isinstance(self.arguments, tuple) else (self.arguments,) self.arguments if isinstance(self.arguments, tuple) else (self.arguments,)
) )
@args.setter
def args(self, new_arguments: tuple) -> None:
"""
Update the arguments of this instance with a new one
"""
if isinstance(new_arguments, tuple):
self.arguments = new_arguments
else:
print("Please enter a valid arguments tuple")
...@@ -663,6 +663,35 @@ class HFLM(LM): ...@@ -663,6 +663,35 @@ class HFLM(LM):
elif self.AUTO_MODEL_CLASS == transformers.AutoModelForSeq2SeqLM: elif self.AUTO_MODEL_CLASS == transformers.AutoModelForSeq2SeqLM:
return self.tokenizer.decode(tokens, skip_special_tokens=True) return self.tokenizer.decode(tokens, skip_special_tokens=True)
def tok_wrap_chat_template(self, requests: List[Instance], system: bool = False) -> List[Instance]:
new_reqs = []
for req in requests:
context, continuation = req.args[0], req.args[1]
if system:
chat = [
{"role": "system", "content": system.strip()},
{"role": "user", "content": context.strip()},
{"role": "assistant", "content": continuation.strip()},
]
else:
chat = [
{"role": "user", "content": context.strip()},
{"role": "assistant", "content": continuation.strip()},
]
single_tokenized_conversation = self.tokenizer.apply_chat_template(
chat,
tokenize=False,
add_generation_prompt=True,
)
split_on_continuation = single_tokenized_conversation.split(continuation)
context = split_on_continuation[0]
continuation += split_on_continuation[1]
req.args = (context, continuation)
new_reqs.append(req)
return new_reqs
def _model_call(self, inps, attn_mask=None, labels=None): def _model_call(self, inps, attn_mask=None, labels=None):
""" """
:param inps: torch.Tensor :param inps: torch.Tensor
...@@ -743,6 +772,12 @@ class HFLM(LM): ...@@ -743,6 +772,12 @@ class HFLM(LM):
return context_enc, continuation_enc return context_enc, continuation_enc
def loglikelihood(self, requests: List[Instance]) -> List[Tuple[float, bool]]: def loglikelihood(self, requests: List[Instance]) -> List[Tuple[float, bool]]:
#print(requests)
print(requests[0])
print(requests[0].args)
requests = self.tok_wrap_chat_template(requests)
print(requests[0])
print(requests[0].args)
new_reqs = [] new_reqs = []
for context, continuation in [req.args for req in requests]: for context, continuation in [req.args for req in requests]:
if context == "": if context == "":
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment