Unverified Commit 8941c067 authored by Michael Goin's avatar Michael Goin Committed by GitHub
Browse files

Merge branch 'EleutherAI:master' into deepsparselm

parents ff24ddbd 008fc2a2
* @haileyschoelkopf @lintangsutawika
* @haileyschoelkopf @lintangsutawika @StellaAthena
FROM nvidia/cuda:11.2.0-cudnn8-runtime-ubuntu20.04
FROM nvidia/cuda:11.2.2-cudnn8-runtime-ubuntu20.04
### Install python 3.10 and set it as default python interpreter
......
......@@ -141,6 +141,15 @@ python main.py \
--tasks hellaswag
```
GGUF or GGML quantized models can be loaded by using `llama-cpp-python` server:
```bash
python main.py \
--model gguf \
--model_args base_url=http://localhost:8000 \
--tasks hellaswag
```
We support wildcards in task names, for example you can run all of the machine-translated lambada tasks via `--task lambada_openai_mt_*`.
We currently only support one prompt per task, which we strive to make the "standard" as defined by the benchmark's authors. If you would like to study how varying prompts causes changes in the evaluation score, check out the [BigScience fork](https://github.com/bigscience-workshop/lm-evaluation-harness) of this repo. We are currently working on upstreaming this capability to `main`.
......
......@@ -7,7 +7,6 @@ import os
import json
import hashlib
import datasets
from sqlitedict import SqliteDict
from tqdm import tqdm
import torch
import torch.nn.functional as F
......@@ -891,6 +890,7 @@ class CachingLM:
:param cache_db: str
Path to cache db
"""
from sqlitedict import SqliteDict
self.lm = lm
self.cache_db = cache_db
if os.path.dirname(cache_db):
......
......@@ -43,7 +43,7 @@ level (for indicating the level of difficulty).
_HOMEPAGE = "https://github.com/chaochun/nlu-asdiv-dataset"
# License available at https://github.com/chaochun/nlu-asdiv-dataset/blob/master/README.md
# License declared at https://github.com/chaochun/nlu-asdiv-dataset/blob/master/README.md
_LICENSE = "CC BY-NC 4.0"
_URLS = "https://github.com/chaochun/nlu-asdiv-dataset/archive/55790e5270bb91ccfa5053194b25732534696b50.zip"
......
......@@ -43,7 +43,7 @@ and perform discrete operations over them (such as addition, counting, or sortin
_HOMEPAGE = "https://allenai.org/data/drop"
# License available at https://allenai.org/data/drop
# License declared at https://allenai.org/data/drop
_LICENSE = "CC BY"
_URLS = {
......
......@@ -51,8 +51,10 @@ The dataset contains questions about the following topics: medicine, nursing, ps
_HOMEPAGE = "https://aghie.github.io/head-qa/"
# License available at https://github.com/aghie/head-qa/blob/master/LICENSE
_LICENSE = "MIT License"
# The Spanish data comes from the "Ministerio de Sanidad, Consumo y Bienestar Social", as indicated here : https://github.com/aghie/head-qa
# This Spanish data seems to follow the intellectual property rights stated here : https://www.sanidad.gob.es/avisoLegal/home.htm
# The English data was translated by the authors of head-qa (https://arxiv.org/pdf/1906.04701.pdf).
_LICENSE = "Custom license"
_URL = "https://drive.google.com/uc?export=download&confirm=t&id=1a_95N5zQQoUCq8IBNVZgziHbeM-QxG2t"
......
......@@ -41,8 +41,10 @@ learning agents.
_HOMEPAGE = "https://github.com/hendrycks/ethics"
# License available at https://github.com/hendrycks/ethics/blob/master/LICENSE
_LICENSE = "MIT License"
# The authors declared that the dataset is not distributed under a copyright or intellectual property (https://arxiv.org/pdf/2008.02275.pdf)
# On Hugging Face, the dataset is distributed under the MIT license (https://huggingface.co/datasets/hendrycks/ethics)
# The common sense portion is from Reddit and might incur some licensing complications.
_LICENSE = "Ambiguous"
_URLS = "https://people.eecs.berkeley.edu/~hendrycks/ethics.tar"
......
......@@ -38,7 +38,7 @@ models to generate answer derivations and explanations.
_HOMEPAGE = "https://github.com/hendrycks/math"
# License available at https://github.com/hendrycks/math/blob/main/LICENSE
# License declared at https://arxiv.org/pdf/2103.03874.pdf
_LICENSE = "MIT License"
_URLS = "https://people.eecs.berkeley.edu/~hendrycks/MATH.tar"
......
......@@ -38,8 +38,8 @@ math, computer science, and philosophy papers.
_HOMEPAGE = "https://pile.eleuther.ai/"
# License available at https://github.com/EleutherAI/the-pile/blob/master/LICENSE
_LICENSE = "MIT License"
# More details at https://arxiv.org/pdf/2201.07311.pdf
_LICENSE = "Multiple licenses"
_URLS = {
"validation": "https://the-eye.eu/public/AI/pile/val.jsonl.zst",
......
......@@ -39,7 +39,7 @@ a teacher who answers the questions by providing short excerpts (spans) from the
_HOMEPAGE = "https://quac.ai/"
# License available at https://quac.ai/
# License declared at https://quac.ai/
_LICENSE = "CC BY-SA 4.0"
_URLS = {
......
......@@ -5,6 +5,7 @@ from . import huggingface
from . import textsynth
from . import deepsparse
from . import dummy
from . import gguf
MODEL_REGISTRY = {
"hf": gpt2.HFLM,
......@@ -17,6 +18,7 @@ MODEL_REGISTRY = {
"textsynth": textsynth.TextSynthLM,
"deepsparse": deepsparse.DeepSparseLM,
"dummy": dummy.DummyLM,
"gguf": gguf.GGUFLM
}
......
import requests
import logging
import time
from tqdm import tqdm
from requests.exceptions import RequestException
import transformers
from lm_eval.utils import Reorderer
from lm_eval.base import BaseLM
logger = logging.getLogger(__name__)
def get_result(logprobs, context_length):
is_greedy = True
offsets = logprobs['text_offset']
tokens = logprobs['tokens']
tokens_logprobs = logprobs['token_logprobs']
idx = 0
while offsets[idx] < context_length:
idx += 1
continuation_logprobs = sum(tokens_logprobs[idx:-1])
for i in range(idx, len(tokens)):
token = tokens[i]
top_tokens = logprobs["top_logprobs"][i]
top_token = max(top_tokens.keys(), key=lambda x: top_tokens[x])
if top_token != token:
is_greedy = False
break
return continuation_logprobs, is_greedy
class GGUFLM(BaseLM):
def __init__(self, base_url, max_length=2048):
super().__init__()
self.base_url = base_url
self.logprobs = 10
self.temperature = 0.0
self.max_length = max_length
def gguf_completion(self, context, continuation=None, stop=None, retries=3, delay=5, **kwargs):
for _ in range(retries):
try:
prompt = context
request = {'prompt': prompt, 'logprobs': self.logprobs,
'temperature': self.temperature}
if continuation:
prompt += continuation
request.update({'prompt': prompt, 'max_tokens': 1, 'echo': True})
if stop is not None:
request['stop'] = stop
response = requests.post(f"{self.base_url}/v1/completions", json=request)
response.raise_for_status()
return response.json()
except RequestException as e:
logger.error(f"RequestException: {e}")
time.sleep(delay) # wait before retrying
else:
raise Exception(f"Failed to get a valid response after {retries} retries.")
def loglikelihood(self, requests):
if not requests:
return []
res = []
for context, continuation in tqdm(requests):
response = self.gguf_completion(context=context, continuation=continuation)
if response and "choices" in response and response["choices"]:
choice = response["choices"][0]
logprobs = choice.get("logprobs")
if logprobs and "token_logprobs" in logprobs and logprobs["token_logprobs"]:
logprob, is_greedy = get_result(logprobs, len(context))
res.append((logprob, is_greedy))
else:
logger.warning("Invalid logprobs data. Expected 'logprobs' to contain 'token_logprobs' list.")
else:
logger.error(f"Invalid response for loglikelihood. Response: {response}")
assert False
return res
def greedy_until(self, requests):
if not requests:
return []
res = []
for request in tqdm(requests):
inp = request[0]
request_args = request[1]
until = request_args["until"]
response = self.gguf_completion(context=inp, stop=until)
if response and "choices" in response and response["choices"]:
choice = response["choices"][0]
if "text" in choice:
generated_text = choice["text"].strip()
res.append(generated_text)
else:
logger.error(f"Invalid response for greedy_until. Response: {response}")
res.append(None) # Add default value in case of error
else:
logger.error(f"Invalid response for greedy_until. Response: {response}")
res.append(None) # Add default value in case of error
return res
def loglikelihood_rolling(self, requests):
raise NotImplementedError("loglikelihood_rolling not yet supported for GGUF models")
def _model_call(self, inps):
# Placeholder implementation
raise NotImplementedError()
def _model_generate(self, context, max_length, eos_token_id):
# Placeholder implementation
raise NotImplementedError()
def tok_encode(self, string: str):
raise NotImplementedError()
def tok_decode(self, tokens):
raise NotImplementedError()
@property
def batch_size(self):
# Placeholder implementation
raise NotImplementedError()
@property
def device(self):
# Placeholder implementation
raise NotImplementedError()
@property
def eot_token_id(self):
# Placeholder implementation
raise NotImplementedError()
def max_length(self):
return self.max_length
@property
def max_gen_toks(self):
# Placeholder implementation
raise NotImplementedError()
......@@ -19,7 +19,6 @@ _DeviceMapping = NewType("DeviceMapping", Mapping[str, Union[int, str, torch.dev
def _get_accelerate_args(
low_cpu_mem_usage: Optional[bool] = True,
device_map_option: Optional[str] = "auto",
max_memory_per_gpu: Optional[Union[int, str]] = None,
max_cpu_memory: Optional[Union[int, str]] = None,
......@@ -39,7 +38,6 @@ def _get_accelerate_args(
args = {}
if max_memory:
args["max_memory"] = max_memory
args["low_cpu_mem_usage"] = low_cpu_mem_usage
args["device_map"] = device_map_option
args["offload_folder"] = offload_folder
return args
......@@ -222,7 +220,6 @@ class HuggingFaceAutoLM(BaseLM):
model_kwargs = {}
if use_accelerate:
model_kwargs = _get_accelerate_args(
low_cpu_mem_usage,
device_map_option,
max_memory_per_gpu,
max_cpu_memory,
......@@ -242,6 +239,7 @@ class HuggingFaceAutoLM(BaseLM):
bnb_4bit_quant_type=bnb_4bit_quant_type,
bnb_4bit_compute_dtype=bnb_4bit_compute_dtype,
bnb_4bit_use_double_quant=bnb_4bit_use_double_quant,
low_cpu_mem_usage=low_cpu_mem_usage,
**model_kwargs,
)
# note: peft_path can be different than pretrained model path
......
......@@ -349,7 +349,7 @@ TASK_REGISTRY = {
**mgsm.construct_tasks(),
**scrolls.construct_tasks(),
**ceval.create_all_tasks(),
**cmmlu.create_all_tasks(),
**cmmlu.create_all_tasks()
}
......
......@@ -2,16 +2,23 @@
CMMLU: Measuring massive multitask language understanding in Chinese
https://arxiv.org/abs/2306.09212
CMMLU is a comprehensive evaluation benchmark specifically designed to evaluate the knowledge and reasoning abilities of LLMs within the context of Chinese language and culture.
CMMLU covers a wide range of subjects, comprising 67 topics that span from elementary to advanced professional levels.
CMMLU is a comprehensive Chinese assessment suite specifically designed to evaluate the advanced knowledge
and reasoning abilities of LLMs within the Chinese language and cultural context. CMMLU covers a wide range of
subjects, comprising 67 topics that span from elementary to advanced professional levels. It includes subjects that
require computational expertise, such as physics and mathematics, as well as disciplines within humanities and
social sciences. Many of these tasks are not easily translatable from other languages due to their specific
contextual nuances and wording. Furthermore, numerous tasks within CMMLU have answers that are specific to
China and may not be universally applicable or considered correct in other regions or languages.
Homepage: https://github.com/haonan-li/CMMLU
Huggingface homepage: https://huggingface.co/datasets/haonan-li/cmmlu
"""
from lm_eval.base import MultipleChoiceTask
import os
from lm_eval.base import MultipleChoiceTask, rf
_CITATION = """
@misc{li2023cmmlu,
title={CMMLU: Measuring massive multitask language understanding in Chinese},
title={CMMLU: Measuring massive multitask language understanding in Chinese},
author={Haonan Li and Yixuan Zhang and Fajri Koto and Yifei Yang and Hai Zhao and Yeyun Gong and Nan Duan and Timothy Baldwin},
year={2023},
eprint={2306.09212},
......@@ -21,7 +28,77 @@ _CITATION = """
"""
SUBJECTS = {
SUBJECTS = [
"agronomy",
"anatomy",
"ancient_chinese",
"arts",
"astronomy",
"business_ethics",
"chinese_civil_service_exam",
"chinese_driving_rule",
"chinese_food_culture",
"chinese_foreign_policy",
"chinese_history",
"chinese_literature",
"chinese_teacher_qualification",
"clinical_knowledge",
"college_actuarial_science",
"college_education",
"college_engineering_hydrology",
"college_law",
"college_mathematics",
"college_medical_statistics",
"college_medicine",
"computer_science",
"computer_security",
"conceptual_physics",
"construction_project_management",
"economics",
"education",
"electrical_engineering",
"elementary_chinese",
"elementary_commonsense",
"elementary_information_and_technology",
"elementary_mathematics",
"ethnology",
"food_science",
"genetics",
"global_facts",
"high_school_biology",
"high_school_chemistry",
"high_school_geography",
"high_school_mathematics",
"high_school_physics",
"high_school_politics",
"human_sexuality",
"international_law",
"journalism",
"jurisprudence",
"legal_and_moral_basis",
"logical",
"machine_learning",
"management",
"marketing",
"marxist_theory",
"modern_chinese",
"nutrition",
"philosophy",
"professional_accounting",
"professional_law",
"professional_medicine",
"professional_psychology",
"public_relations",
"security_study",
"sociology",
"sports_science",
"traditional_chinese_medicine",
"virology",
"world_history",
"world_religions"
]
SUBJECT_MAPPING = {
"agronomy": "农学",
"anatomy": "解剖学",
"ancient_chinese": "古汉语",
......@@ -91,26 +168,103 @@ SUBJECTS = {
"world_religions": "世界宗教",
}
SUBJECT_CATEGORIES = {
"agronomy": ['other'],
"anatomy": ['biology'],
"ancient_chinese": ['linguistics','china specific'],
"arts": ['arts'],
"astronomy": ['physics'],
"business_ethics": ['business'],
"chinese_civil_service_exam": ['politics','china specific'],
"chinese_driving_rule": ['other','china specific'],
"chinese_food_culture": ['culture','china specific'],
"chinese_foreign_policy": ['politics','china specific'],
"chinese_history":['history','china specific'],
"chinese_literature": ['literature','china specific'],
"chinese_teacher_qualification": ['education','china specific'],
"college_actuarial_science":['math'],
"college_education":['education'],
"college_engineering_hydrology": ['engineering'],
"college_law": ['law'],
"college_mathematics": ['math'],
"college_medical_statistics":['statistics'],
"clinical_knowledge": ['other'],
"college_medicine": ['other'],
"computer_science": ['computer science'],
"computer_security": ['other'],
"conceptual_physics": ['physics'],
"construction_project_management": ['other','china specific'],
"economics": ['economics'],
"education": ['education'],
"elementary_chinese":['linguistics','china specific'],
"elementary_commonsense":['other','china specific'],
"elementary_information_and_technology": ['other'],
"electrical_engineering": ['engineering'],
"elementary_mathematics": ['math'],
"ethnology": ['culture','china specific'],
"food_science": ['other'],
"genetics": ['biology'],
"global_facts": ['global'],
"high_school_biology": ['biology'],
"high_school_chemistry": ['chemistry'],
"high_school_geography": ['geography'],
"high_school_mathematics": ['math'],
"high_school_physics": ['physics'],
"high_school_politics": ['politics','china specific'],
"human_sexuality": ['other'],
"international_law": ['law'],
"journalism": ['sociology'],
"jurisprudence": ['law'],
"legal_and_moral_basis": ['other'],
"logical": ['philosophy'],
"machine_learning": ['computer science'],
"management": ['business'],
"marketing": ['business'],
"marxist_theory": ['philosophy'],
"modern_chinese": ['linguistics','china specific'],
"nutrition": ['other'],
"philosophy": ['philosophy'],
"professional_accounting": ['business'],
"professional_law": ['law'],
"professional_medicine": ['other'],
"professional_psychology": ['psychology'],
"public_relations": ['politics'],
"security_study": ['politics'],
"sociology": ['culture'],
"sports_science": ['other'],
"traditional_chinese_medicine": ['other','china specific'],
"virology": ['biology'],
"world_history":['history'],
"world_religions": ['global'],
}
CATEGORIES = {
"STEM": ["physics", "chemistry", "biology", "computer science", "math", "engineering", "statistics"],
"Humanities": ["history", "philosophy", "law", "arts", "literature", "global"],
"Social Science": ['linguistics',"business", "politics", "culture", "economics", "geography", "psychology", "education", "sociology"],
"Other":["other"],
"China specific": ["china specific"],
}
def create_all_tasks():
"""Creates a dictionary of tasks from a list of subjects
:return: {task_name: task}
e.g. {cmmlu-world_history: Task, cmmlu-virology: Task}
e.g. {cmmlu-physician: Task, cmmlu-tax_accountant: Task}
"""
return {f"cmmlu-{sub}": create_task(sub) for sub in SUBJECTS.keys()}
return {f"cmmlu-{sub}": create_task(sub) for sub in SUBJECTS}
def create_task(subject):
class Cmmlu(CmmluSubject):
class CmmluTest(GeneralCmmluTest):
def __init__(self):
super().__init__(subject)
return Cmmlu
return CmmluTest
class CmmluSubject(MultipleChoiceTask):
class GeneralCmmluTest(MultipleChoiceTask):
VERSION = 1
DATASET_PATH = "haonan-li/cmmlu"
DATASET_PATH = os.path.join("haonan-li/cmmlu")
DATASET_NAME = None
def __init__(self, subject):
......@@ -121,61 +275,62 @@ class CmmluSubject(MultipleChoiceTask):
return False
def has_validation_docs(self):
return True
return False
def has_test_docs(self):
return True
def validation_docs(self):
if self.has_validation_docs():
return map(self._process_doc, self.dataset["dev"])
def test_docs(self):
if self.has_test_docs():
return map(self._process_doc, self.dataset["test"])
def _format_subject(self, subject):
words = subject.split("_")
return " ".join(words)
return map(self._process_doc, self.dataset["test"])
def fewshot_context(self, doc, num_fewshot, **kwargs):
subject = self.DATASET_NAME
description = f"以下是关于{SUBJECTS[subject]}的单项选择题,请直接给出正确答案的选项。"
description = f"以下是关于{SUBJECT_MAPPING[subject]}的单项选择题,请直接给出正确答案的选项。"
kwargs["description"] = description
return super().fewshot_context(doc=doc, num_fewshot=num_fewshot, **kwargs)
def _process_doc(self, doc):
def format_example(doc, keys):
"""
<prompt>
题目:<prompt>
A. <choice1>
B. <choice2>
C. <choice3>
D. <choice4>
答案:
答案
"""
question = doc["Question"].strip()
choices = "".join([f"{key}. {doc[key]}\n" for key in keys])
prompt = f"{question}\n{choices}答案:"
choices = "".join(
[f"{key}. {doc[key]}\n" for key in keys]
)
prompt = f"题目:{question}\n{choices}答案是:"
return prompt
keys = ["A", "B", "C", "D"]
return {
"query": format_example(doc, keys),
"choices": keys,
"gold": ord(doc["Answer"]) - ord("A"),
"gold": keys.index(doc["Answer"]),
}
def fewshot_examples(self, k, rnd):
if self._fewshot_docs is None:
self._fewshot_docs = list(map(self._process_doc, self.dataset["dev"]))
# use the unchanged order of the dev set without sampling,
return self._fewshot_docs[:k]
def construct_requests(self, doc, ctx):
lls = [
rf.loglikelihood(ctx, "{}".format(choice))[0] for choice in doc["choices"]
]
return lls
def doc_to_text(self, doc):
return doc["query"]
def doc_to_target(self, doc):
return doc["choices"][doc["gold"]]
def should_decontaminate(self):
return True
......
......@@ -33,34 +33,43 @@ _CITATION = """
class Pubmed_QA(Task):
VERSION = 0
DATASET_PATH = "pubmed_qa"
DATASET_NAME = "pqa_labeled"
DATASET_PATH = "bigbio/pubmed_qa"
DATASET_NAME = "pubmed_qa_labeled_fold0_source"
def has_training_docs(self):
return False
return True
def has_validation_docs(self):
return False
return True
def has_test_docs(self):
return True
def training_docs(self):
if self.has_training_docs():
if self._training_docs is None:
self._training_docs = self.dataset["train"]
return self._training_docs
def validation_docs(self):
if self.has_validation_docs():
return self.dataset["validation"]
def test_docs(self):
if self.has_test_docs():
# HF is labelled as train but its really just for testing
return self.dataset["train"]
return self.dataset["test"]
def doc_to_text(self, doc):
ctxs = "\n".join(doc["context"]["contexts"])
ctxs = "\n".join(doc["CONTEXTS"])
return "Abstract: {}\nQuestion: {}\nAnswer:".format(
ctxs, doc["question"], doc["final_decision"]
ctxs, doc["QUESTION"], doc["final_decision"]
)
def should_decontaminate(self):
return True
def doc_to_decontamination_query(self, doc):
return doc["question"] + " " + "\n".join(doc["context"]["contexts"])
return doc["question"] + " " + "\n".join(doc["CONTEXTS"])
def doc_to_target(self, doc):
return " {}".format(doc["final_decision"])
......
......@@ -12,7 +12,7 @@ setuptools.setup(
long_description=long_description,
long_description_content_type="text/markdown",
url="https://github.com/EleutherAI/lm-evaluation-harness",
packages=setuptools.find_packages(),
packages=setuptools.find_packages(exclude=["scripts.*", "scripts"]),
package_data={"lm_eval": ["**/*.json"]},
include_package_data=True,
classifiers=[
......
import unittest
from unittest.mock import patch
import hashlib
import json
import os
import pickle
from lm_eval.models.gguf import GGUFLM
base_url = "https://matthoffner-ggml-llm-api.hf.space"
def gguf_completion_mock(base_url, **kwargs):
# Generate a hash from the parameters
hash_kwargs = {'base_url': base_url, **kwargs}
hash = hashlib.sha256(json.dumps(hash_kwargs, sort_keys=True).encode('utf-8')).hexdigest()
fname = f"./tests/testdata/ggml_test_{hash}.pkl"
if os.path.exists(fname):
with open(fname, "rb") as fh:
return pickle.load(fh)
else:
print("The file does not exist, attempting to write...")
if 'stop' in kwargs:
result = {"choices": [{"text": f"generated text until {kwargs['stop']}", "logprobs": {"token_logprobs": [-1.2345]}, "finish_reason": "length"}]}
else:
result = {"choices": [{"logprobs": {"token_logprobs": [-1.2345]}, "finish_reason": "length"}]}
try:
os.makedirs(os.path.dirname(fname), exist_ok=True)
print('Writing file at', fname)
with open(fname, "wb") as fh:
pickle.dump(result, fh)
print('File written successfully')
except Exception as e:
print('File writing failed:', e)
return result
class GGUFLMTest(unittest.TestCase):
@patch('lm_eval.models.gguf.GGUFLM.gguf_completion', side_effect=gguf_completion_mock)
def test_loglikelihood(self, gguf_completion_mock):
lm = GGUFLM(base_url)
# Test loglikelihood
requests = [("context1", "continuation1"), ("context2", "continuation2")]
res = lm.loglikelihood(requests)
# Assert the loglikelihood response is correct
expected_res = [(logprob, True) for logprob in [-1.2345, -1.2345]]
self.assertEqual(res, expected_res)
@patch('lm_eval.models.gguf.GGUFLM.gguf_completion', side_effect=gguf_completion_mock)
def test_greedy_until(self, gguf_completion_mock):
lm = GGUFLM(base_url)
# Test greedy_until
requests = [("input1", {"until": "stop1"}), ("input2", {"until": "stop2"})]
res = lm.greedy_until(requests)
# Assert the greedy_until response is correct
expected_res = ["generated text until stop1", "generated text until stop2"]
self.assertEqual(res, expected_res)
@patch('lm_eval.models.gguf.GGUFLM.gguf_completion', side_effect=gguf_completion_mock)
def test_loglikelihood_rolling(self, gguf_completion_mock):
lm = GGUFLM(base_url)
# Test loglikelihood_rolling
requests = ["input1", "input2"]
res = lm.loglikelihood_rolling(requests)
# Assert the loglikelihood_rolling response is correct
expected_res = [(-1.2345, True), (-1.2345, True)]
self.assertEqual(res, expected_res)
if __name__ == "__main__":
unittest.main()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment