vllm.py 4.65 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
import argparse
from dataclasses import field
from typing import Any, Optional

from triton_distributed.icp.data_plane import DataPlane
from triton_distributed.icp.request_plane import RequestPlane
from triton_distributed.worker import Operator, RemoteInferenceRequest

from .vllm_disaggregated.pipelines import (
    GenerateStage,
    PrefillStage,
    SingleComputePipeline,
)
from .vllm_disaggregated.stage_executor import PiplineStageExecutor


class VllmContextOperator(Operator):
    def __init__(
        self,
        name: str,
        version: int,
        triton_core,
        request_plane: RequestPlane,
        data_plane: DataPlane,
        parameters: Optional[dict[str, str | int | bool | bytes]] = field(
            default_factory=dict
        ),
        repository: Optional[str] = None,
        logger: Optional[Any] = None,
    ):
        args = argparse.Namespace(**parameters)  # type: ignore
        stage = PrefillStage(
            model=args.model_name,
            tensor_parallel_size=args.context_tp_size,
            generate_tensor_parallel_size=args.generate_tp_size,
            gpu_memory_utilization=args.gpu_memory_utilization,
            max_model_len=args.max_model_len,
            dtype=args.dtype,
            kv_cache_dtype=args.kv_cache_dtype,
            enable_prefix_caching=args.enable_prefix_caching,
            enable_chunked_prefill=args.enable_chunked_prefill,
            enforce_eager=args.enforce_eager,
            ignore_eos=args.ignore_eos,
            max_num_seqs=args.max_num_seqs,
            disable_async_output_proc=args.disable_async_output_proc,
            disable_log_stats=args.disable_log_stats,
        )
        self.executor = PiplineStageExecutor(
            args, request_plane, stage, "prefill", "generate"
        )

    async def execute(self, requests: list[RemoteInferenceRequest]) -> None:
        await self.executor.process_requests(requests)


class VllmGenerateOperator(Operator):
    def __init__(
        self,
        name: str,
        version: int,
        triton_core,
        request_plane: RequestPlane,
        data_plane: DataPlane,
        parameters: Optional[dict[str, str | int | bool | bytes]] = field(
            default_factory=dict
        ),
        repository: Optional[str] = None,
        logger: Optional[Any] = None,
    ):
        args = argparse.Namespace(**parameters)  # type: ignore
        args.worker_name = "generate"
        stage = GenerateStage(
            model=args.model_name,
            tensor_parallel_size=args.generate_tp_size,
            gpu_memory_utilization=args.gpu_memory_utilization,
            max_model_len=args.max_model_len,
            dtype=args.dtype,
            kv_cache_dtype=args.kv_cache_dtype,
            enable_prefix_caching=args.enable_prefix_caching,
            enable_chunked_prefill=args.enable_chunked_prefill,
            enforce_eager=args.enforce_eager,
            ignore_eos=args.ignore_eos,
            max_num_seqs=args.max_num_seqs,
            disable_async_output_proc=args.disable_async_output_proc,
            disable_log_stats=args.disable_log_stats,
        )
        self.executor = PiplineStageExecutor(args, request_plane, stage, "generate")

    async def execute(self, requests: list[RemoteInferenceRequest]) -> None:
        await self.executor.process_requests(requests)


class VllmBaselineOperator(Operator):
    def __init__(
        self,
        name: str,
        version: int,
        triton_core,
        request_plane: RequestPlane,
        data_plane: DataPlane,
        parameters: Optional[dict[str, str | int | bool | bytes]] = field(
            default_factory=dict
        ),
        repository: Optional[str] = None,
        logger: Optional[Any] = None,
    ):
        args = argparse.Namespace(**parameters)  # type: ignore
        stage = SingleComputePipeline(
            model=args.model_name,
            tensor_parallel_size=args.baseline_tp_size,
            gpu_memory_utilization=args.gpu_memory_utilization,
            max_model_len=args.max_model_len,
            dtype=args.dtype,
            kv_cache_dtype=args.kv_cache_dtype,
            enable_prefix_caching=args.enable_prefix_caching,
            enable_chunked_prefill=args.enable_chunked_prefill,
            enforce_eager=args.enforce_eager,
            ignore_eos=args.ignore_eos,
            max_num_seqs=args.max_num_seqs,
            disable_async_output_proc=args.disable_async_output_proc,
            disable_log_stats=args.disable_log_stats,
        )
        self.executor = PiplineStageExecutor(args, request_plane, stage, "baseline")

    async def execute(self, requests: list[RemoteInferenceRequest]) -> None:
        await self.executor.process_requests(requests)