Unverified Commit 04ec6ba2 authored by Liangsheng Yin's avatar Liangsheng Yin Committed by GitHub
Browse files

Fix dockerfile and triton cache manager (#720)

parent d63f13c1
...@@ -23,18 +23,10 @@ RUN echo 'tzdata tzdata/Areas select America' | debconf-set-selections \ ...@@ -23,18 +23,10 @@ RUN echo 'tzdata tzdata/Areas select America' | debconf-set-selections \
RUN apt-get update -y \ RUN apt-get update -y \
&& apt-get install -y python3-pip git curl sudo && apt-get install -y python3-pip git curl sudo
# Workaround for https://github.com/openai/triton/issues/2507 and
# https://github.com/pytorch/pytorch/issues/107960 -- hopefully
# this won't be needed for future versions of this docker image
# or future versions of triton.
RUN ldconfig /usr/local/cuda-$(echo $CUDA_VERSION | cut -d. -f1,2)/compat/
WORKDIR /sgl-workspace WORKDIR /sgl-workspace
RUN pip3 --no-cache-dir install --upgrade pip \ RUN pip3 --no-cache-dir install --upgrade pip \
&& pip3 --no-cache-dir install "sglang[all]" \ && pip3 --no-cache-dir install "sglang[all]" \
&& pip3 --no-cache-dir uninstall -y triton triton-nightly \
&& pip3 --no-cache-dir install --no-deps --index-url https://aiinfra.pkgs.visualstudio.com/PublicPackages/_packaging/Triton-Nightly/pypi/simple/ triton-nightly \
&& pip3 --no-cache-dir install flashinfer -i https://flashinfer.ai/whl/cu121/torch2.3/ && pip3 --no-cache-dir install flashinfer -i https://flashinfer.ai/whl/cu121/torch2.3/
ENV DEBIAN_FRONTEND=interactive ENV DEBIAN_FRONTEND=interactive
...@@ -52,6 +52,7 @@ from sglang.srt.utils import ( ...@@ -52,6 +52,7 @@ from sglang.srt.utils import (
allocate_init_ports, allocate_init_ports,
assert_pkg_version, assert_pkg_version,
enable_show_time_cost, enable_show_time_cost,
maybe_set_triton_cache_manager,
set_ulimit, set_ulimit,
) )
from sglang.utils import get_exception_traceback from sglang.utils import get_exception_traceback
...@@ -201,6 +202,11 @@ def launch_server( ...@@ -201,6 +202,11 @@ def launch_server(
"reinstall the latest version by following the instructions " "reinstall the latest version by following the instructions "
"at https://docs.flashinfer.ai/installation.html.", "at https://docs.flashinfer.ai/installation.html.",
) )
if server_args.tp_size // server_args.dp_size > 1:
# FIXME: remove this after https://github.com/triton-lang/triton/pull/4295 is used as a dependency.
maybe_set_triton_cache_manager()
if server_args.chat_template: if server_args.chat_template:
# TODO: replace this with huggingface transformers template # TODO: replace this with huggingface transformers template
load_chat_template_for_openai_api(server_args.chat_template) load_chat_template_for_openai_api(server_args.chat_template)
......
...@@ -18,10 +18,15 @@ import psutil ...@@ -18,10 +18,15 @@ import psutil
import requests import requests
import torch import torch
import torch.distributed as dist import torch.distributed as dist
import triton
from fastapi.responses import JSONResponse from fastapi.responses import JSONResponse
from packaging import version as pkg_version from packaging import version as pkg_version
from starlette.middleware.base import BaseHTTPMiddleware from starlette.middleware.base import BaseHTTPMiddleware
from triton.runtime.cache import (
FileCacheManager,
default_cache_dir,
default_dump_dir,
default_override_dir,
)
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
...@@ -460,6 +465,44 @@ def monkey_patch_vllm_all_gather(reverse: bool = False): ...@@ -460,6 +465,44 @@ def monkey_patch_vllm_all_gather(reverse: bool = False):
setattr(GroupCoordinator, "all_gather", all_gather) setattr(GroupCoordinator, "all_gather", all_gather)
def maybe_set_triton_cache_manager() -> None:
"""Set environment variable to tell Triton to use a
custom cache manager"""
cache_manger = os.environ.get("TRITON_CACHE_MANAGER", None)
if cache_manger is None:
manager = "sglang.srt.utils:CustomCacheManager"
logger.info("Setting Triton cache manager to: %s", manager)
os.environ["TRITON_CACHE_MANAGER"] = manager
class CustomCacheManager(FileCacheManager):
# Adapted from: https://github.com/tdoublep/vllm/blob/3307522289fdfefe323b6c00d0db696651989a2f/vllm/triton_utils/custom_cache_manager.py
def __init__(self, key, override=False, dump=False):
self.key = key
self.lock_path = None
if dump:
self.cache_dir = default_dump_dir()
self.cache_dir = os.path.join(self.cache_dir, self.key)
self.lock_path = os.path.join(self.cache_dir, "lock")
os.makedirs(self.cache_dir, exist_ok=True)
elif override:
self.cache_dir = default_override_dir()
self.cache_dir = os.path.join(self.cache_dir, self.key)
else:
# create cache directory if it doesn't exist
self.cache_dir = (
os.getenv("TRITON_CACHE_DIR", "").strip() or default_cache_dir()
)
if self.cache_dir:
self.cache_dir = f"{self.cache_dir}_{os.getpid()}"
self.cache_dir = os.path.join(self.cache_dir, self.key)
self.lock_path = os.path.join(self.cache_dir, "lock")
os.makedirs(self.cache_dir, exist_ok=True)
else:
raise RuntimeError("Could not create or locate cache dir")
API_KEY_HEADER_NAME = "X-API-Key" API_KEY_HEADER_NAME = "X-API-Key"
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment