"...composable_kernel.git" did not exist on "41cdd3801a873def1c1220da9860f5202f36edbd"
Commit f3bbfe3e authored by aska-0096's avatar aska-0096
Browse files

Merge branch 'develop' of https://github.com/ROCm/composable_kernel into update_cka8w8

parents 2b840f5a efb34741
...@@ -145,20 +145,20 @@ message("hip_version_flat=${hip_VERSION_FLAT}") ...@@ -145,20 +145,20 @@ message("hip_version_flat=${hip_VERSION_FLAT}")
message("checking which targets are supported") message("checking which targets are supported")
#In order to build just the CK library (without tests and examples) for all supported GPU targets #In order to build just the CK library (without tests and examples) for all supported GPU targets
#use -D GPU_ARCHS="gfx908;gfx90a;gfx940;gfx941;gfx942;gfx1030;gfx1100;gfx1101;gfx1102;gfx1200;gfx1201" #use -D GPU_ARCHS="gfx908;gfx90a;gfx942;gfx1030;gfx1100;gfx1101;gfx1102;gfx1200;gfx1201"
#the GPU_TARGETS flag will be reset in this case in order to avoid conflicts. #the GPU_TARGETS flag will be reset in this case in order to avoid conflicts.
# #
#In order to build CK along with all tests and examples it should be OK to set GPU_TARGETS to just 1 or 2 similar architectures. #In order to build CK along with all tests and examples it should be OK to set GPU_TARGETS to just 1 or 2 similar architectures.
if(NOT ENABLE_ASAN_PACKAGING) if(NOT ENABLE_ASAN_PACKAGING)
if(NOT WIN32 AND ${hip_VERSION_FLAT} LESS 600300000) if(NOT WIN32 AND ${hip_VERSION_FLAT} LESS 600300000)
# WORKAROUND: compiler does not yet fully support gfx12 targets, need to fix version above # WORKAROUND: compiler does not yet fully support gfx12 targets, need to fix version above
set(CK_GPU_TARGETS "gfx908;gfx90a;gfx940;gfx941;gfx942;gfx1030;gfx1100;gfx1101;gfx1102") set(CK_GPU_TARGETS "gfx908;gfx90a;gfx942;gfx1030;gfx1100;gfx1101;gfx1102")
else() else()
set(CK_GPU_TARGETS "gfx908;gfx90a;gfx940;gfx941;gfx942;gfx1030;gfx1100;gfx1101;gfx1102;gfx1200;gfx1201") set(CK_GPU_TARGETS "gfx908;gfx90a;gfx942;gfx1030;gfx1100;gfx1101;gfx1102;gfx1200;gfx1201")
endif() endif()
else() else()
#build CK only for xnack-supported targets when using ASAN #build CK only for xnack-supported targets when using ASAN
set(CK_GPU_TARGETS "gfx908:xnack+;gfx90a:xnack+;gfx940:xnack+;gfx941:xnack+;gfx942:xnack+") set(CK_GPU_TARGETS "gfx908:xnack+;gfx90a:xnack+;gfx942:xnack+")
endif() endif()
#if user set GPU_ARCHS on the cmake command line, overwrite default target list with user's list #if user set GPU_ARCHS on the cmake command line, overwrite default target list with user's list
...@@ -183,12 +183,14 @@ message("Building CK for the following targets: ${SUPPORTED_GPU_TARGETS}") ...@@ -183,12 +183,14 @@ message("Building CK for the following targets: ${SUPPORTED_GPU_TARGETS}")
if (SUPPORTED_GPU_TARGETS MATCHES "gfx9") if (SUPPORTED_GPU_TARGETS MATCHES "gfx9")
message("Enabling XDL instances") message("Enabling XDL instances")
add_definitions(-DCK_USE_XDL) add_definitions(-DCK_USE_XDL)
set(CK_USE_XDL "ON") endif()
if (SUPPORTED_GPU_TARGETS MATCHES "gfx94")
message("Enabling FP8 gemms in ckProfiler")
add_definitions(-DCK_USE_GFX94)
endif() endif()
if (SUPPORTED_GPU_TARGETS MATCHES "gfx11" OR SUPPORTED_GPU_TARGETS MATCHES "gfx12") if (SUPPORTED_GPU_TARGETS MATCHES "gfx11" OR SUPPORTED_GPU_TARGETS MATCHES "gfx12")
message("Enabling WMMA instances") message("Enabling WMMA instances")
add_definitions(-DCK_USE_WMMA) add_definitions(-DCK_USE_WMMA)
set(CK_USE_WMMA "ON")
endif() endif()
option(CK_USE_FP8_ON_UNSUPPORTED_ARCH "Enable FP8 GEMM instances on older architectures" OFF) option(CK_USE_FP8_ON_UNSUPPORTED_ARCH "Enable FP8 GEMM instances on older architectures" OFF)
if(CK_USE_FP8_ON_UNSUPPORTED_ARCH AND (SUPPORTED_GPU_TARGETS MATCHES "gfx90a" OR SUPPORTED_GPU_TARGETS MATCHES "gfx908")) if(CK_USE_FP8_ON_UNSUPPORTED_ARCH AND (SUPPORTED_GPU_TARGETS MATCHES "gfx90a" OR SUPPORTED_GPU_TARGETS MATCHES "gfx908"))
...@@ -221,7 +223,7 @@ if(NOT WIN32 AND ${hip_VERSION_FLAT} GREATER 600140090) ...@@ -221,7 +223,7 @@ if(NOT WIN32 AND ${hip_VERSION_FLAT} GREATER 600140090)
endif() endif()
set(check-coerce) set(check-coerce)
check_cxx_compiler_flag(" -mllvm -amdgpu-coerce-illegal-types=1" check-coerce) check_cxx_compiler_flag(" -mllvm -amdgpu-coerce-illegal-types=1" check-coerce)
if(NOT WIN32 AND check-coerce AND ${hip_VERSION_FLAT} GREATER 600241132 AND ${hip_VERSION_FLAT} LESS 600300000) if(NOT WIN32 AND check-coerce AND ${hip_VERSION_FLAT} GREATER 600241132)
message("Adding the amdgpu-coerce-illegal-types=1") message("Adding the amdgpu-coerce-illegal-types=1")
add_compile_options("SHELL: -mllvm -amdgpu-coerce-illegal-types=1") add_compile_options("SHELL: -mllvm -amdgpu-coerce-illegal-types=1")
endif() endif()
......
...@@ -24,10 +24,10 @@ RUN if [ "$ROCMVERSION" != "6.3" ]; then \ ...@@ -24,10 +24,10 @@ RUN if [ "$ROCMVERSION" != "6.3" ]; then \
sh -c "echo deb [arch=amd64 signed-by=/etc/apt/trusted.gpg.d/rocm-keyring.gpg] $DEB_ROCM_REPO focal main > /etc/apt/sources.list.d/rocm.list" && \ sh -c "echo deb [arch=amd64 signed-by=/etc/apt/trusted.gpg.d/rocm-keyring.gpg] $DEB_ROCM_REPO focal main > /etc/apt/sources.list.d/rocm.list" && \
sh -c 'echo deb [arch=amd64 signed-by=/etc/apt/trusted.gpg.d/rocm-keyring.gpg] https://repo.radeon.com/amdgpu/$ROCMVERSION/ubuntu focal main > /etc/apt/sources.list.d/amdgpu.list'; \ sh -c 'echo deb [arch=amd64 signed-by=/etc/apt/trusted.gpg.d/rocm-keyring.gpg] https://repo.radeon.com/amdgpu/$ROCMVERSION/ubuntu focal main > /etc/apt/sources.list.d/amdgpu.list'; \
elif [ "$ROCMVERSION" = "6.3" ] && [ "$compiler_version" = "rc1" ]; then \ elif [ "$ROCMVERSION" = "6.3" ] && [ "$compiler_version" = "rc1" ]; then \
sh -c "wget http://artifactory-cdn.amd.com/artifactory/list/amdgpu-deb/amdgpu-install-internal_6.3.0.1-20.04-1_all.deb --no-check-certificate" && \ sh -c "wget http://artifactory-cdn.amd.com/artifactory/list/amdgpu-deb/amdgpu-install-internal_6.3-20.04-1_all.deb --no-check-certificate" && \
apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install dialog libpopt0 rsync && DEBIAN_FRONTEND=noninteractive apt-get install ./amdgpu-install-internal_6.3.0.1-20.04-1_all.deb && \ apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install dialog libpopt0 rsync && DEBIAN_FRONTEND=noninteractive apt-get install ./amdgpu-install-internal_6.3-20.04-1_all.deb && \
sh -c 'echo deb [arch=amd64 trusted=yes] http://compute-artifactory.amd.com/artifactory/list/rocm-release-archive-20.04-deb/ 6.3.0.1 rel-5 > /etc/apt/sources.list.d/rocm-build.list' && \ sh -c 'echo deb [arch=amd64 trusted=yes] http://compute-artifactory.amd.com/artifactory/list/rocm-release-archive-20.04-deb/ 6.3 rel-20 > /etc/apt/sources.list.d/rocm-build.list' && \
amdgpu-repo --amdgpu-build=2033700; \ amdgpu-repo --amdgpu-build=2074281; \
fi fi
RUN sh -c "echo deb http://mirrors.kernel.org/ubuntu focal main universe | tee -a /etc/apt/sources.list" RUN sh -c "echo deb http://mirrors.kernel.org/ubuntu focal main universe | tee -a /etc/apt/sources.list"
......
...@@ -1101,11 +1101,11 @@ pipeline { ...@@ -1101,11 +1101,11 @@ pipeline {
agent{ label rocmnode("gfx90a") } agent{ label rocmnode("gfx90a") }
environment{ environment{
setup_args = """ -DCMAKE_INSTALL_PREFIX=../install \ setup_args = """ -DCMAKE_INSTALL_PREFIX=../install \
-DGPU_TARGETS="gfx908;gfx90a;gfx940;gfx941;gfx942" \ -DGPU_TARGETS="gfx908;gfx90a;gfx942" \
-DCMAKE_CXX_FLAGS=" -O3 " """ -DCMAKE_CXX_FLAGS=" -O3 " """
execute_args = """ cd ../client_example && rm -rf build && mkdir build && cd build && \ execute_args = """ cd ../client_example && rm -rf build && mkdir build && cd build && \
cmake -DCMAKE_PREFIX_PATH="${env.WORKSPACE}/install;/opt/rocm" \ cmake -DCMAKE_PREFIX_PATH="${env.WORKSPACE}/install;/opt/rocm" \
-DGPU_TARGETS="gfx908;gfx90a;gfx940;gfx941;gfx942" \ -DGPU_TARGETS="gfx908;gfx90a;gfx942" \
-DCMAKE_CXX_COMPILER="${build_compiler()}" \ -DCMAKE_CXX_COMPILER="${build_compiler()}" \
-DCMAKE_CXX_FLAGS=" -O3 " .. && make -j """ -DCMAKE_CXX_FLAGS=" -O3 " .. && make -j """
} }
...@@ -1165,7 +1165,7 @@ pipeline { ...@@ -1165,7 +1165,7 @@ pipeline {
execute_args = """ cmake -D CMAKE_PREFIX_PATH=/opt/rocm \ execute_args = """ cmake -D CMAKE_PREFIX_PATH=/opt/rocm \
-D CMAKE_CXX_COMPILER="${build_compiler()}" \ -D CMAKE_CXX_COMPILER="${build_compiler()}" \
-D CMAKE_BUILD_TYPE=Release \ -D CMAKE_BUILD_TYPE=Release \
-D GPU_ARCHS="gfx908;gfx90a;gfx940;gfx941;gfx942;gfx1030;gfx1100;gfx1101;gfx1102" \ -D GPU_ARCHS="gfx908;gfx90a;gfx942;gfx1030;gfx1100;gfx1101;gfx1102" \
-D CMAKE_CXX_FLAGS=" -O3 " .. && make -j64 """ -D CMAKE_CXX_FLAGS=" -O3 " .. && make -j64 """
} }
steps{ steps{
......
rocm-docs-core==1.8.3 rocm-docs-core==1.8.4
sphinxcontrib-bibtex==2.6.3 sphinxcontrib-bibtex==2.6.3
...@@ -103,7 +103,7 @@ requests==2.32.3 ...@@ -103,7 +103,7 @@ requests==2.32.3
# via # via
# pygithub # pygithub
# sphinx # sphinx
rocm-docs-core==1.8.3 rocm-docs-core==1.8.4
# via -r requirements.in # via -r requirements.in
six==1.16.0 six==1.16.0
# via pybtex # via pybtex
......
...@@ -80,7 +80,7 @@ using RLayout = typename LayoutSettingSelector<NDimSpatial>::RLayout; ...@@ -80,7 +80,7 @@ using RLayout = typename LayoutSettingSelector<NDimSpatial>::RLayout;
struct ExecutionConfig final struct ExecutionConfig final
{ {
bool do_verification = true; bool do_verification = true;
int init_method = 1; int init_method = 2;
bool time_kernel = false; bool time_kernel = false;
}; };
......
...@@ -73,16 +73,25 @@ bool run_convnd_fwd_max(const ck::utils::conv::ConvParam& problem_size, ...@@ -73,16 +73,25 @@ bool run_convnd_fwd_max(const ck::utils::conv::ConvParam& problem_size,
Tensor<EDataType> conv_output_device(conv_output_g_n_k_wos_desc); Tensor<EDataType> conv_output_device(conv_output_g_n_k_wos_desc);
Tensor<R0DataType> r0_device(r0_desc); Tensor<R0DataType> r0_device(r0_desc);
std::cout << "input: " << conv_input.mDesc << std::endl;
std::cout << "weight: " << conv_weight.mDesc << std::endl;
std::cout << "output: " << conv_output_device.mDesc << std::endl;
std::cout << "reduction: " << r0_device.mDesc << std::endl << std::endl;
switch(config.init_method) switch(config.init_method)
{ {
case 0: break; case 0: break;
case 1: case 1:
ck::utils::FillUniformDistributionIntegerValue<ADataType>{-8, 7}(conv_input); ck::utils::FillUniformDistributionIntegerValue<ADataType>{-8, 7}(conv_input);
ck::utils::FillUniformDistributionIntegerValue<BDataType>{-8, 7}(conv_weight); ck::utils::FillUniformDistributionIntegerValue<BDataType>{-1, 1}(conv_weight);
break;
case 2:
ck::utils::FillUniformDistributionIntegerValue<ADataType>{-8, 7}(conv_input);
ck::utils::FillUniformDistribution<BDataType>{-1, 1}(conv_weight);
break; break;
default: default:
ck::utils::FillUniformDistribution<ADataType>{-5, 5}(conv_input); ck::utils::FillUniformDistribution<ADataType>{-8, 7}(conv_input);
ck::utils::FillUniformDistribution<BDataType>{-5, 5}(conv_weight); ck::utils::FillUniformDistribution<BDataType>{-1, 1}(conv_weight);
} }
DeviceMem conv_input_device_buf(sizeof(ADataType) * conv_input.mDesc.GetElementSpaceSize()); DeviceMem conv_input_device_buf(sizeof(ADataType) * conv_input.mDesc.GetElementSpaceSize());
...@@ -161,15 +170,25 @@ bool run_convnd_fwd_max(const ck::utils::conv::ConvParam& problem_size, ...@@ -161,15 +170,25 @@ bool run_convnd_fwd_max(const ck::utils::conv::ConvParam& problem_size,
return false; return false;
} }
// XXX: DeviceGroupedConvFwdMultipleDMultipleR_Xdl_CShuffle will not initialize r0.
r0_device_buf.SetValue(ck::NumericLimits<R0DataType>::Lowest());
const float avg_time = invoker.Run(argument, StreamConfig{nullptr, config.time_kernel}); const float avg_time = invoker.Run(argument, StreamConfig{nullptr, config.time_kernel});
const std::size_t flop = problem_size.GetFlops(); if(config.time_kernel)
const std::size_t num_btype = problem_size.GetByte<ADataType, BDataType, EDataType>(); {
const std::size_t flop = problem_size.GetFlops();
const std::size_t num_btype = problem_size.GetByte<ADataType, BDataType, EDataType>();
const float tflops = static_cast<float>(flop) / 1.E9 / avg_time; const float tflops = static_cast<float>(flop) / 1.E9 / avg_time;
const float gb_per_sec = num_btype / 1.E6 / avg_time; const float gb_per_sec = num_btype / 1.E6 / avg_time;
std::cout << "Perf: " << avg_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s, " std::cout << "Perf: " << avg_time << " ms, " << tflops << " TFlops, " << gb_per_sec
<< conv.GetTypeString() << std::endl; << " GB/s, " << conv.GetTypeString() << std::endl;
}
else
{
std::cout << "FINISHED: " << conv.GetTypeString() << std::endl;
}
if(config.do_verification) if(config.do_verification)
{ {
...@@ -189,6 +208,7 @@ bool run_convnd_fwd_max(const ck::utils::conv::ConvParam& problem_size, ...@@ -189,6 +208,7 @@ bool run_convnd_fwd_max(const ck::utils::conv::ConvParam& problem_size,
BElementOp{}, BElementOp{},
PassThrough{}); PassThrough{});
std::cout << "\nRunning verification on CPU." << std::endl;
ref_invoker.Run(ref_argument); ref_invoker.Run(ref_argument);
Tensor<R0DataType> r0_host(r0_device.mDesc); Tensor<R0DataType> r0_host(r0_device.mDesc);
...@@ -273,13 +293,18 @@ bool run_convnd_fwd_max(const ck::utils::conv::ConvParam& problem_size, ...@@ -273,13 +293,18 @@ bool run_convnd_fwd_max(const ck::utils::conv::ConvParam& problem_size,
conv_output_device_buf.FromDevice(conv_output_device.mData.data()); conv_output_device_buf.FromDevice(conv_output_device.mData.data());
r0_device_buf.FromDevice(r0_device.mData.data()); r0_device_buf.FromDevice(r0_device.mData.data());
return ck::utils::check_err(conv_output_device, auto pass = ck::utils::check_err(conv_output_device,
conv_output_host, conv_output_host,
"Error: incorrect results! (Matrix E)", "Error: incorrect results! (Matrix E)",
1e-5f, 1e-3f,
1e-4f) && 1e-3f);
ck::utils::check_err( pass =
r0_device, r0_host, "Error: incorrect results! (Matrix R0)", 1e-5f, 1e-4f); pass && ck::utils::check_err(
r0_device, r0_host, "Error: incorrect results! (Matrix R0)", 1e-3f, 1e-3f);
if(pass)
std::cout << "Verification on CPU: PASS" << std::endl;
return pass;
} }
return true; return true;
......
...@@ -198,7 +198,7 @@ int main() ...@@ -198,7 +198,7 @@ int main()
throw std::runtime_error("wrong! this device_op instance does not support this problem"); throw std::runtime_error("wrong! this device_op instance does not support this problem");
} }
// init reducetion buffer to 0 // init reduction buffer to 0
r0_device_buf.SetZero(); r0_device_buf.SetZero();
r1_device_buf.SetZero(); r1_device_buf.SetZero();
......
...@@ -68,7 +68,7 @@ using DeviceElementwisePermuteInstance = ck::tensor_operation::device::DeviceEle ...@@ -68,7 +68,7 @@ using DeviceElementwisePermuteInstance = ck::tensor_operation::device::DeviceEle
using DeviceReduceInstance = using DeviceReduceInstance =
ck::tensor_operation::device::DeviceReduceMultiBlock<OutputDataType, ck::tensor_operation::device::DeviceReduceMultiBlock<OutputDataType,
OutputDataType, ScaleDataType,
OutputDataType, OutputDataType,
NumDim, NumDim,
NumDim, NumDim,
...@@ -108,7 +108,8 @@ void reference_scale_permute_amax(Tensor<InputDataType>& input, ...@@ -108,7 +108,8 @@ void reference_scale_permute_amax(Tensor<InputDataType>& input,
host_output_scaled_casted_transposed(m, k) = y1; host_output_scaled_casted_transposed(m, k) = y1;
const OutputDataType y_fabs = const OutputDataType y_fabs =
ck::type_convert<OutputDataType>(ck::math::abs(ck::type_convert<float>(y0))); ck::type_convert<OutputDataType>(ck::math::abs(ck::type_convert<float>(y0)));
host_output_amax(0) = ck::math::max(y_fabs, host_output_amax(0)); host_output_amax(0) = ck::type_convert<OutputDataType>(ck::math::max(
ck::type_convert<float>(y_fabs), ck::type_convert<float>(host_output_amax(0))));
} }
} }
} }
......
...@@ -85,9 +85,9 @@ function(add_example_executable EXAMPLE_NAME FILE_NAME) ...@@ -85,9 +85,9 @@ function(add_example_executable EXAMPLE_NAME FILE_NAME)
#only continue if there are some source files left on the list #only continue if there are some source files left on the list
if(FILE_NAME) if(FILE_NAME)
if(FILE_NAME MATCHES "_xdl") if(FILE_NAME MATCHES "_xdl")
list(REMOVE_ITEM EX_TARGETS gfx1030 gfx1100 gfx1101 gfx1102 gfx1103 gfx1200 gfx1201) list(REMOVE_ITEM EX_TARGETS gfx900 gfx906 gfx906:xnack- gfx1030 gfx1100 gfx1101 gfx1102 gfx1103 gfx1200 gfx1201 gfx10.3-generic gfx11-generic gfx12-generic)
elseif(FILE_NAME MATCHES "_wmma") elseif(FILE_NAME MATCHES "_wmma")
list(REMOVE_ITEM EX_TARGETS gfx908 gfx90a gfx940 gfx941 gfx942 gfx1030) list(REMOVE_ITEM EX_TARGETS gfx900 gfx906 gfx906:xnack- gfx908:xnack+ gfx908:xnack- gfx90a:xnack+ gfx90a:xnack- gfx908 gfx90a gfx940 gfx941 gfx942 gfx1030)
endif() endif()
set_source_files_properties(${FILE_NAME} PROPERTIES LANGUAGE HIP) set_source_files_properties(${FILE_NAME} PROPERTIES LANGUAGE HIP)
add_executable(${EXAMPLE_NAME} ${FILE_NAME}) add_executable(${EXAMPLE_NAME} ${FILE_NAME})
...@@ -169,9 +169,9 @@ function(add_example_executable_no_testing EXAMPLE_NAME FILE_NAME) ...@@ -169,9 +169,9 @@ function(add_example_executable_no_testing EXAMPLE_NAME FILE_NAME)
#only continue if there are some source files left on the list #only continue if there are some source files left on the list
if(FILE_NAME) if(FILE_NAME)
if(FILE_NAME MATCHES "_xdl") if(FILE_NAME MATCHES "_xdl")
list(REMOVE_ITEM EX_TARGETS gfx900 gfx906 gfx1030 gfx1100 gfx1101 gfx1102 gfx1103 gfx1200 gfx1201) list(REMOVE_ITEM EX_TARGETS gfx900 gfx906 gfx906:xnack- gfx1030 gfx1100 gfx1101 gfx1102 gfx1103 gfx1200 gfx1201 gfx10.3-generic gfx11-generic gfx12-generic)
elseif(FILE_NAME MATCHES "_wmma") elseif(FILE_NAME MATCHES "_wmma")
list(REMOVE_ITEM EX_TARGETS gfx900 gfx906 gfx908 gfx90a gfx940 gfx941 gfx942 gfx1030) list(REMOVE_ITEM EX_TARGETS gfx900 gfx906 gfx906:xnack- gfx908:xnack+ gfx908:xnack- gfx90a:xnack+ gfx90a:xnack- gfx908 gfx90a gfx940 gfx941 gfx942 gfx1030)
endif() endif()
set_source_files_properties(${FILE_NAME} PROPERTIES LANGUAGE HIP) set_source_files_properties(${FILE_NAME} PROPERTIES LANGUAGE HIP)
add_executable(${EXAMPLE_NAME} ${FILE_NAME}) add_executable(${EXAMPLE_NAME} ${FILE_NAME})
......
...@@ -29,14 +29,14 @@ while getopts ":sa" opt; do ...@@ -29,14 +29,14 @@ while getopts ":sa" opt; do
done done
run_fp16_bf16_tests() { run_fp16_bf16_tests() {
local NUM_SPLITS=(1) local NUM_SPLITS="1"
local PAGE_BLOCK_SIZE=(0) local PAGE_BLOCK_SIZE="0"
local CACHE_BATCH_IDX=(0) local CACHE_BATCH_IDX="0"
if [ $TEST_SPLITKV -eq 1 ] ; then if [ $TEST_SPLITKV -eq 1 ] ; then
NUM_SPLITS+=(2 3) NUM_SPLITS="$NUM_SPLITS 2 3"
PAGE_BLOCK_SIZE+=(128) PAGE_BLOCK_SIZE="$PAGE_BLOCK_SIZE 128"
CACHE_BATCH_IDX+=(1) CACHE_BATCH_IDX="$CACHE_BATCH_IDX 1"
fi fi
for prec in "fp16" "bf16" ; do for prec in "fp16" "bf16" ; do
...@@ -47,9 +47,9 @@ run_fp16_bf16_tests() { ...@@ -47,9 +47,9 @@ run_fp16_bf16_tests() {
for lse in 0 1 ; do for lse in 0 1 ; do
for bias in "n" "e" "a" ; do for bias in "n" "e" "a" ; do
for p_drop in 0.0 0.2 ; do for p_drop in 0.0 0.2 ; do
for num_splits in "${NUM_SPLITS[@]}" ; do for num_splits in $NUM_SPLITS ; do
for page_block_size in "${PAGE_BLOCK_SIZE[@]}" ; do for page_block_size in $PAGE_BLOCK_SIZE ; do
for cache_batch_idx in "${CACHE_BATCH_IDX[@]}" ; do for cache_batch_idx in $CACHE_BATCH_IDX ; do
# $EXE -prec=$prec -mode=$mode -b=1 -h=1 -d=$hdim -s=1024 -bias=$bias -p_drop=$p_drop -lse=$lse -iperm=$perm -operm=$perm -vlayout=$vlayout -num_splits=$num_splits -page_block_size=$page_block_size -kname=$KNAME $COMMON_ARGS # $EXE -prec=$prec -mode=$mode -b=1 -h=1 -d=$hdim -s=1024 -bias=$bias -p_drop=$p_drop -lse=$lse -iperm=$perm -operm=$perm -vlayout=$vlayout -num_splits=$num_splits -page_block_size=$page_block_size -kname=$KNAME $COMMON_ARGS
$EXE -prec=$prec -mode=$mode -b=2 -h=2 -h_k=1 -d=16, -d_v=$hdim -s=55 -s_k=256 -bias=$bias -p_drop=$p_drop -lse=$lse -iperm=$perm -operm=$perm -vlayout=$vlayout -num_splits=$num_splits -page_block_size=$page_block_size -cache_batch_idx=$cache_batch_idx -kname=$KNAME $COMMON_ARGS $EXE -prec=$prec -mode=$mode -b=2 -h=2 -h_k=1 -d=16, -d_v=$hdim -s=55 -s_k=256 -bias=$bias -p_drop=$p_drop -lse=$lse -iperm=$perm -operm=$perm -vlayout=$vlayout -num_splits=$num_splits -page_block_size=$page_block_size -cache_batch_idx=$cache_batch_idx -kname=$KNAME $COMMON_ARGS
...@@ -103,4 +103,4 @@ if [ $TEST_APPENDKV -eq 1 ] ; then ...@@ -103,4 +103,4 @@ if [ $TEST_APPENDKV -eq 1 ] ; then
run_fp16_appendkv_tests run_fp16_appendkv_tests
fi fi
set +x set +x
\ No newline at end of file
...@@ -57,6 +57,7 @@ template <typename XDataType_, ...@@ -57,6 +57,7 @@ template <typename XDataType_,
ck_tile::index_t Vector_N_, // vector size along N ck_tile::index_t Vector_N_, // vector size along N
bool kPadN_, bool kPadN_,
bool kSaveMeanInvStd_, bool kSaveMeanInvStd_,
bool kFastFDiv_,
bool kTwoPass_, bool kTwoPass_,
ck_tile::index_t kFusedAdd_ = 0, ck_tile::index_t kFusedAdd_ = 0,
ck_tile::index_t kFusedQuant_ = 0> ck_tile::index_t kFusedQuant_ = 0>
...@@ -118,6 +119,7 @@ struct layernorm2d_fwd_traits_ ...@@ -118,6 +119,7 @@ struct layernorm2d_fwd_traits_
static constexpr bool kPadN = kPadN_; static constexpr bool kPadN = kPadN_;
static constexpr bool kSaveMeanInvStd = kSaveMeanInvStd_; static constexpr bool kSaveMeanInvStd = kSaveMeanInvStd_;
static constexpr bool kFastFDiv = kFastFDiv_;
static constexpr bool kTwoPass = kTwoPass_; static constexpr bool kTwoPass = kTwoPass_;
static constexpr ck_tile::index_t kFusedAdd = kFusedAdd_; static constexpr ck_tile::index_t kFusedAdd = kFusedAdd_;
static constexpr ck_tile::index_t kFusedQuant = kFusedQuant_; static constexpr ck_tile::index_t kFusedQuant = kFusedQuant_;
...@@ -134,6 +136,7 @@ template <typename XDataType_, ...@@ -134,6 +136,7 @@ template <typename XDataType_,
ck_tile::index_t Vector_N_, // vector size along N ck_tile::index_t Vector_N_, // vector size along N
bool kPadN_, bool kPadN_,
bool kSaveMeanInvStd_, bool kSaveMeanInvStd_,
bool kFastFDiv_,
bool kTwoPass_, bool kTwoPass_,
int kFusedAdd_, int kFusedAdd_,
int kFusedQuant_> int kFusedQuant_>
...@@ -148,6 +151,7 @@ using traits_ = layernorm2d_fwd_traits_<XDataType_, ...@@ -148,6 +151,7 @@ using traits_ = layernorm2d_fwd_traits_<XDataType_,
Vector_N_, Vector_N_,
kPadN_, kPadN_,
kSaveMeanInvStd_, kSaveMeanInvStd_,
kFastFDiv_,
kTwoPass_, kTwoPass_,
kFusedAdd_, kFusedAdd_,
kFusedQuant_>; kFusedQuant_>;
...@@ -179,6 +183,7 @@ float layernorm2d_fwd_(const S& s, A a) ...@@ -179,6 +183,7 @@ float layernorm2d_fwd_(const S& s, A a)
using PipelineTraits = ck_tile::Layernorm2dFwdTraits<Traits_::kPadN, using PipelineTraits = ck_tile::Layernorm2dFwdTraits<Traits_::kPadN,
Traits_::kSaveMeanInvStd, Traits_::kSaveMeanInvStd,
Traits_::kFastFDiv,
Traits_::kTwoPass, Traits_::kTwoPass,
static_cast<ck_tile::Layernorm2dFusedAddEnum>(Traits_::kFusedAdd), static_cast<ck_tile::Layernorm2dFusedAddEnum>(Traits_::kFusedAdd),
static_cast<ck_tile::Layernorm2dFusedQuantEnum>(Traits_::kFusedQuant)>; static_cast<ck_tile::Layernorm2dFusedQuantEnum>(Traits_::kFusedQuant)>;
...@@ -269,7 +274,7 @@ float layernorm2d_fwd(layernorm2d_fwd_traits t, ...@@ -269,7 +274,7 @@ float layernorm2d_fwd(layernorm2d_fwd_traits t,
#include "layernorm2d_fwd_api_common.hpp" #include "layernorm2d_fwd_api_common.hpp"
// clang-format off // clang-format off
// prec_i prec_o prec_sy rm rn tm tn vn pd mv 2p add sweep // prec_i prec_o prec_sy rm rn tm tn vn pd mv rpcf 2p add sweep
{F_instance_def} {F_instance_def}
// clang-format on // clang-format on
...@@ -356,6 +361,7 @@ float layernorm2d_fwd(layernorm2d_fwd_traits t, ...@@ -356,6 +361,7 @@ float layernorm2d_fwd(layernorm2d_fwd_traits t,
F_Vector_N : int F_Vector_N : int
F_kPadN : bool F_kPadN : bool
F_kSaveMeanInvStd_ : bool F_kSaveMeanInvStd_ : bool
F_kFastFDiv_ : bool
F_kTwoPass_ : bool F_kTwoPass_ : bool
F_kFusedAdd : int F_kFusedAdd : int
F_kFusedQuant : int F_kFusedQuant : int
...@@ -363,7 +369,7 @@ float layernorm2d_fwd(layernorm2d_fwd_traits t, ...@@ -363,7 +369,7 @@ float layernorm2d_fwd(layernorm2d_fwd_traits t,
@property @property
def trait_name(self) ->str: def trait_name(self) ->str:
t_ = f'{DATA_TYPE_MAP[self.F_XDataType]}, {DATA_TYPE_MAP[self.F_YDataType]}, {DATA_TYPE_MAP[self.F_XScaleDataType]}, {DATA_TYPE_MAP[self.F_YScaleDataType]}, {self.F_Repeat_M:2}, {self.F_Repeat_N:2}, {self.F_ThreadPerBlock_M:2}, {self.F_ThreadPerBlock_N:4}' t_ = f'{DATA_TYPE_MAP[self.F_XDataType]}, {DATA_TYPE_MAP[self.F_YDataType]}, {DATA_TYPE_MAP[self.F_XScaleDataType]}, {DATA_TYPE_MAP[self.F_YScaleDataType]}, {self.F_Repeat_M:2}, {self.F_Repeat_N:2}, {self.F_ThreadPerBlock_M:2}, {self.F_ThreadPerBlock_N:4}'
t_ += f', {self.F_Vector_N:2}, {BOOL_MAP(self.F_kPadN):5}, {BOOL_MAP(self.F_kSaveMeanInvStd_):5}' t_ += f', {self.F_Vector_N:2}, {BOOL_MAP(self.F_kPadN):5}, {BOOL_MAP(self.F_kSaveMeanInvStd_):5}, {BOOL_MAP(self.F_kFastFDiv_):5}'
t_ += f', {BOOL_MAP(self.F_kTwoPass_):5}, {self.F_kFusedAdd:4}, {self.F_kFusedQuant:4}' t_ += f', {BOOL_MAP(self.F_kTwoPass_):5}, {self.F_kFusedAdd:4}, {self.F_kFusedQuant:4}'
return t_ return t_
...@@ -483,52 +489,55 @@ float layernorm2d_fwd(layernorm2d_fwd_traits t, ...@@ -483,52 +489,55 @@ float layernorm2d_fwd(layernorm2d_fwd_traits t,
fused_add_list = [0, 1] fused_add_list = [0, 1]
fused_sweep_list = [0, 1] # NOTE: only single pass can use fused dynamic quant fused_sweep_list = [0, 1] # NOTE: only single pass can use fused dynamic quant
# rm rn tm tn vn pd mv 2p add sweep # rm rn tm tn vn pd mv fdiv 2p add sweep
h_trait_dict = {'64' : [ h_traits('x', 'y', 'xs', 'ys', 1, 1, 4, 64, 1, True, False, False, 0, 0)], h_trait_dict = {'64' : [ h_traits('x', 'y', 'xs', 'ys', 1, 1, 8, 8, 8, True, False, True, False, 0, 0),
'128' : [ h_traits('x', 'y', 'xs', 'ys', 1, 1, 4, 64, 2, True, False, False, 0, 0), h_traits('x', 'y', 'xs', 'ys', 1, 1, 4, 16, 4, True, False, True, False, 0, 0),
h_traits('x', 'y', 'xs', 'ys', 1, 2, 4, 64, 1, True, False, False, 0, 0)], h_traits('x', 'y', 'xs', 'ys', 1, 1, 4, 64, 1, True, False, True, False, 0, 0)],
'256' : [ h_traits('x', 'y', 'xs', 'ys', 1, 1, 4, 64, 4, True, False, False, 0, 0), '128' : [ h_traits('x', 'y', 'xs', 'ys', 1, 1, 4, 16, 8, True, False, True, False, 0, 0),
h_traits('x', 'y', 'xs', 'ys', 1, 2, 4, 64, 2, True, False, False, 0, 0), h_traits('x', 'y', 'xs', 'ys', 1, 1, 4, 64, 2, True, False, True, False, 0, 0),
h_traits('x', 'y', 'xs', 'ys', 1, 4, 4, 64, 1, True, False, False, 0, 0)], h_traits('x', 'y', 'xs', 'ys', 1, 2, 4, 64, 1, True, False, True, False, 0, 0)],
'512' : [ h_traits('x', 'y', 'xs', 'ys', 1, 1, 4, 64, 8, True, False, False, 0, 0), '256' : [ h_traits('x', 'y', 'xs', 'ys', 1, 1, 4, 64, 4, True, False, True, False, 0, 0),
h_traits('x', 'y', 'xs', 'ys', 1, 2, 4, 64, 4, True, False, False, 0, 0), h_traits('x', 'y', 'xs', 'ys', 1, 2, 4, 64, 2, True, False, True, False, 0, 0),
h_traits('x', 'y', 'xs', 'ys', 1, 4, 4, 64, 2, True, False, False, 0, 0), h_traits('x', 'y', 'xs', 'ys', 1, 4, 4, 64, 1, True, False, True, False, 0, 0)],
h_traits('x', 'y', 'xs', 'ys', 1, 8, 4, 64, 1, True, False, False, 0, 0)], '512' : [ h_traits('x', 'y', 'xs', 'ys', 1, 1, 4, 64, 8, True, False, True, False, 0, 0),
'768' : [ h_traits('x', 'y', 'xs', 'ys', 1, 3, 4, 64, 4, True, False, False, 0, 0), h_traits('x', 'y', 'xs', 'ys', 1, 2, 4, 64, 4, True, False, True, False, 0, 0),
h_traits('x', 'y', 'xs', 'ys', 1, 6, 4, 64, 2, True, False, False, 0, 0), h_traits('x', 'y', 'xs', 'ys', 1, 4, 4, 64, 2, True, False, True, False, 0, 0),
h_traits('x', 'y', 'xs', 'ys', 1, 12, 4, 64, 1, True, False, False, 0, 0)], h_traits('x', 'y', 'xs', 'ys', 1, 8, 4, 64, 1, True, False, True, False, 0, 0)],
'1024' :[ h_traits('x', 'y', 'xs', 'ys', 1, 1, 2, 128, 8, True, False, False, 0, 0), '768' : [ h_traits('x', 'y', 'xs', 'ys', 1, 3, 4, 64, 4, True, False, True, False, 0, 0),
h_traits('x', 'y', 'xs', 'ys', 1, 2, 2, 128, 4, True, False, False, 0, 0), h_traits('x', 'y', 'xs', 'ys', 1, 6, 4, 64, 2, True, False, True, False, 0, 0),
h_traits('x', 'y', 'xs', 'ys', 1, 4, 2, 128, 2, True, False, False, 0, 0), h_traits('x', 'y', 'xs', 'ys', 1, 12, 4, 64, 1, True, False, True, False, 0, 0)],
h_traits('x', 'y', 'xs', 'ys', 1, 4, 1, 256, 1, True, False, False, 0, 0)], '1024' :[ h_traits('x', 'y', 'xs', 'ys', 1, 1, 2, 128, 8, True, False, True, False, 0, 0),
'1536' :[ h_traits('x', 'y', 'xs', 'ys', 1, 3, 4, 64, 8, True, False, False, 0, 0), h_traits('x', 'y', 'xs', 'ys', 1, 2, 2, 128, 4, True, False, True, False, 0, 0),
h_traits('x', 'y', 'xs', 'ys', 1, 3, 2, 128, 4, True, False, False, 0, 0), h_traits('x', 'y', 'xs', 'ys', 1, 4, 2, 128, 2, True, False, True, False, 0, 0),
h_traits('x', 'y', 'xs', 'ys', 1, 3, 1, 256, 2, True, False, False, 0, 0), h_traits('x', 'y', 'xs', 'ys', 1, 4, 1, 256, 1, True, False, True, False, 0, 0)],
h_traits('x', 'y', 'xs', 'ys', 1, 6, 1, 256, 1, True, False, False, 0, 0)], '1536' :[ h_traits('x', 'y', 'xs', 'ys', 1, 3, 4, 64, 8, True, False, True, False, 0, 0),
'2048' :[ h_traits('x', 'y', 'xs', 'ys', 1, 1, 1, 256, 8, True, False, False, 0, 0), h_traits('x', 'y', 'xs', 'ys', 1, 3, 2, 128, 4, True, False, True, False, 0, 0),
h_traits('x', 'y', 'xs', 'ys', 1, 2, 1, 256, 4, True, False, False, 0, 0), h_traits('x', 'y', 'xs', 'ys', 1, 3, 1, 256, 2, True, False, True, False, 0, 0),
h_traits('x', 'y', 'xs', 'ys', 1, 4, 1, 256, 2, True, False, False, 0, 0), h_traits('x', 'y', 'xs', 'ys', 1, 6, 1, 256, 1, True, False, True, False, 0, 0)],
h_traits('x', 'y', 'xs', 'ys', 1, 8, 1, 256, 1, True, False, False, 0, 0)], '2048' :[ h_traits('x', 'y', 'xs', 'ys', 1, 1, 1, 256, 8, True, False, True, False, 0, 0),
'3072' :[ h_traits('x', 'y', 'xs', 'ys', 1, 3, 1, 128, 8, True, False, False, 0, 0), h_traits('x', 'y', 'xs', 'ys', 1, 2, 1, 256, 4, True, False, True, False, 0, 0),
h_traits('x', 'y', 'xs', 'ys', 1, 3, 1, 256, 4, True, False, False, 0, 0), h_traits('x', 'y', 'xs', 'ys', 1, 4, 1, 256, 2, True, False, True, False, 0, 0),
h_traits('x', 'y', 'xs', 'ys', 1, 6, 1, 256, 2, True, False, False, 0, 0), h_traits('x', 'y', 'xs', 'ys', 1, 8, 1, 256, 1, True, False, True, False, 0, 0)],
h_traits('x', 'y', 'xs', 'ys', 1, 3, 1,1024, 1, True, False, False, 0, 0)], '3072' :[ h_traits('x', 'y', 'xs', 'ys', 1, 3, 1, 128, 8, True, False, True, False, 0, 0),
'4096' :[ h_traits('x', 'y', 'xs', 'ys', 1, 2, 1, 256, 8, True, False, False, 0, 0), h_traits('x', 'y', 'xs', 'ys', 1, 3, 1, 256, 4, True, False, True, False, 0, 0),
h_traits('x', 'y', 'xs', 'ys', 1, 4, 1, 256, 4, True, False, False, 0, 0), h_traits('x', 'y', 'xs', 'ys', 1, 6, 1, 256, 2, True, False, True, False, 0, 0),
h_traits('x', 'y', 'xs', 'ys', 1, 2, 1,1024, 2, True, False, False, 0, 0), h_traits('x', 'y', 'xs', 'ys', 1, 3, 1,1024, 1, True, False, True, False, 0, 0)],
h_traits('x', 'y', 'xs', 'ys', 1, 4, 1,1024, 1, True, False, False, 0, 0)], '4096' :[ h_traits('x', 'y', 'xs', 'ys', 1, 2, 1, 256, 8, True, False, True, False, 0, 0),
'6144' :[ h_traits('x', 'y', 'xs', 'ys', 1, 3, 1, 256, 8, True, False, False, 0, 0), h_traits('x', 'y', 'xs', 'ys', 1, 4, 1, 256, 4, True, False, True, False, 0, 0),
h_traits('x', 'y', 'xs', 'ys', 1, 3, 1, 512, 4, True, False, False, 0, 0), h_traits('x', 'y', 'xs', 'ys', 1, 2, 1,1024, 2, True, False, True, False, 0, 0),
h_traits('x', 'y', 'xs', 'ys', 1, 3, 1,1024, 2, True, False, False, 0, 0), h_traits('x', 'y', 'xs', 'ys', 1, 4, 1,1024, 1, True, False, True, False, 0, 0)],
h_traits('x', 'y', 'xs', 'ys', 1, 6, 1,1024, 1, True, False, False, 0, 0)], '6144' :[ h_traits('x', 'y', 'xs', 'ys', 1, 3, 1, 256, 8, True, False, True, False, 0, 0),
'8192' :[ h_traits('x', 'y', 'xs', 'ys', 1, 4, 1, 256, 8, True, False, False, 0, 0), h_traits('x', 'y', 'xs', 'ys', 1, 3, 1, 512, 4, True, False, True, False, 0, 0),
h_traits('x', 'y', 'xs', 'ys', 1, 4, 1, 512, 4, True, False, False, 0, 0), h_traits('x', 'y', 'xs', 'ys', 1, 3, 1,1024, 2, True, False, True, False, 0, 0),
h_traits('x', 'y', 'xs', 'ys', 1, 4, 1,1024, 2, True, False, False, 0, 0), h_traits('x', 'y', 'xs', 'ys', 1, 6, 1,1024, 1, True, False, True, False, 0, 0)],
h_traits('x', 'y', 'xs', 'ys', 1, 8, 1,1024, 1, True, False, False, 0, 0)], '8192' :[ h_traits('x', 'y', 'xs', 'ys', 1, 4, 1, 256, 8, True, False, True, False, 0, 0),
'big' :[ h_traits('x', 'y', 'xs', 'ys', 1, 2, 1, 256, 8, True, False, True, 0, 0), h_traits('x', 'y', 'xs', 'ys', 1, 4, 1, 512, 4, True, False, True, False, 0, 0),
h_traits('x', 'y', 'xs', 'ys', 1, 4, 1, 256, 4, True, False, True, 0, 0), h_traits('x', 'y', 'xs', 'ys', 1, 4, 1,1024, 2, True, False, True, False, 0, 0),
h_traits('x', 'y', 'xs', 'ys', 1, 2, 1,1024, 2, True, False, True, 0, 0), h_traits('x', 'y', 'xs', 'ys', 1, 8, 1,1024, 1, True, False, True, False, 0, 0)],
h_traits('x', 'y', 'xs', 'ys', 1, 4, 1,1024, 1, True, False, True, 0, 0)]} 'big' :[ h_traits('x', 'y', 'xs', 'ys', 1, 2, 1, 256, 8, True, False, True, True, 0, 0),
h_traits('x', 'y', 'xs', 'ys', 1, 4, 1, 256, 4, True, False, True, True, 0, 0),
h_traits('x', 'y', 'xs', 'ys', 1, 2, 1,1024, 2, True, False, True, True, 0, 0),
h_traits('x', 'y', 'xs', 'ys', 1, 4, 1,1024, 1, True, False, True, True, 0, 0)]}
total_blob = list() total_blob = list()
for hs_key in h_trait_dict: for hs_key in h_trait_dict:
hs = h_trait_dict[hs_key] hs = h_trait_dict[hs_key]
......
...@@ -25,7 +25,10 @@ auto create_args(int argc, char* argv[]) ...@@ -25,7 +25,10 @@ auto create_args(int argc, char* argv[])
ck_tile::ArgParser arg_parser; ck_tile::ArgParser arg_parser;
arg_parser.insert("m", "3328", "m dimension") arg_parser.insert("m", "3328", "m dimension")
.insert("n", "4096", "n dimension") .insert("n", "4096", "n dimension")
.insert("stride", "-1", "stride per row, if -1 then equal to n") .insert("x_stride", "-1", "x row_stride, if -1 then equal to n")
.insert("xr_stride", "-1", "x residule row_stride, if -1 then equal to n")
.insert("y_stride", "-1", "y row_stride, if -1 then equal to n")
.insert("yr_stride", "-1", "y residule row_stride, if -1 then equal to n")
.insert("e", "1e-5", "epsilon") .insert("e", "1e-5", "epsilon")
.insert("save_mv", "0", "save mean/variance(invstd) or not. set to 1 in training case") .insert("save_mv", "0", "save mean/variance(invstd) or not. set to 1 in training case")
.insert("v", "1", "cpu validation or not") .insert("v", "1", "cpu validation or not")
...@@ -54,11 +57,20 @@ template <typename InDataType, ...@@ -54,11 +57,20 @@ template <typename InDataType,
bool SaveMeanVar> bool SaveMeanVar>
bool run(const ck_tile::ArgParser& arg_parser) bool run(const ck_tile::ArgParser& arg_parser)
{ {
ck_tile::index_t m = arg_parser.get_int("m"); ck_tile::index_t m = arg_parser.get_int("m");
ck_tile::index_t n = arg_parser.get_int("n"); ck_tile::index_t n = arg_parser.get_int("n");
ck_tile::index_t stride = arg_parser.get_int("stride"); ck_tile::index_t x_stride = arg_parser.get_int("x_stride");
if(stride < 0) if(x_stride < 0)
stride = n; x_stride = n;
ck_tile::index_t xr_stride = arg_parser.get_int("xr_stride");
if(xr_stride < 0)
xr_stride = n;
ck_tile::index_t y_stride = arg_parser.get_int("y_stride");
if(y_stride < 0)
y_stride = n;
ck_tile::index_t yr_stride = arg_parser.get_int("yr_stride");
if(yr_stride < 0)
yr_stride = n;
float epsilon = arg_parser.get_float("e"); float epsilon = arg_parser.get_float("e");
std::string prec_i = arg_parser.get_str("prec_i"); std::string prec_i = arg_parser.get_str("prec_i");
std::string prec_o = arg_parser.get_str("prec_o"); std::string prec_o = arg_parser.get_str("prec_o");
...@@ -89,7 +101,7 @@ bool run(const ck_tile::ArgParser& arg_parser) ...@@ -89,7 +101,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
return false; return false;
} }
assert(stride >= n); assert(x_stride >= n);
using TypeConfig = LayerNormTypeConfig<InDataType, OutDataType, XScaleDataType, YScaleDataType>; using TypeConfig = LayerNormTypeConfig<InDataType, OutDataType, XScaleDataType, YScaleDataType>;
...@@ -108,15 +120,15 @@ bool run(const ck_tile::ArgParser& arg_parser) ...@@ -108,15 +120,15 @@ bool run(const ck_tile::ArgParser& arg_parser)
using ComputeDataType = typename TypeConfig::ComputeDataType; using ComputeDataType = typename TypeConfig::ComputeDataType;
// host verify // host verify
ck_tile::HostTensor<XDataType> x_host({m, n}, {stride, 1}); ck_tile::HostTensor<XDataType> x_host({m, n}, {x_stride, 1});
ck_tile::HostTensor<GammaDataType> gamma_host({n}); ck_tile::HostTensor<GammaDataType> gamma_host({n});
ck_tile::HostTensor<BetaDataType> beta_host({n}); ck_tile::HostTensor<BetaDataType> beta_host({n});
ck_tile::HostTensor<XResidualDataType> x_residual_host({m, n}, {stride, 1}); ck_tile::HostTensor<XResidualDataType> x_residual_host({m, n}, {xr_stride, 1});
ck_tile::HostTensor<YResidualDataType> y_residual_host({m, n}, {stride, 1}); ck_tile::HostTensor<YResidualDataType> y_residual_host({m, n}, {yr_stride, 1});
ck_tile::HostTensor<YDataType> y_host_ref({m, n}, {stride, 1}); ck_tile::HostTensor<YDataType> y_host_ref({m, n}, {y_stride, 1});
ck_tile::HostTensor<YDataType> y_host_dev({m, n}, {stride, 1}); ck_tile::HostTensor<YDataType> y_host_dev({m, n}, {y_stride, 1});
ck_tile::HostTensor<MeanDataType> mean_host_ref({m}); ck_tile::HostTensor<MeanDataType> mean_host_ref({m});
ck_tile::HostTensor<InvStdDataType> invStd_host_ref({m}); ck_tile::HostTensor<InvStdDataType> invStd_host_ref({m});
...@@ -162,7 +174,9 @@ bool run(const ck_tile::ArgParser& arg_parser) ...@@ -162,7 +174,9 @@ bool run(const ck_tile::ArgParser& arg_parser)
}(); }();
std::cout << "[" << prec_str << "]" std::cout << "[" << prec_str << "]"
<< " m:" << m << ", n:" << n << ", stride:" << stride << std::flush; << " m:" << m << ", n:" << n << ", x_stride:" << x_stride
<< ", xr_stride:" << xr_stride << ", y_stride:" << y_stride
<< ", yr_stride:" << yr_stride << std::flush;
layernorm2d_fwd_traits traits{ layernorm2d_fwd_traits traits{
prec_i, prec_o, prec_sx, prec_sy, SaveMeanVar, fused_add, fused_quant}; prec_i, prec_o, prec_sx, prec_sy, SaveMeanVar, fused_add, fused_quant};
...@@ -182,7 +196,10 @@ bool run(const ck_tile::ArgParser& arg_parser) ...@@ -182,7 +196,10 @@ bool run(const ck_tile::ArgParser& arg_parser)
epsilon, epsilon,
m, m,
n, n,
stride}; x_stride, // x row_stride
xr_stride, // x residule row stride
y_stride, // y row stride
yr_stride}; // y residule row stride
float ave_time = layernorm2d_fwd( float ave_time = layernorm2d_fwd(
traits, args, ck_tile::stream_config{nullptr, true, kname ? 1 : 0, warmup, repeat}); traits, args, ck_tile::stream_config{nullptr, true, kname ? 1 : 0, warmup, repeat});
...@@ -285,7 +302,7 @@ bool run(const ck_tile::ArgParser& arg_parser) ...@@ -285,7 +302,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
y_buf.FromDevice(y_host_dev.data()); y_buf.FromDevice(y_host_dev.data());
ck_tile::HostTensor<YResidualDataType> y_residual_host_dev({m, n}, {stride, 1}); ck_tile::HostTensor<YResidualDataType> y_residual_host_dev({m, n}, {yr_stride, 1});
if(fused_add == 1) if(fused_add == 1)
{ {
y_residual_buf.FromDevice(y_residual_host_dev.data()); y_residual_buf.FromDevice(y_residual_host_dev.data());
...@@ -293,7 +310,7 @@ bool run(const ck_tile::ArgParser& arg_parser) ...@@ -293,7 +310,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
auto [rtol, atol] = get_elimit<InDataType>(); auto [rtol, atol] = get_elimit<InDataType>();
if(stride == n) if(x_stride == n)
{ {
pass = ck_tile::check_err( pass = ck_tile::check_err(
y_host_dev, y_host_ref, std::string("OUT Error: Incorrect results!"), rtol, atol); y_host_dev, y_host_ref, std::string("OUT Error: Incorrect results!"), rtol, atol);
...@@ -310,10 +327,10 @@ bool run(const ck_tile::ArgParser& arg_parser) ...@@ -310,10 +327,10 @@ bool run(const ck_tile::ArgParser& arg_parser)
{ {
for(int i_r = 0; i_r < m; i_r++) for(int i_r = 0; i_r < m; i_r++)
{ {
std::vector<YDataType> y_host_dev_row(y_host_dev.begin() + i_r * stride, std::vector<YDataType> y_host_dev_row(y_host_dev.begin() + i_r * y_stride,
y_host_dev.begin() + i_r * stride + n); y_host_dev.begin() + i_r * y_stride + n);
std::vector<YDataType> y_host_ref_row(y_host_ref.begin() + i_r * stride, std::vector<YDataType> y_host_ref_row(y_host_ref.begin() + i_r * y_stride,
y_host_ref.begin() + i_r * stride + n); y_host_ref.begin() + i_r * y_stride + n);
pass &= ck_tile::check_err(y_host_dev_row, pass &= ck_tile::check_err(y_host_dev_row,
y_host_ref_row, y_host_ref_row,
std::string("OUT[") + std::to_string(i_r) + std::string("OUT[") + std::to_string(i_r) +
...@@ -323,10 +340,10 @@ bool run(const ck_tile::ArgParser& arg_parser) ...@@ -323,10 +340,10 @@ bool run(const ck_tile::ArgParser& arg_parser)
if(fused_add == 1) if(fused_add == 1)
{ {
std::vector<YResidualDataType> y_residual_host_dev_row( std::vector<YResidualDataType> y_residual_host_dev_row(
y_residual_host_dev.begin() + i_r * stride, y_residual_host_dev.begin() + i_r * yr_stride,
y_residual_host_dev.begin() + i_r * stride + n); y_residual_host_dev.begin() + i_r * yr_stride + n);
std::vector<YResidualDataType> y_residual_host_ref_row( std::vector<YResidualDataType> y_residual_host_ref_row(
x_host.begin() + i_r * stride, x_host.begin() + i_r * stride + n); x_host.begin() + i_r * yr_stride, x_host.begin() + i_r * yr_stride + n);
pass &= ck_tile::check_err(y_residual_host_dev_row, pass &= ck_tile::check_err(y_residual_host_dev_row,
y_residual_host_ref_row, y_residual_host_ref_row,
std::string("ADD[") + std::to_string(i_r) + std::string("ADD[") + std::to_string(i_r) +
......
...@@ -8,7 +8,10 @@ This folder contains example for GEMM using ck_tile tile-programming implementat ...@@ -8,7 +8,10 @@ This folder contains example for GEMM using ck_tile tile-programming implementat
mkdir build && cd build mkdir build && cd build
# you can replace <arch> with the appropriate architecture (for example gfx90a or gfx942) or leave it blank # you can replace <arch> with the appropriate architecture (for example gfx90a or gfx942) or leave it blank
sh ../script/cmake-ck-dev.sh ../ <arch> sh ../script/cmake-ck-dev.sh ../ <arch>
# The basic pipeline method on the gemm calculation
make tile_example_gemm_basic -j make tile_example_gemm_basic -j
# The memory bound pipeline on the gemm calculation
make tile_example_gemm_mem_pipeline -j
``` ```
This will result in an executable `build/bin/tile_example_gemm_basic` This will result in an executable `build/bin/tile_example_gemm_basic`
......
...@@ -17,10 +17,11 @@ ...@@ -17,10 +17,11 @@
template <typename ALayout, typename BLayout, typename CLayout> template <typename ALayout, typename BLayout, typename CLayout>
float gemm_calc(const gemm_basic_args& args, const ck_tile::stream_config& s) float gemm_calc(const gemm_basic_args& args, const ck_tile::stream_config& s)
{ {
// The kPadA, kPadB, kPadC & kBlockPerCu should also come from the Codegen part. // The kPadM, kPadN, kPadK & kBlockPerCu should also come from the Codegen part.
constexpr bool kPadA = true; constexpr bool kPadM = false;
constexpr bool kPadB = true; constexpr bool kPadN = false;
constexpr bool kPadC = true; constexpr bool kPadK = false;
constexpr bool kTilePermute = false; constexpr bool kTilePermute = false;
// The rank and permutation will also be generate out by the CodeGen part. // The rank and permutation will also be generate out by the CodeGen part.
constexpr ck_tile::index_t kOutputRank = 2; constexpr ck_tile::index_t kOutputRank = 2;
...@@ -56,8 +57,8 @@ float gemm_calc(const gemm_basic_args& args, const ck_tile::stream_config& s) ...@@ -56,8 +57,8 @@ float gemm_calc(const gemm_basic_args& args, const ck_tile::stream_config& s)
CShuffleEpilogue, CShuffleEpilogue,
ck_tile::CShuffleEpilogue<ck_tile::CShuffleEpilogueProblem<AccDataType, ck_tile::CShuffleEpilogue<ck_tile::CShuffleEpilogueProblem<AccDataType,
CDataType, CDataType,
kPadA, kPadM,
kPadB, kPadN,
kTilePermute, kTilePermute,
kOutputRank, kOutputRank,
1, 1,
...@@ -65,13 +66,13 @@ float gemm_calc(const gemm_basic_args& args, const ck_tile::stream_config& s) ...@@ -65,13 +66,13 @@ float gemm_calc(const gemm_basic_args& args, const ck_tile::stream_config& s)
TilePartitioner::kM, TilePartitioner::kM,
TilePartitioner::kN>>, TilePartitioner::kN>>,
ck_tile::Default2DEpilogue< ck_tile::Default2DEpilogue<
ck_tile::Default2DEpilogueProblem<AccDataType, CDataType, kPadA, kPadB>>>; ck_tile::Default2DEpilogueProblem<AccDataType, CDataType, kPadM, kPadN>>>;
using CodegenGemmTraits = using CodegenGemmTraits =
ck_tile::TileGemmTraits<kPadA, kPadB, kPadC, ALayout, BLayout, CLayout>; ck_tile::TileGemmTraits<kPadM, kPadN, kPadK, ALayout, BLayout, CLayout>;
using CodegenPipelineProblem = ck_tile:: using CodegenPipelineProblem = ck_tile::
GemmPipelineProblem<ADataType, BDataType, AccDataType, CodegenGemmShape, CodegenGemmTraits>; GemmPipelineProblem<ADataType, BDataType, AccDataType, CodegenGemmShape, CodegenGemmTraits>;
using CodegenGemmPolicy = ck_tile::UniversalGemmPipelineAgBgCrPolicy<ALayout, BLayout, CLayout>; using CodegenGemmPolicy = ck_tile::UniversalGemmPipelineAgBgCrPolicy;
using CodegenGemmPipeline = using CodegenGemmPipeline =
ck_tile::GemmPipelineAGmemBGmemCRegV1<CodegenPipelineProblem, CodegenGemmPolicy>; ck_tile::GemmPipelineAGmemBGmemCRegV1<CodegenPipelineProblem, CodegenGemmPolicy>;
// ToDo: Will add the codegen part to test different pipeline policies in GEMM. // ToDo: Will add the codegen part to test different pipeline policies in GEMM.
......
...@@ -31,9 +31,9 @@ float gemm_calc(const gemm_basic_args& args, const ck_tile::stream_config& s) ...@@ -31,9 +31,9 @@ float gemm_calc(const gemm_basic_args& args, const ck_tile::stream_config& s)
constexpr ck_tile::index_t K_Warp_Tile = 8; constexpr ck_tile::index_t K_Warp_Tile = 8;
// The kPadA, kPadB, kPadC & kBlockPerCu should also come from the Codegen part. // The kPadA, kPadB, kPadC & kBlockPerCu should also come from the Codegen part.
constexpr bool kPadA = true; constexpr bool kPadM = true;
constexpr bool kPadB = true; constexpr bool kPadN = true;
constexpr bool kPadC = true; constexpr bool kPadK = true;
constexpr int kBlockPerCu = 1; constexpr int kBlockPerCu = 1;
...@@ -46,9 +46,9 @@ float gemm_calc(const gemm_basic_args& args, const ck_tile::stream_config& s) ...@@ -46,9 +46,9 @@ float gemm_calc(const gemm_basic_args& args, const ck_tile::stream_config& s)
using TilePartitioner = ck_tile::GemmTilePartitioner<GemmShape>; using TilePartitioner = ck_tile::GemmTilePartitioner<GemmShape>;
using GemmEpilogue = ck_tile::Default2DEpilogue< using GemmEpilogue = ck_tile::Default2DEpilogue<
ck_tile::Default2DEpilogueProblem<AccDataType, CDataType, false, kPadC>>; ck_tile::Default2DEpilogueProblem<AccDataType, CDataType, kPadM, kPadN>>;
using Traits = ck_tile::TileGemmTraits<kPadA, kPadB, kPadC, ALayout, BLayout, CLayout>; using Traits = ck_tile::TileGemmTraits<kPadM, kPadN, kPadK, ALayout, BLayout, CLayout>;
using BaseGemmPipeline = ck_tile::BaseGemmPipelineAgBgCrMem< using BaseGemmPipeline = ck_tile::BaseGemmPipelineAgBgCrMem<
ck_tile::GemmPipelineProblem<ADataType, BDataType, AccDataType, GemmShape, Traits>>; ck_tile::GemmPipelineProblem<ADataType, BDataType, AccDataType, GemmShape, Traits>>;
......
add_executable(tile_example_moe_sorting EXCLUDE_FROM_ALL moe_sorting.cpp moe_sorting_api.cpp)
target_include_directories(tile_example_moe_sorting PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/)
set(EXAMPLE_MOE_SORTING_COMPILE_OPTIONS)
# NOTE: we turn off undefined-func-template to let source compile without explicit declare function specializations
list(APPEND EXAMPLE_MOE_SORTING_COMPILE_OPTIONS -Wno-undefined-func-template -Wno-float-equal)
# list(APPEND EXAMPLE_MOE_SORTING_COMPILE_OPTIONS -v --save-temps -Wno-gnu-line-marker)
target_compile_options(tile_example_moe_sorting PRIVATE ${EXAMPLE_MOE_SORTING_COMPILE_OPTIONS})
# moe-sorting
This folder contains example for moe-sorting kernel using ck_tile tile-programming implementation. This kernel is often used in Moe model, before launching the fused-moe-gemm block. The input&weight is a `token*topk` 2d matrix. The op rearange the input weight ids into different experts and feed into fuse moe gemm kernel.
## build
```
# in the root of ck_tile
mkdir build && cd build
sh ../script/cmake-ck-dev.sh ../ <arch> # you can replace this <arch> to gfx90a, gfx942...
make tile_example_moe_sorting -j
```
This will result in an executable `build/bin/tile_example_moe_sorting`
## example
```
args:
-v weather do CPU validation or not (default:1)
-pr_i index data type. (currently only fp32 supported now) (default:int32)
-pr_w output weight data type(currently only fp32 supported now) (default:fp32)
-t number of input tokens (default:32)
-e number of experts (default:8)
-k topk (default:2)
-st_i row stride of input, -1 means same as experts (default:-1)
-seed seed to be used, -1 means random every time (default:-1)
-kname when set to 1 it will print kernel name (default:0)
```
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#include <set>
#include <vector>
#include <iostream>
#include <numeric>
#include <cassert>
#include <cstdlib>
#include <iostream>
#include <time.h>
#include <unordered_set>
#include "ck_tile/core.hpp"
#include "ck_tile/ops/reduce.hpp"
#include "moe_sorting_api.hpp"
auto create_args(int argc, char* argv[])
{
ck_tile::ArgParser arg_parser;
arg_parser.insert("v", "1", "weather do CPU validation or not")
.insert("pr_i", "int32", "index data type. (currently only int32 supported now)")
.insert("pr_w", "fp32", "output weight data type(currently only fp32 supported now)")
.insert("t", "128", "number of input tokens")
.insert("e", "8", "number of num_experts")
.insert("k", "4", "topk")
.insert("unit", "32", "unit_size")
.insert("moe_buf_size", "0", "moe_buf_size")
.insert("seed", "-1", "seed to be used, -1 means random every time")
.insert("kname", "0", "when set to 1 it will print kernel name")
.insert("warmup", "5", "number of iterations before benchmark the kernel")
.insert("repeat", "20", "number of iterations to benchmark the kernel");
bool result = arg_parser.parse(argc, argv);
return std::make_tuple(result, arg_parser);
}
template <typename IndexType>
void topid_unique_gen(
std::vector<IndexType>& host_tensor, int tokens, int topk, int num_expert, int seed)
{
size_t total_size = topk * tokens;
std::srand(seed);
std::set<IndexType> unique_set;
IndexType current_v;
for(size_t i = 0; i < total_size; i++)
{
if(i % topk == 0)
{
unique_set.clear();
}
current_v = std::rand() % num_expert;
while(unique_set.find(current_v) != unique_set.end())
{
current_v = std::rand() % num_expert;
}
unique_set.insert(current_v);
host_tensor[i] = current_v;
}
}
template <typename WeightType, typename IndexType = ck_tile::index_t>
bool test_moe_sorting(ck_tile::ArgParser args)
{
int validate = args.get_int("v");
std::string index_prec = args.get_str("pr_i");
std::string weight_prec = args.get_str("pr_w");
int tokens = args.get_int("t");
int num_experts = args.get_int("e");
int topk = args.get_int("k");
int seed = args.get_int("seed");
int unit_size = args.get_int("unit");
int moe_buf_size = args.get_int("moe_buf_size");
int kname = args.get_int("kname");
int warmup = args.get_int("warmup");
int repeat = args.get_int("repeat");
int max_output_ids =
ck_tile::integer_least_multiple(topk * tokens + num_experts * unit_size - topk, unit_size);
if(seed < 0)
{
seed = std::time(nullptr);
}
if(topk > num_experts)
{
printf("topk:%d value should be smaller than, or equal to number of num_experts:%d\n",
topk,
num_experts);
return false;
}
// tokens already considered batch size
ck_tile::HostTensor<IndexType> topk_ids_host({tokens, topk}, {topk, 1});
ck_tile::HostTensor<WeightType> weights_host({tokens, topk}, {topk, 1});
ck_tile::HostTensor<IndexType> sorted_ids_host({max_output_ids}, {1});
ck_tile::HostTensor<WeightType> sorted_weights_host({max_output_ids}, {1});
ck_tile::HostTensor<IndexType> sorted_expert_ids_host({max_output_ids / unit_size}, {1});
ck_tile::HostTensor<IndexType> sorted_id_cnt_host({1}, {1});
ck_tile::HostTensor<float> moe_buf_host({moe_buf_size});
ck_tile::FillUniformDistribution<WeightType>{-.5f, .5f}(weights_host);
ck_tile::FillUniformDistribution<WeightType>{-.5f, .5f}(moe_buf_host);
topid_unique_gen<IndexType>(topk_ids_host.mData, tokens, topk, num_experts, seed);
ck_tile::DeviceMem topk_ids_dev(topk_ids_host.get_element_space_size_in_bytes());
ck_tile::DeviceMem weights_dev(weights_host.get_element_space_size_in_bytes());
ck_tile::DeviceMem sorted_ids_dev(sorted_ids_host.get_element_space_size_in_bytes());
ck_tile::DeviceMem sorted_weights_dev(sorted_weights_host.get_element_space_size_in_bytes());
ck_tile::DeviceMem sorted_expert_ids_dev(
sorted_expert_ids_host.get_element_space_size_in_bytes());
ck_tile::DeviceMem sorted_id_cnt_dev(sorted_id_cnt_host.get_element_space_size_in_bytes());
ck_tile::DeviceMem moe_buf_dev(moe_buf_host.get_element_space_size_in_bytes());
topk_ids_dev.ToDevice(topk_ids_host.data());
weights_dev.ToDevice(weights_host.data());
if(moe_buf_size > 0)
{
moe_buf_dev.ToDevice(moe_buf_host.data());
}
moe_sorting_trait trait{index_prec, weight_prec};
moe_sorting_args karg{topk_ids_dev.GetDeviceBuffer(),
weights_dev.GetDeviceBuffer(),
sorted_ids_dev.GetDeviceBuffer(),
sorted_weights_dev.GetDeviceBuffer(),
sorted_expert_ids_dev.GetDeviceBuffer(),
sorted_id_cnt_dev.GetDeviceBuffer(),
moe_buf_size > 0 ? moe_buf_dev.GetDeviceBuffer() : nullptr,
tokens,
unit_size,
num_experts,
topk,
static_cast<ck_tile::index_t>(moe_buf_size * sizeof(float))};
ck_tile::stream_config sc{nullptr,
true,
/* log_level = */ (kname ? 1 : 0),
warmup,
repeat};
auto ms = moe_sorting(trait, karg, sc);
printf("[%s|%s]tokens:%d, num_experts:%d, topk:%d, ms:%f , ",
index_prec.c_str(),
weight_prec.c_str(),
tokens,
num_experts,
topk,
ms);
if(ms < 0)
printf("not supported\n");
fflush(stdout);
if(ms < 0)
{
return false;
}
sorted_ids_dev.FromDevice(sorted_ids_host.data());
sorted_weights_dev.FromDevice(sorted_weights_host.data());
sorted_expert_ids_dev.FromDevice(sorted_expert_ids_host.data());
sorted_id_cnt_dev.FromDevice(sorted_id_cnt_host.data());
if(moe_buf_size > 0)
{
moe_buf_dev.FromDevice(moe_buf_host.data());
}
bool rtn = true;
if(validate)
{
ck_tile::HostTensor<IndexType> sorted_ids_ref({max_output_ids}, {1});
ck_tile::HostTensor<WeightType> sorted_weights_ref({max_output_ids}, {1});
ck_tile::HostTensor<IndexType> sorted_expert_ids_ref({max_output_ids / unit_size}, {1});
int32_t ref_total_tokens_post_pad = 0;
ck_tile::reference_moe_sorting<WeightType, IndexType>(topk_ids_host,
weights_host,
sorted_ids_ref,
sorted_weights_ref,
sorted_expert_ids_ref,
ref_total_tokens_post_pad,
num_experts,
unit_size);
rtn &= ck_tile::check_err(
sorted_ids_host, sorted_ids_ref, std::string("OUT Error: Incorrect ids!"), 1e-6, 1e-6);
rtn &= ck_tile::check_err(sorted_weights_host,
sorted_weights_ref,
std::string("OUT Error: Incorrect w!"),
1e-6,
1e-6);
rtn &= ck_tile::check_err(sorted_expert_ids_host,
sorted_expert_ids_ref,
std::string("OUT Error: Incorrect eid!"),
1e-6,
1e-6);
if(moe_buf_size)
{
ck_tile::HostTensor<WeightType> moe_buf_ref({moe_buf_size});
rtn &= ck_tile::check_err(
moe_buf_host, moe_buf_ref, std::string("OUT Error: Incorrect zero buf!"), 0, 0);
}
rtn &= ref_total_tokens_post_pad == sorted_id_cnt_host.mData[0];
}
printf("valid:%s\n", rtn ? "y" : "n");
fflush(stdout);
return rtn;
}
int main(int argc, char** argv)
{
auto [result, args] = create_args(argc, argv);
if(!result)
return -1;
std::string index_prec = args.get_str("pr_i");
std::string weight_prec = args.get_str("pr_w");
bool r = true;
if(weight_prec.compare("fp32") == 0 && index_prec.compare("int32") == 0)
{
r &= test_moe_sorting<float, ck_tile::index_t>(args);
}
return r ? 0 : -1;
}
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#include "moe_sorting_api.hpp"
#define MOE_SORTING_DISPATCH(unroll_num_) \
constexpr ck_tile::index_t unroll_num = unroll_num_; \
using ms_problem = ck_tile::MoeSortingProblem<index_t, ms_weight_type, unroll_num>; \
using kernel = ck_tile::MoeSortingKernel<ms_problem>; \
auto kargs = kernel::MakeKargs(a); \
const dim3 grids = kernel::GridSize(a); \
const dim3 blocks = kernel::BlockSize(a); \
const auto lds_bytes = kernel::GetSmemSize(a); \
float ave_time = ck_tile::launch_kernel( \
s, ck_tile::make_kernel(kernel{}, grids, blocks, lds_bytes, kargs)); \
return ave_time;
float moe_sorting(moe_sorting_trait t, moe_sorting_args a, ck_tile::stream_config s)
{
if(t.weight_type == "fp32" && t.index_type == "int32")
{
if(a.num_experts > 127)
{
printf("lds size exceed, only support experts <127 \n");
return -1;
}
if(a.moe_buf_bytes % 16)
{
printf("buf set size %d unaligned, must be multiple of 16\n", a.moe_buf_bytes);
return -1;
}
using index_t = ck_tile::index_t;
using ms_weight_type = float;
index_t smem_io_unroll_num = ck_tile::integer_divide_ceil(a.tokens * a.topk, 64);
switch(smem_io_unroll_num)
{
case(1): {
MOE_SORTING_DISPATCH(1);
}
case(2): {
MOE_SORTING_DISPATCH(2);
}
case(3): {
MOE_SORTING_DISPATCH(3);
}
case(5): {
MOE_SORTING_DISPATCH(5);
}
case(6): {
MOE_SORTING_DISPATCH(6);
}
case(7): {
MOE_SORTING_DISPATCH(7);
}
case(8): {
MOE_SORTING_DISPATCH(8);
}
case(9): {
MOE_SORTING_DISPATCH(9);
}
case(10): {
MOE_SORTING_DISPATCH(10);
}
case(11): {
MOE_SORTING_DISPATCH(11);
}
default: {
MOE_SORTING_DISPATCH(4);
}
}
}
return -1;
}
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment