Unverified Commit 73401fd0 authored by Lianmin Zheng's avatar Lianmin Zheng Committed by GitHub
Browse files

Sync distributed package from vllm 0.6.4.post1 (#3010)

parent 89cd9235
# Adapted from https://github.com/vllm-project/vllm/blob/a6221a144af772fd1a68fe7e627935dc53e81738/vllm/distributed/communication_op.py # Adapted from https://github.com/vllm-project/vllm/blob/v0.6.4.post1/vllm/distributed/communication_op.py
from typing import Any, Dict, Optional, Union from typing import Any, Dict, Optional, Union
import torch import torch
import torch.distributed import torch.distributed
from sglang.srt.distributed.parallel_state import get_tp_group from .parallel_state import get_tp_group
def tensor_model_parallel_all_reduce(input_: torch.Tensor) -> torch.Tensor: def tensor_model_parallel_all_reduce(input_: torch.Tensor) -> torch.Tensor:
......
# Adapted from https://github.com/vllm-project/vllm/blob/a6221a144af772fd1a68fe7e627935dc53e81738/vllm/distributed/device_communicators/cuda_wrapper.py # Adapted from https://github.com/vllm-project/vllm/blob/v0.6.4.post1/vllm/distributed/device_communicators/cuda_wrapper.py
"""This file is a pure Python wrapper for the cudart library. """This file is a pure Python wrapper for the cudart library.
It avoids the need to compile a separate shared library, and is It avoids the need to compile a separate shared library, and is
convenient for use when we just need to call a few functions. convenient for use when we just need to call a few functions.
......
# Adapted from https://github.com/vllm-project/vllm/blob/a6221a144af772fd1a68fe7e627935dc53e81738/vllm/distributed/device_communicators/custom_all_reduce.py # Adapted from https://github.com/vllm-project/vllm/blob/v0.6.4.post1/vllm/distributed/device_communicators/custom_all_reduce.py
import ctypes import ctypes
import logging import logging
import os import os
......
# Adapted from https://github.com/vllm-project/vllm/blob/a6221a144af772fd1a68fe7e627935dc53e81738/vllm/distributed/device_communicators/custom_all_reduce_utils.py # Adapted from https://github.com/vllm-project/vllm/blob/v0.6.4.post1/vllm/distributed/device_communicators/custom_all_reduce_utils.py
import ctypes import ctypes
import json import json
import logging import logging
......
# Adapted from https://github.com/vllm-project/vllm/blob/a6221a144af772fd1a68fe7e627935dc53e81738/vllm/distributed/device_communicators/hpu_communicator.py # Adapted from https://github.com/vllm-project/vllm/blob/v0.6.4.post1/vllm/distributed/device_communicators/hpu_communicator.py
import torch import torch
import torch.distributed as dist import torch.distributed as dist
from torch.distributed import ProcessGroup from torch.distributed import ProcessGroup
......
# Adapted from https://github.com/vllm-project/vllm/blob/a6221a144af772fd1a68fe7e627935dc53e81738/vllm/distributed/device_communicators/pynccl.py # Adapted from https://github.com/vllm-project/vllm/blob/v0.6.4.post1/vllm/distributed/device_communicators/pynccl.py
import logging import logging
from contextlib import contextmanager from contextlib import contextmanager
from typing import Optional, Union from typing import Optional, Union
# ===================== import region =====================
import torch import torch
import torch.distributed as dist import torch.distributed as dist
from torch.distributed import ProcessGroup, ReduceOp from torch.distributed import ProcessGroup, ReduceOp
...@@ -143,6 +145,57 @@ class PyNcclCommunicator: ...@@ -143,6 +145,57 @@ class PyNcclCommunicator:
cudaStream_t(stream.cuda_stream), cudaStream_t(stream.cuda_stream),
) )
def all_gather(
self, output_tensor: torch.Tensor, input_tensor: torch.Tensor, stream=None
):
if self.disabled:
return
# nccl communicator created on a specific device
# will only work on tensors on the same device
# otherwise it will cause "illegal memory access"
assert input_tensor.device == self.device, (
f"this nccl communicator is created to work on {self.device}, "
f"but the input tensor is on {input_tensor.device}"
)
if stream is None:
stream = self.stream
self.nccl.ncclAllGather(
buffer_type(input_tensor.data_ptr()),
buffer_type(output_tensor.data_ptr()),
input_tensor.numel(),
ncclDataTypeEnum.from_torch(input_tensor.dtype),
self.comm,
cudaStream_t(stream.cuda_stream),
)
def reduce_scatter(
self,
output_tensor: torch.Tensor,
input_tensor: torch.Tensor,
op: ReduceOp = ReduceOp.SUM,
stream=None,
):
if self.disabled:
return
# nccl communicator created on a specific device
# will only work on tensors on the same device
# otherwise it will cause "illegal memory access"
assert input_tensor.device == self.device, (
f"this nccl communicator is created to work on {self.device}, "
f"but the input tensor is on {input_tensor.device}"
)
if stream is None:
stream = self.stream
self.nccl.ncclReduceScatter(
buffer_type(input_tensor.data_ptr()),
buffer_type(output_tensor.data_ptr()),
output_tensor.numel(),
ncclDataTypeEnum.from_torch(input_tensor.dtype),
ncclRedOpTypeEnum.from_torch(op),
self.comm,
cudaStream_t(stream.cuda_stream),
)
def send(self, tensor: torch.Tensor, dst: int, stream=None): def send(self, tensor: torch.Tensor, dst: int, stream=None):
if self.disabled: if self.disabled:
return return
...@@ -179,6 +232,32 @@ class PyNcclCommunicator: ...@@ -179,6 +232,32 @@ class PyNcclCommunicator:
cudaStream_t(stream.cuda_stream), cudaStream_t(stream.cuda_stream),
) )
def broadcast(self, tensor: torch.Tensor, src: int, stream=None):
if self.disabled:
return
assert tensor.device == self.device, (
f"this nccl communicator is created to work on {self.device}, "
f"but the input tensor is on {tensor.device}"
)
if stream is None:
stream = self.stream
if src == self.rank:
sendbuff = buffer_type(tensor.data_ptr())
# NCCL requires the sender also to have a receive buffer
recvbuff = buffer_type(tensor.data_ptr())
else:
sendbuff = buffer_type()
recvbuff = buffer_type(tensor.data_ptr())
self.nccl.ncclBroadcast(
sendbuff,
recvbuff,
tensor.numel(),
ncclDataTypeEnum.from_torch(tensor.dtype),
src,
self.comm,
cudaStream_t(stream.cuda_stream),
)
@contextmanager @contextmanager
def change_state( def change_state(
self, enable: Optional[bool] = None, stream: Optional[torch.cuda.Stream] = None self, enable: Optional[bool] = None, stream: Optional[torch.cuda.Stream] = None
......
# Adapted from https://github.com/vllm-project/vllm/blob/a6221a144af772fd1a68fe7e627935dc53e81738/vllm/distributed/device_communicators/pynccl.py # Adapted from https://github.com/vllm-project/vllm/blob/v0.6.4.post1/vllm/distributed/device_communicators/pynccl.py
# This file is a pure Python wrapper for the NCCL library. # This file is a pure Python wrapper for the NCCL library.
# The main purpose is to use NCCL combined with CUDA graph. # The main purpose is to use NCCL combined with CUDA graph.
...@@ -187,6 +187,43 @@ class NCCLLibrary: ...@@ -187,6 +187,43 @@ class NCCLLibrary:
cudaStream_t, cudaStream_t,
], ],
), ),
# ncclResult_t ncclAllGather(
# const void* sendbuff, void* recvbuff, size_t count,
# ncclDataType_t datatype, ncclComm_t comm,
# cudaStream_t stream);
# note that cudaStream_t is a pointer type, so the last argument
# is a pointer
Function(
"ncclAllGather",
ncclResult_t,
[
buffer_type,
buffer_type,
ctypes.c_size_t,
ncclDataType_t,
ncclComm_t,
cudaStream_t,
],
),
# ncclResult_t ncclReduceScatter(
# const void* sendbuff, void* recvbuff, size_t count,
# ncclDataType_t datatype, ncclRedOp_t op, ncclComm_t comm,
# cudaStream_t stream);
# note that cudaStream_t is a pointer type, so the last argument
# is a pointer
Function(
"ncclReduceScatter",
ncclResult_t,
[
buffer_type,
buffer_type,
ctypes.c_size_t,
ncclDataType_t,
ncclRedOp_t,
ncclComm_t,
cudaStream_t,
],
),
# ncclResult_t ncclSend( # ncclResult_t ncclSend(
# const void* sendbuff, size_t count, ncclDataType_t datatype, # const void* sendbuff, size_t count, ncclDataType_t datatype,
# int dest, ncclComm_t comm, cudaStream_t stream); # int dest, ncclComm_t comm, cudaStream_t stream);
...@@ -217,6 +254,23 @@ class NCCLLibrary: ...@@ -217,6 +254,23 @@ class NCCLLibrary:
cudaStream_t, cudaStream_t,
], ],
), ),
# ncclResult_t ncclBroadcast(
# const void* sendbuff, void* recvbuff, size_t count,
# ncclDataType_t datatype, int root, ncclComm_t comm,
# cudaStream_t stream);
Function(
"ncclBroadcast",
ncclResult_t,
[
buffer_type,
buffer_type,
ctypes.c_size_t,
ncclDataType_t,
ctypes.c_int,
ncclComm_t,
cudaStream_t,
],
),
# be cautious! this is a collective call, it will block until all # be cautious! this is a collective call, it will block until all
# processes in the communicator have called this function. # processes in the communicator have called this function.
# because Python object destruction can happen in random order, # because Python object destruction can happen in random order,
...@@ -321,6 +375,46 @@ class NCCLLibrary: ...@@ -321,6 +375,46 @@ class NCCLLibrary:
) )
) )
def ncclReduceScatter(
self,
sendbuff: buffer_type,
recvbuff: buffer_type,
count: int,
datatype: int,
op: int,
comm: ncclComm_t,
stream: cudaStream_t,
) -> None:
# `datatype` actually should be `ncclDataType_t`
# and `op` should be `ncclRedOp_t`
# both are aliases of `ctypes.c_int`
# when we pass int to a function, it will be converted to `ctypes.c_int`
# by ctypes automatically
self.NCCL_CHECK(
self._funcs["ncclReduceScatter"](
sendbuff, recvbuff, count, datatype, op, comm, stream
)
)
def ncclAllGather(
self,
sendbuff: buffer_type,
recvbuff: buffer_type,
count: int,
datatype: int,
comm: ncclComm_t,
stream: cudaStream_t,
) -> None:
# `datatype` actually should be `ncclDataType_t`
# which is an aliases of `ctypes.c_int`
# when we pass int to a function, it will be converted to `ctypes.c_int`
# by ctypes automatically
self.NCCL_CHECK(
self._funcs["ncclAllGather"](
sendbuff, recvbuff, count, datatype, comm, stream
)
)
def ncclSend( def ncclSend(
self, self,
sendbuff: buffer_type, sendbuff: buffer_type,
...@@ -347,6 +441,22 @@ class NCCLLibrary: ...@@ -347,6 +441,22 @@ class NCCLLibrary:
self._funcs["ncclRecv"](recvbuff, count, datatype, src, comm, stream) self._funcs["ncclRecv"](recvbuff, count, datatype, src, comm, stream)
) )
def ncclBroadcast(
self,
sendbuff: buffer_type,
recvbuff: buffer_type,
count: int,
datatype: int,
root: int,
comm: ncclComm_t,
stream: cudaStream_t,
) -> None:
self.NCCL_CHECK(
self._funcs["ncclBroadcast"](
sendbuff, recvbuff, count, datatype, root, comm, stream
)
)
def ncclCommDestroy(self, comm: ncclComm_t) -> None: def ncclCommDestroy(self, comm: ncclComm_t) -> None:
self.NCCL_CHECK(self._funcs["ncclCommDestroy"](comm)) self.NCCL_CHECK(self._funcs["ncclCommDestroy"](comm))
......
# Adapted from https://github.com/vllm-project/vllm/blob/a6221a144af772fd1a68fe7e627935dc53e81738/vllm/distributed/device_communicators/shm_broadcast.py # Adapted from https://github.com/vllm-project/vllm/blob/v0.6.4.post1/vllm/distributed/device_communicators/shm_broadcast.py
import ipaddress
import logging import logging
import os import os
import pickle import pickle
import socket
import time import time
import warnings
from contextlib import contextmanager from contextlib import contextmanager
from dataclasses import dataclass, field from dataclasses import dataclass, field
from multiprocessing import shared_memory from multiprocessing import shared_memory
...@@ -18,6 +16,8 @@ from torch.distributed import ProcessGroup ...@@ -18,6 +16,8 @@ from torch.distributed import ProcessGroup
from zmq import IPV6 # type: ignore from zmq import IPV6 # type: ignore
from zmq import SUB, SUBSCRIBE, XPUB, XPUB_VERBOSE, Context # type: ignore from zmq import SUB, SUBSCRIBE, XPUB, XPUB_VERBOSE, Context # type: ignore
from sglang.srt.utils import get_ip, get_open_port, is_valid_ipv6_address
# SGLANG_RINGBUFFER_WARNING_INTERVAL can be set to 60 # SGLANG_RINGBUFFER_WARNING_INTERVAL can be set to 60
SGLANG_RINGBUFFER_WARNING_INTERVAL = int( SGLANG_RINGBUFFER_WARNING_INTERVAL = int(
os.environ.get("SGLANG_RINGBUFFER_WARNING_INTERVAL", "60") os.environ.get("SGLANG_RINGBUFFER_WARNING_INTERVAL", "60")
...@@ -26,73 +26,6 @@ SGLANG_RINGBUFFER_WARNING_INTERVAL = int( ...@@ -26,73 +26,6 @@ SGLANG_RINGBUFFER_WARNING_INTERVAL = int(
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
def get_ip() -> str:
# SGLANG_HOST_IP env can be ignore
host_ip = os.getenv("SGLANG_HOST_IP", "") or os.getenv("HOST_IP", "")
if host_ip:
return host_ip
# IP is not set, try to get it from the network interface
# try ipv4
s = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
try:
s.connect(("8.8.8.8", 80)) # Doesn't need to be reachable
return s.getsockname()[0]
except Exception:
pass
# try ipv6
try:
s = socket.socket(socket.AF_INET6, socket.SOCK_DGRAM)
# Google's public DNS server, see
# https://developers.google.com/speed/public-dns/docs/using#addresses
s.connect(("2001:4860:4860::8888", 80)) # Doesn't need to be reachable
return s.getsockname()[0]
except Exception:
pass
warnings.warn(
"Failed to get the IP address, using 0.0.0.0 by default."
"The value can be set by the environment variable"
" SGLANG_HOST_IP or HOST_IP.",
stacklevel=2,
)
return "0.0.0.0"
def get_open_port() -> int:
port = os.getenv("SGLANG_PORT")
if port is not None:
while True:
try:
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
s.bind(("", port))
return port
except OSError:
port += 1 # Increment port number if already in use
logger.info("Port %d is already in use, trying port %d", port - 1, port)
# try ipv4
try:
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
s.bind(("", 0))
return s.getsockname()[1]
except OSError:
# try ipv6
with socket.socket(socket.AF_INET6, socket.SOCK_STREAM) as s:
s.bind(("", 0))
return s.getsockname()[1]
def is_valid_ipv6_address(address: str) -> bool:
try:
ipaddress.IPv6Address(address)
return True
except ValueError:
return False
class ShmRingBuffer: class ShmRingBuffer:
def __init__( def __init__(
......
# Adapted from https://github.com/vllm-project/vllm/blob/a6221a144af772fd1a68fe7e627935dc53e81738/vllm/distributed/device_communicators/xpu_communicator.py # Adapted from https://github.com/vllm-project/vllm/blob/v0.6.4.post1/vllm/distributed/device_communicators/xpu_communicator.py
import torch import torch
import torch.distributed as dist import torch.distributed as dist
from torch.distributed import ProcessGroup from torch.distributed import ProcessGroup
......
# Adapted from https://github.com/vllm-project/vllm/blob/a6221a144af772fd1a68fe7e627935dc53e81738/vllm/distributed/parallel_state.py # Adapted from https://github.com/vllm-project/vllm/blob/v0.6.4.post1/vllm/distributed/parallel_state.py
# Copyright 2023 The vLLM team. # Copyright 2023 The vLLM team.
# Adapted from # Adapted from
......
# Adapted from https://github.com/vllm-project/vllm/blob/a6221a144af772fd1a68fe7e627935dc53e81738/vllm/distributed/utils.py # Adapted from https://github.com/vllm-project/vllm/blob/v0.6.4.post1/vllm/distributed/utils.py
# Copyright 2023 The vLLM team. # Copyright 2023 The vLLM team.
# Adapted from # Adapted from
# https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/core/tensor_parallel/utils.py # https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/core/tensor_parallel/utils.py
......
...@@ -29,8 +29,8 @@ from sglang.srt.utils import ( ...@@ -29,8 +29,8 @@ from sglang.srt.utils import (
get_nvgpu_memory_capacity, get_nvgpu_memory_capacity,
is_flashinfer_available, is_flashinfer_available,
is_hip, is_hip,
is_ipv6,
is_port_available, is_port_available,
is_valid_ipv6_address,
nullable_str, nullable_str,
) )
...@@ -883,7 +883,7 @@ class ServerArgs: ...@@ -883,7 +883,7 @@ class ServerArgs:
return cls(**{attr: getattr(args, attr) for attr in attrs}) return cls(**{attr: getattr(args, attr) for attr in attrs})
def url(self): def url(self):
if is_ipv6(self.host): if is_valid_ipv6_address(self.host):
return f"http://[{self.host}]:{self.port}" return f"http://[{self.host}]:{self.port}"
else: else:
return f"http://{self.host}:{self.port}" return f"http://{self.host}:{self.port}"
......
...@@ -102,14 +102,6 @@ def is_cuda_available(): ...@@ -102,14 +102,6 @@ def is_cuda_available():
return torch.cuda.is_available() and torch.version.cuda return torch.cuda.is_available() and torch.version.cuda
def is_ipv6(address):
try:
ipaddress.IPv6Address(address)
return True
except ipaddress.AddressValueError:
return False
def enable_show_time_cost(): def enable_show_time_cost():
global show_time_cost global show_time_cost
show_time_cost = True show_time_cost = True
...@@ -1383,3 +1375,70 @@ def set_uvicorn_logging_configs(): ...@@ -1383,3 +1375,70 @@ def set_uvicorn_logging_configs():
"fmt" "fmt"
] = '[%(asctime)s] %(levelprefix)s %(client_addr)s - "%(request_line)s" %(status_code)s' ] = '[%(asctime)s] %(levelprefix)s %(client_addr)s - "%(request_line)s" %(status_code)s'
LOGGING_CONFIG["formatters"]["access"]["datefmt"] = "%Y-%m-%d %H:%M:%S" LOGGING_CONFIG["formatters"]["access"]["datefmt"] = "%Y-%m-%d %H:%M:%S"
def get_ip() -> str:
# SGLANG_HOST_IP env can be ignore
host_ip = os.getenv("SGLANG_HOST_IP", "") or os.getenv("HOST_IP", "")
if host_ip:
return host_ip
# IP is not set, try to get it from the network interface
# try ipv4
s = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
try:
s.connect(("8.8.8.8", 80)) # Doesn't need to be reachable
return s.getsockname()[0]
except Exception:
pass
# try ipv6
try:
s = socket.socket(socket.AF_INET6, socket.SOCK_DGRAM)
# Google's public DNS server, see
# https://developers.google.com/speed/public-dns/docs/using#addresses
s.connect(("2001:4860:4860::8888", 80)) # Doesn't need to be reachable
return s.getsockname()[0]
except Exception:
pass
warnings.warn(
"Failed to get the IP address, using 0.0.0.0 by default."
"The value can be set by the environment variable"
" SGLANG_HOST_IP or HOST_IP.",
stacklevel=2,
)
return "0.0.0.0"
def get_open_port() -> int:
port = os.getenv("SGLANG_PORT")
if port is not None:
while True:
try:
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
s.bind(("", port))
return port
except OSError:
port += 1 # Increment port number if already in use
logger.info("Port %d is already in use, trying port %d", port - 1, port)
# try ipv4
try:
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
s.bind(("", 0))
return s.getsockname()[1]
except OSError:
# try ipv6
with socket.socket(socket.AF_INET6, socket.SOCK_STREAM) as s:
s.bind(("", 0))
return s.getsockname()[1]
def is_valid_ipv6_address(address: str) -> bool:
try:
ipaddress.IPv6Address(address)
return True
except ValueError:
return False
"""Common utilities""" """Common utilities"""
import base64 import base64
import gc
import importlib import importlib
import json import json
import logging import logging
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment