Commit 6d9a07d7 authored by Jun Liu's avatar Jun Liu
Browse files

Merge branch 'develop' into amd-develop

parents b30d416c a776978c
// SPDX-License-Identifier: MIT
// Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2023-2024, Advanced Micro Devices, Inc. All rights reserved.
#include <cstdlib>
#include <iostream>
......@@ -100,7 +100,7 @@ template <ck::index_t NDimSpatial,
typename WeiElementOp,
typename OutElementOp,
typename DeviceConvNDFwdInstance>
bool run_grouped_conv_fwd(bool do_verification,
bool run_grouped_conv(bool do_verification,
int init_method,
bool time_kernel,
const ck::utils::conv::ConvParam& conv_param,
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2023-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
......@@ -11,7 +11,7 @@ void print_helper_msg()
<< ck::utils::conv::get_conv_param_parser_helper_msg() << std::endl;
}
bool run_convnd_fwd_example(int argc, char* argv[])
bool run_convnd_example(int argc, char* argv[])
{
print_helper_msg();
......@@ -63,14 +63,14 @@ bool run_convnd_fwd_example(int argc, char* argv[])
ck::utils::conv::make_output_host_tensor_descriptor_g_n_k_wos_packed<OutLayout>(
conv_param);
return run_grouped_conv_fwd<NDimSpatial,
return run_grouped_conv<NDimSpatial,
InDataType,
WeiDataType,
OutDataType,
InElementOp,
WeiElementOp,
OutElementOp,
DeviceGroupedConvNDFwdActivInstance>(do_verification,
DeviceGroupedConvNDActivInstance>(do_verification,
init_method,
time_kernel,
conv_param,
......
......@@ -2,48 +2,34 @@ list(APPEND gpu_list gfx908 gfx90a gfx940 gfx941 gfx942)
set(target 0)
foreach(gpu IN LISTS GPU_TARGETS)
if(gpu IN_LIST gpu_list AND target EQUAL 0)
add_custom_target(example_convnd_fwd_activ_xdl)
add_custom_target(example_convnd_activ_unary_xdl)
# Sigmoid
add_example_executable(example_convnd_fwd_xdl_sigmoid_fp16 convnd_fwd_xdl_sigmoid_fp16.cpp)
add_example_dependencies(example_convnd_fwd_activ_xdl example_convnd_fwd_xdl_sigmoid_fp16)
add_example_dependencies(example_convnd_activ_unary_xdl example_convnd_fwd_xdl_sigmoid_fp16)
# Tanh
add_example_executable(example_convnd_fwd_xdl_tanh_fp16 convnd_fwd_xdl_tanh_fp16.cpp)
add_example_dependencies(example_convnd_fwd_activ_xdl example_convnd_fwd_xdl_tanh_fp16)
add_example_dependencies(example_convnd_activ_unary_xdl example_convnd_fwd_xdl_tanh_fp16)
# Relu
add_example_executable(example_convnd_fwd_xdl_relu_fp16 convnd_fwd_xdl_relu_fp16.cpp)
add_example_dependencies(example_convnd_fwd_activ_xdl example_convnd_fwd_xdl_relu_fp16)
add_example_dependencies(example_convnd_activ_unary_xdl example_convnd_fwd_xdl_relu_fp16)
# SoftRelu
add_example_executable(example_convnd_fwd_xdl_softrelu_fp16 convnd_fwd_xdl_softrelu_fp16.cpp)
add_example_dependencies(example_convnd_fwd_activ_xdl example_convnd_fwd_xdl_softrelu_fp16)
add_example_dependencies(example_convnd_activ_unary_xdl example_convnd_fwd_xdl_softrelu_fp16)
# Abs
add_example_executable(example_convnd_fwd_xdl_abs_fp16 convnd_fwd_xdl_abs_fp16.cpp)
add_example_dependencies(example_convnd_fwd_activ_xdl example_convnd_fwd_xdl_abs_fp16)
add_example_dependencies(example_convnd_activ_unary_xdl example_convnd_fwd_xdl_abs_fp16)
# Pow
add_example_executable(example_convnd_fwd_xdl_pow_fp16 convnd_fwd_xdl_pow_fp16.cpp)
add_example_dependencies(example_convnd_fwd_activ_xdl example_convnd_fwd_xdl_pow_fp16)
add_example_dependencies(example_convnd_activ_unary_xdl example_convnd_fwd_xdl_pow_fp16)
# Clipped Relu
add_example_executable(example_convnd_fwd_xdl_clippedrelu_fp16 convnd_fwd_xdl_clippedrelu_fp16.cpp)
add_example_dependencies(example_convnd_fwd_activ_xdl example_convnd_fwd_xdl_clippedrelu_fp16)
add_example_dependencies(example_convnd_activ_unary_xdl example_convnd_fwd_xdl_clippedrelu_fp16)
# Leaky Relu
add_example_executable(example_convnd_fwd_xdl_leakyrelu_fp16 convnd_fwd_xdl_leakyrelu_fp16.cpp)
add_example_dependencies(example_convnd_fwd_activ_xdl example_convnd_fwd_xdl_leakyrelu_fp16)
add_example_dependencies(example_convnd_activ_unary_xdl example_convnd_fwd_xdl_leakyrelu_fp16)
# Elu
add_example_executable(example_convnd_fwd_xdl_elu_fp16 convnd_fwd_xdl_elu_fp16.cpp)
add_example_dependencies(example_convnd_fwd_activ_xdl example_convnd_fwd_xdl_elu_fp16)
# ScaleAdd on A and B
add_example_executable(example_conv_fwd_xdl_scaleadd_ab_fp16 multi_AB/conv_fwd_xdl_scaleadd_ab_fp16.cpp)
add_example_dependencies(example_convnd_fwd_activ_xdl example_conv_fwd_xdl_scaleadd_ab_fp16)
add_example_executable(example_conv_fwd_xdl_scaleadd_ab_fp32 multi_AB/conv_fwd_xdl_scaleadd_ab_fp32.cpp)
add_example_dependencies(example_convnd_fwd_activ_xdl example_conv_fwd_xdl_scaleadd_ab_fp32)
add_example_executable(example_conv_fwd_xdl_scaleadd_ab_bf16 multi_AB/conv_fwd_xdl_scaleadd_ab_bf16.cpp)
add_example_dependencies(example_convnd_fwd_activ_xdl example_conv_fwd_xdl_scaleadd_ab_bf16)
add_example_executable(example_conv_fwd_xdl_scaleadd_ab_int8 multi_AB/conv_fwd_xdl_scaleadd_ab_int8.cpp)
add_example_dependencies(example_convnd_fwd_activ_xdl example_conv_fwd_xdl_scaleadd_ab_int8)
# ScaleAdd ScaleAdd Relu
add_example_executable(example_convnd_fwd_xdl_scaleadd_scaleadd_relu_fp16 convnd_fwd_xdl_scaleadd_scaleadd_relu_fp16.cpp)
add_example_dependencies(example_convnd_fwd_activ_xdl example_convnd_fwd_xdl_scaleadd_scaleadd_relu_fp16)
add_example_executable(example_convnd_fwd_xdl_scaleadd_scaleadd_relu_bcasted_bias_fp16 convnd_fwd_xdl_scaleadd_scaleadd_relu_bcasted_bias_fp16.cpp)
add_example_dependencies(example_convnd_fwd_activ_xdl example_convnd_fwd_xdl_scaleadd_scaleadd_relu_bcasted_bias_fp16)
add_example_dependencies(example_convnd_activ_unary_xdl example_convnd_fwd_xdl_elu_fp16)
set(target 1)
endif()
endforeach()
// SPDX-License-Identifier: MIT
// Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2023-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
......@@ -102,7 +102,7 @@ template <ck::index_t NDimSpatial,
typename WeiElementOp,
typename OutElementOp,
typename DeviceConvNDFwdInstance>
bool run_grouped_conv_fwd(bool do_verification,
bool run_grouped_conv(bool do_verification,
int init_method,
bool time_kernel,
const ck::utils::conv::ConvParam& conv_param,
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2023-2024, Advanced Micro Devices, Inc. All rights reserved.
#include "convnd_fwd_activ_unary_common.hpp"
using OutElementOp = ck::tensor_operation::element_wise::UnaryAbs;
using DeviceGroupedConvNDActivInstance = DeviceGroupedConvNDFwdInstance<OutElementOp>;
#include "../run_convnd_activ_example.inc"
int main(int argc, char* argv[]) { return !run_convnd_example(argc, argv); }
// SPDX-License-Identifier: MIT
// Copyright (c) 2023-2024, Advanced Micro Devices, Inc. All rights reserved.
#include "convnd_fwd_activ_unary_common.hpp"
using OutElementOp = ck::tensor_operation::element_wise::ClippedRelu;
using DeviceGroupedConvNDActivInstance = DeviceGroupedConvNDFwdInstance<OutElementOp>;
#include "../run_convnd_activ_example.inc"
int main(int argc, char* argv[]) { return !run_convnd_example(argc, argv); }
// SPDX-License-Identifier: MIT
// Copyright (c) 2023-2024, Advanced Micro Devices, Inc. All rights reserved.
#include "convnd_fwd_activ_unary_common.hpp"
using OutElementOp = ck::tensor_operation::element_wise::Elu;
using DeviceGroupedConvNDActivInstance = DeviceGroupedConvNDFwdInstance<OutElementOp>;
#include "../run_convnd_activ_example.inc"
int main(int argc, char* argv[]) { return !run_convnd_example(argc, argv); }
// SPDX-License-Identifier: MIT
// Copyright (c) 2023-2024, Advanced Micro Devices, Inc. All rights reserved.
#include "convnd_fwd_activ_unary_common.hpp"
using OutElementOp = ck::tensor_operation::element_wise::LeakyRelu;
using DeviceGroupedConvNDActivInstance = DeviceGroupedConvNDFwdInstance<OutElementOp>;
#include "../run_convnd_activ_example.inc"
int main(int argc, char* argv[]) { return !run_convnd_example(argc, argv); }
// SPDX-License-Identifier: MIT
// Copyright (c) 2023-2024, Advanced Micro Devices, Inc. All rights reserved.
#include "convnd_fwd_activ_unary_common.hpp"
using OutElementOp = ck::tensor_operation::element_wise::Power;
using DeviceGroupedConvNDActivInstance = DeviceGroupedConvNDFwdInstance<OutElementOp>;
#include "../run_convnd_activ_example.inc"
int main(int argc, char* argv[]) { return !run_convnd_example(argc, argv); }
// SPDX-License-Identifier: MIT
// Copyright (c) 2023-2024, Advanced Micro Devices, Inc. All rights reserved.
#include "convnd_fwd_activ_unary_common.hpp"
using OutElementOp = ck::tensor_operation::element_wise::Relu;
using DeviceGroupedConvNDActivInstance = DeviceGroupedConvNDFwdInstance<OutElementOp>;
#include "../run_convnd_activ_example.inc"
int main(int argc, char* argv[]) { return !run_convnd_example(argc, argv); }
// SPDX-License-Identifier: MIT
// Copyright (c) 2023-2024, Advanced Micro Devices, Inc. All rights reserved.
#include "convnd_fwd_activ_unary_common.hpp"
using OutElementOp = ck::tensor_operation::element_wise::Sigmoid;
using DeviceGroupedConvNDActivInstance = DeviceGroupedConvNDFwdInstance<OutElementOp>;
#include "../run_convnd_activ_example.inc"
int main(int argc, char* argv[]) { return !run_convnd_example(argc, argv); }
// SPDX-License-Identifier: MIT
// Copyright (c) 2023-2024, Advanced Micro Devices, Inc. All rights reserved.
#include "convnd_fwd_activ_unary_common.hpp"
using OutElementOp = ck::tensor_operation::element_wise::SoftRelu;
using DeviceGroupedConvNDActivInstance = DeviceGroupedConvNDFwdInstance<OutElementOp>;
#include "../run_convnd_activ_example.inc"
int main(int argc, char* argv[]) { return !run_convnd_example(argc, argv); }
// SPDX-License-Identifier: MIT
// Copyright (c) 2023-2024, Advanced Micro Devices, Inc. All rights reserved.
#include "convnd_fwd_activ_unary_common.hpp"
using OutElementOp = ck::tensor_operation::element_wise::TanH;
using DeviceGroupedConvNDActivInstance = DeviceGroupedConvNDFwdInstance<OutElementOp>;
#include "../run_convnd_activ_example.inc"
int main(int argc, char* argv[]) { return !run_convnd_example(argc, argv); }
......@@ -37,7 +37,9 @@ template <index_t BlockSize,
index_t NPerXDL,
index_t MRepeat,
index_t NRepeat,
index_t KPack>
index_t KPack,
typename ComputeTypeA = FloatA,
typename ComputeTypeB = FloatB>
struct BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1
{
static constexpr auto I0 = Number<0>{};
......@@ -59,7 +61,8 @@ struct BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1
static constexpr index_t A_K1 = AK0MK1BlockDesc{}.GetLength(I2);
static constexpr index_t B_K1 = BK0NK1BlockDesc{}.GetLength(I2);
static constexpr auto xdlops_gemm = XdlopsGemm<FloatA, MPerXDL, NPerXDL, KPack, FloatB>{};
static constexpr auto xdlops_gemm =
XdlopsGemm<ComputeTypeA, MPerXDL, NPerXDL, KPack, ComputeTypeB>{};
static constexpr index_t KPerThread = KPerBlock / xdlops_gemm.K0PerXdlops;
......@@ -295,9 +298,9 @@ struct BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1
const BBlockBuffer& b_block_buf,
CThreadBuffer& c_thread_buf) const
{
auto a_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, FloatA>(
auto a_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, ComputeTypeA>(
a_thread_desc_.GetElementSpaceSize());
auto b_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, FloatB>(
auto b_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, ComputeTypeB>(
b_thread_desc_.GetElementSpaceSize());
static_for<0, MRepeat, 1>{}([&](auto m0) {
......@@ -319,20 +322,20 @@ struct BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1
b_thread_buf);
static_for<0, KPerThread, KPack>{}([&](auto k) {
vector_type<FloatA, KPack> a_thread_vec;
vector_type<FloatB, KPack> b_thread_vec;
vector_type<ComputeTypeA, KPack> a_thread_vec;
vector_type<ComputeTypeB, KPack> b_thread_vec;
static_for<0, KPack, 1>{}([&](auto i) {
a_thread_vec.template AsType<FloatA>()(i) = a_thread_buf
a_thread_vec.template AsType<ComputeTypeA>()(i) = a_thread_buf
[Number<a_thread_desc_.CalculateOffset(make_tuple(0, 0, 0, k + i))>{}];
b_thread_vec.template AsType<FloatB>()(i) = b_thread_buf
b_thread_vec.template AsType<ComputeTypeB>()(i) = b_thread_buf
[Number<b_thread_desc_.CalculateOffset(make_tuple(0, 0, 0, k + i))>{}];
});
using mfma_input_type_a =
typename vector_type<FloatA, xdlops_gemm.K1PerXdlops>::type;
typename vector_type<ComputeTypeA, xdlops_gemm.K1PerXdlops>::type;
using mfma_input_type_b =
typename vector_type<FloatB, xdlops_gemm.K1PerXdlops>::type;
typename vector_type<ComputeTypeB, xdlops_gemm.K1PerXdlops>::type;
constexpr index_t c_offset =
c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0));
......@@ -360,7 +363,7 @@ struct BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1
make_tuple(Number<MRepeat>{}, Number<NRepeat>{}, xdlops_gemm.GetRegSizePerXdlops()));
using AThreadCopy = ThreadwiseTensorSliceTransfer_v4<FloatA,
FloatA,
ComputeTypeA,
decltype(a_block_desc_m0_m1_m2_k),
decltype(a_thread_desc_),
Sequence<1, 1, 1, KPerThread>,
......@@ -370,7 +373,7 @@ struct BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1
A_K1>;
using BThreadCopy = ThreadwiseTensorSliceTransfer_v4<FloatB,
FloatB,
ComputeTypeB,
decltype(b_block_desc_n0_n1_n2_k),
decltype(b_thread_desc_),
Sequence<1, 1, 1, KPerThread>,
......@@ -398,6 +401,8 @@ template <index_t BlockSize,
index_t MRepeat,
index_t NRepeat,
index_t KPack,
typename ComputeTypeA = FloatA,
typename ComputeTypeB = FloatB,
index_t NumMacClusters = CK_EXPERIMENTAL_INTER_WAVE_SCHEDULING_MAC_CLUSTERS>
struct BlockwiseGemmXdlopsInterwave_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1
: public BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1<BlockSize,
......@@ -410,7 +415,9 @@ struct BlockwiseGemmXdlopsInterwave_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1
NPerXDL,
MRepeat,
NRepeat,
KPack>
KPack,
ComputeTypeA,
ComputeTypeB>
{
using Base = BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1<BlockSize,
FloatA,
......@@ -422,7 +429,9 @@ struct BlockwiseGemmXdlopsInterwave_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1
NPerXDL,
MRepeat,
NRepeat,
KPack>;
KPack,
ComputeTypeA,
ComputeTypeB>;
#if CK_EXPERIMENTAL_INTER_WAVE_SCHEDULING
using Base::a_block_desc_m0_m1_m2_k;
......@@ -446,9 +455,9 @@ struct BlockwiseGemmXdlopsInterwave_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1
const BBlockBuffer& b_block_buf,
CThreadBuffer& c_thread_buf) const
{
auto a_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, FloatA>(
auto a_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, ComputeTypeA>(
a_thread_desc_.GetElementSpaceSize());
auto b_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, FloatB>(
auto b_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, ComputeTypeB>(
b_thread_desc_.GetElementSpaceSize());
static_for<0, KPerThread, KPerInnerLoop>{}([&](auto k) {
......@@ -485,22 +494,22 @@ struct BlockwiseGemmXdlopsInterwave_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1
static_for<0, KPerInnerLoop, KPack>{}([&](auto k_) {
static_for<0, MRepeat, 1>{}([&](auto m0) {
static_for<0, NRepeat, 1>{}([&](auto n0) {
vector_type<FloatA, KPack> a_thread_vec;
vector_type<FloatB, KPack> b_thread_vec;
vector_type<ComputeTypeA, KPack> a_thread_vec;
vector_type<ComputeTypeB, KPack> b_thread_vec;
static_for<0, KPack, 1>{}([&](auto i) {
a_thread_vec.template AsType<FloatA>()(i) =
a_thread_vec.template AsType<ComputeTypeA>()(i) =
a_thread_buf[Number<a_thread_desc_.CalculateOffset(
make_tuple(m0, 0, 0, k_ + i))>{}];
b_thread_vec.template AsType<FloatB>()(i) =
b_thread_vec.template AsType<ComputeTypeB>()(i) =
b_thread_buf[Number<b_thread_desc_.CalculateOffset(
make_tuple(n0, 0, 0, k_ + i))>{}];
});
using mfma_input_type_a =
typename vector_type<FloatA, xdlops_gemm.K1PerXdlops>::type;
typename vector_type<ComputeTypeA, xdlops_gemm.K1PerXdlops>::type;
using mfma_input_type_b =
typename vector_type<FloatB, xdlops_gemm.K1PerXdlops>::type;
typename vector_type<ComputeTypeB, xdlops_gemm.K1PerXdlops>::type;
constexpr index_t c_offset =
c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0));
......@@ -550,7 +559,7 @@ struct BlockwiseGemmXdlopsInterwave_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1
make_tuple(Number<NRepeat>{}, I1, I1, Number<KPerInnerLoop>{}));
using AThreadCopy = ThreadwiseTensorSliceTransfer_v4<FloatA,
FloatA,
ComputeTypeA,
decltype(a_block_desc_m0_m1_m2_k),
decltype(a_thread_desc_),
Sequence<1, 1, 1, KPerInnerLoop>,
......@@ -560,7 +569,7 @@ struct BlockwiseGemmXdlopsInterwave_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1
A_K1>;
using BThreadCopy = ThreadwiseTensorSliceTransfer_v4<FloatB,
FloatB,
ComputeTypeB,
decltype(b_block_desc_n0_n1_n2_k),
decltype(b_thread_desc_),
Sequence<1, 1, 1, KPerInnerLoop>,
......@@ -586,7 +595,9 @@ template <index_t BlockSize,
index_t MRepeat,
index_t NRepeat,
index_t KPack,
LoopScheduler LoopSched>
LoopScheduler LoopSched,
typename ComputeTypeA = FloatA,
typename ComputeTypeB = FloatB>
constexpr auto BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_Selector()
{
if constexpr(LoopSched == LoopScheduler::Default)
......@@ -601,7 +612,9 @@ constexpr auto BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_Selector()
NPerXDL,
MRepeat,
NRepeat,
KPack>{};
KPack,
ComputeTypeA,
ComputeTypeB>{};
}
else if constexpr(LoopSched == LoopScheduler::Interwave)
{
......@@ -615,7 +628,9 @@ constexpr auto BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_Selector()
NPerXDL,
MRepeat,
NRepeat,
KPack>{};
KPack,
ComputeTypeA,
ComputeTypeB>{};
}
};
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
......@@ -322,6 +322,19 @@ struct DeviceElementwiseImpl : public DeviceElementwise<InDataTypeTuple,
{
return std::make_unique<Invoker>();
};
std::string GetTypeString() const override
{
auto str = std::stringstream();
// clang-format off
str << "DeviceElementwiseNormalizationImpl<";
str << NumDim << ", ";
str << MPerThread << ">";
// clang-format on
return str.str();
}
}; // namespace device
} // namespace device
......
......@@ -60,7 +60,9 @@ template <typename ADataType,
index_t CBlockTransferScalarPerVector_NWaveNPerXDL,
typename ComputeType = CDataType,
PipelineVersion PipelineVer = PipelineVersion::v1,
LoopScheduler LoopSched = make_default_loop_scheduler()>
LoopScheduler LoopSched = make_default_loop_scheduler(),
typename LDSTypeA = ComputeType,
typename LDSTypeB = ComputeType>
struct DeviceGemmXdlSplitKCShuffle : public DeviceGemmSplitK<ALayout,
BLayout,
......@@ -81,6 +83,9 @@ struct DeviceGemmXdlSplitKCShuffle : public DeviceGemmSplitK<ALayout,
// TODO: should be exposed as Tparams.
static constexpr index_t NumGemmKPrefetchStage = 1;
using ComputeTypeA = ComputeType;
using ComputeTypeB = ComputeType;
using GridwiseGemm = GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2<
BlockSize,
ADataType,
......@@ -125,7 +130,10 @@ struct DeviceGemmXdlSplitKCShuffle : public DeviceGemmSplitK<ALayout,
CBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
LoopSched,
PipelineVer,
ComputeType>;
ComputeTypeA,
ComputeTypeB,
LDSTypeA,
LDSTypeB>;
struct Argument : public GridwiseGemm::Argument
{
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
......@@ -650,18 +650,35 @@ struct DeviceGroupedGemm_Xdl_Fixed_NK : public DeviceGroupedGemmFixedNK<ALayout,
constexpr auto AtomicAdd = InMemoryDataOperationEnum::AtomicAdd;
constexpr auto Set = InMemoryDataOperationEnum::Set;
// For bf16 datatype only kbatch = 1 scenario is supported. This condition is enforced
// in IsSupportedArgument function
if constexpr(std::is_same<ADataType, ck::bhalf_t>::value)
{
if(has_main_k_block_loop)
{
ave_time = launch_kernel(integral_constant<bool, true>{},
integral_constant<InMemoryDataOperationEnum, Set>{});
}
else
{
ave_time = launch_kernel(integral_constant<bool, false>{},
integral_constant<InMemoryDataOperationEnum, Set>{});
}
}
else
{
if(arg.k_batch_ > 1)
{
if(has_main_k_block_loop)
{
ave_time =
launch_kernel(integral_constant<bool, true>{},
ave_time = launch_kernel(
integral_constant<bool, true>{},
integral_constant<InMemoryDataOperationEnum, AtomicAdd>{});
}
else
{
ave_time =
launch_kernel(integral_constant<bool, false>{},
ave_time = launch_kernel(
integral_constant<bool, false>{},
integral_constant<InMemoryDataOperationEnum, AtomicAdd>{});
}
}
......@@ -669,15 +686,18 @@ struct DeviceGroupedGemm_Xdl_Fixed_NK : public DeviceGroupedGemmFixedNK<ALayout,
{
if(has_main_k_block_loop)
{
ave_time = launch_kernel(integral_constant<bool, true>{},
ave_time =
launch_kernel(integral_constant<bool, true>{},
integral_constant<InMemoryDataOperationEnum, Set>{});
}
else
{
ave_time = launch_kernel(integral_constant<bool, false>{},
ave_time =
launch_kernel(integral_constant<bool, false>{},
integral_constant<InMemoryDataOperationEnum, Set>{});
}
}
}
return ave_time;
}
......@@ -718,6 +738,13 @@ struct DeviceGroupedGemm_Xdl_Fixed_NK : public DeviceGroupedGemmFixedNK<ALayout,
}
}
// For bf16 datatype only kbatch = 1 is supported since there is no AtomicAdd
// instruction that supports bf16 and we cannot use splitk because of that
if constexpr(std::is_same<ADataType, ck::bhalf_t>::value)
{
supported = supported & (arg.k_batch_ == 1);
}
return supported;
}
......
......@@ -165,7 +165,7 @@ struct Subtract
struct Bilinear
{
Bilinear(float alpha, float beta) : alpha_(alpha), beta_(beta){};
Bilinear(float alpha = 1.f, float beta = 1.f) : alpha_(alpha), beta_(beta){};
template <typename Y, typename X0, typename X1>
__host__ __device__ constexpr void operator()(Y&, const X0&, const X1&) const;
......@@ -184,6 +184,14 @@ struct Bilinear
y = alpha_ * x0 + beta_ * x1;
};
template <>
__host__ __device__ constexpr void
operator()<int8_t, int8_t, int8_t>(int8_t& y, const int8_t& x0, const int8_t& x1) const
{
y = type_convert<int8_t>(alpha_ * type_convert<float>(x0) +
beta_ * type_convert<float>(x1));
};
template <>
__host__ __device__ constexpr void
operator()<half_t, half_t, half_t>(half_t& y, const half_t& x0, const half_t& x1) const
......@@ -221,7 +229,8 @@ struct Bilinear
__host__ __device__ constexpr void operator()<std::int8_t, std::int32_t, std::int8_t>(
std::int8_t& y, const std::int32_t& x0, const std::int8_t& x1) const
{
y = type_convert<std::int8_t>(x0 + ck::type_convert<std::int32_t>(x1));
y = type_convert<int8_t>(alpha_ * type_convert<float>(x0) +
beta_ * type_convert<float>(x1));
};
float alpha_;
......
......@@ -21,50 +21,11 @@ struct PassThroughPack2
template <typename Y, typename X>
__host__ __device__ void operator()(Y& y, const X& x) const;
__host__ __device__ constexpr void operator()(ck::f8x2_t& y, const ck::half2_t& x) const
{
// fake conversion
uint16_t t = ck::bit_cast<uint32_t>(x);
y = ck::bit_cast<ck::f8x2_t>(t);
}
__host__ __device__ constexpr void operator()(ck::half2_t& y, const ck::f8x2_t& x) const
{
auto t = type_convert<float2_t>(x);
y = type_convert<half2_t>(t);
}
__host__ __device__ constexpr void operator()(ck::half2_t& y, const ck::half2_t& x) const
{
y = x;
}
__host__ __device__ constexpr void operator()(ck::f8x2_t& y, const ck::f8x2_t& x) const
{
y = x;
}
__host__ __device__ constexpr void operator()(ck::float2_t& y, const ck::float2_t& x) const
{
y = x;
}
__host__ __device__ constexpr void operator()(ck::int8x2_t& y, const ck::int8x2_t& x) const
{
y = x;
}
__host__ __device__ constexpr void operator()(ck::bhalf2_t& y, const ck::bhalf2_t& x) const
{
y = x;
}
__host__ __device__ constexpr void operator()(ck::double2_t& y, const ck::double2_t& x) const
{
y = x;
}
constexpr const static bool is_pack2_invocable = true;
};
struct PassThrough
......
......@@ -9,7 +9,6 @@
#include "ck/tensor_description/tensor_descriptor_helper.hpp"
#include "ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_selector.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_v1.hpp"
#include "ck/tensor_operation/gpu/block/blockwise_gemm_xdlops.hpp"
#include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v4r1.hpp"
#include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v6r1.hpp"
......@@ -96,7 +95,10 @@ template <index_t BlockSize,
typename CBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
LoopScheduler LoopSched = make_default_loop_scheduler(),
PipelineVersion PipelineVer = PipelineVersion::v1,
typename ComputeType = FloatC>
typename ComputeTypeA = FloatC,
typename ComputeTypeB = ComputeTypeA,
typename LDSTypeA = ComputeTypeA,
typename LDSTypeB = ComputeTypeB>
struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2
{
static constexpr auto I0 = Number<0>{};
......@@ -430,7 +432,8 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2
constexpr auto c_block_size =
GetCBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock().GetElementSpaceSize();
return math::max((a_block_space_size + b_block_space_size) * sizeof(ComputeType),
return math::max(a_block_space_size * sizeof(LDSTypeA) +
b_block_space_size * sizeof(LDSTypeB),
c_block_size * sizeof(FloatC));
}
......@@ -785,7 +788,7 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2
ABlockTransferThreadClusterLengths_K0_M_K1,
ABlockTransferThreadClusterArrangeOrder,
FloatA,
ComputeType,
LDSTypeA,
decltype(a_b_k0_m_k1_grid_desc),
decltype(a_b_k0_m_k1_block_desc),
ABlockTransferSrcAccessOrder,
......@@ -815,7 +818,7 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2
BBlockTransferThreadClusterLengths_K0_N_K1,
BBlockTransferThreadClusterArrangeOrder,
FloatB,
ComputeType,
LDSTypeB,
decltype(b_b_k0_n_k1_grid_desc),
decltype(b_b_k0_n_k1_block_desc),
BBlockTransferSrcAccessOrder,
......@@ -845,8 +848,8 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2
auto blockwise_gemm = BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_Selector<
BlockSize,
ComputeType, // ComputeType A
ComputeType, // ComputeType B
LDSTypeA,
LDSTypeB,
FloatAcc,
decltype(a_k0_m_k1_block_desc),
decltype(b_k0_n_k1_block_desc),
......@@ -855,7 +858,9 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2
MRepeat,
NRepeat,
K1,
LoopSched>();
LoopSched,
ComputeTypeA,
ComputeTypeB>();
auto c_thread_buf = blockwise_gemm.GetCThreadBuffer();
......@@ -863,8 +868,8 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2
constexpr auto a_block_space_size =
math::integer_least_multiple(a_k0_m_k1_block_desc.GetElementSpaceSize(), max_lds_align);
ComputeType* p_a_block = static_cast<ComputeType*>(p_shared_block);
ComputeType* p_b_block = static_cast<ComputeType*>(p_shared_block) + a_block_space_size;
auto p_a_block = reinterpret_cast<LDSTypeA*>(p_shared_block);
auto p_b_block = reinterpret_cast<LDSTypeB*>(p_a_block + a_block_space_size);
constexpr auto a_block_slice_copy_step = make_multi_index(0, K0PerBlock, 0, 0);
constexpr auto b_block_slice_copy_step = make_multi_index(0, K0PerBlock, 0, 0);
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment