Unverified Commit 3d1b8f43 authored by Lintang Sutawika's avatar Lintang Sutawika Committed by GitHub
Browse files

Merge branch 'main' into group-agg-rework

parents e200c24e d855d0ba
......@@ -20,13 +20,13 @@ jobs:
with:
fetch-depth: 2 # OR "2" -> To retrieve the preceding commit.
# Uses the tj-actions/changed-files@v37 action to check for changes.
# Uses the tj-actions/changed-files action to check for changes.
# Outputs provided here: https://github.com/tj-actions/changed-files#outputs
# The `files_yaml` input optionally takes a yaml string to specify filters,
# and prepends the filter name to the standard output names.
- name: Check task folders
id: changed-tasks
uses: tj-actions/changed-files@v37.1.2
uses: tj-actions/changed-files@v44.5.2
with:
# tasks checks the tasks folder and api checks the api folder for changes
files_yaml: |
......@@ -56,7 +56,7 @@ jobs:
if: steps.changed-tasks.outputs.tasks_any_modified == 'true' || steps.changed-tasks.outputs.api_any_modified == 'true'
run: |
python -m pip install --upgrade pip
pip install -e '.[dev]' --extra-index-url https://download.pytorch.org/whl/cpu
pip install -e '.[dev,ifeval]' --extra-index-url https://download.pytorch.org/whl/cpu
# Install optional git dependencies
# pip install bleurt@https://github.com/google-research/bleurt/archive/b610120347ef22b494b6d69b4316e303f5932516.zip#egg=bleurt
# if [ -f requirements.txt ]; then pip install -r requirements.txt; fi
......
......@@ -32,7 +32,7 @@ jobs:
env:
SKIP: "no-commit-to-branch,mypy"
uses: pre-commit/action@v3.0.0
uses: pre-commit/action@v3.0.1
# # mypy turned off for now
# - name: Lint with mypy
# run: mypy . --ignore-missing-imports --check-untyped-defs --explicit-package-bases --warn-unreachable
......@@ -56,12 +56,37 @@ jobs:
- name: Install dependencies
run: |
python -m pip install --upgrade pip
pip install -e '.[dev,anthropic,sentencepiece,optimum,deepsparse,sparseml]' --extra-index-url https://download.pytorch.org/whl/cpu
pip install -e '.[dev,anthropic,sentencepiece]' --extra-index-url https://download.pytorch.org/whl/cpu
# Install optional git dependencies
# pip install bleurt@https://github.com/google-research/bleurt/archive/b610120347ef22b494b6d69b4316e303f5932516.zip#egg=bleurt
# if [ -f requirements.txt ]; then pip install -r requirements.txt; fi
- name: Test with pytest
run: python -m pytest --showlocals -s -vv -n=auto
run: python -m pytest --showlocals -s -vv -n=auto --ignore=tests/models/test_neuralmagic.py --ignore=tests/models/test_openvino.py
- name: Archive artifacts
uses: actions/upload-artifact@v3
with:
name: output_results
path: |
test_logs/*
testmodels:
name: External LM Tests
runs-on: ubuntu-latest
timeout-minutes: 30
steps:
- name: Checkout Code
uses: actions/checkout@v4
- name: Set up Python 3.8
uses: actions/setup-python@v5
with:
python-version: 3.8
cache: pip
cache-dependency-path: pyproject.toml
- name: Install dependencies
run: |
python -m pip install --upgrade pip
pip install -e '.[dev,optimum,deepsparse,sparseml]' --extra-index-url https://download.pytorch.org/whl/cpu
- name: Test with pytest
run: python -m pytest tests/models --showlocals -s -vv
- name: Archive artifacts
uses: actions/upload-artifact@v3
with:
......
......@@ -10,6 +10,7 @@ repos:
- id: check-case-conflict
- id: check-json
- id: check-merge-conflict
args: [--assume-in-merge]
- id: check-symlinks
- id: check-yaml
args: ["--unsafe"]
......@@ -28,8 +29,7 @@ repos:
- id: mixed-line-ending
args: [--fix=lf]
- repo: https://github.com/astral-sh/ruff-pre-commit
# Ruff version.
rev: v0.2.2
rev: v0.4.8
hooks:
# Run the linter.
- id: ruff
......@@ -38,7 +38,7 @@ repos:
# Run the formatter.
- id: ruff-format
- repo: https://github.com/codespell-project/codespell
rev: v2.2.6
rev: v2.3.0
hooks:
- id: codespell
exclude: >
......@@ -46,9 +46,9 @@ repos:
.*\.json|ignore.txt|lm_eval/tasks/.*|.*yaml|.*\.ipynb
)$
args: [--check-filenames, --check-hidden, --ignore-words=ignore.txt]
- repo: https://github.com/pre-commit/mirrors-mypy
rev: v1.5.1
hooks:
- id: mypy
additional_dependencies: [".[sentencepiece,multilingual,promptsource,gptq]", "types-PyYAML", "types-requests"]
exclude: ^tests/.*$
# - repo: https://github.com/pre-commit/mirrors-mypy
# rev: v1.5.1
# hooks:
# - id: mypy
# additional_dependencies: [".[sentencepiece,multilingual,promptsource,gptq]", "types-PyYAML", "types-requests"]
# exclude: ^tests/.*$
......@@ -7,6 +7,7 @@
New updates and features include:
- **New Open LLM Leaderboard tasks have been added ! You can find them under the [leaderboard](lm_eval/tasks/leaderboard/README.md) task group.**
- Internal refactoring
- Config-based task creation and configuration
- Easier import and sharing of externally-defined task config YAMLs
......
......@@ -102,12 +102,10 @@ results = lm_eval.simple_evaluate( # call simple_evaluate
)
```
See https://github.com/EleutherAI/lm-evaluation-harness/blob/365fcda9b85bbb6e0572d91976b8daf409164500/lm_eval/evaluator.py#L35 for a full description of all arguments available. All keyword arguments to simple_evaluate share the same role as the command-line flags described previously.
See the `simple_evaluate()` and `evaluate()` functions in [lm_eval/evaluator.py](../lm_eval/evaluator.py#:~:text=simple_evaluate) for a full description of all arguments available. All keyword arguments to simple_evaluate share the same role as the command-line flags described previously.
Additionally, the `evaluate()` function offers the core evaluation functionality provided by the library, but without some of the special handling and simplification + abstraction provided by `simple_evaluate()`.
See https://github.com/EleutherAI/lm-evaluation-harness/blob/365fcda9b85bbb6e0572d91976b8daf409164500/lm_eval/evaluator.py#L173 for more details.
As a brief example usage of `evaluate()`:
```python
......@@ -147,7 +145,7 @@ task_dict = lm_eval.tasks.get_task_dict(
task_manager # A task manager that allows lm_eval to
# load the task during evaluation.
# If none is provided, `get_task_dict`
# will instantiated one itself, but this
# will instantiate one itself, but this
# only includes the stock tasks so users
# will need to set this if including
# custom paths is required.
......
......@@ -110,13 +110,15 @@
"cell_type": "markdown",
"id": "e974cabdbe70b667",
"metadata": {},
"source": ""
"source": []
},
{
"cell_type": "markdown",
"id": "5178ca9445b844e4",
"metadata": {},
"source": "W&B can also be initialized programmatically for use outside the CLI to parse and log the results."
"source": [
"W&B can also be initialized programmatically for use outside the CLI to parse and log the results."
]
},
{
"cell_type": "code",
......@@ -126,7 +128,7 @@
"outputs": [],
"source": [
"import lm_eval\n",
"from lm_eval.logging_utils import WandbLogger\n",
"from lm_eval.loggers import WandbLogger\n",
"\n",
"results = lm_eval.simple_evaluate(\n",
" model=\"hf\",\n",
......
......@@ -237,7 +237,7 @@ def setup_parser() -> argparse.ArgumentParser:
help=(
"Set seed for python's random, numpy, torch, and fewshot sampling.\n"
"Accepts a comma-separated list of 4 values for python's random, numpy, torch, and fewshot sampling seeds, "
"respectively, or a single integer to set the same seed for all three.\n"
"respectively, or a single integer to set the same seed for all four.\n"
f"The values are either an integer or 'None' to not set the seed. Default is `{default_seed_string}` "
"(for backward compatibility).\n"
"E.g. `--seed 0,None,8,52` sets `random.seed(0)`, `torch.manual_seed(8)`, and fewshot sampling seed to 52. "
......@@ -354,11 +354,17 @@ def cli_evaluate(args: Union[argparse.Namespace, None] = None) -> None:
# Respect user's value passed in via CLI, otherwise default to True and add to comma-separated model args
if args.trust_remote_code:
os.environ["HF_DATASETS_TRUST_REMOTE_CODE"] = str(args.trust_remote_code)
args.model_args = (
args.model_args
+ f",trust_remote_code={os.environ['HF_DATASETS_TRUST_REMOTE_CODE']}"
eval_logger.info(
"Passed `--trust_remote_code`, setting environment variable `HF_DATASETS_TRUST_REMOTE_CODE=true`"
)
# HACK: import datasets and override its HF_DATASETS_TRUST_REMOTE_CODE value internally,
# because it's already been determined based on the prior env var before launching our
# script--`datasets` gets imported by lm_eval internally before these lines can update the env.
import datasets
datasets.config.HF_DATASETS_TRUST_REMOTE_CODE = True
args.model_args = args.model_args + ",trust_remote_code=True"
eval_logger.info(f"Selected Tasks: {task_names}")
......
import logging
import math
import random
import re
import string
from collections.abc import Iterable
from typing import List
import evaluate as hf_evaluate
import numpy as np
import sacrebleu
import sklearn.metrics
......@@ -166,7 +167,60 @@ def acc_mutual_info_fn(items): # This is a passthrough function
return items
exact_match = hf_evaluate.load("exact_match")
### the code used in the `exact_match_hf_evaluate` function is ported from
### https://github.com/huggingface/evaluate/blob/main/metrics/exact_match/exact_match.py
### which is under the apache license.
# Copyright 2020 The HuggingFace Datasets Authors and the current dataset script contributor.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
def exact_match_hf_evaluate(
predictions,
references,
regexes_to_ignore=None,
ignore_case=False,
ignore_punctuation=False,
ignore_numbers=False,
):
if regexes_to_ignore is not None:
for s in regexes_to_ignore:
predictions = np.array([re.sub(s, "", x) for x in predictions])
references = np.array([re.sub(s, "", x) for x in references])
else:
predictions = np.asarray(predictions)
references = np.asarray(references)
if ignore_case:
predictions = np.char.lower(predictions)
references = np.char.lower(references)
if ignore_punctuation:
repl_table = string.punctuation.maketrans("", "", string.punctuation)
predictions = np.char.translate(predictions, table=repl_table)
references = np.char.translate(references, table=repl_table)
if ignore_numbers:
repl_table = string.digits.maketrans("", "", string.digits)
predictions = np.char.translate(predictions, table=repl_table)
references = np.char.translate(references, table=repl_table)
score_list = predictions == references
return {"exact_match": np.mean(score_list)}
###
@register_metric(
......@@ -176,7 +230,7 @@ exact_match = hf_evaluate.load("exact_match")
aggregation="mean",
)
def exact_match_fn(**kwargs):
return exact_match.compute(**kwargs)
return exact_match_hf_evaluate(**kwargs)
@register_metric(
......
......@@ -246,9 +246,10 @@ class CachingLM:
# add hook to lm
lm.set_cache_hook(self.get_cache_hook())
def __getattr__(self, attr):
def __getattr__(self, attr: str):
lm_attr = getattr(self.lm, attr)
if not callable(lm_attr):
if attr not in ["loglikelihood", "loglikelihood_rolling", "generate_until"]:
eval_logger.debug(f"Passing through attribute '{attr}' to underlying LM")
return lm_attr
def fn(requests):
......
......@@ -187,9 +187,9 @@ class TaskConfig(dict):
training_split: Optional[str] = None
validation_split: Optional[str] = None
test_split: Optional[str] = None
fewshot_split: Optional[
str
] = None # TODO: assert that this not None if num_fewshot > 0. (?) assert if this is same split as one evaling (?)
fewshot_split: Optional[str] = (
None # TODO: assert that this not None if num_fewshot > 0. (?) assert if this is same split as one evaling (?)
)
# formatting / prompting options.
# see docs/advanced_task_guide.md for more info
process_docs: Optional[Callable] = None
......@@ -212,9 +212,9 @@ class TaskConfig(dict):
filter_list: Optional[Union[str, list]] = None
should_decontaminate: bool = False
doc_to_decontamination_query: Optional[str] = None
metadata: Optional[
dict
] = None # by default, not used in the code. allows for users to pass arbitrary info to tasks
metadata: Optional[dict] = (
None # by default, not used in the code. allows for users to pass arbitrary info to tasks
)
def __post_init__(self) -> None:
if self.generation_kwargs is not None:
......@@ -351,9 +351,9 @@ class Task(abc.ABC):
self._config: TaskConfig = TaskConfig({**config}) if config else TaskConfig()
self._filters = [build_filter_ensemble("none", [["take_first", None]])]
self.fewshot_rnd: Optional[
random.Random
] = None # purposely induce errors in case of improper usage
self.fewshot_rnd: Optional[random.Random] = (
None # purposely induce errors in case of improper usage
)
def download(
self,
......@@ -490,15 +490,16 @@ class Task(abc.ABC):
def build_all_requests(
self,
*,
limit=None,
rank=None,
world_size=None,
cache_requests=False,
rewrite_requests_cache=False,
system_instruction=None,
apply_chat_template=False,
fewshot_as_multiturn=False,
lm=None,
limit: Union[int, None] = None,
rank: int = 0,
world_size: int = 1,
cache_requests: bool = False,
rewrite_requests_cache: bool = False,
system_instruction: Optional[str] = None,
apply_chat_template: bool = False,
fewshot_as_multiturn: bool = False,
chat_template: Optional[Callable] = None,
tokenizer_name: str = "",
) -> None:
"""Build a set of Instances for a task, and store them in task.instances"""
......@@ -513,7 +514,7 @@ class Task(abc.ABC):
if system_instruction is not None
else ""
)
cache_key += f"-tokenizer{lm.tokenizer_name}" if apply_chat_template else ""
cache_key += f"-tokenizer{tokenizer_name}"
cached_instances = load_from_cache(file_name=cache_key)
......@@ -558,7 +559,7 @@ class Task(abc.ABC):
system_instruction,
apply_chat_template,
fewshot_as_multiturn,
lm,
chat_template,
)
# TODO: we should override self.config.repeats if doing greedy gen so users don't waste time+compute
......@@ -1147,7 +1148,7 @@ class ConfigurableTask(Task):
system_instruction: Optional[str] = None,
apply_chat_template: bool = False,
fewshot_as_multiturn: bool = False,
lm=None,
chat_template: Optional[Callable] = None,
) -> str:
"""Returns a fewshot context string that is made up of a prepended description
(if provided), the `num_fewshot` number of examples, and an appended prompt example.
......@@ -1162,8 +1163,8 @@ class ConfigurableTask(Task):
Whether to apply the chat template to the fewshot context.
:param fewshot_as_multiturn: bool
Whether to provide the fewshot examples as a multiturn conversation or a single user turn.
:param lm:
Language model with definition of the tokenizer/function to use for applying the chat template.
:param chat_template: Callable
Chat template to be applied to the fewshot context.
:returns: str
The fewshot context.
"""
......@@ -1210,7 +1211,7 @@ class ConfigurableTask(Task):
example = self.doc_to_text(doc)
if apply_chat_template:
if self.multiple_input:
return lm.apply_chat_template(labeled_examples)
return chat_template(labeled_examples)
if isinstance(example, str):
self.append_target_question(
labeled_examples, example, fewshot_as_multiturn
......@@ -1222,7 +1223,7 @@ class ConfigurableTask(Task):
for ex in example:
chat = deepcopy(labeled_examples)
self.append_target_question(chat, ex, fewshot_as_multiturn)
labeled_examples_list.append(lm.apply_chat_template(chat))
labeled_examples_list.append(chat_template(chat))
return labeled_examples_list
# if example is an integer, append the choice or convert to string
elif isinstance(example, int):
......@@ -1236,7 +1237,7 @@ class ConfigurableTask(Task):
labeled_examples, str(example), fewshot_as_multiturn
)
# return lm.apply_chat_template(labeled_examples)
return lm.apply_chat_template(labeled_examples)
return chat_template(labeled_examples)
else:
if self.multiple_input:
return labeled_examples
......
......@@ -24,7 +24,7 @@ from lm_eval.evaluator_utils import (
run_task_tests,
)
from lm_eval.loggers import EvaluationTracker
from lm_eval.loggers.utils import add_env_info, get_git_commit_hash
from lm_eval.loggers.utils import add_env_info, add_tokenizer_info, get_git_commit_hash
from lm_eval.tasks import (
ConfigurableGroup,
ConfigurableTask,
......@@ -289,6 +289,7 @@ def simple_evaluate(
model_args=model_args,
system_instruction=system_instruction,
chat_template=lm.chat_template if apply_chat_template else None,
fewshot_as_multiturn=fewshot_as_multiturn,
)
results = evaluate(
......@@ -343,6 +344,7 @@ def simple_evaluate(
results["git_hash"] = get_git_commit_hash()
results["date"] = start_date
add_env_info(results) # additional environment info to results
add_tokenizer_info(results, lm) # additional info about tokenizer
return results
else:
return None
......@@ -415,7 +417,12 @@ def evaluate(
system_instruction=system_instruction,
apply_chat_template=apply_chat_template,
fewshot_as_multiturn=fewshot_as_multiturn,
lm=lm,
chat_template=getattr(lm, "apply_chat_template")
if apply_chat_template
else None,
tokenizer_name=getattr(lm, "tokenizer_name", "")
if apply_chat_template
else "",
)
eval_logger.debug(
f"Task: {task_output.task_name}; number of requests on this rank: {len(task.instances)}"
......
......@@ -289,7 +289,7 @@ def prepare_print_tasks(
def consolidate_results(
eval_tasks: List[TaskOutput],
) -> Tuple[dict, dict, dict, dict, dict]:
) -> Tuple[dict, dict, dict, dict, dict, dict]:
"""
@param eval_tasks: list(TaskOutput).
@return: A tuple containing the consolidated results, samples, configs, versions, and num_fewshot.
......@@ -306,6 +306,8 @@ def consolidate_results(
- configs: A defaultdict with task names as keys and task configurations as values.
- versions: A defaultdict with task names as keys and task versions as values.
- num_fewshot: A defaultdict with task names as keys and number of few-shot samples as values.
- higher_is_better: A defaultdict with task names as keys and indicators of whether higher values are better
for each metric as values.
The method then returns the consolidated results, samples, configs, versions, and num_fewshot as a tuple.
"""
......
......@@ -4,7 +4,6 @@ from lm_eval.api.registry import register_filter
@register_filter("decontaminate")
class DecontaminationFilter(Filter):
"""
A filter which evaluates
"""
......
......@@ -62,11 +62,8 @@ class WhitespaceFilter(Filter):
def filter_set(inst):
filtered_resp = []
for resp in inst:
if resp.startswith(" "):
resp = resp[1:]
resp = resp.lstrip()
filtered_resp.append(resp)
return filtered_resp
filtered_resps = [filter_set(resp) for resp in resps]
......
......@@ -17,9 +17,15 @@ from huggingface_hub import (
from lm_eval.utils import (
eval_logger,
get_file_datetime,
get_file_task_name,
get_results_filenames,
get_sample_results_filenames,
handle_non_serializable,
hash_string,
sanitize_list,
sanitize_model_name,
sanitize_task_name,
)
......@@ -42,6 +48,7 @@ class GeneralConfigTracker:
model_name_sanitized: str = None
system_instruction: str = None
system_instruction_sha: str = None
fewshot_as_multiturn: bool = None
chat_template: str = None
chat_template_sha: str = None
start_time: float = None
......@@ -74,19 +81,19 @@ class GeneralConfigTracker:
model_args: str,
system_instruction: str,
chat_template: str,
fewshot_as_multiturn: bool,
) -> None:
"""Logs model parameters and job ID."""
self.model_source = model_source
self.model_name = GeneralConfigTracker._get_model_name(model_args)
self.model_name_sanitized = re.sub(
r"[\"<>:/\|\\?\*\[\]]+", "__", self.model_name
)
self.model_name_sanitized = sanitize_model_name(self.model_name)
self.system_instruction = system_instruction
self.system_instruction_sha = (
hash_string(system_instruction) if system_instruction else None
)
self.chat_template = chat_template
self.chat_template_sha = hash_string(chat_template) if chat_template else None
self.fewshot_as_multiturn = fewshot_as_multiturn
def log_end_time(self) -> None:
"""Logs the end time of the evaluation and calculates the total evaluation time."""
......@@ -255,7 +262,7 @@ class EvaluationTracker:
path.mkdir(parents=True, exist_ok=True)
file_results_samples = path.joinpath(
f"samples_{task_name}_{self.date_id}.json"
f"samples_{task_name}_{self.date_id}.jsonl"
)
for sample in samples:
......@@ -319,23 +326,14 @@ class EvaluationTracker:
Creates a metadata card for the evaluation results dataset and pushes it to the Hugging Face hub.
"""
def get_file_task_name(filename: str) -> str:
return filename[filename.find("_") + 1 : filename.rfind("_")]
def get_file_datetime(filename: str) -> str:
return filename[filename.rfind("_") + 1 :].replace(".json", "")
def sanitize_task_name(task_name: str) -> str:
return re.sub(r"\W", "_", task_name)
eval_logger.info("Recreating metadata card")
repo_id = (
self.hub_results_repo if self.public_repo else self.hub_results_repo_private
)
files_in_repo = self.api.list_repo_files(repo_id=repo_id, repo_type="dataset")
results_files = [f for f in files_in_repo if "/results_" in f and ".json" in f]
sample_files = [f for f in files_in_repo if "/samples_" in f and ".json" in f]
results_files = get_results_filenames(files_in_repo)
sample_files = get_sample_results_filenames(files_in_repo)
# Build a dictionary to store the latest evaluation datetime for:
# - Each tested model and its aggregated results
......
......@@ -110,3 +110,34 @@ def add_env_info(storage: Dict[str, Any]):
"upper_git_hash": upper_dir_commit, # in case this repo is submodule
}
storage.update(added_info)
def add_tokenizer_info(storage: Dict[str, Any], lm):
if getattr(lm, "tokenizer", False):
try:
tokenizer_info = {
"tokenizer_pad_token": [
lm.tokenizer.pad_token,
lm.tokenizer.pad_token_id,
],
"tokenizer_eos_token": [
lm.tokenizer.eos_token,
lm.tokenizer.eos_token_id,
],
"tokenizer_bos_token": [
lm.tokenizer.bos_token,
lm.tokenizer.bos_token_id,
],
"eot_token_id": getattr(lm, "eot_token_id", None),
"max_length": getattr(lm, "max_length", None),
}
storage.update(tokenizer_info)
except Exception as err:
logger.debug(
f"Logging detailed tokenizer info failed with {err}, skipping..."
)
# seems gguf and textsynth do not have tokenizer
else:
logger.debug(
"LM does not have a 'tokenizer' attribute, not logging tokenizer metadata to results."
)
......@@ -307,7 +307,7 @@ please install anthropic via `pip install 'lm-eval[anthropic]'` or `pip install
# defaults to os.environ.get("ANTHROPIC_API_KEY")
self.client = anthropic.Anthropic()
self.temperature = temperature
self.max_token = max_tokens
self.max_tokens = max_tokens
self.tokenizer = self.client.get_tokenizer()
self.kwargs = kwargs
......
......@@ -30,6 +30,7 @@ from lm_eval.api.registry import register_model
from lm_eval.models.utils import (
Collator,
clear_torch_cache,
configure_pad_token,
get_dtype,
pad_and_concat,
stop_sequences_criteria,
......@@ -253,32 +254,10 @@ class HFLM(TemplateLM):
self.logits_cache = logits_cache
self.vocab_size = self.tokenizer.vocab_size
# select (or create) a pad token to use
if self.tokenizer.pad_token:
pass
elif self.tokenizer.unk_token:
self.tokenizer.pad_token_id = self.tokenizer.unk_token_id
elif self.tokenizer.eos_token:
self.tokenizer.pad_token_id = self.tokenizer.eos_token_id
else:
if getattr(self.config, "model_type", None) == "qwen":
# Qwen's trust_remote_code tokenizer does not allow for adding special tokens
self.tokenizer.pad_token = "<|endoftext|>"
elif (
self.tokenizer.__class__.__name__ == "RWKVWorldTokenizer"
or self.tokenizer.__class__.__name__ == "Rwkv5Tokenizer"
):
# The RWKV world tokenizer, does not allow for adding special tokens / setting the pad token (which is set as 0)
# The additional tokenizer name check is needed, as there exists rwkv4 models with neox tokenizer
# ---
# Note that the world tokenizer class name, might change in the future for the final huggingface merge
# https://github.com/huggingface/transformers/pull/26963
assert self.tokenizer.pad_token_id == 0
else:
self.tokenizer.add_special_tokens({"pad_token": "<|pad|>"})
self.tokenizer = configure_pad_token(self.tokenizer, model_config=self.config)
# TODO: override this for Gemma
self.add_bos_token = add_bos_token
if getattr(self.config, "model_type", None) == "gemma":
if getattr(self.config, "model_type", None) in ["gemma", "gemma2"]:
self.add_bos_token = True
eval_logger.info(
f"Model type is '{self.config.model_type}', a BOS token will be used as Gemma underperforms without it."
......
......@@ -288,7 +288,7 @@ class NEURON_HF(TemplateLM):
self.vocab_size = self.tokenizer.vocab_size
self.tokenizer.pad_token_id = self.tokenizer.eos_token_id
self.add_bos_token = self.add_bos_token
self.add_bos_token = add_bos_token
self._max_length = max_length
......
""" TextSynth API
"""TextSynth API
Implementation provided by Fabrice Bellard:
https://github.com/EleutherAI/lm-evaluation-harness/issues/295
......@@ -11,6 +11,7 @@ Example usage:
Homepage: https://textsynth.com/index.html
"""
import logging
import os
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment