Unverified Commit 204ef976 authored by Chao Liu's avatar Chao Liu Committed by GitHub
Browse files

add more datatype to gemm+gemm and conv+conv example (#397)

* refactor

* refactor

* adding int4/int8/fp16/bf16 for conv+conv and gemm+gemm

* adding int4/int8/fp16/bf16 for conv+conv and gemm+gemm

* clean
parent 46a675aa
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#include <cstdlib>
#include <iostream>
#include <numeric>
#include <type_traits>
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/device_batched_gemm_gemm_xdl_cshuffle.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/library/utility/check_err.hpp"
#include "ck/library/utility/device_memory.hpp"
#include "ck/library/utility/host_tensor.hpp"
#include "ck/library/utility/host_tensor_generator.hpp"
#include "ck/library/utility/convolution_parameter.hpp"
#include "ck/library/utility/convolution_host_tensor_descriptor_helper.hpp"
#include "ck/library/reference_tensor_operation/cpu/reference_conv_fwd.hpp"
using In0DataType = int8_t;
using Wei0DataType = int8_t;
using Acc0DataType = int32_t;
using Wei1DataType = int8_t;
using Acc1DataType = int32_t;
using C1ShuffleDataType = int32_t;
using Out1DataType = int8_t;
// This is used for reference code
using Out0DataType = int8_t;
template <ck::index_t... Is>
using S = ck::Sequence<Is...>;
using In0ElementOp = ck::tensor_operation::element_wise::PassThrough;
using Wei0ElementOp = ck::tensor_operation::element_wise::PassThrough;
using Wei1ElementOp = ck::tensor_operation::element_wise::PassThrough;
using Out0ElementOp = ck::tensor_operation::element_wise::PassThrough;
using Out1ElementOp = ck::tensor_operation::element_wise::UnaryConvert;
using Row = ck::tensor_layout::gemm::RowMajor;
using Col = ck::tensor_layout::gemm::ColumnMajor;
static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default;
using DeviceBatchedGemmGemmInstance =
ck::tensor_operation::device::DeviceBatchedGemmGemm_Xdl_CShuffle<
Row, // ALayout
Col, // B0Layout
Col, // B1Layout
Row, // CLayout
In0DataType, // ADataType,
Wei0DataType, // B0DataType,
Wei1DataType, // B1DataType,
Out1DataType, // CDataType,
Acc0DataType, // AccDataType,
C1ShuffleDataType, // CShuffleDataType,
In0ElementOp, // AElementOp,
Wei0ElementOp, // B0ElementOp,
Out0ElementOp, // Acc0ElementOp,
Wei1ElementOp, // B1ElementOp,
Out1ElementOp, // CElementOp,
GemmDefault,
1,
256,
128, // MPerBlock
128, // NPerBlock
64, // KPerBlock
128, // Gemm1NPerBlock
64, // Gemm1KPerBlock
16, // AK1
16, // BK1
4, // B1K1
32, // MPerXDL
32, // NPerXDL
1, // MXdlPerWave
4, // NXdlPerWave
4, // Gemm1NXdlPerWave
S<4, 64, 1>, // ABlockTransfer
S<1, 0, 2>,
S<1, 0, 2>,
2,
16,
16,
true,
S<4, 64, 1>, // BBlockTransfer
S<1, 0, 2>,
S<1, 0, 2>,
2,
16,
16,
true,
S<4, 64, 1>, // B1BlockTransfer
S<1, 0, 2>,
S<1, 0, 2>,
2,
4,
4,
true,
1, // CShuffleMXdlPerWavePerShuffle
2, // CShuffleNXdlPerWavePerShuffle
S<1, 32, 1, 8>, // CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
8>; // CShuffleBlockTransferScalarPerVector_NPerBlock
#include "run_grouped_conv_conv_fwd_example.inc"
int main(int argc, char* argv[]) { return run_grouped_conv_conv_fwd_example(argc, argv) ? 0 : 1; }
// SPDX-License-Identifier: MIT // SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#include <cstdlib> #pragma once
#include <iostream>
#include <numeric>
#include <type_traits>
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/library/utility/check_err.hpp"
#include "ck/library/utility/device_memory.hpp"
#include "ck/library/utility/host_tensor.hpp"
#include "ck/library/utility/host_tensor_generator.hpp"
#include "ck/library/utility/convolution_parameter.hpp"
#include "ck/library/utility/convolution_host_tensor_descriptor_helper.hpp"
#include "ck/library/reference_tensor_operation/cpu/reference_conv_fwd.hpp"
template <ck::index_t NDimSpatial, template <ck::index_t NDimSpatial,
typename In0DataType, typename In0DataType,
typename Wei0DataType, typename Wei0DataType,
typename Acc0DataType, typename Out0DataType,
typename Wei1DataType, typename Wei1DataType,
typename Out1DataType, typename Out1DataType,
typename In0ElementOp, typename In0ElementOp,
...@@ -30,21 +15,21 @@ template <ck::index_t NDimSpatial, ...@@ -30,21 +15,21 @@ template <ck::index_t NDimSpatial,
typename Wei1ElementOp, typename Wei1ElementOp,
typename Out1ElementOp, typename Out1ElementOp,
typename DeviceOpInstance> typename DeviceOpInstance>
int run_grouped_conv_conv_fwd(bool do_verification, bool run_grouped_conv_conv_fwd(bool do_verification,
int init_method, int init_method,
bool time_kernel, bool time_kernel,
const ck::utils::conv::ConvParam& conv0_param, const ck::utils::conv::ConvParam& conv0_param,
const ck::utils::conv::ConvParam& conv1_param, const ck::utils::conv::ConvParam& conv1_param,
const HostTensorDescriptor& in0_g_n_c_wis_desc, const HostTensorDescriptor& in0_g_n_c_wis_desc,
const HostTensorDescriptor& wei0_g_k_c_xs_desc, const HostTensorDescriptor& wei0_g_k_c_xs_desc,
const HostTensorDescriptor& out0_g_n_k_wos_desc, const HostTensorDescriptor& out0_g_n_k_wos_desc,
const HostTensorDescriptor& wei1_g_k_c_xs_desc, const HostTensorDescriptor& wei1_g_k_c_xs_desc,
const HostTensorDescriptor& out1_g_n_k_wos_desc, const HostTensorDescriptor& out1_g_n_k_wos_desc,
const In0ElementOp& in0_element_op, const In0ElementOp& in0_element_op,
const Wei0ElementOp& wei0_element_op, const Wei0ElementOp& wei0_element_op,
const Wei1ElementOp& wei1_element_op, const Wei1ElementOp& wei1_element_op,
const Out0ElementOp& out0_element_op, const Out0ElementOp& out0_element_op,
const Out1ElementOp& out1_element_op) const Out1ElementOp& out1_element_op)
{ {
Tensor<In0DataType> in0(in0_g_n_c_wis_desc); Tensor<In0DataType> in0(in0_g_n_c_wis_desc);
Tensor<Wei0DataType> wei0(wei0_g_k_c_xs_desc); Tensor<Wei0DataType> wei0(wei0_g_k_c_xs_desc);
...@@ -71,6 +56,20 @@ int run_grouped_conv_conv_fwd(bool do_verification, ...@@ -71,6 +56,20 @@ int run_grouped_conv_conv_fwd(bool do_verification,
wei1.GenerateTensorValue(GeneratorTensor_3<Wei1DataType>{-0.5, 0.5}); wei1.GenerateTensorValue(GeneratorTensor_3<Wei1DataType>{-0.5, 0.5});
} }
#ifdef BUILD_INT4_EXAMPLE
DeviceMem in0_device_buf(sizeof(KernelIn0DataType) * in0.mDesc.GetElementSpaceSize());
DeviceMem wei0_device_buf(sizeof(KernelWei0DataType) * wei0.mDesc.GetElementSpaceSize());
DeviceMem wei1_device_buf(sizeof(KernelWei1DataType) * wei1.mDesc.GetElementSpaceSize());
DeviceMem out1_device_buf(sizeof(KernelOut1DataType) * out1_device.mDesc.GetElementSpaceSize());
const Tensor<KernelIn0DataType> in0_converted(in0);
const Tensor<KernelWei0DataType> wei0_converted(wei0);
const Tensor<KernelWei1DataType> wei1_converted(wei1);
in0_device_buf.ToDevice(in0_converted.mData.data());
wei0_device_buf.ToDevice(wei0_converted.mData.data());
wei1_device_buf.ToDevice(wei1_converted.mData.data());
#else
DeviceMem in0_device_buf(sizeof(In0DataType) * in0.mDesc.GetElementSpaceSize()); DeviceMem in0_device_buf(sizeof(In0DataType) * in0.mDesc.GetElementSpaceSize());
DeviceMem wei0_device_buf(sizeof(Wei0DataType) * wei0.mDesc.GetElementSpaceSize()); DeviceMem wei0_device_buf(sizeof(Wei0DataType) * wei0.mDesc.GetElementSpaceSize());
DeviceMem wei1_device_buf(sizeof(Wei1DataType) * wei1.mDesc.GetElementSpaceSize()); DeviceMem wei1_device_buf(sizeof(Wei1DataType) * wei1.mDesc.GetElementSpaceSize());
...@@ -79,6 +78,7 @@ int run_grouped_conv_conv_fwd(bool do_verification, ...@@ -79,6 +78,7 @@ int run_grouped_conv_conv_fwd(bool do_verification,
in0_device_buf.ToDevice(in0.mData.data()); in0_device_buf.ToDevice(in0.mData.data());
wei0_device_buf.ToDevice(wei0.mData.data()); wei0_device_buf.ToDevice(wei0.mData.data());
wei1_device_buf.ToDevice(wei1.mData.data()); wei1_device_buf.ToDevice(wei1.mData.data());
#endif
std::array<ck::index_t, NDimSpatial + 3> a0_g_n_c_wis_lengths{}; std::array<ck::index_t, NDimSpatial + 3> a0_g_n_c_wis_lengths{};
std::array<ck::index_t, NDimSpatial + 3> a0_g_n_c_wis_strides{}; std::array<ck::index_t, NDimSpatial + 3> a0_g_n_c_wis_strides{};
...@@ -116,7 +116,6 @@ int run_grouped_conv_conv_fwd(bool do_verification, ...@@ -116,7 +116,6 @@ int run_grouped_conv_conv_fwd(bool do_verification,
copy(conv1_param.input_left_pads_, input1_left_pads); copy(conv1_param.input_left_pads_, input1_left_pads);
copy(conv1_param.input_right_pads_, input1_right_pads); copy(conv1_param.input_right_pads_, input1_right_pads);
#if 1
// do Conv using GEMM, only works for 1x1 conv for now // do Conv using GEMM, only works for 1x1 conv for now
const ck::index_t gemm_batch = a0_g_n_c_wis_lengths[0]; const ck::index_t gemm_batch = a0_g_n_c_wis_lengths[0];
...@@ -150,29 +149,36 @@ int run_grouped_conv_conv_fwd(bool do_verification, ...@@ -150,29 +149,36 @@ int run_grouped_conv_conv_fwd(bool do_verification,
auto device_op = DeviceOpInstance{}; auto device_op = DeviceOpInstance{};
auto invoker = device_op.MakeInvoker(); auto invoker = device_op.MakeInvoker();
auto argument = auto argument = device_op.MakeArgument(
device_op.MakeArgument(static_cast<In0DataType*>(in0_device_buf.GetDeviceBuffer()), #ifdef BUILD_INT4_EXAMPLE
static_cast<Wei0DataType*>(wei0_device_buf.GetDeviceBuffer()), static_cast<KernelIn0DataType*>(in0_device_buf.GetDeviceBuffer()),
static_cast<Wei1DataType*>(wei1_device_buf.GetDeviceBuffer()), static_cast<KernelWei0DataType*>(wei0_device_buf.GetDeviceBuffer()),
static_cast<Out1DataType*>(out1_device_buf.GetDeviceBuffer()), static_cast<KernelWei1DataType*>(wei1_device_buf.GetDeviceBuffer()),
gemm0_m_length, static_cast<KernelOut1DataType*>(out1_device_buf.GetDeviceBuffer()),
gemm0_n_length, #else
gemm0_k_length, static_cast<In0DataType*>(in0_device_buf.GetDeviceBuffer()),
gemm1_n_length, static_cast<Wei0DataType*>(wei0_device_buf.GetDeviceBuffer()),
gemm_batch, static_cast<Wei1DataType*>(wei1_device_buf.GetDeviceBuffer()),
a0_stride, static_cast<Out1DataType*>(out1_device_buf.GetDeviceBuffer()),
b0_stride, #endif
b1_stride, gemm0_m_length,
e1_stride, gemm0_n_length,
a0_batch_stride, gemm0_k_length,
b0_batch_stride, gemm1_n_length,
b1_batch_stride, gemm_batch,
e1_batch_stride, a0_stride,
in0_element_op, b0_stride,
wei0_element_op, b1_stride,
out0_element_op, e1_stride,
wei1_element_op, a0_batch_stride,
out1_element_op); b0_batch_stride,
b1_batch_stride,
e1_batch_stride,
in0_element_op,
wei0_element_op,
out0_element_op,
wei1_element_op,
out1_element_op);
if(!device_op.IsSupportedArgument(argument)) if(!device_op.IsSupportedArgument(argument))
{ {
...@@ -193,24 +199,23 @@ int run_grouped_conv_conv_fwd(bool do_verification, ...@@ -193,24 +199,23 @@ int run_grouped_conv_conv_fwd(bool do_verification,
float gb_per_sec = num_btype / 1.E6 / avg_time; float gb_per_sec = num_btype / 1.E6 / avg_time;
std::cout << "Perf: " << avg_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s, " std::cout << "Perf: " << avg_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s, "
<< device_op.GetTypeString() << std::endl; << device_op.GetTypeString() << std::endl;
#endif
if(do_verification) if(do_verification)
{ {
using PassThrough = ck::tensor_operation::element_wise::PassThrough; using PassThrough = ck::tensor_operation::element_wise::PassThrough;
Tensor<Acc0DataType> out0_host(out0_g_n_k_wos_desc); Tensor<Out0DataType> out0_host(out0_g_n_k_wos_desc);
auto ref_conv0 = ck::tensor_operation::host::ReferenceConvFwd<NDimSpatial, auto ref_conv0 = ck::tensor_operation::host::ReferenceConvFwd<NDimSpatial,
In0DataType, In0DataType,
Wei0DataType, Wei0DataType,
Acc0DataType, Out0DataType,
In0ElementOp, In0ElementOp,
Wei0ElementOp, Wei0ElementOp,
Out0ElementOp>(); Out0ElementOp>();
auto ref_conv1 = ck::tensor_operation::host::ReferenceConvFwd<NDimSpatial, auto ref_conv1 = ck::tensor_operation::host::ReferenceConvFwd<NDimSpatial,
Acc0DataType, Out0DataType,
Wei1DataType, Wei1DataType,
Out1DataType, Out1DataType,
PassThrough, PassThrough,
...@@ -245,13 +250,134 @@ int run_grouped_conv_conv_fwd(bool do_verification, ...@@ -245,13 +250,134 @@ int run_grouped_conv_conv_fwd(bool do_verification,
ref_conv0_invoker.Run(ref_conv0_argument); ref_conv0_invoker.Run(ref_conv0_argument);
ref_conv1_invoker.Run(ref_conv1_argument); ref_conv1_invoker.Run(ref_conv1_argument);
#ifdef BUILD_INT4_EXAMPLE
Tensor<KernelOut1DataType> out1_device_converted(out1_host.mDesc);
out1_device_buf.FromDevice(out1_device_converted.mData.data());
out1_device = out1_device_converted.CopyAsType<Out1DataType>();
#else
out1_device_buf.FromDevice(out1_device.mData.data()); out1_device_buf.FromDevice(out1_device.mData.data());
#endif
return ck::utils::check_err( return ck::utils::check_err(
out1_device.mData, out1_host.mData, "Error: incorrect results!", 1e-5f, 1e-4f) out1_device.mData, out1_host.mData, "Error: incorrect results!", 1e-5f, 1e-4f);
? 0 }
: 1;
return true;
}
bool run_grouped_conv_conv_fwd_example(int argc, char* argv[])
{
bool do_verification = true;
int init_method = 1;
bool time_kernel = false;
ck::utils::conv::ConvParam conv0_param{
2, 1, 128, 512, 128, {1, 1}, {28, 28}, {1, 1}, {1, 1}, {0, 0}, {0, 0}};
ck::utils::conv::ConvParam conv1_param{
2, 1, 128, 128, 512, {1, 1}, {28, 28}, {1, 1}, {1, 1}, {0, 0}, {0, 0}};
if(argc == 1)
{
// use default case
}
else if(argc == 4)
{
do_verification = std::stoi(argv[1]);
init_method = std::stoi(argv[2]);
time_kernel = std::stoi(argv[3]);
}
else
{
printf("arg1: verification (0=no, 1=yes)\n");
printf("arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n");
printf("arg3: time kernel (0=no, 1=yes)\n");
exit(0);
}
const auto in0_element_op = In0ElementOp{};
const auto wei0_element_op = Wei0ElementOp{};
const auto wei1_element_op = Wei1ElementOp{};
const auto out0_element_op = Out0ElementOp{};
const auto out1_element_op = Out1ElementOp{};
const auto run = [&](auto ndim_spatial,
auto in0_layout,
auto wei0_layout,
auto wei1_layout,
auto out1_layout) {
constexpr ck::index_t ndim_spatial_value = ndim_spatial.value;
using In0Layout = decltype(in0_layout);
using Wei0Layout = decltype(wei0_layout);
using Wei1Layout = decltype(wei1_layout);
using Out1Layout = decltype(out1_layout);
const auto in0_g_n_c_wis_desc =
ck::utils::conv::make_input_host_tensor_descriptor_g_n_c_wis_packed<In0Layout>(
conv0_param);
const auto wei0_g_k_c_xs_desc =
ck::utils::conv::make_weight_host_tensor_descriptor_g_k_c_xs_packed<Wei0Layout>(
conv0_param);
// out0 doesn't physical exist, any layout for host verification is OK
const auto out0_g_n_k_wos_desc =
ck::utils::conv::make_output_host_tensor_descriptor_g_n_k_wos_packed<Out1Layout>(
conv0_param);
const auto wei1_g_k_c_xs_desc =
ck::utils::conv::make_weight_host_tensor_descriptor_g_k_c_xs_packed<Wei1Layout>(
conv1_param);
const auto out1_g_n_k_wos_desc =
ck::utils::conv::make_output_host_tensor_descriptor_g_n_k_wos_packed<Out1Layout>(
conv1_param);
return run_grouped_conv_conv_fwd<ndim_spatial_value,
In0DataType,
Wei0DataType,
Out0DataType,
Wei1DataType,
Out1DataType,
In0ElementOp,
Wei0ElementOp,
Out0ElementOp,
Wei1ElementOp,
Out1ElementOp,
DeviceBatchedGemmGemmInstance>(do_verification,
init_method,
time_kernel,
conv0_param,
conv1_param,
in0_g_n_c_wis_desc,
wei0_g_k_c_xs_desc,
out0_g_n_k_wos_desc,
wei1_g_k_c_xs_desc,
out1_g_n_k_wos_desc,
in0_element_op,
wei0_element_op,
wei1_element_op,
out0_element_op,
out1_element_op);
};
namespace ctc = ck::tensor_layout::convolution;
if(conv0_param.num_dim_spatial_ == 1)
{
return run(ck::Number<1>{}, ctc::GNWC{}, ctc::GKXC{}, ctc::GKXC{}, ctc::GNWK{});
}
else if(conv0_param.num_dim_spatial_ == 2)
{
return run(ck::Number<2>{}, ctc::GNHWC{}, ctc::GKYXC{}, ctc::GKYXC{}, ctc::GNHWK{});
}
else if(conv0_param.num_dim_spatial_ == 3)
{
return run(ck::Number<3>{}, ctc::GNDHWC{}, ctc::GKZYXC{}, ctc::GKZYXC{}, ctc::GNDHWK{});
} }
return 0; return true;
} }
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment