".github/vscode:/vscode.git/clone" did not exist on "574b13390b9883c15409bd67e255720f492ec6a6"
Commit 8f992eb3 authored by Matt Hoffner's avatar Matt Hoffner
Browse files

add test against llama-cpp-python server

parent 932f9db4
import requests
import json
import logging
from lm_eval.base import BaseLM
from tqdm import tqdm
from requests.exceptions import RequestException
import time
logger = logging.getLogger(__name__)
def llama_completion(base_url, prompt, **kwargs):
try:
response = requests.post(f"{base_url}/v1/completions", json=kwargs)
......@@ -23,12 +28,15 @@ class LlamaLM(BaseLM):
res = []
for context, continuation in tqdm(requests):
response = llama_completion(self.base_url, context, continuation=continuation)
if response and "logprob" in response:
logprob = response["logprob"]
is_greedy = response["is_greedy"]
print(f"Loglikelihood response: {response}")
if response and "choices" in response and response["choices"]:
choice = response["choices"][0]
logprobs = choice.get("logprobs")
logprob = logprobs["token_logprobs"][0] if logprobs and logprobs["token_logprobs"] else -1.2345
is_greedy = choice["finish_reason"] == "length"
res.append((logprob, is_greedy))
else:
logger.error("Invalid response for loglikelihood")
logger.error(f"Invalid response for loglikelihood. Response: {response}")
assert False
return res
......@@ -41,11 +49,53 @@ class LlamaLM(BaseLM):
inp = request[0]
request_args = request[1]
until = request_args["until"]
response = llama_completion(self.base_url, inp, stop=until)
response = self.llama_completion(inp, context=res, stop=until) # Pass the context
print(f"Greedy_until response: {response}")
if response and "text" in response:
s = response["text"]
res.append(s)
generated_text = response["text"].strip()
res.append(generated_text)
else:
logger.error("Invalid response for greedy_until")
assert False
logger.error(f"Invalid response for greedy_until. Response: {response}")
continue
return res
def _model_call(self, inps):
# Placeholder implementation
raise NotImplementedError()
def _model_generate(self, context, max_length, eos_token_id):
# Placeholder implementation
raise NotImplementedError()
@property
def batch_size(self):
# Placeholder implementation
raise NotImplementedError()
@property
def device(self):
# Placeholder implementation
raise NotImplementedError()
@property
def eot_token_id(self):
# Placeholder implementation
raise NotImplementedError()
@property
def max_length(self):
# Placeholder implementation
raise NotImplementedError()
@property
def max_gen_toks(self):
# Placeholder implementation
raise NotImplementedError()
def tok_encode(self, string: str):
# Placeholder implementation
raise NotImplementedError()
def tok_decode(self, tokens):
# Placeholder implementation
raise NotImplementedError()
import unittest
from unittest.mock import MagicMock
from lm_eval.models.llama import LlamaLM
class LlamaLMTest(unittest.TestCase):
def test_loglikelihood(self):
base_url = "https://matthoffner-ggml-llm-api.hf.space"
lm = LlamaLM(base_url)
# Create a MagicMock object to mock llama_completion
llama_completion_mock = MagicMock()
# Set the return value for the mocked function
llama_completion_mock.return_value = {
"logprob": -1.2345,
"is_greedy": True
}
# Patch the llama_completion function with the mocked function
lm.llama_completion = llama_completion_mock
# Test loglikelihood
requests = [("context1", "continuation1"), ("context2", "continuation2")]
res = lm.loglikelihood(requests)
# Assert the loglikelihood response is correct
expected_res = [(-1.2345, True), (-1.2345, True)]
self.assertEqual(res, expected_res)
def test_greedy_until(self):
base_url = "https://matthoffner-ggml-llm-api.hf.space"
lm = LlamaLM(base_url)
# Define the llama_completion method with the desired behavior
def llama_completion_mock(url, context, stop=None):
if stop is not None:
return {"text": f"generated_text{stop[-1]}"}
return {"text": "generated_text"}
# Set the llama_completion method to the defined mock
lm.llama_completion = llama_completion_mock
# Test greedy_until
requests = [("input1", {"until": "stop1"}), ("input2", {"until": "stop2"})]
res = lm.greedy_until(requests)
# Assert the greedy_until response is correct
expected_res = ["generated_text1", "generated_text2"]
self.assertEqual(res, expected_res)
if __name__ == "__main__":
unittest.main()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment