tpu_worker.py 14.2 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
"""A TPU worker class."""
4

5
import os
6
from typing import Any, Optional
7
8
9
10
11
12

import torch
import torch.distributed
import torch.nn as nn

import vllm.envs as envs
13
from vllm.config import VllmConfig
14
15
from vllm.distributed import (ensure_model_parallel_initialized,
                              init_distributed_environment)
16
17
from vllm.distributed.kv_transfer import (ensure_kv_transfer_initialized,
                                          has_kv_transfer_group)
18
from vllm.logger import init_logger
19
from vllm.lora.request import LoRARequest
20
from vllm.model_executor import set_random_seed
21
from vllm.platforms import current_platform
22
from vllm.platforms.tpu import USE_TPU_COMMONS
23
from vllm.tasks import SupportedTask
24
from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE, cdiv
25
from vllm.v1.core.sched.output import SchedulerOutput
26
from vllm.v1.kv_cache_interface import (AttentionSpec, KVCacheConfig,
27
28
                                        KVCacheSpec)
from vllm.v1.outputs import ModelRunnerOutput
29
30
from vllm.v1.utils import report_usage_stats
from vllm.v1.worker.utils import bind_kv_cache
31
32
33

logger = init_logger(__name__)

34
35
36
37
38
39
40
41
42
if not USE_TPU_COMMONS:
    logger.info("tpu_commons not found, using vLLM's TPUWorker.")
    import torch_xla.core.xla_model as xm
    import torch_xla.debug.profiler as xp
    import torch_xla.runtime as xr

    from vllm.v1.attention.backends.pallas import TPU_HEAD_SIZE_ALIGNMENT
    from vllm.v1.worker.tpu_model_runner import TPUModelRunner

43
44
45
46
47
48
49
50
51
52
53

class TPUWorker:

    def __init__(
        self,
        vllm_config: VllmConfig,
        local_rank: int,
        rank: int,
        distributed_init_method: str,
        is_driver_worker: bool = False,
    ):
54
        self.is_driver_worker = is_driver_worker
55
56
57
58
59
60
        self.vllm_config = vllm_config
        self.model_config = vllm_config.model_config
        self.cache_config = vllm_config.cache_config
        self.lora_config = vllm_config.lora_config
        self.load_config = vllm_config.load_config
        self.parallel_config = vllm_config.parallel_config
61
62
63
64
65
66
67
68
69
        self.use_spmd = envs.VLLM_XLA_USE_SPMD
        self.original_parallel_config = None
        if self.use_spmd:
            # Under SPMD mode, distributed env is initialized as if there is
            # only one worker/device.
            self.original_parallel_config = self.parallel_config
            self.parallel_config.tensor_parallel_size = 1
            self.parallel_config.pipeline_parallel_size = 1
            self.parallel_config.world_size = 1
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
        self.scheduler_config = vllm_config.scheduler_config
        self.device_config = vllm_config.device_config
        self.speculative_config = vllm_config.speculative_config
        self.observability_config = vllm_config.observability_config

        self.parallel_config.rank = rank
        self.local_rank = local_rank
        self.rank = rank
        self.distributed_init_method = distributed_init_method

        if self.cache_config.cache_dtype == "auto":
            self.cache_dtype = self.model_config.dtype
        else:
            self.cache_dtype = STR_DTYPE_TO_TORCH_DTYPE[
                self.cache_config.cache_dtype]

        if self.model_config.trust_remote_code:
            # note: lazy import to avoid importing torch before initializing
            from vllm.utils import init_cached_hf_modules
            init_cached_hf_modules()

91
92
93
94
        # Delay profiler initialization to the start of the profiling.
        # This is because in vLLM V1, MP runtime is initialized before the
        # TPU Worker is initialized. The profiler server needs to start after
        # MP runtime is initialized.
95
        self.profiler = None
96
        self.profile_dir = None
97
98
99
100
101
102
103
        if envs.VLLM_TORCH_PROFILER_DIR and self.rank < 1:
            # For TPU, we can only have 1 active profiler session for 1 profiler
            # server. So we only profile on rank0.
            self.profile_dir = envs.VLLM_TORCH_PROFILER_DIR
            logger.info("Profiling enabled. Traces will be saved to: %s",
                        self.profile_dir)

104
105
106
        if self.model_config.seed is None:
            self.model_config.seed = 0

107
108
109
110
111
    def initialize_cache(self, num_gpu_blocks: int,
                         num_cpu_blocks: int) -> None:
        self.cache_config.num_gpu_blocks = num_gpu_blocks
        self.cache_config.num_cpu_blocks = num_cpu_blocks

112
113
    def init_device(self):
        os.environ["PJRT_DEVICE"] = "TPU"
114
115
116
117
118
        # Note: Currently the XLA compiler wrongly uses 2D ring strategy on 1D
        # ring, the xla tpu compiler flag
        # `xla_tpu_force_1d_allreduce_at_chunk_count` is a temporary solution to
        # fix this. It will be removed after the bug in XLA compiler is fixed.
        os.environ["LIBTPU_INIT_ARGS"] = (
119
            os.environ.get("LIBTPU_INIT_ARGS", "") +
120
121
122
123
            " --xla_tpu_force_1d_allreduce_at_chunk_count=1"
            " --xla_jf_conv_input_fusion=False")
        # --xla_jf_conv_input_fusion=False is used to improve the perf of
        # quantized matmul.
124
125
126
127
        torch.set_grad_enabled(False)
        torch.set_default_dtype(self.model_config.dtype)

        # Initialize the distributed environment.
128
        self._init_tpu_worker_distributed_environment(
129
            self.vllm_config, self.rank, self.distributed_init_method,
130
            self.local_rank)
131
132
133
134
135
136
137
138

        # Device initialization should happen after initializing
        # the distributed runtime.
        self.device = xm.xla_device()
        self.device_config.device = self.device

        # Set random seed.
        set_random_seed(self.model_config.seed)
139
140
        if self.model_config.seed is not None:
            xm.set_rng_state(self.model_config.seed, self.device)
141
142
143

        # Increase the cache size limit, which is the maximum number of
        # dynamo graphs that can be compiled.
144
145
        # TODO (NickLucche) On gsm we compile 80+ graphs.
        # Re-evaluate limit, with MM we may get close to this limit.
146
147
148
149
150
151
        torch._dynamo.config.cache_size_limit = 128
        # Use persistent cache to avoid XLA recompilation.
        # NOTE(woosuk): Set per-rank cache path since different ranks
        # can have slightly different XLA graphs.
        world_size = self.parallel_config.world_size
        rank = xr.global_ordinal()
152
153
154
155
156
157
158
159
160
161
        # The PyTorch/XLA compilation cache uses the Torch IR to generate keys.
        # Consequently, changes in optimization flags, which affect compilation
        # results, don't change the cache key. This can result in the wrong
        # compilation being used. To prevent this, disabling the XLA compilation
        # cache during development is recommended.We can disable it by
        # `export VLLM_XLA_CACHE_PATH=`
        if envs.VLLM_XLA_CACHE_PATH:
            per_rank_path = os.path.join(envs.VLLM_XLA_CACHE_PATH,
                                         f"tp{world_size}_rank{rank}")
            xr.initialize_cache(per_rank_path, readonly=False)
162
163

        # Init ModelRunner here, so that we have access to self.device.
164
165
166
        self.model_runner = \
            TPUModelRunner(self.vllm_config, self.device,
                           self.original_parallel_config)
167

168
169
170
171
        if rank == 0:
            # If usage stat is enabled, collect relevant info.
            report_usage_stats(self.vllm_config)

172
    def determine_available_memory(self) -> int:
173
        kv_caches: dict[str, torch.Tensor] = {}
174
175
        kv_cache_spec = self.model_runner.get_kv_cache_spec()
        for layer_name, layer_spec in kv_cache_spec.items():
176
            if isinstance(layer_spec, AttentionSpec):
177
178
179
180
                dtype = layer_spec.dtype

                # Use an empty tensor instead of `None`` to force Dynamo to pass
                # it by reference, rather by specializing on the value ``None``.
181
                tpu_kv_cache = torch.tensor([], dtype=dtype).to(self.device)
182
                kv_caches[layer_name] = tpu_kv_cache
183
            else:
184
185
                raise NotImplementedError(
                    f"Unsupported KV cache spec '{type(layer_spec)}'")
186

187
        runner_kv_caches: list[torch.Tensor] = []
188
189
190
191
192
        bind_kv_cache(
            kv_caches,
            self.vllm_config.compilation_config.static_forward_context,
            runner_kv_caches)

193
        # `max_num_tokens >= max_num_batched_tokens` due to padding.
194
195
        with self.model_runner.maybe_setup_dummy_loras(self.lora_config):
            self.model_runner.profile_run(self.model_runner.max_num_tokens)
196
197
198
199

        # Synchronize before measuring the memory usage.
        xm.wait_device_ops()

200
201
202
203
204
205
206
207
        # During the profiling run, the model runs without KV cache. After
        # the profiling run, the model always runs with KV cache. Here we clear
        # the dynamo cache and cached bytecode to ensure the model always has
        # one compiled bytecode. Having one FX graph/cached bytecode per
        # compiled model is required for `support_torch_compile` decorator to
        # skip dynamo guard.
        self.model_runner.reset_dynamo_cache()

208
209
        # Get the maximum amount of memory used by the model weights and
        # intermediate activations.
210
211
212
213
214
215
216
217
218
219
220
221
222
223
        if self.use_spmd:
            # This is a workaround for the TPU SPMD mode. The get_memory_info
            # API doesn't work with SPMD mode in PyTorch/XLA.
            # TODO: use xm.get_memory_info for SPMD once it's supported in
            # PyTorch/XLA.
            import tpu_info
            chip_type, _ = tpu_info.device.get_local_chips()
            device_usage = tpu_info.metrics.get_chip_usage(chip_type)
            total_memory_size = device_usage[0].total_memory
            current_mem = device_usage[0].memory_usage
        else:
            m = xm.get_memory_info(self.device)
            total_memory_size = m["bytes_limit"]
            current_mem = m["bytes_used"]
224
225
226
227
228
229
        # Ideally we would use profiled = m["peak_bytes_used"] to
        # get weights + activations. But there is memory used during
        # compilation / weight loading that impacts the peak and
        # there is no way to reset peak memory in XLA, So we
        # use the heuristic of 2% of weights.
        profiled = current_mem * 1.02
230
231
232
233
234

        # Calculate the TPU KV cache size based on profiling.
        usable_memory_size = int(total_memory_size *
                                 self.cache_config.gpu_memory_utilization)
        tpu_kv_cache_bytes = max(usable_memory_size - profiled, 0)
235
236
237
238
239
240
241
242
243
244
245
        head_size = self.model_config.get_head_size()
        if head_size > 0:
            padded_head_size = cdiv(
                head_size, TPU_HEAD_SIZE_ALIGNMENT) * TPU_HEAD_SIZE_ALIGNMENT
            if padded_head_size != head_size:
                logger.warning_once("head size is padded to %d",
                                    padded_head_size)
            # We adjust the usable memory size for the KV cache to prevent OOM
            # errors, even after padding the head_size.
            tpu_kv_cache_bytes = (tpu_kv_cache_bytes * head_size //
                                  padded_head_size)
246
247
248
249
250
251
252
        return int(tpu_kv_cache_bytes)

    def execute_model(
        self,
        scheduler_output: "SchedulerOutput",
    ) -> Optional[ModelRunnerOutput]:
        output = self.model_runner.execute_model(scheduler_output)
253
254
255
        # every worker's output is needed when kv_transfer_group is setup
        return output if self.is_driver_worker or has_kv_transfer_group(
        ) else None
256

257
258
    def profile(self, is_start: bool = True):
        if self.rank < 1:
259
            if self.profile_dir is None:
260
261
                raise RuntimeError("Profiler is not enabled.")
            if is_start:
262
263
                if self.profiler is None:
                    self.profiler = xp.start_server(9012)
264
265
266
267
                xp.start_trace(self.profile_dir)
            else:
                xp.stop_trace()

268
269
270
    def add_lora(self, lora_request: LoRARequest) -> bool:
        return self.model_runner.add_lora(lora_request)

271
272
273
    def load_model(self) -> None:
        self.model_runner.load_model()

274
275
276
    def update_config(self, overrides: dict[str, Any]) -> None:
        self.model_runner.update_config(overrides)

277
278
279
    def reload_weights(self) -> None:
        self.model_runner.reload_weights()

280
281
282
283
284
285
286
287
288
289
290
    def compile_or_warm_up_model(self) -> None:
        if not self.model_config.enforce_eager:
            self.model_runner.capture_model()

        # Reset the seed to ensure that the random state is not affected by
        # the model initialization and profiling.
        set_random_seed(self.model_config.seed)

    def get_model(self) -> nn.Module:
        return self.model_runner.get_model()

291
292
    def get_supported_tasks(self) -> tuple[SupportedTask, ...]:
        return self.model_runner.get_supported_tasks()
293

294
    def get_kv_cache_spec(self) -> dict[str, KVCacheSpec]:
295
296
        return self.model_runner.get_kv_cache_spec()

297
    def initialize_from_config(self, kv_cache_config: KVCacheConfig) -> None:
298
299
300
301
302
303
304
        """Allocate GPU KV cache with the specified kv_cache_config."""
        self.model_runner.initialize_kv_cache(kv_cache_config)

    def check_health(self) -> None:
        # worker will always be healthy as long as it's running.
        return

305
306
    def _init_tpu_worker_distributed_environment(
        self,
307
        vllm_config: VllmConfig,
308
309
310
311
312
313
314
315
316
317
318
        rank: int,
        distributed_init_method: Optional[str] = None,
        local_rank: int = -1,
    ) -> None:
        """Initialize the distributed environment."""
        if self.use_spmd:
            xr.use_spmd()
        # NOTE(woosuk): This is just to initialize the TP group and broadcast
        # the input objects on CPU. The all-reduce and all-gather ops on TPU
        # are invoked by `xm.all_reduce` and `xm.all_gather` which use their
        # own context.
319
        parallel_config = vllm_config.parallel_config
320
321
322
323
324
        init_distributed_environment(
            world_size=parallel_config.world_size,
            rank=rank,
            local_rank=local_rank,
            distributed_init_method=distributed_init_method,
325
            backend=current_platform.dist_backend,
326
327
328
329
        )
        ensure_model_parallel_initialized(
            parallel_config.tensor_parallel_size,
            parallel_config.pipeline_parallel_size)
330

331
332
        ensure_kv_transfer_initialized(vllm_config)

333

334
if USE_TPU_COMMONS:
335
    from tpu_commons.worker import TPUWorker as TPUCommonsWorker
336

337
    TPUWorker = TPUCommonsWorker  # type: ignore