"googlemock/include/gmock/vscode:/vscode.git/clone" did not exist on "aa1146da816954a87bdd6ad745d6281a529e0d90"
sglang.rs 2.27 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

use std::path::Path;
use std::sync::Arc;

use crate::backend::ExecutionContext;
use triton_distributed_runtime::pipeline::error as pipeline_error;
use triton_distributed_runtime::CancellationToken;

mod worker;

mod engine;
use engine::SgLangEngine;

mod subprocess;
pub use subprocess::run_subprocess;

pub async fn make_engine(
    cancel_token: CancellationToken,
    // Full path to the model directory
    model_path: &Path,
    // Unique string to name zmq sockets
    sock_code: &str,
37
38
    // Multi node settings
    node_conf: super::MultiNodeConfig,
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
    // How many GPUs to use
    tensor_parallel_size: u32,
    // The base GPU ID to start allocating GPUs from
    base_gpu_id: u32,
) -> pipeline_error::Result<(ExecutionContext, tokio::task::JoinHandle<()>)> {
    let mut engine = SgLangEngine::new(
        cancel_token,
        sock_code,
        model_path,
        node_conf,
        tensor_parallel_size,
        base_gpu_id,
    )
    .await?;
    let sglang_process = engine.take_sglang_worker_handle();
    let engine: ExecutionContext = Arc::new(engine);
    Ok((engine, sglang_process))
}

#[derive(Debug, Clone, Copy)]
pub struct MultiGPUConfig {
    /// How many GPUs we are using / how many processes
    pub tp_size: u32,
    /// Tensor Parallel Rank. Must be unique across all nodes and GPUs.
    pub tp_rank: u32,
    /// GPU ID. Which GPU to run on. In single-node setup this is the same as tp_rank.
    pub gpu_id: u32,
}

impl Default for MultiGPUConfig {
    fn default() -> Self {
        MultiGPUConfig {
            tp_size: 1,
            tp_rank: 0,
            gpu_id: 0,
        }
    }
}