Unverified Commit ba5cdf0f authored by Anjor Kanekar's avatar Anjor Kanekar Committed by GitHub
Browse files

Add TemplateLM boilerplate LM class (#1279)

* loglikelihood refactor using template lm

* linter

* fix whitespace in target + prompt for CoT gsm8k (#1275)

* Make `parallelize=True` vs. `accelerate launch` distinction clearer in docs (#1261)

* Make parallelize=True distinction clearer in documentation.

* run linter

* Allow parameter edits for registered tasks when listed in a benchmark (#1273)

* benchmark yamls allow minor edits of already registered tasks

* add documentation

* removed print

* Fix data-parallel evaluation with quantized models (#1270)

* add WIP device_map overrides

* update handling outside of accelerate launcher

* change .to(device) log to debug level

* run linter

* Rework documentation for explaining local dataset (#1284)

* rewor documentation for explaining local dataset

* fix typo

* Update new_task_guide.md

* Re-add citation

It looks like Google Scholar has [already noticed](https://scholar.google.com/scholar?hl=en&as_sdt=0%2C9...
parent c26a6ac7
...@@ -66,7 +66,7 @@ All three request types take as input `requests` of type `list[Instance]` that h ...@@ -66,7 +66,7 @@ All three request types take as input `requests` of type `list[Instance]` that h
- It should return `(ll,) : Tuple[float]` , a.k.a. solely the *loglikelihood* of producing each piece of text given no starting input. - It should return `(ll,) : Tuple[float]` , a.k.a. solely the *loglikelihood* of producing each piece of text given no starting input.
To allow a model to be evaluated on all types of tasks, you will need to implement these three types of measurements (note that `loglikelihood_rolling` is a special case of `loglikelihood`). For a reference implementation, check out `lm_eval/models/huggingface.py` ! To allow a model to be evaluated on all types of tasks, you will need to implement these three types of measurements (note that `loglikelihood_rolling` is a special case of `loglikelihood`). For a reference implementation, check out `lm_eval/models/huggingface.py` ! Additionally, check out `lm_eval.api.model.TemplateLM` for a class that abstracts away some commonly used functions across LM subclasses, or see if your model would lend itself well to subclassing the `lm_eval.models.huggingface.HFLM` class and overriding just the initialization or a couple methods!
**Tip: be careful of indexing in loglikelihood!** **Tip: be careful of indexing in loglikelihood!**
......
...@@ -247,3 +247,61 @@ class CachingLM: ...@@ -247,3 +247,61 @@ class CachingLM:
def get_cache_hook(self): def get_cache_hook(self):
return CacheHook(self) return CacheHook(self)
class TemplateLM(LM):
"""
A class acting as intermediary between the LM base class
and boilerplate often included in other LM subclasses.
"""
@property
@abc.abstractmethod
def eot_token_id(self):
pass
@abc.abstractmethod
def tok_encode(self, string: str, **kwargs):
pass
@abc.abstractmethod
def _loglikelihood_tokens(self, requests, **kwargs):
pass
def _encode_pair(self, context, continuation):
n_spaces = len(context) - len(context.rstrip())
if n_spaces > 0:
continuation = context[-n_spaces:] + continuation
context = context[:-n_spaces]
whole_enc = self.tok_encode(context + continuation, add_special_tokens=False)
context_enc = self.tok_encode(context, add_special_tokens=False)
context_enc_len = len(context_enc)
continuation_enc = whole_enc[context_enc_len:]
return context_enc, continuation_enc
def loglikelihood(self, requests) -> List[Tuple[float, bool]]:
new_reqs = []
for context, continuation in [req.args for req in requests]:
if context == "":
# end of text as context
context_enc, continuation_enc = (
[self.eot_token_id],
self.tok_encode(continuation),
)
else:
context_enc, continuation_enc = self._encode_pair(context, continuation)
new_reqs.append(((context, continuation), context_enc, continuation_enc))
return self._loglikelihood_tokens(new_reqs)
@abc.abstractmethod
def loglikelihood_rolling(self, requests) -> List[Tuple[float, bool]]:
pass
@abc.abstractmethod
def generate_until(self, requests) -> List[str]:
pass
...@@ -24,7 +24,7 @@ from transformers.models.auto.modeling_auto import ( ...@@ -24,7 +24,7 @@ from transformers.models.auto.modeling_auto import (
from lm_eval import utils from lm_eval import utils
from lm_eval.api.instance import Instance from lm_eval.api.instance import Instance
from lm_eval.api.model import LM from lm_eval.api.model import TemplateLM
from lm_eval.api.registry import register_model from lm_eval.api.registry import register_model
from lm_eval.models.utils import ( from lm_eval.models.utils import (
Collator, Collator,
...@@ -64,7 +64,7 @@ def _get_accelerate_args( ...@@ -64,7 +64,7 @@ def _get_accelerate_args(
@register_model("hf-auto", "hf", "huggingface") @register_model("hf-auto", "hf", "huggingface")
class HFLM(LM): class HFLM(TemplateLM):
""" """
An abstracted Huggingface model class. Enables usage with both models of An abstracted Huggingface model class. Enables usage with both models of
`transformers.AutoModelForCausalLM` and `transformers.AutoModelForSeq2SeqLM` classes. `transformers.AutoModelForCausalLM` and `transformers.AutoModelForSeq2SeqLM` classes.
...@@ -780,39 +780,6 @@ class HFLM(LM): ...@@ -780,39 +780,6 @@ class HFLM(LM):
return logits return logits
def _encode_pair(
self, context: str, continuation: str
) -> Tuple[List[int], List[int]]:
n_spaces = len(context) - len(context.rstrip())
if n_spaces > 0:
continuation = context[-n_spaces:] + continuation
context = context[:-n_spaces]
whole_enc = self.tok_encode(context + continuation, add_special_tokens=False)
context_enc = self.tok_encode(context, add_special_tokens=False)
# whole_enc = self.tok_encode(context + continuation)
# context_enc = self.tok_encode(context, add_special_tokens=False)
context_enc_len = len(context_enc)
continuation_enc = whole_enc[context_enc_len:]
return context_enc, continuation_enc
def loglikelihood(self, requests: List[Instance]) -> List[Tuple[float, bool]]:
new_reqs = []
for context, continuation in [req.args for req in requests]:
if context == "":
# end of text as context
context_enc, continuation_enc = (
[self.eot_token_id],
self.tok_encode(continuation),
)
else:
context_enc, continuation_enc = self._encode_pair(context, continuation)
new_reqs.append(((context, continuation), context_enc, continuation_enc))
return self._loglikelihood_tokens(requests=new_reqs)
def loglikelihood_rolling(self, requests: List[Instance]) -> List[float]: def loglikelihood_rolling(self, requests: List[Instance]) -> List[float]:
loglikelihoods = [] loglikelihoods = []
......
...@@ -15,7 +15,7 @@ from transformers.generation import StoppingCriteriaList ...@@ -15,7 +15,7 @@ from transformers.generation import StoppingCriteriaList
import lm_eval.models.utils import lm_eval.models.utils
from lm_eval import utils from lm_eval import utils
from lm_eval.api.model import LM from lm_eval.api.model import TemplateLM
from lm_eval.api.registry import register_model from lm_eval.api.registry import register_model
from lm_eval.models.utils import stop_sequences_criteria from lm_eval.models.utils import stop_sequences_criteria
...@@ -172,7 +172,7 @@ class CustomNeuronModelForCausalLM(NeuronModelForCausalLM): ...@@ -172,7 +172,7 @@ class CustomNeuronModelForCausalLM(NeuronModelForCausalLM):
@register_model("neuronx") @register_model("neuronx")
class NEURON_HF(LM): class NEURON_HF(TemplateLM):
""" """
Enables usage with on AWS Neuron Enables usage with on AWS Neuron
using the HuggingFace Transformers + Transformers neuronx library. using the HuggingFace Transformers + Transformers neuronx library.
...@@ -447,37 +447,6 @@ class NEURON_HF(LM): ...@@ -447,37 +447,6 @@ class NEURON_HF(LM):
return logits return logits
def _encode_pair(self, context, continuation):
n_spaces = len(context) - len(context.rstrip())
if n_spaces > 0:
continuation = context[-n_spaces:] + continuation
context = context[:-n_spaces]
whole_enc = self.tok_encode(context + continuation, add_special_tokens=False)
context_enc = self.tok_encode(context, add_special_tokens=False)
# whole_enc = self.tok_encode(context + continuation)
# context_enc = self.tok_encode(context, add_special_tokens=False)
context_enc_len = len(context_enc)
continuation_enc = whole_enc[context_enc_len:]
return context_enc, continuation_enc
def loglikelihood(self, requests):
new_reqs = []
for context, continuation in [req.args for req in requests]:
if context == "":
# end of text as context
context_enc, continuation_enc = (
[self.eot_token_id],
self.tok_encode(continuation),
)
else:
context_enc, continuation_enc = self._encode_pair(context, continuation)
new_reqs.append(((context, continuation), context_enc, continuation_enc))
return self._loglikelihood_tokens(new_reqs)
def loglikelihood_rolling(self, requests): def loglikelihood_rolling(self, requests):
loglikelihoods = [] loglikelihoods = []
......
...@@ -8,7 +8,7 @@ from tqdm import tqdm ...@@ -8,7 +8,7 @@ from tqdm import tqdm
import lm_eval.models.utils import lm_eval.models.utils
from lm_eval import utils from lm_eval import utils
from lm_eval.api.model import LM from lm_eval.api.model import LM, TemplateLM
from lm_eval.api.registry import register_model from lm_eval.api.registry import register_model
from lm_eval.models.utils import retry_on_specific_exceptions from lm_eval.models.utils import retry_on_specific_exceptions
from lm_eval.utils import eval_logger from lm_eval.utils import eval_logger
...@@ -75,7 +75,7 @@ def oa_completion(client, chat: bool = False, **kwargs): ...@@ -75,7 +75,7 @@ def oa_completion(client, chat: bool = False, **kwargs):
@register_model("openai-completions", "local-completions") @register_model("openai-completions", "local-completions")
class OpenaiCompletionsLM(LM): class OpenaiCompletionsLM(TemplateLM):
_DEFAULT_MAX_LENGTH = 2048 _DEFAULT_MAX_LENGTH = 2048
def __init__( def __init__(
...@@ -171,41 +171,12 @@ class OpenaiCompletionsLM(LM): ...@@ -171,41 +171,12 @@ class OpenaiCompletionsLM(LM):
# Isn't used because we override _loglikelihood_tokens # Isn't used because we override _loglikelihood_tokens
raise NotImplementedError() raise NotImplementedError()
def tok_encode(self, string: str) -> List[int]: def tok_encode(self, string: str, **kwargs) -> List[int]:
return self.tokenizer.encode(string) return self.tokenizer.encode(string)
def tok_decode(self, tokens: List[int]) -> str: def tok_decode(self, tokens: List[int]) -> str:
return self.tokenizer.decode(tokens) return self.tokenizer.decode(tokens)
def _encode_pair(
self, context: str, continuation: str
) -> Tuple[List[int], List[int]]:
n_spaces = len(context) - len(context.rstrip())
if n_spaces > 0:
continuation = context[-n_spaces:] + continuation
context = context[:-n_spaces]
whole_enc = self.tok_encode(context + continuation)
context_enc = self.tok_encode(context)
context_enc_len = len(context_enc)
continuation_enc = whole_enc[context_enc_len:]
return context_enc, continuation_enc
def loglikelihood(self, requests) -> List[Tuple[float, bool]]:
new_reqs = []
for context, continuation in [req.args for req in requests]:
if context == "":
# end of text as context
context_enc, continuation_enc = (
[self.eot_token_id],
self.tok_encode(continuation),
)
else:
context_enc, continuation_enc = self._encode_pair(context, continuation)
new_reqs.append(((context, continuation), context_enc, continuation_enc))
return self._loglikelihood_tokens(new_reqs)
def _loglikelihood_tokens( def _loglikelihood_tokens(
self, requests, disable_tqdm: bool = False self, requests, disable_tqdm: bool = False
) -> List[Tuple[float, bool]]: ) -> List[Tuple[float, bool]]:
......
...@@ -5,7 +5,7 @@ from typing import List, Literal, Optional, Tuple, Union ...@@ -5,7 +5,7 @@ from typing import List, Literal, Optional, Tuple, Union
from tqdm import tqdm from tqdm import tqdm
from lm_eval.api.instance import Instance from lm_eval.api.instance import Instance
from lm_eval.api.model import LM from lm_eval.api.model import TemplateLM
from lm_eval.api.registry import register_model from lm_eval.api.registry import register_model
from lm_eval.models.utils import Collator, divide from lm_eval.models.utils import Collator, divide
from lm_eval.utils import ( from lm_eval.utils import (
...@@ -35,7 +35,7 @@ def run_inference_one_model( ...@@ -35,7 +35,7 @@ def run_inference_one_model(
@register_model("vllm") @register_model("vllm")
class VLLM(LM): class VLLM(TemplateLM):
_DEFAULT_MAX_LENGTH = 2048 _DEFAULT_MAX_LENGTH = 2048
def __init__( def __init__(
...@@ -194,37 +194,6 @@ class VLLM(LM): ...@@ -194,37 +194,6 @@ class VLLM(LM):
) )
return outputs return outputs
def _encode_pair(
self, context: str, continuation: str
) -> Tuple[List[int], List[int]]:
n_spaces = len(context) - len(context.rstrip())
if n_spaces > 0:
continuation = context[-n_spaces:] + continuation
context = context[:-n_spaces]
whole_enc = self.tok_encode(context + continuation, add_special_tokens=False)
context_enc = self.tok_encode(context, add_special_tokens=False)
context_enc_len = len(context_enc)
continuation_enc = whole_enc[context_enc_len:]
return context_enc, continuation_enc
def loglikelihood(self, requests: List[Instance]) -> List[Tuple[float, bool]]:
new_reqs = []
for context, continuation in [req.args for req in requests]:
if context == "":
# end of text as context
context_enc, continuation_enc = (
[self.eot_token_id],
self.tok_encode(continuation),
)
else:
context_enc, continuation_enc = self._encode_pair(context, continuation)
new_reqs.append(((context, continuation), context_enc, continuation_enc))
return self._loglikelihood_tokens(new_reqs)
def loglikelihood_rolling(self, requests: List[Instance]) -> List[float]: def loglikelihood_rolling(self, requests: List[Instance]) -> List[float]:
loglikelihoods = [] loglikelihoods = []
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment