"examples/offline_inference/rlhf_online_quant.py" did not exist on "845420ac2c2bc27ae0f96c25430b4f1cd20063cc"
tpu.py 4.69 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
4
5
6
7
8
9
10
11
12
13
14
import time

import torch
import torch.nn as nn
import torch_xla.core.xla_model as xm
import torch_xla.distributed.spmd as xs

from vllm.config import ModelConfig, VllmConfig
from vllm.distributed.tpu_distributed_utils import get_fqn, shard_model
from vllm.logger import init_logger
from vllm.model_executor.model_loader.default_loader import DefaultModelLoader
from vllm.model_executor.model_loader.utils import (
15
16
17
    initialize_model,
    process_weights_after_loading,
)
18
from vllm.utils.torch_utils import set_default_torch_dtype
19
20
21
22
23
24
25
26
27
28
29
30
31

logger = init_logger(__name__)


class TPUModelLoader(DefaultModelLoader):
    """
    A TPU model loader for model loading under SPMD mode.
    """

    def load_model(
        self,
        vllm_config: VllmConfig,
        model_config: ModelConfig,
32
        mesh: xs.Mesh | None = None,
33
34
35
36
37
38
    ) -> nn.Module:
        # Initialize model and load weights on CPU. Then, during SPMD partition,
        # weights are sharded and transferred to TPUs.
        self.counter_before_loading_weights = time.perf_counter()
        model_config = vllm_config.model_config
        assert model_config.quantization is None, "Quantization not supported"
39
        target_device = torch.device("cpu")
40
41
42
43
44
45
        with set_default_torch_dtype(model_config.dtype):
            with target_device:
                model = initialize_model(vllm_config=vllm_config)

            load_format = vllm_config.load_config.load_format
            if load_format != "dummy":
46
                weights_to_load = {name for name, _ in model.named_parameters()}
47
48
49
50
51
                all_weights = self.get_all_weights(model_config, model)
                loaded_weights = model.load_weights(all_weights)
                self.counter_after_loading_weights = time.perf_counter()
                logger.info(
                    "Loading weights took %.2f seconds",
52
53
54
                    self.counter_after_loading_weights
                    - self.counter_before_loading_weights,
                )
55
56
                # We only enable strict check for non-quantized models
                # that have loaded weights tracking currently.
57
                if model_config.quantization is None and loaded_weights is not None:
58
59
60
61
                    weights_not_loaded = weights_to_load - loaded_weights
                    if weights_not_loaded:
                        raise ValueError(
                            "Following weights were not initialized from "
62
63
                            f"checkpoint: {weights_not_loaded}"
                        )
64
65
66
67
68
69
70
            else:
                logger.info("Use dummy weight during weight loading.")

            process_weights_after_loading(model, model_config, target_device)

        counter_before_partition = time.perf_counter()
        model = model.eval()
71
        model = model.to("xla")
72
73
        shard_model(model, mesh)
        counter_after_partition = time.perf_counter()
74
75
76
77
        logger.info(
            "Partition model took %.2f seconds",
            counter_after_partition - counter_before_partition,
        )
78
79
80
81
82
83
84
85
86

        # Ensure the model is properly loaded.
        self._check_model_is_loaded(mesh, model)

        # Need to torch compile after model sharding are done. Because the
        # compiler hints ('xs.mark_sharding') are torch ops.
        if not model_config.is_multimodal_model:
            model.model = torch.compile(model.model, backend="openxla")
        else:
87
88
89
            model.language_model.model = torch.compile(
                model.language_model.model, backend="openxla"
            )
90
91
        return model

92
    def _check_model_is_loaded(self, mesh: xs.Mesh | None, model: nn.Module) -> None:
93
94
95
96
97
98
99
100
101
102
        """
        Ensure the model is properly loaded.
        1. All model parameters and buffers are on XLA device.
        2. Non-SPMD friendly layers are replaced as expected.
        """
        device = xm.xla_device()
        device_type = str(device.type)

        # Check parameters
        for name, param in model.named_parameters():
103
            assert param.device.type == device_type, (
104
105
                f"Parameter {name} is on {param.device.type} instead of {device_type}"
            )
106
107
108

        # Check buffers
        for name, buffer in model.named_buffers():
109
            assert buffer.device.type == device_type, (
110
111
                f"Buffer {name} is on {buffer.device.type} instead of {device_type}"
            )
112
113

        for module in model.modules():
114
115
116
117
118
            if (mesh is not None) and (get_fqn(module) == "QKVParallelLinear"):
                raise AssertionError(
                    "QKVParallelLinear should be replaced by \
                            XlaQKVParallelLinear under SPMD mode."
                )