Commit edac496b authored by Jing Zhang's avatar Jing Zhang Committed by Astha Rai
Browse files

fixed some issues

parent 5c0736c9
......@@ -22,10 +22,10 @@ using DeviceElementwisePermuteInstance =
PassThrough,
3, // NumDim_M
1, // NumDim_N
8,
8,
ck::Sequence<8>,
ck::Sequence<8>>;
1,
1,
ck::Sequence<1>,
ck::Sequence<1>>;
template <typename HostTensorA, typename HostTensorB, typename Functor>
void host_elementwise4D(HostTensorB& B_nhwc,
......@@ -48,10 +48,10 @@ int main()
bool do_verification = true;
bool time_kernel = true;
//const int N = 120;
//const int C = 128;
//const int H = 32;
//const int W = 1024;
// const int N = 120;
// const int C = 128;
// const int H = 32;
// const int W = 1024;
const int N = 16;
const int C = 8;
const int H = 32;
......@@ -110,13 +110,13 @@ int main()
float gb_per_sec = num_btype / 1.E6 / ave_time;
//LogRangeAsType<float>(std::cout << "A : ", a.mData, ",") << std::endl;
// LogRangeAsType<float>(std::cout << "A : ", a.mData, ",") << std::endl;
std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s"
<< std::endl;
bool pass = true;
//LogRangeAsType<float>(std::cout << "B : ", b.mData, ",") << std::endl;
// LogRangeAsType<float>(std::cout << "B : ", b.mData, ",") << std::endl;
if(do_verification)
{
b_device_buf.FromDevice(b.mData.data());
......@@ -125,9 +125,9 @@ int main()
Tensor<BDataType> host_b(nhwc);
host_elementwise4D<Tensor<ADataType>, Tensor<BDataType>, PassThrough>(
host_b, a, nchw, PassThrough{});
//LogRangeAsType<float>(std::cout << "Host_b : ", host_b.mData, ",") << std::endl;
// LogRangeAsType<float>(std::cout << "Host_b : ", host_b.mData, ",") << std::endl;
LogRangeAsType<float>(std::cout << "B : ", b.mData, ",") << std::endl;
// LogRangeAsType<float>(std::cout << "B : ", b.mData, ",") << std::endl;
// LogRangeAsType<float>(std::cout << "Host b : ", host_b.mData, ",") << std::endl;
pass &=
ck::utils::check_err(b.mData, host_b.mData, "Error: Incorrect results b", 1e-3, 1e-3);
......
......@@ -20,16 +20,16 @@ using BDataType = F16;
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
using DeviceElementwisePermuteInstance =
ck::tensor_operation::device::DeviceElementwise3dImpl<ck::Tuple<ADataType>,
ck::Tuple<BDataType>,
PassThrough,
3,
1,
1,
8,
8,
8,
ck::Sequence<8>,
ck::Sequence<1>>;
ck::Tuple<BDataType>,
PassThrough,
2, // NumDim_m, {N, C}
2, // NumDim_n, {H, W}
1, // NumDim_k, {D}
1,
1,
1,
ck::Sequence<1>,
ck::Sequence<1>>;
template <typename HostTensorA, typename HostTensorB, typename Functor>
void host_elementwise4D(HostTensorB& B_nchwd, const HostTensorA& A_ncdhw, Functor functor)
......@@ -38,11 +38,11 @@ void host_elementwise4D(HostTensorB& B_nchwd, const HostTensorA& A_ncdhw, Functo
for(std::size_t c = 0; c < A_ncdhw.mDesc.GetLengths()[1]; ++c)
for(std::size_t d = 0; d < A_ncdhw.mDesc.GetLengths()[2]; ++d)
for(std::size_t h = 0; h < A_ncdhw.mDesc.GetLengths()[3]; ++h)
for(std::size_t w = 0; w < A_ncdhw.mDesc.GetLengths()[4]; ++w)
{
auto a_val = A_ncdhw(n, c, d, h, w);
functor(B_nchwd(n, c, h, w, d), a_val);
}
for(std::size_t w = 0; w < A_ncdhw.mDesc.GetLengths()[4]; ++w)
{
auto a_val = A_ncdhw(n, c, d, h, w);
functor(B_nchwd(n, c, h, w, d), a_val);
}
}
int main()
......@@ -50,11 +50,11 @@ int main()
bool do_verification = true;
bool time_kernel = true;
const int N = 16;
const int C = 8;
const int D = 8;
const int H = 8;
const int W = 8;
const int N = 1;
const int C = 2;
const int H = 3;
const int W = 4;
const int D = 16;
std::vector<std::size_t> ncdhw = {N, C, D, H, W};
std::vector<std::size_t> nchwd = {N, C, H, W, D};
......@@ -72,8 +72,8 @@ int main()
std::array<void*, 1> output = {b_device_buf.GetDeviceBuffer()};
std::array<ck::index_t, 5> ab_lengths{N, C, H, W, D};
std::array<ck::index_t, 5> a_strides = {C * D * H * W, D * H * W, 1, D * H, D};
std::array<ck::index_t, 5> b_strides = {C * H * W * D, H * W * D, W * D, D, 1};
std::array<ck::index_t, 5> a_strides = {C * D * H * W, D * H * W, H, 1, H * W}; // N, C, D, H, W
std::array<ck::index_t, 5> b_strides = {C * H * W * D, H * W * D, W * D, D, 1}; // N, C, H, W, D
auto broadcastPermute = DeviceElementwisePermuteInstance{};
auto argument = broadcastPermute.MakeArgumentPointer(
......@@ -93,16 +93,17 @@ int main()
broadcastPermute_invoker_ptr->Run(argument.get(), StreamConfig{nullptr, time_kernel});
std::size_t flop = std::size_t(2) * ncdhw[0] * ncdhw[1] * ncdhw[2] * ncdhw[3] * ncdhw[4];
std::size_t num_btype = sizeof(ADataType) * (ncdhw[0] * ncdhw[1] * ncdhw[2] * ncdhw[3] * ncdhw[4]) +
sizeof(BDataType) * (ncdhw[0] * ncdhw[1] * ncdhw[2] * ncdhw[3] * ncdhw[4]);
std::size_t num_btype =
sizeof(ADataType) * (ncdhw[0] * ncdhw[1] * ncdhw[2] * ncdhw[3] * ncdhw[4]) +
sizeof(BDataType) * (ncdhw[0] * ncdhw[1] * ncdhw[2] * ncdhw[3] * ncdhw[4]);
float tflops = static_cast<float>(flop) / 1.E9 / ave_time;
float gb_per_sec = num_btype / 1.E6 / ave_time;
// LogRangeAsType<float>(std::cout << "A : ", a.mData, ",") << std::endl;
//LogRangeAsType<float>(std::cout << "B : ", b.mData, ",") << std::endl;
//std::cout << "A: " << a.mData.data() << std::endl;
// LogRangeAsType<float>(std::cout << "A : ", a.mData, ",") << std::endl;
// LogRangeAsType<float>(std::cout << "B : ", b.mData, ",") << std::endl;
// std::cout << "A: " << a.mData.data() << std::endl;
std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s"
<< std::endl;
......@@ -113,13 +114,11 @@ int main()
{
b_device_buf.FromDevice(b.mData.data());
//LogRangeAsType<float>(std::cout << "A : ", a.mData, ",") << std::endl;
LogRangeAsType<float>(std::cout << "B : ", b.mData, ",") << std::endl;
// LogRangeAsType<float>(std::cout << "A : ", a.mData, ",") << std::endl;
LogRangeAsType<float>(std::cout << "B : ", b.mData, ",") << std::endl;
Tensor<BDataType> host_b(nchwd);
host_elementwise4D(host_b, a, PassThrough{});
//LogRangeAsType<float>(std::cout << "B : ", b.mData, ",") << std::endl;
//LogRangeAsType<float>(std::cout << "Host B : ", host_b.mData, ",") << std::endl;
LogRangeAsType<float>(std::cout << "Host B : ", host_b.mData, ",") << std::endl;
pass &=
ck::utils::check_err(b.mData, host_b.mData, "Error: Incorrect results b", 1e-3, 1e-3);
......
......@@ -21,9 +21,9 @@ namespace device {
template <typename InDataTypeTuple,
typename OutDataTypeTuple,
typename ElementwiseOperation,
index_t NumDim_m,//choose how to set dims
index_t NumDim_m, // choose how to set dims
index_t NumDim_n,
index_t NumDim_k,
index_t NumDim_k,
index_t MPerThread,
index_t NPerThread,
index_t KPerThread,
......@@ -58,7 +58,7 @@ struct DeviceElementwise3dImpl : public DeviceElementwise<InDataTypeTuple,
return static_cast<const DataType*>(nullptr);
},
Number<NumInput>{});
};
}
static auto GenerateOutDataTypePointerTuple()
{
......@@ -69,46 +69,49 @@ struct DeviceElementwise3dImpl : public DeviceElementwise<InDataTypeTuple,
return static_cast<DataType*>(nullptr);
},
Number<NumOutput>{});
};
}
using InDataTypePointerTuple = decltype(GenerateInDataTypePointerTuple());
using OutDataTypePointerTuple = decltype(GenerateOutDataTypePointerTuple());
template <typename Desc_MNK>
static auto PadDescriptor_MNK(Desc_MNK desc_mnk,
index_t gridSize,
index_t blockSize,
index_t num_threads_m,
index_t num_threads_n,
index_t num_threads_k)
index_t gridSize,
index_t blockSize,
index_t num_threads_m,
index_t num_threads_n,
index_t num_threads_k)
{
std::ignore = blockSize;
std::ignore = gridSize;
std::ignore = blockSize;
std::ignore = gridSize;
const auto m = desc_mnk.GetLength(I0);
const auto n = desc_mnk.GetLength(I1);
const auto k = desc_mnk.GetLength(I2);
const auto k = desc_mnk.GetLength(I2);
const index_t loop_step_m = num_threads_m * MPerThread;
const index_t loop_step_n = num_threads_n * NPerThread;
const index_t loop_step_k = num_threads_k * KPerThread;
const index_t loop_step_k = num_threads_k * KPerThread;
const auto pad_m = math::integer_least_multiple(m, loop_step_m) - m;
const auto pad_n = math::integer_least_multiple(n, loop_step_n) - n;
const auto pad_k = math::integer_least_multiple(k, loop_step_k) - k;
const auto desc_mnk_pad = transform_tensor_descriptor(
desc_mnk,
make_tuple(make_right_pad_transform(m, pad_m), make_right_pad_transform(n, pad_n), make_right_pad_transform(k, pad_k)),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}));
const auto pad_k = math::integer_least_multiple(k, loop_step_k) - k;
const auto desc_mnk_pad =
transform_tensor_descriptor(desc_mnk,
make_tuple(make_right_pad_transform(m, pad_m),
make_right_pad_transform(n, pad_n),
make_right_pad_transform(k, pad_k)),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}));
return desc_mnk_pad;
}
static auto MakeDescriptor_MNK(const std::array<index_t, NumDim>& lengths,
const std::array<index_t, NumDim>& stride,
index_t gridSize,
index_t blockSize,
index_t num_threads_m,
index_t num_threads_n,
index_t num_threads_k)
const std::array<index_t, NumDim>& stride,
index_t gridSize,
index_t blockSize,
index_t num_threads_m,
index_t num_threads_n,
index_t num_threads_k)
{
auto tupleOfShape = generate_tuple([&](auto I) { return lengths[I]; }, Number<NumDim>{});
auto tupleOfStride = generate_tuple([&](auto I) { return stride[I]; }, Number<NumDim>{});
......@@ -119,27 +122,30 @@ struct DeviceElementwise3dImpl : public DeviceElementwise<InDataTypeTuple,
constexpr auto mDimIds = typename arithmetic_sequence_gen<0, NumDim_m, 1>::type();
constexpr auto nDimIds =
typename arithmetic_sequence_gen<NumDim_m, NumDim_m + NumDim_n, 1>::type();
constexpr auto kDimIds =
constexpr auto kDimIds =
typename arithmetic_sequence_gen<NumDim_m + NumDim_n, NumDim, 1>::type();
const auto mLengths = get_container_subset(tupleOfShape, mDimIds);
const auto nLengths = get_container_subset(tupleOfShape, nDimIds);
const auto kLengths = get_container_subset(tupleOfShape, kDimIds);
const auto kLengths = get_container_subset(tupleOfShape, kDimIds);
// merge nd to 3d desc - [s0 * s1 * ...]
if constexpr(NumDim > 3)
{
const auto desc_mnk = transform_tensor_descriptor(
desc,
make_tuple(make_merge_transform(mLengths), make_merge_transform(nLengths), make_merge_transform(kLengths)),
make_tuple(make_merge_transform(mLengths),
make_merge_transform(nLengths),
make_merge_transform(kLengths)),
make_tuple(mDimIds, nDimIds, kDimIds),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}));
return PadDescriptor_MNK(desc_mnk, gridSize, blockSize, num_threads_m, num_threads_n, num_threads_k);
return PadDescriptor_MNK(
desc_mnk, gridSize, blockSize, num_threads_m, num_threads_n, num_threads_k);
}
else
return PadDescriptor_MNK(desc, gridSize, blockSize, num_threads_m, num_threads_n, num_threads_k);
return PadDescriptor_MNK(
desc, gridSize, blockSize, num_threads_m, num_threads_n, num_threads_k);
}
template <index_t TupleSize>
......@@ -157,7 +163,7 @@ struct DeviceElementwise3dImpl : public DeviceElementwise<InDataTypeTuple,
};
},
Number<TupleSize>{});
};
}
using OutGrid3dDescTuple = decltype(GenerateInOutGrid3dDescTuple(Number<NumOutput>{}));
using InGrid3dDescTuple = decltype(GenerateInOutGrid3dDescTuple(Number<NumInput>{}));
......@@ -169,7 +175,7 @@ struct DeviceElementwise3dImpl : public DeviceElementwise<InDataTypeTuple,
ElementwiseOperation,
MPerThread,
NPerThread,
KPerThread,
KPerThread,
InScalarPerVectorSeq,
OutScalarPerVectorSeq>;
......@@ -222,32 +228,32 @@ struct DeviceElementwise3dImpl : public DeviceElementwise<InDataTypeTuple,
{
float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{})
{
index_t gridSize = getAvailableComputeUnitCount(stream_config);
index_t num_threads_m = (gridSize * arg.blockSize_) / 16;
index_t gridSize = getAvailableComputeUnitCount(stream_config) * arg.blockSize_;
index_t num_threads_m = gridSize / (16 * 16);
index_t num_threads_n = 16;
index_t num_threads_k = 16;
index_t num_threads_k = 16;
auto in_grid_3d_desc_tuple = generate_tuple(
[&](auto I) {
return MakeDescriptor_MNK(arg.lengths_,
arg.inStridesArray_[I.value],
gridSize,
arg.blockSize_,
num_threads_m,
num_threads_n,
num_threads_k);
arg.inStridesArray_[I.value],
gridSize,
arg.blockSize_,
num_threads_m,
num_threads_n,
num_threads_k);
},
Number<NumInput>{});
auto out_grid_3d_desc_tuple = generate_tuple(
[&](auto I) {
return MakeDescriptor_MNK(arg.lengths_,
arg.outStridesArray_[I.value],
gridSize,
arg.blockSize_,
num_threads_m,
num_threads_n,
num_threads_k);
arg.outStridesArray_[I.value],
gridSize,
arg.blockSize_,
num_threads_m,
num_threads_n,
num_threads_k);
},
Number<NumOutput>{});
......@@ -270,7 +276,7 @@ struct DeviceElementwise3dImpl : public DeviceElementwise<InDataTypeTuple,
arg.elementwise_op_,
num_threads_m,
num_threads_n,
num_threads_k);
num_threads_k);
return elapsed_time;
}
......@@ -296,42 +302,47 @@ struct DeviceElementwise3dImpl : public DeviceElementwise<InDataTypeTuple,
const std::array<index_t, NumDim>& strides,
index_t scalarPerVector,
index_t vectorDim) {
if(strides[vectorDim] == 1 &&
(lengths[vectorDim] % scalarPerVector == 0 ||
lengths[vectorDim] % scalarPerVector == lengths[vectorDim]))
{
return true;
}
if(strides[vectorDim] != 1 && scalarPerVector == strides[vectorDim])
{
return true;
}
return false;
ignore = lengths;
ignore = strides;
ignore = scalarPerVector;
ignore = vectorDim;
// if(strides[vectorDim] == 1 &&
//(lengths[vectorDim] % scalarPerVector == 0))
////lengths[vectorDim] % scalarPerVector == lengths[vectorDim]))
//{
// return true;
//}
// if(strides[vectorDim] >= scalarPerVector)
//{
// return true;
//}
return true;
};
bool valid = true;
static_for<0, NumInput, 1>{}([&](auto I) {
if(!IsScalarPerVectorValid(pArg->lengths_,
pArg->inStridesArray_[I.value],
InScalarPerVectorSeq::At(I),
NumDim_m - 1))
//LogRangeAsType<float>(std::cout << "in scalarperveq : ", InScalarPerVectorSeq::At(I), ",") << std::endl;
//LogRangeAsType<float>(std::cout << "vecdim : ", NumDim_m - 1, ",") << std::endl;
valid = false;
valid = valid && IsScalarPerVectorValid(pArg->lengths_,
pArg->inStridesArray_[I.value],
InScalarPerVectorSeq::At(I),
NumDim_m - 1);
// LogRangeAsType<float>(std::cout << "in scalarperveq : ",
// InScalarPerVectorSeq::At(I), ",") << std::endl; LogRangeAsType<float>(std::cout <<
// "vecdim : ", NumDim_m - 1, ",") << std::endl;
});
static_for<0, NumOutput, 1>{}([&](auto I) {
if(!IsScalarPerVectorValid(pArg->lengths_,
pArg->outStridesArray_[I.value],
OutScalarPerVectorSeq::At(I),
NumDim - 1))
//LogRangeAsType<float>(std::cout << "out scalarperveq : ", OutScalarPerVectorSeq::At(I), ",") << std::endl;
//LogRangeAsType<float>(std::cout << "vecdim : ", NumDim - 1, ",") << std::endl;
valid = false;
valid = valid && IsScalarPerVectorValid(pArg->lengths_,
pArg->outStridesArray_[I.value],
OutScalarPerVectorSeq::At(I),
NumDim - 1);
// LogRangeAsType<float>(std::cout << "out scalarperveq : ",
// OutScalarPerVectorSeq::At(I), ",") << std::endl; LogRangeAsType<float>(std::cout
// << "vecdim : ", NumDim - 1, ",") << std::endl;
});
return valid;
};
}
std::unique_ptr<BaseArgument>
MakeArgumentPointer(const std::array<index_t, NumDim> lengths,
......@@ -353,7 +364,7 @@ struct DeviceElementwise3dImpl : public DeviceElementwise<InDataTypeTuple,
std::unique_ptr<BaseInvoker> MakeInvokerPointer() override
{
return std::make_unique<Invoker>();
};
}
}; // namespace device
} // namespace device
......
......@@ -60,8 +60,8 @@ struct GridwiseElementwise_3D
static constexpr auto I1 = Number<1>{};
static constexpr auto I2 = Number<2>{};
static constexpr auto thread_buffer_desc_mnk =
make_naive_tensor_descriptor_packed(make_tuple(Number<MPerThread>{}, Number<NPerThread>{}, Number<KPerThread>{}));
static constexpr auto thread_buffer_desc_mnk = make_naive_tensor_descriptor_packed(
make_tuple(Number<MPerThread>{}, Number<NPerThread>{}, Number<KPerThread>{}));
using PassThroughOp = tensor_operation::element_wise::PassThrough;
......@@ -72,7 +72,7 @@ struct GridwiseElementwise_3D
const ElementwiseOperation elementwise_op,
const index_t num_threads_m,
const index_t num_threads_n,
const index_t num_threads_k)
const index_t num_threads_k)
{
auto in_thread_buf_tuple = generate_tuple(
[&](auto I) {
......@@ -114,21 +114,21 @@ struct GridwiseElementwise_3D
const auto M = in_grid_3d_desc_tuple[I0].GetLength(I0);
const auto N = in_grid_3d_desc_tuple[I0].GetLength(I1);
const auto K = in_grid_3d_desc_tuple[I0].GetLength(I2);
const auto K = in_grid_3d_desc_tuple[I0].GetLength(I2);
const index_t loop_step_m = num_threads_m * MPerThread;
const index_t loop_step_n = num_threads_n * NPerThread;
const index_t loop_step_k = num_threads_k * KPerThread;
const index_t thread_1d_id = get_thread_global_1d_id();
//index_t tid_m = thread_1d_id / num_threads_n;
//index_t tid_n = thread_1d_id % num_threads_n;
index_t tid_m = thread_1d_id / (num_threads_n * num_threads_k);
index_t tid_n = (thread_1d_id / (num_threads_n * num_threads_k)) / num_threads_k;
index_t tid_k = (thread_1d_id / (num_threads_n * num_threads_k)) % num_threads_k;
const index_t tid_m = thread_1d_id / (num_threads_n * num_threads_k);
const index_t tid_nk = thread_1d_id % (num_threads_n * num_threads_k);
const index_t tid_n = tid_nk / num_threads_k;
const index_t tid_k = tid_nk % num_threads_k;
const auto thread_global_offset = make_multi_index(tid_m * MPerThread, tid_n * NPerThread, tid_k * KPerThread);
const auto thread_global_offset =
make_multi_index(tid_m * MPerThread, tid_n * NPerThread, tid_k * KPerThread);
auto in_global_load_tuple = generate_tuple(
[&](auto I) {
......@@ -141,10 +141,10 @@ struct GridwiseElementwise_3D
decltype(in_grid_3d_desc_tuple[I]),
decltype(thread_buffer_desc_mnk),
Sequence<MPerThread, NPerThread, KPerThread>, // SliceLengths
Sequence<0, 1, 2>, // DimAccessOrder
0, // SrcVectorDim
InScalarPerVectorSeq::At(I), // ScalarPerVector
1, // SrcScalarStrideInVector
Sequence<0, 1, 2>, // DimAccessOrder
0, // SrcVectorDim
1, // InScalarPerVectorSeq::At(I), // ScalarPerVector
1, // SrcScalarStrideInVector
true>{in_grid_3d_desc_tuple[I], thread_global_offset};
},
Number<NumInput>{});
......@@ -161,9 +161,9 @@ struct GridwiseElementwise_3D
decltype(out_grid_3d_desc_tuple[I]),
PassThroughOp,
Sequence<MPerThread, NPerThread, KPerThread>, // SliceLengths
Sequence<0, 1, 2>, // DimAccessOrder
1, // SrcVectorDim
1, // OutScalarPerVectorSeq::At(I),
Sequence<0, 1, 2>, // DimAccessOrder
1, // SrcVectorDim
1, // OutScalarPerVectorSeq::At(I),
InMemoryDataOperationEnum::Set,
1,
true>(out_grid_3d_desc_tuple[I], thread_global_offset, PassThroughOp{});
......@@ -177,7 +177,8 @@ struct GridwiseElementwise_3D
do
{
index_t num_iter_k = K / (loop_step_k);
do{
do
{
static_for<0, NumInput, 1>{}([&](auto I) {
in_global_load_tuple(I).Run(in_grid_3d_desc_tuple[I],
in_global_buf_tuple[I],
......@@ -185,13 +186,13 @@ struct GridwiseElementwise_3D
make_tuple(I0, I0, I0),
in_thread_buf_tuple(I));
in_global_load_tuple(I).MoveSrcSliceWindow(in_grid_3d_desc_tuple[I],
make_multi_index(0, 0, loop_step_k));
in_global_load_tuple(I).MoveSrcSliceWindow(
in_grid_3d_desc_tuple[I], make_multi_index(0, 0, loop_step_k));
});
static_for<0, MPerThread, 1>{}([&](auto iM) {
static_for<0, NPerThread, 1>{}([&](auto iN) {
static_for<0, KPerThread, 1>{}([&](auto iK){
static_for<0, KPerThread, 1>{}([&](auto iK) {
constexpr auto offset =
thread_buffer_desc_mnk.CalculateOffset(make_tuple(iM, iN, iK));
// get reference to in data
......@@ -216,20 +217,20 @@ struct GridwiseElementwise_3D
static_for<0, NumOutput, 1>{}([&](auto I) {
out_global_store_tuple(I).Run(thread_buffer_desc_mnk,
make_tuple(I0, I0, I0),
out_thread_buf_tuple[I],
out_grid_3d_desc_tuple[I],
out_global_buf_tuple(I));
make_tuple(I0, I0, I0),
out_thread_buf_tuple[I],
out_grid_3d_desc_tuple[I],
out_global_buf_tuple(I));
out_global_store_tuple(I).MoveDstSliceWindow(out_grid_3d_desc_tuple[I],
make_multi_index(0, loop_step_n, loop_step_k));
out_global_store_tuple(I).MoveDstSliceWindow(
out_grid_3d_desc_tuple[I], make_multi_index(0, 0, loop_step_k));
});
} while (--num_iter_k);
} while(--num_iter_k);
static_for<0, NumInput, 1>{}([&](auto I) {
in_global_load_tuple(I).MoveSrcSliceWindow(
in_grid_3d_desc_tuple[I],
make_multi_index(0, loop_step_n, -(K / loop_step_k) * loop_step_k));
in_global_load_tuple(I).MoveSrcSliceWindow(
in_grid_3d_desc_tuple[I],
make_multi_index(0, loop_step_n, -(K / loop_step_k) * loop_step_k));
});
static_for<0, NumOutput, 1>{}([&](auto I) {
......@@ -243,13 +244,17 @@ struct GridwiseElementwise_3D
static_for<0, NumInput, 1>{}([&](auto I) {
in_global_load_tuple(I).MoveSrcSliceWindow(
in_grid_3d_desc_tuple[I],
make_multi_index(loop_step_m, -(N / loop_step_n) * loop_step_n, -(K / loop_step_k) * loop_step_k));
make_multi_index(loop_step_m,
-(N / loop_step_n) * loop_step_n,
-(K / loop_step_k) * loop_step_k));
});
static_for<0, NumOutput, 1>{}([&](auto I) {
out_global_store_tuple(I).MoveDstSliceWindow(
out_grid_3d_desc_tuple[I],
make_multi_index(loop_step_m, -(N / loop_step_n) * loop_step_n, -(K / loop_step_k) * loop_step_k));
make_multi_index(loop_step_m,
-(N / loop_step_n) * loop_step_n,
-(K / loop_step_k) * loop_step_k));
});
} while(--num_iter_m);
}
......
......@@ -12,7 +12,7 @@ cmake
-save-temps=$PWD" \
-D CMAKE_BUILD_TYPE=Release \
-D BUILD_DEV=ON \
-D GPU_TARGETS="gfx908;gfx90a;gfx940" \
-D GPU_TARGETS="gfx90a" \
-D CMAKE_VERBOSE_MAKEFILE:BOOL=ON \
-D USE_BITINT_EXTENSION_INT4=OFF \
${MY_PROJECT_SOURCE}
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment