Commit c42aded1 authored by turneram's avatar turneram
Browse files

Formatting

parent 961cf059
...@@ -83,18 +83,18 @@ struct ck_elementwise_compiler : compiler<ck_elementwise_compiler> ...@@ -83,18 +83,18 @@ struct ck_elementwise_compiler : compiler<ck_elementwise_compiler>
operation compile_op(context& ctx, const std::vector<shape>& inputs, const value& v) const operation compile_op(context& ctx, const std::vector<shape>& inputs, const value& v) const
{ {
hip_compile_options options; hip_compile_options options;
options.inputs = inputs; options.inputs = inputs;
options.output = inputs.back(); options.output = inputs.back();
//options.virtual_inputs = reduce_dims(inputs); // options.virtual_inputs = reduce_dims(inputs);
//std::cout << options.virtual_inputs << std::endl; // std::cout << options.virtual_inputs << std::endl;
options.params = "-Wno-float-equal"; options.params = "-Wno-float-equal";
// auto axis = find_fast_axis(options.virtual_inputs); // auto axis = find_fast_axis(options.virtual_inputs);
// auto vec = vectorize::elements(axis, options.virtual_inputs); // auto vec = vectorize::elements(axis, options.virtual_inputs);
// auto preloads = preload::broadcasts(axis, options.virtual_inputs); // auto preloads = preload::broadcasts(axis, options.virtual_inputs);
auto axis = find_fast_axis(inputs); auto axis = find_fast_axis(inputs);
auto vec = vectorize::elements(axis, inputs); auto vec = vectorize::elements(axis, inputs);
auto preloads = preload::broadcasts(axis, inputs); auto preloads = preload::broadcasts(axis, inputs);
options.kernel_name = "ck_elementwise_kernel"; options.kernel_name = "ck_elementwise_kernel";
options.set_launch_params( options.set_launch_params(
v, v,
compute_global_for(ctx, compute_global_for(ctx,
......
...@@ -95,7 +95,7 @@ template <ck::index_t ndim> ...@@ -95,7 +95,7 @@ template <ck::index_t ndim>
struct CKBinaryElementwise2 struct CKBinaryElementwise2
{ {
template <class Desc_M> template <class Desc_M>
/* constexpr */__device__ auto PadDescriptor_M_1d(Desc_M desc_m) /* constexpr */ __device__ auto PadDescriptor_M_1d(Desc_M desc_m)
{ {
auto gridSize = 72; auto gridSize = 72;
auto blockSize = 1024; auto blockSize = 1024;
...@@ -112,12 +112,16 @@ struct CKBinaryElementwise2 ...@@ -112,12 +112,16 @@ struct CKBinaryElementwise2
} }
template <class L, class S> template <class L, class S>
/* constexpr */__device__ auto MakeDescriptor_M(const L& lengths, const S& strides) /* constexpr */ __device__ auto MakeDescriptor_M(const L& lengths, const S& strides)
{ {
auto tupleOfShape = generate_tuple( auto tupleOfShape = generate_tuple(
[&](auto I) { return static_cast<ck::index_t>(lengths[I]); }, ck::Number<ndim>{}); [&](auto I) { return static_cast<ck::index_t>(lengths[I]); }, ck::Number<ndim>{});
auto tupleOfStride = generate_tuple( auto tupleOfStride = generate_tuple(
[&](auto I) { printf ("Stride %i: %i\n", int(I), int(strides[I])); return static_cast<ck::index_t>(strides[I]); }, ck::Number<ndim>{}); [&](auto I) {
printf("Stride %i: %i\n", int(I), int(strides[I]));
return static_cast<ck::index_t>(strides[I]);
},
ck::Number<ndim>{});
const auto desc = make_naive_tensor_descriptor(tupleOfShape, tupleOfStride); const auto desc = make_naive_tensor_descriptor(tupleOfShape, tupleOfStride);
// merge nd to 1d desc - [s0 * s1 * ...] // merge nd to 1d desc - [s0 * s1 * ...]
if constexpr(ndim > 1) if constexpr(ndim > 1)
...@@ -166,37 +170,37 @@ struct Div ...@@ -166,37 +170,37 @@ struct Div
template <class T, class U, class V> template <class T, class U, class V>
__device__ void ck_elementwise(const T& a_t, const U& b_t, const V& c_t) __device__ void ck_elementwise(const T& a_t, const U& b_t, const V& c_t)
{ {
//auto idx = make_index(); // auto idx = make_index();
constexpr auto a_lens = get_shape_c<T>{}.lens; constexpr auto a_lens = get_shape_c<T>{}.lens;
constexpr auto a_strides = get_shape_c<T>{}.strides; constexpr auto a_strides = get_shape_c<T>{}.strides;
constexpr ck::index_t a_ndim = a_lens.size(); //decltype(a_lens.size()){}; constexpr ck::index_t a_ndim = a_lens.size(); // decltype(a_lens.size()){};
// if (idx.global == 0) // if (idx.global == 0)
// printf("a_ndim: %i\n", int(a_ndim)); // printf("a_ndim: %i\n", int(a_ndim));
auto a_bin_op = CKBinaryElementwise<a_ndim>{}; auto a_bin_op = CKBinaryElementwise<a_ndim>{};
constexpr auto a_desc = a_bin_op.MakeDescriptor_M(a_lens, a_strides); constexpr auto a_desc = a_bin_op.MakeDescriptor_M(a_lens, a_strides);
constexpr auto b_lens = get_shape_c<U>{}.lens; constexpr auto b_lens = get_shape_c<U>{}.lens;
constexpr auto b_strides = get_shape_c<U>{}.strides; constexpr auto b_strides = get_shape_c<U>{}.strides;
constexpr ck::index_t b_ndim = b_lens.size(); //decltype(b_lens.size()){}; constexpr ck::index_t b_ndim = b_lens.size(); // decltype(b_lens.size()){};
// if (idx.global == 0) // if (idx.global == 0)
// printf("b_ndim: %i\n", int(b_ndim)); // printf("b_ndim: %i\n", int(b_ndim));
auto b_bin_op = CKBinaryElementwise<b_ndim>{}; auto b_bin_op = CKBinaryElementwise<b_ndim>{};
constexpr auto b_desc = b_bin_op.MakeDescriptor_M(b_lens, b_strides); constexpr auto b_desc = b_bin_op.MakeDescriptor_M(b_lens, b_strides);
constexpr auto c_lens = get_shape_c<V>{}.lens; constexpr auto c_lens = get_shape_c<V>{}.lens;
constexpr auto c_strides = get_shape_c<V>{}.strides; constexpr auto c_strides = get_shape_c<V>{}.strides;
constexpr ck::index_t c_ndim = c_lens.size(); //decltype(c_lens.size()){}; constexpr ck::index_t c_ndim = c_lens.size(); // decltype(c_lens.size()){};
auto c_bin_op = CKBinaryElementwise<c_ndim>{}; auto c_bin_op = CKBinaryElementwise<c_ndim>{};
constexpr auto c_desc = c_bin_op.MakeDescriptor_M(c_lens, c_strides); constexpr auto c_desc = c_bin_op.MakeDescriptor_M(c_lens, c_strides);
using AGridDesc_M = decltype(a_desc); using AGridDesc_M = decltype(a_desc);
using BGridDesc_M = decltype(b_desc); using BGridDesc_M = decltype(b_desc);
using CGridDesc_M = decltype(c_desc); using CGridDesc_M = decltype(c_desc);
constexpr ck::index_t MPerThread = 8; constexpr ck::index_t MPerThread = 8;
constexpr ck::index_t AScalarPerVector = 8; constexpr ck::index_t AScalarPerVector = 8;
constexpr ck::index_t BScalarPerVector = 8; constexpr ck::index_t BScalarPerVector = 8;
constexpr ck::index_t CScalarPerVector = 8; constexpr ck::index_t CScalarPerVector = 8;
using GridwiseBinEltwise = ck::GridwiseBinaryElementwise_1D<ADataType, using GridwiseBinEltwise = ck::GridwiseBinaryElementwise_1D<ADataType,
BDataType, BDataType,
CDataType, CDataType,
CDataType, CDataType,
...@@ -208,7 +212,7 @@ __device__ void ck_elementwise(const T& a_t, const U& b_t, const V& c_t) ...@@ -208,7 +212,7 @@ __device__ void ck_elementwise(const T& a_t, const U& b_t, const V& c_t)
AScalarPerVector, AScalarPerVector,
BScalarPerVector, BScalarPerVector,
CScalarPerVector>; CScalarPerVector>;
auto op = Add{}; auto op = Add{};
GridwiseBinEltwise::Run(a_t.data(), b_t.data(), c_t.data(), a_desc, b_desc, c_desc, op); GridwiseBinEltwise::Run(a_t.data(), b_t.data(), c_t.data(), a_desc, b_desc, c_desc, op);
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment