Commit 7cb3d6fd authored by Jing Zhang's avatar Jing Zhang
Browse files

recover v3r1

parent 786a0faa
...@@ -335,6 +335,10 @@ struct GridwiseGemm_xdl_cshuffle_v3 ...@@ -335,6 +335,10 @@ struct GridwiseGemm_xdl_cshuffle_v3
using GemmSpecialization = tensor_operation::device::GemmSpecialization; using GemmSpecialization = tensor_operation::device::GemmSpecialization;
static_assert(!(is_same_v<remove_cvref_t<ADataType>, pk_i4_t> &&
GemmSpec != GemmSpecialization::Default),
"pk_i4_t does not support padding");
if constexpr(GemmSpec == GemmSpecialization::NKPadding || if constexpr(GemmSpec == GemmSpecialization::NKPadding ||
GemmSpec == GemmSpecialization::MNKPadding) GemmSpec == GemmSpecialization::MNKPadding)
{ {
......
...@@ -1125,10 +1125,8 @@ struct ThreadwiseTensorSliceTransfer_v4 ...@@ -1125,10 +1125,8 @@ struct ThreadwiseTensorSliceTransfer_v4
using src_vector_t = typename decltype(src_tmp_vector)::type; using src_vector_t = typename decltype(src_tmp_vector)::type;
// const bool is_src_valid = const bool is_src_valid = coordinate_has_valid_offset_assuming_visible_index_is_valid(
// coordinate_has_valid_offset_assuming_visible_index_is_valid( src_desc, src_desc, src_data_coord);
// src_data_coord);
const bool is_src_valid = true;
// copy data from src_buf into src_tmp_vector // copy data from src_buf into src_tmp_vector
if constexpr(SrcBuffer::IsDynamicBuffer()) if constexpr(SrcBuffer::IsDynamicBuffer())
......
...@@ -193,10 +193,10 @@ struct ThreadwiseTensorSliceTransfer_v3r1 ...@@ -193,10 +193,10 @@ struct ThreadwiseTensorSliceTransfer_v3r1
[&](auto i) { return Number<src_data_idx[i]>{}; }, Number<src_data_idx.Size()>{}); [&](auto i) { return Number<src_data_idx[i]>{}; }, Number<src_data_idx.Size()>{});
// maintain a container record is_src_valid, waiting for RunWrite use. // maintain a container record is_src_valid, waiting for RunWrite use.
// const bool is_src_valid = const bool is_src_valid =
// coordinate_has_valid_offset_assuming_visible_index_is_valid(src_desc, src_coord_); coordinate_has_valid_offset_assuming_visible_index_is_valid(src_desc, src_coord_);
// src_oob_thread_scratch_tuple_(thread_scratch_id) src_oob_thread_scratch_tuple_(thread_scratch_id)
//.template SetAsType<bool>(src_data_idx_seq, is_src_valid); .template SetAsType<bool>(src_data_idx_seq, is_src_valid);
using src_vector_type = vector_type_maker_t<SrcData, SrcScalarPerVector>; using src_vector_type = vector_type_maker_t<SrcData, SrcScalarPerVector>;
using src_vector_t = typename src_vector_type::type; using src_vector_t = typename src_vector_type::type;
...@@ -306,7 +306,7 @@ struct ThreadwiseTensorSliceTransfer_v3r1 ...@@ -306,7 +306,7 @@ struct ThreadwiseTensorSliceTransfer_v3r1
}); });
#else #else
#if 0 #if 1
// OOB Check // OOB Check
constexpr auto src_scalar_per_access = generate_sequence( constexpr auto src_scalar_per_access = generate_sequence(
detail::lambda_scalar_per_access<SrcVectorDim, SrcScalarPerVector_>{}, Number<nDim>{}); detail::lambda_scalar_per_access<SrcVectorDim, SrcScalarPerVector_>{}, Number<nDim>{});
...@@ -358,13 +358,13 @@ struct ThreadwiseTensorSliceTransfer_v3r1 ...@@ -358,13 +358,13 @@ struct ThreadwiseTensorSliceTransfer_v3r1
using vector_t = typename vector_type_maker<DstData, SrcScalarPerVector>::type::type; using vector_t = typename vector_type_maker<DstData, SrcScalarPerVector>::type::type;
auto op_r_v = src_thread_scratch_tuple_(thread_scratch_id) auto op_r = src_thread_scratch_tuple_(thread_scratch_id)
.template GetAsType<vector_t>(src_data_idx_seq); .template GetAsType<vector_t>(src_data_idx_seq);
// const bool is_src_valid = src_oob_thread_scratch_tuple_(thread_scratch_id) const bool is_src_valid = src_oob_thread_scratch_tuple_(thread_scratch_id)
//.template GetAsType<bool>(src_data_idx_seq); .template GetAsType<bool>(src_data_idx_seq);
// auto op_r_v = is_src_valid ? op_r : vector_t(0); auto op_r_v = is_src_valid ? op_r : vector_t(0);
src_thread_scratch_tuple_(thread_scratch_id) src_thread_scratch_tuple_(thread_scratch_id)
.template SetAsType<vector_t>(src_data_idx_seq, op_r_v); .template SetAsType<vector_t>(src_data_idx_seq, op_r_v);
...@@ -381,8 +381,9 @@ struct ThreadwiseTensorSliceTransfer_v3r1 ...@@ -381,8 +381,9 @@ struct ThreadwiseTensorSliceTransfer_v3r1
(is_same<f8_t, remove_cvref_t<DstData>>::value && (is_same<f8_t, remove_cvref_t<DstData>>::value &&
SrcScalarPerVector % 4 == 0 && DstScalarPerVector % 4 == 0))) SrcScalarPerVector % 4 == 0 && DstScalarPerVector % 4 == 0)))
{ {
static_assert(false, "no transpose allowed"); static_assert(is_same_v<remove_cvref_t<SrcData>, pk_i4_t>,
#if 0 "transpose is not allowed for pk_i4_t");
#if 1
// each transpose does // each transpose does
// DstScalarPerVector # of src vectors in src_thread_scratch_ // DstScalarPerVector # of src vectors in src_thread_scratch_
// SrcScalarPerVector # of dst vectors in dst_thread_scratch_ // SrcScalarPerVector # of dst vectors in dst_thread_scratch_
...@@ -874,8 +875,8 @@ struct ThreadwiseTensorSliceTransfer_v3r1 ...@@ -874,8 +875,8 @@ struct ThreadwiseTensorSliceTransfer_v3r1
private: private:
static constexpr auto src_thread_scratch_desc_ = decltype(GetSrcThreadScratchDescriptor()){}; static constexpr auto src_thread_scratch_desc_ = decltype(GetSrcThreadScratchDescriptor()){};
// static constexpr auto src_oob_thread_scratch_desc_ = static constexpr auto src_oob_thread_scratch_desc_ =
// decltype(GetSrcThreadScratchDescriptor()){}; decltype(GetSrcThreadScratchDescriptor()){};
static constexpr auto dst_thread_scratch_desc_ = decltype(GetDstThreadScratchDescriptor()){}; static constexpr auto dst_thread_scratch_desc_ = decltype(GetDstThreadScratchDescriptor()){};
using SrcThreadScratch = using SrcThreadScratch =
...@@ -885,12 +886,12 @@ struct ThreadwiseTensorSliceTransfer_v3r1 ...@@ -885,12 +886,12 @@ struct ThreadwiseTensorSliceTransfer_v3r1
decltype(src_thread_scratch_desc_), decltype(src_thread_scratch_desc_),
true>; true>;
// using SrcOOBThreadScratch = using SrcOOBThreadScratch =
// StaticTensorTupleOfVectorBuffer<AddressSpaceEnum::Vgpr, StaticTensorTupleOfVectorBuffer<AddressSpaceEnum::Vgpr,
// bool, // apply data_convert with SrcThreadScratch bool, // apply data_convert with SrcThreadScratch
// 1, 1,
// decltype(src_oob_thread_scratch_desc_), decltype(src_oob_thread_scratch_desc_),
// true>; true>;
using DstThreadScratch = StaticTensorTupleOfVectorBuffer<AddressSpaceEnum::Vgpr, using DstThreadScratch = StaticTensorTupleOfVectorBuffer<AddressSpaceEnum::Vgpr,
DstData, DstData,
...@@ -899,7 +900,7 @@ struct ThreadwiseTensorSliceTransfer_v3r1 ...@@ -899,7 +900,7 @@ struct ThreadwiseTensorSliceTransfer_v3r1
true>; true>;
StaticallyIndexedArray<SrcThreadScratch, NumThreadScratch> src_thread_scratch_tuple_; StaticallyIndexedArray<SrcThreadScratch, NumThreadScratch> src_thread_scratch_tuple_;
// StaticallyIndexedArray<SrcOOBThreadScratch, NumThreadScratch> src_oob_thread_scratch_tuple_; StaticallyIndexedArray<SrcOOBThreadScratch, NumThreadScratch> src_oob_thread_scratch_tuple_;
DstThreadScratch dst_thread_scratch_; DstThreadScratch dst_thread_scratch_;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment