Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
DeepEP
Commits
243eca85
Commit
243eca85
authored
Apr 17, 2026
by
lishen01
Browse files
fix: 解决高吞吐的SM最大只能到48的问题,提升高吞吐的整体性能
parent
766b17b3
Changes
7
Hide whitespace changes
Inline
Side-by-side
Showing
7 changed files
with
123 additions
and
95 deletions
+123
-95
csrc/config.hpp
csrc/config.hpp
+35
-28
csrc/deep_ep.cu
csrc/deep_ep.cu
+2
-0
csrc/kernels/internode.cu
csrc/kernels/internode.cu
+68
-59
deep_ep/buffer.py
deep_ep/buffer.py
+2
-1
tests/1.sh
tests/1.sh
+3
-2
tests/2.sh
tests/2.sh
+3
-2
tests/test_internode.py
tests/test_internode.py
+10
-3
No files found.
csrc/config.hpp
View file @
243eca85
...
@@ -47,21 +47,25 @@ struct Config {
...
@@ -47,21 +47,25 @@ struct Config {
EP_HOST_ASSERT
(
num_ranks
<=
NUM_MAX_NVL_PEERS
or
num_sms
%
(
2
*
NUM_INTERNODE_DISPATCH_BLOCKS_PER_CHANNEL
)
==
0
);
EP_HOST_ASSERT
(
num_ranks
<=
NUM_MAX_NVL_PEERS
or
num_sms
%
(
2
*
NUM_INTERNODE_DISPATCH_BLOCKS_PER_CHANNEL
)
==
0
);
const
auto
num_rdma_ranks
=
std
::
max
(
num_ranks
/
NUM_MAX_NVL_PEERS
,
1
);
const
auto
num_rdma_ranks
=
std
::
max
(
num_ranks
/
NUM_MAX_NVL_PEERS
,
1
);
const
auto
num_nvl_ranks
=
std
::
min
(
num_ranks
,
NUM_MAX_NVL_PEERS
);
const
auto
num_nvl_ranks
=
std
::
min
(
num_ranks
,
NUM_MAX_NVL_PEERS
);
const
int
num_channels
=
num_sms
;
const
int
num_channels
=
num_sms
/
NUM_INTERNODE_DISPATCH_BLOCKS_PER_CHANNEL
;
size_t
num_bytes
=
0
;
// 计算每个nvl通信数据包的数据量
num_bytes
+=
num_channels
*
num_nvl_ranks
*
(
2
*
num_rdma_ranks
+
3
)
*
sizeof
(
int
);
size_t
num_single_nvl_bag_bytes
=
num_bytes
+=
num_channels
*
num_nvl_ranks
*
num_max_nvl_chunked_recv_tokens
*
hidden_bytes
;
hidden_bytes
+
// 数据缓冲区(Token Data)。存储从 RDMA 转发过来的 token 数据(x 张量)
#ifndef DISABLE_ROCSHMEM
#ifndef DISABLE_ROCSHMEM
num_bytes
+=
num_channels
*
num_nvl_ranks
*
num_max_nvl_chunked_recv_tokens
*
internode
::
get_source_meta_bytes
()
+
// 源元数据缓冲区(Source Metadata)。存储每个 token 的源信息(哪个 RDMA rank 发送的)
internode
::
get_source_meta_bytes
();
#endif
#endif
num_bytes
+=
num_channels
*
num_nvl_ranks
*
num_max_nvl_chunked_recv_tokens
*
kNumMaxTopK
*
kNumMaxTopK
*
sizeof
(
int
)
+
// TopK 索引缓冲区。存储每个 token 的 top-k 专家索引
sizeof
(
int64_t
);
kNumMaxTopK
*
sizeof
(
float
)
+
// TopK 权重缓冲区。存储每个 token 的 top-k 专家权重
num_bytes
+=
num_channels
*
num_nvl_ranks
*
num_max_nvl_chunked_recv_tokens
*
kNumMaxTopK
*
kNumMaxScales
*
sizeof
(
float
);
// Scale 缓冲区。存储每个 token 的量化缩放因子
sizeof
(
float
);
num_bytes
+=
num_channels
*
num_nvl_ranks
*
num_max_nvl_chunked_recv_tokens
*
// 计算每个 NVL channel 的控制信息所需的字节数,存储每个 NVL channel 的前缀索引信息,用于快速定位数据(nvl_channel_prefix_start、nvl_channel_prefix_end 等)
kNumMaxScales
*
sizeof
(
float
);
size_t
num_single_nvl_control_bytes
=
(
2
*
num_rdma_ranks
+
3
)
*
sizeof
(
int
);
// NVL 数据总的字节数
size_t
num_bytes
=
(
num_single_nvl_bag_bytes
*
num_max_nvl_chunked_recv_tokens
+
num_single_nvl_control_bytes
)
*
num_channels
*
num_nvl_ranks
;
// 128 字节对齐,匹配 GPU 缓存行大小,优化内存访问。
num_bytes
=
((
num_bytes
+
127
)
/
128
)
*
128
;
num_bytes
=
((
num_bytes
+
127
)
/
128
)
*
128
;
return
num_bytes
;
return
num_bytes
;
}
}
...
@@ -79,22 +83,25 @@ struct Config {
...
@@ -79,22 +83,25 @@ struct Config {
EP_HOST_ASSERT
(
num_ranks
%
NUM_MAX_NVL_PEERS
==
0
);
EP_HOST_ASSERT
(
num_ranks
%
NUM_MAX_NVL_PEERS
==
0
);
EP_HOST_ASSERT
(
num_sms
%
NUM_INTERNODE_DISPATCH_BLOCKS_PER_CHANNEL
==
0
);
EP_HOST_ASSERT
(
num_sms
%
NUM_INTERNODE_DISPATCH_BLOCKS_PER_CHANNEL
==
0
);
const
int
num_rdma_ranks
=
num_ranks
/
NUM_MAX_NVL_PEERS
;
const
int
num_rdma_ranks
=
num_ranks
/
NUM_MAX_NVL_PEERS
;
const
int
num_channels
=
num_sms
;
const
int
num_channels
=
num_sms
/
NUM_INTERNODE_DISPATCH_BLOCKS_PER_CHANNEL
;
size_t
num_bytes
=
0
;
// 计算每个rdma通信数据包的数据量
num_bytes
+=
num_channels
*
num_rdma_ranks
*
(
NUM_MAX_NVL_PEERS
*
2
+
2
)
*
2
*
sizeof
(
int
);
size_t
num_single_rdma_bag_bytes
=
num_bytes
+=
hidden_bytes
+
// 数据缓冲区。存储实际的 token 数据(x 张量),对应代码中的 rdma_channel_data
num_channels
*
num_rdma_ranks
*
num_max_rdma_chunked_recv_tokens
*
hidden_bytes
*
2
;
internode
::
get_source_meta_bytes
()
+
// 源元数据缓冲区。存储每个 token 的源信息(SourceMeta)
num_bytes
+=
num_channels
*
num_rdma_ranks
*
num_max_rdma_chunked_recv_tokens
*
kNumMaxTopK
*
sizeof
(
int
)
+
// 存储每个 token 的 top-k 专家索引。对应 topk_idx 数据
internode
::
get_source_meta_bytes
()
*
2
;
kNumMaxTopK
*
sizeof
(
float
)
+
// 存储每个 token 的 top-k 专家权重。对应 topk_weights 数据
num_bytes
+=
num_channels
*
num_rdma_ranks
*
num_max_rdma_chunked_recv_tokens
*
kNumMaxScales
*
sizeof
(
float
)
+
// 存储每个 token 的缩放因子(x_scales)
kNumMaxTopK
*
sizeof
(
int64_t
)
*
2
;
sizeof
(
int4
);
// 预留空间用于内存对齐和未来扩展
num_bytes
+=
num_channels
*
num_rdma_ranks
*
num_max_rdma_chunked_recv_tokens
*
kNumMaxTopK
*
sizeof
(
float
)
*
2
;
// 计算每个 RDMA channel 的控制信息(起始/结束索引)所需的字节数,对应代码中的 rdma_channel_meta
num_bytes
+=
num_channels
*
num_rdma_ranks
*
num_max_rdma_chunked_recv_tokens
*
size_t
num_single_rdma_control_bytes
=
(
NUM_MAX_NVL_PEERS
*
2
+
4
)
*
sizeof
(
int
);
kNumMaxScales
*
sizeof
(
float
)
*
2
;
num_bytes
+=
// RDMA 数据总的字节数
num_channels
*
num_rdma_ranks
*
num_max_rdma_chunked_recv_tokens
*
sizeof
(
int4
)
*
2
;
size_t
num_bytes
=
(
num_single_rdma_bag_bytes
*
num_max_rdma_chunked_recv_tokens
+
num_single_rdma_control_bytes
)
*
num_channels
*
num_rdma_ranks
*
2
;
// 128 字节对齐(缓存行对齐),优化内存访问性能
num_bytes
=
((
num_bytes
+
127
)
/
128
)
*
128
;
num_bytes
=
((
num_bytes
+
127
)
/
128
)
*
128
;
return
num_bytes
;
return
num_bytes
;
#else
#else
...
...
csrc/deep_ep.cu
View file @
243eca85
...
@@ -937,6 +937,7 @@ Buffer::internode_dispatch(const torch::Tensor &x, const std::optional<torch::Te
...
@@ -937,6 +937,7 @@ Buffer::internode_dispatch(const torch::Tensor &x, const std::optional<torch::Te
gbl_channel_prefix_matrix
=
cached_gbl_channel_prefix_matrix
.
value
();
gbl_channel_prefix_matrix
=
cached_gbl_channel_prefix_matrix
.
value
();
recv_gbl_rank_prefix_sum
=
cached_recv_gbl_rank_prefix_sum
.
value
();
recv_gbl_rank_prefix_sum
=
cached_recv_gbl_rank_prefix_sum
.
value
();
EP_HOST_ASSERT
(
num_rdma_bytes
>=
config
.
get_rdma_buffer_size_hint
(
hidden_int4
*
sizeof
(
int4
),
num_ranks
));
// Just a barrier and clean flags
// Just a barrier and clean flags
internode
::
cached_notify
(
internode
::
cached_notify
(
hidden_int4
,
num_scales
,
num_topk
,
num_topk
,
num_ranks
,
num_channels
,
0
,
nullptr
,
hidden_int4
,
num_scales
,
num_topk
,
num_topk
,
num_ranks
,
num_channels
,
0
,
nullptr
,
...
@@ -1205,6 +1206,7 @@ Buffer::internode_combine(
...
@@ -1205,6 +1206,7 @@ Buffer::internode_combine(
EP_HOST_ASSERT
(
config
.
num_max_nvl_chunked_recv_tokens
%
num_rdma_ranks
==
0
);
EP_HOST_ASSERT
(
config
.
num_max_nvl_chunked_recv_tokens
%
num_rdma_ranks
==
0
);
EP_HOST_ASSERT
(
config
.
num_max_nvl_chunked_send_tokens
<=
EP_HOST_ASSERT
(
config
.
num_max_nvl_chunked_send_tokens
<=
config
.
num_max_nvl_chunked_recv_tokens
/
num_rdma_ranks
);
config
.
num_max_nvl_chunked_recv_tokens
/
num_rdma_ranks
);
EP_HOST_ASSERT
(
num_rdma_bytes
>=
config
.
get_rdma_buffer_size_hint
(
hidden_int4
*
sizeof
(
int4
),
num_ranks
));
// Launch barrier and reset queue head and tail
// Launch barrier and reset queue head and tail
internode
::
cached_notify
(
internode
::
cached_notify
(
...
...
csrc/kernels/internode.cu
View file @
243eca85
...
@@ -7,6 +7,10 @@
...
@@ -7,6 +7,10 @@
#ifndef DISABLE_ROCSHMEM
#ifndef DISABLE_ROCSHMEM
// 安全检查:确保宏已定义
#ifndef HIP_VERSION_PATCH
#error "HIP_VERSION_PATCH not defined! Check your HIP installation."
#endif
// TODO: fix unroll warnings
// TODO: fix unroll warnings
// #ifdef __clang__
// #ifdef __clang__
// #pragma clang diagnostic push
// #pragma clang diagnostic push
...
@@ -56,16 +60,18 @@ __host__ __device__ __forceinline__ int get_num_bytes_per_rdma_token(int hidden_
...
@@ -56,16 +60,18 @@ __host__ __device__ __forceinline__ int get_num_bytes_per_rdma_token(int hidden_
__host__
__device__
__forceinline__
std
::
pair
<
int
,
int
>
__host__
__device__
__forceinline__
std
::
pair
<
int
,
int
>
get_rdma_clean_meta
(
int
hidden_int4
,
int
num_scales
,
int
num_topk_idx
,
int
num_topk_weights
,
get_rdma_clean_meta
(
int
hidden_int4
,
int
num_scales
,
int
num_topk_idx
,
int
num_topk_weights
,
int
num_rdma_ranks
,
int
num_rdma_recv_buffer_tokens
,
int
num_sms
)
{
int
num_rdma_ranks
,
int
num_rdma_recv_buffer_tokens
,
int
num_channels
)
{
// Return `int32_t` offset and count to clean
// Return `int32_t` offset and count to clean
return
{(
get_num_bytes_per_rdma_token
(
hidden_int4
,
num_scales
,
num_topk_idx
,
num_topk_weights
)
*
return
{(
get_num_bytes_per_rdma_token
(
hidden_int4
,
num_scales
,
num_topk_idx
,
num_topk_weights
)
*
num_rdma_recv_buffer_tokens
*
num_rdma_ranks
*
2
*
num_
sm
s
)
/
sizeof
(
int
),
num_rdma_recv_buffer_tokens
*
num_rdma_ranks
*
2
*
num_
channel
s
)
/
sizeof
(
int
),
(
NUM_MAX_NVL_PEERS
*
2
+
4
)
*
num_rdma_ranks
*
2
*
num_
sm
s
};
(
NUM_MAX_NVL_PEERS
*
2
+
4
)
*
num_rdma_ranks
*
2
*
num_
channel
s
};
}
}
__host__
__device__
__forceinline__
std
::
pair
<
int
,
int
>
__host__
__device__
__forceinline__
std
::
pair
<
int
,
int
>
get_nvl_clean_meta
(
int
hidden_int4
,
int
num_scales
,
int
num_topk_idx
,
int
num_topk_weights
,
get_nvl_clean_meta
(
int
hidden_int4
,
int
num_scales
,
int
num_topk_idx
,
int
num_topk_weights
,
int
num_rdma_ranks
,
int
num_nvl_ranks
,
int
num_nvl_recv_buffer_tokens
,
int
num_rdma_ranks
,
int
num_nvl_ranks
,
int
num_nvl_recv_buffer_tokens
,
int
num_
sm
s
)
{
int
num_
channel
s
)
{
// Return `int32_t` offset and to clean
// Return `int32_t` offset and to clean
EP_STATIC_ASSERT
(
sizeof
(
SourceMeta
)
%
sizeof
(
int
)
==
0
,
EP_STATIC_ASSERT
(
sizeof
(
SourceMeta
)
%
sizeof
(
int
)
==
0
,
"Invalid size of `SourceMeta`"
);
"Invalid size of `SourceMeta`"
);
...
@@ -73,8 +79,8 @@ get_nvl_clean_meta(int hidden_int4, int num_scales, int num_topk_idx, int num_to
...
@@ -73,8 +79,8 @@ get_nvl_clean_meta(int hidden_int4, int num_scales, int num_topk_idx, int num_to
(
num_nvl_recv_buffer_tokens
*
(
num_nvl_recv_buffer_tokens
*
(
hidden_int4
*
sizeof
(
int4
)
+
num_scales
*
sizeof
(
float
)
+
num_topk_idx
*
sizeof
(
int
)
+
(
hidden_int4
*
sizeof
(
int4
)
+
num_scales
*
sizeof
(
float
)
+
num_topk_idx
*
sizeof
(
int
)
+
num_topk_weights
*
sizeof
(
float
)
+
sizeof
(
SourceMeta
))
*
num_topk_weights
*
sizeof
(
float
)
+
sizeof
(
SourceMeta
))
*
num_nvl_ranks
*
num_
sm
s
)
/
sizeof
(
int
),
num_nvl_ranks
*
num_
channel
s
)
/
sizeof
(
int
),
num_nvl_ranks
*
(
2
*
num_rdma_ranks
+
2
)
*
num_
sm
s
,
num_nvl_ranks
*
(
2
*
num_rdma_ranks
+
2
)
*
num_
channel
s
,
};
};
}
}
...
@@ -1230,24 +1236,25 @@ cached_notify(const int rdma_clean_offset, const int rdma_num_int_clean, const i
...
@@ -1230,24 +1236,25 @@ cached_notify(const int rdma_clean_offset, const int rdma_num_int_clean, const i
if
(
is_cached_dispatch
)
if
(
is_cached_dispatch
)
return
;
return
;
EP_DEVICE_ASSERT
(
num_warps
>=
num_channels
);
EP_DEVICE_ASSERT
(
num_rdma_ranks
<=
kWarpSize
);
EP_DEVICE_ASSERT
(
num_rdma_ranks
<=
kWarpSize
);
// Iterate in reverse order
// Iterate in reverse order
if
(
lane_id
<
num_rdma_ranks
and
warp_id
<
num_channels
)
{
for
(
int
channel_id
=
warp_id
;
channel_id
<
num_channels
;
channel_id
+=
num_warps
)
{
int
token_start_idx
,
token_end_idx
;
if
(
lane_id
<
num_rdma_ranks
)
{
get_channel_task_range
(
num_combined_tokens
,
num_channels
,
warp_id
,
token_start_idx
,
int
token_start_idx
,
token_end_idx
;
token_end_idx
);
get_channel_task_range
(
num_combined_tokens
,
num_channels
,
channel_id
,
token_start_idx
,
token_end_idx
);
// NOTES: `1 << 25` is a heuristic large number
// NOTES: `1 << 25` is a heuristic large number
int
last_head
=
1
<<
25
;
int
last_head
=
1
<<
25
;
for
(
int
token_idx
=
token_end_idx
-
1
;
token_idx
>=
token_start_idx
;
--
token_idx
)
{
for
(
int
token_idx
=
token_end_idx
-
1
;
token_idx
>=
token_start_idx
;
--
token_idx
)
{
auto
current_head
=
auto
current_head
=
__ldg
(
combined_rdma_head
+
token_idx
*
num_rdma_ranks
+
lane_id
);
__ldg
(
combined_rdma_head
+
token_idx
*
num_rdma_ranks
+
lane_id
);
if
(
current_head
<
0
)
{
if
(
current_head
<
0
)
{
combined_rdma_head
[
token_idx
*
num_rdma_ranks
+
lane_id
]
=
-
last_head
-
1
;
combined_rdma_head
[
token_idx
*
num_rdma_ranks
+
lane_id
]
=
-
last_head
-
1
;
}
else
{
}
else
{
last_head
=
current_head
;
last_head
=
current_head
;
}
}
}
}
}
}
}
...
@@ -1255,34 +1262,34 @@ cached_notify(const int rdma_clean_offset, const int rdma_num_int_clean, const i
...
@@ -1255,34 +1262,34 @@ cached_notify(const int rdma_clean_offset, const int rdma_num_int_clean, const i
if
(
is_cached_dispatch
)
if
(
is_cached_dispatch
)
return
;
return
;
EP_DEVICE_ASSERT
(
num_warps
>=
num_channels
);
EP_DEVICE_ASSERT
(
rdma_channel_prefix_matrix
!=
nullptr
and
rdma_rank_prefix_sum
!=
nullptr
);
EP_DEVICE_ASSERT
(
rdma_channel_prefix_matrix
!=
nullptr
and
rdma_rank_prefix_sum
!=
nullptr
);
EP_STATIC_ASSERT
(
NUM_MAX_NVL_PEERS
<=
kWarpSize
,
"Too many NVL peers"
);
EP_STATIC_ASSERT
(
NUM_MAX_NVL_PEERS
<=
kWarpSize
,
"Too many NVL peers"
);
constexpr
int
num_clean_sms
=
2
;
constexpr
int
num_clean_sms
=
2
;
if
(
lane_id
<
NUM_MAX_NVL_PEERS
and
warp_id
<
num_channels
)
{
for
(
int
channel_id
=
warp_id
;
channel_id
<
num_channels
;
channel_id
+=
num_warps
)
{
for
(
int
dst_rdma_rank
=
sm_id
-
num_clean_sms
;
dst_rdma_rank
<
num_rdma_ranks
;
if
(
lane_id
<
NUM_MAX_NVL_PEERS
)
{
dst_rdma_rank
+=
num_channels
*
2
-
num_clean_sms
)
{
for
(
int
dst_rdma_rank
=
sm_id
-
num_clean_sms
;
dst_rdma_rank
<
num_rdma_ranks
;
// Iterate in reverse order
dst_rdma_rank
+=
num_channels
*
NUM_INTERNODE_DISPATCH_BLOCKS_PER_CHANNEL
-
num_clean_sms
)
{
int
token_start_idx
=
// Iterate in reverse order
warp_id
==
0
int
token_start_idx
=
?
0
channel_id
==
0
:
rdma_channel_prefix_matrix
[
dst_rdma_rank
*
num_channels
+
warp_id
-
1
];
?
0
int
token_end_idx
=
:
rdma_channel_prefix_matrix
[
dst_rdma_rank
*
num_channels
+
channel_id
-
1
];
rdma_channel_prefix_matrix
[
dst_rdma_rank
*
num_channels
+
warp_id
];
int
token_end_idx
=
int
shift
=
dst_rdma_rank
==
0
?
0
:
rdma_rank_prefix_sum
[
dst_rdma_rank
-
1
];
rdma_channel_prefix_matrix
[
dst_rdma_rank
*
num_channels
+
channel_id
];
token_start_idx
+=
shift
,
token_end_idx
+=
shift
;
int
shift
=
dst_rdma_rank
==
0
?
0
:
rdma_rank_prefix_sum
[
dst_rdma_rank
-
1
];
token_start_idx
+=
shift
,
token_end_idx
+=
shift
;
// NOTES: `1 << 25` is a heuristic large number
int
last_head
=
1
<<
25
;
// NOTES: `1 << 25` is a heuristic large number
for
(
int
token_idx
=
token_end_idx
-
1
;
token_idx
>=
token_start_idx
;
--
token_idx
)
{
int
last_head
=
1
<<
25
;
auto
current_head
=
for
(
int
token_idx
=
token_end_idx
-
1
;
token_idx
>=
token_start_idx
;
--
token_idx
)
{
__ldg
(
combined_nvl_head
+
token_idx
*
NUM_MAX_NVL_PEERS
+
lane_id
);
auto
current_head
=
if
(
current_head
<
0
)
{
__ldg
(
combined_nvl_head
+
token_idx
*
NUM_MAX_NVL_PEERS
+
lane_id
);
combined_nvl_head
[
token_idx
*
NUM_MAX_NVL_PEERS
+
lane_id
]
=
-
last_head
-
1
;
if
(
current_head
<
0
)
{
}
else
{
combined_nvl_head
[
token_idx
*
NUM_MAX_NVL_PEERS
+
lane_id
]
=
-
last_head
-
1
;
last_head
=
current_head
;
}
else
{
last_head
=
current_head
;
}
}
}
}
}
}
}
...
@@ -1298,7 +1305,7 @@ void cached_notify(int hidden_int4, int num_scales, int num_topk_idx, int num_to
...
@@ -1298,7 +1305,7 @@ void cached_notify(int hidden_int4, int num_scales, int num_topk_idx, int num_to
int
num_max_nvl_chunked_recv_tokens
,
int
**
barrier_signal_ptrs
,
int
rank
,
int
num_max_nvl_chunked_recv_tokens
,
int
**
barrier_signal_ptrs
,
int
rank
,
hipStream_t
stream
,
int64_t
num_rdma_bytes
,
int64_t
num_nvl_bytes
,
hipStream_t
stream
,
int64_t
num_rdma_bytes
,
int64_t
num_nvl_bytes
,
bool
is_cached_dispatch
,
bool
low_latency_mode
)
{
bool
is_cached_dispatch
,
bool
low_latency_mode
)
{
const
int
num_threads
=
::
max
(
128
,
kWarpSize
*
num_channels
);
const
int
num_threads
=
::
min
(
1024
,
::
max
(
128
,
kWarpSize
*
num_channels
)
)
;
const
auto
num_rdma_ranks
=
num_ranks
/
NUM_MAX_NVL_PEERS
;
const
auto
num_rdma_ranks
=
num_ranks
/
NUM_MAX_NVL_PEERS
;
// Get clean meta
// Get clean meta
...
@@ -1314,11 +1321,11 @@ void cached_notify(int hidden_int4, int num_scales, int num_topk_idx, int num_to
...
@@ -1314,11 +1321,11 @@ void cached_notify(int hidden_int4, int num_scales, int num_topk_idx, int num_to
num_nvl_bytes
);
num_nvl_bytes
);
EP_HOST_ASSERT
(
num_rdma_bytes
<
std
::
numeric_limits
<
int
>::
max
());
EP_HOST_ASSERT
(
num_rdma_bytes
<
std
::
numeric_limits
<
int
>::
max
());
EP_HOST_ASSERT
(
num_nvl_bytes
<
std
::
numeric_limits
<
int
>::
max
());
EP_HOST_ASSERT
(
num_nvl_bytes
<
std
::
numeric_limits
<
int
>::
max
());
EP_HOST_ASSERT
(
num_channels
*
2
>
2
);
EP_HOST_ASSERT
(
num_channels
*
NUM_INTERNODE_DISPATCH_BLOCKS_PER_CHANNEL
>
2
);
// Launch kernel
// Launch kernel
auto
cached_notify_func
=
low_latency_mode
?
cached_notify
<
true
>
:
cached_notify
<
false
>
;
auto
cached_notify_func
=
low_latency_mode
?
cached_notify
<
true
>
:
cached_notify
<
false
>
;
SETUP_LAUNCH_CONFIG
(
num_channels
*
2
,
num_threads
,
stream
);
SETUP_LAUNCH_CONFIG
(
num_channels
*
NUM_INTERNODE_DISPATCH_BLOCKS_PER_CHANNEL
,
num_threads
,
stream
);
LAUNCH_KERNEL_NON_COOPERATIVE
(
LAUNCH_KERNEL_NON_COOPERATIVE
(
&
cfg
,
cached_notify_func
,
rdma_clean_meta
.
first
,
rdma_clean_meta
.
second
,
&
cfg
,
cached_notify_func
,
rdma_clean_meta
.
first
,
rdma_clean_meta
.
second
,
nvl_clean_meta
.
first
,
nvl_clean_meta
.
second
,
combined_rdma_head
,
num_combined_tokens
,
nvl_clean_meta
.
first
,
nvl_clean_meta
.
second
,
combined_rdma_head
,
num_combined_tokens
,
...
@@ -1327,11 +1334,12 @@ void cached_notify(int hidden_int4, int num_scales, int num_topk_idx, int num_to
...
@@ -1327,11 +1334,12 @@ void cached_notify(int hidden_int4, int num_scales, int num_topk_idx, int num_to
cpu_rdma_team
);
cpu_rdma_team
);
}
}
template
<
int
kNumRanks
,
typename
dtype_t
,
int
kMaxNumRanks
,
typename
Receive
Fn
,
typename
ReceiveTWFn
>
template
<
int
kNumRanks
,
typename
dtype_t
,
int
kMaxNumRanks
,
bool
kUseMLS
,
typename
GetAddr
Fn
,
typename
ReceiveTWFn
>
__device__
int
combine_token
(
bool
is_token_in_rank
,
int
head_idx
,
__device__
int
combine_token
(
bool
is_token_in_rank
,
int
head_idx
,
int
lane_id
,
int
hidden_int4
,
int
num_topk
,
int
lane_id
,
int
hidden_int4
,
int
num_topk
,
int4
*
combined_row
,
float
*
combined_topk_weights
,
int4
*
combined_row
,
float
*
combined_topk_weights
,
int
num_max_recv_tokens
,
const
ReceiveFn
&
recv_fn
,
const
ReceiveTWFn
&
recv_tw_fn
)
{
int
num_max_recv_tokens
,
const
GetAddrFn
&
get_addr_fn
,
const
ReceiveTWFn
&
recv_tw_fn
)
{
constexpr
auto
kDtypePerInt4
=
sizeof
(
int4
)
/
sizeof
(
dtype_t
);
constexpr
auto
kDtypePerInt4
=
sizeof
(
int4
)
/
sizeof
(
dtype_t
);
// Broadcast current heads
// Broadcast current heads
...
@@ -1353,7 +1361,7 @@ __device__ int combine_token(bool is_token_in_rank, int head_idx,
...
@@ -1353,7 +1361,7 @@ __device__ int combine_token(bool is_token_in_rank, int head_idx,
int4
recv_value_int4
[
kMaxNumRanks
];
int4
recv_value_int4
[
kMaxNumRanks
];
#pragma unroll
#pragma unroll
for
(
int
j
=
0
;
j
<
num_topk_ranks
;
++
j
)
for
(
int
j
=
0
;
j
<
num_topk_ranks
;
++
j
)
recv_value_int4
[
j
]
=
recv
_fn
(
topk_ranks
[
j
],
slot_indices
[
j
],
i
);
recv_value_int4
[
j
]
=
ld_nc_global
(
get_addr
_fn
(
topk_ranks
[
j
],
slot_indices
[
j
],
i
)
)
;
// Reduce all-to-all results
// Reduce all-to-all results
float
values
[
kDtypePerInt4
]
=
{
0
};
float
values
[
kDtypePerInt4
]
=
{
0
};
...
@@ -1416,6 +1424,7 @@ combine(int4 *combined_x, float *combined_topk_weights, const bool *is_combined_
...
@@ -1416,6 +1424,7 @@ combine(int4 *combined_x, float *combined_topk_weights, const bool *is_combined_
__shared__
shmem_ctx_t
ctx
;
__shared__
shmem_ctx_t
ctx
;
shmem_wg_ctx_create
(
&
ctx
);
shmem_wg_ctx_create
(
&
ctx
);
#endif
#endif
EP_STATIC_ASSERT
(
kNumCombineForwarderWarps
<=
kWarpSize
,
"Invalid number of forwarder warps"
);
const
auto
sm_id
=
static_cast
<
int
>
(
blockIdx
.
x
);
const
auto
sm_id
=
static_cast
<
int
>
(
blockIdx
.
x
);
const
auto
num_threads
=
static_cast
<
int
>
(
blockDim
.
x
),
num_warps
=
num_threads
/
kWarpSize
;
const
auto
num_threads
=
static_cast
<
int
>
(
blockDim
.
x
),
num_warps
=
num_threads
/
kWarpSize
;
...
@@ -1717,14 +1726,15 @@ combine(int4 *combined_x, float *combined_topk_weights, const bool *is_combined_
...
@@ -1717,14 +1726,15 @@ combine(int4 *combined_x, float *combined_topk_weights, const bool *is_combined_
// Combine current token
// Combine current token
auto
rdma_slot_idx
=
token_idx
%
num_max_rdma_chunked_recv_tokens
;
auto
rdma_slot_idx
=
token_idx
%
num_max_rdma_chunked_recv_tokens
;
void
*
shifted
=
send_buffer
+
rdma_slot_idx
*
num_bytes_per_rdma_token
;
void
*
shifted
=
send_buffer
+
rdma_slot_idx
*
num_bytes_per_rdma_token
;
auto
recv
_fn
=
[
&
](
int
src_nvl_rank
,
int
slot_idx
,
int
hidden_int4_idx
)
->
int4
{
return
ld_nc_global
(
nvl_channel_x
.
buffer
(
src_nvl_rank
)
+
slot_idx
*
hidden_int4
+
hidden_int4_idx
)
;
};
auto
get_addr
_fn
=
[
&
](
int
src_nvl_rank
,
int
slot_idx
,
int
hidden_int4_idx
)
->
int4
*
{
return
reinterpret_cast
<
int4
*>
(
nvl_channel_x
.
buffer
(
src_nvl_rank
)
+
slot_idx
*
hidden_int4
)
+
hidden_int4_idx
;
};
auto
recv_tw_fn
=
[
&
](
int
src_nvl_rank
,
int
slot_idx
,
int
topk_idx
)
->
float
{
return
ld_nc_global
(
nvl_channel_topk_weights
.
buffer
(
src_nvl_rank
)
+
slot_idx
*
num_topk
+
topk_idx
);
};
auto
recv_tw_fn
=
[
&
](
int
src_nvl_rank
,
int
slot_idx
,
int
topk_idx
)
->
float
{
return
ld_nc_global
(
nvl_channel_topk_weights
.
buffer
(
src_nvl_rank
)
+
slot_idx
*
num_topk
+
topk_idx
);
};
combine_token
<
NUM_MAX_NVL_PEERS
,
dtype_t
,
NUM_MAX_NVL_PEERS
>
(
expected_head
>=
0
,
combine_token
<
NUM_MAX_NVL_PEERS
,
dtype_t
,
NUM_MAX_NVL_PEERS
,
true
>
(
expected_head
>=
0
,
expected_head
,
lane_id
,
expected_head
,
lane_id
,
hidden_int4
,
num_topk
,
hidden_int4
,
num_topk
,
reinterpret_cast
<
int4
*>
(
shifted
),
reinterpret_cast
<
int4
*>
(
shifted
),
reinterpret_cast
<
float
*>
(
reinterpret_cast
<
int8_t
*>
(
shifted
)
+
hidden_bytes
+
sizeof
(
SourceMeta
)),
reinterpret_cast
<
float
*>
(
reinterpret_cast
<
int8_t
*>
(
shifted
)
+
hidden_bytes
+
sizeof
(
SourceMeta
)),
num_max_nvl_chunked_recv_tokens_per_rdma
,
recv_fn
,
recv_tw_fn
);
num_max_nvl_chunked_recv_tokens_per_rdma
,
get_addr_fn
,
recv_tw_fn
);
// Update head
// Update head
if
(
lane_id
<
NUM_MAX_NVL_PEERS
)
{
if
(
lane_id
<
NUM_MAX_NVL_PEERS
)
{
...
@@ -1787,7 +1797,6 @@ combine(int4 *combined_x, float *combined_topk_weights, const bool *is_combined_
...
@@ -1787,7 +1797,6 @@ combine(int4 *combined_x, float *combined_topk_weights, const bool *is_combined_
int
last_nvl_head
[
kNumRDMARanks
]
=
{
0
};
int
last_nvl_head
[
kNumRDMARanks
]
=
{
0
};
int
dst_nvl_rank
=
lane_id
<
NUM_MAX_NVL_PEERS
?
lane_id
:
0
;
int
dst_nvl_rank
=
lane_id
<
NUM_MAX_NVL_PEERS
?
lane_id
:
0
;
EP_STATIC_ASSERT
(
kNumCombineForwarderWarps
<=
kWarpSize
,
"Invalid number of forwarder warps"
);
while
(
true
)
{
while
(
true
)
{
// Retired
// Retired
...
@@ -1853,14 +1862,15 @@ combine(int4 *combined_x, float *combined_topk_weights, const bool *is_combined_
...
@@ -1853,14 +1862,15 @@ combine(int4 *combined_x, float *combined_topk_weights, const bool *is_combined_
syncwarp
();
syncwarp
();
// Combine current token
// Combine current token
auto
recv
_fn
=
[
&
](
int
src_rdma_rank
,
int
slot_idx
,
int
hidden_int4_idx
)
->
int4
{
return
ld_nc_global
(
reinterpret_cast
<
const
int4
*>
(
rdma_channel_data
.
recv_buffer
(
src_rdma_rank
)
+
slot_idx
*
num_bytes_per_rdma_token
)
+
hidden_int4_idx
)
;};
auto
get_addr
_fn
=
[
&
](
int
src_rdma_rank
,
int
slot_idx
,
int
hidden_int4_idx
)
->
int4
*
{
return
reinterpret_cast
<
int4
*>
(
rdma_channel_data
.
recv_buffer
(
src_rdma_rank
)
+
slot_idx
*
num_bytes_per_rdma_token
)
+
hidden_int4_idx
;
};
auto
recv_tw_fn
=
[
&
](
int
src_rdma_rank
,
int
slot_idx
,
int
topk_idx
)
->
float
{
return
ld_nc_global
(
reinterpret_cast
<
const
float
*>
(
rdma_channel_data
.
recv_buffer
(
src_rdma_rank
)
+
slot_idx
*
num_bytes_per_rdma_token
+
hidden_bytes
+
sizeof
(
SourceMeta
))
+
topk_idx
);};
auto
recv_tw_fn
=
[
&
](
int
src_rdma_rank
,
int
slot_idx
,
int
topk_idx
)
->
float
{
return
ld_nc_global
(
reinterpret_cast
<
const
float
*>
(
rdma_channel_data
.
recv_buffer
(
src_rdma_rank
)
+
slot_idx
*
num_bytes_per_rdma_token
+
hidden_bytes
+
sizeof
(
SourceMeta
))
+
topk_idx
);};
combine_token
<
kNumRDMARanks
,
dtype_t
,
kNumTopkRDMARanks
>
(
expected_head
>=
0
,
combine_token
<
kNumRDMARanks
,
dtype_t
,
kNumTopkRDMARanks
,
false
>
(
expected_head
>=
0
,
expected_head
,
lane_id
,
expected_head
,
lane_id
,
hidden_int4
,
num_topk
,
hidden_int4
,
num_topk
,
combined_x
+
token_idx
*
hidden_int4
,
combined_x
+
token_idx
*
hidden_int4
,
combined_topk_weights
+
token_idx
*
num_topk
,
combined_topk_weights
+
token_idx
*
num_topk
,
num_max_rdma_chunked_recv_tokens
,
recv_fn
,
recv_tw_fn
);
num_max_rdma_chunked_recv_tokens
,
get_addr_fn
,
recv_tw_fn
);
}
}
// Retired
// Retired
...
@@ -1879,7 +1889,6 @@ combine(int4 *combined_x, float *combined_topk_weights, const bool *is_combined_
...
@@ -1879,7 +1889,6 @@ combine(int4 *combined_x, float *combined_topk_weights, const bool *is_combined_
int
last_nvl_head
[
kNumRDMARanks
]
=
{
0
};
int
last_nvl_head
[
kNumRDMARanks
]
=
{
0
};
int
dst_rdma_rank
=
lane_id
<
kNumRDMARanks
?
lane_id
:
0
;
int
dst_rdma_rank
=
lane_id
<
kNumRDMARanks
?
lane_id
:
0
;
int
dst_nvl_rank
=
lane_id
<
NUM_MAX_NVL_PEERS
?
lane_id
:
0
;
int
dst_nvl_rank
=
lane_id
<
NUM_MAX_NVL_PEERS
?
lane_id
:
0
;
EP_STATIC_ASSERT
(
kNumCombineForwarderWarps
<=
kWarpSize
,
"Invalid number of forwarder warps"
);
while
(
true
)
{
while
(
true
)
{
// Retired
// Retired
...
...
deep_ep/buffer.py
View file @
243eca85
...
@@ -207,7 +207,8 @@ class Buffer:
...
@@ -207,7 +207,8 @@ class Buffer:
new_num_sms: the new number to be set.
new_num_sms: the new number to be set.
"""
"""
assert
new_num_sms
%
2
==
0
,
"The SM count must be even"
assert
new_num_sms
%
2
==
0
,
"The SM count must be new_num_sms % 2 == 0"
assert
new_num_sms
%
3
==
0
,
"The SM count must be new_num_sms % 3 == 0"
Buffer
.
num_sms
=
new_num_sms
Buffer
.
num_sms
=
new_num_sms
@
staticmethod
@
staticmethod
...
...
tests/1.sh
View file @
243eca85
#!/bin/bash
# rocSHMEM
# rocSHMEM
export
ROCSHMEM_GDA_NUM_QPS_DEFAULT_CTX
=
288
export
ROCSHMEM_GDA_NUM_QPS_DEFAULT_CTX
=
288
export
ROCSHMEM_MAX_NUM_CONTEXTS
=
48
export
ROCSHMEM_MAX_NUM_CONTEXTS
=
60
export
ROCSHMEM_ALLOWED_IBV_DEVICES
=
mlx5_2,mlx5_3,mlx5_4,mlx5_5,mlx5_6,mlx5_7,mlx5_8,mlx5_9
export
ROCSHMEM_ALLOWED_IBV_DEVICES
=
mlx5_2,mlx5_3,mlx5_4,mlx5_5,mlx5_6,mlx5_7,mlx5_8,mlx5_9
export
ROCSHMEM_HEAP_SIZE
=
3737418240
export
ROCSHMEM_HEAP_SIZE
=
3737418240
export
ROCSHMEM_TOPO_FILE_FORCE
=
./topo.config
export
ROCSHMEM_TOPO_FILE_FORCE
=
./topo.config
# export ROCSHMEM_DISABLE_HDP_FLUSH=1
# export ROCSHMEM_GDR_DISABLE_XDP=1
# # duSHMEM
# # duSHMEM
# export LD_LIBRARY_PATH=/opt/dtk/dushmem/lib:$LD_LIBRARY_PATH
# export LD_LIBRARY_PATH=/opt/dtk/dushmem/lib:$LD_LIBRARY_PATH
...
...
tests/2.sh
View file @
243eca85
#!/bin/bash
# rocSHMEM
# rocSHMEM
export
ROCSHMEM_GDA_NUM_QPS_DEFAULT_CTX
=
288
export
ROCSHMEM_GDA_NUM_QPS_DEFAULT_CTX
=
288
export
ROCSHMEM_MAX_NUM_CONTEXTS
=
48
export
ROCSHMEM_MAX_NUM_CONTEXTS
=
60
export
ROCSHMEM_ALLOWED_IBV_DEVICES
=
mlx5_2,mlx5_3,mlx5_4,mlx5_5,mlx5_6,mlx5_7,mlx5_8,mlx5_9
export
ROCSHMEM_ALLOWED_IBV_DEVICES
=
mlx5_2,mlx5_3,mlx5_4,mlx5_5,mlx5_6,mlx5_7,mlx5_8,mlx5_9
export
ROCSHMEM_HEAP_SIZE
=
3737418240
export
ROCSHMEM_HEAP_SIZE
=
3737418240
export
ROCSHMEM_TOPO_FILE_FORCE
=
./topo.config
export
ROCSHMEM_TOPO_FILE_FORCE
=
./topo.config
# export ROCSHMEM_DISABLE_HDP_FLUSH=1
# export ROCSHMEM_GDR_DISABLE_XDP=1
# # duSHMEM
# # duSHMEM
# export LD_LIBRARY_PATH=/opt/dtk/dushmem/lib:$LD_LIBRARY_PATH
# export LD_LIBRARY_PATH=/opt/dtk/dushmem/lib:$LD_LIBRARY_PATH
...
...
tests/test_internode.py
View file @
243eca85
...
@@ -143,7 +143,8 @@ def test_main(args: argparse.Namespace, num_sms: int,
...
@@ -143,7 +143,8 @@ def test_main(args: argparse.Namespace, num_sms: int,
# Check `topk_weights`
# Check `topk_weights`
if
not
is_rand
:
if
not
is_rand
:
recv_topk_weights
[
recv_topk_idx
.
eq
(
-
1
)]
=
recv_topk_weights
.
amax
(
dim
=
1
,
keepdim
=
True
).
expand_as
(
recv_topk_weights
)[
recv_topk_idx
.
eq
(
-
1
)]
max_weights
=
recv_topk_weights
.
amax
(
dim
=
1
,
keepdim
=
True
)
# Shape: [Batch, 1]
recv_topk_weights
=
torch
.
where
(
recv_topk_idx
==
-
1
,
max_weights
,
recv_topk_weights
)
check_data
(
recv_topk_weights
,
recv_gbl_rank_prefix_sum
)
check_data
(
recv_topk_weights
,
recv_gbl_rank_prefix_sum
)
# Test cached dispatch (must without top-k staffs)
# Test cached dispatch (must without top-k staffs)
...
@@ -186,6 +187,7 @@ def test_main(args: argparse.Namespace, num_sms: int,
...
@@ -186,6 +187,7 @@ def test_main(args: argparse.Namespace, num_sms: int,
if
local_rank
==
0
:
if
local_rank
==
0
:
print
(
' passed'
,
flush
=
True
)
print
(
' passed'
,
flush
=
True
)
if
local_rank
==
0
:
if
local_rank
==
0
:
print
(
''
,
flush
=
True
)
print
(
''
,
flush
=
True
)
...
@@ -201,6 +203,8 @@ def test_main(args: argparse.Namespace, num_sms: int,
...
@@ -201,6 +203,8 @@ def test_main(args: argparse.Namespace, num_sms: int,
nvl_recv_bytes
=
(
dispatch_bf16_nvl_recv_bytes
*
fp8_factor
)
if
isinstance
(
current_x
,
tuple
)
else
dispatch_bf16_nvl_recv_bytes
nvl_recv_bytes
=
(
dispatch_bf16_nvl_recv_bytes
*
fp8_factor
)
if
isinstance
(
current_x
,
tuple
)
else
dispatch_bf16_nvl_recv_bytes
for
nvl_chunk_size
in
range
(
4
,
45
,
4
):
for
nvl_chunk_size
in
range
(
4
,
45
,
4
):
for
rdma_chunk_size
in
range
(
4
,
33
,
4
):
for
rdma_chunk_size
in
range
(
4
,
33
,
4
):
if
rdma_buffer_size
%
rdma_chunk_size
!=
0
:
continue
config
=
deep_ep
.
Config
(
num_sms
,
nvl_chunk_size
,
nvl_buffer_size
,
rdma_chunk_size
,
rdma_buffer_size
)
config
=
deep_ep
.
Config
(
num_sms
,
nvl_chunk_size
,
nvl_buffer_size
,
rdma_chunk_size
,
rdma_buffer_size
)
tune_args
=
{
'x'
:
current_x
,
'handle'
:
handle
,
'config'
:
config
}
tune_args
=
{
'x'
:
current_x
,
'handle'
:
handle
,
'config'
:
config
}
t
,
notify_t
=
bench_kineto
(
lambda
:
buffer
.
dispatch
(
**
tune_args
),
(
'dispatch'
,
'notify'
),
suppress_kineto_output
=
True
)
t
,
notify_t
=
bench_kineto
(
lambda
:
buffer
.
dispatch
(
**
tune_args
),
(
'dispatch'
,
'notify'
),
suppress_kineto_output
=
True
)
...
@@ -233,6 +237,8 @@ def test_main(args: argparse.Namespace, num_sms: int,
...
@@ -233,6 +237,8 @@ def test_main(args: argparse.Namespace, num_sms: int,
best_time
,
best_results
=
1e10
,
None
best_time
,
best_results
=
1e10
,
None
for
nvl_chunk_size
in
range
(
1
,
8
,
1
):
for
nvl_chunk_size
in
range
(
1
,
8
,
1
):
for
rdma_chunk_size
in
range
(
12
if
num_nodes
==
2
else
8
,
33
,
4
):
for
rdma_chunk_size
in
range
(
12
if
num_nodes
==
2
else
8
,
33
,
4
):
if
rdma_buffer_size
%
rdma_chunk_size
!=
0
:
continue
config
=
deep_ep
.
Config
(
num_sms
,
nvl_chunk_size
,
nvl_buffer_size
,
rdma_chunk_size
,
rdma_buffer_size
)
config
=
deep_ep
.
Config
(
num_sms
,
nvl_chunk_size
,
nvl_buffer_size
,
rdma_chunk_size
,
rdma_buffer_size
)
tune_args
=
{
'x'
:
recv_x
,
'handle'
:
handle
,
'config'
:
config
}
tune_args
=
{
'x'
:
recv_x
,
'handle'
:
handle
,
'config'
:
config
}
t
,
notify_t
=
bench_kineto
(
lambda
:
buffer
.
combine
(
**
tune_args
),
(
'combine'
,
'notify'
),
suppress_kineto_output
=
True
)
t
,
notify_t
=
bench_kineto
(
lambda
:
buffer
.
combine
(
**
tune_args
),
(
'combine'
,
'notify'
),
suppress_kineto_output
=
True
)
...
@@ -265,8 +271,9 @@ def test_loop(local_rank: int, num_local_ranks: int, args: argparse.Namespace):
...
@@ -265,8 +271,9 @@ def test_loop(local_rank: int, num_local_ranks: int, args: argparse.Namespace):
ll_num_tokens
,
ll_hidden
,
ll_num_experts
,
ll_num_topk
=
16
,
5120
,
256
,
9
ll_num_tokens
,
ll_hidden
,
ll_num_experts
,
ll_num_topk
=
16
,
5120
,
256
,
9
num_rdma_bytes_ll
=
deep_ep
.
Buffer
.
get_low_latency_rdma_size_hint
(
ll_num_tokens
,
ll_hidden
,
num_ranks
,
ll_num_experts
)
num_rdma_bytes_ll
=
deep_ep
.
Buffer
.
get_low_latency_rdma_size_hint
(
ll_num_tokens
,
ll_hidden
,
num_ranks
,
ll_num_experts
)
num_sms
=
48
num_sms
=
60
num_qps_per_rank
=
max
(
num_sms
,
ll_num_experts
//
num_ranks
if
args
.
test_ll_compatibility
else
0
)
num_qps_per_rank
=
max
(
num_sms
,
ll_num_experts
//
num_ranks
if
args
.
test_ll_compatibility
else
0
)
deep_ep
.
Buffer
.
set_num_sms
(
num_sms
)
hidden_bytes
=
get_hidden_bytes
(
args
)
hidden_bytes
=
get_hidden_bytes
(
args
)
num_nvl_bytes
,
num_rdma_bytes
,
num_rdma_bytes_norm
=
0
,
0
,
0
num_nvl_bytes
,
num_rdma_bytes
,
num_rdma_bytes_norm
=
0
,
0
,
0
...
@@ -292,7 +299,7 @@ def test_loop(local_rank: int, num_local_ranks: int, args: argparse.Namespace):
...
@@ -292,7 +299,7 @@ def test_loop(local_rank: int, num_local_ranks: int, args: argparse.Namespace):
break
break
if
local_rank
==
0
:
if
local_rank
==
0
:
print
(
f
'
{
ref_hash
=
}
'
)
print
(
f
'
ref_hash=
{
ref_hash
}
'
)
print
(
''
,
flush
=
True
)
print
(
''
,
flush
=
True
)
for
j
in
range
(
20
):
for
j
in
range
(
20
):
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment