Unverified Commit 2eea3f50 authored by Avelina Asada Hadji-Kyriacou's avatar Avelina Asada Hadji-Kyriacou Committed by GitHub
Browse files

Added `chat_template_args` to pass additional kwargs to tokenizer.apply_chat_template (#3164)



* added support for additional chat template arguments

* use `enable_thinking`

* add wrap logging function

* add `chat_template_args` back to HF

---------
Co-authored-by: default avatarBaber <baber@hey.com>
parent 091aaf6f
......@@ -35,6 +35,7 @@ from lm_eval.utils import (
positional_deprecated,
setup_logging,
simple_parse_args_string,
wrap_text,
)
......@@ -169,8 +170,11 @@ def simple_evaluate(
)
) and not apply_chat_template:
eval_logger.warning(
"Model appears to be an instruct or chat variant but chat template is not applied. "
"Recommend setting `apply_chat_template` (optionally `fewshot_as_multiturn`)."
wrap_text(
f"""pretrained={model_args.get("pretrained") if isinstance(model_args, dict) else model_args} appears to be an
instruct or chat variant but chat template is not applied.
Recommend setting `apply_chat_template` (optionally `fewshot_as_multiturn`).""",
)
)
if delete_requests_cache:
......@@ -234,7 +238,9 @@ def simple_evaluate(
else:
eval_logger.info(
f"Initializing {model} model, with arguments: {simple_parse_args_string(model_args)}"
wrap_text(
f"Initializing {model} model, with arguments: {simple_parse_args_string(model_args)}"
)
)
lm = lm_eval.api.registry.get_model(model).create_from_arg_string(
model_args,
......
from __future__ import annotations
import copy
import logging
import os
from datetime import timedelta
from pathlib import Path
from typing import TYPE_CHECKING, Dict, List, Literal, Optional, Tuple, Union
from typing import TYPE_CHECKING, Any, Dict, List, Literal, Optional, Tuple, Union
import jinja2
import torch
......@@ -99,6 +101,8 @@ class HFLM(TemplateLM):
# end token for thinking, either the string or int token id.
# splits to get response after this token (if provided).
think_end_token: Union[str, int, None] = None,
enable_thinking: bool | None = None,
chat_template_args: Optional[dict[str, Any]] = None,
**kwargs,
) -> None:
super().__init__()
......@@ -238,6 +242,11 @@ class HFLM(TemplateLM):
self.vocab_size = self.tokenizer.vocab_size
# select (or create) a pad token to use
self.tokenizer = configure_pad_token(self.tokenizer, model_config=self.config)
self.chat_template_args = (
chat_template_args or {} | dict(enable_thinking=enable_thinking)
if enable_thinking is not None
else {}
)
self.add_bos_token = add_bos_token
if "gemma" in getattr(self.config, "model_type", ""):
......@@ -1483,6 +1492,7 @@ class HFLM(TemplateLM):
tokenize=False,
add_generation_prompt=add_generation_prompt,
continue_final_message=not add_generation_prompt,
**self.chat_template_args,
)
except jinja2.exceptions.TemplateError:
eval_logger.warning(
......@@ -1494,6 +1504,7 @@ class HFLM(TemplateLM):
tokenize=False,
add_generation_prompt=add_generation_prompt,
continue_final_message=not add_generation_prompt,
**self.chat_template_args,
)
return chat_templated
......
......@@ -136,6 +136,7 @@ class VLLM(TemplateLM):
lora_local_path: str = None,
# VLLM: enable thinking tags in the prompt.
enable_thinking: bool = True,
chat_template_args: Optional[dict] = None,
# End marker for thinking tags - splits to get response after this token (if provided).
think_end_token: Optional[str] = None,
max_lora_rank: int = 16,
......@@ -209,7 +210,10 @@ class VLLM(TemplateLM):
add_bos_token=add_bos_token,
)
self.tokenizer = configure_pad_token(self.tokenizer, model_config=self._config)
self.enable_thinking = enable_thinking
self.chat_template_args = chat_template_args or {}
self.enable_thinking = chat_template_args.pop(
"enable_thinking", enable_thinking
)
self.add_bos_token = add_bos_token
if "gemma" in pretrained.lower():
self.add_bos_token = True
......@@ -317,6 +321,7 @@ class VLLM(TemplateLM):
continue_final_message=not add_generation_prompt,
chat_template=self.hf_chat_template,
enable_thinking=self.enable_thinking,
**self.chat_template_args,
)
except jinja2.exceptions.TemplateError:
eval_logger.warning(
......@@ -329,6 +334,7 @@ class VLLM(TemplateLM):
continue_final_message=not add_generation_prompt,
chat_template=self.hf_chat_template,
enable_thinking=self.enable_thinking,
**self.chat_template_args,
)
return chat_templated
......
......@@ -26,6 +26,23 @@ HIGHER_IS_BETTER_SYMBOLS = {
}
def wrap_text(string: str, width: int = 140, **kwargs) -> Optional[str]:
"""
Wraps the given string to the specified width.
"""
import textwrap
return textwrap.fill(
inspect.cleandoc(string),
width=width,
initial_indent="",
subsequent_indent=" " * 8,
break_long_words=False,
break_on_hyphens=False,
**kwargs,
)
def setup_logging(verbosity=logging.INFO):
# Configure the root logger
class CustomFormatter(logging.Formatter):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment