Commit 30296795 authored by lintangsutawika's avatar lintangsutawika
Browse files

update

parent 12e92616
import os
import time
from typing import List, Tuple
import copy
from collections import defaultdict
from tqdm import tqdm
from lm_eval import utils
from lm_eval.api.model import LM
from lm_eval.api.registry import register_model
......@@ -10,6 +14,7 @@ from openai import OpenAI
client = OpenAI()
def oa_chat_completion(**kwargs):
"""Query OpenAI API for chat completion.
......@@ -40,7 +45,7 @@ class OpenaiChatCompletionsLM(LM):
REQ_CHUNK_SIZE = 20
def __init__(
self, model: str = "gpt-3.5-turbo", truncate: bool = False, batch_size: int = 1
self, model: str = "gpt-3.5-turbo", truncate: bool = False, batch_size: int = 1
) -> None:
"""
......@@ -70,7 +75,6 @@ class OpenaiChatCompletionsLM(LM):
self.end_of_text_token_id = self.tokenizer.eot_token
# Read from environment variable OPENAI_API_SECRET_KEY
@property
def eot_token_id(self):
......@@ -102,7 +106,7 @@ class OpenaiChatCompletionsLM(LM):
return self.tokenizer.decode(tokens)
def _encode_pair(
self, context: str, continuation: str
self, context: str, continuation: str
) -> Tuple[List[int], List[int]]:
n_spaces = len(context) - len(context.rstrip())
if n_spaces > 0:
......@@ -115,16 +119,20 @@ class OpenaiChatCompletionsLM(LM):
return context_enc, continuation_enc
def generate_until(self, requests) -> List[str]:
if not requests:
return []
res = []
requests = [req.args for req in requests]
res = defaultdict(list)
re_ords = {}
def _collate(x):
toks = self.tok_encode(x[0])
return len(toks), x[0]
return -len(toks), x[0]
re_ord = utils.Reorderer(requests, _collate)
# we group requests by their generation_kwargs,
# so that we don't try to execute e.g. greedy sampling and temp=0.8 sampling
# in the same batch.
grouper = utils.Grouper(requests, lambda x: str(x.args[1]))
for key, reqs in grouper.get_grouped().items():
# within each set of reqs for given kwargs, we reorder by token length, descending.
re_ords[key] = utils.Reorderer([req.args for req in reqs], _collate)
def sameuntil_chunks(xs, size):
ret = []
......@@ -139,25 +147,41 @@ class OpenaiChatCompletionsLM(LM):
if ret:
yield ret, lastuntil
# todo: more intelligent batching for heterogeneous `until`
for chunk, request_args in tqdm(
list(sameuntil_chunks(re_ord.get_reordered(), self.REQ_CHUNK_SIZE))
):
inps = []
for context, _ in chunk:
# context_enc = self.tok_encode(context)
# inp = context_enc[-(self.max_length - self.max_gen_toks):]
inps.append({"role": "user", "content": context})
pbar = tqdm(total=len(requests), disable=(self.rank != 0))
for key, re_ord in re_ords.items():
chunks = utils.chunks(re_ord.get_reordered(), n=self.REQ_CHUNK_SIZE)
for chunk in chunks:
contexts, all_gen_kwargs = zip(*chunk)
inps = [{"role": "user", "content": context} for context in contexts]
gen_kwargs = all_gen_kwargs[0]
until = None
if isinstance(gen_kwargs, dict):
kwargs = copy.deepcopy(gen_kwargs) # edge case for repeats > 1
if "until" in kwargs.keys():
until = kwargs.pop("until")
if isinstance(until, str):
until = [kwargs]
elif not isinstance(until, list):
raise ValueError(
f"Expected `kwargs['until']` to be of type Union[str,list] but got {until}"
)
else:
raise ValueError(
f"Expected `kwargs` to be of type `dict` but got {kwargs}"
)
# until = request_args.get("until", ["<|endoftext|>"])
until = request_args.get("until", None)
if "max_gen_toks" in kwargs.keys():
max_gen_toks = kwargs.pop("max_gen_toks")
else:
max_gen_toks = self.max_gen_toks
response = oa_chat_completion(
messages=inps,
model=self.model,
frequency_penalty=self.frequency_penalty,
# logit_bias=self.logit_bias,
max_tokens=self.max_gen_toks,
max_tokens=max_gen_toks,
n=self.n,
presence_penalty=self.presence_penalty,
temperature=self.temperature,
......@@ -167,21 +191,23 @@ class OpenaiChatCompletionsLM(LM):
for resp, (context, args_) in zip(response.choices, chunk):
s = resp.message.content
# until_ = args_.get("until", ["<|endoftext|>"])
until_ = args_.get("until", None)
if until_ is not None:
for term in until_:
if until is not None:
for term in until:
if len(term) > 0:
s = s.split(term)[0]
# partial caching
res[key].append(s)
self.cache_hook.add_partial(
"generate_until", (context, {"until": until_}), s
"generate_until", (context, {"until": until}), s
)
pbar.update(1)
res[key] = re_ord.get_original(res[key])
pbar.close()
res.append(s)
return re_ord.get_original(res)
return grouper.get_original(res)
def loglikelihood(self, requests):
raise NotImplementedError("No support for logits.")
......
import hashlib
import json
from openai import OpenAI
client = OpenAI()
import os
import pickle
import pytest
......@@ -10,6 +7,10 @@ import unittest.mock as mock
import lm_eval.models as models
from openai import OpenAI
client = OpenAI()
LOGLIKELIHOOD_TEST_CASES = [
("The quick brown fox jumps over the lazy", " dog"),
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment