Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
DeepEP
Commits
21efbe9b
Unverified
Commit
21efbe9b
authored
Jun 12, 2025
by
Shifang Xu
Committed by
GitHub
Jun 12, 2025
Browse files
Support UE8M0 data format. (#206)
parent
9ec06120
Changes
14
Hide whitespace changes
Inline
Side-by-side
Showing
14 changed files
with
258 additions
and
118 deletions
+258
-118
csrc/deep_ep.cpp
csrc/deep_ep.cpp
+38
-19
csrc/deep_ep.hpp
csrc/deep_ep.hpp
+2
-1
csrc/kernels/CMakeLists.txt
csrc/kernels/CMakeLists.txt
+2
-2
csrc/kernels/api.cuh
csrc/kernels/api.cuh
+6
-3
csrc/kernels/internode.cu
csrc/kernels/internode.cu
+11
-4
csrc/kernels/internode_ll.cu
csrc/kernels/internode_ll.cu
+48
-21
csrc/kernels/intranode.cu
csrc/kernels/intranode.cu
+10
-2
csrc/kernels/utils.cuh
csrc/kernels/utils.cuh
+37
-0
deep_ep/buffer.py
deep_ep/buffer.py
+16
-10
install.sh
install.sh
+12
-0
tests/test_internode.py
tests/test_internode.py
+5
-0
tests/test_intranode.py
tests/test_intranode.py
+1
-0
tests/test_low_latency.py
tests/test_low_latency.py
+67
-56
tests/utils.py
tests/utils.py
+3
-0
No files found.
csrc/deep_ep.cpp
View file @
21efbe9b
...
...
@@ -359,14 +359,16 @@ Buffer::intranode_dispatch(const torch::Tensor& x, const std::optional<torch::Te
// FP8 scales checks
float
*
x_scales_ptr
=
nullptr
;
int
num_scales
=
0
;
int
num_scales
=
0
,
scale_token_stride
=
0
,
scale_hidden_stride
=
0
;
if
(
x_scales
.
has_value
())
{
EP_HOST_ASSERT
(
x
.
element_size
()
==
1
);
EP_HOST_ASSERT
(
x_scales
->
scalar_type
()
==
torch
::
kFloat32
);
EP_HOST_ASSERT
(
x_scales
->
dim
()
>
0
and
x_scales
->
dim
()
<
3
and
x_scales
->
is_contiguous
()
);
EP_HOST_ASSERT
(
x_scales
->
scalar_type
()
==
torch
::
kFloat32
or
x_scales
->
scalar_type
()
==
torch
::
kInt
);
EP_HOST_ASSERT
(
x_scales
->
dim
()
==
2
);
EP_HOST_ASSERT
(
x_scales
->
size
(
0
)
==
num_tokens
);
num_scales
=
x_scales
->
dim
()
==
1
?
1
:
static_cast
<
int
>
(
x_scales
->
size
(
1
));
x_scales_ptr
=
x_scales
->
data_ptr
<
float
>
();
x_scales_ptr
=
static_cast
<
float
*>
(
x_scales
->
data_ptr
());
scale_token_stride
=
static_cast
<
int
>
(
x_scales
->
stride
(
0
));
scale_hidden_stride
=
static_cast
<
int
>
(
x_scales
->
stride
(
1
));
}
// Allocate all tensors on comm stream if set
...
...
@@ -474,7 +476,7 @@ Buffer::intranode_dispatch(const torch::Tensor& x, const std::optional<torch::Te
recv_x_scales
=
x_scales
->
dim
()
==
1
?
torch
::
empty
({
num_recv_tokens
},
x_scales
->
options
())
:
torch
::
empty
({
num_recv_tokens
,
num_scales
},
x_scales
->
options
());
recv_x_scales_ptr
=
recv_x_scales
->
data_ptr
<
float
>
();
recv_x_scales_ptr
=
static_cast
<
float
*>
(
recv_x_scales
->
data_ptr
(
)
);
}
// Dispatch
...
...
@@ -492,7 +494,9 @@ Buffer::intranode_dispatch(const torch::Tensor& x, const std::optional<torch::Te
send_head
.
data_ptr
<
int
>
(),
x
.
data_ptr
(),
x_scales_ptr
,
topk_idx_ptr
,
topk_weights_ptr
,
is_token_in_rank
.
data_ptr
<
bool
>
(),
channel_prefix_matrix
.
data_ptr
<
int
>
(),
num_tokens
,
num_worst_tokens
,
static_cast
<
int
>
(
hidden
*
recv_x
.
element_size
()
/
sizeof
(
int4
)),
num_topk
,
num_experts
,
num_scales
,
num_tokens
,
num_worst_tokens
,
static_cast
<
int
>
(
hidden
*
recv_x
.
element_size
()
/
sizeof
(
int4
)),
num_topk
,
num_experts
,
num_scales
,
scale_token_stride
,
scale_hidden_stride
,
buffer_ptrs_gpu
,
rank
,
num_ranks
,
comm_stream
,
config
.
num_sms
,
config
.
num_max_nvl_chunked_send_tokens
,
config
.
num_max_nvl_chunked_recv_tokens
);
...
...
@@ -708,14 +712,16 @@ Buffer::internode_dispatch(const torch::Tensor& x, const std::optional<torch::Te
// FP8 scales checks
float
*
x_scales_ptr
=
nullptr
;
int
num_scales
=
0
;
int
num_scales
=
0
,
scale_token_stride
=
0
,
scale_hidden_stride
=
0
;
if
(
x_scales
.
has_value
())
{
EP_HOST_ASSERT
(
x
.
element_size
()
==
1
);
EP_HOST_ASSERT
(
x_scales
->
scalar_type
()
==
torch
::
kFloat32
);
EP_HOST_ASSERT
(
x_scales
->
dim
()
>
0
and
x_scales
->
dim
()
<
3
and
x_scales
->
is_contiguous
()
);
EP_HOST_ASSERT
(
x_scales
->
scalar_type
()
==
torch
::
kFloat32
or
x_scales
->
scalar_type
()
==
torch
::
kInt
);
EP_HOST_ASSERT
(
x_scales
->
dim
()
==
2
);
EP_HOST_ASSERT
(
x_scales
->
size
(
0
)
==
num_tokens
);
num_scales
=
x_scales
->
dim
()
==
1
?
1
:
static_cast
<
int
>
(
x_scales
->
size
(
1
));
x_scales_ptr
=
x_scales
->
data_ptr
<
float
>
();
x_scales_ptr
=
static_cast
<
float
*>
(
x_scales
->
data_ptr
());
scale_token_stride
=
static_cast
<
int
>
(
x_scales
->
stride
(
0
));
scale_hidden_stride
=
static_cast
<
int
>
(
x_scales
->
stride
(
1
));
}
// Allocate all tensors on comm stream if set
...
...
@@ -838,7 +844,7 @@ Buffer::internode_dispatch(const torch::Tensor& x, const std::optional<torch::Te
recv_x_scales
=
x_scales
->
dim
()
==
1
?
torch
::
empty
({
num_recv_tokens
},
x_scales
->
options
())
:
torch
::
empty
({
num_recv_tokens
,
num_scales
},
x_scales
->
options
());
recv_x_scales_ptr
=
recv_x_scales
->
data_ptr
<
float
>
();
recv_x_scales_ptr
=
static_cast
<
float
*>
(
recv_x_scales
->
data_ptr
(
)
);
}
// Launch data dispatch
...
...
@@ -851,8 +857,9 @@ Buffer::internode_dispatch(const torch::Tensor& x, const std::optional<torch::Te
cached_mode
?
nullptr
:
recv_gbl_channel_prefix_matrix
->
data_ptr
<
int
>
(),
rdma_channel_prefix_matrix
.
data_ptr
<
int
>
(),
recv_rdma_rank_prefix_sum
.
data_ptr
<
int
>
(),
gbl_channel_prefix_matrix
.
data_ptr
<
int
>
(),
recv_gbl_rank_prefix_sum
.
data_ptr
<
int
>
(),
num_tokens
,
hidden_int4
,
num_scales
,
num_topk
,
num_experts
,
is_token_in_rank
.
data_ptr
<
bool
>
(),
num_tokens
,
hidden_int4
,
num_scales
,
num_topk
,
num_experts
,
scale_token_stride
,
scale_hidden_stride
,
rdma_buffer_ptr
,
config
.
num_max_rdma_chunked_send_tokens
,
config
.
num_max_rdma_chunked_recv_tokens
,
buffer_ptrs_gpu
,
config
.
num_max_nvl_chunked_send_tokens
,
config
.
num_max_nvl_chunked_recv_tokens
,
rank
,
num_ranks
,
cached_mode
,
...
...
@@ -1057,7 +1064,8 @@ std::tuple<torch::Tensor, std::optional<torch::Tensor>, torch::Tensor, torch::Te
Buffer
::
low_latency_dispatch
(
const
torch
::
Tensor
&
x
,
const
torch
::
Tensor
&
topk_idx
,
const
std
::
optional
<
torch
::
Tensor
>&
cumulative_local_expert_recv_stats
,
int
num_max_dispatch_tokens_per_rank
,
int
num_experts
,
bool
use_fp8
,
bool
async
,
bool
return_recv_hook
)
{
bool
use_fp8
,
bool
round_scale
,
bool
use_ue8m0
,
bool
async
,
bool
return_recv_hook
)
{
#ifndef DISABLE_NVSHMEM
EP_HOST_ASSERT
(
low_latency_mode
);
...
...
@@ -1077,7 +1085,7 @@ Buffer::low_latency_dispatch(const torch::Tensor& x, const torch::Tensor& topk_i
auto
num_tokens
=
static_cast
<
int
>
(
x
.
size
(
0
)),
hidden
=
static_cast
<
int
>
(
x
.
size
(
1
));
auto
num_scales
=
hidden
/
128
,
num_topk
=
static_cast
<
int
>
(
topk_idx
.
size
(
1
));
int
num_local_experts
=
num_experts
/
num_ranks
;
auto
num_local_experts
=
num_experts
/
num_ranks
;
// Buffer control
LowLatencyLayout
layout
(
rdma_buffer_ptr
,
num_max_dispatch_tokens_per_rank
,
hidden
,
num_ranks
,
num_experts
);
...
...
@@ -1102,12 +1110,22 @@ Buffer::low_latency_dispatch(const torch::Tensor& x, const torch::Tensor& topk_i
// Allocate column-majored scales
auto
packed_recv_x_scales
=
std
::
optional
<
torch
::
Tensor
>
();
float
*
packed_recv_x_scales_ptr
=
nullptr
;
void
*
packed_recv_x_scales_ptr
=
nullptr
;
EP_HOST_ASSERT
((
num_ranks
*
num_max_dispatch_tokens_per_rank
)
%
4
==
0
and
"TMA requires the number of tokens to be multiple of 4"
);
if
(
use_fp8
)
{
EP_HOST_ASSERT
((
num_ranks
*
num_max_dispatch_tokens_per_rank
)
%
4
==
0
and
"TMA requires the number of tokens to be multiple of 4"
);
packed_recv_x_scales
=
torch
::
empty
({
num_local_experts
,
num_scales
,
num_ranks
*
num_max_dispatch_tokens_per_rank
},
torch
::
dtype
(
torch
::
kFloat32
).
device
(
torch
::
kCUDA
));
// TODO: support unaligned cases
EP_HOST_ASSERT
(
hidden
%
512
==
0
);
if
(
not
use_ue8m0
)
{
packed_recv_x_scales
=
torch
::
empty
({
num_local_experts
,
hidden
/
128
,
num_ranks
*
num_max_dispatch_tokens_per_rank
},
torch
::
dtype
(
torch
::
kFloat32
).
device
(
torch
::
kCUDA
));
}
else
{
EP_HOST_ASSERT
(
round_scale
);
packed_recv_x_scales
=
torch
::
empty
({
num_local_experts
,
hidden
/
512
,
num_ranks
*
num_max_dispatch_tokens_per_rank
},
torch
::
dtype
(
torch
::
kInt
).
device
(
torch
::
kCUDA
));
}
packed_recv_x_scales
=
torch
::
transpose
(
packed_recv_x_scales
.
value
(),
1
,
2
);
packed_recv_x_scales_ptr
=
packed_recv_x_scales
->
data_ptr
<
float
>
();
packed_recv_x_scales_ptr
=
packed_recv_x_scales
->
data_ptr
();
}
// Kernel launch
...
...
@@ -1122,7 +1140,8 @@ Buffer::low_latency_dispatch(const torch::Tensor& x, const torch::Tensor& topk_i
x
.
data_ptr
(),
topk_idx
.
data_ptr
<
int64_t
>
(),
next_clean_meta
.
first
,
next_clean_meta
.
second
,
num_tokens
,
hidden
,
num_max_dispatch_tokens_per_rank
,
num_topk
,
num_experts
,
rank
,
num_ranks
,
use_fp8
,
num_topk
,
num_experts
,
rank
,
num_ranks
,
use_fp8
,
round_scale
,
use_ue8m0
,
workspace
,
low_latency_usage_flag_mapped
,
launch_stream
,
phases
);
};
launcher
(
return_recv_hook
?
LOW_LATENCY_SEND_PHASE
:
(
LOW_LATENCY_SEND_PHASE
|
LOW_LATENCY_RECV_PHASE
));
...
...
csrc/deep_ep.hpp
View file @
21efbe9b
...
...
@@ -141,7 +141,8 @@ public:
low_latency_dispatch
(
const
torch
::
Tensor
&
x
,
const
torch
::
Tensor
&
topk_idx
,
const
std
::
optional
<
torch
::
Tensor
>&
cumulative_local_expert_recv_stats
,
int
num_max_dispatch_tokens_per_rank
,
int
num_experts
,
bool
use_fp8
,
bool
async
,
bool
return_recv_hook
);
bool
use_fp8
,
bool
round_scale
,
bool
use_ue8m0
,
bool
async
,
bool
return_recv_hook
);
std
::
tuple
<
torch
::
Tensor
,
std
::
optional
<
EventHandle
>
,
std
::
optional
<
std
::
function
<
void
()
>>>
low_latency_combine
(
const
torch
::
Tensor
&
x
,
const
torch
::
Tensor
&
topk_idx
,
const
torch
::
Tensor
&
topk_weights
,
...
...
csrc/kernels/CMakeLists.txt
View file @
21efbe9b
...
...
@@ -4,8 +4,8 @@ function(add_deep_ep_library target_name source_file)
POSITION_INDEPENDENT_CODE ON
CXX_STANDARD_REQUIRED ON
CUDA_STANDARD_REQUIRED ON
CXX_STANDARD 1
4
CUDA_STANDARD 1
4
CXX_STANDARD 1
7
CUDA_STANDARD 1
7
CUDA_SEPARABLE_COMPILATION ON
)
target_link_libraries
(
${
target_name
}
PUBLIC nvshmem cudart cudadevrt mlx5
)
...
...
csrc/kernels/api.cuh
View file @
21efbe9b
...
...
@@ -57,6 +57,7 @@ void dispatch(void* recv_x, float* recv_x_scales, int* recv_src_idx, int64_t* re
int
*
send_head
,
const
void
*
x
,
const
float
*
x_scales
,
const
int64_t
*
topk_idx
,
const
float
*
topk_weights
,
const
bool
*
is_token_in_rank
,
const
int
*
channel_prefix_matrix
,
int
num_tokens
,
int
num_worst_tokens
,
int
hidden_int4
,
int
num_topk
,
int
num_experts
,
int
num_scales
,
int
scale_token_stride
,
int
scale_hidden_stride
,
void
**
buffer_ptrs
,
int
rank
,
int
num_ranks
,
cudaStream_t
stream
,
int
num_sms
,
int
num_max_send_tokens
,
int
num_recv_buffer_tokens
);
...
...
@@ -99,8 +100,9 @@ void dispatch(void* recv_x, float* recv_x_scales, int64_t* recv_topk_idx, float*
int
*
recv_rdma_channel_prefix_matrix
,
int
*
recv_gbl_channel_prefix_matrix
,
const
int
*
rdma_channel_prefix_matrix
,
const
int
*
recv_rdma_rank_prefix_sum
,
const
int
*
gbl_channel_prefix_matrix
,
const
int
*
recv_gbl_rank_prefix_sum
,
int
num_tokens
,
int
hidden_int4
,
int
num_scales
,
int
num_topk
,
int
num_experts
,
const
bool
*
is_token_in_rank
,
int
num_tokens
,
int
hidden_int4
,
int
num_scales
,
int
num_topk
,
int
num_experts
,
int
scale_token_stride
,
int
scale_hidden_stride
,
void
*
rdma_buffer_ptr
,
int
num_max_rdma_chunked_send_tokens
,
int
num_max_rdma_chunked_recv_tokens
,
void
**
buffer_ptrs
,
int
num_max_nvl_chunked_send_tokens
,
int
num_max_nvl_chunked_recv_tokens
,
int
rank
,
int
num_ranks
,
bool
is_cached_dispatch
,
...
...
@@ -135,7 +137,7 @@ void clean_low_latency_buffer(int* clean_0, int num_clean_int_0,
int
*
clean_1
,
int
num_clean_int_1
,
cudaStream_t
stream
);
void
dispatch
(
void
*
packed_recv_x
,
float
*
packed_recv_x_scales
,
void
dispatch
(
void
*
packed_recv_x
,
void
*
packed_recv_x_scales
,
int
*
packed_recv_src_info
,
int64_t
*
packed_recv_layout_range
,
int
*
packed_recv_count
,
int
*
cumulative_local_expert_recv_stats
,
...
...
@@ -143,7 +145,8 @@ void dispatch(void* packed_recv_x, float* packed_recv_x_scales,
const
void
*
x
,
const
int64_t
*
topk_idx
,
int
*
next_clean
,
int
num_next_clean_int
,
int
num_tokens
,
int
hidden
,
int
num_max_dispatch_tokens_per_rank
,
int
num_topk
,
int
num_experts
,
int
rank
,
int
num_ranks
,
bool
use_fp8
,
int
num_topk
,
int
num_experts
,
int
rank
,
int
num_ranks
,
bool
use_fp8
,
bool
round_scale
,
bool
use_ue8m0
,
void
*
workspace
,
int
*
usage_flag
,
cudaStream_t
stream
,
int
phases
);
...
...
csrc/kernels/internode.cu
View file @
21efbe9b
...
...
@@ -343,8 +343,9 @@ dispatch(int4* recv_x, float* recv_x_scales, int64_t* recv_topk_idx, float* recv
int
*
recv_rdma_channel_prefix_matrix
,
int
*
recv_gbl_channel_prefix_matrix
,
const
int
*
rdma_channel_prefix_matrix
,
const
int
*
recv_rdma_rank_prefix_sum
,
const
int
*
gbl_channel_prefix_matrix
,
const
int
*
recv_gbl_rank_prefix_sum
,
int
num_tokens
,
int
hidden_int4
,
int
num_scales
,
int
num_topk
,
int
num_experts
,
const
bool
*
is_token_in_rank
,
int
num_tokens
,
int
hidden_int4
,
int
num_scales
,
int
num_topk
,
int
num_experts
,
int
scale_token_stride
,
int
scale_hidden_stride
,
void
*
rdma_buffer_ptr
,
int
num_max_rdma_chunked_send_tokens
,
int
num_max_rdma_chunked_recv_tokens
,
void
**
buffer_ptrs
,
int
num_max_nvl_chunked_send_tokens
,
int
num_max_nvl_chunked_recv_tokens
,
int
rank
,
int
num_ranks
)
{
...
...
@@ -536,7 +537,8 @@ dispatch(int4* recv_x, float* recv_x_scales, int64_t* recv_topk_idx, float* recv
// Copy `x_scales` into symmetric send buffer
#pragma unroll
for
(
int
i
=
lane_id
;
i
<
num_scales
;
i
+=
32
)
{
auto
value
=
ld_nc_global
(
x_scales
+
token_idx
*
num_scales
+
i
);
auto
offset
=
token_idx
*
scale_token_stride
+
i
*
scale_hidden_stride
;
auto
value
=
ld_nc_global
(
x_scales
+
offset
);
#pragma unroll
for
(
int
j
=
0
;
j
<
num_topk_ranks
;
++
j
)
st_na_global
(
reinterpret_cast
<
float
*>
(
dst_send_buffers
[
j
])
+
i
,
value
);
...
...
@@ -938,14 +940,18 @@ void dispatch(void* recv_x, float* recv_x_scales, int64_t* recv_topk_idx, float*
int
*
recv_rdma_channel_prefix_matrix
,
int
*
recv_gbl_channel_prefix_matrix
,
const
int
*
rdma_channel_prefix_matrix
,
const
int
*
recv_rdma_rank_prefix_sum
,
const
int
*
gbl_channel_prefix_matrix
,
const
int
*
recv_gbl_rank_prefix_sum
,
int
num_tokens
,
int
hidden_int4
,
int
num_scales
,
int
num_topk
,
int
num_experts
,
const
bool
*
is_token_in_rank
,
int
num_tokens
,
int
hidden_int4
,
int
num_scales
,
int
num_topk
,
int
num_experts
,
int
scale_token_stride
,
int
scale_hidden_stride
,
void
*
rdma_buffer_ptr
,
int
num_max_rdma_chunked_send_tokens
,
int
num_max_rdma_chunked_recv_tokens
,
void
**
buffer_ptrs
,
int
num_max_nvl_chunked_send_tokens
,
int
num_max_nvl_chunked_recv_tokens
,
int
rank
,
int
num_ranks
,
bool
is_cached_dispatch
,
cudaStream_t
stream
,
int
num_channels
,
bool
low_latency_mode
)
{
constexpr
int
kNumDispatchRDMASenderWarps
=
7
;
// Make sure never OOB
EP_HOST_ASSERT
(
static_cast
<
int64_t
>
(
num_scales
)
*
scale_hidden_stride
<
std
::
numeric_limits
<
int
>::
max
());
#define DISPATCH_LAUNCH_CASE(num_rdma_ranks) { \
auto dispatch_func = low_latency_mode ? \
(is_cached_dispatch ? dispatch<true, num_rdma_ranks, true, kNumDispatchRDMASenderWarps> : dispatch<true, num_rdma_ranks, false, kNumDispatchRDMASenderWarps>) : \
...
...
@@ -957,8 +963,9 @@ void dispatch(void* recv_x, float* recv_x_scales, int64_t* recv_topk_idx, float*
recv_rdma_channel_prefix_matrix, recv_gbl_channel_prefix_matrix, \
rdma_channel_prefix_matrix, recv_rdma_rank_prefix_sum, \
gbl_channel_prefix_matrix, recv_gbl_rank_prefix_sum, \
num_tokens, hidden_int4, num_scales, num_topk, num_experts, \
is_token_in_rank, \
num_tokens, hidden_int4, num_scales, num_topk, num_experts, \
scale_token_stride, scale_hidden_stride, \
rdma_buffer_ptr, num_max_rdma_chunked_send_tokens, num_max_rdma_chunked_recv_tokens, \
buffer_ptrs, num_max_nvl_chunked_send_tokens, num_max_nvl_chunked_recv_tokens, \
rank, num_ranks); } break
...
...
csrc/kernels/internode_ll.cu
View file @
21efbe9b
...
...
@@ -36,9 +36,10 @@ void clean_low_latency_buffer(int* clean_0, int num_clean_int_0,
clean_0
,
num_clean_int_0
,
clean_1
,
num_clean_int_1
);
}
template
<
bool
kUseFP8
,
int
kNumWarpGroups
,
int
kNumWarpsPerGroup
,
int
kHidden
>
template
<
bool
kUseFP8
,
bool
kUseUE8M0
,
int
kNumWarpGroups
,
int
kNumWarpsPerGroup
,
int
kHidden
>
__global__
__launch_bounds__
(
kNumWarpGroups
*
kNumWarpsPerGroup
*
32
,
1
)
void
dispatch
(
void
*
packed_recv_x
,
float
*
packed_recv_x_scales
,
dispatch
(
void
*
packed_recv_x
,
void
*
packed_recv_x_scales
,
int
*
packed_recv_src_info
,
int64_t
*
packed_recv_layout_range
,
int
*
packed_recv_count
,
int
*
cumulative_local_expert_recv_stats
,
...
...
@@ -48,7 +49,7 @@ dispatch(void* packed_recv_x, float* packed_recv_x_scales,
int
*
next_clean
,
int
num_next_clean_int
,
int
num_tokens
,
int
num_max_dispatch_tokens_per_rank
,
int
num_topk
,
int
num_experts
,
int
rank
,
int
num_ranks
,
int
*
usage_flag
,
int
phases
)
{
bool
round_scale
,
int
*
usage_flag
,
int
phases
)
{
const
auto
sm_id
=
static_cast
<
int
>
(
blockIdx
.
x
);
const
auto
thread_id
=
static_cast
<
int
>
(
threadIdx
.
x
);
const
auto
warp_id
=
thread_id
/
32
,
lane_id
=
get_lane_id
();
...
...
@@ -59,9 +60,13 @@ dispatch(void* packed_recv_x, float* packed_recv_x_scales,
const
auto
sub_warp_id
=
warp_id
%
kNumWarpsPerGroup
;
const
auto
responsible_expert_idx
=
sm_id
*
kNumWarpGroups
+
warp_group_id
;
// May extract UE8M0 from the scales
using
scale_t
=
std
::
conditional_t
<
kUseUE8M0
,
uint8_t
,
float
>
;
using
packed_t
=
std
::
conditional_t
<
kUseUE8M0
,
uint32_t
,
float
>
;
EP_STATIC_ASSERT
(
sizeof
(
packed_t
)
%
sizeof
(
scale_t
)
==
0
,
"Invalid vector length"
);
// FP8 staffs
constexpr
int
kNumPerChannels
=
128
;
constexpr
float
kFP8Margin
=
1e-4
,
kFP8Amax
=
448
,
kFP8AmaxInv
=
1.0
f
/
448.0
f
;
const
int
num_scales
=
kHidden
/
kNumPerChannels
;
const
size_t
hidden_bytes
=
kHidden
*
(
kUseFP8
?
sizeof
(
__nv_fp8_storage_t
)
:
sizeof
(
nv_bfloat16
));
const
size_t
hidden_int4
=
hidden_bytes
/
sizeof
(
int4
);
...
...
@@ -96,7 +101,7 @@ dispatch(void* packed_recv_x, float* packed_recv_x_scales,
const
auto
rdma_x_vec
=
reinterpret_cast
<
vec_t
*>
(
reinterpret_cast
<
uint8_t
*>
(
rdma_x_src_idx
)
+
sizeof
(
int4
));
const
auto
rdma_x_scales
=
reinterpret_cast
<
float
*>
(
reinterpret_cast
<
uint8_t
*>
(
rdma_x_vec
)
+
hidden_bytes
);
// Overlap top-k index read and source token index write
// Overlap top-k index read and source token index write
s
auto
dst_expert_idx
=
warp_id
<
num_topk
?
static_cast
<
int
>
(
__ldg
(
topk_idx
+
token_idx
*
num_topk
+
warp_id
))
:
-
1
;
thread_id
==
0
?
(
*
rdma_x_src_idx
=
token_idx
)
:
0
;
...
...
@@ -106,7 +111,7 @@ dispatch(void* packed_recv_x, float* packed_recv_x_scales,
// Read
auto
int4_value
=
__ldg
(
x_int4
+
i
);
if
(
kUseFP8
)
{
if
constexpr
(
kUseFP8
)
{
// Calculate local amax
auto
bf16_values
=
reinterpret_cast
<
nv_bfloat16
*>
(
&
int4_value
);
float
fp32_values
[
kNumElemsPerRead
];
...
...
@@ -119,7 +124,8 @@ dispatch(void* packed_recv_x, float* packed_recv_x_scales,
// Reduce amax and scale
EP_STATIC_ASSERT
(
kNumElemsPerRead
*
32
/
kNumPerChannels
==
2
,
"Invalid vectorization"
);
amax
=
half_warp_reduce_max
(
amax
),
scale
=
kFP8Amax
/
amax
,
scale_inv
=
amax
*
kFP8AmaxInv
;
amax
=
half_warp_reduce_max
(
amax
);
calculate_fp8_scales
(
amax
,
scale
,
scale_inv
,
round_scale
);
if
(
lane_id
==
0
or
lane_id
==
16
)
rdma_x_scales
[
i
*
kNumElemsPerRead
/
128
]
=
scale_inv
;
...
...
@@ -256,9 +262,10 @@ dispatch(void* packed_recv_x, float* packed_recv_x_scales,
src_rank
*
num_max_dispatch_tokens_per_rank
*
num_bytes_per_msg
;
const
auto
recv_x_int4
=
reinterpret_cast
<
int4
*>
(
packed_recv_x
)
+
local_expert_idx
*
num_ranks
*
num_max_dispatch_tokens_per_rank
*
hidden_int4
;
const
auto
recv_x_scales
=
packed_recv_x_scales
+
local_expert_idx
*
num_ranks
*
num_max_dispatch_tokens_per_rank
*
num_scales
;
const
auto
recv_src_info
=
packed_recv_src_info
+
local_expert_idx
*
num_ranks
*
num_max_dispatch_tokens_per_rank
;
const
auto
recv_range
=
packed_recv_layout_range
+
local_expert_idx
*
num_ranks
;
const
auto
num_aligned_scales
=
align
<
int
>
(
num_scales
,
sizeof
(
float
)
/
sizeof
(
scale_t
));
const
auto
recv_x_scales
=
reinterpret_cast
<
scale_t
*>
(
packed_recv_x_scales
)
+
local_expert_idx
*
num_ranks
*
num_max_dispatch_tokens_per_rank
*
num_aligned_scales
;
// Shared between sub-warps in warp groups
__shared__
int
shared_num_recv_tokens
[
kNumWarpGroups
],
shared_recv_token_begin_idx
[
kNumWarpGroups
];
...
...
@@ -297,20 +304,32 @@ dispatch(void* packed_recv_x, float* packed_recv_x_scales,
UNROLLED_WARP_COPY
(
7
,
lane_id
,
hidden_int4
,
dst_data
,
src_data
,
ld_nc_global
,
st_na_global
);
// Copy scales
if
(
kUseFP8
)
{
if
constexpr
(
kUseFP8
)
{
// Equivalent CuTe layout:
// (num_tokens, (num_packed, num_elems_per_pack)):(num_elems_per_pack, (num_tokens * num_elems_per_pack, 1))
const
auto
src_scales
=
reinterpret_cast
<
float
*>
(
reinterpret_cast
<
uint8_t
*>
(
src_data
)
+
hidden_bytes
);
const
auto
dst_scales
=
reinterpret_cast
<
float
*>
(
recv_x_scales
+
recv_token_begin_idx
+
i
);
const
auto
scale_stride
=
num_ranks
*
num_max_dispatch_tokens_per_rank
;
auto
scale_0
=
lane_id
<
num_scales
?
ld_nc_global
(
src_scales
+
lane_id
)
:
0
;
auto
scale_1
=
(
lane_id
+
32
)
<
num_scales
?
ld_nc_global
(
src_scales
+
lane_id
+
32
)
:
0
;
lane_id
<
num_scales
?
dst_scales
[
lane_id
*
scale_stride
]
=
scale_0
:
0.0
f
;
(
lane_id
+
32
)
<
num_scales
?
dst_scales
[(
lane_id
+
32
)
*
scale_stride
]
=
scale_1
:
0.0
f
;
const
auto
num_elems_per_pack
=
static_cast
<
int
>
(
sizeof
(
packed_t
)
/
sizeof
(
scale_t
));
const
auto
token_idx
=
recv_token_begin_idx
+
i
;
const
auto
token_stride
=
num_elems_per_pack
;
const
auto
pack_stride
=
num_ranks
*
num_max_dispatch_tokens_per_rank
*
num_elems_per_pack
;
if
(
lane_id
<
num_scales
)
{
const
auto
pack_idx
=
lane_id
/
num_elems_per_pack
;
const
auto
elem_idx
=
lane_id
%
num_elems_per_pack
;
auto
scale
=
extract_required_scale_format
<
kUseUE8M0
>
(
ld_nc_global
(
src_scales
+
lane_id
));
recv_x_scales
[
token_idx
*
token_stride
+
pack_idx
*
pack_stride
+
elem_idx
]
=
scale
;
}
if
(
lane_id
+
32
<
num_scales
)
{
const
auto
pack_idx
=
(
lane_id
+
32
)
/
num_elems_per_pack
;
const
auto
elem_idx
=
(
lane_id
+
32
)
%
num_elems_per_pack
;
auto
scale
=
extract_required_scale_format
<
kUseUE8M0
>
(
ld_nc_global
(
src_scales
+
lane_id
+
32
));
recv_x_scales
[
token_idx
*
token_stride
+
pack_idx
*
pack_stride
+
elem_idx
]
=
scale
;
}
}
}
}
}
void
dispatch
(
void
*
packed_recv_x
,
float
*
packed_recv_x_scales
,
void
dispatch
(
void
*
packed_recv_x
,
void
*
packed_recv_x_scales
,
int
*
packed_recv_src_info
,
int64_t
*
packed_recv_layout_range
,
int
*
packed_recv_count
,
int
*
cumulative_local_expert_recv_stats
,
...
...
@@ -318,7 +337,8 @@ void dispatch(void* packed_recv_x, float* packed_recv_x_scales,
const
void
*
x
,
const
int64_t
*
topk_idx
,
int
*
next_clean
,
int
num_next_clean_int
,
int
num_tokens
,
int
hidden
,
int
num_max_dispatch_tokens_per_rank
,
int
num_topk
,
int
num_experts
,
int
rank
,
int
num_ranks
,
bool
use_fp8
,
int
num_topk
,
int
num_experts
,
int
rank
,
int
num_ranks
,
bool
use_fp8
,
bool
round_scale
,
bool
use_ue8m0
,
void
*
workspace
,
int
*
usage_flag
,
cudaStream_t
stream
,
int
phases
)
{
constexpr
int
kNumMaxTopK
=
9
;
...
...
@@ -331,13 +351,20 @@ void dispatch(void* packed_recv_x, float* packed_recv_x_scales,
EP_HOST_ASSERT
(
num_topk
<=
kNumMaxTopK
);
// Workspace checks
auto
atomic_counter_per_expert
=
reinterpret
_cast
<
int
*>
(
workspace
);
auto
atomic_counter_per_expert
=
static
_cast
<
int
*>
(
workspace
);
auto
atomic_finish_counter_per_expert
=
atomic_counter_per_expert
+
num_experts
;
EP_HOST_ASSERT
(
num_experts
*
sizeof
(
int
)
*
2
<=
NUM_WORKSPACE_BYTES
);
// FP8 checks
if
(
use_ue8m0
)
EP_HOST_ASSERT
(
round_scale
and
"UE8M0 SF requires `round_scale=True`"
);
#define DISPATCH_LAUNCH_CASE(hidden) { \
auto dispatch_func = use_fp8 ? dispatch<true, kNumWarpGroups, kNumWarpsPerGroup, hidden> : \
dispatch<false, kNumWarpGroups, kNumWarpsPerGroup, hidden>; \
auto dispatch_func = dispatch<false, false, kNumWarpGroups, kNumWarpsPerGroup, hidden>; \
if (use_fp8 and not use_ue8m0) \
dispatch_func = dispatch<true, false, kNumWarpGroups, kNumWarpsPerGroup, hidden>; \
if (use_fp8 and use_ue8m0) \
dispatch_func = dispatch<true, true, kNumWarpGroups, kNumWarpsPerGroup, hidden>; \
LAUNCH_KERNEL(&cfg, dispatch_func, \
packed_recv_x, packed_recv_x_scales, \
packed_recv_src_info, packed_recv_layout_range, \
...
...
@@ -349,7 +376,7 @@ LAUNCH_KERNEL(&cfg, dispatch_func, \
next_clean, num_next_clean_int, \
num_tokens, num_max_dispatch_tokens_per_rank, \
num_topk, num_experts, rank, num_ranks, \
usage_flag, phases); } break
round_scale,
usage_flag, phases); } break
SETUP_LAUNCH_CONFIG
(
num_sms
,
num_warps
*
32
,
stream
);
SWITCH_HIDDEN
(
DISPATCH_LAUNCH_CASE
);
...
...
csrc/kernels/intranode.cu
View file @
21efbe9b
...
...
@@ -174,6 +174,7 @@ dispatch(int4* recv_x, float* recv_x_scales, int* recv_src_idx, int64_t* recv_to
int
*
send_head
,
const
int4
*
x
,
const
float
*
x_scales
,
const
int64_t
*
topk_idx
,
const
float
*
topk_weights
,
const
bool
*
is_token_in_rank
,
const
int
*
channel_prefix_matrix
,
int
num_tokens
,
int
num_worst_tokens
,
int
hidden_int4
,
int
num_topk
,
int
num_experts
,
int
num_scales
,
int
scale_token_stride
,
int
scale_hidden_stride
,
void
**
buffer_ptrs
,
int
rank
,
int
num_max_send_tokens
,
int
num_recv_buffer_tokens
)
{
const
auto
num_sms
=
static_cast
<
int
>
(
gridDim
.
x
),
sm_id
=
static_cast
<
int
>
(
blockIdx
.
x
);
...
...
@@ -326,8 +327,10 @@ dispatch(int4* recv_x, float* recv_x_scales, int* recv_src_idx, int64_t* recv_to
// Copy `x_scales`
#pragma unroll
for
(
int
i
=
lane_id
;
i
<
num_scales
;
i
+=
32
)
channel_x_scales_buffers
[
dst_slot_idx
*
num_scales
+
i
]
=
__ldg
(
x_scales
+
token_idx
*
num_scales
+
i
);
for
(
int
i
=
lane_id
;
i
<
num_scales
;
i
+=
32
)
{
auto
offset
=
token_idx
*
scale_token_stride
+
i
*
scale_hidden_stride
;
channel_x_scales_buffers
[
dst_slot_idx
*
num_scales
+
i
]
=
__ldg
(
x_scales
+
offset
);
}
}
// Move token index
...
...
@@ -478,6 +481,7 @@ void dispatch(void* recv_x, float* recv_x_scales, int* recv_src_idx, int64_t* re
int
*
send_head
,
const
void
*
x
,
const
float
*
x_scales
,
const
int64_t
*
topk_idx
,
const
float
*
topk_weights
,
const
bool
*
is_token_in_rank
,
const
int
*
channel_prefix_matrix
,
int
num_tokens
,
int
num_worst_tokens
,
int
hidden_int4
,
int
num_topk
,
int
num_experts
,
int
num_scales
,
int
scale_token_stride
,
int
scale_hidden_stride
,
void
**
buffer_ptrs
,
int
rank
,
int
num_ranks
,
cudaStream_t
stream
,
int
num_sms
,
int
num_max_send_tokens
,
int
num_recv_buffer_tokens
)
{
constexpr
int
kNumThreads
=
768
;
...
...
@@ -486,6 +490,9 @@ void dispatch(void* recv_x, float* recv_x_scales, int* recv_src_idx, int64_t* re
constexpr
int
smem_size
=
kNumTMABytesPerWarp
*
(
kNumThreads
/
32
);
#endif
// Make sure never OOB
EP_HOST_ASSERT
(
static_cast
<
int64_t
>
(
num_scales
)
*
scale_hidden_stride
<
std
::
numeric_limits
<
int
>::
max
());
#define DISPATCH_LAUNCH_CASE(ranks) { \
auto kernel = dispatch<ranks, kNumThreads, kNumTMABytesPerWarp>; \
SET_SHARED_MEMORY_FOR_TMA(kernel); \
...
...
@@ -494,6 +501,7 @@ void dispatch(void* recv_x, float* recv_x_scales, int* recv_src_idx, int64_t* re
send_head, reinterpret_cast<const int4*>(x), x_scales, topk_idx, topk_weights, \
is_token_in_rank, channel_prefix_matrix, \
num_tokens, num_worst_tokens, hidden_int4, num_topk, num_experts, num_scales, \
scale_token_stride, scale_hidden_stride, \
buffer_ptrs, rank, \
num_max_send_tokens, num_recv_buffer_tokens); \
} break
...
...
csrc/kernels/utils.cuh
View file @
21efbe9b
...
...
@@ -401,6 +401,43 @@ __forceinline__ __device__ int get_lane_id() {
return
lane_id
;
}
constexpr
float
kFP8Margin
=
1e-4
;
constexpr
float
kFinfoAmaxE4M3
=
448.0
f
;
constexpr
float
kFinfoAmaxInvE4M3
=
1
/
448.0
f
;
__forceinline__
__device__
float
fast_pow2
(
int
x
)
{
// We can ensure `-126 <= x and x <= 127`
uint32_t
bits_x
=
(
x
+
127
)
<<
23
;
return
*
reinterpret_cast
<
float
*>
(
&
bits_x
);
}
__forceinline__
__device__
int
fast_log2_ceil
(
float
x
)
{
auto
bits_x
=
*
reinterpret_cast
<
uint32_t
*>
(
&
x
);
auto
exp_x
=
(
bits_x
>>
23
)
&
0xff
;
auto
man_bits
=
bits_x
&
((
1
<<
23
)
-
1
);
return
exp_x
-
127
+
(
man_bits
!=
0
);
}
__forceinline__
__device__
void
calculate_fp8_scales
(
float
amax
,
float
&
scale
,
float
&
scale_inv
,
bool
round_scale
)
{
if
(
round_scale
)
{
auto
exp_scale_inv
=
fast_log2_ceil
(
amax
*
kFinfoAmaxInvE4M3
);
scale
=
fast_pow2
(
-
exp_scale_inv
);
scale_inv
=
fast_pow2
(
exp_scale_inv
);
}
else
{
scale_inv
=
amax
*
kFinfoAmaxInvE4M3
;
scale
=
kFinfoAmaxE4M3
/
amax
;
}
}
template
<
bool
kIsUE8M0
,
typename
out_dtype_t
=
std
::
conditional_t
<
kIsUE8M0
,
uint8_t
,
float
>
>
__forceinline__
__device__
out_dtype_t
extract_required_scale_format
(
float
value
)
{
if
constexpr
(
kIsUE8M0
)
{
return
static_cast
<
uint8_t
>
((
*
reinterpret_cast
<
uint32_t
*>
(
&
value
))
>>
23
);
}
else
{
return
value
;
}
}
template
<
int
kNumRanks
>
__forceinline__
__device__
void
barrier_block
(
int
**
barrier_signal_ptrs
,
int
rank
)
{
...
...
deep_ep/buffer.py
View file @
21efbe9b
...
...
@@ -178,6 +178,7 @@ class Buffer:
config: the recommended config.
"""
# TODO: automatically tune
config_map
=
{
2
:
Config
(
Buffer
.
num_sms
,
24
,
256
,
6
,
128
),
4
:
Config
(
Buffer
.
num_sms
,
6
,
256
,
6
,
128
),
...
...
@@ -205,6 +206,7 @@ class Buffer:
config: the recommended config.
"""
# TODO: automatically tune
config_map
=
{
2
:
Config
(
Buffer
.
num_sms
,
10
,
256
,
6
,
128
),
4
:
Config
(
Buffer
.
num_sms
,
9
,
256
,
6
,
128
),
...
...
@@ -486,14 +488,14 @@ class Buffer:
def
low_latency_dispatch
(
self
,
x
:
torch
.
Tensor
,
topk_idx
:
torch
.
Tensor
,
num_max_dispatch_tokens_per_rank
:
int
,
num_experts
:
int
,
cumulative_local_expert_recv_stats
:
Optional
[
torch
.
Tensor
]
=
None
,
use_fp8
:
bool
=
True
,
async_finish
:
bool
=
False
,
return_recv_hook
:
bool
=
False
)
->
\
use_fp8
:
bool
=
True
,
round_scale
:
bool
=
False
,
use_ue8m0
:
bool
=
False
,
async_finish
:
bool
=
False
,
return_recv_hook
:
bool
=
False
)
->
\
Tuple
[
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
],
torch
.
Tensor
,
Tuple
,
EventOverlap
,
Callable
]:
"""
A low-latency implementation for dispatching with IBGDA.
This kernel requires all the ranks (no matter intranode or internode) should be visible via RDMA
(specifically, IBGDA must be enabled).
Even for ranks in the same node, NVLink are fully disabled for simplicity.
Warning: as there are only two buffers, and the returned tensors reuse the buffer, you can not hold more than 2
Warning: as there are only two buffers, and the returned tensors reuse the buffer, you cannot hold more than 2
low-latency kernels' result tensors at a single moment.
Arguments:
...
...
@@ -507,17 +509,21 @@ class Buffer:
`[num_local_experts]` and be typed as `torch.int`. This is useful for online service EP load balance
monitoring.
use_fp8: whether to enable FP8 casting, with this, the received data will be a tuple of FP8 tensor and scaling factors.
round_scale: whether round the scaling factors into power of 2.
use_ue8m0: whether use UE8M0 as scaling factor format (available only with `round_scale=True`).
async_finish: the current stream will not wait for the communication kernels to be finished if set.
return_recv_hook: return a receiving hook if set. If set, the kernel will just do the RDMA request issues,
but **without actually receiving the data**. You must call the received hook to make sure the data's arrival.
If you not set this flag, the kernel will ensure the data's arrival.
If you
do
not set this flag, the kernel will ensure the data's arrival.
Returns:
recv_x: a tensor or tuple with received tokens for each expert.
With `use_fp8=True`: the first element is a `torch.Tensor` shaped as
`[num_local_experts, num_max_dispatch_tokens_per_rank * num_ranks, hidden]` with `torch.float8_e4m3fn`.
The second tensor is the corresponding scales for the first element with shape
`[num_local_experts, num_max_dispatch_tokens_per_rank * num_ranks, hidden // 128]` with `torch.float`.
`[num_local_experts, num_max_dispatch_tokens_per_rank * num_ranks, hidden // 128]` with `torch.float`,
if `use_ue8m0=False`. With `use_ue8m0=True`, the second one is packed and shaped as
`[num_local_experts, num_max_dispatch_tokens_per_rank * num_ranks, hidden // 512]` with type `torch.int`.
Notice that, the last-two-dimension of the scaling tensors are in column-major for TMA compatibility.
With `use_fp8=False`, the result would be a tensor shaped as
`[num_local_experts, num_max_dispatch_tokens_per_rank * num_ranks, hidden]` with `torch.bfloat16`.
...
...
@@ -533,7 +539,8 @@ class Buffer:
self
.
runtime
.
low_latency_dispatch
(
x
,
topk_idx
,
cumulative_local_expert_recv_stats
,
num_max_dispatch_tokens_per_rank
,
num_experts
,
use_fp8
,
async_finish
,
return_recv_hook
)
use_fp8
,
round_scale
,
use_ue8m0
,
async_finish
,
return_recv_hook
)
handle
=
(
packed_recv_src_info
,
packed_recv_layout_range
,
num_max_dispatch_tokens_per_rank
,
x
.
size
(
1
),
num_experts
)
tensors_to_record
=
(
x
,
topk_idx
,
packed_recv_x
,
packed_recv_x_scales
,
packed_recv_count
,
...
...
@@ -551,9 +558,8 @@ class Buffer:
A low-latency implementation for combining tokens (reduce **with weights**) with IBGDA.
This kernel requires all the ranks (no matter intranode or internode) should be visible via RDMA
(specifically, IBGDA must be enabled).
Even for ranks in the same node, NVLink are fully disabled for simplicity.
Warning: as there are only two buffers, and the returned tensors reuse the buffer, you can not hold more than 2
low-latency kernels' result tensor at a single moment.
Warning: as there are only two buffers, and the returned tensors reuse the buffer, you cannot hold more than 2
low-latency kernels' result tensors at a single moment.
Arguments:
x: `[num_local_experts, num_max_dispatch_tokens_per_rank * num_ranks, hidden]` with `torch.bfloat16`,
...
...
@@ -569,7 +575,7 @@ class Buffer:
async_finish: the current stream will not wait for the communication kernels to be finished if set.
return_recv_hook: return a receiving hook if set. If set, the kernel will just do the RDMA request issues,
but **without actually receiving the data**. You must call the received hook to make sure the data's arrival.
If you not set this flag, the kernel will ensure the data's arrival.
If you
do
not set this flag, the kernel will ensure the data's arrival.
out: the in-place output tensor, if set, the kernel will write the result to this tensor and return it directly.
Returns:
...
...
install.sh
0 → 100755
View file @
21efbe9b
# Change current directory into project root
original_dir
=
$(
pwd
)
script_dir
=
$(
dirname
"
$0
"
)
cd
"
$script_dir
"
# Remove old dist file, build, and install
rm
-rf
dist
python setup.py bdist_wheel
pip
install
dist/
*
.whl
# Open users' original directory
cd
"
$original_dir
"
tests/test_internode.py
View file @
21efbe9b
...
...
@@ -22,6 +22,7 @@ def test_main(num_sms: int, local_rank: int, num_local_ranks: int, num_ranks: in
x
=
torch
.
ones
((
num_tokens
,
hidden
),
dtype
=
torch
.
bfloat16
,
device
=
'cuda'
)
*
rank
x_pure_rand
=
torch
.
randn
((
num_tokens
,
hidden
),
dtype
=
torch
.
bfloat16
,
device
=
'cuda'
)
x_e4m3
=
per_token_cast_to_fp8
(
x
)
x_e4m3
=
(
x_e4m3
[
0
],
x_e4m3
[
1
].
T
.
contiguous
().
T
)
scores
=
torch
.
randn
((
num_tokens
,
num_experts
),
dtype
=
torch
.
float32
,
device
=
'cuda'
).
abs
()
+
1
group_scores
=
scores
.
view
(
num_tokens
,
num_nodes
,
-
1
).
amax
(
dim
=-
1
)
group_idx
=
torch
.
topk
(
group_scores
,
k
=
num_topk_groups
,
dim
=-
1
,
sorted
=
False
).
indices
...
...
@@ -241,6 +242,10 @@ def test_loop(local_rank: int, num_local_ranks: int):
buffer
.
clean_low_latency_buffer
(
ll_num_tokens
,
ll_hidden
,
ll_num_experts
)
test_low_latency
.
test_main
(
ll_num_tokens
,
ll_hidden
,
ll_num_experts
,
ll_num_topk
,
rank
,
num_ranks
,
group
,
buffer
,
seed
=
1
)
# Destroy the communication group
dist
.
barrier
()
dist
.
destroy_process_group
()
if
__name__
==
'__main__'
:
num_processes
=
8
...
...
tests/test_intranode.py
View file @
21efbe9b
...
...
@@ -21,6 +21,7 @@ def test_main(num_sms: int, local_rank: int, num_ranks: int, rank: int, buffer:
x
=
torch
.
ones
((
num_tokens
,
hidden
),
dtype
=
torch
.
bfloat16
,
device
=
'cuda'
)
*
rank
x_pure_rand
=
torch
.
randn
((
num_tokens
,
hidden
),
dtype
=
torch
.
bfloat16
,
device
=
'cuda'
)
x_e4m3
=
per_token_cast_to_fp8
(
x
)
if
deep_ep
.
Buffer
.
is_sm90_compiled
()
else
None
x_e4m3
=
(
x_e4m3
[
0
],
x_e4m3
[
1
].
T
.
contiguous
().
T
)
if
x_e4m3
is
not
None
else
None
scores
=
torch
.
randn
((
num_tokens
,
num_experts
),
dtype
=
torch
.
float32
,
device
=
'cuda'
).
abs
()
+
1
topk_idx
=
torch
.
topk
(
scores
,
num_topk
,
dim
=-
1
,
largest
=
True
,
sorted
=
False
)[
1
]
topk_weights
=
torch
.
ones
((
num_tokens
,
num_topk
),
dtype
=
torch
.
float32
,
device
=
'cuda'
)
*
rank
...
...
tests/test_low_latency.py
View file @
21efbe9b
...
...
@@ -34,61 +34,68 @@ def test_main(num_tokens: int, hidden: int, num_experts: int, num_topk: int,
hash_value
,
num_times
=
0
,
0
for
return_recv_hook
in
(
False
,
True
):
for
dispatch_use_fp8
in
(
False
,
True
):
num_times
+=
1
for
i
in
range
((
num_times
%
2
)
+
1
):
cumulative_local_expert_recv_stats
=
torch
.
zeros
((
num_local_experts
,
),
dtype
=
torch
.
int
,
device
=
'cuda'
)
packed_recv_x
,
packed_recv_count
,
handle
,
event
,
hook
=
\
buffer
.
low_latency_dispatch
(
x
,
topk_idx
,
num_tokens
,
num_experts
,
use_fp8
=
dispatch_use_fp8
,
cumulative_local_expert_recv_stats
=
cumulative_local_expert_recv_stats
,
async_finish
=
not
return_recv_hook
,
return_recv_hook
=
return_recv_hook
)
hook
()
if
return_recv_hook
else
event
.
current_stream_wait
()
packed_recv_x
=
(
packed_recv_x
[
0
],
packed_recv_x
[
1
].
contiguous
())
if
dispatch_use_fp8
else
packed_recv_x
simulated_gemm_x
=
per_token_cast_back
(
packed_recv_x
[
0
].
view
(
-
1
,
hidden
),
packed_recv_x
[
1
].
view
(
-
1
,
hidden
//
128
)).
view
(
packed_recv_x
[
0
].
shape
)
\
if
dispatch_use_fp8
else
packed_recv_x
.
clone
()
all_topk_idx
=
torch
.
empty
((
num_ranks
,
num_tokens
,
num_topk
),
dtype
=
topk_idx
.
dtype
,
device
=
'cuda'
)
dist
.
all_gather_into_tensor
(
all_topk_idx
,
topk_idx
,
group
=
group
)
for
i
in
range
(
num_local_experts
if
do_check
else
0
):
expert_id
=
rank
*
num_local_experts
+
i
recv_x
=
per_token_cast_back
(
packed_recv_x
[
0
][
i
],
packed_recv_x
[
1
][
i
])
if
dispatch_use_fp8
else
packed_recv_x
[
i
]
recv_count
,
recv_src_info
,
recv_layout_range
=
packed_recv_count
[
i
],
handle
[
0
][
i
],
handle
[
1
][
i
]
# Check expert indices
int_mask
=
(
2
**
32
)
-
1
num_valid_tokens
=
recv_count
.
item
()
assert
cumulative_local_expert_recv_stats
[
i
].
item
()
==
num_valid_tokens
,
f
'
{
cumulative_local_expert_recv_stats
[
i
].
item
()
}
!=
{
num_valid_tokens
}
'
assert
num_valid_tokens
==
(
recv_layout_range
&
int_mask
).
sum
().
item
(),
f
'
{
num_valid_tokens
}
!=
{
recv_layout_range
&
int_mask
}
.sum().item()'
assert
num_valid_tokens
==
(
all_topk_idx
==
expert_id
).
sum
().
item
(),
f
'
{
num_valid_tokens
}
!=
{
(
all_topk_idx
==
expert_id
).
sum
().
item
()
}
'
# Check received data
recv_x
=
recv_x
[:
num_valid_tokens
]
recv_x_amin
=
recv_x
[:,
:
-
128
].
amin
(
dim
=-
1
)
recv_src_info
=
recv_src_info
[:
num_valid_tokens
]
assert
torch
.
equal
(
recv_x_amin
,
recv_x
[:,
:
-
128
].
amax
(
dim
=-
1
))
assert
(
recv_x
[:,
-
128
:]
-
recv_src_info
.
view
(
-
1
,
1
)
%
num_tokens
).
sum
().
item
()
==
0
for
j
in
range
(
num_ranks
):
begin_idx
,
count
=
(
recv_layout_range
[
j
]
>>
32
).
item
(),
(
recv_layout_range
[
j
]
&
int_mask
).
item
()
assert
(
recv_x_amin
==
j
-
rank_offset
).
sum
().
item
()
==
(
all_topk_idx
[
j
]
==
expert_id
).
sum
().
item
()
assert
(
recv_x
[
begin_idx
:
begin_idx
+
count
][:
-
128
]
-
j
).
sum
().
item
()
==
0
if
dispatch_use_fp8
:
hash_value
^=
hash_tensor
(
packed_recv_x
[
0
][
i
,
:
num_valid_tokens
])
hash_value
^=
hash_tensor
(
packed_recv_x
[
1
][
i
,
:
num_valid_tokens
])
else
:
hash_value
^=
hash_tensor
(
packed_recv_x
[
i
,
:
num_valid_tokens
])
# Check combine correctness
for
zero_copy
in
(
False
,
True
):
if
zero_copy
:
buffer
.
get_next_low_latency_combine_buffer
(
handle
)[:,
:,
:]
=
simulated_gemm_x
out
=
torch
.
empty
((
num_tokens
,
hidden
),
dtype
=
torch
.
bfloat16
,
device
=
'cuda'
)
combined_x
,
event
,
hook
=
buffer
.
low_latency_combine
(
simulated_gemm_x
,
topk_idx
,
topk_weights
,
handle
,
async_finish
=
not
return_recv_hook
,
zero_copy
=
zero_copy
,
return_recv_hook
=
return_recv_hook
,
out
=
out
)
hook
()
if
return_recv_hook
else
event
.
current_stream_wait
()
if
do_check
:
diff
=
calc_diff
(
x
*
topk_weights
.
masked_fill
(
topk_idx
==
-
1
,
0
).
sum
(
dim
=
1
).
view
(
-
1
,
1
),
combined_x
)
assert
torch
.
isnan
(
combined_x
).
sum
().
item
()
==
0
assert
diff
<
1e-5
,
f
'Error:
{
diff
=
}
,
{
zero_copy
=
}
'
hash_value
^=
hash_tensor
(
combined_x
)
for
round_scale
in
(
False
,
True
)
if
dispatch_use_fp8
else
(
False
,
):
for
use_ue8m0
in
(
False
,
True
)
if
round_scale
else
(
False
,
):
num_times
+=
1
for
i
in
range
((
num_times
%
2
)
+
1
):
cumulative_local_expert_recv_stats
=
torch
.
zeros
((
num_local_experts
,
),
dtype
=
torch
.
int
,
device
=
'cuda'
)
packed_recv_x
,
packed_recv_count
,
handle
,
event
,
hook
=
\
buffer
.
low_latency_dispatch
(
x
,
topk_idx
,
num_tokens
,
num_experts
,
use_fp8
=
dispatch_use_fp8
,
round_scale
=
round_scale
,
use_ue8m0
=
use_ue8m0
,
cumulative_local_expert_recv_stats
=
cumulative_local_expert_recv_stats
,
async_finish
=
not
return_recv_hook
,
return_recv_hook
=
return_recv_hook
)
hook
()
if
return_recv_hook
else
event
.
current_stream_wait
()
packed_recv_x
=
(
packed_recv_x
[
0
],
packed_recv_x
[
1
].
contiguous
())
if
dispatch_use_fp8
else
packed_recv_x
simulated_gemm_x
=
per_token_cast_back
(
packed_recv_x
[
0
].
view
(
-
1
,
hidden
),
packed_recv_x
[
1
].
view
(
-
1
,
hidden
//
128
)).
view
(
packed_recv_x
[
0
].
shape
)
\
if
dispatch_use_fp8
else
packed_recv_x
.
clone
()
all_topk_idx
=
torch
.
empty
((
num_ranks
,
num_tokens
,
num_topk
),
dtype
=
topk_idx
.
dtype
,
device
=
'cuda'
)
dist
.
all_gather_into_tensor
(
all_topk_idx
,
topk_idx
,
group
=
group
)
for
i
in
range
(
num_local_experts
if
do_check
else
0
):
expert_id
=
rank
*
num_local_experts
+
i
recv_x
=
per_token_cast_back
(
packed_recv_x
[
0
][
i
],
packed_recv_x
[
1
][
i
])
if
dispatch_use_fp8
else
packed_recv_x
[
i
]
recv_count
,
recv_src_info
,
recv_layout_range
=
packed_recv_count
[
i
],
handle
[
0
][
i
],
handle
[
1
][
i
]
# Check expert indices
int_mask
=
(
2
**
32
)
-
1
num_valid_tokens
=
recv_count
.
item
()
assert
cumulative_local_expert_recv_stats
[
i
].
item
()
==
num_valid_tokens
,
f
'
{
cumulative_local_expert_recv_stats
[
i
].
item
()
}
!=
{
num_valid_tokens
}
'
assert
num_valid_tokens
==
(
recv_layout_range
&
int_mask
).
sum
().
item
(),
f
'
{
num_valid_tokens
}
!=
{
recv_layout_range
&
int_mask
}
.sum().item()'
assert
num_valid_tokens
==
(
all_topk_idx
==
expert_id
).
sum
().
item
(),
f
'
{
num_valid_tokens
}
!=
{
(
all_topk_idx
==
expert_id
).
sum
().
item
()
}
'
# Check received data
recv_x
=
recv_x
[:
num_valid_tokens
]
recv_x_amin
=
recv_x
[:,
:
-
128
].
amin
(
dim
=-
1
)
recv_src_info
=
recv_src_info
[:
num_valid_tokens
]
assert
torch
.
equal
(
recv_x_amin
,
recv_x
[:,
:
-
128
].
amax
(
dim
=-
1
))
if
round_scale
:
assert
calc_diff
(
recv_x
[:,
-
1
],
recv_src_info
.
view
(
-
1
))
<
0.007
else
:
assert
(
recv_x
[:,
-
128
:]
-
recv_src_info
.
view
(
-
1
,
1
)
%
num_tokens
).
sum
().
item
()
==
0
for
j
in
range
(
num_ranks
):
begin_idx
,
count
=
(
recv_layout_range
[
j
]
>>
32
).
item
(),
(
recv_layout_range
[
j
]
&
int_mask
).
item
()
if
not
round_scale
:
assert
(
recv_x_amin
==
j
-
rank_offset
).
sum
().
item
()
==
(
all_topk_idx
[
j
]
==
expert_id
).
sum
().
item
()
assert
(
recv_x
[
begin_idx
:
begin_idx
+
count
][:
-
128
]
-
j
).
sum
().
item
()
==
0
if
dispatch_use_fp8
:
hash_value
^=
hash_tensor
(
packed_recv_x
[
0
][
i
,
:
num_valid_tokens
])
hash_value
^=
hash_tensor
(
packed_recv_x
[
1
][
i
,
:
num_valid_tokens
])
else
:
hash_value
^=
hash_tensor
(
packed_recv_x
[
i
,
:
num_valid_tokens
])
# Check combine correctness
for
zero_copy
in
(
False
,
True
):
if
zero_copy
:
buffer
.
get_next_low_latency_combine_buffer
(
handle
)[:,
:,
:]
=
simulated_gemm_x
out
=
torch
.
empty
((
num_tokens
,
hidden
),
dtype
=
torch
.
bfloat16
,
device
=
'cuda'
)
combined_x
,
event
,
hook
=
buffer
.
low_latency_combine
(
simulated_gemm_x
,
topk_idx
,
topk_weights
,
handle
,
async_finish
=
not
return_recv_hook
,
zero_copy
=
zero_copy
,
return_recv_hook
=
return_recv_hook
,
out
=
out
)
hook
()
if
return_recv_hook
else
event
.
current_stream_wait
()
if
do_check
:
diff
=
calc_diff
(
x
*
topk_weights
.
masked_fill
(
topk_idx
==
-
1
,
0
).
sum
(
dim
=
1
).
view
(
-
1
,
1
),
combined_x
)
assert
torch
.
isnan
(
combined_x
).
sum
().
item
()
==
0
assert
diff
<
(
7e-4
if
round_scale
else
1e-5
),
f
'Error:
{
diff
=
}
,
{
zero_copy
=
}
'
hash_value
^=
hash_tensor
(
combined_x
)
def
create_test_cast_with_outliers
(
num_outliers
):
tmp
=
torch
.
randn
((
num_tokens
,
hidden
),
dtype
=
torch
.
bfloat16
,
device
=
'cuda'
)
...
...
@@ -112,7 +119,7 @@ def test_main(num_tokens: int, hidden: int, num_experts: int, num_topk: int,
recv_x
,
recv_count
,
handle
,
event
,
hook
=
\
buffer
.
low_latency_dispatch
(
x
,
topk_idx
,
num_tokens
,
num_experts
,
cumulative_local_expert_recv_stats
=
cumulative_local_expert_recv_stats
,
async_finish
=
False
,
return_recv_hook
=
return_recv_hook
)
use_fp8
=
True
,
async_finish
=
False
,
return_recv_hook
=
return_recv_hook
)
large_gemm_with_hook
(
hook
)
if
return_recv_hook
else
None
if
zero_copy
:
buffer
.
get_next_low_latency_combine_buffer
(
handle
)[:,
:,
:]
=
simulated_gemm_x
...
...
@@ -170,6 +177,10 @@ def test_loop(local_rank: int, num_local_ranks: int):
for
i
in
range
(
20
):
assert
test_main
(
num_tokens
,
hidden
,
num_experts
,
num_topk
,
rank
,
num_ranks
,
group
,
buffer
,
seed
=
seed
)
==
ref_hash
,
f
'Error: seed=
{
seed
}
'
# Destroy the communication group
dist
.
barrier
()
dist
.
destroy_process_group
()
if
__name__
==
'__main__'
:
# TODO: you may modify NUMA binding for less CPU overhead
...
...
tests/utils.py
View file @
21efbe9b
...
...
@@ -43,6 +43,9 @@ def per_token_cast_to_fp8(x: torch.Tensor):
def
per_token_cast_back
(
x_fp8
:
torch
.
Tensor
,
x_scales
:
torch
.
Tensor
):
if
x_scales
.
dtype
==
torch
.
int
:
x_scales
=
x_scales
.
view
(
dtype
=
torch
.
int8
).
to
(
torch
.
int
)
<<
23
x_scales
=
x_scales
.
view
(
dtype
=
torch
.
float
)
x_fp32
=
x_fp8
.
to
(
torch
.
float32
).
view
(
x_fp8
.
size
(
0
),
-
1
,
128
)
x_scales
=
x_scales
.
view
(
x_fp8
.
size
(
0
),
-
1
,
1
)
return
(
x_fp32
*
x_scales
).
view
(
x_fp8
.
shape
).
to
(
torch
.
bfloat16
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment