Unverified Commit a0cfe3f6 authored by Anjor Kanekar's avatar Anjor Kanekar Committed by GitHub
Browse files

Add tokenizer backend (#1186)

* separate local flag

* tokenizer_backend

* import order
parent 2b0b6fd8
...@@ -3,7 +3,7 @@ import os ...@@ -3,7 +3,7 @@ import os
import time import time
from collections import defaultdict from collections import defaultdict
from importlib.util import find_spec from importlib.util import find_spec
from typing import List, Optional, Tuple from typing import List, Literal, Optional, Tuple
import transformers import transformers
from tqdm import tqdm from tqdm import tqdm
...@@ -360,6 +360,7 @@ class OpenaiChatCompletionsLM(LM): ...@@ -360,6 +360,7 @@ class OpenaiChatCompletionsLM(LM):
self, self,
model: str = "gpt-3.5-turbo", # GPT model or Local model using HuggingFace model paths model: str = "gpt-3.5-turbo", # GPT model or Local model using HuggingFace model paths
base_url: str = None, base_url: str = None,
tokenizer_backend: Literal["tiktoken", "huggingface"] = "tiktoken",
truncate: bool = False, truncate: bool = False,
revision: Optional[str] = "main", revision: Optional[str] = "main",
trust_remote_code: Optional[bool] = False, trust_remote_code: Optional[bool] = False,
...@@ -388,10 +389,11 @@ class OpenaiChatCompletionsLM(LM): ...@@ -388,10 +389,11 @@ class OpenaiChatCompletionsLM(LM):
) )
self.model = model self.model = model
self.base_url = base_url self.base_url = base_url
self.tokenizer_backend = tokenizer_backend
self.truncate = truncate self.truncate = truncate
# if we have a local model, use HF tokenizer over tiktoken # if we have a local model, use HF tokenizer over tiktoken
if self.base_url: if self.tokenizer_backend == "huggingface":
self.revision = revision self.revision = revision
self.trust_remote_code = trust_remote_code self.trust_remote_code = trust_remote_code
self.use_fast_tokenizer = use_fast_tokenizer self.use_fast_tokenizer = use_fast_tokenizer
...@@ -404,10 +406,14 @@ class OpenaiChatCompletionsLM(LM): ...@@ -404,10 +406,14 @@ class OpenaiChatCompletionsLM(LM):
) )
self.vocab_size = self.tokenizer.vocab self.vocab_size = self.tokenizer.vocab
self.end_of_text_token_id = self.tokenizer.eos_token self.end_of_text_token_id = self.tokenizer.eos_token
else: elif self.tokenizer_backend == "tiktoken":
self.tokenizer = tiktoken.encoding_for_model(self.model) self.tokenizer = tiktoken.encoding_for_model(self.model)
self.vocab_size = self.tokenizer.n_vocab self.vocab_size = self.tokenizer.n_vocab
self.end_of_text_token_id = self.tokenizer.eot_token self.end_of_text_token_id = self.tokenizer.eot_token
else:
raise ValueError(
f"Expected tokenizer_backend to be one of ['tiktoken', 'huggingface'] but got {self.tokenizer_backend}"
)
# Read from environment variable OPENAI_API_KEY # Read from environment variable OPENAI_API_KEY
# Set to EMPTY for local # Set to EMPTY for local
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment