Commit 68c30aa7 authored by haileyschoelkopf's avatar haileyschoelkopf
Browse files

push most recent code

parent b8bda478
......@@ -96,6 +96,10 @@ class HFLM(LM):
# PEFT and quantization options
peft: Optional[str] = None,
autogptq: Optional[Union[bool, str]] = False,
# Chat templating settings
use_chat_template: Optional[bool] = False,
# TODO: validate a template exists in tokenizer config, if this flag is true
system_prompt: Optional[str] = None,
**kwargs,
) -> None:
super().__init__()
......@@ -241,6 +245,9 @@ class HFLM(LM):
else:
self.tokenizer.add_special_tokens({"pad_token": "<|pad|>"})
self.system_prompt = system_prompt
self.use_chat_template = use_chat_template
self._max_length = max_length
self.batch_schedule = 1
......@@ -691,9 +698,11 @@ class HFLM(LM):
context, continuation = req.args[0].strip(), req.args[1]
chat = []
if self.system_prompt is not None:
chat += {"role": "system", "content": "You are a helpful assistant."}
chat += [{"role": "system", "content": "You are a helpful assistant."}]
chat += ({"role": "user", "content": context},)
chat += [
{"role": "user", "content": context},
]
# TODO: expose settings for chat formatting:
# - whether some "trigger" / start of assistant response might be placed in assistant's generation for it
# - if few-shot, should the fewshots be placed in separate convo turns? provided in user's single turn?...
......@@ -786,9 +795,10 @@ class HFLM(LM):
return context_enc, continuation_enc
def loglikelihood(self, requests: List[Instance]) -> List[Tuple[float, bool]]:
print(f"First element before prompt formatting...\n{requests[0].args}")
requests = self.wrap_chat_template(requests)
print(f"First element after prompt formatting...\n{requests[0].args}")
if self.use_chat_template:
print(f"First element before prompt formatting...\n{requests[0].args}")
requests = self.wrap_chat_template(requests)
print(f"First element after prompt formatting...\n{requests[0].args}")
new_reqs = []
for context, continuation in [req.args for req in requests]:
......@@ -1064,9 +1074,10 @@ class HFLM(LM):
return re_ord.get_original(res)
def generate_until(self, requests: List[Instance]) -> List[str]:
print(f"First element before prompt formatting...\n{requests[0].args}")
requests = self.tok_chat_template(requests)
print(f"First element after prompt formatting...\n{requests[0].args}")
if self.use_chat_template:
print(f"First element before prompt formatting...\n{requests[0].args}")
requests = self.tok_chat_template(requests)
print(f"First element after prompt formatting...\n{requests[0].args}")
res = []
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment