Unverified Commit c26d7349 authored by shangmingc's avatar shangmingc Committed by GitHub
Browse files

[PD] Add custom memory pool option to support Mooncake PD with NVLink (#7264)


Signed-off-by: default avatarShangming Cai <caishangming@linux.alibaba.com>
parent ceaa85c9
import os
import threading
from importlib import resources
from typing import Dict, Final, Optional
import torch
from torch.cuda.memory import CUDAPluggableAllocator
# TODO(shangming): move this class into mooncake's package for more general use cases
class MooncakeNVLinkAllocator:
_instances: Dict[torch.device, CUDAPluggableAllocator] = {}
_lock: Final = threading.Lock()
@classmethod
def _get_so_path(cls) -> str:
"""Dynamically locate hook.so in the mooncake package installation"""
try:
# Attempt to locate package resource
with resources.path("mooncake", "hook.so") as so_path:
if so_path.exists():
return str(so_path)
except (ImportError, FileNotFoundError, TypeError):
pass
# Fallback strategy: check in package location via import metadata
try:
import mooncake
base_path = os.path.dirname(os.path.abspath(mooncake.__file__))
so_path = os.path.join(base_path, "hook.so")
if os.path.exists(so_path):
return so_path
except (ImportError, FileNotFoundError, TypeError):
raise ImportError(
"SGLANG_MOONCAKE_CUSTOM_MEM_POOL require mooncake-transfer-engine >= 0.3.3.post2."
)
@classmethod
def get_allocator(cls, device: torch.device) -> CUDAPluggableAllocator:
with cls._lock:
if device not in cls._instances:
so_path = cls._get_so_path()
cls._instances[device] = CUDAPluggableAllocator(
so_path, "mc_nvlink_malloc", "mc_nvlink_free"
)
return cls._instances[device]
......@@ -6,6 +6,7 @@ import random
import threading
import warnings
from collections import deque
from contextlib import nullcontext
from enum import Enum
from typing import TYPE_CHECKING, List, Optional
......@@ -84,24 +85,37 @@ class ReqToMetadataIdxAllocator:
class MetadataBuffers:
def __init__(self, size: int, max_top_logprobs_num: int = 128):
# TODO: abort top_logprobs_num > 128 in PD
# We transfer the metadata of first output token to decode
# The minimal size for RDMA is 64Bytes, so we pad it to > 64Bytes
self.output_ids = torch.zeros((size, 16), dtype=torch.int32, device="cpu")
self.output_token_logprobs_val = torch.zeros(
(size, 16), dtype=torch.float32, device="cpu"
)
self.output_token_logprobs_idx = torch.zeros(
(size, 16), dtype=torch.int32, device="cpu"
)
self.output_top_logprobs_val = torch.zeros(
(size, max_top_logprobs_num), dtype=torch.float32, device="cpu"
)
self.output_top_logprobs_idx = torch.zeros(
(size, max_top_logprobs_num), dtype=torch.int32, device="cpu"
)
def __init__(
self,
size: int,
max_top_logprobs_num: int = 128,
custom_mem_pool: torch.cuda.MemPool = None,
):
self.custom_mem_pool = custom_mem_pool
device = "cuda" if self.custom_mem_pool else "cpu"
with (
torch.cuda.use_mem_pool(self.custom_mem_pool)
if self.custom_mem_pool
else nullcontext()
):
# TODO: abort top_logprobs_num > 128 in PD
# We transfer the metadata of first output token to decode
# The minimal size for RDMA is 64Bytes, so we pad it to > 64Bytes
self.output_ids = torch.zeros((size, 16), dtype=torch.int32, device=device)
self.output_token_logprobs_val = torch.zeros(
(size, 16), dtype=torch.float32, device=device
)
self.output_token_logprobs_idx = torch.zeros(
(size, 16), dtype=torch.int32, device=device
)
self.output_top_logprobs_val = torch.zeros(
(size, max_top_logprobs_num), dtype=torch.float32, device=device
)
self.output_top_logprobs_idx = torch.zeros(
(size, max_top_logprobs_num), dtype=torch.int32, device=device
)
def get_buf_infos(self):
ptrs = [
......
......@@ -622,7 +622,10 @@ class Scheduler(
self.req_to_metadata_buffer_idx_allocator = ReqToMetadataIdxAllocator(
buffer_size
)
self.disagg_metadata_buffers = MetadataBuffers(buffer_size)
self.disagg_metadata_buffers = MetadataBuffers(
buffer_size,
custom_mem_pool=self.token_to_kv_pool_allocator.get_kvcache().maybe_get_custom_mem_pool(),
)
# The decode requests polling kv cache
self.disagg_decode_transfer_queue = DecodeTransferQueue(
......@@ -669,7 +672,10 @@ class Scheduler(
self.req_to_metadata_buffer_idx_allocator = ReqToMetadataIdxAllocator(
buffer_size
)
self.disagg_metadata_buffers = MetadataBuffers(buffer_size)
self.disagg_metadata_buffers = MetadataBuffers(
buffer_size,
custom_mem_pool=self.token_to_kv_pool_allocator.get_kvcache().maybe_get_custom_mem_pool(),
)
self.disagg_prefill_bootstrap_queue = PrefillBootstrapQueue(
token_to_kv_pool=self.token_to_kv_pool_allocator.get_kvcache(),
......
......@@ -26,6 +26,8 @@ KVCache actually holds the physical kv cache.
import abc
import logging
import os
from contextlib import nullcontext
from typing import List, Optional, Tuple, Union
import numpy as np
......@@ -34,7 +36,7 @@ import triton
import triton.language as tl
from sglang.srt.layers.radix_attention import RadixAttention
from sglang.srt.utils import debug_timing, is_cuda, next_power_of_2
from sglang.srt.utils import debug_timing, get_bool_env_var, is_cuda, next_power_of_2
logger = logging.getLogger(__name__)
......@@ -260,6 +262,22 @@ class MHATokenToKVPool(KVCache):
self.head_num = head_num
self.head_dim = head_dim
# for disagg with nvlink
self.enable_custom_mem_pool = get_bool_env_var(
"SGLANG_MOONCAKE_CUSTOM_MEM_POOL", "false"
)
if self.enable_custom_mem_pool:
from sglang.srt.disaggregation.mooncake.memory_pool import (
MooncakeNVLinkAllocator,
)
# TODO(shangming): abstract custom allocator class for more backends
allocator = MooncakeNVLinkAllocator.get_allocator(self.device)
self.custom_mem_pool = torch.cuda.MemPool(allocator.allocator())
else:
self.custom_mem_pool = None
self._create_buffers()
# used for chunked cpu-offloading
......@@ -275,24 +293,29 @@ class MHATokenToKVPool(KVCache):
def _create_buffers(self):
with self.memory_saver_adapter.region():
# [size, head_num, head_dim] for each layer
# The padded slot 0 is used for writing dummy outputs from padded tokens.
self.k_buffer = [
torch.zeros(
(self.size + self.page_size, self.head_num, self.head_dim),
dtype=self.store_dtype,
device=self.device,
)
for _ in range(self.layer_num)
]
self.v_buffer = [
torch.zeros(
(self.size + self.page_size, self.head_num, self.head_dim),
dtype=self.store_dtype,
device=self.device,
)
for _ in range(self.layer_num)
]
with (
torch.cuda.use_mem_pool(self.custom_mem_pool)
if self.enable_custom_mem_pool
else nullcontext()
):
# [size, head_num, head_dim] for each layer
# The padded slot 0 is used for writing dummy outputs from padded tokens.
self.k_buffer = [
torch.zeros(
(self.size + self.page_size, self.head_num, self.head_dim),
dtype=self.store_dtype,
device=self.device,
)
for _ in range(self.layer_num)
]
self.v_buffer = [
torch.zeros(
(self.size + self.page_size, self.head_num, self.head_dim),
dtype=self.store_dtype,
device=self.device,
)
for _ in range(self.layer_num)
]
self.data_ptrs = torch.tensor(
[x.data_ptr() for x in self.k_buffer + self.v_buffer],
......@@ -349,6 +372,9 @@ class MHATokenToKVPool(KVCache):
]
return kv_data_ptrs, kv_data_lens, kv_item_lens
def maybe_get_custom_mem_pool(self):
return self.custom_mem_pool
def get_cpu_copy(self, indices):
torch.cuda.synchronize()
kv_cache_cpu = []
......@@ -569,16 +595,36 @@ class MLATokenToKVPool(KVCache):
self.kv_lora_rank = kv_lora_rank
self.qk_rope_head_dim = qk_rope_head_dim
# for disagg with nvlink
self.enable_custom_mem_pool = get_bool_env_var(
"SGLANG_MOONCAKE_CUSTOM_MEM_POOL", "false"
)
if self.enable_custom_mem_pool:
from sglang.srt.disaggregation.mooncake.memory_pool import (
MooncakeNVLinkAllocator,
)
# TODO(shangming): abstract custom allocator class for more backends
allocator = MooncakeNVLinkAllocator.get_allocator(self.device)
self.custom_mem_pool = torch.cuda.MemPool(allocator.allocator())
else:
self.custom_mem_pool = None
with self.memory_saver_adapter.region():
# The padded slot 0 is used for writing dummy outputs from padded tokens.
self.kv_buffer = [
torch.zeros(
(size + page_size, 1, kv_lora_rank + qk_rope_head_dim),
dtype=self.store_dtype,
device=device,
)
for _ in range(layer_num)
]
with (
torch.cuda.use_mem_pool(self.custom_mem_pool)
if self.custom_mem_pool
else nullcontext()
):
# The padded slot 0 is used for writing dummy outputs from padded tokens.
self.kv_buffer = [
torch.zeros(
(size + page_size, 1, kv_lora_rank + qk_rope_head_dim),
dtype=self.store_dtype,
device=device,
)
for _ in range(layer_num)
]
self.layer_transfer_counter = None
......@@ -604,6 +650,9 @@ class MLATokenToKVPool(KVCache):
]
return kv_data_ptrs, kv_data_lens, kv_item_lens
def maybe_get_custom_mem_pool(self):
return self.custom_mem_pool
def get_key_buffer(self, layer_id: int):
if self.layer_transfer_counter is not None:
self.layer_transfer_counter.wait_until(layer_id - self.start_layer)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment