Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
DeepEP
Commits
da13c63a
Commit
da13c63a
authored
Nov 04, 2025
by
lishen
Browse files
完成低延迟接口功能
parent
09cb2b03
Changes
25
Show whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
917 additions
and
1072 deletions
+917
-1072
1.sh
1.sh
+6
-3
2.sh
2.sh
+6
-3
build.sh
build.sh
+5
-3
csrc/config.hpp
csrc/config.hpp
+1
-1
csrc/deep_ep.cu
csrc/deep_ep.cu
+67
-174
csrc/deep_ep.hpp
csrc/deep_ep.hpp
+16
-25
csrc/kernels/api.cuh
csrc/kernels/api.cuh
+22
-33
csrc/kernels/internode_ll.cu
csrc/kernels/internode_ll.cu
+351
-646
csrc/kernels/utils.cuh
csrc/kernels/utils.cuh
+20
-4
deep_ep/buffer.py
deep_ep/buffer.py
+33
-141
rocshmem_dir/bin/rocshmem_info
rocshmem_dir/bin/rocshmem_info
+0
-0
rocshmem_dir/include/rocshmem/rocshmem.hpp
rocshmem_dir/include/rocshmem/rocshmem.hpp
+52
-13
rocshmem_dir/include/rocshmem/rocshmem_COLL.hpp
rocshmem_dir/include/rocshmem/rocshmem_COLL.hpp
+8
-22
rocshmem_dir/include/rocshmem/rocshmem_common.hpp
rocshmem_dir/include/rocshmem/rocshmem_common.hpp
+18
-1
rocshmem_dir/include/rocshmem/rocshmem_config.h
rocshmem_dir/include/rocshmem/rocshmem_config.h
+1
-0
rocshmem_dir/include/rocshmem/rocshmem_mpi.hpp
rocshmem_dir/include/rocshmem/rocshmem_mpi.hpp
+143
-0
rocshmem_dir/lib/cmake/rocshmem/rocshmem-targets.cmake
rocshmem_dir/lib/cmake/rocshmem/rocshmem-targets.cmake
+1
-1
rocshmem_dir/lib/librocshmem.a
rocshmem_dir/lib/librocshmem.a
+0
-0
rocshmem_dir/rocshmem/lib/cmake/rocshmem-config-version.cmake
...hmem_dir/rocshmem/lib/cmake/rocshmem-config-version.cmake
+65
-1
rocshmem_dir/rocshmem/lib/cmake/rocshmem-config.cmake
rocshmem_dir/rocshmem/lib/cmake/rocshmem-config.cmake
+102
-1
No files found.
1.sh
View file @
da13c63a
pgrep
-f
/usr/bin/python | xargs
kill
-9
export
OMPI_MCA_pml
=
ucx
export
OMPI_MCA_pml
=
ucx
export
OMPI_MCA_osc
=
ucx
export
OMPI_MCA_osc
=
ucx
export
OMPI_MCA_coll_hcoll_enable
=
0
export
OMPI_MCA_coll_hcoll_enable
=
0
...
@@ -7,8 +9,9 @@ export OMPI_MCA_rmaps_base_mapping_policy="slot:numa"
...
@@ -7,8 +9,9 @@ export OMPI_MCA_rmaps_base_mapping_policy="slot:numa"
export
ROCSHMEM_MAX_NUM_CONTEXTS
=
32
export
ROCSHMEM_MAX_NUM_CONTEXTS
=
32
export
UCX_ROCM_IPC_SIGPOOL_MAX_ELEMS
=
16384
export
UCX_ROCM_IPC_SIGPOOL_MAX_ELEMS
=
16384
export
UCX_NET_DEVICES
=
mlx5_2:1,mlx5_3:1,mlx5_4:1,mlx5_5:1,mlx5_6:1,mlx5_7:1,mlx5_8:1,mlx5_9:1
export
UCX_NET_DEVICES
=
mlx5_2:1,mlx5_3:1,mlx5_4:1,mlx5_5:1,mlx5_6:1,mlx5_7:1,mlx5_8:1,mlx5_9:1
export
ROCSHMEM_ALLOWED_IBV_DEVICES
=
mlx5_2,mlx5_3,mlx5_4,mlx5_5,mlx5_6,mlx5_7,mlx5_8,mlx5_9
export
HIP_VISIBLE_DEVICES
=
0,1,2,3,4,5,6,7
export
HIP_VISIBLE_DEVICES
=
0,1,2,3,4,5,6,7
# export ROCSHMEM_HEAP_SIZE=536870912 805306368 10737418240
# export ROCSHMEM_HEAP_SIZE=536870912 805306368 10737418240
export
ROCSHMEM_HEAP_SIZE
=
53687
09
1
2
export
ROCSHMEM_HEAP_SIZE
=
288010
09
9
2
export
PYTHONPATH
=
/public/home/lishen/Tmp/DeepEP:
$PYTHONPATH
# torchrun --nproc-per-node=1 --nnodes=2 --node-rank=0 --master-addr="10.16.1.37" --master-port=1234 tests/test_internode.py
torchrun
--nproc-per-node
=
1
--nnodes
=
2
--node-rank
=
0
--master-addr
=
"10.16.1.37"
--master-port
=
1234 tests/test_
internode
.py
torchrun
--nproc-per-node
=
1
--nnodes
=
2
--node-rank
=
0
--master-addr
=
"10.16.1.37"
--master-port
=
1234 tests/test_
low_latency
.py
2.sh
View file @
da13c63a
pgrep
-f
/usr/bin/python | xargs
kill
-9
export
OMPI_MCA_pml
=
ucx
export
OMPI_MCA_pml
=
ucx
export
OMPI_MCA_osc
=
ucx
export
OMPI_MCA_osc
=
ucx
export
OMPI_MCA_coll_hcoll_enable
=
0
export
OMPI_MCA_coll_hcoll_enable
=
0
...
@@ -7,8 +9,9 @@ export OMPI_MCA_rmaps_base_mapping_policy="slot:numa"
...
@@ -7,8 +9,9 @@ export OMPI_MCA_rmaps_base_mapping_policy="slot:numa"
export
ROCSHMEM_MAX_NUM_CONTEXTS
=
32
export
ROCSHMEM_MAX_NUM_CONTEXTS
=
32
export
UCX_ROCM_IPC_SIGPOOL_MAX_ELEMS
=
16384
export
UCX_ROCM_IPC_SIGPOOL_MAX_ELEMS
=
16384
export
UCX_NET_DEVICES
=
mlx5_2:1,mlx5_3:1,mlx5_4:1,mlx5_5:1,mlx5_6:1,mlx5_7:1,mlx5_8:1,mlx5_9:1
export
UCX_NET_DEVICES
=
mlx5_2:1,mlx5_3:1,mlx5_4:1,mlx5_5:1,mlx5_6:1,mlx5_7:1,mlx5_8:1,mlx5_9:1
export
ROCSHMEM_ALLOWED_IBV_DEVICES
=
mlx5_2,mlx5_3,mlx5_4,mlx5_5,mlx5_6,mlx5_7,mlx5_8,mlx5_9
export
HIP_VISIBLE_DEVICES
=
0,1,2,3,4,5,6,7
export
HIP_VISIBLE_DEVICES
=
0,1,2,3,4,5,6,7
# export ROCSHMEM_HEAP_SIZE=536870912 805306368 10737418240
# export ROCSHMEM_HEAP_SIZE=536870912 805306368 10737418240
export
ROCSHMEM_HEAP_SIZE
=
53687
09
1
2
export
ROCSHMEM_HEAP_SIZE
=
288010
09
9
2
export
PYTHONPATH
=
/public/home/lishen/Tmp/DeepEP:
$PYTHONPATH
# torchrun --nproc-per-node=1 --nnodes=2 --node-rank=1 --master-addr="10.16.1.37" --master-port=1234 tests/test_internode.py
torchrun
--nproc-per-node
=
1
--nnodes
=
2
--node-rank
=
1
--master-addr
=
"10.16.1.37"
--master-port
=
1234 tests/test_
internode
.py
torchrun
--nproc-per-node
=
1
--nnodes
=
2
--node-rank
=
1
--master-addr
=
"10.16.1.37"
--master-port
=
1234 tests/test_
low_latency
.py
build.sh
View file @
da13c63a
...
@@ -8,8 +8,10 @@ fi
...
@@ -8,8 +8,10 @@ fi
PYTHON_INCLUDE
=
$(
python3
-c
"from sysconfig import get_paths; print(get_paths()['include'])"
)
PYTHON_INCLUDE
=
$(
python3
-c
"from sysconfig import get_paths; print(get_paths()['include'])"
)
PYTHON_PLATLIB
=
$(
python3
-c
"from sysconfig import get_paths; print(get_paths()['platlib'])"
)
PYTHON_PLATLIB
=
$(
python3
-c
"from sysconfig import get_paths; print(get_paths()['platlib'])"
)
INCLUDE_PATHS
=
${
INCLUDE_PATHS
:
=-Icsrc/ -I
$(
pwd
)
/rocshmem_dir/include/ -I/opt/mpi/include -I
${
PYTHON_PLATLIB
}
/torch/include -I
${
PYTHON_PLATLIB
}
/torch/include/torch/csrc/api/include -I
${
PYTHON_PLATLIB
}
/torch/include/TH -I
${
PYTHON_PLATLIB
}
/torch/include/THC -I
${
PYTHON_PLATLIB
}
/torch/include/THH -I/opt/dtk/include -I
${
PYTHON_INCLUDE
}}
ROCSHMEM_INSTALL_PREFIX
=
${
ROCSHMEM_INSTALL_PREFIX
:
=
$(
pwd
)
/rocshmem_dir
}
COMPILE_OPTIONS
=
${
COMPILE_OPTIONS
:
= -fPIC -D__HIP_PLATFORM_AMD__=1 -DUSE_ROCM=1 -DHIPBLAS_V2 -DCUDA_HAS_FP16=1 -D__HIP_NO_HALF_OPERATORS__=1 -D__HIP_NO_HALF_CONVERSIONS__=1 -O3 -fgpu-rdc -DTORCH_API_INCLUDE_EXTENSION_H
'-DPYBIND11_COMPILER_TYPE="_gcc"'
'-DPYBIND11_STDLIB="_libstdcpp"'
'-DPYBIND11_BUILD_ABI="_cxxabi1014"'
-DTORCH_EXTENSION_NAME=deep_ep_cpp -D_GLIBCXX_USE_CXX11_ABI=1 --offload-arch=gfx936 -std=c++17
}
INCLUDE_PATHS
=
${
INCLUDE_PATHS
:
=-Icsrc/ -I
${
ROCSHMEM_INSTALL_PREFIX
}
/include/ -I/opt/mpi/include -I
${
PYTHON_PLATLIB
}
/torch/include -I
${
PYTHON_PLATLIB
}
/torch/include/torch/csrc/api/include -I
${
PYTHON_PLATLIB
}
/torch/include/TH -I
${
PYTHON_PLATLIB
}
/torch/include/THC -I
${
PYTHON_PLATLIB
}
/torch/include/THH -I/opt/dtk/include -I
${
PYTHON_INCLUDE
}}
COMPILE_OPTIONS
=
${
COMPILE_OPTIONS
:
= -fPIC -D__HIP_PLATFORM_AMD__=1 -DUSE_ROCM=1 -DHIPBLAS_V2 -DCUDA_HAS_FP16=1 -D__HIP_NO_HALF_OPERATORS__=1 -D__HIP_NO_HALF_CONVERSIONS__=1 -O3 -fgpu-rdc -DTORCH_API_INCLUDE_EXTENSION_H
'-DPYBIND11_COMPILER_TYPE="_gcc"'
'-DPYBIND11_STDLIB="_libstdcpp"'
'-DPYBIND11_BUILD_ABI="_cxxabi1014"'
-DTORCH_EXTENSION_NAME=deep_ep_cpp -D_GLIBCXX_USE_CXX11_ABI=1 --offload-arch=gfx936 -std=c++17 -Wno-return-type
}
hipcc
${
INCLUDE_PATHS
}
-c
$(
pwd
)
/csrc/kernels/runtime.cu
-o
build_/runtime.o
${
COMPILE_OPTIONS
}
hipcc
${
INCLUDE_PATHS
}
-c
$(
pwd
)
/csrc/kernels/runtime.cu
-o
build_/runtime.o
${
COMPILE_OPTIONS
}
hipcc
${
INCLUDE_PATHS
}
-c
$(
pwd
)
/csrc/kernels/layout.cu
-o
build_/layout.o
${
COMPILE_OPTIONS
}
hipcc
${
INCLUDE_PATHS
}
-c
$(
pwd
)
/csrc/kernels/layout.cu
-o
build_/layout.o
${
COMPILE_OPTIONS
}
...
@@ -18,7 +20,7 @@ hipcc ${INCLUDE_PATHS} -c $(pwd)/csrc/kernels/internode.cu -o build_/internode.o
...
@@ -18,7 +20,7 @@ hipcc ${INCLUDE_PATHS} -c $(pwd)/csrc/kernels/internode.cu -o build_/internode.o
hipcc
${
INCLUDE_PATHS
}
-c
$(
pwd
)
/csrc/kernels/internode_ll.cu
-o
build_/internode_ll.o
${
COMPILE_OPTIONS
}
hipcc
${
INCLUDE_PATHS
}
-c
$(
pwd
)
/csrc/kernels/internode_ll.cu
-o
build_/internode_ll.o
${
COMPILE_OPTIONS
}
hipcc
${
INCLUDE_PATHS
}
-c
$(
pwd
)
/csrc/deep_ep.cu
-o
build_/deep_ep.o
${
COMPILE_OPTIONS
}
hipcc
${
INCLUDE_PATHS
}
-c
$(
pwd
)
/csrc/deep_ep.cu
-o
build_/deep_ep.o
${
COMPILE_OPTIONS
}
hipcc
-Wno-unused-result
-Wsign-compare
-DNDEBUG
-g
-fwrapv
-O2
-Wall
-g
-fstack-protector-strong
-Wformat
-Werror
=
format-security
-g
-fwrapv
-O2
-shared
-Wl
,-O1
-Wl
,-Bsymbolic-functions build_/internode.o build_/intranode.o build_/runtime.o build_/deep_ep.o build_/layout.o build_/internode_ll.o
-L
$
(
pwd
)
/rocshmem_dir
/lib/
-L
/opt/mpi/lib
-L
/opt/dtk/hip/lib
-L
/usr/lib/x86_64-linux-gnu
-lhipblaslt
-lamdhip64
-o
deep_ep/deep_ep_cpp.cpython-310-x86_64-linux-gnu.so
-Wl
,-rpath,/opt/dtk/lib
-fgpu-rdc
--hip-link
--offload-arch
=
gfx936
-shared
-Wl
,-soname,deep_ep/deep_ep_cpp.cpython-310-x86_64-linux-gnu.so
-Wl
,-rpath,
$
(
pwd
)
/rocshmem_dir
/lib/
-L
"/opt/dtk/llvm/lib/clang/15.0.0/include/../lib/linux"
-lclang_rt
.builtins-x86_64 /opt/dtk/hip/lib/libgalaxyhip.so /opt/dtk/llvm/lib/clang/15.0.0/lib/linux/libclang_rt.builtins-x86_64.a /opt/hyhal/lib/libhsa-runtime64.so.1.11.0
-L
${
PYTHON_PLATLIB
}
/torch/lib
-L
/opt/dtk/lib
-L
/opt/dtk/hip/lib
-L
/usr/local/lib
-lc10
-ltorch
-ltorch_cpu
-ltorch_python
-lamdhip64
-lc10_hip
-ltorch_hip
-lrocm-core
-lrocm_smi64
-l
:librocshmem.a
-fgpu-rdc
--hip-link
-lamdhip64
-lhsa-runtime64
-l
:libmpi.so
-Wl
,-rpath,/opt/mpi/lib/
-libverbs
-lmlx5
hipcc
-Wno-unused-result
-Wsign-compare
-DNDEBUG
-g
-fwrapv
-O2
-Wall
-g
-fstack-protector-strong
-Wformat
-Werror
=
format-security
-g
-fwrapv
-O2
-shared
-Wl
,-O1
-Wl
,-Bsymbolic-functions build_/internode.o build_/intranode.o build_/runtime.o build_/deep_ep.o build_/layout.o build_/internode_ll.o
-L
$
{
ROCSHMEM_INSTALL_PREFIX
}
/lib/
-L
/opt/mpi/lib
-L
/opt/dtk/hip/lib
-L
/usr/lib/x86_64-linux-gnu
-lhipblaslt
-lamdhip64
-o
deep_ep/deep_ep_cpp.cpython-310-x86_64-linux-gnu.so
-Wl
,-rpath,/opt/dtk/lib
-fgpu-rdc
--hip-link
--offload-arch
=
gfx936
-shared
-Wl
,-soname,deep_ep/deep_ep_cpp.cpython-310-x86_64-linux-gnu.so
-Wl
,-rpath,
$
$
{
ROCSHMEM_INSTALL_PREFIX
}
/lib/
-L
"/opt/dtk/llvm/lib/clang/15.0.0/include/../lib/linux"
-lclang_rt
.builtins-x86_64 /opt/dtk/hip/lib/libgalaxyhip.so /opt/dtk/llvm/lib/clang/15.0.0/lib/linux/libclang_rt.builtins-x86_64.a /opt/hyhal/lib/libhsa-runtime64.so.1.11.0
-L
${
PYTHON_PLATLIB
}
/torch/lib
-L
/opt/dtk/lib
-L
/opt/dtk/hip/lib
-L
/usr/local/lib
-lc10
-ltorch
-ltorch_cpu
-ltorch_python
-lamdhip64
-lc10_hip
-ltorch_hip
-lrocm-core
-lrocm_smi64
-l
:librocshmem.a
-fgpu-rdc
--hip-link
-lamdhip64
-lhsa-runtime64
-l
:libmpi.so
-Wl
,-rpath,/opt/mpi/lib/
-libverbs
-lmlx5
# build whl
# build whl
echo
"Using Python:
$(
which python3
)
"
echo
"Using Python:
$(
which python3
)
"
...
...
csrc/config.hpp
View file @
da13c63a
...
@@ -136,7 +136,7 @@ struct LowLatencyLayout {
...
@@ -136,7 +136,7 @@ struct LowLatencyLayout {
LowLatencyLayout
(
void
*
rdma_buffer
,
int
num_max_dispatch_tokens_per_rank
,
int
hidden
,
LowLatencyLayout
(
void
*
rdma_buffer
,
int
num_max_dispatch_tokens_per_rank
,
int
hidden
,
int
num_ranks
,
int
num_experts
)
{
int
num_ranks
,
int
num_experts
)
{
const
int
num_scales
=
hidden
/
128
;
const
int
num_scales
=
hidden
/
FP8_QUANTIZATION_NUM_PER_CHANNEL
;
// Dispatch and combine layout:
// Dispatch and combine layout:
// - 2 symmetric odd/even send buffer
// - 2 symmetric odd/even send buffer
...
...
csrc/deep_ep.cu
View file @
da13c63a
...
@@ -42,7 +42,6 @@ Buffer::Buffer(int rank, int num_ranks, int64_t num_nvl_bytes, int64_t num_rdma_
...
@@ -42,7 +42,6 @@ Buffer::Buffer(int rank, int num_ranks, int64_t num_nvl_bytes, int64_t num_rdma_
rdma_rank
=
rank
/
NUM_MAX_NVL_PEERS
,
nvl_rank
=
rank
%
NUM_MAX_NVL_PEERS
;
rdma_rank
=
rank
/
NUM_MAX_NVL_PEERS
,
nvl_rank
=
rank
%
NUM_MAX_NVL_PEERS
;
num_rdma_ranks
=
::
max
(
1
,
num_ranks
/
NUM_MAX_NVL_PEERS
),
num_rdma_ranks
=
::
max
(
1
,
num_ranks
/
NUM_MAX_NVL_PEERS
),
num_nvl_ranks
=
::
min
(
num_ranks
,
NUM_MAX_NVL_PEERS
);
num_nvl_ranks
=
::
min
(
num_ranks
,
NUM_MAX_NVL_PEERS
);
#ifdef DISABLE_ROCSHMEM
#ifdef DISABLE_ROCSHMEM
EP_HOST_ASSERT
(
num_rdma_ranks
==
1
and
not
low_latency_mode
and
EP_HOST_ASSERT
(
num_rdma_ranks
==
1
and
not
low_latency_mode
and
"rocSHMEM is disabled during compilation, please install rocSHMEM by "
"rocSHMEM is disabled during compilation, please install rocSHMEM by "
...
@@ -269,8 +268,11 @@ void Buffer::sync(const std::vector<int> &device_
...
@@ -269,8 +268,11 @@ void Buffer::sync(const std::vector<int> &device_
// Allocate
// Allocate
rdma_buffer_ptr
=
internode
::
alloc
(
num_rdma_bytes
,
NUM_BUFFER_ALIGNMENT_BYTES
);
rdma_buffer_ptr
=
internode
::
alloc
(
num_rdma_bytes
,
NUM_BUFFER_ALIGNMENT_BYTES
);
// Clean buffer (mainly for low-latency mode)
auto
hip_check
=
hipMemset
(
rdma_buffer_ptr
,
0
,
num_rdma_bytes
);
CUDA_CHECK
(
hipMemset
(
rdma_buffer_ptr
,
0
,
num_rdma_bytes
));
if
(
hip_check
!=
hipSuccess
)
{
printf
(
"Error in hipMemset. Perhaps the value of ROCSHMEM_HEAP_SIZE needs to be greater than num_rdma_bytes(%ld)
\n
"
,
num_rdma_bytes
);
CUDA_CHECK
(
hip_check
);
}
// Allocate and clean shrink buffer
// Allocate and clean shrink buffer
if
(
enable_shrink
)
{
if
(
enable_shrink
)
{
...
@@ -1105,7 +1107,6 @@ Buffer::internode_dispatch(const torch::Tensor &x, const std::optional<torch::Te
...
@@ -1105,7 +1107,6 @@ Buffer::internode_dispatch(const torch::Tensor &x, const std::optional<torch::Te
send_rdma_head
,
send_rdma_head
,
send_nvl_head
,
send_nvl_head
,
event
};
event
};
#else
#else
EP_HOST_ASSERT
(
false
and
"rocSHMEM is disabled during compilation, please install rocSHMEM by "
EP_HOST_ASSERT
(
false
and
"rocSHMEM is disabled during compilation, please install rocSHMEM by "
"following docs/install_dependencies.md"
);
"following docs/install_dependencies.md"
);
...
@@ -1271,7 +1272,6 @@ Buffer::internode_combine(
...
@@ -1271,7 +1272,6 @@ Buffer::internode_combine(
}
}
void
Buffer
::
clean_low_latency_buffer
(
int
num_max_dispatch_tokens_per_rank
,
int
hidden
,
int
num_experts
)
{
void
Buffer
::
clean_low_latency_buffer
(
int
num_max_dispatch_tokens_per_rank
,
int
hidden
,
int
num_experts
)
{
#ifndef DISABLE_ROCSHMEM
EP_HOST_ASSERT
(
low_latency_mode
);
EP_HOST_ASSERT
(
low_latency_mode
);
auto
layout
=
LowLatencyLayout
(
rdma_buffer_ptr
,
num_max_dispatch_tokens_per_rank
,
hidden
,
num_ranks
,
num_experts
);
auto
layout
=
LowLatencyLayout
(
rdma_buffer_ptr
,
num_max_dispatch_tokens_per_rank
,
hidden
,
num_ranks
,
num_experts
);
...
@@ -1282,31 +1282,18 @@ void Buffer::clean_low_latency_buffer(int num_max_dispatch_tokens_per_rank, int
...
@@ -1282,31 +1282,18 @@ void Buffer::clean_low_latency_buffer(int num_max_dispatch_tokens_per_rank, int
auto
offset
=
reinterpret_cast
<
int64_t
>
(
ptr
)
-
reinterpret_cast
<
int64_t
>
(
rdma_buffer_ptr
);
auto
offset
=
reinterpret_cast
<
int64_t
>
(
ptr
)
-
reinterpret_cast
<
int64_t
>
(
rdma_buffer_ptr
);
EP_HOST_ASSERT
(
0
<=
offset
and
offset
+
num_bytes
<=
num_rdma_bytes
);
EP_HOST_ASSERT
(
0
<=
offset
and
offset
+
num_bytes
<=
num_rdma_bytes
);
};
};
check_boundary
(
clean_meta_0
.
first
,
clean_meta_0
.
second
*
sizeof
(
int64_t
));
check_boundary
(
clean_meta_0
.
first
,
clean_meta_0
.
second
*
sizeof
(
int
));
check_boundary
(
clean_meta_1
.
first
,
clean_meta_1
.
second
*
sizeof
(
int64_t
));
check_boundary
(
clean_meta_1
.
first
,
clean_meta_1
.
second
*
sizeof
(
int
));
internode_ll
::
clean_low_latency_buffer
(
clean_meta_0
.
first
,
internode_ll
::
clean_low_latency_buffer
(
clean_meta_0
.
first
,
clean_meta_0
.
second
,
clean_meta_0
.
second
,
clean_meta_1
.
first
,
clean_meta_1
.
second
,
clean_meta_1
.
first
,
clean_meta_1
.
second
,
rank
,
num_ranks
,
mask_buffer_ptr
,
sync_buffer_ptr
,
at
::
hip
::
getCurrentHIPStreamMasqueradingAsCUDA
());
at
::
hip
::
getCurrentHIPStreamMasqueradingAsCUDA
());
#else
EP_HOST_ASSERT
(
false
and
"ROCSHMEM is disabled during compilation"
);
#endif
}
}
std
::
tuple
<
torch
::
Tensor
,
std
::
optional
<
torch
::
Tensor
>
,
torch
::
Tensor
,
torch
::
Tensor
,
torch
::
Tensor
,
std
::
tuple
<
torch
::
Tensor
,
std
::
optional
<
torch
::
Tensor
>
,
torch
::
Tensor
,
torch
::
Tensor
,
torch
::
Tensor
,
std
::
optional
<
EventHandle
>
,
std
::
optional
<
std
::
function
<
void
()
>>>
std
::
optional
<
EventHandle
>
,
std
::
optional
<
std
::
function
<
void
()
>>>
Buffer
::
low_latency_dispatch
(
const
torch
::
Tensor
&
x
,
const
torch
::
Tensor
&
topk_idx
,
Buffer
::
low_latency_dispatch
(
const
torch
::
Tensor
&
x
,
const
torch
::
Tensor
&
topk_idx
,
int
num_max_dispatch_tokens_per_rank
,
int
num_experts
,
const
std
::
optional
<
torch
::
Tensor
>
&
cumulative_local_expert_recv_stats
,
bool
use_fp8
,
bool
async
,
bool
return_recv_hook
)
{
const
std
::
optional
<
torch
::
Tensor
>
&
dispatch_wait_recv_cost_stats
,
int
num_max_dispatch_tokens_per_rank
,
int
num_experts
,
bool
use_fp8
,
bool
round_scale
,
bool
use_ue8m0
,
bool
async
,
bool
return_recv_hook
)
{
#ifndef DISABLE_ROCSHMEM
EP_HOST_ASSERT
(
low_latency_mode
);
EP_HOST_ASSERT
(
low_latency_mode
);
// Tensor checks
// Tensor checks
...
@@ -1318,99 +1305,62 @@ Buffer::low_latency_dispatch(const torch::Tensor &x, const torch::Tensor &topk_i
...
@@ -1318,99 +1305,62 @@ Buffer::low_latency_dispatch(const torch::Tensor &x, const torch::Tensor &topk_i
EP_HOST_ASSERT
(
topk_idx
.
scalar_type
()
==
torch
::
kInt64
);
EP_HOST_ASSERT
(
topk_idx
.
scalar_type
()
==
torch
::
kInt64
);
EP_HOST_ASSERT
(
num_experts
%
num_ranks
==
0
);
EP_HOST_ASSERT
(
num_experts
%
num_ranks
==
0
);
// Diagnosis tensors
if
(
cumulative_local_expert_recv_stats
.
has_value
())
{
EP_HOST_ASSERT
(
cumulative_local_expert_recv_stats
->
scalar_type
()
==
torch
::
kInt
);
EP_HOST_ASSERT
(
cumulative_local_expert_recv_stats
->
dim
()
==
1
and
cumulative_local_expert_recv_stats
->
is_contiguous
());
EP_HOST_ASSERT
(
cumulative_local_expert_recv_stats
->
size
(
0
)
==
num_experts
/
num_ranks
);
}
if
(
dispatch_wait_recv_cost_stats
.
has_value
())
{
EP_HOST_ASSERT
(
dispatch_wait_recv_cost_stats
->
scalar_type
()
==
torch
::
kInt64
);
EP_HOST_ASSERT
(
dispatch_wait_recv_cost_stats
->
dim
()
==
1
and
dispatch_wait_recv_cost_stats
->
is_contiguous
());
EP_HOST_ASSERT
(
dispatch_wait_recv_cost_stats
->
size
(
0
)
==
num_ranks
);
}
auto
num_tokens
=
static_cast
<
int
>
(
x
.
size
(
0
)),
hidden
=
static_cast
<
int
>
(
x
.
size
(
1
));
auto
num_tokens
=
static_cast
<
int
>
(
x
.
size
(
0
)),
hidden
=
static_cast
<
int
>
(
x
.
size
(
1
));
auto
num_topk
=
static_cast
<
int
>
(
topk_idx
.
size
(
1
));
auto
num_scales
=
hidden
/
128
,
num_topk
=
static_cast
<
int
>
(
topk_idx
.
size
(
1
));
auto
num_local_experts
=
num_experts
/
num_ranks
;
int
num_local_experts
=
num_experts
/
num_ranks
;
// Buffer control
// Buffer control
LowLatencyLayout
layout
(
rdma_buffer_ptr
,
num_max_dispatch_tokens_per_rank
,
hidden
,
num_ranks
,
num_experts
);
LowLatencyLayout
layout
(
rdma_buffer_ptr
,
num_max_dispatch_tokens_per_rank
,
hidden
,
num_ranks
,
num_experts
);
EP_HOST_ASSERT
(
layout
.
total_bytes
<=
num_rdma_bytes
);
EP_HOST_ASSERT
(
layout
.
total_bytes
<=
num_rdma_bytes
);
auto
buffer
=
layout
.
buffers
[
low_latency_buffer_idx
];
auto
buffer
=
layout
.
buffers
[
low_latency_buffer_idx
];
auto
next_buffer
=
layout
.
buffers
[
low_latency_buffer_idx
^=
1
];
// 双buffer操作
auto
next_buffer
=
layout
.
buffers
[
low_latency_buffer_idx
^=
1
];
// Buffer control
LowLatencyLayout
nvl_layout
(
nvl_buffer_ptrs
[
nvl_rank
],
num_max_dispatch_tokens_per_rank
,
hidden
,
num_ranks
,
num_experts
);
EP_HOST_ASSERT
(
nvl_layout
.
total_bytes
<=
num_rdma_bytes
);
auto
nvl_buffer
=
nvl_layout
.
buffers
[
low_latency_buffer_idx
^=
1
];
auto
nvl_next_buffer
=
nvl_layout
.
buffers
[
low_latency_buffer_idx
^=
1
];
auto
global_atomic_counter
=
torch
::
zeros
({
1
},
torch
::
dtype
(
torch
::
kInt32
).
device
(
torch
::
kCUDA
));
auto
global_atomic_counter
=
torch
::
zeros
({
1
},
torch
::
dtype
(
torch
::
kInt32
).
device
(
torch
::
kCUDA
));
// Wait previous tasks to be finished
// Wait previous tasks to be finished
// NOTES: the hook mode will always use the default stream
// NOTES: the hook mode will always use the default stream
auto
compute_stream
=
at
::
hip
::
getCurrentHIPStreamMasqueradingAsCUDA
();
auto
compute_stream
=
at
::
hip
::
getCurrentHIPStreamMasqueradingAsCUDA
();
auto
launch_stream
=
return_recv_hook
?
compute_stream
:
comm_stream
;
auto
launch_stream
=
return_recv_hook
?
compute_stream
:
comm_stream
;
EP_HOST_ASSERT
(
not
(
async
and
return_recv_hook
));
EP_HOST_ASSERT
(
not
(
async
and
return_recv_hook
));
if
(
not
return_recv_hook
)
if
(
not
return_recv_hook
)
stream_wait
(
launch_stream
,
compute_stream
);
stream_wait
(
launch_stream
,
compute_stream
);
// Allocate packed tensors
// Allocate packed tensors
auto
packed_recv_x
=
torch
::
empty
({
num_local_experts
,
num_ranks
*
num_max_dispatch_tokens_per_rank
,
hidden
},
auto
packed_recv_x
=
torch
::
empty
({
num_local_experts
,
num_ranks
*
num_max_dispatch_tokens_per_rank
,
hidden
},
x
.
options
().
dtype
(
use_fp8
?
torch
::
kFloat8_e4m3fn
:
torch
::
kBFloat16
));
x
.
options
().
dtype
(
use_fp8
?
torch
::
kFloat8_e4m3fnuz
:
torch
::
kBFloat16
));
auto
packed_recv_src_info
=
auto
packed_recv_src_info
=
torch
::
empty
({
num_local_experts
,
num_ranks
*
num_max_dispatch_tokens_per_rank
},
torch
::
dtype
(
torch
::
kInt32
).
device
(
torch
::
kCUDA
));
torch
::
empty
({
num_local_experts
,
num_ranks
*
num_max_dispatch_tokens_per_rank
},
torch
::
dtype
(
torch
::
kInt32
).
device
(
torch
::
kCUDA
));
auto
packed_recv_layout_range
=
torch
::
empty
({
num_local_experts
,
num_ranks
},
torch
::
dtype
(
torch
::
kInt64
).
device
(
torch
::
kCUDA
));
auto
packed_recv_layout_range
=
torch
::
empty
({
num_local_experts
,
num_ranks
},
torch
::
dtype
(
torch
::
kInt64
).
device
(
torch
::
kCUDA
));
auto
packed_recv_count
=
torch
::
empty
({
num_local_experts
},
torch
::
dtype
(
torch
::
kInt32
).
device
(
torch
::
kCUDA
));
auto
packed_recv_count
=
torch
::
empty
({
num_local_experts
},
torch
::
dtype
(
torch
::
kInt32
).
device
(
torch
::
kCUDA
));
// Allocate column-majored scales
// Allocate column-majored scales
auto
packed_recv_x_scales
=
std
::
optional
<
torch
::
Tensor
>
();
auto
packed_recv_x_scales
=
std
::
optional
<
torch
::
Tensor
>
();
void
*
packed_recv_x_scales_ptr
=
nullptr
;
float
*
packed_recv_x_scales_ptr
=
nullptr
;
EP_HOST_ASSERT
((
num_ranks
*
num_max_dispatch_tokens_per_rank
)
%
4
==
0
and
"TMA requires the number of tokens to be multiple of 4"
);
// TODO: support unaligned cases
EP_HOST_ASSERT
(
hidden
%
512
==
0
);
if
(
use_fp8
)
{
if
(
use_fp8
)
{
if
(
not
use_ue8m0
)
{
EP_HOST_ASSERT
((
num_ranks
*
num_max_dispatch_tokens_per_rank
)
%
4
==
0
and
"TMA requires the number of tokens to be multiple of 4"
);
packed_recv_x_scales
=
torch
::
empty
({
num_local_experts
,
hidden
/
128
,
num_ranks
*
num_max_dispatch_tokens_per_rank
},
packed_recv_x_scales
=
torch
::
empty
({
num_local_experts
,
num_scales
,
num_ranks
*
num_max_dispatch_tokens_per_rank
},
torch
::
dtype
(
torch
::
kFloat32
).
device
(
torch
::
kCUDA
));
torch
::
dtype
(
torch
::
kFloat32
).
device
(
torch
::
kCUDA
));
}
else
{
EP_HOST_ASSERT
(
round_scale
);
packed_recv_x_scales
=
torch
::
empty
({
num_local_experts
,
hidden
/
512
,
num_ranks
*
num_max_dispatch_tokens_per_rank
},
torch
::
dtype
(
torch
::
kInt
).
device
(
torch
::
kCUDA
));
}
packed_recv_x_scales
=
torch
::
transpose
(
packed_recv_x_scales
.
value
(),
1
,
2
);
packed_recv_x_scales
=
torch
::
transpose
(
packed_recv_x_scales
.
value
(),
1
,
2
);
packed_recv_x_scales_ptr
=
packed_recv_x_scales
->
data_ptr
();
packed_recv_x_scales_ptr
=
packed_recv_x_scales
->
data_ptr
<
float
>
();
}
}
// Kernel launch
// Kernel launch
auto
next_clean_meta
=
next_buffer
.
clean_meta
();
auto
next_clean_meta
=
next_buffer
.
clean_meta
();
auto
launcher
=
[
=
](
int
phases
)
{
auto
launcher
=
[
=
](
int
phases
)
{
internode_ll
::
dispatch
(
internode_ll
::
dispatch
(
packed_recv_x
.
data_ptr
(),
packed_recv_x_scales_ptr
,
packed_recv_x
.
data_ptr
(),
packed_recv_src_info
.
data_ptr
<
int
>
(),
packed_recv_layout_range
.
data_ptr
<
int64_t
>
(),
packed_recv_x_scales_ptr
,
packed_recv_src_info
.
data_ptr
<
int
>
(),
packed_recv_layout_range
.
data_ptr
<
int64_t
>
(),
packed_recv_count
.
data_ptr
<
int
>
(),
packed_recv_count
.
data_ptr
<
int
>
(),
global_atomic_counter
.
data_ptr
<
int
>
(),
global_atomic_counter
.
data_ptr
<
int
>
(),
mask_buffer_ptr
,
buffer
.
dispatch_rdma_recv_data_buffer
,
buffer
.
dispatch_rdma_recv_count_buffer
,
cumulative_local_expert_recv_stats
.
has_value
()
?
cumulative_local_expert_recv_stats
->
data_ptr
<
int
>
()
:
nullptr
,
dispatch_wait_recv_cost_stats
.
has_value
()
?
dispatch_wait_recv_cost_stats
->
data_ptr
<
int64_t
>
()
:
nullptr
,
buffer
.
dispatch_rdma_recv_data_buffer
,
buffer
.
dispatch_rdma_recv_count_buffer
,
buffer
.
dispatch_rdma_send_buffer
,
buffer
.
dispatch_rdma_send_buffer
,
x
.
data_ptr
(),
x
.
data_ptr
(),
topk_idx
.
data_ptr
<
int64_t
>
(),
topk_idx
.
data_ptr
<
int64_t
>
(),
next_clean_meta
.
first
,
next_clean_meta
.
second
,
next_clean_meta
.
first
,
num_tokens
,
hidden
,
num_max_dispatch_tokens_per_rank
,
next_clean_meta
.
second
,
num_topk
,
num_experts
,
rank
,
num_ranks
,
use_fp8
,
num_tokens
,
workspace
,
launch_stream
,
phases
);
hidden
,
num_max_dispatch_tokens_per_rank
,
num_topk
,
num_experts
,
rank
,
num_ranks
,
use_fp8
,
round_scale
,
use_ue8m0
,
workspace
,
num_device_sms
,
launch_stream
,
phases
);
};
};
launcher
(
return_recv_hook
?
LOW_LATENCY_SEND_PHASE
:
(
LOW_LATENCY_SEND_PHASE
|
LOW_LATENCY_RECV_PHASE
));
launcher
(
return_recv_hook
?
LOW_LATENCY_SEND_PHASE
:
(
LOW_LATENCY_SEND_PHASE
|
LOW_LATENCY_RECV_PHASE
));
...
@@ -1431,20 +1381,14 @@ Buffer::low_latency_dispatch(const torch::Tensor &x, const torch::Tensor &topk_i
...
@@ -1431,20 +1381,14 @@ Buffer::low_latency_dispatch(const torch::Tensor &x, const torch::Tensor &topk_i
// Return values
// Return values
return
{
packed_recv_x
,
packed_recv_x_scales
,
packed_recv_count
,
packed_recv_src_info
,
packed_recv_layout_range
,
event
,
recv_hook
};
return
{
packed_recv_x
,
packed_recv_x_scales
,
packed_recv_count
,
packed_recv_src_info
,
packed_recv_layout_range
,
event
,
recv_hook
};
#else
EP_HOST_ASSERT
(
false
and
"ROCSHMEM is disabled during compilation"
);
return
{};
#endif
}
}
std
::
tuple
<
torch
::
Tensor
,
std
::
optional
<
EventHandle
>
,
std
::
optional
<
std
::
function
<
void
()
>>>
std
::
tuple
<
torch
::
Tensor
,
std
::
optional
<
EventHandle
>
,
std
::
optional
<
std
::
function
<
void
()
>>>
Buffer
::
low_latency_combine
(
const
torch
::
Tensor
&
x
,
const
torch
::
Tensor
&
topk_idx
,
const
torch
::
Tensor
&
topk_weights
,
Buffer
::
low_latency_combine
(
const
torch
::
Tensor
&
x
,
const
torch
::
Tensor
&
topk_idx
,
const
torch
::
Tensor
&
topk_weights
,
const
torch
::
Tensor
&
src_info
,
const
torch
::
Tensor
&
layout_range
,
const
torch
::
Tensor
&
src_info
,
const
torch
::
Tensor
&
layout_range
,
const
std
::
optional
<
torch
::
Tensor
>&
combine_wait_recv_cost_stats
,
int
num_max_dispatch_tokens_per_rank
,
int
num_experts
,
int
num_max_dispatch_tokens_per_rank
,
int
num_experts
,
bool
use_logfmt
,
bool
zero_copy
,
bool
async
,
bool
return_recv_hook
,
bool
zero_copy
,
bool
async
,
bool
return_recv_hook
,
const
std
::
optional
<
torch
::
Tensor
>&
out
)
{
const
std
::
optional
<
torch
::
Tensor
>&
out
)
{
#ifndef DISABLE_ROCSHMEM
EP_HOST_ASSERT
(
low_latency_mode
);
EP_HOST_ASSERT
(
low_latency_mode
);
// Tensor checks
// Tensor checks
...
@@ -1463,29 +1407,27 @@ Buffer::low_latency_combine(const torch::Tensor& x, const torch::Tensor& topk_id
...
@@ -1463,29 +1407,27 @@ Buffer::low_latency_combine(const torch::Tensor& x, const torch::Tensor& topk_id
EP_HOST_ASSERT
(
layout_range
.
dim
()
==
2
and
layout_range
.
is_contiguous
());
EP_HOST_ASSERT
(
layout_range
.
dim
()
==
2
and
layout_range
.
is_contiguous
());
EP_HOST_ASSERT
(
layout_range
.
scalar_type
()
==
torch
::
kInt64
);
EP_HOST_ASSERT
(
layout_range
.
scalar_type
()
==
torch
::
kInt64
);
EP_HOST_ASSERT
(
layout_range
.
size
(
0
)
==
num_experts
/
num_ranks
and
layout_range
.
size
(
1
)
==
num_ranks
);
EP_HOST_ASSERT
(
layout_range
.
size
(
0
)
==
num_experts
/
num_ranks
and
layout_range
.
size
(
1
)
==
num_ranks
);
if
(
combine_wait_recv_cost_stats
.
has_value
())
{
EP_HOST_ASSERT
(
combine_wait_recv_cost_stats
->
scalar_type
()
==
torch
::
kInt64
);
EP_HOST_ASSERT
(
combine_wait_recv_cost_stats
->
dim
()
==
1
and
combine_wait_recv_cost_stats
->
is_contiguous
());
EP_HOST_ASSERT
(
combine_wait_recv_cost_stats
->
size
(
0
)
==
num_ranks
);
}
auto
hidden
=
static_cast
<
int
>
(
x
.
size
(
2
));
auto
hidden
=
static_cast
<
int
>
(
x
.
size
(
2
));
auto
num_topk
=
static_cast
<
int
>
(
topk_weights
.
size
(
1
));
auto
num_local_experts
=
num_experts
/
num_ranks
,
num_topk
=
static_cast
<
int
>
(
topk_weights
.
size
(
1
));
auto
num_combined_tokens
=
static_cast
<
int
>
(
topk_weights
.
size
(
0
));
auto
num_combined_tokens
=
static_cast
<
int
>
(
topk_weights
.
size
(
0
));
auto
global_atomic_counter
=
torch
::
zeros
({
1
},
torch
::
dtype
(
torch
::
kInt32
).
device
(
torch
::
kCUDA
));
auto
global_atomic_counter
=
torch
::
zeros
({
1
},
torch
::
dtype
(
torch
::
kInt32
).
device
(
torch
::
kCUDA
));
// Buffer control
// Buffer control
LowLatencyLayout
layout
(
rdma_buffer_ptr
,
num_max_dispatch_tokens_per_rank
,
hidden
,
num_ranks
,
num_experts
);
LowLatencyLayout
layout
(
rdma_buffer_ptr
,
num_max_dispatch_tokens_per_rank
,
hidden
,
num_ranks
,
num_experts
);
EP_HOST_ASSERT
(
layout
.
total_bytes
<=
num_rdma_bytes
);
EP_HOST_ASSERT
(
layout
.
total_bytes
<=
num_rdma_bytes
);
auto
buffer
=
layout
.
buffers
[
low_latency_buffer_idx
];
auto
buffer
=
layout
.
buffers
[
low_latency_buffer_idx
];
auto
next_buffer
=
layout
.
buffers
[
low_latency_buffer_idx
^=
1
];
auto
next_buffer
=
layout
.
buffers
[
low_latency_buffer_idx
^=
1
];
// Buffer control
LowLatencyLayout
nvl_layout
(
nvl_buffer_ptrs
[
nvl_rank
],
num_max_dispatch_tokens_per_rank
,
hidden
,
num_ranks
,
num_experts
);
EP_HOST_ASSERT
(
nvl_layout
.
total_bytes
<=
num_rdma_bytes
);
auto
nvl_buffer
=
nvl_layout
.
buffers
[
low_latency_buffer_idx
^=
1
];
auto
nvl_next_buffer
=
nvl_layout
.
buffers
[
low_latency_buffer_idx
^=
1
];
// Wait previous tasks to be finished
// Wait previous tasks to be finished
// NOTES: the hook mode will always use the default stream
// NOTES: the hook mode will always use the default stream
auto
compute_stream
=
at
::
hip
::
getCurrentHIPStreamMasqueradingAsCUDA
();
auto
compute_stream
=
at
::
hip
::
getCurrentHIPStreamMasqueradingAsCUDA
();
auto
launch_stream
=
return_recv_hook
?
compute_stream
:
comm_stream
;
auto
launch_stream
=
return_recv_hook
?
compute_stream
:
comm_stream
;
EP_HOST_ASSERT
(
not
(
async
and
return_recv_hook
));
EP_HOST_ASSERT
(
not
(
async
and
return_recv_hook
));
if
(
not
return_recv_hook
)
if
(
not
return_recv_hook
)
stream_wait
(
launch_stream
,
compute_stream
);
stream_wait
(
launch_stream
,
compute_stream
);
...
@@ -1504,32 +1446,16 @@ Buffer::low_latency_combine(const torch::Tensor& x, const torch::Tensor& topk_id
...
@@ -1504,32 +1446,16 @@ Buffer::low_latency_combine(const torch::Tensor& x, const torch::Tensor& topk_id
auto
next_clean_meta
=
next_buffer
.
clean_meta
();
auto
next_clean_meta
=
next_buffer
.
clean_meta
();
auto
launcher
=
[
=
](
int
phases
)
{
auto
launcher
=
[
=
](
int
phases
)
{
internode_ll
::
combine
(
combined_x
.
data_ptr
(),
internode_ll
::
combine
(
combined_x
.
data_ptr
(),
buffer
.
combine_rdma_recv_data_buffer
,
buffer
.
combine_rdma_recv_data_buffer
,
buffer
.
combine_rdma_recv_flag_buffer
,
buffer
.
combine_rdma_recv_flag_buffer
,
buffer
.
combine_rdma_send_buffer
,
buffer
.
combine_rdma_send_buffer
,
x
.
data_ptr
(),
x
.
data_ptr
(),
topk_idx
.
data_ptr
<
int64_t
>
(),
topk_weights
.
data_ptr
<
float
>
(),
topk_idx
.
data_ptr
<
int64_t
>
(),
src_info
.
data_ptr
<
int
>
(),
layout_range
.
data_ptr
<
int64_t
>
(),
topk_weights
.
data_ptr
<
float
>
(),
src_info
.
data_ptr
<
int
>
(),
layout_range
.
data_ptr
<
int64_t
>
(),
global_atomic_counter
.
data_ptr
<
int
>
(),
global_atomic_counter
.
data_ptr
<
int
>
(),
mask_buffer_ptr
,
next_clean_meta
.
first
,
next_clean_meta
.
second
,
combine_wait_recv_cost_stats
.
has_value
()
?
combine_wait_recv_cost_stats
->
data_ptr
<
int64_t
>
()
:
nullptr
,
num_combined_tokens
,
hidden
,
num_max_dispatch_tokens_per_rank
,
next_clean_meta
.
first
,
num_topk
,
num_experts
,
rank
,
num_ranks
,
next_clean_meta
.
second
,
workspace
,
launch_stream
,
num_combined_tokens
,
phases
,
zero_copy
);
hidden
,
num_max_dispatch_tokens_per_rank
,
num_topk
,
num_experts
,
rank
,
num_ranks
,
use_logfmt
,
workspace
,
num_device_sms
,
launch_stream
,
phases
,
zero_copy
);
};
};
launcher
(
return_recv_hook
?
LOW_LATENCY_SEND_PHASE
:
(
LOW_LATENCY_SEND_PHASE
|
LOW_LATENCY_RECV_PHASE
));
launcher
(
return_recv_hook
?
LOW_LATENCY_SEND_PHASE
:
(
LOW_LATENCY_SEND_PHASE
|
LOW_LATENCY_RECV_PHASE
));
...
@@ -1550,49 +1476,19 @@ Buffer::low_latency_combine(const torch::Tensor& x, const torch::Tensor& topk_id
...
@@ -1550,49 +1476,19 @@ Buffer::low_latency_combine(const torch::Tensor& x, const torch::Tensor& topk_id
// Return values
// Return values
return
{
combined_x
,
event
,
recv_hook
};
return
{
combined_x
,
event
,
recv_hook
};
#else
EP_HOST_ASSERT
(
false
and
"ROCSHMEM is disabled during compilation"
);
return
{};
#endif
}
}
torch
::
Tensor
Buffer
::
get_next_low_latency_combine_buffer
(
int
num_max_dispatch_tokens_per_rank
,
int
hidden
,
int
num_experts
)
const
{
torch
::
Tensor
Buffer
::
get_next_low_latency_combine_buffer
(
int
num_max_dispatch_tokens_per_rank
,
int
hidden
,
int
num_experts
)
{
#ifndef DISABLE_ROCSHMEM
LowLatencyLayout
layout
(
rdma_buffer_ptr
,
num_max_dispatch_tokens_per_rank
,
hidden
,
num_ranks
,
num_experts
);
LowLatencyLayout
layout
(
rdma_buffer_ptr
,
num_max_dispatch_tokens_per_rank
,
hidden
,
num_ranks
,
num_experts
);
auto
buffer
=
layout
.
buffers
[
low_latency_buffer_idx
];
auto
buffer
=
layout
.
buffers
[
low_latency_buffer_idx
];
auto
dtype
=
torch
::
kBFloat16
;
auto
dtype
=
torch
::
kBFloat16
;
auto
num_msg_elems
=
static_cast
<
int
>
(
buffer
.
num_bytes_per_combine_msg
/
elementSize
(
torch
::
kBFloat16
));
auto
num_msg_elems
=
static_cast
<
int
>
(
buffer
.
num_bytes_per_combine_msg
/
elementSize
(
torch
::
kBFloat16
));
// buffer.num_bytes_per_combine_msg = sizeof(int4) + hidden * sizeof(hip_bfloat16);
EP_HOST_ASSERT
(
buffer
.
num_bytes_per_combine_msg
%
elementSize
(
torch
::
kBFloat16
)
==
0
);
EP_HOST_ASSERT
(
buffer
.
num_bytes_per_combine_msg
%
elementSize
(
torch
::
kBFloat16
)
==
0
);
return
torch
::
from_blob
(
buffer
.
combine_rdma_send_buffer_data_start
,
return
torch
::
from_blob
(
buffer
.
combine_rdma_send_buffer_data_start
,
{
num_experts
/
num_ranks
,
num_ranks
*
num_max_dispatch_tokens_per_rank
,
hidden
},
{
num_experts
/
num_ranks
,
num_ranks
*
num_max_dispatch_tokens_per_rank
,
hidden
},
{
num_ranks
*
num_max_dispatch_tokens_per_rank
*
num_msg_elems
,
num_msg_elems
,
1
},
{
num_ranks
*
num_max_dispatch_tokens_per_rank
*
num_msg_elems
,
num_msg_elems
,
1
},
torch
::
TensorOptions
().
dtype
(
dtype
).
device
(
torch
::
kCUDA
));
torch
::
TensorOptions
().
dtype
(
dtype
).
device
(
torch
::
kCUDA
));
#else
EP_HOST_ASSERT
(
false
and
"ROCSHMEM is disabled during compilation"
);
return
{};
#endif
}
void
Buffer
::
low_latency_update_mask_buffer
(
int
rank_to_mask
,
bool
mask
)
{
EP_HOST_ASSERT
(
mask_buffer_ptr
!=
nullptr
and
"Shrink mode must be enabled"
);
EP_HOST_ASSERT
(
rank_to_mask
>=
0
and
rank_to_mask
<
num_ranks
);
internode_ll
::
update_mask_buffer
(
mask_buffer_ptr
,
rank_to_mask
,
mask
,
at
::
hip
::
getCurrentHIPStreamMasqueradingAsCUDA
());
}
void
Buffer
::
low_latency_query_mask_buffer
(
const
torch
::
Tensor
&
mask_status
)
{
EP_HOST_ASSERT
(
mask_buffer_ptr
!=
nullptr
and
"Shrink mode must be enabled"
);
EP_HOST_ASSERT
(
mask_status
.
numel
()
==
num_ranks
&&
mask_status
.
scalar_type
()
==
torch
::
kInt32
);
internode_ll
::
query_mask_buffer
(
mask_buffer_ptr
,
num_ranks
,
reinterpret_cast
<
int
*>
(
mask_status
.
data_ptr
()),
at
::
hip
::
getCurrentHIPStreamMasqueradingAsCUDA
());
}
void
Buffer
::
low_latency_clean_mask_buffer
()
{
EP_HOST_ASSERT
(
mask_buffer_ptr
!=
nullptr
and
"Shrink mode must be enabled"
);
internode_ll
::
clean_mask_buffer
(
mask_buffer_ptr
,
num_ranks
,
at
::
hip
::
getCurrentHIPStreamMasqueradingAsCUDA
());
}
}
}
// namespace deep_ep
}
// namespace deep_ep
...
@@ -1634,10 +1530,7 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
...
@@ -1634,10 +1530,7 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
.
def
(
"clean_low_latency_buffer"
,
&
deep_ep
::
Buffer
::
clean_low_latency_buffer
)
.
def
(
"clean_low_latency_buffer"
,
&
deep_ep
::
Buffer
::
clean_low_latency_buffer
)
.
def
(
"low_latency_dispatch"
,
&
deep_ep
::
Buffer
::
low_latency_dispatch
)
.
def
(
"low_latency_dispatch"
,
&
deep_ep
::
Buffer
::
low_latency_dispatch
)
.
def
(
"low_latency_combine"
,
&
deep_ep
::
Buffer
::
low_latency_combine
)
.
def
(
"low_latency_combine"
,
&
deep_ep
::
Buffer
::
low_latency_combine
)
.
def
(
"get_next_low_latency_combine_buffer"
,
&
deep_ep
::
Buffer
::
get_next_low_latency_combine_buffer
)
.
def
(
"get_next_low_latency_combine_buffer"
,
&
deep_ep
::
Buffer
::
get_next_low_latency_combine_buffer
);
.
def
(
"low_latency_update_mask_buffer"
,
&
deep_ep
::
Buffer
::
low_latency_update_mask_buffer
)
.
def
(
"low_latency_query_mask_buffer"
,
&
deep_ep
::
Buffer
::
low_latency_query_mask_buffer
)
.
def
(
"low_latency_clean_mask_buffer"
,
&
deep_ep
::
Buffer
::
low_latency_clean_mask_buffer
);
// m.def("is_sm90_compiled", deep_ep::is_sm90_compiled);
// m.def("is_sm90_compiled", deep_ep::is_sm90_compiled);
// m.attr("int64_t") = py::cast(c10::CppTypeToScalarType<deep_ep::int64_t>::value);
// m.attr("int64_t") = py::cast(c10::CppTypeToScalarType<deep_ep::int64_t>::value);
...
...
csrc/deep_ep.hpp
View file @
da13c63a
...
@@ -26,6 +26,9 @@ private:
...
@@ -26,6 +26,9 @@ private:
void
*
buffer_ptrs
[
NUM_MAX_NVL_PEERS
]
=
{
nullptr
};
void
*
buffer_ptrs
[
NUM_MAX_NVL_PEERS
]
=
{
nullptr
};
void
**
buffer_ptrs_gpu
=
nullptr
;
void
**
buffer_ptrs_gpu
=
nullptr
;
void
*
nvl_buffer_ptrs
[
NUM_MAX_NVL_PEERS
]
=
{
nullptr
};
void
**
nvl_buffer_ptrs_gpu
=
nullptr
;
// NVSHMEM Buffer
// NVSHMEM Buffer
int64_t
num_rdma_bytes
;
int64_t
num_rdma_bytes
;
void
*
rdma_buffer_ptr
=
nullptr
;
void
*
rdma_buffer_ptr
=
nullptr
;
...
@@ -171,31 +174,19 @@ public:
...
@@ -171,31 +174,19 @@ public:
void
clean_low_latency_buffer
(
int
num_max_dispatch_tokens_per_rank
,
int
hidden
,
void
clean_low_latency_buffer
(
int
num_max_dispatch_tokens_per_rank
,
int
hidden
,
int
num_experts
);
int
num_experts
);
std
::
tuple
<
torch
::
Tensor
,
std
::
optional
<
torch
::
Tensor
>
,
torch
::
Tensor
,
torch
::
Tensor
,
std
::
tuple
<
torch
::
Tensor
,
std
::
optional
<
torch
::
Tensor
>
,
torch
::
Tensor
,
torch
::
Tensor
,
torch
::
Tensor
,
std
::
optional
<
EventHandle
>
,
std
::
optional
<
std
::
function
<
void
()
>>>
torch
::
Tensor
,
std
::
optional
<
EventHandle
>
,
std
::
optional
<
std
::
function
<
void
()
>>>
low_latency_dispatch
(
const
torch
::
Tensor
&
x
,
const
torch
::
Tensor
&
topk_idx
,
low_latency_dispatch
(
const
torch
::
Tensor
&
x
,
const
torch
::
Tensor
&
topk_idx
,
int
num_max_dispatch_tokens_per_rank
,
int
num_experts
,
const
std
::
optional
<
torch
::
Tensor
>
&
cumulative_local_expert_recv_stats
,
bool
use_fp8
,
bool
async
,
bool
return_recv_hook
);
const
std
::
optional
<
torch
::
Tensor
>
&
dispatch_wait_recv_cost_stats
,
int
num_max_dispatch_tokens_per_rank
,
int
num_experts
,
bool
use_fp8
,
bool
round_scale
,
bool
use_ue8m0
,
bool
async
,
bool
return_recv_hook
);
std
::
tuple
<
torch
::
Tensor
,
std
::
optional
<
EventHandle
>
,
std
::
optional
<
std
::
function
<
void
()
>>>
std
::
tuple
<
torch
::
Tensor
,
std
::
optional
<
EventHandle
>
,
std
::
optional
<
std
::
function
<
void
()
>>>
low_latency_combine
(
const
torch
::
Tensor
&
x
,
const
torch
::
Tensor
&
topk_idx
,
low_latency_combine
(
const
torch
::
Tensor
&
x
,
const
torch
::
Tensor
&
topk_idx
,
const
torch
::
Tensor
&
topk_weights
,
const
torch
::
Tensor
&
topk_weights
,
const
torch
::
Tensor
&
src_info
,
const
torch
::
Tensor
&
src_info
,
const
torch
::
Tensor
&
layout_range
,
const
torch
::
Tensor
&
layout_range
,
int
num_max_dispatch_tokens_per_rank
,
int
num_experts
,
const
std
::
optional
<
torch
::
Tensor
>
&
combine_wait_recv_cost_stats
,
int
num_max_dispatch_tokens_per_rank
,
int
num_experts
,
bool
use_logfmt
,
bool
zero_copy
,
bool
async
,
bool
return_recv_hook
,
bool
zero_copy
,
bool
async
,
bool
return_recv_hook
,
const
std
::
optional
<
torch
::
Tensor
>
&
out
=
std
::
nullopt
);
const
std
::
optional
<
torch
::
Tensor
>&
out
=
std
::
nullopt
);
torch
::
Tensor
get_next_low_latency_combine_buffer
(
int
num_max_dispatch_tokens_per_rank
,
int
hidden
,
int
num_experts
)
const
;
void
low_latency_update_mask_buffer
(
int
rank_to_mask
,
bool
mask
);
void
low_latency_query_mask_buffer
(
const
torch
::
Tensor
&
mask_status
);
void
low_latency_clean_mask_buffer
(
);
torch
::
Tensor
get_next_low_latency_combine_buffer
(
int
num_max_dispatch_tokens_per_rank
,
int
hidden
,
int
num_experts
);
};
};
}
// namespace deep_ep
}
// namespace deep_ep
csrc/kernels/api.cuh
View file @
da13c63a
...
@@ -134,43 +134,32 @@ void combine(hipDataType type, void *combined_x, float *combined_topk_weights,
...
@@ -134,43 +134,32 @@ void combine(hipDataType type, void *combined_x, float *combined_topk_weights,
// Internode low-latency kernels
// Internode low-latency kernels
namespace
internode_ll
{
namespace
internode_ll
{
void
clean_low_latency_buffer
(
int64_t
*
clean_0
,
int
num_clean_int_0
,
void
clean_low_latency_buffer
(
int64_t
*
clean_0
,
int
num_clean_int_0
,
int64_t
*
clean_1
,
int
num_clean_int_1
,
int64_t
*
clean_1
,
int
num_clean_int_1
,
int
rank
,
int
num_ranks
,
hipStream_t
stream
);
int
*
mask_buffer
,
int
*
sync_buffer
,
hipStream_t
stream
);
void
dispatch
(
void
*
packed_recv_x
,
void
*
packed_recv_x_scales
,
void
dispatch
(
void
*
packed_recv_x
,
float
*
packed_recv_x_scales
,
int
*
packed_recv_src_info
,
int64_t
*
packed_recv_layout_range
,
int
*
packed_recv_count
,
int
*
packed_recv_src_info
,
int64_t
*
packed_recv_layout_range
,
int
*
packed_recv_count
,
int
*
global_atomic_counter
,
int
*
global_atomic_counter
,
int
*
mask_buffer
,
int
*
cumulative_local_expert_recv_stats
,
int64_t
*
dispatch_wait_recv_cost_stats
,
void
*
rdma_recv_x
,
int64_t
*
rdma_recv_count
,
void
*
rdma_x
,
void
*
rdma_recv_x
,
int64_t
*
rdma_recv_count
,
void
*
rdma_x
,
const
void
*
x
,
const
int64_t
*
topk_idx
,
const
void
*
x
,
const
int64_t
*
topk_idx
,
int64_t
*
next_clean
,
int
num_next_clean_int
,
int64_t
*
next_clean
,
int
num_next_clean_int
,
int
num_tokens
,
int
hidden
,
int
num_max_dispatch_tokens_per_rank
,
int
num_tokens
,
int
hidden
,
int
num_max_dispatch_tokens_per_rank
,
int
num_topk
,
int
num_experts
,
int
rank
,
int
num_ranks
,
int
num_topk
,
int
num_experts
,
int
rank
,
int
num_ranks
,
bool
use_fp8
,
bool
use_fp8
,
bool
round_scale
,
bool
use_ue8m0
,
void
*
workspace
,
hipStream_t
stream
,
int
phases
);
void
*
workspace
,
int
num_device_sms
,
hipStream_t
stream
,
int
phases
);
void
combine
(
void
*
combined_x
,
void
combine
(
void
*
combined_x
,
void
*
rdma_recv_x
,
int64_t
*
rdma_recv_flag
,
void
*
rdma_send_x
,
void
*
rdma_recv_x
,
int64_t
*
rdma_recv_flag
,
void
*
rdma_send_x
,
const
void
*
x
,
const
int64_t
*
topk_idx
,
const
float
*
topk_weights
,
const
void
*
x
,
const
int64_t
*
topk_idx
,
const
float
*
topk_weights
,
const
int
*
src_info
,
const
int64_t
*
layout_range
,
const
int
*
src_info
,
const
int64_t
*
layout_range
,
int
*
global_atomic_counter
,
int
*
global_atomic_counter
,
int
*
mask_buffer
,
int64_t
*
combine_wait_recv_cost_stats
,
int64_t
*
next_clean
,
int
num_next_clean_int
,
int64_t
*
next_clean
,
int
num_next_clean_int
,
int
num_combined_tokens
,
int
num_combined_tokens
,
int
hidden
,
int
num_max_dispatch_tokens_per_rank
,
int
hidden
,
int
num_max_dispatch_tokens_per_rank
,
int
num_topk
,
int
num_experts
,
int
rank
,
int
num_ranks
,
int
num_topk
,
int
num_experts
,
int
rank
,
int
num_ranks
,
bool
use_logfmt
,
void
*
workspace
,
hipStream_t
stream
,
void
*
workspace
,
int
num_device_sms
,
hipStream_t
stream
,
int
phases
,
bool
zero_copy
);
int
phases
,
bool
zero_copy
);
void
query_mask_buffer
(
int
*
mask_buffer_ptr
,
int
num_ranks
,
int
*
output_mask_tensor
,
hipStream_t
stream
);
void
update_mask_buffer
(
int
*
mask_buffer_ptr
,
int
rank_to_mask
,
bool
mask
,
hipStream_t
stream
);
void
clean_mask_buffer
(
int
*
mask_buffer_ptr
,
int
num_ranks
,
hipStream_t
stream
);
}
// namespace internode_ll
}
// namespace internode_ll
}
// namespace deep_ep
}
// namespace deep_ep
csrc/kernels/internode_ll.cu
View file @
da13c63a
...
@@ -5,63 +5,64 @@
...
@@ -5,63 +5,64 @@
#include "utils.cuh"
#include "utils.cuh"
// #include <cooperative_groups.h>
// #include <cooperative_groups.h>
#include <iostream>
#include <iostream>
#include "hip/hip_runtime.h"
// low latency+RocSHMEM has issue with CTX.
// low latency+RocSHMEM has issue with CTX.
#define ROCM_DISABLE_CTX
#define ROCM_DISABLE_CTX
#ifndef DISABLE_ROCSHMEM
#include <rocshmem/rocshmem.hpp>
#include <rocshmem/rocshmem.hpp>
#include <rocshmem/rocshmem_COLL.hpp>
using
namespace
rocshmem
;
namespace
deep_ep
{
namespace
deep_ep
{
namespace
internode_ll
{
namespace
internode_ll
{
template
<
bool
use_warp_sync
=
false
>
template
<
typename
dtype_a_t
,
typename
dtype_b_t
>
__forceinline__
__device__
bool
is_rank_masked
(
int
*
mask_buffer_ptr
,
int
rank
)
{
__device__
__forceinline__
dtype_b_t
pack2
(
const
dtype_a_t
&
x
,
const
dtype_a_t
&
y
)
{
if
(
mask_buffer_ptr
==
nullptr
)
{
EP_STATIC_ASSERT
(
sizeof
(
dtype_a_t
)
*
2
==
sizeof
(
dtype_b_t
),
"Invalid dtypes"
);
return
false
;
dtype_b_t
packed
;
}
auto
unpacked_ptr
=
reinterpret_cast
<
dtype_a_t
*>
(
&
packed
);
if
constexpr
(
use_warp_sync
)
{
unpacked_ptr
[
0
]
=
x
,
unpacked_ptr
[
1
]
=
y
;
return
shfl_sync
(
ld_acquire_global
(
mask_buffer_ptr
+
rank
),
0
)
!=
0
;
return
packed
;
}
else
{
return
ld_acquire_global
(
mask_buffer_ptr
+
rank
)
!=
0
;
}
}
}
__device__
void
grid_barrier
(
int
*
global_counter
,
int
num_blocks
)
{
__device__
void
grid_barrier
(
int
*
global_counter
,
int
num_blocks
)
{
volatile
int
ret
;
volatile
int
ret
;
__syncthreads
();
__syncthreads
();
memory_
fence
_gpu
();
__thread
fence
();
if
(
threadIdx
.
x
==
0
)
{
if
(
threadIdx
.
x
==
0
)
{
ret
=
atomicAdd
((
int
*
)
&
global_counter
[
0
],
1
);
// ret = __hip_atomic_fetch_add(&global_counter[0], 1, __ATOMIC_RELAXED, __HIP_MEMORY_SCOPE_AGENT);
ret
=
atomicAdd
(
&
global_counter
[
0
],
1
);
}
}
__syncthreads
();
__syncthreads
();
if
(
threadIdx
.
x
==
0
)
{
if
(
threadIdx
.
x
==
0
)
{
while
(
ld_relaxed_global
(
global_counter
)
!=
num_blocks
);
while
(
__hip_atomic_load
(
global_counter
,
__ATOMIC_RELAXED
,
__HIP_MEMORY_SCOPE_AGENT
)
!=
num_blocks
);
}
}
__syncthreads
();
__syncthreads
();
}
}
template
<
typename
dtype_t
>
__host__
__device__
dtype_t
ceil_div
(
dtype_t
a
,
dtype_t
b
)
{
return
(
a
+
b
-
1
)
/
b
;
}
template
<
typename
dtype_a_t
,
typename
dtype_b_t
>
__device__
__forceinline__
void
unpack2
(
const
dtype_b_t
&
packed
,
dtype_a_t
&
x
,
dtype_a_t
&
y
)
{
EP_STATIC_ASSERT
(
sizeof
(
dtype_a_t
)
*
2
==
sizeof
(
dtype_b_t
),
"Invalid dtypes"
);
auto
unpacked_ptr
=
reinterpret_cast
<
const
dtype_a_t
*>
(
&
packed
);
x
=
unpacked_ptr
[
0
],
y
=
unpacked_ptr
[
1
];
}
template
<
int
kNumThreads
>
__launch_bounds__
(
kNumThreads
,
1
)
template
<
int
kNumThreads
>
__launch_bounds__
(
kNumThreads
,
1
)
__global__
void
clean_low_latency_buffer
(
int64_t
*
clean_0
,
int
num_clean_int_0
,
__global__
void
clean_low_latency_buffer
(
int64_t
*
clean_0
,
int
num_clean_int_0
,
int64_t
*
clean_1
,
int
num_clean_int_1
,
int64_t
*
clean_1
,
int
num_clean_int_1
)
{
int
rank
,
int
num_ranks
,
int
*
mask_buffer_ptr
,
int
*
sync_buffer_ptr
)
{
auto
thread_id
=
static_cast
<
int
>
(
threadIdx
.
x
);
// Barrier before cleaning (in case of unfinished chunked EP)
// Barrier before cleaning (in case of unfinished chunked EP)
if
(
sync_buffer_ptr
==
nullptr
)
{
if
(
threadIdx
.
x
==
0
)
// rocshmem::rocshmem_barrier_all_wg();
if
(
thread_id
==
0
)
rocshmem
::
rocshmem_barrier_all
();
rocshmem
::
rocshmem_barrier_all
();
}
else
{
// barrier<kNumThreads>(thread_id, rank, num_ranks, mask_buffer_ptr, sync_buffer_ptr);
EP_DEVICE_ASSERT
(
0
);
}
// Clean
// Clean
auto
thread_id
=
static_cast
<
int
>
(
threadIdx
.
x
);
#pragma unroll
#pragma unroll
for
(
int
i
=
thread_id
;
i
<
num_clean_int_0
;
i
+=
kNumThreads
)
for
(
int
i
=
thread_id
;
i
<
num_clean_int_0
;
i
+=
kNumThreads
)
clean_0
[
i
]
=
0
;
clean_0
[
i
]
=
0
;
...
@@ -70,59 +71,33 @@ __global__ void clean_low_latency_buffer(int64_t* clean_0, int num_clean_int_0,
...
@@ -70,59 +71,33 @@ __global__ void clean_low_latency_buffer(int64_t* clean_0, int num_clean_int_0,
clean_1
[
i
]
=
0
;
clean_1
[
i
]
=
0
;
// Barrier after cleaning (make sure low-latency mode work
// Barrier after cleaning (make sure low-latency mode work
if
(
sync_buffer_ptr
==
nullptr
)
{
if
(
threadIdx
.
x
==
0
)
// rocshmem::rocshmem_barrier_all_wg();
if
(
thread_id
==
0
)
rocshmem
::
rocshmem_barrier_all
();
rocshmem
::
rocshmem_barrier_all
();
}
else
{
// barrier<kNumThreads>(thread_id, rank, num_ranks, mask_buffer_ptr, sync_buffer_ptr);
EP_DEVICE_ASSERT
(
0
);
}
}
}
void
clean_low_latency_buffer
(
int64_t
*
clean_0
,
int
num_clean_int_0
,
void
clean_low_latency_buffer
(
int64_t
*
clean_0
,
int
num_clean_int_0
,
int64_t
*
clean_1
,
int
num_clean_int_1
,
int64_t
*
clean_1
,
int
num_clean_int_1
,
int
rank
,
int
num_ranks
,
int
*
mask_buffer_ptr
,
int
*
sync_buffer_ptr
,
hipStream_t
stream
)
{
hipStream_t
stream
)
{
constexpr
int
kNumThreads
=
256
;
constexpr
int
kNumThreads
=
256
;
SETUP_LAUNCH_CONFIG
(
1
,
kNumThreads
,
stream
);
SETUP_LAUNCH_CONFIG
(
1
,
kNumThreads
,
stream
);
LAUNCH_KERNEL
(
&
cfg
,
clean_low_latency_buffer
<
kNumThreads
>
,
LAUNCH_KERNEL_NON_COOPERATIVE
(
&
cfg
,
clean_low_latency_buffer
<
kNumThreads
>
,
clean_0
,
num_clean_int_0
,
clean_1
,
num_clean_int_1
,
clean_0
,
num_clean_int_0
,
clean_1
,
num_clean_int_1
);
rank
,
num_ranks
,
mask_buffer_ptr
,
sync_buffer_ptr
);
}
}
template
<
bool
kUseFP8
,
bool
kUseUE8M0
,
int
kHidden
>
template
<
bool
kUseFP8
,
int
kHidden
>
__launch_bounds__
(
1024
,
1
)
__global__
void
dispatch
(
void
*
packed_recv_x
,
__global__
__launch_bounds__
(
16
*
kWarpSize
,
1
)
void
void
*
packed_recv_x_scales
,
dispatch
(
void
*
packed_recv_x
,
float
*
packed_recv_x_scales
,
int
*
packed_recv_src_info
,
int
*
packed_recv_src_info
,
int64_t
*
packed_recv_layout_range
,
int64_t
*
packed_recv_layout_range
,
int
*
packed_recv_count
,
int
*
packed_recv_count
,
int
*
global_atomic_counter
,
int
*
global_atomic_counter
,
int
*
mask_buffer_ptr
,
void
*
rdma_recv_x
,
int64_t
*
rdma_recv_count
,
void
*
rdma_x
,
int
*
cumulative_local_expert_recv_stats
,
const
void
*
x
,
const
int64_t
*
topk_idx
,
int64_t
*
dispatch_wait_recv_cost_stats
,
int
*
atomic_counter_per_expert
,
int
*
atomic_finish_counter_per_expert
,
void
*
rdma_recv_x
,
int64_t
*
next_clean
,
int
num_next_clean_int
,
int64_t
*
rdma_recv_count
,
int
num_tokens
,
int
num_max_dispatch_tokens_per_rank
,
void
*
rdma_x
,
int
num_topk
,
int
num_experts
,
int
rank
,
int
num_ranks
,
const
void
*
x
,
int
num_warp_groups
,
int
num_warps_per_group
,
int
phases
)
{
const
int64_t
*
topk_idx
,
int
*
atomic_counter_per_expert
,
int
*
atomic_finish_counter_per_expert
,
int64_t
*
next_clean
,
int
num_next_clean_int
,
int
num_tokens
,
int
num_max_dispatch_tokens_per_rank
,
int
num_topk
,
int
num_experts
,
int
rank
,
int
num_ranks
,
int
num_warp_groups
,
int
num_warps_per_group
,
bool
round_scale
,
int
phases
)
{
#if !defined(ROCM_DISABLE_CTX)
#if !defined(ROCM_DISABLE_CTX)
__shared__
rocshmem
::
rocshmem_ctx_t
ctx
;
__shared__
rocshmem
::
rocshmem_ctx_t
ctx
;
rocshmem
::
rocshmem_wg_ctx_create
(
0
,
&
ctx
);
rocshmem
::
rocshmem_wg_ctx_create
(
0
,
&
ctx
);
...
@@ -136,33 +111,21 @@ __launch_bounds__(1024, 1) __global__ void dispatch(void* packed_recv_x,
...
@@ -136,33 +111,21 @@ __launch_bounds__(1024, 1) __global__ void dispatch(void* packed_recv_x,
const
auto
num_local_experts
=
num_experts
/
num_ranks
;
const
auto
num_local_experts
=
num_experts
/
num_ranks
;
const
auto
warp_group_id
=
warp_id
/
num_warps_per_group
;
const
auto
warp_group_id
=
warp_id
/
num_warps_per_group
;
const
auto
sub_warp_id
=
warp_id
%
num_warps_per_group
;
const
auto
sub_warp_id
=
warp_id
%
num_warps_per_group
;
// 每个warp处理一个expert
const
auto
responsible_expert_idx
=
sm_id
*
num_warp_groups
+
warp_group_id
;
const
auto
responsible_expert_idx
=
sm_id
*
num_warp_groups
+
warp_group_id
;
// May extract UE8M0 from the scales
using
scale_t
=
std
::
conditional_t
<
kUseUE8M0
,
uint8_t
,
float
>
;
using
packed_t
=
std
::
conditional_t
<
kUseUE8M0
,
uint32_t
,
float
>
;
EP_STATIC_ASSERT
(
sizeof
(
packed_t
)
%
sizeof
(
scale_t
)
==
0
,
"Invalid vector length"
);
// FP8 staffs
// FP8 staffs
constexpr
int
kNumPerChannels
=
FP8_QUANTIZATION_NUM_PER_CHANNEL
;
constexpr
int
kNumPerChannels
=
FP8_QUANTIZATION_NUM_PER_CHANNEL
;
const
int
num_scales
=
kHidden
/
kNumPerChannels
;
const
int
num_scales
=
kHidden
/
kNumPerChannels
;
const
size_t
hidden_bytes
=
kHidden
*
(
kUseFP8
?
sizeof
(
__hip_fp8_storage_t
)
:
sizeof
(
hip_bfloat16
));
const
size_t
hidden_bytes
=
kHidden
*
(
kUseFP8
?
sizeof
(
__hip_fp8_storage_t
)
:
sizeof
(
hip_bfloat16
));
const
size_t
hidden_int4
=
hidden_bytes
/
sizeof
(
int4
);
const
size_t
hidden_int4
=
hidden_bytes
/
sizeof
(
int4
);
// Message package:
index at source (int), 3 reserved int fields,
hidden data, FP8 scales
// Message package: hidden data, FP8 scales
, index at source
// NOTES: currently we have 3 reserved int fields for future use
// NOTES: currently we have 3 reserved int fields for future use
using
vec_t
=
std
::
conditional
_t
<
kUseFP8
,
int2
,
int4
>
;
using
vec_t
=
typename
std
::
conditional
<
kUseFP8
,
int2
,
int4
>
::
type
;
const
size_t
num_bytes_per_msg
=
sizeof
(
int4
)
+
(
kUseFP8
?
(
kHidden
+
num_scales
*
sizeof
(
float
))
:
(
kHidden
*
sizeof
(
hip_bfloat16
)));
const
size_t
num_bytes_per_msg
=
sizeof
(
int4
)
+
(
kUseFP8
?
(
kHidden
+
num_scales
*
sizeof
(
float
))
:
(
kHidden
*
sizeof
(
hip_bfloat16
)));
const
size_t
num_int4_per_msg
=
num_bytes_per_msg
/
sizeof
(
int4
);
const
size_t
num_int4_per_msg
=
num_bytes_per_msg
/
sizeof
(
int4
);
EP_DEVICE_ASSERT
(
num_bytes_per_msg
%
sizeof
(
int4
)
==
0
);
EP_DEVICE_ASSERT
(
num_bytes_per_msg
%
sizeof
(
int4
)
==
0
);
// Expert counts
constexpr
int
kNumMaxWarpGroups
=
16
;
// 每个kernel最多warp group数量,即每个block负责的专家数
__shared__
int
shared_num_tokens_sent_per_expert
[
kNumMaxWarpGroups
];
#ifdef USE_ROCM
// 用于同步
// 16 is the max possible number of warps in AMD GPUs
// 16 is the max possible number of warps in AMD GPUs
constexpr
int
kMaxNumWarps
=
1024
/
kWarpSize
;
constexpr
int
kMaxNumWarps
=
1024
/
kWarpSize
;
constexpr
int
num_sync_large_iteration
=
kMaxNumWarps
;
constexpr
int
num_sync_large_iteration
=
kMaxNumWarps
;
...
@@ -173,57 +136,57 @@ __launch_bounds__(1024, 1) __global__ void dispatch(void* packed_recv_x,
...
@@ -173,57 +136,57 @@ __launch_bounds__(1024, 1) __global__ void dispatch(void* packed_recv_x,
sync_large_warp_counters
[
i
]
=
0
;
sync_large_warp_counters
[
i
]
=
0
;
}
}
__syncthreads
();
__syncthreads
();
#endif
// Sending phase,如果没有发送任务,则直接跳到接收阶段
// Expert counts
constexpr
int
kNumMaxWarpGroups
=
1024
/
kWarpSize
;
__shared__
int
shared_num_tokens_sent_per_expert
[
kNumMaxWarpGroups
];
// Sending phase
if
((
phases
&
LOW_LATENCY_SEND_PHASE
)
==
0
)
if
((
phases
&
LOW_LATENCY_SEND_PHASE
)
==
0
)
goto
LOW_LATENCY_DISPATCH_RECV
;
goto
LOW_LATENCY_DISPATCH_RECV
;
// There are 2 kinds of warps in this part:
// There are 2 kinds of warps in this part:
// 1. The first-kind warps for FP8 cast and sending top-k tokens
// 1. The first-kind warps for FP8 cast and sending top-k tokens
// 2. The last warp for reading `topk_idx` and count for per-expert information
// 2. The last warp for reading `topk_idx` and count for per-expert information
if
(
warp_id
<
num_warps
-
1
)
{
if
(
warp_id
<
num_warps
)
{
constexpr
int
kNumElemsPerRead
=
sizeof
(
int4
)
/
sizeof
(
hip_bfloat16
);
// 128/16 = 8
constexpr
int
kNumElemsPerRead
=
sizeof
(
int4
)
/
sizeof
(
hip_bfloat16
);
EP_
STAT
IC_ASSERT
(
kHidden
%
(
kWarpSize
*
kNumElemsPerRead
)
==
0
,
"Invalid hidden"
);
EP_
DEV
IC
E
_ASSERT
(
kHidden
%
kNumElemsPerRead
==
0
);
EP_STATIC_ASSERT
(
kNumElemsPerRead
*
kWarpSize
%
kNumPerChannels
==
0
,
"Invalid vectorization"
);
EP_STATIC_ASSERT
(
kNumElemsPerRead
*
kWarpSize
%
kNumPerChannels
==
0
,
"Invalid vectorization"
);
const
auto
num_threads
=
(
num_warps
-
1
)
*
kWarpSize
;
const
auto
num_threads
=
(
num_warps
-
1
)
*
kWarpSize
;
const
size_t
hidden_bf16_int4
=
kHidden
/
kNumElemsPerRead
;
const
size_t
hidden_bf16_int4
=
kHidden
/
kNumElemsPerRead
;
for
(
int
token_idx
=
sm_id
;
token_idx
<
num_tokens
;
token_idx
+=
num_sms
)
{
for
(
int
token_idx
=
sm_id
;
token_idx
<
num_tokens
;
token_idx
+=
num_sms
)
{
const
auto
x_int4
=
static
_cast
<
const
int4
*>
(
x
)
+
token_idx
*
hidden_bf16_int4
;
const
auto
x_int4
=
reinterpret
_cast
<
const
int4
*>
(
x
)
+
token_idx
*
hidden_bf16_int4
;
const
auto
rdma_x_src_idx
=
reinterpret_cast
<
int
*>
(
static
_cast
<
uint8_t
*>
(
rdma_x
)
+
token_idx
*
num_bytes_per_msg
);
const
auto
rdma_x_src_idx
=
reinterpret_cast
<
int
*>
(
reinterpret
_cast
<
uint8_t
*>
(
rdma_x
)
+
token_idx
*
num_bytes_per_msg
);
const
auto
rdma_x_vec
=
reinterpret_cast
<
vec_t
*>
(
reinterpret_cast
<
uint8_t
*>
(
rdma_x_src_idx
)
+
sizeof
(
int4
));
const
auto
rdma_x_vec
=
reinterpret_cast
<
vec_t
*>
(
reinterpret_cast
<
uint8_t
*>
(
rdma_x_src_idx
)
+
sizeof
(
int4
));
const
auto
rdma_x_scales
=
reinterpret_cast
<
float
*>
(
reinterpret_cast
<
uint8_t
*>
(
rdma_x_vec
)
+
hidden_bytes
);
const
auto
rdma_x_scales
=
reinterpret_cast
<
float
*>
(
reinterpret_cast
<
uint8_t
*>
(
rdma_x_vec
)
+
hidden_bytes
);
// Overlap top-k index read and source token index write
s
// Overlap top-k index read and source token index write
auto
dst_expert_idx
=
warp_id
<
num_topk
?
static_cast
<
int
>
(
__ldg
(
topk_idx
+
token_idx
*
num_topk
+
warp_id
))
:
-
1
;
auto
dst_expert_idx
=
warp_id
<
num_topk
?
static_cast
<
int
>
(
__ldg
(
topk_idx
+
token_idx
*
num_topk
+
warp_id
))
:
-
1
;
thread_id
==
0
?
(
*
rdma_x_src_idx
=
token_idx
)
:
0
;
thread_id
==
0
?
(
*
rdma_x_src_idx
=
token_idx
)
:
0
;
// FP8 cast
// FP8 cast
EP_STATIC_ASSERT
(
hidden_bf16_int4
%
kWarpSize
==
0
,
"Must use the full warp to reduce"
);
#pragma unroll
#pragma unroll
for
(
int
i
=
thread_id
;
i
<
hidden_bf16_int4
;
i
+=
num_threads
)
{
for
(
int
i
=
thread_id
;
i
<
hidden_bf16_int4
;
i
+=
num_threads
)
{
// Read
// Read
auto
int4_value
=
__ldg
(
x_int4
+
i
);
auto
int4_value
=
__ldg
(
x_int4
+
i
);
if
constexpr
(
kUseFP8
)
{
if
(
kUseFP8
)
{
// Calculate local amax
// Calculate local amax
auto
bf16_values
=
reinterpret_cast
<
hip_bfloat16
*>
(
&
int4_value
);
auto
bf16_values
=
reinterpret_cast
<
hip_bfloat16
*>
(
&
int4_value
);
float
fp32_values
[
kNumElemsPerRead
];
float
fp32_values
[
kNumElemsPerRead
];
float
amax
=
kFP8Margin
,
scale
,
scale_inv
;
float
amax
=
kFP8Margin
,
scale
,
scale_inv
;
#pragma unroll
#pragma unroll
for
(
int
j
=
0
;
j
<
kNumElemsPerRead
;
++
j
)
{
for
(
int
j
=
0
;
j
<
kNumElemsPerRead
;
++
j
)
{
fp32_values
[
j
]
=
static_cast
<
float
>
(
bf16_values
[
j
]);
fp32_values
[
j
]
=
static_cast
<
float
>
(
bf16_values
[
j
]);
amax
=
fmaxf
(
amax
,
fabsf
(
fp32_values
[
j
]));
amax
=
fmaxf
(
amax
,
fabsf
(
fp32_values
[
j
]));
}
}
// Reduce amax and scale
// Reduce amax and scale
EP_STATIC_ASSERT
(
kNumElemsPerRead
*
kWarpSize
/
kNumPerChannels
==
4
,
"Invalid vectorization"
);
EP_STATIC_ASSERT
(
kNumElemsPerRead
*
kWarpSize
/
kNumPerChannels
==
4
,
"Invalid vectorization"
);
amax
=
warp_reduce_max
<
16
>
(
amax
);
amax
=
warp_reduce_max
<
16
>
(
amax
);
calculate_fp8_scales
(
amax
,
scale
,
scale_inv
,
round_scale
);
calculate_fp8_scales
<
/*round_scale*/
false
>
(
amax
,
scale
,
scale_inv
);
if
(
lane_id
%
16
==
0
)
if
(
lane_id
%
16
==
0
)
rdma_x_scales
[
i
*
kNumElemsPerRead
/
FP8_QUANTIZATION_NUM_PER_CHANNEL
]
=
scale_inv
;
rdma_x_scales
[
i
*
kNumElemsPerRead
/
128
]
=
scale_inv
;
// Cast into send buffer
// Cast into send buffer
vec_t
int2_value
;
vec_t
int2_value
;
...
@@ -240,44 +203,38 @@ __launch_bounds__(1024, 1) __global__ void dispatch(void* packed_recv_x,
...
@@ -240,44 +203,38 @@ __launch_bounds__(1024, 1) __global__ void dispatch(void* packed_recv_x,
}
}
}
}
__syncthreads
();
__syncthreads
();
// Issue IBGDA sends
// Issue IBGDA sends
if
(
dst_expert_idx
>=
0
)
{
if
(
dst_expert_idx
>=
0
)
{
int
slot_idx
=
lane_id
==
0
?
atomicAdd
(
atomic_counter_per_expert
+
dst_expert_idx
,
1
)
:
0
;
int
slot_idx
=
lane_id
==
0
?
atomicAdd
(
atomic_counter_per_expert
+
dst_expert_idx
,
1
)
:
0
;
slot_idx
=
shfl_sync
(
slot_idx
,
0
);
slot_idx
=
shfl_sync
(
slot_idx
,
0
);
const
int
dst_rank
=
dst_expert_idx
/
num_local_experts
;
const
auto
dst_rank
=
dst_expert_idx
/
num_local_experts
;
const
int
dst_expert_local_idx
=
dst_expert_idx
%
num_local_experts
;
const
auto
dst_expert_local_idx
=
dst_expert_idx
%
num_local_experts
;
const
auto
src_ptr
=
reinterpret_cast
<
uint64_t
>
(
rdma_x_src_idx
);
const
auto
src_ptr
=
reinterpret_cast
<
uint64_t
>
(
rdma_x_src_idx
);
const
auto
dst_ptr
=
reinterpret_cast
<
uint64_t
>
(
rdma_recv_x
)
+
const
auto
dst_ptr
=
reinterpret_cast
<
uint64_t
>
(
rdma_recv_x
)
+
dst_expert_local_idx
*
num_ranks
*
num_max_dispatch_tokens_per_rank
*
num_bytes_per_msg
+
dst_expert_local_idx
*
num_ranks
*
num_max_dispatch_tokens_per_rank
*
num_bytes_per_msg
+
rank
*
num_max_dispatch_tokens_per_rank
*
num_bytes_per_msg
+
slot_idx
*
num_bytes_per_msg
;
rank
*
num_max_dispatch_tokens_per_rank
*
num_bytes_per_msg
+
slot_idx
*
num_bytes_per_msg
;
if
(
dst_rank
!=
rank
)
{
if
(
dst_rank
!=
rank
)
{
#if !defined(ROCM_DISABLE_CTX)
rocshmem
::
rocshmem_schar_put_nbi_wave
(
reinterpret_cast
<
signed
char
*>
(
dst_ptr
),
rocshmem
::
rocshmem_ctx_schar_put_nbi_wave
(
ctx
,
reinterpret_cast
<
signed
char
*>
(
src_ptr
),
num_bytes_per_msg
,
dst_rank
);
reinterpret_cast
<
signed
char
*>
(
dst_ptr
),
reinterpret_cast
<
signed
char
*>
(
src_ptr
),
num_bytes_per_msg
,
dst_rank
);
rocshmem
::
rocshmem_ctx_quiet
(
ctx
);
#else
rocshmem
::
rocshmem_schar_put_nbi_wave
(
reinterpret_cast
<
signed
char
*>
(
dst_ptr
),
reinterpret_cast
<
signed
char
*>
(
src_ptr
),
num_bytes_per_msg
,
dst_rank
);
rocshmem
::
rocshmem_fence
();
rocshmem
::
rocshmem_fence
();
#endif
}
else
{
}
else
{
// NOTES: only 2 load iterations for 7K hidden with 8 unrolls
// NOTES: only 2 load iterations for 7K hidden with 8 unrolls
const
auto
*
src_int4_ptr
=
reinterpret_cast
<
const
int4
*>
(
src_ptr
);
const
auto
*
src_int4_ptr
=
reinterpret_cast
<
const
int4
*>
(
src_ptr
);
const
auto
*
dst_int4_ptr
=
reinterpret_cast
<
int4
*>
(
dst_ptr
);
const
auto
*
dst_int4_ptr
=
reinterpret_cast
<
int4
*>
(
dst_ptr
);
UNROLLED_WARP_COPY
(
8
,
lane_id
,
num_int4_per_msg
,
dst_int4_ptr
,
src_int4_ptr
,
ld_nc_global
,
st_na_global
);
UNROLLED_WARP_COPY
_LL
(
8
,
lane_id
,
num_int4_per_msg
,
dst_int4_ptr
,
src_int4_ptr
,
ld_nc_global
,
st_na_global
);
}
}
// Increase counter after finishing
// Increase counter after finishing
syncwarp
();
syncwarp
();
lane_id
==
0
?
atomic_add_release_global
(
atomic_finish_counter_per_expert
+
dst_expert_idx
,
1
)
:
0
;
lane_id
==
0
?
atomic_add_release_global
(
atomic_finish_counter_per_expert
+
dst_expert_idx
,
1
)
:
0
;
}
}
}
}
}
else
if
(
warp_id
==
num_warps
-
1
)
{
}
if
(
warp_id
==
num_warps
-
1
)
{
EP_DEVICE_ASSERT
(
num_sms
>
1
);
EP_DEVICE_ASSERT
(
num_sms
>
1
);
if
(
sm_id
==
0
)
{
if
(
sm_id
==
0
)
{
// The first SM is also responsible for checking QPs
// The first SM is also responsible for cleaning the next buffer
// The first SM is also responsible for cleaning the next buffer
#pragma unroll
#pragma unroll
for
(
int
i
=
lane_id
;
i
<
num_next_clean_int
;
i
+=
kWarpSize
)
for
(
int
i
=
lane_id
;
i
<
num_next_clean_int
;
i
+=
kWarpSize
)
...
@@ -289,7 +246,6 @@ __launch_bounds__(1024, 1) __global__ void dispatch(void* packed_recv_x,
...
@@ -289,7 +246,6 @@ __launch_bounds__(1024, 1) __global__ void dispatch(void* packed_recv_x,
for
(
int
i
=
lane_id
;
i
<
num_experts
;
i
+=
kWarpSize
)
for
(
int
i
=
lane_id
;
i
<
num_experts
;
i
+=
kWarpSize
)
atomic_add_release_global
(
atomic_finish_counter_per_expert
+
i
,
FINISHED_SUM_TAG
);
atomic_add_release_global
(
atomic_finish_counter_per_expert
+
i
,
FINISHED_SUM_TAG
);
}
}
// This SM should be responsible for some destination experts, read `topk_idx` for them
// This SM should be responsible for some destination experts, read `topk_idx` for them
int
expert_count
[
kNumMaxWarpGroups
]
=
{
0
};
int
expert_count
[
kNumMaxWarpGroups
]
=
{
0
};
const
auto
expert_begin_idx
=
sm_id
*
num_warp_groups
;
const
auto
expert_begin_idx
=
sm_id
*
num_warp_groups
;
...
@@ -300,12 +256,12 @@ __launch_bounds__(1024, 1) __global__ void dispatch(void* packed_recv_x,
...
@@ -300,12 +256,12 @@ __launch_bounds__(1024, 1) __global__ void dispatch(void* packed_recv_x,
for
(
int
i
=
lane_id
;
i
<
num_tokens
*
num_topk
;
i
+=
kWarpSize
)
{
for
(
int
i
=
lane_id
;
i
<
num_tokens
*
num_topk
;
i
+=
kWarpSize
)
{
auto
idx
=
static_cast
<
int
>
(
__ldg
(
topk_idx
+
i
));
auto
idx
=
static_cast
<
int
>
(
__ldg
(
topk_idx
+
i
));
if
(
idx
>=
expert_begin_idx
and
idx
<
expert_end_idx
)
if
(
idx
>=
expert_begin_idx
and
idx
<
expert_end_idx
)
expert_count
[
idx
-
expert_begin_idx
]
++
;
expert_count
[
idx
-
expert_begin_idx
]
++
;
}
}
// Warp reduce
// Warp reduce
#pragma unroll
#pragma unroll
for
(
int
i
=
expert_begin_idx
;
i
<
expert_end_idx
;
++
i
)
{
for
(
int
i
=
expert_begin_idx
;
i
<
expert_end_idx
;
++
i
)
{
auto
sum
=
warp_reduce_sum
(
expert_count
[
i
-
expert_begin_idx
]);
auto
sum
=
warp_reduce_sum
(
expert_count
[
i
-
expert_begin_idx
]);
if
(
lane_id
==
0
)
{
if
(
lane_id
==
0
)
{
shared_num_tokens_sent_per_expert
[
i
-
expert_begin_idx
]
=
sum
;
shared_num_tokens_sent_per_expert
[
i
-
expert_begin_idx
]
=
sum
;
...
@@ -314,6 +270,7 @@ __launch_bounds__(1024, 1) __global__ void dispatch(void* packed_recv_x,
...
@@ -314,6 +270,7 @@ __launch_bounds__(1024, 1) __global__ void dispatch(void* packed_recv_x,
}
}
}
}
//revert sync_large_warp_counters to 0 for next sync
__syncthreads
();
__syncthreads
();
// Issue count sends
// Issue count sends
...
@@ -324,17 +281,10 @@ __launch_bounds__(1024, 1) __global__ void dispatch(void* packed_recv_x,
...
@@ -324,17 +281,10 @@ __launch_bounds__(1024, 1) __global__ void dispatch(void* packed_recv_x,
// Wait local sends issued and send expert counts
// Wait local sends issued and send expert counts
while
(
ld_acquire_global
(
atomic_finish_counter_per_expert
+
responsible_expert_idx
)
!=
FINISHED_SUM_TAG
*
2
);
while
(
ld_acquire_global
(
atomic_finish_counter_per_expert
+
responsible_expert_idx
)
!=
FINISHED_SUM_TAG
*
2
);
if
(
not
is_rank_masked
(
mask_buffer_ptr
,
dst_rank
))
{
auto
dst_ptr
=
reinterpret_cast
<
int64_t
*>
(
rdma_recv_count
+
dst_expert_local_idx
*
num_ranks
+
rank
);
if
(
dst_rank
!=
rank
)
{
if
(
dst_rank
!=
rank
)
{
#if !defined(ROCM_DISABLE_CTX)
rocshmem
::
rocshmem_long_atomic_add
(
rdma_recv_count
+
dst_expert_local_idx
*
num_ranks
+
rank
,
-
num_tokens_sent
-
1
,
dst_rank
);
rocshmem
::
rocshmem_ctx_long_atomic_add
(
ctx
,
dst_ptr
,
-
num_tokens_sent
-
1
,
dst_rank
);
#else
rocshmem
::
rocshmem_long_atomic_add
(
dst_ptr
,
-
num_tokens_sent
-
1
,
dst_rank
);
#endif
}
else
{
}
else
{
st_release_sys_global
(
dst_ptr
,
-
num_tokens_sent
-
1
);
st_na_release
(
reinterpret_cast
<
int
*>
(
rdma_recv_count
+
dst_expert_local_idx
*
num_ranks
+
rank
),
-
num_tokens_sent
-
1
);
}
}
}
// Clean workspace for next use
// Clean workspace for next use
...
@@ -347,10 +297,8 @@ __launch_bounds__(1024, 1) __global__ void dispatch(void* packed_recv_x,
...
@@ -347,10 +297,8 @@ __launch_bounds__(1024, 1) __global__ void dispatch(void* packed_recv_x,
}
}
syncwarp
();
syncwarp
();
// Receiving phase
// Receiving phase
LOW_LATENCY_DISPATCH_RECV:
LOW_LATENCY_DISPATCH_RECV:
// 如果没有接收直接返回
if
((
phases
&
LOW_LATENCY_RECV_PHASE
)
==
0
)
if
((
phases
&
LOW_LATENCY_RECV_PHASE
)
==
0
)
return
;
return
;
...
@@ -363,85 +311,40 @@ __launch_bounds__(1024, 1) __global__ void dispatch(void* packed_recv_x,
...
@@ -363,85 +311,40 @@ __launch_bounds__(1024, 1) __global__ void dispatch(void* packed_recv_x,
if
(
responsible_expert_idx
<
num_experts
)
{
if
(
responsible_expert_idx
<
num_experts
)
{
const
auto
src_rank
=
responsible_expert_idx
/
num_local_experts
;
const
auto
src_rank
=
responsible_expert_idx
/
num_local_experts
;
const
auto
local_expert_idx
=
responsible_expert_idx
%
num_local_experts
;
const
auto
local_expert_idx
=
responsible_expert_idx
%
num_local_experts
;
const
auto
rdma_recv_x_uint8
=
static
_cast
<
uint8_t
*>
(
rdma_recv_x
)
+
const
auto
rdma_recv_x_uint8
=
reinterpret
_cast
<
uint8_t
*>
(
rdma_recv_x
)
+
local_expert_idx
*
num_ranks
*
num_max_dispatch_tokens_per_rank
*
num_bytes_per_msg
+
local_expert_idx
*
num_ranks
*
num_max_dispatch_tokens_per_rank
*
num_bytes_per_msg
+
src_rank
*
num_max_dispatch_tokens_per_rank
*
num_bytes_per_msg
;
src_rank
*
num_max_dispatch_tokens_per_rank
*
num_bytes_per_msg
;
const
auto
recv_x_int4
=
const
auto
recv_x_int4
=
reinterpret_cast
<
int4
*>
(
packed_recv_x
)
+
static_cast
<
int4
*>
(
packed_recv_x
)
+
local_expert_idx
*
num_ranks
*
num_max_dispatch_tokens_per_rank
*
hidden_int4
;
local_expert_idx
*
num_ranks
*
num_max_dispatch_tokens_per_rank
*
hidden_int4
;
const
auto
recv_x_scales
=
packed_recv_x_scales
+
local_expert_idx
*
num_ranks
*
num_max_dispatch_tokens_per_rank
*
num_scales
;
const
auto
recv_src_info
=
packed_recv_src_info
+
local_expert_idx
*
num_ranks
*
num_max_dispatch_tokens_per_rank
;
const
auto
recv_src_info
=
packed_recv_src_info
+
local_expert_idx
*
num_ranks
*
num_max_dispatch_tokens_per_rank
;
const
auto
recv_range
=
packed_recv_layout_range
+
local_expert_idx
*
num_ranks
;
const
auto
recv_range
=
packed_recv_layout_range
+
local_expert_idx
*
num_ranks
;
const
auto
num_aligned_scales
=
ALIGN
<
int
>
(
num_scales
,
sizeof
(
float
)
/
sizeof
(
scale_t
));
const
auto
recv_x_scales
=
static_cast
<
scale_t
*>
(
packed_recv_x_scales
)
+
local_expert_idx
*
num_ranks
*
num_max_dispatch_tokens_per_rank
*
num_aligned_scales
;
// Shared between sub-warps in warp groups
// Shared between sub-warps in warp groups
__shared__
int
shared_num_recv_tokens
[
kNumMaxWarpGroups
],
shared_recv_token_begin_idx
[
kNumMaxWarpGroups
];
__shared__
int
shared_num_recv_tokens
[
kNumMaxWarpGroups
],
shared_recv_token_begin_idx
[
kNumMaxWarpGroups
];
// Wait tokens to arrive
// Wait tokens to arrive
// NOTES: using sub-warp 1 to overlap with sub-warp 0
// NOTES: using sub-warp 1 to overlap with sub-warp 0
int64_t
num_recv_tokens
;
int
num_recv_tokens
,
recv_token_begin_idx
;
int
recv_token_begin_idx
;
EP_DEVICE_ASSERT
(
num_warps_per_group
>
1
);
EP_DEVICE_ASSERT
(
num_warps_per_group
>
1
and
num_warp_groups
<
15
);
if
(
sub_warp_id
==
1
and
lane_id
==
0
)
{
if
(
sub_warp_id
==
1
and
lane_id
==
0
)
{
auto
start_time
=
wall_clock64
();
while
((
num_recv_tokens
=
ld_acquire_global
(
reinterpret_cast
<
int
*>
(
rdma_recv_count
+
local_expert_idx
*
num_ranks
+
src_rank
)))
==
0
);
int64_t
wait_recv_cost
=
0
;
int
offset
=
local_expert_idx
*
num_ranks
+
src_rank
;
if
(
not
is_rank_masked
(
mask_buffer_ptr
,
src_rank
))
{
while
((
wait_recv_cost
=
wall_clock64
()
-
start_time
)
<=
NUM_TIMEOUT_CYCLES
)
{
// not timeout
if
((
num_recv_tokens
=
ld_acquire_global
(
reinterpret_cast
<
int64_t
*>
(
rdma_recv_count
+
local_expert_idx
*
num_ranks
+
src_rank
)))
!=
0
)
{
break
;
}
}
}
// Mask rank if timeout
if
(
wait_recv_cost
>
NUM_TIMEOUT_CYCLES
)
{
printf
(
"Warning: DeepEP timeout for dispatch receive, rank %d, local_expert_idx %d, src_rank %d
\n
"
,
rank
,
local_expert_idx
,
src_rank
);
if
(
mask_buffer_ptr
==
nullptr
)
trap
();
atomicExch
(
mask_buffer_ptr
+
src_rank
,
1
);
}
// Do not receive tokens if rank timeout or masked
if
(
num_recv_tokens
==
0
)
num_recv_tokens
=
-
1
;
#if 1
num_recv_tokens
=
-
num_recv_tokens
-
1
;
num_recv_tokens
=
-
num_recv_tokens
-
1
;
int
num_recv_tokens_int32
=
static_cast
<
int
>
(
num_recv_tokens
);
recv_token_begin_idx
=
atomicAdd
(
packed_recv_count
+
local_expert_idx
,
num_recv_tokens
);
shared_num_recv_tokens
[
warp_group_id
]
=
num_recv_tokens
;
recv_token_begin_idx
=
atomicAdd
(
packed_recv_count
+
local_expert_idx
,
num_recv_tokens_int32
);
shared_num_recv_tokens
[
warp_group_id
]
=
num_recv_tokens_int32
;
shared_recv_token_begin_idx
[
warp_group_id
]
=
recv_token_begin_idx
;
shared_recv_token_begin_idx
[
warp_group_id
]
=
recv_token_begin_idx
;
recv_range
[
src_rank
]
=
pack2
<
int
,
int64_t
>
(
num_recv_tokens_int32
,
recv_token_begin_idx
);
recv_range
[
src_rank
]
=
pack2
<
int
,
int64_t
>
(
num_recv_tokens
,
recv_token_begin_idx
);
// Add stats for diagnosis
if
(
cumulative_local_expert_recv_stats
!=
nullptr
)
atomicAdd
(
cumulative_local_expert_recv_stats
+
local_expert_idx
,
num_recv_tokens_int32
);
if
(
dispatch_wait_recv_cost_stats
!=
nullptr
)
{
atomicAdd
(
reinterpret_cast
<
uint64_t
*>
(
dispatch_wait_recv_cost_stats
+
src_rank
),
static_cast
<
uint64_t
>
(
wait_recv_cost
));
}
#endif
}
}
#if 1
#ifdef USE_ROCM
// no needs to reset because there is no iteration
// no needs to reset because there is no iteration
if
(
lane_id
==
0
){
if
(
lane_id
==
0
){
// volatile int ret = __hip_atomic_fetch_add(&sync_large_warp_counters[warp_group_id], 1, __ATOMIC_RELAXED, __HIP_MEMORY_SCOPE_WORKGROUP);
volatile
int
ret
=
atomicAdd
((
int
*
)
&
sync_large_warp_counters
[
warp_group_id
],
1
);
volatile
int
ret
=
atomicAdd
((
int
*
)
&
sync_large_warp_counters
[
warp_group_id
],
1
);
}
}
syncwarp
();
syncwarp
();
while
(
sync_large_warp_counters
[
warp_group_id
]
<
num_warps_per_group
)
{}
#else
// asm volatile("bar.sync %0, %1;" ::"r"(warp_group_id + 2), "r"(num_warps_per_group * 32));
#endif
while
(
sync_large_warp_counters
[
warp_group_id
]
<
num_warps_per_group
);
num_recv_tokens
=
shared_num_recv_tokens
[
warp_group_id
];
num_recv_tokens
=
shared_num_recv_tokens
[
warp_group_id
];
recv_token_begin_idx
=
shared_recv_token_begin_idx
[
warp_group_id
];
recv_token_begin_idx
=
shared_recv_token_begin_idx
[
warp_group_id
];
...
@@ -458,506 +361,308 @@ __launch_bounds__(1024, 1) __global__ void dispatch(void* packed_recv_x,
...
@@ -458,506 +361,308 @@ __launch_bounds__(1024, 1) __global__ void dispatch(void* packed_recv_x,
// NOTES: only 2 load iterations for 7K hidden with 7 unrolls
// NOTES: only 2 load iterations for 7K hidden with 7 unrolls
const
auto
src_data
=
reinterpret_cast
<
int4
*>
(
reinterpret_cast
<
uint8_t
*>
(
src_src_idx
)
+
sizeof
(
int4
));
const
auto
src_data
=
reinterpret_cast
<
int4
*>
(
reinterpret_cast
<
uint8_t
*>
(
src_src_idx
)
+
sizeof
(
int4
));
const
auto
dst_data
=
recv_x_int4
+
(
recv_token_begin_idx
+
i
)
*
hidden_int4
;
const
auto
dst_data
=
recv_x_int4
+
(
recv_token_begin_idx
+
i
)
*
hidden_int4
;
UNROLLED_WARP_COPY
(
7
,
lane_id
,
hidden_int4
,
dst_data
,
src_data
,
ld_nc_global
,
st_na_global
);
UNROLLED_WARP_COPY
_LL
(
7
,
lane_id
,
hidden_int4
,
dst_data
,
src_data
,
ld_nc_global
,
st_na_global
);
// Copy scales
// Copy scales
if
constexpr
(
kUseFP8
)
{
if
(
kUseFP8
)
{
// Equivalent CuTe layout:
// (num_tokens, (num_packed, num_elems_per_pack)):(num_elems_per_pack, (num_tokens * num_elems_per_pack, 1))
const
auto
src_scales
=
reinterpret_cast
<
float
*>
(
reinterpret_cast
<
uint8_t
*>
(
src_data
)
+
hidden_bytes
);
const
auto
src_scales
=
reinterpret_cast
<
float
*>
(
reinterpret_cast
<
uint8_t
*>
(
src_data
)
+
hidden_bytes
);
const
auto
num_elems_per_pack
=
static_cast
<
int
>
(
sizeof
(
packed_t
)
/
sizeof
(
scale_t
));
const
auto
dst_scales
=
reinterpret_cast
<
float
*>
(
recv_x_scales
+
recv_token_begin_idx
+
i
);
const
auto
token_idx
=
recv_token_begin_idx
+
i
;
const
auto
scale_stride
=
num_ranks
*
num_max_dispatch_tokens_per_rank
;
const
auto
token_stride
=
num_elems_per_pack
;
auto
scale_0
=
lane_id
<
num_scales
?
ld_nc_global
(
src_scales
+
lane_id
)
:
0
;
const
auto
pack_stride
=
num_ranks
*
num_max_dispatch_tokens_per_rank
*
num_elems_per_pack
;
auto
scale_1
=
(
lane_id
+
kWarpSize
)
<
num_scales
?
ld_nc_global
(
src_scales
+
lane_id
+
kWarpSize
)
:
0
;
if
(
lane_id
<
num_scales
)
{
lane_id
<
num_scales
?
dst_scales
[
lane_id
*
scale_stride
]
=
scale_0
:
0.0
f
;
const
auto
pack_idx
=
lane_id
/
num_elems_per_pack
;
(
lane_id
+
kWarpSize
)
<
num_scales
?
dst_scales
[(
lane_id
+
kWarpSize
)
*
scale_stride
]
=
scale_1
:
0.0
f
;
const
auto
elem_idx
=
lane_id
%
num_elems_per_pack
;
auto
scale
=
extract_required_scale_format
<
kUseUE8M0
>
(
ld_nc_global
(
src_scales
+
lane_id
));
recv_x_scales
[
token_idx
*
token_stride
+
pack_idx
*
pack_stride
+
elem_idx
]
=
scale
;
}
if
(
lane_id
+
kWarpSize
<
num_scales
)
{
const
auto
pack_idx
=
(
lane_id
+
kWarpSize
)
/
num_elems_per_pack
;
const
auto
elem_idx
=
(
lane_id
+
kWarpSize
)
%
num_elems_per_pack
;
auto
scale
=
extract_required_scale_format
<
kUseUE8M0
>
(
ld_nc_global
(
src_scales
+
lane_id
+
kWarpSize
));
recv_x_scales
[
token_idx
*
token_stride
+
pack_idx
*
pack_stride
+
elem_idx
]
=
scale
;
}
}
}
}
}
#endif
}
}
#if !defined(ROCM_DISABLE_CTX)
#if !defined(ROCM_DISABLE_CTX)
rocshmem
::
rocshmem_wg_ctx_destroy
(
&
ctx
);
rocshmem
::
rocshmem_wg_ctx_destroy
(
&
ctx
);
#endif
#endif
}
}
void
dispatch
(
void
*
packed_recv_x
,
void
dispatch
(
void
*
packed_recv_x
,
float
*
packed_recv_x_scales
,
void
*
packed_recv_x_scales
,
int
*
packed_recv_src_info
,
int64_t
*
packed_recv_layout_range
,
int
*
packed_recv_src_info
,
int64_t
*
packed_recv_layout_range
,
int
*
packed_recv_count
,
int
*
packed_recv_count
,
int
*
global_atomic_counter
,
int
*
global_atomic_counter
,
int
*
mask_buffer_ptr
,
void
*
rdma_recv_x
,
int64_t
*
rdma_recv_count
,
void
*
rdma_x
,
int
*
cumulative_local_expert_recv_stats
,
const
void
*
x
,
const
int64_t
*
topk_idx
,
int64_t
*
dispatch_wait_recv_cost_stats
,
int64_t
*
next_clean
,
int
num_next_clean_int
,
void
*
rdma_recv_x
,
int
num_tokens
,
int
hidden
,
int
num_max_dispatch_tokens_per_rank
,
int64_t
*
rdma_recv_count
,
int
num_topk
,
int
num_experts
,
int
rank
,
int
num_ranks
,
bool
use_fp8
,
void
*
rdma_x
,
void
*
workspace
,
hipStream_t
stream
,
int
phases
)
{
const
void
*
x
,
const
int64_t
*
topk_idx
,
int64_t
*
next_clean
,
int
num_next_clean_int
,
int
num_tokens
,
int
hidden
,
int
num_max_dispatch_tokens_per_rank
,
int
num_topk
,
int
num_experts
,
int
rank
,
int
num_ranks
,
bool
use_fp8
,
bool
round_scale
,
bool
use_ue8m0
,
void
*
workspace
,
int
num_device_sms
,
hipStream_t
stream
,
int
phases
)
{
constexpr
int
kNumMaxTopK
=
11
;
constexpr
int
kNumMaxTopK
=
11
;
const
int
num_warp_groups
=
DIVUP
(
num_experts
,
num_device_sms
);
const
int
num_warp_groups
=
ceil_div
(
num_experts
,
/*num_device_sms*/
80
);
EP_HOST_ASSERT
(
num_warp_groups
<=
16
);
const
int
num_warps_per_group
=
16
/
num_warp_groups
;
const
int
num_warps_per_group
=
16
/
num_warp_groups
;
// 每个kernel最大16个warp
EP_HOST_ASSERT
(
num_warp_groups
>
0
and
num_warps_per_group
>
0
);
EP_HOST_ASSERT
(
num_warp_groups
>
0
and
num_warps_per_group
>
0
);
EP_HOST_ASSERT
(
kNumMaxTopK
+
1
<=
num_warp_groups
*
num_warps_per_group
);
EP_HOST_ASSERT
(
kNumMaxTopK
+
1
<=
num_warp_groups
*
num_warps_per_group
);
const
auto
num_warps
=
num_warp_groups
*
num_warps_per_group
;
const
auto
num_warps
=
num_warp_groups
*
num_warps_per_group
;
const
auto
num_sms
=
DIVUP
(
num_experts
,
num_warp_groups
);
const
auto
num_sms
=
ceil_div
(
num_experts
,
num_warp_groups
);
EP_HOST_ASSERT
(
num_topk
<=
kNumMaxTopK
);
EP_HOST_ASSERT
(
num_topk
<=
kNumMaxTopK
);
// Workspace checks
// Workspace checks
auto
atomic_counter_per_expert
=
static
_cast
<
int
*>
(
workspace
);
auto
atomic_counter_per_expert
=
reinterpret
_cast
<
int
*>
(
workspace
);
auto
atomic_finish_counter_per_expert
=
atomic_counter_per_expert
+
num_experts
;
auto
atomic_finish_counter_per_expert
=
atomic_counter_per_expert
+
num_experts
;
EP_HOST_ASSERT
(
num_experts
*
sizeof
(
int
)
*
2
<=
NUM_WORKSPACE_BYTES
);
EP_HOST_ASSERT
(
num_experts
*
sizeof
(
int
)
*
2
<=
NUM_WORKSPACE_BYTES
);
#define DISPATCH_LAUNCH_CASE(hidden) \
#define DISPATCH_LAUNCH_CASE(hidden) { \
{ \
auto dispatch_func = use_fp8 ? dispatch<true, hidden> : \
auto dispatch_func = dispatch<false, false, hidden>; \
dispatch<false, hidden>; \
if(use_fp8 and not use_ue8m0) \
LAUNCH_KERNEL_NON_COOPERATIVE(&cfg, dispatch_func, \
dispatch_func = dispatch<true, false, hidden>; \
packed_recv_x, packed_recv_x_scales, \
if(use_fp8 and use_ue8m0) \
packed_recv_src_info, packed_recv_layout_range, \
dispatch_func = dispatch<true, true, hidden>; \
LAUNCH_KERNEL_NON_COOPERATIVE(&cfg, \
dispatch_func, \
packed_recv_x, \
packed_recv_x_scales, \
packed_recv_src_info, \
packed_recv_layout_range, \
packed_recv_count, \
packed_recv_count, \
global_atomic_counter, \
global_atomic_counter, \
mask_buffer_ptr, \
rdma_recv_x, rdma_recv_count, rdma_x, \
cumulative_local_expert_recv_stats, \
x, topk_idx, \
dispatch_wait_recv_cost_stats, \
atomic_counter_per_expert, atomic_finish_counter_per_expert, \
rdma_recv_x, \
next_clean, num_next_clean_int, \
rdma_recv_count, \
num_tokens, num_max_dispatch_tokens_per_rank, \
rdma_x, \
num_topk, num_experts, rank, num_ranks, \
x, \
num_warp_groups, num_warps_per_group, phases); } break
topk_idx, \
atomic_counter_per_expert, \
atomic_finish_counter_per_expert, \
next_clean, \
num_next_clean_int, \
num_tokens, \
num_max_dispatch_tokens_per_rank, \
num_topk, \
num_experts, \
rank, \
num_ranks, \
num_warp_groups, \
num_warps_per_group, \
round_scale, \
phases); \
} \
break
SETUP_LAUNCH_CONFIG
(
num_sms
,
num_warps
*
kWarpSize
,
stream
);
SETUP_LAUNCH_CONFIG
(
num_sms
,
num_warps
*
kWarpSize
,
stream
);
SWITCH_HIDDEN
(
DISPATCH_LAUNCH_CASE
);
SWITCH_HIDDEN
(
DISPATCH_LAUNCH_CASE
);
#undef DISPATCH_LAUNCH_CASE
#undef DISPATCH_LAUNCH_CASE
}
}
template
<
bool
kUseLogFMT
,
int
kHidden
,
int
kNumMaxTopk
,
int
kNumMaxUnrolls
>
template
<
int
kNumWarpGroups
,
int
kNumWarpsPerGroup
,
int
kHidden
,
int
kNumMaxTopk
>
__launch_bounds__
(
1024
,
1
)
__global__
void
combine
(
void
*
combined_x
,
__global__
__launch_bounds__
(
kNumWarpGroups
*
kNumWarpsPerGroup
*
kWarpSize
,
1
)
void
void
*
rdma_recv_x
,
combine
(
void
*
combined_x
,
int
*
rdma_recv_flag
,
void
*
rdma_recv_x
,
int64_t
*
rdma_recv_flag
,
void
*
rdma_send_x
,
void
*
rdma_send_x
,
const
void
*
x
,
const
int64_t
*
topk_idx
,
const
float
*
topk_weights
,
const
void
*
x
,
const
int
*
src_info
,
const
int64_t
*
layout_range
,
const
int64_t
*
topk_idx
,
const
float
*
topk_weights
,
const
int
*
src_info
,
const
int64_t
*
layout_range
,
int
*
global_atomic_counter
,
int
*
global_atomic_counter
,
int
*
mask_buffer_ptr
,
int64_t
*
next_clean
,
int
num_next_clean_int
,
int64_t
*
combine_wait_recv_cost_stats
,
int64_t
*
next_clean
,
int
num_next_clean_int
,
int
*
atomic_clean_flag
,
int
*
atomic_clean_flag
,
int
num_combined_tokens
,
int
num_combined_tokens
,
int
hidden
,
int
num_topk
,
int
hidden
,
int
num_topk
,
int
num_max_dispatch_tokens_per_rank
,
int
num_max_dispatch_tokens_per_rank
,
int
num_experts
,
int
num_experts
,
int
rank
,
int
num_ranks
,
int
rank
,
int
phases
,
bool
zero_copy
)
{
int
num_ranks
,
int
num_warp_groups
,
int
num_warps_per_group
,
int
phases
,
bool
zero_copy
)
{
#if !defined(ROCM_DISABLE_CTX)
#if !defined(ROCM_DISABLE_CTX)
__shared__
rocshmem
::
rocshmem_ctx_t
ctx
;
__shared__
rocshmem
::
rocshmem_ctx_t
ctx
;
rocshmem
::
rocshmem_wg_ctx_create
(
0
,
&
ctx
);
rocshmem
::
rocshmem_wg_ctx_create
(
0
,
&
ctx
);
#endif
#endif
const
auto
sm_id
=
static_cast
<
int
>
(
blockIdx
.
x
);
const
auto
num_sms
=
static_cast
<
int
>
(
gridDim
.
x
);
const
auto
thread_id
=
static_cast
<
int
>
(
threadIdx
.
x
);
const
auto
num_threads
=
static_cast
<
int
>
(
blockDim
.
x
);
const
auto
warp_id
=
thread_id
/
kWarpSize
,
lane_id
=
get_lane_id
();
const
auto
num_local_experts
=
num_experts
/
num_ranks
;
const
auto
warp_group_id
=
warp_id
/
kNumWarpsPerGroup
;
const
auto
sub_warp_id
=
warp_id
%
kNumWarpsPerGroup
;
const
auto
responsible_expert_idx
=
sm_id
*
kNumWarpGroups
+
warp_group_id
;
// Data type staffs
constexpr
int
kNumElemsPerInt4
=
sizeof
(
int4
)
/
sizeof
(
hip_bfloat16
);
const
size_t
hidden_bf16_int4
=
kHidden
/
kNumElemsPerInt4
;
// Message package
// BF16 mode: always use BF16 for hidden data (ignoring the extra flag slot)
constexpr
size_t
num_bytes_per_slot
=
sizeof
(
int4
)
+
kHidden
*
sizeof
(
hip_bfloat16
);
EP_STATIC_ASSERT
(
num_bytes_per_slot
%
sizeof
(
int4
)
==
0
,
"Invalid vectorization"
);
__syncthreads
();
// 16 is the max possible number of warps in AMD GPUs
constexpr
int
kMaxNumWarps
=
1024
/
kWarpSize
;
__shared__
volatile
int
sync_large_warp_counters
[
kMaxNumWarps
];
if
(
threadIdx
.
x
==
0
){
#pragma unroll
for
(
int
i
=
0
;
i
<
kMaxNumWarps
;
++
i
)
{
sync_large_warp_counters
[
i
]
=
0
;
}
}
__syncthreads
();
// const auto sm_id = static_cast<int>(blockIdx.x);
// Sending phase
// const auto num_sms = static_cast<int>(gridDim.x);
if
((
phases
&
LOW_LATENCY_SEND_PHASE
)
==
0
)
// const auto thread_id = static_cast<int>(threadIdx.x);
goto
LOW_LATENCY_COMBINE_RECV
;
// const auto num_threads = static_cast<int>(blockDim.x);
// const auto warp_id = thread_id / kWarpSize, lane_id = get_lane_id();
// const auto num_local_experts = num_experts / num_ranks;
// const auto warp_group_id = warp_id / kNumWarpsPerGroup;
// const auto sub_warp_id = warp_id % kNumWarpsPerGroup;
// const auto responsible_expert_idx = sm_id * kNumWarpGroups + warp_group_id;
// // Data type staffs
// constexpr int kNumElemsPerInt4 = sizeof(int4) / sizeof(gpu_bfloat16_t);
// const size_t hidden_bf16_int4 = kHidden / kNumElemsPerInt4;
// // Message package
// // BF16 mode: always use BF16 for hidden data (ignoring the extra flag slot)
// constexpr size_t num_bytes_per_slot = sizeof(int4) + kHidden * sizeof(gpu_bfloat16_t);
// EP_STATIC_ASSERT(num_bytes_per_slot % sizeof(int4) == 0, "Invalid vectorization");
// __syncthreads();
// #ifdef USE_ROCM
// // 16 is the max possible number of warps in AMD GPUs
// constexpr int kMaxNumWarps = 1024 / kWarpSize;
// __shared__ volatile int sync_large_warp_counters[kMaxNumWarps];
// if (threadIdx.x==0){
// // printf("combine");
// #pragma unroll
// for (int i = 0; i < kMaxNumWarps; ++i) {
// sync_large_warp_counters[i] = 0;
// }
// }
// __syncthreads();
// #endif
// // Sending phase
// if ((phases & LOW_LATENCY_SEND_PHASE) == 0)
// goto LOW_LATENCY_COMBINE_RECV;
// // Clean up next buffer
// if (sm_id == 0 and warp_group_id == 0 and sub_warp_id == 0) {
// #pragma unroll
// for (int i = lane_id; i < num_next_clean_int; i += kWarpSize)
// next_clean[i] = 0;
// // Notify before executing `int_p`
// syncwarp();
// if (lane_id == 0)
// atomic_add_release_global(atomic_clean_flag, num_experts);
// }
// // Issue IBGDA sends
// if (responsible_expert_idx < num_experts) {
// const auto dst_rank = responsible_expert_idx / num_local_experts;
// const auto local_expert_idx = responsible_expert_idx % num_local_experts;
// const auto global_expert_idx = rank * num_local_experts + local_expert_idx;
// const auto layout = __ldg(layout_range + local_expert_idx * num_ranks + dst_rank);
// const auto local_x = reinterpret_cast<const int4*>(x) +
// local_expert_idx * num_ranks * num_max_dispatch_tokens_per_rank * hidden_bf16_int4;
// const auto local_src_info = src_info + local_expert_idx * num_ranks * num_max_dispatch_tokens_per_rank;
// const auto rdma_send_x_vec = reinterpret_cast<uint8_t*>(rdma_send_x) +
// local_expert_idx * num_ranks * num_max_dispatch_tokens_per_rank * num_bytes_per_slot;
// // Unpack layout
// int offset, num_tokens_to_send;
// unpack2(layout, num_tokens_to_send, offset);
// // Issue IBGDA send
// for (int token_idx = offset + sub_warp_id; token_idx < offset + num_tokens_to_send; token_idx += kNumWarpsPerGroup) {
// const auto x_int4 = local_x + token_idx * hidden_bf16_int4;
// const auto rdma_send_type_row = reinterpret_cast<int*>(rdma_send_x_vec + token_idx * num_bytes_per_slot);
// const auto rdma_send_x_vec_row = reinterpret_cast<uint8_t*>(rdma_send_type_row + 4);
// // Copy directly to local rank, or copy to buffer and issue RDMA
// auto src_idx = __ldg(local_src_info + token_idx);
// const auto buf_ptr = reinterpret_cast<int64_t>(rdma_send_x_vec_row);
// const auto dst_ptr = reinterpret_cast<uint64_t>(rdma_recv_x) + (global_expert_idx * num_max_dispatch_tokens_per_rank + src_idx) * num_bytes_per_slot + sizeof(int4);
// if (dst_rank == rank) {
// const auto dst_int4_ptr = reinterpret_cast<int4*>(dst_ptr);
// UNROLLED_WARP_COPY(7, lane_id, hidden_bf16_int4, dst_int4_ptr, x_int4, ld_nc_global, st_na_global);
// } else {
// const auto buf_int4_ptr = reinterpret_cast<int4*>(buf_ptr);
// if (not zero_copy)
// UNROLLED_WARP_COPY(7, lane_id, hidden_bf16_int4, buf_int4_ptr, x_int4, ld_nc_global, st_na_global);
// //nvshmemi_ibgda_put_nbi_warp(dst_ptr, buf_ptr, hidden * sizeof(gpu_bfloat16_t), dst_rank, local_expert_idx, lane_id, token_idx - offset);
// #if defined(ROCM_DISABLE_CTX)
// internode::shmemx_int8_put_nbi_warp(
// #else
// internode::shmem_ctx_schar_put_nbi_warp(ctx,
// #endif
// reinterpret_cast<signed char*>(dst_ptr), reinterpret_cast<signed char*>(buf_ptr), hidden * sizeof(gpu_bfloat16_t), dst_rank);
// #if defined(ROCM_DISABLE_CTX)
// internode::shmem_fence();
// #else
// internode::shmem_ctx_quiet(ctx);
// #endif
// }
// }
// // Put finishing flag
// EP_STATIC_ASSERT(kNumWarpsPerGroup > 1, "Requires more than one warp per group");
// #ifdef USE_ROCM
// if (lane_id == 0){
// volatile int ret = __hip_atomic_fetch_add(
// &sync_large_warp_counters[warp_group_id], 1,
// __ATOMIC_RELAXED, __HIP_MEMORY_SCOPE_WORKGROUP);
// }
// syncwarp();
// while (sync_large_warp_counters[warp_group_id] < (kNumWarpsPerGroup));
// #else
// asm volatile("bar.sync %0, %1;" :: "r"(warp_group_id + 1), "r"(kNumWarpsPerGroup * 32));
// #endif
// if (sub_warp_id == 1 and lane_id == 0) {
// while (ld_acquire_global(atomic_clean_flag) == 0);
// if (dst_rank != rank) {
// #ifdef USE_ROCM
// #if defined(ROCM_DISABLE_CTX)
// internode::shmem_long_atomic_add(rdma_recv_flag + global_expert_idx, 1, dst_rank);
// #else
// internode::shmem_ctx_long_atomic_add(ctx, rdma_recv_flag + global_expert_idx, 1, dst_rank);
// #endif
// #else
// nvshmemi_ibgda_amo_nonfetch_add(rdma_recv_flag + global_expert_idx, 1, dst_rank, local_expert_idx);
// #endif
// } else {
// st_na_release(reinterpret_cast<int*>(rdma_recv_flag + global_expert_idx), 1);
// }
// atomic_add_release_global(atomic_clean_flag, -1);
// }
// syncwarp();
// }
// // Receiving phase
// LOW_LATENCY_COMBINE_RECV:
// if ((phases & LOW_LATENCY_RECV_PHASE) == 0)
// return;
// // Wait all ranks to arrive and notify PCIe usage
// if (responsible_expert_idx < num_experts) {
// EP_STATIC_ASSERT(kNumWarpsPerGroup > 1, "Invalid number of warps per group");
// if (sub_warp_id == 0 and lane_id == 0){
// while (ld_acquire_global(reinterpret_cast<int*>(rdma_recv_flag + responsible_expert_idx)) == 0);
// }
// }
// grid_barrier(global_atomic_counter, num_sms);
// // Reduce tokens with FP8 cast
// EP_DEVICE_ASSERT(num_topk <= kWarpSize and hidden_bf16_int4 <= num_threads);
// EP_STATIC_ASSERT(kHidden % (kWarpSize * kNumElemsPerInt4) == 0, "Invalid vectorization");
// if (thread_id < hidden_bf16_int4) {
// for (int token_idx = sm_id; token_idx < num_combined_tokens; token_idx += num_sms) {
// // Read top-k indices and weights
// int reg_topk_idx[kNumMaxTopk];
// float reg_topk_weights[kNumMaxTopk];
// #pragma unroll
// for (int i = 0; i < num_topk; ++ i) {
// reg_topk_idx[i] = static_cast<int>(__ldg(topk_idx + token_idx * num_topk + i));
// reg_topk_weights[i] = __ldg(topk_weights + token_idx * num_topk + i);
// }
// float combined_values[kNumElemsPerInt4] = {0.0f};
// #pragma unroll
// for (int i = 0; i < num_topk; ++ i) if (reg_topk_idx[i] >= 0) {
// // Read from sources
// auto rdma_buffer_type = reinterpret_cast<const int*>(reinterpret_cast<uint8_t*>(rdma_recv_x) + (reg_topk_idx[i] * num_max_dispatch_tokens_per_rank + token_idx) * num_bytes_per_slot);
// auto rdma_buffer_row = reinterpret_cast<const uint8_t*>(rdma_buffer_type + 4);
// // Reduce
// auto x_vec = ld_nc_global(reinterpret_cast<const int4*>(rdma_buffer_row) + thread_id);
// const auto x_bf16 = reinterpret_cast<gpu_bfloat16_t*>(&x_vec);
// #pragma unroll
// for (int j = 0; j < kNumElemsPerInt4; ++ j)
// combined_values[j] += static_cast<float>(x_bf16[j]) * reg_topk_weights[i];
// }
// // Write results
// int4& combined_int4 = *reinterpret_cast<int4*>(combined_values);
// auto combined_bf16 = reinterpret_cast<gpu_bfloat16_t*>(&combined_values);
// #pragma unroll
// for (int j = 0; j < kNumElemsPerInt4; ++ j)
// combined_bf16[j] = static_cast<gpu_bfloat16_t>(combined_values[j]);
// (reinterpret_cast<int4*>(combined_x) + token_idx * hidden_bf16_int4)[thread_id] = combined_int4;
// }
// }
#if !defined(ROCM_DISABLE_CTX)
rocshmem
::
rocshmem_wg_ctx_destroy
(
&
ctx
);
#endif
}
void
combine
(
void
*
combined_x
,
// Clean up next buffer
void
*
rdma_recv_x
,
if
(
sm_id
==
0
and
warp_group_id
==
0
and
sub_warp_id
==
0
)
{
int64_t
*
rdma_recv_flag
,
#pragma unroll
void
*
rdma_send_x
,
for
(
int
i
=
lane_id
;
i
<
num_next_clean_int
;
i
+=
kWarpSize
)
const
void
*
x
,
next_clean
[
i
]
=
0
;
const
int64_t
*
topk_idx
,
const
float
*
topk_weights
,
const
int
*
src_info
,
const
int64_t
*
layout_range
,
int
*
global_atomic_counter
,
int
*
mask_buffer_ptr
,
int64_t
*
combine_wait_recv_cost_stats
,
int64_t
*
next_clean
,
int
num_next_clean_int
,
int
num_combined_tokens
,
int
hidden
,
int
num_max_dispatch_tokens_per_rank
,
int
num_topk
,
int
num_experts
,
int
rank
,
int
num_ranks
,
bool
use_logfmt
,
void
*
workspace
,
int
num_device_sms
,
hipStream_t
stream
,
int
phases
,
bool
zero_copy
)
{
constexpr
int
kNumMaxTopk
=
11
;
const
int
num_warp_groups
=
DIVUP
(
num_experts
,
num_device_sms
);
const
int
num_warps_per_group
=
16
/
num_warp_groups
;
const
int
num_recv_per_sm
=
DIVUP
(
num_combined_tokens
,
num_device_sms
);
EP_HOST_ASSERT
(
num_warp_groups
>
0
and
num_warps_per_group
>
0
and
num_recv_per_sm
>=
0
);
const
auto
num_warps
=
num_warp_groups
*
num_warps_per_group
;
// Notify before executing `int_p`
const
auto
num_sms
=
max
(
DIVUP
(
num_experts
,
num_warp_groups
),
num_recv_per_sm
==
0
?
1
:
DIVUP
(
num_combined_tokens
,
num_recv_per_sm
));
syncwarp
();
if
(
lane_id
==
0
)
atomic_add_release_global
(
atomic_clean_flag
,
num_experts
);
}
// Check workspace
// Issue IBGDA sends
auto
atomic_clean_flag
=
static_cast
<
int
*>
(
workspace
);
if
(
responsible_expert_idx
<
num_experts
)
{
EP_HOST_ASSERT
(
sizeof
(
int
)
<=
NUM_WORKSPACE_BYTES
);
const
auto
dst_rank
=
responsible_expert_idx
/
num_local_experts
;
EP_HOST_ASSERT
(
num_topk
<=
kNumMaxTopk
);
const
auto
local_expert_idx
=
responsible_expert_idx
%
num_local_experts
;
const
auto
global_expert_idx
=
rank
*
num_local_experts
+
local_expert_idx
;
const
auto
layout
=
__ldg
(
layout_range
+
local_expert_idx
*
num_ranks
+
dst_rank
);
const
auto
local_x
=
reinterpret_cast
<
const
int4
*>
(
x
)
+
local_expert_idx
*
num_ranks
*
num_max_dispatch_tokens_per_rank
*
hidden_bf16_int4
;
const
auto
local_src_info
=
src_info
+
local_expert_idx
*
num_ranks
*
num_max_dispatch_tokens_per_rank
;
const
auto
rdma_send_x_vec
=
reinterpret_cast
<
uint8_t
*>
(
rdma_send_x
)
+
local_expert_idx
*
num_ranks
*
num_max_dispatch_tokens_per_rank
*
num_bytes_per_slot
;
// Unpack layout
int
offset
,
num_tokens_to_send
;
unpack2
(
layout
,
num_tokens_to_send
,
offset
);
// Issue IBGDA send
for
(
int
token_idx
=
offset
+
sub_warp_id
;
token_idx
<
offset
+
num_tokens_to_send
;
token_idx
+=
kNumWarpsPerGroup
)
{
const
auto
x_int4
=
local_x
+
token_idx
*
hidden_bf16_int4
;
const
auto
rdma_send_type_row
=
reinterpret_cast
<
int
*>
(
rdma_send_x_vec
+
token_idx
*
num_bytes_per_slot
);
const
auto
rdma_send_x_vec_row
=
reinterpret_cast
<
uint8_t
*>
(
rdma_send_type_row
+
4
);
// Copy directly to local rank, or copy to buffer and issue RDMA
auto
src_idx
=
__ldg
(
local_src_info
+
token_idx
);
const
auto
buf_ptr
=
reinterpret_cast
<
int64_t
>
(
rdma_send_x_vec_row
);
const
auto
dst_ptr
=
reinterpret_cast
<
uint64_t
>
(
rdma_recv_x
)
+
(
global_expert_idx
*
num_max_dispatch_tokens_per_rank
+
src_idx
)
*
num_bytes_per_slot
+
sizeof
(
int4
);
if
(
dst_rank
==
rank
)
{
const
auto
dst_int4_ptr
=
reinterpret_cast
<
int4
*>
(
dst_ptr
);
UNROLLED_WARP_COPY_LL
(
7
,
lane_id
,
hidden_bf16_int4
,
dst_int4_ptr
,
x_int4
,
ld_nc_global
,
st_na_global
);
}
else
{
const
auto
buf_int4_ptr
=
reinterpret_cast
<
int4
*>
(
buf_ptr
);
if
(
not
zero_copy
)
UNROLLED_WARP_COPY_LL
(
7
,
lane_id
,
hidden_bf16_int4
,
buf_int4_ptr
,
x_int4
,
ld_nc_global
,
st_na_global
);
// Online cast cannot use zero-copy
//nvshmemi_ibgda_put_nbi_warp(dst_ptr, buf_ptr, hidden * sizeof(hip_bfloat16), dst_rank, local_expert_idx, lane_id, token_idx - offset);
EP_HOST_ASSERT
(
not
(
zero_copy
and
use_logfmt
));
#if defined(ROCM_DISABLE_CTX)
EP_HOST_ASSERT
(
use_logfmt
==
0
);
rocshmem
::
rocshmem_schar_put_nbi_wave
(
#else
constexpr
int
kNumMaxUnrolls
=
4
;
rocshmem
::
rocshmem_ctx_schar_put_nbi_wave
(
ctx
,
#endif
#ifdef USEING_TMA
reinterpret_cast
<
signed
char
*>
(
dst_ptr
),
reinterpret_cast
<
signed
char
*>
(
buf_ptr
),
hidden
*
sizeof
(
hip_bfloat16
),
dst_rank
);
constexpr
int
kNumStages
=
3
;
constexpr
int
kMaxNumGroups
=
2
;
// Send buffer size
const
int
num_meta_bytes
=
hidden
/
FP8_QUANTIZATION_NUM_PER_CHANNEL
*
4
;
const
int
num_send_tma_bytes
=
32
*
sizeof
(
int4
)
*
kNumMaxUnrolls
+
16
;
const
int
smem_send_size
=
num_warps
*
(
kNumStages
*
num_send_tma_bytes
+
num_meta_bytes
);
// Receive buffer size
const
int
num_recv_tma_bytes
=
16
+
hidden
*
2
;
const
int
smem_recv_size
=
kMaxNumGroups
*
(
kNumStages
*
num_recv_tma_bytes
+
hidden
*
2
+
kNumStages
*
num_meta_bytes
*
3
);
// Total requirement
const
int
smem_size
=
max
(
smem_send_size
,
smem_recv_size
);
#endif
// #define COMBINE_LAUNCH_CASE(hidden) \
// { \
// auto combine_func = combine<false, hidden, kNumMaxTopk, kNumMaxUnrolls>; \
// LAUNCH_KERNEL(&cfg, \
// combine_func, \
// combined_x, \
// rdma_recv_x, \
// rdma_recv_flag, \
// rdma_send_x, \
// x, \
// topk_idx, \
// topk_weights, \
// src_info, \
// layout_range, \
// global_atomic_counter, \
// mask_buffer_ptr, \
// combine_wait_recv_cost_stats, \
// next_clean, \
// num_next_clean_int, \
// atomic_clean_flag, \
// num_combined_tokens, \
// hidden, \
// num_topk, \
// num_max_dispatch_tokens_per_rank, \
// num_experts, \
// rank, \
// num_ranks, \
// num_warp_groups, \
// num_warps_per_group, \
// phases, \
// zero_copy); \
// } \
// break
// SETUP_LAUNCH_CONFIG(num_sms, num_warps* kWarpSize, stream);
// SWITCH_HIDDEN(COMBINE_LAUNCH_CASE);
// #undef COMBINE_LAUNCH_CASE
}
template
<
int
kNumThreads
>
#if defined(ROCM_DISABLE_CTX)
__launch_bounds__
(
kNumThreads
,
1
)
__global__
void
query_mask_buffer
(
int
*
mask_buffer_ptr
,
int
num_ranks
,
int
*
mask_tensor
)
{
rocshmem
::
rocshmem_fence
();
const
auto
num_sms
=
static_cast
<
int
>
(
gridDim
.
x
);
#else
const
auto
sm_id
=
static_cast
<
int
>
(
blockIdx
.
x
);
rocshmem
::
rocshmem_ctx_quiet
(
ctx
);
const
auto
num_threads
=
num_sms
*
kNumThreads
;
#endif
const
auto
thread_id
=
sm_id
*
kNumThreads
+
static_cast
<
int
>
(
threadIdx
.
x
);
}
for
(
int
rank_id
=
thread_id
;
rank_id
<
num_ranks
;
rank_id
+=
num_threads
)
{
mask_tensor
[
rank_id
]
=
mask_buffer_ptr
[
rank_id
];
}
}
}
void
query_mask_buffer
(
int
*
mask_buffer_ptr
,
int
num_ranks
,
int
*
mask_tensor
,
hipStream_t
stream
)
{
// Put finishing flag
constexpr
int
num_sms
=
1
;
EP_STATIC_ASSERT
(
kNumWarpsPerGroup
>
1
,
"Requires more than one warp per group"
);
constexpr
int
kNumThreads
=
1024
;
if
(
lane_id
==
0
){
SETUP_LAUNCH_CONFIG
(
num_sms
,
kNumThreads
,
stream
);
// volatile int ret = __hip_atomic_fetch_add(&sync_large_warp_counters[warp_group_id], 1,__ATOMIC_RELAXED, __HIP_MEMORY_SCOPE_WORKGROUP);
LAUNCH_KERNEL_NON_COOPERATIVE
(
&
cfg
,
query_mask_buffer
<
kNumThreads
>
,
mask_buffer_ptr
,
num_ranks
,
mask_tensor
);
volatile
int
ret
=
atomicAdd
((
int
*
)
&
sync_large_warp_counters
[
warp_group_id
],
1
);
}
}
syncwarp
();
while
(
sync_large_warp_counters
[
warp_group_id
]
<
(
kNumWarpsPerGroup
));
if
(
sub_warp_id
==
1
and
lane_id
==
0
)
{
while
(
ld_acquire_global
(
atomic_clean_flag
)
==
0
);
if
(
dst_rank
!=
rank
)
{
#if defined(ROCM_DISABLE_CTX)
rocshmem
::
rocshmem_long_atomic_add
(
rdma_recv_flag
+
global_expert_idx
,
1
,
dst_rank
);
#else
rocshmem
::
rocshmem_ctx_long_atomic_add
(
ctx
,
rdma_recv_flag
+
global_expert_idx
,
1
,
dst_rank
);
#endif
}
else
{
st_na_release
(
reinterpret_cast
<
int
*>
(
rdma_recv_flag
+
global_expert_idx
),
1
);
}
atomic_add_release_global
(
atomic_clean_flag
,
-
1
);
}
syncwarp
();
}
template
<
int
kNumThreads
>
// Receiving phase
__launch_bounds__
(
kNumThreads
,
1
)
__global__
void
update_mask_buffer
(
int
*
mask_buffer_ptr
,
int
rank_to_mask
,
bool
mask
)
{
LOW_LATENCY_COMBINE_RECV:
const
auto
sm_id
=
static_cast
<
int
>
(
blockIdx
.
x
);
if
((
phases
&
LOW_LATENCY_RECV_PHASE
)
==
0
)
const
auto
thread_id
=
static_cast
<
int
>
(
threadIdx
.
x
);
return
;
if
(
sm_id
==
0
&&
thread_id
==
0
)
{
atomicExch
(
mask_buffer_ptr
+
rank_to_mask
,
mask
?
1
:
0
);
// Wait all ranks to arrive and notify PCIe usage
if
(
responsible_expert_idx
<
num_experts
)
{
EP_STATIC_ASSERT
(
kNumWarpsPerGroup
>
1
,
"Invalid number of warps per group"
);
if
(
sub_warp_id
==
0
and
lane_id
==
0
){
while
(
ld_acquire_global
(
reinterpret_cast
<
int
*>
(
rdma_recv_flag
+
responsible_expert_idx
))
==
0
);
}
}
}
}
grid_barrier
(
global_atomic_counter
,
num_sms
);
void
update_mask_buffer
(
int
*
mask_buffer_ptr
,
int
rank
,
bool
mask
,
hipStream_t
stream
)
{
// Reduce tokens with FP8 cast
constexpr
int
num_sms
=
1
;
EP_DEVICE_ASSERT
(
num_topk
<=
kWarpSize
and
hidden_bf16_int4
<=
num_threads
);
constexpr
int
kNumThreads
=
64
;
EP_STATIC_ASSERT
(
kHidden
%
(
kWarpSize
*
kNumElemsPerInt4
)
==
0
,
"Invalid vectorization"
);
SETUP_LAUNCH_CONFIG
(
num_sms
,
kNumThreads
,
stream
);
if
(
thread_id
<
hidden_bf16_int4
)
{
LAUNCH_KERNEL_NON_COOPERATIVE
(
&
cfg
,
update_mask_buffer
<
kNumThreads
>
,
mask_buffer_ptr
,
rank
,
mask
);
for
(
int
token_idx
=
sm_id
;
token_idx
<
num_combined_tokens
;
token_idx
+=
num_sms
)
{
}
// Read top-k indices and weights
int
reg_topk_idx
[
kNumMaxTopk
];
float
reg_topk_weights
[
kNumMaxTopk
];
#pragma unroll
for
(
int
i
=
0
;
i
<
num_topk
;
++
i
)
{
reg_topk_idx
[
i
]
=
static_cast
<
int
>
(
__ldg
(
topk_idx
+
token_idx
*
num_topk
+
i
));
reg_topk_weights
[
i
]
=
__ldg
(
topk_weights
+
token_idx
*
num_topk
+
i
);
}
template
<
int
kNumThreads
>
float
combined_values
[
kNumElemsPerInt4
]
=
{
0.0
f
};
__launch_bounds__
(
kNumThreads
,
1
)
__global__
void
clean_mask_buffer
(
int
*
mask_buffer_ptr
,
int
num_ranks
)
{
#pragma unroll
auto
thread_id
=
static_cast
<
int
>
(
threadIdx
.
x
);
for
(
int
i
=
0
;
i
<
num_topk
;
++
i
)
if
(
reg_topk_idx
[
i
]
>=
0
)
{
// Read from sources
auto
rdma_buffer_type
=
reinterpret_cast
<
const
int
*>
(
reinterpret_cast
<
uint8_t
*>
(
rdma_recv_x
)
+
(
reg_topk_idx
[
i
]
*
num_max_dispatch_tokens_per_rank
+
token_idx
)
*
num_bytes_per_slot
);
auto
rdma_buffer_row
=
reinterpret_cast
<
const
uint8_t
*>
(
rdma_buffer_type
+
4
);
// Reduce
auto
x_vec
=
ld_nc_global
(
reinterpret_cast
<
const
int4
*>
(
rdma_buffer_row
)
+
thread_id
);
const
auto
x_bf16
=
reinterpret_cast
<
hip_bfloat16
*>
(
&
x_vec
);
#pragma unroll
for
(
int
j
=
0
;
j
<
kNumElemsPerInt4
;
++
j
)
combined_values
[
j
]
+=
static_cast
<
float
>
(
x_bf16
[
j
])
*
reg_topk_weights
[
i
];
}
// Write results
int4
&
combined_int4
=
*
reinterpret_cast
<
int4
*>
(
combined_values
);
auto
combined_bf16
=
reinterpret_cast
<
hip_bfloat16
*>
(
&
combined_values
);
#pragma unroll
#pragma unroll
for
(
int
i
=
thread_id
;
i
<
num_ranks
;
i
+=
kNumThreads
)
for
(
int
j
=
0
;
j
<
kNumElemsPerInt4
;
++
j
)
mask_buffer_ptr
[
i
]
=
0
;
combined_bf16
[
j
]
=
static_cast
<
hip_bfloat16
>
(
combined_values
[
j
]);
(
reinterpret_cast
<
int4
*>
(
combined_x
)
+
token_idx
*
hidden_bf16_int4
)[
thread_id
]
=
combined_int4
;
}
}
#if !defined(ROCM_DISABLE_CTX)
rocshmem
::
rocshmem_wg_ctx_destroy
(
&
ctx
);
#endif
}
}
void
clean_mask_buffer
(
int
*
mask_buffer_ptr
,
int
num_ranks
,
hipStream_t
stream
)
{
void
combine
(
void
*
combined_x
,
constexpr
int
num_sms
=
1
;
void
*
rdma_recv_x
,
int64_t
*
rdma_recv_flag
,
void
*
rdma_send_x
,
constexpr
int
kNumThreads
=
64
;
const
void
*
x
,
const
int64_t
*
topk_idx
,
const
float
*
topk_weights
,
SETUP_LAUNCH_CONFIG
(
num_sms
,
kNumThreads
,
stream
);
const
int
*
src_info
,
const
int64_t
*
layout_range
,
LAUNCH_KERNEL_NON_COOPERATIVE
(
&
cfg
,
clean_mask_buffer
<
kNumThreads
>
,
mask_buffer_ptr
,
num_ranks
);
int
*
global_atomic_counter
,
int64_t
*
next_clean
,
int
num_next_clean_int
,
int
num_combined_tokens
,
int
hidden
,
int
num_max_dispatch_tokens_per_rank
,
int
num_topk
,
int
num_experts
,
int
rank
,
int
num_ranks
,
void
*
workspace
,
hipStream_t
stream
,
int
phases
,
bool
zero_copy
)
{
constexpr
int
kNumWarpsPerGroup
=
4
;
constexpr
int
kNumWarpGroups
=
4
;
constexpr
int
kNumMaxTopk
=
9
;
const
auto
num_warps
=
kNumWarpGroups
*
kNumWarpsPerGroup
;
const
auto
num_sms
=
ceil_div
(
num_experts
,
kNumWarpGroups
);
// Check workspace
auto
atomic_clean_flag
=
reinterpret_cast
<
int
*>
(
workspace
);
EP_HOST_ASSERT
(
sizeof
(
int
)
<=
NUM_WORKSPACE_BYTES
);
EP_HOST_ASSERT
(
num_topk
<=
kNumMaxTopk
);
#define COMBINE_LAUNCH_CASE(hidden) { \
auto combine_func = combine<kNumWarpGroups, kNumWarpsPerGroup, hidden, kNumMaxTopk>; \
LAUNCH_KERNEL_NON_COOPERATIVE(&cfg, combine_func, \
combined_x, \
rdma_recv_x, rdma_recv_flag, rdma_send_x, \
x, topk_idx, topk_weights, src_info, layout_range, \
global_atomic_counter, \
next_clean, num_next_clean_int, \
atomic_clean_flag, \
num_combined_tokens, hidden, num_topk, \
num_max_dispatch_tokens_per_rank, \
num_experts, rank, num_ranks, \
phases, zero_copy); } break
SETUP_LAUNCH_CONFIG
(
num_sms
,
num_warps
*
kWarpSize
,
stream
);
SWITCH_HIDDEN
(
COMBINE_LAUNCH_CASE
);
#undef COMBINE_LAUNCH_CASE
}
}
}
// namespace internode_ll
}
// namespace internode_ll
}
// namespace deep_ep
}
// namespace deep_ep
#endif
csrc/kernels/utils.cuh
View file @
da13c63a
...
@@ -31,6 +31,21 @@
...
@@ -31,6 +31,21 @@
} \
} \
}
}
#define UNROLLED_WARP_COPY_LL(UNROLL_FACTOR, LANE_ID, N, DST, SRC, LD_FUNC, ST_FUNC) \
{ \
constexpr int kLoopStride = kWarpSize * (UNROLL_FACTOR); \
typename std::remove_reference<decltype(LD_FUNC((SRC) + 0))>::type unrolled_values[(UNROLL_FACTOR)]; \
auto __src = (SRC); \
auto __dst = (DST); \
for(int __i = (LANE_ID); __i < ((N) / kLoopStride) * kLoopStride; __i += kLoopStride) { \
_Pragma("unroll") for(int __j = 0; __j < (UNROLL_FACTOR); ++__j) unrolled_values[__j] = LD_FUNC(__src + __i + __j * kWarpSize); \
_Pragma("unroll") for(int __j = 0; __j < (UNROLL_FACTOR); ++__j) ST_FUNC(__dst + __i + __j * kWarpSize, unrolled_values[__j]); \
} \
for(int __i = ((N) / kLoopStride) * kLoopStride + (LANE_ID); __i < (N); __i += kWarpSize) \
ST_FUNC(__dst + __i, LD_FUNC(__src + __i)); \
}
#define UNROLLED_WARP_COPY_EMULATED(UNROLL_FACTOR, LANE_ID, N, DST, SRC, LD_FUNC, ST_FUNC) \
#define UNROLLED_WARP_COPY_EMULATED(UNROLL_FACTOR, LANE_ID, N, DST, SRC, LD_FUNC, ST_FUNC) \
{ \
{ \
constexpr int kLoopStride = kEmulatedWarpSize * (UNROLL_FACTOR); \
constexpr int kLoopStride = kEmulatedWarpSize * (UNROLL_FACTOR); \
...
@@ -329,8 +344,8 @@ __device__ __forceinline__ dtype_t broadcast(dtype_t &ptr, int src_lane_idx) {
...
@@ -329,8 +344,8 @@ __device__ __forceinline__ dtype_t broadcast(dtype_t &ptr, int src_lane_idx) {
#ifdef USE_ROCM
#ifdef USE_ROCM
constexpr
float
kFP8Margin
=
1e-4
;
constexpr
float
kFP8Margin
=
1e-4
;
constexpr
float
kFinfoAmaxE4M3
=
240.0
f
;
constexpr
float
kFinfoAmaxE4M3
=
240.0
f
;
constexpr
float
kFinfoAmaxInvE4M3
=
1.0
f
/
kFinfoAmaxE4M3
;
constexpr
float
kFinfoAmaxInvE4M3
=
1.0
f
/
kFinfoAmaxE4M3
;
#else
#else
constexpr
float
kFP8Margin
=
1e-4
;
constexpr
float
kFP8Margin
=
1e-4
;
constexpr
float
kFinfoAmaxE4M3
=
448.0
f
;
constexpr
float
kFinfoAmaxE4M3
=
448.0
f
;
...
@@ -350,8 +365,9 @@ __forceinline__ __device__ int fast_log2_ceil(float x) {
...
@@ -350,8 +365,9 @@ __forceinline__ __device__ int fast_log2_ceil(float x) {
return
exp_x
-
127
+
(
man_bits
!=
0
);
return
exp_x
-
127
+
(
man_bits
!=
0
);
}
}
__forceinline__
__device__
void
calculate_fp8_scales
(
float
amax
,
float
&
scale
,
float
&
scale_inv
,
bool
round_scale
)
{
template
<
bool
kRoundScale
>
if
(
round_scale
)
{
__forceinline__
__device__
void
calculate_fp8_scales
(
float
amax
,
float
&
scale
,
float
&
scale_inv
)
{
if
constexpr
(
kRoundScale
)
{
auto
exp_scale_inv
=
fast_log2_ceil
(
amax
*
kFinfoAmaxInvE4M3
);
auto
exp_scale_inv
=
fast_log2_ceil
(
amax
*
kFinfoAmaxInvE4M3
);
scale
=
fast_pow2
(
-
exp_scale_inv
);
scale
=
fast_pow2
(
-
exp_scale_inv
);
scale_inv
=
fast_pow2
(
exp_scale_inv
);
scale_inv
=
fast_pow2
(
exp_scale_inv
);
...
...
deep_ep/buffer.py
View file @
da13c63a
...
@@ -802,26 +802,17 @@ class Buffer:
...
@@ -802,26 +802,17 @@ class Buffer:
self
.
runtime
.
clean_low_latency_buffer
(
num_max_dispatch_tokens_per_rank
,
hidden
,
num_experts
)
self
.
runtime
.
clean_low_latency_buffer
(
num_max_dispatch_tokens_per_rank
,
hidden
,
num_experts
)
# noinspection PyTypeChecker
# noinspection PyTypeChecker
def
low_latency_dispatch
(
def
low_latency_dispatch
(
self
,
x
:
torch
.
Tensor
,
topk_idx
:
torch
.
Tensor
,
self
,
num_max_dispatch_tokens_per_rank
:
int
,
num_experts
:
int
,
x
:
torch
.
Tensor
,
use_fp8
:
bool
=
True
,
async_finish
:
bool
=
False
,
return_recv_hook
:
bool
=
False
)
->
\
topk_idx
:
torch
.
Tensor
,
Tuple
[
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
],
torch
.
Tensor
,
Tuple
,
EventOverlap
,
Callable
]:
num_max_dispatch_tokens_per_rank
:
int
,
num_experts
:
int
,
cumulative_local_expert_recv_stats
:
Optional
[
torch
.
Tensor
]
=
None
,
dispatch_wait_recv_cost_stats
:
Optional
[
torch
.
Tensor
]
=
None
,
use_fp8
:
bool
=
True
,
round_scale
:
bool
=
False
,
use_ue8m0
:
bool
=
False
,
async_finish
:
bool
=
False
,
return_recv_hook
:
bool
=
False
,
)
->
Tuple
[
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
],
torch
.
Tensor
,
Tuple
,
EventOverlap
,
Callable
]:
"""
"""
A low-latency implementation for dispatching with IBGDA.
A low-latency implementation for dispatching with IBGDA.
This kernel requires all the ranks (no matter intranode or internode) should be visible via RDMA
This kernel requires all the ranks (no matter intranode or internode) should be visible via RDMA
(specifically, IBGDA must be enabled).
(specifically, IBGDA must be enabled).
Warning: as there are only two buffers, and the returned tensors reuse the buffer, you cannot hold more than 2
Even for ranks in the same node, NVLink are fully disabled for simplicity.
low-latency kernels' result tensors at a single moment.
Warning: as there are only two buffers, and the returned tensors reuse the buffer, you can not hold more than 2
low-latency kernels' result tensor at a single moment.
Arguments:
Arguments:
x: `torch.Tensor` with `torch.bfloat16`, shaped as `[num_tokens, hidden]`, only several hidden shapes are
x: `torch.Tensor` with `torch.bfloat16`, shaped as `[num_tokens, hidden]`, only several hidden shapes are
...
@@ -830,105 +821,52 @@ class Buffer:
...
@@ -830,105 +821,52 @@ class Buffer:
are supported. `-1` indices (not selecting any expert) are supported.
are supported. `-1` indices (not selecting any expert) are supported.
num_max_dispatch_tokens_per_rank: the maximum number of tokens to dispatch, all the ranks must hold the same value.
num_max_dispatch_tokens_per_rank: the maximum number of tokens to dispatch, all the ranks must hold the same value.
num_experts: the number of all experts.
num_experts: the number of all experts.
cumulative_local_expert_recv_stats: a cumulative expert count tensor for statistics, which should have shape
`[num_local_experts]` and be typed as `torch.int`. This is useful for online service EP load balance
monitoring.
dispatch_wait_recv_cost_stats: a cumulative time spent waiting to receive each token tensor for statistics,
which should have shape `[num_ranks, num_ranks]` and be typed as `torch.int64`.
This is useful for detecting and pre-cisely localizing slow anomalies.
use_fp8: whether to enable FP8 casting, with this, the received data will be a tuple of FP8 tensor and scaling factors.
use_fp8: whether to enable FP8 casting, with this, the received data will be a tuple of FP8 tensor and scaling factors.
round_scale: whether round the scaling factors into power of 2.
use_ue8m0: whether use UE8M0 as scaling factor format (available only with `round_scale=True`).
async_finish: the current stream will not wait for the communication kernels to be finished if set.
async_finish: the current stream will not wait for the communication kernels to be finished if set.
return_recv_hook: return a receiving hook if set. If set, the kernel will just do the RDMA request issues,
return_recv_hook: return a receiving hook if set. If set, the kernel will just do the RDMA request issues,
but **without actually receiving the data**. You must call the received hook to make sure the data's arrival.
but **without actually receiving the data**. You must call the received hook to make sure the data's arrival.
If you
do
not set this flag, the kernel will ensure the data's arrival.
If you not set this flag, the kernel will ensure the data's arrival.
Returns:
Returns:
recv_x: a tensor or tuple with received tokens for each expert.
recv_x: a tensor or tuple with received tokens for each expert.
With `use_fp8=True`: the first element is a `torch.Tensor` shaped as
With `use_fp8=True`: the first element is a `torch.Tensor` shaped as
`[num_local_experts, num_max_dispatch_tokens_per_rank * num_ranks, hidden]` with `torch.float8_e4m3fn`.
`[num_local_experts, num_max_dispatch_tokens_per_rank * num_ranks, hidden]` with `torch.float8_e4m3fn`.
The second tensor is the corresponding scales for the first element with shape
The second tensor is the corresponding scales for the first element with shape
`[num_local_experts, num_max_dispatch_tokens_per_rank * num_ranks, hidden // 128]` with `torch.float`,
`[num_local_experts, num_max_dispatch_tokens_per_rank * num_ranks, hidden // 128]` with `torch.float`.
if `use_ue8m0=False`. With `use_ue8m0=True`, the second one is packed and shaped as
`[num_local_experts, num_max_dispatch_tokens_per_rank * num_ranks, hidden // 512]` with type `torch.int`.
Notice that, the last-two-dimension of the scaling tensors are in column-major for TMA compatibility.
Notice that, the last-two-dimension of the scaling tensors are in column-major for TMA compatibility.
With `use_fp8=False`, the result would be a tensor shaped as
With `use_fp8=False`, the result would be a tensor shaped as
`[num_local_experts, num_max_dispatch_tokens_per_rank * num_ranks, hidden]` with `torch.bfloat16`.
`[num_local_experts, num_max_dispatch_tokens_per_rank * num_ranks, hidden]` with `torch.bfloat16`.
Moreover, not all tokens are valid, only some of the `num_max_dispatch_tokens_per_rank * num_ranks` are,
Moreover, not all tokens are valid, only some of the `num_max_dispatch_tokens_per_rank * num_ranks` are,
as we do not synchronize CPU received count with GPU (also not incompatible with CUDA graph if synced).
as we do not synchronize CPU received count with GPU (also not incompatible with CUDA graph if synced).
recv_count: a tensor shaped `[num_local_experts]` with type `torch.int`, indicating how many tokens each
recv_count: a tensor shaped `[num_local_experts]` with type `torch.int`, indicating how many tokens each
expert receive
s
. As mentioned before,
not all
tokens are valid in `recv_x`.
expert receive. As mentioned before,
all not
tokens are valid in `recv_x`.
handle: the communication handle to be used in the `low_latency_combine` function.
handle: the communication handle to be used in the `low_latency_combine` function.
event: the event after executing the kernel (valid only if `async_finish` is set).
event: the event after executing the kernel (valid only if `async_finish` is set).
hook: the receiving hook function (valid only if `return_recv_hook` is set).
hook: the receiving hook function (valid only if `return_recv_hook` is set).
"""
"""
(
packed_recv_x
,
packed_recv_x_scales
,
packed_recv_count
,
packed_recv_src_info
,
packed_recv_layout_range
,
event
,
hook
=
\
packed_recv_x
,
self
.
runtime
.
low_latency_dispatch
(
x
,
topk_idx
,
packed_recv_x_scales
,
num_max_dispatch_tokens_per_rank
,
num_experts
,
packed_recv_count
,
use_fp8
,
async_finish
,
return_recv_hook
)
packed_recv_src_info
,
handle
=
(
packed_recv_src_info
,
packed_recv_layout_range
,
num_max_dispatch_tokens_per_rank
,
x
.
size
(
1
),
num_experts
)
packed_recv_layout_range
,
tensors_to_record
=
(
x
,
topk_idx
,
event
,
packed_recv_x
,
packed_recv_x_scales
,
packed_recv_count
,
hook
,
packed_recv_src_info
,
packed_recv_layout_range
)
)
=
self
.
runtime
.
low_latency_dispatch
(
return
(
packed_recv_x
,
packed_recv_x_scales
)
if
use_fp8
else
packed_recv_x
,
packed_recv_count
,
handle
,
\
x
,
EventOverlap
(
event
,
tensors_to_record
if
async_finish
else
None
),
hook
topk_idx
,
cumulative_local_expert_recv_stats
,
dispatch_wait_recv_cost_stats
,
num_max_dispatch_tokens_per_rank
,
num_experts
,
use_fp8
,
round_scale
,
use_ue8m0
,
async_finish
,
return_recv_hook
,
)
handle
=
(
packed_recv_src_info
,
packed_recv_layout_range
,
num_max_dispatch_tokens_per_rank
,
x
.
size
(
1
),
num_experts
,
)
tensors_to_record
=
(
x
,
topk_idx
,
packed_recv_x
,
packed_recv_x_scales
,
packed_recv_count
,
packed_recv_src_info
,
packed_recv_layout_range
,
cumulative_local_expert_recv_stats
,
)
return
(
(
packed_recv_x
,
packed_recv_x_scales
)
if
use_fp8
else
packed_recv_x
,
packed_recv_count
,
handle
,
EventOverlap
(
event
,
tensors_to_record
if
async_finish
else
None
),
hook
,
)
# noinspection PyTypeChecker
# noinspection PyTypeChecker
def
low_latency_combine
(
def
low_latency_combine
(
self
,
x
:
torch
.
Tensor
,
topk_idx
:
torch
.
Tensor
,
topk_weights
:
torch
.
Tensor
,
self
,
handle
:
tuple
,
zero_copy
:
bool
=
False
,
async_finish
:
bool
=
False
,
x
:
torch
.
Tensor
,
return_recv_hook
:
bool
=
False
,
out
:
Optional
[
torch
.
Tensor
]
=
None
)
->
\
topk_idx
:
torch
.
Tensor
,
Tuple
[
torch
.
Tensor
,
EventOverlap
,
Callable
]:
topk_weights
:
torch
.
Tensor
,
handle
:
tuple
,
use_logfmt
:
bool
=
False
,
zero_copy
:
bool
=
False
,
async_finish
:
bool
=
False
,
return_recv_hook
:
bool
=
False
,
out
:
Optional
[
torch
.
Tensor
]
=
None
,
combine_wait_recv_cost_stats
:
Optional
[
torch
.
Tensor
]
=
None
,
)
->
Tuple
[
torch
.
Tensor
,
EventOverlap
,
Callable
]:
"""
"""
A low-latency implementation for combining tokens (reduce **with weights**) with IBGDA.
A low-latency implementation for combining tokens (reduce **with weights**) with IBGDA.
This kernel requires all the ranks (no matter intranode or internode) should be visible via RDMA
This kernel requires all the ranks (no matter intranode or internode) should be visible via RDMA
(specifically, IBGDA must be enabled).
(specifically, IBGDA must be enabled).
Warning: as there are only two buffers, and the returned tensors reuse the buffer, you cannot hold more than 2
Even for ranks in the same node, NVLink are fully disabled for simplicity.
low-latency kernels' result tensors at a single moment.
Warning: as there are only two buffers, and the returned tensors reuse the buffer, you can not hold more than 2
low-latency kernels' result tensor at a single moment.
Arguments:
Arguments:
x: `[num_local_experts, num_max_dispatch_tokens_per_rank * num_ranks, hidden]` with `torch.bfloat16`,
x: `[num_local_experts, num_max_dispatch_tokens_per_rank * num_ranks, hidden]` with `torch.bfloat16`,
...
@@ -939,39 +877,23 @@ class Buffer:
...
@@ -939,39 +877,23 @@ class Buffer:
topk_weights: `[num_combined_tokens, num_topk]` with `torch.float`, the expert weights selected by the dispatched
topk_weights: `[num_combined_tokens, num_topk]` with `torch.float`, the expert weights selected by the dispatched
tokens. The received tokens will be reduced with the weights in this tensor.
tokens. The received tokens will be reduced with the weights in this tensor.
handle: the communication handle given by the `dispatch` function.
handle: the communication handle given by the `dispatch` function.
use_logfmt: whether to use an internal "LogFMT with dynamic per-64-channel cast" format (10 bits).
zero_copy: whether the tensor is already copied into the RDMA buffer, should be cooperative
zero_copy: whether the tensor is already copied into the RDMA buffer, should be cooperative
with `get_next_low_latency_combine_buffer`.
with `get_next_low_latency_combine_buffer`.
async_finish: the current stream will not wait for the communication kernels to be finished if set.
async_finish: the current stream will not wait for the communication kernels to be finished if set.
return_recv_hook: return a receiving hook if set. If set, the kernel will just do the RDMA request issues,
return_recv_hook: return a receiving hook if set. If set, the kernel will just do the RDMA request issues,
but **without actually receiving the data**. You must call the received hook to make sure the data's arrival.
but **without actually receiving the data**. You must call the received hook to make sure the data's arrival.
If you
do
not set this flag, the kernel will ensure the data's arrival.
If you not set this flag, the kernel will ensure the data's arrival.
out: the in-place output tensor, if set, the kernel will write the result to this tensor and return it directly.
out: the in-place output tensor, if set, the kernel will write the result to this tensor and return it directly.
combine_wait_recv_cost_stats: a cumulative time spent waiting to receive each token tensor for statistics,
which should have shape `[num_ranks, num_ranks]` and be typed as `torch.int64`.
This is useful for detecting and pre-cisely localizing slow anomalies.
Returns:
Returns:
combined_x: the reduced token tensor, with shape `[num_combined_tokens,
hidden
]` and type `torch.bfloat16`.
combined_x: the reduced token tensor, with shape `[num_combined_tokens,
num_topk
]` and type `torch.bfloat16`.
event: the event after executing the kernel (valid only if `async_finish` is set).
event: the event after executing the kernel (valid only if `async_finish` is set).
hook: the receiving hook function (valid only if `return_recv_hook` is set).
hook: the receiving hook function (valid only if `return_recv_hook` is set).
"""
"""
src_info
,
layout_range
,
num_max_dispatch_tokens_per_rank
,
hidden
,
num_experts
=
handle
src_info
,
layout_range
,
num_max_dispatch_tokens_per_rank
,
hidden
,
num_experts
=
handle
combined_x
,
event
,
hook
=
self
.
runtime
.
low_latency_combine
(
combined_x
,
event
,
hook
=
self
.
runtime
.
low_latency_combine
(
x
,
topk_idx
,
topk_weights
,
src_info
,
layout_range
,
x
,
num_max_dispatch_tokens_per_rank
,
num_experts
,
topk_idx
,
zero_copy
,
async_finish
,
return_recv_hook
,
out
)
topk_weights
,
src_info
,
layout_range
,
combine_wait_recv_cost_stats
,
num_max_dispatch_tokens_per_rank
,
num_experts
,
use_logfmt
,
zero_copy
,
async_finish
,
return_recv_hook
,
out
,
)
tensors_to_record
=
(
x
,
topk_idx
,
topk_weights
,
src_info
,
layout_range
,
combined_x
)
tensors_to_record
=
(
x
,
topk_idx
,
topk_weights
,
src_info
,
layout_range
,
combined_x
)
return
combined_x
,
EventOverlap
(
event
,
tensors_to_record
if
async_finish
else
None
),
hook
return
combined_x
,
EventOverlap
(
event
,
tensors_to_record
if
async_finish
else
None
),
hook
...
@@ -988,34 +910,4 @@ class Buffer:
...
@@ -988,34 +910,4 @@ class Buffer:
by yourself.
by yourself.
"""
"""
src_info
,
layout_range
,
num_max_dispatch_tokens_per_rank
,
hidden
,
num_experts
=
handle
src_info
,
layout_range
,
num_max_dispatch_tokens_per_rank
,
hidden
,
num_experts
=
handle
return
self
.
runtime
.
get_next_low_latency_combine_buffer
(
return
self
.
runtime
.
get_next_low_latency_combine_buffer
(
num_max_dispatch_tokens_per_rank
,
hidden
,
num_experts
)
num_max_dispatch_tokens_per_rank
,
hidden
,
num_experts
)
def
low_latency_update_mask_buffer
(
self
,
rank_to_mask
:
int
,
mask
:
bool
=
False
):
"""
Mask (unmask) a rank during communication (dispatch, combine, and clean)
Arguments:
rank: the rank to mask (unmask).
mask: if True, will mask the rank (do not recvfrom/sendto the rank), otherwise will unmask the rank.
"""
self
.
runtime
.
low_latency_update_mask_buffer
(
rank_to_mask
,
mask
)
def
low_latency_query_mask_buffer
(
self
,
mask_status
:
torch
.
Tensor
):
"""
Query the mask status of all ranks
Arguments:
mask_status: `[num_ranks]` with `torch.int`, the mask status of each rank. `1` means mask and `0` means unmasked.
"""
self
.
runtime
.
low_latency_query_mask_buffer
(
mask_status
)
def
low_latency_clean_mask_buffer
(
self
):
"""
Clean the mask buffer
"""
self
.
runtime
.
low_latency_clean_mask_buffer
()
rocshmem_dir/bin/rocshmem_info
View file @
da13c63a
No preview for this file type
rocshmem_dir/include/rocshmem/rocshmem.hpp
View file @
da13c63a
...
@@ -26,7 +26,6 @@
...
@@ -26,7 +26,6 @@
#define LIBRARY_INCLUDE_ROCSHMEM_HPP
#define LIBRARY_INCLUDE_ROCSHMEM_HPP
#include <hip/hip_runtime.h>
#include <hip/hip_runtime.h>
#include <mpi.h>
#include "rocshmem_config.h"
#include "rocshmem_config.h"
#include "rocshmem_common.hpp"
#include "rocshmem_common.hpp"
...
@@ -36,6 +35,10 @@
...
@@ -36,6 +35,10 @@
#include "rocshmem_COLL.hpp"
#include "rocshmem_COLL.hpp"
#include "rocshmem_P2P_SYNC.hpp"
#include "rocshmem_P2P_SYNC.hpp"
#include "rocshmem_RMA_X.hpp"
#include "rocshmem_RMA_X.hpp"
#if defined(HAVE_EXTERNAL_MPI)
#include <mpi.h>
#endif
/**
/**
* @file rocshmem.hpp
* @file rocshmem.hpp
* @brief Public header for rocSHMEM device and host libraries.
* @brief Public header for rocSHMEM device and host libraries.
...
@@ -57,13 +60,22 @@ constexpr char VERSION[] = "3.0.0";
...
@@ -57,13 +60,22 @@ constexpr char VERSION[] = "3.0.0";
/******************************************************************************
/******************************************************************************
**************************** HOST INTERFACE **********************************
**************************** HOST INTERFACE **********************************
*****************************************************************************/
*****************************************************************************/
#if defined(HAVE_EXTERNAL_MPI)
/**
/**
* @brief Initialize the rocSHMEM runtime and underlying transport layer.
* @brief Initialize the rocSHMEM runtime and underlying transport layer.
*
*
* @param[in] comm
(Optional)
MPI Communicator that rocSHMEM will be using
* @param[in] comm MPI Communicator that rocSHMEM will be using
* If MPI_COMM_NULL, rocSHMEM will be using MPI_COMM_WORLD
* If MPI_COMM_NULL, rocSHMEM will be using MPI_COMM_WORLD
*/
*/
__host__
void
rocshmem_init
(
MPI_Comm
comm
=
MPI_COMM_WORLD
);
[[
deprecated
]]
__host__
void
rocshmem_init
(
MPI_Comm
comm
);
#endif
/**
* @brief Initialize the rocSHMEM runtime and underlying transport layer.
* This is equivalent to the previous function, using implicitely
* MPI_COMM_WORLD for initialization
*/
__host__
void
rocshmem_init
(
void
);
/**
/**
* @brief Query rocSHMEM context from host API
* @brief Query rocSHMEM context from host API
...
@@ -86,8 +98,10 @@ __host__ void * rocshmem_get_device_ctx();
...
@@ -86,8 +98,10 @@ __host__ void * rocshmem_get_device_ctx();
* This can be used to issue load/store from custom kernels
* This can be used to issue load/store from custom kernels
* instead of using rocshmem device side get/put APIs for RMA operations.
* instead of using rocshmem device side get/put APIs for RMA operations.
*/
*/
__host__
void
*
rocshmem_ptr
(
void
*
dest
,
int
pe
);
__host__
void
*
rocshmem_ptr
(
const
void
*
dest
,
int
pe
);
__device__
ATTR_NO_INLINE
void
*
rocshmem_ptr
(
const
void
*
dest
,
int
pe
);
#if defined(HAVE_EXTERNAL_MPI)
/**
/**
* @brief Initialize the rocSHMEM runtime and underlying transport layer
* @brief Initialize the rocSHMEM runtime and underlying transport layer
* with an attempt to enable the requested thread support.
* with an attempt to enable the requested thread support.
...
@@ -102,8 +116,9 @@ __host__ void *rocshmem_ptr(void *dest, int pe);
...
@@ -102,8 +116,9 @@ __host__ void *rocshmem_ptr(void *dest, int pe);
* @return int returns 0 upon success; otherwise, it returns a nonzero
* @return int returns 0 upon success; otherwise, it returns a nonzero
* value
* value
*/
*/
__host__
int
rocshmem_init_thread
(
int
requested
,
int
*
provided
,
[[
deprecated
]]
__host__
int
rocshmem_init_thread
(
int
requested
,
int
*
provided
,
MPI_Comm
comm
=
MPI_COMM_WORLD
);
MPI_Comm
comm
);
#endif
/**
/**
* @brief Initialize the rocSHMEM runtime and underlying transport layer
* @brief Initialize the rocSHMEM runtime and underlying transport layer
...
@@ -327,6 +342,13 @@ __host__ void rocshmem_quiet();
...
@@ -327,6 +342,13 @@ __host__ void rocshmem_quiet();
*/
*/
__host__
void
rocshmem_barrier_all
();
__host__
void
rocshmem_barrier_all
();
/**
* @brief enqueues a collective barrier on given stream.
*
* @return void
*/
__host__
void
rocshmem_barrier_all_on_stream
(
hipStream_t
stream
);
/**
/**
* @brief registers the arrival of a PE at a barrier.
* @brief registers the arrival of a PE at a barrier.
* The caller is blocked until the synchronization is resolved.
* The caller is blocked until the synchronization is resolved.
...
@@ -360,7 +382,7 @@ __host__ void rocshmem_global_exit(int status);
...
@@ -360,7 +382,7 @@ __host__ void rocshmem_global_exit(int status);
*
*
* @return void.
* @return void.
*/
*/
__device__
void
rocshmem_wg_init
();
[[
deprecated
]]
__device__
void
rocshmem_wg_init
();
/**
/**
* @brief Finalizes device-side rocSHMEM resources. Must be called before
* @brief Finalizes device-side rocSHMEM resources. Must be called before
...
@@ -370,7 +392,7 @@ __device__ void rocshmem_wg_init();
...
@@ -370,7 +392,7 @@ __device__ void rocshmem_wg_init();
*
*
* @return void.
* @return void.
*/
*/
__device__
void
rocshmem_wg_finalize
();
[[
deprecated
]]
__device__
void
rocshmem_wg_finalize
();
/**
/**
* @brief Initializes device-side rocSHMEM resources. Must be called before
* @brief Initializes device-side rocSHMEM resources. Must be called before
...
@@ -386,7 +408,7 @@ __device__ void rocshmem_wg_finalize();
...
@@ -386,7 +408,7 @@ __device__ void rocshmem_wg_finalize();
*
*
* @return void.
* @return void.
*/
*/
__device__
void
rocshmem_wg_init_thread
(
int
requested
,
int
*
provided
);
[[
deprecated
]]
__device__
void
rocshmem_wg_init_thread
(
int
requested
,
int
*
provided
);
/**
/**
* @brief Query the thread mode used by the runtime.
* @brief Query the thread mode used by the runtime.
...
@@ -476,6 +498,23 @@ __device__ ATTR_NO_INLINE void rocshmem_ctx_quiet(rocshmem_ctx_t ctx);
...
@@ -476,6 +498,23 @@ __device__ ATTR_NO_INLINE void rocshmem_ctx_quiet(rocshmem_ctx_t ctx);
__device__
ATTR_NO_INLINE
void
rocshmem_quiet
();
__device__
ATTR_NO_INLINE
void
rocshmem_quiet
();
/**
* @brief Completes all previous operations posted to this context for PEs in the
* `target_pes` array.
*
* @param[in] ctx Context with which to perform this operation.
*
* @param[in] target_pes Address of target PE array where the operations need to be completed.
*
* @param[in] npes The number of PEs in the target PE array.
*
* @return void.
*/
__device__
ATTR_NO_INLINE
void
rocshmem_ctx_pe_quiet
(
rocshmem_ctx_t
ctx
,
const
int
*
target_pes
,
size_t
npes
);
__device__
ATTR_NO_INLINE
void
rocshmem_pe_quiet
(
const
int
*
target_pes
,
size_t
npes
);
/**
/**
* @brief Query the total number of PEs.
* @brief Query the total number of PEs.
*
*
...
...
rocshmem_dir/include/rocshmem/rocshmem_COLL.hpp
View file @
da13c63a
...
@@ -599,6 +599,14 @@ __host__ int rocshmem_ctx_double_prod_reduce(
...
@@ -599,6 +599,14 @@ __host__ int rocshmem_ctx_double_prod_reduce(
rocshmem_ctx_t
ctx
,
rocshmem_team_t
team
,
double
*
dest
,
const
double
*
source
,
rocshmem_ctx_t
ctx
,
rocshmem_team_t
team
,
double
*
dest
,
const
double
*
source
,
int
nreduce
);
int
nreduce
);
/**
* @brief kernel for performing a barrier synchronization.
* Caller enqueues the kernel on given stream
*
* @return void
*/
__global__
ATTR_NO_INLINE
void
rocshmem_barrier_all_kernel
();
/**
/**
* @brief perform a collective barrier between all PEs in the system.
* @brief perform a collective barrier between all PEs in the system.
* The caller is blocked until the barrier is resolved.
* The caller is blocked until the barrier is resolved.
...
@@ -767,28 +775,6 @@ __device__ ATTR_NO_INLINE void rocshmem_ctx_sync_wave(
...
@@ -767,28 +775,6 @@ __device__ ATTR_NO_INLINE void rocshmem_ctx_sync_wave(
__device__
ATTR_NO_INLINE
void
rocshmem_ctx_sync_wg
(
__device__
ATTR_NO_INLINE
void
rocshmem_ctx_sync_wg
(
rocshmem_ctx_t
ctx
,
rocshmem_team_t
team
);
rocshmem_ctx_t
ctx
,
rocshmem_team_t
team
);
/**
* @brief Query a local pointer to a symmetric data object on the
* specified \pe . Returns an address that may be used to directly reference
* dest on the specified \pe. This address can be accesses with LD/ST ops.
*
* Can be called per thread with no performance penalty.
*/
__device__
ATTR_NO_INLINE
void
*
rocshmem_ptr
(
const
void
*
dest
,
int
pe
);
/**
* @brief Make all uncacheable GPU data visible to other agents in the sytem.
*
* This only works for data that was explicitly allocated uncacheable on the
* GPU!
*
* Can be called per thread with no performance penalty.
*
* @param[in] GPU-side handle.
*
* @return void
*/
}
// namespace rocshmem
}
// namespace rocshmem
#endif // LIBRARY_INCLUDE_ROCSHMEM_COLL_HPP
#endif // LIBRARY_INCLUDE_ROCSHMEM_COLL_HPP
rocshmem_dir/include/rocshmem/rocshmem_common.hpp
View file @
da13c63a
...
@@ -106,9 +106,18 @@ const int ROCSHMEM_CTX_SHARED = 8;
...
@@ -106,9 +106,18 @@ const int ROCSHMEM_CTX_SHARED = 8;
* @brief GPU side OpenSHMEM context created from each work-groups'
* @brief GPU side OpenSHMEM context created from each work-groups'
* rocshmem_wg_handle_t
* rocshmem_wg_handle_t
*/
*/
typedef
struct
{
typedef
struct
rocshmem_ctx
{
void
*
ctx_opaque
;
void
*
ctx_opaque
;
void
*
team_opaque
;
void
*
team_opaque
;
__host__
__device__
bool
operator
==
(
const
struct
rocshmem_ctx
&
other
)
const
{
return
(
ctx_opaque
==
other
.
ctx_opaque
&&
team_opaque
==
other
.
team_opaque
);
}
__host__
__device__
bool
operator
!=
(
const
struct
rocshmem_ctx
&
other
)
const
{
return
!
(
*
this
==
other
);
}
}
rocshmem_ctx_t
;
}
rocshmem_ctx_t
;
/**
/**
...
@@ -116,6 +125,14 @@ typedef struct {
...
@@ -116,6 +125,14 @@ typedef struct {
*/
*/
extern
"C"
__device__
rocshmem_ctx_t
__attribute__
((
visibility
(
"default"
)))
ROCSHMEM_CTX_DEFAULT
;
extern
"C"
__device__
rocshmem_ctx_t
__attribute__
((
visibility
(
"default"
)))
ROCSHMEM_CTX_DEFAULT
;
/**
* A value corresponding to an invalid communication context. This value can be
* used to initialize or update context handles to indicate that they do not
* reference a valid context. When managed in this way, applications can use an
* equality comparison to test whether a given context handle references a
* valid context.
*/
extern
__constant__
rocshmem_ctx_t
ROCSHMEM_CTX_INVALID
;
/**
/**
* Used internally to set default context.
* Used internally to set default context.
*/
*/
...
...
rocshmem_dir/include/rocshmem/rocshmem_config.h
View file @
da13c63a
...
@@ -45,3 +45,4 @@
...
@@ -45,3 +45,4 @@
/* #undef GDA_IONIC */
/* #undef GDA_IONIC */
/* #undef GDA_BNXT */
/* #undef GDA_BNXT */
#define GDA_MLX5
#define GDA_MLX5
#define HAVE_EXTERNAL_MPI
rocshmem_dir/include/rocshmem/rocshmem_mpi.hpp
0 → 100644
View file @
da13c63a
/******************************************************************************
* Copyright (c) Advanced Micro Devices, Inc. All rights reserved.
*
* SPDX-License-Identifier: MIT
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to
* deal in the Software without restriction, including without limitation the
* rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
* sell copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
* FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS
* IN THE SOFTWARE.
*****************************************************************************/
#ifndef LIBRARY_INCLUDE_ROCSHMEM_MPI_HPP
#define LIBRARY_INCLUDE_ROCSHMEM_MPI_HPP
#if defined(HAVE_EXTERNAL_MPI)
#include <mpi.h>
#endif
#if defined(c_plusplus) || defined(__cplusplus)
extern
"C"
{
#endif
#if !defined(MPI_VERSION)
// Open MPI based values for the constants/handles etc.
// Even though we did not include an external MPI header file
// The includer may have (e.g., a unit test).
typedef
void
*
MPI_Comm
;
typedef
void
*
MPI_Win
;
typedef
void
*
MPI_Group
;
typedef
void
*
MPI_Op
;
typedef
void
*
MPI_Datatype
;
typedef
void
*
MPI_Request
;
typedef
void
*
MPI_Info
;
struct
ompi_status_public_t
{
int
MPI_SOURCE
;
int
MPI_TAG
;
int
MPI_ERROR
;
int
_cancelled
;
size_t
_ucount
;
};
typedef
struct
ompi_status_public_t
MPI_Status
;
#define MPI_Aint uint64_t
#define MPI_UNDEFINED -32766
#define MPI_THREAD_MULTIPLE 3
#define MPI_SUCCESS 0
#define MPI_IN_PLACE (void*)1
#define MPI_MODE_NOCHECK 1
#define MPI_COMM_TYPE_SHARED 0
#define MPI_Aint_diff(addr1, addr2) ((MPI_Aint) ((char *) (addr1) - (char *) (addr2)))
struct
ompi_internal_symbols_t
{
void
*
ompi_mpi_comm_world
;
void
*
ompi_mpi_comm_null
;
void
*
ompi_request_null
;
void
*
ompi_mpi_info_null
;
void
*
ompi_mpi_datatype_null
;
void
*
ompi_mpi_op_max
;
void
*
ompi_mpi_op_min
;
void
*
ompi_mpi_op_sum
;
void
*
ompi_mpi_op_prod
;
void
*
ompi_mpi_op_band
;
void
*
ompi_mpi_op_bor
;
void
*
ompi_mpi_op_bxor
;
void
*
ompi_mpi_op_replace
;
void
*
ompi_mpi_op_no_op
;
void
*
ompi_mpi_char
;
void
*
ompi_mpi_unsigned_char
;
void
*
ompi_mpi_signed_char
;
void
*
ompi_mpi_short
;
void
*
ompi_mpi_unsigned_short
;
void
*
ompi_mpi_int
;
void
*
ompi_mpi_unsigned
;
void
*
ompi_mpi_long
;
void
*
ompi_mpi_unsigned_long
;
void
*
ompi_mpi_long_long_int
;
void
*
ompi_mpi_unsigned_long_long
;
void
*
ompi_mpi_float
;
void
*
ompi_mpi_double
;
void
*
ompi_mpi_long_double
;
};
extern
struct
ompi_internal_symbols_t
ompi_symbols_
;
#define OMPI_PREDEFINED_GLOBAL(type, global) (static_cast<type> (global))
#define MPI_COMM_WORLD OMPI_PREDEFINED_GLOBAL(MPI_Comm, ompi_symbols_.ompi_mpi_comm_world)
#define MPI_COMM_NULL OMPI_PREDEFINED_GLOBAL(MPI_Comm, ompi_symbols_.ompi_mpi_comm_null)
#define MPI_REQUEST_NULL OMPI_PREDEFINED_GLOBAL(MPI_Request, ompi_symbols_.ompi_request_null)
#define MPI_WIN_NULL OMPI_PREDEFINED_GLOBAL(MPI_Win, ompi_symbols_.ompi_mpi_win_null)
#define MPI_INFO_NULL OMPI_PREDEFINED_GLOBAL(MPI_Info, ompi_symbols_.ompi_mpi_info_null)
#define MPI_MAX OMPI_PREDEFINED_GLOBAL(MPI_Op, ompi_symbols_.ompi_mpi_op_max)
#define MPI_MIN OMPI_PREDEFINED_GLOBAL(MPI_Op, ompi_symbols_.ompi_mpi_op_min)
#define MPI_SUM OMPI_PREDEFINED_GLOBAL(MPI_Op, ompi_symbols_.ompi_mpi_op_sum)
#define MPI_PROD OMPI_PREDEFINED_GLOBAL(MPI_Op, ompi_symbols_.ompi_mpi_op_prod)
#define MPI_BAND OMPI_PREDEFINED_GLOBAL(MPI_Op, ompi_symbols_.ompi_mpi_op_band)
#define MPI_BOR OMPI_PREDEFINED_GLOBAL(MPI_Op, ompi_symbols_.ompi_mpi_op_bor)
#define MPI_BXOR OMPI_PREDEFINED_GLOBAL(MPI_Op, ompi_symbols_.ompi_mpi_op_bxor)
#define MPI_REPLACE OMPI_PREDEFINED_GLOBAL(MPI_Op, ompi_symbols_.ompi_mpi_op_replace)
#define MPI_NO_OP OMPI_PREDEFINED_GLOBAL(MPI_Op, ompi_symbols_.ompi_mpi_op_no_op)
#define MPI_DATATYPE_NULL OMPI_PREDEFINED_GLOBAL(MPI_Datatype, ompi_symbols_.ompi_mpi_datatype_null)
#define MPI_CHAR OMPI_PREDEFINED_GLOBAL(MPI_Datatype, ompi_symbols_.ompi_mpi_char)
#define MPI_UNSIGNED_CHAR OMPI_PREDEFINED_GLOBAL(MPI_Datatype, ompi_symbols_.ompi_mpi_unsigned_char)
#define MPI_SIGNED_CHAR OMPI_PREDEFINED_GLOBAL(MPI_Datatype, ompi_symbols_.ompi_mpi_signed_char)
#define MPI_SHORT OMPI_PREDEFINED_GLOBAL(MPI_Datatype, ompi_symbols_.ompi_mpi_short)
#define MPI_UNSIGNED_SHORT OMPI_PREDEFINED_GLOBAL(MPI_Datatype, ompi_symbols_.ompi_mpi_unsigned_short)
#define MPI_INT OMPI_PREDEFINED_GLOBAL(MPI_Datatype, ompi_symbols_.ompi_mpi_int)
#define MPI_UNSIGNED OMPI_PREDEFINED_GLOBAL(MPI_Datatype, ompi_symbols_.ompi_mpi_unsigned)
#define MPI_LONG OMPI_PREDEFINED_GLOBAL(MPI_Datatype, ompi_symbols_.ompi_mpi_long)
#define MPI_UNSIGNED_LONG OMPI_PREDEFINED_GLOBAL(MPI_Datatype, ompi_symbols_.ompi_mpi_unsigned_long)
#define MPI_LONG_LONG OMPI_PREDEFINED_GLOBAL(MPI_Datatype, ompi_symbols_.ompi_mpi_long_long_int)
#define MPI_UNSIGNED_LONG_LONG OMPI_PREDEFINED_GLOBAL(MPI_Datatype, ompi_symbols_.ompi_mpi_unsigned_long_long)
#define MPI_FLOAT OMPI_PREDEFINED_GLOBAL(MPI_Datatype, ompi_symbols_.ompi_mpi_float)
#define MPI_DOUBLE OMPI_PREDEFINED_GLOBAL(MPI_Datatype, ompi_symbols_.ompi_mpi_double)
#define MPI_LONG_DOUBLE OMPI_PREDEFINED_GLOBAL(MPI_Datatype, ompi_symbols_.ompi_mpi_long_double)
#endif //!defined(MPI_VERSION)
#if defined(c_plusplus) || defined(__cplusplus)
}
#endif
#endif //LIBRARY_INCLUDE_ROCSHMEM_MPI_HPP
rocshmem_dir/lib/cmake/rocshmem/rocshmem-targets.cmake
View file @
da13c63a
...
@@ -61,7 +61,7 @@ add_library(roc::rocshmem STATIC IMPORTED)
...
@@ -61,7 +61,7 @@ add_library(roc::rocshmem STATIC IMPORTED)
set_target_properties
(
roc::rocshmem PROPERTIES
set_target_properties
(
roc::rocshmem PROPERTIES
INTERFACE_COMPILE_OPTIONS
"-fgpu-rdc;-fgpu-rdc"
INTERFACE_COMPILE_OPTIONS
"-fgpu-rdc;-fgpu-rdc"
INTERFACE_INCLUDE_DIRECTORIES
"
${
_IMPORT_PREFIX
}
/include;
${
_IMPORT_PREFIX
}
/include"
INTERFACE_INCLUDE_DIRECTORIES
"
${
_IMPORT_PREFIX
}
/include;
${
_IMPORT_PREFIX
}
/include"
INTERFACE_LINK_LIBRARIES
"IBVerbs::verbs;numa;Threads::Threads;
MPI::MPI_CXX;
hip::device;hip::host;hsa-runtime64::hsa-runtime64;-fgpu-rdc"
INTERFACE_LINK_LIBRARIES
"IBVerbs::verbs;numa;
\$
<
\$
<BOOL:ON>:MPI::MPI_CXX>;
Threads::Threads;hip::device;hip::host;
dl;
hsa-runtime64::hsa-runtime64;-fgpu-rdc"
)
)
# Load information for each installed configuration.
# Load information for each installed configuration.
...
...
rocshmem_dir/lib/librocshmem.a
View file @
da13c63a
No preview for this file type
rocshmem_dir/rocshmem/lib/cmake/rocshmem-config-version.cmake
View file @
da13c63a
../../../lib/cmake/rocshmem/rocshmem-config-version.cmake
# This is a basic version file for the Config-mode of find_package().
\ No newline at end of file
# It is used by write_basic_package_version_file() as input file for configure_file()
# to create a version-file which can be installed along a config.cmake file.
#
# The created file sets PACKAGE_VERSION_EXACT if the current version string and
# the requested version string are exactly the same and it sets
# PACKAGE_VERSION_COMPATIBLE if the current version is >= requested version,
# but only if the requested major version is the same as the current one.
# The variable CVF_VERSION must be set before calling configure_file().
set
(
PACKAGE_VERSION
"3.0.0"
)
if
(
PACKAGE_VERSION VERSION_LESS PACKAGE_FIND_VERSION
)
set
(
PACKAGE_VERSION_COMPATIBLE FALSE
)
else
()
if
(
"3.0.0"
MATCHES
"^([0-9]+)
\\
."
)
set
(
CVF_VERSION_MAJOR
"
${
CMAKE_MATCH_1
}
"
)
if
(
NOT CVF_VERSION_MAJOR VERSION_EQUAL 0
)
string
(
REGEX REPLACE
"^0+"
""
CVF_VERSION_MAJOR
"
${
CVF_VERSION_MAJOR
}
"
)
endif
()
else
()
set
(
CVF_VERSION_MAJOR
"3.0.0"
)
endif
()
if
(
PACKAGE_FIND_VERSION_RANGE
)
# both endpoints of the range must have the expected major version
math
(
EXPR CVF_VERSION_MAJOR_NEXT
"
${
CVF_VERSION_MAJOR
}
+ 1"
)
if
(
NOT PACKAGE_FIND_VERSION_MIN_MAJOR STREQUAL CVF_VERSION_MAJOR
OR
((
PACKAGE_FIND_VERSION_RANGE_MAX STREQUAL
"INCLUDE"
AND NOT PACKAGE_FIND_VERSION_MAX_MAJOR STREQUAL CVF_VERSION_MAJOR
)
OR
(
PACKAGE_FIND_VERSION_RANGE_MAX STREQUAL
"EXCLUDE"
AND NOT PACKAGE_FIND_VERSION_MAX VERSION_LESS_EQUAL CVF_VERSION_MAJOR_NEXT
)))
set
(
PACKAGE_VERSION_COMPATIBLE FALSE
)
elseif
(
PACKAGE_FIND_VERSION_MIN_MAJOR STREQUAL CVF_VERSION_MAJOR
AND
((
PACKAGE_FIND_VERSION_RANGE_MAX STREQUAL
"INCLUDE"
AND PACKAGE_VERSION VERSION_LESS_EQUAL PACKAGE_FIND_VERSION_MAX
)
OR
(
PACKAGE_FIND_VERSION_RANGE_MAX STREQUAL
"EXCLUDE"
AND PACKAGE_VERSION VERSION_LESS PACKAGE_FIND_VERSION_MAX
)))
set
(
PACKAGE_VERSION_COMPATIBLE TRUE
)
else
()
set
(
PACKAGE_VERSION_COMPATIBLE FALSE
)
endif
()
else
()
if
(
PACKAGE_FIND_VERSION_MAJOR STREQUAL CVF_VERSION_MAJOR
)
set
(
PACKAGE_VERSION_COMPATIBLE TRUE
)
else
()
set
(
PACKAGE_VERSION_COMPATIBLE FALSE
)
endif
()
if
(
PACKAGE_FIND_VERSION STREQUAL PACKAGE_VERSION
)
set
(
PACKAGE_VERSION_EXACT TRUE
)
endif
()
endif
()
endif
()
# if the installed or the using project don't have CMAKE_SIZEOF_VOID_P set, ignore it:
if
(
"
${
CMAKE_SIZEOF_VOID_P
}
"
STREQUAL
""
OR
"8"
STREQUAL
""
)
return
()
endif
()
# check that the installed version has the same 32/64bit-ness as the one which is currently searching:
if
(
NOT CMAKE_SIZEOF_VOID_P STREQUAL
"8"
)
math
(
EXPR installedBits
"8 * 8"
)
set
(
PACKAGE_VERSION
"
${
PACKAGE_VERSION
}
(
${
installedBits
}
bit)"
)
set
(
PACKAGE_VERSION_UNSUITABLE TRUE
)
endif
()
rocshmem_dir/rocshmem/lib/cmake/rocshmem-config.cmake
View file @
da13c63a
../../../lib/cmake/rocshmem/rocshmem-config.cmake
\ No newline at end of file
####################################################################################
# Auto generated @PACKAGE_INIT@ by rocm_configure_package_config_file()
# from rocshmem-config.cmake.in
# Any changes to this file will be overwritten by the next CMake run
####################################################################################
get_filename_component
(
_ROCM_CMAKE_CURRENT_LIST_FILE_REAL
"
${
CMAKE_CURRENT_LIST_FILE
}
"
REALPATH
)
get_filename_component
(
_ROCM_CMAKE_CURRENT_LIST_DIR_REAL
"
${
_ROCM_CMAKE_CURRENT_LIST_FILE_REAL
}
"
DIRECTORY
)
get_filename_component
(
PACKAGE_PREFIX_DIR
"
${
_ROCM_CMAKE_CURRENT_LIST_DIR_REAL
}
/../../../"
ABSOLUTE
)
macro
(
set_and_check _var _file
)
set
(
${
_var
}
"
${
_file
}
"
)
if
(
NOT EXISTS
"
${
_file
}
"
)
message
(
FATAL_ERROR
"File or directory
${
_file
}
referenced by variable
${
_var
}
does not exist !"
)
endif
()
endmacro
()
include
(
CMakeFindDependencyMacro OPTIONAL RESULT_VARIABLE _ROCMCMakeFindDependencyMacro_FOUND
)
if
(
NOT _ROCMCMakeFindDependencyMacro_FOUND
)
macro
(
find_dependency dep
)
if
(
NOT
${
dep
}
_FOUND
)
set
(
rocm_fd_version
)
if
(
${
ARGC
}
GREATER 1
)
set
(
rocm_fd_version
${
ARGV1
}
)
endif
()
set
(
rocm_fd_exact_arg
)
if
(
${
CMAKE_FIND_PACKAGE_NAME
}
_FIND_VERSION_EXACT
)
set
(
rocm_fd_exact_arg EXACT
)
endif
()
set
(
rocm_fd_quiet_arg
)
if
(
${
CMAKE_FIND_PACKAGE_NAME
}
_FIND_QUIETLY
)
set
(
rocm_fd_quiet_arg QUIET
)
endif
()
set
(
rocm_fd_required_arg
)
if
(
${
CMAKE_FIND_PACKAGE_NAME
}
_FIND_REQUIRED
)
set
(
rocm_fd_required_arg REQUIRED
)
endif
()
find_package
(
${
dep
}
${
rocm_fd_version
}
${
rocm_fd_exact_arg
}
${
rocm_fd_quiet_arg
}
${
rocm_fd_required_arg
}
)
string
(
TOUPPER
${
dep
}
cmake_dep_upper
)
if
(
NOT
${
dep
}
_FOUND AND NOT
${
cmake_dep_upper
}
_FOUND
)
set
(
${
CMAKE_FIND_PACKAGE_NAME
}
_NOT_FOUND_MESSAGE
"
${
CMAKE_FIND_PACKAGE_NAME
}
could not be found because dependency
${
dep
}
could not be found."
)
set
(
${
CMAKE_FIND_PACKAGE_NAME
}
_FOUND False
)
return
()
endif
()
set
(
rocm_fd_version
)
set
(
rocm_fd_required_arg
)
set
(
rocm_fd_quiet_arg
)
set
(
rocm_fd_exact_arg
)
endif
()
endmacro
()
endif
()
macro
(
check_required_components _NAME
)
foreach
(
comp
${${
_NAME
}
_FIND_COMPONENTS
}
)
if
(
NOT
${
_NAME
}
_
${
comp
}
_FOUND
)
if
(
${
_NAME
}
_FIND_REQUIRED_
${
comp
}
)
set
(
${
_NAME
}
_FOUND FALSE
)
endif
()
endif
()
endforeach
()
endmacro
()
####################################################################################
set_and_check
(
rocshmem_INCLUDE_DIR
${
PACKAGE_PREFIX_DIR
}
/include
)
set_and_check
(
rocshmem_INCLUDE_DIRS
${
PACKAGE_PREFIX_DIR
}
/include
)
set_and_check
(
ROCSHMEM_INCLUDE_DIR
${
PACKAGE_PREFIX_DIR
}
/include
)
set_and_check
(
ROCSHMEM_INCLUDE_DIRS
${
PACKAGE_PREFIX_DIR
}
/include
)
set_and_check
(
rocshmem_INCLUDE_DIR
${
PACKAGE_PREFIX_DIR
}
/include
)
set_and_check
(
rocshmem_INCLUDE_DIRS
${
PACKAGE_PREFIX_DIR
}
/include
)
set_and_check
(
rocshmem_TARGET_FILE
${
PACKAGE_PREFIX_DIR
}
/lib/cmake/rocshmem/rocshmem-targets.cmake
)
include
(
${
rocshmem_TARGET_FILE
}
)
set
(
rocshmem_LIBRARIES roc::rocshmem
)
set
(
rocshmem_LIBRARY roc::rocshmem
)
set
(
ROCSHMEM_LIBRARIES roc::rocshmem
)
set
(
ROCSHMEM_LIBRARY roc::rocshmem
)
set
(
rocshmem_LIBRARIES roc::rocshmem
)
set
(
rocshmem_LIBRARY roc::rocshmem
)
Prev
1
2
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment