Commit 2a4c2316 authored by danyao12's avatar danyao12
Browse files

Merge branch 'develop' into ck_tile/fa_asm_bwd

parents 1e01ee09 770d2b77
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#ifndef GUARD_CK_FILESYSTEM_HPP_
#define GUARD_CK_FILESYSTEM_HPP_
#include <string>
#include <string_view>
// clang-format off
#if defined(CPPCHECK)
#define CK_HAS_FILESYSTEM 1
#define CK_HAS_FILESYSTEM_TS 1
#elif defined(_WIN32)
#if _MSC_VER >= 1920
#define CK_HAS_FILESYSTEM 1
#define CK_HAS_FILESYSTEM_TS 0
#elif _MSC_VER >= 1900
#define CK_HAS_FILESYSTEM 0
#define CK_HAS_FILESYSTEM_TS 1
#else
#define CK_HAS_FILESYSTEM 0
#define CK_HAS_FILESYSTEM_TS 0
#endif
#elif defined(__has_include)
#if __has_include(<filesystem>) && __cplusplus >= 201703L
#define CK_HAS_FILESYSTEM 1
#else
#define CK_HAS_FILESYSTEM 0
#endif
#if __has_include(<experimental/filesystem>) && __cplusplus >= 201103L
#define CK_HAS_FILESYSTEM_TS 1
#else
#define CK_HAS_FILESYSTEM_TS 0
#endif
#else
#define CK_HAS_FILESYSTEM 0
#define CK_HAS_FILESYSTEM_TS 0
#endif
// clang-format on
#if CK_HAS_FILESYSTEM
#include <filesystem>
#elif CK_HAS_FILESYSTEM_TS
#include <experimental/filesystem>
#else
#error "No filesystem include available"
#endif
namespace CK {
#if CK_HAS_FILESYSTEM
namespace fs = ::std::filesystem;
#elif CK_HAS_FILESYSTEM_TS
namespace fs = ::std::experimental::filesystem;
#endif
} // namespace CK
inline std::string operator+(const std::string_view s, const CK::fs::path& path)
{
return path.string().insert(0, s);
}
inline std::string operator+(const CK::fs::path& path, const std::string_view s)
{
return path.string().append(s);
}
#define FS_ENUM_PERMS_ALL fs::perms::all
#if CK_HAS_FILESYSTEM_TS
#ifdef __linux__
#include <linux/limits.h>
namespace CK {
inline fs::path weakly_canonical(const fs::path& path)
{
std::string result(PATH_MAX, '\0');
std::string p{path.is_relative() ? (fs::current_path() / path).string() : path.string()};
char* retval = realpath(p.c_str(), &result[0]);
return (retval == nullptr) ? path : fs::path{result};
}
} // namespace CK
#else
#error "Not implmeneted!"
#endif
#else
namespace CK {
inline fs::path weakly_canonical(const fs::path& path) { return fs::weakly_canonical(path); }
} // namespace CK
#endif
namespace CK {
#ifdef _WIN32
constexpr std::string_view executable_postfix{".exe"};
constexpr std::string_view library_prefix{""};
constexpr std::string_view dynamic_library_postfix{".dll"};
constexpr std::string_view static_library_postfix{".lib"};
constexpr std::string_view object_file_postfix{".obj"};
#else
constexpr std::string_view executable_postfix{""};
constexpr std::string_view library_prefix{"lib"};
constexpr std::string_view dynamic_library_postfix{".so"};
constexpr std::string_view static_library_postfix{".a"};
constexpr std::string_view object_file_postfix{".o"};
#endif
inline fs::path make_executable_name(const fs::path& path)
{
return path.parent_path() / (path.filename() + executable_postfix);
}
inline fs::path make_dynamic_library_name(const fs::path& path)
{
return path.parent_path() / (library_prefix + path.filename() + dynamic_library_postfix);
}
inline fs::path make_object_file_name(const fs::path& path)
{
return path.parent_path() / (path.filename() + object_file_postfix);
}
inline fs::path make_static_library_name(const fs::path& path)
{
return path.parent_path() / (library_prefix + path.filename() + static_library_postfix);
}
struct FsPathHash
{
std::size_t operator()(const fs::path& path) const { return fs::hash_value(path); }
};
} // namespace CK
#endif // GUARD_CK_FILESYSTEM_HPP_
...@@ -446,7 +446,9 @@ struct BlockwiseGemmXdlops_pipeline_v3<BlockGemmPipelineScheduler::Intrawave, ...@@ -446,7 +446,9 @@ struct BlockwiseGemmXdlops_pipeline_v3<BlockGemmPipelineScheduler::Intrawave,
}); });
}); });
}); });
__builtin_amdgcn_sched_barrier(0); // Let's leak last MFMA block to epilogue region, cover the potential lds-shuffle
// latency
// __builtin_amdgcn_sched_barrier(0);
} }
} }
......
...@@ -171,6 +171,16 @@ struct DeviceGemmMultiD_Xdl_CShuffle_V3 : public DeviceGemmMultipleDSplitK<ALayo ...@@ -171,6 +171,16 @@ struct DeviceGemmMultiD_Xdl_CShuffle_V3 : public DeviceGemmMultipleDSplitK<ALayo
Argument arg_ = arg; Argument arg_ = arg;
const auto a_grid_desc_ak0_m_ak1 = GridwiseGemm::MakeAGridDescriptor_AK0_M_AK1(
arg_.M, arg_.MPadded, arg_.K, arg_.KPadded, arg_.StrideA, arg_.AK0);
const auto b_grid_desc_bk0_n_bk1 = GridwiseGemm::MakeBGridDescriptor_BK0_N_BK1(
arg_.K, arg_.KPadded, arg_.N, arg_.NPadded, arg_.StrideB, arg_.BK0);
auto size_a_buffer =
a_grid_desc_ak0_m_ak1.GetElementSpaceSize() * sizeof(ADataType);
auto size_b_buffer =
b_grid_desc_bk0_n_bk1.GetElementSpaceSize() * sizeof(BDataType);
const auto ds_grid_desc_m_n = GridwiseGemm::MakeDsGridDescriptor_M_N( const auto ds_grid_desc_m_n = GridwiseGemm::MakeDsGridDescriptor_M_N(
arg_.M, arg_.MPadded, arg_.N, arg_.NPadded, arg_.StrideDs); arg_.M, arg_.MPadded, arg_.N, arg_.NPadded, arg_.StrideDs);
...@@ -179,11 +189,7 @@ struct DeviceGemmMultiD_Xdl_CShuffle_V3 : public DeviceGemmMultipleDSplitK<ALayo ...@@ -179,11 +189,7 @@ struct DeviceGemmMultiD_Xdl_CShuffle_V3 : public DeviceGemmMultipleDSplitK<ALayo
DsSize[i] = ds_grid_desc_m_n[i].GetElementSpaceSize() * sizeof(DDataType); DsSize[i] = ds_grid_desc_m_n[i].GetElementSpaceSize() * sizeof(DDataType);
}); });
ck::utility::RotatingMemWrapperMultiD<Argument, DsDataType> rotating_mem( ck::utility::RotatingMemWrapperMultiD<Argument, DsDataType> rotating_mem(
arg_, arg_, stream_config.rotating_count, size_a_buffer, size_b_buffer, DsSize);
stream_config.rotating_count,
arg_.M * arg_.K * sizeof(ADataType),
arg_.K * arg_.N * sizeof(BDataType),
DsSize);
rotating_mem.Print(); rotating_mem.Print();
auto run_flush_cache = [&]() { auto run_flush_cache = [&]() {
......
...@@ -155,11 +155,19 @@ struct DeviceGemm_Xdl_CShuffleV3 : public DeviceGemmV2<ALayout, ...@@ -155,11 +155,19 @@ struct DeviceGemm_Xdl_CShuffleV3 : public DeviceGemmV2<ALayout,
if(stream_config.flush_cache) if(stream_config.flush_cache)
{ {
Argument arg_ = arg; Argument arg_ = arg;
const auto a_grid_desc_ak0_m_ak1 = GridwiseGemm::MakeAGridDescriptor_AK0_M_AK1(
arg_.M, arg_.MPadded, arg_.K, arg_.KPadded, arg_.StrideA, arg_.AK0);
const auto b_grid_desc_bk0_n_bk1 = GridwiseGemm::MakeBGridDescriptor_BK0_N_BK1(
arg_.K, arg_.KPadded, arg_.N, arg_.NPadded, arg_.StrideB, arg_.BK0);
auto size_a_buffer =
a_grid_desc_ak0_m_ak1.GetElementSpaceSize() * sizeof(ADataType);
auto size_b_buffer =
b_grid_desc_bk0_n_bk1.GetElementSpaceSize() * sizeof(BDataType);
ck::utility::RotatingMemWrapper<Argument> rotating_mem( ck::utility::RotatingMemWrapper<Argument> rotating_mem(
arg_, arg_, stream_config.rotating_count, size_a_buffer, size_b_buffer);
stream_config.rotating_count,
arg_.M * arg_.K * sizeof(ADataType),
arg_.K * arg_.N * sizeof(BDataType));
rotating_mem.Print(); rotating_mem.Print();
auto run_flush_cache = [&]() { auto run_flush_cache = [&]() {
......
...@@ -15,6 +15,7 @@ ...@@ -15,6 +15,7 @@
#include "ck/tensor_operation/gpu/device/device_grouped_conv_bwd_weight.hpp" #include "ck/tensor_operation/gpu/device/device_grouped_conv_bwd_weight.hpp"
#include "ck/tensor_operation/operator_transform/transform_conv_bwd_weight_to_gemm.hpp" #include "ck/tensor_operation/operator_transform/transform_conv_bwd_weight_to_gemm.hpp"
#include "ck/tensor_operation/operator_transform/transform_conv_bwd_weight_to_gemm_v2.hpp" #include "ck/tensor_operation/operator_transform/transform_conv_bwd_weight_to_gemm_v2.hpp"
#include "ck/tensor_operation/operator_transform/transform_conv_ngchw_to_nhwgc.hpp"
#include "ck/tensor_operation/gpu/device/convolution_backward_weight_specialization.hpp" #include "ck/tensor_operation/gpu/device/convolution_backward_weight_specialization.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_elementwise_2d.hpp" #include "ck/tensor_operation/gpu/grid/gridwise_elementwise_2d.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" #include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
...@@ -22,7 +23,6 @@ ...@@ -22,7 +23,6 @@
#include <ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp> #include <ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp>
#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_utils.hpp" #include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_utils.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" #include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/tensor_operation/gpu/device/matrix_padder.hpp"
#include "ck/host_utility/device_prop.hpp" #include "ck/host_utility/device_prop.hpp"
#include "ck/host_utility/kernel_launch.hpp" #include "ck/host_utility/kernel_launch.hpp"
...@@ -257,6 +257,19 @@ struct DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle ...@@ -257,6 +257,19 @@ struct DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle
KPerBlock / K1Number, KPerBlock / K1Number,
ConvBackwardWeightSpecialization>{}; ConvBackwardWeightSpecialization>{};
static constexpr index_t ClusterLengthMPerBlock =
CBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock::At(1);
static constexpr index_t ClusterLengthNPerBlock =
CBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock::At(3);
static constexpr auto conv_ngchw_to_nhwgc_transformer =
TransformConvNGCHWToNHWGC<InLayout,
WeiLayout,
OutLayout,
NDimSpatial,
MPerBlock / ClusterLengthMPerBlock,
NPerBlock / ClusterLengthNPerBlock>{};
static constexpr GemmSpecialization GemmSpec = GemmSpecialization::Default; static constexpr GemmSpecialization GemmSpec = GemmSpecialization::Default;
template <ck::index_t NDim, typename ck::enable_if<NDim == 2, bool>::type = false> template <ck::index_t NDim, typename ck::enable_if<NDim == 2, bool>::type = false>
...@@ -359,141 +372,12 @@ struct DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle ...@@ -359,141 +372,12 @@ struct DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle
batch)[I2]; batch)[I2];
} }
static constexpr index_t ClusterLengthMPerBlock = using NGCHWTransposeDescType =
CBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock::At(1); remove_cvref_t<decltype(conv_ngchw_to_nhwgc_transformer
static constexpr index_t ClusterLengthNPerBlock = .template MakeNGCHWTransposeDesc<NDimSpatial>({}, {}))>;
CBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock::At(3); using NHWGCTransposeDescType =
remove_cvref_t<decltype(conv_ngchw_to_nhwgc_transformer
template <ck::index_t NDim, typename ck::enable_if<NDim == 2, bool>::type = false> .template MakeNHWGCTransposeDesc<NDimSpatial>({}, {}))>;
static auto MakeInputTransposeDesc(std::array<ck::index_t, NDimSpatial + 3> g_n_c_wis_lengths,
std::array<ck::index_t, NDimSpatial + 3> g_n_c_wis_strides)
{
const index_t& G = g_n_c_wis_lengths[0];
const index_t& N = g_n_c_wis_lengths[1];
const index_t& C = g_n_c_wis_lengths[2];
const index_t& Hi = g_n_c_wis_lengths[3];
const index_t& Wi = g_n_c_wis_lengths[4];
const index_t& GStride = g_n_c_wis_strides[0];
const index_t& NStride = g_n_c_wis_strides[1];
const index_t& CStride = g_n_c_wis_strides[2];
const index_t& HiStride = g_n_c_wis_strides[3];
const index_t& WiStride = g_n_c_wis_strides[4];
const auto desc = make_naive_tensor_descriptor(
make_tuple(N, G, C, Hi, Wi), make_tuple(NStride, GStride, CStride, HiStride, WiStride));
const auto merged_desc =
transform_tensor_descriptor(desc,
make_tuple(make_merge_transform(make_tuple(N, G, C)),
make_merge_transform(make_tuple(Hi, Wi))),
make_tuple(Sequence<0, 1, 2>{}, Sequence<3, 4>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
return PadTensorDescriptor(
merged_desc,
make_tuple(MPerBlock / ClusterLengthMPerBlock, NPerBlock / ClusterLengthNPerBlock),
Sequence<true, true>{});
}
template <ck::index_t NDim, typename ck::enable_if<NDim == 2, bool>::type = false>
static auto MakeOutputTransposeDesc(std::array<ck::index_t, NDimSpatial + 3> g_n_c_wis_lengths,
std::array<ck::index_t, NDimSpatial + 3> g_n_c_wis_strides)
{
const index_t& G = g_n_c_wis_lengths[0];
const index_t& N = g_n_c_wis_lengths[1];
const index_t& C = g_n_c_wis_lengths[2];
const index_t& Hi = g_n_c_wis_lengths[3];
const index_t& Wi = g_n_c_wis_lengths[4];
const index_t& NStride = g_n_c_wis_strides[1];
const index_t HiStride = Wi * G * C;
const index_t WiStride = G * C;
const index_t GStride = C;
const index_t CStride = 1;
const auto desc = make_naive_tensor_descriptor(
make_tuple(N, G, C, Hi, Wi), make_tuple(NStride, GStride, CStride, HiStride, WiStride));
const auto merged_desc =
transform_tensor_descriptor(desc,
make_tuple(make_merge_transform(make_tuple(N, G, C)),
make_merge_transform(make_tuple(Hi, Wi))),
make_tuple(Sequence<0, 1, 2>{}, Sequence<3, 4>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
return PadTensorDescriptor(
merged_desc,
make_tuple(MPerBlock / ClusterLengthMPerBlock, NPerBlock / ClusterLengthNPerBlock),
Sequence<true, true>{});
}
template <ck::index_t NDim, typename ck::enable_if<NDim == 3, bool>::type = false>
static auto MakeInputTransposeDesc(std::array<ck::index_t, NDimSpatial + 3> g_n_c_wis_lengths,
std::array<ck::index_t, NDimSpatial + 3> g_n_c_wis_strides)
{
const index_t& G = g_n_c_wis_lengths[0];
const index_t& N = g_n_c_wis_lengths[1];
const index_t& C = g_n_c_wis_lengths[2];
const index_t& Di = g_n_c_wis_lengths[3];
const index_t& Hi = g_n_c_wis_lengths[4];
const index_t& Wi = g_n_c_wis_lengths[5];
const index_t& GStride = g_n_c_wis_strides[0];
const index_t& NStride = g_n_c_wis_strides[1];
const index_t& CStride = g_n_c_wis_strides[2];
const index_t& DiStride = g_n_c_wis_strides[3];
const index_t& HiStride = g_n_c_wis_strides[4];
const index_t& WiStride = g_n_c_wis_strides[5];
const auto desc = make_naive_tensor_descriptor(
make_tuple(N, G, C, Di, Hi, Wi),
make_tuple(NStride, GStride, CStride, DiStride, HiStride, WiStride));
const auto merged_desc =
transform_tensor_descriptor(desc,
make_tuple(make_merge_transform(make_tuple(N, G, C)),
make_merge_transform(make_tuple(Di, Hi, Wi))),
make_tuple(Sequence<0, 1, 2>{}, Sequence<3, 4, 5>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
return PadTensorDescriptor(
merged_desc,
make_tuple(MPerBlock / ClusterLengthMPerBlock, NPerBlock / ClusterLengthNPerBlock),
Sequence<true, true>{});
}
template <ck::index_t NDim, typename ck::enable_if<NDim == 3, bool>::type = false>
static auto MakeOutputTransposeDesc(std::array<ck::index_t, NDimSpatial + 3> g_n_c_wis_lengths,
std::array<ck::index_t, NDimSpatial + 3> g_n_c_wis_strides)
{
const index_t& G = g_n_c_wis_lengths[0];
const index_t& N = g_n_c_wis_lengths[1];
const index_t& C = g_n_c_wis_lengths[2];
const index_t& Di = g_n_c_wis_lengths[3];
const index_t& Hi = g_n_c_wis_lengths[4];
const index_t& Wi = g_n_c_wis_lengths[5];
const index_t& NStride = g_n_c_wis_strides[1];
const index_t DiStride = Hi * Wi * G * C;
const index_t HiStride = Wi * G * C;
const index_t WiStride = G * C;
const index_t GStride = C;
const index_t CStride = 1;
const auto desc = make_naive_tensor_descriptor(
make_tuple(N, G, C, Di, Hi, Wi),
make_tuple(NStride, GStride, CStride, DiStride, HiStride, WiStride));
const auto merged_desc =
transform_tensor_descriptor(desc,
make_tuple(make_merge_transform(make_tuple(N, G, C)),
make_merge_transform(make_tuple(Di, Hi, Wi))),
make_tuple(Sequence<0, 1, 2>{}, Sequence<3, 4, 5>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
return PadTensorDescriptor(
merged_desc,
make_tuple(MPerBlock / ClusterLengthMPerBlock, NPerBlock / ClusterLengthNPerBlock),
Sequence<true, true>{});
}
using InputTransposeDescType =
remove_cvref_t<decltype(MakeInputTransposeDesc<NDimSpatial>({}, {}))>;
using OutputTransposeDescType =
remove_cvref_t<decltype(MakeOutputTransposeDesc<NDimSpatial>({}, {}))>;
using ABCGridDescs = decltype(GetABCGridDesc<NDimSpatial>()); using ABCGridDescs = decltype(GetABCGridDesc<NDimSpatial>());
...@@ -572,8 +456,8 @@ struct DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle ...@@ -572,8 +456,8 @@ struct DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle
I1>; I1>;
using GridwiseElementwiseTranspose = using GridwiseElementwiseTranspose =
GridwiseElementwise<Tuple<InputTransposeDescType>, GridwiseElementwise<Tuple<NGCHWTransposeDescType>,
Tuple<OutputTransposeDescType>, Tuple<NHWGCTransposeDescType>,
Tuple<const ADataType*>, Tuple<const ADataType*>,
Tuple<ADataType*>, Tuple<ADataType*>,
Block2TileMapElementwise, Block2TileMapElementwise,
...@@ -652,43 +536,11 @@ struct DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle ...@@ -652,43 +536,11 @@ struct DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle
begin(output_spatial_lengths_)); begin(output_spatial_lengths_));
std::array<index_t, NDimSpatial + 3> b_g_n_c_wis_strides_transposed = std::array<index_t, NDimSpatial + 3> b_g_n_c_wis_strides_transposed =
b_g_n_c_wis_strides; conv_ngchw_to_nhwgc_transformer.TransposeStrides(b_g_n_c_wis_lengths,
b_g_n_c_wis_strides);
std::array<index_t, NDimSpatial + 3> a_g_n_k_wos_strides_transposed = std::array<index_t, NDimSpatial + 3> a_g_n_k_wos_strides_transposed =
a_g_n_k_wos_strides; conv_ngchw_to_nhwgc_transformer.TransposeStrides(a_g_n_k_wos_lengths,
a_g_n_k_wos_strides);
// NGKHW - transpose needed
if constexpr(is_NGCHW_GKYXC_NGKHW<InLayout, WeiLayout, OutLayout>() ||
is_NGCDHW_GKZYXC_NGKDHW<InLayout, WeiLayout, OutLayout>())
{
b_g_n_c_wis_strides_transposed[I0] = Conv_C_;
b_g_n_c_wis_strides_transposed[I2] = I1;
a_g_n_k_wos_strides_transposed[I0] = Conv_K_;
a_g_n_k_wos_strides_transposed[I2] = I1;
if constexpr(NDimSpatial == 2)
{
b_g_n_c_wis_strides_transposed[I3] =
input_spatial_lengths_[I1] * Conv_G_ * Conv_C_;
b_g_n_c_wis_strides_transposed[I4] = Conv_G_ * Conv_C_;
a_g_n_k_wos_strides_transposed[I3] =
output_spatial_lengths_[I1] * Conv_G_ * Conv_K_;
a_g_n_k_wos_strides_transposed[I4] = Conv_G_ * Conv_K_;
}
else if constexpr(NDimSpatial == 3)
{
b_g_n_c_wis_strides_transposed[I3] =
input_spatial_lengths_[I1] * input_spatial_lengths_[I2] * Conv_G_ * Conv_C_;
b_g_n_c_wis_strides_transposed[I4] =
input_spatial_lengths_[I2] * Conv_G_ * Conv_C_;
b_g_n_c_wis_strides_transposed[I5] = Conv_G_ * Conv_C_;
a_g_n_k_wos_strides_transposed[I3] = output_spatial_lengths_[I1] *
input_spatial_lengths_[I2] * Conv_G_ *
Conv_K_;
a_g_n_k_wos_strides_transposed[I4] =
input_spatial_lengths_[I2] * Conv_G_ * Conv_K_;
a_g_n_k_wos_strides_transposed[I5] = Conv_G_ * Conv_K_;
}
}
const auto descs = const auto descs =
conv_to_gemm_transformer_v2 conv_to_gemm_transformer_v2
...@@ -755,14 +607,18 @@ struct DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle ...@@ -755,14 +607,18 @@ struct DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle
is_NGCDHW_GKZYXC_NGKDHW<InLayout, WeiLayout, OutLayout>()) is_NGCDHW_GKZYXC_NGKDHW<InLayout, WeiLayout, OutLayout>())
{ {
a_in_transpose_desc_ = a_in_transpose_desc_ =
MakeInputTransposeDesc<NDimSpatial>(a_g_n_k_wos_lengths, a_g_n_k_wos_strides); conv_ngchw_to_nhwgc_transformer.template MakeNGCHWTransposeDesc<NDimSpatial>(
a_g_n_k_wos_lengths, a_g_n_k_wos_strides);
a_out_transpose_desc_ = a_out_transpose_desc_ =
MakeOutputTransposeDesc<NDimSpatial>(a_g_n_k_wos_lengths, a_g_n_k_wos_strides); conv_ngchw_to_nhwgc_transformer.template MakeNHWGCTransposeDesc<NDimSpatial>(
a_g_n_k_wos_lengths, a_g_n_k_wos_strides);
b_in_transpose_desc_ = b_in_transpose_desc_ =
MakeInputTransposeDesc<NDimSpatial>(b_g_n_c_wis_lengths, b_g_n_c_wis_strides); conv_ngchw_to_nhwgc_transformer.template MakeNGCHWTransposeDesc<NDimSpatial>(
b_g_n_c_wis_lengths, b_g_n_c_wis_strides);
b_out_transpose_desc_ = b_out_transpose_desc_ =
MakeOutputTransposeDesc<NDimSpatial>(b_g_n_c_wis_lengths, b_g_n_c_wis_strides); conv_ngchw_to_nhwgc_transformer.template MakeNHWGCTransposeDesc<NDimSpatial>(
b_g_n_c_wis_lengths, b_g_n_c_wis_strides);
elementwise_block_2_ctile_map_transpose_a_ = Block2TileMapElementwise{ elementwise_block_2_ctile_map_transpose_a_ = Block2TileMapElementwise{
a_in_transpose_desc_.GetLength(I0), a_in_transpose_desc_.GetLength(I1)}; a_in_transpose_desc_.GetLength(I0), a_in_transpose_desc_.GetLength(I1)};
...@@ -816,8 +672,8 @@ struct DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle ...@@ -816,8 +672,8 @@ struct DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle
Block2TileMapElementwise elementwise_block_2_ctile_map_transpose_a_, Block2TileMapElementwise elementwise_block_2_ctile_map_transpose_a_,
elementwise_block_2_ctile_map_transpose_b_; elementwise_block_2_ctile_map_transpose_b_;
InputTransposeDescType a_in_transpose_desc_, b_in_transpose_desc_; NGCHWTransposeDescType a_in_transpose_desc_, b_in_transpose_desc_;
OutputTransposeDescType a_out_transpose_desc_, b_out_transpose_desc_; NHWGCTransposeDescType a_out_transpose_desc_, b_out_transpose_desc_;
// for computing batch offset // for computing batch offset
ComputePtrOffsetOfStridedBatch<I1, I1, I0> compute_ptr_offset_of_batch_; ComputePtrOffsetOfStridedBatch<I1, I1, I0> compute_ptr_offset_of_batch_;
...@@ -1569,13 +1425,14 @@ struct DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle ...@@ -1569,13 +1425,14 @@ struct DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle
(arg.GetWorkspaceETensorSizeBytes() + arg.GetWorkspaceATensorSizeBytes()) / (arg.GetWorkspaceETensorSizeBytes() + arg.GetWorkspaceATensorSizeBytes()) /
sizeof(BDataType); sizeof(BDataType);
// Different data type for A and B is not supported
auto kernel_transpose = kernel_elementwise_dual<GridwiseElementwiseTranspose, auto kernel_transpose = kernel_elementwise_dual<GridwiseElementwiseTranspose,
ck::Tuple<InputTransposeDescType>, ck::Tuple<NGCHWTransposeDescType>,
ck::Tuple<InputTransposeDescType>, ck::Tuple<NGCHWTransposeDescType>,
ck::Tuple<OutputTransposeDescType>, ck::Tuple<NHWGCTransposeDescType>,
ck::Tuple<OutputTransposeDescType>, ck::Tuple<NHWGCTransposeDescType>,
ck::Tuple<const ADataType*>, ck::Tuple<const ADataType*>,
ck::Tuple<BDataType*>, ck::Tuple<ADataType*>,
Block2TileMapElementwise, Block2TileMapElementwise,
Block2TileMapElementwise, Block2TileMapElementwise,
element_wise::PassThrough>; element_wise::PassThrough>;
......
...@@ -26,6 +26,15 @@ constexpr bool is_GNWC_GKXC_GNWK() ...@@ -26,6 +26,15 @@ constexpr bool is_GNWC_GKXC_GNWK()
is_same_v<WeiLayout, tensor_layout::convolution::GKXC> && is_same_v<WeiLayout, tensor_layout::convolution::GKXC> &&
is_same_v<OutLayout, tensor_layout::convolution::GNWK>; is_same_v<OutLayout, tensor_layout::convolution::GNWK>;
} }
template <typename InLayout, typename WeiLayout, typename OutLayout>
constexpr bool is_NGCW_GKXC_NGKW()
{
return is_same_v<InLayout, tensor_layout::convolution::NGCW> &&
is_same_v<WeiLayout, tensor_layout::convolution::GKXC> &&
is_same_v<OutLayout, tensor_layout::convolution::NGKW>;
}
// 2d // 2d
template <typename InLayout, typename WeiLayout, typename OutLayout> template <typename InLayout, typename WeiLayout, typename OutLayout>
constexpr bool is_NHWGC_GKYXC_NHWGK() constexpr bool is_NHWGC_GKYXC_NHWGK()
...@@ -91,6 +100,14 @@ constexpr bool is_GNSpatialC_GKSpatial_GNSpatialK() ...@@ -91,6 +100,14 @@ constexpr bool is_GNSpatialC_GKSpatial_GNSpatialK()
is_GNDHWC_GKZYXC_GNDHWK<InLayout, WeiLayout, OutLayout>(); is_GNDHWC_GKZYXC_GNDHWK<InLayout, WeiLayout, OutLayout>();
} }
template <typename InLayout, typename WeiLayout, typename OutLayout>
constexpr bool is_NGCSpatial_GKSpatial_NGKSpatial()
{
return is_NGCW_GKXC_NGKW<InLayout, WeiLayout, OutLayout>() ||
is_NGCHW_GKYXC_NGKHW<InLayout, WeiLayout, OutLayout>() ||
is_NGCDHW_GKZYXC_NGKDHW<InLayout, WeiLayout, OutLayout>();
}
template <index_t NumATensor = 1, index_t NumBTensor = 1, index_t NumDTensor = 0, typename = void> template <index_t NumATensor = 1, index_t NumBTensor = 1, index_t NumDTensor = 0, typename = void>
struct ComputePtrOffsetOfStridedBatch struct ComputePtrOffsetOfStridedBatch
{ {
......
...@@ -355,12 +355,39 @@ struct UnaryDivide ...@@ -355,12 +355,39 @@ struct UnaryDivide
__host__ __device__ void operator()(T& y, const T& x) const __host__ __device__ void operator()(T& y, const T& x) const
{ {
static_assert(is_same<T, float>::value || is_same<T, double>::value || static_assert(is_same<T, float>::value || is_same<T, double>::value ||
is_same<T, int32_t>::value, is_same<T, int32_t>::value || is_same<T, int8_t>::value,
"Data type is not supported by this operation!"); "Data type is not supported by this operation!");
y = x / type_convert<T>(divider_); y = x / type_convert<T>(divider_);
}; };
template <>
__host__ __device__ void operator()<half_t>(half_t& y, const half_t& x) const
{
float x_ = type_convert<float>(x);
float divider_f_ = type_convert<float>(divider_);
y = type_convert<half_t>(x_ / divider_f_);
};
template <>
__host__ __device__ void operator()<bhalf_t>(bhalf_t& y, const bhalf_t& x) const
{
float x_ = type_convert<float>(x);
float divider_f_ = type_convert<float>(divider_);
y = type_convert<bhalf_t>(x_ / divider_f_);
};
template <>
__host__ __device__ void operator()<f8_t>(f8_t& y, const f8_t& x) const
{
float x_ = type_convert<float>(x);
float divider_f_ = type_convert<float>(divider_);
y = type_convert<f8_t>(x_ / divider_f_);
};
int32_t divider_ = 1; int32_t divider_ = 1;
}; };
......
...@@ -221,7 +221,7 @@ struct GridwiseGemm_xdl_cshuffle_v3 ...@@ -221,7 +221,7 @@ struct GridwiseGemm_xdl_cshuffle_v3
make_tuple(Sequence<3>{}, Sequence<0, 1, 2>{})); make_tuple(Sequence<3>{}, Sequence<0, 1, 2>{}));
} }
__device__ static auto MakeAGridDescriptor_AK0_M_AK1( __host__ __device__ static auto MakeAGridDescriptor_AK0_M_AK1(
index_t M, index_t MPad, index_t K, index_t KPad, index_t StrideA, index_t AK0) index_t M, index_t MPad, index_t K, index_t KPad, index_t StrideA, index_t AK0)
{ {
const auto a_grid_desc_mraw_kraw = [&]() { const auto a_grid_desc_mraw_kraw = [&]() {
...@@ -303,7 +303,7 @@ struct GridwiseGemm_xdl_cshuffle_v3 ...@@ -303,7 +303,7 @@ struct GridwiseGemm_xdl_cshuffle_v3
} }
} }
__device__ static auto MakeBGridDescriptor_BK0_N_BK1( __host__ __device__ static auto MakeBGridDescriptor_BK0_N_BK1(
index_t K, index_t KPad, index_t N, index_t NPad, index_t StrideB, index_t BK0) index_t K, index_t KPad, index_t N, index_t NPad, index_t StrideB, index_t BK0)
{ {
const auto b_grid_desc_nraw_kraw = [&]() { const auto b_grid_desc_nraw_kraw = [&]() {
...@@ -576,12 +576,12 @@ struct GridwiseGemm_xdl_cshuffle_v3 ...@@ -576,12 +576,12 @@ struct GridwiseGemm_xdl_cshuffle_v3
} }
else if constexpr(is_same_v<tensor_layout::gemm::ColumnMajor, ALayout>) else if constexpr(is_same_v<tensor_layout::gemm::ColumnMajor, ALayout>)
{ {
a_k_split_offset = blockIdx.z * karg.KRead * karg.M; a_k_split_offset = blockIdx.z * karg.KRead * karg.StrideA;
} }
if constexpr(is_same_v<tensor_layout::gemm::RowMajor, BLayout>) if constexpr(is_same_v<tensor_layout::gemm::RowMajor, BLayout>)
{ {
b_k_split_offset = blockIdx.z * karg.KRead * karg.N; b_k_split_offset = blockIdx.z * karg.KRead * karg.StrideB;
} }
else if constexpr(is_same_v<tensor_layout::gemm::ColumnMajor, BLayout>) else if constexpr(is_same_v<tensor_layout::gemm::ColumnMajor, BLayout>)
{ {
......
...@@ -255,7 +255,7 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3 ...@@ -255,7 +255,7 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3
make_tuple(Sequence<3>{}, Sequence<0, 1, 2>{})); make_tuple(Sequence<3>{}, Sequence<0, 1, 2>{}));
} }
__device__ static auto MakeAGridDescriptor_AK0_M_AK1( __host__ __device__ static auto MakeAGridDescriptor_AK0_M_AK1(
index_t M, index_t MPad, index_t K, index_t KPad, index_t StrideA, index_t AK0) index_t M, index_t MPad, index_t K, index_t KPad, index_t StrideA, index_t AK0)
{ {
const auto a_grid_desc_mraw_kraw = [&]() { const auto a_grid_desc_mraw_kraw = [&]() {
...@@ -337,7 +337,7 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3 ...@@ -337,7 +337,7 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3
} }
} }
__device__ static auto MakeBGridDescriptor_BK0_N_BK1( __host__ __device__ static auto MakeBGridDescriptor_BK0_N_BK1(
index_t K, index_t KPad, index_t N, index_t NPad, index_t StrideB, index_t BK0) index_t K, index_t KPad, index_t N, index_t NPad, index_t StrideB, index_t BK0)
{ {
const auto b_grid_desc_nraw_kraw = [&]() { const auto b_grid_desc_nraw_kraw = [&]() {
...@@ -647,12 +647,12 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3 ...@@ -647,12 +647,12 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3
} }
else if constexpr(is_same_v<tensor_layout::gemm::ColumnMajor, ALayout>) else if constexpr(is_same_v<tensor_layout::gemm::ColumnMajor, ALayout>)
{ {
a_k_split_offset = blockIdx.z * karg.KRead * karg.M; a_k_split_offset = blockIdx.z * karg.KRead * karg.StrideA;
} }
if constexpr(is_same_v<tensor_layout::gemm::RowMajor, BLayout>) if constexpr(is_same_v<tensor_layout::gemm::RowMajor, BLayout>)
{ {
b_k_split_offset = blockIdx.z * karg.KRead * karg.N; b_k_split_offset = blockIdx.z * karg.KRead * karg.StrideB;
} }
else if constexpr(is_same_v<tensor_layout::gemm::ColumnMajor, BLayout>) else if constexpr(is_same_v<tensor_layout::gemm::ColumnMajor, BLayout>)
{ {
......
...@@ -315,7 +315,7 @@ struct ThreadwiseTensorSliceTransfer_v5r1 ...@@ -315,7 +315,7 @@ struct ThreadwiseTensorSliceTransfer_v5r1
forward_sweep_(I0) = true; forward_sweep_(I0) = true;
static_for<1, nDim, 1>{}([&](auto i) { static_for<1, nDim, 1>{}([&](auto i) {
index_t tmp = ordered_dst_access_idx[I0]; index_t tmp = 0;
static_for<0, i, 1>{}([&](auto j) { static_for<0, i, 1>{}([&](auto j) {
tmp = tmp * ordered_dst_access_lengths[j] + ordered_dst_access_idx[j]; tmp = tmp * ordered_dst_access_lengths[j] + ordered_dst_access_idx[j];
......
...@@ -35,10 +35,16 @@ struct smfmac<SmfmacInstr::smfmac_f32_16x16x32f16> ...@@ -35,10 +35,16 @@ struct smfmac<SmfmacInstr::smfmac_f32_16x16x32f16>
static constexpr index_t k_per_blk = 8; static constexpr index_t k_per_blk = 8;
static constexpr bool is_k_reduction = true; static constexpr bool is_k_reduction = true;
template <index_t MPerXdlops, index_t NPerXdlops, class FloatA, class FloatB, class FloatC> template <index_t MPerXdlops,
__device__ void run(const FloatA& a, const FloatB& b, const int32_t& idx, FloatC& reg_c) const index_t NPerXdlops,
index_t idx_part,
class FloatA,
class FloatB,
class FloatC>
__device__ void run(const FloatA& a, const FloatB& b, const index_t& idx, FloatC& reg_c) const
{ {
intrin_smfmac_f32_16x16x32f16<MPerXdlops, NPerXdlops>::Run(a, b, idx, reg_c); intrin_smfmac_f32_16x16x32f16<MPerXdlops, NPerXdlops>::Run<FloatC, idx_part>(
a, b, idx, reg_c);
} }
}; };
...@@ -57,10 +63,16 @@ struct smfmac<SmfmacInstr::smfmac_f32_32x32x16f16> ...@@ -57,10 +63,16 @@ struct smfmac<SmfmacInstr::smfmac_f32_32x32x16f16>
static constexpr index_t k_per_blk = 16; static constexpr index_t k_per_blk = 16;
static constexpr bool is_k_reduction = true; static constexpr bool is_k_reduction = true;
template <index_t MPerXdlops, index_t NPerXdlops, class FloatA, class FloatB, class FloatC> template <index_t MPerXdlops,
__device__ void run(const FloatA& a, const FloatB& b, const int32_t& idx, FloatC& reg_c) const index_t NPerXdlops,
index_t idx_part,
class FloatA,
class FloatB,
class FloatC>
__device__ void run(const FloatA& a, const FloatB& b, const index_t& idx, FloatC& reg_c) const
{ {
intrin_smfmac_f32_32x32x16f16<MPerXdlops, NPerXdlops>::Run(a, b, idx, reg_c); intrin_smfmac_f32_32x32x16f16<MPerXdlops, NPerXdlops>::Run<FloatC, idx_part>(
a, b, idx, reg_c);
} }
}; };
...@@ -79,10 +91,16 @@ struct smfmac<SmfmacInstr::smfmac_f32_16x16x32bf16> ...@@ -79,10 +91,16 @@ struct smfmac<SmfmacInstr::smfmac_f32_16x16x32bf16>
static constexpr index_t k_per_blk = 8; static constexpr index_t k_per_blk = 8;
static constexpr bool is_k_reduction = true; static constexpr bool is_k_reduction = true;
template <index_t MPerXdlops, index_t NPerXdlops, class FloatA, class FloatB, class FloatC> template <index_t MPerXdlops,
__device__ void run(const FloatA& a, const FloatB& b, const int32_t& idx, FloatC& reg_c) const index_t NPerXdlops,
index_t idx_part,
class FloatA,
class FloatB,
class FloatC>
__device__ void run(const FloatA& a, const FloatB& b, const index_t& idx, FloatC& reg_c) const
{ {
intrin_smfmac_f32_16x16x32bf16<MPerXdlops, NPerXdlops>::Run(a, b, idx, reg_c); intrin_smfmac_f32_16x16x32bf16<MPerXdlops, NPerXdlops>::Run<FloatC, idx_part>(
a, b, idx, reg_c);
} }
}; };
...@@ -101,10 +119,16 @@ struct smfmac<SmfmacInstr::smfmac_f32_32x32x16bf16> ...@@ -101,10 +119,16 @@ struct smfmac<SmfmacInstr::smfmac_f32_32x32x16bf16>
static constexpr index_t k_per_blk = 16; static constexpr index_t k_per_blk = 16;
static constexpr bool is_k_reduction = true; static constexpr bool is_k_reduction = true;
template <index_t MPerXdlops, index_t NPerXdlops, class FloatA, class FloatB, class FloatC> template <index_t MPerXdlops,
__device__ void run(const FloatA& a, const FloatB& b, const int32_t& idx, FloatC& reg_c) const index_t NPerXdlops,
index_t idx_part,
class FloatA,
class FloatB,
class FloatC>
__device__ void run(const FloatA& a, const FloatB& b, const index_t& idx, FloatC& reg_c) const
{ {
intrin_smfmac_f32_32x32x16bf16<MPerXdlops, NPerXdlops>::Run(a, b, idx, reg_c); intrin_smfmac_f32_32x32x16bf16<MPerXdlops, NPerXdlops>::Run<FloatC, idx_part>(
a, b, idx, reg_c);
} }
}; };
...@@ -305,8 +329,8 @@ struct SparseXdlopsGemm ...@@ -305,8 +329,8 @@ struct SparseXdlopsGemm
"base base_type must be half or bfloat16!"); "base base_type must be half or bfloat16!");
static_for<0, KPack / smfmac_instr.k_per_blk, 1>{}([&](auto k) { static_for<0, KPack / smfmac_instr.k_per_blk, 1>{}([&](auto k) {
smfmac_instr.template run<MPerXdlops, NPerXdlops>( smfmac_instr.template run<MPerXdlops, NPerXdlops, k % 4>(
p_a_wave[k], p_b_wave[k], idx[k], p_c_thread); p_a_wave[k], p_b_wave[k], idx[k / 4], p_c_thread);
}); });
} }
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck/utility/common_header.hpp"
#include "ck/tensor_description/tensor_descriptor.hpp"
#include "ck/tensor_description/tensor_descriptor_helper.hpp"
#include "ck/tensor_operation/gpu/device/matrix_padder.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_utils.hpp"
namespace ck {
namespace tensor_operation {
template <typename ALayout,
typename BLayout,
typename ELayout,
index_t NDimSpatial,
index_t MPerThread,
index_t NPerThread>
struct TransformConvNGCHWToNHWGC
{
static constexpr auto I0 = Number<0>{};
static constexpr auto I1 = Number<1>{};
static constexpr auto I2 = Number<2>{};
static constexpr auto I3 = Number<3>{};
static constexpr auto I4 = Number<4>{};
static constexpr auto I5 = Number<5>{};
template <ck::index_t NDim, typename ck::enable_if<NDim == 1, bool>::type = false>
static auto MakeNGCHWTransposeDesc(std::array<ck::index_t, NDimSpatial + 3> g_n_c_wis_lengths,
std::array<ck::index_t, NDimSpatial + 3> g_n_c_wis_strides)
{
const index_t& G = g_n_c_wis_lengths[I0];
const index_t& N = g_n_c_wis_lengths[I1];
const index_t& C = g_n_c_wis_lengths[I2];
const index_t& Wi = g_n_c_wis_lengths[I3];
const index_t& GStride = g_n_c_wis_strides[I0];
const index_t& NStride = g_n_c_wis_strides[I1];
const index_t& CStride = g_n_c_wis_strides[I2];
const index_t& WiStride = g_n_c_wis_strides[I3];
const auto desc = make_naive_tensor_descriptor(
make_tuple(N, G, C, Wi), make_tuple(NStride, GStride, CStride, WiStride));
const auto merged_desc =
transform_tensor_descriptor(desc,
make_tuple(make_merge_transform(make_tuple(N, G, C)),
make_merge_transform(make_tuple(Wi))),
make_tuple(Sequence<0, 1, 2>{}, Sequence<3>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
return device::PadTensorDescriptor(
merged_desc, make_tuple(MPerThread, NPerThread), Sequence<true, true>{});
}
template <ck::index_t NDim, typename ck::enable_if<NDim == 1, bool>::type = false>
static auto MakeNHWGCTransposeDesc(std::array<ck::index_t, NDimSpatial + 3> g_n_c_wis_lengths,
std::array<ck::index_t, NDimSpatial + 3> g_n_c_wis_strides)
{
const index_t& G = g_n_c_wis_lengths[I0];
const index_t& N = g_n_c_wis_lengths[I1];
const index_t& C = g_n_c_wis_lengths[I2];
const index_t& Wi = g_n_c_wis_lengths[I3];
const index_t& NStride = g_n_c_wis_strides[I1];
const index_t WiStride = G * C;
const index_t GStride = C;
const index_t CStride = 1;
const auto desc = make_naive_tensor_descriptor(
make_tuple(N, G, C, Wi), make_tuple(NStride, GStride, CStride, WiStride));
const auto merged_desc =
transform_tensor_descriptor(desc,
make_tuple(make_merge_transform(make_tuple(N, G, C)),
make_merge_transform(make_tuple(Wi))),
make_tuple(Sequence<0, 1, 2>{}, Sequence<3>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
return device::PadTensorDescriptor(
merged_desc, make_tuple(MPerThread, NPerThread), Sequence<true, true>{});
}
template <ck::index_t NDim, typename ck::enable_if<NDim == 2, bool>::type = false>
static auto MakeNGCHWTransposeDesc(std::array<ck::index_t, NDimSpatial + 3> g_n_c_wis_lengths,
std::array<ck::index_t, NDimSpatial + 3> g_n_c_wis_strides)
{
const index_t& G = g_n_c_wis_lengths[I0];
const index_t& N = g_n_c_wis_lengths[I1];
const index_t& C = g_n_c_wis_lengths[I2];
const index_t& Hi = g_n_c_wis_lengths[I3];
const index_t& Wi = g_n_c_wis_lengths[I4];
const index_t& GStride = g_n_c_wis_strides[I0];
const index_t& NStride = g_n_c_wis_strides[I1];
const index_t& CStride = g_n_c_wis_strides[I2];
const index_t& HiStride = g_n_c_wis_strides[I3];
const index_t& WiStride = g_n_c_wis_strides[I4];
const auto desc = make_naive_tensor_descriptor(
make_tuple(N, G, C, Hi, Wi), make_tuple(NStride, GStride, CStride, HiStride, WiStride));
const auto merged_desc =
transform_tensor_descriptor(desc,
make_tuple(make_merge_transform(make_tuple(N, G, C)),
make_merge_transform(make_tuple(Hi, Wi))),
make_tuple(Sequence<0, 1, 2>{}, Sequence<3, 4>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
return device::PadTensorDescriptor(
merged_desc, make_tuple(MPerThread, NPerThread), Sequence<true, true>{});
}
template <ck::index_t NDim, typename ck::enable_if<NDim == 2, bool>::type = false>
static auto MakeNHWGCTransposeDesc(std::array<ck::index_t, NDimSpatial + 3> g_n_c_wis_lengths,
std::array<ck::index_t, NDimSpatial + 3> g_n_c_wis_strides)
{
const index_t& G = g_n_c_wis_lengths[I0];
const index_t& N = g_n_c_wis_lengths[I1];
const index_t& C = g_n_c_wis_lengths[I2];
const index_t& Hi = g_n_c_wis_lengths[I3];
const index_t& Wi = g_n_c_wis_lengths[I4];
const index_t& NStride = g_n_c_wis_strides[I1];
const index_t HiStride = Wi * G * C;
const index_t WiStride = G * C;
const index_t GStride = C;
const index_t CStride = 1;
const auto desc = make_naive_tensor_descriptor(
make_tuple(N, G, C, Hi, Wi), make_tuple(NStride, GStride, CStride, HiStride, WiStride));
const auto merged_desc =
transform_tensor_descriptor(desc,
make_tuple(make_merge_transform(make_tuple(N, G, C)),
make_merge_transform(make_tuple(Hi, Wi))),
make_tuple(Sequence<0, 1, 2>{}, Sequence<3, 4>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
return device::PadTensorDescriptor(
merged_desc, make_tuple(MPerThread, NPerThread), Sequence<true, true>{});
}
template <ck::index_t NDim, typename ck::enable_if<NDim == 3, bool>::type = false>
static auto MakeNGCHWTransposeDesc(std::array<ck::index_t, NDimSpatial + 3> g_n_c_wis_lengths,
std::array<ck::index_t, NDimSpatial + 3> g_n_c_wis_strides)
{
const index_t& G = g_n_c_wis_lengths[I0];
const index_t& N = g_n_c_wis_lengths[I1];
const index_t& C = g_n_c_wis_lengths[I2];
const index_t& Di = g_n_c_wis_lengths[I3];
const index_t& Hi = g_n_c_wis_lengths[I4];
const index_t& Wi = g_n_c_wis_lengths[I5];
const index_t& GStride = g_n_c_wis_strides[I0];
const index_t& NStride = g_n_c_wis_strides[I1];
const index_t& CStride = g_n_c_wis_strides[I2];
const index_t& DiStride = g_n_c_wis_strides[I3];
const index_t& HiStride = g_n_c_wis_strides[I4];
const index_t& WiStride = g_n_c_wis_strides[I5];
const auto desc = make_naive_tensor_descriptor(
make_tuple(N, G, C, Di, Hi, Wi),
make_tuple(NStride, GStride, CStride, DiStride, HiStride, WiStride));
const auto merged_desc =
transform_tensor_descriptor(desc,
make_tuple(make_merge_transform(make_tuple(N, G, C)),
make_merge_transform(make_tuple(Di, Hi, Wi))),
make_tuple(Sequence<0, 1, 2>{}, Sequence<3, 4, 5>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
return device::PadTensorDescriptor(
merged_desc, make_tuple(MPerThread, NPerThread), Sequence<true, true>{});
}
template <ck::index_t NDim, typename ck::enable_if<NDim == 3, bool>::type = false>
static auto MakeNHWGCTransposeDesc(std::array<ck::index_t, NDimSpatial + 3> g_n_c_wis_lengths,
std::array<ck::index_t, NDimSpatial + 3> g_n_c_wis_strides)
{
const index_t& G = g_n_c_wis_lengths[I0];
const index_t& N = g_n_c_wis_lengths[I1];
const index_t& C = g_n_c_wis_lengths[I2];
const index_t& Di = g_n_c_wis_lengths[I3];
const index_t& Hi = g_n_c_wis_lengths[I4];
const index_t& Wi = g_n_c_wis_lengths[I5];
const index_t& NStride = g_n_c_wis_strides[I1];
const index_t DiStride = Hi * Wi * G * C;
const index_t HiStride = Wi * G * C;
const index_t WiStride = G * C;
const index_t GStride = C;
const index_t CStride = 1;
const auto desc = make_naive_tensor_descriptor(
make_tuple(N, G, C, Di, Hi, Wi),
make_tuple(NStride, GStride, CStride, DiStride, HiStride, WiStride));
const auto merged_desc =
transform_tensor_descriptor(desc,
make_tuple(make_merge_transform(make_tuple(N, G, C)),
make_merge_transform(make_tuple(Di, Hi, Wi))),
make_tuple(Sequence<0, 1, 2>{}, Sequence<3, 4, 5>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
return device::PadTensorDescriptor(
merged_desc, make_tuple(MPerThread, NPerThread), Sequence<true, true>{});
}
static auto TransposeStrides(const std::array<index_t, NDimSpatial + 3>& g_n_c_wis_lengths,
const std::array<index_t, NDimSpatial + 3>& g_n_c_wis_strides)
{
if constexpr(device::is_NGCHW_GKYXC_NGKHW<ALayout, BLayout, ELayout>() ||
device::is_NGCDHW_GKZYXC_NGKDHW<ALayout, BLayout, ELayout>())
{
std::array<index_t, NDimSpatial + 3> g_n_c_wis_strides_transposed;
const auto G = g_n_c_wis_lengths[I0];
const auto C = g_n_c_wis_lengths[I2];
g_n_c_wis_strides_transposed[I0] = C;
g_n_c_wis_strides_transposed[I1] = g_n_c_wis_strides[I1];
g_n_c_wis_strides_transposed[I2] = I1;
if constexpr(NDimSpatial == 2)
{
g_n_c_wis_strides_transposed[I3] = g_n_c_wis_lengths[I4] * G * C;
g_n_c_wis_strides_transposed[I4] = G * C;
}
else if constexpr(NDimSpatial == 3)
{
g_n_c_wis_strides_transposed[I3] =
g_n_c_wis_lengths[I4] * g_n_c_wis_lengths[I5] * G * C;
g_n_c_wis_strides_transposed[I4] = g_n_c_wis_lengths[I5] * G * C;
g_n_c_wis_strides_transposed[I5] = G * C;
}
return g_n_c_wis_strides_transposed;
}
else
{
// transpose not needed
return g_n_c_wis_strides;
}
}
};
} // namespace tensor_operation
} // namespace ck
...@@ -9,16 +9,18 @@ namespace ck { ...@@ -9,16 +9,18 @@ namespace ck {
template <index_t MPerWave, index_t NPerWave> template <index_t MPerWave, index_t NPerWave>
struct intrin_smfmac_f32_16x16x32f16; struct intrin_smfmac_f32_16x16x32f16;
// for every smfmac instruction if CBSZ[1:0]=0, ABID[1:0] selects one of four 8-bit sets of sparse
// indices from reg_idx
template <> template <>
struct intrin_smfmac_f32_16x16x32f16<16, 16> struct intrin_smfmac_f32_16x16x32f16<16, 16>
{ {
template <class FloatC> template <class FloatC, index_t abid = 0>
__device__ static void __device__ static void
Run(const half4_t& reg_a, const half8_t& reg_b, const int32_t& reg_idx, FloatC& reg_c) Run(const half4_t& reg_a, const half8_t& reg_b, const index_t& reg_idx, FloatC& reg_c)
{ {
#if defined(__gfx94__) #if defined(__gfx94__)
reg_c.template AsType<float4_t>()(Number<0>{}) = __builtin_amdgcn_smfmac_f32_16x16x32_f16( reg_c.template AsType<float4_t>()(Number<0>{}) = __builtin_amdgcn_smfmac_f32_16x16x32_f16(
reg_a, reg_b, reg_c.template AsType<float4_t>()[Number<0>{}], reg_idx, 0, 0); reg_a, reg_b, reg_c.template AsType<float4_t>()[Number<0>{}], reg_idx, 0, abid);
#else #else
ignore = reg_a; ignore = reg_a;
ignore = reg_b; ignore = reg_b;
...@@ -34,13 +36,13 @@ struct intrin_smfmac_f32_16x16x32bf16; ...@@ -34,13 +36,13 @@ struct intrin_smfmac_f32_16x16x32bf16;
template <> template <>
struct intrin_smfmac_f32_16x16x32bf16<16, 16> struct intrin_smfmac_f32_16x16x32bf16<16, 16>
{ {
template <class FloatC> template <class FloatC, index_t abid = 0>
__device__ static void __device__ static void
Run(const bhalf4_t& reg_a, const bhalf8_t& reg_b, const int32_t& reg_idx, FloatC& reg_c) Run(const bhalf4_t& reg_a, const bhalf8_t& reg_b, const index_t& reg_idx, FloatC& reg_c)
{ {
#if defined(__gfx94__) #if defined(__gfx94__)
reg_c.template AsType<float4_t>()(Number<0>{}) = __builtin_amdgcn_smfmac_f32_16x16x32_bf16( reg_c.template AsType<float4_t>()(Number<0>{}) = __builtin_amdgcn_smfmac_f32_16x16x32_bf16(
reg_a, reg_b, reg_c.template AsType<float4_t>()[Number<0>{}], reg_idx, 0, 0); reg_a, reg_b, reg_c.template AsType<float4_t>()[Number<0>{}], reg_idx, 0, abid);
#else #else
ignore = reg_a; ignore = reg_a;
ignore = reg_b; ignore = reg_b;
...@@ -56,13 +58,13 @@ struct intrin_smfmac_f32_32x32x16f16; ...@@ -56,13 +58,13 @@ struct intrin_smfmac_f32_32x32x16f16;
template <> template <>
struct intrin_smfmac_f32_32x32x16f16<32, 32> struct intrin_smfmac_f32_32x32x16f16<32, 32>
{ {
template <class FloatC> template <class FloatC, index_t abid = 0>
__device__ static void __device__ static void
Run(const half4_t& reg_a, const half8_t& reg_b, const int32_t& reg_idx, FloatC& reg_c) Run(const half4_t& reg_a, const half8_t& reg_b, const index_t& reg_idx, FloatC& reg_c)
{ {
#if defined(__gfx94__) #if defined(__gfx94__)
reg_c.template AsType<float16_t>()(Number<0>{}) = __builtin_amdgcn_smfmac_f32_32x32x16_f16( reg_c.template AsType<float16_t>()(Number<0>{}) = __builtin_amdgcn_smfmac_f32_32x32x16_f16(
reg_a, reg_b, reg_c.template AsType<float16_t>()[Number<0>{}], reg_idx, 0, 0); reg_a, reg_b, reg_c.template AsType<float16_t>()[Number<0>{}], reg_idx, 0, abid);
#else #else
ignore = reg_a; ignore = reg_a;
ignore = reg_b; ignore = reg_b;
...@@ -78,13 +80,13 @@ struct intrin_smfmac_f32_32x32x16bf16; ...@@ -78,13 +80,13 @@ struct intrin_smfmac_f32_32x32x16bf16;
template <> template <>
struct intrin_smfmac_f32_32x32x16bf16<32, 32> struct intrin_smfmac_f32_32x32x16bf16<32, 32>
{ {
template <class FloatC> template <class FloatC, index_t abid = 0>
__device__ static void __device__ static void
Run(const bhalf4_t& reg_a, const bhalf8_t& reg_b, const int32_t& reg_idx, FloatC& reg_c) Run(const bhalf4_t& reg_a, const bhalf8_t& reg_b, const index_t& reg_idx, FloatC& reg_c)
{ {
#if defined(__gfx94__) #if defined(__gfx94__)
reg_c.template AsType<float16_t>()(Number<0>{}) = __builtin_amdgcn_smfmac_f32_32x32x16_bf16( reg_c.template AsType<float16_t>()(Number<0>{}) = __builtin_amdgcn_smfmac_f32_32x32x16_bf16(
reg_a, reg_b, reg_c.template AsType<float16_t>()[Number<0>{}], reg_idx, 0, 0); reg_a, reg_b, reg_c.template AsType<float16_t>()[Number<0>{}], reg_idx, 0, abid);
#else #else
ignore = reg_a; ignore = reg_a;
ignore = reg_b; ignore = reg_b;
......
...@@ -52,12 +52,28 @@ struct Add ...@@ -52,12 +52,28 @@ struct Add
__host__ __device__ inline constexpr void operator()(T& a, T b) const __host__ __device__ inline constexpr void operator()(T& a, T b) const
{ {
static_assert(is_same<T, float>::value || is_same<T, double>::value || static_assert(is_same<T, float>::value || is_same<T, double>::value ||
is_same<T, int32_t>::value || is_same<T, half_t>::value, is_same<T, int32_t>::value || is_same<T, int8_t>::value,
"The data type is not supported by the Add accumulator!"); "The data type is not supported by the Add accumulator!");
a = a + b; a = a + b;
} }
__host__ __device__ inline constexpr void operator()(f8_t& a, f8_t b) const
{
float a_ = type_convert<float>(a);
float b_ = type_convert<float>(b);
a = type_convert<f8_t>(a_ + b_);
}
__host__ __device__ inline constexpr void operator()(half_t& a, half_t b) const
{
float a_ = type_convert<float>(a);
float b_ = type_convert<float>(b);
a = type_convert<half_t>(a_ + b_);
}
__host__ __device__ inline constexpr void operator()(bhalf_t& a, bhalf_t b) const __host__ __device__ inline constexpr void operator()(bhalf_t& a, bhalf_t b) const
{ {
float a_ = type_convert<float>(a); float a_ = type_convert<float>(a);
...@@ -112,12 +128,28 @@ struct Mul ...@@ -112,12 +128,28 @@ struct Mul
__host__ __device__ inline constexpr void operator()(T& a, T b) const __host__ __device__ inline constexpr void operator()(T& a, T b) const
{ {
static_assert(is_same<T, float>::value || is_same<T, double>::value || static_assert(is_same<T, float>::value || is_same<T, double>::value ||
is_same<T, int32_t>::value || is_same<T, half_t>::value, is_same<T, int32_t>::value || is_same<T, int8_t>::value,
"The data type is not supported by the Mul accumulator!"); "The data type is not supported by the Mul accumulator!");
a = a * b; a = a * b;
} }
__host__ __device__ inline constexpr void operator()(f8_t& a, f8_t b) const
{
float a_ = type_convert<float>(a);
float b_ = type_convert<float>(b);
a = type_convert<f8_t>(a_ * b_);
}
__host__ __device__ inline constexpr void operator()(half_t& a, half_t b) const
{
float a_ = type_convert<float>(a);
float b_ = type_convert<float>(b);
a = type_convert<half_t>(a_ * b_);
}
__host__ __device__ inline constexpr void operator()(bhalf_t& a, bhalf_t b) const __host__ __device__ inline constexpr void operator()(bhalf_t& a, bhalf_t b) const
{ {
float a_ = type_convert<float>(a); float a_ = type_convert<float>(a);
...@@ -137,6 +169,16 @@ struct Max ...@@ -137,6 +169,16 @@ struct Max
float val = NumericLimits<float>::Lowest(); float val = NumericLimits<float>::Lowest();
return type_convert<bhalf_t>(val); return type_convert<bhalf_t>(val);
} }
if constexpr(is_same_v<T, f8_t>)
{
float val = NumericLimits<float>::Lowest();
return type_convert<f8_t>(val);
}
if constexpr(is_same_v<T, half_t>)
{
float val = NumericLimits<float>::Lowest();
return type_convert<half_t>(val);
}
else else
{ {
return NumericLimits<T>::Lowest(); return NumericLimits<T>::Lowest();
...@@ -154,8 +196,7 @@ struct Max ...@@ -154,8 +196,7 @@ struct Max
__host__ __device__ inline constexpr void operator()(T& a, T b) const __host__ __device__ inline constexpr void operator()(T& a, T b) const
{ {
static_assert(is_same<T, float>::value || is_same<T, double>::value || static_assert(is_same<T, float>::value || is_same<T, double>::value ||
is_same<T, half_t>::value || is_same<T, int32_t>::value || is_same<T, int32_t>::value || is_same<T, int8_t>::value,
is_same<T, int8_t>::value,
"The data type is not supported by the Max accumulator!"); "The data type is not supported by the Max accumulator!");
if(a < b) if(a < b)
...@@ -171,12 +212,29 @@ struct Max ...@@ -171,12 +212,29 @@ struct Max
a = b; a = b;
} }
__host__ __device__ inline constexpr void operator()(half_t& a, half_t b) const
{
float a_ = type_convert<float>(a);
float b_ = type_convert<float>(b);
if(a_ < b_)
a = b;
}
__host__ __device__ inline constexpr void operator()(f8_t& a, f8_t b) const
{
float a_ = type_convert<float>(a);
float b_ = type_convert<float>(b);
if(a_ < b_)
a = b;
}
template <typename T> template <typename T>
__host__ __device__ inline constexpr void operator()(T& a, T b, bool& changed) const __host__ __device__ inline constexpr void operator()(T& a, T b, bool& changed) const
{ {
static_assert(is_same<T, float>::value || is_same<T, double>::value || static_assert(is_same<T, float>::value || is_same<T, double>::value ||
is_same<T, half_t>::value || is_same<T, int32_t>::value || is_same<T, int32_t>::value || is_same<T, int8_t>::value,
is_same<T, int8_t>::value,
"The data type is not supported by the Max accumulator!"); "The data type is not supported by the Max accumulator!");
if(a < b) if(a < b)
...@@ -197,6 +255,30 @@ struct Max ...@@ -197,6 +255,30 @@ struct Max
changed = true; changed = true;
} }
} }
__host__ __device__ inline constexpr void operator()(half_t& a, half_t b, bool& changed) const
{
float a_ = type_convert<float>(a);
float b_ = type_convert<float>(b);
if(a_ < b_)
{
a = b;
changed = true;
}
}
__host__ __device__ inline constexpr void operator()(f8_t& a, f8_t b, bool& changed) const
{
float a_ = type_convert<float>(a);
float b_ = type_convert<float>(b);
if(a_ < b_)
{
a = b;
changed = true;
}
}
}; };
struct Min struct Min
...@@ -209,6 +291,16 @@ struct Min ...@@ -209,6 +291,16 @@ struct Min
float val = NumericLimits<float>::Max(); float val = NumericLimits<float>::Max();
return type_convert<bhalf_t>(val); return type_convert<bhalf_t>(val);
} }
else if constexpr(is_same_v<T, half_t>)
{
float val = NumericLimits<float>::Max();
return type_convert<half_t>(val);
}
else if constexpr(is_same_v<T, f8_t>)
{
float val = NumericLimits<float>::Max();
return type_convert<f8_t>(val);
}
else else
{ {
return NumericLimits<T>::Max(); return NumericLimits<T>::Max();
...@@ -227,8 +319,7 @@ struct Min ...@@ -227,8 +319,7 @@ struct Min
__host__ __device__ inline constexpr void operator()(T& a, T b) const __host__ __device__ inline constexpr void operator()(T& a, T b) const
{ {
static_assert(is_same<T, float>::value || is_same<T, double>::value || static_assert(is_same<T, float>::value || is_same<T, double>::value ||
is_same<T, half_t>::value || is_same<T, int32_t>::value || is_same<T, int32_t>::value || is_same<T, int8_t>::value,
is_same<T, int8_t>::value,
"The data type is not supported by the Min accumulator!"); "The data type is not supported by the Min accumulator!");
if(a > b) if(a > b)
...@@ -244,6 +335,24 @@ struct Min ...@@ -244,6 +335,24 @@ struct Min
a = b; a = b;
} }
__host__ __device__ inline constexpr void operator()(half_t& a, half_t b) const
{
float a_ = type_convert<float>(a);
float b_ = type_convert<float>(b);
if(a_ > b_)
a = b;
}
__host__ __device__ inline constexpr void operator()(f8_t& a, f8_t b) const
{
float a_ = type_convert<float>(a);
float b_ = type_convert<float>(b);
if(a_ > b_)
a = b;
}
template <typename T> template <typename T>
__host__ __device__ inline constexpr void operator()(T& a, T b, bool& changed) const __host__ __device__ inline constexpr void operator()(T& a, T b, bool& changed) const
{ {
...@@ -270,6 +379,30 @@ struct Min ...@@ -270,6 +379,30 @@ struct Min
changed = true; changed = true;
} }
} }
__host__ __device__ inline constexpr void operator()(half_t& a, half_t b, bool& changed) const
{
float a_ = type_convert<float>(a);
float b_ = type_convert<float>(b);
if(a_ > b_)
{
a = b;
changed = true;
}
}
__host__ __device__ inline constexpr void operator()(f8_t& a, f8_t b, bool& changed) const
{
float a_ = type_convert<float>(a);
float b_ = type_convert<float>(b);
if(a_ > b_)
{
a = b;
changed = true;
}
}
}; };
struct AMax struct AMax
...@@ -299,6 +432,15 @@ struct AMax ...@@ -299,6 +432,15 @@ struct AMax
a = b; a = b;
} }
__host__ __device__ inline constexpr void operator()(f8_t& a, f8_t b) const
{
float a_ = type_convert<float>(a);
float b_ = type_convert<float>(b);
if(a_ < b_)
a = b;
}
template <typename T> template <typename T>
__host__ __device__ inline constexpr void operator()(T& a, T b, bool& changed) const __host__ __device__ inline constexpr void operator()(T& a, T b, bool& changed) const
{ {
...@@ -313,6 +455,18 @@ struct AMax ...@@ -313,6 +455,18 @@ struct AMax
changed = true; changed = true;
} }
} }
__host__ __device__ inline constexpr void operator()(f8_t& a, f8_t b, bool& changed) const
{
float a_ = type_convert<float>(a);
float b_ = type_convert<float>(b);
if(a_ < b_)
{
a = b;
changed = true;
}
}
}; };
template <typename T> template <typename T>
...@@ -352,7 +506,8 @@ struct InMemoryDataOperationSupportedOnDataType<InMemoryDataOperationEnum::Set, ...@@ -352,7 +506,8 @@ struct InMemoryDataOperationSupportedOnDataType<InMemoryDataOperationEnum::Set,
static constexpr bool value = static constexpr bool value =
is_same<DataType, float>::value || is_same<DataType, double>::value || is_same<DataType, float>::value || is_same<DataType, double>::value ||
is_same<DataType, half_t>::value || is_same<DataType, bhalf_t>::value || is_same<DataType, half_t>::value || is_same<DataType, bhalf_t>::value ||
is_same<DataType, int8_t>::value || is_same<DataType, int32_t>::value; is_same<DataType, int8_t>::value || is_same<DataType, int32_t>::value ||
is_same<DataType, f8_t>::value;
}; };
template <typename DataType> template <typename DataType>
......
...@@ -5,6 +5,7 @@ ...@@ -5,6 +5,7 @@
#include "ck_tile/core.hpp" #include "ck_tile/core.hpp"
#include "ck_tile/host/host_tensor.hpp" #include "ck_tile/host/host_tensor.hpp"
#include "ck_tile/ops/common/tensor_layout.hpp"
#include <thread> #include <thread>
namespace ck_tile { namespace ck_tile {
...@@ -13,6 +14,9 @@ template <typename ADataType, ...@@ -13,6 +14,9 @@ template <typename ADataType,
typename BDataType, typename BDataType,
typename AccDataType, typename AccDataType,
typename CDataType, typename CDataType,
typename LayoutA,
typename LayoutB,
typename LayoutC,
typename AElementOp = ck_tile::identity, typename AElementOp = ck_tile::identity,
typename BElementOp = ck_tile::identity, typename BElementOp = ck_tile::identity,
typename ACCElementOp = ck_tile::identity> typename ACCElementOp = ck_tile::identity>
...@@ -24,7 +28,12 @@ CK_TILE_HOST void reference_gemm(const HostTensor<ADataType>& a_m_k, ...@@ -24,7 +28,12 @@ CK_TILE_HOST void reference_gemm(const HostTensor<ADataType>& a_m_k,
const ACCElementOp& acc_element_op = {}) const ACCElementOp& acc_element_op = {})
{ {
const int N = b_n_k.mDesc.get_lengths()[0]; const int N = b_n_k.mDesc.get_lengths()[0];
const int K = b_n_k.mDesc.get_lengths()[1]; const int K = (std::is_same_v<LayoutA, tensor_layout::gemm::RowMajor>)
? a_m_k.mDesc.get_lengths()[1]
: a_m_k.mDesc.get_lengths()[0];
const int M = (std::is_same_v<LayoutA, tensor_layout::gemm::RowMajor>)
? a_m_k.mDesc.get_lengths()[0]
: a_m_k.mDesc.get_lengths()[1];
auto f = [&](auto m) { auto f = [&](auto m) {
for(int n = 0; n < N; ++n) for(int n = 0; n < N; ++n)
...@@ -33,7 +42,9 @@ CK_TILE_HOST void reference_gemm(const HostTensor<ADataType>& a_m_k, ...@@ -33,7 +42,9 @@ CK_TILE_HOST void reference_gemm(const HostTensor<ADataType>& a_m_k,
for(int k = 0; k < K; ++k) for(int k = 0; k < K; ++k)
{ {
ADataType v_a = a_element_op(a_m_k(m, k)); ADataType v_a = (std::is_same_v<LayoutA, tensor_layout::gemm::RowMajor>)
? a_element_op(a_m_k(m, k))
: a_element_op(a_m_k(k, m));
BDataType v_b = b_element_op(b_n_k(n, k)); BDataType v_b = b_element_op(b_n_k(n, k));
v_acc += ck_tile::type_convert<AccDataType>(v_a) * v_acc += ck_tile::type_convert<AccDataType>(v_a) *
...@@ -44,7 +55,123 @@ CK_TILE_HOST void reference_gemm(const HostTensor<ADataType>& a_m_k, ...@@ -44,7 +55,123 @@ CK_TILE_HOST void reference_gemm(const HostTensor<ADataType>& a_m_k,
} }
}; };
make_ParallelTensorFunctor(f, make_ParallelTensorFunctor(f, M)(std::thread::hardware_concurrency());
c_m_n.mDesc.get_lengths()[0])(std::thread::hardware_concurrency()); }
template <typename ADataType, typename BDataType, typename AccDataType, typename CDataType>
__global__ void naive_gemm_kernel(ADataType* A,
BDataType* B,
CDataType* C,
ck_tile::index_t M,
ck_tile::index_t N,
ck_tile::index_t K,
ck_tile::index_t strideA,
ck_tile::index_t strideB,
ck_tile::index_t strideC)
{
int idx = blockIdx.x * blockDim.x + threadIdx.x;
int row = idx / N; // Compute row index
int col = idx % N; // Compute column index
if(row < M && col < N)
{
AccDataType acc = 0.0;
for(int k = 0; k < K; ++k)
{
acc += static_cast<AccDataType>(A[row * strideA + k]) *
static_cast<AccDataType>(B[col * strideB + k]);
}
C[row * strideC + col] = acc; // Store as AccDataType
}
}
template <typename ADataType, typename BDataType, typename AccDataType, typename CDataType>
void reference_gemm_gpu(DeviceMem& a_device,
DeviceMem& b_device,
DeviceMem& c_device,
index_t M,
index_t N,
index_t K,
index_t stride_a,
index_t stride_b,
index_t stride_c)
{
ADataType* d_A;
BDataType* d_B;
CDataType* d_C;
hipError_t errA = hipMalloc(&d_A, M * K * sizeof(ADataType));
hipError_t errB = hipMalloc(&d_B, N * K * sizeof(BDataType));
hipError_t errC = hipMalloc(&d_C, M * N * sizeof(CDataType));
if(errA != hipSuccess)
{
std::cerr << "Error allocating device memory for A: " << hipGetErrorString(errA)
<< std::endl;
return; // Early exit on error
}
if(errB != hipSuccess)
{
std::cerr << "Error allocating device memory for B: " << hipGetErrorString(errB)
<< std::endl;
return; // Early exit on error
}
if(errC != hipSuccess)
{
std::cerr << "Error allocating device memory for C: " << hipGetErrorString(errC)
<< std::endl;
return; // Early exit on error
}
errA = hipMemcpy(
d_A, a_device.GetDeviceBuffer(), M * K * sizeof(ADataType), hipMemcpyHostToDevice);
if(errA != hipSuccess)
{
std::cerr << "Error copying A to device: " << hipGetErrorString(errA) << std::endl;
}
errB = hipMemcpy(
d_B, b_device.GetDeviceBuffer(), N * K * sizeof(BDataType), hipMemcpyHostToDevice);
if(errB != hipSuccess)
{
std::cerr << "Error copying B to device: " << hipGetErrorString(errB) << std::endl;
}
int totalElements = M * N;
int numThreadsPerBlock = 256; // Common choice for threads per block
int numBlocks = (totalElements + numThreadsPerBlock - 1) / numThreadsPerBlock;
naive_gemm_kernel<ADataType, BDataType, AccDataType, CDataType>
<<<numBlocks, numThreadsPerBlock>>>(d_A, d_B, d_C, M, N, K, stride_a, stride_b, stride_c);
errC = hipMemcpy(
c_device.GetDeviceBuffer(), d_C, M * N * sizeof(CDataType), hipMemcpyDeviceToHost);
if(errC != hipSuccess)
{
std::cerr << "Error copying C to device: " << hipGetErrorString(errC) << std::endl;
}
errA = hipFree(d_A);
if(errA != hipSuccess)
{
std::cerr << "Error free the A memory: " << hipGetErrorString(errA) << std::endl;
}
errB = hipFree(d_B);
if(errB != hipSuccess)
{
std::cerr << "Error free the B memory: " << hipGetErrorString(errB) << std::endl;
}
errC = hipFree(d_C);
if(errC != hipSuccess)
{
std::cerr << "Error free the C memory: " << hipGetErrorString(errC) << std::endl;
}
return;
} }
} // namespace ck_tile } // namespace ck_tile
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment