Commit 69977fab authored by aska-0096's avatar aska-0096
Browse files

tempsave

parent 1e339898
...@@ -28,15 +28,15 @@ using DeviceGemmV2Instance = ...@@ -28,15 +28,15 @@ using DeviceGemmV2Instance =
ADataType, BDataType, CDataType, AccDataType, CShuffleDataType, ADataType, BDataType, CDataType, AccDataType, CShuffleDataType,
PassThrough, PassThrough, PassThrough, GemmDefault, PassThrough, PassThrough, PassThrough, GemmDefault,
256, 256,
224, 256, 256, 256,
64, 8, 8, 64, 8, 8,
16, 16, 16, 16,
7, 8, 8, 8,
S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>,
2, 8, 8, 0, 2, 8, 8, 0,
S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>,
1, 8, 8, 0, 1, 8, 8, 0,
1, 2, S<1, 32, 1, 8>, 8, 1, 2, S<1, 32, 1, 8>, 8, // TODO: Deprecated
ck::BlockGemmPipelineScheduler::Intrawave,ck::BlockGemmPipelineVersion::v3>; ck::BlockGemmPipelineScheduler::Intrawave,ck::BlockGemmPipelineVersion::v3>;
// clang-format on // clang-format on
......
...@@ -218,32 +218,6 @@ struct StaticTensorTupleOfVectorBuffer ...@@ -218,32 +218,6 @@ struct StaticTensorTupleOfVectorBuffer
} }
} }
template <typename X,
typename Idx,
typename enable_if<has_same_scalar_type<S, X>::value &&
is_known_at_compile_time<Idx>::value && Idx::Size() == ndim_,
bool>::type = false>
__host__ __device__ constexpr void SetAsType_Print(Idx, X x)
{
constexpr auto coord = make_tensor_coordinate(desc_, to_multi_index(Idx{}));
constexpr index_t offset = coord.GetOffset();
if(get_thread_local_1d_id()==0){
printf("Tid: %d, Index: (%d, %d, %d, %d), Offset: %d\n", get_thread_local_1d_id(),
Idx{}.At(Number<0>{}).value,
Idx{}.At(Number<1>{}).value,
Idx{}.At(Number<2>{}).value,
Idx{}.At(Number<3>{}).value, offset);
}
constexpr bool is_valid = coordinate_has_valid_offset(desc_, coord);
if constexpr(is_valid)
{
data_.template SetAsType<X>(Number<offset>{}, x);
}
}
// Get read access to V. No is_valid check // Get read access to V. No is_valid check
// Idx is for S, not V. Idx should be aligned with V // Idx is for S, not V. Idx should be aligned with V
template <typename Idx> template <typename Idx>
......
...@@ -302,21 +302,17 @@ struct BlockwiseGemmXdlops_pipeline_base ...@@ -302,21 +302,17 @@ struct BlockwiseGemmXdlops_pipeline_base
return xdlops_gemm.MakeCDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(c_block_desc_m0_n0_m1_n1_m2_n2); return xdlops_gemm.MakeCDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(c_block_desc_m0_n0_m1_n1_m2_n2);
} }
__host__ __device__ static constexpr auto __host__ __device__ static constexpr auto GetCBlockDescriptor_M0_M1_N0_M2_M3_N1_N2_M4()
GetCBlockDescriptor_MBlock_NBlock_M0_M1_N0_M2_M3_N1_N2_M4()
{ {
constexpr auto c_block_desc_mblock_nblock_m0_n0_m1_n1_m2_n2 = constexpr auto c_block_desc_m0_n0_m1_n1_m2_n2 =
make_naive_tensor_descriptor_packed(make_tuple(I1, make_naive_tensor_descriptor_packed(make_tuple(Number<MRepeat>{},
I1,
Number<MRepeat>{},
Number<NRepeat>{}, Number<NRepeat>{},
Number<MWaves>{}, Number<MWaves>{},
Number<NWaves>{}, Number<NWaves>{},
Number<MPerXDL>{}, Number<MPerXDL>{},
Number<NPerXDL>{})); Number<NPerXDL>{}));
return xdlops_gemm.MakeCDescriptor_MBlock_NBlock_M0_M1_N0_M2_M3_N1_N2_M4( return xdlops_gemm.MakeCDescriptor_M0_M1_N0_M2_M3_N1_N2_M4(c_block_desc_m0_n0_m1_n1_m2_n2);
c_block_desc_mblock_nblock_m0_n0_m1_n1_m2_n2);
} }
__host__ __device__ static constexpr auto GetCBlockDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2() __host__ __device__ static constexpr auto GetCBlockDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2()
......
...@@ -332,12 +332,12 @@ struct BlockwiseGemmXdlops_pipeline_v3<BlockGemmPipelineScheduler::Intrawave, ...@@ -332,12 +332,12 @@ struct BlockwiseGemmXdlops_pipeline_v3<BlockGemmPipelineScheduler::Intrawave,
// Local prefetch 1 // Local prefetch 1
block_sync_lds(); block_sync_lds();
static_for<0, KRepeat, 1>{}([&](auto k0) { static_for<0, KRepeat, 1>{}([&](auto k0) {
// a_thread_copy_.Run(a_block_desc_m0_m1_m2_k, a_thread_copy_.Run(a_block_desc_m0_m1_m2_k,
// make_tuple(I0, I0, I0, Number<k0 * AMmaKStride>{}), make_tuple(I0, I0, I0, Number<k0 * AMmaKStride>{}),
// a_block_buf, a_block_buf,
// a_thread_desc_, a_thread_desc_,
// make_tuple(I0, I0, k0, I0), make_tuple(I0, I0, k0, I0),
// a_thread_buf); a_thread_buf);
b_thread_copy_.Run(b_block_desc_n0_n1_n2_k, b_thread_copy_.Run(b_block_desc_n0_n1_n2_k,
make_tuple(I0, I0, I0, Number<k0 * BMmaKStride>{}), make_tuple(I0, I0, I0, Number<k0 * BMmaKStride>{}),
...@@ -399,12 +399,12 @@ struct BlockwiseGemmXdlops_pipeline_v3<BlockGemmPipelineScheduler::Intrawave, ...@@ -399,12 +399,12 @@ struct BlockwiseGemmXdlops_pipeline_v3<BlockGemmPipelineScheduler::Intrawave,
block_sync_lds(); block_sync_lds();
static_for<0, KRepeat, 1>{}([&](auto k0) { static_for<0, KRepeat, 1>{}([&](auto k0) {
// a_thread_copy_.Run(a_block_desc_m0_m1_m2_k, a_thread_copy_.Run(a_block_desc_m0_m1_m2_k,
// make_tuple(I0, I0, I0, Number<k0 * AMmaKStride>{}), make_tuple(I0, I0, I0, Number<k0 * AMmaKStride>{}),
// a_block_buf, a_block_buf,
// a_thread_desc_, a_thread_desc_,
// make_tuple(I0, I0, k0, I0), make_tuple(I0, I0, k0, I0),
// a_thread_buf); a_thread_buf);
b_thread_copy_.Run(b_block_desc_n0_n1_n2_k, b_thread_copy_.Run(b_block_desc_n0_n1_n2_k,
make_tuple(I0, I0, I0, Number<k0 * BMmaKStride>{}), make_tuple(I0, I0, I0, Number<k0 * BMmaKStride>{}),
......
...@@ -146,7 +146,7 @@ struct GridwiseGemm_xdl_cshuffle_v3 ...@@ -146,7 +146,7 @@ struct GridwiseGemm_xdl_cshuffle_v3
static constexpr auto BK1Number = Number<BK1Value>{}; static constexpr auto BK1Number = Number<BK1Value>{};
static constexpr index_t KPack = static constexpr index_t KPack =
math::max(math::lcm(AK1Number, BK1Number), math::max(math::gcd(AK1Number, BK1Number),
MfmaSelector<ComputeTypeA, MPerXdl, NPerXdl>::selected_mfma.k_per_blk); MfmaSelector<ComputeTypeA, MPerXdl, NPerXdl>::selected_mfma.k_per_blk);
using ThisThreadBlock = ThisThreadBlock<BlockSize>; using ThisThreadBlock = ThisThreadBlock<BlockSize>;
...@@ -1424,25 +1424,27 @@ struct GridwiseGemm_xdl_cshuffle_v3 ...@@ -1424,25 +1424,27 @@ struct GridwiseGemm_xdl_cshuffle_v3
constexpr auto c_thread_desc_mblock_nblock_m0_m1_n0_m2_m3_n1_n2_m4 = constexpr auto c_thread_desc_mblock_nblock_m0_m1_n0_m2_m3_n1_n2_m4 =
blockwise_gemm_pipeline.GetCThreadDescriptor_MBlock_NBlock_M0_M1_N0_M2_M3_N1_N2_M4(); blockwise_gemm_pipeline.GetCThreadDescriptor_MBlock_NBlock_M0_M1_N0_M2_M3_N1_N2_M4();
constexpr auto c_block_desc_mblock_nblock_m0_m1_n0_m2_m3_n1_n2_m4 = constexpr auto c_block_desc_m0_m1_n0_m2_m3_n1_n2_m4 =
blockwise_gemm_pipeline.GetCBlockDescriptor_MBlock_NBlock_M0_M1_N0_M2_M3_N1_N2_M4(); blockwise_gemm_pipeline.GetCBlockDescriptor_M0_M1_N0_M2_M3_N1_N2_M4();
constexpr auto M0 = constexpr auto M0 = c_block_desc_m0_m1_n0_m2_m3_n1_n2_m4.GetLength(Number<0>{});
c_block_desc_mblock_nblock_m0_m1_n0_m2_m3_n1_n2_m4.GetLength(Number<2>{}); constexpr auto M1 = c_block_desc_m0_m1_n0_m2_m3_n1_n2_m4.GetLength(Number<1>{});
constexpr auto M1 = constexpr auto N0 = c_block_desc_m0_m1_n0_m2_m3_n1_n2_m4.GetLength(Number<2>{});
c_block_desc_mblock_nblock_m0_m1_n0_m2_m3_n1_n2_m4.GetLength(Number<3>{}); constexpr auto M2 = c_block_desc_m0_m1_n0_m2_m3_n1_n2_m4.GetLength(Number<3>{});
constexpr auto N0 = constexpr auto M3 = c_block_desc_m0_m1_n0_m2_m3_n1_n2_m4.GetLength(Number<4>{});
c_block_desc_mblock_nblock_m0_m1_n0_m2_m3_n1_n2_m4.GetLength(Number<4>{}); constexpr auto N1 = c_block_desc_m0_m1_n0_m2_m3_n1_n2_m4.GetLength(Number<5>{});
constexpr auto M2 = constexpr auto N2 = c_block_desc_m0_m1_n0_m2_m3_n1_n2_m4.GetLength(Number<6>{});
c_block_desc_mblock_nblock_m0_m1_n0_m2_m3_n1_n2_m4.GetLength(Number<5>{}); constexpr auto M4 = c_block_desc_m0_m1_n0_m2_m3_n1_n2_m4.GetLength(Number<7>{});
constexpr auto M3 =
c_block_desc_mblock_nblock_m0_m1_n0_m2_m3_n1_n2_m4.GetLength(Number<6>{}); const auto c_grid_desc_mblock_nblock_m0_m1_n0_m2_m3_n1_n2_m4 = transform_tensor_descriptor(
constexpr auto N1 = c_grid_desc_mblock_mperblock_nblock_nperblock,
c_block_desc_mblock_nblock_m0_m1_n0_m2_m3_n1_n2_m4.GetLength(Number<7>{}); make_tuple(make_pass_through_transform(problem.MBlock),
constexpr auto N2 = make_unmerge_transform(make_tuple(M0, M1, M2, M3, M4)),
c_block_desc_mblock_nblock_m0_m1_n0_m2_m3_n1_n2_m4.GetLength(Number<8>{}); make_pass_through_transform(problem.NBlock),
constexpr auto M4 = make_unmerge_transform(make_tuple(N0, N1, N2))),
c_block_desc_mblock_nblock_m0_m1_n0_m2_m3_n1_n2_m4.GetLength(Number<9>{}); make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
make_tuple(
Sequence<0>{}, Sequence<2, 3, 5, 6, 9>{}, Sequence<1>{}, Sequence<4, 7, 8>{}));
const auto c_thread_mtx_on_block = const auto c_thread_mtx_on_block =
blockwise_gemm_pipeline.CalculateCThreadOriginDataIndexContiguous(I0, I0, I0, I0); blockwise_gemm_pipeline.CalculateCThreadOriginDataIndexContiguous(I0, I0, I0, I0);
...@@ -1474,7 +1476,7 @@ struct GridwiseGemm_xdl_cshuffle_v3 ...@@ -1474,7 +1476,7 @@ struct GridwiseGemm_xdl_cshuffle_v3
AccDataType, AccDataType,
CDataType, CDataType,
decltype(c_thread_desc_mblock_nblock_m0_m1_n0_m2_m3_n1_n2_m4), decltype(c_thread_desc_mblock_nblock_m0_m1_n0_m2_m3_n1_n2_m4),
decltype(c_block_desc_mblock_nblock_m0_m1_n0_m2_m3_n1_n2_m4), decltype(c_grid_desc_mblock_nblock_m0_m1_n0_m2_m3_n1_n2_m4),
CElementwiseOperation, CElementwiseOperation,
Sequence<I1, I1, M0, I1, I1, M2, I1, I1, N2, M4>, Sequence<I1, I1, M0, I1, I1, M2, I1, I1, N2, M4>,
Sequence<0, 1, 2, 3, 4, 5, 6, 7, 8, 9>, Sequence<0, 1, 2, 3, 4, 5, 6, 7, 8, 9>,
...@@ -1484,7 +1486,7 @@ struct GridwiseGemm_xdl_cshuffle_v3 ...@@ -1484,7 +1486,7 @@ struct GridwiseGemm_xdl_cshuffle_v3
M4, M4,
N2, N2,
InMemoryDataOperationEnum::Set, InMemoryDataOperationEnum::Set,
false>{c_block_desc_mblock_nblock_m0_m1_n0_m2_m3_n1_n2_m4, false>{c_grid_desc_mblock_nblock_m0_m1_n0_m2_m3_n1_n2_m4,
make_multi_index(block_m_id, make_multi_index(block_m_id,
block_n_id, block_n_id,
m_thread_data_on_block_idx[I0], m_thread_data_on_block_idx[I0],
...@@ -1500,7 +1502,7 @@ struct GridwiseGemm_xdl_cshuffle_v3 ...@@ -1500,7 +1502,7 @@ struct GridwiseGemm_xdl_cshuffle_v3
c_thread_copy_vgpr_to_global.Run(c_thread_desc_mblock_nblock_m0_m1_n0_m2_m3_n1_n2_m4, c_thread_copy_vgpr_to_global.Run(c_thread_desc_mblock_nblock_m0_m1_n0_m2_m3_n1_n2_m4,
make_tuple(I0, I0, I0, I0, I0, I0, I0, I0, I0, I0), make_tuple(I0, I0, I0, I0, I0, I0, I0, I0, I0, I0),
c_thread_buf, c_thread_buf,
c_block_desc_mblock_nblock_m0_m1_n0_m2_m3_n1_n2_m4, c_grid_desc_mblock_nblock_m0_m1_n0_m2_m3_n1_n2_m4,
c_grid_buf); c_grid_buf);
} }
......
...@@ -399,10 +399,57 @@ struct ThreadwiseTensorSliceTransfer_v1r4 ...@@ -399,10 +399,57 @@ struct ThreadwiseTensorSliceTransfer_v1r4
constexpr auto dst_dim_access_order = DstDimAccessOrder{}; constexpr auto dst_dim_access_order = DstDimAccessOrder{};
constexpr auto dst_access_lengths = SliceLengths{} / dst_scalar_per_access;
constexpr auto ordered_dst_access_lengths = constexpr auto ordered_dst_access_lengths =
container_reorder_given_new2old(access_lengths, dst_dim_access_order); container_reorder_given_new2old(dst_access_lengths, dst_dim_access_order);
// make forward steps
const auto dst_forward_steps = generate_tuple(
[&](auto i) {
Index forward_step_idx;
static_for<0, nDim, 1>{}([&](auto j) {
forward_step_idx(j) = (i.value == j.value) ? dst_scalar_per_access[i] : 0;
});
return make_tensor_coordinate_step(dst_desc, forward_step_idx);
},
Number<nDim>{});
// make backward steps
const auto dst_backward_steps = generate_tuple(
[&](auto i) {
Index backward_step_idx;
static_for<0, nDim, 1>{}([&](auto j) {
backward_step_idx(j) = (i.value == j.value) ? -dst_scalar_per_access[i] : 0;
});
return make_tensor_coordinate_step(dst_desc, backward_step_idx);
},
Number<nDim>{});
static_ford<decltype(ordered_dst_access_lengths)>{}([&](auto ordered_dst_access_idx) { static_ford<decltype(ordered_dst_access_lengths)>{}([&](auto ordered_dst_access_idx) {
// judge move forward or move backward
constexpr auto forward_sweep = [&]() {
StaticallyIndexedArray<bool, nDim> forward_sweep_;
forward_sweep_(Number<0>{}) = true;
static_for<1, nDim, 1>{}([&](auto i) {
index_t tmp = ordered_dst_access_idx[Number<0>{}];
static_for<1, i, 1>{}([&](auto j) {
tmp = tmp * ordered_dst_access_lengths[j] + ordered_dst_access_idx[j];
});
forward_sweep_(i) = tmp % 2 == 0;
});
return forward_sweep_;
}();
using dst_vector_type = vector_type_maker_t<DstData, DstScalarPerVector>; using dst_vector_type = vector_type_maker_t<DstData, DstScalarPerVector>;
using dst_vector_t = typename dst_vector_type::type; using dst_vector_t = typename dst_vector_type::type;
...@@ -423,10 +470,39 @@ struct ThreadwiseTensorSliceTransfer_v1r4 ...@@ -423,10 +470,39 @@ struct ThreadwiseTensorSliceTransfer_v1r4
is_dst_valid, is_dst_valid,
dst_vector.template AsType<dst_vector_t>()[Number<0>{}]); dst_vector.template AsType<dst_vector_t>()[Number<0>{}]);
move_tensor_coordinate( constexpr auto move_on_dim = [&]() constexpr
dst_desc, {
dst_coord_, StaticallyIndexedArray<bool, nDim> move_on_dim_;
make_tensor_coordinate_step(dst_desc, to_multi_index(data_to_origin_disp_idx)));
static_for<0, nDim, 1>{}([&](auto i) {
move_on_dim_(i) = ordered_dst_access_idx[i] < ordered_dst_access_lengths[i] - 1;
static_for<i + 1, nDim, 1>{}([&](auto j) {
move_on_dim_(i) &=
ordered_dst_access_idx[j] == ordered_dst_access_lengths[j] - 1;
});
});
return move_on_dim_;
}
();
// move dst coord
static_for<0, nDim, 1>{}([&](auto i) {
if constexpr(move_on_dim[i])
{
if constexpr(forward_sweep[i])
{
move_tensor_coordinate(
dst_desc, dst_coord_, dst_forward_steps[dst_dim_access_order[i]]);
}
else
{
move_tensor_coordinate(
dst_desc, dst_coord_, dst_backward_steps[dst_dim_access_order[i]]);
}
}
});
}); });
// move dst coordinate back to slice origin (or not) // move dst coordinate back to slice origin (or not)
...@@ -1697,28 +1773,20 @@ struct ThreadwiseTensorSliceTransfer_v5 ...@@ -1697,28 +1773,20 @@ struct ThreadwiseTensorSliceTransfer_v5
constexpr auto src_scalar_per_access = generate_sequence( constexpr auto src_scalar_per_access = generate_sequence(
detail::lambda_scalar_per_access<SrcVectorDim, SrcScalarPerVector>{}, Number<nDim>{}); detail::lambda_scalar_per_access<SrcVectorDim, SrcScalarPerVector>{}, Number<nDim>{});
constexpr auto access_lengths = SliceLengths{} / src_scalar_per_access; constexpr auto access_lengths = SliceLengths{} / src_scalar_per_access;
constexpr auto src_dim_access_order = SrcDimAccessOrder{}; constexpr auto src_dim_access_order = SrcDimAccessOrder{};
constexpr auto ordered_access_lengths = constexpr auto ordered_access_lengths =
container_reorder_given_new2old(access_lengths, src_dim_access_order); container_reorder_given_new2old(access_lengths, src_dim_access_order);
static_ford<decltype(ordered_access_lengths)>{}([&](auto ordered_access_idx) { static_ford<decltype(ordered_access_lengths)>{}([&](auto ordered_access_idx) {
// position in slice window // position in slice window
constexpr auto data_to_origin_disp_idx = constexpr auto data_to_origin_disp_idx =
ordered_access_idx.ReorderGivenOld2New(src_dim_access_order) * ordered_access_idx.ReorderGivenOld2New(src_dim_access_order) *
src_scalar_per_access; src_scalar_per_access;
#if 0
if (get_thread_local_1d_id()==0){
printf("%d, %d, %d, %d\n",
data_to_origin_disp_idx.At(Number<0>{}).value,
data_to_origin_disp_idx.At(Number<1>{}).value,
data_to_origin_disp_idx.At(Number<2>{}).value,
data_to_origin_disp_idx.At(Number<3>{}).value);
}
#endif
// src coordinate // src coordinate
constexpr auto src_ref_to_data_disp_idx = constexpr auto src_ref_to_data_disp_idx =
src_ref_to_origin_disp_idx + data_to_origin_disp_idx; src_ref_to_origin_disp_idx + data_to_origin_disp_idx;
...@@ -1740,16 +1808,9 @@ struct ThreadwiseTensorSliceTransfer_v5 ...@@ -1740,16 +1808,9 @@ struct ThreadwiseTensorSliceTransfer_v5
// copy data from src_buf into src_tmp_vector // copy data from src_buf into src_tmp_vector
src_tmp_vector.template AsType<src_vector_t>()(Number<0>{}) = src_tmp_vector.template AsType<src_vector_t>()(Number<0>{}) =
src_buf.template Get<src_vector_t>(src_data_coord.GetOffset(), is_src_valid); src_buf.template Get<src_vector_t>(src_data_coord.GetOffset(), is_src_valid);
#if 0
if (get_thread_local_1d_id()<32){
printf("Tid: %02d, Index(%d, %d, %d, %d), offset: %d\n", get_thread_local_1d_id(), src_data_coord.GetIndex().At(Number<0>{}),
src_data_coord.GetIndex().At(Number<1>{}),
src_data_coord.GetIndex().At(Number<2>{}),
src_data_coord.GetIndex().At(Number<3>{}), src_data_coord.GetOffset());
}
#endif
// Set data to scratch // Set data to scratch
src_thread_scratch_.template SetAsType_Print<src_vector_t>( src_thread_scratch_.template SetAsType<src_vector_t>(
data_to_origin_disp_idx, src_tmp_vector.template AsType<src_vector_t>()[I0]); data_to_origin_disp_idx, src_tmp_vector.template AsType<src_vector_t>()[I0]);
}); });
...@@ -1847,8 +1908,10 @@ struct ThreadwiseTensorSliceTransfer_v5 ...@@ -1847,8 +1908,10 @@ struct ThreadwiseTensorSliceTransfer_v5
constexpr auto dst_dim_access_order = DstDimAccessOrder{}; constexpr auto dst_dim_access_order = DstDimAccessOrder{};
constexpr auto dst_access_lengths = SliceLengths{} / dst_scalar_per_access;
constexpr auto ordered_dst_access_lengths = constexpr auto ordered_dst_access_lengths =
container_reorder_given_new2old(access_lengths, dst_dim_access_order); container_reorder_given_new2old(dst_access_lengths, dst_dim_access_order);
static_ford<decltype(ordered_dst_access_lengths)>{}([&](auto ordered_dst_access_idx) { static_ford<decltype(ordered_dst_access_lengths)>{}([&](auto ordered_dst_access_idx) {
// position in slice window // position in slice window
......
...@@ -950,22 +950,18 @@ struct XdlopsGemm ...@@ -950,22 +950,18 @@ struct XdlopsGemm
Sequence<7>{})); Sequence<7>{}));
} }
template <typename CDesc_MBlock_NBlock_M0_N0_M1_N1_M2_N2> template <typename CDesc_M0_N0_M1_N1_M2_N2>
__host__ __device__ static constexpr auto MakeCDescriptor_MBlock_NBlock_M0_M1_N0_M2_M3_N1_N2_M4( __host__ __device__ static constexpr auto
const CDesc_MBlock_NBlock_M0_N0_M1_N1_M2_N2& c_desc_mblock_nblock_m0_n0_m1_n1_m2_n2) MakeCDescriptor_M0_M1_N0_M2_M3_N1_N2_M4(const CDesc_M0_N0_M1_N1_M2_N2& c_desc_m0_n0_m1_n1_m2_n2)
{ {
const auto MBlock = c_desc_mblock_nblock_m0_n0_m1_n1_m2_n2.GetLength(I0); const auto M0 = c_desc_m0_n0_m1_n1_m2_n2.GetLength(I0);
const auto NBlock = c_desc_mblock_nblock_m0_n0_m1_n1_m2_n2.GetLength(I1); const auto N0 = c_desc_m0_n0_m1_n1_m2_n2.GetLength(I1);
const auto M0 = c_desc_mblock_nblock_m0_n0_m1_n1_m2_n2.GetLength(I2); const auto M1 = c_desc_m0_n0_m1_n1_m2_n2.GetLength(I2);
const auto N0 = c_desc_mblock_nblock_m0_n0_m1_n1_m2_n2.GetLength(I3); const auto N1 = c_desc_m0_n0_m1_n1_m2_n2.GetLength(I3);
const auto M1 = c_desc_mblock_nblock_m0_n0_m1_n1_m2_n2.GetLength(I4);
const auto N1 = c_desc_mblock_nblock_m0_n0_m1_n1_m2_n2.GetLength(I5);
return transform_tensor_descriptor( return transform_tensor_descriptor(
c_desc_mblock_nblock_m0_n0_m1_n1_m2_n2, c_desc_m0_n0_m1_n1_m2_n2,
make_tuple(make_pass_through_transform(MBlock), make_tuple(make_pass_through_transform(M0),
make_pass_through_transform(NBlock),
make_pass_through_transform(M0),
make_pass_through_transform(N0), make_pass_through_transform(N0),
make_pass_through_transform(M1), make_pass_through_transform(M1),
make_pass_through_transform(N1), make_pass_through_transform(N1),
...@@ -978,17 +974,13 @@ struct XdlopsGemm ...@@ -978,17 +974,13 @@ struct XdlopsGemm
Sequence<2>{}, Sequence<2>{},
Sequence<3>{}, Sequence<3>{},
Sequence<4>{}, Sequence<4>{},
Sequence<5>{}, Sequence<5>{}),
Sequence<6>{},
Sequence<7>{}),
make_tuple(Sequence<0>{}, make_tuple(Sequence<0>{},
Sequence<6>{},
Sequence<1>{}, Sequence<1>{},
Sequence<2>{}, Sequence<2>{},
Sequence<8>{}, Sequence<3, 4, 7>{},
Sequence<3>{}, Sequence<5>{}));
Sequence<4>{},
Sequence<5, 6, 9>{},
Sequence<7>{}));
} }
// transposed XDL output supporting C' = B' * A' // transposed XDL output supporting C' = B' * A'
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment