"vllm/vscode:/vscode.git/clone" did not exist on "b764547616e6ae1517929985400b1dc62cdbc3a3"
test_tracing.py 8.15 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
import os
import threading
from concurrent import futures
from typing import Callable, Dict, Iterable, Literal

import grpc
import pytest
from opentelemetry.proto.collector.trace.v1.trace_service_pb2 import (
    ExportTraceServiceResponse)
from opentelemetry.proto.collector.trace.v1.trace_service_pb2_grpc import (
    TraceServiceServicer, add_TraceServiceServicer_to_server)
from opentelemetry.proto.common.v1.common_pb2 import AnyValue, KeyValue
from opentelemetry.sdk.environment_variables import (
    OTEL_EXPORTER_OTLP_TRACES_INSECURE)

from vllm import LLM, SamplingParams
from vllm.tracing import SpanAttributes

FAKE_TRACE_SERVER_ADDRESS = "localhost:4317"

FieldName = Literal['bool_value', 'string_value', 'int_value', 'double_value',
                    'array_value']


def decode_value(value: AnyValue):
    field_decoders: Dict[FieldName, Callable] = {
        "bool_value": (lambda v: v.bool_value),
        "string_value": (lambda v: v.string_value),
        "int_value": (lambda v: v.int_value),
        "double_value": (lambda v: v.double_value),
        "array_value":
        (lambda v: [decode_value(item) for item in v.array_value.values]),
    }
    for field, decoder in field_decoders.items():
        if value.HasField(field):
            return decoder(value)
    raise ValueError(f"Couldn't decode value: {value}")


def decode_attributes(attributes: Iterable[KeyValue]):
    return {kv.key: decode_value(kv.value) for kv in attributes}


class FakeTraceService(TraceServiceServicer):

    def __init__(self):
        self.request = None
        self.evt = threading.Event()

    def Export(self, request, context):
        self.request = request
        self.evt.set()
        return ExportTraceServiceResponse()


@pytest.fixture
def trace_service():
    """Fixture to set up a fake gRPC trace service"""
    server = grpc.server(futures.ThreadPoolExecutor(max_workers=1))
    service = FakeTraceService()
    add_TraceServiceServicer_to_server(service, server)
    server.add_insecure_port(FAKE_TRACE_SERVER_ADDRESS)
    server.start()

    yield service

    server.stop(None)


def test_traces(trace_service):
    os.environ[OTEL_EXPORTER_OTLP_TRACES_INSECURE] = "true"

    sampling_params = SamplingParams(temperature=0.01,
                                     top_p=0.1,
                                     max_tokens=256)
    model = "facebook/opt-125m"
    llm = LLM(
        model=model,
        otlp_traces_endpoint=FAKE_TRACE_SERVER_ADDRESS,
    )
    prompts = ["This is a short prompt"]
    outputs = llm.generate(prompts, sampling_params=sampling_params)

    timeout = 5
    if not trace_service.evt.wait(timeout):
        raise TimeoutError(
            f"The fake trace service didn't receive a trace within "
            f"the {timeout} seconds timeout")

90
91
92
93
94
95
96
97
98
99
100
101
102
    request = trace_service.request
    assert len(request.resource_spans) == 1, (
        f"Expected 1 resource span, "
        f"but got {len(request.resource_spans)}")
    assert len(request.resource_spans[0].scope_spans) == 1, (
        f"Expected 1 scope span, "
        f"but got {len(request.resource_spans[0].scope_spans)}")
    assert len(request.resource_spans[0].scope_spans[0].spans) == 1, (
        f"Expected 1 span, "
        f"but got {len(request.resource_spans[0].scope_spans[0].spans)}")

    attributes = decode_attributes(
        request.resource_spans[0].scope_spans[0].spans[0].attributes)
103
    assert attributes.get(SpanAttributes.GEN_AI_RESPONSE_MODEL) == model
104
    assert attributes.get(
105
106
107
        SpanAttributes.GEN_AI_REQUEST_ID) == outputs[0].request_id
    assert attributes.get(SpanAttributes.GEN_AI_REQUEST_TEMPERATURE
                          ) == sampling_params.temperature
108
    assert attributes.get(
109
        SpanAttributes.GEN_AI_REQUEST_TOP_P) == sampling_params.top_p
110
    assert attributes.get(
111
112
113
        SpanAttributes.GEN_AI_REQUEST_MAX_TOKENS) == sampling_params.max_tokens
    assert attributes.get(SpanAttributes.GEN_AI_REQUEST_N) == sampling_params.n
    assert attributes.get(SpanAttributes.GEN_AI_USAGE_PROMPT_TOKENS) == len(
114
115
116
        outputs[0].prompt_token_ids)
    completion_tokens = sum(len(o.token_ids) for o in outputs[0].outputs)
    assert attributes.get(
117
        SpanAttributes.GEN_AI_USAGE_COMPLETION_TOKENS) == completion_tokens
118
119
    metrics = outputs[0].metrics
    assert attributes.get(
120
        SpanAttributes.GEN_AI_LATENCY_TIME_IN_QUEUE) == metrics.time_in_queue
121
122
    ttft = metrics.first_token_time - metrics.arrival_time
    assert attributes.get(
123
        SpanAttributes.GEN_AI_LATENCY_TIME_TO_FIRST_TOKEN) == ttft
124
    e2e_time = metrics.finished_time - metrics.arrival_time
125
    assert attributes.get(SpanAttributes.GEN_AI_LATENCY_E2E) == e2e_time
126
    assert metrics.scheduler_time > 0
127
128
    assert attributes.get(SpanAttributes.GEN_AI_LATENCY_TIME_IN_SCHEDULER
                          ) == metrics.scheduler_time
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
    # Model forward and model execute should be none, since detailed traces is
    # not enabled.
    assert metrics.model_forward_time is None
    assert metrics.model_execute_time is None


def test_traces_with_detailed_steps(trace_service):
    os.environ[OTEL_EXPORTER_OTLP_TRACES_INSECURE] = "true"

    sampling_params = SamplingParams(temperature=0.01,
                                     top_p=0.1,
                                     max_tokens=256)
    model = "facebook/opt-125m"
    llm = LLM(
        model=model,
        otlp_traces_endpoint=FAKE_TRACE_SERVER_ADDRESS,
        collect_detailed_traces="all",
    )
    prompts = ["This is a short prompt"]
    outputs = llm.generate(prompts, sampling_params=sampling_params)

    timeout = 5
    if not trace_service.evt.wait(timeout):
        raise TimeoutError(
            f"The fake trace service didn't receive a trace within "
            f"the {timeout} seconds timeout")

156
157
158
159
160
161
162
163
164
165
166
167
168
    request = trace_service.request
    assert len(request.resource_spans) == 1, (
        f"Expected 1 resource span, "
        f"but got {len(request.resource_spans)}")
    assert len(request.resource_spans[0].scope_spans) == 1, (
        f"Expected 1 scope span, "
        f"but got {len(request.resource_spans[0].scope_spans)}")
    assert len(request.resource_spans[0].scope_spans[0].spans) == 1, (
        f"Expected 1 span, "
        f"but got {len(request.resource_spans[0].scope_spans[0].spans)}")

    attributes = decode_attributes(
        request.resource_spans[0].scope_spans[0].spans[0].attributes)
169
    assert attributes.get(SpanAttributes.GEN_AI_RESPONSE_MODEL) == model
170
    assert attributes.get(
171
172
173
        SpanAttributes.GEN_AI_REQUEST_ID) == outputs[0].request_id
    assert attributes.get(SpanAttributes.GEN_AI_REQUEST_TEMPERATURE
                          ) == sampling_params.temperature
174
    assert attributes.get(
175
        SpanAttributes.GEN_AI_REQUEST_TOP_P) == sampling_params.top_p
176
    assert attributes.get(
177
178
179
        SpanAttributes.GEN_AI_REQUEST_MAX_TOKENS) == sampling_params.max_tokens
    assert attributes.get(SpanAttributes.GEN_AI_REQUEST_N) == sampling_params.n
    assert attributes.get(SpanAttributes.GEN_AI_USAGE_PROMPT_TOKENS) == len(
180
181
182
        outputs[0].prompt_token_ids)
    completion_tokens = sum(len(o.token_ids) for o in outputs[0].outputs)
    assert attributes.get(
183
        SpanAttributes.GEN_AI_USAGE_COMPLETION_TOKENS) == completion_tokens
184
185
    metrics = outputs[0].metrics
    assert attributes.get(
186
        SpanAttributes.GEN_AI_LATENCY_TIME_IN_QUEUE) == metrics.time_in_queue
187
188
    ttft = metrics.first_token_time - metrics.arrival_time
    assert attributes.get(
189
        SpanAttributes.GEN_AI_LATENCY_TIME_TO_FIRST_TOKEN) == ttft
190
    e2e_time = metrics.finished_time - metrics.arrival_time
191
    assert attributes.get(SpanAttributes.GEN_AI_LATENCY_E2E) == e2e_time
192
    assert metrics.scheduler_time > 0
193
194
    assert attributes.get(SpanAttributes.GEN_AI_LATENCY_TIME_IN_SCHEDULER
                          ) == metrics.scheduler_time
195
196
    assert metrics.model_forward_time > 0
    assert attributes.get(
197
        SpanAttributes.GEN_AI_LATENCY_TIME_IN_MODEL_FORWARD) == pytest.approx(
198
199
            metrics.model_forward_time / 1000)
    assert metrics.model_execute_time > 0
200
    assert attributes.get(SpanAttributes.GEN_AI_LATENCY_TIME_IN_MODEL_EXECUTE
201
202
                          ) == metrics.model_execute_time
    assert metrics.model_forward_time < 1000 * metrics.model_execute_time