tpu_worker.py 14.3 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
"""A TPU worker class."""
4

5
import os
6
7
from collections.abc import Callable
from typing import Any, TypeVar
8
9
10
11
12

import torch
import torch.nn as nn

import vllm.envs as envs
13
from vllm.config import VllmConfig
14
15
16
17
18
19
20
from vllm.distributed import (
    ensure_model_parallel_initialized,
    init_distributed_environment,
)
from vllm.distributed.kv_transfer import (
    ensure_kv_transfer_initialized,
)
21
from vllm.logger import init_logger
22
from vllm.lora.request import LoRARequest
23
from vllm.model_executor import set_random_seed
24
from vllm.platforms import current_platform
25
from vllm.platforms.tpu import USE_TPU_INFERENCE
26
from vllm.tasks import SupportedTask
27
from vllm.utils.math_utils import cdiv
28
from vllm.utils.torch_utils import STR_DTYPE_TO_TORCH_DTYPE
29
from vllm.v1.core.sched.output import GrammarOutput, SchedulerOutput
30
from vllm.v1.kv_cache_interface import AttentionSpec, KVCacheConfig, KVCacheSpec
31
from vllm.v1.outputs import ModelRunnerOutput
32
33
from vllm.v1.utils import report_usage_stats
from vllm.v1.worker.utils import bind_kv_cache
34
35
36

logger = init_logger(__name__)

37
38
_R = TypeVar("_R")

39
40
if not USE_TPU_INFERENCE:
    logger.info("tpu_inference not found, using vLLM's TPUWorker.")
41
42
43
44
45
46
47
    import torch_xla.core.xla_model as xm
    import torch_xla.debug.profiler as xp
    import torch_xla.runtime as xr

    from vllm.v1.attention.backends.pallas import TPU_HEAD_SIZE_ALIGNMENT
    from vllm.v1.worker.tpu_model_runner import TPUModelRunner

48
49
50
51
52
53
54
55
56
57

class TPUWorker:
    def __init__(
        self,
        vllm_config: VllmConfig,
        local_rank: int,
        rank: int,
        distributed_init_method: str,
        is_driver_worker: bool = False,
    ):
58
        self.is_driver_worker = is_driver_worker
59
60
61
62
63
64
        self.vllm_config = vllm_config
        self.model_config = vllm_config.model_config
        self.cache_config = vllm_config.cache_config
        self.lora_config = vllm_config.lora_config
        self.load_config = vllm_config.load_config
        self.parallel_config = vllm_config.parallel_config
65
66
67
68
69
70
71
72
73
        self.use_spmd = envs.VLLM_XLA_USE_SPMD
        self.original_parallel_config = None
        if self.use_spmd:
            # Under SPMD mode, distributed env is initialized as if there is
            # only one worker/device.
            self.original_parallel_config = self.parallel_config
            self.parallel_config.tensor_parallel_size = 1
            self.parallel_config.pipeline_parallel_size = 1
            self.parallel_config.world_size = 1
74
75
76
77
78
79
80
81
82
83
84
85
86
        self.scheduler_config = vllm_config.scheduler_config
        self.device_config = vllm_config.device_config
        self.speculative_config = vllm_config.speculative_config
        self.observability_config = vllm_config.observability_config

        self.parallel_config.rank = rank
        self.local_rank = local_rank
        self.rank = rank
        self.distributed_init_method = distributed_init_method

        if self.cache_config.cache_dtype == "auto":
            self.cache_dtype = self.model_config.dtype
        else:
87
            self.cache_dtype = STR_DTYPE_TO_TORCH_DTYPE[self.cache_config.cache_dtype]
88
89
90

        if self.model_config.trust_remote_code:
            # note: lazy import to avoid importing torch before initializing
91
            from vllm.utils.import_utils import init_cached_hf_modules
92

93
94
            init_cached_hf_modules()

95
96
97
98
        # Delay profiler initialization to the start of the profiling.
        # This is because in vLLM V1, MP runtime is initialized before the
        # TPU Worker is initialized. The profiler server needs to start after
        # MP runtime is initialized.
99
        self.profiler = None
100
        self.profile_dir = None
101
        if vllm_config.profiler_config.profiler == "torch" and self.rank < 1:
102
103
            # For TPU, we can only have 1 active profiler session for 1 profiler
            # server. So we only profile on rank0.
104
            self.profile_dir = vllm_config.profiler_config.torch_profiler_dir
105
106
107
            logger.info(
                "Profiling enabled. Traces will be saved to: %s", self.profile_dir
            )
108

109
    def initialize_cache(self, num_gpu_blocks: int, num_cpu_blocks: int) -> None:
110
111
112
        self.cache_config.num_gpu_blocks = num_gpu_blocks
        self.cache_config.num_cpu_blocks = num_cpu_blocks

113
114
    def init_device(self):
        os.environ["PJRT_DEVICE"] = "TPU"
115
116
117
118
119
        # Note: Currently the XLA compiler wrongly uses 2D ring strategy on 1D
        # ring, the xla tpu compiler flag
        # `xla_tpu_force_1d_allreduce_at_chunk_count` is a temporary solution to
        # fix this. It will be removed after the bug in XLA compiler is fixed.
        os.environ["LIBTPU_INIT_ARGS"] = (
120
121
122
123
            os.environ.get("LIBTPU_INIT_ARGS", "")
            + " --xla_tpu_force_1d_allreduce_at_chunk_count=1"
            " --xla_jf_conv_input_fusion=False"
        )
124
125
        # --xla_jf_conv_input_fusion=False is used to improve the perf of
        # quantized matmul.
126
127
128
129
        torch.set_grad_enabled(False)
        torch.set_default_dtype(self.model_config.dtype)

        # Initialize the distributed environment.
130
        self._init_tpu_worker_distributed_environment(
131
132
            self.vllm_config, self.rank, self.distributed_init_method, self.local_rank
        )
133
134
135
136
137
138
139
140

        # Device initialization should happen after initializing
        # the distributed runtime.
        self.device = xm.xla_device()
        self.device_config.device = self.device

        # Set random seed.
        set_random_seed(self.model_config.seed)
141
        xm.set_rng_state(self.model_config.seed, self.device)
142
143
144

        # Increase the cache size limit, which is the maximum number of
        # dynamo graphs that can be compiled.
145
146
        # TODO (NickLucche) On gsm we compile 80+ graphs.
        # Re-evaluate limit, with MM we may get close to this limit.
147
148
149
150
151
152
        torch._dynamo.config.cache_size_limit = 128
        # Use persistent cache to avoid XLA recompilation.
        # NOTE(woosuk): Set per-rank cache path since different ranks
        # can have slightly different XLA graphs.
        world_size = self.parallel_config.world_size
        rank = xr.global_ordinal()
153
154
155
156
157
158
159
        # The PyTorch/XLA compilation cache uses the Torch IR to generate keys.
        # Consequently, changes in optimization flags, which affect compilation
        # results, don't change the cache key. This can result in the wrong
        # compilation being used. To prevent this, disabling the XLA compilation
        # cache during development is recommended.We can disable it by
        # `export VLLM_XLA_CACHE_PATH=`
        if envs.VLLM_XLA_CACHE_PATH:
160
161
162
            per_rank_path = os.path.join(
                envs.VLLM_XLA_CACHE_PATH, f"tp{world_size}_rank{rank}"
            )
163
            xr.initialize_cache(per_rank_path, readonly=False)
164
165

        # Init ModelRunner here, so that we have access to self.device.
166
167
168
        self.model_runner = TPUModelRunner(
            self.vllm_config, self.device, self.original_parallel_config
        )
169

170
171
172
173
        if rank == 0:
            # If usage stat is enabled, collect relevant info.
            report_usage_stats(self.vllm_config)

174
    def determine_available_memory(self) -> int:
175
        kv_caches: dict[str, torch.Tensor] = {}
176
177
        kv_cache_spec = self.model_runner.get_kv_cache_spec()
        for layer_name, layer_spec in kv_cache_spec.items():
178
            if isinstance(layer_spec, AttentionSpec):
179
180
                dtype = layer_spec.dtype

181
182
                # Use an empty tensor instead of `None` to force Dynamo to pass
                # it by reference, rather by specializing on the value `None`.
183
                tpu_kv_cache = torch.tensor([], dtype=dtype).to(self.device)
184
                kv_caches[layer_name] = tpu_kv_cache
185
            else:
186
                raise NotImplementedError(
187
188
                    f"Unsupported KV cache spec '{type(layer_spec)}'"
                )
189

190
        runner_kv_caches: list[torch.Tensor] = []
191
192
193
        bind_kv_cache(
            kv_caches,
            self.vllm_config.compilation_config.static_forward_context,
194
195
            runner_kv_caches,
        )
196

197
        # `max_num_tokens >= max_num_batched_tokens` due to padding.
198
199
        with self.model_runner.maybe_setup_dummy_loras(self.lora_config):
            self.model_runner.profile_run(self.model_runner.max_num_tokens)
200
201
202
203

        # Synchronize before measuring the memory usage.
        xm.wait_device_ops()

204
205
206
207
208
209
210
211
        # During the profiling run, the model runs without KV cache. After
        # the profiling run, the model always runs with KV cache. Here we clear
        # the dynamo cache and cached bytecode to ensure the model always has
        # one compiled bytecode. Having one FX graph/cached bytecode per
        # compiled model is required for `support_torch_compile` decorator to
        # skip dynamo guard.
        self.model_runner.reset_dynamo_cache()

212
213
        # Get the maximum amount of memory used by the model weights and
        # intermediate activations.
214
215
216
217
218
219
        if self.use_spmd:
            # This is a workaround for the TPU SPMD mode. The get_memory_info
            # API doesn't work with SPMD mode in PyTorch/XLA.
            # TODO: use xm.get_memory_info for SPMD once it's supported in
            # PyTorch/XLA.
            import tpu_info
220

221
222
223
224
225
226
227
228
            chip_type, _ = tpu_info.device.get_local_chips()
            device_usage = tpu_info.metrics.get_chip_usage(chip_type)
            total_memory_size = device_usage[0].total_memory
            current_mem = device_usage[0].memory_usage
        else:
            m = xm.get_memory_info(self.device)
            total_memory_size = m["bytes_limit"]
            current_mem = m["bytes_used"]
229
230
231
232
233
234
        # Ideally we would use profiled = m["peak_bytes_used"] to
        # get weights + activations. But there is memory used during
        # compilation / weight loading that impacts the peak and
        # there is no way to reset peak memory in XLA, So we
        # use the heuristic of 2% of weights.
        profiled = current_mem * 1.02
235
236

        # Calculate the TPU KV cache size based on profiling.
237
238
239
        usable_memory_size = int(
            total_memory_size * self.cache_config.gpu_memory_utilization
        )
240
        tpu_kv_cache_bytes = max(usable_memory_size - profiled, 0)
241
242
        head_size = self.model_config.get_head_size()
        if head_size > 0:
243
244
245
            padded_head_size = (
                cdiv(head_size, TPU_HEAD_SIZE_ALIGNMENT) * TPU_HEAD_SIZE_ALIGNMENT
            )
246
            if padded_head_size != head_size:
247
                logger.warning_once("head size is padded to %d", padded_head_size)
248
249
            # We adjust the usable memory size for the KV cache to prevent OOM
            # errors, even after padding the head_size.
250
            tpu_kv_cache_bytes = tpu_kv_cache_bytes * head_size // padded_head_size
251
252
        return int(tpu_kv_cache_bytes)

253
254
255
    def sample_tokens(self, grammar_output: "GrammarOutput") -> ModelRunnerOutput:
        return self.model_runner.sample_tokens(grammar_output)

256
    def execute_model(
257
        self, scheduler_output: "SchedulerOutput"
258
    ) -> ModelRunnerOutput | None:
259
        return self.model_runner.execute_model(scheduler_output)
260

261
262
    def profile(self, is_start: bool = True):
        if self.rank < 1:
263
            if self.profile_dir is None:
264
265
                raise RuntimeError("Profiler is not enabled.")
            if is_start:
266
267
                if self.profiler is None:
                    self.profiler = xp.start_server(9012)
268
269
270
271
                xp.start_trace(self.profile_dir)
            else:
                xp.stop_trace()

272
273
274
    def add_lora(self, lora_request: LoRARequest) -> bool:
        return self.model_runner.add_lora(lora_request)

275
276
277
    def load_model(self) -> None:
        self.model_runner.load_model()

278
279
280
    def update_config(self, overrides: dict[str, Any]) -> None:
        self.model_runner.update_config(overrides)

281
282
283
    def reload_weights(self) -> None:
        self.model_runner.reload_weights()

284
285
286
287
288
289
290
291
    def compile_or_warm_up_model(self) -> None:
        if not self.model_config.enforce_eager:
            self.model_runner.capture_model()

        # Reset the seed to ensure that the random state is not affected by
        # the model initialization and profiling.
        set_random_seed(self.model_config.seed)

292
293
294
    def reset_mm_cache(self) -> None:
        self.model_runner.reset_mm_cache()

295
296
297
    def get_model(self) -> nn.Module:
        return self.model_runner.get_model()

298
299
    def get_supported_tasks(self) -> tuple[SupportedTask, ...]:
        return self.model_runner.get_supported_tasks()
300

301
    def get_kv_cache_spec(self) -> dict[str, KVCacheSpec]:
302
303
        return self.model_runner.get_kv_cache_spec()

304
    def initialize_from_config(self, kv_cache_config: KVCacheConfig) -> None:
305
306
307
308
309
310
311
        """Allocate GPU KV cache with the specified kv_cache_config."""
        self.model_runner.initialize_kv_cache(kv_cache_config)

    def check_health(self) -> None:
        # worker will always be healthy as long as it's running.
        return

312
313
    def _init_tpu_worker_distributed_environment(
        self,
314
        vllm_config: VllmConfig,
315
        rank: int,
316
        distributed_init_method: str | None = None,
317
318
319
320
321
322
323
324
325
        local_rank: int = -1,
    ) -> None:
        """Initialize the distributed environment."""
        if self.use_spmd:
            xr.use_spmd()
        # NOTE(woosuk): This is just to initialize the TP group and broadcast
        # the input objects on CPU. The all-reduce and all-gather ops on TPU
        # are invoked by `xm.all_reduce` and `xm.all_gather` which use their
        # own context.
326
        parallel_config = vllm_config.parallel_config
327
328
329
330
        init_distributed_environment(
            world_size=parallel_config.world_size,
            rank=rank,
            local_rank=local_rank,
331
            distributed_init_method=distributed_init_method or "env://",
332
            backend=current_platform.dist_backend,
333
334
        )
        ensure_model_parallel_initialized(
335
336
            parallel_config.tensor_parallel_size, parallel_config.pipeline_parallel_size
        )
337

338
339
        ensure_kv_transfer_initialized(vllm_config)

340
341
342
    def shutdown(self) -> None:
        self.model_runner.ensure_kv_transfer_shutdown()

343
344
345
346
    def apply_model(self, fn: Callable[[nn.Module], _R]) -> _R:
        """Apply a function on the model inside this worker."""
        return fn(self.get_model())

347

348
if USE_TPU_INFERENCE:
Johnny Yang's avatar
Johnny Yang committed
349
    from tpu_inference.worker.tpu_worker import TPUWorker as TpuInferenceWorker
350

351
    TPUWorker = TpuInferenceWorker  # type: ignore