Commit 69977fab authored by aska-0096's avatar aska-0096
Browse files

tempsave

parent 1e339898
......@@ -28,15 +28,15 @@ using DeviceGemmV2Instance =
ADataType, BDataType, CDataType, AccDataType, CShuffleDataType,
PassThrough, PassThrough, PassThrough, GemmDefault,
256,
224, 256,
256, 256,
64, 8, 8,
16, 16,
7, 8,
8, 8,
S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>,
2, 8, 8, 0,
S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>,
1, 8, 8, 0,
1, 2, S<1, 32, 1, 8>, 8,
1, 2, S<1, 32, 1, 8>, 8, // TODO: Deprecated
ck::BlockGemmPipelineScheduler::Intrawave,ck::BlockGemmPipelineVersion::v3>;
// clang-format on
......
......@@ -218,32 +218,6 @@ struct StaticTensorTupleOfVectorBuffer
}
}
template <typename X,
typename Idx,
typename enable_if<has_same_scalar_type<S, X>::value &&
is_known_at_compile_time<Idx>::value && Idx::Size() == ndim_,
bool>::type = false>
__host__ __device__ constexpr void SetAsType_Print(Idx, X x)
{
constexpr auto coord = make_tensor_coordinate(desc_, to_multi_index(Idx{}));
constexpr index_t offset = coord.GetOffset();
if(get_thread_local_1d_id()==0){
printf("Tid: %d, Index: (%d, %d, %d, %d), Offset: %d\n", get_thread_local_1d_id(),
Idx{}.At(Number<0>{}).value,
Idx{}.At(Number<1>{}).value,
Idx{}.At(Number<2>{}).value,
Idx{}.At(Number<3>{}).value, offset);
}
constexpr bool is_valid = coordinate_has_valid_offset(desc_, coord);
if constexpr(is_valid)
{
data_.template SetAsType<X>(Number<offset>{}, x);
}
}
// Get read access to V. No is_valid check
// Idx is for S, not V. Idx should be aligned with V
template <typename Idx>
......
......@@ -302,21 +302,17 @@ struct BlockwiseGemmXdlops_pipeline_base
return xdlops_gemm.MakeCDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(c_block_desc_m0_n0_m1_n1_m2_n2);
}
__host__ __device__ static constexpr auto
GetCBlockDescriptor_MBlock_NBlock_M0_M1_N0_M2_M3_N1_N2_M4()
__host__ __device__ static constexpr auto GetCBlockDescriptor_M0_M1_N0_M2_M3_N1_N2_M4()
{
constexpr auto c_block_desc_mblock_nblock_m0_n0_m1_n1_m2_n2 =
make_naive_tensor_descriptor_packed(make_tuple(I1,
I1,
Number<MRepeat>{},
constexpr auto c_block_desc_m0_n0_m1_n1_m2_n2 =
make_naive_tensor_descriptor_packed(make_tuple(Number<MRepeat>{},
Number<NRepeat>{},
Number<MWaves>{},
Number<NWaves>{},
Number<MPerXDL>{},
Number<NPerXDL>{}));
return xdlops_gemm.MakeCDescriptor_MBlock_NBlock_M0_M1_N0_M2_M3_N1_N2_M4(
c_block_desc_mblock_nblock_m0_n0_m1_n1_m2_n2);
return xdlops_gemm.MakeCDescriptor_M0_M1_N0_M2_M3_N1_N2_M4(c_block_desc_m0_n0_m1_n1_m2_n2);
}
__host__ __device__ static constexpr auto GetCBlockDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2()
......
......@@ -332,12 +332,12 @@ struct BlockwiseGemmXdlops_pipeline_v3<BlockGemmPipelineScheduler::Intrawave,
// Local prefetch 1
block_sync_lds();
static_for<0, KRepeat, 1>{}([&](auto k0) {
// a_thread_copy_.Run(a_block_desc_m0_m1_m2_k,
// make_tuple(I0, I0, I0, Number<k0 * AMmaKStride>{}),
// a_block_buf,
// a_thread_desc_,
// make_tuple(I0, I0, k0, I0),
// a_thread_buf);
a_thread_copy_.Run(a_block_desc_m0_m1_m2_k,
make_tuple(I0, I0, I0, Number<k0 * AMmaKStride>{}),
a_block_buf,
a_thread_desc_,
make_tuple(I0, I0, k0, I0),
a_thread_buf);
b_thread_copy_.Run(b_block_desc_n0_n1_n2_k,
make_tuple(I0, I0, I0, Number<k0 * BMmaKStride>{}),
......@@ -399,12 +399,12 @@ struct BlockwiseGemmXdlops_pipeline_v3<BlockGemmPipelineScheduler::Intrawave,
block_sync_lds();
static_for<0, KRepeat, 1>{}([&](auto k0) {
// a_thread_copy_.Run(a_block_desc_m0_m1_m2_k,
// make_tuple(I0, I0, I0, Number<k0 * AMmaKStride>{}),
// a_block_buf,
// a_thread_desc_,
// make_tuple(I0, I0, k0, I0),
// a_thread_buf);
a_thread_copy_.Run(a_block_desc_m0_m1_m2_k,
make_tuple(I0, I0, I0, Number<k0 * AMmaKStride>{}),
a_block_buf,
a_thread_desc_,
make_tuple(I0, I0, k0, I0),
a_thread_buf);
b_thread_copy_.Run(b_block_desc_n0_n1_n2_k,
make_tuple(I0, I0, I0, Number<k0 * BMmaKStride>{}),
......
......@@ -146,7 +146,7 @@ struct GridwiseGemm_xdl_cshuffle_v3
static constexpr auto BK1Number = Number<BK1Value>{};
static constexpr index_t KPack =
math::max(math::lcm(AK1Number, BK1Number),
math::max(math::gcd(AK1Number, BK1Number),
MfmaSelector<ComputeTypeA, MPerXdl, NPerXdl>::selected_mfma.k_per_blk);
using ThisThreadBlock = ThisThreadBlock<BlockSize>;
......@@ -1424,25 +1424,27 @@ struct GridwiseGemm_xdl_cshuffle_v3
constexpr auto c_thread_desc_mblock_nblock_m0_m1_n0_m2_m3_n1_n2_m4 =
blockwise_gemm_pipeline.GetCThreadDescriptor_MBlock_NBlock_M0_M1_N0_M2_M3_N1_N2_M4();
constexpr auto c_block_desc_mblock_nblock_m0_m1_n0_m2_m3_n1_n2_m4 =
blockwise_gemm_pipeline.GetCBlockDescriptor_MBlock_NBlock_M0_M1_N0_M2_M3_N1_N2_M4();
constexpr auto M0 =
c_block_desc_mblock_nblock_m0_m1_n0_m2_m3_n1_n2_m4.GetLength(Number<2>{});
constexpr auto M1 =
c_block_desc_mblock_nblock_m0_m1_n0_m2_m3_n1_n2_m4.GetLength(Number<3>{});
constexpr auto N0 =
c_block_desc_mblock_nblock_m0_m1_n0_m2_m3_n1_n2_m4.GetLength(Number<4>{});
constexpr auto M2 =
c_block_desc_mblock_nblock_m0_m1_n0_m2_m3_n1_n2_m4.GetLength(Number<5>{});
constexpr auto M3 =
c_block_desc_mblock_nblock_m0_m1_n0_m2_m3_n1_n2_m4.GetLength(Number<6>{});
constexpr auto N1 =
c_block_desc_mblock_nblock_m0_m1_n0_m2_m3_n1_n2_m4.GetLength(Number<7>{});
constexpr auto N2 =
c_block_desc_mblock_nblock_m0_m1_n0_m2_m3_n1_n2_m4.GetLength(Number<8>{});
constexpr auto M4 =
c_block_desc_mblock_nblock_m0_m1_n0_m2_m3_n1_n2_m4.GetLength(Number<9>{});
constexpr auto c_block_desc_m0_m1_n0_m2_m3_n1_n2_m4 =
blockwise_gemm_pipeline.GetCBlockDescriptor_M0_M1_N0_M2_M3_N1_N2_M4();
constexpr auto M0 = c_block_desc_m0_m1_n0_m2_m3_n1_n2_m4.GetLength(Number<0>{});
constexpr auto M1 = c_block_desc_m0_m1_n0_m2_m3_n1_n2_m4.GetLength(Number<1>{});
constexpr auto N0 = c_block_desc_m0_m1_n0_m2_m3_n1_n2_m4.GetLength(Number<2>{});
constexpr auto M2 = c_block_desc_m0_m1_n0_m2_m3_n1_n2_m4.GetLength(Number<3>{});
constexpr auto M3 = c_block_desc_m0_m1_n0_m2_m3_n1_n2_m4.GetLength(Number<4>{});
constexpr auto N1 = c_block_desc_m0_m1_n0_m2_m3_n1_n2_m4.GetLength(Number<5>{});
constexpr auto N2 = c_block_desc_m0_m1_n0_m2_m3_n1_n2_m4.GetLength(Number<6>{});
constexpr auto M4 = c_block_desc_m0_m1_n0_m2_m3_n1_n2_m4.GetLength(Number<7>{});
const auto c_grid_desc_mblock_nblock_m0_m1_n0_m2_m3_n1_n2_m4 = transform_tensor_descriptor(
c_grid_desc_mblock_mperblock_nblock_nperblock,
make_tuple(make_pass_through_transform(problem.MBlock),
make_unmerge_transform(make_tuple(M0, M1, M2, M3, M4)),
make_pass_through_transform(problem.NBlock),
make_unmerge_transform(make_tuple(N0, N1, N2))),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
make_tuple(
Sequence<0>{}, Sequence<2, 3, 5, 6, 9>{}, Sequence<1>{}, Sequence<4, 7, 8>{}));
const auto c_thread_mtx_on_block =
blockwise_gemm_pipeline.CalculateCThreadOriginDataIndexContiguous(I0, I0, I0, I0);
......@@ -1474,7 +1476,7 @@ struct GridwiseGemm_xdl_cshuffle_v3
AccDataType,
CDataType,
decltype(c_thread_desc_mblock_nblock_m0_m1_n0_m2_m3_n1_n2_m4),
decltype(c_block_desc_mblock_nblock_m0_m1_n0_m2_m3_n1_n2_m4),
decltype(c_grid_desc_mblock_nblock_m0_m1_n0_m2_m3_n1_n2_m4),
CElementwiseOperation,
Sequence<I1, I1, M0, I1, I1, M2, I1, I1, N2, M4>,
Sequence<0, 1, 2, 3, 4, 5, 6, 7, 8, 9>,
......@@ -1484,7 +1486,7 @@ struct GridwiseGemm_xdl_cshuffle_v3
M4,
N2,
InMemoryDataOperationEnum::Set,
false>{c_block_desc_mblock_nblock_m0_m1_n0_m2_m3_n1_n2_m4,
false>{c_grid_desc_mblock_nblock_m0_m1_n0_m2_m3_n1_n2_m4,
make_multi_index(block_m_id,
block_n_id,
m_thread_data_on_block_idx[I0],
......@@ -1500,7 +1502,7 @@ struct GridwiseGemm_xdl_cshuffle_v3
c_thread_copy_vgpr_to_global.Run(c_thread_desc_mblock_nblock_m0_m1_n0_m2_m3_n1_n2_m4,
make_tuple(I0, I0, I0, I0, I0, I0, I0, I0, I0, I0),
c_thread_buf,
c_block_desc_mblock_nblock_m0_m1_n0_m2_m3_n1_n2_m4,
c_grid_desc_mblock_nblock_m0_m1_n0_m2_m3_n1_n2_m4,
c_grid_buf);
}
......
......@@ -399,10 +399,57 @@ struct ThreadwiseTensorSliceTransfer_v1r4
constexpr auto dst_dim_access_order = DstDimAccessOrder{};
constexpr auto dst_access_lengths = SliceLengths{} / dst_scalar_per_access;
constexpr auto ordered_dst_access_lengths =
container_reorder_given_new2old(access_lengths, dst_dim_access_order);
container_reorder_given_new2old(dst_access_lengths, dst_dim_access_order);
// make forward steps
const auto dst_forward_steps = generate_tuple(
[&](auto i) {
Index forward_step_idx;
static_for<0, nDim, 1>{}([&](auto j) {
forward_step_idx(j) = (i.value == j.value) ? dst_scalar_per_access[i] : 0;
});
return make_tensor_coordinate_step(dst_desc, forward_step_idx);
},
Number<nDim>{});
// make backward steps
const auto dst_backward_steps = generate_tuple(
[&](auto i) {
Index backward_step_idx;
static_for<0, nDim, 1>{}([&](auto j) {
backward_step_idx(j) = (i.value == j.value) ? -dst_scalar_per_access[i] : 0;
});
return make_tensor_coordinate_step(dst_desc, backward_step_idx);
},
Number<nDim>{});
static_ford<decltype(ordered_dst_access_lengths)>{}([&](auto ordered_dst_access_idx) {
// judge move forward or move backward
constexpr auto forward_sweep = [&]() {
StaticallyIndexedArray<bool, nDim> forward_sweep_;
forward_sweep_(Number<0>{}) = true;
static_for<1, nDim, 1>{}([&](auto i) {
index_t tmp = ordered_dst_access_idx[Number<0>{}];
static_for<1, i, 1>{}([&](auto j) {
tmp = tmp * ordered_dst_access_lengths[j] + ordered_dst_access_idx[j];
});
forward_sweep_(i) = tmp % 2 == 0;
});
return forward_sweep_;
}();
using dst_vector_type = vector_type_maker_t<DstData, DstScalarPerVector>;
using dst_vector_t = typename dst_vector_type::type;
......@@ -423,10 +470,39 @@ struct ThreadwiseTensorSliceTransfer_v1r4
is_dst_valid,
dst_vector.template AsType<dst_vector_t>()[Number<0>{}]);
move_tensor_coordinate(
dst_desc,
dst_coord_,
make_tensor_coordinate_step(dst_desc, to_multi_index(data_to_origin_disp_idx)));
constexpr auto move_on_dim = [&]() constexpr
{
StaticallyIndexedArray<bool, nDim> move_on_dim_;
static_for<0, nDim, 1>{}([&](auto i) {
move_on_dim_(i) = ordered_dst_access_idx[i] < ordered_dst_access_lengths[i] - 1;
static_for<i + 1, nDim, 1>{}([&](auto j) {
move_on_dim_(i) &=
ordered_dst_access_idx[j] == ordered_dst_access_lengths[j] - 1;
});
});
return move_on_dim_;
}
();
// move dst coord
static_for<0, nDim, 1>{}([&](auto i) {
if constexpr(move_on_dim[i])
{
if constexpr(forward_sweep[i])
{
move_tensor_coordinate(
dst_desc, dst_coord_, dst_forward_steps[dst_dim_access_order[i]]);
}
else
{
move_tensor_coordinate(
dst_desc, dst_coord_, dst_backward_steps[dst_dim_access_order[i]]);
}
}
});
});
// move dst coordinate back to slice origin (or not)
......@@ -1697,28 +1773,20 @@ struct ThreadwiseTensorSliceTransfer_v5
constexpr auto src_scalar_per_access = generate_sequence(
detail::lambda_scalar_per_access<SrcVectorDim, SrcScalarPerVector>{}, Number<nDim>{});
constexpr auto access_lengths = SliceLengths{} / src_scalar_per_access;
constexpr auto src_dim_access_order = SrcDimAccessOrder{};
constexpr auto ordered_access_lengths =
container_reorder_given_new2old(access_lengths, src_dim_access_order);
static_ford<decltype(ordered_access_lengths)>{}([&](auto ordered_access_idx) {
// position in slice window
constexpr auto data_to_origin_disp_idx =
ordered_access_idx.ReorderGivenOld2New(src_dim_access_order) *
src_scalar_per_access;
#if 0
if (get_thread_local_1d_id()==0){
printf("%d, %d, %d, %d\n",
data_to_origin_disp_idx.At(Number<0>{}).value,
data_to_origin_disp_idx.At(Number<1>{}).value,
data_to_origin_disp_idx.At(Number<2>{}).value,
data_to_origin_disp_idx.At(Number<3>{}).value);
}
#endif
// src coordinate
constexpr auto src_ref_to_data_disp_idx =
src_ref_to_origin_disp_idx + data_to_origin_disp_idx;
......@@ -1740,16 +1808,9 @@ struct ThreadwiseTensorSliceTransfer_v5
// copy data from src_buf into src_tmp_vector
src_tmp_vector.template AsType<src_vector_t>()(Number<0>{}) =
src_buf.template Get<src_vector_t>(src_data_coord.GetOffset(), is_src_valid);
#if 0
if (get_thread_local_1d_id()<32){
printf("Tid: %02d, Index(%d, %d, %d, %d), offset: %d\n", get_thread_local_1d_id(), src_data_coord.GetIndex().At(Number<0>{}),
src_data_coord.GetIndex().At(Number<1>{}),
src_data_coord.GetIndex().At(Number<2>{}),
src_data_coord.GetIndex().At(Number<3>{}), src_data_coord.GetOffset());
}
#endif
// Set data to scratch
src_thread_scratch_.template SetAsType_Print<src_vector_t>(
src_thread_scratch_.template SetAsType<src_vector_t>(
data_to_origin_disp_idx, src_tmp_vector.template AsType<src_vector_t>()[I0]);
});
......@@ -1847,8 +1908,10 @@ struct ThreadwiseTensorSliceTransfer_v5
constexpr auto dst_dim_access_order = DstDimAccessOrder{};
constexpr auto dst_access_lengths = SliceLengths{} / dst_scalar_per_access;
constexpr auto ordered_dst_access_lengths =
container_reorder_given_new2old(access_lengths, dst_dim_access_order);
container_reorder_given_new2old(dst_access_lengths, dst_dim_access_order);
static_ford<decltype(ordered_dst_access_lengths)>{}([&](auto ordered_dst_access_idx) {
// position in slice window
......
......@@ -950,22 +950,18 @@ struct XdlopsGemm
Sequence<7>{}));
}
template <typename CDesc_MBlock_NBlock_M0_N0_M1_N1_M2_N2>
__host__ __device__ static constexpr auto MakeCDescriptor_MBlock_NBlock_M0_M1_N0_M2_M3_N1_N2_M4(
const CDesc_MBlock_NBlock_M0_N0_M1_N1_M2_N2& c_desc_mblock_nblock_m0_n0_m1_n1_m2_n2)
template <typename CDesc_M0_N0_M1_N1_M2_N2>
__host__ __device__ static constexpr auto
MakeCDescriptor_M0_M1_N0_M2_M3_N1_N2_M4(const CDesc_M0_N0_M1_N1_M2_N2& c_desc_m0_n0_m1_n1_m2_n2)
{
const auto MBlock = c_desc_mblock_nblock_m0_n0_m1_n1_m2_n2.GetLength(I0);
const auto NBlock = c_desc_mblock_nblock_m0_n0_m1_n1_m2_n2.GetLength(I1);
const auto M0 = c_desc_mblock_nblock_m0_n0_m1_n1_m2_n2.GetLength(I2);
const auto N0 = c_desc_mblock_nblock_m0_n0_m1_n1_m2_n2.GetLength(I3);
const auto M1 = c_desc_mblock_nblock_m0_n0_m1_n1_m2_n2.GetLength(I4);
const auto N1 = c_desc_mblock_nblock_m0_n0_m1_n1_m2_n2.GetLength(I5);
const auto M0 = c_desc_m0_n0_m1_n1_m2_n2.GetLength(I0);
const auto N0 = c_desc_m0_n0_m1_n1_m2_n2.GetLength(I1);
const auto M1 = c_desc_m0_n0_m1_n1_m2_n2.GetLength(I2);
const auto N1 = c_desc_m0_n0_m1_n1_m2_n2.GetLength(I3);
return transform_tensor_descriptor(
c_desc_mblock_nblock_m0_n0_m1_n1_m2_n2,
make_tuple(make_pass_through_transform(MBlock),
make_pass_through_transform(NBlock),
make_pass_through_transform(M0),
c_desc_m0_n0_m1_n1_m2_n2,
make_tuple(make_pass_through_transform(M0),
make_pass_through_transform(N0),
make_pass_through_transform(M1),
make_pass_through_transform(N1),
......@@ -978,17 +974,13 @@ struct XdlopsGemm
Sequence<2>{},
Sequence<3>{},
Sequence<4>{},
Sequence<5>{},
Sequence<6>{},
Sequence<7>{}),
Sequence<5>{}),
make_tuple(Sequence<0>{},
Sequence<6>{},
Sequence<1>{},
Sequence<2>{},
Sequence<8>{},
Sequence<3>{},
Sequence<4>{},
Sequence<5, 6, 9>{},
Sequence<7>{}));
Sequence<3, 4, 7>{},
Sequence<5>{}));
}
// transposed XDL output supporting C' = B' * A'
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment