"docs/git@developer.sourcefind.cn:renzhc/diffusers_dcu.git" did not exist on "30132aba308b83187997bd66579c34993e035f6a"
Unverified Commit 82d7d993 authored by rocking5566's avatar rocking5566 Committed by GitHub
Browse files

Hotfix binary elementwise (for broadcast on fastest axis) (#254)



* Support different length of ScalarPerVector

* Add example of broadcast on fastest axis

* Typo

* Refine fastest example

* Add dimension check

* Modify fastest broadcast example to 3d

* Enforce users give scalarPerVector explicitely

* 1. Add CscalarPerVedctor
2. Not only broadcast on fastest need to set scalarPerVector to 1

* Rename var

* Move IsScalarPerVectorValid() inside IsSupportedArgument()

* Separate GridDesc_M0 into A, B and C

* rename var

* Rename var of length
Co-authored-by: default avatarrocking <chunylai@amd.com>
parent e579c9e5
add_example_executable(example_broadcast_add_2d broadcast_add_2d.cpp) add_example_executable(example_broadcast_add_2d_amn_bn broadcast_add_2d_amn_bn.cpp)
add_example_executable(example_broadcast_add_3d_am_bmnk broadcast_add_3d_am_bmnk.cpp)
add_example_executable(example_elementwise_add_1d elementwise_add_1d.cpp) add_example_executable(example_elementwise_add_1d elementwise_add_1d.cpp)
add_example_executable(example_elementwise_add_4d elementwise_add_4d.cpp) add_example_executable(example_elementwise_add_4d elementwise_add_4d.cpp)
\ No newline at end of file
...@@ -19,8 +19,17 @@ using EltwiseComputeDataType = F32; ...@@ -19,8 +19,17 @@ using EltwiseComputeDataType = F32;
using Add = ck::tensor_operation::binary_element_wise::Add; using Add = ck::tensor_operation::binary_element_wise::Add;
using DeviceElementwiseAddInstance = ck::tensor_operation::device:: using DeviceElementwiseAddInstance =
DeviceBinaryElementwise<ABDataType, ABDataType, CDataType, EltwiseComputeDataType, Add, 2, 8>; ck::tensor_operation::device::DeviceBinaryElementwise<ABDataType,
ABDataType,
CDataType,
EltwiseComputeDataType,
Add,
2,
8,
8,
8,
8>;
template <typename HostTensorA, template <typename HostTensorA,
typename HostTensorB, typename HostTensorB,
...@@ -100,7 +109,7 @@ int main() ...@@ -100,7 +109,7 @@ int main()
if(!broadcastAdd.IsSupportedArgument(argument.get())) if(!broadcastAdd.IsSupportedArgument(argument.get()))
{ {
throw std::runtime_error("The runtime parameters seems not supported by the " throw std::runtime_error("The runtime parameters seems not supported by the "
"DeviceBinaryElementwise_2D instance, exiting!"); "DeviceBinaryElementwise instance, exiting!");
}; };
auto broadcastAdd_invoker_ptr = broadcastAdd.MakeInvokerPointer(); auto broadcastAdd_invoker_ptr = broadcastAdd.MakeInvokerPointer();
...@@ -123,7 +132,7 @@ int main() ...@@ -123,7 +132,7 @@ int main()
0>(host_c_m_n, a_m_n, b_n, M, N, Add{}); 0>(host_c_m_n, a_m_n, b_n, M, N, Add{});
pass &= ck::utils::check_err( pass &= ck::utils::check_err(
c_m_n.mData, host_c_m_n.mData, "Error: Incorrect results d1", 1e-3, 1e-3); c_m_n.mData, host_c_m_n.mData, "Error: Incorrect results c", 1e-3, 1e-3);
} }
return pass ? 0 : 1; return pass ? 0 : 1;
......
#include <iostream>
#include <cstdlib>
#include "check_err.hpp"
#include "config.hpp"
#include "device.hpp"
#include "host_tensor.hpp"
#include "host_tensor_generator.hpp"
#include "device_tensor.hpp"
#include "binary_element_wise_operation.hpp"
#include "device_binary_elementwise.hpp"
using F16 = ck::half_t;
using F32 = float;
using ABDataType = F16;
using CDataType = F16;
using EltwiseComputeDataType = F32;
using Add = ck::tensor_operation::binary_element_wise::Add;
using DeviceElementwiseAddInstance =
ck::tensor_operation::device::DeviceBinaryElementwise<ABDataType,
ABDataType,
CDataType,
EltwiseComputeDataType,
Add,
3,
8,
1,
8,
8>;
template <typename HostTensorA,
typename HostTensorB,
typename HostTensorC,
typename ComputeDataType,
typename Functor>
void host_broadcast3D_am_bmnk(HostTensorC& C,
const HostTensorA& A,
const HostTensorB& B,
const std::vector<std::size_t>& shape,
Functor functor)
{
using ctype = ck::remove_reference_t<decltype(C(0, 0))>;
for(std::size_t m = 0; m < shape[0]; ++m)
for(std::size_t n = 0; n < shape[1]; ++n)
for(std::size_t k = 0; k < shape[2]; ++k)
{
ComputeDataType a_val = static_cast<ComputeDataType>(A(m));
ComputeDataType b_val = static_cast<ComputeDataType>(B(m, n, k));
ComputeDataType c_val = 0;
functor(c_val, a_val, b_val);
C(m, n, k) = static_cast<ctype>(c_val);
}
}
int main()
{
bool do_verification = true;
bool time_kernel = false;
std::vector<std::size_t> mnk = {4, 16, 32};
ck::index_t M = mnk[0];
Tensor<ABDataType> a_m({M});
Tensor<ABDataType> b_m_n_k(mnk);
Tensor<CDataType> c_m_n_k(mnk);
a_m.GenerateTensorValue(GeneratorTensor_3<ABDataType>{0.0, 1.0});
b_m_n_k.GenerateTensorValue(GeneratorTensor_3<ABDataType>{0.0, 1.0});
DeviceMem a_m_device_buf(sizeof(ABDataType) * a_m.mDesc.GetElementSpace());
DeviceMem b_m_n_k_device_buf(sizeof(ABDataType) * b_m_n_k.mDesc.GetElementSpace());
DeviceMem c_m_n_k_device_buf(sizeof(CDataType) * c_m_n_k.mDesc.GetElementSpace());
a_m_device_buf.ToDevice(a_m.mData.data());
b_m_n_k_device_buf.ToDevice(b_m_n_k.mData.data());
auto broadcastAdd = DeviceElementwiseAddInstance{};
auto argument = broadcastAdd.MakeArgumentPointer(
a_m_device_buf.GetDeviceBuffer(),
b_m_n_k_device_buf.GetDeviceBuffer(),
c_m_n_k_device_buf.GetDeviceBuffer(),
std::vector<ck::index_t>{mnk.begin(), mnk.end()},
{1, 0, 0}, // broadcast A on second and third dimension
std::vector<ck::index_t>{b_m_n_k.mDesc.GetStrides().begin(),
b_m_n_k.mDesc.GetStrides().end()},
std::vector<ck::index_t>{c_m_n_k.mDesc.GetStrides().begin(),
c_m_n_k.mDesc.GetStrides().end()},
Add{});
if(!broadcastAdd.IsSupportedArgument(argument.get()))
{
throw std::runtime_error("The runtime parameters seems not supported by the "
"DeviceBinaryElementwise instance, exiting!");
};
auto broadcastAdd_invoker_ptr = broadcastAdd.MakeInvokerPointer();
float ave_time =
broadcastAdd_invoker_ptr->Run(argument.get(), StreamConfig{nullptr, time_kernel});
std::cout << "Perf: " << ave_time << " ms" << std::endl;
bool pass = true;
if(do_verification)
{
c_m_n_k_device_buf.FromDevice(c_m_n_k.mData.data());
Tensor<CDataType> host_c_m_n_k(mnk);
host_broadcast3D_am_bmnk<Tensor<ABDataType>,
Tensor<ABDataType>,
Tensor<CDataType>,
EltwiseComputeDataType,
Add>(host_c_m_n_k, a_m, b_m_n_k, mnk, Add{});
pass &= ck::utils::check_err(
c_m_n_k.mData, host_c_m_n_k.mData, "Error: Incorrect results c", 1e-3, 1e-3);
}
return pass ? 0 : 1;
}
...@@ -19,8 +19,17 @@ using EltwiseComputeDataType = F32; ...@@ -19,8 +19,17 @@ using EltwiseComputeDataType = F32;
using Add = ck::tensor_operation::binary_element_wise::Add; using Add = ck::tensor_operation::binary_element_wise::Add;
using DeviceElementwiseAddInstance = ck::tensor_operation::device:: using DeviceElementwiseAddInstance =
DeviceBinaryElementwise<ABDataType, ABDataType, CDataType, EltwiseComputeDataType, Add, 1, 8>; ck::tensor_operation::device::DeviceBinaryElementwise<ABDataType,
ABDataType,
CDataType,
EltwiseComputeDataType,
Add,
1,
8,
8,
8,
8>;
template <typename HostTensorA, template <typename HostTensorA,
typename HostTensorB, typename HostTensorB,
...@@ -81,7 +90,7 @@ int main() ...@@ -81,7 +90,7 @@ int main()
if(!broadcastAdd.IsSupportedArgument(argument.get())) if(!broadcastAdd.IsSupportedArgument(argument.get()))
{ {
throw std::runtime_error("The runtime parameters seems not supported by the " throw std::runtime_error("The runtime parameters seems not supported by the "
"DeviceBinaryElementwise_2D instance, exiting!"); "DeviceBinaryElementwise instance, exiting!");
}; };
auto broadcastAdd_invoker_ptr = broadcastAdd.MakeInvokerPointer(); auto broadcastAdd_invoker_ptr = broadcastAdd.MakeInvokerPointer();
...@@ -103,7 +112,7 @@ int main() ...@@ -103,7 +112,7 @@ int main()
Add>(host_c_m, a_m, b_m, M, Add{}); Add>(host_c_m, a_m, b_m, M, Add{});
pass &= ck::utils::check_err( pass &= ck::utils::check_err(
c_m.mData, host_c_m.mData, "Error: Incorrect results d1", 1e-3, 1e-3); c_m.mData, host_c_m.mData, "Error: Incorrect results c", 1e-3, 1e-3);
} }
return pass ? 0 : 1; return pass ? 0 : 1;
......
...@@ -19,8 +19,17 @@ using EltwiseComputeDataType = F32; ...@@ -19,8 +19,17 @@ using EltwiseComputeDataType = F32;
using Add = ck::tensor_operation::binary_element_wise::Add; using Add = ck::tensor_operation::binary_element_wise::Add;
using DeviceElementwiseAddInstance = ck::tensor_operation::device:: using DeviceElementwiseAddInstance =
DeviceBinaryElementwise<ABDataType, ABDataType, CDataType, EltwiseComputeDataType, Add, 4, 8>; ck::tensor_operation::device::DeviceBinaryElementwise<ABDataType,
ABDataType,
CDataType,
EltwiseComputeDataType,
Add,
4,
8,
8,
8,
8>;
template <typename HostTensorA, template <typename HostTensorA,
typename HostTensorB, typename HostTensorB,
...@@ -83,7 +92,7 @@ int main() ...@@ -83,7 +92,7 @@ int main()
if(!broadcastAdd.IsSupportedArgument(argument.get())) if(!broadcastAdd.IsSupportedArgument(argument.get()))
{ {
throw std::runtime_error("The runtime parameters seems not supported by the " throw std::runtime_error("The runtime parameters seems not supported by the "
"DeviceBinaryElementwise_2D instance, exiting!"); "DeviceBinaryElementwise instance, exiting!");
}; };
auto broadcastAdd_invoker_ptr = broadcastAdd.MakeInvokerPointer(); auto broadcastAdd_invoker_ptr = broadcastAdd.MakeInvokerPointer();
...@@ -105,7 +114,7 @@ int main() ...@@ -105,7 +114,7 @@ int main()
Add>(host_c, a, b, nchw, Add{}); Add>(host_c, a, b, nchw, Add{});
pass &= pass &=
ck::utils::check_err(c.mData, host_c.mData, "Error: Incorrect results d1", 1e-3, 1e-3); ck::utils::check_err(c.mData, host_c.mData, "Error: Incorrect results c", 1e-3, 1e-3);
} }
return pass ? 0 : 1; return pass ? 0 : 1;
......
...@@ -15,91 +15,107 @@ template <typename ADataType, ...@@ -15,91 +15,107 @@ template <typename ADataType,
typename CDataType, typename CDataType,
typename ComputeDataType, typename ComputeDataType,
typename ElementwiseFunctor, typename ElementwiseFunctor,
index_t Dim, index_t NDim,
index_t ScalarPerVector> index_t MPerThread,
index_t AScalarPerVector,
index_t BScalarPerVector,
index_t CScalarPerVector>
struct DeviceBinaryElementwise : public BaseOperator struct DeviceBinaryElementwise : public BaseOperator
{ {
static constexpr auto I0 = Number<0>{}; static constexpr auto I0 = Number<0>{};
template <typename Desc_M0> template <typename Desc_M>
static auto PadDescriptor_M0_1d(Desc_M0 desc_m0, index_t gridSize, index_t blockSize) static auto PadDescriptor_M_1d(Desc_M desc_m, index_t gridSize, index_t blockSize)
{ {
const auto m0 = desc_m0.GetLength(I0); const auto M = desc_m.GetLength(I0);
const index_t loop_step = gridSize * blockSize * ScalarPerVector; const index_t loop_step = gridSize * blockSize * MPerThread;
const auto pad = math::integer_least_multiple(m0, loop_step) - m0; const auto pad = math::integer_least_multiple(M, loop_step) - M;
const auto desc_m0_pad = const auto desc_m_pad =
transform_tensor_descriptor(desc_m0, transform_tensor_descriptor(desc_m,
make_tuple(make_right_pad_transform(m0, pad)), make_tuple(make_right_pad_transform(M, pad)),
make_tuple(Sequence<0>{}), make_tuple(Sequence<0>{}),
make_tuple(Sequence<0>{})); make_tuple(Sequence<0>{}));
return desc_m0_pad; return desc_m_pad;
} }
static auto MakeDescriptor_M0(const std::vector<index_t>& shape, static auto MakeDescriptor_M(const std::vector<index_t>& lengths,
const std::vector<index_t>& stride, const std::vector<index_t>& strides,
index_t gridSize, index_t gridSize,
index_t blockSize) index_t blockSize)
{ {
auto tupleOfShape = generate_tuple([&](auto I) { return shape[I]; }, Number<Dim>{}); auto tupleOfShape = generate_tuple([&](auto I) { return lengths[I]; }, Number<NDim>{});
auto tupleOfStride = generate_tuple([&](auto I) { return stride[I]; }, Number<Dim>{}); auto tupleOfStride = generate_tuple([&](auto I) { return strides[I]; }, Number<NDim>{});
// nd desc - [s0, s1, s2, ...] // nd desc - [s0, s1, s2, ...]
const auto desc = make_naive_tensor_descriptor(tupleOfShape, tupleOfStride); const auto desc = make_naive_tensor_descriptor(tupleOfShape, tupleOfStride);
// merge nd to 1d desc - [s0 * s1 * ...] // merge nd to 1d desc - [s0 * s1 * ...]
if constexpr(Dim > 1) if constexpr(NDim > 1)
{ {
const auto desc_m0 = transform_tensor_descriptor( const auto desc_m = transform_tensor_descriptor(
desc, desc,
make_tuple(make_merge_transform(tupleOfShape)), make_tuple(make_merge_transform(tupleOfShape)),
make_tuple(generate_sequence_v2([&](auto I) { return I; }, Number<Dim>{})), make_tuple(generate_sequence_v2([&](auto I) { return I; }, Number<NDim>{})),
make_tuple(Sequence<0>{})); make_tuple(Sequence<0>{}));
return PadDescriptor_M0_1d(desc_m0, gridSize, blockSize); return PadDescriptor_M_1d(desc_m, gridSize, blockSize);
} }
else else
return PadDescriptor_M0_1d(desc, gridSize, blockSize); return PadDescriptor_M_1d(desc, gridSize, blockSize);
} }
using GridDesc_M0 = decltype(MakeDescriptor_M0({1, 1}, {1, 1}, 1, 1)); using AGridDesc_M = decltype(MakeDescriptor_M({1, 1}, {1, 1}, 1, 1));
using BGridDesc_M = decltype(MakeDescriptor_M({1, 1}, {1, 1}, 1, 1));
using CGridDesc_M = decltype(MakeDescriptor_M({1, 1}, {1, 1}, 1, 1));
using GridwiseBinEltwise = GridwiseBinaryElementwise_1D<ADataType, using GridwiseBinEltwise = GridwiseBinaryElementwise_1D<ADataType,
BDataType, BDataType,
CDataType, CDataType,
ComputeDataType, ComputeDataType,
GridDesc_M0, AGridDesc_M,
BGridDesc_M,
CGridDesc_M,
ElementwiseFunctor, ElementwiseFunctor,
ScalarPerVector>; MPerThread,
AScalarPerVector,
BScalarPerVector,
CScalarPerVector>;
struct Argument : public BaseArgument struct Argument : public BaseArgument
{ {
Argument(const ADataType* p_a, Argument(const ADataType* p_a,
const BDataType* p_b, const BDataType* p_b,
CDataType* p_c, CDataType* p_c,
const std::vector<index_t>& shape, const std::vector<index_t>& lengths,
const std::vector<index_t>& stride_a, const std::vector<index_t>& a_strides,
const std::vector<index_t>& stride_b, const std::vector<index_t>& b_strides,
const std::vector<index_t>& stride_c, const std::vector<index_t>& c_strides,
ElementwiseFunctor functor) ElementwiseFunctor functor)
: p_a_(p_a), : p_a_(p_a),
p_b_(p_b), p_b_(p_b),
p_c_(p_c), p_c_(p_c),
shape_(shape), lengths_(lengths),
a_strides_(a_strides),
b_strides_(b_strides),
c_strides_(c_strides),
functor_(functor), functor_(functor),
blockSize_(256), blockSize_(256),
gridSize_(120) // FIXME - Calculate the grid size by number of CU in the future gridSize_(120) // FIXME - Calculate the grid size by number of CU in the future
{ {
a_grid_desc_m0_ = MakeDescriptor_M0(shape, stride_a, gridSize_, blockSize_); a_grid_desc_m_ = MakeDescriptor_M(lengths, a_strides, gridSize_, blockSize_);
b_grid_desc_m0_ = MakeDescriptor_M0(shape, stride_b, gridSize_, blockSize_); b_grid_desc_m_ = MakeDescriptor_M(lengths, b_strides, gridSize_, blockSize_);
c_grid_desc_m0_ = MakeDescriptor_M0(shape, stride_c, gridSize_, blockSize_); c_grid_desc_m_ = MakeDescriptor_M(lengths, c_strides, gridSize_, blockSize_);
} }
const ADataType* p_a_; const ADataType* p_a_;
const BDataType* p_b_; const BDataType* p_b_;
CDataType* p_c_; CDataType* p_c_;
std::vector<int> shape_; std::vector<int> lengths_;
GridDesc_M0 a_grid_desc_m0_; AGridDesc_M a_grid_desc_m_;
GridDesc_M0 b_grid_desc_m0_; BGridDesc_M b_grid_desc_m_;
GridDesc_M0 c_grid_desc_m0_; CGridDesc_M c_grid_desc_m_;
std::vector<index_t> a_strides_;
std::vector<index_t> b_strides_;
std::vector<index_t> c_strides_;
ElementwiseFunctor functor_; ElementwiseFunctor functor_;
index_t blockSize_; index_t blockSize_;
index_t gridSize_; index_t gridSize_;
...@@ -113,7 +129,9 @@ struct DeviceBinaryElementwise : public BaseOperator ...@@ -113,7 +129,9 @@ struct DeviceBinaryElementwise : public BaseOperator
ADataType, ADataType,
BDataType, BDataType,
CDataType, CDataType,
GridDesc_M0, AGridDesc_M,
BGridDesc_M,
CGridDesc_M,
ElementwiseFunctor>; ElementwiseFunctor>;
float elapsed_time = launch_and_time_kernel(stream_config, float elapsed_time = launch_and_time_kernel(stream_config,
...@@ -124,9 +142,9 @@ struct DeviceBinaryElementwise : public BaseOperator ...@@ -124,9 +142,9 @@ struct DeviceBinaryElementwise : public BaseOperator
arg.p_a_, arg.p_a_,
arg.p_b_, arg.p_b_,
arg.p_c_, arg.p_c_,
arg.a_grid_desc_m0_, arg.a_grid_desc_m_,
arg.b_grid_desc_m0_, arg.b_grid_desc_m_,
arg.c_grid_desc_m0_, arg.c_grid_desc_m_,
arg.functor_); arg.functor_);
return elapsed_time; return elapsed_time;
} }
...@@ -146,7 +164,30 @@ struct DeviceBinaryElementwise : public BaseOperator ...@@ -146,7 +164,30 @@ struct DeviceBinaryElementwise : public BaseOperator
if(pArg == nullptr) if(pArg == nullptr)
return false; return false;
if(pArg->shape_.back() % ScalarPerVector != 0) if(pArg->lengths_.size() != NDim)
return false;
if(pArg->lengths_.back() % MPerThread != 0)
return false;
auto IsScalarPerVectorValid = [](bool isLastDimensionCoalesced, int scalarPerVector) {
bool ret = true;
if(!isLastDimensionCoalesced)
ret = scalarPerVector == 1;
else
ret = MPerThread % scalarPerVector == 0;
return ret;
};
if(!IsScalarPerVectorValid(pArg->a_strides_.back() == 1, AScalarPerVector))
return false;
if(!IsScalarPerVectorValid(pArg->b_strides_.back() == 1, BScalarPerVector))
return false;
if(!IsScalarPerVectorValid(pArg->c_strides_.back() == 1, CScalarPerVector))
return false; return false;
return true; return true;
...@@ -155,19 +196,19 @@ struct DeviceBinaryElementwise : public BaseOperator ...@@ -155,19 +196,19 @@ struct DeviceBinaryElementwise : public BaseOperator
std::unique_ptr<BaseArgument> MakeArgumentPointer(const void* p_a, std::unique_ptr<BaseArgument> MakeArgumentPointer(const void* p_a,
const void* p_b, const void* p_b,
void* p_c, void* p_c,
std::vector<index_t> shape, std::vector<index_t> lengths,
std::vector<index_t> stride_a, std::vector<index_t> a_strides,
std::vector<index_t> stride_b, std::vector<index_t> b_strides,
std::vector<index_t> stride_c, std::vector<index_t> c_strides,
ElementwiseFunctor functor) ElementwiseFunctor functor)
{ {
return std::make_unique<Argument>(static_cast<const ADataType*>(p_a), return std::make_unique<Argument>(static_cast<const ADataType*>(p_a),
static_cast<const BDataType*>(p_b), static_cast<const BDataType*>(p_b),
static_cast<CDataType*>(p_c), static_cast<CDataType*>(p_c),
shape, lengths,
stride_a, a_strides,
stride_b, b_strides,
stride_c, c_strides,
functor); functor);
} }
...@@ -180,7 +221,7 @@ struct DeviceBinaryElementwise : public BaseOperator ...@@ -180,7 +221,7 @@ struct DeviceBinaryElementwise : public BaseOperator
// clang-format off // clang-format off
str << "DeviceBinaryElementwise" str << "DeviceBinaryElementwise"
<< "<" << "<"
<< "ScalarPerVector = " << ScalarPerVector << "MPerThread = " << MPerThread
<< ">"; << ">";
// clang-format on // clang-format on
......
...@@ -11,138 +11,140 @@ template <typename GridwiseBinEltwise, ...@@ -11,138 +11,140 @@ template <typename GridwiseBinEltwise,
typename ADataType, typename ADataType,
typename BDataType, typename BDataType,
typename CDataType, typename CDataType,
typename GridDesc_M0, typename AGridDesc_M,
typename BGridDesc_M,
typename CGridDesc_M,
typename ElementwiseFunctor> typename ElementwiseFunctor>
__global__ void kernel_binary_elementwise_1d(const ADataType* __restrict__ p_a_global, __global__ void kernel_binary_elementwise_1d(const ADataType* __restrict__ p_a_global,
const BDataType* __restrict__ p_b_global, const BDataType* __restrict__ p_b_global,
CDataType* __restrict__ p_c_global, CDataType* __restrict__ p_c_global,
const GridDesc_M0 a_grid_desc_m0, const AGridDesc_M a_grid_desc_m,
const GridDesc_M0 b_grid_desc_m0, const BGridDesc_M b_grid_desc_m,
const GridDesc_M0 c_grid_desc_m0, const CGridDesc_M c_grid_desc_m,
const ElementwiseFunctor functor) const ElementwiseFunctor functor)
{ {
GridwiseBinEltwise::Run(p_a_global, GridwiseBinEltwise::Run(
p_b_global, p_a_global, p_b_global, p_c_global, a_grid_desc_m, b_grid_desc_m, c_grid_desc_m, functor);
p_c_global,
a_grid_desc_m0,
b_grid_desc_m0,
c_grid_desc_m0,
functor);
} }
template <typename ADataType, template <typename ADataType,
typename BDataType, typename BDataType,
typename CDataType, typename CDataType,
typename ComputeDataType, typename ComputeDataType,
typename GridDesc_M0, typename AGridDesc_M,
typename BGridDesc_M,
typename CGridDesc_M,
typename ElementwiseFunctor, typename ElementwiseFunctor,
index_t ScalarPerVector> index_t MPerThread,
index_t AScalarPerVector,
index_t BScalarPerVector,
index_t CScalarPerVector>
struct GridwiseBinaryElementwise_1D struct GridwiseBinaryElementwise_1D
{ {
static constexpr auto I0 = Number<0>{}; static constexpr auto I0 = Number<0>{};
static constexpr auto thread_desc_m0 = static constexpr auto thread_desc_m =
make_naive_tensor_descriptor_packed(make_tuple(Number<ScalarPerVector>{})); make_naive_tensor_descriptor_packed(make_tuple(Number<MPerThread>{}));
using PassThrough = tensor_operation::element_wise::PassThrough; using PassThrough = tensor_operation::element_wise::PassThrough;
static __device__ auto CalculateElementwiseIndex() static __device__ auto CalculateElementwiseIndex()
{ {
const index_t global_thread_id = get_thread_global_1d_id(); const index_t global_thread_id = get_thread_global_1d_id();
return make_multi_index(global_thread_id * ScalarPerVector); return make_multi_index(global_thread_id * MPerThread);
} }
__device__ static void Run(const ADataType* __restrict__ p_a_global, __device__ static void Run(const ADataType* __restrict__ p_a_global,
const BDataType* __restrict__ p_b_global, const BDataType* __restrict__ p_b_global,
CDataType* __restrict__ p_c_global, CDataType* __restrict__ p_c_global,
const GridDesc_M0 a_grid_desc_m0, const AGridDesc_M a_grid_desc_m,
const GridDesc_M0 b_grid_desc_m0, const BGridDesc_M b_grid_desc_m,
const GridDesc_M0 c_grid_desc_m0, const CGridDesc_M c_grid_desc_m,
const ElementwiseFunctor functor) const ElementwiseFunctor functor)
{ {
const auto a_global_buf = make_dynamic_buffer<AddressSpaceEnum::Global>( const auto a_global_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_a_global, a_grid_desc_m0.GetElementSpaceSize()); p_a_global, a_grid_desc_m.GetElementSpaceSize());
const auto b_global_buf = make_dynamic_buffer<AddressSpaceEnum::Global>( const auto b_global_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_b_global, b_grid_desc_m0.GetElementSpaceSize()); p_b_global, b_grid_desc_m.GetElementSpaceSize());
auto c_global_buf = make_dynamic_buffer<AddressSpaceEnum::Global>( auto c_global_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_c_global, c_grid_desc_m0.GetElementSpaceSize()); p_c_global, c_grid_desc_m.GetElementSpaceSize());
StaticBuffer<AddressSpaceEnum::Vgpr, ComputeDataType, ScalarPerVector, true> a_thread_buf; StaticBuffer<AddressSpaceEnum::Vgpr, ComputeDataType, MPerThread, true> a_thread_buf;
StaticBuffer<AddressSpaceEnum::Vgpr, ComputeDataType, ScalarPerVector, true> b_thread_buf; StaticBuffer<AddressSpaceEnum::Vgpr, ComputeDataType, MPerThread, true> b_thread_buf;
StaticBuffer<AddressSpaceEnum::Vgpr, ComputeDataType, ScalarPerVector, true> c_thread_buf; StaticBuffer<AddressSpaceEnum::Vgpr, ComputeDataType, MPerThread, true> c_thread_buf;
const auto thread_store_global_offset = CalculateElementwiseIndex(); const auto thread_store_global_offset = CalculateElementwiseIndex();
auto a_global_load = auto a_global_load =
ThreadwiseTensorSliceTransfer_v2<ADataType, ThreadwiseTensorSliceTransfer_v2<ADataType,
ComputeDataType, ComputeDataType,
GridDesc_M0, AGridDesc_M,
decltype(thread_desc_m0), decltype(thread_desc_m),
Sequence<ScalarPerVector>, // SliceLengths Sequence<MPerThread>, // SliceLengths
Sequence<0>, // DimAccessOrder Sequence<0>, // DimAccessOrder
0, // SrcVectorDim 0, // SrcVectorDim
ScalarPerVector, AScalarPerVector, // ScalarPerVector
1, // SrcScalarStrideInVector 1, // SrcScalarStrideInVector
false>{a_grid_desc_m0, thread_store_global_offset}; false>{a_grid_desc_m, thread_store_global_offset};
auto b_global_load = auto b_global_load =
ThreadwiseTensorSliceTransfer_v2<BDataType, ThreadwiseTensorSliceTransfer_v2<BDataType,
ComputeDataType, ComputeDataType,
GridDesc_M0, BGridDesc_M,
decltype(thread_desc_m0), decltype(thread_desc_m),
Sequence<ScalarPerVector>, // SliceLengths Sequence<MPerThread>, // SliceLengths
Sequence<0>, // DimAccessOrder Sequence<0>, // DimAccessOrder
0, // SrcVectorDim 0, // SrcVectorDim
ScalarPerVector, BScalarPerVector, // ScalarPerVector
1, // SrcScalarStrideInVector 1, // SrcScalarStrideInVector
false>{b_grid_desc_m0, thread_store_global_offset}; false>{b_grid_desc_m, thread_store_global_offset};
auto c_global_write = auto c_global_write =
ThreadwiseTensorSliceTransfer_v1r3<ComputeDataType, ThreadwiseTensorSliceTransfer_v1r3<ComputeDataType,
CDataType, CDataType,
decltype(thread_desc_m0), decltype(thread_desc_m),
GridDesc_M0, CGridDesc_M,
PassThrough, PassThrough,
Sequence<ScalarPerVector>, // SliceLengths Sequence<MPerThread>, // SliceLengths
Sequence<0>, // DimAccessOrder Sequence<0>, // DimAccessOrder
0, // DstVectorDim 0, // DstVectorDim
ScalarPerVector, CScalarPerVector, // ScalarPerVector
InMemoryDataOperationEnum::Set, InMemoryDataOperationEnum::Set,
1, // DstScalarStrideInVector 1, // DstScalarStrideInVector
false>{ false>{
c_grid_desc_m0, thread_store_global_offset, PassThrough{}}; c_grid_desc_m, thread_store_global_offset, PassThrough{}};
const index_t blockSize = get_block_size(); const index_t blockSize = get_block_size();
const index_t blockPerGrid = get_grid_size(); const index_t blockPerGrid = get_grid_size();
const auto m0 = c_grid_desc_m0.GetLength(I0); const auto M = c_grid_desc_m.GetLength(I0);
const index_t loop_step = blockPerGrid * blockSize * ScalarPerVector; const index_t loop_step = blockPerGrid * blockSize * MPerThread;
const auto loop_step_index = make_multi_index(loop_step); const auto loop_step_index = make_multi_index(loop_step);
index_t num_iter = m0 / (loop_step); index_t num_iter = M / (loop_step);
do do
{ {
// read and process ScalarPerVector elements // read and process MPerThread elements
a_global_load.Run( a_global_load.Run(
a_grid_desc_m0, a_global_buf, thread_desc_m0, make_tuple(I0), a_thread_buf); a_grid_desc_m, a_global_buf, thread_desc_m, make_tuple(I0), a_thread_buf);
b_global_load.Run( b_global_load.Run(
b_grid_desc_m0, b_global_buf, thread_desc_m0, make_tuple(I0), b_thread_buf); b_grid_desc_m, b_global_buf, thread_desc_m, make_tuple(I0), b_thread_buf);
static_for<0, ScalarPerVector, 1>{}([&](auto m) { static_for<0, MPerThread, 1>{}([&](auto m) {
constexpr auto offset = thread_desc_m0.CalculateOffset(make_tuple(m)); constexpr auto offset = thread_desc_m.CalculateOffset(make_tuple(m));
functor(c_thread_buf(Number<offset>{}), functor(c_thread_buf(Number<offset>{}),
a_thread_buf(Number<offset>{}), a_thread_buf(Number<offset>{}),
b_thread_buf(Number<offset>{})); b_thread_buf(Number<offset>{}));
}); });
c_global_write.Run(thread_desc_m0, c_global_write.Run(thread_desc_m,
make_tuple(I0), // SrcSliceOriginIdx make_tuple(I0), // SrcSliceOriginIdx
c_thread_buf, c_thread_buf,
c_grid_desc_m0, c_grid_desc_m,
c_global_buf); c_global_buf);
a_global_load.MoveSrcSliceWindow(a_grid_desc_m0, loop_step_index); a_global_load.MoveSrcSliceWindow(a_grid_desc_m, loop_step_index);
b_global_load.MoveSrcSliceWindow(b_grid_desc_m0, loop_step_index); b_global_load.MoveSrcSliceWindow(b_grid_desc_m, loop_step_index);
c_global_write.MoveDstSliceWindow(c_grid_desc_m0, loop_step_index); c_global_write.MoveDstSliceWindow(c_grid_desc_m, loop_step_index);
} while(--num_iter); } while(--num_iter);
} }
}; };
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment