Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
DeepEP
Commits
09cb2b03
"vscode:/vscode.git/clone" did not exist on "abbf0b6784e240e30325993377564929c7be4db4"
Commit
09cb2b03
authored
Oct 30, 2025
by
lishen
Browse files
添加low latency接口,正确性需补充
parent
0b14d3b2
Changes
9
Show whitespace changes
Inline
Side-by-side
Showing
9 changed files
with
1579 additions
and
1059 deletions
+1579
-1059
build.sh
build.sh
+11
-6
csrc/config.hpp
csrc/config.hpp
+29
-26
csrc/deep_ep.cu
csrc/deep_ep.cu
+355
-64
csrc/deep_ep.hpp
csrc/deep_ep.hpp
+12
-3
csrc/kernels/api.cuh
csrc/kernels/api.cuh
+42
-0
csrc/kernels/configs.cuh
csrc/kernels/configs.cuh
+2
-0
csrc/kernels/internode_ll.cu
csrc/kernels/internode_ll.cu
+927
-949
csrc/kernels/utils.cuh
csrc/kernels/utils.cuh
+169
-9
deep_ep/buffer.py
deep_ep/buffer.py
+32
-2
No files found.
build.sh
View file @
09cb2b03
...
@@ -8,12 +8,17 @@ fi
...
@@ -8,12 +8,17 @@ fi
PYTHON_INCLUDE
=
$(
python3
-c
"from sysconfig import get_paths; print(get_paths()['include'])"
)
PYTHON_INCLUDE
=
$(
python3
-c
"from sysconfig import get_paths; print(get_paths()['include'])"
)
PYTHON_PLATLIB
=
$(
python3
-c
"from sysconfig import get_paths; print(get_paths()['platlib'])"
)
PYTHON_PLATLIB
=
$(
python3
-c
"from sysconfig import get_paths; print(get_paths()['platlib'])"
)
/opt/dtk/bin/hipcc
-Icsrc
/
-I
$(
pwd
)
/rocshmem_dir/include/
-I
/opt/mpi/include
-I
${
PYTHON_PLATLIB
}
/torch/include
-I
${
PYTHON_PLATLIB
}
/torch/include/torch/csrc/api/include
-I
${
PYTHON_PLATLIB
}
/torch/include/TH
-I
${
PYTHON_PLATLIB
}
/torch/include/THC
-I
${
PYTHON_PLATLIB
}
/torch/include/THH
-I
/opt/dtk/include
-I
${
PYTHON_INCLUDE
}
-c
-c
./csrc/kernels/intranode.cu
-o
build_/intranode.o
-fPIC
-D__HIP_PLATFORM_AMD__
=
1
-DUSE_ROCM
=
1
-DHIPBLAS_V2
-DCUDA_HAS_FP16
=
1
-D__HIP_NO_HALF_OPERATORS__
=
1
-D__HIP_NO_HALF_CONVERSIONS__
=
1
-O3
-fgpu-rdc
-DTORCH_API_INCLUDE_EXTENSION_H
'-DPYBIND11_COMPILER_TYPE="_gcc"'
'-DPYBIND11_STDLIB="_libstdcpp"'
'-DPYBIND11_BUILD_ABI="_cxxabi1014"'
-DTORCH_EXTENSION_NAME
=
deep_ep_cpp
-D_GLIBCXX_USE_CXX11_ABI
=
1
--offload-arch
=
gfx936
-std
=
c++17
INCLUDE_PATHS
=
${
INCLUDE_PATHS
:
=-Icsrc/ -I
$(
pwd
)
/rocshmem_dir/include/ -I/opt/mpi/include -I
${
PYTHON_PLATLIB
}
/torch/include -I
${
PYTHON_PLATLIB
}
/torch/include/torch/csrc/api/include -I
${
PYTHON_PLATLIB
}
/torch/include/TH -I
${
PYTHON_PLATLIB
}
/torch/include/THC -I
${
PYTHON_PLATLIB
}
/torch/include/THH -I/opt/dtk/include -I
${
PYTHON_INCLUDE
}}
/opt/dtk/bin/hipcc
-Icsrc
/
-I
$(
pwd
)
/rocshmem_dir/include/
-I
/opt/mpi/include
-I
${
PYTHON_PLATLIB
}
/torch/include
-I
${
PYTHON_PLATLIB
}
/torch/include/torch/csrc/api/include
-I
${
PYTHON_PLATLIB
}
/torch/include/TH
-I
${
PYTHON_PLATLIB
}
/torch/include/THC
-I
${
PYTHON_PLATLIB
}
/torch/include/THH
-I
/opt/dtk/include
-I
${
PYTHON_INCLUDE
}
-c
-c
./csrc/kernels/runtime.cu
-o
build_/runtime.o
-fPIC
-D__HIP_PLATFORM_AMD__
=
1
-DUSE_ROCM
=
1
-DHIPBLAS_V2
-DCUDA_HAS_FP16
=
1
-D__HIP_NO_HALF_OPERATORS__
=
1
-D__HIP_NO_HALF_CONVERSIONS__
=
1
-O3
-fgpu-rdc
-DTORCH_API_INCLUDE_EXTENSION_H
'-DPYBIND11_COMPILER_TYPE="_gcc"'
'-DPYBIND11_STDLIB="_libstdcpp"'
'-DPYBIND11_BUILD_ABI="_cxxabi1014"'
-DTORCH_EXTENSION_NAME
=
deep_ep_cpp
-D_GLIBCXX_USE_CXX11_ABI
=
1
--offload-arch
=
gfx936
-std
=
c++17
COMPILE_OPTIONS
=
${
COMPILE_OPTIONS
:
= -fPIC -D__HIP_PLATFORM_AMD__=1 -DUSE_ROCM=1 -DHIPBLAS_V2 -DCUDA_HAS_FP16=1 -D__HIP_NO_HALF_OPERATORS__=1 -D__HIP_NO_HALF_CONVERSIONS__=1 -O3 -fgpu-rdc -DTORCH_API_INCLUDE_EXTENSION_H
'-DPYBIND11_COMPILER_TYPE="_gcc"'
'-DPYBIND11_STDLIB="_libstdcpp"'
'-DPYBIND11_BUILD_ABI="_cxxabi1014"'
-DTORCH_EXTENSION_NAME=deep_ep_cpp -D_GLIBCXX_USE_CXX11_ABI=1 --offload-arch=gfx936 -std=c++17
}
/opt/dtk/bin/hipcc
-Icsrc
/
-I
$(
pwd
)
/rocshmem_dir/include/
-I
/opt/mpi/include
-I
${
PYTHON_PLATLIB
}
/torch/include
-I
${
PYTHON_PLATLIB
}
/torch/include/torch/csrc/api/include
-I
${
PYTHON_PLATLIB
}
/torch/include/TH
-I
${
PYTHON_PLATLIB
}
/torch/include/THC
-I
${
PYTHON_PLATLIB
}
/torch/include/THH
-I
/opt/dtk/include
-I
${
PYTHON_INCLUDE
}
-c
-c
./csrc/kernels/layout.cu
-o
build_/layout.o
-fPIC
-D__HIP_PLATFORM_AMD__
=
1
-DUSE_ROCM
=
1
-DHIPBLAS_V2
-DCUDA_HAS_FP16
=
1
-D__HIP_NO_HALF_OPERATORS__
=
1
-D__HIP_NO_HALF_CONVERSIONS__
=
1
-O3
-fgpu-rdc
-DTORCH_API_INCLUDE_EXTENSION_H
'-DPYBIND11_COMPILER_TYPE="_gcc"'
'-DPYBIND11_STDLIB="_libstdcpp"'
'-DPYBIND11_BUILD_ABI="_cxxabi1014"'
-DTORCH_EXTENSION_NAME
=
deep_ep_cpp
-D_GLIBCXX_USE_CXX11_ABI
=
1
--offload-arch
=
gfx936
-std
=
c++17
/opt/dtk/bin/hipcc
-Icsrc
/
-I
$(
pwd
)
/rocshmem_dir/include/
-I
/opt/mpi/include
-I
${
PYTHON_PLATLIB
}
/torch/include
-I
${
PYTHON_PLATLIB
}
/torch/include/torch/csrc/api/include
-I
${
PYTHON_PLATLIB
}
/torch/include/TH
-I
${
PYTHON_PLATLIB
}
/torch/include/THC
-I
${
PYTHON_PLATLIB
}
/torch/include/THH
-I
/opt/dtk/include
-I
${
PYTHON_INCLUDE
}
-c
-c
./csrc/deep_ep.cu
-o
build_/deep_ep.o
-fPIC
-D__HIP_PLATFORM_AMD__
=
1
-DUSE_ROCM
=
1
-DHIPBLAS_V2
-DCUDA_HAS_FP16
=
1
-D__HIP_NO_HALF_OPERATORS__
=
1
-D__HIP_NO_HALF_CONVERSIONS__
=
1
-O3
-fgpu-rdc
-DTORCH_API_INCLUDE_EXTENSION_H
'-DPYBIND11_COMPILER_TYPE="_gcc"'
'-DPYBIND11_STDLIB="_libstdcpp"'
'-DPYBIND11_BUILD_ABI="_cxxabi1014"'
-DTORCH_EXTENSION_NAME
=
deep_ep_cpp
-D_GLIBCXX_USE_CXX11_ABI
=
1
--offload-arch
=
gfx936
-std
=
c++17
hipcc
${
INCLUDE_PATHS
}
-c
$(
pwd
)
/csrc/kernels/runtime.cu
-o
build_/runtime.o
${
COMPILE_OPTIONS
}
/opt/dtk/bin/hipcc
-Icsrc
/
-I
$(
pwd
)
/rocshmem_dir/include/
-I
/opt/mpi/include
-I
${
PYTHON_PLATLIB
}
/torch/include
-I
${
PYTHON_PLATLIB
}
/torch/include/torch/csrc/api/include
-I
${
PYTHON_PLATLIB
}
/torch/include/TH
-I
${
PYTHON_PLATLIB
}
/torch/include/THC
-I
${
PYTHON_PLATLIB
}
/torch/include/THH
-I
/opt/dtk/include
-I
${
PYTHON_INCLUDE
}
-c
-c
./csrc/kernels/internode.cu
-o
build_/internode.o
-fPIC
-D__HIP_PLATFORM_AMD__
=
1
-DUSE_ROCM
=
1
-DHIPBLAS_V2
-DCUDA_HAS_FP16
=
1
-D__HIP_NO_HALF_OPERATORS__
=
1
-D__HIP_NO_HALF_CONVERSIONS__
=
1
-O3
-fgpu-rdc
-DTORCH_API_INCLUDE_EXTENSION_H
'-DPYBIND11_COMPILER_TYPE="_gcc"'
'-DPYBIND11_STDLIB="_libstdcpp"'
'-DPYBIND11_BUILD_ABI="_cxxabi1014"'
-DTORCH_EXTENSION_NAME
=
deep_ep_cpp
-D_GLIBCXX_USE_CXX11_ABI
=
1
--offload-arch
=
gfx936
-std
=
c++17
hipcc
${
INCLUDE_PATHS
}
-c
$(
pwd
)
/csrc/kernels/layout.cu
-o
build_/layout.o
${
COMPILE_OPTIONS
}
hipcc
-Wno-unused-result
-Wsign-compare
-DNDEBUG
-g
-fwrapv
-O2
-Wall
-g
-fstack-protector-strong
-Wformat
-Werror
=
format-security
-g
-fwrapv
-O2
-shared
-Wl
,-O1
-Wl
,-Bsymbolic-functions build_/internode.o build_/intranode.o build_/runtime.o build_/deep_ep.o build_/layout.o
-L
$(
pwd
)
/rocshmem_dir/lib/
-L
/opt/mpi/lib
-L
/opt/dtk/hip/lib
-L
/usr/lib/x86_64-linux-gnu
-lhipblaslt
-lamdhip64
-o
deep_ep/deep_ep_cpp.cpython-310-x86_64-linux-gnu.so
-Wl
,-rpath,/opt/dtk/lib
-fgpu-rdc
--hip-link
--offload-arch
=
gfx936
-shared
-Wl
,-soname,deep_ep/deep_ep_cpp.cpython-310-x86_64-linux-gnu.so
-Wl
,-rpath,
$(
pwd
)
/rocshmem_dir/lib/
-L
"/opt/dtk/llvm/lib/clang/15.0.0/include/../lib/linux"
-lclang_rt
.builtins-x86_64 /opt/dtk/hip/lib/libgalaxyhip.so /opt/dtk/llvm/lib/clang/15.0.0/lib/linux/libclang_rt.builtins-x86_64.a /opt/hyhal/lib/libhsa-runtime64.so.1.11.0
-L
${
PYTHON_PLATLIB
}
/torch/lib
-L
/opt/dtk/lib
-L
/opt/dtk/hip/lib
-L
/usr/local/lib
-lc10
-ltorch
-ltorch_cpu
-ltorch_python
-lamdhip64
-lc10_hip
-ltorch_hip
-lrocm-core
-lrocm_smi64
-l
:librocshmem.a
-fgpu-rdc
--hip-link
-lamdhip64
-lhsa-runtime64
-l
:libmpi.so
-Wl
,-rpath,/opt/mpi/lib/
-libverbs
-lmlx5
hipcc
${
INCLUDE_PATHS
}
-c
$(
pwd
)
/csrc/kernels/intranode.cu
-o
build_/intranode.o
${
COMPILE_OPTIONS
}
hipcc
${
INCLUDE_PATHS
}
-c
$(
pwd
)
/csrc/kernels/internode.cu
-o
build_/internode.o
${
COMPILE_OPTIONS
}
hipcc
${
INCLUDE_PATHS
}
-c
$(
pwd
)
/csrc/kernels/internode_ll.cu
-o
build_/internode_ll.o
${
COMPILE_OPTIONS
}
hipcc
${
INCLUDE_PATHS
}
-c
$(
pwd
)
/csrc/deep_ep.cu
-o
build_/deep_ep.o
${
COMPILE_OPTIONS
}
hipcc
-Wno-unused-result
-Wsign-compare
-DNDEBUG
-g
-fwrapv
-O2
-Wall
-g
-fstack-protector-strong
-Wformat
-Werror
=
format-security
-g
-fwrapv
-O2
-shared
-Wl
,-O1
-Wl
,-Bsymbolic-functions build_/internode.o build_/intranode.o build_/runtime.o build_/deep_ep.o build_/layout.o build_/internode_ll.o
-L
$(
pwd
)
/rocshmem_dir/lib/
-L
/opt/mpi/lib
-L
/opt/dtk/hip/lib
-L
/usr/lib/x86_64-linux-gnu
-lhipblaslt
-lamdhip64
-o
deep_ep/deep_ep_cpp.cpython-310-x86_64-linux-gnu.so
-Wl
,-rpath,/opt/dtk/lib
-fgpu-rdc
--hip-link
--offload-arch
=
gfx936
-shared
-Wl
,-soname,deep_ep/deep_ep_cpp.cpython-310-x86_64-linux-gnu.so
-Wl
,-rpath,
$(
pwd
)
/rocshmem_dir/lib/
-L
"/opt/dtk/llvm/lib/clang/15.0.0/include/../lib/linux"
-lclang_rt
.builtins-x86_64 /opt/dtk/hip/lib/libgalaxyhip.so /opt/dtk/llvm/lib/clang/15.0.0/lib/linux/libclang_rt.builtins-x86_64.a /opt/hyhal/lib/libhsa-runtime64.so.1.11.0
-L
${
PYTHON_PLATLIB
}
/torch/lib
-L
/opt/dtk/lib
-L
/opt/dtk/hip/lib
-L
/usr/local/lib
-lc10
-ltorch
-ltorch_cpu
-ltorch_python
-lamdhip64
-lc10_hip
-ltorch_hip
-lrocm-core
-lrocm_smi64
-l
:librocshmem.a
-fgpu-rdc
--hip-link
-lamdhip64
-lhsa-runtime64
-l
:libmpi.so
-Wl
,-rpath,/opt/mpi/lib/
-libverbs
-lmlx5
# build whl
# build whl
echo
"Using Python:
$(
which python3
)
"
echo
"Using Python:
$(
which python3
)
"
...
...
csrc/config.hpp
View file @
09cb2b03
// !!! This is a file automatically generated by hipify!!!
#include <ATen/dtk_macros.h>
#pragma once
#pragma once
#include "kernels/api.cuh"
#include "kernels/api.cuh"
...
@@ -105,18 +107,18 @@ struct Config {
...
@@ -105,18 +107,18 @@ struct Config {
struct
LowLatencyBuffer
{
struct
LowLatencyBuffer
{
int
num_clean_int
=
0
;
int
num_clean_int
=
0
;
void
*
dispatch_rdma_send_buffer
=
nullptr
;
void
*
dispatch_rdma_send_buffer
=
nullptr
;
void
*
dispatch_rdma_recv_data_buffer
=
nullptr
;
void
*
dispatch_rdma_recv_data_buffer
=
nullptr
;
int
*
dispatch_rdma_recv_count_buffer
=
nullptr
;
int
64_t
*
dispatch_rdma_recv_count_buffer
=
nullptr
;
void
*
combine_rdma_send_buffer
=
nullptr
;
void
*
combine_rdma_send_buffer
=
nullptr
;
void
*
combine_rdma_recv_data_buffer
=
nullptr
;
void
*
combine_rdma_recv_data_buffer
=
nullptr
;
int
*
combine_rdma_recv_flag_buffer
=
nullptr
;
int
64_t
*
combine_rdma_recv_flag_buffer
=
nullptr
;
void
*
combine_rdma_send_buffer_data_start
=
nullptr
;
void
*
combine_rdma_send_buffer_data_start
=
nullptr
;
size_t
num_bytes_per_combine_msg
=
0
;
size_t
num_bytes_per_combine_msg
=
0
;
std
::
pair
<
int
*
,
int
>
clean_meta
()
{
std
::
pair
<
int
64_t
*
,
int
>
clean_meta
()
{
EP_HOST_ASSERT
(
dispatch_rdma_recv_count_buffer
==
combine_rdma_recv_flag_buffer
);
EP_HOST_ASSERT
(
dispatch_rdma_recv_count_buffer
==
combine_rdma_recv_flag_buffer
);
return
{
dispatch_rdma_recv_count_buffer
,
num_clean_int
};
return
{
dispatch_rdma_recv_count_buffer
,
num_clean_int
};
}
}
...
@@ -171,29 +173,30 @@ struct LowLatencyLayout {
...
@@ -171,29 +173,30 @@ struct LowLatencyLayout {
total_bytes
+=
recv_buffer_bytes
*
2
;
total_bytes
+=
recv_buffer_bytes
*
2
;
// Symmetric signaling buffers
// Symmetric signaling buffers
size_t
dispatch_recv_count_buffer_bytes
=
num_experts
*
sizeof
(
int
);
size_t
dispatch_recv_count_buffer_bytes
=
num_experts
*
sizeof
(
int
64_t
);
size_t
combine_recv_flag_buffer_bytes
=
dispatch_recv_count_buffer_bytes
;
size_t
combine_recv_flag_buffer_bytes
=
dispatch_recv_count_buffer_bytes
;
size_t
signaling_buffer_bytes
=
size_t
signaling_buffer_bytes
=
std
::
max
(
dispatch_recv_count_buffer_bytes
,
combine_recv_flag_buffer_bytes
);
std
::
max
(
dispatch_recv_count_buffer_bytes
,
combine_recv_flag_buffer_bytes
);
total_bytes
+=
signaling_buffer_bytes
*
2
;
size_t
signaling_buffer_bytes_aligned
=
ALIGN
<
size_t
>
(
signaling_buffer_bytes
,
128
);
total_bytes
+=
signaling_buffer_bytes_aligned
*
2
;
// Assign pointers
// Assign pointers
// NOTES: we still leave some space for distinguishing dispatch/combine buffer,
// NOTES: we still leave some space for distinguishing dispatch/combine buffer,
// so you may see some parameters are duplicated
// so you may see some parameters are duplicated
for
(
int
i
=
0
;
i
<
2
;
++
i
)
{
for
(
int
i
=
0
;
i
<
2
;
++
i
)
{
buffers
[
i
]
=
{
buffers
[
i
]
=
{
static_cast
<
int
>
(
signaling_buffer_bytes
/
sizeof
(
int
)),
static_cast
<
int
>
(
signaling_buffer_bytes
/
sizeof
(
int64_t
)),
advance
(
rdma_buffer
,
signaling_buffer_bytes_aligned
*
2
+
send_buffer_bytes
*
i
),
// dispatch:send_buffer + recv_buffer + recv_count
advance
(
rdma_buffer
,
signaling_buffer_bytes_aligned
*
2
+
send_buffer_bytes
*
2
+
advance
(
rdma_buffer
,
send_buffer_bytes
*
i
),
recv_buffer_bytes
*
i
),
advance
(
rdma_buffer
,
send_buffer_bytes
*
2
+
recv_buffer_bytes
*
i
),
advance
<
int
*>
(
rdma_buffer
,
signaling_buffer_bytes_aligned
*
i
),
advance
<
int64_t
*>
(
rdma_buffer
,
send_buffer_bytes
*
2
+
recv_buffer_bytes
*
2
+
signaling_buffer_bytes
*
i
),
advance
(
rdma_buffer
,
signaling_buffer_bytes_aligned
*
2
+
send_buffer_bytes
*
i
),
// combine:send_buffer + recv_buffer + recv_count
advance
(
rdma_buffer
,
signaling_buffer_bytes_aligned
*
2
+
send_buffer_bytes
*
2
+
advance
(
rdma_buffer
,
send_buffer_bytes
*
i
),
recv_buffer_bytes
*
i
),
advance
(
rdma_buffer
,
send_buffer_bytes
*
2
+
recv_buffer_bytes
*
i
),
advance
<
int
*>
(
rdma_buffer
,
signaling_buffer_bytes_aligned
*
i
),
advance
<
int64_t
*>
(
rdma_buffer
,
send_buffer_bytes
*
2
+
recv_buffer_bytes
*
2
+
signaling_buffer_bytes
*
i
),
advance
(
rdma_buffer
,
signaling_buffer_bytes_aligned
*
2
+
send_buffer_bytes
*
i
),
// combine_rdma_send_buffer_data_start
num_bytes_per_combine_msg
};
advance
(
rdma_buffer
,
send_buffer_bytes
*
i
+
sizeof
(
int4
)),
//
num_bytes_per_combine_msg
};
}
}
}
}
};
};
...
...
csrc/deep_ep.cu
View file @
09cb2b03
//
#include <ATen/dtk_macros.h>
#include <ATen/dtk_macros.h>
#include <ATen/hip/HIPContext.h>
#include <ATen/hip/HIPContext.h>
#include <ATen/hip/HIPDataType.h>
#include <ATen/hip/HIPDataType.h>
#include <chrono>
#include <chrono>
...
@@ -13,20 +13,19 @@
...
@@ -13,20 +13,19 @@
namespace
deep_ep
{
namespace
deep_ep
{
Buffer
::
Buffer
(
int
rank
,
int
num_ranks
,
int64_t
num_nvl_bytes
,
int64_t
num_rdma_bytes
,
Buffer
::
Buffer
(
int
rank
,
int
num_ranks
,
int64_t
num_nvl_bytes
,
int64_t
num_rdma_bytes
,
bool
low_latency_mode
,
bool
explicitly_destroy
,
bool
low_latency_mode
,
bool
explicitly_destroy
,
bool
enable_shrink
)
bool
use_default_stream_as_comm_stream
)
:
rank
(
rank
),
num_ranks
(
num_ranks
),
num_nvl_bytes
(
num_nvl_bytes
),
:
rank
(
rank
),
num_ranks
(
num_ranks
),
num_nvl_bytes
(
num_nvl_bytes
),
num_rdma_bytes
(
num_rdma_bytes
),
low_latency_mode
(
low_latency_mode
),
num_rdma_bytes
(
num_rdma_bytes
),
low_latency_mode
(
low_latency_mode
),
explicitly_destroy
(
explicitly_destroy
),
explicitly_destroy
(
explicitly_destroy
),
use_default_stream_as_comm_stream
(
use_default_stream_as_comm_stream
),
enable_shrink
(
enable_shrink
),
comm_stream
(
use_default_stream_as_comm_stream
comm_stream
(
at
::
hip
::
getStreamFromPoolMasqueradingAsCUDA
(
true
))
{
?
at
::
hip
::
getCurrentHIPStreamMasqueradingAsCUDA
()
:
at
::
hip
::
getStreamFromPoolMasqueradingAsCUDA
(
true
))
{
// Metadata memory
// Metadata memory
int64_t
barrier_signal_bytes
=
NUM_MAX_NVL_PEERS
*
sizeof
(
int
);
int64_t
barrier_signal_bytes
=
NUM_MAX_NVL_PEERS
*
sizeof
(
int
);
int64_t
buffer_ptr_bytes
=
NUM_MAX_NVL_PEERS
*
sizeof
(
void
*
);
int64_t
buffer_ptr_bytes
=
NUM_MAX_NVL_PEERS
*
sizeof
(
void
*
);
int64_t
barrier_signal_ptr_bytes
=
NUM_MAX_NVL_PEERS
*
sizeof
(
int
*
);
int64_t
barrier_signal_ptr_bytes
=
NUM_MAX_NVL_PEERS
*
sizeof
(
int
*
);
EP_HOST_ASSERT
(
enable_shrink
==
false
);
// Common checks
// Common checks
EP_HOST_ASSERT
(
num_nvl_bytes
%
NUM_BUFFER_ALIGNMENT_BYTES
==
0
and
EP_HOST_ASSERT
(
num_nvl_bytes
%
NUM_BUFFER_ALIGNMENT_BYTES
==
0
and
(
num_nvl_bytes
<=
std
::
numeric_limits
<
int
>::
max
()
or
num_rdma_bytes
==
0
));
(
num_nvl_bytes
<=
std
::
numeric_limits
<
int
>::
max
()
or
num_rdma_bytes
==
0
));
...
@@ -77,7 +76,7 @@ Buffer::Buffer(int rank, int num_ranks, int64_t num_nvl_bytes, int64_t num_rdma_
...
@@ -77,7 +76,7 @@ Buffer::Buffer(int rank, int num_ranks, int64_t num_nvl_bytes, int64_t num_rdma_
}
}
// Create 32 MiB workspace
// Create 32 MiB workspace
CUDA_CHECK
(
hipMalloc
(
&
workspace
,
NUM_WORKSPACE_BYTES
));
CUDA_CHECK
(
hip
Ext
Malloc
WithFlags
(
&
workspace
,
NUM_WORKSPACE_BYTES
,
hipDeviceMallocUncached
));
CUDA_CHECK
(
hipMemsetAsync
(
workspace
,
0
,
NUM_WORKSPACE_BYTES
,
comm_stream
));
CUDA_CHECK
(
hipMemsetAsync
(
workspace
,
0
,
NUM_WORKSPACE_BYTES
,
comm_stream
));
// MoE counter
// MoE counter
...
@@ -200,6 +199,10 @@ void Buffer::destroy() {
...
@@ -200,6 +199,10 @@ void Buffer::destroy() {
CUDA_CHECK
(
hipDeviceSynchronize
());
CUDA_CHECK
(
hipDeviceSynchronize
());
internode
::
barrier
();
internode
::
barrier
();
internode
::
free
(
rdma_buffer_ptr
);
internode
::
free
(
rdma_buffer_ptr
);
if
(
enable_shrink
)
{
internode
::
free
(
mask_buffer_ptr
);
internode
::
free
(
sync_buffer_ptr
);
}
internode
::
finalize
();
internode
::
finalize
();
}
}
#endif
#endif
...
@@ -253,25 +256,32 @@ void Buffer::sync(const std::vector<int> &device_
...
@@ -253,25 +256,32 @@ void Buffer::sync(const std::vector<int> &device_
// Sync ROCSHMEM handles and allocate memory
// Sync ROCSHMEM handles and allocate memory
if
(
num_rdma_bytes
>
0
)
{
if
(
num_rdma_bytes
>
0
)
{
// Initialize
NV
SHMEM
// Initialize
ROC
SHMEM
EP_HOST_ASSERT
(
root_unique_id_opt
.
has_value
());
EP_HOST_ASSERT
(
root_unique_id_opt
.
has_value
());
std
::
vector
<
uint8_t
>
root_unique_id
(
root_unique_id_opt
->
size
());
std
::
vector
<
uint8_t
>
root_unique_id
(
root_unique_id_opt
->
size
());
auto
root_unique_id_str
=
root_unique_id_opt
->
cast
<
std
::
string
>
();
auto
root_unique_id_str
=
root_unique_id_opt
->
cast
<
std
::
string
>
();
std
::
memcpy
(
root_unique_id
.
data
(),
root_unique_id_str
.
c_str
(),
root_unique_id_opt
->
size
());
std
::
memcpy
(
root_unique_id
.
data
(),
root_unique_id_str
.
c_str
(),
root_unique_id_opt
->
size
());
auto
nvshmem_rank
=
low_latency_mode
?
rank
:
rdma_rank
;
auto
nvshmem_rank
=
low_latency_mode
?
rank
:
rdma_rank
;
auto
num_nvshmem_ranks
=
low_latency_mode
?
num_ranks
:
num_rdma_ranks
;
auto
num_nvshmem_ranks
=
low_latency_mode
?
num_ranks
:
num_rdma_ranks
;
EP_HOST_ASSERT
(
nvshmem_rank
==
EP_HOST_ASSERT
(
nvshmem_rank
==
internode
::
init
(
root_unique_id
,
nvshmem_rank
,
num_nvshmem_ranks
,
low_latency_mode
));
internode
::
init
(
root_unique_id
,
nvshmem_rank
,
num_nvshmem_ranks
,
low_latency_mode
));
internode
::
barrier
();
internode
::
barrier
();
// Allocate
// Allocate
rdma_buffer_ptr
=
rdma_buffer_ptr
=
internode
::
alloc
(
num_rdma_bytes
,
NUM_BUFFER_ALIGNMENT_BYTES
);
internode
::
alloc
(
num_rdma_bytes
,
NUM_BUFFER_ALIGNMENT_BYTES
);
// Clean buffer (mainly for low-latency mode)
// Clean buffer (mainly for low-latency mode)
CUDA_CHECK
(
hipMemset
(
rdma_buffer_ptr
,
0
,
num_rdma_bytes
));
CUDA_CHECK
(
hipMemset
(
rdma_buffer_ptr
,
0
,
num_rdma_bytes
));
// Allocate and clean shrink buffer
if
(
enable_shrink
)
{
int
num_mask_buffer_bytes
=
num_ranks
*
sizeof
(
int
);
int
num_sync_buffer_bytes
=
num_ranks
*
sizeof
(
int
);
mask_buffer_ptr
=
reinterpret_cast
<
int
*>
(
internode
::
alloc
(
num_mask_buffer_bytes
,
NUM_BUFFER_ALIGNMENT_BYTES
));
sync_buffer_ptr
=
reinterpret_cast
<
int
*>
(
internode
::
alloc
(
num_sync_buffer_bytes
,
NUM_BUFFER_ALIGNMENT_BYTES
));
CUDA_CHECK
(
hipMemset
(
mask_buffer_ptr
,
0
,
num_mask_buffer_bytes
));
CUDA_CHECK
(
hipMemset
(
sync_buffer_ptr
,
0
,
num_sync_buffer_bytes
));
}
// Barrier
// Barrier
internode
::
barrier
();
internode
::
barrier
();
CUDA_CHECK
(
hipDeviceSynchronize
());
CUDA_CHECK
(
hipDeviceSynchronize
());
...
@@ -298,14 +308,12 @@ Buffer::get_dispatch_layout(const torch::Tensor &topk_idx, int num_experts,
...
@@ -298,14 +308,12 @@ Buffer::get_dispatch_layout(const torch::Tensor &topk_idx, int num_experts,
at
::
hip
::
setCurrentHIPStreamMasqueradingAsCUDA
(
comm_stream
);
at
::
hip
::
setCurrentHIPStreamMasqueradingAsCUDA
(
comm_stream
);
}
}
if
(
not
use_default_stream_as_comm_stream
)
{
// Wait previous tasks to be finished
// Wait previous tasks to be finished
if
(
previous_event
.
has_value
())
{
if
(
previous_event
.
has_value
())
{
stream_wait
(
comm_stream
,
previous_event
.
value
());
stream_wait
(
comm_stream
,
previous_event
.
value
());
}
else
{
}
else
{
stream_wait
(
comm_stream
,
compute_stream
);
stream_wait
(
comm_stream
,
compute_stream
);
}
}
}
auto
num_tokens
=
static_cast
<
int
>
(
topk_idx
.
size
(
0
)),
auto
num_tokens
=
static_cast
<
int
>
(
topk_idx
.
size
(
0
)),
num_topk
=
static_cast
<
int
>
(
topk_idx
.
size
(
1
));
num_topk
=
static_cast
<
int
>
(
topk_idx
.
size
(
1
));
...
@@ -342,10 +350,8 @@ Buffer::get_dispatch_layout(const torch::Tensor &topk_idx, int num_experts,
...
@@ -342,10 +350,8 @@ Buffer::get_dispatch_layout(const torch::Tensor &topk_idx, int num_experts,
to
.
has_value
()
?
to
->
record_stream
(
compute_stream
)
:
void
();
to
.
has_value
()
?
to
->
record_stream
(
compute_stream
)
:
void
();
}
}
}
else
{
}
else
{
if
(
not
use_default_stream_as_comm_stream
)
{
stream_wait
(
compute_stream
,
comm_stream
);
stream_wait
(
compute_stream
,
comm_stream
);
}
}
}
// Switch back compute stream
// Switch back compute stream
if
(
allocate_on_comm_stream
)
if
(
allocate_on_comm_stream
)
...
@@ -461,13 +467,11 @@ Buffer::intranode_dispatch(
...
@@ -461,13 +467,11 @@ Buffer::intranode_dispatch(
}
}
// Wait previous tasks to be finished
// Wait previous tasks to be finished
if
(
not
use_default_stream_as_comm_stream
)
{
if
(
previous_event
.
has_value
())
{
if
(
previous_event
.
has_value
())
{
stream_wait
(
comm_stream
,
previous_event
.
value
());
stream_wait
(
comm_stream
,
previous_event
.
value
());
}
else
{
}
else
{
stream_wait
(
comm_stream
,
compute_stream
);
stream_wait
(
comm_stream
,
compute_stream
);
}
}
}
// Create handles (only return for non-cached mode)
// Create handles (only return for non-cached mode)
int
num_recv_tokens
=
-
1
;
int
num_recv_tokens
=
-
1
;
...
@@ -623,10 +627,8 @@ Buffer::intranode_dispatch(
...
@@ -623,10 +627,8 @@ Buffer::intranode_dispatch(
to
.
has_value
()
?
to
->
record_stream
(
compute_stream
)
:
void
();
to
.
has_value
()
?
to
->
record_stream
(
compute_stream
)
:
void
();
}
}
}
else
{
}
else
{
if
(
not
use_default_stream_as_comm_stream
)
{
stream_wait
(
compute_stream
,
comm_stream
);
stream_wait
(
compute_stream
,
comm_stream
);
}
}
}
// Switch back compute stream
// Switch back compute stream
if
(
allocate_on_comm_stream
)
if
(
allocate_on_comm_stream
)
...
@@ -691,13 +693,11 @@ Buffer::intranode_combine(const torch::Tensor &x, const std::optional<torch::Ten
...
@@ -691,13 +693,11 @@ Buffer::intranode_combine(const torch::Tensor &x, const std::optional<torch::Ten
}
}
// Wait previous tasks to be finished
// Wait previous tasks to be finished
if
(
not
use_default_stream_as_comm_stream
)
{
if
(
previous_event
.
has_value
())
{
if
(
previous_event
.
has_value
())
{
stream_wait
(
comm_stream
,
previous_event
.
value
());
stream_wait
(
comm_stream
,
previous_event
.
value
());
}
else
{
}
else
{
stream_wait
(
comm_stream
,
compute_stream
);
stream_wait
(
comm_stream
,
compute_stream
);
}
}
}
int
num_topk
=
0
;
int
num_topk
=
0
;
auto
recv_topk_weights
=
std
::
optional
<
torch
::
Tensor
>
();
auto
recv_topk_weights
=
std
::
optional
<
torch
::
Tensor
>
();
...
@@ -765,10 +765,8 @@ Buffer::intranode_combine(const torch::Tensor &x, const std::optional<torch::Ten
...
@@ -765,10 +765,8 @@ Buffer::intranode_combine(const torch::Tensor &x, const std::optional<torch::Ten
to
.
has_value
()
?
to
->
record_stream
(
compute_stream
)
:
void
();
to
.
has_value
()
?
to
->
record_stream
(
compute_stream
)
:
void
();
}
}
}
else
{
}
else
{
if
(
not
use_default_stream_as_comm_stream
)
{
stream_wait
(
compute_stream
,
comm_stream
);
stream_wait
(
compute_stream
,
comm_stream
);
}
}
}
// Switch back compute stream
// Switch back compute stream
if
(
allocate_on_comm_stream
)
if
(
allocate_on_comm_stream
)
...
@@ -804,8 +802,8 @@ Buffer::internode_dispatch(const torch::Tensor &x, const std::optional<torch::Te
...
@@ -804,8 +802,8 @@ Buffer::internode_dispatch(const torch::Tensor &x, const std::optional<torch::Te
// here.
// here.
pybind11
::
gil_scoped_release
release
;
pybind11
::
gil_scoped_release
release
;
const
int
num_channels
=
config
.
num_sms
/
3
;
const
int
num_channels
=
config
.
num_sms
/
NUM_INTERNODE_DISPATCH_BLOCKS_PER_CHANNEL
;
EP_HOST_ASSERT
(
config
.
num_sms
%
3
==
0
);
EP_HOST_ASSERT
(
config
.
num_sms
%
NUM_INTERNODE_DISPATCH_BLOCKS_PER_CHANNEL
==
0
);
EP_HOST_ASSERT
(
0
<
get_num_rdma_ranks
()
and
get_num_rdma_ranks
()
<=
NUM_MAX_RDMA_PEERS
);
EP_HOST_ASSERT
(
0
<
get_num_rdma_ranks
()
and
get_num_rdma_ranks
()
<=
NUM_MAX_RDMA_PEERS
);
bool
cached_mode
=
cached_rdma_channel_prefix_matrix
.
has_value
();
bool
cached_mode
=
cached_rdma_channel_prefix_matrix
.
has_value
();
...
@@ -1125,8 +1123,8 @@ Buffer::internode_combine(
...
@@ -1125,8 +1123,8 @@ Buffer::internode_combine(
const
torch
::
Tensor
&
combined_nvl_head
,
const
Config
&
config
,
const
torch
::
Tensor
&
combined_nvl_head
,
const
Config
&
config
,
std
::
optional
<
EventHandle
>
&
previous_event
,
bool
async
,
bool
allocate_on_comm_stream
)
{
std
::
optional
<
EventHandle
>
&
previous_event
,
bool
async
,
bool
allocate_on_comm_stream
)
{
#ifndef DISABLE_ROCSHMEM
#ifndef DISABLE_ROCSHMEM
const
int
num_channels
=
config
.
num_sms
/
3
;
const
int
num_channels
=
config
.
num_sms
/
NUM_INTERNODE_DISPATCH_BLOCKS_PER_CHANNEL
;
EP_HOST_ASSERT
(
config
.
num_sms
%
3
==
0
);
EP_HOST_ASSERT
(
config
.
num_sms
%
NUM_INTERNODE_DISPATCH_BLOCKS_PER_CHANNEL
==
0
);
// Shape and contiguous checks
// Shape and contiguous checks
EP_HOST_ASSERT
(
x
.
dim
()
==
2
and
x
.
is_contiguous
());
EP_HOST_ASSERT
(
x
.
dim
()
==
2
and
x
.
is_contiguous
());
...
@@ -1272,39 +1270,329 @@ Buffer::internode_combine(
...
@@ -1272,39 +1270,329 @@ Buffer::internode_combine(
#endif
#endif
}
}
void
Buffer
::
clean_low_latency_buffer
(
int
num_max_dispatch_tokens_per_rank
,
int
hidden
,
void
Buffer
::
clean_low_latency_buffer
(
int
num_max_dispatch_tokens_per_rank
,
int
hidden
,
int
num_experts
)
{
int
num_experts
)
{
#ifndef DISABLE_ROCSHMEM
EP_HOST_ASSERT
(
false
and
"not support low latency"
);
EP_HOST_ASSERT
(
low_latency_mode
);
auto
layout
=
LowLatencyLayout
(
rdma_buffer_ptr
,
num_max_dispatch_tokens_per_rank
,
hidden
,
num_ranks
,
num_experts
);
auto
clean_meta_0
=
layout
.
buffers
[
0
].
clean_meta
();
auto
clean_meta_1
=
layout
.
buffers
[
1
].
clean_meta
();
auto
check_boundary
=
[
=
](
void
*
ptr
,
size_t
num_bytes
)
{
auto
offset
=
reinterpret_cast
<
int64_t
>
(
ptr
)
-
reinterpret_cast
<
int64_t
>
(
rdma_buffer_ptr
);
EP_HOST_ASSERT
(
0
<=
offset
and
offset
+
num_bytes
<=
num_rdma_bytes
);
};
check_boundary
(
clean_meta_0
.
first
,
clean_meta_0
.
second
*
sizeof
(
int64_t
));
check_boundary
(
clean_meta_1
.
first
,
clean_meta_1
.
second
*
sizeof
(
int64_t
));
internode_ll
::
clean_low_latency_buffer
(
clean_meta_0
.
first
,
clean_meta_0
.
second
,
clean_meta_1
.
first
,
clean_meta_1
.
second
,
rank
,
num_ranks
,
mask_buffer_ptr
,
sync_buffer_ptr
,
at
::
hip
::
getCurrentHIPStreamMasqueradingAsCUDA
());
#else
EP_HOST_ASSERT
(
false
and
"ROCSHMEM is disabled during compilation"
);
#endif
}
}
std
::
tuple
<
torch
::
Tensor
,
std
::
optional
<
torch
::
Tensor
>
,
torch
::
Tensor
,
torch
::
Tensor
,
torch
::
Tensor
,
std
::
tuple
<
torch
::
Tensor
,
std
::
optional
<
torch
::
Tensor
>
,
torch
::
Tensor
,
torch
::
Tensor
,
torch
::
Tensor
,
std
::
optional
<
EventHandle
>
,
std
::
optional
<
std
::
function
<
void
()
>>>
std
::
optional
<
EventHandle
>
,
std
::
optional
<
std
::
function
<
void
()
>>>
Buffer
::
low_latency_dispatch
(
const
torch
::
Tensor
&
x
,
const
torch
::
Tensor
&
topk_idx
,
Buffer
::
low_latency_dispatch
(
const
torch
::
Tensor
&
x
,
const
torch
::
Tensor
&
topk_idx
,
const
std
::
optional
<
torch
::
Tensor
>
&
cumulative_local_expert_recv_stats
,
const
std
::
optional
<
torch
::
Tensor
>
&
cumulative_local_expert_recv_stats
,
const
std
::
optional
<
torch
::
Tensor
>
&
dispatch_wait_recv_cost_stats
,
const
std
::
optional
<
torch
::
Tensor
>
&
dispatch_wait_recv_cost_stats
,
int
num_max_dispatch_tokens_per_rank
,
int
num_experts
,
bool
use_fp8
,
int
num_max_dispatch_tokens_per_rank
,
int
num_experts
,
bool
use_fp8
,
bool
round_scale
,
bool
use_ue8m0
,
bool
async
,
bool
return_recv_hook
)
{
bool
round_scale
,
bool
use_ue8m0
,
bool
async
,
bool
return_recv_hook
)
{
EP_HOST_ASSERT
(
false
and
"not support low latency"
);
#ifndef DISABLE_ROCSHMEM
EP_HOST_ASSERT
(
low_latency_mode
);
// Tensor checks
// By default using `ptp128c` FP8 cast
EP_HOST_ASSERT
(
x
.
dim
()
==
2
and
x
.
is_contiguous
()
and
x
.
scalar_type
()
==
torch
::
kBFloat16
);
EP_HOST_ASSERT
(
x
.
size
(
1
)
%
sizeof
(
int4
)
==
0
and
x
.
size
(
1
)
%
128
==
0
);
EP_HOST_ASSERT
(
topk_idx
.
dim
()
==
2
and
topk_idx
.
is_contiguous
());
EP_HOST_ASSERT
(
x
.
size
(
0
)
==
topk_idx
.
size
(
0
)
and
x
.
size
(
0
)
<=
num_max_dispatch_tokens_per_rank
);
EP_HOST_ASSERT
(
topk_idx
.
scalar_type
()
==
torch
::
kInt64
);
EP_HOST_ASSERT
(
num_experts
%
num_ranks
==
0
);
// Diagnosis tensors
if
(
cumulative_local_expert_recv_stats
.
has_value
())
{
EP_HOST_ASSERT
(
cumulative_local_expert_recv_stats
->
scalar_type
()
==
torch
::
kInt
);
EP_HOST_ASSERT
(
cumulative_local_expert_recv_stats
->
dim
()
==
1
and
cumulative_local_expert_recv_stats
->
is_contiguous
());
EP_HOST_ASSERT
(
cumulative_local_expert_recv_stats
->
size
(
0
)
==
num_experts
/
num_ranks
);
}
if
(
dispatch_wait_recv_cost_stats
.
has_value
())
{
EP_HOST_ASSERT
(
dispatch_wait_recv_cost_stats
->
scalar_type
()
==
torch
::
kInt64
);
EP_HOST_ASSERT
(
dispatch_wait_recv_cost_stats
->
dim
()
==
1
and
dispatch_wait_recv_cost_stats
->
is_contiguous
());
EP_HOST_ASSERT
(
dispatch_wait_recv_cost_stats
->
size
(
0
)
==
num_ranks
);
}
auto
num_tokens
=
static_cast
<
int
>
(
x
.
size
(
0
)),
hidden
=
static_cast
<
int
>
(
x
.
size
(
1
));
auto
num_topk
=
static_cast
<
int
>
(
topk_idx
.
size
(
1
));
auto
num_local_experts
=
num_experts
/
num_ranks
;
// Buffer control
LowLatencyLayout
layout
(
rdma_buffer_ptr
,
num_max_dispatch_tokens_per_rank
,
hidden
,
num_ranks
,
num_experts
);
EP_HOST_ASSERT
(
layout
.
total_bytes
<=
num_rdma_bytes
);
auto
buffer
=
layout
.
buffers
[
low_latency_buffer_idx
];
auto
next_buffer
=
layout
.
buffers
[
low_latency_buffer_idx
^=
1
];
// 双buffer操作
auto
global_atomic_counter
=
torch
::
zeros
({
1
},
torch
::
dtype
(
torch
::
kInt32
).
device
(
torch
::
kCUDA
));
// Wait previous tasks to be finished
// NOTES: the hook mode will always use the default stream
auto
compute_stream
=
at
::
hip
::
getCurrentHIPStreamMasqueradingAsCUDA
();
auto
launch_stream
=
return_recv_hook
?
compute_stream
:
comm_stream
;
EP_HOST_ASSERT
(
not
(
async
and
return_recv_hook
));
if
(
not
return_recv_hook
)
stream_wait
(
launch_stream
,
compute_stream
);
// Allocate packed tensors
auto
packed_recv_x
=
torch
::
empty
({
num_local_experts
,
num_ranks
*
num_max_dispatch_tokens_per_rank
,
hidden
},
x
.
options
().
dtype
(
use_fp8
?
torch
::
kFloat8_e4m3fn
:
torch
::
kBFloat16
));
auto
packed_recv_src_info
=
torch
::
empty
({
num_local_experts
,
num_ranks
*
num_max_dispatch_tokens_per_rank
},
torch
::
dtype
(
torch
::
kInt32
).
device
(
torch
::
kCUDA
));
auto
packed_recv_layout_range
=
torch
::
empty
({
num_local_experts
,
num_ranks
},
torch
::
dtype
(
torch
::
kInt64
).
device
(
torch
::
kCUDA
));
auto
packed_recv_count
=
torch
::
empty
({
num_local_experts
},
torch
::
dtype
(
torch
::
kInt32
).
device
(
torch
::
kCUDA
));
// Allocate column-majored scales
auto
packed_recv_x_scales
=
std
::
optional
<
torch
::
Tensor
>
();
void
*
packed_recv_x_scales_ptr
=
nullptr
;
EP_HOST_ASSERT
((
num_ranks
*
num_max_dispatch_tokens_per_rank
)
%
4
==
0
and
"TMA requires the number of tokens to be multiple of 4"
);
// TODO: support unaligned cases
EP_HOST_ASSERT
(
hidden
%
512
==
0
);
if
(
use_fp8
)
{
if
(
not
use_ue8m0
)
{
packed_recv_x_scales
=
torch
::
empty
({
num_local_experts
,
hidden
/
128
,
num_ranks
*
num_max_dispatch_tokens_per_rank
},
torch
::
dtype
(
torch
::
kFloat32
).
device
(
torch
::
kCUDA
));
}
else
{
EP_HOST_ASSERT
(
round_scale
);
packed_recv_x_scales
=
torch
::
empty
({
num_local_experts
,
hidden
/
512
,
num_ranks
*
num_max_dispatch_tokens_per_rank
},
torch
::
dtype
(
torch
::
kInt
).
device
(
torch
::
kCUDA
));
}
packed_recv_x_scales
=
torch
::
transpose
(
packed_recv_x_scales
.
value
(),
1
,
2
);
packed_recv_x_scales_ptr
=
packed_recv_x_scales
->
data_ptr
();
}
// Kernel launch
auto
next_clean_meta
=
next_buffer
.
clean_meta
();
auto
launcher
=
[
=
](
int
phases
)
{
internode_ll
::
dispatch
(
packed_recv_x
.
data_ptr
(),
packed_recv_x_scales_ptr
,
packed_recv_src_info
.
data_ptr
<
int
>
(),
packed_recv_layout_range
.
data_ptr
<
int64_t
>
(),
packed_recv_count
.
data_ptr
<
int
>
(),
global_atomic_counter
.
data_ptr
<
int
>
(),
mask_buffer_ptr
,
cumulative_local_expert_recv_stats
.
has_value
()
?
cumulative_local_expert_recv_stats
->
data_ptr
<
int
>
()
:
nullptr
,
dispatch_wait_recv_cost_stats
.
has_value
()
?
dispatch_wait_recv_cost_stats
->
data_ptr
<
int64_t
>
()
:
nullptr
,
buffer
.
dispatch_rdma_recv_data_buffer
,
buffer
.
dispatch_rdma_recv_count_buffer
,
buffer
.
dispatch_rdma_send_buffer
,
x
.
data_ptr
(),
topk_idx
.
data_ptr
<
int64_t
>
(),
next_clean_meta
.
first
,
next_clean_meta
.
second
,
num_tokens
,
hidden
,
num_max_dispatch_tokens_per_rank
,
num_topk
,
num_experts
,
rank
,
num_ranks
,
use_fp8
,
round_scale
,
use_ue8m0
,
workspace
,
num_device_sms
,
launch_stream
,
phases
);
};
launcher
(
return_recv_hook
?
LOW_LATENCY_SEND_PHASE
:
(
LOW_LATENCY_SEND_PHASE
|
LOW_LATENCY_RECV_PHASE
));
// Wait streams
std
::
optional
<
EventHandle
>
event
;
if
(
async
)
{
// NOTES: we must ensure the all tensors will not be deallocated before the stream-wait happens,
// so in Python API, we must wrap all tensors into the event handle.
event
=
EventHandle
(
launch_stream
);
}
else
if
(
not
return_recv_hook
)
{
stream_wait
(
compute_stream
,
launch_stream
);
}
// Receiver callback
std
::
optional
<
std
::
function
<
void
()
>>
recv_hook
=
std
::
nullopt
;
if
(
return_recv_hook
)
recv_hook
=
[
=
]()
{
launcher
(
LOW_LATENCY_RECV_PHASE
);
};
// Return values
return
{
packed_recv_x
,
packed_recv_x_scales
,
packed_recv_count
,
packed_recv_src_info
,
packed_recv_layout_range
,
event
,
recv_hook
};
#else
EP_HOST_ASSERT
(
false
and
"ROCSHMEM is disabled during compilation"
);
return
{};
return
{};
#endif
}
}
std
::
tuple
<
torch
::
Tensor
,
std
::
optional
<
EventHandle
>
,
std
::
optional
<
std
::
function
<
void
()
>>>
std
::
tuple
<
torch
::
Tensor
,
std
::
optional
<
EventHandle
>
,
std
::
optional
<
std
::
function
<
void
()
>>>
Buffer
::
low_latency_combine
(
const
torch
::
Tensor
&
x
,
const
torch
::
Tensor
&
topk_idx
,
Buffer
::
low_latency_combine
(
const
torch
::
Tensor
&
x
,
const
torch
::
Tensor
&
topk_idx
,
const
torch
::
Tensor
&
topk_weights
,
const
torch
::
Tensor
&
topk_weights
,
const
torch
::
Tensor
&
src_info
,
const
torch
::
Tensor
&
src_info
,
const
torch
::
Tensor
&
layout_range
,
const
torch
::
Tensor
&
layout_range
,
const
std
::
optional
<
torch
::
Tensor
>&
combine_wait_recv_cost_stats
,
const
std
::
optional
<
torch
::
Tensor
>
&
combine_wait_recv_cost_stats
,
int
num_max_dispatch_tokens_per_rank
,
int
num_experts
,
bool
use_logfmt
,
int
num_max_dispatch_tokens_per_rank
,
int
num_experts
,
bool
use_logfmt
,
bool
zero_copy
,
bool
async
,
bool
return_recv_hook
,
bool
zero_copy
,
bool
async
,
bool
return_recv_hook
,
const
std
::
optional
<
torch
::
Tensor
>
&
out
)
{
const
std
::
optional
<
torch
::
Tensor
>&
out
)
{
EP_HOST_ASSERT
(
false
and
"not support low latency"
);
#ifndef DISABLE_ROCSHMEM
EP_HOST_ASSERT
(
low_latency_mode
);
// Tensor checks
EP_HOST_ASSERT
(
x
.
dim
()
==
3
and
x
.
is_contiguous
()
and
x
.
scalar_type
()
==
torch
::
kBFloat16
);
EP_HOST_ASSERT
(
x
.
size
(
0
)
==
num_experts
/
num_ranks
);
EP_HOST_ASSERT
(
x
.
size
(
1
)
==
num_ranks
*
num_max_dispatch_tokens_per_rank
);
EP_HOST_ASSERT
(
x
.
size
(
2
)
%
sizeof
(
int4
)
==
0
and
x
.
size
(
2
)
%
128
==
0
);
EP_HOST_ASSERT
(
topk_idx
.
dim
()
==
2
and
topk_idx
.
is_contiguous
());
EP_HOST_ASSERT
(
topk_idx
.
size
(
0
)
==
topk_weights
.
size
(
0
)
and
topk_idx
.
size
(
1
)
==
topk_weights
.
size
(
1
));
EP_HOST_ASSERT
(
topk_idx
.
scalar_type
()
==
torch
::
kInt64
);
EP_HOST_ASSERT
(
topk_weights
.
dim
()
==
2
and
topk_weights
.
is_contiguous
());
EP_HOST_ASSERT
(
topk_weights
.
size
(
0
)
<=
num_max_dispatch_tokens_per_rank
);
EP_HOST_ASSERT
(
topk_weights
.
scalar_type
()
==
torch
::
kFloat32
);
EP_HOST_ASSERT
(
src_info
.
dim
()
==
2
and
src_info
.
is_contiguous
());
EP_HOST_ASSERT
(
src_info
.
scalar_type
()
==
torch
::
kInt32
and
x
.
size
(
0
)
==
src_info
.
size
(
0
));
EP_HOST_ASSERT
(
layout_range
.
dim
()
==
2
and
layout_range
.
is_contiguous
());
EP_HOST_ASSERT
(
layout_range
.
scalar_type
()
==
torch
::
kInt64
);
EP_HOST_ASSERT
(
layout_range
.
size
(
0
)
==
num_experts
/
num_ranks
and
layout_range
.
size
(
1
)
==
num_ranks
);
if
(
combine_wait_recv_cost_stats
.
has_value
())
{
EP_HOST_ASSERT
(
combine_wait_recv_cost_stats
->
scalar_type
()
==
torch
::
kInt64
);
EP_HOST_ASSERT
(
combine_wait_recv_cost_stats
->
dim
()
==
1
and
combine_wait_recv_cost_stats
->
is_contiguous
());
EP_HOST_ASSERT
(
combine_wait_recv_cost_stats
->
size
(
0
)
==
num_ranks
);
}
auto
hidden
=
static_cast
<
int
>
(
x
.
size
(
2
));
auto
num_topk
=
static_cast
<
int
>
(
topk_weights
.
size
(
1
));
auto
num_combined_tokens
=
static_cast
<
int
>
(
topk_weights
.
size
(
0
));
auto
global_atomic_counter
=
torch
::
zeros
({
1
},
torch
::
dtype
(
torch
::
kInt32
).
device
(
torch
::
kCUDA
));
// Buffer control
LowLatencyLayout
layout
(
rdma_buffer_ptr
,
num_max_dispatch_tokens_per_rank
,
hidden
,
num_ranks
,
num_experts
);
EP_HOST_ASSERT
(
layout
.
total_bytes
<=
num_rdma_bytes
);
auto
buffer
=
layout
.
buffers
[
low_latency_buffer_idx
];
auto
next_buffer
=
layout
.
buffers
[
low_latency_buffer_idx
^=
1
];
// Wait previous tasks to be finished
// NOTES: the hook mode will always use the default stream
auto
compute_stream
=
at
::
hip
::
getCurrentHIPStreamMasqueradingAsCUDA
();
auto
launch_stream
=
return_recv_hook
?
compute_stream
:
comm_stream
;
EP_HOST_ASSERT
(
not
(
async
and
return_recv_hook
));
if
(
not
return_recv_hook
)
stream_wait
(
launch_stream
,
compute_stream
);
// Allocate output tensor
torch
::
Tensor
combined_x
;
if
(
out
.
has_value
())
{
EP_HOST_ASSERT
(
out
->
dim
()
==
2
and
out
->
is_contiguous
());
EP_HOST_ASSERT
(
out
->
size
(
0
)
==
num_combined_tokens
and
out
->
size
(
1
)
==
hidden
);
EP_HOST_ASSERT
(
out
->
scalar_type
()
==
x
.
scalar_type
());
combined_x
=
out
.
value
();
}
else
{
combined_x
=
torch
::
empty
({
num_combined_tokens
,
hidden
},
x
.
options
());
}
// Kernel launch
auto
next_clean_meta
=
next_buffer
.
clean_meta
();
auto
launcher
=
[
=
](
int
phases
)
{
internode_ll
::
combine
(
combined_x
.
data_ptr
(),
buffer
.
combine_rdma_recv_data_buffer
,
buffer
.
combine_rdma_recv_flag_buffer
,
buffer
.
combine_rdma_send_buffer
,
x
.
data_ptr
(),
topk_idx
.
data_ptr
<
int64_t
>
(),
topk_weights
.
data_ptr
<
float
>
(),
src_info
.
data_ptr
<
int
>
(),
layout_range
.
data_ptr
<
int64_t
>
(),
global_atomic_counter
.
data_ptr
<
int
>
(),
mask_buffer_ptr
,
combine_wait_recv_cost_stats
.
has_value
()
?
combine_wait_recv_cost_stats
->
data_ptr
<
int64_t
>
()
:
nullptr
,
next_clean_meta
.
first
,
next_clean_meta
.
second
,
num_combined_tokens
,
hidden
,
num_max_dispatch_tokens_per_rank
,
num_topk
,
num_experts
,
rank
,
num_ranks
,
use_logfmt
,
workspace
,
num_device_sms
,
launch_stream
,
phases
,
zero_copy
);
};
launcher
(
return_recv_hook
?
LOW_LATENCY_SEND_PHASE
:
(
LOW_LATENCY_SEND_PHASE
|
LOW_LATENCY_RECV_PHASE
));
// Wait streams
std
::
optional
<
EventHandle
>
event
;
if
(
async
)
{
// NOTES: we must ensure the all tensors will not be deallocated before the stream-wait happens,
// so in Python API, we must wrap all tensors into the event handle.
event
=
EventHandle
(
launch_stream
);
}
else
if
(
not
return_recv_hook
)
{
stream_wait
(
compute_stream
,
launch_stream
);
}
// Receiver callback
std
::
optional
<
std
::
function
<
void
()
>>
recv_hook
=
std
::
nullopt
;
if
(
return_recv_hook
)
recv_hook
=
[
=
]()
{
launcher
(
LOW_LATENCY_RECV_PHASE
);
};
// Return values
return
{
combined_x
,
event
,
recv_hook
};
#else
EP_HOST_ASSERT
(
false
and
"ROCSHMEM is disabled during compilation"
);
return
{};
return
{};
#endif
}
}
torch
::
Tensor
Buffer
::
get_next_low_latency_combine_buffer
(
int
num_max_dispatch_tokens_per_rank
,
torch
::
Tensor
Buffer
::
get_next_low_latency_combine_buffer
(
int
num_max_dispatch_tokens_per_rank
,
int
hidden
,
int
num_experts
)
const
{
int
hidden
,
int
num_experts
)
const
{
#ifndef DISABLE_ROCSHMEM
LowLatencyLayout
layout
(
rdma_buffer_ptr
,
num_max_dispatch_tokens_per_rank
,
hidden
,
num_ranks
,
num_experts
);
EP_HOST_ASSERT
(
false
and
"not support low latency"
);
auto
buffer
=
layout
.
buffers
[
low_latency_buffer_idx
];
auto
dtype
=
torch
::
kBFloat16
;
auto
num_msg_elems
=
static_cast
<
int
>
(
buffer
.
num_bytes_per_combine_msg
/
elementSize
(
torch
::
kBFloat16
));
// buffer.num_bytes_per_combine_msg = sizeof(int4) + hidden * sizeof(hip_bfloat16);
EP_HOST_ASSERT
(
buffer
.
num_bytes_per_combine_msg
%
elementSize
(
torch
::
kBFloat16
)
==
0
);
return
torch
::
from_blob
(
buffer
.
combine_rdma_send_buffer_data_start
,
{
num_experts
/
num_ranks
,
num_ranks
*
num_max_dispatch_tokens_per_rank
,
hidden
},
{
num_ranks
*
num_max_dispatch_tokens_per_rank
*
num_msg_elems
,
num_msg_elems
,
1
},
torch
::
TensorOptions
().
dtype
(
dtype
).
device
(
torch
::
kCUDA
));
#else
EP_HOST_ASSERT
(
false
and
"ROCSHMEM is disabled during compilation"
);
return
{};
return
{};
#endif
}
void
Buffer
::
low_latency_update_mask_buffer
(
int
rank_to_mask
,
bool
mask
)
{
EP_HOST_ASSERT
(
mask_buffer_ptr
!=
nullptr
and
"Shrink mode must be enabled"
);
EP_HOST_ASSERT
(
rank_to_mask
>=
0
and
rank_to_mask
<
num_ranks
);
internode_ll
::
update_mask_buffer
(
mask_buffer_ptr
,
rank_to_mask
,
mask
,
at
::
hip
::
getCurrentHIPStreamMasqueradingAsCUDA
());
}
void
Buffer
::
low_latency_query_mask_buffer
(
const
torch
::
Tensor
&
mask_status
)
{
EP_HOST_ASSERT
(
mask_buffer_ptr
!=
nullptr
and
"Shrink mode must be enabled"
);
EP_HOST_ASSERT
(
mask_status
.
numel
()
==
num_ranks
&&
mask_status
.
scalar_type
()
==
torch
::
kInt32
);
internode_ll
::
query_mask_buffer
(
mask_buffer_ptr
,
num_ranks
,
reinterpret_cast
<
int
*>
(
mask_status
.
data_ptr
()),
at
::
hip
::
getCurrentHIPStreamMasqueradingAsCUDA
());
}
void
Buffer
::
low_latency_clean_mask_buffer
()
{
EP_HOST_ASSERT
(
mask_buffer_ptr
!=
nullptr
and
"Shrink mode must be enabled"
);
internode_ll
::
clean_mask_buffer
(
mask_buffer_ptr
,
num_ranks
,
at
::
hip
::
getCurrentHIPStreamMasqueradingAsCUDA
());
}
}
}
// namespace deep_ep
}
// namespace deep_ep
...
@@ -1346,8 +1634,11 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
...
@@ -1346,8 +1634,11 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
.
def
(
"clean_low_latency_buffer"
,
&
deep_ep
::
Buffer
::
clean_low_latency_buffer
)
.
def
(
"clean_low_latency_buffer"
,
&
deep_ep
::
Buffer
::
clean_low_latency_buffer
)
.
def
(
"low_latency_dispatch"
,
&
deep_ep
::
Buffer
::
low_latency_dispatch
)
.
def
(
"low_latency_dispatch"
,
&
deep_ep
::
Buffer
::
low_latency_dispatch
)
.
def
(
"low_latency_combine"
,
&
deep_ep
::
Buffer
::
low_latency_combine
)
.
def
(
"low_latency_combine"
,
&
deep_ep
::
Buffer
::
low_latency_combine
)
.
def
(
"get_next_low_latency_combine_buffer"
,
&
deep_ep
::
Buffer
::
get_next_low_latency_combine_buffer
);
.
def
(
"get_next_low_latency_combine_buffer"
,
&
deep_ep
::
Buffer
::
get_next_low_latency_combine_buffer
)
.
def
(
"low_latency_update_mask_buffer"
,
&
deep_ep
::
Buffer
::
low_latency_update_mask_buffer
)
.
def
(
"low_latency_query_mask_buffer"
,
&
deep_ep
::
Buffer
::
low_latency_query_mask_buffer
)
.
def
(
"low_latency_clean_mask_buffer"
,
&
deep_ep
::
Buffer
::
low_latency_clean_mask_buffer
);
// m.def("is_sm90_compiled", deep_ep::is_sm90_compiled);
// m.def("is_sm90_compiled", deep_ep::is_sm90_compiled);
// m.attr("
topk_idx
_t") = py::cast(c10::CppTypeToScalarType<deep_ep::
topk_idx
_t>::value);
// m.attr("
int64
_t") = py::cast(c10::CppTypeToScalarType<deep_ep::
int64
_t>::value);
}
}
csrc/deep_ep.hpp
View file @
09cb2b03
...
@@ -30,6 +30,11 @@ private:
...
@@ -30,6 +30,11 @@ private:
int64_t
num_rdma_bytes
;
int64_t
num_rdma_bytes
;
void
*
rdma_buffer_ptr
=
nullptr
;
void
*
rdma_buffer_ptr
=
nullptr
;
// Shrink mode buffer
bool
enable_shrink
=
false
;
int
*
mask_buffer_ptr
=
nullptr
;
int
*
sync_buffer_ptr
=
nullptr
;
// Device info and communication
// Device info and communication
int
device_id
;
int
device_id
;
int
num_device_sms
;
int
num_device_sms
;
...
@@ -67,11 +72,9 @@ private:
...
@@ -67,11 +72,9 @@ private:
volatile
int
*
moe_recv_rdma_counter
=
nullptr
;
volatile
int
*
moe_recv_rdma_counter
=
nullptr
;
int
*
moe_recv_rdma_counter_mapped
=
nullptr
;
int
*
moe_recv_rdma_counter_mapped
=
nullptr
;
bool
use_default_stream_as_comm_stream
=
false
;
public:
public:
Buffer
(
int
rank
,
int
num_ranks
,
int64_t
num_nvl_bytes
,
int64_t
num_rdma_bytes
,
Buffer
(
int
rank
,
int
num_ranks
,
int64_t
num_nvl_bytes
,
int64_t
num_rdma_bytes
,
bool
low_latency_mode
,
bool
explicitly_destroy
,
bool
use_default_stream_as_comm_stream
);
bool
low_latency_mode
,
bool
explicitly_destroy
,
bool
enable_shrink
);
~
Buffer
()
noexcept
(
false
);
~
Buffer
()
noexcept
(
false
);
...
@@ -187,6 +190,12 @@ public:
...
@@ -187,6 +190,12 @@ public:
torch
::
Tensor
get_next_low_latency_combine_buffer
(
int
num_max_dispatch_tokens_per_rank
,
torch
::
Tensor
get_next_low_latency_combine_buffer
(
int
num_max_dispatch_tokens_per_rank
,
int
hidden
,
int
num_experts
)
const
;
int
hidden
,
int
num_experts
)
const
;
void
low_latency_update_mask_buffer
(
int
rank_to_mask
,
bool
mask
);
void
low_latency_query_mask_buffer
(
const
torch
::
Tensor
&
mask_status
);
void
low_latency_clean_mask_buffer
();
};
};
}
// namespace deep_ep
}
// namespace deep_ep
csrc/kernels/api.cuh
View file @
09cb2b03
...
@@ -131,4 +131,46 @@ void combine(hipDataType type, void *combined_x, float *combined_topk_weights,
...
@@ -131,4 +131,46 @@ void combine(hipDataType type, void *combined_x, float *combined_topk_weights,
int
num_ranks
,
hipStream_t
stream
,
int
num_channels
,
bool
low_latency_mode
);
int
num_ranks
,
hipStream_t
stream
,
int
num_channels
,
bool
low_latency_mode
);
}
// namespace internode
}
// namespace internode
// Internode low-latency kernels
namespace
internode_ll
{
void
clean_low_latency_buffer
(
int64_t
*
clean_0
,
int
num_clean_int_0
,
int64_t
*
clean_1
,
int
num_clean_int_1
,
int
rank
,
int
num_ranks
,
int
*
mask_buffer
,
int
*
sync_buffer
,
hipStream_t
stream
);
void
dispatch
(
void
*
packed_recv_x
,
void
*
packed_recv_x_scales
,
int
*
packed_recv_src_info
,
int64_t
*
packed_recv_layout_range
,
int
*
packed_recv_count
,
int
*
global_atomic_counter
,
int
*
mask_buffer
,
int
*
cumulative_local_expert_recv_stats
,
int64_t
*
dispatch_wait_recv_cost_stats
,
void
*
rdma_recv_x
,
int64_t
*
rdma_recv_count
,
void
*
rdma_x
,
const
void
*
x
,
const
int64_t
*
topk_idx
,
int64_t
*
next_clean
,
int
num_next_clean_int
,
int
num_tokens
,
int
hidden
,
int
num_max_dispatch_tokens_per_rank
,
int
num_topk
,
int
num_experts
,
int
rank
,
int
num_ranks
,
bool
use_fp8
,
bool
round_scale
,
bool
use_ue8m0
,
void
*
workspace
,
int
num_device_sms
,
hipStream_t
stream
,
int
phases
);
void
combine
(
void
*
combined_x
,
void
*
rdma_recv_x
,
int64_t
*
rdma_recv_flag
,
void
*
rdma_send_x
,
const
void
*
x
,
const
int64_t
*
topk_idx
,
const
float
*
topk_weights
,
const
int
*
src_info
,
const
int64_t
*
layout_range
,
int
*
global_atomic_counter
,
int
*
mask_buffer
,
int64_t
*
combine_wait_recv_cost_stats
,
int64_t
*
next_clean
,
int
num_next_clean_int
,
int
num_combined_tokens
,
int
hidden
,
int
num_max_dispatch_tokens_per_rank
,
int
num_topk
,
int
num_experts
,
int
rank
,
int
num_ranks
,
bool
use_logfmt
,
void
*
workspace
,
int
num_device_sms
,
hipStream_t
stream
,
int
phases
,
bool
zero_copy
);
void
query_mask_buffer
(
int
*
mask_buffer_ptr
,
int
num_ranks
,
int
*
output_mask_tensor
,
hipStream_t
stream
);
void
update_mask_buffer
(
int
*
mask_buffer_ptr
,
int
rank_to_mask
,
bool
mask
,
hipStream_t
stream
);
void
clean_mask_buffer
(
int
*
mask_buffer_ptr
,
int
num_ranks
,
hipStream_t
stream
);
}
// namespace internode_ll
}
// namespace deep_ep
}
// namespace deep_ep
csrc/kernels/configs.cuh
View file @
09cb2b03
...
@@ -22,6 +22,8 @@
...
@@ -22,6 +22,8 @@
#define LOW_LATENCY_SEND_PHASE 1
#define LOW_LATENCY_SEND_PHASE 1
#define LOW_LATENCY_RECV_PHASE 2
#define LOW_LATENCY_RECV_PHASE 2
#define NUM_INTERNODE_DISPATCH_BLOCKS_PER_CHANNEL 3
#define FP8_QUANTIZATION_NUM_PER_CHANNEL 128
#define NUM_INTERNODE_DISPATCH_BLOCKS_PER_CHANNEL 3
#define NUM_INTERNODE_DISPATCH_BLOCKS_PER_CHANNEL 3
...
...
csrc/kernels/internode_ll.cu
View file @
09cb2b03
#include "configs.cuh"
#include "configs.cuh"
#include "exception.cuh"
#include "exception.cuh"
#include "launch.cuh"
#include "launch.cuh"
// #include "ibgda_device.cuh"
#include "buffer.cuh"
#include "utils.cuh"
// #include <cooperative_groups.h>
#include <iostream>
// low latency+RocSHMEM has issue with CTX.
#define ROCM_DISABLE_CTX
#ifndef DISABLE_ROCSHMEM
#include <rocshmem/rocshmem.hpp>
#include <rocshmem/rocshmem_COLL.hpp>
namespace
deep_ep
{
namespace
deep_ep
{
namespace
internode_ll
{
namespace
internode_ll
{
template
<
bool
use_warp_sync
=
false
>
__forceinline__
__device__
bool
is_rank_masked
(
int
*
mask_buffer_ptr
,
int
rank
)
{
if
(
mask_buffer_ptr
==
nullptr
)
{
return
false
;
}
if
constexpr
(
use_warp_sync
)
{
return
shfl_sync
(
ld_acquire_global
(
mask_buffer_ptr
+
rank
),
0
)
!=
0
;
}
else
{
return
ld_acquire_global
(
mask_buffer_ptr
+
rank
)
!=
0
;
}
}
__device__
void
grid_barrier
(
int
*
global_counter
,
int
num_blocks
)
{
volatile
int
ret
;
__syncthreads
();
memory_fence_gpu
();
if
(
threadIdx
.
x
==
0
)
{
ret
=
atomicAdd
((
int
*
)
&
global_counter
[
0
],
1
);
}
__syncthreads
();
if
(
threadIdx
.
x
==
0
)
{
while
(
ld_relaxed_global
(
global_counter
)
!=
num_blocks
);
}
__syncthreads
();
}
template
<
int
kNumThreads
>
__launch_bounds__
(
kNumThreads
,
1
)
template
<
int
kNumThreads
>
__launch_bounds__
(
kNumThreads
,
1
)
__global__
void
clean_low_latency_buffer
(
int
*
clean_0
,
int
num_clean_int_0
,
__global__
void
clean_low_latency_buffer
(
int64_t
*
clean_0
,
int
num_clean_int_0
,
int
*
clean_1
,
int
num_clean_int_1
)
{
int64_t
*
clean_1
,
int
num_clean_int_1
,
int
rank
,
int
num_ranks
,
int
*
mask_buffer_ptr
,
int
*
sync_buffer_ptr
)
{
auto
thread_id
=
static_cast
<
int
>
(
threadIdx
.
x
);
// Barrier before cleaning (in case of unfinished chunked EP)
// Barrier before cleaning (in case of unfinished chunked EP)
// nvshmemx_barrier_all_block();
if
(
sync_buffer_ptr
==
nullptr
)
{
// rocshmem::rocshmem_barrier_all_wg();
// // Clean
if
(
thread_id
==
0
)
// auto thread_id = static_cast<int>(threadIdx.x);
rocshmem
::
rocshmem_barrier_all
();
// #pragma unroll
}
else
{
// for (int i = thread_id; i < num_clean_int_0; i += kNumThreads)
// barrier<kNumThreads>(thread_id, rank, num_ranks, mask_buffer_ptr, sync_buffer_ptr);
// clean_0[i] = 0;
EP_DEVICE_ASSERT
(
0
);
// #pragma unroll
}
// for (int i = thread_id; i < num_clean_int_1; i += kNumThreads)
// clean_1[i] = 0;
// Clean
#pragma unroll
// // Barrier after cleaning (make sure the low-latency mode works fine)
for
(
int
i
=
thread_id
;
i
<
num_clean_int_0
;
i
+=
kNumThreads
)
// nvshmemx_barrier_all_block();
clean_0
[
i
]
=
0
;
#pragma unroll
for
(
int
i
=
thread_id
;
i
<
num_clean_int_1
;
i
+=
kNumThreads
)
clean_1
[
i
]
=
0
;
// Barrier after cleaning (make sure low-latency mode work
if
(
sync_buffer_ptr
==
nullptr
)
{
// rocshmem::rocshmem_barrier_all_wg();
if
(
thread_id
==
0
)
rocshmem
::
rocshmem_barrier_all
();
}
else
{
// barrier<kNumThreads>(thread_id, rank, num_ranks, mask_buffer_ptr, sync_buffer_ptr);
EP_DEVICE_ASSERT
(
0
);
}
}
}
void
clean_low_latency_buffer
(
int
*
clean_0
,
int
num_clean_int_0
,
void
clean_low_latency_buffer
(
int64_t
*
clean_0
,
int
num_clean_int_0
,
int
*
clean_1
,
int
num_clean_int_1
,
int64_t
*
clean_1
,
int
num_clean_int_1
,
cudaStream_t
stream
)
{
int
rank
,
int
num_ranks
,
// constexpr int kNumThreads = 256;
int
*
mask_buffer_ptr
,
int
*
sync_buffer_ptr
,
hipStream_t
stream
)
{
// SETUP_LAUNCH_CONFIG(1, kNumThreads, stream);
constexpr
int
kNumThreads
=
256
;
// LAUNCH_KERNEL(&cfg, clean_low_latency_buffer<kNumThreads>,
// clean_0, num_clean_int_0, clean_1, num_clean_int_1);
SETUP_LAUNCH_CONFIG
(
1
,
kNumThreads
,
stream
);
LAUNCH_KERNEL
(
&
cfg
,
clean_low_latency_buffer
<
kNumThreads
>
,
clean_0
,
num_clean_int_0
,
clean_1
,
num_clean_int_1
,
rank
,
num_ranks
,
mask_buffer_ptr
,
sync_buffer_ptr
);
}
}
template
<
bool
kUseFP8
,
bool
kUseUE8M0
,
int
kHidden
>
template
<
bool
kUseFP8
,
bool
kUseUE8M0
,
int
kHidden
>
__global__
__launch_bounds__
(
1024
,
1
)
void
__launch_bounds__
(
1024
,
1
)
__global__
void
dispatch
(
void
*
packed_recv_x
,
dispatch
(
void
*
packed_recv_x
,
void
*
packed_recv_x_scales
,
void
*
packed_recv_x_scales
,
int
*
packed_recv_src_info
,
int64_t
*
packed_recv_layout_range
,
int
*
packed_recv_src_info
,
int64_t
*
packed_recv_layout_range
,
int
*
packed_recv_count
,
int
*
packed_recv_count
,
int
*
global_atomic_counter
,
int
*
mask_buffer_ptr
,
int
*
cumulative_local_expert_recv_stats
,
int
*
cumulative_local_expert_recv_stats
,
int64_t
*
dispatch_wait_recv_cost_stats
,
int64_t
*
dispatch_wait_recv_cost_stats
,
void
*
rdma_recv_x
,
int
*
rdma_recv_count
,
void
*
rdma_x
,
void
*
rdma_recv_x
,
const
void
*
x
,
const
topk_idx_t
*
topk_idx
,
int64_t
*
rdma_recv_count
,
int
*
atomic_counter_per_expert
,
int
*
atomic_finish_counter_per_expert
,
void
*
rdma_x
,
int
*
next_clean
,
int
num_next_clean_int
,
const
void
*
x
,
int
num_tokens
,
int
num_max_dispatch_tokens_per_rank
,
const
int64_t
*
topk_idx
,
int
num_topk
,
int
num_experts
,
int
rank
,
int
num_ranks
,
int
*
atomic_counter_per_expert
,
int
num_warp_groups
,
int
num_warps_per_group
,
int
*
atomic_finish_counter_per_expert
,
bool
round_scale
,
int
phases
)
{
int64_t
*
next_clean
,
// const auto sm_id = static_cast<int>(blockIdx.x);
int
num_next_clean_int
,
// const auto thread_id = static_cast<int>(threadIdx.x);
int
num_tokens
,
// const auto warp_id = thread_id / 32, lane_id = get_lane_id();
int
num_max_dispatch_tokens_per_rank
,
// const auto num_sms = static_cast<int>(gridDim.x);
int
num_topk
,
// const auto num_warps = num_warp_groups * num_warps_per_group;
int
num_experts
,
// const auto num_local_experts = num_experts / num_ranks;
int
rank
,
// const auto warp_group_id = warp_id / num_warps_per_group;
int
num_ranks
,
// const auto sub_warp_id = warp_id % num_warps_per_group;
int
num_warp_groups
,
// const auto responsible_expert_idx = sm_id * num_warp_groups + warp_group_id;
int
num_warps_per_group
,
bool
round_scale
,
// // May extract UE8M0 from the scales
int
phases
)
{
// using scale_t = std::conditional_t<kUseUE8M0, uint8_t, float>;
#if !defined(ROCM_DISABLE_CTX)
// using packed_t = std::conditional_t<kUseUE8M0, uint32_t, float>;
__shared__
rocshmem
::
rocshmem_ctx_t
ctx
;
// EP_STATIC_ASSERT(sizeof(packed_t) % sizeof(scale_t) == 0, "Invalid vector length");
rocshmem
::
rocshmem_wg_ctx_create
(
0
,
&
ctx
);
#endif
// // FP8 staffs
// constexpr int kNumPerChannels = 128;
const
auto
sm_id
=
static_cast
<
int
>
(
blockIdx
.
x
);
// const int num_scales = kHidden / kNumPerChannels;
const
auto
thread_id
=
static_cast
<
int
>
(
threadIdx
.
x
);
// const size_t hidden_bytes = kHidden * (kUseFP8 ? sizeof(__nv_fp8_storage_t) : sizeof(nv_bfloat16));
const
auto
warp_id
=
thread_id
/
kWarpSize
,
lane_id
=
get_lane_id
();
// const size_t hidden_int4 = hidden_bytes / sizeof(int4);
const
auto
num_sms
=
static_cast
<
int
>
(
gridDim
.
x
);
const
auto
num_warps
=
num_warp_groups
*
num_warps_per_group
;
// // Message package: index at source (int), 3 reserved int fields, hidden data, FP8 scales
const
auto
num_local_experts
=
num_experts
/
num_ranks
;
// // NOTES: currently we have 3 reserved int fields for future use
const
auto
warp_group_id
=
warp_id
/
num_warps_per_group
;
// using vec_t = std::conditional_t<kUseFP8, int2, int4>;
const
auto
sub_warp_id
=
warp_id
%
num_warps_per_group
;
// const size_t num_bytes_per_msg = sizeof(int4) + (kUseFP8 ? (kHidden + num_scales * sizeof(float)) : (kHidden * sizeof(nv_bfloat16)));
// 每个warp处理一个expert
// const size_t num_int4_per_msg = num_bytes_per_msg / sizeof(int4);
const
auto
responsible_expert_idx
=
sm_id
*
num_warp_groups
+
warp_group_id
;
// EP_DEVICE_ASSERT(num_bytes_per_msg % sizeof(int4) == 0);
// May extract UE8M0 from the scales
// // Expert counts
using
scale_t
=
std
::
conditional_t
<
kUseUE8M0
,
uint8_t
,
float
>
;
// constexpr int kNumMaxWarpGroups = 32;
using
packed_t
=
std
::
conditional_t
<
kUseUE8M0
,
uint32_t
,
float
>
;
// __shared__ int shared_num_tokens_sent_per_expert[kNumMaxWarpGroups];
EP_STATIC_ASSERT
(
sizeof
(
packed_t
)
%
sizeof
(
scale_t
)
==
0
,
"Invalid vector length"
);
// // Sending phase
// FP8 staffs
// if ((phases & LOW_LATENCY_SEND_PHASE) == 0)
constexpr
int
kNumPerChannels
=
FP8_QUANTIZATION_NUM_PER_CHANNEL
;
// goto LOW_LATENCY_DISPATCH_RECV;
const
int
num_scales
=
kHidden
/
kNumPerChannels
;
const
size_t
hidden_bytes
=
kHidden
*
(
kUseFP8
?
sizeof
(
__hip_fp8_storage_t
)
:
sizeof
(
hip_bfloat16
));
// // There are 2 kinds of warps in this part:
const
size_t
hidden_int4
=
hidden_bytes
/
sizeof
(
int4
);
// // 1. The first-kind warps for FP8 cast and sending top-k tokens
// // 2. The last warp for reading `topk_idx` and count for per-expert information
// Message package: index at source (int), 3 reserved int fields, hidden data, FP8 scales
// if (warp_id < num_warps - 1) {
// NOTES: currently we have 3 reserved int fields for future use
// constexpr int kNumElemsPerRead = sizeof(int4) / sizeof(nv_bfloat16);
using
vec_t
=
std
::
conditional_t
<
kUseFP8
,
int2
,
int4
>
;
// EP_STATIC_ASSERT(kHidden % (32 * kNumElemsPerRead) == 0, "Invalid hidden");
const
size_t
num_bytes_per_msg
=
sizeof
(
int4
)
+
(
kUseFP8
?
(
kHidden
+
num_scales
*
sizeof
(
float
))
:
(
kHidden
*
sizeof
(
hip_bfloat16
)));
// EP_STATIC_ASSERT(kNumElemsPerRead * 32 % kNumPerChannels == 0, "Invalid vectorization");
const
size_t
num_int4_per_msg
=
num_bytes_per_msg
/
sizeof
(
int4
);
// const auto num_threads = (num_warps - 1) * 32;
EP_DEVICE_ASSERT
(
num_bytes_per_msg
%
sizeof
(
int4
)
==
0
);
// const size_t hidden_bf16_int4 = kHidden / kNumElemsPerRead;
// Expert counts
// for (int token_idx = sm_id; token_idx < num_tokens; token_idx += num_sms) {
constexpr
int
kNumMaxWarpGroups
=
16
;
// 每个kernel最多warp group数量,即每个block负责的专家数
// const auto x_int4 = static_cast<const int4*>(x) + token_idx * hidden_bf16_int4;
__shared__
int
shared_num_tokens_sent_per_expert
[
kNumMaxWarpGroups
];
// const auto rdma_x_src_idx = reinterpret_cast<int*>(static_cast<uint8_t*>(rdma_x) + token_idx * num_bytes_per_msg);
// const auto rdma_x_vec = reinterpret_cast<vec_t*>(reinterpret_cast<uint8_t*>(rdma_x_src_idx) + sizeof(int4));
#ifdef USE_ROCM
// const auto rdma_x_scales = reinterpret_cast<float*>(reinterpret_cast<uint8_t*>(rdma_x_vec) + hidden_bytes);
// 用于同步
// 16 is the max possible number of warps in AMD GPUs
// // Overlap top-k index read and source token index writes
constexpr
int
kMaxNumWarps
=
1024
/
kWarpSize
;
// auto dst_expert_idx = warp_id < num_topk ? static_cast<int>(__ldg(topk_idx + token_idx * num_topk + warp_id)) : -1;
constexpr
int
num_sync_large_iteration
=
kMaxNumWarps
;
// thread_id == 0 ? (*rdma_x_src_idx = token_idx) : 0;
__shared__
volatile
int
sync_large_warp_counters
[
num_sync_large_iteration
];
// // FP8 cast
#pragma unroll
// EP_STATIC_ASSERT(hidden_bf16_int4 % 32 == 0, "Must use the full warp to reduce");
for
(
int
i
=
thread_id
;
i
<
num_sync_large_iteration
;
i
+=
blockDim
.
x
)
{
// #pragma unroll
sync_large_warp_counters
[
i
]
=
0
;
// for (int i = thread_id; i < hidden_bf16_int4; i += num_threads) {
}
// // Read
__syncthreads
();
// auto int4_value = __ldg(x_int4 + i);
#endif
// if constexpr (kUseFP8) {
// Sending phase,如果没有发送任务,则直接跳到接收阶段
// // Calculate local amax
if
((
phases
&
LOW_LATENCY_SEND_PHASE
)
==
0
)
// auto bf16_values = reinterpret_cast<nv_bfloat16*>(&int4_value);
goto
LOW_LATENCY_DISPATCH_RECV
;
// float fp32_values[kNumElemsPerRead];
// float amax = kFP8Margin, scale, scale_inv;
// There are 2 kinds of warps in this part:
// #pragma unroll
// 1. The first-kind warps for FP8 cast and sending top-k tokens
// for (int j = 0; j < kNumElemsPerRead; ++ j) {
// 2. The last warp for reading `topk_idx` and count for per-expert information
// fp32_values[j] = static_cast<float>(bf16_values[j]);
if
(
warp_id
<
num_warps
-
1
)
{
// amax = fmaxf(amax, fabsf(fp32_values[j]));
constexpr
int
kNumElemsPerRead
=
sizeof
(
int4
)
/
sizeof
(
hip_bfloat16
);
// 128/16 = 8
// }
EP_STATIC_ASSERT
(
kHidden
%
(
kWarpSize
*
kNumElemsPerRead
)
==
0
,
"Invalid hidden"
);
EP_STATIC_ASSERT
(
kNumElemsPerRead
*
kWarpSize
%
kNumPerChannels
==
0
,
"Invalid vectorization"
);
// // Reduce amax and scale
const
auto
num_threads
=
(
num_warps
-
1
)
*
kWarpSize
;
// EP_STATIC_ASSERT(kNumElemsPerRead * 32 / kNumPerChannels == 2, "Invalid vectorization");
const
size_t
hidden_bf16_int4
=
kHidden
/
kNumElemsPerRead
;
// amax = warp_reduce_max<16>(amax);
// calculate_fp8_scales(amax, scale, scale_inv, round_scale);
for
(
int
token_idx
=
sm_id
;
token_idx
<
num_tokens
;
token_idx
+=
num_sms
)
{
// if (lane_id == 0 or lane_id == 16)
const
auto
x_int4
=
static_cast
<
const
int4
*>
(
x
)
+
token_idx
*
hidden_bf16_int4
;
// rdma_x_scales[i * kNumElemsPerRead / 128] = scale_inv;
const
auto
rdma_x_src_idx
=
reinterpret_cast
<
int
*>
(
static_cast
<
uint8_t
*>
(
rdma_x
)
+
token_idx
*
num_bytes_per_msg
);
const
auto
rdma_x_vec
=
reinterpret_cast
<
vec_t
*>
(
reinterpret_cast
<
uint8_t
*>
(
rdma_x_src_idx
)
+
sizeof
(
int4
));
// // Cast into send buffer
const
auto
rdma_x_scales
=
reinterpret_cast
<
float
*>
(
reinterpret_cast
<
uint8_t
*>
(
rdma_x_vec
)
+
hidden_bytes
);
// vec_t int2_value;
// auto fp8x2_values = reinterpret_cast<__nv_fp8x2_storage_t*>(&int2_value);
// Overlap top-k index read and source token index writes
// #pragma unroll
auto
dst_expert_idx
=
warp_id
<
num_topk
?
static_cast
<
int
>
(
__ldg
(
topk_idx
+
token_idx
*
num_topk
+
warp_id
))
:
-
1
;
// for (int j = 0; j < kNumElemsPerRead; j += 2) {
thread_id
==
0
?
(
*
rdma_x_src_idx
=
token_idx
)
:
0
;
// float2 fp32x2 = {fp32_values[j] * scale, fp32_values[j + 1] * scale};
// fp8x2_values[j / 2] = __nv_cvt_float2_to_fp8x2(fp32x2, __NV_SATFINITE, __NV_E4M3);
// FP8 cast
// }
EP_STATIC_ASSERT
(
hidden_bf16_int4
%
kWarpSize
==
0
,
"Must use the full warp to reduce"
);
// rdma_x_vec[i] = int2_value;
#pragma unroll
// } else {
for
(
int
i
=
thread_id
;
i
<
hidden_bf16_int4
;
i
+=
num_threads
)
{
// // Reinterpret-cast is for C++14 compatibility
// Read
// rdma_x_vec[i] = *reinterpret_cast<vec_t*>(&int4_value);
auto
int4_value
=
__ldg
(
x_int4
+
i
);
// }
// }
if
constexpr
(
kUseFP8
)
{
// asm volatile("bar.sync 1, %0;" :: "r"(num_threads));
// Calculate local amax
auto
bf16_values
=
reinterpret_cast
<
hip_bfloat16
*>
(
&
int4_value
);
// // Issue IBGDA sends
float
fp32_values
[
kNumElemsPerRead
];
// if (dst_expert_idx >= 0) {
float
amax
=
kFP8Margin
,
scale
,
scale_inv
;
// int slot_idx = lane_id == 0 ? atomicAdd(atomic_counter_per_expert + dst_expert_idx, 1) : 0;
#pragma unroll
// slot_idx = __shfl_sync(0xffffffff, slot_idx, 0);
for
(
int
j
=
0
;
j
<
kNumElemsPerRead
;
++
j
)
{
// const auto dst_rank = dst_expert_idx / num_local_experts;
fp32_values
[
j
]
=
static_cast
<
float
>
(
bf16_values
[
j
]);
// const auto dst_expert_local_idx = dst_expert_idx % num_local_experts;
amax
=
fmaxf
(
amax
,
fabsf
(
fp32_values
[
j
]));
// const auto src_ptr = reinterpret_cast<uint64_t>(rdma_x_src_idx);
}
// const auto dst_ptr = reinterpret_cast<uint64_t>(rdma_recv_x) +
// dst_expert_local_idx * num_ranks * num_max_dispatch_tokens_per_rank * num_bytes_per_msg +
// Reduce amax and scale
// rank * num_max_dispatch_tokens_per_rank * num_bytes_per_msg +
EP_STATIC_ASSERT
(
kNumElemsPerRead
*
kWarpSize
/
kNumPerChannels
==
4
,
"Invalid vectorization"
);
// slot_idx * num_bytes_per_msg;
amax
=
warp_reduce_max
<
16
>
(
amax
);
// const auto dst_p2p_ptr = nvshmemi_get_p2p_ptr(dst_ptr, rank, dst_rank);
calculate_fp8_scales
(
amax
,
scale
,
scale_inv
,
round_scale
);
// if (dst_p2p_ptr == 0) {
// nvshmemi_ibgda_put_nbi_warp(dst_ptr, src_ptr, num_bytes_per_msg, dst_rank, dst_expert_local_idx, lane_id, slot_idx);
if
(
lane_id
%
16
==
0
)
// } else {
rdma_x_scales
[
i
*
kNumElemsPerRead
/
FP8_QUANTIZATION_NUM_PER_CHANNEL
]
=
scale_inv
;
// // NOTES: only 2 load iterations for 7K hidden with 8 unrolls
// const auto* src_int4_ptr = reinterpret_cast<const int4*>(src_ptr);
// Cast into send buffer
// const auto* dst_int4_ptr = reinterpret_cast<int4*>(dst_p2p_ptr);
vec_t
int2_value
;
// UNROLLED_WARP_COPY(8, lane_id, num_int4_per_msg, dst_int4_ptr, src_int4_ptr, ld_nc_global, st_na_global);
auto
fp8x2_values
=
reinterpret_cast
<
__hip_fp8x2_storage_t
*>
(
&
int2_value
);
// }
#pragma unroll
for
(
int
j
=
0
;
j
<
kNumElemsPerRead
;
j
+=
2
)
{
// // Increase counter after finishing
float2
fp32x2
=
{
fp32_values
[
j
]
*
scale
,
fp32_values
[
j
+
1
]
*
scale
};
// __syncwarp();
fp8x2_values
[
j
/
2
]
=
__hip_cvt_float2_to_fp8x2
(
fp32x2
,
__HIP_SATFINITE
,
__HIP_E4M3_FNUZ
);
// lane_id == 0 ? atomic_add_release_global(atomic_finish_counter_per_expert + dst_expert_idx, 1) : 0;
}
// }
rdma_x_vec
[
i
]
=
int2_value
;
// }
}
else
{
// } else if (warp_id == num_warps - 1) {
// Reinterpret-cast is for C++14 compatibility
// EP_DEVICE_ASSERT(num_sms > 1);
rdma_x_vec
[
i
]
=
*
reinterpret_cast
<
vec_t
*>
(
&
int4_value
);
// if (sm_id == 0) {
}
// // The first SM is also responsible for checking QPs
}
// EP_DEVICE_ASSERT(ibgda_get_state()->num_rc_per_pe >= num_local_experts);
__syncthreads
();
// // The first SM is also responsible for cleaning the next buffer
// Issue IBGDA sends
// #pragma unroll
if
(
dst_expert_idx
>=
0
)
{
// for (int i = lane_id; i < num_next_clean_int; i += 32)
int
slot_idx
=
lane_id
==
0
?
atomicAdd
(
atomic_counter_per_expert
+
dst_expert_idx
,
1
)
:
0
;
// next_clean[i] = 0;
slot_idx
=
shfl_sync
(
slot_idx
,
0
);
const
int
dst_rank
=
dst_expert_idx
/
num_local_experts
;
// // Notify before executing `int_p`
const
int
dst_expert_local_idx
=
dst_expert_idx
%
num_local_experts
;
// __syncwarp();
const
auto
src_ptr
=
reinterpret_cast
<
uint64_t
>
(
rdma_x_src_idx
);
// #pragma unroll
const
auto
dst_ptr
=
reinterpret_cast
<
uint64_t
>
(
rdma_recv_x
)
+
// for (int i = lane_id; i < num_experts; i += 32)
dst_expert_local_idx
*
num_ranks
*
num_max_dispatch_tokens_per_rank
*
num_bytes_per_msg
+
// atomic_add_release_global(atomic_finish_counter_per_expert + i, FINISHED_SUM_TAG);
rank
*
num_max_dispatch_tokens_per_rank
*
num_bytes_per_msg
+
slot_idx
*
num_bytes_per_msg
;
// }
if
(
dst_rank
!=
rank
)
{
// // This SM should be responsible for some destination experts, read `topk_idx` for them
#if !defined(ROCM_DISABLE_CTX)
// int expert_count[kNumMaxWarpGroups] = {0};
rocshmem
::
rocshmem_ctx_schar_put_nbi_wave
(
ctx
,
// const auto expert_begin_idx = sm_id * num_warp_groups;
reinterpret_cast
<
signed
char
*>
(
dst_ptr
),
reinterpret_cast
<
signed
char
*>
(
src_ptr
),
// const auto expert_end_idx = min(expert_begin_idx + num_warp_groups, num_experts);
num_bytes_per_msg
,
dst_rank
);
rocshmem
::
rocshmem_ctx_quiet
(
ctx
);
// // Per lane count
#else
// #pragma unroll 8
rocshmem
::
rocshmem_schar_put_nbi_wave
(
// for (int i = lane_id; i < num_tokens * num_topk; i += 32) {
reinterpret_cast
<
signed
char
*>
(
dst_ptr
),
reinterpret_cast
<
signed
char
*>
(
src_ptr
),
// auto idx = static_cast<int>(__ldg(topk_idx + i));
num_bytes_per_msg
,
dst_rank
);
// if (idx >= expert_begin_idx and idx < expert_end_idx)
rocshmem
::
rocshmem_fence
();
// expert_count[idx - expert_begin_idx] ++;
#endif
// }
}
else
{
// NOTES: only 2 load iterations for 7K hidden with 8 unrolls
// // Warp reduce
const
auto
*
src_int4_ptr
=
reinterpret_cast
<
const
int4
*>
(
src_ptr
);
// #pragma unroll
const
auto
*
dst_int4_ptr
=
reinterpret_cast
<
int4
*>
(
dst_ptr
);
// for (int i = expert_begin_idx; i < expert_end_idx; ++ i) {
UNROLLED_WARP_COPY
(
8
,
lane_id
,
num_int4_per_msg
,
dst_int4_ptr
,
src_int4_ptr
,
ld_nc_global
,
st_na_global
);
// auto sum = warp_reduce_sum(expert_count[i - expert_begin_idx]);
}
// if (lane_id == 0) {
// Increase counter after finishing
// shared_num_tokens_sent_per_expert[i - expert_begin_idx] = sum;
syncwarp
();
// atomic_add_release_global(atomic_finish_counter_per_expert + i, FINISHED_SUM_TAG - sum);
lane_id
==
0
?
atomic_add_release_global
(
atomic_finish_counter_per_expert
+
dst_expert_idx
,
1
)
:
0
;
// }
}
// }
}
// }
}
else
if
(
warp_id
==
num_warps
-
1
)
{
// __syncthreads();
EP_DEVICE_ASSERT
(
num_sms
>
1
);
if
(
sm_id
==
0
)
{
// // Issue count sends
// The first SM is also responsible for cleaning the next buffer
// if (responsible_expert_idx < num_experts and sub_warp_id == 0 and lane_id == 0) {
#pragma unroll
// const auto dst_rank = responsible_expert_idx / num_local_experts;
for
(
int
i
=
lane_id
;
i
<
num_next_clean_int
;
i
+=
kWarpSize
)
// const auto dst_expert_local_idx = responsible_expert_idx % num_local_experts;
next_clean
[
i
]
=
0
;
// const auto num_tokens_sent = shared_num_tokens_sent_per_expert[responsible_expert_idx - sm_id * num_warp_groups];
// Notify before executing `int_p`
// // Wait local sends issued and send expert counts
syncwarp
();
// while (ld_acquire_global(atomic_finish_counter_per_expert + responsible_expert_idx) != FINISHED_SUM_TAG * 2);
#pragma unroll
// auto dst_ptr = reinterpret_cast<uint64_t>(rdma_recv_count + dst_expert_local_idx * num_ranks + rank);
for
(
int
i
=
lane_id
;
i
<
num_experts
;
i
+=
kWarpSize
)
// auto dst_p2p_ptr = nvshmemi_get_p2p_ptr(dst_ptr, rank, dst_rank);
atomic_add_release_global
(
atomic_finish_counter_per_expert
+
i
,
FINISHED_SUM_TAG
);
// if (dst_p2p_ptr == 0) {
}
// nvshmemi_ibgda_amo_nonfetch_add(reinterpret_cast<int*>(dst_ptr), -num_tokens_sent - 1, dst_rank, dst_expert_local_idx);
// } else {
// This SM should be responsible for some destination experts, read `topk_idx` for them
// st_release_sys_global(reinterpret_cast<int*>(dst_p2p_ptr), -num_tokens_sent - 1);
int
expert_count
[
kNumMaxWarpGroups
]
=
{
0
};
// }
const
auto
expert_begin_idx
=
sm_id
*
num_warp_groups
;
const
auto
expert_end_idx
=
min
(
expert_begin_idx
+
num_warp_groups
,
num_experts
);
// // Clean workspace for next use
// atomic_counter_per_expert[responsible_expert_idx] = 0;
// Per lane count
// atomic_finish_counter_per_expert[responsible_expert_idx] = 0;
#pragma unroll 8
for
(
int
i
=
lane_id
;
i
<
num_tokens
*
num_topk
;
i
+=
kWarpSize
)
{
// // Clean `packed_recv_count`
auto
idx
=
static_cast
<
int
>
(
__ldg
(
topk_idx
+
i
));
// if (dst_rank == 0)
if
(
idx
>=
expert_begin_idx
and
idx
<
expert_end_idx
)
// packed_recv_count[dst_expert_local_idx] = 0;
expert_count
[
idx
-
expert_begin_idx
]
++
;
// }
}
// __syncwarp();
// Warp reduce
// // Receiving phase
#pragma unroll
// LOW_LATENCY_DISPATCH_RECV:
for
(
int
i
=
expert_begin_idx
;
i
<
expert_end_idx
;
++
i
)
{
// if ((phases & LOW_LATENCY_RECV_PHASE) == 0)
auto
sum
=
warp_reduce_sum
(
expert_count
[
i
-
expert_begin_idx
]);
// return;
if
(
lane_id
==
0
)
{
shared_num_tokens_sent_per_expert
[
i
-
expert_begin_idx
]
=
sum
;
// // For send-and-recv kernels, we need a grid sync for making `packed_recv_count` visible
atomic_add_release_global
(
atomic_finish_counter_per_expert
+
i
,
FINISHED_SUM_TAG
-
sum
);
// if (phases & LOW_LATENCY_SEND_PHASE)
}
// cg::this_grid().sync();
}
}
// // Receiving and packing
// if (responsible_expert_idx < num_experts) {
__syncthreads
();
// const auto src_rank = responsible_expert_idx / num_local_experts;
// const auto local_expert_idx = responsible_expert_idx % num_local_experts;
// Issue count sends
// const auto rdma_recv_x_uint8 = static_cast<uint8_t*>(rdma_recv_x) +
if
(
responsible_expert_idx
<
num_experts
and
sub_warp_id
==
0
and
lane_id
==
0
)
{
// local_expert_idx * num_ranks * num_max_dispatch_tokens_per_rank * num_bytes_per_msg +
const
auto
dst_rank
=
responsible_expert_idx
/
num_local_experts
;
// src_rank * num_max_dispatch_tokens_per_rank * num_bytes_per_msg;
const
auto
dst_expert_local_idx
=
responsible_expert_idx
%
num_local_experts
;
// const auto recv_x_int4 = static_cast<int4*>(packed_recv_x) +
const
auto
num_tokens_sent
=
shared_num_tokens_sent_per_expert
[
responsible_expert_idx
-
sm_id
*
num_warp_groups
];
// local_expert_idx * num_ranks * num_max_dispatch_tokens_per_rank * hidden_int4;
// const auto recv_src_info = packed_recv_src_info + local_expert_idx * num_ranks * num_max_dispatch_tokens_per_rank;
// Wait local sends issued and send expert counts
// const auto recv_range = packed_recv_layout_range + local_expert_idx * num_ranks;
while
(
ld_acquire_global
(
atomic_finish_counter_per_expert
+
responsible_expert_idx
)
!=
FINISHED_SUM_TAG
*
2
);
// const auto num_aligned_scales = align_up<int>(num_scales, sizeof(float) / sizeof(scale_t));
if
(
not
is_rank_masked
(
mask_buffer_ptr
,
dst_rank
))
{
// const auto recv_x_scales = static_cast<scale_t*>(packed_recv_x_scales) + local_expert_idx * num_ranks * num_max_dispatch_tokens_per_rank * num_aligned_scales;
auto
dst_ptr
=
reinterpret_cast
<
int64_t
*>
(
rdma_recv_count
+
dst_expert_local_idx
*
num_ranks
+
rank
);
if
(
dst_rank
!=
rank
)
{
// // Shared between sub-warps in warp groups
#if !defined(ROCM_DISABLE_CTX)
// __shared__ int shared_num_recv_tokens[kNumMaxWarpGroups], shared_recv_token_begin_idx[kNumMaxWarpGroups];
rocshmem
::
rocshmem_ctx_long_atomic_add
(
ctx
,
dst_ptr
,
-
num_tokens_sent
-
1
,
dst_rank
);
#else
// // Wait tokens to arrive
rocshmem
::
rocshmem_long_atomic_add
(
dst_ptr
,
-
num_tokens_sent
-
1
,
dst_rank
);
// // NOTES: using sub-warp 1 to overlap with sub-warp 0
#endif
// int num_recv_tokens, recv_token_begin_idx;
}
else
{
// EP_DEVICE_ASSERT(num_warps_per_group > 1 and num_warp_groups < 15);
st_release_sys_global
(
dst_ptr
,
-
num_tokens_sent
-
1
);
// if (sub_warp_id == 1 and lane_id == 0) {
}
// auto start_time = clock64();
}
// while ((num_recv_tokens = ld_acquire_sys_global(rdma_recv_count + local_expert_idx * num_ranks + src_rank)) == 0);
// auto wait_recv_cost = clock64() - start_time;
// Clean workspace for next use
// num_recv_tokens = -num_recv_tokens - 1;
atomic_counter_per_expert
[
responsible_expert_idx
]
=
0
;
// recv_token_begin_idx = atomicAdd(packed_recv_count + local_expert_idx, num_recv_tokens);
atomic_finish_counter_per_expert
[
responsible_expert_idx
]
=
0
;
// shared_num_recv_tokens[warp_group_id] = num_recv_tokens;
// shared_recv_token_begin_idx[warp_group_id] = recv_token_begin_idx;
// Clean `packed_recv_count`
// recv_range[src_rank] = pack2<int, int64_t>(num_recv_tokens, recv_token_begin_idx);
if
(
dst_rank
==
0
)
packed_recv_count
[
dst_expert_local_idx
]
=
0
;
// // Add stats for diagnosis
}
// if (cumulative_local_expert_recv_stats != nullptr)
syncwarp
();
// atomicAdd(cumulative_local_expert_recv_stats + local_expert_idx, num_recv_tokens);
// if (dispatch_wait_recv_cost_stats != nullptr)
// atomicAdd(reinterpret_cast<unsigned long long*>(dispatch_wait_recv_cost_stats + src_rank), wait_recv_cost);
// Receiving phase
// }
LOW_LATENCY_DISPATCH_RECV:
// asm volatile("bar.sync %0, %1;" :: "r"(warp_group_id + 2), "r"(num_warps_per_group * 32));
// 如果没有接收直接返回
// num_recv_tokens = shared_num_recv_tokens[warp_group_id];
if
((
phases
&
LOW_LATENCY_RECV_PHASE
)
==
0
)
// recv_token_begin_idx = shared_recv_token_begin_idx[warp_group_id];
return
;
// // Copy tokens
// For send-and-recv kernels, we need a grid sync for making `packed_recv_count` visible
// EP_DEVICE_ASSERT(num_scales <= 64);
if
(
phases
&
LOW_LATENCY_SEND_PHASE
){
// for (int i = sub_warp_id; i < num_recv_tokens; i += num_warps_per_group) {
grid_barrier
(
global_atomic_counter
,
num_sms
);
// // Copy source info
}
// const auto src_src_idx = reinterpret_cast<int*>(rdma_recv_x_uint8 + i * num_bytes_per_msg);
// if (lane_id == 0)
// Receiving and packing
// recv_src_info[recv_token_begin_idx + i] = ld_nc_global(src_src_idx);
if
(
responsible_expert_idx
<
num_experts
)
{
// __syncwarp();
const
auto
src_rank
=
responsible_expert_idx
/
num_local_experts
;
const
auto
local_expert_idx
=
responsible_expert_idx
%
num_local_experts
;
// // Copy data
const
auto
rdma_recv_x_uint8
=
static_cast
<
uint8_t
*>
(
rdma_recv_x
)
+
// // NOTES: only 2 load iterations for 7K hidden with 7 unrolls
local_expert_idx
*
num_ranks
*
num_max_dispatch_tokens_per_rank
*
num_bytes_per_msg
+
// const auto src_data = reinterpret_cast<int4*>(reinterpret_cast<uint8_t*>(src_src_idx) + sizeof(int4));
src_rank
*
num_max_dispatch_tokens_per_rank
*
num_bytes_per_msg
;
// const auto dst_data = recv_x_int4 + (recv_token_begin_idx + i) * hidden_int4;
const
auto
recv_x_int4
=
// UNROLLED_WARP_COPY(7, lane_id, hidden_int4, dst_data, src_data, ld_nc_global, st_na_global);
static_cast
<
int4
*>
(
packed_recv_x
)
+
local_expert_idx
*
num_ranks
*
num_max_dispatch_tokens_per_rank
*
hidden_int4
;
const
auto
recv_src_info
=
packed_recv_src_info
+
local_expert_idx
*
num_ranks
*
num_max_dispatch_tokens_per_rank
;
// // Copy scales
const
auto
recv_range
=
packed_recv_layout_range
+
local_expert_idx
*
num_ranks
;
// if constexpr (kUseFP8) {
const
auto
num_aligned_scales
=
ALIGN
<
int
>
(
num_scales
,
sizeof
(
float
)
/
sizeof
(
scale_t
));
// // Equivalent CuTe layout:
const
auto
recv_x_scales
=
static_cast
<
scale_t
*>
(
packed_recv_x_scales
)
+
// // (num_tokens, (num_packed, num_elems_per_pack)):(num_elems_per_pack, (num_tokens * num_elems_per_pack, 1))
local_expert_idx
*
num_ranks
*
num_max_dispatch_tokens_per_rank
*
num_aligned_scales
;
// const auto src_scales = reinterpret_cast<float*>(reinterpret_cast<uint8_t*>(src_data) + hidden_bytes);
// const auto num_elems_per_pack = static_cast<int>(sizeof(packed_t) / sizeof(scale_t));
// Shared between sub-warps in warp groups
// const auto token_idx = recv_token_begin_idx + i;
__shared__
int
shared_num_recv_tokens
[
kNumMaxWarpGroups
],
shared_recv_token_begin_idx
[
kNumMaxWarpGroups
];
// const auto token_stride = num_elems_per_pack;
// const auto pack_stride = num_ranks * num_max_dispatch_tokens_per_rank * num_elems_per_pack;
// Wait tokens to arrive
// if (lane_id < num_scales) {
// NOTES: using sub-warp 1 to overlap with sub-warp 0
// const auto pack_idx = lane_id / num_elems_per_pack;
int64_t
num_recv_tokens
;
// const auto elem_idx = lane_id % num_elems_per_pack;
int
recv_token_begin_idx
;
// auto scale = extract_required_scale_format<kUseUE8M0>(ld_nc_global(src_scales + lane_id));
EP_DEVICE_ASSERT
(
num_warps_per_group
>
1
and
num_warp_groups
<
15
);
// recv_x_scales[token_idx * token_stride + pack_idx * pack_stride + elem_idx] = scale;
// }
// if (lane_id + 32 < num_scales) {
if
(
sub_warp_id
==
1
and
lane_id
==
0
)
{
// const auto pack_idx = (lane_id + 32) / num_elems_per_pack;
auto
start_time
=
wall_clock64
();
// const auto elem_idx = (lane_id + 32) % num_elems_per_pack;
int64_t
wait_recv_cost
=
0
;
// auto scale = extract_required_scale_format<kUseUE8M0>(ld_nc_global(src_scales + lane_id + 32));
// recv_x_scales[token_idx * token_stride + pack_idx * pack_stride + elem_idx] = scale;
int
offset
=
local_expert_idx
*
num_ranks
+
src_rank
;
// }
if
(
not
is_rank_masked
(
mask_buffer_ptr
,
src_rank
))
{
// }
while
((
wait_recv_cost
=
wall_clock64
()
-
start_time
)
<=
NUM_TIMEOUT_CYCLES
)
{
// not timeout
// }
if
((
num_recv_tokens
=
ld_acquire_global
(
reinterpret_cast
<
int64_t
*>
(
// }
rdma_recv_count
+
local_expert_idx
*
num_ranks
+
src_rank
)))
!=
0
)
{
break
;
}
}
}
// Mask rank if timeout
if
(
wait_recv_cost
>
NUM_TIMEOUT_CYCLES
)
{
printf
(
"Warning: DeepEP timeout for dispatch receive, rank %d, local_expert_idx %d, src_rank %d
\n
"
,
rank
,
local_expert_idx
,
src_rank
);
if
(
mask_buffer_ptr
==
nullptr
)
trap
();
atomicExch
(
mask_buffer_ptr
+
src_rank
,
1
);
}
// Do not receive tokens if rank timeout or masked
if
(
num_recv_tokens
==
0
)
num_recv_tokens
=
-
1
;
#if 1
num_recv_tokens
=
-
num_recv_tokens
-
1
;
int
num_recv_tokens_int32
=
static_cast
<
int
>
(
num_recv_tokens
);
recv_token_begin_idx
=
atomicAdd
(
packed_recv_count
+
local_expert_idx
,
num_recv_tokens_int32
);
shared_num_recv_tokens
[
warp_group_id
]
=
num_recv_tokens_int32
;
shared_recv_token_begin_idx
[
warp_group_id
]
=
recv_token_begin_idx
;
recv_range
[
src_rank
]
=
pack2
<
int
,
int64_t
>
(
num_recv_tokens_int32
,
recv_token_begin_idx
);
// Add stats for diagnosis
if
(
cumulative_local_expert_recv_stats
!=
nullptr
)
atomicAdd
(
cumulative_local_expert_recv_stats
+
local_expert_idx
,
num_recv_tokens_int32
);
if
(
dispatch_wait_recv_cost_stats
!=
nullptr
)
{
atomicAdd
(
reinterpret_cast
<
uint64_t
*>
(
dispatch_wait_recv_cost_stats
+
src_rank
),
static_cast
<
uint64_t
>
(
wait_recv_cost
));
}
#endif
}
#if 1
#ifdef USE_ROCM
// no needs to reset because there is no iteration
if
(
lane_id
==
0
){
volatile
int
ret
=
atomicAdd
((
int
*
)
&
sync_large_warp_counters
[
warp_group_id
],
1
);
}
syncwarp
();
while
(
sync_large_warp_counters
[
warp_group_id
]
<
num_warps_per_group
)
{}
#else
// asm volatile("bar.sync %0, %1;" ::"r"(warp_group_id + 2), "r"(num_warps_per_group * 32));
#endif
num_recv_tokens
=
shared_num_recv_tokens
[
warp_group_id
];
recv_token_begin_idx
=
shared_recv_token_begin_idx
[
warp_group_id
];
// Copy tokens
EP_DEVICE_ASSERT
(
num_scales
<=
64
);
for
(
int
i
=
sub_warp_id
;
i
<
num_recv_tokens
;
i
+=
num_warps_per_group
)
{
// Copy source info
const
auto
src_src_idx
=
reinterpret_cast
<
int
*>
(
rdma_recv_x_uint8
+
i
*
num_bytes_per_msg
);
if
(
lane_id
==
0
)
recv_src_info
[
recv_token_begin_idx
+
i
]
=
ld_nc_global
(
src_src_idx
);
syncwarp
();
// Copy data
// NOTES: only 2 load iterations for 7K hidden with 7 unrolls
const
auto
src_data
=
reinterpret_cast
<
int4
*>
(
reinterpret_cast
<
uint8_t
*>
(
src_src_idx
)
+
sizeof
(
int4
));
const
auto
dst_data
=
recv_x_int4
+
(
recv_token_begin_idx
+
i
)
*
hidden_int4
;
UNROLLED_WARP_COPY
(
7
,
lane_id
,
hidden_int4
,
dst_data
,
src_data
,
ld_nc_global
,
st_na_global
);
// Copy scales
if
constexpr
(
kUseFP8
)
{
// Equivalent CuTe layout:
// (num_tokens, (num_packed, num_elems_per_pack)):(num_elems_per_pack, (num_tokens * num_elems_per_pack, 1))
const
auto
src_scales
=
reinterpret_cast
<
float
*>
(
reinterpret_cast
<
uint8_t
*>
(
src_data
)
+
hidden_bytes
);
const
auto
num_elems_per_pack
=
static_cast
<
int
>
(
sizeof
(
packed_t
)
/
sizeof
(
scale_t
));
const
auto
token_idx
=
recv_token_begin_idx
+
i
;
const
auto
token_stride
=
num_elems_per_pack
;
const
auto
pack_stride
=
num_ranks
*
num_max_dispatch_tokens_per_rank
*
num_elems_per_pack
;
if
(
lane_id
<
num_scales
)
{
const
auto
pack_idx
=
lane_id
/
num_elems_per_pack
;
const
auto
elem_idx
=
lane_id
%
num_elems_per_pack
;
auto
scale
=
extract_required_scale_format
<
kUseUE8M0
>
(
ld_nc_global
(
src_scales
+
lane_id
));
recv_x_scales
[
token_idx
*
token_stride
+
pack_idx
*
pack_stride
+
elem_idx
]
=
scale
;
}
if
(
lane_id
+
kWarpSize
<
num_scales
)
{
const
auto
pack_idx
=
(
lane_id
+
kWarpSize
)
/
num_elems_per_pack
;
const
auto
elem_idx
=
(
lane_id
+
kWarpSize
)
%
num_elems_per_pack
;
auto
scale
=
extract_required_scale_format
<
kUseUE8M0
>
(
ld_nc_global
(
src_scales
+
lane_id
+
kWarpSize
));
recv_x_scales
[
token_idx
*
token_stride
+
pack_idx
*
pack_stride
+
elem_idx
]
=
scale
;
}
}
}
#endif
}
#if !defined(ROCM_DISABLE_CTX)
rocshmem
::
rocshmem_wg_ctx_destroy
(
&
ctx
);
#endif
}
}
void
dispatch
(
void
*
packed_recv_x
,
void
*
packed_recv_x_scales
,
void
dispatch
(
void
*
packed_recv_x
,
int
*
packed_recv_src_info
,
int64_t
*
packed_recv_layout_range
,
void
*
packed_recv_x_scales
,
int
*
packed_recv_src_info
,
int64_t
*
packed_recv_layout_range
,
int
*
packed_recv_count
,
int
*
packed_recv_count
,
int
*
global_atomic_counter
,
int
*
mask_buffer_ptr
,
int
*
cumulative_local_expert_recv_stats
,
int
*
cumulative_local_expert_recv_stats
,
int64_t
*
dispatch_wait_recv_cost_stats
,
int64_t
*
dispatch_wait_recv_cost_stats
,
void
*
rdma_recv_x
,
int
*
rdma_recv_count
,
void
*
rdma_x
,
void
*
rdma_recv_x
,
const
void
*
x
,
const
topk_idx_t
*
topk_idx
,
int64_t
*
rdma_recv_count
,
int
*
next_clean
,
int
num_next_clean_int
,
void
*
rdma_x
,
int
num_tokens
,
int
hidden
,
int
num_max_dispatch_tokens_per_rank
,
const
void
*
x
,
int
num_topk
,
int
num_experts
,
int
rank
,
int
num_ranks
,
const
int64_t
*
topk_idx
,
bool
use_fp8
,
bool
round_scale
,
bool
use_ue8m0
,
int64_t
*
next_clean
,
void
*
workspace
,
int
num_device_sms
,
int
num_next_clean_int
,
cudaStream_t
stream
,
int
phases
)
{
int
num_tokens
,
int
hidden
,
int
num_max_dispatch_tokens_per_rank
,
int
num_topk
,
int
num_experts
,
int
rank
,
int
num_ranks
,
bool
use_fp8
,
bool
round_scale
,
bool
use_ue8m0
,
void
*
workspace
,
int
num_device_sms
,
hipStream_t
stream
,
int
phases
)
{
constexpr
int
kNumMaxTopK
=
11
;
constexpr
int
kNumMaxTopK
=
11
;
// const int num_warp_groups = ceil_div(num_experts, num_device_sms);
const
int
num_warp_groups
=
DIVUP
(
num_experts
,
num_device_sms
);
// const int num_warps_per_group = 32 / num_warp_groups;
EP_HOST_ASSERT
(
num_warp_groups
<=
16
);
// EP_HOST_ASSERT(num_warp_groups > 0 and num_warps_per_group > 0);
const
int
num_warps_per_group
=
16
/
num_warp_groups
;
// 每个kernel最大16个warp
// EP_HOST_ASSERT(kNumMaxTopK + 1 <= num_warp_groups * num_warps_per_group);
EP_HOST_ASSERT
(
num_warp_groups
>
0
and
num_warps_per_group
>
0
);
EP_HOST_ASSERT
(
kNumMaxTopK
+
1
<=
num_warp_groups
*
num_warps_per_group
);
// const auto num_warps = num_warp_groups * num_warps_per_group;
// const auto num_sms = ceil_div(num_experts, num_warp_groups);
const
auto
num_warps
=
num_warp_groups
*
num_warps_per_group
;
// EP_HOST_ASSERT(num_topk <= kNumMaxTopK);
const
auto
num_sms
=
DIVUP
(
num_experts
,
num_warp_groups
);
EP_HOST_ASSERT
(
num_topk
<=
kNumMaxTopK
);
// // Workspace checks
// auto atomic_counter_per_expert = static_cast<int*>(workspace);
// Workspace checks
// auto atomic_finish_counter_per_expert = atomic_counter_per_expert + num_experts;
auto
atomic_counter_per_expert
=
static_cast
<
int
*>
(
workspace
);
// EP_HOST_ASSERT(num_experts * sizeof(int) * 2 <= NUM_WORKSPACE_BYTES);
auto
atomic_finish_counter_per_expert
=
atomic_counter_per_expert
+
num_experts
;
EP_HOST_ASSERT
(
num_experts
*
sizeof
(
int
)
*
2
<=
NUM_WORKSPACE_BYTES
);
// // FP8 checks
// if (use_ue8m0)
#define DISPATCH_LAUNCH_CASE(hidden) \
// EP_HOST_ASSERT(round_scale and "UE8M0 SF requires `round_scale=True`");
{ \
auto dispatch_func = dispatch<false, false, hidden>; \
// #define DISPATCH_LAUNCH_CASE(hidden) { \
if(use_fp8 and not use_ue8m0) \
// auto dispatch_func = dispatch<false, false, hidden>; \
dispatch_func = dispatch<true, false, hidden>; \
// if (use_fp8 and not use_ue8m0) \
if(use_fp8 and use_ue8m0) \
// dispatch_func = dispatch<true, false, hidden>; \
dispatch_func = dispatch<true, true, hidden>; \
// if (use_fp8 and use_ue8m0) \
LAUNCH_KERNEL_NON_COOPERATIVE(&cfg, \
// dispatch_func = dispatch<true, true, hidden>; \
dispatch_func, \
// LAUNCH_KERNEL(&cfg, dispatch_func, \
packed_recv_x, \
// packed_recv_x, packed_recv_x_scales, \
packed_recv_x_scales, \
// packed_recv_src_info, packed_recv_layout_range, \
packed_recv_src_info, \
// packed_recv_count, \
packed_recv_layout_range, \
// cumulative_local_expert_recv_stats, \
packed_recv_count, \
// dispatch_wait_recv_cost_stats, \
global_atomic_counter, \
// rdma_recv_x, rdma_recv_count, rdma_x, \
mask_buffer_ptr, \
// x, topk_idx, \
cumulative_local_expert_recv_stats, \
// atomic_counter_per_expert, atomic_finish_counter_per_expert, \
dispatch_wait_recv_cost_stats, \
// next_clean, num_next_clean_int, \
rdma_recv_x, \
// num_tokens, num_max_dispatch_tokens_per_rank, \
rdma_recv_count, \
// num_topk, num_experts, rank, num_ranks, \
rdma_x, \
// num_warp_groups, num_warps_per_group, \
x, \
// round_scale, phases); } break
topk_idx, \
atomic_counter_per_expert, \
// SETUP_LAUNCH_CONFIG(num_sms, num_warps * 32, stream);
atomic_finish_counter_per_expert, \
// SWITCH_HIDDEN(DISPATCH_LAUNCH_CASE);
next_clean, \
// #undef DISPATCH_LAUNCH_CASE
num_next_clean_int, \
}
num_tokens, \
num_max_dispatch_tokens_per_rank, \
template
<
int
kNumSendUnrolls
>
num_topk, \
__forceinline__
__device__
int
logfmt_encode
(
void
*
buffer
,
nv_bfloat162
*
shared_amaxmin
,
const
int
&
lane_id
)
{
num_experts, \
// constexpr int kNumElemsPerInt4 = sizeof(int4) / sizeof(nv_bfloat16);
rank, \
// constexpr float kLogThreshold = 0;
num_ranks, \
// constexpr float kMinClip = 32; // `== log_2(2 ^ (2 ^ 5))`
num_warp_groups, \
// constexpr int kNumBits = 10;
num_warps_per_group, \
// constexpr int kNumValues = 1 << (kNumBits - 1);
round_scale, \
phases); \
// int4 int4_values[kNumSendUnrolls];
} \
// const auto& uint32_values = reinterpret_cast<uint32_t*>(int4_values);
break
// const auto& bf162_values = reinterpret_cast<nv_bfloat162*>(int4_values);
SETUP_LAUNCH_CONFIG
(
num_sms
,
num_warps
*
kWarpSize
,
stream
);
// // Calculate lane offset
SWITCH_HIDDEN
(
DISPATCH_LAUNCH_CASE
);
// const auto& ld_buffer = reinterpret_cast<uint32_t*>(static_cast<uint8_t*>(buffer) + lane_id * (kNumSendUnrolls * sizeof(int4)));
#undef DISPATCH_LAUNCH_CASE
// const auto& st_buffer = reinterpret_cast<uint32_t*>(static_cast<uint8_t*>(buffer) + lane_id * (kNumSendUnrolls * sizeof(int4) * 10 / 16));
// // Local log amax
// auto bf162_amax = __nv_bfloat162(CUDART_ZERO_BF16, CUDART_ZERO_BF16);
// auto bf162_amin = __nv_bfloat162(CUDART_INF_BF16, CUDART_INF_BF16);
// uint32_t local_signs = 0;
// #pragma unroll
// for (int k = 0; k < kNumSendUnrolls * kNumElemsPerInt4 / 2; ++ k) {
// // TODO: eliminate bank conflicts
// uint32_values[k] = ld_buffer[k];
// local_signs |= ((uint32_values[k] >> 15) & 1) << (k * 2);
// local_signs |= ((uint32_values[k] >> 31) & 1) << (k * 2 + 1);
// uint32_values[k] &= 0x7fff7fff;
// bf162_amax = __hmax2(bf162_amax, bf162_values[k]);
// bf162_amin = __hmin2(bf162_amin, bf162_values[k]);
// }
// // Reduce per 128 channels
// // TODO: figure out how hardware do 2-byte min/max
// auto amax = std::max(static_cast<float>(bf162_amax.x), static_cast<float>(bf162_amax.y));
// auto amin = std::min(static_cast<float>(bf162_amin.x), static_cast<float>(bf162_amin.y));
// constexpr static int kNumLanesToReduce = 128 * sizeof(nv_bfloat16) / (kNumSendUnrolls * sizeof(int4));
// amax = warp_reduce_max<kNumLanesToReduce>(amax);
// amin = warp_reduce_min<kNumLanesToReduce>(amin);
// // Write min/max into the shared memory
// if (shared_amaxmin != nullptr)
// *shared_amaxmin = __nv_bfloat162(amax, amin);
// __syncwarp();
// // Calculate log amin/amax float
// const auto& log_amax = log2f_approx(amax);
// const auto& log_amin = fmaxf(log2f_approx(amin), log_amax - kMinClip);
// const bool& enable_cast = warp_reduce_and<kNumLanesToReduce, true>(log_amax < kLogThreshold and log_amin < log_amax);
// // Case into LogFMT-10 if satisfied
// if (enable_cast) {
// const auto step = (log_amax - log_amin) / static_cast<float>(kNumValues - 2);
// const auto step_inv = 1.0f / step;
// const auto rounding = 2.0f - log2f_approx((1.0f + exp2f_approx(step)) * 0.5f) * step_inv;
// const auto fused_rounding = rounding - log_amin * step_inv;
// // Pack every 256 bits into 160 bits
// EP_STATIC_ASSERT(kNumSendUnrolls == 2 or kNumSendUnrolls == 4, "kNumSendUnrolls == 2 or 4 only");
// uint32_t encoded[kNumElemsPerInt4 * 2];
// #pragma unroll 1
// for (int i = 0; i < kNumSendUnrolls / 2; ++ i) {
// #pragma unroll
// for (int k = 0; k < kNumElemsPerInt4; ++ k) {
// const auto& [x, y] = __bfloat1622float2(bf162_values[i * kNumElemsPerInt4 + k]);
// encoded[k * 2 + 0] = __float2uint_rd(fmaxf(log2f_approx(x) * step_inv + fused_rounding, 0));
// encoded[k * 2 + 1] = __float2uint_rd(fmaxf(log2f_approx(y) * step_inv + fused_rounding, 0));
// }
// st_buffer[i * 5 + 0] = (encoded[ 0] >> 0) | (encoded[ 1] << 9) | (encoded[ 2] << 18) | (encoded[ 3] << 27);
// st_buffer[i * 5 + 1] = (encoded[ 3] >> 5) | (encoded[ 4] << 4) | (encoded[ 5] << 13) | (encoded[ 6] << 22) | (encoded[7] << 31);
// st_buffer[i * 5 + 2] = (encoded[ 7] >> 1) | (encoded[ 8] << 8) | (encoded[ 9] << 17) | (encoded[10] << 26);
// st_buffer[i * 5 + 3] = (encoded[10] >> 6) | (encoded[11] << 3) | (encoded[12] << 12) | (encoded[13] << 21) | (encoded[14] << 30);
// st_buffer[i * 5 + 4] = (encoded[14] >> 2) | (encoded[15] << 7) | ((i == 0) ? (local_signs << 16) : (local_signs & 0xffff0000u));
// }
// tma_store_fence();
// __syncwarp();
// }
// // Return TMA copy bytes
// return enable_cast ? (32 * (kNumSendUnrolls * sizeof(int4) * 8 * 10 / 16 / 8)):
// (32 * (kNumSendUnrolls * sizeof(int4)));
}
template
<
int
kNumLanes
,
int
kNumSendUnrolls
,
int
kNumRecvUnrolls
>
__forceinline__
__device__
void
logfmt_check_amaxmin
(
uint8_t
*
meta_buffer
,
float2
*
shared_log_amax
,
float2
*
shared_log_amin
,
int
*
shared_cast_info
,
const
int
lane_id
)
{
// constexpr float kLogThreshold = 0;
// constexpr float kMinClip = 32; // `== log_2(2 ^ (2 ^ 5))`
// bool enable_cast = true;
// if (lane_id < kNumLanes) {
// // Calculate log amin/amax float
// auto amaxmin2 = reinterpret_cast<uint64_t*>(meta_buffer)[lane_id];
// const auto& bf162_amaxmin = reinterpret_cast<__nv_bfloat162*>(&amaxmin2);
// float log_amax[2], log_amin[2];
// #pragma unroll
// for (int i = 0; i < 2; ++ i) {
// auto amax = static_cast<float>(bf162_amaxmin[i].x);
// auto amin = static_cast<float>(bf162_amaxmin[i].y);
// log_amax[i] = log2f_approx(amax);
// log_amin[i] = amin == 0 ? log_amax[i] - kMinClip : fmaxf(log2f_approx(amin), log_amax[i] - kMinClip);
// enable_cast = enable_cast and log_amax[i] < kLogThreshold and log_amin[i] < log_amax[i];
// }
// shared_log_amax[lane_id] = make_float2(log_amax[0], log_amax[1]);
// shared_log_amin[lane_id] = make_float2(log_amin[0], log_amin[1]);
// }
// const auto& casted = warp_reduce_and<kNumSendUnrolls>(enable_cast) ? 1u << (lane_id / kNumRecvUnrolls): 0u;
// const auto& num_casted_prefix = __popc(warp_reduce_or<kNumRecvUnrolls, true>(casted) & ((1u << (lane_id / kNumRecvUnrolls)) - 1));
// if (lane_id < kNumLanes and lane_id % kNumRecvUnrolls == 0)
// shared_cast_info[lane_id / kNumRecvUnrolls] = (num_casted_prefix << 1) | (casted ? 1u : 0u);
// __syncwarp();
}
template
<
int
kNumRecvUnrolls
>
__forceinline__
__device__
void
decode_and_accumulate
(
uint32_t
*
ld_buffer
,
float
*
accum
,
const
float
&
log_amax
,
const
float
&
log_amin
,
const
bool
&
enable_cast
,
const
float
&
weight
)
{
// if (enable_cast) {
// constexpr int kNumBits = 10;
// constexpr int kNumValues = 1 << (kNumBits - 1);
// const auto& step = (log_amax - log_amin) / static_cast<float>(kNumValues - 2);
// auto decode = [=](const uint32_t &encoded, const uint32_t &sign) {
// const auto decoded = encoded == 0 ? .0f : exp2f_approx((encoded - 1) * step + log_amin);
// return sign ? -decoded : decoded;
// };
// EP_STATIC_ASSERT(kNumRecvUnrolls == 2 or kNumRecvUnrolls == 4, "kNumRecvUnrolls == 2 or 4 only");
// #pragma unroll
// for (int i = 0; i < kNumRecvUnrolls / 2; ++ i) {
// uint32_t concat[6];
// concat[0] = ld_buffer[i * 5];
// #pragma unroll
// for (int k = 1; k < 5; ++ k)
// concat[k] = (ld_buffer[i * 5 + k - 1] >> (32 - k * 5)) | (ld_buffer[i * 5 + k] << (k * 5));
// concat[5] = ld_buffer[i * 5 + 4] >> 7;
// const uint32_t& local_signs = ld_buffer[i * 5 + 4] >> 16;
// #pragma unroll
// for (int k = 0; k < 5; ++ k) {
// accum[i * 16 + k * 3 + 0] += decode((concat[k] >> 0) & 0x1ff, (local_signs >> (k * 3 + 0)) & 1) * weight;
// accum[i * 16 + k * 3 + 1] += decode((concat[k] >> 9) & 0x1ff, (local_signs >> (k * 3 + 1)) & 1) * weight;
// accum[i * 16 + k * 3 + 2] += decode((concat[k] >> 18) & 0x1ff, (local_signs >> (k * 3 + 2)) & 1) * weight;
// }
// accum[i * 16 + 15] += decode(concat[5] & 0x1ff, (local_signs >> 15) & 1) * weight;
// }
// } else {
// #pragma unroll
// for (int k = 0; k < kNumRecvUnrolls * 4; ++ k) {
// auto bf16_pack = *reinterpret_cast<__nv_bfloat162*>(ld_buffer + k);
// accum[k * 2 + 0] += static_cast<float>(bf16_pack.x) * weight;
// accum[k * 2 + 1] += static_cast<float>(bf16_pack.y) * weight;
// }
// }
}
}
template
<
bool
kUseLogFMT
,
int
kHidden
,
int
kNumMaxTopk
,
int
kNumMaxUnrolls
>
template
<
bool
kUseLogFMT
,
int
kHidden
,
int
kNumMaxTopk
,
int
kNumMaxUnrolls
>
__global__
__launch_bounds__
(
1024
,
1
)
void
__launch_bounds__
(
1024
,
1
)
__global__
void
combine
(
void
*
combined_x
,
combine
(
void
*
combined_x
,
void
*
rdma_recv_x
,
void
*
rdma_recv_x
,
int
*
rdma_recv_flag
,
void
*
rdma_send_x
,
int
*
rdma_recv_flag
,
const
void
*
x
,
const
topk_idx_t
*
topk_idx
,
const
float
*
topk_weights
,
void
*
rdma_send_x
,
const
int
*
src_info
,
const
int64_t
*
layout_range
,
const
void
*
x
,
const
int64_t
*
topk_idx
,
const
float
*
topk_weights
,
const
int
*
src_info
,
const
int64_t
*
layout_range
,
int
*
global_atomic_counter
,
int
*
mask_buffer_ptr
,
int64_t
*
combine_wait_recv_cost_stats
,
int64_t
*
combine_wait_recv_cost_stats
,
int
*
next_clean
,
int
num_next_clean_int
,
int64_t
*
next_clean
,
int
num_next_clean_int
,
int
*
atomic_clean_flag
,
int
*
atomic_clean_flag
,
int
num_combined_tokens
,
int
hidden
,
int
num_topk
,
int
num_combined_tokens
,
int
hidden
,
int
num_topk
,
int
num_max_dispatch_tokens_per_rank
,
int
num_max_dispatch_tokens_per_rank
,
int
num_experts
,
int
rank
,
int
num_ranks
,
int
num_experts
,
int
num_warp_groups
,
int
num_warps_per_group
,
int
rank
,
int
phases
,
bool
zero_copy
)
{
int
num_ranks
,
// const auto sm_id = __shfl_sync(0xffffffff, static_cast<int>(blockIdx.x), 0);
int
num_warp_groups
,
// const auto num_sms = __shfl_sync(0xffffffff, static_cast<int>(gridDim.x), 0);
int
num_warps_per_group
,
// const auto thread_id = static_cast<int>(threadIdx.x);
int
phases
,
// const auto num_threads = __shfl_sync(0xffffffff, static_cast<int>(blockDim.x), 0);
bool
zero_copy
)
{
// const auto warp_id = __shfl_sync(0xffffffff, thread_id / 32, 0), lane_id = get_lane_id();
#if !defined(ROCM_DISABLE_CTX)
// const auto num_local_experts = num_experts / num_ranks;
__shared__
rocshmem
::
rocshmem_ctx_t
ctx
;
// const auto warp_group_id = warp_id / num_warps_per_group;
rocshmem
::
rocshmem_wg_ctx_create
(
0
,
&
ctx
);
// const auto sub_warp_id = warp_id % num_warps_per_group;
#endif
// const auto responsible_expert_idx = sm_id * num_warp_groups + warp_group_id;
// const auto sm_id = static_cast<int>(blockIdx.x);
// extern __shared__ __align__(1024) uint8_t smem_buffer[];
// const auto num_sms = static_cast<int>(gridDim.x);
// const auto thread_id = static_cast<int>(threadIdx.x);
// // Data type staffs
// const auto num_threads = static_cast<int>(blockDim.x);
// constexpr int kNumElemsPerInt4 = sizeof(int4) / sizeof(nv_bfloat16);
// const auto warp_id = thread_id / kWarpSize, lane_id = get_lane_id();
// constexpr int64_t hidden_bf16_int4 = kHidden / kNumElemsPerInt4;
// const auto num_local_experts = num_experts / num_ranks;
// const auto warp_group_id = warp_id / kNumWarpsPerGroup;
// // Use different unroll factors for send and recv phases
// const auto sub_warp_id = warp_id % kNumWarpsPerGroup;
// constexpr int kNumSendUnrolls = kHidden % (32 * 4 * sizeof(int4) / sizeof(nv_bfloat16)) == 0 ? 4 : 2;
// const auto responsible_expert_idx = sm_id * kNumWarpGroups + warp_group_id;
// constexpr int kNumRecvUnrolls = 2;
// constexpr int hidden_bf16_int4_pad = align_up(static_cast<int>(hidden_bf16_int4), 32 * kNumSendUnrolls);
// // Data type staffs
// EP_STATIC_ASSERT(kHidden % (32 * 2 * sizeof(int4) / sizeof(nv_bfloat16)) == 0, "Invalid hidden");
// constexpr int kNumElemsPerInt4 = sizeof(int4) / sizeof(gpu_bfloat16_t);
// EP_STATIC_ASSERT(kNumSendUnrolls <= kNumMaxUnrolls and kNumRecvUnrolls <= kNumMaxUnrolls, "Invalid unrolls");
// const size_t hidden_bf16_int4 = kHidden / kNumElemsPerInt4;
// EP_STATIC_ASSERT(hidden_bf16_int4 % kNumSendUnrolls == 0, "Invalid hidden");
// EP_STATIC_ASSERT(kNumSendUnrolls >= kNumRecvUnrolls, "Invalid unroll factors");
// // Message package
// // BF16 mode: always use BF16 for hidden data (ignoring the extra flag slot)
// // Message package
// constexpr size_t num_bytes_per_slot = sizeof(int4) + kHidden * sizeof(gpu_bfloat16_t);
// EP_STATIC_ASSERT(kHidden % 128 == 0, "Invalid hidden");
// EP_STATIC_ASSERT(num_bytes_per_slot % sizeof(int4) == 0, "Invalid vectorization");
// constexpr int kNumDivisions = kHidden / 128;
// __syncthreads();
// constexpr int kNumMetaBytes = kNumDivisions * sizeof(nv_bfloat162);
// #ifdef USE_ROCM
// constexpr size_t num_bytes_per_slot = kHidden * sizeof(nv_bfloat16) + kNumMetaBytes;
// // 16 is the max possible number of warps in AMD GPUs
// EP_STATIC_ASSERT(num_bytes_per_slot % sizeof(int4) == 0, "Invalid vectorization");
// constexpr int kMaxNumWarps = 1024 / kWarpSize;
// __shared__ volatile int sync_large_warp_counters[kMaxNumWarps];
// // Sending phase
// if (threadIdx.x==0){
// if ((phases & LOW_LATENCY_SEND_PHASE) == 0)
// // printf("combine");
// goto LOW_LATENCY_COMBINE_RECV;
// #pragma unroll
// for (int i = 0; i < kMaxNumWarps; ++i) {
// // Clean up next buffer
// sync_large_warp_counters[i] = 0;
// if (sm_id == 0 and warp_group_id == 0 and sub_warp_id == 0) {
// }
// #pragma unroll
// }
// for (int i = lane_id; i < num_next_clean_int; i += 32)
// __syncthreads();
// next_clean[i] = 0;
// #endif
// // Notify before executing `int_p`
// // Sending phase
// __syncwarp();
// if ((phases & LOW_LATENCY_SEND_PHASE) == 0)
// if (lane_id == 0)
// goto LOW_LATENCY_COMBINE_RECV;
// atomic_add_release_global(atomic_clean_flag, num_experts);
// }
// // Clean up next buffer
// if (sm_id == 0 and warp_group_id == 0 and sub_warp_id == 0) {
// // Issue IBGDA sends
// #pragma unroll
// if (responsible_expert_idx < num_experts) {
// for (int i = lane_id; i < num_next_clean_int; i += kWarpSize)
// const auto dst_rank = responsible_expert_idx / num_local_experts;
// next_clean[i] = 0;
// const auto local_expert_idx = responsible_expert_idx % num_local_experts;
// const auto global_expert_idx = rank * num_local_experts + local_expert_idx;
// // Notify before executing `int_p`
// const auto layout = __ldg(layout_range + local_expert_idx * num_ranks + dst_rank);
// syncwarp();
// const auto local_x = static_cast<const int4*>(x) +
// if (lane_id == 0)
// local_expert_idx * num_ranks * num_max_dispatch_tokens_per_rank * hidden_bf16_int4;
// atomic_add_release_global(atomic_clean_flag, num_experts);
// const auto local_src_info = src_info + local_expert_idx * num_ranks * num_max_dispatch_tokens_per_rank;
// }
// const auto rdma_send_x_vec = static_cast<uint8_t*>(rdma_send_x) +
// local_expert_idx * num_ranks * num_max_dispatch_tokens_per_rank * num_bytes_per_slot;
// // Issue IBGDA sends
// if (responsible_expert_idx < num_experts) {
// // Unpack layout
// const auto dst_rank = responsible_expert_idx / num_local_experts;
// int offset, num_tokens_to_send;
// const auto local_expert_idx = responsible_expert_idx % num_local_experts;
// unpack2(layout, num_tokens_to_send, offset);
// const auto global_expert_idx = rank * num_local_experts + local_expert_idx;
// const auto layout = __ldg(layout_range + local_expert_idx * num_ranks + dst_rank);
// // TMA stuffs
// const auto local_x = reinterpret_cast<const int4*>(x) +
// constexpr int kNumTMABufferBytes = sizeof(int4) * 32 * kNumSendUnrolls;
// local_expert_idx * num_ranks * num_max_dispatch_tokens_per_rank * hidden_bf16_int4;
// constexpr int kNumStages = 3;
// const auto local_src_info = src_info + local_expert_idx * num_ranks * num_max_dispatch_tokens_per_rank;
// constexpr int kNumPrefetch = 1;
// const auto rdma_send_x_vec = reinterpret_cast<uint8_t*>(rdma_send_x) +
// EP_STATIC_ASSERT(kNumStages == 3 and kNumPrefetch == 1, "Invalid stages");
// local_expert_idx * num_ranks * num_max_dispatch_tokens_per_rank * num_bytes_per_slot;
// auto smem_ptr = smem_buffer + warp_id * (kNumStages * (kNumTMABufferBytes + 16) + kNumMetaBytes);
// // Unpack layout
// uint32_t tma_phase = 0;
// int offset, num_tokens_to_send;
// auto tma_buffers = PatternVisitor([=](const int& i) { return reinterpret_cast<int4*>(smem_ptr + i * (kNumTMABufferBytes + 16)); });
// unpack2(layout, num_tokens_to_send, offset);
// auto full_barriers = PatternVisitor([=](const int& i) { return reinterpret_cast<uint64_t*>(smem_ptr + i * (kNumTMABufferBytes + 16) + kNumTMABufferBytes); });
// auto meta_buffers = kUseLogFMT ? reinterpret_cast<nv_bfloat162*>(smem_ptr + kNumStages * (kNumTMABufferBytes + 16)) : nullptr;
// // Issue IBGDA send
// EP_STATIC_ASSERT(kNumSendUnrolls * kNumStages <= 12, "TMA buffer size exceed limit");
// for (int token_idx = offset + sub_warp_id; token_idx < offset + num_tokens_to_send; token_idx += kNumWarpsPerGroup) {
// const auto x_int4 = local_x + token_idx * hidden_bf16_int4;
// // Initialize m-barriers
// const auto rdma_send_type_row = reinterpret_cast<int*>(rdma_send_x_vec + token_idx * num_bytes_per_slot);
// if (lane_id < kNumStages) {
// const auto rdma_send_x_vec_row = reinterpret_cast<uint8_t*>(rdma_send_type_row + 4);
// mbarrier_init(full_barriers[lane_id], 1);
// fence_barrier_init();
// // Copy directly to local rank, or copy to buffer and issue RDMA
// }
// auto src_idx = __ldg(local_src_info + token_idx);
// __syncwarp();
// const auto buf_ptr = reinterpret_cast<int64_t>(rdma_send_x_vec_row);
// const auto dst_ptr = reinterpret_cast<uint64_t>(rdma_recv_x) + (global_expert_idx * num_max_dispatch_tokens_per_rank + src_idx) * num_bytes_per_slot + sizeof(int4);
// constexpr int kNumIters = hidden_bf16_int4_pad / (32 * kNumSendUnrolls);
// if (dst_rank == rank) {
// auto tma_load_and_arrive = [&](const int& stage_idx, const int4* gmem_ptr, const int& num_bytes) {
// const auto dst_int4_ptr = reinterpret_cast<int4*>(dst_ptr);
// tma_load_1d(tma_buffers[stage_idx], gmem_ptr, full_barriers[stage_idx], num_bytes);
// UNROLLED_WARP_COPY(7, lane_id, hidden_bf16_int4, dst_int4_ptr, x_int4, ld_nc_global, st_na_global);
// mbarrier_arrive_and_expect_tx(full_barriers[stage_idx], num_bytes);
// } else {
// };
// const auto buf_int4_ptr = reinterpret_cast<int4*>(buf_ptr);
// auto get_num_tma_bytes = [&](const int& offset_int4) {
// if (not zero_copy)
// return min(kNumTMABufferBytes, static_cast<int>((hidden_bf16_int4 - offset_int4) * sizeof(int4)));
// UNROLLED_WARP_COPY(7, lane_id, hidden_bf16_int4, buf_int4_ptr, x_int4, ld_nc_global, st_na_global);
// };
// //nvshmemi_ibgda_put_nbi_warp(dst_ptr, buf_ptr, hidden * sizeof(gpu_bfloat16_t), dst_rank, local_expert_idx, lane_id, token_idx - offset);
// // Issue IBGDA send
// #if defined(ROCM_DISABLE_CTX)
// for (int token_idx = offset + sub_warp_id; token_idx < offset + num_tokens_to_send; token_idx += num_warps_per_group) {
// internode::shmemx_int8_put_nbi_warp(
// const auto x_int4 = local_x + token_idx * hidden_bf16_int4;
// #else
// const auto rdma_send_type_row = reinterpret_cast<int*>(rdma_send_x_vec + token_idx * num_bytes_per_slot);
// internode::shmem_ctx_schar_put_nbi_warp(ctx,
// const auto rdma_send_x_vec_row = reinterpret_cast<uint8_t*>(rdma_send_type_row);
// #endif
// reinterpret_cast<signed char*>(dst_ptr), reinterpret_cast<signed char*>(buf_ptr), hidden * sizeof(gpu_bfloat16_t), dst_rank);
// // Copy directly to local rank, or copy to buffer and issue RDMA
// const auto src_idx = __shfl_sync(0xffffffff, __ldg(local_src_info + token_idx), 0);
// #if defined(ROCM_DISABLE_CTX)
// const auto buf_ptr = reinterpret_cast<int64_t>(rdma_send_x_vec_row);
// internode::shmem_fence();
// const auto dst_ptr = reinterpret_cast<uint64_t>(rdma_recv_x) + (global_expert_idx * num_max_dispatch_tokens_per_rank + src_idx) * num_bytes_per_slot;
// #else
// const auto dst_p2p_ptr = nvshmemi_get_p2p_ptr(dst_ptr, rank, dst_rank);
// internode::shmem_ctx_quiet(ctx);
// int num_send_bytes = hidden * sizeof(nv_bfloat16);
// #endif
// }
// if (not zero_copy or dst_p2p_ptr != 0) {
// }
// // Read from `cpy_src_int4_ptr` and copy into `cpy_dst_int4_ptr`
// const auto cpy_src_int4_ptr = zero_copy ? reinterpret_cast<int4*>(buf_ptr) : x_int4;
// // Put finishing flag
// const auto cpy_dst_int4_ptr = dst_p2p_ptr == 0 ? reinterpret_cast<int4*>(buf_ptr) : reinterpret_cast<int4*>(dst_p2p_ptr);
// EP_STATIC_ASSERT(kNumWarpsPerGroup > 1, "Requires more than one warp per group");
// #ifdef USE_ROCM
// // Prefetch
// if (lane_id == 0){
// if (elect_one_sync())
// volatile int ret = __hip_atomic_fetch_add(
// tma_load_and_arrive(0, cpy_src_int4_ptr, get_num_tma_bytes(0));
// &sync_large_warp_counters[warp_group_id], 1,
// __syncwarp();
// __ATOMIC_RELAXED, __HIP_MEMORY_SCOPE_WORKGROUP);
// }
// int tma_offset_bytes = kNumMetaBytes;
// syncwarp();
// #pragma unroll
// while (sync_large_warp_counters[warp_group_id] < (kNumWarpsPerGroup));
// for (int i = lane_id * kNumSendUnrolls, iter_idx = 0; i < hidden_bf16_int4_pad; i += 32 * kNumSendUnrolls, ++ iter_idx) {
// #else
// // Load the next iteration
// asm volatile("bar.sync %0, %1;" :: "r"(warp_group_id + 1), "r"(kNumWarpsPerGroup * 32));
// const int& stage_idx = iter_idx % kNumStages;
// #endif
// const int& next_stage_idx = (iter_idx + 1) % kNumStages;
// if (sub_warp_id == 1 and lane_id == 0) {
// if (iter_idx + 1 < kNumIters and elect_one_sync()) {
// while (ld_acquire_global(atomic_clean_flag) == 0);
// tma_store_wait<kNumStages - kNumPrefetch - 1>();
// if (dst_rank != rank) {
// const auto& offset_int4 = i + 32 * kNumSendUnrolls;
// #ifdef USE_ROCM
// tma_load_and_arrive(next_stage_idx, cpy_src_int4_ptr + offset_int4, get_num_tma_bytes(offset_int4));
// #if defined(ROCM_DISABLE_CTX)
// }
// internode::shmem_long_atomic_add(rdma_recv_flag + global_expert_idx, 1, dst_rank);
// __syncwarp();
// #else
// internode::shmem_ctx_long_atomic_add(ctx, rdma_recv_flag + global_expert_idx, 1, dst_rank);
// // Wait the current TMA arrival
// #endif
// EP_STATIC_ASSERT(kNumStages < 32, "Too many stages");
// #else
// mbarrier_wait<true>(full_barriers[stage_idx], tma_phase, stage_idx);
// nvshmemi_ibgda_amo_nonfetch_add(rdma_recv_flag + global_expert_idx, 1, dst_rank, local_expert_idx);
// if constexpr (kUseLogFMT) {
// #endif
// // Cast if possible
// } else {
// constexpr int kNumInt4PerDivision = 128 / kNumElemsPerInt4;
// st_na_release(reinterpret_cast<int*>(rdma_recv_flag + global_expert_idx), 1);
// int num_tma_bytes = logfmt_encode<kNumSendUnrolls>(
// }
// tma_buffers[stage_idx],
// atomic_add_release_global(atomic_clean_flag, -1);
// // NOTES: only the leader lane will write the result
// }
// (i % kNumInt4PerDivision == 0) ? meta_buffers + i / kNumInt4PerDivision : nullptr,
// syncwarp();
// lane_id);
// }
// if (elect_one_sync())
// tma_store_1d(tma_buffers[stage_idx], reinterpret_cast<uint8_t*>(cpy_dst_int4_ptr) + tma_offset_bytes, num_tma_bytes);
// // Receiving phase
// tma_offset_bytes += num_tma_bytes;
// LOW_LATENCY_COMBINE_RECV:
// } else {
// if ((phases & LOW_LATENCY_RECV_PHASE) == 0)
// // BF16 original values
// return;
// if (elect_one_sync())
// tma_store_1d(tma_buffers[stage_idx], cpy_dst_int4_ptr + i, get_num_tma_bytes(i));
// // Wait all ranks to arrive and notify PCIe usage
// }
// if (responsible_expert_idx < num_experts) {
// __syncwarp();
// EP_STATIC_ASSERT(kNumWarpsPerGroup > 1, "Invalid number of warps per group");
// }
// if (sub_warp_id == 0 and lane_id == 0){
// while (ld_acquire_global(reinterpret_cast<int*>(rdma_recv_flag + responsible_expert_idx)) == 0);
// // Store metadata (min/max values) for LogFMT
// }
// if constexpr (kUseLogFMT) {
// }
// num_send_bytes = tma_offset_bytes;
// grid_barrier(global_atomic_counter, num_sms);
// if (elect_one_sync())
// tma_store_1d(meta_buffers, cpy_dst_int4_ptr, kNumMetaBytes);
// // Reduce tokens with FP8 cast
// }
// EP_DEVICE_ASSERT(num_topk <= kWarpSize and hidden_bf16_int4 <= num_threads);
// EP_STATIC_ASSERT(kHidden % (kWarpSize * kNumElemsPerInt4) == 0, "Invalid vectorization");
// // Flush all stores
// if (thread_id < hidden_bf16_int4) {
// tma_store_wait<0>();
// for (int token_idx = sm_id; token_idx < num_combined_tokens; token_idx += num_sms) {
// __syncwarp();
// // Read top-k indices and weights
// }
// int reg_topk_idx[kNumMaxTopk];
// float reg_topk_weights[kNumMaxTopk];
// // Issue RDMA
// #pragma unroll
// // NOTES: for zero-copy mode, we assume the data is already in the send buffer
// for (int i = 0; i < num_topk; ++ i) {
// if (dst_p2p_ptr == 0)
// reg_topk_idx[i] = static_cast<int>(__ldg(topk_idx + token_idx * num_topk + i));
// nvshmemi_ibgda_put_nbi_warp(dst_ptr, buf_ptr, num_send_bytes, dst_rank, local_expert_idx, lane_id, token_idx - offset);
// reg_topk_weights[i] = __ldg(topk_weights + token_idx * num_topk + i);
// }
// }
// // Put the finishing flag
// float combined_values[kNumElemsPerInt4] = {0.0f};
// EP_DEVICE_ASSERT(num_warps_per_group > 1 and num_warp_groups < 16);
// #pragma unroll
// asm volatile("bar.sync %0, %1;" :: "r"(warp_group_id + 1), "r"(num_warps_per_group * 32));
// for (int i = 0; i < num_topk; ++ i) if (reg_topk_idx[i] >= 0) {
// if (sub_warp_id == 1 and lane_id == 0) {
// // Read from sources
// while (ld_acquire_global(atomic_clean_flag) == 0);
// auto rdma_buffer_type = reinterpret_cast<const int*>(reinterpret_cast<uint8_t*>(rdma_recv_x) + (reg_topk_idx[i] * num_max_dispatch_tokens_per_rank + token_idx) * num_bytes_per_slot);
// auto dst_ptr = reinterpret_cast<uint64_t>(rdma_recv_flag + global_expert_idx);
// auto rdma_buffer_row = reinterpret_cast<const uint8_t*>(rdma_buffer_type + 4);
// auto dst_p2p_ptr = nvshmemi_get_p2p_ptr(dst_ptr, rank, dst_rank);
// if (dst_p2p_ptr == 0) {
// // Reduce
// nvshmemi_ibgda_amo_nonfetch_add(reinterpret_cast<int*>(dst_ptr), 1, dst_rank, local_expert_idx);
// auto x_vec = ld_nc_global(reinterpret_cast<const int4*>(rdma_buffer_row) + thread_id);
// } else {
// const auto x_bf16 = reinterpret_cast<gpu_bfloat16_t*>(&x_vec);
// st_release_sys_global(reinterpret_cast<int*>(dst_p2p_ptr), 1);
// #pragma unroll
// }
// for (int j = 0; j < kNumElemsPerInt4; ++ j)
// atomic_add_release_global(atomic_clean_flag, -1);
// combined_values[j] += static_cast<float>(x_bf16[j]) * reg_topk_weights[i];
// }
// }
// __syncwarp();
// // Write results
// // Destroy m-barriers
// int4& combined_int4 = *reinterpret_cast<int4*>(combined_values);
// if (lane_id < kNumStages) {
// auto combined_bf16 = reinterpret_cast<gpu_bfloat16_t*>(&combined_values);
// mbarrier_inval(full_barriers[lane_id]);
// #pragma unroll
// fence_barrier_init();
// for (int j = 0; j < kNumElemsPerInt4; ++ j)
// }
// combined_bf16[j] = static_cast<gpu_bfloat16_t>(combined_values[j]);
// __syncwarp();
// (reinterpret_cast<int4*>(combined_x) + token_idx * hidden_bf16_int4)[thread_id] = combined_int4;
// }
// }
// }
// // Receiving phase
#if !defined(ROCM_DISABLE_CTX)
// LOW_LATENCY_COMBINE_RECV:
rocshmem
::
rocshmem_wg_ctx_destroy
(
&
ctx
);
// if ((phases & LOW_LATENCY_RECV_PHASE) == 0)
#endif
// return;
// // Wait all ranks to arrive
// if (responsible_expert_idx < num_experts) {
// EP_DEVICE_ASSERT(num_warps_per_group > 1);
// if (sub_warp_id == 0 and lane_id == 0) {
// auto start_time = clock64();
// while (ld_acquire_sys_global(rdma_recv_flag + responsible_expert_idx) == 0);
// auto wait_recv_cost = clock64() - start_time;
// if (combine_wait_recv_cost_stats != nullptr) {
// const auto& src_rank = responsible_expert_idx / num_local_experts;
// atomicAdd(reinterpret_cast<unsigned long long*>(combine_wait_recv_cost_stats + src_rank), wait_recv_cost);
// }
// }
// }
// cg::this_grid().sync();
// // Reassign warp groups
// constexpr int kMaxNumGroups = 2;
// const int num_decode_warps = hidden_bf16_int4_pad / (kNumRecvUnrolls * 32);
// const int num_groups = min(kMaxNumGroups, (num_threads / 32) / (num_decode_warps + 1));
// const int decode_warp_idx = __shfl_sync(0xffffffff, warp_id % (num_decode_warps + 1), 0);
// const int group_idx = __shfl_sync(0xffffffff, warp_id / (num_decode_warps + 1), 0);
// EP_STATIC_ASSERT(kHidden % (32 * kNumElemsPerInt4) == 0, "Invalid vectorization");
// EP_DEVICE_ASSERT(num_topk <= 32);
// EP_DEVICE_ASSERT(num_groups > 0);
// if (group_idx < num_groups) {
// constexpr int kNumStages = 3;
// constexpr int kNumTMABufferBytes = 16 * 2 + kHidden * 2;
// constexpr int kNumBF16PerWarpBytes = 32 * kNumRecvUnrolls * kNumElemsPerInt4 * 2;
// constexpr int kNumLogFMTPerWarpBytes = kNumBF16PerWarpBytes / 16 * 10;
// constexpr int kNumDivisionBytes = kNumDivisions * sizeof(uint32_t);
// constexpr int kNumBytesPerGroup = kNumStages * kNumTMABufferBytes + kHidden * 2 + kNumStages * kNumDivisionBytes * 3;
// // Reallocate shared memory
// const auto smem_group_buffer = smem_buffer + kNumBytesPerGroup * group_idx;
// auto full_barriers = PatternVisitor([=](const int& i) { return reinterpret_cast<uint64_t*>(smem_group_buffer + i * kNumTMABufferBytes); });
// auto empty_barriers = PatternVisitor([=](const int& i) { return reinterpret_cast<uint64_t*>(smem_group_buffer + i * kNumTMABufferBytes + 8); });
// auto tma_ld_buffers = PatternVisitor([=](const int& i) { return reinterpret_cast<uint8_t* >(smem_group_buffer + i * kNumTMABufferBytes + 16); });
// auto tma_st_buffers = PatternVisitor([=](const int& i) { return reinterpret_cast<uint32_t*>(smem_group_buffer + kNumStages * kNumTMABufferBytes + i * kNumBF16PerWarpBytes); });
// // Redundant when logfmt is disabled
// const auto smem_group_ptr = smem_group_buffer + kNumStages * kNumTMABufferBytes + kHidden * 2;
// auto log_amax_buffers = PatternVisitor([=](const int& i) { return reinterpret_cast<float*>(smem_group_ptr + i * kNumDivisionBytes); });
// auto log_amin_buffers = PatternVisitor([=](const int& i) { return reinterpret_cast<float*>(smem_group_ptr + kNumStages * kNumDivisionBytes + i * kNumDivisionBytes); });
// auto cast_info_buffers = PatternVisitor([=](const int& i) { return reinterpret_cast<int*> (smem_group_ptr + kNumStages * kNumDivisionBytes * 2 + i * kNumDivisionBytes); });
// uint32_t tma_phase = 0;
// EP_STATIC_ASSERT(kNumStages < 32, "Too many stages");
// if (decode_warp_idx == num_decode_warps)
// tma_phase = (1 << kNumStages) - 1;
// // Initialize m-barriers
// if (decode_warp_idx == num_decode_warps and lane_id < kNumStages) {
// mbarrier_init(full_barriers[lane_id], 1);
// mbarrier_init(empty_barriers[lane_id], num_decode_warps);
// }
// asm volatile("bar.sync %0, %1;" :: "r"(group_idx + 1), "r"((num_decode_warps + 1) * 32));
// int stage_idx = 0, topk_idx_by_lane = 0;
// EP_STATIC_ASSERT(kNumMaxTopk <= 32, "Invalid number of topks");
// if (decode_warp_idx == num_decode_warps) {
// // TMA load warp
// for (int token_idx = sm_id + num_sms * group_idx; token_idx < num_combined_tokens; token_idx += num_sms * num_groups) {
// if (lane_id < num_topk)
// topk_idx_by_lane = static_cast<int>(__ldg(topk_idx + token_idx * num_topk + lane_id));
// for (int i = 0; i < num_topk; ++ i) {
// int topk_idx_reg = __shfl_sync(0xffffffff, topk_idx_by_lane, i);
// if (topk_idx_reg < 0)
// continue;
// mbarrier_wait<true>(empty_barriers[stage_idx], tma_phase, stage_idx);
// auto buffer = static_cast<uint8_t*>(rdma_recv_x) + (topk_idx_reg * num_max_dispatch_tokens_per_rank + token_idx) * num_bytes_per_slot;
// if constexpr (kUseLogFMT) {
// logfmt_check_amaxmin<kNumDivisions / 2, kNumSendUnrolls, kNumRecvUnrolls>(
// buffer, reinterpret_cast<float2*>(log_amax_buffers[stage_idx]),
// reinterpret_cast<float2*>(log_amin_buffers[stage_idx]), cast_info_buffers[stage_idx], lane_id);
// }
// if (elect_one_sync()) {
// int num_casted = 0;
// if constexpr (kUseLogFMT) {
// const auto& info = cast_info_buffers[stage_idx][num_decode_warps - 1];
// num_casted = (info >> 1) + (info & 1);
// }
// int num_tma_bytes = num_casted * kNumLogFMTPerWarpBytes + (num_decode_warps - num_casted) * kNumBF16PerWarpBytes;
// tma_load_1d(tma_ld_buffers[stage_idx], buffer + (kUseLogFMT ? kNumMetaBytes : 0), full_barriers[stage_idx], num_tma_bytes);
// mbarrier_arrive_and_expect_tx(full_barriers[stage_idx], num_tma_bytes);
// }
// __syncwarp();
// stage_idx = (stage_idx + 1) % kNumStages;
// }
// }
// } else {
// // Reduction warps
// float topk_weights_by_lane;
// for (int token_idx = sm_id + num_sms * group_idx; token_idx < num_combined_tokens; token_idx += num_sms * num_groups) {
// if (lane_id < num_topk) {
// topk_idx_by_lane = static_cast<int>(__ldg(topk_idx + token_idx * num_topk + lane_id));
// topk_weights_by_lane = __ldg(topk_weights + token_idx * num_topk + lane_id);
// }
// __syncwarp();
// float combined_values[kNumElemsPerInt4 * kNumRecvUnrolls] = {0.0f};
// for (int i = 0; i < num_topk; ++ i) {
// if (__shfl_sync(0xffffffff, topk_idx_by_lane, i) < 0)
// continue;
// const auto& topk_weight = __shfl_sync(0xffffffff, topk_weights_by_lane, i);
// mbarrier_wait<true>(full_barriers[stage_idx], tma_phase, stage_idx);
// if constexpr (kUseLogFMT) {
// const auto& info = cast_info_buffers[stage_idx][decode_warp_idx];
// bool enable_cast = info & 1;
// int num_casted_prefix = info >> 1;
// int tma_offset = kNumLogFMTPerWarpBytes * num_casted_prefix + kNumBF16PerWarpBytes * (decode_warp_idx - num_casted_prefix);
// int division_idx = decode_warp_idx * (kNumRecvUnrolls * 2) + lane_id * kNumRecvUnrolls / 16;
// decode_and_accumulate<kNumRecvUnrolls>(
// reinterpret_cast<uint32_t*>(tma_ld_buffers[stage_idx] + tma_offset + (enable_cast ? kNumLogFMTPerWarpBytes : kNumBF16PerWarpBytes) / 32 * lane_id),
// combined_values, log_amax_buffers[stage_idx][division_idx], log_amin_buffers[stage_idx][division_idx], enable_cast, topk_weight
// );
// } else {
// int tma_offset = kNumBF16PerWarpBytes * decode_warp_idx;
// decode_and_accumulate<kNumRecvUnrolls>(
// reinterpret_cast<uint32_t*>(tma_ld_buffers[stage_idx] + tma_offset + kNumBF16PerWarpBytes / 32 * lane_id),
// combined_values, 0, 0, false, topk_weight
// );
// }
// if (elect_one_sync())
// mbarrier_arrive(empty_barriers[stage_idx]);
// stage_idx = (stage_idx + 1) % kNumStages;
// }
// tma_store_wait<0>();
// #pragma unroll
// for (int k = 0; k < kNumRecvUnrolls * 4; ++ k) {
// auto combined_pack = __nv_bfloat162(combined_values[k * 2], combined_values[k * 2 + 1]);
// tma_st_buffers[decode_warp_idx][kNumRecvUnrolls * 4 * lane_id + k] = *reinterpret_cast<uint32_t*>(&combined_pack);
// }
// tma_store_fence();
// if (elect_one_sync()) {
// tma_store_1d(tma_st_buffers[decode_warp_idx],
// static_cast<int4*>(combined_x) + token_idx * hidden_bf16_int4 + decode_warp_idx * kNumRecvUnrolls * 32,
// kNumBF16PerWarpBytes);
// }
// __syncwarp();
// }
// }
// }
}
}
void
combine
(
void
*
combined_x
,
void
combine
(
void
*
combined_x
,
void
*
rdma_recv_x
,
int
*
rdma_recv_flag
,
void
*
rdma_send_x
,
void
*
rdma_recv_x
,
const
void
*
x
,
const
topk_idx_t
*
topk_idx
,
const
float
*
topk_weights
,
int64_t
*
rdma_recv_flag
,
const
int
*
src_info
,
const
int64_t
*
layout_range
,
void
*
rdma_send_x
,
const
void
*
x
,
const
int64_t
*
topk_idx
,
const
float
*
topk_weights
,
const
int
*
src_info
,
const
int64_t
*
layout_range
,
int
*
global_atomic_counter
,
int
*
mask_buffer_ptr
,
int64_t
*
combine_wait_recv_cost_stats
,
int64_t
*
combine_wait_recv_cost_stats
,
int
*
next_clean
,
int
num_next_clean_int
,
int64_t
*
next_clean
,
int
num_combined_tokens
,
int
hidden
,
int
num_max_dispatch_tokens_per_rank
,
int
num_next_clean_int
,
int
num_topk
,
int
num_experts
,
int
rank
,
int
num_ranks
,
int
num_combined_tokens
,
int
hidden
,
int
num_max_dispatch_tokens_per_rank
,
int
num_topk
,
int
num_experts
,
int
rank
,
int
num_ranks
,
bool
use_logfmt
,
bool
use_logfmt
,
void
*
workspace
,
int
num_device_sms
,
void
*
workspace
,
cudaStream_t
stream
,
int
phases
,
bool
zero_copy
)
{
int
num_device_sms
,
hipStream_t
stream
,
int
phases
,
bool
zero_copy
)
{
constexpr
int
kNumMaxTopk
=
11
;
constexpr
int
kNumMaxTopk
=
11
;
// const int num_warp_groups = ceil_div(num_experts, num_device_sms);
const
int
num_warp_groups
=
DIVUP
(
num_experts
,
num_device_sms
);
// const int num_warps_per_group = 32 / num_warp_groups;
const
int
num_warps_per_group
=
16
/
num_warp_groups
;
// const int num_recv_per_sm = ceil_div(num_combined_tokens, num_device_sms);
const
int
num_recv_per_sm
=
DIVUP
(
num_combined_tokens
,
num_device_sms
);
// EP_HOST_ASSERT(num_warp_groups > 0 and num_warps_per_group > 0 and num_recv_per_sm >= 0);
EP_HOST_ASSERT
(
num_warp_groups
>
0
and
num_warps_per_group
>
0
and
num_recv_per_sm
>=
0
);
// const auto num_warps = num_warp_groups * num_warps_per_group;
const
auto
num_warps
=
num_warp_groups
*
num_warps_per_group
;
// const auto num_sms = max(ceil_div(num_experts, num_warp_groups),
const
auto
num_sms
=
max
(
DIVUP
(
num_experts
,
num_warp_groups
),
num_recv_per_sm
==
0
?
1
:
DIVUP
(
num_combined_tokens
,
num_recv_per_sm
));
// num_recv_per_sm == 0 ? 1 : ceil_div(num_combined_tokens, num_recv_per_sm));
// Check workspace
// // Check workspace
auto
atomic_clean_flag
=
static_cast
<
int
*>
(
workspace
);
// auto atomic_clean_flag = static_cast<int*>(workspace);
EP_HOST_ASSERT
(
sizeof
(
int
)
<=
NUM_WORKSPACE_BYTES
);
// EP_HOST_ASSERT(sizeof(int) <= NUM_WORKSPACE_BYTES);
EP_HOST_ASSERT
(
num_topk
<=
kNumMaxTopk
);
// EP_HOST_ASSERT(num_topk <= kNumMaxTopk);
// Online cast cannot use zero-copy
// // Online cast cannot use zero-copy
EP_HOST_ASSERT
(
not
(
zero_copy
and
use_logfmt
));
// EP_HOST_ASSERT(not (zero_copy and use_logfmt));
EP_HOST_ASSERT
(
use_logfmt
==
0
);
// constexpr int kNumStages = 3;
constexpr
int
kNumMaxUnrolls
=
4
;
// constexpr int kNumMaxUnrolls = 4;
// constexpr int kMaxNumGroups = 2;
#ifdef USEING_TMA
constexpr
int
kNumStages
=
3
;
// // Send buffer size
constexpr
int
kMaxNumGroups
=
2
;
// const int num_meta_bytes = hidden / 128 * 4;
// const int num_send_tma_bytes = 32 * sizeof(int4) * kNumMaxUnrolls + 16;
// Send buffer size
// const int smem_send_size = num_warps * (kNumStages * num_send_tma_bytes + num_meta_bytes);
const
int
num_meta_bytes
=
hidden
/
FP8_QUANTIZATION_NUM_PER_CHANNEL
*
4
;
const
int
num_send_tma_bytes
=
32
*
sizeof
(
int4
)
*
kNumMaxUnrolls
+
16
;
// // Receive buffer size
const
int
smem_send_size
=
num_warps
*
(
kNumStages
*
num_send_tma_bytes
+
num_meta_bytes
);
// const int num_recv_tma_bytes = 16 + hidden * 2;
// const int smem_recv_size = kMaxNumGroups * (kNumStages * num_recv_tma_bytes + hidden * 2 + kNumStages * num_meta_bytes * 3);
// Receive buffer size
const
int
num_recv_tma_bytes
=
16
+
hidden
*
2
;
// // Total requirement
const
int
smem_recv_size
=
kMaxNumGroups
*
(
kNumStages
*
num_recv_tma_bytes
+
hidden
*
2
+
kNumStages
*
num_meta_bytes
*
3
);
// const int smem_size = max(smem_send_size, smem_recv_size);
// Total requirement
// #define COMBINE_LAUNCH_CASE(hidden) { \
const
int
smem_size
=
max
(
smem_send_size
,
smem_recv_size
);
// auto combine_func = use_logfmt ? \
#endif
// combine<true, hidden, kNumMaxTopk, kNumMaxUnrolls> : \
// combine<false, hidden, kNumMaxTopk, kNumMaxUnrolls>; \
// #define COMBINE_LAUNCH_CASE(hidden) \
// SET_SHARED_MEMORY_FOR_TMA(combine_func); \
// { \
// LAUNCH_KERNEL(&cfg, combine_func, \
// auto combine_func = combine<false, hidden, kNumMaxTopk, kNumMaxUnrolls>; \
// LAUNCH_KERNEL(&cfg, \
// combine_func, \
// combined_x, \
// combined_x, \
// rdma_recv_x, rdma_recv_flag, rdma_send_x, \
// rdma_recv_x, \
// x, topk_idx, topk_weights, src_info, layout_range, \
// rdma_recv_flag, \
// rdma_send_x, \
// x, \
// topk_idx, \
// topk_weights, \
// src_info, \
// layout_range, \
// global_atomic_counter, \
// mask_buffer_ptr, \
// combine_wait_recv_cost_stats, \
// combine_wait_recv_cost_stats, \
// next_clean, num_next_clean_int, \
// next_clean, \
// num_next_clean_int, \
// atomic_clean_flag, \
// atomic_clean_flag, \
// num_combined_tokens, hidden, num_topk, \
// num_combined_tokens, \
// hidden, \
// num_topk, \
// num_max_dispatch_tokens_per_rank, \
// num_max_dispatch_tokens_per_rank, \
// num_experts, rank, num_ranks, \
// num_experts, \
// num_warp_groups, num_warps_per_group, \
// rank, \
// phases, zero_copy); } break
// num_ranks, \
// num_warp_groups, \
// SETUP_LAUNCH_CONFIG(num_sms, num_warps * 32, stream);
// num_warps_per_group, \
// phases, \
// zero_copy); \
// } \
// break
// SETUP_LAUNCH_CONFIG(num_sms, num_warps* kWarpSize, stream);
// SWITCH_HIDDEN(COMBINE_LAUNCH_CASE);
// SWITCH_HIDDEN(COMBINE_LAUNCH_CASE);
// #undef COMBINE_LAUNCH_CASE
// #undef COMBINE_LAUNCH_CASE
}
}
template
<
int
kNumThreads
>
__launch_bounds__
(
kNumThreads
,
1
)
__global__
void
query_mask_buffer
(
int
*
mask_buffer_ptr
,
int
num_ranks
,
int
*
mask_tensor
)
{
const
auto
num_sms
=
static_cast
<
int
>
(
gridDim
.
x
);
const
auto
sm_id
=
static_cast
<
int
>
(
blockIdx
.
x
);
const
auto
num_threads
=
num_sms
*
kNumThreads
;
const
auto
thread_id
=
sm_id
*
kNumThreads
+
static_cast
<
int
>
(
threadIdx
.
x
);
for
(
int
rank_id
=
thread_id
;
rank_id
<
num_ranks
;
rank_id
+=
num_threads
)
{
mask_tensor
[
rank_id
]
=
mask_buffer_ptr
[
rank_id
];
}
}
void
query_mask_buffer
(
int
*
mask_buffer_ptr
,
int
num_ranks
,
int
*
mask_tensor
,
hipStream_t
stream
)
{
constexpr
int
num_sms
=
1
;
constexpr
int
kNumThreads
=
1024
;
SETUP_LAUNCH_CONFIG
(
num_sms
,
kNumThreads
,
stream
);
LAUNCH_KERNEL_NON_COOPERATIVE
(
&
cfg
,
query_mask_buffer
<
kNumThreads
>
,
mask_buffer_ptr
,
num_ranks
,
mask_tensor
);
}
template
<
int
kNumThreads
>
__launch_bounds__
(
kNumThreads
,
1
)
__global__
void
update_mask_buffer
(
int
*
mask_buffer_ptr
,
int
rank_to_mask
,
bool
mask
)
{
const
auto
sm_id
=
static_cast
<
int
>
(
blockIdx
.
x
);
const
auto
thread_id
=
static_cast
<
int
>
(
threadIdx
.
x
);
if
(
sm_id
==
0
&&
thread_id
==
0
)
{
atomicExch
(
mask_buffer_ptr
+
rank_to_mask
,
mask
?
1
:
0
);
}
}
void
update_mask_buffer
(
int
*
mask_buffer_ptr
,
int
rank
,
bool
mask
,
hipStream_t
stream
)
{
constexpr
int
num_sms
=
1
;
constexpr
int
kNumThreads
=
64
;
SETUP_LAUNCH_CONFIG
(
num_sms
,
kNumThreads
,
stream
);
LAUNCH_KERNEL_NON_COOPERATIVE
(
&
cfg
,
update_mask_buffer
<
kNumThreads
>
,
mask_buffer_ptr
,
rank
,
mask
);
}
template
<
int
kNumThreads
>
__launch_bounds__
(
kNumThreads
,
1
)
__global__
void
clean_mask_buffer
(
int
*
mask_buffer_ptr
,
int
num_ranks
)
{
auto
thread_id
=
static_cast
<
int
>
(
threadIdx
.
x
);
#pragma unroll
for
(
int
i
=
thread_id
;
i
<
num_ranks
;
i
+=
kNumThreads
)
mask_buffer_ptr
[
i
]
=
0
;
}
void
clean_mask_buffer
(
int
*
mask_buffer_ptr
,
int
num_ranks
,
hipStream_t
stream
)
{
constexpr
int
num_sms
=
1
;
constexpr
int
kNumThreads
=
64
;
SETUP_LAUNCH_CONFIG
(
num_sms
,
kNumThreads
,
stream
);
LAUNCH_KERNEL_NON_COOPERATIVE
(
&
cfg
,
clean_mask_buffer
<
kNumThreads
>
,
mask_buffer_ptr
,
num_ranks
);
}
}
// namespace internode_ll
}
// namespace internode_ll
}
// namespace deep_ep
}
// namespace deep_ep
#endif
csrc/kernels/utils.cuh
View file @
09cb2b03
...
@@ -125,6 +125,10 @@ __device__ __forceinline__ void st_release_sys_global(const int *ptr, int val) {
...
@@ -125,6 +125,10 @@ __device__ __forceinline__ void st_release_sys_global(const int *ptr, int val) {
__hip_atomic_store
(
const_cast
<
int
*>
(
ptr
),
val
,
__ATOMIC_RELEASE
,
__HIP_MEMORY_SCOPE_SYSTEM
);
__hip_atomic_store
(
const_cast
<
int
*>
(
ptr
),
val
,
__ATOMIC_RELEASE
,
__HIP_MEMORY_SCOPE_SYSTEM
);
}
}
__device__
__forceinline__
void
st_release_sys_global
(
const
int64_t
*
ptr
,
int64_t
val
)
{
__hip_atomic_store
(
const_cast
<
int64_t
*>
(
ptr
),
val
,
__ATOMIC_RELEASE
,
__HIP_MEMORY_SCOPE_SYSTEM
);
}
__device__
__forceinline__
void
st_release_cta
(
const
int
*
ptr
,
int
val
)
{
__device__
__forceinline__
void
st_release_cta
(
const
int
*
ptr
,
int
val
)
{
__hip_atomic_store
(
const_cast
<
int
*>
(
ptr
),
val
,
__ATOMIC_RELEASE
,
__HIP_MEMORY_SCOPE_WORKGROUP
);
__hip_atomic_store
(
const_cast
<
int
*>
(
ptr
),
val
,
__ATOMIC_RELEASE
,
__HIP_MEMORY_SCOPE_WORKGROUP
);
}
}
...
@@ -157,6 +161,12 @@ __device__ __forceinline__ int ld_acquire_global(const int *ptr) {
...
@@ -157,6 +161,12 @@ __device__ __forceinline__ int ld_acquire_global(const int *ptr) {
return
ret
;
return
ret
;
}
}
__device__
__forceinline__
int64_t
ld_acquire_global
(
const
int64_t
*
ptr
)
{
int64_t
ret
;
ret
=
__hip_atomic_load
(
ptr
,
__ATOMIC_ACQUIRE
,
__HIP_MEMORY_SCOPE_AGENT
);
return
ret
;
}
__device__
__forceinline__
int
atomic_add_release_global
(
const
int
*
ptr
,
int
value
)
{
__device__
__forceinline__
int
atomic_add_release_global
(
const
int
*
ptr
,
int
value
)
{
int
ret
;
int
ret
;
// ret = __hip_atomic_fetch_add(const_cast<int *>(ptr), value, __ATOMIC_RELEASE,
// ret = __hip_atomic_fetch_add(const_cast<int *>(ptr), value, __ATOMIC_RELEASE,
...
@@ -165,6 +175,12 @@ __device__ __forceinline__ int atomic_add_release_global(const int *ptr, int val
...
@@ -165,6 +175,12 @@ __device__ __forceinline__ int atomic_add_release_global(const int *ptr, int val
return
ret
;
return
ret
;
}
}
__device__
__forceinline__
int
ld_relaxed_global
(
const
int
*
ptr
)
{
int
ret
;
ret
=
__hip_atomic_load
(
ptr
,
__ATOMIC_RELAXED
,
__HIP_MEMORY_SCOPE_AGENT
);
return
ret
;
}
__device__
__forceinline__
int
ld_acquire_cta
(
const
int
*
ptr
)
{
__device__
__forceinline__
int
ld_acquire_cta
(
const
int
*
ptr
)
{
int
ret
;
int
ret
;
ret
=
__hip_atomic_load
(
ptr
,
__ATOMIC_ACQUIRE
,
__HIP_MEMORY_SCOPE_WORKGROUP
);
ret
=
__hip_atomic_load
(
ptr
,
__ATOMIC_ACQUIRE
,
__HIP_MEMORY_SCOPE_WORKGROUP
);
...
@@ -245,6 +261,11 @@ __device__ __forceinline__ void st_na_release(const uint64_t *ptr, uint64_t val)
...
@@ -245,6 +261,11 @@ __device__ __forceinline__ void st_na_release(const uint64_t *ptr, uint64_t val)
__hip_atomic_store
(
non_const_ptr
,
val
,
__ATOMIC_RELEASE
,
__HIP_MEMORY_SCOPE_AGENT
);
__hip_atomic_store
(
non_const_ptr
,
val
,
__ATOMIC_RELEASE
,
__HIP_MEMORY_SCOPE_AGENT
);
}
}
__device__
__forceinline__
void
st_na_release
(
const
int64_t
*
ptr
,
int64_t
val
)
{
int64_t
*
non_const_ptr
=
const_cast
<
int64_t
*>
(
ptr
);
__hip_atomic_store
(
non_const_ptr
,
val
,
__ATOMIC_RELEASE
,
__HIP_MEMORY_SCOPE_AGENT
);
}
// TODO:: apply "st.global.L1::no_allocate" in ROCM
// TODO:: apply "st.global.L1::no_allocate" in ROCM
template
<
typename
dtype_t
>
template
<
typename
dtype_t
>
__device__
__forceinline__
void
st_na_global
(
const
dtype_t
*
ptr
,
const
dtype_t
&
value
)
{
__device__
__forceinline__
void
st_na_global
(
const
dtype_t
*
ptr
,
const
dtype_t
&
value
)
{
...
@@ -279,6 +300,22 @@ __forceinline__ __device__ void get_channel_task_range(int num_tokens, int num_s
...
@@ -279,6 +300,22 @@ __forceinline__ __device__ void get_channel_task_range(int num_tokens, int num_s
token_end_idx
=
min
(
token_start_idx
+
num_tokens_per_sm
,
num_tokens
);
token_end_idx
=
min
(
token_start_idx
+
num_tokens_per_sm
,
num_tokens
);
}
}
template
<
typename
dtype_a_t
,
typename
dtype_b_t
>
__device__
__forceinline__
dtype_b_t
pack2
(
const
dtype_a_t
&
x
,
const
dtype_a_t
&
y
)
{
EP_STATIC_ASSERT
(
sizeof
(
dtype_a_t
)
*
2
==
sizeof
(
dtype_b_t
),
"Invalid dtypes"
);
dtype_b_t
packed
;
auto
unpacked_ptr
=
reinterpret_cast
<
dtype_a_t
*>
(
&
packed
);
unpacked_ptr
[
0
]
=
x
,
unpacked_ptr
[
1
]
=
y
;
return
packed
;
}
template
<
typename
dtype_a_t
,
typename
dtype_b_t
>
__device__
__forceinline__
void
unpack2
(
const
dtype_b_t
&
packed
,
dtype_a_t
&
x
,
dtype_a_t
&
y
)
{
EP_STATIC_ASSERT
(
sizeof
(
dtype_a_t
)
*
2
==
sizeof
(
dtype_b_t
),
"Invalid dtypes"
);
auto
unpacked_ptr
=
reinterpret_cast
<
const
dtype_a_t
*>
(
&
packed
);
x
=
unpacked_ptr
[
0
],
y
=
unpacked_ptr
[
1
];
}
template
<
typename
dtype_t
>
template
<
typename
dtype_t
>
__device__
__forceinline__
dtype_t
broadcast
(
dtype_t
&
ptr
,
int
src_lane_idx
)
{
__device__
__forceinline__
dtype_t
broadcast
(
dtype_t
&
ptr
,
int
src_lane_idx
)
{
EP_STATIC_ASSERT
(
sizeof
(
dtype_t
)
%
sizeof
(
int
)
==
0
,
""
);
EP_STATIC_ASSERT
(
sizeof
(
dtype_t
)
%
sizeof
(
int
)
==
0
,
""
);
...
@@ -290,15 +327,47 @@ __device__ __forceinline__ dtype_t broadcast(dtype_t &ptr, int src_lane_idx) {
...
@@ -290,15 +327,47 @@ __device__ __forceinline__ dtype_t broadcast(dtype_t &ptr, int src_lane_idx) {
return
*
reinterpret_cast
<
dtype_t
*>
(
recv_int_values
);
return
*
reinterpret_cast
<
dtype_t
*>
(
recv_int_values
);
}
}
__forceinline__
__device__
int
warp_reduce_sum
(
int
value
)
{
#ifdef USE_ROCM
if
constexpr
(
kWarpSize
==
64
)
constexpr
float
kFP8Margin
=
1e-4
;
value
+=
shfl_xor
<
int
>
(
value
,
32
);
constexpr
float
kFinfoAmaxE4M3
=
240.0
f
;
value
+=
shfl_xor
<
int
>
(
value
,
16
);
constexpr
float
kFinfoAmaxInvE4M3
=
1.0
f
/
kFinfoAmaxE4M3
;
value
+=
shfl_xor
<
int
>
(
value
,
8
);
#else
value
+=
shfl_xor
<
int
>
(
value
,
4
);
constexpr
float
kFP8Margin
=
1e-4
;
value
+=
shfl_xor
<
int
>
(
value
,
2
);
constexpr
float
kFinfoAmaxE4M3
=
448.0
f
;
value
+=
shfl_xor
<
int
>
(
value
,
1
);
constexpr
float
kFinfoAmaxInvE4M3
=
1.0
f
/
kFinfoAmaxE4M3
;
#endif
__forceinline__
__device__
float
fast_pow2
(
int
x
)
{
// We can ensure `-126 <= x and x <= 127`
uint32_t
bits_x
=
(
x
+
127
)
<<
23
;
return
*
reinterpret_cast
<
float
*>
(
&
bits_x
);
}
__forceinline__
__device__
int
fast_log2_ceil
(
float
x
)
{
auto
bits_x
=
*
reinterpret_cast
<
uint32_t
*>
(
&
x
);
auto
exp_x
=
(
bits_x
>>
23
)
&
0xff
;
auto
man_bits
=
bits_x
&
((
1
<<
23
)
-
1
);
return
exp_x
-
127
+
(
man_bits
!=
0
);
}
__forceinline__
__device__
void
calculate_fp8_scales
(
float
amax
,
float
&
scale
,
float
&
scale_inv
,
bool
round_scale
)
{
if
(
round_scale
)
{
auto
exp_scale_inv
=
fast_log2_ceil
(
amax
*
kFinfoAmaxInvE4M3
);
scale
=
fast_pow2
(
-
exp_scale_inv
);
scale_inv
=
fast_pow2
(
exp_scale_inv
);
}
else
{
scale_inv
=
amax
*
kFinfoAmaxInvE4M3
;
scale
=
kFinfoAmaxE4M3
/
amax
;
}
}
template
<
bool
kIsUE8M0
,
typename
out_dtype_t
=
std
::
conditional_t
<
kIsUE8M0
,
uint8_t
,
float
>
>
__forceinline__
__device__
out_dtype_t
extract_required_scale_format
(
float
value
)
{
if
constexpr
(
kIsUE8M0
)
{
return
static_cast
<
uint8_t
>
((
*
reinterpret_cast
<
uint32_t
*>
(
&
value
))
>>
23
);
}
else
{
return
value
;
return
value
;
}
}
}
__forceinline__
__device__
int
get_lane_id
()
{
__forceinline__
__device__
int
get_lane_id
()
{
...
@@ -340,4 +409,95 @@ __forceinline__ __device__ void barrier_block(int **barrier_signal_ptrs, int ran
...
@@ -340,4 +409,95 @@ __forceinline__ __device__ void barrier_block(int **barrier_signal_ptrs, int ran
}
}
__syncthreads
();
__syncthreads
();
}
}
// Operation functors
template
<
typename
T
>
struct
ReduceSum
{
__device__
T
operator
()(
T
a
,
T
b
)
const
{
return
a
+
b
;
}
};
template
<
typename
T
>
struct
ReduceMax
{
__device__
T
operator
()(
T
a
,
T
b
)
const
{
return
a
>
b
?
a
:
b
;
}
};
template
<
typename
T
>
struct
ReduceMin
{
__device__
T
operator
()(
T
a
,
T
b
)
const
{
return
a
<
b
?
a
:
b
;
}
};
template
<
typename
T
>
struct
ReduceAnd
{
__device__
T
operator
()(
T
a
,
T
b
)
const
{
return
a
&
b
;
}
};
template
<
typename
T
>
struct
ReduceOr
{
__device__
T
operator
()(
T
a
,
T
b
)
const
{
return
a
|
b
;
}
};
// Unified reduction function
template
<
int
kNumLanesPerGroup
,
bool
kIntergroupReduce
,
typename
T
,
typename
Op
>
__forceinline__
__device__
T
warp_reduce
(
T
value
,
Op
op
)
{
EP_STATIC_ASSERT
(
kNumLanesPerGroup
==
kWarpSize
or
kNumLanesPerGroup
==
32
or
kNumLanesPerGroup
==
16
or
kNumLanesPerGroup
==
8
or
kNumLanesPerGroup
==
4
or
kNumLanesPerGroup
==
2
or
kNumLanesPerGroup
==
1
,
"Invalid number of lanes"
);
constexpr
uint32_t
mask
=
0xffffffff
;
if
constexpr
(
kIntergroupReduce
)
{
if
constexpr
(
kNumLanesPerGroup
<=
1
)
value
=
op
(
value
,
shfl_xor
(
value
,
1
));
if
constexpr
(
kNumLanesPerGroup
<=
2
)
value
=
op
(
value
,
shfl_xor
(
value
,
2
));
if
constexpr
(
kNumLanesPerGroup
<=
4
)
value
=
op
(
value
,
shfl_xor
(
value
,
4
));
if
constexpr
(
kNumLanesPerGroup
<=
8
)
value
=
op
(
value
,
shfl_xor
(
value
,
8
));
if
constexpr
(
kNumLanesPerGroup
<=
16
)
value
=
op
(
value
,
shfl_xor
(
value
,
16
));
if
constexpr
(
kWarpSize
==
64
){
if
constexpr
(
kNumLanesPerGroup
<=
32
)
value
=
op
(
value
,
shfl_xor
(
value
,
32
));
}
}
else
{
if
constexpr
(
kWarpSize
==
64
){
if
constexpr
(
kNumLanesPerGroup
>=
kWarpSize
)
value
=
op
(
value
,
shfl_xor
(
value
,
32
));
}
if
constexpr
(
kNumLanesPerGroup
>=
32
)
value
=
op
(
value
,
shfl_xor
(
value
,
16
));
if
constexpr
(
kNumLanesPerGroup
>=
16
)
value
=
op
(
value
,
shfl_xor
(
value
,
8
));
if
constexpr
(
kNumLanesPerGroup
>=
8
)
value
=
op
(
value
,
shfl_xor
(
value
,
4
));
if
constexpr
(
kNumLanesPerGroup
>=
4
)
value
=
op
(
value
,
shfl_xor
(
value
,
2
));
if
constexpr
(
kNumLanesPerGroup
>=
2
)
value
=
op
(
value
,
shfl_xor
(
value
,
1
));
}
return
value
;
}
// Convenience aliases
template
<
int
kNumLanesPerGroup
=
kWarpSize
,
bool
kIntergroupReduce
=
false
,
typename
T
>
__forceinline__
__device__
T
warp_reduce_sum
(
T
value
)
{
return
warp_reduce
<
kNumLanesPerGroup
,
kIntergroupReduce
,
T
>
(
value
,
ReduceSum
<
T
>
{});
}
template
<
int
kNumLanesPerGroup
=
kWarpSize
,
bool
kIntergroupReduce
=
false
,
typename
T
>
__forceinline__
__device__
T
warp_reduce_max
(
T
value
)
{
return
warp_reduce
<
kNumLanesPerGroup
,
kIntergroupReduce
,
T
>
(
value
,
ReduceMax
<
T
>
{});
}
template
<
int
kNumLanesPerGroup
=
kWarpSize
,
bool
kIntergroupReduce
=
false
,
typename
T
>
__forceinline__
__device__
T
warp_reduce_min
(
T
value
)
{
return
warp_reduce
<
kNumLanesPerGroup
,
kIntergroupReduce
,
T
>
(
value
,
ReduceMin
<
T
>
{});
}
template
<
int
kNumLanesPerGroup
=
kWarpSize
,
bool
kIntergroupReduce
=
false
,
typename
T
>
__forceinline__
__device__
T
warp_reduce_and
(
T
value
)
{
return
warp_reduce
<
kNumLanesPerGroup
,
kIntergroupReduce
,
T
>
(
value
,
ReduceAnd
<
T
>
{});
}
template
<
int
kNumLanesPerGroup
=
kWarpSize
,
bool
kIntergroupReduce
=
false
,
typename
T
>
__forceinline__
__device__
T
warp_reduce_or
(
T
value
)
{
return
warp_reduce
<
kNumLanesPerGroup
,
kIntergroupReduce
,
T
>
(
value
,
ReduceOr
<
T
>
{});
}
}
// namespace deep_ep
}
// namespace deep_ep
deep_ep/buffer.py
View file @
09cb2b03
...
@@ -39,7 +39,7 @@ class Buffer:
...
@@ -39,7 +39,7 @@ class Buffer:
allow_nvlink_for_low_latency_mode
:
bool
=
True
,
allow_nvlink_for_low_latency_mode
:
bool
=
True
,
allow_mnnvl
:
bool
=
False
,
allow_mnnvl
:
bool
=
False
,
explicitly_destroy
:
bool
=
False
,
explicitly_destroy
:
bool
=
False
,
use_default_stream_as_comm_stream
:
bool
=
Tru
e
,
enable_shrink
:
bool
=
Fals
e
,
)
->
None
:
)
->
None
:
"""
"""
Initialize the communication buffer.
Initialize the communication buffer.
...
@@ -59,6 +59,7 @@ class Buffer:
...
@@ -59,6 +59,7 @@ class Buffer:
explicitly_destroy: If this flag is set to True, you need to explicitly call `destroy()` to release resources;
explicitly_destroy: If this flag is set to True, you need to explicitly call `destroy()` to release resources;
otherwise, the resources will be released by the destructor.
otherwise, the resources will be released by the destructor.
Note: Releasing resources in the destructor may cause Python's exception handling process to hang.
Note: Releasing resources in the destructor may cause Python's exception handling process to hang.
enable_shrink: whether to enable shrink mode. The enable mode allocates a mask buffer to support masking ranks dynamically.
"""
"""
check_nvlink_connections
(
group
)
check_nvlink_connections
(
group
)
...
@@ -70,6 +71,7 @@ class Buffer:
...
@@ -70,6 +71,7 @@ class Buffer:
self
.
num_rdma_bytes
=
num_rdma_bytes
self
.
num_rdma_bytes
=
num_rdma_bytes
self
.
low_latency_mode
=
low_latency_mode
self
.
low_latency_mode
=
low_latency_mode
self
.
explicitly_destroy
=
explicitly_destroy
self
.
explicitly_destroy
=
explicitly_destroy
self
.
enable_shrink
=
enable_shrink
self
.
runtime
=
deep_ep_cpp
.
Buffer
(
self
.
runtime
=
deep_ep_cpp
.
Buffer
(
self
.
rank
,
self
.
rank
,
self
.
group_size
,
self
.
group_size
,
...
@@ -77,7 +79,7 @@ class Buffer:
...
@@ -77,7 +79,7 @@ class Buffer:
num_rdma_bytes
,
num_rdma_bytes
,
low_latency_mode
,
low_latency_mode
,
explicitly_destroy
,
explicitly_destroy
,
use_default_stream_as_comm_stream
,
enable_shrink
)
)
# Synchronize device IDs
# Synchronize device IDs
...
@@ -989,3 +991,31 @@ class Buffer:
...
@@ -989,3 +991,31 @@ class Buffer:
return
self
.
runtime
.
get_next_low_latency_combine_buffer
(
return
self
.
runtime
.
get_next_low_latency_combine_buffer
(
num_max_dispatch_tokens_per_rank
,
hidden
,
num_experts
num_max_dispatch_tokens_per_rank
,
hidden
,
num_experts
)
)
def
low_latency_update_mask_buffer
(
self
,
rank_to_mask
:
int
,
mask
:
bool
=
False
):
"""
Mask (unmask) a rank during communication (dispatch, combine, and clean)
Arguments:
rank: the rank to mask (unmask).
mask: if True, will mask the rank (do not recvfrom/sendto the rank), otherwise will unmask the rank.
"""
self
.
runtime
.
low_latency_update_mask_buffer
(
rank_to_mask
,
mask
)
def
low_latency_query_mask_buffer
(
self
,
mask_status
:
torch
.
Tensor
):
"""
Query the mask status of all ranks
Arguments:
mask_status: `[num_ranks]` with `torch.int`, the mask status of each rank. `1` means mask and `0` means unmasked.
"""
self
.
runtime
.
low_latency_query_mask_buffer
(
mask_status
)
def
low_latency_clean_mask_buffer
(
self
):
"""
Clean the mask buffer
"""
self
.
runtime
.
low_latency_clean_mask_buffer
()
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment