Commit 718eaec8 authored by baberabb's avatar baberabb
Browse files

added methods to openai

parent eabadf46
...@@ -3,7 +3,7 @@ from lm_eval.api.registry import register_model ...@@ -3,7 +3,7 @@ from lm_eval.api.registry import register_model
from tqdm import tqdm from tqdm import tqdm
import time import time
from lm_eval.logger import eval_logger from lm_eval.logger import eval_logger
from typing import List, Literal, Any, Tuple, Optional from typing import List, Any, Tuple
def anthropic_completion( def anthropic_completion(
......
import os import os
import time import time
import transformers import transformers # type: ignore
from typing import List, Tuple
import numpy as np
from tqdm import tqdm from tqdm import tqdm
from lm_eval import utils from lm_eval import utils
from lm_eval.api.model import LM from lm_eval.api.model import LM
from lm_eval.api.registry import register_model from lm_eval.api.registry import register_model
def get_result(response, ctxlen): def get_result(response: dict, ctxlen: int) -> Tuple[float, bool]:
"""Process results from OpenAI API response. """Process results from OpenAI API response.
:param response: dict :param response: dict
...@@ -61,7 +59,12 @@ def oa_completion(**kwargs): ...@@ -61,7 +59,12 @@ def oa_completion(**kwargs):
class OpenaiCompletionsLM(LM): class OpenaiCompletionsLM(LM):
REQ_CHUNK_SIZE = 20 REQ_CHUNK_SIZE = 20
def __init__(self, engine, truncate=False): def __init__(
self,
engine: str = "text-davinci-003",
truncate: bool = False,
batch_size: int = 1,
):
""" """
:param engine: str :param engine: str
...@@ -112,13 +115,41 @@ class OpenaiCompletionsLM(LM): ...@@ -112,13 +115,41 @@ class OpenaiCompletionsLM(LM):
# Isn't used because we override _loglikelihood_tokens # Isn't used because we override _loglikelihood_tokens
raise NotImplementedError() raise NotImplementedError()
def tok_encode(self, string: str): def tok_encode(self, string: str) -> List[int]:
return self.tokenizer.encode(string, add_special_tokens=False) return self.tokenizer.encode(string, add_special_tokens=False)
def tok_decode(self, tokens): def tok_decode(self, tokens: List[int]) -> str:
return self.tokenizer.decode(tokens) return self.tokenizer.decode(tokens)
def _loglikelihood_tokens(self, requests, disable_tqdm=False): def _encode_pair(
self, context: str, continuation: str
) -> Tuple[List[int], List[int]]:
n_spaces = len(context) - len(context.rstrip())
if n_spaces > 0:
continuation = context[-n_spaces:] + continuation
context = context[:-n_spaces]
whole_enc = self.tok_encode(context + continuation)
context_enc = self.tok_encode(context)
context_enc_len = len(context_enc)
continuation_enc = whole_enc[context_enc_len:]
return context_enc, continuation_enc
def loglikelihood(self, requests) -> List[List[float]]:
new_reqs = []
for context, continuation in [req.args for req in requests]:
if context == "":
# end of text as context
context_enc, continuation_enc = [self.eot_token_id], self.tok_encode(
continuation
)
else:
context_enc, continuation_enc = self._encode_pair(context, continuation)
new_reqs.append(((context, continuation), context_enc, continuation_enc))
return self._loglikelihood_tokens(new_reqs)
def _loglikelihood_tokens(self, requests, disable_tqdm=False) -> List[List[float]]:
res = [] res = []
def _collate(x): def _collate(x):
...@@ -169,7 +200,7 @@ class OpenaiCompletionsLM(LM): ...@@ -169,7 +200,7 @@ class OpenaiCompletionsLM(LM):
return re_ord.get_original(res) return re_ord.get_original(res)
def greedy_until(self, requests): def greedy_until(self, requests) -> List[str]:
if not requests: if not requests:
return [] return []
res = [] res = []
...@@ -244,3 +275,7 @@ class OpenaiCompletionsLM(LM): ...@@ -244,3 +275,7 @@ class OpenaiCompletionsLM(LM):
def _model_generate(self, context, max_length, eos_token_id): def _model_generate(self, context, max_length, eos_token_id):
# Isn't used because we override greedy_until # Isn't used because we override greedy_until
raise NotImplementedError() raise NotImplementedError()
def loglikelihood_rolling(self, requests):
# Isn't used because we override _loglikelihood_tokens
raise NotImplementedError()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment