Commit d7b8fd9c authored by Konrad's avatar Konrad
Browse files

tokenizer name property for cache_key

parent 615352c0
......@@ -183,6 +183,12 @@ class LM(abc.ABC):
# not support multi-device parallelism nor expect it.
return self._world_size
@property
def get_tokenizer_name(self) -> str:
raise NotImplementedError(
"To use this model with chat templates, please implement the 'get_tokenizer_name' property."
)
def set_cache_hook(self, cache_hook) -> None:
self.cache_hook = cache_hook
......
......@@ -391,8 +391,7 @@ class Task(abc.ABC):
if system_instruction is not None
else ""
)
if lm is not None and hasattr(lm, "tokenizer"):
cache_key += f"-{lm.tokenizer.name_or_path.replace('/', '__')}"
cache_key += f"-tokenizer{lm.get_tokenizer_name}" if apply_chat_template else ""
cached_instances = load_from_cache(file_name=cache_key)
......
......@@ -415,6 +415,10 @@ class HFLM(TemplateLM):
def world_size(self):
return self._world_size
@property
def get_tokenizer_name(self) -> str:
return self.tokenizer.name_or_path.replace("/", "__")
def _get_backend(
self,
config: Union[transformers.PretrainedConfig, transformers.AutoConfig],
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment