Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
DeepEP
Commits
da13c63a
Commit
da13c63a
authored
Nov 04, 2025
by
lishen
Browse files
完成低延迟接口功能
parent
09cb2b03
Changes
25
Expand all
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
917 additions
and
1072 deletions
+917
-1072
1.sh
1.sh
+6
-3
2.sh
2.sh
+6
-3
build.sh
build.sh
+5
-3
csrc/config.hpp
csrc/config.hpp
+1
-1
csrc/deep_ep.cu
csrc/deep_ep.cu
+67
-174
csrc/deep_ep.hpp
csrc/deep_ep.hpp
+16
-25
csrc/kernels/api.cuh
csrc/kernels/api.cuh
+22
-33
csrc/kernels/internode_ll.cu
csrc/kernels/internode_ll.cu
+351
-646
csrc/kernels/utils.cuh
csrc/kernels/utils.cuh
+20
-4
deep_ep/buffer.py
deep_ep/buffer.py
+33
-141
rocshmem_dir/bin/rocshmem_info
rocshmem_dir/bin/rocshmem_info
+0
-0
rocshmem_dir/include/rocshmem/rocshmem.hpp
rocshmem_dir/include/rocshmem/rocshmem.hpp
+52
-13
rocshmem_dir/include/rocshmem/rocshmem_COLL.hpp
rocshmem_dir/include/rocshmem/rocshmem_COLL.hpp
+8
-22
rocshmem_dir/include/rocshmem/rocshmem_common.hpp
rocshmem_dir/include/rocshmem/rocshmem_common.hpp
+18
-1
rocshmem_dir/include/rocshmem/rocshmem_config.h
rocshmem_dir/include/rocshmem/rocshmem_config.h
+1
-0
rocshmem_dir/include/rocshmem/rocshmem_mpi.hpp
rocshmem_dir/include/rocshmem/rocshmem_mpi.hpp
+143
-0
rocshmem_dir/lib/cmake/rocshmem/rocshmem-targets.cmake
rocshmem_dir/lib/cmake/rocshmem/rocshmem-targets.cmake
+1
-1
rocshmem_dir/lib/librocshmem.a
rocshmem_dir/lib/librocshmem.a
+0
-0
rocshmem_dir/rocshmem/lib/cmake/rocshmem-config-version.cmake
...hmem_dir/rocshmem/lib/cmake/rocshmem-config-version.cmake
+65
-1
rocshmem_dir/rocshmem/lib/cmake/rocshmem-config.cmake
rocshmem_dir/rocshmem/lib/cmake/rocshmem-config.cmake
+102
-1
No files found.
1.sh
View file @
da13c63a
pgrep
-f
/usr/bin/python | xargs
kill
-9
export
OMPI_MCA_pml
=
ucx
export
OMPI_MCA_pml
=
ucx
export
OMPI_MCA_osc
=
ucx
export
OMPI_MCA_osc
=
ucx
export
OMPI_MCA_coll_hcoll_enable
=
0
export
OMPI_MCA_coll_hcoll_enable
=
0
...
@@ -7,8 +9,9 @@ export OMPI_MCA_rmaps_base_mapping_policy="slot:numa"
...
@@ -7,8 +9,9 @@ export OMPI_MCA_rmaps_base_mapping_policy="slot:numa"
export
ROCSHMEM_MAX_NUM_CONTEXTS
=
32
export
ROCSHMEM_MAX_NUM_CONTEXTS
=
32
export
UCX_ROCM_IPC_SIGPOOL_MAX_ELEMS
=
16384
export
UCX_ROCM_IPC_SIGPOOL_MAX_ELEMS
=
16384
export
UCX_NET_DEVICES
=
mlx5_2:1,mlx5_3:1,mlx5_4:1,mlx5_5:1,mlx5_6:1,mlx5_7:1,mlx5_8:1,mlx5_9:1
export
UCX_NET_DEVICES
=
mlx5_2:1,mlx5_3:1,mlx5_4:1,mlx5_5:1,mlx5_6:1,mlx5_7:1,mlx5_8:1,mlx5_9:1
export
ROCSHMEM_ALLOWED_IBV_DEVICES
=
mlx5_2,mlx5_3,mlx5_4,mlx5_5,mlx5_6,mlx5_7,mlx5_8,mlx5_9
export
HIP_VISIBLE_DEVICES
=
0,1,2,3,4,5,6,7
export
HIP_VISIBLE_DEVICES
=
0,1,2,3,4,5,6,7
# export ROCSHMEM_HEAP_SIZE=536870912 805306368 10737418240
# export ROCSHMEM_HEAP_SIZE=536870912 805306368 10737418240
export
ROCSHMEM_HEAP_SIZE
=
53687
09
1
2
export
ROCSHMEM_HEAP_SIZE
=
288010
09
9
2
export
PYTHONPATH
=
/public/home/lishen/Tmp/DeepEP:
$PYTHONPATH
# torchrun --nproc-per-node=1 --nnodes=2 --node-rank=0 --master-addr="10.16.1.37" --master-port=1234 tests/test_internode.py
torchrun
--nproc-per-node
=
1
--nnodes
=
2
--node-rank
=
0
--master-addr
=
"10.16.1.37"
--master-port
=
1234 tests/test_
internode
.py
torchrun
--nproc-per-node
=
1
--nnodes
=
2
--node-rank
=
0
--master-addr
=
"10.16.1.37"
--master-port
=
1234 tests/test_
low_latency
.py
2.sh
View file @
da13c63a
pgrep
-f
/usr/bin/python | xargs
kill
-9
export
OMPI_MCA_pml
=
ucx
export
OMPI_MCA_pml
=
ucx
export
OMPI_MCA_osc
=
ucx
export
OMPI_MCA_osc
=
ucx
export
OMPI_MCA_coll_hcoll_enable
=
0
export
OMPI_MCA_coll_hcoll_enable
=
0
...
@@ -7,8 +9,9 @@ export OMPI_MCA_rmaps_base_mapping_policy="slot:numa"
...
@@ -7,8 +9,9 @@ export OMPI_MCA_rmaps_base_mapping_policy="slot:numa"
export
ROCSHMEM_MAX_NUM_CONTEXTS
=
32
export
ROCSHMEM_MAX_NUM_CONTEXTS
=
32
export
UCX_ROCM_IPC_SIGPOOL_MAX_ELEMS
=
16384
export
UCX_ROCM_IPC_SIGPOOL_MAX_ELEMS
=
16384
export
UCX_NET_DEVICES
=
mlx5_2:1,mlx5_3:1,mlx5_4:1,mlx5_5:1,mlx5_6:1,mlx5_7:1,mlx5_8:1,mlx5_9:1
export
UCX_NET_DEVICES
=
mlx5_2:1,mlx5_3:1,mlx5_4:1,mlx5_5:1,mlx5_6:1,mlx5_7:1,mlx5_8:1,mlx5_9:1
export
ROCSHMEM_ALLOWED_IBV_DEVICES
=
mlx5_2,mlx5_3,mlx5_4,mlx5_5,mlx5_6,mlx5_7,mlx5_8,mlx5_9
export
HIP_VISIBLE_DEVICES
=
0,1,2,3,4,5,6,7
export
HIP_VISIBLE_DEVICES
=
0,1,2,3,4,5,6,7
# export ROCSHMEM_HEAP_SIZE=536870912 805306368 10737418240
# export ROCSHMEM_HEAP_SIZE=536870912 805306368 10737418240
export
ROCSHMEM_HEAP_SIZE
=
53687
09
1
2
export
ROCSHMEM_HEAP_SIZE
=
288010
09
9
2
export
PYTHONPATH
=
/public/home/lishen/Tmp/DeepEP:
$PYTHONPATH
# torchrun --nproc-per-node=1 --nnodes=2 --node-rank=1 --master-addr="10.16.1.37" --master-port=1234 tests/test_internode.py
torchrun
--nproc-per-node
=
1
--nnodes
=
2
--node-rank
=
1
--master-addr
=
"10.16.1.37"
--master-port
=
1234 tests/test_
internode
.py
torchrun
--nproc-per-node
=
1
--nnodes
=
2
--node-rank
=
1
--master-addr
=
"10.16.1.37"
--master-port
=
1234 tests/test_
low_latency
.py
build.sh
View file @
da13c63a
...
@@ -8,8 +8,10 @@ fi
...
@@ -8,8 +8,10 @@ fi
PYTHON_INCLUDE
=
$(
python3
-c
"from sysconfig import get_paths; print(get_paths()['include'])"
)
PYTHON_INCLUDE
=
$(
python3
-c
"from sysconfig import get_paths; print(get_paths()['include'])"
)
PYTHON_PLATLIB
=
$(
python3
-c
"from sysconfig import get_paths; print(get_paths()['platlib'])"
)
PYTHON_PLATLIB
=
$(
python3
-c
"from sysconfig import get_paths; print(get_paths()['platlib'])"
)
INCLUDE_PATHS
=
${
INCLUDE_PATHS
:
=-Icsrc/ -I
$(
pwd
)
/rocshmem_dir/include/ -I/opt/mpi/include -I
${
PYTHON_PLATLIB
}
/torch/include -I
${
PYTHON_PLATLIB
}
/torch/include/torch/csrc/api/include -I
${
PYTHON_PLATLIB
}
/torch/include/TH -I
${
PYTHON_PLATLIB
}
/torch/include/THC -I
${
PYTHON_PLATLIB
}
/torch/include/THH -I/opt/dtk/include -I
${
PYTHON_INCLUDE
}}
ROCSHMEM_INSTALL_PREFIX
=
${
ROCSHMEM_INSTALL_PREFIX
:
=
$(
pwd
)
/rocshmem_dir
}
COMPILE_OPTIONS
=
${
COMPILE_OPTIONS
:
= -fPIC -D__HIP_PLATFORM_AMD__=1 -DUSE_ROCM=1 -DHIPBLAS_V2 -DCUDA_HAS_FP16=1 -D__HIP_NO_HALF_OPERATORS__=1 -D__HIP_NO_HALF_CONVERSIONS__=1 -O3 -fgpu-rdc -DTORCH_API_INCLUDE_EXTENSION_H
'-DPYBIND11_COMPILER_TYPE="_gcc"'
'-DPYBIND11_STDLIB="_libstdcpp"'
'-DPYBIND11_BUILD_ABI="_cxxabi1014"'
-DTORCH_EXTENSION_NAME=deep_ep_cpp -D_GLIBCXX_USE_CXX11_ABI=1 --offload-arch=gfx936 -std=c++17
}
INCLUDE_PATHS
=
${
INCLUDE_PATHS
:
=-Icsrc/ -I
${
ROCSHMEM_INSTALL_PREFIX
}
/include/ -I/opt/mpi/include -I
${
PYTHON_PLATLIB
}
/torch/include -I
${
PYTHON_PLATLIB
}
/torch/include/torch/csrc/api/include -I
${
PYTHON_PLATLIB
}
/torch/include/TH -I
${
PYTHON_PLATLIB
}
/torch/include/THC -I
${
PYTHON_PLATLIB
}
/torch/include/THH -I/opt/dtk/include -I
${
PYTHON_INCLUDE
}}
COMPILE_OPTIONS
=
${
COMPILE_OPTIONS
:
= -fPIC -D__HIP_PLATFORM_AMD__=1 -DUSE_ROCM=1 -DHIPBLAS_V2 -DCUDA_HAS_FP16=1 -D__HIP_NO_HALF_OPERATORS__=1 -D__HIP_NO_HALF_CONVERSIONS__=1 -O3 -fgpu-rdc -DTORCH_API_INCLUDE_EXTENSION_H
'-DPYBIND11_COMPILER_TYPE="_gcc"'
'-DPYBIND11_STDLIB="_libstdcpp"'
'-DPYBIND11_BUILD_ABI="_cxxabi1014"'
-DTORCH_EXTENSION_NAME=deep_ep_cpp -D_GLIBCXX_USE_CXX11_ABI=1 --offload-arch=gfx936 -std=c++17 -Wno-return-type
}
hipcc
${
INCLUDE_PATHS
}
-c
$(
pwd
)
/csrc/kernels/runtime.cu
-o
build_/runtime.o
${
COMPILE_OPTIONS
}
hipcc
${
INCLUDE_PATHS
}
-c
$(
pwd
)
/csrc/kernels/runtime.cu
-o
build_/runtime.o
${
COMPILE_OPTIONS
}
hipcc
${
INCLUDE_PATHS
}
-c
$(
pwd
)
/csrc/kernels/layout.cu
-o
build_/layout.o
${
COMPILE_OPTIONS
}
hipcc
${
INCLUDE_PATHS
}
-c
$(
pwd
)
/csrc/kernels/layout.cu
-o
build_/layout.o
${
COMPILE_OPTIONS
}
...
@@ -18,7 +20,7 @@ hipcc ${INCLUDE_PATHS} -c $(pwd)/csrc/kernels/internode.cu -o build_/internode.o
...
@@ -18,7 +20,7 @@ hipcc ${INCLUDE_PATHS} -c $(pwd)/csrc/kernels/internode.cu -o build_/internode.o
hipcc
${
INCLUDE_PATHS
}
-c
$(
pwd
)
/csrc/kernels/internode_ll.cu
-o
build_/internode_ll.o
${
COMPILE_OPTIONS
}
hipcc
${
INCLUDE_PATHS
}
-c
$(
pwd
)
/csrc/kernels/internode_ll.cu
-o
build_/internode_ll.o
${
COMPILE_OPTIONS
}
hipcc
${
INCLUDE_PATHS
}
-c
$(
pwd
)
/csrc/deep_ep.cu
-o
build_/deep_ep.o
${
COMPILE_OPTIONS
}
hipcc
${
INCLUDE_PATHS
}
-c
$(
pwd
)
/csrc/deep_ep.cu
-o
build_/deep_ep.o
${
COMPILE_OPTIONS
}
hipcc
-Wno-unused-result
-Wsign-compare
-DNDEBUG
-g
-fwrapv
-O2
-Wall
-g
-fstack-protector-strong
-Wformat
-Werror
=
format-security
-g
-fwrapv
-O2
-shared
-Wl
,-O1
-Wl
,-Bsymbolic-functions build_/internode.o build_/intranode.o build_/runtime.o build_/deep_ep.o build_/layout.o build_/internode_ll.o
-L
$
(
pwd
)
/rocshmem_dir
/lib/
-L
/opt/mpi/lib
-L
/opt/dtk/hip/lib
-L
/usr/lib/x86_64-linux-gnu
-lhipblaslt
-lamdhip64
-o
deep_ep/deep_ep_cpp.cpython-310-x86_64-linux-gnu.so
-Wl
,-rpath,/opt/dtk/lib
-fgpu-rdc
--hip-link
--offload-arch
=
gfx936
-shared
-Wl
,-soname,deep_ep/deep_ep_cpp.cpython-310-x86_64-linux-gnu.so
-Wl
,-rpath,
$
(
pwd
)
/rocshmem_dir
/lib/
-L
"/opt/dtk/llvm/lib/clang/15.0.0/include/../lib/linux"
-lclang_rt
.builtins-x86_64 /opt/dtk/hip/lib/libgalaxyhip.so /opt/dtk/llvm/lib/clang/15.0.0/lib/linux/libclang_rt.builtins-x86_64.a /opt/hyhal/lib/libhsa-runtime64.so.1.11.0
-L
${
PYTHON_PLATLIB
}
/torch/lib
-L
/opt/dtk/lib
-L
/opt/dtk/hip/lib
-L
/usr/local/lib
-lc10
-ltorch
-ltorch_cpu
-ltorch_python
-lamdhip64
-lc10_hip
-ltorch_hip
-lrocm-core
-lrocm_smi64
-l
:librocshmem.a
-fgpu-rdc
--hip-link
-lamdhip64
-lhsa-runtime64
-l
:libmpi.so
-Wl
,-rpath,/opt/mpi/lib/
-libverbs
-lmlx5
hipcc
-Wno-unused-result
-Wsign-compare
-DNDEBUG
-g
-fwrapv
-O2
-Wall
-g
-fstack-protector-strong
-Wformat
-Werror
=
format-security
-g
-fwrapv
-O2
-shared
-Wl
,-O1
-Wl
,-Bsymbolic-functions build_/internode.o build_/intranode.o build_/runtime.o build_/deep_ep.o build_/layout.o build_/internode_ll.o
-L
$
{
ROCSHMEM_INSTALL_PREFIX
}
/lib/
-L
/opt/mpi/lib
-L
/opt/dtk/hip/lib
-L
/usr/lib/x86_64-linux-gnu
-lhipblaslt
-lamdhip64
-o
deep_ep/deep_ep_cpp.cpython-310-x86_64-linux-gnu.so
-Wl
,-rpath,/opt/dtk/lib
-fgpu-rdc
--hip-link
--offload-arch
=
gfx936
-shared
-Wl
,-soname,deep_ep/deep_ep_cpp.cpython-310-x86_64-linux-gnu.so
-Wl
,-rpath,
$
$
{
ROCSHMEM_INSTALL_PREFIX
}
/lib/
-L
"/opt/dtk/llvm/lib/clang/15.0.0/include/../lib/linux"
-lclang_rt
.builtins-x86_64 /opt/dtk/hip/lib/libgalaxyhip.so /opt/dtk/llvm/lib/clang/15.0.0/lib/linux/libclang_rt.builtins-x86_64.a /opt/hyhal/lib/libhsa-runtime64.so.1.11.0
-L
${
PYTHON_PLATLIB
}
/torch/lib
-L
/opt/dtk/lib
-L
/opt/dtk/hip/lib
-L
/usr/local/lib
-lc10
-ltorch
-ltorch_cpu
-ltorch_python
-lamdhip64
-lc10_hip
-ltorch_hip
-lrocm-core
-lrocm_smi64
-l
:librocshmem.a
-fgpu-rdc
--hip-link
-lamdhip64
-lhsa-runtime64
-l
:libmpi.so
-Wl
,-rpath,/opt/mpi/lib/
-libverbs
-lmlx5
# build whl
# build whl
echo
"Using Python:
$(
which python3
)
"
echo
"Using Python:
$(
which python3
)
"
...
...
csrc/config.hpp
View file @
da13c63a
...
@@ -136,7 +136,7 @@ struct LowLatencyLayout {
...
@@ -136,7 +136,7 @@ struct LowLatencyLayout {
LowLatencyLayout
(
void
*
rdma_buffer
,
int
num_max_dispatch_tokens_per_rank
,
int
hidden
,
LowLatencyLayout
(
void
*
rdma_buffer
,
int
num_max_dispatch_tokens_per_rank
,
int
hidden
,
int
num_ranks
,
int
num_experts
)
{
int
num_ranks
,
int
num_experts
)
{
const
int
num_scales
=
hidden
/
128
;
const
int
num_scales
=
hidden
/
FP8_QUANTIZATION_NUM_PER_CHANNEL
;
// Dispatch and combine layout:
// Dispatch and combine layout:
// - 2 symmetric odd/even send buffer
// - 2 symmetric odd/even send buffer
...
...
csrc/deep_ep.cu
View file @
da13c63a
This diff is collapsed.
Click to expand it.
csrc/deep_ep.hpp
View file @
da13c63a
...
@@ -26,14 +26,17 @@ private:
...
@@ -26,14 +26,17 @@ private:
void
*
buffer_ptrs
[
NUM_MAX_NVL_PEERS
]
=
{
nullptr
};
void
*
buffer_ptrs
[
NUM_MAX_NVL_PEERS
]
=
{
nullptr
};
void
**
buffer_ptrs_gpu
=
nullptr
;
void
**
buffer_ptrs_gpu
=
nullptr
;
void
*
nvl_buffer_ptrs
[
NUM_MAX_NVL_PEERS
]
=
{
nullptr
};
void
**
nvl_buffer_ptrs_gpu
=
nullptr
;
// NVSHMEM Buffer
// NVSHMEM Buffer
int64_t
num_rdma_bytes
;
int64_t
num_rdma_bytes
;
void
*
rdma_buffer_ptr
=
nullptr
;
void
*
rdma_buffer_ptr
=
nullptr
;
// Shrink mode buffer
// Shrink mode buffer
bool
enable_shrink
=
false
;
bool
enable_shrink
=
false
;
int
*
mask_buffer_ptr
=
nullptr
;
int
*
mask_buffer_ptr
=
nullptr
;
int
*
sync_buffer_ptr
=
nullptr
;
int
*
sync_buffer_ptr
=
nullptr
;
// Device info and communication
// Device info and communication
int
device_id
;
int
device_id
;
...
@@ -171,31 +174,19 @@ public:
...
@@ -171,31 +174,19 @@ public:
void
clean_low_latency_buffer
(
int
num_max_dispatch_tokens_per_rank
,
int
hidden
,
void
clean_low_latency_buffer
(
int
num_max_dispatch_tokens_per_rank
,
int
hidden
,
int
num_experts
);
int
num_experts
);
std
::
tuple
<
torch
::
Tensor
,
std
::
optional
<
torch
::
Tensor
>
,
torch
::
Tensor
,
torch
::
Tensor
,
std
::
tuple
<
torch
::
Tensor
,
std
::
optional
<
torch
::
Tensor
>
,
torch
::
Tensor
,
torch
::
Tensor
,
torch
::
Tensor
,
std
::
optional
<
EventHandle
>
,
std
::
optional
<
std
::
function
<
void
()
>>>
torch
::
Tensor
,
std
::
optional
<
EventHandle
>
,
std
::
optional
<
std
::
function
<
void
()
>>>
low_latency_dispatch
(
const
torch
::
Tensor
&
x
,
const
torch
::
Tensor
&
topk_idx
,
low_latency_dispatch
(
const
torch
::
Tensor
&
x
,
const
torch
::
Tensor
&
topk_idx
,
int
num_max_dispatch_tokens_per_rank
,
int
num_experts
,
const
std
::
optional
<
torch
::
Tensor
>
&
cumulative_local_expert_recv_stats
,
bool
use_fp8
,
bool
async
,
bool
return_recv_hook
);
const
std
::
optional
<
torch
::
Tensor
>
&
dispatch_wait_recv_cost_stats
,
int
num_max_dispatch_tokens_per_rank
,
int
num_experts
,
bool
use_fp8
,
bool
round_scale
,
bool
use_ue8m0
,
bool
async
,
bool
return_recv_hook
);
std
::
tuple
<
torch
::
Tensor
,
std
::
optional
<
EventHandle
>
,
std
::
optional
<
std
::
function
<
void
()
>>>
std
::
tuple
<
torch
::
Tensor
,
std
::
optional
<
EventHandle
>
,
std
::
optional
<
std
::
function
<
void
()
>>>
low_latency_combine
(
const
torch
::
Tensor
&
x
,
const
torch
::
Tensor
&
topk_idx
,
low_latency_combine
(
const
torch
::
Tensor
&
x
,
const
torch
::
Tensor
&
topk_idx
,
const
torch
::
Tensor
&
topk_weights
,
const
torch
::
Tensor
&
topk_weights
,
const
torch
::
Tensor
&
src_info
,
const
torch
::
Tensor
&
src_info
,
const
torch
::
Tensor
&
layout_range
,
const
torch
::
Tensor
&
layout_range
,
int
num_max_dispatch_tokens_per_rank
,
int
num_experts
,
const
std
::
optional
<
torch
::
Tensor
>
&
combine_wait_recv_cost_stats
,
int
num_max_dispatch_tokens_per_rank
,
int
num_experts
,
bool
use_logfmt
,
bool
zero_copy
,
bool
async
,
bool
return_recv_hook
,
bool
zero_copy
,
bool
async
,
bool
return_recv_hook
,
const
std
::
optional
<
torch
::
Tensor
>
&
out
=
std
::
nullopt
);
const
std
::
optional
<
torch
::
Tensor
>&
out
=
std
::
nullopt
);
torch
::
Tensor
get_next_low_latency_combine_buffer
(
int
num_max_dispatch_tokens_per_rank
,
int
hidden
,
int
num_experts
)
const
;
void
low_latency_update_mask_buffer
(
int
rank_to_mask
,
bool
mask
);
void
low_latency_query_mask_buffer
(
const
torch
::
Tensor
&
mask_status
);
void
low_latency_clean_mask_buffer
(
);
torch
::
Tensor
get_next_low_latency_combine_buffer
(
int
num_max_dispatch_tokens_per_rank
,
int
hidden
,
int
num_experts
);
};
};
}
// namespace deep_ep
}
// namespace deep_ep
csrc/kernels/api.cuh
View file @
da13c63a
...
@@ -134,42 +134,31 @@ void combine(hipDataType type, void *combined_x, float *combined_topk_weights,
...
@@ -134,42 +134,31 @@ void combine(hipDataType type, void *combined_x, float *combined_topk_weights,
// Internode low-latency kernels
// Internode low-latency kernels
namespace
internode_ll
{
namespace
internode_ll
{
void
clean_low_latency_buffer
(
int64_t
*
clean_0
,
int
num_clean_int_0
,
void
clean_low_latency_buffer
(
int64_t
*
clean_0
,
int
num_clean_int_0
,
int64_t
*
clean_1
,
int
num_clean_int_1
,
int64_t
*
clean_1
,
int
num_clean_int_1
,
int
rank
,
int
num_ranks
,
hipStream_t
stream
);
int
*
mask_buffer
,
int
*
sync_buffer
,
hipStream_t
stream
);
void
dispatch
(
void
*
packed_recv_x
,
float
*
packed_recv_x_scales
,
void
dispatch
(
void
*
packed_recv_x
,
void
*
packed_recv_x_scales
,
int
*
packed_recv_src_info
,
int64_t
*
packed_recv_layout_range
,
int
*
packed_recv_src_info
,
int64_t
*
packed_recv_layout_range
,
int
*
packed_recv_count
,
int
*
packed_recv_count
,
int
*
global_atomic_counter
,
int
*
global_atomic_counter
,
int
*
mask_buffer
,
int
*
cumulative_local_expert_recv_stats
,
void
*
rdma_recv_x
,
int64_t
*
rdma_recv_count
,
void
*
rdma_x
,
int64_t
*
dispatch_wait_recv_cost_stats
,
const
void
*
x
,
const
int64_t
*
topk_idx
,
void
*
rdma_recv_x
,
int64_t
*
rdma_recv_count
,
void
*
rdma_x
,
int64_t
*
next_clean
,
int
num_next_clean_int
,
const
void
*
x
,
const
int64_t
*
topk_idx
,
int
num_tokens
,
int
hidden
,
int
num_max_dispatch_tokens_per_rank
,
int64_t
*
next_clean
,
int
num_next_clean_int
,
int
num_topk
,
int
num_experts
,
int
rank
,
int
num_ranks
,
bool
use_fp8
,
int
num_tokens
,
int
hidden
,
int
num_max_dispatch_tokens_per_rank
,
void
*
workspace
,
hipStream_t
stream
,
int
phases
);
int
num_topk
,
int
num_experts
,
int
rank
,
int
num_ranks
,
bool
use_fp8
,
bool
round_scale
,
bool
use_ue8m0
,
void
*
workspace
,
int
num_device_sms
,
hipStream_t
stream
,
int
phases
);
void
combine
(
void
*
combined_x
,
void
combine
(
void
*
combined_x
,
void
*
rdma_recv_x
,
int64_t
*
rdma_recv_flag
,
void
*
rdma_send_x
,
void
*
rdma_recv_x
,
int64_t
*
rdma_recv_flag
,
void
*
rdma_send_x
,
const
void
*
x
,
const
int64_t
*
topk_idx
,
const
float
*
topk_weights
,
const
void
*
x
,
const
int64_t
*
topk_idx
,
const
float
*
topk_weights
,
const
int
*
src_info
,
const
int64_t
*
layout_range
,
const
int
*
src_info
,
const
int64_t
*
layout_range
,
int
*
global_atomic_counter
,
int
*
global_atomic_counter
,
int
*
mask_buffer
,
int64_t
*
combine_wait_recv_cost_stats
,
int64_t
*
next_clean
,
int
num_next_clean_int
,
int64_t
*
next_clean
,
int
num_next_clean_int
,
int
num_combined_tokens
,
int
num_combined_tokens
,
int
hidden
,
int
num_max_dispatch_tokens_per_rank
,
int
hidden
,
int
num_max_dispatch_tokens_per_rank
,
int
num_topk
,
int
num_experts
,
int
rank
,
int
num_ranks
,
int
num_topk
,
int
num_experts
,
int
rank
,
int
num_ranks
,
bool
use_logfmt
,
void
*
workspace
,
hipStream_t
stream
,
void
*
workspace
,
int
num_device_sms
,
hipStream_t
stream
,
int
phases
,
bool
zero_copy
);
int
phases
,
bool
zero_copy
);
void
query_mask_buffer
(
int
*
mask_buffer_ptr
,
int
num_ranks
,
int
*
output_mask_tensor
,
hipStream_t
stream
);
void
update_mask_buffer
(
int
*
mask_buffer_ptr
,
int
rank_to_mask
,
bool
mask
,
hipStream_t
stream
);
void
clean_mask_buffer
(
int
*
mask_buffer_ptr
,
int
num_ranks
,
hipStream_t
stream
);
}
// namespace internode_ll
}
// namespace internode_ll
...
...
csrc/kernels/internode_ll.cu
View file @
da13c63a
This diff is collapsed.
Click to expand it.
csrc/kernels/utils.cuh
View file @
da13c63a
...
@@ -31,6 +31,21 @@
...
@@ -31,6 +31,21 @@
} \
} \
}
}
#define UNROLLED_WARP_COPY_LL(UNROLL_FACTOR, LANE_ID, N, DST, SRC, LD_FUNC, ST_FUNC) \
{ \
constexpr int kLoopStride = kWarpSize * (UNROLL_FACTOR); \
typename std::remove_reference<decltype(LD_FUNC((SRC) + 0))>::type unrolled_values[(UNROLL_FACTOR)]; \
auto __src = (SRC); \
auto __dst = (DST); \
for(int __i = (LANE_ID); __i < ((N) / kLoopStride) * kLoopStride; __i += kLoopStride) { \
_Pragma("unroll") for(int __j = 0; __j < (UNROLL_FACTOR); ++__j) unrolled_values[__j] = LD_FUNC(__src + __i + __j * kWarpSize); \
_Pragma("unroll") for(int __j = 0; __j < (UNROLL_FACTOR); ++__j) ST_FUNC(__dst + __i + __j * kWarpSize, unrolled_values[__j]); \
} \
for(int __i = ((N) / kLoopStride) * kLoopStride + (LANE_ID); __i < (N); __i += kWarpSize) \
ST_FUNC(__dst + __i, LD_FUNC(__src + __i)); \
}
#define UNROLLED_WARP_COPY_EMULATED(UNROLL_FACTOR, LANE_ID, N, DST, SRC, LD_FUNC, ST_FUNC) \
#define UNROLLED_WARP_COPY_EMULATED(UNROLL_FACTOR, LANE_ID, N, DST, SRC, LD_FUNC, ST_FUNC) \
{ \
{ \
constexpr int kLoopStride = kEmulatedWarpSize * (UNROLL_FACTOR); \
constexpr int kLoopStride = kEmulatedWarpSize * (UNROLL_FACTOR); \
...
@@ -329,8 +344,8 @@ __device__ __forceinline__ dtype_t broadcast(dtype_t &ptr, int src_lane_idx) {
...
@@ -329,8 +344,8 @@ __device__ __forceinline__ dtype_t broadcast(dtype_t &ptr, int src_lane_idx) {
#ifdef USE_ROCM
#ifdef USE_ROCM
constexpr
float
kFP8Margin
=
1e-4
;
constexpr
float
kFP8Margin
=
1e-4
;
constexpr
float
kFinfoAmaxE4M3
=
240.0
f
;
constexpr
float
kFinfoAmaxE4M3
=
240.0
f
;
constexpr
float
kFinfoAmaxInvE4M3
=
1.0
f
/
kFinfoAmaxE4M3
;
constexpr
float
kFinfoAmaxInvE4M3
=
1.0
f
/
kFinfoAmaxE4M3
;
#else
#else
constexpr
float
kFP8Margin
=
1e-4
;
constexpr
float
kFP8Margin
=
1e-4
;
constexpr
float
kFinfoAmaxE4M3
=
448.0
f
;
constexpr
float
kFinfoAmaxE4M3
=
448.0
f
;
...
@@ -350,8 +365,9 @@ __forceinline__ __device__ int fast_log2_ceil(float x) {
...
@@ -350,8 +365,9 @@ __forceinline__ __device__ int fast_log2_ceil(float x) {
return
exp_x
-
127
+
(
man_bits
!=
0
);
return
exp_x
-
127
+
(
man_bits
!=
0
);
}
}
__forceinline__
__device__
void
calculate_fp8_scales
(
float
amax
,
float
&
scale
,
float
&
scale_inv
,
bool
round_scale
)
{
template
<
bool
kRoundScale
>
if
(
round_scale
)
{
__forceinline__
__device__
void
calculate_fp8_scales
(
float
amax
,
float
&
scale
,
float
&
scale_inv
)
{
if
constexpr
(
kRoundScale
)
{
auto
exp_scale_inv
=
fast_log2_ceil
(
amax
*
kFinfoAmaxInvE4M3
);
auto
exp_scale_inv
=
fast_log2_ceil
(
amax
*
kFinfoAmaxInvE4M3
);
scale
=
fast_pow2
(
-
exp_scale_inv
);
scale
=
fast_pow2
(
-
exp_scale_inv
);
scale_inv
=
fast_pow2
(
exp_scale_inv
);
scale_inv
=
fast_pow2
(
exp_scale_inv
);
...
...
deep_ep/buffer.py
View file @
da13c63a
...
@@ -802,26 +802,17 @@ class Buffer:
...
@@ -802,26 +802,17 @@ class Buffer:
self
.
runtime
.
clean_low_latency_buffer
(
num_max_dispatch_tokens_per_rank
,
hidden
,
num_experts
)
self
.
runtime
.
clean_low_latency_buffer
(
num_max_dispatch_tokens_per_rank
,
hidden
,
num_experts
)
# noinspection PyTypeChecker
# noinspection PyTypeChecker
def
low_latency_dispatch
(
def
low_latency_dispatch
(
self
,
x
:
torch
.
Tensor
,
topk_idx
:
torch
.
Tensor
,
self
,
num_max_dispatch_tokens_per_rank
:
int
,
num_experts
:
int
,
x
:
torch
.
Tensor
,
use_fp8
:
bool
=
True
,
async_finish
:
bool
=
False
,
return_recv_hook
:
bool
=
False
)
->
\
topk_idx
:
torch
.
Tensor
,
Tuple
[
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
],
torch
.
Tensor
,
Tuple
,
EventOverlap
,
Callable
]:
num_max_dispatch_tokens_per_rank
:
int
,
num_experts
:
int
,
cumulative_local_expert_recv_stats
:
Optional
[
torch
.
Tensor
]
=
None
,
dispatch_wait_recv_cost_stats
:
Optional
[
torch
.
Tensor
]
=
None
,
use_fp8
:
bool
=
True
,
round_scale
:
bool
=
False
,
use_ue8m0
:
bool
=
False
,
async_finish
:
bool
=
False
,
return_recv_hook
:
bool
=
False
,
)
->
Tuple
[
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
],
torch
.
Tensor
,
Tuple
,
EventOverlap
,
Callable
]:
"""
"""
A low-latency implementation for dispatching with IBGDA.
A low-latency implementation for dispatching with IBGDA.
This kernel requires all the ranks (no matter intranode or internode) should be visible via RDMA
This kernel requires all the ranks (no matter intranode or internode) should be visible via RDMA
(specifically, IBGDA must be enabled).
(specifically, IBGDA must be enabled).
Warning: as there are only two buffers, and the returned tensors reuse the buffer, you cannot hold more than 2
Even for ranks in the same node, NVLink are fully disabled for simplicity.
low-latency kernels' result tensors at a single moment.
Warning: as there are only two buffers, and the returned tensors reuse the buffer, you can not hold more than 2
low-latency kernels' result tensor at a single moment.
Arguments:
Arguments:
x: `torch.Tensor` with `torch.bfloat16`, shaped as `[num_tokens, hidden]`, only several hidden shapes are
x: `torch.Tensor` with `torch.bfloat16`, shaped as `[num_tokens, hidden]`, only several hidden shapes are
...
@@ -830,105 +821,52 @@ class Buffer:
...
@@ -830,105 +821,52 @@ class Buffer:
are supported. `-1` indices (not selecting any expert) are supported.
are supported. `-1` indices (not selecting any expert) are supported.
num_max_dispatch_tokens_per_rank: the maximum number of tokens to dispatch, all the ranks must hold the same value.
num_max_dispatch_tokens_per_rank: the maximum number of tokens to dispatch, all the ranks must hold the same value.
num_experts: the number of all experts.
num_experts: the number of all experts.
cumulative_local_expert_recv_stats: a cumulative expert count tensor for statistics, which should have shape
`[num_local_experts]` and be typed as `torch.int`. This is useful for online service EP load balance
monitoring.
dispatch_wait_recv_cost_stats: a cumulative time spent waiting to receive each token tensor for statistics,
which should have shape `[num_ranks, num_ranks]` and be typed as `torch.int64`.
This is useful for detecting and pre-cisely localizing slow anomalies.
use_fp8: whether to enable FP8 casting, with this, the received data will be a tuple of FP8 tensor and scaling factors.
use_fp8: whether to enable FP8 casting, with this, the received data will be a tuple of FP8 tensor and scaling factors.
round_scale: whether round the scaling factors into power of 2.
use_ue8m0: whether use UE8M0 as scaling factor format (available only with `round_scale=True`).
async_finish: the current stream will not wait for the communication kernels to be finished if set.
async_finish: the current stream will not wait for the communication kernels to be finished if set.
return_recv_hook: return a receiving hook if set. If set, the kernel will just do the RDMA request issues,
return_recv_hook: return a receiving hook if set. If set, the kernel will just do the RDMA request issues,
but **without actually receiving the data**. You must call the received hook to make sure the data's arrival.
but **without actually receiving the data**. You must call the received hook to make sure the data's arrival.
If you
do
not set this flag, the kernel will ensure the data's arrival.
If you not set this flag, the kernel will ensure the data's arrival.
Returns:
Returns:
recv_x: a tensor or tuple with received tokens for each expert.
recv_x: a tensor or tuple with received tokens for each expert.
With `use_fp8=True`: the first element is a `torch.Tensor` shaped as
With `use_fp8=True`: the first element is a `torch.Tensor` shaped as
`[num_local_experts, num_max_dispatch_tokens_per_rank * num_ranks, hidden]` with `torch.float8_e4m3fn`.
`[num_local_experts, num_max_dispatch_tokens_per_rank * num_ranks, hidden]` with `torch.float8_e4m3fn`.
The second tensor is the corresponding scales for the first element with shape
The second tensor is the corresponding scales for the first element with shape
`[num_local_experts, num_max_dispatch_tokens_per_rank * num_ranks, hidden // 128]` with `torch.float`,
`[num_local_experts, num_max_dispatch_tokens_per_rank * num_ranks, hidden // 128]` with `torch.float`.
if `use_ue8m0=False`. With `use_ue8m0=True`, the second one is packed and shaped as
`[num_local_experts, num_max_dispatch_tokens_per_rank * num_ranks, hidden // 512]` with type `torch.int`.
Notice that, the last-two-dimension of the scaling tensors are in column-major for TMA compatibility.
Notice that, the last-two-dimension of the scaling tensors are in column-major for TMA compatibility.
With `use_fp8=False`, the result would be a tensor shaped as
With `use_fp8=False`, the result would be a tensor shaped as
`[num_local_experts, num_max_dispatch_tokens_per_rank * num_ranks, hidden]` with `torch.bfloat16`.
`[num_local_experts, num_max_dispatch_tokens_per_rank * num_ranks, hidden]` with `torch.bfloat16`.
Moreover, not all tokens are valid, only some of the `num_max_dispatch_tokens_per_rank * num_ranks` are,
Moreover, not all tokens are valid, only some of the `num_max_dispatch_tokens_per_rank * num_ranks` are,
as we do not synchronize CPU received count with GPU (also not incompatible with CUDA graph if synced).
as we do not synchronize CPU received count with GPU (also not incompatible with CUDA graph if synced).
recv_count: a tensor shaped `[num_local_experts]` with type `torch.int`, indicating how many tokens each
recv_count: a tensor shaped `[num_local_experts]` with type `torch.int`, indicating how many tokens each
expert receive
s
. As mentioned before,
not all
tokens are valid in `recv_x`.
expert receive. As mentioned before,
all not
tokens are valid in `recv_x`.
handle: the communication handle to be used in the `low_latency_combine` function.
handle: the communication handle to be used in the `low_latency_combine` function.
event: the event after executing the kernel (valid only if `async_finish` is set).
event: the event after executing the kernel (valid only if `async_finish` is set).
hook: the receiving hook function (valid only if `return_recv_hook` is set).
hook: the receiving hook function (valid only if `return_recv_hook` is set).
"""
"""
(
packed_recv_x
,
packed_recv_x_scales
,
packed_recv_count
,
packed_recv_src_info
,
packed_recv_layout_range
,
event
,
hook
=
\
packed_recv_x
,
self
.
runtime
.
low_latency_dispatch
(
x
,
topk_idx
,
packed_recv_x_scales
,
num_max_dispatch_tokens_per_rank
,
num_experts
,
packed_recv_count
,
use_fp8
,
async_finish
,
return_recv_hook
)
packed_recv_src_info
,
handle
=
(
packed_recv_src_info
,
packed_recv_layout_range
,
num_max_dispatch_tokens_per_rank
,
x
.
size
(
1
),
num_experts
)
packed_recv_layout_range
,
tensors_to_record
=
(
x
,
topk_idx
,
event
,
packed_recv_x
,
packed_recv_x_scales
,
packed_recv_count
,
hook
,
packed_recv_src_info
,
packed_recv_layout_range
)
)
=
self
.
runtime
.
low_latency_dispatch
(
return
(
packed_recv_x
,
packed_recv_x_scales
)
if
use_fp8
else
packed_recv_x
,
packed_recv_count
,
handle
,
\
x
,
EventOverlap
(
event
,
tensors_to_record
if
async_finish
else
None
),
hook
topk_idx
,
cumulative_local_expert_recv_stats
,
dispatch_wait_recv_cost_stats
,
num_max_dispatch_tokens_per_rank
,
num_experts
,
use_fp8
,
round_scale
,
use_ue8m0
,
async_finish
,
return_recv_hook
,
)
handle
=
(
packed_recv_src_info
,
packed_recv_layout_range
,
num_max_dispatch_tokens_per_rank
,
x
.
size
(
1
),
num_experts
,
)
tensors_to_record
=
(
x
,
topk_idx
,
packed_recv_x
,
packed_recv_x_scales
,
packed_recv_count
,
packed_recv_src_info
,
packed_recv_layout_range
,
cumulative_local_expert_recv_stats
,
)
return
(
(
packed_recv_x
,
packed_recv_x_scales
)
if
use_fp8
else
packed_recv_x
,
packed_recv_count
,
handle
,
EventOverlap
(
event
,
tensors_to_record
if
async_finish
else
None
),
hook
,
)
# noinspection PyTypeChecker
# noinspection PyTypeChecker
def
low_latency_combine
(
def
low_latency_combine
(
self
,
x
:
torch
.
Tensor
,
topk_idx
:
torch
.
Tensor
,
topk_weights
:
torch
.
Tensor
,
self
,
handle
:
tuple
,
zero_copy
:
bool
=
False
,
async_finish
:
bool
=
False
,
x
:
torch
.
Tensor
,
return_recv_hook
:
bool
=
False
,
out
:
Optional
[
torch
.
Tensor
]
=
None
)
->
\
topk_idx
:
torch
.
Tensor
,
Tuple
[
torch
.
Tensor
,
EventOverlap
,
Callable
]:
topk_weights
:
torch
.
Tensor
,
handle
:
tuple
,
use_logfmt
:
bool
=
False
,
zero_copy
:
bool
=
False
,
async_finish
:
bool
=
False
,
return_recv_hook
:
bool
=
False
,
out
:
Optional
[
torch
.
Tensor
]
=
None
,
combine_wait_recv_cost_stats
:
Optional
[
torch
.
Tensor
]
=
None
,
)
->
Tuple
[
torch
.
Tensor
,
EventOverlap
,
Callable
]:
"""
"""
A low-latency implementation for combining tokens (reduce **with weights**) with IBGDA.
A low-latency implementation for combining tokens (reduce **with weights**) with IBGDA.
This kernel requires all the ranks (no matter intranode or internode) should be visible via RDMA
This kernel requires all the ranks (no matter intranode or internode) should be visible via RDMA
(specifically, IBGDA must be enabled).
(specifically, IBGDA must be enabled).
Warning: as there are only two buffers, and the returned tensors reuse the buffer, you cannot hold more than 2
Even for ranks in the same node, NVLink are fully disabled for simplicity.
low-latency kernels' result tensors at a single moment.
Warning: as there are only two buffers, and the returned tensors reuse the buffer, you can not hold more than 2
low-latency kernels' result tensor at a single moment.
Arguments:
Arguments:
x: `[num_local_experts, num_max_dispatch_tokens_per_rank * num_ranks, hidden]` with `torch.bfloat16`,
x: `[num_local_experts, num_max_dispatch_tokens_per_rank * num_ranks, hidden]` with `torch.bfloat16`,
...
@@ -939,39 +877,23 @@ class Buffer:
...
@@ -939,39 +877,23 @@ class Buffer:
topk_weights: `[num_combined_tokens, num_topk]` with `torch.float`, the expert weights selected by the dispatched
topk_weights: `[num_combined_tokens, num_topk]` with `torch.float`, the expert weights selected by the dispatched
tokens. The received tokens will be reduced with the weights in this tensor.
tokens. The received tokens will be reduced with the weights in this tensor.
handle: the communication handle given by the `dispatch` function.
handle: the communication handle given by the `dispatch` function.
use_logfmt: whether to use an internal "LogFMT with dynamic per-64-channel cast" format (10 bits).
zero_copy: whether the tensor is already copied into the RDMA buffer, should be cooperative
zero_copy: whether the tensor is already copied into the RDMA buffer, should be cooperative
with `get_next_low_latency_combine_buffer`.
with `get_next_low_latency_combine_buffer`.
async_finish: the current stream will not wait for the communication kernels to be finished if set.
async_finish: the current stream will not wait for the communication kernels to be finished if set.
return_recv_hook: return a receiving hook if set. If set, the kernel will just do the RDMA request issues,
return_recv_hook: return a receiving hook if set. If set, the kernel will just do the RDMA request issues,
but **without actually receiving the data**. You must call the received hook to make sure the data's arrival.
but **without actually receiving the data**. You must call the received hook to make sure the data's arrival.
If you
do
not set this flag, the kernel will ensure the data's arrival.
If you not set this flag, the kernel will ensure the data's arrival.
out: the in-place output tensor, if set, the kernel will write the result to this tensor and return it directly.
out: the in-place output tensor, if set, the kernel will write the result to this tensor and return it directly.
combine_wait_recv_cost_stats: a cumulative time spent waiting to receive each token tensor for statistics,
which should have shape `[num_ranks, num_ranks]` and be typed as `torch.int64`.
This is useful for detecting and pre-cisely localizing slow anomalies.
Returns:
Returns:
combined_x: the reduced token tensor, with shape `[num_combined_tokens,
hidden
]` and type `torch.bfloat16`.
combined_x: the reduced token tensor, with shape `[num_combined_tokens,
num_topk
]` and type `torch.bfloat16`.
event: the event after executing the kernel (valid only if `async_finish` is set).
event: the event after executing the kernel (valid only if `async_finish` is set).
hook: the receiving hook function (valid only if `return_recv_hook` is set).
hook: the receiving hook function (valid only if `return_recv_hook` is set).
"""
"""
src_info
,
layout_range
,
num_max_dispatch_tokens_per_rank
,
hidden
,
num_experts
=
handle
src_info
,
layout_range
,
num_max_dispatch_tokens_per_rank
,
hidden
,
num_experts
=
handle
combined_x
,
event
,
hook
=
self
.
runtime
.
low_latency_combine
(
combined_x
,
event
,
hook
=
self
.
runtime
.
low_latency_combine
(
x
,
topk_idx
,
topk_weights
,
src_info
,
layout_range
,
x
,
num_max_dispatch_tokens_per_rank
,
num_experts
,
topk_idx
,
zero_copy
,
async_finish
,
return_recv_hook
,
out
)
topk_weights
,
src_info
,
layout_range
,
combine_wait_recv_cost_stats
,
num_max_dispatch_tokens_per_rank
,
num_experts
,
use_logfmt
,
zero_copy
,
async_finish
,
return_recv_hook
,
out
,
)
tensors_to_record
=
(
x
,
topk_idx
,
topk_weights
,
src_info
,
layout_range
,
combined_x
)
tensors_to_record
=
(
x
,
topk_idx
,
topk_weights
,
src_info
,
layout_range
,
combined_x
)
return
combined_x
,
EventOverlap
(
event
,
tensors_to_record
if
async_finish
else
None
),
hook
return
combined_x
,
EventOverlap
(
event
,
tensors_to_record
if
async_finish
else
None
),
hook
...
@@ -988,34 +910,4 @@ class Buffer:
...
@@ -988,34 +910,4 @@ class Buffer:
by yourself.
by yourself.
"""
"""
src_info
,
layout_range
,
num_max_dispatch_tokens_per_rank
,
hidden
,
num_experts
=
handle
src_info
,
layout_range
,
num_max_dispatch_tokens_per_rank
,
hidden
,
num_experts
=
handle
return
self
.
runtime
.
get_next_low_latency_combine_buffer
(
return
self
.
runtime
.
get_next_low_latency_combine_buffer
(
num_max_dispatch_tokens_per_rank
,
hidden
,
num_experts
)
num_max_dispatch_tokens_per_rank
,
hidden
,
num_experts
)
def
low_latency_update_mask_buffer
(
self
,
rank_to_mask
:
int
,
mask
:
bool
=
False
):
"""
Mask (unmask) a rank during communication (dispatch, combine, and clean)
Arguments:
rank: the rank to mask (unmask).
mask: if True, will mask the rank (do not recvfrom/sendto the rank), otherwise will unmask the rank.
"""
self
.
runtime
.
low_latency_update_mask_buffer
(
rank_to_mask
,
mask
)
def
low_latency_query_mask_buffer
(
self
,
mask_status
:
torch
.
Tensor
):
"""
Query the mask status of all ranks
Arguments:
mask_status: `[num_ranks]` with `torch.int`, the mask status of each rank. `1` means mask and `0` means unmasked.
"""
self
.
runtime
.
low_latency_query_mask_buffer
(
mask_status
)
def
low_latency_clean_mask_buffer
(
self
):
"""
Clean the mask buffer
"""
self
.
runtime
.
low_latency_clean_mask_buffer
()
rocshmem_dir/bin/rocshmem_info
View file @
da13c63a
No preview for this file type
rocshmem_dir/include/rocshmem/rocshmem.hpp
View file @
da13c63a
...
@@ -26,7 +26,6 @@
...
@@ -26,7 +26,6 @@
#define LIBRARY_INCLUDE_ROCSHMEM_HPP
#define LIBRARY_INCLUDE_ROCSHMEM_HPP
#include <hip/hip_runtime.h>
#include <hip/hip_runtime.h>
#include <mpi.h>
#include "rocshmem_config.h"
#include "rocshmem_config.h"
#include "rocshmem_common.hpp"
#include "rocshmem_common.hpp"
...
@@ -36,6 +35,10 @@
...
@@ -36,6 +35,10 @@
#include "rocshmem_COLL.hpp"
#include "rocshmem_COLL.hpp"
#include "rocshmem_P2P_SYNC.hpp"
#include "rocshmem_P2P_SYNC.hpp"
#include "rocshmem_RMA_X.hpp"
#include "rocshmem_RMA_X.hpp"
#if defined(HAVE_EXTERNAL_MPI)
#include <mpi.h>
#endif
/**
/**
* @file rocshmem.hpp
* @file rocshmem.hpp
* @brief Public header for rocSHMEM device and host libraries.
* @brief Public header for rocSHMEM device and host libraries.
...
@@ -57,20 +60,29 @@ constexpr char VERSION[] = "3.0.0";
...
@@ -57,20 +60,29 @@ constexpr char VERSION[] = "3.0.0";
/******************************************************************************
/******************************************************************************
**************************** HOST INTERFACE **********************************
**************************** HOST INTERFACE **********************************
*****************************************************************************/
*****************************************************************************/
#if defined(HAVE_EXTERNAL_MPI)
/**
/**
* @brief Initialize the rocSHMEM runtime and underlying transport layer.
* @brief Initialize the rocSHMEM runtime and underlying transport layer.
*
*
* @param[in] comm
(Optional)
MPI Communicator that rocSHMEM will be using
* @param[in] comm MPI Communicator that rocSHMEM will be using
* If MPI_COMM_NULL, rocSHMEM will be using MPI_COMM_WORLD
* If MPI_COMM_NULL, rocSHMEM will be using MPI_COMM_WORLD
*/
*/
__host__
void
rocshmem_init
(
MPI_Comm
comm
=
MPI_COMM_WORLD
);
[[
deprecated
]]
__host__
void
rocshmem_init
(
MPI_Comm
comm
);
#endif
/**
* @brief Initialize the rocSHMEM runtime and underlying transport layer.
* This is equivalent to the previous function, using implicitely
* MPI_COMM_WORLD for initialization
*/
__host__
void
rocshmem_init
(
void
);
/**
/**
* @brief Query rocSHMEM context from host API
* @brief Query rocSHMEM context from host API
*
*
* @param[out] ctx Returns ROCSHMEM_CTX_DEFAULT device pointer that users
* @param[out] ctx Returns ROCSHMEM_CTX_DEFAULT device pointer that users
* can query from one instance of rocshmem host library and
* can query from one instance of rocshmem host library and
* use use later for dynamic module initialization in
* use use later for dynamic module initialization in
* kernel bitcode device library in the same application
* kernel bitcode device library in the same application
*/
*/
__host__
void
*
rocshmem_get_device_ctx
();
__host__
void
*
rocshmem_get_device_ctx
();
...
@@ -79,15 +91,17 @@ __host__ void * rocshmem_get_device_ctx();
...
@@ -79,15 +91,17 @@ __host__ void * rocshmem_get_device_ctx();
* @brief Query rocSHMEM remote symmetric heap pointer
* @brief Query rocSHMEM remote symmetric heap pointer
*
*
* @param[in] dest local symmetric heap allocation pointer for current pe/device
* @param[in] dest local symmetric heap allocation pointer for current pe/device
*
*
* @param[in] pe remote PE
* @param[in] pe remote PE
*
*
* @param[out] ptr Returns remote symmetric heap device pointer from host-side API.
* @param[out] ptr Returns remote symmetric heap device pointer from host-side API.
* This can be used to issue load/store from custom kernels
* This can be used to issue load/store from custom kernels
* instead of using rocshmem device side get/put APIs for RMA operations.
* instead of using rocshmem device side get/put APIs for RMA operations.
*/
*/
__host__
void
*
rocshmem_ptr
(
void
*
dest
,
int
pe
);
__host__
void
*
rocshmem_ptr
(
const
void
*
dest
,
int
pe
);
__device__
ATTR_NO_INLINE
void
*
rocshmem_ptr
(
const
void
*
dest
,
int
pe
);
#if defined(HAVE_EXTERNAL_MPI)
/**
/**
* @brief Initialize the rocSHMEM runtime and underlying transport layer
* @brief Initialize the rocSHMEM runtime and underlying transport layer
* with an attempt to enable the requested thread support.
* with an attempt to enable the requested thread support.
...
@@ -102,8 +116,9 @@ __host__ void *rocshmem_ptr(void *dest, int pe);
...
@@ -102,8 +116,9 @@ __host__ void *rocshmem_ptr(void *dest, int pe);
* @return int returns 0 upon success; otherwise, it returns a nonzero
* @return int returns 0 upon success; otherwise, it returns a nonzero
* value
* value
*/
*/
__host__
int
rocshmem_init_thread
(
int
requested
,
int
*
provided
,
[[
deprecated
]]
__host__
int
rocshmem_init_thread
(
int
requested
,
int
*
provided
,
MPI_Comm
comm
=
MPI_COMM_WORLD
);
MPI_Comm
comm
);
#endif
/**
/**
* @brief Initialize the rocSHMEM runtime and underlying transport layer
* @brief Initialize the rocSHMEM runtime and underlying transport layer
...
@@ -327,6 +342,13 @@ __host__ void rocshmem_quiet();
...
@@ -327,6 +342,13 @@ __host__ void rocshmem_quiet();
*/
*/
__host__
void
rocshmem_barrier_all
();
__host__
void
rocshmem_barrier_all
();
/**
* @brief enqueues a collective barrier on given stream.
*
* @return void
*/
__host__
void
rocshmem_barrier_all_on_stream
(
hipStream_t
stream
);
/**
/**
* @brief registers the arrival of a PE at a barrier.
* @brief registers the arrival of a PE at a barrier.
* The caller is blocked until the synchronization is resolved.
* The caller is blocked until the synchronization is resolved.
...
@@ -360,7 +382,7 @@ __host__ void rocshmem_global_exit(int status);
...
@@ -360,7 +382,7 @@ __host__ void rocshmem_global_exit(int status);
*
*
* @return void.
* @return void.
*/
*/
__device__
void
rocshmem_wg_init
();
[[
deprecated
]]
__device__
void
rocshmem_wg_init
();
/**
/**
* @brief Finalizes device-side rocSHMEM resources. Must be called before
* @brief Finalizes device-side rocSHMEM resources. Must be called before
...
@@ -370,7 +392,7 @@ __device__ void rocshmem_wg_init();
...
@@ -370,7 +392,7 @@ __device__ void rocshmem_wg_init();
*
*
* @return void.
* @return void.
*/
*/
__device__
void
rocshmem_wg_finalize
();
[[
deprecated
]]
__device__
void
rocshmem_wg_finalize
();
/**
/**
* @brief Initializes device-side rocSHMEM resources. Must be called before
* @brief Initializes device-side rocSHMEM resources. Must be called before
...
@@ -386,7 +408,7 @@ __device__ void rocshmem_wg_finalize();
...
@@ -386,7 +408,7 @@ __device__ void rocshmem_wg_finalize();
*
*
* @return void.
* @return void.
*/
*/
__device__
void
rocshmem_wg_init_thread
(
int
requested
,
int
*
provided
);
[[
deprecated
]]
__device__
void
rocshmem_wg_init_thread
(
int
requested
,
int
*
provided
);
/**
/**
* @brief Query the thread mode used by the runtime.
* @brief Query the thread mode used by the runtime.
...
@@ -476,6 +498,23 @@ __device__ ATTR_NO_INLINE void rocshmem_ctx_quiet(rocshmem_ctx_t ctx);
...
@@ -476,6 +498,23 @@ __device__ ATTR_NO_INLINE void rocshmem_ctx_quiet(rocshmem_ctx_t ctx);
__device__
ATTR_NO_INLINE
void
rocshmem_quiet
();
__device__
ATTR_NO_INLINE
void
rocshmem_quiet
();
/**
* @brief Completes all previous operations posted to this context for PEs in the
* `target_pes` array.
*
* @param[in] ctx Context with which to perform this operation.
*
* @param[in] target_pes Address of target PE array where the operations need to be completed.
*
* @param[in] npes The number of PEs in the target PE array.
*
* @return void.
*/
__device__
ATTR_NO_INLINE
void
rocshmem_ctx_pe_quiet
(
rocshmem_ctx_t
ctx
,
const
int
*
target_pes
,
size_t
npes
);
__device__
ATTR_NO_INLINE
void
rocshmem_pe_quiet
(
const
int
*
target_pes
,
size_t
npes
);
/**
/**
* @brief Query the total number of PEs.
* @brief Query the total number of PEs.
*
*
...
...
rocshmem_dir/include/rocshmem/rocshmem_COLL.hpp
View file @
da13c63a
...
@@ -599,6 +599,14 @@ __host__ int rocshmem_ctx_double_prod_reduce(
...
@@ -599,6 +599,14 @@ __host__ int rocshmem_ctx_double_prod_reduce(
rocshmem_ctx_t
ctx
,
rocshmem_team_t
team
,
double
*
dest
,
const
double
*
source
,
rocshmem_ctx_t
ctx
,
rocshmem_team_t
team
,
double
*
dest
,
const
double
*
source
,
int
nreduce
);
int
nreduce
);
/**
* @brief kernel for performing a barrier synchronization.
* Caller enqueues the kernel on given stream
*
* @return void
*/
__global__
ATTR_NO_INLINE
void
rocshmem_barrier_all_kernel
();
/**
/**
* @brief perform a collective barrier between all PEs in the system.
* @brief perform a collective barrier between all PEs in the system.
* The caller is blocked until the barrier is resolved.
* The caller is blocked until the barrier is resolved.
...
@@ -767,28 +775,6 @@ __device__ ATTR_NO_INLINE void rocshmem_ctx_sync_wave(
...
@@ -767,28 +775,6 @@ __device__ ATTR_NO_INLINE void rocshmem_ctx_sync_wave(
__device__
ATTR_NO_INLINE
void
rocshmem_ctx_sync_wg
(
__device__
ATTR_NO_INLINE
void
rocshmem_ctx_sync_wg
(
rocshmem_ctx_t
ctx
,
rocshmem_team_t
team
);
rocshmem_ctx_t
ctx
,
rocshmem_team_t
team
);
/**
* @brief Query a local pointer to a symmetric data object on the
* specified \pe . Returns an address that may be used to directly reference
* dest on the specified \pe. This address can be accesses with LD/ST ops.
*
* Can be called per thread with no performance penalty.
*/
__device__
ATTR_NO_INLINE
void
*
rocshmem_ptr
(
const
void
*
dest
,
int
pe
);
/**
* @brief Make all uncacheable GPU data visible to other agents in the sytem.
*
* This only works for data that was explicitly allocated uncacheable on the
* GPU!
*
* Can be called per thread with no performance penalty.
*
* @param[in] GPU-side handle.
*
* @return void
*/
}
// namespace rocshmem
}
// namespace rocshmem
#endif // LIBRARY_INCLUDE_ROCSHMEM_COLL_HPP
#endif // LIBRARY_INCLUDE_ROCSHMEM_COLL_HPP
rocshmem_dir/include/rocshmem/rocshmem_common.hpp
View file @
da13c63a
...
@@ -106,9 +106,18 @@ const int ROCSHMEM_CTX_SHARED = 8;
...
@@ -106,9 +106,18 @@ const int ROCSHMEM_CTX_SHARED = 8;
* @brief GPU side OpenSHMEM context created from each work-groups'
* @brief GPU side OpenSHMEM context created from each work-groups'
* rocshmem_wg_handle_t
* rocshmem_wg_handle_t
*/
*/
typedef
struct
{
typedef
struct
rocshmem_ctx
{
void
*
ctx_opaque
;
void
*
ctx_opaque
;
void
*
team_opaque
;
void
*
team_opaque
;
__host__
__device__
bool
operator
==
(
const
struct
rocshmem_ctx
&
other
)
const
{
return
(
ctx_opaque
==
other
.
ctx_opaque
&&
team_opaque
==
other
.
team_opaque
);
}
__host__
__device__
bool
operator
!=
(
const
struct
rocshmem_ctx
&
other
)
const
{
return
!
(
*
this
==
other
);
}
}
rocshmem_ctx_t
;
}
rocshmem_ctx_t
;
/**
/**
...
@@ -116,6 +125,14 @@ typedef struct {
...
@@ -116,6 +125,14 @@ typedef struct {
*/
*/
extern
"C"
__device__
rocshmem_ctx_t
__attribute__
((
visibility
(
"default"
)))
ROCSHMEM_CTX_DEFAULT
;
extern
"C"
__device__
rocshmem_ctx_t
__attribute__
((
visibility
(
"default"
)))
ROCSHMEM_CTX_DEFAULT
;
/**
* A value corresponding to an invalid communication context. This value can be
* used to initialize or update context handles to indicate that they do not
* reference a valid context. When managed in this way, applications can use an
* equality comparison to test whether a given context handle references a
* valid context.
*/
extern
__constant__
rocshmem_ctx_t
ROCSHMEM_CTX_INVALID
;
/**
/**
* Used internally to set default context.
* Used internally to set default context.
*/
*/
...
...
rocshmem_dir/include/rocshmem/rocshmem_config.h
View file @
da13c63a
...
@@ -45,3 +45,4 @@
...
@@ -45,3 +45,4 @@
/* #undef GDA_IONIC */
/* #undef GDA_IONIC */
/* #undef GDA_BNXT */
/* #undef GDA_BNXT */
#define GDA_MLX5
#define GDA_MLX5
#define HAVE_EXTERNAL_MPI
rocshmem_dir/include/rocshmem/rocshmem_mpi.hpp
0 → 100644
View file @
da13c63a
/******************************************************************************
* Copyright (c) Advanced Micro Devices, Inc. All rights reserved.
*
* SPDX-License-Identifier: MIT
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to
* deal in the Software without restriction, including without limitation the
* rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
* sell copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
* FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS
* IN THE SOFTWARE.
*****************************************************************************/
#ifndef LIBRARY_INCLUDE_ROCSHMEM_MPI_HPP
#define LIBRARY_INCLUDE_ROCSHMEM_MPI_HPP
#if defined(HAVE_EXTERNAL_MPI)
#include <mpi.h>
#endif
#if defined(c_plusplus) || defined(__cplusplus)
extern
"C"
{
#endif
#if !defined(MPI_VERSION)
// Open MPI based values for the constants/handles etc.
// Even though we did not include an external MPI header file
// The includer may have (e.g., a unit test).
typedef
void
*
MPI_Comm
;
typedef
void
*
MPI_Win
;
typedef
void
*
MPI_Group
;
typedef
void
*
MPI_Op
;
typedef
void
*
MPI_Datatype
;
typedef
void
*
MPI_Request
;
typedef
void
*
MPI_Info
;
struct
ompi_status_public_t
{
int
MPI_SOURCE
;
int
MPI_TAG
;
int
MPI_ERROR
;
int
_cancelled
;
size_t
_ucount
;
};
typedef
struct
ompi_status_public_t
MPI_Status
;
#define MPI_Aint uint64_t
#define MPI_UNDEFINED -32766
#define MPI_THREAD_MULTIPLE 3
#define MPI_SUCCESS 0
#define MPI_IN_PLACE (void*)1
#define MPI_MODE_NOCHECK 1
#define MPI_COMM_TYPE_SHARED 0
#define MPI_Aint_diff(addr1, addr2) ((MPI_Aint) ((char *) (addr1) - (char *) (addr2)))
struct
ompi_internal_symbols_t
{
void
*
ompi_mpi_comm_world
;
void
*
ompi_mpi_comm_null
;
void
*
ompi_request_null
;
void
*
ompi_mpi_info_null
;
void
*
ompi_mpi_datatype_null
;
void
*
ompi_mpi_op_max
;
void
*
ompi_mpi_op_min
;
void
*
ompi_mpi_op_sum
;
void
*
ompi_mpi_op_prod
;
void
*
ompi_mpi_op_band
;
void
*
ompi_mpi_op_bor
;
void
*
ompi_mpi_op_bxor
;
void
*
ompi_mpi_op_replace
;
void
*
ompi_mpi_op_no_op
;
void
*
ompi_mpi_char
;
void
*
ompi_mpi_unsigned_char
;
void
*
ompi_mpi_signed_char
;
void
*
ompi_mpi_short
;
void
*
ompi_mpi_unsigned_short
;
void
*
ompi_mpi_int
;
void
*
ompi_mpi_unsigned
;
void
*
ompi_mpi_long
;
void
*
ompi_mpi_unsigned_long
;
void
*
ompi_mpi_long_long_int
;
void
*
ompi_mpi_unsigned_long_long
;
void
*
ompi_mpi_float
;
void
*
ompi_mpi_double
;
void
*
ompi_mpi_long_double
;
};
extern
struct
ompi_internal_symbols_t
ompi_symbols_
;
#define OMPI_PREDEFINED_GLOBAL(type, global) (static_cast<type> (global))
#define MPI_COMM_WORLD OMPI_PREDEFINED_GLOBAL(MPI_Comm, ompi_symbols_.ompi_mpi_comm_world)
#define MPI_COMM_NULL OMPI_PREDEFINED_GLOBAL(MPI_Comm, ompi_symbols_.ompi_mpi_comm_null)
#define MPI_REQUEST_NULL OMPI_PREDEFINED_GLOBAL(MPI_Request, ompi_symbols_.ompi_request_null)
#define MPI_WIN_NULL OMPI_PREDEFINED_GLOBAL(MPI_Win, ompi_symbols_.ompi_mpi_win_null)
#define MPI_INFO_NULL OMPI_PREDEFINED_GLOBAL(MPI_Info, ompi_symbols_.ompi_mpi_info_null)
#define MPI_MAX OMPI_PREDEFINED_GLOBAL(MPI_Op, ompi_symbols_.ompi_mpi_op_max)
#define MPI_MIN OMPI_PREDEFINED_GLOBAL(MPI_Op, ompi_symbols_.ompi_mpi_op_min)
#define MPI_SUM OMPI_PREDEFINED_GLOBAL(MPI_Op, ompi_symbols_.ompi_mpi_op_sum)
#define MPI_PROD OMPI_PREDEFINED_GLOBAL(MPI_Op, ompi_symbols_.ompi_mpi_op_prod)
#define MPI_BAND OMPI_PREDEFINED_GLOBAL(MPI_Op, ompi_symbols_.ompi_mpi_op_band)
#define MPI_BOR OMPI_PREDEFINED_GLOBAL(MPI_Op, ompi_symbols_.ompi_mpi_op_bor)
#define MPI_BXOR OMPI_PREDEFINED_GLOBAL(MPI_Op, ompi_symbols_.ompi_mpi_op_bxor)
#define MPI_REPLACE OMPI_PREDEFINED_GLOBAL(MPI_Op, ompi_symbols_.ompi_mpi_op_replace)
#define MPI_NO_OP OMPI_PREDEFINED_GLOBAL(MPI_Op, ompi_symbols_.ompi_mpi_op_no_op)
#define MPI_DATATYPE_NULL OMPI_PREDEFINED_GLOBAL(MPI_Datatype, ompi_symbols_.ompi_mpi_datatype_null)
#define MPI_CHAR OMPI_PREDEFINED_GLOBAL(MPI_Datatype, ompi_symbols_.ompi_mpi_char)
#define MPI_UNSIGNED_CHAR OMPI_PREDEFINED_GLOBAL(MPI_Datatype, ompi_symbols_.ompi_mpi_unsigned_char)
#define MPI_SIGNED_CHAR OMPI_PREDEFINED_GLOBAL(MPI_Datatype, ompi_symbols_.ompi_mpi_signed_char)
#define MPI_SHORT OMPI_PREDEFINED_GLOBAL(MPI_Datatype, ompi_symbols_.ompi_mpi_short)
#define MPI_UNSIGNED_SHORT OMPI_PREDEFINED_GLOBAL(MPI_Datatype, ompi_symbols_.ompi_mpi_unsigned_short)
#define MPI_INT OMPI_PREDEFINED_GLOBAL(MPI_Datatype, ompi_symbols_.ompi_mpi_int)
#define MPI_UNSIGNED OMPI_PREDEFINED_GLOBAL(MPI_Datatype, ompi_symbols_.ompi_mpi_unsigned)
#define MPI_LONG OMPI_PREDEFINED_GLOBAL(MPI_Datatype, ompi_symbols_.ompi_mpi_long)
#define MPI_UNSIGNED_LONG OMPI_PREDEFINED_GLOBAL(MPI_Datatype, ompi_symbols_.ompi_mpi_unsigned_long)
#define MPI_LONG_LONG OMPI_PREDEFINED_GLOBAL(MPI_Datatype, ompi_symbols_.ompi_mpi_long_long_int)
#define MPI_UNSIGNED_LONG_LONG OMPI_PREDEFINED_GLOBAL(MPI_Datatype, ompi_symbols_.ompi_mpi_unsigned_long_long)
#define MPI_FLOAT OMPI_PREDEFINED_GLOBAL(MPI_Datatype, ompi_symbols_.ompi_mpi_float)
#define MPI_DOUBLE OMPI_PREDEFINED_GLOBAL(MPI_Datatype, ompi_symbols_.ompi_mpi_double)
#define MPI_LONG_DOUBLE OMPI_PREDEFINED_GLOBAL(MPI_Datatype, ompi_symbols_.ompi_mpi_long_double)
#endif //!defined(MPI_VERSION)
#if defined(c_plusplus) || defined(__cplusplus)
}
#endif
#endif //LIBRARY_INCLUDE_ROCSHMEM_MPI_HPP
rocshmem_dir/lib/cmake/rocshmem/rocshmem-targets.cmake
View file @
da13c63a
...
@@ -61,7 +61,7 @@ add_library(roc::rocshmem STATIC IMPORTED)
...
@@ -61,7 +61,7 @@ add_library(roc::rocshmem STATIC IMPORTED)
set_target_properties
(
roc::rocshmem PROPERTIES
set_target_properties
(
roc::rocshmem PROPERTIES
INTERFACE_COMPILE_OPTIONS
"-fgpu-rdc;-fgpu-rdc"
INTERFACE_COMPILE_OPTIONS
"-fgpu-rdc;-fgpu-rdc"
INTERFACE_INCLUDE_DIRECTORIES
"
${
_IMPORT_PREFIX
}
/include;
${
_IMPORT_PREFIX
}
/include"
INTERFACE_INCLUDE_DIRECTORIES
"
${
_IMPORT_PREFIX
}
/include;
${
_IMPORT_PREFIX
}
/include"
INTERFACE_LINK_LIBRARIES
"IBVerbs::verbs;numa;Threads::Threads;
MPI::MPI_CXX;
hip::device;hip::host;hsa-runtime64::hsa-runtime64;-fgpu-rdc"
INTERFACE_LINK_LIBRARIES
"IBVerbs::verbs;numa;
\$
<
\$
<BOOL:ON>:MPI::MPI_CXX>;
Threads::Threads;hip::device;hip::host;
dl;
hsa-runtime64::hsa-runtime64;-fgpu-rdc"
)
)
# Load information for each installed configuration.
# Load information for each installed configuration.
...
...
rocshmem_dir/lib/librocshmem.a
View file @
da13c63a
No preview for this file type
rocshmem_dir/rocshmem/lib/cmake/rocshmem-config-version.cmake
View file @
da13c63a
../../../lib/cmake/rocshmem/rocshmem-config-version.cmake
# This is a basic version file for the Config-mode of find_package().
\ No newline at end of file
# It is used by write_basic_package_version_file() as input file for configure_file()
# to create a version-file which can be installed along a config.cmake file.
#
# The created file sets PACKAGE_VERSION_EXACT if the current version string and
# the requested version string are exactly the same and it sets
# PACKAGE_VERSION_COMPATIBLE if the current version is >= requested version,
# but only if the requested major version is the same as the current one.
# The variable CVF_VERSION must be set before calling configure_file().
set
(
PACKAGE_VERSION
"3.0.0"
)
if
(
PACKAGE_VERSION VERSION_LESS PACKAGE_FIND_VERSION
)
set
(
PACKAGE_VERSION_COMPATIBLE FALSE
)
else
()
if
(
"3.0.0"
MATCHES
"^([0-9]+)
\\
."
)
set
(
CVF_VERSION_MAJOR
"
${
CMAKE_MATCH_1
}
"
)
if
(
NOT CVF_VERSION_MAJOR VERSION_EQUAL 0
)
string
(
REGEX REPLACE
"^0+"
""
CVF_VERSION_MAJOR
"
${
CVF_VERSION_MAJOR
}
"
)
endif
()
else
()
set
(
CVF_VERSION_MAJOR
"3.0.0"
)
endif
()
if
(
PACKAGE_FIND_VERSION_RANGE
)
# both endpoints of the range must have the expected major version
math
(
EXPR CVF_VERSION_MAJOR_NEXT
"
${
CVF_VERSION_MAJOR
}
+ 1"
)
if
(
NOT PACKAGE_FIND_VERSION_MIN_MAJOR STREQUAL CVF_VERSION_MAJOR
OR
((
PACKAGE_FIND_VERSION_RANGE_MAX STREQUAL
"INCLUDE"
AND NOT PACKAGE_FIND_VERSION_MAX_MAJOR STREQUAL CVF_VERSION_MAJOR
)
OR
(
PACKAGE_FIND_VERSION_RANGE_MAX STREQUAL
"EXCLUDE"
AND NOT PACKAGE_FIND_VERSION_MAX VERSION_LESS_EQUAL CVF_VERSION_MAJOR_NEXT
)))
set
(
PACKAGE_VERSION_COMPATIBLE FALSE
)
elseif
(
PACKAGE_FIND_VERSION_MIN_MAJOR STREQUAL CVF_VERSION_MAJOR
AND
((
PACKAGE_FIND_VERSION_RANGE_MAX STREQUAL
"INCLUDE"
AND PACKAGE_VERSION VERSION_LESS_EQUAL PACKAGE_FIND_VERSION_MAX
)
OR
(
PACKAGE_FIND_VERSION_RANGE_MAX STREQUAL
"EXCLUDE"
AND PACKAGE_VERSION VERSION_LESS PACKAGE_FIND_VERSION_MAX
)))
set
(
PACKAGE_VERSION_COMPATIBLE TRUE
)
else
()
set
(
PACKAGE_VERSION_COMPATIBLE FALSE
)
endif
()
else
()
if
(
PACKAGE_FIND_VERSION_MAJOR STREQUAL CVF_VERSION_MAJOR
)
set
(
PACKAGE_VERSION_COMPATIBLE TRUE
)
else
()
set
(
PACKAGE_VERSION_COMPATIBLE FALSE
)
endif
()
if
(
PACKAGE_FIND_VERSION STREQUAL PACKAGE_VERSION
)
set
(
PACKAGE_VERSION_EXACT TRUE
)
endif
()
endif
()
endif
()
# if the installed or the using project don't have CMAKE_SIZEOF_VOID_P set, ignore it:
if
(
"
${
CMAKE_SIZEOF_VOID_P
}
"
STREQUAL
""
OR
"8"
STREQUAL
""
)
return
()
endif
()
# check that the installed version has the same 32/64bit-ness as the one which is currently searching:
if
(
NOT CMAKE_SIZEOF_VOID_P STREQUAL
"8"
)
math
(
EXPR installedBits
"8 * 8"
)
set
(
PACKAGE_VERSION
"
${
PACKAGE_VERSION
}
(
${
installedBits
}
bit)"
)
set
(
PACKAGE_VERSION_UNSUITABLE TRUE
)
endif
()
rocshmem_dir/rocshmem/lib/cmake/rocshmem-config.cmake
View file @
da13c63a
../../../lib/cmake/rocshmem/rocshmem-config.cmake
\ No newline at end of file
####################################################################################
# Auto generated @PACKAGE_INIT@ by rocm_configure_package_config_file()
# from rocshmem-config.cmake.in
# Any changes to this file will be overwritten by the next CMake run
####################################################################################
get_filename_component
(
_ROCM_CMAKE_CURRENT_LIST_FILE_REAL
"
${
CMAKE_CURRENT_LIST_FILE
}
"
REALPATH
)
get_filename_component
(
_ROCM_CMAKE_CURRENT_LIST_DIR_REAL
"
${
_ROCM_CMAKE_CURRENT_LIST_FILE_REAL
}
"
DIRECTORY
)
get_filename_component
(
PACKAGE_PREFIX_DIR
"
${
_ROCM_CMAKE_CURRENT_LIST_DIR_REAL
}
/../../../"
ABSOLUTE
)
macro
(
set_and_check _var _file
)
set
(
${
_var
}
"
${
_file
}
"
)
if
(
NOT EXISTS
"
${
_file
}
"
)
message
(
FATAL_ERROR
"File or directory
${
_file
}
referenced by variable
${
_var
}
does not exist !"
)
endif
()
endmacro
()
include
(
CMakeFindDependencyMacro OPTIONAL RESULT_VARIABLE _ROCMCMakeFindDependencyMacro_FOUND
)
if
(
NOT _ROCMCMakeFindDependencyMacro_FOUND
)
macro
(
find_dependency dep
)
if
(
NOT
${
dep
}
_FOUND
)
set
(
rocm_fd_version
)
if
(
${
ARGC
}
GREATER 1
)
set
(
rocm_fd_version
${
ARGV1
}
)
endif
()
set
(
rocm_fd_exact_arg
)
if
(
${
CMAKE_FIND_PACKAGE_NAME
}
_FIND_VERSION_EXACT
)
set
(
rocm_fd_exact_arg EXACT
)
endif
()
set
(
rocm_fd_quiet_arg
)
if
(
${
CMAKE_FIND_PACKAGE_NAME
}
_FIND_QUIETLY
)
set
(
rocm_fd_quiet_arg QUIET
)
endif
()
set
(
rocm_fd_required_arg
)
if
(
${
CMAKE_FIND_PACKAGE_NAME
}
_FIND_REQUIRED
)
set
(
rocm_fd_required_arg REQUIRED
)
endif
()
find_package
(
${
dep
}
${
rocm_fd_version
}
${
rocm_fd_exact_arg
}
${
rocm_fd_quiet_arg
}
${
rocm_fd_required_arg
}
)
string
(
TOUPPER
${
dep
}
cmake_dep_upper
)
if
(
NOT
${
dep
}
_FOUND AND NOT
${
cmake_dep_upper
}
_FOUND
)
set
(
${
CMAKE_FIND_PACKAGE_NAME
}
_NOT_FOUND_MESSAGE
"
${
CMAKE_FIND_PACKAGE_NAME
}
could not be found because dependency
${
dep
}
could not be found."
)
set
(
${
CMAKE_FIND_PACKAGE_NAME
}
_FOUND False
)
return
()
endif
()
set
(
rocm_fd_version
)
set
(
rocm_fd_required_arg
)
set
(
rocm_fd_quiet_arg
)
set
(
rocm_fd_exact_arg
)
endif
()
endmacro
()
endif
()
macro
(
check_required_components _NAME
)
foreach
(
comp
${${
_NAME
}
_FIND_COMPONENTS
}
)
if
(
NOT
${
_NAME
}
_
${
comp
}
_FOUND
)
if
(
${
_NAME
}
_FIND_REQUIRED_
${
comp
}
)
set
(
${
_NAME
}
_FOUND FALSE
)
endif
()
endif
()
endforeach
()
endmacro
()
####################################################################################
set_and_check
(
rocshmem_INCLUDE_DIR
${
PACKAGE_PREFIX_DIR
}
/include
)
set_and_check
(
rocshmem_INCLUDE_DIRS
${
PACKAGE_PREFIX_DIR
}
/include
)
set_and_check
(
ROCSHMEM_INCLUDE_DIR
${
PACKAGE_PREFIX_DIR
}
/include
)
set_and_check
(
ROCSHMEM_INCLUDE_DIRS
${
PACKAGE_PREFIX_DIR
}
/include
)
set_and_check
(
rocshmem_INCLUDE_DIR
${
PACKAGE_PREFIX_DIR
}
/include
)
set_and_check
(
rocshmem_INCLUDE_DIRS
${
PACKAGE_PREFIX_DIR
}
/include
)
set_and_check
(
rocshmem_TARGET_FILE
${
PACKAGE_PREFIX_DIR
}
/lib/cmake/rocshmem/rocshmem-targets.cmake
)
include
(
${
rocshmem_TARGET_FILE
}
)
set
(
rocshmem_LIBRARIES roc::rocshmem
)
set
(
rocshmem_LIBRARY roc::rocshmem
)
set
(
ROCSHMEM_LIBRARIES roc::rocshmem
)
set
(
ROCSHMEM_LIBRARY roc::rocshmem
)
set
(
rocshmem_LIBRARIES roc::rocshmem
)
set
(
rocshmem_LIBRARY roc::rocshmem
)
Prev
1
2
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment