Commit b89a88b5 authored by Adam Osewski's avatar Adam Osewski
Browse files

Merge branch 'develop' into wavelet_model

parents 41d5fca7 43c898f6
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
/*
Gemm + Gemm fused operation. Computes C_m_o = A_m_k * B0_k_n * B1_n_o
|------------|
Gemm0
|---------------------|
Gemm1
*/
#ifndef CK_EXPERIMENTAL_BIT_INT_EXTENSION_INT4
#error Should compile this file with ck::int4_t support
#endif
#include <iostream>
#include <numeric>
#include <initializer_list>
#include <cstdlib>
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/device_batched_gemm_gemm_xdl_cshuffle.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/library/utility/check_err.hpp"
#include "ck/library/utility/device_memory.hpp"
#include "ck/library/utility/host_tensor.hpp"
#include "ck/library/utility/host_tensor_generator.hpp"
#include "ck/library/reference_tensor_operation/cpu/reference_batched_gemm.hpp"
template <ck::index_t... Is>
using S = ck::Sequence<Is...>;
using Row = ck::tensor_layout::gemm::RowMajor;
using Col = ck::tensor_layout::gemm::ColumnMajor;
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
using ADataType = ck::int4_t;
using B0DataType = ck::int4_t;
using B1DataType = ck::int4_t;
using KernelADataType = int8_t;
using KernelB0DataType = int8_t;
using KernelB1DataType = int8_t;
using AccDataType = int32_t;
using CShuffleDataType = int32_t;
using CDataType = ck::int4_t;
using KernelCDataType = int8_t;
using ALayout = Row;
using B0Layout = Col;
using B1Layout = Row;
using CLayout = Row;
using AElementOp = PassThrough;
using B0ElementOp = PassThrough;
using Acc0ElementOp = PassThrough;
using B1ElementOp = PassThrough;
using CElementOp = PassThrough;
static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default;
using DeviceGemmInstance = ck::tensor_operation::device::DeviceBatchedGemmGemm_Xdl_CShuffle<
ALayout,
B0Layout,
B1Layout,
CLayout,
KernelADataType,
KernelB0DataType,
KernelB1DataType,
KernelCDataType,
AccDataType,
CShuffleDataType,
AElementOp,
B0ElementOp,
Acc0ElementOp,
B1ElementOp,
CElementOp,
GemmDefault,
1,
256,
128, // MPerBlock
128, // NPerBlock
64, // KPerBlock
128, // Gemm1NPerBlock
64, // Gemm1KPerBlock
16, // AK1
16, // BK1
4, // B1K1
32, // MPerXDL
32, // NPerXDL
1, // MXdlPerWave
4, // NXdlPerWave
4, // Gemm1NXdlPerWave
S<4, 64, 1>, // ABlockTransfer
S<1, 0, 2>,
S<1, 0, 2>,
2,
16,
16,
true,
S<4, 64, 1>, // BBlockTransfer
S<1, 0, 2>,
S<1, 0, 2>,
2,
16,
16,
true,
S<8, 32, 1>, // B1BlockTransfer
S<0, 2, 1>,
S<0, 2, 1>,
1,
4,
4,
false,
1, // CShuffleMXdlPerWavePerShuffle
2, // CShuffleNXdlPerWavePerShuffle
S<1, 32, 1, 8>, // CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
8>; // CShuffleBlockTransferScalarPerVector_NPerBlock
using ReferenceGemm0Instance = ck::tensor_operation::host::ReferenceBatchedGemm<ADataType,
B0DataType,
ADataType,
AccDataType,
AElementOp,
B0ElementOp,
CElementOp>;
using ReferenceGemm1Instance = ck::tensor_operation::host::ReferenceBatchedGemm<ADataType,
B1DataType,
CDataType,
AccDataType,
AElementOp,
B1ElementOp,
CElementOp>;
#define BUILD_INT4_EXAMPLE
#include "run_batched_gemm_gemm_example.inc"
#if defined(BUILD_INT4_EXAMPLE) && defined(CK_EXPERIMENTAL_BIT_INT_EXTENSION_INT4)
static_assert(sizeof(ck::int4_t) == sizeof(int8_t));
#endif
int main(int argc, char* argv[]) { return run_batched_gemm_gemm_example(argc, argv) ? 0 : 1; }
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
/*
Gemm + Gemm fused operation. Computes C_m_o = A_m_k * B0_k_n * B1_n_o
|------------|
Gemm0
|---------------------|
Gemm1
*/
#include <iostream>
#include <numeric>
#include <initializer_list>
#include <cstdlib>
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/device_batched_gemm_gemm_xdl_cshuffle.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/library/utility/check_err.hpp"
#include "ck/library/utility/device_memory.hpp"
#include "ck/library/utility/host_tensor.hpp"
#include "ck/library/utility/host_tensor_generator.hpp"
#include "ck/library/reference_tensor_operation/cpu/reference_batched_gemm.hpp"
template <ck::index_t... Is>
using S = ck::Sequence<Is...>;
using Row = ck::tensor_layout::gemm::RowMajor;
using Col = ck::tensor_layout::gemm::ColumnMajor;
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
using ADataType = int8_t;
using B0DataType = int8_t;
using B1DataType = int8_t;
using AccDataType = int32_t;
using CShuffleDataType = int32_t;
using CDataType = int8_t;
using ALayout = Row;
using B0Layout = Col;
using B1Layout = Row;
using CLayout = Row;
using AElementOp = PassThrough;
using B0ElementOp = PassThrough;
using Acc0ElementOp = PassThrough;
using B1ElementOp = PassThrough;
using CElementOp = PassThrough;
static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default;
using DeviceGemmInstance = ck::tensor_operation::device::DeviceBatchedGemmGemm_Xdl_CShuffle<
ALayout,
B0Layout,
B1Layout,
CLayout,
ADataType,
B0DataType,
B1DataType,
CDataType,
AccDataType,
CShuffleDataType,
AElementOp,
B0ElementOp,
Acc0ElementOp,
B1ElementOp,
CElementOp,
GemmDefault,
1,
256,
128, // MPerBlock
128, // NPerBlock
64, // KPerBlock
128, // Gemm1NPerBlock
64, // Gemm1KPerBlock
16, // AK1
16, // BK1
4, // B1K1
32, // MPerXDL
32, // NPerXDL
1, // MXdlPerWave
4, // NXdlPerWave
4, // Gemm1NXdlPerWave
S<4, 64, 1>, // ABlockTransfer
S<1, 0, 2>,
S<1, 0, 2>,
2,
16,
16,
true,
S<4, 64, 1>, // BBlockTransfer
S<1, 0, 2>,
S<1, 0, 2>,
2,
16,
16,
true,
S<8, 32, 1>, // B1BlockTransfer
S<0, 2, 1>,
S<0, 2, 1>,
1,
4,
4,
false,
1, // CShuffleMXdlPerWavePerShuffle
2, // CShuffleNXdlPerWavePerShuffle
S<1, 32, 1, 8>, // CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
8>; // CShuffleBlockTransferScalarPerVector_NPerBlock
using ReferenceGemm0Instance = ck::tensor_operation::host::ReferenceBatchedGemm<ADataType,
B0DataType,
ADataType,
AccDataType,
AElementOp,
B0ElementOp,
CElementOp>;
using ReferenceGemm1Instance = ck::tensor_operation::host::ReferenceBatchedGemm<ADataType,
B1DataType,
CDataType,
AccDataType,
AElementOp,
B1ElementOp,
CElementOp>;
#include "run_batched_gemm_gemm_example.inc"
int main(int argc, char* argv[]) { return run_batched_gemm_gemm_example(argc, argv) ? 0 : 1; }
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
bool run_batched_gemm_gemm_example(int argc, char* argv[])
{
bool do_verification = true;
int init_method = 1;
bool time_kernel = false;
// GEMM shape
ck::index_t M = 1024;
ck::index_t N = 1024;
ck::index_t K = 64;
ck::index_t O = 128;
ck::index_t BatchCount = 4;
ck::index_t StrideA = -1;
ck::index_t StrideB0 = -1;
ck::index_t StrideB1 = -1;
ck::index_t StrideC = -1;
ck::index_t BatchStrideA = -1;
ck::index_t BatchStrideB0 = -1;
ck::index_t BatchStrideB1 = -1;
ck::index_t BatchStrideC = -1;
if(argc == 1)
{
// use default case
}
else if(argc == 4)
{
do_verification = std::stoi(argv[1]);
init_method = std::stoi(argv[2]);
time_kernel = std::stoi(argv[3]);
}
else if(argc == 9)
{
do_verification = std::stoi(argv[1]);
init_method = std::stoi(argv[2]);
time_kernel = std::stoi(argv[3]);
M = std::stoi(argv[4]);
N = std::stoi(argv[5]);
K = std::stoi(argv[6]);
O = std::stoi(argv[7]);
BatchCount = std::stoi(argv[8]);
}
else if(argc == 17)
{
do_verification = std::stoi(argv[1]);
init_method = std::stoi(argv[2]);
time_kernel = std::stoi(argv[3]);
M = std::stoi(argv[4]);
N = std::stoi(argv[5]);
K = std::stoi(argv[6]);
O = std::stoi(argv[7]);
BatchCount = std::stoi(argv[8]);
StrideA = std::stoi(argv[9]);
StrideB0 = std::stoi(argv[10]);
StrideB1 = std::stoi(argv[11]);
StrideC = std::stoi(argv[12]);
BatchStrideA = std::stoi(argv[13]);
BatchStrideB0 = std::stoi(argv[14]);
BatchStrideB1 = std::stoi(argv[15]);
BatchStrideC = std::stoi(argv[16]);
}
else
{
printf("arg1: verification (0=no, 1=yes)\n");
printf("arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n");
printf("arg3: time kernel (0=no, 1=yes)\n");
printf("arg4 to 17: M, N, K, O, Batch, StrideA, StrideB0, StrideB1, StrideC, BatchStrideA, "
"BatchStrideB0, BatchStrideB1, BatchStrideC\n");
exit(0);
}
const int DefaultStrideA = ck::is_same_v<ALayout, Row> ? K : M;
const int DefaultStrideB0 = ck::is_same_v<B0Layout, Row> ? N : K;
const int DefaultStrideB1 = ck::is_same_v<B1Layout, Row> ? O : N;
const int DefaultStrideC = ck::is_same_v<CLayout, Row> ? O : M;
StrideA = (StrideA < 0) ? DefaultStrideA : StrideA;
StrideB0 = (StrideB0 < 0) ? DefaultStrideB0 : StrideB0;
StrideB1 = (StrideB1 < 0) ? DefaultStrideB1 : StrideB1;
StrideC = (StrideC < 0) ? DefaultStrideC : StrideC;
const int DefaultBatchStrideA = (ck::is_same_v<ALayout, Col> ? K : M) * StrideA;
const int DefaultBatchStrideB0 = (ck::is_same_v<B0Layout, Col> ? N : K) * StrideB0;
const int DefaultBatchStrideB1 = (ck::is_same_v<B1Layout, Col> ? O : N) * StrideB1;
const int DefaultBatchStrideC = (ck::is_same_v<CLayout, Col> ? O : M) * StrideC;
BatchStrideA = BatchStrideA < 0 ? DefaultBatchStrideA : BatchStrideA;
BatchStrideB0 = BatchStrideB0 < 0 ? DefaultBatchStrideB0 : BatchStrideB0;
BatchStrideB1 = BatchStrideB1 < 0 ? DefaultBatchStrideB1 : BatchStrideB1;
BatchStrideC = BatchStrideC < 0 ? DefaultBatchStrideC : BatchStrideC;
auto f_host_tensor_descriptor = [](std::size_t batch_count,
std::size_t row,
std::size_t col,
std::size_t stride,
std::size_t batch_stride,
auto layout) {
if(std::is_same<decltype(layout), Row>::value)
{
return HostTensorDescriptor(std::vector<std::size_t>({batch_count, row, col}),
std::vector<std::size_t>({batch_stride, stride, 1}));
}
else
{
return HostTensorDescriptor(std::vector<std::size_t>({batch_count, row, col}),
std::vector<std::size_t>({batch_stride, 1, stride}));
}
};
// C_m_o = A_m_k * B0_k_n * B1_n_o
Tensor<ADataType> a_g_m_k(
f_host_tensor_descriptor(BatchCount, M, K, StrideA, BatchStrideA, ALayout{}));
Tensor<B0DataType> b0_g_k_n(
f_host_tensor_descriptor(BatchCount, K, N, StrideB0, BatchStrideB0, B0Layout{}));
Tensor<B1DataType> b1_g_n_o(
f_host_tensor_descriptor(BatchCount, N, O, StrideB1, BatchStrideB1, B1Layout{}));
Tensor<CDataType> c_g_m_o_host_result(
f_host_tensor_descriptor(BatchCount, M, O, StrideC, BatchStrideC, CLayout{}));
Tensor<CDataType> c_g_m_o_device_result(
f_host_tensor_descriptor(BatchCount, M, O, StrideC, BatchStrideC, CLayout{}));
std::cout << "a_g_m_k: " << a_g_m_k.mDesc << std::endl;
std::cout << "b0_g_k_n: " << b0_g_k_n.mDesc << std::endl;
std::cout << "b1_g_n_o: " << b1_g_n_o.mDesc << std::endl;
std::cout << "c_g_m_o: " << c_g_m_o_host_result.mDesc << std::endl;
switch(init_method)
{
case 0: break;
case 1:
a_g_m_k.GenerateTensorValue(GeneratorTensor_2<ADataType>{-5, 5});
b0_g_k_n.GenerateTensorValue(GeneratorTensor_2<B0DataType>{-5, 5});
b1_g_n_o.GenerateTensorValue(GeneratorTensor_2<B1DataType>{-5, 5});
break;
case 2:
a_g_m_k.GenerateTensorValue(GeneratorTensor_3<ADataType>{0.0, 1.0});
b0_g_k_n.GenerateTensorValue(GeneratorTensor_3<B0DataType>{0.0, 1.0});
b1_g_n_o.GenerateTensorValue(GeneratorTensor_3<B1DataType>{-0.5, 0.5});
break;
default:
a_g_m_k.GenerateTensorValue(GeneratorTensor_1<ADataType>{1});
b0_g_k_n.GenerateTensorValue(GeneratorTensor_Sequential<1>{});
b1_g_n_o.GenerateTensorValue(GeneratorTensor_Diagonal<B1DataType>{});
}
#ifdef BUILD_INT4_EXAMPLE
DeviceMem a_g_m_k_device_buf(sizeof(KernelADataType) * a_g_m_k.mDesc.GetElementSpaceSize());
DeviceMem b0_g_k_n_device_buf(sizeof(KernelB0DataType) * b0_g_k_n.mDesc.GetElementSpaceSize());
DeviceMem b1_g_n_o_device_buf(sizeof(KernelB1DataType) * b1_g_n_o.mDesc.GetElementSpaceSize());
DeviceMem c_g_m_o_device_buf(sizeof(KernelCDataType) *
c_g_m_o_device_result.mDesc.GetElementSpaceSize());
const Tensor<KernelADataType> a_g_m_k_converted(a_g_m_k);
const Tensor<KernelB0DataType> b0_g_k_n_converted(b0_g_k_n);
const Tensor<KernelB1DataType> b1_g_n_o_converted(b1_g_n_o);
a_g_m_k_device_buf.ToDevice(a_g_m_k_converted.mData.data());
b0_g_k_n_device_buf.ToDevice(b0_g_k_n_converted.mData.data());
b1_g_n_o_device_buf.ToDevice(b1_g_n_o_converted.mData.data());
#else
DeviceMem a_g_m_k_device_buf(sizeof(ADataType) * a_g_m_k.mDesc.GetElementSpaceSize());
DeviceMem b0_g_k_n_device_buf(sizeof(B0DataType) * b0_g_k_n.mDesc.GetElementSpaceSize());
DeviceMem b1_g_n_o_device_buf(sizeof(B1DataType) * b1_g_n_o.mDesc.GetElementSpaceSize());
DeviceMem c_g_m_o_device_buf(sizeof(CDataType) *
c_g_m_o_device_result.mDesc.GetElementSpaceSize());
a_g_m_k_device_buf.ToDevice(a_g_m_k.mData.data());
b0_g_k_n_device_buf.ToDevice(b0_g_k_n.mData.data());
b1_g_n_o_device_buf.ToDevice(b1_g_n_o.mData.data());
#endif
auto a_element_op = AElementOp{};
auto b0_element_op = B0ElementOp{};
auto acc0_element_op = Acc0ElementOp{};
auto b1_element_op = B1ElementOp{};
auto c_element_op = CElementOp{};
// do GEMM
auto gemm = DeviceGemmInstance{};
auto invoker = gemm.MakeInvoker();
auto argument = gemm.MakeArgument(
#ifdef BUILD_INT4_EXAMPLE
static_cast<KernelADataType*>(a_g_m_k_device_buf.GetDeviceBuffer()),
static_cast<KernelB0DataType*>(b0_g_k_n_device_buf.GetDeviceBuffer()),
static_cast<KernelB1DataType*>(b1_g_n_o_device_buf.GetDeviceBuffer()),
static_cast<KernelCDataType*>(c_g_m_o_device_buf.GetDeviceBuffer()),
#else
static_cast<ADataType*>(a_g_m_k_device_buf.GetDeviceBuffer()),
static_cast<B0DataType*>(b0_g_k_n_device_buf.GetDeviceBuffer()),
static_cast<B1DataType*>(b1_g_n_o_device_buf.GetDeviceBuffer()),
static_cast<CDataType*>(c_g_m_o_device_buf.GetDeviceBuffer()),
#endif
M,
N,
K,
O,
BatchCount,
StrideA,
StrideB0,
StrideB1,
StrideC,
BatchStrideA,
BatchStrideB0,
BatchStrideB1,
BatchStrideC,
a_element_op,
b0_element_op,
acc0_element_op,
b1_element_op,
c_element_op);
if(!gemm.IsSupportedArgument(argument))
{
std::cout << gemm.GetTypeString() << " does not support this problem" << std::endl;
return 0;
}
float ave_time = invoker.Run(argument, StreamConfig{nullptr, time_kernel});
std::size_t flop = (size_t(M) * N * K * 2 + size_t(M) * N * O * 2) * BatchCount;
std::size_t num_btype = (sizeof(ADataType) * M * K + sizeof(B0DataType) * K * N +
sizeof(B1DataType) * N * O + sizeof(CDataType) * M * O) *
BatchCount;
float tflops = static_cast<float>(flop) / 1.E9 / ave_time;
float gb_per_sec = num_btype / 1.E6 / ave_time;
std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s, "
<< gemm.GetTypeString() << std::endl;
if(do_verification)
{
// Output of Gemm0 is input A of Gemm1
Tensor<ADataType> a1_g_m_n(f_host_tensor_descriptor(BatchCount, M, N, N, M * N, Row{}));
auto ref_gemm0 = ReferenceGemm0Instance{};
auto ref_gemm0_invoker = ref_gemm0.MakeInvoker();
auto ref_gemm0_argument = ref_gemm0.MakeArgument(
a_g_m_k, b0_g_k_n, a1_g_m_n, a_element_op, b0_element_op, PassThrough{});
ref_gemm0_invoker.Run(ref_gemm0_argument);
auto ref_gemm1 = ReferenceGemm1Instance{};
auto ref_gemm1_invoker = ref_gemm1.MakeInvoker();
auto ref_gemm1_argument = ref_gemm1.MakeArgument(
a1_g_m_n, b1_g_n_o, c_g_m_o_host_result, PassThrough{}, b1_element_op, c_element_op);
ref_gemm1_invoker.Run(ref_gemm1_argument);
#ifdef BUILD_INT4_EXAMPLE
Tensor<KernelCDataType> c_g_m_o_device_result_converted(c_g_m_o_host_result.mDesc);
c_g_m_o_device_buf.FromDevice(c_g_m_o_device_result_converted.mData.data());
c_g_m_o_device_result = c_g_m_o_device_result_converted.CopyAsType<CDataType>();
#else
c_g_m_o_device_buf.FromDevice(c_g_m_o_device_result.mData.data());
#endif
return ck::utils::check_err(c_g_m_o_device_result.mData, c_g_m_o_host_result.mData);
}
return true;
}
add_example_executable(example_batched_gemm_scale_softmax_gemm_xdl_fp16 batched_gemm_scale_softmax_gemm_xdl_fp16.cpp)
add_example_executable(example_batched_gemm_scale_softmax_gemm_permute_xdl_fp16 batched_gemm_scale_softmax_gemm_permute_xdl_fp16.cpp)
add_example_executable(example_padded_batched_gemm_scale_softmax_gemm_xdl_fp16 padded_batched_gemm_scale_softmax_gemm_xdl_fp16.cpp)
add_custom_target(example_batched_gemm_scale_softmax_gemm)
add_dependencies(example_batched_gemm_scale_softmax_gemm example_batched_gemm_scale_softmax_gemm_xdl_fp16)
add_dependencies(example_batched_gemm_scale_softmax_gemm example_batched_gemm_scale_softmax_gemm_permute_xdl_fp16)
add_dependencies(example_batched_gemm_scale_softmax_gemm example_padded_batched_gemm_scale_softmax_gemm_xdl_fp16)
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
/*
Gemm + Softmax + Gemm fused operation. Computes C_g_m_o = Softmax(A_g_m_k * B0_g_k_n) * B1_g_n_o
|-----------------|
Gemm0
|-------------------------------------|
Gemm1
*/
#include <iostream>
#include <numeric>
#include <initializer_list>
#include <cstdlib>
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/tensor_specialization.hpp"
#include "ck/tensor_operation/gpu/device/device_batched_gemm_softmax_gemm_permute_xdl_cshuffle.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/library/utility/check_err.hpp"
#include "ck/library/utility/device_memory.hpp"
#include "ck/library/utility/host_tensor.hpp"
#include "ck/library/utility/host_tensor_generator.hpp"
#include "ck/library/reference_tensor_operation/cpu/reference_batched_gemm.hpp"
#include "ck/library/reference_tensor_operation/cpu/reference_softmax.hpp"
template <ck::index_t... Is>
using S = ck::Sequence<Is...>;
using F16 = ck::half_t;
using F32 = float;
using Row = ck::tensor_layout::gemm::RowMajor;
using Col = ck::tensor_layout::gemm::ColumnMajor;
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
using ADataType = F16;
using B0DataType = F16;
using B1DataType = F16;
using AccDataType = F32;
using CShuffleDataType = F32;
using CDataType = F16;
using ALayout = Row;
using B0Layout = Col;
using B1Layout = Row;
using CPermuteNumDims_G_M_O =
S<2, 1, 1>; // "using CLayout = Row" has been replaced by CPermuteNumDims_G_M_O
using AElementOp = PassThrough;
using B0ElementOp = PassThrough;
using Acc0ElementOp = ck::tensor_operation::element_wise::Scale;
using B1ElementOp = PassThrough;
using CElementOp = PassThrough;
static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::MNOPadding;
using DeviceGemmInstance =
ck::tensor_operation::device::DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle<
ALayout,
B0Layout,
B1Layout,
CPermuteNumDims_G_M_O,
ADataType,
B0DataType,
B1DataType,
CDataType,
AccDataType,
CShuffleDataType,
AElementOp,
B0ElementOp,
Acc0ElementOp,
B1ElementOp,
CElementOp,
GemmSpec,
1,
256,
128, // MPerBlock
128, // NPerBlock
32, // KPerBlock
64, // Gemm1NPerBlock
32, // Gemm1KPerBlock
8, // AK1
8, // BK1
2, // B1K1
32, // MPerXDL
32, // NPerXDL
1, // MXdlPerWave
4, // NXdlPerWave
2, // Gemm1NXdlPerWave
S<4, 64, 1>, // ABlockTransfer
S<1, 0, 2>,
S<1, 0, 2>,
2,
8,
8,
true,
S<4, 64, 1>, // BBlockTransfer
S<1, 0, 2>,
S<1, 0, 2>,
2,
8,
8,
true,
S<16, 16, 1>, // B1BlockTransfer
S<0, 2, 1>,
S<0, 2, 1>,
1,
4,
2,
false,
1, // CShuffleMXdlPerWavePerShuffle
2, // CShuffleNXdlPerWavePerShuffle
S<1, 32, 1, 8>, // CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
8>; // CShuffleBlockTransferScalarPerVector_NPerBlock
// Ref Gemm0: fp16 in, fp32 out
using ReferenceGemm0Instance = ck::tensor_operation::host::ReferenceBatchedGemm<ADataType,
B0DataType,
AccDataType,
AccDataType,
AElementOp,
B0ElementOp,
Acc0ElementOp>;
// Ref Softmax: fp32 in, fp16 out
using ReferenceSoftmaxInstance =
ck::tensor_operation::host::ReferenceSoftmax<AccDataType, ADataType, AccDataType>;
// Ref Gemm1: fp16 in, fp16 out
using ReferenceGemm1Instance = ck::tensor_operation::host::ReferenceBatchedGemm<ADataType,
B1DataType,
CDataType,
AccDataType,
AElementOp,
B1ElementOp,
CElementOp>;
int main(int argc, char* argv[])
{
bool do_verification = true;
int init_method = 1;
bool time_kernel = false;
// GEMM shape for A/B0/B1/C
// C_g_m_o = A_g_m_k * B0_g_k_n * B1_g_n_o
ck::index_t M = 128;
ck::index_t N = 1024;
ck::index_t K = 64;
ck::index_t O = 128;
ck::index_t StrideA = -1;
ck::index_t StrideB0 = -1;
ck::index_t StrideB1 = -1;
ck::index_t BatchStrideA = -1;
ck::index_t BatchStrideB0 = -1;
ck::index_t BatchStrideB1 = -1;
float alpha = 1;
// Output shape C[G0, M, G1, O]. Batch dim, outer dim, inner dim must match GEMM shape
// C_g0_g1_m_o = reshape(C_g_m_o, [g0, g1, m, o])
// C_g0_m_g1_o = permute(C_g0_g1_m_o, [0, 2, 1, 3])
ck::index_t G0 = 7;
ck::index_t G1 = 13;
if(argc == 1)
{
// use default case
}
else if(argc == 4)
{
do_verification = std::stoi(argv[1]);
init_method = std::stoi(argv[2]);
time_kernel = std::stoi(argv[3]);
}
else if(argc == 11)
{
do_verification = std::stoi(argv[1]);
init_method = std::stoi(argv[2]);
time_kernel = std::stoi(argv[3]);
M = std::stoi(argv[4]);
N = std::stoi(argv[5]);
K = std::stoi(argv[6]);
O = std::stoi(argv[7]);
G0 = std::stoi(argv[8]);
G1 = std::stoi(argv[9]);
alpha = std::stof(argv[10]);
}
else
{
printf("arg1: verification (0=no, 1=yes)\n");
printf("arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n");
printf("arg3: time kernel (0=no, 1=yes)\n");
printf("arg4 to 11: M, N, K, O, G0, G1\n");
printf("arg10: scale (alpha)\n");
exit(0);
}
std::vector<ck::index_t> c_gs_ms_os_lengths{G0, G1, M, O};
std::vector<ck::index_t> c_gs_ms_os_strides{M * G1 * O, O, G1 * O, 1};
const int DefaultStrideA = ck::is_same_v<ALayout, Row> ? K : M;
const int DefaultStrideB0 = ck::is_same_v<B0Layout, Row> ? N : K;
const int DefaultStrideB1 = ck::is_same_v<B1Layout, Row> ? O : N;
StrideA = (StrideA < 0) ? DefaultStrideA : StrideA;
StrideB0 = (StrideB0 < 0) ? DefaultStrideB0 : StrideB0;
StrideB1 = (StrideB1 < 0) ? DefaultStrideB1 : StrideB1;
const int DefaultBatchStrideA = (ck::is_same_v<ALayout, Col> ? K : M) * StrideA;
const int DefaultBatchStrideB0 = (ck::is_same_v<B0Layout, Col> ? N : K) * StrideB0;
const int DefaultBatchStrideB1 = (ck::is_same_v<B1Layout, Col> ? O : N) * StrideB1;
BatchStrideA = BatchStrideA < 0 ? DefaultBatchStrideA : BatchStrideA;
BatchStrideB0 = BatchStrideB0 < 0 ? DefaultBatchStrideB0 : BatchStrideB0;
BatchStrideB1 = BatchStrideB1 < 0 ? DefaultBatchStrideB1 : BatchStrideB1;
const int BatchCount = G0 * G1;
auto f_host_tensor_descriptor = [](std::size_t batch_count,
std::size_t row,
std::size_t col,
std::size_t stride,
std::size_t batch_stride,
auto layout) {
if(std::is_same<decltype(layout), Row>::value)
{
return HostTensorDescriptor(std::vector<std::size_t>({batch_count, row, col}),
std::vector<std::size_t>({batch_stride, stride, 1}));
}
else
{
return HostTensorDescriptor(std::vector<std::size_t>({batch_count, row, col}),
std::vector<std::size_t>({batch_stride, 1, stride}));
}
};
// C_m_o = A_m_k * B0_k_n * B1_n_o
Tensor<ADataType> a_g_m_k(
f_host_tensor_descriptor(BatchCount, M, K, StrideA, BatchStrideA, ALayout{}));
Tensor<B0DataType> b0_g_k_n(
f_host_tensor_descriptor(BatchCount, K, N, StrideB0, BatchStrideB0, B0Layout{}));
Tensor<B1DataType> b1_g_n_o(
f_host_tensor_descriptor(BatchCount, N, O, StrideB1, BatchStrideB1, B1Layout{}));
Tensor<CDataType> c_gs_ms_os_host_result(
std::vector<std::size_t>(c_gs_ms_os_lengths.begin(), c_gs_ms_os_lengths.end()),
std::vector<std::size_t>(c_gs_ms_os_strides.begin(), c_gs_ms_os_strides.end()));
Tensor<CDataType> c_gs_ms_os_device_result(
std::vector<std::size_t>(c_gs_ms_os_lengths.begin(), c_gs_ms_os_lengths.end()),
std::vector<std::size_t>(c_gs_ms_os_strides.begin(), c_gs_ms_os_strides.end()));
std::cout << "a_g_m_k: " << a_g_m_k.mDesc << std::endl;
std::cout << "b0_g_k_n: " << b0_g_k_n.mDesc << std::endl;
std::cout << "b1_g_n_o: " << b1_g_n_o.mDesc << std::endl;
std::cout << "c_gs_ms_os: " << c_gs_ms_os_host_result.mDesc << std::endl;
switch(init_method)
{
case 0: break;
case 1:
a_g_m_k.GenerateTensorValue(GeneratorTensor_2<ADataType>{-5, 5});
b0_g_k_n.GenerateTensorValue(GeneratorTensor_2<B0DataType>{-5, 5});
b1_g_n_o.GenerateTensorValue(GeneratorTensor_2<B1DataType>{-5, 5});
break;
case 2:
a_g_m_k.GenerateTensorValue(GeneratorTensor_3<ADataType>{0.0, 1.0});
b0_g_k_n.GenerateTensorValue(GeneratorTensor_3<B0DataType>{0.0, 1.0});
b1_g_n_o.GenerateTensorValue(GeneratorTensor_3<B1DataType>{-0.5, 0.5});
break;
case 3:
a_g_m_k.GenerateTensorValue(GeneratorTensor_2<ADataType>{-2, 2});
b0_g_k_n.GenerateTensorValue(GeneratorTensor_Diagonal<B0DataType>{});
b1_g_n_o.GenerateTensorValue(GeneratorTensor_Diagonal<B1DataType>{});
break;
default:
a_g_m_k.GenerateTensorValue(GeneratorTensor_1<ADataType>{1});
b0_g_k_n.GenerateTensorValue(GeneratorTensor_Sequential<1>{});
b1_g_n_o.GenerateTensorValue(GeneratorTensor_Diagonal<B1DataType>{});
}
DeviceMem a_g_m_k_device_buf(sizeof(ADataType) * a_g_m_k.mDesc.GetElementSpaceSize());
DeviceMem b0_g_k_n_device_buf(sizeof(B0DataType) * b0_g_k_n.mDesc.GetElementSpaceSize());
DeviceMem b1_g_n_o_device_buf(sizeof(B1DataType) * b1_g_n_o.mDesc.GetElementSpaceSize());
DeviceMem c_gs_ms_os_device_buf(sizeof(CDataType) *
c_gs_ms_os_device_result.mDesc.GetElementSpaceSize());
a_g_m_k_device_buf.ToDevice(a_g_m_k.mData.data());
b0_g_k_n_device_buf.ToDevice(b0_g_k_n.mData.data());
b1_g_n_o_device_buf.ToDevice(b1_g_n_o.mData.data());
auto a_element_op = AElementOp{};
auto b0_element_op = B0ElementOp{};
auto acc0_element_op = Acc0ElementOp{alpha};
auto b1_element_op = B1ElementOp{};
auto c_element_op = CElementOp{};
// do GEMM
auto gemm = DeviceGemmInstance{};
auto invoker = gemm.MakeInvoker();
auto argument =
gemm.MakeArgument(static_cast<ADataType*>(a_g_m_k_device_buf.GetDeviceBuffer()),
static_cast<B0DataType*>(b0_g_k_n_device_buf.GetDeviceBuffer()),
static_cast<B1DataType*>(b1_g_n_o_device_buf.GetDeviceBuffer()),
static_cast<CDataType*>(c_gs_ms_os_device_buf.GetDeviceBuffer()),
M,
N,
K,
O,
BatchCount,
c_gs_ms_os_lengths,
c_gs_ms_os_strides,
StrideA,
StrideB0,
StrideB1,
BatchStrideA,
BatchStrideB0,
BatchStrideB1,
a_element_op,
b0_element_op,
acc0_element_op,
b1_element_op,
c_element_op);
if(!gemm.IsSupportedArgument(argument))
{
std::cout << gemm.GetTypeString() << " does not support this problem" << std::endl;
return 0;
}
float ave_time = invoker.Run(argument, StreamConfig{nullptr, time_kernel});
std::size_t flop = (size_t(M) * N * K * 2 + size_t(M) * N * O * 2) * BatchCount;
std::size_t num_btype = (sizeof(ADataType) * M * K + sizeof(B0DataType) * K * N +
sizeof(B1DataType) * N * O + sizeof(CDataType) * M * O) *
BatchCount;
float tflops = static_cast<float>(flop) / 1.E9 / ave_time;
float gb_per_sec = num_btype / 1.E6 / ave_time;
std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s, "
<< gemm.GetTypeString() << std::endl;
if(do_verification)
{
c_gs_ms_os_device_buf.FromDevice(c_gs_ms_os_device_result.mData.data());
// Output of Gemm0 is input A of Gemm1
Tensor<AccDataType> acc0_g_m_n(f_host_tensor_descriptor(BatchCount, M, N, N, M * N, Row{}));
Tensor<ADataType> a1_g_m_n(f_host_tensor_descriptor(BatchCount, M, N, N, M * N, Row{}));
Tensor<CDataType> c_g_m_o_host_result(std::vector<int>{BatchCount, M, O},
std::vector<int>{M * O, O, 1});
auto ref_gemm0 = ReferenceGemm0Instance{};
auto ref_gemm0_invoker = ref_gemm0.MakeInvoker();
auto ref_gemm0_argument = ref_gemm0.MakeArgument(
a_g_m_k, b0_g_k_n, acc0_g_m_n, a_element_op, b0_element_op, acc0_element_op);
ref_gemm0_invoker.Run(ref_gemm0_argument);
auto ref_softmax = ReferenceSoftmaxInstance{};
auto ref_softmax_invoker = ref_softmax.MakeInvoker();
auto ref_softmax_argument = ref_softmax.MakeArgument(acc0_g_m_n, a1_g_m_n, 1, 0, {2});
ref_softmax_invoker.Run(ref_softmax_argument);
auto ref_gemm1 = ReferenceGemm1Instance{};
auto ref_gemm1_invoker = ref_gemm1.MakeInvoker();
auto ref_gemm1_argument = ref_gemm1.MakeArgument(
a1_g_m_n, b1_g_n_o, c_g_m_o_host_result, PassThrough{}, b1_element_op, c_element_op);
ref_gemm1_invoker.Run(ref_gemm1_argument);
c_gs_ms_os_host_result.ForEach([&](auto& self, auto idx) {
const size_t& g0 = idx[0];
const size_t& g1 = idx[1];
const size_t g = g0 * G1 + g1;
self(idx) = c_g_m_o_host_result(g, idx[2], idx[3]);
});
return ck::utils::check_err(c_gs_ms_os_device_result.mData, c_gs_ms_os_host_result.mData)
? 0
: 1;
}
return 0;
}
...@@ -2,11 +2,11 @@ ...@@ -2,11 +2,11 @@
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
/* /*
Gemm + Gemm fused operation. Computes C_m_o = A_m_k * B0_k_n * B1_n_o Gemm + Softmax + Gemm fused operation. Computes C_g_m_o = Softmax(A_g_m_k * B0_g_k_n) * B1_g_n_o
|------------| |-----------------|
Gemm0 Gemm0
|---------------------| |-------------------------------------|
Gemm1 Gemm1
*/ */
#include <iostream> #include <iostream>
...@@ -51,7 +51,7 @@ using CLayout = Row; ...@@ -51,7 +51,7 @@ using CLayout = Row;
using AElementOp = PassThrough; using AElementOp = PassThrough;
using B0ElementOp = PassThrough; using B0ElementOp = PassThrough;
using Acc0ElementOp = PassThrough; using Acc0ElementOp = ck::tensor_operation::element_wise::Scale;
using B1ElementOp = PassThrough; using B1ElementOp = PassThrough;
using CElementOp = PassThrough; using CElementOp = PassThrough;
...@@ -122,7 +122,7 @@ using ReferenceGemm0Instance = ck::tensor_operation::host::ReferenceBatchedGemm< ...@@ -122,7 +122,7 @@ using ReferenceGemm0Instance = ck::tensor_operation::host::ReferenceBatchedGemm<
AccDataType, AccDataType,
AElementOp, AElementOp,
B0ElementOp, B0ElementOp,
CElementOp>; Acc0ElementOp>;
// Ref Softmax: fp32 in, fp16 out // Ref Softmax: fp32 in, fp16 out
using ReferenceSoftmaxInstance = using ReferenceSoftmaxInstance =
...@@ -157,6 +157,7 @@ int main(int argc, char* argv[]) ...@@ -157,6 +157,7 @@ int main(int argc, char* argv[])
ck::index_t BatchStrideB0 = -1; ck::index_t BatchStrideB0 = -1;
ck::index_t BatchStrideB1 = -1; ck::index_t BatchStrideB1 = -1;
ck::index_t BatchStrideC = -1; ck::index_t BatchStrideC = -1;
float alpha = 1;
if(argc == 1) if(argc == 1)
{ {
...@@ -181,7 +182,7 @@ int main(int argc, char* argv[]) ...@@ -181,7 +182,7 @@ int main(int argc, char* argv[])
BatchCount = std::stoi(argv[8]); BatchCount = std::stoi(argv[8]);
} }
else if(argc == 17) else if(argc == 18)
{ {
do_verification = std::stoi(argv[1]); do_verification = std::stoi(argv[1]);
init_method = std::stoi(argv[2]); init_method = std::stoi(argv[2]);
...@@ -203,14 +204,17 @@ int main(int argc, char* argv[]) ...@@ -203,14 +204,17 @@ int main(int argc, char* argv[])
BatchStrideB0 = std::stoi(argv[14]); BatchStrideB0 = std::stoi(argv[14]);
BatchStrideB1 = std::stoi(argv[15]); BatchStrideB1 = std::stoi(argv[15]);
BatchStrideC = std::stoi(argv[16]); BatchStrideC = std::stoi(argv[16]);
alpha = std::stof(argv[17]);
} }
else else
{ {
printf("arg1: verification (0=no, 1=yes)\n"); printf("arg1: verification (0=no, 1=yes)\n");
printf("arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n"); printf("arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n");
printf("arg3: time kernel (0=no, 1=yes)\n"); printf("arg3: time kernel (0=no, 1=yes)\n");
printf("arg4 to 17: M, N, K, O, Batch, StrideA, StrideB0, StrideB1, StrideC, BatchStrideA, " printf("arg4 to 16: M, N, K, O, Batch, StrideA, StrideB0, StrideB1, StrideC, BatchStrideA, "
"BatchStrideB0, BatchStrideB1, BatchStrideC\n"); "BatchStrideB0, BatchStrideB1, BatchStrideC\n");
printf("arg17: scale (alpha)\n");
exit(0); exit(0);
} }
...@@ -293,10 +297,11 @@ int main(int argc, char* argv[]) ...@@ -293,10 +297,11 @@ int main(int argc, char* argv[])
b1_g_n_o.GenerateTensorValue(GeneratorTensor_Diagonal<B1DataType>{}); b1_g_n_o.GenerateTensorValue(GeneratorTensor_Diagonal<B1DataType>{});
} }
DeviceMem a_g_m_k_device_buf(sizeof(ADataType) * a_g_m_k.mDesc.GetElementSize()); DeviceMem a_g_m_k_device_buf(sizeof(ADataType) * a_g_m_k.mDesc.GetElementSpaceSize());
DeviceMem b0_g_k_n_device_buf(sizeof(B0DataType) * b0_g_k_n.mDesc.GetElementSize()); DeviceMem b0_g_k_n_device_buf(sizeof(B0DataType) * b0_g_k_n.mDesc.GetElementSpaceSize());
DeviceMem b1_g_n_o_device_buf(sizeof(B1DataType) * b1_g_n_o.mDesc.GetElementSize()); DeviceMem b1_g_n_o_device_buf(sizeof(B1DataType) * b1_g_n_o.mDesc.GetElementSpaceSize());
DeviceMem c_g_m_o_device_buf(sizeof(CDataType) * c_g_m_o_device_result.mDesc.GetElementSize()); DeviceMem c_g_m_o_device_buf(sizeof(CDataType) *
c_g_m_o_device_result.mDesc.GetElementSpaceSize());
a_g_m_k_device_buf.ToDevice(a_g_m_k.mData.data()); a_g_m_k_device_buf.ToDevice(a_g_m_k.mData.data());
b0_g_k_n_device_buf.ToDevice(b0_g_k_n.mData.data()); b0_g_k_n_device_buf.ToDevice(b0_g_k_n.mData.data());
...@@ -304,7 +309,7 @@ int main(int argc, char* argv[]) ...@@ -304,7 +309,7 @@ int main(int argc, char* argv[])
auto a_element_op = AElementOp{}; auto a_element_op = AElementOp{};
auto b0_element_op = B0ElementOp{}; auto b0_element_op = B0ElementOp{};
auto acc0_element_op = Acc0ElementOp{}; auto acc0_element_op = Acc0ElementOp{alpha};
auto b1_element_op = B1ElementOp{}; auto b1_element_op = B1ElementOp{};
auto c_element_op = CElementOp{}; auto c_element_op = CElementOp{};
...@@ -368,7 +373,7 @@ int main(int argc, char* argv[]) ...@@ -368,7 +373,7 @@ int main(int argc, char* argv[])
auto ref_gemm0 = ReferenceGemm0Instance{}; auto ref_gemm0 = ReferenceGemm0Instance{};
auto ref_gemm0_invoker = ref_gemm0.MakeInvoker(); auto ref_gemm0_invoker = ref_gemm0.MakeInvoker();
auto ref_gemm0_argument = ref_gemm0.MakeArgument( auto ref_gemm0_argument = ref_gemm0.MakeArgument(
a_g_m_k, b0_g_k_n, acc0_g_m_n, a_element_op, b0_element_op, PassThrough{}); a_g_m_k, b0_g_k_n, acc0_g_m_n, a_element_op, b0_element_op, acc0_element_op);
ref_gemm0_invoker.Run(ref_gemm0_argument); ref_gemm0_invoker.Run(ref_gemm0_argument);
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
/*
Gemm + Softmax + Gemm fused operation. Computes C_g_m_o = Softmax(A_g_m_k * B0_g_k_n) * B1_g_n_o
|-----------------|
Gemm0
|-------------------------------------|
Gemm1
*/
#include <iostream>
#include <numeric>
#include <initializer_list>
#include <cstdlib>
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/device_batched_gemm_softmax_gemm_xdl_cshuffle.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/library/utility/check_err.hpp"
#include "ck/library/utility/device_memory.hpp"
#include "ck/library/utility/host_tensor.hpp"
#include "ck/library/utility/host_tensor_generator.hpp"
#include "ck/library/reference_tensor_operation/cpu/reference_batched_gemm.hpp"
#include "ck/library/reference_tensor_operation/cpu/reference_softmax.hpp"
template <ck::index_t... Is>
using S = ck::Sequence<Is...>;
using F16 = ck::half_t;
using F32 = float;
using Row = ck::tensor_layout::gemm::RowMajor;
using Col = ck::tensor_layout::gemm::ColumnMajor;
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
using ADataType = F16;
using B0DataType = F16;
using B1DataType = F16;
using AccDataType = F32;
using CShuffleDataType = F32;
using CDataType = F16;
using ALayout = Row;
using B0Layout = Col;
using B1Layout = Row;
using CLayout = Row;
using AElementOp = PassThrough;
using B0ElementOp = PassThrough;
using Acc0ElementOp = ck::tensor_operation::element_wise::Scale;
using B1ElementOp = PassThrough;
using CElementOp = PassThrough;
static constexpr auto MNPadding = ck::tensor_operation::device::GemmSpecialization::MNPadding;
using DeviceGemmInstance = ck::tensor_operation::device::DeviceBatchedGemmSoftmaxGemm_Xdl_CShuffle<
ALayout,
B0Layout,
B1Layout,
CLayout,
ADataType,
B0DataType,
B1DataType,
CDataType,
AccDataType,
CShuffleDataType,
AElementOp,
B0ElementOp,
Acc0ElementOp,
B1ElementOp,
CElementOp,
MNPadding,
1,
256,
128, // MPerBlock
128, // NPerBlock
32, // KPerBlock
64, // Gemm1NPerBlock
32, // Gemm1KPerBlock
8, // AK1
8, // BK1
2, // B1K1
32, // MPerXDL
32, // NPerXDL
1, // MXdlPerWave
4, // NXdlPerWave
2, // Gemm1NXdlPerWave
S<4, 64, 1>, // ABlockTransfer
S<1, 0, 2>,
S<1, 0, 2>,
2,
8,
8,
true,
S<4, 64, 1>, // BBlockTransfer
S<1, 0, 2>,
S<1, 0, 2>,
2,
8,
8,
true,
S<16, 16, 1>, // B1BlockTransfer
S<0, 2, 1>,
S<0, 2, 1>,
1,
4,
2,
false,
1, // CShuffleMXdlPerWavePerShuffle
2, // CShuffleNXdlPerWavePerShuffle
S<1, 32, 1, 8>, // CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
8>; // CShuffleBlockTransferScalarPerVector_NPerBlock
// Ref Gemm0: fp16 in, fp32 out
using ReferenceGemm0Instance = ck::tensor_operation::host::ReferenceBatchedGemm<ADataType,
B0DataType,
AccDataType,
AccDataType,
AElementOp,
B0ElementOp,
Acc0ElementOp>;
// Ref Softmax: fp32 in, fp16 out
using ReferenceSoftmaxInstance =
ck::tensor_operation::host::ReferenceSoftmax<AccDataType, ADataType, AccDataType>;
// Ref Gemm1: fp16 in, fp16 out
using ReferenceGemm1Instance = ck::tensor_operation::host::ReferenceBatchedGemm<ADataType,
B1DataType,
CDataType,
AccDataType,
AElementOp,
B1ElementOp,
CElementOp>;
int main(int argc, char* argv[])
{
bool do_verification = true;
int init_method = 1;
bool time_kernel = false;
// GEMM shape
ck::index_t M = 1020;
ck::index_t N = 1020;
ck::index_t K = 64;
ck::index_t O = 128;
ck::index_t BatchCount = 4;
ck::index_t StrideA = -1;
ck::index_t StrideB0 = -1;
ck::index_t StrideB1 = -1;
ck::index_t StrideC = -1;
ck::index_t BatchStrideA = -1;
ck::index_t BatchStrideB0 = -1;
ck::index_t BatchStrideB1 = -1;
ck::index_t BatchStrideC = -1;
float alpha = 1;
if(argc == 1)
{
// use default case
}
else if(argc == 4)
{
do_verification = std::stoi(argv[1]);
init_method = std::stoi(argv[2]);
time_kernel = std::stoi(argv[3]);
}
else if(argc == 9)
{
do_verification = std::stoi(argv[1]);
init_method = std::stoi(argv[2]);
time_kernel = std::stoi(argv[3]);
M = std::stoi(argv[4]);
N = std::stoi(argv[5]);
K = std::stoi(argv[6]);
O = std::stoi(argv[7]);
BatchCount = std::stoi(argv[8]);
}
else if(argc == 18)
{
do_verification = std::stoi(argv[1]);
init_method = std::stoi(argv[2]);
time_kernel = std::stoi(argv[3]);
M = std::stoi(argv[4]);
N = std::stoi(argv[5]);
K = std::stoi(argv[6]);
O = std::stoi(argv[7]);
BatchCount = std::stoi(argv[8]);
StrideA = std::stoi(argv[9]);
StrideB0 = std::stoi(argv[10]);
StrideB1 = std::stoi(argv[11]);
StrideC = std::stoi(argv[12]);
BatchStrideA = std::stoi(argv[13]);
BatchStrideB0 = std::stoi(argv[14]);
BatchStrideB1 = std::stoi(argv[15]);
BatchStrideC = std::stoi(argv[16]);
alpha = std::stof(argv[17]);
}
else
{
printf("arg1: verification (0=no, 1=yes)\n");
printf("arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n");
printf("arg3: time kernel (0=no, 1=yes)\n");
printf("arg4 to 16: M, N, K, O, Batch, StrideA, StrideB0, StrideB1, StrideC, BatchStrideA, "
"BatchStrideB0, BatchStrideB1, BatchStrideC\n");
printf("arg17: scale (alpha)\n");
exit(0);
}
const int DefaultStrideA = ck::is_same_v<ALayout, Row> ? K : M;
const int DefaultStrideB0 = ck::is_same_v<B0Layout, Row> ? N : K;
const int DefaultStrideB1 = ck::is_same_v<B1Layout, Row> ? O : N;
const int DefaultStrideC = ck::is_same_v<CLayout, Row> ? O : M;
StrideA = (StrideA < 0) ? DefaultStrideA : StrideA;
StrideB0 = (StrideB0 < 0) ? DefaultStrideB0 : StrideB0;
StrideB1 = (StrideB1 < 0) ? DefaultStrideB1 : StrideB1;
StrideC = (StrideC < 0) ? DefaultStrideC : StrideC;
const int DefaultBatchStrideA = (ck::is_same_v<ALayout, Col> ? K : M) * StrideA;
const int DefaultBatchStrideB0 = (ck::is_same_v<B0Layout, Col> ? N : K) * StrideB0;
const int DefaultBatchStrideB1 = (ck::is_same_v<B1Layout, Col> ? O : N) * StrideB1;
const int DefaultBatchStrideC = (ck::is_same_v<CLayout, Col> ? O : M) * StrideC;
BatchStrideA = BatchStrideA < 0 ? DefaultBatchStrideA : BatchStrideA;
BatchStrideB0 = BatchStrideB0 < 0 ? DefaultBatchStrideB0 : BatchStrideB0;
BatchStrideB1 = BatchStrideB1 < 0 ? DefaultBatchStrideB1 : BatchStrideB1;
BatchStrideC = BatchStrideC < 0 ? DefaultBatchStrideC : BatchStrideC;
auto f_host_tensor_descriptor = [](std::size_t batch_count,
std::size_t row,
std::size_t col,
std::size_t stride,
std::size_t batch_stride,
auto layout) {
if(std::is_same<decltype(layout), Row>::value)
{
return HostTensorDescriptor(std::vector<std::size_t>({batch_count, row, col}),
std::vector<std::size_t>({batch_stride, stride, 1}));
}
else
{
return HostTensorDescriptor(std::vector<std::size_t>({batch_count, row, col}),
std::vector<std::size_t>({batch_stride, 1, stride}));
}
};
// C_m_o = A_m_k * B0_k_n * B1_n_o
Tensor<ADataType> a_g_m_k(
f_host_tensor_descriptor(BatchCount, M, K, StrideA, BatchStrideA, ALayout{}));
Tensor<B0DataType> b0_g_k_n(
f_host_tensor_descriptor(BatchCount, K, N, StrideB0, BatchStrideB0, B0Layout{}));
Tensor<B1DataType> b1_g_n_o(
f_host_tensor_descriptor(BatchCount, N, O, StrideB1, BatchStrideB1, B1Layout{}));
Tensor<CDataType> c_g_m_o_host_result(
f_host_tensor_descriptor(BatchCount, M, O, StrideC, BatchStrideC, CLayout{}));
Tensor<CDataType> c_g_m_o_device_result(
f_host_tensor_descriptor(BatchCount, M, O, StrideC, BatchStrideC, CLayout{}));
std::cout << "a_g_m_k: " << a_g_m_k.mDesc << std::endl;
std::cout << "b0_g_k_n: " << b0_g_k_n.mDesc << std::endl;
std::cout << "b1_g_n_o: " << b1_g_n_o.mDesc << std::endl;
std::cout << "c_g_m_o: " << c_g_m_o_host_result.mDesc << std::endl;
switch(init_method)
{
case 0: break;
case 1:
a_g_m_k.GenerateTensorValue(GeneratorTensor_2<ADataType>{-5, 5});
b0_g_k_n.GenerateTensorValue(GeneratorTensor_2<B0DataType>{-5, 5});
b1_g_n_o.GenerateTensorValue(GeneratorTensor_2<B1DataType>{-5, 5});
break;
case 2:
a_g_m_k.GenerateTensorValue(GeneratorTensor_3<ADataType>{0.0, 1.0});
b0_g_k_n.GenerateTensorValue(GeneratorTensor_3<B0DataType>{0.0, 1.0});
b1_g_n_o.GenerateTensorValue(GeneratorTensor_3<B1DataType>{-0.5, 0.5});
break;
case 3:
a_g_m_k.GenerateTensorValue(GeneratorTensor_2<ADataType>{-2, 2});
b0_g_k_n.GenerateTensorValue(GeneratorTensor_Diagonal<B0DataType>{});
b1_g_n_o.GenerateTensorValue(GeneratorTensor_Diagonal<B1DataType>{});
break;
default:
a_g_m_k.GenerateTensorValue(GeneratorTensor_1<ADataType>{1});
b0_g_k_n.GenerateTensorValue(GeneratorTensor_Sequential<1>{});
b1_g_n_o.GenerateTensorValue(GeneratorTensor_Diagonal<B1DataType>{});
}
DeviceMem a_g_m_k_device_buf(sizeof(ADataType) * a_g_m_k.mDesc.GetElementSpaceSize());
DeviceMem b0_g_k_n_device_buf(sizeof(B0DataType) * b0_g_k_n.mDesc.GetElementSpaceSize());
DeviceMem b1_g_n_o_device_buf(sizeof(B1DataType) * b1_g_n_o.mDesc.GetElementSpaceSize());
DeviceMem c_g_m_o_device_buf(sizeof(CDataType) *
c_g_m_o_device_result.mDesc.GetElementSpaceSize());
a_g_m_k_device_buf.ToDevice(a_g_m_k.mData.data());
b0_g_k_n_device_buf.ToDevice(b0_g_k_n.mData.data());
b1_g_n_o_device_buf.ToDevice(b1_g_n_o.mData.data());
auto a_element_op = AElementOp{};
auto b0_element_op = B0ElementOp{};
auto acc0_element_op = Acc0ElementOp{alpha};
auto b1_element_op = B1ElementOp{};
auto c_element_op = CElementOp{};
// do GEMM
auto gemm = DeviceGemmInstance{};
auto invoker = gemm.MakeInvoker();
auto argument =
gemm.MakeArgument(static_cast<ADataType*>(a_g_m_k_device_buf.GetDeviceBuffer()),
static_cast<B0DataType*>(b0_g_k_n_device_buf.GetDeviceBuffer()),
static_cast<B1DataType*>(b1_g_n_o_device_buf.GetDeviceBuffer()),
static_cast<CDataType*>(c_g_m_o_device_buf.GetDeviceBuffer()),
M,
N,
K,
O,
BatchCount,
StrideA,
StrideB0,
StrideB1,
StrideC,
BatchStrideA,
BatchStrideB0,
BatchStrideB1,
BatchStrideC,
a_element_op,
b0_element_op,
acc0_element_op,
b1_element_op,
c_element_op);
if(!gemm.IsSupportedArgument(argument))
{
std::cout << gemm.GetTypeString() << " does not support this problem" << std::endl;
return 0;
}
float ave_time = invoker.Run(argument, StreamConfig{nullptr, time_kernel});
std::size_t flop = (size_t(M) * N * K * 2 + size_t(M) * N * O * 2) * BatchCount;
std::size_t num_btype = (sizeof(ADataType) * M * K + sizeof(B0DataType) * K * N +
sizeof(B1DataType) * N * O + sizeof(CDataType) * M * O) *
BatchCount;
float tflops = static_cast<float>(flop) / 1.E9 / ave_time;
float gb_per_sec = num_btype / 1.E6 / ave_time;
std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s, "
<< gemm.GetTypeString() << std::endl;
c_g_m_o_device_buf.FromDevice(c_g_m_o_device_result.mData.data());
if(do_verification)
{
// Output of Gemm0 is input A of Gemm1
Tensor<AccDataType> acc0_g_m_n(f_host_tensor_descriptor(BatchCount, M, N, N, M * N, Row{}));
Tensor<ADataType> a1_g_m_n(f_host_tensor_descriptor(BatchCount, M, N, N, M * N, Row{}));
auto ref_gemm0 = ReferenceGemm0Instance{};
auto ref_gemm0_invoker = ref_gemm0.MakeInvoker();
auto ref_gemm0_argument = ref_gemm0.MakeArgument(
a_g_m_k, b0_g_k_n, acc0_g_m_n, a_element_op, b0_element_op, acc0_element_op);
ref_gemm0_invoker.Run(ref_gemm0_argument);
auto ref_softmax = ReferenceSoftmaxInstance{};
auto ref_softmax_invoker = ref_softmax.MakeInvoker();
auto ref_softmax_argument = ref_softmax.MakeArgument(acc0_g_m_n, a1_g_m_n, 1, 0, {2});
ref_softmax_invoker.Run(ref_softmax_argument);
auto ref_gemm1 = ReferenceGemm1Instance{};
auto ref_gemm1_invoker = ref_gemm1.MakeInvoker();
auto ref_gemm1_argument = ref_gemm1.MakeArgument(
a1_g_m_n, b1_g_n_o, c_g_m_o_host_result, PassThrough{}, b1_element_op, c_element_op);
ref_gemm1_invoker.Run(ref_gemm1_argument);
return ck::utils::check_err(c_g_m_o_device_result.mData, c_g_m_o_host_result.mData) ? 0 : 1;
}
return 0;
}
# TODO: add example batched_gemm_gemm_xdl_fp16
add_example_executable(example_batched_gemm_softmax_gemm_xdl_fp16 batched_gemm_softmax_gemm_xdl_fp16.cpp)
add_example_executable(example_dual_reduce_multiblock dual_reduce_multiblock.cpp)
add_example_executable(example_dual_reduce_threadwise dual_reduce_threadwise.cpp)
# Instructions for ```example_dual_reduce```
## Run ```example_dual_reduce_multiblock```
```bash
# -D <xxx> : input 4-d tensor lengths
# -v <x> : verification (0=no, 1=yes)
#arg1: initialization (0=no init, 1=single integer value, 2=scope integer value, 3=decimal value)
#arg2: time kernel (0=no, 1=yes)
./bin/example_dual_reduce_multiblock -D 600,28,28,256 -v 1 2 1
```
Result
```
./bin/example_dual_reduce_multiblock -D 600,28,28,256 -v 1 2 1
launch_and_time_kernel: grid_dim {150, 1, 1}, block_dim {256, 1, 1}
Warm up 1 time
Start running 10 times...
Perf: 1.19529 ms, 201.499 GB/s, DeviceMultipleReduceBlockWise<256,M_C4_S1,K_C64_S1,InSrcVectorDim_1_InSrcVectorSize_1,OutDstVectorSize_1_1>
```
## Run ```example_dual_reduce_threadwise```
```bash
# -D <xxx> : input 4-d tensor lengths
# -v <x> : verification (0=no, 1=yes)
#arg1: initialization (0=no init, 1=single integer value, 2=scope integer value, 3=decimal value)
#arg2: time kernel (0=no, 1=yes)
./bin/example_dual_reduce_multiblock -D 8000,4,4,4 -v 1 2 1
```
Result
```
./bin/example_dual_reduce_threadwise -D 8000,4,4,4 -v 1 2 1
launch_and_time_kernel: grid_dim {32, 1, 1}, block_dim {256, 1, 1}
Warm up 1 time
Start running 10 times...
Perf: 0.01512 ms, 71.9577 GB/s, DeviceMultipleReduceThreadwise<256,M_C256_S1,K_C1_S4,InSrcVectorDim_1_InSrcVectorSize_2,OutDstVectorSize_1_1>
```
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#include <iostream>
#include <cstdlib>
#include <vector>
#include <array>
#include <algorithm>
#include <getopt.h>
#include "ck/ck.hpp"
#include "ck/utility/reduction_enums.hpp"
#include "ck/utility/data_type.hpp"
#include "ck/library/utility/check_err.hpp"
#include "ck/library/utility/device_memory.hpp"
#include "ck/library/utility/host_tensor.hpp"
#include "ck/library/utility/host_tensor_generator.hpp"
#include "ck/library/utility/host_common_util.hpp"
static struct option long_options[] = {{"inLengths", required_argument, nullptr, 'D'},
{"verify", required_argument, nullptr, 'v'},
{"help", no_argument, nullptr, '?'},
{nullptr, 0, nullptr, 0}};
class SimpleAppArgs
{
private:
int option_index = 0;
public:
std::vector<size_t> inLengths = {600, 28, 28, 256};
size_t n, h, w, c;
bool do_verification = true;
int init_method = 2;
bool time_kernel = true;
public:
SimpleAppArgs()
{
n = inLengths[0];
h = inLengths[1];
w = inLengths[2];
c = inLengths[3];
};
void show_usage(const char* cmd)
{
std::cout << "Usage of " << cmd << std::endl;
std::cout << "--inLengths or -D, comma separated list of input tensor dimension lengths"
<< std::endl;
std::cout << "--verify or -v, 1/0 to indicate whether to verify the reduction result by "
"comparing with the host-based reduction"
<< std::endl;
std::cout << "Arg1 -- init method (0=no init, 1=single integer value, 2=scope integer "
"value, 3=decimal value)"
<< std::endl;
std::cout << "Arg2 -- time kernel (0=no, 1=yes)" << std::endl;
};
int processArgs(int argc, char* argv[])
{
using ck::host_common::getTypeValuesFromString;
int ch;
while(1)
{
ch = getopt_long(argc, argv, "D:v:l:", long_options, &option_index);
if(ch == -1)
break;
switch(ch)
{
case 'D':
if(!optarg)
throw std::runtime_error("Invalid option format!");
inLengths = getTypeValuesFromString<size_t>(optarg);
if(inLengths.size() != 4)
throw std::runtime_error(
"Invalid option format! The number of integers is incorrect!");
break;
case 'v':
if(!optarg)
throw std::runtime_error("Invalid option format!");
do_verification = static_cast<bool>(std::atoi(optarg));
break;
case '?':
if(std::string(long_options[option_index].name) == "help")
{
show_usage(argv[0]);
return (-1);
};
break;
default: show_usage(argv[0]); return (-1);
};
};
if(optind + 2 > argc)
throw std::runtime_error("Invalid cmd-line arguments, more argumetns are needed!");
init_method = std::atoi(argv[optind++]);
time_kernel = static_cast<bool>(std::atoi(argv[optind]));
n = inLengths[0];
h = inLengths[1];
w = inLengths[2];
c = inLengths[3];
return (0);
};
};
template <typename InDataType, typename OutDataType1, typename OutDataType2, typename AccDataType>
static void mean_meansquare_host(const Tensor<InDataType>& in,
Tensor<OutDataType1>& mean_ref,
Tensor<OutDataType2>& meansquare_ref,
size_t n,
size_t h,
size_t w,
size_t c)
{
auto thread_reduce_func = [&](auto iN) {
AccDataType mean = ck::type_convert<AccDataType>(0.0f);
AccDataType meansquare = ck::type_convert<AccDataType>(0.0f);
// compute mean, meanquare, variance, invVariance
for(std::size_t iH = 0; iH < h; iH++)
{
for(std::size_t iW = 0; iW < w; iW++)
{
for(std::size_t iC = 0; iC < c; iC++)
{
AccDataType curr_value = ck::type_convert<AccDataType>(in(iN, iH, iW, iC));
mean += curr_value;
meansquare += curr_value * curr_value;
};
}
};
mean = mean / (h * w * c);
meansquare = meansquare / (h * w * c);
mean_ref(iN) = ck::type_convert<OutDataType1>(mean);
meansquare_ref(iN) = ck::type_convert<OutDataType2>(meansquare);
};
std::size_t num_thread = std::thread::hardware_concurrency();
std::size_t work_per_thread = (n + num_thread - 1) / num_thread;
std::vector<joinable_thread> threads(num_thread);
for(std::size_t it = 0; it < num_thread; it++)
{
std::size_t iN_begin = it * work_per_thread;
std::size_t iN_end = std::min(static_cast<size_t>((it + 1) * work_per_thread), n);
auto f = [=] {
for(std::size_t iN = iN_begin; iN < iN_end; iN++)
{
thread_reduce_func(iN);
}
};
threads[it] = joinable_thread(f);
}
};
using ReduceOperation = ck::reduce::Add;
using InElementwiseOperation_Mean = ck::tensor_operation::element_wise::PassThrough;
using AccElementwiseOperation_Mean = ck::tensor_operation::element_wise::UnaryDivide;
using InElementwiseOperation_Meansquare = ck::tensor_operation::element_wise::UnarySquare;
using AccElementwiseOperation_Meansquare = ck::tensor_operation::element_wise::UnaryDivide;
using InElementwiseOperationTuple =
ck::Tuple<InElementwiseOperation_Mean, InElementwiseOperation_Meansquare>;
using AccElementwiseOperationTuple =
ck::Tuple<AccElementwiseOperation_Mean, AccElementwiseOperation_Meansquare>;
template <typename DeviceDualReduce,
typename InDataType,
typename OutDataType,
typename AccDataType,
int Rank,
int NumReduceDim>
int mean_meansquare_dual_reduce_test(size_t n,
size_t h,
size_t w,
size_t c,
bool do_verification,
int init_method,
bool time_kernel,
const std::array<int, NumReduceDim> reduceDims)
{
const std::vector<size_t> inLengths = {n, h, w, c};
Tensor<InDataType> in(inLengths);
std::vector<size_t> outLengths{n};
Tensor<OutDataType> mean_ref(outLengths);
Tensor<OutDataType> mean(outLengths);
Tensor<OutDataType> meansquare_ref(outLengths);
Tensor<OutDataType> meansquare(outLengths);
auto inStrides = in.mDesc.GetStrides();
auto outStrides = mean.mDesc.GetStrides();
size_t invariant_total_length = n;
size_t reduce_total_length = h * w * c;
const AccDataType alpha = ck::type_convert<AccDataType>(1.0f);
const AccDataType beta = ck::type_convert<AccDataType>(0.0f);
std::size_t num_thread = 1;
if(do_verification)
{
switch(init_method)
{
case 0: break;
case 1: in.GenerateTensorValue(GeneratorTensor_1<InDataType>{1}, num_thread); break;
case 2: in.GenerateTensorValue(GeneratorTensor_2<InDataType>{-5, 5}, num_thread); break;
default: in.GenerateTensorValue(GeneratorTensor_3<InDataType>{-5.0, 5.0}, num_thread);
}
};
// these buffers are usually provided by the user application
DeviceMem in_dev(sizeof(InDataType) * in.mDesc.GetElementSpaceSize());
DeviceMem mean_dev(sizeof(OutDataType) * mean.mDesc.GetElementSpaceSize());
DeviceMem meansquare_dev(sizeof(OutDataType) * meansquare.mDesc.GetElementSpaceSize());
in_dev.ToDevice(in.mData.data());
if(do_verification)
{
mean_meansquare_host<InDataType, OutDataType, OutDataType, AccDataType>(
in, mean_ref, meansquare_ref, n, h, w, c);
};
constexpr ck::index_t NumInputDim = Rank;
constexpr ck::index_t NumOutputDim = (Rank - NumReduceDim > 1) ? Rank - NumReduceDim : 1;
std::array<ck::index_t, NumInputDim> i_inLengths;
std::array<ck::index_t, NumInputDim> i_inStrides;
std::array<ck::index_t, NumOutputDim> i_outLengths;
std::array<ck::index_t, NumOutputDim> i_outStrides;
std::copy(inLengths.begin(), inLengths.end(), i_inLengths.begin());
std::copy(inStrides.begin(), inStrides.end(), i_inStrides.begin());
std::copy(outLengths.begin(), outLengths.end(), i_outLengths.begin());
std::copy(outStrides.begin(), outStrides.end(), i_outStrides.begin());
auto dual_reduce_op = DeviceDualReduce{};
auto argument_ptr = dual_reduce_op.MakeArgumentPointer(
i_inLengths,
i_inStrides,
i_outLengths,
{i_outStrides, i_outStrides},
reduceDims,
{&alpha, &alpha},
{&beta, &beta},
in_dev.GetDeviceBuffer(),
{mean_dev.GetDeviceBuffer(), meansquare_dev.GetDeviceBuffer()},
ck::make_tuple(InElementwiseOperation_Mean{}, InElementwiseOperation_Meansquare{}),
ck::make_tuple(
AccElementwiseOperation_Mean{static_cast<int32_t>(reduce_total_length)},
AccElementwiseOperation_Meansquare{static_cast<int32_t>(reduce_total_length)}));
if(!dual_reduce_op.IsSupportedArgument(argument_ptr.get()))
{
std::cout
<< "The runtime parameters seems not supported by the DeviceReduce instance, exiting!"
<< std::endl;
return (-1);
};
std::string reduce_name = dual_reduce_op.GetTypeString();
auto invoker_ptr = dual_reduce_op.MakeInvokerPointer();
float avg_time = 0.0f;
avg_time += invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, time_kernel});
std::size_t num_bytes = invariant_total_length * reduce_total_length * sizeof(InDataType) +
2 * invariant_total_length * sizeof(OutDataType);
float gb_per_sec = num_bytes / 1.E6 / avg_time;
std::cout << "Perf: " << avg_time << " ms, " << gb_per_sec << " GB/s, " << reduce_name
<< std::endl;
bool pass = true;
if(do_verification)
{
mean_dev.FromDevice(mean.mData.data());
meansquare_dev.FromDevice(meansquare.mData.data());
pass = pass && ck::utils::check_err(mean.mData, mean_ref.mData);
pass = pass && ck::utils::check_err(meansquare.mData, meansquare_ref.mData);
};
return (pass ? 0 : 1);
}
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#include <iostream>
#include <cstdlib>
#include <vector>
#include <array>
#include <algorithm>
#include <getopt.h>
#include "ck/ck.hpp"
#include "ck/utility/reduction_enums.hpp"
#include "ck/utility/data_type.hpp"
#include "ck/tensor_operation/gpu/device/device_base.hpp"
#include "ck/tensor_operation/gpu/device/device_multiple_reduce_multiblock.hpp"
#include "ck/tensor_operation/gpu/device/reduction_operator_mapping.hpp"
#include "dual_reduce_common.hpp"
using namespace ck;
using namespace ck::tensor_operation::device;
using InDataType = ck::half_t;
using OutDataType = float;
using OutDataTypeTuple = Tuple<OutDataType, OutDataType>;
using AccDataType = float;
// for NHWC layer-norm calculation of mean and meansquare
constexpr int Rank = 4;
constexpr int NumReduceDim = 3;
constexpr bool PropagateNan = false;
constexpr InMemoryDataOperationEnum OutMemoryDataOperation = InMemoryDataOperationEnum::Set;
using DeviceDualReduce = DeviceMultipleReduceMultiBlock<2,
InDataType,
AccDataType,
OutDataTypeTuple,
Rank,
NumReduceDim,
ReduceOperation,
InElementwiseOperationTuple,
AccElementwiseOperationTuple,
OutMemoryDataOperation,
PropagateNan,
256,
4,
64,
1,
1,
1, // InSrcVectorDim
1,
ck::Sequence<1, 1>>;
int main(int argc, char* argv[])
{
int retval = 0;
if(argc > 1)
{
SimpleAppArgs arg;
if(arg.processArgs(argc, argv) < 0)
return (-1);
std::array<int, NumReduceDim> reduceDims = {1, 2, 3};
retval = mean_meansquare_dual_reduce_test<DeviceDualReduce,
InDataType,
OutDataType,
AccDataType,
Rank,
NumReduceDim>(arg.n,
arg.h,
arg.w,
arg.c,
arg.do_verification,
arg.init_method,
arg.time_kernel,
reduceDims);
}
else
{
std::array<int, NumReduceDim> reduceDims = {1, 2, 3};
retval = mean_meansquare_dual_reduce_test<DeviceDualReduce,
InDataType,
OutDataType,
AccDataType,
Rank,
NumReduceDim>(
600, 28, 28, 256, true, 2, true, reduceDims);
};
return (retval);
}
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#include <iostream>
#include <cstdlib>
#include <vector>
#include <array>
#include <algorithm>
#include <getopt.h>
#include "ck/ck.hpp"
#include "ck/utility/reduction_enums.hpp"
#include "ck/utility/data_type.hpp"
#include "ck/tensor_operation/gpu/device/device_base.hpp"
#include "ck/tensor_operation/gpu/device/device_multiple_reduce_threadwise.hpp"
#include "ck/tensor_operation/gpu/device/reduction_operator_mapping.hpp"
#include "dual_reduce_common.hpp"
using namespace ck;
using namespace ck::tensor_operation::device;
using InDataType = ck::half_t;
using OutDataType = float;
using OutDataTypeTuple = Tuple<OutDataType, OutDataType>;
using AccDataType = float;
// for NHWC layer-norm calculation of mean and meansquare
constexpr int Rank = 4;
constexpr int NumReduceDim = 3;
constexpr bool PropagateNan = false;
using DeviceDualReduce = DeviceMultipleReduceThreadWise<2,
InDataType,
AccDataType,
OutDataTypeTuple,
Rank,
NumReduceDim,
ReduceOperation,
InElementwiseOperationTuple,
AccElementwiseOperationTuple,
PropagateNan,
256,
1,
4,
1, // InSrcVectorDim
2,
ck::Sequence<1, 1>>;
int main(int argc, char* argv[])
{
int retval = 0;
if(argc > 1)
{
SimpleAppArgs arg;
if(arg.processArgs(argc, argv) < 0)
return (-1);
std::array<int, NumReduceDim> reduceDims = {1, 2, 3};
retval = mean_meansquare_dual_reduce_test<DeviceDualReduce,
InDataType,
OutDataType,
AccDataType,
Rank,
NumReduceDim>(arg.n,
arg.h,
arg.w,
arg.c,
arg.do_verification,
arg.init_method,
arg.time_kernel,
reduceDims);
}
else
{
std::array<int, NumReduceDim> reduceDims = {1, 2, 3};
retval = mean_meansquare_dual_reduce_test<DeviceDualReduce,
InDataType,
OutDataType,
AccDataType,
Rank,
NumReduceDim>(
8000, 4, 4, 4, true, 2, true, reduceDims);
};
return (retval);
}
add_example_executable(example_batchnorm_forward batchnorm_forward_nhwc.cpp)
add_example_executable(example_batchnorm_infer batchnorm_infer_nhwc.cpp)
# Instructions for ```batchnorm nhwc``` Example
## Run ```batchnorm forward nhwc```
```bash
# -D <xxx> : input 4-d tensor lengths
# -v <x> : verification (0=no, 1=yes)
#arg1: data type (0: fp16, 1: fp32, 3: int8, 5: bp16, 6: fp64)
#arg2: 1/0 to indicate whether to update the moving average and variance (0=no, 1=yes)
#arg3: 1/0 to indicate whether to save result mean/invVariance (0=no, 1=yes)
#arg4: initialization (0=no init, 1=single integer value, 2=scope integer value, 3=decimal value)
#arg5: time kernel (0=no, 1=yes)
./bin/example_batchnorm_forward -D 128,16,16,1024 -v 1 0 0 1 2 1
```
Result
```
./bin/example_batchnorm_forward -D 128,16,16,1024 -v 1 0 0 1 2 1
launch_and_time_kernel: grid_dim {64, 1, 1}, block_dim {256, 1, 1}
Warm up 1 time
Start running 10 times...
launch_and_time_kernel: grid_dim {120, 1, 1}, block_dim {256, 1, 1}
Warm up 1 time
Start running 10 times...
launch_and_time_kernel: grid_dim {120, 1, 1}, block_dim {256, 1, 1}
Warm up 1 time
Start running 10 times...
Perf: 2.08231 ms, 354.519 GB/s
```
Result
```
./bin/example_batchnorm_forward -D 128,16,16,1024 -v 1 0 1 0 2 0
echo $?
0
```
## Run ```batchnorm infer nhwc```
```bash
# -D <xxx> : input 4-d tensor lengths
# -v <x> : verification (0=no, 1=yes)
#arg1: data type (0: fp16, 1: fp32, 3: int8, 5: bp16, 6: fp64)
#arg2: initialization (0=no init, 1=single integer value, 2=scope integer value, 3=decimal value)
#arg3: time kernel (0=no, 1=yes)
./bin/example_batchnorm_infer -D 128,16,16,1024 -v 1 0 2 1
```
Result
```
./bin/example_batchnorm_infer -D 128,16,16,1024 -v 1 0 2 1
launch_and_time_kernel: grid_dim {120, 1, 1}, block_dim {256, 1, 1}
Warm up 1 time
Start running 10 times...
Perf: 1.28235 ms, 523.329 GB/s
```
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <cassert>
#include <vector>
#include <array>
#include <type_traits>
#include "ck/utility/data_type.hpp"
// binary operation used to calculate invVariance from mean and meansquare
struct InvVariance
{
InvVariance(double epsilon) : epsilon_(epsilon){};
template <typename T>
__host__ __device__ constexpr void operator()(T& y, const T& mean, const T& meansquare) const
{
static_assert(std::is_same<T, float>::value || std::is_same<T, double>::value,
"Data type is not supported by this operation!");
using ck::type_convert;
using ck::math::sqrt;
T tmp_epsilon = type_convert<T>(epsilon_);
y = meansquare - mean * mean;
y = 1.0f / sqrt(tmp_epsilon + y);
};
double epsilon_;
};
// (4-in, 2-out) element-wise operation used to update the moving average of mean and variance
struct MovingAverage
{
MovingAverage(double factor) : factor_(factor){};
template <typename T>
__host__ __device__ constexpr void operator()(T& y0,
T& y1,
const T& mean,
const T& runningMean,
const T& meansquare,
const T& runningVariance) const
{
static_assert(std::is_same<T, float>::value || std::is_same<T, double>::value,
"Data type is not supported by this operation!");
using ck::type_convert;
T tmp_factor = type_convert<T>(factor_);
T variance = meansquare - mean * mean;
y0 = runningMean * (type_convert<T>(1.0f) - tmp_factor) + mean * tmp_factor;
y1 = runningVariance * (type_convert<T>(1.0f) - tmp_factor) + variance * tmp_factor;
};
double factor_;
};
struct MovingAverageAndInvVariance
{
MovingAverageAndInvVariance(double epsilon, double factor)
: epsilon_(epsilon), factor_(factor){};
template <typename T>
__host__ __device__ constexpr void operator()(T& y0, // resultRunningMean
T& y1, // resultRunningVariance
T& y2, // saveInvVariance
const T& mean,
const T& runningMean,
const T& meansquare,
const T& runningVariance) const
{
static_assert(std::is_same<T, float>::value || std::is_same<T, double>::value,
"Data type is not supported by this operation!");
using ck::type_convert;
using ck::math::sqrt;
T tmp_epsilon = type_convert<T>(epsilon_);
T tmp_factor = type_convert<T>(factor_);
T variance = meansquare - mean * mean;
y0 = runningMean * (type_convert<T>(1.0f) - tmp_factor) + mean * tmp_factor;
y1 = runningVariance * (type_convert<T>(1.0f) - tmp_factor) + variance * tmp_factor;
y2 = 1.0f / sqrt(tmp_epsilon + variance);
};
double epsilon_;
double factor_;
};
struct NormalizeInInfer
{
NormalizeInInfer(double epsilon = 1e-4) : epsilon_(epsilon) {}
template <typename T1, typename T2>
__host__ __device__ constexpr void operator()(T1& y,
const T1& x,
const T2& mean,
const T2& variance,
const T2& gamma,
const T2& beta) const
{
static_assert(std::is_same<T2, float>::value || std::is_same<T2, double>::value,
"Data type is not supported by this operation!");
using ck::type_convert;
using ck::math::sqrt;
T2 tmp_x, tmp_y;
tmp_x = type_convert<T2>(x);
tmp_y = ((tmp_x - mean) / sqrt(variance + type_convert<T2>(epsilon_))) * gamma + beta;
y = type_convert<T1>(tmp_y);
};
double epsilon_;
};
struct NormalizeInForward
{
NormalizeInForward(double epsilon = 1e-4) : epsilon_(epsilon) {}
template <typename T1, typename T2>
__host__ __device__ constexpr void operator()(T1& y,
const T1& x,
const T2& mean,
const T2& meansquare,
const T2& gamma,
const T2& beta) const
{
static_assert(std::is_same<T2, float>::value || std::is_same<T2, double>::value,
"Data type is not supported by this operation!");
using ck::type_convert;
using ck::math::sqrt;
T2 tmp_x, tmp_y;
T2 variance = meansquare - mean * mean;
tmp_x = type_convert<T2>(x);
tmp_y = ((tmp_x - mean) / sqrt(variance + type_convert<T2>(epsilon_))) * gamma + beta;
y = type_convert<T1>(tmp_y);
};
double epsilon_;
};
template <int Rank, int NumReduceDim>
static inline std::array<int, Rank - NumReduceDim>
get_invariant_dims(const std::array<int, NumReduceDim>& reduceDims)
{
int reduceFlag = 0;
// flag the bits for the reduceDims
for(int i = 0; i < NumReduceDim; i++)
{
reduceFlag |= 1 << reduceDims[i];
};
std::array<int, Rank - NumReduceDim> invariantDims;
// collect invariant dimensions
int dim = 0;
for(int i = 0; i < Rank; i++)
if((reduceFlag & (1 << i)) == 0)
{
invariantDims[dim] = i;
dim++;
};
return invariantDims;
};
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <cassert>
#include <vector>
#include "ck/ck.hpp"
#include "ck/utility/reduction_operator.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/tensor_operation/gpu/device/device_multiple_reduce_multiblock.hpp"
#include "ck/tensor_operation/gpu/device/device_elementwise.hpp"
#include "batchnorm_common.hpp"
template <typename InOutDataType,
typename AccDataType,
ck::index_t Rank,
ck::index_t NumBatchNormReduceDim,
bool fastest_dim_is_reduced = false>
int bnorm_fwd(bool time_kernel,
bool updateMovingAverage,
bool saveMeanAndInvVariance,
const std::array<int, NumBatchNormReduceDim> reduceDims,
const std::array<ck::index_t, Rank> xyLengths,
const std::array<ck::index_t, Rank> xStrides,
const std::array<ck::index_t, Rank> yStrides,
const std::array<ck::index_t, Rank - NumBatchNormReduceDim> bnScaleBiasMeanVarLengths,
const std::array<ck::index_t, Rank - NumBatchNormReduceDim> bnScaleBiasMeanVarStrides,
const void* p_x,
const void* p_scale,
const void* p_bias,
void* p_y,
double exponentialAverageFactor,
void* p_runningMean,
void* p_runningVariance,
double epsilon,
void* p_saveMean,
void* p_saveInvVariance,
void* p_tmp_mean,
void* p_tmp_meansquare)
{
static_assert(NumBatchNormReduceDim < Rank,
"Invalid number of reduced dimensions for batchnorm!");
constexpr ck::index_t NumScaleBiasMeanVarDim = Rank - NumBatchNormReduceDim;
using InElementwiseOperation_Mean = ck::tensor_operation::element_wise::PassThrough;
using AccElementwiseOperation_Mean = ck::tensor_operation::element_wise::UnaryDivide;
using InElementwiseOperation_Meansquare = ck::tensor_operation::element_wise::UnarySquare;
using AccElementwiseOperation_Meansquare = ck::tensor_operation::element_wise::UnaryDivide;
using DeviceMeanAndMeansquareInstance =
ck::tensor_operation::device::DeviceMultipleReduceMultiBlock<
2,
InOutDataType,
AccDataType,
ck::Tuple<AccDataType, AccDataType>,
Rank,
NumBatchNormReduceDim,
ck::reduce::Add,
ck::Tuple<InElementwiseOperation_Mean, InElementwiseOperation_Meansquare>,
ck::Tuple<AccElementwiseOperation_Mean, AccElementwiseOperation_Meansquare>,
ck::InMemoryDataOperationEnum::Set,
false, // PropagateNan
256,
16,
16,
1,
1,
fastest_dim_is_reduced ? 1 : 0,
1,
ck::Sequence<1, 1>>;
using DeviceNormalizeInstance = ck::tensor_operation::device::DeviceElementwise<
ck::Tuple<InOutDataType, AccDataType, AccDataType, AccDataType, AccDataType>, // x, mean,
// meansquare,
// scale, bias
ck::Tuple<InOutDataType>, // y
NormalizeInForward,
Rank,
2, // MPerthread
ck::Sequence<1, 1, 1, 1, 1>, // scalarPerVector: x, mean, meansquare, scale, bias
ck::Sequence<1>>; // scalarPerVector: y
using DeviceInvVarianceInstance = ck::tensor_operation::device::DeviceElementwise<
ck::Tuple<AccDataType, AccDataType>, // mean, meansquare
ck::Tuple<AccDataType>, // invVariance
InvVariance,
NumScaleBiasMeanVarDim,
2, // MPerthread
ck::Sequence<1, 1>, // scalarPerVector: mean, meansquare
ck::Sequence<1>>; // scalarPerVector: invVariance
using DeviceMovingAverageInstance = ck::tensor_operation::device::DeviceElementwise<
ck::Tuple<AccDataType, AccDataType, AccDataType, AccDataType>, // old moving mean, new mean,
// old moving variance, new
// meansquare
ck::Tuple<AccDataType, AccDataType>, // updated moving mean, updated moving variance
MovingAverage,
NumScaleBiasMeanVarDim,
4, // MPerthread
ck::Sequence<1, 1, 1, 1>, // scalarPerVector: old moving mean, new mean, old moving
// variance, new meansquare
ck::Sequence<1, 1>>; // scalarPerVector: updated moving mean, updated moving variance
using DeviceMovingAverageAndInvVarianceInstance =
ck::tensor_operation::device::DeviceElementwise<
ck::Tuple<AccDataType, AccDataType, AccDataType, AccDataType>, // old moving mean, new
// mean, old moving
// variance, new
// meansquare
ck::Tuple<AccDataType, AccDataType, AccDataType>, // updated moving mean, updated moving
// variancem, invVariance
MovingAverageAndInvVariance,
NumScaleBiasMeanVarDim,
4, // MPerthread
ck::Sequence<1, 1, 1, 1>, // scalarPerVector: old moving mean, new mean, old moving
// variance, new meansquare
ck::Sequence<1, 1, 1>>; // scalarPerVector: updated moving mean, updated moving variance
auto invariantDims = get_invariant_dims<Rank, NumBatchNormReduceDim>(reduceDims);
std::array<ck::index_t, Rank> aligned_scaleBiasMeanVarStrides{0};
int i = 0;
for(auto dim : invariantDims)
{
assert(xyLengths[dim] == bnScaleBiasMeanVarLengths[i]);
aligned_scaleBiasMeanVarStrides[dim] = bnScaleBiasMeanVarStrides[i];
i++;
};
int32_t reduceLength = 1;
for(auto dim : reduceDims)
reduceLength *= xyLengths[dim];
int32_t invariantLength = 1;
for(auto dim : invariantDims)
invariantLength *= xyLengths[dim];
size_t total_length = static_cast<size_t>(invariantLength) * reduceLength;
float avg_time = 0.0f;
std::size_t num_bytes = 0;
auto dev_mean_and_meansquare = DeviceMeanAndMeansquareInstance{};
void* p_mean = saveMeanAndInvVariance ? p_saveMean : p_tmp_mean;
const AccDataType alpha = ck::type_convert<AccDataType>(1.0f);
const AccDataType beta = ck::type_convert<AccDataType>(0.0f);
auto argument_ptr1 = dev_mean_and_meansquare.MakeArgumentPointer(
xyLengths,
xStrides,
bnScaleBiasMeanVarLengths,
{bnScaleBiasMeanVarStrides, bnScaleBiasMeanVarStrides},
reduceDims,
{&alpha, &alpha},
{&beta, &beta},
p_x,
{p_mean, p_tmp_meansquare},
ck::make_tuple(InElementwiseOperation_Mean{}, InElementwiseOperation_Meansquare{}),
ck::make_tuple(AccElementwiseOperation_Mean{reduceLength},
AccElementwiseOperation_Meansquare{reduceLength}));
auto dev_normalize = DeviceNormalizeInstance{};
auto argument_ptr2 =
dev_normalize.MakeArgumentPointer(xyLengths,
{xStrides,
aligned_scaleBiasMeanVarStrides,
aligned_scaleBiasMeanVarStrides,
aligned_scaleBiasMeanVarStrides,
aligned_scaleBiasMeanVarStrides},
{yStrides},
{p_x, p_mean, p_tmp_meansquare, p_scale, p_bias},
{p_y},
NormalizeInForward{epsilon});
if(!dev_mean_and_meansquare.IsSupportedArgument(argument_ptr1.get()) ||
!dev_normalize.IsSupportedArgument(argument_ptr2.get()))
{
std::cout << "The runtime parameters seems not supported by the Devic, exiting!"
<< std::endl;
return (-1);
};
auto invoker_ptr1 = dev_mean_and_meansquare.MakeInvokerPointer();
auto invoker_ptr2 = dev_normalize.MakeInvokerPointer();
avg_time += invoker_ptr1->Run(argument_ptr1.get(), StreamConfig{nullptr, time_kernel});
avg_time += invoker_ptr2->Run(argument_ptr2.get(), StreamConfig{nullptr, time_kernel});
num_bytes +=
(total_length * sizeof(InOutDataType) + invariantLength * 2 * sizeof(AccDataType)) + // No.1
(total_length * (1 * sizeof(InOutDataType) + 4 * sizeof(AccDataType)) +
total_length * sizeof(InOutDataType)); // No.2
if(saveMeanAndInvVariance && updateMovingAverage)
{
auto dev_moving_average_inv_variance = DeviceMovingAverageAndInvVarianceInstance{};
auto argument_ptr3 = dev_moving_average_inv_variance.MakeArgumentPointer(
bnScaleBiasMeanVarLengths,
{bnScaleBiasMeanVarStrides,
bnScaleBiasMeanVarStrides,
bnScaleBiasMeanVarStrides,
bnScaleBiasMeanVarStrides},
{bnScaleBiasMeanVarStrides, bnScaleBiasMeanVarStrides, bnScaleBiasMeanVarStrides},
{p_mean, p_runningMean, p_tmp_meansquare, p_runningVariance},
{p_runningMean, p_runningVariance, p_saveInvVariance},
MovingAverageAndInvVariance{epsilon, exponentialAverageFactor});
if(!dev_moving_average_inv_variance.IsSupportedArgument(argument_ptr3.get()))
{
std::cout << "Runtime parameters not supported by the Device, exiting!" << std::endl;
return (-1);
};
auto invoker_ptr3 = dev_moving_average_inv_variance.MakeInvokerPointer();
avg_time += invoker_ptr3->Run(argument_ptr3.get(), StreamConfig{nullptr, time_kernel});
num_bytes += invariantLength * (4 + 3) * sizeof(AccDataType) * 2; // No.5
}
else if(saveMeanAndInvVariance)
{
auto dev_inv_variance = DeviceInvVarianceInstance{};
auto argument_ptr3 = dev_inv_variance.MakeArgumentPointer(
bnScaleBiasMeanVarLengths,
{bnScaleBiasMeanVarStrides, bnScaleBiasMeanVarStrides},
{bnScaleBiasMeanVarStrides},
{p_mean, p_tmp_meansquare},
{p_saveInvVariance},
InvVariance{epsilon});
if(!dev_inv_variance.IsSupportedArgument(argument_ptr3.get()))
{
std::cout << "Runtime parameters not supported by the Device, exiting!" << std::endl;
return (-1);
};
auto invoker_ptr3 = dev_inv_variance.MakeInvokerPointer();
avg_time += invoker_ptr3->Run(argument_ptr3.get(), StreamConfig{nullptr, time_kernel});
num_bytes += invariantLength * (2 + 1) * sizeof(AccDataType);
}
else if(updateMovingAverage)
{
auto dev_moving_average = DeviceMovingAverageInstance{};
auto argument_ptr3 = dev_moving_average.MakeArgumentPointer(
bnScaleBiasMeanVarLengths,
{bnScaleBiasMeanVarStrides,
bnScaleBiasMeanVarStrides,
bnScaleBiasMeanVarStrides,
bnScaleBiasMeanVarStrides},
{bnScaleBiasMeanVarStrides, bnScaleBiasMeanVarStrides},
{p_mean, p_runningMean, p_tmp_meansquare, p_runningVariance},
{p_runningMean, p_runningVariance},
MovingAverage{exponentialAverageFactor});
if(!dev_moving_average.IsSupportedArgument(argument_ptr3.get()))
{
std::cout << "Runtime parameters not supported by the Device, exiting!" << std::endl;
return (-1);
};
auto invoker_ptr3 = dev_moving_average.MakeInvokerPointer();
avg_time += invoker_ptr3->Run(argument_ptr3.get(), StreamConfig{nullptr, time_kernel});
num_bytes += invariantLength * (4 + 2) * sizeof(AccDataType) * 2; // No.5
};
if(time_kernel)
{
float gb_per_sec = num_bytes / 1.E6 / avg_time;
std::cout << "Perf: " << avg_time << " ms, " << gb_per_sec << " GB/s" << std::endl;
};
return (0);
};
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#include <limits>
#include <iostream>
#include <vector>
#include <array>
#include <algorithm>
#include <getopt.h>
#include "ck/ck.hpp"
#include "ck/library/utility/check_err.hpp"
#include "ck/library/utility/device_memory.hpp"
#include "ck/library/utility/host_tensor.hpp"
#include "ck/library/utility/host_tensor_generator.hpp"
#include "ck/library/utility/host_common_util.hpp"
#include "ck/library/reference_tensor_operation/cpu/reference_batchnorm_forward_nhwc_c.hpp"
#include "batchnorm_forward_impl.hpp"
template <typename InOutDataType, typename AccDataType>
using ReferenceBatchNormFwdInstance =
ck::tensor_operation::host::ReferenceBatchNormFwd_Input_N_H_W_C_Output_C<InOutDataType,
AccDataType>;
static struct option long_options[] = {{"inOutLengths", required_argument, nullptr, 'D'},
{"verify", required_argument, nullptr, 'v'},
{"help", no_argument, nullptr, '?'},
{nullptr, 0, nullptr, 0}};
class BatchNormFwdArg
{
private:
int option_index = 0;
public:
std::vector<size_t> inOutLengths;
bool do_verification = false;
bool updateMovingAverage;
bool saveMeanAndInvVariance;
int data_type = 0;
int init_method = 2;
bool time_kernel = false;
public:
void show_usage(const char* cmd)
{
std::cout << "Usage of " << cmd << std::endl;
std::cout << "--inOutLengths or -D, comma separated list of input tensor dimension "
"lengths, must have 4 integers for nhwc"
<< std::endl;
std::cout << "--verify or -v, 1/0 to indicate whether to verify the batch-normalization "
"result by "
"comparing with the host-based batch-normalization"
<< std::endl;
std::cout << "Arg1: data type (0: fp16, 1: fp32, 3: int8, 5: bp16, 6: fp64)" << std::endl;
std::cout << "Arg2: 1/0 to indicate whether to update the moving average and variance "
"(0=no, 1=yes)"
<< std::endl;
std::cout << "Arg3: 1/0 to indicate whether to save the calculated mean and invVariance "
"(0=no, 1=yes)"
<< std::endl;
std::cout << "Arg4: init method used for bnScale and bnBias (0=no init, 1=single integer "
"value, 2=scope integer "
"value, 3=decimal value)"
<< std::endl;
std::cout << "Arg5: time kernel (0=no, 1=yes)" << std::endl;
};
int processArgs(int argc, char* argv[])
{
using ck::host_common::getTypeValuesFromString;
int ch;
while(1)
{
ch = getopt_long(argc, argv, "D:v:", long_options, &option_index);
if(ch == -1)
break;
switch(ch)
{
case 'D':
if(!optarg)
throw std::runtime_error("Invalid option format!");
inOutLengths = getTypeValuesFromString<size_t>(optarg);
if(inOutLengths.size() != 4)
throw std::runtime_error(
"NHWC tensor layout should have 4 length values specified!");
break;
case 'v':
if(!optarg)
throw std::runtime_error("Invalid option format!");
do_verification = static_cast<bool>(std::atoi(optarg));
break;
case '?':
if(std::string(long_options[option_index].name) == "help")
{
show_usage(argv[0]);
return (-1);
};
break;
default: show_usage(argv[0]); return (-1);
};
};
if(optind + 5 > argc)
throw std::runtime_error("Invalid cmd-line arguments, more argumetns are needed!");
data_type = std::atoi(argv[optind++]);
updateMovingAverage = std::atoi(argv[optind++]);
saveMeanAndInvVariance = std::atoi(argv[optind++]);
init_method = std::atoi(argv[optind++]);
time_kernel = static_cast<bool>(std::atoi(argv[optind]));
if(data_type != 0 && data_type != 1 && data_type != 3 && data_type != 5 && data_type != 6)
return (-1);
return (0);
};
};
using namespace ck;
template <typename InOutDataType, typename AccDataType>
bool bnorm_fwd_nhwc_test(bool do_verification,
int init_method,
bool time_kernel,
const std::vector<size_t> inOutLengths,
bool updateMovingAverage,
bool saveMeanAndInvVariance,
double averageFactor,
double epsilon)
{
// for NHWC BatchNorm calculation of mean and meansquare
constexpr int Rank = 4;
constexpr int NumReduceDim = 3;
const std::vector<size_t> scaleBiasMeanVarLengths = {inOutLengths[3]};
// input data of the batchnorm forward algorithm
Tensor<InOutDataType> x(inOutLengths);
Tensor<AccDataType> bnScale(scaleBiasMeanVarLengths);
Tensor<AccDataType> bnBias(scaleBiasMeanVarLengths);
// output data of the batchnorm forward algorithm
Tensor<InOutDataType> y_ref(inOutLengths);
Tensor<InOutDataType> y(inOutLengths);
Tensor<AccDataType> resultSaveMean_ref(scaleBiasMeanVarLengths);
Tensor<AccDataType> resultSaveInvVariance_ref(scaleBiasMeanVarLengths);
Tensor<AccDataType> resultRunningMean_ref(scaleBiasMeanVarLengths);
Tensor<AccDataType> resultRunningVariance_ref(scaleBiasMeanVarLengths);
auto inOutStrides = x.mDesc.GetStrides();
auto scaleBiasMeanVarStrides = bnScale.mDesc.GetStrides();
std::size_t num_thread = std::thread::hardware_concurrency();
if(updateMovingAverage)
{
if constexpr(std::is_same<InOutDataType, int8_t>::value)
{
x.GenerateTensorValue(GeneratorTensor_2<InOutDataType>{-5, 5}, num_thread);
const float x_mean = 0.0f;
const float x_stddev = 2.5f;
const float noise_stddev = 0.04f;
resultRunningMean_ref.GenerateTensorValue(
GeneratorTensor_4<AccDataType>{x_mean, noise_stddev}, num_thread);
resultRunningVariance_ref.GenerateTensorValue(
GeneratorTensor_4<AccDataType>{x_stddev * x_stddev, noise_stddev}, num_thread);
}
else
{
const float x_mean = 0.0f;
const float x_stddev = 1.0f;
const float noise_stddev = 0.04f;
// input data in normal distribution
x.GenerateTensorValue(GeneratorTensor_4<InOutDataType>{x_mean, x_stddev}, num_thread);
// initialize the runningMean to be values with tiny variation to the mean of the x
// values
resultRunningMean_ref.GenerateTensorValue(
GeneratorTensor_4<AccDataType>{x_mean, noise_stddev}, num_thread);
// initialize the runningVariance to be values with tiny variation to the variance of
// the x values
resultRunningVariance_ref.GenerateTensorValue(
GeneratorTensor_4<AccDataType>{x_stddev * x_stddev, noise_stddev}, num_thread);
};
}
else
{
if constexpr(std::is_same<InOutDataType, int8_t>::value)
x.GenerateTensorValue(GeneratorTensor_2<InOutDataType>{-5, 5}, num_thread);
else
x.GenerateTensorValue(GeneratorTensor_3<InOutDataType>{-5.0f, 5.0f}, num_thread);
};
if(do_verification)
{
switch(init_method)
{
case 0:
bnScale.GenerateTensorValue(GeneratorTensor_0<AccDataType>{}, num_thread);
bnBias.GenerateTensorValue(GeneratorTensor_0<AccDataType>{}, num_thread);
break;
case 1:
bnScale.GenerateTensorValue(GeneratorTensor_1<AccDataType>{1}, num_thread);
bnBias.GenerateTensorValue(GeneratorTensor_1<AccDataType>{0}, num_thread);
break;
case 2:
bnScale.GenerateTensorValue(GeneratorTensor_2<AccDataType>{-5, 5}, num_thread);
bnBias.GenerateTensorValue(GeneratorTensor_2<AccDataType>{-5, 5}, num_thread);
break;
default:
bnScale.GenerateTensorValue(GeneratorTensor_3<AccDataType>{-5.0f, 5.0f}, num_thread);
bnBias.GenerateTensorValue(GeneratorTensor_3<AccDataType>{-5.0f, 5.0f}, num_thread);
}
};
// these buffers are usually provided by the user application
DeviceMem x_dev(sizeof(InOutDataType) * x.mDesc.GetElementSpaceSize());
DeviceMem y_dev(sizeof(InOutDataType) * y.mDesc.GetElementSpaceSize());
DeviceMem bnScale_dev(sizeof(AccDataType) * bnScale.mDesc.GetElementSpaceSize());
DeviceMem bnBias_dev(sizeof(AccDataType) * bnBias.mDesc.GetElementSpaceSize());
// mean_dev or resultSaveMean_dev
DeviceMem resultSaveMean_dev(sizeof(AccDataType) *
resultSaveMean_ref.mDesc.GetElementSpaceSize());
// meansquare_dev or resultSaveInvVariance_dev
DeviceMem resultSaveInvVariance_dev(sizeof(AccDataType) *
resultSaveInvVariance_ref.mDesc.GetElementSpaceSize());
// resultRunningMean_dev
DeviceMem resultRunningMean_dev(sizeof(AccDataType) *
resultRunningMean_ref.mDesc.GetElementSpaceSize());
// resultRunningVariance_dev
DeviceMem resultRunningVariance_dev(sizeof(AccDataType) *
resultRunningVariance_ref.mDesc.GetElementSpaceSize());
x_dev.ToDevice(x.mData.data());
bnScale_dev.ToDevice(bnScale.mData.data());
bnBias_dev.ToDevice(bnBias.mData.data());
if(updateMovingAverage)
{
resultRunningMean_dev.ToDevice(resultRunningMean_ref.mData.data());
resultRunningVariance_dev.ToDevice(resultRunningVariance_ref.mData.data());
};
std::array<index_t, Rank> i_inOutLengths;
std::array<index_t, Rank> i_inOutStrides;
std::array<index_t, Rank - NumReduceDim> i_scaleBiasMeanVarLengths;
std::array<index_t, Rank - NumReduceDim> i_scaleBiasMeanVarStrides;
std::copy(inOutLengths.begin(), inOutLengths.end(), i_inOutLengths.begin());
std::copy(inOutStrides.begin(), inOutStrides.end(), i_inOutStrides.begin());
std::copy(scaleBiasMeanVarLengths.begin(),
scaleBiasMeanVarLengths.end(),
i_scaleBiasMeanVarLengths.begin());
std::copy(scaleBiasMeanVarStrides.begin(),
scaleBiasMeanVarStrides.end(),
i_scaleBiasMeanVarStrides.begin());
int result = 0;
// used for saving meansquare
DeviceMem workspace(sizeof(AccDataType) * 2 * resultSaveMean_ref.mDesc.GetElementSpaceSize() +
128);
void* p_tmp_mean = workspace.GetDeviceBuffer();
void* p_tmp_meansquare =
static_cast<char*>(p_tmp_mean) +
(sizeof(AccDataType) * resultSaveMean_ref.mDesc.GetElementSpaceSize() + 63) / 64 * 64;
result = bnorm_fwd<InOutDataType, AccDataType, Rank, NumReduceDim, false>(
time_kernel,
updateMovingAverage,
saveMeanAndInvVariance,
{0, 1, 2},
i_inOutLengths,
i_inOutStrides,
i_inOutStrides,
i_scaleBiasMeanVarLengths,
i_scaleBiasMeanVarStrides,
x_dev.GetDeviceBuffer(),
bnScale_dev.GetDeviceBuffer(),
bnBias_dev.GetDeviceBuffer(),
y_dev.GetDeviceBuffer(),
averageFactor,
updateMovingAverage ? resultRunningMean_dev.GetDeviceBuffer() : nullptr,
updateMovingAverage ? resultRunningVariance_dev.GetDeviceBuffer() : nullptr,
epsilon,
saveMeanAndInvVariance ? resultSaveMean_dev.GetDeviceBuffer() : nullptr,
saveMeanAndInvVariance ? resultSaveInvVariance_dev.GetDeviceBuffer() : nullptr,
p_tmp_mean,
p_tmp_meansquare);
if(result < 0)
return (false);
bool pass = true;
if(do_verification)
{
auto batchNormFwd_ref = ReferenceBatchNormFwdInstance<InOutDataType, AccDataType>{};
auto argument_ptr_ref = batchNormFwd_ref.MakeArgumentPointer(
i_inOutLengths,
i_inOutStrides,
i_inOutStrides,
i_scaleBiasMeanVarLengths,
i_scaleBiasMeanVarStrides,
x.mData.data(),
bnScale.mData.data(),
bnBias.mData.data(),
y_ref.mData.data(),
0.1, // exponentialAverageFactor
updateMovingAverage ? resultRunningMean_ref.mData.data() : nullptr, // resultRunningMean
updateMovingAverage ? resultRunningVariance_ref.mData.data()
: nullptr, // resultRunningVariance
epsilon,
saveMeanAndInvVariance ? resultSaveMean_ref.mData.data() : nullptr,
saveMeanAndInvVariance ? resultSaveInvVariance_ref.mData.data() : nullptr);
if(!batchNormFwd_ref.IsSupportedArgument(argument_ptr_ref.get()))
{
std::cout
<< "The runtime parameters seems not supported by the BatchNorm instance, exiting!"
<< std::endl;
return (-2);
};
auto invoker_ptr_ref = batchNormFwd_ref.MakeInvokerPointer();
(void)invoker_ptr_ref->Run(argument_ptr_ref.get());
y_dev.FromDevice(y.mData.data());
pass = pass && ck::utils::check_err(y.mData, y_ref.mData);
if(updateMovingAverage)
{
Tensor<AccDataType> resultRunningMean(scaleBiasMeanVarLengths);
Tensor<AccDataType> resultRunningVariance(scaleBiasMeanVarLengths);
resultRunningMean_dev.FromDevice(resultRunningMean.mData.data());
resultRunningVariance_dev.FromDevice(resultRunningVariance.mData.data());
pass =
pass && ck::utils::check_err(resultRunningMean.mData, resultRunningMean_ref.mData);
pass = pass && ck::utils::check_err(resultRunningVariance.mData,
resultRunningVariance_ref.mData);
};
if(saveMeanAndInvVariance)
{
Tensor<AccDataType> resultSaveMean(scaleBiasMeanVarLengths);
Tensor<AccDataType> resultSaveInvVariance(scaleBiasMeanVarLengths);
resultSaveMean_dev.FromDevice(resultSaveMean.mData.data());
resultSaveInvVariance_dev.FromDevice(resultSaveInvVariance.mData.data());
pass = pass && ck::utils::check_err(resultSaveMean.mData, resultSaveMean_ref.mData);
pass = pass && ck::utils::check_err(resultSaveInvVariance.mData,
resultSaveInvVariance_ref.mData);
};
};
return (pass);
};
const double epsilon = std::numeric_limits<float>::epsilon();
static const double averageFactor = 0.1;
int main(int argc, char* argv[])
{
bool pass = true;
if(argc > 1)
{
BatchNormFwdArg arg;
if(arg.processArgs(argc, argv) < 0)
return (-1);
if(arg.data_type == 0)
{
pass = bnorm_fwd_nhwc_test<ck::half_t, float>(arg.do_verification,
arg.init_method,
arg.time_kernel,
arg.inOutLengths,
arg.updateMovingAverage,
arg.saveMeanAndInvVariance,
averageFactor,
epsilon);
}
else if(arg.data_type == 1)
{
pass = bnorm_fwd_nhwc_test<float, float>(arg.do_verification,
arg.init_method,
arg.time_kernel,
arg.inOutLengths,
arg.updateMovingAverage,
arg.saveMeanAndInvVariance,
averageFactor,
epsilon);
}
else if(arg.data_type == 3)
{
pass = bnorm_fwd_nhwc_test<int8_t, float>(arg.do_verification,
arg.init_method,
arg.time_kernel,
arg.inOutLengths,
arg.updateMovingAverage,
arg.saveMeanAndInvVariance,
averageFactor,
epsilon);
}
else if(arg.data_type == 5)
{
pass = bnorm_fwd_nhwc_test<ck::bhalf_t, float>(arg.do_verification,
arg.init_method,
arg.time_kernel,
arg.inOutLengths,
arg.updateMovingAverage,
arg.saveMeanAndInvVariance,
averageFactor,
epsilon);
}
else if(arg.data_type == 6)
{
pass = bnorm_fwd_nhwc_test<double, double>(arg.do_verification,
arg.init_method,
arg.time_kernel,
arg.inOutLengths,
arg.updateMovingAverage,
arg.saveMeanAndInvVariance,
averageFactor,
epsilon);
}
}
else
{
pass = bnorm_fwd_nhwc_test<ck::half_t, float>(true,
2,
false, // don't time kernel
{128, 16, 16, 1024},
true,
false,
averageFactor,
epsilon);
};
return (pass ? 0 : 1);
}
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <cassert>
#include <vector>
#include "ck/ck.hpp"
#include "ck/utility/sequence.hpp"
#include "ck/utility/tuple.hpp"
#include "ck/utility/reduction_operator.hpp"
#include "ck/tensor_operation/gpu/device/device_elementwise.hpp"
#include "batchnorm_common.hpp"
template <typename InOutDataType,
typename AccDataType,
ck::index_t Rank,
ck::index_t NumBatchNormReduceDim,
bool fastest_dim_is_reduced = false>
int bnorm_infer(
bool time_kernel,
const std::array<int, NumBatchNormReduceDim> reduceDims,
const std::array<ck::index_t, Rank> xyLengths,
const std::array<ck::index_t, Rank> xStrides,
const std::array<ck::index_t, Rank> yStrides,
const std::array<ck::index_t, Rank - NumBatchNormReduceDim> bnScaleBiasMeanVarLengths,
const std::array<ck::index_t, Rank - NumBatchNormReduceDim> bnScaleBiasMeanVarStrides,
const void* p_x,
const void* p_scale,
const void* p_bias,
double epsilon,
const void* p_estimatedMean,
const void* p_estimatedVariance,
void* p_y)
{
(void)bnScaleBiasMeanVarLengths;
static_assert(NumBatchNormReduceDim < Rank,
"Invalid number of reduced dimensions for batchnorm!");
using DeviceNormalizeInstance = ck::tensor_operation::device::DeviceElementwise<
ck::Tuple<InOutDataType, AccDataType, AccDataType, AccDataType, AccDataType>, // x, mean,
// variance,
// scale,
// bias,
ck::Tuple<InOutDataType>, // y
NormalizeInInfer,
Rank,
2, // MPerthread
ck::Sequence<1, 1, 1, 1, 1>, // x, mean, variance, scale, bias
ck::Sequence<1>>; // scalarPerVector: y
auto invariantDims = get_invariant_dims<Rank, NumBatchNormReduceDim>(reduceDims);
std::array<ck::index_t, Rank> aligned_scaleBiasMeanVarStrides{0};
int i = 0;
for(auto dim : invariantDims)
{
assert(xyLengths[dim] == bnScaleBiasMeanVarLengths[i]);
aligned_scaleBiasMeanVarStrides[dim] = bnScaleBiasMeanVarStrides[i];
i++;
};
int32_t reduceLength = 1;
for(auto dim : reduceDims)
reduceLength *= xyLengths[dim];
int32_t invariantLength = 1;
for(auto dim : invariantDims)
invariantLength *= xyLengths[dim];
size_t total_length = static_cast<size_t>(invariantLength) * reduceLength;
float avg_time = 0.0f;
std::size_t num_bytes = 0;
auto dev_normalize = DeviceNormalizeInstance{};
auto argument_ptr1 = dev_normalize.MakeArgumentPointer(
xyLengths,
{xStrides,
aligned_scaleBiasMeanVarStrides,
aligned_scaleBiasMeanVarStrides,
aligned_scaleBiasMeanVarStrides,
aligned_scaleBiasMeanVarStrides},
{yStrides},
{p_x, p_estimatedMean, p_estimatedVariance, p_scale, p_bias},
{p_y},
NormalizeInInfer{epsilon});
if(!dev_normalize.IsSupportedArgument(argument_ptr1.get()))
{
std::cout << "The runtime parameters seems not supported by the Devic, exiting!"
<< std::endl;
return (-1);
};
auto invoker_ptr1 = dev_normalize.MakeInvokerPointer();
avg_time += invoker_ptr1->Run(argument_ptr1.get(), StreamConfig{nullptr, time_kernel});
num_bytes += (total_length * (1 * sizeof(InOutDataType) + 4 * sizeof(AccDataType)) +
total_length * sizeof(InOutDataType));
if(time_kernel)
{
float gb_per_sec = num_bytes / 1.E6 / avg_time;
std::cout << "Perf: " << avg_time << " ms, " << gb_per_sec << " GB/s" << std::endl;
};
return (0);
};
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#include <limits>
#include <iostream>
#include <vector>
#include <array>
#include <algorithm>
#include <getopt.h>
#include "ck/ck.hpp"
#include "ck/library/utility/check_err.hpp"
#include "ck/library/utility/device_memory.hpp"
#include "ck/library/utility/host_tensor.hpp"
#include "ck/library/utility/host_tensor_generator.hpp"
#include "ck/library/utility/host_common_util.hpp"
#include "ck/library/reference_tensor_operation/cpu/reference_batchnorm_infer_nhwc_c.hpp"
#include "batchnorm_infer_impl.hpp"
template <typename InOutDataType, typename AccDataType>
using ReferenceBatchNormInferInstance =
ck::tensor_operation::host::ReferenceBatchNormInfer_Input_N_H_W_C_Output_C<InOutDataType,
AccDataType>;
static struct option long_options[] = {{"inOutLengths", required_argument, nullptr, 'D'},
{"verify", required_argument, nullptr, 'v'},
{"help", no_argument, nullptr, '?'},
{nullptr, 0, nullptr, 0}};
class BatchNormInferArg
{
private:
int option_index = 0;
public:
std::vector<size_t> inOutLengths;
bool do_verification = false;
int data_type = 0;
int init_method = 2;
bool time_kernel = false;
public:
void show_usage(const char* cmd)
{
std::cout << "Usage of " << cmd << std::endl;
std::cout << "--inOutLengths or -D, comma separated list of input tensor dimension "
"lengths, must have 4 integers for nhwc"
<< std::endl;
std::cout << "--verify or -v, 1/0 to indicate whether to verify the batch-normalization "
"result by "
"comparing with the host-based batch-normalization"
<< std::endl;
std::cout << "Arg1: data type (0: fp16, 1: fp32, 3: int8, 5: bp16, 6: fp64)" << std::endl;
std::cout << "Arg2: init method used for bnScale and bnBias (0=no init, 1=single integer "
"value, 2=scope integer "
"value, 3=decimal value)"
<< std::endl;
std::cout << "Arg3: time kernel (0=no, 1=yes)" << std::endl;
};
int processArgs(int argc, char* argv[])
{
using ck::host_common::getTypeValuesFromString;
int ch;
while(1)
{
ch = getopt_long(argc, argv, "D:v:", long_options, &option_index);
if(ch == -1)
break;
switch(ch)
{
case 'D':
if(!optarg)
throw std::runtime_error("Invalid option format!");
inOutLengths = getTypeValuesFromString<size_t>(optarg);
if(inOutLengths.size() != 4)
throw std::runtime_error(
"NHWC tensor layout should have 4 length values specified!");
break;
case 'v':
if(!optarg)
throw std::runtime_error("Invalid option format!");
do_verification = static_cast<bool>(std::atoi(optarg));
break;
case '?':
if(std::string(long_options[option_index].name) == "help")
{
show_usage(argv[0]);
return (-1);
};
break;
default: show_usage(argv[0]); return (-1);
};
};
if(optind + 3 > argc)
throw std::runtime_error("Invalid cmd-line arguments, more argumetns are needed!");
data_type = std::atoi(argv[optind++]);
init_method = std::atoi(argv[optind++]);
time_kernel = static_cast<bool>(std::atoi(argv[optind]));
if(data_type != 0 && data_type != 1 && data_type != 3 && data_type != 5 && data_type != 6)
return (-1);
return (0);
};
};
using namespace ck;
template <typename InOutDataType, typename AccDataType>
bool bnorm_infer_nhwc_test(bool do_verification,
int init_method,
bool time_kernel,
const std::vector<size_t> inOutLengths,
double epsilon)
{
// for NHWC BatchNorm calculation of mean and meansquare
constexpr int Rank = 4;
constexpr int NumReduceDim = 3;
const std::vector<size_t> scaleBiasMeanVarLengths = {inOutLengths[3]};
// input data of the batchnorm forward algorithm
Tensor<InOutDataType> x(inOutLengths);
Tensor<AccDataType> bnScale(scaleBiasMeanVarLengths);
Tensor<AccDataType> bnBias(scaleBiasMeanVarLengths);
// output data of the batchnorm forward algorithm
Tensor<InOutDataType> y_ref(inOutLengths);
Tensor<InOutDataType> y(inOutLengths);
Tensor<AccDataType> estimatedMean(scaleBiasMeanVarLengths);
Tensor<AccDataType> estimatedVariance(scaleBiasMeanVarLengths);
auto inOutStrides = x.mDesc.GetStrides();
auto scaleBiasMeanVarStrides = bnScale.mDesc.GetStrides();
std::size_t num_thread = std::thread::hardware_concurrency();
if constexpr(std::is_same<InOutDataType, int8_t>::value)
{
x.GenerateTensorValue(GeneratorTensor_2<InOutDataType>{-5, 5}, num_thread);
const float x_mean = 0.0f;
const float x_stddev = 2.5f;
const float noise_stddev = 0.0001f;
estimatedMean.GenerateTensorValue(GeneratorTensor_4<AccDataType>{x_mean, noise_stddev},
num_thread);
estimatedVariance.GenerateTensorValue(
GeneratorTensor_4<AccDataType>{x_stddev * x_stddev, noise_stddev}, num_thread);
}
else
{
const float x_mean = 0.0f;
const float x_stddev = 1.0f;
const float noise_stddev = 0.0001f;
x.GenerateTensorValue(GeneratorTensor_4<InOutDataType>{x_mean, x_stddev}, num_thread);
// initialize the savedMean to be values with tiny variation to the mean of the x values
estimatedMean.GenerateTensorValue(GeneratorTensor_4<AccDataType>{x_mean, noise_stddev},
num_thread);
// initialize the variance to be values with tiny variation to the variance of the x values
estimatedVariance.GenerateTensorValue(
GeneratorTensor_4<AccDataType>{x_stddev * x_stddev, noise_stddev}, num_thread);
};
if(do_verification)
{
switch(init_method)
{
case 0:
bnScale.GenerateTensorValue(GeneratorTensor_0<AccDataType>{}, num_thread);
bnBias.GenerateTensorValue(GeneratorTensor_0<AccDataType>{}, num_thread);
break;
case 1:
bnScale.GenerateTensorValue(GeneratorTensor_1<AccDataType>{1}, num_thread);
bnBias.GenerateTensorValue(GeneratorTensor_1<AccDataType>{0}, num_thread);
break;
case 2:
bnScale.GenerateTensorValue(GeneratorTensor_2<AccDataType>{-5, 5}, num_thread);
bnBias.GenerateTensorValue(GeneratorTensor_2<AccDataType>{-5, 5}, num_thread);
break;
default:
bnScale.GenerateTensorValue(GeneratorTensor_3<AccDataType>{-5.0f, 5.0f}, num_thread);
bnBias.GenerateTensorValue(GeneratorTensor_3<AccDataType>{-5.0f, 5.0f}, num_thread);
}
};
// these buffers are usually provided by the user application
DeviceMem x_dev(sizeof(InOutDataType) * x.mDesc.GetElementSpaceSize());
DeviceMem y_dev(sizeof(InOutDataType) * y.mDesc.GetElementSpaceSize());
DeviceMem bnScale_dev(sizeof(AccDataType) * bnScale.mDesc.GetElementSpaceSize());
DeviceMem bnBias_dev(sizeof(AccDataType) * bnBias.mDesc.GetElementSpaceSize());
// mean_dev or resultSaveMean_dev
DeviceMem estimatedMean_dev(sizeof(AccDataType) * estimatedMean.mDesc.GetElementSpaceSize());
// meansquare_dev or resultSaveInvVariance_dev
DeviceMem estimatedVariance_dev(sizeof(AccDataType) *
estimatedVariance.mDesc.GetElementSpaceSize());
x_dev.ToDevice(x.mData.data());
bnScale_dev.ToDevice(bnScale.mData.data());
bnBias_dev.ToDevice(bnBias.mData.data());
estimatedMean_dev.ToDevice(estimatedMean.mData.data());
estimatedVariance_dev.ToDevice(estimatedVariance.mData.data());
using ck::index_t;
std::array<index_t, Rank> i_inOutLengths;
std::array<index_t, Rank> i_inOutStrides;
std::array<index_t, Rank - NumReduceDim> i_scaleBiasMeanVarLengths;
std::array<index_t, Rank - NumReduceDim> i_scaleBiasMeanVarStrides;
std::copy(inOutLengths.begin(), inOutLengths.end(), i_inOutLengths.begin());
std::copy(inOutStrides.begin(), inOutStrides.end(), i_inOutStrides.begin());
std::copy(scaleBiasMeanVarLengths.begin(),
scaleBiasMeanVarLengths.end(),
i_scaleBiasMeanVarLengths.begin());
std::copy(scaleBiasMeanVarStrides.begin(),
scaleBiasMeanVarStrides.end(),
i_scaleBiasMeanVarStrides.begin());
int result = 0;
result = bnorm_infer<InOutDataType, AccDataType, Rank, NumReduceDim, false>(
time_kernel,
{0, 1, 2},
i_inOutLengths,
i_inOutStrides,
i_inOutStrides,
i_scaleBiasMeanVarLengths,
i_scaleBiasMeanVarStrides,
x_dev.GetDeviceBuffer(),
bnScale_dev.GetDeviceBuffer(),
bnBias_dev.GetDeviceBuffer(),
epsilon,
estimatedMean_dev.GetDeviceBuffer(),
estimatedVariance_dev.GetDeviceBuffer(),
y_dev.GetDeviceBuffer());
if(result < 0)
return (false);
bool pass = true;
if(do_verification)
{
auto batchNormInfer_ref = ReferenceBatchNormInferInstance<InOutDataType, AccDataType>{};
auto argument_ptr_ref =
batchNormInfer_ref.MakeArgumentPointer(i_inOutLengths,
i_inOutStrides,
i_inOutStrides,
i_scaleBiasMeanVarLengths,
i_scaleBiasMeanVarStrides,
x.mData.data(),
bnScale.mData.data(),
bnBias.mData.data(),
epsilon,
estimatedMean.mData.data(),
estimatedVariance.mData.data(),
y_ref.mData.data());
if(!batchNormInfer_ref.IsSupportedArgument(argument_ptr_ref.get()))
{
std::cout
<< "The runtime parameters seems not supported by the BatchNorm instance, exiting!"
<< std::endl;
return (-2);
};
auto invoker_ptr_ref = batchNormInfer_ref.MakeInvokerPointer();
(void)invoker_ptr_ref->Run(argument_ptr_ref.get());
y_dev.FromDevice(y.mData.data());
pass = pass && ck::utils::check_err(y.mData, y_ref.mData);
};
return (pass);
};
static const double epsilon = std::numeric_limits<float>::epsilon();
int main(int argc, char* argv[])
{
bool pass = true;
if(argc > 1)
{
BatchNormInferArg arg;
if(arg.processArgs(argc, argv) < 0)
return (-1);
if(arg.data_type == 0)
{
pass = bnorm_infer_nhwc_test<ck::half_t, float>(
arg.do_verification, arg.init_method, arg.time_kernel, arg.inOutLengths, epsilon);
}
else if(arg.data_type == 1)
{
pass = bnorm_infer_nhwc_test<float, float>(
arg.do_verification, arg.init_method, arg.time_kernel, arg.inOutLengths, epsilon);
}
else if(arg.data_type == 3)
{
pass = bnorm_infer_nhwc_test<int8_t, float>(
arg.do_verification, arg.init_method, arg.time_kernel, arg.inOutLengths, epsilon);
}
else if(arg.data_type == 5)
{
pass = bnorm_infer_nhwc_test<ck::bhalf_t, float>(
arg.do_verification, arg.init_method, arg.time_kernel, arg.inOutLengths, epsilon);
}
else if(arg.data_type == 6)
{
pass = bnorm_infer_nhwc_test<double, double>(
arg.do_verification, arg.init_method, arg.time_kernel, arg.inOutLengths, epsilon);
};
}
else
{
pass = bnorm_infer_nhwc_test<ck::half_t, float>(true,
2,
false, // don't time kernel
{128, 16, 16, 1024},
epsilon);
};
return (pass ? 0 : 1);
}
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment