Unverified Commit c26d7349 authored by shangmingc's avatar shangmingc Committed by GitHub
Browse files

[PD] Add custom memory pool option to support Mooncake PD with NVLink (#7264)


Signed-off-by: default avatarShangming Cai <caishangming@linux.alibaba.com>
parent ceaa85c9
import os
import threading
from importlib import resources
from typing import Dict, Final, Optional
import torch
from torch.cuda.memory import CUDAPluggableAllocator
# TODO(shangming): move this class into mooncake's package for more general use cases
class MooncakeNVLinkAllocator:
_instances: Dict[torch.device, CUDAPluggableAllocator] = {}
_lock: Final = threading.Lock()
@classmethod
def _get_so_path(cls) -> str:
"""Dynamically locate hook.so in the mooncake package installation"""
try:
# Attempt to locate package resource
with resources.path("mooncake", "hook.so") as so_path:
if so_path.exists():
return str(so_path)
except (ImportError, FileNotFoundError, TypeError):
pass
# Fallback strategy: check in package location via import metadata
try:
import mooncake
base_path = os.path.dirname(os.path.abspath(mooncake.__file__))
so_path = os.path.join(base_path, "hook.so")
if os.path.exists(so_path):
return so_path
except (ImportError, FileNotFoundError, TypeError):
raise ImportError(
"SGLANG_MOONCAKE_CUSTOM_MEM_POOL require mooncake-transfer-engine >= 0.3.3.post2."
)
@classmethod
def get_allocator(cls, device: torch.device) -> CUDAPluggableAllocator:
with cls._lock:
if device not in cls._instances:
so_path = cls._get_so_path()
cls._instances[device] = CUDAPluggableAllocator(
so_path, "mc_nvlink_malloc", "mc_nvlink_free"
)
return cls._instances[device]
...@@ -6,6 +6,7 @@ import random ...@@ -6,6 +6,7 @@ import random
import threading import threading
import warnings import warnings
from collections import deque from collections import deque
from contextlib import nullcontext
from enum import Enum from enum import Enum
from typing import TYPE_CHECKING, List, Optional from typing import TYPE_CHECKING, List, Optional
...@@ -84,24 +85,37 @@ class ReqToMetadataIdxAllocator: ...@@ -84,24 +85,37 @@ class ReqToMetadataIdxAllocator:
class MetadataBuffers: class MetadataBuffers:
def __init__(self, size: int, max_top_logprobs_num: int = 128): def __init__(
# TODO: abort top_logprobs_num > 128 in PD self,
size: int,
# We transfer the metadata of first output token to decode max_top_logprobs_num: int = 128,
# The minimal size for RDMA is 64Bytes, so we pad it to > 64Bytes custom_mem_pool: torch.cuda.MemPool = None,
self.output_ids = torch.zeros((size, 16), dtype=torch.int32, device="cpu") ):
self.output_token_logprobs_val = torch.zeros( self.custom_mem_pool = custom_mem_pool
(size, 16), dtype=torch.float32, device="cpu" device = "cuda" if self.custom_mem_pool else "cpu"
)
self.output_token_logprobs_idx = torch.zeros( with (
(size, 16), dtype=torch.int32, device="cpu" torch.cuda.use_mem_pool(self.custom_mem_pool)
) if self.custom_mem_pool
self.output_top_logprobs_val = torch.zeros( else nullcontext()
(size, max_top_logprobs_num), dtype=torch.float32, device="cpu" ):
) # TODO: abort top_logprobs_num > 128 in PD
self.output_top_logprobs_idx = torch.zeros(
(size, max_top_logprobs_num), dtype=torch.int32, device="cpu" # We transfer the metadata of first output token to decode
) # The minimal size for RDMA is 64Bytes, so we pad it to > 64Bytes
self.output_ids = torch.zeros((size, 16), dtype=torch.int32, device=device)
self.output_token_logprobs_val = torch.zeros(
(size, 16), dtype=torch.float32, device=device
)
self.output_token_logprobs_idx = torch.zeros(
(size, 16), dtype=torch.int32, device=device
)
self.output_top_logprobs_val = torch.zeros(
(size, max_top_logprobs_num), dtype=torch.float32, device=device
)
self.output_top_logprobs_idx = torch.zeros(
(size, max_top_logprobs_num), dtype=torch.int32, device=device
)
def get_buf_infos(self): def get_buf_infos(self):
ptrs = [ ptrs = [
......
...@@ -622,7 +622,10 @@ class Scheduler( ...@@ -622,7 +622,10 @@ class Scheduler(
self.req_to_metadata_buffer_idx_allocator = ReqToMetadataIdxAllocator( self.req_to_metadata_buffer_idx_allocator = ReqToMetadataIdxAllocator(
buffer_size buffer_size
) )
self.disagg_metadata_buffers = MetadataBuffers(buffer_size) self.disagg_metadata_buffers = MetadataBuffers(
buffer_size,
custom_mem_pool=self.token_to_kv_pool_allocator.get_kvcache().maybe_get_custom_mem_pool(),
)
# The decode requests polling kv cache # The decode requests polling kv cache
self.disagg_decode_transfer_queue = DecodeTransferQueue( self.disagg_decode_transfer_queue = DecodeTransferQueue(
...@@ -669,7 +672,10 @@ class Scheduler( ...@@ -669,7 +672,10 @@ class Scheduler(
self.req_to_metadata_buffer_idx_allocator = ReqToMetadataIdxAllocator( self.req_to_metadata_buffer_idx_allocator = ReqToMetadataIdxAllocator(
buffer_size buffer_size
) )
self.disagg_metadata_buffers = MetadataBuffers(buffer_size) self.disagg_metadata_buffers = MetadataBuffers(
buffer_size,
custom_mem_pool=self.token_to_kv_pool_allocator.get_kvcache().maybe_get_custom_mem_pool(),
)
self.disagg_prefill_bootstrap_queue = PrefillBootstrapQueue( self.disagg_prefill_bootstrap_queue = PrefillBootstrapQueue(
token_to_kv_pool=self.token_to_kv_pool_allocator.get_kvcache(), token_to_kv_pool=self.token_to_kv_pool_allocator.get_kvcache(),
......
...@@ -26,6 +26,8 @@ KVCache actually holds the physical kv cache. ...@@ -26,6 +26,8 @@ KVCache actually holds the physical kv cache.
import abc import abc
import logging import logging
import os
from contextlib import nullcontext
from typing import List, Optional, Tuple, Union from typing import List, Optional, Tuple, Union
import numpy as np import numpy as np
...@@ -34,7 +36,7 @@ import triton ...@@ -34,7 +36,7 @@ import triton
import triton.language as tl import triton.language as tl
from sglang.srt.layers.radix_attention import RadixAttention from sglang.srt.layers.radix_attention import RadixAttention
from sglang.srt.utils import debug_timing, is_cuda, next_power_of_2 from sglang.srt.utils import debug_timing, get_bool_env_var, is_cuda, next_power_of_2
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
...@@ -260,6 +262,22 @@ class MHATokenToKVPool(KVCache): ...@@ -260,6 +262,22 @@ class MHATokenToKVPool(KVCache):
self.head_num = head_num self.head_num = head_num
self.head_dim = head_dim self.head_dim = head_dim
# for disagg with nvlink
self.enable_custom_mem_pool = get_bool_env_var(
"SGLANG_MOONCAKE_CUSTOM_MEM_POOL", "false"
)
if self.enable_custom_mem_pool:
from sglang.srt.disaggregation.mooncake.memory_pool import (
MooncakeNVLinkAllocator,
)
# TODO(shangming): abstract custom allocator class for more backends
allocator = MooncakeNVLinkAllocator.get_allocator(self.device)
self.custom_mem_pool = torch.cuda.MemPool(allocator.allocator())
else:
self.custom_mem_pool = None
self._create_buffers() self._create_buffers()
# used for chunked cpu-offloading # used for chunked cpu-offloading
...@@ -275,24 +293,29 @@ class MHATokenToKVPool(KVCache): ...@@ -275,24 +293,29 @@ class MHATokenToKVPool(KVCache):
def _create_buffers(self): def _create_buffers(self):
with self.memory_saver_adapter.region(): with self.memory_saver_adapter.region():
# [size, head_num, head_dim] for each layer with (
# The padded slot 0 is used for writing dummy outputs from padded tokens. torch.cuda.use_mem_pool(self.custom_mem_pool)
self.k_buffer = [ if self.enable_custom_mem_pool
torch.zeros( else nullcontext()
(self.size + self.page_size, self.head_num, self.head_dim), ):
dtype=self.store_dtype, # [size, head_num, head_dim] for each layer
device=self.device, # The padded slot 0 is used for writing dummy outputs from padded tokens.
) self.k_buffer = [
for _ in range(self.layer_num) torch.zeros(
] (self.size + self.page_size, self.head_num, self.head_dim),
self.v_buffer = [ dtype=self.store_dtype,
torch.zeros( device=self.device,
(self.size + self.page_size, self.head_num, self.head_dim), )
dtype=self.store_dtype, for _ in range(self.layer_num)
device=self.device, ]
) self.v_buffer = [
for _ in range(self.layer_num) torch.zeros(
] (self.size + self.page_size, self.head_num, self.head_dim),
dtype=self.store_dtype,
device=self.device,
)
for _ in range(self.layer_num)
]
self.data_ptrs = torch.tensor( self.data_ptrs = torch.tensor(
[x.data_ptr() for x in self.k_buffer + self.v_buffer], [x.data_ptr() for x in self.k_buffer + self.v_buffer],
...@@ -349,6 +372,9 @@ class MHATokenToKVPool(KVCache): ...@@ -349,6 +372,9 @@ class MHATokenToKVPool(KVCache):
] ]
return kv_data_ptrs, kv_data_lens, kv_item_lens return kv_data_ptrs, kv_data_lens, kv_item_lens
def maybe_get_custom_mem_pool(self):
return self.custom_mem_pool
def get_cpu_copy(self, indices): def get_cpu_copy(self, indices):
torch.cuda.synchronize() torch.cuda.synchronize()
kv_cache_cpu = [] kv_cache_cpu = []
...@@ -569,16 +595,36 @@ class MLATokenToKVPool(KVCache): ...@@ -569,16 +595,36 @@ class MLATokenToKVPool(KVCache):
self.kv_lora_rank = kv_lora_rank self.kv_lora_rank = kv_lora_rank
self.qk_rope_head_dim = qk_rope_head_dim self.qk_rope_head_dim = qk_rope_head_dim
# for disagg with nvlink
self.enable_custom_mem_pool = get_bool_env_var(
"SGLANG_MOONCAKE_CUSTOM_MEM_POOL", "false"
)
if self.enable_custom_mem_pool:
from sglang.srt.disaggregation.mooncake.memory_pool import (
MooncakeNVLinkAllocator,
)
# TODO(shangming): abstract custom allocator class for more backends
allocator = MooncakeNVLinkAllocator.get_allocator(self.device)
self.custom_mem_pool = torch.cuda.MemPool(allocator.allocator())
else:
self.custom_mem_pool = None
with self.memory_saver_adapter.region(): with self.memory_saver_adapter.region():
# The padded slot 0 is used for writing dummy outputs from padded tokens. with (
self.kv_buffer = [ torch.cuda.use_mem_pool(self.custom_mem_pool)
torch.zeros( if self.custom_mem_pool
(size + page_size, 1, kv_lora_rank + qk_rope_head_dim), else nullcontext()
dtype=self.store_dtype, ):
device=device, # The padded slot 0 is used for writing dummy outputs from padded tokens.
) self.kv_buffer = [
for _ in range(layer_num) torch.zeros(
] (size + page_size, 1, kv_lora_rank + qk_rope_head_dim),
dtype=self.store_dtype,
device=device,
)
for _ in range(layer_num)
]
self.layer_transfer_counter = None self.layer_transfer_counter = None
...@@ -604,6 +650,9 @@ class MLATokenToKVPool(KVCache): ...@@ -604,6 +650,9 @@ class MLATokenToKVPool(KVCache):
] ]
return kv_data_ptrs, kv_data_lens, kv_item_lens return kv_data_ptrs, kv_data_lens, kv_item_lens
def maybe_get_custom_mem_pool(self):
return self.custom_mem_pool
def get_key_buffer(self, layer_id: int): def get_key_buffer(self, layer_id: int):
if self.layer_transfer_counter is not None: if self.layer_transfer_counter is not None:
self.layer_transfer_counter.wait_until(layer_id - self.start_layer) self.layer_transfer_counter.wait_until(layer_id - self.start_layer)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment