"csrc/attention/attention_kernels.cuh" did not exist on "14dbd5a7674e5de2862c18adb711d9feecd35063"
test_tracing.py 5.1 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
# ruff: noqa
# type: ignore
from __future__ import annotations

import threading
from collections.abc import Iterable
from concurrent import futures
from typing import Callable, Generator, Literal

import grpc
import pytest
from opentelemetry.proto.collector.trace.v1.trace_service_pb2 import (
15
16
    ExportTraceServiceResponse,
)
17
from opentelemetry.proto.collector.trace.v1.trace_service_pb2_grpc import (
18
19
20
    TraceServiceServicer,
    add_TraceServiceServicer_to_server,
)
21
from opentelemetry.proto.common.v1.common_pb2 import AnyValue, KeyValue
22
from opentelemetry.sdk.environment_variables import OTEL_EXPORTER_OTLP_TRACES_INSECURE
23
24
25
26
27
28

from vllm import LLM, SamplingParams
from vllm.tracing import SpanAttributes

FAKE_TRACE_SERVER_ADDRESS = "localhost:4317"

29
30
31
FieldName = Literal[
    "bool_value", "string_value", "int_value", "double_value", "array_value"
]
32
33
34
35
36
37
38
39


def decode_value(value: AnyValue):
    field_decoders: dict[FieldName, Callable] = {
        "bool_value": (lambda v: v.bool_value),
        "string_value": (lambda v: v.string_value),
        "int_value": (lambda v: v.int_value),
        "double_value": (lambda v: v.double_value),
40
41
42
        "array_value": (
            lambda v: [decode_value(item) for item in v.array_value.values]
        ),
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
    }
    for field, decoder in field_decoders.items():
        if value.HasField(field):
            return decoder(value)
    raise ValueError(f"Couldn't decode value: {value}")


def decode_attributes(attributes: Iterable[KeyValue]):
    return {kv.key: decode_value(kv.value) for kv in attributes}


class FakeTraceService(TraceServiceServicer):
    def __init__(self):
        self.request = None
        self.evt = threading.Event()

    def Export(self, request, context):
        self.request = request
        self.evt.set()
        return ExportTraceServiceResponse()


@pytest.fixture
def trace_service() -> Generator[FakeTraceService, None, None]:
    """Fixture to set up a fake gRPC trace service"""
    server = grpc.server(futures.ThreadPoolExecutor(max_workers=1))
    service = FakeTraceService()
    add_TraceServiceServicer_to_server(service, server)
    server.add_insecure_port(FAKE_TRACE_SERVER_ADDRESS)
    server.start()

    yield service

    server.stop(None)


def test_traces(
    monkeypatch: pytest.MonkeyPatch,
    trace_service: FakeTraceService,
):
    with monkeypatch.context() as m:
        m.setenv(OTEL_EXPORTER_OTLP_TRACES_INSECURE, "true")
85

86
87
88
89
90
91
        sampling_params = SamplingParams(
            temperature=0.01,
            top_p=0.1,
            max_tokens=256,
        )
        model = "facebook/opt-125m"
92
93
94
95
96
97
        llm = LLM(
            model=model,
            otlp_traces_endpoint=FAKE_TRACE_SERVER_ADDRESS,
            gpu_memory_utilization=0.3,
            disable_log_stats=False,
        )
98
99
100
101
102
103
104
105
        prompts = ["This is a short prompt"]
        outputs = llm.generate(prompts, sampling_params=sampling_params)
        print(f"test_traces outputs is : {outputs}")

        timeout = 10
        if not trace_service.evt.wait(timeout):
            raise TimeoutError(
                f"The fake trace service didn't receive a trace within "
106
107
                f"the {timeout} seconds timeout"
            )
108
109
110

        request = trace_service.request
        assert len(request.resource_spans) == 1, (
111
112
            f"Expected 1 resource span, but got {len(request.resource_spans)}"
        )
113
114
        assert len(request.resource_spans[0].scope_spans) == 1, (
            f"Expected 1 scope span, "
115
116
            f"but got {len(request.resource_spans[0].scope_spans)}"
        )
117
118
        assert len(request.resource_spans[0].scope_spans[0].spans) == 1, (
            f"Expected 1 span, "
119
120
            f"but got {len(request.resource_spans[0].scope_spans[0].spans)}"
        )
121
122

        attributes = decode_attributes(
123
124
            request.resource_spans[0].scope_spans[0].spans[0].attributes
        )
125
        # assert attributes.get(SpanAttributes.GEN_AI_RESPONSE_MODEL) == model
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
        assert attributes.get(SpanAttributes.GEN_AI_REQUEST_ID) == outputs[0].request_id
        assert (
            attributes.get(SpanAttributes.GEN_AI_REQUEST_TEMPERATURE)
            == sampling_params.temperature
        )
        assert (
            attributes.get(SpanAttributes.GEN_AI_REQUEST_TOP_P) == sampling_params.top_p
        )
        assert (
            attributes.get(SpanAttributes.GEN_AI_REQUEST_MAX_TOKENS)
            == sampling_params.max_tokens
        )
        assert attributes.get(SpanAttributes.GEN_AI_REQUEST_N) == sampling_params.n
        assert attributes.get(SpanAttributes.GEN_AI_USAGE_PROMPT_TOKENS) == len(
            outputs[0].prompt_token_ids
        )
142
        completion_tokens = sum(len(o.token_ids) for o in outputs[0].outputs)
143
144
145
146
        assert (
            attributes.get(SpanAttributes.GEN_AI_USAGE_COMPLETION_TOKENS)
            == completion_tokens
        )
147
148

        assert attributes.get(SpanAttributes.GEN_AI_LATENCY_TIME_IN_QUEUE) > 0
149
        assert attributes.get(SpanAttributes.GEN_AI_LATENCY_TIME_TO_FIRST_TOKEN) > 0
150
        assert attributes.get(SpanAttributes.GEN_AI_LATENCY_E2E) > 0