test_flash_phi35_moe.py 2.13 KB
Newer Older
drbh's avatar
drbh committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
import pytest


@pytest.fixture(scope="module")
def flash_phi35_moe_handle(launcher):
    with launcher(
        "microsoft/Phi-3.5-MoE-instruct",
        num_shard=4,
    ) as handle:
        yield handle


@pytest.fixture(scope="module")
async def flash_phi35_moe(flash_phi35_moe_handle):
    await flash_phi35_moe_handle.health(300)
    return flash_phi35_moe_handle.client


@pytest.mark.asyncio
async def test_flash_phi35_moe(flash_phi35_moe, response_snapshot):
    response = await flash_phi35_moe.generate(
        "What is gradient descent?\n\n", max_new_tokens=10, decoder_input_details=True
    )

    assert response.details.generated_tokens == 10
    assert (
        response.generated_text
Daniël de Kok's avatar
Daniël de Kok committed
28
        == "Gradient descent is an optimization algorithm commonly used in"
drbh's avatar
drbh committed
29
30
31
32
33
34
35
    )
    assert response == response_snapshot


@pytest.mark.asyncio
async def test_flash_phi35_moe_all_params(flash_phi35_moe, response_snapshot):
    response = await flash_phi35_moe.generate(
Daniël de Kok's avatar
Daniël de Kok committed
36
        "What is gradient descent?\n",
drbh's avatar
drbh committed
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
        max_new_tokens=10,
        repetition_penalty=1.2,
        return_full_text=True,
        stop_sequences=["test"],
        temperature=0.5,
        top_p=0.9,
        top_k=10,
        truncate=5,
        typical_p=0.9,
        watermark=True,
        decoder_input_details=True,
        seed=0,
    )

    assert response.details.generated_tokens == 10
    assert (
        response.generated_text
Daniël de Kok's avatar
Daniël de Kok committed
54
        == "What is gradient descent?\nGradient Descent (GD) is an"
drbh's avatar
drbh committed
55
56
57
58
59
60
61
62
63
64
65
66
67
68
    )
    assert response == response_snapshot


@pytest.mark.asyncio
async def test_flash_phi35_moe_load(flash_phi35_moe, generate_load, response_snapshot):
    responses = await generate_load(
        flash_phi35_moe, "What is gradient descent?\n\n", max_new_tokens=10, n=4
    )

    assert len(responses) == 4
    assert responses[0].details.generated_tokens == 10
    assert (
        responses[0].generated_text
Daniël de Kok's avatar
Daniël de Kok committed
69
        == "Gradient descent is an optimization algorithm commonly used in"
drbh's avatar
drbh committed
70
71
72
73
74
75
    )
    assert all(
        [r.generated_text == responses[0].generated_text for r in responses]
    ), f"{[r.generated_text  for r in responses]}"

    assert responses == response_snapshot