You need to sign in or sign up before continuing.
Unverified Commit 82d7d993 authored by rocking5566's avatar rocking5566 Committed by GitHub
Browse files

Hotfix binary elementwise (for broadcast on fastest axis) (#254)



* Support different length of ScalarPerVector

* Add example of broadcast on fastest axis

* Typo

* Refine fastest example

* Add dimension check

* Modify fastest broadcast example to 3d

* Enforce users give scalarPerVector explicitely

* 1. Add CscalarPerVedctor
2. Not only broadcast on fastest need to set scalarPerVector to 1

* Rename var

* Move IsScalarPerVectorValid() inside IsSupportedArgument()

* Separate GridDesc_M0 into A, B and C

* rename var

* Rename var of length
Co-authored-by: default avatarrocking <chunylai@amd.com>
parent e579c9e5
add_example_executable(example_broadcast_add_2d broadcast_add_2d.cpp)
add_example_executable(example_broadcast_add_2d_amn_bn broadcast_add_2d_amn_bn.cpp)
add_example_executable(example_broadcast_add_3d_am_bmnk broadcast_add_3d_am_bmnk.cpp)
add_example_executable(example_elementwise_add_1d elementwise_add_1d.cpp)
add_example_executable(example_elementwise_add_4d elementwise_add_4d.cpp)
\ No newline at end of file
......@@ -19,8 +19,17 @@ using EltwiseComputeDataType = F32;
using Add = ck::tensor_operation::binary_element_wise::Add;
using DeviceElementwiseAddInstance = ck::tensor_operation::device::
DeviceBinaryElementwise<ABDataType, ABDataType, CDataType, EltwiseComputeDataType, Add, 2, 8>;
using DeviceElementwiseAddInstance =
ck::tensor_operation::device::DeviceBinaryElementwise<ABDataType,
ABDataType,
CDataType,
EltwiseComputeDataType,
Add,
2,
8,
8,
8,
8>;
template <typename HostTensorA,
typename HostTensorB,
......@@ -100,7 +109,7 @@ int main()
if(!broadcastAdd.IsSupportedArgument(argument.get()))
{
throw std::runtime_error("The runtime parameters seems not supported by the "
"DeviceBinaryElementwise_2D instance, exiting!");
"DeviceBinaryElementwise instance, exiting!");
};
auto broadcastAdd_invoker_ptr = broadcastAdd.MakeInvokerPointer();
......@@ -123,7 +132,7 @@ int main()
0>(host_c_m_n, a_m_n, b_n, M, N, Add{});
pass &= ck::utils::check_err(
c_m_n.mData, host_c_m_n.mData, "Error: Incorrect results d1", 1e-3, 1e-3);
c_m_n.mData, host_c_m_n.mData, "Error: Incorrect results c", 1e-3, 1e-3);
}
return pass ? 0 : 1;
......
#include <iostream>
#include <cstdlib>
#include "check_err.hpp"
#include "config.hpp"
#include "device.hpp"
#include "host_tensor.hpp"
#include "host_tensor_generator.hpp"
#include "device_tensor.hpp"
#include "binary_element_wise_operation.hpp"
#include "device_binary_elementwise.hpp"
using F16 = ck::half_t;
using F32 = float;
using ABDataType = F16;
using CDataType = F16;
using EltwiseComputeDataType = F32;
using Add = ck::tensor_operation::binary_element_wise::Add;
using DeviceElementwiseAddInstance =
ck::tensor_operation::device::DeviceBinaryElementwise<ABDataType,
ABDataType,
CDataType,
EltwiseComputeDataType,
Add,
3,
8,
1,
8,
8>;
template <typename HostTensorA,
typename HostTensorB,
typename HostTensorC,
typename ComputeDataType,
typename Functor>
void host_broadcast3D_am_bmnk(HostTensorC& C,
const HostTensorA& A,
const HostTensorB& B,
const std::vector<std::size_t>& shape,
Functor functor)
{
using ctype = ck::remove_reference_t<decltype(C(0, 0))>;
for(std::size_t m = 0; m < shape[0]; ++m)
for(std::size_t n = 0; n < shape[1]; ++n)
for(std::size_t k = 0; k < shape[2]; ++k)
{
ComputeDataType a_val = static_cast<ComputeDataType>(A(m));
ComputeDataType b_val = static_cast<ComputeDataType>(B(m, n, k));
ComputeDataType c_val = 0;
functor(c_val, a_val, b_val);
C(m, n, k) = static_cast<ctype>(c_val);
}
}
int main()
{
bool do_verification = true;
bool time_kernel = false;
std::vector<std::size_t> mnk = {4, 16, 32};
ck::index_t M = mnk[0];
Tensor<ABDataType> a_m({M});
Tensor<ABDataType> b_m_n_k(mnk);
Tensor<CDataType> c_m_n_k(mnk);
a_m.GenerateTensorValue(GeneratorTensor_3<ABDataType>{0.0, 1.0});
b_m_n_k.GenerateTensorValue(GeneratorTensor_3<ABDataType>{0.0, 1.0});
DeviceMem a_m_device_buf(sizeof(ABDataType) * a_m.mDesc.GetElementSpace());
DeviceMem b_m_n_k_device_buf(sizeof(ABDataType) * b_m_n_k.mDesc.GetElementSpace());
DeviceMem c_m_n_k_device_buf(sizeof(CDataType) * c_m_n_k.mDesc.GetElementSpace());
a_m_device_buf.ToDevice(a_m.mData.data());
b_m_n_k_device_buf.ToDevice(b_m_n_k.mData.data());
auto broadcastAdd = DeviceElementwiseAddInstance{};
auto argument = broadcastAdd.MakeArgumentPointer(
a_m_device_buf.GetDeviceBuffer(),
b_m_n_k_device_buf.GetDeviceBuffer(),
c_m_n_k_device_buf.GetDeviceBuffer(),
std::vector<ck::index_t>{mnk.begin(), mnk.end()},
{1, 0, 0}, // broadcast A on second and third dimension
std::vector<ck::index_t>{b_m_n_k.mDesc.GetStrides().begin(),
b_m_n_k.mDesc.GetStrides().end()},
std::vector<ck::index_t>{c_m_n_k.mDesc.GetStrides().begin(),
c_m_n_k.mDesc.GetStrides().end()},
Add{});
if(!broadcastAdd.IsSupportedArgument(argument.get()))
{
throw std::runtime_error("The runtime parameters seems not supported by the "
"DeviceBinaryElementwise instance, exiting!");
};
auto broadcastAdd_invoker_ptr = broadcastAdd.MakeInvokerPointer();
float ave_time =
broadcastAdd_invoker_ptr->Run(argument.get(), StreamConfig{nullptr, time_kernel});
std::cout << "Perf: " << ave_time << " ms" << std::endl;
bool pass = true;
if(do_verification)
{
c_m_n_k_device_buf.FromDevice(c_m_n_k.mData.data());
Tensor<CDataType> host_c_m_n_k(mnk);
host_broadcast3D_am_bmnk<Tensor<ABDataType>,
Tensor<ABDataType>,
Tensor<CDataType>,
EltwiseComputeDataType,
Add>(host_c_m_n_k, a_m, b_m_n_k, mnk, Add{});
pass &= ck::utils::check_err(
c_m_n_k.mData, host_c_m_n_k.mData, "Error: Incorrect results c", 1e-3, 1e-3);
}
return pass ? 0 : 1;
}
......@@ -19,8 +19,17 @@ using EltwiseComputeDataType = F32;
using Add = ck::tensor_operation::binary_element_wise::Add;
using DeviceElementwiseAddInstance = ck::tensor_operation::device::
DeviceBinaryElementwise<ABDataType, ABDataType, CDataType, EltwiseComputeDataType, Add, 1, 8>;
using DeviceElementwiseAddInstance =
ck::tensor_operation::device::DeviceBinaryElementwise<ABDataType,
ABDataType,
CDataType,
EltwiseComputeDataType,
Add,
1,
8,
8,
8,
8>;
template <typename HostTensorA,
typename HostTensorB,
......@@ -81,7 +90,7 @@ int main()
if(!broadcastAdd.IsSupportedArgument(argument.get()))
{
throw std::runtime_error("The runtime parameters seems not supported by the "
"DeviceBinaryElementwise_2D instance, exiting!");
"DeviceBinaryElementwise instance, exiting!");
};
auto broadcastAdd_invoker_ptr = broadcastAdd.MakeInvokerPointer();
......@@ -103,7 +112,7 @@ int main()
Add>(host_c_m, a_m, b_m, M, Add{});
pass &= ck::utils::check_err(
c_m.mData, host_c_m.mData, "Error: Incorrect results d1", 1e-3, 1e-3);
c_m.mData, host_c_m.mData, "Error: Incorrect results c", 1e-3, 1e-3);
}
return pass ? 0 : 1;
......
......@@ -19,8 +19,17 @@ using EltwiseComputeDataType = F32;
using Add = ck::tensor_operation::binary_element_wise::Add;
using DeviceElementwiseAddInstance = ck::tensor_operation::device::
DeviceBinaryElementwise<ABDataType, ABDataType, CDataType, EltwiseComputeDataType, Add, 4, 8>;
using DeviceElementwiseAddInstance =
ck::tensor_operation::device::DeviceBinaryElementwise<ABDataType,
ABDataType,
CDataType,
EltwiseComputeDataType,
Add,
4,
8,
8,
8,
8>;
template <typename HostTensorA,
typename HostTensorB,
......@@ -83,7 +92,7 @@ int main()
if(!broadcastAdd.IsSupportedArgument(argument.get()))
{
throw std::runtime_error("The runtime parameters seems not supported by the "
"DeviceBinaryElementwise_2D instance, exiting!");
"DeviceBinaryElementwise instance, exiting!");
};
auto broadcastAdd_invoker_ptr = broadcastAdd.MakeInvokerPointer();
......@@ -105,7 +114,7 @@ int main()
Add>(host_c, a, b, nchw, Add{});
pass &=
ck::utils::check_err(c.mData, host_c.mData, "Error: Incorrect results d1", 1e-3, 1e-3);
ck::utils::check_err(c.mData, host_c.mData, "Error: Incorrect results c", 1e-3, 1e-3);
}
return pass ? 0 : 1;
......
......@@ -15,91 +15,107 @@ template <typename ADataType,
typename CDataType,
typename ComputeDataType,
typename ElementwiseFunctor,
index_t Dim,
index_t ScalarPerVector>
index_t NDim,
index_t MPerThread,
index_t AScalarPerVector,
index_t BScalarPerVector,
index_t CScalarPerVector>
struct DeviceBinaryElementwise : public BaseOperator
{
static constexpr auto I0 = Number<0>{};
template <typename Desc_M0>
static auto PadDescriptor_M0_1d(Desc_M0 desc_m0, index_t gridSize, index_t blockSize)
template <typename Desc_M>
static auto PadDescriptor_M_1d(Desc_M desc_m, index_t gridSize, index_t blockSize)
{
const auto m0 = desc_m0.GetLength(I0);
const index_t loop_step = gridSize * blockSize * ScalarPerVector;
const auto pad = math::integer_least_multiple(m0, loop_step) - m0;
const auto desc_m0_pad =
transform_tensor_descriptor(desc_m0,
make_tuple(make_right_pad_transform(m0, pad)),
const auto M = desc_m.GetLength(I0);
const index_t loop_step = gridSize * blockSize * MPerThread;
const auto pad = math::integer_least_multiple(M, loop_step) - M;
const auto desc_m_pad =
transform_tensor_descriptor(desc_m,
make_tuple(make_right_pad_transform(M, pad)),
make_tuple(Sequence<0>{}),
make_tuple(Sequence<0>{}));
return desc_m0_pad;
return desc_m_pad;
}
static auto MakeDescriptor_M0(const std::vector<index_t>& shape,
const std::vector<index_t>& stride,
index_t gridSize,
index_t blockSize)
static auto MakeDescriptor_M(const std::vector<index_t>& lengths,
const std::vector<index_t>& strides,
index_t gridSize,
index_t blockSize)
{
auto tupleOfShape = generate_tuple([&](auto I) { return shape[I]; }, Number<Dim>{});
auto tupleOfStride = generate_tuple([&](auto I) { return stride[I]; }, Number<Dim>{});
auto tupleOfShape = generate_tuple([&](auto I) { return lengths[I]; }, Number<NDim>{});
auto tupleOfStride = generate_tuple([&](auto I) { return strides[I]; }, Number<NDim>{});
// nd desc - [s0, s1, s2, ...]
const auto desc = make_naive_tensor_descriptor(tupleOfShape, tupleOfStride);
// merge nd to 1d desc - [s0 * s1 * ...]
if constexpr(Dim > 1)
if constexpr(NDim > 1)
{
const auto desc_m0 = transform_tensor_descriptor(
const auto desc_m = transform_tensor_descriptor(
desc,
make_tuple(make_merge_transform(tupleOfShape)),
make_tuple(generate_sequence_v2([&](auto I) { return I; }, Number<Dim>{})),
make_tuple(generate_sequence_v2([&](auto I) { return I; }, Number<NDim>{})),
make_tuple(Sequence<0>{}));
return PadDescriptor_M0_1d(desc_m0, gridSize, blockSize);
return PadDescriptor_M_1d(desc_m, gridSize, blockSize);
}
else
return PadDescriptor_M0_1d(desc, gridSize, blockSize);
return PadDescriptor_M_1d(desc, gridSize, blockSize);
}
using GridDesc_M0 = decltype(MakeDescriptor_M0({1, 1}, {1, 1}, 1, 1));
using AGridDesc_M = decltype(MakeDescriptor_M({1, 1}, {1, 1}, 1, 1));
using BGridDesc_M = decltype(MakeDescriptor_M({1, 1}, {1, 1}, 1, 1));
using CGridDesc_M = decltype(MakeDescriptor_M({1, 1}, {1, 1}, 1, 1));
using GridwiseBinEltwise = GridwiseBinaryElementwise_1D<ADataType,
BDataType,
CDataType,
ComputeDataType,
GridDesc_M0,
AGridDesc_M,
BGridDesc_M,
CGridDesc_M,
ElementwiseFunctor,
ScalarPerVector>;
MPerThread,
AScalarPerVector,
BScalarPerVector,
CScalarPerVector>;
struct Argument : public BaseArgument
{
Argument(const ADataType* p_a,
const BDataType* p_b,
CDataType* p_c,
const std::vector<index_t>& shape,
const std::vector<index_t>& stride_a,
const std::vector<index_t>& stride_b,
const std::vector<index_t>& stride_c,
const std::vector<index_t>& lengths,
const std::vector<index_t>& a_strides,
const std::vector<index_t>& b_strides,
const std::vector<index_t>& c_strides,
ElementwiseFunctor functor)
: p_a_(p_a),
p_b_(p_b),
p_c_(p_c),
shape_(shape),
lengths_(lengths),
a_strides_(a_strides),
b_strides_(b_strides),
c_strides_(c_strides),
functor_(functor),
blockSize_(256),
gridSize_(120) // FIXME - Calculate the grid size by number of CU in the future
{
a_grid_desc_m0_ = MakeDescriptor_M0(shape, stride_a, gridSize_, blockSize_);
b_grid_desc_m0_ = MakeDescriptor_M0(shape, stride_b, gridSize_, blockSize_);
c_grid_desc_m0_ = MakeDescriptor_M0(shape, stride_c, gridSize_, blockSize_);
a_grid_desc_m_ = MakeDescriptor_M(lengths, a_strides, gridSize_, blockSize_);
b_grid_desc_m_ = MakeDescriptor_M(lengths, b_strides, gridSize_, blockSize_);
c_grid_desc_m_ = MakeDescriptor_M(lengths, c_strides, gridSize_, blockSize_);
}
const ADataType* p_a_;
const BDataType* p_b_;
CDataType* p_c_;
std::vector<int> shape_;
GridDesc_M0 a_grid_desc_m0_;
GridDesc_M0 b_grid_desc_m0_;
GridDesc_M0 c_grid_desc_m0_;
std::vector<int> lengths_;
AGridDesc_M a_grid_desc_m_;
BGridDesc_M b_grid_desc_m_;
CGridDesc_M c_grid_desc_m_;
std::vector<index_t> a_strides_;
std::vector<index_t> b_strides_;
std::vector<index_t> c_strides_;
ElementwiseFunctor functor_;
index_t blockSize_;
index_t gridSize_;
......@@ -113,7 +129,9 @@ struct DeviceBinaryElementwise : public BaseOperator
ADataType,
BDataType,
CDataType,
GridDesc_M0,
AGridDesc_M,
BGridDesc_M,
CGridDesc_M,
ElementwiseFunctor>;
float elapsed_time = launch_and_time_kernel(stream_config,
......@@ -124,9 +142,9 @@ struct DeviceBinaryElementwise : public BaseOperator
arg.p_a_,
arg.p_b_,
arg.p_c_,
arg.a_grid_desc_m0_,
arg.b_grid_desc_m0_,
arg.c_grid_desc_m0_,
arg.a_grid_desc_m_,
arg.b_grid_desc_m_,
arg.c_grid_desc_m_,
arg.functor_);
return elapsed_time;
}
......@@ -146,7 +164,30 @@ struct DeviceBinaryElementwise : public BaseOperator
if(pArg == nullptr)
return false;
if(pArg->shape_.back() % ScalarPerVector != 0)
if(pArg->lengths_.size() != NDim)
return false;
if(pArg->lengths_.back() % MPerThread != 0)
return false;
auto IsScalarPerVectorValid = [](bool isLastDimensionCoalesced, int scalarPerVector) {
bool ret = true;
if(!isLastDimensionCoalesced)
ret = scalarPerVector == 1;
else
ret = MPerThread % scalarPerVector == 0;
return ret;
};
if(!IsScalarPerVectorValid(pArg->a_strides_.back() == 1, AScalarPerVector))
return false;
if(!IsScalarPerVectorValid(pArg->b_strides_.back() == 1, BScalarPerVector))
return false;
if(!IsScalarPerVectorValid(pArg->c_strides_.back() == 1, CScalarPerVector))
return false;
return true;
......@@ -155,19 +196,19 @@ struct DeviceBinaryElementwise : public BaseOperator
std::unique_ptr<BaseArgument> MakeArgumentPointer(const void* p_a,
const void* p_b,
void* p_c,
std::vector<index_t> shape,
std::vector<index_t> stride_a,
std::vector<index_t> stride_b,
std::vector<index_t> stride_c,
std::vector<index_t> lengths,
std::vector<index_t> a_strides,
std::vector<index_t> b_strides,
std::vector<index_t> c_strides,
ElementwiseFunctor functor)
{
return std::make_unique<Argument>(static_cast<const ADataType*>(p_a),
static_cast<const BDataType*>(p_b),
static_cast<CDataType*>(p_c),
shape,
stride_a,
stride_b,
stride_c,
lengths,
a_strides,
b_strides,
c_strides,
functor);
}
......@@ -180,7 +221,7 @@ struct DeviceBinaryElementwise : public BaseOperator
// clang-format off
str << "DeviceBinaryElementwise"
<< "<"
<< "ScalarPerVector = " << ScalarPerVector
<< "MPerThread = " << MPerThread
<< ">";
// clang-format on
......
......@@ -11,138 +11,140 @@ template <typename GridwiseBinEltwise,
typename ADataType,
typename BDataType,
typename CDataType,
typename GridDesc_M0,
typename AGridDesc_M,
typename BGridDesc_M,
typename CGridDesc_M,
typename ElementwiseFunctor>
__global__ void kernel_binary_elementwise_1d(const ADataType* __restrict__ p_a_global,
const BDataType* __restrict__ p_b_global,
CDataType* __restrict__ p_c_global,
const GridDesc_M0 a_grid_desc_m0,
const GridDesc_M0 b_grid_desc_m0,
const GridDesc_M0 c_grid_desc_m0,
const AGridDesc_M a_grid_desc_m,
const BGridDesc_M b_grid_desc_m,
const CGridDesc_M c_grid_desc_m,
const ElementwiseFunctor functor)
{
GridwiseBinEltwise::Run(p_a_global,
p_b_global,
p_c_global,
a_grid_desc_m0,
b_grid_desc_m0,
c_grid_desc_m0,
functor);
GridwiseBinEltwise::Run(
p_a_global, p_b_global, p_c_global, a_grid_desc_m, b_grid_desc_m, c_grid_desc_m, functor);
}
template <typename ADataType,
typename BDataType,
typename CDataType,
typename ComputeDataType,
typename GridDesc_M0,
typename AGridDesc_M,
typename BGridDesc_M,
typename CGridDesc_M,
typename ElementwiseFunctor,
index_t ScalarPerVector>
index_t MPerThread,
index_t AScalarPerVector,
index_t BScalarPerVector,
index_t CScalarPerVector>
struct GridwiseBinaryElementwise_1D
{
static constexpr auto I0 = Number<0>{};
static constexpr auto thread_desc_m0 =
make_naive_tensor_descriptor_packed(make_tuple(Number<ScalarPerVector>{}));
static constexpr auto thread_desc_m =
make_naive_tensor_descriptor_packed(make_tuple(Number<MPerThread>{}));
using PassThrough = tensor_operation::element_wise::PassThrough;
static __device__ auto CalculateElementwiseIndex()
{
const index_t global_thread_id = get_thread_global_1d_id();
return make_multi_index(global_thread_id * ScalarPerVector);
return make_multi_index(global_thread_id * MPerThread);
}
__device__ static void Run(const ADataType* __restrict__ p_a_global,
const BDataType* __restrict__ p_b_global,
CDataType* __restrict__ p_c_global,
const GridDesc_M0 a_grid_desc_m0,
const GridDesc_M0 b_grid_desc_m0,
const GridDesc_M0 c_grid_desc_m0,
const AGridDesc_M a_grid_desc_m,
const BGridDesc_M b_grid_desc_m,
const CGridDesc_M c_grid_desc_m,
const ElementwiseFunctor functor)
{
const auto a_global_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_a_global, a_grid_desc_m0.GetElementSpaceSize());
p_a_global, a_grid_desc_m.GetElementSpaceSize());
const auto b_global_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_b_global, b_grid_desc_m0.GetElementSpaceSize());
p_b_global, b_grid_desc_m.GetElementSpaceSize());
auto c_global_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_c_global, c_grid_desc_m0.GetElementSpaceSize());
p_c_global, c_grid_desc_m.GetElementSpaceSize());
StaticBuffer<AddressSpaceEnum::Vgpr, ComputeDataType, ScalarPerVector, true> a_thread_buf;
StaticBuffer<AddressSpaceEnum::Vgpr, ComputeDataType, ScalarPerVector, true> b_thread_buf;
StaticBuffer<AddressSpaceEnum::Vgpr, ComputeDataType, ScalarPerVector, true> c_thread_buf;
StaticBuffer<AddressSpaceEnum::Vgpr, ComputeDataType, MPerThread, true> a_thread_buf;
StaticBuffer<AddressSpaceEnum::Vgpr, ComputeDataType, MPerThread, true> b_thread_buf;
StaticBuffer<AddressSpaceEnum::Vgpr, ComputeDataType, MPerThread, true> c_thread_buf;
const auto thread_store_global_offset = CalculateElementwiseIndex();
auto a_global_load =
ThreadwiseTensorSliceTransfer_v2<ADataType,
ComputeDataType,
GridDesc_M0,
decltype(thread_desc_m0),
Sequence<ScalarPerVector>, // SliceLengths
Sequence<0>, // DimAccessOrder
0, // SrcVectorDim
ScalarPerVector,
1, // SrcScalarStrideInVector
false>{a_grid_desc_m0, thread_store_global_offset};
AGridDesc_M,
decltype(thread_desc_m),
Sequence<MPerThread>, // SliceLengths
Sequence<0>, // DimAccessOrder
0, // SrcVectorDim
AScalarPerVector, // ScalarPerVector
1, // SrcScalarStrideInVector
false>{a_grid_desc_m, thread_store_global_offset};
auto b_global_load =
ThreadwiseTensorSliceTransfer_v2<BDataType,
ComputeDataType,
GridDesc_M0,
decltype(thread_desc_m0),
Sequence<ScalarPerVector>, // SliceLengths
Sequence<0>, // DimAccessOrder
0, // SrcVectorDim
ScalarPerVector,
1, // SrcScalarStrideInVector
false>{b_grid_desc_m0, thread_store_global_offset};
BGridDesc_M,
decltype(thread_desc_m),
Sequence<MPerThread>, // SliceLengths
Sequence<0>, // DimAccessOrder
0, // SrcVectorDim
BScalarPerVector, // ScalarPerVector
1, // SrcScalarStrideInVector
false>{b_grid_desc_m, thread_store_global_offset};
auto c_global_write =
ThreadwiseTensorSliceTransfer_v1r3<ComputeDataType,
CDataType,
decltype(thread_desc_m0),
GridDesc_M0,
decltype(thread_desc_m),
CGridDesc_M,
PassThrough,
Sequence<ScalarPerVector>, // SliceLengths
Sequence<0>, // DimAccessOrder
0, // DstVectorDim
ScalarPerVector,
Sequence<MPerThread>, // SliceLengths
Sequence<0>, // DimAccessOrder
0, // DstVectorDim
CScalarPerVector, // ScalarPerVector
InMemoryDataOperationEnum::Set,
1, // DstScalarStrideInVector
false>{
c_grid_desc_m0, thread_store_global_offset, PassThrough{}};
c_grid_desc_m, thread_store_global_offset, PassThrough{}};
const index_t blockSize = get_block_size();
const index_t blockPerGrid = get_grid_size();
const auto m0 = c_grid_desc_m0.GetLength(I0);
const index_t loop_step = blockPerGrid * blockSize * ScalarPerVector;
const auto M = c_grid_desc_m.GetLength(I0);
const index_t loop_step = blockPerGrid * blockSize * MPerThread;
const auto loop_step_index = make_multi_index(loop_step);
index_t num_iter = m0 / (loop_step);
index_t num_iter = M / (loop_step);
do
{
// read and process ScalarPerVector elements
// read and process MPerThread elements
a_global_load.Run(
a_grid_desc_m0, a_global_buf, thread_desc_m0, make_tuple(I0), a_thread_buf);
a_grid_desc_m, a_global_buf, thread_desc_m, make_tuple(I0), a_thread_buf);
b_global_load.Run(
b_grid_desc_m0, b_global_buf, thread_desc_m0, make_tuple(I0), b_thread_buf);
b_grid_desc_m, b_global_buf, thread_desc_m, make_tuple(I0), b_thread_buf);
static_for<0, ScalarPerVector, 1>{}([&](auto m) {
constexpr auto offset = thread_desc_m0.CalculateOffset(make_tuple(m));
static_for<0, MPerThread, 1>{}([&](auto m) {
constexpr auto offset = thread_desc_m.CalculateOffset(make_tuple(m));
functor(c_thread_buf(Number<offset>{}),
a_thread_buf(Number<offset>{}),
b_thread_buf(Number<offset>{}));
});
c_global_write.Run(thread_desc_m0,
c_global_write.Run(thread_desc_m,
make_tuple(I0), // SrcSliceOriginIdx
c_thread_buf,
c_grid_desc_m0,
c_grid_desc_m,
c_global_buf);
a_global_load.MoveSrcSliceWindow(a_grid_desc_m0, loop_step_index);
b_global_load.MoveSrcSliceWindow(b_grid_desc_m0, loop_step_index);
c_global_write.MoveDstSliceWindow(c_grid_desc_m0, loop_step_index);
a_global_load.MoveSrcSliceWindow(a_grid_desc_m, loop_step_index);
b_global_load.MoveSrcSliceWindow(b_grid_desc_m, loop_step_index);
c_global_write.MoveDstSliceWindow(c_grid_desc_m, loop_step_index);
} while(--num_iter);
}
};
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment