llama.py 1.64 KB
Newer Older
Matt Hoffner's avatar
Matt Hoffner committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
import requests
import json
from tqdm import tqdm
from requests.exceptions import RequestException
import time

def llama_completion(base_url, prompt, **kwargs):
    try:
        response = requests.post(f"{base_url}/v1/completions", json=kwargs)
        response.raise_for_status()
        return response.json()
    except RequestException as e:
        print(f"RequestException: {e}")
        return None

class LlamaLM(BaseLM):
    def __init__(self, base_url, truncate=False):
        super().__init__()
        self.base_url = base_url
        self.truncate = truncate

    def loglikelihood(self, requests):
        res = []
        for context, continuation in tqdm(requests):
            response = llama_completion(self.base_url, context, continuation=continuation)
            if response and "logprob" in response:
                logprob = response["logprob"]
                is_greedy = response["is_greedy"]
                res.append((logprob, is_greedy))
            else:
                logger.error("Invalid response for loglikelihood")
                assert False
        return res

    def greedy_until(self, requests):
        if not requests:
            return []

        res = []
        for request in tqdm(requests):
            inp = request[0]
            request_args = request[1]
            until = request_args["until"]
            response = llama_completion(self.base_url, inp, stop=until)
            if response and "text" in response:
                s = response["text"]
                res.append(s)
            else:
                logger.error("Invalid response for greedy_until")
                assert False
        return res