Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
DeepEP
Commits
8aaddf76
Unverified
Commit
8aaddf76
authored
Jun 16, 2025
by
Chenggang Zhao
Committed by
GitHub
Jun 16, 2025
Browse files
Remove the low-latency usage flag (#214)
parent
1b92be8a
Changes
6
Hide whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
15 additions
and
69 deletions
+15
-69
csrc/deep_ep.cpp
csrc/deep_ep.cpp
+4
-24
csrc/deep_ep.hpp
csrc/deep_ep.hpp
+0
-6
csrc/kernels/api.cuh
csrc/kernels/api.cuh
+4
-6
csrc/kernels/internode_ll.cu
csrc/kernels/internode_ll.cu
+7
-19
deep_ep/buffer.py
deep_ep/buffer.py
+0
-13
tests/test_low_latency.py
tests/test_low_latency.py
+0
-1
No files found.
csrc/deep_ep.cpp
View file @
8aaddf76
...
@@ -78,13 +78,6 @@ Buffer::Buffer(int rank, int num_ranks, int64_t num_nvl_bytes, int64_t num_rdma_
...
@@ -78,13 +78,6 @@ Buffer::Buffer(int rank, int num_ranks, int64_t num_nvl_bytes, int64_t num_rdma_
CUDA_CHECK
(
cudaHostGetDevicePointer
(
&
moe_recv_rdma_counter_mapped
,
const_cast
<
int
*>
(
moe_recv_rdma_counter
),
0
));
CUDA_CHECK
(
cudaHostGetDevicePointer
(
&
moe_recv_rdma_counter_mapped
,
const_cast
<
int
*>
(
moe_recv_rdma_counter
),
0
));
*
moe_recv_rdma_counter
=
-
1
;
*
moe_recv_rdma_counter
=
-
1
;
}
}
// Low-latency kernels' usage flag
if
(
low_latency_mode
)
{
CUDA_CHECK
(
cudaMallocHost
(
&
low_latency_usage_flag
,
sizeof
(
int
),
cudaHostAllocMapped
));
CUDA_CHECK
(
cudaHostGetDevicePointer
(
&
low_latency_usage_flag_mapped
,
const_cast
<
int
*>
(
low_latency_usage_flag
),
0
));
*
low_latency_usage_flag
=
0
;
}
}
}
Buffer
::~
Buffer
()
noexcept
(
false
)
{
Buffer
::~
Buffer
()
noexcept
(
false
)
{
...
@@ -1028,16 +1021,6 @@ Buffer::internode_combine(const torch::Tensor& x, const std::optional<torch::Ten
...
@@ -1028,16 +1021,6 @@ Buffer::internode_combine(const torch::Tensor& x, const std::optional<torch::Ten
#endif
#endif
}
}
uint64_t
Buffer
::
get_low_latency_usage_flag
()
const
{
#ifndef DISABLE_NVSHMEM
EP_HOST_ASSERT
(
low_latency_usage_flag
!=
nullptr
);
return
reinterpret_cast
<
uint64_t
>
(
low_latency_usage_flag
);
#else
EP_HOST_ASSERT
(
false
and
"NVSHMEM is disable during compilation"
);
return
0
;
#endif
}
void
Buffer
::
clean_low_latency_buffer
(
int
num_max_dispatch_tokens_per_rank
,
int
hidden
,
int
num_experts
)
{
void
Buffer
::
clean_low_latency_buffer
(
int
num_max_dispatch_tokens_per_rank
,
int
hidden
,
int
num_experts
)
{
#ifndef DISABLE_NVSHMEM
#ifndef DISABLE_NVSHMEM
EP_HOST_ASSERT
(
low_latency_mode
);
EP_HOST_ASSERT
(
low_latency_mode
);
...
@@ -1143,9 +1126,8 @@ Buffer::low_latency_dispatch(const torch::Tensor& x, const torch::Tensor& topk_i
...
@@ -1143,9 +1126,8 @@ Buffer::low_latency_dispatch(const torch::Tensor& x, const torch::Tensor& topk_i
num_tokens
,
hidden
,
num_max_dispatch_tokens_per_rank
,
num_tokens
,
hidden
,
num_max_dispatch_tokens_per_rank
,
num_topk
,
num_experts
,
rank
,
num_ranks
,
num_topk
,
num_experts
,
rank
,
num_ranks
,
use_fp8
,
round_scale
,
use_ue8m0
,
use_fp8
,
round_scale
,
use_ue8m0
,
workspace
,
low_latency_usage_flag_mapped
,
workspace
,
num_device_sms
,
num_device_sms
,
launch_stream
,
launch_stream
,
phases
);
phases
);
};
};
launcher
(
return_recv_hook
?
LOW_LATENCY_SEND_PHASE
:
(
LOW_LATENCY_SEND_PHASE
|
LOW_LATENCY_RECV_PHASE
));
launcher
(
return_recv_hook
?
LOW_LATENCY_SEND_PHASE
:
(
LOW_LATENCY_SEND_PHASE
|
LOW_LATENCY_RECV_PHASE
));
...
@@ -1237,9 +1219,8 @@ Buffer::low_latency_combine(const torch::Tensor& x, const torch::Tensor& topk_id
...
@@ -1237,9 +1219,8 @@ Buffer::low_latency_combine(const torch::Tensor& x, const torch::Tensor& topk_id
next_clean_meta
.
first
,
next_clean_meta
.
second
,
next_clean_meta
.
first
,
next_clean_meta
.
second
,
num_combined_tokens
,
hidden
,
num_max_dispatch_tokens_per_rank
,
num_combined_tokens
,
hidden
,
num_max_dispatch_tokens_per_rank
,
num_topk
,
num_experts
,
rank
,
num_ranks
,
num_topk
,
num_experts
,
rank
,
num_ranks
,
workspace
,
low_latency_usage_flag_mapped
,
workspace
,
num_device_sms
,
num_device_sms
,
launch_stream
,
launch_stream
,
phases
,
zero_copy
);
phases
,
zero_copy
);
};
};
launcher
(
return_recv_hook
?
LOW_LATENCY_SEND_PHASE
:
(
LOW_LATENCY_SEND_PHASE
|
LOW_LATENCY_RECV_PHASE
));
launcher
(
return_recv_hook
?
LOW_LATENCY_SEND_PHASE
:
(
LOW_LATENCY_SEND_PHASE
|
LOW_LATENCY_RECV_PHASE
));
...
@@ -1328,7 +1309,6 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
...
@@ -1328,7 +1309,6 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
.
def
(
"intranode_combine"
,
&
deep_ep
::
Buffer
::
intranode_combine
)
.
def
(
"intranode_combine"
,
&
deep_ep
::
Buffer
::
intranode_combine
)
.
def
(
"internode_dispatch"
,
&
deep_ep
::
Buffer
::
internode_dispatch
)
.
def
(
"internode_dispatch"
,
&
deep_ep
::
Buffer
::
internode_dispatch
)
.
def
(
"internode_combine"
,
&
deep_ep
::
Buffer
::
internode_combine
)
.
def
(
"internode_combine"
,
&
deep_ep
::
Buffer
::
internode_combine
)
.
def
(
"get_low_latency_usage_flag"
,
&
deep_ep
::
Buffer
::
get_low_latency_usage_flag
)
.
def
(
"clean_low_latency_buffer"
,
&
deep_ep
::
Buffer
::
clean_low_latency_buffer
)
.
def
(
"clean_low_latency_buffer"
,
&
deep_ep
::
Buffer
::
clean_low_latency_buffer
)
.
def
(
"low_latency_dispatch"
,
&
deep_ep
::
Buffer
::
low_latency_dispatch
)
.
def
(
"low_latency_dispatch"
,
&
deep_ep
::
Buffer
::
low_latency_dispatch
)
.
def
(
"low_latency_combine"
,
&
deep_ep
::
Buffer
::
low_latency_combine
)
.
def
(
"low_latency_combine"
,
&
deep_ep
::
Buffer
::
low_latency_combine
)
...
...
csrc/deep_ep.hpp
View file @
8aaddf76
...
@@ -71,10 +71,6 @@ private:
...
@@ -71,10 +71,6 @@ private:
volatile
int
*
moe_recv_rdma_counter
=
nullptr
;
volatile
int
*
moe_recv_rdma_counter
=
nullptr
;
int
*
moe_recv_rdma_counter_mapped
=
nullptr
;
int
*
moe_recv_rdma_counter_mapped
=
nullptr
;
// Host-side low-latency kernels' usages
volatile
int
*
low_latency_usage_flag
=
nullptr
;
int
*
low_latency_usage_flag_mapped
=
nullptr
;
public:
public:
Buffer
(
int
rank
,
int
num_ranks
,
int64_t
num_nvl_bytes
,
int64_t
num_rdma_bytes
,
bool
low_latency_mode
);
Buffer
(
int
rank
,
int
num_ranks
,
int64_t
num_nvl_bytes
,
int64_t
num_rdma_bytes
,
bool
low_latency_mode
);
...
@@ -134,8 +130,6 @@ public:
...
@@ -134,8 +130,6 @@ public:
const
torch
::
Tensor
&
combined_rdma_head
,
const
torch
::
Tensor
&
combined_nvl_head
,
const
torch
::
Tensor
&
combined_rdma_head
,
const
torch
::
Tensor
&
combined_nvl_head
,
const
Config
&
config
,
std
::
optional
<
EventHandle
>&
previous_event
,
bool
async
,
bool
allocate_on_comm_stream
);
const
Config
&
config
,
std
::
optional
<
EventHandle
>&
previous_event
,
bool
async
,
bool
allocate_on_comm_stream
);
uint64_t
get_low_latency_usage_flag
()
const
;
void
clean_low_latency_buffer
(
int
num_max_dispatch_tokens_per_rank
,
int
hidden
,
int
num_experts
);
void
clean_low_latency_buffer
(
int
num_max_dispatch_tokens_per_rank
,
int
hidden
,
int
num_experts
);
std
::
tuple
<
torch
::
Tensor
,
std
::
optional
<
torch
::
Tensor
>
,
torch
::
Tensor
,
torch
::
Tensor
,
torch
::
Tensor
,
std
::
optional
<
EventHandle
>
,
std
::
optional
<
std
::
function
<
void
()
>>>
std
::
tuple
<
torch
::
Tensor
,
std
::
optional
<
torch
::
Tensor
>
,
torch
::
Tensor
,
torch
::
Tensor
,
torch
::
Tensor
,
std
::
optional
<
EventHandle
>
,
std
::
optional
<
std
::
function
<
void
()
>>>
...
...
csrc/kernels/api.cuh
View file @
8aaddf76
...
@@ -147,9 +147,8 @@ void dispatch(void* packed_recv_x, void* packed_recv_x_scales,
...
@@ -147,9 +147,8 @@ void dispatch(void* packed_recv_x, void* packed_recv_x_scales,
int
num_tokens
,
int
hidden
,
int
num_max_dispatch_tokens_per_rank
,
int
num_tokens
,
int
hidden
,
int
num_max_dispatch_tokens_per_rank
,
int
num_topk
,
int
num_experts
,
int
rank
,
int
num_ranks
,
int
num_topk
,
int
num_experts
,
int
rank
,
int
num_ranks
,
bool
use_fp8
,
bool
round_scale
,
bool
use_ue8m0
,
bool
use_fp8
,
bool
round_scale
,
bool
use_ue8m0
,
void
*
workspace
,
int
*
usage_flag
,
void
*
workspace
,
int
num_device_sms
,
int
num_device_sms
,
cudaStream_t
stream
,
cudaStream_t
stream
,
int
phases
);
int
phases
);
void
combine
(
void
*
combined_x
,
void
combine
(
void
*
combined_x
,
void
*
rdma_recv_x
,
int
*
rdma_recv_flag
,
void
*
rdma_send_x
,
void
*
rdma_recv_x
,
int
*
rdma_recv_flag
,
void
*
rdma_send_x
,
...
@@ -158,9 +157,8 @@ void combine(void* combined_x,
...
@@ -158,9 +157,8 @@ void combine(void* combined_x,
int
*
next_clean
,
int
num_next_clean_int
,
int
*
next_clean
,
int
num_next_clean_int
,
int
num_combined_tokens
,
int
hidden
,
int
num_max_dispatch_tokens_per_rank
,
int
num_combined_tokens
,
int
hidden
,
int
num_max_dispatch_tokens_per_rank
,
int
num_topk
,
int
num_experts
,
int
rank
,
int
num_ranks
,
int
num_topk
,
int
num_experts
,
int
rank
,
int
num_ranks
,
void
*
workspace
,
int
*
usage_flag
,
void
*
workspace
,
int
num_device_sms
,
int
num_device_sms
,
cudaStream_t
stream
,
cudaStream_t
stream
,
int
phases
,
bool
zero_copy
);
int
phases
,
bool
zero_copy
);
}
// namespace internode_ll
}
// namespace internode_ll
...
...
csrc/kernels/internode_ll.cu
View file @
8aaddf76
...
@@ -48,9 +48,8 @@ dispatch(void* packed_recv_x, void* packed_recv_x_scales,
...
@@ -48,9 +48,8 @@ dispatch(void* packed_recv_x, void* packed_recv_x_scales,
int
*
next_clean
,
int
num_next_clean_int
,
int
*
next_clean
,
int
num_next_clean_int
,
int
num_tokens
,
int
num_max_dispatch_tokens_per_rank
,
int
num_tokens
,
int
num_max_dispatch_tokens_per_rank
,
int
num_topk
,
int
num_experts
,
int
rank
,
int
num_ranks
,
int
num_topk
,
int
num_experts
,
int
rank
,
int
num_ranks
,
bool
round_scale
,
int
*
usage_flag
,
int
num_warp_groups
,
int
num_warps_per_group
,
int
num_warp_groups
,
int
num_warps_per_group
,
int
phases
)
{
bool
round_scale
,
int
phases
)
{
const
auto
sm_id
=
static_cast
<
int
>
(
blockIdx
.
x
);
const
auto
sm_id
=
static_cast
<
int
>
(
blockIdx
.
x
);
const
auto
thread_id
=
static_cast
<
int
>
(
threadIdx
.
x
);
const
auto
thread_id
=
static_cast
<
int
>
(
threadIdx
.
x
);
const
auto
warp_id
=
thread_id
/
32
,
lane_id
=
get_lane_id
();
const
auto
warp_id
=
thread_id
/
32
,
lane_id
=
get_lane_id
();
...
@@ -189,10 +188,6 @@ dispatch(void* packed_recv_x, void* packed_recv_x_scales,
...
@@ -189,10 +188,6 @@ dispatch(void* packed_recv_x, void* packed_recv_x_scales,
#pragma unroll
#pragma unroll
for
(
int
i
=
lane_id
;
i
<
num_experts
;
i
+=
32
)
for
(
int
i
=
lane_id
;
i
<
num_experts
;
i
+=
32
)
atomic_add_release_global
(
atomic_finish_counter_per_expert
+
i
,
FINISHED_SUM_TAG
);
atomic_add_release_global
(
atomic_finish_counter_per_expert
+
i
,
FINISHED_SUM_TAG
);
}
else
if
(
sm_id
==
1
)
{
// The second SM is also responsible for notifying PCIe usage
if
(
lane_id
==
0
)
atomicAdd_system
(
usage_flag
,
1
);
}
}
// This SM should be responsible for some destination experts, read `topk_idx` for them
// This SM should be responsible for some destination experts, read `topk_idx` for them
...
@@ -341,9 +336,8 @@ void dispatch(void* packed_recv_x, void* packed_recv_x_scales,
...
@@ -341,9 +336,8 @@ void dispatch(void* packed_recv_x, void* packed_recv_x_scales,
int
num_tokens
,
int
hidden
,
int
num_max_dispatch_tokens_per_rank
,
int
num_tokens
,
int
hidden
,
int
num_max_dispatch_tokens_per_rank
,
int
num_topk
,
int
num_experts
,
int
rank
,
int
num_ranks
,
int
num_topk
,
int
num_experts
,
int
rank
,
int
num_ranks
,
bool
use_fp8
,
bool
round_scale
,
bool
use_ue8m0
,
bool
use_fp8
,
bool
round_scale
,
bool
use_ue8m0
,
void
*
workspace
,
int
*
usage_flag
,
void
*
workspace
,
int
num_device_sms
,
int
num_device_sms
,
cudaStream_t
stream
,
cudaStream_t
stream
,
int
phases
)
{
int
phases
)
{
constexpr
int
kNumMaxTopK
=
9
;
constexpr
int
kNumMaxTopK
=
9
;
const
int
num_warp_groups
=
ceil_div
(
num_experts
,
num_device_sms
);
const
int
num_warp_groups
=
ceil_div
(
num_experts
,
num_device_sms
);
const
int
num_warps_per_group
=
32
/
num_warp_groups
;
const
int
num_warps_per_group
=
32
/
num_warp_groups
;
...
@@ -380,9 +374,8 @@ LAUNCH_KERNEL(&cfg, dispatch_func, \
...
@@ -380,9 +374,8 @@ LAUNCH_KERNEL(&cfg, dispatch_func, \
next_clean, num_next_clean_int, \
next_clean, num_next_clean_int, \
num_tokens, num_max_dispatch_tokens_per_rank, \
num_tokens, num_max_dispatch_tokens_per_rank, \
num_topk, num_experts, rank, num_ranks, \
num_topk, num_experts, rank, num_ranks, \
round_scale, usage_flag, \
num_warp_groups, num_warps_per_group, \
num_warp_groups, num_warps_per_group, \
phases); } break
round_scale,
phases); } break
SETUP_LAUNCH_CONFIG
(
num_sms
,
num_warps
*
32
,
stream
);
SETUP_LAUNCH_CONFIG
(
num_sms
,
num_warps
*
32
,
stream
);
SWITCH_HIDDEN
(
DISPATCH_LAUNCH_CASE
);
SWITCH_HIDDEN
(
DISPATCH_LAUNCH_CASE
);
...
@@ -400,7 +393,6 @@ combine(void* combined_x,
...
@@ -400,7 +393,6 @@ combine(void* combined_x,
int
num_combined_tokens
,
int
hidden
,
int
num_topk
,
int
num_combined_tokens
,
int
hidden
,
int
num_topk
,
int
num_max_dispatch_tokens_per_rank
,
int
num_max_dispatch_tokens_per_rank
,
int
num_experts
,
int
rank
,
int
num_ranks
,
int
num_experts
,
int
rank
,
int
num_ranks
,
int
*
usage_flag
,
int
num_warp_groups
,
int
num_warps_per_group
,
int
num_warp_groups
,
int
num_warps_per_group
,
int
phases
,
bool
zero_copy
)
{
int
phases
,
bool
zero_copy
)
{
const
auto
sm_id
=
static_cast
<
int
>
(
blockIdx
.
x
);
const
auto
sm_id
=
static_cast
<
int
>
(
blockIdx
.
x
);
...
@@ -497,13 +489,11 @@ combine(void* combined_x,
...
@@ -497,13 +489,11 @@ combine(void* combined_x,
if
((
phases
&
LOW_LATENCY_RECV_PHASE
)
==
0
)
if
((
phases
&
LOW_LATENCY_RECV_PHASE
)
==
0
)
return
;
return
;
// Wait all ranks to arrive
and notify usages
// Wait all ranks to arrive
if
(
responsible_expert_idx
<
num_experts
)
{
if
(
responsible_expert_idx
<
num_experts
)
{
EP_DEVICE_ASSERT
(
num_warps_per_group
>
1
);
EP_DEVICE_ASSERT
(
num_warps_per_group
>
1
);
if
(
sub_warp_id
==
0
and
lane_id
==
0
)
{
if
(
sub_warp_id
==
0
and
lane_id
==
0
)
{
while
(
ld_acquire_sys_global
(
rdma_recv_flag
+
responsible_expert_idx
)
==
0
);
while
(
ld_acquire_sys_global
(
rdma_recv_flag
+
responsible_expert_idx
)
==
0
);
}
else
if
(
sm_id
==
0
and
sub_warp_id
==
1
and
lane_id
==
0
)
{
atomicAdd_system
(
usage_flag
,
1
);
}
}
}
}
cg
::
this_grid
().
sync
();
cg
::
this_grid
().
sync
();
...
@@ -555,9 +545,8 @@ void combine(void* combined_x,
...
@@ -555,9 +545,8 @@ void combine(void* combined_x,
int
*
next_clean
,
int
num_next_clean_int
,
int
*
next_clean
,
int
num_next_clean_int
,
int
num_combined_tokens
,
int
hidden
,
int
num_max_dispatch_tokens_per_rank
,
int
num_combined_tokens
,
int
hidden
,
int
num_max_dispatch_tokens_per_rank
,
int
num_topk
,
int
num_experts
,
int
rank
,
int
num_ranks
,
int
num_topk
,
int
num_experts
,
int
rank
,
int
num_ranks
,
void
*
workspace
,
int
*
usage_flag
,
void
*
workspace
,
int
num_device_sms
,
int
num_device_sms
,
cudaStream_t
stream
,
cudaStream_t
stream
,
int
phases
,
bool
zero_copy
)
{
int
phases
,
bool
zero_copy
)
{
constexpr
int
kNumMaxTopk
=
9
;
constexpr
int
kNumMaxTopk
=
9
;
const
int
num_warp_groups
=
ceil_div
(
num_experts
,
num_device_sms
);
const
int
num_warp_groups
=
ceil_div
(
num_experts
,
num_device_sms
);
const
int
num_warps_per_group
=
32
/
num_warp_groups
;
const
int
num_warps_per_group
=
32
/
num_warp_groups
;
...
@@ -582,7 +571,6 @@ LAUNCH_KERNEL(&cfg, combine_func, \
...
@@ -582,7 +571,6 @@ LAUNCH_KERNEL(&cfg, combine_func, \
num_combined_tokens, hidden, num_topk, \
num_combined_tokens, hidden, num_topk, \
num_max_dispatch_tokens_per_rank, \
num_max_dispatch_tokens_per_rank, \
num_experts, rank, num_ranks, \
num_experts, rank, num_ranks, \
usage_flag, \
num_warp_groups, num_warps_per_group, \
num_warp_groups, num_warps_per_group, \
phases, zero_copy); } break
phases, zero_copy); } break
...
...
deep_ep/buffer.py
View file @
8aaddf76
...
@@ -457,19 +457,6 @@ class Buffer:
...
@@ -457,19 +457,6 @@ class Buffer:
async_finish
,
allocate_on_comm_stream
)
async_finish
,
allocate_on_comm_stream
)
return
combined_x
,
combined_topk_weights
,
EventOverlap
(
event
)
return
combined_x
,
combined_topk_weights
,
EventOverlap
(
event
)
def
get_low_latency_usage_flag
(
self
):
"""
Return a host-side integer flag, which indicates the stages of low-latency kernels.
The initial value is 0, the low-latency dispatch will add 1 before communication, the low-latency combine
will add 1 after communication.
This is useful when there is no two-batch overlap, and you want to overlap H2D/D2H transfer with attention layers.
Returns:
flag: the host-side integer flag pointer. The value is in `int`, but returns a `uint64_t` pointer. Please
`reinterpret_cast` the returned value into `int*`.
"""
return
self
.
runtime
.
get_low_latency_usage_flag
()
def
clean_low_latency_buffer
(
self
,
num_max_dispatch_tokens_per_rank
:
int
,
hidden
:
int
,
num_experts
:
int
)
->
None
:
def
clean_low_latency_buffer
(
self
,
num_max_dispatch_tokens_per_rank
:
int
,
hidden
:
int
,
num_experts
:
int
)
->
None
:
"""
"""
As low-latency kernels require part of the buffer to be zero-initialized, so it is vital to clean the buffer
As low-latency kernels require part of the buffer to be zero-initialized, so it is vital to clean the buffer
...
...
tests/test_low_latency.py
View file @
8aaddf76
...
@@ -166,7 +166,6 @@ def test_loop(local_rank: int, num_local_ranks: int):
...
@@ -166,7 +166,6 @@ def test_loop(local_rank: int, num_local_ranks: int):
print
(
f
'Allocating buffer size:
{
num_rdma_bytes
/
1e6
}
MB ...'
,
flush
=
True
)
print
(
f
'Allocating buffer size:
{
num_rdma_bytes
/
1e6
}
MB ...'
,
flush
=
True
)
buffer
=
deep_ep
.
Buffer
(
group
,
num_rdma_bytes
=
num_rdma_bytes
,
low_latency_mode
=
True
,
buffer
=
deep_ep
.
Buffer
(
group
,
num_rdma_bytes
=
num_rdma_bytes
,
low_latency_mode
=
True
,
num_qps_per_rank
=
num_experts
//
num_ranks
)
num_qps_per_rank
=
num_experts
//
num_ranks
)
buffer
.
get_low_latency_usage_flag
()
test_main
(
num_tokens
,
hidden
,
num_experts
,
num_topk
,
rank
,
num_ranks
,
group
,
buffer
,
seed
=
1
)
test_main
(
num_tokens
,
hidden
,
num_experts
,
num_topk
,
rank
,
num_ranks
,
group
,
buffer
,
seed
=
1
)
do_pressure_test
=
False
do_pressure_test
=
False
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment