Unverified Commit 86044712 authored by ronnie_zheng's avatar ronnie_zheng Committed by GitHub
Browse files

[feature] kv transfer support of ascend npu (#7795)


Co-authored-by: default avatarliupeng <liupeng374@huawei.com>
parent 61555307
...@@ -111,3 +111,36 @@ $ python -m sglang.launch_server --model-path deepseek-ai/DeepSeek-V3-0324 ---di ...@@ -111,3 +111,36 @@ $ python -m sglang.launch_server --model-path deepseek-ai/DeepSeek-V3-0324 ---di
# decode 1 # decode 1
$ python -m sglang.launch_server --model-path deepseek-ai/DeepSeek-V3-0324 ---disaggregation-transfer-backend nixl --disaggregation-mode decode --host ${local_ip} --port 30001 --trust-remote-code --dist-init-addr ${decode_master_ip}:5000 --nnodes 2 --node-rank 1 --tp-size 16 --dp-size 8 --enable-dp-attention --enable-deepep-moe --deepep-mode low_latency --mem-fraction-static 0.8 --max-running-requests 128 $ python -m sglang.launch_server --model-path deepseek-ai/DeepSeek-V3-0324 ---disaggregation-transfer-backend nixl --disaggregation-mode decode --host ${local_ip} --port 30001 --trust-remote-code --dist-init-addr ${decode_master_ip}:5000 --nnodes 2 --node-rank 1 --tp-size 16 --dp-size 8 --enable-dp-attention --enable-deepep-moe --deepep-mode low_latency --mem-fraction-static 0.8 --max-running-requests 128
``` ```
## ASCEND
### Usage
Use ascend backend with [mf_adapter(download link)](https://sglang-ascend.obs.cn-east-3.myhuaweicloud.com:443/sglang/mf_adapter-1.0.0-cp311-cp311-linux_aarch64.whl?AccessKeyId=HPUAXT4YM0U8JNTERLST&Expires=1783151861&Signature=3j10QDUjqk70enaq8lostYV2bEA%3D) and ASCEND_MF_STORE_URL being set
```bash
pip install mf_adapter-1.0.0-cp311-cp311-linux_aarch64.whl --force-reinstall
export ASCEND_MF_STORE_URL="tcp://xxx.xx.xxx.xxx:xxxx"
```
Use mooncake backend, more details can be found in mooncake section.
```bash
export ENABLE_ASCEND_TRANSFER_WITH_MOONCAKE=true
```
### Llama Single Node
```bash
$ python -m sglang.launch_server --model-path meta-llama/Llama-3.1-8B-Instruct --disaggregation-mode prefill --disaggregation-transfer-backend ascend
$ python -m sglang.launch_server --model-path meta-llama/Llama-3.1-8B-Instruct --disaggregation-mode decode --port 30001 --base-gpu-id 1 --disaggregation-transfer-backend ascend
$ python -m sglang.srt.disaggregation.mini_lb --prefill http://127.0.0.1:30000 --decode http://127.0.0.1:30001 --host 0.0.0.0 --port 8000
```
### DeepSeek Multi-Node
```bash
# prefill 0
$ python -m sglang.launch_server --model-path deepseek-ai/DeepSeek-V3-0324 ---disaggregation-transfer-backend ascend --disaggregation-mode prefill --host ${local_ip} --port 30000 --trust-remote-code --dist-init-addr ${prefill_master_ip}:5000 --nnodes 1 --node-rank 0 --tp-size 16
# decode 0
$ python -m sglang.launch_server --model-path deepseek-ai/DeepSeek-V3-0324 ---disaggregation-transfer-backend ascend --disaggregation-mode decode --host ${local_ip} --port 30001 --trust-remote-code --dist-init-addr ${decode_master_ip}:5000 --nnodes 1 --node-rank 0 --tp-size 16
```
from sglang.srt.disaggregation.ascend.conn import (
AscendKVBootstrapServer,
AscendKVManager,
AscendKVReceiver,
AscendKVSender,
)
import logging
from sglang.srt.disaggregation.ascend.transfer_engine import AscendTransferEngine
from sglang.srt.disaggregation.mooncake.conn import (
MooncakeKVBootstrapServer,
MooncakeKVManager,
MooncakeKVReceiver,
MooncakeKVSender,
)
from sglang.srt.utils import get_local_ip_by_remote
logger = logging.getLogger(__name__)
class AscendKVManager(MooncakeKVManager):
def init_engine(self):
# TransferEngine initialized on ascend.
local_ip = get_local_ip_by_remote()
self.engine = AscendTransferEngine(
hostname=local_ip,
npu_id=self.kv_args.gpu_id,
disaggregation_mode=self.disaggregation_mode,
)
def register_buffer_to_engine(self):
self.engine.register(
self.kv_args.kv_data_ptrs[0], sum(self.kv_args.kv_data_lens)
)
# The Ascend backend optimize batch registration for small memory blocks.
self.engine.batch_register(
self.kv_args.aux_data_ptrs, self.kv_args.aux_data_lens
)
class AscendKVSender(MooncakeKVSender):
pass
class AscendKVReceiver(MooncakeKVReceiver):
pass
class AscendKVBootstrapServer(MooncakeKVBootstrapServer):
pass
import logging
import os
from typing import List, Optional
from sglang.srt.disaggregation.mooncake.transfer_engine import MooncakeTransferEngine
from sglang.srt.disaggregation.utils import DisaggregationMode
logger = logging.getLogger(__name__)
class AscendTransferEngine(MooncakeTransferEngine):
def __init__(
self, hostname: str, npu_id: int, disaggregation_mode: DisaggregationMode
):
try:
from mf_adapter import TransferEngine
except ImportError as e:
raise ImportError(
"Please install mf_adapter, for details, see docs/backend/pd_disaggregation.md"
) from e
self.engine = TransferEngine()
self.hostname = hostname
self.npu_id = npu_id
# Centralized storage address of the AscendTransferEngine
self.store_url = os.getenv("ASCEND_MF_STORE_URL")
if disaggregation_mode == DisaggregationMode.PREFILL:
self.role = "Prefill"
elif disaggregation_mode == DisaggregationMode.DECODE:
self.role = "Decode"
else:
logger.error(f"Unsupported DisaggregationMode: {disaggregation_mode}")
raise ValueError(f"Unsupported DisaggregationMode: {disaggregation_mode}")
self.session_id = f"{self.hostname}:{self.engine.get_rpc_port()}"
self.initialize()
def initialize(self) -> None:
"""Initialize the ascend transfer instance."""
ret_value = self.engine.initialize(
self.store_url,
self.session_id,
self.role,
self.npu_id,
)
if ret_value != 0:
logger.error("Ascend Transfer Engine initialization failed.")
raise RuntimeError("Ascend Transfer Engine initialization failed.")
def batch_register(self, ptrs: List[int], lengths: List[int]):
try:
ret_value = self.engine.batch_register_memory(ptrs, lengths)
except Exception:
# Mark register as failed
ret_value = -1
if ret_value != 0:
logger.debug(f"Ascend memory registration for ptr {ptrs} failed.")
...@@ -132,13 +132,9 @@ class MooncakeKVManager(BaseKVManager): ...@@ -132,13 +132,9 @@ class MooncakeKVManager(BaseKVManager):
): ):
self.kv_args = args self.kv_args = args
self.local_ip = get_local_ip_auto() self.local_ip = get_local_ip_auto()
self.engine = MooncakeTransferEngine(
hostname=self.local_ip,
gpu_id=self.kv_args.gpu_id,
ib_device=self.kv_args.ib_device,
)
self.is_mla_backend = is_mla_backend self.is_mla_backend = is_mla_backend
self.disaggregation_mode = disaggregation_mode self.disaggregation_mode = disaggregation_mode
self.init_engine()
# for p/d multi node infer # for p/d multi node infer
self.bootstrap_port = server_args.disaggregation_bootstrap_port self.bootstrap_port = server_args.disaggregation_bootstrap_port
self.dist_init_addr = server_args.dist_init_addr self.dist_init_addr = server_args.dist_init_addr
...@@ -225,6 +221,13 @@ class MooncakeKVManager(BaseKVManager): ...@@ -225,6 +221,13 @@ class MooncakeKVManager(BaseKVManager):
self.failure_records: Dict[int, str] = {} self.failure_records: Dict[int, str] = {}
self.failure_lock = threading.Lock() self.failure_lock = threading.Lock()
def init_engine(self):
self.engine = MooncakeTransferEngine(
hostname=self.local_ip,
gpu_id=self.kv_args.gpu_id,
ib_device=self.kv_args.ib_device,
)
def register_buffer_to_engine(self): def register_buffer_to_engine(self):
for kv_data_ptr, kv_data_len in zip( for kv_data_ptr, kv_data_len in zip(
self.kv_args.kv_data_ptrs, self.kv_args.kv_data_lens self.kv_args.kv_data_ptrs, self.kv_args.kv_data_lens
......
import logging import logging
from typing import List, Optional from typing import List, Optional
from sglang.srt.utils import get_bool_env_var, get_free_port
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
...@@ -53,6 +55,15 @@ class MooncakeTransferEngine: ...@@ -53,6 +55,15 @@ class MooncakeTransferEngine:
device_name: Optional[str], device_name: Optional[str],
) -> None: ) -> None:
"""Initialize the mooncake instance.""" """Initialize the mooncake instance."""
if get_bool_env_var("ENABLE_ASCEND_TRANSFER_WITH_MOONCAKE", "false"):
hostname += f":{get_free_port()}:npu_{self.gpu_id}"
ret_value = self.engine.initialize(
hostname,
"P2PHANDSHAKE",
"ascend",
device_name if device_name is not None else "",
)
else:
ret_value = self.engine.initialize( ret_value = self.engine.initialize(
hostname, hostname,
"P2PHANDSHAKE", "P2PHANDSHAKE",
......
...@@ -15,7 +15,7 @@ import requests ...@@ -15,7 +15,7 @@ import requests
import torch import torch
import torch.distributed as dist import torch.distributed as dist
from sglang.srt.utils import get_ip from sglang.srt.utils import get_ip, is_npu
if TYPE_CHECKING: if TYPE_CHECKING:
from sglang.srt.managers.schedule_batch import Req from sglang.srt.managers.schedule_batch import Req
...@@ -94,8 +94,12 @@ class MetadataBuffers: ...@@ -94,8 +94,12 @@ class MetadataBuffers:
custom_mem_pool: torch.cuda.MemPool = None, custom_mem_pool: torch.cuda.MemPool = None,
): ):
self.custom_mem_pool = custom_mem_pool self.custom_mem_pool = custom_mem_pool
device = "cuda" if self.custom_mem_pool else "cpu" device = "cpu"
if is_npu():
# For ascend backend, output tokens are placed in the NPU and will be transferred by D2D channel.
device = "npu"
elif self.custom_mem_pool:
device = "cuda"
with ( with (
torch.cuda.use_mem_pool(self.custom_mem_pool) torch.cuda.use_mem_pool(self.custom_mem_pool)
if self.custom_mem_pool if self.custom_mem_pool
...@@ -200,6 +204,7 @@ class MetadataBuffers: ...@@ -200,6 +204,7 @@ class MetadataBuffers:
class TransferBackend(Enum): class TransferBackend(Enum):
MOONCAKE = "mooncake" MOONCAKE = "mooncake"
NIXL = "nixl" NIXL = "nixl"
ASCEND = "ascend"
FAKE = "fake" FAKE = "fake"
...@@ -231,6 +236,23 @@ def get_kv_class(transfer_backend: TransferBackend, class_type: KVClassType): ...@@ -231,6 +236,23 @@ def get_kv_class(transfer_backend: TransferBackend, class_type: KVClassType):
KVClassType.BOOTSTRAP_SERVER: MooncakeKVBootstrapServer, KVClassType.BOOTSTRAP_SERVER: MooncakeKVBootstrapServer,
} }
return class_mapping.get(class_type) return class_mapping.get(class_type)
elif transfer_backend == TransferBackend.ASCEND:
from sglang.srt.disaggregation.ascend import (
AscendKVBootstrapServer,
AscendKVManager,
AscendKVReceiver,
AscendKVSender,
)
from sglang.srt.disaggregation.base import KVArgs
class_mapping = {
KVClassType.KVARGS: KVArgs,
KVClassType.MANAGER: AscendKVManager,
KVClassType.SENDER: AscendKVSender,
KVClassType.RECEIVER: (AscendKVReceiver),
KVClassType.BOOTSTRAP_SERVER: AscendKVBootstrapServer,
}
return class_mapping.get(class_type)
elif transfer_backend == TransferBackend.NIXL: elif transfer_backend == TransferBackend.NIXL:
from sglang.srt.disaggregation.base import KVArgs from sglang.srt.disaggregation.base import KVArgs
from sglang.srt.disaggregation.nixl import ( from sglang.srt.disaggregation.nixl import (
......
...@@ -285,6 +285,20 @@ class TokenizerManager: ...@@ -285,6 +285,20 @@ class TokenizerManager:
self.bootstrap_server = kv_bootstrap_server_class( self.bootstrap_server = kv_bootstrap_server_class(
self.server_args.disaggregation_bootstrap_port self.server_args.disaggregation_bootstrap_port
) )
is_create_store = (
self.server_args.node_rank == 0
and self.server_args.disaggregation_transfer_backend == "ascend"
)
if is_create_store:
try:
from mf_adapter import create_config_store
ascend_url = os.getenv("ASCEND_MF_STORE_URL")
create_config_store(ascend_url)
except Exception as e:
error_message = f"Failed create mf store, invalid ascend_url."
error_message += f" With exception {e}"
raise error_message
# For load balancing # For load balancing
self.current_load = 0 self.current_load = 0
......
...@@ -604,9 +604,12 @@ class AscendTokenToKVPool(MHATokenToKVPool): ...@@ -604,9 +604,12 @@ class AscendTokenToKVPool(MHATokenToKVPool):
with self.memory_saver_adapter.region(GPU_MEMORY_TYPE_KV_CACHE): with self.memory_saver_adapter.region(GPU_MEMORY_TYPE_KV_CACHE):
# [size, head_num, head_dim] for each layer # [size, head_num, head_dim] for each layer
# The padded slot 0 is used for writing dummy outputs from padded tokens. # The padded slot 0 is used for writing dummy outputs from padded tokens.
self.k_buffer = [ # Continuous memory improves the efficiency of Ascend`s transmission backend,
torch.zeros( # while other backends remain unchanged.
self.kv_buffer = torch.zeros(
( (
2,
self.layer_num,
self.size // self.page_size + 1, self.size // self.page_size + 1,
self.page_size, self.page_size,
self.head_num, self.head_num,
...@@ -615,21 +618,35 @@ class AscendTokenToKVPool(MHATokenToKVPool): ...@@ -615,21 +618,35 @@ class AscendTokenToKVPool(MHATokenToKVPool):
dtype=self.store_dtype, dtype=self.store_dtype,
device=self.device, device=self.device,
) )
for _ in range(self.layer_num) self.k_buffer = self.kv_buffer[0]
self.v_buffer = self.kv_buffer[1]
# for disagg
def get_contiguous_buf_infos(self):
# layer_num x [seq_len, head_num, head_dim]
# layer_num x [page_num, page_size, head_num, head_dim]
kv_data_ptrs = [
self.get_key_buffer(i).data_ptr()
for i in range(self.start_layer, self.start_layer + self.layer_num)
] + [
self.get_value_buffer(i).data_ptr()
for i in range(self.start_layer, self.start_layer + self.layer_num)
] ]
self.v_buffer = [ kv_data_lens = [
torch.zeros( self.get_key_buffer(i).nbytes
( for i in range(self.start_layer, self.start_layer + self.layer_num)
self.size // self.page_size + 1, ] + [
self.page_size, self.get_value_buffer(i).nbytes
self.head_num, for i in range(self.start_layer, self.start_layer + self.layer_num)
self.head_dim, ]
), kv_item_lens = [
dtype=self.store_dtype, self.get_key_buffer(i)[0].nbytes
device=self.device, for i in range(self.start_layer, self.start_layer + self.layer_num)
) ] + [
for _ in range(self.layer_num) self.get_value_buffer(i)[0].nbytes
for i in range(self.start_layer, self.start_layer + self.layer_num)
] ]
return kv_data_ptrs, kv_data_lens, kv_item_lens
def set_kv_buffer( def set_kv_buffer(
self, self,
...@@ -969,9 +986,9 @@ class AscendMLAPagedTokenToKVPool(MLATokenToKVPool): ...@@ -969,9 +986,9 @@ class AscendMLAPagedTokenToKVPool(MLATokenToKVPool):
with self.memory_saver_adapter.region(GPU_MEMORY_TYPE_KV_CACHE): with self.memory_saver_adapter.region(GPU_MEMORY_TYPE_KV_CACHE):
# The padded slot 0 is used for writing dummy outputs from padded tokens. # The padded slot 0 is used for writing dummy outputs from padded tokens.
self.kv_buffer = [ self.kv_buffer = torch.zeros(
torch.zeros(
( (
layer_num,
self.size // self.page_size + 1, self.size // self.page_size + 1,
self.page_size, self.page_size,
self.kv_lora_rank + self.qk_rope_head_dim, self.kv_lora_rank + self.qk_rope_head_dim,
...@@ -979,8 +996,6 @@ class AscendMLAPagedTokenToKVPool(MLATokenToKVPool): ...@@ -979,8 +996,6 @@ class AscendMLAPagedTokenToKVPool(MLATokenToKVPool):
dtype=self.store_dtype, dtype=self.store_dtype,
device=self.device, device=self.device,
) )
for _ in range(layer_num)
]
self.layer_transfer_counter = None self.layer_transfer_counter = None
...@@ -990,6 +1005,14 @@ class AscendMLAPagedTokenToKVPool(MLATokenToKVPool): ...@@ -990,6 +1005,14 @@ class AscendMLAPagedTokenToKVPool(MLATokenToKVPool):
) )
self.mem_usage = kv_size / GB self.mem_usage = kv_size / GB
# for disagg
def get_contiguous_buf_infos(self):
# MLA has only one kv_buffer, so only the information of this buffer needs to be returned.
kv_data_ptrs = [self.kv_buffer[i].data_ptr() for i in range(self.layer_num)]
kv_data_lens = [self.kv_buffer[i].nbytes for i in range(self.layer_num)]
kv_item_lens = [self.kv_buffer[i][0].nbytes for i in range(self.layer_num)]
return kv_data_ptrs, kv_data_lens, kv_item_lens
def set_kv_buffer( def set_kv_buffer(
self, self,
layer: RadixAttention, layer: RadixAttention,
......
...@@ -1621,7 +1621,7 @@ class ServerArgs: ...@@ -1621,7 +1621,7 @@ class ServerArgs:
"--disaggregation-transfer-backend", "--disaggregation-transfer-backend",
type=str, type=str,
default=ServerArgs.disaggregation_transfer_backend, default=ServerArgs.disaggregation_transfer_backend,
choices=["mooncake", "nixl"], choices=["mooncake", "nixl", "ascend"],
help="The backend for disaggregation transfer. Default is mooncake.", help="The backend for disaggregation transfer. Default is mooncake.",
) )
parser.add_argument( parser.add_argument(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment