Unverified Commit c8a8385f authored by Jun Liu's avatar Jun Liu Committed by GitHub
Browse files

[HotFix] add config and version files to pass on build info (#856)

* experiment with config file

* experiment with version.h config

* add more info to version.h

* minor updates

* minor updates

* fix case where DTYPE is not used

* large amount of files but minor changes

* remove white space

* minor changes to add more MACROs

* fix cmakedefine01

* fix issue with CK internal conflict

* fix define and define value

* fix clang-format

* fix formatting issue

* experiment with cmake

* clang format v12 to be consistent with miopen

* avoid clang-format for config file
parent 350d64f3
...@@ -3,7 +3,7 @@ repos: ...@@ -3,7 +3,7 @@ repos:
hooks: hooks:
- id: clang-format - id: clang-format
name: clang-format name: clang-format
entry: clang-format-10 -i --style=file entry: clang-format-12 -i --style=file
language: system language: system
types_or: [c++, inc] types_or: [c++, inc]
- id: copyright-year-checker - id: copyright-year-checker
......
cmake_minimum_required(VERSION 3.14) cmake_minimum_required(VERSION 3.14)
set(version 1.1.0)
# Check support for CUDA/HIP in Cmake # Check support for CUDA/HIP in Cmake
project(composable_kernel) project(composable_kernel VERSION ${version})
list(APPEND CMAKE_MODULE_PATH "${PROJECT_SOURCE_DIR}/cmake") list(APPEND CMAKE_MODULE_PATH "${PROJECT_SOURCE_DIR}/cmake")
if (DTYPES) if (DTYPES)
add_definitions(-DDTYPES) add_definitions(-DDTYPES)
if (DTYPES MATCHES "int8") if (DTYPES MATCHES "int8")
add_definitions(-D__int8__) add_definitions(-DCK_ENABLE_INT8)
set(CK_ENABLE_INT8 "ON")
endif() endif()
if (DTYPES MATCHES "fp8") if (DTYPES MATCHES "fp8")
add_definitions(-D__fp8__) add_definitions(-DCK_ENABLE_FP8)
set(CK_ENABLE_FP8 "ON")
endif() endif()
if (DTYPES MATCHES "fp16") if (DTYPES MATCHES "fp16")
add_definitions(-D__fp16__) add_definitions(-DCK_ENABLE_FP16)
set(CK_ENABLE_FP16 "ON")
endif() endif()
if (DTYPES MATCHES "fp32") if (DTYPES MATCHES "fp32")
add_definitions(-D__fp32__) add_definitions(-DCK_ENABLE_FP32)
set(CK_ENABLE_FP32 "ON")
endif() endif()
if (DTYPES MATCHES "fp64") if (DTYPES MATCHES "fp64")
add_definitions(-D__fp64__) add_definitions(-DCK_ENABLE_FP64)
set(CK_ENABLE_FP64 "ON")
endif() endif()
if (DTYPES MATCHES "bf16") if (DTYPES MATCHES "bf16")
add_definitions(-D__bf16__) add_definitions(-DCK_ENABLE_BF16)
set(CK_ENABLE_BF16 "ON")
endif() endif()
message("DTYPES macro set to ${DTYPES}") message("DTYPES macro set to ${DTYPES}")
else() else()
add_definitions(-D__int8__ -D__fp8__ -D__fp16__ -D__fp32__ -D__fp64__ -D__bf16__) add_definitions(-DCK_ENABLE_INT8 -DCK_ENABLE_FP8 -DCK_ENABLE_FP16 -DCK_ENABLE_FP32 -DCK_ENABLE_FP64 -DCK_ENABLE_BF16)
set(CK_ENABLE_ALL_DTYPES "ON")
endif() endif()
if(DL_KERNELS) if(DL_KERNELS)
add_definitions(-DDL_KERNELS) add_definitions(-DDL_KERNELS)
set(CK_ENABLE_DL_KERNELS "ON")
endif() endif()
if(INSTANCES_ONLY) if(INSTANCES_ONLY)
add_definitions(-DINSTANCES_ONLY) add_definitions(-DINSTANCES_ONLY)
set(CK_ENABLE_INSTANCES_ONLY "ON")
endif() endif()
# CK config file to record supported datatypes, etc.
configure_file("${PROJECT_SOURCE_DIR}/include/ck/config.h.in" "${PROJECT_BINARY_DIR}/include/ck/config.h")
# CK version file to record release version as well as git commit hash
find_package(Git REQUIRED)
execute_process(COMMAND "${GIT_EXECUTABLE}" rev-parse HEAD OUTPUT_VARIABLE COMMIT_ID OUTPUT_STRIP_TRAILING_WHITESPACE)
configure_file("${PROJECT_SOURCE_DIR}/include/ck/version.h.in" "${PROJECT_BINARY_DIR}/include/ck/version.h")
enable_testing() enable_testing()
set(ROCM_SYMLINK_LIBS OFF) set(ROCM_SYMLINK_LIBS OFF)
...@@ -50,8 +68,10 @@ include(ROCMInstallSymlinks) ...@@ -50,8 +68,10 @@ include(ROCMInstallSymlinks)
include(ROCMCreatePackage) include(ROCMCreatePackage)
include(CheckCXXCompilerFlag) include(CheckCXXCompilerFlag)
include(ROCMCheckTargetIds) include(ROCMCheckTargetIds)
rocm_setup_version(VERSION 0.2.0)
include(TargetFlags) include(TargetFlags)
rocm_setup_version(VERSION ${version})
list(APPEND CMAKE_PREFIX_PATH ${CMAKE_INSTALL_PREFIX} ${CMAKE_INSTALL_PREFIX}/llvm ${CMAKE_INSTALL_PREFIX}/hip /opt/rocm /opt/rocm/llvm /opt/rocm/hip) list(APPEND CMAKE_PREFIX_PATH ${CMAKE_INSTALL_PREFIX} ${CMAKE_INSTALL_PREFIX}/llvm ${CMAKE_INSTALL_PREFIX}/hip /opt/rocm /opt/rocm/llvm /opt/rocm/hip)
message("GPU_TARGETS= ${GPU_TARGETS}") message("GPU_TARGETS= ${GPU_TARGETS}")
...@@ -315,13 +335,14 @@ set(CMAKE_LIBRARY_OUTPUT_DIRECTORY ${CMAKE_CURRENT_BINARY_DIR}/lib) ...@@ -315,13 +335,14 @@ set(CMAKE_LIBRARY_OUTPUT_DIRECTORY ${CMAKE_CURRENT_BINARY_DIR}/lib)
set(CMAKE_ARCHIVE_OUTPUT_DIRECTORY ${CMAKE_CURRENT_BINARY_DIR}/lib) set(CMAKE_ARCHIVE_OUTPUT_DIRECTORY ${CMAKE_CURRENT_BINARY_DIR}/lib)
set(CMAKE_RUNTIME_OUTPUT_DIRECTORY ${CMAKE_CURRENT_BINARY_DIR}/bin) set(CMAKE_RUNTIME_OUTPUT_DIRECTORY ${CMAKE_CURRENT_BINARY_DIR}/bin)
# set CK project include directories
include_directories(BEFORE include_directories(BEFORE
${PROJECT_BINARY_DIR}/include
${PROJECT_SOURCE_DIR}/include ${PROJECT_SOURCE_DIR}/include
${PROJECT_SOURCE_DIR}/library/include ${PROJECT_SOURCE_DIR}/library/include
${HIP_INCLUDE_DIRS} ${HIP_INCLUDE_DIRS}
) )
SET(BUILD_DEV ON CACHE BOOL "BUILD_DEV") SET(BUILD_DEV ON CACHE BOOL "BUILD_DEV")
if(BUILD_DEV) if(BUILD_DEV)
add_compile_options(-Werror) add_compile_options(-Werror)
...@@ -409,7 +430,6 @@ endif() ...@@ -409,7 +430,6 @@ endif()
#Create an interface target for the include only files and call it "composablekernels" #Create an interface target for the include only files and call it "composablekernels"
include(CMakePackageConfigHelpers) include(CMakePackageConfigHelpers)
set(version 1.0.0)
write_basic_package_version_file( write_basic_package_version_file(
"${CMAKE_CURRENT_BINARY_DIR}/composable_kernelConfigVersion.cmake" "${CMAKE_CURRENT_BINARY_DIR}/composable_kernelConfigVersion.cmake"
VERSION "${version}" VERSION "${version}"
...@@ -428,6 +448,13 @@ rocm_install(FILES ...@@ -428,6 +448,13 @@ rocm_install(FILES
DESTINATION ${CMAKE_INSTALL_LIBDIR}/cmake/composable_kernel DESTINATION ${CMAKE_INSTALL_LIBDIR}/cmake/composable_kernel
) )
# Install CK version and configuration files
install(FILES
${PROJECT_BINARY_DIR}/include/ck/version.h
${PROJECT_BINARY_DIR}/include/ck/config.h
DESTINATION ${CMAKE_INSTALL_INCLUDEDIR}/ck/
)
set(CPACK_RESOURCE_FILE_LICENSE "${CMAKE_CURRENT_SOURCE_DIR}/LICENSE") set(CPACK_RESOURCE_FILE_LICENSE "${CMAKE_CURRENT_SOURCE_DIR}/LICENSE")
set(CPACK_RPM_PACKAGE_LICENSE "MIT") set(CPACK_RPM_PACKAGE_LICENSE "MIT")
......
...@@ -63,7 +63,7 @@ RUN apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install -y --allow- ...@@ -63,7 +63,7 @@ RUN apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install -y --allow-
nano \ nano \
zlib1g-dev \ zlib1g-dev \
openssh-server \ openssh-server \
clang-format-10 \ clang-format-12 \
kmod && \ kmod && \
apt-get clean && \ apt-get clean && \
rm -rf /var/lib/apt/lists/* rm -rf /var/lib/apt/lists/*
......
...@@ -689,7 +689,7 @@ pipeline { ...@@ -689,7 +689,7 @@ pipeline {
-o -iname \'*.cpp.in\' \ -o -iname \'*.cpp.in\' \
-o -iname \'*.cl\' \ -o -iname \'*.cl\' \
| grep -v 'build/' \ | grep -v 'build/' \
| xargs -n 1 -P 1 -I{} -t sh -c \'clang-format-10 -style=file {} | diff - {}\'" | xargs -n 1 -P 1 -I{} -t sh -c \'clang-format-12 -style=file {} | diff - {}\'"
} }
steps{ steps{
buildHipClangJobAndReboot(setup_cmd: "", build_cmd: "", execute_cmd: execute_cmd, no_reboot:true) buildHipClangJobAndReboot(setup_cmd: "", build_cmd: "", execute_cmd: execute_cmd, no_reboot:true)
......
...@@ -5,29 +5,50 @@ add_compile_options(-std=c++17) ...@@ -5,29 +5,50 @@ add_compile_options(-std=c++17)
if (DTYPES) if (DTYPES)
add_definitions(-DDTYPES) add_definitions(-DDTYPES)
if (DTYPES MATCHES "int8") if (DTYPES MATCHES "int8")
add_definitions(-D__int8__) add_definitions(-DCK_ENABLE_INT8)
if(NOT DEFINED ${CK_ENABLE_INT8})
set(CK_ENABLE_INT8 "ON")
endif()
endif() endif()
if (DTYPES MATCHES "fp8") if (DTYPES MATCHES "fp8")
add_definitions(-D__fp8__) add_definitions(-DCK_ENABLE_FP8)
if(NOT DEFINED ${CK_ENABLE_FP8})
set(CK_ENABLE_FP8 "ON")
endif()
endif() endif()
if (DTYPES MATCHES "fp16") if (DTYPES MATCHES "fp16")
add_definitions(-D__fp16__) add_definitions(-DCK_ENABLE_FP16)
if(NOT DEFINED ${CK_ENABLE_FP16})
set(CK_ENABLE_FP16 "ON")
endif()
endif() endif()
if (DTYPES MATCHES "fp32") if (DTYPES MATCHES "fp32")
add_definitions(-D__fp32__) add_definitions(-DCK_ENABLE_FP32)
if(NOT DEFINED ${CK_ENABLE_FP32})
set(CK_ENABLE_FP32 "ON")
endif()
endif() endif()
if (DTYPES MATCHES "fp64") if (DTYPES MATCHES "fp64")
add_definitions(-D__fp64__) add_definitions(-DCK_ENABLE_FP64)
if(NOT DEFINED ${CK_ENABLE_FP64})
set(CK_ENABLE_FP64 "ON")
endif()
endif() endif()
if (DTYPES MATCHES "bf16") if (DTYPES MATCHES "bf16")
add_definitions(-D__bf16__) add_definitions(-DCK_ENABLE_BF16)
if(NOT DEFINED ${CK_ENABLE_BF16})
set(CK_ENABLE_BF16 "ON")
endif()
endif() endif()
message("DTYPES macro set to ${DTYPES}") message("DTYPES macro set to ${DTYPES}")
else() else()
add_definitions(-D__int8__ -D__fp8__ -D__fp16__ -D__fp32__ -D__fp64__ -D__bf16__) add_definitions(-DCK_ENABLE_INT8 -DCK_ENABLE_FP8 -DCK_ENABLE_FP16 -DCK_ENABLE_FP32 -DCK_ENABLE_FP64 -DCK_ENABLE_BF16)
if(NOT DEFINED ${CK_ENABLE_ALL_DTYPES})
set(CK_ENABLE_ALL_DTYPES "ON")
endif()
endif() endif()
find_package(composable_kernel 1.0.0 COMPONENTS device_operations) find_package(composable_kernel COMPONENTS device_operations)
find_package(hip REQUIRED PATHS /opt/rocm) find_package(hip REQUIRED PATHS /opt/rocm)
message(STATUS "Build with HIP ${hip_VERSION}") message(STATUS "Build with HIP ${hip_VERSION}")
......
...@@ -3,6 +3,8 @@ ...@@ -3,6 +3,8 @@
#pragma once #pragma once
#include "ck/config.h"
#ifndef CK_DONT_USE_HIP_RUNTIME_HEADERS #ifndef CK_DONT_USE_HIP_RUNTIME_HEADERS
#include "hip/hip_runtime.h" #include "hip/hip_runtime.h"
#include "hip/hip_fp16.h" #include "hip/hip_fp16.h"
......
/*******************************************************************************
*
* MIT License
*
* Copyright (c) 2023 Advanced Micro Devices, Inc.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in all
* copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
* SOFTWARE.
*
*******************************************************************************/
#ifndef CK_CONFIG_H_IN
#define CK_CONFIG_H_IN
// clang-format off
//
// DataType supports in the current CK build
//
#ifndef DTYPES
#cmakedefine DTYPES "@DTYPES@"
#endif
// if DTYPES is not defined, enable all datatypes in headerfiles
#ifndef CK_ENABLE_ALL_DTYPES
#cmakedefine CK_ENABLE_ALL_DTYPES @CK_ENABLE_ALL_DTYPES@
#if defined(CK_ENABLE_ALL_DTYPES)
#ifndef CK_ENABLE_INT8
#define CK_ENABLE_INT8 "ON"
#endif
#ifndef CK_ENABLE_FP8
#define CK_ENABLE_FP8 "ON"
#endif
#ifndef CK_ENABLE_FP16
#define CK_ENABLE_FP16 "ON"
#endif
#ifndef CK_ENABLE_BF16
#define CK_ENABLE_BF16 "ON"
#endif
#ifndef CK_ENABLE_FP32
#define CK_ENABLE_FP32 "ON"
#endif
#ifndef CK_ENABLE_FP64
#define CK_ENABLE_FP64 "ON"
#endif
#endif
#endif
// if DTYPES are selectively enabled
#ifndef CK_ENABLE_INT8
#cmakedefine CK_ENABLE_INT8 @CK_ENABLE_INT8@
#endif
#ifndef CK_ENABLE_FP8
#cmakedefine CK_ENABLE_FP8 @CK_ENABLE_FP8@
#endif
#ifndef CK_ENABLE_FP16
#cmakedefine CK_ENABLE_FP16 @CK_ENABLE_FP16@
#endif
#ifndef CK_ENABLE_BF16
#cmakedefine CK_ENABLE_BF16 @CK_ENABLE_BF16@
#endif
#ifndef CK_ENABLE_FP32
#cmakedefine CK_ENABLE_FP32 @CK_ENABLE_FP32@
#endif
#ifndef CK_ENABLE_FP64
#cmakedefine CK_ENABLE_FP64 @CK_ENABLE_FP64@
#endif
//
// Legacy DL kernel supports in the current CK build
// by default DL kernels are turned OFF
//
#ifndef CK_ENABLE_DL_KERNELS
#cmakedefine CK_ENABLE_DL_KERNELS @CK_ENABLE_DL_KERNELS@
#endif
//
// Instances supports in the current CK build
//
#ifndef CK_ENABLE_INSTANCES_ONLY
#cmakedefine CK_ENABLE_INSTANCES_ONLY @CK_ENABLE_INSTANCES_ONLY@
#endif
// clang-format on
#endif // CK_CONFIG_H_IN
...@@ -1042,12 +1042,12 @@ struct Merge_v2_magic_division ...@@ -1042,12 +1042,12 @@ struct Merge_v2_magic_division
using UpLengths = using UpLengths =
decltype(make_tuple(container_reduce(LowLengths{}, math::multiplies{}, Number<1>{}))); decltype(make_tuple(container_reduce(LowLengths{}, math::multiplies{}, Number<1>{})));
using LowLengthsMagicDivisorMultipiler = decltype( using LowLengthsMagicDivisorMultipiler = decltype(generate_tuple(
generate_tuple(lambda_merge_generate_MagicDivision_calculate_magic_multiplier<LowLengths>{}, lambda_merge_generate_MagicDivision_calculate_magic_multiplier<LowLengths>{},
Number<NDimLow>{})); Number<NDimLow>{}));
using LowLengthsMagicDivisorShift = decltype( using LowLengthsMagicDivisorShift = decltype(generate_tuple(
generate_tuple(lambda_merge_generate_MagicDivision_calculate_magic_shift<LowLengths>{}, lambda_merge_generate_MagicDivision_calculate_magic_shift<LowLengths>{},
Number<NDimLow>{})); Number<NDimLow>{}));
LowLengths low_lengths_; LowLengths low_lengths_;
...@@ -1201,8 +1201,8 @@ struct Merge_v2r2_magic_division ...@@ -1201,8 +1201,8 @@ struct Merge_v2r2_magic_division
lambda_merge_generate_MagicDivision_calculate_magic_multiplier<LowLengthsScan>{}, lambda_merge_generate_MagicDivision_calculate_magic_multiplier<LowLengthsScan>{},
Number<NDimLow>{})); Number<NDimLow>{}));
using LowLengthsScanMagicDivisorShift = decltype( using LowLengthsScanMagicDivisorShift = decltype(generate_tuple(
generate_tuple(lambda_merge_generate_MagicDivision_calculate_magic_shift<LowLengthsScan>{}, lambda_merge_generate_MagicDivision_calculate_magic_shift<LowLengthsScan>{},
Number<NDimLow>{})); Number<NDimLow>{}));
LowLengths low_lengths_; LowLengths low_lengths_;
......
...@@ -35,8 +35,8 @@ struct BlockwiseSoftmax ...@@ -35,8 +35,8 @@ struct BlockwiseSoftmax
static constexpr index_t MRepeat = ThreadSliceDesc_M_K{}.GetLength(I0); static constexpr index_t MRepeat = ThreadSliceDesc_M_K{}.GetLength(I0);
static constexpr index_t KRepeat = ThreadSliceDesc_M_K{}.GetLength(I1); static constexpr index_t KRepeat = ThreadSliceDesc_M_K{}.GetLength(I1);
using ThreadSliceDesc_M = decltype( using ThreadSliceDesc_M = decltype(make_naive_tensor_descriptor_packed(
make_naive_tensor_descriptor_packed(make_tuple(ThreadSliceDesc_M_K{}.GetLength(I0)))); make_tuple(ThreadSliceDesc_M_K{}.GetLength(I0))));
using ThreadwiseMaxReduce = typename conditional< using ThreadwiseMaxReduce = typename conditional<
IgnoreNaN, IgnoreNaN,
......
...@@ -588,14 +588,18 @@ struct DeviceBatchedContractionMultipleD_Xdl_CShuffle ...@@ -588,14 +588,18 @@ struct DeviceBatchedContractionMultipleD_Xdl_CShuffle
LoopSched>; LoopSched>;
// desc for blockwise copy // desc for blockwise copy
using AGridDesc_AK0_M_AK1 = remove_cvref_t<decltype( using AGridDesc_AK0_M_AK1 =
GridwiseGemm::MakeDefaultAGridDescriptor_AK0_M_AK1(AGridDesc_M_K{}))>; remove_cvref_t<decltype(GridwiseGemm::MakeDefaultAGridDescriptor_AK0_M_AK1(
using BGridDesc_BK0_N_BK1 = remove_cvref_t<decltype( AGridDesc_M_K{}))>;
GridwiseGemm::MakeDefaultBGridDescriptor_BK0_N_BK1(BGridDesc_N_K{}))>; using BGridDesc_BK0_N_BK1 =
using DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock = remove_cvref_t<decltype( remove_cvref_t<decltype(GridwiseGemm::MakeDefaultBGridDescriptor_BK0_N_BK1(
GridwiseGemm::MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(DsGridDesc_M_N{}))>; BGridDesc_N_K{}))>;
using EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock = remove_cvref_t<decltype( using DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock = remove_cvref_t<
GridwiseGemm::MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(EGridDesc_M_N{}))>; decltype(GridwiseGemm::MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
DsGridDesc_M_N{}))>;
using EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock =
remove_cvref_t<decltype(GridwiseGemm::MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
EGridDesc_M_N{}))>;
// block-to-e-tile map // block-to-e-tile map
using Block2ETileMap = using Block2ETileMap =
......
...@@ -378,13 +378,16 @@ struct DeviceBatchedGemmEPermuteXdl : public DeviceBatchedGemmEPermute<ALayout, ...@@ -378,13 +378,16 @@ struct DeviceBatchedGemmEPermuteXdl : public DeviceBatchedGemmEPermute<ALayout,
CDEBlockTransferScalarPerVector_NPerBlock, CDEBlockTransferScalarPerVector_NPerBlock,
LoopSched>; LoopSched>;
using AGridDesc_AK0_M_AK1 = remove_cvref_t<decltype( using AGridDesc_AK0_M_AK1 =
GridwiseGemm::MakeDefaultAGridDescriptor_AK0_M_AK1(AGridDesc_M_K{}))>; remove_cvref_t<decltype(GridwiseGemm::MakeDefaultAGridDescriptor_AK0_M_AK1(
using BGridDesc_BK0_N_BK1 = remove_cvref_t<decltype( AGridDesc_M_K{}))>;
GridwiseGemm::MakeDefaultBGridDescriptor_BK0_N_BK1(BGridDesc_N_K{}))>; using BGridDesc_BK0_N_BK1 =
remove_cvref_t<decltype(GridwiseGemm::MakeDefaultBGridDescriptor_BK0_N_BK1(
using EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock = decltype( BGridDesc_N_K{}))>;
GridwiseGemm::MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(EGridDesc_M_N{}));
using EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock =
decltype(GridwiseGemm::MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
EGridDesc_M_N{}));
using Block2ETileMap = typename GridwiseGemm::DefaultBlock2ETileMap; using Block2ETileMap = typename GridwiseGemm::DefaultBlock2ETileMap;
// Argument // Argument
......
...@@ -368,14 +368,18 @@ struct DeviceBatchedGemmMultiD_Xdl : public DeviceBatchedGemmMultiD<ALayout, ...@@ -368,14 +368,18 @@ struct DeviceBatchedGemmMultiD_Xdl : public DeviceBatchedGemmMultiD<ALayout,
LoopSched>; LoopSched>;
// desc for blockwise copy // desc for blockwise copy
using AGridDesc_AK0_M_AK1 = remove_cvref_t<decltype( using AGridDesc_AK0_M_AK1 =
GridwiseGemm::MakeDefaultAGridDescriptor_AK0_M_AK1(AGridDesc_M_K{}))>; remove_cvref_t<decltype(GridwiseGemm::MakeDefaultAGridDescriptor_AK0_M_AK1(
using BGridDesc_BK0_N_BK1 = remove_cvref_t<decltype( AGridDesc_M_K{}))>;
GridwiseGemm::MakeDefaultBGridDescriptor_BK0_N_BK1(BGridDesc_N_K{}))>; using BGridDesc_BK0_N_BK1 =
using DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock = remove_cvref_t<decltype( remove_cvref_t<decltype(GridwiseGemm::MakeDefaultBGridDescriptor_BK0_N_BK1(
GridwiseGemm::MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(DsGridDesc_M_N{}))>; BGridDesc_N_K{}))>;
using EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock = remove_cvref_t<decltype( using DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock = remove_cvref_t<
GridwiseGemm::MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(EGridDesc_M_N{}))>; decltype(GridwiseGemm::MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
DsGridDesc_M_N{}))>;
using EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock =
remove_cvref_t<decltype(GridwiseGemm::MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
EGridDesc_M_N{}))>;
// block-to-e-tile map // block-to-e-tile map
using Block2ETileMap = using Block2ETileMap =
......
...@@ -510,12 +510,15 @@ struct DeviceBatchedGemmMultipleDGemmMultipleD_Xdl_CShuffle ...@@ -510,12 +510,15 @@ struct DeviceBatchedGemmMultipleDGemmMultipleD_Xdl_CShuffle
CDE1ShuffleBlockTransferScalarPerVector_NPerBlock, CDE1ShuffleBlockTransferScalarPerVector_NPerBlock,
LoopSched>; LoopSched>;
using A0GridDesc_AK0_M_AK1 = remove_cvref_t<decltype( using A0GridDesc_AK0_M_AK1 =
GridwiseGemm::MakeDefaultA0GridDescriptor_AK0_M_AK1(A0GridDesc_M_K{}))>; remove_cvref_t<decltype(GridwiseGemm::MakeDefaultA0GridDescriptor_AK0_M_AK1(
using B0GridDesc_BK0_N_BK1 = remove_cvref_t<decltype( A0GridDesc_M_K{}))>;
GridwiseGemm::MakeDefaultB0GridDescriptor_BK0_N_BK1(B0GridDesc_N_K{}))>; using B0GridDesc_BK0_N_BK1 =
using B1GridDesc_BK0_N_BK1 = remove_cvref_t<decltype( remove_cvref_t<decltype(GridwiseGemm::MakeDefaultB0GridDescriptor_BK0_N_BK1(
GridwiseGemm::MakeDefaultB1GridDescriptor_BK0_N_BK1(B1GridDesc_N_K{}))>; B0GridDesc_N_K{}))>;
using B1GridDesc_BK0_N_BK1 =
remove_cvref_t<decltype(GridwiseGemm::MakeDefaultB1GridDescriptor_BK0_N_BK1(
B1GridDesc_N_K{}))>;
// Argument // Argument
struct Argument : public BaseArgument struct Argument : public BaseArgument
......
...@@ -355,14 +355,18 @@ struct DeviceContractionMultipleD_Xdl_CShuffle ...@@ -355,14 +355,18 @@ struct DeviceContractionMultipleD_Xdl_CShuffle
LoopSched>; LoopSched>;
// desc for blockwise copy // desc for blockwise copy
using AGridDesc_AK0_M_AK1 = remove_cvref_t<decltype( using AGridDesc_AK0_M_AK1 =
GridwiseGemm::MakeDefaultAGridDescriptor_AK0_M_AK1(AGridDesc_M_K{}))>; remove_cvref_t<decltype(GridwiseGemm::MakeDefaultAGridDescriptor_AK0_M_AK1(
using BGridDesc_BK0_N_BK1 = remove_cvref_t<decltype( AGridDesc_M_K{}))>;
GridwiseGemm::MakeDefaultBGridDescriptor_BK0_N_BK1(BGridDesc_N_K{}))>; using BGridDesc_BK0_N_BK1 =
using DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock = remove_cvref_t<decltype( remove_cvref_t<decltype(GridwiseGemm::MakeDefaultBGridDescriptor_BK0_N_BK1(
GridwiseGemm::MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(DsGridDesc_M_N{}))>; BGridDesc_N_K{}))>;
using EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock = remove_cvref_t<decltype( using DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock = remove_cvref_t<
GridwiseGemm::MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(EGridDesc_M_N{}))>; decltype(GridwiseGemm::MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
DsGridDesc_M_N{}))>;
using EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock =
remove_cvref_t<decltype(GridwiseGemm::MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
EGridDesc_M_N{}))>;
// block-to-e-tile map // block-to-e-tile map
using Block2ETileMap = using Block2ETileMap =
......
...@@ -364,11 +364,13 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle ...@@ -364,11 +364,13 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle
using DsGridDesc_M_N = remove_cvref_t<decltype(MakeDsGridDescriptor_M_N({}, {}, {}))>; using DsGridDesc_M_N = remove_cvref_t<decltype(MakeDsGridDescriptor_M_N({}, {}, {}))>;
// We have to separate mean var descriptor for gemm and layernorm bacause of different grid // We have to separate mean var descriptor for gemm and layernorm bacause of different grid
// layout(different padding) // layout(different padding)
using GemmMeanVarGridDesc_M_NBlock = decltype( using GemmMeanVarGridDesc_M_NBlock =
MakeMeanVarDescriptor_M_N<Sequence<true, false>, GemmMPerBlock, GemmNPerBlock>(1, 1)); decltype(MakeMeanVarDescriptor_M_N<Sequence<true, false>, GemmMPerBlock, GemmNPerBlock>(1,
1));
using GemmCountGridDesc_M_NBlock = decltype( using GemmCountGridDesc_M_NBlock =
MakeCountDescriptor_M_N<Sequence<true, false>, GemmMPerBlock, GemmNPerBlock>(1, 1)); decltype(MakeCountDescriptor_M_N<Sequence<true, false>, GemmMPerBlock, GemmNPerBlock>(1,
1));
using LayernormMeanVarGridDesc_M_NBlock = using LayernormMeanVarGridDesc_M_NBlock =
decltype(MakeMeanVarDescriptor_M_N<Sequence<true, true>, decltype(MakeMeanVarDescriptor_M_N<Sequence<true, true>,
......
...@@ -337,10 +337,12 @@ struct DeviceGemmMultipleDMultipleR_Xdl_CShuffle ...@@ -337,10 +337,12 @@ struct DeviceGemmMultipleDMultipleR_Xdl_CShuffle
RThreadTransferDstScalarPerVector_MPerBlock, RThreadTransferDstScalarPerVector_MPerBlock,
LoopSched>; LoopSched>;
using AGridDesc_AK0_M_AK1 = remove_cvref_t<decltype( using AGridDesc_AK0_M_AK1 =
GridwiseGemm::MakeDefaultAGridDescriptor_AK0_M_AK1(AGridDesc_M_K{}))>; remove_cvref_t<decltype(GridwiseGemm::MakeDefaultAGridDescriptor_AK0_M_AK1(
using BGridDesc_BK0_N_BK1 = remove_cvref_t<decltype( AGridDesc_M_K{}))>;
GridwiseGemm::MakeDefaultBGridDescriptor_BK0_N_BK1(BGridDesc_N_K{}))>; using BGridDesc_BK0_N_BK1 =
remove_cvref_t<decltype(GridwiseGemm::MakeDefaultBGridDescriptor_BK0_N_BK1(
BGridDesc_N_K{}))>;
using Block2ETileMap = typename GridwiseGemm::DefaultBlock2ETileMap; using Block2ETileMap = typename GridwiseGemm::DefaultBlock2ETileMap;
......
...@@ -288,14 +288,18 @@ struct DeviceGemmMultipleD_Xdl_CShuffle : public DeviceGemmMultipleD<ALayout, ...@@ -288,14 +288,18 @@ struct DeviceGemmMultipleD_Xdl_CShuffle : public DeviceGemmMultipleD<ALayout,
PipelineVer>; PipelineVer>;
// desc for blockwise copy // desc for blockwise copy
using AGridDesc_AK0_M_AK1 = remove_cvref_t<decltype( using AGridDesc_AK0_M_AK1 =
GridwiseGemm::MakeDefaultAGridDescriptor_AK0_M_AK1(AGridDesc_M_K{}))>; remove_cvref_t<decltype(GridwiseGemm::MakeDefaultAGridDescriptor_AK0_M_AK1(
using BGridDesc_BK0_N_BK1 = remove_cvref_t<decltype( AGridDesc_M_K{}))>;
GridwiseGemm::MakeDefaultBGridDescriptor_BK0_N_BK1(BGridDesc_N_K{}))>; using BGridDesc_BK0_N_BK1 =
using DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock = remove_cvref_t<decltype( remove_cvref_t<decltype(GridwiseGemm::MakeDefaultBGridDescriptor_BK0_N_BK1(
GridwiseGemm::MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(DsGridDesc_M_N{}))>; BGridDesc_N_K{}))>;
using EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock = remove_cvref_t<decltype( using DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock = remove_cvref_t<
GridwiseGemm::MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(EGridDesc_M_N{}))>; decltype(GridwiseGemm::MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
DsGridDesc_M_N{}))>;
using EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock =
remove_cvref_t<decltype(GridwiseGemm::MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
EGridDesc_M_N{}))>;
// block-to-e-tile map // block-to-e-tile map
using Block2ETileMap = using Block2ETileMap =
......
...@@ -248,10 +248,12 @@ struct DeviceGemm_Xdl_WaveletModel_CShuffle : public DeviceGemm<ALayout, ...@@ -248,10 +248,12 @@ struct DeviceGemm_Xdl_WaveletModel_CShuffle : public DeviceGemm<ALayout,
CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
CShuffleBlockTransferScalarPerVector_NPerBlock>; CShuffleBlockTransferScalarPerVector_NPerBlock>;
using AGridDesc_AK0_M_AK1 = remove_cvref_t<decltype( using AGridDesc_AK0_M_AK1 =
GridwiseGemm::MakeDefaultAGridDescriptor_AK0_M_AK1(AGridDesc_M_K{}))>; remove_cvref_t<decltype(GridwiseGemm::MakeDefaultAGridDescriptor_AK0_M_AK1(
using BGridDesc_BK0_N_BK1 = remove_cvref_t<decltype( AGridDesc_M_K{}))>;
GridwiseGemm::MakeDefaultBGridDescriptor_BK0_N_BK1(BGridDesc_N_K{}))>; using BGridDesc_BK0_N_BK1 =
remove_cvref_t<decltype(GridwiseGemm::MakeDefaultBGridDescriptor_BK0_N_BK1(
BGridDesc_N_K{}))>;
using Block2ETileMap = typename GridwiseGemm::DefaultBlock2ETileMap; using Block2ETileMap = typename GridwiseGemm::DefaultBlock2ETileMap;
......
...@@ -400,14 +400,18 @@ struct DeviceGroupedContractionMultipleD_Xdl_CShuffle ...@@ -400,14 +400,18 @@ struct DeviceGroupedContractionMultipleD_Xdl_CShuffle
LoopSched>; LoopSched>;
// desc for blockwise copy // desc for blockwise copy
using AGridDesc_AK0_M_AK1 = remove_cvref_t<decltype( using AGridDesc_AK0_M_AK1 =
GridwiseGemm::MakeDefaultAGridDescriptor_AK0_M_AK1(AGridDesc_M_K{}))>; remove_cvref_t<decltype(GridwiseGemm::MakeDefaultAGridDescriptor_AK0_M_AK1(
using BGridDesc_BK0_N_BK1 = remove_cvref_t<decltype( AGridDesc_M_K{}))>;
GridwiseGemm::MakeDefaultBGridDescriptor_BK0_N_BK1(BGridDesc_N_K{}))>; using BGridDesc_BK0_N_BK1 =
using DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock = remove_cvref_t<decltype( remove_cvref_t<decltype(GridwiseGemm::MakeDefaultBGridDescriptor_BK0_N_BK1(
GridwiseGemm::MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(DsGridDesc_M_N{}))>; BGridDesc_N_K{}))>;
using EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock = remove_cvref_t<decltype( using DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock = remove_cvref_t<
GridwiseGemm::MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(EGridDesc_M_N{}))>; decltype(GridwiseGemm::MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
DsGridDesc_M_N{}))>;
using EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock =
remove_cvref_t<decltype(GridwiseGemm::MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
EGridDesc_M_N{}))>;
struct GroupedContractionBlock2ETileMap struct GroupedContractionBlock2ETileMap
{ {
......
...@@ -422,10 +422,12 @@ struct DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1 ...@@ -422,10 +422,12 @@ struct DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1
using AGridDesc_M_K = decltype(transform_k0_m_k1_to_m_k(AGridDesc_AK0_M_AK1{})); using AGridDesc_M_K = decltype(transform_k0_m_k1_to_m_k(AGridDesc_AK0_M_AK1{}));
using BGridDesc_N_K = decltype(transform_k0_m_k1_to_m_k(BGridDesc_BK0_N_BK1{})); using BGridDesc_N_K = decltype(transform_k0_m_k1_to_m_k(BGridDesc_BK0_N_BK1{}));
using DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock = decltype( using DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock =
GridwiseGemm::MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(DsGridDesc_M_N{})); decltype(GridwiseGemm::MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
using EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock = decltype( DsGridDesc_M_N{}));
GridwiseGemm::MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(EGridDesc_M_N{})); using EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock =
decltype(GridwiseGemm::MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
EGridDesc_M_N{}));
// block-to-e-tile map // block-to-e-tile map
using Block2ETileMap = using Block2ETileMap =
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment