Unverified Commit 97bc9780 authored by Hailey Schoelkopf's avatar Hailey Schoelkopf Committed by GitHub
Browse files

Update and rename test_ggml.py to test_gguf.py

parent 2d433a4a
......@@ -4,11 +4,11 @@ import hashlib
import json
import os
import pickle
from lm_eval.models.ggml import GGMLLM
from lm_eval.models.gguf import GGUFLM
base_url = "https://matthoffner-ggml-llm-api.hf.space"
def ggml_completion_mock(base_url, **kwargs):
def gguf_completion_mock(base_url, **kwargs):
# Generate a hash from the parameters
hash_kwargs = {'base_url': base_url, **kwargs}
hash = hashlib.sha256(json.dumps(hash_kwargs, sort_keys=True).encode('utf-8')).hexdigest()
......@@ -37,10 +37,10 @@ def ggml_completion_mock(base_url, **kwargs):
return result
class GGMLLMTest(unittest.TestCase):
@patch('lm_eval.models.ggml.GGMLLM.ggml_completion', side_effect=ggml_completion_mock)
def test_loglikelihood(self, ggml_completion_mock):
lm = GGMLLM(base_url)
class GGUFLMTest(unittest.TestCase):
@patch('lm_eval.models.gguf.GGUFLM.gguf_completion', side_effect=gguf_completion_mock)
def test_loglikelihood(self, gguf_completion_mock):
lm = GGUFLM(base_url)
# Test loglikelihood
requests = [("context1", "continuation1"), ("context2", "continuation2")]
......@@ -50,9 +50,9 @@ class GGMLLMTest(unittest.TestCase):
expected_res = [(logprob, True) for logprob in [-1.2345, -1.2345]]
self.assertEqual(res, expected_res)
@patch('lm_eval.models.ggml.GGMLLM.ggml_completion', side_effect=ggml_completion_mock)
def test_greedy_until(self, ggml_completion_mock):
lm = GGMLLM(base_url)
@patch('lm_eval.models.gguf.GGUFLM.gguf_completion', side_effect=gguf_completion_mock)
def test_greedy_until(self, gguf_completion_mock):
lm = GGUFLM(base_url)
# Test greedy_until
requests = [("input1", {"until": "stop1"}), ("input2", {"until": "stop2"})]
......@@ -62,9 +62,9 @@ class GGMLLMTest(unittest.TestCase):
expected_res = ["generated text until stop1", "generated text until stop2"]
self.assertEqual(res, expected_res)
@patch('lm_eval.models.ggml.GGMLLM.ggml_completion', side_effect=ggml_completion_mock)
def test_loglikelihood_rolling(self, ggml_completion_mock):
lm = GGMLLM(base_url)
@patch('lm_eval.models.gguf.GGUFLM.gguf_completion', side_effect=gguf_completion_mock)
def test_loglikelihood_rolling(self, gguf_completion_mock):
lm = GGUFLM(base_url)
# Test loglikelihood_rolling
requests = ["input1", "input2"]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment