Commit 1a0f57d7 authored by baberabb's avatar baberabb
Browse files

fix tokenizer

parent 2312c263
......@@ -11,6 +11,7 @@ from lm_eval import utils
try:
from vllm import LLM, SamplingParams
from ray.util.multiprocessing import Pool
from vllm.transformers_utils.tokenizer import get_tokenizer
except ModuleNotFoundError:
pass
......@@ -36,7 +37,9 @@ class VLLM(LM):
dtype: Literal["float16", "bfloat16", "float32", "auto"] = "auto",
revision: Optional[str] = None,
trust_remote_code: Optional[bool] = False,
tokenizer: Optional[str] = None,
tokenizer_mode: Literal["auto", "slow"] = "auto",
tokenizer_revision: Optional[str] = None,
tensor_parallel_size: int = 1,
quantization: Optional[Literal["awq"]] = None,
max_gen_toks: int = 256,
......@@ -78,12 +81,15 @@ please install vllm via `pip install lm-eval[vllm]` or `pip install -e .[vllm]`"
self.model = LLM(**self.model_args)
else:
self.model_args["worker_use_ray"] = True
self.tokenizer = AutoTokenizer.from_pretrained(
pretrained,
revision=revision,
trust_remote_code=trust_remote_code,
use_fast=True if tokenizer_mode == "auto" else False,
)
if tokenizer:
self.tokenizer = tokenizer
else:
self.tokenizer = get_tokenizer(
pretrained,
tokenizer_mode=tokenizer_mode,
trust_remote_code=trust_remote_code,
tokenizer_revision=tokenizer_revision,
)
self.batch_size = batch_size
self._max_length = max_length
self._max_gen_toks = max_gen_toks
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment