Commit b2290854 authored by rocking's avatar rocking
Browse files

Merge commit '3e6c2610' into gemm_norm

parents 253f7ef2 3e6c2610
#include <iostream>
#include <cstdlib>
#include "config.hpp"
#include "tensor_layout.hpp"
#include "reduction_enums.hpp"
#include "pool2d_fwd_common.hpp"
using InDataType = float;
using OutDataType = float;
using AccDataType = float;
using IndexDataType = int32_t;
using InLayout = ck::tensor_layout::convolution::NHWC;
using OutLayout = ck::tensor_layout::convolution::NHWC;
#if 1
static constexpr auto ReduceOpId = ck::ReduceTensorOp::MAX;
#else
static constexpr auto ReduceOpId = ck::ReduceTensorOp::AVG;
#endif
static constexpr bool OutputIndex = false;
static constexpr bool PropagateNan = false;
int main(int argc, char* argv[])
{
using namespace ck::host_reduce;
bool do_verification;
int init_method;
bool time_kernel;
// Pool shape
ck::index_t N = 128;
ck::index_t C = 192;
ck::index_t Y = 3;
ck::index_t X = 3;
ck::index_t Hi = 71;
ck::index_t Wi = 71;
ck::index_t window_stride_h = 2;
ck::index_t window_stride_w = 2;
ck::index_t in_left_pad_h = 1;
ck::index_t in_left_pad_w = 1;
ck::index_t in_right_pad_h = 1;
ck::index_t in_right_pad_w = 1;
if(argc == 1)
{
do_verification = true;
init_method = 1;
time_kernel = true;
}
else if(argc == 4)
{
do_verification = std::stoi(argv[1]);
init_method = std::stoi(argv[2]);
time_kernel = static_cast<bool>(std::stoi(argv[3]));
}
else if(argc == 16)
{
do_verification = std::stoi(argv[1]);
init_method = std::stoi(argv[2]);
time_kernel = static_cast<bool>(std::stoi(argv[3]));
N = std::stoi(argv[4]);
C = std::stoi(argv[5]);
Y = std::stoi(argv[6]);
X = std::stoi(argv[7]);
Hi = std::stoi(argv[8]);
Wi = std::stoi(argv[9]);
window_stride_h = std::stoi(argv[10]);
window_stride_w = std::stoi(argv[11]);
in_left_pad_h = std::stoi(argv[12]);
in_left_pad_w = std::stoi(argv[13]);
in_right_pad_h = std::stoi(argv[14]);
in_right_pad_w = std::stoi(argv[15]);
}
else
{
printf("arg1: verification (0=no, 1=yes)\n");
printf("arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n");
printf("arg3: time kernel (0=no, 1=yes)\n");
printf("arg4 to 15: N, C, Y, X, Hi, Wi, Sy, Sx, LeftPy, LeftPx, RightPy, "
"RightPx\n");
exit(0);
}
bool pass = pool_test<InDataType,
OutDataType,
AccDataType,
IndexDataType,
InLayout,
OutLayout,
ReduceOpId,
PropagateNan,
OutputIndex>(do_verification,
init_method,
time_kernel,
N,
C,
Y,
X,
Hi,
Wi,
window_stride_h,
window_stride_w,
in_left_pad_h,
in_left_pad_w,
in_right_pad_h,
in_right_pad_w);
return (pass ? 0 : 1);
}
......@@ -100,8 +100,13 @@ using DeviceGemmInstance = ck::tensor_operation::device::DeviceGemm_Xdl_CShuffle
16>; // index_t CShuffleBlockTransferScalarPerVector_NPerBlock>
// clang-format on
using ReferenceGemmInstance = ck::tensor_operation::host::
ReferenceGemm<ADataType, BDataType, CDataType, PassThrough, PassThrough, RequantReluRequant>;
using ReferenceGemmInstance = ck::tensor_operation::host::ReferenceGemm<ADataType,
BDataType,
CDataType,
float,
PassThrough,
PassThrough,
RequantReluRequant>;
int main(int argc, char* argv[])
{
......
......@@ -56,7 +56,7 @@ using DeviceGemmInstance = ck::tensor_operation::device::DeviceGroupedGemmXdl
// clang-format on
using ReferenceGemmInstance = ck::tensor_operation::host::
ReferenceGemm<ADataType, BDataType, CDataType, AElementOp, BElementOp, CElementOp>;
ReferenceGemm<ADataType, BDataType, CDataType, AccDataType, AElementOp, BElementOp, CElementOp>;
int main(int argc, char* argv[])
{
......
......@@ -32,6 +32,7 @@ using CDataType = F16;
using ReduceAccDataType = F32;
using DDataType = F64;
using DPtrsGlobal = ck::Tuple<DDataType*>;
using AccDataType = F32;
using ALayout = ck::tensor_layout::gemm::RowMajor;
using BLayout = ck::tensor_layout::gemm::ColumnMajor;
......@@ -59,7 +60,7 @@ using DeviceGemmReduceInstance = ck::tensor_operation::device::DeviceGemmReduce_
// clang-format on
using ReferenceGemmInstance = ck::tensor_operation::host::
ReferenceGemm<ADataType, BDataType, CDataType, AElementOp, BElementOp, CElementOp>;
ReferenceGemm<ADataType, BDataType, CDataType, AccDataType, AElementOp, BElementOp, CElementOp>;
int main(int argc, char* argv[])
{
......
......@@ -32,6 +32,7 @@ using CDataType = F16;
using ReduceAccDataType = F32;
using DDataType = F32;
using DPtrsGlobal = ck::Tuple<DDataType*, DDataType*>;
using AccDataType = F32;
using ALayout = ck::tensor_layout::gemm::RowMajor;
using BLayout = ck::tensor_layout::gemm::ColumnMajor;
......@@ -70,7 +71,7 @@ using DeviceGemmReduceInstance = ck::tensor_operation::device::DeviceGemmReduce_
// clang-format on
using ReferenceGemmInstance = ck::tensor_operation::host::
ReferenceGemm<ADataType, BDataType, CDataType, AElementOp, BElementOp, CElementOp>;
ReferenceGemm<ADataType, BDataType, CDataType, AccDataType, AElementOp, BElementOp, CElementOp>;
int main(int argc, char* argv[])
{
......
add_example_executable(example_broadcast_add_2d broadcast_add_2d.cpp)
add_example_executable(example_broadcast_add_2d_amn_bn broadcast_add_2d_amn_bn.cpp)
add_example_executable(example_broadcast_add_3d_am_bmnk broadcast_add_3d_am_bmnk.cpp)
add_example_executable(example_elementwise_add_1d elementwise_add_1d.cpp)
add_example_executable(example_elementwise_add_4d elementwise_add_4d.cpp)
\ No newline at end of file
......@@ -19,8 +19,17 @@ using EltwiseComputeDataType = F32;
using Add = ck::tensor_operation::binary_element_wise::Add;
using DeviceElementwiseAddInstance = ck::tensor_operation::device::
DeviceBinaryElementwise<ABDataType, ABDataType, CDataType, EltwiseComputeDataType, Add, 2, 8>;
using DeviceElementwiseAddInstance =
ck::tensor_operation::device::DeviceBinaryElementwise<ABDataType,
ABDataType,
CDataType,
EltwiseComputeDataType,
Add,
2,
8,
8,
8,
8>;
template <typename HostTensorA,
typename HostTensorB,
......@@ -100,7 +109,7 @@ int main()
if(!broadcastAdd.IsSupportedArgument(argument.get()))
{
throw std::runtime_error("The runtime parameters seems not supported by the "
"DeviceBinaryElementwise_2D instance, exiting!");
"DeviceBinaryElementwise instance, exiting!");
};
auto broadcastAdd_invoker_ptr = broadcastAdd.MakeInvokerPointer();
......@@ -123,7 +132,7 @@ int main()
0>(host_c_m_n, a_m_n, b_n, M, N, Add{});
pass &= ck::utils::check_err(
c_m_n.mData, host_c_m_n.mData, "Error: Incorrect results d1", 1e-3, 1e-3);
c_m_n.mData, host_c_m_n.mData, "Error: Incorrect results c", 1e-3, 1e-3);
}
return pass ? 0 : 1;
......
#include <iostream>
#include <cstdlib>
#include "check_err.hpp"
#include "config.hpp"
#include "device.hpp"
#include "host_tensor.hpp"
#include "host_tensor_generator.hpp"
#include "device_tensor.hpp"
#include "binary_element_wise_operation.hpp"
#include "device_binary_elementwise.hpp"
using F16 = ck::half_t;
using F32 = float;
using ABDataType = F16;
using CDataType = F16;
using EltwiseComputeDataType = F32;
using Add = ck::tensor_operation::binary_element_wise::Add;
using DeviceElementwiseAddInstance =
ck::tensor_operation::device::DeviceBinaryElementwise<ABDataType,
ABDataType,
CDataType,
EltwiseComputeDataType,
Add,
3,
8,
1,
8,
8>;
template <typename HostTensorA,
typename HostTensorB,
typename HostTensorC,
typename ComputeDataType,
typename Functor>
void host_broadcast3D_am_bmnk(HostTensorC& C,
const HostTensorA& A,
const HostTensorB& B,
const std::vector<std::size_t>& shape,
Functor functor)
{
using ctype = ck::remove_reference_t<decltype(C(0, 0))>;
for(std::size_t m = 0; m < shape[0]; ++m)
for(std::size_t n = 0; n < shape[1]; ++n)
for(std::size_t k = 0; k < shape[2]; ++k)
{
ComputeDataType a_val = static_cast<ComputeDataType>(A(m));
ComputeDataType b_val = static_cast<ComputeDataType>(B(m, n, k));
ComputeDataType c_val = 0;
functor(c_val, a_val, b_val);
C(m, n, k) = static_cast<ctype>(c_val);
}
}
int main()
{
bool do_verification = true;
bool time_kernel = false;
std::vector<std::size_t> mnk = {4, 16, 32};
ck::index_t M = mnk[0];
Tensor<ABDataType> a_m({M});
Tensor<ABDataType> b_m_n_k(mnk);
Tensor<CDataType> c_m_n_k(mnk);
a_m.GenerateTensorValue(GeneratorTensor_3<ABDataType>{0.0, 1.0});
b_m_n_k.GenerateTensorValue(GeneratorTensor_3<ABDataType>{0.0, 1.0});
DeviceMem a_m_device_buf(sizeof(ABDataType) * a_m.mDesc.GetElementSpace());
DeviceMem b_m_n_k_device_buf(sizeof(ABDataType) * b_m_n_k.mDesc.GetElementSpace());
DeviceMem c_m_n_k_device_buf(sizeof(CDataType) * c_m_n_k.mDesc.GetElementSpace());
a_m_device_buf.ToDevice(a_m.mData.data());
b_m_n_k_device_buf.ToDevice(b_m_n_k.mData.data());
auto broadcastAdd = DeviceElementwiseAddInstance{};
auto argument = broadcastAdd.MakeArgumentPointer(
a_m_device_buf.GetDeviceBuffer(),
b_m_n_k_device_buf.GetDeviceBuffer(),
c_m_n_k_device_buf.GetDeviceBuffer(),
std::vector<ck::index_t>{mnk.begin(), mnk.end()},
{1, 0, 0}, // broadcast A on second and third dimension
std::vector<ck::index_t>{b_m_n_k.mDesc.GetStrides().begin(),
b_m_n_k.mDesc.GetStrides().end()},
std::vector<ck::index_t>{c_m_n_k.mDesc.GetStrides().begin(),
c_m_n_k.mDesc.GetStrides().end()},
Add{});
if(!broadcastAdd.IsSupportedArgument(argument.get()))
{
throw std::runtime_error("The runtime parameters seems not supported by the "
"DeviceBinaryElementwise instance, exiting!");
};
auto broadcastAdd_invoker_ptr = broadcastAdd.MakeInvokerPointer();
float ave_time =
broadcastAdd_invoker_ptr->Run(argument.get(), StreamConfig{nullptr, time_kernel});
std::cout << "Perf: " << ave_time << " ms" << std::endl;
bool pass = true;
if(do_verification)
{
c_m_n_k_device_buf.FromDevice(c_m_n_k.mData.data());
Tensor<CDataType> host_c_m_n_k(mnk);
host_broadcast3D_am_bmnk<Tensor<ABDataType>,
Tensor<ABDataType>,
Tensor<CDataType>,
EltwiseComputeDataType,
Add>(host_c_m_n_k, a_m, b_m_n_k, mnk, Add{});
pass &= ck::utils::check_err(
c_m_n_k.mData, host_c_m_n_k.mData, "Error: Incorrect results c", 1e-3, 1e-3);
}
return pass ? 0 : 1;
}
......@@ -19,8 +19,17 @@ using EltwiseComputeDataType = F32;
using Add = ck::tensor_operation::binary_element_wise::Add;
using DeviceElementwiseAddInstance = ck::tensor_operation::device::
DeviceBinaryElementwise<ABDataType, ABDataType, CDataType, EltwiseComputeDataType, Add, 1, 8>;
using DeviceElementwiseAddInstance =
ck::tensor_operation::device::DeviceBinaryElementwise<ABDataType,
ABDataType,
CDataType,
EltwiseComputeDataType,
Add,
1,
8,
8,
8,
8>;
template <typename HostTensorA,
typename HostTensorB,
......@@ -81,7 +90,7 @@ int main()
if(!broadcastAdd.IsSupportedArgument(argument.get()))
{
throw std::runtime_error("The runtime parameters seems not supported by the "
"DeviceBinaryElementwise_2D instance, exiting!");
"DeviceBinaryElementwise instance, exiting!");
};
auto broadcastAdd_invoker_ptr = broadcastAdd.MakeInvokerPointer();
......@@ -103,7 +112,7 @@ int main()
Add>(host_c_m, a_m, b_m, M, Add{});
pass &= ck::utils::check_err(
c_m.mData, host_c_m.mData, "Error: Incorrect results d1", 1e-3, 1e-3);
c_m.mData, host_c_m.mData, "Error: Incorrect results c", 1e-3, 1e-3);
}
return pass ? 0 : 1;
......
......@@ -19,8 +19,17 @@ using EltwiseComputeDataType = F32;
using Add = ck::tensor_operation::binary_element_wise::Add;
using DeviceElementwiseAddInstance = ck::tensor_operation::device::
DeviceBinaryElementwise<ABDataType, ABDataType, CDataType, EltwiseComputeDataType, Add, 4, 8>;
using DeviceElementwiseAddInstance =
ck::tensor_operation::device::DeviceBinaryElementwise<ABDataType,
ABDataType,
CDataType,
EltwiseComputeDataType,
Add,
4,
8,
8,
8,
8>;
template <typename HostTensorA,
typename HostTensorB,
......@@ -83,7 +92,7 @@ int main()
if(!broadcastAdd.IsSupportedArgument(argument.get()))
{
throw std::runtime_error("The runtime parameters seems not supported by the "
"DeviceBinaryElementwise_2D instance, exiting!");
"DeviceBinaryElementwise instance, exiting!");
};
auto broadcastAdd_invoker_ptr = broadcastAdd.MakeInvokerPointer();
......@@ -105,7 +114,7 @@ int main()
Add>(host_c, a, b, nchw, Add{});
pass &=
ck::utils::check_err(c.mData, host_c.mData, "Error: Incorrect results d1", 1e-3, 1e-3);
ck::utils::check_err(c.mData, host_c.mData, "Error: Incorrect results c", 1e-3, 1e-3);
}
return pass ? 0 : 1;
......
......@@ -257,11 +257,11 @@ int main(int argc, char* argv[])
case 0: break;
case 1:
out_n_k_ho_wo.GenerateTensorValue(GeneratorTensor_2<OutDataType>{-2, 2});
in_n_c_hi_wi.GenerateTensorValue(GeneratorTensor_2<WeiDataType>{-2, 2});
in_n_c_hi_wi.GenerateTensorValue(GeneratorTensor_2<InDataType>{-2, 2});
break;
default:
out_n_k_ho_wo.GenerateTensorValue(GeneratorTensor_1<OutDataType>{1});
in_n_c_hi_wi.GenerateTensorValue(GeneratorTensor_1<WeiDataType>{1});
in_n_c_hi_wi.GenerateTensorValue(GeneratorTensor_1<InDataType>{1});
}
DeviceMem in_device_buf(sizeof(InDataType) * in_n_c_hi_wi.mDesc.GetElementSpace());
......@@ -296,15 +296,53 @@ int main(int argc, char* argv[])
OutElementOp{},
split_k);
if(!conv->IsSupportedArgument(argument.get()))
// alloc work space
size_t bwd_weight_workspace_size = conv->GetWorkSpaceSize(argument.get());
float ave_time = 0.f;
if(std::is_same<InDataType, ck::bhalf_t>::value && split_k > 1)
{
std::cout << "wrong! device_conv with the specified compilation parameters does "
"not support this Conv problem"
<< std::endl;
return 1;
}
DeviceMem wei_work_space_device_buf(bwd_weight_workspace_size);
wei_work_space_device_buf.SetZero();
argument = conv->MakeArgumentPointer(
static_cast<InDataType*>(in_device_buf.GetDeviceBuffer()),
static_cast<AccDataType*>(wei_work_space_device_buf.GetDeviceBuffer()),
static_cast<OutDataType*>(out_device_buf.GetDeviceBuffer()),
params.N_,
params.K_,
params.C_,
params.input_spatial_lengths_,
params.filter_spatial_lengths_,
output_spatial_lengths,
params.conv_filter_strides_,
params.conv_filter_dilations_,
params.input_left_pads_,
params.input_right_pads_,
InElementOp{},
WeiElementOp{},
OutElementOp{},
split_k);
if(!conv->IsSupportedArgument(argument.get()))
{
std::cout << "wrong! device_conv with the specified compilation parameters does "
"not support this Conv problem"
<< std::endl;
return 1;
}
float ave_time = invoker->Run(argument.get(), StreamConfig{nullptr, time_kernel});
ave_time = invoker->Run(argument.get(), StreamConfig{nullptr, time_kernel});
}
else
{
if(!conv->IsSupportedArgument(argument.get()))
{
std::cout << "wrong! device_conv with the specified compilation parameters does "
"not support this Conv problem"
<< std::endl;
return 1;
}
ave_time = invoker->Run(argument.get(), StreamConfig{nullptr, time_kernel});
}
std::size_t flop = ck::utils::conv::get_flops(
params.N_, params.C_, params.K_, params.filter_spatial_lengths_, output_spatial_lengths);
......
include_directories(BEFORE
${PROJECT_SOURCE_DIR}/include/ck
${PROJECT_SOURCE_DIR}/include/ck/utility
${PROJECT_SOURCE_DIR}/include/ck/host_utility
${PROJECT_SOURCE_DIR}/include/ck/tensor_description
${PROJECT_SOURCE_DIR}/include/ck/tensor
${PROJECT_SOURCE_DIR}/include/ck/problem_transform
......
#pragma once
#include <string>
#include <map>
namespace ck {
inline std::string get_device_name()
{
hipDeviceProp_t props{};
int device;
auto status = hipGetDevice(&device);
if(status != hipSuccess)
{
return std::string();
}
status = hipGetDeviceProperties(&props, device);
if(status != hipSuccess)
{
return std::string();
}
const std::string raw_name(props.gcnArchName);
// https://github.com/ROCmSoftwarePlatform/MIOpen/blob/8498875aef84878e04c1eabefdf6571514891086/src/target_properties.cpp#L40
static std::map<std::string, std::string> device_name_map = {
{"Ellesmere", "gfx803"},
{"Baffin", "gfx803"},
{"RacerX", "gfx803"},
{"Polaris10", "gfx803"},
{"Polaris11", "gfx803"},
{"Tonga", "gfx803"},
{"Fiji", "gfx803"},
{"gfx800", "gfx803"},
{"gfx802", "gfx803"},
{"gfx804", "gfx803"},
{"Vega10", "gfx900"},
{"gfx901", "gfx900"},
{"10.3.0 Sienna_Cichlid 18", "gfx1030"},
};
const auto name = raw_name.substr(0, raw_name.find(':')); // str.substr(0, npos) returns str.
auto match = device_name_map.find(name);
if(match != device_name_map.end())
return match->second;
return name;
}
} // namespace ck
#ifndef CK_BLOCKWISE_GEMM_DLOPS_V2R3_HPP
#define CK_BLOCKWISE_GEMM_DLOPS_V2R3_HPP
#pragma once
#include "common_header.hpp"
#include "tensor_adaptor.hpp"
#include "threadwise_tensor_slice_transfer_v2.hpp"
#include "threadwise_contraction_dlops.hpp"
#include "threadwise_tensor_slice_transfer_v4r1.hpp"
#include "threadwise_contraction_dl.hpp"
namespace ck {
......@@ -41,7 +39,7 @@ template <index_t BlockSize,
typename enable_if<ABlockDesc_BK0_BM_BK1::IsKnownAtCompileTime() &&
BBlockDesc_BK0_BN_BK1::IsKnownAtCompileTime(),
bool>::type = false>
struct BlockwiseGemmDlops_A_BK0_BM_BK1_B_BK0_BN_BK1_C_BM0_BM1_BN0_BN1_pipeline_BM0_2_BN0_2
struct BlockwiseGemmDl_A_BK0_BM_BK1_B_BK0_BN_BK1_C_BM0_BM1_BN0_BN1_pipeline_BM0_2_BN0_2
{
using AIndex = MultiIndex<3>;
using BIndex = MultiIndex<3>;
......@@ -148,7 +146,7 @@ struct BlockwiseGemmDlops_A_BK0_BM_BK1_B_BK0_BN_BK1_C_BM0_BM1_BN0_BN1_pipeline_B
MakeBBlockDescriptor_BK0_BN0_BN1_BK1(BBlockDesc_BK0_BN_BK1{});
public:
__device__ BlockwiseGemmDlops_A_BK0_BM_BK1_B_BK0_BN_BK1_C_BM0_BM1_BN0_BN1_pipeline_BM0_2_BN0_2()
__device__ BlockwiseGemmDl_A_BK0_BM_BK1_B_BK0_BN_BK1_C_BM0_BM1_BN0_BN1_pipeline_BM0_2_BN0_2()
: c_thread_origin_data_idx_{CalculateCThreadOriginOnBlock_BM0_BM1_BN0_BN1(
get_thread_local_1d_id())},
a_thread_copy_{
......@@ -175,6 +173,7 @@ struct BlockwiseGemmDlops_A_BK0_BM_BK1_B_BK0_BN_BK1_C_BM0_BM1_BN0_BN1_pipeline_B
"wrong!");
// TODO: remove this restriction
static_assert(BM0 == 2, "wrong");
static_assert(BM0 == 2 && BN0 == 2, "wrong");
}
......@@ -226,7 +225,7 @@ struct BlockwiseGemmDlops_A_BK0_BM_BK1_B_BK0_BN_BK1_C_BM0_BM1_BN0_BN1_pipeline_B
b_thread_desc_bk0_bn0_bn1_bk1_.GetElementSpaceSize());
constexpr auto threadwise_contraction =
ThreadwiseContractionDlops_A_TK0_TM0_TM1_TK1_B_TK0_TN0_TN1_TK1_C_TM0_TM1_TN0_TN1<
ThreadwiseContractionDl_A_TK0_TM0_TM1_TK1_B_TK0_TN0_TN1_TK1_C_TM0_TM1_TN0_TN1<
FloatA,
FloatB,
FloatC,
......@@ -407,4 +406,3 @@ struct BlockwiseGemmDlops_A_BK0_BM_BK1_B_BK0_BN_BK1_C_BM0_BM1_BN0_BN1_pipeline_B
};
} // namespace ck
#endif
......@@ -75,14 +75,13 @@ struct BlockwiseTensorSliceTransfer_v5r1
}
}
template <typename SrcBuffer, typename SrcStepHacks>
__device__ void
RunRead(const SrcDesc& src_desc, const SrcBuffer& src_buf, const SrcStepHacks& src_step_hacks)
template <typename SrcBuffer>
__device__ void RunRead(const SrcDesc& src_desc, const SrcBuffer& src_buf)
{
if(BlockSize == thread_cluster_desc_.GetElementSize() or
get_thread_local_1d_id() < thread_cluster_desc_.GetElementSize())
{
threadwise_transfer_.RunRead(src_desc, src_buf, src_step_hacks);
threadwise_transfer_.RunRead(src_desc, src_buf);
}
}
......
......@@ -40,6 +40,8 @@ struct BaseOperator
virtual bool IsSupportedArgument(const BaseArgument*) { return false; }
virtual std::string GetTypeString() const { return ""; }
virtual size_t GetWorkSpaceSize(const BaseArgument*) const { return 0; }
virtual ~BaseOperator() {}
};
......
......@@ -15,91 +15,107 @@ template <typename ADataType,
typename CDataType,
typename ComputeDataType,
typename ElementwiseFunctor,
index_t Dim,
index_t ScalarPerVector>
index_t NDim,
index_t MPerThread,
index_t AScalarPerVector,
index_t BScalarPerVector,
index_t CScalarPerVector>
struct DeviceBinaryElementwise : public BaseOperator
{
static constexpr auto I0 = Number<0>{};
template <typename Desc_M0>
static auto PadDescriptor_M0_1d(Desc_M0 desc_m0, index_t gridSize, index_t blockSize)
template <typename Desc_M>
static auto PadDescriptor_M_1d(Desc_M desc_m, index_t gridSize, index_t blockSize)
{
const auto m0 = desc_m0.GetLength(I0);
const index_t loop_step = gridSize * blockSize * ScalarPerVector;
const auto pad = math::integer_least_multiple(m0, loop_step) - m0;
const auto desc_m0_pad =
transform_tensor_descriptor(desc_m0,
make_tuple(make_right_pad_transform(m0, pad)),
const auto M = desc_m.GetLength(I0);
const index_t loop_step = gridSize * blockSize * MPerThread;
const auto pad = math::integer_least_multiple(M, loop_step) - M;
const auto desc_m_pad =
transform_tensor_descriptor(desc_m,
make_tuple(make_right_pad_transform(M, pad)),
make_tuple(Sequence<0>{}),
make_tuple(Sequence<0>{}));
return desc_m0_pad;
return desc_m_pad;
}
static auto MakeDescriptor_M0(const std::vector<index_t>& shape,
const std::vector<index_t>& stride,
index_t gridSize,
index_t blockSize)
static auto MakeDescriptor_M(const std::vector<index_t>& lengths,
const std::vector<index_t>& strides,
index_t gridSize,
index_t blockSize)
{
auto tupleOfShape = generate_tuple([&](auto I) { return shape[I]; }, Number<Dim>{});
auto tupleOfStride = generate_tuple([&](auto I) { return stride[I]; }, Number<Dim>{});
auto tupleOfShape = generate_tuple([&](auto I) { return lengths[I]; }, Number<NDim>{});
auto tupleOfStride = generate_tuple([&](auto I) { return strides[I]; }, Number<NDim>{});
// nd desc - [s0, s1, s2, ...]
const auto desc = make_naive_tensor_descriptor(tupleOfShape, tupleOfStride);
// merge nd to 1d desc - [s0 * s1 * ...]
if constexpr(Dim > 1)
if constexpr(NDim > 1)
{
const auto desc_m0 = transform_tensor_descriptor(
const auto desc_m = transform_tensor_descriptor(
desc,
make_tuple(make_merge_transform(tupleOfShape)),
make_tuple(generate_sequence_v2([&](auto I) { return I; }, Number<Dim>{})),
make_tuple(generate_sequence_v2([&](auto I) { return I; }, Number<NDim>{})),
make_tuple(Sequence<0>{}));
return PadDescriptor_M0_1d(desc_m0, gridSize, blockSize);
return PadDescriptor_M_1d(desc_m, gridSize, blockSize);
}
else
return PadDescriptor_M0_1d(desc, gridSize, blockSize);
return PadDescriptor_M_1d(desc, gridSize, blockSize);
}
using GridDesc_M0 = decltype(MakeDescriptor_M0({1, 1}, {1, 1}, 1, 1));
using AGridDesc_M = decltype(MakeDescriptor_M({1, 1}, {1, 1}, 1, 1));
using BGridDesc_M = decltype(MakeDescriptor_M({1, 1}, {1, 1}, 1, 1));
using CGridDesc_M = decltype(MakeDescriptor_M({1, 1}, {1, 1}, 1, 1));
using GridwiseBinEltwise = GridwiseBinaryElementwise_1D<ADataType,
BDataType,
CDataType,
ComputeDataType,
GridDesc_M0,
AGridDesc_M,
BGridDesc_M,
CGridDesc_M,
ElementwiseFunctor,
ScalarPerVector>;
MPerThread,
AScalarPerVector,
BScalarPerVector,
CScalarPerVector>;
struct Argument : public BaseArgument
{
Argument(const ADataType* p_a,
const BDataType* p_b,
CDataType* p_c,
const std::vector<index_t>& shape,
const std::vector<index_t>& stride_a,
const std::vector<index_t>& stride_b,
const std::vector<index_t>& stride_c,
const std::vector<index_t>& lengths,
const std::vector<index_t>& a_strides,
const std::vector<index_t>& b_strides,
const std::vector<index_t>& c_strides,
ElementwiseFunctor functor)
: p_a_(p_a),
p_b_(p_b),
p_c_(p_c),
shape_(shape),
lengths_(lengths),
a_strides_(a_strides),
b_strides_(b_strides),
c_strides_(c_strides),
functor_(functor),
blockSize_(256),
gridSize_(120) // FIXME - Calculate the grid size by number of CU in the future
{
a_grid_desc_m0_ = MakeDescriptor_M0(shape, stride_a, gridSize_, blockSize_);
b_grid_desc_m0_ = MakeDescriptor_M0(shape, stride_b, gridSize_, blockSize_);
c_grid_desc_m0_ = MakeDescriptor_M0(shape, stride_c, gridSize_, blockSize_);
a_grid_desc_m_ = MakeDescriptor_M(lengths, a_strides, gridSize_, blockSize_);
b_grid_desc_m_ = MakeDescriptor_M(lengths, b_strides, gridSize_, blockSize_);
c_grid_desc_m_ = MakeDescriptor_M(lengths, c_strides, gridSize_, blockSize_);
}
const ADataType* p_a_;
const BDataType* p_b_;
CDataType* p_c_;
std::vector<int> shape_;
GridDesc_M0 a_grid_desc_m0_;
GridDesc_M0 b_grid_desc_m0_;
GridDesc_M0 c_grid_desc_m0_;
std::vector<int> lengths_;
AGridDesc_M a_grid_desc_m_;
BGridDesc_M b_grid_desc_m_;
CGridDesc_M c_grid_desc_m_;
std::vector<index_t> a_strides_;
std::vector<index_t> b_strides_;
std::vector<index_t> c_strides_;
ElementwiseFunctor functor_;
index_t blockSize_;
index_t gridSize_;
......@@ -113,7 +129,9 @@ struct DeviceBinaryElementwise : public BaseOperator
ADataType,
BDataType,
CDataType,
GridDesc_M0,
AGridDesc_M,
BGridDesc_M,
CGridDesc_M,
ElementwiseFunctor>;
float elapsed_time = launch_and_time_kernel(stream_config,
......@@ -124,9 +142,9 @@ struct DeviceBinaryElementwise : public BaseOperator
arg.p_a_,
arg.p_b_,
arg.p_c_,
arg.a_grid_desc_m0_,
arg.b_grid_desc_m0_,
arg.c_grid_desc_m0_,
arg.a_grid_desc_m_,
arg.b_grid_desc_m_,
arg.c_grid_desc_m_,
arg.functor_);
return elapsed_time;
}
......@@ -146,7 +164,30 @@ struct DeviceBinaryElementwise : public BaseOperator
if(pArg == nullptr)
return false;
if(pArg->shape_.back() % ScalarPerVector != 0)
if(pArg->lengths_.size() != NDim)
return false;
if(pArg->lengths_.back() % MPerThread != 0)
return false;
auto IsScalarPerVectorValid = [](bool isLastDimensionCoalesced, int scalarPerVector) {
bool ret = true;
if(!isLastDimensionCoalesced)
ret = scalarPerVector == 1;
else
ret = MPerThread % scalarPerVector == 0;
return ret;
};
if(!IsScalarPerVectorValid(pArg->a_strides_.back() == 1, AScalarPerVector))
return false;
if(!IsScalarPerVectorValid(pArg->b_strides_.back() == 1, BScalarPerVector))
return false;
if(!IsScalarPerVectorValid(pArg->c_strides_.back() == 1, CScalarPerVector))
return false;
return true;
......@@ -155,19 +196,19 @@ struct DeviceBinaryElementwise : public BaseOperator
std::unique_ptr<BaseArgument> MakeArgumentPointer(const void* p_a,
const void* p_b,
void* p_c,
std::vector<index_t> shape,
std::vector<index_t> stride_a,
std::vector<index_t> stride_b,
std::vector<index_t> stride_c,
std::vector<index_t> lengths,
std::vector<index_t> a_strides,
std::vector<index_t> b_strides,
std::vector<index_t> c_strides,
ElementwiseFunctor functor)
{
return std::make_unique<Argument>(static_cast<const ADataType*>(p_a),
static_cast<const BDataType*>(p_b),
static_cast<CDataType*>(p_c),
shape,
stride_a,
stride_b,
stride_c,
lengths,
a_strides,
b_strides,
c_strides,
functor);
}
......@@ -180,7 +221,7 @@ struct DeviceBinaryElementwise : public BaseOperator
// clang-format off
str << "DeviceBinaryElementwise"
<< "<"
<< "ScalarPerVector = " << ScalarPerVector
<< "MPerThread = " << MPerThread
<< ">";
// clang-format on
......
......@@ -1175,6 +1175,57 @@ struct DeviceConvndBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_
return str.str();
}
template <ck::index_t NDim, typename ck::enable_if<NDim == 1, bool>::type = false>
static size_t GetWorkSpaceSize(const Argument& arg)
{
size_t WorkSpaceSize = 0;
if(arg.k_batch_ > 1)
{
if constexpr(std::is_same<InDataType, ck::bhalf_t>::value)
{
WorkSpaceSize =
arg.Conv_K_ * arg.Conv_C_ * arg.filter_spatial_lengths_[0] * sizeof(float);
}
}
return WorkSpaceSize;
}
template <ck::index_t NDim, typename ck::enable_if<NDim == 2, bool>::type = false>
static size_t GetWorkSpaceSize(const Argument& arg)
{
size_t WorkSpaceSize = 0;
if(arg.k_batch_ > 1)
{
if constexpr(std::is_same<InDataType, ck::bhalf_t>::value)
{
WorkSpaceSize = arg.Conv_K_ * arg.Conv_C_ * arg.filter_spatial_lengths_[0] *
arg.filter_spatial_lengths_[1] * sizeof(float);
}
}
return WorkSpaceSize;
}
template <ck::index_t NDim, typename ck::enable_if<NDim == 3, bool>::type = false>
static size_t GetWorkSpaceSize(const Argument& arg)
{
size_t WorkSpaceSize = 0;
if(arg.k_batch_ > 1)
{
if constexpr(std::is_same<InDataType, ck::bhalf_t>::value)
{
WorkSpaceSize = arg.Conv_K_ * arg.Conv_C_ * arg.filter_spatial_lengths_[0] *
arg.filter_spatial_lengths_[1] * arg.filter_spatial_lengths_[2] *
sizeof(float);
}
}
return WorkSpaceSize;
}
size_t GetWorkSpaceSize(const BaseArgument* p_arg) const override final
{
return GetWorkSpaceSize<NumDimSpatial>(*dynamic_cast<const Argument*>(p_arg));
}
};
} // namespace device
......
This diff is collapsed.
#ifndef DEVICE_GEMM_XDL_HPP
#define DEVICE_GEMM_XDL_HPP
#pragma once
#include <iostream>
#include <sstream>
......@@ -12,6 +11,7 @@
#include "tensor_descriptor_helper.hpp"
#include "gridwise_gemm_xdlops_v2r3.hpp"
#include "gemm_specialization.hpp"
#include "device_prop.hpp"
namespace ck {
namespace tensor_operation {
......@@ -408,6 +408,11 @@ struct DeviceGemmXdl
static bool IsSupportedArgument(const Argument& arg)
{
if(!(ck::get_device_name() == "gfx908" || ck::get_device_name() == "gfx90a"))
{
return false;
}
return GridwiseGemm::CheckValidity(arg.a_grid_desc_k0_m_k1_,
arg.b_grid_desc_k0_n_k1_,
arg.c_grid_desc_m_n_,
......@@ -515,4 +520,3 @@ struct DeviceGemmXdl
} // namespace device
} // namespace tensor_operation
} // namespace ck
#endif
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment