Unverified Commit da0a5e36 authored by Anjor Kanekar's avatar Anjor Kanekar Committed by GitHub
Browse files

Remove tokenizer for openai chat completions (#1191)

* remove tokenizer for openai chat completions

* reordering function

* linter

* remove tiktoken import
parent 84790e99
......@@ -3,9 +3,8 @@ import os
import time
from collections import defaultdict
from importlib.util import find_spec
from typing import List, Literal, Optional, Tuple
from typing import List, Optional, Tuple
import transformers
from tqdm import tqdm
from lm_eval import utils
......@@ -360,11 +359,7 @@ class OpenaiChatCompletionsLM(LM):
self,
model: str = "gpt-3.5-turbo", # GPT model or Local model using HuggingFace model paths
base_url: str = None,
tokenizer_backend: Literal["tiktoken", "huggingface"] = "tiktoken",
truncate: bool = False,
revision: Optional[str] = "main",
trust_remote_code: Optional[bool] = False,
use_fast_tokenizer: Optional[bool] = True,
**kwargs,
) -> None:
"""
......@@ -381,7 +376,6 @@ class OpenaiChatCompletionsLM(LM):
super().__init__()
try:
import openai # noqa: E401
import tiktoken
except ModuleNotFoundError:
raise Exception(
"attempted to use 'openai' LM type, but package `openai` or `tiktoken` are not installed. \
......@@ -389,32 +383,8 @@ class OpenaiChatCompletionsLM(LM):
)
self.model = model
self.base_url = base_url
self.tokenizer_backend = tokenizer_backend
self.truncate = truncate
# if we have a local model, use HF tokenizer over tiktoken
if self.tokenizer_backend == "huggingface":
self.revision = revision
self.trust_remote_code = trust_remote_code
self.use_fast_tokenizer = use_fast_tokenizer
self.tokenizer = transformers.AutoTokenizer.from_pretrained(
self.model,
revision=self.revision,
trust_remote_code=self.trust_remote_code,
use_fast_tokenizer=self.use_fast_tokenizer,
)
self.vocab_size = self.tokenizer.vocab
self.end_of_text_token_id = self.tokenizer.eos_token
elif self.tokenizer_backend == "tiktoken":
self.tokenizer = tiktoken.encoding_for_model(self.model)
self.vocab_size = self.tokenizer.n_vocab
self.end_of_text_token_id = self.tokenizer.eot_token
else:
raise ValueError(
f"Expected tokenizer_backend to be one of ['tiktoken', 'huggingface'] but got {self.tokenizer_backend}"
)
# Read from environment variable OPENAI_API_KEY
# Set to EMPTY for local
if self.base_url:
......@@ -422,10 +392,6 @@ class OpenaiChatCompletionsLM(LM):
else:
self.client = openai.OpenAI() # openai.AsyncOpenAI()
@property
def eot_token_id(self):
return self.end_of_text_token_id
@property
def max_length(self) -> int:
# Note: the OpenAI API supports up to 2049 tokens, with the first token being the first input token
......@@ -445,53 +411,19 @@ class OpenaiChatCompletionsLM(LM):
# Isn't used because we override _loglikelihood_tokens
raise NotImplementedError()
def tok_encode(self, string: str) -> List[int]:
return self.tokenizer.encode(string)
def tok_decode(self, tokens: List[int]) -> str:
return self.tokenizer.decode(tokens)
def _encode_pair(
self, context: str, continuation: str
) -> Tuple[List[int], List[int]]:
n_spaces = len(context) - len(context.rstrip())
if n_spaces > 0:
continuation = context[-n_spaces:] + continuation
context = context[:-n_spaces]
whole_enc = self.tok_encode(context + continuation)
context_enc = self.tok_encode(context)
context_enc_len = len(context_enc)
continuation_enc = whole_enc[context_enc_len:]
return context_enc, continuation_enc
def generate_until(self, requests) -> List[str]:
res = defaultdict(list)
re_ords = {}
def _collate(x):
toks = self.tok_encode(x[0])
return -len(toks), x[0]
# we group requests by their generation_kwargs,
# so that we don't try to execute e.g. greedy sampling and temp=0.8 sampling
# in the same batch.
grouper = utils.Grouper(requests, lambda x: str(x.args[1]))
for key, reqs in grouper.get_grouped().items():
# within each set of reqs for given kwargs, we reorder by token length, descending.
re_ords[key] = utils.Reorderer([req.args for req in reqs], _collate)
def sameuntil_chunks(xs, size):
ret = []
lastuntil = xs[0][1]
for x in xs:
if len(ret) >= size or x[1] != lastuntil:
yield ret, lastuntil
ret = []
lastuntil = x[1]
ret.append(x)
if ret:
yield ret, lastuntil
re_ords[key] = utils.Reorderer(
[req.args for req in reqs], lambda x: (-len(x[0]), x[0])
)
pbar = tqdm(total=len(requests), disable=(self.rank != 0))
for key, re_ord in re_ords.items():
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment