test_flash_phi.py 1.83 KB
Newer Older
drbh's avatar
drbh committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
import pytest


@pytest.fixture(scope="module")
def flash_phi_handle(launcher):
    with launcher("microsoft/phi-2", num_shard=1) as handle:
        yield handle


@pytest.fixture(scope="module")
async def flash_phi(flash_phi_handle):
    await flash_phi_handle.health(300)
    return flash_phi_handle.client


@pytest.mark.asyncio
@pytest.mark.private
async def test_flash_phi(flash_phi, response_snapshot):
    response = await flash_phi.generate(
        "Test request", max_new_tokens=10, decoder_input_details=True
    )

    assert response.details.generated_tokens == 10
OlivierDehaene's avatar
OlivierDehaene committed
24
    assert response.generated_text == ': {request}")\n        response = self'
drbh's avatar
drbh committed
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
    assert response == response_snapshot


@pytest.mark.asyncio
@pytest.mark.private
async def test_flash_phi_all_params(flash_phi, response_snapshot):
    response = await flash_phi.generate(
        "Test request",
        max_new_tokens=10,
        repetition_penalty=1.2,
        return_full_text=True,
        stop_sequences=["network"],
        temperature=0.5,
        top_p=0.9,
        top_k=10,
        truncate=5,
        typical_p=0.9,
        watermark=True,
        decoder_input_details=True,
        seed=0,
    )

    assert response.details.generated_tokens == 6
    assert response.generated_text == "Test request to send data over a network"
    assert response == response_snapshot


@pytest.mark.asyncio
@pytest.mark.private
async def test_flash_phi_load(flash_phi, generate_load, response_snapshot):
OlivierDehaene's avatar
OlivierDehaene committed
55
    responses = await generate_load(flash_phi, "Test request", max_new_tokens=10, n=4)
drbh's avatar
drbh committed
56
57
58
59
60

    assert len(responses) == 4
    assert all(
        [r.generated_text == responses[0].generated_text for r in responses]
    ), f"{[r.generated_text  for r in responses]}"
OlivierDehaene's avatar
OlivierDehaene committed
61
    assert responses[0].generated_text == ': {request}")\n        response = self'
drbh's avatar
drbh committed
62
63

    assert responses == response_snapshot