Unverified Commit 73401fd0 authored by Lianmin Zheng's avatar Lianmin Zheng Committed by GitHub
Browse files

Sync distributed package from vllm 0.6.4.post1 (#3010)

parent 89cd9235
# Adapted from https://github.com/vllm-project/vllm/blob/a6221a144af772fd1a68fe7e627935dc53e81738/vllm/distributed/communication_op.py
# Adapted from https://github.com/vllm-project/vllm/blob/v0.6.4.post1/vllm/distributed/communication_op.py
from typing import Any, Dict, Optional, Union
import torch
import torch.distributed
from sglang.srt.distributed.parallel_state import get_tp_group
from .parallel_state import get_tp_group
def tensor_model_parallel_all_reduce(input_: torch.Tensor) -> torch.Tensor:
......
# Adapted from https://github.com/vllm-project/vllm/blob/a6221a144af772fd1a68fe7e627935dc53e81738/vllm/distributed/device_communicators/cuda_wrapper.py
# Adapted from https://github.com/vllm-project/vllm/blob/v0.6.4.post1/vllm/distributed/device_communicators/cuda_wrapper.py
"""This file is a pure Python wrapper for the cudart library.
It avoids the need to compile a separate shared library, and is
convenient for use when we just need to call a few functions.
......
# Adapted from https://github.com/vllm-project/vllm/blob/a6221a144af772fd1a68fe7e627935dc53e81738/vllm/distributed/device_communicators/custom_all_reduce.py
# Adapted from https://github.com/vllm-project/vllm/blob/v0.6.4.post1/vllm/distributed/device_communicators/custom_all_reduce.py
import ctypes
import logging
import os
......
# Adapted from https://github.com/vllm-project/vllm/blob/a6221a144af772fd1a68fe7e627935dc53e81738/vllm/distributed/device_communicators/custom_all_reduce_utils.py
# Adapted from https://github.com/vllm-project/vllm/blob/v0.6.4.post1/vllm/distributed/device_communicators/custom_all_reduce_utils.py
import ctypes
import json
import logging
......
# Adapted from https://github.com/vllm-project/vllm/blob/a6221a144af772fd1a68fe7e627935dc53e81738/vllm/distributed/device_communicators/hpu_communicator.py
# Adapted from https://github.com/vllm-project/vllm/blob/v0.6.4.post1/vllm/distributed/device_communicators/hpu_communicator.py
import torch
import torch.distributed as dist
from torch.distributed import ProcessGroup
......
# Adapted from https://github.com/vllm-project/vllm/blob/a6221a144af772fd1a68fe7e627935dc53e81738/vllm/distributed/device_communicators/pynccl.py
# Adapted from https://github.com/vllm-project/vllm/blob/v0.6.4.post1/vllm/distributed/device_communicators/pynccl.py
import logging
from contextlib import contextmanager
from typing import Optional, Union
# ===================== import region =====================
import torch
import torch.distributed as dist
from torch.distributed import ProcessGroup, ReduceOp
......@@ -143,6 +145,57 @@ class PyNcclCommunicator:
cudaStream_t(stream.cuda_stream),
)
def all_gather(
self, output_tensor: torch.Tensor, input_tensor: torch.Tensor, stream=None
):
if self.disabled:
return
# nccl communicator created on a specific device
# will only work on tensors on the same device
# otherwise it will cause "illegal memory access"
assert input_tensor.device == self.device, (
f"this nccl communicator is created to work on {self.device}, "
f"but the input tensor is on {input_tensor.device}"
)
if stream is None:
stream = self.stream
self.nccl.ncclAllGather(
buffer_type(input_tensor.data_ptr()),
buffer_type(output_tensor.data_ptr()),
input_tensor.numel(),
ncclDataTypeEnum.from_torch(input_tensor.dtype),
self.comm,
cudaStream_t(stream.cuda_stream),
)
def reduce_scatter(
self,
output_tensor: torch.Tensor,
input_tensor: torch.Tensor,
op: ReduceOp = ReduceOp.SUM,
stream=None,
):
if self.disabled:
return
# nccl communicator created on a specific device
# will only work on tensors on the same device
# otherwise it will cause "illegal memory access"
assert input_tensor.device == self.device, (
f"this nccl communicator is created to work on {self.device}, "
f"but the input tensor is on {input_tensor.device}"
)
if stream is None:
stream = self.stream
self.nccl.ncclReduceScatter(
buffer_type(input_tensor.data_ptr()),
buffer_type(output_tensor.data_ptr()),
output_tensor.numel(),
ncclDataTypeEnum.from_torch(input_tensor.dtype),
ncclRedOpTypeEnum.from_torch(op),
self.comm,
cudaStream_t(stream.cuda_stream),
)
def send(self, tensor: torch.Tensor, dst: int, stream=None):
if self.disabled:
return
......@@ -179,6 +232,32 @@ class PyNcclCommunicator:
cudaStream_t(stream.cuda_stream),
)
def broadcast(self, tensor: torch.Tensor, src: int, stream=None):
if self.disabled:
return
assert tensor.device == self.device, (
f"this nccl communicator is created to work on {self.device}, "
f"but the input tensor is on {tensor.device}"
)
if stream is None:
stream = self.stream
if src == self.rank:
sendbuff = buffer_type(tensor.data_ptr())
# NCCL requires the sender also to have a receive buffer
recvbuff = buffer_type(tensor.data_ptr())
else:
sendbuff = buffer_type()
recvbuff = buffer_type(tensor.data_ptr())
self.nccl.ncclBroadcast(
sendbuff,
recvbuff,
tensor.numel(),
ncclDataTypeEnum.from_torch(tensor.dtype),
src,
self.comm,
cudaStream_t(stream.cuda_stream),
)
@contextmanager
def change_state(
self, enable: Optional[bool] = None, stream: Optional[torch.cuda.Stream] = None
......
# Adapted from https://github.com/vllm-project/vllm/blob/a6221a144af772fd1a68fe7e627935dc53e81738/vllm/distributed/device_communicators/pynccl.py
# Adapted from https://github.com/vllm-project/vllm/blob/v0.6.4.post1/vllm/distributed/device_communicators/pynccl.py
# This file is a pure Python wrapper for the NCCL library.
# The main purpose is to use NCCL combined with CUDA graph.
......@@ -187,6 +187,43 @@ class NCCLLibrary:
cudaStream_t,
],
),
# ncclResult_t ncclAllGather(
# const void* sendbuff, void* recvbuff, size_t count,
# ncclDataType_t datatype, ncclComm_t comm,
# cudaStream_t stream);
# note that cudaStream_t is a pointer type, so the last argument
# is a pointer
Function(
"ncclAllGather",
ncclResult_t,
[
buffer_type,
buffer_type,
ctypes.c_size_t,
ncclDataType_t,
ncclComm_t,
cudaStream_t,
],
),
# ncclResult_t ncclReduceScatter(
# const void* sendbuff, void* recvbuff, size_t count,
# ncclDataType_t datatype, ncclRedOp_t op, ncclComm_t comm,
# cudaStream_t stream);
# note that cudaStream_t is a pointer type, so the last argument
# is a pointer
Function(
"ncclReduceScatter",
ncclResult_t,
[
buffer_type,
buffer_type,
ctypes.c_size_t,
ncclDataType_t,
ncclRedOp_t,
ncclComm_t,
cudaStream_t,
],
),
# ncclResult_t ncclSend(
# const void* sendbuff, size_t count, ncclDataType_t datatype,
# int dest, ncclComm_t comm, cudaStream_t stream);
......@@ -217,6 +254,23 @@ class NCCLLibrary:
cudaStream_t,
],
),
# ncclResult_t ncclBroadcast(
# const void* sendbuff, void* recvbuff, size_t count,
# ncclDataType_t datatype, int root, ncclComm_t comm,
# cudaStream_t stream);
Function(
"ncclBroadcast",
ncclResult_t,
[
buffer_type,
buffer_type,
ctypes.c_size_t,
ncclDataType_t,
ctypes.c_int,
ncclComm_t,
cudaStream_t,
],
),
# be cautious! this is a collective call, it will block until all
# processes in the communicator have called this function.
# because Python object destruction can happen in random order,
......@@ -321,6 +375,46 @@ class NCCLLibrary:
)
)
def ncclReduceScatter(
self,
sendbuff: buffer_type,
recvbuff: buffer_type,
count: int,
datatype: int,
op: int,
comm: ncclComm_t,
stream: cudaStream_t,
) -> None:
# `datatype` actually should be `ncclDataType_t`
# and `op` should be `ncclRedOp_t`
# both are aliases of `ctypes.c_int`
# when we pass int to a function, it will be converted to `ctypes.c_int`
# by ctypes automatically
self.NCCL_CHECK(
self._funcs["ncclReduceScatter"](
sendbuff, recvbuff, count, datatype, op, comm, stream
)
)
def ncclAllGather(
self,
sendbuff: buffer_type,
recvbuff: buffer_type,
count: int,
datatype: int,
comm: ncclComm_t,
stream: cudaStream_t,
) -> None:
# `datatype` actually should be `ncclDataType_t`
# which is an aliases of `ctypes.c_int`
# when we pass int to a function, it will be converted to `ctypes.c_int`
# by ctypes automatically
self.NCCL_CHECK(
self._funcs["ncclAllGather"](
sendbuff, recvbuff, count, datatype, comm, stream
)
)
def ncclSend(
self,
sendbuff: buffer_type,
......@@ -347,6 +441,22 @@ class NCCLLibrary:
self._funcs["ncclRecv"](recvbuff, count, datatype, src, comm, stream)
)
def ncclBroadcast(
self,
sendbuff: buffer_type,
recvbuff: buffer_type,
count: int,
datatype: int,
root: int,
comm: ncclComm_t,
stream: cudaStream_t,
) -> None:
self.NCCL_CHECK(
self._funcs["ncclBroadcast"](
sendbuff, recvbuff, count, datatype, root, comm, stream
)
)
def ncclCommDestroy(self, comm: ncclComm_t) -> None:
self.NCCL_CHECK(self._funcs["ncclCommDestroy"](comm))
......
# Adapted from https://github.com/vllm-project/vllm/blob/a6221a144af772fd1a68fe7e627935dc53e81738/vllm/distributed/device_communicators/shm_broadcast.py
import ipaddress
# Adapted from https://github.com/vllm-project/vllm/blob/v0.6.4.post1/vllm/distributed/device_communicators/shm_broadcast.py
import logging
import os
import pickle
import socket
import time
import warnings
from contextlib import contextmanager
from dataclasses import dataclass, field
from multiprocessing import shared_memory
......@@ -18,6 +16,8 @@ from torch.distributed import ProcessGroup
from zmq import IPV6 # type: ignore
from zmq import SUB, SUBSCRIBE, XPUB, XPUB_VERBOSE, Context # type: ignore
from sglang.srt.utils import get_ip, get_open_port, is_valid_ipv6_address
# SGLANG_RINGBUFFER_WARNING_INTERVAL can be set to 60
SGLANG_RINGBUFFER_WARNING_INTERVAL = int(
os.environ.get("SGLANG_RINGBUFFER_WARNING_INTERVAL", "60")
......@@ -26,73 +26,6 @@ SGLANG_RINGBUFFER_WARNING_INTERVAL = int(
logger = logging.getLogger(__name__)
def get_ip() -> str:
# SGLANG_HOST_IP env can be ignore
host_ip = os.getenv("SGLANG_HOST_IP", "") or os.getenv("HOST_IP", "")
if host_ip:
return host_ip
# IP is not set, try to get it from the network interface
# try ipv4
s = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
try:
s.connect(("8.8.8.8", 80)) # Doesn't need to be reachable
return s.getsockname()[0]
except Exception:
pass
# try ipv6
try:
s = socket.socket(socket.AF_INET6, socket.SOCK_DGRAM)
# Google's public DNS server, see
# https://developers.google.com/speed/public-dns/docs/using#addresses
s.connect(("2001:4860:4860::8888", 80)) # Doesn't need to be reachable
return s.getsockname()[0]
except Exception:
pass
warnings.warn(
"Failed to get the IP address, using 0.0.0.0 by default."
"The value can be set by the environment variable"
" SGLANG_HOST_IP or HOST_IP.",
stacklevel=2,
)
return "0.0.0.0"
def get_open_port() -> int:
port = os.getenv("SGLANG_PORT")
if port is not None:
while True:
try:
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
s.bind(("", port))
return port
except OSError:
port += 1 # Increment port number if already in use
logger.info("Port %d is already in use, trying port %d", port - 1, port)
# try ipv4
try:
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
s.bind(("", 0))
return s.getsockname()[1]
except OSError:
# try ipv6
with socket.socket(socket.AF_INET6, socket.SOCK_STREAM) as s:
s.bind(("", 0))
return s.getsockname()[1]
def is_valid_ipv6_address(address: str) -> bool:
try:
ipaddress.IPv6Address(address)
return True
except ValueError:
return False
class ShmRingBuffer:
def __init__(
......
# Adapted from https://github.com/vllm-project/vllm/blob/a6221a144af772fd1a68fe7e627935dc53e81738/vllm/distributed/device_communicators/xpu_communicator.py
# Adapted from https://github.com/vllm-project/vllm/blob/v0.6.4.post1/vllm/distributed/device_communicators/xpu_communicator.py
import torch
import torch.distributed as dist
from torch.distributed import ProcessGroup
......
# Adapted from https://github.com/vllm-project/vllm/blob/a6221a144af772fd1a68fe7e627935dc53e81738/vllm/distributed/parallel_state.py
# Adapted from https://github.com/vllm-project/vllm/blob/v0.6.4.post1/vllm/distributed/parallel_state.py
# Copyright 2023 The vLLM team.
# Adapted from
......
# Adapted from https://github.com/vllm-project/vllm/blob/a6221a144af772fd1a68fe7e627935dc53e81738/vllm/distributed/utils.py
# Adapted from https://github.com/vllm-project/vllm/blob/v0.6.4.post1/vllm/distributed/utils.py
# Copyright 2023 The vLLM team.
# Adapted from
# https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/core/tensor_parallel/utils.py
......
......@@ -29,8 +29,8 @@ from sglang.srt.utils import (
get_nvgpu_memory_capacity,
is_flashinfer_available,
is_hip,
is_ipv6,
is_port_available,
is_valid_ipv6_address,
nullable_str,
)
......@@ -883,7 +883,7 @@ class ServerArgs:
return cls(**{attr: getattr(args, attr) for attr in attrs})
def url(self):
if is_ipv6(self.host):
if is_valid_ipv6_address(self.host):
return f"http://[{self.host}]:{self.port}"
else:
return f"http://{self.host}:{self.port}"
......
......@@ -102,14 +102,6 @@ def is_cuda_available():
return torch.cuda.is_available() and torch.version.cuda
def is_ipv6(address):
try:
ipaddress.IPv6Address(address)
return True
except ipaddress.AddressValueError:
return False
def enable_show_time_cost():
global show_time_cost
show_time_cost = True
......@@ -1383,3 +1375,70 @@ def set_uvicorn_logging_configs():
"fmt"
] = '[%(asctime)s] %(levelprefix)s %(client_addr)s - "%(request_line)s" %(status_code)s'
LOGGING_CONFIG["formatters"]["access"]["datefmt"] = "%Y-%m-%d %H:%M:%S"
def get_ip() -> str:
# SGLANG_HOST_IP env can be ignore
host_ip = os.getenv("SGLANG_HOST_IP", "") or os.getenv("HOST_IP", "")
if host_ip:
return host_ip
# IP is not set, try to get it from the network interface
# try ipv4
s = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
try:
s.connect(("8.8.8.8", 80)) # Doesn't need to be reachable
return s.getsockname()[0]
except Exception:
pass
# try ipv6
try:
s = socket.socket(socket.AF_INET6, socket.SOCK_DGRAM)
# Google's public DNS server, see
# https://developers.google.com/speed/public-dns/docs/using#addresses
s.connect(("2001:4860:4860::8888", 80)) # Doesn't need to be reachable
return s.getsockname()[0]
except Exception:
pass
warnings.warn(
"Failed to get the IP address, using 0.0.0.0 by default."
"The value can be set by the environment variable"
" SGLANG_HOST_IP or HOST_IP.",
stacklevel=2,
)
return "0.0.0.0"
def get_open_port() -> int:
port = os.getenv("SGLANG_PORT")
if port is not None:
while True:
try:
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
s.bind(("", port))
return port
except OSError:
port += 1 # Increment port number if already in use
logger.info("Port %d is already in use, trying port %d", port - 1, port)
# try ipv4
try:
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
s.bind(("", 0))
return s.getsockname()[1]
except OSError:
# try ipv6
with socket.socket(socket.AF_INET6, socket.SOCK_STREAM) as s:
s.bind(("", 0))
return s.getsockname()[1]
def is_valid_ipv6_address(address: str) -> bool:
try:
ipaddress.IPv6Address(address)
return True
except ValueError:
return False
"""Common utilities"""
import base64
import gc
import importlib
import json
import logging
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment