Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
DeepEP
Commits
ed7487c1
Commit
ed7487c1
authored
Mar 10, 2025
by
Chenggang Zhao
Browse files
Support BF16 for low-latency kernels
parent
1fc40d50
Changes
8
Show whitespace changes
Inline
Side-by-side
Showing
8 changed files
with
139 additions
and
112 deletions
+139
-112
README.md
README.md
+1
-1
csrc/config.hpp
csrc/config.hpp
+1
-1
csrc/deep_ep.cpp
csrc/deep_ep.cpp
+14
-8
csrc/deep_ep.hpp
csrc/deep_ep.hpp
+2
-2
csrc/kernels/api.cuh
csrc/kernels/api.cuh
+1
-1
csrc/kernels/internode_ll.cu
csrc/kernels/internode_ll.cu
+63
-51
deep_ep/buffer.py
deep_ep/buffer.py
+10
-6
tests/test_low_latency.py
tests/test_low_latency.py
+47
-42
No files found.
README.md
View file @
ed7487c1
...
@@ -282,7 +282,7 @@ For two micro-batch overlapping, you can refer to the following figure. With our
...
@@ -282,7 +282,7 @@ For two micro-batch overlapping, you can refer to the following figure. With our
-
[x] AR support
-
[x] AR support
-
[ ] Refactor low-latency mode AR code
-
[ ] Refactor low-latency mode AR code
-
[ ] A100 support (intranode only)
-
[ ] A100 support (intranode only)
-
[
] Support BF16 for the low-latency dispatch kernel
-
[
x
] Support BF16 for the low-latency dispatch kernel
-
[ ] Support NVLink protocol for intranode low-latency kernels
-
[ ] Support NVLink protocol for intranode low-latency kernels
-
[ ] SM-free normal kernels
-
[ ] SM-free normal kernels
...
...
csrc/config.hpp
View file @
ed7487c1
...
@@ -128,7 +128,7 @@ struct LowLatencyLayout {
...
@@ -128,7 +128,7 @@ struct LowLatencyLayout {
// Message sizes
// Message sizes
EP_HOST_ASSERT
(
num_scales
*
sizeof
(
float
)
<=
hidden
);
EP_HOST_ASSERT
(
num_scales
*
sizeof
(
float
)
<=
hidden
);
size_t
num_bytes_per_dispatch_msg
=
hidden
+
num_scales
*
sizeof
(
float
)
+
sizeof
(
int4
);
size_t
num_bytes_per_dispatch_msg
=
sizeof
(
int4
)
+
std
::
max
(
hidden
*
sizeof
(
nv_b
float
16
),
hidden
+
num_scales
*
sizeof
(
float
)
);
size_t
num_bytes_per_combine_msg
=
sizeof
(
int4
)
+
hidden
*
sizeof
(
nv_bfloat16
);
size_t
num_bytes_per_combine_msg
=
sizeof
(
int4
)
+
hidden
*
sizeof
(
nv_bfloat16
);
// Send buffer
// Send buffer
...
...
csrc/deep_ep.cpp
View file @
ed7487c1
...
@@ -1011,10 +1011,10 @@ void Buffer::clean_low_latency_buffer(int num_max_dispatch_tokens_per_rank, int
...
@@ -1011,10 +1011,10 @@ void Buffer::clean_low_latency_buffer(int num_max_dispatch_tokens_per_rank, int
at
::
cuda
::
getCurrentCUDAStream
());
at
::
cuda
::
getCurrentCUDAStream
());
}
}
std
::
tuple
<
torch
::
Tensor
,
torch
::
Tensor
,
torch
::
Tensor
,
torch
::
Tensor
,
torch
::
Tensor
,
std
::
optional
<
EventHandle
>
,
std
::
optional
<
std
::
function
<
void
()
>>>
std
::
tuple
<
torch
::
Tensor
,
std
::
optional
<
torch
::
Tensor
>
,
torch
::
Tensor
,
torch
::
Tensor
,
torch
::
Tensor
,
std
::
optional
<
EventHandle
>
,
std
::
optional
<
std
::
function
<
void
()
>>>
Buffer
::
low_latency_dispatch
(
const
torch
::
Tensor
&
x
,
const
torch
::
Tensor
&
topk_idx
,
Buffer
::
low_latency_dispatch
(
const
torch
::
Tensor
&
x
,
const
torch
::
Tensor
&
topk_idx
,
int
num_max_dispatch_tokens_per_rank
,
int
num_experts
,
int
num_max_dispatch_tokens_per_rank
,
int
num_experts
,
bool
async
,
bool
return_recv_hook
)
{
bool
use_fp8
,
bool
async
,
bool
return_recv_hook
)
{
EP_HOST_ASSERT
(
low_latency_mode
);
EP_HOST_ASSERT
(
low_latency_mode
);
// Tensor checks
// Tensor checks
...
@@ -1045,20 +1045,26 @@ Buffer::low_latency_dispatch(const torch::Tensor& x, const torch::Tensor& topk_i
...
@@ -1045,20 +1045,26 @@ Buffer::low_latency_dispatch(const torch::Tensor& x, const torch::Tensor& topk_i
stream_wait
(
launch_stream
,
compute_stream
);
stream_wait
(
launch_stream
,
compute_stream
);
// Allocate packed tensors
// Allocate packed tensors
auto
packed_recv_x
=
torch
::
empty
({
num_local_experts
,
num_ranks
*
num_max_dispatch_tokens_per_rank
,
hidden
},
x
.
options
().
dtype
(
torch
::
kFloat8_e4m3fn
));
auto
packed_recv_x
=
torch
::
empty
({
num_local_experts
,
num_ranks
*
num_max_dispatch_tokens_per_rank
,
hidden
},
x
.
options
().
dtype
(
use_fp8
?
torch
::
kFloat8_e4m3fn
:
torch
::
kBFloat16
));
auto
packed_recv_src_info
=
torch
::
empty
({
num_local_experts
,
num_ranks
*
num_max_dispatch_tokens_per_rank
},
torch
::
dtype
(
torch
::
kInt32
).
device
(
torch
::
kCUDA
));
auto
packed_recv_src_info
=
torch
::
empty
({
num_local_experts
,
num_ranks
*
num_max_dispatch_tokens_per_rank
},
torch
::
dtype
(
torch
::
kInt32
).
device
(
torch
::
kCUDA
));
auto
packed_recv_layout_range
=
torch
::
empty
({
num_local_experts
,
num_ranks
},
torch
::
dtype
(
torch
::
kInt64
).
device
(
torch
::
kCUDA
));
auto
packed_recv_layout_range
=
torch
::
empty
({
num_local_experts
,
num_ranks
},
torch
::
dtype
(
torch
::
kInt64
).
device
(
torch
::
kCUDA
));
auto
packed_recv_count
=
torch
::
empty
({
num_local_experts
},
torch
::
dtype
(
torch
::
kInt32
).
device
(
torch
::
kCUDA
));
auto
packed_recv_count
=
torch
::
empty
({
num_local_experts
},
torch
::
dtype
(
torch
::
kInt32
).
device
(
torch
::
kCUDA
));
// Allocate column-majored scales
// Allocate column-majored scales
auto
packed_recv_x_scales
=
std
::
optional
<
torch
::
Tensor
>
();
float
*
packed_recv_x_scales_ptr
=
nullptr
;
if
(
use_fp8
)
{
EP_HOST_ASSERT
((
num_ranks
*
num_max_dispatch_tokens_per_rank
)
%
4
==
0
and
"TMA requires the number of tokens to be multiple of 4"
);
EP_HOST_ASSERT
((
num_ranks
*
num_max_dispatch_tokens_per_rank
)
%
4
==
0
and
"TMA requires the number of tokens to be multiple of 4"
);
auto
packed_recv_x_scales
=
torch
::
empty
({
num_local_experts
,
num_scales
,
num_ranks
*
num_max_dispatch_tokens_per_rank
},
torch
::
dtype
(
torch
::
kFloat32
).
device
(
torch
::
kCUDA
));
packed_recv_x_scales
=
torch
::
empty
({
num_local_experts
,
num_scales
,
num_ranks
*
num_max_dispatch_tokens_per_rank
},
torch
::
dtype
(
torch
::
kFloat32
).
device
(
torch
::
kCUDA
));
packed_recv_x_scales
=
torch
::
transpose
(
packed_recv_x_scales
,
1
,
2
);
packed_recv_x_scales
=
torch
::
transpose
(
packed_recv_x_scales
.
value
(),
1
,
2
);
packed_recv_x_scales_ptr
=
packed_recv_x_scales
->
data_ptr
<
float
>
();
}
// Kernel launch
// Kernel launch
auto
next_clean_meta
=
next_buffer
.
clean_meta
();
auto
next_clean_meta
=
next_buffer
.
clean_meta
();
auto
launcher
=
[
=
](
int
phases
)
{
auto
launcher
=
[
=
](
int
phases
)
{
internode_ll
::
dispatch
(
packed_recv_x
.
data_ptr
(),
packed_recv_x_scales
.
data_ptr
<
float
>
()
,
internode_ll
::
dispatch
(
packed_recv_x
.
data_ptr
(),
packed_recv_x_scales
_ptr
,
packed_recv_src_info
.
data_ptr
<
int
>
(),
packed_recv_layout_range
.
data_ptr
<
int64_t
>
(),
packed_recv_src_info
.
data_ptr
<
int
>
(),
packed_recv_layout_range
.
data_ptr
<
int64_t
>
(),
packed_recv_count
.
data_ptr
<
int
>
(),
packed_recv_count
.
data_ptr
<
int
>
(),
buffer
.
dispatch_rdma_recv_data_buffer
,
buffer
.
dispatch_rdma_recv_count_buffer
,
buffer
.
dispatch_rdma_recv_data_buffer
,
buffer
.
dispatch_rdma_recv_count_buffer
,
...
@@ -1066,7 +1072,7 @@ Buffer::low_latency_dispatch(const torch::Tensor& x, const torch::Tensor& topk_i
...
@@ -1066,7 +1072,7 @@ Buffer::low_latency_dispatch(const torch::Tensor& x, const torch::Tensor& topk_i
x
.
data_ptr
(),
topk_idx
.
data_ptr
<
int64_t
>
(),
x
.
data_ptr
(),
topk_idx
.
data_ptr
<
int64_t
>
(),
next_clean_meta
.
first
,
next_clean_meta
.
second
,
next_clean_meta
.
first
,
next_clean_meta
.
second
,
num_tokens
,
hidden
,
num_max_dispatch_tokens_per_rank
,
num_tokens
,
hidden
,
num_max_dispatch_tokens_per_rank
,
num_topk
,
num_experts
,
rank
,
num_ranks
,
num_topk
,
num_experts
,
rank
,
num_ranks
,
use_fp8
,
workspace
,
launch_stream
,
phases
);
workspace
,
launch_stream
,
phases
);
};
};
launcher
(
return_recv_hook
?
LOW_LATENCY_SEND_PHASE
:
(
LOW_LATENCY_SEND_PHASE
|
LOW_LATENCY_RECV_PHASE
));
launcher
(
return_recv_hook
?
LOW_LATENCY_SEND_PHASE
:
(
LOW_LATENCY_SEND_PHASE
|
LOW_LATENCY_RECV_PHASE
));
...
...
csrc/deep_ep.hpp
View file @
ed7487c1
...
@@ -134,10 +134,10 @@ public:
...
@@ -134,10 +134,10 @@ public:
void
clean_low_latency_buffer
(
int
num_max_dispatch_tokens_per_rank
,
int
hidden
,
int
num_experts
);
void
clean_low_latency_buffer
(
int
num_max_dispatch_tokens_per_rank
,
int
hidden
,
int
num_experts
);
std
::
tuple
<
torch
::
Tensor
,
torch
::
Tensor
,
torch
::
Tensor
,
torch
::
Tensor
,
torch
::
Tensor
,
std
::
optional
<
EventHandle
>
,
std
::
optional
<
std
::
function
<
void
()
>>>
std
::
tuple
<
torch
::
Tensor
,
std
::
optional
<
torch
::
Tensor
>
,
torch
::
Tensor
,
torch
::
Tensor
,
torch
::
Tensor
,
std
::
optional
<
EventHandle
>
,
std
::
optional
<
std
::
function
<
void
()
>>>
low_latency_dispatch
(
const
torch
::
Tensor
&
x
,
const
torch
::
Tensor
&
topk_idx
,
low_latency_dispatch
(
const
torch
::
Tensor
&
x
,
const
torch
::
Tensor
&
topk_idx
,
int
num_max_dispatch_tokens_per_rank
,
int
num_experts
,
int
num_max_dispatch_tokens_per_rank
,
int
num_experts
,
bool
async
,
bool
return_recv_hook
);
bool
use_fp8
,
bool
async
,
bool
return_recv_hook
);
std
::
tuple
<
torch
::
Tensor
,
std
::
optional
<
EventHandle
>
,
std
::
optional
<
std
::
function
<
void
()
>>>
std
::
tuple
<
torch
::
Tensor
,
std
::
optional
<
EventHandle
>
,
std
::
optional
<
std
::
function
<
void
()
>>>
low_latency_combine
(
const
torch
::
Tensor
&
x
,
const
torch
::
Tensor
&
topk_idx
,
const
torch
::
Tensor
&
topk_weights
,
low_latency_combine
(
const
torch
::
Tensor
&
x
,
const
torch
::
Tensor
&
topk_idx
,
const
torch
::
Tensor
&
topk_weights
,
...
...
csrc/kernels/api.cuh
View file @
ed7487c1
...
@@ -137,7 +137,7 @@ void dispatch(void* packed_recv_x, float* packed_recv_x_scales,
...
@@ -137,7 +137,7 @@ void dispatch(void* packed_recv_x, float* packed_recv_x_scales,
const
void
*
x
,
const
int64_t
*
topk_idx
,
const
void
*
x
,
const
int64_t
*
topk_idx
,
int
*
next_clean
,
int
num_next_clean_int
,
int
*
next_clean
,
int
num_next_clean_int
,
int
num_tokens
,
int
hidden
,
int
num_max_dispatch_tokens_per_rank
,
int
num_tokens
,
int
hidden
,
int
num_max_dispatch_tokens_per_rank
,
int
num_topk
,
int
num_experts
,
int
rank
,
int
num_ranks
,
int
num_topk
,
int
num_experts
,
int
rank
,
int
num_ranks
,
bool
use_fp8
,
void
*
workspace
,
cudaStream_t
stream
,
int
phases
);
void
*
workspace
,
cudaStream_t
stream
,
int
phases
);
void
combine
(
void
*
combined_x
,
void
combine
(
void
*
combined_x
,
...
...
csrc/kernels/internode_ll.cu
View file @
ed7487c1
...
@@ -36,7 +36,7 @@ void clean_low_latency_buffer(int* clean_0, int num_clean_int_0,
...
@@ -36,7 +36,7 @@ void clean_low_latency_buffer(int* clean_0, int num_clean_int_0,
clean_0
,
num_clean_int_0
,
clean_1
,
num_clean_int_1
);
clean_0
,
num_clean_int_0
,
clean_1
,
num_clean_int_1
);
}
}
template
<
int
kNumWarpGroups
,
int
kNumWarpsPerGroup
,
int
kHidden
>
template
<
bool
kUseFP8
,
int
kNumWarpGroups
,
int
kNumWarpsPerGroup
,
int
kHidden
>
__global__
__launch_bounds__
(
kNumWarpGroups
*
kNumWarpsPerGroup
*
32
,
1
)
void
__global__
__launch_bounds__
(
kNumWarpGroups
*
kNumWarpsPerGroup
*
32
,
1
)
void
dispatch
(
void
*
packed_recv_x
,
float
*
packed_recv_x_scales
,
dispatch
(
void
*
packed_recv_x
,
float
*
packed_recv_x_scales
,
int
*
packed_recv_src_info
,
int64_t
*
packed_recv_layout_range
,
int
*
packed_recv_src_info
,
int64_t
*
packed_recv_layout_range
,
...
@@ -62,11 +62,13 @@ dispatch(void* packed_recv_x, float* packed_recv_x_scales,
...
@@ -62,11 +62,13 @@ dispatch(void* packed_recv_x, float* packed_recv_x_scales,
constexpr
int
kNumPerChannels
=
128
;
constexpr
int
kNumPerChannels
=
128
;
constexpr
float
kFP8Margin
=
1e-4
,
kFP8Amax
=
448
,
kFP8AmaxInv
=
1.0
f
/
448.0
f
;
constexpr
float
kFP8Margin
=
1e-4
,
kFP8Amax
=
448
,
kFP8AmaxInv
=
1.0
f
/
448.0
f
;
const
int
num_scales
=
kHidden
/
kNumPerChannels
;
const
int
num_scales
=
kHidden
/
kNumPerChannels
;
const
size_t
hidden_int4
=
kHidden
/
sizeof
(
int4
);
const
size_t
hidden_bytes
=
kHidden
*
(
kUseFP8
?
sizeof
(
__nv_fp8_storage_t
)
:
sizeof
(
nv_bfloat16
));
const
size_t
hidden_int4
=
hidden_bytes
/
sizeof
(
int4
);
// Message package: hidden data, FP8 scales, index at source
// Message package: hidden data, FP8 scales, index at source
// NOTES: currently we have 3 reserved int fields for future use
// NOTES: currently we have 3 reserved int fields for future use
const
size_t
num_bytes_per_msg
=
kHidden
+
num_scales
*
sizeof
(
float
)
+
sizeof
(
int4
);
using
vec_t
=
typename
std
::
conditional
<
kUseFP8
,
int2
,
int4
>::
type
;
const
size_t
num_bytes_per_msg
=
sizeof
(
int4
)
+
(
kUseFP8
?
(
kHidden
+
num_scales
*
sizeof
(
float
))
:
(
kHidden
*
sizeof
(
nv_bfloat16
)));
const
size_t
num_int4_per_msg
=
num_bytes_per_msg
/
sizeof
(
int4
);
const
size_t
num_int4_per_msg
=
num_bytes_per_msg
/
sizeof
(
int4
);
EP_DEVICE_ASSERT
(
num_bytes_per_msg
%
sizeof
(
int4
)
==
0
);
EP_DEVICE_ASSERT
(
num_bytes_per_msg
%
sizeof
(
int4
)
==
0
);
...
@@ -89,9 +91,9 @@ dispatch(void* packed_recv_x, float* packed_recv_x_scales,
...
@@ -89,9 +91,9 @@ dispatch(void* packed_recv_x, float* packed_recv_x_scales,
for
(
int
token_idx
=
sm_id
;
token_idx
<
num_tokens
;
token_idx
+=
num_sms
)
{
for
(
int
token_idx
=
sm_id
;
token_idx
<
num_tokens
;
token_idx
+=
num_sms
)
{
const
auto
x_int4
=
reinterpret_cast
<
const
int4
*>
(
x
)
+
token_idx
*
hidden_bf16_int4
;
const
auto
x_int4
=
reinterpret_cast
<
const
int4
*>
(
x
)
+
token_idx
*
hidden_bf16_int4
;
const
auto
rdma_x_
int2
=
reinterpret_cast
<
int
2
*>
(
reinterpret_cast
<
uint8_t
*>
(
rdma_x
)
+
token_idx
*
num_bytes_per_msg
);
const
auto
rdma_x_
src_idx
=
reinterpret_cast
<
int
*>
(
reinterpret_cast
<
uint8_t
*>
(
rdma_x
)
+
token_idx
*
num_bytes_per_msg
);
const
auto
rdma_x_
scales
=
reinterpret_cast
<
floa
t
*>
(
reinterpret_cast
<
uint8_t
*>
(
rdma_x_
int2
)
+
kHidden
);
const
auto
rdma_x_
vec
=
reinterpret_cast
<
vec_
t
*>
(
reinterpret_cast
<
uint8_t
*>
(
rdma_x_
src_idx
)
+
sizeof
(
int4
)
);
const
auto
rdma_x_s
rc_idx
=
reinterpret_cast
<
int
*>
(
rdma_x_scales
+
num_scal
es
);
const
auto
rdma_x_s
cales
=
reinterpret_cast
<
float
*>
(
reinterpret_cast
<
uint8_t
*>
(
rdma_x_vec
)
+
hidden_byt
es
);
// Overlap top-k index read and source token index write
// Overlap top-k index read and source token index write
auto
dst_expert_idx
=
warp_id
<
num_topk
?
static_cast
<
int
>
(
__ldg
(
topk_idx
+
token_idx
*
num_topk
+
warp_id
))
:
-
1
;
auto
dst_expert_idx
=
warp_id
<
num_topk
?
static_cast
<
int
>
(
__ldg
(
topk_idx
+
token_idx
*
num_topk
+
warp_id
))
:
-
1
;
...
@@ -100,8 +102,11 @@ dispatch(void* packed_recv_x, float* packed_recv_x_scales,
...
@@ -100,8 +102,11 @@ dispatch(void* packed_recv_x, float* packed_recv_x_scales,
// FP8 cast
// FP8 cast
#pragma unroll
#pragma unroll
for
(
int
i
=
thread_id
;
i
<
hidden_bf16_int4
;
i
+=
num_threads
)
{
for
(
int
i
=
thread_id
;
i
<
hidden_bf16_int4
;
i
+=
num_threads
)
{
// Read
and calculate local amax
// Read
auto
int4_value
=
__ldg
(
x_int4
+
i
);
auto
int4_value
=
__ldg
(
x_int4
+
i
);
if
(
kUseFP8
)
{
// Calculate local amax
auto
bf16_values
=
reinterpret_cast
<
nv_bfloat16
*>
(
&
int4_value
);
auto
bf16_values
=
reinterpret_cast
<
nv_bfloat16
*>
(
&
int4_value
);
float
fp32_values
[
kNumElemsPerRead
];
float
fp32_values
[
kNumElemsPerRead
];
float
amax
=
kFP8Margin
,
scale
,
scale_inv
;
float
amax
=
kFP8Margin
,
scale
,
scale_inv
;
...
@@ -118,14 +123,18 @@ dispatch(void* packed_recv_x, float* packed_recv_x_scales,
...
@@ -118,14 +123,18 @@ dispatch(void* packed_recv_x, float* packed_recv_x_scales,
rdma_x_scales
[
i
*
kNumElemsPerRead
/
128
]
=
scale_inv
;
rdma_x_scales
[
i
*
kNumElemsPerRead
/
128
]
=
scale_inv
;
// Cast into send buffer
// Cast into send buffer
int2
int2_value
;
vec_t
int2_value
;
auto
fp8x2_values
=
reinterpret_cast
<
__nv_fp8x2_storage_t
*>
(
&
int2_value
);
auto
fp8x2_values
=
reinterpret_cast
<
__nv_fp8x2_storage_t
*>
(
&
int2_value
);
#pragma unroll
#pragma unroll
for
(
int
j
=
0
;
j
<
kNumElemsPerRead
;
j
+=
2
)
{
for
(
int
j
=
0
;
j
<
kNumElemsPerRead
;
j
+=
2
)
{
float2
fp32x2
=
{
fp32_values
[
j
]
*
scale
,
fp32_values
[
j
+
1
]
*
scale
};
float2
fp32x2
=
{
fp32_values
[
j
]
*
scale
,
fp32_values
[
j
+
1
]
*
scale
};
fp8x2_values
[
j
/
2
]
=
__nv_cvt_float2_to_fp8x2
(
fp32x2
,
__NV_SATFINITE
,
__NV_E4M3
);
fp8x2_values
[
j
/
2
]
=
__nv_cvt_float2_to_fp8x2
(
fp32x2
,
__NV_SATFINITE
,
__NV_E4M3
);
}
}
rdma_x_int2
[
i
]
=
int2_value
;
rdma_x_vec
[
i
]
=
int2_value
;
}
else
{
// Reinterpret-cast is for C++14 compatibility
rdma_x_vec
[
i
]
=
*
reinterpret_cast
<
vec_t
*>
(
&
int4_value
);
}
}
}
asm
volatile
(
"bar.sync 1, %0;"
::
"r"
(
num_threads
));
asm
volatile
(
"bar.sync 1, %0;"
::
"r"
(
num_threads
));
...
@@ -135,7 +144,7 @@ dispatch(void* packed_recv_x, float* packed_recv_x_scales,
...
@@ -135,7 +144,7 @@ dispatch(void* packed_recv_x, float* packed_recv_x_scales,
slot_idx
=
__shfl_sync
(
0xffffffff
,
slot_idx
,
0
);
slot_idx
=
__shfl_sync
(
0xffffffff
,
slot_idx
,
0
);
const
auto
dst_rank
=
dst_expert_idx
/
num_local_experts
;
const
auto
dst_rank
=
dst_expert_idx
/
num_local_experts
;
const
auto
dst_expert_local_idx
=
dst_expert_idx
%
num_local_experts
;
const
auto
dst_expert_local_idx
=
dst_expert_idx
%
num_local_experts
;
const
auto
src_ptr
=
reinterpret_cast
<
uint64_t
>
(
rdma_x_
int2
);
const
auto
src_ptr
=
reinterpret_cast
<
uint64_t
>
(
rdma_x_
src_idx
);
const
auto
dst_ptr
=
reinterpret_cast
<
uint64_t
>
(
rdma_recv_x
)
+
const
auto
dst_ptr
=
reinterpret_cast
<
uint64_t
>
(
rdma_recv_x
)
+
dst_expert_local_idx
*
num_ranks
*
num_max_dispatch_tokens_per_rank
*
num_bytes_per_msg
+
dst_expert_local_idx
*
num_ranks
*
num_max_dispatch_tokens_per_rank
*
num_bytes_per_msg
+
rank
*
num_max_dispatch_tokens_per_rank
*
num_bytes_per_msg
+
rank
*
num_max_dispatch_tokens_per_rank
*
num_bytes_per_msg
+
...
@@ -273,26 +282,28 @@ dispatch(void* packed_recv_x, float* packed_recv_x_scales,
...
@@ -273,26 +282,28 @@ dispatch(void* packed_recv_x, float* packed_recv_x_scales,
// Copy tokens
// Copy tokens
EP_DEVICE_ASSERT
(
num_scales
<=
64
);
EP_DEVICE_ASSERT
(
num_scales
<=
64
);
for
(
int
i
=
sub_warp_id
;
i
<
num_recv_tokens
;
i
+=
kNumWarpsPerGroup
)
{
for
(
int
i
=
sub_warp_id
;
i
<
num_recv_tokens
;
i
+=
kNumWarpsPerGroup
)
{
// Copy source info
const
auto
src_src_idx
=
reinterpret_cast
<
int
*>
(
rdma_recv_x_uint8
+
i
*
num_bytes_per_msg
);
if
(
lane_id
==
0
)
recv_src_info
[
recv_token_begin_idx
+
i
]
=
ld_nc_global
(
src_src_idx
);
__syncwarp
();
// Copy data
// Copy data
// NOTES: only 2 load iterations for 7K hidden with 7 unrolls
// NOTES: only 2 load iterations for 7K hidden with 7 unrolls
const
auto
src
=
reinterpret_cast
<
int4
*>
(
r
dma_recv_x_uint8
+
i
*
num_bytes_per_msg
);
const
auto
src
_data
=
reinterpret_cast
<
int4
*>
(
r
einterpret_cast
<
uint8_t
*>
(
src_src_idx
)
+
sizeof
(
int4
)
);
const
auto
dst
=
recv_x_int4
+
(
recv_token_begin_idx
+
i
)
*
hidden_int4
;
const
auto
dst
_data
=
recv_x_int4
+
(
recv_token_begin_idx
+
i
)
*
hidden_int4
;
UNROLLED_WARP_COPY
(
7
,
lane_id
,
hidden_int4
,
dst
,
src
,
ld_nc_global
,
st_na_global
);
UNROLLED_WARP_COPY
(
7
,
lane_id
,
hidden_int4
,
dst
_data
,
src_data
,
ld_nc_global
,
st_na_global
);
// Copy scales
// Copy scales
const
auto
src_scales
=
reinterpret_cast
<
float
*>
(
rdma_recv_x_uint8
+
i
*
num_bytes_per_msg
+
kHidden
);
if
(
kUseFP8
)
{
const
auto
src_scales
=
reinterpret_cast
<
float
*>
(
reinterpret_cast
<
uint8_t
*>
(
src_data
)
+
hidden_bytes
);
const
auto
dst_scales
=
reinterpret_cast
<
float
*>
(
recv_x_scales
+
recv_token_begin_idx
+
i
);
const
auto
dst_scales
=
reinterpret_cast
<
float
*>
(
recv_x_scales
+
recv_token_begin_idx
+
i
);
const
auto
scale_stride
=
num_ranks
*
num_max_dispatch_tokens_per_rank
;
const
auto
scale_stride
=
num_ranks
*
num_max_dispatch_tokens_per_rank
;
auto
scale_0
=
lane_id
<
num_scales
?
ld_nc_global
(
src_scales
+
lane_id
)
:
0
;
auto
scale_0
=
lane_id
<
num_scales
?
ld_nc_global
(
src_scales
+
lane_id
)
:
0
;
auto
scale_1
=
(
lane_id
+
32
)
<
num_scales
?
ld_nc_global
(
src_scales
+
lane_id
+
32
)
:
0
;
auto
scale_1
=
(
lane_id
+
32
)
<
num_scales
?
ld_nc_global
(
src_scales
+
lane_id
+
32
)
:
0
;
lane_id
<
num_scales
?
dst_scales
[
lane_id
*
scale_stride
]
=
scale_0
:
0.0
f
;
lane_id
<
num_scales
?
dst_scales
[
lane_id
*
scale_stride
]
=
scale_0
:
0.0
f
;
(
lane_id
+
32
)
<
num_scales
?
dst_scales
[(
lane_id
+
32
)
*
scale_stride
]
=
scale_1
:
0.0
f
;
(
lane_id
+
32
)
<
num_scales
?
dst_scales
[(
lane_id
+
32
)
*
scale_stride
]
=
scale_1
:
0.0
f
;
}
// Copy source info
const
auto
src_src_idx
=
reinterpret_cast
<
int
*>
(
src_scales
+
num_scales
);
if
(
lane_id
==
0
)
recv_src_info
[
recv_token_begin_idx
+
i
]
=
ld_nc_global
(
src_src_idx
);
__syncwarp
();
}
}
}
}
}
}
...
@@ -304,7 +315,7 @@ void dispatch(void* packed_recv_x, float* packed_recv_x_scales,
...
@@ -304,7 +315,7 @@ void dispatch(void* packed_recv_x, float* packed_recv_x_scales,
const
void
*
x
,
const
int64_t
*
topk_idx
,
const
void
*
x
,
const
int64_t
*
topk_idx
,
int
*
next_clean
,
int
num_next_clean_int
,
int
*
next_clean
,
int
num_next_clean_int
,
int
num_tokens
,
int
hidden
,
int
num_max_dispatch_tokens_per_rank
,
int
num_tokens
,
int
hidden
,
int
num_max_dispatch_tokens_per_rank
,
int
num_topk
,
int
num_experts
,
int
rank
,
int
num_ranks
,
int
num_topk
,
int
num_experts
,
int
rank
,
int
num_ranks
,
bool
use_fp8
,
void
*
workspace
,
cudaStream_t
stream
,
int
phases
)
{
void
*
workspace
,
cudaStream_t
stream
,
int
phases
)
{
constexpr
int
kNumMaxTopK
=
9
;
constexpr
int
kNumMaxTopK
=
9
;
constexpr
int
kNumWarpsPerGroup
=
10
;
constexpr
int
kNumWarpsPerGroup
=
10
;
...
@@ -314,15 +325,16 @@ void dispatch(void* packed_recv_x, float* packed_recv_x_scales,
...
@@ -314,15 +325,16 @@ void dispatch(void* packed_recv_x, float* packed_recv_x_scales,
const
auto
num_warps
=
kNumWarpGroups
*
kNumWarpsPerGroup
;
const
auto
num_warps
=
kNumWarpGroups
*
kNumWarpsPerGroup
;
const
auto
num_sms
=
cell_div
(
num_experts
,
kNumWarpGroups
);
const
auto
num_sms
=
cell_div
(
num_experts
,
kNumWarpGroups
);
EP_HOST_ASSERT
(
num_topk
<=
kNumMaxTopK
);
EP_HOST_ASSERT
(
num_topk
<=
kNumMaxTopK
);
EP_HOST_ASSERT
(
cell_div
(
static_cast
<
int
>
(
hidden
*
2
/
sizeof
(
int4
)),
32
*
(
num_warps
-
1
))
<=
2
);
// Workspace checks
// Workspace checks
auto
atomic_counter_per_expert
=
reinterpret_cast
<
int
*>
(
workspace
);
auto
atomic_counter_per_expert
=
reinterpret_cast
<
int
*>
(
workspace
);
auto
atomic_finish_counter_per_expert
=
atomic_counter_per_expert
+
num_experts
;
auto
atomic_finish_counter_per_expert
=
atomic_counter_per_expert
+
num_experts
;
EP_HOST_ASSERT
(
num_experts
*
sizeof
(
int
)
*
2
<=
NUM_WORKSPACE_BYTES
);
EP_HOST_ASSERT
(
num_experts
*
sizeof
(
int
)
*
2
<=
NUM_WORKSPACE_BYTES
);
#define DISPATCH_LAUNCH_CASE(hidden) \
#define DISPATCH_LAUNCH_CASE(hidden) { \
LAUNCH_KERNEL(&cfg, dispatch<kNumWarpGroups, kNumWarpsPerGroup, hidden>, \
auto dispatch_func = use_fp8 ? dispatch<true, kNumWarpGroups, kNumWarpsPerGroup, hidden> : \
dispatch<false, kNumWarpGroups, kNumWarpsPerGroup, hidden>; \
LAUNCH_KERNEL(&cfg, dispatch_func, \
packed_recv_x, packed_recv_x_scales, \
packed_recv_x, packed_recv_x_scales, \
packed_recv_src_info, packed_recv_layout_range, \
packed_recv_src_info, packed_recv_layout_range, \
packed_recv_count, \
packed_recv_count, \
...
@@ -331,7 +343,7 @@ LAUNCH_KERNEL(&cfg, dispatch<kNumWarpGroups, kNumWarpsPerGroup, hidden>, \
...
@@ -331,7 +343,7 @@ LAUNCH_KERNEL(&cfg, dispatch<kNumWarpGroups, kNumWarpsPerGroup, hidden>, \
atomic_counter_per_expert, atomic_finish_counter_per_expert, \
atomic_counter_per_expert, atomic_finish_counter_per_expert, \
next_clean, num_next_clean_int, \
next_clean, num_next_clean_int, \
num_tokens, num_max_dispatch_tokens_per_rank, \
num_tokens, num_max_dispatch_tokens_per_rank, \
num_topk, num_experts, rank, num_ranks, phases); break
num_topk, num_experts, rank, num_ranks, phases);
}
break
SETUP_LAUNCH_CONFIG
(
num_sms
,
num_warps
*
32
,
stream
);
SETUP_LAUNCH_CONFIG
(
num_sms
,
num_warps
*
32
,
stream
);
SWITCH_HIDDEN
(
DISPATCH_LAUNCH_CASE
);
SWITCH_HIDDEN
(
DISPATCH_LAUNCH_CASE
);
...
...
deep_ep/buffer.py
View file @
ed7487c1
...
@@ -444,10 +444,10 @@ class Buffer:
...
@@ -444,10 +444,10 @@ class Buffer:
# noinspection PyTypeChecker
# noinspection PyTypeChecker
def
low_latency_dispatch
(
self
,
x
:
torch
.
Tensor
,
topk_idx
:
torch
.
Tensor
,
def
low_latency_dispatch
(
self
,
x
:
torch
.
Tensor
,
topk_idx
:
torch
.
Tensor
,
num_max_dispatch_tokens_per_rank
:
int
,
num_experts
:
int
,
num_max_dispatch_tokens_per_rank
:
int
,
num_experts
:
int
,
async_finish
:
bool
=
False
,
return_recv_hook
:
bool
=
False
)
->
\
use_fp8
:
bool
=
True
,
async_finish
:
bool
=
False
,
return_recv_hook
:
bool
=
False
)
->
\
Tuple
[
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
],
torch
.
Tensor
,
Tuple
,
EventOverlap
,
Callable
]:
Tuple
[
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
],
torch
.
Tensor
,
Tuple
,
EventOverlap
,
Callable
]:
"""
"""
A low-latency implementation for dispatching with IBGDA
**with implicit FP8 casting**
.
A low-latency implementation for dispatching with IBGDA.
This kernel requires all the ranks (no matter intranode or internode) should be visible via RDMA
This kernel requires all the ranks (no matter intranode or internode) should be visible via RDMA
(specifically, IBGDA must be enabled).
(specifically, IBGDA must be enabled).
Even for ranks in the same node, NVLink are fully disabled for simplicity.
Even for ranks in the same node, NVLink are fully disabled for simplicity.
...
@@ -461,19 +461,23 @@ class Buffer:
...
@@ -461,19 +461,23 @@ class Buffer:
are supported. `-1` indices (not selecting any expert) are supported.
are supported. `-1` indices (not selecting any expert) are supported.
num_max_dispatch_tokens_per_rank: the maximum number of tokens to dispatch, all the ranks must hold the same value.
num_max_dispatch_tokens_per_rank: the maximum number of tokens to dispatch, all the ranks must hold the same value.
num_experts: the number of all experts.
num_experts: the number of all experts.
use_fp8: whether to enable FP8 casting, with this, the received data will be a tuple of FP8 tensor and scaling factors.
async_finish: the current stream will not wait for the communication kernels to be finished if set.
async_finish: the current stream will not wait for the communication kernels to be finished if set.
return_recv_hook: return a receiving hook if set. If set, the kernel will just do the RDMA request issues,
return_recv_hook: return a receiving hook if set. If set, the kernel will just do the RDMA request issues,
but **without actually receiving the data**. You must call the received hook to make sure the data's arrival.
but **without actually receiving the data**. You must call the received hook to make sure the data's arrival.
If you not set this flag, the kernel will ensure the data's arrival.
If you not set this flag, the kernel will ensure the data's arrival.
Returns:
Returns:
recv_x: a tuple with received tokens for each expert. The first element is a `torch.Tensor` shaped as
recv_x: a tensor or tuple with received tokens for each expert.
With `use_fp8=True`: the first element is a `torch.Tensor` shaped as
`[num_local_experts, num_max_dispatch_tokens_per_rank * num_ranks, hidden]` with `torch.float8_e4m3fn`.
`[num_local_experts, num_max_dispatch_tokens_per_rank * num_ranks, hidden]` with `torch.float8_e4m3fn`.
The second tensor is the corresponding scales for the first element with shape
The second tensor is the corresponding scales for the first element with shape
`[num_local_experts, num_max_dispatch_tokens_per_rank * num_ranks, hidden // 128]` with `torch.float`.
`[num_local_experts, num_max_dispatch_tokens_per_rank * num_ranks, hidden // 128]` with `torch.float`.
Notice that, the last-two-dimension of the scaling tensors are in column-major for TMA compatibility.
Notice that, the last-two-dimension of the scaling tensors are in column-major for TMA compatibility.
With `use_fp8=False`, the result would be a tensor shaped as
`[num_local_experts, num_max_dispatch_tokens_per_rank * num_ranks, hidden]` with `torch.bfloat16`.
Moreover, not all tokens are valid, only some of the `num_max_dispatch_tokens_per_rank * num_ranks` are,
Moreover, not all tokens are valid, only some of the `num_max_dispatch_tokens_per_rank * num_ranks` are,
as we do not synchronize CPU received count with GPU (also not incompatible with CUDA graph).
as we do not synchronize CPU received count with GPU (also not incompatible with CUDA graph
if synced
).
recv_count: a tensor shaped `[num_local_experts]` with type `torch.int`, indicating how many tokens each
recv_count: a tensor shaped `[num_local_experts]` with type `torch.int`, indicating how many tokens each
expert receive. As mentioned before, all not tokens are valid in `recv_x`.
expert receive. As mentioned before, all not tokens are valid in `recv_x`.
handle: the communication handle to be used in the `low_latency_combine` function.
handle: the communication handle to be used in the `low_latency_combine` function.
...
@@ -483,12 +487,12 @@ class Buffer:
...
@@ -483,12 +487,12 @@ class Buffer:
packed_recv_x
,
packed_recv_x_scales
,
packed_recv_count
,
packed_recv_src_info
,
packed_recv_layout_range
,
event
,
hook
=
\
packed_recv_x
,
packed_recv_x_scales
,
packed_recv_count
,
packed_recv_src_info
,
packed_recv_layout_range
,
event
,
hook
=
\
self
.
runtime
.
low_latency_dispatch
(
x
,
topk_idx
,
self
.
runtime
.
low_latency_dispatch
(
x
,
topk_idx
,
num_max_dispatch_tokens_per_rank
,
num_experts
,
num_max_dispatch_tokens_per_rank
,
num_experts
,
async_finish
,
return_recv_hook
)
use_fp8
,
async_finish
,
return_recv_hook
)
handle
=
(
packed_recv_src_info
,
packed_recv_layout_range
,
num_max_dispatch_tokens_per_rank
,
num_experts
)
handle
=
(
packed_recv_src_info
,
packed_recv_layout_range
,
num_max_dispatch_tokens_per_rank
,
num_experts
)
tensors_to_record
=
(
x
,
topk_idx
,
tensors_to_record
=
(
x
,
topk_idx
,
packed_recv_x
,
packed_recv_x_scales
,
packed_recv_count
,
packed_recv_x
,
packed_recv_x_scales
,
packed_recv_count
,
packed_recv_src_info
,
packed_recv_layout_range
)
packed_recv_src_info
,
packed_recv_layout_range
)
return
(
packed_recv_x
,
packed_recv_x_scales
),
packed_recv_count
,
handle
,
\
return
(
packed_recv_x
,
packed_recv_x_scales
)
if
use_fp8
else
packed_recv_x
,
packed_recv_count
,
handle
,
\
EventOverlap
(
event
,
tensors_to_record
if
async_finish
else
None
),
hook
EventOverlap
(
event
,
tensors_to_record
if
async_finish
else
None
),
hook
# noinspection PyTypeChecker
# noinspection PyTypeChecker
...
...
tests/test_low_latency.py
View file @
ed7487c1
...
@@ -33,19 +33,21 @@ def test_main(num_tokens: int, hidden: int, num_experts: int, num_topk: int,
...
@@ -33,19 +33,21 @@ def test_main(num_tokens: int, hidden: int, num_experts: int, num_topk: int,
do_check
=
True
do_check
=
True
hash_value
,
num_times
=
0
,
0
hash_value
,
num_times
=
0
,
0
for
return_recv_hook
in
(
False
,
True
):
for
return_recv_hook
in
(
False
,
True
):
for
dispatch_use_fp8
in
(
False
,
True
):
num_times
+=
1
num_times
+=
1
for
i
in
range
((
num_times
%
2
)
+
1
):
for
i
in
range
((
num_times
%
2
)
+
1
):
packed_recv_x
,
packed_recv_count
,
handle
,
event
,
hook
=
\
packed_recv_x
,
packed_recv_count
,
handle
,
event
,
hook
=
\
buffer
.
low_latency_dispatch
(
x
,
topk_idx
,
num_tokens
,
num_experts
,
buffer
.
low_latency_dispatch
(
x
,
topk_idx
,
num_tokens
,
num_experts
,
use_fp8
=
dispatch_use_fp8
,
async_finish
=
not
return_recv_hook
,
return_recv_hook
=
return_recv_hook
)
async_finish
=
not
return_recv_hook
,
return_recv_hook
=
return_recv_hook
)
hook
()
if
return_recv_hook
else
event
.
current_stream_wait
()
hook
()
if
return_recv_hook
else
event
.
current_stream_wait
()
packed_recv_x
=
(
packed_recv_x
[
0
],
packed_recv_x
[
1
].
contiguous
())
packed_recv_x
=
(
packed_recv_x
[
0
],
packed_recv_x
[
1
].
contiguous
())
if
dispatch_use_fp8
else
packed_recv_x
simulated_gemm_x
=
per_token_cast_back
(
packed_recv_x
[
0
].
view
(
-
1
,
hidden
),
packed_recv_x
[
1
].
view
(
-
1
,
hidden
//
128
)).
view
(
packed_recv_x
[
0
].
shape
)
simulated_gemm_x
=
per_token_cast_back
(
packed_recv_x
[
0
].
view
(
-
1
,
hidden
),
packed_recv_x
[
1
].
view
(
-
1
,
hidden
//
128
)).
view
(
packed_recv_x
[
0
].
shape
)
\
if
dispatch_use_fp8
else
packed_recv_x
.
clone
()
all_topk_idx
=
torch
.
empty
((
num_ranks
,
num_tokens
,
num_topk
),
dtype
=
topk_idx
.
dtype
,
device
=
'cuda'
)
all_topk_idx
=
torch
.
empty
((
num_ranks
,
num_tokens
,
num_topk
),
dtype
=
topk_idx
.
dtype
,
device
=
'cuda'
)
dist
.
all_gather_into_tensor
(
all_topk_idx
,
topk_idx
,
group
=
group
)
dist
.
all_gather_into_tensor
(
all_topk_idx
,
topk_idx
,
group
=
group
)
for
i
in
range
(
num_local_experts
if
do_check
else
0
):
for
i
in
range
(
num_local_experts
if
do_check
else
0
):
expert_id
=
rank
*
num_local_experts
+
i
expert_id
=
rank
*
num_local_experts
+
i
recv_x
=
per_token_cast_back
(
packed_recv_x
[
0
][
i
],
packed_recv_x
[
1
][
i
])
recv_x
=
per_token_cast_back
(
packed_recv_x
[
0
][
i
],
packed_recv_x
[
1
][
i
])
if
dispatch_use_fp8
else
packed_recv_x
[
i
]
recv_count
,
recv_src_info
,
recv_layout_range
=
packed_recv_count
[
i
],
handle
[
0
][
i
],
handle
[
1
][
i
]
recv_count
,
recv_src_info
,
recv_layout_range
=
packed_recv_count
[
i
],
handle
[
0
][
i
],
handle
[
1
][
i
]
# Check expert indices
# Check expert indices
...
@@ -64,8 +66,11 @@ def test_main(num_tokens: int, hidden: int, num_experts: int, num_topk: int,
...
@@ -64,8 +66,11 @@ def test_main(num_tokens: int, hidden: int, num_experts: int, num_topk: int,
begin_idx
,
count
=
(
recv_layout_range
[
j
]
>>
32
).
item
(),
(
recv_layout_range
[
j
]
&
int_mask
).
item
()
begin_idx
,
count
=
(
recv_layout_range
[
j
]
>>
32
).
item
(),
(
recv_layout_range
[
j
]
&
int_mask
).
item
()
assert
(
recv_x_amin
==
j
-
rank_offset
).
sum
().
item
()
==
(
all_topk_idx
[
j
]
==
expert_id
).
sum
().
item
()
assert
(
recv_x_amin
==
j
-
rank_offset
).
sum
().
item
()
==
(
all_topk_idx
[
j
]
==
expert_id
).
sum
().
item
()
assert
(
recv_x
[
begin_idx
:
begin_idx
+
count
][:
-
128
]
-
j
).
sum
().
item
()
==
0
assert
(
recv_x
[
begin_idx
:
begin_idx
+
count
][:
-
128
]
-
j
).
sum
().
item
()
==
0
if
dispatch_use_fp8
:
hash_value
^=
hash_tensor
(
packed_recv_x
[
0
][
i
,
:
num_valid_tokens
])
hash_value
^=
hash_tensor
(
packed_recv_x
[
0
][
i
,
:
num_valid_tokens
])
hash_value
^=
hash_tensor
(
packed_recv_x
[
1
][
i
,
:
num_valid_tokens
])
hash_value
^=
hash_tensor
(
packed_recv_x
[
1
][
i
,
:
num_valid_tokens
])
else
:
hash_value
^=
hash_tensor
(
packed_recv_x
[
i
,
:
num_valid_tokens
])
# Check combine correctness
# Check combine correctness
combined_x
,
event
,
hook
=
buffer
.
low_latency_combine
(
simulated_gemm_x
,
topk_idx
,
topk_weights
,
handle
,
combined_x
,
event
,
hook
=
buffer
.
low_latency_combine
(
simulated_gemm_x
,
topk_idx
,
topk_weights
,
handle
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment