"deploy/snapshot/protocol/checkpoint_observation_test.go" did not exist on "23144df513723f919ddc0a13f6f46e1bd6da822a"
tpu_worker.py 12.2 KB
Newer Older
1
import os
2
from typing import List, Optional, Tuple, Union
3
4
5

import torch
import torch_xla.core.xla_model as xm
6
import torch_xla.experimental.dynamo_set_buffer_donor  # noqa: F401
7
8
9
10
import torch_xla.runtime as xr

import vllm.envs as envs
from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, ModelConfig,
11
                         MultiModalConfig, ParallelConfig, SchedulerConfig)
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
from vllm.distributed import (ensure_model_parallel_initialized,
                              init_distributed_environment)
from vllm.logger import init_logger
from vllm.model_executor import set_random_seed
from vllm.sequence import ExecuteModelRequest, SamplerOutput
from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE, get_dtype_size
from vllm.worker.tpu_model_runner import TPUModelRunner
from vllm.worker.worker_base import LoraNotSupportedWorkerBase

logger = init_logger(__name__)


class TPUWorker(LoraNotSupportedWorkerBase):

    def __init__(
        self,
        model_config: ModelConfig,
        parallel_config: ParallelConfig,
        scheduler_config: SchedulerConfig,
        device_config: DeviceConfig,
        cache_config: CacheConfig,
        load_config: LoadConfig,
34
        multimodal_config: Optional[MultiModalConfig],
35
36
37
        local_rank: int,
        rank: int,
        distributed_init_method: str,
38
        is_driver_worker: bool,
39
40
41
    ) -> None:
        self.model_config = model_config
        self.parallel_config = parallel_config
42
        self.parallel_config.rank = rank
43
44
45
46
        self.scheduler_config = scheduler_config
        self.device_config = device_config
        self.cache_config = cache_config
        self.load_config = load_config
47
        self.multimodal_config = multimodal_config
48
49
50
        self.local_rank = local_rank
        self.rank = rank
        self.distributed_init_method = distributed_init_method
51
        self.is_driver_worker = is_driver_worker
52
53
54
55
56
57
58
59

        assert self.device_config.device_type == "tpu"
        if self.cache_config.cache_dtype == "auto":
            self.cache_dtype = self.model_config.dtype
        else:
            self.cache_dtype = STR_DTYPE_TO_TORCH_DTYPE[
                self.cache_config.cache_dtype]

60
61
62
63
64
65
        self.model_runner = TPUModelRunner(model_config,
                                           parallel_config,
                                           scheduler_config,
                                           device_config,
                                           cache_config,
                                           load_config,
66
                                           multimodal_config,
67
                                           is_driver_worker=is_driver_worker)
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100

    def init_device(self) -> None:
        os.environ["PJRT_DEVICE"] = "TPU"
        self.device = xm.xla_device()
        self.device_config.device = self.device
        torch.set_grad_enabled(False)
        torch.set_default_dtype(self.model_config.dtype)

        # NOTE(woosuk): This is just a hack to initialize the TP group.
        # This cannot perform the actual communication ops.
        init_distributed_environment(
            world_size=self.parallel_config.world_size,
            rank=self.rank,
            local_rank=self.local_rank,
            distributed_init_method=self.distributed_init_method,
            backend="gloo",
        )
        ensure_model_parallel_initialized(
            self.parallel_config.tensor_parallel_size,
            self.parallel_config.pipeline_parallel_size)

        # Set random seed.
        set_random_seed(self.model_config.seed)
        xm.set_rng_state(self.model_config.seed, self.device)

        # Increase the cache size limit, which is the maximum number of
        # dynamo graphs that can be compiled.
        # NOTE(woosuk): Usually, we compile 10-15 graphs for prefill and
        # 30-40 graphs for decode. 128 is an arbitrary safe number.
        torch._dynamo.config.cache_size_limit = 128
        # Use persistent cache to avoid XLA recompilation.
        # NOTE(woosuk): This does not completely eliminate the recompilation
        # overhead because dynamo does not cache the compiled results.
101
        xr.initialize_cache(envs.VLLM_XLA_CACHE_PATH, readonly=False)
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120

    def load_model(self):
        self.model_runner.load_model()

    def determine_num_available_blocks(self) -> Tuple[int, int]:
        num_layers = self.model_config.get_num_layers(self.parallel_config)
        head_size = self.model_config.get_head_size()
        num_kv_heads = self.model_config.get_num_kv_heads(self.parallel_config)

        kv_caches = [(None, None) for _ in range(num_layers)]
        self.model_runner._dummy_run(
            batch_size=1,
            seq_len=self.scheduler_config.max_num_batched_tokens,
            kv_caches=kv_caches,
            is_prompt=True,
        )
        # Synchronize before measuring the memory usage.
        xm.wait_device_ops()

121
122
123
124
125
126
        dtype_btyes = get_dtype_size(self.cache_dtype)
        block_size = self.cache_config.block_size
        block_size_bytes = (dtype_btyes * block_size * num_layers * 2 *
                            head_size * num_kv_heads)

        # Calculate the TPU KV cache size based on profiling.
127
        m = xm.get_memory_info(self.device)
128
129
130
131
        total_memory_size = m["bytes_limit"]
        usable_memory_size = int(total_memory_size *
                                 self.cache_config.gpu_memory_utilization)
        profiled = m["bytes_used"]  # Weights + intermediate activations.
132
133
        tpu_kv_cache_bytes = max(usable_memory_size - profiled, 0)
        num_tpu_blocks = tpu_kv_cache_bytes // block_size_bytes
134
        num_tpu_blocks = (num_tpu_blocks // 8) * 8  # Round down to 8.
135
136
137
138
139
140

        # Calculate the CPU KV cache size based on the config.
        num_cpu_blocks = (self.cache_config.swap_space_bytes //
                          block_size_bytes)
        num_cpu_blocks = (num_cpu_blocks // 8) * 8  # Round down to 8.
        return num_tpu_blocks, num_cpu_blocks
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155

    def initialize_cache(
        self,
        num_gpu_blocks: int,
        num_cpu_blocks: int,
    ) -> None:
        self.cache_config.num_gpu_blocks = num_gpu_blocks
        self.cache_config.num_cpu_blocks = num_cpu_blocks
        self.block_size = self.cache_config.block_size

        dtype = self.cache_dtype
        num_layers = self.model_config.get_num_layers(self.parallel_config)
        num_kv_heads = self.model_config.get_num_kv_heads(self.parallel_config)
        head_size = self.model_config.get_head_size()

156
157
        self.cpu_cache: List[Tuple[torch.Tensor, torch.Tensor]] = []
        self.tpu_cache: List[Tuple[torch.Tensor, torch.Tensor]] = []
158
159
        tpu_cache_shape = self.model_runner.attn_backend.get_kv_cache_shape(
            num_gpu_blocks, self.block_size, num_kv_heads, head_size)
160
161
        cpu_cache_shape = self.model_runner.attn_backend.get_kv_cache_shape(
            num_cpu_blocks, self.block_size, num_kv_heads, head_size)
162
        for _ in range(num_layers):
163
164
165
166
167
            tpu_k_cache = torch.zeros(tpu_cache_shape,
                                      dtype=dtype,
                                      device=self.device)
            tpu_v_cache = torch.zeros_like(tpu_k_cache)
            self.tpu_cache.append((tpu_k_cache, tpu_v_cache))
168
169
170
171
            cpu_k_cache = torch.zeros(cpu_cache_shape,
                                      dtype=dtype,
                                      device="cpu")
            cpu_v_cache = torch.zeros_like(cpu_k_cache)
172
            self.cpu_cache.append((cpu_k_cache, cpu_v_cache))
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
        self._warmup_model()

    def _warmup_model(self) -> None:
        # FIXME(woosuk): Here we are abusing `enforce_eager` which is defined
        # for CUDA graphs. We should refactor this part.
        if not self.model_config.enforce_eager:
            # Warm up the model with all possible input shapes so that
            # compilation never happens during the actual execution.
            # This may take ~30 mins for the first run and ~20 mins for the
            # subsequent runs.
            # If `enforce_eager` is True, the ahead-of-time compilation is
            # skipped and the compilation happens during the actual execution,
            # which is bad for performance but useful for development.
            self.model_runner.warmup_model(self.tpu_cache)

    def get_cache_block_size_bytes(self) -> int:
        head_size = self.model_config.get_head_size()
        num_heads = self.model_config.get_num_kv_heads(self.parallel_config)
        num_layers = self.model_config.get_num_layers(self.parallel_config)

        key_cache_block = self.cache_config.block_size * num_heads * head_size
        value_cache_block = key_cache_block
        total = num_layers * (key_cache_block + value_cache_block)
        dtype_size = get_dtype_size(self.cache_dtype)
        return dtype_size * total

    def execute_model(
        self,
201
        execute_model_req: Optional[ExecuteModelRequest] = None,
202
    ) -> List[SamplerOutput]:
203
204
        if not self.is_driver_worker:
            self._execute_model_non_driver()
205
            return []
206
        assert execute_model_req is not None
207
208
209
210
211
212
213
        # Issue cache operations.
        self.cache_swap(
            execute_model_req.blocks_to_swap_in,
            execute_model_req.blocks_to_swap_out,
            execute_model_req.blocks_to_copy,
        )
        # Run the model.
214
215
        seq_group_metadata_list = execute_model_req.seq_group_metadata_list
        assert len(seq_group_metadata_list) > 0
216
217
        output = self.model_runner.execute_model(seq_group_metadata_list,
                                                 self.tpu_cache)
218
        return output
219

220
221
222
223
224
225
226
227
228
229
230
    def cache_swap(
        self,
        blocks_to_swap_in: List[Tuple[int, int]],
        blocks_to_swap_out: List[Tuple[int, int]],
        blocks_to_copy: List[Tuple[int, int]],
    ) -> None:
        attn_backend = self.model_runner.attn_backend
        num_layers = self.model_config.get_num_layers(self.parallel_config)

        if blocks_to_swap_in:
            # Swap from CPU to TPU.
231
232
            src_indices, dst_indices = _make_src_to_dst(
                blocks_to_swap_in, "cpu", self.device)
233
            for i in range(num_layers):
234
235
236
237
238
239
                tpu_k_cache, tpu_v_cache = self.tpu_cache[i]
                cpu_k_cache, cpu_v_cache = self.cpu_cache[i]
                k = cpu_k_cache[:, src_indices].to(self.device)
                v = cpu_v_cache[:, src_indices].to(self.device)
                _insert_kv(k, v, dst_indices, tpu_k_cache, tpu_v_cache)

240
241
        if blocks_to_swap_out:
            # Swap from TPU to CPU.
242
243
            src_indices, dst_indices = _make_src_to_dst(
                blocks_to_swap_out, self.device, "cpu")
244
            for i in range(num_layers):
245
246
247
248
249
                tpu_k_cache, tpu_v_cache = self.tpu_cache[i]
                cpu_k_cache, cpu_v_cache = self.cpu_cache[i]
                cpu_k_cache[:, dst_indices] = tpu_k_cache[:, src_indices].cpu()
                cpu_v_cache[:, dst_indices] = tpu_v_cache[:, src_indices].cpu()

250
251
252
253
254
        if blocks_to_copy:
            src_to_dst = _make_src_to_dst(blocks_to_copy, self.device,
                                          self.device)
            attn_backend.copy_blocks(self.tpu_cache, src_to_dst)

255
256
257
258
259
260
261
    def start_worker_execution_loop(self) -> None:
        while self._execute_model_non_driver():
            pass

    def _execute_model_non_driver(self) -> bool:
        self.model_runner.execute_model(None, self.tpu_cache)
        return True
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277


def _make_src_to_dst(
    mapping: List[Tuple[int, int]],
    src_device: Union[torch.device, str],
    dst_device: Union[torch.device, str],
) -> Tuple[torch.Tensor, torch.Tensor]:
    src_indices = [i for i, _ in mapping]
    dst_indices = [i for _, i in mapping]
    src_indices = torch.tensor(src_indices,
                               device=src_device,
                               dtype=torch.int64)
    dst_indices = torch.tensor(dst_indices,
                               device=dst_device,
                               dtype=torch.int64)
    return src_indices, dst_indices
278
279
280
281
282
283
284
285
286
287
288
289
290
291


@torch.compile(backend="openxla")
def _insert_kv(
    k: torch.Tensor,
    v: torch.Tensor,
    indices: torch.Tensor,
    tpu_k_cache: torch.Tensor,
    tpu_v_cache: torch.Tensor,
) -> None:
    torch.ops.xla.dynamo_set_buffer_donor_(tpu_k_cache, True)
    torch.ops.xla.dynamo_set_buffer_donor_(tpu_v_cache, True)
    tpu_k_cache[:, indices] = k
    tpu_v_cache[:, indices] = v