Commit f92481f0 authored by chenych's avatar chenych
Browse files

First commit.

parent 7121d0b0
Pipeline #2435 failed with stages
in 0 seconds
# Copyright 2024 Bytedance Ltd. and/or its affiliates
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Rollout config
"""
from dataclasses import asdict, dataclass, field
@dataclass
class RolloutConfig:
name: str = "vllm"
temperature: float = 1.0
top_k: int = -1
top_p: float = 1.0
dtype: str = "bfloat16"
gpu_memory_utilization: float = 0.5
ignore_eos: bool = False
enforce_eager: bool = False
free_cache_engine: bool = False
enable_chunked_prefill: bool = False
tensor_parallel_size: int = 2
max_num_batched_tokens: int = 8192
max_num_seqs: int = 1024
disable_log_stats: bool = True
do_sample: bool = True
n: int = 1
limit_images: int = 0
"""auto keys"""
prompt_length: int = field(default=-1, init=False)
response_length: int = field(default=-1, init=False)
def to_dict(self):
return asdict(self)
# Copyright 2024 Bytedance Ltd. and/or its affiliates
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from .dtensor_weight_loaders import load_dtensor_weights
from .vllm_rollout_spmd import vLLMRollout
__all__ = ["vLLMRollout", "load_dtensor_weights"]
# Copyright 2024 Bytedance Ltd. and/or its affiliates
# Copyright 2023 The vLLM team.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Adapted from https://github.com/vllm-project/vllm/tree/main/vllm/model_executor/model_loader
from typing import Dict
import torch
import torch.nn as nn
from torch.distributed._tensor import DTensor
from vllm.model_executor.layers.fused_moe import FusedMoE
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.models.utils import is_pp_missing_parameter
def gemma_dtensor_weight_loader(actor_weights: Dict, vllm_model: nn.Module) -> nn.Module:
stacked_params_mapping = [
# (param_name, shard_name, shard_id)
("qkv_proj", "q_proj", "q"),
("qkv_proj", "k_proj", "k"),
("qkv_proj", "v_proj", "v"),
("gate_up_proj", "gate_proj", 0),
("gate_up_proj", "up_proj", 1),
]
params_dict = dict(vllm_model.named_parameters())
for name, loaded_weight in actor_weights.items():
for param_name, shard_name, shard_id in stacked_params_mapping:
if shard_name not in name:
continue
stacked_name = name.replace(shard_name, param_name)
# Skip loading extra bias for GPTQ models.
if stacked_name.endswith(".bias") and stacked_name not in params_dict:
continue
local_loaded_weight = redistribute_dtensor(param_name=name, loaded_weights=loaded_weight)
param = params_dict[stacked_name]
weight_loader = getattr(param, "weight_loader", default_weight_loader)
weight_loader(param, local_loaded_weight.to(dtype=param.dtype), shard_id)
break
else:
# lm_head is not used in vllm as it is tied with embed_token.
# To prevent errors, skip loading lm_head.weight.
if "lm_head.weight" in name:
continue
# Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict:
continue
local_loaded_weight = redistribute_dtensor(param_name=name, loaded_weights=loaded_weight)
param = params_dict[name]
weight_loader = getattr(param, "weight_loader", default_weight_loader)
weight_loader(param, local_loaded_weight.to(dtype=param.dtype))
def llama_dtensor_weight_loader(actor_weights: Dict, vllm_model: nn.Module) -> nn.Module:
stacked_params_mapping = [
# (param_name, shard_name, shard_id)
(".qkv_proj", ".q_proj", "q"),
(".qkv_proj", ".k_proj", "k"),
(".qkv_proj", ".v_proj", "v"),
(".gate_up_proj", ".gate_proj", 0),
(".gate_up_proj", ".up_proj", 1),
]
params_dict = dict(vllm_model.named_parameters())
for name, loaded_weight in actor_weights.items():
if "rotary_emb.inv_freq" in name:
continue
if "rotary_emb.cos_cached" in name or "rotary_emb.sin_cached" in name:
# Models trained using ColossalAI may include these tensors in
# the checkpoint. Skip them.
continue
# With tie_word_embeddings, we can skip lm_head.weight
# The weight might appear unnecessarily in the files if the model is
# processed with quantization, LoRA, fine-tuning, etc.
if vllm_model.config.tie_word_embeddings and "lm_head.weight" in name:
continue
for param_name, weight_name, shard_id in stacked_params_mapping:
if weight_name not in name:
continue
name = name.replace(weight_name, param_name)
# Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict:
continue
local_loaded_weight = redistribute_dtensor(param_name=name, loaded_weights=loaded_weight)
param = params_dict[name]
weight_loader = param.weight_loader
weight_loader(param, local_loaded_weight.to(dtype=param.dtype), shard_id)
break
else:
# Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict:
continue
local_loaded_weight = redistribute_dtensor(param_name=name, loaded_weights=loaded_weight)
param = params_dict[name]
weight_loader = getattr(param, "weight_loader", default_weight_loader)
weight_loader(param, local_loaded_weight)
def qwen2_dtensor_weight_loader(actor_weights: Dict, vllm_model: nn.Module) -> nn.Module:
stacked_params_mapping = [
# (param_name, shard_name, shard_id)
("qkv_proj", "q_proj", "q"),
("qkv_proj", "k_proj", "k"),
("qkv_proj", "v_proj", "v"),
("gate_up_proj", "gate_proj", 0),
("gate_up_proj", "up_proj", 1),
]
params_dict = dict(vllm_model.named_parameters(remove_duplicate=False))
for name, loaded_weight in actor_weights.items():
if "rotary_emb.inv_freq" in name:
continue
if vllm_model.config.tie_word_embeddings and "lm_head.weight" in name:
continue
for param_name, weight_name, shard_id in stacked_params_mapping:
if weight_name not in name:
continue
name = name.replace(weight_name, param_name)
# Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict:
continue
local_loaded_weight = redistribute_dtensor(param_name=name, loaded_weights=loaded_weight)
param = params_dict[name]
weight_loader = param.weight_loader
weight_loader(param, local_loaded_weight.to(dtype=param.dtype), shard_id)
break
else:
# Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict:
continue
param = params_dict[name]
local_loaded_weight = redistribute_dtensor(param_name=name, loaded_weights=loaded_weight)
weight_loader = getattr(param, "weight_loader", default_weight_loader)
weight_loader(param, local_loaded_weight.to(dtype=param.dtype))
def qwen2vl_dtensor_weight_loader(actor_weights: Dict[str, torch.Tensor], vllm_model: nn.Module) -> nn.Module:
stacked_params_mapping = [
# (vllm_substr, hf_substr, shard_id)
("qkv_proj", "q_proj", "q"),
("qkv_proj", "k_proj", "k"),
("qkv_proj", "v_proj", "v"),
("gate_up_proj", "gate_proj", 0),
("gate_up_proj", "up_proj", 1),
]
vllm_params = dict(vllm_model.named_parameters(remove_duplicate=False))
for actor_name, actor_weight in actor_weights.items():
if "rotary_emb.inv_freq" in actor_name:
continue
if vllm_model.config.tie_word_embeddings and "lm_head.weight" in actor_name:
continue
for vllm_substr, hf_substr, shard_id in stacked_params_mapping:
if hf_substr not in actor_name:
continue
if "visual" in actor_name:
continue
vllm_name = "language_model." + actor_name.replace(hf_substr, vllm_substr)
if actor_name.endswith(".bias") and actor_name not in vllm_params:
continue # skip loading extra bias for GPTQ models
local_actor_weight = redistribute_dtensor(param_name=actor_name, loaded_weights=actor_weight)
vllm_param = vllm_params[vllm_name]
weight_loader = vllm_param.weight_loader
weight_loader(vllm_param, local_actor_weight.to(dtype=vllm_param.dtype), shard_id)
break
else:
if actor_name.endswith(".bias") and actor_name not in vllm_params:
continue # skip loading extra bias for GPTQ models
if "visual" in actor_name:
vllm_name = actor_name
else:
vllm_name = "language_model." + actor_name
vllm_param = vllm_params[vllm_name]
local_actor_weight = redistribute_dtensor(param_name=actor_name, loaded_weights=actor_weight)
weight_loader = getattr(vllm_param, "weight_loader", default_weight_loader)
weight_loader(vllm_param, local_actor_weight.to(dtype=vllm_param.dtype))
def deepseekv2_dtensor_weight_loader(actor_weights: Dict, vllm_model: nn.Module) -> nn.Module:
stacked_params_mapping = [
# (param_name, shard_name, shard_id)
("gate_up_proj", "gate_proj", 0),
("gate_up_proj", "up_proj", 1),
]
# Params for weights, fp8 weight scales, fp8 activation scales
# (param_name, weight_name, expert_id, shard_id)
expert_params_mapping = FusedMoE.make_expert_params_mapping(
ckpt_gate_proj_name="gate_proj",
ckpt_down_proj_name="down_proj",
ckpt_up_proj_name="up_proj",
num_experts=vllm_model.config.n_routed_experts,
)
params_dict = dict(vllm_model.named_parameters(remove_duplicate=False))
for name, loaded_weight in actor_weights.items():
if "rotary_emb.inv_freq" in name:
continue
for param_name, weight_name, shard_id in stacked_params_mapping:
# Skip non-stacked layers and experts (experts handled below).
if weight_name not in name:
continue
# We have mlp.experts[0].gate_proj in the checkpoint.
# Since we handle the experts below in expert_params_mapping,
# we need to skip here BEFORE we update the name, otherwise
# name will be updated to mlp.experts[0].gate_up_proj, which
# will then be updated below in expert_params_mapping
# for mlp.experts[0].gate_gate_up_proj, which breaks load.
if ("mlp.experts." in name) and name not in params_dict:
continue
name = name.replace(weight_name, param_name)
# Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict:
continue
if is_pp_missing_parameter(name, vllm_model):
continue
param = params_dict[name]
local_loaded_weight = redistribute_dtensor(param_name=name, loaded_weights=loaded_weight)
weight_loader = getattr(param, "weight_loader", default_weight_loader)
weight_loader(param, local_loaded_weight.to(dtype=param.dtype), shard_id)
break
else:
for mapping in expert_params_mapping:
param_name, weight_name, expert_id, shard_id = mapping
if weight_name not in name:
continue
name = name.replace(weight_name, param_name)
if is_pp_missing_parameter(name, vllm_model):
continue
param = params_dict[name]
local_loaded_weight = redistribute_dtensor(param_name=name, loaded_weights=loaded_weight)
weight_loader = getattr(param, "weight_loader", default_weight_loader)
weight_loader(
param,
local_loaded_weight.to(dtype=param.dtype),
weight_name,
shard_id=shard_id,
expert_id=expert_id,
)
break
else:
# Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict:
continue
if is_pp_missing_parameter(name, vllm_model):
continue
param = params_dict[name]
local_loaded_weight = redistribute_dtensor(param_name=name, loaded_weights=loaded_weight)
weight_loader = getattr(param, "weight_loader", default_weight_loader)
weight_loader(param, local_loaded_weight.to(dtype=param.dtype))
def redistribute_dtensor(param_name: str, loaded_weights: DTensor, parallelize_plan: Dict = None):
param_name = _process_parameter_names(name=param_name)
if parallelize_plan is not None:
assert param_name in parallelize_plan.keys(), (
f"param name: {param_name} not in parallelize_plan :{parallelize_plan.keys()}"
)
placement = parallelize_plan[param_name]
local_loaded_weights = loaded_weights.redistribute(
device_mesh=loaded_weights.device_mesh, placements=placement
).to_local()
else:
local_loaded_weights = loaded_weights.full_tensor()
return local_loaded_weights
def _process_parameter_names(name):
# Remove '.weight' if it exists at the end of the string
if name.endswith(".weight"):
name = name[:-7]
# Remove 'model.layers.x.' or 'model.' prefix
if "model.layers" in name:
parts = name.split(".")
# Reconstruct the string without 'model.layers.x.'
name = ".".join(parts[3:]) # parts[0] is 'model', parts[1] is 'layers', parts[2] is 'x'
elif name.startswith("model."):
name = name[6:] # Remove 'model.'
return name
__MODEL_DTENSOR_WEIGHT_LOADER_REGISTRY__ = {
"LlamaForCausalLM": llama_dtensor_weight_loader,
"LLaMAForCausalLM": llama_dtensor_weight_loader,
"MistralForCausalLM": llama_dtensor_weight_loader, # mistral is the same as llama in vLLM
"InternLMForCausalLM": llama_dtensor_weight_loader,
"Phi3ForCausalLM": llama_dtensor_weight_loader,
"GemmaForCausalLM": gemma_dtensor_weight_loader,
"Gemma2ForCausalLM": gemma_dtensor_weight_loader,
"Qwen2ForCausalLM": qwen2_dtensor_weight_loader,
"DeepseekV2ForCausalLM": deepseekv2_dtensor_weight_loader,
"Qwen2VLForConditionalGeneration": qwen2vl_dtensor_weight_loader,
"Qwen2_5_VLForConditionalGeneration": qwen2vl_dtensor_weight_loader,
}
# the actor model is .state_dict()
# Load dtensor weights
def load_dtensor_weights(actor_weights: Dict, vllm_model: nn.Module):
weight_loader = _get_model_weight_loader(vllm_model.__class__.__name__)
weight_loader(actor_weights, vllm_model)
# NOTE(sgm) to reduce peak memory usage, we offload vllm model to cpu
# after init, and we need this after sync model weights for in first iter.
vllm_model = vllm_model.cuda()
def _get_model_weight_loader(arch: str):
if arch in __MODEL_DTENSOR_WEIGHT_LOADER_REGISTRY__:
return __MODEL_DTENSOR_WEIGHT_LOADER_REGISTRY__[arch]
raise ValueError(
f"Model architectures {arch} are not supported for now. "
f"Supported architectures: {__MODEL_DTENSOR_WEIGHT_LOADER_REGISTRY__.keys()}"
)
# NOTE(sgm): we use per-parameter weight loader in each vllm sub
def update_dtensor_weight_loader():
pass
# Copyright 2024 Bytedance Ltd. and/or its affiliates
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
The vllm_rollout that can be applied in different backend
When working with FSDP:
- Use DTensor weight loader (recommended) or HF weight loader
- Utilize state_dict from the FSDP to synchronize the weights among tp ranks in vLLM
"""
from contextlib import contextmanager
from typing import Any, List, Union
import torch
import torch.distributed
from tensordict import TensorDict
from transformers import PreTrainedTokenizer
from vllm import LLM, RequestOutput, SamplingParams
from verl import DataProto
from verl.utils.torch_functional import get_eos_mask, pad_2d_list_to_length
from verl.workers.rollout.base import BaseRollout
from verl.workers.rollout.config import RolloutConfig
def _repeat_interleave(features: Union[torch.Tensor, List[Any]], repeats: int) -> Union[torch.Tensor, List[Any]]:
if isinstance(features, torch.Tensor):
return features.repeat_interleave(repeats, dim=0)
else:
return [feature for feature in features for _ in range(repeats)]
class vLLMRollout(BaseRollout):
def __init__(self, model_path: str, config: RolloutConfig, tokenizer: PreTrainedTokenizer):
"""A vLLM rollout. It requires the module is supported by the vllm.
Args:
module: module here follows huggingface APIs
config: DictConfig
tokenizer: the task/model tokenizer
"""
super().__init__()
self.config = config
self.pad_token_id = tokenizer.pad_token_id
if config.tensor_parallel_size > torch.distributed.get_world_size():
raise ValueError("Tensor parallelism size should be less than world size.")
if not config.enforce_eager and config.free_cache_engine:
raise ValueError("CUDA graph should be disabled when `free_cache_engine` is True.")
if config.max_num_batched_tokens < config.prompt_length + config.response_length:
raise ValueError("max_num_batched_tokens should be greater than prompt_length + response_length.")
vllm_init_kwargs = {}
if config.limit_images > 0:
vllm_init_kwargs = {"limit_mm_per_prompt": {"image": config.limit_images}}
self.inference_engine = LLM(
model=model_path,
skip_tokenizer_init=False,
tensor_parallel_size=config.tensor_parallel_size,
dtype=config.dtype,
gpu_memory_utilization=config.gpu_memory_utilization,
enforce_eager=config.enforce_eager,
max_model_len=config.prompt_length + config.response_length,
max_num_batched_tokens=config.max_num_batched_tokens,
enable_sleep_mode=True,
distributed_executor_backend="external_launcher",
disable_custom_all_reduce=True,
disable_log_stats=config.disable_log_stats,
enable_chunked_prefill=config.enable_chunked_prefill,
**vllm_init_kwargs,
)
# Offload vllm model to reduce peak memory usage
self.inference_engine.sleep(level=1)
sampling_kwargs = {"max_tokens": config.response_length, "detokenize": False}
default_sampling_params = SamplingParams()
for key in config.to_dict().keys():
if hasattr(default_sampling_params, key):
sampling_kwargs[key] = getattr(config, key)
print(f"Sampling params: {sampling_kwargs}.")
self.sampling_params = SamplingParams(**sampling_kwargs)
@contextmanager
def update_sampling_params(self, **kwargs):
# update sampling params
old_sampling_params_args = {}
if kwargs:
for key, value in kwargs.items():
if hasattr(self.sampling_params, key):
old_value = getattr(self.sampling_params, key)
old_sampling_params_args[key] = old_value
setattr(self.sampling_params, key, value)
yield
# roll back to previous sampling params
for key, value in old_sampling_params_args.items():
setattr(self.sampling_params, key, value)
@torch.no_grad()
def generate_sequences(self, prompts: DataProto, **kwargs) -> DataProto:
# left-padded attention_mask
input_ids: torch.Tensor = prompts.batch["input_ids"] # (bs, prompt_length)
attention_mask: torch.Tensor = prompts.batch["attention_mask"]
position_ids: torch.Tensor = prompts.batch["position_ids"]
eos_token_id: int = prompts.meta_info["eos_token_id"]
batch_size = input_ids.size(0)
do_sample = prompts.meta_info.get("do_sample", True)
if not do_sample:
kwargs = {
"n": 1,
"temperature": 0.0,
"top_p": 1.0,
"top_k": -1,
"min_p": 0.0,
}
non_tensor_batch = prompts.non_tensor_batch
if batch_size != len(non_tensor_batch["raw_prompt_ids"]):
raise RuntimeError("vllm sharding manager is not work properly.")
if "images" in non_tensor_batch:
vllm_inputs = []
for raw_prompt_ids, images in zip(non_tensor_batch.pop("raw_prompt_ids"), non_tensor_batch.pop("images")):
vllm_inputs.append({"prompt_token_ids": raw_prompt_ids, "multi_modal_data": {"image": images}})
else:
vllm_inputs = [
{"prompt_token_ids": raw_prompt_ids} for raw_prompt_ids in non_tensor_batch.pop("raw_prompt_ids")
]
# users can customize different sampling_params at different run
with self.update_sampling_params(**kwargs):
completions: List[RequestOutput] = self.inference_engine.generate(
prompts=vllm_inputs, sampling_params=self.sampling_params
)
response_ids = []
for completion in completions:
for output in completion.outputs:
response_ids.append(output.token_ids)
response_ids = pad_2d_list_to_length(
response_ids, self.pad_token_id, max_length=self.config.response_length
).to(input_ids.device)
if self.config.n > 1 and do_sample:
batch_size = batch_size * self.config.n
input_ids = _repeat_interleave(input_ids, self.config.n)
attention_mask = _repeat_interleave(attention_mask, self.config.n)
position_ids = _repeat_interleave(position_ids, self.config.n)
if "pixel_values" in non_tensor_batch.keys():
non_tensor_batch["pixel_values"] = _repeat_interleave(non_tensor_batch["pixel_values"], self.config.n)
non_tensor_batch["image_grid_thw"] = _repeat_interleave(
non_tensor_batch["image_grid_thw"], self.config.n
)
sequence_ids = torch.cat([input_ids, response_ids], dim=-1)
response_length = response_ids.size(1)
delta_position_id = torch.arange(1, response_length + 1, device=position_ids.device)
delta_position_id = delta_position_id.view(1, -1).expand(batch_size, -1)
if position_ids.dim() == 3: # qwen2vl mrope
delta_position_id = delta_position_id.view(batch_size, 1, -1).expand(batch_size, 3, -1)
# prompt: left pad + response: right pad
# attention_mask: [0,0,0,0,1,1,1,1 | 1,1,1,0,0,0,0,0]
# position_ids: [0,0,0,0,0,1,2,3 | 4,5,6,7,8,9,10,11]
response_position_ids = position_ids[..., -1:] + delta_position_id
position_ids = torch.cat([position_ids, response_position_ids], dim=-1)
response_attention_mask = get_eos_mask(
response_ids=response_ids, eos_token=eos_token_id, dtype=attention_mask.dtype
)
attention_mask = torch.cat((attention_mask, response_attention_mask), dim=-1)
# all the tp ranks should contain the same data here. data in all ranks are valid
batch = TensorDict(
{
"prompts": input_ids,
"responses": response_ids,
"input_ids": sequence_ids, # here input_ids become the whole sentences
"attention_mask": attention_mask,
"position_ids": position_ids,
},
batch_size=batch_size,
)
return DataProto(batch=batch, non_tensor_batch=non_tensor_batch)
# Copyright 2024 Bytedance Ltd. and/or its affiliates
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from .base import BaseShardingManager
from .fsdp_ulysses import FSDPUlyssesShardingManager
from .fsdp_vllm import FSDPVLLMShardingManager
__all__ = ["BaseShardingManager", "FSDPUlyssesShardingManager", "FSDPVLLMShardingManager"]
# Copyright 2024 Bytedance Ltd. and/or its affiliates
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Sharding manager to implement HybridEngine
"""
from verl import DataProto
class BaseShardingManager:
def __enter__(self):
pass
def __exit__(self, exc_type, exc_value, traceback):
pass
def preprocess_data(self, data: DataProto) -> DataProto:
return data
def postprocess_data(self, data: DataProto) -> DataProto:
return data
# Copyright 2024 Bytedance Ltd. and/or its affiliates
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Contains a resharding manager that binds weights from FSDP zero3 to XPerfGPT
"""
from torch.distributed.device_mesh import DeviceMesh
from verl import DataProto
from verl.utils.ulysses import get_ulysses_sequence_parallel_group, set_ulysses_sequence_parallel_group
from .base import BaseShardingManager
class FSDPUlyssesShardingManager(BaseShardingManager):
"""
Sharding manager to support data resharding when using FSDP + Ulysses
"""
def __init__(self, device_mesh: DeviceMesh):
super().__init__()
self.device_mesh = device_mesh
def __enter__(self):
if self.device_mesh is not None:
self.prev_sp_group = get_ulysses_sequence_parallel_group()
set_ulysses_sequence_parallel_group(self.device_mesh["sp"].get_group())
def __exit__(self, exc_type, exc_value, traceback):
if self.device_mesh is not None:
set_ulysses_sequence_parallel_group(self.prev_sp_group)
def preprocess_data(self, data: DataProto) -> DataProto:
"""
AllGather data from sp region
This is because the data is first sharded along the FSDP dimension as we utilize the DP_COMPUTE
In Ulysses, we need to make sure the same data is used across a SP group
"""
if self.device_mesh is not None:
sp_group = self.device_mesh["sp"].get_group()
data = data.to("cuda")
data.all_gather(sp_group)
return data
def postprocess_data(self, data: DataProto) -> DataProto:
"""
Split the data to follow FSDP partition
"""
if self.device_mesh is not None:
sp_size = self.device_mesh["sp"].size()
sp_rank = self.device_mesh["sp"].get_local_rank()
data = data.chunk(chunks=sp_size)[sp_rank]
return data
# Copyright 2024 Bytedance Ltd. and/or its affiliates
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import torch
import torch.distributed as dist
from torch.distributed.device_mesh import DeviceMesh
from torch.distributed.fsdp.api import ShardedStateDictConfig, StateDictType
from torch.distributed.fsdp.fully_sharded_data_parallel import FullyShardedDataParallel as FSDP
from vllm import LLM
from vllm.distributed import parallel_state as vllm_ps
from verl import DataProto
from verl.utils.performance import log_gpu_memory_usage
from verl.workers.rollout.vllm_rollout import load_dtensor_weights
from .base import BaseShardingManager
class FSDPVLLMShardingManager(BaseShardingManager):
def __init__(
self,
module: FSDP,
inference_engine: LLM,
device_mesh: DeviceMesh = None,
):
self.module = module
self.inference_engine = inference_engine
self.device_mesh = device_mesh
FSDP.set_state_dict_type(
self.module,
state_dict_type=StateDictType.SHARDED_STATE_DICT,
state_dict_config=ShardedStateDictConfig(),
)
# Note that torch_random_states may be different on each dp rank
self.torch_random_states = torch.cuda.get_rng_state()
# get a random rng states
if self.device_mesh is not None:
gen_dp_rank = self.device_mesh["dp"].get_local_rank()
torch.cuda.manual_seed(gen_dp_rank + 1000) # make sure all tp ranks have the same random states
self.gen_random_states = torch.cuda.get_rng_state()
torch.cuda.set_rng_state(self.torch_random_states)
else:
self.gen_random_states = None
def __enter__(self):
log_gpu_memory_usage("Before state_dict() in sharding manager")
actor_weights = self.module.state_dict()
log_gpu_memory_usage("After state_dict() in sharding manager")
self.inference_engine.wake_up()
load_dtensor_weights(
actor_weights, self.inference_engine.llm_engine.model_executor.driver_worker.worker.model_runner.model
)
log_gpu_memory_usage("After sync model weights in sharding manager")
del actor_weights
torch.cuda.empty_cache()
log_gpu_memory_usage("After del state_dict and empty_cache in sharding manager")
# important: need to manually set the random states of each tp to be identical.
if self.device_mesh is not None:
self.torch_random_states = torch.cuda.get_rng_state()
torch.cuda.set_rng_state(self.gen_random_states)
def __exit__(self, exc_type, exc_value, traceback):
log_gpu_memory_usage("Before vllm offload in sharding manager")
self.inference_engine.sleep(level=1)
log_gpu_memory_usage("After vllm offload in sharding manager")
self.module.train()
torch.cuda.empty_cache() # add empty cache after each compute
# restore random states
if self.device_mesh is not None:
self.gen_random_states = torch.cuda.get_rng_state()
torch.cuda.set_rng_state(self.torch_random_states)
def preprocess_data(self, data: DataProto) -> DataProto:
tp_group = vllm_ps.get_tensor_model_parallel_group().device_group
data = data.to("cuda")
data.all_gather(tp_group)
return data
def postprocess_data(self, data: DataProto) -> DataProto:
dp_rank = dist.get_rank()
tp_size = vllm_ps.get_tensor_model_parallel_world_size()
if tp_size > 1:
data = data.chunk(chunks=tp_size)[dp_rank % tp_size]
return data
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment