Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
DeepEP
Commits
ab0afb04
Commit
ab0afb04
authored
Apr 17, 2026
by
lishen
Browse files
Merge branch 'normal_update' into 'main'
Normal update See merge request dcutoolkit/deeplearing/DeepEP!28
parents
766b17b3
30aa7a87
Changes
13
Show whitespace changes
Inline
Side-by-side
Showing
13 changed files
with
235 additions
and
130 deletions
+235
-130
csrc/config.hpp
csrc/config.hpp
+35
-28
csrc/deep_ep.cu
csrc/deep_ep.cu
+2
-0
csrc/kernels/internode.cu
csrc/kernels/internode.cu
+68
-59
deep_ep/buffer.py
deep_ep/buffer.py
+2
-1
tests/1.sh
tests/1.sh
+6
-3
tests/2.sh
tests/2.sh
+6
-3
tests/test_internode.py
tests/test_internode.py
+10
-3
tests/test_intranode.py
tests/test_intranode.py
+1
-1
tests/test_low_latency.py
tests/test_low_latency.py
+4
-3
tests_mpi/test_env.sh
tests_mpi/test_env.sh
+9
-6
tests_mpi/test_internode.py
tests_mpi/test_internode.py
+9
-3
tests_mpi/test_intranode.py
tests_mpi/test_intranode.py
+4
-3
tests_mpi/test_low_latency.py
tests_mpi/test_low_latency.py
+79
-17
No files found.
csrc/config.hpp
View file @
ab0afb04
...
@@ -47,21 +47,25 @@ struct Config {
...
@@ -47,21 +47,25 @@ struct Config {
EP_HOST_ASSERT
(
num_ranks
<=
NUM_MAX_NVL_PEERS
or
num_sms
%
(
2
*
NUM_INTERNODE_DISPATCH_BLOCKS_PER_CHANNEL
)
==
0
);
EP_HOST_ASSERT
(
num_ranks
<=
NUM_MAX_NVL_PEERS
or
num_sms
%
(
2
*
NUM_INTERNODE_DISPATCH_BLOCKS_PER_CHANNEL
)
==
0
);
const
auto
num_rdma_ranks
=
std
::
max
(
num_ranks
/
NUM_MAX_NVL_PEERS
,
1
);
const
auto
num_rdma_ranks
=
std
::
max
(
num_ranks
/
NUM_MAX_NVL_PEERS
,
1
);
const
auto
num_nvl_ranks
=
std
::
min
(
num_ranks
,
NUM_MAX_NVL_PEERS
);
const
auto
num_nvl_ranks
=
std
::
min
(
num_ranks
,
NUM_MAX_NVL_PEERS
);
const
int
num_channels
=
num_sms
;
const
int
num_channels
=
num_sms
/
NUM_INTERNODE_DISPATCH_BLOCKS_PER_CHANNEL
;
size_t
num_bytes
=
0
;
// 计算每个nvl通信数据包的数据量
num_bytes
+=
num_channels
*
num_nvl_ranks
*
(
2
*
num_rdma_ranks
+
3
)
*
sizeof
(
int
);
size_t
num_single_nvl_bag_bytes
=
num_bytes
+=
num_channels
*
num_nvl_ranks
*
num_max_nvl_chunked_recv_tokens
*
hidden_bytes
;
hidden_bytes
+
// 数据缓冲区(Token Data)。存储从 RDMA 转发过来的 token 数据(x 张量)
#ifndef DISABLE_ROCSHMEM
#ifndef DISABLE_ROCSHMEM
num_bytes
+=
num_channels
*
num_nvl_ranks
*
num_max_nvl_chunked_recv_tokens
*
internode
::
get_source_meta_bytes
()
+
// 源元数据缓冲区(Source Metadata)。存储每个 token 的源信息(哪个 RDMA rank 发送的)
internode
::
get_source_meta_bytes
();
#endif
#endif
num_bytes
+=
num_channels
*
num_nvl_ranks
*
num_max_nvl_chunked_recv_tokens
*
kNumMaxTopK
*
kNumMaxTopK
*
sizeof
(
int
)
+
// TopK 索引缓冲区。存储每个 token 的 top-k 专家索引
sizeof
(
int64_t
);
kNumMaxTopK
*
sizeof
(
float
)
+
// TopK 权重缓冲区。存储每个 token 的 top-k 专家权重
num_bytes
+=
num_channels
*
num_nvl_ranks
*
num_max_nvl_chunked_recv_tokens
*
kNumMaxTopK
*
kNumMaxScales
*
sizeof
(
float
);
// Scale 缓冲区。存储每个 token 的量化缩放因子
sizeof
(
float
);
num_bytes
+=
num_channels
*
num_nvl_ranks
*
num_max_nvl_chunked_recv_tokens
*
// 计算每个 NVL channel 的控制信息所需的字节数,存储每个 NVL channel 的前缀索引信息,用于快速定位数据(nvl_channel_prefix_start、nvl_channel_prefix_end 等)
kNumMaxScales
*
sizeof
(
float
);
size_t
num_single_nvl_control_bytes
=
(
2
*
num_rdma_ranks
+
3
)
*
sizeof
(
int
);
// NVL 数据总的字节数
size_t
num_bytes
=
(
num_single_nvl_bag_bytes
*
num_max_nvl_chunked_recv_tokens
+
num_single_nvl_control_bytes
)
*
num_channels
*
num_nvl_ranks
;
// 128 字节对齐,匹配 GPU 缓存行大小,优化内存访问。
num_bytes
=
((
num_bytes
+
127
)
/
128
)
*
128
;
num_bytes
=
((
num_bytes
+
127
)
/
128
)
*
128
;
return
num_bytes
;
return
num_bytes
;
}
}
...
@@ -79,22 +83,25 @@ struct Config {
...
@@ -79,22 +83,25 @@ struct Config {
EP_HOST_ASSERT
(
num_ranks
%
NUM_MAX_NVL_PEERS
==
0
);
EP_HOST_ASSERT
(
num_ranks
%
NUM_MAX_NVL_PEERS
==
0
);
EP_HOST_ASSERT
(
num_sms
%
NUM_INTERNODE_DISPATCH_BLOCKS_PER_CHANNEL
==
0
);
EP_HOST_ASSERT
(
num_sms
%
NUM_INTERNODE_DISPATCH_BLOCKS_PER_CHANNEL
==
0
);
const
int
num_rdma_ranks
=
num_ranks
/
NUM_MAX_NVL_PEERS
;
const
int
num_rdma_ranks
=
num_ranks
/
NUM_MAX_NVL_PEERS
;
const
int
num_channels
=
num_sms
;
const
int
num_channels
=
num_sms
/
NUM_INTERNODE_DISPATCH_BLOCKS_PER_CHANNEL
;
size_t
num_bytes
=
0
;
// 计算每个rdma通信数据包的数据量
num_bytes
+=
num_channels
*
num_rdma_ranks
*
(
NUM_MAX_NVL_PEERS
*
2
+
2
)
*
2
*
sizeof
(
int
);
size_t
num_single_rdma_bag_bytes
=
num_bytes
+=
hidden_bytes
+
// 数据缓冲区。存储实际的 token 数据(x 张量),对应代码中的 rdma_channel_data
num_channels
*
num_rdma_ranks
*
num_max_rdma_chunked_recv_tokens
*
hidden_bytes
*
2
;
internode
::
get_source_meta_bytes
()
+
// 源元数据缓冲区。存储每个 token 的源信息(SourceMeta)
num_bytes
+=
num_channels
*
num_rdma_ranks
*
num_max_rdma_chunked_recv_tokens
*
kNumMaxTopK
*
sizeof
(
int
)
+
// 存储每个 token 的 top-k 专家索引。对应 topk_idx 数据
internode
::
get_source_meta_bytes
()
*
2
;
kNumMaxTopK
*
sizeof
(
float
)
+
// 存储每个 token 的 top-k 专家权重。对应 topk_weights 数据
num_bytes
+=
num_channels
*
num_rdma_ranks
*
num_max_rdma_chunked_recv_tokens
*
kNumMaxScales
*
sizeof
(
float
)
+
// 存储每个 token 的缩放因子(x_scales)
kNumMaxTopK
*
sizeof
(
int64_t
)
*
2
;
sizeof
(
int4
);
// 预留空间用于内存对齐和未来扩展
num_bytes
+=
num_channels
*
num_rdma_ranks
*
num_max_rdma_chunked_recv_tokens
*
kNumMaxTopK
*
sizeof
(
float
)
*
2
;
// 计算每个 RDMA channel 的控制信息(起始/结束索引)所需的字节数,对应代码中的 rdma_channel_meta
num_bytes
+=
num_channels
*
num_rdma_ranks
*
num_max_rdma_chunked_recv_tokens
*
size_t
num_single_rdma_control_bytes
=
(
NUM_MAX_NVL_PEERS
*
2
+
4
)
*
sizeof
(
int
);
kNumMaxScales
*
sizeof
(
float
)
*
2
;
num_bytes
+=
// RDMA 数据总的字节数
num_channels
*
num_rdma_ranks
*
num_max_rdma_chunked_recv_tokens
*
sizeof
(
int4
)
*
2
;
size_t
num_bytes
=
(
num_single_rdma_bag_bytes
*
num_max_rdma_chunked_recv_tokens
+
num_single_rdma_control_bytes
)
*
num_channels
*
num_rdma_ranks
*
2
;
// 128 字节对齐(缓存行对齐),优化内存访问性能
num_bytes
=
((
num_bytes
+
127
)
/
128
)
*
128
;
num_bytes
=
((
num_bytes
+
127
)
/
128
)
*
128
;
return
num_bytes
;
return
num_bytes
;
#else
#else
...
...
csrc/deep_ep.cu
View file @
ab0afb04
...
@@ -937,6 +937,7 @@ Buffer::internode_dispatch(const torch::Tensor &x, const std::optional<torch::Te
...
@@ -937,6 +937,7 @@ Buffer::internode_dispatch(const torch::Tensor &x, const std::optional<torch::Te
gbl_channel_prefix_matrix
=
cached_gbl_channel_prefix_matrix
.
value
();
gbl_channel_prefix_matrix
=
cached_gbl_channel_prefix_matrix
.
value
();
recv_gbl_rank_prefix_sum
=
cached_recv_gbl_rank_prefix_sum
.
value
();
recv_gbl_rank_prefix_sum
=
cached_recv_gbl_rank_prefix_sum
.
value
();
EP_HOST_ASSERT
(
num_rdma_bytes
>=
config
.
get_rdma_buffer_size_hint
(
hidden_int4
*
sizeof
(
int4
),
num_ranks
));
// Just a barrier and clean flags
// Just a barrier and clean flags
internode
::
cached_notify
(
internode
::
cached_notify
(
hidden_int4
,
num_scales
,
num_topk
,
num_topk
,
num_ranks
,
num_channels
,
0
,
nullptr
,
hidden_int4
,
num_scales
,
num_topk
,
num_topk
,
num_ranks
,
num_channels
,
0
,
nullptr
,
...
@@ -1205,6 +1206,7 @@ Buffer::internode_combine(
...
@@ -1205,6 +1206,7 @@ Buffer::internode_combine(
EP_HOST_ASSERT
(
config
.
num_max_nvl_chunked_recv_tokens
%
num_rdma_ranks
==
0
);
EP_HOST_ASSERT
(
config
.
num_max_nvl_chunked_recv_tokens
%
num_rdma_ranks
==
0
);
EP_HOST_ASSERT
(
config
.
num_max_nvl_chunked_send_tokens
<=
EP_HOST_ASSERT
(
config
.
num_max_nvl_chunked_send_tokens
<=
config
.
num_max_nvl_chunked_recv_tokens
/
num_rdma_ranks
);
config
.
num_max_nvl_chunked_recv_tokens
/
num_rdma_ranks
);
EP_HOST_ASSERT
(
num_rdma_bytes
>=
config
.
get_rdma_buffer_size_hint
(
hidden_int4
*
sizeof
(
int4
),
num_ranks
));
// Launch barrier and reset queue head and tail
// Launch barrier and reset queue head and tail
internode
::
cached_notify
(
internode
::
cached_notify
(
...
...
csrc/kernels/internode.cu
View file @
ab0afb04
...
@@ -7,6 +7,10 @@
...
@@ -7,6 +7,10 @@
#ifndef DISABLE_ROCSHMEM
#ifndef DISABLE_ROCSHMEM
// 安全检查:确保宏已定义
#ifndef HIP_VERSION_PATCH
#error "HIP_VERSION_PATCH not defined! Check your HIP installation."
#endif
// TODO: fix unroll warnings
// TODO: fix unroll warnings
// #ifdef __clang__
// #ifdef __clang__
// #pragma clang diagnostic push
// #pragma clang diagnostic push
...
@@ -56,16 +60,18 @@ __host__ __device__ __forceinline__ int get_num_bytes_per_rdma_token(int hidden_
...
@@ -56,16 +60,18 @@ __host__ __device__ __forceinline__ int get_num_bytes_per_rdma_token(int hidden_
__host__
__device__
__forceinline__
std
::
pair
<
int
,
int
>
__host__
__device__
__forceinline__
std
::
pair
<
int
,
int
>
get_rdma_clean_meta
(
int
hidden_int4
,
int
num_scales
,
int
num_topk_idx
,
int
num_topk_weights
,
get_rdma_clean_meta
(
int
hidden_int4
,
int
num_scales
,
int
num_topk_idx
,
int
num_topk_weights
,
int
num_rdma_ranks
,
int
num_rdma_recv_buffer_tokens
,
int
num_sms
)
{
int
num_rdma_ranks
,
int
num_rdma_recv_buffer_tokens
,
int
num_channels
)
{
// Return `int32_t` offset and count to clean
// Return `int32_t` offset and count to clean
return
{(
get_num_bytes_per_rdma_token
(
hidden_int4
,
num_scales
,
num_topk_idx
,
num_topk_weights
)
*
return
{(
get_num_bytes_per_rdma_token
(
hidden_int4
,
num_scales
,
num_topk_idx
,
num_topk_weights
)
*
num_rdma_recv_buffer_tokens
*
num_rdma_ranks
*
2
*
num_
sm
s
)
/
sizeof
(
int
),
num_rdma_recv_buffer_tokens
*
num_rdma_ranks
*
2
*
num_
channel
s
)
/
sizeof
(
int
),
(
NUM_MAX_NVL_PEERS
*
2
+
4
)
*
num_rdma_ranks
*
2
*
num_
sm
s
};
(
NUM_MAX_NVL_PEERS
*
2
+
4
)
*
num_rdma_ranks
*
2
*
num_
channel
s
};
}
}
__host__
__device__
__forceinline__
std
::
pair
<
int
,
int
>
__host__
__device__
__forceinline__
std
::
pair
<
int
,
int
>
get_nvl_clean_meta
(
int
hidden_int4
,
int
num_scales
,
int
num_topk_idx
,
int
num_topk_weights
,
get_nvl_clean_meta
(
int
hidden_int4
,
int
num_scales
,
int
num_topk_idx
,
int
num_topk_weights
,
int
num_rdma_ranks
,
int
num_nvl_ranks
,
int
num_nvl_recv_buffer_tokens
,
int
num_rdma_ranks
,
int
num_nvl_ranks
,
int
num_nvl_recv_buffer_tokens
,
int
num_
sm
s
)
{
int
num_
channel
s
)
{
// Return `int32_t` offset and to clean
// Return `int32_t` offset and to clean
EP_STATIC_ASSERT
(
sizeof
(
SourceMeta
)
%
sizeof
(
int
)
==
0
,
EP_STATIC_ASSERT
(
sizeof
(
SourceMeta
)
%
sizeof
(
int
)
==
0
,
"Invalid size of `SourceMeta`"
);
"Invalid size of `SourceMeta`"
);
...
@@ -73,8 +79,8 @@ get_nvl_clean_meta(int hidden_int4, int num_scales, int num_topk_idx, int num_to
...
@@ -73,8 +79,8 @@ get_nvl_clean_meta(int hidden_int4, int num_scales, int num_topk_idx, int num_to
(
num_nvl_recv_buffer_tokens
*
(
num_nvl_recv_buffer_tokens
*
(
hidden_int4
*
sizeof
(
int4
)
+
num_scales
*
sizeof
(
float
)
+
num_topk_idx
*
sizeof
(
int
)
+
(
hidden_int4
*
sizeof
(
int4
)
+
num_scales
*
sizeof
(
float
)
+
num_topk_idx
*
sizeof
(
int
)
+
num_topk_weights
*
sizeof
(
float
)
+
sizeof
(
SourceMeta
))
*
num_topk_weights
*
sizeof
(
float
)
+
sizeof
(
SourceMeta
))
*
num_nvl_ranks
*
num_
sm
s
)
/
sizeof
(
int
),
num_nvl_ranks
*
num_
channel
s
)
/
sizeof
(
int
),
num_nvl_ranks
*
(
2
*
num_rdma_ranks
+
2
)
*
num_
sm
s
,
num_nvl_ranks
*
(
2
*
num_rdma_ranks
+
2
)
*
num_
channel
s
,
};
};
}
}
...
@@ -1230,13 +1236,13 @@ cached_notify(const int rdma_clean_offset, const int rdma_num_int_clean, const i
...
@@ -1230,13 +1236,13 @@ cached_notify(const int rdma_clean_offset, const int rdma_num_int_clean, const i
if
(
is_cached_dispatch
)
if
(
is_cached_dispatch
)
return
;
return
;
EP_DEVICE_ASSERT
(
num_warps
>=
num_channels
);
EP_DEVICE_ASSERT
(
num_rdma_ranks
<=
kWarpSize
);
EP_DEVICE_ASSERT
(
num_rdma_ranks
<=
kWarpSize
);
// Iterate in reverse order
// Iterate in reverse order
if
(
lane_id
<
num_rdma_ranks
and
warp_id
<
num_channels
)
{
for
(
int
channel_id
=
warp_id
;
channel_id
<
num_channels
;
channel_id
+=
num_warps
)
{
if
(
lane_id
<
num_rdma_ranks
)
{
int
token_start_idx
,
token_end_idx
;
int
token_start_idx
,
token_end_idx
;
get_channel_task_range
(
num_combined_tokens
,
num_channels
,
warp
_id
,
token_start_idx
,
get_channel_task_range
(
num_combined_tokens
,
num_channels
,
channel
_id
,
token_start_idx
,
token_end_idx
);
token_end_idx
);
// NOTES: `1 << 25` is a heuristic large number
// NOTES: `1 << 25` is a heuristic large number
...
@@ -1251,26 +1257,26 @@ cached_notify(const int rdma_clean_offset, const int rdma_num_int_clean, const i
...
@@ -1251,26 +1257,26 @@ cached_notify(const int rdma_clean_offset, const int rdma_num_int_clean, const i
}
}
}
}
}
}
}
}
else
{
}
else
{
if
(
is_cached_dispatch
)
if
(
is_cached_dispatch
)
return
;
return
;
EP_DEVICE_ASSERT
(
num_warps
>=
num_channels
);
EP_DEVICE_ASSERT
(
rdma_channel_prefix_matrix
!=
nullptr
and
rdma_rank_prefix_sum
!=
nullptr
);
EP_DEVICE_ASSERT
(
rdma_channel_prefix_matrix
!=
nullptr
and
rdma_rank_prefix_sum
!=
nullptr
);
EP_STATIC_ASSERT
(
NUM_MAX_NVL_PEERS
<=
kWarpSize
,
"Too many NVL peers"
);
EP_STATIC_ASSERT
(
NUM_MAX_NVL_PEERS
<=
kWarpSize
,
"Too many NVL peers"
);
constexpr
int
num_clean_sms
=
2
;
constexpr
int
num_clean_sms
=
2
;
if
(
lane_id
<
NUM_MAX_NVL_PEERS
and
warp_id
<
num_channels
)
{
for
(
int
channel_id
=
warp_id
;
channel_id
<
num_channels
;
channel_id
+=
num_warps
)
{
if
(
lane_id
<
NUM_MAX_NVL_PEERS
)
{
for
(
int
dst_rdma_rank
=
sm_id
-
num_clean_sms
;
dst_rdma_rank
<
num_rdma_ranks
;
for
(
int
dst_rdma_rank
=
sm_id
-
num_clean_sms
;
dst_rdma_rank
<
num_rdma_ranks
;
dst_rdma_rank
+=
num_channels
*
2
-
num_clean_sms
)
{
dst_rdma_rank
+=
num_channels
*
NUM_INTERNODE_DISPATCH_BLOCKS_PER_CHANNEL
-
num_clean_sms
)
{
// Iterate in reverse order
// Iterate in reverse order
int
token_start_idx
=
int
token_start_idx
=
warp
_id
==
0
channel
_id
==
0
?
0
?
0
:
rdma_channel_prefix_matrix
[
dst_rdma_rank
*
num_channels
+
warp
_id
-
1
];
:
rdma_channel_prefix_matrix
[
dst_rdma_rank
*
num_channels
+
channel
_id
-
1
];
int
token_end_idx
=
int
token_end_idx
=
rdma_channel_prefix_matrix
[
dst_rdma_rank
*
num_channels
+
warp
_id
];
rdma_channel_prefix_matrix
[
dst_rdma_rank
*
num_channels
+
channel
_id
];
int
shift
=
dst_rdma_rank
==
0
?
0
:
rdma_rank_prefix_sum
[
dst_rdma_rank
-
1
];
int
shift
=
dst_rdma_rank
==
0
?
0
:
rdma_rank_prefix_sum
[
dst_rdma_rank
-
1
];
token_start_idx
+=
shift
,
token_end_idx
+=
shift
;
token_start_idx
+=
shift
,
token_end_idx
+=
shift
;
...
@@ -1288,6 +1294,7 @@ cached_notify(const int rdma_clean_offset, const int rdma_num_int_clean, const i
...
@@ -1288,6 +1294,7 @@ cached_notify(const int rdma_clean_offset, const int rdma_num_int_clean, const i
}
}
}
}
}
}
}
}
}
void
cached_notify
(
int
hidden_int4
,
int
num_scales
,
int
num_topk_idx
,
int
num_topk_weights
,
void
cached_notify
(
int
hidden_int4
,
int
num_scales
,
int
num_topk_idx
,
int
num_topk_weights
,
...
@@ -1298,7 +1305,7 @@ void cached_notify(int hidden_int4, int num_scales, int num_topk_idx, int num_to
...
@@ -1298,7 +1305,7 @@ void cached_notify(int hidden_int4, int num_scales, int num_topk_idx, int num_to
int
num_max_nvl_chunked_recv_tokens
,
int
**
barrier_signal_ptrs
,
int
rank
,
int
num_max_nvl_chunked_recv_tokens
,
int
**
barrier_signal_ptrs
,
int
rank
,
hipStream_t
stream
,
int64_t
num_rdma_bytes
,
int64_t
num_nvl_bytes
,
hipStream_t
stream
,
int64_t
num_rdma_bytes
,
int64_t
num_nvl_bytes
,
bool
is_cached_dispatch
,
bool
low_latency_mode
)
{
bool
is_cached_dispatch
,
bool
low_latency_mode
)
{
const
int
num_threads
=
::
max
(
128
,
kWarpSize
*
num_channels
);
const
int
num_threads
=
::
min
(
1024
,
::
max
(
128
,
kWarpSize
*
num_channels
)
)
;
const
auto
num_rdma_ranks
=
num_ranks
/
NUM_MAX_NVL_PEERS
;
const
auto
num_rdma_ranks
=
num_ranks
/
NUM_MAX_NVL_PEERS
;
// Get clean meta
// Get clean meta
...
@@ -1314,11 +1321,11 @@ void cached_notify(int hidden_int4, int num_scales, int num_topk_idx, int num_to
...
@@ -1314,11 +1321,11 @@ void cached_notify(int hidden_int4, int num_scales, int num_topk_idx, int num_to
num_nvl_bytes
);
num_nvl_bytes
);
EP_HOST_ASSERT
(
num_rdma_bytes
<
std
::
numeric_limits
<
int
>::
max
());
EP_HOST_ASSERT
(
num_rdma_bytes
<
std
::
numeric_limits
<
int
>::
max
());
EP_HOST_ASSERT
(
num_nvl_bytes
<
std
::
numeric_limits
<
int
>::
max
());
EP_HOST_ASSERT
(
num_nvl_bytes
<
std
::
numeric_limits
<
int
>::
max
());
EP_HOST_ASSERT
(
num_channels
*
2
>
2
);
EP_HOST_ASSERT
(
num_channels
*
NUM_INTERNODE_DISPATCH_BLOCKS_PER_CHANNEL
>
2
);
// Launch kernel
// Launch kernel
auto
cached_notify_func
=
low_latency_mode
?
cached_notify
<
true
>
:
cached_notify
<
false
>
;
auto
cached_notify_func
=
low_latency_mode
?
cached_notify
<
true
>
:
cached_notify
<
false
>
;
SETUP_LAUNCH_CONFIG
(
num_channels
*
2
,
num_threads
,
stream
);
SETUP_LAUNCH_CONFIG
(
num_channels
*
NUM_INTERNODE_DISPATCH_BLOCKS_PER_CHANNEL
,
num_threads
,
stream
);
LAUNCH_KERNEL_NON_COOPERATIVE
(
LAUNCH_KERNEL_NON_COOPERATIVE
(
&
cfg
,
cached_notify_func
,
rdma_clean_meta
.
first
,
rdma_clean_meta
.
second
,
&
cfg
,
cached_notify_func
,
rdma_clean_meta
.
first
,
rdma_clean_meta
.
second
,
nvl_clean_meta
.
first
,
nvl_clean_meta
.
second
,
combined_rdma_head
,
num_combined_tokens
,
nvl_clean_meta
.
first
,
nvl_clean_meta
.
second
,
combined_rdma_head
,
num_combined_tokens
,
...
@@ -1327,11 +1334,12 @@ void cached_notify(int hidden_int4, int num_scales, int num_topk_idx, int num_to
...
@@ -1327,11 +1334,12 @@ void cached_notify(int hidden_int4, int num_scales, int num_topk_idx, int num_to
cpu_rdma_team
);
cpu_rdma_team
);
}
}
template
<
int
kNumRanks
,
typename
dtype_t
,
int
kMaxNumRanks
,
typename
Receive
Fn
,
typename
ReceiveTWFn
>
template
<
int
kNumRanks
,
typename
dtype_t
,
int
kMaxNumRanks
,
bool
kUseMLS
,
typename
GetAddr
Fn
,
typename
ReceiveTWFn
>
__device__
int
combine_token
(
bool
is_token_in_rank
,
int
head_idx
,
__device__
int
combine_token
(
bool
is_token_in_rank
,
int
head_idx
,
int
lane_id
,
int
hidden_int4
,
int
num_topk
,
int
lane_id
,
int
hidden_int4
,
int
num_topk
,
int4
*
combined_row
,
float
*
combined_topk_weights
,
int4
*
combined_row
,
float
*
combined_topk_weights
,
int
num_max_recv_tokens
,
const
ReceiveFn
&
recv_fn
,
const
ReceiveTWFn
&
recv_tw_fn
)
{
int
num_max_recv_tokens
,
const
GetAddrFn
&
get_addr_fn
,
const
ReceiveTWFn
&
recv_tw_fn
)
{
constexpr
auto
kDtypePerInt4
=
sizeof
(
int4
)
/
sizeof
(
dtype_t
);
constexpr
auto
kDtypePerInt4
=
sizeof
(
int4
)
/
sizeof
(
dtype_t
);
// Broadcast current heads
// Broadcast current heads
...
@@ -1353,7 +1361,7 @@ __device__ int combine_token(bool is_token_in_rank, int head_idx,
...
@@ -1353,7 +1361,7 @@ __device__ int combine_token(bool is_token_in_rank, int head_idx,
int4
recv_value_int4
[
kMaxNumRanks
];
int4
recv_value_int4
[
kMaxNumRanks
];
#pragma unroll
#pragma unroll
for
(
int
j
=
0
;
j
<
num_topk_ranks
;
++
j
)
for
(
int
j
=
0
;
j
<
num_topk_ranks
;
++
j
)
recv_value_int4
[
j
]
=
recv
_fn
(
topk_ranks
[
j
],
slot_indices
[
j
],
i
);
recv_value_int4
[
j
]
=
ld_nc_global
(
get_addr
_fn
(
topk_ranks
[
j
],
slot_indices
[
j
],
i
)
)
;
// Reduce all-to-all results
// Reduce all-to-all results
float
values
[
kDtypePerInt4
]
=
{
0
};
float
values
[
kDtypePerInt4
]
=
{
0
};
...
@@ -1416,6 +1424,7 @@ combine(int4 *combined_x, float *combined_topk_weights, const bool *is_combined_
...
@@ -1416,6 +1424,7 @@ combine(int4 *combined_x, float *combined_topk_weights, const bool *is_combined_
__shared__
shmem_ctx_t
ctx
;
__shared__
shmem_ctx_t
ctx
;
shmem_wg_ctx_create
(
&
ctx
);
shmem_wg_ctx_create
(
&
ctx
);
#endif
#endif
EP_STATIC_ASSERT
(
kNumCombineForwarderWarps
<=
kWarpSize
,
"Invalid number of forwarder warps"
);
const
auto
sm_id
=
static_cast
<
int
>
(
blockIdx
.
x
);
const
auto
sm_id
=
static_cast
<
int
>
(
blockIdx
.
x
);
const
auto
num_threads
=
static_cast
<
int
>
(
blockDim
.
x
),
num_warps
=
num_threads
/
kWarpSize
;
const
auto
num_threads
=
static_cast
<
int
>
(
blockDim
.
x
),
num_warps
=
num_threads
/
kWarpSize
;
...
@@ -1717,14 +1726,15 @@ combine(int4 *combined_x, float *combined_topk_weights, const bool *is_combined_
...
@@ -1717,14 +1726,15 @@ combine(int4 *combined_x, float *combined_topk_weights, const bool *is_combined_
// Combine current token
// Combine current token
auto
rdma_slot_idx
=
token_idx
%
num_max_rdma_chunked_recv_tokens
;
auto
rdma_slot_idx
=
token_idx
%
num_max_rdma_chunked_recv_tokens
;
void
*
shifted
=
send_buffer
+
rdma_slot_idx
*
num_bytes_per_rdma_token
;
void
*
shifted
=
send_buffer
+
rdma_slot_idx
*
num_bytes_per_rdma_token
;
auto
recv
_fn
=
[
&
](
int
src_nvl_rank
,
int
slot_idx
,
int
hidden_int4_idx
)
->
int4
{
return
ld_nc_global
(
nvl_channel_x
.
buffer
(
src_nvl_rank
)
+
slot_idx
*
hidden_int4
+
hidden_int4_idx
)
;
};
auto
get_addr
_fn
=
[
&
](
int
src_nvl_rank
,
int
slot_idx
,
int
hidden_int4_idx
)
->
int4
*
{
return
reinterpret_cast
<
int4
*>
(
nvl_channel_x
.
buffer
(
src_nvl_rank
)
+
slot_idx
*
hidden_int4
)
+
hidden_int4_idx
;
};
auto
recv_tw_fn
=
[
&
](
int
src_nvl_rank
,
int
slot_idx
,
int
topk_idx
)
->
float
{
return
ld_nc_global
(
nvl_channel_topk_weights
.
buffer
(
src_nvl_rank
)
+
slot_idx
*
num_topk
+
topk_idx
);
};
auto
recv_tw_fn
=
[
&
](
int
src_nvl_rank
,
int
slot_idx
,
int
topk_idx
)
->
float
{
return
ld_nc_global
(
nvl_channel_topk_weights
.
buffer
(
src_nvl_rank
)
+
slot_idx
*
num_topk
+
topk_idx
);
};
combine_token
<
NUM_MAX_NVL_PEERS
,
dtype_t
,
NUM_MAX_NVL_PEERS
>
(
expected_head
>=
0
,
combine_token
<
NUM_MAX_NVL_PEERS
,
dtype_t
,
NUM_MAX_NVL_PEERS
,
true
>
(
expected_head
>=
0
,
expected_head
,
lane_id
,
expected_head
,
lane_id
,
hidden_int4
,
num_topk
,
hidden_int4
,
num_topk
,
reinterpret_cast
<
int4
*>
(
shifted
),
reinterpret_cast
<
int4
*>
(
shifted
),
reinterpret_cast
<
float
*>
(
reinterpret_cast
<
int8_t
*>
(
shifted
)
+
hidden_bytes
+
sizeof
(
SourceMeta
)),
reinterpret_cast
<
float
*>
(
reinterpret_cast
<
int8_t
*>
(
shifted
)
+
hidden_bytes
+
sizeof
(
SourceMeta
)),
num_max_nvl_chunked_recv_tokens_per_rdma
,
recv_fn
,
recv_tw_fn
);
num_max_nvl_chunked_recv_tokens_per_rdma
,
get_addr_fn
,
recv_tw_fn
);
// Update head
// Update head
if
(
lane_id
<
NUM_MAX_NVL_PEERS
)
{
if
(
lane_id
<
NUM_MAX_NVL_PEERS
)
{
...
@@ -1787,7 +1797,6 @@ combine(int4 *combined_x, float *combined_topk_weights, const bool *is_combined_
...
@@ -1787,7 +1797,6 @@ combine(int4 *combined_x, float *combined_topk_weights, const bool *is_combined_
int
last_nvl_head
[
kNumRDMARanks
]
=
{
0
};
int
last_nvl_head
[
kNumRDMARanks
]
=
{
0
};
int
dst_nvl_rank
=
lane_id
<
NUM_MAX_NVL_PEERS
?
lane_id
:
0
;
int
dst_nvl_rank
=
lane_id
<
NUM_MAX_NVL_PEERS
?
lane_id
:
0
;
EP_STATIC_ASSERT
(
kNumCombineForwarderWarps
<=
kWarpSize
,
"Invalid number of forwarder warps"
);
while
(
true
)
{
while
(
true
)
{
// Retired
// Retired
...
@@ -1853,14 +1862,15 @@ combine(int4 *combined_x, float *combined_topk_weights, const bool *is_combined_
...
@@ -1853,14 +1862,15 @@ combine(int4 *combined_x, float *combined_topk_weights, const bool *is_combined_
syncwarp
();
syncwarp
();
// Combine current token
// Combine current token
auto
recv
_fn
=
[
&
](
int
src_rdma_rank
,
int
slot_idx
,
int
hidden_int4_idx
)
->
int4
{
return
ld_nc_global
(
reinterpret_cast
<
const
int4
*>
(
rdma_channel_data
.
recv_buffer
(
src_rdma_rank
)
+
slot_idx
*
num_bytes_per_rdma_token
)
+
hidden_int4_idx
)
;};
auto
get_addr
_fn
=
[
&
](
int
src_rdma_rank
,
int
slot_idx
,
int
hidden_int4_idx
)
->
int4
*
{
return
reinterpret_cast
<
int4
*>
(
rdma_channel_data
.
recv_buffer
(
src_rdma_rank
)
+
slot_idx
*
num_bytes_per_rdma_token
)
+
hidden_int4_idx
;
};
auto
recv_tw_fn
=
[
&
](
int
src_rdma_rank
,
int
slot_idx
,
int
topk_idx
)
->
float
{
return
ld_nc_global
(
reinterpret_cast
<
const
float
*>
(
rdma_channel_data
.
recv_buffer
(
src_rdma_rank
)
+
slot_idx
*
num_bytes_per_rdma_token
+
hidden_bytes
+
sizeof
(
SourceMeta
))
+
topk_idx
);};
auto
recv_tw_fn
=
[
&
](
int
src_rdma_rank
,
int
slot_idx
,
int
topk_idx
)
->
float
{
return
ld_nc_global
(
reinterpret_cast
<
const
float
*>
(
rdma_channel_data
.
recv_buffer
(
src_rdma_rank
)
+
slot_idx
*
num_bytes_per_rdma_token
+
hidden_bytes
+
sizeof
(
SourceMeta
))
+
topk_idx
);};
combine_token
<
kNumRDMARanks
,
dtype_t
,
kNumTopkRDMARanks
>
(
expected_head
>=
0
,
combine_token
<
kNumRDMARanks
,
dtype_t
,
kNumTopkRDMARanks
,
false
>
(
expected_head
>=
0
,
expected_head
,
lane_id
,
expected_head
,
lane_id
,
hidden_int4
,
num_topk
,
hidden_int4
,
num_topk
,
combined_x
+
token_idx
*
hidden_int4
,
combined_x
+
token_idx
*
hidden_int4
,
combined_topk_weights
+
token_idx
*
num_topk
,
combined_topk_weights
+
token_idx
*
num_topk
,
num_max_rdma_chunked_recv_tokens
,
recv_fn
,
recv_tw_fn
);
num_max_rdma_chunked_recv_tokens
,
get_addr_fn
,
recv_tw_fn
);
}
}
// Retired
// Retired
...
@@ -1879,7 +1889,6 @@ combine(int4 *combined_x, float *combined_topk_weights, const bool *is_combined_
...
@@ -1879,7 +1889,6 @@ combine(int4 *combined_x, float *combined_topk_weights, const bool *is_combined_
int
last_nvl_head
[
kNumRDMARanks
]
=
{
0
};
int
last_nvl_head
[
kNumRDMARanks
]
=
{
0
};
int
dst_rdma_rank
=
lane_id
<
kNumRDMARanks
?
lane_id
:
0
;
int
dst_rdma_rank
=
lane_id
<
kNumRDMARanks
?
lane_id
:
0
;
int
dst_nvl_rank
=
lane_id
<
NUM_MAX_NVL_PEERS
?
lane_id
:
0
;
int
dst_nvl_rank
=
lane_id
<
NUM_MAX_NVL_PEERS
?
lane_id
:
0
;
EP_STATIC_ASSERT
(
kNumCombineForwarderWarps
<=
kWarpSize
,
"Invalid number of forwarder warps"
);
while
(
true
)
{
while
(
true
)
{
// Retired
// Retired
...
...
deep_ep/buffer.py
View file @
ab0afb04
...
@@ -207,7 +207,8 @@ class Buffer:
...
@@ -207,7 +207,8 @@ class Buffer:
new_num_sms: the new number to be set.
new_num_sms: the new number to be set.
"""
"""
assert
new_num_sms
%
2
==
0
,
"The SM count must be even"
assert
new_num_sms
%
2
==
0
,
"The SM count must be new_num_sms % 2 == 0"
assert
new_num_sms
%
3
==
0
,
"The SM count must be new_num_sms % 3 == 0"
Buffer
.
num_sms
=
new_num_sms
Buffer
.
num_sms
=
new_num_sms
@
staticmethod
@
staticmethod
...
...
tests/1.sh
View file @
ab0afb04
...
@@ -2,10 +2,13 @@
...
@@ -2,10 +2,13 @@
# rocSHMEM
# rocSHMEM
export
ROCSHMEM_GDA_NUM_QPS_DEFAULT_CTX
=
288
export
ROCSHMEM_GDA_NUM_QPS_DEFAULT_CTX
=
288
export
ROCSHMEM_MAX_NUM_CONTEXTS
=
48
export
ROCSHMEM_MAX_NUM_CONTEXTS
=
60
export
ROCSHMEM_ALLOWED_IBV_DEVICES
=
mlx5_2,mlx5_3,mlx5_4,mlx5_5,mlx5_6,mlx5_7,mlx5_8,mlx5_9
export
ROCSHMEM_ALLOWED_IBV_DEVICES
=
mlx5_2,mlx5_3,mlx5_4,mlx5_5,mlx5_6,mlx5_7,mlx5_8,mlx5_9
export
ROCSHMEM_HEAP_SIZE
=
3737418240
export
ROCSHMEM_HEAP_SIZE
=
3737418240
export
ROCSHMEM_TOPO_FILE_FORCE
=
./topo.config
export
ROCSHMEM_TOPO_FILE_FORCE
=
./topo.config
# NMZ使用
# export ROCSHMEM_DISABLE_HDP_FLUSH=1
# export ROCSHMEM_GDR_DISABLE_XDP=1
# # duSHMEM
# # duSHMEM
# export LD_LIBRARY_PATH=/opt/dtk/dushmem/lib:$LD_LIBRARY_PATH
# export LD_LIBRARY_PATH=/opt/dtk/dushmem/lib:$LD_LIBRARY_PATH
...
@@ -17,8 +20,8 @@ export HIP_VISIBLE_DEVICES=0,1,2,3,4,5,6,7
...
@@ -17,8 +20,8 @@ export HIP_VISIBLE_DEVICES=0,1,2,3,4,5,6,7
export
PYTHONPATH
=
$(
pwd
)
/../
export
PYTHONPATH
=
$(
pwd
)
/../
# test
# test
#
torchrun --nproc-per-node=1 --nnodes=2 --node-rank=0 --master-addr="10.16.1.37" --master-port=1234 ./test_internode.py
torchrun
--nproc-per-node
=
1
--nnodes
=
2
--node-rank
=
0
--master-addr
=
"10.16.1.37"
--master-port
=
1234 ./test_internode.py
# torchrun --nproc-per-node=1 --nnodes=2 --node-rank=0 --master-addr="10.16.1.37" --master-port=1234 ./test_internode.py --test-ll-compatibility
# torchrun --nproc-per-node=1 --nnodes=2 --node-rank=0 --master-addr="10.16.1.37" --master-port=1234 ./test_internode.py --test-ll-compatibility
torchrun
--nproc-per-node
=
1
--nnodes
=
2
--node-rank
=
0
--master-addr
=
"10.16.1.37"
--master-port
=
1234 ./test_low_latency.py
# --pressure-test
#
torchrun --nproc-per-node=1 --nnodes=2 --node-rank=0 --master-addr="10.16.1.37" --master-port=1234 ./test_low_latency.py # --pressure-test
# torchrun --nproc-per-node=1 --nnodes=2 --node-rank=0 --master-addr="10.16.1.37" --master-port=1234 ./test_low_latency.py --use-logfmt
# torchrun --nproc-per-node=1 --nnodes=2 --node-rank=0 --master-addr="10.16.1.37" --master-port=1234 ./test_low_latency.py --use-logfmt
# torchrun --nproc-per-node=1 --nnodes=2 --node-rank=0 --master-addr="10.16.1.37" --master-port=1234 ./test_low_latency.py --enable-dispatch-ll-layered --enable-combine-overlap
# torchrun --nproc-per-node=1 --nnodes=2 --node-rank=0 --master-addr="10.16.1.37" --master-port=1234 ./test_low_latency.py --enable-dispatch-ll-layered --enable-combine-overlap
tests/2.sh
View file @
ab0afb04
...
@@ -2,10 +2,13 @@
...
@@ -2,10 +2,13 @@
# rocSHMEM
# rocSHMEM
export
ROCSHMEM_GDA_NUM_QPS_DEFAULT_CTX
=
288
export
ROCSHMEM_GDA_NUM_QPS_DEFAULT_CTX
=
288
export
ROCSHMEM_MAX_NUM_CONTEXTS
=
48
export
ROCSHMEM_MAX_NUM_CONTEXTS
=
60
export
ROCSHMEM_ALLOWED_IBV_DEVICES
=
mlx5_2,mlx5_3,mlx5_4,mlx5_5,mlx5_6,mlx5_7,mlx5_8,mlx5_9
export
ROCSHMEM_ALLOWED_IBV_DEVICES
=
mlx5_2,mlx5_3,mlx5_4,mlx5_5,mlx5_6,mlx5_7,mlx5_8,mlx5_9
export
ROCSHMEM_HEAP_SIZE
=
3737418240
export
ROCSHMEM_HEAP_SIZE
=
3737418240
export
ROCSHMEM_TOPO_FILE_FORCE
=
./topo.config
export
ROCSHMEM_TOPO_FILE_FORCE
=
./topo.config
# NMZ使用
# export ROCSHMEM_DISABLE_HDP_FLUSH=1
# export ROCSHMEM_GDR_DISABLE_XDP=1
# # duSHMEM
# # duSHMEM
# export LD_LIBRARY_PATH=/opt/dtk/dushmem/lib:$LD_LIBRARY_PATH
# export LD_LIBRARY_PATH=/opt/dtk/dushmem/lib:$LD_LIBRARY_PATH
...
@@ -17,8 +20,8 @@ export HIP_VISIBLE_DEVICES=0,1,2,3,4,5,6,7
...
@@ -17,8 +20,8 @@ export HIP_VISIBLE_DEVICES=0,1,2,3,4,5,6,7
export
PYTHONPATH
=
$(
pwd
)
/../
export
PYTHONPATH
=
$(
pwd
)
/../
# test
# test
#
torchrun --nproc-per-node=1 --nnodes=2 --node-rank=1 --master-addr="10.16.1.37" --master-port=1234 ./test_internode.py
torchrun
--nproc-per-node
=
1
--nnodes
=
2
--node-rank
=
1
--master-addr
=
"10.16.1.37"
--master-port
=
1234 ./test_internode.py
# torchrun --nproc-per-node=1 --nnodes=2 --node-rank=1 --master-addr="10.16.1.37" --master-port=1234 ./test_internode.py --test-ll-compatibility
# torchrun --nproc-per-node=1 --nnodes=2 --node-rank=1 --master-addr="10.16.1.37" --master-port=1234 ./test_internode.py --test-ll-compatibility
torchrun
--nproc-per-node
=
1
--nnodes
=
2
--node-rank
=
1
--master-addr
=
"10.16.1.37"
--master-port
=
1234 ./test_low_latency.py
# --pressure-test
#
torchrun --nproc-per-node=1 --nnodes=2 --node-rank=1 --master-addr="10.16.1.37" --master-port=1234 ./test_low_latency.py # --pressure-test
# torchrun --nproc-per-node=1 --nnodes=2 --node-rank=1 --master-addr="10.16.1.37" --master-port=1234 ./test_low_latency.py --use-logfmt
# torchrun --nproc-per-node=1 --nnodes=2 --node-rank=1 --master-addr="10.16.1.37" --master-port=1234 ./test_low_latency.py --use-logfmt
# torchrun --nproc-per-node=1 --nnodes=2 --node-rank=1 --master-addr="10.16.1.37" --master-port=1234 ./test_low_latency.py --enable-dispatch-ll-layered --enable-combine-overlap
# torchrun --nproc-per-node=1 --nnodes=2 --node-rank=1 --master-addr="10.16.1.37" --master-port=1234 ./test_low_latency.py --enable-dispatch-ll-layered --enable-combine-overlap
tests/test_internode.py
View file @
ab0afb04
...
@@ -143,7 +143,8 @@ def test_main(args: argparse.Namespace, num_sms: int,
...
@@ -143,7 +143,8 @@ def test_main(args: argparse.Namespace, num_sms: int,
# Check `topk_weights`
# Check `topk_weights`
if
not
is_rand
:
if
not
is_rand
:
recv_topk_weights
[
recv_topk_idx
.
eq
(
-
1
)]
=
recv_topk_weights
.
amax
(
dim
=
1
,
keepdim
=
True
).
expand_as
(
recv_topk_weights
)[
recv_topk_idx
.
eq
(
-
1
)]
max_weights
=
recv_topk_weights
.
amax
(
dim
=
1
,
keepdim
=
True
)
# Shape: [Batch, 1]
recv_topk_weights
=
torch
.
where
(
recv_topk_idx
==
-
1
,
max_weights
,
recv_topk_weights
)
check_data
(
recv_topk_weights
,
recv_gbl_rank_prefix_sum
)
check_data
(
recv_topk_weights
,
recv_gbl_rank_prefix_sum
)
# Test cached dispatch (must without top-k staffs)
# Test cached dispatch (must without top-k staffs)
...
@@ -186,6 +187,7 @@ def test_main(args: argparse.Namespace, num_sms: int,
...
@@ -186,6 +187,7 @@ def test_main(args: argparse.Namespace, num_sms: int,
if
local_rank
==
0
:
if
local_rank
==
0
:
print
(
' passed'
,
flush
=
True
)
print
(
' passed'
,
flush
=
True
)
if
local_rank
==
0
:
if
local_rank
==
0
:
print
(
''
,
flush
=
True
)
print
(
''
,
flush
=
True
)
...
@@ -201,6 +203,8 @@ def test_main(args: argparse.Namespace, num_sms: int,
...
@@ -201,6 +203,8 @@ def test_main(args: argparse.Namespace, num_sms: int,
nvl_recv_bytes
=
(
dispatch_bf16_nvl_recv_bytes
*
fp8_factor
)
if
isinstance
(
current_x
,
tuple
)
else
dispatch_bf16_nvl_recv_bytes
nvl_recv_bytes
=
(
dispatch_bf16_nvl_recv_bytes
*
fp8_factor
)
if
isinstance
(
current_x
,
tuple
)
else
dispatch_bf16_nvl_recv_bytes
for
nvl_chunk_size
in
range
(
4
,
45
,
4
):
for
nvl_chunk_size
in
range
(
4
,
45
,
4
):
for
rdma_chunk_size
in
range
(
4
,
33
,
4
):
for
rdma_chunk_size
in
range
(
4
,
33
,
4
):
if
rdma_buffer_size
%
rdma_chunk_size
!=
0
:
continue
config
=
deep_ep
.
Config
(
num_sms
,
nvl_chunk_size
,
nvl_buffer_size
,
rdma_chunk_size
,
rdma_buffer_size
)
config
=
deep_ep
.
Config
(
num_sms
,
nvl_chunk_size
,
nvl_buffer_size
,
rdma_chunk_size
,
rdma_buffer_size
)
tune_args
=
{
'x'
:
current_x
,
'handle'
:
handle
,
'config'
:
config
}
tune_args
=
{
'x'
:
current_x
,
'handle'
:
handle
,
'config'
:
config
}
t
,
notify_t
=
bench_kineto
(
lambda
:
buffer
.
dispatch
(
**
tune_args
),
(
'dispatch'
,
'notify'
),
suppress_kineto_output
=
True
)
t
,
notify_t
=
bench_kineto
(
lambda
:
buffer
.
dispatch
(
**
tune_args
),
(
'dispatch'
,
'notify'
),
suppress_kineto_output
=
True
)
...
@@ -233,6 +237,8 @@ def test_main(args: argparse.Namespace, num_sms: int,
...
@@ -233,6 +237,8 @@ def test_main(args: argparse.Namespace, num_sms: int,
best_time
,
best_results
=
1e10
,
None
best_time
,
best_results
=
1e10
,
None
for
nvl_chunk_size
in
range
(
1
,
8
,
1
):
for
nvl_chunk_size
in
range
(
1
,
8
,
1
):
for
rdma_chunk_size
in
range
(
12
if
num_nodes
==
2
else
8
,
33
,
4
):
for
rdma_chunk_size
in
range
(
12
if
num_nodes
==
2
else
8
,
33
,
4
):
if
rdma_buffer_size
%
rdma_chunk_size
!=
0
:
continue
config
=
deep_ep
.
Config
(
num_sms
,
nvl_chunk_size
,
nvl_buffer_size
,
rdma_chunk_size
,
rdma_buffer_size
)
config
=
deep_ep
.
Config
(
num_sms
,
nvl_chunk_size
,
nvl_buffer_size
,
rdma_chunk_size
,
rdma_buffer_size
)
tune_args
=
{
'x'
:
recv_x
,
'handle'
:
handle
,
'config'
:
config
}
tune_args
=
{
'x'
:
recv_x
,
'handle'
:
handle
,
'config'
:
config
}
t
,
notify_t
=
bench_kineto
(
lambda
:
buffer
.
combine
(
**
tune_args
),
(
'combine'
,
'notify'
),
suppress_kineto_output
=
True
)
t
,
notify_t
=
bench_kineto
(
lambda
:
buffer
.
combine
(
**
tune_args
),
(
'combine'
,
'notify'
),
suppress_kineto_output
=
True
)
...
@@ -265,8 +271,9 @@ def test_loop(local_rank: int, num_local_ranks: int, args: argparse.Namespace):
...
@@ -265,8 +271,9 @@ def test_loop(local_rank: int, num_local_ranks: int, args: argparse.Namespace):
ll_num_tokens
,
ll_hidden
,
ll_num_experts
,
ll_num_topk
=
16
,
5120
,
256
,
9
ll_num_tokens
,
ll_hidden
,
ll_num_experts
,
ll_num_topk
=
16
,
5120
,
256
,
9
num_rdma_bytes_ll
=
deep_ep
.
Buffer
.
get_low_latency_rdma_size_hint
(
ll_num_tokens
,
ll_hidden
,
num_ranks
,
ll_num_experts
)
num_rdma_bytes_ll
=
deep_ep
.
Buffer
.
get_low_latency_rdma_size_hint
(
ll_num_tokens
,
ll_hidden
,
num_ranks
,
ll_num_experts
)
num_sms
=
48
num_sms
=
60
num_qps_per_rank
=
max
(
num_sms
,
ll_num_experts
//
num_ranks
if
args
.
test_ll_compatibility
else
0
)
num_qps_per_rank
=
max
(
num_sms
,
ll_num_experts
//
num_ranks
if
args
.
test_ll_compatibility
else
0
)
deep_ep
.
Buffer
.
set_num_sms
(
num_sms
)
hidden_bytes
=
get_hidden_bytes
(
args
)
hidden_bytes
=
get_hidden_bytes
(
args
)
num_nvl_bytes
,
num_rdma_bytes
,
num_rdma_bytes_norm
=
0
,
0
,
0
num_nvl_bytes
,
num_rdma_bytes
,
num_rdma_bytes_norm
=
0
,
0
,
0
...
@@ -292,7 +299,7 @@ def test_loop(local_rank: int, num_local_ranks: int, args: argparse.Namespace):
...
@@ -292,7 +299,7 @@ def test_loop(local_rank: int, num_local_ranks: int, args: argparse.Namespace):
break
break
if
local_rank
==
0
:
if
local_rank
==
0
:
print
(
f
'
{
ref_hash
=
}
'
)
print
(
f
'
ref_hash=
{
ref_hash
}
'
)
print
(
''
,
flush
=
True
)
print
(
''
,
flush
=
True
)
for
j
in
range
(
20
):
for
j
in
range
(
20
):
...
...
tests/test_intranode.py
View file @
ab0afb04
...
@@ -244,7 +244,7 @@ def test_loop(local_rank: int, num_local_ranks: int, args: argparse.Namespace):
...
@@ -244,7 +244,7 @@ def test_loop(local_rank: int, num_local_ranks: int, args: argparse.Namespace):
num_qps_per_rank
=
(
ll_num_experts
//
num_ranks
if
test_ll_compatibility
else
1
),
explicitly_destroy
=
True
)
num_qps_per_rank
=
(
ll_num_experts
//
num_ranks
if
test_ll_compatibility
else
1
),
explicitly_destroy
=
True
)
torch
.
manual_seed
(
rank
)
torch
.
manual_seed
(
rank
)
for
i
in
(
24
,
):
for
i
in
(
60
,
):
test_main
(
args
,
i
,
local_rank
,
num_ranks
,
rank
,
buffer
,
group
)
test_main
(
args
,
i
,
local_rank
,
num_ranks
,
rank
,
buffer
,
group
)
if
local_rank
==
0
:
if
local_rank
==
0
:
print
(
''
,
flush
=
True
)
print
(
''
,
flush
=
True
)
...
...
tests/test_low_latency.py
View file @
ab0afb04
...
@@ -52,8 +52,9 @@ def test_main(num_tokens: int,
...
@@ -52,8 +52,9 @@ def test_main(num_tokens: int,
seed
:
int
=
0
):
seed
:
int
=
0
):
torch
.
manual_seed
(
seed
+
rank
)
torch
.
manual_seed
(
seed
+
rank
)
random
.
seed
(
seed
+
rank
)
random
.
seed
(
seed
+
rank
)
if
rank
==
0
:
print
(
f
"enable_dispatch_ll_layered=
{
enable_dispatch_ll_layered
}
, enable_combine_overlap=
{
enable_combine_overlap
}
, use_logfmt=
{
use_logfmt
}
"
)
print
(
f
"enable_dispatch_ll_layered=
{
enable_dispatch_ll_layered
}
, enable_combine_overlap=
{
enable_combine_overlap
}
, use_logfmt=
{
use_logfmt
}
"
)
assert
not
(
use_logfmt
and
(
enable_dispatch_ll_layered
or
enable_combine_overlap
)),
\
assert
not
(
use_logfmt
and
(
enable_dispatch_ll_layered
or
enable_combine_overlap
)),
\
"use_logfmt=True and enable_dispatch_ll_layered/enable_combine_overlap conflict"
"use_logfmt=True and enable_dispatch_ll_layered/enable_combine_overlap conflict"
assert
num_experts
%
num_ranks
==
0
assert
num_experts
%
num_ranks
==
0
...
@@ -144,7 +145,7 @@ def test_main(num_tokens: int,
...
@@ -144,7 +145,7 @@ def test_main(num_tokens: int,
recv_x_amin
=
recv_x
[:,
:
-
128
].
amin
(
dim
=-
1
)
recv_x_amin
=
recv_x
[:,
:
-
128
].
amin
(
dim
=-
1
)
recv_x_amax
=
recv_x
[:,
:
-
128
].
amax
(
dim
=-
1
)
recv_x_amax
=
recv_x
[:,
:
-
128
].
amax
(
dim
=-
1
)
if
(
enable_dispatch_ll_layered
or
enable_combine_overlap
)
:
if
enable_dispatch_ll_layered
or
enable_combine_overlap
:
recv_src_info
=
recv_src_info
[:
num_valid_tokens
]
&
int_mask
# 掩掉多余的信息
recv_src_info
=
recv_src_info
[:
num_valid_tokens
]
&
int_mask
# 掩掉多余的信息
else
:
else
:
recv_src_info
=
recv_src_info
[:
num_valid_tokens
]
recv_src_info
=
recv_src_info
[:
num_valid_tokens
]
...
@@ -179,7 +180,7 @@ def test_main(num_tokens: int,
...
@@ -179,7 +180,7 @@ def test_main(num_tokens: int,
out
=
torch
.
empty
((
num_tokens
,
hidden
),
dtype
=
torch
.
bfloat16
,
device
=
'cuda'
)
out
=
torch
.
empty
((
num_tokens
,
hidden
),
dtype
=
torch
.
bfloat16
,
device
=
'cuda'
)
if
enable_combine_overlap
:
if
enable_combine_overlap
:
block_m
,
threshold
,
num_sms
=
64
,
10
,
3
block_m
,
threshold
,
num_sms
=
64
,
10
,
3
total_num_per_expert
=
ceil_div
(
num_tokens
*
num_ranks
,
block_m
)
# 每个本地专家 总的信号数
??
total_num_per_expert
=
ceil_div
(
num_tokens
*
num_ranks
,
block_m
)
# 每个本地专家 总的信号数
comp_signal
=
torch
.
zeros
(
num_local_experts
*
total_num_per_expert
,
dtype
=
torch
.
int32
,
device
=
'cuda'
)
comp_signal
=
torch
.
zeros
(
num_local_experts
*
total_num_per_expert
,
dtype
=
torch
.
int32
,
device
=
'cuda'
)
for
i
in
range
(
num_local_experts
):
for
i
in
range
(
num_local_experts
):
...
...
tests_mpi/test_env.sh
View file @
ab0afb04
...
@@ -8,12 +8,15 @@ export PYTHONPATH=$(pwd)
...
@@ -8,12 +8,15 @@ export PYTHONPATH=$(pwd)
# rocSHMEM
# rocSHMEM
export
ROCSHMEM_GDA_NUM_QPS_DEFAULT_CTX
=
288
export
ROCSHMEM_GDA_NUM_QPS_DEFAULT_CTX
=
288
export
ROCSHMEM_MAX_NUM_CONTEXTS
=
48
export
ROCSHMEM_MAX_NUM_CONTEXTS
=
60
export
ROCSHMEM_ALLOWED_IBV_DEVICES
=
mlx5_2,mlx5_3,mlx5_4,mlx5_5,mlx5_6,mlx5_7,mlx5_8,mlx5_9
export
ROCSHMEM_ALLOWED_IBV_DEVICES
=
mlx5_2,mlx5_3,mlx5_4,mlx5_5,mlx5_6,mlx5_7,mlx5_8,mlx5_9
export
ROCSHMEM_HEAP_SIZE
=
10
737418240
export
ROCSHMEM_HEAP_SIZE
=
3
737418240
export
ROCSHMEM_TOPO_FILE_FORCE
=
$(
pwd
)
/tests_mpi/topo.config
export
ROCSHMEM_TOPO_FILE_FORCE
=
$(
pwd
)
/tests_mpi/topo.config
# NMZ使用
# export ROCSHMEM_DISABLE_HDP_FLUSH=1
# export ROCSHMEM_GDR_DISABLE_XDP=1
# duSHMEM
#
# duSHMEM
export
LD_LIBRARY_PATH
=
/opt/dtk/dushmem/lib:
$LD_LIBRARY_PATH
#
export LD_LIBRARY_PATH=/opt/dtk/dushmem/lib:$LD_LIBRARY_PATH
export
DEEP_EP_DEVICE_TO_HCA_MAPPING
=
0:mlx5_2:1,1:mlx5_3:1,2:mlx5_4:1,3:mlx5_5:1,4:mlx5_6:1,5:mlx5_7:1,6:mlx5_8:1,7:mlx5_9:1
#
export DEEP_EP_DEVICE_TO_HCA_MAPPING=0:mlx5_2:1,1:mlx5_3:1,2:mlx5_4:1,3:mlx5_5:1,4:mlx5_6:1,5:mlx5_7:1,6:mlx5_8:1,7:mlx5_9:1
export
NVSHMEM_SYMMETRIC_SIZE
=
10737418240
#
export NVSHMEM_SYMMETRIC_SIZE=10737418240
tests_mpi/test_internode.py
View file @
ab0afb04
...
@@ -145,7 +145,8 @@ def test_main(args: argparse.Namespace, num_sms: int,
...
@@ -145,7 +145,8 @@ def test_main(args: argparse.Namespace, num_sms: int,
# Check `topk_weights`
# Check `topk_weights`
if
not
is_rand
:
if
not
is_rand
:
recv_topk_weights
[
recv_topk_idx
.
eq
(
-
1
)]
=
recv_topk_weights
.
amax
(
dim
=
1
,
keepdim
=
True
).
expand_as
(
recv_topk_weights
)[
recv_topk_idx
.
eq
(
-
1
)]
max_weights
=
recv_topk_weights
.
amax
(
dim
=
1
,
keepdim
=
True
)
# Shape: [Batch, 1]
recv_topk_weights
=
torch
.
where
(
recv_topk_idx
==
-
1
,
max_weights
,
recv_topk_weights
)
check_data
(
recv_topk_weights
,
recv_gbl_rank_prefix_sum
)
check_data
(
recv_topk_weights
,
recv_gbl_rank_prefix_sum
)
# Test cached dispatch (must without top-k staffs)
# Test cached dispatch (must without top-k staffs)
...
@@ -203,6 +204,8 @@ def test_main(args: argparse.Namespace, num_sms: int,
...
@@ -203,6 +204,8 @@ def test_main(args: argparse.Namespace, num_sms: int,
nvl_recv_bytes
=
(
dispatch_bf16_nvl_recv_bytes
*
fp8_factor
)
if
isinstance
(
current_x
,
tuple
)
else
dispatch_bf16_nvl_recv_bytes
nvl_recv_bytes
=
(
dispatch_bf16_nvl_recv_bytes
*
fp8_factor
)
if
isinstance
(
current_x
,
tuple
)
else
dispatch_bf16_nvl_recv_bytes
for
nvl_chunk_size
in
range
(
4
,
45
,
4
):
for
nvl_chunk_size
in
range
(
4
,
45
,
4
):
for
rdma_chunk_size
in
range
(
4
,
33
,
4
):
for
rdma_chunk_size
in
range
(
4
,
33
,
4
):
if
rdma_buffer_size
%
rdma_chunk_size
!=
0
:
continue
config
=
deep_ep
.
Config
(
num_sms
,
nvl_chunk_size
,
nvl_buffer_size
,
rdma_chunk_size
,
rdma_buffer_size
)
config
=
deep_ep
.
Config
(
num_sms
,
nvl_chunk_size
,
nvl_buffer_size
,
rdma_chunk_size
,
rdma_buffer_size
)
tune_args
=
{
'x'
:
current_x
,
'handle'
:
handle
,
'config'
:
config
}
tune_args
=
{
'x'
:
current_x
,
'handle'
:
handle
,
'config'
:
config
}
t
,
notify_t
=
bench_kineto
(
lambda
:
buffer
.
dispatch
(
**
tune_args
),
(
'dispatch'
,
'notify'
),
suppress_kineto_output
=
True
)
t
,
notify_t
=
bench_kineto
(
lambda
:
buffer
.
dispatch
(
**
tune_args
),
(
'dispatch'
,
'notify'
),
suppress_kineto_output
=
True
)
...
@@ -235,6 +238,8 @@ def test_main(args: argparse.Namespace, num_sms: int,
...
@@ -235,6 +238,8 @@ def test_main(args: argparse.Namespace, num_sms: int,
best_time
,
best_results
=
1e10
,
None
best_time
,
best_results
=
1e10
,
None
for
nvl_chunk_size
in
range
(
1
,
8
,
1
):
for
nvl_chunk_size
in
range
(
1
,
8
,
1
):
for
rdma_chunk_size
in
range
(
12
if
num_nodes
==
2
else
8
,
33
,
4
):
for
rdma_chunk_size
in
range
(
12
if
num_nodes
==
2
else
8
,
33
,
4
):
if
rdma_buffer_size
%
rdma_chunk_size
!=
0
:
continue
config
=
deep_ep
.
Config
(
num_sms
,
nvl_chunk_size
,
nvl_buffer_size
,
rdma_chunk_size
,
rdma_buffer_size
)
config
=
deep_ep
.
Config
(
num_sms
,
nvl_chunk_size
,
nvl_buffer_size
,
rdma_chunk_size
,
rdma_buffer_size
)
tune_args
=
{
'x'
:
recv_x
,
'handle'
:
handle
,
'config'
:
config
}
tune_args
=
{
'x'
:
recv_x
,
'handle'
:
handle
,
'config'
:
config
}
t
,
notify_t
=
bench_kineto
(
lambda
:
buffer
.
combine
(
**
tune_args
),
(
'combine'
,
'notify'
),
suppress_kineto_output
=
True
)
t
,
notify_t
=
bench_kineto
(
lambda
:
buffer
.
combine
(
**
tune_args
),
(
'combine'
,
'notify'
),
suppress_kineto_output
=
True
)
...
@@ -272,8 +277,9 @@ def test_loop(local_rank: int, num_local_ranks: int, args: argparse.Namespace):
...
@@ -272,8 +277,9 @@ def test_loop(local_rank: int, num_local_ranks: int, args: argparse.Namespace):
ll_num_tokens
,
ll_hidden
,
ll_num_experts
,
ll_num_topk
=
16
,
5120
,
256
,
9
ll_num_tokens
,
ll_hidden
,
ll_num_experts
,
ll_num_topk
=
16
,
5120
,
256
,
9
num_rdma_bytes_ll
=
deep_ep
.
Buffer
.
get_low_latency_rdma_size_hint
(
ll_num_tokens
,
ll_hidden
,
num_ranks
,
ll_num_experts
)
num_rdma_bytes_ll
=
deep_ep
.
Buffer
.
get_low_latency_rdma_size_hint
(
ll_num_tokens
,
ll_hidden
,
num_ranks
,
ll_num_experts
)
num_sms
=
48
num_sms
=
60
num_qps_per_rank
=
max
(
num_sms
,
ll_num_experts
//
num_ranks
if
args
.
test_ll_compatibility
else
0
)
num_qps_per_rank
=
max
(
num_sms
,
ll_num_experts
//
num_ranks
if
args
.
test_ll_compatibility
else
0
)
deep_ep
.
Buffer
.
set_num_sms
(
num_sms
)
hidden_bytes
=
get_hidden_bytes
(
args
)
hidden_bytes
=
get_hidden_bytes
(
args
)
num_nvl_bytes
,
num_rdma_bytes
,
num_rdma_bytes_norm
=
0
,
0
,
0
num_nvl_bytes
,
num_rdma_bytes
,
num_rdma_bytes_norm
=
0
,
0
,
0
...
@@ -299,7 +305,7 @@ def test_loop(local_rank: int, num_local_ranks: int, args: argparse.Namespace):
...
@@ -299,7 +305,7 @@ def test_loop(local_rank: int, num_local_ranks: int, args: argparse.Namespace):
break
break
if
rank
==
0
:
if
rank
==
0
:
print
(
f
'
{
ref_hash
=
}
'
)
print
(
f
'
ref_hash=
{
ref_hash
}
'
)
print
(
''
,
flush
=
True
)
print
(
''
,
flush
=
True
)
for
j
in
range
(
20
):
for
j
in
range
(
20
):
...
...
tests_mpi/test_intranode.py
View file @
ab0afb04
...
@@ -119,7 +119,8 @@ def test_main(args: argparse.Namespace, num_sms: int, local_rank: int, num_ranks
...
@@ -119,7 +119,8 @@ def test_main(args: argparse.Namespace, num_sms: int, local_rank: int, num_ranks
# Check `topk_weights`
# Check `topk_weights`
recv_topk_weights_clone
=
recv_topk_weights
.
clone
()
recv_topk_weights_clone
=
recv_topk_weights
.
clone
()
if
current_x
is
not
x_pure_rand
:
if
current_x
is
not
x_pure_rand
:
recv_topk_weights
[
recv_topk_idx
.
eq
(
-
1
)]
=
recv_topk_weights
.
amax
(
dim
=
1
,
keepdim
=
True
).
expand_as
(
recv_topk_weights
)[
recv_topk_idx
.
eq
(
-
1
)]
max_weights
=
recv_topk_weights
.
amax
(
dim
=
1
,
keepdim
=
True
)
# Shape: [Batch, 1]
recv_topk_weights
=
torch
.
where
(
recv_topk_idx
==
-
1
,
max_weights
,
recv_topk_weights
)
check_data
(
recv_topk_weights
,
rank_prefix_matrix
)
check_data
(
recv_topk_weights
,
rank_prefix_matrix
)
# Test `num_worst_tokens != 0`
# Test `num_worst_tokens != 0`
...
@@ -251,7 +252,7 @@ def test_loop(local_rank: int, num_local_ranks: int, args: argparse.Namespace):
...
@@ -251,7 +252,7 @@ def test_loop(local_rank: int, num_local_ranks: int, args: argparse.Namespace):
num_qps_per_rank
=
(
ll_num_experts
//
num_ranks
if
test_ll_compatibility
else
1
),
explicitly_destroy
=
True
)
num_qps_per_rank
=
(
ll_num_experts
//
num_ranks
if
test_ll_compatibility
else
1
),
explicitly_destroy
=
True
)
torch
.
manual_seed
(
rank
)
torch
.
manual_seed
(
rank
)
for
i
in
(
48
,
):
for
i
in
(
60
,
):
test_main
(
args
,
i
,
local_rank
,
num_ranks
,
rank
,
buffer
,
group
)
test_main
(
args
,
i
,
local_rank
,
num_ranks
,
rank
,
buffer
,
group
)
if
local_rank
==
0
:
if
local_rank
==
0
:
print
(
''
,
flush
=
True
)
print
(
''
,
flush
=
True
)
...
...
tests_mpi/test_low_latency.py
View file @
ab0afb04
...
@@ -36,6 +36,10 @@ def query_mask_buffer_and_check(api: Literal["dispatch", "combine", "clean"], bu
...
@@ -36,6 +36,10 @@ def query_mask_buffer_and_check(api: Literal["dispatch", "combine", "clean"], bu
assert
set
(
mask_status
.
nonzero
().
squeeze
(
-
1
).
tolist
())
==
expected_masked_ranks
assert
set
(
mask_status
.
nonzero
().
squeeze
(
-
1
).
tolist
())
==
expected_masked_ranks
def
ceil_div
(
a
,
b
):
return
(
a
+
b
-
1
)
//
b
def
test_main
(
num_tokens
:
int
,
def
test_main
(
num_tokens
:
int
,
hidden
:
int
,
hidden
:
int
,
num_experts
:
int
,
num_experts
:
int
,
...
@@ -44,11 +48,17 @@ def test_main(num_tokens: int,
...
@@ -44,11 +48,17 @@ def test_main(num_tokens: int,
num_ranks
:
int
,
num_ranks
:
int
,
group
:
dist
.
ProcessGroup
,
group
:
dist
.
ProcessGroup
,
buffer
:
deep_ep
.
Buffer
,
buffer
:
deep_ep
.
Buffer
,
enable_dispatch_ll_layered
:
bool
=
False
,
enable_combine_overlap
:
bool
=
False
,
use_logfmt
:
bool
=
False
,
use_logfmt
:
bool
=
False
,
seed
:
int
=
0
):
seed
:
int
=
0
):
torch
.
manual_seed
(
seed
+
rank
)
torch
.
manual_seed
(
seed
+
rank
)
random
.
seed
(
seed
+
rank
)
random
.
seed
(
seed
+
rank
)
if
rank
==
0
:
print
(
f
"enable_dispatch_ll_layered=
{
enable_dispatch_ll_layered
}
, enable_combine_overlap=
{
enable_combine_overlap
}
, use_logfmt=
{
use_logfmt
}
"
)
assert
not
(
use_logfmt
and
(
enable_dispatch_ll_layered
or
enable_combine_overlap
)),
\
"use_logfmt=True and enable_dispatch_ll_layered/enable_combine_overlap conflict"
assert
num_experts
%
num_ranks
==
0
assert
num_experts
%
num_ranks
==
0
num_local_experts
=
num_experts
//
num_ranks
num_local_experts
=
num_experts
//
num_ranks
...
@@ -86,6 +96,9 @@ def test_main(num_tokens: int,
...
@@ -86,6 +96,9 @@ def test_main(num_tokens: int,
hash_value
,
num_times
=
0
,
0
hash_value
,
num_times
=
0
,
0
for
x_i
,
current_x
in
enumerate
(
x_list
):
for
x_i
,
current_x
in
enumerate
(
x_list
):
for
return_recv_hook
in
(
False
,
True
):
for
return_recv_hook
in
(
False
,
True
):
if
enable_combine_overlap
and
(
not
return_recv_hook
):
# return_recv_hook 为False 时,不能启用 overlop
continue
for
quant_type
in
(
0
,
1
,
2
,
3
,
):
# 0: 不量化, 1: int8, 2: FP8_E4M3, 3: FP8_UE8M0 (仅支持round_scale=True), 4: FP8_E5M2
for
quant_type
in
(
0
,
1
,
2
,
3
,
):
# 0: 不量化, 1: int8, 2: FP8_E4M3, 3: FP8_UE8M0 (仅支持round_scale=True), 4: FP8_E5M2
dispatch_use_quant
=
quant_type
>
0
dispatch_use_quant
=
quant_type
>
0
for
fp8_round_scale
in
(
False
,
True
)
if
quant_type
!=
3
else
(
True
,
):
for
fp8_round_scale
in
(
False
,
True
)
if
quant_type
!=
3
else
(
True
,
):
...
@@ -133,7 +146,12 @@ def test_main(num_tokens: int,
...
@@ -133,7 +146,12 @@ def test_main(num_tokens: int,
recv_x
=
recv_x
[:
num_valid_tokens
]
recv_x
=
recv_x
[:
num_valid_tokens
]
recv_x_amin
=
recv_x
[:,
:
-
128
].
amin
(
dim
=-
1
)
recv_x_amin
=
recv_x
[:,
:
-
128
].
amin
(
dim
=-
1
)
recv_x_amax
=
recv_x
[:,
:
-
128
].
amax
(
dim
=-
1
)
recv_x_amax
=
recv_x
[:,
:
-
128
].
amax
(
dim
=-
1
)
if
enable_dispatch_ll_layered
or
enable_combine_overlap
:
recv_src_info
=
recv_src_info
[:
num_valid_tokens
]
&
int_mask
# 掩掉多余的信息
else
:
recv_src_info
=
recv_src_info
[:
num_valid_tokens
]
recv_src_info
=
recv_src_info
[:
num_valid_tokens
]
assert
torch
.
equal
(
recv_x_amin
,
recv_x_amax
)
assert
torch
.
equal
(
recv_x_amin
,
recv_x_amax
)
if
dispatch_use_quant
:
if
dispatch_use_quant
:
...
@@ -150,6 +168,7 @@ def test_main(num_tokens: int,
...
@@ -150,6 +168,7 @@ def test_main(num_tokens: int,
if
not
fp8_round_scale
:
if
not
fp8_round_scale
:
assert
(
recv_x_amin
==
j
-
rank_offset
).
sum
().
item
()
==
(
all_topk_idx
[
j
]
==
expert_id
).
sum
().
item
()
assert
(
recv_x_amin
==
j
-
rank_offset
).
sum
().
item
()
==
(
all_topk_idx
[
j
]
==
expert_id
).
sum
().
item
()
assert
(
recv_x
[
begin_idx
:
begin_idx
+
count
,
:
-
128
]
-
j
+
rank_offset
).
sum
().
item
()
==
0
assert
(
recv_x
[
begin_idx
:
begin_idx
+
count
,
:
-
128
]
-
j
+
rank_offset
).
sum
().
item
()
==
0
if
dispatch_use_quant
:
if
dispatch_use_quant
:
hash_value
^=
hash_tensor
(
packed_recv_x
[
0
][
i
,
:
num_valid_tokens
])
hash_value
^=
hash_tensor
(
packed_recv_x
[
0
][
i
,
:
num_valid_tokens
])
hash_value
^=
hash_tensor
(
packed_recv_x
[
1
][
i
,
:
num_valid_tokens
])
hash_value
^=
hash_tensor
(
packed_recv_x
[
1
][
i
,
:
num_valid_tokens
])
...
@@ -161,6 +180,28 @@ def test_main(num_tokens: int,
...
@@ -161,6 +180,28 @@ def test_main(num_tokens: int,
if
zero_copy
:
if
zero_copy
:
buffer
.
get_next_low_latency_combine_buffer
(
handle
)[:,
:,
:]
=
simulated_gemm_x
buffer
.
get_next_low_latency_combine_buffer
(
handle
)[:,
:,
:]
=
simulated_gemm_x
out
=
torch
.
empty
((
num_tokens
,
hidden
),
dtype
=
torch
.
bfloat16
,
device
=
'cuda'
)
out
=
torch
.
empty
((
num_tokens
,
hidden
),
dtype
=
torch
.
bfloat16
,
device
=
'cuda'
)
if
enable_combine_overlap
:
block_m
,
threshold
,
num_sms
=
64
,
10
,
3
total_num_per_expert
=
ceil_div
(
num_tokens
*
num_ranks
,
block_m
)
# 每个本地专家 总的信号数
comp_signal
=
torch
.
zeros
(
num_local_experts
*
total_num_per_expert
,
dtype
=
torch
.
int32
,
device
=
'cuda'
)
for
i
in
range
(
num_local_experts
):
vaild_num
=
ceil_div
(
packed_recv_count
[
i
],
block_m
)
comp_signal
[
i
*
total_num_per_expert
:
i
*
total_num_per_expert
+
vaild_num
]
=
threshold
combined_x
,
event
,
hook
=
buffer
.
low_latency_combine
(
simulated_gemm_x
,
topk_idx
,
topk_weights
,
handle
,
packed_recv_count
=
packed_recv_count
,
comp_signal
=
comp_signal
,
block_m
=
block_m
,
threshold
=
threshold
,
num_sms
=
num_sms
,
async_finish
=
not
return_recv_hook
,
zero_copy
=
zero_copy
,
return_recv_hook
=
return_recv_hook
,
out
=
out
)
else
:
combined_x
,
event
,
hook
=
buffer
.
low_latency_combine
(
simulated_gemm_x
,
combined_x
,
event
,
hook
=
buffer
.
low_latency_combine
(
simulated_gemm_x
,
topk_idx
,
topk_idx
,
topk_weights
,
topk_weights
,
...
@@ -170,6 +211,7 @@ def test_main(num_tokens: int,
...
@@ -170,6 +211,7 @@ def test_main(num_tokens: int,
zero_copy
=
zero_copy
,
zero_copy
=
zero_copy
,
return_recv_hook
=
return_recv_hook
,
return_recv_hook
=
return_recv_hook
,
out
=
out
)
out
=
out
)
hook
()
if
return_recv_hook
else
event
.
current_stream_wait
()
hook
()
if
return_recv_hook
else
event
.
current_stream_wait
()
if
do_check
:
if
do_check
:
diff
=
calc_diff
(
current_x
*
topk_weights
.
masked_fill
(
topk_idx
==
-
1
,
0
).
sum
(
dim
=
1
).
view
(
-
1
,
1
),
combined_x
)
diff
=
calc_diff
(
current_x
*
topk_weights
.
masked_fill
(
topk_idx
==
-
1
,
0
).
sum
(
dim
=
1
).
view
(
-
1
,
1
),
combined_x
)
...
@@ -181,8 +223,10 @@ def test_main(num_tokens: int,
...
@@ -181,8 +223,10 @@ def test_main(num_tokens: int,
if
rank
==
0
:
if
rank
==
0
:
print
(
f
"data:
{
x_i
}
, return_recv_hook:
{
return_recv_hook
}
, quant_type:
{
quant_type
}
, "
,
print
(
f
"data:
{
x_i
}
, return_recv_hook:
{
return_recv_hook
}
, quant_type:
{
quant_type
}
, "
,
f
"fp8_round_scale:
{
fp8_round_scale
}
, quant_group_size:
{
quant_group_size
}
pass"
)
f
"fp8_round_scale:
{
fp8_round_scale
}
, quant_group_size:
{
quant_group_size
}
pass"
)
if
rank
==
0
:
print
(
''
,
flush
=
True
)
print
(
"deep_ep 全部正确性测试完成"
)
if
enable_dispatch_ll_layered
or
enable_combine_overlap
:
return
hash_value
# noinspection PyShadowingNames
# noinspection PyShadowingNames
def
large_gemm_with_hook
(
hook
):
def
large_gemm_with_hook
(
hook
):
...
@@ -252,9 +296,13 @@ def test_loop(local_rank: int, num_local_ranks: int, args: argparse.Namespace):
...
@@ -252,9 +296,13 @@ def test_loop(local_rank: int, num_local_ranks: int, args: argparse.Namespace):
num_tokens
,
hidden
=
args
.
num_tokens
,
args
.
hidden
num_tokens
,
hidden
=
args
.
num_tokens
,
args
.
hidden
num_topk
,
num_experts
=
args
.
num_topk
,
args
.
num_experts
num_topk
,
num_experts
=
args
.
num_topk
,
args
.
num_experts
print
(
f
"num_tokens, hidden, num_ranks, num_experts =
{
num_tokens
}
,
{
hidden
}
,
{
num_ranks
}
,
{
num_experts
}
"
)
enable_dispatch_ll_layered
=
args
.
enable_dispatch_ll_layered
enable_combine_overlap
=
args
.
enable_combine_overlap
if
enable_dispatch_ll_layered
:
enable_combine_overlap
=
True
num_rdma_bytes
=
deep_ep
.
Buffer
.
get_low_latency_rdma_size_hint
(
num_tokens
,
hidden
,
num_ranks
,
num_experts
)
num_rdma_bytes
=
deep_ep
.
Buffer
.
get_low_latency_rdma_size_hint
(
num_tokens
,
hidden
,
num_ranks
,
num_experts
,
enable_dispatch_ll_layered
=
enable_dispatch_ll_layered
)
if
rank
==
0
:
if
rank
==
0
:
print
(
f
'Allocating buffer size:
{
num_rdma_bytes
/
1e6
}
MB ...'
,
flush
=
True
)
print
(
f
'Allocating buffer size:
{
num_rdma_bytes
/
1e6
}
MB ...'
,
flush
=
True
)
buffer
=
deep_ep
.
Buffer
(
group
,
buffer
=
deep_ep
.
Buffer
(
group
,
...
@@ -263,7 +311,11 @@ def test_loop(local_rank: int, num_local_ranks: int, args: argparse.Namespace):
...
@@ -263,7 +311,11 @@ def test_loop(local_rank: int, num_local_ranks: int, args: argparse.Namespace):
num_qps_per_rank
=
num_experts
//
num_ranks
,
num_qps_per_rank
=
num_experts
//
num_ranks
,
allow_nvlink_for_low_latency_mode
=
not
args
.
disable_nvlink
,
allow_nvlink_for_low_latency_mode
=
not
args
.
disable_nvlink
,
explicitly_destroy
=
True
,
explicitly_destroy
=
True
,
allow_mnnvl
=
args
.
allow_mnnvl
)
allow_mnnvl
=
args
.
allow_mnnvl
,
enable_dispatch_ll_layered
=
enable_dispatch_ll_layered
,
enable_combine_overlap
=
enable_combine_overlap
)
print
(
"deep_ep 初始化完成"
)
test_main
(
num_tokens
,
test_main
(
num_tokens
,
hidden
,
hidden
,
num_experts
,
num_experts
,
...
@@ -273,6 +325,8 @@ def test_loop(local_rank: int, num_local_ranks: int, args: argparse.Namespace):
...
@@ -273,6 +325,8 @@ def test_loop(local_rank: int, num_local_ranks: int, args: argparse.Namespace):
group
,
group
,
buffer
,
buffer
,
use_logfmt
=
args
.
use_logfmt
,
use_logfmt
=
args
.
use_logfmt
,
enable_dispatch_ll_layered
=
enable_dispatch_ll_layered
,
enable_combine_overlap
=
enable_combine_overlap
,
seed
=
1
)
seed
=
1
)
do_pressure_test
=
args
.
pressure_test
do_pressure_test
=
args
.
pressure_test
...
@@ -288,6 +342,8 @@ def test_loop(local_rank: int, num_local_ranks: int, args: argparse.Namespace):
...
@@ -288,6 +342,8 @@ def test_loop(local_rank: int, num_local_ranks: int, args: argparse.Namespace):
group
,
group
,
buffer
,
buffer
,
use_logfmt
=
args
.
use_logfmt
,
use_logfmt
=
args
.
use_logfmt
,
enable_dispatch_ll_layered
=
enable_dispatch_ll_layered
,
enable_combine_overlap
=
enable_combine_overlap
,
seed
=
seed
)
seed
=
seed
)
for
_
in
range
(
20
):
for
_
in
range
(
20
):
assert
test_main
(
num_tokens
,
assert
test_main
(
num_tokens
,
...
@@ -299,6 +355,8 @@ def test_loop(local_rank: int, num_local_ranks: int, args: argparse.Namespace):
...
@@ -299,6 +355,8 @@ def test_loop(local_rank: int, num_local_ranks: int, args: argparse.Namespace):
group
,
group
,
buffer
,
buffer
,
use_logfmt
=
args
.
use_logfmt
,
use_logfmt
=
args
.
use_logfmt
,
enable_dispatch_ll_layered
=
enable_dispatch_ll_layered
,
enable_combine_overlap
=
enable_combine_overlap
,
seed
=
seed
)
==
ref_hash
,
f
'Error: seed=
{
seed
}
'
seed
=
seed
)
==
ref_hash
,
f
'Error: seed=
{
seed
}
'
# Destroy the buffer runtime and communication group
# Destroy the buffer runtime and communication group
...
@@ -331,6 +389,10 @@ if __name__ == '__main__':
...
@@ -331,6 +389,10 @@ if __name__ == '__main__':
parser
.
add_argument
(
"--pressure-test"
,
action
=
'store_true'
,
help
=
'Whether to do pressure test'
)
parser
.
add_argument
(
"--pressure-test"
,
action
=
'store_true'
,
help
=
'Whether to do pressure test'
)
parser
.
add_argument
(
"--shrink-test"
,
action
=
'store_true'
,
help
=
'Whether to simulate failure and test shrink mode'
)
parser
.
add_argument
(
"--shrink-test"
,
action
=
'store_true'
,
help
=
'Whether to simulate failure and test shrink mode'
)
parser
.
add_argument
(
'--use-logfmt'
,
action
=
'store_true'
,
help
=
'Whether to test LogFMT combine'
)
parser
.
add_argument
(
'--use-logfmt'
,
action
=
'store_true'
,
help
=
'Whether to test LogFMT combine'
)
# 新版 sbo 需要的
parser
.
add_argument
(
'--enable-dispatch-ll-layered'
,
action
=
'store_true'
,
help
=
'Enable low-latency layered dispatch optimization'
)
parser
.
add_argument
(
"--enable-combine-overlap"
,
action
=
'store_true'
,
help
=
'Enable GEMM-compute/communication overlap in the combine phase'
)
args
=
parser
.
parse_args
()
args
=
parser
.
parse_args
()
if
args
.
world_size
>
args
.
num_processes
:
if
args
.
world_size
>
args
.
num_processes
:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment