Commit 24e18ae8 authored by Jing Zhang's avatar Jing Zhang
Browse files

fixed coord reset

parent c3a4652a
...@@ -48,10 +48,9 @@ using DeviceGemmV2Instance = ...@@ -48,10 +48,9 @@ using DeviceGemmV2Instance =
S<16, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, S<16, 8, 1>, S<1, 0, 2>, S<1, 0, 2>,
2, 8, 8, 0, 2, 8, 8, 0,
S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>,
2, 32, 32, 1, 2, 32, 32, 0,
1, 1, S<1, 16, 1, 8>, 4, 1, 1, S<1, 16, 1, 8>, 4,
ck::BlockGemmPipelineScheduler::Interwave, ck::BlockGemmPipelineVersion::v1>; ck::BlockGemmPipelineScheduler::Interwave, ck::BlockGemmPipelineVersion::v1>;
#endif #endif
// clang-format on // clang-format on
......
...@@ -224,14 +224,14 @@ bool run_gemm(const ProblemType& problem_size, const ExecutionConfig& config) ...@@ -224,14 +224,14 @@ bool run_gemm(const ProblemType& problem_size, const ExecutionConfig& config)
// get_rtol<CDataType>(), // get_rtol<CDataType>(),
// get_atol<CDataType>()); // get_atol<CDataType>());
for(int i = 0; i < M; i++) //for(int i = 0; i < M; i++)
{ //{
for(int j = 0; j < N; j++) // for(int j = 0; j < N; j++)
{ // {
std::cout << ck::type_convert<float>(c_m_n_device_result(i, j)) << ","; // std::cout << ck::type_convert<float>(c_m_n_device_result(i, j)) << ",";
} // }
std::cout << std::endl; // std::cout << std::endl;
} //}
} }
if(config.time_kernel) if(config.time_kernel)
......
...@@ -775,7 +775,7 @@ struct GridwiseGemm_xdl_cshuffle_v3 ...@@ -775,7 +775,7 @@ struct GridwiseGemm_xdl_cshuffle_v3
{ {
// NLdsLayer * K0 as logical Bank // NLdsLayer * K0 as logical Bank
constexpr index_t LdsSize = 32 * 4 / KPerBlock / sizeof(BDataType); constexpr index_t LdsSize = 32 * 4 / KPerBlock / sizeof(BDataType);
constexpr auto NLdsLayer = LdsSize < 1 ? 1 : LdsSize; constexpr index_t NLdsLayer = LdsSize < 1 ? 1 : LdsSize;
constexpr auto b_lds_block_desc = make_naive_tensor_descriptor( constexpr auto b_lds_block_desc = make_naive_tensor_descriptor(
make_tuple( make_tuple(
BK0Number * Number<NLdsLayer>{}, Number<NPerBlock / NLdsLayer>{}, BK1Number), BK0Number * Number<NLdsLayer>{}, Number<NPerBlock / NLdsLayer>{}, BK1Number),
...@@ -1318,17 +1318,14 @@ struct GridwiseGemm_xdl_cshuffle_v3 ...@@ -1318,17 +1318,14 @@ struct GridwiseGemm_xdl_cshuffle_v3
constexpr auto a_block_space_size_aligned = math::integer_least_multiple( constexpr auto a_block_space_size_aligned = math::integer_least_multiple(
a_block_desc_ak0_m_ak1.GetElementSpaceSize(), max_lds_align); a_block_desc_ak0_m_ak1.GetElementSpaceSize(), max_lds_align);
constexpr auto b_block_space_size_aligned = math::integer_least_multiple(
b_block_desc_bk0_n_bk1.GetElementSpaceSize(), max_lds_align);
// Cast after lds // Cast after lds
auto a_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>( auto a_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
static_cast<ADataType*>(p_shared), a_block_space_size_aligned); static_cast<ADataType*>(p_shared), a_block_desc_ak0_m_ak1.GetElementSpaceSize());
auto b_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>( auto b_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
reinterpret_cast<BDataType*>(static_cast<char*>(p_shared) + reinterpret_cast<BDataType*>(static_cast<char*>(p_shared) +
a_block_space_size_aligned * sizeof(ADataType) / APackedSize), a_block_space_size_aligned * sizeof(ADataType) / APackedSize),
b_block_space_size_aligned); b_block_desc_bk0_n_bk1.GetElementSpaceSize());
constexpr auto a_block_slice_copy_step = make_multi_index(KPerBlock / AK1Number, 0, 0); constexpr auto a_block_slice_copy_step = make_multi_index(KPerBlock / AK1Number, 0, 0);
constexpr auto b_block_slice_copy_step = make_multi_index(KPerBlock / BK1Number, 0, 0); constexpr auto b_block_slice_copy_step = make_multi_index(KPerBlock / BK1Number, 0, 0);
......
...@@ -1133,12 +1133,13 @@ struct ThreadwiseTensorSliceTransfer_v4 ...@@ -1133,12 +1133,13 @@ struct ThreadwiseTensorSliceTransfer_v4
} }
else if constexpr(SrcBuffer::IsStaticBuffer()) else if constexpr(SrcBuffer::IsStaticBuffer())
{ {
static_assert(false, "");
static_for<0, SrcScalarPerVector, 1>{}([&](auto i) { static_for<0, SrcScalarPerVector, 1>{}([&](auto i) {
constexpr index_t src_offset = src_desc.CalculateOffset( constexpr index_t src_offset = src_desc.CalculateOffset(
src_ref_to_origin_disp_idx + data_to_origin_disp_idx + src_ref_to_origin_disp_idx + data_to_origin_disp_idx +
i * src_scalar_step_in_vector); i * src_scalar_step_in_vector);
src_tmp_vector.template AsType<SrcData>()(i) = src_buf[Number<src_offset / PackedSize>{}]; src_tmp_vector.template AsType<SrcData>()(i) = src_buf[Number<src_offset>{}];
}); });
} }
......
...@@ -185,10 +185,10 @@ struct ThreadwiseTensorSliceTransfer_v3r1 ...@@ -185,10 +185,10 @@ struct ThreadwiseTensorSliceTransfer_v3r1
[&](auto i) { return Number<src_data_idx[i]>{}; }, Number<src_data_idx.Size()>{}); [&](auto i) { return Number<src_data_idx[i]>{}; }, Number<src_data_idx.Size()>{});
// maintain a container record is_src_valid, waiting for RunWrite use. // maintain a container record is_src_valid, waiting for RunWrite use.
const bool is_src_valid = //const bool is_src_valid =
coordinate_has_valid_offset_assuming_visible_index_is_valid(src_desc, src_coord_); //coordinate_has_valid_offset_assuming_visible_index_is_valid(src_desc, src_coord_);
src_oob_thread_scratch_tuple_(thread_scratch_id) //src_oob_thread_scratch_tuple_(thread_scratch_id)
.template SetAsType<bool>(src_data_idx_seq, is_src_valid); //.template SetAsType<bool>(src_data_idx_seq, is_src_valid);
using src_vector_type = vector_type_maker_t<SrcData, SrcScalarPerVector>; using src_vector_type = vector_type_maker_t<SrcData, SrcScalarPerVector>;
using src_vector_t = typename src_vector_type::type; using src_vector_t = typename src_vector_type::type;
...@@ -347,13 +347,13 @@ struct ThreadwiseTensorSliceTransfer_v3r1 ...@@ -347,13 +347,13 @@ struct ThreadwiseTensorSliceTransfer_v3r1
using vector_t = typename vector_type_maker<DstData, SrcScalarPerVector>::type::type; using vector_t = typename vector_type_maker<DstData, SrcScalarPerVector>::type::type;
auto op_r = src_thread_scratch_tuple_(thread_scratch_id) auto op_r_v = src_thread_scratch_tuple_(thread_scratch_id)
.template GetAsType<vector_t>(src_data_idx_seq); .template GetAsType<vector_t>(src_data_idx_seq);
const bool is_src_valid = src_oob_thread_scratch_tuple_(thread_scratch_id) //const bool is_src_valid = src_oob_thread_scratch_tuple_(thread_scratch_id)
.template GetAsType<bool>(src_data_idx_seq); //.template GetAsType<bool>(src_data_idx_seq);
auto op_r_v = is_src_valid ? op_r : vector_t(0); //auto op_r_v = is_src_valid ? op_r : vector_t(0);
src_thread_scratch_tuple_(thread_scratch_id) src_thread_scratch_tuple_(thread_scratch_id)
.template SetAsType<vector_t>(src_data_idx_seq, op_r_v); .template SetAsType<vector_t>(src_data_idx_seq, op_r_v);
...@@ -537,8 +537,8 @@ struct ThreadwiseTensorSliceTransfer_v3r1 ...@@ -537,8 +537,8 @@ struct ThreadwiseTensorSliceTransfer_v3r1
constexpr auto dst_data_idx_seq = generate_sequence_v2( constexpr auto dst_data_idx_seq = generate_sequence_v2(
[&](auto i) { return Number<dst_data_idx[i]>{}; }, Number<dst_data_idx.Size()>{}); [&](auto i) { return Number<dst_data_idx[i]>{}; }, Number<dst_data_idx.Size()>{});
const bool is_dst_valid = //const bool is_dst_valid =
coordinate_has_valid_offset_assuming_visible_index_is_valid(dst_desc, dst_coord_); //coordinate_has_valid_offset_assuming_visible_index_is_valid(dst_desc, dst_coord_);
using dst_vector_type = vector_type_maker_t<DstData, DstScalarPerVector>; using dst_vector_type = vector_type_maker_t<DstData, DstScalarPerVector>;
using dst_vector_t = typename dst_vector_type::type; using dst_vector_t = typename dst_vector_type::type;
...@@ -552,15 +552,13 @@ struct ThreadwiseTensorSliceTransfer_v3r1 ...@@ -552,15 +552,13 @@ struct ThreadwiseTensorSliceTransfer_v3r1
// apply DstElementwiseOperation // apply DstElementwiseOperation
dst_element_op_(dst_v, dst_vector_container.template AsType<DstData>()[i]); dst_element_op_(dst_v, dst_vector_container.template AsType<DstData>()[i]);
dst_vector_container.template AsType<DstData>()(i) = dst_v;
}); });
// copy data from dst_vector_container to dst_buf // copy data from dst_vector_container to dst_buf
dst_buf.template Set<dst_vector_t>( dst_buf.template Set<dst_vector_t>(
dst_coord_.GetOffset() / PackedSize, dst_coord_.GetOffset() / PackedSize,
is_dst_valid, true,
dst_vector_container.template AsType<dst_vector_t>()[I0]); dst_vector_container.template AsType<dst_vector_t>()[I0]);
constexpr auto move_on_dim = [&]() constexpr constexpr auto move_on_dim = [&]() constexpr
{ {
...@@ -612,7 +610,7 @@ struct ThreadwiseTensorSliceTransfer_v3r1 ...@@ -612,7 +610,7 @@ struct ThreadwiseTensorSliceTransfer_v3r1
// scalar per access on each dim // scalar per access on each dim
// TODO: don't use lambda_scalar_per_access // TODO: don't use lambda_scalar_per_access
constexpr auto src_scalar_per_access = generate_sequence( constexpr auto src_scalar_per_access = generate_sequence(
detail::lambda_scalar_per_access<SrcVectorDim, SrcScalarPerVector>{}, Number<nDim>{}); detail::lambda_scalar_per_access<SrcVectorDim, SrcScalarPerVector_>{}, Number<nDim>{});
constexpr auto src_access_lengths = SliceLengths{} / src_scalar_per_access; constexpr auto src_access_lengths = SliceLengths{} / src_scalar_per_access;
...@@ -670,7 +668,7 @@ struct ThreadwiseTensorSliceTransfer_v3r1 ...@@ -670,7 +668,7 @@ struct ThreadwiseTensorSliceTransfer_v3r1
// scalar per access on each dim // scalar per access on each dim
// TODO: don't use lambda_scalar_per_access // TODO: don't use lambda_scalar_per_access
constexpr auto dst_scalar_per_access = generate_sequence( constexpr auto dst_scalar_per_access = generate_sequence(
detail::lambda_scalar_per_access<DstVectorDim, DstScalarPerVector>{}, Number<nDim>{}); detail::lambda_scalar_per_access<DstVectorDim, DstScalarPerVector_>{}, Number<nDim>{});
constexpr auto dst_access_lengths = SliceLengths{} / dst_scalar_per_access; constexpr auto dst_access_lengths = SliceLengths{} / dst_scalar_per_access;
...@@ -756,12 +754,12 @@ struct ThreadwiseTensorSliceTransfer_v3r1 ...@@ -756,12 +754,12 @@ struct ThreadwiseTensorSliceTransfer_v3r1
__device__ static constexpr auto GetSrcThreadScratchDescriptor() __device__ static constexpr auto GetSrcThreadScratchDescriptor()
{ {
constexpr auto src_scalar_per_access = generate_sequence( constexpr auto src_scalar_per_access = generate_sequence(
detail::lambda_scalar_per_access<SrcVectorDim, SrcScalarPerVector>{}, Number<nDim>{}); detail::lambda_scalar_per_access<SrcVectorDim, SrcScalarPerVector_>{}, Number<nDim>{});
constexpr auto src_access_lengths = SliceLengths{} / src_scalar_per_access; constexpr auto src_access_lengths = SliceLengths{} / src_scalar_per_access;
constexpr auto src_access_lengths_and_vector_length = container_push_back( constexpr auto src_access_lengths_and_vector_length = container_push_back(
sequence_to_tuple_of_number(src_access_lengths), Number<SrcScalarPerVector>{}); sequence_to_tuple_of_number(src_access_lengths), Number<SrcScalarPerVector_>{});
// 1st stage of transforms // 1st stage of transforms
constexpr auto desc0 = constexpr auto desc0 =
...@@ -805,7 +803,7 @@ struct ThreadwiseTensorSliceTransfer_v3r1 ...@@ -805,7 +803,7 @@ struct ThreadwiseTensorSliceTransfer_v3r1
__device__ static constexpr auto GetSrcOOBThreadScratchDescriptor() __device__ static constexpr auto GetSrcOOBThreadScratchDescriptor()
{ {
constexpr auto src_scalar_per_access = generate_sequence( constexpr auto src_scalar_per_access = generate_sequence(
detail::lambda_scalar_per_access<SrcVectorDim, SrcScalarPerVector>{}, Number<nDim>{}); detail::lambda_scalar_per_access<SrcVectorDim, SrcScalarPerVector_>{}, Number<nDim>{});
constexpr auto src_access_lengths = SliceLengths{} / src_scalar_per_access; constexpr auto src_access_lengths = SliceLengths{} / src_scalar_per_access;
...@@ -816,12 +814,12 @@ struct ThreadwiseTensorSliceTransfer_v3r1 ...@@ -816,12 +814,12 @@ struct ThreadwiseTensorSliceTransfer_v3r1
{ {
// 1st stage of transforms // 1st stage of transforms
constexpr auto dst_scalar_per_access = generate_sequence( constexpr auto dst_scalar_per_access = generate_sequence(
detail::lambda_scalar_per_access<DstVectorDim, DstScalarPerVector>{}, Number<nDim>{}); detail::lambda_scalar_per_access<DstVectorDim, DstScalarPerVector_>{}, Number<nDim>{});
constexpr auto dst_access_lengths = SliceLengths{} / dst_scalar_per_access; constexpr auto dst_access_lengths = SliceLengths{} / dst_scalar_per_access;
constexpr auto dst_access_lengths_and_vector_length = container_push_back( constexpr auto dst_access_lengths_and_vector_length = container_push_back(
sequence_to_tuple_of_number(dst_access_lengths), Number<DstScalarPerVector>{}); sequence_to_tuple_of_number(dst_access_lengths), Number<DstScalarPerVector_>{});
constexpr auto desc0 = constexpr auto desc0 =
make_naive_tensor_descriptor_packed(dst_access_lengths_and_vector_length); make_naive_tensor_descriptor_packed(dst_access_lengths_and_vector_length);
...@@ -874,12 +872,12 @@ struct ThreadwiseTensorSliceTransfer_v3r1 ...@@ -874,12 +872,12 @@ struct ThreadwiseTensorSliceTransfer_v3r1
decltype(src_thread_scratch_desc_), decltype(src_thread_scratch_desc_),
true>; true>;
using SrcOOBThreadScratch = //using SrcOOBThreadScratch =
StaticTensorTupleOfVectorBuffer<AddressSpaceEnum::Vgpr, //StaticTensorTupleOfVectorBuffer<AddressSpaceEnum::Vgpr,
bool, // apply data_convert with SrcThreadScratch //bool, // apply data_convert with SrcThreadScratch
1, //1,
decltype(src_oob_thread_scratch_desc_), //decltype(src_oob_thread_scratch_desc_),
true>; //true>;
using DstThreadScratch = StaticTensorTupleOfVectorBuffer<AddressSpaceEnum::Vgpr, using DstThreadScratch = StaticTensorTupleOfVectorBuffer<AddressSpaceEnum::Vgpr,
DstData, DstData,
...@@ -888,7 +886,7 @@ struct ThreadwiseTensorSliceTransfer_v3r1 ...@@ -888,7 +886,7 @@ struct ThreadwiseTensorSliceTransfer_v3r1
true>; true>;
StaticallyIndexedArray<SrcThreadScratch, NumThreadScratch> src_thread_scratch_tuple_; StaticallyIndexedArray<SrcThreadScratch, NumThreadScratch> src_thread_scratch_tuple_;
StaticallyIndexedArray<SrcOOBThreadScratch, NumThreadScratch> src_oob_thread_scratch_tuple_; //StaticallyIndexedArray<SrcOOBThreadScratch, NumThreadScratch> src_oob_thread_scratch_tuple_;
DstThreadScratch dst_thread_scratch_; DstThreadScratch dst_thread_scratch_;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment