tp_worker.py 8.02 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
# Copyright 2023-2024 SGLang Team
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
Lianmin Zheng's avatar
Lianmin Zheng committed
14
15
"""A tensor parallel worker."""

Lianmin Zheng's avatar
Lianmin Zheng committed
16
import logging
17
import threading
18
from typing import Optional
19

20
from sglang.srt.configs.model_config import ModelConfig
Lianmin Zheng's avatar
Lianmin Zheng committed
21
from sglang.srt.hf_transformers_utils import get_processor, get_tokenizer
22
23
from sglang.srt.managers.io_struct import (
    GetWeightsByNameReqInput,
24
    InitWeightsUpdateGroupReqInput,
25
    UpdateWeightFromDiskReqInput,
26
    UpdateWeightsFromDistributedReqInput,
27
    UpdateWeightsFromTensorReqInput,
28
)
29
from sglang.srt.managers.schedule_batch import ModelWorkerBatch, global_server_args_dict
30
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
31
from sglang.srt.model_executor.model_runner import ModelRunner
Mingyi's avatar
Mingyi committed
32
from sglang.srt.server_args import ServerArgs
33
from sglang.srt.utils import MultiprocessingSerializer, broadcast_pyobj, set_random_seed
34

Ying Sheng's avatar
Ying Sheng committed
35
logger = logging.getLogger(__name__)
Lianmin Zheng's avatar
Lianmin Zheng committed
36
37


38
39
40
class TpModelWorker:
    """A tensor parallel model worker."""

41
    def __init__(
Lianmin Zheng's avatar
Lianmin Zheng committed
42
        self,
43
        server_args: ServerArgs,
44
        gpu_id: int,
Lianmin Zheng's avatar
Lianmin Zheng committed
45
        tp_rank: int,
46
        dp_rank: Optional[int],
Mingyi's avatar
Mingyi committed
47
        nccl_port: int,
48
        is_draft_worker: bool = False,
Lianmin Zheng's avatar
Lianmin Zheng committed
49
    ):
50
        # Parse args
Lianmin Zheng's avatar
Lianmin Zheng committed
51
52
53
54
        self.tp_rank = tp_rank

        # Init model and tokenizer
        self.model_config = ModelConfig(
55
56
57
58
59
            (
                server_args.model_path
                if not is_draft_worker
                else server_args.speculative_draft_model_path
            ),
60
            trust_remote_code=server_args.trust_remote_code,
61
            revision=server_args.revision,
Liangsheng Yin's avatar
Liangsheng Yin committed
62
            context_length=server_args.context_length,
63
64
            model_override_args=server_args.json_model_override_args,
            is_embedding=server_args.is_embedding,
65
66
            dtype=server_args.dtype,
            quantization=server_args.quantization,
Lianmin Zheng's avatar
Lianmin Zheng committed
67
68
        )
        self.model_runner = ModelRunner(
Liangsheng Yin's avatar
Liangsheng Yin committed
69
70
            model_config=self.model_config,
            mem_fraction_static=server_args.mem_fraction_static,
71
            gpu_id=gpu_id,
Liangsheng Yin's avatar
Liangsheng Yin committed
72
73
            tp_rank=tp_rank,
            tp_size=server_args.tp_size,
Mingyi's avatar
Mingyi committed
74
            nccl_port=nccl_port,
Lianmin Zheng's avatar
Lianmin Zheng committed
75
            server_args=server_args,
76
            is_draft_worker=is_draft_worker,
Lianmin Zheng's avatar
Lianmin Zheng committed
77
        )
78
79
        if server_args.skip_tokenizer_init:
            self.tokenizer = self.processor = None
Lianmin Zheng's avatar
Lianmin Zheng committed
80
        else:
81
            if self.model_config.is_multimodal:
82
83
84
85
                self.processor = get_processor(
                    server_args.tokenizer_path,
                    tokenizer_mode=server_args.tokenizer_mode,
                    trust_remote_code=server_args.trust_remote_code,
86
                    revision=server_args.revision,
87
88
89
90
91
92
93
                )
                self.tokenizer = self.processor.tokenizer
            else:
                self.tokenizer = get_tokenizer(
                    server_args.tokenizer_path,
                    tokenizer_mode=server_args.tokenizer_mode,
                    trust_remote_code=server_args.trust_remote_code,
94
                    revision=server_args.revision,
95
                )
96
        self.device = self.model_runner.device
97
98

        # Profile number of tokens
99
        self.max_total_num_tokens = self.model_runner.max_total_num_tokens
100
        self.max_prefill_tokens = server_args.max_prefill_tokens
Ying Sheng's avatar
Ying Sheng committed
101
        self.max_running_requests = min(
102
103
104
105
            (
                self.max_total_num_tokens // 2
                if server_args.max_running_requests is None
                else server_args.max_running_requests
106
                // (server_args.dp_size if server_args.enable_dp_attention else 1)
107
            ),
108
            self.model_runner.req_to_token_pool.size,
Ying Sheng's avatar
Ying Sheng committed
109
        )
110
        self.max_req_len = min(
111
112
113
            self.model_config.context_len - 1,
            self.max_total_num_tokens - 1,
        )
114
115
116
117
        self.max_req_input_len = self.max_req_len - 5
        assert (
            self.max_req_len > 0 and self.max_req_input_len > 0
        ), "Memory pool size is too small"
118

Lianmin Zheng's avatar
Lianmin Zheng committed
119
        # Sync random seed across TP workers
120
        self.random_seed = broadcast_pyobj(
121
122
123
124
            [server_args.random_seed],
            self.tp_rank,
            self.model_runner.tp_group.cpu_group,
        )[0]
125
        set_random_seed(self.random_seed)
126

127
    def get_worker_info(self):
128
129
        return (
            self.max_total_num_tokens,
130
            self.max_prefill_tokens,
131
            self.max_running_requests,
132
            self.max_req_len,
133
134
            self.max_req_input_len,
            self.random_seed,
135
            self.device,
136
137
138
139
            global_server_args_dict,
            self.model_runner.req_to_token_pool.size,
            self.model_runner.req_to_token_pool.max_context_len,
            self.model_runner.token_to_kv_pool.size,
140
141
142
143
144
145
146
147
        )

    def get_pad_input_ids_func(self):
        return getattr(self.model_runner.model, "pad_input_ids", None)

    def get_tp_cpu_group(self):
        return self.model_runner.tp_group.cpu_group

148
149
150
    def get_attention_tp_cpu_group(self):
        return self.model_runner.attention_tp_group.cpu_group

151
152
153
154
    def get_memory_pool(self):
        return (
            self.model_runner.req_to_token_pool,
            self.model_runner.token_to_kv_pool,
Lianmin Zheng's avatar
Lianmin Zheng committed
155
        )
156

157
158
159
    def forward_batch_generation(
        self,
        model_worker_batch: ModelWorkerBatch,
160
        launch_done: Optional[threading.Event] = None,
161
        skip_sample: bool = False,
162
    ):
163
        forward_batch = ForwardBatch.init_new(model_worker_batch, self.model_runner)
164
        logits_output = self.model_runner.forward(forward_batch)
165
166
        if launch_done:
            launch_done.set()
167
168
169
170
171
172

        if skip_sample:
            next_token_ids = None
        else:
            next_token_ids = self.model_runner.sample(logits_output, model_worker_batch)

173
        return logits_output, next_token_ids
Lianmin Zheng's avatar
Lianmin Zheng committed
174

175
176
    def forward_batch_embedding(self, model_worker_batch: ModelWorkerBatch):
        forward_batch = ForwardBatch.init_new(model_worker_batch, self.model_runner)
177
        logits_output = self.model_runner.forward(forward_batch)
178
        embeddings = logits_output.embeddings
179
        return embeddings
180

Chayenne's avatar
Chayenne committed
181
182
    def update_weights_from_disk(self, recv_req: UpdateWeightFromDiskReqInput):
        success, message = self.model_runner.update_weights_from_disk(
183
184
185
            recv_req.model_path, recv_req.load_format
        )
        return success, message
186

187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
    def init_weights_update_group(self, recv_req: InitWeightsUpdateGroupReqInput):
        success, message = self.model_runner.init_weights_update_group(
            recv_req.master_address,
            recv_req.master_port,
            recv_req.rank_offset,
            recv_req.world_size,
            recv_req.group_name,
            recv_req.backend,
        )
        return success, message

    def update_weights_from_distributed(
        self, recv_req: UpdateWeightsFromDistributedReqInput
    ):
        success, message = self.model_runner.update_weights_from_distributed(
            recv_req.name, recv_req.dtype, recv_req.shape
        )
        return success, message

206
207
    def update_weights_from_tensor(self, recv_req: UpdateWeightsFromTensorReqInput):
        success, message = self.model_runner.update_weights_from_tensor(
208
209
210
211
            named_tensors=MultiprocessingSerializer.deserialize(
                recv_req.serialized_named_tensors
            ),
            load_format=recv_req.load_format,
212
213
214
        )
        return success, message

215
216
217
218
219
    def get_weights_by_name(self, recv_req: GetWeightsByNameReqInput):
        parameter = self.model_runner.get_weights_by_name(
            recv_req.name, recv_req.truncate_size
        )
        return parameter