Commit d7b8fd9c authored by Konrad's avatar Konrad
Browse files

tokenizer name property for cache_key

parent 615352c0
...@@ -183,6 +183,12 @@ class LM(abc.ABC): ...@@ -183,6 +183,12 @@ class LM(abc.ABC):
# not support multi-device parallelism nor expect it. # not support multi-device parallelism nor expect it.
return self._world_size return self._world_size
@property
def get_tokenizer_name(self) -> str:
raise NotImplementedError(
"To use this model with chat templates, please implement the 'get_tokenizer_name' property."
)
def set_cache_hook(self, cache_hook) -> None: def set_cache_hook(self, cache_hook) -> None:
self.cache_hook = cache_hook self.cache_hook = cache_hook
......
...@@ -391,8 +391,7 @@ class Task(abc.ABC): ...@@ -391,8 +391,7 @@ class Task(abc.ABC):
if system_instruction is not None if system_instruction is not None
else "" else ""
) )
if lm is not None and hasattr(lm, "tokenizer"): cache_key += f"-tokenizer{lm.get_tokenizer_name}" if apply_chat_template else ""
cache_key += f"-{lm.tokenizer.name_or_path.replace('/', '__')}"
cached_instances = load_from_cache(file_name=cache_key) cached_instances = load_from_cache(file_name=cache_key)
......
...@@ -415,6 +415,10 @@ class HFLM(TemplateLM): ...@@ -415,6 +415,10 @@ class HFLM(TemplateLM):
def world_size(self): def world_size(self):
return self._world_size return self._world_size
@property
def get_tokenizer_name(self) -> str:
return self.tokenizer.name_or_path.replace("/", "__")
def _get_backend( def _get_backend(
self, self,
config: Union[transformers.PretrainedConfig, transformers.AutoConfig], config: Union[transformers.PretrainedConfig, transformers.AutoConfig],
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment