Unverified Commit c7b03ad4 authored by Hailey Schoelkopf's avatar Hailey Schoelkopf Committed by GitHub
Browse files

Fixes to Loglikelihood prefix token / VLLM (#1611)

* make vllm use prefix_token_id ; have prefix_token_id be optional method to define

* custom_prefix_token_id wasn't set if not passed
parent d4b8fc13
...@@ -284,10 +284,9 @@ class TemplateLM(LM): ...@@ -284,10 +284,9 @@ class TemplateLM(LM):
pass pass
@property @property
@abc.abstractmethod
def prefix_token_id(self): def prefix_token_id(self):
# it is used as prefix for loglikelihood # it is used as prefix for loglikelihood
pass return self.eot_token_id
@abc.abstractmethod @abc.abstractmethod
def tok_encode(self, string: str, **kwargs): def tok_encode(self, string: str, **kwargs):
......
...@@ -99,6 +99,7 @@ class HFLM(TemplateLM): ...@@ -99,6 +99,7 @@ class HFLM(TemplateLM):
trust_remote_code: Optional[bool] = False, trust_remote_code: Optional[bool] = False,
use_fast_tokenizer: Optional[bool] = True, use_fast_tokenizer: Optional[bool] = True,
add_bos_token: Optional[bool] = False, add_bos_token: Optional[bool] = False,
prefix_token_id: Optional[int] = None,
# arguments used for splitting a model across GPUs naively. # arguments used for splitting a model across GPUs naively.
# only used if `parallelize=True`. # only used if `parallelize=True`.
parallelize: Optional[bool] = False, parallelize: Optional[bool] = False,
...@@ -109,7 +110,6 @@ class HFLM(TemplateLM): ...@@ -109,7 +110,6 @@ class HFLM(TemplateLM):
# PEFT and quantization options # PEFT and quantization options
peft: Optional[str] = None, peft: Optional[str] = None,
autogptq: Optional[Union[bool, str]] = False, autogptq: Optional[Union[bool, str]] = False,
prefix_token_id: Optional[int] = None,
**kwargs, **kwargs,
) -> None: ) -> None:
super().__init__() super().__init__()
...@@ -342,9 +342,10 @@ class HFLM(TemplateLM): ...@@ -342,9 +342,10 @@ class HFLM(TemplateLM):
self._world_size = 1 self._world_size = 1
self.custom_prefix_token_id = prefix_token_id self.custom_prefix_token_id = prefix_token_id
eval_logger.info( if prefix_token_id is not None:
f"Loglikelihood prefix token id used in evaluation: {self.prefix_token_id}" eval_logger.info(
) f"Loglikelihood prefix token id used in evaluation: {self.prefix_token_id}"
)
@property @property
def config(self): def config(self):
......
...@@ -42,6 +42,7 @@ class VLLM(TemplateLM): ...@@ -42,6 +42,7 @@ class VLLM(TemplateLM):
tokenizer_mode: Literal["auto", "slow"] = "auto", tokenizer_mode: Literal["auto", "slow"] = "auto",
tokenizer_revision: Optional[str] = None, tokenizer_revision: Optional[str] = None,
add_bos_token: Optional[bool] = False, add_bos_token: Optional[bool] = False,
prefix_token_id: Optional[int] = None,
tensor_parallel_size: int = 1, tensor_parallel_size: int = 1,
quantization: Optional[str] = None, quantization: Optional[str] = None,
max_gen_toks: int = 256, max_gen_toks: int = 256,
...@@ -118,6 +119,11 @@ class VLLM(TemplateLM): ...@@ -118,6 +119,11 @@ class VLLM(TemplateLM):
tokenizer_revision=tokenizer_revision, tokenizer_revision=tokenizer_revision,
) )
self.add_bos_token = add_bos_token self.add_bos_token = add_bos_token
self.custom_prefix_token_id = prefix_token_id
if prefix_token_id is not None:
eval_logger.info(
f"Loglikelihood prefix token id used in evaluation: {self.prefix_token_id}"
)
self._max_gen_toks = max_gen_toks self._max_gen_toks = max_gen_toks
...@@ -126,6 +132,15 @@ class VLLM(TemplateLM): ...@@ -126,6 +132,15 @@ class VLLM(TemplateLM):
# we use EOT because end of *text* is more accurate for what we're doing than end of *sentence* # we use EOT because end of *text* is more accurate for what we're doing than end of *sentence*
return self.tokenizer.eos_token_id return self.tokenizer.eos_token_id
@property
def prefix_token_id(self):
# it is used as prefix for loglikelihood
if self.custom_prefix_token_id is not None:
return self.custom_prefix_token_id
if self.tokenizer.bos_token_id is not None:
return self.tokenizer.bos_token_id
return self.tokenizer.eos_token_id
@property @property
def max_length(self): def max_length(self):
if self._max_length: # if max length manually set, return it if self._max_length: # if max length manually set, return it
......
...@@ -25,8 +25,8 @@ class TEST_VLLM: ...@@ -25,8 +25,8 @@ class TEST_VLLM:
multiple_choice_task.build_all_requests(limit=10, rank=0, world_size=1) multiple_choice_task.build_all_requests(limit=10, rank=0, world_size=1)
MULTIPLE_CH: List[Instance] = multiple_choice_task.instances MULTIPLE_CH: List[Instance] = multiple_choice_task.instances
generate_until_task = task_list["gsm8k"] # type: ignore generate_until_task = task_list["gsm8k"] # type: ignore
generate_until_task.build_all_requests(limit=10, rank=0, world_size=1)
generate_until_task._config.generation_kwargs["max_gen_toks"] = 10 generate_until_task._config.generation_kwargs["max_gen_toks"] = 10
generate_until_task.build_all_requests(limit=10, rank=0, world_size=1)
generate_until: List[Instance] = generate_until_task.instances generate_until: List[Instance] = generate_until_task.instances
rolling_task = task_list["wikitext"] # type: ignore rolling_task = task_list["wikitext"] # type: ignore
rolling_task.build_all_requests(limit=10, rank=0, world_size=1) rolling_task.build_all_requests(limit=10, rank=0, world_size=1)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment