Commit e2eb0418 authored by Jun Liu's avatar Jun Liu
Browse files

Merge branch 'develop' into amd-develop

parents f2ba10b8 cbb6f2ab
...@@ -541,6 +541,9 @@ if(NOT DEFINED INSTANCES_ONLY) ...@@ -541,6 +541,9 @@ if(NOT DEFINED INSTANCES_ONLY)
PACKAGE_NAME examples PACKAGE_NAME examples
) )
add_subdirectory(example) add_subdirectory(example)
if(GPU_TARGETS MATCHES "gfx9" AND NOT INSTANCES_ONLY)
add_subdirectory(codegen)
endif()
if(BUILD_TESTING) if(BUILD_TESTING)
add_subdirectory(test) add_subdirectory(test)
endif() endif()
......
...@@ -746,10 +746,6 @@ pipeline { ...@@ -746,10 +746,6 @@ pipeline {
name: "RUN_PERFORMANCE_TESTS", name: "RUN_PERFORMANCE_TESTS",
defaultValue: true, defaultValue: true,
description: "Run the performance tests (default: ON)") description: "Run the performance tests (default: ON)")
booleanParam(
name: "RUN_CODEGEN_TESTS",
defaultValue: true,
description: "Run the codegen tests (default: ON)")
booleanParam( booleanParam(
name: "RUN_CK_TILE_TESTS", name: "RUN_CK_TILE_TESTS",
defaultValue: false, defaultValue: false,
...@@ -841,33 +837,6 @@ pipeline { ...@@ -841,33 +837,6 @@ pipeline {
} }
} }
} }
stage("Run Codegen Tests")
{
parallel
{
stage("Run Codegen Tests on gfx90a")
{
when {
beforeAgent true
expression { params.RUN_CODEGEN_TESTS.toBoolean() }
}
agent{ label rocmnode("gfx90a")}
environment{
setup_args = "NO_CK_BUILD"
execute_args = """ cd ../codegen && rm -rf build && mkdir build && cd build && \
cmake -D CMAKE_PREFIX_PATH=/opt/rocm \
-D CMAKE_CXX_COMPILER=/opt/rocm/llvm/bin/clang++ \
-D CMAKE_BUILD_TYPE=Release \
-D GPU_TARGETS="gfx90a" \
-DCMAKE_CXX_FLAGS=" -O3 " .. && make -j check"""
}
steps{
buildHipClangJobAndReboot(setup_args:setup_args, no_reboot:true, build_type: 'Release', execute_cmd: execute_args)
cleanWs()
}
}
}
}
stage("Run CK_TILE Tests") stage("Run CK_TILE Tests")
{ {
parallel parallel
......
cmake_minimum_required(VERSION 3.16)
project(composable_kernel_host LANGUAGES CXX HIP)
set(CMAKE_EXPORT_COMPILE_COMMANDS ON) set(CMAKE_EXPORT_COMPILE_COMMANDS ON)
set(CMAKE_LIBRARY_OUTPUT_DIRECTORY ${CMAKE_BINARY_DIR}/lib) set(CMAKE_LIBRARY_OUTPUT_DIRECTORY ${CMAKE_BINARY_DIR}/lib)
...@@ -8,17 +5,9 @@ set(CMAKE_ARCHIVE_OUTPUT_DIRECTORY ${CMAKE_BINARY_DIR}/lib) ...@@ -8,17 +5,9 @@ set(CMAKE_ARCHIVE_OUTPUT_DIRECTORY ${CMAKE_BINARY_DIR}/lib)
set(CMAKE_RUNTIME_OUTPUT_DIRECTORY ${CMAKE_BINARY_DIR}/bin) set(CMAKE_RUNTIME_OUTPUT_DIRECTORY ${CMAKE_BINARY_DIR}/bin)
set(CK_ROOT ${CMAKE_CURRENT_SOURCE_DIR}/..) set(CK_ROOT ${CMAKE_CURRENT_SOURCE_DIR}/..)
find_package(ROCM)
include(ROCMInstallTargets)
include(ROCMTest)
add_compile_options(-std=c++17) add_compile_options(-std=c++17)
find_package(hip) find_package(hip)
## HIP add_custom_target(codegen)
set(CMAKE_HIP_PLATFORM amd)
set(CMAKE_HIP_COMPILER ${CMAKE_CXX_COMPILER})
set(CMAKE_HIP_EXTENSIONS ON)
message("CMAKE_HIP_COMPILER: ${CMAKE_HIP_COMPILER}")
# add include directories # add include directories
include_directories(BEFORE include_directories(BEFORE
...@@ -32,8 +21,9 @@ list(APPEND CMAKE_MODULE_PATH ${CK_ROOT}/cmake) ...@@ -32,8 +21,9 @@ list(APPEND CMAKE_MODULE_PATH ${CK_ROOT}/cmake)
include(Embed) include(Embed)
file(GLOB_RECURSE KERNEL_FILES CONFIGURE_DEPENDS file(GLOB_RECURSE KERNEL_FILES CONFIGURE_DEPENDS
${CK_ROOT}/include/ck/*.hpp) ${CK_ROOT}/include/ck/*.hpp)
message(STATUS "KERNEL_FILES: ${KERNEL_FILES}") #printouts fot debug purposes
message(STATUS "RELATIVE: ${CK_ROOT}/include") #message(STATUS "KERNEL_FILES: ${KERNEL_FILES}")
#message(STATUS "RELATIVE: ${CK_ROOT}/include")
add_embed_library(ck_headers ${KERNEL_FILES} RELATIVE ${CK_ROOT}/include) add_embed_library(ck_headers ${KERNEL_FILES} RELATIVE ${CK_ROOT}/include)
file(GLOB SOURCES CONFIGURE_DEPENDS src/*.cpp) file(GLOB SOURCES CONFIGURE_DEPENDS src/*.cpp)
......
...@@ -76,8 +76,11 @@ std::string SequenceStr(const std::vector<int>& v); ...@@ -76,8 +76,11 @@ std::string SequenceStr(const std::vector<int>& v);
std::string MakeTuple(const std::vector<std::string>& v); std::string MakeTuple(const std::vector<std::string>& v);
#pragma clang diagnostic push
#pragma clang diagnostic ignored "-Wglobal-constructors"
template <int... xs> template <int... xs>
const std::string S = SequenceStr({xs...}); const std::string S = SequenceStr({xs...});
#pragma clang diagnostic pop
constexpr const char* PassThrough = "ck::tensor_operation::element_wise::PassThrough"; constexpr const char* PassThrough = "ck::tensor_operation::element_wise::PassThrough";
constexpr const char* Bilinear = "ck::tensor_operation::element_wise::Bilinear"; constexpr const char* Bilinear = "ck::tensor_operation::element_wise::Bilinear";
......
...@@ -3,6 +3,7 @@ ...@@ -3,6 +3,7 @@
#include "ck/host/device_gemm_multiple_d/operation.hpp" #include "ck/host/device_gemm_multiple_d/operation.hpp"
#include "ck/host/stringutils.hpp" #include "ck/host/stringutils.hpp"
#include "ck/host/types.hpp"
#include "ck/host/utils.hpp" #include "ck/host/utils.hpp"
#include <cassert> #include <cassert>
...@@ -32,11 +33,11 @@ static std::string GetGemmSpec(const std::size_t m, ...@@ -32,11 +33,11 @@ static std::string GetGemmSpec(const std::size_t m,
} }
// function to update prologue/epilogue with user provided operation // function to update prologue/epilogue with user provided operation
void Operation_Xdl_CShuffle::update_prologue(const std::string& prologue) void Operation_Xdl_CShuffle::update_prologue(const std::string& pro)
{ {
if(!prologue.empty()) if(!pro.empty())
{ {
this->prologue = prologue; this->prologue = pro;
this->cde_elem_op = "CDEElementOp"; this->cde_elem_op = "CDEElementOp";
} }
else else
...@@ -45,11 +46,11 @@ void Operation_Xdl_CShuffle::update_prologue(const std::string& prologue) ...@@ -45,11 +46,11 @@ void Operation_Xdl_CShuffle::update_prologue(const std::string& prologue)
} }
} }
void Operation_Xdl_CShuffle::update_epilogue(const std::string& epilogue) void Operation_Xdl_CShuffle::update_epilogue(const std::string& epi)
{ {
if(!epilogue.empty()) if(!epi.empty())
{ {
this->epilogue = epilogue; this->epilogue = epi;
this->cde_elem_op = "CDEElementOp"; this->cde_elem_op = "CDEElementOp";
} }
else else
......
...@@ -4,6 +4,7 @@ ...@@ -4,6 +4,7 @@
#include "ck/host/device_grouped_conv_fwd_multiple_d/conv_fwd_op.hpp" #include "ck/host/device_grouped_conv_fwd_multiple_d/conv_fwd_op.hpp"
#include <iostream> #include <iostream>
#include "ck/host/stringutils.hpp" #include "ck/host/stringutils.hpp"
#include "ck/host/types.hpp"
#include "ck/host/utils.hpp" #include "ck/host/utils.hpp"
#include <cassert> #include <cassert>
...@@ -11,34 +12,15 @@ namespace ck { ...@@ -11,34 +12,15 @@ namespace ck {
namespace host { namespace host {
namespace conv { namespace conv {
// calculate appropriate Gemm Specification based on input tensor dimensions // NOTE: in CK, MNKPadding is always used for forward convolution, so didn't
// NOTE: in CK, MNKPadding is always used for forward convolution // add GemmSpec function here
static std::string GetGemmSpec(const std::size_t m,
const std::size_t n,
const std::size_t k,
const std::size_t m_per_block,
const std::size_t n_per_block,
const std::size_t k_per_block)
{
std::string spec = "";
if(integer_divide_ceil(m, m_per_block) * m_per_block - m != 0)
spec += "M";
if(integer_divide_ceil(n, n_per_block) * n_per_block - n != 0)
spec += "N";
if(integer_divide_ceil(k, k_per_block) * k_per_block - k != 0)
spec += "K";
if(spec == "")
return "ck::tensor_operation::device::GemmSpecialization::Default";
return "ck::tensor_operation::device::GemmSpecialization::" + spec + "Padding";
}
// function to update prologue/epilogue with user provided operation // function to update prologue/epilogue with user provided operation
void Operation_Conv_Fwd_Xdl_Cshuffle::update_prologue(const std::string& prologue) void Operation_Conv_Fwd_Xdl_Cshuffle::update_prologue(const std::string& pro)
{ {
if(!prologue.empty()) if(!pro.empty())
{ {
this->prologue = prologue; this->prologue = pro;
this->cde_elem_op = "CDEElementOp"; this->cde_elem_op = "CDEElementOp";
} }
else else
...@@ -47,11 +29,11 @@ void Operation_Conv_Fwd_Xdl_Cshuffle::update_prologue(const std::string& prologu ...@@ -47,11 +29,11 @@ void Operation_Conv_Fwd_Xdl_Cshuffle::update_prologue(const std::string& prologu
} }
} }
void Operation_Conv_Fwd_Xdl_Cshuffle::update_epilogue(const std::string& epilogue) void Operation_Conv_Fwd_Xdl_Cshuffle::update_epilogue(const std::string& epi)
{ {
if(!epilogue.empty()) if(!epi.empty())
{ {
this->epilogue = epilogue; this->epilogue = epi;
this->cde_elem_op = "CDEElementOp"; this->cde_elem_op = "CDEElementOp";
} }
else else
......
...@@ -4,7 +4,10 @@ ...@@ -4,7 +4,10 @@
namespace ck { namespace ck {
namespace host { namespace host {
#pragma clang diagnostic push
#pragma clang diagnostic ignored "-Wglobal-constructors"
const std::string config_header = ""; const std::string config_header = "";
#pragma clang diagnostic pop
std::unordered_map<std::string_view, std::string_view> GetHeaders() std::unordered_map<std::string_view, std::string_view> GetHeaders()
{ {
......
...@@ -4,7 +4,9 @@ file(GLOB TEST_SRCS CONFIGURE_DEPENDS *.cpp) ...@@ -4,7 +4,9 @@ file(GLOB TEST_SRCS CONFIGURE_DEPENDS *.cpp)
foreach(TEST_SRC ${TEST_SRCS}) foreach(TEST_SRC ${TEST_SRCS})
set_source_files_properties(${TEST_SRC} PROPERTIES LANGUAGE HIP) set_source_files_properties(${TEST_SRC} PROPERTIES LANGUAGE HIP)
get_filename_component(BASE_NAME ${TEST_SRC} NAME_WE) get_filename_component(BASE_NAME ${TEST_SRC} NAME_WE)
rocm_add_test_executable(test_host_${BASE_NAME} ${TEST_SRC}) add_executable(test_host_${BASE_NAME} ${TEST_SRC})
add_dependencies(codegen test_host_${BASE_NAME})
add_test(NAME codegen_test_${BASE_NAME} COMMAND test_host_${BASE_NAME})
target_link_libraries(test_host_${BASE_NAME} ck_rtc ck_host) target_link_libraries(test_host_${BASE_NAME} ck_rtc ck_host)
# target_link_libraries(test_host_${BASE_NAME} ${CK_ROOT}/build/lib/libutility.a) # target_link_libraries(test_host_${BASE_NAME} ${CK_ROOT}/build/lib/libutility.a)
target_include_directories(test_host_${BASE_NAME} PUBLIC include()) target_include_directories(test_host_${BASE_NAME} PUBLIC include())
......
...@@ -92,7 +92,6 @@ struct Epilogue ...@@ -92,7 +92,6 @@ struct Epilogue
static_cast<int>(prob.C), static_cast<int>(prob.C),
static_cast<int>(prob.Y), static_cast<int>(prob.Y),
static_cast<int>(prob.X)}; static_cast<int>(prob.X)};
ck::Array<ck::index_t, 5> d_lengths = {};
ck::Array<ck::index_t, 5> in_strides{static_cast<int>(prob.C), ck::Array<ck::index_t, 5> in_strides{static_cast<int>(prob.C),
static_cast<int>(prob.Hi * prob.Wi * prob.G * prob.C), static_cast<int>(prob.Hi * prob.Wi * prob.G * prob.C),
...@@ -109,7 +108,6 @@ struct Epilogue ...@@ -109,7 +108,6 @@ struct Epilogue
1, 1,
static_cast<int>(prob.X * prob.C), static_cast<int>(prob.X * prob.C),
static_cast<int>(prob.C)}; static_cast<int>(prob.C)};
ck::Array<ck::index_t, 5> d_strides = {};
ck::Array<ck::index_t, 2> conv_filter_strides = {2, 2}; ck::Array<ck::index_t, 2> conv_filter_strides = {2, 2};
ck::Array<ck::index_t, 2> conv_filter_dilations = {1, 1}; ck::Array<ck::index_t, 2> conv_filter_dilations = {1, 1};
......
...@@ -92,7 +92,6 @@ struct Epilogue ...@@ -92,7 +92,6 @@ struct Epilogue
static_cast<int>(prob.C), static_cast<int>(prob.C),
static_cast<int>(prob.Y), static_cast<int>(prob.Y),
static_cast<int>(prob.X)}; static_cast<int>(prob.X)};
ck::Array<ck::index_t, 5> d_lengths = {};
ck::Array<ck::index_t, 5> in_strides{static_cast<int>(prob.C), ck::Array<ck::index_t, 5> in_strides{static_cast<int>(prob.C),
static_cast<int>(prob.Hi * prob.Wi * prob.G * prob.C), static_cast<int>(prob.Hi * prob.Wi * prob.G * prob.C),
...@@ -109,7 +108,6 @@ struct Epilogue ...@@ -109,7 +108,6 @@ struct Epilogue
1, 1,
static_cast<int>(prob.X * prob.C), static_cast<int>(prob.X * prob.C),
static_cast<int>(prob.C)}; static_cast<int>(prob.C)};
ck::Array<ck::index_t, 5> d_strides = {};
ck::Array<ck::index_t, 2> conv_filter_strides = {1, 1}; ck::Array<ck::index_t, 2> conv_filter_strides = {1, 1};
ck::Array<ck::index_t, 2> conv_filter_dilations = {1, 1}; ck::Array<ck::index_t, 2> conv_filter_dilations = {1, 1};
......
...@@ -92,7 +92,6 @@ struct Epilogue ...@@ -92,7 +92,6 @@ struct Epilogue
static_cast<int>(prob.C), static_cast<int>(prob.C),
static_cast<int>(prob.Y), static_cast<int>(prob.Y),
static_cast<int>(prob.X)}; static_cast<int>(prob.X)};
ck::Array<ck::index_t, 5> d_lengths = {};
ck::Array<ck::index_t, 5> in_strides{static_cast<int>(prob.C), ck::Array<ck::index_t, 5> in_strides{static_cast<int>(prob.C),
static_cast<int>(prob.Hi * prob.Wi * prob.G * prob.C), static_cast<int>(prob.Hi * prob.Wi * prob.G * prob.C),
...@@ -109,7 +108,6 @@ struct Epilogue ...@@ -109,7 +108,6 @@ struct Epilogue
1, 1,
static_cast<int>(prob.X * prob.C), static_cast<int>(prob.X * prob.C),
static_cast<int>(prob.C)}; static_cast<int>(prob.C)};
ck::Array<ck::index_t, 5> d_strides = {};
ck::Array<ck::index_t, 2> conv_filter_strides = {2, 2}; ck::Array<ck::index_t, 2> conv_filter_strides = {2, 2};
ck::Array<ck::index_t, 2> conv_filter_dilations = {1, 1}; ck::Array<ck::index_t, 2> conv_filter_dilations = {1, 1};
......
...@@ -92,7 +92,6 @@ struct Epilogue ...@@ -92,7 +92,6 @@ struct Epilogue
static_cast<int>(prob.C), static_cast<int>(prob.C),
static_cast<int>(prob.Y), static_cast<int>(prob.Y),
static_cast<int>(prob.X)}; static_cast<int>(prob.X)};
ck::Array<ck::index_t, 5> d_lengths = {};
ck::Array<ck::index_t, 5> in_strides{static_cast<int>(prob.C), ck::Array<ck::index_t, 5> in_strides{static_cast<int>(prob.C),
static_cast<int>(prob.Hi * prob.Wi * prob.G * prob.C), static_cast<int>(prob.Hi * prob.Wi * prob.G * prob.C),
...@@ -109,7 +108,6 @@ struct Epilogue ...@@ -109,7 +108,6 @@ struct Epilogue
1, 1,
static_cast<int>(prob.X * prob.C), static_cast<int>(prob.X * prob.C),
static_cast<int>(prob.C)}; static_cast<int>(prob.C)};
ck::Array<ck::index_t, 5> d_strides = {};
ck::Array<ck::index_t, 2> conv_filter_strides = {1, 1}; ck::Array<ck::index_t, 2> conv_filter_strides = {1, 1};
ck::Array<ck::index_t, 2> conv_filter_dilations = {1, 1}; ck::Array<ck::index_t, 2> conv_filter_dilations = {1, 1};
......
...@@ -118,4 +118,4 @@ void kernel::launch(hipStream_t stream, ...@@ -118,4 +118,4 @@ void kernel::launch(hipStream_t stream,
launch_kernel(impl->fun, stream, global, local, kernargs.data(), size); launch_kernel(impl->fun, stream, global, local, kernargs.data(), size);
} }
} // namespace rtc } // namespace rtc
\ No newline at end of file
...@@ -45,4 +45,4 @@ void tmp_dir::execute(const std::string& cmd) const ...@@ -45,4 +45,4 @@ void tmp_dir::execute(const std::string& cmd) const
tmp_dir::~tmp_dir() { std::filesystem::remove_all(this->path); } tmp_dir::~tmp_dir() { std::filesystem::remove_all(this->path); }
} // namespace rtc } // namespace rtc
\ No newline at end of file
...@@ -75,7 +75,7 @@ function(add_example_executable EXAMPLE_NAME FILE_NAME) ...@@ -75,7 +75,7 @@ function(add_example_executable EXAMPLE_NAME FILE_NAME)
#only continue if there are some source files left on the list #only continue if there are some source files left on the list
if(FILE_NAME) if(FILE_NAME)
if(FILE_NAME MATCHES "_xdl") if(FILE_NAME MATCHES "_xdl")
list(REMOVE_ITEM EX_TARGETS gfx1030 gfx1100 gfx1101 gfx1102 gfx1103) list(REMOVE_ITEM EX_TARGETS gfx1030 gfx1100 gfx1101 gfx1102 gfx1103 gfx1200 gfx1201)
elseif(FILE_NAME MATCHES "_wmma") elseif(FILE_NAME MATCHES "_wmma")
list(REMOVE_ITEM EX_TARGETS gfx908 gfx90a gfx940 gfx941 gfx942 gfx1030) list(REMOVE_ITEM EX_TARGETS gfx908 gfx90a gfx940 gfx941 gfx942 gfx1030)
endif() endif()
...@@ -162,7 +162,7 @@ function(add_example_executable_no_testing EXAMPLE_NAME FILE_NAME) ...@@ -162,7 +162,7 @@ function(add_example_executable_no_testing EXAMPLE_NAME FILE_NAME)
#only continue if there are some source files left on the list #only continue if there are some source files left on the list
if(FILE_NAME) if(FILE_NAME)
if(FILE_NAME MATCHES "_xdl") if(FILE_NAME MATCHES "_xdl")
list(REMOVE_ITEM EX_TARGETS gfx900 gfx906 gfx1030 gfx1100 gfx1101 gfx1102 gfx1103) list(REMOVE_ITEM EX_TARGETS gfx900 gfx906 gfx1030 gfx1100 gfx1101 gfx1102 gfx1103 gfx1200 gfx1201)
elseif(FILE_NAME MATCHES "_wmma") elseif(FILE_NAME MATCHES "_wmma")
list(REMOVE_ITEM EX_TARGETS gfx900 gfx906 gfx908 gfx90a gfx940 gfx941 gfx942 gfx1030) list(REMOVE_ITEM EX_TARGETS gfx900 gfx906 gfx908 gfx90a gfx940 gfx941 gfx942 gfx1030)
endif() endif()
......
...@@ -86,7 +86,6 @@ __global__ void ...@@ -86,7 +86,6 @@ __global__ void
const AElementwiseOperation a_element_op, const AElementwiseOperation a_element_op,
const BElementwiseOperation b_element_op, const BElementwiseOperation b_element_op,
const CDEElementwiseOperation cde_element_op, const CDEElementwiseOperation cde_element_op,
const index_t groups_count,
const AGridDesc_AK0_M_AK1 a_grid_desc_k0_m_k1, const AGridDesc_AK0_M_AK1 a_grid_desc_k0_m_k1,
const BGridDesc_BK0_N_BK1 b_grid_desc_k0_n_k1, const BGridDesc_BK0_N_BK1 b_grid_desc_k0_n_k1,
const DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock const DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
...@@ -101,10 +100,8 @@ __global__ void ...@@ -101,10 +100,8 @@ __global__ void
defined(__gfx94__)) defined(__gfx94__))
// offset base pointer for each work-group // offset base pointer for each work-group
const index_t num_blocks_per_batch = __builtin_amdgcn_readfirstlane(gridDim.y / groups_count); const index_t g_idx = __builtin_amdgcn_readfirstlane(blockIdx.y);
const index_t& num_blocks_per_n = groups_count; const index_t n_idx = __builtin_amdgcn_readfirstlane(blockIdx.z);
const index_t g_idx = __builtin_amdgcn_readfirstlane(blockIdx.y / num_blocks_per_batch);
const index_t n_idx = __builtin_amdgcn_readfirstlane(blockIdx.y / num_blocks_per_n);
const long_index_t e_batch_offset = const long_index_t e_batch_offset =
amd_wave_read_first_lane(compute_ptr_offset_of_groups.GetEPtrOffset(g_idx)); amd_wave_read_first_lane(compute_ptr_offset_of_groups.GetEPtrOffset(g_idx));
...@@ -200,7 +197,6 @@ __global__ void ...@@ -200,7 +197,6 @@ __global__ void
ignore = p_bs_grid; ignore = p_bs_grid;
ignore = p_ds_grid; ignore = p_ds_grid;
ignore = p_e_grid; ignore = p_e_grid;
ignore = groups_count;
ignore = a_grid_desc_k0_m_k1; ignore = a_grid_desc_k0_m_k1;
ignore = b_grid_desc_k0_n_k1; ignore = b_grid_desc_k0_n_k1;
ignore = ds_grid_desc_mblock_mperblock_nblock_nperblock; ignore = ds_grid_desc_mblock_mperblock_nblock_nperblock;
...@@ -321,8 +317,8 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle ...@@ -321,8 +317,8 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
using ConvToGemmFwdTransformer = TransformConvFwdToGemm<NDimSpatial, using ConvToGemmFwdTransformer = TransformConvFwdToGemm<NDimSpatial,
ConvForwardSpecialization, ConvForwardSpecialization,
true /*SplitN*/, true /*SplitN*/,
ALayout, ADataType,
ELayout>; EDataType>;
static constexpr auto matrix_padder = static constexpr auto matrix_padder =
MatrixPadder<GemmSpec, index_t, index_t, index_t>{MPerBlock, NPerBlock, KPerBlock}; MatrixPadder<GemmSpec, index_t, index_t, index_t>{MPerBlock, NPerBlock, KPerBlock};
...@@ -730,8 +726,8 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle ...@@ -730,8 +726,8 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
arg.a_g_n_c_wis_lengths_[I1] / arg.conv_N_per_block_; arg.a_g_n_c_wis_lengths_[I1] / arg.conv_N_per_block_;
const index_t gdx = arg.block_2_etile_map_.CalculateGridSize(arg.e_grid_desc_m_n_); const index_t gdx = arg.block_2_etile_map_.CalculateGridSize(arg.e_grid_desc_m_n_);
const index_t gdy = arg.num_group_ * num_workgroups_per_Conv_N; const index_t gdy = arg.num_group_;
const index_t gdz = 1; const index_t gdz = num_workgroups_per_Conv_N;
const auto K = const auto K =
arg.a_grid_desc_ak0_m_ak1_.GetLength(I0) * arg.a_grid_desc_ak0_m_ak1_.GetLength(I2); arg.a_grid_desc_ak0_m_ak1_.GetLength(I0) * arg.a_grid_desc_ak0_m_ak1_.GetLength(I2);
...@@ -780,7 +776,6 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle ...@@ -780,7 +776,6 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
arg.a_element_op_, arg.a_element_op_,
arg.b_element_op_, arg.b_element_op_,
arg.cde_element_op_, arg.cde_element_op_,
arg.a_g_n_c_wis_lengths_[0], // Group count
as_grid_desc_ak0_m_ak1, as_grid_desc_ak0_m_ak1,
bs_grid_desc_bk0_n_bk1, bs_grid_desc_bk0_n_bk1,
arg.ds_grid_desc_mblock_mperblock_nblock_nperblock_, arg.ds_grid_desc_mblock_mperblock_nblock_nperblock_,
...@@ -824,7 +819,6 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle ...@@ -824,7 +819,6 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
arg.a_element_op_, arg.a_element_op_,
arg.b_element_op_, arg.b_element_op_,
arg.cde_element_op_, arg.cde_element_op_,
arg.a_g_n_c_wis_lengths_[0], // Group count
arg.a_grid_desc_ak0_m_ak1_, arg.a_grid_desc_ak0_m_ak1_,
arg.b_grid_desc_bk0_n_bk1_, arg.b_grid_desc_bk0_n_bk1_,
arg.ds_grid_desc_mblock_mperblock_nblock_nperblock_, arg.ds_grid_desc_mblock_mperblock_nblock_nperblock_,
......
...@@ -81,11 +81,11 @@ function(add_instance_library INSTANCE_NAME) ...@@ -81,11 +81,11 @@ function(add_instance_library INSTANCE_NAME)
set(INST_TARGETS ${GPU_TARGETS}) set(INST_TARGETS ${GPU_TARGETS})
endif() endif()
if(source MATCHES "_xdl") if(source MATCHES "_xdl")
list(REMOVE_ITEM INST_TARGETS gfx900 gfx906 gfx1030 gfx1100 gfx1101 gfx1102 gfx1103) list(REMOVE_ITEM INST_TARGETS gfx900 gfx906 gfx1030 gfx1100 gfx1101 gfx1102 gfx1103 gfx1200 gfx1201)
elseif(ARGN MATCHES "_wmma") elseif(ARGN MATCHES "_wmma")
list(REMOVE_ITEM INST_TARGETS gfx900 gfx906 gfx908 gfx90a gfx940 gfx941 gfx942 gfx1030) list(REMOVE_ITEM INST_TARGETS gfx900 gfx906 gfx908 gfx90a gfx940 gfx941 gfx942 gfx1030)
elseif(ARGN MATCHES "mha") elseif(ARGN MATCHES "mha")
list(REMOVE_ITEM INST_TARGETS gfx900 gfx906 gfx908 gfx90a gfx1030 gfx1100 gfx1101 gfx1102 gfx1103) list(REMOVE_ITEM INST_TARGETS gfx900 gfx906 gfx908 gfx90a gfx1030 gfx1100 gfx1101 gfx1102 gfx1103 gfx1200 gfx1201)
endif() endif()
set(offload_targets) set(offload_targets)
foreach(target IN LISTS INST_TARGETS) foreach(target IN LISTS INST_TARGETS)
......
#!/bin/bash
## The following will be used for CI
set -x
## for float
bin/test_reduce_with_index -D 64,4,280,82 -R 0,1,2,3 0 2
bin/test_reduce_with_index -D 64,4,280,82 -R 0,1,2 0 2
bin/test_reduce_with_index -D 64,4,280,82 -R 0,1,3 0 2
bin/test_reduce_with_index -D 64,4,280,82 -R 0,2,3 0 2
bin/test_reduce_with_index -D 64,4,280,82 -R 1,2,3 0 2
bin/test_reduce_with_index -D 64,4,280,82 -R 0 0 2
bin/test_reduce_with_index -D 64,4,280,82 -R 1 0 2
bin/test_reduce_with_index -D 64,4,280,82 -R 2 0 2
bin/test_reduce_with_index -D 64,4,280,82 -R 3 0 2
## for float64
bin/test_reduce_with_index -D 64,4,280,82 -R 0,1,2,3 6 2
bin/test_reduce_with_index -D 64,4,280,82 -R 0,1,2 6 2
bin/test_reduce_with_index -D 64,4,280,82 -R 0,1,3 6 2
bin/test_reduce_with_index -D 64,4,280,82 -R 0,2,3 6 2
bin/test_reduce_with_index -D 64,4,280,82 -R 1,2,3 6 2
bin/test_reduce_with_index -D 64,4,280,82 -R 0 6 2
bin/test_reduce_with_index -D 64,4,280,82 -R 1 6 2
bin/test_reduce_with_index -D 64,4,280,82 -R 2 6 2
bin/test_reduce_with_index -D 64,4,280,82 -R 3 6 2
## for float16
bin/test_reduce_with_index -D 64,4,280,82 -R 0,1,2,3 1 2
bin/test_reduce_with_index -D 64,4,280,82 -R 0,1,2 1 2
bin/test_reduce_with_index -D 64,4,280,82 -R 0,1,3 1 2
bin/test_reduce_with_index -D 64,4,280,82 -R 0,2,3 1 2
bin/test_reduce_with_index -D 64,4,280,82 -R 1,2,3 1 2
bin/test_reduce_with_index -D 64,4,280,82 -R 0 1 2
bin/test_reduce_with_index -D 64,4,280,82 -R 1 1 2
bin/test_reduce_with_index -D 64,4,280,82 -R 2 1 2
bin/test_reduce_with_index -D 64,4,280,82 -R 3 1 2
## for int8_t
bin/test_reduce_with_index -D 64,4,280,82 -R 0,1,2,3 3 2
bin/test_reduce_with_index -D 64,4,280,82 -R 0,1,2 3 2
bin/test_reduce_with_index -D 64,4,280,82 -R 0,1,3 3 2
bin/test_reduce_with_index -D 64,4,280,82 -R 0,2,3 3 2
bin/test_reduce_with_index -D 64,4,280,82 -R 1,2,3 3 2
bin/test_reduce_with_index -D 64,4,280,82 -R 0 3 2
bin/test_reduce_with_index -D 64,4,280,82 -R 1 3 2
bin/test_reduce_with_index -D 64,4,280,82 -R 2 3 2
bin/test_reduce_with_index -D 64,4,280,82 -R 3 3 2
## for bfloat16
bin/test_reduce_with_index -D 64,4,280,82 -R 0,1,2,3 5 2
bin/test_reduce_with_index -D 64,4,280,82 -R 0,1,2 5 2
bin/test_reduce_with_index -D 64,4,280,82 -R 0,1,3 5 2
bin/test_reduce_with_index -D 64,4,280,82 -R 0,2,3 5 2
bin/test_reduce_with_index -D 64,4,280,82 -R 1,2,3 5 2
bin/test_reduce_with_index -D 64,4,280,82 -R 0 5 2
bin/test_reduce_with_index -D 64,4,280,82 -R 1 5 2
bin/test_reduce_with_index -D 64,4,280,82 -R 2 5 2
bin/test_reduce_with_index -D 64,4,280,82 -R 3 5 2
set +x
...@@ -68,11 +68,11 @@ function(add_test_executable TEST_NAME) ...@@ -68,11 +68,11 @@ function(add_test_executable TEST_NAME)
#only continue if there are some source files left on the list #only continue if there are some source files left on the list
if(ARGN) if(ARGN)
if(ARGN MATCHES "_xdl") if(ARGN MATCHES "_xdl")
list(REMOVE_ITEM TEST_TARGETS gfx1030 gfx1100 gfx1101 gfx1102 gfx1103) list(REMOVE_ITEM TEST_TARGETS gfx1030 gfx1100 gfx1101 gfx1102 gfx1103 gfx1200 gfx1201)
elseif(ARGN MATCHES "_wmma") elseif(ARGN MATCHES "_wmma")
list(REMOVE_ITEM TEST_TARGETS gfx908 gfx90a gfx940 gfx941 gfx942 gfx1030) list(REMOVE_ITEM TEST_TARGETS gfx908 gfx90a gfx940 gfx941 gfx942 gfx1030)
elseif(ARGN MATCHES "_smfmac") elseif(ARGN MATCHES "_smfmac")
list(REMOVE_ITEM TEST_TARGETS gfx1030 gfx1100 gfx1101 gfx1102 gfx1103 gfx908 gfx90a) list(REMOVE_ITEM TEST_TARGETS gfx1030 gfx1100 gfx1101 gfx1102 gfx1103 gfx908 gfx90a gfx1200 gfx1201)
endif() endif()
set_source_files_properties(${ARGN} PROPERTIES LANGUAGE HIP) set_source_files_properties(${ARGN} PROPERTIES LANGUAGE HIP)
add_executable(${TEST_NAME} ${ARGN}) add_executable(${TEST_NAME} ${ARGN})
...@@ -149,11 +149,11 @@ function(add_gtest_executable TEST_NAME) ...@@ -149,11 +149,11 @@ function(add_gtest_executable TEST_NAME)
#only continue if there are some source files left on the list #only continue if there are some source files left on the list
if(ARGN) if(ARGN)
if(ARGN MATCHES "_xdl") if(ARGN MATCHES "_xdl")
list(REMOVE_ITEM TEST_TARGETS gfx900 gfx906 gfx1030 gfx1100 gfx1101 gfx1102 gfx1103) list(REMOVE_ITEM TEST_TARGETS gfx900 gfx906 gfx1030 gfx1100 gfx1101 gfx1102 gfx1103 gfx1200 gfx1201)
elseif(ARGN MATCHES "_wmma") elseif(ARGN MATCHES "_wmma")
list(REMOVE_ITEM TEST_TARGETS gfx900 gfx906 gfx908 gfx90a gfx940 gfx941 gfx942 gfx1030) list(REMOVE_ITEM TEST_TARGETS gfx900 gfx906 gfx908 gfx90a gfx940 gfx941 gfx942 gfx1030)
elseif(ARGN MATCHES "_smfmac") elseif(ARGN MATCHES "_smfmac")
list(REMOVE_ITEM TEST_TARGETS gfx1030 gfx1100 gfx1101 gfx1102 gfx1103 gfx908 gfx90a) list(REMOVE_ITEM TEST_TARGETS gfx1030 gfx1100 gfx1101 gfx1102 gfx1103 gfx908 gfx90a gfx1200 gfx1201)
endif() endif()
set_source_files_properties(${ARGN} PROPERTIES LANGUAGE HIP) set_source_files_properties(${ARGN} PROPERTIES LANGUAGE HIP)
add_executable(${TEST_NAME} ${ARGN}) add_executable(${TEST_NAME} ${ARGN})
......
add_test_executable(test_reduce_no_index reduce_no_index.cpp) add_gtest_executable(test_reduce_no_index reduce_no_index.cpp)
add_test_executable(test_reduce_with_index reduce_with_index.cpp) add_gtest_executable(test_reduce_with_index reduce_with_index.cpp)
target_link_libraries(test_reduce_no_index PRIVATE utility device_reduce_instance) target_link_libraries(test_reduce_no_index PRIVATE utility device_reduce_instance)
target_link_libraries(test_reduce_with_index PRIVATE utility device_reduce_instance) target_link_libraries(test_reduce_with_index PRIVATE utility device_reduce_instance)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment