"vscode:/vscode.git/clone" did not exist on "7e35d32e2987493838779826155f7434bc30b81c"
Commit 6ca8ab15 authored by haileyschoelkopf's avatar haileyschoelkopf
Browse files

clean up wrap_chat_template + add TODOs

parent c47de8be
...@@ -663,59 +663,34 @@ class HFLM(LM): ...@@ -663,59 +663,34 @@ class HFLM(LM):
elif self.AUTO_MODEL_CLASS == transformers.AutoModelForSeq2SeqLM: elif self.AUTO_MODEL_CLASS == transformers.AutoModelForSeq2SeqLM:
return self.tokenizer.decode(tokens, skip_special_tokens=True) return self.tokenizer.decode(tokens, skip_special_tokens=True)
def tok_wrap_chat_template(self, requests: List[Instance]) -> List[Instance]: def wrap_chat_template(
self, requests: List[Instance], generate=False
) -> List[Instance]:
""" """
Utility for adding chat templates via the apply_chat_template() method Utility for adding chat templates via the apply_chat_template() method
""" """
# TODO: handle repeats > 1 case?
# TODO: raise an error if system prompt not compatible with template
new_reqs = [] new_reqs = []
for req in requests: for req in requests:
context, continuation = req.args[0].strip(), req.args[1] context, continuation = req.args[0].strip(), req.args[1]
chat = [
{"role": "system", "content": "You are a helpful assistant."},
{"role": "user", "content": context},
]
context = self.tokenizer.apply_chat_template(
chat,
tokenize=False,
add_generation_prompt=True,
)
req.args = (context, continuation)
new_reqs.append(req)
return new_reqs
def tok_wrap_chat_template(self, requests: List[Instance]) -> List[Instance]:
"""
Utility for adding chat templates via the apply_chat_template() method
"""
new_reqs = []
for req in requests:
context = req.args[0].strip()
#system_prompt = "You are a helpful assistant."
# arc experiment with few-shot formatting
import re
elements = re.split('Answer:|Question:', context.replace('\n', ' '))
new_elements = []
for element in elements[1:-1]:
new_elements.append(element.strip())
new_elements
#chat = [{"role": "system", "content": system_prompt}]
chat = [] chat = []
for i in range(len(new_elements)): if self.system_prompt is not None:
if i % 2 == 0: chat += {"role": "system", "content": "You are a helpful assistant."}
chat.append({"role": "user", "content": f"Question: {new_elements[i]} Answer:"})
else: chat += ({"role": "user", "content": context},)
chat.append({"role": "assistant", "content": f"{new_elements[i]}"}) # TODO: expose settings for chat formatting:
# - whether some "trigger" / start of assistant response might be placed in assistant's generation for it
# - if few-shot, should the fewshots be placed in separate convo turns? provided in user's single turn?...
context = self.tokenizer.apply_chat_template( context = self.tokenizer.apply_chat_template(
chat, chat,
tokenize=False, tokenize=False,
add_generation_prompt=True, add_generation_prompt=True,
) )
req.args = (context, req.args[1].strip()) req.args = (context, continuation)
new_reqs.append(req) new_reqs.append(req)
return new_reqs return new_reqs
def _model_call(self, inps, attn_mask=None, labels=None): def _model_call(self, inps, attn_mask=None, labels=None):
""" """
:param inps: torch.Tensor :param inps: torch.Tensor
...@@ -796,10 +771,8 @@ class HFLM(LM): ...@@ -796,10 +771,8 @@ class HFLM(LM):
return context_enc, continuation_enc return context_enc, continuation_enc
def loglikelihood(self, requests: List[Instance]) -> List[Tuple[float, bool]]: def loglikelihood(self, requests: List[Instance]) -> List[Tuple[float, bool]]:
print("Loglikelihood invoked")
print(f"First element before prompt formatting...\n{requests[0].args}") print(f"First element before prompt formatting...\n{requests[0].args}")
requests = self.tok_wrap_chat_template(requests) requests = self.wrap_chat_template(requests)
print(f"First element after prompt formatting...\n{requests[0].args}") print(f"First element after prompt formatting...\n{requests[0].args}")
new_reqs = [] new_reqs = []
...@@ -820,6 +793,8 @@ class HFLM(LM): ...@@ -820,6 +793,8 @@ class HFLM(LM):
def loglikelihood_rolling(self, requests: List[Instance]) -> List[float]: def loglikelihood_rolling(self, requests: List[Instance]) -> List[float]:
loglikelihoods = [] loglikelihoods = []
# TODO: add a warning that chat templates are ignored for ppl evals
adaptive_batch_size = None adaptive_batch_size = None
if self.batch_size == "auto": if self.batch_size == "auto":
# using rolling window with maximum context # using rolling window with maximum context
...@@ -896,7 +871,6 @@ class HFLM(LM): ...@@ -896,7 +871,6 @@ class HFLM(LM):
disable_tqdm: bool = False, disable_tqdm: bool = False,
override_bs: int = None, override_bs: int = None,
) -> List[Tuple[float, bool]]: ) -> List[Tuple[float, bool]]:
# TODO: implement some kind of efficient-request-middleware that lumps together requests with the same context
res = [] res = []
def _collate(x): def _collate(x):
...@@ -1075,10 +1049,8 @@ class HFLM(LM): ...@@ -1075,10 +1049,8 @@ class HFLM(LM):
return re_ord.get_original(res) return re_ord.get_original(res)
def generate_until(self, requests: List[Instance]) -> List[str]: def generate_until(self, requests: List[Instance]) -> List[str]:
print("Generate_until invoked")
print(f"First element before prompt formatting...\n{requests[0].args}") print(f"First element before prompt formatting...\n{requests[0].args}")
requests = self.tok_wrap_chat_template(requests) requests = self.tok_chat_template(requests)
print(f"First element after prompt formatting...\n{requests[0].args}") print(f"First element after prompt formatting...\n{requests[0].args}")
res = [] res = []
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment