fsdp_vllm.py 5.67 KB
Newer Older
chenych's avatar
chenych committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
# Copyright 2024 Bytedance Ltd. and/or its affiliates
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

chenych's avatar
chenych committed
15
16
import warnings
from typing import Dict, Iterable, Tuple, Union
chenych's avatar
chenych committed
17
18
19

import torch
import torch.distributed as dist
chenych's avatar
chenych committed
20
from torch.distributed._tensor import DTensor
chenych's avatar
chenych committed
21
22
23
24
25
26
from torch.distributed.device_mesh import DeviceMesh
from torch.distributed.fsdp.api import ShardedStateDictConfig, StateDictType
from torch.distributed.fsdp.fully_sharded_data_parallel import FullyShardedDataParallel as FSDP
from vllm import LLM
from vllm.distributed import parallel_state as vllm_ps

chenych's avatar
chenych committed
27
28
from ...protocol import DataProto, all_gather_data_proto
from ...utils.model_utils import print_gpu_memory_usage
chenych's avatar
chenych committed
29
30
31
32
33
34
35
36
37
38
39
40
41
from .base import BaseShardingManager


class FSDPVLLMShardingManager(BaseShardingManager):
    def __init__(
        self,
        module: FSDP,
        inference_engine: LLM,
        device_mesh: DeviceMesh = None,
    ):
        self.module = module
        self.inference_engine = inference_engine
        self.device_mesh = device_mesh
chenych's avatar
chenych committed
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
        with warnings.catch_warnings():
            warnings.simplefilter("ignore")
            FSDP.set_state_dict_type(
                self.module,
                state_dict_type=StateDictType.SHARDED_STATE_DICT,
                state_dict_config=ShardedStateDictConfig(),
            )

        self.world_size = dist.get_world_size()
        self.tp_size = vllm_ps.get_tensor_model_parallel_world_size()
        self.tp_rank = vllm_ps.get_tensor_model_parallel_rank()
        self.tp_group = vllm_ps.get_tensor_model_parallel_group().device_group

        # Record freed bytes to estimate memory usage correctly
        # https://github.com/vllm-project/vllm/pull/11743#issuecomment-2754338119
        self.freed_bytes = 0
chenych's avatar
chenych committed
58
59
60
61
62
63
64
65
66
67
68
69

        # Note that torch_random_states may be different on each dp rank
        self.torch_random_states = torch.cuda.get_rng_state()
        # get a random rng states
        if self.device_mesh is not None:
            gen_dp_rank = self.device_mesh["dp"].get_local_rank()
            torch.cuda.manual_seed(gen_dp_rank + 1000)  # make sure all tp ranks have the same random states
            self.gen_random_states = torch.cuda.get_rng_state()
            torch.cuda.set_rng_state(self.torch_random_states)
        else:
            self.gen_random_states = None

chenych's avatar
chenych committed
70
71
72
73
74
75
    def _make_weight_iterator(
        self, actor_weights: Dict[str, Union[torch.Tensor, DTensor]]
    ) -> Iterable[Tuple[str, torch.Tensor]]:
        for name, tensor in actor_weights.items():
            yield name, tensor.full_tensor() if self.world_size != 1 else tensor

chenych's avatar
chenych committed
76
    def __enter__(self):
chenych's avatar
chenych committed
77
78
79
80
81
82
83
84
85
        # NOTE: Basically, we only need `torch.cuda.empty_cache()` before vllm wake_up and
        # after vllm sleep, since vllm has its own caching memory allocator CuMemAllocator.
        # Out of vllm scope, we should avoid empty cache to let pytorch using caching memory
        # to speed up memory allocations.
        #
        # pytorch: https://pytorch.org/docs/stable/notes/cuda.html#memory-management
        # vllm: https://github.com/vllm-project/vllm/blob/v0.7.3/vllm/device_allocator/cumem.py#L103
        torch.cuda.empty_cache()
        print_gpu_memory_usage("Before state_dict() in sharding manager")
chenych's avatar
chenych committed
86
        actor_weights = self.module.state_dict()
chenych's avatar
chenych committed
87
        print_gpu_memory_usage("After state_dict() in sharding manager")
chenych's avatar
chenych committed
88
89

        self.inference_engine.wake_up()
chenych's avatar
chenych committed
90
91
92
        model = self.inference_engine.llm_engine.model_executor.driver_worker.worker.model_runner.model
        model.load_weights(self._make_weight_iterator(actor_weights))
        print_gpu_memory_usage("After sync model weights in sharding manager")
chenych's avatar
chenych committed
93
94
95

        del actor_weights
        torch.cuda.empty_cache()
chenych's avatar
chenych committed
96
        print_gpu_memory_usage("After del state_dict and empty_cache in sharding manager")
chenych's avatar
chenych committed
97
98
99
100
101
102
        # important: need to manually set the random states of each tp to be identical.
        if self.device_mesh is not None:
            self.torch_random_states = torch.cuda.get_rng_state()
            torch.cuda.set_rng_state(self.gen_random_states)

    def __exit__(self, exc_type, exc_value, traceback):
chenych's avatar
chenych committed
103
104
        print_gpu_memory_usage("Before vllm offload in sharding manager")
        free_bytes_before_sleep = torch.cuda.mem_get_info()[0]
chenych's avatar
chenych committed
105
        self.inference_engine.sleep(level=1)
chenych's avatar
chenych committed
106
107
108
        free_bytes_after_sleep = torch.cuda.mem_get_info()[0]
        self.freed_bytes = free_bytes_after_sleep - free_bytes_before_sleep
        print_gpu_memory_usage("After vllm offload in sharding manager")
chenych's avatar
chenych committed
109
110
111
112
113
114
115
116
117
118

        self.module.train()
        torch.cuda.empty_cache()  # add empty cache after each compute

        # restore random states
        if self.device_mesh is not None:
            self.gen_random_states = torch.cuda.get_rng_state()
            torch.cuda.set_rng_state(self.torch_random_states)

    def preprocess_data(self, data: DataProto) -> DataProto:
chenych's avatar
chenych committed
119
120
        """All gather across tp group to make each rank has identical input."""
        all_gather_data_proto(data, size=self.tp_size, group=self.tp_group)
chenych's avatar
chenych committed
121
122
123
        return data

    def postprocess_data(self, data: DataProto) -> DataProto:
chenych's avatar
chenych committed
124
125
126
        """Get chunk data of this tp rank since we do all gather in preprocess."""
        if self.tp_size > 1:
            data = data.chunk(chunks=self.tp_size)[self.tp_rank]
chenych's avatar
chenych committed
127
128

        return data