Commit 5563b6d0 authored by lijian6's avatar lijian6
Browse files

Fitter for DCU.


Signed-off-by: lijian6's avatarlijian <lijian6@sugon.com>
parent da6ca24e
/******************************************************************************
* Copyright (c) Advanced Micro Devices, Inc. All rights reserved.
*
* SPDX-License-Identifier: MIT
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to
* deal in the Software without restriction, including without limitation the
* rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
* sell copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
* FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS
* IN THE SOFTWARE.
*****************************************************************************/
/* #undef DEBUG */
/* #undef PROFILE */
/* #undef USE_RO */
/* #undef USE_IPC */
#define USE_GDA
/* #undef USE_THREADS */
/* #undef USE_SHARED_CTX */
/* #undef USE_WF_COAL */
#define USE_HEAP_DEVICE_FINEGRAIN
/* #undef USE_HEAP_DEVICE_UNCACHED */
/* #undef USE_HEAP_DEVICE_COARSEGRAIN */
/* #undef USE_HEAP_MANAGED */
/* #undef USE_HEAP_HOST_HIP */
/* #undef USE_HEAP_HOST */
#define USE_ALLOC_DLMALLOC
/* #undef USE_ALLOC_POW2BINS */
/* #undef USE_FUNC_CALL */
/* #undef USE_SINGLE_NODE */
/* #undef USE_HDP_FLUSH */
/* #undef USE_HDP_FLUSH_HOST_SIDE */
/* #undef GDA_IONIC */
/* #undef GDA_BNXT */
#define GDA_MLX5
/******************************************************************************
* Copyright (c) Advanced Micro Devices, Inc. All rights reserved.
*
* SPDX-License-Identifier: MIT
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to
* deal in the Software without restriction, including without limitation the
* rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
* sell copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
* FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS
* IN THE SOFTWARE.
*****************************************************************************/
#ifndef LIBRARY_INCLUDE_DEBUG_HPP
#define LIBRARY_INCLUDE_DEBUG_HPP
namespace rocshmem {
void debug_print_cq(int dest_pe, int src_wg, int cqe_index);
void debug_print_sq(int dest_pe, int src_wg, int index_wqe);
} // namespace rocshmem
#endif // LIBRARY_INCLUDE_DEBUG_HPP
# This is a basic version file for the Config-mode of find_package().
# It is used by write_basic_package_version_file() as input file for configure_file()
# to create a version-file which can be installed along a config.cmake file.
#
# The created file sets PACKAGE_VERSION_EXACT if the current version string and
# the requested version string are exactly the same and it sets
# PACKAGE_VERSION_COMPATIBLE if the current version is >= requested version,
# but only if the requested major version is the same as the current one.
# The variable CVF_VERSION must be set before calling configure_file().
set(PACKAGE_VERSION "3.0.0")
if(PACKAGE_VERSION VERSION_LESS PACKAGE_FIND_VERSION)
set(PACKAGE_VERSION_COMPATIBLE FALSE)
else()
if("3.0.0" MATCHES "^([0-9]+)\\.")
set(CVF_VERSION_MAJOR "${CMAKE_MATCH_1}")
if(NOT CVF_VERSION_MAJOR VERSION_EQUAL 0)
string(REGEX REPLACE "^0+" "" CVF_VERSION_MAJOR "${CVF_VERSION_MAJOR}")
endif()
else()
set(CVF_VERSION_MAJOR "3.0.0")
endif()
if(PACKAGE_FIND_VERSION_RANGE)
# both endpoints of the range must have the expected major version
math (EXPR CVF_VERSION_MAJOR_NEXT "${CVF_VERSION_MAJOR} + 1")
if (NOT PACKAGE_FIND_VERSION_MIN_MAJOR STREQUAL CVF_VERSION_MAJOR
OR ((PACKAGE_FIND_VERSION_RANGE_MAX STREQUAL "INCLUDE" AND NOT PACKAGE_FIND_VERSION_MAX_MAJOR STREQUAL CVF_VERSION_MAJOR)
OR (PACKAGE_FIND_VERSION_RANGE_MAX STREQUAL "EXCLUDE" AND NOT PACKAGE_FIND_VERSION_MAX VERSION_LESS_EQUAL CVF_VERSION_MAJOR_NEXT)))
set(PACKAGE_VERSION_COMPATIBLE FALSE)
elseif(PACKAGE_FIND_VERSION_MIN_MAJOR STREQUAL CVF_VERSION_MAJOR
AND ((PACKAGE_FIND_VERSION_RANGE_MAX STREQUAL "INCLUDE" AND PACKAGE_VERSION VERSION_LESS_EQUAL PACKAGE_FIND_VERSION_MAX)
OR (PACKAGE_FIND_VERSION_RANGE_MAX STREQUAL "EXCLUDE" AND PACKAGE_VERSION VERSION_LESS PACKAGE_FIND_VERSION_MAX)))
set(PACKAGE_VERSION_COMPATIBLE TRUE)
else()
set(PACKAGE_VERSION_COMPATIBLE FALSE)
endif()
else()
if(PACKAGE_FIND_VERSION_MAJOR STREQUAL CVF_VERSION_MAJOR)
set(PACKAGE_VERSION_COMPATIBLE TRUE)
else()
set(PACKAGE_VERSION_COMPATIBLE FALSE)
endif()
if(PACKAGE_FIND_VERSION STREQUAL PACKAGE_VERSION)
set(PACKAGE_VERSION_EXACT TRUE)
endif()
endif()
endif()
# if the installed or the using project don't have CMAKE_SIZEOF_VOID_P set, ignore it:
if("${CMAKE_SIZEOF_VOID_P}" STREQUAL "" OR "8" STREQUAL "")
return()
endif()
# check that the installed version has the same 32/64bit-ness as the one which is currently searching:
if(NOT CMAKE_SIZEOF_VOID_P STREQUAL "8")
math(EXPR installedBits "8 * 8")
set(PACKAGE_VERSION "${PACKAGE_VERSION} (${installedBits}bit)")
set(PACKAGE_VERSION_UNSUITABLE TRUE)
endif()
####################################################################################
# Auto generated @PACKAGE_INIT@ by rocm_configure_package_config_file()
# from rocshmem-config.cmake.in
# Any changes to this file will be overwritten by the next CMake run
####################################################################################
get_filename_component(_ROCM_CMAKE_CURRENT_LIST_FILE_REAL "${CMAKE_CURRENT_LIST_FILE}" REALPATH)
get_filename_component(_ROCM_CMAKE_CURRENT_LIST_DIR_REAL "${_ROCM_CMAKE_CURRENT_LIST_FILE_REAL}" DIRECTORY)
get_filename_component(PACKAGE_PREFIX_DIR "${_ROCM_CMAKE_CURRENT_LIST_DIR_REAL}/../../../" ABSOLUTE)
macro(set_and_check _var _file)
set(${_var} "${_file}")
if(NOT EXISTS "${_file}")
message(FATAL_ERROR "File or directory ${_file} referenced by variable ${_var} does not exist !")
endif()
endmacro()
include(CMakeFindDependencyMacro OPTIONAL RESULT_VARIABLE _ROCMCMakeFindDependencyMacro_FOUND)
if (NOT _ROCMCMakeFindDependencyMacro_FOUND)
macro(find_dependency dep)
if (NOT ${dep}_FOUND)
set(rocm_fd_version)
if (${ARGC} GREATER 1)
set(rocm_fd_version ${ARGV1})
endif()
set(rocm_fd_exact_arg)
if(${CMAKE_FIND_PACKAGE_NAME}_FIND_VERSION_EXACT)
set(rocm_fd_exact_arg EXACT)
endif()
set(rocm_fd_quiet_arg)
if(${CMAKE_FIND_PACKAGE_NAME}_FIND_QUIETLY)
set(rocm_fd_quiet_arg QUIET)
endif()
set(rocm_fd_required_arg)
if(${CMAKE_FIND_PACKAGE_NAME}_FIND_REQUIRED)
set(rocm_fd_required_arg REQUIRED)
endif()
find_package(${dep} ${rocm_fd_version}
${rocm_fd_exact_arg}
${rocm_fd_quiet_arg}
${rocm_fd_required_arg}
)
string(TOUPPER ${dep} cmake_dep_upper)
if (NOT ${dep}_FOUND AND NOT ${cmake_dep_upper}_FOUND)
set(${CMAKE_FIND_PACKAGE_NAME}_NOT_FOUND_MESSAGE
"${CMAKE_FIND_PACKAGE_NAME} could not be found because dependency ${dep} could not be found.")
set(${CMAKE_FIND_PACKAGE_NAME}_FOUND False)
return()
endif()
set(rocm_fd_version)
set(rocm_fd_required_arg)
set(rocm_fd_quiet_arg)
set(rocm_fd_exact_arg)
endif()
endmacro()
endif()
macro(check_required_components _NAME)
foreach(comp ${${_NAME}_FIND_COMPONENTS})
if(NOT ${_NAME}_${comp}_FOUND)
if(${_NAME}_FIND_REQUIRED_${comp})
set(${_NAME}_FOUND FALSE)
endif()
endif()
endforeach()
endmacro()
####################################################################################
set_and_check(rocshmem_INCLUDE_DIR ${PACKAGE_PREFIX_DIR}/include)
set_and_check(rocshmem_INCLUDE_DIRS ${PACKAGE_PREFIX_DIR}/include)
set_and_check(ROCSHMEM_INCLUDE_DIR ${PACKAGE_PREFIX_DIR}/include)
set_and_check(ROCSHMEM_INCLUDE_DIRS ${PACKAGE_PREFIX_DIR}/include)
set_and_check(rocshmem_INCLUDE_DIR ${PACKAGE_PREFIX_DIR}/include)
set_and_check(rocshmem_INCLUDE_DIRS ${PACKAGE_PREFIX_DIR}/include)
set_and_check(rocshmem_TARGET_FILE ${PACKAGE_PREFIX_DIR}/lib/cmake/rocshmem/rocshmem-targets.cmake)
include(${rocshmem_TARGET_FILE})
set(rocshmem_LIBRARIES roc::rocshmem)
set(rocshmem_LIBRARY roc::rocshmem)
set(ROCSHMEM_LIBRARIES roc::rocshmem)
set(ROCSHMEM_LIBRARY roc::rocshmem)
set(rocshmem_LIBRARIES roc::rocshmem)
set(rocshmem_LIBRARY roc::rocshmem)
#----------------------------------------------------------------
# Generated CMake target import file for configuration "Release".
#----------------------------------------------------------------
# Commands may need to know the format version.
set(CMAKE_IMPORT_FILE_VERSION 1)
# Import target "roc::rocshmem" for configuration "Release"
set_property(TARGET roc::rocshmem APPEND PROPERTY IMPORTED_CONFIGURATIONS RELEASE)
set_target_properties(roc::rocshmem PROPERTIES
IMPORTED_LINK_INTERFACE_LANGUAGES_RELEASE "CXX"
IMPORTED_LOCATION_RELEASE "${_IMPORT_PREFIX}/lib/librocshmem.a"
)
list(APPEND _cmake_import_check_targets roc::rocshmem )
list(APPEND _cmake_import_check_files_for_roc::rocshmem "${_IMPORT_PREFIX}/lib/librocshmem.a" )
# Commands beyond this point should not need to know the version.
set(CMAKE_IMPORT_FILE_VERSION)
# Generated by CMake
if("${CMAKE_MAJOR_VERSION}.${CMAKE_MINOR_VERSION}" LESS 2.8)
message(FATAL_ERROR "CMake >= 2.8.0 required")
endif()
if(CMAKE_VERSION VERSION_LESS "2.8.12")
message(FATAL_ERROR "CMake >= 2.8.12 required")
endif()
cmake_policy(PUSH)
cmake_policy(VERSION 2.8.12...3.27)
#----------------------------------------------------------------
# Generated CMake target import file.
#----------------------------------------------------------------
# Commands may need to know the format version.
set(CMAKE_IMPORT_FILE_VERSION 1)
# Protect against multiple inclusion, which would fail when already imported targets are added once more.
set(_cmake_targets_defined "")
set(_cmake_targets_not_defined "")
set(_cmake_expected_targets "")
foreach(_cmake_expected_target IN ITEMS roc::rocshmem)
list(APPEND _cmake_expected_targets "${_cmake_expected_target}")
if(TARGET "${_cmake_expected_target}")
list(APPEND _cmake_targets_defined "${_cmake_expected_target}")
else()
list(APPEND _cmake_targets_not_defined "${_cmake_expected_target}")
endif()
endforeach()
unset(_cmake_expected_target)
if(_cmake_targets_defined STREQUAL _cmake_expected_targets)
unset(_cmake_targets_defined)
unset(_cmake_targets_not_defined)
unset(_cmake_expected_targets)
unset(CMAKE_IMPORT_FILE_VERSION)
cmake_policy(POP)
return()
endif()
if(NOT _cmake_targets_defined STREQUAL "")
string(REPLACE ";" ", " _cmake_targets_defined_text "${_cmake_targets_defined}")
string(REPLACE ";" ", " _cmake_targets_not_defined_text "${_cmake_targets_not_defined}")
message(FATAL_ERROR "Some (but not all) targets in this export set were already defined.\nTargets Defined: ${_cmake_targets_defined_text}\nTargets not yet defined: ${_cmake_targets_not_defined_text}\n")
endif()
unset(_cmake_targets_defined)
unset(_cmake_targets_not_defined)
unset(_cmake_expected_targets)
# Compute the installation prefix relative to this file.
get_filename_component(_IMPORT_PREFIX "${CMAKE_CURRENT_LIST_FILE}" PATH)
get_filename_component(_IMPORT_PREFIX "${_IMPORT_PREFIX}" PATH)
get_filename_component(_IMPORT_PREFIX "${_IMPORT_PREFIX}" PATH)
get_filename_component(_IMPORT_PREFIX "${_IMPORT_PREFIX}" PATH)
if(_IMPORT_PREFIX STREQUAL "/")
set(_IMPORT_PREFIX "")
endif()
# Create imported target roc::rocshmem
add_library(roc::rocshmem STATIC IMPORTED)
set_target_properties(roc::rocshmem PROPERTIES
INTERFACE_COMPILE_OPTIONS "-fgpu-rdc;-fgpu-rdc"
INTERFACE_INCLUDE_DIRECTORIES "${_IMPORT_PREFIX}/include;${_IMPORT_PREFIX}/include"
INTERFACE_LINK_LIBRARIES "IBVerbs::verbs;numa;Threads::Threads;MPI::MPI_CXX;hip::device;hip::host;hsa-runtime64::hsa-runtime64;-fgpu-rdc"
)
# Load information for each installed configuration.
file(GLOB _cmake_config_files "${CMAKE_CURRENT_LIST_DIR}/rocshmem-targets-*.cmake")
foreach(_cmake_config_file IN LISTS _cmake_config_files)
include("${_cmake_config_file}")
endforeach()
unset(_cmake_config_file)
unset(_cmake_config_files)
# Cleanup temporary variables.
set(_IMPORT_PREFIX)
# Loop over all imported files and verify that they actually exist
foreach(_cmake_target IN LISTS _cmake_import_check_targets)
if(CMAKE_VERSION VERSION_LESS "3.28"
OR NOT DEFINED _cmake_import_check_xcframework_for_${_cmake_target}
OR NOT IS_DIRECTORY "${_cmake_import_check_xcframework_for_${_cmake_target}}")
foreach(_cmake_file IN LISTS "_cmake_import_check_files_for_${_cmake_target}")
if(NOT EXISTS "${_cmake_file}")
message(FATAL_ERROR "The imported target \"${_cmake_target}\" references the file
\"${_cmake_file}\"
but this file does not exist. Possible reasons include:
* The file was deleted, renamed, or moved to another location.
* An install or uninstall procedure did not complete successfully.
* The installation package was faulty and contained
\"${CMAKE_CURRENT_LIST_FILE}\"
but not all the files it references.
")
endif()
endforeach()
endif()
unset(_cmake_file)
unset("_cmake_import_check_files_for_${_cmake_target}")
endforeach()
unset(_cmake_target)
unset(_cmake_import_check_targets)
# This file does not depend on other imported targets which have
# been exported from the same project but in a separate export set.
# Commands beyond this point should not need to know the version.
set(CMAKE_IMPORT_FILE_VERSION)
cmake_policy(POP)
../../../lib/cmake/rocshmem/rocshmem-config-version.cmake
\ No newline at end of file
../../../lib/cmake/rocshmem/rocshmem-config.cmake
\ No newline at end of file
../../../lib/cmake/rocshmem/rocshmem-targets-release.cmake
\ No newline at end of file
../../../lib/cmake/rocshmem/rocshmem-targets.cmake
\ No newline at end of file
../../lib/librocshmem.a
\ No newline at end of file
MIT License
Copyright (c) Advanced Micro Devices, Inc. All rights reserved.
SPDX-License-Identifier: MIT
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.
import os
import subprocess
import setuptools
import importlib
from pathlib import Path
from torch.utils.cpp_extension import BuildExtension, CUDAExtension
# Wheel specific: the wheels only include the soname of the host library `libnvshmem_host.so.X`
def get_nvshmem_host_lib_name(base_dir):
path = Path(base_dir).joinpath('lib')
for file in path.rglob('libnvshmem_host.so.*'):
return file.name
raise ModuleNotFoundError('libnvshmem_host.so not found')
if __name__ == '__main__':
disable_nvshmem = False
nvshmem_dir = os.getenv('NVSHMEM_DIR', None)
nvshmem_host_lib = 'libnvshmem_host.so'
if nvshmem_dir is None:
try:
nvshmem_dir = importlib.util.find_spec("nvidia.nvshmem").submodule_search_locations[0]
nvshmem_host_lib = get_nvshmem_host_lib_name(nvshmem_dir)
import nvidia.nvshmem as nvshmem
except (ModuleNotFoundError, AttributeError, IndexError):
print('Warning: `NVSHMEM_DIR` is not specified, and the NVSHMEM module is not installed. All internode and low-latency features are disabled\n')
disable_nvshmem = True
else:
disable_nvshmem = False
if not disable_nvshmem:
assert os.path.exists(nvshmem_dir), f'The specified NVSHMEM directory does not exist: {nvshmem_dir}'
cxx_flags = ['-O3', '-Wno-deprecated-declarations', '-Wno-unused-variable',
'-Wno-sign-compare', '-Wno-reorder', '-Wno-attributes']
nvcc_flags = ['-O3', '-Xcompiler', '-O3']
sources = ['csrc/deep_ep.cpp', 'csrc/kernels/runtime.cu', 'csrc/kernels/layout.cu', 'csrc/kernels/intranode.cu']
include_dirs = ['csrc/']
library_dirs = []
nvcc_dlink = []
extra_link_args = []
# NVSHMEM flags
if disable_nvshmem:
cxx_flags.append('-DDISABLE_NVSHMEM')
nvcc_flags.append('-DDISABLE_NVSHMEM')
else:
sources.extend(['csrc/kernels/internode.cu', 'csrc/kernels/internode_ll.cu'])
include_dirs.extend([f'{nvshmem_dir}/include'])
library_dirs.extend([f'{nvshmem_dir}/lib'])
nvcc_dlink.extend(['-dlink', f'-L{nvshmem_dir}/lib', '-lnvshmem_device'])
extra_link_args.extend([f'-l:{nvshmem_host_lib}', '-l:libnvshmem_device.a', f'-Wl,-rpath,{nvshmem_dir}/lib'])
if int(os.getenv('DISABLE_SM90_FEATURES', 0)):
# Prefer A100
os.environ['TORCH_CUDA_ARCH_LIST'] = os.getenv('TORCH_CUDA_ARCH_LIST', '8.0')
# Disable some SM90 features: FP8, launch methods, and TMA
cxx_flags.append('-DDISABLE_SM90_FEATURES')
nvcc_flags.append('-DDISABLE_SM90_FEATURES')
# Disable internode and low-latency kernels
assert disable_nvshmem
else:
# Prefer H800 series
os.environ['TORCH_CUDA_ARCH_LIST'] = os.getenv('TORCH_CUDA_ARCH_LIST', '9.0')
# CUDA 12 flags
nvcc_flags.extend(['-rdc=true', '--ptxas-options=--register-usage-level=10'])
# Disable LD/ST tricks, as some CUDA version does not support `.L1::no_allocate`
if os.environ['TORCH_CUDA_ARCH_LIST'].strip() != '9.0':
assert int(os.getenv('DISABLE_AGGRESSIVE_PTX_INSTRS', 1)) == 1
os.environ['DISABLE_AGGRESSIVE_PTX_INSTRS'] = '1'
# Disable aggressive PTX instructions
if int(os.getenv('DISABLE_AGGRESSIVE_PTX_INSTRS', '1')):
cxx_flags.append('-DDISABLE_AGGRESSIVE_PTX_INSTRS')
nvcc_flags.append('-DDISABLE_AGGRESSIVE_PTX_INSTRS')
# Bits of `topk_idx.dtype`, choices are 32 and 64
if "TOPK_IDX_BITS" in os.environ:
topk_idx_bits = int(os.environ['TOPK_IDX_BITS'])
cxx_flags.append(f'-DTOPK_IDX_BITS={topk_idx_bits}')
nvcc_flags.append(f'-DTOPK_IDX_BITS={topk_idx_bits}')
# Put them together
extra_compile_args = {
'cxx': cxx_flags,
'nvcc': nvcc_flags,
}
if len(nvcc_dlink) > 0:
extra_compile_args['nvcc_dlink'] = nvcc_dlink
# Summary
print(f'Build summary:')
print(f' > Sources: {sources}')
print(f' > Includes: {include_dirs}')
print(f' > Libraries: {library_dirs}')
print(f' > Compilation flags: {extra_compile_args}')
print(f' > Link flags: {extra_link_args}')
print(f' > Arch list: {os.environ["TORCH_CUDA_ARCH_LIST"]}')
print(f' > NVSHMEM path: {nvshmem_dir}')
print()
# noinspection PyBroadException
try:
cmd = ['git', 'rev-parse', '--short', 'HEAD']
revision = '+' + subprocess.check_output(cmd).decode('ascii').rstrip()
......@@ -114,21 +12,9 @@ if __name__ == '__main__':
setuptools.setup(
name='deep_ep',
version='1.2.1' + revision,
packages=setuptools.find_packages(
include=['deep_ep']
),
ext_modules=[
CUDAExtension(
name='deep_ep_cpp',
include_dirs=include_dirs,
library_dirs=library_dirs,
sources=sources,
extra_compile_args=extra_compile_args,
extra_link_args=extra_link_args
)
],
cmdclass={
'build_ext': BuildExtension
}
version='1.0.0' + revision,
packages=setuptools.find_packages(include=['deep_ep']),
include_package_data=True,
package_data={"deep_ep": ["deep_ep_cpp.cpython-310-x86_64-linux-gnu.so"]},
zip_safe=False,
)
......@@ -15,7 +15,7 @@ import test_low_latency
# noinspection PyShadowingNames
def test_main(args: argparse.Namespace, num_sms: int,
local_rank: int, num_local_ranks: int, num_ranks: int, num_nodes: int, rank: int,
buffer: deep_ep.Buffer, group: dist.ProcessGroup, skip_benchmark: bool = False):
buffer: deep_ep.Buffer, group: dist.ProcessGroup, skip_benchmark: bool = True):
# Settings
num_tokens, hidden = args.num_tokens, args.hidden
num_topk_groups, num_topk, num_experts = args.num_topk_groups, args.num_topk, args.num_experts
......@@ -35,7 +35,7 @@ def test_main(args: argparse.Namespace, num_sms: int,
group_idx = torch.topk(group_scores, k=num_topk_groups, dim=-1, sorted=False).indices
masked_scores = create_grouped_scores(scores, group_idx, num_nodes)
topk_idx = torch.topk(masked_scores, num_topk, dim=-1, largest=True, sorted=False)[1]
topk_idx = topk_idx.to(deep_ep.topk_idx_t)
# topk_idx = topk_idx.to(deep_ep.topk_idx_t)
topk_weights = torch.ones((num_tokens, num_topk), dtype=torch.float32, device='cuda') * rank
topk_weights_pure_rand = torch.randn((num_tokens, num_topk), dtype=torch.float32, device='cuda')
rank_idx = topk_idx // (num_experts // num_ranks)
......@@ -106,7 +106,7 @@ def test_main(args: argparse.Namespace, num_sms: int,
check_start = check_end
for previous_mode in (False, True):
for async_mode in (False, True):
for async_mode in (False, False):
for current_x in (x_pure_rand, x, x_pure_rand_e4m3, x_e4m3):
for with_topk in (False, True):
is_rand = current_x is x_pure_rand or current_x is x_pure_rand_e4m3
......@@ -158,16 +158,18 @@ def test_main(args: argparse.Namespace, num_sms: int,
check_data(recv_x, recv_gbl_rank_prefix_sum)
# Test combine
# torch.cuda.synchronize()
# print("lijian test dipatch end and combine start.")
bias_0 = torch.ones((num_tokens, hidden), dtype=torch.bfloat16, device='cuda')
bias_1 = torch.randn((num_tokens, hidden), dtype=torch.bfloat16, device='cuda')
combine_args = {'x': recv_x, 'bias': (bias_0, bias_1), 'handle': handle, 'config': config, 'async_finish': async_mode}
combine_args = {'x': recv_x, 'handle': handle, 'config': config}
if with_topk:
combine_args.update({'topk_weights': recv_topk_weights})
if previous_mode:
combine_args.update({'previous_event': buffer.capture()})
combined_x, combined_topk_weights, event = buffer.combine(**combine_args)
event.current_stream_wait() if async_mode else ()
check_x = (combined_x.float() - bias_0.float() - bias_1.float()) / is_token_in_rank.sum(dim=1).unsqueeze(1)
check_x = combined_x.float() / is_token_in_rank.sum(dim=1).unsqueeze(1)
ref_x = x_pure_rand if is_rand else x
assert calc_diff(check_x, ref_x) < 5e-4 if current_x is x_pure_rand_e4m3 else 5e-6
if with_topk:
......@@ -191,6 +193,7 @@ def test_main(args: argparse.Namespace, num_sms: int,
if skip_benchmark:
return hash_value
# print("benchmark start:")
# Tune dispatch performance
best_dispatch_results = None
fp8_factor = (1 + 4 / 128) / 2
......@@ -262,7 +265,7 @@ def test_loop(local_rank: int, num_local_ranks: int, args: argparse.Namespace):
num_qps_per_rank = max(num_sms, ll_num_experts // num_ranks if args.test_ll_compatibility else 0)
buffer = deep_ep.Buffer(group, int(2e9), int(1e9), low_latency_mode=args.test_ll_compatibility,
num_qps_per_rank=num_qps_per_rank, explicitly_destroy=True)
num_qps_per_rank=num_qps_per_rank, explicitly_destroy=True, use_default_stream_as_comm_stream=False)
assert num_local_ranks == 8 and num_ranks > 8
for seed in range(int(1e9)):
......
......@@ -25,11 +25,11 @@ def test_main(args: argparse.Namespace, num_sms: int, local_rank: int, num_ranks
# Random data
x = torch.ones((num_tokens, hidden), dtype=torch.bfloat16, device='cuda') * rank
x_pure_rand = torch.randn((num_tokens, hidden), dtype=torch.bfloat16, device='cuda')
x_e4m3 = per_token_cast_to_fp8(x) if deep_ep.Buffer.is_sm90_compiled() else None
x_e4m3 = (x_e4m3[0], x_e4m3[1].T.contiguous().T) if x_e4m3 is not None else None
x_e4m3 = None # per_token_cast_to_fp8(x) if deep_ep.Buffer.is_sm90_compiled() else None
x_e4m3 = None # (x_e4m3[0], x_e4m3[1].T.contiguous().T) if x_e4m3 is not None else None
scores = torch.randn((num_tokens, num_experts), dtype=torch.float32, device='cuda').abs() + 1
topk_idx = torch.topk(scores, num_topk, dim=-1, largest=True, sorted=False)[1]
topk_idx = topk_idx.to(deep_ep.topk_idx_t)
# topk_idx = topk_idx.to(deep_ep.topk_idx_t)
topk_weights = torch.ones((num_tokens, num_topk), dtype=torch.float32, device='cuda') * rank
topk_weights_pure_rand = torch.randn((num_tokens, num_topk), dtype=torch.float32, device='cuda')
rank_idx = topk_idx // (num_experts // num_ranks)
......
......@@ -37,7 +37,7 @@ def test_main(num_tokens: int, hidden: int, num_experts: int, num_topk: int,
scores = torch.randn((num_tokens, num_experts), dtype=torch.float32, device='cuda').abs() + 1
topk_idx = torch.topk(scores, num_topk, dim=-1, largest=True, sorted=True)[1]
topk_idx = topk_idx.to(deep_ep.topk_idx_t)
# topk_idx = topk_idx.to(int64_t)
topk_weights = torch.randn((num_tokens, num_topk), dtype=torch.float32, device='cuda').abs()
# Randomly mask some positions
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment