Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
DeepEP
Commits
1b497233
Commit
1b497233
authored
Dec 30, 2025
by
lishen
Browse files
Merge branch 'updates' into 'main'
Updates See merge request dcutoolkit/deeplearing/DeepEP!12
parents
1b00b9d8
94694314
Changes
7
Hide whitespace changes
Inline
Side-by-side
Showing
7 changed files
with
115 additions
and
95 deletions
+115
-95
1.sh
1.sh
+4
-3
2.sh
2.sh
+3
-2
build.sh
build.sh
+27
-15
csrc/config.hpp
csrc/config.hpp
+4
-4
csrc/kernels/configs.cuh
csrc/kernels/configs.cuh
+0
-2
csrc/kernels/internode_ll.cu
csrc/kernels/internode_ll.cu
+61
-69
csrc/kernels/shmem_wrapper.cuh
csrc/kernels/shmem_wrapper.cuh
+16
-0
No files found.
1.sh
View file @
1b497233
pgrep
-f
/usr/bin/python | xargs
kill
-9
pgrep
-f
/usr/bin/python | xargs
kill
-9
export
OMPI_MCA_pml
=
ucx
export
OMPI_MCA_pml
=
ucx
export
OMPI_MCA_osc
=
ucx
export
OMPI_MCA_osc
=
ucx
...
@@ -6,7 +6,8 @@ export OMPI_MCA_coll_hcoll_enable=0
...
@@ -6,7 +6,8 @@ export OMPI_MCA_coll_hcoll_enable=0
export
UCX_TLS
=
rc,rocm
export
UCX_TLS
=
rc,rocm
# export ROCSHMEM_UNIQUEID_WITH_MPI=1
# export ROCSHMEM_UNIQUEID_WITH_MPI=1
export
OMPI_MCA_rmaps_base_mapping_policy
=
"slot:numa"
export
OMPI_MCA_rmaps_base_mapping_policy
=
"slot:numa"
export
ROCSHMEM_MAX_NUM_CONTEXTS
=
32
export
ROCSHMEM_GDA_NUM_QPS_DEFAULT_CTX
=
288
export
ROCSHMEM_MAX_NUM_CONTEXTS
=
48
export
UCX_ROCM_IPC_SIGPOOL_MAX_ELEMS
=
16384
export
UCX_ROCM_IPC_SIGPOOL_MAX_ELEMS
=
16384
export
UCX_NET_DEVICES
=
mlx5_2:1,mlx5_3:1,mlx5_4:1,mlx5_5:1,mlx5_6:1,mlx5_7:1,mlx5_8:1,mlx5_9:1
export
UCX_NET_DEVICES
=
mlx5_2:1,mlx5_3:1,mlx5_4:1,mlx5_5:1,mlx5_6:1,mlx5_7:1,mlx5_8:1,mlx5_9:1
export
ROCSHMEM_ALLOWED_IBV_DEVICES
=
mlx5_2,mlx5_3,mlx5_4,mlx5_5,mlx5_6,mlx5_7,mlx5_8,mlx5_9
export
ROCSHMEM_ALLOWED_IBV_DEVICES
=
mlx5_2,mlx5_3,mlx5_4,mlx5_5,mlx5_6,mlx5_7,mlx5_8,mlx5_9
...
@@ -15,5 +16,5 @@ export HIP_VISIBLE_DEVICES=0,1,2,3,4,5,6,7
...
@@ -15,5 +16,5 @@ export HIP_VISIBLE_DEVICES=0,1,2,3,4,5,6,7
export
ROCSHMEM_HEAP_SIZE
=
10737418240
export
ROCSHMEM_HEAP_SIZE
=
10737418240
# torchrun --nproc-per-node=1 --nnodes=2 --node-rank=0 --master-addr="10.16.1.37" --master-port=1234 tests/test_internode.py
# torchrun --nproc-per-node=1 --nnodes=2 --node-rank=0 --master-addr="10.16.1.37" --master-port=1234 tests/test_internode.py
# torchrun --nproc-per-node=1 --nnodes=2 --node-rank=0 --master-addr="10.16.1.37" --master-port=1234 tests/test_low_latency.py
# torchrun --nproc-per-node=1 --nnodes=2 --node-rank=0 --master-addr="10.16.1.37" --master-port=1234 tests/test_low_latency.py
# torchrun --nproc-per-node=1 --nnodes=2 --node-rank=0 --master-addr="10.16.1.37" --master-port=1234 tests/test_low_latency_new.py
# torchrun --nproc-per-node=1 --nnodes=2 --node-rank=0 --master-addr="10.16.1.37" --master-port=1234 tests/test_low_latency_new.py
--pressure-test
torchrun
--nproc-per-node
=
1
--nnodes
=
2
--node-rank
=
0
--master-addr
=
"10.16.1.37"
--master-port
=
1234 tests/test_internode.py
--test-ll-compatibility
torchrun
--nproc-per-node
=
1
--nnodes
=
2
--node-rank
=
0
--master-addr
=
"10.16.1.37"
--master-port
=
1234 tests/test_internode.py
--test-ll-compatibility
2.sh
View file @
1b497233
...
@@ -6,7 +6,8 @@ export OMPI_MCA_coll_hcoll_enable=0
...
@@ -6,7 +6,8 @@ export OMPI_MCA_coll_hcoll_enable=0
export
UCX_TLS
=
rc,rocm
export
UCX_TLS
=
rc,rocm
# export ROCSHMEM_UNIQUEID_WITH_MPI=1
# export ROCSHMEM_UNIQUEID_WITH_MPI=1
export
OMPI_MCA_rmaps_base_mapping_policy
=
"slot:numa"
export
OMPI_MCA_rmaps_base_mapping_policy
=
"slot:numa"
export
ROCSHMEM_MAX_NUM_CONTEXTS
=
32
export
ROCSHMEM_GDA_NUM_QPS_DEFAULT_CTX
=
288
export
ROCSHMEM_MAX_NUM_CONTEXTS
=
48
export
UCX_ROCM_IPC_SIGPOOL_MAX_ELEMS
=
16384
export
UCX_ROCM_IPC_SIGPOOL_MAX_ELEMS
=
16384
export
UCX_NET_DEVICES
=
mlx5_2:1,mlx5_3:1,mlx5_4:1,mlx5_5:1,mlx5_6:1,mlx5_7:1,mlx5_8:1,mlx5_9:1
export
UCX_NET_DEVICES
=
mlx5_2:1,mlx5_3:1,mlx5_4:1,mlx5_5:1,mlx5_6:1,mlx5_7:1,mlx5_8:1,mlx5_9:1
export
ROCSHMEM_ALLOWED_IBV_DEVICES
=
mlx5_2,mlx5_3,mlx5_4,mlx5_5,mlx5_6,mlx5_7,mlx5_8,mlx5_9
export
ROCSHMEM_ALLOWED_IBV_DEVICES
=
mlx5_2,mlx5_3,mlx5_4,mlx5_5,mlx5_6,mlx5_7,mlx5_8,mlx5_9
...
@@ -15,5 +16,5 @@ export HIP_VISIBLE_DEVICES=0,1,2,3,4,5,6,7
...
@@ -15,5 +16,5 @@ export HIP_VISIBLE_DEVICES=0,1,2,3,4,5,6,7
export
ROCSHMEM_HEAP_SIZE
=
10737418240
export
ROCSHMEM_HEAP_SIZE
=
10737418240
# torchrun --nproc-per-node=1 --nnodes=2 --node-rank=1 --master-addr="10.16.1.37" --master-port=1234 tests/test_internode.py
# torchrun --nproc-per-node=1 --nnodes=2 --node-rank=1 --master-addr="10.16.1.37" --master-port=1234 tests/test_internode.py
# torchrun --nproc-per-node=1 --nnodes=2 --node-rank=1 --master-addr="10.16.1.37" --master-port=1234 tests/test_low_latency.py
# torchrun --nproc-per-node=1 --nnodes=2 --node-rank=1 --master-addr="10.16.1.37" --master-port=1234 tests/test_low_latency.py
# torchrun --nproc-per-node=1 --nnodes=2 --node-rank=1 --master-addr="10.16.1.37" --master-port=1234 tests/test_low_latency_new.py
# torchrun --nproc-per-node=1 --nnodes=2 --node-rank=1 --master-addr="10.16.1.37" --master-port=1234 tests/test_low_latency_new.py
--pressure-test
torchrun
--nproc-per-node
=
1
--nnodes
=
2
--node-rank
=
1
--master-addr
=
"10.16.1.37"
--master-port
=
1234 tests/test_internode.py
--test-ll-compatibility
torchrun
--nproc-per-node
=
1
--nnodes
=
2
--node-rank
=
1
--master-addr
=
"10.16.1.37"
--master-port
=
1234 tests/test_internode.py
--test-ll-compatibility
build.sh
View file @
1b497233
...
@@ -31,24 +31,33 @@ PYTHON_PLATLIB=$(python3 -c "from sysconfig import get_paths; print(get_paths()[
...
@@ -31,24 +31,33 @@ PYTHON_PLATLIB=$(python3 -c "from sysconfig import get_paths; print(get_paths()[
USE_NVSHMEM
=
OFF
USE_NVSHMEM
=
OFF
USE_ROCSHMEM
=
OFF
USE_ROCSHMEM
=
OFF
ROCM_DISABLE_CTX
=
OFF
ROCM_DISABLE_CTX
=
OFF
case
"
$1
"
in
ROCM_USE_MULTIQP
=
OFF
rocshmem
)
# 解析命令行参数
USE_ROCSHMEM
=
ON
for
arg
in
"
$@
"
;
do
;;
case
$arg
in
nvshmem|dushmem
)
rocshmem
)
USE_NVSHMEM
=
ON
USE_ROCSHMEM
=
ON
;;
;;
*
)
nvshmem|dushmem
)
echo
"Usage: ./build.sh rocshmem [ROCM_DISABLE_CTX] / ./build.sh nvshmem"
USE_NVSHMEM
=
ON
exit
1
;;
;;
ROCM_DISABLE_CTX
=
ON
)
esac
ROCM_DISABLE_CTX
=
ON
if
[
"
${
2
:-}
"
=
"ROCM_DISABLE_CTX"
]
;
then
;;
ROCM_DISABLE_CTX
=
ON
ROCM_USE_MULTIQP
=
ON
)
fi
ROCM_USE_MULTIQP
=
ON
;;
*
)
echo
"Usage: ./build.sh rocshmem [ROCM_DISABLE_CTX=ON] [ROCM_USE_MULTIQP=ON] / ./build.sh nvshmem"
exit
1
;;
esac
done
echo
"USE_NVSHMEM=
$USE_NVSHMEM
"
echo
"USE_NVSHMEM=
$USE_NVSHMEM
"
echo
"USE_ROCSHMEM=
$USE_ROCSHMEM
"
echo
"USE_ROCSHMEM=
$USE_ROCSHMEM
"
echo
"ROCM_DISABLE_CTX=
$ROCM_DISABLE_CTX
"
echo
"ROCM_DISABLE_CTX=
$ROCM_DISABLE_CTX
"
echo
"ROCM_USE_MULTIQP=
$ROCM_USE_MULTIQP
"
# -------------------------- With rocSHMEM -------------------------- #
# -------------------------- With rocSHMEM -------------------------- #
build_rocshmem
()
build_rocshmem
()
...
@@ -84,6 +93,9 @@ if [ "$USE_ROCSHMEM" == "ON" ]; then
...
@@ -84,6 +93,9 @@ if [ "$USE_ROCSHMEM" == "ON" ]; then
if
[
"
$ROCM_DISABLE_CTX
"
==
"ON"
]
;
then
if
[
"
$ROCM_DISABLE_CTX
"
==
"ON"
]
;
then
COMPILE_OPTIONS
=
"-DROCM_DISABLE_CTX
$COMPILE_OPTIONS
"
COMPILE_OPTIONS
=
"-DROCM_DISABLE_CTX
$COMPILE_OPTIONS
"
fi
fi
if
[
"
$ROCM_USE_MULTIQP
"
==
"ON"
]
;
then
COMPILE_OPTIONS
=
"-DROCM_USE_MULTIQP
$COMPILE_OPTIONS
"
fi
SHMEM_LINK_OPTIONS
=
${
SHMEM_LINK_OPTIONS
:
=
"-Wl,-rpath,
${
SHMEM_INSTALL_PREFIX
}
/lib/ -l:librocshmem.a"
}
SHMEM_LINK_OPTIONS
=
${
SHMEM_LINK_OPTIONS
:
=
"-Wl,-rpath,
${
SHMEM_INSTALL_PREFIX
}
/lib/ -l:librocshmem.a"
}
fi
fi
# -------------------------- rocSHMEM END -------------------------- #
# -------------------------- rocSHMEM END -------------------------- #
...
...
csrc/config.hpp
View file @
1b497233
...
@@ -44,10 +44,10 @@ struct Config {
...
@@ -44,10 +44,10 @@ struct Config {
constexpr
int
kNumMaxTopK
=
128
;
constexpr
int
kNumMaxTopK
=
128
;
constexpr
int
kNumMaxScales
=
128
;
constexpr
int
kNumMaxScales
=
128
;
EP_HOST_ASSERT
(
num_ranks
<
NUM_MAX_NVL_PEERS
or
num_ranks
%
NUM_MAX_NVL_PEERS
==
0
);
EP_HOST_ASSERT
(
num_ranks
<
NUM_MAX_NVL_PEERS
or
num_ranks
%
NUM_MAX_NVL_PEERS
==
0
);
EP_HOST_ASSERT
(
num_ranks
<=
NUM_MAX_NVL_PEERS
or
num_sms
%
2
==
0
);
EP_HOST_ASSERT
(
num_ranks
<=
NUM_MAX_NVL_PEERS
or
num_sms
%
(
2
*
NUM_INTERNODE_DISPATCH_BLOCKS_PER_CHANNEL
)
==
0
);
const
auto
num_rdma_ranks
=
std
::
max
(
num_ranks
/
NUM_MAX_NVL_PEERS
,
1
);
const
auto
num_rdma_ranks
=
std
::
max
(
num_ranks
/
NUM_MAX_NVL_PEERS
,
1
);
const
auto
num_nvl_ranks
=
std
::
min
(
num_ranks
,
NUM_MAX_NVL_PEERS
);
const
auto
num_nvl_ranks
=
std
::
min
(
num_ranks
,
NUM_MAX_NVL_PEERS
);
const
int
num_channels
=
num_sms
/
2
;
const
int
num_channels
=
num_sms
;
size_t
num_bytes
=
0
;
size_t
num_bytes
=
0
;
num_bytes
+=
num_channels
*
num_nvl_ranks
*
(
2
*
num_rdma_ranks
+
3
)
*
sizeof
(
int
);
num_bytes
+=
num_channels
*
num_nvl_ranks
*
(
2
*
num_rdma_ranks
+
3
)
*
sizeof
(
int
);
...
@@ -77,9 +77,9 @@ struct Config {
...
@@ -77,9 +77,9 @@ struct Config {
constexpr
int
kNumMaxTopK
=
128
;
constexpr
int
kNumMaxTopK
=
128
;
constexpr
int
kNumMaxScales
=
128
;
constexpr
int
kNumMaxScales
=
128
;
EP_HOST_ASSERT
(
num_ranks
%
NUM_MAX_NVL_PEERS
==
0
);
EP_HOST_ASSERT
(
num_ranks
%
NUM_MAX_NVL_PEERS
==
0
);
EP_HOST_ASSERT
(
num_sms
%
2
==
0
);
EP_HOST_ASSERT
(
num_sms
%
NUM_INTERNODE_DISPATCH_BLOCKS_PER_CHANNEL
==
0
);
const
int
num_rdma_ranks
=
num_ranks
/
NUM_MAX_NVL_PEERS
;
const
int
num_rdma_ranks
=
num_ranks
/
NUM_MAX_NVL_PEERS
;
const
int
num_channels
=
num_sms
/
2
;
const
int
num_channels
=
num_sms
;
size_t
num_bytes
=
0
;
size_t
num_bytes
=
0
;
num_bytes
+=
num_channels
*
num_rdma_ranks
*
(
NUM_MAX_NVL_PEERS
*
2
+
2
)
*
2
*
sizeof
(
int
);
num_bytes
+=
num_channels
*
num_rdma_ranks
*
(
NUM_MAX_NVL_PEERS
*
2
+
2
)
*
2
*
sizeof
(
int
);
...
...
csrc/kernels/configs.cuh
View file @
1b497233
...
@@ -25,8 +25,6 @@
...
@@ -25,8 +25,6 @@
#define NUM_INTERNODE_DISPATCH_BLOCKS_PER_CHANNEL 3
#define NUM_INTERNODE_DISPATCH_BLOCKS_PER_CHANNEL 3
#define FP8_QUANTIZATION_NUM_PER_CHANNEL 128
#define FP8_QUANTIZATION_NUM_PER_CHANNEL 128
#define NUM_INTERNODE_DISPATCH_BLOCKS_PER_CHANNEL 3
#define DEFAULT_NUM_CU 20
#define DEFAULT_NUM_CU 20
#define DEFAULT_NUM_MAX_XGMI_CHUNKED_SEND_TOKENS 6
#define DEFAULT_NUM_MAX_XGMI_CHUNKED_SEND_TOKENS 6
#define DEFAULT_NUM_MAX_XGMI_CHUNKED_RECV_TOKENS 256
#define DEFAULT_NUM_MAX_XGMI_CHUNKED_RECV_TOKENS 256
...
...
csrc/kernels/internode_ll.cu
View file @
1b497233
...
@@ -8,9 +8,6 @@
...
@@ -8,9 +8,6 @@
#include "hip/hip_runtime.h"
#include "hip/hip_runtime.h"
// low latency+RocSHMEM has issue with CTX.
#define ROCM_DISABLE_CTX
#include "shmem_wrapper.cuh"
#include "shmem_wrapper.cuh"
namespace
deep_ep
{
namespace
deep_ep
{
...
@@ -133,11 +130,6 @@ dispatch(void* packed_recv_x, void* packed_recv_x_scales,
...
@@ -133,11 +130,6 @@ dispatch(void* packed_recv_x, void* packed_recv_x_scales,
if
((
phases
&
LOW_LATENCY_SEND_PHASE
)
==
0
)
if
((
phases
&
LOW_LATENCY_SEND_PHASE
)
==
0
)
goto
LOW_LATENCY_DISPATCH_RECV
;
goto
LOW_LATENCY_DISPATCH_RECV
;
#if !defined(ROCM_DISABLE_CTX)
__shared__
internode
::
shmem_ctx_t
ctx
;
internode
::
shmem_wg_ctx_create
(
&
ctx
);
#endif
// There are 2 kinds of warps in this part:
// There are 2 kinds of warps in this part:
// 1. The first-kind warps for FP8 cast and sending top-k tokens
// 1. The first-kind warps for FP8 cast and sending top-k tokens
// 2. The last warp for reading `topk_idx` and count for per-expert information
// 2. The last warp for reading `topk_idx` and count for per-expert information
...
@@ -265,24 +257,29 @@ dispatch(void* packed_recv_x, void* packed_recv_x_scales,
...
@@ -265,24 +257,29 @@ dispatch(void* packed_recv_x, void* packed_recv_x_scales,
rank
*
num_max_dispatch_tokens_per_rank
*
num_bytes_per_msg
+
rank
*
num_max_dispatch_tokens_per_rank
*
num_bytes_per_msg
+
slot_idx
*
num_bytes_per_msg
;
slot_idx
*
num_bytes_per_msg
;
if
(
dst_rank
!=
rank
)
{
if
(
dst_rank
!=
rank
)
{
#if defined(FORCE_NVSHMEM_API)
#if defined(FORCE_NVSHMEM_API)
void
*
peer_base_addr
=
(
void
*
)
__ldg
((
const
long
long
unsigned
*
)
nvshmemi_device_state_d
.
peer_heap_base_p2p
+
dst_rank
);
void
*
peer_base_addr
=
(
void
*
)
__ldg
((
const
long
long
unsigned
*
)
nvshmemi_device_state_d
.
peer_heap_base_p2p
+
dst_rank
);
if
(
peer_base_addr
)
{
if
(
peer_base_addr
)
{
char
*
req_rptr_actual
=
(
char
*
)(
peer_base_addr
)
+
((
char
*
)
dst_ptr
-
(
char
*
)(
nvshmemi_device_state_d
.
heap_base
));
char
*
req_rptr_actual
=
(
char
*
)(
peer_base_addr
)
+
((
char
*
)
dst_ptr
-
(
char
*
)(
nvshmemi_device_state_d
.
heap_base
));
const
auto
*
src_int4_ptr
=
reinterpret_cast
<
const
int4
*>
(
src_ptr
);
const
auto
*
src_int4_ptr
=
reinterpret_cast
<
const
int4
*>
(
src_ptr
);
const
auto
*
dst_int4_ptr
=
reinterpret_cast
<
int4
*>
(
req_rptr_actual
);
const
auto
*
dst_int4_ptr
=
reinterpret_cast
<
int4
*>
(
req_rptr_actual
);
UNROLLED_WARP_COPY
(
8
,
lane_id
,
num_int4_per_msg
,
dst_int4_ptr
,
src_int4_ptr
,
ld_nc_global
,
st_na_global
);
UNROLLED_WARP_COPY
(
8
,
lane_id
,
num_int4_per_msg
,
dst_int4_ptr
,
src_int4_ptr
,
ld_nc_global
,
st_na_global
);
}
else
}
else
{
#endif
internode
::
shmemx_int8_put_nbi_warp
(
{
reinterpret_cast
<
signed
char
*>
(
dst_ptr
),
reinterpret_cast
<
signed
char
*>
(
src_ptr
),
#if !defined(ROCM_DISABLE_CTX)
num_bytes_per_msg
,
dst_rank
);
internode
::
shmem_ctx_schar_put_nbi_warp
(
ctx
,
}
#else
#else
#if !defined(ROCM_USE_MULTIQP)
internode
::
shmemx_int8_put_nbi_warp
(
internode
::
shmemx_int8_put_nbi_warp
(
#endif
reinterpret_cast
<
signed
char
*>
(
dst_ptr
),
reinterpret_cast
<
signed
char
*>
(
src_ptr
),
reinterpret_cast
<
signed
char
*>
(
dst_ptr
),
reinterpret_cast
<
signed
char
*>
(
src_ptr
),
num_bytes_per_msg
,
dst_rank
);
num_bytes_per_msg
,
dst_rank
);
}
#else
internode
::
shmemx_int8_put_nbi_warp_dp
(
reinterpret_cast
<
signed
char
*>
(
dst_ptr
),
reinterpret_cast
<
signed
char
*>
(
src_ptr
),
num_bytes_per_msg
,
(
dst_expert_local_idx
+
1
)
*
num_ranks
+
dst_rank
,
dst_rank
);
#endif
#endif // defined(FORCE_NVSHMEM_API)
}
else
{
}
else
{
// NOTES: only 2 load iterations for 7K hidden with 8 unrolls
// NOTES: only 2 load iterations for 7K hidden with 8 unrolls
const
auto
*
src_int4_ptr
=
reinterpret_cast
<
const
int4
*>
(
src_ptr
);
const
auto
*
src_int4_ptr
=
reinterpret_cast
<
const
int4
*>
(
src_ptr
);
...
@@ -345,22 +342,26 @@ dispatch(void* packed_recv_x, void* packed_recv_x_scales,
...
@@ -345,22 +342,26 @@ dispatch(void* packed_recv_x, void* packed_recv_x_scales,
// Wait local sends issued and send expert counts
// Wait local sends issued and send expert counts
while
(
ld_acquire_global
(
atomic_finish_counter_per_expert
+
responsible_expert_idx
)
!=
FINISHED_SUM_TAG
*
2
);
while
(
ld_acquire_global
(
atomic_finish_counter_per_expert
+
responsible_expert_idx
)
!=
FINISHED_SUM_TAG
*
2
);
if
(
dst_rank
!=
rank
)
{
if
(
dst_rank
!=
rank
)
{
#if defined(FORCE_NVSHMEM_API)
#if defined(FORCE_NVSHMEM_API)
void
*
peer_base_addr
=
(
void
*
)
__ldg
((
const
long
long
unsigned
*
)
nvshmemi_device_state_d
.
peer_heap_base_p2p
+
dst_rank
);
void
*
peer_base_addr
=
(
void
*
)
__ldg
((
const
long
long
unsigned
*
)
nvshmemi_device_state_d
.
peer_heap_base_p2p
+
dst_rank
);
if
(
peer_base_addr
)
{
// P2P enabled
if
(
peer_base_addr
)
{
// P2P enabled
int
*
rptr_actual
=
(
int
*
)((
char
*
)(
peer_base_addr
)
+
int
*
rptr_actual
=
(
int
*
)((
char
*
)(
peer_base_addr
)
+
((
char
*
)(
rdma_recv_count
+
dst_expert_local_idx
*
num_ranks
+
rank
)
-
(
char
*
)(
nvshmemi_device_state_d
.
heap_base
)));
((
char
*
)(
rdma_recv_count
+
dst_expert_local_idx
*
num_ranks
+
rank
)
-
(
char
*
)(
nvshmemi_device_state_d
.
heap_base
)));
st_na_release
(
rptr_actual
,
-
num_tokens_sent
-
1
);
st_na_release
(
rptr_actual
,
-
num_tokens_sent
-
1
);
}
else
}
else
{
#endif
internode
::
shmem_long_atomic_add
(
{
rdma_recv_count
+
dst_expert_local_idx
*
num_ranks
+
rank
,
-
num_tokens_sent
-
1
,
dst_rank
);
#if !defined(ROCM_DISABLE_CTX)
}
internode
::
shmem_ctx_long_atomic_add
(
ctx
,
#else
#else
#if !defined(ROCM_USE_MULTIQP)
internode
::
shmem_long_atomic_add
(
internode
::
shmem_long_atomic_add
(
#endif
rdma_recv_count
+
dst_expert_local_idx
*
num_ranks
+
rank
,
-
num_tokens_sent
-
1
,
dst_rank
);
rdma_recv_count
+
dst_expert_local_idx
*
num_ranks
+
rank
,
-
num_tokens_sent
-
1
,
dst_rank
);
}
#else
internode
::
shmem_long_atomic_add_dp
(
rdma_recv_count
+
dst_expert_local_idx
*
num_ranks
+
rank
,
-
num_tokens_sent
-
1
,
(
dst_expert_local_idx
+
1
)
*
num_ranks
+
dst_rank
,
dst_rank
);
#endif
#endif // defined(FORCE_NVSHMEM_API)
}
else
{
}
else
{
st_na_release
(
reinterpret_cast
<
int
*>
(
rdma_recv_count
+
dst_expert_local_idx
*
num_ranks
+
rank
),
-
num_tokens_sent
-
1
);
st_na_release
(
reinterpret_cast
<
int
*>
(
rdma_recv_count
+
dst_expert_local_idx
*
num_ranks
+
rank
),
-
num_tokens_sent
-
1
);
}
}
...
@@ -375,10 +376,6 @@ dispatch(void* packed_recv_x, void* packed_recv_x_scales,
...
@@ -375,10 +376,6 @@ dispatch(void* packed_recv_x, void* packed_recv_x_scales,
}
}
syncwarp
();
syncwarp
();
#if !defined(ROCM_DISABLE_CTX)
internode
::
shmem_wg_ctx_destroy
(
&
ctx
);
#endif
// Receiving phase
// Receiving phase
LOW_LATENCY_DISPATCH_RECV:
LOW_LATENCY_DISPATCH_RECV:
if
((
phases
&
LOW_LATENCY_RECV_PHASE
)
==
0
)
if
((
phases
&
LOW_LATENCY_RECV_PHASE
)
==
0
)
...
@@ -433,6 +430,12 @@ LOW_LATENCY_DISPATCH_RECV:
...
@@ -433,6 +430,12 @@ LOW_LATENCY_DISPATCH_RECV:
recv_range
[
src_rank
]
=
pack2
<
int
,
int64_t
>
(
num_recv_tokens
,
recv_token_begin_idx
);
recv_range
[
src_rank
]
=
pack2
<
int
,
int64_t
>
(
num_recv_tokens
,
recv_token_begin_idx
);
}
}
#if defined(ROCM_USE_MULTIQP)
if
(
sub_warp_id
==
2
and
lane_id
==
0
)
{
internode
::
shmem_qp_quiet
(
num_ranks
+
responsible_expert_idx
);
}
#endif
// no needs to reset because there is no iteration
// no needs to reset because there is no iteration
if
(
lane_id
==
0
){
if
(
lane_id
==
0
){
volatile
int
ret
=
__hip_atomic_fetch_add
(
&
sync_large_warp_counters
[
warp_group_id
],
1
,
__ATOMIC_RELAXED
,
__HIP_MEMORY_SCOPE_WORKGROUP
);
volatile
int
ret
=
__hip_atomic_fetch_add
(
&
sync_large_warp_counters
[
warp_group_id
],
1
,
__ATOMIC_RELAXED
,
__HIP_MEMORY_SCOPE_WORKGROUP
);
...
@@ -591,11 +594,6 @@ combine(void* combined_x,
...
@@ -591,11 +594,6 @@ combine(void* combined_x,
if
((
phases
&
LOW_LATENCY_SEND_PHASE
)
==
0
)
if
((
phases
&
LOW_LATENCY_SEND_PHASE
)
==
0
)
goto
LOW_LATENCY_COMBINE_RECV
;
goto
LOW_LATENCY_COMBINE_RECV
;
#if !defined(ROCM_DISABLE_CTX)
__shared__
internode
::
shmem_ctx_t
ctx
;
internode
::
shmem_wg_ctx_create
(
&
ctx
);
#endif
// Clean up next buffer
// Clean up next buffer
if
(
sm_id
==
0
and
warp_group_id
==
0
and
sub_warp_id
==
0
)
{
if
(
sm_id
==
0
and
warp_group_id
==
0
and
sub_warp_id
==
0
)
{
#pragma unroll
#pragma unroll
...
@@ -642,23 +640,28 @@ combine(void* combined_x,
...
@@ -642,23 +640,28 @@ combine(void* combined_x,
if
(
not
zero_copy
)
if
(
not
zero_copy
)
UNROLLED_WARP_COPY_LL
(
7
,
lane_id
,
hidden_bf16_int4
,
buf_int4_ptr
,
x_int4
,
ld_nc_global
,
st_na_global
);
UNROLLED_WARP_COPY_LL
(
7
,
lane_id
,
hidden_bf16_int4
,
buf_int4_ptr
,
x_int4
,
ld_nc_global
,
st_na_global
);
#if defined(FORCE_NVSHMEM_API)
#if defined(FORCE_NVSHMEM_API)
void
*
peer_base_addr
=
(
void
*
)
__ldg
((
const
long
long
unsigned
*
)
nvshmemi_device_state_d
.
peer_heap_base_p2p
+
dst_rank
);
void
*
peer_base_addr
=
(
void
*
)
__ldg
((
const
long
long
unsigned
*
)
nvshmemi_device_state_d
.
peer_heap_base_p2p
+
dst_rank
);
if
(
peer_base_addr
)
{
if
(
peer_base_addr
)
{
char
*
req_rptr_actual
=
(
char
*
)(
peer_base_addr
)
+
((
char
*
)
dst_ptr
-
(
char
*
)(
nvshmemi_device_state_d
.
heap_base
));
char
*
req_rptr_actual
=
(
char
*
)(
peer_base_addr
)
+
((
char
*
)
dst_ptr
-
(
char
*
)(
nvshmemi_device_state_d
.
heap_base
));
const
auto
dst_int4_ptr
=
reinterpret_cast
<
int4
*>
(
req_rptr_actual
);
const
auto
dst_int4_ptr
=
reinterpret_cast
<
int4
*>
(
req_rptr_actual
);
UNROLLED_WARP_COPY
(
7
,
lane_id
,
hidden_bf16_int4
,
dst_int4_ptr
,
x_int4
,
ld_nc_global
,
st_na_global
);
UNROLLED_WARP_COPY
(
7
,
lane_id
,
hidden_bf16_int4
,
dst_int4_ptr
,
x_int4
,
ld_nc_global
,
st_na_global
);
}
else
}
else
{
#endif
internode
::
shmemx_int8_put_nbi_warp
(
{
reinterpret_cast
<
signed
char
*>
(
dst_ptr
),
reinterpret_cast
<
signed
char
*>
(
buf_ptr
),
#if !defined(ROCM_DISABLE_CTX)
hidden
*
sizeof
(
hip_bfloat16
),
dst_rank
);
internode
::
shmem_ctx_schar_put_nbi_warp
(
ctx
,
}
#else
#else
#if !defined(ROCM_USE_MULTIQP)
internode
::
shmemx_int8_put_nbi_warp
(
internode
::
shmemx_int8_put_nbi_warp
(
#endif
reinterpret_cast
<
signed
char
*>
(
dst_ptr
),
reinterpret_cast
<
signed
char
*>
(
buf_ptr
),
reinterpret_cast
<
signed
char
*>
(
dst_ptr
),
reinterpret_cast
<
signed
char
*>
(
buf_ptr
),
hidden
*
sizeof
(
hip_bfloat16
),
dst_rank
);
hidden
*
sizeof
(
hip_bfloat16
),
dst_rank
);
}
#else
internode
::
shmemx_int8_put_nbi_warp_dp
(
reinterpret_cast
<
signed
char
*>
(
dst_ptr
),
reinterpret_cast
<
signed
char
*>
(
buf_ptr
),
hidden
*
sizeof
(
hip_bfloat16
),
(
local_expert_idx
+
1
)
*
num_ranks
+
dst_rank
,
dst_rank
);
#endif
#endif // defined(FORCE_NVSHMEM_API)
}
}
}
}
...
@@ -673,55 +676,39 @@ combine(void* combined_x,
...
@@ -673,55 +676,39 @@ combine(void* combined_x,
if
(
sub_warp_id
==
1
and
lane_id
==
0
)
{
if
(
sub_warp_id
==
1
and
lane_id
==
0
)
{
while
(
ld_acquire_global
(
atomic_clean_flag
)
==
0
);
while
(
ld_acquire_global
(
atomic_clean_flag
)
==
0
);
if
(
dst_rank
!=
rank
)
{
if
(
dst_rank
!=
rank
)
{
#if defined(FORCE_NVSHMEM_API)
#if defined(FORCE_NVSHMEM_API)
void
*
peer_base_addr
=
(
void
*
)
__ldg
((
const
long
long
unsigned
*
)
nvshmemi_device_state_d
.
peer_heap_base_p2p
+
dst_rank
);
void
*
peer_base_addr
=
(
void
*
)
__ldg
((
const
long
long
unsigned
*
)
nvshmemi_device_state_d
.
peer_heap_base_p2p
+
dst_rank
);
if
(
peer_base_addr
)
{
if
(
peer_base_addr
)
{
int
*
req_rptr_actual
=
(
int
*
)((
char
*
)(
peer_base_addr
)
+
int
*
req_rptr_actual
=
(
int
*
)((
char
*
)(
peer_base_addr
)
+
((
char
*
)(
rdma_recv_flag
+
global_expert_idx
)
-
(
char
*
)(
nvshmemi_device_state_d
.
heap_base
)));
((
char
*
)(
rdma_recv_flag
+
global_expert_idx
)
-
(
char
*
)(
nvshmemi_device_state_d
.
heap_base
)));
st_na_release
(
req_rptr_actual
,
1
);
st_na_release
(
req_rptr_actual
,
1
);
}
else
}
else
{
#endif
internode
::
shmem_long_atomic_add
(
{
rdma_recv_flag
+
global_expert_idx
,
1
,
dst_rank
);
#if !defined(ROCM_DISABLE_CTX)
}
internode
::
shmem_ctx_long_atomic_add
(
ctx
,
#else
#else
#if !defined(ROCM_USE_MULTIQP)
internode
::
shmem_long_atomic_add
(
internode
::
shmem_long_atomic_add
(
#endif
rdma_recv_flag
+
global_expert_idx
,
1
,
dst_rank
);
rdma_recv_flag
+
global_expert_idx
,
1
,
dst_rank
);
}
#else
internode
::
shmem_long_atomic_add_dp
(
rdma_recv_flag
+
global_expert_idx
,
1
,
(
local_expert_idx
+
1
)
*
num_ranks
+
dst_rank
,
dst_rank
);
#endif
#endif // defined(FORCE_NVSHMEM_API)
}
else
{
}
else
{
st_na_release
(
reinterpret_cast
<
int
*>
(
rdma_recv_flag
+
global_expert_idx
),
1
);
st_na_release
(
reinterpret_cast
<
int
*>
(
rdma_recv_flag
+
global_expert_idx
),
1
);
}
}
atomic_add_release_global
(
atomic_clean_flag
,
-
1
);
atomic_add_release_global
(
atomic_clean_flag
,
-
1
);
}
}
syncwarp
();
syncwarp
();
if
(
num_ranks
>
8
){
#if !defined(ROCM_DISABLE_CTX)
internode
::
shmem_ctx_quiet
(
ctx
);
#else
internode
::
shmem_fence
();
#endif
}
}
}
#if !defined(ROCM_DISABLE_CTX)
internode
::
shmem_wg_ctx_destroy
(
&
ctx
);
#endif
// Receiving phase
// Receiving phase
LOW_LATENCY_COMBINE_RECV:
LOW_LATENCY_COMBINE_RECV:
if
((
phases
&
LOW_LATENCY_RECV_PHASE
)
==
0
)
if
((
phases
&
LOW_LATENCY_RECV_PHASE
)
==
0
)
return
;
return
;
// if (num_ranks > 8){
// #if !defined(ROCM_DISABLE_CTX)
// internode::shmem_ctx_quiet(ctx);
// #else
// internode::shmem_fence();
// #endif
// }
// Wait all ranks to arrive and notify PCIe usage
// Wait all ranks to arrive and notify PCIe usage
if
(
responsible_expert_idx
<
num_experts
)
{
if
(
responsible_expert_idx
<
num_experts
)
{
EP_DEVICE_ASSERT
(
num_warps_per_group
>
1
);
EP_DEVICE_ASSERT
(
num_warps_per_group
>
1
);
...
@@ -743,6 +730,11 @@ LOW_LATENCY_COMBINE_RECV:
...
@@ -743,6 +730,11 @@ LOW_LATENCY_COMBINE_RECV:
atomicAdd
(
reinterpret_cast
<
unsigned
long
long
*>
(
combine_wait_recv_cost_stats
+
src_rank
),
wait_recv_cost
);
atomicAdd
(
reinterpret_cast
<
unsigned
long
long
*>
(
combine_wait_recv_cost_stats
+
src_rank
),
wait_recv_cost
);
}
}
}
}
#if defined(ROCM_USE_MULTIQP)
if
(
sub_warp_id
==
2
and
lane_id
==
0
)
{
internode
::
shmem_qp_quiet
(
num_ranks
+
responsible_expert_idx
);
}
#endif
}
}
grid_barrier
(
global_atomic_counter
,
num_sms
);
grid_barrier
(
global_atomic_counter
,
num_sms
);
...
...
csrc/kernels/shmem_wrapper.cuh
View file @
1b497233
...
@@ -116,6 +116,22 @@ __device__ inline void shmem_long_atomic_add(
...
@@ -116,6 +116,22 @@ __device__ inline void shmem_long_atomic_add(
rocshmem
::
rocshmem_long_atomic_add
(
dest
,
value
,
pe
);
rocshmem
::
rocshmem_long_atomic_add
(
dest
,
value
,
pe
);
}
}
#if defined(ROCM_USE_MULTIQP)
__device__
inline
void
shmem_qp_quiet
(
int
idx_qp
)
{
rocshmem
::
rocshmem_quiet_dp
(
idx_qp
);
}
__device__
inline
void
shmemx_int8_put_nbi_warp_dp
(
signed
char
*
dest
,
const
signed
char
*
source
,
size_t
nelems
,
int
qp_idx
,
int
pe
)
{
rocshmem
::
rocshmem_schar_put_nbi_wave_dp
(
dest
,
source
,
nelems
,
qp_idx
,
pe
);
}
__device__
inline
void
shmem_long_atomic_add_dp
(
long
*
dest
,
long
value
,
int
qp_idx
,
int
pe
)
{
rocshmem
::
rocshmem_long_atomic_add_dp
(
dest
,
value
,
qp_idx
,
pe
);
}
#endif
#if !defined(ROCM_DISABLE_CTX)
#if !defined(ROCM_DISABLE_CTX)
using
shmem_ctx_t
=
rocshmem
::
rocshmem_ctx_t
;
using
shmem_ctx_t
=
rocshmem
::
rocshmem_ctx_t
;
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment