Commit 1196b676 authored by turneram's avatar turneram
Browse files

Formatting

parent e4737e2f
......@@ -44,16 +44,18 @@ using ElementwiseFunctor = float;
static constexpr auto I0 = ck::Number<0>{};
using index_t = index_int;
template<class L, class S>
template <class L, class S>
__host__ __device__ constexpr auto MakeDescriptor_M(const L& lengths, const S& strides)
{
auto idx = make_index();
auto tupleOfShape = generate_tuple([&](auto I) { return static_cast<ck::index_t>(lengths[I]); }, ck::Number<1>{});
auto tupleOfStride = generate_tuple([&](auto I) { return static_cast<ck::index_t>(strides[I]); }, ck::Number<1>{});
auto tupleOfShape = generate_tuple([&](auto I) { return static_cast<ck::index_t>(lengths[I]); },
ck::Number<1>{});
auto tupleOfStride = generate_tuple(
[&](auto I) { return static_cast<ck::index_t>(strides[I]); }, ck::Number<1>{});
const auto desc_m = make_naive_tensor_descriptor(tupleOfShape, tupleOfStride);
const auto M = desc_m.GetLength(I0);
const index_t loop_step = idx.nglobal();//gridSize * blockSize * MPerThread;
const index_t loop_step = idx.nglobal(); // gridSize * blockSize * MPerThread;
const auto pad = ck::math::integer_least_multiple(M, loop_step) - M;
const auto desc_m_pad =
transform_tensor_descriptor(desc_m,
......@@ -83,17 +85,16 @@ struct Add
};
};
template <class T, class U, class V>
__device__ void ck_elementwise(const T& a_t, const U& b_t, const V& c_t)
{
//auto add = [](auto a, auto b) { return a + b; };
// auto add = [](auto a, auto b) { return a + b; };
auto lengths = a_t.get_shape().lens;
auto strides = a_t.get_shape().strides;
auto a_desc = MakeDescriptor_M(lengths, strides);
using AGridDesc_M = decltype(a_desc);
//using Add = ck::tensor_operation::element_wise::Add;
// using Add = ck::tensor_operation::element_wise::Add;
using GridwiseBinEltwise = ck::GridwiseBinaryElementwise_1D<ADataType,
BDataType,
CDataType,
......@@ -118,8 +119,8 @@ __device__ void ck_elementwise(const T& a_t, const U& b_t, const V& c_t)
// Add>;
// kernel(a_t.data(), b_t.data(), c_t.data(), a_desc, a_desc, a_desc, Add);
// Argument arg{a_t.data(), b_t.data(), c_t.data(), c_t.get_shape().lens, a_t.get_shape().strides, b_t.get_shape().strides, c_t.get_shape().strides,
// Argument arg{a_t.data(), b_t.data(), c_t.data(), c_t.get_shape().lens,
// a_t.get_shape().strides, b_t.get_shape().strides, c_t.get_shape().strides,
// add};
// auto lengths = a_t.get_shape().lens;
// auto strides = a_t.get_shape().strides;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment