Unverified Commit a4192489 authored by kwrobel.eth's avatar kwrobel.eth Committed by GitHub
Browse files

use BOS token in loglikelihood (#1588)



* use BOS token in loglikelihood

* improve comments

* add model arg

* log prefix token id

* log prefix token id

* Update lm_eval/api/model.py
Co-authored-by: default avatarHailey Schoelkopf <65563625+haileyschoelkopf@users.noreply.github.com>

* change name to prefix_token_id

---------
Co-authored-by: default avatarHailey Schoelkopf <65563625+haileyschoelkopf@users.noreply.github.com>
parent 4600d6bf
...@@ -66,11 +66,11 @@ class LM(abc.ABC): ...@@ -66,11 +66,11 @@ class LM(abc.ABC):
multiple chunks, the last input will still a full-sized context. multiple chunks, the last input will still a full-sized context.
Example: Example:
Input tokens: [ 0 1 2 3 4 5 6 7 8 9 ] Input tokens: [ 0 1 2 3 4 5 6 7 8 9 ]
Prefix: EOT Prefix: BOS/EOS
Max context length: 4 Max context length: 4
Resulting input/prediction pairs: Resulting input/prediction pairs:
INPUT: EOT 0 1 2 INPUT: BOS 0 1 2
PRED: 0 1 2 3 PRED: 0 1 2 3
INPUT: 3 4 5 6 INPUT: 3 4 5 6
...@@ -90,7 +90,8 @@ class LM(abc.ABC): ...@@ -90,7 +90,8 @@ class LM(abc.ABC):
:return: list[tuple[float]] :return: list[tuple[float]]
A list of tuples (logprob,) A list of tuples (logprob,)
logprob: float logprob: float
The log probability of `context` conditioned on the EOT token. The log probability of `context` conditioned on the BOS/EOS token.
Can also be overridden for custom cases by `prefix_token_id`.
""" """
pass pass
...@@ -283,6 +284,12 @@ class TemplateLM(LM): ...@@ -283,6 +284,12 @@ class TemplateLM(LM):
def eot_token_id(self): def eot_token_id(self):
pass pass
@property
@abc.abstractmethod
def prefix_token_id(self):
# it is used as prefix for loglikelihood
pass
@abc.abstractmethod @abc.abstractmethod
def tok_encode(self, string: str, **kwargs): def tok_encode(self, string: str, **kwargs):
pass pass
...@@ -316,9 +323,9 @@ class TemplateLM(LM): ...@@ -316,9 +323,9 @@ class TemplateLM(LM):
new_reqs = [] new_reqs = []
for context, continuation in [req.args for req in requests]: for context, continuation in [req.args for req in requests]:
if context == "": if context == "":
# end of text as context # BOS or EOS as context
context_enc, continuation_enc = ( context_enc, continuation_enc = (
[self.eot_token_id], [self.prefix_token_id],
self.tok_encode(continuation), self.tok_encode(continuation),
) )
else: else:
......
...@@ -109,6 +109,7 @@ class HFLM(TemplateLM): ...@@ -109,6 +109,7 @@ class HFLM(TemplateLM):
# PEFT and quantization options # PEFT and quantization options
peft: Optional[str] = None, peft: Optional[str] = None,
autogptq: Optional[Union[bool, str]] = False, autogptq: Optional[Union[bool, str]] = False,
prefix_token_id: Optional[int] = None,
**kwargs, **kwargs,
) -> None: ) -> None:
super().__init__() super().__init__()
...@@ -340,6 +341,11 @@ class HFLM(TemplateLM): ...@@ -340,6 +341,11 @@ class HFLM(TemplateLM):
self._rank = 0 self._rank = 0
self._world_size = 1 self._world_size = 1
self.custom_prefix_token_id = prefix_token_id
eval_logger.info(
f"Loglikelihood prefix token id used in evaluation: {self.prefix_token_id}"
)
@property @property
def config(self): def config(self):
# return the associated transformers.AutoConfig for the given pretrained model. # return the associated transformers.AutoConfig for the given pretrained model.
...@@ -358,6 +364,15 @@ class HFLM(TemplateLM): ...@@ -358,6 +364,15 @@ class HFLM(TemplateLM):
# we use EOT because end of *text* is more accurate for what we're doing than end of *sentence* # we use EOT because end of *text* is more accurate for what we're doing than end of *sentence*
return self.tokenizer.eos_token_id return self.tokenizer.eos_token_id
@property
def prefix_token_id(self):
# it is used as prefix for loglikelihood
if self.custom_prefix_token_id is not None:
return self.custom_prefix_token_id
if self.tokenizer.bos_token_id is not None:
return self.tokenizer.bos_token_id
return self.tokenizer.eos_token_id
@property @property
def max_length(self): def max_length(self):
if self._max_length: # if max length manually set, return it if self._max_length: # if max length manually set, return it
...@@ -815,7 +830,7 @@ class HFLM(TemplateLM): ...@@ -815,7 +830,7 @@ class HFLM(TemplateLM):
utils.make_disjoint_window, utils.make_disjoint_window,
utils.get_rolling_token_windows( utils.get_rolling_token_windows(
token_list=self.tok_encode(string), token_list=self.tok_encode(string),
prefix_token=self.eot_token_id, prefix_token=self.prefix_token_id,
max_seq_len=self.max_length, max_seq_len=self.max_length,
context_len=1, context_len=1,
), ),
......
...@@ -305,6 +305,11 @@ class NEURON_HF(TemplateLM): ...@@ -305,6 +305,11 @@ class NEURON_HF(TemplateLM):
# we use EOT because end of *text* is more accurate for what we're doing than end of *sentence* # we use EOT because end of *text* is more accurate for what we're doing than end of *sentence*
return self.tokenizer.eos_token_id return self.tokenizer.eos_token_id
@property
def prefix_token_id(self):
# it is used as prefix for loglikelihood
return self.tokenizer.bos_token_id or self.tokenizer.eos_token_id
@property @property
def max_length(self): def max_length(self):
if self._max_length: # if max length manually set, return it if self._max_length: # if max length manually set, return it
...@@ -460,7 +465,7 @@ class NEURON_HF(TemplateLM): ...@@ -460,7 +465,7 @@ class NEURON_HF(TemplateLM):
utils.make_disjoint_window, utils.make_disjoint_window,
utils.get_rolling_token_windows( utils.get_rolling_token_windows(
token_list=self.tok_encode(string), token_list=self.tok_encode(string),
prefix_token=self.eot_token_id, prefix_token=self.prefix_token_id,
max_seq_len=self.max_length, max_seq_len=self.max_length,
context_len=1, context_len=1,
), ),
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment