Unverified Commit 4c51111c authored by Hailey Schoelkopf's avatar Hailey Schoelkopf Committed by GitHub
Browse files

Add Gemma support (Add flag to control BOS token usage) (#1465)



* add add_bos_token to HFLM

* add BOS token flag to other local model classes

---------
Co-authored-by: default avatarLintang Sutawika <lintang@eleuther.ai>
parent 7de7b27e
...@@ -274,8 +274,8 @@ class TemplateLM(LM): ...@@ -274,8 +274,8 @@ class TemplateLM(LM):
continuation = context[-n_spaces:] + continuation continuation = context[-n_spaces:] + continuation
context = context[:-n_spaces] context = context[:-n_spaces]
whole_enc = self.tok_encode(context + continuation, add_special_tokens=False) whole_enc = self.tok_encode(context + continuation)
context_enc = self.tok_encode(context, add_special_tokens=False) context_enc = self.tok_encode(context)
context_enc_len = len(context_enc) context_enc_len = len(context_enc)
continuation_enc = whole_enc[context_enc_len:] continuation_enc = whole_enc[context_enc_len:]
......
...@@ -98,6 +98,7 @@ class HFLM(TemplateLM): ...@@ -98,6 +98,7 @@ class HFLM(TemplateLM):
max_batch_size: Optional[int] = 64, max_batch_size: Optional[int] = 64,
trust_remote_code: Optional[bool] = True, trust_remote_code: Optional[bool] = True,
use_fast_tokenizer: Optional[bool] = True, use_fast_tokenizer: Optional[bool] = True,
add_bos_token: Optional[bool] = False,
# arguments used for splitting a model across GPUs naively. # arguments used for splitting a model across GPUs naively.
# only used if `parallelize=True`. # only used if `parallelize=True`.
parallelize: Optional[bool] = False, parallelize: Optional[bool] = False,
...@@ -265,6 +266,14 @@ class HFLM(TemplateLM): ...@@ -265,6 +266,14 @@ class HFLM(TemplateLM):
else: else:
self.tokenizer.add_special_tokens({"pad_token": "<|pad|>"}) self.tokenizer.add_special_tokens({"pad_token": "<|pad|>"})
# TODO: override this for Gemma
self.add_bos_token = add_bos_token
if self.config.model_type == "gemma":
eval_logger.info(
"Model is of type 'gemma', will use a BOS token as Gemma underperforms without it."
)
self.add_bos_token = True
self._max_length = max_length self._max_length = max_length
self.batch_schedule = 1 self.batch_schedule = 1
...@@ -657,8 +666,9 @@ class HFLM(TemplateLM): ...@@ -657,8 +666,9 @@ class HFLM(TemplateLM):
""" """ """ """
if add_special_tokens is None: if add_special_tokens is None:
if self.AUTO_MODEL_CLASS == transformers.AutoModelForCausalLM: if self.AUTO_MODEL_CLASS == transformers.AutoModelForCausalLM:
add_special_tokens = False add_special_tokens = False or self.add_bos_token
elif self.AUTO_MODEL_CLASS == transformers.AutoModelForSeq2SeqLM: elif self.AUTO_MODEL_CLASS == transformers.AutoModelForSeq2SeqLM:
# TODO: investigate best practices for enc-dec models + special tokens
add_special_tokens = True add_special_tokens = True
encoding = self.tokenizer.encode(string, add_special_tokens=add_special_tokens) encoding = self.tokenizer.encode(string, add_special_tokens=add_special_tokens)
...@@ -681,7 +691,7 @@ class HFLM(TemplateLM): ...@@ -681,7 +691,7 @@ class HFLM(TemplateLM):
self.tokenizer.padding_side = padding_side self.tokenizer.padding_side = padding_side
if self.AUTO_MODEL_CLASS == transformers.AutoModelForCausalLM: if self.AUTO_MODEL_CLASS == transformers.AutoModelForCausalLM:
add_special_tokens = False add_special_tokens = False or self.add_bos_token
elif self.AUTO_MODEL_CLASS == transformers.AutoModelForSeq2SeqLM: elif self.AUTO_MODEL_CLASS == transformers.AutoModelForSeq2SeqLM:
add_special_tokens = True add_special_tokens = True
......
...@@ -195,8 +195,7 @@ class NEURON_HF(TemplateLM): ...@@ -195,8 +195,7 @@ class NEURON_HF(TemplateLM):
low_cpu_mem_usage: Optional[bool] = True, low_cpu_mem_usage: Optional[bool] = True,
trust_remote_code: Optional[bool] = False, trust_remote_code: Optional[bool] = False,
use_fast_tokenizer: Optional[bool] = True, use_fast_tokenizer: Optional[bool] = True,
# arguments used for splitting a model across GPUs naively. add_bos_token: Optional[bool] = False,
# only used if `parallelize=True`.
) -> None: ) -> None:
if not NEURON_AVAILABLE: if not NEURON_AVAILABLE:
raise Exception( raise Exception(
...@@ -289,6 +288,7 @@ class NEURON_HF(TemplateLM): ...@@ -289,6 +288,7 @@ class NEURON_HF(TemplateLM):
self.vocab_size = self.tokenizer.vocab_size self.vocab_size = self.tokenizer.vocab_size
self.tokenizer.pad_token_id = self.tokenizer.eos_token_id self.tokenizer.pad_token_id = self.tokenizer.eos_token_id
self.add_bos_token = self.add_bos_token
self._max_length = max_length self._max_length = max_length
...@@ -343,7 +343,7 @@ class NEURON_HF(TemplateLM): ...@@ -343,7 +343,7 @@ class NEURON_HF(TemplateLM):
def tok_encode(self, string: str, left_truncate_len=None, add_special_tokens=None): def tok_encode(self, string: str, left_truncate_len=None, add_special_tokens=None):
""" """ """ """
if add_special_tokens is None: if add_special_tokens is None:
add_special_tokens = False add_special_tokens = False or self.add_bos_token
encoding = self.tokenizer.encode(string, add_special_tokens=add_special_tokens) encoding = self.tokenizer.encode(string, add_special_tokens=add_special_tokens)
...@@ -364,7 +364,7 @@ class NEURON_HF(TemplateLM): ...@@ -364,7 +364,7 @@ class NEURON_HF(TemplateLM):
old_padding_side = self.tokenizer.padding_side old_padding_side = self.tokenizer.padding_side
self.tokenizer.padding_side = padding_side self.tokenizer.padding_side = padding_side
add_special_tokens = False add_special_tokens = False or self.add_bos_token
encoding = self.tokenizer( encoding = self.tokenizer(
strings, strings,
......
...@@ -47,6 +47,7 @@ class VLLM(TemplateLM): ...@@ -47,6 +47,7 @@ class VLLM(TemplateLM):
tokenizer: Optional[str] = None, tokenizer: Optional[str] = None,
tokenizer_mode: Literal["auto", "slow"] = "auto", tokenizer_mode: Literal["auto", "slow"] = "auto",
tokenizer_revision: Optional[str] = None, tokenizer_revision: Optional[str] = None,
add_bos_token: Optional[bool] = False,
tensor_parallel_size: int = 1, tensor_parallel_size: int = 1,
quantization: Optional[str] = None, quantization: Optional[str] = None,
max_gen_toks: int = 256, max_gen_toks: int = 256,
...@@ -114,6 +115,7 @@ class VLLM(TemplateLM): ...@@ -114,6 +115,7 @@ class VLLM(TemplateLM):
trust_remote_code=trust_remote_code, trust_remote_code=trust_remote_code,
tokenizer_revision=tokenizer_revision, tokenizer_revision=tokenizer_revision,
) )
self.add_bos_token = add_bos_token
self._max_gen_toks = max_gen_toks self._max_gen_toks = max_gen_toks
...@@ -147,10 +149,12 @@ class VLLM(TemplateLM): ...@@ -147,10 +149,12 @@ class VLLM(TemplateLM):
self, self,
string: str, string: str,
left_truncate_len=None, left_truncate_len=None,
add_special_tokens=False, add_special_tokens=None,
truncation=False, truncation=False,
): ):
""" """ """ """
if not add_special_tokens:
add_special_tokens = False or self.add_bos_token
encoding = self.tokenizer.encode( encoding = self.tokenizer.encode(
string, add_special_tokens=add_special_tokens, truncation=truncation string, add_special_tokens=add_special_tokens, truncation=truncation
) )
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment