Unverified Commit 690ef8ba authored by Mac Misiura's avatar Mac Misiura Committed by GitHub
Browse files

Leverage vllm's `tokenizer_info` endpoint to avoid manual duplication (#3185)

* 

 added an approach to use tokenizer_info endpoint from vllm
Signed-off-by: default avatarm-misiura <mmisiura@redhat.com>

* 🚧

 removed all auto-detection and tokenization logic from `LocalChatCompletion`

* pacify pre-commit

---------
Signed-off-by: default avatarm-misiura <mmisiura@redhat.com>
Co-authored-by: default avatarBaber <baber@hey.com>
parent 655718d0
...@@ -114,7 +114,7 @@ class TemplateAPI(TemplateLM): ...@@ -114,7 +114,7 @@ class TemplateAPI(TemplateLM):
# however the requests can be sent as a string if the API doesn't support token inputs. # however the requests can be sent as a string if the API doesn't support token inputs.
# use tokenized_requests=False # use tokenized_requests=False
tokenizer_backend: Optional[ tokenizer_backend: Optional[
Literal["tiktoken", "huggingface", "None", "none"] Literal["tiktoken", "huggingface", "remote", "None", "none"]
] = "huggingface", ] = "huggingface",
truncate: bool = False, truncate: bool = False,
# number of concurrent requests. More useful if not batching # number of concurrent requests. More useful if not batching
...@@ -132,6 +132,8 @@ class TemplateAPI(TemplateLM): ...@@ -132,6 +132,8 @@ class TemplateAPI(TemplateLM):
revision: Optional[str] = "main", revision: Optional[str] = "main",
use_fast_tokenizer: bool = True, use_fast_tokenizer: bool = True,
verify_certificate: bool = True, verify_certificate: bool = True,
ca_cert_path: Optional[str] = None,
auth_token: Optional[str] = None,
eos_string: str = None, eos_string: str = None,
# timeout in seconds # timeout in seconds
timeout: int = 300, timeout: int = 300,
...@@ -182,6 +184,8 @@ class TemplateAPI(TemplateLM): ...@@ -182,6 +184,8 @@ class TemplateAPI(TemplateLM):
self.tokenized_requests = tokenized_requests self.tokenized_requests = tokenized_requests
self.max_retries = int(max_retries) self.max_retries = int(max_retries)
self.verify_certificate = verify_certificate self.verify_certificate = verify_certificate
self.ca_cert_path = ca_cert_path
self.auth_token = auth_token
self._eos_string = eos_string self._eos_string = eos_string
self.timeout = int(timeout) self.timeout = int(timeout)
self.max_images = int(max_images) self.max_images = int(max_images)
...@@ -218,6 +222,21 @@ class TemplateAPI(TemplateLM): ...@@ -218,6 +222,21 @@ class TemplateAPI(TemplateLM):
f"Passed `base_url={self.base_url}` but using (OpenAI) Tiktoken tokenizer backend. " f"Passed `base_url={self.base_url}` but using (OpenAI) Tiktoken tokenizer backend. "
"Pass `tokenizer_backend=huggingface` and provide the HF tokenizer name if your model does not use Tiktoken." "Pass `tokenizer_backend=huggingface` and provide the HF tokenizer name if your model does not use Tiktoken."
) )
elif self.tokenizer_backend == "remote":
from lm_eval.utils import RemoteTokenizer
if not self.base_url:
raise ValueError(
"base_url is required for remote tokenizer backend"
)
self.tokenizer = RemoteTokenizer(
self.base_url,
self.timeout,
self.verify_certificate,
self.ca_cert_path,
self.auth_token,
)
eval_logger.info(f"Using remote tokenizer from {self.base_url}")
else: else:
import transformers import transformers
...@@ -310,7 +329,7 @@ class TemplateAPI(TemplateLM): ...@@ -310,7 +329,7 @@ class TemplateAPI(TemplateLM):
def apply_chat_template( def apply_chat_template(
self, chat_history: List[Dict[str, str]], add_generation_prompt: bool = True self, chat_history: List[Dict[str, str]], add_generation_prompt: bool = True
) -> Union[str, JsonChatStr]: ) -> Union[str, JsonChatStr, List[Dict]]:
"""Applies a chat template to a list of chat history between user and model.""" """Applies a chat template to a list of chat history between user and model."""
if self.tokenizer_backend == "huggingface" and self.tokenized_requests: if self.tokenizer_backend == "huggingface" and self.tokenized_requests:
return self.tokenizer.apply_chat_template( return self.tokenizer.apply_chat_template(
...@@ -319,6 +338,8 @@ class TemplateAPI(TemplateLM): ...@@ -319,6 +338,8 @@ class TemplateAPI(TemplateLM):
add_generation_prompt=add_generation_prompt, add_generation_prompt=add_generation_prompt,
continue_final_message=not add_generation_prompt, continue_final_message=not add_generation_prompt,
) )
elif self.tokenizer_backend == "remote" and self.tokenized_requests:
return chat_history
else: else:
# bit of a hack. We'll load back before sending to the API # bit of a hack. We'll load back before sending to the API
return JsonChatStr( return JsonChatStr(
...@@ -337,6 +358,8 @@ class TemplateAPI(TemplateLM): ...@@ -337,6 +358,8 @@ class TemplateAPI(TemplateLM):
return self.tokenizer.eos_token_id return self.tokenizer.eos_token_id
elif self.tokenizer_backend == "tiktoken": elif self.tokenizer_backend == "tiktoken":
return self.tokenizer.eot_token return self.tokenizer.eot_token
elif self.tokenizer_backend == "remote":
return self.tokenizer.eos_token_id
@cached_property @cached_property
def eos_string(self) -> Optional[str]: def eos_string(self) -> Optional[str]:
...@@ -347,6 +370,8 @@ class TemplateAPI(TemplateLM): ...@@ -347,6 +370,8 @@ class TemplateAPI(TemplateLM):
return self.tokenizer.eos_token return self.tokenizer.eos_token
elif self.tokenizer_backend == "tiktoken": elif self.tokenizer_backend == "tiktoken":
return self.tokenizer.decode([self.tokenizer.eot_token]) return self.tokenizer.decode([self.tokenizer.eot_token])
elif self.tokenizer_backend == "remote":
return self.tokenizer.eos_token
else: else:
eval_logger.warning( eval_logger.warning(
"Cannot determine EOS string to pass to stop sequence. Manually set by passing `eos_string` to model_args." "Cannot determine EOS string to pass to stop sequence. Manually set by passing `eos_string` to model_args."
...@@ -364,6 +389,8 @@ class TemplateAPI(TemplateLM): ...@@ -364,6 +389,8 @@ class TemplateAPI(TemplateLM):
if self.tokenizer.bos_token_id is not None: if self.tokenizer.bos_token_id is not None:
return self.tokenizer.bos_token_id return self.tokenizer.bos_token_id
return self.tokenizer.eos_token_id return self.tokenizer.eos_token_id
elif self.tokenizer_backend == "remote":
return self.tokenizer.bos_token_id or self.tokenizer.eos_token_id
else: else:
return self.tokenizer.eot_token return self.tokenizer.eot_token
...@@ -396,7 +423,19 @@ class TemplateAPI(TemplateLM): ...@@ -396,7 +423,19 @@ class TemplateAPI(TemplateLM):
encoding = encoding[-left_truncate_len:] encoding = encoding[-left_truncate_len:]
return encoding return encoding
elif self.tokenizer_backend == "remote":
if isinstance(string, str):
encoding = self.tokenizer.encode(string)
else:
encoding = [self.tokenizer.encode(s) for s in string]
if left_truncate_len:
if isinstance(string, str):
encoding = encoding[-left_truncate_len:]
else:
encoding = [enc[-left_truncate_len:] for enc in encoding]
return encoding
else: else:
try: try:
encoding = self.tokenizer.encode(string) encoding = self.tokenizer.encode(string)
...@@ -409,6 +448,8 @@ class TemplateAPI(TemplateLM): ...@@ -409,6 +448,8 @@ class TemplateAPI(TemplateLM):
return self.tokenizer.batch_decode(tokens) return self.tokenizer.batch_decode(tokens)
elif self.tokenizer_backend == "tiktoken": elif self.tokenizer_backend == "tiktoken":
return self.tokenizer.decode_batch(tokens) return self.tokenizer.decode_batch(tokens)
elif self.tokenizer_backend == "remote":
return self.tokenizer.batch_decode(tokens)
def model_call( def model_call(
self, self,
......
...@@ -16,12 +16,46 @@ eval_logger = logging.getLogger(__name__) ...@@ -16,12 +16,46 @@ eval_logger = logging.getLogger(__name__)
class LocalCompletionsAPI(TemplateAPI): class LocalCompletionsAPI(TemplateAPI):
def __init__( def __init__(
self, self,
base_url: str = None, base_url=None,
tokenizer_backend: str = "huggingface", tokenizer_backend="auto",
verify_certificate=True,
ca_cert_path=None,
auth_token=None,
**kwargs, **kwargs,
): ):
# Auto-detect tokenizer backend
if tokenizer_backend == "auto":
if base_url:
from lm_eval.utils import check_remote_tokenizer_support
if check_remote_tokenizer_support(
base_url,
verify_certificate=verify_certificate,
ca_cert_path=ca_cert_path,
auth_token=auth_token,
):
eval_logger.info(
"Auto-detected remote tokenizer support. Using remote tokenizer backend."
)
tokenizer_backend = "remote"
else:
eval_logger.info(
"Remote tokenizer not supported. Using huggingface tokenizer backend."
)
tokenizer_backend = "huggingface"
else:
eval_logger.warning(
"No base_url provided. Using huggingface tokenizer backend."
)
tokenizer_backend = "huggingface"
super().__init__( super().__init__(
base_url=base_url, tokenizer_backend=tokenizer_backend, **kwargs base_url=base_url,
tokenizer_backend=tokenizer_backend,
verify_certificate=verify_certificate,
ca_cert_path=ca_cert_path,
auth_token=auth_token,
**kwargs,
) )
def _create_payload( def _create_payload(
...@@ -106,20 +140,28 @@ class LocalCompletionsAPI(TemplateAPI): ...@@ -106,20 +140,28 @@ class LocalCompletionsAPI(TemplateAPI):
@register_model("local-chat-completions") @register_model("local-chat-completions")
class LocalChatCompletion(LocalCompletionsAPI): class LocalChatCompletion(LocalCompletionsAPI):
"""
Minimal chat-completions wrapper.
- Only accepts messages as list[dict].
- No tokenization or template logic.
- Use with --apply_chat_template or ensure upstream formats messages correctly.
"""
def __init__( def __init__(
self, self,
base_url: str = None, base_url=None,
tokenizer_backend: str = None, verify_certificate=True,
tokenized_requests: bool = False, ca_cert_path=None,
auth_token=None,
**kwargs, **kwargs,
): ):
eval_logger.warning(
"chat-completions endpoint requires the `--apply_chat_template` flag."
)
super().__init__( super().__init__(
base_url=base_url, base_url=base_url,
tokenizer_backend=tokenizer_backend, tokenizer_backend=None,
tokenized_requests=tokenized_requests, tokenized_requests=None,
verify_certificate=verify_certificate,
ca_cert_path=ca_cert_path,
auth_token=auth_token,
**kwargs, **kwargs,
) )
if self._batch_size > 1: if self._batch_size > 1:
...@@ -137,9 +179,13 @@ class LocalChatCompletion(LocalCompletionsAPI): ...@@ -137,9 +179,13 @@ class LocalChatCompletion(LocalCompletionsAPI):
eos=None, eos=None,
**kwargs, **kwargs,
) -> dict: ) -> dict:
assert type(messages) is not str, ( assert isinstance(messages, list) and all(
"chat-completions require the --apply_chat_template flag." isinstance(m, dict) for m in messages
), (
"LocalChatCompletion expects messages as list[dict]. "
"If you see this error, ensure --apply_chat_template is set or upstream code formats messages correctly."
) )
gen_kwargs = gen_kwargs or {}
gen_kwargs.pop("do_sample", False) gen_kwargs.pop("do_sample", False)
if "max_tokens" in gen_kwargs: if "max_tokens" in gen_kwargs:
max_tokens = gen_kwargs.pop("max_tokens") max_tokens = gen_kwargs.pop("max_tokens")
......
...@@ -8,12 +8,14 @@ import json ...@@ -8,12 +8,14 @@ import json
import logging import logging
import os import os
import re import re
import threading
from dataclasses import asdict, is_dataclass from dataclasses import asdict, is_dataclass
from itertools import islice from itertools import islice
from pathlib import Path from pathlib import Path
from typing import Any, Callable, Generator, List, Optional, Tuple from typing import Any, Callable, Generator, List, Optional, Tuple
import numpy as np import numpy as np
import requests
import yaml import yaml
from jinja2 import BaseLoader, Environment, StrictUndefined from jinja2 import BaseLoader, Environment, StrictUndefined
...@@ -623,3 +625,218 @@ def hash_dict_images(data_dict): ...@@ -623,3 +625,218 @@ def hash_dict_images(data_dict):
if importlib.util.find_spec("PIL") if importlib.util.find_spec("PIL")
else data_dict else data_dict
) )
class RemoteTokenizer:
"""
Minimal robust tokenizer that uses vLLM server's tokenizer endpoints.
"""
def __init__(
self,
base_url: str,
timeout: int = 30,
verify_certificate: bool = True,
ca_cert_path: Optional[str] = None,
auth_token: Optional[str] = None,
max_retries: int = 3,
):
self.timeout = timeout
self.max_retries = max_retries
self._lock = threading.RLock()
self._tokenizer_info = None
self._chat_template_obj = None
# Certificate logic
self.cert_config = (
ca_cert_path if verify_certificate and ca_cert_path else verify_certificate
)
# Auth header logic
self.headers = {"Content-Type": "application/json"}
if auth_token:
self.headers["Authorization"] = f"Bearer {auth_token}"
# Normalize base URL - remove API endpoints to get server base
self.base_url = (
base_url.replace("/v1/completions", "")
.replace("/v1/chat/completions", "")
.rstrip("/")
)
# Use a session for connection pooling
self.session = requests.Session()
self.session.headers.update(self.headers)
# Validate server supports tokenizer_info endpoint
self._validate_server()
def _request_with_retries(self, method, url, **kwargs):
last_exc = None
for _ in range(self.max_retries):
try:
resp = self.session.request(
method,
url,
timeout=kwargs.pop("timeout", self.timeout),
verify=self.cert_config,
**kwargs,
)
resp.raise_for_status()
return resp
except requests.RequestException as e:
last_exc = e
raise RuntimeError(
f"RemoteTokenizer: {method} {url} failed after {self.max_retries} attempts: {last_exc}"
)
def _validate_server(self):
url = f"{self.base_url}/tokenizer_info"
resp = self._request_with_retries("GET", url)
if resp.status_code != 200:
raise RuntimeError(
f"Server does not support tokenizer_info endpoint. Status: {resp.status_code}"
)
@property
def tokenizer_info(self) -> dict:
with self._lock:
if self._tokenizer_info is None:
url = f"{self.base_url}/tokenizer_info"
resp = self._request_with_retries("GET", url)
self._tokenizer_info = resp.json()
return self._tokenizer_info
@property
def eos_token(self) -> Optional[str]:
return self.tokenizer_info.get("eos_token")
@property
def bos_token(self) -> Optional[str]:
return self.tokenizer_info.get("bos_token")
@property
def pad_token(self) -> Optional[str]:
return self.tokenizer_info.get("pad_token")
@property
def eos_token_id(self) -> Optional[int]:
if self.eos_token is None:
return None
return self.encode(self.eos_token)[0]
@property
def bos_token_id(self) -> Optional[int]:
if self.bos_token is None:
return None
return self.encode(self.bos_token)[0]
@property
def eot_token(self) -> Optional[int]:
return self.eos_token_id
def encode(self, text: str) -> List[int]:
url = f"{self.base_url}/tokenize"
payload = {"prompt": text, "add_special_tokens": False}
resp = self._request_with_retries("POST", url, json=payload)
tokens = resp.json().get("tokens")
if not isinstance(tokens, list):
raise RuntimeError("Malformed response from /tokenize endpoint.")
return tokens
def decode(self, tokens: List[int]) -> str:
url = f"{self.base_url}/detokenize"
payload = {"tokens": tokens}
resp = self._request_with_retries("POST", url, json=payload)
prompt = resp.json().get("prompt")
if not isinstance(prompt, str):
raise RuntimeError("Malformed response from /detokenize endpoint.")
return prompt
def batch_decode(self, tokens_list: List[List[int]]) -> List[str]:
return [self.decode(tokens) for tokens in tokens_list]
def apply_chat_template(
self, chat_history: list, add_generation_prompt: bool = True, **kwargs
) -> str:
with self._lock:
if self._chat_template_obj is None:
template_str = self.tokenizer_info.get("chat_template")
if not template_str:
raise ValueError("No chat template available from server")
self._chat_template_obj = env.from_string(template_str)
return self._chat_template_obj.render(
messages=chat_history, add_generation_prompt=add_generation_prompt, **kwargs
)
def __call__(self, text: str, add_special_tokens: bool = False, **kwargs) -> dict:
tokens = self.encode(text)
return {"input_ids": tokens}
def check_remote_tokenizer_support(
base_url: str,
timeout: int = 5,
verify_certificate: bool = True,
ca_cert_path: Optional[str] = None,
auth_token: Optional[str] = None,
max_retries: int = 3,
) -> bool:
"""
Check if server supports remote tokenizer endpoints.
Returns True if both /tokenizer_info and /tokenize endpoints are available and functional, False otherwise.
"""
if not base_url:
return False
server_base = (
base_url.replace("/v1/completions", "")
.replace("/v1/chat/completions", "")
.rstrip("/")
)
cert_config = (
ca_cert_path if verify_certificate and ca_cert_path else verify_certificate
)
headers = {"Content-Type": "application/json"}
if auth_token:
headers["Authorization"] = f"Bearer {auth_token}"
session = requests.Session()
session.headers.update(headers)
def _request_with_retries(method, url, **kwargs):
for _ in range(max_retries):
try:
resp = session.request(
method,
url,
timeout=kwargs.pop("timeout", timeout),
verify=cert_config,
**kwargs,
)
resp.raise_for_status()
return resp
except requests.RequestException:
pass
return None
# Check /tokenizer_info
info_url = f"{server_base}/tokenizer_info"
resp = _request_with_retries("GET", info_url)
if not resp:
return False
info = resp.json()
if not isinstance(info, dict) or "eos_token" not in info:
return False
# Check /tokenize
tokenize_url = f"{server_base}/tokenize"
test_payload = {"prompt": "test", "add_special_tokens": False}
resp = _request_with_retries("POST", tokenize_url, json=test_payload)
if not resp:
return False
tokens = resp.json().get("tokens")
if not isinstance(tokens, list):
return False
return True
...@@ -226,3 +226,99 @@ def test_get_batched_requests_with_no_ssl( ...@@ -226,3 +226,99 @@ def test_get_batched_requests_with_no_ssl(
mock_connector.assert_called_with(limit=2, ssl=False) mock_connector.assert_called_with(limit=2, ssl=False)
assert result_batches assert result_batches
def test_local_completionsapi_remote_tokenizer_authenticated(monkeypatch):
captured = {}
class DummyTokenizer:
def __init__(
self, base_url, timeout, verify_certificate, ca_cert_path, auth_token
):
captured.update(locals())
monkeypatch.setattr("lm_eval.utils.RemoteTokenizer", DummyTokenizer)
LocalCompletionsAPI(
base_url="https://secure-server",
tokenizer_backend="remote",
verify_certificate=True,
ca_cert_path="secure.crt",
auth_token="secure-token",
)
assert captured["base_url"] == "https://secure-server"
assert captured["verify_certificate"] is True
assert captured["ca_cert_path"] == "secure.crt"
assert captured["auth_token"] == "secure-token"
def test_local_completionsapi_remote_tokenizer_unauthenticated(monkeypatch):
captured = {}
class DummyTokenizer:
def __init__(
self, base_url, timeout, verify_certificate, ca_cert_path, auth_token
):
captured.update(locals())
monkeypatch.setattr("lm_eval.utils.RemoteTokenizer", DummyTokenizer)
LocalCompletionsAPI(
base_url="http://localhost:8000",
tokenizer_backend="remote",
verify_certificate=False,
ca_cert_path=None,
auth_token=None,
)
assert captured["base_url"] == "http://localhost:8000"
assert captured["verify_certificate"] is False
assert captured["ca_cert_path"] is None
assert captured["auth_token"] is None
def test_localchatcompletion_remote_tokenizer_authenticated(monkeypatch):
captured = {}
class DummyTokenizer:
def __init__(
self, base_url, timeout, verify_certificate, ca_cert_path, auth_token
):
captured.update(locals())
monkeypatch.setattr("lm_eval.utils.RemoteTokenizer", DummyTokenizer)
from lm_eval.models.openai_completions import LocalChatCompletion
LocalChatCompletion(
base_url="https://secure-server",
tokenizer_backend="remote",
verify_certificate=True,
ca_cert_path="secure.crt",
auth_token="secure-token",
)
assert captured["base_url"] == "https://secure-server"
assert captured["verify_certificate"] is True
assert captured["ca_cert_path"] == "secure.crt"
assert captured["auth_token"] == "secure-token"
def test_localchatcompletion_remote_tokenizer_unauthenticated(monkeypatch):
captured = {}
class DummyTokenizer:
def __init__(
self, base_url, timeout, verify_certificate, ca_cert_path, auth_token
):
captured.update(locals())
monkeypatch.setattr("lm_eval.utils.RemoteTokenizer", DummyTokenizer)
from lm_eval.models.openai_completions import LocalChatCompletion
LocalChatCompletion(
base_url="http://localhost:8000",
tokenizer_backend="remote",
verify_certificate=False,
ca_cert_path=None,
auth_token=None,
)
assert captured["base_url"] == "http://localhost:8000"
assert captured["verify_certificate"] is False
assert captured["ca_cert_path"] is None
assert captured["auth_token"] is None
...@@ -12,6 +12,8 @@ from lm_eval.api.metrics import ( ...@@ -12,6 +12,8 @@ from lm_eval.api.metrics import (
) )
from lm_eval.models.utils import Collator from lm_eval.models.utils import Collator
from lm_eval.utils import ( from lm_eval.utils import (
RemoteTokenizer,
check_remote_tokenizer_support,
get_rolling_token_windows, get_rolling_token_windows,
make_disjoint_window, make_disjoint_window,
) )
...@@ -396,3 +398,146 @@ def test_aggregate_stderrs(samples): ...@@ -396,3 +398,146 @@ def test_aggregate_stderrs(samples):
mean_stderr(list(itertools.chain.from_iterable(samples))), mean_stderr(list(itertools.chain.from_iterable(samples))),
atol=1.0e-3, atol=1.0e-3,
) )
def test_remote_tokenizer_custom_cert_and_token(monkeypatch):
class DummyResponse:
status_code = 200
def json(self):
return {
"name_or_path": "mock",
"chat_template": "{{ messages[0].content }}",
}
def raise_for_status(self):
pass
monkeypatch.setattr("os.path.exists", lambda path: True)
monkeypatch.setattr(
"requests.Session.request", lambda self, method, url, **kwargs: DummyResponse()
)
tokenizer = RemoteTokenizer(
base_url="https://mock-server",
verify_certificate=True,
ca_cert_path="dummy.crt",
auth_token="dummy-token",
)
assert tokenizer.cert_config == "dummy.crt"
assert tokenizer.headers["Authorization"] == "Bearer dummy-token"
assert tokenizer.tokenizer_info["name_or_path"] == "mock"
def test_remote_tokenizer_no_cert(monkeypatch):
class DummyResponse:
status_code = 200
def json(self):
return {"name_or_path": "mock"}
def raise_for_status(self):
pass
monkeypatch.setattr("os.path.exists", lambda path: True)
monkeypatch.setattr(
"requests.Session.request", lambda self, method, url, **kwargs: DummyResponse()
)
tokenizer = RemoteTokenizer(
base_url="https://mock-server",
verify_certificate=True,
ca_cert_path=None,
auth_token="dummy-token",
)
assert tokenizer.cert_config is True
assert tokenizer.headers["Authorization"] == "Bearer dummy-token"
assert tokenizer.tokenizer_info["name_or_path"] == "mock"
def test_remote_tokenizer_http_url(monkeypatch):
class DummyResponse:
status_code = 200
def json(self):
return {"name_or_path": "mock"}
def raise_for_status(self):
pass
monkeypatch.setattr("os.path.exists", lambda path: True)
monkeypatch.setattr(
"requests.Session.request", lambda self, method, url, **kwargs: DummyResponse()
)
tokenizer = RemoteTokenizer(
base_url="http://mock-server",
verify_certificate=True,
ca_cert_path="dummy.crt",
auth_token="dummy-token",
)
assert tokenizer.base_url.startswith("http://")
assert tokenizer.tokenizer_info["name_or_path"] == "mock"
def test_check_remote_tokenizer_support(monkeypatch):
class DummyResponse:
status_code = 200
def json(self):
return self._json
def raise_for_status(self):
pass
def __init__(self, url, json=None):
if "tokenizer_info" in url:
self._json = {
"name_or_path": "mock",
"eos_token": "</s>",
"bos_token": "<s>",
"pad_token": "<pad>",
"chat_template": "{{ messages[0].content }}",
}
elif "tokenize" in url:
self._json = {"tokens": [1, 2, 3]}
else:
self._json = {}
monkeypatch.setattr("os.path.exists", lambda path: True)
def dummy_request(self, method, url, **kwargs):
return DummyResponse(url, json=kwargs.get("json"))
monkeypatch.setattr("requests.Session.request", dummy_request)
assert check_remote_tokenizer_support(
base_url="https://mock-server",
verify_certificate=True,
ca_cert_path="dummy.crt",
auth_token="dummy-token",
)
def test_apply_chat_template(monkeypatch):
class DummyResponse:
status_code = 200
def json(self):
return {
"name_or_path": "mock",
"chat_template": "{{ messages[0].content }}",
}
def raise_for_status(self):
pass
monkeypatch.setattr("os.path.exists", lambda path: True)
monkeypatch.setattr(
"requests.Session.request", lambda self, method, url, **kwargs: DummyResponse()
)
tokenizer = RemoteTokenizer(
base_url="https://mock-server",
verify_certificate=True,
ca_cert_path="dummy.crt",
auth_token="dummy-token",
)
chat_history = [{"role": "user", "content": "Hello"}]
rendered = tokenizer.apply_chat_template(chat_history)
assert rendered == "Hello"
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment