Commit 2369eb2b authored by chenych's avatar chenych
Browse files

update

parent ac9d2b05
...@@ -16,7 +16,7 @@ Rollout config ...@@ -16,7 +16,7 @@ Rollout config
""" """
from dataclasses import asdict, dataclass, field from dataclasses import asdict, dataclass, field
from typing import Any, Dict from typing import Any, Dict, Optional
@dataclass @dataclass
...@@ -26,6 +26,7 @@ class RolloutConfig: ...@@ -26,6 +26,7 @@ class RolloutConfig:
temperature: float = 1.0 temperature: float = 1.0
top_p: float = 1.0 top_p: float = 1.0
top_k: int = -1 top_k: int = -1
seed: int = 1
limit_images: int = 0 limit_images: int = 0
dtype: str = "bf16" dtype: str = "bf16"
gpu_memory_utilization: float = 0.6 gpu_memory_utilization: float = 0.6
...@@ -33,13 +34,14 @@ class RolloutConfig: ...@@ -33,13 +34,14 @@ class RolloutConfig:
enforce_eager: bool = False enforce_eager: bool = False
enable_chunked_prefill: bool = False # only for v0 engine enable_chunked_prefill: bool = False # only for v0 engine
tensor_parallel_size: int = 2 tensor_parallel_size: int = 2
max_model_len: Optional[int] = None
max_num_batched_tokens: int = 8192 max_num_batched_tokens: int = 8192
max_num_seqs: int = 1024
disable_log_stats: bool = True disable_log_stats: bool = True
val_override_config: Dict[str, Any] = field(default_factory=dict) val_override_config: Dict[str, Any] = field(default_factory=dict)
"""auto keys""" """auto keys"""
prompt_length: int = field(default=-1, init=False) prompt_length: int = field(default=-1, init=False)
response_length: int = field(default=-1, init=False) response_length: int = field(default=-1, init=False)
trust_remote_code: bool = field(default=False, init=False)
def to_dict(self): def to_dict(self):
return asdict(self) return asdict(self)
...@@ -11,16 +11,10 @@ ...@@ -11,16 +11,10 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
"""
The vllm_rollout that can be applied in different backend
When working with FSDP:
- Use DTensor weight loader (recommended) or HF weight loader
- Utilize state_dict from the FSDP to synchronize the weights among tp ranks in vLLM
"""
import os import os
from contextlib import contextmanager from contextlib import contextmanager
from typing import Any, List, Union from typing import Any, Dict, List, Optional, Union
import numpy as np import numpy as np
import torch import torch
...@@ -31,6 +25,7 @@ from vllm import LLM, RequestOutput, SamplingParams ...@@ -31,6 +25,7 @@ from vllm import LLM, RequestOutput, SamplingParams
from ...protocol import DataProto from ...protocol import DataProto
from ...utils import torch_functional as VF from ...utils import torch_functional as VF
from ...utils.tokenizer import get_processor
from ...utils.torch_dtypes import PrecisionType from ...utils.torch_dtypes import PrecisionType
from .base import BaseRollout from .base import BaseRollout
from .config import RolloutConfig from .config import RolloutConfig
...@@ -43,6 +38,15 @@ def _repeat_interleave(value: Union[torch.Tensor, np.ndarray], repeats: int) -> ...@@ -43,6 +38,15 @@ def _repeat_interleave(value: Union[torch.Tensor, np.ndarray], repeats: int) ->
return np.repeat(value, repeats, axis=0) return np.repeat(value, repeats, axis=0)
def _get_logit_bias(model_path: str, trust_remote_code: bool) -> Optional[Dict[int, float]]:
processor = get_processor(model_path, trust_remote_code=trust_remote_code)
if processor is not None and hasattr(processor, "image_token"):
image_token_id = processor.tokenizer.convert_tokens_to_ids(processor.image_token)
return {image_token_id: -100}
else:
return None
class vLLMRollout(BaseRollout): class vLLMRollout(BaseRollout):
def __init__(self, model_path: str, config: RolloutConfig, tokenizer: PreTrainedTokenizer): def __init__(self, model_path: str, config: RolloutConfig, tokenizer: PreTrainedTokenizer):
"""A vLLM rollout. It requires the module is supported by the vllm. """A vLLM rollout. It requires the module is supported by the vllm.
...@@ -62,33 +66,38 @@ class vLLMRollout(BaseRollout): ...@@ -62,33 +66,38 @@ class vLLMRollout(BaseRollout):
if config.max_num_batched_tokens < config.prompt_length + config.response_length: if config.max_num_batched_tokens < config.prompt_length + config.response_length:
raise ValueError("max_num_batched_tokens should be greater than prompt_length + response_length.") raise ValueError("max_num_batched_tokens should be greater than prompt_length + response_length.")
vllm_init_kwargs = {}
if config.limit_images > 0:
vllm_init_kwargs = {"limit_mm_per_prompt": {"image": config.limit_images}}
self.inference_engine = LLM( self.inference_engine = LLM(
model=model_path, model=model_path,
skip_tokenizer_init=False, skip_tokenizer_init=False,
tensor_parallel_size=config.tensor_parallel_size, trust_remote_code=config.trust_remote_code,
load_format="dummy",
dtype=PrecisionType.to_str(PrecisionType.to_dtype(config.dtype)), dtype=PrecisionType.to_str(PrecisionType.to_dtype(config.dtype)),
seed=config.seed,
max_model_len=config.max_model_len or config.prompt_length + config.response_length,
distributed_executor_backend="external_launcher",
tensor_parallel_size=config.tensor_parallel_size,
gpu_memory_utilization=config.gpu_memory_utilization, gpu_memory_utilization=config.gpu_memory_utilization,
enforce_eager=config.enforce_eager,
max_model_len=config.prompt_length + config.response_length,
max_num_batched_tokens=config.max_num_batched_tokens, max_num_batched_tokens=config.max_num_batched_tokens,
enable_sleep_mode=False, disable_log_stats=config.disable_log_stats,
distributed_executor_backend="external_launcher", enforce_eager=config.enforce_eager,
disable_custom_all_reduce=True, disable_custom_all_reduce=True,
limit_mm_per_prompt={"image": config.limit_images} if config.limit_images > 0 else None,
disable_mm_preprocessor_cache=True, disable_mm_preprocessor_cache=True,
disable_log_stats=config.disable_log_stats,
enable_chunked_prefill=config.enable_chunked_prefill, enable_chunked_prefill=config.enable_chunked_prefill,
seed=self.rank // config.tensor_parallel_size, # dp rank enable_sleep_mode=False, # nv True rocm False
**vllm_init_kwargs, # swap_space=20,
) )
# Offload vllm model to reduce peak memory usage # Offload vllm model to reduce peak memory usage
# self.inference_engine.sleep(level=1) # self.inference_engine.sleep(level=1)
## TODO DCU 怎么释放显存
sampling_kwargs = {"max_tokens": config.response_length, "detokenize": False} # self.inference_engine.offload_model_weights()
sampling_kwargs = {
"max_tokens": config.response_length,
"detokenize": False,
"logit_bias": _get_logit_bias(model_path, trust_remote_code=config.trust_remote_code),
}
default_sampling_params = SamplingParams() default_sampling_params = SamplingParams()
for key in config.to_dict().keys(): for key in config.to_dict().keys():
if hasattr(default_sampling_params, key): if hasattr(default_sampling_params, key):
...@@ -152,10 +161,6 @@ class vLLMRollout(BaseRollout): ...@@ -152,10 +161,6 @@ class vLLMRollout(BaseRollout):
input_ids = _repeat_interleave(input_ids, self.sampling_params.n) input_ids = _repeat_interleave(input_ids, self.sampling_params.n)
attention_mask = _repeat_interleave(attention_mask, self.sampling_params.n) attention_mask = _repeat_interleave(attention_mask, self.sampling_params.n)
position_ids = _repeat_interleave(position_ids, self.sampling_params.n) position_ids = _repeat_interleave(position_ids, self.sampling_params.n)
if "multi_modal_inputs" in non_tensor_batch.keys():
non_tensor_batch["multi_modal_inputs"] = _repeat_interleave(
non_tensor_batch["multi_modal_inputs"], self.sampling_params.n
)
sequence_ids = torch.cat([input_ids, response_ids], dim=-1) sequence_ids = torch.cat([input_ids, response_ids], dim=-1)
response_length = response_ids.size(1) response_length = response_ids.size(1)
......
...@@ -12,14 +12,14 @@ ...@@ -12,14 +12,14 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import warnings import inspect
from typing import Dict, Iterable, Tuple, Union from typing import Dict, Iterable, Tuple, Union
import torch import torch
import torch.distributed as dist import torch.distributed as dist
from torch.distributed._tensor import DTensor from torch.distributed._tensor import DTensor
from torch.distributed.checkpoint.state_dict import get_model_state_dict
from torch.distributed.device_mesh import DeviceMesh from torch.distributed.device_mesh import DeviceMesh
from torch.distributed.fsdp.api import ShardedStateDictConfig, StateDictType
from torch.distributed.fsdp.fully_sharded_data_parallel import FullyShardedDataParallel as FSDP from torch.distributed.fsdp.fully_sharded_data_parallel import FullyShardedDataParallel as FSDP
from vllm import LLM from vllm import LLM
from vllm.distributed import parallel_state as vllm_ps from vllm.distributed import parallel_state as vllm_ps
...@@ -34,18 +34,11 @@ class FSDPVLLMShardingManager(BaseShardingManager): ...@@ -34,18 +34,11 @@ class FSDPVLLMShardingManager(BaseShardingManager):
self, self,
module: FSDP, module: FSDP,
inference_engine: LLM, inference_engine: LLM,
device_mesh: DeviceMesh = None, device_mesh: DeviceMesh,
): ):
self.module = module self.module = module
self.inference_engine = inference_engine self.inference_engine = inference_engine
self.device_mesh = device_mesh self.device_mesh = device_mesh
with warnings.catch_warnings():
warnings.simplefilter("ignore")
FSDP.set_state_dict_type(
self.module,
state_dict_type=StateDictType.SHARDED_STATE_DICT,
state_dict_config=ShardedStateDictConfig(),
)
self.world_size = dist.get_world_size() self.world_size = dist.get_world_size()
self.tp_size = vllm_ps.get_tensor_model_parallel_world_size() self.tp_size = vllm_ps.get_tensor_model_parallel_world_size()
...@@ -59,13 +52,10 @@ class FSDPVLLMShardingManager(BaseShardingManager): ...@@ -59,13 +52,10 @@ class FSDPVLLMShardingManager(BaseShardingManager):
# Note that torch_random_states may be different on each dp rank # Note that torch_random_states may be different on each dp rank
self.torch_random_states = torch.cuda.get_rng_state() self.torch_random_states = torch.cuda.get_rng_state()
# get a random rng states # get a random rng states
if self.device_mesh is not None:
gen_dp_rank = self.device_mesh["dp"].get_local_rank() gen_dp_rank = self.device_mesh["dp"].get_local_rank()
torch.cuda.manual_seed(gen_dp_rank + 1000) # make sure all tp ranks have the same random states torch.cuda.manual_seed(gen_dp_rank + 1000) # make sure all tp ranks have the same random states
self.gen_random_states = torch.cuda.get_rng_state() self.gen_random_states = torch.cuda.get_rng_state()
torch.cuda.set_rng_state(self.torch_random_states) torch.cuda.set_rng_state(self.torch_random_states)
else:
self.gen_random_states = None
def _make_weight_iterator( def _make_weight_iterator(
self, actor_weights: Dict[str, Union[torch.Tensor, DTensor]] self, actor_weights: Dict[str, Union[torch.Tensor, DTensor]]
...@@ -83,16 +73,24 @@ class FSDPVLLMShardingManager(BaseShardingManager): ...@@ -83,16 +73,24 @@ class FSDPVLLMShardingManager(BaseShardingManager):
# vllm: https://github.com/vllm-project/vllm/blob/v0.7.3/vllm/device_allocator/cumem.py#L103 # vllm: https://github.com/vllm-project/vllm/blob/v0.7.3/vllm/device_allocator/cumem.py#L103
torch.cuda.empty_cache() torch.cuda.empty_cache()
print_gpu_memory_usage("Before state_dict() in sharding manager") print_gpu_memory_usage("Before state_dict() in sharding manager")
actor_weights = self.module.state_dict() actor_weights = get_model_state_dict(self.module)
print_gpu_memory_usage("After state_dict() in sharding manager") print_gpu_memory_usage("After state_dict() in sharding manager")
if "tags" in inspect.signature(self.inference_engine.wake_up).parameters:
self.inference_engine.wake_up(tags=["weights"])
else:
self.inference_engine.wake_up() self.inference_engine.wake_up()
model = self.inference_engine.llm_engine.model_executor.driver_worker.worker.model_runner.model model = self.inference_engine.llm_engine.model_executor.driver_worker.worker.model_runner.model
model.load_weights(self._make_weight_iterator(actor_weights)) model.load_weights(self._make_weight_iterator(actor_weights))
print_gpu_memory_usage("After sync model weights in sharding manager") print_gpu_memory_usage("After sync model weights in sharding manager")
del actor_weights del actor_weights
torch.cuda.empty_cache() torch.cuda.empty_cache()
if "tags" in inspect.signature(self.inference_engine.wake_up).parameters:
self.inference_engine.wake_up(tags=["kv_cache"])
print_gpu_memory_usage("After del state_dict and empty_cache in sharding manager") print_gpu_memory_usage("After del state_dict and empty_cache in sharding manager")
# important: need to manually set the random states of each tp to be identical. # important: need to manually set the random states of each tp to be identical.
if self.device_mesh is not None: if self.device_mesh is not None:
...@@ -103,6 +101,8 @@ class FSDPVLLMShardingManager(BaseShardingManager): ...@@ -103,6 +101,8 @@ class FSDPVLLMShardingManager(BaseShardingManager):
print_gpu_memory_usage("Before vllm offload in sharding manager") print_gpu_memory_usage("Before vllm offload in sharding manager")
free_bytes_before_sleep = torch.cuda.mem_get_info()[0] free_bytes_before_sleep = torch.cuda.mem_get_info()[0]
# self.inference_engine.sleep(level=1) # self.inference_engine.sleep(level=1)
## rocm
# self.inference_engine.offload_model_weights()
free_bytes_after_sleep = torch.cuda.mem_get_info()[0] free_bytes_after_sleep = torch.cuda.mem_get_info()[0]
self.freed_bytes = free_bytes_after_sleep - free_bytes_before_sleep self.freed_bytes = free_bytes_after_sleep - free_bytes_before_sleep
print_gpu_memory_usage("After vllm offload in sharding manager") print_gpu_memory_usage("After vllm offload in sharding manager")
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment