tp_worker.py 7.58 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
# Copyright 2023-2024 SGLang Team
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
Lianmin Zheng's avatar
Lianmin Zheng committed
14
15
"""A tensor parallel worker."""

Lianmin Zheng's avatar
Lianmin Zheng committed
16
import logging
17
import threading
18
from typing import Optional
19

20
from sglang.srt.configs.model_config import ModelConfig
Lianmin Zheng's avatar
Lianmin Zheng committed
21
from sglang.srt.hf_transformers_utils import get_processor, get_tokenizer
22
23
from sglang.srt.managers.io_struct import (
    GetWeightsByNameReqInput,
24
    InitWeightsUpdateGroupReqInput,
25
    UpdateWeightFromDiskReqInput,
26
    UpdateWeightsFromDistributedReqInput,
27
    UpdateWeightsFromTensorReqInput,
28
)
29
from sglang.srt.managers.schedule_batch import ModelWorkerBatch, global_server_args_dict
30
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
31
from sglang.srt.model_executor.model_runner import ModelRunner
Mingyi's avatar
Mingyi committed
32
from sglang.srt.server_args import ServerArgs
33
from sglang.srt.utils import broadcast_pyobj, set_random_seed
34

Ying Sheng's avatar
Ying Sheng committed
35
logger = logging.getLogger(__name__)
Lianmin Zheng's avatar
Lianmin Zheng committed
36
37


38
39
40
class TpModelWorker:
    """A tensor parallel model worker."""

41
    def __init__(
Lianmin Zheng's avatar
Lianmin Zheng committed
42
        self,
43
        server_args: ServerArgs,
44
        gpu_id: int,
Lianmin Zheng's avatar
Lianmin Zheng committed
45
        tp_rank: int,
46
        dp_rank: Optional[int],
Mingyi's avatar
Mingyi committed
47
        nccl_port: int,
Lianmin Zheng's avatar
Lianmin Zheng committed
48
    ):
49
        # Parse args
Lianmin Zheng's avatar
Lianmin Zheng committed
50
51
52
53
        self.tp_rank = tp_rank

        # Init model and tokenizer
        self.model_config = ModelConfig(
Liangsheng Yin's avatar
Liangsheng Yin committed
54
            server_args.model_path,
55
            trust_remote_code=server_args.trust_remote_code,
56
            revision=server_args.revision,
Liangsheng Yin's avatar
Liangsheng Yin committed
57
            context_length=server_args.context_length,
58
59
            model_override_args=server_args.json_model_override_args,
            is_embedding=server_args.is_embedding,
60
61
            dtype=server_args.dtype,
            quantization=server_args.quantization,
Lianmin Zheng's avatar
Lianmin Zheng committed
62
63
        )
        self.model_runner = ModelRunner(
Liangsheng Yin's avatar
Liangsheng Yin committed
64
65
            model_config=self.model_config,
            mem_fraction_static=server_args.mem_fraction_static,
66
            gpu_id=gpu_id,
Liangsheng Yin's avatar
Liangsheng Yin committed
67
68
            tp_rank=tp_rank,
            tp_size=server_args.tp_size,
Mingyi's avatar
Mingyi committed
69
            nccl_port=nccl_port,
Lianmin Zheng's avatar
Lianmin Zheng committed
70
            server_args=server_args,
Lianmin Zheng's avatar
Lianmin Zheng committed
71
        )
72
73
        if server_args.skip_tokenizer_init:
            self.tokenizer = self.processor = None
Lianmin Zheng's avatar
Lianmin Zheng committed
74
        else:
75
            if self.model_config.is_multimodal:
76
77
78
79
80
81
82
83
84
85
86
87
                self.processor = get_processor(
                    server_args.tokenizer_path,
                    tokenizer_mode=server_args.tokenizer_mode,
                    trust_remote_code=server_args.trust_remote_code,
                )
                self.tokenizer = self.processor.tokenizer
            else:
                self.tokenizer = get_tokenizer(
                    server_args.tokenizer_path,
                    tokenizer_mode=server_args.tokenizer_mode,
                    trust_remote_code=server_args.trust_remote_code,
                )
88
        self.device = self.model_runner.device
89
90

        # Profile number of tokens
91
        self.max_total_num_tokens = self.model_runner.max_total_num_tokens
92
        self.max_prefill_tokens = server_args.max_prefill_tokens
Ying Sheng's avatar
Ying Sheng committed
93
        self.max_running_requests = min(
94
95
96
97
98
            (
                self.max_total_num_tokens // 2
                if server_args.max_running_requests is None
                else server_args.max_running_requests
            ),
99
            self.model_runner.req_to_token_pool.size,
Ying Sheng's avatar
Ying Sheng committed
100
        )
101
        self.max_req_len = min(
102
103
104
            self.model_config.context_len - 1,
            self.max_total_num_tokens - 1,
        )
105
106
107
108
        self.max_req_input_len = self.max_req_len - 5
        assert (
            self.max_req_len > 0 and self.max_req_input_len > 0
        ), "Memory pool size is too small"
109

Lianmin Zheng's avatar
Lianmin Zheng committed
110
        # Sync random seed across TP workers
111
        self.random_seed = broadcast_pyobj(
112
113
114
115
            [server_args.random_seed],
            self.tp_rank,
            self.model_runner.tp_group.cpu_group,
        )[0]
116
        set_random_seed(self.random_seed)
117

118
    def get_worker_info(self):
119
120
        return (
            self.max_total_num_tokens,
121
            self.max_prefill_tokens,
122
            self.max_running_requests,
123
            self.max_req_len,
124
125
            self.max_req_input_len,
            self.random_seed,
126
            self.device,
127
128
129
130
            global_server_args_dict,
            self.model_runner.req_to_token_pool.size,
            self.model_runner.req_to_token_pool.max_context_len,
            self.model_runner.token_to_kv_pool.size,
131
132
133
134
135
136
137
138
139
140
141
142
        )

    def get_pad_input_ids_func(self):
        return getattr(self.model_runner.model, "pad_input_ids", None)

    def get_tp_cpu_group(self):
        return self.model_runner.tp_group.cpu_group

    def get_memory_pool(self):
        return (
            self.model_runner.req_to_token_pool,
            self.model_runner.token_to_kv_pool,
Lianmin Zheng's avatar
Lianmin Zheng committed
143
        )
144

Ke Bao's avatar
Ke Bao committed
145
146
147
148
    def forward_batch_idle(self, model_worker_batch: ModelWorkerBatch):
        forward_batch = ForwardBatch.init_new(model_worker_batch, self.model_runner)
        self.model_runner.forward(forward_batch)

149
150
151
    def forward_batch_generation(
        self,
        model_worker_batch: ModelWorkerBatch,
152
        launch_done: Optional[threading.Event] = None,
153
        skip_sample: bool = False,
154
    ):
155
        forward_batch = ForwardBatch.init_new(model_worker_batch, self.model_runner)
156
        logits_output = self.model_runner.forward(forward_batch)
157
158
        if launch_done:
            launch_done.set()
159
160
161
162
163
164

        if skip_sample:
            next_token_ids = None
        else:
            next_token_ids = self.model_runner.sample(logits_output, model_worker_batch)

165
        return logits_output, next_token_ids
Lianmin Zheng's avatar
Lianmin Zheng committed
166

167
168
    def forward_batch_embedding(self, model_worker_batch: ModelWorkerBatch):
        forward_batch = ForwardBatch.init_new(model_worker_batch, self.model_runner)
169
        logits_output = self.model_runner.forward(forward_batch)
170
        embeddings = logits_output.embeddings
171
        return embeddings
172

Chayenne's avatar
Chayenne committed
173
174
    def update_weights_from_disk(self, recv_req: UpdateWeightFromDiskReqInput):
        success, message = self.model_runner.update_weights_from_disk(
175
176
177
            recv_req.model_path, recv_req.load_format
        )
        return success, message
178

179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
    def init_weights_update_group(self, recv_req: InitWeightsUpdateGroupReqInput):
        success, message = self.model_runner.init_weights_update_group(
            recv_req.master_address,
            recv_req.master_port,
            recv_req.rank_offset,
            recv_req.world_size,
            recv_req.group_name,
            recv_req.backend,
        )
        return success, message

    def update_weights_from_distributed(
        self, recv_req: UpdateWeightsFromDistributedReqInput
    ):
        success, message = self.model_runner.update_weights_from_distributed(
            recv_req.name, recv_req.dtype, recv_req.shape
        )
        return success, message

198
199
200
201
202
203
    def update_weights_from_tensor(self, recv_req: UpdateWeightsFromTensorReqInput):
        success, message = self.model_runner.update_weights_from_tensor(
            recv_req.name, recv_req.tensor
        )
        return success, message

204
205
206
207
208
    def get_weights_by_name(self, recv_req: GetWeightsByNameReqInput):
        parameter = self.model_runner.get_weights_by_name(
            recv_req.name, recv_req.truncate_size
        )
        return parameter