Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
DeepEP
Commits
ce671dd4
Commit
ce671dd4
authored
Nov 05, 2025
by
lishen
Browse files
低延迟接口支持int8类型通信
parent
da13c63a
Changes
10
Hide whitespace changes
Inline
Side-by-side
Showing
10 changed files
with
399 additions
and
74 deletions
+399
-74
1.sh
1.sh
+2
-1
2.sh
2.sh
+2
-1
csrc/deep_ep.cu
csrc/deep_ep.cu
+20
-9
csrc/deep_ep.hpp
csrc/deep_ep.hpp
+2
-1
csrc/kernels/api.cuh
csrc/kernels/api.cuh
+5
-4
csrc/kernels/internode.cu
csrc/kernels/internode.cu
+0
-5
csrc/kernels/internode_ll.cu
csrc/kernels/internode_ll.cu
+70
-40
csrc/kernels/utils.cuh
csrc/kernels/utils.cuh
+2
-3
deep_ep/buffer.py
deep_ep/buffer.py
+15
-10
tests/test_low_latency_new.py
tests/test_low_latency_new.py
+281
-0
No files found.
1.sh
View file @
ce671dd4
...
@@ -14,4 +14,5 @@ export HIP_VISIBLE_DEVICES=0,1,2,3,4,5,6,7
...
@@ -14,4 +14,5 @@ export HIP_VISIBLE_DEVICES=0,1,2,3,4,5,6,7
# export ROCSHMEM_HEAP_SIZE=536870912 805306368 10737418240
# export ROCSHMEM_HEAP_SIZE=536870912 805306368 10737418240
export
ROCSHMEM_HEAP_SIZE
=
2880100992
export
ROCSHMEM_HEAP_SIZE
=
2880100992
# torchrun --nproc-per-node=1 --nnodes=2 --node-rank=0 --master-addr="10.16.1.37" --master-port=1234 tests/test_internode.py
# torchrun --nproc-per-node=1 --nnodes=2 --node-rank=0 --master-addr="10.16.1.37" --master-port=1234 tests/test_internode.py
torchrun
--nproc-per-node
=
1
--nnodes
=
2
--node-rank
=
0
--master-addr
=
"10.16.1.37"
--master-port
=
1234 tests/test_low_latency.py
# torchrun --nproc-per-node=1 --nnodes=2 --node-rank=0 --master-addr="10.16.1.37" --master-port=1234 tests/test_low_latency.py
torchrun
--nproc-per-node
=
1
--nnodes
=
2
--node-rank
=
0
--master-addr
=
"10.16.1.37"
--master-port
=
1234 tests/test_low_latency_new.py
2.sh
View file @
ce671dd4
...
@@ -14,4 +14,5 @@ export HIP_VISIBLE_DEVICES=0,1,2,3,4,5,6,7
...
@@ -14,4 +14,5 @@ export HIP_VISIBLE_DEVICES=0,1,2,3,4,5,6,7
# export ROCSHMEM_HEAP_SIZE=536870912 805306368 10737418240
# export ROCSHMEM_HEAP_SIZE=536870912 805306368 10737418240
export
ROCSHMEM_HEAP_SIZE
=
2880100992
export
ROCSHMEM_HEAP_SIZE
=
2880100992
# torchrun --nproc-per-node=1 --nnodes=2 --node-rank=1 --master-addr="10.16.1.37" --master-port=1234 tests/test_internode.py
# torchrun --nproc-per-node=1 --nnodes=2 --node-rank=1 --master-addr="10.16.1.37" --master-port=1234 tests/test_internode.py
torchrun
--nproc-per-node
=
1
--nnodes
=
2
--node-rank
=
1
--master-addr
=
"10.16.1.37"
--master-port
=
1234 tests/test_low_latency.py
# torchrun --nproc-per-node=1 --nnodes=2 --node-rank=1 --master-addr="10.16.1.37" --master-port=1234 tests/test_low_latency.py
torchrun
--nproc-per-node
=
1
--nnodes
=
2
--node-rank
=
1
--master-addr
=
"10.16.1.37"
--master-port
=
1234 tests/test_low_latency_new.py
csrc/deep_ep.cu
View file @
ce671dd4
...
@@ -1293,7 +1293,8 @@ void Buffer::clean_low_latency_buffer(int num_max_dispatch_tokens_per_rank, int
...
@@ -1293,7 +1293,8 @@ void Buffer::clean_low_latency_buffer(int num_max_dispatch_tokens_per_rank, int
std
::
tuple
<
torch
::
Tensor
,
std
::
optional
<
torch
::
Tensor
>
,
torch
::
Tensor
,
torch
::
Tensor
,
torch
::
Tensor
,
std
::
optional
<
EventHandle
>
,
std
::
optional
<
std
::
function
<
void
()
>>>
std
::
tuple
<
torch
::
Tensor
,
std
::
optional
<
torch
::
Tensor
>
,
torch
::
Tensor
,
torch
::
Tensor
,
torch
::
Tensor
,
std
::
optional
<
EventHandle
>
,
std
::
optional
<
std
::
function
<
void
()
>>>
Buffer
::
low_latency_dispatch
(
const
torch
::
Tensor
&
x
,
const
torch
::
Tensor
&
topk_idx
,
Buffer
::
low_latency_dispatch
(
const
torch
::
Tensor
&
x
,
const
torch
::
Tensor
&
topk_idx
,
int
num_max_dispatch_tokens_per_rank
,
int
num_experts
,
int
num_max_dispatch_tokens_per_rank
,
int
num_experts
,
bool
use_fp8
,
bool
async
,
bool
return_recv_hook
)
{
bool
use_fp8
,
bool
round_scale
,
bool
use_ue8m0
,
bool
async
,
bool
return_recv_hook
)
{
EP_HOST_ASSERT
(
low_latency_mode
);
EP_HOST_ASSERT
(
low_latency_mode
);
// Tensor checks
// Tensor checks
...
@@ -1306,8 +1307,8 @@ Buffer::low_latency_dispatch(const torch::Tensor& x, const torch::Tensor& topk_i
...
@@ -1306,8 +1307,8 @@ Buffer::low_latency_dispatch(const torch::Tensor& x, const torch::Tensor& topk_i
EP_HOST_ASSERT
(
num_experts
%
num_ranks
==
0
);
EP_HOST_ASSERT
(
num_experts
%
num_ranks
==
0
);
auto
num_tokens
=
static_cast
<
int
>
(
x
.
size
(
0
)),
hidden
=
static_cast
<
int
>
(
x
.
size
(
1
));
auto
num_tokens
=
static_cast
<
int
>
(
x
.
size
(
0
)),
hidden
=
static_cast
<
int
>
(
x
.
size
(
1
));
auto
num_scales
=
hidden
/
128
,
num_topk
=
static_cast
<
int
>
(
topk_idx
.
size
(
1
));
auto
num_topk
=
static_cast
<
int
>
(
topk_idx
.
size
(
1
));
int
num_local_experts
=
num_experts
/
num_ranks
;
auto
num_local_experts
=
num_experts
/
num_ranks
;
// Buffer control
// Buffer control
LowLatencyLayout
layout
(
rdma_buffer_ptr
,
num_max_dispatch_tokens_per_rank
,
hidden
,
num_ranks
,
num_experts
);
LowLatencyLayout
layout
(
rdma_buffer_ptr
,
num_max_dispatch_tokens_per_rank
,
hidden
,
num_ranks
,
num_experts
);
...
@@ -1339,12 +1340,21 @@ Buffer::low_latency_dispatch(const torch::Tensor& x, const torch::Tensor& topk_i
...
@@ -1339,12 +1340,21 @@ Buffer::low_latency_dispatch(const torch::Tensor& x, const torch::Tensor& topk_i
// Allocate column-majored scales
// Allocate column-majored scales
auto
packed_recv_x_scales
=
std
::
optional
<
torch
::
Tensor
>
();
auto
packed_recv_x_scales
=
std
::
optional
<
torch
::
Tensor
>
();
float
*
packed_recv_x_scales_ptr
=
nullptr
;
void
*
packed_recv_x_scales_ptr
=
nullptr
;
if
(
use_fp8
)
{
if
(
use_fp8
)
{
EP_HOST_ASSERT
((
num_ranks
*
num_max_dispatch_tokens_per_rank
)
%
4
==
0
and
"TMA requires the number of tokens to be multiple of 4"
);
EP_HOST_ASSERT
((
num_ranks
*
num_max_dispatch_tokens_per_rank
)
%
4
==
0
and
"TMA requires the number of tokens to be multiple of 4"
);
packed_recv_x_scales
=
torch
::
empty
({
num_local_experts
,
num_scales
,
num_ranks
*
num_max_dispatch_tokens_per_rank
},
torch
::
dtype
(
torch
::
kFloat32
).
device
(
torch
::
kCUDA
));
// TODO: support unaligned cases
EP_HOST_ASSERT
(
hidden
%
(
FP8_QUANTIZATION_NUM_PER_CHANNEL
*
4
)
==
0
);
if
(
not
use_ue8m0
)
{
packed_recv_x_scales
=
torch
::
empty
({
num_local_experts
,
hidden
/
FP8_QUANTIZATION_NUM_PER_CHANNEL
,
num_ranks
*
num_max_dispatch_tokens_per_rank
},
torch
::
dtype
(
torch
::
kFloat32
).
device
(
torch
::
kCUDA
));
}
else
{
EP_HOST_ASSERT
(
round_scale
);
packed_recv_x_scales
=
torch
::
empty
({
num_local_experts
,
hidden
/
(
FP8_QUANTIZATION_NUM_PER_CHANNEL
*
4
),
num_ranks
*
num_max_dispatch_tokens_per_rank
},
torch
::
dtype
(
torch
::
kInt
).
device
(
torch
::
kCUDA
));
}
packed_recv_x_scales
=
torch
::
transpose
(
packed_recv_x_scales
.
value
(),
1
,
2
);
packed_recv_x_scales
=
torch
::
transpose
(
packed_recv_x_scales
.
value
(),
1
,
2
);
packed_recv_x_scales_ptr
=
packed_recv_x_scales
->
data_ptr
<
float
>
();
packed_recv_x_scales_ptr
=
packed_recv_x_scales
->
data_ptr
();
}
}
// Kernel launch
// Kernel launch
...
@@ -1359,8 +1369,9 @@ Buffer::low_latency_dispatch(const torch::Tensor& x, const torch::Tensor& topk_i
...
@@ -1359,8 +1369,9 @@ Buffer::low_latency_dispatch(const torch::Tensor& x, const torch::Tensor& topk_i
x
.
data_ptr
(),
topk_idx
.
data_ptr
<
int64_t
>
(),
x
.
data_ptr
(),
topk_idx
.
data_ptr
<
int64_t
>
(),
next_clean_meta
.
first
,
next_clean_meta
.
second
,
next_clean_meta
.
first
,
next_clean_meta
.
second
,
num_tokens
,
hidden
,
num_max_dispatch_tokens_per_rank
,
num_tokens
,
hidden
,
num_max_dispatch_tokens_per_rank
,
num_topk
,
num_experts
,
rank
,
num_ranks
,
use_fp8
,
num_topk
,
num_experts
,
rank
,
num_ranks
,
workspace
,
launch_stream
,
phases
);
use_fp8
,
round_scale
,
use_ue8m0
,
workspace
,
num_device_sms
,
launch_stream
,
phases
);
};
};
launcher
(
return_recv_hook
?
LOW_LATENCY_SEND_PHASE
:
(
LOW_LATENCY_SEND_PHASE
|
LOW_LATENCY_RECV_PHASE
));
launcher
(
return_recv_hook
?
LOW_LATENCY_SEND_PHASE
:
(
LOW_LATENCY_SEND_PHASE
|
LOW_LATENCY_RECV_PHASE
));
...
@@ -1454,7 +1465,7 @@ Buffer::low_latency_combine(const torch::Tensor& x, const torch::Tensor& topk_id
...
@@ -1454,7 +1465,7 @@ Buffer::low_latency_combine(const torch::Tensor& x, const torch::Tensor& topk_id
next_clean_meta
.
first
,
next_clean_meta
.
second
,
next_clean_meta
.
first
,
next_clean_meta
.
second
,
num_combined_tokens
,
hidden
,
num_max_dispatch_tokens_per_rank
,
num_combined_tokens
,
hidden
,
num_max_dispatch_tokens_per_rank
,
num_topk
,
num_experts
,
rank
,
num_ranks
,
num_topk
,
num_experts
,
rank
,
num_ranks
,
workspace
,
launch_stream
,
workspace
,
num_device_sms
,
launch_stream
,
phases
,
zero_copy
);
phases
,
zero_copy
);
};
};
launcher
(
return_recv_hook
?
LOW_LATENCY_SEND_PHASE
:
(
LOW_LATENCY_SEND_PHASE
|
LOW_LATENCY_RECV_PHASE
));
launcher
(
return_recv_hook
?
LOW_LATENCY_SEND_PHASE
:
(
LOW_LATENCY_SEND_PHASE
|
LOW_LATENCY_RECV_PHASE
));
...
...
csrc/deep_ep.hpp
View file @
ce671dd4
...
@@ -177,7 +177,8 @@ public:
...
@@ -177,7 +177,8 @@ public:
std
::
tuple
<
torch
::
Tensor
,
std
::
optional
<
torch
::
Tensor
>
,
torch
::
Tensor
,
torch
::
Tensor
,
torch
::
Tensor
,
std
::
optional
<
EventHandle
>
,
std
::
optional
<
std
::
function
<
void
()
>>>
std
::
tuple
<
torch
::
Tensor
,
std
::
optional
<
torch
::
Tensor
>
,
torch
::
Tensor
,
torch
::
Tensor
,
torch
::
Tensor
,
std
::
optional
<
EventHandle
>
,
std
::
optional
<
std
::
function
<
void
()
>>>
low_latency_dispatch
(
const
torch
::
Tensor
&
x
,
const
torch
::
Tensor
&
topk_idx
,
low_latency_dispatch
(
const
torch
::
Tensor
&
x
,
const
torch
::
Tensor
&
topk_idx
,
int
num_max_dispatch_tokens_per_rank
,
int
num_experts
,
int
num_max_dispatch_tokens_per_rank
,
int
num_experts
,
bool
use_fp8
,
bool
async
,
bool
return_recv_hook
);
bool
use_fp8
,
bool
round_scale
,
bool
use_ue8m0
,
bool
async
,
bool
return_recv_hook
);
std
::
tuple
<
torch
::
Tensor
,
std
::
optional
<
EventHandle
>
,
std
::
optional
<
std
::
function
<
void
()
>>>
std
::
tuple
<
torch
::
Tensor
,
std
::
optional
<
EventHandle
>
,
std
::
optional
<
std
::
function
<
void
()
>>>
low_latency_combine
(
const
torch
::
Tensor
&
x
,
const
torch
::
Tensor
&
topk_idx
,
const
torch
::
Tensor
&
topk_weights
,
low_latency_combine
(
const
torch
::
Tensor
&
x
,
const
torch
::
Tensor
&
topk_idx
,
const
torch
::
Tensor
&
topk_weights
,
...
...
csrc/kernels/api.cuh
View file @
ce671dd4
...
@@ -138,7 +138,7 @@ void clean_low_latency_buffer(int64_t* clean_0, int num_clean_int_0,
...
@@ -138,7 +138,7 @@ void clean_low_latency_buffer(int64_t* clean_0, int num_clean_int_0,
int64_t
*
clean_1
,
int
num_clean_int_1
,
int64_t
*
clean_1
,
int
num_clean_int_1
,
hipStream_t
stream
);
hipStream_t
stream
);
void
dispatch
(
void
*
packed_recv_x
,
float
*
packed_recv_x_scales
,
void
dispatch
(
void
*
packed_recv_x
,
void
*
packed_recv_x_scales
,
int
*
packed_recv_src_info
,
int64_t
*
packed_recv_layout_range
,
int
*
packed_recv_src_info
,
int64_t
*
packed_recv_layout_range
,
int
*
packed_recv_count
,
int
*
packed_recv_count
,
int
*
global_atomic_counter
,
int
*
global_atomic_counter
,
...
@@ -146,8 +146,9 @@ void dispatch(void* packed_recv_x, float* packed_recv_x_scales,
...
@@ -146,8 +146,9 @@ void dispatch(void* packed_recv_x, float* packed_recv_x_scales,
const
void
*
x
,
const
int64_t
*
topk_idx
,
const
void
*
x
,
const
int64_t
*
topk_idx
,
int64_t
*
next_clean
,
int
num_next_clean_int
,
int64_t
*
next_clean
,
int
num_next_clean_int
,
int
num_tokens
,
int
hidden
,
int
num_max_dispatch_tokens_per_rank
,
int
num_tokens
,
int
hidden
,
int
num_max_dispatch_tokens_per_rank
,
int
num_topk
,
int
num_experts
,
int
rank
,
int
num_ranks
,
bool
use_fp8
,
int
num_topk
,
int
num_experts
,
int
rank
,
int
num_ranks
,
void
*
workspace
,
hipStream_t
stream
,
int
phases
);
bool
use_fp8
,
bool
round_scale
,
bool
use_ue8m0
,
void
*
workspace
,
int
num_device_sms
,
hipStream_t
stream
,
int
phases
);
void
combine
(
void
*
combined_x
,
void
combine
(
void
*
combined_x
,
void
*
rdma_recv_x
,
int64_t
*
rdma_recv_flag
,
void
*
rdma_send_x
,
void
*
rdma_recv_x
,
int64_t
*
rdma_recv_flag
,
void
*
rdma_send_x
,
...
@@ -157,7 +158,7 @@ void combine(void* combined_x,
...
@@ -157,7 +158,7 @@ void combine(void* combined_x,
int64_t
*
next_clean
,
int
num_next_clean_int
,
int64_t
*
next_clean
,
int
num_next_clean_int
,
int
num_combined_tokens
,
int
hidden
,
int
num_max_dispatch_tokens_per_rank
,
int
num_combined_tokens
,
int
hidden
,
int
num_max_dispatch_tokens_per_rank
,
int
num_topk
,
int
num_experts
,
int
rank
,
int
num_ranks
,
int
num_topk
,
int
num_experts
,
int
rank
,
int
num_ranks
,
void
*
workspace
,
hipStream_t
stream
,
void
*
workspace
,
int
num_device_sms
,
hipStream_t
stream
,
int
phases
,
bool
zero_copy
);
int
phases
,
bool
zero_copy
);
}
// namespace internode_ll
}
// namespace internode_ll
...
...
csrc/kernels/internode.cu
View file @
ce671dd4
...
@@ -440,9 +440,6 @@ dispatch(int4 *recv_x, float *recv_x_scales, int64_t *recv_topk_idx, float *recv
...
@@ -440,9 +440,6 @@ dispatch(int4 *recv_x, float *recv_x_scales, int64_t *recv_topk_idx, float *recv
auto
warp_role
=
role_meta
.
first
;
auto
warp_role
=
role_meta
.
first
;
auto
target_rank
=
role_meta
.
second
;
// Not applicable for RDMA senders
auto
target_rank
=
role_meta
.
second
;
// Not applicable for RDMA senders
// if(lane_id==0){
// printf("tid=%d, bid=%d, warp_role=%d\n", threadIdx.x, blockIdx.x, warp_role);
// }
// RDMA symmetric layout
// RDMA symmetric layout
auto
hidden_bytes
=
hidden_int4
*
sizeof
(
int4
);
auto
hidden_bytes
=
hidden_int4
*
sizeof
(
int4
);
...
@@ -1610,8 +1607,6 @@ combine(int4 *combined_x, float *combined_topk_weights, const bool *is_combined_
...
@@ -1610,8 +1607,6 @@ combine(int4 *combined_x, float *combined_topk_weights, const bool *is_combined_
int
lds_dst_rdma_rank
=
dst_rdma_rank
+
(
iter
%
num_sync_large_iteration
)
*
kNumRDMARanks
+
mode
*
rdma_warp_counters
;
int
lds_dst_rdma_rank
=
dst_rdma_rank
+
(
iter
%
num_sync_large_iteration
)
*
kNumRDMARanks
+
mode
*
rdma_warp_counters
;
//reset index in the LDS to avoid race condition due to warp scheduling
//reset index in the LDS to avoid race condition due to warp scheduling
int
reset_idx
=
dst_rdma_rank
+
((
iter
+
num_sync_large_iteration
/
2
)
%
num_sync_large_iteration
)
*
kNumRDMARanks
+
mode
*
rdma_warp_counters
;
int
reset_idx
=
dst_rdma_rank
+
((
iter
+
num_sync_large_iteration
/
2
)
%
num_sync_large_iteration
)
*
kNumRDMARanks
+
mode
*
rdma_warp_counters
;
// // if (lane_id==0)
// // printf("rank %d dst_rdma_rank %d iter %d warp_id %d val %d\n", rank, dst_rdma_rank, iter, warp_id, sync_large_warp_counters[lds_dst_rdma_rank]);
auto
start_time
=
wall_clock64
();
auto
start_time
=
wall_clock64
();
if
(
lane_id
==
0
){
if
(
lane_id
==
0
){
volatile
int
ret
=
atomicAdd
((
int
*
)
&
sync_large_warp_counters
[
lds_dst_rdma_rank
],
1
);
volatile
int
ret
=
atomicAdd
((
int
*
)
&
sync_large_warp_counters
[
lds_dst_rdma_rank
],
1
);
...
...
csrc/kernels/internode_ll.cu
View file @
ce671dd4
...
@@ -85,9 +85,9 @@ void clean_low_latency_buffer(int64_t* clean_0, int num_clean_int_0,
...
@@ -85,9 +85,9 @@ void clean_low_latency_buffer(int64_t* clean_0, int num_clean_int_0,
clean_0
,
num_clean_int_0
,
clean_1
,
num_clean_int_1
);
clean_0
,
num_clean_int_0
,
clean_1
,
num_clean_int_1
);
}
}
template
<
bool
kUseFP8
,
int
kHidden
>
template
<
bool
kUseFP8
,
bool
kUseUE8M0
,
int
kHidden
>
__global__
__launch_bounds__
(
16
*
kWarpSize
,
1
)
void
__global__
__launch_bounds__
(
16
*
kWarpSize
,
1
)
void
dispatch
(
void
*
packed_recv_x
,
float
*
packed_recv_x_scales
,
dispatch
(
void
*
packed_recv_x
,
void
*
packed_recv_x_scales
,
int
*
packed_recv_src_info
,
int64_t
*
packed_recv_layout_range
,
int
*
packed_recv_src_info
,
int64_t
*
packed_recv_layout_range
,
int
*
packed_recv_count
,
int
*
packed_recv_count
,
int
*
global_atomic_counter
,
int
*
global_atomic_counter
,
...
@@ -97,7 +97,8 @@ dispatch(void* packed_recv_x, float* packed_recv_x_scales,
...
@@ -97,7 +97,8 @@ dispatch(void* packed_recv_x, float* packed_recv_x_scales,
int64_t
*
next_clean
,
int
num_next_clean_int
,
int64_t
*
next_clean
,
int
num_next_clean_int
,
int
num_tokens
,
int
num_max_dispatch_tokens_per_rank
,
int
num_tokens
,
int
num_max_dispatch_tokens_per_rank
,
int
num_topk
,
int
num_experts
,
int
rank
,
int
num_ranks
,
int
num_topk
,
int
num_experts
,
int
rank
,
int
num_ranks
,
int
num_warp_groups
,
int
num_warps_per_group
,
int
phases
)
{
int
num_warp_groups
,
int
num_warps_per_group
,
bool
round_scale
,
int
phases
)
{
#if !defined(ROCM_DISABLE_CTX)
#if !defined(ROCM_DISABLE_CTX)
__shared__
rocshmem
::
rocshmem_ctx_t
ctx
;
__shared__
rocshmem
::
rocshmem_ctx_t
ctx
;
rocshmem
::
rocshmem_wg_ctx_create
(
0
,
&
ctx
);
rocshmem
::
rocshmem_wg_ctx_create
(
0
,
&
ctx
);
...
@@ -113,6 +114,11 @@ dispatch(void* packed_recv_x, float* packed_recv_x_scales,
...
@@ -113,6 +114,11 @@ dispatch(void* packed_recv_x, float* packed_recv_x_scales,
const
auto
sub_warp_id
=
warp_id
%
num_warps_per_group
;
const
auto
sub_warp_id
=
warp_id
%
num_warps_per_group
;
const
auto
responsible_expert_idx
=
sm_id
*
num_warp_groups
+
warp_group_id
;
const
auto
responsible_expert_idx
=
sm_id
*
num_warp_groups
+
warp_group_id
;
// May extract UE8M0 from the scales
using
scale_t
=
std
::
conditional_t
<
kUseUE8M0
,
uint8_t
,
float
>
;
using
packed_t
=
std
::
conditional_t
<
kUseUE8M0
,
uint32_t
,
float
>
;
EP_STATIC_ASSERT
(
sizeof
(
packed_t
)
%
sizeof
(
scale_t
)
==
0
,
"Invalid vector length"
);
// FP8 staffs
// FP8 staffs
constexpr
int
kNumPerChannels
=
FP8_QUANTIZATION_NUM_PER_CHANNEL
;
constexpr
int
kNumPerChannels
=
FP8_QUANTIZATION_NUM_PER_CHANNEL
;
const
int
num_scales
=
kHidden
/
kNumPerChannels
;
const
int
num_scales
=
kHidden
/
kNumPerChannels
;
...
@@ -184,9 +190,9 @@ dispatch(void* packed_recv_x, float* packed_recv_x_scales,
...
@@ -184,9 +190,9 @@ dispatch(void* packed_recv_x, float* packed_recv_x_scales,
// Reduce amax and scale
// Reduce amax and scale
EP_STATIC_ASSERT
(
kNumElemsPerRead
*
kWarpSize
/
kNumPerChannels
==
4
,
"Invalid vectorization"
);
EP_STATIC_ASSERT
(
kNumElemsPerRead
*
kWarpSize
/
kNumPerChannels
==
4
,
"Invalid vectorization"
);
amax
=
warp_reduce_max
<
16
>
(
amax
);
amax
=
warp_reduce_max
<
16
>
(
amax
);
calculate_fp8_scales
<
/*round_scale*/
false
>
(
amax
,
scale
,
scale_inv
);
calculate_fp8_scales
(
amax
,
scale
,
scale_inv
,
round_scale
);
if
(
lane_id
%
16
==
0
)
if
(
lane_id
%
16
==
0
)
rdma_x_scales
[
i
*
kNumElemsPerRead
/
128
]
=
scale_inv
;
rdma_x_scales
[
i
*
kNumElemsPerRead
/
FP8_QUANTIZATION_NUM_PER_CHANNEL
]
=
scale_inv
;
// Cast into send buffer
// Cast into send buffer
vec_t
int2_value
;
vec_t
int2_value
;
...
@@ -316,9 +322,11 @@ LOW_LATENCY_DISPATCH_RECV:
...
@@ -316,9 +322,11 @@ LOW_LATENCY_DISPATCH_RECV:
src_rank
*
num_max_dispatch_tokens_per_rank
*
num_bytes_per_msg
;
src_rank
*
num_max_dispatch_tokens_per_rank
*
num_bytes_per_msg
;
const
auto
recv_x_int4
=
reinterpret_cast
<
int4
*>
(
packed_recv_x
)
+
const
auto
recv_x_int4
=
reinterpret_cast
<
int4
*>
(
packed_recv_x
)
+
local_expert_idx
*
num_ranks
*
num_max_dispatch_tokens_per_rank
*
hidden_int4
;
local_expert_idx
*
num_ranks
*
num_max_dispatch_tokens_per_rank
*
hidden_int4
;
const
auto
recv_x_scales
=
packed_recv_x_scales
+
local_expert_idx
*
num_ranks
*
num_max_dispatch_tokens_per_rank
*
num_scales
;
const
auto
recv_src_info
=
packed_recv_src_info
+
local_expert_idx
*
num_ranks
*
num_max_dispatch_tokens_per_rank
;
const
auto
recv_src_info
=
packed_recv_src_info
+
local_expert_idx
*
num_ranks
*
num_max_dispatch_tokens_per_rank
;
const
auto
recv_range
=
packed_recv_layout_range
+
local_expert_idx
*
num_ranks
;
const
auto
recv_range
=
packed_recv_layout_range
+
local_expert_idx
*
num_ranks
;
const
auto
num_aligned_scales
=
ALIGN
<
int
>
(
num_scales
,
sizeof
(
float
)
/
sizeof
(
scale_t
));
const
auto
recv_x_scales
=
static_cast
<
scale_t
*>
(
packed_recv_x_scales
)
+
local_expert_idx
*
num_ranks
*
num_max_dispatch_tokens_per_rank
*
num_aligned_scales
;
// Shared between sub-warps in warp groups
// Shared between sub-warps in warp groups
__shared__
int
shared_num_recv_tokens
[
kNumMaxWarpGroups
],
shared_recv_token_begin_idx
[
kNumMaxWarpGroups
];
__shared__
int
shared_num_recv_tokens
[
kNumMaxWarpGroups
],
shared_recv_token_begin_idx
[
kNumMaxWarpGroups
];
...
@@ -366,12 +374,23 @@ LOW_LATENCY_DISPATCH_RECV:
...
@@ -366,12 +374,23 @@ LOW_LATENCY_DISPATCH_RECV:
// Copy scales
// Copy scales
if
(
kUseFP8
)
{
if
(
kUseFP8
)
{
const
auto
src_scales
=
reinterpret_cast
<
float
*>
(
reinterpret_cast
<
uint8_t
*>
(
src_data
)
+
hidden_bytes
);
const
auto
src_scales
=
reinterpret_cast
<
float
*>
(
reinterpret_cast
<
uint8_t
*>
(
src_data
)
+
hidden_bytes
);
const
auto
dst_scales
=
reinterpret_cast
<
float
*>
(
recv_x_scales
+
recv_token_begin_idx
+
i
);
const
auto
num_elems_per_pack
=
static_cast
<
int
>
(
sizeof
(
packed_t
)
/
sizeof
(
scale_t
));
const
auto
scale_stride
=
num_ranks
*
num_max_dispatch_tokens_per_rank
;
const
auto
token_idx
=
recv_token_begin_idx
+
i
;
auto
scale_0
=
lane_id
<
num_scales
?
ld_nc_global
(
src_scales
+
lane_id
)
:
0
;
const
auto
token_stride
=
num_elems_per_pack
;
auto
scale_1
=
(
lane_id
+
kWarpSize
)
<
num_scales
?
ld_nc_global
(
src_scales
+
lane_id
+
kWarpSize
)
:
0
;
const
auto
pack_stride
=
num_ranks
*
num_max_dispatch_tokens_per_rank
*
num_elems_per_pack
;
lane_id
<
num_scales
?
dst_scales
[
lane_id
*
scale_stride
]
=
scale_0
:
0.0
f
;
(
lane_id
+
kWarpSize
)
<
num_scales
?
dst_scales
[(
lane_id
+
kWarpSize
)
*
scale_stride
]
=
scale_1
:
0.0
f
;
if
(
lane_id
<
num_scales
)
{
const
auto
pack_idx
=
lane_id
/
num_elems_per_pack
;
const
auto
elem_idx
=
lane_id
%
num_elems_per_pack
;
auto
scale
=
extract_required_scale_format
<
kUseUE8M0
>
(
ld_nc_global
(
src_scales
+
lane_id
));
recv_x_scales
[
token_idx
*
token_stride
+
pack_idx
*
pack_stride
+
elem_idx
]
=
scale
;
}
if
(
lane_id
+
kWarpSize
<
num_scales
)
{
const
auto
pack_idx
=
(
lane_id
+
kWarpSize
)
/
num_elems_per_pack
;
const
auto
elem_idx
=
(
lane_id
+
kWarpSize
)
%
num_elems_per_pack
;
auto
scale
=
extract_required_scale_format
<
kUseUE8M0
>
(
ld_nc_global
(
src_scales
+
lane_id
+
kWarpSize
));
recv_x_scales
[
token_idx
*
token_stride
+
pack_idx
*
pack_stride
+
elem_idx
]
=
scale
;
}
}
}
}
}
}
}
...
@@ -381,7 +400,7 @@ LOW_LATENCY_DISPATCH_RECV:
...
@@ -381,7 +400,7 @@ LOW_LATENCY_DISPATCH_RECV:
#endif
#endif
}
}
void
dispatch
(
void
*
packed_recv_x
,
float
*
packed_recv_x_scales
,
void
dispatch
(
void
*
packed_recv_x
,
void
*
packed_recv_x_scales
,
int
*
packed_recv_src_info
,
int64_t
*
packed_recv_layout_range
,
int
*
packed_recv_src_info
,
int64_t
*
packed_recv_layout_range
,
int
*
packed_recv_count
,
int
*
packed_recv_count
,
int
*
global_atomic_counter
,
int
*
global_atomic_counter
,
...
@@ -389,10 +408,12 @@ void dispatch(void* packed_recv_x, float* packed_recv_x_scales,
...
@@ -389,10 +408,12 @@ void dispatch(void* packed_recv_x, float* packed_recv_x_scales,
const
void
*
x
,
const
int64_t
*
topk_idx
,
const
void
*
x
,
const
int64_t
*
topk_idx
,
int64_t
*
next_clean
,
int
num_next_clean_int
,
int64_t
*
next_clean
,
int
num_next_clean_int
,
int
num_tokens
,
int
hidden
,
int
num_max_dispatch_tokens_per_rank
,
int
num_tokens
,
int
hidden
,
int
num_max_dispatch_tokens_per_rank
,
int
num_topk
,
int
num_experts
,
int
rank
,
int
num_ranks
,
bool
use_fp8
,
int
num_topk
,
int
num_experts
,
int
rank
,
int
num_ranks
,
void
*
workspace
,
hipStream_t
stream
,
int
phases
)
{
bool
use_fp8
,
bool
round_scale
,
bool
use_ue8m0
,
void
*
workspace
,
int
num_device_sms
,
hipStream_t
stream
,
int
phases
)
{
constexpr
int
kNumMaxTopK
=
11
;
constexpr
int
kNumMaxTopK
=
11
;
const
int
num_warp_groups
=
ceil_div
(
num_experts
,
/*
num_device_sms
*/
80
);
const
int
num_warp_groups
=
ceil_div
(
num_experts
,
num_device_sms
);
const
int
num_warps_per_group
=
16
/
num_warp_groups
;
const
int
num_warps_per_group
=
16
/
num_warp_groups
;
EP_HOST_ASSERT
(
num_warp_groups
>
0
and
num_warps_per_group
>
0
);
EP_HOST_ASSERT
(
num_warp_groups
>
0
and
num_warps_per_group
>
0
);
EP_HOST_ASSERT
(
kNumMaxTopK
+
1
<=
num_warp_groups
*
num_warps_per_group
);
EP_HOST_ASSERT
(
kNumMaxTopK
+
1
<=
num_warp_groups
*
num_warps_per_group
);
...
@@ -407,8 +428,11 @@ void dispatch(void* packed_recv_x, float* packed_recv_x_scales,
...
@@ -407,8 +428,11 @@ void dispatch(void* packed_recv_x, float* packed_recv_x_scales,
EP_HOST_ASSERT
(
num_experts
*
sizeof
(
int
)
*
2
<=
NUM_WORKSPACE_BYTES
);
EP_HOST_ASSERT
(
num_experts
*
sizeof
(
int
)
*
2
<=
NUM_WORKSPACE_BYTES
);
#define DISPATCH_LAUNCH_CASE(hidden) { \
#define DISPATCH_LAUNCH_CASE(hidden) { \
auto dispatch_func = use_fp8 ? dispatch<true, hidden> : \
auto dispatch_func = dispatch<false, false, hidden>; \
dispatch<false, hidden>; \
if (use_fp8 and not use_ue8m0) \
dispatch_func = dispatch<true, false, hidden>; \
if (use_fp8 and use_ue8m0) \
dispatch_func = dispatch<true, true, hidden>; \
LAUNCH_KERNEL_NON_COOPERATIVE(&cfg, dispatch_func, \
LAUNCH_KERNEL_NON_COOPERATIVE(&cfg, dispatch_func, \
packed_recv_x, packed_recv_x_scales, \
packed_recv_x, packed_recv_x_scales, \
packed_recv_src_info, packed_recv_layout_range, \
packed_recv_src_info, packed_recv_layout_range, \
...
@@ -420,15 +444,15 @@ LAUNCH_KERNEL_NON_COOPERATIVE(&cfg, dispatch_func, \
...
@@ -420,15 +444,15 @@ LAUNCH_KERNEL_NON_COOPERATIVE(&cfg, dispatch_func, \
next_clean, num_next_clean_int, \
next_clean, num_next_clean_int, \
num_tokens, num_max_dispatch_tokens_per_rank, \
num_tokens, num_max_dispatch_tokens_per_rank, \
num_topk, num_experts, rank, num_ranks, \
num_topk, num_experts, rank, num_ranks, \
num_warp_groups, num_warps_per_group, phases); } break
num_warp_groups, num_warps_per_group,
round_scale,
phases); } break
SETUP_LAUNCH_CONFIG
(
num_sms
,
num_warps
*
kWarpSize
,
stream
);
SETUP_LAUNCH_CONFIG
(
num_sms
,
num_warps
*
kWarpSize
,
stream
);
SWITCH_HIDDEN
(
DISPATCH_LAUNCH_CASE
);
SWITCH_HIDDEN
(
DISPATCH_LAUNCH_CASE
);
#undef DISPATCH_LAUNCH_CASE
#undef DISPATCH_LAUNCH_CASE
}
}
template
<
int
kNumWarpGroups
,
int
kNumWarpsPerGroup
,
int
kHidden
,
int
kNumMaxTopk
>
template
<
int
kHidden
,
int
kNumMaxTopk
>
__global__
__launch_bounds__
(
kNumWarpGroups
*
kNumWarpsPerGroup
*
kWarpSize
,
1
)
void
__global__
__launch_bounds__
(
16
*
kWarpSize
,
1
)
void
combine
(
void
*
combined_x
,
combine
(
void
*
combined_x
,
void
*
rdma_recv_x
,
int64_t
*
rdma_recv_flag
,
void
*
rdma_send_x
,
void
*
rdma_recv_x
,
int64_t
*
rdma_recv_flag
,
void
*
rdma_send_x
,
const
void
*
x
,
const
int64_t
*
topk_idx
,
const
float
*
topk_weights
,
const
void
*
x
,
const
int64_t
*
topk_idx
,
const
float
*
topk_weights
,
...
@@ -439,6 +463,7 @@ combine(void* combined_x,
...
@@ -439,6 +463,7 @@ combine(void* combined_x,
int
num_combined_tokens
,
int
hidden
,
int
num_topk
,
int
num_combined_tokens
,
int
hidden
,
int
num_topk
,
int
num_max_dispatch_tokens_per_rank
,
int
num_max_dispatch_tokens_per_rank
,
int
num_experts
,
int
rank
,
int
num_ranks
,
int
num_experts
,
int
rank
,
int
num_ranks
,
int
num_warp_groups
,
int
num_warps_per_group
,
int
phases
,
bool
zero_copy
)
{
int
phases
,
bool
zero_copy
)
{
#if !defined(ROCM_DISABLE_CTX)
#if !defined(ROCM_DISABLE_CTX)
...
@@ -451,19 +476,21 @@ combine(void* combined_x,
...
@@ -451,19 +476,21 @@ combine(void* combined_x,
const
auto
num_threads
=
static_cast
<
int
>
(
blockDim
.
x
);
const
auto
num_threads
=
static_cast
<
int
>
(
blockDim
.
x
);
const
auto
warp_id
=
thread_id
/
kWarpSize
,
lane_id
=
get_lane_id
();
const
auto
warp_id
=
thread_id
/
kWarpSize
,
lane_id
=
get_lane_id
();
const
auto
num_local_experts
=
num_experts
/
num_ranks
;
const
auto
num_local_experts
=
num_experts
/
num_ranks
;
const
auto
warp_group_id
=
warp_id
/
kNumWarpsPerG
roup
;
const
auto
warp_group_id
=
warp_id
/
num_warps_per_g
roup
;
const
auto
sub_warp_id
=
warp_id
%
kNumWarpsPerG
roup
;
const
auto
sub_warp_id
=
warp_id
%
num_warps_per_g
roup
;
const
auto
responsible_expert_idx
=
sm_id
*
kNumW
arp
G
roups
+
warp_group_id
;
const
auto
responsible_expert_idx
=
sm_id
*
num_w
arp
_g
roups
+
warp_group_id
;
// Data type staffs
// Data type staffs
constexpr
int
kNumElemsPerInt4
=
sizeof
(
int4
)
/
sizeof
(
hip_bfloat16
);
constexpr
int
kNumElemsPerInt4
=
sizeof
(
int4
)
/
sizeof
(
hip_bfloat16
);
const
size_t
hidden_bf16_int4
=
kHidden
/
kNumElemsPerInt4
;
const
size_t
hidden_bf16_int4
=
kHidden
/
kNumElemsPerInt4
;
// Message package
// Message package
// BF16 mode: always use BF16 for hidden data (ignoring the extra flag slot)
EP_STATIC_ASSERT
(
kHidden
%
FP8_QUANTIZATION_NUM_PER_CHANNEL
==
0
,
"Invalid hidden"
);
constexpr
size_t
num_bytes_per_slot
=
sizeof
(
int4
)
+
kHidden
*
sizeof
(
hip_bfloat16
);
constexpr
int
kNumDivisions
=
kHidden
/
FP8_QUANTIZATION_NUM_PER_CHANNEL
;
constexpr
int
kNumMetaBytes
=
kNumDivisions
*
sizeof
(
float
);
constexpr
size_t
num_bytes_per_slot
=
sizeof
(
int4
)
+
kHidden
*
sizeof
(
hip_bfloat16
)
+
kNumMetaBytes
;
EP_STATIC_ASSERT
(
num_bytes_per_slot
%
sizeof
(
int4
)
==
0
,
"Invalid vectorization"
);
EP_STATIC_ASSERT
(
num_bytes_per_slot
%
sizeof
(
int4
)
==
0
,
"Invalid vectorization"
);
__syncthreads
();
// 16 is the max possible number of warps in AMD GPUs
// 16 is the max possible number of warps in AMD GPUs
constexpr
int
kMaxNumWarps
=
1024
/
kWarpSize
;
constexpr
int
kMaxNumWarps
=
1024
/
kWarpSize
;
__shared__
volatile
int
sync_large_warp_counters
[
kMaxNumWarps
];
__shared__
volatile
int
sync_large_warp_counters
[
kMaxNumWarps
];
...
@@ -508,13 +535,13 @@ combine(void* combined_x,
...
@@ -508,13 +535,13 @@ combine(void* combined_x,
unpack2
(
layout
,
num_tokens_to_send
,
offset
);
unpack2
(
layout
,
num_tokens_to_send
,
offset
);
// Issue IBGDA send
// Issue IBGDA send
for
(
int
token_idx
=
offset
+
sub_warp_id
;
token_idx
<
offset
+
num_tokens_to_send
;
token_idx
+=
kNumWarpsPerG
roup
)
{
for
(
int
token_idx
=
offset
+
sub_warp_id
;
token_idx
<
offset
+
num_tokens_to_send
;
token_idx
+=
num_warps_per_g
roup
)
{
const
auto
x_int4
=
local_x
+
token_idx
*
hidden_bf16_int4
;
const
auto
x_int4
=
local_x
+
token_idx
*
hidden_bf16_int4
;
const
auto
rdma_send_type_row
=
reinterpret_cast
<
int
*>
(
rdma_send_x_vec
+
token_idx
*
num_bytes_per_slot
);
const
auto
rdma_send_type_row
=
reinterpret_cast
<
int
*>
(
rdma_send_x_vec
+
token_idx
*
num_bytes_per_slot
);
const
auto
rdma_send_x_vec_row
=
reinterpret_cast
<
uint8_t
*>
(
rdma_send_type_row
+
4
);
const
auto
rdma_send_x_vec_row
=
reinterpret_cast
<
uint8_t
*>
(
rdma_send_type_row
+
4
);
// Copy directly to local rank, or copy to buffer and issue RDMA
// Copy directly to local rank, or copy to buffer and issue RDMA
auto
src_idx
=
__ldg
(
local_src_info
+
token_idx
);
const
auto
src_idx
=
shfl_sync
(
__ldg
(
local_src_info
+
token_idx
)
,
0
)
;
const
auto
buf_ptr
=
reinterpret_cast
<
int64_t
>
(
rdma_send_x_vec_row
);
const
auto
buf_ptr
=
reinterpret_cast
<
int64_t
>
(
rdma_send_x_vec_row
);
const
auto
dst_ptr
=
reinterpret_cast
<
uint64_t
>
(
rdma_recv_x
)
+
(
global_expert_idx
*
num_max_dispatch_tokens_per_rank
+
src_idx
)
*
num_bytes_per_slot
+
sizeof
(
int4
);
const
auto
dst_ptr
=
reinterpret_cast
<
uint64_t
>
(
rdma_recv_x
)
+
(
global_expert_idx
*
num_max_dispatch_tokens_per_rank
+
src_idx
)
*
num_bytes_per_slot
+
sizeof
(
int4
);
if
(
dst_rank
==
rank
)
{
if
(
dst_rank
==
rank
)
{
...
@@ -542,13 +569,13 @@ combine(void* combined_x,
...
@@ -542,13 +569,13 @@ combine(void* combined_x,
}
}
// Put finishing flag
// Put finishing flag
EP_
STAT
IC_ASSERT
(
kNumWarpsPerGroup
>
1
,
"Requires more than one warp per group"
);
EP_
DEV
IC
E
_ASSERT
(
num_warps_per_group
>
1
);
if
(
lane_id
==
0
){
if
(
lane_id
==
0
){
// volatile int ret = __hip_atomic_fetch_add(&sync_large_warp_counters[warp_group_id], 1,__ATOMIC_RELAXED, __HIP_MEMORY_SCOPE_WORKGROUP);
// volatile int ret = __hip_atomic_fetch_add(&sync_large_warp_counters[warp_group_id], 1,__ATOMIC_RELAXED, __HIP_MEMORY_SCOPE_WORKGROUP);
volatile
int
ret
=
atomicAdd
((
int
*
)
&
sync_large_warp_counters
[
warp_group_id
],
1
);
volatile
int
ret
=
atomicAdd
((
int
*
)
&
sync_large_warp_counters
[
warp_group_id
],
1
);
}
}
syncwarp
();
syncwarp
();
while
(
sync_large_warp_counters
[
warp_group_id
]
<
(
kNumWarpsPerG
roup
)
)
;
while
(
sync_large_warp_counters
[
warp_group_id
]
<
num_warps_per_g
roup
);
if
(
sub_warp_id
==
1
and
lane_id
==
0
)
{
if
(
sub_warp_id
==
1
and
lane_id
==
0
)
{
while
(
ld_acquire_global
(
atomic_clean_flag
)
==
0
);
while
(
ld_acquire_global
(
atomic_clean_flag
)
==
0
);
if
(
dst_rank
!=
rank
)
{
if
(
dst_rank
!=
rank
)
{
...
@@ -572,7 +599,7 @@ combine(void* combined_x,
...
@@ -572,7 +599,7 @@ combine(void* combined_x,
// Wait all ranks to arrive and notify PCIe usage
// Wait all ranks to arrive and notify PCIe usage
if
(
responsible_expert_idx
<
num_experts
)
{
if
(
responsible_expert_idx
<
num_experts
)
{
EP_
STAT
IC_ASSERT
(
kNumWarpsPerGroup
>
1
,
"Invalid number of warps per group"
);
EP_
DEV
IC
E
_ASSERT
(
num_warps_per_group
>
1
);
if
(
sub_warp_id
==
0
and
lane_id
==
0
){
if
(
sub_warp_id
==
0
and
lane_id
==
0
){
while
(
ld_acquire_global
(
reinterpret_cast
<
int
*>
(
rdma_recv_flag
+
responsible_expert_idx
))
==
0
);
while
(
ld_acquire_global
(
reinterpret_cast
<
int
*>
(
rdma_recv_flag
+
responsible_expert_idx
))
==
0
);
}
}
...
@@ -630,14 +657,17 @@ void combine(void* combined_x,
...
@@ -630,14 +657,17 @@ void combine(void* combined_x,
int64_t
*
next_clean
,
int
num_next_clean_int
,
int64_t
*
next_clean
,
int
num_next_clean_int
,
int
num_combined_tokens
,
int
hidden
,
int
num_max_dispatch_tokens_per_rank
,
int
num_combined_tokens
,
int
hidden
,
int
num_max_dispatch_tokens_per_rank
,
int
num_topk
,
int
num_experts
,
int
rank
,
int
num_ranks
,
int
num_topk
,
int
num_experts
,
int
rank
,
int
num_ranks
,
void
*
workspace
,
hipStream_t
stream
,
void
*
workspace
,
int
num_device_sms
,
hipStream_t
stream
,
int
phases
,
bool
zero_copy
)
{
int
phases
,
bool
zero_copy
)
{
constexpr
int
kNumWarpsPerGroup
=
4
;
constexpr
int
kNumMaxTopk
=
11
;
constexpr
int
kNumWarpGroups
=
4
;
const
int
num_warp_groups
=
ceil_div
(
num_experts
,
num_device_sms
);
constexpr
int
kNumMaxTopk
=
9
;
const
int
num_warps_per_group
=
16
/
num_warp_groups
;
// num_warps_per_group>1, "Requires more than one warp per group"
const
int
num_recv_per_sm
=
ceil_div
(
num_combined_tokens
,
num_device_sms
);
EP_HOST_ASSERT
(
num_warp_groups
>
0
and
num_warps_per_group
>
0
and
num_recv_per_sm
>=
0
);
const
auto
num_warps
=
kNumWarpGroups
*
kNumWarpsPerGroup
;
const
auto
num_warps
=
num_warp_groups
*
num_warps_per_group
;
const
auto
num_sms
=
ceil_div
(
num_experts
,
kNumWarpGroups
);
const
auto
num_sms
=
max
(
ceil_div
(
num_experts
,
num_warp_groups
),
num_recv_per_sm
==
0
?
1
:
ceil_div
(
num_combined_tokens
,
num_recv_per_sm
));
// Check workspace
// Check workspace
auto
atomic_clean_flag
=
reinterpret_cast
<
int
*>
(
workspace
);
auto
atomic_clean_flag
=
reinterpret_cast
<
int
*>
(
workspace
);
...
@@ -645,7 +675,7 @@ void combine(void* combined_x,
...
@@ -645,7 +675,7 @@ void combine(void* combined_x,
EP_HOST_ASSERT
(
num_topk
<=
kNumMaxTopk
);
EP_HOST_ASSERT
(
num_topk
<=
kNumMaxTopk
);
#define COMBINE_LAUNCH_CASE(hidden) { \
#define COMBINE_LAUNCH_CASE(hidden) { \
auto combine_func = combine<
kNumWarpGroups, kNumWarpsPerGroup,
hidden, kNumMaxTopk>; \
auto combine_func = combine<hidden, kNumMaxTopk>; \
LAUNCH_KERNEL_NON_COOPERATIVE(&cfg, combine_func, \
LAUNCH_KERNEL_NON_COOPERATIVE(&cfg, combine_func, \
combined_x, \
combined_x, \
rdma_recv_x, rdma_recv_flag, rdma_send_x, \
rdma_recv_x, rdma_recv_flag, rdma_send_x, \
...
@@ -656,7 +686,7 @@ LAUNCH_KERNEL_NON_COOPERATIVE(&cfg, combine_func, \
...
@@ -656,7 +686,7 @@ LAUNCH_KERNEL_NON_COOPERATIVE(&cfg, combine_func, \
num_combined_tokens, hidden, num_topk, \
num_combined_tokens, hidden, num_topk, \
num_max_dispatch_tokens_per_rank, \
num_max_dispatch_tokens_per_rank, \
num_experts, rank, num_ranks, \
num_experts, rank, num_ranks, \
phases, zero_copy); } break
num_warp_groups, num_warps_per_group,
phases, zero_copy); } break
SETUP_LAUNCH_CONFIG
(
num_sms
,
num_warps
*
kWarpSize
,
stream
);
SETUP_LAUNCH_CONFIG
(
num_sms
,
num_warps
*
kWarpSize
,
stream
);
SWITCH_HIDDEN
(
COMBINE_LAUNCH_CASE
);
SWITCH_HIDDEN
(
COMBINE_LAUNCH_CASE
);
...
...
csrc/kernels/utils.cuh
View file @
ce671dd4
...
@@ -365,9 +365,8 @@ __forceinline__ __device__ int fast_log2_ceil(float x) {
...
@@ -365,9 +365,8 @@ __forceinline__ __device__ int fast_log2_ceil(float x) {
return
exp_x
-
127
+
(
man_bits
!=
0
);
return
exp_x
-
127
+
(
man_bits
!=
0
);
}
}
template
<
bool
kRoundScale
>
__forceinline__
__device__
void
calculate_fp8_scales
(
float
amax
,
float
&
scale
,
float
&
scale_inv
,
bool
round_scale
)
{
__forceinline__
__device__
void
calculate_fp8_scales
(
float
amax
,
float
&
scale
,
float
&
scale_inv
)
{
if
(
round_scale
)
{
if
constexpr
(
kRoundScale
)
{
auto
exp_scale_inv
=
fast_log2_ceil
(
amax
*
kFinfoAmaxInvE4M3
);
auto
exp_scale_inv
=
fast_log2_ceil
(
amax
*
kFinfoAmaxInvE4M3
);
scale
=
fast_pow2
(
-
exp_scale_inv
);
scale
=
fast_pow2
(
-
exp_scale_inv
);
scale_inv
=
fast_pow2
(
exp_scale_inv
);
scale_inv
=
fast_pow2
(
exp_scale_inv
);
...
...
deep_ep/buffer.py
View file @
ce671dd4
...
@@ -804,42 +804,46 @@ class Buffer:
...
@@ -804,42 +804,46 @@ class Buffer:
# noinspection PyTypeChecker
# noinspection PyTypeChecker
def
low_latency_dispatch
(
self
,
x
:
torch
.
Tensor
,
topk_idx
:
torch
.
Tensor
,
def
low_latency_dispatch
(
self
,
x
:
torch
.
Tensor
,
topk_idx
:
torch
.
Tensor
,
num_max_dispatch_tokens_per_rank
:
int
,
num_experts
:
int
,
num_max_dispatch_tokens_per_rank
:
int
,
num_experts
:
int
,
use_fp8
:
bool
=
True
,
async_finish
:
bool
=
False
,
return_recv_hook
:
bool
=
False
)
->
\
use_fp8
:
bool
=
True
,
round_scale
:
bool
=
False
,
use_ue8m0
:
bool
=
False
,
async_finish
:
bool
=
False
,
return_recv_hook
:
bool
=
False
)
->
\
Tuple
[
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
],
torch
.
Tensor
,
Tuple
,
EventOverlap
,
Callable
]:
Tuple
[
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
],
torch
.
Tensor
,
Tuple
,
EventOverlap
,
Callable
]:
"""
"""
A low-latency implementation for dispatching with IBGDA.
A low-latency implementation for dispatching with IBGDA.
This kernel requires all the ranks (no matter intranode or internode) should be visible via RDMA
This kernel requires all the ranks (no matter intranode or internode) should be visible via RDMA
(specifically, IBGDA must be enabled).
(specifically, IBGDA must be enabled).
Even for ranks in the same node, NVLink are fully disabled for simplicity.
Warning: as there are only two buffers, and the returned tensors reuse the buffer, you cannot hold more than 2
Warning: as there are only two buffers, and the returned tensors reuse the buffer, you can not hold more than 2
low-latency kernels' result tensors at a single moment.
low-latency kernels' result tensor at a single moment.
Arguments:
Arguments:
x: `torch.Tensor` with `torch.bfloat16`, shaped as `[num_tokens, hidden]`, only several hidden shapes are
x: `torch.Tensor` with `torch.bfloat16`, shaped as `[num_tokens, hidden]`, only several hidden shapes are
supported. The number of tokens to be dispatched must be less than `num_max_dispatch_tokens_per_rank`.
supported. The number of tokens to be dispatched must be less than `num_max_dispatch_tokens_per_rank`.
topk_idx: `torch.Tensor` with `torch.int64`, shaped as `[num_tokens, num_topk]`,
only several top-k shapes
topk_idx: `torch.Tensor` with
`deep_ep.topk_idx_t` (typically
`torch.int64`
)
, shaped as `[num_tokens, num_topk]`,
are supported. `-1` indices (not selecting any expert) are supported.
only several top-k shapes
are supported. `-1` indices (not selecting any expert) are supported.
num_max_dispatch_tokens_per_rank: the maximum number of tokens to dispatch, all the ranks must hold the same value.
num_max_dispatch_tokens_per_rank: the maximum number of tokens to dispatch, all the ranks must hold the same value.
num_experts: the number of all experts.
num_experts: the number of all experts.
use_fp8: whether to enable FP8 casting, with this, the received data will be a tuple of FP8 tensor and scaling factors.
use_fp8: whether to enable FP8 casting, with this, the received data will be a tuple of FP8 tensor and scaling factors.
round_scale: whether round the scaling factors into power of 2.
use_ue8m0: whether use UE8M0 as scaling factor format (available only with `round_scale=True`).
async_finish: the current stream will not wait for the communication kernels to be finished if set.
async_finish: the current stream will not wait for the communication kernels to be finished if set.
return_recv_hook: return a receiving hook if set. If set, the kernel will just do the RDMA request issues,
return_recv_hook: return a receiving hook if set. If set, the kernel will just do the RDMA request issues,
but **without actually receiving the data**. You must call the received hook to make sure the data's arrival.
but **without actually receiving the data**. You must call the received hook to make sure the data's arrival.
If you not set this flag, the kernel will ensure the data's arrival.
If you
do
not set this flag, the kernel will ensure the data's arrival.
Returns:
Returns:
recv_x: a tensor or tuple with received tokens for each expert.
recv_x: a tensor or tuple with received tokens for each expert.
With `use_fp8=True`: the first element is a `torch.Tensor` shaped as
With `use_fp8=True`: the first element is a `torch.Tensor` shaped as
`[num_local_experts, num_max_dispatch_tokens_per_rank * num_ranks, hidden]` with `torch.float8_e4m3fn`.
`[num_local_experts, num_max_dispatch_tokens_per_rank * num_ranks, hidden]` with `torch.float8_e4m3fn`.
The second tensor is the corresponding scales for the first element with shape
The second tensor is the corresponding scales for the first element with shape
`[num_local_experts, num_max_dispatch_tokens_per_rank * num_ranks, hidden // 128]` with `torch.float`.
`[num_local_experts, num_max_dispatch_tokens_per_rank * num_ranks, hidden // 128]` with `torch.float`,
if `use_ue8m0=False`. With `use_ue8m0=True`, the second one is packed and shaped as
`[num_local_experts, num_max_dispatch_tokens_per_rank * num_ranks, hidden // 512]` with type `torch.int`.
Notice that, the last-two-dimension of the scaling tensors are in column-major for TMA compatibility.
Notice that, the last-two-dimension of the scaling tensors are in column-major for TMA compatibility.
With `use_fp8=False`, the result would be a tensor shaped as
With `use_fp8=False`, the result would be a tensor shaped as
`[num_local_experts, num_max_dispatch_tokens_per_rank * num_ranks, hidden]` with `torch.bfloat16`.
`[num_local_experts, num_max_dispatch_tokens_per_rank * num_ranks, hidden]` with `torch.bfloat16`.
Moreover, not all tokens are valid, only some of the `num_max_dispatch_tokens_per_rank * num_ranks` are,
Moreover, not all tokens are valid, only some of the `num_max_dispatch_tokens_per_rank * num_ranks` are,
as we do not synchronize CPU received count with GPU (also not incompatible with CUDA graph if synced).
as we do not synchronize CPU received count with GPU (also not incompatible with CUDA graph if synced).
recv_count: a tensor shaped `[num_local_experts]` with type `torch.int`, indicating how many tokens each
recv_count: a tensor shaped `[num_local_experts]` with type `torch.int`, indicating how many tokens each
expert receive. As mentioned before,
all not
tokens are valid in `recv_x`.
expert receive
s
. As mentioned before,
not all
tokens are valid in `recv_x`.
handle: the communication handle to be used in the `low_latency_combine` function.
handle: the communication handle to be used in the `low_latency_combine` function.
event: the event after executing the kernel (valid only if `async_finish` is set).
event: the event after executing the kernel (valid only if `async_finish` is set).
hook: the receiving hook function (valid only if `return_recv_hook` is set).
hook: the receiving hook function (valid only if `return_recv_hook` is set).
...
@@ -847,7 +851,8 @@ class Buffer:
...
@@ -847,7 +851,8 @@ class Buffer:
packed_recv_x
,
packed_recv_x_scales
,
packed_recv_count
,
packed_recv_src_info
,
packed_recv_layout_range
,
event
,
hook
=
\
packed_recv_x
,
packed_recv_x_scales
,
packed_recv_count
,
packed_recv_src_info
,
packed_recv_layout_range
,
event
,
hook
=
\
self
.
runtime
.
low_latency_dispatch
(
x
,
topk_idx
,
self
.
runtime
.
low_latency_dispatch
(
x
,
topk_idx
,
num_max_dispatch_tokens_per_rank
,
num_experts
,
num_max_dispatch_tokens_per_rank
,
num_experts
,
use_fp8
,
async_finish
,
return_recv_hook
)
use_fp8
,
round_scale
,
use_ue8m0
,
async_finish
,
return_recv_hook
)
handle
=
(
packed_recv_src_info
,
packed_recv_layout_range
,
num_max_dispatch_tokens_per_rank
,
x
.
size
(
1
),
num_experts
)
handle
=
(
packed_recv_src_info
,
packed_recv_layout_range
,
num_max_dispatch_tokens_per_rank
,
x
.
size
(
1
),
num_experts
)
tensors_to_record
=
(
x
,
topk_idx
,
tensors_to_record
=
(
x
,
topk_idx
,
packed_recv_x
,
packed_recv_x_scales
,
packed_recv_count
,
packed_recv_x
,
packed_recv_x_scales
,
packed_recv_count
,
...
...
tests/test_low_latency_new.py
0 → 100644
View file @
ce671dd4
import
argparse
import
random
import
torch
import
torch.distributed
as
dist
from
functools
import
partial
from
typing
import
Literal
,
Set
import
deep_ep
from
utils
import
init_dist
,
bench
,
bench_kineto
,
calc_diff
,
hash_tensor
,
per_token_cast_back
def
simulate_failure_and_skip
(
rank
:
int
,
api
:
Literal
[
"dispatch"
,
"combine"
,
"clean"
],
expected_masked_ranks
:
Set
[
int
]):
# Simulates rank failure when the rank first calls the corresponding communication API
failed_api_ranks
=
{
# API -> rank to fail (rank fails when it first calls the corresponding communication API)
'dispatch'
:
1
,
'combine'
:
3
,
'clean'
:
5
}
if
rank
in
expected_masked_ranks
:
# Rank already failed
return
True
if
api
in
failed_api_ranks
.
keys
():
expected_masked_ranks
.
add
(
failed_api_ranks
[
api
])
if
failed_api_ranks
[
api
]
==
rank
:
print
(
f
"Rank
{
rank
}
failed when first calling
{
api
}
communication API, exit..."
,
flush
=
True
)
return
True
return
False
def
query_mask_buffer_and_check
(
api
:
Literal
[
"dispatch"
,
"combine"
,
"clean"
],
buffer
:
deep_ep
.
Buffer
,
mask_status
:
torch
.
Tensor
,
expected_masked_ranks
:
Set
[
int
]):
buffer
.
low_latency_query_mask_buffer
(
mask_status
)
assert
set
(
mask_status
.
nonzero
().
squeeze
(
-
1
).
tolist
())
==
expected_masked_ranks
def
test_main
(
num_tokens
:
int
,
hidden
:
int
,
num_experts
:
int
,
num_topk
:
int
,
rank
:
int
,
num_ranks
:
int
,
group
:
dist
.
ProcessGroup
,
buffer
:
deep_ep
.
Buffer
,
seed
:
int
=
0
):
torch
.
manual_seed
(
seed
+
rank
)
random
.
seed
(
seed
+
rank
)
assert
num_experts
%
num_ranks
==
0
num_local_experts
=
num_experts
//
num_ranks
# NOTES: the integers greater than 256 exceed the BF16 precision limit
rank_offset
=
128
assert
num_ranks
-
rank_offset
<
257
,
'Too many ranks (exceeding test precision limit)'
x
=
torch
.
ones
((
num_tokens
,
hidden
),
dtype
=
torch
.
bfloat16
,
device
=
'cuda'
)
*
(
rank
-
rank_offset
)
x
[:,
-
128
:]
=
torch
.
arange
(
num_tokens
,
device
=
'cuda'
).
to
(
torch
.
bfloat16
).
view
(
-
1
,
1
)
x_list
=
[
x
]
# # NOTES: the last one is for performance testing
# # Most of the values in the perf case is lower than the threshold, casting most channels
# x_list.append(torch.randn((num_tokens, hidden), dtype=torch.bfloat16, device='cuda') * 0.1)
scores
=
torch
.
randn
((
num_tokens
,
num_experts
),
dtype
=
torch
.
float32
,
device
=
'cuda'
).
abs
()
+
1
topk_idx
=
torch
.
topk
(
scores
,
num_topk
,
dim
=-
1
,
largest
=
True
,
sorted
=
True
)[
1
]
topk_weights
=
torch
.
randn
((
num_tokens
,
num_topk
),
dtype
=
torch
.
float32
,
device
=
'cuda'
).
abs
()
# Randomly mask some positions
for
_
in
range
(
10
):
topk_idx
[
random
.
randint
(
0
,
num_tokens
-
1
),
random
.
randint
(
0
,
num_topk
-
1
)]
=
-
1
all_topk_idx
=
torch
.
empty
((
num_ranks
,
num_tokens
,
num_topk
),
dtype
=
topk_idx
.
dtype
,
device
=
'cuda'
)
dist
.
all_gather_into_tensor
(
all_topk_idx
,
topk_idx
,
group
=
group
)
# For failure simulation and shrink testing
mask_status
=
torch
.
zeros
((
num_ranks
,),
dtype
=
torch
.
int
,
device
=
'cuda'
)
expected_masked_ranks
=
set
()
# Check dispatch correctness
do_check
=
True
hash_value
,
num_times
=
0
,
0
for
current_x
in
x_list
:
for
return_recv_hook
in
(
False
,
True
):
for
dispatch_use_fp8
in
(
False
,
True
):
for
round_scale
in
(
False
,
True
)
if
dispatch_use_fp8
else
(
False
,):
for
use_ue8m0
in
(
False
,
True
)
if
round_scale
else
(
False
,):
num_times
+=
1
for
_
in
range
((
num_times
%
2
)
+
1
):
packed_recv_x
,
packed_recv_count
,
handle
,
event
,
hook
=
\
buffer
.
low_latency_dispatch
(
current_x
,
topk_idx
,
num_tokens
,
num_experts
,
use_fp8
=
dispatch_use_fp8
,
round_scale
=
round_scale
,
use_ue8m0
=
use_ue8m0
,
async_finish
=
not
return_recv_hook
,
return_recv_hook
=
return_recv_hook
)
hook
()
if
return_recv_hook
else
event
.
current_stream_wait
()
packed_recv_x
=
(
packed_recv_x
[
0
],
packed_recv_x
[
1
].
contiguous
())
if
dispatch_use_fp8
else
packed_recv_x
simulated_gemm_x
=
per_token_cast_back
(
packed_recv_x
[
0
].
view
(
-
1
,
hidden
),
packed_recv_x
[
1
].
view
(
-
1
,
hidden
//
128
)).
view
(
packed_recv_x
[
0
].
shape
)
\
if
dispatch_use_fp8
else
packed_recv_x
.
clone
()
for
i
in
range
(
num_local_experts
if
do_check
else
0
):
expert_id
=
rank
*
num_local_experts
+
i
recv_x
=
per_token_cast_back
(
packed_recv_x
[
0
][
i
],
packed_recv_x
[
1
][
i
])
if
dispatch_use_fp8
else
packed_recv_x
[
i
]
recv_count
,
recv_src_info
,
recv_layout_range
=
packed_recv_count
[
i
],
handle
[
0
][
i
],
handle
[
1
][
i
]
# Check expert indices
int_mask
=
(
2
**
32
)
-
1
num_valid_tokens
=
recv_count
.
item
()
assert
num_valid_tokens
==
(
recv_layout_range
&
int_mask
).
sum
().
item
(),
f
'
{
num_valid_tokens
}
!=
{
recv_layout_range
&
int_mask
}
.sum().item()'
assert
num_valid_tokens
==
(
all_topk_idx
==
expert_id
).
sum
(
dim
=
[
1
,
2
])[
mask_status
==
0
].
sum
().
item
(
),
f
'
{
num_valid_tokens
}
!=
{
(
all_topk_idx
==
expert_id
).
sum
(
dim
=
[
1
,
2
])[
mask_status
==
0
].
sum
().
item
()
}
'
if
num_valid_tokens
==
0
:
continue
# Check received data
if
current_x
is
x
:
recv_x
=
recv_x
[:
num_valid_tokens
]
recv_x_amin
=
recv_x
[:,
:
-
128
].
amin
(
dim
=-
1
)
recv_src_info
=
recv_src_info
[:
num_valid_tokens
]
assert
torch
.
equal
(
recv_x_amin
,
recv_x
[:,
:
-
128
].
amax
(
dim
=-
1
))
if
round_scale
:
assert
calc_diff
(
recv_x
[:,
-
1
],
recv_src_info
.
view
(
-
1
))
<
0.007
else
:
assert
(
recv_x
[:,
-
128
:]
-
recv_src_info
.
view
(
-
1
,
1
)
%
num_tokens
).
sum
().
item
()
==
0
for
j
in
range
(
num_ranks
):
begin_idx
,
count
=
(
recv_layout_range
[
j
]
>>
32
).
item
(),
(
recv_layout_range
[
j
]
&
int_mask
).
item
()
if
not
round_scale
:
assert
(
recv_x_amin
==
j
-
rank_offset
).
sum
().
item
()
==
(
all_topk_idx
[
j
]
==
expert_id
).
sum
().
item
()
assert
(
recv_x
[
begin_idx
:
begin_idx
+
count
,
:
-
128
]
-
j
+
rank_offset
).
sum
().
item
()
==
0
if
dispatch_use_fp8
:
hash_value
^=
hash_tensor
(
packed_recv_x
[
0
][
i
,
:
num_valid_tokens
])
hash_value
^=
hash_tensor
(
packed_recv_x
[
1
][
i
,
:
num_valid_tokens
])
else
:
hash_value
^=
hash_tensor
(
packed_recv_x
[
i
,
:
num_valid_tokens
])
# Check combine correctness
for
zero_copy
in
(
False
,
True
):
if
zero_copy
:
buffer
.
get_next_low_latency_combine_buffer
(
handle
)[:,
:,
:]
=
simulated_gemm_x
out
=
torch
.
empty
((
num_tokens
,
hidden
),
dtype
=
torch
.
bfloat16
,
device
=
'cuda'
)
combined_x
,
event
,
hook
=
buffer
.
low_latency_combine
(
simulated_gemm_x
,
topk_idx
,
topk_weights
,
handle
,
async_finish
=
not
return_recv_hook
,
# zero_copy=zero_copy,
return_recv_hook
=
return_recv_hook
,
out
=
out
)
hook
()
if
return_recv_hook
else
event
.
current_stream_wait
()
if
do_check
:
diff
=
calc_diff
(
current_x
*
topk_weights
.
masked_fill
(
topk_idx
==
-
1
,
0
).
sum
(
dim
=
1
).
view
(
-
1
,
1
),
combined_x
)
assert
torch
.
isnan
(
combined_x
).
sum
().
item
()
==
0
# if not round_scale:
assert
diff
<
(
9e-4
if
dispatch_use_fp8
else
1e-5
),
f
'Error: diff=
{
diff
}
, dispatch_use_fp8=
{
dispatch_use_fp8
}
, zero_copy=
{
zero_copy
}
'
hash_value
^=
hash_tensor
(
combined_x
)
# noinspection PyShadowingNames
def
large_gemm_with_hook
(
hook
):
mat_0
=
torch
.
randn
((
8192
,
8192
),
dtype
=
torch
.
float
)
mat_1
=
torch
.
randn
((
8192
,
8192
),
dtype
=
torch
.
float
)
mat_0
@
mat_1
hook
()
# noinspection PyShadowingNames
def
test_func
(
return_recv_hook
:
bool
):
recv_x
,
recv_count
,
handle
,
event
,
hook
=
\
buffer
.
low_latency_dispatch
(
current_x
,
topk_idx
,
num_tokens
,
num_experts
,
use_fp8
=
True
,
async_finish
=
False
,
return_recv_hook
=
return_recv_hook
)
large_gemm_with_hook
(
hook
)
if
return_recv_hook
else
None
combined_x
,
event
,
hook
=
buffer
.
low_latency_combine
(
simulated_gemm_x
,
topk_idx
,
topk_weights
,
handle
,
return_recv_hook
=
return_recv_hook
)
large_gemm_with_hook
(
hook
)
if
return_recv_hook
else
None
# Calculate bandwidth
num_fp8_bytes
,
num_bf16_bytes
=
(
hidden
+
hidden
/
128
*
4
+
16
),
hidden
*
2
num_logfmt10_bytes
=
hidden
*
10
/
8
+
hidden
/
128
*
4
num_dispatch_comm_bytes
,
num_combine_comm_bytes
=
0
,
0
for
i
in
range
(
num_tokens
):
num_selections
=
(
topk_idx
[
i
]
!=
-
1
).
sum
().
item
()
num_dispatch_comm_bytes
+=
num_fp8_bytes
*
num_selections
num_combine_comm_bytes
+=
num_bf16_bytes
*
num_selections
# Dispatch + combine testing
avg_t
,
min_t
,
max_t
=
bench
(
partial
(
test_func
,
return_recv_hook
=
False
))
print
(
f
'[rank
{
rank
}
] Dispatch + combine bandwidth:
{
(
num_dispatch_comm_bytes
+
num_combine_comm_bytes
)
/
1e9
/
avg_t
:.
2
f
}
GB/s, '
f
'avg_t=
{
avg_t
*
1e6
:.
2
f
}
us, min_t=
{
min_t
*
1e6
:.
2
f
}
us, max_t=
{
max_t
*
1e6
:.
2
f
}
us'
,
flush
=
True
)
# Separate profiling
for
return_recv_hook
in
(
False
,
True
):
group
.
barrier
()
dispatch_t
,
combine_t
=
bench_kineto
(
partial
(
test_func
,
return_recv_hook
=
return_recv_hook
),
kernel_names
=
(
'dispatch'
,
'combine'
),
barrier_comm_profiling
=
True
,
suppress_kineto_output
=
True
,
num_kernels_per_period
=
2
if
return_recv_hook
else
1
)
if
not
return_recv_hook
:
print
(
f
'[rank
{
rank
}
] Dispatch bandwidth:
{
num_dispatch_comm_bytes
/
1e9
/
dispatch_t
:.
2
f
}
GB/s, avg_t=
{
dispatch_t
*
1e6
:.
2
f
}
us | '
f
'Combine bandwidth:
{
num_combine_comm_bytes
/
1e9
/
combine_t
:.
2
f
}
GB/s, avg_t=
{
combine_t
*
1e6
:.
2
f
}
us'
,
flush
=
True
)
else
:
print
(
f
'[rank
{
rank
}
] Dispatch send/recv time:
{
dispatch_t
[
0
]
*
1e6
:.
2
f
}
+
{
dispatch_t
[
1
]
*
1e6
:.
2
f
}
us | '
f
'Combine send/recv time:
{
combine_t
[
0
]
*
1e6
:.
2
f
}
+
{
combine_t
[
1
]
*
1e6
:.
2
f
}
us'
,
flush
=
True
)
return
hash_value
# noinspection PyUnboundLocalVariable,PyShadowingNames
def
test_loop
(
local_rank
:
int
,
num_local_ranks
:
int
,
args
:
argparse
.
Namespace
):
rank
,
num_ranks
,
group
=
init_dist
(
local_rank
,
num_local_ranks
)
num_tokens
,
hidden
=
args
.
num_tokens
,
args
.
hidden
num_topk
,
num_experts
=
args
.
num_topk
,
args
.
num_experts
num_rdma_bytes
=
deep_ep
.
Buffer
.
get_low_latency_rdma_size_hint
(
num_tokens
,
hidden
,
num_ranks
,
num_experts
)
if
local_rank
==
0
:
print
(
f
'Allocating buffer size:
{
num_rdma_bytes
/
1e6
}
MB ...'
,
flush
=
True
)
buffer
=
deep_ep
.
Buffer
(
group
,
num_rdma_bytes
=
num_rdma_bytes
,
low_latency_mode
=
True
,
num_qps_per_rank
=
num_experts
//
num_ranks
,
allow_nvlink_for_low_latency_mode
=
not
args
.
disable_nvlink
,
explicitly_destroy
=
True
,
allow_mnnvl
=
args
.
allow_mnnvl
)
test_main
(
num_tokens
,
hidden
,
num_experts
,
num_topk
,
rank
,
num_ranks
,
group
,
buffer
,
seed
=
1
)
do_pressure_test
=
args
.
pressure_test
for
seed
in
range
(
int
(
1e9
)
if
do_pressure_test
else
0
):
if
local_rank
==
0
:
print
(
f
'Testing with seed
{
seed
}
...'
,
flush
=
True
)
ref_hash
=
test_main
(
num_tokens
,
hidden
,
num_experts
,
num_topk
,
rank
,
num_ranks
,
group
,
buffer
,
seed
=
seed
)
for
_
in
range
(
20
):
assert
test_main
(
num_tokens
,
hidden
,
num_experts
,
num_topk
,
rank
,
num_ranks
,
group
,
buffer
,
seed
=
seed
)
==
ref_hash
,
f
'Error: seed=
{
seed
}
'
# Destroy the buffer runtime and communication group
buffer
.
destroy
()
dist
.
barrier
()
dist
.
destroy_process_group
()
if
__name__
==
'__main__'
:
# TODO: you may modify NUMA binding for less CPU overhead
# TODO: buggy with `num_tokens=512`
parser
=
argparse
.
ArgumentParser
(
description
=
'Test low-latency EP kernels'
)
parser
.
add_argument
(
'--num-processes'
,
type
=
int
,
default
=
8
,
help
=
'Number of processes to spawn (default: 8)'
)
parser
.
add_argument
(
'--num-tokens'
,
type
=
int
,
default
=
128
,
help
=
'Number of tokens (default: 128)'
)
parser
.
add_argument
(
'--hidden'
,
type
=
int
,
default
=
7168
,
help
=
'Hidden dimension size (default: 7168)'
)
parser
.
add_argument
(
'--num-topk'
,
type
=
int
,
default
=
8
,
help
=
'Number of top-k experts (default: 8)'
)
parser
.
add_argument
(
'--num-experts'
,
type
=
int
,
default
=
288
,
help
=
'Number of experts (default: 288)'
)
parser
.
add_argument
(
'--allow-mnnvl'
,
action
=
"store_true"
,
help
=
'Allow MNNVL for communication'
)
parser
.
add_argument
(
'--disable-nvlink'
,
action
=
'store_true'
,
help
=
'Whether to disable NVLink for testing'
)
parser
.
add_argument
(
"--pressure-test"
,
action
=
'store_true'
,
help
=
'Whether to do pressure test'
)
parser
.
add_argument
(
"--shrink-test"
,
action
=
'store_true'
,
help
=
'Whether to simulate failure and test shrink mode'
)
parser
.
add_argument
(
'--use-logfmt'
,
action
=
'store_true'
,
help
=
'Whether to test LogFMT combine'
)
args
=
parser
.
parse_args
()
num_processes
=
args
.
num_processes
torch
.
multiprocessing
.
spawn
(
test_loop
,
args
=
(
num_processes
,
args
),
nprocs
=
num_processes
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment