tp_worker.py 8.51 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
# Copyright 2023-2024 SGLang Team
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
Lianmin Zheng's avatar
Lianmin Zheng committed
14
15
"""A tensor parallel worker."""

Lianmin Zheng's avatar
Lianmin Zheng committed
16
import logging
17
import threading
18
19
20
from typing import Optional, Tuple

import torch
21

22
from sglang.srt.configs.model_config import ModelConfig
Lianmin Zheng's avatar
Lianmin Zheng committed
23
from sglang.srt.hf_transformers_utils import get_processor, get_tokenizer
24
from sglang.srt.layers.logits_processor import LogitsProcessorOutput
25
26
from sglang.srt.managers.io_struct import (
    GetWeightsByNameReqInput,
27
    InitWeightsUpdateGroupReqInput,
28
    UpdateWeightFromDiskReqInput,
29
    UpdateWeightsFromDistributedReqInput,
30
    UpdateWeightsFromTensorReqInput,
31
)
32
from sglang.srt.managers.schedule_batch import ModelWorkerBatch, global_server_args_dict
33
from sglang.srt.mem_cache.memory_pool import ReqToTokenPool, TokenToKVPoolAllocator
34
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
35
from sglang.srt.model_executor.model_runner import ModelRunner
Mingyi's avatar
Mingyi committed
36
from sglang.srt.server_args import ServerArgs
37
from sglang.srt.utils import MultiprocessingSerializer, broadcast_pyobj, set_random_seed
38

Ying Sheng's avatar
Ying Sheng committed
39
logger = logging.getLogger(__name__)
Lianmin Zheng's avatar
Lianmin Zheng committed
40
41


42
43
44
class TpModelWorker:
    """A tensor parallel model worker."""

45
    def __init__(
Lianmin Zheng's avatar
Lianmin Zheng committed
46
        self,
47
        server_args: ServerArgs,
48
        gpu_id: int,
Lianmin Zheng's avatar
Lianmin Zheng committed
49
        tp_rank: int,
50
        dp_rank: Optional[int],
Mingyi's avatar
Mingyi committed
51
        nccl_port: int,
52
        is_draft_worker: bool = False,
53
54
        req_to_token_pool: Optional[ReqToTokenPool] = None,
        token_to_kv_pool_allocator: Optional[TokenToKVPoolAllocator] = None,
Lianmin Zheng's avatar
Lianmin Zheng committed
55
    ):
56
        # Parse args
Lianmin Zheng's avatar
Lianmin Zheng committed
57
58
59
60
        self.tp_rank = tp_rank

        # Init model and tokenizer
        self.model_config = ModelConfig(
61
62
63
64
65
            (
                server_args.model_path
                if not is_draft_worker
                else server_args.speculative_draft_model_path
            ),
66
            trust_remote_code=server_args.trust_remote_code,
67
            revision=server_args.revision,
Liangsheng Yin's avatar
Liangsheng Yin committed
68
            context_length=server_args.context_length,
69
70
            model_override_args=server_args.json_model_override_args,
            is_embedding=server_args.is_embedding,
71
72
            dtype=server_args.dtype,
            quantization=server_args.quantization,
Lianmin Zheng's avatar
Lianmin Zheng committed
73
74
        )
        self.model_runner = ModelRunner(
Liangsheng Yin's avatar
Liangsheng Yin committed
75
76
            model_config=self.model_config,
            mem_fraction_static=server_args.mem_fraction_static,
77
            gpu_id=gpu_id,
Liangsheng Yin's avatar
Liangsheng Yin committed
78
79
            tp_rank=tp_rank,
            tp_size=server_args.tp_size,
Mingyi's avatar
Mingyi committed
80
            nccl_port=nccl_port,
Lianmin Zheng's avatar
Lianmin Zheng committed
81
            server_args=server_args,
82
            is_draft_worker=is_draft_worker,
83
84
            req_to_token_pool=req_to_token_pool,
            token_to_kv_pool_allocator=token_to_kv_pool_allocator,
Lianmin Zheng's avatar
Lianmin Zheng committed
85
        )
86
87
        if server_args.skip_tokenizer_init:
            self.tokenizer = self.processor = None
Lianmin Zheng's avatar
Lianmin Zheng committed
88
        else:
89
            if self.model_config.is_multimodal:
90
91
92
93
                self.processor = get_processor(
                    server_args.tokenizer_path,
                    tokenizer_mode=server_args.tokenizer_mode,
                    trust_remote_code=server_args.trust_remote_code,
94
                    revision=server_args.revision,
95
96
97
98
99
100
101
                )
                self.tokenizer = self.processor.tokenizer
            else:
                self.tokenizer = get_tokenizer(
                    server_args.tokenizer_path,
                    tokenizer_mode=server_args.tokenizer_mode,
                    trust_remote_code=server_args.trust_remote_code,
102
                    revision=server_args.revision,
103
                )
104
        self.device = self.model_runner.device
105
106

        # Profile number of tokens
107
        self.max_total_num_tokens = self.model_runner.max_total_num_tokens
108
        self.max_prefill_tokens = server_args.max_prefill_tokens
Ying Sheng's avatar
Ying Sheng committed
109
        self.max_running_requests = min(
110
111
112
113
            (
                self.max_total_num_tokens // 2
                if server_args.max_running_requests is None
                else server_args.max_running_requests
114
                // (server_args.dp_size if server_args.enable_dp_attention else 1)
115
            ),
116
            self.model_runner.req_to_token_pool.size,
Ying Sheng's avatar
Ying Sheng committed
117
        )
118
        self.max_req_len = min(
119
120
121
            self.model_config.context_len - 1,
            self.max_total_num_tokens - 1,
        )
122
123
124
125
        self.max_req_input_len = self.max_req_len - 5
        assert (
            self.max_req_len > 0 and self.max_req_input_len > 0
        ), "Memory pool size is too small"
126

Lianmin Zheng's avatar
Lianmin Zheng committed
127
        # Sync random seed across TP workers
128
        self.random_seed = broadcast_pyobj(
129
130
131
132
            [server_args.random_seed],
            self.tp_rank,
            self.model_runner.tp_group.cpu_group,
        )[0]
133
        set_random_seed(self.random_seed)
134

135
    def get_worker_info(self):
136
137
        return (
            self.max_total_num_tokens,
138
            self.max_prefill_tokens,
139
            self.max_running_requests,
140
            self.max_req_len,
141
142
            self.max_req_input_len,
            self.random_seed,
143
            self.device,
144
145
146
147
            global_server_args_dict,
            self.model_runner.req_to_token_pool.size,
            self.model_runner.req_to_token_pool.max_context_len,
            self.model_runner.token_to_kv_pool.size,
148
149
150
151
152
153
154
155
        )

    def get_pad_input_ids_func(self):
        return getattr(self.model_runner.model, "pad_input_ids", None)

    def get_tp_cpu_group(self):
        return self.model_runner.tp_group.cpu_group

156
157
158
    def get_attention_tp_cpu_group(self):
        return self.model_runner.attention_tp_group.cpu_group

159
160
161
    def get_memory_pool(self):
        return (
            self.model_runner.req_to_token_pool,
162
            self.model_runner.token_to_kv_pool_allocator,
Lianmin Zheng's avatar
Lianmin Zheng committed
163
        )
164

165
166
167
    def forward_batch_generation(
        self,
        model_worker_batch: ModelWorkerBatch,
168
        launch_done: Optional[threading.Event] = None,
169
        skip_sample: bool = False,
170
    ) -> Tuple[LogitsProcessorOutput, Optional[torch.Tensor]]:
171
        forward_batch = ForwardBatch.init_new(model_worker_batch, self.model_runner)
172
        logits_output = self.model_runner.forward(forward_batch)
173
174
        if launch_done:
            launch_done.set()
175
176
177
178
179
180

        if skip_sample:
            next_token_ids = None
        else:
            next_token_ids = self.model_runner.sample(logits_output, model_worker_batch)

181
        return logits_output, next_token_ids
Lianmin Zheng's avatar
Lianmin Zheng committed
182

183
184
    def forward_batch_embedding(self, model_worker_batch: ModelWorkerBatch):
        forward_batch = ForwardBatch.init_new(model_worker_batch, self.model_runner)
185
        logits_output = self.model_runner.forward(forward_batch)
186
        embeddings = logits_output.embeddings
187
        return embeddings
188

Chayenne's avatar
Chayenne committed
189
190
    def update_weights_from_disk(self, recv_req: UpdateWeightFromDiskReqInput):
        success, message = self.model_runner.update_weights_from_disk(
191
192
193
            recv_req.model_path, recv_req.load_format
        )
        return success, message
194

195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
    def init_weights_update_group(self, recv_req: InitWeightsUpdateGroupReqInput):
        success, message = self.model_runner.init_weights_update_group(
            recv_req.master_address,
            recv_req.master_port,
            recv_req.rank_offset,
            recv_req.world_size,
            recv_req.group_name,
            recv_req.backend,
        )
        return success, message

    def update_weights_from_distributed(
        self, recv_req: UpdateWeightsFromDistributedReqInput
    ):
        success, message = self.model_runner.update_weights_from_distributed(
            recv_req.name, recv_req.dtype, recv_req.shape
        )
        return success, message

214
215
    def update_weights_from_tensor(self, recv_req: UpdateWeightsFromTensorReqInput):
        success, message = self.model_runner.update_weights_from_tensor(
216
217
218
219
            named_tensors=MultiprocessingSerializer.deserialize(
                recv_req.serialized_named_tensors
            ),
            load_format=recv_req.load_format,
220
221
222
        )
        return success, message

223
224
225
226
227
    def get_weights_by_name(self, recv_req: GetWeightsByNameReqInput):
        parameter = self.model_runner.get_weights_by_name(
            recv_req.name, recv_req.truncate_size
        )
        return parameter