Commit e2878e25 authored by Alan Turner's avatar Alan Turner
Browse files

Merge remote-tracking branch 'origin/develop' into migx-jit-lib

parents 1ec96717 642d5e91
...@@ -19,7 +19,7 @@ def runShell(String command){ ...@@ -19,7 +19,7 @@ def runShell(String command){
def getDockerImageName(){ def getDockerImageName(){
def img def img
if (params.ROCMVERSION != "5.5" && params.ROCMVERSION != "5.6"){ if (params.ROCMVERSION != "5.6"){
if (params.COMPILER_VERSION == "") { if (params.COMPILER_VERSION == "") {
img = "${env.CK_DOCKERHUB}:ck_ub20.04_rocm${params.ROCMVERSION}" img = "${env.CK_DOCKERHUB}:ck_ub20.04_rocm${params.ROCMVERSION}"
} }
...@@ -597,7 +597,7 @@ def process_results(Map conf=[:]){ ...@@ -597,7 +597,7 @@ def process_results(Map conf=[:]){
//launch develop branch daily at 23:00 UT in FULL_QA mode and at 19:00 UT with latest staging compiler version //launch develop branch daily at 23:00 UT in FULL_QA mode and at 19:00 UT with latest staging compiler version
CRON_SETTINGS = BRANCH_NAME == "develop" ? '''0 23 * * * % RUN_FULL_QA=true CRON_SETTINGS = BRANCH_NAME == "develop" ? '''0 23 * * * % RUN_FULL_QA=true
0 21 * * * % ROCMVERSION=5.4.3;COMPILER_VERSION=release;COMPILER_COMMIT= 0 21 * * * % ROCMVERSION=5.5;COMPILER_VERSION=release;COMPILER_COMMIT=
0 19 * * * % BUILD_DOCKER=true;COMPILER_VERSION=amd-stg-open;COMPILER_COMMIT=''' : "" 0 19 * * * % BUILD_DOCKER=true;COMPILER_VERSION=amd-stg-open;COMPILER_COMMIT=''' : ""
pipeline { pipeline {
......
...@@ -92,6 +92,7 @@ else() ...@@ -92,6 +92,7 @@ else()
-Wno-unused-command-line-argument -Wno-unused-command-line-argument
-Wno-weak-vtables -Wno-weak-vtables
-Wno-covered-switch-default -Wno-covered-switch-default
-Wno-unsafe-buffer-usage
) )
else() else()
if (CMAKE_${COMPILER}_COMPILER_ID MATCHES "GNU" AND ${COMPILER} MATCHES "CXX") if (CMAKE_${COMPILER}_COMPILER_ID MATCHES "GNU" AND ${COMPILER} MATCHES "CXX")
......
...@@ -5,6 +5,7 @@ add_example_executable(example_grouped_gemm_xdl_fp16 grouped_gemm_xdl_fp16.cpp) ...@@ -5,6 +5,7 @@ add_example_executable(example_grouped_gemm_xdl_fp16 grouped_gemm_xdl_fp16.cpp)
add_example_executable(example_grouped_gemm_xdl_bfp16 grouped_gemm_xdl_bfp16.cpp) add_example_executable(example_grouped_gemm_xdl_bfp16 grouped_gemm_xdl_bfp16.cpp)
add_example_executable(example_grouped_gemm_xdl_int8 grouped_gemm_xdl_int8.cpp) add_example_executable(example_grouped_gemm_xdl_int8 grouped_gemm_xdl_int8.cpp)
add_example_executable(example_grouped_gemm_multiple_d_dl_fp16 grouped_gemm_multiple_d_dl_fp16.cpp) add_example_executable(example_grouped_gemm_multiple_d_dl_fp16 grouped_gemm_multiple_d_dl_fp16.cpp)
add_example_executable(example_grouped_gemm_xdl_splitk_fp16 grouped_gemm_xdl_splitk_fp16.cpp)
add_dependencies(example_grouped_gemm_xdl add_dependencies(example_grouped_gemm_xdl
...@@ -12,7 +13,8 @@ add_dependencies(example_grouped_gemm_xdl ...@@ -12,7 +13,8 @@ add_dependencies(example_grouped_gemm_xdl
example_grouped_gemm_xdl_fp16 example_grouped_gemm_xdl_fp16
example_grouped_gemm_xdl_bfp16 example_grouped_gemm_xdl_bfp16
example_grouped_gemm_xdl_int8 example_grouped_gemm_xdl_int8
example_grouped_gemm_multiple_d_dl_fp16) example_grouped_gemm_multiple_d_dl_fp16
example_grouped_gemm_xdl_splitk_fp16)
if(USE_BITINT_EXTENSION_INT4) if(USE_BITINT_EXTENSION_INT4)
add_example_executable(example_grouped_gemm_xdl_int4 grouped_gemm_xdl_int4.cpp) add_example_executable(example_grouped_gemm_xdl_int4 grouped_gemm_xdl_int4.cpp)
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#include <iostream>
#include <numeric>
#include <initializer_list>
#include <cstdlib>
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_grouped_gemm_xdl_splitk_cshuffle.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/library/utility/check_err.hpp"
#include "ck/library/utility/device_memory.hpp"
#include "ck/library/utility/host_tensor.hpp"
#include "ck/library/utility/host_tensor_generator.hpp"
#include "ck/library/utility/literals.hpp"
#include "ck/library/reference_tensor_operation/cpu/reference_gemm.hpp"
template <ck::index_t... Is>
using S = ck::Sequence<Is...>;
using F16 = ck::half_t;
using F32 = float;
using Row = ck::tensor_layout::gemm::RowMajor;
using Col = ck::tensor_layout::gemm::ColumnMajor;
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
using ADataType = F16;
using BDataType = F16;
using AccDataType = F32;
using CShuffleDataType = F16;
using DsDataType = ck::Tuple<>;
using EDataType = F16;
using ALayout = Row;
using BLayout = Col;
using DsLayout = ck::Tuple<>;
using ELayout = Row;
using AElementOp = PassThrough;
using BElementOp = PassThrough;
using CDEElementOp = PassThrough;
static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::MNKPadding;
using DeviceGemmInstance = ck::tensor_operation::device::DeviceGroupedGemmXdlSplitKCShuffle
// clang-format off
//######| ALayout| BLayout| DsLayout| ELayout| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer|
//######| | | | | Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector|
//######| | | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl|
//######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
< ALayout, BLayout, DsLayout, ELayout, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmDefault, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>;
// clang-format on
#include "run_grouped_gemm_example.inc"
int main(int argc, char* argv[])
{
ProblemSize problem_size;
ExecutionConfig config;
problem_size.group_count = 16;
problem_size.Ms = {
167, 183, 177, 181, 153, 139, 156, 173, 163, 150, 204, 184, 168, 156, 168, 148};
for(int i = 0; i < problem_size.group_count; i++)
{
problem_size.Ns.push_back(768);
problem_size.Ks.push_back(4608);
problem_size.stride_As.push_back(problem_size.Ks[i]);
problem_size.stride_Bs.push_back(problem_size.Ks[i]);
problem_size.stride_Cs.push_back(problem_size.Ns[i]);
}
if(argc == 4)
{
config.do_verification = std::stoi(argv[1]);
config.init_method = std::stoi(argv[2]);
config.time_kernel = std::stoi(argv[3]);
}
else
{
printf("arg1: verification (0=no, 1=yes)\n");
printf("arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n");
printf("arg3: time kernel (0=n0, 1=yes)\n");
exit(0);
}
return !run_grouped_gemm(problem_size, config);
}
...@@ -147,6 +147,7 @@ bool run_grouped_gemm(const ProblemSize& problem_size, const ExecutionConfig& co ...@@ -147,6 +147,7 @@ bool run_grouped_gemm(const ProblemSize& problem_size, const ExecutionConfig& co
#else #else
a_tensors_device[i]->ToDevice(a_tensors[i].mData.data()); a_tensors_device[i]->ToDevice(a_tensors[i].mData.data());
b_tensors_device[i]->ToDevice(b_tensors[i].mData.data()); b_tensors_device[i]->ToDevice(b_tensors[i].mData.data());
c_tensors_device[i]->SetZero();
#endif #endif
p_a.push_back(a_tensors_device[i]->GetDeviceBuffer()); p_a.push_back(a_tensors_device[i]->GetDeviceBuffer());
......
...@@ -16,6 +16,7 @@ ...@@ -16,6 +16,7 @@
#include "ck/library/utility/host_tensor.hpp" #include "ck/library/utility/host_tensor.hpp"
#include "ck/library/utility/host_tensor_generator.hpp" #include "ck/library/utility/host_tensor_generator.hpp"
#include "ck/library/utility/numeric.hpp" #include "ck/library/utility/numeric.hpp"
#include "ck/library/reference_tensor_operation/cpu/reference_contraction.hpp"
template <ck::index_t... Is> template <ck::index_t... Is>
using S = ck::Sequence<Is...>; using S = ck::Sequence<Is...>;
...@@ -74,141 +75,6 @@ using DeviceOpInstanceMNNN = ck::tensor_operation::device:: ...@@ -74,141 +75,6 @@ using DeviceOpInstanceMNNN = ck::tensor_operation::device::
using DeviceOpInstance = DeviceOpInstanceKKNN; using DeviceOpInstance = DeviceOpInstanceKKNN;
// hardcoded for NumDimM == NumDimN == NumDimK == 2
template <ck::index_t NumDimM,
ck::index_t NumDimN,
ck::index_t NumDimK,
typename ADataType,
typename BDataType,
typename EDataType,
typename AccDataType,
typename AElementwiseOperation,
typename BElementwiseOperation,
typename CDEElementwiseOperation,
ck::enable_if_t<NumDimM == 2 && NumDimN == 2 && NumDimK == 2, bool> = false>
struct ReferenceContraction_M2_N2_K2 : public ck::tensor_operation::device::BaseOperator
{
// Argument
struct Argument : public ck::tensor_operation::device::BaseArgument
{
Argument(const Tensor<ADataType>& a_ms_ks,
const Tensor<BDataType>& b_ns_ks,
Tensor<EDataType>& e_ms_ns,
AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op,
CDEElementwiseOperation cde_element_op)
: a_ms_ks_{a_ms_ks},
b_ns_ks_{b_ns_ks},
e_ms_ns_{e_ms_ns},
a_element_op_{a_element_op},
b_element_op_{b_element_op},
cde_element_op_{cde_element_op}
{
}
const Tensor<ADataType>& a_ms_ks_;
const Tensor<BDataType>& b_ns_ks_;
Tensor<EDataType>& e_ms_ns_;
AElementwiseOperation a_element_op_;
BElementwiseOperation b_element_op_;
CDEElementwiseOperation cde_element_op_;
};
// Invoker
struct Invoker : public ck::tensor_operation::device::BaseInvoker
{
using Argument = ReferenceContraction_M2_N2_K2::Argument;
float Run(const Argument& arg)
{
auto f_ms_ns = [&](auto m0, auto m1, auto n0, auto n1) {
const int K0 = arg.a_ms_ks_.mDesc.GetLengths()[2];
const int K1 = arg.a_ms_ks_.mDesc.GetLengths()[3];
AccDataType v_acc = 0;
for(int k0 = 0; k0 < K0; ++k0)
{
for(int k1 = 0; k1 < K1; ++k1)
{
AccDataType v_a;
AccDataType v_b;
arg.a_element_op_(
v_a, ck::type_convert<const AccDataType>(arg.a_ms_ks_(m0, m1, k0, k1)));
arg.b_element_op_(
v_b, ck::type_convert<const AccDataType>(arg.b_ns_ks_(n0, n1, k0, k1)));
v_acc += v_a * v_b;
}
}
AccDataType v_c;
arg.cde_element_op_(v_c, v_acc);
arg.e_ms_ns_(m0, m1, n0, n1) = v_c;
};
make_ParallelTensorFunctor(f_ms_ns,
arg.e_ms_ns_.mDesc.GetLengths()[0],
arg.e_ms_ns_.mDesc.GetLengths()[1],
arg.e_ms_ns_.mDesc.GetLengths()[2],
arg.e_ms_ns_.mDesc.GetLengths()[3])(
std::thread::hardware_concurrency());
return 0;
}
float Run(const ck::tensor_operation::device::BaseArgument* p_arg,
const StreamConfig& /* stream_config */ = StreamConfig{}) override
{
return Run(*dynamic_cast<const Argument*>(p_arg));
}
};
static constexpr bool IsValidCompilationParameter()
{
// TODO: properly implement this check
return true;
}
bool IsSupportedArgument(const ck::tensor_operation::device::BaseArgument*) override
{
return true;
}
static auto MakeArgument(const Tensor<ADataType>& a_ms_ks,
const Tensor<BDataType>& b_ns_ks,
Tensor<EDataType>& e_ms_ns,
AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op,
CDEElementwiseOperation cde_element_op)
{
return Argument{a_ms_ks, b_ns_ks, e_ms_ns, a_element_op, b_element_op, cde_element_op};
}
static auto MakeInvoker() { return Invoker{}; }
virtual std::unique_ptr<ck::tensor_operation::device::BaseInvoker> MakeInvokerPointer()
{
return std::make_unique<Invoker>(Invoker{});
}
std::string GetTypeString() const override
{
auto str = std::stringstream();
// clang-format off
str << "ReferenceContraction_M2_N2_K2"
<< std::endl;
// clang-format on
return str.str();
}
};
int main(int argc, char* argv[]) int main(int argc, char* argv[])
{ {
bool do_verification = true; bool do_verification = true;
...@@ -385,22 +251,22 @@ int main(int argc, char* argv[]) ...@@ -385,22 +251,22 @@ int main(int argc, char* argv[])
{ {
Tensor<CShuffleDataType> c_ms_ns_host_result(e_ms_ns_lengths, e_ms_ns_strides); Tensor<CShuffleDataType> c_ms_ns_host_result(e_ms_ns_lengths, e_ms_ns_strides);
using ReferenceOpInstance = ReferenceContraction_M2_N2_K2<NumDimM, using ReferenceOpInstance =
NumDimN, ck::tensor_operation::host::ReferenceContraction_M2_N2_K2<NumDimM,
NumDimK, NumDimN,
ADataType, NumDimK,
BDataType, ADataType,
CShuffleDataType, BDataType,
AccDataType, CShuffleDataType,
AElementOp, AccDataType,
BElementOp, AElementOp,
PassThrough>; BElementOp>;
auto ref_gemm = ReferenceOpInstance{}; auto ref_op = ReferenceOpInstance{};
auto ref_invoker = ref_gemm.MakeInvoker(); auto ref_invoker = ref_op.MakeInvoker();
auto ref_argument = ref_gemm.MakeArgument( auto ref_argument =
a_ms_ks, b_ns_ks, c_ms_ns_host_result, a_element_op, b_element_op, PassThrough{}); ref_op.MakeArgument(a_ms_ks, b_ns_ks, c_ms_ns_host_result, a_element_op, b_element_op);
ref_invoker.Run(ref_argument); ref_invoker.Run(ref_argument);
......
...@@ -16,6 +16,7 @@ ...@@ -16,6 +16,7 @@
#include "ck/library/utility/host_tensor.hpp" #include "ck/library/utility/host_tensor.hpp"
#include "ck/library/utility/host_tensor_generator.hpp" #include "ck/library/utility/host_tensor_generator.hpp"
#include "ck/library/utility/numeric.hpp" #include "ck/library/utility/numeric.hpp"
#include "ck/library/reference_tensor_operation/cpu/reference_contraction.hpp"
template <ck::index_t... Is> template <ck::index_t... Is>
using S = ck::Sequence<Is...>; using S = ck::Sequence<Is...>;
...@@ -74,141 +75,6 @@ using DeviceOpInstanceMNNN = ck::tensor_operation::device:: ...@@ -74,141 +75,6 @@ using DeviceOpInstanceMNNN = ck::tensor_operation::device::
using DeviceOpInstance = DeviceOpInstanceKKNN; using DeviceOpInstance = DeviceOpInstanceKKNN;
// hardcoded for NumDimM == NumDimN == NumDimK == 2
template <ck::index_t NumDimM,
ck::index_t NumDimN,
ck::index_t NumDimK,
typename ADataType,
typename BDataType,
typename EDataType,
typename AccDataType,
typename AElementwiseOperation,
typename BElementwiseOperation,
typename CDEElementwiseOperation,
ck::enable_if_t<NumDimM == 2 && NumDimN == 2 && NumDimK == 2, bool> = false>
struct ReferenceContraction_M2_N2_K2 : public ck::tensor_operation::device::BaseOperator
{
// Argument
struct Argument : public ck::tensor_operation::device::BaseArgument
{
Argument(const Tensor<ADataType>& a_ms_ks,
const Tensor<BDataType>& b_ns_ks,
Tensor<EDataType>& e_ms_ns,
AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op,
CDEElementwiseOperation cde_element_op)
: a_ms_ks_{a_ms_ks},
b_ns_ks_{b_ns_ks},
e_ms_ns_{e_ms_ns},
a_element_op_{a_element_op},
b_element_op_{b_element_op},
cde_element_op_{cde_element_op}
{
}
const Tensor<ADataType>& a_ms_ks_;
const Tensor<BDataType>& b_ns_ks_;
Tensor<EDataType>& e_ms_ns_;
AElementwiseOperation a_element_op_;
BElementwiseOperation b_element_op_;
CDEElementwiseOperation cde_element_op_;
};
// Invoker
struct Invoker : public ck::tensor_operation::device::BaseInvoker
{
using Argument = ReferenceContraction_M2_N2_K2::Argument;
float Run(const Argument& arg)
{
auto f_ms_ns = [&](auto m0, auto m1, auto n0, auto n1) {
const int K0 = arg.a_ms_ks_.mDesc.GetLengths()[2];
const int K1 = arg.a_ms_ks_.mDesc.GetLengths()[3];
AccDataType v_acc = 0;
for(int k0 = 0; k0 < K0; ++k0)
{
for(int k1 = 0; k1 < K1; ++k1)
{
AccDataType v_a;
AccDataType v_b;
arg.a_element_op_(
v_a, ck::type_convert<const AccDataType>(arg.a_ms_ks_(m0, m1, k0, k1)));
arg.b_element_op_(
v_b, ck::type_convert<const AccDataType>(arg.b_ns_ks_(n0, n1, k0, k1)));
v_acc += v_a * v_b;
}
}
AccDataType v_c;
arg.cde_element_op_(v_c, v_acc);
arg.e_ms_ns_(m0, m1, n0, n1) = v_c;
};
make_ParallelTensorFunctor(f_ms_ns,
arg.e_ms_ns_.mDesc.GetLengths()[0],
arg.e_ms_ns_.mDesc.GetLengths()[1],
arg.e_ms_ns_.mDesc.GetLengths()[2],
arg.e_ms_ns_.mDesc.GetLengths()[3])(
std::thread::hardware_concurrency());
return 0;
}
float Run(const ck::tensor_operation::device::BaseArgument* p_arg,
const StreamConfig& /* stream_config */ = StreamConfig{}) override
{
return Run(*dynamic_cast<const Argument*>(p_arg));
}
};
static constexpr bool IsValidCompilationParameter()
{
// TODO: properly implement this check
return true;
}
bool IsSupportedArgument(const ck::tensor_operation::device::BaseArgument*) override
{
return true;
}
static auto MakeArgument(const Tensor<ADataType>& a_ms_ks,
const Tensor<BDataType>& b_ns_ks,
Tensor<EDataType>& e_ms_ns,
AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op,
CDEElementwiseOperation cde_element_op)
{
return Argument{a_ms_ks, b_ns_ks, e_ms_ns, a_element_op, b_element_op, cde_element_op};
}
static auto MakeInvoker() { return Invoker{}; }
virtual std::unique_ptr<ck::tensor_operation::device::BaseInvoker> MakeInvokerPointer()
{
return std::make_unique<Invoker>(Invoker{});
}
std::string GetTypeString() const override
{
auto str = std::stringstream();
// clang-format off
str << "ReferenceContraction_M2_N2_K2"
<< std::endl;
// clang-format on
return str.str();
}
};
int main(int argc, char* argv[]) int main(int argc, char* argv[])
{ {
bool do_verification = true; bool do_verification = true;
...@@ -385,22 +251,22 @@ int main(int argc, char* argv[]) ...@@ -385,22 +251,22 @@ int main(int argc, char* argv[])
{ {
Tensor<CShuffleDataType> c_ms_ns_host_result(e_ms_ns_lengths, e_ms_ns_strides); Tensor<CShuffleDataType> c_ms_ns_host_result(e_ms_ns_lengths, e_ms_ns_strides);
using ReferenceOpInstance = ReferenceContraction_M2_N2_K2<NumDimM, using ReferenceOpInstance =
NumDimN, ck::tensor_operation::host::ReferenceContraction_M2_N2_K2<NumDimM,
NumDimK, NumDimN,
ADataType, NumDimK,
BDataType, ADataType,
CShuffleDataType, BDataType,
AccDataType, CShuffleDataType,
AElementOp, AccDataType,
BElementOp, AElementOp,
PassThrough>; BElementOp>;
auto ref_gemm = ReferenceOpInstance{}; auto ref_op = ReferenceOpInstance{};
auto ref_invoker = ref_gemm.MakeInvoker(); auto ref_invoker = ref_op.MakeInvoker();
auto ref_argument = ref_gemm.MakeArgument( auto ref_argument =
a_ms_ks, b_ns_ks, c_ms_ns_host_result, a_element_op, b_element_op, PassThrough{}); ref_op.MakeArgument(a_ms_ks, b_ns_ks, c_ms_ns_host_result, a_element_op, b_element_op);
ref_invoker.Run(ref_argument); ref_invoker.Run(ref_argument);
......
...@@ -16,6 +16,7 @@ ...@@ -16,6 +16,7 @@
#include "ck/library/utility/host_tensor.hpp" #include "ck/library/utility/host_tensor.hpp"
#include "ck/library/utility/host_tensor_generator.hpp" #include "ck/library/utility/host_tensor_generator.hpp"
#include "ck/library/utility/numeric.hpp" #include "ck/library/utility/numeric.hpp"
#include "ck/library/reference_tensor_operation/cpu/reference_contraction.hpp"
template <ck::index_t... Is> template <ck::index_t... Is>
using S = ck::Sequence<Is...>; using S = ck::Sequence<Is...>;
...@@ -73,141 +74,6 @@ using DeviceOpInstanceMNN = ck::tensor_operation::device:: ...@@ -73,141 +74,6 @@ using DeviceOpInstanceMNN = ck::tensor_operation::device::
using DeviceOpInstance = DeviceOpInstanceKKN; using DeviceOpInstance = DeviceOpInstanceKKN;
// hardcoded for NumDimM == NumDimN == NumDimK == 2
template <ck::index_t NumDimM,
ck::index_t NumDimN,
ck::index_t NumDimK,
typename ADataType,
typename BDataType,
typename EDataType,
typename AccDataType,
typename AElementwiseOperation,
typename BElementwiseOperation,
typename CDEElementwiseOperation,
ck::enable_if_t<NumDimM == 2 && NumDimN == 2 && NumDimK == 2, bool> = false>
struct ReferenceContraction_M2_N2_K2 : public ck::tensor_operation::device::BaseOperator
{
// Argument
struct Argument : public ck::tensor_operation::device::BaseArgument
{
Argument(const Tensor<ADataType>& a_ms_ks,
const Tensor<BDataType>& b_ns_ks,
Tensor<EDataType>& e_ms_ns,
AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op,
CDEElementwiseOperation cde_element_op)
: a_ms_ks_{a_ms_ks},
b_ns_ks_{b_ns_ks},
e_ms_ns_{e_ms_ns},
a_element_op_{a_element_op},
b_element_op_{b_element_op},
cde_element_op_{cde_element_op}
{
}
const Tensor<ADataType>& a_ms_ks_;
const Tensor<BDataType>& b_ns_ks_;
Tensor<EDataType>& e_ms_ns_;
AElementwiseOperation a_element_op_;
BElementwiseOperation b_element_op_;
CDEElementwiseOperation cde_element_op_;
};
// Invoker
struct Invoker : public ck::tensor_operation::device::BaseInvoker
{
using Argument = ReferenceContraction_M2_N2_K2::Argument;
float Run(const Argument& arg)
{
auto f_ms_ns = [&](auto m0, auto m1, auto n0, auto n1) {
const int K0 = arg.a_ms_ks_.mDesc.GetLengths()[2];
const int K1 = arg.a_ms_ks_.mDesc.GetLengths()[3];
AccDataType v_acc = 0;
for(int k0 = 0; k0 < K0; ++k0)
{
for(int k1 = 0; k1 < K1; ++k1)
{
AccDataType v_a;
AccDataType v_b;
arg.a_element_op_(
v_a, ck::type_convert<const AccDataType>(arg.a_ms_ks_(m0, m1, k0, k1)));
arg.b_element_op_(
v_b, ck::type_convert<const AccDataType>(arg.b_ns_ks_(n0, n1, k0, k1)));
v_acc += v_a * v_b;
}
}
AccDataType v_c;
arg.cde_element_op_(v_c, v_acc);
arg.e_ms_ns_(m0, m1, n0, n1) = v_c;
};
make_ParallelTensorFunctor(f_ms_ns,
arg.e_ms_ns_.mDesc.GetLengths()[0],
arg.e_ms_ns_.mDesc.GetLengths()[1],
arg.e_ms_ns_.mDesc.GetLengths()[2],
arg.e_ms_ns_.mDesc.GetLengths()[3])(
std::thread::hardware_concurrency());
return 0;
}
float Run(const ck::tensor_operation::device::BaseArgument* p_arg,
const StreamConfig& /* stream_config */ = StreamConfig{}) override
{
return Run(*dynamic_cast<const Argument*>(p_arg));
}
};
static constexpr bool IsValidCompilationParameter()
{
// TODO: properly implement this check
return true;
}
bool IsSupportedArgument(const ck::tensor_operation::device::BaseArgument*) override
{
return true;
}
static auto MakeArgument(const Tensor<ADataType>& a_ms_ks,
const Tensor<BDataType>& b_ns_ks,
Tensor<EDataType>& e_ms_ns,
AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op,
CDEElementwiseOperation cde_element_op)
{
return Argument{a_ms_ks, b_ns_ks, e_ms_ns, a_element_op, b_element_op, cde_element_op};
}
static auto MakeInvoker() { return Invoker{}; }
virtual std::unique_ptr<ck::tensor_operation::device::BaseInvoker> MakeInvokerPointer()
{
return std::make_unique<Invoker>(Invoker{});
}
std::string GetTypeString() const override
{
auto str = std::stringstream();
// clang-format off
str << "ReferenceContraction_M2_N2_K2"
<< std::endl;
// clang-format on
return str.str();
}
};
int main(int argc, char* argv[]) int main(int argc, char* argv[])
{ {
bool do_verification = true; bool do_verification = true;
...@@ -368,22 +234,23 @@ int main(int argc, char* argv[]) ...@@ -368,22 +234,23 @@ int main(int argc, char* argv[])
{ {
Tensor<CShuffleDataType> c_ms_ns_host_result(e_ms_ns_lengths, e_ms_ns_strides); Tensor<CShuffleDataType> c_ms_ns_host_result(e_ms_ns_lengths, e_ms_ns_strides);
using ReferenceOpInstance = ReferenceContraction_M2_N2_K2<NumDimM, using ReferenceOpInstance =
NumDimN, ck::tensor_operation::host::ReferenceContraction_M2_N2_K2<NumDimM,
NumDimK, NumDimN,
ADataType, NumDimK,
BDataType, ADataType,
CShuffleDataType, BDataType,
AccDataType, CShuffleDataType,
AElementOp, AccDataType,
BElementOp, AElementOp,
PassThrough>; BElementOp>;
auto ref_gemm = ReferenceOpInstance{}; auto ref_op = ReferenceOpInstance{};
auto ref_invoker = ref_gemm.MakeInvoker(); auto ref_invoker = ref_op.MakeInvoker();
auto ref_argument = ref_gemm.MakeArgument( Tensor<float> empty_tensor(std::vector<ck::index_t>{}, std::vector<ck::index_t>{});
a_ms_ks, b_ns_ks, c_ms_ns_host_result, a_element_op, b_element_op, PassThrough{}); auto ref_argument =
ref_op.MakeArgument(a_ms_ks, b_ns_ks, c_ms_ns_host_result, a_element_op, b_element_op);
ref_invoker.Run(ref_argument); ref_invoker.Run(ref_argument);
......
...@@ -16,6 +16,7 @@ ...@@ -16,6 +16,7 @@
#include "ck/library/utility/host_tensor.hpp" #include "ck/library/utility/host_tensor.hpp"
#include "ck/library/utility/host_tensor_generator.hpp" #include "ck/library/utility/host_tensor_generator.hpp"
#include "ck/library/utility/numeric.hpp" #include "ck/library/utility/numeric.hpp"
#include "ck/library/reference_tensor_operation/cpu/reference_contraction.hpp"
template <ck::index_t... Is> template <ck::index_t... Is>
using S = ck::Sequence<Is...>; using S = ck::Sequence<Is...>;
...@@ -73,141 +74,6 @@ using DeviceOpInstanceMNN = ck::tensor_operation::device:: ...@@ -73,141 +74,6 @@ using DeviceOpInstanceMNN = ck::tensor_operation::device::
using DeviceOpInstance = DeviceOpInstanceKKN; using DeviceOpInstance = DeviceOpInstanceKKN;
// hardcoded for NumDimM == NumDimN == NumDimK == 2
template <ck::index_t NumDimM,
ck::index_t NumDimN,
ck::index_t NumDimK,
typename ADataType,
typename BDataType,
typename EDataType,
typename AccDataType,
typename AElementwiseOperation,
typename BElementwiseOperation,
typename CDEElementwiseOperation,
ck::enable_if_t<NumDimM == 2 && NumDimN == 2 && NumDimK == 2, bool> = false>
struct ReferenceContraction_M2_N2_K2 : public ck::tensor_operation::device::BaseOperator
{
// Argument
struct Argument : public ck::tensor_operation::device::BaseArgument
{
Argument(const Tensor<ADataType>& a_ms_ks,
const Tensor<BDataType>& b_ns_ks,
Tensor<EDataType>& e_ms_ns,
AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op,
CDEElementwiseOperation cde_element_op)
: a_ms_ks_{a_ms_ks},
b_ns_ks_{b_ns_ks},
e_ms_ns_{e_ms_ns},
a_element_op_{a_element_op},
b_element_op_{b_element_op},
cde_element_op_{cde_element_op}
{
}
const Tensor<ADataType>& a_ms_ks_;
const Tensor<BDataType>& b_ns_ks_;
Tensor<EDataType>& e_ms_ns_;
AElementwiseOperation a_element_op_;
BElementwiseOperation b_element_op_;
CDEElementwiseOperation cde_element_op_;
};
// Invoker
struct Invoker : public ck::tensor_operation::device::BaseInvoker
{
using Argument = ReferenceContraction_M2_N2_K2::Argument;
float Run(const Argument& arg)
{
auto f_ms_ns = [&](auto m0, auto m1, auto n0, auto n1) {
const int K0 = arg.a_ms_ks_.mDesc.GetLengths()[2];
const int K1 = arg.a_ms_ks_.mDesc.GetLengths()[3];
AccDataType v_acc = 0;
for(int k0 = 0; k0 < K0; ++k0)
{
for(int k1 = 0; k1 < K1; ++k1)
{
AccDataType v_a;
AccDataType v_b;
arg.a_element_op_(
v_a, ck::type_convert<const AccDataType>(arg.a_ms_ks_(m0, m1, k0, k1)));
arg.b_element_op_(
v_b, ck::type_convert<const AccDataType>(arg.b_ns_ks_(n0, n1, k0, k1)));
v_acc += v_a * v_b;
}
}
AccDataType v_c;
arg.cde_element_op_(v_c, v_acc);
arg.e_ms_ns_(m0, m1, n0, n1) = v_c;
};
make_ParallelTensorFunctor(f_ms_ns,
arg.e_ms_ns_.mDesc.GetLengths()[0],
arg.e_ms_ns_.mDesc.GetLengths()[1],
arg.e_ms_ns_.mDesc.GetLengths()[2],
arg.e_ms_ns_.mDesc.GetLengths()[3])(
std::thread::hardware_concurrency());
return 0;
}
float Run(const ck::tensor_operation::device::BaseArgument* p_arg,
const StreamConfig& /* stream_config */ = StreamConfig{}) override
{
return Run(*dynamic_cast<const Argument*>(p_arg));
}
};
static constexpr bool IsValidCompilationParameter()
{
// TODO: properly implement this check
return true;
}
bool IsSupportedArgument(const ck::tensor_operation::device::BaseArgument*) override
{
return true;
}
static auto MakeArgument(const Tensor<ADataType>& a_ms_ks,
const Tensor<BDataType>& b_ns_ks,
Tensor<EDataType>& e_ms_ns,
AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op,
CDEElementwiseOperation cde_element_op)
{
return Argument{a_ms_ks, b_ns_ks, e_ms_ns, a_element_op, b_element_op, cde_element_op};
}
static auto MakeInvoker() { return Invoker{}; }
virtual std::unique_ptr<ck::tensor_operation::device::BaseInvoker> MakeInvokerPointer()
{
return std::make_unique<Invoker>(Invoker{});
}
std::string GetTypeString() const override
{
auto str = std::stringstream();
// clang-format off
str << "ReferenceContraction_M2_N2_K2"
<< std::endl;
// clang-format on
return str.str();
}
};
int main(int argc, char* argv[]) int main(int argc, char* argv[])
{ {
bool do_verification = true; bool do_verification = true;
...@@ -368,22 +234,23 @@ int main(int argc, char* argv[]) ...@@ -368,22 +234,23 @@ int main(int argc, char* argv[])
{ {
Tensor<CShuffleDataType> c_ms_ns_host_result(e_ms_ns_lengths, e_ms_ns_strides); Tensor<CShuffleDataType> c_ms_ns_host_result(e_ms_ns_lengths, e_ms_ns_strides);
using ReferenceOpInstance = ReferenceContraction_M2_N2_K2<NumDimM, using ReferenceOpInstance =
NumDimN, ck::tensor_operation::host::ReferenceContraction_M2_N2_K2<NumDimM,
NumDimK, NumDimN,
ADataType, NumDimK,
BDataType, ADataType,
CShuffleDataType, BDataType,
AccDataType, CShuffleDataType,
AElementOp, AccDataType,
BElementOp, AElementOp,
PassThrough>; BElementOp>;
auto ref_gemm = ReferenceOpInstance{}; auto ref_op = ReferenceOpInstance{};
auto ref_invoker = ref_gemm.MakeInvoker(); auto ref_invoker = ref_op.MakeInvoker();
auto ref_argument = ref_gemm.MakeArgument( Tensor<float> empty_tensor(std::vector<ck::index_t>{}, std::vector<ck::index_t>{});
a_ms_ks, b_ns_ks, c_ms_ns_host_result, a_element_op, b_element_op, PassThrough{}); auto ref_argument =
ref_op.MakeArgument(a_ms_ks, b_ns_ks, c_ms_ns_host_result, a_element_op, b_element_op);
ref_invoker.Run(ref_argument); ref_invoker.Run(ref_argument);
......
add_example_executable(example_layernorm_blockwise layernorm_blockwise.cpp) add_example_executable(example_layernorm_fp16 layernorm_fp16.cpp)
add_example_executable(example_layernorm_splitk_fp16 layernorm_splitk_fp16.cpp)
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <iostream>
#include <numeric>
#include <initializer_list>
#include <cstdlib>
#include <getopt.h>
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_normalization_impl.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_normalization_splitk_impl.hpp"
#include "ck/library/utility/check_err.hpp"
#include "ck/library/utility/device_memory.hpp"
#include "ck/library/utility/host_common_util.hpp"
#include "ck/library/utility/host_tensor.hpp"
#include "ck/library/utility/host_tensor_generator.hpp"
#include "ck/library/utility/literals.hpp"
#include "ck/library/reference_tensor_operation/cpu/reference_layernorm.hpp"
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#include "common.hpp"
using XDataType = ck::half_t;
using GammaDataType = ck::half_t;
using BetaDataType = ck::half_t;
using YDataType = ck::half_t;
using ComputeDataType = float;
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
constexpr int Rank = 2;
constexpr int NumReduceDim = 1;
using DeviceInstance =
ck::tensor_operation::device::DeviceNormalizationImpl<XDataType,
GammaDataType,
BetaDataType,
ComputeDataType,
YDataType,
PassThrough,
Rank,
NumReduceDim,
256, // BlockSize
8, // ClusterM
32, // ClusterK
1, // SliceM
8, // SliceK
1, // XYVectorDim (0=M, 1=K)
8, // SrcScalarPerVector
1, // GammaVecDim (0=M, 1=K)
8, // GammaScalarPerVector
1, // BetaVecDim (0=M, 1=K)
8, // BetaScalarPerVector
8>; // OutScalarPerVector
#include "run_layernorm_example.inc"
int main() { return run_groupnorm_example<DeviceInstance>(); }
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#include "common.hpp"
using XDataType = ck::half_t;
using GammaDataType = ck::half_t;
using BetaDataType = ck::half_t;
using YDataType = ck::half_t;
using ComputeDataType = float;
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
constexpr int Rank = 2;
constexpr int NumReduceDim = 1;
using DeviceInstance =
ck::tensor_operation::device::DeviceNormalizationSplitKImpl<XDataType,
GammaDataType,
BetaDataType,
ComputeDataType,
YDataType,
PassThrough,
Rank,
NumReduceDim,
256, // BlockSize
8, // ClusterM
32, // ClusterK
1, // SliceM
8, // SliceK
1, // XYVectorDim (0=M, 1=K)
8, // XScalarPerVector
1, // GammaVecDim (0=M, 1=K)
8, // GammaScalarPerVector
1, // BetaVecDim (0=M, 1=K)
8, // BetaScalarPerVector
8>; // YScalarPerVector
#include "run_layernorm_example.inc"
int main() { return run_groupnorm_example<DeviceInstance>(); }
// SPDX-License-Identifier: MIT // SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#include <iostream> #pragma once
#include <numeric>
#include <initializer_list> template <typename DeviceInstance>
#include <cstdlib> int run_groupnorm_example()
#include <getopt.h>
#include "ck/ck.hpp"
#include "ck/utility/reduction_enums.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_normalization_impl.hpp"
#include "ck/tensor_operation/gpu/device/reduction_operator_mapping.hpp"
#include "ck/library/utility/check_err.hpp"
#include "ck/library/utility/device_memory.hpp"
#include "ck/library/utility/host_common_util.hpp"
#include "ck/library/utility/host_tensor.hpp"
#include "ck/library/utility/host_tensor_generator.hpp"
#include "ck/library/utility/literals.hpp"
#include "ck/library/reference_tensor_operation/cpu/reference_layernorm.hpp"
using XDataType = ck::half_t;
using GammaDataType = ck::half_t;
using BetaDataType = ck::half_t;
using YDataType = ck::half_t;
using ComputeDataType = float;
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
constexpr int Rank = 2;
constexpr int NumReduceDim = 1;
using DeviceInstance =
ck::tensor_operation::device::DeviceNormalizationImpl<XDataType,
GammaDataType,
BetaDataType,
ComputeDataType,
YDataType,
PassThrough,
Rank,
NumReduceDim,
256, // BlockSize
8, // ClusterM
32, // ClusterK
1, // SliceM
8, // SliceK
1, // SrcVecDim (0=M, 1=K)
8, // SrcScalarPerVector
1, // GammaVecDim (0=M, 1=K)
8, // GammaScalarPerVector
1, // BetaVecDim (0=M, 1=K)
8, // BetaScalarPerVector
8>; // OutScalarPerVector
int main()
{ {
bool time_kernel = false; bool time_kernel = false;
...@@ -111,6 +63,10 @@ int main() ...@@ -111,6 +63,10 @@ int main()
return 1; return 1;
}; };
size_t workspace_sz = device_instance.GetWorkSpaceSize(argument_ptr.get());
DeviceMem workspace_dev(workspace_sz);
device_instance.SetWorkSpacePointer(argument_ptr.get(), workspace_dev.GetDeviceBuffer());
auto invoker_ptr = device_instance.MakeInvokerPointer(); auto invoker_ptr = device_instance.MakeInvokerPointer();
invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, time_kernel}); invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, time_kernel});
...@@ -133,7 +89,8 @@ int main() ...@@ -133,7 +89,8 @@ int main()
ref_invoker.Run(ref_argument); ref_invoker.Run(ref_argument);
y_dev.FromDevice(y.mData.data()); y_dev.FromDevice(y.mData.data());
pass &= ck::utils::check_err(y, host_y, "Error: Incorrect results d1", 1e-3, 1e-3); pass &= ck::utils::check_err(y, host_y, "Error: Incorrect results", 1e-3, 1e-3);
} }
return (pass ? 0 : 1); return (pass ? 0 : 1);
} }
add_example_executable(example_batched_gemm_gemm_xdl_fp32 batched_gemm_gemm_xdl_fp32.cpp) add_example_executable(example_batched_gemm_gemm_xdl_fp32 batched_gemm_gemm_xdl_fp32.cpp)
add_example_executable(example_batched_gemm_gemm_xdl_fp16 batched_gemm_gemm_xdl_fp16.cpp) add_example_executable(example_batched_gemm_gemm_xdl_fp16 batched_gemm_gemm_xdl_fp16.cpp)
add_example_executable(example_batched_gemm_gemm_xdl_bf16 batched_gemm_gemm_xdl_bf16.cpp) add_example_executable(example_batched_gemm_gemm_xdl_bf16 batched_gemm_gemm_xdl_bf16.cpp)
add_example_executable(example_batched_gemm_gemm_xdl_int8 batched_gemm_gemm_xdl_int8.cpp) if(NOT GPU_TARGETS MATCHES "gfx940")
add_example_executable(example_batched_gemm_gemm_xdl_int8 batched_gemm_gemm_xdl_int8.cpp)
endif()
if(USE_BITINT_EXTENSION_INT4) if(USE_BITINT_EXTENSION_INT4)
add_example_executable(example_batched_gemm_gemm_xdl_int4 batched_gemm_gemm_xdl_int4.cpp) add_example_executable(example_batched_gemm_gemm_xdl_int4 batched_gemm_gemm_xdl_int4.cpp)
......
add_example_executable(example_grouped_conv_conv_fwd_xdl_fp32 grouped_conv_conv_fwd_xdl_fp32.cpp) add_example_executable(example_grouped_conv_conv_fwd_xdl_fp32 grouped_conv_conv_fwd_xdl_fp32.cpp)
add_example_executable(example_grouped_conv_conv_fwd_xdl_fp16 grouped_conv_conv_fwd_xdl_fp16.cpp) add_example_executable(example_grouped_conv_conv_fwd_xdl_fp16 grouped_conv_conv_fwd_xdl_fp16.cpp)
add_example_executable(example_grouped_conv_conv_fwd_xdl_bf16 grouped_conv_conv_fwd_xdl_bf16.cpp) add_example_executable(example_grouped_conv_conv_fwd_xdl_bf16 grouped_conv_conv_fwd_xdl_bf16.cpp)
add_example_executable(example_grouped_conv_conv_fwd_xdl_int8 grouped_conv_conv_fwd_xdl_int8.cpp) if(NOT GPU_TARGETS MATCHES "gfx940")
add_example_executable(example_grouped_conv_conv_fwd_xdl_int8 grouped_conv_conv_fwd_xdl_int8.cpp)
endif()
if(USE_BITINT_EXTENSION_INT4) if(USE_BITINT_EXTENSION_INT4)
add_example_executable(example_grouped_conv_conv_fwd_xdl_int4 grouped_conv_conv_fwd_xdl_int4.cpp) add_example_executable(example_grouped_conv_conv_fwd_xdl_int4 grouped_conv_conv_fwd_xdl_int4.cpp)
endif(USE_BITINT_EXTENSION_INT4) endif(USE_BITINT_EXTENSION_INT4)
add_example_executable(example_groupnorm_sigmoid_mul_fp16 groupnorm_sigmoid_mul_fp16.cpp) add_example_executable(example_groupnorm_sigmoid_mul_fp16 groupnorm_sigmoid_mul_fp16.cpp)
add_example_executable(example_groupnorm_splitk_fp16 groupnorm_splitk_fp16.cpp)
add_example_executable(example_groupnorm_swish_fp16 groupnorm_swish_fp16.cpp) add_example_executable(example_groupnorm_swish_fp16 groupnorm_swish_fp16.cpp)
...@@ -12,6 +12,7 @@ ...@@ -12,6 +12,7 @@
#include "ck/ck.hpp" #include "ck/ck.hpp"
#include "ck/utility/reduction_enums.hpp" #include "ck/utility/reduction_enums.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_normalization_impl.hpp" #include "ck/tensor_operation/gpu/device/impl/device_normalization_impl.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_normalization_splitk_impl.hpp"
#include "ck/tensor_operation/gpu/device/reduction_operator_mapping.hpp" #include "ck/tensor_operation/gpu/device/reduction_operator_mapping.hpp"
#include "ck/library/utility/fill.hpp" #include "ck/library/utility/fill.hpp"
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#include "common.hpp"
constexpr int Rank = 5;
constexpr int NumReduceDim = 3;
using XDataType = ck::half_t;
using GammaDataType = ck::half_t;
using BetaDataType = ck::half_t;
using YDataType = ck::half_t;
using ComputeDataType = float;
using YElementOp = ck::tensor_operation::element_wise::Swish;
using DeviceInstance =
ck::tensor_operation::device::DeviceNormalizationSplitKImpl<XDataType,
GammaDataType,
BetaDataType,
ComputeDataType,
YDataType,
YElementOp,
Rank,
NumReduceDim,
256, // BlockSize
1, // ClusterM
256, // ClusterK
1, // SliceM
16, // SliceK
1, // SrcVecDim (0=M, 1=K)
2, // SrcScalarPerVector
1, // GammaVecDim (0=M, 1=K)
2, // GammaScalarPerVector
1, // BetaVecDim (0=M, 1=K)
2, // BetaScalarPerVector
2>; // OutScalarPerVector
#include "run_groupnorm_example.inc"
int main(int argc, char* argv[]) { run_groupnorm_example(argc, argv); }
...@@ -73,6 +73,10 @@ int run_groupnorm_example(int argc, char* argv[]) ...@@ -73,6 +73,10 @@ int run_groupnorm_example(int argc, char* argv[])
return 1; return 1;
}; };
size_t workspace_sz = device_instance.GetWorkSpaceSize(argument_ptr.get());
DeviceMem workspace_dev(workspace_sz);
device_instance.SetWorkSpacePointer(argument_ptr.get(), workspace_dev.GetDeviceBuffer());
auto invoker_ptr = device_instance.MakeInvokerPointer(); auto invoker_ptr = device_instance.MakeInvokerPointer();
float ave_time = invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, true, true}); float ave_time = invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, true, true});
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment