Unverified Commit 2eea3f50 authored by Avelina Asada Hadji-Kyriacou's avatar Avelina Asada Hadji-Kyriacou Committed by GitHub
Browse files

Added `chat_template_args` to pass additional kwargs to tokenizer.apply_chat_template (#3164)



* added support for additional chat template arguments

* use `enable_thinking`

* add wrap logging function

* add `chat_template_args` back to HF

---------
Co-authored-by: default avatarBaber <baber@hey.com>
parent 091aaf6f
...@@ -35,6 +35,7 @@ from lm_eval.utils import ( ...@@ -35,6 +35,7 @@ from lm_eval.utils import (
positional_deprecated, positional_deprecated,
setup_logging, setup_logging,
simple_parse_args_string, simple_parse_args_string,
wrap_text,
) )
...@@ -169,8 +170,11 @@ def simple_evaluate( ...@@ -169,8 +170,11 @@ def simple_evaluate(
) )
) and not apply_chat_template: ) and not apply_chat_template:
eval_logger.warning( eval_logger.warning(
"Model appears to be an instruct or chat variant but chat template is not applied. " wrap_text(
"Recommend setting `apply_chat_template` (optionally `fewshot_as_multiturn`)." f"""pretrained={model_args.get("pretrained") if isinstance(model_args, dict) else model_args} appears to be an
instruct or chat variant but chat template is not applied.
Recommend setting `apply_chat_template` (optionally `fewshot_as_multiturn`).""",
)
) )
if delete_requests_cache: if delete_requests_cache:
...@@ -234,8 +238,10 @@ def simple_evaluate( ...@@ -234,8 +238,10 @@ def simple_evaluate(
else: else:
eval_logger.info( eval_logger.info(
wrap_text(
f"Initializing {model} model, with arguments: {simple_parse_args_string(model_args)}" f"Initializing {model} model, with arguments: {simple_parse_args_string(model_args)}"
) )
)
lm = lm_eval.api.registry.get_model(model).create_from_arg_string( lm = lm_eval.api.registry.get_model(model).create_from_arg_string(
model_args, model_args,
{ {
......
from __future__ import annotations
import copy import copy
import logging import logging
import os import os
from datetime import timedelta from datetime import timedelta
from pathlib import Path from pathlib import Path
from typing import TYPE_CHECKING, Dict, List, Literal, Optional, Tuple, Union from typing import TYPE_CHECKING, Any, Dict, List, Literal, Optional, Tuple, Union
import jinja2 import jinja2
import torch import torch
...@@ -99,6 +101,8 @@ class HFLM(TemplateLM): ...@@ -99,6 +101,8 @@ class HFLM(TemplateLM):
# end token for thinking, either the string or int token id. # end token for thinking, either the string or int token id.
# splits to get response after this token (if provided). # splits to get response after this token (if provided).
think_end_token: Union[str, int, None] = None, think_end_token: Union[str, int, None] = None,
enable_thinking: bool | None = None,
chat_template_args: Optional[dict[str, Any]] = None,
**kwargs, **kwargs,
) -> None: ) -> None:
super().__init__() super().__init__()
...@@ -238,6 +242,11 @@ class HFLM(TemplateLM): ...@@ -238,6 +242,11 @@ class HFLM(TemplateLM):
self.vocab_size = self.tokenizer.vocab_size self.vocab_size = self.tokenizer.vocab_size
# select (or create) a pad token to use # select (or create) a pad token to use
self.tokenizer = configure_pad_token(self.tokenizer, model_config=self.config) self.tokenizer = configure_pad_token(self.tokenizer, model_config=self.config)
self.chat_template_args = (
chat_template_args or {} | dict(enable_thinking=enable_thinking)
if enable_thinking is not None
else {}
)
self.add_bos_token = add_bos_token self.add_bos_token = add_bos_token
if "gemma" in getattr(self.config, "model_type", ""): if "gemma" in getattr(self.config, "model_type", ""):
...@@ -1483,6 +1492,7 @@ class HFLM(TemplateLM): ...@@ -1483,6 +1492,7 @@ class HFLM(TemplateLM):
tokenize=False, tokenize=False,
add_generation_prompt=add_generation_prompt, add_generation_prompt=add_generation_prompt,
continue_final_message=not add_generation_prompt, continue_final_message=not add_generation_prompt,
**self.chat_template_args,
) )
except jinja2.exceptions.TemplateError: except jinja2.exceptions.TemplateError:
eval_logger.warning( eval_logger.warning(
...@@ -1494,6 +1504,7 @@ class HFLM(TemplateLM): ...@@ -1494,6 +1504,7 @@ class HFLM(TemplateLM):
tokenize=False, tokenize=False,
add_generation_prompt=add_generation_prompt, add_generation_prompt=add_generation_prompt,
continue_final_message=not add_generation_prompt, continue_final_message=not add_generation_prompt,
**self.chat_template_args,
) )
return chat_templated return chat_templated
......
...@@ -136,6 +136,7 @@ class VLLM(TemplateLM): ...@@ -136,6 +136,7 @@ class VLLM(TemplateLM):
lora_local_path: str = None, lora_local_path: str = None,
# VLLM: enable thinking tags in the prompt. # VLLM: enable thinking tags in the prompt.
enable_thinking: bool = True, enable_thinking: bool = True,
chat_template_args: Optional[dict] = None,
# End marker for thinking tags - splits to get response after this token (if provided). # End marker for thinking tags - splits to get response after this token (if provided).
think_end_token: Optional[str] = None, think_end_token: Optional[str] = None,
max_lora_rank: int = 16, max_lora_rank: int = 16,
...@@ -209,7 +210,10 @@ class VLLM(TemplateLM): ...@@ -209,7 +210,10 @@ class VLLM(TemplateLM):
add_bos_token=add_bos_token, add_bos_token=add_bos_token,
) )
self.tokenizer = configure_pad_token(self.tokenizer, model_config=self._config) self.tokenizer = configure_pad_token(self.tokenizer, model_config=self._config)
self.enable_thinking = enable_thinking self.chat_template_args = chat_template_args or {}
self.enable_thinking = chat_template_args.pop(
"enable_thinking", enable_thinking
)
self.add_bos_token = add_bos_token self.add_bos_token = add_bos_token
if "gemma" in pretrained.lower(): if "gemma" in pretrained.lower():
self.add_bos_token = True self.add_bos_token = True
...@@ -317,6 +321,7 @@ class VLLM(TemplateLM): ...@@ -317,6 +321,7 @@ class VLLM(TemplateLM):
continue_final_message=not add_generation_prompt, continue_final_message=not add_generation_prompt,
chat_template=self.hf_chat_template, chat_template=self.hf_chat_template,
enable_thinking=self.enable_thinking, enable_thinking=self.enable_thinking,
**self.chat_template_args,
) )
except jinja2.exceptions.TemplateError: except jinja2.exceptions.TemplateError:
eval_logger.warning( eval_logger.warning(
...@@ -329,6 +334,7 @@ class VLLM(TemplateLM): ...@@ -329,6 +334,7 @@ class VLLM(TemplateLM):
continue_final_message=not add_generation_prompt, continue_final_message=not add_generation_prompt,
chat_template=self.hf_chat_template, chat_template=self.hf_chat_template,
enable_thinking=self.enable_thinking, enable_thinking=self.enable_thinking,
**self.chat_template_args,
) )
return chat_templated return chat_templated
......
...@@ -26,6 +26,23 @@ HIGHER_IS_BETTER_SYMBOLS = { ...@@ -26,6 +26,23 @@ HIGHER_IS_BETTER_SYMBOLS = {
} }
def wrap_text(string: str, width: int = 140, **kwargs) -> Optional[str]:
"""
Wraps the given string to the specified width.
"""
import textwrap
return textwrap.fill(
inspect.cleandoc(string),
width=width,
initial_indent="",
subsequent_indent=" " * 8,
break_long_words=False,
break_on_hyphens=False,
**kwargs,
)
def setup_logging(verbosity=logging.INFO): def setup_logging(verbosity=logging.INFO):
# Configure the root logger # Configure the root logger
class CustomFormatter(logging.Formatter): class CustomFormatter(logging.Formatter):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment