Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
DeepEP
Commits
d7f41337
Commit
d7f41337
authored
Jan 15, 2026
by
lijian6
Browse files
Modify nvshmem to dushmem.
Signed-off-by:
lijian
<
lijian6@sugon.com
>
parent
1a2f45fc
Changes
8
Show whitespace changes
Inline
Side-by-side
Showing
8 changed files
with
109 additions
and
109 deletions
+109
-109
build.sh
build.sh
+3
-3
csrc/deep_ep.cu
csrc/deep_ep.cu
+5
-5
csrc/deep_ep.hpp
csrc/deep_ep.hpp
+3
-3
csrc/kernels/internode.cu
csrc/kernels/internode.cu
+22
-22
csrc/kernels/internode_ll.cu
csrc/kernels/internode_ll.cu
+16
-16
csrc/kernels/runtime.cu
csrc/kernels/runtime.cu
+3
-3
csrc/kernels/shmem_wrapper.cuh
csrc/kernels/shmem_wrapper.cuh
+36
-36
deep_ep/buffer.py
deep_ep/buffer.py
+21
-21
No files found.
build.sh
View file @
d7f41337
...
...
@@ -48,7 +48,7 @@ for arg in "$@"; do
ROCM_USE_MULTIQP
=
ON
;;
*
)
echo
"Usage: ./build.sh rocshmem [ROCM_DISABLE_CTX=ON] [ROCM_USE_MULTIQP=ON] / ./build.sh
nv
shmem"
echo
"Usage: ./build.sh rocshmem [ROCM_DISABLE_CTX=ON] [ROCM_USE_MULTIQP=ON] / ./build.sh
du
shmem"
exit
1
;;
esac
...
...
@@ -133,8 +133,8 @@ if [ "$USE_NVSHMEM" == "ON" ]; then
# build_dushmem
# SHMEM_INSTALL_PREFIX=$(pwd)/third-party/dushmem_install
SHMEM_INSTALL_PREFIX
=
${
ROCM_PATH
}
/dushmem
COMPILE_OPTIONS
=
${
COMPILE_OPTIONS
:
= -fPIC -DFORCE_
NV
SHMEM_API -DHIP_ENABLE_WARP_SYNC_BUILTINS -D__HIP_PLATFORM_AMD__=1 -DUSE_ROCM=1 -DHIPBLAS_V2 -DCUDA_HAS_FP16=1 -O3 -fgpu-rdc -DTORCH_API_INCLUDE_EXTENSION_H
'-DPYBIND11_COMPILER_TYPE="_gcc"'
'-DPYBIND11_STDLIB="_libstdcpp"'
'-DPYBIND11_BUILD_ABI="_cxxabi1014"'
-DTORCH_EXTENSION_NAME=deep_ep_cpp -D_GLIBCXX_USE_CXX11_ABI=1 --offload-arch=gfx936 --offload-arch=gfx938 -std=c++17 -Wno-return-type
}
SHMEM_LINK_OPTIONS
=
"-Wl,-rpath,
${
SHMEM_INSTALL_PREFIX
}
/lib/ -l:lib
nv
shmem_device.a -l
nv
shmem_host"
COMPILE_OPTIONS
=
${
COMPILE_OPTIONS
:
= -fPIC -DFORCE_
DU
SHMEM_API -DHIP_ENABLE_WARP_SYNC_BUILTINS -D__HIP_PLATFORM_AMD__=1 -DUSE_ROCM=1 -DHIPBLAS_V2 -DCUDA_HAS_FP16=1 -O3 -fgpu-rdc -DTORCH_API_INCLUDE_EXTENSION_H
'-DPYBIND11_COMPILER_TYPE="_gcc"'
'-DPYBIND11_STDLIB="_libstdcpp"'
'-DPYBIND11_BUILD_ABI="_cxxabi1014"'
-DTORCH_EXTENSION_NAME=deep_ep_cpp -D_GLIBCXX_USE_CXX11_ABI=1 --offload-arch=gfx936 --offload-arch=gfx938 -std=c++17 -Wno-return-type
}
SHMEM_LINK_OPTIONS
=
"-Wl,-rpath,
${
SHMEM_INSTALL_PREFIX
}
/lib/ -l:lib
du
shmem_device.a -l
du
shmem_host"
fi
# -------------------------- duSHMEM END -------------------------- #
...
...
csrc/deep_ep.cu
View file @
d7f41337
...
...
@@ -143,7 +143,7 @@ pybind11::bytearray Buffer::get_local_ipc_handle() const {
return
{
ipc_handles
[
nvl_rank
].
reserved
,
HIP_IPC_HANDLE_SIZE
};
}
pybind11
::
bytearray
Buffer
::
get_local_
nv
shmem_unique_id
()
const
{
pybind11
::
bytearray
Buffer
::
get_local_
du
shmem_unique_id
()
const
{
#ifndef DISABLE_ROCSHMEM
EP_HOST_ASSERT
(
rdma_rank
==
0
and
"Only RDMA rank 0 can get ROCSHMEM unique ID"
);
auto
unique_id
=
internode
::
get_unique_id
();
...
...
@@ -260,9 +260,9 @@ void Buffer::sync(const std::vector<int> &device_
std
::
vector
<
uint8_t
>
root_unique_id
(
root_unique_id_opt
->
size
());
auto
root_unique_id_str
=
root_unique_id_opt
->
cast
<
std
::
string
>
();
std
::
memcpy
(
root_unique_id
.
data
(),
root_unique_id_str
.
c_str
(),
root_unique_id_opt
->
size
());
auto
nv
shmem_rank
=
low_latency_mode
?
rank
:
rdma_rank
;
auto
num_
nv
shmem_ranks
=
low_latency_mode
?
num_ranks
:
num_rdma_ranks
;
EP_HOST_ASSERT
(
nv
shmem_rank
==
internode
::
init
(
root_unique_id
,
nv
shmem_rank
,
num_
nv
shmem_ranks
,
low_latency_mode
));
auto
du
shmem_rank
=
low_latency_mode
?
rank
:
rdma_rank
;
auto
num_
du
shmem_ranks
=
low_latency_mode
?
num_ranks
:
num_rdma_ranks
;
EP_HOST_ASSERT
(
du
shmem_rank
==
internode
::
init
(
root_unique_id
,
du
shmem_rank
,
num_
du
shmem_ranks
,
low_latency_mode
));
internode
::
barrier
();
// Allocate
...
...
@@ -1531,7 +1531,7 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
.
def
(
"get_root_rdma_rank"
,
&
deep_ep
::
Buffer
::
get_root_rdma_rank
)
.
def
(
"get_local_device_id"
,
&
deep_ep
::
Buffer
::
get_local_device_id
)
.
def
(
"get_local_ipc_handle"
,
&
deep_ep
::
Buffer
::
get_local_ipc_handle
)
.
def
(
"get_local_
nv
shmem_unique_id"
,
&
deep_ep
::
Buffer
::
get_local_
nv
shmem_unique_id
)
.
def
(
"get_local_
du
shmem_unique_id"
,
&
deep_ep
::
Buffer
::
get_local_
du
shmem_unique_id
)
.
def
(
"get_local_buffer_tensor"
,
&
deep_ep
::
Buffer
::
get_local_buffer_tensor
)
.
def
(
"get_comm_stream"
,
&
deep_ep
::
Buffer
::
get_comm_stream
)
.
def
(
"sync"
,
&
deep_ep
::
Buffer
::
sync
)
...
...
csrc/deep_ep.hpp
View file @
d7f41337
...
...
@@ -29,7 +29,7 @@ private:
void
*
nvl_buffer_ptrs
[
NUM_MAX_NVL_PEERS
]
=
{
nullptr
};
void
**
nvl_buffer_ptrs_gpu
=
nullptr
;
//
NV
SHMEM Buffer
//
DU
SHMEM Buffer
int64_t
num_rdma_bytes
;
void
*
rdma_buffer_ptr
=
nullptr
;
...
...
@@ -48,7 +48,7 @@ private:
// Stream for communication
at
::
hip
::
HIPStreamMasqueradingAsCUDA
comm_stream
;
// After IPC/
NV
SHMEM synchronization, this flag will be true
// After IPC/
DU
SHMEM synchronization, this flag will be true
bool
available
=
false
;
// Whether explicit `destroy()` is required.
...
...
@@ -95,7 +95,7 @@ public:
pybind11
::
bytearray
get_local_ipc_handle
()
const
;
pybind11
::
bytearray
get_local_
nv
shmem_unique_id
()
const
;
pybind11
::
bytearray
get_local_
du
shmem_unique_id
()
const
;
torch
::
Tensor
get_local_buffer_tensor
(
const
pybind11
::
object
&
dtype
,
int64_t
offset
,
bool
use_rdma_buffer
)
const
;
...
...
csrc/kernels/internode.cu
View file @
d7f41337
...
...
@@ -86,7 +86,7 @@ __forceinline__ __device__ int translate_dst_rdma_rank(const int dst_rdma_rank,
template
<
bool
kLowLatencyMode
>
__forceinline__
__device__
void
nv
shmem_barrier_with_same_gpu_idx
(
const
shmem_team_t
&
rdma_team
)
{
du
shmem_barrier_with_same_gpu_idx
(
const
shmem_team_t
&
rdma_team
)
{
// NOTE: shmem_device_barrier_all() might be an issue as
// it doesn't follow OpenSHMEM specification on ROCm
kLowLatencyMode
?
shmem_barrier
(
rdma_team
)
:
shmem_device_barrier_all
();
...
...
@@ -119,7 +119,7 @@ notify_dispatch(const int *num_tokens_per_rank, int *moe_recv_counter_mapped, in
EP_DEVICE_ASSERT
(
num_warps
>
1
);
EP_DEVICE_ASSERT
(
kNumRDMARanks
<=
num_threads
);
if
(
thread_id
==
kWarpSize
)
nv
shmem_barrier_with_same_gpu_idx
<
kLowLatencyMode
>
(
rdma_team
);
du
shmem_barrier_with_same_gpu_idx
<
kLowLatencyMode
>
(
rdma_team
);
barrier_block
<
NUM_MAX_NVL_PEERS
>
(
barrier_signal_ptrs
,
nvl_rank
);
__syncthreads
();
...
...
@@ -161,7 +161,7 @@ notify_dispatch(const int *num_tokens_per_rank, int *moe_recv_counter_mapped, in
}
__syncthreads
();
if
(
thread_id
==
0
)
nv
shmem_barrier_with_same_gpu_idx
<
kLowLatencyMode
>
(
rdma_team
);
du
shmem_barrier_with_same_gpu_idx
<
kLowLatencyMode
>
(
rdma_team
);
__syncthreads
();
...
...
@@ -189,7 +189,7 @@ notify_dispatch(const int *num_tokens_per_rank, int *moe_recv_counter_mapped, in
nvl_buffer_ptr_int
[
nvl_clean_offset
+
i
]
=
0
;
// Reduce number of tokens per expert into the NVL send buffer
// TODO: may use
NV
SHMEM reduction
// TODO: may use
DU
SHMEM reduction
EP_DEVICE_ASSERT
(
num_rdma_experts
<=
num_threads
);
if
(
thread_id
<
num_rdma_experts
)
{
int
sum
=
0
;
...
...
@@ -257,7 +257,7 @@ notify_dispatch(const int *num_tokens_per_rank, int *moe_recv_counter_mapped, in
// Finally barrier
__syncthreads
();
if
(
thread_id
==
kWarpSize
)
nv
shmem_barrier_with_same_gpu_idx
<
kLowLatencyMode
>
(
rdma_team
);
du
shmem_barrier_with_same_gpu_idx
<
kLowLatencyMode
>
(
rdma_team
);
barrier_block
<
NUM_MAX_NVL_PEERS
>
(
barrier_signal_ptrs
,
nvl_rank
);
}
else
{
...
...
@@ -399,7 +399,7 @@ dispatch(int4 *recv_x, float *recv_x_scales, int64_t *recv_topk_idx, float *recv
kForwarderCoordinator
,
// 向远端RDMA确认接收
kNVLReceivers
// 从nvl缓存写入到recv_x
};
#if !defined(FORCE_
NV
SHMEM_API) && !defined(ROCM_DISABLE_CTX)
#if !defined(FORCE_
DU
SHMEM_API) && !defined(ROCM_DISABLE_CTX)
__shared__
shmem_ctx_t
ctx
;
shmem_wg_ctx_create
(
&
ctx
);
#endif
...
...
@@ -516,7 +516,7 @@ dispatch(int4 *recv_x, float *recv_x_scales, int64_t *recv_topk_idx, float *recv
syncwarp
();
if
(
dst_rdma_rank
!=
rdma_rank
)
{
#if !defined(FORCE_
NV
SHMEM_API) && !defined(ROCM_DISABLE_CTX)
#if !defined(FORCE_
DU
SHMEM_API) && !defined(ROCM_DISABLE_CTX)
shmem_ctx_int_put_nbi_warp
(
ctx
,
#else
shmemx_int_put_nbi_warp
(
...
...
@@ -527,7 +527,7 @@ dispatch(int4 *recv_x, float *recv_x_scales, int64_t *recv_topk_idx, float *recv
}
}
#if !defined(FORCE_
NV
SHMEM_API) && !defined(ROCM_DISABLE_CTX)
#if !defined(FORCE_
DU
SHMEM_API) && !defined(ROCM_DISABLE_CTX)
shmem_ctx_quiet
(
ctx
);
#else
shmem_fence
();
...
...
@@ -690,7 +690,7 @@ dispatch(int4 *recv_x, float *recv_x_scales, int64_t *recv_topk_idx, float *recv
最后,它更新相关的尾部位置,以便下次循环时可以正确地计算需要发送的令牌数。
kRDMASenderCoordinator使用了同sm内存一致性(ld.acquire.cta.s32),
nv
shmem内存一致性(
nv
shmem_fence)和原子操作(
nv
shmemx_signal_op),减少硬同步,提升整体效率。
du
shmem内存一致性(
du
shmem_fence)和原子操作(
du
shmemx_signal_op),减少硬同步,提升整体效率。
*/
if
(
warp_id
>
kNumDispatchRDMASenderWarps
)
{
return
;
...
...
@@ -741,7 +741,7 @@ dispatch(int4 *recv_x, float *recv_x_scales, int64_t *recv_topk_idx, float *recv
if
(
dst_rdma_rank
!=
rdma_rank
)
{
auto
dst_slot_idx
=
synced_last_issued_tail
%
num_max_rdma_chunked_recv_tokens
;
EP_DEVICE_ASSERT
(
dst_slot_idx
+
num_tokens_to_issue
<=
num_max_rdma_chunked_recv_tokens
);
#if !defined(FORCE_
NV
SHMEM_API) && !defined(ROCM_DISABLE_CTX)
#if !defined(FORCE_
DU
SHMEM_API) && !defined(ROCM_DISABLE_CTX)
shmem_ctx_schar_put_nbi_warp
(
ctx
,
#else
shmemx_int8_put_nbi_warp
(
...
...
@@ -752,7 +752,7 @@ dispatch(int4 *recv_x, float *recv_x_scales, int64_t *recv_topk_idx, float *recv
dst_slot_idx
*
num_bytes_per_rdma_token
,
num_bytes_per_rdma_token
*
num_tokens_to_issue
,
translate_dst_rdma_rank
<
kLowLatencyMode
>
(
dst_rdma_rank
,
nvl_rank
));
#if !defined(FORCE_
NV
SHMEM_API) && !defined(ROCM_DISABLE_CTX)
#if !defined(FORCE_
DU
SHMEM_API) && !defined(ROCM_DISABLE_CTX)
shmem_ctx_quiet
(
ctx
);
#else
shmem_fence
();
...
...
@@ -768,7 +768,7 @@ dispatch(int4 *recv_x, float *recv_x_scales, int64_t *recv_topk_idx, float *recv
last_issued_tail
+=
num_tokens_to_issue
;
num_tokens_to_send
-=
num_tokens_to_issue
;
// 更新远端rdma 己方已发送的token数,用于做发送信息同步。用于与kRDMAAndNVLForwarder互相通信
#if !defined(FORCE_
NV
SHMEM_API) && !defined(ROCM_DISABLE_CTX)
#if !defined(FORCE_
DU
SHMEM_API) && !defined(ROCM_DISABLE_CTX)
shmem_ctx_ulong_atomic_add
(
ctx
,
#else
shmem_signal_op_add
(
...
...
@@ -1008,7 +1008,7 @@ dispatch(int4 *recv_x, float *recv_x_scales, int64_t *recv_topk_idx, float *recv
// 更新远程头部
if
(
min_head
!=
std
::
numeric_limits
<
int
>::
max
()
&&
min_head
>=
last_head
+
num_max_rdma_chunked_send_tokens
&&
lane_id
<
kNumRDMARanks
){
#if !defined(FORCE_
NV
SHMEM_API) && !defined(ROCM_DISABLE_CTX)
#if !defined(FORCE_
DU
SHMEM_API) && !defined(ROCM_DISABLE_CTX)
shmem_ctx_ulong_atomic_add
(
ctx
,
#else
shmem_signal_op_add
(
...
...
@@ -1127,7 +1127,7 @@ dispatch(int4 *recv_x, float *recv_x_scales, int64_t *recv_topk_idx, float *recv
}
}
// while(num_tokens_to_recv > 0)
}
#if !defined(FORCE_
NV
SHMEM_API) && !defined(ROCM_DISABLE_CTX)
#if !defined(FORCE_
DU
SHMEM_API) && !defined(ROCM_DISABLE_CTX)
shmem_wg_ctx_destroy
(
&
ctx
);
#endif
}
...
...
@@ -1203,7 +1203,7 @@ cached_notify(const int rdma_clean_offset, const int rdma_num_int_clean, const i
if
(
sm_id
==
0
)
{
// Barrier for RDMA
if
(
thread_id
==
0
)
nv
shmem_barrier_with_same_gpu_idx
<
kLowLatencyMode
>
(
rdma_team
);
du
shmem_barrier_with_same_gpu_idx
<
kLowLatencyMode
>
(
rdma_team
);
__syncthreads
();
...
...
@@ -1216,7 +1216,7 @@ cached_notify(const int rdma_clean_offset, const int rdma_num_int_clean, const i
// Barrier again
if
(
thread_id
==
0
)
nv
shmem_barrier_with_same_gpu_idx
<
kLowLatencyMode
>
(
rdma_team
);
du
shmem_barrier_with_same_gpu_idx
<
kLowLatencyMode
>
(
rdma_team
);
}
else
if
(
sm_id
==
1
)
{
// Barrier for NVL
...
...
@@ -1417,7 +1417,7 @@ combine(int4 *combined_x, float *combined_topk_weights, const bool *is_combined_
kRDMACoordinator
,
kNVLCoordinator
};
#if !defined(FORCE_
NV
SHMEM_API) && !defined(ROCM_DISABLE_CTX)
#if !defined(FORCE_
DU
SHMEM_API) && !defined(ROCM_DISABLE_CTX)
__shared__
shmem_ctx_t
ctx
;
shmem_wg_ctx_create
(
&
ctx
);
#endif
...
...
@@ -1744,7 +1744,7 @@ combine(int4 *combined_x, float *combined_topk_weights, const bool *is_combined_
if
(
sub_warp_id
==
kNumWarpsPerForwarder
-
1
)
{
if
(
dst_rdma_rank
!=
rdma_rank
)
{
auto
rdma_slot_idx
=
token_start_idx
%
num_max_rdma_chunked_recv_tokens
;
#if !defined(FORCE_
NV
SHMEM_API) && !defined(ROCM_DISABLE_CTX)
#if !defined(FORCE_
DU
SHMEM_API) && !defined(ROCM_DISABLE_CTX)
shmem_ctx_schar_put_nbi_warp
(
ctx
,
#else
shmemx_int8_put_nbi_warp
(
...
...
@@ -1755,7 +1755,7 @@ combine(int4 *combined_x, float *combined_topk_weights, const bool *is_combined_
rdma_slot_idx
*
num_bytes_per_rdma_token
,
num_chunked_tokens
*
num_bytes_per_rdma_token
,
translate_dst_rdma_rank
<
kLowLatencyMode
>
(
dst_rdma_rank
,
nvl_rank
));
#if !defined(FORCE_
NV
SHMEM_API) && !defined(ROCM_DISABLE_CTX)
#if !defined(FORCE_
DU
SHMEM_API) && !defined(ROCM_DISABLE_CTX)
shmem_ctx_quiet
(
ctx
);
#else
shmem_fence
();
...
...
@@ -1767,7 +1767,7 @@ combine(int4 *combined_x, float *combined_topk_weights, const bool *is_combined_
// Write new RDMA tail
syncwarp
();
if
(
lane_id
==
0
)
{
#if !defined(FORCE_
NV
SHMEM_API) && !defined(ROCM_DISABLE_CTX)
#if !defined(FORCE_
DU
SHMEM_API) && !defined(ROCM_DISABLE_CTX)
shmem_ctx_ulong_atomic_add
(
ctx
,
#else
shmem_signal_op_add
(
...
...
@@ -1900,7 +1900,7 @@ combine(int4 *combined_x, float *combined_topk_weights, const bool *is_combined_
min_head
=
min
(
min_head
,
rdma_receiver_rdma_head
[
i
][
dst_rdma_rank
]);
if
(
min_head
!=
std
::
numeric_limits
<
int
>::
max
()
and
min_head
>=
last_rdma_head
+
num_max_rdma_chunked_send_tokens
and
lane_id
<
kNumRDMARanks
)
{
#if !defined(FORCE_
NV
SHMEM_API) && !defined(ROCM_DISABLE_CTX)
#if !defined(FORCE_
DU
SHMEM_API) && !defined(ROCM_DISABLE_CTX)
shmem_ctx_ulong_atomic_add
(
ctx
,
#else
shmem_signal_op_add
(
...
...
@@ -1917,7 +1917,7 @@ combine(int4 *combined_x, float *combined_topk_weights, const bool *is_combined_
}
}
}
#if !defined(FORCE_
NV
SHMEM_API) && !defined(ROCM_DISABLE_CTX)
#if !defined(FORCE_
DU
SHMEM_API) && !defined(ROCM_DISABLE_CTX)
shmem_wg_ctx_destroy
(
&
ctx
);
#endif
}
...
...
csrc/kernels/internode_ll.cu
View file @
d7f41337
...
...
@@ -257,10 +257,10 @@ dispatch(void* packed_recv_x, void* packed_recv_x_scales,
rank
*
num_max_dispatch_tokens_per_rank
*
num_bytes_per_msg
+
slot_idx
*
num_bytes_per_msg
;
if
(
dst_rank
!=
rank
)
{
#if defined(FORCE_
NV
SHMEM_API)
void
*
peer_base_addr
=
(
void
*
)
__ldg
((
const
long
long
unsigned
*
)
nv
shmemi_device_state_d
.
peer_heap_base_p2p
+
dst_rank
);
#if defined(FORCE_
DU
SHMEM_API)
void
*
peer_base_addr
=
(
void
*
)
__ldg
((
const
long
long
unsigned
*
)
du
shmemi_device_state_d
.
peer_heap_base_p2p
+
dst_rank
);
if
(
peer_base_addr
)
{
char
*
req_rptr_actual
=
(
char
*
)(
peer_base_addr
)
+
((
char
*
)
dst_ptr
-
(
char
*
)(
nv
shmemi_device_state_d
.
heap_base
));
char
*
req_rptr_actual
=
(
char
*
)(
peer_base_addr
)
+
((
char
*
)
dst_ptr
-
(
char
*
)(
du
shmemi_device_state_d
.
heap_base
));
const
auto
*
src_int4_ptr
=
reinterpret_cast
<
const
int4
*>
(
src_ptr
);
const
auto
*
dst_int4_ptr
=
reinterpret_cast
<
int4
*>
(
req_rptr_actual
);
UNROLLED_WARP_COPY
(
8
,
lane_id
,
num_int4_per_msg
,
dst_int4_ptr
,
src_int4_ptr
,
ld_nc_global
,
st_na_global
);
...
...
@@ -279,7 +279,7 @@ dispatch(void* packed_recv_x, void* packed_recv_x_scales,
reinterpret_cast
<
signed
char
*>
(
dst_ptr
),
reinterpret_cast
<
signed
char
*>
(
src_ptr
),
num_bytes_per_msg
,
(
dst_expert_local_idx
+
1
)
*
num_ranks
+
dst_rank
,
dst_rank
);
#endif
#endif // defined(FORCE_
NV
SHMEM_API)
#endif // defined(FORCE_
DU
SHMEM_API)
}
else
{
// NOTES: only 2 load iterations for 7K hidden with 8 unrolls
const
auto
*
src_int4_ptr
=
reinterpret_cast
<
const
int4
*>
(
src_ptr
);
...
...
@@ -342,11 +342,11 @@ dispatch(void* packed_recv_x, void* packed_recv_x_scales,
// Wait local sends issued and send expert counts
while
(
ld_acquire_global
(
atomic_finish_counter_per_expert
+
responsible_expert_idx
)
!=
FINISHED_SUM_TAG
*
2
);
if
(
dst_rank
!=
rank
)
{
#if defined(FORCE_
NV
SHMEM_API)
void
*
peer_base_addr
=
(
void
*
)
__ldg
((
const
long
long
unsigned
*
)
nv
shmemi_device_state_d
.
peer_heap_base_p2p
+
dst_rank
);
#if defined(FORCE_
DU
SHMEM_API)
void
*
peer_base_addr
=
(
void
*
)
__ldg
((
const
long
long
unsigned
*
)
du
shmemi_device_state_d
.
peer_heap_base_p2p
+
dst_rank
);
if
(
peer_base_addr
)
{
// P2P enabled
int
*
rptr_actual
=
(
int
*
)((
char
*
)(
peer_base_addr
)
+
((
char
*
)(
rdma_recv_count
+
dst_expert_local_idx
*
num_ranks
+
rank
)
-
(
char
*
)(
nv
shmemi_device_state_d
.
heap_base
)));
((
char
*
)(
rdma_recv_count
+
dst_expert_local_idx
*
num_ranks
+
rank
)
-
(
char
*
)(
du
shmemi_device_state_d
.
heap_base
)));
st_na_release
(
rptr_actual
,
-
num_tokens_sent
-
1
);
}
else
{
internode
::
shmem_long_atomic_add
(
...
...
@@ -361,7 +361,7 @@ dispatch(void* packed_recv_x, void* packed_recv_x_scales,
rdma_recv_count
+
dst_expert_local_idx
*
num_ranks
+
rank
,
-
num_tokens_sent
-
1
,
(
dst_expert_local_idx
+
1
)
*
num_ranks
+
dst_rank
,
dst_rank
);
#endif
#endif // defined(FORCE_
NV
SHMEM_API)
#endif // defined(FORCE_
DU
SHMEM_API)
}
else
{
st_na_release
(
reinterpret_cast
<
int
*>
(
rdma_recv_count
+
dst_expert_local_idx
*
num_ranks
+
rank
),
-
num_tokens_sent
-
1
);
}
...
...
@@ -640,10 +640,10 @@ combine(void* combined_x,
if
(
not
zero_copy
)
UNROLLED_WARP_COPY_LL
(
7
,
lane_id
,
hidden_bf16_int4
,
buf_int4_ptr
,
x_int4
,
ld_nc_global
,
st_na_global
);
#if defined(FORCE_
NV
SHMEM_API)
void
*
peer_base_addr
=
(
void
*
)
__ldg
((
const
long
long
unsigned
*
)
nv
shmemi_device_state_d
.
peer_heap_base_p2p
+
dst_rank
);
#if defined(FORCE_
DU
SHMEM_API)
void
*
peer_base_addr
=
(
void
*
)
__ldg
((
const
long
long
unsigned
*
)
du
shmemi_device_state_d
.
peer_heap_base_p2p
+
dst_rank
);
if
(
peer_base_addr
)
{
char
*
req_rptr_actual
=
(
char
*
)(
peer_base_addr
)
+
((
char
*
)
dst_ptr
-
(
char
*
)(
nv
shmemi_device_state_d
.
heap_base
));
char
*
req_rptr_actual
=
(
char
*
)(
peer_base_addr
)
+
((
char
*
)
dst_ptr
-
(
char
*
)(
du
shmemi_device_state_d
.
heap_base
));
const
auto
dst_int4_ptr
=
reinterpret_cast
<
int4
*>
(
req_rptr_actual
);
UNROLLED_WARP_COPY
(
7
,
lane_id
,
hidden_bf16_int4
,
dst_int4_ptr
,
x_int4
,
ld_nc_global
,
st_na_global
);
}
else
{
...
...
@@ -661,7 +661,7 @@ combine(void* combined_x,
reinterpret_cast
<
signed
char
*>
(
dst_ptr
),
reinterpret_cast
<
signed
char
*>
(
buf_ptr
),
hidden
*
sizeof
(
hip_bfloat16
),
(
local_expert_idx
+
1
)
*
num_ranks
+
dst_rank
,
dst_rank
);
#endif
#endif // defined(FORCE_
NV
SHMEM_API)
#endif // defined(FORCE_
DU
SHMEM_API)
}
}
...
...
@@ -676,11 +676,11 @@ combine(void* combined_x,
if
(
sub_warp_id
==
1
and
lane_id
==
0
)
{
while
(
ld_acquire_global
(
atomic_clean_flag
)
==
0
);
if
(
dst_rank
!=
rank
)
{
#if defined(FORCE_
NV
SHMEM_API)
void
*
peer_base_addr
=
(
void
*
)
__ldg
((
const
long
long
unsigned
*
)
nv
shmemi_device_state_d
.
peer_heap_base_p2p
+
dst_rank
);
#if defined(FORCE_
DU
SHMEM_API)
void
*
peer_base_addr
=
(
void
*
)
__ldg
((
const
long
long
unsigned
*
)
du
shmemi_device_state_d
.
peer_heap_base_p2p
+
dst_rank
);
if
(
peer_base_addr
)
{
int
*
req_rptr_actual
=
(
int
*
)((
char
*
)(
peer_base_addr
)
+
((
char
*
)(
rdma_recv_flag
+
global_expert_idx
)
-
(
char
*
)(
nv
shmemi_device_state_d
.
heap_base
)));
((
char
*
)(
rdma_recv_flag
+
global_expert_idx
)
-
(
char
*
)(
du
shmemi_device_state_d
.
heap_base
)));
st_na_release
(
req_rptr_actual
,
1
);
}
else
{
internode
::
shmem_long_atomic_add
(
...
...
@@ -695,7 +695,7 @@ combine(void* combined_x,
rdma_recv_flag
+
global_expert_idx
,
1
,
(
local_expert_idx
+
1
)
*
num_ranks
+
dst_rank
,
dst_rank
);
#endif
#endif // defined(FORCE_
NV
SHMEM_API)
#endif // defined(FORCE_
DU
SHMEM_API)
}
else
{
st_na_release
(
reinterpret_cast
<
int
*>
(
rdma_recv_flag
+
global_expert_idx
),
1
);
}
...
...
csrc/kernels/runtime.cu
View file @
d7f41337
...
...
@@ -61,9 +61,9 @@ int init(const std::vector<uint8_t> &root_unique_id_val, int rank, int num_ranks
&
cpu_rdma_team_config
,
0
,
&
cpu_rdma_team
)
==
0
);
EP_HOST_ASSERT
(
cpu_rdma_team
!=
EP_SHMEM_TEAM_INVALID
);
#ifdef FORCE_
NV
SHMEM_API
nv
shmemi_device_host_state_t
*
dev_state_ptr
=
nullptr
;
CUDA_CHECK
(
hipGetSymbolAddress
(
reinterpret_cast
<
void
**>
(
&
dev_state_ptr
),
nv
shmemi_device_state_d
));
#ifdef FORCE_
DU
SHMEM_API
du
shmemi_device_host_state_t
*
dev_state_ptr
=
nullptr
;
CUDA_CHECK
(
hipGetSymbolAddress
(
reinterpret_cast
<
void
**>
(
&
dev_state_ptr
),
du
shmemi_device_state_d
));
bool
ibgda_is_initialized
=
false
;
CUDA_CHECK
(
hipMemcpy
(
&
dev_state_ptr
->
ibgda_is_initialized
,
&
ibgda_is_initialized
,
sizeof
(
bool
),
hipMemcpyHostToDevice
));
#endif
...
...
csrc/kernels/shmem_wrapper.cuh
View file @
d7f41337
#pragma once
/*
* Temporary wrapper for for platform specific
NV
SHMEM and rocSHMEM functions.
* Temporary wrapper for for platform specific
DU
SHMEM and rocSHMEM functions.
* Once hipify or hipify-torch fully supports this mapping, this file has to be
* removed and according
nv
shmem* functions restored.
* removed and according
du
shmem* functions restored.
*/
#ifndef DISABLE_ROCSHMEM
#include "configs.cuh"
#ifndef FORCE_
NV
SHMEM_API
#ifndef FORCE_
DU
SHMEM_API
#include <hip/hip_bfloat16.h>
#include <hip/hip_fp8.h>
#include <hip/hip_runtime.h>
#include <rocshmem/rocshmem.hpp>
#else
#include <device_host_transport/
nv
shmem_common_ibgda.h>
#include <device_host_transport/
du
shmem_common_ibgda.h>
#include <infiniband/mlx5dv.h>
#include <
nv
shmem.h>
#include <
nv
shmemx.h>
#include <non_abi/device/threadgroup/
nv
shmemi_common_device_defines.cuh>
#include <
du
shmem.h>
#include <
du
shmemx.h>
#include <non_abi/device/threadgroup/
du
shmemi_common_device_defines.cuh>
#endif
namespace
deep_ep
::
internode
{
// rocSHMEM wrapper
#ifndef FORCE_
NV
SHMEM_API
#ifndef FORCE_
DU
SHMEM_API
using
shmem_team_t
=
rocshmem
::
rocshmem_team_t
;
using
shmem_team_config_t
=
rocshmem
::
rocshmem_team_config_t
;
const
shmem_team_t
EP_SHMEM_TEAM_INVALID
=
rocshmem
::
ROCSHMEM_TEAM_INVALID
;
...
...
@@ -171,106 +171,106 @@ __device__ inline void shmem_ctx_int_put_nbi_warp(
#else
//
NV
SHMEM wrapper
//
DU
SHMEM wrapper
#ifndef ROCM_DISABLE_CTX
#define ROCM_DISABLE_CTX
#endif
using
shmem_team_t
=
nv
shmem_team_t
;
using
shmem_team_config_t
=
nv
shmem_team_config_t
;
using
shmemx_uniqueid_t
=
nv
shmemx_uniqueid_t
;
using
shmemx_init_attr_t
=
nv
shmemx_init_attr_t
;
const
shmem_team_t
EP_SHMEM_TEAM_INVALID
=
NV
SHMEM_TEAM_INVALID
;
const
shmem_team_t
EP_SHMEM_TEAM_WORLD
=
NV
SHMEM_TEAM_WORLD
;
constexpr
auto
EP_SHMEMX_INIT_WITH_UNIQUEID
=
NV
SHMEMX_INIT_WITH_UNIQUEID
;
using
shmem_team_t
=
du
shmem_team_t
;
using
shmem_team_config_t
=
du
shmem_team_config_t
;
using
shmemx_uniqueid_t
=
du
shmemx_uniqueid_t
;
using
shmemx_init_attr_t
=
du
shmemx_init_attr_t
;
const
shmem_team_t
EP_SHMEM_TEAM_INVALID
=
DU
SHMEM_TEAM_INVALID
;
const
shmem_team_t
EP_SHMEM_TEAM_WORLD
=
DU
SHMEM_TEAM_WORLD
;
constexpr
auto
EP_SHMEMX_INIT_WITH_UNIQUEID
=
DU
SHMEMX_INIT_WITH_UNIQUEID
;
__host__
inline
int
shmemx_get_uniqueid
(
shmemx_uniqueid_t
*
uid
)
{
return
nv
shmemx_get_uniqueid
(
uid
);
return
du
shmemx_get_uniqueid
(
uid
);
}
__host__
inline
int
shmemx_set_attr_uniqueid_args
(
int
rank
,
int
nranks
,
shmemx_uniqueid_t
*
uid
,
shmemx_init_attr_t
*
attr
)
{
return
nv
shmemx_set_attr_uniqueid_args
(
rank
,
nranks
,
uid
,
attr
);
return
du
shmemx_set_attr_uniqueid_args
(
rank
,
nranks
,
uid
,
attr
);
}
__host__
inline
int
shmemx_init_attr
(
unsigned
int
flags
,
shmemx_init_attr_t
*
attr
)
{
return
nv
shmemx_init_attr
(
flags
,
attr
);
return
du
shmemx_init_attr
(
flags
,
attr
);
}
__host__
inline
int
shmem_team_split_strided
(
shmem_team_t
parent_team
,
int
start
,
int
stride
,
int
size
,
const
shmem_team_config_t
*
config
,
long
config_mask
,
shmem_team_t
*
new_team
)
{
return
nv
shmem_team_split_strided
(
parent_team
,
start
,
stride
,
size
,
config
,
config_mask
,
new_team
);
return
du
shmem_team_split_strided
(
parent_team
,
start
,
stride
,
size
,
config
,
config_mask
,
new_team
);
}
__host__
inline
void
shmem_barrier_all
()
{
nv
shmem_barrier_all
();
du
shmem_barrier_all
();
}
__device__
inline
void
shmem_device_barrier_all
()
{
nv
shmem_barrier_all
();
du
shmem_barrier_all
();
}
__device__
inline
void
shmem_barrier
(
shmem_team_t
team
)
{
void
(
nv
shmem_barrier
(
team
));
void
(
du
shmem_barrier
(
team
));
}
__host__
inline
int
shmem_my_pe
(){
return
nv
shmem_my_pe
();
return
du
shmem_my_pe
();
}
__host__
inline
void
shmem_free
(
void
*
ptr
){
nv
shmem_free
(
ptr
);
du
shmem_free
(
ptr
);
}
__host__
inline
void
*
shmem_align
(
const
size_t
alignment
,
const
size_t
size
)
{
return
nv
shmem_align
(
size
,
alignment
);
return
du
shmem_align
(
size
,
alignment
);
}
__host__
inline
void
shmem_finalize
()
{
nv
shmem_finalize
();
du
shmem_finalize
();
}
__host__
inline
void
shmem_team_destroy
(
shmem_team_t
team
)
{
nv
shmem_team_destroy
(
team
);
du
shmem_team_destroy
(
team
);
}
__device__
inline
void
shmem_fence
()
{
nv
shmem_fence
();
du
shmem_fence
();
}
__device__
inline
void
shmem_int_put_nbi
(
int
*
dest
,
const
int
*
source
,
size_t
nelems
,
int
pe
)
{
nv
shmem_int_put_nbi
(
dest
,
source
,
nelems
,
pe
);
du
shmem_int_put_nbi
(
dest
,
source
,
nelems
,
pe
);
}
__device__
inline
void
shmemx_int_put_nbi_warp
(
int
*
dest
,
const
int
*
source
,
size_t
nelems
,
int
pe
)
{
nv
shmemx_int_put_nbi_warp
(
dest
,
source
,
nelems
,
pe
);
du
shmemx_int_put_nbi_warp
(
dest
,
source
,
nelems
,
pe
);
}
__device__
inline
void
shmemx_int8_put_nbi_warp
(
signed
char
*
dest
,
const
signed
char
*
source
,
size_t
nelems
,
int
pe
)
{
nv
shmemx_int8_put_nbi_warp
(
dest
,
source
,
nelems
,
pe
);
du
shmemx_int8_put_nbi_warp
(
dest
,
source
,
nelems
,
pe
);
}
__device__
inline
void
shmem_signal_op_add
(
uint64_t
*
dest
,
uint64_t
value
,
int
pe
)
{
nv
shmemx_signal_op
(
dest
,
value
,
NV
SHMEM_SIGNAL_ADD
,
pe
);
du
shmemx_signal_op
(
dest
,
value
,
DU
SHMEM_SIGNAL_ADD
,
pe
);
}
__device__
inline
void
shmem_ulong_atomic_add
(
uint64_t
*
dest
,
uint64_t
value
,
int
pe
)
{
nv
shmem_ulong_atomic_add
(
dest
,
value
,
pe
);
du
shmem_ulong_atomic_add
(
dest
,
value
,
pe
);
}
__device__
inline
void
shmem_long_atomic_add
(
long
*
dest
,
long
value
,
int
pe
)
{
//
nv
shmem_##Name##_atomic_add(dest, value, pe);
nv
shmem_long_atomic_add
(
dest
,
value
,
pe
);
//
du
shmem_##Name##_atomic_add(dest, value, pe);
du
shmem_long_atomic_add
(
dest
,
value
,
pe
);
}
#endif
...
...
deep_ep/buffer.py
View file @
d7f41337
...
...
@@ -96,46 +96,46 @@ class Buffer:
local_ipc_handle
=
self
.
runtime
.
get_local_ipc_handle
()
dist
.
all_gather_object
(
ipc_handles
,
local_ipc_handle
,
group
)
# Synchronize
NV
SHMEM unique IDs
# Synchronize
DU
SHMEM unique IDs
root_unique_id
=
None
if
self
.
runtime
.
get_num_rdma_ranks
()
>
1
or
low_latency_mode
:
# Enable IBGDA
self
.
_setup_device_hca_mapping
()
assert
num_qps_per_rank
>
0
os
.
environ
[
"
NV
SHMEM_DISABLE_P2P"
]
=
"0"
if
allow_nvlink_for_low_latency_mode
else
"1"
# os.environ["
NV
SHMEM_IB_ENABLE_IBGDA"] = "1"
os
.
environ
[
"
NV
SHMEM_IB_ENABLE_IBGDA"
]
=
"0"
# force_use_ibrc
os
.
environ
[
"
DU
SHMEM_DISABLE_P2P"
]
=
"0"
if
allow_nvlink_for_low_latency_mode
else
"1"
# os.environ["
DU
SHMEM_IB_ENABLE_IBGDA"] = "1"
os
.
environ
[
"
DU
SHMEM_IB_ENABLE_IBGDA"
]
=
"0"
# force_use_ibrc
os
.
environ
[
"
NV
SHMEM_IBGDA_NIC_HANDLER"
]
=
"gpu"
os
.
environ
[
"
NV
SHMEM_IB_DISABLE_DMABUF"
]
=
"1"
os
.
environ
[
"
NV
SHMEM_ENABLE_NIC_PE_MAPPING"
]
=
"1"
os
.
environ
[
"
DU
SHMEM_IBGDA_NIC_HANDLER"
]
=
"gpu"
os
.
environ
[
"
DU
SHMEM_IB_DISABLE_DMABUF"
]
=
"1"
os
.
environ
[
"
DU
SHMEM_ENABLE_NIC_PE_MAPPING"
]
=
"1"
os
.
environ
[
"
NV
SHMEM_IBGDA_NUM_RC_PER_PE"
]
=
f
"
{
num_qps_per_rank
}
"
os
.
environ
[
"
DU
SHMEM_IBGDA_NUM_RC_PER_PE"
]
=
f
"
{
num_qps_per_rank
}
"
# Make sure QP depth is always larger than the number of on-flight WRs, so that we can skip WQ slot check
os
.
environ
[
"
NV
SHMEM_QP_DEPTH"
]
=
os
.
environ
.
get
(
"
NV
SHMEM_QP_DEPTH"
,
"1024"
)
os
.
environ
[
"
DU
SHMEM_QP_DEPTH"
]
=
os
.
environ
.
get
(
"
DU
SHMEM_QP_DEPTH"
,
"1024"
)
# Reduce gpu memory usage
# 6 default teams + 1 extra team
os
.
environ
[
"
NV
SHMEM_MAX_TEAMS"
]
=
"7"
os
.
environ
[
"
DU
SHMEM_MAX_TEAMS"
]
=
"7"
# Disable NVLink SHArP
os
.
environ
[
"
NV
SHMEM_DISABLE_NVLS"
]
=
"1"
# NOTES:
NV
SHMEM initialization requires at least 256 MiB
os
.
environ
[
"
NV
SHMEM_CUMEM_GRANULARITY"
]
=
f
"
{
2
**
29
}
"
os
.
environ
[
"
DU
SHMEM_DISABLE_NVLS"
]
=
"1"
# NOTES:
DU
SHMEM initialization requires at least 256 MiB
os
.
environ
[
"
DU
SHMEM_CUMEM_GRANULARITY"
]
=
f
"
{
2
**
29
}
"
if
not
allow_mnnvl
:
# Disable multi-node NVLink detection
os
.
environ
[
"
NV
SHMEM_DISABLE_MNNVL"
]
=
"1"
os
.
environ
[
"
DU
SHMEM_DISABLE_MNNVL"
]
=
"1"
# Synchronize using the root ID
nv
shmem_unique_ids
=
[
du
shmem_unique_ids
=
[
None
,
]
*
self
.
group_size
if
(
low_latency_mode
and
self
.
rank
==
0
)
or
(
not
low_latency_mode
and
self
.
runtime
.
get_rdma_rank
()
==
0
):
root_unique_id
=
self
.
runtime
.
get_local_
nv
shmem_unique_id
()
dist
.
all_gather_object
(
nv
shmem_unique_ids
,
root_unique_id
,
group
)
root_unique_id
=
nv
shmem_unique_ids
[
root_unique_id
=
self
.
runtime
.
get_local_
du
shmem_unique_id
()
dist
.
all_gather_object
(
du
shmem_unique_ids
,
root_unique_id
,
group
)
root_unique_id
=
du
shmem_unique_ids
[
0
if
low_latency_mode
else
self
.
runtime
.
get_root_rdma_rank
(
True
)
]
...
...
@@ -169,9 +169,9 @@ class Buffer:
# assert visible_devices[current_device].isdigit(), f"DEEP_EP_DEVICE_TO_HCA_MAPPING requires CUDA_VISIBLE_DEVICES to contain integer indices"
# current_device = int(visible_devices[current_device])
assert
current_device
in
device_mapping
,
f
"Current
CUDA
device
{
current_device
}
not found in DEEP_EP_DEVICE_TO_HCA_MAPPING"
os
.
environ
[
'
NV
SHMEM_ENABLE_PE_MAPPING'
]
=
'1'
os
.
environ
[
'
NV
SHMEM_HCA_LIST'
]
=
device_mapping
[
current_device
]
assert
current_device
in
device_mapping
,
f
"Current
HIP
device
{
current_device
}
not found in DEEP_EP_DEVICE_TO_HCA_MAPPING"
os
.
environ
[
'
DU
SHMEM_ENABLE_PE_MAPPING'
]
=
'1'
os
.
environ
[
'
DU
SHMEM_HCA_LIST'
]
=
device_mapping
[
current_device
]
def
destroy
(
self
):
"""
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment