Commit 30296795 authored by lintangsutawika's avatar lintangsutawika
Browse files

update

parent 12e92616
import os import os
import time import time
from typing import List, Tuple from typing import List, Tuple
import copy
from collections import defaultdict
from tqdm import tqdm from tqdm import tqdm
from lm_eval import utils from lm_eval import utils
from lm_eval.api.model import LM from lm_eval.api.model import LM
from lm_eval.api.registry import register_model from lm_eval.api.registry import register_model
...@@ -10,6 +14,7 @@ from openai import OpenAI ...@@ -10,6 +14,7 @@ from openai import OpenAI
client = OpenAI() client = OpenAI()
def oa_chat_completion(**kwargs): def oa_chat_completion(**kwargs):
"""Query OpenAI API for chat completion. """Query OpenAI API for chat completion.
...@@ -40,7 +45,7 @@ class OpenaiChatCompletionsLM(LM): ...@@ -40,7 +45,7 @@ class OpenaiChatCompletionsLM(LM):
REQ_CHUNK_SIZE = 20 REQ_CHUNK_SIZE = 20
def __init__( def __init__(
self, model: str = "gpt-3.5-turbo", truncate: bool = False, batch_size: int = 1 self, model: str = "gpt-3.5-turbo", truncate: bool = False, batch_size: int = 1
) -> None: ) -> None:
""" """
...@@ -70,7 +75,6 @@ class OpenaiChatCompletionsLM(LM): ...@@ -70,7 +75,6 @@ class OpenaiChatCompletionsLM(LM):
self.end_of_text_token_id = self.tokenizer.eot_token self.end_of_text_token_id = self.tokenizer.eot_token
# Read from environment variable OPENAI_API_SECRET_KEY # Read from environment variable OPENAI_API_SECRET_KEY
@property @property
def eot_token_id(self): def eot_token_id(self):
...@@ -102,7 +106,7 @@ class OpenaiChatCompletionsLM(LM): ...@@ -102,7 +106,7 @@ class OpenaiChatCompletionsLM(LM):
return self.tokenizer.decode(tokens) return self.tokenizer.decode(tokens)
def _encode_pair( def _encode_pair(
self, context: str, continuation: str self, context: str, continuation: str
) -> Tuple[List[int], List[int]]: ) -> Tuple[List[int], List[int]]:
n_spaces = len(context) - len(context.rstrip()) n_spaces = len(context) - len(context.rstrip())
if n_spaces > 0: if n_spaces > 0:
...@@ -115,16 +119,20 @@ class OpenaiChatCompletionsLM(LM): ...@@ -115,16 +119,20 @@ class OpenaiChatCompletionsLM(LM):
return context_enc, continuation_enc return context_enc, continuation_enc
def generate_until(self, requests) -> List[str]: def generate_until(self, requests) -> List[str]:
if not requests: res = defaultdict(list)
return [] re_ords = {}
res = []
requests = [req.args for req in requests]
def _collate(x): def _collate(x):
toks = self.tok_encode(x[0]) toks = self.tok_encode(x[0])
return len(toks), x[0] return -len(toks), x[0]
re_ord = utils.Reorderer(requests, _collate) # we group requests by their generation_kwargs,
# so that we don't try to execute e.g. greedy sampling and temp=0.8 sampling
# in the same batch.
grouper = utils.Grouper(requests, lambda x: str(x.args[1]))
for key, reqs in grouper.get_grouped().items():
# within each set of reqs for given kwargs, we reorder by token length, descending.
re_ords[key] = utils.Reorderer([req.args for req in reqs], _collate)
def sameuntil_chunks(xs, size): def sameuntil_chunks(xs, size):
ret = [] ret = []
...@@ -139,25 +147,41 @@ class OpenaiChatCompletionsLM(LM): ...@@ -139,25 +147,41 @@ class OpenaiChatCompletionsLM(LM):
if ret: if ret:
yield ret, lastuntil yield ret, lastuntil
# todo: more intelligent batching for heterogeneous `until` pbar = tqdm(total=len(requests), disable=(self.rank != 0))
for chunk, request_args in tqdm( for key, re_ord in re_ords.items():
list(sameuntil_chunks(re_ord.get_reordered(), self.REQ_CHUNK_SIZE)) chunks = utils.chunks(re_ord.get_reordered(), n=self.REQ_CHUNK_SIZE)
): for chunk in chunks:
inps = [] contexts, all_gen_kwargs = zip(*chunk)
for context, _ in chunk: inps = [{"role": "user", "content": context} for context in contexts]
# context_enc = self.tok_encode(context)
# inp = context_enc[-(self.max_length - self.max_gen_toks):] gen_kwargs = all_gen_kwargs[0]
inps.append({"role": "user", "content": context}) until = None
if isinstance(gen_kwargs, dict):
kwargs = copy.deepcopy(gen_kwargs) # edge case for repeats > 1
if "until" in kwargs.keys():
until = kwargs.pop("until")
if isinstance(until, str):
until = [kwargs]
elif not isinstance(until, list):
raise ValueError(
f"Expected `kwargs['until']` to be of type Union[str,list] but got {until}"
)
else:
raise ValueError(
f"Expected `kwargs` to be of type `dict` but got {kwargs}"
)
# until = request_args.get("until", ["<|endoftext|>"]) if "max_gen_toks" in kwargs.keys():
until = request_args.get("until", None) max_gen_toks = kwargs.pop("max_gen_toks")
else:
max_gen_toks = self.max_gen_toks
response = oa_chat_completion( response = oa_chat_completion(
messages=inps, messages=inps,
model=self.model, model=self.model,
frequency_penalty=self.frequency_penalty, frequency_penalty=self.frequency_penalty,
# logit_bias=self.logit_bias, # logit_bias=self.logit_bias,
max_tokens=self.max_gen_toks, max_tokens=max_gen_toks,
n=self.n, n=self.n,
presence_penalty=self.presence_penalty, presence_penalty=self.presence_penalty,
temperature=self.temperature, temperature=self.temperature,
...@@ -167,21 +191,23 @@ class OpenaiChatCompletionsLM(LM): ...@@ -167,21 +191,23 @@ class OpenaiChatCompletionsLM(LM):
for resp, (context, args_) in zip(response.choices, chunk): for resp, (context, args_) in zip(response.choices, chunk):
s = resp.message.content s = resp.message.content
# until_ = args_.get("until", ["<|endoftext|>"]) if until is not None:
until_ = args_.get("until", None) for term in until:
if until_ is not None:
for term in until_:
if len(term) > 0: if len(term) > 0:
s = s.split(term)[0] s = s.split(term)[0]
# partial caching res[key].append(s)
self.cache_hook.add_partial( self.cache_hook.add_partial(
"generate_until", (context, {"until": until_}), s "generate_until", (context, {"until": until}), s
) )
pbar.update(1)
res[key] = re_ord.get_original(res[key])
pbar.close()
res.append(s) return grouper.get_original(res)
return re_ord.get_original(res)
def loglikelihood(self, requests): def loglikelihood(self, requests):
raise NotImplementedError("No support for logits.") raise NotImplementedError("No support for logits.")
......
import hashlib import hashlib
import json import json
from openai import OpenAI
client = OpenAI()
import os import os
import pickle import pickle
import pytest import pytest
...@@ -10,6 +7,10 @@ import unittest.mock as mock ...@@ -10,6 +7,10 @@ import unittest.mock as mock
import lm_eval.models as models import lm_eval.models as models
from openai import OpenAI
client = OpenAI()
LOGLIKELIHOOD_TEST_CASES = [ LOGLIKELIHOOD_TEST_CASES = [
("The quick brown fox jumps over the lazy", " dog"), ("The quick brown fox jumps over the lazy", " dog"),
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment