Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
DeepEP
Commits
21efbe9b
"vscode:/vscode.git/clone" did not exist on "bd6c5aa925f4cbe9319cd1530f9732f40407bf1e"
Unverified
Commit
21efbe9b
authored
Jun 12, 2025
by
Shifang Xu
Committed by
GitHub
Jun 12, 2025
Browse files
Support UE8M0 data format. (#206)
parent
9ec06120
Changes
14
Show whitespace changes
Inline
Side-by-side
Showing
14 changed files
with
258 additions
and
118 deletions
+258
-118
csrc/deep_ep.cpp
csrc/deep_ep.cpp
+38
-19
csrc/deep_ep.hpp
csrc/deep_ep.hpp
+2
-1
csrc/kernels/CMakeLists.txt
csrc/kernels/CMakeLists.txt
+2
-2
csrc/kernels/api.cuh
csrc/kernels/api.cuh
+6
-3
csrc/kernels/internode.cu
csrc/kernels/internode.cu
+11
-4
csrc/kernels/internode_ll.cu
csrc/kernels/internode_ll.cu
+48
-21
csrc/kernels/intranode.cu
csrc/kernels/intranode.cu
+10
-2
csrc/kernels/utils.cuh
csrc/kernels/utils.cuh
+37
-0
deep_ep/buffer.py
deep_ep/buffer.py
+16
-10
install.sh
install.sh
+12
-0
tests/test_internode.py
tests/test_internode.py
+5
-0
tests/test_intranode.py
tests/test_intranode.py
+1
-0
tests/test_low_latency.py
tests/test_low_latency.py
+67
-56
tests/utils.py
tests/utils.py
+3
-0
No files found.
csrc/deep_ep.cpp
View file @
21efbe9b
...
@@ -359,14 +359,16 @@ Buffer::intranode_dispatch(const torch::Tensor& x, const std::optional<torch::Te
...
@@ -359,14 +359,16 @@ Buffer::intranode_dispatch(const torch::Tensor& x, const std::optional<torch::Te
// FP8 scales checks
// FP8 scales checks
float
*
x_scales_ptr
=
nullptr
;
float
*
x_scales_ptr
=
nullptr
;
int
num_scales
=
0
;
int
num_scales
=
0
,
scale_token_stride
=
0
,
scale_hidden_stride
=
0
;
if
(
x_scales
.
has_value
())
{
if
(
x_scales
.
has_value
())
{
EP_HOST_ASSERT
(
x
.
element_size
()
==
1
);
EP_HOST_ASSERT
(
x
.
element_size
()
==
1
);
EP_HOST_ASSERT
(
x_scales
->
scalar_type
()
==
torch
::
kFloat32
);
EP_HOST_ASSERT
(
x_scales
->
scalar_type
()
==
torch
::
kFloat32
or
x_scales
->
scalar_type
()
==
torch
::
kInt
);
EP_HOST_ASSERT
(
x_scales
->
dim
()
>
0
and
x_scales
->
dim
()
<
3
and
x_scales
->
is_contiguous
()
);
EP_HOST_ASSERT
(
x_scales
->
dim
()
==
2
);
EP_HOST_ASSERT
(
x_scales
->
size
(
0
)
==
num_tokens
);
EP_HOST_ASSERT
(
x_scales
->
size
(
0
)
==
num_tokens
);
num_scales
=
x_scales
->
dim
()
==
1
?
1
:
static_cast
<
int
>
(
x_scales
->
size
(
1
));
num_scales
=
x_scales
->
dim
()
==
1
?
1
:
static_cast
<
int
>
(
x_scales
->
size
(
1
));
x_scales_ptr
=
x_scales
->
data_ptr
<
float
>
();
x_scales_ptr
=
static_cast
<
float
*>
(
x_scales
->
data_ptr
());
scale_token_stride
=
static_cast
<
int
>
(
x_scales
->
stride
(
0
));
scale_hidden_stride
=
static_cast
<
int
>
(
x_scales
->
stride
(
1
));
}
}
// Allocate all tensors on comm stream if set
// Allocate all tensors on comm stream if set
...
@@ -474,7 +476,7 @@ Buffer::intranode_dispatch(const torch::Tensor& x, const std::optional<torch::Te
...
@@ -474,7 +476,7 @@ Buffer::intranode_dispatch(const torch::Tensor& x, const std::optional<torch::Te
recv_x_scales
=
x_scales
->
dim
()
==
1
?
recv_x_scales
=
x_scales
->
dim
()
==
1
?
torch
::
empty
({
num_recv_tokens
},
x_scales
->
options
())
:
torch
::
empty
({
num_recv_tokens
},
x_scales
->
options
())
:
torch
::
empty
({
num_recv_tokens
,
num_scales
},
x_scales
->
options
());
torch
::
empty
({
num_recv_tokens
,
num_scales
},
x_scales
->
options
());
recv_x_scales_ptr
=
recv_x_scales
->
data_ptr
<
float
>
();
recv_x_scales_ptr
=
static_cast
<
float
*>
(
recv_x_scales
->
data_ptr
(
)
);
}
}
// Dispatch
// Dispatch
...
@@ -492,7 +494,9 @@ Buffer::intranode_dispatch(const torch::Tensor& x, const std::optional<torch::Te
...
@@ -492,7 +494,9 @@ Buffer::intranode_dispatch(const torch::Tensor& x, const std::optional<torch::Te
send_head
.
data_ptr
<
int
>
(),
send_head
.
data_ptr
<
int
>
(),
x
.
data_ptr
(),
x_scales_ptr
,
topk_idx_ptr
,
topk_weights_ptr
,
x
.
data_ptr
(),
x_scales_ptr
,
topk_idx_ptr
,
topk_weights_ptr
,
is_token_in_rank
.
data_ptr
<
bool
>
(),
channel_prefix_matrix
.
data_ptr
<
int
>
(),
is_token_in_rank
.
data_ptr
<
bool
>
(),
channel_prefix_matrix
.
data_ptr
<
int
>
(),
num_tokens
,
num_worst_tokens
,
static_cast
<
int
>
(
hidden
*
recv_x
.
element_size
()
/
sizeof
(
int4
)),
num_topk
,
num_experts
,
num_scales
,
num_tokens
,
num_worst_tokens
,
static_cast
<
int
>
(
hidden
*
recv_x
.
element_size
()
/
sizeof
(
int4
)),
num_topk
,
num_experts
,
num_scales
,
scale_token_stride
,
scale_hidden_stride
,
buffer_ptrs_gpu
,
rank
,
num_ranks
,
comm_stream
,
config
.
num_sms
,
buffer_ptrs_gpu
,
rank
,
num_ranks
,
comm_stream
,
config
.
num_sms
,
config
.
num_max_nvl_chunked_send_tokens
,
config
.
num_max_nvl_chunked_recv_tokens
);
config
.
num_max_nvl_chunked_send_tokens
,
config
.
num_max_nvl_chunked_recv_tokens
);
...
@@ -708,14 +712,16 @@ Buffer::internode_dispatch(const torch::Tensor& x, const std::optional<torch::Te
...
@@ -708,14 +712,16 @@ Buffer::internode_dispatch(const torch::Tensor& x, const std::optional<torch::Te
// FP8 scales checks
// FP8 scales checks
float
*
x_scales_ptr
=
nullptr
;
float
*
x_scales_ptr
=
nullptr
;
int
num_scales
=
0
;
int
num_scales
=
0
,
scale_token_stride
=
0
,
scale_hidden_stride
=
0
;
if
(
x_scales
.
has_value
())
{
if
(
x_scales
.
has_value
())
{
EP_HOST_ASSERT
(
x
.
element_size
()
==
1
);
EP_HOST_ASSERT
(
x
.
element_size
()
==
1
);
EP_HOST_ASSERT
(
x_scales
->
scalar_type
()
==
torch
::
kFloat32
);
EP_HOST_ASSERT
(
x_scales
->
scalar_type
()
==
torch
::
kFloat32
or
x_scales
->
scalar_type
()
==
torch
::
kInt
);
EP_HOST_ASSERT
(
x_scales
->
dim
()
>
0
and
x_scales
->
dim
()
<
3
and
x_scales
->
is_contiguous
()
);
EP_HOST_ASSERT
(
x_scales
->
dim
()
==
2
);
EP_HOST_ASSERT
(
x_scales
->
size
(
0
)
==
num_tokens
);
EP_HOST_ASSERT
(
x_scales
->
size
(
0
)
==
num_tokens
);
num_scales
=
x_scales
->
dim
()
==
1
?
1
:
static_cast
<
int
>
(
x_scales
->
size
(
1
));
num_scales
=
x_scales
->
dim
()
==
1
?
1
:
static_cast
<
int
>
(
x_scales
->
size
(
1
));
x_scales_ptr
=
x_scales
->
data_ptr
<
float
>
();
x_scales_ptr
=
static_cast
<
float
*>
(
x_scales
->
data_ptr
());
scale_token_stride
=
static_cast
<
int
>
(
x_scales
->
stride
(
0
));
scale_hidden_stride
=
static_cast
<
int
>
(
x_scales
->
stride
(
1
));
}
}
// Allocate all tensors on comm stream if set
// Allocate all tensors on comm stream if set
...
@@ -838,7 +844,7 @@ Buffer::internode_dispatch(const torch::Tensor& x, const std::optional<torch::Te
...
@@ -838,7 +844,7 @@ Buffer::internode_dispatch(const torch::Tensor& x, const std::optional<torch::Te
recv_x_scales
=
x_scales
->
dim
()
==
1
?
recv_x_scales
=
x_scales
->
dim
()
==
1
?
torch
::
empty
({
num_recv_tokens
},
x_scales
->
options
())
:
torch
::
empty
({
num_recv_tokens
},
x_scales
->
options
())
:
torch
::
empty
({
num_recv_tokens
,
num_scales
},
x_scales
->
options
());
torch
::
empty
({
num_recv_tokens
,
num_scales
},
x_scales
->
options
());
recv_x_scales_ptr
=
recv_x_scales
->
data_ptr
<
float
>
();
recv_x_scales_ptr
=
static_cast
<
float
*>
(
recv_x_scales
->
data_ptr
(
)
);
}
}
// Launch data dispatch
// Launch data dispatch
...
@@ -851,8 +857,9 @@ Buffer::internode_dispatch(const torch::Tensor& x, const std::optional<torch::Te
...
@@ -851,8 +857,9 @@ Buffer::internode_dispatch(const torch::Tensor& x, const std::optional<torch::Te
cached_mode
?
nullptr
:
recv_gbl_channel_prefix_matrix
->
data_ptr
<
int
>
(),
cached_mode
?
nullptr
:
recv_gbl_channel_prefix_matrix
->
data_ptr
<
int
>
(),
rdma_channel_prefix_matrix
.
data_ptr
<
int
>
(),
recv_rdma_rank_prefix_sum
.
data_ptr
<
int
>
(),
rdma_channel_prefix_matrix
.
data_ptr
<
int
>
(),
recv_rdma_rank_prefix_sum
.
data_ptr
<
int
>
(),
gbl_channel_prefix_matrix
.
data_ptr
<
int
>
(),
recv_gbl_rank_prefix_sum
.
data_ptr
<
int
>
(),
gbl_channel_prefix_matrix
.
data_ptr
<
int
>
(),
recv_gbl_rank_prefix_sum
.
data_ptr
<
int
>
(),
num_tokens
,
hidden_int4
,
num_scales
,
num_topk
,
num_experts
,
is_token_in_rank
.
data_ptr
<
bool
>
(),
is_token_in_rank
.
data_ptr
<
bool
>
(),
num_tokens
,
hidden_int4
,
num_scales
,
num_topk
,
num_experts
,
scale_token_stride
,
scale_hidden_stride
,
rdma_buffer_ptr
,
config
.
num_max_rdma_chunked_send_tokens
,
config
.
num_max_rdma_chunked_recv_tokens
,
rdma_buffer_ptr
,
config
.
num_max_rdma_chunked_send_tokens
,
config
.
num_max_rdma_chunked_recv_tokens
,
buffer_ptrs_gpu
,
config
.
num_max_nvl_chunked_send_tokens
,
config
.
num_max_nvl_chunked_recv_tokens
,
buffer_ptrs_gpu
,
config
.
num_max_nvl_chunked_send_tokens
,
config
.
num_max_nvl_chunked_recv_tokens
,
rank
,
num_ranks
,
cached_mode
,
rank
,
num_ranks
,
cached_mode
,
...
@@ -1057,7 +1064,8 @@ std::tuple<torch::Tensor, std::optional<torch::Tensor>, torch::Tensor, torch::Te
...
@@ -1057,7 +1064,8 @@ std::tuple<torch::Tensor, std::optional<torch::Tensor>, torch::Tensor, torch::Te
Buffer
::
low_latency_dispatch
(
const
torch
::
Tensor
&
x
,
const
torch
::
Tensor
&
topk_idx
,
Buffer
::
low_latency_dispatch
(
const
torch
::
Tensor
&
x
,
const
torch
::
Tensor
&
topk_idx
,
const
std
::
optional
<
torch
::
Tensor
>&
cumulative_local_expert_recv_stats
,
const
std
::
optional
<
torch
::
Tensor
>&
cumulative_local_expert_recv_stats
,
int
num_max_dispatch_tokens_per_rank
,
int
num_experts
,
int
num_max_dispatch_tokens_per_rank
,
int
num_experts
,
bool
use_fp8
,
bool
async
,
bool
return_recv_hook
)
{
bool
use_fp8
,
bool
round_scale
,
bool
use_ue8m0
,
bool
async
,
bool
return_recv_hook
)
{
#ifndef DISABLE_NVSHMEM
#ifndef DISABLE_NVSHMEM
EP_HOST_ASSERT
(
low_latency_mode
);
EP_HOST_ASSERT
(
low_latency_mode
);
...
@@ -1077,7 +1085,7 @@ Buffer::low_latency_dispatch(const torch::Tensor& x, const torch::Tensor& topk_i
...
@@ -1077,7 +1085,7 @@ Buffer::low_latency_dispatch(const torch::Tensor& x, const torch::Tensor& topk_i
auto
num_tokens
=
static_cast
<
int
>
(
x
.
size
(
0
)),
hidden
=
static_cast
<
int
>
(
x
.
size
(
1
));
auto
num_tokens
=
static_cast
<
int
>
(
x
.
size
(
0
)),
hidden
=
static_cast
<
int
>
(
x
.
size
(
1
));
auto
num_scales
=
hidden
/
128
,
num_topk
=
static_cast
<
int
>
(
topk_idx
.
size
(
1
));
auto
num_scales
=
hidden
/
128
,
num_topk
=
static_cast
<
int
>
(
topk_idx
.
size
(
1
));
int
num_local_experts
=
num_experts
/
num_ranks
;
auto
num_local_experts
=
num_experts
/
num_ranks
;
// Buffer control
// Buffer control
LowLatencyLayout
layout
(
rdma_buffer_ptr
,
num_max_dispatch_tokens_per_rank
,
hidden
,
num_ranks
,
num_experts
);
LowLatencyLayout
layout
(
rdma_buffer_ptr
,
num_max_dispatch_tokens_per_rank
,
hidden
,
num_ranks
,
num_experts
);
...
@@ -1102,12 +1110,22 @@ Buffer::low_latency_dispatch(const torch::Tensor& x, const torch::Tensor& topk_i
...
@@ -1102,12 +1110,22 @@ Buffer::low_latency_dispatch(const torch::Tensor& x, const torch::Tensor& topk_i
// Allocate column-majored scales
// Allocate column-majored scales
auto
packed_recv_x_scales
=
std
::
optional
<
torch
::
Tensor
>
();
auto
packed_recv_x_scales
=
std
::
optional
<
torch
::
Tensor
>
();
float
*
packed_recv_x_scales_ptr
=
nullptr
;
void
*
packed_recv_x_scales_ptr
=
nullptr
;
if
(
use_fp8
)
{
EP_HOST_ASSERT
((
num_ranks
*
num_max_dispatch_tokens_per_rank
)
%
4
==
0
and
"TMA requires the number of tokens to be multiple of 4"
);
EP_HOST_ASSERT
((
num_ranks
*
num_max_dispatch_tokens_per_rank
)
%
4
==
0
and
"TMA requires the number of tokens to be multiple of 4"
);
packed_recv_x_scales
=
torch
::
empty
({
num_local_experts
,
num_scales
,
num_ranks
*
num_max_dispatch_tokens_per_rank
},
torch
::
dtype
(
torch
::
kFloat32
).
device
(
torch
::
kCUDA
));
if
(
use_fp8
)
{
// TODO: support unaligned cases
EP_HOST_ASSERT
(
hidden
%
512
==
0
);
if
(
not
use_ue8m0
)
{
packed_recv_x_scales
=
torch
::
empty
({
num_local_experts
,
hidden
/
128
,
num_ranks
*
num_max_dispatch_tokens_per_rank
},
torch
::
dtype
(
torch
::
kFloat32
).
device
(
torch
::
kCUDA
));
}
else
{
EP_HOST_ASSERT
(
round_scale
);
packed_recv_x_scales
=
torch
::
empty
({
num_local_experts
,
hidden
/
512
,
num_ranks
*
num_max_dispatch_tokens_per_rank
},
torch
::
dtype
(
torch
::
kInt
).
device
(
torch
::
kCUDA
));
}
packed_recv_x_scales
=
torch
::
transpose
(
packed_recv_x_scales
.
value
(),
1
,
2
);
packed_recv_x_scales
=
torch
::
transpose
(
packed_recv_x_scales
.
value
(),
1
,
2
);
packed_recv_x_scales_ptr
=
packed_recv_x_scales
->
data_ptr
<
float
>
();
packed_recv_x_scales_ptr
=
packed_recv_x_scales
->
data_ptr
();
}
}
// Kernel launch
// Kernel launch
...
@@ -1122,7 +1140,8 @@ Buffer::low_latency_dispatch(const torch::Tensor& x, const torch::Tensor& topk_i
...
@@ -1122,7 +1140,8 @@ Buffer::low_latency_dispatch(const torch::Tensor& x, const torch::Tensor& topk_i
x
.
data_ptr
(),
topk_idx
.
data_ptr
<
int64_t
>
(),
x
.
data_ptr
(),
topk_idx
.
data_ptr
<
int64_t
>
(),
next_clean_meta
.
first
,
next_clean_meta
.
second
,
next_clean_meta
.
first
,
next_clean_meta
.
second
,
num_tokens
,
hidden
,
num_max_dispatch_tokens_per_rank
,
num_tokens
,
hidden
,
num_max_dispatch_tokens_per_rank
,
num_topk
,
num_experts
,
rank
,
num_ranks
,
use_fp8
,
num_topk
,
num_experts
,
rank
,
num_ranks
,
use_fp8
,
round_scale
,
use_ue8m0
,
workspace
,
low_latency_usage_flag_mapped
,
launch_stream
,
phases
);
workspace
,
low_latency_usage_flag_mapped
,
launch_stream
,
phases
);
};
};
launcher
(
return_recv_hook
?
LOW_LATENCY_SEND_PHASE
:
(
LOW_LATENCY_SEND_PHASE
|
LOW_LATENCY_RECV_PHASE
));
launcher
(
return_recv_hook
?
LOW_LATENCY_SEND_PHASE
:
(
LOW_LATENCY_SEND_PHASE
|
LOW_LATENCY_RECV_PHASE
));
...
...
csrc/deep_ep.hpp
View file @
21efbe9b
...
@@ -141,7 +141,8 @@ public:
...
@@ -141,7 +141,8 @@ public:
low_latency_dispatch
(
const
torch
::
Tensor
&
x
,
const
torch
::
Tensor
&
topk_idx
,
low_latency_dispatch
(
const
torch
::
Tensor
&
x
,
const
torch
::
Tensor
&
topk_idx
,
const
std
::
optional
<
torch
::
Tensor
>&
cumulative_local_expert_recv_stats
,
const
std
::
optional
<
torch
::
Tensor
>&
cumulative_local_expert_recv_stats
,
int
num_max_dispatch_tokens_per_rank
,
int
num_experts
,
int
num_max_dispatch_tokens_per_rank
,
int
num_experts
,
bool
use_fp8
,
bool
async
,
bool
return_recv_hook
);
bool
use_fp8
,
bool
round_scale
,
bool
use_ue8m0
,
bool
async
,
bool
return_recv_hook
);
std
::
tuple
<
torch
::
Tensor
,
std
::
optional
<
EventHandle
>
,
std
::
optional
<
std
::
function
<
void
()
>>>
std
::
tuple
<
torch
::
Tensor
,
std
::
optional
<
EventHandle
>
,
std
::
optional
<
std
::
function
<
void
()
>>>
low_latency_combine
(
const
torch
::
Tensor
&
x
,
const
torch
::
Tensor
&
topk_idx
,
const
torch
::
Tensor
&
topk_weights
,
low_latency_combine
(
const
torch
::
Tensor
&
x
,
const
torch
::
Tensor
&
topk_idx
,
const
torch
::
Tensor
&
topk_weights
,
...
...
csrc/kernels/CMakeLists.txt
View file @
21efbe9b
...
@@ -4,8 +4,8 @@ function(add_deep_ep_library target_name source_file)
...
@@ -4,8 +4,8 @@ function(add_deep_ep_library target_name source_file)
POSITION_INDEPENDENT_CODE ON
POSITION_INDEPENDENT_CODE ON
CXX_STANDARD_REQUIRED ON
CXX_STANDARD_REQUIRED ON
CUDA_STANDARD_REQUIRED ON
CUDA_STANDARD_REQUIRED ON
CXX_STANDARD 1
4
CXX_STANDARD 1
7
CUDA_STANDARD 1
4
CUDA_STANDARD 1
7
CUDA_SEPARABLE_COMPILATION ON
CUDA_SEPARABLE_COMPILATION ON
)
)
target_link_libraries
(
${
target_name
}
PUBLIC nvshmem cudart cudadevrt mlx5
)
target_link_libraries
(
${
target_name
}
PUBLIC nvshmem cudart cudadevrt mlx5
)
...
...
csrc/kernels/api.cuh
View file @
21efbe9b
...
@@ -57,6 +57,7 @@ void dispatch(void* recv_x, float* recv_x_scales, int* recv_src_idx, int64_t* re
...
@@ -57,6 +57,7 @@ void dispatch(void* recv_x, float* recv_x_scales, int* recv_src_idx, int64_t* re
int
*
send_head
,
const
void
*
x
,
const
float
*
x_scales
,
const
int64_t
*
topk_idx
,
const
float
*
topk_weights
,
int
*
send_head
,
const
void
*
x
,
const
float
*
x_scales
,
const
int64_t
*
topk_idx
,
const
float
*
topk_weights
,
const
bool
*
is_token_in_rank
,
const
int
*
channel_prefix_matrix
,
const
bool
*
is_token_in_rank
,
const
int
*
channel_prefix_matrix
,
int
num_tokens
,
int
num_worst_tokens
,
int
hidden_int4
,
int
num_topk
,
int
num_experts
,
int
num_scales
,
int
num_tokens
,
int
num_worst_tokens
,
int
hidden_int4
,
int
num_topk
,
int
num_experts
,
int
num_scales
,
int
scale_token_stride
,
int
scale_hidden_stride
,
void
**
buffer_ptrs
,
int
rank
,
int
num_ranks
,
void
**
buffer_ptrs
,
int
rank
,
int
num_ranks
,
cudaStream_t
stream
,
int
num_sms
,
cudaStream_t
stream
,
int
num_sms
,
int
num_max_send_tokens
,
int
num_recv_buffer_tokens
);
int
num_max_send_tokens
,
int
num_recv_buffer_tokens
);
...
@@ -99,8 +100,9 @@ void dispatch(void* recv_x, float* recv_x_scales, int64_t* recv_topk_idx, float*
...
@@ -99,8 +100,9 @@ void dispatch(void* recv_x, float* recv_x_scales, int64_t* recv_topk_idx, float*
int
*
recv_rdma_channel_prefix_matrix
,
int
*
recv_gbl_channel_prefix_matrix
,
int
*
recv_rdma_channel_prefix_matrix
,
int
*
recv_gbl_channel_prefix_matrix
,
const
int
*
rdma_channel_prefix_matrix
,
const
int
*
recv_rdma_rank_prefix_sum
,
const
int
*
rdma_channel_prefix_matrix
,
const
int
*
recv_rdma_rank_prefix_sum
,
const
int
*
gbl_channel_prefix_matrix
,
const
int
*
recv_gbl_rank_prefix_sum
,
const
int
*
gbl_channel_prefix_matrix
,
const
int
*
recv_gbl_rank_prefix_sum
,
int
num_tokens
,
int
hidden_int4
,
int
num_scales
,
int
num_topk
,
int
num_experts
,
const
bool
*
is_token_in_rank
,
const
bool
*
is_token_in_rank
,
int
num_tokens
,
int
hidden_int4
,
int
num_scales
,
int
num_topk
,
int
num_experts
,
int
scale_token_stride
,
int
scale_hidden_stride
,
void
*
rdma_buffer_ptr
,
int
num_max_rdma_chunked_send_tokens
,
int
num_max_rdma_chunked_recv_tokens
,
void
*
rdma_buffer_ptr
,
int
num_max_rdma_chunked_send_tokens
,
int
num_max_rdma_chunked_recv_tokens
,
void
**
buffer_ptrs
,
int
num_max_nvl_chunked_send_tokens
,
int
num_max_nvl_chunked_recv_tokens
,
void
**
buffer_ptrs
,
int
num_max_nvl_chunked_send_tokens
,
int
num_max_nvl_chunked_recv_tokens
,
int
rank
,
int
num_ranks
,
bool
is_cached_dispatch
,
int
rank
,
int
num_ranks
,
bool
is_cached_dispatch
,
...
@@ -135,7 +137,7 @@ void clean_low_latency_buffer(int* clean_0, int num_clean_int_0,
...
@@ -135,7 +137,7 @@ void clean_low_latency_buffer(int* clean_0, int num_clean_int_0,
int
*
clean_1
,
int
num_clean_int_1
,
int
*
clean_1
,
int
num_clean_int_1
,
cudaStream_t
stream
);
cudaStream_t
stream
);
void
dispatch
(
void
*
packed_recv_x
,
float
*
packed_recv_x_scales
,
void
dispatch
(
void
*
packed_recv_x
,
void
*
packed_recv_x_scales
,
int
*
packed_recv_src_info
,
int64_t
*
packed_recv_layout_range
,
int
*
packed_recv_src_info
,
int64_t
*
packed_recv_layout_range
,
int
*
packed_recv_count
,
int
*
packed_recv_count
,
int
*
cumulative_local_expert_recv_stats
,
int
*
cumulative_local_expert_recv_stats
,
...
@@ -143,7 +145,8 @@ void dispatch(void* packed_recv_x, float* packed_recv_x_scales,
...
@@ -143,7 +145,8 @@ void dispatch(void* packed_recv_x, float* packed_recv_x_scales,
const
void
*
x
,
const
int64_t
*
topk_idx
,
const
void
*
x
,
const
int64_t
*
topk_idx
,
int
*
next_clean
,
int
num_next_clean_int
,
int
*
next_clean
,
int
num_next_clean_int
,
int
num_tokens
,
int
hidden
,
int
num_max_dispatch_tokens_per_rank
,
int
num_tokens
,
int
hidden
,
int
num_max_dispatch_tokens_per_rank
,
int
num_topk
,
int
num_experts
,
int
rank
,
int
num_ranks
,
bool
use_fp8
,
int
num_topk
,
int
num_experts
,
int
rank
,
int
num_ranks
,
bool
use_fp8
,
bool
round_scale
,
bool
use_ue8m0
,
void
*
workspace
,
int
*
usage_flag
,
void
*
workspace
,
int
*
usage_flag
,
cudaStream_t
stream
,
int
phases
);
cudaStream_t
stream
,
int
phases
);
...
...
csrc/kernels/internode.cu
View file @
21efbe9b
...
@@ -343,8 +343,9 @@ dispatch(int4* recv_x, float* recv_x_scales, int64_t* recv_topk_idx, float* recv
...
@@ -343,8 +343,9 @@ dispatch(int4* recv_x, float* recv_x_scales, int64_t* recv_topk_idx, float* recv
int
*
recv_rdma_channel_prefix_matrix
,
int
*
recv_gbl_channel_prefix_matrix
,
int
*
recv_rdma_channel_prefix_matrix
,
int
*
recv_gbl_channel_prefix_matrix
,
const
int
*
rdma_channel_prefix_matrix
,
const
int
*
recv_rdma_rank_prefix_sum
,
const
int
*
rdma_channel_prefix_matrix
,
const
int
*
recv_rdma_rank_prefix_sum
,
const
int
*
gbl_channel_prefix_matrix
,
const
int
*
recv_gbl_rank_prefix_sum
,
const
int
*
gbl_channel_prefix_matrix
,
const
int
*
recv_gbl_rank_prefix_sum
,
int
num_tokens
,
int
hidden_int4
,
int
num_scales
,
int
num_topk
,
int
num_experts
,
const
bool
*
is_token_in_rank
,
const
bool
*
is_token_in_rank
,
int
num_tokens
,
int
hidden_int4
,
int
num_scales
,
int
num_topk
,
int
num_experts
,
int
scale_token_stride
,
int
scale_hidden_stride
,
void
*
rdma_buffer_ptr
,
int
num_max_rdma_chunked_send_tokens
,
int
num_max_rdma_chunked_recv_tokens
,
void
*
rdma_buffer_ptr
,
int
num_max_rdma_chunked_send_tokens
,
int
num_max_rdma_chunked_recv_tokens
,
void
**
buffer_ptrs
,
int
num_max_nvl_chunked_send_tokens
,
int
num_max_nvl_chunked_recv_tokens
,
void
**
buffer_ptrs
,
int
num_max_nvl_chunked_send_tokens
,
int
num_max_nvl_chunked_recv_tokens
,
int
rank
,
int
num_ranks
)
{
int
rank
,
int
num_ranks
)
{
...
@@ -536,7 +537,8 @@ dispatch(int4* recv_x, float* recv_x_scales, int64_t* recv_topk_idx, float* recv
...
@@ -536,7 +537,8 @@ dispatch(int4* recv_x, float* recv_x_scales, int64_t* recv_topk_idx, float* recv
// Copy `x_scales` into symmetric send buffer
// Copy `x_scales` into symmetric send buffer
#pragma unroll
#pragma unroll
for
(
int
i
=
lane_id
;
i
<
num_scales
;
i
+=
32
)
{
for
(
int
i
=
lane_id
;
i
<
num_scales
;
i
+=
32
)
{
auto
value
=
ld_nc_global
(
x_scales
+
token_idx
*
num_scales
+
i
);
auto
offset
=
token_idx
*
scale_token_stride
+
i
*
scale_hidden_stride
;
auto
value
=
ld_nc_global
(
x_scales
+
offset
);
#pragma unroll
#pragma unroll
for
(
int
j
=
0
;
j
<
num_topk_ranks
;
++
j
)
for
(
int
j
=
0
;
j
<
num_topk_ranks
;
++
j
)
st_na_global
(
reinterpret_cast
<
float
*>
(
dst_send_buffers
[
j
])
+
i
,
value
);
st_na_global
(
reinterpret_cast
<
float
*>
(
dst_send_buffers
[
j
])
+
i
,
value
);
...
@@ -938,14 +940,18 @@ void dispatch(void* recv_x, float* recv_x_scales, int64_t* recv_topk_idx, float*
...
@@ -938,14 +940,18 @@ void dispatch(void* recv_x, float* recv_x_scales, int64_t* recv_topk_idx, float*
int
*
recv_rdma_channel_prefix_matrix
,
int
*
recv_gbl_channel_prefix_matrix
,
int
*
recv_rdma_channel_prefix_matrix
,
int
*
recv_gbl_channel_prefix_matrix
,
const
int
*
rdma_channel_prefix_matrix
,
const
int
*
recv_rdma_rank_prefix_sum
,
const
int
*
rdma_channel_prefix_matrix
,
const
int
*
recv_rdma_rank_prefix_sum
,
const
int
*
gbl_channel_prefix_matrix
,
const
int
*
recv_gbl_rank_prefix_sum
,
const
int
*
gbl_channel_prefix_matrix
,
const
int
*
recv_gbl_rank_prefix_sum
,
int
num_tokens
,
int
hidden_int4
,
int
num_scales
,
int
num_topk
,
int
num_experts
,
const
bool
*
is_token_in_rank
,
const
bool
*
is_token_in_rank
,
int
num_tokens
,
int
hidden_int4
,
int
num_scales
,
int
num_topk
,
int
num_experts
,
int
scale_token_stride
,
int
scale_hidden_stride
,
void
*
rdma_buffer_ptr
,
int
num_max_rdma_chunked_send_tokens
,
int
num_max_rdma_chunked_recv_tokens
,
void
*
rdma_buffer_ptr
,
int
num_max_rdma_chunked_send_tokens
,
int
num_max_rdma_chunked_recv_tokens
,
void
**
buffer_ptrs
,
int
num_max_nvl_chunked_send_tokens
,
int
num_max_nvl_chunked_recv_tokens
,
void
**
buffer_ptrs
,
int
num_max_nvl_chunked_send_tokens
,
int
num_max_nvl_chunked_recv_tokens
,
int
rank
,
int
num_ranks
,
bool
is_cached_dispatch
,
int
rank
,
int
num_ranks
,
bool
is_cached_dispatch
,
cudaStream_t
stream
,
int
num_channels
,
bool
low_latency_mode
)
{
cudaStream_t
stream
,
int
num_channels
,
bool
low_latency_mode
)
{
constexpr
int
kNumDispatchRDMASenderWarps
=
7
;
constexpr
int
kNumDispatchRDMASenderWarps
=
7
;
// Make sure never OOB
EP_HOST_ASSERT
(
static_cast
<
int64_t
>
(
num_scales
)
*
scale_hidden_stride
<
std
::
numeric_limits
<
int
>::
max
());
#define DISPATCH_LAUNCH_CASE(num_rdma_ranks) { \
#define DISPATCH_LAUNCH_CASE(num_rdma_ranks) { \
auto dispatch_func = low_latency_mode ? \
auto dispatch_func = low_latency_mode ? \
(is_cached_dispatch ? dispatch<true, num_rdma_ranks, true, kNumDispatchRDMASenderWarps> : dispatch<true, num_rdma_ranks, false, kNumDispatchRDMASenderWarps>) : \
(is_cached_dispatch ? dispatch<true, num_rdma_ranks, true, kNumDispatchRDMASenderWarps> : dispatch<true, num_rdma_ranks, false, kNumDispatchRDMASenderWarps>) : \
...
@@ -957,8 +963,9 @@ void dispatch(void* recv_x, float* recv_x_scales, int64_t* recv_topk_idx, float*
...
@@ -957,8 +963,9 @@ void dispatch(void* recv_x, float* recv_x_scales, int64_t* recv_topk_idx, float*
recv_rdma_channel_prefix_matrix, recv_gbl_channel_prefix_matrix, \
recv_rdma_channel_prefix_matrix, recv_gbl_channel_prefix_matrix, \
rdma_channel_prefix_matrix, recv_rdma_rank_prefix_sum, \
rdma_channel_prefix_matrix, recv_rdma_rank_prefix_sum, \
gbl_channel_prefix_matrix, recv_gbl_rank_prefix_sum, \
gbl_channel_prefix_matrix, recv_gbl_rank_prefix_sum, \
num_tokens, hidden_int4, num_scales, num_topk, num_experts, \
is_token_in_rank, \
is_token_in_rank, \
num_tokens, hidden_int4, num_scales, num_topk, num_experts, \
scale_token_stride, scale_hidden_stride, \
rdma_buffer_ptr, num_max_rdma_chunked_send_tokens, num_max_rdma_chunked_recv_tokens, \
rdma_buffer_ptr, num_max_rdma_chunked_send_tokens, num_max_rdma_chunked_recv_tokens, \
buffer_ptrs, num_max_nvl_chunked_send_tokens, num_max_nvl_chunked_recv_tokens, \
buffer_ptrs, num_max_nvl_chunked_send_tokens, num_max_nvl_chunked_recv_tokens, \
rank, num_ranks); } break
rank, num_ranks); } break
...
...
csrc/kernels/internode_ll.cu
View file @
21efbe9b
...
@@ -36,9 +36,10 @@ void clean_low_latency_buffer(int* clean_0, int num_clean_int_0,
...
@@ -36,9 +36,10 @@ void clean_low_latency_buffer(int* clean_0, int num_clean_int_0,
clean_0
,
num_clean_int_0
,
clean_1
,
num_clean_int_1
);
clean_0
,
num_clean_int_0
,
clean_1
,
num_clean_int_1
);
}
}
template
<
bool
kUseFP8
,
int
kNumWarpGroups
,
int
kNumWarpsPerGroup
,
int
kHidden
>
template
<
bool
kUseFP8
,
bool
kUseUE8M0
,
int
kNumWarpGroups
,
int
kNumWarpsPerGroup
,
int
kHidden
>
__global__
__launch_bounds__
(
kNumWarpGroups
*
kNumWarpsPerGroup
*
32
,
1
)
void
__global__
__launch_bounds__
(
kNumWarpGroups
*
kNumWarpsPerGroup
*
32
,
1
)
void
dispatch
(
void
*
packed_recv_x
,
float
*
packed_recv_x_scales
,
dispatch
(
void
*
packed_recv_x
,
void
*
packed_recv_x_scales
,
int
*
packed_recv_src_info
,
int64_t
*
packed_recv_layout_range
,
int
*
packed_recv_src_info
,
int64_t
*
packed_recv_layout_range
,
int
*
packed_recv_count
,
int
*
packed_recv_count
,
int
*
cumulative_local_expert_recv_stats
,
int
*
cumulative_local_expert_recv_stats
,
...
@@ -48,7 +49,7 @@ dispatch(void* packed_recv_x, float* packed_recv_x_scales,
...
@@ -48,7 +49,7 @@ dispatch(void* packed_recv_x, float* packed_recv_x_scales,
int
*
next_clean
,
int
num_next_clean_int
,
int
*
next_clean
,
int
num_next_clean_int
,
int
num_tokens
,
int
num_max_dispatch_tokens_per_rank
,
int
num_tokens
,
int
num_max_dispatch_tokens_per_rank
,
int
num_topk
,
int
num_experts
,
int
rank
,
int
num_ranks
,
int
num_topk
,
int
num_experts
,
int
rank
,
int
num_ranks
,
int
*
usage_flag
,
int
phases
)
{
bool
round_scale
,
int
*
usage_flag
,
int
phases
)
{
const
auto
sm_id
=
static_cast
<
int
>
(
blockIdx
.
x
);
const
auto
sm_id
=
static_cast
<
int
>
(
blockIdx
.
x
);
const
auto
thread_id
=
static_cast
<
int
>
(
threadIdx
.
x
);
const
auto
thread_id
=
static_cast
<
int
>
(
threadIdx
.
x
);
const
auto
warp_id
=
thread_id
/
32
,
lane_id
=
get_lane_id
();
const
auto
warp_id
=
thread_id
/
32
,
lane_id
=
get_lane_id
();
...
@@ -59,9 +60,13 @@ dispatch(void* packed_recv_x, float* packed_recv_x_scales,
...
@@ -59,9 +60,13 @@ dispatch(void* packed_recv_x, float* packed_recv_x_scales,
const
auto
sub_warp_id
=
warp_id
%
kNumWarpsPerGroup
;
const
auto
sub_warp_id
=
warp_id
%
kNumWarpsPerGroup
;
const
auto
responsible_expert_idx
=
sm_id
*
kNumWarpGroups
+
warp_group_id
;
const
auto
responsible_expert_idx
=
sm_id
*
kNumWarpGroups
+
warp_group_id
;
// May extract UE8M0 from the scales
using
scale_t
=
std
::
conditional_t
<
kUseUE8M0
,
uint8_t
,
float
>
;
using
packed_t
=
std
::
conditional_t
<
kUseUE8M0
,
uint32_t
,
float
>
;
EP_STATIC_ASSERT
(
sizeof
(
packed_t
)
%
sizeof
(
scale_t
)
==
0
,
"Invalid vector length"
);
// FP8 staffs
// FP8 staffs
constexpr
int
kNumPerChannels
=
128
;
constexpr
int
kNumPerChannels
=
128
;
constexpr
float
kFP8Margin
=
1e-4
,
kFP8Amax
=
448
,
kFP8AmaxInv
=
1.0
f
/
448.0
f
;
const
int
num_scales
=
kHidden
/
kNumPerChannels
;
const
int
num_scales
=
kHidden
/
kNumPerChannels
;
const
size_t
hidden_bytes
=
kHidden
*
(
kUseFP8
?
sizeof
(
__nv_fp8_storage_t
)
:
sizeof
(
nv_bfloat16
));
const
size_t
hidden_bytes
=
kHidden
*
(
kUseFP8
?
sizeof
(
__nv_fp8_storage_t
)
:
sizeof
(
nv_bfloat16
));
const
size_t
hidden_int4
=
hidden_bytes
/
sizeof
(
int4
);
const
size_t
hidden_int4
=
hidden_bytes
/
sizeof
(
int4
);
...
@@ -96,7 +101,7 @@ dispatch(void* packed_recv_x, float* packed_recv_x_scales,
...
@@ -96,7 +101,7 @@ dispatch(void* packed_recv_x, float* packed_recv_x_scales,
const
auto
rdma_x_vec
=
reinterpret_cast
<
vec_t
*>
(
reinterpret_cast
<
uint8_t
*>
(
rdma_x_src_idx
)
+
sizeof
(
int4
));
const
auto
rdma_x_vec
=
reinterpret_cast
<
vec_t
*>
(
reinterpret_cast
<
uint8_t
*>
(
rdma_x_src_idx
)
+
sizeof
(
int4
));
const
auto
rdma_x_scales
=
reinterpret_cast
<
float
*>
(
reinterpret_cast
<
uint8_t
*>
(
rdma_x_vec
)
+
hidden_bytes
);
const
auto
rdma_x_scales
=
reinterpret_cast
<
float
*>
(
reinterpret_cast
<
uint8_t
*>
(
rdma_x_vec
)
+
hidden_bytes
);
// Overlap top-k index read and source token index write
// Overlap top-k index read and source token index write
s
auto
dst_expert_idx
=
warp_id
<
num_topk
?
static_cast
<
int
>
(
__ldg
(
topk_idx
+
token_idx
*
num_topk
+
warp_id
))
:
-
1
;
auto
dst_expert_idx
=
warp_id
<
num_topk
?
static_cast
<
int
>
(
__ldg
(
topk_idx
+
token_idx
*
num_topk
+
warp_id
))
:
-
1
;
thread_id
==
0
?
(
*
rdma_x_src_idx
=
token_idx
)
:
0
;
thread_id
==
0
?
(
*
rdma_x_src_idx
=
token_idx
)
:
0
;
...
@@ -106,7 +111,7 @@ dispatch(void* packed_recv_x, float* packed_recv_x_scales,
...
@@ -106,7 +111,7 @@ dispatch(void* packed_recv_x, float* packed_recv_x_scales,
// Read
// Read
auto
int4_value
=
__ldg
(
x_int4
+
i
);
auto
int4_value
=
__ldg
(
x_int4
+
i
);
if
(
kUseFP8
)
{
if
constexpr
(
kUseFP8
)
{
// Calculate local amax
// Calculate local amax
auto
bf16_values
=
reinterpret_cast
<
nv_bfloat16
*>
(
&
int4_value
);
auto
bf16_values
=
reinterpret_cast
<
nv_bfloat16
*>
(
&
int4_value
);
float
fp32_values
[
kNumElemsPerRead
];
float
fp32_values
[
kNumElemsPerRead
];
...
@@ -119,7 +124,8 @@ dispatch(void* packed_recv_x, float* packed_recv_x_scales,
...
@@ -119,7 +124,8 @@ dispatch(void* packed_recv_x, float* packed_recv_x_scales,
// Reduce amax and scale
// Reduce amax and scale
EP_STATIC_ASSERT
(
kNumElemsPerRead
*
32
/
kNumPerChannels
==
2
,
"Invalid vectorization"
);
EP_STATIC_ASSERT
(
kNumElemsPerRead
*
32
/
kNumPerChannels
==
2
,
"Invalid vectorization"
);
amax
=
half_warp_reduce_max
(
amax
),
scale
=
kFP8Amax
/
amax
,
scale_inv
=
amax
*
kFP8AmaxInv
;
amax
=
half_warp_reduce_max
(
amax
);
calculate_fp8_scales
(
amax
,
scale
,
scale_inv
,
round_scale
);
if
(
lane_id
==
0
or
lane_id
==
16
)
if
(
lane_id
==
0
or
lane_id
==
16
)
rdma_x_scales
[
i
*
kNumElemsPerRead
/
128
]
=
scale_inv
;
rdma_x_scales
[
i
*
kNumElemsPerRead
/
128
]
=
scale_inv
;
...
@@ -256,9 +262,10 @@ dispatch(void* packed_recv_x, float* packed_recv_x_scales,
...
@@ -256,9 +262,10 @@ dispatch(void* packed_recv_x, float* packed_recv_x_scales,
src_rank
*
num_max_dispatch_tokens_per_rank
*
num_bytes_per_msg
;
src_rank
*
num_max_dispatch_tokens_per_rank
*
num_bytes_per_msg
;
const
auto
recv_x_int4
=
reinterpret_cast
<
int4
*>
(
packed_recv_x
)
+
const
auto
recv_x_int4
=
reinterpret_cast
<
int4
*>
(
packed_recv_x
)
+
local_expert_idx
*
num_ranks
*
num_max_dispatch_tokens_per_rank
*
hidden_int4
;
local_expert_idx
*
num_ranks
*
num_max_dispatch_tokens_per_rank
*
hidden_int4
;
const
auto
recv_x_scales
=
packed_recv_x_scales
+
local_expert_idx
*
num_ranks
*
num_max_dispatch_tokens_per_rank
*
num_scales
;
const
auto
recv_src_info
=
packed_recv_src_info
+
local_expert_idx
*
num_ranks
*
num_max_dispatch_tokens_per_rank
;
const
auto
recv_src_info
=
packed_recv_src_info
+
local_expert_idx
*
num_ranks
*
num_max_dispatch_tokens_per_rank
;
const
auto
recv_range
=
packed_recv_layout_range
+
local_expert_idx
*
num_ranks
;
const
auto
recv_range
=
packed_recv_layout_range
+
local_expert_idx
*
num_ranks
;
const
auto
num_aligned_scales
=
align
<
int
>
(
num_scales
,
sizeof
(
float
)
/
sizeof
(
scale_t
));
const
auto
recv_x_scales
=
reinterpret_cast
<
scale_t
*>
(
packed_recv_x_scales
)
+
local_expert_idx
*
num_ranks
*
num_max_dispatch_tokens_per_rank
*
num_aligned_scales
;
// Shared between sub-warps in warp groups
// Shared between sub-warps in warp groups
__shared__
int
shared_num_recv_tokens
[
kNumWarpGroups
],
shared_recv_token_begin_idx
[
kNumWarpGroups
];
__shared__
int
shared_num_recv_tokens
[
kNumWarpGroups
],
shared_recv_token_begin_idx
[
kNumWarpGroups
];
...
@@ -297,20 +304,32 @@ dispatch(void* packed_recv_x, float* packed_recv_x_scales,
...
@@ -297,20 +304,32 @@ dispatch(void* packed_recv_x, float* packed_recv_x_scales,
UNROLLED_WARP_COPY
(
7
,
lane_id
,
hidden_int4
,
dst_data
,
src_data
,
ld_nc_global
,
st_na_global
);
UNROLLED_WARP_COPY
(
7
,
lane_id
,
hidden_int4
,
dst_data
,
src_data
,
ld_nc_global
,
st_na_global
);
// Copy scales
// Copy scales
if
(
kUseFP8
)
{
if
constexpr
(
kUseFP8
)
{
// Equivalent CuTe layout:
// (num_tokens, (num_packed, num_elems_per_pack)):(num_elems_per_pack, (num_tokens * num_elems_per_pack, 1))
const
auto
src_scales
=
reinterpret_cast
<
float
*>
(
reinterpret_cast
<
uint8_t
*>
(
src_data
)
+
hidden_bytes
);
const
auto
src_scales
=
reinterpret_cast
<
float
*>
(
reinterpret_cast
<
uint8_t
*>
(
src_data
)
+
hidden_bytes
);
const
auto
dst_scales
=
reinterpret_cast
<
float
*>
(
recv_x_scales
+
recv_token_begin_idx
+
i
);
const
auto
num_elems_per_pack
=
static_cast
<
int
>
(
sizeof
(
packed_t
)
/
sizeof
(
scale_t
));
const
auto
scale_stride
=
num_ranks
*
num_max_dispatch_tokens_per_rank
;
const
auto
token_idx
=
recv_token_begin_idx
+
i
;
auto
scale_0
=
lane_id
<
num_scales
?
ld_nc_global
(
src_scales
+
lane_id
)
:
0
;
const
auto
token_stride
=
num_elems_per_pack
;
auto
scale_1
=
(
lane_id
+
32
)
<
num_scales
?
ld_nc_global
(
src_scales
+
lane_id
+
32
)
:
0
;
const
auto
pack_stride
=
num_ranks
*
num_max_dispatch_tokens_per_rank
*
num_elems_per_pack
;
lane_id
<
num_scales
?
dst_scales
[
lane_id
*
scale_stride
]
=
scale_0
:
0.0
f
;
if
(
lane_id
<
num_scales
)
{
(
lane_id
+
32
)
<
num_scales
?
dst_scales
[(
lane_id
+
32
)
*
scale_stride
]
=
scale_1
:
0.0
f
;
const
auto
pack_idx
=
lane_id
/
num_elems_per_pack
;
const
auto
elem_idx
=
lane_id
%
num_elems_per_pack
;
auto
scale
=
extract_required_scale_format
<
kUseUE8M0
>
(
ld_nc_global
(
src_scales
+
lane_id
));
recv_x_scales
[
token_idx
*
token_stride
+
pack_idx
*
pack_stride
+
elem_idx
]
=
scale
;
}
if
(
lane_id
+
32
<
num_scales
)
{
const
auto
pack_idx
=
(
lane_id
+
32
)
/
num_elems_per_pack
;
const
auto
elem_idx
=
(
lane_id
+
32
)
%
num_elems_per_pack
;
auto
scale
=
extract_required_scale_format
<
kUseUE8M0
>
(
ld_nc_global
(
src_scales
+
lane_id
+
32
));
recv_x_scales
[
token_idx
*
token_stride
+
pack_idx
*
pack_stride
+
elem_idx
]
=
scale
;
}
}
}
}
}
}
}
}
}
void
dispatch
(
void
*
packed_recv_x
,
float
*
packed_recv_x_scales
,
void
dispatch
(
void
*
packed_recv_x
,
void
*
packed_recv_x_scales
,
int
*
packed_recv_src_info
,
int64_t
*
packed_recv_layout_range
,
int
*
packed_recv_src_info
,
int64_t
*
packed_recv_layout_range
,
int
*
packed_recv_count
,
int
*
packed_recv_count
,
int
*
cumulative_local_expert_recv_stats
,
int
*
cumulative_local_expert_recv_stats
,
...
@@ -318,7 +337,8 @@ void dispatch(void* packed_recv_x, float* packed_recv_x_scales,
...
@@ -318,7 +337,8 @@ void dispatch(void* packed_recv_x, float* packed_recv_x_scales,
const
void
*
x
,
const
int64_t
*
topk_idx
,
const
void
*
x
,
const
int64_t
*
topk_idx
,
int
*
next_clean
,
int
num_next_clean_int
,
int
*
next_clean
,
int
num_next_clean_int
,
int
num_tokens
,
int
hidden
,
int
num_max_dispatch_tokens_per_rank
,
int
num_tokens
,
int
hidden
,
int
num_max_dispatch_tokens_per_rank
,
int
num_topk
,
int
num_experts
,
int
rank
,
int
num_ranks
,
bool
use_fp8
,
int
num_topk
,
int
num_experts
,
int
rank
,
int
num_ranks
,
bool
use_fp8
,
bool
round_scale
,
bool
use_ue8m0
,
void
*
workspace
,
int
*
usage_flag
,
void
*
workspace
,
int
*
usage_flag
,
cudaStream_t
stream
,
int
phases
)
{
cudaStream_t
stream
,
int
phases
)
{
constexpr
int
kNumMaxTopK
=
9
;
constexpr
int
kNumMaxTopK
=
9
;
...
@@ -331,13 +351,20 @@ void dispatch(void* packed_recv_x, float* packed_recv_x_scales,
...
@@ -331,13 +351,20 @@ void dispatch(void* packed_recv_x, float* packed_recv_x_scales,
EP_HOST_ASSERT
(
num_topk
<=
kNumMaxTopK
);
EP_HOST_ASSERT
(
num_topk
<=
kNumMaxTopK
);
// Workspace checks
// Workspace checks
auto
atomic_counter_per_expert
=
reinterpret
_cast
<
int
*>
(
workspace
);
auto
atomic_counter_per_expert
=
static
_cast
<
int
*>
(
workspace
);
auto
atomic_finish_counter_per_expert
=
atomic_counter_per_expert
+
num_experts
;
auto
atomic_finish_counter_per_expert
=
atomic_counter_per_expert
+
num_experts
;
EP_HOST_ASSERT
(
num_experts
*
sizeof
(
int
)
*
2
<=
NUM_WORKSPACE_BYTES
);
EP_HOST_ASSERT
(
num_experts
*
sizeof
(
int
)
*
2
<=
NUM_WORKSPACE_BYTES
);
// FP8 checks
if
(
use_ue8m0
)
EP_HOST_ASSERT
(
round_scale
and
"UE8M0 SF requires `round_scale=True`"
);
#define DISPATCH_LAUNCH_CASE(hidden) { \
#define DISPATCH_LAUNCH_CASE(hidden) { \
auto dispatch_func = use_fp8 ? dispatch<true, kNumWarpGroups, kNumWarpsPerGroup, hidden> : \
auto dispatch_func = dispatch<false, false, kNumWarpGroups, kNumWarpsPerGroup, hidden>; \
dispatch<false, kNumWarpGroups, kNumWarpsPerGroup, hidden>; \
if (use_fp8 and not use_ue8m0) \
dispatch_func = dispatch<true, false, kNumWarpGroups, kNumWarpsPerGroup, hidden>; \
if (use_fp8 and use_ue8m0) \
dispatch_func = dispatch<true, true, kNumWarpGroups, kNumWarpsPerGroup, hidden>; \
LAUNCH_KERNEL(&cfg, dispatch_func, \
LAUNCH_KERNEL(&cfg, dispatch_func, \
packed_recv_x, packed_recv_x_scales, \
packed_recv_x, packed_recv_x_scales, \
packed_recv_src_info, packed_recv_layout_range, \
packed_recv_src_info, packed_recv_layout_range, \
...
@@ -349,7 +376,7 @@ LAUNCH_KERNEL(&cfg, dispatch_func, \
...
@@ -349,7 +376,7 @@ LAUNCH_KERNEL(&cfg, dispatch_func, \
next_clean, num_next_clean_int, \
next_clean, num_next_clean_int, \
num_tokens, num_max_dispatch_tokens_per_rank, \
num_tokens, num_max_dispatch_tokens_per_rank, \
num_topk, num_experts, rank, num_ranks, \
num_topk, num_experts, rank, num_ranks, \
usage_flag, phases); } break
round_scale,
usage_flag, phases); } break
SETUP_LAUNCH_CONFIG
(
num_sms
,
num_warps
*
32
,
stream
);
SETUP_LAUNCH_CONFIG
(
num_sms
,
num_warps
*
32
,
stream
);
SWITCH_HIDDEN
(
DISPATCH_LAUNCH_CASE
);
SWITCH_HIDDEN
(
DISPATCH_LAUNCH_CASE
);
...
...
csrc/kernels/intranode.cu
View file @
21efbe9b
...
@@ -174,6 +174,7 @@ dispatch(int4* recv_x, float* recv_x_scales, int* recv_src_idx, int64_t* recv_to
...
@@ -174,6 +174,7 @@ dispatch(int4* recv_x, float* recv_x_scales, int* recv_src_idx, int64_t* recv_to
int
*
send_head
,
const
int4
*
x
,
const
float
*
x_scales
,
const
int64_t
*
topk_idx
,
const
float
*
topk_weights
,
int
*
send_head
,
const
int4
*
x
,
const
float
*
x_scales
,
const
int64_t
*
topk_idx
,
const
float
*
topk_weights
,
const
bool
*
is_token_in_rank
,
const
int
*
channel_prefix_matrix
,
const
bool
*
is_token_in_rank
,
const
int
*
channel_prefix_matrix
,
int
num_tokens
,
int
num_worst_tokens
,
int
hidden_int4
,
int
num_topk
,
int
num_experts
,
int
num_scales
,
int
num_tokens
,
int
num_worst_tokens
,
int
hidden_int4
,
int
num_topk
,
int
num_experts
,
int
num_scales
,
int
scale_token_stride
,
int
scale_hidden_stride
,
void
**
buffer_ptrs
,
int
rank
,
void
**
buffer_ptrs
,
int
rank
,
int
num_max_send_tokens
,
int
num_recv_buffer_tokens
)
{
int
num_max_send_tokens
,
int
num_recv_buffer_tokens
)
{
const
auto
num_sms
=
static_cast
<
int
>
(
gridDim
.
x
),
sm_id
=
static_cast
<
int
>
(
blockIdx
.
x
);
const
auto
num_sms
=
static_cast
<
int
>
(
gridDim
.
x
),
sm_id
=
static_cast
<
int
>
(
blockIdx
.
x
);
...
@@ -326,8 +327,10 @@ dispatch(int4* recv_x, float* recv_x_scales, int* recv_src_idx, int64_t* recv_to
...
@@ -326,8 +327,10 @@ dispatch(int4* recv_x, float* recv_x_scales, int* recv_src_idx, int64_t* recv_to
// Copy `x_scales`
// Copy `x_scales`
#pragma unroll
#pragma unroll
for
(
int
i
=
lane_id
;
i
<
num_scales
;
i
+=
32
)
for
(
int
i
=
lane_id
;
i
<
num_scales
;
i
+=
32
)
{
channel_x_scales_buffers
[
dst_slot_idx
*
num_scales
+
i
]
=
__ldg
(
x_scales
+
token_idx
*
num_scales
+
i
);
auto
offset
=
token_idx
*
scale_token_stride
+
i
*
scale_hidden_stride
;
channel_x_scales_buffers
[
dst_slot_idx
*
num_scales
+
i
]
=
__ldg
(
x_scales
+
offset
);
}
}
}
// Move token index
// Move token index
...
@@ -478,6 +481,7 @@ void dispatch(void* recv_x, float* recv_x_scales, int* recv_src_idx, int64_t* re
...
@@ -478,6 +481,7 @@ void dispatch(void* recv_x, float* recv_x_scales, int* recv_src_idx, int64_t* re
int
*
send_head
,
const
void
*
x
,
const
float
*
x_scales
,
const
int64_t
*
topk_idx
,
const
float
*
topk_weights
,
int
*
send_head
,
const
void
*
x
,
const
float
*
x_scales
,
const
int64_t
*
topk_idx
,
const
float
*
topk_weights
,
const
bool
*
is_token_in_rank
,
const
int
*
channel_prefix_matrix
,
const
bool
*
is_token_in_rank
,
const
int
*
channel_prefix_matrix
,
int
num_tokens
,
int
num_worst_tokens
,
int
hidden_int4
,
int
num_topk
,
int
num_experts
,
int
num_scales
,
int
num_tokens
,
int
num_worst_tokens
,
int
hidden_int4
,
int
num_topk
,
int
num_experts
,
int
num_scales
,
int
scale_token_stride
,
int
scale_hidden_stride
,
void
**
buffer_ptrs
,
int
rank
,
int
num_ranks
,
void
**
buffer_ptrs
,
int
rank
,
int
num_ranks
,
cudaStream_t
stream
,
int
num_sms
,
int
num_max_send_tokens
,
int
num_recv_buffer_tokens
)
{
cudaStream_t
stream
,
int
num_sms
,
int
num_max_send_tokens
,
int
num_recv_buffer_tokens
)
{
constexpr
int
kNumThreads
=
768
;
constexpr
int
kNumThreads
=
768
;
...
@@ -486,6 +490,9 @@ void dispatch(void* recv_x, float* recv_x_scales, int* recv_src_idx, int64_t* re
...
@@ -486,6 +490,9 @@ void dispatch(void* recv_x, float* recv_x_scales, int* recv_src_idx, int64_t* re
constexpr
int
smem_size
=
kNumTMABytesPerWarp
*
(
kNumThreads
/
32
);
constexpr
int
smem_size
=
kNumTMABytesPerWarp
*
(
kNumThreads
/
32
);
#endif
#endif
// Make sure never OOB
EP_HOST_ASSERT
(
static_cast
<
int64_t
>
(
num_scales
)
*
scale_hidden_stride
<
std
::
numeric_limits
<
int
>::
max
());
#define DISPATCH_LAUNCH_CASE(ranks) { \
#define DISPATCH_LAUNCH_CASE(ranks) { \
auto kernel = dispatch<ranks, kNumThreads, kNumTMABytesPerWarp>; \
auto kernel = dispatch<ranks, kNumThreads, kNumTMABytesPerWarp>; \
SET_SHARED_MEMORY_FOR_TMA(kernel); \
SET_SHARED_MEMORY_FOR_TMA(kernel); \
...
@@ -494,6 +501,7 @@ void dispatch(void* recv_x, float* recv_x_scales, int* recv_src_idx, int64_t* re
...
@@ -494,6 +501,7 @@ void dispatch(void* recv_x, float* recv_x_scales, int* recv_src_idx, int64_t* re
send_head, reinterpret_cast<const int4*>(x), x_scales, topk_idx, topk_weights, \
send_head, reinterpret_cast<const int4*>(x), x_scales, topk_idx, topk_weights, \
is_token_in_rank, channel_prefix_matrix, \
is_token_in_rank, channel_prefix_matrix, \
num_tokens, num_worst_tokens, hidden_int4, num_topk, num_experts, num_scales, \
num_tokens, num_worst_tokens, hidden_int4, num_topk, num_experts, num_scales, \
scale_token_stride, scale_hidden_stride, \
buffer_ptrs, rank, \
buffer_ptrs, rank, \
num_max_send_tokens, num_recv_buffer_tokens); \
num_max_send_tokens, num_recv_buffer_tokens); \
} break
} break
...
...
csrc/kernels/utils.cuh
View file @
21efbe9b
...
@@ -401,6 +401,43 @@ __forceinline__ __device__ int get_lane_id() {
...
@@ -401,6 +401,43 @@ __forceinline__ __device__ int get_lane_id() {
return
lane_id
;
return
lane_id
;
}
}
constexpr
float
kFP8Margin
=
1e-4
;
constexpr
float
kFinfoAmaxE4M3
=
448.0
f
;
constexpr
float
kFinfoAmaxInvE4M3
=
1
/
448.0
f
;
__forceinline__
__device__
float
fast_pow2
(
int
x
)
{
// We can ensure `-126 <= x and x <= 127`
uint32_t
bits_x
=
(
x
+
127
)
<<
23
;
return
*
reinterpret_cast
<
float
*>
(
&
bits_x
);
}
__forceinline__
__device__
int
fast_log2_ceil
(
float
x
)
{
auto
bits_x
=
*
reinterpret_cast
<
uint32_t
*>
(
&
x
);
auto
exp_x
=
(
bits_x
>>
23
)
&
0xff
;
auto
man_bits
=
bits_x
&
((
1
<<
23
)
-
1
);
return
exp_x
-
127
+
(
man_bits
!=
0
);
}
__forceinline__
__device__
void
calculate_fp8_scales
(
float
amax
,
float
&
scale
,
float
&
scale_inv
,
bool
round_scale
)
{
if
(
round_scale
)
{
auto
exp_scale_inv
=
fast_log2_ceil
(
amax
*
kFinfoAmaxInvE4M3
);
scale
=
fast_pow2
(
-
exp_scale_inv
);
scale_inv
=
fast_pow2
(
exp_scale_inv
);
}
else
{
scale_inv
=
amax
*
kFinfoAmaxInvE4M3
;
scale
=
kFinfoAmaxE4M3
/
amax
;
}
}
template
<
bool
kIsUE8M0
,
typename
out_dtype_t
=
std
::
conditional_t
<
kIsUE8M0
,
uint8_t
,
float
>
>
__forceinline__
__device__
out_dtype_t
extract_required_scale_format
(
float
value
)
{
if
constexpr
(
kIsUE8M0
)
{
return
static_cast
<
uint8_t
>
((
*
reinterpret_cast
<
uint32_t
*>
(
&
value
))
>>
23
);
}
else
{
return
value
;
}
}
template
<
int
kNumRanks
>
template
<
int
kNumRanks
>
__forceinline__
__device__
void
__forceinline__
__device__
void
barrier_block
(
int
**
barrier_signal_ptrs
,
int
rank
)
{
barrier_block
(
int
**
barrier_signal_ptrs
,
int
rank
)
{
...
...
deep_ep/buffer.py
View file @
21efbe9b
...
@@ -178,6 +178,7 @@ class Buffer:
...
@@ -178,6 +178,7 @@ class Buffer:
config: the recommended config.
config: the recommended config.
"""
"""
# TODO: automatically tune
config_map
=
{
config_map
=
{
2
:
Config
(
Buffer
.
num_sms
,
24
,
256
,
6
,
128
),
2
:
Config
(
Buffer
.
num_sms
,
24
,
256
,
6
,
128
),
4
:
Config
(
Buffer
.
num_sms
,
6
,
256
,
6
,
128
),
4
:
Config
(
Buffer
.
num_sms
,
6
,
256
,
6
,
128
),
...
@@ -205,6 +206,7 @@ class Buffer:
...
@@ -205,6 +206,7 @@ class Buffer:
config: the recommended config.
config: the recommended config.
"""
"""
# TODO: automatically tune
config_map
=
{
config_map
=
{
2
:
Config
(
Buffer
.
num_sms
,
10
,
256
,
6
,
128
),
2
:
Config
(
Buffer
.
num_sms
,
10
,
256
,
6
,
128
),
4
:
Config
(
Buffer
.
num_sms
,
9
,
256
,
6
,
128
),
4
:
Config
(
Buffer
.
num_sms
,
9
,
256
,
6
,
128
),
...
@@ -486,14 +488,14 @@ class Buffer:
...
@@ -486,14 +488,14 @@ class Buffer:
def
low_latency_dispatch
(
self
,
x
:
torch
.
Tensor
,
topk_idx
:
torch
.
Tensor
,
def
low_latency_dispatch
(
self
,
x
:
torch
.
Tensor
,
topk_idx
:
torch
.
Tensor
,
num_max_dispatch_tokens_per_rank
:
int
,
num_experts
:
int
,
num_max_dispatch_tokens_per_rank
:
int
,
num_experts
:
int
,
cumulative_local_expert_recv_stats
:
Optional
[
torch
.
Tensor
]
=
None
,
cumulative_local_expert_recv_stats
:
Optional
[
torch
.
Tensor
]
=
None
,
use_fp8
:
bool
=
True
,
async_finish
:
bool
=
False
,
return_recv_hook
:
bool
=
False
)
->
\
use_fp8
:
bool
=
True
,
round_scale
:
bool
=
False
,
use_ue8m0
:
bool
=
False
,
async_finish
:
bool
=
False
,
return_recv_hook
:
bool
=
False
)
->
\
Tuple
[
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
],
torch
.
Tensor
,
Tuple
,
EventOverlap
,
Callable
]:
Tuple
[
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
],
torch
.
Tensor
,
Tuple
,
EventOverlap
,
Callable
]:
"""
"""
A low-latency implementation for dispatching with IBGDA.
A low-latency implementation for dispatching with IBGDA.
This kernel requires all the ranks (no matter intranode or internode) should be visible via RDMA
This kernel requires all the ranks (no matter intranode or internode) should be visible via RDMA
(specifically, IBGDA must be enabled).
(specifically, IBGDA must be enabled).
Even for ranks in the same node, NVLink are fully disabled for simplicity.
Warning: as there are only two buffers, and the returned tensors reuse the buffer, you cannot hold more than 2
Warning: as there are only two buffers, and the returned tensors reuse the buffer, you can not hold more than 2
low-latency kernels' result tensors at a single moment.
low-latency kernels' result tensors at a single moment.
Arguments:
Arguments:
...
@@ -507,17 +509,21 @@ class Buffer:
...
@@ -507,17 +509,21 @@ class Buffer:
`[num_local_experts]` and be typed as `torch.int`. This is useful for online service EP load balance
`[num_local_experts]` and be typed as `torch.int`. This is useful for online service EP load balance
monitoring.
monitoring.
use_fp8: whether to enable FP8 casting, with this, the received data will be a tuple of FP8 tensor and scaling factors.
use_fp8: whether to enable FP8 casting, with this, the received data will be a tuple of FP8 tensor and scaling factors.
round_scale: whether round the scaling factors into power of 2.
use_ue8m0: whether use UE8M0 as scaling factor format (available only with `round_scale=True`).
async_finish: the current stream will not wait for the communication kernels to be finished if set.
async_finish: the current stream will not wait for the communication kernels to be finished if set.
return_recv_hook: return a receiving hook if set. If set, the kernel will just do the RDMA request issues,
return_recv_hook: return a receiving hook if set. If set, the kernel will just do the RDMA request issues,
but **without actually receiving the data**. You must call the received hook to make sure the data's arrival.
but **without actually receiving the data**. You must call the received hook to make sure the data's arrival.
If you not set this flag, the kernel will ensure the data's arrival.
If you
do
not set this flag, the kernel will ensure the data's arrival.
Returns:
Returns:
recv_x: a tensor or tuple with received tokens for each expert.
recv_x: a tensor or tuple with received tokens for each expert.
With `use_fp8=True`: the first element is a `torch.Tensor` shaped as
With `use_fp8=True`: the first element is a `torch.Tensor` shaped as
`[num_local_experts, num_max_dispatch_tokens_per_rank * num_ranks, hidden]` with `torch.float8_e4m3fn`.
`[num_local_experts, num_max_dispatch_tokens_per_rank * num_ranks, hidden]` with `torch.float8_e4m3fn`.
The second tensor is the corresponding scales for the first element with shape
The second tensor is the corresponding scales for the first element with shape
`[num_local_experts, num_max_dispatch_tokens_per_rank * num_ranks, hidden // 128]` with `torch.float`.
`[num_local_experts, num_max_dispatch_tokens_per_rank * num_ranks, hidden // 128]` with `torch.float`,
if `use_ue8m0=False`. With `use_ue8m0=True`, the second one is packed and shaped as
`[num_local_experts, num_max_dispatch_tokens_per_rank * num_ranks, hidden // 512]` with type `torch.int`.
Notice that, the last-two-dimension of the scaling tensors are in column-major for TMA compatibility.
Notice that, the last-two-dimension of the scaling tensors are in column-major for TMA compatibility.
With `use_fp8=False`, the result would be a tensor shaped as
With `use_fp8=False`, the result would be a tensor shaped as
`[num_local_experts, num_max_dispatch_tokens_per_rank * num_ranks, hidden]` with `torch.bfloat16`.
`[num_local_experts, num_max_dispatch_tokens_per_rank * num_ranks, hidden]` with `torch.bfloat16`.
...
@@ -533,7 +539,8 @@ class Buffer:
...
@@ -533,7 +539,8 @@ class Buffer:
self
.
runtime
.
low_latency_dispatch
(
x
,
topk_idx
,
self
.
runtime
.
low_latency_dispatch
(
x
,
topk_idx
,
cumulative_local_expert_recv_stats
,
cumulative_local_expert_recv_stats
,
num_max_dispatch_tokens_per_rank
,
num_experts
,
num_max_dispatch_tokens_per_rank
,
num_experts
,
use_fp8
,
async_finish
,
return_recv_hook
)
use_fp8
,
round_scale
,
use_ue8m0
,
async_finish
,
return_recv_hook
)
handle
=
(
packed_recv_src_info
,
packed_recv_layout_range
,
num_max_dispatch_tokens_per_rank
,
x
.
size
(
1
),
num_experts
)
handle
=
(
packed_recv_src_info
,
packed_recv_layout_range
,
num_max_dispatch_tokens_per_rank
,
x
.
size
(
1
),
num_experts
)
tensors_to_record
=
(
x
,
topk_idx
,
tensors_to_record
=
(
x
,
topk_idx
,
packed_recv_x
,
packed_recv_x_scales
,
packed_recv_count
,
packed_recv_x
,
packed_recv_x_scales
,
packed_recv_count
,
...
@@ -551,9 +558,8 @@ class Buffer:
...
@@ -551,9 +558,8 @@ class Buffer:
A low-latency implementation for combining tokens (reduce **with weights**) with IBGDA.
A low-latency implementation for combining tokens (reduce **with weights**) with IBGDA.
This kernel requires all the ranks (no matter intranode or internode) should be visible via RDMA
This kernel requires all the ranks (no matter intranode or internode) should be visible via RDMA
(specifically, IBGDA must be enabled).
(specifically, IBGDA must be enabled).
Even for ranks in the same node, NVLink are fully disabled for simplicity.
Warning: as there are only two buffers, and the returned tensors reuse the buffer, you cannot hold more than 2
Warning: as there are only two buffers, and the returned tensors reuse the buffer, you can not hold more than 2
low-latency kernels' result tensors at a single moment.
low-latency kernels' result tensor at a single moment.
Arguments:
Arguments:
x: `[num_local_experts, num_max_dispatch_tokens_per_rank * num_ranks, hidden]` with `torch.bfloat16`,
x: `[num_local_experts, num_max_dispatch_tokens_per_rank * num_ranks, hidden]` with `torch.bfloat16`,
...
@@ -569,7 +575,7 @@ class Buffer:
...
@@ -569,7 +575,7 @@ class Buffer:
async_finish: the current stream will not wait for the communication kernels to be finished if set.
async_finish: the current stream will not wait for the communication kernels to be finished if set.
return_recv_hook: return a receiving hook if set. If set, the kernel will just do the RDMA request issues,
return_recv_hook: return a receiving hook if set. If set, the kernel will just do the RDMA request issues,
but **without actually receiving the data**. You must call the received hook to make sure the data's arrival.
but **without actually receiving the data**. You must call the received hook to make sure the data's arrival.
If you not set this flag, the kernel will ensure the data's arrival.
If you
do
not set this flag, the kernel will ensure the data's arrival.
out: the in-place output tensor, if set, the kernel will write the result to this tensor and return it directly.
out: the in-place output tensor, if set, the kernel will write the result to this tensor and return it directly.
Returns:
Returns:
...
...
install.sh
0 → 100755
View file @
21efbe9b
# Change current directory into project root
original_dir
=
$(
pwd
)
script_dir
=
$(
dirname
"
$0
"
)
cd
"
$script_dir
"
# Remove old dist file, build, and install
rm
-rf
dist
python setup.py bdist_wheel
pip
install
dist/
*
.whl
# Open users' original directory
cd
"
$original_dir
"
tests/test_internode.py
View file @
21efbe9b
...
@@ -22,6 +22,7 @@ def test_main(num_sms: int, local_rank: int, num_local_ranks: int, num_ranks: in
...
@@ -22,6 +22,7 @@ def test_main(num_sms: int, local_rank: int, num_local_ranks: int, num_ranks: in
x
=
torch
.
ones
((
num_tokens
,
hidden
),
dtype
=
torch
.
bfloat16
,
device
=
'cuda'
)
*
rank
x
=
torch
.
ones
((
num_tokens
,
hidden
),
dtype
=
torch
.
bfloat16
,
device
=
'cuda'
)
*
rank
x_pure_rand
=
torch
.
randn
((
num_tokens
,
hidden
),
dtype
=
torch
.
bfloat16
,
device
=
'cuda'
)
x_pure_rand
=
torch
.
randn
((
num_tokens
,
hidden
),
dtype
=
torch
.
bfloat16
,
device
=
'cuda'
)
x_e4m3
=
per_token_cast_to_fp8
(
x
)
x_e4m3
=
per_token_cast_to_fp8
(
x
)
x_e4m3
=
(
x_e4m3
[
0
],
x_e4m3
[
1
].
T
.
contiguous
().
T
)
scores
=
torch
.
randn
((
num_tokens
,
num_experts
),
dtype
=
torch
.
float32
,
device
=
'cuda'
).
abs
()
+
1
scores
=
torch
.
randn
((
num_tokens
,
num_experts
),
dtype
=
torch
.
float32
,
device
=
'cuda'
).
abs
()
+
1
group_scores
=
scores
.
view
(
num_tokens
,
num_nodes
,
-
1
).
amax
(
dim
=-
1
)
group_scores
=
scores
.
view
(
num_tokens
,
num_nodes
,
-
1
).
amax
(
dim
=-
1
)
group_idx
=
torch
.
topk
(
group_scores
,
k
=
num_topk_groups
,
dim
=-
1
,
sorted
=
False
).
indices
group_idx
=
torch
.
topk
(
group_scores
,
k
=
num_topk_groups
,
dim
=-
1
,
sorted
=
False
).
indices
...
@@ -241,6 +242,10 @@ def test_loop(local_rank: int, num_local_ranks: int):
...
@@ -241,6 +242,10 @@ def test_loop(local_rank: int, num_local_ranks: int):
buffer
.
clean_low_latency_buffer
(
ll_num_tokens
,
ll_hidden
,
ll_num_experts
)
buffer
.
clean_low_latency_buffer
(
ll_num_tokens
,
ll_hidden
,
ll_num_experts
)
test_low_latency
.
test_main
(
ll_num_tokens
,
ll_hidden
,
ll_num_experts
,
ll_num_topk
,
rank
,
num_ranks
,
group
,
buffer
,
seed
=
1
)
test_low_latency
.
test_main
(
ll_num_tokens
,
ll_hidden
,
ll_num_experts
,
ll_num_topk
,
rank
,
num_ranks
,
group
,
buffer
,
seed
=
1
)
# Destroy the communication group
dist
.
barrier
()
dist
.
destroy_process_group
()
if
__name__
==
'__main__'
:
if
__name__
==
'__main__'
:
num_processes
=
8
num_processes
=
8
...
...
tests/test_intranode.py
View file @
21efbe9b
...
@@ -21,6 +21,7 @@ def test_main(num_sms: int, local_rank: int, num_ranks: int, rank: int, buffer:
...
@@ -21,6 +21,7 @@ def test_main(num_sms: int, local_rank: int, num_ranks: int, rank: int, buffer:
x
=
torch
.
ones
((
num_tokens
,
hidden
),
dtype
=
torch
.
bfloat16
,
device
=
'cuda'
)
*
rank
x
=
torch
.
ones
((
num_tokens
,
hidden
),
dtype
=
torch
.
bfloat16
,
device
=
'cuda'
)
*
rank
x_pure_rand
=
torch
.
randn
((
num_tokens
,
hidden
),
dtype
=
torch
.
bfloat16
,
device
=
'cuda'
)
x_pure_rand
=
torch
.
randn
((
num_tokens
,
hidden
),
dtype
=
torch
.
bfloat16
,
device
=
'cuda'
)
x_e4m3
=
per_token_cast_to_fp8
(
x
)
if
deep_ep
.
Buffer
.
is_sm90_compiled
()
else
None
x_e4m3
=
per_token_cast_to_fp8
(
x
)
if
deep_ep
.
Buffer
.
is_sm90_compiled
()
else
None
x_e4m3
=
(
x_e4m3
[
0
],
x_e4m3
[
1
].
T
.
contiguous
().
T
)
if
x_e4m3
is
not
None
else
None
scores
=
torch
.
randn
((
num_tokens
,
num_experts
),
dtype
=
torch
.
float32
,
device
=
'cuda'
).
abs
()
+
1
scores
=
torch
.
randn
((
num_tokens
,
num_experts
),
dtype
=
torch
.
float32
,
device
=
'cuda'
).
abs
()
+
1
topk_idx
=
torch
.
topk
(
scores
,
num_topk
,
dim
=-
1
,
largest
=
True
,
sorted
=
False
)[
1
]
topk_idx
=
torch
.
topk
(
scores
,
num_topk
,
dim
=-
1
,
largest
=
True
,
sorted
=
False
)[
1
]
topk_weights
=
torch
.
ones
((
num_tokens
,
num_topk
),
dtype
=
torch
.
float32
,
device
=
'cuda'
)
*
rank
topk_weights
=
torch
.
ones
((
num_tokens
,
num_topk
),
dtype
=
torch
.
float32
,
device
=
'cuda'
)
*
rank
...
...
tests/test_low_latency.py
View file @
21efbe9b
...
@@ -34,11 +34,14 @@ def test_main(num_tokens: int, hidden: int, num_experts: int, num_topk: int,
...
@@ -34,11 +34,14 @@ def test_main(num_tokens: int, hidden: int, num_experts: int, num_topk: int,
hash_value
,
num_times
=
0
,
0
hash_value
,
num_times
=
0
,
0
for
return_recv_hook
in
(
False
,
True
):
for
return_recv_hook
in
(
False
,
True
):
for
dispatch_use_fp8
in
(
False
,
True
):
for
dispatch_use_fp8
in
(
False
,
True
):
for
round_scale
in
(
False
,
True
)
if
dispatch_use_fp8
else
(
False
,
):
for
use_ue8m0
in
(
False
,
True
)
if
round_scale
else
(
False
,
):
num_times
+=
1
num_times
+=
1
for
i
in
range
((
num_times
%
2
)
+
1
):
for
i
in
range
((
num_times
%
2
)
+
1
):
cumulative_local_expert_recv_stats
=
torch
.
zeros
((
num_local_experts
,
),
dtype
=
torch
.
int
,
device
=
'cuda'
)
cumulative_local_expert_recv_stats
=
torch
.
zeros
((
num_local_experts
,
),
dtype
=
torch
.
int
,
device
=
'cuda'
)
packed_recv_x
,
packed_recv_count
,
handle
,
event
,
hook
=
\
packed_recv_x
,
packed_recv_count
,
handle
,
event
,
hook
=
\
buffer
.
low_latency_dispatch
(
x
,
topk_idx
,
num_tokens
,
num_experts
,
use_fp8
=
dispatch_use_fp8
,
buffer
.
low_latency_dispatch
(
x
,
topk_idx
,
num_tokens
,
num_experts
,
use_fp8
=
dispatch_use_fp8
,
round_scale
=
round_scale
,
use_ue8m0
=
use_ue8m0
,
cumulative_local_expert_recv_stats
=
cumulative_local_expert_recv_stats
,
cumulative_local_expert_recv_stats
=
cumulative_local_expert_recv_stats
,
async_finish
=
not
return_recv_hook
,
return_recv_hook
=
return_recv_hook
)
async_finish
=
not
return_recv_hook
,
return_recv_hook
=
return_recv_hook
)
hook
()
if
return_recv_hook
else
event
.
current_stream_wait
()
hook
()
if
return_recv_hook
else
event
.
current_stream_wait
()
...
@@ -64,9 +67,13 @@ def test_main(num_tokens: int, hidden: int, num_experts: int, num_topk: int,
...
@@ -64,9 +67,13 @@ def test_main(num_tokens: int, hidden: int, num_experts: int, num_topk: int,
recv_x_amin
=
recv_x
[:,
:
-
128
].
amin
(
dim
=-
1
)
recv_x_amin
=
recv_x
[:,
:
-
128
].
amin
(
dim
=-
1
)
recv_src_info
=
recv_src_info
[:
num_valid_tokens
]
recv_src_info
=
recv_src_info
[:
num_valid_tokens
]
assert
torch
.
equal
(
recv_x_amin
,
recv_x
[:,
:
-
128
].
amax
(
dim
=-
1
))
assert
torch
.
equal
(
recv_x_amin
,
recv_x
[:,
:
-
128
].
amax
(
dim
=-
1
))
if
round_scale
:
assert
calc_diff
(
recv_x
[:,
-
1
],
recv_src_info
.
view
(
-
1
))
<
0.007
else
:
assert
(
recv_x
[:,
-
128
:]
-
recv_src_info
.
view
(
-
1
,
1
)
%
num_tokens
).
sum
().
item
()
==
0
assert
(
recv_x
[:,
-
128
:]
-
recv_src_info
.
view
(
-
1
,
1
)
%
num_tokens
).
sum
().
item
()
==
0
for
j
in
range
(
num_ranks
):
for
j
in
range
(
num_ranks
):
begin_idx
,
count
=
(
recv_layout_range
[
j
]
>>
32
).
item
(),
(
recv_layout_range
[
j
]
&
int_mask
).
item
()
begin_idx
,
count
=
(
recv_layout_range
[
j
]
>>
32
).
item
(),
(
recv_layout_range
[
j
]
&
int_mask
).
item
()
if
not
round_scale
:
assert
(
recv_x_amin
==
j
-
rank_offset
).
sum
().
item
()
==
(
all_topk_idx
[
j
]
==
expert_id
).
sum
().
item
()
assert
(
recv_x_amin
==
j
-
rank_offset
).
sum
().
item
()
==
(
all_topk_idx
[
j
]
==
expert_id
).
sum
().
item
()
assert
(
recv_x
[
begin_idx
:
begin_idx
+
count
][:
-
128
]
-
j
).
sum
().
item
()
==
0
assert
(
recv_x
[
begin_idx
:
begin_idx
+
count
][:
-
128
]
-
j
).
sum
().
item
()
==
0
if
dispatch_use_fp8
:
if
dispatch_use_fp8
:
...
@@ -87,7 +94,7 @@ def test_main(num_tokens: int, hidden: int, num_experts: int, num_topk: int,
...
@@ -87,7 +94,7 @@ def test_main(num_tokens: int, hidden: int, num_experts: int, num_topk: int,
if
do_check
:
if
do_check
:
diff
=
calc_diff
(
x
*
topk_weights
.
masked_fill
(
topk_idx
==
-
1
,
0
).
sum
(
dim
=
1
).
view
(
-
1
,
1
),
combined_x
)
diff
=
calc_diff
(
x
*
topk_weights
.
masked_fill
(
topk_idx
==
-
1
,
0
).
sum
(
dim
=
1
).
view
(
-
1
,
1
),
combined_x
)
assert
torch
.
isnan
(
combined_x
).
sum
().
item
()
==
0
assert
torch
.
isnan
(
combined_x
).
sum
().
item
()
==
0
assert
diff
<
1e-5
,
f
'Error:
{
diff
=
}
,
{
zero_copy
=
}
'
assert
diff
<
(
7e-4
if
round_scale
else
1e-5
)
,
f
'Error:
{
diff
=
}
,
{
zero_copy
=
}
'
hash_value
^=
hash_tensor
(
combined_x
)
hash_value
^=
hash_tensor
(
combined_x
)
def
create_test_cast_with_outliers
(
num_outliers
):
def
create_test_cast_with_outliers
(
num_outliers
):
...
@@ -112,7 +119,7 @@ def test_main(num_tokens: int, hidden: int, num_experts: int, num_topk: int,
...
@@ -112,7 +119,7 @@ def test_main(num_tokens: int, hidden: int, num_experts: int, num_topk: int,
recv_x
,
recv_count
,
handle
,
event
,
hook
=
\
recv_x
,
recv_count
,
handle
,
event
,
hook
=
\
buffer
.
low_latency_dispatch
(
x
,
topk_idx
,
num_tokens
,
num_experts
,
buffer
.
low_latency_dispatch
(
x
,
topk_idx
,
num_tokens
,
num_experts
,
cumulative_local_expert_recv_stats
=
cumulative_local_expert_recv_stats
,
cumulative_local_expert_recv_stats
=
cumulative_local_expert_recv_stats
,
async_finish
=
False
,
return_recv_hook
=
return_recv_hook
)
use_fp8
=
True
,
async_finish
=
False
,
return_recv_hook
=
return_recv_hook
)
large_gemm_with_hook
(
hook
)
if
return_recv_hook
else
None
large_gemm_with_hook
(
hook
)
if
return_recv_hook
else
None
if
zero_copy
:
if
zero_copy
:
buffer
.
get_next_low_latency_combine_buffer
(
handle
)[:,
:,
:]
=
simulated_gemm_x
buffer
.
get_next_low_latency_combine_buffer
(
handle
)[:,
:,
:]
=
simulated_gemm_x
...
@@ -170,6 +177,10 @@ def test_loop(local_rank: int, num_local_ranks: int):
...
@@ -170,6 +177,10 @@ def test_loop(local_rank: int, num_local_ranks: int):
for
i
in
range
(
20
):
for
i
in
range
(
20
):
assert
test_main
(
num_tokens
,
hidden
,
num_experts
,
num_topk
,
rank
,
num_ranks
,
group
,
buffer
,
seed
=
seed
)
==
ref_hash
,
f
'Error: seed=
{
seed
}
'
assert
test_main
(
num_tokens
,
hidden
,
num_experts
,
num_topk
,
rank
,
num_ranks
,
group
,
buffer
,
seed
=
seed
)
==
ref_hash
,
f
'Error: seed=
{
seed
}
'
# Destroy the communication group
dist
.
barrier
()
dist
.
destroy_process_group
()
if
__name__
==
'__main__'
:
if
__name__
==
'__main__'
:
# TODO: you may modify NUMA binding for less CPU overhead
# TODO: you may modify NUMA binding for less CPU overhead
...
...
tests/utils.py
View file @
21efbe9b
...
@@ -43,6 +43,9 @@ def per_token_cast_to_fp8(x: torch.Tensor):
...
@@ -43,6 +43,9 @@ def per_token_cast_to_fp8(x: torch.Tensor):
def
per_token_cast_back
(
x_fp8
:
torch
.
Tensor
,
x_scales
:
torch
.
Tensor
):
def
per_token_cast_back
(
x_fp8
:
torch
.
Tensor
,
x_scales
:
torch
.
Tensor
):
if
x_scales
.
dtype
==
torch
.
int
:
x_scales
=
x_scales
.
view
(
dtype
=
torch
.
int8
).
to
(
torch
.
int
)
<<
23
x_scales
=
x_scales
.
view
(
dtype
=
torch
.
float
)
x_fp32
=
x_fp8
.
to
(
torch
.
float32
).
view
(
x_fp8
.
size
(
0
),
-
1
,
128
)
x_fp32
=
x_fp8
.
to
(
torch
.
float32
).
view
(
x_fp8
.
size
(
0
),
-
1
,
128
)
x_scales
=
x_scales
.
view
(
x_fp8
.
size
(
0
),
-
1
,
1
)
x_scales
=
x_scales
.
view
(
x_fp8
.
size
(
0
),
-
1
,
1
)
return
(
x_fp32
*
x_scales
).
view
(
x_fp8
.
shape
).
to
(
torch
.
bfloat16
)
return
(
x_fp32
*
x_scales
).
view
(
x_fp8
.
shape
).
to
(
torch
.
bfloat16
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment