fsdp_vllm.py 5.65 KB
Newer Older
chenych's avatar
chenych committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
# Copyright 2024 Bytedance Ltd. and/or its affiliates
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

chenych's avatar
update  
chenych committed
15
import inspect
chenych's avatar
chenych committed
16
from typing import Dict, Iterable, Tuple, Union
chenych's avatar
chenych committed
17
18
19

import torch
import torch.distributed as dist
chenych's avatar
chenych committed
20
from torch.distributed._tensor import DTensor
chenych's avatar
update  
chenych committed
21
from torch.distributed.checkpoint.state_dict import get_model_state_dict
chenych's avatar
chenych committed
22
23
24
25
26
from torch.distributed.device_mesh import DeviceMesh
from torch.distributed.fsdp.fully_sharded_data_parallel import FullyShardedDataParallel as FSDP
from vllm import LLM
from vllm.distributed import parallel_state as vllm_ps

chenych's avatar
chenych committed
27
28
from ...protocol import DataProto, all_gather_data_proto
from ...utils.model_utils import print_gpu_memory_usage
chenych's avatar
chenych committed
29
30
31
32
33
34
35
36
from .base import BaseShardingManager


class FSDPVLLMShardingManager(BaseShardingManager):
    def __init__(
        self,
        module: FSDP,
        inference_engine: LLM,
chenych's avatar
update  
chenych committed
37
        device_mesh: DeviceMesh,
chenych's avatar
chenych committed
38
39
40
41
    ):
        self.module = module
        self.inference_engine = inference_engine
        self.device_mesh = device_mesh
chenych's avatar
chenych committed
42
43
44
45
46
47
48
49
50

        self.world_size = dist.get_world_size()
        self.tp_size = vllm_ps.get_tensor_model_parallel_world_size()
        self.tp_rank = vllm_ps.get_tensor_model_parallel_rank()
        self.tp_group = vllm_ps.get_tensor_model_parallel_group().device_group

        # Record freed bytes to estimate memory usage correctly
        # https://github.com/vllm-project/vllm/pull/11743#issuecomment-2754338119
        self.freed_bytes = 0
chenych's avatar
chenych committed
51
52
53
54

        # Note that torch_random_states may be different on each dp rank
        self.torch_random_states = torch.cuda.get_rng_state()
        # get a random rng states
chenych's avatar
update  
chenych committed
55
56
57
58
        gen_dp_rank = self.device_mesh["dp"].get_local_rank()
        torch.cuda.manual_seed(gen_dp_rank + 1000)  # make sure all tp ranks have the same random states
        self.gen_random_states = torch.cuda.get_rng_state()
        torch.cuda.set_rng_state(self.torch_random_states)
chenych's avatar
chenych committed
59

chenych's avatar
chenych committed
60
61
62
63
64
65
    def _make_weight_iterator(
        self, actor_weights: Dict[str, Union[torch.Tensor, DTensor]]
    ) -> Iterable[Tuple[str, torch.Tensor]]:
        for name, tensor in actor_weights.items():
            yield name, tensor.full_tensor() if self.world_size != 1 else tensor

chenych's avatar
chenych committed
66
    def __enter__(self):
chenych's avatar
chenych committed
67
68
69
70
71
72
73
74
75
        # NOTE: Basically, we only need `torch.cuda.empty_cache()` before vllm wake_up and
        # after vllm sleep, since vllm has its own caching memory allocator CuMemAllocator.
        # Out of vllm scope, we should avoid empty cache to let pytorch using caching memory
        # to speed up memory allocations.
        #
        # pytorch: https://pytorch.org/docs/stable/notes/cuda.html#memory-management
        # vllm: https://github.com/vllm-project/vllm/blob/v0.7.3/vllm/device_allocator/cumem.py#L103
        torch.cuda.empty_cache()
        print_gpu_memory_usage("Before state_dict() in sharding manager")
chenych's avatar
update  
chenych committed
76
        actor_weights = get_model_state_dict(self.module)
chenych's avatar
chenych committed
77
        print_gpu_memory_usage("After state_dict() in sharding manager")
chenych's avatar
chenych committed
78

chenych's avatar
update  
chenych committed
79
80
81
82
83
        if "tags" in inspect.signature(self.inference_engine.wake_up).parameters:
            self.inference_engine.wake_up(tags=["weights"])
        else:
            self.inference_engine.wake_up()

chenych's avatar
chenych committed
84
85
86
        model = self.inference_engine.llm_engine.model_executor.driver_worker.worker.model_runner.model
        model.load_weights(self._make_weight_iterator(actor_weights))
        print_gpu_memory_usage("After sync model weights in sharding manager")
chenych's avatar
chenych committed
87
88
89

        del actor_weights
        torch.cuda.empty_cache()
chenych's avatar
update  
chenych committed
90
91
92
93

        if "tags" in inspect.signature(self.inference_engine.wake_up).parameters:
            self.inference_engine.wake_up(tags=["kv_cache"])

chenych's avatar
chenych committed
94
        print_gpu_memory_usage("After del state_dict and empty_cache in sharding manager")
chenych's avatar
chenych committed
95
96
97
98
99
100
        # important: need to manually set the random states of each tp to be identical.
        if self.device_mesh is not None:
            self.torch_random_states = torch.cuda.get_rng_state()
            torch.cuda.set_rng_state(self.gen_random_states)

    def __exit__(self, exc_type, exc_value, traceback):
chenych's avatar
chenych committed
101
102
        print_gpu_memory_usage("Before vllm offload in sharding manager")
        free_bytes_before_sleep = torch.cuda.mem_get_info()[0]
chenych's avatar
chenych committed
103
        # self.inference_engine.sleep(level=1)
chenych's avatar
update  
chenych committed
104
105
        ## rocm
        # self.inference_engine.offload_model_weights()
chenych's avatar
chenych committed
106
107
108
        free_bytes_after_sleep = torch.cuda.mem_get_info()[0]
        self.freed_bytes = free_bytes_after_sleep - free_bytes_before_sleep
        print_gpu_memory_usage("After vllm offload in sharding manager")
chenych's avatar
chenych committed
109
110
111
112
113
114
115
116
117
118

        self.module.train()
        torch.cuda.empty_cache()  # add empty cache after each compute

        # restore random states
        if self.device_mesh is not None:
            self.gen_random_states = torch.cuda.get_rng_state()
            torch.cuda.set_rng_state(self.torch_random_states)

    def preprocess_data(self, data: DataProto) -> DataProto:
chenych's avatar
chenych committed
119
120
        """All gather across tp group to make each rank has identical input."""
        all_gather_data_proto(data, size=self.tp_size, group=self.tp_group)
chenych's avatar
chenych committed
121
122
123
        return data

    def postprocess_data(self, data: DataProto) -> DataProto:
chenych's avatar
chenych committed
124
125
126
        """Get chunk data of this tp rank since we do all gather in preprocess."""
        if self.tp_size > 1:
            data = data.chunk(chunks=self.tp_size)[self.tp_rank]
chenych's avatar
chenych committed
127
128

        return data