Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
DeepEP
Commits
09cb2b03
Commit
09cb2b03
authored
Oct 30, 2025
by
lishen
Browse files
添加low latency接口,正确性需补充
parent
0b14d3b2
Changes
9
Expand all
Hide whitespace changes
Inline
Side-by-side
Showing
9 changed files
with
1579 additions
and
1059 deletions
+1579
-1059
build.sh
build.sh
+11
-6
csrc/config.hpp
csrc/config.hpp
+29
-26
csrc/deep_ep.cu
csrc/deep_ep.cu
+355
-64
csrc/deep_ep.hpp
csrc/deep_ep.hpp
+12
-3
csrc/kernels/api.cuh
csrc/kernels/api.cuh
+42
-0
csrc/kernels/configs.cuh
csrc/kernels/configs.cuh
+2
-0
csrc/kernels/internode_ll.cu
csrc/kernels/internode_ll.cu
+927
-949
csrc/kernels/utils.cuh
csrc/kernels/utils.cuh
+169
-9
deep_ep/buffer.py
deep_ep/buffer.py
+32
-2
No files found.
build.sh
View file @
09cb2b03
...
@@ -8,12 +8,17 @@ fi
...
@@ -8,12 +8,17 @@ fi
PYTHON_INCLUDE
=
$(
python3
-c
"from sysconfig import get_paths; print(get_paths()['include'])"
)
PYTHON_INCLUDE
=
$(
python3
-c
"from sysconfig import get_paths; print(get_paths()['include'])"
)
PYTHON_PLATLIB
=
$(
python3
-c
"from sysconfig import get_paths; print(get_paths()['platlib'])"
)
PYTHON_PLATLIB
=
$(
python3
-c
"from sysconfig import get_paths; print(get_paths()['platlib'])"
)
/opt/dtk/bin/hipcc
-Icsrc
/
-I
$(
pwd
)
/rocshmem_dir/include/
-I
/opt/mpi/include
-I
${
PYTHON_PLATLIB
}
/torch/include
-I
${
PYTHON_PLATLIB
}
/torch/include/torch/csrc/api/include
-I
${
PYTHON_PLATLIB
}
/torch/include/TH
-I
${
PYTHON_PLATLIB
}
/torch/include/THC
-I
${
PYTHON_PLATLIB
}
/torch/include/THH
-I
/opt/dtk/include
-I
${
PYTHON_INCLUDE
}
-c
-c
./csrc/kernels/intranode.cu
-o
build_/intranode.o
-fPIC
-D__HIP_PLATFORM_AMD__
=
1
-DUSE_ROCM
=
1
-DHIPBLAS_V2
-DCUDA_HAS_FP16
=
1
-D__HIP_NO_HALF_OPERATORS__
=
1
-D__HIP_NO_HALF_CONVERSIONS__
=
1
-O3
-fgpu-rdc
-DTORCH_API_INCLUDE_EXTENSION_H
'-DPYBIND11_COMPILER_TYPE="_gcc"'
'-DPYBIND11_STDLIB="_libstdcpp"'
'-DPYBIND11_BUILD_ABI="_cxxabi1014"'
-DTORCH_EXTENSION_NAME
=
deep_ep_cpp
-D_GLIBCXX_USE_CXX11_ABI
=
1
--offload-arch
=
gfx936
-std
=
c++17
INCLUDE_PATHS
=
${
INCLUDE_PATHS
:
=-Icsrc/ -I
$(
pwd
)
/rocshmem_dir/include/ -I/opt/mpi/include -I
${
PYTHON_PLATLIB
}
/torch/include -I
${
PYTHON_PLATLIB
}
/torch/include/torch/csrc/api/include -I
${
PYTHON_PLATLIB
}
/torch/include/TH -I
${
PYTHON_PLATLIB
}
/torch/include/THC -I
${
PYTHON_PLATLIB
}
/torch/include/THH -I/opt/dtk/include -I
${
PYTHON_INCLUDE
}}
/opt/dtk/bin/hipcc
-Icsrc
/
-I
$(
pwd
)
/rocshmem_dir/include/
-I
/opt/mpi/include
-I
${
PYTHON_PLATLIB
}
/torch/include
-I
${
PYTHON_PLATLIB
}
/torch/include/torch/csrc/api/include
-I
${
PYTHON_PLATLIB
}
/torch/include/TH
-I
${
PYTHON_PLATLIB
}
/torch/include/THC
-I
${
PYTHON_PLATLIB
}
/torch/include/THH
-I
/opt/dtk/include
-I
${
PYTHON_INCLUDE
}
-c
-c
./csrc/kernels/runtime.cu
-o
build_/runtime.o
-fPIC
-D__HIP_PLATFORM_AMD__
=
1
-DUSE_ROCM
=
1
-DHIPBLAS_V2
-DCUDA_HAS_FP16
=
1
-D__HIP_NO_HALF_OPERATORS__
=
1
-D__HIP_NO_HALF_CONVERSIONS__
=
1
-O3
-fgpu-rdc
-DTORCH_API_INCLUDE_EXTENSION_H
'-DPYBIND11_COMPILER_TYPE="_gcc"'
'-DPYBIND11_STDLIB="_libstdcpp"'
'-DPYBIND11_BUILD_ABI="_cxxabi1014"'
-DTORCH_EXTENSION_NAME
=
deep_ep_cpp
-D_GLIBCXX_USE_CXX11_ABI
=
1
--offload-arch
=
gfx936
-std
=
c++17
COMPILE_OPTIONS
=
${
COMPILE_OPTIONS
:
= -fPIC -D__HIP_PLATFORM_AMD__=1 -DUSE_ROCM=1 -DHIPBLAS_V2 -DCUDA_HAS_FP16=1 -D__HIP_NO_HALF_OPERATORS__=1 -D__HIP_NO_HALF_CONVERSIONS__=1 -O3 -fgpu-rdc -DTORCH_API_INCLUDE_EXTENSION_H
'-DPYBIND11_COMPILER_TYPE="_gcc"'
'-DPYBIND11_STDLIB="_libstdcpp"'
'-DPYBIND11_BUILD_ABI="_cxxabi1014"'
-DTORCH_EXTENSION_NAME=deep_ep_cpp -D_GLIBCXX_USE_CXX11_ABI=1 --offload-arch=gfx936 -std=c++17
}
/opt/dtk/bin/hipcc
-Icsrc
/
-I
$(
pwd
)
/rocshmem_dir/include/
-I
/opt/mpi/include
-I
${
PYTHON_PLATLIB
}
/torch/include
-I
${
PYTHON_PLATLIB
}
/torch/include/torch/csrc/api/include
-I
${
PYTHON_PLATLIB
}
/torch/include/TH
-I
${
PYTHON_PLATLIB
}
/torch/include/THC
-I
${
PYTHON_PLATLIB
}
/torch/include/THH
-I
/opt/dtk/include
-I
${
PYTHON_INCLUDE
}
-c
-c
./csrc/kernels/layout.cu
-o
build_/layout.o
-fPIC
-D__HIP_PLATFORM_AMD__
=
1
-DUSE_ROCM
=
1
-DHIPBLAS_V2
-DCUDA_HAS_FP16
=
1
-D__HIP_NO_HALF_OPERATORS__
=
1
-D__HIP_NO_HALF_CONVERSIONS__
=
1
-O3
-fgpu-rdc
-DTORCH_API_INCLUDE_EXTENSION_H
'-DPYBIND11_COMPILER_TYPE="_gcc"'
'-DPYBIND11_STDLIB="_libstdcpp"'
'-DPYBIND11_BUILD_ABI="_cxxabi1014"'
-DTORCH_EXTENSION_NAME
=
deep_ep_cpp
-D_GLIBCXX_USE_CXX11_ABI
=
1
--offload-arch
=
gfx936
-std
=
c++17
/opt/dtk/bin/hipcc
-Icsrc
/
-I
$(
pwd
)
/rocshmem_dir/include/
-I
/opt/mpi/include
-I
${
PYTHON_PLATLIB
}
/torch/include
-I
${
PYTHON_PLATLIB
}
/torch/include/torch/csrc/api/include
-I
${
PYTHON_PLATLIB
}
/torch/include/TH
-I
${
PYTHON_PLATLIB
}
/torch/include/THC
-I
${
PYTHON_PLATLIB
}
/torch/include/THH
-I
/opt/dtk/include
-I
${
PYTHON_INCLUDE
}
-c
-c
./csrc/deep_ep.cu
-o
build_/deep_ep.o
-fPIC
-D__HIP_PLATFORM_AMD__
=
1
-DUSE_ROCM
=
1
-DHIPBLAS_V2
-DCUDA_HAS_FP16
=
1
-D__HIP_NO_HALF_OPERATORS__
=
1
-D__HIP_NO_HALF_CONVERSIONS__
=
1
-O3
-fgpu-rdc
-DTORCH_API_INCLUDE_EXTENSION_H
'-DPYBIND11_COMPILER_TYPE="_gcc"'
'-DPYBIND11_STDLIB="_libstdcpp"'
'-DPYBIND11_BUILD_ABI="_cxxabi1014"'
-DTORCH_EXTENSION_NAME
=
deep_ep_cpp
-D_GLIBCXX_USE_CXX11_ABI
=
1
--offload-arch
=
gfx936
-std
=
c++17
hipcc
${
INCLUDE_PATHS
}
-c
$(
pwd
)
/csrc/kernels/runtime.cu
-o
build_/runtime.o
${
COMPILE_OPTIONS
}
/opt/dtk/bin/hipcc
-Icsrc
/
-I
$(
pwd
)
/rocshmem_dir/include/
-I
/opt/mpi/include
-I
${
PYTHON_PLATLIB
}
/torch/include
-I
${
PYTHON_PLATLIB
}
/torch/include/torch/csrc/api/include
-I
${
PYTHON_PLATLIB
}
/torch/include/TH
-I
${
PYTHON_PLATLIB
}
/torch/include/THC
-I
${
PYTHON_PLATLIB
}
/torch/include/THH
-I
/opt/dtk/include
-I
${
PYTHON_INCLUDE
}
-c
-c
./csrc/kernels/internode.cu
-o
build_/internode.o
-fPIC
-D__HIP_PLATFORM_AMD__
=
1
-DUSE_ROCM
=
1
-DHIPBLAS_V2
-DCUDA_HAS_FP16
=
1
-D__HIP_NO_HALF_OPERATORS__
=
1
-D__HIP_NO_HALF_CONVERSIONS__
=
1
-O3
-fgpu-rdc
-DTORCH_API_INCLUDE_EXTENSION_H
'-DPYBIND11_COMPILER_TYPE="_gcc"'
'-DPYBIND11_STDLIB="_libstdcpp"'
'-DPYBIND11_BUILD_ABI="_cxxabi1014"'
-DTORCH_EXTENSION_NAME
=
deep_ep_cpp
-D_GLIBCXX_USE_CXX11_ABI
=
1
--offload-arch
=
gfx936
-std
=
c++17
hipcc
${
INCLUDE_PATHS
}
-c
$(
pwd
)
/csrc/kernels/layout.cu
-o
build_/layout.o
${
COMPILE_OPTIONS
}
hipcc
-Wno-unused-result
-Wsign-compare
-DNDEBUG
-g
-fwrapv
-O2
-Wall
-g
-fstack-protector-strong
-Wformat
-Werror
=
format-security
-g
-fwrapv
-O2
-shared
-Wl
,-O1
-Wl
,-Bsymbolic-functions build_/internode.o build_/intranode.o build_/runtime.o build_/deep_ep.o build_/layout.o
-L
$(
pwd
)
/rocshmem_dir/lib/
-L
/opt/mpi/lib
-L
/opt/dtk/hip/lib
-L
/usr/lib/x86_64-linux-gnu
-lhipblaslt
-lamdhip64
-o
deep_ep/deep_ep_cpp.cpython-310-x86_64-linux-gnu.so
-Wl
,-rpath,/opt/dtk/lib
-fgpu-rdc
--hip-link
--offload-arch
=
gfx936
-shared
-Wl
,-soname,deep_ep/deep_ep_cpp.cpython-310-x86_64-linux-gnu.so
-Wl
,-rpath,
$(
pwd
)
/rocshmem_dir/lib/
-L
"/opt/dtk/llvm/lib/clang/15.0.0/include/../lib/linux"
-lclang_rt
.builtins-x86_64 /opt/dtk/hip/lib/libgalaxyhip.so /opt/dtk/llvm/lib/clang/15.0.0/lib/linux/libclang_rt.builtins-x86_64.a /opt/hyhal/lib/libhsa-runtime64.so.1.11.0
-L
${
PYTHON_PLATLIB
}
/torch/lib
-L
/opt/dtk/lib
-L
/opt/dtk/hip/lib
-L
/usr/local/lib
-lc10
-ltorch
-ltorch_cpu
-ltorch_python
-lamdhip64
-lc10_hip
-ltorch_hip
-lrocm-core
-lrocm_smi64
-l
:librocshmem.a
-fgpu-rdc
--hip-link
-lamdhip64
-lhsa-runtime64
-l
:libmpi.so
-Wl
,-rpath,/opt/mpi/lib/
-libverbs
-lmlx5
hipcc
${
INCLUDE_PATHS
}
-c
$(
pwd
)
/csrc/kernels/intranode.cu
-o
build_/intranode.o
${
COMPILE_OPTIONS
}
hipcc
${
INCLUDE_PATHS
}
-c
$(
pwd
)
/csrc/kernels/internode.cu
-o
build_/internode.o
${
COMPILE_OPTIONS
}
hipcc
${
INCLUDE_PATHS
}
-c
$(
pwd
)
/csrc/kernels/internode_ll.cu
-o
build_/internode_ll.o
${
COMPILE_OPTIONS
}
hipcc
${
INCLUDE_PATHS
}
-c
$(
pwd
)
/csrc/deep_ep.cu
-o
build_/deep_ep.o
${
COMPILE_OPTIONS
}
hipcc
-Wno-unused-result
-Wsign-compare
-DNDEBUG
-g
-fwrapv
-O2
-Wall
-g
-fstack-protector-strong
-Wformat
-Werror
=
format-security
-g
-fwrapv
-O2
-shared
-Wl
,-O1
-Wl
,-Bsymbolic-functions build_/internode.o build_/intranode.o build_/runtime.o build_/deep_ep.o build_/layout.o build_/internode_ll.o
-L
$(
pwd
)
/rocshmem_dir/lib/
-L
/opt/mpi/lib
-L
/opt/dtk/hip/lib
-L
/usr/lib/x86_64-linux-gnu
-lhipblaslt
-lamdhip64
-o
deep_ep/deep_ep_cpp.cpython-310-x86_64-linux-gnu.so
-Wl
,-rpath,/opt/dtk/lib
-fgpu-rdc
--hip-link
--offload-arch
=
gfx936
-shared
-Wl
,-soname,deep_ep/deep_ep_cpp.cpython-310-x86_64-linux-gnu.so
-Wl
,-rpath,
$(
pwd
)
/rocshmem_dir/lib/
-L
"/opt/dtk/llvm/lib/clang/15.0.0/include/../lib/linux"
-lclang_rt
.builtins-x86_64 /opt/dtk/hip/lib/libgalaxyhip.so /opt/dtk/llvm/lib/clang/15.0.0/lib/linux/libclang_rt.builtins-x86_64.a /opt/hyhal/lib/libhsa-runtime64.so.1.11.0
-L
${
PYTHON_PLATLIB
}
/torch/lib
-L
/opt/dtk/lib
-L
/opt/dtk/hip/lib
-L
/usr/local/lib
-lc10
-ltorch
-ltorch_cpu
-ltorch_python
-lamdhip64
-lc10_hip
-ltorch_hip
-lrocm-core
-lrocm_smi64
-l
:librocshmem.a
-fgpu-rdc
--hip-link
-lamdhip64
-lhsa-runtime64
-l
:libmpi.so
-Wl
,-rpath,/opt/mpi/lib/
-libverbs
-lmlx5
# build whl
# build whl
echo
"Using Python:
$(
which python3
)
"
echo
"Using Python:
$(
which python3
)
"
...
...
csrc/config.hpp
View file @
09cb2b03
// !!! This is a file automatically generated by hipify!!!
#include <ATen/dtk_macros.h>
#pragma once
#pragma once
#include "kernels/api.cuh"
#include "kernels/api.cuh"
...
@@ -105,18 +107,18 @@ struct Config {
...
@@ -105,18 +107,18 @@ struct Config {
struct
LowLatencyBuffer
{
struct
LowLatencyBuffer
{
int
num_clean_int
=
0
;
int
num_clean_int
=
0
;
void
*
dispatch_rdma_send_buffer
=
nullptr
;
void
*
dispatch_rdma_send_buffer
=
nullptr
;
void
*
dispatch_rdma_recv_data_buffer
=
nullptr
;
void
*
dispatch_rdma_recv_data_buffer
=
nullptr
;
int
*
dispatch_rdma_recv_count_buffer
=
nullptr
;
int
64_t
*
dispatch_rdma_recv_count_buffer
=
nullptr
;
void
*
combine_rdma_send_buffer
=
nullptr
;
void
*
combine_rdma_send_buffer
=
nullptr
;
void
*
combine_rdma_recv_data_buffer
=
nullptr
;
void
*
combine_rdma_recv_data_buffer
=
nullptr
;
int
*
combine_rdma_recv_flag_buffer
=
nullptr
;
int
64_t
*
combine_rdma_recv_flag_buffer
=
nullptr
;
void
*
combine_rdma_send_buffer_data_start
=
nullptr
;
void
*
combine_rdma_send_buffer_data_start
=
nullptr
;
size_t
num_bytes_per_combine_msg
=
0
;
size_t
num_bytes_per_combine_msg
=
0
;
std
::
pair
<
int
*
,
int
>
clean_meta
()
{
std
::
pair
<
int
64_t
*
,
int
>
clean_meta
()
{
EP_HOST_ASSERT
(
dispatch_rdma_recv_count_buffer
==
combine_rdma_recv_flag_buffer
);
EP_HOST_ASSERT
(
dispatch_rdma_recv_count_buffer
==
combine_rdma_recv_flag_buffer
);
return
{
dispatch_rdma_recv_count_buffer
,
num_clean_int
};
return
{
dispatch_rdma_recv_count_buffer
,
num_clean_int
};
}
}
...
@@ -171,29 +173,30 @@ struct LowLatencyLayout {
...
@@ -171,29 +173,30 @@ struct LowLatencyLayout {
total_bytes
+=
recv_buffer_bytes
*
2
;
total_bytes
+=
recv_buffer_bytes
*
2
;
// Symmetric signaling buffers
// Symmetric signaling buffers
size_t
dispatch_recv_count_buffer_bytes
=
num_experts
*
sizeof
(
int
);
size_t
dispatch_recv_count_buffer_bytes
=
num_experts
*
sizeof
(
int64_t
);
size_t
combine_recv_flag_buffer_bytes
=
dispatch_recv_count_buffer_bytes
;
size_t
combine_recv_flag_buffer_bytes
=
dispatch_recv_count_buffer_bytes
;
size_t
signaling_buffer_bytes
=
size_t
signaling_buffer_bytes
=
std
::
max
(
dispatch_recv_count_buffer_bytes
,
combine_recv_flag_buffer_bytes
);
std
::
max
(
dispatch_recv_count_buffer_bytes
,
combine_recv_flag_buffer_bytes
);
total_bytes
+=
signaling_buffer_bytes
*
2
;
size_t
signaling_buffer_bytes_aligned
=
ALIGN
<
size_t
>
(
signaling_buffer_bytes
,
128
);
total_bytes
+=
signaling_buffer_bytes_aligned
*
2
;
// Assign pointers
// Assign pointers
// NOTES: we still leave some space for distinguishing dispatch/combine buffer,
// NOTES: we still leave some space for distinguishing dispatch/combine buffer,
// so you may see some parameters are duplicated
// so you may see some parameters are duplicated
for
(
int
i
=
0
;
i
<
2
;
++
i
)
{
for
(
int
i
=
0
;
i
<
2
;
++
i
)
{
buffers
[
i
]
=
{
buffers
[
i
]
=
{
static_cast
<
int
>
(
signaling_buffer_bytes
/
sizeof
(
int
)),
static_cast
<
int
>
(
signaling_buffer_bytes
/
sizeof
(
int64_t
)),
advance
(
rdma_buffer
,
signaling_buffer_bytes_aligned
*
2
+
send_buffer_bytes
*
i
),
// dispatch:send_buffer + recv_buffer + recv_count
advance
(
rdma_buffer
,
signaling_buffer_bytes_aligned
*
2
+
send_buffer_bytes
*
2
+
advance
(
rdma_buffer
,
send_buffer_bytes
*
i
),
recv_buffer_bytes
*
i
),
advance
(
rdma_buffer
,
send_buffer_bytes
*
2
+
recv_buffer_bytes
*
i
),
advance
<
int
*>
(
rdma_buffer
,
signaling_buffer_bytes_aligned
*
i
),
advance
<
int64_t
*>
(
rdma_buffer
,
send_buffer_bytes
*
2
+
recv_buffer_bytes
*
2
+
signaling_buffer_bytes
*
i
),
advance
(
rdma_buffer
,
signaling_buffer_bytes_aligned
*
2
+
send_buffer_bytes
*
i
),
// combine:send_buffer + recv_buffer + recv_count
advance
(
rdma_buffer
,
signaling_buffer_bytes_aligned
*
2
+
send_buffer_bytes
*
2
+
advance
(
rdma_buffer
,
send_buffer_bytes
*
i
),
recv_buffer_bytes
*
i
),
advance
(
rdma_buffer
,
send_buffer_bytes
*
2
+
recv_buffer_bytes
*
i
),
advance
<
int
*>
(
rdma_buffer
,
signaling_buffer_bytes_aligned
*
i
),
advance
<
int64_t
*>
(
rdma_buffer
,
send_buffer_bytes
*
2
+
recv_buffer_bytes
*
2
+
signaling_buffer_bytes
*
i
),
advance
(
rdma_buffer
,
signaling_buffer_bytes_aligned
*
2
+
send_buffer_bytes
*
i
),
// combine_rdma_send_buffer_data_start
num_bytes_per_combine_msg
};
advance
(
rdma_buffer
,
send_buffer_bytes
*
i
+
sizeof
(
int4
)),
//
num_bytes_per_combine_msg
};
}
}
}
}
};
};
...
...
csrc/deep_ep.cu
View file @
09cb2b03
This diff is collapsed.
Click to expand it.
csrc/deep_ep.hpp
View file @
09cb2b03
...
@@ -30,6 +30,11 @@ private:
...
@@ -30,6 +30,11 @@ private:
int64_t
num_rdma_bytes
;
int64_t
num_rdma_bytes
;
void
*
rdma_buffer_ptr
=
nullptr
;
void
*
rdma_buffer_ptr
=
nullptr
;
// Shrink mode buffer
bool
enable_shrink
=
false
;
int
*
mask_buffer_ptr
=
nullptr
;
int
*
sync_buffer_ptr
=
nullptr
;
// Device info and communication
// Device info and communication
int
device_id
;
int
device_id
;
int
num_device_sms
;
int
num_device_sms
;
...
@@ -67,11 +72,9 @@ private:
...
@@ -67,11 +72,9 @@ private:
volatile
int
*
moe_recv_rdma_counter
=
nullptr
;
volatile
int
*
moe_recv_rdma_counter
=
nullptr
;
int
*
moe_recv_rdma_counter_mapped
=
nullptr
;
int
*
moe_recv_rdma_counter_mapped
=
nullptr
;
bool
use_default_stream_as_comm_stream
=
false
;
public:
public:
Buffer
(
int
rank
,
int
num_ranks
,
int64_t
num_nvl_bytes
,
int64_t
num_rdma_bytes
,
Buffer
(
int
rank
,
int
num_ranks
,
int64_t
num_nvl_bytes
,
int64_t
num_rdma_bytes
,
bool
low_latency_mode
,
bool
explicitly_destroy
,
bool
use_default_stream_as_comm_stream
);
bool
low_latency_mode
,
bool
explicitly_destroy
,
bool
enable_shrink
);
~
Buffer
()
noexcept
(
false
);
~
Buffer
()
noexcept
(
false
);
...
@@ -187,6 +190,12 @@ public:
...
@@ -187,6 +190,12 @@ public:
torch
::
Tensor
get_next_low_latency_combine_buffer
(
int
num_max_dispatch_tokens_per_rank
,
torch
::
Tensor
get_next_low_latency_combine_buffer
(
int
num_max_dispatch_tokens_per_rank
,
int
hidden
,
int
num_experts
)
const
;
int
hidden
,
int
num_experts
)
const
;
void
low_latency_update_mask_buffer
(
int
rank_to_mask
,
bool
mask
);
void
low_latency_query_mask_buffer
(
const
torch
::
Tensor
&
mask_status
);
void
low_latency_clean_mask_buffer
();
};
};
}
// namespace deep_ep
}
// namespace deep_ep
csrc/kernels/api.cuh
View file @
09cb2b03
...
@@ -131,4 +131,46 @@ void combine(hipDataType type, void *combined_x, float *combined_topk_weights,
...
@@ -131,4 +131,46 @@ void combine(hipDataType type, void *combined_x, float *combined_topk_weights,
int
num_ranks
,
hipStream_t
stream
,
int
num_channels
,
bool
low_latency_mode
);
int
num_ranks
,
hipStream_t
stream
,
int
num_channels
,
bool
low_latency_mode
);
}
// namespace internode
}
// namespace internode
// Internode low-latency kernels
namespace
internode_ll
{
void
clean_low_latency_buffer
(
int64_t
*
clean_0
,
int
num_clean_int_0
,
int64_t
*
clean_1
,
int
num_clean_int_1
,
int
rank
,
int
num_ranks
,
int
*
mask_buffer
,
int
*
sync_buffer
,
hipStream_t
stream
);
void
dispatch
(
void
*
packed_recv_x
,
void
*
packed_recv_x_scales
,
int
*
packed_recv_src_info
,
int64_t
*
packed_recv_layout_range
,
int
*
packed_recv_count
,
int
*
global_atomic_counter
,
int
*
mask_buffer
,
int
*
cumulative_local_expert_recv_stats
,
int64_t
*
dispatch_wait_recv_cost_stats
,
void
*
rdma_recv_x
,
int64_t
*
rdma_recv_count
,
void
*
rdma_x
,
const
void
*
x
,
const
int64_t
*
topk_idx
,
int64_t
*
next_clean
,
int
num_next_clean_int
,
int
num_tokens
,
int
hidden
,
int
num_max_dispatch_tokens_per_rank
,
int
num_topk
,
int
num_experts
,
int
rank
,
int
num_ranks
,
bool
use_fp8
,
bool
round_scale
,
bool
use_ue8m0
,
void
*
workspace
,
int
num_device_sms
,
hipStream_t
stream
,
int
phases
);
void
combine
(
void
*
combined_x
,
void
*
rdma_recv_x
,
int64_t
*
rdma_recv_flag
,
void
*
rdma_send_x
,
const
void
*
x
,
const
int64_t
*
topk_idx
,
const
float
*
topk_weights
,
const
int
*
src_info
,
const
int64_t
*
layout_range
,
int
*
global_atomic_counter
,
int
*
mask_buffer
,
int64_t
*
combine_wait_recv_cost_stats
,
int64_t
*
next_clean
,
int
num_next_clean_int
,
int
num_combined_tokens
,
int
hidden
,
int
num_max_dispatch_tokens_per_rank
,
int
num_topk
,
int
num_experts
,
int
rank
,
int
num_ranks
,
bool
use_logfmt
,
void
*
workspace
,
int
num_device_sms
,
hipStream_t
stream
,
int
phases
,
bool
zero_copy
);
void
query_mask_buffer
(
int
*
mask_buffer_ptr
,
int
num_ranks
,
int
*
output_mask_tensor
,
hipStream_t
stream
);
void
update_mask_buffer
(
int
*
mask_buffer_ptr
,
int
rank_to_mask
,
bool
mask
,
hipStream_t
stream
);
void
clean_mask_buffer
(
int
*
mask_buffer_ptr
,
int
num_ranks
,
hipStream_t
stream
);
}
// namespace internode_ll
}
// namespace deep_ep
}
// namespace deep_ep
csrc/kernels/configs.cuh
View file @
09cb2b03
...
@@ -22,6 +22,8 @@
...
@@ -22,6 +22,8 @@
#define LOW_LATENCY_SEND_PHASE 1
#define LOW_LATENCY_SEND_PHASE 1
#define LOW_LATENCY_RECV_PHASE 2
#define LOW_LATENCY_RECV_PHASE 2
#define NUM_INTERNODE_DISPATCH_BLOCKS_PER_CHANNEL 3
#define FP8_QUANTIZATION_NUM_PER_CHANNEL 128
#define NUM_INTERNODE_DISPATCH_BLOCKS_PER_CHANNEL 3
#define NUM_INTERNODE_DISPATCH_BLOCKS_PER_CHANNEL 3
...
...
csrc/kernels/internode_ll.cu
View file @
09cb2b03
This diff is collapsed.
Click to expand it.
csrc/kernels/utils.cuh
View file @
09cb2b03
...
@@ -125,6 +125,10 @@ __device__ __forceinline__ void st_release_sys_global(const int *ptr, int val) {
...
@@ -125,6 +125,10 @@ __device__ __forceinline__ void st_release_sys_global(const int *ptr, int val) {
__hip_atomic_store
(
const_cast
<
int
*>
(
ptr
),
val
,
__ATOMIC_RELEASE
,
__HIP_MEMORY_SCOPE_SYSTEM
);
__hip_atomic_store
(
const_cast
<
int
*>
(
ptr
),
val
,
__ATOMIC_RELEASE
,
__HIP_MEMORY_SCOPE_SYSTEM
);
}
}
__device__
__forceinline__
void
st_release_sys_global
(
const
int64_t
*
ptr
,
int64_t
val
)
{
__hip_atomic_store
(
const_cast
<
int64_t
*>
(
ptr
),
val
,
__ATOMIC_RELEASE
,
__HIP_MEMORY_SCOPE_SYSTEM
);
}
__device__
__forceinline__
void
st_release_cta
(
const
int
*
ptr
,
int
val
)
{
__device__
__forceinline__
void
st_release_cta
(
const
int
*
ptr
,
int
val
)
{
__hip_atomic_store
(
const_cast
<
int
*>
(
ptr
),
val
,
__ATOMIC_RELEASE
,
__HIP_MEMORY_SCOPE_WORKGROUP
);
__hip_atomic_store
(
const_cast
<
int
*>
(
ptr
),
val
,
__ATOMIC_RELEASE
,
__HIP_MEMORY_SCOPE_WORKGROUP
);
}
}
...
@@ -157,6 +161,12 @@ __device__ __forceinline__ int ld_acquire_global(const int *ptr) {
...
@@ -157,6 +161,12 @@ __device__ __forceinline__ int ld_acquire_global(const int *ptr) {
return
ret
;
return
ret
;
}
}
__device__
__forceinline__
int64_t
ld_acquire_global
(
const
int64_t
*
ptr
)
{
int64_t
ret
;
ret
=
__hip_atomic_load
(
ptr
,
__ATOMIC_ACQUIRE
,
__HIP_MEMORY_SCOPE_AGENT
);
return
ret
;
}
__device__
__forceinline__
int
atomic_add_release_global
(
const
int
*
ptr
,
int
value
)
{
__device__
__forceinline__
int
atomic_add_release_global
(
const
int
*
ptr
,
int
value
)
{
int
ret
;
int
ret
;
// ret = __hip_atomic_fetch_add(const_cast<int *>(ptr), value, __ATOMIC_RELEASE,
// ret = __hip_atomic_fetch_add(const_cast<int *>(ptr), value, __ATOMIC_RELEASE,
...
@@ -165,6 +175,12 @@ __device__ __forceinline__ int atomic_add_release_global(const int *ptr, int val
...
@@ -165,6 +175,12 @@ __device__ __forceinline__ int atomic_add_release_global(const int *ptr, int val
return
ret
;
return
ret
;
}
}
__device__
__forceinline__
int
ld_relaxed_global
(
const
int
*
ptr
)
{
int
ret
;
ret
=
__hip_atomic_load
(
ptr
,
__ATOMIC_RELAXED
,
__HIP_MEMORY_SCOPE_AGENT
);
return
ret
;
}
__device__
__forceinline__
int
ld_acquire_cta
(
const
int
*
ptr
)
{
__device__
__forceinline__
int
ld_acquire_cta
(
const
int
*
ptr
)
{
int
ret
;
int
ret
;
ret
=
__hip_atomic_load
(
ptr
,
__ATOMIC_ACQUIRE
,
__HIP_MEMORY_SCOPE_WORKGROUP
);
ret
=
__hip_atomic_load
(
ptr
,
__ATOMIC_ACQUIRE
,
__HIP_MEMORY_SCOPE_WORKGROUP
);
...
@@ -245,6 +261,11 @@ __device__ __forceinline__ void st_na_release(const uint64_t *ptr, uint64_t val)
...
@@ -245,6 +261,11 @@ __device__ __forceinline__ void st_na_release(const uint64_t *ptr, uint64_t val)
__hip_atomic_store
(
non_const_ptr
,
val
,
__ATOMIC_RELEASE
,
__HIP_MEMORY_SCOPE_AGENT
);
__hip_atomic_store
(
non_const_ptr
,
val
,
__ATOMIC_RELEASE
,
__HIP_MEMORY_SCOPE_AGENT
);
}
}
__device__
__forceinline__
void
st_na_release
(
const
int64_t
*
ptr
,
int64_t
val
)
{
int64_t
*
non_const_ptr
=
const_cast
<
int64_t
*>
(
ptr
);
__hip_atomic_store
(
non_const_ptr
,
val
,
__ATOMIC_RELEASE
,
__HIP_MEMORY_SCOPE_AGENT
);
}
// TODO:: apply "st.global.L1::no_allocate" in ROCM
// TODO:: apply "st.global.L1::no_allocate" in ROCM
template
<
typename
dtype_t
>
template
<
typename
dtype_t
>
__device__
__forceinline__
void
st_na_global
(
const
dtype_t
*
ptr
,
const
dtype_t
&
value
)
{
__device__
__forceinline__
void
st_na_global
(
const
dtype_t
*
ptr
,
const
dtype_t
&
value
)
{
...
@@ -279,6 +300,22 @@ __forceinline__ __device__ void get_channel_task_range(int num_tokens, int num_s
...
@@ -279,6 +300,22 @@ __forceinline__ __device__ void get_channel_task_range(int num_tokens, int num_s
token_end_idx
=
min
(
token_start_idx
+
num_tokens_per_sm
,
num_tokens
);
token_end_idx
=
min
(
token_start_idx
+
num_tokens_per_sm
,
num_tokens
);
}
}
template
<
typename
dtype_a_t
,
typename
dtype_b_t
>
__device__
__forceinline__
dtype_b_t
pack2
(
const
dtype_a_t
&
x
,
const
dtype_a_t
&
y
)
{
EP_STATIC_ASSERT
(
sizeof
(
dtype_a_t
)
*
2
==
sizeof
(
dtype_b_t
),
"Invalid dtypes"
);
dtype_b_t
packed
;
auto
unpacked_ptr
=
reinterpret_cast
<
dtype_a_t
*>
(
&
packed
);
unpacked_ptr
[
0
]
=
x
,
unpacked_ptr
[
1
]
=
y
;
return
packed
;
}
template
<
typename
dtype_a_t
,
typename
dtype_b_t
>
__device__
__forceinline__
void
unpack2
(
const
dtype_b_t
&
packed
,
dtype_a_t
&
x
,
dtype_a_t
&
y
)
{
EP_STATIC_ASSERT
(
sizeof
(
dtype_a_t
)
*
2
==
sizeof
(
dtype_b_t
),
"Invalid dtypes"
);
auto
unpacked_ptr
=
reinterpret_cast
<
const
dtype_a_t
*>
(
&
packed
);
x
=
unpacked_ptr
[
0
],
y
=
unpacked_ptr
[
1
];
}
template
<
typename
dtype_t
>
template
<
typename
dtype_t
>
__device__
__forceinline__
dtype_t
broadcast
(
dtype_t
&
ptr
,
int
src_lane_idx
)
{
__device__
__forceinline__
dtype_t
broadcast
(
dtype_t
&
ptr
,
int
src_lane_idx
)
{
EP_STATIC_ASSERT
(
sizeof
(
dtype_t
)
%
sizeof
(
int
)
==
0
,
""
);
EP_STATIC_ASSERT
(
sizeof
(
dtype_t
)
%
sizeof
(
int
)
==
0
,
""
);
...
@@ -290,15 +327,47 @@ __device__ __forceinline__ dtype_t broadcast(dtype_t &ptr, int src_lane_idx) {
...
@@ -290,15 +327,47 @@ __device__ __forceinline__ dtype_t broadcast(dtype_t &ptr, int src_lane_idx) {
return
*
reinterpret_cast
<
dtype_t
*>
(
recv_int_values
);
return
*
reinterpret_cast
<
dtype_t
*>
(
recv_int_values
);
}
}
__forceinline__
__device__
int
warp_reduce_sum
(
int
value
)
{
#ifdef USE_ROCM
if
constexpr
(
kWarpSize
==
64
)
constexpr
float
kFP8Margin
=
1e-4
;
value
+=
shfl_xor
<
int
>
(
value
,
32
);
constexpr
float
kFinfoAmaxE4M3
=
240.0
f
;
value
+=
shfl_xor
<
int
>
(
value
,
16
);
constexpr
float
kFinfoAmaxInvE4M3
=
1.0
f
/
kFinfoAmaxE4M3
;
value
+=
shfl_xor
<
int
>
(
value
,
8
);
#else
value
+=
shfl_xor
<
int
>
(
value
,
4
);
constexpr
float
kFP8Margin
=
1e-4
;
value
+=
shfl_xor
<
int
>
(
value
,
2
);
constexpr
float
kFinfoAmaxE4M3
=
448.0
f
;
value
+=
shfl_xor
<
int
>
(
value
,
1
);
constexpr
float
kFinfoAmaxInvE4M3
=
1.0
f
/
kFinfoAmaxE4M3
;
return
value
;
#endif
__forceinline__
__device__
float
fast_pow2
(
int
x
)
{
// We can ensure `-126 <= x and x <= 127`
uint32_t
bits_x
=
(
x
+
127
)
<<
23
;
return
*
reinterpret_cast
<
float
*>
(
&
bits_x
);
}
__forceinline__
__device__
int
fast_log2_ceil
(
float
x
)
{
auto
bits_x
=
*
reinterpret_cast
<
uint32_t
*>
(
&
x
);
auto
exp_x
=
(
bits_x
>>
23
)
&
0xff
;
auto
man_bits
=
bits_x
&
((
1
<<
23
)
-
1
);
return
exp_x
-
127
+
(
man_bits
!=
0
);
}
__forceinline__
__device__
void
calculate_fp8_scales
(
float
amax
,
float
&
scale
,
float
&
scale_inv
,
bool
round_scale
)
{
if
(
round_scale
)
{
auto
exp_scale_inv
=
fast_log2_ceil
(
amax
*
kFinfoAmaxInvE4M3
);
scale
=
fast_pow2
(
-
exp_scale_inv
);
scale_inv
=
fast_pow2
(
exp_scale_inv
);
}
else
{
scale_inv
=
amax
*
kFinfoAmaxInvE4M3
;
scale
=
kFinfoAmaxE4M3
/
amax
;
}
}
template
<
bool
kIsUE8M0
,
typename
out_dtype_t
=
std
::
conditional_t
<
kIsUE8M0
,
uint8_t
,
float
>
>
__forceinline__
__device__
out_dtype_t
extract_required_scale_format
(
float
value
)
{
if
constexpr
(
kIsUE8M0
)
{
return
static_cast
<
uint8_t
>
((
*
reinterpret_cast
<
uint32_t
*>
(
&
value
))
>>
23
);
}
else
{
return
value
;
}
}
}
__forceinline__
__device__
int
get_lane_id
()
{
__forceinline__
__device__
int
get_lane_id
()
{
...
@@ -340,4 +409,95 @@ __forceinline__ __device__ void barrier_block(int **barrier_signal_ptrs, int ran
...
@@ -340,4 +409,95 @@ __forceinline__ __device__ void barrier_block(int **barrier_signal_ptrs, int ran
}
}
__syncthreads
();
__syncthreads
();
}
}
// Operation functors
template
<
typename
T
>
struct
ReduceSum
{
__device__
T
operator
()(
T
a
,
T
b
)
const
{
return
a
+
b
;
}
};
template
<
typename
T
>
struct
ReduceMax
{
__device__
T
operator
()(
T
a
,
T
b
)
const
{
return
a
>
b
?
a
:
b
;
}
};
template
<
typename
T
>
struct
ReduceMin
{
__device__
T
operator
()(
T
a
,
T
b
)
const
{
return
a
<
b
?
a
:
b
;
}
};
template
<
typename
T
>
struct
ReduceAnd
{
__device__
T
operator
()(
T
a
,
T
b
)
const
{
return
a
&
b
;
}
};
template
<
typename
T
>
struct
ReduceOr
{
__device__
T
operator
()(
T
a
,
T
b
)
const
{
return
a
|
b
;
}
};
// Unified reduction function
template
<
int
kNumLanesPerGroup
,
bool
kIntergroupReduce
,
typename
T
,
typename
Op
>
__forceinline__
__device__
T
warp_reduce
(
T
value
,
Op
op
)
{
EP_STATIC_ASSERT
(
kNumLanesPerGroup
==
kWarpSize
or
kNumLanesPerGroup
==
32
or
kNumLanesPerGroup
==
16
or
kNumLanesPerGroup
==
8
or
kNumLanesPerGroup
==
4
or
kNumLanesPerGroup
==
2
or
kNumLanesPerGroup
==
1
,
"Invalid number of lanes"
);
constexpr
uint32_t
mask
=
0xffffffff
;
if
constexpr
(
kIntergroupReduce
)
{
if
constexpr
(
kNumLanesPerGroup
<=
1
)
value
=
op
(
value
,
shfl_xor
(
value
,
1
));
if
constexpr
(
kNumLanesPerGroup
<=
2
)
value
=
op
(
value
,
shfl_xor
(
value
,
2
));
if
constexpr
(
kNumLanesPerGroup
<=
4
)
value
=
op
(
value
,
shfl_xor
(
value
,
4
));
if
constexpr
(
kNumLanesPerGroup
<=
8
)
value
=
op
(
value
,
shfl_xor
(
value
,
8
));
if
constexpr
(
kNumLanesPerGroup
<=
16
)
value
=
op
(
value
,
shfl_xor
(
value
,
16
));
if
constexpr
(
kWarpSize
==
64
){
if
constexpr
(
kNumLanesPerGroup
<=
32
)
value
=
op
(
value
,
shfl_xor
(
value
,
32
));
}
}
else
{
if
constexpr
(
kWarpSize
==
64
){
if
constexpr
(
kNumLanesPerGroup
>=
kWarpSize
)
value
=
op
(
value
,
shfl_xor
(
value
,
32
));
}
if
constexpr
(
kNumLanesPerGroup
>=
32
)
value
=
op
(
value
,
shfl_xor
(
value
,
16
));
if
constexpr
(
kNumLanesPerGroup
>=
16
)
value
=
op
(
value
,
shfl_xor
(
value
,
8
));
if
constexpr
(
kNumLanesPerGroup
>=
8
)
value
=
op
(
value
,
shfl_xor
(
value
,
4
));
if
constexpr
(
kNumLanesPerGroup
>=
4
)
value
=
op
(
value
,
shfl_xor
(
value
,
2
));
if
constexpr
(
kNumLanesPerGroup
>=
2
)
value
=
op
(
value
,
shfl_xor
(
value
,
1
));
}
return
value
;
}
// Convenience aliases
template
<
int
kNumLanesPerGroup
=
kWarpSize
,
bool
kIntergroupReduce
=
false
,
typename
T
>
__forceinline__
__device__
T
warp_reduce_sum
(
T
value
)
{
return
warp_reduce
<
kNumLanesPerGroup
,
kIntergroupReduce
,
T
>
(
value
,
ReduceSum
<
T
>
{});
}
template
<
int
kNumLanesPerGroup
=
kWarpSize
,
bool
kIntergroupReduce
=
false
,
typename
T
>
__forceinline__
__device__
T
warp_reduce_max
(
T
value
)
{
return
warp_reduce
<
kNumLanesPerGroup
,
kIntergroupReduce
,
T
>
(
value
,
ReduceMax
<
T
>
{});
}
template
<
int
kNumLanesPerGroup
=
kWarpSize
,
bool
kIntergroupReduce
=
false
,
typename
T
>
__forceinline__
__device__
T
warp_reduce_min
(
T
value
)
{
return
warp_reduce
<
kNumLanesPerGroup
,
kIntergroupReduce
,
T
>
(
value
,
ReduceMin
<
T
>
{});
}
template
<
int
kNumLanesPerGroup
=
kWarpSize
,
bool
kIntergroupReduce
=
false
,
typename
T
>
__forceinline__
__device__
T
warp_reduce_and
(
T
value
)
{
return
warp_reduce
<
kNumLanesPerGroup
,
kIntergroupReduce
,
T
>
(
value
,
ReduceAnd
<
T
>
{});
}
template
<
int
kNumLanesPerGroup
=
kWarpSize
,
bool
kIntergroupReduce
=
false
,
typename
T
>
__forceinline__
__device__
T
warp_reduce_or
(
T
value
)
{
return
warp_reduce
<
kNumLanesPerGroup
,
kIntergroupReduce
,
T
>
(
value
,
ReduceOr
<
T
>
{});
}
}
// namespace deep_ep
}
// namespace deep_ep
deep_ep/buffer.py
View file @
09cb2b03
...
@@ -39,7 +39,7 @@ class Buffer:
...
@@ -39,7 +39,7 @@ class Buffer:
allow_nvlink_for_low_latency_mode
:
bool
=
True
,
allow_nvlink_for_low_latency_mode
:
bool
=
True
,
allow_mnnvl
:
bool
=
False
,
allow_mnnvl
:
bool
=
False
,
explicitly_destroy
:
bool
=
False
,
explicitly_destroy
:
bool
=
False
,
use_default_stream_as_comm_stream
:
bool
=
Tru
e
,
enable_shrink
:
bool
=
Fals
e
,
)
->
None
:
)
->
None
:
"""
"""
Initialize the communication buffer.
Initialize the communication buffer.
...
@@ -59,6 +59,7 @@ class Buffer:
...
@@ -59,6 +59,7 @@ class Buffer:
explicitly_destroy: If this flag is set to True, you need to explicitly call `destroy()` to release resources;
explicitly_destroy: If this flag is set to True, you need to explicitly call `destroy()` to release resources;
otherwise, the resources will be released by the destructor.
otherwise, the resources will be released by the destructor.
Note: Releasing resources in the destructor may cause Python's exception handling process to hang.
Note: Releasing resources in the destructor may cause Python's exception handling process to hang.
enable_shrink: whether to enable shrink mode. The enable mode allocates a mask buffer to support masking ranks dynamically.
"""
"""
check_nvlink_connections
(
group
)
check_nvlink_connections
(
group
)
...
@@ -70,6 +71,7 @@ class Buffer:
...
@@ -70,6 +71,7 @@ class Buffer:
self
.
num_rdma_bytes
=
num_rdma_bytes
self
.
num_rdma_bytes
=
num_rdma_bytes
self
.
low_latency_mode
=
low_latency_mode
self
.
low_latency_mode
=
low_latency_mode
self
.
explicitly_destroy
=
explicitly_destroy
self
.
explicitly_destroy
=
explicitly_destroy
self
.
enable_shrink
=
enable_shrink
self
.
runtime
=
deep_ep_cpp
.
Buffer
(
self
.
runtime
=
deep_ep_cpp
.
Buffer
(
self
.
rank
,
self
.
rank
,
self
.
group_size
,
self
.
group_size
,
...
@@ -77,7 +79,7 @@ class Buffer:
...
@@ -77,7 +79,7 @@ class Buffer:
num_rdma_bytes
,
num_rdma_bytes
,
low_latency_mode
,
low_latency_mode
,
explicitly_destroy
,
explicitly_destroy
,
use_default_stream_as_comm_stream
,
enable_shrink
)
)
# Synchronize device IDs
# Synchronize device IDs
...
@@ -989,3 +991,31 @@ class Buffer:
...
@@ -989,3 +991,31 @@ class Buffer:
return
self
.
runtime
.
get_next_low_latency_combine_buffer
(
return
self
.
runtime
.
get_next_low_latency_combine_buffer
(
num_max_dispatch_tokens_per_rank
,
hidden
,
num_experts
num_max_dispatch_tokens_per_rank
,
hidden
,
num_experts
)
)
def
low_latency_update_mask_buffer
(
self
,
rank_to_mask
:
int
,
mask
:
bool
=
False
):
"""
Mask (unmask) a rank during communication (dispatch, combine, and clean)
Arguments:
rank: the rank to mask (unmask).
mask: if True, will mask the rank (do not recvfrom/sendto the rank), otherwise will unmask the rank.
"""
self
.
runtime
.
low_latency_update_mask_buffer
(
rank_to_mask
,
mask
)
def
low_latency_query_mask_buffer
(
self
,
mask_status
:
torch
.
Tensor
):
"""
Query the mask status of all ranks
Arguments:
mask_status: `[num_ranks]` with `torch.int`, the mask status of each rank. `1` means mask and `0` means unmasked.
"""
self
.
runtime
.
low_latency_query_mask_buffer
(
mask_status
)
def
low_latency_clean_mask_buffer
(
self
):
"""
Clean the mask buffer
"""
self
.
runtime
.
low_latency_clean_mask_buffer
()
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment