Unverified Commit 4c51111c authored by Hailey Schoelkopf's avatar Hailey Schoelkopf Committed by GitHub
Browse files

Add Gemma support (Add flag to control BOS token usage) (#1465)



* add add_bos_token to HFLM

* add BOS token flag to other local model classes

---------
Co-authored-by: default avatarLintang Sutawika <lintang@eleuther.ai>
parent 7de7b27e
......@@ -274,8 +274,8 @@ class TemplateLM(LM):
continuation = context[-n_spaces:] + continuation
context = context[:-n_spaces]
whole_enc = self.tok_encode(context + continuation, add_special_tokens=False)
context_enc = self.tok_encode(context, add_special_tokens=False)
whole_enc = self.tok_encode(context + continuation)
context_enc = self.tok_encode(context)
context_enc_len = len(context_enc)
continuation_enc = whole_enc[context_enc_len:]
......
......@@ -98,6 +98,7 @@ class HFLM(TemplateLM):
max_batch_size: Optional[int] = 64,
trust_remote_code: Optional[bool] = True,
use_fast_tokenizer: Optional[bool] = True,
add_bos_token: Optional[bool] = False,
# arguments used for splitting a model across GPUs naively.
# only used if `parallelize=True`.
parallelize: Optional[bool] = False,
......@@ -265,6 +266,14 @@ class HFLM(TemplateLM):
else:
self.tokenizer.add_special_tokens({"pad_token": "<|pad|>"})
# TODO: override this for Gemma
self.add_bos_token = add_bos_token
if self.config.model_type == "gemma":
eval_logger.info(
"Model is of type 'gemma', will use a BOS token as Gemma underperforms without it."
)
self.add_bos_token = True
self._max_length = max_length
self.batch_schedule = 1
......@@ -657,8 +666,9 @@ class HFLM(TemplateLM):
""" """
if add_special_tokens is None:
if self.AUTO_MODEL_CLASS == transformers.AutoModelForCausalLM:
add_special_tokens = False
add_special_tokens = False or self.add_bos_token
elif self.AUTO_MODEL_CLASS == transformers.AutoModelForSeq2SeqLM:
# TODO: investigate best practices for enc-dec models + special tokens
add_special_tokens = True
encoding = self.tokenizer.encode(string, add_special_tokens=add_special_tokens)
......@@ -681,7 +691,7 @@ class HFLM(TemplateLM):
self.tokenizer.padding_side = padding_side
if self.AUTO_MODEL_CLASS == transformers.AutoModelForCausalLM:
add_special_tokens = False
add_special_tokens = False or self.add_bos_token
elif self.AUTO_MODEL_CLASS == transformers.AutoModelForSeq2SeqLM:
add_special_tokens = True
......
......@@ -195,8 +195,7 @@ class NEURON_HF(TemplateLM):
low_cpu_mem_usage: Optional[bool] = True,
trust_remote_code: Optional[bool] = False,
use_fast_tokenizer: Optional[bool] = True,
# arguments used for splitting a model across GPUs naively.
# only used if `parallelize=True`.
add_bos_token: Optional[bool] = False,
) -> None:
if not NEURON_AVAILABLE:
raise Exception(
......@@ -289,6 +288,7 @@ class NEURON_HF(TemplateLM):
self.vocab_size = self.tokenizer.vocab_size
self.tokenizer.pad_token_id = self.tokenizer.eos_token_id
self.add_bos_token = self.add_bos_token
self._max_length = max_length
......@@ -343,7 +343,7 @@ class NEURON_HF(TemplateLM):
def tok_encode(self, string: str, left_truncate_len=None, add_special_tokens=None):
""" """
if add_special_tokens is None:
add_special_tokens = False
add_special_tokens = False or self.add_bos_token
encoding = self.tokenizer.encode(string, add_special_tokens=add_special_tokens)
......@@ -364,7 +364,7 @@ class NEURON_HF(TemplateLM):
old_padding_side = self.tokenizer.padding_side
self.tokenizer.padding_side = padding_side
add_special_tokens = False
add_special_tokens = False or self.add_bos_token
encoding = self.tokenizer(
strings,
......
......@@ -47,6 +47,7 @@ class VLLM(TemplateLM):
tokenizer: Optional[str] = None,
tokenizer_mode: Literal["auto", "slow"] = "auto",
tokenizer_revision: Optional[str] = None,
add_bos_token: Optional[bool] = False,
tensor_parallel_size: int = 1,
quantization: Optional[str] = None,
max_gen_toks: int = 256,
......@@ -114,6 +115,7 @@ class VLLM(TemplateLM):
trust_remote_code=trust_remote_code,
tokenizer_revision=tokenizer_revision,
)
self.add_bos_token = add_bos_token
self._max_gen_toks = max_gen_toks
......@@ -147,10 +149,12 @@ class VLLM(TemplateLM):
self,
string: str,
left_truncate_len=None,
add_special_tokens=False,
add_special_tokens=None,
truncation=False,
):
""" """
if not add_special_tokens:
add_special_tokens = False or self.add_bos_token
encoding = self.tokenizer.encode(
string, add_special_tokens=add_special_tokens, truncation=truncation
)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment