# Copyright 2024 Bytedance Ltd. and/or its affiliates # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import inspect from typing import Dict, Iterable, Tuple, Union import torch import torch.distributed as dist from torch.distributed._tensor import DTensor from torch.distributed.checkpoint.state_dict import get_model_state_dict from torch.distributed.device_mesh import DeviceMesh from torch.distributed.fsdp.fully_sharded_data_parallel import FullyShardedDataParallel as FSDP from vllm import LLM from vllm.distributed import parallel_state as vllm_ps from ...protocol import DataProto, all_gather_data_proto from ...utils.model_utils import print_gpu_memory_usage from .base import BaseShardingManager class FSDPVLLMShardingManager(BaseShardingManager): def __init__( self, module: FSDP, inference_engine: LLM, device_mesh: DeviceMesh, ): self.module = module self.inference_engine = inference_engine self.device_mesh = device_mesh self.world_size = dist.get_world_size() self.tp_size = vllm_ps.get_tensor_model_parallel_world_size() self.tp_rank = vllm_ps.get_tensor_model_parallel_rank() self.tp_group = vllm_ps.get_tensor_model_parallel_group().device_group # Record freed bytes to estimate memory usage correctly # https://github.com/vllm-project/vllm/pull/11743#issuecomment-2754338119 self.freed_bytes = 0 # Note that torch_random_states may be different on each dp rank self.torch_random_states = torch.cuda.get_rng_state() # get a random rng states gen_dp_rank = self.device_mesh["dp"].get_local_rank() torch.cuda.manual_seed(gen_dp_rank + 1000) # make sure all tp ranks have the same random states self.gen_random_states = torch.cuda.get_rng_state() torch.cuda.set_rng_state(self.torch_random_states) def _make_weight_iterator( self, actor_weights: Dict[str, Union[torch.Tensor, DTensor]] ) -> Iterable[Tuple[str, torch.Tensor]]: for name, tensor in actor_weights.items(): yield name, tensor.full_tensor() if self.world_size != 1 else tensor def __enter__(self): # NOTE: Basically, we only need `torch.cuda.empty_cache()` before vllm wake_up and # after vllm sleep, since vllm has its own caching memory allocator CuMemAllocator. # Out of vllm scope, we should avoid empty cache to let pytorch using caching memory # to speed up memory allocations. # # pytorch: https://pytorch.org/docs/stable/notes/cuda.html#memory-management # vllm: https://github.com/vllm-project/vllm/blob/v0.7.3/vllm/device_allocator/cumem.py#L103 torch.cuda.empty_cache() print_gpu_memory_usage("Before state_dict() in sharding manager") actor_weights = get_model_state_dict(self.module) print_gpu_memory_usage("After state_dict() in sharding manager") if "tags" in inspect.signature(self.inference_engine.wake_up).parameters: self.inference_engine.wake_up(tags=["weights"]) else: self.inference_engine.wake_up() model = self.inference_engine.llm_engine.model_executor.driver_worker.worker.model_runner.model model.load_weights(self._make_weight_iterator(actor_weights)) print_gpu_memory_usage("After sync model weights in sharding manager") del actor_weights torch.cuda.empty_cache() if "tags" in inspect.signature(self.inference_engine.wake_up).parameters: self.inference_engine.wake_up(tags=["kv_cache"]) print_gpu_memory_usage("After del state_dict and empty_cache in sharding manager") # important: need to manually set the random states of each tp to be identical. if self.device_mesh is not None: self.torch_random_states = torch.cuda.get_rng_state() torch.cuda.set_rng_state(self.gen_random_states) def __exit__(self, exc_type, exc_value, traceback): print_gpu_memory_usage("Before vllm offload in sharding manager") free_bytes_before_sleep = torch.cuda.mem_get_info()[0] # self.inference_engine.sleep(level=1) ## rocm # self.inference_engine.offload_model_weights() free_bytes_after_sleep = torch.cuda.mem_get_info()[0] self.freed_bytes = free_bytes_after_sleep - free_bytes_before_sleep print_gpu_memory_usage("After vllm offload in sharding manager") self.module.train() torch.cuda.empty_cache() # add empty cache after each compute # restore random states if self.device_mesh is not None: self.gen_random_states = torch.cuda.get_rng_state() torch.cuda.set_rng_state(self.torch_random_states) def preprocess_data(self, data: DataProto) -> DataProto: """All gather across tp group to make each rank has identical input.""" all_gather_data_proto(data, size=self.tp_size, group=self.tp_group) return data def postprocess_data(self, data: DataProto) -> DataProto: """Get chunk data of this tp rank since we do all gather in preprocess.""" if self.tp_size > 1: data = data.chunk(chunks=self.tp_size)[self.tp_rank] return data