Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
DeepEP
Commits
1553fc42
Commit
1553fc42
authored
Mar 04, 2025
by
Chenggang Zhao
Browse files
Improve EP2/4 performance
parent
55cdd9a6
Changes
3
Show whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
88 additions
and
74 deletions
+88
-74
csrc/kernels/intranode.cu
csrc/kernels/intranode.cu
+81
-64
deep_ep/buffer.py
deep_ep/buffer.py
+6
-8
tests/test_intranode.py
tests/test_intranode.py
+1
-2
No files found.
csrc/kernels/intranode.cu
View file @
1553fc42
...
@@ -174,8 +174,8 @@ void cached_notify_dispatch(const int* rank_prefix_matrix, int num_memset_int,
...
@@ -174,8 +174,8 @@ void cached_notify_dispatch(const int* rank_prefix_matrix, int num_memset_int,
#undef CACHED_NOTIFY_DISPATCH_LAUNCH_CASE
#undef CACHED_NOTIFY_DISPATCH_LAUNCH_CASE
}
}
template
<
int
kNumRanks
>
template
<
int
kNumRanks
,
int
kNumThreads
>
__global__
void
__launch_bounds__
(
kNum
Ranks
*
32
,
1
)
__global__
void
__launch_bounds__
(
kNum
Threads
,
1
)
dispatch
(
int4
*
recv_x
,
float
*
recv_x_scales
,
int
*
recv_src_idx
,
int64_t
*
recv_topk_idx
,
float
*
recv_topk_weights
,
int
*
recv_channel_offset
,
dispatch
(
int4
*
recv_x
,
float
*
recv_x_scales
,
int
*
recv_src_idx
,
int64_t
*
recv_topk_idx
,
float
*
recv_topk_weights
,
int
*
recv_channel_offset
,
int
*
send_head
,
const
int4
*
x
,
const
float
*
x_scales
,
const
int64_t
*
topk_idx
,
const
float
*
topk_weights
,
int
*
send_head
,
const
int4
*
x
,
const
float
*
x_scales
,
const
int64_t
*
topk_idx
,
const
float
*
topk_weights
,
const
bool
*
is_token_in_rank
,
const
int
*
channel_prefix_matrix
,
const
bool
*
is_token_in_rank
,
const
int
*
channel_prefix_matrix
,
...
@@ -187,11 +187,11 @@ dispatch(int4* recv_x, float* recv_x_scales, int* recv_src_idx, int64_t* recv_to
...
@@ -187,11 +187,11 @@ dispatch(int4* recv_x, float* recv_x_scales, int* recv_src_idx, int64_t* recv_to
const
bool
is_sender
=
sm_id
%
2
==
0
;
const
bool
is_sender
=
sm_id
%
2
==
0
;
EP_DEVICE_ASSERT
(
num_sms
%
2
==
0
);
EP_DEVICE_ASSERT
(
num_sms
%
2
==
0
);
// Each warp is responsible for a single rank
// Several warps are response for a single rank
const
auto
num_threads_per_rank
=
kNumThreads
/
kNumRanks
;
const
auto
num_channels
=
num_sms
/
2
;
const
auto
num_channels
=
num_sms
/
2
;
const
auto
responsible_rank
=
(
static_cast
<
int
>
(
thread_id
))
/
32
;
const
auto
responsible_rank
=
(
static_cast
<
int
>
(
thread_id
))
/
num_threads_per_rank
;
// Even-numbered blocks for sending, odd-numbered blocks for receiving.
// Even-numbered blocks for sending, odd-numbered blocks for receiving
const
auto
responsible_channel
=
sm_id
/
2
;
const
auto
responsible_channel
=
sm_id
/
2
;
int
num_experts_per_rank
=
num_experts
/
kNumRanks
;
int
num_experts_per_rank
=
num_experts
/
kNumRanks
;
...
@@ -234,19 +234,20 @@ dispatch(int4* recv_x, float* recv_x_scales, int* recv_src_idx, int64_t* recv_to
...
@@ -234,19 +234,20 @@ dispatch(int4* recv_x, float* recv_x_scales, int* recv_src_idx, int64_t* recv_to
if
(
is_sender
)
{
if
(
is_sender
)
{
// Workers for sending
// Workers for sending
constexpr
int
num_send_warps
=
kNumRanks
;
constexpr
int
num_send_warps
=
kNumThreads
/
32
;
constexpr
int
num_send_warps_per_rank
=
num_send_warps
/
kNumRanks
;
const
auto
send_thread_id
=
thread_id
;
const
auto
send_thread_id
=
thread_id
;
const
auto
send_warp_id
=
send_thread_id
/
32
;
const
auto
send_lane_id
=
send_thread_id
%
32
;
const
auto
send_lane_id
=
send_thread_id
%
32
;
const
auto
send_warp_id_in_rank
=
send_thread_id
%
num_threads_per_rank
/
32
;
EP_DEVICE_ASSERT
(
kNumRanks
<=
32
);
EP_DEVICE_ASSERT
(
kNumRanks
<=
32
);
EP_DEVICE_ASSERT
(
num_send_warps
==
kNumRanks
and
send_warp_id
==
responsible_rank
);
EP_DEVICE_ASSERT
(
num_send_warps
%
kNumRanks
==
0
);
// Send offset by `-value - 1`, e.g. 0 -> -1, 1 -> -2
// Send offset by `-value - 1`, e.g. 0 -> -1, 1 -> -2
// NOTES: this is for distinguishing zero tokens
// NOTES: this is for distinguishing zero tokens
if
(
send_lane_id
==
0
)
{
if
(
send_lane_id
==
0
and
send_warp_id_in_rank
==
0
)
{
int
value
=
responsible_channel
>
0
?
channel_prefix_matrix
[
send_warp_id
*
num_channels
+
responsible_channel
-
1
]
:
0
;
int
value
=
responsible_channel
>
0
?
channel_prefix_matrix
[
responsible_rank
*
num_channels
+
responsible_channel
-
1
]
:
0
;
st_relaxed_sys_global
(
channel_start_offset
.
buffer
(),
-
value
-
1
);
st_relaxed_sys_global
(
channel_start_offset
.
buffer
(),
-
value
-
1
);
value
=
channel_prefix_matrix
[
send_warp_id
*
num_channels
+
responsible_channel
];
value
=
channel_prefix_matrix
[
responsible_rank
*
num_channels
+
responsible_channel
];
st_relaxed_sys_global
(
channel_end_offset
.
buffer
(),
-
value
-
1
);
st_relaxed_sys_global
(
channel_end_offset
.
buffer
(),
-
value
-
1
);
}
}
__syncwarp
();
__syncwarp
();
...
@@ -257,8 +258,9 @@ dispatch(int4* recv_x, float* recv_x_scales, int* recv_src_idx, int64_t* recv_to
...
@@ -257,8 +258,9 @@ dispatch(int4* recv_x, float* recv_x_scales, int* recv_src_idx, int64_t* recv_to
// Iterate over all tokens and send by chunks
// Iterate over all tokens and send by chunks
int
cached_channel_tail_idx
=
0
;
int
cached_channel_tail_idx
=
0
;
for
(
int64_t
token_idx
=
token_start_idx
;
token_idx
<
token_end_idx
;)
{
for
(
int64_t
token_idx
=
token_start_idx
;
token_idx
<
token_end_idx
;
)
{
// Check destination queue emptiness, or wait a buffer to be released (rare cases)
// Check destination queue emptiness, or wait a buffer to be released (rare cases)
// NOTES: the head index received by different warps may not be the same
auto
start_time
=
clock64
();
auto
start_time
=
clock64
();
while
(
send_lane_id
==
0
)
{
while
(
send_lane_id
==
0
)
{
// NOTES: we only consider the worst case, because counting the real numbers are time-consuming
// NOTES: we only consider the worst case, because counting the real numbers are time-consuming
...
@@ -276,17 +278,19 @@ dispatch(int4* recv_x, float* recv_x_scales, int* recv_src_idx, int64_t* recv_to
...
@@ -276,17 +278,19 @@ dispatch(int4* recv_x, float* recv_x_scales, int* recv_src_idx, int64_t* recv_to
int
chunk_token_idx
=
0
;
int
chunk_token_idx
=
0
;
while
(
chunk_token_idx
<
num_max_send_tokens
and
token_idx
<
token_end_idx
)
{
while
(
chunk_token_idx
<
num_max_send_tokens
and
token_idx
<
token_end_idx
)
{
if
(
send_lane_id
==
0
)
// NOTES: for the same token, the warp assigned to save `send_head` may be different from the warp assigned to send subsequent data
send_head
[
token_idx
*
kNumRanks
+
send_warp_id
]
=
is_token_in_rank
[
token_idx
*
kNumRanks
+
send_warp_id
]
?
cached_channel_tail_idx
:
-
1
;
if
(
send_lane_id
==
0
and
token_idx
%
num_send_warps_per_rank
==
send_warp_id_in_rank
)
send_head
[
token_idx
*
kNumRanks
+
responsible_rank
]
=
is_token_in_rank
[
token_idx
*
kNumRanks
+
responsible_rank
]
?
cached_channel_tail_idx
:
-
1
;
// Skip if not selected
// Skip if not selected
if
(
not
is_token_in_rank
[
token_idx
*
kNumRanks
+
send_warp_id
])
{
if
(
not
is_token_in_rank
[
token_idx
*
kNumRanks
+
responsible_rank
])
{
token_idx
++
;
token_idx
++
;
continue
;
continue
;
}
}
// Get an empty slot
// Get an empty slot
int
dst_slot_idx
=
(
cached_channel_tail_idx
++
)
%
num_recv_buffer_tokens
;
int
dst_slot_idx
=
(
cached_channel_tail_idx
++
)
%
num_recv_buffer_tokens
;
if
(
cached_channel_tail_idx
%
num_send_warps_per_rank
==
send_warp_id_in_rank
)
{
// Copy data
// Copy data
auto
shifted_channel_x_buffers
=
channel_x_buffers
.
buffer
()
+
dst_slot_idx
*
hidden_int4
;
auto
shifted_channel_x_buffers
=
channel_x_buffers
.
buffer
()
+
dst_slot_idx
*
hidden_int4
;
auto
shifted_x
=
x
+
token_idx
*
hidden_int4
;
auto
shifted_x
=
x
+
token_idx
*
hidden_int4
;
...
@@ -300,7 +304,7 @@ dispatch(int4* recv_x, float* recv_x_scales, int* recv_src_idx, int64_t* recv_to
...
@@ -300,7 +304,7 @@ dispatch(int4* recv_x, float* recv_x_scales, int* recv_src_idx, int64_t* recv_to
// Copy `topk_idx` and `topk_weights` with transformed index
// Copy `topk_idx` and `topk_weights` with transformed index
if
(
send_lane_id
<
num_topk
)
{
if
(
send_lane_id
<
num_topk
)
{
// Top-k index
// Top-k index
int
recv_expert_begin
=
send_warp_id
*
num_experts_per_rank
,
recv_expert_end
=
(
send_warp_id
+
1
)
*
num_experts_per_rank
;
int
recv_expert_begin
=
responsible_rank
*
num_experts_per_rank
,
recv_expert_end
=
(
responsible_rank
+
1
)
*
num_experts_per_rank
;
auto
idx_value
=
__ldg
(
topk_idx
+
token_idx
*
num_topk
+
send_lane_id
);
auto
idx_value
=
__ldg
(
topk_idx
+
token_idx
*
num_topk
+
send_lane_id
);
idx_value
=
(
idx_value
>=
recv_expert_begin
and
idx_value
<
recv_expert_end
)
?
idx_value
-
recv_expert_begin
:
-
1
;
idx_value
=
(
idx_value
>=
recv_expert_begin
and
idx_value
<
recv_expert_end
)
?
idx_value
-
recv_expert_begin
:
-
1
;
channel_topk_idx_buffers
[
dst_slot_idx
*
num_topk
+
send_lane_id
]
=
idx_value
;
channel_topk_idx_buffers
[
dst_slot_idx
*
num_topk
+
send_lane_id
]
=
idx_value
;
...
@@ -315,28 +319,32 @@ dispatch(int4* recv_x, float* recv_x_scales, int* recv_src_idx, int64_t* recv_to
...
@@ -315,28 +319,32 @@ dispatch(int4* recv_x, float* recv_x_scales, int* recv_src_idx, int64_t* recv_to
#pragma unroll
#pragma unroll
for
(
int
i
=
send_lane_id
;
i
<
num_scales
;
i
+=
32
)
for
(
int
i
=
send_lane_id
;
i
<
num_scales
;
i
+=
32
)
channel_x_scales_buffers
[
dst_slot_idx
*
num_scales
+
i
]
=
__ldg
(
x_scales
+
token_idx
*
num_scales
+
i
);
channel_x_scales_buffers
[
dst_slot_idx
*
num_scales
+
i
]
=
__ldg
(
x_scales
+
token_idx
*
num_scales
+
i
);
}
// Move token index
// Move token index
chunk_token_idx
++
,
token_idx
++
;
chunk_token_idx
++
,
token_idx
++
;
}
}
// Move tail index
// Move tail index
__syncwarp
();
// NOTES: here all warps should share the same new tail
if
(
send_lane_id
==
0
)
asm
volatile
(
"bar.sync %0, %1;"
::
"r"
(
responsible_rank
),
"r"
(
num_threads_per_rank
));
if
(
send_warp_id_in_rank
==
0
and
send_lane_id
==
0
)
st_release_sys_global
(
channel_tail_idx
.
buffer
(),
cached_channel_tail_idx
);
st_release_sys_global
(
channel_tail_idx
.
buffer
(),
cached_channel_tail_idx
);
}
}
}
else
{
}
else
{
// Workers for receiving and copying into buffer
// Workers for receiving and copying into buffer
constexpr
int
num_recv_warps
=
kNumRanks
;
constexpr
int
num_recv_warps
=
kNumThreads
/
32
;
constexpr
int
num_recv_warps_per_rank
=
num_recv_warps
/
kNumRanks
;
const
auto
recv_thread_id
=
thread_id
;
const
auto
recv_thread_id
=
thread_id
;
const
auto
recv_warp_id
=
recv_thread_id
/
32
;
const
auto
recv_lane_id
=
recv_thread_id
%
32
;
const
auto
recv_lane_id
=
recv_thread_id
%
32
;
EP_DEVICE_ASSERT
(
kNumRanks
<=
32
and
recv_warp_id
==
responsible_rank
);
const
auto
recv_thread_id_in_rank
=
recv_thread_id
%
num_threads_per_rank
;
EP_DEVICE_ASSERT
(
recv_thread_id
>=
0
and
num_recv_warps
==
kNumRanks
);
const
auto
recv_warp_id_in_rank
=
recv_thread_id_in_rank
/
32
;
EP_DEVICE_ASSERT
(
kNumRanks
<=
32
);
EP_DEVICE_ASSERT
(
recv_thread_id
>=
0
and
num_recv_warps
%
kNumRanks
==
0
);
// Calculate offset first
// Calculate offset first
auto
rank_prefix_matrix
=
reinterpret_cast
<
int
*>
(
buffer_ptrs
[
rank
]);
auto
rank_prefix_matrix
=
reinterpret_cast
<
int
*>
(
buffer_ptrs
[
rank
]);
int
rank_offset
=
re
cv_warp_id
>
0
?
rank_prefix_matrix
[(
re
cv_warp_id
-
1
)
*
kNumRanks
+
rank
]
:
0
;
int
rank_offset
=
re
sponsible_rank
>
0
?
rank_prefix_matrix
[(
re
sponsible_rank
-
1
)
*
kNumRanks
+
rank
]
:
0
;
// Receive channel offset
// Receive channel offset
int
total_offset
,
num_tokens_to_recv
;
int
total_offset
,
num_tokens_to_recv
;
...
@@ -344,23 +352,29 @@ dispatch(int4* recv_x, float* recv_x_scales, int* recv_src_idx, int64_t* recv_to
...
@@ -344,23 +352,29 @@ dispatch(int4* recv_x, float* recv_x_scales, int* recv_src_idx, int64_t* recv_to
while
(
recv_lane_id
==
0
and
(
num_tokens_to_recv
=
ld_volatile_global
(
channel_end_offset
.
buffer
()))
==
0
);
while
(
recv_lane_id
==
0
and
(
num_tokens_to_recv
=
ld_volatile_global
(
channel_end_offset
.
buffer
()))
==
0
);
if
(
recv_lane_id
==
0
)
{
if
(
recv_lane_id
==
0
)
{
total_offset
=
-
total_offset
-
1
,
num_tokens_to_recv
=
-
num_tokens_to_recv
-
1
;
total_offset
=
-
total_offset
-
1
,
num_tokens_to_recv
=
-
num_tokens_to_recv
-
1
;
recv_channel_offset
[
recv_warp_id
*
num_channels
+
responsible_channel
]
=
total_offset
;
if
(
recv_warp_id_in_rank
==
0
)
recv_channel_offset
[
responsible_rank
*
num_channels
+
responsible_channel
]
=
total_offset
;
num_tokens_to_recv
-=
total_offset
;
num_tokens_to_recv
-=
total_offset
;
}
}
total_offset
=
__shfl_sync
(
0xffffffff
,
total_offset
,
0
);
total_offset
=
__shfl_sync
(
0xffffffff
,
total_offset
,
0
);
total_offset
+=
rank_offset
;
total_offset
+=
rank_offset
;
num_tokens_to_recv
=
__shfl_sync
(
0xffffffff
,
num_tokens_to_recv
,
0
);
num_tokens_to_recv
=
__shfl_sync
(
0xffffffff
,
num_tokens_to_recv
,
0
);
// Shared tail indices for different warps
__shared__
volatile
int
shared_channel_tail_idx
[
kNumRanks
];
auto
start_time
=
clock64
();
auto
start_time
=
clock64
();
int
cached_channel_head_idx
=
0
,
cached_channel_tail_idx
=
0
;
int
cached_channel_head_idx
=
0
,
cached_channel_tail_idx
=
0
;
while
(
num_tokens_to_recv
>
0
)
{
while
(
num_tokens_to_recv
>
0
)
{
//
Check channel status by lane 0
//
NOTES: unlike the sender, the receiver must ensure that the tail indices hold by different warps are same
while
(
recv_
lane_id
==
0
)
{
while
(
recv_
thread_id_in_rank
==
0
)
{
cached_channel_tail_idx
=
ld_acquire_sys_global
(
channel_tail_idx
.
buffer
());;
cached_channel_tail_idx
=
ld_acquire_sys_global
(
channel_tail_idx
.
buffer
());;
// Ready to copy
// Ready to copy
if
(
cached_channel_head_idx
!=
cached_channel_tail_idx
)
if
(
cached_channel_head_idx
!=
cached_channel_tail_idx
)
{
shared_channel_tail_idx
[
responsible_rank
]
=
cached_channel_tail_idx
;
break
;
break
;
}
// Timeout check
// Timeout check
if
(
clock64
()
-
start_time
>
NUM_TIMEOUT_CYCLES
)
{
if
(
clock64
()
-
start_time
>
NUM_TIMEOUT_CYCLES
)
{
...
@@ -369,12 +383,13 @@ dispatch(int4* recv_x, float* recv_x_scales, int* recv_src_idx, int64_t* recv_to
...
@@ -369,12 +383,13 @@ dispatch(int4* recv_x, float* recv_x_scales, int* recv_src_idx, int64_t* recv_to
}
}
}
}
// Sync queue tail
// Synchronize queue tail
cached_channel_tail_idx
=
__shfl_sync
(
0xffffffff
,
cached_channel_tail_idx
,
0
);
asm
volatile
(
"bar.sync %0, %1;"
::
"r"
(
responsible_rank
),
"r"
(
num_threads_per_rank
));
cached_channel_tail_idx
=
shared_channel_tail_idx
[
responsible_rank
];
// Copy data
// Copy data
int
num_recv_tokens
=
cached_channel_tail_idx
-
cached_channel_head_idx
;
int
num_recv_tokens
=
cached_channel_tail_idx
-
cached_channel_head_idx
;
for
(
int
chunk_idx
=
0
;
chunk_idx
<
num_recv_tokens
;
++
chunk_idx
)
{
for
(
int
chunk_idx
=
recv_warp_id_in_rank
;
chunk_idx
<
num_recv_tokens
;
chunk_idx
+=
num_recv_warps_per_rank
)
{
int
token_idx_in_buffer
=
(
cached_channel_head_idx
+
chunk_idx
)
%
num_recv_buffer_tokens
;
int
token_idx_in_buffer
=
(
cached_channel_head_idx
+
chunk_idx
)
%
num_recv_buffer_tokens
;
auto
shifted_buffer_x_int4
=
channel_x_buffers
.
buffer
()
+
token_idx_in_buffer
*
hidden_int4
;
auto
shifted_buffer_x_int4
=
channel_x_buffers
.
buffer
()
+
token_idx_in_buffer
*
hidden_int4
;
auto
shifted_recv_x_int4
=
recv_x
+
static_cast
<
int64_t
>
(
total_offset
+
chunk_idx
)
*
hidden_int4
;
auto
shifted_recv_x_int4
=
recv_x
+
static_cast
<
int64_t
>
(
total_offset
+
chunk_idx
)
*
hidden_int4
;
...
@@ -384,12 +399,12 @@ dispatch(int4* recv_x, float* recv_x_scales, int* recv_src_idx, int64_t* recv_to
...
@@ -384,12 +399,12 @@ dispatch(int4* recv_x, float* recv_x_scales, int* recv_src_idx, int64_t* recv_to
// Copy `src_idx`
// Copy `src_idx`
#pragma unroll 4
#pragma unroll 4
for
(
int
chunk_idx
=
cached_channel_head_idx
+
recv_
lane_id
;
chunk_idx
<
cached_channel_tail_idx
;
chunk_idx
+=
32
)
for
(
int
chunk_idx
=
cached_channel_head_idx
+
recv_
thread_id_in_rank
;
chunk_idx
<
cached_channel_tail_idx
;
chunk_idx
+=
32
*
num_recv_warps_per_rank
)
recv_src_idx
[
total_offset
+
chunk_idx
-
cached_channel_head_idx
]
=
ld_nc_global
(
channel_src_idx_buffers
.
buffer
()
+
chunk_idx
%
num_recv_buffer_tokens
);
recv_src_idx
[
total_offset
+
chunk_idx
-
cached_channel_head_idx
]
=
ld_nc_global
(
channel_src_idx_buffers
.
buffer
()
+
chunk_idx
%
num_recv_buffer_tokens
);
// Copy `topk_idx` and `topk_weights`
// Copy `topk_idx` and `topk_weights`
#pragma unroll 4
#pragma unroll 4
for
(
int
idx
=
recv_
lane_id
;
idx
<
num_recv_tokens
*
num_topk
;
idx
+=
32
)
{
for
(
int
idx
=
recv_
thread_id_in_rank
;
idx
<
num_recv_tokens
*
num_topk
;
idx
+=
32
*
num_recv_warps_per_rank
)
{
int
chunk_idx
=
idx
/
num_topk
,
token_topk_idx
=
idx
%
num_topk
;
int
chunk_idx
=
idx
/
num_topk
,
token_topk_idx
=
idx
%
num_topk
;
int
token_idx_in_buffer
=
(
cached_channel_head_idx
+
chunk_idx
)
%
num_recv_buffer_tokens
;
int
token_idx_in_buffer
=
(
cached_channel_head_idx
+
chunk_idx
)
%
num_recv_buffer_tokens
;
auto
recv_idx
=
static_cast
<
int64_t
>
(
total_offset
+
chunk_idx
)
*
num_topk
+
token_topk_idx
;
auto
recv_idx
=
static_cast
<
int64_t
>
(
total_offset
+
chunk_idx
)
*
num_topk
+
token_topk_idx
;
...
@@ -400,7 +415,7 @@ dispatch(int4* recv_x, float* recv_x_scales, int* recv_src_idx, int64_t* recv_to
...
@@ -400,7 +415,7 @@ dispatch(int4* recv_x, float* recv_x_scales, int* recv_src_idx, int64_t* recv_to
// Copy `x_scales`
// Copy `x_scales`
#pragma unroll 4
#pragma unroll 4
for
(
int
i
=
recv_
lane_id
;
i
<
num_recv_tokens
*
num_scales
;
i
+=
32
)
{
for
(
int
i
=
recv_
thread_id_in_rank
;
i
<
num_recv_tokens
*
num_scales
;
i
+=
32
*
num_recv_warps_per_rank
)
{
int
chunk_idx
=
i
/
num_scales
,
scales_idx
=
i
%
num_scales
;
int
chunk_idx
=
i
/
num_scales
,
scales_idx
=
i
%
num_scales
;
int
token_idx_in_buffer
=
(
cached_channel_head_idx
+
chunk_idx
)
%
num_recv_buffer_tokens
;
int
token_idx_in_buffer
=
(
cached_channel_head_idx
+
chunk_idx
)
%
num_recv_buffer_tokens
;
recv_x_scales
[
static_cast
<
int64_t
>
(
total_offset
+
chunk_idx
)
*
num_scales
+
scales_idx
]
=
recv_x_scales
[
static_cast
<
int64_t
>
(
total_offset
+
chunk_idx
)
*
num_scales
+
scales_idx
]
=
...
@@ -410,8 +425,8 @@ dispatch(int4* recv_x, float* recv_x_scales, int* recv_src_idx, int64_t* recv_to
...
@@ -410,8 +425,8 @@ dispatch(int4* recv_x, float* recv_x_scales, int* recv_src_idx, int64_t* recv_to
// Move queue
// Move queue
cached_channel_head_idx
+=
num_recv_tokens
;
cached_channel_head_idx
+=
num_recv_tokens
;
total_offset
+=
num_recv_tokens
;
total_offset
+=
num_recv_tokens
;
__syncwarp
(
);
asm
volatile
(
"bar.sync %0, %1;"
::
"r"
(
responsible_rank
),
"r"
(
num_threads_per_rank
)
);
if
(
recv_lane_id
==
0
)
if
(
recv_warp_id_in_rank
==
num_recv_warps_per_rank
-
1
and
recv_lane_id
==
0
)
st_relaxed_sys_global
(
channel_head_idx
.
buffer
(),
cached_channel_head_idx
);
st_relaxed_sys_global
(
channel_head_idx
.
buffer
(),
cached_channel_head_idx
);
// Exit
// Exit
...
@@ -426,8 +441,10 @@ void dispatch(void* recv_x, float* recv_x_scales, int* recv_src_idx, int64_t* re
...
@@ -426,8 +441,10 @@ void dispatch(void* recv_x, float* recv_x_scales, int* recv_src_idx, int64_t* re
int
num_tokens
,
int
hidden_int4
,
int
num_topk
,
int
num_experts
,
int
num_scales
,
int
num_tokens
,
int
hidden_int4
,
int
num_topk
,
int
num_experts
,
int
num_scales
,
void
**
buffer_ptrs
,
int
rank
,
int
num_ranks
,
void
**
buffer_ptrs
,
int
rank
,
int
num_ranks
,
cudaStream_t
stream
,
int
num_sms
,
int
num_max_send_tokens
,
int
num_recv_buffer_tokens
)
{
cudaStream_t
stream
,
int
num_sms
,
int
num_max_send_tokens
,
int
num_recv_buffer_tokens
)
{
constexpr
int
kNumThreads
=
512
;
#define DISPATCH_LAUNCH_CASE(ranks) \
#define DISPATCH_LAUNCH_CASE(ranks) \
LAUNCH_KERNEL(&cfg, dispatch<ranks>, \
LAUNCH_KERNEL(&cfg, dispatch<ranks
, kNumThreads
>, \
reinterpret_cast<int4*>(recv_x), recv_x_scales, recv_src_idx, recv_topk_idx, recv_topk_weights, recv_channel_offset, \
reinterpret_cast<int4*>(recv_x), recv_x_scales, recv_src_idx, recv_topk_idx, recv_topk_weights, recv_channel_offset, \
send_head, reinterpret_cast<const int4*>(x), x_scales, topk_idx, topk_weights, \
send_head, reinterpret_cast<const int4*>(x), x_scales, topk_idx, topk_weights, \
is_token_in_rank, channel_prefix_matrix, \
is_token_in_rank, channel_prefix_matrix, \
...
@@ -438,7 +455,7 @@ break
...
@@ -438,7 +455,7 @@ break
// Even-numbered blocks for sending, odd-numbered blocks for receiving.
// Even-numbered blocks for sending, odd-numbered blocks for receiving.
EP_HOST_ASSERT
(
num_sms
%
2
==
0
);
EP_HOST_ASSERT
(
num_sms
%
2
==
0
);
SETUP_LAUNCH_CONFIG
(
num_sms
,
num_ranks
*
32
,
stream
);
SETUP_LAUNCH_CONFIG
(
num_sms
,
kNumThreads
,
stream
);
SWITCH_RANKS
(
DISPATCH_LAUNCH_CASE
);
SWITCH_RANKS
(
DISPATCH_LAUNCH_CASE
);
#undef DISPATCH_LAUNCH_CASE
#undef DISPATCH_LAUNCH_CASE
}
}
...
...
deep_ep/buffer.py
View file @
1553fc42
...
@@ -160,12 +160,11 @@ class Buffer:
...
@@ -160,12 +160,11 @@ class Buffer:
Returns:
Returns:
config: the recommended config.
config: the recommended config.
"""
"""
# Intranode
if
num_ranks
<=
8
:
return
Config
(
Buffer
.
num_sms
,
6
,
256
,
6
,
128
)
# Internode
config_map
=
{
config_map
=
{
2
:
Config
(
Buffer
.
num_sms
,
16
,
256
,
6
,
128
),
4
:
Config
(
Buffer
.
num_sms
,
16
,
256
,
6
,
128
),
8
:
Config
(
Buffer
.
num_sms
,
6
,
256
,
6
,
128
),
16
:
Config
(
Buffer
.
num_sms
,
16
,
288
,
20
,
128
),
16
:
Config
(
Buffer
.
num_sms
,
16
,
288
,
20
,
128
),
24
:
Config
(
Buffer
.
num_sms
,
8
,
288
,
32
,
128
),
24
:
Config
(
Buffer
.
num_sms
,
8
,
288
,
32
,
128
),
32
:
Config
(
Buffer
.
num_sms
,
8
,
288
,
32
,
128
),
32
:
Config
(
Buffer
.
num_sms
,
8
,
288
,
32
,
128
),
...
@@ -188,12 +187,11 @@ class Buffer:
...
@@ -188,12 +187,11 @@ class Buffer:
Returns:
Returns:
config: the recommended config.
config: the recommended config.
"""
"""
# Intranode
if
num_ranks
<=
8
:
return
Config
(
Buffer
.
num_sms
,
6
,
256
,
6
,
128
)
# Internode
config_map
=
{
config_map
=
{
2
:
Config
(
Buffer
.
num_sms
,
6
,
256
,
6
,
128
),
4
:
Config
(
Buffer
.
num_sms
,
6
,
256
,
6
,
128
),
8
:
Config
(
Buffer
.
num_sms
,
6
,
256
,
6
,
128
),
16
:
Config
(
Buffer
.
num_sms
,
2
,
288
,
28
,
128
),
16
:
Config
(
Buffer
.
num_sms
,
2
,
288
,
28
,
128
),
24
:
Config
(
Buffer
.
num_sms
,
1
,
288
,
20
,
128
),
24
:
Config
(
Buffer
.
num_sms
,
1
,
288
,
20
,
128
),
32
:
Config
(
Buffer
.
num_sms
,
1
,
288
,
20
,
128
),
32
:
Config
(
Buffer
.
num_sms
,
1
,
288
,
20
,
128
),
...
...
tests/test_intranode.py
View file @
1553fc42
...
@@ -13,7 +13,6 @@ import test_low_latency
...
@@ -13,7 +13,6 @@ import test_low_latency
def
test_main
(
num_sms
:
int
,
local_rank
:
int
,
num_ranks
:
int
,
rank
:
int
,
buffer
:
deep_ep
.
Buffer
,
group
:
dist
.
ProcessGroup
):
def
test_main
(
num_sms
:
int
,
local_rank
:
int
,
num_ranks
:
int
,
rank
:
int
,
buffer
:
deep_ep
.
Buffer
,
group
:
dist
.
ProcessGroup
):
# Settings
# Settings
# TODO: fix EP2/4/8 performance
num_tokens
,
hidden
,
num_topk
,
num_experts
=
4096
,
7168
,
8
,
(
256
//
num_ranks
)
*
num_ranks
num_tokens
,
hidden
,
num_topk
,
num_experts
=
4096
,
7168
,
8
,
(
256
//
num_ranks
)
*
num_ranks
assert
num_experts
%
num_ranks
==
0
assert
num_experts
%
num_ranks
==
0
if
local_rank
==
0
:
if
local_rank
==
0
:
...
@@ -182,7 +181,7 @@ def test_main(num_sms: int, local_rank: int, num_ranks: int, rank: int, buffer:
...
@@ -182,7 +181,7 @@ def test_main(num_sms: int, local_rank: int, num_ranks: int, rank: int, buffer:
# Tune combine performance
# Tune combine performance
best_time
,
best_results
=
1e10
,
None
best_time
,
best_results
=
1e10
,
None
for
nvl_chunk_size
in
range
(
1
,
5
,
1
):
for
nvl_chunk_size
in
range
(
1
,
7
,
1
):
config
=
deep_ep
.
Config
(
num_sms
,
nvl_chunk_size
,
nvl_buffer_size
)
config
=
deep_ep
.
Config
(
num_sms
,
nvl_chunk_size
,
nvl_buffer_size
)
tune_args
=
{
'x'
:
recv_x
,
'handle'
:
handle
,
'config'
:
config
}
tune_args
=
{
'x'
:
recv_x
,
'handle'
:
handle
,
'config'
:
config
}
t
=
bench
(
lambda
:
buffer
.
combine
(
**
tune_args
))[
0
]
t
=
bench
(
lambda
:
buffer
.
combine
(
**
tune_args
))[
0
]
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment