Commit 79b31dad authored by Baber's avatar Baber
Browse files

Merge branch 'bos' into mrl

parents cbb8f5a4 7e5f909b
...@@ -16,7 +16,7 @@ classifiers = [ ...@@ -16,7 +16,7 @@ classifiers = [
"License :: OSI Approved :: MIT License", "License :: OSI Approved :: MIT License",
"Operating System :: OS Independent", "Operating System :: OS Independent",
] ]
requires-python = ">=3.9" requires-python = ">=3.10"
license = { "text" = "MIT" } license = { "text" = "MIT" }
dependencies = [ dependencies = [
"accelerate>=0.26.0", "accelerate>=0.26.0",
......
...@@ -226,3 +226,99 @@ def test_get_batched_requests_with_no_ssl( ...@@ -226,3 +226,99 @@ def test_get_batched_requests_with_no_ssl(
mock_connector.assert_called_with(limit=2, ssl=False) mock_connector.assert_called_with(limit=2, ssl=False)
assert result_batches assert result_batches
def test_local_completionsapi_remote_tokenizer_authenticated(monkeypatch):
captured = {}
class DummyTokenizer:
def __init__(
self, base_url, timeout, verify_certificate, ca_cert_path, auth_token
):
captured.update(locals())
monkeypatch.setattr("lm_eval.utils.RemoteTokenizer", DummyTokenizer)
LocalCompletionsAPI(
base_url="https://secure-server",
tokenizer_backend="remote",
verify_certificate=True,
ca_cert_path="secure.crt",
auth_token="secure-token",
)
assert captured["base_url"] == "https://secure-server"
assert captured["verify_certificate"] is True
assert captured["ca_cert_path"] == "secure.crt"
assert captured["auth_token"] == "secure-token"
def test_local_completionsapi_remote_tokenizer_unauthenticated(monkeypatch):
captured = {}
class DummyTokenizer:
def __init__(
self, base_url, timeout, verify_certificate, ca_cert_path, auth_token
):
captured.update(locals())
monkeypatch.setattr("lm_eval.utils.RemoteTokenizer", DummyTokenizer)
LocalCompletionsAPI(
base_url="http://localhost:8000",
tokenizer_backend="remote",
verify_certificate=False,
ca_cert_path=None,
auth_token=None,
)
assert captured["base_url"] == "http://localhost:8000"
assert captured["verify_certificate"] is False
assert captured["ca_cert_path"] is None
assert captured["auth_token"] is None
def test_localchatcompletion_remote_tokenizer_authenticated(monkeypatch):
captured = {}
class DummyTokenizer:
def __init__(
self, base_url, timeout, verify_certificate, ca_cert_path, auth_token
):
captured.update(locals())
monkeypatch.setattr("lm_eval.utils.RemoteTokenizer", DummyTokenizer)
from lm_eval.models.openai_completions import LocalChatCompletion
LocalChatCompletion(
base_url="https://secure-server",
tokenizer_backend="remote",
verify_certificate=True,
ca_cert_path="secure.crt",
auth_token="secure-token",
)
assert captured["base_url"] == "https://secure-server"
assert captured["verify_certificate"] is True
assert captured["ca_cert_path"] == "secure.crt"
assert captured["auth_token"] == "secure-token"
def test_localchatcompletion_remote_tokenizer_unauthenticated(monkeypatch):
captured = {}
class DummyTokenizer:
def __init__(
self, base_url, timeout, verify_certificate, ca_cert_path, auth_token
):
captured.update(locals())
monkeypatch.setattr("lm_eval.utils.RemoteTokenizer", DummyTokenizer)
from lm_eval.models.openai_completions import LocalChatCompletion
LocalChatCompletion(
base_url="http://localhost:8000",
tokenizer_backend="remote",
verify_certificate=False,
ca_cert_path=None,
auth_token=None,
)
assert captured["base_url"] == "http://localhost:8000"
assert captured["verify_certificate"] is False
assert captured["ca_cert_path"] is None
assert captured["auth_token"] is None
"""
Test suite for TemplateLM class from lm_eval.api.model
This file provides boilerplate mocking and test fixtures for testing
the TemplateLM abstract base class methods.
"""
from __future__ import annotations
import os
import random
from typing import Optional
from unittest.mock import Mock, patch
import pytest
from lm_eval.api.instance import Instance
from lm_eval.api.model import TemplateLM
os.environ["TOKENIZERS_PARALLELISM"] = "false"
# ============================================================================
# Mock TemplateLM Implementation
# ============================================================================
class MockTemplateLM(TemplateLM):
"""
Concrete implementation of TemplateLM for testing purposes.
Override abstract methods with mock implementations.
"""
def __init__(
self,
tokenizer=None,
eot_token_id: int = 0,
prefix_token_id: Optional[int] = None,
**kwargs,
):
super().__init__()
self.tokenizer = tokenizer
self._eot_token_id = eot_token_id
self._prefix_token_id = prefix_token_id
self.AUTO_MODEL_CLASS = None # Set to specific class in tests if needed
@property
def eot_token_id(self) -> int:
return self._eot_token_id
@property
def prefix_token_id(self) -> int:
if self._prefix_token_id is not None:
return self._prefix_token_id
return self.eot_token_id
def tok_encode(self, string: str, **kwargs) -> list[int]:
"""Mock tokenization - override in tests as needed"""
# Use tokenizer if available, otherwise fall back to character codes
if self.tokenizer is not None and hasattr(self.tokenizer, "encode"):
result = self.tokenizer.encode(string, **kwargs)
# Handle both list returns and Mock returns
return result if isinstance(result, list) else list(result)
# Fallback: return list of character codes
return [ord(c) for c in string]
def _loglikelihood_tokens(self, requests, *args, **kwargs):
"""Mock implementation - override in tests"""
return requests
def loglikelihood_rolling(
self, requests, disable_tqdm: bool = False
) -> list[float]:
"""Mock implementation - override in tests"""
return [-1.0 for _ in requests]
def generate_until(self, requests, disable_tqdm: bool = False) -> list[str]:
"""Mock implementation - override in tests"""
return ["mock_output" for _ in requests]
# ============================================================================
# Fixtures
# ============================================================================
@pytest.fixture
def mock_tokenizer():
"""Provides a mock tokenizer object with common attributes"""
tokenizer = Mock()
tokenizer.chat_template = None
tokenizer.default_chat_template = "default template"
# Mock encode to return token IDs (using character codes as simple token IDs)
tokenizer.encode = Mock(
side_effect=lambda x, **kwargs: [0] + [ord(c) for c in x]
if kwargs.get("add_special_tokens")
else [ord(c) for c in x]
)
# Mock decode to convert token IDs back to string
tokenizer.decode = Mock(side_effect=lambda x, **kwargs: "".join(chr(c) for c in x))
return tokenizer
@pytest.fixture
def mock_lm():
"""Provides a basic MockTemplateLM instance"""
return MockTemplateLM(eot_token_id=0)
@pytest.fixture
def mock_lm_bos():
"""Provides a MockTemplateLM instance with a BOS token"""
return MockTemplateLM(eot_token_id=0, prefix_token_id=1)
@pytest.fixture
def mock_lm_with_tokenizer(mock_tokenizer):
"""Provides a MockTemplateLM instance with a tokenizer"""
return MockTemplateLM(tokenizer=mock_tokenizer, eot_token_id=0)
@pytest.fixture
def sample_instances():
"""Provides sample Instance objects for testing"""
return [
Instance(
request_type="loglikelihood",
doc={"context": "test", "continuation": "output"},
arguments=("context1", "continuation1"),
idx=0,
),
Instance(
request_type="loglikelihood",
doc={"context": "test2", "continuation": "output2"},
arguments=("context2", "continuation2"),
idx=1,
),
]
@pytest.fixture
def sample_empty_context_instance():
"""Provides an Instance with empty context for testing edge cases"""
return Instance(
request_type="loglikelihood",
doc={"context": "", "continuation": "output"},
arguments=("", "continuation"),
idx=0,
)
# ============================================================================
# Test Class
# ============================================================================
class TestTemplateLM:
"""Test suite for TemplateLM methods"""
# ------------------------------------------------------------------------
# Property Tests
# ------------------------------------------------------------------------
def test_eot_token_id(self, mock_lm):
"""Test eot_token_id property"""
assert mock_lm.eot_token_id == 0
def test_prefix_token_id_default(self, mock_lm):
"""Test that prefix_token_id defaults to eot_token_id"""
assert mock_lm.prefix_token_id == mock_lm.eot_token_id
def test_prefix_token_id_custom(self, mock_lm_bos):
"""Test custom prefix_token_id"""
assert mock_lm_bos.prefix_token_id == 1
# ------------------------------------------------------------------------
# tok_encode Tests
# ------------------------------------------------------------------------
def test_tok_encode_empty_string(self, mock_lm, mock_tokenizer):
"""Test tok_encode with empty string"""
mock_lm.tokenizer = mock_tokenizer
with pytest.raises(AssertionError):
mock_lm._encode_pair("", "hello")
# ------------------------------------------------------------------------
# _encode_pair Tests
# ------------------------------------------------------------------------
def test_encode_pair(self, mock_lm, sample_instances, mock_tokenizer):
"""Test tok_encode with a simple string"""
mock_lm.tokenizer = mock_tokenizer
for instance in sample_instances:
context, cont = instance.args
context_enc, cont_enc = mock_lm._encode_pair(context, cont)
assert context == mock_lm.tokenizer.decode(context_enc)
assert cont == mock_lm.tokenizer.decode(cont_enc)
def test_encode_pair_context_trailing_spaces(
self, mock_lm, sample_instances, mock_tokenizer
):
"""Test _encode_pair moves trailing spaces from context to continuation"""
mock_lm.tokenizer = mock_tokenizer
for instance in sample_instances:
context, cont = instance.args
context_enc, cont_enc = mock_lm._encode_pair(context + " ", cont)
assert context == mock_lm.tokenizer.decode(context_enc)
assert " " + cont == mock_lm.tokenizer.decode(cont_enc)
def test_encode_pair_multiple_trailing_spaces(
self, mock_lm, sample_instances, mock_tokenizer
):
"""Test _encode_pair with multiple trailing spaces"""
mock_lm.tokenizer = mock_tokenizer
spaces = [random.randint(4, 10) for _ in range(len(sample_instances))]
for i, instance in zip(spaces, sample_instances):
context, cont = instance.args
context_enc, cont_enc = mock_lm._encode_pair(context + " " * i, cont)
assert context == mock_lm.tokenizer.decode(context_enc)
assert " " * i + cont == mock_lm.tokenizer.decode(cont_enc)
@patch("transformers.AutoModelForSeq2SeqLM")
def test_encode_pair_seq2seq_model(self, mock_seq2seq, mock_lm):
"""Test _encode_pair behavior with Seq2Seq models"""
# TODO: Implement test
pass
def test_encode_pair_decoder_only_model(self, mock_lm):
"""Test _encode_pair behavior with decoder-only models"""
# TODO: Implement test
pass
# ------------------------------------------------------------------------
# loglikelihood Tests
# ------------------------------------------------------------------------
def test_loglikelihood_add_special_adds_bos(
self, mock_lm, mock_tokenizer, sample_instances
):
"""Test loglikelihood with simple requests"""
"""Testing edge case where context is empty and
add_special_tokens=True -> encode(context + cont) -> cont == [0] + ..."""
mock_lm.tokenizer = mock_tokenizer
def test_loglikelihood_disable_tqdm(self, mock_lm, sample_instances):
"""Test loglikelihood with disable_tqdm=True"""
# TODO: Implement test
pass
def test_loglikelihood_calls_loglikelihood_tokens(self, mock_lm, sample_instances):
"""Test that loglikelihood properly calls _loglikelihood_tokens"""
# TODO: Implement test
# Mock _loglikelihood_tokens and verify it's called with correct args
pass
# ------------------------------------------------------------------------
# chat_template Tests
# ------------------------------------------------------------------------
def test_chat_template_no_tokenizer(self, mock_lm):
"""Test chat_template returns empty string when tokenizer is None"""
# TODO: Implement test
pass
def test_chat_template_false_returns_none(self, mock_lm_with_tokenizer):
"""Test chat_template returns None when passed False"""
# TODO: Implement test
pass
def test_chat_template_none_returns_none(self, mock_lm_with_tokenizer):
"""Test chat_template returns None when passed None"""
# TODO: Implement test
pass
def test_chat_template_single_template(self, mock_lm_with_tokenizer):
"""Test chat_template with single template string"""
# TODO: Implement test
pass
def test_chat_template_dict_with_default(self, mock_lm_with_tokenizer):
"""Test chat_template with dict containing 'default' key"""
# TODO: Implement test
pass
def test_chat_template_dict_with_specific_name(self, mock_lm_with_tokenizer):
"""Test chat_template with dict and specific template name"""
# TODO: Implement test
pass
def test_chat_template_dict_no_default_raises_error(self, mock_lm_with_tokenizer):
"""Test chat_template raises error when dict has no default"""
# TODO: Implement test
pass
def test_chat_template_dict_invalid_name_raises_error(self, mock_lm_with_tokenizer):
"""Test chat_template raises error for invalid template name"""
# TODO: Implement test
pass
def test_chat_template_uses_default_template(self, mock_lm_with_tokenizer):
"""Test chat_template falls back to default_chat_template"""
# TODO: Implement test
pass
def test_chat_template_warning_for_default(self, mock_lm_with_tokenizer):
"""Test that using default template generates warning"""
# TODO: Implement test
pass
# ------------------------------------------------------------------------
# Integration Tests
# ------------------------------------------------------------------------
def test_loglikelihood_encode_pair_integration(self, mock_lm):
"""Integration test: loglikelihood properly uses _encode_pair"""
# TODO: Implement test
pass
def test_tokenization_consistency(self, mock_lm):
"""Test that tokenization is consistent across multiple calls"""
# TODO: Implement test
pass
# ============================================================================
# Additional Helper Functions
# ============================================================================
def create_mock_instance(
context: str, continuation: str, request_type: str = "loglikelihood"
) -> Instance:
"""
Helper function to create mock Instance objects for testing.
Args:
context: Context string
continuation: Continuation string
request_type: Type of request (default: "loglikelihood")
Returns:
Instance object with the specified parameters
"""
return Instance(
request_type=request_type,
doc={"context": context, "continuation": continuation},
arguments=(context, continuation),
idx=0,
)
def create_mock_tokenizer_with_chat_templates(
templates: dict | str | None = None,
) -> Mock:
"""
Helper function to create a mock tokenizer with specific chat templates.
Args:
templates: Chat template(s) - can be None, str, or dict
Returns:
Mock tokenizer object with chat_template set
"""
tokenizer = Mock()
if isinstance(templates, dict):
tokenizer.chat_template = templates
tokenizer.default_chat_template = "default"
elif isinstance(templates, str):
tokenizer.chat_template = templates
tokenizer.default_chat_template = None
else:
tokenizer.chat_template = None
tokenizer.default_chat_template = "default"
return tokenizer
if __name__ == "__main__":
pytest.main()
...@@ -12,6 +12,8 @@ from lm_eval.api.metrics import ( ...@@ -12,6 +12,8 @@ from lm_eval.api.metrics import (
) )
from lm_eval.models.utils import Collator from lm_eval.models.utils import Collator
from lm_eval.utils import ( from lm_eval.utils import (
RemoteTokenizer,
check_remote_tokenizer_support,
get_rolling_token_windows, get_rolling_token_windows,
make_disjoint_window, make_disjoint_window,
) )
...@@ -396,3 +398,146 @@ def test_aggregate_stderrs(samples): ...@@ -396,3 +398,146 @@ def test_aggregate_stderrs(samples):
mean_stderr(list(itertools.chain.from_iterable(samples))), mean_stderr(list(itertools.chain.from_iterable(samples))),
atol=1.0e-3, atol=1.0e-3,
) )
def test_remote_tokenizer_custom_cert_and_token(monkeypatch):
class DummyResponse:
status_code = 200
def json(self):
return {
"name_or_path": "mock",
"chat_template": "{{ messages[0].content }}",
}
def raise_for_status(self):
pass
monkeypatch.setattr("os.path.exists", lambda path: True)
monkeypatch.setattr(
"requests.Session.request", lambda self, method, url, **kwargs: DummyResponse()
)
tokenizer = RemoteTokenizer(
base_url="https://mock-server",
verify_certificate=True,
ca_cert_path="dummy.crt",
auth_token="dummy-token",
)
assert tokenizer.cert_config == "dummy.crt"
assert tokenizer.headers["Authorization"] == "Bearer dummy-token"
assert tokenizer.tokenizer_info["name_or_path"] == "mock"
def test_remote_tokenizer_no_cert(monkeypatch):
class DummyResponse:
status_code = 200
def json(self):
return {"name_or_path": "mock"}
def raise_for_status(self):
pass
monkeypatch.setattr("os.path.exists", lambda path: True)
monkeypatch.setattr(
"requests.Session.request", lambda self, method, url, **kwargs: DummyResponse()
)
tokenizer = RemoteTokenizer(
base_url="https://mock-server",
verify_certificate=True,
ca_cert_path=None,
auth_token="dummy-token",
)
assert tokenizer.cert_config is True
assert tokenizer.headers["Authorization"] == "Bearer dummy-token"
assert tokenizer.tokenizer_info["name_or_path"] == "mock"
def test_remote_tokenizer_http_url(monkeypatch):
class DummyResponse:
status_code = 200
def json(self):
return {"name_or_path": "mock"}
def raise_for_status(self):
pass
monkeypatch.setattr("os.path.exists", lambda path: True)
monkeypatch.setattr(
"requests.Session.request", lambda self, method, url, **kwargs: DummyResponse()
)
tokenizer = RemoteTokenizer(
base_url="http://mock-server",
verify_certificate=True,
ca_cert_path="dummy.crt",
auth_token="dummy-token",
)
assert tokenizer.base_url.startswith("http://")
assert tokenizer.tokenizer_info["name_or_path"] == "mock"
def test_check_remote_tokenizer_support(monkeypatch):
class DummyResponse:
status_code = 200
def json(self):
return self._json
def raise_for_status(self):
pass
def __init__(self, url, json=None):
if "tokenizer_info" in url:
self._json = {
"name_or_path": "mock",
"eos_token": "</s>",
"bos_token": "<s>",
"pad_token": "<pad>",
"chat_template": "{{ messages[0].content }}",
}
elif "tokenize" in url:
self._json = {"tokens": [1, 2, 3]}
else:
self._json = {}
monkeypatch.setattr("os.path.exists", lambda path: True)
def dummy_request(self, method, url, **kwargs):
return DummyResponse(url, json=kwargs.get("json"))
monkeypatch.setattr("requests.Session.request", dummy_request)
assert check_remote_tokenizer_support(
base_url="https://mock-server",
verify_certificate=True,
ca_cert_path="dummy.crt",
auth_token="dummy-token",
)
def test_apply_chat_template(monkeypatch):
class DummyResponse:
status_code = 200
def json(self):
return {
"name_or_path": "mock",
"chat_template": "{{ messages[0].content }}",
}
def raise_for_status(self):
pass
monkeypatch.setattr("os.path.exists", lambda path: True)
monkeypatch.setattr(
"requests.Session.request", lambda self, method, url, **kwargs: DummyResponse()
)
tokenizer = RemoteTokenizer(
base_url="https://mock-server",
verify_certificate=True,
ca_cert_path="dummy.crt",
auth_token="dummy-token",
)
chat_history = [{"role": "user", "content": "Hello"}]
rendered = tokenizer.apply_chat_template(chat_history)
assert rendered == "Hello"
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment