Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
DeepEP
Commits
35735902
Commit
35735902
authored
Nov 25, 2025
by
lijian6
Browse files
Merge branch 'int8-main' into 'main'
支持int8类型的kernel接口 See merge request dcutoolkit/deeplearing/DeepEP!3
parents
baa261b5
6dfe3bc2
Changes
8
Hide whitespace changes
Inline
Side-by-side
Showing
8 changed files
with
385 additions
and
77 deletions
+385
-77
csrc/deep_ep.cu
csrc/deep_ep.cu
+14
-20
csrc/deep_ep.hpp
csrc/deep_ep.hpp
+1
-1
csrc/kernels/api.cuh
csrc/kernels/api.cuh
+1
-1
csrc/kernels/internode_ll.cu
csrc/kernels/internode_ll.cu
+107
-44
csrc/kernels/utils.cuh
csrc/kernels/utils.cuh
+8
-9
deep_ep/buffer.py
deep_ep/buffer.py
+3
-2
tests/test_low_latency_new_int8.py
tests/test_low_latency_new_int8.py
+228
-0
tests/utils.py
tests/utils.py
+23
-0
No files found.
csrc/deep_ep.cu
View file @
35735902
...
@@ -1293,7 +1293,7 @@ void Buffer::clean_low_latency_buffer(int num_max_dispatch_tokens_per_rank, int
...
@@ -1293,7 +1293,7 @@ void Buffer::clean_low_latency_buffer(int num_max_dispatch_tokens_per_rank, int
std
::
tuple
<
torch
::
Tensor
,
std
::
optional
<
torch
::
Tensor
>
,
torch
::
Tensor
,
torch
::
Tensor
,
torch
::
Tensor
,
std
::
optional
<
EventHandle
>
,
std
::
optional
<
std
::
function
<
void
()
>>>
std
::
tuple
<
torch
::
Tensor
,
std
::
optional
<
torch
::
Tensor
>
,
torch
::
Tensor
,
torch
::
Tensor
,
torch
::
Tensor
,
std
::
optional
<
EventHandle
>
,
std
::
optional
<
std
::
function
<
void
()
>>>
Buffer
::
low_latency_dispatch
(
const
torch
::
Tensor
&
x
,
const
torch
::
Tensor
&
topk_idx
,
Buffer
::
low_latency_dispatch
(
const
torch
::
Tensor
&
x
,
const
torch
::
Tensor
&
topk_idx
,
int
num_max_dispatch_tokens_per_rank
,
int
num_experts
,
int
num_max_dispatch_tokens_per_rank
,
int
num_experts
,
bool
use_fp8
,
bool
round_scale
,
bool
use_ue8m0
,
bool
use_fp8
,
bool
round_scale
,
bool
use_ue8m0
,
bool
use_int8
,
bool
async
,
bool
return_recv_hook
)
{
bool
async
,
bool
return_recv_hook
)
{
EP_HOST_ASSERT
(
low_latency_mode
);
EP_HOST_ASSERT
(
low_latency_mode
);
...
@@ -1316,13 +1316,8 @@ Buffer::low_latency_dispatch(const torch::Tensor& x, const torch::Tensor& topk_i
...
@@ -1316,13 +1316,8 @@ Buffer::low_latency_dispatch(const torch::Tensor& x, const torch::Tensor& topk_i
auto
buffer
=
layout
.
buffers
[
low_latency_buffer_idx
];
auto
buffer
=
layout
.
buffers
[
low_latency_buffer_idx
];
auto
next_buffer
=
layout
.
buffers
[
low_latency_buffer_idx
^=
1
];
auto
next_buffer
=
layout
.
buffers
[
low_latency_buffer_idx
^=
1
];
// Buffer control
LowLatencyLayout
nvl_layout
(
nvl_buffer_ptrs
[
nvl_rank
],
num_max_dispatch_tokens_per_rank
,
hidden
,
num_ranks
,
num_experts
);
EP_HOST_ASSERT
(
nvl_layout
.
total_bytes
<=
num_rdma_bytes
);
auto
nvl_buffer
=
nvl_layout
.
buffers
[
low_latency_buffer_idx
^=
1
];
auto
nvl_next_buffer
=
nvl_layout
.
buffers
[
low_latency_buffer_idx
^=
1
];
auto
global_atomic_counter
=
torch
::
zeros
({
1
},
torch
::
dtype
(
torch
::
kInt32
).
device
(
torch
::
kCUDA
));
auto
global_atomic_counter
=
torch
::
zeros
({
1
},
torch
::
dtype
(
torch
::
kInt32
).
device
(
torch
::
kCUDA
));
// Wait previous tasks to be finished
// Wait previous tasks to be finished
// NOTES: the hook mode will always use the default stream
// NOTES: the hook mode will always use the default stream
auto
compute_stream
=
at
::
hip
::
getCurrentHIPStreamMasqueradingAsCUDA
();
auto
compute_stream
=
at
::
hip
::
getCurrentHIPStreamMasqueradingAsCUDA
();
...
@@ -1333,7 +1328,7 @@ Buffer::low_latency_dispatch(const torch::Tensor& x, const torch::Tensor& topk_i
...
@@ -1333,7 +1328,7 @@ Buffer::low_latency_dispatch(const torch::Tensor& x, const torch::Tensor& topk_i
// Allocate packed tensors
// Allocate packed tensors
auto
packed_recv_x
=
torch
::
empty
({
num_local_experts
,
num_ranks
*
num_max_dispatch_tokens_per_rank
,
hidden
},
auto
packed_recv_x
=
torch
::
empty
({
num_local_experts
,
num_ranks
*
num_max_dispatch_tokens_per_rank
,
hidden
},
x
.
options
().
dtype
(
use_fp8
?
torch
::
kFloat8_e4m3fnuz
:
torch
::
kBFloat16
));
x
.
options
().
dtype
(
use_int8
?
torch
::
kInt8
:
use_fp8
?
torch
::
kFloat8_e4m3fnuz
:
torch
::
kBFloat16
));
auto
packed_recv_src_info
=
torch
::
empty
({
num_local_experts
,
num_ranks
*
num_max_dispatch_tokens_per_rank
},
torch
::
dtype
(
torch
::
kInt32
).
device
(
torch
::
kCUDA
));
auto
packed_recv_src_info
=
torch
::
empty
({
num_local_experts
,
num_ranks
*
num_max_dispatch_tokens_per_rank
},
torch
::
dtype
(
torch
::
kInt32
).
device
(
torch
::
kCUDA
));
auto
packed_recv_layout_range
=
torch
::
empty
({
num_local_experts
,
num_ranks
},
torch
::
dtype
(
torch
::
kInt64
).
device
(
torch
::
kCUDA
));
auto
packed_recv_layout_range
=
torch
::
empty
({
num_local_experts
,
num_ranks
},
torch
::
dtype
(
torch
::
kInt64
).
device
(
torch
::
kCUDA
));
auto
packed_recv_count
=
torch
::
empty
({
num_local_experts
},
torch
::
dtype
(
torch
::
kInt32
).
device
(
torch
::
kCUDA
));
auto
packed_recv_count
=
torch
::
empty
({
num_local_experts
},
torch
::
dtype
(
torch
::
kInt32
).
device
(
torch
::
kCUDA
));
...
@@ -1345,13 +1340,18 @@ Buffer::low_latency_dispatch(const torch::Tensor& x, const torch::Tensor& topk_i
...
@@ -1345,13 +1340,18 @@ Buffer::low_latency_dispatch(const torch::Tensor& x, const torch::Tensor& topk_i
EP_HOST_ASSERT
((
num_ranks
*
num_max_dispatch_tokens_per_rank
)
%
4
==
0
and
"TMA requires the number of tokens to be multiple of 4"
);
EP_HOST_ASSERT
((
num_ranks
*
num_max_dispatch_tokens_per_rank
)
%
4
==
0
and
"TMA requires the number of tokens to be multiple of 4"
);
// TODO: support unaligned cases
// TODO: support unaligned cases
EP_HOST_ASSERT
(
hidden
%
(
FP8_QUANTIZATION_NUM_PER_CHANNEL
*
4
)
==
0
);
EP_HOST_ASSERT
(
hidden
%
(
FP8_QUANTIZATION_NUM_PER_CHANNEL
*
4
)
==
0
);
if
(
not
use_ue8m0
)
{
EP_HOST_ASSERT
(
!
(
use_ue8m0
&&
use_int8
));
packed_recv_x_scales
=
torch
::
empty
({
num_local_experts
,
hidden
/
FP8_QUANTIZATION_NUM_PER_CHANNEL
,
num_ranks
*
num_max_dispatch_tokens_per_rank
},
torch
::
dtype
(
torch
::
kFloat32
).
device
(
torch
::
kCUDA
));
if
(
use_ue8m0
)
{
}
else
{
EP_HOST_ASSERT
(
round_scale
);
EP_HOST_ASSERT
(
round_scale
);
packed_recv_x_scales
=
torch
::
empty
({
num_local_experts
,
hidden
/
(
FP8_QUANTIZATION_NUM_PER_CHANNEL
*
4
),
num_ranks
*
num_max_dispatch_tokens_per_rank
},
packed_recv_x_scales
=
torch
::
empty
({
num_local_experts
,
hidden
/
(
FP8_QUANTIZATION_NUM_PER_CHANNEL
*
4
),
num_ranks
*
num_max_dispatch_tokens_per_rank
},
torch
::
dtype
(
torch
::
kInt
).
device
(
torch
::
kCUDA
));
torch
::
dtype
(
torch
::
kInt
).
device
(
torch
::
kCUDA
));
}
else
if
(
use_int8
)
{
packed_recv_x_scales
=
torch
::
empty
({
num_local_experts
,
1
,
num_ranks
*
num_max_dispatch_tokens_per_rank
},
torch
::
dtype
(
torch
::
kFloat32
).
device
(
torch
::
kCUDA
));
}
else
{
packed_recv_x_scales
=
torch
::
empty
({
num_local_experts
,
hidden
/
FP8_QUANTIZATION_NUM_PER_CHANNEL
,
num_ranks
*
num_max_dispatch_tokens_per_rank
},
torch
::
dtype
(
torch
::
kFloat32
).
device
(
torch
::
kCUDA
));
}
}
packed_recv_x_scales
=
torch
::
transpose
(
packed_recv_x_scales
.
value
(),
1
,
2
);
packed_recv_x_scales
=
torch
::
transpose
(
packed_recv_x_scales
.
value
(),
1
,
2
);
packed_recv_x_scales_ptr
=
packed_recv_x_scales
->
data_ptr
();
packed_recv_x_scales_ptr
=
packed_recv_x_scales
->
data_ptr
();
...
@@ -1369,8 +1369,8 @@ Buffer::low_latency_dispatch(const torch::Tensor& x, const torch::Tensor& topk_i
...
@@ -1369,8 +1369,8 @@ Buffer::low_latency_dispatch(const torch::Tensor& x, const torch::Tensor& topk_i
x
.
data_ptr
(),
topk_idx
.
data_ptr
<
int64_t
>
(),
x
.
data_ptr
(),
topk_idx
.
data_ptr
<
int64_t
>
(),
next_clean_meta
.
first
,
next_clean_meta
.
second
,
next_clean_meta
.
first
,
next_clean_meta
.
second
,
num_tokens
,
hidden
,
num_max_dispatch_tokens_per_rank
,
num_tokens
,
hidden
,
num_max_dispatch_tokens_per_rank
,
num_topk
,
num_experts
,
rank
,
num_ranks
,
num_topk
,
num_experts
,
rank
,
num_ranks
,
use_fp8
,
round_scale
,
use_ue8m0
,
use_fp8
,
round_scale
,
use_ue8m0
,
use_int8
,
workspace
,
num_device_sms
,
launch_stream
,
phases
);
workspace
,
num_device_sms
,
launch_stream
,
phases
);
};
};
launcher
(
return_recv_hook
?
LOW_LATENCY_SEND_PHASE
:
(
LOW_LATENCY_SEND_PHASE
|
LOW_LATENCY_RECV_PHASE
));
launcher
(
return_recv_hook
?
LOW_LATENCY_SEND_PHASE
:
(
LOW_LATENCY_SEND_PHASE
|
LOW_LATENCY_RECV_PHASE
));
...
@@ -1427,12 +1427,6 @@ Buffer::low_latency_combine(const torch::Tensor& x, const torch::Tensor& topk_id
...
@@ -1427,12 +1427,6 @@ Buffer::low_latency_combine(const torch::Tensor& x, const torch::Tensor& topk_id
EP_HOST_ASSERT
(
layout
.
total_bytes
<=
num_rdma_bytes
);
EP_HOST_ASSERT
(
layout
.
total_bytes
<=
num_rdma_bytes
);
auto
buffer
=
layout
.
buffers
[
low_latency_buffer_idx
];
auto
buffer
=
layout
.
buffers
[
low_latency_buffer_idx
];
auto
next_buffer
=
layout
.
buffers
[
low_latency_buffer_idx
^=
1
];
auto
next_buffer
=
layout
.
buffers
[
low_latency_buffer_idx
^=
1
];
// Buffer control
LowLatencyLayout
nvl_layout
(
nvl_buffer_ptrs
[
nvl_rank
],
num_max_dispatch_tokens_per_rank
,
hidden
,
num_ranks
,
num_experts
);
EP_HOST_ASSERT
(
nvl_layout
.
total_bytes
<=
num_rdma_bytes
);
auto
nvl_buffer
=
nvl_layout
.
buffers
[
low_latency_buffer_idx
^=
1
];
auto
nvl_next_buffer
=
nvl_layout
.
buffers
[
low_latency_buffer_idx
^=
1
];
// Wait previous tasks to be finished
// Wait previous tasks to be finished
// NOTES: the hook mode will always use the default stream
// NOTES: the hook mode will always use the default stream
...
...
csrc/deep_ep.hpp
View file @
35735902
...
@@ -177,7 +177,7 @@ public:
...
@@ -177,7 +177,7 @@ public:
std
::
tuple
<
torch
::
Tensor
,
std
::
optional
<
torch
::
Tensor
>
,
torch
::
Tensor
,
torch
::
Tensor
,
torch
::
Tensor
,
std
::
optional
<
EventHandle
>
,
std
::
optional
<
std
::
function
<
void
()
>>>
std
::
tuple
<
torch
::
Tensor
,
std
::
optional
<
torch
::
Tensor
>
,
torch
::
Tensor
,
torch
::
Tensor
,
torch
::
Tensor
,
std
::
optional
<
EventHandle
>
,
std
::
optional
<
std
::
function
<
void
()
>>>
low_latency_dispatch
(
const
torch
::
Tensor
&
x
,
const
torch
::
Tensor
&
topk_idx
,
low_latency_dispatch
(
const
torch
::
Tensor
&
x
,
const
torch
::
Tensor
&
topk_idx
,
int
num_max_dispatch_tokens_per_rank
,
int
num_experts
,
int
num_max_dispatch_tokens_per_rank
,
int
num_experts
,
bool
use_fp8
,
bool
round_scale
,
bool
use_ue8m0
,
bool
use_fp8
,
bool
round_scale
,
bool
use_ue8m0
,
bool
use_int8
,
bool
async
,
bool
return_recv_hook
);
bool
async
,
bool
return_recv_hook
);
std
::
tuple
<
torch
::
Tensor
,
std
::
optional
<
EventHandle
>
,
std
::
optional
<
std
::
function
<
void
()
>>>
std
::
tuple
<
torch
::
Tensor
,
std
::
optional
<
EventHandle
>
,
std
::
optional
<
std
::
function
<
void
()
>>>
...
...
csrc/kernels/api.cuh
View file @
35735902
...
@@ -147,7 +147,7 @@ void dispatch(void* packed_recv_x, void* packed_recv_x_scales,
...
@@ -147,7 +147,7 @@ void dispatch(void* packed_recv_x, void* packed_recv_x_scales,
int64_t
*
next_clean
,
int
num_next_clean_int
,
int64_t
*
next_clean
,
int
num_next_clean_int
,
int
num_tokens
,
int
hidden
,
int
num_max_dispatch_tokens_per_rank
,
int
num_tokens
,
int
hidden
,
int
num_max_dispatch_tokens_per_rank
,
int
num_topk
,
int
num_experts
,
int
rank
,
int
num_ranks
,
int
num_topk
,
int
num_experts
,
int
rank
,
int
num_ranks
,
bool
use_fp8
,
bool
round_scale
,
bool
use_ue8m0
,
bool
use_fp8
,
bool
round_scale
,
bool
use_ue8m0
,
bool
use_int8
,
void
*
workspace
,
int
num_device_sms
,
hipStream_t
stream
,
int
phases
);
void
*
workspace
,
int
num_device_sms
,
hipStream_t
stream
,
int
phases
);
void
combine
(
void
*
combined_x
,
void
combine
(
void
*
combined_x
,
...
...
csrc/kernels/internode_ll.cu
View file @
35735902
...
@@ -31,8 +31,7 @@ __device__ void grid_barrier(int* global_counter, int num_blocks) {
...
@@ -31,8 +31,7 @@ __device__ void grid_barrier(int* global_counter, int num_blocks) {
__syncthreads
();
__syncthreads
();
__threadfence
();
__threadfence
();
if
(
threadIdx
.
x
==
0
)
{
if
(
threadIdx
.
x
==
0
)
{
// ret = __hip_atomic_fetch_add(&global_counter[0], 1, __ATOMIC_RELAXED, __HIP_MEMORY_SCOPE_AGENT);
ret
=
__hip_atomic_fetch_add
(
&
global_counter
[
0
],
1
,
__ATOMIC_RELAXED
,
__HIP_MEMORY_SCOPE_AGENT
);
ret
=
atomicAdd
(
&
global_counter
[
0
],
1
);
}
}
__syncthreads
();
__syncthreads
();
if
(
threadIdx
.
x
==
0
)
{
if
(
threadIdx
.
x
==
0
)
{
...
@@ -84,7 +83,7 @@ void clean_low_latency_buffer(int64_t* clean_0, int num_clean_int_0,
...
@@ -84,7 +83,7 @@ void clean_low_latency_buffer(int64_t* clean_0, int num_clean_int_0,
clean_0
,
num_clean_int_0
,
clean_1
,
num_clean_int_1
);
clean_0
,
num_clean_int_0
,
clean_1
,
num_clean_int_1
);
}
}
template
<
bool
kUseFP8
,
bool
kUseUE8M0
,
int
kHidden
>
template
<
bool
kUseFP8
,
bool
kUseUE8M0
,
bool
kUseInt8
,
int
kHidden
>
__global__
__launch_bounds__
(
16
*
kWarpSize
,
1
)
void
__global__
__launch_bounds__
(
16
*
kWarpSize
,
1
)
void
dispatch
(
void
*
packed_recv_x
,
void
*
packed_recv_x_scales
,
dispatch
(
void
*
packed_recv_x
,
void
*
packed_recv_x_scales
,
int
*
packed_recv_src_info
,
int64_t
*
packed_recv_layout_range
,
int
*
packed_recv_src_info
,
int64_t
*
packed_recv_layout_range
,
...
@@ -115,14 +114,14 @@ dispatch(void* packed_recv_x, void* packed_recv_x_scales,
...
@@ -115,14 +114,14 @@ dispatch(void* packed_recv_x, void* packed_recv_x_scales,
// FP8 staffs
// FP8 staffs
constexpr
int
kNumPerChannels
=
FP8_QUANTIZATION_NUM_PER_CHANNEL
;
constexpr
int
kNumPerChannels
=
FP8_QUANTIZATION_NUM_PER_CHANNEL
;
const
int
num_s
cales
=
kHidden
/
kNumPerChannels
;
const
expr
int
kNumS
cales
=
kHidden
/
kNumPerChannels
;
const
size_t
hidden_bytes
=
kHidden
*
(
kUseFP8
?
sizeof
(
__hip_fp8_storage_t
)
:
sizeof
(
hip_bfloat16
));
const
size_t
hidden_bytes
=
kHidden
*
(
kUseFP8
?
sizeof
(
__hip_fp8_storage_t
)
:
sizeof
(
hip_bfloat16
));
const
size_t
hidden_int4
=
hidden_bytes
/
sizeof
(
int4
);
const
size_t
hidden_int4
=
hidden_bytes
/
sizeof
(
int4
);
// Message package: hidden data, FP8 scales, index at source
// Message package: hidden data, FP8 scales, index at source
// NOTES: currently we have 3 reserved int fields for future use
// NOTES: currently we have 3 reserved int fields for future use
using
vec_t
=
typename
std
::
conditional
<
kUseFP8
,
int2
,
int4
>::
type
;
using
vec_t
=
typename
std
::
conditional
<
kUseFP8
,
int2
,
int4
>::
type
;
const
size_t
num_bytes_per_msg
=
sizeof
(
int4
)
+
(
kUseFP8
?
(
kHidden
+
num_s
cales
*
sizeof
(
float
))
:
(
kHidden
*
sizeof
(
hip_bfloat16
)));
const
size_t
num_bytes_per_msg
=
sizeof
(
int4
)
+
(
kUseFP8
?
(
kHidden
+
kNumS
cales
*
sizeof
(
float
))
:
(
kHidden
*
sizeof
(
hip_bfloat16
)));
const
size_t
num_int4_per_msg
=
num_bytes_per_msg
/
sizeof
(
int4
);
const
size_t
num_int4_per_msg
=
num_bytes_per_msg
/
sizeof
(
int4
);
EP_DEVICE_ASSERT
(
num_bytes_per_msg
%
sizeof
(
int4
)
==
0
);
EP_DEVICE_ASSERT
(
num_bytes_per_msg
%
sizeof
(
int4
)
==
0
);
...
@@ -147,7 +146,7 @@ dispatch(void* packed_recv_x, void* packed_recv_x_scales,
...
@@ -147,7 +146,7 @@ dispatch(void* packed_recv_x, void* packed_recv_x_scales,
EP_DEVICE_ASSERT
(
kHidden
%
kNumElemsPerRead
==
0
);
EP_DEVICE_ASSERT
(
kHidden
%
kNumElemsPerRead
==
0
);
EP_STATIC_ASSERT
(
kNumElemsPerRead
*
kWarpSize
%
kNumPerChannels
==
0
,
"Invalid vectorization"
);
EP_STATIC_ASSERT
(
kNumElemsPerRead
*
kWarpSize
%
kNumPerChannels
==
0
,
"Invalid vectorization"
);
const
auto
num_threads
=
(
num_warps
-
1
)
*
kWarpSize
;
const
auto
num_threads
=
(
num_warps
-
1
)
*
kWarpSize
;
const
size_
t
hidden_bf16_int4
=
kHidden
/
kNumElemsPerRead
;
const
expr
in
t
hidden_bf16_int4
=
kHidden
/
kNumElemsPerRead
;
for
(
int
token_idx
=
sm_id
;
token_idx
<
num_tokens
;
token_idx
+=
num_sms
)
{
for
(
int
token_idx
=
sm_id
;
token_idx
<
num_tokens
;
token_idx
+=
num_sms
)
{
const
auto
x_int4
=
reinterpret_cast
<
const
int4
*>
(
x
)
+
token_idx
*
hidden_bf16_int4
;
const
auto
x_int4
=
reinterpret_cast
<
const
int4
*>
(
x
)
+
token_idx
*
hidden_bf16_int4
;
...
@@ -159,13 +158,21 @@ dispatch(void* packed_recv_x, void* packed_recv_x_scales,
...
@@ -159,13 +158,21 @@ dispatch(void* packed_recv_x, void* packed_recv_x_scales,
auto
dst_expert_idx
=
warp_id
<
num_topk
?
static_cast
<
int
>
(
__ldg
(
topk_idx
+
token_idx
*
num_topk
+
warp_id
))
:
-
1
;
auto
dst_expert_idx
=
warp_id
<
num_topk
?
static_cast
<
int
>
(
__ldg
(
topk_idx
+
token_idx
*
num_topk
+
warp_id
))
:
-
1
;
thread_id
==
0
?
(
*
rdma_x_src_idx
=
token_idx
)
:
0
;
thread_id
==
0
?
(
*
rdma_x_src_idx
=
token_idx
)
:
0
;
__shared__
float
int8_amaxf
[
kNumScales
];
if
constexpr
(
kUseInt8
)
{
if
(
thread_id
<
kNumScales
)
{
int8_amaxf
[
thread_id
]
=
kFP8Margin
;
}
__syncthreads
();
}
// FP8 cast
// FP8 cast
#pragma unroll
#pragma unroll
for
(
int
i
=
thread_id
;
i
<
hidden_bf16_int4
;
i
+=
num_threads
)
{
for
(
int
i
=
thread_id
;
i
<
hidden_bf16_int4
;
i
+=
num_threads
)
{
// Read
// Read
auto
int4_value
=
__ldg
(
x_int4
+
i
);
auto
int4_value
=
__ldg
(
x_int4
+
i
);
if
(
kUseFP8
)
{
if
constexpr
(
kUseFP8
)
{
// Calculate local amax
// Calculate local amax
auto
bf16_values
=
reinterpret_cast
<
hip_bfloat16
*>
(
&
int4_value
);
auto
bf16_values
=
reinterpret_cast
<
hip_bfloat16
*>
(
&
int4_value
);
float
fp32_values
[
kNumElemsPerRead
];
float
fp32_values
[
kNumElemsPerRead
];
...
@@ -178,25 +185,74 @@ dispatch(void* packed_recv_x, void* packed_recv_x_scales,
...
@@ -178,25 +185,74 @@ dispatch(void* packed_recv_x, void* packed_recv_x_scales,
// Reduce amax and scale
// Reduce amax and scale
EP_STATIC_ASSERT
(
kNumElemsPerRead
*
kWarpSize
/
kNumPerChannels
==
4
,
"Invalid vectorization"
);
EP_STATIC_ASSERT
(
kNumElemsPerRead
*
kWarpSize
/
kNumPerChannels
==
4
,
"Invalid vectorization"
);
amax
=
warp_reduce_max
<
16
>
(
amax
);
amax
=
warp_reduce_max
<
16
>
(
amax
);
calculate_fp8_scales
(
amax
,
scale
,
scale_inv
,
round_scale
);
const
int
scale_offset
=
i
*
kNumElemsPerRead
/
FP8_QUANTIZATION_NUM_PER_CHANNEL
;
if
(
lane_id
%
16
==
0
)
rdma_x_scales
[
i
*
kNumElemsPerRead
/
FP8_QUANTIZATION_NUM_PER_CHANNEL
]
=
scale_inv
;
if
constexpr
(
kUseInt8
)
{
// 记录每128个数的最大值
// Cast into send buffer
int8_amaxf
[
scale_offset
]
=
fmaxf
(
amax
,
int8_amaxf
[
scale_offset
]);
vec_t
int2_value
;
}
else
{
auto
fp8x2_values
=
reinterpret_cast
<
__hip_fp8x2_storage_t
*>
(
&
int2_value
);
calculate_fp8_scales
(
amax
,
scale
,
scale_inv
,
round_scale
);
#pragma unroll
if
(
lane_id
%
16
==
0
)
for
(
int
j
=
0
;
j
<
kNumElemsPerRead
;
j
+=
2
)
{
rdma_x_scales
[
scale_offset
]
=
scale_inv
;
float2
fp32x2
=
{
fp32_values
[
j
]
*
scale
,
fp32_values
[
j
+
1
]
*
scale
};
fp8x2_values
[
j
/
2
]
=
__hip_cvt_float2_to_fp8x2
(
fp32x2
,
__HIP_SATFINITE
,
__HIP_E4M3_FNUZ
);
// Cast into send buffer
vec_t
int2_value
;
auto
fp8x2_values
=
reinterpret_cast
<
__hip_fp8x2_storage_t
*>
(
&
int2_value
);
#pragma unroll
for
(
int
j
=
0
;
j
<
kNumElemsPerRead
;
j
+=
2
)
{
float2
fp32x2
=
{
fp32_values
[
j
]
*
scale
,
fp32_values
[
j
+
1
]
*
scale
};
fp8x2_values
[
j
/
2
]
=
__hip_cvt_float2_to_fp8x2
(
fp32x2
,
__HIP_SATFINITE
,
__HIP_E4M3_FNUZ
);
}
rdma_x_vec
[
i
]
=
int2_value
;
}
}
rdma_x_vec
[
i
]
=
int2_value
;
}
else
{
}
else
{
// Reinterpret-cast is for C++14 compatibility
// Reinterpret-cast is for C++14 compatibility
rdma_x_vec
[
i
]
=
*
reinterpret_cast
<
vec_t
*>
(
&
int4_value
);
rdma_x_vec
[
i
]
=
*
reinterpret_cast
<
vec_t
*>
(
&
int4_value
);
}
}
}
}
__syncthreads
();
__syncthreads
();
if
constexpr
(
kUseInt8
)
{
float
amax_per_token
=
kFP8Margin
;
// 并行规约,计算每个token的amax
for
(
int
s
=
0
;
s
<
kNumScales
;
s
+=
kWarpSize
)
{
int
src_idx
=
s
+
lane_id
;
float
tmp_amaxf
=
0
;
if
(
src_idx
<
kNumScales
)
{
tmp_amaxf
=
int8_amaxf
[
src_idx
];
}
tmp_amaxf
=
warp_reduce_max
<
kWarpSize
>
(
tmp_amaxf
);
int8_amaxf
[
0
]
=
fmaxf
(
tmp_amaxf
,
int8_amaxf
[
0
]);
__syncthreads
();
}
amax_per_token
=
int8_amaxf
[
0
];
// 根据最大值计算scale
float
scale
,
scale_inv
;
calculate_int8_scales
(
amax_per_token
,
scale
,
scale_inv
);
if
(
thread_id
==
0
)
{
rdma_x_scales
[
0
]
=
scale_inv
;
}
for
(
int
i
=
thread_id
;
i
<
hidden_bf16_int4
;
i
+=
num_threads
)
{
// Read
auto
int4_value
=
__ldg
(
x_int4
+
i
);
auto
bf16_values
=
reinterpret_cast
<
hip_bfloat16
*>
(
&
int4_value
);
// Cast into send buffer
vec_t
int2_value
;
auto
int8_values
=
reinterpret_cast
<
int8_t
*>
(
&
int2_value
);
#pragma unroll
for
(
int
j
=
0
;
j
<
kNumElemsPerRead
;
++
j
)
{
auto
fp32_value
=
static_cast
<
float
>
(
bf16_values
[
j
]);
auto
fp32_value_scaled
=
fp32_value
*
scale
;
int8_values
[
j
]
=
static_cast
<
int8_t
>
(
nearbyintf
(
fp32_value_scaled
));
}
rdma_x_vec
[
i
]
=
int2_value
;
}
__syncthreads
();
}
// Issue IBGDA sends
// Issue IBGDA sends
if
(
dst_expert_idx
>=
0
)
{
if
(
dst_expert_idx
>=
0
)
{
int
slot_idx
=
lane_id
==
0
?
atomicAdd
(
atomic_counter_per_expert
+
dst_expert_idx
,
1
)
:
0
;
int
slot_idx
=
lane_id
==
0
?
atomicAdd
(
atomic_counter_per_expert
+
dst_expert_idx
,
1
)
:
0
;
...
@@ -339,9 +395,10 @@ LOW_LATENCY_DISPATCH_RECV:
...
@@ -339,9 +395,10 @@ LOW_LATENCY_DISPATCH_RECV:
local_expert_idx
*
num_ranks
*
num_max_dispatch_tokens_per_rank
*
hidden_int4
;
local_expert_idx
*
num_ranks
*
num_max_dispatch_tokens_per_rank
*
hidden_int4
;
const
auto
recv_src_info
=
packed_recv_src_info
+
local_expert_idx
*
num_ranks
*
num_max_dispatch_tokens_per_rank
;
const
auto
recv_src_info
=
packed_recv_src_info
+
local_expert_idx
*
num_ranks
*
num_max_dispatch_tokens_per_rank
;
const
auto
recv_range
=
packed_recv_layout_range
+
local_expert_idx
*
num_ranks
;
const
auto
recv_range
=
packed_recv_layout_range
+
local_expert_idx
*
num_ranks
;
const
auto
num_aligned_scales
=
ALIGN
<
int
>
(
num_s
cales
,
sizeof
(
float
)
/
sizeof
(
scale_t
));
const
auto
num_aligned_scales
=
ALIGN
<
int
>
(
kNumS
cales
,
sizeof
(
float
)
/
sizeof
(
scale_t
));
const
auto
recv_x_scales
=
static_cast
<
scale_t
*>
(
packed_recv_x_scales
)
+
const
auto
recv_x_scales
=
static_cast
<
scale_t
*>
(
packed_recv_x_scales
)
+
local_expert_idx
*
num_ranks
*
num_max_dispatch_tokens_per_rank
*
num_aligned_scales
;
local_expert_idx
*
num_ranks
*
num_max_dispatch_tokens_per_rank
*
(
kUseInt8
?
1
:
num_aligned_scales
);
// Shared between sub-warps in warp groups
// Shared between sub-warps in warp groups
__shared__
int
shared_num_recv_tokens
[
kNumMaxWarpGroups
],
shared_recv_token_begin_idx
[
kNumMaxWarpGroups
];
__shared__
int
shared_num_recv_tokens
[
kNumMaxWarpGroups
],
shared_recv_token_begin_idx
[
kNumMaxWarpGroups
];
...
@@ -362,8 +419,7 @@ LOW_LATENCY_DISPATCH_RECV:
...
@@ -362,8 +419,7 @@ LOW_LATENCY_DISPATCH_RECV:
// no needs to reset because there is no iteration
// no needs to reset because there is no iteration
if
(
lane_id
==
0
){
if
(
lane_id
==
0
){
// volatile int ret = __hip_atomic_fetch_add(&sync_large_warp_counters[warp_group_id], 1, __ATOMIC_RELAXED, __HIP_MEMORY_SCOPE_WORKGROUP);
volatile
int
ret
=
__hip_atomic_fetch_add
(
&
sync_large_warp_counters
[
warp_group_id
],
1
,
__ATOMIC_RELAXED
,
__HIP_MEMORY_SCOPE_WORKGROUP
);
volatile
int
ret
=
atomicAdd
((
int
*
)
&
sync_large_warp_counters
[
warp_group_id
],
1
);
}
}
syncwarp
();
syncwarp
();
...
@@ -372,7 +428,7 @@ LOW_LATENCY_DISPATCH_RECV:
...
@@ -372,7 +428,7 @@ LOW_LATENCY_DISPATCH_RECV:
recv_token_begin_idx
=
shared_recv_token_begin_idx
[
warp_group_id
];
recv_token_begin_idx
=
shared_recv_token_begin_idx
[
warp_group_id
];
// Copy tokens
// Copy tokens
EP_DEVICE_ASSERT
(
num_s
cales
<=
64
);
EP_DEVICE_ASSERT
(
kNumS
cales
<=
64
);
for
(
int
i
=
sub_warp_id
;
i
<
num_recv_tokens
;
i
+=
num_warps_per_group
)
{
for
(
int
i
=
sub_warp_id
;
i
<
num_recv_tokens
;
i
+=
num_warps_per_group
)
{
// Copy source info
// Copy source info
const
auto
src_src_idx
=
reinterpret_cast
<
int
*>
(
rdma_recv_x_uint8
+
i
*
num_bytes_per_msg
);
const
auto
src_src_idx
=
reinterpret_cast
<
int
*>
(
rdma_recv_x_uint8
+
i
*
num_bytes_per_msg
);
...
@@ -387,24 +443,30 @@ LOW_LATENCY_DISPATCH_RECV:
...
@@ -387,24 +443,30 @@ LOW_LATENCY_DISPATCH_RECV:
UNROLLED_WARP_COPY_LL
(
7
,
lane_id
,
hidden_int4
,
dst_data
,
src_data
,
ld_nc_global
,
st_na_global
);
UNROLLED_WARP_COPY_LL
(
7
,
lane_id
,
hidden_int4
,
dst_data
,
src_data
,
ld_nc_global
,
st_na_global
);
// Copy scales
// Copy scales
if
(
kUseFP8
)
{
if
constexpr
(
kUseFP8
)
{
const
auto
src_scales
=
reinterpret_cast
<
float
*>
(
reinterpret_cast
<
uint8_t
*>
(
src_data
)
+
hidden_bytes
);
const
auto
src_scales
=
reinterpret_cast
<
float
*>
(
reinterpret_cast
<
uint8_t
*>
(
src_data
)
+
hidden_bytes
);
const
auto
num_elems_per_pack
=
static_cast
<
int
>
(
sizeof
(
packed_t
)
/
sizeof
(
scale_t
));
const
auto
num_elems_per_pack
=
static_cast
<
int
>
(
sizeof
(
packed_t
)
/
sizeof
(
scale_t
));
const
auto
token_idx
=
recv_token_begin_idx
+
i
;
const
auto
token_idx
=
recv_token_begin_idx
+
i
;
const
auto
token_stride
=
num_elems_per_pack
;
const
auto
token_stride
=
num_elems_per_pack
;
const
auto
pack_stride
=
num_ranks
*
num_max_dispatch_tokens_per_rank
*
num_elems_per_pack
;
const
auto
pack_stride
=
num_ranks
*
num_max_dispatch_tokens_per_rank
*
num_elems_per_pack
;
if
(
lane_id
<
num_scales
)
{
if
constexpr
(
kUseInt8
)
{
const
auto
pack_idx
=
lane_id
/
num_elems_per_pack
;
if
(
lane_id
==
0
)
{
const
auto
elem_idx
=
lane_id
%
num_elems_per_pack
;
recv_x_scales
[
token_idx
]
=
ld_nc_global
(
src_scales
);
auto
scale
=
extract_required_scale_format
<
kUseUE8M0
>
(
ld_nc_global
(
src_scales
+
lane_id
));
}
recv_x_scales
[
token_idx
*
token_stride
+
pack_idx
*
pack_stride
+
elem_idx
]
=
scale
;
}
else
{
}
if
(
lane_id
<
kNumScales
)
{
if
(
lane_id
+
kWarpSize
<
num_scales
)
{
const
auto
pack_idx
=
lane_id
/
num_elems_per_pack
;
const
auto
pack_idx
=
(
lane_id
+
kWarpSize
)
/
num_elems_per_pack
;
const
auto
elem_idx
=
lane_id
%
num_elems_per_pack
;
const
auto
elem_idx
=
(
lane_id
+
kWarpSize
)
%
num_elems_per_pack
;
auto
scale
=
extract_required_scale_format
<
kUseUE8M0
>
(
ld_nc_global
(
src_scales
+
lane_id
));
auto
scale
=
extract_required_scale_format
<
kUseUE8M0
>
(
ld_nc_global
(
src_scales
+
lane_id
+
kWarpSize
));
recv_x_scales
[
token_idx
*
token_stride
+
pack_idx
*
pack_stride
+
elem_idx
]
=
scale
;
recv_x_scales
[
token_idx
*
token_stride
+
pack_idx
*
pack_stride
+
elem_idx
]
=
scale
;
}
if
(
lane_id
+
kWarpSize
<
kNumScales
)
{
const
auto
pack_idx
=
(
lane_id
+
kWarpSize
)
/
num_elems_per_pack
;
const
auto
elem_idx
=
(
lane_id
+
kWarpSize
)
%
num_elems_per_pack
;
auto
scale
=
extract_required_scale_format
<
kUseUE8M0
>
(
ld_nc_global
(
src_scales
+
lane_id
+
kWarpSize
));
recv_x_scales
[
token_idx
*
token_stride
+
pack_idx
*
pack_stride
+
elem_idx
]
=
scale
;
}
}
}
}
}
}
}
...
@@ -420,7 +482,7 @@ void dispatch(void* packed_recv_x, void* packed_recv_x_scales,
...
@@ -420,7 +482,7 @@ void dispatch(void* packed_recv_x, void* packed_recv_x_scales,
int64_t
*
next_clean
,
int
num_next_clean_int
,
int64_t
*
next_clean
,
int
num_next_clean_int
,
int
num_tokens
,
int
hidden
,
int
num_max_dispatch_tokens_per_rank
,
int
num_tokens
,
int
hidden
,
int
num_max_dispatch_tokens_per_rank
,
int
num_topk
,
int
num_experts
,
int
rank
,
int
num_ranks
,
int
num_topk
,
int
num_experts
,
int
rank
,
int
num_ranks
,
bool
use_fp8
,
bool
round_scale
,
bool
use_ue8m0
,
bool
use_fp8
,
bool
round_scale
,
bool
use_ue8m0
,
bool
use_int8
,
void
*
workspace
,
int
num_device_sms
,
void
*
workspace
,
int
num_device_sms
,
hipStream_t
stream
,
int
phases
)
{
hipStream_t
stream
,
int
phases
)
{
constexpr
int
kNumMaxTopK
=
11
;
constexpr
int
kNumMaxTopK
=
11
;
...
@@ -439,11 +501,13 @@ void dispatch(void* packed_recv_x, void* packed_recv_x_scales,
...
@@ -439,11 +501,13 @@ void dispatch(void* packed_recv_x, void* packed_recv_x_scales,
EP_HOST_ASSERT
(
num_experts
*
sizeof
(
int
)
*
2
<=
NUM_WORKSPACE_BYTES
);
EP_HOST_ASSERT
(
num_experts
*
sizeof
(
int
)
*
2
<=
NUM_WORKSPACE_BYTES
);
#define DISPATCH_LAUNCH_CASE(hidden) { \
#define DISPATCH_LAUNCH_CASE(hidden) { \
auto dispatch_func = dispatch<false, false, hidden>; \
auto dispatch_func = dispatch<false, false, false, hidden>; \
if (use_fp8 and not use_ue8m0) \
if (use_fp8 and not use_ue8m0) \
dispatch_func = dispatch<true, false, hidden>; \
dispatch_func = dispatch<true, false, false, hidden>; \
if (use_fp8 and use_ue8m0) \
if (use_fp8 and use_ue8m0) \
dispatch_func = dispatch<true, true, hidden>; \
dispatch_func = dispatch<true, true, false, hidden>; \
if (use_int8) \
dispatch_func = dispatch<true, false, true, hidden>; \
LAUNCH_KERNEL_NON_COOPERATIVE(&cfg, dispatch_func, \
LAUNCH_KERNEL_NON_COOPERATIVE(&cfg, dispatch_func, \
packed_recv_x, packed_recv_x_scales, \
packed_recv_x, packed_recv_x_scales, \
packed_recv_src_info, packed_recv_layout_range, \
packed_recv_src_info, packed_recv_layout_range, \
...
@@ -575,8 +639,7 @@ combine(void* combined_x,
...
@@ -575,8 +639,7 @@ combine(void* combined_x,
// Put finishing flag
// Put finishing flag
EP_DEVICE_ASSERT
(
num_warps_per_group
>
1
);
EP_DEVICE_ASSERT
(
num_warps_per_group
>
1
);
if
(
lane_id
==
0
){
if
(
lane_id
==
0
){
// volatile int ret = __hip_atomic_fetch_add(&sync_large_warp_counters[warp_group_id], 1,__ATOMIC_RELAXED, __HIP_MEMORY_SCOPE_WORKGROUP);
volatile
int
ret
=
__hip_atomic_fetch_add
(
&
sync_large_warp_counters
[
warp_group_id
],
1
,
__ATOMIC_RELAXED
,
__HIP_MEMORY_SCOPE_WORKGROUP
);
volatile
int
ret
=
atomicAdd
((
int
*
)
&
sync_large_warp_counters
[
warp_group_id
],
1
);
}
}
syncwarp
();
syncwarp
();
while
(
sync_large_warp_counters
[
warp_group_id
]
<
num_warps_per_group
);
while
(
sync_large_warp_counters
[
warp_group_id
]
<
num_warps_per_group
);
...
...
csrc/kernels/utils.cuh
View file @
35735902
...
@@ -184,9 +184,8 @@ __device__ __forceinline__ int64_t ld_acquire_global(const int64_t *ptr) {
...
@@ -184,9 +184,8 @@ __device__ __forceinline__ int64_t ld_acquire_global(const int64_t *ptr) {
__device__
__forceinline__
int
atomic_add_release_global
(
const
int
*
ptr
,
int
value
)
{
__device__
__forceinline__
int
atomic_add_release_global
(
const
int
*
ptr
,
int
value
)
{
int
ret
;
int
ret
;
// ret = __hip_atomic_fetch_add(const_cast<int *>(ptr), value, __ATOMIC_RELEASE,
ret
=
__hip_atomic_fetch_add
(
const_cast
<
int
*>
(
ptr
),
value
,
__ATOMIC_RELEASE
,
__HIP_MEMORY_SCOPE_AGENT
);
// __HIP_MEMORY_SCOPE_AGENT);
// ret = atomicAdd((int*)ptr, value);
ret
=
atomicAdd
((
int
*
)
ptr
,
value
);
return
ret
;
return
ret
;
}
}
...
@@ -342,15 +341,10 @@ __device__ __forceinline__ dtype_t broadcast(dtype_t &ptr, int src_lane_idx) {
...
@@ -342,15 +341,10 @@ __device__ __forceinline__ dtype_t broadcast(dtype_t &ptr, int src_lane_idx) {
return
*
reinterpret_cast
<
dtype_t
*>
(
recv_int_values
);
return
*
reinterpret_cast
<
dtype_t
*>
(
recv_int_values
);
}
}
#ifndef FORCE_NVSHMEM_API
constexpr
float
kFP8Margin
=
1e-4
;
constexpr
float
kFP8Margin
=
1e-4
;
constexpr
float
kFinfoAmaxE4M3
=
240.0
f
;
constexpr
float
kFinfoAmaxE4M3
=
240.0
f
;
constexpr
float
kFinfoAmaxInvE4M3
=
1.0
f
/
kFinfoAmaxE4M3
;
constexpr
float
kFinfoAmaxInvE4M3
=
1.0
f
/
kFinfoAmaxE4M3
;
#else
constexpr
float
kInt8Amax
=
127.0
f
;
constexpr
float
kFP8Margin
=
1e-4
;
constexpr
float
kFinfoAmaxE4M3
=
448.0
f
;
constexpr
float
kFinfoAmaxInvE4M3
=
1.0
f
/
kFinfoAmaxE4M3
;
#endif
__forceinline__
__device__
float
fast_pow2
(
int
x
)
{
__forceinline__
__device__
float
fast_pow2
(
int
x
)
{
// We can ensure `-126 <= x and x <= 127`
// We can ensure `-126 <= x and x <= 127`
...
@@ -376,6 +370,11 @@ __forceinline__ __device__ void calculate_fp8_scales(float amax, float& scale, f
...
@@ -376,6 +370,11 @@ __forceinline__ __device__ void calculate_fp8_scales(float amax, float& scale, f
}
}
}
}
__forceinline__
__device__
void
calculate_int8_scales
(
float
amax
,
float
&
scale
,
float
&
scale_inv
)
{
scale
=
kInt8Amax
/
amax
;
scale_inv
=
amax
/
kInt8Amax
;
}
template
<
bool
kIsUE8M0
,
typename
out_dtype_t
=
std
::
conditional_t
<
kIsUE8M0
,
uint8_t
,
float
>
>
template
<
bool
kIsUE8M0
,
typename
out_dtype_t
=
std
::
conditional_t
<
kIsUE8M0
,
uint8_t
,
float
>
>
__forceinline__
__device__
out_dtype_t
extract_required_scale_format
(
float
value
)
{
__forceinline__
__device__
out_dtype_t
extract_required_scale_format
(
float
value
)
{
if
constexpr
(
kIsUE8M0
)
{
if
constexpr
(
kIsUE8M0
)
{
...
...
deep_ep/buffer.py
View file @
35735902
...
@@ -804,7 +804,7 @@ class Buffer:
...
@@ -804,7 +804,7 @@ class Buffer:
# noinspection PyTypeChecker
# noinspection PyTypeChecker
def
low_latency_dispatch
(
self
,
x
:
torch
.
Tensor
,
topk_idx
:
torch
.
Tensor
,
def
low_latency_dispatch
(
self
,
x
:
torch
.
Tensor
,
topk_idx
:
torch
.
Tensor
,
num_max_dispatch_tokens_per_rank
:
int
,
num_experts
:
int
,
num_max_dispatch_tokens_per_rank
:
int
,
num_experts
:
int
,
use_fp8
:
bool
=
True
,
round_scale
:
bool
=
False
,
use_ue8m0
:
bool
=
False
,
use_fp8
:
bool
=
True
,
round_scale
:
bool
=
False
,
use_ue8m0
:
bool
=
False
,
use_int8
:
bool
=
False
,
async_finish
:
bool
=
False
,
return_recv_hook
:
bool
=
False
)
->
\
async_finish
:
bool
=
False
,
return_recv_hook
:
bool
=
False
)
->
\
Tuple
[
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
],
torch
.
Tensor
,
Tuple
,
EventOverlap
,
Callable
]:
Tuple
[
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
],
torch
.
Tensor
,
Tuple
,
EventOverlap
,
Callable
]:
"""
"""
...
@@ -824,6 +824,7 @@ class Buffer:
...
@@ -824,6 +824,7 @@ class Buffer:
use_fp8: whether to enable FP8 casting, with this, the received data will be a tuple of FP8 tensor and scaling factors.
use_fp8: whether to enable FP8 casting, with this, the received data will be a tuple of FP8 tensor and scaling factors.
round_scale: whether round the scaling factors into power of 2.
round_scale: whether round the scaling factors into power of 2.
use_ue8m0: whether use UE8M0 as scaling factor format (available only with `round_scale=True`).
use_ue8m0: whether use UE8M0 as scaling factor format (available only with `round_scale=True`).
use_int8: whether to enable INT8 casting.
async_finish: the current stream will not wait for the communication kernels to be finished if set.
async_finish: the current stream will not wait for the communication kernels to be finished if set.
return_recv_hook: return a receiving hook if set. If set, the kernel will just do the RDMA request issues,
return_recv_hook: return a receiving hook if set. If set, the kernel will just do the RDMA request issues,
but **without actually receiving the data**. You must call the received hook to make sure the data's arrival.
but **without actually receiving the data**. You must call the received hook to make sure the data's arrival.
...
@@ -851,7 +852,7 @@ class Buffer:
...
@@ -851,7 +852,7 @@ class Buffer:
packed_recv_x
,
packed_recv_x_scales
,
packed_recv_count
,
packed_recv_src_info
,
packed_recv_layout_range
,
event
,
hook
=
\
packed_recv_x
,
packed_recv_x_scales
,
packed_recv_count
,
packed_recv_src_info
,
packed_recv_layout_range
,
event
,
hook
=
\
self
.
runtime
.
low_latency_dispatch
(
x
,
topk_idx
,
self
.
runtime
.
low_latency_dispatch
(
x
,
topk_idx
,
num_max_dispatch_tokens_per_rank
,
num_experts
,
num_max_dispatch_tokens_per_rank
,
num_experts
,
use_fp8
,
round_scale
,
use_ue8m0
,
use_fp8
,
round_scale
,
use_ue8m0
,
use_int8
,
async_finish
,
return_recv_hook
)
async_finish
,
return_recv_hook
)
handle
=
(
packed_recv_src_info
,
packed_recv_layout_range
,
num_max_dispatch_tokens_per_rank
,
x
.
size
(
1
),
num_experts
)
handle
=
(
packed_recv_src_info
,
packed_recv_layout_range
,
num_max_dispatch_tokens_per_rank
,
x
.
size
(
1
),
num_experts
)
tensors_to_record
=
(
x
,
topk_idx
,
tensors_to_record
=
(
x
,
topk_idx
,
...
...
tests/test_low_latency_new_int8.py
0 → 100644
View file @
35735902
import
argparse
import
random
import
os
import
torch
import
torch.distributed
as
dist
from
functools
import
partial
from
typing
import
Literal
,
Set
import
deep_ep
from
utils
import
init_dist
,
bench
,
bench_kineto
,
calc_diff
,
hash_tensor
,
per_token_cast_back
,
per_token_cast_back_int8
def
test_main
(
num_tokens
:
int
,
hidden
:
int
,
num_experts
:
int
,
num_topk
:
int
,
rank
:
int
,
num_ranks
:
int
,
group
:
dist
.
ProcessGroup
,
buffer
:
deep_ep
.
Buffer
,
seed
:
int
=
0
):
torch
.
manual_seed
(
seed
+
rank
)
random
.
seed
(
seed
+
rank
)
assert
num_experts
%
num_ranks
==
0
num_local_experts
=
num_experts
//
num_ranks
# NOTES: the integers greater than 256 exceed the BF16 precision limit
rank_offset
=
128
assert
num_ranks
-
rank_offset
<
257
,
'Too many ranks (exceeding test precision limit)'
x
=
torch
.
ones
((
num_tokens
,
hidden
),
dtype
=
torch
.
bfloat16
,
device
=
'cuda'
)
*
(
rank
-
rank_offset
)
x
[:,
-
128
:]
=
torch
.
arange
(
num_tokens
,
device
=
'cuda'
).
to
(
torch
.
bfloat16
).
view
(
-
1
,
1
)
x_list
=
[
x
]
# # NOTES: the last one is for performance testing
# # Most of the values in the perf case is lower than the threshold, casting most channels
# x_list.append(torch.randn((num_tokens, hidden), dtype=torch.bfloat16, device='cuda') * 0.1)
scores
=
torch
.
randn
((
num_tokens
,
num_experts
),
dtype
=
torch
.
float32
,
device
=
'cuda'
).
abs
()
+
1
topk_idx
=
torch
.
topk
(
scores
,
num_topk
,
dim
=-
1
,
largest
=
True
,
sorted
=
True
)[
1
]
topk_weights
=
torch
.
randn
((
num_tokens
,
num_topk
),
dtype
=
torch
.
float32
,
device
=
'cuda'
).
abs
()
# Randomly mask some positions
for
_
in
range
(
10
):
topk_idx
[
random
.
randint
(
0
,
num_tokens
-
1
),
random
.
randint
(
0
,
num_topk
-
1
)]
=
-
1
all_topk_idx
=
torch
.
empty
((
num_ranks
,
num_tokens
,
num_topk
),
dtype
=
topk_idx
.
dtype
,
device
=
'cuda'
)
dist
.
all_gather_into_tensor
(
all_topk_idx
,
topk_idx
,
group
=
group
)
# For failure simulation and shrink testing
mask_status
=
torch
.
zeros
((
num_ranks
,),
dtype
=
torch
.
int
,
device
=
'cuda'
)
# Check dispatch correctness
do_check
=
True
hash_value
,
num_times
=
0
,
0
for
current_x
in
x_list
:
for
return_recv_hook
in
(
False
,
):
for
dispatch_use_fp8
in
(
True
,
):
for
round_scale
in
(
False
,
):
for
use_ue8m0
in
(
False
,
):
num_times
+=
1
use_int8
=
True
for
_
in
range
(
1
):
packed_recv_x
,
packed_recv_count
,
handle
,
event
,
hook
=
\
buffer
.
low_latency_dispatch
(
current_x
,
topk_idx
,
num_tokens
,
num_experts
,
use_fp8
=
dispatch_use_fp8
,
round_scale
=
round_scale
,
use_ue8m0
=
use_ue8m0
,
use_int8
=
use_int8
,
async_finish
=
not
return_recv_hook
,
return_recv_hook
=
return_recv_hook
)
hook
()
if
return_recv_hook
else
event
.
current_stream_wait
()
packed_recv_x
=
(
packed_recv_x
[
0
],
packed_recv_x
[
1
].
contiguous
())
if
dispatch_use_fp8
else
packed_recv_x
for
i
in
range
(
num_local_experts
if
do_check
else
0
):
expert_id
=
rank
*
num_local_experts
+
i
recv_x
=
per_token_cast_back_int8
(
packed_recv_x
[
0
][
i
],
packed_recv_x
[
1
][
i
])
if
dispatch_use_fp8
else
packed_recv_x
[
i
]
recv_count
,
recv_src_info
,
recv_layout_range
=
packed_recv_count
[
i
],
handle
[
0
][
i
],
handle
[
1
][
i
]
# Check expert indices
int_mask
=
(
2
**
32
)
-
1
num_valid_tokens
=
recv_count
.
item
()
assert
num_valid_tokens
==
(
recv_layout_range
&
int_mask
).
sum
().
item
(),
f
'
{
num_valid_tokens
}
!=
{
recv_layout_range
&
int_mask
}
.sum().item()'
assert
num_valid_tokens
==
(
all_topk_idx
==
expert_id
).
sum
(
dim
=
[
1
,
2
])[
mask_status
==
0
].
sum
().
item
(
),
f
'
{
num_valid_tokens
}
!=
{
(
all_topk_idx
==
expert_id
).
sum
(
dim
=
[
1
,
2
])[
mask_status
==
0
].
sum
().
item
()
}
'
if
num_valid_tokens
==
0
:
continue
# Check received data
if
current_x
is
x
:
recv_x
=
recv_x
[:
num_valid_tokens
]
recv_x_amin
=
recv_x
[:,
:
-
128
].
amin
(
dim
=-
1
)
recv_x_amax
=
recv_x
[:,
:
-
128
].
amax
(
dim
=-
1
)
recv_src_info
=
recv_src_info
[:
num_valid_tokens
]
assert
torch
.
equal
(
recv_x_amin
,
recv_x_amax
)
if
round_scale
:
assert
calc_diff
(
recv_x
[:,
-
1
],
recv_src_info
.
view
(
-
1
))
<
0.007
elif
use_int8
:
assert
calc_diff
(
recv_x
[:,
-
1
],
recv_src_info
.
view
(
-
1
))
<
0.01
else
:
assert
(
recv_x
[:,
-
128
:]
-
recv_src_info
.
view
(
-
1
,
1
)
%
num_tokens
).
sum
().
item
()
==
0
# for j in range(num_ranks):
# if (not round_scale):
# check_tmp1 = (recv_x_amin == j - rank_offset).sum().item()
# check_tmp2 = (all_topk_idx[j] == expert_id).sum().item()
# print(f'rank: {rank}, j: {j}, check_tmp1: {check_tmp1}, check_tmp2: {check_tmp2}, diff: {abs(check_tmp1 - check_tmp2)}')
# assert abs(check_tmp1 - check_tmp2) < 3
# assert (recv_x[begin_idx:begin_idx + count, :-128] - j + rank_offset).sum().item() == 0
if
dispatch_use_fp8
:
hash_value
^=
hash_tensor
(
packed_recv_x
[
0
][
i
,
:
num_valid_tokens
])
hash_value
^=
hash_tensor
(
packed_recv_x
[
1
][
i
,
:
num_valid_tokens
])
else
:
hash_value
^=
hash_tensor
(
packed_recv_x
[
i
,
:
num_valid_tokens
])
print
(
"dispatch int 8 pass"
)
# noinspection PyShadowingNames
def
large_gemm_with_hook
(
hook
):
mat_0
=
torch
.
randn
((
8192
,
8192
),
dtype
=
torch
.
float
)
mat_1
=
torch
.
randn
((
8192
,
8192
),
dtype
=
torch
.
float
)
mat_0
@
mat_1
hook
()
# noinspection PyShadowingNames
def
test_func
(
return_recv_hook
:
bool
):
recv_x
,
recv_count
,
handle
,
event
,
hook
=
\
buffer
.
low_latency_dispatch
(
current_x
,
topk_idx
,
num_tokens
,
num_experts
,
use_fp8
=
True
,
round_scale
=
False
,
use_ue8m0
=
False
,
use_int8
=
True
,
async_finish
=
False
,
return_recv_hook
=
return_recv_hook
)
large_gemm_with_hook
(
hook
)
if
return_recv_hook
else
None
# Calculate bandwidth
num_fp8_bytes
,
num_bf16_bytes
=
(
hidden
+
hidden
/
128
*
4
+
16
),
hidden
*
2
num_logfmt10_bytes
=
hidden
*
10
/
8
+
hidden
/
128
*
4
num_dispatch_comm_bytes
,
num_combine_comm_bytes
=
0
,
0
for
i
in
range
(
num_tokens
):
num_selections
=
(
topk_idx
[
i
]
!=
-
1
).
sum
().
item
()
num_dispatch_comm_bytes
+=
num_fp8_bytes
*
num_selections
num_combine_comm_bytes
+=
num_bf16_bytes
*
num_selections
# Separate profiling
for
return_recv_hook
in
(
True
,
):
group
.
barrier
()
dispatch_t
=
bench_kineto
(
partial
(
test_func
,
return_recv_hook
=
return_recv_hook
),
kernel_names
=
'dispatch'
,
barrier_comm_profiling
=
True
,
suppress_kineto_output
=
True
,
num_kernels_per_period
=
2
if
return_recv_hook
else
1
)
if
not
return_recv_hook
:
print
(
f
'[rank
{
rank
}
] Dispatch bandwidth:
{
num_dispatch_comm_bytes
/
1e9
/
dispatch_t
:.
2
f
}
GB/s, avg_t=
{
dispatch_t
*
1e6
:.
2
f
}
us'
,
flush
=
True
)
else
:
print
(
f
'[rank
{
rank
}
] Dispatch send/recv time:
{
dispatch_t
[
0
]
*
1e6
:.
2
f
}
+
{
dispatch_t
[
1
]
*
1e6
:.
2
f
}
us'
,
flush
=
True
)
return
hash_value
# noinspection PyUnboundLocalVariable,PyShadowingNames
def
test_loop
(
local_rank
:
int
,
num_local_ranks
:
int
,
args
:
argparse
.
Namespace
):
rank
,
num_ranks
,
group
=
init_dist
(
local_rank
,
num_local_ranks
)
num_tokens
,
hidden
=
args
.
num_tokens
,
args
.
hidden
num_topk
,
num_experts
=
args
.
num_topk
,
args
.
num_experts
num_rdma_bytes
=
deep_ep
.
Buffer
.
get_low_latency_rdma_size_hint
(
num_tokens
,
hidden
,
num_ranks
,
num_experts
)
if
local_rank
==
0
:
print
(
f
'Allocating buffer size:
{
num_rdma_bytes
/
1e6
}
MB ...'
,
flush
=
True
)
buffer
=
deep_ep
.
Buffer
(
group
,
num_rdma_bytes
=
num_rdma_bytes
,
low_latency_mode
=
True
,
num_qps_per_rank
=
num_experts
//
num_ranks
,
allow_nvlink_for_low_latency_mode
=
not
args
.
disable_nvlink
,
explicitly_destroy
=
True
,
allow_mnnvl
=
args
.
allow_mnnvl
)
test_main
(
num_tokens
,
hidden
,
num_experts
,
num_topk
,
rank
,
num_ranks
,
group
,
buffer
,
seed
=
1
)
# do_pressure_test = args.pressure_test
# for seed in range(int(1e9) if do_pressure_test else 0):
# if local_rank == 0:
# print(f'Testing with seed {seed} ...', flush=True)
# ref_hash = test_main(num_tokens,
# hidden,
# num_experts,
# num_topk,
# rank,
# num_ranks,
# group,
# buffer,
# seed=seed)
# for _ in range(20):
# assert test_main(num_tokens,
# hidden,
# num_experts,
# num_topk,
# rank,
# num_ranks,
# group,
# buffer,
# seed=seed) == ref_hash, f'Error: seed={seed}'
# Destroy the buffer runtime and communication group
buffer
.
destroy
()
dist
.
barrier
()
dist
.
destroy_process_group
()
if
__name__
==
'__main__'
:
# TODO: you may modify NUMA binding for less CPU overhead
# TODO: buggy with `num_tokens=512`
parser
=
argparse
.
ArgumentParser
(
description
=
'Test low-latency EP kernels'
)
parser
.
add_argument
(
'--num-processes'
,
type
=
int
,
default
=
8
,
help
=
'Number of processes to spawn (default: 8)'
)
parser
.
add_argument
(
'--num-tokens'
,
type
=
int
,
default
=
128
,
help
=
'Number of tokens (default: 128)'
)
parser
.
add_argument
(
'--hidden'
,
type
=
int
,
default
=
2560
,
help
=
'Hidden dimension size (default: 7168)'
)
parser
.
add_argument
(
'--num-topk'
,
type
=
int
,
default
=
8
,
help
=
'Number of top-k experts (default: 8)'
)
parser
.
add_argument
(
'--num-experts'
,
type
=
int
,
default
=
256
,
help
=
'Number of experts (default: 288)'
)
parser
.
add_argument
(
'--allow-mnnvl'
,
action
=
"store_true"
,
help
=
'Allow MNNVL for communication'
)
parser
.
add_argument
(
'--disable-nvlink'
,
action
=
'store_true'
,
help
=
'Whether to disable NVLink for testing'
)
parser
.
add_argument
(
"--pressure-test"
,
action
=
'store_true'
,
help
=
'Whether to do pressure test'
)
parser
.
add_argument
(
"--shrink-test"
,
action
=
'store_true'
,
help
=
'Whether to simulate failure and test shrink mode'
)
parser
.
add_argument
(
'--use-logfmt'
,
action
=
'store_true'
,
help
=
'Whether to test LogFMT combine'
)
args
=
parser
.
parse_args
()
num_processes
=
args
.
num_processes
torch
.
multiprocessing
.
spawn
(
test_loop
,
args
=
(
num_processes
,
args
),
nprocs
=
num_processes
)
tests/utils.py
View file @
35735902
...
@@ -73,6 +73,29 @@ def per_token_cast_back(x_fp8: torch.Tensor, x_scales: torch.Tensor):
...
@@ -73,6 +73,29 @@ def per_token_cast_back(x_fp8: torch.Tensor, x_scales: torch.Tensor):
return
(
x_fp32_padded
*
x_scales
).
view
(
x_fp8_padded
.
shape
).
to
(
torch
.
bfloat16
)[:,:
n
].
contiguous
()
return
(
x_fp32_padded
*
x_scales
).
view
(
x_fp8_padded
.
shape
).
to
(
torch
.
bfloat16
)[:,:
n
].
contiguous
()
def
per_token_cast_back_int8
(
x_int8
:
torch
.
Tensor
,
x_scales
:
torch
.
Tensor
):
"""
x_int8: [m, n] int8 tensor
x_scales: [m, n] 或 [m, 1] 或 [m, n/128] 量化 scale float
return: [m, n] bf16 tensor
"""
if
x_int8
.
numel
()
==
0
:
return
x_int8
.
to
(
torch
.
bfloat16
)
assert
x_int8
.
dim
()
==
2
m
,
n
=
x_int8
.
shape
aligned_n
=
align_up
(
n
,
128
)
x_int8_padded
=
torch
.
nn
.
functional
.
pad
(
x_int8
,
(
0
,
aligned_n
-
n
),
mode
=
'constant'
,
value
=
0
)
x_fp32_padded
=
x_int8_padded
.
to
(
torch
.
float32
).
view
(
m
,
-
1
,
1
)
x_scales
=
x_scales
.
view
(
m
,
-
1
,
1
).
to
(
torch
.
float32
)
# print(f'x_int8.shape: {x_int8.shape}, x_fp32_padded: {x_fp32_padded.shape}, x_scales: {x_scales.shape}')
x_deq
=
(
x_fp32_padded
*
x_scales
).
view
(
m
,
aligned_n
)
return
x_deq
[:,
:
n
].
to
(
torch
.
bfloat16
).
contiguous
()
def
inplace_unique
(
x
:
torch
.
Tensor
,
num_slots
:
int
):
def
inplace_unique
(
x
:
torch
.
Tensor
,
num_slots
:
int
):
assert
x
.
dim
()
==
2
assert
x
.
dim
()
==
2
mask
=
x
<
0
mask
=
x
<
0
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment