"tests/pipelines/vscode:/vscode.git/clone" did not exist on "38466c369f5c539ec068e9e19bbe26a0f591ff2e"
Commit 1e339898 authored by aska-0096's avatar aska-0096
Browse files

temp save

parent 11444e4c
......@@ -25,6 +25,7 @@ add_example_dependencies(example_gemm_xdl example_gemm_xdl_fp16_v2)
add_example_executable(example_gemm_xdl_fp16_streamk_v3 gemm_xdl_fp16_streamk_v3.cpp)
add_example_dependencies(example_gemm_xdl example_gemm_xdl_fp16_streamk_v3)
add_example_executable(example_gemm_xdl_fp16_v3 gemm_xdl_fp16_v3.cpp)
target_compile_options(example_gemm_xdl_fp16_v3 PRIVATE -mllvm -greedy-reverse-local-assignment=1 -save-temps=$PWD -Wno-gnu-line-marker)
add_example_dependencies(example_gemm_xdl example_gemm_xdl_fp16_v3)
add_example_executable(example_gemm_xdl_fp8_v3 gemm_xdl_fp8_v3.cpp)
add_example_dependencies(example_gemm_xdl example_gemm_xdl_fp8_v3)
......
......@@ -19,7 +19,7 @@ using AElementOp = PassThrough;
using BElementOp = PassThrough;
using CElementOp = PassThrough;
static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::MNPadding;
static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default;
// clang-format off
using DeviceGemmV2Instance =
......@@ -29,13 +29,13 @@ using DeviceGemmV2Instance =
PassThrough, PassThrough, PassThrough, GemmDefault,
256,
224, 256,
64, 8, 2,
64, 8, 8,
16, 16,
7, 8,
S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>,
2, 8, 8, 0,
S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>,
1, 8, 2, 0,
1, 8, 8, 0,
1, 2, S<1, 32, 1, 8>, 8,
ck::BlockGemmPipelineScheduler::Intrawave,ck::BlockGemmPipelineVersion::v3>;
// clang-format on
......
......@@ -218,6 +218,32 @@ struct StaticTensorTupleOfVectorBuffer
}
}
template <typename X,
typename Idx,
typename enable_if<has_same_scalar_type<S, X>::value &&
is_known_at_compile_time<Idx>::value && Idx::Size() == ndim_,
bool>::type = false>
__host__ __device__ constexpr void SetAsType_Print(Idx, X x)
{
constexpr auto coord = make_tensor_coordinate(desc_, to_multi_index(Idx{}));
constexpr index_t offset = coord.GetOffset();
if(get_thread_local_1d_id()==0){
printf("Tid: %d, Index: (%d, %d, %d, %d), Offset: %d\n", get_thread_local_1d_id(),
Idx{}.At(Number<0>{}).value,
Idx{}.At(Number<1>{}).value,
Idx{}.At(Number<2>{}).value,
Idx{}.At(Number<3>{}).value, offset);
}
constexpr bool is_valid = coordinate_has_valid_offset(desc_, coord);
if constexpr(is_valid)
{
data_.template SetAsType<X>(Number<offset>{}, x);
}
}
// Get read access to V. No is_valid check
// Idx is for S, not V. Idx should be aligned with V
template <typename Idx>
......
......@@ -30,6 +30,8 @@ template <index_t BlockSize,
index_t MRepeat,
index_t NRepeat,
index_t KPack,
bool TransposeA = false,
bool TransposeB = false,
bool TransposeC = false>
struct BlockwiseGemmXdlops_pipeline_base
{
......@@ -152,6 +154,38 @@ struct BlockwiseGemmXdlops_pipeline_base
return make_tuple(c_thread_m, c_thread_n);
}
// Contiguous output tile
template <index_t m0, index_t n0, index_t xdlops_i, index_t blk_i>
__device__ static auto CalculateCThreadOriginDataIndexContiguous(Number<m0>,
Number<n0>,
Number<xdlops_i>,
Number<blk_i>)
{
const auto wave_idx = GetWaveIdx();
const auto waveId_m = wave_idx[I0];
const auto waveId_n = wave_idx[I1];
const auto blk_idx = xdlops_gemm.GetBeginOfThreadBlk(xdlops_i, blk_i);
constexpr auto mrepeat_mwave_mperxdl_to_m_adaptor = make_single_stage_tensor_adaptor(
make_tuple(make_unmerge_transform(make_tuple(MRepeat, MWaves, MPerXDL))),
make_tuple(Sequence<0>{}),
make_tuple(Sequence<0, 1, 2>{}));
constexpr auto nrepeat_nwave_nperxdl_to_n_adaptor = make_single_stage_tensor_adaptor(
make_tuple(make_unmerge_transform(make_tuple(NWaves, NPerXDL, NRepeat))),
make_tuple(Sequence<0>{}),
make_tuple(Sequence<0, 1, 2>{}));
const index_t c_thread_m = mrepeat_mwave_mperxdl_to_m_adaptor.CalculateBottomIndex(
make_tuple(m0, waveId_m, blk_idx[I0]))[I0];
const index_t c_thread_n = nrepeat_nwave_nperxdl_to_n_adaptor.CalculateBottomIndex(
make_tuple(waveId_n, blk_idx[I1], n0))[I0];
return make_tuple(c_thread_m, c_thread_n);
}
template <index_t m0, index_t n0, index_t xdlops_i, index_t blk_i>
__device__ static auto
CalculateCThreadOriginDataIndex8D(Number<m0>, Number<n0>, Number<xdlops_i>, Number<blk_i>)
......@@ -212,6 +246,21 @@ struct BlockwiseGemmXdlops_pipeline_base
make_tuple(Number<MRepeat>{}, Number<NRepeat>{}, I1, I1, M0, M1, M2, N));
}
// Contiguous output tile
__host__ __device__ static constexpr auto
GetCThreadDescriptor_MBlock_NBlock_M0_M1_N0_M2_M3_N1_N2_M4()
{
constexpr auto c_m0_m1_m2_n_tblk_lens = xdlops_gemm.GetCM0M1M2NThreadBlkLengths();
constexpr auto M0 = c_m0_m1_m2_n_tblk_lens[I0];
constexpr auto M1 = c_m0_m1_m2_n_tblk_lens[I1];
constexpr auto M2 = c_m0_m1_m2_n_tblk_lens[I2];
constexpr auto N = c_m0_m1_m2_n_tblk_lens[I3];
return make_naive_tensor_descriptor_packed(
make_tuple(I1, I1, Number<MRepeat>{}, I1, I1, M0, M1, N, Number<NRepeat>{}, M2));
}
__host__ __device__ static constexpr auto GetCThreadDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2()
{
constexpr auto c_m0_m1_m2_n_tblk_lens = xdlops_gemm.GetCM0M1M2NThreadBlkLengths();
......@@ -253,6 +302,23 @@ struct BlockwiseGemmXdlops_pipeline_base
return xdlops_gemm.MakeCDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(c_block_desc_m0_n0_m1_n1_m2_n2);
}
__host__ __device__ static constexpr auto
GetCBlockDescriptor_MBlock_NBlock_M0_M1_N0_M2_M3_N1_N2_M4()
{
constexpr auto c_block_desc_mblock_nblock_m0_n0_m1_n1_m2_n2 =
make_naive_tensor_descriptor_packed(make_tuple(I1,
I1,
Number<MRepeat>{},
Number<NRepeat>{},
Number<MWaves>{},
Number<NWaves>{},
Number<MPerXDL>{},
Number<NPerXDL>{}));
return xdlops_gemm.MakeCDescriptor_MBlock_NBlock_M0_M1_N0_M2_M3_N1_N2_M4(
c_block_desc_mblock_nblock_m0_n0_m1_n1_m2_n2);
}
__host__ __device__ static constexpr auto GetCBlockDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2()
{
constexpr auto c_block_desc_g_m0_n0_m1_n1_m2_n2 =
......@@ -327,28 +393,78 @@ struct BlockwiseGemmXdlops_pipeline_base
static constexpr auto c_thread_desc_ = make_naive_tensor_descriptor_packed(
make_tuple(Number<MRepeat>{}, Number<NRepeat>{}, xdlops_gemm.GetRegSizePerXdlops()));
using AThreadCopy = ThreadwiseTensorSliceTransfer_v4<ADataType,
ComputeDataType,
decltype(a_block_desc_m0_m1_m2_k),
decltype(a_thread_desc_),
Sequence<1, 1, 1, KPack>,
Sequence<0, 1, 2, 3>,
3,
A_K1,
A_K1>;
using BThreadCopy = ThreadwiseTensorSliceTransfer_v4<BDataType,
ComputeDataType,
decltype(b_block_desc_n0_n1_n2_k),
decltype(b_thread_desc_),
Sequence<1, 1, 1, KPack>,
Sequence<0, 1, 2, 3>,
3,
B_K1,
B_K1>;
AThreadCopy a_thread_copy_;
BThreadCopy b_thread_copy_;
template <bool Transpose>
struct AThreadCopySelector;
template <>
struct AThreadCopySelector<false>
{
using type = ThreadwiseTensorSliceTransfer_v5<ADataType,
ComputeDataType,
decltype(a_block_desc_m0_m1_m2_k),
decltype(a_thread_desc_),
Sequence<MRepeat, 1, 1, KPack>,
Sequence<0, 1, 2, 3>,
Sequence<0, 1, 2, 3>,
3,
3,
A_K1,
A_K1>;
};
template <>
struct AThreadCopySelector<true>
{
using type = ThreadwiseTensorSliceTransfer_v5<ADataType,
ComputeDataType,
decltype(a_block_desc_m0_m1_m2_k),
decltype(a_thread_desc_),
Sequence<MRepeat, 1, 1, KPack>,
Sequence<3, 1, 2, 0>,
Sequence<0, 1, 2, 3>,
0,
3,
MRepeat,
A_K1>;
};
template <bool Transpose>
struct BThreadCopySelector;
template <>
struct BThreadCopySelector<false>
{
using type = ThreadwiseTensorSliceTransfer_v5<BDataType,
ComputeDataType,
decltype(b_block_desc_n0_n1_n2_k),
decltype(b_thread_desc_),
Sequence<NRepeat, 1, 1, KPack>,
Sequence<0, 1, 2, 3>,
Sequence<0, 1, 2, 3>,
3,
3,
B_K1,
B_K1>;
};
template <>
struct BThreadCopySelector<true>
{
using type = ThreadwiseTensorSliceTransfer_v5<BDataType,
ComputeDataType,
decltype(b_block_desc_n0_n1_n2_k),
decltype(b_thread_desc_),
Sequence<NRepeat, 1, 1, KPack>,
Sequence<3, 1, 2, 0>,
Sequence<0, 1, 2, 3>,
0,
3,
NRepeat,
B_K1>;
};
typename AThreadCopySelector<TransposeA>::type a_thread_copy_;
typename BThreadCopySelector<TransposeB>::type b_thread_copy_;
};
} // namespace ck
......@@ -40,7 +40,9 @@ template <BlockGemmPipelineVersion BlkGemmPipelineVer,
index_t NPerXDL,
index_t MRepeat,
index_t NRepeat,
index_t KPack>
index_t KPack,
bool TransposeA,
bool TransposeB>
constexpr auto BlockGemmPipeline_Selector()
{
if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v1)
......@@ -110,7 +112,9 @@ constexpr auto BlockGemmPipeline_Selector()
NPerXDL,
MRepeat,
NRepeat,
KPack>{};
KPack,
TransposeA,
TransposeB>{};
}
else if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v4)
{
......
......@@ -32,7 +32,9 @@ template <BlockGemmPipelineScheduler BlkGemmPipelineVer,
index_t NPerXDL,
index_t MRepeat,
index_t NRepeat,
index_t KPacks>
index_t KPack,
bool TransposeA,
bool TransposeB>
struct BlockwiseGemmXdlops_pipeline_v3
{
};
......@@ -55,7 +57,9 @@ template <index_t BlockSize,
index_t NPerXDL,
index_t MRepeat,
index_t NRepeat,
index_t KPack
index_t KPack,
bool TransposeA,
bool TransposeB
// ,bool TransposeC //disable transposec right now...
>
struct BlockwiseGemmXdlops_pipeline_v3<BlockGemmPipelineScheduler::Intrawave,
......@@ -77,7 +81,9 @@ struct BlockwiseGemmXdlops_pipeline_v3<BlockGemmPipelineScheduler::Intrawave,
NPerXDL,
MRepeat,
NRepeat,
KPack>
KPack,
TransposeA,
TransposeB>
: BlockwiseGemmXdlops_pipeline_base<BlockSize,
ADataType,
BDataType,
......@@ -96,7 +102,9 @@ struct BlockwiseGemmXdlops_pipeline_v3<BlockGemmPipelineScheduler::Intrawave,
NPerXDL,
MRepeat,
NRepeat,
KPack>
KPack,
TransposeA,
TransposeB>
{
using Base = BlockwiseGemmXdlops_pipeline_base<BlockSize,
......@@ -117,7 +125,9 @@ struct BlockwiseGemmXdlops_pipeline_v3<BlockGemmPipelineScheduler::Intrawave,
NPerXDL,
MRepeat,
NRepeat,
KPack>;
KPack,
TransposeA,
TransposeB>;
using Base::I0;
using Base::I1;
using Base::KRepeat;
......@@ -322,22 +332,19 @@ struct BlockwiseGemmXdlops_pipeline_v3<BlockGemmPipelineScheduler::Intrawave,
// Local prefetch 1
block_sync_lds();
static_for<0, KRepeat, 1>{}([&](auto k0) {
static_for<0, MRepeat, 1>{}([&](auto m0) {
a_thread_copy_.Run(a_block_desc_m0_m1_m2_k,
make_tuple(m0, I0, I0, Number<k0 * AMmaKStride>{}),
a_block_buf,
a_thread_desc_,
make_tuple(m0, I0, k0, I0),
a_thread_buf);
});
static_for<0, NRepeat, 1>{}([&](auto n0) {
b_thread_copy_.Run(b_block_desc_n0_n1_n2_k,
make_tuple(n0, I0, I0, Number<k0 * BMmaKStride>{}),
b_block_buf,
b_thread_desc_,
make_tuple(n0, I0, k0, I0),
b_thread_buf);
});
// a_thread_copy_.Run(a_block_desc_m0_m1_m2_k,
// make_tuple(I0, I0, I0, Number<k0 * AMmaKStride>{}),
// a_block_buf,
// a_thread_desc_,
// make_tuple(I0, I0, k0, I0),
// a_thread_buf);
b_thread_copy_.Run(b_block_desc_n0_n1_n2_k,
make_tuple(I0, I0, I0, Number<k0 * BMmaKStride>{}),
b_block_buf,
b_thread_desc_,
make_tuple(I0, I0, k0, I0),
b_thread_buf);
});
__builtin_amdgcn_sched_barrier(0);
......@@ -392,22 +399,19 @@ struct BlockwiseGemmXdlops_pipeline_v3<BlockGemmPipelineScheduler::Intrawave,
block_sync_lds();
static_for<0, KRepeat, 1>{}([&](auto k0) {
static_for<0, MRepeat, 1>{}([&](auto m0) {
a_thread_copy_.Run(a_block_desc_m0_m1_m2_k,
make_tuple(m0, I0, I0, Number<k0 * AMmaKStride>{}),
a_block_buf,
a_thread_desc_,
make_tuple(m0, I0, k0, I0),
a_thread_buf);
});
static_for<0, NRepeat, 1>{}([&](auto n0) {
b_thread_copy_.Run(b_block_desc_n0_n1_n2_k,
make_tuple(n0, I0, I0, Number<k0 * BMmaKStride>{}),
b_block_buf,
b_thread_desc_,
make_tuple(n0, I0, k0, I0),
b_thread_buf);
});
// a_thread_copy_.Run(a_block_desc_m0_m1_m2_k,
// make_tuple(I0, I0, I0, Number<k0 * AMmaKStride>{}),
// a_block_buf,
// a_thread_desc_,
// make_tuple(I0, I0, k0, I0),
// a_thread_buf);
b_thread_copy_.Run(b_block_desc_n0_n1_n2_k,
make_tuple(I0, I0, I0, Number<k0 * BMmaKStride>{}),
b_block_buf,
b_thread_desc_,
make_tuple(I0, I0, k0, I0),
b_thread_buf);
});
HotLoopScheduler();
......
......@@ -950,6 +950,47 @@ struct XdlopsGemm
Sequence<7>{}));
}
template <typename CDesc_MBlock_NBlock_M0_N0_M1_N1_M2_N2>
__host__ __device__ static constexpr auto MakeCDescriptor_MBlock_NBlock_M0_M1_N0_M2_M3_N1_N2_M4(
const CDesc_MBlock_NBlock_M0_N0_M1_N1_M2_N2& c_desc_mblock_nblock_m0_n0_m1_n1_m2_n2)
{
const auto MBlock = c_desc_mblock_nblock_m0_n0_m1_n1_m2_n2.GetLength(I0);
const auto NBlock = c_desc_mblock_nblock_m0_n0_m1_n1_m2_n2.GetLength(I1);
const auto M0 = c_desc_mblock_nblock_m0_n0_m1_n1_m2_n2.GetLength(I2);
const auto N0 = c_desc_mblock_nblock_m0_n0_m1_n1_m2_n2.GetLength(I3);
const auto M1 = c_desc_mblock_nblock_m0_n0_m1_n1_m2_n2.GetLength(I4);
const auto N1 = c_desc_mblock_nblock_m0_n0_m1_n1_m2_n2.GetLength(I5);
return transform_tensor_descriptor(
c_desc_mblock_nblock_m0_n0_m1_n1_m2_n2,
make_tuple(make_pass_through_transform(MBlock),
make_pass_through_transform(NBlock),
make_pass_through_transform(M0),
make_pass_through_transform(N0),
make_pass_through_transform(M1),
make_pass_through_transform(N1),
make_unmerge_transform(make_tuple(Number<mfma_instr.num_groups_per_blk>{},
Number<mfma_instr.num_input_blks>{},
Number<mfma_instr.group_size>{})),
make_pass_through_transform(Number<mfma_instr.num_threads_per_blk>{})),
make_tuple(Sequence<0>{},
Sequence<1>{},
Sequence<2>{},
Sequence<3>{},
Sequence<4>{},
Sequence<5>{},
Sequence<6>{},
Sequence<7>{}),
make_tuple(Sequence<0>{},
Sequence<1>{},
Sequence<2>{},
Sequence<8>{},
Sequence<3>{},
Sequence<4>{},
Sequence<5, 6, 9>{},
Sequence<7>{}));
}
// transposed XDL output supporting C' = B' * A'
// M2_N2 -> M2_N2_N3_N4
template <typename CDesc_M0_N0_M1_N1_M2_N2>
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment