"tests/others/test_utils.py" did not exist on "8e2c4cd56cd75c076b04ad0869aca074f307bea7"
test_mllama.py 2.72 KB
Newer Older
Nicolas Patry's avatar
Nicolas Patry committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
import pytest
import asyncio


@pytest.fixture(scope="module")
def mllama_handle(launcher):
    with launcher("meta-llama/Llama-3.2-11B-Vision-Instruct", num_shard=2) as handle:
        yield handle


@pytest.fixture(scope="module")
async def mllama(mllama_handle):
    await mllama_handle.health(300)
    return mllama_handle.client


@pytest.mark.asyncio
async def test_mllama_simpl(mllama, response_snapshot):
    response = await mllama.chat(
        max_tokens=10,
        temperature=0.0,
        messages=[
            {
                "role": "user",
                "content": [
                    {
                        "type": "text",
                        "text": "Can you tell me a very short story based on the image?",
                    },
                    {
                        "type": "image_url",
                        "image_url": {
                            "url": "https://raw.githubusercontent.com/huggingface/text-generation-inference/main/integration-tests/images/chicken_on_money.png"
                        },
                    },
                ],
            },
        ],
    )

    assert response.usage == {
        "completion_tokens": 10,
        "prompt_tokens": 50,
        "total_tokens": 60,
    }
    assert (
        response.choices[0].message.content
        == "In a bustling city, a chicken named Cluck"
    )
    assert response == response_snapshot


@pytest.mark.release
@pytest.mark.asyncio
async def test_mllama_load(mllama, generate_load, response_snapshot):
    futures = [
        mllama.chat(
            max_tokens=10,
            temperature=0.0,
            messages=[
                {
                    "role": "user",
                    "content": [
                        {
                            "type": "text",
                            "text": "Can you tell me a very short story based on the image?",
                        },
                        {
                            "type": "image_url",
                            "image_url": {
                                "url": "https://raw.githubusercontent.com/huggingface/text-generation-inference/main/integration-tests/images/chicken_on_money.png"
                            },
                        },
                    ],
                },
            ],
        )
        for i in range(4)
    ]
    responses = await asyncio.gather(*futures)

Nicolas Patry's avatar
Nicolas Patry committed
82
    _ = [response.choices[0].message.content for response in responses]
Nicolas Patry's avatar
Nicolas Patry committed
83

Nicolas Patry's avatar
Nicolas Patry committed
84
85
86
87
88
89
90
    # XXX: TODO: Fix this test.
    # assert generated_texts[0] == "In a bustling city, a chicken named Cluck"
    # assert len(generated_texts) == 4
    # assert generated_texts, all(
    #     [text == generated_texts[0] for text in generated_texts]
    # )
    # assert responses == response_snapshot