Commit 1196b676 authored by turneram's avatar turneram
Browse files

Formatting

parent e4737e2f
...@@ -36,25 +36,27 @@ ...@@ -36,25 +36,27 @@
namespace migraphx { namespace migraphx {
using ADataType = float; using ADataType = float;
using BDataType = float; using BDataType = float;
using CDataType = float; using CDataType = float;
using ElementwiseFunctor = float; using ElementwiseFunctor = float;
static constexpr auto I0 = ck::Number<0>{}; static constexpr auto I0 = ck::Number<0>{};
using index_t = index_int; using index_t = index_int;
template<class L, class S> template <class L, class S>
__host__ __device__ constexpr auto MakeDescriptor_M(const L& lengths, const S& strides) __host__ __device__ constexpr auto MakeDescriptor_M(const L& lengths, const S& strides)
{ {
auto idx = make_index(); auto idx = make_index();
auto tupleOfShape = generate_tuple([&](auto I) { return static_cast<ck::index_t>(lengths[I]); }, ck::Number<1>{}); auto tupleOfShape = generate_tuple([&](auto I) { return static_cast<ck::index_t>(lengths[I]); },
auto tupleOfStride = generate_tuple([&](auto I) { return static_cast<ck::index_t>(strides[I]); }, ck::Number<1>{}); ck::Number<1>{});
auto tupleOfStride = generate_tuple(
[&](auto I) { return static_cast<ck::index_t>(strides[I]); }, ck::Number<1>{});
const auto desc_m = make_naive_tensor_descriptor(tupleOfShape, tupleOfStride); const auto desc_m = make_naive_tensor_descriptor(tupleOfShape, tupleOfStride);
const auto M = desc_m.GetLength(I0); const auto M = desc_m.GetLength(I0);
const index_t loop_step = idx.nglobal();//gridSize * blockSize * MPerThread; const index_t loop_step = idx.nglobal(); // gridSize * blockSize * MPerThread;
const auto pad = ck::math::integer_least_multiple(M, loop_step) - M; const auto pad = ck::math::integer_least_multiple(M, loop_step) - M;
const auto desc_m_pad = const auto desc_m_pad =
transform_tensor_descriptor(desc_m, transform_tensor_descriptor(desc_m,
make_tuple(ck::make_right_pad_transform(M, pad)), make_tuple(ck::make_right_pad_transform(M, pad)),
...@@ -83,30 +85,29 @@ struct Add ...@@ -83,30 +85,29 @@ struct Add
}; };
}; };
template <class T, class U, class V> template <class T, class U, class V>
__device__ void ck_elementwise(const T& a_t, const U& b_t, const V& c_t) __device__ void ck_elementwise(const T& a_t, const U& b_t, const V& c_t)
{ {
//auto add = [](auto a, auto b) { return a + b; }; // auto add = [](auto a, auto b) { return a + b; };
auto lengths = a_t.get_shape().lens; auto lengths = a_t.get_shape().lens;
auto strides = a_t.get_shape().strides; auto strides = a_t.get_shape().strides;
auto a_desc = MakeDescriptor_M(lengths, strides); auto a_desc = MakeDescriptor_M(lengths, strides);
using AGridDesc_M = decltype(a_desc); using AGridDesc_M = decltype(a_desc);
//using Add = ck::tensor_operation::element_wise::Add; // using Add = ck::tensor_operation::element_wise::Add;
using GridwiseBinEltwise = ck::GridwiseBinaryElementwise_1D<ADataType, using GridwiseBinEltwise = ck::GridwiseBinaryElementwise_1D<ADataType,
BDataType, BDataType,
CDataType, CDataType,
CDataType, CDataType,
AGridDesc_M, AGridDesc_M,
AGridDesc_M, AGridDesc_M,
AGridDesc_M, AGridDesc_M,
Add, Add,
8, 8,
8, 8,
8, 8,
8>; 8>;
auto op = Add{}; auto op = Add{};
GridwiseBinEltwise::Run(a_t.data(), b_t.data(), c_t.data(), a_desc, a_desc, a_desc, op); GridwiseBinEltwise::Run(a_t.data(), b_t.data(), c_t.data(), a_desc, a_desc, a_desc, op);
// auto kernel = ck::kernel_binary_elementwise_1d<GridwiseBinEltwise, // auto kernel = ck::kernel_binary_elementwise_1d<GridwiseBinEltwise,
// ADataType, // ADataType,
...@@ -118,8 +119,8 @@ __device__ void ck_elementwise(const T& a_t, const U& b_t, const V& c_t) ...@@ -118,8 +119,8 @@ __device__ void ck_elementwise(const T& a_t, const U& b_t, const V& c_t)
// Add>; // Add>;
// kernel(a_t.data(), b_t.data(), c_t.data(), a_desc, a_desc, a_desc, Add); // kernel(a_t.data(), b_t.data(), c_t.data(), a_desc, a_desc, a_desc, Add);
// Argument arg{a_t.data(), b_t.data(), c_t.data(), c_t.get_shape().lens,
// Argument arg{a_t.data(), b_t.data(), c_t.data(), c_t.get_shape().lens, a_t.get_shape().strides, b_t.get_shape().strides, c_t.get_shape().strides, // a_t.get_shape().strides, b_t.get_shape().strides, c_t.get_shape().strides,
// add}; // add};
// auto lengths = a_t.get_shape().lens; // auto lengths = a_t.get_shape().lens;
// auto strides = a_t.get_shape().strides; // auto strides = a_t.get_shape().strides;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment