vllm_rollout_spmd.py 8.69 KB
Newer Older
chenych's avatar
chenych committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
# Copyright 2024 Bytedance Ltd. and/or its affiliates
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

chenych's avatar
chenych committed
15
import os
chenych's avatar
chenych committed
16
from contextlib import contextmanager
chenych's avatar
update  
chenych committed
17
from typing import Any, Dict, List, Optional, Union
chenych's avatar
chenych committed
18

chenych's avatar
chenych committed
19
import numpy as np
chenych's avatar
chenych committed
20
21
22
23
24
25
import torch
import torch.distributed
from tensordict import TensorDict
from transformers import PreTrainedTokenizer
from vllm import LLM, RequestOutput, SamplingParams

chenych's avatar
Update  
chenych committed
26
27
from ...protocol import DataProto
from ...utils import torch_functional as VF
chenych's avatar
update  
chenych committed
28
from ...utils.tokenizer import get_processor
chenych's avatar
Update  
chenych committed
29
30
31
from ...utils.torch_dtypes import PrecisionType
from .base import BaseRollout
from .config import RolloutConfig
chenych's avatar
chenych committed
32
33


chenych's avatar
chenych committed
34
35
36
def _repeat_interleave(value: Union[torch.Tensor, np.ndarray], repeats: int) -> Union[torch.Tensor, List[Any]]:
    if isinstance(value, torch.Tensor):
        return value.repeat_interleave(repeats, dim=0)
chenych's avatar
chenych committed
37
    else:
chenych's avatar
chenych committed
38
        return np.repeat(value, repeats, axis=0)
chenych's avatar
chenych committed
39
40


chenych's avatar
update  
chenych committed
41
42
43
44
45
46
47
48
49
def _get_logit_bias(model_path: str, trust_remote_code: bool) -> Optional[Dict[int, float]]:
    processor = get_processor(model_path, trust_remote_code=trust_remote_code)
    if processor is not None and hasattr(processor, "image_token"):
        image_token_id = processor.tokenizer.convert_tokens_to_ids(processor.image_token)
        return {image_token_id: -100}
    else:
        return None


chenych's avatar
chenych committed
50
51
52
53
54
55
56
57
58
59
class vLLMRollout(BaseRollout):
    def __init__(self, model_path: str, config: RolloutConfig, tokenizer: PreTrainedTokenizer):
        """A vLLM rollout. It requires the module is supported by the vllm.

        Args:
            module: module here follows huggingface APIs
            config: DictConfig
            tokenizer: the task/model tokenizer
        """
        super().__init__()
chenych's avatar
chenych committed
60
        self.rank = int(os.getenv("RANK", "0"))
chenych's avatar
chenych committed
61
62
63
64
65
66
67
68
69
70
71
        self.config = config
        self.pad_token_id = tokenizer.pad_token_id
        if config.tensor_parallel_size > torch.distributed.get_world_size():
            raise ValueError("Tensor parallelism size should be less than world size.")

        if config.max_num_batched_tokens < config.prompt_length + config.response_length:
            raise ValueError("max_num_batched_tokens should be greater than prompt_length + response_length.")

        self.inference_engine = LLM(
            model=model_path,
            skip_tokenizer_init=False,
chenych's avatar
chenych committed
72
            trust_remote_code=True,
chenych's avatar
update  
chenych committed
73
            load_format="dummy",
chenych's avatar
chenych committed
74
            dtype=PrecisionType.to_str(PrecisionType.to_dtype(config.dtype)),
chenych's avatar
update  
chenych committed
75
76
77
78
            seed=config.seed,
            max_model_len=config.max_model_len or config.prompt_length + config.response_length,
            distributed_executor_backend="external_launcher",
            tensor_parallel_size=config.tensor_parallel_size,
chenych's avatar
chenych committed
79
80
            gpu_memory_utilization=config.gpu_memory_utilization,
            max_num_batched_tokens=config.max_num_batched_tokens,
chenych's avatar
update  
chenych committed
81
            disable_log_stats=config.disable_log_stats,
chenych's avatar
chenych committed
82
            enforce_eager=True,
chenych's avatar
chenych committed
83
            disable_custom_all_reduce=True,
chenych's avatar
update  
chenych committed
84
            limit_mm_per_prompt={"image": config.limit_images} if config.limit_images > 0 else None,
chenych's avatar
chenych committed
85
            disable_mm_preprocessor_cache=True,
chenych's avatar
chenych committed
86
            enable_chunked_prefill=config.enable_chunked_prefill,
chenych's avatar
chenych committed
87
88
            enable_sleep_mode=False, # only support GPUs
            # swap_space=20,
chenych's avatar
chenych committed
89
90
91
        )

        # Offload vllm model to reduce peak memory usage
chenych's avatar
chenych committed
92
        # self.inference_engine.sleep(level=1)
chenych's avatar
update  
chenych committed
93
94
95
96
97
98
99
100
        ## TODO DCU 怎么释放显存
        # self.inference_engine.offload_model_weights()

        sampling_kwargs = {
            "max_tokens": config.response_length,
            "detokenize": False,
            "logit_bias": _get_logit_bias(model_path, trust_remote_code=config.trust_remote_code),
        }
chenych's avatar
chenych committed
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
        default_sampling_params = SamplingParams()
        for key in config.to_dict().keys():
            if hasattr(default_sampling_params, key):
                sampling_kwargs[key] = getattr(config, key)

        print(f"Sampling params: {sampling_kwargs}.")
        self.sampling_params = SamplingParams(**sampling_kwargs)

    @contextmanager
    def update_sampling_params(self, **kwargs):
        # update sampling params
        old_sampling_params_args = {}
        if kwargs:
            for key, value in kwargs.items():
                if hasattr(self.sampling_params, key):
                    old_value = getattr(self.sampling_params, key)
                    old_sampling_params_args[key] = old_value
                    setattr(self.sampling_params, key, value)

        yield
        # roll back to previous sampling params
        for key, value in old_sampling_params_args.items():
            setattr(self.sampling_params, key, value)

    @torch.no_grad()
chenych's avatar
chenych committed
126
    def generate_sequences(self, prompts: DataProto) -> DataProto:
chenych's avatar
chenych committed
127
128
129
130
131
132
133
134
135
136
137
        # left-padded attention_mask
        input_ids: torch.Tensor = prompts.batch["input_ids"]  # (bs, prompt_length)
        attention_mask: torch.Tensor = prompts.batch["attention_mask"]
        position_ids: torch.Tensor = prompts.batch["position_ids"]
        eos_token_id: int = prompts.meta_info["eos_token_id"]
        batch_size = input_ids.size(0)

        non_tensor_batch = prompts.non_tensor_batch
        if batch_size != len(non_tensor_batch["raw_prompt_ids"]):
            raise RuntimeError("vllm sharding manager is not work properly.")

chenych's avatar
chenych committed
138
        if "multi_modal_data" in non_tensor_batch:
chenych's avatar
chenych committed
139
            vllm_inputs = []
chenych's avatar
chenych committed
140
141
142
143
            for raw_prompt_ids, multi_modal_data in zip(
                non_tensor_batch.pop("raw_prompt_ids"), non_tensor_batch.pop("multi_modal_data")
            ):
                vllm_inputs.append({"prompt_token_ids": list(raw_prompt_ids), "multi_modal_data": multi_modal_data})
chenych's avatar
chenych committed
144
145
        else:
            vllm_inputs = [
chenych's avatar
chenych committed
146
                {"prompt_token_ids": list(raw_prompt_ids)} for raw_prompt_ids in non_tensor_batch.pop("raw_prompt_ids")
chenych's avatar
chenych committed
147
148
149
            ]

        # users can customize different sampling_params at different run
chenych's avatar
chenych committed
150
        with self.update_sampling_params(**prompts.meta_info):
chenych's avatar
chenych committed
151
            completions: List[RequestOutput] = self.inference_engine.generate(
chenych's avatar
chenych committed
152
                prompts=vllm_inputs, sampling_params=self.sampling_params, use_tqdm=(self.rank == 0)
chenych's avatar
chenych committed
153
            )
chenych's avatar
chenych committed
154
155
156
157
158
159
160
161
162
163
            response_ids = [output.token_ids for completion in completions for output in completion.outputs]
            response_ids = VF.pad_2d_list_to_length(
                response_ids, self.pad_token_id, max_length=self.config.response_length
            ).to(input_ids.device)

            if self.sampling_params.n > 1:
                batch_size = batch_size * self.sampling_params.n
                input_ids = _repeat_interleave(input_ids, self.sampling_params.n)
                attention_mask = _repeat_interleave(attention_mask, self.sampling_params.n)
                position_ids = _repeat_interleave(position_ids, self.sampling_params.n)
chenych's avatar
chenych committed
164
165
166
167
168
169
170
171
172
173
174
175
176

        sequence_ids = torch.cat([input_ids, response_ids], dim=-1)
        response_length = response_ids.size(1)
        delta_position_id = torch.arange(1, response_length + 1, device=position_ids.device)
        delta_position_id = delta_position_id.view(1, -1).expand(batch_size, -1)
        if position_ids.dim() == 3:  # qwen2vl mrope
            delta_position_id = delta_position_id.view(batch_size, 1, -1).expand(batch_size, 3, -1)

        # prompt: left pad + response: right pad
        # attention_mask: [0,0,0,0,1,1,1,1 | 1,1,1,0,0,0,0,0]
        # position_ids:   [0,0,0,0,0,1,2,3 | 4,5,6,7,8,9,10,11]
        response_position_ids = position_ids[..., -1:] + delta_position_id
        position_ids = torch.cat([position_ids, response_position_ids], dim=-1)
chenych's avatar
Update  
chenych committed
177
        response_mask = VF.get_response_mask(
chenych's avatar
chenych committed
178
            response_ids=response_ids, eos_token_id=eos_token_id, dtype=attention_mask.dtype
chenych's avatar
chenych committed
179
        )
chenych's avatar
chenych committed
180
        attention_mask = torch.cat((attention_mask, response_mask), dim=-1)
chenych's avatar
chenych committed
181
182
183
184
185
186
187
188

        # all the tp ranks should contain the same data here. data in all ranks are valid
        batch = TensorDict(
            {
                "prompts": input_ids,
                "responses": response_ids,
                "input_ids": sequence_ids,  # here input_ids become the whole sentences
                "attention_mask": attention_mask,
chenych's avatar
chenych committed
189
                "response_mask": response_mask,
chenych's avatar
chenych committed
190
191
192
193
194
                "position_ids": position_ids,
            },
            batch_size=batch_size,
        )
        return DataProto(batch=batch, non_tensor_batch=non_tensor_batch)