Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
DeepEP
Commits
09cb2b03
Commit
09cb2b03
authored
Oct 30, 2025
by
lishen
Browse files
添加low latency接口,正确性需补充
parent
0b14d3b2
Changes
9
Show whitespace changes
Inline
Side-by-side
Showing
9 changed files
with
1579 additions
and
1059 deletions
+1579
-1059
build.sh
build.sh
+11
-6
csrc/config.hpp
csrc/config.hpp
+29
-26
csrc/deep_ep.cu
csrc/deep_ep.cu
+355
-64
csrc/deep_ep.hpp
csrc/deep_ep.hpp
+12
-3
csrc/kernels/api.cuh
csrc/kernels/api.cuh
+42
-0
csrc/kernels/configs.cuh
csrc/kernels/configs.cuh
+2
-0
csrc/kernels/internode_ll.cu
csrc/kernels/internode_ll.cu
+927
-949
csrc/kernels/utils.cuh
csrc/kernels/utils.cuh
+169
-9
deep_ep/buffer.py
deep_ep/buffer.py
+32
-2
No files found.
build.sh
View file @
09cb2b03
...
...
@@ -8,12 +8,17 @@ fi
PYTHON_INCLUDE
=
$(
python3
-c
"from sysconfig import get_paths; print(get_paths()['include'])"
)
PYTHON_PLATLIB
=
$(
python3
-c
"from sysconfig import get_paths; print(get_paths()['platlib'])"
)
/opt/dtk/bin/hipcc
-Icsrc
/
-I
$(
pwd
)
/rocshmem_dir/include/
-I
/opt/mpi/include
-I
${
PYTHON_PLATLIB
}
/torch/include
-I
${
PYTHON_PLATLIB
}
/torch/include/torch/csrc/api/include
-I
${
PYTHON_PLATLIB
}
/torch/include/TH
-I
${
PYTHON_PLATLIB
}
/torch/include/THC
-I
${
PYTHON_PLATLIB
}
/torch/include/THH
-I
/opt/dtk/include
-I
${
PYTHON_INCLUDE
}
-c
-c
./csrc/kernels/intranode.cu
-o
build_/intranode.o
-fPIC
-D__HIP_PLATFORM_AMD__
=
1
-DUSE_ROCM
=
1
-DHIPBLAS_V2
-DCUDA_HAS_FP16
=
1
-D__HIP_NO_HALF_OPERATORS__
=
1
-D__HIP_NO_HALF_CONVERSIONS__
=
1
-O3
-fgpu-rdc
-DTORCH_API_INCLUDE_EXTENSION_H
'-DPYBIND11_COMPILER_TYPE="_gcc"'
'-DPYBIND11_STDLIB="_libstdcpp"'
'-DPYBIND11_BUILD_ABI="_cxxabi1014"'
-DTORCH_EXTENSION_NAME
=
deep_ep_cpp
-D_GLIBCXX_USE_CXX11_ABI
=
1
--offload-arch
=
gfx936
-std
=
c++17
/opt/dtk/bin/hipcc
-Icsrc
/
-I
$(
pwd
)
/rocshmem_dir/include/
-I
/opt/mpi/include
-I
${
PYTHON_PLATLIB
}
/torch/include
-I
${
PYTHON_PLATLIB
}
/torch/include/torch/csrc/api/include
-I
${
PYTHON_PLATLIB
}
/torch/include/TH
-I
${
PYTHON_PLATLIB
}
/torch/include/THC
-I
${
PYTHON_PLATLIB
}
/torch/include/THH
-I
/opt/dtk/include
-I
${
PYTHON_INCLUDE
}
-c
-c
./csrc/kernels/runtime.cu
-o
build_/runtime.o
-fPIC
-D__HIP_PLATFORM_AMD__
=
1
-DUSE_ROCM
=
1
-DHIPBLAS_V2
-DCUDA_HAS_FP16
=
1
-D__HIP_NO_HALF_OPERATORS__
=
1
-D__HIP_NO_HALF_CONVERSIONS__
=
1
-O3
-fgpu-rdc
-DTORCH_API_INCLUDE_EXTENSION_H
'-DPYBIND11_COMPILER_TYPE="_gcc"'
'-DPYBIND11_STDLIB="_libstdcpp"'
'-DPYBIND11_BUILD_ABI="_cxxabi1014"'
-DTORCH_EXTENSION_NAME
=
deep_ep_cpp
-D_GLIBCXX_USE_CXX11_ABI
=
1
--offload-arch
=
gfx936
-std
=
c++17
/opt/dtk/bin/hipcc
-Icsrc
/
-I
$(
pwd
)
/rocshmem_dir/include/
-I
/opt/mpi/include
-I
${
PYTHON_PLATLIB
}
/torch/include
-I
${
PYTHON_PLATLIB
}
/torch/include/torch/csrc/api/include
-I
${
PYTHON_PLATLIB
}
/torch/include/TH
-I
${
PYTHON_PLATLIB
}
/torch/include/THC
-I
${
PYTHON_PLATLIB
}
/torch/include/THH
-I
/opt/dtk/include
-I
${
PYTHON_INCLUDE
}
-c
-c
./csrc/kernels/layout.cu
-o
build_/layout.o
-fPIC
-D__HIP_PLATFORM_AMD__
=
1
-DUSE_ROCM
=
1
-DHIPBLAS_V2
-DCUDA_HAS_FP16
=
1
-D__HIP_NO_HALF_OPERATORS__
=
1
-D__HIP_NO_HALF_CONVERSIONS__
=
1
-O3
-fgpu-rdc
-DTORCH_API_INCLUDE_EXTENSION_H
'-DPYBIND11_COMPILER_TYPE="_gcc"'
'-DPYBIND11_STDLIB="_libstdcpp"'
'-DPYBIND11_BUILD_ABI="_cxxabi1014"'
-DTORCH_EXTENSION_NAME
=
deep_ep_cpp
-D_GLIBCXX_USE_CXX11_ABI
=
1
--offload-arch
=
gfx936
-std
=
c++17
/opt/dtk/bin/hipcc
-Icsrc
/
-I
$(
pwd
)
/rocshmem_dir/include/
-I
/opt/mpi/include
-I
${
PYTHON_PLATLIB
}
/torch/include
-I
${
PYTHON_PLATLIB
}
/torch/include/torch/csrc/api/include
-I
${
PYTHON_PLATLIB
}
/torch/include/TH
-I
${
PYTHON_PLATLIB
}
/torch/include/THC
-I
${
PYTHON_PLATLIB
}
/torch/include/THH
-I
/opt/dtk/include
-I
${
PYTHON_INCLUDE
}
-c
-c
./csrc/deep_ep.cu
-o
build_/deep_ep.o
-fPIC
-D__HIP_PLATFORM_AMD__
=
1
-DUSE_ROCM
=
1
-DHIPBLAS_V2
-DCUDA_HAS_FP16
=
1
-D__HIP_NO_HALF_OPERATORS__
=
1
-D__HIP_NO_HALF_CONVERSIONS__
=
1
-O3
-fgpu-rdc
-DTORCH_API_INCLUDE_EXTENSION_H
'-DPYBIND11_COMPILER_TYPE="_gcc"'
'-DPYBIND11_STDLIB="_libstdcpp"'
'-DPYBIND11_BUILD_ABI="_cxxabi1014"'
-DTORCH_EXTENSION_NAME
=
deep_ep_cpp
-D_GLIBCXX_USE_CXX11_ABI
=
1
--offload-arch
=
gfx936
-std
=
c++17
/opt/dtk/bin/hipcc
-Icsrc
/
-I
$(
pwd
)
/rocshmem_dir/include/
-I
/opt/mpi/include
-I
${
PYTHON_PLATLIB
}
/torch/include
-I
${
PYTHON_PLATLIB
}
/torch/include/torch/csrc/api/include
-I
${
PYTHON_PLATLIB
}
/torch/include/TH
-I
${
PYTHON_PLATLIB
}
/torch/include/THC
-I
${
PYTHON_PLATLIB
}
/torch/include/THH
-I
/opt/dtk/include
-I
${
PYTHON_INCLUDE
}
-c
-c
./csrc/kernels/internode.cu
-o
build_/internode.o
-fPIC
-D__HIP_PLATFORM_AMD__
=
1
-DUSE_ROCM
=
1
-DHIPBLAS_V2
-DCUDA_HAS_FP16
=
1
-D__HIP_NO_HALF_OPERATORS__
=
1
-D__HIP_NO_HALF_CONVERSIONS__
=
1
-O3
-fgpu-rdc
-DTORCH_API_INCLUDE_EXTENSION_H
'-DPYBIND11_COMPILER_TYPE="_gcc"'
'-DPYBIND11_STDLIB="_libstdcpp"'
'-DPYBIND11_BUILD_ABI="_cxxabi1014"'
-DTORCH_EXTENSION_NAME
=
deep_ep_cpp
-D_GLIBCXX_USE_CXX11_ABI
=
1
--offload-arch
=
gfx936
-std
=
c++17
hipcc
-Wno-unused-result
-Wsign-compare
-DNDEBUG
-g
-fwrapv
-O2
-Wall
-g
-fstack-protector-strong
-Wformat
-Werror
=
format-security
-g
-fwrapv
-O2
-shared
-Wl
,-O1
-Wl
,-Bsymbolic-functions build_/internode.o build_/intranode.o build_/runtime.o build_/deep_ep.o build_/layout.o
-L
$(
pwd
)
/rocshmem_dir/lib/
-L
/opt/mpi/lib
-L
/opt/dtk/hip/lib
-L
/usr/lib/x86_64-linux-gnu
-lhipblaslt
-lamdhip64
-o
deep_ep/deep_ep_cpp.cpython-310-x86_64-linux-gnu.so
-Wl
,-rpath,/opt/dtk/lib
-fgpu-rdc
--hip-link
--offload-arch
=
gfx936
-shared
-Wl
,-soname,deep_ep/deep_ep_cpp.cpython-310-x86_64-linux-gnu.so
-Wl
,-rpath,
$(
pwd
)
/rocshmem_dir/lib/
-L
"/opt/dtk/llvm/lib/clang/15.0.0/include/../lib/linux"
-lclang_rt
.builtins-x86_64 /opt/dtk/hip/lib/libgalaxyhip.so /opt/dtk/llvm/lib/clang/15.0.0/lib/linux/libclang_rt.builtins-x86_64.a /opt/hyhal/lib/libhsa-runtime64.so.1.11.0
-L
${
PYTHON_PLATLIB
}
/torch/lib
-L
/opt/dtk/lib
-L
/opt/dtk/hip/lib
-L
/usr/local/lib
-lc10
-ltorch
-ltorch_cpu
-ltorch_python
-lamdhip64
-lc10_hip
-ltorch_hip
-lrocm-core
-lrocm_smi64
-l
:librocshmem.a
-fgpu-rdc
--hip-link
-lamdhip64
-lhsa-runtime64
-l
:libmpi.so
-Wl
,-rpath,/opt/mpi/lib/
-libverbs
-lmlx5
INCLUDE_PATHS
=
${
INCLUDE_PATHS
:
=-Icsrc/ -I
$(
pwd
)
/rocshmem_dir/include/ -I/opt/mpi/include -I
${
PYTHON_PLATLIB
}
/torch/include -I
${
PYTHON_PLATLIB
}
/torch/include/torch/csrc/api/include -I
${
PYTHON_PLATLIB
}
/torch/include/TH -I
${
PYTHON_PLATLIB
}
/torch/include/THC -I
${
PYTHON_PLATLIB
}
/torch/include/THH -I/opt/dtk/include -I
${
PYTHON_INCLUDE
}}
COMPILE_OPTIONS
=
${
COMPILE_OPTIONS
:
= -fPIC -D__HIP_PLATFORM_AMD__=1 -DUSE_ROCM=1 -DHIPBLAS_V2 -DCUDA_HAS_FP16=1 -D__HIP_NO_HALF_OPERATORS__=1 -D__HIP_NO_HALF_CONVERSIONS__=1 -O3 -fgpu-rdc -DTORCH_API_INCLUDE_EXTENSION_H
'-DPYBIND11_COMPILER_TYPE="_gcc"'
'-DPYBIND11_STDLIB="_libstdcpp"'
'-DPYBIND11_BUILD_ABI="_cxxabi1014"'
-DTORCH_EXTENSION_NAME=deep_ep_cpp -D_GLIBCXX_USE_CXX11_ABI=1 --offload-arch=gfx936 -std=c++17
}
hipcc
${
INCLUDE_PATHS
}
-c
$(
pwd
)
/csrc/kernels/runtime.cu
-o
build_/runtime.o
${
COMPILE_OPTIONS
}
hipcc
${
INCLUDE_PATHS
}
-c
$(
pwd
)
/csrc/kernels/layout.cu
-o
build_/layout.o
${
COMPILE_OPTIONS
}
hipcc
${
INCLUDE_PATHS
}
-c
$(
pwd
)
/csrc/kernels/intranode.cu
-o
build_/intranode.o
${
COMPILE_OPTIONS
}
hipcc
${
INCLUDE_PATHS
}
-c
$(
pwd
)
/csrc/kernels/internode.cu
-o
build_/internode.o
${
COMPILE_OPTIONS
}
hipcc
${
INCLUDE_PATHS
}
-c
$(
pwd
)
/csrc/kernels/internode_ll.cu
-o
build_/internode_ll.o
${
COMPILE_OPTIONS
}
hipcc
${
INCLUDE_PATHS
}
-c
$(
pwd
)
/csrc/deep_ep.cu
-o
build_/deep_ep.o
${
COMPILE_OPTIONS
}
hipcc
-Wno-unused-result
-Wsign-compare
-DNDEBUG
-g
-fwrapv
-O2
-Wall
-g
-fstack-protector-strong
-Wformat
-Werror
=
format-security
-g
-fwrapv
-O2
-shared
-Wl
,-O1
-Wl
,-Bsymbolic-functions build_/internode.o build_/intranode.o build_/runtime.o build_/deep_ep.o build_/layout.o build_/internode_ll.o
-L
$(
pwd
)
/rocshmem_dir/lib/
-L
/opt/mpi/lib
-L
/opt/dtk/hip/lib
-L
/usr/lib/x86_64-linux-gnu
-lhipblaslt
-lamdhip64
-o
deep_ep/deep_ep_cpp.cpython-310-x86_64-linux-gnu.so
-Wl
,-rpath,/opt/dtk/lib
-fgpu-rdc
--hip-link
--offload-arch
=
gfx936
-shared
-Wl
,-soname,deep_ep/deep_ep_cpp.cpython-310-x86_64-linux-gnu.so
-Wl
,-rpath,
$(
pwd
)
/rocshmem_dir/lib/
-L
"/opt/dtk/llvm/lib/clang/15.0.0/include/../lib/linux"
-lclang_rt
.builtins-x86_64 /opt/dtk/hip/lib/libgalaxyhip.so /opt/dtk/llvm/lib/clang/15.0.0/lib/linux/libclang_rt.builtins-x86_64.a /opt/hyhal/lib/libhsa-runtime64.so.1.11.0
-L
${
PYTHON_PLATLIB
}
/torch/lib
-L
/opt/dtk/lib
-L
/opt/dtk/hip/lib
-L
/usr/local/lib
-lc10
-ltorch
-ltorch_cpu
-ltorch_python
-lamdhip64
-lc10_hip
-ltorch_hip
-lrocm-core
-lrocm_smi64
-l
:librocshmem.a
-fgpu-rdc
--hip-link
-lamdhip64
-lhsa-runtime64
-l
:libmpi.so
-Wl
,-rpath,/opt/mpi/lib/
-libverbs
-lmlx5
# build whl
echo
"Using Python:
$(
which python3
)
"
...
...
csrc/config.hpp
View file @
09cb2b03
// !!! This is a file automatically generated by hipify!!!
#include <ATen/dtk_macros.h>
#pragma once
#include "kernels/api.cuh"
...
...
@@ -105,18 +107,18 @@ struct Config {
struct
LowLatencyBuffer
{
int
num_clean_int
=
0
;
void
*
dispatch_rdma_send_buffer
=
nullptr
;
void
*
dispatch_rdma_recv_data_buffer
=
nullptr
;
int
*
dispatch_rdma_recv_count_buffer
=
nullptr
;
void
*
dispatch_rdma_send_buffer
=
nullptr
;
void
*
dispatch_rdma_recv_data_buffer
=
nullptr
;
int
64_t
*
dispatch_rdma_recv_count_buffer
=
nullptr
;
void
*
combine_rdma_send_buffer
=
nullptr
;
void
*
combine_rdma_recv_data_buffer
=
nullptr
;
int
*
combine_rdma_recv_flag_buffer
=
nullptr
;
void
*
combine_rdma_send_buffer
=
nullptr
;
void
*
combine_rdma_recv_data_buffer
=
nullptr
;
int
64_t
*
combine_rdma_recv_flag_buffer
=
nullptr
;
void
*
combine_rdma_send_buffer_data_start
=
nullptr
;
void
*
combine_rdma_send_buffer_data_start
=
nullptr
;
size_t
num_bytes_per_combine_msg
=
0
;
std
::
pair
<
int
*
,
int
>
clean_meta
()
{
std
::
pair
<
int
64_t
*
,
int
>
clean_meta
()
{
EP_HOST_ASSERT
(
dispatch_rdma_recv_count_buffer
==
combine_rdma_recv_flag_buffer
);
return
{
dispatch_rdma_recv_count_buffer
,
num_clean_int
};
}
...
...
@@ -171,29 +173,30 @@ struct LowLatencyLayout {
total_bytes
+=
recv_buffer_bytes
*
2
;
// Symmetric signaling buffers
size_t
dispatch_recv_count_buffer_bytes
=
num_experts
*
sizeof
(
int
);
size_t
dispatch_recv_count_buffer_bytes
=
num_experts
*
sizeof
(
int
64_t
);
size_t
combine_recv_flag_buffer_bytes
=
dispatch_recv_count_buffer_bytes
;
size_t
signaling_buffer_bytes
=
std
::
max
(
dispatch_recv_count_buffer_bytes
,
combine_recv_flag_buffer_bytes
);
size_t
signaling_buffer_bytes_aligned
=
ALIGN
<
size_t
>
(
signaling_buffer_bytes
,
128
);
total_bytes
+=
signaling_buffer_bytes_aligned
*
2
;
size_t
signaling_buffer_bytes
=
std
::
max
(
dispatch_recv_count_buffer_bytes
,
combine_recv_flag_buffer_bytes
);
total_bytes
+=
signaling_buffer_bytes
*
2
;
// Assign pointers
// NOTES: we still leave some space for distinguishing dispatch/combine buffer,
// so you may see some parameters are duplicated
for
(
int
i
=
0
;
i
<
2
;
++
i
)
{
buffers
[
i
]
=
{
static_cast
<
int
>
(
signaling_buffer_bytes
/
sizeof
(
int
)),
advance
(
rdma_buffer
,
signaling_buffer_bytes_aligned
*
2
+
send_buffer_bytes
*
i
),
advance
(
rdma_buffer
,
signaling_buffer_bytes_aligned
*
2
+
send_buffer_bytes
*
2
+
recv_buffer_bytes
*
i
),
advance
<
int
*>
(
rdma_buffer
,
signaling_buffer_bytes_aligned
*
i
),
advance
(
rdma_buffer
,
signaling_buffer_bytes_aligned
*
2
+
send_buffer_bytes
*
i
),
advance
(
rdma_buffer
,
signaling_buffer_bytes_aligned
*
2
+
send_buffer_bytes
*
2
+
recv_buffer_bytes
*
i
),
advance
<
int
*>
(
rdma_buffer
,
signaling_buffer_bytes_aligned
*
i
),
advance
(
rdma_buffer
,
signaling_buffer_bytes_aligned
*
2
+
send_buffer_bytes
*
i
),
num_bytes_per_combine_msg
};
static_cast
<
int
>
(
signaling_buffer_bytes
/
sizeof
(
int64_t
)),
// dispatch:send_buffer + recv_buffer + recv_count
advance
(
rdma_buffer
,
send_buffer_bytes
*
i
),
advance
(
rdma_buffer
,
send_buffer_bytes
*
2
+
recv_buffer_bytes
*
i
),
advance
<
int64_t
*>
(
rdma_buffer
,
send_buffer_bytes
*
2
+
recv_buffer_bytes
*
2
+
signaling_buffer_bytes
*
i
),
// combine:send_buffer + recv_buffer + recv_count
advance
(
rdma_buffer
,
send_buffer_bytes
*
i
),
advance
(
rdma_buffer
,
send_buffer_bytes
*
2
+
recv_buffer_bytes
*
i
),
advance
<
int64_t
*>
(
rdma_buffer
,
send_buffer_bytes
*
2
+
recv_buffer_bytes
*
2
+
signaling_buffer_bytes
*
i
),
// combine_rdma_send_buffer_data_start
advance
(
rdma_buffer
,
send_buffer_bytes
*
i
+
sizeof
(
int4
)),
//
num_bytes_per_combine_msg
};
}
}
};
...
...
csrc/deep_ep.cu
View file @
09cb2b03
//
#include <ATen/dtk_macros.h>
#include <ATen/dtk_macros.h>
#include <ATen/hip/HIPContext.h>
#include <ATen/hip/HIPDataType.h>
#include <chrono>
...
...
@@ -13,20 +13,19 @@
namespace
deep_ep
{
Buffer
::
Buffer
(
int
rank
,
int
num_ranks
,
int64_t
num_nvl_bytes
,
int64_t
num_rdma_bytes
,
bool
low_latency_mode
,
bool
explicitly_destroy
,
bool
use_default_stream_as_comm_stream
)
bool
low_latency_mode
,
bool
explicitly_destroy
,
bool
enable_shrink
)
:
rank
(
rank
),
num_ranks
(
num_ranks
),
num_nvl_bytes
(
num_nvl_bytes
),
num_rdma_bytes
(
num_rdma_bytes
),
low_latency_mode
(
low_latency_mode
),
explicitly_destroy
(
explicitly_destroy
),
use_default_stream_as_comm_stream
(
use_default_stream_as_comm_stream
),
comm_stream
(
use_default_stream_as_comm_stream
?
at
::
hip
::
getCurrentHIPStreamMasqueradingAsCUDA
()
:
at
::
hip
::
getStreamFromPoolMasqueradingAsCUDA
(
true
))
{
enable_shrink
(
enable_shrink
),
comm_stream
(
at
::
hip
::
getStreamFromPoolMasqueradingAsCUDA
(
true
))
{
// Metadata memory
int64_t
barrier_signal_bytes
=
NUM_MAX_NVL_PEERS
*
sizeof
(
int
);
int64_t
buffer_ptr_bytes
=
NUM_MAX_NVL_PEERS
*
sizeof
(
void
*
);
int64_t
barrier_signal_ptr_bytes
=
NUM_MAX_NVL_PEERS
*
sizeof
(
int
*
);
EP_HOST_ASSERT
(
enable_shrink
==
false
);
// Common checks
EP_HOST_ASSERT
(
num_nvl_bytes
%
NUM_BUFFER_ALIGNMENT_BYTES
==
0
and
(
num_nvl_bytes
<=
std
::
numeric_limits
<
int
>::
max
()
or
num_rdma_bytes
==
0
));
...
...
@@ -77,7 +76,7 @@ Buffer::Buffer(int rank, int num_ranks, int64_t num_nvl_bytes, int64_t num_rdma_
}
// Create 32 MiB workspace
CUDA_CHECK
(
hipMalloc
(
&
workspace
,
NUM_WORKSPACE_BYTES
));
CUDA_CHECK
(
hip
Ext
Malloc
WithFlags
(
&
workspace
,
NUM_WORKSPACE_BYTES
,
hipDeviceMallocUncached
));
CUDA_CHECK
(
hipMemsetAsync
(
workspace
,
0
,
NUM_WORKSPACE_BYTES
,
comm_stream
));
// MoE counter
...
...
@@ -200,6 +199,10 @@ void Buffer::destroy() {
CUDA_CHECK
(
hipDeviceSynchronize
());
internode
::
barrier
();
internode
::
free
(
rdma_buffer_ptr
);
if
(
enable_shrink
)
{
internode
::
free
(
mask_buffer_ptr
);
internode
::
free
(
sync_buffer_ptr
);
}
internode
::
finalize
();
}
#endif
...
...
@@ -253,25 +256,32 @@ void Buffer::sync(const std::vector<int> &device_
// Sync ROCSHMEM handles and allocate memory
if
(
num_rdma_bytes
>
0
)
{
// Initialize
NV
SHMEM
// Initialize
ROC
SHMEM
EP_HOST_ASSERT
(
root_unique_id_opt
.
has_value
());
std
::
vector
<
uint8_t
>
root_unique_id
(
root_unique_id_opt
->
size
());
auto
root_unique_id_str
=
root_unique_id_opt
->
cast
<
std
::
string
>
();
std
::
memcpy
(
root_unique_id
.
data
(),
root_unique_id_str
.
c_str
(),
root_unique_id_opt
->
size
());
auto
nvshmem_rank
=
low_latency_mode
?
rank
:
rdma_rank
;
auto
num_nvshmem_ranks
=
low_latency_mode
?
num_ranks
:
num_rdma_ranks
;
EP_HOST_ASSERT
(
nvshmem_rank
==
internode
::
init
(
root_unique_id
,
nvshmem_rank
,
num_nvshmem_ranks
,
low_latency_mode
));
EP_HOST_ASSERT
(
nvshmem_rank
==
internode
::
init
(
root_unique_id
,
nvshmem_rank
,
num_nvshmem_ranks
,
low_latency_mode
));
internode
::
barrier
();
// Allocate
rdma_buffer_ptr
=
internode
::
alloc
(
num_rdma_bytes
,
NUM_BUFFER_ALIGNMENT_BYTES
);
rdma_buffer_ptr
=
internode
::
alloc
(
num_rdma_bytes
,
NUM_BUFFER_ALIGNMENT_BYTES
);
// Clean buffer (mainly for low-latency mode)
CUDA_CHECK
(
hipMemset
(
rdma_buffer_ptr
,
0
,
num_rdma_bytes
));
// Allocate and clean shrink buffer
if
(
enable_shrink
)
{
int
num_mask_buffer_bytes
=
num_ranks
*
sizeof
(
int
);
int
num_sync_buffer_bytes
=
num_ranks
*
sizeof
(
int
);
mask_buffer_ptr
=
reinterpret_cast
<
int
*>
(
internode
::
alloc
(
num_mask_buffer_bytes
,
NUM_BUFFER_ALIGNMENT_BYTES
));
sync_buffer_ptr
=
reinterpret_cast
<
int
*>
(
internode
::
alloc
(
num_sync_buffer_bytes
,
NUM_BUFFER_ALIGNMENT_BYTES
));
CUDA_CHECK
(
hipMemset
(
mask_buffer_ptr
,
0
,
num_mask_buffer_bytes
));
CUDA_CHECK
(
hipMemset
(
sync_buffer_ptr
,
0
,
num_sync_buffer_bytes
));
}
// Barrier
internode
::
barrier
();
CUDA_CHECK
(
hipDeviceSynchronize
());
...
...
@@ -298,14 +308,12 @@ Buffer::get_dispatch_layout(const torch::Tensor &topk_idx, int num_experts,
at
::
hip
::
setCurrentHIPStreamMasqueradingAsCUDA
(
comm_stream
);
}
if
(
not
use_default_stream_as_comm_stream
)
{
// Wait previous tasks to be finished
if
(
previous_event
.
has_value
())
{
stream_wait
(
comm_stream
,
previous_event
.
value
());
}
else
{
stream_wait
(
comm_stream
,
compute_stream
);
}
}
auto
num_tokens
=
static_cast
<
int
>
(
topk_idx
.
size
(
0
)),
num_topk
=
static_cast
<
int
>
(
topk_idx
.
size
(
1
));
...
...
@@ -342,10 +350,8 @@ Buffer::get_dispatch_layout(const torch::Tensor &topk_idx, int num_experts,
to
.
has_value
()
?
to
->
record_stream
(
compute_stream
)
:
void
();
}
}
else
{
if
(
not
use_default_stream_as_comm_stream
)
{
stream_wait
(
compute_stream
,
comm_stream
);
}
}
// Switch back compute stream
if
(
allocate_on_comm_stream
)
...
...
@@ -461,13 +467,11 @@ Buffer::intranode_dispatch(
}
// Wait previous tasks to be finished
if
(
not
use_default_stream_as_comm_stream
)
{
if
(
previous_event
.
has_value
())
{
stream_wait
(
comm_stream
,
previous_event
.
value
());
}
else
{
stream_wait
(
comm_stream
,
compute_stream
);
}
}
// Create handles (only return for non-cached mode)
int
num_recv_tokens
=
-
1
;
...
...
@@ -623,10 +627,8 @@ Buffer::intranode_dispatch(
to
.
has_value
()
?
to
->
record_stream
(
compute_stream
)
:
void
();
}
}
else
{
if
(
not
use_default_stream_as_comm_stream
)
{
stream_wait
(
compute_stream
,
comm_stream
);
}
}
// Switch back compute stream
if
(
allocate_on_comm_stream
)
...
...
@@ -691,13 +693,11 @@ Buffer::intranode_combine(const torch::Tensor &x, const std::optional<torch::Ten
}
// Wait previous tasks to be finished
if
(
not
use_default_stream_as_comm_stream
)
{
if
(
previous_event
.
has_value
())
{
stream_wait
(
comm_stream
,
previous_event
.
value
());
}
else
{
stream_wait
(
comm_stream
,
compute_stream
);
}
}
int
num_topk
=
0
;
auto
recv_topk_weights
=
std
::
optional
<
torch
::
Tensor
>
();
...
...
@@ -765,10 +765,8 @@ Buffer::intranode_combine(const torch::Tensor &x, const std::optional<torch::Ten
to
.
has_value
()
?
to
->
record_stream
(
compute_stream
)
:
void
();
}
}
else
{
if
(
not
use_default_stream_as_comm_stream
)
{
stream_wait
(
compute_stream
,
comm_stream
);
}
}
// Switch back compute stream
if
(
allocate_on_comm_stream
)
...
...
@@ -804,8 +802,8 @@ Buffer::internode_dispatch(const torch::Tensor &x, const std::optional<torch::Te
// here.
pybind11
::
gil_scoped_release
release
;
const
int
num_channels
=
config
.
num_sms
/
3
;
EP_HOST_ASSERT
(
config
.
num_sms
%
3
==
0
);
const
int
num_channels
=
config
.
num_sms
/
NUM_INTERNODE_DISPATCH_BLOCKS_PER_CHANNEL
;
EP_HOST_ASSERT
(
config
.
num_sms
%
NUM_INTERNODE_DISPATCH_BLOCKS_PER_CHANNEL
==
0
);
EP_HOST_ASSERT
(
0
<
get_num_rdma_ranks
()
and
get_num_rdma_ranks
()
<=
NUM_MAX_RDMA_PEERS
);
bool
cached_mode
=
cached_rdma_channel_prefix_matrix
.
has_value
();
...
...
@@ -1125,8 +1123,8 @@ Buffer::internode_combine(
const
torch
::
Tensor
&
combined_nvl_head
,
const
Config
&
config
,
std
::
optional
<
EventHandle
>
&
previous_event
,
bool
async
,
bool
allocate_on_comm_stream
)
{
#ifndef DISABLE_ROCSHMEM
const
int
num_channels
=
config
.
num_sms
/
3
;
EP_HOST_ASSERT
(
config
.
num_sms
%
3
==
0
);
const
int
num_channels
=
config
.
num_sms
/
NUM_INTERNODE_DISPATCH_BLOCKS_PER_CHANNEL
;
EP_HOST_ASSERT
(
config
.
num_sms
%
NUM_INTERNODE_DISPATCH_BLOCKS_PER_CHANNEL
==
0
);
// Shape and contiguous checks
EP_HOST_ASSERT
(
x
.
dim
()
==
2
and
x
.
is_contiguous
());
...
...
@@ -1272,39 +1270,329 @@ Buffer::internode_combine(
#endif
}
void
Buffer
::
clean_low_latency_buffer
(
int
num_max_dispatch_tokens_per_rank
,
int
hidden
,
int
num_experts
)
{
EP_HOST_ASSERT
(
false
and
"not support low latency"
);
void
Buffer
::
clean_low_latency_buffer
(
int
num_max_dispatch_tokens_per_rank
,
int
hidden
,
int
num_experts
)
{
#ifndef DISABLE_ROCSHMEM
EP_HOST_ASSERT
(
low_latency_mode
);
auto
layout
=
LowLatencyLayout
(
rdma_buffer_ptr
,
num_max_dispatch_tokens_per_rank
,
hidden
,
num_ranks
,
num_experts
);
auto
clean_meta_0
=
layout
.
buffers
[
0
].
clean_meta
();
auto
clean_meta_1
=
layout
.
buffers
[
1
].
clean_meta
();
auto
check_boundary
=
[
=
](
void
*
ptr
,
size_t
num_bytes
)
{
auto
offset
=
reinterpret_cast
<
int64_t
>
(
ptr
)
-
reinterpret_cast
<
int64_t
>
(
rdma_buffer_ptr
);
EP_HOST_ASSERT
(
0
<=
offset
and
offset
+
num_bytes
<=
num_rdma_bytes
);
};
check_boundary
(
clean_meta_0
.
first
,
clean_meta_0
.
second
*
sizeof
(
int64_t
));
check_boundary
(
clean_meta_1
.
first
,
clean_meta_1
.
second
*
sizeof
(
int64_t
));
internode_ll
::
clean_low_latency_buffer
(
clean_meta_0
.
first
,
clean_meta_0
.
second
,
clean_meta_1
.
first
,
clean_meta_1
.
second
,
rank
,
num_ranks
,
mask_buffer_ptr
,
sync_buffer_ptr
,
at
::
hip
::
getCurrentHIPStreamMasqueradingAsCUDA
());
#else
EP_HOST_ASSERT
(
false
and
"ROCSHMEM is disabled during compilation"
);
#endif
}
std
::
tuple
<
torch
::
Tensor
,
std
::
optional
<
torch
::
Tensor
>
,
torch
::
Tensor
,
torch
::
Tensor
,
torch
::
Tensor
,
std
::
optional
<
EventHandle
>
,
std
::
optional
<
std
::
function
<
void
()
>>>
std
::
optional
<
EventHandle
>
,
std
::
optional
<
std
::
function
<
void
()
>>>
Buffer
::
low_latency_dispatch
(
const
torch
::
Tensor
&
x
,
const
torch
::
Tensor
&
topk_idx
,
const
std
::
optional
<
torch
::
Tensor
>
&
cumulative_local_expert_recv_stats
,
const
std
::
optional
<
torch
::
Tensor
>
&
dispatch_wait_recv_cost_stats
,
int
num_max_dispatch_tokens_per_rank
,
int
num_experts
,
bool
use_fp8
,
bool
round_scale
,
bool
use_ue8m0
,
bool
async
,
bool
return_recv_hook
)
{
EP_HOST_ASSERT
(
false
and
"not support low latency"
);
#ifndef DISABLE_ROCSHMEM
EP_HOST_ASSERT
(
low_latency_mode
);
// Tensor checks
// By default using `ptp128c` FP8 cast
EP_HOST_ASSERT
(
x
.
dim
()
==
2
and
x
.
is_contiguous
()
and
x
.
scalar_type
()
==
torch
::
kBFloat16
);
EP_HOST_ASSERT
(
x
.
size
(
1
)
%
sizeof
(
int4
)
==
0
and
x
.
size
(
1
)
%
128
==
0
);
EP_HOST_ASSERT
(
topk_idx
.
dim
()
==
2
and
topk_idx
.
is_contiguous
());
EP_HOST_ASSERT
(
x
.
size
(
0
)
==
topk_idx
.
size
(
0
)
and
x
.
size
(
0
)
<=
num_max_dispatch_tokens_per_rank
);
EP_HOST_ASSERT
(
topk_idx
.
scalar_type
()
==
torch
::
kInt64
);
EP_HOST_ASSERT
(
num_experts
%
num_ranks
==
0
);
// Diagnosis tensors
if
(
cumulative_local_expert_recv_stats
.
has_value
())
{
EP_HOST_ASSERT
(
cumulative_local_expert_recv_stats
->
scalar_type
()
==
torch
::
kInt
);
EP_HOST_ASSERT
(
cumulative_local_expert_recv_stats
->
dim
()
==
1
and
cumulative_local_expert_recv_stats
->
is_contiguous
());
EP_HOST_ASSERT
(
cumulative_local_expert_recv_stats
->
size
(
0
)
==
num_experts
/
num_ranks
);
}
if
(
dispatch_wait_recv_cost_stats
.
has_value
())
{
EP_HOST_ASSERT
(
dispatch_wait_recv_cost_stats
->
scalar_type
()
==
torch
::
kInt64
);
EP_HOST_ASSERT
(
dispatch_wait_recv_cost_stats
->
dim
()
==
1
and
dispatch_wait_recv_cost_stats
->
is_contiguous
());
EP_HOST_ASSERT
(
dispatch_wait_recv_cost_stats
->
size
(
0
)
==
num_ranks
);
}
auto
num_tokens
=
static_cast
<
int
>
(
x
.
size
(
0
)),
hidden
=
static_cast
<
int
>
(
x
.
size
(
1
));
auto
num_topk
=
static_cast
<
int
>
(
topk_idx
.
size
(
1
));
auto
num_local_experts
=
num_experts
/
num_ranks
;
// Buffer control
LowLatencyLayout
layout
(
rdma_buffer_ptr
,
num_max_dispatch_tokens_per_rank
,
hidden
,
num_ranks
,
num_experts
);
EP_HOST_ASSERT
(
layout
.
total_bytes
<=
num_rdma_bytes
);
auto
buffer
=
layout
.
buffers
[
low_latency_buffer_idx
];
auto
next_buffer
=
layout
.
buffers
[
low_latency_buffer_idx
^=
1
];
// 双buffer操作
auto
global_atomic_counter
=
torch
::
zeros
({
1
},
torch
::
dtype
(
torch
::
kInt32
).
device
(
torch
::
kCUDA
));
// Wait previous tasks to be finished
// NOTES: the hook mode will always use the default stream
auto
compute_stream
=
at
::
hip
::
getCurrentHIPStreamMasqueradingAsCUDA
();
auto
launch_stream
=
return_recv_hook
?
compute_stream
:
comm_stream
;
EP_HOST_ASSERT
(
not
(
async
and
return_recv_hook
));
if
(
not
return_recv_hook
)
stream_wait
(
launch_stream
,
compute_stream
);
// Allocate packed tensors
auto
packed_recv_x
=
torch
::
empty
({
num_local_experts
,
num_ranks
*
num_max_dispatch_tokens_per_rank
,
hidden
},
x
.
options
().
dtype
(
use_fp8
?
torch
::
kFloat8_e4m3fn
:
torch
::
kBFloat16
));
auto
packed_recv_src_info
=
torch
::
empty
({
num_local_experts
,
num_ranks
*
num_max_dispatch_tokens_per_rank
},
torch
::
dtype
(
torch
::
kInt32
).
device
(
torch
::
kCUDA
));
auto
packed_recv_layout_range
=
torch
::
empty
({
num_local_experts
,
num_ranks
},
torch
::
dtype
(
torch
::
kInt64
).
device
(
torch
::
kCUDA
));
auto
packed_recv_count
=
torch
::
empty
({
num_local_experts
},
torch
::
dtype
(
torch
::
kInt32
).
device
(
torch
::
kCUDA
));
// Allocate column-majored scales
auto
packed_recv_x_scales
=
std
::
optional
<
torch
::
Tensor
>
();
void
*
packed_recv_x_scales_ptr
=
nullptr
;
EP_HOST_ASSERT
((
num_ranks
*
num_max_dispatch_tokens_per_rank
)
%
4
==
0
and
"TMA requires the number of tokens to be multiple of 4"
);
// TODO: support unaligned cases
EP_HOST_ASSERT
(
hidden
%
512
==
0
);
if
(
use_fp8
)
{
if
(
not
use_ue8m0
)
{
packed_recv_x_scales
=
torch
::
empty
({
num_local_experts
,
hidden
/
128
,
num_ranks
*
num_max_dispatch_tokens_per_rank
},
torch
::
dtype
(
torch
::
kFloat32
).
device
(
torch
::
kCUDA
));
}
else
{
EP_HOST_ASSERT
(
round_scale
);
packed_recv_x_scales
=
torch
::
empty
({
num_local_experts
,
hidden
/
512
,
num_ranks
*
num_max_dispatch_tokens_per_rank
},
torch
::
dtype
(
torch
::
kInt
).
device
(
torch
::
kCUDA
));
}
packed_recv_x_scales
=
torch
::
transpose
(
packed_recv_x_scales
.
value
(),
1
,
2
);
packed_recv_x_scales_ptr
=
packed_recv_x_scales
->
data_ptr
();
}
// Kernel launch
auto
next_clean_meta
=
next_buffer
.
clean_meta
();
auto
launcher
=
[
=
](
int
phases
)
{
internode_ll
::
dispatch
(
packed_recv_x
.
data_ptr
(),
packed_recv_x_scales_ptr
,
packed_recv_src_info
.
data_ptr
<
int
>
(),
packed_recv_layout_range
.
data_ptr
<
int64_t
>
(),
packed_recv_count
.
data_ptr
<
int
>
(),
global_atomic_counter
.
data_ptr
<
int
>
(),
mask_buffer_ptr
,
cumulative_local_expert_recv_stats
.
has_value
()
?
cumulative_local_expert_recv_stats
->
data_ptr
<
int
>
()
:
nullptr
,
dispatch_wait_recv_cost_stats
.
has_value
()
?
dispatch_wait_recv_cost_stats
->
data_ptr
<
int64_t
>
()
:
nullptr
,
buffer
.
dispatch_rdma_recv_data_buffer
,
buffer
.
dispatch_rdma_recv_count_buffer
,
buffer
.
dispatch_rdma_send_buffer
,
x
.
data_ptr
(),
topk_idx
.
data_ptr
<
int64_t
>
(),
next_clean_meta
.
first
,
next_clean_meta
.
second
,
num_tokens
,
hidden
,
num_max_dispatch_tokens_per_rank
,
num_topk
,
num_experts
,
rank
,
num_ranks
,
use_fp8
,
round_scale
,
use_ue8m0
,
workspace
,
num_device_sms
,
launch_stream
,
phases
);
};
launcher
(
return_recv_hook
?
LOW_LATENCY_SEND_PHASE
:
(
LOW_LATENCY_SEND_PHASE
|
LOW_LATENCY_RECV_PHASE
));
// Wait streams
std
::
optional
<
EventHandle
>
event
;
if
(
async
)
{
// NOTES: we must ensure the all tensors will not be deallocated before the stream-wait happens,
// so in Python API, we must wrap all tensors into the event handle.
event
=
EventHandle
(
launch_stream
);
}
else
if
(
not
return_recv_hook
)
{
stream_wait
(
compute_stream
,
launch_stream
);
}
// Receiver callback
std
::
optional
<
std
::
function
<
void
()
>>
recv_hook
=
std
::
nullopt
;
if
(
return_recv_hook
)
recv_hook
=
[
=
]()
{
launcher
(
LOW_LATENCY_RECV_PHASE
);
};
// Return values
return
{
packed_recv_x
,
packed_recv_x_scales
,
packed_recv_count
,
packed_recv_src_info
,
packed_recv_layout_range
,
event
,
recv_hook
};
#else
EP_HOST_ASSERT
(
false
and
"ROCSHMEM is disabled during compilation"
);
return
{};
#endif
}
std
::
tuple
<
torch
::
Tensor
,
std
::
optional
<
EventHandle
>
,
std
::
optional
<
std
::
function
<
void
()
>>>
Buffer
::
low_latency_combine
(
const
torch
::
Tensor
&
x
,
const
torch
::
Tensor
&
topk_idx
,
const
torch
::
Tensor
&
topk_weights
,
const
torch
::
Tensor
&
src_info
,
const
torch
::
Tensor
&
layout_range
,
const
std
::
optional
<
torch
::
Tensor
>
&
combine_wait_recv_cost_stats
,
Buffer
::
low_latency_combine
(
const
torch
::
Tensor
&
x
,
const
torch
::
Tensor
&
topk_idx
,
const
torch
::
Tensor
&
topk_weights
,
const
torch
::
Tensor
&
src_info
,
const
torch
::
Tensor
&
layout_range
,
const
std
::
optional
<
torch
::
Tensor
>&
combine_wait_recv_cost_stats
,
int
num_max_dispatch_tokens_per_rank
,
int
num_experts
,
bool
use_logfmt
,
bool
zero_copy
,
bool
async
,
bool
return_recv_hook
,
const
std
::
optional
<
torch
::
Tensor
>
&
out
)
{
EP_HOST_ASSERT
(
false
and
"not support low latency"
);
const
std
::
optional
<
torch
::
Tensor
>&
out
)
{
#ifndef DISABLE_ROCSHMEM
EP_HOST_ASSERT
(
low_latency_mode
);
// Tensor checks
EP_HOST_ASSERT
(
x
.
dim
()
==
3
and
x
.
is_contiguous
()
and
x
.
scalar_type
()
==
torch
::
kBFloat16
);
EP_HOST_ASSERT
(
x
.
size
(
0
)
==
num_experts
/
num_ranks
);
EP_HOST_ASSERT
(
x
.
size
(
1
)
==
num_ranks
*
num_max_dispatch_tokens_per_rank
);
EP_HOST_ASSERT
(
x
.
size
(
2
)
%
sizeof
(
int4
)
==
0
and
x
.
size
(
2
)
%
128
==
0
);
EP_HOST_ASSERT
(
topk_idx
.
dim
()
==
2
and
topk_idx
.
is_contiguous
());
EP_HOST_ASSERT
(
topk_idx
.
size
(
0
)
==
topk_weights
.
size
(
0
)
and
topk_idx
.
size
(
1
)
==
topk_weights
.
size
(
1
));
EP_HOST_ASSERT
(
topk_idx
.
scalar_type
()
==
torch
::
kInt64
);
EP_HOST_ASSERT
(
topk_weights
.
dim
()
==
2
and
topk_weights
.
is_contiguous
());
EP_HOST_ASSERT
(
topk_weights
.
size
(
0
)
<=
num_max_dispatch_tokens_per_rank
);
EP_HOST_ASSERT
(
topk_weights
.
scalar_type
()
==
torch
::
kFloat32
);
EP_HOST_ASSERT
(
src_info
.
dim
()
==
2
and
src_info
.
is_contiguous
());
EP_HOST_ASSERT
(
src_info
.
scalar_type
()
==
torch
::
kInt32
and
x
.
size
(
0
)
==
src_info
.
size
(
0
));
EP_HOST_ASSERT
(
layout_range
.
dim
()
==
2
and
layout_range
.
is_contiguous
());
EP_HOST_ASSERT
(
layout_range
.
scalar_type
()
==
torch
::
kInt64
);
EP_HOST_ASSERT
(
layout_range
.
size
(
0
)
==
num_experts
/
num_ranks
and
layout_range
.
size
(
1
)
==
num_ranks
);
if
(
combine_wait_recv_cost_stats
.
has_value
())
{
EP_HOST_ASSERT
(
combine_wait_recv_cost_stats
->
scalar_type
()
==
torch
::
kInt64
);
EP_HOST_ASSERT
(
combine_wait_recv_cost_stats
->
dim
()
==
1
and
combine_wait_recv_cost_stats
->
is_contiguous
());
EP_HOST_ASSERT
(
combine_wait_recv_cost_stats
->
size
(
0
)
==
num_ranks
);
}
auto
hidden
=
static_cast
<
int
>
(
x
.
size
(
2
));
auto
num_topk
=
static_cast
<
int
>
(
topk_weights
.
size
(
1
));
auto
num_combined_tokens
=
static_cast
<
int
>
(
topk_weights
.
size
(
0
));
auto
global_atomic_counter
=
torch
::
zeros
({
1
},
torch
::
dtype
(
torch
::
kInt32
).
device
(
torch
::
kCUDA
));
// Buffer control
LowLatencyLayout
layout
(
rdma_buffer_ptr
,
num_max_dispatch_tokens_per_rank
,
hidden
,
num_ranks
,
num_experts
);
EP_HOST_ASSERT
(
layout
.
total_bytes
<=
num_rdma_bytes
);
auto
buffer
=
layout
.
buffers
[
low_latency_buffer_idx
];
auto
next_buffer
=
layout
.
buffers
[
low_latency_buffer_idx
^=
1
];
// Wait previous tasks to be finished
// NOTES: the hook mode will always use the default stream
auto
compute_stream
=
at
::
hip
::
getCurrentHIPStreamMasqueradingAsCUDA
();
auto
launch_stream
=
return_recv_hook
?
compute_stream
:
comm_stream
;
EP_HOST_ASSERT
(
not
(
async
and
return_recv_hook
));
if
(
not
return_recv_hook
)
stream_wait
(
launch_stream
,
compute_stream
);
// Allocate output tensor
torch
::
Tensor
combined_x
;
if
(
out
.
has_value
())
{
EP_HOST_ASSERT
(
out
->
dim
()
==
2
and
out
->
is_contiguous
());
EP_HOST_ASSERT
(
out
->
size
(
0
)
==
num_combined_tokens
and
out
->
size
(
1
)
==
hidden
);
EP_HOST_ASSERT
(
out
->
scalar_type
()
==
x
.
scalar_type
());
combined_x
=
out
.
value
();
}
else
{
combined_x
=
torch
::
empty
({
num_combined_tokens
,
hidden
},
x
.
options
());
}
// Kernel launch
auto
next_clean_meta
=
next_buffer
.
clean_meta
();
auto
launcher
=
[
=
](
int
phases
)
{
internode_ll
::
combine
(
combined_x
.
data_ptr
(),
buffer
.
combine_rdma_recv_data_buffer
,
buffer
.
combine_rdma_recv_flag_buffer
,
buffer
.
combine_rdma_send_buffer
,
x
.
data_ptr
(),
topk_idx
.
data_ptr
<
int64_t
>
(),
topk_weights
.
data_ptr
<
float
>
(),
src_info
.
data_ptr
<
int
>
(),
layout_range
.
data_ptr
<
int64_t
>
(),
global_atomic_counter
.
data_ptr
<
int
>
(),
mask_buffer_ptr
,
combine_wait_recv_cost_stats
.
has_value
()
?
combine_wait_recv_cost_stats
->
data_ptr
<
int64_t
>
()
:
nullptr
,
next_clean_meta
.
first
,
next_clean_meta
.
second
,
num_combined_tokens
,
hidden
,
num_max_dispatch_tokens_per_rank
,
num_topk
,
num_experts
,
rank
,
num_ranks
,
use_logfmt
,
workspace
,
num_device_sms
,
launch_stream
,
phases
,
zero_copy
);
};
launcher
(
return_recv_hook
?
LOW_LATENCY_SEND_PHASE
:
(
LOW_LATENCY_SEND_PHASE
|
LOW_LATENCY_RECV_PHASE
));
// Wait streams
std
::
optional
<
EventHandle
>
event
;
if
(
async
)
{
// NOTES: we must ensure the all tensors will not be deallocated before the stream-wait happens,
// so in Python API, we must wrap all tensors into the event handle.
event
=
EventHandle
(
launch_stream
);
}
else
if
(
not
return_recv_hook
)
{
stream_wait
(
compute_stream
,
launch_stream
);
}
// Receiver callback
std
::
optional
<
std
::
function
<
void
()
>>
recv_hook
=
std
::
nullopt
;
if
(
return_recv_hook
)
recv_hook
=
[
=
]()
{
launcher
(
LOW_LATENCY_RECV_PHASE
);
};
// Return values
return
{
combined_x
,
event
,
recv_hook
};
#else
EP_HOST_ASSERT
(
false
and
"ROCSHMEM is disabled during compilation"
);
return
{};
#endif
}
torch
::
Tensor
Buffer
::
get_next_low_latency_combine_buffer
(
int
num_max_dispatch_tokens_per_rank
,
int
hidden
,
int
num_experts
)
const
{
EP_HOST_ASSERT
(
false
and
"not support low latency"
);
torch
::
Tensor
Buffer
::
get_next_low_latency_combine_buffer
(
int
num_max_dispatch_tokens_per_rank
,
int
hidden
,
int
num_experts
)
const
{
#ifndef DISABLE_ROCSHMEM
LowLatencyLayout
layout
(
rdma_buffer_ptr
,
num_max_dispatch_tokens_per_rank
,
hidden
,
num_ranks
,
num_experts
);
auto
buffer
=
layout
.
buffers
[
low_latency_buffer_idx
];
auto
dtype
=
torch
::
kBFloat16
;
auto
num_msg_elems
=
static_cast
<
int
>
(
buffer
.
num_bytes_per_combine_msg
/
elementSize
(
torch
::
kBFloat16
));
// buffer.num_bytes_per_combine_msg = sizeof(int4) + hidden * sizeof(hip_bfloat16);
EP_HOST_ASSERT
(
buffer
.
num_bytes_per_combine_msg
%
elementSize
(
torch
::
kBFloat16
)
==
0
);
return
torch
::
from_blob
(
buffer
.
combine_rdma_send_buffer_data_start
,
{
num_experts
/
num_ranks
,
num_ranks
*
num_max_dispatch_tokens_per_rank
,
hidden
},
{
num_ranks
*
num_max_dispatch_tokens_per_rank
*
num_msg_elems
,
num_msg_elems
,
1
},
torch
::
TensorOptions
().
dtype
(
dtype
).
device
(
torch
::
kCUDA
));
#else
EP_HOST_ASSERT
(
false
and
"ROCSHMEM is disabled during compilation"
);
return
{};
#endif
}
void
Buffer
::
low_latency_update_mask_buffer
(
int
rank_to_mask
,
bool
mask
)
{
EP_HOST_ASSERT
(
mask_buffer_ptr
!=
nullptr
and
"Shrink mode must be enabled"
);
EP_HOST_ASSERT
(
rank_to_mask
>=
0
and
rank_to_mask
<
num_ranks
);
internode_ll
::
update_mask_buffer
(
mask_buffer_ptr
,
rank_to_mask
,
mask
,
at
::
hip
::
getCurrentHIPStreamMasqueradingAsCUDA
());
}
void
Buffer
::
low_latency_query_mask_buffer
(
const
torch
::
Tensor
&
mask_status
)
{
EP_HOST_ASSERT
(
mask_buffer_ptr
!=
nullptr
and
"Shrink mode must be enabled"
);
EP_HOST_ASSERT
(
mask_status
.
numel
()
==
num_ranks
&&
mask_status
.
scalar_type
()
==
torch
::
kInt32
);
internode_ll
::
query_mask_buffer
(
mask_buffer_ptr
,
num_ranks
,
reinterpret_cast
<
int
*>
(
mask_status
.
data_ptr
()),
at
::
hip
::
getCurrentHIPStreamMasqueradingAsCUDA
());
}
void
Buffer
::
low_latency_clean_mask_buffer
()
{
EP_HOST_ASSERT
(
mask_buffer_ptr
!=
nullptr
and
"Shrink mode must be enabled"
);
internode_ll
::
clean_mask_buffer
(
mask_buffer_ptr
,
num_ranks
,
at
::
hip
::
getCurrentHIPStreamMasqueradingAsCUDA
());
}
}
// namespace deep_ep
...
...
@@ -1346,8 +1634,11 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
.
def
(
"clean_low_latency_buffer"
,
&
deep_ep
::
Buffer
::
clean_low_latency_buffer
)
.
def
(
"low_latency_dispatch"
,
&
deep_ep
::
Buffer
::
low_latency_dispatch
)
.
def
(
"low_latency_combine"
,
&
deep_ep
::
Buffer
::
low_latency_combine
)
.
def
(
"get_next_low_latency_combine_buffer"
,
&
deep_ep
::
Buffer
::
get_next_low_latency_combine_buffer
);
.
def
(
"get_next_low_latency_combine_buffer"
,
&
deep_ep
::
Buffer
::
get_next_low_latency_combine_buffer
)
.
def
(
"low_latency_update_mask_buffer"
,
&
deep_ep
::
Buffer
::
low_latency_update_mask_buffer
)
.
def
(
"low_latency_query_mask_buffer"
,
&
deep_ep
::
Buffer
::
low_latency_query_mask_buffer
)
.
def
(
"low_latency_clean_mask_buffer"
,
&
deep_ep
::
Buffer
::
low_latency_clean_mask_buffer
);
// m.def("is_sm90_compiled", deep_ep::is_sm90_compiled);
// m.attr("
topk_idx
_t") = py::cast(c10::CppTypeToScalarType<deep_ep::
topk_idx
_t>::value);
// m.attr("
int64
_t") = py::cast(c10::CppTypeToScalarType<deep_ep::
int64
_t>::value);
}
csrc/deep_ep.hpp
View file @
09cb2b03
...
...
@@ -30,6 +30,11 @@ private:
int64_t
num_rdma_bytes
;
void
*
rdma_buffer_ptr
=
nullptr
;
// Shrink mode buffer
bool
enable_shrink
=
false
;
int
*
mask_buffer_ptr
=
nullptr
;
int
*
sync_buffer_ptr
=
nullptr
;
// Device info and communication
int
device_id
;
int
num_device_sms
;
...
...
@@ -67,11 +72,9 @@ private:
volatile
int
*
moe_recv_rdma_counter
=
nullptr
;
int
*
moe_recv_rdma_counter_mapped
=
nullptr
;
bool
use_default_stream_as_comm_stream
=
false
;
public:
Buffer
(
int
rank
,
int
num_ranks
,
int64_t
num_nvl_bytes
,
int64_t
num_rdma_bytes
,
bool
low_latency_mode
,
bool
explicitly_destroy
,
bool
use_default_stream_as_comm_stream
);
bool
low_latency_mode
,
bool
explicitly_destroy
,
bool
enable_shrink
);
~
Buffer
()
noexcept
(
false
);
...
...
@@ -187,6 +190,12 @@ public:
torch
::
Tensor
get_next_low_latency_combine_buffer
(
int
num_max_dispatch_tokens_per_rank
,
int
hidden
,
int
num_experts
)
const
;
void
low_latency_update_mask_buffer
(
int
rank_to_mask
,
bool
mask
);
void
low_latency_query_mask_buffer
(
const
torch
::
Tensor
&
mask_status
);
void
low_latency_clean_mask_buffer
();
};
}
// namespace deep_ep
csrc/kernels/api.cuh
View file @
09cb2b03
...
...
@@ -131,4 +131,46 @@ void combine(hipDataType type, void *combined_x, float *combined_topk_weights,
int
num_ranks
,
hipStream_t
stream
,
int
num_channels
,
bool
low_latency_mode
);
}
// namespace internode
// Internode low-latency kernels
namespace
internode_ll
{
void
clean_low_latency_buffer
(
int64_t
*
clean_0
,
int
num_clean_int_0
,
int64_t
*
clean_1
,
int
num_clean_int_1
,
int
rank
,
int
num_ranks
,
int
*
mask_buffer
,
int
*
sync_buffer
,
hipStream_t
stream
);
void
dispatch
(
void
*
packed_recv_x
,
void
*
packed_recv_x_scales
,
int
*
packed_recv_src_info
,
int64_t
*
packed_recv_layout_range
,
int
*
packed_recv_count
,
int
*
global_atomic_counter
,
int
*
mask_buffer
,
int
*
cumulative_local_expert_recv_stats
,
int64_t
*
dispatch_wait_recv_cost_stats
,
void
*
rdma_recv_x
,
int64_t
*
rdma_recv_count
,
void
*
rdma_x
,
const
void
*
x
,
const
int64_t
*
topk_idx
,
int64_t
*
next_clean
,
int
num_next_clean_int
,
int
num_tokens
,
int
hidden
,
int
num_max_dispatch_tokens_per_rank
,
int
num_topk
,
int
num_experts
,
int
rank
,
int
num_ranks
,
bool
use_fp8
,
bool
round_scale
,
bool
use_ue8m0
,
void
*
workspace
,
int
num_device_sms
,
hipStream_t
stream
,
int
phases
);
void
combine
(
void
*
combined_x
,
void
*
rdma_recv_x
,
int64_t
*
rdma_recv_flag
,
void
*
rdma_send_x
,
const
void
*
x
,
const
int64_t
*
topk_idx
,
const
float
*
topk_weights
,
const
int
*
src_info
,
const
int64_t
*
layout_range
,
int
*
global_atomic_counter
,
int
*
mask_buffer
,
int64_t
*
combine_wait_recv_cost_stats
,
int64_t
*
next_clean
,
int
num_next_clean_int
,
int
num_combined_tokens
,
int
hidden
,
int
num_max_dispatch_tokens_per_rank
,
int
num_topk
,
int
num_experts
,
int
rank
,
int
num_ranks
,
bool
use_logfmt
,
void
*
workspace
,
int
num_device_sms
,
hipStream_t
stream
,
int
phases
,
bool
zero_copy
);
void
query_mask_buffer
(
int
*
mask_buffer_ptr
,
int
num_ranks
,
int
*
output_mask_tensor
,
hipStream_t
stream
);
void
update_mask_buffer
(
int
*
mask_buffer_ptr
,
int
rank_to_mask
,
bool
mask
,
hipStream_t
stream
);
void
clean_mask_buffer
(
int
*
mask_buffer_ptr
,
int
num_ranks
,
hipStream_t
stream
);
}
// namespace internode_ll
}
// namespace deep_ep
csrc/kernels/configs.cuh
View file @
09cb2b03
...
...
@@ -22,6 +22,8 @@
#define LOW_LATENCY_SEND_PHASE 1
#define LOW_LATENCY_RECV_PHASE 2
#define NUM_INTERNODE_DISPATCH_BLOCKS_PER_CHANNEL 3
#define FP8_QUANTIZATION_NUM_PER_CHANNEL 128
#define NUM_INTERNODE_DISPATCH_BLOCKS_PER_CHANNEL 3
...
...
csrc/kernels/internode_ll.cu
View file @
09cb2b03
#include "configs.cuh"
#include "exception.cuh"
#include "launch.cuh"
// #include "ibgda_device.cuh"
#include "buffer.cuh"
#include "utils.cuh"
// #include <cooperative_groups.h>
#include <iostream>
// low latency+RocSHMEM has issue with CTX.
#define ROCM_DISABLE_CTX
#ifndef DISABLE_ROCSHMEM
#include <rocshmem/rocshmem.hpp>
#include <rocshmem/rocshmem_COLL.hpp>
namespace
deep_ep
{
namespace
internode_ll
{
template
<
bool
use_warp_sync
=
false
>
__forceinline__
__device__
bool
is_rank_masked
(
int
*
mask_buffer_ptr
,
int
rank
)
{
if
(
mask_buffer_ptr
==
nullptr
)
{
return
false
;
}
if
constexpr
(
use_warp_sync
)
{
return
shfl_sync
(
ld_acquire_global
(
mask_buffer_ptr
+
rank
),
0
)
!=
0
;
}
else
{
return
ld_acquire_global
(
mask_buffer_ptr
+
rank
)
!=
0
;
}
}
__device__
void
grid_barrier
(
int
*
global_counter
,
int
num_blocks
)
{
volatile
int
ret
;
__syncthreads
();
memory_fence_gpu
();
if
(
threadIdx
.
x
==
0
)
{
ret
=
atomicAdd
((
int
*
)
&
global_counter
[
0
],
1
);
}
__syncthreads
();
if
(
threadIdx
.
x
==
0
)
{
while
(
ld_relaxed_global
(
global_counter
)
!=
num_blocks
);
}
__syncthreads
();
}
template
<
int
kNumThreads
>
__launch_bounds__
(
kNumThreads
,
1
)
__global__
void
clean_low_latency_buffer
(
int
*
clean_0
,
int
num_clean_int_0
,
int
*
clean_1
,
int
num_clean_int_1
)
{
__global__
void
clean_low_latency_buffer
(
int64_t
*
clean_0
,
int
num_clean_int_0
,
int64_t
*
clean_1
,
int
num_clean_int_1
,
int
rank
,
int
num_ranks
,
int
*
mask_buffer_ptr
,
int
*
sync_buffer_ptr
)
{
auto
thread_id
=
static_cast
<
int
>
(
threadIdx
.
x
);
// Barrier before cleaning (in case of unfinished chunked EP)
// nvshmemx_barrier_all_block();
// // Clean
// auto thread_id = static_cast<int>(threadIdx.x);
// #pragma unroll
// for (int i = thread_id; i < num_clean_int_0; i += kNumThreads)
// clean_0[i] = 0;
// #pragma unroll
// for (int i = thread_id; i < num_clean_int_1; i += kNumThreads)
// clean_1[i] = 0;
// // Barrier after cleaning (make sure the low-latency mode works fine)
// nvshmemx_barrier_all_block();
if
(
sync_buffer_ptr
==
nullptr
)
{
// rocshmem::rocshmem_barrier_all_wg();
if
(
thread_id
==
0
)
rocshmem
::
rocshmem_barrier_all
();
}
else
{
// barrier<kNumThreads>(thread_id, rank, num_ranks, mask_buffer_ptr, sync_buffer_ptr);
EP_DEVICE_ASSERT
(
0
);
}
// Clean
#pragma unroll
for
(
int
i
=
thread_id
;
i
<
num_clean_int_0
;
i
+=
kNumThreads
)
clean_0
[
i
]
=
0
;
#pragma unroll
for
(
int
i
=
thread_id
;
i
<
num_clean_int_1
;
i
+=
kNumThreads
)
clean_1
[
i
]
=
0
;
// Barrier after cleaning (make sure low-latency mode work
if
(
sync_buffer_ptr
==
nullptr
)
{
// rocshmem::rocshmem_barrier_all_wg();
if
(
thread_id
==
0
)
rocshmem
::
rocshmem_barrier_all
();
}
else
{
// barrier<kNumThreads>(thread_id, rank, num_ranks, mask_buffer_ptr, sync_buffer_ptr);
EP_DEVICE_ASSERT
(
0
);
}
}
void
clean_low_latency_buffer
(
int
*
clean_0
,
int
num_clean_int_0
,
int
*
clean_1
,
int
num_clean_int_1
,
cudaStream_t
stream
)
{
// constexpr int kNumThreads = 256;
// SETUP_LAUNCH_CONFIG(1, kNumThreads, stream);
// LAUNCH_KERNEL(&cfg, clean_low_latency_buffer<kNumThreads>,
// clean_0, num_clean_int_0, clean_1, num_clean_int_1);
void
clean_low_latency_buffer
(
int64_t
*
clean_0
,
int
num_clean_int_0
,
int64_t
*
clean_1
,
int
num_clean_int_1
,
int
rank
,
int
num_ranks
,
int
*
mask_buffer_ptr
,
int
*
sync_buffer_ptr
,
hipStream_t
stream
)
{
constexpr
int
kNumThreads
=
256
;
SETUP_LAUNCH_CONFIG
(
1
,
kNumThreads
,
stream
);
LAUNCH_KERNEL
(
&
cfg
,
clean_low_latency_buffer
<
kNumThreads
>
,
clean_0
,
num_clean_int_0
,
clean_1
,
num_clean_int_1
,
rank
,
num_ranks
,
mask_buffer_ptr
,
sync_buffer_ptr
);
}
template
<
bool
kUseFP8
,
bool
kUseUE8M0
,
int
kHidden
>
__global__
__launch_bounds__
(
1024
,
1
)
void
dispatch
(
void
*
packed_recv_x
,
void
*
packed_recv_x_scales
,
int
*
packed_recv_src_info
,
int64_t
*
packed_recv_layout_range
,
__launch_bounds__
(
1024
,
1
)
__global__
void
dispatch
(
void
*
packed_recv_x
,
void
*
packed_recv_x_scales
,
int
*
packed_recv_src_info
,
int64_t
*
packed_recv_layout_range
,
int
*
packed_recv_count
,
int
*
global_atomic_counter
,
int
*
mask_buffer_ptr
,
int
*
cumulative_local_expert_recv_stats
,
int64_t
*
dispatch_wait_recv_cost_stats
,
void
*
rdma_recv_x
,
int
*
rdma_recv_count
,
void
*
rdma_x
,
const
void
*
x
,
const
topk_idx_t
*
topk_idx
,
int
*
atomic_counter_per_expert
,
int
*
atomic_finish_counter_per_expert
,
int
*
next_clean
,
int
num_next_clean_int
,
int
num_tokens
,
int
num_max_dispatch_tokens_per_rank
,
int
num_topk
,
int
num_experts
,
int
rank
,
int
num_ranks
,
int
num_warp_groups
,
int
num_warps_per_group
,
bool
round_scale
,
int
phases
)
{
// const auto sm_id = static_cast<int>(blockIdx.x);
// const auto thread_id = static_cast<int>(threadIdx.x);
// const auto warp_id = thread_id / 32, lane_id = get_lane_id();
// const auto num_sms = static_cast<int>(gridDim.x);
// const auto num_warps = num_warp_groups * num_warps_per_group;
// const auto num_local_experts = num_experts / num_ranks;
// const auto warp_group_id = warp_id / num_warps_per_group;
// const auto sub_warp_id = warp_id % num_warps_per_group;
// const auto responsible_expert_idx = sm_id * num_warp_groups + warp_group_id;
// // May extract UE8M0 from the scales
// using scale_t = std::conditional_t<kUseUE8M0, uint8_t, float>;
// using packed_t = std::conditional_t<kUseUE8M0, uint32_t, float>;
// EP_STATIC_ASSERT(sizeof(packed_t) % sizeof(scale_t) == 0, "Invalid vector length");
// // FP8 staffs
// constexpr int kNumPerChannels = 128;
// const int num_scales = kHidden / kNumPerChannels;
// const size_t hidden_bytes = kHidden * (kUseFP8 ? sizeof(__nv_fp8_storage_t) : sizeof(nv_bfloat16));
// const size_t hidden_int4 = hidden_bytes / sizeof(int4);
// // Message package: index at source (int), 3 reserved int fields, hidden data, FP8 scales
// // NOTES: currently we have 3 reserved int fields for future use
// using vec_t = std::conditional_t<kUseFP8, int2, int4>;
// const size_t num_bytes_per_msg = sizeof(int4) + (kUseFP8 ? (kHidden + num_scales * sizeof(float)) : (kHidden * sizeof(nv_bfloat16)));
// const size_t num_int4_per_msg = num_bytes_per_msg / sizeof(int4);
// EP_DEVICE_ASSERT(num_bytes_per_msg % sizeof(int4) == 0);
// // Expert counts
// constexpr int kNumMaxWarpGroups = 32;
// __shared__ int shared_num_tokens_sent_per_expert[kNumMaxWarpGroups];
// // Sending phase
// if ((phases & LOW_LATENCY_SEND_PHASE) == 0)
// goto LOW_LATENCY_DISPATCH_RECV;
// // There are 2 kinds of warps in this part:
// // 1. The first-kind warps for FP8 cast and sending top-k tokens
// // 2. The last warp for reading `topk_idx` and count for per-expert information
// if (warp_id < num_warps - 1) {
// constexpr int kNumElemsPerRead = sizeof(int4) / sizeof(nv_bfloat16);
// EP_STATIC_ASSERT(kHidden % (32 * kNumElemsPerRead) == 0, "Invalid hidden");
// EP_STATIC_ASSERT(kNumElemsPerRead * 32 % kNumPerChannels == 0, "Invalid vectorization");
// const auto num_threads = (num_warps - 1) * 32;
// const size_t hidden_bf16_int4 = kHidden / kNumElemsPerRead;
// for (int token_idx = sm_id; token_idx < num_tokens; token_idx += num_sms) {
// const auto x_int4 = static_cast<const int4*>(x) + token_idx * hidden_bf16_int4;
// const auto rdma_x_src_idx = reinterpret_cast<int*>(static_cast<uint8_t*>(rdma_x) + token_idx * num_bytes_per_msg);
// const auto rdma_x_vec = reinterpret_cast<vec_t*>(reinterpret_cast<uint8_t*>(rdma_x_src_idx) + sizeof(int4));
// const auto rdma_x_scales = reinterpret_cast<float*>(reinterpret_cast<uint8_t*>(rdma_x_vec) + hidden_bytes);
// // Overlap top-k index read and source token index writes
// auto dst_expert_idx = warp_id < num_topk ? static_cast<int>(__ldg(topk_idx + token_idx * num_topk + warp_id)) : -1;
// thread_id == 0 ? (*rdma_x_src_idx = token_idx) : 0;
// // FP8 cast
// EP_STATIC_ASSERT(hidden_bf16_int4 % 32 == 0, "Must use the full warp to reduce");
// #pragma unroll
// for (int i = thread_id; i < hidden_bf16_int4; i += num_threads) {
// // Read
// auto int4_value = __ldg(x_int4 + i);
// if constexpr (kUseFP8) {
// // Calculate local amax
// auto bf16_values = reinterpret_cast<nv_bfloat16*>(&int4_value);
// float fp32_values[kNumElemsPerRead];
// float amax = kFP8Margin, scale, scale_inv;
// #pragma unroll
// for (int j = 0; j < kNumElemsPerRead; ++ j) {
// fp32_values[j] = static_cast<float>(bf16_values[j]);
// amax = fmaxf(amax, fabsf(fp32_values[j]));
// }
// // Reduce amax and scale
// EP_STATIC_ASSERT(kNumElemsPerRead * 32 / kNumPerChannels == 2, "Invalid vectorization");
// amax = warp_reduce_max<16>(amax);
// calculate_fp8_scales(amax, scale, scale_inv, round_scale);
// if (lane_id == 0 or lane_id == 16)
// rdma_x_scales[i * kNumElemsPerRead / 128] = scale_inv;
// // Cast into send buffer
// vec_t int2_value;
// auto fp8x2_values = reinterpret_cast<__nv_fp8x2_storage_t*>(&int2_value);
// #pragma unroll
// for (int j = 0; j < kNumElemsPerRead; j += 2) {
// float2 fp32x2 = {fp32_values[j] * scale, fp32_values[j + 1] * scale};
// fp8x2_values[j / 2] = __nv_cvt_float2_to_fp8x2(fp32x2, __NV_SATFINITE, __NV_E4M3);
// }
// rdma_x_vec[i] = int2_value;
// } else {
// // Reinterpret-cast is for C++14 compatibility
// rdma_x_vec[i] = *reinterpret_cast<vec_t*>(&int4_value);
// }
// }
// asm volatile("bar.sync 1, %0;" :: "r"(num_threads));
// // Issue IBGDA sends
// if (dst_expert_idx >= 0) {
// int slot_idx = lane_id == 0 ? atomicAdd(atomic_counter_per_expert + dst_expert_idx, 1) : 0;
// slot_idx = __shfl_sync(0xffffffff, slot_idx, 0);
// const auto dst_rank = dst_expert_idx / num_local_experts;
// const auto dst_expert_local_idx = dst_expert_idx % num_local_experts;
// const auto src_ptr = reinterpret_cast<uint64_t>(rdma_x_src_idx);
// const auto dst_ptr = reinterpret_cast<uint64_t>(rdma_recv_x) +
// dst_expert_local_idx * num_ranks * num_max_dispatch_tokens_per_rank * num_bytes_per_msg +
// rank * num_max_dispatch_tokens_per_rank * num_bytes_per_msg +
// slot_idx * num_bytes_per_msg;
// const auto dst_p2p_ptr = nvshmemi_get_p2p_ptr(dst_ptr, rank, dst_rank);
// if (dst_p2p_ptr == 0) {
// nvshmemi_ibgda_put_nbi_warp(dst_ptr, src_ptr, num_bytes_per_msg, dst_rank, dst_expert_local_idx, lane_id, slot_idx);
// } else {
// // NOTES: only 2 load iterations for 7K hidden with 8 unrolls
// const auto* src_int4_ptr = reinterpret_cast<const int4*>(src_ptr);
// const auto* dst_int4_ptr = reinterpret_cast<int4*>(dst_p2p_ptr);
// UNROLLED_WARP_COPY(8, lane_id, num_int4_per_msg, dst_int4_ptr, src_int4_ptr, ld_nc_global, st_na_global);
// }
// // Increase counter after finishing
// __syncwarp();
// lane_id == 0 ? atomic_add_release_global(atomic_finish_counter_per_expert + dst_expert_idx, 1) : 0;
// }
// }
// } else if (warp_id == num_warps - 1) {
// EP_DEVICE_ASSERT(num_sms > 1);
// if (sm_id == 0) {
// // The first SM is also responsible for checking QPs
// EP_DEVICE_ASSERT(ibgda_get_state()->num_rc_per_pe >= num_local_experts);
// // The first SM is also responsible for cleaning the next buffer
// #pragma unroll
// for (int i = lane_id; i < num_next_clean_int; i += 32)
// next_clean[i] = 0;
// // Notify before executing `int_p`
// __syncwarp();
// #pragma unroll
// for (int i = lane_id; i < num_experts; i += 32)
// atomic_add_release_global(atomic_finish_counter_per_expert + i, FINISHED_SUM_TAG);
// }
// // This SM should be responsible for some destination experts, read `topk_idx` for them
// int expert_count[kNumMaxWarpGroups] = {0};
// const auto expert_begin_idx = sm_id * num_warp_groups;
// const auto expert_end_idx = min(expert_begin_idx + num_warp_groups, num_experts);
// // Per lane count
// #pragma unroll 8
// for (int i = lane_id; i < num_tokens * num_topk; i += 32) {
// auto idx = static_cast<int>(__ldg(topk_idx + i));
// if (idx >= expert_begin_idx and idx < expert_end_idx)
// expert_count[idx - expert_begin_idx] ++;
// }
// // Warp reduce
// #pragma unroll
// for (int i = expert_begin_idx; i < expert_end_idx; ++ i) {
// auto sum = warp_reduce_sum(expert_count[i - expert_begin_idx]);
// if (lane_id == 0) {
// shared_num_tokens_sent_per_expert[i - expert_begin_idx] = sum;
// atomic_add_release_global(atomic_finish_counter_per_expert + i, FINISHED_SUM_TAG - sum);
// }
// }
// }
// __syncthreads();
// // Issue count sends
// if (responsible_expert_idx < num_experts and sub_warp_id == 0 and lane_id == 0) {
// const auto dst_rank = responsible_expert_idx / num_local_experts;
// const auto dst_expert_local_idx = responsible_expert_idx % num_local_experts;
// const auto num_tokens_sent = shared_num_tokens_sent_per_expert[responsible_expert_idx - sm_id * num_warp_groups];
// // Wait local sends issued and send expert counts
// while (ld_acquire_global(atomic_finish_counter_per_expert + responsible_expert_idx) != FINISHED_SUM_TAG * 2);
// auto dst_ptr = reinterpret_cast<uint64_t>(rdma_recv_count + dst_expert_local_idx * num_ranks + rank);
// auto dst_p2p_ptr = nvshmemi_get_p2p_ptr(dst_ptr, rank, dst_rank);
// if (dst_p2p_ptr == 0) {
// nvshmemi_ibgda_amo_nonfetch_add(reinterpret_cast<int*>(dst_ptr), -num_tokens_sent - 1, dst_rank, dst_expert_local_idx);
// } else {
// st_release_sys_global(reinterpret_cast<int*>(dst_p2p_ptr), -num_tokens_sent - 1);
// }
// // Clean workspace for next use
// atomic_counter_per_expert[responsible_expert_idx] = 0;
// atomic_finish_counter_per_expert[responsible_expert_idx] = 0;
// // Clean `packed_recv_count`
// if (dst_rank == 0)
// packed_recv_count[dst_expert_local_idx] = 0;
// }
// __syncwarp();
// // Receiving phase
// LOW_LATENCY_DISPATCH_RECV:
// if ((phases & LOW_LATENCY_RECV_PHASE) == 0)
// return;
// // For send-and-recv kernels, we need a grid sync for making `packed_recv_count` visible
// if (phases & LOW_LATENCY_SEND_PHASE)
// cg::this_grid().sync();
// // Receiving and packing
// if (responsible_expert_idx < num_experts) {
// const auto src_rank = responsible_expert_idx / num_local_experts;
// const auto local_expert_idx = responsible_expert_idx % num_local_experts;
// const auto rdma_recv_x_uint8 = static_cast<uint8_t*>(rdma_recv_x) +
// local_expert_idx * num_ranks * num_max_dispatch_tokens_per_rank * num_bytes_per_msg +
// src_rank * num_max_dispatch_tokens_per_rank * num_bytes_per_msg;
// const auto recv_x_int4 = static_cast<int4*>(packed_recv_x) +
// local_expert_idx * num_ranks * num_max_dispatch_tokens_per_rank * hidden_int4;
// const auto recv_src_info = packed_recv_src_info + local_expert_idx * num_ranks * num_max_dispatch_tokens_per_rank;
// const auto recv_range = packed_recv_layout_range + local_expert_idx * num_ranks;
// const auto num_aligned_scales = align_up<int>(num_scales, sizeof(float) / sizeof(scale_t));
// const auto recv_x_scales = static_cast<scale_t*>(packed_recv_x_scales) + local_expert_idx * num_ranks * num_max_dispatch_tokens_per_rank * num_aligned_scales;
// // Shared between sub-warps in warp groups
// __shared__ int shared_num_recv_tokens[kNumMaxWarpGroups], shared_recv_token_begin_idx[kNumMaxWarpGroups];
// // Wait tokens to arrive
// // NOTES: using sub-warp 1 to overlap with sub-warp 0
// int num_recv_tokens, recv_token_begin_idx;
// EP_DEVICE_ASSERT(num_warps_per_group > 1 and num_warp_groups < 15);
// if (sub_warp_id == 1 and lane_id == 0) {
// auto start_time = clock64();
// while ((num_recv_tokens = ld_acquire_sys_global(rdma_recv_count + local_expert_idx * num_ranks + src_rank)) == 0);
// auto wait_recv_cost = clock64() - start_time;
// num_recv_tokens = -num_recv_tokens - 1;
// recv_token_begin_idx = atomicAdd(packed_recv_count + local_expert_idx, num_recv_tokens);
// shared_num_recv_tokens[warp_group_id] = num_recv_tokens;
// shared_recv_token_begin_idx[warp_group_id] = recv_token_begin_idx;
// recv_range[src_rank] = pack2<int, int64_t>(num_recv_tokens, recv_token_begin_idx);
// // Add stats for diagnosis
// if (cumulative_local_expert_recv_stats != nullptr)
// atomicAdd(cumulative_local_expert_recv_stats + local_expert_idx, num_recv_tokens);
// if (dispatch_wait_recv_cost_stats != nullptr)
// atomicAdd(reinterpret_cast<unsigned long long*>(dispatch_wait_recv_cost_stats + src_rank), wait_recv_cost);
// }
// asm volatile("bar.sync %0, %1;" :: "r"(warp_group_id + 2), "r"(num_warps_per_group * 32));
// num_recv_tokens = shared_num_recv_tokens[warp_group_id];
// recv_token_begin_idx = shared_recv_token_begin_idx[warp_group_id];
// // Copy tokens
// EP_DEVICE_ASSERT(num_scales <= 64);
// for (int i = sub_warp_id; i < num_recv_tokens; i += num_warps_per_group) {
// // Copy source info
// const auto src_src_idx = reinterpret_cast<int*>(rdma_recv_x_uint8 + i * num_bytes_per_msg);
// if (lane_id == 0)
// recv_src_info[recv_token_begin_idx + i] = ld_nc_global(src_src_idx);
// __syncwarp();
// // Copy data
// // NOTES: only 2 load iterations for 7K hidden with 7 unrolls
// const auto src_data = reinterpret_cast<int4*>(reinterpret_cast<uint8_t*>(src_src_idx) + sizeof(int4));
// const auto dst_data = recv_x_int4 + (recv_token_begin_idx + i) * hidden_int4;
// UNROLLED_WARP_COPY(7, lane_id, hidden_int4, dst_data, src_data, ld_nc_global, st_na_global);
// // Copy scales
// if constexpr (kUseFP8) {
// // Equivalent CuTe layout:
// // (num_tokens, (num_packed, num_elems_per_pack)):(num_elems_per_pack, (num_tokens * num_elems_per_pack, 1))
// const auto src_scales = reinterpret_cast<float*>(reinterpret_cast<uint8_t*>(src_data) + hidden_bytes);
// const auto num_elems_per_pack = static_cast<int>(sizeof(packed_t) / sizeof(scale_t));
// const auto token_idx = recv_token_begin_idx + i;
// const auto token_stride = num_elems_per_pack;
// const auto pack_stride = num_ranks * num_max_dispatch_tokens_per_rank * num_elems_per_pack;
// if (lane_id < num_scales) {
// const auto pack_idx = lane_id / num_elems_per_pack;
// const auto elem_idx = lane_id % num_elems_per_pack;
// auto scale = extract_required_scale_format<kUseUE8M0>(ld_nc_global(src_scales + lane_id));
// recv_x_scales[token_idx * token_stride + pack_idx * pack_stride + elem_idx] = scale;
// }
// if (lane_id + 32 < num_scales) {
// const auto pack_idx = (lane_id + 32) / num_elems_per_pack;
// const auto elem_idx = (lane_id + 32) % num_elems_per_pack;
// auto scale = extract_required_scale_format<kUseUE8M0>(ld_nc_global(src_scales + lane_id + 32));
// recv_x_scales[token_idx * token_stride + pack_idx * pack_stride + elem_idx] = scale;
// }
// }
// }
// }
void
*
rdma_recv_x
,
int64_t
*
rdma_recv_count
,
void
*
rdma_x
,
const
void
*
x
,
const
int64_t
*
topk_idx
,
int
*
atomic_counter_per_expert
,
int
*
atomic_finish_counter_per_expert
,
int64_t
*
next_clean
,
int
num_next_clean_int
,
int
num_tokens
,
int
num_max_dispatch_tokens_per_rank
,
int
num_topk
,
int
num_experts
,
int
rank
,
int
num_ranks
,
int
num_warp_groups
,
int
num_warps_per_group
,
bool
round_scale
,
int
phases
)
{
#if !defined(ROCM_DISABLE_CTX)
__shared__
rocshmem
::
rocshmem_ctx_t
ctx
;
rocshmem
::
rocshmem_wg_ctx_create
(
0
,
&
ctx
);
#endif
const
auto
sm_id
=
static_cast
<
int
>
(
blockIdx
.
x
);
const
auto
thread_id
=
static_cast
<
int
>
(
threadIdx
.
x
);
const
auto
warp_id
=
thread_id
/
kWarpSize
,
lane_id
=
get_lane_id
();
const
auto
num_sms
=
static_cast
<
int
>
(
gridDim
.
x
);
const
auto
num_warps
=
num_warp_groups
*
num_warps_per_group
;
const
auto
num_local_experts
=
num_experts
/
num_ranks
;
const
auto
warp_group_id
=
warp_id
/
num_warps_per_group
;
const
auto
sub_warp_id
=
warp_id
%
num_warps_per_group
;
// 每个warp处理一个expert
const
auto
responsible_expert_idx
=
sm_id
*
num_warp_groups
+
warp_group_id
;
// May extract UE8M0 from the scales
using
scale_t
=
std
::
conditional_t
<
kUseUE8M0
,
uint8_t
,
float
>
;
using
packed_t
=
std
::
conditional_t
<
kUseUE8M0
,
uint32_t
,
float
>
;
EP_STATIC_ASSERT
(
sizeof
(
packed_t
)
%
sizeof
(
scale_t
)
==
0
,
"Invalid vector length"
);
// FP8 staffs
constexpr
int
kNumPerChannels
=
FP8_QUANTIZATION_NUM_PER_CHANNEL
;
const
int
num_scales
=
kHidden
/
kNumPerChannels
;
const
size_t
hidden_bytes
=
kHidden
*
(
kUseFP8
?
sizeof
(
__hip_fp8_storage_t
)
:
sizeof
(
hip_bfloat16
));
const
size_t
hidden_int4
=
hidden_bytes
/
sizeof
(
int4
);
// Message package: index at source (int), 3 reserved int fields, hidden data, FP8 scales
// NOTES: currently we have 3 reserved int fields for future use
using
vec_t
=
std
::
conditional_t
<
kUseFP8
,
int2
,
int4
>
;
const
size_t
num_bytes_per_msg
=
sizeof
(
int4
)
+
(
kUseFP8
?
(
kHidden
+
num_scales
*
sizeof
(
float
))
:
(
kHidden
*
sizeof
(
hip_bfloat16
)));
const
size_t
num_int4_per_msg
=
num_bytes_per_msg
/
sizeof
(
int4
);
EP_DEVICE_ASSERT
(
num_bytes_per_msg
%
sizeof
(
int4
)
==
0
);
// Expert counts
constexpr
int
kNumMaxWarpGroups
=
16
;
// 每个kernel最多warp group数量,即每个block负责的专家数
__shared__
int
shared_num_tokens_sent_per_expert
[
kNumMaxWarpGroups
];
#ifdef USE_ROCM
// 用于同步
// 16 is the max possible number of warps in AMD GPUs
constexpr
int
kMaxNumWarps
=
1024
/
kWarpSize
;
constexpr
int
num_sync_large_iteration
=
kMaxNumWarps
;
__shared__
volatile
int
sync_large_warp_counters
[
num_sync_large_iteration
];
#pragma unroll
for
(
int
i
=
thread_id
;
i
<
num_sync_large_iteration
;
i
+=
blockDim
.
x
)
{
sync_large_warp_counters
[
i
]
=
0
;
}
__syncthreads
();
#endif
// Sending phase,如果没有发送任务,则直接跳到接收阶段
if
((
phases
&
LOW_LATENCY_SEND_PHASE
)
==
0
)
goto
LOW_LATENCY_DISPATCH_RECV
;
// There are 2 kinds of warps in this part:
// 1. The first-kind warps for FP8 cast and sending top-k tokens
// 2. The last warp for reading `topk_idx` and count for per-expert information
if
(
warp_id
<
num_warps
-
1
)
{
constexpr
int
kNumElemsPerRead
=
sizeof
(
int4
)
/
sizeof
(
hip_bfloat16
);
// 128/16 = 8
EP_STATIC_ASSERT
(
kHidden
%
(
kWarpSize
*
kNumElemsPerRead
)
==
0
,
"Invalid hidden"
);
EP_STATIC_ASSERT
(
kNumElemsPerRead
*
kWarpSize
%
kNumPerChannels
==
0
,
"Invalid vectorization"
);
const
auto
num_threads
=
(
num_warps
-
1
)
*
kWarpSize
;
const
size_t
hidden_bf16_int4
=
kHidden
/
kNumElemsPerRead
;
for
(
int
token_idx
=
sm_id
;
token_idx
<
num_tokens
;
token_idx
+=
num_sms
)
{
const
auto
x_int4
=
static_cast
<
const
int4
*>
(
x
)
+
token_idx
*
hidden_bf16_int4
;
const
auto
rdma_x_src_idx
=
reinterpret_cast
<
int
*>
(
static_cast
<
uint8_t
*>
(
rdma_x
)
+
token_idx
*
num_bytes_per_msg
);
const
auto
rdma_x_vec
=
reinterpret_cast
<
vec_t
*>
(
reinterpret_cast
<
uint8_t
*>
(
rdma_x_src_idx
)
+
sizeof
(
int4
));
const
auto
rdma_x_scales
=
reinterpret_cast
<
float
*>
(
reinterpret_cast
<
uint8_t
*>
(
rdma_x_vec
)
+
hidden_bytes
);
// Overlap top-k index read and source token index writes
auto
dst_expert_idx
=
warp_id
<
num_topk
?
static_cast
<
int
>
(
__ldg
(
topk_idx
+
token_idx
*
num_topk
+
warp_id
))
:
-
1
;
thread_id
==
0
?
(
*
rdma_x_src_idx
=
token_idx
)
:
0
;
// FP8 cast
EP_STATIC_ASSERT
(
hidden_bf16_int4
%
kWarpSize
==
0
,
"Must use the full warp to reduce"
);
#pragma unroll
for
(
int
i
=
thread_id
;
i
<
hidden_bf16_int4
;
i
+=
num_threads
)
{
// Read
auto
int4_value
=
__ldg
(
x_int4
+
i
);
if
constexpr
(
kUseFP8
)
{
// Calculate local amax
auto
bf16_values
=
reinterpret_cast
<
hip_bfloat16
*>
(
&
int4_value
);
float
fp32_values
[
kNumElemsPerRead
];
float
amax
=
kFP8Margin
,
scale
,
scale_inv
;
#pragma unroll
for
(
int
j
=
0
;
j
<
kNumElemsPerRead
;
++
j
)
{
fp32_values
[
j
]
=
static_cast
<
float
>
(
bf16_values
[
j
]);
amax
=
fmaxf
(
amax
,
fabsf
(
fp32_values
[
j
]));
}
// Reduce amax and scale
EP_STATIC_ASSERT
(
kNumElemsPerRead
*
kWarpSize
/
kNumPerChannels
==
4
,
"Invalid vectorization"
);
amax
=
warp_reduce_max
<
16
>
(
amax
);
calculate_fp8_scales
(
amax
,
scale
,
scale_inv
,
round_scale
);
if
(
lane_id
%
16
==
0
)
rdma_x_scales
[
i
*
kNumElemsPerRead
/
FP8_QUANTIZATION_NUM_PER_CHANNEL
]
=
scale_inv
;
// Cast into send buffer
vec_t
int2_value
;
auto
fp8x2_values
=
reinterpret_cast
<
__hip_fp8x2_storage_t
*>
(
&
int2_value
);
#pragma unroll
for
(
int
j
=
0
;
j
<
kNumElemsPerRead
;
j
+=
2
)
{
float2
fp32x2
=
{
fp32_values
[
j
]
*
scale
,
fp32_values
[
j
+
1
]
*
scale
};
fp8x2_values
[
j
/
2
]
=
__hip_cvt_float2_to_fp8x2
(
fp32x2
,
__HIP_SATFINITE
,
__HIP_E4M3_FNUZ
);
}
rdma_x_vec
[
i
]
=
int2_value
;
}
else
{
// Reinterpret-cast is for C++14 compatibility
rdma_x_vec
[
i
]
=
*
reinterpret_cast
<
vec_t
*>
(
&
int4_value
);
}
}
__syncthreads
();
// Issue IBGDA sends
if
(
dst_expert_idx
>=
0
)
{
int
slot_idx
=
lane_id
==
0
?
atomicAdd
(
atomic_counter_per_expert
+
dst_expert_idx
,
1
)
:
0
;
slot_idx
=
shfl_sync
(
slot_idx
,
0
);
const
int
dst_rank
=
dst_expert_idx
/
num_local_experts
;
const
int
dst_expert_local_idx
=
dst_expert_idx
%
num_local_experts
;
const
auto
src_ptr
=
reinterpret_cast
<
uint64_t
>
(
rdma_x_src_idx
);
const
auto
dst_ptr
=
reinterpret_cast
<
uint64_t
>
(
rdma_recv_x
)
+
dst_expert_local_idx
*
num_ranks
*
num_max_dispatch_tokens_per_rank
*
num_bytes_per_msg
+
rank
*
num_max_dispatch_tokens_per_rank
*
num_bytes_per_msg
+
slot_idx
*
num_bytes_per_msg
;
if
(
dst_rank
!=
rank
)
{
#if !defined(ROCM_DISABLE_CTX)
rocshmem
::
rocshmem_ctx_schar_put_nbi_wave
(
ctx
,
reinterpret_cast
<
signed
char
*>
(
dst_ptr
),
reinterpret_cast
<
signed
char
*>
(
src_ptr
),
num_bytes_per_msg
,
dst_rank
);
rocshmem
::
rocshmem_ctx_quiet
(
ctx
);
#else
rocshmem
::
rocshmem_schar_put_nbi_wave
(
reinterpret_cast
<
signed
char
*>
(
dst_ptr
),
reinterpret_cast
<
signed
char
*>
(
src_ptr
),
num_bytes_per_msg
,
dst_rank
);
rocshmem
::
rocshmem_fence
();
#endif
}
else
{
// NOTES: only 2 load iterations for 7K hidden with 8 unrolls
const
auto
*
src_int4_ptr
=
reinterpret_cast
<
const
int4
*>
(
src_ptr
);
const
auto
*
dst_int4_ptr
=
reinterpret_cast
<
int4
*>
(
dst_ptr
);
UNROLLED_WARP_COPY
(
8
,
lane_id
,
num_int4_per_msg
,
dst_int4_ptr
,
src_int4_ptr
,
ld_nc_global
,
st_na_global
);
}
// Increase counter after finishing
syncwarp
();
lane_id
==
0
?
atomic_add_release_global
(
atomic_finish_counter_per_expert
+
dst_expert_idx
,
1
)
:
0
;
}
}
}
else
if
(
warp_id
==
num_warps
-
1
)
{
EP_DEVICE_ASSERT
(
num_sms
>
1
);
if
(
sm_id
==
0
)
{
// The first SM is also responsible for cleaning the next buffer
#pragma unroll
for
(
int
i
=
lane_id
;
i
<
num_next_clean_int
;
i
+=
kWarpSize
)
next_clean
[
i
]
=
0
;
// Notify before executing `int_p`
syncwarp
();
#pragma unroll
for
(
int
i
=
lane_id
;
i
<
num_experts
;
i
+=
kWarpSize
)
atomic_add_release_global
(
atomic_finish_counter_per_expert
+
i
,
FINISHED_SUM_TAG
);
}
// This SM should be responsible for some destination experts, read `topk_idx` for them
int
expert_count
[
kNumMaxWarpGroups
]
=
{
0
};
const
auto
expert_begin_idx
=
sm_id
*
num_warp_groups
;
const
auto
expert_end_idx
=
min
(
expert_begin_idx
+
num_warp_groups
,
num_experts
);
// Per lane count
#pragma unroll 8
for
(
int
i
=
lane_id
;
i
<
num_tokens
*
num_topk
;
i
+=
kWarpSize
)
{
auto
idx
=
static_cast
<
int
>
(
__ldg
(
topk_idx
+
i
));
if
(
idx
>=
expert_begin_idx
and
idx
<
expert_end_idx
)
expert_count
[
idx
-
expert_begin_idx
]
++
;
}
// Warp reduce
#pragma unroll
for
(
int
i
=
expert_begin_idx
;
i
<
expert_end_idx
;
++
i
)
{
auto
sum
=
warp_reduce_sum
(
expert_count
[
i
-
expert_begin_idx
]);
if
(
lane_id
==
0
)
{
shared_num_tokens_sent_per_expert
[
i
-
expert_begin_idx
]
=
sum
;
atomic_add_release_global
(
atomic_finish_counter_per_expert
+
i
,
FINISHED_SUM_TAG
-
sum
);
}
}
}
__syncthreads
();
// Issue count sends
if
(
responsible_expert_idx
<
num_experts
and
sub_warp_id
==
0
and
lane_id
==
0
)
{
const
auto
dst_rank
=
responsible_expert_idx
/
num_local_experts
;
const
auto
dst_expert_local_idx
=
responsible_expert_idx
%
num_local_experts
;
const
auto
num_tokens_sent
=
shared_num_tokens_sent_per_expert
[
responsible_expert_idx
-
sm_id
*
num_warp_groups
];
// Wait local sends issued and send expert counts
while
(
ld_acquire_global
(
atomic_finish_counter_per_expert
+
responsible_expert_idx
)
!=
FINISHED_SUM_TAG
*
2
);
if
(
not
is_rank_masked
(
mask_buffer_ptr
,
dst_rank
))
{
auto
dst_ptr
=
reinterpret_cast
<
int64_t
*>
(
rdma_recv_count
+
dst_expert_local_idx
*
num_ranks
+
rank
);
if
(
dst_rank
!=
rank
)
{
#if !defined(ROCM_DISABLE_CTX)
rocshmem
::
rocshmem_ctx_long_atomic_add
(
ctx
,
dst_ptr
,
-
num_tokens_sent
-
1
,
dst_rank
);
#else
rocshmem
::
rocshmem_long_atomic_add
(
dst_ptr
,
-
num_tokens_sent
-
1
,
dst_rank
);
#endif
}
else
{
st_release_sys_global
(
dst_ptr
,
-
num_tokens_sent
-
1
);
}
}
// Clean workspace for next use
atomic_counter_per_expert
[
responsible_expert_idx
]
=
0
;
atomic_finish_counter_per_expert
[
responsible_expert_idx
]
=
0
;
// Clean `packed_recv_count`
if
(
dst_rank
==
0
)
packed_recv_count
[
dst_expert_local_idx
]
=
0
;
}
syncwarp
();
// Receiving phase
LOW_LATENCY_DISPATCH_RECV:
// 如果没有接收直接返回
if
((
phases
&
LOW_LATENCY_RECV_PHASE
)
==
0
)
return
;
// For send-and-recv kernels, we need a grid sync for making `packed_recv_count` visible
if
(
phases
&
LOW_LATENCY_SEND_PHASE
){
grid_barrier
(
global_atomic_counter
,
num_sms
);
}
// Receiving and packing
if
(
responsible_expert_idx
<
num_experts
)
{
const
auto
src_rank
=
responsible_expert_idx
/
num_local_experts
;
const
auto
local_expert_idx
=
responsible_expert_idx
%
num_local_experts
;
const
auto
rdma_recv_x_uint8
=
static_cast
<
uint8_t
*>
(
rdma_recv_x
)
+
local_expert_idx
*
num_ranks
*
num_max_dispatch_tokens_per_rank
*
num_bytes_per_msg
+
src_rank
*
num_max_dispatch_tokens_per_rank
*
num_bytes_per_msg
;
const
auto
recv_x_int4
=
static_cast
<
int4
*>
(
packed_recv_x
)
+
local_expert_idx
*
num_ranks
*
num_max_dispatch_tokens_per_rank
*
hidden_int4
;
const
auto
recv_src_info
=
packed_recv_src_info
+
local_expert_idx
*
num_ranks
*
num_max_dispatch_tokens_per_rank
;
const
auto
recv_range
=
packed_recv_layout_range
+
local_expert_idx
*
num_ranks
;
const
auto
num_aligned_scales
=
ALIGN
<
int
>
(
num_scales
,
sizeof
(
float
)
/
sizeof
(
scale_t
));
const
auto
recv_x_scales
=
static_cast
<
scale_t
*>
(
packed_recv_x_scales
)
+
local_expert_idx
*
num_ranks
*
num_max_dispatch_tokens_per_rank
*
num_aligned_scales
;
// Shared between sub-warps in warp groups
__shared__
int
shared_num_recv_tokens
[
kNumMaxWarpGroups
],
shared_recv_token_begin_idx
[
kNumMaxWarpGroups
];
// Wait tokens to arrive
// NOTES: using sub-warp 1 to overlap with sub-warp 0
int64_t
num_recv_tokens
;
int
recv_token_begin_idx
;
EP_DEVICE_ASSERT
(
num_warps_per_group
>
1
and
num_warp_groups
<
15
);
if
(
sub_warp_id
==
1
and
lane_id
==
0
)
{
auto
start_time
=
wall_clock64
();
int64_t
wait_recv_cost
=
0
;
int
offset
=
local_expert_idx
*
num_ranks
+
src_rank
;
if
(
not
is_rank_masked
(
mask_buffer_ptr
,
src_rank
))
{
while
((
wait_recv_cost
=
wall_clock64
()
-
start_time
)
<=
NUM_TIMEOUT_CYCLES
)
{
// not timeout
if
((
num_recv_tokens
=
ld_acquire_global
(
reinterpret_cast
<
int64_t
*>
(
rdma_recv_count
+
local_expert_idx
*
num_ranks
+
src_rank
)))
!=
0
)
{
break
;
}
}
}
// Mask rank if timeout
if
(
wait_recv_cost
>
NUM_TIMEOUT_CYCLES
)
{
printf
(
"Warning: DeepEP timeout for dispatch receive, rank %d, local_expert_idx %d, src_rank %d
\n
"
,
rank
,
local_expert_idx
,
src_rank
);
if
(
mask_buffer_ptr
==
nullptr
)
trap
();
atomicExch
(
mask_buffer_ptr
+
src_rank
,
1
);
}
// Do not receive tokens if rank timeout or masked
if
(
num_recv_tokens
==
0
)
num_recv_tokens
=
-
1
;
#if 1
num_recv_tokens
=
-
num_recv_tokens
-
1
;
int
num_recv_tokens_int32
=
static_cast
<
int
>
(
num_recv_tokens
);
recv_token_begin_idx
=
atomicAdd
(
packed_recv_count
+
local_expert_idx
,
num_recv_tokens_int32
);
shared_num_recv_tokens
[
warp_group_id
]
=
num_recv_tokens_int32
;
shared_recv_token_begin_idx
[
warp_group_id
]
=
recv_token_begin_idx
;
recv_range
[
src_rank
]
=
pack2
<
int
,
int64_t
>
(
num_recv_tokens_int32
,
recv_token_begin_idx
);
// Add stats for diagnosis
if
(
cumulative_local_expert_recv_stats
!=
nullptr
)
atomicAdd
(
cumulative_local_expert_recv_stats
+
local_expert_idx
,
num_recv_tokens_int32
);
if
(
dispatch_wait_recv_cost_stats
!=
nullptr
)
{
atomicAdd
(
reinterpret_cast
<
uint64_t
*>
(
dispatch_wait_recv_cost_stats
+
src_rank
),
static_cast
<
uint64_t
>
(
wait_recv_cost
));
}
#endif
}
#if 1
#ifdef USE_ROCM
// no needs to reset because there is no iteration
if
(
lane_id
==
0
){
volatile
int
ret
=
atomicAdd
((
int
*
)
&
sync_large_warp_counters
[
warp_group_id
],
1
);
}
syncwarp
();
while
(
sync_large_warp_counters
[
warp_group_id
]
<
num_warps_per_group
)
{}
#else
// asm volatile("bar.sync %0, %1;" ::"r"(warp_group_id + 2), "r"(num_warps_per_group * 32));
#endif
num_recv_tokens
=
shared_num_recv_tokens
[
warp_group_id
];
recv_token_begin_idx
=
shared_recv_token_begin_idx
[
warp_group_id
];
// Copy tokens
EP_DEVICE_ASSERT
(
num_scales
<=
64
);
for
(
int
i
=
sub_warp_id
;
i
<
num_recv_tokens
;
i
+=
num_warps_per_group
)
{
// Copy source info
const
auto
src_src_idx
=
reinterpret_cast
<
int
*>
(
rdma_recv_x_uint8
+
i
*
num_bytes_per_msg
);
if
(
lane_id
==
0
)
recv_src_info
[
recv_token_begin_idx
+
i
]
=
ld_nc_global
(
src_src_idx
);
syncwarp
();
// Copy data
// NOTES: only 2 load iterations for 7K hidden with 7 unrolls
const
auto
src_data
=
reinterpret_cast
<
int4
*>
(
reinterpret_cast
<
uint8_t
*>
(
src_src_idx
)
+
sizeof
(
int4
));
const
auto
dst_data
=
recv_x_int4
+
(
recv_token_begin_idx
+
i
)
*
hidden_int4
;
UNROLLED_WARP_COPY
(
7
,
lane_id
,
hidden_int4
,
dst_data
,
src_data
,
ld_nc_global
,
st_na_global
);
// Copy scales
if
constexpr
(
kUseFP8
)
{
// Equivalent CuTe layout:
// (num_tokens, (num_packed, num_elems_per_pack)):(num_elems_per_pack, (num_tokens * num_elems_per_pack, 1))
const
auto
src_scales
=
reinterpret_cast
<
float
*>
(
reinterpret_cast
<
uint8_t
*>
(
src_data
)
+
hidden_bytes
);
const
auto
num_elems_per_pack
=
static_cast
<
int
>
(
sizeof
(
packed_t
)
/
sizeof
(
scale_t
));
const
auto
token_idx
=
recv_token_begin_idx
+
i
;
const
auto
token_stride
=
num_elems_per_pack
;
const
auto
pack_stride
=
num_ranks
*
num_max_dispatch_tokens_per_rank
*
num_elems_per_pack
;
if
(
lane_id
<
num_scales
)
{
const
auto
pack_idx
=
lane_id
/
num_elems_per_pack
;
const
auto
elem_idx
=
lane_id
%
num_elems_per_pack
;
auto
scale
=
extract_required_scale_format
<
kUseUE8M0
>
(
ld_nc_global
(
src_scales
+
lane_id
));
recv_x_scales
[
token_idx
*
token_stride
+
pack_idx
*
pack_stride
+
elem_idx
]
=
scale
;
}
if
(
lane_id
+
kWarpSize
<
num_scales
)
{
const
auto
pack_idx
=
(
lane_id
+
kWarpSize
)
/
num_elems_per_pack
;
const
auto
elem_idx
=
(
lane_id
+
kWarpSize
)
%
num_elems_per_pack
;
auto
scale
=
extract_required_scale_format
<
kUseUE8M0
>
(
ld_nc_global
(
src_scales
+
lane_id
+
kWarpSize
));
recv_x_scales
[
token_idx
*
token_stride
+
pack_idx
*
pack_stride
+
elem_idx
]
=
scale
;
}
}
}
#endif
}
#if !defined(ROCM_DISABLE_CTX)
rocshmem
::
rocshmem_wg_ctx_destroy
(
&
ctx
);
#endif
}
void
dispatch
(
void
*
packed_recv_x
,
void
*
packed_recv_x_scales
,
int
*
packed_recv_src_info
,
int64_t
*
packed_recv_layout_range
,
void
dispatch
(
void
*
packed_recv_x
,
void
*
packed_recv_x_scales
,
int
*
packed_recv_src_info
,
int64_t
*
packed_recv_layout_range
,
int
*
packed_recv_count
,
int
*
global_atomic_counter
,
int
*
mask_buffer_ptr
,
int
*
cumulative_local_expert_recv_stats
,
int64_t
*
dispatch_wait_recv_cost_stats
,
void
*
rdma_recv_x
,
int
*
rdma_recv_count
,
void
*
rdma_x
,
const
void
*
x
,
const
topk_idx_t
*
topk_idx
,
int
*
next_clean
,
int
num_next_clean_int
,
int
num_tokens
,
int
hidden
,
int
num_max_dispatch_tokens_per_rank
,
int
num_topk
,
int
num_experts
,
int
rank
,
int
num_ranks
,
bool
use_fp8
,
bool
round_scale
,
bool
use_ue8m0
,
void
*
workspace
,
int
num_device_sms
,
cudaStream_t
stream
,
int
phases
)
{
void
*
rdma_recv_x
,
int64_t
*
rdma_recv_count
,
void
*
rdma_x
,
const
void
*
x
,
const
int64_t
*
topk_idx
,
int64_t
*
next_clean
,
int
num_next_clean_int
,
int
num_tokens
,
int
hidden
,
int
num_max_dispatch_tokens_per_rank
,
int
num_topk
,
int
num_experts
,
int
rank
,
int
num_ranks
,
bool
use_fp8
,
bool
round_scale
,
bool
use_ue8m0
,
void
*
workspace
,
int
num_device_sms
,
hipStream_t
stream
,
int
phases
)
{
constexpr
int
kNumMaxTopK
=
11
;
// const int num_warp_groups = ceil_div(num_experts, num_device_sms);
// const int num_warps_per_group = 32 / num_warp_groups;
// EP_HOST_ASSERT(num_warp_groups > 0 and num_warps_per_group > 0);
// EP_HOST_ASSERT(kNumMaxTopK + 1 <= num_warp_groups * num_warps_per_group);
// const auto num_warps = num_warp_groups * num_warps_per_group;
// const auto num_sms = ceil_div(num_experts, num_warp_groups);
// EP_HOST_ASSERT(num_topk <= kNumMaxTopK);
// // Workspace checks
// auto atomic_counter_per_expert = static_cast<int*>(workspace);
// auto atomic_finish_counter_per_expert = atomic_counter_per_expert + num_experts;
// EP_HOST_ASSERT(num_experts * sizeof(int) * 2 <= NUM_WORKSPACE_BYTES);
// // FP8 checks
// if (use_ue8m0)
// EP_HOST_ASSERT(round_scale and "UE8M0 SF requires `round_scale=True`");
// #define DISPATCH_LAUNCH_CASE(hidden) { \
// auto dispatch_func = dispatch<false, false, hidden>; \
// if (use_fp8 and not use_ue8m0) \
// dispatch_func = dispatch<true, false, hidden>; \
// if (use_fp8 and use_ue8m0) \
// dispatch_func = dispatch<true, true, hidden>; \
// LAUNCH_KERNEL(&cfg, dispatch_func, \
// packed_recv_x, packed_recv_x_scales, \
// packed_recv_src_info, packed_recv_layout_range, \
// packed_recv_count, \
// cumulative_local_expert_recv_stats, \
// dispatch_wait_recv_cost_stats, \
// rdma_recv_x, rdma_recv_count, rdma_x, \
// x, topk_idx, \
// atomic_counter_per_expert, atomic_finish_counter_per_expert, \
// next_clean, num_next_clean_int, \
// num_tokens, num_max_dispatch_tokens_per_rank, \
// num_topk, num_experts, rank, num_ranks, \
// num_warp_groups, num_warps_per_group, \
// round_scale, phases); } break
// SETUP_LAUNCH_CONFIG(num_sms, num_warps * 32, stream);
// SWITCH_HIDDEN(DISPATCH_LAUNCH_CASE);
// #undef DISPATCH_LAUNCH_CASE
}
template
<
int
kNumSendUnrolls
>
__forceinline__
__device__
int
logfmt_encode
(
void
*
buffer
,
nv_bfloat162
*
shared_amaxmin
,
const
int
&
lane_id
)
{
// constexpr int kNumElemsPerInt4 = sizeof(int4) / sizeof(nv_bfloat16);
// constexpr float kLogThreshold = 0;
// constexpr float kMinClip = 32; // `== log_2(2 ^ (2 ^ 5))`
// constexpr int kNumBits = 10;
// constexpr int kNumValues = 1 << (kNumBits - 1);
// int4 int4_values[kNumSendUnrolls];
// const auto& uint32_values = reinterpret_cast<uint32_t*>(int4_values);
// const auto& bf162_values = reinterpret_cast<nv_bfloat162*>(int4_values);
// // Calculate lane offset
// const auto& ld_buffer = reinterpret_cast<uint32_t*>(static_cast<uint8_t*>(buffer) + lane_id * (kNumSendUnrolls * sizeof(int4)));
// const auto& st_buffer = reinterpret_cast<uint32_t*>(static_cast<uint8_t*>(buffer) + lane_id * (kNumSendUnrolls * sizeof(int4) * 10 / 16));
// // Local log amax
// auto bf162_amax = __nv_bfloat162(CUDART_ZERO_BF16, CUDART_ZERO_BF16);
// auto bf162_amin = __nv_bfloat162(CUDART_INF_BF16, CUDART_INF_BF16);
// uint32_t local_signs = 0;
// #pragma unroll
// for (int k = 0; k < kNumSendUnrolls * kNumElemsPerInt4 / 2; ++ k) {
// // TODO: eliminate bank conflicts
// uint32_values[k] = ld_buffer[k];
// local_signs |= ((uint32_values[k] >> 15) & 1) << (k * 2);
// local_signs |= ((uint32_values[k] >> 31) & 1) << (k * 2 + 1);
// uint32_values[k] &= 0x7fff7fff;
// bf162_amax = __hmax2(bf162_amax, bf162_values[k]);
// bf162_amin = __hmin2(bf162_amin, bf162_values[k]);
// }
// // Reduce per 128 channels
// // TODO: figure out how hardware do 2-byte min/max
// auto amax = std::max(static_cast<float>(bf162_amax.x), static_cast<float>(bf162_amax.y));
// auto amin = std::min(static_cast<float>(bf162_amin.x), static_cast<float>(bf162_amin.y));
// constexpr static int kNumLanesToReduce = 128 * sizeof(nv_bfloat16) / (kNumSendUnrolls * sizeof(int4));
// amax = warp_reduce_max<kNumLanesToReduce>(amax);
// amin = warp_reduce_min<kNumLanesToReduce>(amin);
// // Write min/max into the shared memory
// if (shared_amaxmin != nullptr)
// *shared_amaxmin = __nv_bfloat162(amax, amin);
// __syncwarp();
// // Calculate log amin/amax float
// const auto& log_amax = log2f_approx(amax);
// const auto& log_amin = fmaxf(log2f_approx(amin), log_amax - kMinClip);
// const bool& enable_cast = warp_reduce_and<kNumLanesToReduce, true>(log_amax < kLogThreshold and log_amin < log_amax);
// // Case into LogFMT-10 if satisfied
// if (enable_cast) {
// const auto step = (log_amax - log_amin) / static_cast<float>(kNumValues - 2);
// const auto step_inv = 1.0f / step;
// const auto rounding = 2.0f - log2f_approx((1.0f + exp2f_approx(step)) * 0.5f) * step_inv;
// const auto fused_rounding = rounding - log_amin * step_inv;
// // Pack every 256 bits into 160 bits
// EP_STATIC_ASSERT(kNumSendUnrolls == 2 or kNumSendUnrolls == 4, "kNumSendUnrolls == 2 or 4 only");
// uint32_t encoded[kNumElemsPerInt4 * 2];
// #pragma unroll 1
// for (int i = 0; i < kNumSendUnrolls / 2; ++ i) {
// #pragma unroll
// for (int k = 0; k < kNumElemsPerInt4; ++ k) {
// const auto& [x, y] = __bfloat1622float2(bf162_values[i * kNumElemsPerInt4 + k]);
// encoded[k * 2 + 0] = __float2uint_rd(fmaxf(log2f_approx(x) * step_inv + fused_rounding, 0));
// encoded[k * 2 + 1] = __float2uint_rd(fmaxf(log2f_approx(y) * step_inv + fused_rounding, 0));
// }
// st_buffer[i * 5 + 0] = (encoded[ 0] >> 0) | (encoded[ 1] << 9) | (encoded[ 2] << 18) | (encoded[ 3] << 27);
// st_buffer[i * 5 + 1] = (encoded[ 3] >> 5) | (encoded[ 4] << 4) | (encoded[ 5] << 13) | (encoded[ 6] << 22) | (encoded[7] << 31);
// st_buffer[i * 5 + 2] = (encoded[ 7] >> 1) | (encoded[ 8] << 8) | (encoded[ 9] << 17) | (encoded[10] << 26);
// st_buffer[i * 5 + 3] = (encoded[10] >> 6) | (encoded[11] << 3) | (encoded[12] << 12) | (encoded[13] << 21) | (encoded[14] << 30);
// st_buffer[i * 5 + 4] = (encoded[14] >> 2) | (encoded[15] << 7) | ((i == 0) ? (local_signs << 16) : (local_signs & 0xffff0000u));
// }
// tma_store_fence();
// __syncwarp();
// }
// // Return TMA copy bytes
// return enable_cast ? (32 * (kNumSendUnrolls * sizeof(int4) * 8 * 10 / 16 / 8)):
// (32 * (kNumSendUnrolls * sizeof(int4)));
}
template
<
int
kNumLanes
,
int
kNumSendUnrolls
,
int
kNumRecvUnrolls
>
__forceinline__
__device__
void
logfmt_check_amaxmin
(
uint8_t
*
meta_buffer
,
float2
*
shared_log_amax
,
float2
*
shared_log_amin
,
int
*
shared_cast_info
,
const
int
lane_id
)
{
// constexpr float kLogThreshold = 0;
// constexpr float kMinClip = 32; // `== log_2(2 ^ (2 ^ 5))`
// bool enable_cast = true;
// if (lane_id < kNumLanes) {
// // Calculate log amin/amax float
// auto amaxmin2 = reinterpret_cast<uint64_t*>(meta_buffer)[lane_id];
// const auto& bf162_amaxmin = reinterpret_cast<__nv_bfloat162*>(&amaxmin2);
// float log_amax[2], log_amin[2];
// #pragma unroll
// for (int i = 0; i < 2; ++ i) {
// auto amax = static_cast<float>(bf162_amaxmin[i].x);
// auto amin = static_cast<float>(bf162_amaxmin[i].y);
// log_amax[i] = log2f_approx(amax);
// log_amin[i] = amin == 0 ? log_amax[i] - kMinClip : fmaxf(log2f_approx(amin), log_amax[i] - kMinClip);
// enable_cast = enable_cast and log_amax[i] < kLogThreshold and log_amin[i] < log_amax[i];
// }
// shared_log_amax[lane_id] = make_float2(log_amax[0], log_amax[1]);
// shared_log_amin[lane_id] = make_float2(log_amin[0], log_amin[1]);
// }
// const auto& casted = warp_reduce_and<kNumSendUnrolls>(enable_cast) ? 1u << (lane_id / kNumRecvUnrolls): 0u;
// const auto& num_casted_prefix = __popc(warp_reduce_or<kNumRecvUnrolls, true>(casted) & ((1u << (lane_id / kNumRecvUnrolls)) - 1));
// if (lane_id < kNumLanes and lane_id % kNumRecvUnrolls == 0)
// shared_cast_info[lane_id / kNumRecvUnrolls] = (num_casted_prefix << 1) | (casted ? 1u : 0u);
// __syncwarp();
}
template
<
int
kNumRecvUnrolls
>
__forceinline__
__device__
void
decode_and_accumulate
(
uint32_t
*
ld_buffer
,
float
*
accum
,
const
float
&
log_amax
,
const
float
&
log_amin
,
const
bool
&
enable_cast
,
const
float
&
weight
)
{
// if (enable_cast) {
// constexpr int kNumBits = 10;
// constexpr int kNumValues = 1 << (kNumBits - 1);
// const auto& step = (log_amax - log_amin) / static_cast<float>(kNumValues - 2);
// auto decode = [=](const uint32_t &encoded, const uint32_t &sign) {
// const auto decoded = encoded == 0 ? .0f : exp2f_approx((encoded - 1) * step + log_amin);
// return sign ? -decoded : decoded;
// };
// EP_STATIC_ASSERT(kNumRecvUnrolls == 2 or kNumRecvUnrolls == 4, "kNumRecvUnrolls == 2 or 4 only");
// #pragma unroll
// for (int i = 0; i < kNumRecvUnrolls / 2; ++ i) {
// uint32_t concat[6];
// concat[0] = ld_buffer[i * 5];
// #pragma unroll
// for (int k = 1; k < 5; ++ k)
// concat[k] = (ld_buffer[i * 5 + k - 1] >> (32 - k * 5)) | (ld_buffer[i * 5 + k] << (k * 5));
// concat[5] = ld_buffer[i * 5 + 4] >> 7;
// const uint32_t& local_signs = ld_buffer[i * 5 + 4] >> 16;
// #pragma unroll
// for (int k = 0; k < 5; ++ k) {
// accum[i * 16 + k * 3 + 0] += decode((concat[k] >> 0) & 0x1ff, (local_signs >> (k * 3 + 0)) & 1) * weight;
// accum[i * 16 + k * 3 + 1] += decode((concat[k] >> 9) & 0x1ff, (local_signs >> (k * 3 + 1)) & 1) * weight;
// accum[i * 16 + k * 3 + 2] += decode((concat[k] >> 18) & 0x1ff, (local_signs >> (k * 3 + 2)) & 1) * weight;
// }
// accum[i * 16 + 15] += decode(concat[5] & 0x1ff, (local_signs >> 15) & 1) * weight;
// }
// } else {
// #pragma unroll
// for (int k = 0; k < kNumRecvUnrolls * 4; ++ k) {
// auto bf16_pack = *reinterpret_cast<__nv_bfloat162*>(ld_buffer + k);
// accum[k * 2 + 0] += static_cast<float>(bf16_pack.x) * weight;
// accum[k * 2 + 1] += static_cast<float>(bf16_pack.y) * weight;
// }
// }
const
int
num_warp_groups
=
DIVUP
(
num_experts
,
num_device_sms
);
EP_HOST_ASSERT
(
num_warp_groups
<=
16
);
const
int
num_warps_per_group
=
16
/
num_warp_groups
;
// 每个kernel最大16个warp
EP_HOST_ASSERT
(
num_warp_groups
>
0
and
num_warps_per_group
>
0
);
EP_HOST_ASSERT
(
kNumMaxTopK
+
1
<=
num_warp_groups
*
num_warps_per_group
);
const
auto
num_warps
=
num_warp_groups
*
num_warps_per_group
;
const
auto
num_sms
=
DIVUP
(
num_experts
,
num_warp_groups
);
EP_HOST_ASSERT
(
num_topk
<=
kNumMaxTopK
);
// Workspace checks
auto
atomic_counter_per_expert
=
static_cast
<
int
*>
(
workspace
);
auto
atomic_finish_counter_per_expert
=
atomic_counter_per_expert
+
num_experts
;
EP_HOST_ASSERT
(
num_experts
*
sizeof
(
int
)
*
2
<=
NUM_WORKSPACE_BYTES
);
#define DISPATCH_LAUNCH_CASE(hidden) \
{ \
auto dispatch_func = dispatch<false, false, hidden>; \
if(use_fp8 and not use_ue8m0) \
dispatch_func = dispatch<true, false, hidden>; \
if(use_fp8 and use_ue8m0) \
dispatch_func = dispatch<true, true, hidden>; \
LAUNCH_KERNEL_NON_COOPERATIVE(&cfg, \
dispatch_func, \
packed_recv_x, \
packed_recv_x_scales, \
packed_recv_src_info, \
packed_recv_layout_range, \
packed_recv_count, \
global_atomic_counter, \
mask_buffer_ptr, \
cumulative_local_expert_recv_stats, \
dispatch_wait_recv_cost_stats, \
rdma_recv_x, \
rdma_recv_count, \
rdma_x, \
x, \
topk_idx, \
atomic_counter_per_expert, \
atomic_finish_counter_per_expert, \
next_clean, \
num_next_clean_int, \
num_tokens, \
num_max_dispatch_tokens_per_rank, \
num_topk, \
num_experts, \
rank, \
num_ranks, \
num_warp_groups, \
num_warps_per_group, \
round_scale, \
phases); \
} \
break
SETUP_LAUNCH_CONFIG
(
num_sms
,
num_warps
*
kWarpSize
,
stream
);
SWITCH_HIDDEN
(
DISPATCH_LAUNCH_CASE
);
#undef DISPATCH_LAUNCH_CASE
}
template
<
bool
kUseLogFMT
,
int
kHidden
,
int
kNumMaxTopk
,
int
kNumMaxUnrolls
>
__global__
__launch_bounds__
(
1024
,
1
)
void
combine
(
void
*
combined_x
,
void
*
rdma_recv_x
,
int
*
rdma_recv_flag
,
void
*
rdma_send_x
,
const
void
*
x
,
const
topk_idx_t
*
topk_idx
,
const
float
*
topk_weights
,
const
int
*
src_info
,
const
int64_t
*
layout_range
,
__launch_bounds__
(
1024
,
1
)
__global__
void
combine
(
void
*
combined_x
,
void
*
rdma_recv_x
,
int
*
rdma_recv_flag
,
void
*
rdma_send_x
,
const
void
*
x
,
const
int64_t
*
topk_idx
,
const
float
*
topk_weights
,
const
int
*
src_info
,
const
int64_t
*
layout_range
,
int
*
global_atomic_counter
,
int
*
mask_buffer_ptr
,
int64_t
*
combine_wait_recv_cost_stats
,
int
*
next_clean
,
int
num_next_clean_int
,
int64_t
*
next_clean
,
int
num_next_clean_int
,
int
*
atomic_clean_flag
,
int
num_combined_tokens
,
int
hidden
,
int
num_topk
,
int
num_combined_tokens
,
int
hidden
,
int
num_topk
,
int
num_max_dispatch_tokens_per_rank
,
int
num_experts
,
int
rank
,
int
num_ranks
,
int
num_warp_groups
,
int
num_warps_per_group
,
int
phases
,
bool
zero_copy
)
{
// const auto sm_id = __shfl_sync(0xffffffff, static_cast<int>(blockIdx.x), 0);
// const auto num_sms = __shfl_sync(0xffffffff, static_cast<int>(gridDim.x), 0);
// const auto thread_id = static_cast<int>(threadIdx.x);
// const auto num_threads = __shfl_sync(0xffffffff, static_cast<int>(blockDim.x), 0);
// const auto warp_id = __shfl_sync(0xffffffff, thread_id / 32, 0), lane_id = get_lane_id();
// const auto num_local_experts = num_experts / num_ranks;
// const auto warp_group_id = warp_id / num_warps_per_group;
// const auto sub_warp_id = warp_id % num_warps_per_group;
// const auto responsible_expert_idx = sm_id * num_warp_groups + warp_group_id;
// extern __shared__ __align__(1024) uint8_t smem_buffer[];
// // Data type staffs
// constexpr int kNumElemsPerInt4 = sizeof(int4) / sizeof(nv_bfloat16);
// constexpr int64_t hidden_bf16_int4 = kHidden / kNumElemsPerInt4;
// // Use different unroll factors for send and recv phases
// constexpr int kNumSendUnrolls = kHidden % (32 * 4 * sizeof(int4) / sizeof(nv_bfloat16)) == 0 ? 4 : 2;
// constexpr int kNumRecvUnrolls = 2;
// constexpr int hidden_bf16_int4_pad = align_up(static_cast<int>(hidden_bf16_int4), 32 * kNumSendUnrolls);
// EP_STATIC_ASSERT(kHidden % (32 * 2 * sizeof(int4) / sizeof(nv_bfloat16)) == 0, "Invalid hidden");
// EP_STATIC_ASSERT(kNumSendUnrolls <= kNumMaxUnrolls and kNumRecvUnrolls <= kNumMaxUnrolls, "Invalid unrolls");
// EP_STATIC_ASSERT(hidden_bf16_int4 % kNumSendUnrolls == 0, "Invalid hidden");
// EP_STATIC_ASSERT(kNumSendUnrolls >= kNumRecvUnrolls, "Invalid unroll factors");
// // Message package
// EP_STATIC_ASSERT(kHidden % 128 == 0, "Invalid hidden");
// constexpr int kNumDivisions = kHidden / 128;
// constexpr int kNumMetaBytes = kNumDivisions * sizeof(nv_bfloat162);
// constexpr size_t num_bytes_per_slot = kHidden * sizeof(nv_bfloat16) + kNumMetaBytes;
// EP_STATIC_ASSERT(num_bytes_per_slot % sizeof(int4) == 0, "Invalid vectorization");
// // Sending phase
// if ((phases & LOW_LATENCY_SEND_PHASE) == 0)
// goto LOW_LATENCY_COMBINE_RECV;
// // Clean up next buffer
// if (sm_id == 0 and warp_group_id == 0 and sub_warp_id == 0) {
// #pragma unroll
// for (int i = lane_id; i < num_next_clean_int; i += 32)
// next_clean[i] = 0;
// // Notify before executing `int_p`
// __syncwarp();
// if (lane_id == 0)
// atomic_add_release_global(atomic_clean_flag, num_experts);
// }
// // Issue IBGDA sends
// if (responsible_expert_idx < num_experts) {
// const auto dst_rank = responsible_expert_idx / num_local_experts;
// const auto local_expert_idx = responsible_expert_idx % num_local_experts;
// const auto global_expert_idx = rank * num_local_experts + local_expert_idx;
// const auto layout = __ldg(layout_range + local_expert_idx * num_ranks + dst_rank);
// const auto local_x = static_cast<const int4*>(x) +
// local_expert_idx * num_ranks * num_max_dispatch_tokens_per_rank * hidden_bf16_int4;
// const auto local_src_info = src_info + local_expert_idx * num_ranks * num_max_dispatch_tokens_per_rank;
// const auto rdma_send_x_vec = static_cast<uint8_t*>(rdma_send_x) +
// local_expert_idx * num_ranks * num_max_dispatch_tokens_per_rank * num_bytes_per_slot;
// // Unpack layout
// int offset, num_tokens_to_send;
// unpack2(layout, num_tokens_to_send, offset);
// // TMA stuffs
// constexpr int kNumTMABufferBytes = sizeof(int4) * 32 * kNumSendUnrolls;
// constexpr int kNumStages = 3;
// constexpr int kNumPrefetch = 1;
// EP_STATIC_ASSERT(kNumStages == 3 and kNumPrefetch == 1, "Invalid stages");
// auto smem_ptr = smem_buffer + warp_id * (kNumStages * (kNumTMABufferBytes + 16) + kNumMetaBytes);
// uint32_t tma_phase = 0;
// auto tma_buffers = PatternVisitor([=](const int& i) { return reinterpret_cast<int4*>(smem_ptr + i * (kNumTMABufferBytes + 16)); });
// auto full_barriers = PatternVisitor([=](const int& i) { return reinterpret_cast<uint64_t*>(smem_ptr + i * (kNumTMABufferBytes + 16) + kNumTMABufferBytes); });
// auto meta_buffers = kUseLogFMT ? reinterpret_cast<nv_bfloat162*>(smem_ptr + kNumStages * (kNumTMABufferBytes + 16)) : nullptr;
// EP_STATIC_ASSERT(kNumSendUnrolls * kNumStages <= 12, "TMA buffer size exceed limit");
// // Initialize m-barriers
// if (lane_id < kNumStages) {
// mbarrier_init(full_barriers[lane_id], 1);
// fence_barrier_init();
// }
// __syncwarp();
// constexpr int kNumIters = hidden_bf16_int4_pad / (32 * kNumSendUnrolls);
// auto tma_load_and_arrive = [&](const int& stage_idx, const int4* gmem_ptr, const int& num_bytes) {
// tma_load_1d(tma_buffers[stage_idx], gmem_ptr, full_barriers[stage_idx], num_bytes);
// mbarrier_arrive_and_expect_tx(full_barriers[stage_idx], num_bytes);
// };
// auto get_num_tma_bytes = [&](const int& offset_int4) {
// return min(kNumTMABufferBytes, static_cast<int>((hidden_bf16_int4 - offset_int4) * sizeof(int4)));
// };
// // Issue IBGDA send
// for (int token_idx = offset + sub_warp_id; token_idx < offset + num_tokens_to_send; token_idx += num_warps_per_group) {
// const auto x_int4 = local_x + token_idx * hidden_bf16_int4;
// const auto rdma_send_type_row = reinterpret_cast<int*>(rdma_send_x_vec + token_idx * num_bytes_per_slot);
// const auto rdma_send_x_vec_row = reinterpret_cast<uint8_t*>(rdma_send_type_row);
// // Copy directly to local rank, or copy to buffer and issue RDMA
// const auto src_idx = __shfl_sync(0xffffffff, __ldg(local_src_info + token_idx), 0);
// const auto buf_ptr = reinterpret_cast<int64_t>(rdma_send_x_vec_row);
// const auto dst_ptr = reinterpret_cast<uint64_t>(rdma_recv_x) + (global_expert_idx * num_max_dispatch_tokens_per_rank + src_idx) * num_bytes_per_slot;
// const auto dst_p2p_ptr = nvshmemi_get_p2p_ptr(dst_ptr, rank, dst_rank);
// int num_send_bytes = hidden * sizeof(nv_bfloat16);
// if (not zero_copy or dst_p2p_ptr != 0) {
// // Read from `cpy_src_int4_ptr` and copy into `cpy_dst_int4_ptr`
// const auto cpy_src_int4_ptr = zero_copy ? reinterpret_cast<int4*>(buf_ptr) : x_int4;
// const auto cpy_dst_int4_ptr = dst_p2p_ptr == 0 ? reinterpret_cast<int4*>(buf_ptr) : reinterpret_cast<int4*>(dst_p2p_ptr);
// // Prefetch
// if (elect_one_sync())
// tma_load_and_arrive(0, cpy_src_int4_ptr, get_num_tma_bytes(0));
// __syncwarp();
// int tma_offset_bytes = kNumMetaBytes;
// #pragma unroll
// for (int i = lane_id * kNumSendUnrolls, iter_idx = 0; i < hidden_bf16_int4_pad; i += 32 * kNumSendUnrolls, ++ iter_idx) {
// // Load the next iteration
// const int& stage_idx = iter_idx % kNumStages;
// const int& next_stage_idx = (iter_idx + 1) % kNumStages;
// if (iter_idx + 1 < kNumIters and elect_one_sync()) {
// tma_store_wait<kNumStages - kNumPrefetch - 1>();
// const auto& offset_int4 = i + 32 * kNumSendUnrolls;
// tma_load_and_arrive(next_stage_idx, cpy_src_int4_ptr + offset_int4, get_num_tma_bytes(offset_int4));
// }
// __syncwarp();
// // Wait the current TMA arrival
// EP_STATIC_ASSERT(kNumStages < 32, "Too many stages");
// mbarrier_wait<true>(full_barriers[stage_idx], tma_phase, stage_idx);
// if constexpr (kUseLogFMT) {
// // Cast if possible
// constexpr int kNumInt4PerDivision = 128 / kNumElemsPerInt4;
// int num_tma_bytes = logfmt_encode<kNumSendUnrolls>(
// tma_buffers[stage_idx],
// // NOTES: only the leader lane will write the result
// (i % kNumInt4PerDivision == 0) ? meta_buffers + i / kNumInt4PerDivision : nullptr,
// lane_id);
// if (elect_one_sync())
// tma_store_1d(tma_buffers[stage_idx], reinterpret_cast<uint8_t*>(cpy_dst_int4_ptr) + tma_offset_bytes, num_tma_bytes);
// tma_offset_bytes += num_tma_bytes;
// } else {
// // BF16 original values
// if (elect_one_sync())
// tma_store_1d(tma_buffers[stage_idx], cpy_dst_int4_ptr + i, get_num_tma_bytes(i));
// }
// __syncwarp();
// }
// // Store metadata (min/max values) for LogFMT
// if constexpr (kUseLogFMT) {
// num_send_bytes = tma_offset_bytes;
// if (elect_one_sync())
// tma_store_1d(meta_buffers, cpy_dst_int4_ptr, kNumMetaBytes);
// }
// // Flush all stores
// tma_store_wait<0>();
// __syncwarp();
// }
// // Issue RDMA
// // NOTES: for zero-copy mode, we assume the data is already in the send buffer
// if (dst_p2p_ptr == 0)
// nvshmemi_ibgda_put_nbi_warp(dst_ptr, buf_ptr, num_send_bytes, dst_rank, local_expert_idx, lane_id, token_idx - offset);
// }
// // Put the finishing flag
// EP_DEVICE_ASSERT(num_warps_per_group > 1 and num_warp_groups < 16);
// asm volatile("bar.sync %0, %1;" :: "r"(warp_group_id + 1), "r"(num_warps_per_group * 32));
// if (sub_warp_id == 1 and lane_id == 0) {
// while (ld_acquire_global(atomic_clean_flag) == 0);
// auto dst_ptr = reinterpret_cast<uint64_t>(rdma_recv_flag + global_expert_idx);
// auto dst_p2p_ptr = nvshmemi_get_p2p_ptr(dst_ptr, rank, dst_rank);
// if (dst_p2p_ptr == 0) {
// nvshmemi_ibgda_amo_nonfetch_add(reinterpret_cast<int*>(dst_ptr), 1, dst_rank, local_expert_idx);
// } else {
// st_release_sys_global(reinterpret_cast<int*>(dst_p2p_ptr), 1);
// }
// atomic_add_release_global(atomic_clean_flag, -1);
// }
// __syncwarp();
// // Destroy m-barriers
// if (lane_id < kNumStages) {
// mbarrier_inval(full_barriers[lane_id]);
// fence_barrier_init();
// }
// __syncwarp();
// }
// // Receiving phase
// LOW_LATENCY_COMBINE_RECV:
// if ((phases & LOW_LATENCY_RECV_PHASE) == 0)
// return;
// // Wait all ranks to arrive
// if (responsible_expert_idx < num_experts) {
// EP_DEVICE_ASSERT(num_warps_per_group > 1);
// if (sub_warp_id == 0 and lane_id == 0) {
// auto start_time = clock64();
// while (ld_acquire_sys_global(rdma_recv_flag + responsible_expert_idx) == 0);
// auto wait_recv_cost = clock64() - start_time;
// if (combine_wait_recv_cost_stats != nullptr) {
// const auto& src_rank = responsible_expert_idx / num_local_experts;
// atomicAdd(reinterpret_cast<unsigned long long*>(combine_wait_recv_cost_stats + src_rank), wait_recv_cost);
// }
// }
// }
// cg::this_grid().sync();
// // Reassign warp groups
// constexpr int kMaxNumGroups = 2;
// const int num_decode_warps = hidden_bf16_int4_pad / (kNumRecvUnrolls * 32);
// const int num_groups = min(kMaxNumGroups, (num_threads / 32) / (num_decode_warps + 1));
// const int decode_warp_idx = __shfl_sync(0xffffffff, warp_id % (num_decode_warps + 1), 0);
// const int group_idx = __shfl_sync(0xffffffff, warp_id / (num_decode_warps + 1), 0);
// EP_STATIC_ASSERT(kHidden % (32 * kNumElemsPerInt4) == 0, "Invalid vectorization");
// EP_DEVICE_ASSERT(num_topk <= 32);
// EP_DEVICE_ASSERT(num_groups > 0);
// if (group_idx < num_groups) {
// constexpr int kNumStages = 3;
// constexpr int kNumTMABufferBytes = 16 * 2 + kHidden * 2;
// constexpr int kNumBF16PerWarpBytes = 32 * kNumRecvUnrolls * kNumElemsPerInt4 * 2;
// constexpr int kNumLogFMTPerWarpBytes = kNumBF16PerWarpBytes / 16 * 10;
// constexpr int kNumDivisionBytes = kNumDivisions * sizeof(uint32_t);
// constexpr int kNumBytesPerGroup = kNumStages * kNumTMABufferBytes + kHidden * 2 + kNumStages * kNumDivisionBytes * 3;
// // Reallocate shared memory
// const auto smem_group_buffer = smem_buffer + kNumBytesPerGroup * group_idx;
// auto full_barriers = PatternVisitor([=](const int& i) { return reinterpret_cast<uint64_t*>(smem_group_buffer + i * kNumTMABufferBytes); });
// auto empty_barriers = PatternVisitor([=](const int& i) { return reinterpret_cast<uint64_t*>(smem_group_buffer + i * kNumTMABufferBytes + 8); });
// auto tma_ld_buffers = PatternVisitor([=](const int& i) { return reinterpret_cast<uint8_t* >(smem_group_buffer + i * kNumTMABufferBytes + 16); });
// auto tma_st_buffers = PatternVisitor([=](const int& i) { return reinterpret_cast<uint32_t*>(smem_group_buffer + kNumStages * kNumTMABufferBytes + i * kNumBF16PerWarpBytes); });
// // Redundant when logfmt is disabled
// const auto smem_group_ptr = smem_group_buffer + kNumStages * kNumTMABufferBytes + kHidden * 2;
// auto log_amax_buffers = PatternVisitor([=](const int& i) { return reinterpret_cast<float*>(smem_group_ptr + i * kNumDivisionBytes); });
// auto log_amin_buffers = PatternVisitor([=](const int& i) { return reinterpret_cast<float*>(smem_group_ptr + kNumStages * kNumDivisionBytes + i * kNumDivisionBytes); });
// auto cast_info_buffers = PatternVisitor([=](const int& i) { return reinterpret_cast<int*> (smem_group_ptr + kNumStages * kNumDivisionBytes * 2 + i * kNumDivisionBytes); });
// uint32_t tma_phase = 0;
// EP_STATIC_ASSERT(kNumStages < 32, "Too many stages");
// if (decode_warp_idx == num_decode_warps)
// tma_phase = (1 << kNumStages) - 1;
// // Initialize m-barriers
// if (decode_warp_idx == num_decode_warps and lane_id < kNumStages) {
// mbarrier_init(full_barriers[lane_id], 1);
// mbarrier_init(empty_barriers[lane_id], num_decode_warps);
// }
// asm volatile("bar.sync %0, %1;" :: "r"(group_idx + 1), "r"((num_decode_warps + 1) * 32));
// int stage_idx = 0, topk_idx_by_lane = 0;
// EP_STATIC_ASSERT(kNumMaxTopk <= 32, "Invalid number of topks");
// if (decode_warp_idx == num_decode_warps) {
// // TMA load warp
// for (int token_idx = sm_id + num_sms * group_idx; token_idx < num_combined_tokens; token_idx += num_sms * num_groups) {
// if (lane_id < num_topk)
// topk_idx_by_lane = static_cast<int>(__ldg(topk_idx + token_idx * num_topk + lane_id));
// for (int i = 0; i < num_topk; ++ i) {
// int topk_idx_reg = __shfl_sync(0xffffffff, topk_idx_by_lane, i);
// if (topk_idx_reg < 0)
// continue;
// mbarrier_wait<true>(empty_barriers[stage_idx], tma_phase, stage_idx);
// auto buffer = static_cast<uint8_t*>(rdma_recv_x) + (topk_idx_reg * num_max_dispatch_tokens_per_rank + token_idx) * num_bytes_per_slot;
// if constexpr (kUseLogFMT) {
// logfmt_check_amaxmin<kNumDivisions / 2, kNumSendUnrolls, kNumRecvUnrolls>(
// buffer, reinterpret_cast<float2*>(log_amax_buffers[stage_idx]),
// reinterpret_cast<float2*>(log_amin_buffers[stage_idx]), cast_info_buffers[stage_idx], lane_id);
// }
// if (elect_one_sync()) {
// int num_casted = 0;
// if constexpr (kUseLogFMT) {
// const auto& info = cast_info_buffers[stage_idx][num_decode_warps - 1];
// num_casted = (info >> 1) + (info & 1);
// }
// int num_tma_bytes = num_casted * kNumLogFMTPerWarpBytes + (num_decode_warps - num_casted) * kNumBF16PerWarpBytes;
// tma_load_1d(tma_ld_buffers[stage_idx], buffer + (kUseLogFMT ? kNumMetaBytes : 0), full_barriers[stage_idx], num_tma_bytes);
// mbarrier_arrive_and_expect_tx(full_barriers[stage_idx], num_tma_bytes);
// }
// __syncwarp();
// stage_idx = (stage_idx + 1) % kNumStages;
// }
// }
// } else {
// // Reduction warps
// float topk_weights_by_lane;
// for (int token_idx = sm_id + num_sms * group_idx; token_idx < num_combined_tokens; token_idx += num_sms * num_groups) {
// if (lane_id < num_topk) {
// topk_idx_by_lane = static_cast<int>(__ldg(topk_idx + token_idx * num_topk + lane_id));
// topk_weights_by_lane = __ldg(topk_weights + token_idx * num_topk + lane_id);
// }
// __syncwarp();
// float combined_values[kNumElemsPerInt4 * kNumRecvUnrolls] = {0.0f};
// for (int i = 0; i < num_topk; ++ i) {
// if (__shfl_sync(0xffffffff, topk_idx_by_lane, i) < 0)
// continue;
// const auto& topk_weight = __shfl_sync(0xffffffff, topk_weights_by_lane, i);
// mbarrier_wait<true>(full_barriers[stage_idx], tma_phase, stage_idx);
// if constexpr (kUseLogFMT) {
// const auto& info = cast_info_buffers[stage_idx][decode_warp_idx];
// bool enable_cast = info & 1;
// int num_casted_prefix = info >> 1;
// int tma_offset = kNumLogFMTPerWarpBytes * num_casted_prefix + kNumBF16PerWarpBytes * (decode_warp_idx - num_casted_prefix);
// int division_idx = decode_warp_idx * (kNumRecvUnrolls * 2) + lane_id * kNumRecvUnrolls / 16;
// decode_and_accumulate<kNumRecvUnrolls>(
// reinterpret_cast<uint32_t*>(tma_ld_buffers[stage_idx] + tma_offset + (enable_cast ? kNumLogFMTPerWarpBytes : kNumBF16PerWarpBytes) / 32 * lane_id),
// combined_values, log_amax_buffers[stage_idx][division_idx], log_amin_buffers[stage_idx][division_idx], enable_cast, topk_weight
// );
// } else {
// int tma_offset = kNumBF16PerWarpBytes * decode_warp_idx;
// decode_and_accumulate<kNumRecvUnrolls>(
// reinterpret_cast<uint32_t*>(tma_ld_buffers[stage_idx] + tma_offset + kNumBF16PerWarpBytes / 32 * lane_id),
// combined_values, 0, 0, false, topk_weight
// );
// }
// if (elect_one_sync())
// mbarrier_arrive(empty_barriers[stage_idx]);
// stage_idx = (stage_idx + 1) % kNumStages;
// }
// tma_store_wait<0>();
// #pragma unroll
// for (int k = 0; k < kNumRecvUnrolls * 4; ++ k) {
// auto combined_pack = __nv_bfloat162(combined_values[k * 2], combined_values[k * 2 + 1]);
// tma_st_buffers[decode_warp_idx][kNumRecvUnrolls * 4 * lane_id + k] = *reinterpret_cast<uint32_t*>(&combined_pack);
// }
// tma_store_fence();
// if (elect_one_sync()) {
// tma_store_1d(tma_st_buffers[decode_warp_idx],
// static_cast<int4*>(combined_x) + token_idx * hidden_bf16_int4 + decode_warp_idx * kNumRecvUnrolls * 32,
// kNumBF16PerWarpBytes);
// }
// __syncwarp();
// }
// }
// }
int
num_experts
,
int
rank
,
int
num_ranks
,
int
num_warp_groups
,
int
num_warps_per_group
,
int
phases
,
bool
zero_copy
)
{
#if !defined(ROCM_DISABLE_CTX)
__shared__
rocshmem
::
rocshmem_ctx_t
ctx
;
rocshmem
::
rocshmem_wg_ctx_create
(
0
,
&
ctx
);
#endif
// const auto sm_id = static_cast<int>(blockIdx.x);
// const auto num_sms = static_cast<int>(gridDim.x);
// const auto thread_id = static_cast<int>(threadIdx.x);
// const auto num_threads = static_cast<int>(blockDim.x);
// const auto warp_id = thread_id / kWarpSize, lane_id = get_lane_id();
// const auto num_local_experts = num_experts / num_ranks;
// const auto warp_group_id = warp_id / kNumWarpsPerGroup;
// const auto sub_warp_id = warp_id % kNumWarpsPerGroup;
// const auto responsible_expert_idx = sm_id * kNumWarpGroups + warp_group_id;
// // Data type staffs
// constexpr int kNumElemsPerInt4 = sizeof(int4) / sizeof(gpu_bfloat16_t);
// const size_t hidden_bf16_int4 = kHidden / kNumElemsPerInt4;
// // Message package
// // BF16 mode: always use BF16 for hidden data (ignoring the extra flag slot)
// constexpr size_t num_bytes_per_slot = sizeof(int4) + kHidden * sizeof(gpu_bfloat16_t);
// EP_STATIC_ASSERT(num_bytes_per_slot % sizeof(int4) == 0, "Invalid vectorization");
// __syncthreads();
// #ifdef USE_ROCM
// // 16 is the max possible number of warps in AMD GPUs
// constexpr int kMaxNumWarps = 1024 / kWarpSize;
// __shared__ volatile int sync_large_warp_counters[kMaxNumWarps];
// if (threadIdx.x==0){
// // printf("combine");
// #pragma unroll
// for (int i = 0; i < kMaxNumWarps; ++i) {
// sync_large_warp_counters[i] = 0;
// }
// }
// __syncthreads();
// #endif
// // Sending phase
// if ((phases & LOW_LATENCY_SEND_PHASE) == 0)
// goto LOW_LATENCY_COMBINE_RECV;
// // Clean up next buffer
// if (sm_id == 0 and warp_group_id == 0 and sub_warp_id == 0) {
// #pragma unroll
// for (int i = lane_id; i < num_next_clean_int; i += kWarpSize)
// next_clean[i] = 0;
// // Notify before executing `int_p`
// syncwarp();
// if (lane_id == 0)
// atomic_add_release_global(atomic_clean_flag, num_experts);
// }
// // Issue IBGDA sends
// if (responsible_expert_idx < num_experts) {
// const auto dst_rank = responsible_expert_idx / num_local_experts;
// const auto local_expert_idx = responsible_expert_idx % num_local_experts;
// const auto global_expert_idx = rank * num_local_experts + local_expert_idx;
// const auto layout = __ldg(layout_range + local_expert_idx * num_ranks + dst_rank);
// const auto local_x = reinterpret_cast<const int4*>(x) +
// local_expert_idx * num_ranks * num_max_dispatch_tokens_per_rank * hidden_bf16_int4;
// const auto local_src_info = src_info + local_expert_idx * num_ranks * num_max_dispatch_tokens_per_rank;
// const auto rdma_send_x_vec = reinterpret_cast<uint8_t*>(rdma_send_x) +
// local_expert_idx * num_ranks * num_max_dispatch_tokens_per_rank * num_bytes_per_slot;
// // Unpack layout
// int offset, num_tokens_to_send;
// unpack2(layout, num_tokens_to_send, offset);
// // Issue IBGDA send
// for (int token_idx = offset + sub_warp_id; token_idx < offset + num_tokens_to_send; token_idx += kNumWarpsPerGroup) {
// const auto x_int4 = local_x + token_idx * hidden_bf16_int4;
// const auto rdma_send_type_row = reinterpret_cast<int*>(rdma_send_x_vec + token_idx * num_bytes_per_slot);
// const auto rdma_send_x_vec_row = reinterpret_cast<uint8_t*>(rdma_send_type_row + 4);
// // Copy directly to local rank, or copy to buffer and issue RDMA
// auto src_idx = __ldg(local_src_info + token_idx);
// const auto buf_ptr = reinterpret_cast<int64_t>(rdma_send_x_vec_row);
// const auto dst_ptr = reinterpret_cast<uint64_t>(rdma_recv_x) + (global_expert_idx * num_max_dispatch_tokens_per_rank + src_idx) * num_bytes_per_slot + sizeof(int4);
// if (dst_rank == rank) {
// const auto dst_int4_ptr = reinterpret_cast<int4*>(dst_ptr);
// UNROLLED_WARP_COPY(7, lane_id, hidden_bf16_int4, dst_int4_ptr, x_int4, ld_nc_global, st_na_global);
// } else {
// const auto buf_int4_ptr = reinterpret_cast<int4*>(buf_ptr);
// if (not zero_copy)
// UNROLLED_WARP_COPY(7, lane_id, hidden_bf16_int4, buf_int4_ptr, x_int4, ld_nc_global, st_na_global);
// //nvshmemi_ibgda_put_nbi_warp(dst_ptr, buf_ptr, hidden * sizeof(gpu_bfloat16_t), dst_rank, local_expert_idx, lane_id, token_idx - offset);
// #if defined(ROCM_DISABLE_CTX)
// internode::shmemx_int8_put_nbi_warp(
// #else
// internode::shmem_ctx_schar_put_nbi_warp(ctx,
// #endif
// reinterpret_cast<signed char*>(dst_ptr), reinterpret_cast<signed char*>(buf_ptr), hidden * sizeof(gpu_bfloat16_t), dst_rank);
// #if defined(ROCM_DISABLE_CTX)
// internode::shmem_fence();
// #else
// internode::shmem_ctx_quiet(ctx);
// #endif
// }
// }
// // Put finishing flag
// EP_STATIC_ASSERT(kNumWarpsPerGroup > 1, "Requires more than one warp per group");
// #ifdef USE_ROCM
// if (lane_id == 0){
// volatile int ret = __hip_atomic_fetch_add(
// &sync_large_warp_counters[warp_group_id], 1,
// __ATOMIC_RELAXED, __HIP_MEMORY_SCOPE_WORKGROUP);
// }
// syncwarp();
// while (sync_large_warp_counters[warp_group_id] < (kNumWarpsPerGroup));
// #else
// asm volatile("bar.sync %0, %1;" :: "r"(warp_group_id + 1), "r"(kNumWarpsPerGroup * 32));
// #endif
// if (sub_warp_id == 1 and lane_id == 0) {
// while (ld_acquire_global(atomic_clean_flag) == 0);
// if (dst_rank != rank) {
// #ifdef USE_ROCM
// #if defined(ROCM_DISABLE_CTX)
// internode::shmem_long_atomic_add(rdma_recv_flag + global_expert_idx, 1, dst_rank);
// #else
// internode::shmem_ctx_long_atomic_add(ctx, rdma_recv_flag + global_expert_idx, 1, dst_rank);
// #endif
// #else
// nvshmemi_ibgda_amo_nonfetch_add(rdma_recv_flag + global_expert_idx, 1, dst_rank, local_expert_idx);
// #endif
// } else {
// st_na_release(reinterpret_cast<int*>(rdma_recv_flag + global_expert_idx), 1);
// }
// atomic_add_release_global(atomic_clean_flag, -1);
// }
// syncwarp();
// }
// // Receiving phase
// LOW_LATENCY_COMBINE_RECV:
// if ((phases & LOW_LATENCY_RECV_PHASE) == 0)
// return;
// // Wait all ranks to arrive and notify PCIe usage
// if (responsible_expert_idx < num_experts) {
// EP_STATIC_ASSERT(kNumWarpsPerGroup > 1, "Invalid number of warps per group");
// if (sub_warp_id == 0 and lane_id == 0){
// while (ld_acquire_global(reinterpret_cast<int*>(rdma_recv_flag + responsible_expert_idx)) == 0);
// }
// }
// grid_barrier(global_atomic_counter, num_sms);
// // Reduce tokens with FP8 cast
// EP_DEVICE_ASSERT(num_topk <= kWarpSize and hidden_bf16_int4 <= num_threads);
// EP_STATIC_ASSERT(kHidden % (kWarpSize * kNumElemsPerInt4) == 0, "Invalid vectorization");
// if (thread_id < hidden_bf16_int4) {
// for (int token_idx = sm_id; token_idx < num_combined_tokens; token_idx += num_sms) {
// // Read top-k indices and weights
// int reg_topk_idx[kNumMaxTopk];
// float reg_topk_weights[kNumMaxTopk];
// #pragma unroll
// for (int i = 0; i < num_topk; ++ i) {
// reg_topk_idx[i] = static_cast<int>(__ldg(topk_idx + token_idx * num_topk + i));
// reg_topk_weights[i] = __ldg(topk_weights + token_idx * num_topk + i);
// }
// float combined_values[kNumElemsPerInt4] = {0.0f};
// #pragma unroll
// for (int i = 0; i < num_topk; ++ i) if (reg_topk_idx[i] >= 0) {
// // Read from sources
// auto rdma_buffer_type = reinterpret_cast<const int*>(reinterpret_cast<uint8_t*>(rdma_recv_x) + (reg_topk_idx[i] * num_max_dispatch_tokens_per_rank + token_idx) * num_bytes_per_slot);
// auto rdma_buffer_row = reinterpret_cast<const uint8_t*>(rdma_buffer_type + 4);
// // Reduce
// auto x_vec = ld_nc_global(reinterpret_cast<const int4*>(rdma_buffer_row) + thread_id);
// const auto x_bf16 = reinterpret_cast<gpu_bfloat16_t*>(&x_vec);
// #pragma unroll
// for (int j = 0; j < kNumElemsPerInt4; ++ j)
// combined_values[j] += static_cast<float>(x_bf16[j]) * reg_topk_weights[i];
// }
// // Write results
// int4& combined_int4 = *reinterpret_cast<int4*>(combined_values);
// auto combined_bf16 = reinterpret_cast<gpu_bfloat16_t*>(&combined_values);
// #pragma unroll
// for (int j = 0; j < kNumElemsPerInt4; ++ j)
// combined_bf16[j] = static_cast<gpu_bfloat16_t>(combined_values[j]);
// (reinterpret_cast<int4*>(combined_x) + token_idx * hidden_bf16_int4)[thread_id] = combined_int4;
// }
// }
#if !defined(ROCM_DISABLE_CTX)
rocshmem
::
rocshmem_wg_ctx_destroy
(
&
ctx
);
#endif
}
void
combine
(
void
*
combined_x
,
void
*
rdma_recv_x
,
int
*
rdma_recv_flag
,
void
*
rdma_send_x
,
const
void
*
x
,
const
topk_idx_t
*
topk_idx
,
const
float
*
topk_weights
,
const
int
*
src_info
,
const
int64_t
*
layout_range
,
void
*
rdma_recv_x
,
int64_t
*
rdma_recv_flag
,
void
*
rdma_send_x
,
const
void
*
x
,
const
int64_t
*
topk_idx
,
const
float
*
topk_weights
,
const
int
*
src_info
,
const
int64_t
*
layout_range
,
int
*
global_atomic_counter
,
int
*
mask_buffer_ptr
,
int64_t
*
combine_wait_recv_cost_stats
,
int
*
next_clean
,
int
num_next_clean_int
,
int
num_combined_tokens
,
int
hidden
,
int
num_max_dispatch_tokens_per_rank
,
int
num_topk
,
int
num_experts
,
int
rank
,
int
num_ranks
,
int64_t
*
next_clean
,
int
num_next_clean_int
,
int
num_combined_tokens
,
int
hidden
,
int
num_max_dispatch_tokens_per_rank
,
int
num_topk
,
int
num_experts
,
int
rank
,
int
num_ranks
,
bool
use_logfmt
,
void
*
workspace
,
int
num_device_sms
,
cudaStream_t
stream
,
int
phases
,
bool
zero_copy
)
{
void
*
workspace
,
int
num_device_sms
,
hipStream_t
stream
,
int
phases
,
bool
zero_copy
)
{
constexpr
int
kNumMaxTopk
=
11
;
// const int num_warp_groups = ceil_div(num_experts, num_device_sms);
// const int num_warps_per_group = 32 / num_warp_groups;
// const int num_recv_per_sm = ceil_div(num_combined_tokens, num_device_sms);
// EP_HOST_ASSERT(num_warp_groups > 0 and num_warps_per_group > 0 and num_recv_per_sm >= 0);
// const auto num_warps = num_warp_groups * num_warps_per_group;
// const auto num_sms = max(ceil_div(num_experts, num_warp_groups),
// num_recv_per_sm == 0 ? 1 : ceil_div(num_combined_tokens, num_recv_per_sm));
// // Check workspace
// auto atomic_clean_flag = static_cast<int*>(workspace);
// EP_HOST_ASSERT(sizeof(int) <= NUM_WORKSPACE_BYTES);
// EP_HOST_ASSERT(num_topk <= kNumMaxTopk);
// // Online cast cannot use zero-copy
// EP_HOST_ASSERT(not (zero_copy and use_logfmt));
// constexpr int kNumStages = 3;
// constexpr int kNumMaxUnrolls = 4;
// constexpr int kMaxNumGroups = 2;
// // Send buffer size
// const int num_meta_bytes = hidden / 128 * 4;
// const int num_send_tma_bytes = 32 * sizeof(int4) * kNumMaxUnrolls + 16;
// const int smem_send_size = num_warps * (kNumStages * num_send_tma_bytes + num_meta_bytes);
// // Receive buffer size
// const int num_recv_tma_bytes = 16 + hidden * 2;
// const int smem_recv_size = kMaxNumGroups * (kNumStages * num_recv_tma_bytes + hidden * 2 + kNumStages * num_meta_bytes * 3);
// // Total requirement
// const int smem_size = max(smem_send_size, smem_recv_size);
// #define COMBINE_LAUNCH_CASE(hidden) { \
// auto combine_func = use_logfmt ? \
// combine<true, hidden, kNumMaxTopk, kNumMaxUnrolls> : \
// combine<false, hidden, kNumMaxTopk, kNumMaxUnrolls>; \
// SET_SHARED_MEMORY_FOR_TMA(combine_func); \
// LAUNCH_KERNEL(&cfg, combine_func, \
const
int
num_warp_groups
=
DIVUP
(
num_experts
,
num_device_sms
);
const
int
num_warps_per_group
=
16
/
num_warp_groups
;
const
int
num_recv_per_sm
=
DIVUP
(
num_combined_tokens
,
num_device_sms
);
EP_HOST_ASSERT
(
num_warp_groups
>
0
and
num_warps_per_group
>
0
and
num_recv_per_sm
>=
0
);
const
auto
num_warps
=
num_warp_groups
*
num_warps_per_group
;
const
auto
num_sms
=
max
(
DIVUP
(
num_experts
,
num_warp_groups
),
num_recv_per_sm
==
0
?
1
:
DIVUP
(
num_combined_tokens
,
num_recv_per_sm
));
// Check workspace
auto
atomic_clean_flag
=
static_cast
<
int
*>
(
workspace
);
EP_HOST_ASSERT
(
sizeof
(
int
)
<=
NUM_WORKSPACE_BYTES
);
EP_HOST_ASSERT
(
num_topk
<=
kNumMaxTopk
);
// Online cast cannot use zero-copy
EP_HOST_ASSERT
(
not
(
zero_copy
and
use_logfmt
));
EP_HOST_ASSERT
(
use_logfmt
==
0
);
constexpr
int
kNumMaxUnrolls
=
4
;
#ifdef USEING_TMA
constexpr
int
kNumStages
=
3
;
constexpr
int
kMaxNumGroups
=
2
;
// Send buffer size
const
int
num_meta_bytes
=
hidden
/
FP8_QUANTIZATION_NUM_PER_CHANNEL
*
4
;
const
int
num_send_tma_bytes
=
32
*
sizeof
(
int4
)
*
kNumMaxUnrolls
+
16
;
const
int
smem_send_size
=
num_warps
*
(
kNumStages
*
num_send_tma_bytes
+
num_meta_bytes
);
// Receive buffer size
const
int
num_recv_tma_bytes
=
16
+
hidden
*
2
;
const
int
smem_recv_size
=
kMaxNumGroups
*
(
kNumStages
*
num_recv_tma_bytes
+
hidden
*
2
+
kNumStages
*
num_meta_bytes
*
3
);
// Total requirement
const
int
smem_size
=
max
(
smem_send_size
,
smem_recv_size
);
#endif
// #define COMBINE_LAUNCH_CASE(hidden) \
// { \
// auto combine_func = combine<false, hidden, kNumMaxTopk, kNumMaxUnrolls>; \
// LAUNCH_KERNEL(&cfg, \
// combine_func, \
// combined_x, \
// rdma_recv_x, rdma_recv_flag, rdma_send_x, \
// x, topk_idx, topk_weights, src_info, layout_range, \
// rdma_recv_x, \
// rdma_recv_flag, \
// rdma_send_x, \
// x, \
// topk_idx, \
// topk_weights, \
// src_info, \
// layout_range, \
// global_atomic_counter, \
// mask_buffer_ptr, \
// combine_wait_recv_cost_stats, \
// next_clean, num_next_clean_int, \
// next_clean, \
// num_next_clean_int, \
// atomic_clean_flag, \
// num_combined_tokens, hidden, num_topk, \
// num_combined_tokens, \
// hidden, \
// num_topk, \
// num_max_dispatch_tokens_per_rank, \
// num_experts, rank, num_ranks, \
// num_warp_groups, num_warps_per_group, \
// phases, zero_copy); } break
// SETUP_LAUNCH_CONFIG(num_sms, num_warps * 32, stream);
// num_experts, \
// rank, \
// num_ranks, \
// num_warp_groups, \
// num_warps_per_group, \
// phases, \
// zero_copy); \
// } \
// break
// SETUP_LAUNCH_CONFIG(num_sms, num_warps* kWarpSize, stream);
// SWITCH_HIDDEN(COMBINE_LAUNCH_CASE);
// #undef COMBINE_LAUNCH_CASE
}
template
<
int
kNumThreads
>
__launch_bounds__
(
kNumThreads
,
1
)
__global__
void
query_mask_buffer
(
int
*
mask_buffer_ptr
,
int
num_ranks
,
int
*
mask_tensor
)
{
const
auto
num_sms
=
static_cast
<
int
>
(
gridDim
.
x
);
const
auto
sm_id
=
static_cast
<
int
>
(
blockIdx
.
x
);
const
auto
num_threads
=
num_sms
*
kNumThreads
;
const
auto
thread_id
=
sm_id
*
kNumThreads
+
static_cast
<
int
>
(
threadIdx
.
x
);
for
(
int
rank_id
=
thread_id
;
rank_id
<
num_ranks
;
rank_id
+=
num_threads
)
{
mask_tensor
[
rank_id
]
=
mask_buffer_ptr
[
rank_id
];
}
}
void
query_mask_buffer
(
int
*
mask_buffer_ptr
,
int
num_ranks
,
int
*
mask_tensor
,
hipStream_t
stream
)
{
constexpr
int
num_sms
=
1
;
constexpr
int
kNumThreads
=
1024
;
SETUP_LAUNCH_CONFIG
(
num_sms
,
kNumThreads
,
stream
);
LAUNCH_KERNEL_NON_COOPERATIVE
(
&
cfg
,
query_mask_buffer
<
kNumThreads
>
,
mask_buffer_ptr
,
num_ranks
,
mask_tensor
);
}
template
<
int
kNumThreads
>
__launch_bounds__
(
kNumThreads
,
1
)
__global__
void
update_mask_buffer
(
int
*
mask_buffer_ptr
,
int
rank_to_mask
,
bool
mask
)
{
const
auto
sm_id
=
static_cast
<
int
>
(
blockIdx
.
x
);
const
auto
thread_id
=
static_cast
<
int
>
(
threadIdx
.
x
);
if
(
sm_id
==
0
&&
thread_id
==
0
)
{
atomicExch
(
mask_buffer_ptr
+
rank_to_mask
,
mask
?
1
:
0
);
}
}
void
update_mask_buffer
(
int
*
mask_buffer_ptr
,
int
rank
,
bool
mask
,
hipStream_t
stream
)
{
constexpr
int
num_sms
=
1
;
constexpr
int
kNumThreads
=
64
;
SETUP_LAUNCH_CONFIG
(
num_sms
,
kNumThreads
,
stream
);
LAUNCH_KERNEL_NON_COOPERATIVE
(
&
cfg
,
update_mask_buffer
<
kNumThreads
>
,
mask_buffer_ptr
,
rank
,
mask
);
}
template
<
int
kNumThreads
>
__launch_bounds__
(
kNumThreads
,
1
)
__global__
void
clean_mask_buffer
(
int
*
mask_buffer_ptr
,
int
num_ranks
)
{
auto
thread_id
=
static_cast
<
int
>
(
threadIdx
.
x
);
#pragma unroll
for
(
int
i
=
thread_id
;
i
<
num_ranks
;
i
+=
kNumThreads
)
mask_buffer_ptr
[
i
]
=
0
;
}
void
clean_mask_buffer
(
int
*
mask_buffer_ptr
,
int
num_ranks
,
hipStream_t
stream
)
{
constexpr
int
num_sms
=
1
;
constexpr
int
kNumThreads
=
64
;
SETUP_LAUNCH_CONFIG
(
num_sms
,
kNumThreads
,
stream
);
LAUNCH_KERNEL_NON_COOPERATIVE
(
&
cfg
,
clean_mask_buffer
<
kNumThreads
>
,
mask_buffer_ptr
,
num_ranks
);
}
}
// namespace internode_ll
}
// namespace deep_ep
#endif
csrc/kernels/utils.cuh
View file @
09cb2b03
...
...
@@ -125,6 +125,10 @@ __device__ __forceinline__ void st_release_sys_global(const int *ptr, int val) {
__hip_atomic_store
(
const_cast
<
int
*>
(
ptr
),
val
,
__ATOMIC_RELEASE
,
__HIP_MEMORY_SCOPE_SYSTEM
);
}
__device__
__forceinline__
void
st_release_sys_global
(
const
int64_t
*
ptr
,
int64_t
val
)
{
__hip_atomic_store
(
const_cast
<
int64_t
*>
(
ptr
),
val
,
__ATOMIC_RELEASE
,
__HIP_MEMORY_SCOPE_SYSTEM
);
}
__device__
__forceinline__
void
st_release_cta
(
const
int
*
ptr
,
int
val
)
{
__hip_atomic_store
(
const_cast
<
int
*>
(
ptr
),
val
,
__ATOMIC_RELEASE
,
__HIP_MEMORY_SCOPE_WORKGROUP
);
}
...
...
@@ -157,6 +161,12 @@ __device__ __forceinline__ int ld_acquire_global(const int *ptr) {
return
ret
;
}
__device__
__forceinline__
int64_t
ld_acquire_global
(
const
int64_t
*
ptr
)
{
int64_t
ret
;
ret
=
__hip_atomic_load
(
ptr
,
__ATOMIC_ACQUIRE
,
__HIP_MEMORY_SCOPE_AGENT
);
return
ret
;
}
__device__
__forceinline__
int
atomic_add_release_global
(
const
int
*
ptr
,
int
value
)
{
int
ret
;
// ret = __hip_atomic_fetch_add(const_cast<int *>(ptr), value, __ATOMIC_RELEASE,
...
...
@@ -165,6 +175,12 @@ __device__ __forceinline__ int atomic_add_release_global(const int *ptr, int val
return
ret
;
}
__device__
__forceinline__
int
ld_relaxed_global
(
const
int
*
ptr
)
{
int
ret
;
ret
=
__hip_atomic_load
(
ptr
,
__ATOMIC_RELAXED
,
__HIP_MEMORY_SCOPE_AGENT
);
return
ret
;
}
__device__
__forceinline__
int
ld_acquire_cta
(
const
int
*
ptr
)
{
int
ret
;
ret
=
__hip_atomic_load
(
ptr
,
__ATOMIC_ACQUIRE
,
__HIP_MEMORY_SCOPE_WORKGROUP
);
...
...
@@ -245,6 +261,11 @@ __device__ __forceinline__ void st_na_release(const uint64_t *ptr, uint64_t val)
__hip_atomic_store
(
non_const_ptr
,
val
,
__ATOMIC_RELEASE
,
__HIP_MEMORY_SCOPE_AGENT
);
}
__device__
__forceinline__
void
st_na_release
(
const
int64_t
*
ptr
,
int64_t
val
)
{
int64_t
*
non_const_ptr
=
const_cast
<
int64_t
*>
(
ptr
);
__hip_atomic_store
(
non_const_ptr
,
val
,
__ATOMIC_RELEASE
,
__HIP_MEMORY_SCOPE_AGENT
);
}
// TODO:: apply "st.global.L1::no_allocate" in ROCM
template
<
typename
dtype_t
>
__device__
__forceinline__
void
st_na_global
(
const
dtype_t
*
ptr
,
const
dtype_t
&
value
)
{
...
...
@@ -279,6 +300,22 @@ __forceinline__ __device__ void get_channel_task_range(int num_tokens, int num_s
token_end_idx
=
min
(
token_start_idx
+
num_tokens_per_sm
,
num_tokens
);
}
template
<
typename
dtype_a_t
,
typename
dtype_b_t
>
__device__
__forceinline__
dtype_b_t
pack2
(
const
dtype_a_t
&
x
,
const
dtype_a_t
&
y
)
{
EP_STATIC_ASSERT
(
sizeof
(
dtype_a_t
)
*
2
==
sizeof
(
dtype_b_t
),
"Invalid dtypes"
);
dtype_b_t
packed
;
auto
unpacked_ptr
=
reinterpret_cast
<
dtype_a_t
*>
(
&
packed
);
unpacked_ptr
[
0
]
=
x
,
unpacked_ptr
[
1
]
=
y
;
return
packed
;
}
template
<
typename
dtype_a_t
,
typename
dtype_b_t
>
__device__
__forceinline__
void
unpack2
(
const
dtype_b_t
&
packed
,
dtype_a_t
&
x
,
dtype_a_t
&
y
)
{
EP_STATIC_ASSERT
(
sizeof
(
dtype_a_t
)
*
2
==
sizeof
(
dtype_b_t
),
"Invalid dtypes"
);
auto
unpacked_ptr
=
reinterpret_cast
<
const
dtype_a_t
*>
(
&
packed
);
x
=
unpacked_ptr
[
0
],
y
=
unpacked_ptr
[
1
];
}
template
<
typename
dtype_t
>
__device__
__forceinline__
dtype_t
broadcast
(
dtype_t
&
ptr
,
int
src_lane_idx
)
{
EP_STATIC_ASSERT
(
sizeof
(
dtype_t
)
%
sizeof
(
int
)
==
0
,
""
);
...
...
@@ -290,15 +327,47 @@ __device__ __forceinline__ dtype_t broadcast(dtype_t &ptr, int src_lane_idx) {
return
*
reinterpret_cast
<
dtype_t
*>
(
recv_int_values
);
}
__forceinline__
__device__
int
warp_reduce_sum
(
int
value
)
{
if
constexpr
(
kWarpSize
==
64
)
value
+=
shfl_xor
<
int
>
(
value
,
32
);
value
+=
shfl_xor
<
int
>
(
value
,
16
);
value
+=
shfl_xor
<
int
>
(
value
,
8
);
value
+=
shfl_xor
<
int
>
(
value
,
4
);
value
+=
shfl_xor
<
int
>
(
value
,
2
);
value
+=
shfl_xor
<
int
>
(
value
,
1
);
#ifdef USE_ROCM
constexpr
float
kFP8Margin
=
1e-4
;
constexpr
float
kFinfoAmaxE4M3
=
240.0
f
;
constexpr
float
kFinfoAmaxInvE4M3
=
1.0
f
/
kFinfoAmaxE4M3
;
#else
constexpr
float
kFP8Margin
=
1e-4
;
constexpr
float
kFinfoAmaxE4M3
=
448.0
f
;
constexpr
float
kFinfoAmaxInvE4M3
=
1.0
f
/
kFinfoAmaxE4M3
;
#endif
__forceinline__
__device__
float
fast_pow2
(
int
x
)
{
// We can ensure `-126 <= x and x <= 127`
uint32_t
bits_x
=
(
x
+
127
)
<<
23
;
return
*
reinterpret_cast
<
float
*>
(
&
bits_x
);
}
__forceinline__
__device__
int
fast_log2_ceil
(
float
x
)
{
auto
bits_x
=
*
reinterpret_cast
<
uint32_t
*>
(
&
x
);
auto
exp_x
=
(
bits_x
>>
23
)
&
0xff
;
auto
man_bits
=
bits_x
&
((
1
<<
23
)
-
1
);
return
exp_x
-
127
+
(
man_bits
!=
0
);
}
__forceinline__
__device__
void
calculate_fp8_scales
(
float
amax
,
float
&
scale
,
float
&
scale_inv
,
bool
round_scale
)
{
if
(
round_scale
)
{
auto
exp_scale_inv
=
fast_log2_ceil
(
amax
*
kFinfoAmaxInvE4M3
);
scale
=
fast_pow2
(
-
exp_scale_inv
);
scale_inv
=
fast_pow2
(
exp_scale_inv
);
}
else
{
scale_inv
=
amax
*
kFinfoAmaxInvE4M3
;
scale
=
kFinfoAmaxE4M3
/
amax
;
}
}
template
<
bool
kIsUE8M0
,
typename
out_dtype_t
=
std
::
conditional_t
<
kIsUE8M0
,
uint8_t
,
float
>
>
__forceinline__
__device__
out_dtype_t
extract_required_scale_format
(
float
value
)
{
if
constexpr
(
kIsUE8M0
)
{
return
static_cast
<
uint8_t
>
((
*
reinterpret_cast
<
uint32_t
*>
(
&
value
))
>>
23
);
}
else
{
return
value
;
}
}
__forceinline__
__device__
int
get_lane_id
()
{
...
...
@@ -340,4 +409,95 @@ __forceinline__ __device__ void barrier_block(int **barrier_signal_ptrs, int ran
}
__syncthreads
();
}
// Operation functors
template
<
typename
T
>
struct
ReduceSum
{
__device__
T
operator
()(
T
a
,
T
b
)
const
{
return
a
+
b
;
}
};
template
<
typename
T
>
struct
ReduceMax
{
__device__
T
operator
()(
T
a
,
T
b
)
const
{
return
a
>
b
?
a
:
b
;
}
};
template
<
typename
T
>
struct
ReduceMin
{
__device__
T
operator
()(
T
a
,
T
b
)
const
{
return
a
<
b
?
a
:
b
;
}
};
template
<
typename
T
>
struct
ReduceAnd
{
__device__
T
operator
()(
T
a
,
T
b
)
const
{
return
a
&
b
;
}
};
template
<
typename
T
>
struct
ReduceOr
{
__device__
T
operator
()(
T
a
,
T
b
)
const
{
return
a
|
b
;
}
};
// Unified reduction function
template
<
int
kNumLanesPerGroup
,
bool
kIntergroupReduce
,
typename
T
,
typename
Op
>
__forceinline__
__device__
T
warp_reduce
(
T
value
,
Op
op
)
{
EP_STATIC_ASSERT
(
kNumLanesPerGroup
==
kWarpSize
or
kNumLanesPerGroup
==
32
or
kNumLanesPerGroup
==
16
or
kNumLanesPerGroup
==
8
or
kNumLanesPerGroup
==
4
or
kNumLanesPerGroup
==
2
or
kNumLanesPerGroup
==
1
,
"Invalid number of lanes"
);
constexpr
uint32_t
mask
=
0xffffffff
;
if
constexpr
(
kIntergroupReduce
)
{
if
constexpr
(
kNumLanesPerGroup
<=
1
)
value
=
op
(
value
,
shfl_xor
(
value
,
1
));
if
constexpr
(
kNumLanesPerGroup
<=
2
)
value
=
op
(
value
,
shfl_xor
(
value
,
2
));
if
constexpr
(
kNumLanesPerGroup
<=
4
)
value
=
op
(
value
,
shfl_xor
(
value
,
4
));
if
constexpr
(
kNumLanesPerGroup
<=
8
)
value
=
op
(
value
,
shfl_xor
(
value
,
8
));
if
constexpr
(
kNumLanesPerGroup
<=
16
)
value
=
op
(
value
,
shfl_xor
(
value
,
16
));
if
constexpr
(
kWarpSize
==
64
){
if
constexpr
(
kNumLanesPerGroup
<=
32
)
value
=
op
(
value
,
shfl_xor
(
value
,
32
));
}
}
else
{
if
constexpr
(
kWarpSize
==
64
){
if
constexpr
(
kNumLanesPerGroup
>=
kWarpSize
)
value
=
op
(
value
,
shfl_xor
(
value
,
32
));
}
if
constexpr
(
kNumLanesPerGroup
>=
32
)
value
=
op
(
value
,
shfl_xor
(
value
,
16
));
if
constexpr
(
kNumLanesPerGroup
>=
16
)
value
=
op
(
value
,
shfl_xor
(
value
,
8
));
if
constexpr
(
kNumLanesPerGroup
>=
8
)
value
=
op
(
value
,
shfl_xor
(
value
,
4
));
if
constexpr
(
kNumLanesPerGroup
>=
4
)
value
=
op
(
value
,
shfl_xor
(
value
,
2
));
if
constexpr
(
kNumLanesPerGroup
>=
2
)
value
=
op
(
value
,
shfl_xor
(
value
,
1
));
}
return
value
;
}
// Convenience aliases
template
<
int
kNumLanesPerGroup
=
kWarpSize
,
bool
kIntergroupReduce
=
false
,
typename
T
>
__forceinline__
__device__
T
warp_reduce_sum
(
T
value
)
{
return
warp_reduce
<
kNumLanesPerGroup
,
kIntergroupReduce
,
T
>
(
value
,
ReduceSum
<
T
>
{});
}
template
<
int
kNumLanesPerGroup
=
kWarpSize
,
bool
kIntergroupReduce
=
false
,
typename
T
>
__forceinline__
__device__
T
warp_reduce_max
(
T
value
)
{
return
warp_reduce
<
kNumLanesPerGroup
,
kIntergroupReduce
,
T
>
(
value
,
ReduceMax
<
T
>
{});
}
template
<
int
kNumLanesPerGroup
=
kWarpSize
,
bool
kIntergroupReduce
=
false
,
typename
T
>
__forceinline__
__device__
T
warp_reduce_min
(
T
value
)
{
return
warp_reduce
<
kNumLanesPerGroup
,
kIntergroupReduce
,
T
>
(
value
,
ReduceMin
<
T
>
{});
}
template
<
int
kNumLanesPerGroup
=
kWarpSize
,
bool
kIntergroupReduce
=
false
,
typename
T
>
__forceinline__
__device__
T
warp_reduce_and
(
T
value
)
{
return
warp_reduce
<
kNumLanesPerGroup
,
kIntergroupReduce
,
T
>
(
value
,
ReduceAnd
<
T
>
{});
}
template
<
int
kNumLanesPerGroup
=
kWarpSize
,
bool
kIntergroupReduce
=
false
,
typename
T
>
__forceinline__
__device__
T
warp_reduce_or
(
T
value
)
{
return
warp_reduce
<
kNumLanesPerGroup
,
kIntergroupReduce
,
T
>
(
value
,
ReduceOr
<
T
>
{});
}
}
// namespace deep_ep
deep_ep/buffer.py
View file @
09cb2b03
...
...
@@ -39,7 +39,7 @@ class Buffer:
allow_nvlink_for_low_latency_mode
:
bool
=
True
,
allow_mnnvl
:
bool
=
False
,
explicitly_destroy
:
bool
=
False
,
use_default_stream_as_comm_stream
:
bool
=
Tru
e
,
enable_shrink
:
bool
=
Fals
e
,
)
->
None
:
"""
Initialize the communication buffer.
...
...
@@ -59,6 +59,7 @@ class Buffer:
explicitly_destroy: If this flag is set to True, you need to explicitly call `destroy()` to release resources;
otherwise, the resources will be released by the destructor.
Note: Releasing resources in the destructor may cause Python's exception handling process to hang.
enable_shrink: whether to enable shrink mode. The enable mode allocates a mask buffer to support masking ranks dynamically.
"""
check_nvlink_connections
(
group
)
...
...
@@ -70,6 +71,7 @@ class Buffer:
self
.
num_rdma_bytes
=
num_rdma_bytes
self
.
low_latency_mode
=
low_latency_mode
self
.
explicitly_destroy
=
explicitly_destroy
self
.
enable_shrink
=
enable_shrink
self
.
runtime
=
deep_ep_cpp
.
Buffer
(
self
.
rank
,
self
.
group_size
,
...
...
@@ -77,7 +79,7 @@ class Buffer:
num_rdma_bytes
,
low_latency_mode
,
explicitly_destroy
,
use_default_stream_as_comm_stream
,
enable_shrink
)
# Synchronize device IDs
...
...
@@ -989,3 +991,31 @@ class Buffer:
return
self
.
runtime
.
get_next_low_latency_combine_buffer
(
num_max_dispatch_tokens_per_rank
,
hidden
,
num_experts
)
def
low_latency_update_mask_buffer
(
self
,
rank_to_mask
:
int
,
mask
:
bool
=
False
):
"""
Mask (unmask) a rank during communication (dispatch, combine, and clean)
Arguments:
rank: the rank to mask (unmask).
mask: if True, will mask the rank (do not recvfrom/sendto the rank), otherwise will unmask the rank.
"""
self
.
runtime
.
low_latency_update_mask_buffer
(
rank_to_mask
,
mask
)
def
low_latency_query_mask_buffer
(
self
,
mask_status
:
torch
.
Tensor
):
"""
Query the mask status of all ranks
Arguments:
mask_status: `[num_ranks]` with `torch.int`, the mask status of each rank. `1` means mask and `0` means unmasked.
"""
self
.
runtime
.
low_latency_query_mask_buffer
(
mask_status
)
def
low_latency_clean_mask_buffer
(
self
):
"""
Clean the mask buffer
"""
self
.
runtime
.
low_latency_clean_mask_buffer
()
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment