Commit 5530440b authored by Jing Zhang's avatar Jing Zhang
Browse files

fixed v7r2 copy

parent 2baf0613
......@@ -37,7 +37,7 @@ using DDataType = F16;
using EDataType = F16;
using ALayout = Row;
using BLayout = Col;
using BLayout = Row;
using DLayout = Row;
using ELayout = Row;
......@@ -141,9 +141,9 @@ using DeviceOpInstance = ck::tensor_operation::device::DeviceGemmMultipleABD_Xdl
S<4, 64, 1>,
S<1, 0, 2>,
S<1, 0, 2>,
2,
8,
8,
1,
1,
1,
1,
1,
1,
......@@ -161,10 +161,10 @@ int main(int argc, char* argv[])
ck::index_t N = 4096;
ck::index_t K = 4096;
ck::index_t StrideA = 4096;
ck::index_t StrideB = 4096;
ck::index_t StrideD = 4096;
ck::index_t StrideE = 4096;
ck::index_t StrideA = K;
ck::index_t StrideB = N;
ck::index_t StrideD = N;
ck::index_t StrideE = N;
float alpha = 1.0f;
float beta = 1.0f;
......
......@@ -141,7 +141,7 @@ struct ThreadwiseTensorSliceTransfer_v7r2
// loop over space-filling curve
static_for<0, src_num_access, 1>{}([&](auto iAccess) {
auto src_vectors = generate_vectors<SrcDatas, SrcScalarPerVector>();
auto dst_vectors = generate_vectors<DstDatas, SrcScalarPerVector>();
auto elm_vectors = generate_vectors<DstDatas, SrcScalarPerVector>();
// copy data from src_bufs into src_vectors
static_for<0, nSrc, 1>{}([&](auto i) {
......@@ -199,7 +199,7 @@ struct ThreadwiseTensorSliceTransfer_v7r2
using elem_op_vec_t = typename vector_type<DstData, elem_op_vec_len>::type;
return dst_vectors(iDst).template AsType<elem_op_vec_t>()(i);
return elm_vectors(iDst).template AsType<elem_op_vec_t>()(i);
},
Number<nDst>{});
......@@ -214,7 +214,7 @@ struct ThreadwiseTensorSliceTransfer_v7r2
unpack2(element_op_, dst_data_refs, src_data_refs);
});
dst_vectors_tuple_(iAccess) = dst_vectors;
elm_vectors_tuple_(iAccess) = elm_vectors;
// move coordinate
if constexpr(iAccess.value != src_num_access - 1)
......@@ -247,6 +247,8 @@ struct ThreadwiseTensorSliceTransfer_v7r2
enable_if_t<DstDescs::Size() == DstBuffers::Size(), bool> = false>
__device__ void RunWrite(const DstDescs& dst_descs, DstBuffers dst_bufs)
{
dst_vectors_tuple_ = bit_cast<decltype(dst_vectors_tuple_)>(elm_vectors_tuple_);
// loop over space-filling curve
static_for<0, dst_num_access, 1>{}([&](auto iAccess) {
auto dst_vectors = dst_vectors_tuple_[iAccess];
......@@ -372,11 +374,13 @@ struct ThreadwiseTensorSliceTransfer_v7r2
private:
using SrcVectorsType = decltype(generate_vectors<SrcDatas, SrcScalarPerVector>());
using ElmVectorsType = decltype(generate_vectors<DstDatas, SrcScalarPerVector>());
using DstVectorsType = decltype(generate_vectors<DstDatas, DstScalarPerVector>());
static constexpr auto src_num_access = SrcSpaceFillingCurve::GetNumOfAccess();
static constexpr auto dst_num_access = DstSpaceFillingCurve::GetNumOfAccess();
StaticallyIndexedArray<ElmVectorsType, src_num_access> elm_vectors_tuple_;
StaticallyIndexedArray<DstVectorsType, dst_num_access> dst_vectors_tuple_;
SrcCoords src_coords_;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment