Unverified Commit 9eefe2c0 authored by Lianmin Zheng's avatar Lianmin Zheng Committed by GitHub
Browse files

Set CUDA_VISIBLE_DEVICES to achieve one GPU per process (#9170)


Co-authored-by: default avatarSangBin Cho <rkooo567@gmail.com>
Co-authored-by: default avatarCheng Wan <cwan@x.ai>
Co-authored-by: default avatarCheng Wan <54331508+ch-wan@users.noreply.github.com>
parent 69fe3c97
...@@ -74,6 +74,7 @@ SGLang supports various environment variables that can be used to configure its ...@@ -74,6 +74,7 @@ SGLang supports various environment variables that can be used to configure its
| `SGLANG_BLOCK_NONZERO_RANK_CHILDREN` | Control blocking of non-zero rank children processes | `1` | | `SGLANG_BLOCK_NONZERO_RANK_CHILDREN` | Control blocking of non-zero rank children processes | `1` |
| `SGL_IS_FIRST_RANK_ON_NODE` | Indicates if the current process is the first rank on its node | `"true"` | | `SGL_IS_FIRST_RANK_ON_NODE` | Indicates if the current process is the first rank on its node | `"true"` |
| `SGLANG_PP_LAYER_PARTITION` | Pipeline parallel layer partition specification | Not set | | `SGLANG_PP_LAYER_PARTITION` | Pipeline parallel layer partition specification | Not set |
| `SGLANG_ONE_VISIBLE_DEVICE_PER_PROCESS` | Set one visible device per process for distributed computing | `false` |
## Testing & Debugging (Internal/CI) ## Testing & Debugging (Internal/CI)
......
...@@ -61,6 +61,7 @@ import torch.distributed as dist ...@@ -61,6 +61,7 @@ import torch.distributed as dist
from sglang.srt.configs.model_config import ModelConfig from sglang.srt.configs.model_config import ModelConfig
from sglang.srt.distributed.parallel_state import destroy_distributed_environment from sglang.srt.distributed.parallel_state import destroy_distributed_environment
from sglang.srt.entrypoints.engine import _set_envs_and_config from sglang.srt.entrypoints.engine import _set_envs_and_config
from sglang.srt.environ import envs
from sglang.srt.layers.moe import initialize_moe_config from sglang.srt.layers.moe import initialize_moe_config
from sglang.srt.managers.schedule_batch import Req, ScheduleBatch from sglang.srt.managers.schedule_batch import Req, ScheduleBatch
from sglang.srt.managers.scheduler import Scheduler from sglang.srt.managers.scheduler import Scheduler
...@@ -75,6 +76,7 @@ from sglang.srt.utils import ( ...@@ -75,6 +76,7 @@ from sglang.srt.utils import (
is_cuda_alike, is_cuda_alike,
is_xpu, is_xpu,
kill_process_tree, kill_process_tree,
maybe_reindex_device_id,
require_mlp_sync, require_mlp_sync,
require_mlp_tp_gather, require_mlp_tp_gather,
set_gpu_proc_affinity, set_gpu_proc_affinity,
...@@ -159,7 +161,7 @@ class BenchArgs: ...@@ -159,7 +161,7 @@ class BenchArgs:
) )
def load_model(server_args, port_args, tp_rank): def load_model(server_args, port_args, gpu_id, tp_rank):
suppress_other_loggers() suppress_other_loggers()
rank_print = print if tp_rank == 0 else lambda *args, **kwargs: None rank_print = print if tp_rank == 0 else lambda *args, **kwargs: None
moe_ep_rank = tp_rank // (server_args.tp_size // server_args.ep_size) moe_ep_rank = tp_rank // (server_args.tp_size // server_args.ep_size)
...@@ -168,7 +170,7 @@ def load_model(server_args, port_args, tp_rank): ...@@ -168,7 +170,7 @@ def load_model(server_args, port_args, tp_rank):
model_runner = ModelRunner( model_runner = ModelRunner(
model_config=model_config, model_config=model_config,
mem_fraction_static=server_args.mem_fraction_static, mem_fraction_static=server_args.mem_fraction_static,
gpu_id=tp_rank, gpu_id=gpu_id,
tp_rank=tp_rank, tp_rank=tp_rank,
tp_size=server_args.tp_size, tp_size=server_args.tp_size,
moe_ep_rank=moe_ep_rank, moe_ep_rank=moe_ep_rank,
...@@ -350,6 +352,7 @@ def correctness_test( ...@@ -350,6 +352,7 @@ def correctness_test(
server_args, server_args,
port_args, port_args,
bench_args, bench_args,
gpu_id,
tp_rank, tp_rank,
): ):
# Configure the logger # Configure the logger
...@@ -357,7 +360,7 @@ def correctness_test( ...@@ -357,7 +360,7 @@ def correctness_test(
rank_print = print if tp_rank == 0 else lambda *args, **kwargs: None rank_print = print if tp_rank == 0 else lambda *args, **kwargs: None
# Load the model # Load the model
model_runner, tokenizer = load_model(server_args, port_args, tp_rank) model_runner, tokenizer = load_model(server_args, port_args, gpu_id, tp_rank)
# Prepare inputs # Prepare inputs
custom_prompts = _read_prompts_from_file(bench_args.prompt_filename, rank_print) custom_prompts = _read_prompts_from_file(bench_args.prompt_filename, rank_print)
...@@ -517,6 +520,7 @@ def latency_test( ...@@ -517,6 +520,7 @@ def latency_test(
server_args, server_args,
port_args, port_args,
bench_args, bench_args,
gpu_id,
tp_rank, tp_rank,
): ):
initialize_moe_config(server_args) initialize_moe_config(server_args)
...@@ -532,7 +536,7 @@ def latency_test( ...@@ -532,7 +536,7 @@ def latency_test(
rank_print = print if tp_rank == 0 else lambda *args, **kwargs: None rank_print = print if tp_rank == 0 else lambda *args, **kwargs: None
# Load the model # Load the model
model_runner, tokenizer = load_model(server_args, port_args, tp_rank) model_runner, tokenizer = load_model(server_args, port_args, gpu_id, tp_rank)
# Prepare inputs for warm up # Prepare inputs for warm up
reqs = prepare_synthetic_inputs_for_latency_test( reqs = prepare_synthetic_inputs_for_latency_test(
...@@ -634,21 +638,23 @@ def main(server_args, bench_args): ...@@ -634,21 +638,23 @@ def main(server_args, bench_args):
port_args = PortArgs.init_new(server_args) port_args = PortArgs.init_new(server_args)
if server_args.tp_size == 1: if server_args.tp_size == 1:
work_func(server_args, port_args, bench_args, 0) work_func(server_args, port_args, bench_args, 0, 0)
else: else:
workers = [] workers = []
for tp_rank in range(server_args.tp_size): for tp_rank in range(server_args.tp_size):
proc = multiprocessing.Process( with maybe_reindex_device_id(tp_rank) as gpu_id:
target=work_func, proc = multiprocessing.Process(
args=( target=work_func,
server_args, args=(
port_args, server_args,
bench_args, port_args,
tp_rank, bench_args,
), gpu_id,
) tp_rank,
proc.start() ),
workers.append(proc) )
proc.start()
workers.append(proc)
for proc in workers: for proc in workers:
proc.join() proc.join()
......
...@@ -39,6 +39,7 @@ import torch ...@@ -39,6 +39,7 @@ import torch
import torch.distributed import torch.distributed
from torch.distributed import Backend, ProcessGroup from torch.distributed import Backend, ProcessGroup
from sglang.srt.environ import envs
from sglang.srt.utils import ( from sglang.srt.utils import (
direct_register_custom_op, direct_register_custom_op,
get_bool_env_var, get_bool_env_var,
...@@ -56,8 +57,6 @@ _is_npu = is_npu() ...@@ -56,8 +57,6 @@ _is_npu = is_npu()
_is_cpu = is_cpu() _is_cpu = is_cpu()
_supports_custom_op = supports_custom_op() _supports_custom_op = supports_custom_op()
IS_ONE_DEVICE_PER_PROCESS = get_bool_env_var("SGLANG_ONE_DEVICE_PER_PROCESS")
TensorMetadata = namedtuple("TensorMetadata", ["device", "dtype", "size"]) TensorMetadata = namedtuple("TensorMetadata", ["device", "dtype", "size"])
...@@ -277,11 +276,13 @@ class GroupCoordinator: ...@@ -277,11 +276,13 @@ class GroupCoordinator:
assert self.cpu_group is not None assert self.cpu_group is not None
assert self.device_group is not None assert self.device_group is not None
device_id = 0 if IS_ONE_DEVICE_PER_PROCESS else local_rank
if is_cuda_alike(): if is_cuda_alike():
device_id = (
0 if envs.SGLANG_ONE_VISIBLE_DEVICE_PER_PROCESS.get() else local_rank
)
self.device = torch.device(f"cuda:{device_id}") self.device = torch.device(f"cuda:{device_id}")
elif _is_npu: elif _is_npu:
self.device = torch.device(f"npu:{device_id}") self.device = torch.device(f"npu:{local_rank}")
else: else:
self.device = torch.device("cpu") self.device = torch.device("cpu")
self.device_module = torch.get_device_module(self.device) self.device_module = torch.get_device_module(self.device)
......
...@@ -75,6 +75,7 @@ from sglang.srt.utils import ( ...@@ -75,6 +75,7 @@ from sglang.srt.utils import (
is_cuda, is_cuda,
kill_process_tree, kill_process_tree,
launch_dummy_health_check_server, launch_dummy_health_check_server,
maybe_reindex_device_id,
prepare_model_and_tokenizer, prepare_model_and_tokenizer,
set_prometheus_multiproc_dir, set_prometheus_multiproc_dir,
set_ulimit, set_ulimit,
...@@ -782,22 +783,24 @@ def _launch_subprocesses( ...@@ -782,22 +783,24 @@ def _launch_subprocesses(
+ (tp_rank % tp_size_per_node) * server_args.gpu_id_step + (tp_rank % tp_size_per_node) * server_args.gpu_id_step
) )
moe_ep_rank = tp_rank // (server_args.tp_size // server_args.ep_size) moe_ep_rank = tp_rank // (server_args.tp_size // server_args.ep_size)
proc = mp.Process(
target=run_scheduler_process,
args=(
server_args,
port_args,
gpu_id,
tp_rank,
moe_ep_rank,
pp_rank,
None,
writer,
),
)
with memory_saver_adapter.configure_subprocess(): with maybe_reindex_device_id(gpu_id) as gpu_id:
proc.start() proc = mp.Process(
target=run_scheduler_process,
args=(
server_args,
port_args,
gpu_id,
tp_rank,
moe_ep_rank,
pp_rank,
None,
writer,
),
)
with memory_saver_adapter.configure_subprocess():
proc.start()
scheduler_procs.append(proc) scheduler_procs.append(proc)
scheduler_pipe_readers.append(reader) scheduler_pipe_readers.append(reader)
else: else:
......
...@@ -142,6 +142,7 @@ class Envs: ...@@ -142,6 +142,7 @@ class Envs:
# Model Parallel # Model Parallel
SGLANG_USE_MESSAGE_QUEUE_BROADCASTER = EnvBool(True) SGLANG_USE_MESSAGE_QUEUE_BROADCASTER = EnvBool(True)
SGLANG_ONE_VISIBLE_DEVICE_PER_PROCESS = EnvBool(False)
# Constrained Decoding # Constrained Decoding
SGLANG_DISABLE_OUTLINES_DISK_CACHE = EnvBool(True) SGLANG_DISABLE_OUTLINES_DISK_CACHE = EnvBool(True)
......
...@@ -46,6 +46,7 @@ from sglang.srt.utils import ( ...@@ -46,6 +46,7 @@ from sglang.srt.utils import (
configure_logger, configure_logger,
get_zmq_socket, get_zmq_socket,
kill_itself_when_parent_died, kill_itself_when_parent_died,
maybe_reindex_device_id,
) )
from sglang.srt.utils.torch_memory_saver_adapter import TorchMemorySaverAdapter from sglang.srt.utils.torch_memory_saver_adapter import TorchMemorySaverAdapter
from sglang.utils import TypeBasedDispatcher, get_exception_traceback from sglang.utils import TypeBasedDispatcher, get_exception_traceback
...@@ -139,6 +140,9 @@ class DataParallelController: ...@@ -139,6 +140,9 @@ class DataParallelController:
# Load balance budget # Load balance budget
self.dp_budget = DPBudget() self.dp_budget = DPBudget()
# To protect changing env vars to set CUDA_VISIBLE_DEVICES.
self.env_lock = threading.Lock()
# Launch data parallel workers # Launch data parallel workers
self.scheduler_procs = [] self.scheduler_procs = []
self.workers: List[zmq.Socket] = [None] * server_args.dp_size self.workers: List[zmq.Socket] = [None] * server_args.dp_size
...@@ -399,21 +403,22 @@ class DataParallelController: ...@@ -399,21 +403,22 @@ class DataParallelController:
+ (tp_rank % tp_size_per_node) * server_args.gpu_id_step + (tp_rank % tp_size_per_node) * server_args.gpu_id_step
) )
moe_ep_rank = tp_rank // (server_args.tp_size // server_args.ep_size) moe_ep_rank = tp_rank // (server_args.tp_size // server_args.ep_size)
proc = mp.Process( with self.env_lock, maybe_reindex_device_id(gpu_id) as gpu_id:
target=run_scheduler_process, proc = mp.Process(
args=( target=run_scheduler_process,
server_args, args=(
rank_port_args, server_args,
gpu_id, rank_port_args,
tp_rank, gpu_id,
moe_ep_rank, tp_rank,
pp_rank, moe_ep_rank,
dp_rank, pp_rank,
writer, dp_rank,
), writer,
) ),
with memory_saver_adapter.configure_subprocess(): )
proc.start() with memory_saver_adapter.configure_subprocess():
proc.start()
self.scheduler_procs.append(proc) self.scheduler_procs.append(proc)
scheduler_pipe_readers.append(reader) scheduler_pipe_readers.append(reader)
......
...@@ -88,6 +88,7 @@ from torch.profiler import ProfilerActivity, profile, record_function ...@@ -88,6 +88,7 @@ from torch.profiler import ProfilerActivity, profile, record_function
from torch.utils._contextlib import _DecoratorContextManager from torch.utils._contextlib import _DecoratorContextManager
from typing_extensions import Literal from typing_extensions import Literal
from sglang.srt.environ import envs
from sglang.srt.metrics.func_timer import enable_func_timer from sglang.srt.metrics.func_timer import enable_func_timer
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
...@@ -3273,7 +3274,7 @@ def json_list_type(value): ...@@ -3273,7 +3274,7 @@ def json_list_type(value):
@contextmanager @contextmanager
def maybe_reindex_device_id(gpu_id: int): def maybe_reindex_device_id(gpu_id: int):
if not is_cuda_alike(): if envs.SGLANG_ONE_VISIBLE_DEVICE_PER_PROCESS.get() is False or not is_cuda_alike():
yield gpu_id yield gpu_id
return return
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment