Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
DeepEP
Commits
c4b8ffc3
Unverified
Commit
c4b8ffc3
authored
Mar 18, 2025
by
Chenggang Zhao
Committed by
GitHub
Mar 18, 2025
Browse files
Merge pull request #79 from deepseek-ai/zero-copy-combine
Support zero-copy for low-latency combine
parents
82dcf48f
66465476
Changes
7
Hide whitespace changes
Inline
Side-by-side
Showing
7 changed files
with
80 additions
and
28 deletions
+80
-28
csrc/config.hpp
csrc/config.hpp
+6
-1
csrc/deep_ep.cpp
csrc/deep_ep.cpp
+20
-3
csrc/deep_ep.hpp
csrc/deep_ep.hpp
+5
-1
csrc/kernels/api.cuh
csrc/kernels/api.cuh
+2
-1
csrc/kernels/internode_ll.cu
csrc/kernels/internode_ll.cu
+6
-4
deep_ep/buffer.py
deep_ep/buffer.py
+22
-5
tests/test_low_latency.py
tests/test_low_latency.py
+19
-13
No files found.
csrc/config.hpp
View file @
c4b8ffc3
...
@@ -102,6 +102,9 @@ struct LowLatencyBuffer {
...
@@ -102,6 +102,9 @@ struct LowLatencyBuffer {
void
*
combine_rdma_recv_data_buffer
=
nullptr
;
void
*
combine_rdma_recv_data_buffer
=
nullptr
;
int
*
combine_rdma_recv_flag_buffer
=
nullptr
;
int
*
combine_rdma_recv_flag_buffer
=
nullptr
;
void
*
combine_rdma_send_buffer_data_start
=
nullptr
;
size_t
num_bytes_per_combine_msg
=
0
;
std
::
pair
<
int
*
,
int
>
clean_meta
()
{
std
::
pair
<
int
*
,
int
>
clean_meta
()
{
EP_HOST_ASSERT
(
dispatch_rdma_recv_count_buffer
==
combine_rdma_recv_flag_buffer
);
EP_HOST_ASSERT
(
dispatch_rdma_recv_count_buffer
==
combine_rdma_recv_flag_buffer
);
return
{
dispatch_rdma_recv_count_buffer
,
num_clean_int
};
return
{
dispatch_rdma_recv_count_buffer
,
num_clean_int
};
...
@@ -163,7 +166,9 @@ struct LowLatencyLayout {
...
@@ -163,7 +166,9 @@ struct LowLatencyLayout {
advance
<
int
*>
(
rdma_buffer
,
send_buffer_bytes
*
2
+
recv_buffer_bytes
*
2
+
signaling_buffer_bytes
*
i
),
advance
<
int
*>
(
rdma_buffer
,
send_buffer_bytes
*
2
+
recv_buffer_bytes
*
2
+
signaling_buffer_bytes
*
i
),
advance
(
rdma_buffer
,
send_buffer_bytes
*
i
),
advance
(
rdma_buffer
,
send_buffer_bytes
*
i
),
advance
(
rdma_buffer
,
send_buffer_bytes
*
2
+
recv_buffer_bytes
*
i
),
advance
(
rdma_buffer
,
send_buffer_bytes
*
2
+
recv_buffer_bytes
*
i
),
advance
<
int
*>
(
rdma_buffer
,
send_buffer_bytes
*
2
+
recv_buffer_bytes
*
2
+
signaling_buffer_bytes
*
i
)
advance
<
int
*>
(
rdma_buffer
,
send_buffer_bytes
*
2
+
recv_buffer_bytes
*
2
+
signaling_buffer_bytes
*
i
),
advance
(
rdma_buffer
,
send_buffer_bytes
*
i
+
sizeof
(
int4
)),
num_bytes_per_combine_msg
};
};
}
}
}
}
...
...
csrc/deep_ep.cpp
View file @
c4b8ffc3
...
@@ -1100,7 +1100,8 @@ std::tuple<torch::Tensor, std::optional<EventHandle>, std::optional<std::functio
...
@@ -1100,7 +1100,8 @@ std::tuple<torch::Tensor, std::optional<EventHandle>, std::optional<std::functio
Buffer
::
low_latency_combine
(
const
torch
::
Tensor
&
x
,
const
torch
::
Tensor
&
topk_idx
,
const
torch
::
Tensor
&
topk_weights
,
Buffer
::
low_latency_combine
(
const
torch
::
Tensor
&
x
,
const
torch
::
Tensor
&
topk_idx
,
const
torch
::
Tensor
&
topk_weights
,
const
torch
::
Tensor
&
src_info
,
const
torch
::
Tensor
&
layout_range
,
const
torch
::
Tensor
&
src_info
,
const
torch
::
Tensor
&
layout_range
,
int
num_max_dispatch_tokens_per_rank
,
int
num_experts
,
int
num_max_dispatch_tokens_per_rank
,
int
num_experts
,
bool
async
,
bool
return_recv_hook
,
std
::
optional
<
torch
::
Tensor
>
out
)
{
bool
zero_copy
,
bool
async
,
bool
return_recv_hook
,
const
std
::
optional
<
torch
::
Tensor
>&
out
)
{
EP_HOST_ASSERT
(
low_latency_mode
);
EP_HOST_ASSERT
(
low_latency_mode
);
// Tensor checks
// Tensor checks
...
@@ -1159,7 +1160,8 @@ Buffer::low_latency_combine(const torch::Tensor& x, const torch::Tensor& topk_id
...
@@ -1159,7 +1160,8 @@ Buffer::low_latency_combine(const torch::Tensor& x, const torch::Tensor& topk_id
next_clean_meta
.
first
,
next_clean_meta
.
second
,
next_clean_meta
.
first
,
next_clean_meta
.
second
,
num_combined_tokens
,
hidden
,
num_max_dispatch_tokens_per_rank
,
num_combined_tokens
,
hidden
,
num_max_dispatch_tokens_per_rank
,
num_topk
,
num_experts
,
rank
,
num_ranks
,
num_topk
,
num_experts
,
rank
,
num_ranks
,
workspace
,
launch_stream
,
phases
);
workspace
,
launch_stream
,
phases
,
zero_copy
);
};
};
launcher
(
return_recv_hook
?
LOW_LATENCY_SEND_PHASE
:
(
LOW_LATENCY_SEND_PHASE
|
LOW_LATENCY_RECV_PHASE
));
launcher
(
return_recv_hook
?
LOW_LATENCY_SEND_PHASE
:
(
LOW_LATENCY_SEND_PHASE
|
LOW_LATENCY_RECV_PHASE
));
...
@@ -1182,6 +1184,20 @@ Buffer::low_latency_combine(const torch::Tensor& x, const torch::Tensor& topk_id
...
@@ -1182,6 +1184,20 @@ Buffer::low_latency_combine(const torch::Tensor& x, const torch::Tensor& topk_id
return
{
combined_x
,
event
,
recv_hook
};
return
{
combined_x
,
event
,
recv_hook
};
}
}
torch
::
Tensor
Buffer
::
get_next_low_latency_combine_buffer
(
int
num_max_dispatch_tokens_per_rank
,
int
hidden
,
int
num_experts
)
{
LowLatencyLayout
layout
(
rdma_buffer_ptr
,
num_max_dispatch_tokens_per_rank
,
hidden
,
num_ranks
,
num_experts
);
auto
buffer
=
layout
.
buffers
[
low_latency_buffer_idx
];
auto
dtype
=
torch
::
kBFloat16
;
auto
num_msg_elems
=
static_cast
<
int
>
(
buffer
.
num_bytes_per_combine_msg
/
elementSize
(
torch
::
kBFloat16
));
EP_HOST_ASSERT
(
buffer
.
num_bytes_per_combine_msg
%
elementSize
(
torch
::
kBFloat16
)
==
0
);
return
torch
::
from_blob
(
buffer
.
combine_rdma_send_buffer_data_start
,
{
num_experts
/
num_ranks
,
num_ranks
*
num_max_dispatch_tokens_per_rank
,
hidden
},
{
num_ranks
*
num_max_dispatch_tokens_per_rank
*
num_msg_elems
,
num_msg_elems
,
1
},
torch
::
TensorOptions
().
dtype
(
dtype
).
device
(
torch
::
kCUDA
));
}
}
// namespace deep_ep
}
// namespace deep_ep
PYBIND11_MODULE
(
TORCH_EXTENSION_NAME
,
m
)
{
PYBIND11_MODULE
(
TORCH_EXTENSION_NAME
,
m
)
{
...
@@ -1218,5 +1234,6 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
...
@@ -1218,5 +1234,6 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
.
def
(
"internode_combine"
,
&
deep_ep
::
Buffer
::
internode_combine
)
.
def
(
"internode_combine"
,
&
deep_ep
::
Buffer
::
internode_combine
)
.
def
(
"clean_low_latency_buffer"
,
&
deep_ep
::
Buffer
::
clean_low_latency_buffer
)
.
def
(
"clean_low_latency_buffer"
,
&
deep_ep
::
Buffer
::
clean_low_latency_buffer
)
.
def
(
"low_latency_dispatch"
,
&
deep_ep
::
Buffer
::
low_latency_dispatch
)
.
def
(
"low_latency_dispatch"
,
&
deep_ep
::
Buffer
::
low_latency_dispatch
)
.
def
(
"low_latency_combine"
,
&
deep_ep
::
Buffer
::
low_latency_combine
);
.
def
(
"low_latency_combine"
,
&
deep_ep
::
Buffer
::
low_latency_combine
)
.
def
(
"get_next_low_latency_combine_buffer"
,
&
deep_ep
::
Buffer
::
get_next_low_latency_combine_buffer
);
}
}
csrc/deep_ep.hpp
View file @
c4b8ffc3
...
@@ -143,7 +143,11 @@ public:
...
@@ -143,7 +143,11 @@ public:
low_latency_combine
(
const
torch
::
Tensor
&
x
,
const
torch
::
Tensor
&
topk_idx
,
const
torch
::
Tensor
&
topk_weights
,
low_latency_combine
(
const
torch
::
Tensor
&
x
,
const
torch
::
Tensor
&
topk_idx
,
const
torch
::
Tensor
&
topk_weights
,
const
torch
::
Tensor
&
src_info
,
const
torch
::
Tensor
&
layout_range
,
const
torch
::
Tensor
&
src_info
,
const
torch
::
Tensor
&
layout_range
,
int
num_max_dispatch_tokens_per_rank
,
int
num_experts
,
int
num_max_dispatch_tokens_per_rank
,
int
num_experts
,
bool
async
,
bool
return_recv_hook
,
std
::
optional
<
torch
::
Tensor
>
out
=
std
::
nullopt
);
bool
zero_copy
,
bool
async
,
bool
return_recv_hook
,
const
std
::
optional
<
torch
::
Tensor
>&
out
=
std
::
nullopt
);
torch
::
Tensor
get_next_low_latency_combine_buffer
(
int
num_max_dispatch_tokens_per_rank
,
int
hidden
,
int
num_experts
);
};
};
}
// namespace deep_ep
}
// namespace deep_ep
csrc/kernels/api.cuh
View file @
c4b8ffc3
...
@@ -147,7 +147,8 @@ void combine(void* combined_x,
...
@@ -147,7 +147,8 @@ void combine(void* combined_x,
int
*
next_clean
,
int
num_next_clean_int
,
int
*
next_clean
,
int
num_next_clean_int
,
int
num_combined_tokens
,
int
hidden
,
int
num_max_dispatch_tokens_per_rank
,
int
num_combined_tokens
,
int
hidden
,
int
num_max_dispatch_tokens_per_rank
,
int
num_topk
,
int
num_experts
,
int
rank
,
int
num_ranks
,
int
num_topk
,
int
num_experts
,
int
rank
,
int
num_ranks
,
void
*
workspace
,
cudaStream_t
stream
,
int
phases
);
void
*
workspace
,
cudaStream_t
stream
,
int
phases
,
bool
zero_copy
);
}
// namespace internode_ll
}
// namespace internode_ll
...
...
csrc/kernels/internode_ll.cu
View file @
c4b8ffc3
...
@@ -353,7 +353,7 @@ combine(void* combined_x,
...
@@ -353,7 +353,7 @@ combine(void* combined_x,
int
num_combined_tokens
,
int
hidden
,
int
num_topk
,
int
num_combined_tokens
,
int
hidden
,
int
num_topk
,
int
num_max_dispatch_tokens_per_rank
,
int
num_max_dispatch_tokens_per_rank
,
int
num_experts
,
int
rank
,
int
num_ranks
,
int
num_experts
,
int
rank
,
int
num_ranks
,
int
phases
)
{
int
phases
,
bool
zero_copy
)
{
const
auto
sm_id
=
static_cast
<
int
>
(
blockIdx
.
x
);
const
auto
sm_id
=
static_cast
<
int
>
(
blockIdx
.
x
);
const
auto
num_sms
=
static_cast
<
int
>
(
gridDim
.
x
);
const
auto
num_sms
=
static_cast
<
int
>
(
gridDim
.
x
);
const
auto
thread_id
=
static_cast
<
int
>
(
threadIdx
.
x
);
const
auto
thread_id
=
static_cast
<
int
>
(
threadIdx
.
x
);
...
@@ -420,7 +420,8 @@ combine(void* combined_x,
...
@@ -420,7 +420,8 @@ combine(void* combined_x,
UNROLLED_WARP_COPY
(
7
,
lane_id
,
hidden_bf16_int4
,
dst_int4_ptr
,
x_int4
,
ld_nc_global
,
st_na_global
);
UNROLLED_WARP_COPY
(
7
,
lane_id
,
hidden_bf16_int4
,
dst_int4_ptr
,
x_int4
,
ld_nc_global
,
st_na_global
);
}
else
{
}
else
{
const
auto
buf_int4_ptr
=
reinterpret_cast
<
int4
*>
(
buf_ptr
);
const
auto
buf_int4_ptr
=
reinterpret_cast
<
int4
*>
(
buf_ptr
);
UNROLLED_WARP_COPY
(
7
,
lane_id
,
hidden_bf16_int4
,
buf_int4_ptr
,
x_int4
,
ld_nc_global
,
st_na_global
);
if
(
not
zero_copy
)
UNROLLED_WARP_COPY
(
7
,
lane_id
,
hidden_bf16_int4
,
buf_int4_ptr
,
x_int4
,
ld_nc_global
,
st_na_global
);
nvshmemi_ibgda_put_nbi_warp
(
dst_ptr
,
buf_ptr
,
hidden
*
sizeof
(
nv_bfloat16
),
dst_rank
,
local_expert_idx
,
lane_id
,
token_idx
-
offset
);
nvshmemi_ibgda_put_nbi_warp
(
dst_ptr
,
buf_ptr
,
hidden
*
sizeof
(
nv_bfloat16
),
dst_rank
,
local_expert_idx
,
lane_id
,
token_idx
-
offset
);
}
}
}
}
...
@@ -500,7 +501,8 @@ void combine(void* combined_x,
...
@@ -500,7 +501,8 @@ void combine(void* combined_x,
int
*
next_clean
,
int
num_next_clean_int
,
int
*
next_clean
,
int
num_next_clean_int
,
int
num_combined_tokens
,
int
hidden
,
int
num_max_dispatch_tokens_per_rank
,
int
num_combined_tokens
,
int
hidden
,
int
num_max_dispatch_tokens_per_rank
,
int
num_topk
,
int
num_experts
,
int
rank
,
int
num_ranks
,
int
num_topk
,
int
num_experts
,
int
rank
,
int
num_ranks
,
void
*
workspace
,
cudaStream_t
stream
,
int
phases
)
{
void
*
workspace
,
cudaStream_t
stream
,
int
phases
,
bool
zero_copy
)
{
constexpr
int
kNumWarpsPerGroup
=
10
;
constexpr
int
kNumWarpsPerGroup
=
10
;
constexpr
int
kNumWarpGroups
=
3
;
constexpr
int
kNumWarpGroups
=
3
;
constexpr
int
kNumMaxTopk
=
9
;
constexpr
int
kNumMaxTopk
=
9
;
...
@@ -524,7 +526,7 @@ LAUNCH_KERNEL(&cfg, combine_func, \
...
@@ -524,7 +526,7 @@ LAUNCH_KERNEL(&cfg, combine_func, \
num_combined_tokens, hidden, num_topk, \
num_combined_tokens, hidden, num_topk, \
num_max_dispatch_tokens_per_rank, \
num_max_dispatch_tokens_per_rank, \
num_experts, rank, num_ranks, \
num_experts, rank, num_ranks, \
phases); } break
phases
, zero_copy
); } break
SETUP_LAUNCH_CONFIG
(
num_sms
,
num_warps
*
32
,
stream
);
SETUP_LAUNCH_CONFIG
(
num_sms
,
num_warps
*
32
,
stream
);
SWITCH_HIDDEN
(
COMBINE_LAUNCH_CASE
);
SWITCH_HIDDEN
(
COMBINE_LAUNCH_CASE
);
...
...
deep_ep/buffer.py
View file @
c4b8ffc3
...
@@ -488,7 +488,7 @@ class Buffer:
...
@@ -488,7 +488,7 @@ class Buffer:
self
.
runtime
.
low_latency_dispatch
(
x
,
topk_idx
,
self
.
runtime
.
low_latency_dispatch
(
x
,
topk_idx
,
num_max_dispatch_tokens_per_rank
,
num_experts
,
num_max_dispatch_tokens_per_rank
,
num_experts
,
use_fp8
,
async_finish
,
return_recv_hook
)
use_fp8
,
async_finish
,
return_recv_hook
)
handle
=
(
packed_recv_src_info
,
packed_recv_layout_range
,
num_max_dispatch_tokens_per_rank
,
num_experts
)
handle
=
(
packed_recv_src_info
,
packed_recv_layout_range
,
num_max_dispatch_tokens_per_rank
,
x
.
size
(
1
),
num_experts
)
tensors_to_record
=
(
x
,
topk_idx
,
tensors_to_record
=
(
x
,
topk_idx
,
packed_recv_x
,
packed_recv_x_scales
,
packed_recv_count
,
packed_recv_x
,
packed_recv_x_scales
,
packed_recv_count
,
packed_recv_src_info
,
packed_recv_layout_range
)
packed_recv_src_info
,
packed_recv_layout_range
)
...
@@ -497,8 +497,8 @@ class Buffer:
...
@@ -497,8 +497,8 @@ class Buffer:
# noinspection PyTypeChecker
# noinspection PyTypeChecker
def
low_latency_combine
(
self
,
x
:
torch
.
Tensor
,
topk_idx
:
torch
.
Tensor
,
topk_weights
:
torch
.
Tensor
,
def
low_latency_combine
(
self
,
x
:
torch
.
Tensor
,
topk_idx
:
torch
.
Tensor
,
topk_weights
:
torch
.
Tensor
,
handle
:
tuple
,
async_finish
:
bool
=
False
,
return_recv_hook
:
bool
=
False
,
handle
:
tuple
,
zero_copy
:
bool
=
False
,
async_finish
:
bool
=
False
,
out
:
Optional
[
torch
.
Tensor
]
=
None
)
->
\
return_recv_hook
:
bool
=
False
,
out
:
Optional
[
torch
.
Tensor
]
=
None
)
->
\
Tuple
[
torch
.
Tensor
,
EventOverlap
,
Callable
]:
Tuple
[
torch
.
Tensor
,
EventOverlap
,
Callable
]:
"""
"""
A low-latency implementation for combining tokens (reduce **with weights**) with IBGDA.
A low-latency implementation for combining tokens (reduce **with weights**) with IBGDA.
...
@@ -517,6 +517,8 @@ class Buffer:
...
@@ -517,6 +517,8 @@ class Buffer:
topk_weights: `[num_combined_tokens, num_topk]` with `torch.float`, the expert weights selected by the dispatched
topk_weights: `[num_combined_tokens, num_topk]` with `torch.float`, the expert weights selected by the dispatched
tokens. The received tokens will be reduced with the weights in this tensor.
tokens. The received tokens will be reduced with the weights in this tensor.
handle: the communication handle given by the `dispatch` function.
handle: the communication handle given by the `dispatch` function.
zero_copy: whether the tensor is already copied into the RDMA buffer, should be cooperative
with `get_next_low_latency_combine_buffer`.
async_finish: the current stream will not wait for the communication kernels to be finished if set.
async_finish: the current stream will not wait for the communication kernels to be finished if set.
return_recv_hook: return a receiving hook if set. If set, the kernel will just do the RDMA request issues,
return_recv_hook: return a receiving hook if set. If set, the kernel will just do the RDMA request issues,
but **without actually receiving the data**. You must call the received hook to make sure the data's arrival.
but **without actually receiving the data**. You must call the received hook to make sure the data's arrival.
...
@@ -528,9 +530,24 @@ class Buffer:
...
@@ -528,9 +530,24 @@ class Buffer:
event: the event after executing the kernel (valid only if `async_finish` is set).
event: the event after executing the kernel (valid only if `async_finish` is set).
hook: the receiving hook function (valid only if `return_recv_hook` is set).
hook: the receiving hook function (valid only if `return_recv_hook` is set).
"""
"""
src_info
,
layout_range
,
num_max_dispatch_tokens_per_rank
,
num_experts
=
handle
src_info
,
layout_range
,
num_max_dispatch_tokens_per_rank
,
hidden
,
num_experts
=
handle
combined_x
,
event
,
hook
=
self
.
runtime
.
low_latency_combine
(
x
,
topk_idx
,
topk_weights
,
src_info
,
layout_range
,
combined_x
,
event
,
hook
=
self
.
runtime
.
low_latency_combine
(
x
,
topk_idx
,
topk_weights
,
src_info
,
layout_range
,
num_max_dispatch_tokens_per_rank
,
num_experts
,
num_max_dispatch_tokens_per_rank
,
num_experts
,
async_finish
,
return_recv_hook
,
out
)
zero_copy
,
async_finish
,
return_recv_hook
,
out
)
tensors_to_record
=
(
x
,
topk_idx
,
topk_weights
,
src_info
,
layout_range
,
combined_x
)
tensors_to_record
=
(
x
,
topk_idx
,
topk_weights
,
src_info
,
layout_range
,
combined_x
)
return
combined_x
,
EventOverlap
(
event
,
tensors_to_record
if
async_finish
else
None
),
hook
return
combined_x
,
EventOverlap
(
event
,
tensors_to_record
if
async_finish
else
None
),
hook
def
get_next_low_latency_combine_buffer
(
self
,
handle
:
object
):
"""
Get the raw registered RDMA buffer tensor for next low-latency combine, so that the next combine kernel can skip the copying.
Arguments:
handle: the communication handle given by the `dispatch` function.
Returns:
buffer: the raw RDMA low-latency buffer as a BF16 PyTorch tensor with shape
`[num_local_experts, num_ranks * num_max_dispatch_tokens_per_rank, hidden]`, you should fill this buffer
by yourself.
"""
src_info
,
layout_range
,
num_max_dispatch_tokens_per_rank
,
hidden
,
num_experts
=
handle
return
self
.
runtime
.
get_next_low_latency_combine_buffer
(
num_max_dispatch_tokens_per_rank
,
hidden
,
num_experts
)
tests/test_low_latency.py
View file @
c4b8ffc3
...
@@ -73,15 +73,19 @@ def test_main(num_tokens: int, hidden: int, num_experts: int, num_topk: int,
...
@@ -73,15 +73,19 @@ def test_main(num_tokens: int, hidden: int, num_experts: int, num_topk: int,
hash_value
^=
hash_tensor
(
packed_recv_x
[
i
,
:
num_valid_tokens
])
hash_value
^=
hash_tensor
(
packed_recv_x
[
i
,
:
num_valid_tokens
])
# Check combine correctness
# Check combine correctness
out
=
torch
.
empty
((
num_tokens
,
hidden
),
dtype
=
torch
.
bfloat16
,
device
=
'cuda'
)
for
zero_copy
in
(
False
,
True
):
combined_x
,
event
,
hook
=
buffer
.
low_latency_combine
(
simulated_gemm_x
,
topk_idx
,
topk_weights
,
handle
,
if
zero_copy
:
async_finish
=
not
return_recv_hook
,
return_recv_hook
=
return_recv_hook
,
out
=
out
)
buffer
.
get_next_low_latency_combine_buffer
(
handle
)[:,
:,
:]
=
simulated_gemm_x
hook
()
if
return_recv_hook
else
event
.
current_stream_wait
()
out
=
torch
.
empty
((
num_tokens
,
hidden
),
dtype
=
torch
.
bfloat16
,
device
=
'cuda'
)
if
do_check
:
combined_x
,
event
,
hook
=
buffer
.
low_latency_combine
(
simulated_gemm_x
,
topk_idx
,
topk_weights
,
handle
,
diff
=
calc_diff
(
x
*
topk_weights
.
masked_fill
(
topk_idx
==
-
1
,
0
).
sum
(
dim
=
1
).
view
(
-
1
,
1
),
combined_x
)
async_finish
=
not
return_recv_hook
,
assert
torch
.
isnan
(
combined_x
).
sum
().
item
()
==
0
return_recv_hook
=
return_recv_hook
,
out
=
out
)
assert
diff
<
1e-5
,
f
'Error: diff=
{
diff
}
'
hook
()
if
return_recv_hook
else
event
.
current_stream_wait
()
hash_value
^=
hash_tensor
(
combined_x
)
if
do_check
:
diff
=
calc_diff
(
x
*
topk_weights
.
masked_fill
(
topk_idx
==
-
1
,
0
).
sum
(
dim
=
1
).
view
(
-
1
,
1
),
combined_x
)
assert
torch
.
isnan
(
combined_x
).
sum
().
item
()
==
0
assert
diff
<
1e-5
,
f
'Error: diff=
{
diff
}
'
hash_value
^=
hash_tensor
(
combined_x
)
def
create_test_cast_with_outliers
(
num_outliers
):
def
create_test_cast_with_outliers
(
num_outliers
):
tmp
=
torch
.
randn
((
num_tokens
,
hidden
),
dtype
=
torch
.
bfloat16
,
device
=
'cuda'
)
tmp
=
torch
.
randn
((
num_tokens
,
hidden
),
dtype
=
torch
.
bfloat16
,
device
=
'cuda'
)
...
@@ -101,13 +105,15 @@ def test_main(num_tokens: int, hidden: int, num_experts: int, num_topk: int,
...
@@ -101,13 +105,15 @@ def test_main(num_tokens: int, hidden: int, num_experts: int, num_topk: int,
hook
()
hook
()
# noinspection PyShadowingNames
# noinspection PyShadowingNames
def
test_func
(
return_recv_hook
):
def
test_func
(
zero_copy
:
bool
,
return_recv_hook
:
bool
):
recv_x
,
recv_count
,
handle
,
event
,
hook
=
\
recv_x
,
recv_count
,
handle
,
event
,
hook
=
\
buffer
.
low_latency_dispatch
(
x
,
topk_idx
,
num_tokens
,
num_experts
,
buffer
.
low_latency_dispatch
(
x
,
topk_idx
,
num_tokens
,
num_experts
,
async_finish
=
False
,
return_recv_hook
=
return_recv_hook
)
async_finish
=
False
,
return_recv_hook
=
return_recv_hook
)
large_gemm_with_hook
(
hook
)
if
return_recv_hook
else
None
large_gemm_with_hook
(
hook
)
if
return_recv_hook
else
None
if
zero_copy
:
buffer
.
get_next_low_latency_combine_buffer
(
handle
)[:,
:,
:]
=
simulated_gemm_x
combined_x
,
event
,
hook
=
buffer
.
low_latency_combine
(
simulated_gemm_x
,
topk_idx
,
topk_weights
,
handle
,
combined_x
,
event
,
hook
=
buffer
.
low_latency_combine
(
simulated_gemm_x
,
topk_idx
,
topk_weights
,
handle
,
return_recv_hook
=
return_recv_hook
)
zero_copy
=
zero_copy
,
return_recv_hook
=
return_recv_hook
)
large_gemm_with_hook
(
hook
)
if
return_recv_hook
else
None
large_gemm_with_hook
(
hook
)
if
return_recv_hook
else
None
# Calculate bandwidth
# Calculate bandwidth
...
@@ -119,14 +125,14 @@ def test_main(num_tokens: int, hidden: int, num_experts: int, num_topk: int,
...
@@ -119,14 +125,14 @@ def test_main(num_tokens: int, hidden: int, num_experts: int, num_topk: int,
num_combine_comm_bytes
+=
num_bf16_bytes
*
num_selections
num_combine_comm_bytes
+=
num_bf16_bytes
*
num_selections
# Dispatch + combine testing
# Dispatch + combine testing
avg_t
,
min_t
,
max_t
=
bench
(
partial
(
test_func
,
return_recv_hook
=
False
))
avg_t
,
min_t
,
max_t
=
bench
(
partial
(
test_func
,
zero_copy
=
False
,
return_recv_hook
=
False
))
print
(
f
'[rank
{
rank
}
] Dispatch + combine bandwidth:
{
(
num_dispatch_comm_bytes
+
num_combine_comm_bytes
)
/
1e9
/
avg_t
:.
2
f
}
GB/s, '
print
(
f
'[rank
{
rank
}
] Dispatch + combine bandwidth:
{
(
num_dispatch_comm_bytes
+
num_combine_comm_bytes
)
/
1e9
/
avg_t
:.
2
f
}
GB/s, '
f
'avg_t=
{
avg_t
*
1e6
:.
2
f
}
us, min_t=
{
min_t
*
1e6
:.
2
f
}
us, max_t=
{
max_t
*
1e6
:.
2
f
}
us'
,
flush
=
True
)
f
'avg_t=
{
avg_t
*
1e6
:.
2
f
}
us, min_t=
{
min_t
*
1e6
:.
2
f
}
us, max_t=
{
max_t
*
1e6
:.
2
f
}
us'
,
flush
=
True
)
# Separate profiling
# Separate profiling
for
return_recv_hook
in
(
False
,
True
):
for
return_recv_hook
in
(
False
,
True
):
group
.
barrier
()
group
.
barrier
()
dispatch_t
,
combine_t
=
bench_kineto
(
partial
(
test_func
,
return_recv_hook
=
return_recv_hook
),
dispatch_t
,
combine_t
=
bench_kineto
(
partial
(
test_func
,
zero_copy
=
True
,
return_recv_hook
=
return_recv_hook
),
kernel_names
=
(
'dispatch'
,
'combine'
),
barrier_comm_profiling
=
True
,
kernel_names
=
(
'dispatch'
,
'combine'
),
barrier_comm_profiling
=
True
,
suppress_kineto_output
=
True
)
suppress_kineto_output
=
True
)
if
not
return_recv_hook
:
if
not
return_recv_hook
:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment