"docs/vscode:/vscode.git/clone" did not exist on "a256753221ad2a33ec9750b31f6284b581c1e1fd"
Commit edac496b authored by Jing Zhang's avatar Jing Zhang Committed by Astha Rai
Browse files

fixed some issues

parent 5c0736c9
...@@ -22,10 +22,10 @@ using DeviceElementwisePermuteInstance = ...@@ -22,10 +22,10 @@ using DeviceElementwisePermuteInstance =
PassThrough, PassThrough,
3, // NumDim_M 3, // NumDim_M
1, // NumDim_N 1, // NumDim_N
8, 1,
8, 1,
ck::Sequence<8>, ck::Sequence<1>,
ck::Sequence<8>>; ck::Sequence<1>>;
template <typename HostTensorA, typename HostTensorB, typename Functor> template <typename HostTensorA, typename HostTensorB, typename Functor>
void host_elementwise4D(HostTensorB& B_nhwc, void host_elementwise4D(HostTensorB& B_nhwc,
...@@ -48,10 +48,10 @@ int main() ...@@ -48,10 +48,10 @@ int main()
bool do_verification = true; bool do_verification = true;
bool time_kernel = true; bool time_kernel = true;
//const int N = 120; // const int N = 120;
//const int C = 128; // const int C = 128;
//const int H = 32; // const int H = 32;
//const int W = 1024; // const int W = 1024;
const int N = 16; const int N = 16;
const int C = 8; const int C = 8;
const int H = 32; const int H = 32;
...@@ -110,13 +110,13 @@ int main() ...@@ -110,13 +110,13 @@ int main()
float gb_per_sec = num_btype / 1.E6 / ave_time; float gb_per_sec = num_btype / 1.E6 / ave_time;
//LogRangeAsType<float>(std::cout << "A : ", a.mData, ",") << std::endl; // LogRangeAsType<float>(std::cout << "A : ", a.mData, ",") << std::endl;
std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s" std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s"
<< std::endl; << std::endl;
bool pass = true; bool pass = true;
//LogRangeAsType<float>(std::cout << "B : ", b.mData, ",") << std::endl; // LogRangeAsType<float>(std::cout << "B : ", b.mData, ",") << std::endl;
if(do_verification) if(do_verification)
{ {
b_device_buf.FromDevice(b.mData.data()); b_device_buf.FromDevice(b.mData.data());
...@@ -125,9 +125,9 @@ int main() ...@@ -125,9 +125,9 @@ int main()
Tensor<BDataType> host_b(nhwc); Tensor<BDataType> host_b(nhwc);
host_elementwise4D<Tensor<ADataType>, Tensor<BDataType>, PassThrough>( host_elementwise4D<Tensor<ADataType>, Tensor<BDataType>, PassThrough>(
host_b, a, nchw, PassThrough{}); host_b, a, nchw, PassThrough{});
//LogRangeAsType<float>(std::cout << "Host_b : ", host_b.mData, ",") << std::endl; // LogRangeAsType<float>(std::cout << "Host_b : ", host_b.mData, ",") << std::endl;
LogRangeAsType<float>(std::cout << "B : ", b.mData, ",") << std::endl; // LogRangeAsType<float>(std::cout << "B : ", b.mData, ",") << std::endl;
// LogRangeAsType<float>(std::cout << "Host b : ", host_b.mData, ",") << std::endl; // LogRangeAsType<float>(std::cout << "Host b : ", host_b.mData, ",") << std::endl;
pass &= pass &=
ck::utils::check_err(b.mData, host_b.mData, "Error: Incorrect results b", 1e-3, 1e-3); ck::utils::check_err(b.mData, host_b.mData, "Error: Incorrect results b", 1e-3, 1e-3);
......
...@@ -22,13 +22,13 @@ using DeviceElementwisePermuteInstance = ...@@ -22,13 +22,13 @@ using DeviceElementwisePermuteInstance =
ck::tensor_operation::device::DeviceElementwise3dImpl<ck::Tuple<ADataType>, ck::tensor_operation::device::DeviceElementwise3dImpl<ck::Tuple<ADataType>,
ck::Tuple<BDataType>, ck::Tuple<BDataType>,
PassThrough, PassThrough,
3, 2, // NumDim_m, {N, C}
2, // NumDim_n, {H, W}
1, // NumDim_k, {D}
1, 1,
1, 1,
8, 1,
8, ck::Sequence<1>,
8,
ck::Sequence<8>,
ck::Sequence<1>>; ck::Sequence<1>>;
template <typename HostTensorA, typename HostTensorB, typename Functor> template <typename HostTensorA, typename HostTensorB, typename Functor>
...@@ -50,11 +50,11 @@ int main() ...@@ -50,11 +50,11 @@ int main()
bool do_verification = true; bool do_verification = true;
bool time_kernel = true; bool time_kernel = true;
const int N = 16; const int N = 1;
const int C = 8; const int C = 2;
const int D = 8; const int H = 3;
const int H = 8; const int W = 4;
const int W = 8; const int D = 16;
std::vector<std::size_t> ncdhw = {N, C, D, H, W}; std::vector<std::size_t> ncdhw = {N, C, D, H, W};
std::vector<std::size_t> nchwd = {N, C, H, W, D}; std::vector<std::size_t> nchwd = {N, C, H, W, D};
...@@ -72,8 +72,8 @@ int main() ...@@ -72,8 +72,8 @@ int main()
std::array<void*, 1> output = {b_device_buf.GetDeviceBuffer()}; std::array<void*, 1> output = {b_device_buf.GetDeviceBuffer()};
std::array<ck::index_t, 5> ab_lengths{N, C, H, W, D}; std::array<ck::index_t, 5> ab_lengths{N, C, H, W, D};
std::array<ck::index_t, 5> a_strides = {C * D * H * W, D * H * W, 1, D * H, D}; std::array<ck::index_t, 5> a_strides = {C * D * H * W, D * H * W, H, 1, H * W}; // N, C, D, H, W
std::array<ck::index_t, 5> b_strides = {C * H * W * D, H * W * D, W * D, D, 1}; std::array<ck::index_t, 5> b_strides = {C * H * W * D, H * W * D, W * D, D, 1}; // N, C, H, W, D
auto broadcastPermute = DeviceElementwisePermuteInstance{}; auto broadcastPermute = DeviceElementwisePermuteInstance{};
auto argument = broadcastPermute.MakeArgumentPointer( auto argument = broadcastPermute.MakeArgumentPointer(
...@@ -93,7 +93,8 @@ int main() ...@@ -93,7 +93,8 @@ int main()
broadcastPermute_invoker_ptr->Run(argument.get(), StreamConfig{nullptr, time_kernel}); broadcastPermute_invoker_ptr->Run(argument.get(), StreamConfig{nullptr, time_kernel});
std::size_t flop = std::size_t(2) * ncdhw[0] * ncdhw[1] * ncdhw[2] * ncdhw[3] * ncdhw[4]; std::size_t flop = std::size_t(2) * ncdhw[0] * ncdhw[1] * ncdhw[2] * ncdhw[3] * ncdhw[4];
std::size_t num_btype = sizeof(ADataType) * (ncdhw[0] * ncdhw[1] * ncdhw[2] * ncdhw[3] * ncdhw[4]) + std::size_t num_btype =
sizeof(ADataType) * (ncdhw[0] * ncdhw[1] * ncdhw[2] * ncdhw[3] * ncdhw[4]) +
sizeof(BDataType) * (ncdhw[0] * ncdhw[1] * ncdhw[2] * ncdhw[3] * ncdhw[4]); sizeof(BDataType) * (ncdhw[0] * ncdhw[1] * ncdhw[2] * ncdhw[3] * ncdhw[4]);
float tflops = static_cast<float>(flop) / 1.E9 / ave_time; float tflops = static_cast<float>(flop) / 1.E9 / ave_time;
...@@ -101,8 +102,8 @@ int main() ...@@ -101,8 +102,8 @@ int main()
float gb_per_sec = num_btype / 1.E6 / ave_time; float gb_per_sec = num_btype / 1.E6 / ave_time;
// LogRangeAsType<float>(std::cout << "A : ", a.mData, ",") << std::endl; // LogRangeAsType<float>(std::cout << "A : ", a.mData, ",") << std::endl;
//LogRangeAsType<float>(std::cout << "B : ", b.mData, ",") << std::endl; // LogRangeAsType<float>(std::cout << "B : ", b.mData, ",") << std::endl;
//std::cout << "A: " << a.mData.data() << std::endl; // std::cout << "A: " << a.mData.data() << std::endl;
std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s" std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s"
<< std::endl; << std::endl;
...@@ -113,13 +114,11 @@ int main() ...@@ -113,13 +114,11 @@ int main()
{ {
b_device_buf.FromDevice(b.mData.data()); b_device_buf.FromDevice(b.mData.data());
//LogRangeAsType<float>(std::cout << "A : ", a.mData, ",") << std::endl; // LogRangeAsType<float>(std::cout << "A : ", a.mData, ",") << std::endl;
LogRangeAsType<float>(std::cout << "B : ", b.mData, ",") << std::endl; LogRangeAsType<float>(std::cout << "B : ", b.mData, ",") << std::endl;
Tensor<BDataType> host_b(nchwd); Tensor<BDataType> host_b(nchwd);
host_elementwise4D(host_b, a, PassThrough{}); host_elementwise4D(host_b, a, PassThrough{});
LogRangeAsType<float>(std::cout << "Host B : ", host_b.mData, ",") << std::endl;
//LogRangeAsType<float>(std::cout << "B : ", b.mData, ",") << std::endl;
//LogRangeAsType<float>(std::cout << "Host B : ", host_b.mData, ",") << std::endl;
pass &= pass &=
ck::utils::check_err(b.mData, host_b.mData, "Error: Incorrect results b", 1e-3, 1e-3); ck::utils::check_err(b.mData, host_b.mData, "Error: Incorrect results b", 1e-3, 1e-3);
......
...@@ -21,7 +21,7 @@ namespace device { ...@@ -21,7 +21,7 @@ namespace device {
template <typename InDataTypeTuple, template <typename InDataTypeTuple,
typename OutDataTypeTuple, typename OutDataTypeTuple,
typename ElementwiseOperation, typename ElementwiseOperation,
index_t NumDim_m,//choose how to set dims index_t NumDim_m, // choose how to set dims
index_t NumDim_n, index_t NumDim_n,
index_t NumDim_k, index_t NumDim_k,
index_t MPerThread, index_t MPerThread,
...@@ -58,7 +58,7 @@ struct DeviceElementwise3dImpl : public DeviceElementwise<InDataTypeTuple, ...@@ -58,7 +58,7 @@ struct DeviceElementwise3dImpl : public DeviceElementwise<InDataTypeTuple,
return static_cast<const DataType*>(nullptr); return static_cast<const DataType*>(nullptr);
}, },
Number<NumInput>{}); Number<NumInput>{});
}; }
static auto GenerateOutDataTypePointerTuple() static auto GenerateOutDataTypePointerTuple()
{ {
...@@ -69,7 +69,7 @@ struct DeviceElementwise3dImpl : public DeviceElementwise<InDataTypeTuple, ...@@ -69,7 +69,7 @@ struct DeviceElementwise3dImpl : public DeviceElementwise<InDataTypeTuple,
return static_cast<DataType*>(nullptr); return static_cast<DataType*>(nullptr);
}, },
Number<NumOutput>{}); Number<NumOutput>{});
}; }
using InDataTypePointerTuple = decltype(GenerateInDataTypePointerTuple()); using InDataTypePointerTuple = decltype(GenerateInDataTypePointerTuple());
using OutDataTypePointerTuple = decltype(GenerateOutDataTypePointerTuple()); using OutDataTypePointerTuple = decltype(GenerateOutDataTypePointerTuple());
...@@ -84,6 +84,7 @@ struct DeviceElementwise3dImpl : public DeviceElementwise<InDataTypeTuple, ...@@ -84,6 +84,7 @@ struct DeviceElementwise3dImpl : public DeviceElementwise<InDataTypeTuple,
{ {
std::ignore = blockSize; std::ignore = blockSize;
std::ignore = gridSize; std::ignore = gridSize;
const auto m = desc_mnk.GetLength(I0); const auto m = desc_mnk.GetLength(I0);
const auto n = desc_mnk.GetLength(I1); const auto n = desc_mnk.GetLength(I1);
const auto k = desc_mnk.GetLength(I2); const auto k = desc_mnk.GetLength(I2);
...@@ -94,9 +95,11 @@ struct DeviceElementwise3dImpl : public DeviceElementwise<InDataTypeTuple, ...@@ -94,9 +95,11 @@ struct DeviceElementwise3dImpl : public DeviceElementwise<InDataTypeTuple,
const auto pad_n = math::integer_least_multiple(n, loop_step_n) - n; const auto pad_n = math::integer_least_multiple(n, loop_step_n) - n;
const auto pad_k = math::integer_least_multiple(k, loop_step_k) - k; const auto pad_k = math::integer_least_multiple(k, loop_step_k) - k;
const auto desc_mnk_pad = transform_tensor_descriptor( const auto desc_mnk_pad =
desc_mnk, transform_tensor_descriptor(desc_mnk,
make_tuple(make_right_pad_transform(m, pad_m), make_right_pad_transform(n, pad_n), make_right_pad_transform(k, pad_k)), make_tuple(make_right_pad_transform(m, pad_m),
make_right_pad_transform(n, pad_n),
make_right_pad_transform(k, pad_k)),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}), make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{})); make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}));
return desc_mnk_pad; return desc_mnk_pad;
...@@ -127,19 +130,22 @@ struct DeviceElementwise3dImpl : public DeviceElementwise<InDataTypeTuple, ...@@ -127,19 +130,22 @@ struct DeviceElementwise3dImpl : public DeviceElementwise<InDataTypeTuple,
const auto kLengths = get_container_subset(tupleOfShape, kDimIds); const auto kLengths = get_container_subset(tupleOfShape, kDimIds);
// merge nd to 3d desc - [s0 * s1 * ...] // merge nd to 3d desc - [s0 * s1 * ...]
if constexpr(NumDim > 3) if constexpr(NumDim > 3)
{ {
const auto desc_mnk = transform_tensor_descriptor( const auto desc_mnk = transform_tensor_descriptor(
desc, desc,
make_tuple(make_merge_transform(mLengths), make_merge_transform(nLengths), make_merge_transform(kLengths)), make_tuple(make_merge_transform(mLengths),
make_merge_transform(nLengths),
make_merge_transform(kLengths)),
make_tuple(mDimIds, nDimIds, kDimIds), make_tuple(mDimIds, nDimIds, kDimIds),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{})); make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}));
return PadDescriptor_MNK(desc_mnk, gridSize, blockSize, num_threads_m, num_threads_n, num_threads_k); return PadDescriptor_MNK(
desc_mnk, gridSize, blockSize, num_threads_m, num_threads_n, num_threads_k);
} }
else else
return PadDescriptor_MNK(desc, gridSize, blockSize, num_threads_m, num_threads_n, num_threads_k); return PadDescriptor_MNK(
desc, gridSize, blockSize, num_threads_m, num_threads_n, num_threads_k);
} }
template <index_t TupleSize> template <index_t TupleSize>
...@@ -157,7 +163,7 @@ struct DeviceElementwise3dImpl : public DeviceElementwise<InDataTypeTuple, ...@@ -157,7 +163,7 @@ struct DeviceElementwise3dImpl : public DeviceElementwise<InDataTypeTuple,
}; };
}, },
Number<TupleSize>{}); Number<TupleSize>{});
}; }
using OutGrid3dDescTuple = decltype(GenerateInOutGrid3dDescTuple(Number<NumOutput>{})); using OutGrid3dDescTuple = decltype(GenerateInOutGrid3dDescTuple(Number<NumOutput>{}));
using InGrid3dDescTuple = decltype(GenerateInOutGrid3dDescTuple(Number<NumInput>{})); using InGrid3dDescTuple = decltype(GenerateInOutGrid3dDescTuple(Number<NumInput>{}));
...@@ -222,8 +228,8 @@ struct DeviceElementwise3dImpl : public DeviceElementwise<InDataTypeTuple, ...@@ -222,8 +228,8 @@ struct DeviceElementwise3dImpl : public DeviceElementwise<InDataTypeTuple,
{ {
float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{}) float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{})
{ {
index_t gridSize = getAvailableComputeUnitCount(stream_config); index_t gridSize = getAvailableComputeUnitCount(stream_config) * arg.blockSize_;
index_t num_threads_m = (gridSize * arg.blockSize_) / 16; index_t num_threads_m = gridSize / (16 * 16);
index_t num_threads_n = 16; index_t num_threads_n = 16;
index_t num_threads_k = 16; index_t num_threads_k = 16;
...@@ -296,42 +302,47 @@ struct DeviceElementwise3dImpl : public DeviceElementwise<InDataTypeTuple, ...@@ -296,42 +302,47 @@ struct DeviceElementwise3dImpl : public DeviceElementwise<InDataTypeTuple,
const std::array<index_t, NumDim>& strides, const std::array<index_t, NumDim>& strides,
index_t scalarPerVector, index_t scalarPerVector,
index_t vectorDim) { index_t vectorDim) {
if(strides[vectorDim] == 1 && ignore = lengths;
(lengths[vectorDim] % scalarPerVector == 0 || ignore = strides;
lengths[vectorDim] % scalarPerVector == lengths[vectorDim])) ignore = scalarPerVector;
{ ignore = vectorDim;
// if(strides[vectorDim] == 1 &&
//(lengths[vectorDim] % scalarPerVector == 0))
////lengths[vectorDim] % scalarPerVector == lengths[vectorDim]))
//{
// return true;
//}
// if(strides[vectorDim] >= scalarPerVector)
//{
// return true;
//}
return true; return true;
}
if(strides[vectorDim] != 1 && scalarPerVector == strides[vectorDim])
{
return true;
}
return false;
}; };
bool valid = true; bool valid = true;
static_for<0, NumInput, 1>{}([&](auto I) { static_for<0, NumInput, 1>{}([&](auto I) {
if(!IsScalarPerVectorValid(pArg->lengths_, valid = valid && IsScalarPerVectorValid(pArg->lengths_,
pArg->inStridesArray_[I.value], pArg->inStridesArray_[I.value],
InScalarPerVectorSeq::At(I), InScalarPerVectorSeq::At(I),
NumDim_m - 1)) NumDim_m - 1);
//LogRangeAsType<float>(std::cout << "in scalarperveq : ", InScalarPerVectorSeq::At(I), ",") << std::endl; // LogRangeAsType<float>(std::cout << "in scalarperveq : ",
//LogRangeAsType<float>(std::cout << "vecdim : ", NumDim_m - 1, ",") << std::endl; // InScalarPerVectorSeq::At(I), ",") << std::endl; LogRangeAsType<float>(std::cout <<
valid = false; // "vecdim : ", NumDim_m - 1, ",") << std::endl;
}); });
static_for<0, NumOutput, 1>{}([&](auto I) { static_for<0, NumOutput, 1>{}([&](auto I) {
if(!IsScalarPerVectorValid(pArg->lengths_, valid = valid && IsScalarPerVectorValid(pArg->lengths_,
pArg->outStridesArray_[I.value], pArg->outStridesArray_[I.value],
OutScalarPerVectorSeq::At(I), OutScalarPerVectorSeq::At(I),
NumDim - 1)) NumDim - 1);
//LogRangeAsType<float>(std::cout << "out scalarperveq : ", OutScalarPerVectorSeq::At(I), ",") << std::endl; // LogRangeAsType<float>(std::cout << "out scalarperveq : ",
//LogRangeAsType<float>(std::cout << "vecdim : ", NumDim - 1, ",") << std::endl; // OutScalarPerVectorSeq::At(I), ",") << std::endl; LogRangeAsType<float>(std::cout
valid = false; // << "vecdim : ", NumDim - 1, ",") << std::endl;
}); });
return valid; return valid;
}; }
std::unique_ptr<BaseArgument> std::unique_ptr<BaseArgument>
MakeArgumentPointer(const std::array<index_t, NumDim> lengths, MakeArgumentPointer(const std::array<index_t, NumDim> lengths,
...@@ -353,7 +364,7 @@ struct DeviceElementwise3dImpl : public DeviceElementwise<InDataTypeTuple, ...@@ -353,7 +364,7 @@ struct DeviceElementwise3dImpl : public DeviceElementwise<InDataTypeTuple,
std::unique_ptr<BaseInvoker> MakeInvokerPointer() override std::unique_ptr<BaseInvoker> MakeInvokerPointer() override
{ {
return std::make_unique<Invoker>(); return std::make_unique<Invoker>();
}; }
}; // namespace device }; // namespace device
} // namespace device } // namespace device
......
...@@ -60,8 +60,8 @@ struct GridwiseElementwise_3D ...@@ -60,8 +60,8 @@ struct GridwiseElementwise_3D
static constexpr auto I1 = Number<1>{}; static constexpr auto I1 = Number<1>{};
static constexpr auto I2 = Number<2>{}; static constexpr auto I2 = Number<2>{};
static constexpr auto thread_buffer_desc_mnk = static constexpr auto thread_buffer_desc_mnk = make_naive_tensor_descriptor_packed(
make_naive_tensor_descriptor_packed(make_tuple(Number<MPerThread>{}, Number<NPerThread>{}, Number<KPerThread>{})); make_tuple(Number<MPerThread>{}, Number<NPerThread>{}, Number<KPerThread>{}));
using PassThroughOp = tensor_operation::element_wise::PassThrough; using PassThroughOp = tensor_operation::element_wise::PassThrough;
...@@ -121,14 +121,14 @@ struct GridwiseElementwise_3D ...@@ -121,14 +121,14 @@ struct GridwiseElementwise_3D
const index_t loop_step_k = num_threads_k * KPerThread; const index_t loop_step_k = num_threads_k * KPerThread;
const index_t thread_1d_id = get_thread_global_1d_id(); const index_t thread_1d_id = get_thread_global_1d_id();
//index_t tid_m = thread_1d_id / num_threads_n;
//index_t tid_n = thread_1d_id % num_threads_n;
index_t tid_m = thread_1d_id / (num_threads_n * num_threads_k);
index_t tid_n = (thread_1d_id / (num_threads_n * num_threads_k)) / num_threads_k;
index_t tid_k = (thread_1d_id / (num_threads_n * num_threads_k)) % num_threads_k;
const index_t tid_m = thread_1d_id / (num_threads_n * num_threads_k);
const index_t tid_nk = thread_1d_id % (num_threads_n * num_threads_k);
const index_t tid_n = tid_nk / num_threads_k;
const index_t tid_k = tid_nk % num_threads_k;
const auto thread_global_offset = make_multi_index(tid_m * MPerThread, tid_n * NPerThread, tid_k * KPerThread); const auto thread_global_offset =
make_multi_index(tid_m * MPerThread, tid_n * NPerThread, tid_k * KPerThread);
auto in_global_load_tuple = generate_tuple( auto in_global_load_tuple = generate_tuple(
[&](auto I) { [&](auto I) {
...@@ -143,7 +143,7 @@ struct GridwiseElementwise_3D ...@@ -143,7 +143,7 @@ struct GridwiseElementwise_3D
Sequence<MPerThread, NPerThread, KPerThread>, // SliceLengths Sequence<MPerThread, NPerThread, KPerThread>, // SliceLengths
Sequence<0, 1, 2>, // DimAccessOrder Sequence<0, 1, 2>, // DimAccessOrder
0, // SrcVectorDim 0, // SrcVectorDim
InScalarPerVectorSeq::At(I), // ScalarPerVector 1, // InScalarPerVectorSeq::At(I), // ScalarPerVector
1, // SrcScalarStrideInVector 1, // SrcScalarStrideInVector
true>{in_grid_3d_desc_tuple[I], thread_global_offset}; true>{in_grid_3d_desc_tuple[I], thread_global_offset};
}, },
...@@ -177,7 +177,8 @@ struct GridwiseElementwise_3D ...@@ -177,7 +177,8 @@ struct GridwiseElementwise_3D
do do
{ {
index_t num_iter_k = K / (loop_step_k); index_t num_iter_k = K / (loop_step_k);
do{ do
{
static_for<0, NumInput, 1>{}([&](auto I) { static_for<0, NumInput, 1>{}([&](auto I) {
in_global_load_tuple(I).Run(in_grid_3d_desc_tuple[I], in_global_load_tuple(I).Run(in_grid_3d_desc_tuple[I],
in_global_buf_tuple[I], in_global_buf_tuple[I],
...@@ -185,13 +186,13 @@ struct GridwiseElementwise_3D ...@@ -185,13 +186,13 @@ struct GridwiseElementwise_3D
make_tuple(I0, I0, I0), make_tuple(I0, I0, I0),
in_thread_buf_tuple(I)); in_thread_buf_tuple(I));
in_global_load_tuple(I).MoveSrcSliceWindow(in_grid_3d_desc_tuple[I], in_global_load_tuple(I).MoveSrcSliceWindow(
make_multi_index(0, 0, loop_step_k)); in_grid_3d_desc_tuple[I], make_multi_index(0, 0, loop_step_k));
}); });
static_for<0, MPerThread, 1>{}([&](auto iM) { static_for<0, MPerThread, 1>{}([&](auto iM) {
static_for<0, NPerThread, 1>{}([&](auto iN) { static_for<0, NPerThread, 1>{}([&](auto iN) {
static_for<0, KPerThread, 1>{}([&](auto iK){ static_for<0, KPerThread, 1>{}([&](auto iK) {
constexpr auto offset = constexpr auto offset =
thread_buffer_desc_mnk.CalculateOffset(make_tuple(iM, iN, iK)); thread_buffer_desc_mnk.CalculateOffset(make_tuple(iM, iN, iK));
// get reference to in data // get reference to in data
...@@ -221,10 +222,10 @@ struct GridwiseElementwise_3D ...@@ -221,10 +222,10 @@ struct GridwiseElementwise_3D
out_grid_3d_desc_tuple[I], out_grid_3d_desc_tuple[I],
out_global_buf_tuple(I)); out_global_buf_tuple(I));
out_global_store_tuple(I).MoveDstSliceWindow(out_grid_3d_desc_tuple[I], out_global_store_tuple(I).MoveDstSliceWindow(
make_multi_index(0, loop_step_n, loop_step_k)); out_grid_3d_desc_tuple[I], make_multi_index(0, 0, loop_step_k));
}); });
} while (--num_iter_k); } while(--num_iter_k);
static_for<0, NumInput, 1>{}([&](auto I) { static_for<0, NumInput, 1>{}([&](auto I) {
in_global_load_tuple(I).MoveSrcSliceWindow( in_global_load_tuple(I).MoveSrcSliceWindow(
...@@ -243,13 +244,17 @@ struct GridwiseElementwise_3D ...@@ -243,13 +244,17 @@ struct GridwiseElementwise_3D
static_for<0, NumInput, 1>{}([&](auto I) { static_for<0, NumInput, 1>{}([&](auto I) {
in_global_load_tuple(I).MoveSrcSliceWindow( in_global_load_tuple(I).MoveSrcSliceWindow(
in_grid_3d_desc_tuple[I], in_grid_3d_desc_tuple[I],
make_multi_index(loop_step_m, -(N / loop_step_n) * loop_step_n, -(K / loop_step_k) * loop_step_k)); make_multi_index(loop_step_m,
-(N / loop_step_n) * loop_step_n,
-(K / loop_step_k) * loop_step_k));
}); });
static_for<0, NumOutput, 1>{}([&](auto I) { static_for<0, NumOutput, 1>{}([&](auto I) {
out_global_store_tuple(I).MoveDstSliceWindow( out_global_store_tuple(I).MoveDstSliceWindow(
out_grid_3d_desc_tuple[I], out_grid_3d_desc_tuple[I],
make_multi_index(loop_step_m, -(N / loop_step_n) * loop_step_n, -(K / loop_step_k) * loop_step_k)); make_multi_index(loop_step_m,
-(N / loop_step_n) * loop_step_n,
-(K / loop_step_k) * loop_step_k));
}); });
} while(--num_iter_m); } while(--num_iter_m);
} }
......
...@@ -12,7 +12,7 @@ cmake ...@@ -12,7 +12,7 @@ cmake
-save-temps=$PWD" \ -save-temps=$PWD" \
-D CMAKE_BUILD_TYPE=Release \ -D CMAKE_BUILD_TYPE=Release \
-D BUILD_DEV=ON \ -D BUILD_DEV=ON \
-D GPU_TARGETS="gfx908;gfx90a;gfx940" \ -D GPU_TARGETS="gfx90a" \
-D CMAKE_VERBOSE_MAKEFILE:BOOL=ON \ -D CMAKE_VERBOSE_MAKEFILE:BOOL=ON \
-D USE_BITINT_EXTENSION_INT4=OFF \ -D USE_BITINT_EXTENSION_INT4=OFF \
${MY_PROJECT_SOURCE} ${MY_PROJECT_SOURCE}
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment