Commit 7cb3d6fd authored by Jing Zhang's avatar Jing Zhang
Browse files

recover v3r1

parent 786a0faa
......@@ -335,6 +335,10 @@ struct GridwiseGemm_xdl_cshuffle_v3
using GemmSpecialization = tensor_operation::device::GemmSpecialization;
static_assert(!(is_same_v<remove_cvref_t<ADataType>, pk_i4_t> &&
GemmSpec != GemmSpecialization::Default),
"pk_i4_t does not support padding");
if constexpr(GemmSpec == GemmSpecialization::NKPadding ||
GemmSpec == GemmSpecialization::MNKPadding)
{
......
......@@ -1125,10 +1125,8 @@ struct ThreadwiseTensorSliceTransfer_v4
using src_vector_t = typename decltype(src_tmp_vector)::type;
// const bool is_src_valid =
// coordinate_has_valid_offset_assuming_visible_index_is_valid( src_desc,
// src_data_coord);
const bool is_src_valid = true;
const bool is_src_valid = coordinate_has_valid_offset_assuming_visible_index_is_valid(
src_desc, src_data_coord);
// copy data from src_buf into src_tmp_vector
if constexpr(SrcBuffer::IsDynamicBuffer())
......
......@@ -193,10 +193,10 @@ struct ThreadwiseTensorSliceTransfer_v3r1
[&](auto i) { return Number<src_data_idx[i]>{}; }, Number<src_data_idx.Size()>{});
// maintain a container record is_src_valid, waiting for RunWrite use.
// const bool is_src_valid =
// coordinate_has_valid_offset_assuming_visible_index_is_valid(src_desc, src_coord_);
// src_oob_thread_scratch_tuple_(thread_scratch_id)
//.template SetAsType<bool>(src_data_idx_seq, is_src_valid);
const bool is_src_valid =
coordinate_has_valid_offset_assuming_visible_index_is_valid(src_desc, src_coord_);
src_oob_thread_scratch_tuple_(thread_scratch_id)
.template SetAsType<bool>(src_data_idx_seq, is_src_valid);
using src_vector_type = vector_type_maker_t<SrcData, SrcScalarPerVector>;
using src_vector_t = typename src_vector_type::type;
......@@ -306,7 +306,7 @@ struct ThreadwiseTensorSliceTransfer_v3r1
});
#else
#if 0
#if 1
// OOB Check
constexpr auto src_scalar_per_access = generate_sequence(
detail::lambda_scalar_per_access<SrcVectorDim, SrcScalarPerVector_>{}, Number<nDim>{});
......@@ -358,13 +358,13 @@ struct ThreadwiseTensorSliceTransfer_v3r1
using vector_t = typename vector_type_maker<DstData, SrcScalarPerVector>::type::type;
auto op_r_v = src_thread_scratch_tuple_(thread_scratch_id)
.template GetAsType<vector_t>(src_data_idx_seq);
auto op_r = src_thread_scratch_tuple_(thread_scratch_id)
.template GetAsType<vector_t>(src_data_idx_seq);
// const bool is_src_valid = src_oob_thread_scratch_tuple_(thread_scratch_id)
//.template GetAsType<bool>(src_data_idx_seq);
const bool is_src_valid = src_oob_thread_scratch_tuple_(thread_scratch_id)
.template GetAsType<bool>(src_data_idx_seq);
// auto op_r_v = is_src_valid ? op_r : vector_t(0);
auto op_r_v = is_src_valid ? op_r : vector_t(0);
src_thread_scratch_tuple_(thread_scratch_id)
.template SetAsType<vector_t>(src_data_idx_seq, op_r_v);
......@@ -381,8 +381,9 @@ struct ThreadwiseTensorSliceTransfer_v3r1
(is_same<f8_t, remove_cvref_t<DstData>>::value &&
SrcScalarPerVector % 4 == 0 && DstScalarPerVector % 4 == 0)))
{
static_assert(false, "no transpose allowed");
#if 0
static_assert(is_same_v<remove_cvref_t<SrcData>, pk_i4_t>,
"transpose is not allowed for pk_i4_t");
#if 1
// each transpose does
// DstScalarPerVector # of src vectors in src_thread_scratch_
// SrcScalarPerVector # of dst vectors in dst_thread_scratch_
......@@ -874,8 +875,8 @@ struct ThreadwiseTensorSliceTransfer_v3r1
private:
static constexpr auto src_thread_scratch_desc_ = decltype(GetSrcThreadScratchDescriptor()){};
// static constexpr auto src_oob_thread_scratch_desc_ =
// decltype(GetSrcThreadScratchDescriptor()){};
static constexpr auto src_oob_thread_scratch_desc_ =
decltype(GetSrcThreadScratchDescriptor()){};
static constexpr auto dst_thread_scratch_desc_ = decltype(GetDstThreadScratchDescriptor()){};
using SrcThreadScratch =
......@@ -885,12 +886,12 @@ struct ThreadwiseTensorSliceTransfer_v3r1
decltype(src_thread_scratch_desc_),
true>;
// using SrcOOBThreadScratch =
// StaticTensorTupleOfVectorBuffer<AddressSpaceEnum::Vgpr,
// bool, // apply data_convert with SrcThreadScratch
// 1,
// decltype(src_oob_thread_scratch_desc_),
// true>;
using SrcOOBThreadScratch =
StaticTensorTupleOfVectorBuffer<AddressSpaceEnum::Vgpr,
bool, // apply data_convert with SrcThreadScratch
1,
decltype(src_oob_thread_scratch_desc_),
true>;
using DstThreadScratch = StaticTensorTupleOfVectorBuffer<AddressSpaceEnum::Vgpr,
DstData,
......@@ -899,7 +900,7 @@ struct ThreadwiseTensorSliceTransfer_v3r1
true>;
StaticallyIndexedArray<SrcThreadScratch, NumThreadScratch> src_thread_scratch_tuple_;
// StaticallyIndexedArray<SrcOOBThreadScratch, NumThreadScratch> src_oob_thread_scratch_tuple_;
StaticallyIndexedArray<SrcOOBThreadScratch, NumThreadScratch> src_oob_thread_scratch_tuple_;
DstThreadScratch dst_thread_scratch_;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment