Commit 2369eb2b authored by chenych's avatar chenych
Browse files

update

parent ac9d2b05
......@@ -16,7 +16,7 @@ Rollout config
"""
from dataclasses import asdict, dataclass, field
from typing import Any, Dict
from typing import Any, Dict, Optional
@dataclass
......@@ -26,6 +26,7 @@ class RolloutConfig:
temperature: float = 1.0
top_p: float = 1.0
top_k: int = -1
seed: int = 1
limit_images: int = 0
dtype: str = "bf16"
gpu_memory_utilization: float = 0.6
......@@ -33,13 +34,14 @@ class RolloutConfig:
enforce_eager: bool = False
enable_chunked_prefill: bool = False # only for v0 engine
tensor_parallel_size: int = 2
max_model_len: Optional[int] = None
max_num_batched_tokens: int = 8192
max_num_seqs: int = 1024
disable_log_stats: bool = True
val_override_config: Dict[str, Any] = field(default_factory=dict)
"""auto keys"""
prompt_length: int = field(default=-1, init=False)
response_length: int = field(default=-1, init=False)
trust_remote_code: bool = field(default=False, init=False)
def to_dict(self):
return asdict(self)
......@@ -11,16 +11,10 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
The vllm_rollout that can be applied in different backend
When working with FSDP:
- Use DTensor weight loader (recommended) or HF weight loader
- Utilize state_dict from the FSDP to synchronize the weights among tp ranks in vLLM
"""
import os
from contextlib import contextmanager
from typing import Any, List, Union
from typing import Any, Dict, List, Optional, Union
import numpy as np
import torch
......@@ -31,6 +25,7 @@ from vllm import LLM, RequestOutput, SamplingParams
from ...protocol import DataProto
from ...utils import torch_functional as VF
from ...utils.tokenizer import get_processor
from ...utils.torch_dtypes import PrecisionType
from .base import BaseRollout
from .config import RolloutConfig
......@@ -43,6 +38,15 @@ def _repeat_interleave(value: Union[torch.Tensor, np.ndarray], repeats: int) ->
return np.repeat(value, repeats, axis=0)
def _get_logit_bias(model_path: str, trust_remote_code: bool) -> Optional[Dict[int, float]]:
processor = get_processor(model_path, trust_remote_code=trust_remote_code)
if processor is not None and hasattr(processor, "image_token"):
image_token_id = processor.tokenizer.convert_tokens_to_ids(processor.image_token)
return {image_token_id: -100}
else:
return None
class vLLMRollout(BaseRollout):
def __init__(self, model_path: str, config: RolloutConfig, tokenizer: PreTrainedTokenizer):
"""A vLLM rollout. It requires the module is supported by the vllm.
......@@ -62,33 +66,38 @@ class vLLMRollout(BaseRollout):
if config.max_num_batched_tokens < config.prompt_length + config.response_length:
raise ValueError("max_num_batched_tokens should be greater than prompt_length + response_length.")
vllm_init_kwargs = {}
if config.limit_images > 0:
vllm_init_kwargs = {"limit_mm_per_prompt": {"image": config.limit_images}}
self.inference_engine = LLM(
model=model_path,
skip_tokenizer_init=False,
tensor_parallel_size=config.tensor_parallel_size,
trust_remote_code=config.trust_remote_code,
load_format="dummy",
dtype=PrecisionType.to_str(PrecisionType.to_dtype(config.dtype)),
seed=config.seed,
max_model_len=config.max_model_len or config.prompt_length + config.response_length,
distributed_executor_backend="external_launcher",
tensor_parallel_size=config.tensor_parallel_size,
gpu_memory_utilization=config.gpu_memory_utilization,
enforce_eager=config.enforce_eager,
max_model_len=config.prompt_length + config.response_length,
max_num_batched_tokens=config.max_num_batched_tokens,
enable_sleep_mode=False,
distributed_executor_backend="external_launcher",
disable_log_stats=config.disable_log_stats,
enforce_eager=config.enforce_eager,
disable_custom_all_reduce=True,
limit_mm_per_prompt={"image": config.limit_images} if config.limit_images > 0 else None,
disable_mm_preprocessor_cache=True,
disable_log_stats=config.disable_log_stats,
enable_chunked_prefill=config.enable_chunked_prefill,
seed=self.rank // config.tensor_parallel_size, # dp rank
**vllm_init_kwargs,
enable_sleep_mode=False, # nv True rocm False
# swap_space=20,
)
# Offload vllm model to reduce peak memory usage
# self.inference_engine.sleep(level=1)
sampling_kwargs = {"max_tokens": config.response_length, "detokenize": False}
## TODO DCU 怎么释放显存
# self.inference_engine.offload_model_weights()
sampling_kwargs = {
"max_tokens": config.response_length,
"detokenize": False,
"logit_bias": _get_logit_bias(model_path, trust_remote_code=config.trust_remote_code),
}
default_sampling_params = SamplingParams()
for key in config.to_dict().keys():
if hasattr(default_sampling_params, key):
......@@ -152,10 +161,6 @@ class vLLMRollout(BaseRollout):
input_ids = _repeat_interleave(input_ids, self.sampling_params.n)
attention_mask = _repeat_interleave(attention_mask, self.sampling_params.n)
position_ids = _repeat_interleave(position_ids, self.sampling_params.n)
if "multi_modal_inputs" in non_tensor_batch.keys():
non_tensor_batch["multi_modal_inputs"] = _repeat_interleave(
non_tensor_batch["multi_modal_inputs"], self.sampling_params.n
)
sequence_ids = torch.cat([input_ids, response_ids], dim=-1)
response_length = response_ids.size(1)
......
......@@ -12,14 +12,14 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import warnings
import inspect
from typing import Dict, Iterable, Tuple, Union
import torch
import torch.distributed as dist
from torch.distributed._tensor import DTensor
from torch.distributed.checkpoint.state_dict import get_model_state_dict
from torch.distributed.device_mesh import DeviceMesh
from torch.distributed.fsdp.api import ShardedStateDictConfig, StateDictType
from torch.distributed.fsdp.fully_sharded_data_parallel import FullyShardedDataParallel as FSDP
from vllm import LLM
from vllm.distributed import parallel_state as vllm_ps
......@@ -34,18 +34,11 @@ class FSDPVLLMShardingManager(BaseShardingManager):
self,
module: FSDP,
inference_engine: LLM,
device_mesh: DeviceMesh = None,
device_mesh: DeviceMesh,
):
self.module = module
self.inference_engine = inference_engine
self.device_mesh = device_mesh
with warnings.catch_warnings():
warnings.simplefilter("ignore")
FSDP.set_state_dict_type(
self.module,
state_dict_type=StateDictType.SHARDED_STATE_DICT,
state_dict_config=ShardedStateDictConfig(),
)
self.world_size = dist.get_world_size()
self.tp_size = vllm_ps.get_tensor_model_parallel_world_size()
......@@ -59,13 +52,10 @@ class FSDPVLLMShardingManager(BaseShardingManager):
# Note that torch_random_states may be different on each dp rank
self.torch_random_states = torch.cuda.get_rng_state()
# get a random rng states
if self.device_mesh is not None:
gen_dp_rank = self.device_mesh["dp"].get_local_rank()
torch.cuda.manual_seed(gen_dp_rank + 1000) # make sure all tp ranks have the same random states
self.gen_random_states = torch.cuda.get_rng_state()
torch.cuda.set_rng_state(self.torch_random_states)
else:
self.gen_random_states = None
def _make_weight_iterator(
self, actor_weights: Dict[str, Union[torch.Tensor, DTensor]]
......@@ -83,16 +73,24 @@ class FSDPVLLMShardingManager(BaseShardingManager):
# vllm: https://github.com/vllm-project/vllm/blob/v0.7.3/vllm/device_allocator/cumem.py#L103
torch.cuda.empty_cache()
print_gpu_memory_usage("Before state_dict() in sharding manager")
actor_weights = self.module.state_dict()
actor_weights = get_model_state_dict(self.module)
print_gpu_memory_usage("After state_dict() in sharding manager")
if "tags" in inspect.signature(self.inference_engine.wake_up).parameters:
self.inference_engine.wake_up(tags=["weights"])
else:
self.inference_engine.wake_up()
model = self.inference_engine.llm_engine.model_executor.driver_worker.worker.model_runner.model
model.load_weights(self._make_weight_iterator(actor_weights))
print_gpu_memory_usage("After sync model weights in sharding manager")
del actor_weights
torch.cuda.empty_cache()
if "tags" in inspect.signature(self.inference_engine.wake_up).parameters:
self.inference_engine.wake_up(tags=["kv_cache"])
print_gpu_memory_usage("After del state_dict and empty_cache in sharding manager")
# important: need to manually set the random states of each tp to be identical.
if self.device_mesh is not None:
......@@ -103,6 +101,8 @@ class FSDPVLLMShardingManager(BaseShardingManager):
print_gpu_memory_usage("Before vllm offload in sharding manager")
free_bytes_before_sleep = torch.cuda.mem_get_info()[0]
# self.inference_engine.sleep(level=1)
## rocm
# self.inference_engine.offload_model_weights()
free_bytes_after_sleep = torch.cuda.mem_get_info()[0]
self.freed_bytes = free_bytes_after_sleep - free_bytes_before_sleep
print_gpu_memory_usage("After vllm offload in sharding manager")
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment