tp_worker.py 5.59 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
"""
Copyright 2023-2024 SGLang Team
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""

Lianmin Zheng's avatar
Lianmin Zheng committed
16
17
"""A tensor parallel worker."""

Lianmin Zheng's avatar
Lianmin Zheng committed
18
import logging
19
from typing import Optional
20

21
from sglang.srt.configs.model_config import ModelConfig
Lianmin Zheng's avatar
Lianmin Zheng committed
22
from sglang.srt.hf_transformers_utils import get_processor, get_tokenizer
23
from sglang.srt.managers.io_struct import UpdateWeightReqInput
24
from sglang.srt.managers.schedule_batch import ModelWorkerBatch, global_server_args_dict
25
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
26
from sglang.srt.model_executor.model_runner import ModelRunner
Mingyi's avatar
Mingyi committed
27
from sglang.srt.server_args import ServerArgs
28
from sglang.srt.utils import broadcast_pyobj, set_random_seed
29

Ying Sheng's avatar
Ying Sheng committed
30
logger = logging.getLogger(__name__)
Lianmin Zheng's avatar
Lianmin Zheng committed
31
32


33
34
35
class TpModelWorker:
    """A tensor parallel model worker."""

36
    def __init__(
Lianmin Zheng's avatar
Lianmin Zheng committed
37
        self,
38
        server_args: ServerArgs,
39
        gpu_id: int,
Lianmin Zheng's avatar
Lianmin Zheng committed
40
        tp_rank: int,
41
        dp_rank: Optional[int],
Mingyi's avatar
Mingyi committed
42
        nccl_port: int,
Lianmin Zheng's avatar
Lianmin Zheng committed
43
    ):
44
        # Parse args
Lianmin Zheng's avatar
Lianmin Zheng committed
45
46
47
48
        self.tp_rank = tp_rank

        # Init model and tokenizer
        self.model_config = ModelConfig(
Liangsheng Yin's avatar
Liangsheng Yin committed
49
            server_args.model_path,
50
            trust_remote_code=server_args.trust_remote_code,
Liangsheng Yin's avatar
Liangsheng Yin committed
51
            context_length=server_args.context_length,
52
53
            model_override_args=server_args.json_model_override_args,
            is_embedding=server_args.is_embedding,
Lianmin Zheng's avatar
Lianmin Zheng committed
54
55
        )
        self.model_runner = ModelRunner(
Liangsheng Yin's avatar
Liangsheng Yin committed
56
57
            model_config=self.model_config,
            mem_fraction_static=server_args.mem_fraction_static,
58
            gpu_id=gpu_id,
Liangsheng Yin's avatar
Liangsheng Yin committed
59
60
            tp_rank=tp_rank,
            tp_size=server_args.tp_size,
Mingyi's avatar
Mingyi committed
61
            nccl_port=nccl_port,
Lianmin Zheng's avatar
Lianmin Zheng committed
62
            server_args=server_args,
Lianmin Zheng's avatar
Lianmin Zheng committed
63
        )
64
65
        if server_args.skip_tokenizer_init:
            self.tokenizer = self.processor = None
Lianmin Zheng's avatar
Lianmin Zheng committed
66
        else:
67
            if self.model_config.is_multimodal:
68
69
70
71
72
73
74
75
76
77
78
79
                self.processor = get_processor(
                    server_args.tokenizer_path,
                    tokenizer_mode=server_args.tokenizer_mode,
                    trust_remote_code=server_args.trust_remote_code,
                )
                self.tokenizer = self.processor.tokenizer
            else:
                self.tokenizer = get_tokenizer(
                    server_args.tokenizer_path,
                    tokenizer_mode=server_args.tokenizer_mode,
                    trust_remote_code=server_args.trust_remote_code,
                )
80
        self.device = self.model_runner.device
81
82

        # Profile number of tokens
83
        self.max_total_num_tokens = self.model_runner.max_total_num_tokens
84
        self.max_prefill_tokens = server_args.max_prefill_tokens
Ying Sheng's avatar
Ying Sheng committed
85
        self.max_running_requests = min(
86
87
88
89
90
            (
                self.max_total_num_tokens // 2
                if server_args.max_running_requests is None
                else server_args.max_running_requests
            ),
91
            self.model_runner.req_to_token_pool.size,
Ying Sheng's avatar
Ying Sheng committed
92
        )
93
        self.max_req_len = min(
94
95
96
            self.model_config.context_len - 1,
            self.max_total_num_tokens - 1,
        )
97
98
99
100
        self.max_req_input_len = self.max_req_len - 5
        assert (
            self.max_req_len > 0 and self.max_req_input_len > 0
        ), "Memory pool size is too small"
101

Lianmin Zheng's avatar
Lianmin Zheng committed
102
        # Sync random seed across TP workers
103
        self.random_seed = broadcast_pyobj(
104
105
106
107
            [server_args.random_seed],
            self.tp_rank,
            self.model_runner.tp_group.cpu_group,
        )[0]
108
        set_random_seed(self.random_seed)
109

110
    def get_worker_info(self):
111
112
        return (
            self.max_total_num_tokens,
113
            self.max_prefill_tokens,
114
            self.max_running_requests,
115
            self.max_req_len,
116
117
            self.max_req_input_len,
            self.random_seed,
118
            self.device,
119
120
121
122
            global_server_args_dict,
            self.model_runner.req_to_token_pool.size,
            self.model_runner.req_to_token_pool.max_context_len,
            self.model_runner.token_to_kv_pool.size,
123
124
125
126
127
128
129
130
131
132
133
134
        )

    def get_pad_input_ids_func(self):
        return getattr(self.model_runner.model, "pad_input_ids", None)

    def get_tp_cpu_group(self):
        return self.model_runner.tp_group.cpu_group

    def get_memory_pool(self):
        return (
            self.model_runner.req_to_token_pool,
            self.model_runner.token_to_kv_pool,
Lianmin Zheng's avatar
Lianmin Zheng committed
135
        )
136

137
138
    def forward_batch_generation(self, model_worker_batch: ModelWorkerBatch):
        forward_batch = ForwardBatch.init_new(model_worker_batch, self.model_runner)
139
        logits_output = self.model_runner.forward(forward_batch)
140
        next_token_ids = self.model_runner.sample(logits_output, model_worker_batch)
141
        return logits_output, next_token_ids
Lianmin Zheng's avatar
Lianmin Zheng committed
142

143
144
    def forward_batch_embedding(self, model_worker_batch: ModelWorkerBatch):
        forward_batch = ForwardBatch.init_new(model_worker_batch, self.model_runner)
145
        logits_output = self.model_runner.forward(forward_batch)
146
        embeddings = logits_output.embeddings
147
        return embeddings
148

149
    def update_weights(self, recv_req: UpdateWeightReqInput):
150
151
152
153
        success, message = self.model_runner.update_weights(
            recv_req.model_path, recv_req.load_format
        )
        return success, message