Unverified Commit 3b7dbef9 authored by daniel-salib's avatar daniel-salib Committed by GitHub
Browse files

use verify_certificate flag in batch requests (#2785)

parent 91264653
......@@ -475,7 +475,7 @@ class TemplateAPI(TemplateLM):
**kwargs,
) -> Union[List[List[str]], List[List[Tuple[float, bool]]]]:
ctxlens = ctxlens if ctxlens else [None] * len(requests)
conn = TCPConnector(limit=self._concurrent)
conn = TCPConnector(limit=self._concurrent, ssl=self.verify_certificate)
async with ClientSession(
connector=conn, timeout=ClientTimeout(total=self.timeout)
) as session:
......
from unittest.mock import MagicMock, patch
import asyncio
from unittest.mock import AsyncMock, MagicMock, patch
import pytest
......@@ -21,6 +22,17 @@ def api_tokenized():
)
@pytest.fixture
def api_batch_ssl_tokenized():
return LocalCompletionsAPI(
base_url="https://test-url.com",
model="EleutherAI/pythia-1b",
verify_certificate=False,
num_concurrent=2,
tokenizer_backend="huggingface",
)
def test_create_payload_generate(api):
messages = ["Generate a story"]
gen_kwargs = {
......@@ -147,3 +159,70 @@ def test_model_tokenized_call_usage(
assert "json" in kwargs
assert kwargs["json"] == expected_payload
assert result == {"result": "success"}
class DummyAsyncContextManager:
def __init__(self, result):
self.result = result
async def __aenter__(self):
return self.result
async def __aexit__(self, exc_type, exc, tb):
pass
@pytest.mark.parametrize(
"expected_inputs, expected_ctxlens, expected_cache_keys",
[
(
[
[1, 2, 3, 4, 5],
[6, 7, 8, 9, 10],
[11, 12, 13, 14, 15],
[16, 17, 18, 19, 20],
],
[3, 3, 3, 3],
["cache_key1", "cache_key2", "cache_key3", "cache_key4"],
),
],
)
def test_get_batched_requests_with_no_ssl(
api_batch_ssl_tokenized, expected_inputs, expected_ctxlens, expected_cache_keys
):
with (
patch(
"lm_eval.models.api_models.TCPConnector", autospec=True
) as mock_connector,
patch(
"lm_eval.models.api_models.ClientSession", autospec=True
) as mock_client_session,
patch(
"lm_eval.models.openai_completions.LocalCompletionsAPI.parse_logprobs",
autospec=True,
) as mock_parse,
):
mock_session_instance = AsyncMock()
mock_post_response = AsyncMock()
mock_post_response.status = 200
mock_post_response.ok = True
mock_post_response.json = AsyncMock(return_value={"mocked": "response"})
mock_post_response.raise_for_status = lambda: None
mock_session_instance.post = lambda *args, **kwargs: DummyAsyncContextManager(
mock_post_response
)
mock_client_session.return_value.__aenter__.return_value = mock_session_instance
mock_parse.return_value = [(1.23, True), (4.56, False)]
async def run():
return await api_batch_ssl_tokenized.get_batched_requests(
expected_inputs,
expected_cache_keys,
generate=False,
ctxlens=expected_ctxlens,
)
result_batches = asyncio.run(run())
mock_connector.assert_called_with(limit=2, ssl=False)
assert result_batches
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment