Unverified Commit 0c984e25 authored by Shangyan Zhou's avatar Shangyan Zhou Committed by GitHub
Browse files

Explicitly destroy the C++ runtime and release resources. (#292)

* Explicitly destroy the C++ runtime and release resources.

* Small fix

* fix typo

* Add a flag to control whether explicit `destroy()` is required.
parent 06f417dc
......@@ -12,10 +12,11 @@
namespace deep_ep {
Buffer::Buffer(int rank, int num_ranks, int64_t num_nvl_bytes, int64_t num_rdma_bytes, bool low_latency_mode):
Buffer::Buffer(int rank, int num_ranks, int64_t num_nvl_bytes, int64_t num_rdma_bytes, bool low_latency_mode, bool explicitly_destroy):
rank(rank), num_ranks(num_ranks),
num_nvl_bytes(num_nvl_bytes), num_rdma_bytes(num_rdma_bytes),
low_latency_mode(low_latency_mode),
explicitly_destroy(explicitly_destroy),
comm_stream(at::cuda::getStreamFromPool(true)) {
// Metadata memory
int64_t barrier_signal_bytes = NUM_MAX_NVL_PEERS * sizeof(int);
......@@ -81,40 +82,12 @@ Buffer::Buffer(int rank, int num_ranks, int64_t num_nvl_bytes, int64_t num_rdma_
}
Buffer::~Buffer() noexcept(false) {
// Synchronize
CUDA_CHECK(cudaDeviceSynchronize());
if (num_nvl_bytes > 0) {
// Barrier
intranode::barrier(barrier_signal_ptrs_gpu, nvl_rank, num_nvl_ranks, comm_stream);
CUDA_CHECK(cudaDeviceSynchronize());
// Close remote IPC
if (is_available()) {
for (int i = 0; i < num_nvl_ranks; ++ i) if (i != nvl_rank)
CUDA_CHECK(cudaIpcCloseMemHandle(buffer_ptrs[i]));
}
// Free local buffer and error flag
CUDA_CHECK(cudaFree(buffer_ptrs[nvl_rank]));
if (not explicitly_destroy) {
destroy();
} else if (not destroyed) {
printf("WARNING: destroy() was not called before DeepEP buffer destruction, which can leak resources.\n");
fflush(stdout);
}
// Free NVSHMEM
#ifndef DISABLE_NVSHMEM
if (num_rdma_bytes > 0) {
CUDA_CHECK(cudaDeviceSynchronize());
internode::barrier();
internode::free(rdma_buffer_ptr);
internode::finalize();
}
#endif
// Free workspace and MoE counter
CUDA_CHECK(cudaFree(workspace));
CUDA_CHECK(cudaFreeHost(const_cast<int*>(moe_recv_counter)));
// Free chunked mode staffs
CUDA_CHECK(cudaFreeHost(const_cast<int*>(moe_recv_expert_counter)));
}
bool Buffer::is_available() const {
......@@ -167,6 +140,48 @@ torch::Stream Buffer::get_comm_stream() const {
return comm_stream;
}
void Buffer::destroy() {
EP_HOST_ASSERT(not destroyed);
// Synchronize
CUDA_CHECK(cudaDeviceSynchronize());
if (num_nvl_bytes > 0) {
// Barrier
intranode::barrier(barrier_signal_ptrs_gpu, nvl_rank, num_nvl_ranks, comm_stream);
CUDA_CHECK(cudaDeviceSynchronize());
// Close remote IPC
if (is_available()) {
for (int i = 0; i < num_nvl_ranks; ++ i) if (i != nvl_rank)
CUDA_CHECK(cudaIpcCloseMemHandle(buffer_ptrs[i]));
}
// Free local buffer and error flag
CUDA_CHECK(cudaFree(buffer_ptrs[nvl_rank]));
}
// Free NVSHMEM
#ifndef DISABLE_NVSHMEM
if (is_available() and num_rdma_bytes > 0) {
CUDA_CHECK(cudaDeviceSynchronize());
internode::barrier();
internode::free(rdma_buffer_ptr);
internode::finalize();
}
#endif
// Free workspace and MoE counter
CUDA_CHECK(cudaFree(workspace));
CUDA_CHECK(cudaFreeHost(const_cast<int*>(moe_recv_counter)));
// Free chunked mode staffs
CUDA_CHECK(cudaFreeHost(const_cast<int*>(moe_recv_expert_counter)));
destroyed = true;
available = false;
}
void Buffer::sync(const std::vector<int> &device_ids,
const std::vector<std::optional<pybind11::bytearray>> &all_gathered_handles,
const std::optional<pybind11::bytearray>& root_unique_id_opt) {
......@@ -1323,7 +1338,7 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
.def("current_stream_wait", &deep_ep::EventHandle::current_stream_wait);
pybind11::class_<deep_ep::Buffer>(m, "Buffer")
.def(pybind11::init<int, int, int64_t, int64_t, bool>())
.def(pybind11::init<int, int, int64_t, int64_t, bool, bool>())
.def("is_available", &deep_ep::Buffer::is_available)
.def("get_num_rdma_ranks", &deep_ep::Buffer::get_num_rdma_ranks)
.def("get_rdma_rank", &deep_ep::Buffer::get_rdma_rank)
......@@ -1334,6 +1349,7 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
.def("get_local_buffer_tensor", &deep_ep::Buffer::get_local_buffer_tensor)
.def("get_comm_stream", &deep_ep::Buffer::get_comm_stream)
.def("sync", &deep_ep::Buffer::sync)
.def("destroy", &deep_ep::Buffer::destroy)
.def("get_dispatch_layout", &deep_ep::Buffer::get_dispatch_layout)
.def("intranode_dispatch", &deep_ep::Buffer::intranode_dispatch)
.def("intranode_combine", &deep_ep::Buffer::intranode_combine)
......
......@@ -52,6 +52,11 @@ private:
// After IPC/NVSHMEM synchronization, this flag will be true
bool available = false;
// Whether explicit `destroy()` is required.
bool explicitly_destroy;
// After `destroy()` be called, this flag will be true
bool destroyed = false;
// Barrier signals
int* barrier_signal_ptrs[NUM_MAX_NVL_PEERS] = {nullptr};
int** barrier_signal_ptrs_gpu = nullptr;
......@@ -72,7 +77,7 @@ private:
int* moe_recv_rdma_counter_mapped = nullptr;
public:
Buffer(int rank, int num_ranks, int64_t num_nvl_bytes, int64_t num_rdma_bytes, bool low_latency_mode);
Buffer(int rank, int num_ranks, int64_t num_nvl_bytes, int64_t num_rdma_bytes, bool low_latency_mode, bool explicitly_destroy);
~Buffer() noexcept(false);
......@@ -98,6 +103,8 @@ public:
void sync(const std::vector<int>& device_ids, const std::vector<std::optional<pybind11::bytearray>>& all_gathered_handles, const std::optional<pybind11::bytearray>& root_unique_id_opt);
void destroy();
std::tuple<torch::Tensor, std::optional<torch::Tensor>, torch::Tensor, torch::Tensor, std::optional<EventHandle>>
get_dispatch_layout(const torch::Tensor& topk_idx, int num_experts, std::optional<EventHandle>& previous_event,
bool async, bool allocate_on_comm_stream);
......
......@@ -33,7 +33,8 @@ class Buffer:
num_nvl_bytes: int = 0, num_rdma_bytes: int = 0,
low_latency_mode: bool = False, num_qps_per_rank: int = 24,
allow_nvlink_for_low_latency_mode: bool = True,
allow_mnnvl: bool = False) -> None:
allow_mnnvl: bool = False,
explicitly_destroy: bool = False) -> None:
"""
Initialize the communication buffer.
......@@ -49,6 +50,9 @@ class Buffer:
Warning: PCIe connections may lead to errors due to memory ordering issues,
please make sure all connections are via NVLink.
allow_mnnvl: whether to allow MNNVL
explicitly_destroy: If this flag is set to True, you need to explicitly call `destroy()` to release resources;
otherwise, the resources will be released by the destructor.
Note: Releasing resources in the destructor may cause Python's exception handling process to hang.
"""
check_nvlink_connections(group)
......@@ -59,7 +63,8 @@ class Buffer:
self.num_nvl_bytes = num_nvl_bytes
self.num_rdma_bytes = num_rdma_bytes
self.low_latency_mode = low_latency_mode
self.runtime = deep_ep_cpp.Buffer(self.rank, self.group_size, num_nvl_bytes, num_rdma_bytes, low_latency_mode)
self.explicitly_destroy = explicitly_destroy
self.runtime = deep_ep_cpp.Buffer(self.rank, self.group_size, num_nvl_bytes, num_rdma_bytes, low_latency_mode, explicitly_destroy)
# Synchronize device IDs
device_ids = [None, ] * self.group_size
......@@ -106,6 +111,18 @@ class Buffer:
self.runtime.sync(device_ids, ipc_handles, root_unique_id)
assert self.runtime.is_available()
def destroy(self):
"""
Destroy the cpp runtime and release resources.
"""
assert self.explicitly_destroy, '`explicitly_destroy` flag must be set'
self.runtime.destroy()
self.runtime = None
@staticmethod
def is_sm90_compiled():
return deep_ep_cpp.is_sm90_compiled()
......
......@@ -235,7 +235,7 @@ def test_loop(local_rank: int, num_local_ranks: int, args: argparse.Namespace):
num_qps_per_rank = max(num_sms, ll_num_experts // num_ranks if args.test_ll_compatibility else 0)
buffer = deep_ep.Buffer(group, int(2e9), int(1e9), low_latency_mode=args.test_ll_compatibility,
num_qps_per_rank=num_qps_per_rank)
num_qps_per_rank=num_qps_per_rank, explicitly_destroy=True)
assert num_local_ranks == 8 and num_ranks > 8
torch.manual_seed(rank)
......@@ -249,7 +249,8 @@ def test_loop(local_rank: int, num_local_ranks: int, args: argparse.Namespace):
buffer.clean_low_latency_buffer(ll_num_tokens, ll_hidden, ll_num_experts)
test_low_latency.test_main(ll_num_tokens, ll_hidden, ll_num_experts, ll_num_topk, rank, num_ranks, group, buffer, seed=1)
# Destroy the communication group
# Destroy the buffer runtime and communication group
buffer.destroy()
dist.barrier()
dist.destroy_process_group()
......
......@@ -238,7 +238,7 @@ def test_loop(local_rank: int, num_local_ranks: int, args: argparse.Namespace):
num_rdma_bytes = deep_ep.Buffer.get_low_latency_rdma_size_hint(ll_num_tokens, ll_hidden, num_ranks, ll_num_experts)
buffer = deep_ep.Buffer(group, int(2e9), num_rdma_bytes, low_latency_mode=test_ll_compatibility,
num_qps_per_rank=(ll_num_experts // num_ranks if test_ll_compatibility else 1))
num_qps_per_rank=(ll_num_experts // num_ranks if test_ll_compatibility else 1), explicitly_destroy=True)
torch.manual_seed(rank)
for i in (24, ):
......@@ -251,7 +251,8 @@ def test_loop(local_rank: int, num_local_ranks: int, args: argparse.Namespace):
buffer.clean_low_latency_buffer(ll_num_tokens, ll_hidden, ll_num_experts)
test_low_latency.test_main(ll_num_tokens, ll_hidden, ll_num_experts, ll_num_topk, rank, num_ranks, group, buffer, seed=1)
# Destroy the communication group
# Destroy the buffer runtime and communication group
buffer.destroy()
dist.barrier()
dist.destroy_process_group()
......
......@@ -160,7 +160,7 @@ def test_loop(local_rank: int, num_local_ranks: int, args: argparse.Namespace):
print(f'Allocating buffer size: {num_rdma_bytes / 1e6} MB ...', flush=True)
buffer = deep_ep.Buffer(group, num_rdma_bytes=num_rdma_bytes, low_latency_mode=True,
num_qps_per_rank=num_experts // num_ranks,
allow_nvlink_for_low_latency_mode=not args.disable_nvlink)
allow_nvlink_for_low_latency_mode=not args.disable_nvlink, explicitly_destroy=True)
test_main(num_tokens, hidden, num_experts, num_topk, rank, num_ranks, group, buffer,
use_logfmt=args.use_logfmt, seed=1)
......@@ -174,7 +174,8 @@ def test_loop(local_rank: int, num_local_ranks: int, args: argparse.Namespace):
assert test_main(num_tokens, hidden, num_experts, num_topk, rank, num_ranks, group, buffer,
use_logfmt=args.use_logfmt, seed=seed) == ref_hash, f'Error: seed={seed}'
# Destroy the communication group
# Destroy the buffer runtime and communication group
buffer.destroy()
dist.barrier()
dist.destroy_process_group()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment