Unverified Commit 0c984e25 authored by Shangyan Zhou's avatar Shangyan Zhou Committed by GitHub
Browse files

Explicitly destroy the C++ runtime and release resources. (#292)

* Explicitly destroy the C++ runtime and release resources.

* Small fix

* fix typo

* Add a flag to control whether explicit `destroy()` is required.
parent 06f417dc
...@@ -12,10 +12,11 @@ ...@@ -12,10 +12,11 @@
namespace deep_ep { namespace deep_ep {
Buffer::Buffer(int rank, int num_ranks, int64_t num_nvl_bytes, int64_t num_rdma_bytes, bool low_latency_mode): Buffer::Buffer(int rank, int num_ranks, int64_t num_nvl_bytes, int64_t num_rdma_bytes, bool low_latency_mode, bool explicitly_destroy):
rank(rank), num_ranks(num_ranks), rank(rank), num_ranks(num_ranks),
num_nvl_bytes(num_nvl_bytes), num_rdma_bytes(num_rdma_bytes), num_nvl_bytes(num_nvl_bytes), num_rdma_bytes(num_rdma_bytes),
low_latency_mode(low_latency_mode), low_latency_mode(low_latency_mode),
explicitly_destroy(explicitly_destroy),
comm_stream(at::cuda::getStreamFromPool(true)) { comm_stream(at::cuda::getStreamFromPool(true)) {
// Metadata memory // Metadata memory
int64_t barrier_signal_bytes = NUM_MAX_NVL_PEERS * sizeof(int); int64_t barrier_signal_bytes = NUM_MAX_NVL_PEERS * sizeof(int);
...@@ -81,40 +82,12 @@ Buffer::Buffer(int rank, int num_ranks, int64_t num_nvl_bytes, int64_t num_rdma_ ...@@ -81,40 +82,12 @@ Buffer::Buffer(int rank, int num_ranks, int64_t num_nvl_bytes, int64_t num_rdma_
} }
Buffer::~Buffer() noexcept(false) { Buffer::~Buffer() noexcept(false) {
// Synchronize if (not explicitly_destroy) {
CUDA_CHECK(cudaDeviceSynchronize()); destroy();
} else if (not destroyed) {
if (num_nvl_bytes > 0) { printf("WARNING: destroy() was not called before DeepEP buffer destruction, which can leak resources.\n");
// Barrier fflush(stdout);
intranode::barrier(barrier_signal_ptrs_gpu, nvl_rank, num_nvl_ranks, comm_stream);
CUDA_CHECK(cudaDeviceSynchronize());
// Close remote IPC
if (is_available()) {
for (int i = 0; i < num_nvl_ranks; ++ i) if (i != nvl_rank)
CUDA_CHECK(cudaIpcCloseMemHandle(buffer_ptrs[i]));
}
// Free local buffer and error flag
CUDA_CHECK(cudaFree(buffer_ptrs[nvl_rank]));
} }
// Free NVSHMEM
#ifndef DISABLE_NVSHMEM
if (num_rdma_bytes > 0) {
CUDA_CHECK(cudaDeviceSynchronize());
internode::barrier();
internode::free(rdma_buffer_ptr);
internode::finalize();
}
#endif
// Free workspace and MoE counter
CUDA_CHECK(cudaFree(workspace));
CUDA_CHECK(cudaFreeHost(const_cast<int*>(moe_recv_counter)));
// Free chunked mode staffs
CUDA_CHECK(cudaFreeHost(const_cast<int*>(moe_recv_expert_counter)));
} }
bool Buffer::is_available() const { bool Buffer::is_available() const {
...@@ -167,6 +140,48 @@ torch::Stream Buffer::get_comm_stream() const { ...@@ -167,6 +140,48 @@ torch::Stream Buffer::get_comm_stream() const {
return comm_stream; return comm_stream;
} }
void Buffer::destroy() {
EP_HOST_ASSERT(not destroyed);
// Synchronize
CUDA_CHECK(cudaDeviceSynchronize());
if (num_nvl_bytes > 0) {
// Barrier
intranode::barrier(barrier_signal_ptrs_gpu, nvl_rank, num_nvl_ranks, comm_stream);
CUDA_CHECK(cudaDeviceSynchronize());
// Close remote IPC
if (is_available()) {
for (int i = 0; i < num_nvl_ranks; ++ i) if (i != nvl_rank)
CUDA_CHECK(cudaIpcCloseMemHandle(buffer_ptrs[i]));
}
// Free local buffer and error flag
CUDA_CHECK(cudaFree(buffer_ptrs[nvl_rank]));
}
// Free NVSHMEM
#ifndef DISABLE_NVSHMEM
if (is_available() and num_rdma_bytes > 0) {
CUDA_CHECK(cudaDeviceSynchronize());
internode::barrier();
internode::free(rdma_buffer_ptr);
internode::finalize();
}
#endif
// Free workspace and MoE counter
CUDA_CHECK(cudaFree(workspace));
CUDA_CHECK(cudaFreeHost(const_cast<int*>(moe_recv_counter)));
// Free chunked mode staffs
CUDA_CHECK(cudaFreeHost(const_cast<int*>(moe_recv_expert_counter)));
destroyed = true;
available = false;
}
void Buffer::sync(const std::vector<int> &device_ids, void Buffer::sync(const std::vector<int> &device_ids,
const std::vector<std::optional<pybind11::bytearray>> &all_gathered_handles, const std::vector<std::optional<pybind11::bytearray>> &all_gathered_handles,
const std::optional<pybind11::bytearray>& root_unique_id_opt) { const std::optional<pybind11::bytearray>& root_unique_id_opt) {
...@@ -1323,7 +1338,7 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { ...@@ -1323,7 +1338,7 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
.def("current_stream_wait", &deep_ep::EventHandle::current_stream_wait); .def("current_stream_wait", &deep_ep::EventHandle::current_stream_wait);
pybind11::class_<deep_ep::Buffer>(m, "Buffer") pybind11::class_<deep_ep::Buffer>(m, "Buffer")
.def(pybind11::init<int, int, int64_t, int64_t, bool>()) .def(pybind11::init<int, int, int64_t, int64_t, bool, bool>())
.def("is_available", &deep_ep::Buffer::is_available) .def("is_available", &deep_ep::Buffer::is_available)
.def("get_num_rdma_ranks", &deep_ep::Buffer::get_num_rdma_ranks) .def("get_num_rdma_ranks", &deep_ep::Buffer::get_num_rdma_ranks)
.def("get_rdma_rank", &deep_ep::Buffer::get_rdma_rank) .def("get_rdma_rank", &deep_ep::Buffer::get_rdma_rank)
...@@ -1334,6 +1349,7 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { ...@@ -1334,6 +1349,7 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
.def("get_local_buffer_tensor", &deep_ep::Buffer::get_local_buffer_tensor) .def("get_local_buffer_tensor", &deep_ep::Buffer::get_local_buffer_tensor)
.def("get_comm_stream", &deep_ep::Buffer::get_comm_stream) .def("get_comm_stream", &deep_ep::Buffer::get_comm_stream)
.def("sync", &deep_ep::Buffer::sync) .def("sync", &deep_ep::Buffer::sync)
.def("destroy", &deep_ep::Buffer::destroy)
.def("get_dispatch_layout", &deep_ep::Buffer::get_dispatch_layout) .def("get_dispatch_layout", &deep_ep::Buffer::get_dispatch_layout)
.def("intranode_dispatch", &deep_ep::Buffer::intranode_dispatch) .def("intranode_dispatch", &deep_ep::Buffer::intranode_dispatch)
.def("intranode_combine", &deep_ep::Buffer::intranode_combine) .def("intranode_combine", &deep_ep::Buffer::intranode_combine)
......
...@@ -52,6 +52,11 @@ private: ...@@ -52,6 +52,11 @@ private:
// After IPC/NVSHMEM synchronization, this flag will be true // After IPC/NVSHMEM synchronization, this flag will be true
bool available = false; bool available = false;
// Whether explicit `destroy()` is required.
bool explicitly_destroy;
// After `destroy()` be called, this flag will be true
bool destroyed = false;
// Barrier signals // Barrier signals
int* barrier_signal_ptrs[NUM_MAX_NVL_PEERS] = {nullptr}; int* barrier_signal_ptrs[NUM_MAX_NVL_PEERS] = {nullptr};
int** barrier_signal_ptrs_gpu = nullptr; int** barrier_signal_ptrs_gpu = nullptr;
...@@ -72,7 +77,7 @@ private: ...@@ -72,7 +77,7 @@ private:
int* moe_recv_rdma_counter_mapped = nullptr; int* moe_recv_rdma_counter_mapped = nullptr;
public: public:
Buffer(int rank, int num_ranks, int64_t num_nvl_bytes, int64_t num_rdma_bytes, bool low_latency_mode); Buffer(int rank, int num_ranks, int64_t num_nvl_bytes, int64_t num_rdma_bytes, bool low_latency_mode, bool explicitly_destroy);
~Buffer() noexcept(false); ~Buffer() noexcept(false);
...@@ -98,6 +103,8 @@ public: ...@@ -98,6 +103,8 @@ public:
void sync(const std::vector<int>& device_ids, const std::vector<std::optional<pybind11::bytearray>>& all_gathered_handles, const std::optional<pybind11::bytearray>& root_unique_id_opt); void sync(const std::vector<int>& device_ids, const std::vector<std::optional<pybind11::bytearray>>& all_gathered_handles, const std::optional<pybind11::bytearray>& root_unique_id_opt);
void destroy();
std::tuple<torch::Tensor, std::optional<torch::Tensor>, torch::Tensor, torch::Tensor, std::optional<EventHandle>> std::tuple<torch::Tensor, std::optional<torch::Tensor>, torch::Tensor, torch::Tensor, std::optional<EventHandle>>
get_dispatch_layout(const torch::Tensor& topk_idx, int num_experts, std::optional<EventHandle>& previous_event, get_dispatch_layout(const torch::Tensor& topk_idx, int num_experts, std::optional<EventHandle>& previous_event,
bool async, bool allocate_on_comm_stream); bool async, bool allocate_on_comm_stream);
......
...@@ -33,7 +33,8 @@ class Buffer: ...@@ -33,7 +33,8 @@ class Buffer:
num_nvl_bytes: int = 0, num_rdma_bytes: int = 0, num_nvl_bytes: int = 0, num_rdma_bytes: int = 0,
low_latency_mode: bool = False, num_qps_per_rank: int = 24, low_latency_mode: bool = False, num_qps_per_rank: int = 24,
allow_nvlink_for_low_latency_mode: bool = True, allow_nvlink_for_low_latency_mode: bool = True,
allow_mnnvl: bool = False) -> None: allow_mnnvl: bool = False,
explicitly_destroy: bool = False) -> None:
""" """
Initialize the communication buffer. Initialize the communication buffer.
...@@ -49,6 +50,9 @@ class Buffer: ...@@ -49,6 +50,9 @@ class Buffer:
Warning: PCIe connections may lead to errors due to memory ordering issues, Warning: PCIe connections may lead to errors due to memory ordering issues,
please make sure all connections are via NVLink. please make sure all connections are via NVLink.
allow_mnnvl: whether to allow MNNVL allow_mnnvl: whether to allow MNNVL
explicitly_destroy: If this flag is set to True, you need to explicitly call `destroy()` to release resources;
otherwise, the resources will be released by the destructor.
Note: Releasing resources in the destructor may cause Python's exception handling process to hang.
""" """
check_nvlink_connections(group) check_nvlink_connections(group)
...@@ -59,7 +63,8 @@ class Buffer: ...@@ -59,7 +63,8 @@ class Buffer:
self.num_nvl_bytes = num_nvl_bytes self.num_nvl_bytes = num_nvl_bytes
self.num_rdma_bytes = num_rdma_bytes self.num_rdma_bytes = num_rdma_bytes
self.low_latency_mode = low_latency_mode self.low_latency_mode = low_latency_mode
self.runtime = deep_ep_cpp.Buffer(self.rank, self.group_size, num_nvl_bytes, num_rdma_bytes, low_latency_mode) self.explicitly_destroy = explicitly_destroy
self.runtime = deep_ep_cpp.Buffer(self.rank, self.group_size, num_nvl_bytes, num_rdma_bytes, low_latency_mode, explicitly_destroy)
# Synchronize device IDs # Synchronize device IDs
device_ids = [None, ] * self.group_size device_ids = [None, ] * self.group_size
...@@ -106,6 +111,18 @@ class Buffer: ...@@ -106,6 +111,18 @@ class Buffer:
self.runtime.sync(device_ids, ipc_handles, root_unique_id) self.runtime.sync(device_ids, ipc_handles, root_unique_id)
assert self.runtime.is_available() assert self.runtime.is_available()
def destroy(self):
"""
Destroy the cpp runtime and release resources.
"""
assert self.explicitly_destroy, '`explicitly_destroy` flag must be set'
self.runtime.destroy()
self.runtime = None
@staticmethod @staticmethod
def is_sm90_compiled(): def is_sm90_compiled():
return deep_ep_cpp.is_sm90_compiled() return deep_ep_cpp.is_sm90_compiled()
......
...@@ -235,7 +235,7 @@ def test_loop(local_rank: int, num_local_ranks: int, args: argparse.Namespace): ...@@ -235,7 +235,7 @@ def test_loop(local_rank: int, num_local_ranks: int, args: argparse.Namespace):
num_qps_per_rank = max(num_sms, ll_num_experts // num_ranks if args.test_ll_compatibility else 0) num_qps_per_rank = max(num_sms, ll_num_experts // num_ranks if args.test_ll_compatibility else 0)
buffer = deep_ep.Buffer(group, int(2e9), int(1e9), low_latency_mode=args.test_ll_compatibility, buffer = deep_ep.Buffer(group, int(2e9), int(1e9), low_latency_mode=args.test_ll_compatibility,
num_qps_per_rank=num_qps_per_rank) num_qps_per_rank=num_qps_per_rank, explicitly_destroy=True)
assert num_local_ranks == 8 and num_ranks > 8 assert num_local_ranks == 8 and num_ranks > 8
torch.manual_seed(rank) torch.manual_seed(rank)
...@@ -249,7 +249,8 @@ def test_loop(local_rank: int, num_local_ranks: int, args: argparse.Namespace): ...@@ -249,7 +249,8 @@ def test_loop(local_rank: int, num_local_ranks: int, args: argparse.Namespace):
buffer.clean_low_latency_buffer(ll_num_tokens, ll_hidden, ll_num_experts) buffer.clean_low_latency_buffer(ll_num_tokens, ll_hidden, ll_num_experts)
test_low_latency.test_main(ll_num_tokens, ll_hidden, ll_num_experts, ll_num_topk, rank, num_ranks, group, buffer, seed=1) test_low_latency.test_main(ll_num_tokens, ll_hidden, ll_num_experts, ll_num_topk, rank, num_ranks, group, buffer, seed=1)
# Destroy the communication group # Destroy the buffer runtime and communication group
buffer.destroy()
dist.barrier() dist.barrier()
dist.destroy_process_group() dist.destroy_process_group()
......
...@@ -238,7 +238,7 @@ def test_loop(local_rank: int, num_local_ranks: int, args: argparse.Namespace): ...@@ -238,7 +238,7 @@ def test_loop(local_rank: int, num_local_ranks: int, args: argparse.Namespace):
num_rdma_bytes = deep_ep.Buffer.get_low_latency_rdma_size_hint(ll_num_tokens, ll_hidden, num_ranks, ll_num_experts) num_rdma_bytes = deep_ep.Buffer.get_low_latency_rdma_size_hint(ll_num_tokens, ll_hidden, num_ranks, ll_num_experts)
buffer = deep_ep.Buffer(group, int(2e9), num_rdma_bytes, low_latency_mode=test_ll_compatibility, buffer = deep_ep.Buffer(group, int(2e9), num_rdma_bytes, low_latency_mode=test_ll_compatibility,
num_qps_per_rank=(ll_num_experts // num_ranks if test_ll_compatibility else 1)) num_qps_per_rank=(ll_num_experts // num_ranks if test_ll_compatibility else 1), explicitly_destroy=True)
torch.manual_seed(rank) torch.manual_seed(rank)
for i in (24, ): for i in (24, ):
...@@ -251,7 +251,8 @@ def test_loop(local_rank: int, num_local_ranks: int, args: argparse.Namespace): ...@@ -251,7 +251,8 @@ def test_loop(local_rank: int, num_local_ranks: int, args: argparse.Namespace):
buffer.clean_low_latency_buffer(ll_num_tokens, ll_hidden, ll_num_experts) buffer.clean_low_latency_buffer(ll_num_tokens, ll_hidden, ll_num_experts)
test_low_latency.test_main(ll_num_tokens, ll_hidden, ll_num_experts, ll_num_topk, rank, num_ranks, group, buffer, seed=1) test_low_latency.test_main(ll_num_tokens, ll_hidden, ll_num_experts, ll_num_topk, rank, num_ranks, group, buffer, seed=1)
# Destroy the communication group # Destroy the buffer runtime and communication group
buffer.destroy()
dist.barrier() dist.barrier()
dist.destroy_process_group() dist.destroy_process_group()
......
...@@ -160,7 +160,7 @@ def test_loop(local_rank: int, num_local_ranks: int, args: argparse.Namespace): ...@@ -160,7 +160,7 @@ def test_loop(local_rank: int, num_local_ranks: int, args: argparse.Namespace):
print(f'Allocating buffer size: {num_rdma_bytes / 1e6} MB ...', flush=True) print(f'Allocating buffer size: {num_rdma_bytes / 1e6} MB ...', flush=True)
buffer = deep_ep.Buffer(group, num_rdma_bytes=num_rdma_bytes, low_latency_mode=True, buffer = deep_ep.Buffer(group, num_rdma_bytes=num_rdma_bytes, low_latency_mode=True,
num_qps_per_rank=num_experts // num_ranks, num_qps_per_rank=num_experts // num_ranks,
allow_nvlink_for_low_latency_mode=not args.disable_nvlink) allow_nvlink_for_low_latency_mode=not args.disable_nvlink, explicitly_destroy=True)
test_main(num_tokens, hidden, num_experts, num_topk, rank, num_ranks, group, buffer, test_main(num_tokens, hidden, num_experts, num_topk, rank, num_ranks, group, buffer,
use_logfmt=args.use_logfmt, seed=1) use_logfmt=args.use_logfmt, seed=1)
...@@ -174,7 +174,8 @@ def test_loop(local_rank: int, num_local_ranks: int, args: argparse.Namespace): ...@@ -174,7 +174,8 @@ def test_loop(local_rank: int, num_local_ranks: int, args: argparse.Namespace):
assert test_main(num_tokens, hidden, num_experts, num_topk, rank, num_ranks, group, buffer, assert test_main(num_tokens, hidden, num_experts, num_topk, rank, num_ranks, group, buffer,
use_logfmt=args.use_logfmt, seed=seed) == ref_hash, f'Error: seed={seed}' use_logfmt=args.use_logfmt, seed=seed) == ref_hash, f'Error: seed={seed}'
# Destroy the communication group # Destroy the buffer runtime and communication group
buffer.destroy()
dist.barrier() dist.barrier()
dist.destroy_process_group() dist.destroy_process_group()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment