"vscode:/vscode.git/clone" did not exist on "44be6438065ed7b4c094aeae10a4677bfadcd0ed"
Unverified Commit da0a5e36 authored by Anjor Kanekar's avatar Anjor Kanekar Committed by GitHub
Browse files

Remove tokenizer for openai chat completions (#1191)

* remove tokenizer for openai chat completions

* reordering function

* linter

* remove tiktoken import
parent 84790e99
...@@ -3,9 +3,8 @@ import os ...@@ -3,9 +3,8 @@ import os
import time import time
from collections import defaultdict from collections import defaultdict
from importlib.util import find_spec from importlib.util import find_spec
from typing import List, Literal, Optional, Tuple from typing import List, Optional, Tuple
import transformers
from tqdm import tqdm from tqdm import tqdm
from lm_eval import utils from lm_eval import utils
...@@ -360,11 +359,7 @@ class OpenaiChatCompletionsLM(LM): ...@@ -360,11 +359,7 @@ class OpenaiChatCompletionsLM(LM):
self, self,
model: str = "gpt-3.5-turbo", # GPT model or Local model using HuggingFace model paths model: str = "gpt-3.5-turbo", # GPT model or Local model using HuggingFace model paths
base_url: str = None, base_url: str = None,
tokenizer_backend: Literal["tiktoken", "huggingface"] = "tiktoken",
truncate: bool = False, truncate: bool = False,
revision: Optional[str] = "main",
trust_remote_code: Optional[bool] = False,
use_fast_tokenizer: Optional[bool] = True,
**kwargs, **kwargs,
) -> None: ) -> None:
""" """
...@@ -381,7 +376,6 @@ class OpenaiChatCompletionsLM(LM): ...@@ -381,7 +376,6 @@ class OpenaiChatCompletionsLM(LM):
super().__init__() super().__init__()
try: try:
import openai # noqa: E401 import openai # noqa: E401
import tiktoken
except ModuleNotFoundError: except ModuleNotFoundError:
raise Exception( raise Exception(
"attempted to use 'openai' LM type, but package `openai` or `tiktoken` are not installed. \ "attempted to use 'openai' LM type, but package `openai` or `tiktoken` are not installed. \
...@@ -389,32 +383,8 @@ class OpenaiChatCompletionsLM(LM): ...@@ -389,32 +383,8 @@ class OpenaiChatCompletionsLM(LM):
) )
self.model = model self.model = model
self.base_url = base_url self.base_url = base_url
self.tokenizer_backend = tokenizer_backend
self.truncate = truncate self.truncate = truncate
# if we have a local model, use HF tokenizer over tiktoken
if self.tokenizer_backend == "huggingface":
self.revision = revision
self.trust_remote_code = trust_remote_code
self.use_fast_tokenizer = use_fast_tokenizer
self.tokenizer = transformers.AutoTokenizer.from_pretrained(
self.model,
revision=self.revision,
trust_remote_code=self.trust_remote_code,
use_fast_tokenizer=self.use_fast_tokenizer,
)
self.vocab_size = self.tokenizer.vocab
self.end_of_text_token_id = self.tokenizer.eos_token
elif self.tokenizer_backend == "tiktoken":
self.tokenizer = tiktoken.encoding_for_model(self.model)
self.vocab_size = self.tokenizer.n_vocab
self.end_of_text_token_id = self.tokenizer.eot_token
else:
raise ValueError(
f"Expected tokenizer_backend to be one of ['tiktoken', 'huggingface'] but got {self.tokenizer_backend}"
)
# Read from environment variable OPENAI_API_KEY # Read from environment variable OPENAI_API_KEY
# Set to EMPTY for local # Set to EMPTY for local
if self.base_url: if self.base_url:
...@@ -422,10 +392,6 @@ class OpenaiChatCompletionsLM(LM): ...@@ -422,10 +392,6 @@ class OpenaiChatCompletionsLM(LM):
else: else:
self.client = openai.OpenAI() # openai.AsyncOpenAI() self.client = openai.OpenAI() # openai.AsyncOpenAI()
@property
def eot_token_id(self):
return self.end_of_text_token_id
@property @property
def max_length(self) -> int: def max_length(self) -> int:
# Note: the OpenAI API supports up to 2049 tokens, with the first token being the first input token # Note: the OpenAI API supports up to 2049 tokens, with the first token being the first input token
...@@ -445,53 +411,19 @@ class OpenaiChatCompletionsLM(LM): ...@@ -445,53 +411,19 @@ class OpenaiChatCompletionsLM(LM):
# Isn't used because we override _loglikelihood_tokens # Isn't used because we override _loglikelihood_tokens
raise NotImplementedError() raise NotImplementedError()
def tok_encode(self, string: str) -> List[int]:
return self.tokenizer.encode(string)
def tok_decode(self, tokens: List[int]) -> str:
return self.tokenizer.decode(tokens)
def _encode_pair(
self, context: str, continuation: str
) -> Tuple[List[int], List[int]]:
n_spaces = len(context) - len(context.rstrip())
if n_spaces > 0:
continuation = context[-n_spaces:] + continuation
context = context[:-n_spaces]
whole_enc = self.tok_encode(context + continuation)
context_enc = self.tok_encode(context)
context_enc_len = len(context_enc)
continuation_enc = whole_enc[context_enc_len:]
return context_enc, continuation_enc
def generate_until(self, requests) -> List[str]: def generate_until(self, requests) -> List[str]:
res = defaultdict(list) res = defaultdict(list)
re_ords = {} re_ords = {}
def _collate(x):
toks = self.tok_encode(x[0])
return -len(toks), x[0]
# we group requests by their generation_kwargs, # we group requests by their generation_kwargs,
# so that we don't try to execute e.g. greedy sampling and temp=0.8 sampling # so that we don't try to execute e.g. greedy sampling and temp=0.8 sampling
# in the same batch. # in the same batch.
grouper = utils.Grouper(requests, lambda x: str(x.args[1])) grouper = utils.Grouper(requests, lambda x: str(x.args[1]))
for key, reqs in grouper.get_grouped().items(): for key, reqs in grouper.get_grouped().items():
# within each set of reqs for given kwargs, we reorder by token length, descending. # within each set of reqs for given kwargs, we reorder by token length, descending.
re_ords[key] = utils.Reorderer([req.args for req in reqs], _collate) re_ords[key] = utils.Reorderer(
[req.args for req in reqs], lambda x: (-len(x[0]), x[0])
def sameuntil_chunks(xs, size): )
ret = []
lastuntil = xs[0][1]
for x in xs:
if len(ret) >= size or x[1] != lastuntil:
yield ret, lastuntil
ret = []
lastuntil = x[1]
ret.append(x)
if ret:
yield ret, lastuntil
pbar = tqdm(total=len(requests), disable=(self.rank != 0)) pbar = tqdm(total=len(requests), disable=(self.rank != 0))
for key, re_ord in re_ords.items(): for key, re_ord in re_ords.items():
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment