Commit e5bcd2bb authored by Jing Zhang's avatar Jing Zhang
Browse files

debug

parent 41cdd380
......@@ -157,10 +157,13 @@ struct BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1
__host__ __device__ static constexpr auto
MakeCGridDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(const CGridDesc_M_N& c_grid_desc_m_n)
{
const auto M = c_grid_desc_m_n.GetLength(I0);
const auto N = c_grid_desc_m_n.GetLength(I1);
const auto c_grid_desc_m0_n0_m1_n1_m2_n2 = transform_tensor_descriptor(
c_grid_desc_m_n,
make_tuple(make_unmerge_transform(make_tuple(MRepeat, MWaves, MPerXDL)),
make_unmerge_transform(make_tuple(NRepeat, NWaves, NPerXDL))),
make_tuple(make_unmerge_transform(make_tuple(M / (MWaves * MPerXDL), MWaves, MPerXDL)),
make_unmerge_transform(make_tuple(N / (NWaves * NPerXDL), NWaves, NPerXDL))),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0, 2, 4>{}, Sequence<1, 3, 5>{}));
......
......@@ -288,13 +288,13 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3
// A matrix in LDS memory, dst of blockwise copy
constexpr auto a_block_desc_k0_m_k1 = [&]() {
if constexpr(ABlockLdsExtraM)
{
return make_naive_tensor_descriptor(
make_tuple(Number<K0PerBlock>{}, Number<MPerBlock>{}, K1),
make_tuple(Number<MPerBlock + 1>{} * K1, K1, I1));
}
else
// if constexpr(ABlockLdsExtraM)
//{
// return make_naive_tensor_descriptor(
// make_tuple(Number<K0PerBlock>{}, Number<MPerBlock>{}, K1),
// make_tuple(Number<MPerBlock + 1>{} * K1, K1, I1));
//}
// else
{
return make_naive_tensor_descriptor_aligned(
make_tuple(Number<K0PerBlock>{}, Number<MPerBlock>{}, K1), max_lds_align);
......@@ -303,13 +303,13 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3
// B matrix in LDS memory, dst of blockwise copy
constexpr auto b_block_desc_k0_n_k1 = [&]() {
if constexpr(BBlockLdsExtraN)
{
return make_naive_tensor_descriptor(
make_tuple(Number<K0PerBlock>{}, Number<NPerBlock>{}, K1),
make_tuple(Number<NPerBlock + 1>{} * K1, K1, I1));
}
else
// if constexpr(BBlockLdsExtraN)
//{
// return make_naive_tensor_descriptor(
// make_tuple(Number<K0PerBlock>{}, Number<NPerBlock>{}, K1),
// make_tuple(Number<NPerBlock + 1>{} * K1, K1, I1));
//}
// else
{
return make_naive_tensor_descriptor_aligned(
make_tuple(Number<K0PerBlock>{}, Number<NPerBlock>{}, K1), max_lds_align);
......@@ -619,6 +619,11 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3
const index_t n_thread_data_on_grid =
n_block_data_idx_on_grid + c_thread_mtx_on_block[I1];
printf("%d %d %d\n",
get_thread_local_1d_id(),
c_thread_mtx_on_block[I0],
c_thread_mtx_on_block[I1]);
constexpr auto c_m0_n0_m1_n1_m2_m3_m4_n2_grid_tensor_step_hacks = CGridStepHacks{};
const auto m_thread_data_on_grid_to_m0_m1_m2_m3_m4_adaptor =
......@@ -640,6 +645,14 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3
n_thread_data_on_grid_to_n0_n1_n2_adaptor.CalculateBottomIndex(
make_multi_index(n_thread_data_on_grid));
c_thread_buf.Fill(get_thread_local_1d_id());
if(get_thread_local_1d_id() == 0)
printf("%d %d %d\n",
c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2.GetLength(I0),
c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2.GetLength(I1),
c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2.GetLength(I2));
auto c_thread_copy =
ThreadwiseTensorSliceTransfer_v1r3<FloatAcc,
FloatC,
......@@ -652,7 +665,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3
CThreadTransferDstScalarPerVector,
CGlobalMemoryDataOperation,
1,
true>{
false>{
c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2,
make_multi_index(m_thread_data_on_grid_idx[I0],
n_thread_data_on_grid_idx[I0],
......
......@@ -214,6 +214,7 @@ struct ThreadwiseTensorSliceTransfer_v1r3
dst_coord_.GetOffset(),
is_dst_valid,
dst_vector.template AsType<dst_vector_t>()[Number<0>{}]);
printf("copy: %d %d\n", dst_coord_.GetOffset(), dst_coord_.GetIndex()[I0]);
}
else if constexpr(DstInMemOp == InMemoryDataOperationEnum_t::AtomicAdd)
{
......
......@@ -589,6 +589,7 @@ struct XdlopsGemm
const auto N0 = c_desc_m0_n0_m1_n1_m2_n2.GetLength(I1);
const auto M1 = c_desc_m0_n0_m1_n1_m2_n2.GetLength(I2);
const auto N1 = c_desc_m0_n0_m1_n1_m2_n2.GetLength(I3);
const auto N2 = c_desc_m0_n0_m1_n1_m2_n2.GetLength(I5);
return transform_tensor_descriptor(
c_desc_m0_n0_m1_n1_m2_n2,
......@@ -599,7 +600,7 @@ struct XdlopsGemm
make_unmerge_transform(make_tuple(mfma_instr.num_groups_per_blk,
mfma_instr.num_input_blks,
mfma_instr.group_size)),
make_pass_through_transform(mfma_instr.num_threads_per_blk)),
make_pass_through_transform(N2)),
make_tuple(Sequence<0>{},
Sequence<1>{},
Sequence<2>{},
......
......@@ -57,7 +57,7 @@
// AMD buffer addressing
#ifndef CK_USE_AMD_BUFFER_ADDRESSING
#define CK_USE_AMD_BUFFER_ADDRESSING 1
#define CK_USE_AMD_BUFFER_ADDRESSING 0
#endif
// only gfx908 support native floating point atomic add
......
......@@ -104,6 +104,11 @@ struct StaticBufferOfVectorTypeV2 : public StaticallyIndexedArray<T, N>
[&](auto i) { GetElement(i, true) = invalid_element_value_; });
}
__host__ __device__ void Fill(VecBaseType val)
{
static_for<0, GetNumElements(), 1>{}([&](auto i) { GetElement(i, true) = val; });
}
__host__ __device__ static constexpr bool IsStaticBuffer() { return true; }
__host__ __device__ static constexpr bool IsDynamicBuffer() { return false; }
......
......@@ -27,14 +27,18 @@ using device_gemm_xdl_instance_f32_f32_f32_mk_kn_mn = std::tuple<
//##########| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadSlice| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ThreadSlice| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| SrcDstVectorDim| DstScalar| AddExtraM| AddExtraN|
//##########| | | | | | | | Operation| Operation| Operation| | | | | | | | Wave| Wave| Lengths_K0_N_K1| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| Lengths_K0_N_K1| Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerVector| | |
//##########| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
DeviceGemmXdl< F32, F32, F32, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, 256, 256, 128, 4, 4, 32, 32, 4, 2, S<1, 4, 4>, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, S<1, 2, 4>, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, 7, 1, true, true>,
DeviceGemmXdl< F32, F32, F32, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, 256, 128, 256, 4, 4, 32, 32, 2, 4, S<1, 2, 4>, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, S<1, 4, 4>, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 7, 1, true, true>,
DeviceGemmXdl< F32, F32, F32, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, 128, 128, 128, 4, 4, 32, 32, 4, 2, S<1, 4, 4>, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, S<1, 4, 4>, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 7, 1, true, true>,
DeviceGemmXdl< F32, F32, F32, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, 256, 128, 128, 4, 4, 32, 32, 2, 2, S<1, 2, 4>, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, S<1, 2, 4>, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, 7, 1, true, true>,
DeviceGemmXdl< F32, F32, F32, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, 128, 128, 64, 4, 4, 32, 32, 2, 2, S<1, 4, 4>, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, S<1, 2, 4>, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, 7, 1, true, true>,
DeviceGemmXdl< F32, F32, F32, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, 128, 64, 128, 4, 4, 32, 32, 2, 2, S<1, 2, 4>, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, S<1, 4, 4>, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 7, 1, true, true>,
DeviceGemmXdl< F32, F32, F32, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, 256, 128, 64, 4, 4, 32, 32, 2, 1, S<1, 2, 4>, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, S<1, 1, 4>, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 4, 7, 1, true, true>,
DeviceGemmXdl< F32, F32, F32, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, 256, 64, 128, 4, 4, 32, 32, 1, 2, S<1, 1, 4>, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, S<1, 2, 4>, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, 7, 1, true, true>
//DeviceGemmXdl< F32, F32, F32, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, 256, 96, 128, 4, 4, 16, 16, 3, 4, S<1, 3, 4>, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, S<1, 2, 4>, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, 7, 1, true, true>
//DeviceGemmXdl< F32, F32, F32, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, 256, 32, 128, 4, 4, 16, 16, 1, 4, S<1, 1, 4>, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, S<1, 2, 4>, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, 7, 1, true, true>
DeviceGemmXdl< F32, F32, F32, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, 128, 96, 128, 4, 4, 32, 32, 3, 2, S<1, 3, 4>, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, S<1, 4, 4>, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, 7, 1, true, true>
//DeviceGemmXdl< F32, F32, F32, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, 256, 128, 256, 4, 4, 32, 32, 2, 4, S<1, 2, 4>, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, S<1, 4, 4>, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 7, 1, true, true>,
//DeviceGemmXdl< F32, F32, F32, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, 256, 256, 128, 4, 4, 32, 32, 4, 2, S<1, 4, 4>, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, S<1, 2, 4>, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, 7, 1, true, true>,
//DeviceGemmXdl< F32, F32, F32, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, 256, 128, 256, 4, 4, 32, 32, 2, 4, S<1, 2, 4>, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, S<1, 4, 4>, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 7, 1, true, true>,
//DeviceGemmXdl< F32, F32, F32, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, 128, 128, 128, 4, 4, 32, 32, 4, 2, S<1, 4, 4>, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, S<1, 4, 4>, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 7, 1, true, true>,
//DeviceGemmXdl< F32, F32, F32, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, 256, 128, 128, 4, 4, 32, 32, 2, 2, S<1, 2, 4>, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, S<1, 2, 4>, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, 7, 1, true, true>,
//DeviceGemmXdl< F32, F32, F32, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, 128, 128, 64, 4, 4, 32, 32, 2, 2, S<1, 4, 4>, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, S<1, 2, 4>, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, 7, 1, true, true>,
//DeviceGemmXdl< F32, F32, F32, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, 128, 64, 128, 4, 4, 32, 32, 2, 2, S<1, 2, 4>, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, S<1, 4, 4>, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 7, 1, true, true>,
//DeviceGemmXdl< F32, F32, F32, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, 256, 128, 64, 4, 4, 32, 32, 2, 1, S<1, 2, 4>, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, S<1, 1, 4>, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 4, 7, 1, true, true>,
//DeviceGemmXdl< F32, F32, F32, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, 256, 64, 128, 4, 4, 32, 32, 1, 2, S<1, 1, 4>, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, S<1, 2, 4>, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, 7, 1, true, true>
// clang-format on
>;
......
......@@ -287,27 +287,27 @@ void device_convolution_forward_implicit_gemm_v4r4r4_xdlops_nhwc_kyxc_nhwk(
constexpr index_t GemmCThreadTransferDstScalarPerVector = 1;
#elif 1
// [M, N, K0, K1] = [128, 256, 4, 8], C = 128, for fp16
constexpr index_t BlockSize = 256;
constexpr index_t BlockSize = 64;
constexpr index_t GemmMPerBlock = 128;
constexpr index_t GemmNPerBlock = 256;
constexpr index_t GemmMPerBlock = 48;
constexpr index_t GemmNPerBlock = 16;
constexpr index_t GemmKPerBlock = 4;
constexpr index_t GemmMPerXDL = 32;
constexpr index_t GemmNPerXDL = 32;
constexpr index_t GemmMPerXDL = 16;
constexpr index_t GemmNPerXDL = 16;
constexpr index_t GemmK1 = 8;
constexpr index_t MRepeat = 2;
constexpr index_t NRepeat = 4;
constexpr index_t MRepeat = 3;
constexpr index_t NRepeat = 1;
using GemmABlockTransferThreadSliceLengths_GemmK0_GemmM_GemmK1 = Sequence<1, 2, 8>;
using GemmABlockTransferThreadClusterLengths_GemmK0_GemmM_GemmK1 = Sequence<4, 64, 1>;
using GemmABlockTransferThreadSliceLengths_GemmK0_GemmM_GemmK1 = Sequence<4, 1, 8>;
using GemmABlockTransferThreadClusterLengths_GemmK0_GemmM_GemmK1 = Sequence<1, 48, 1>;
constexpr index_t GemmABlockTransferSrcScalarPerVector_GemmK1 = 8;
constexpr index_t GemmABlockTransferDstScalarPerVector_GemmK1 = 8;
using GemmBBlockTransferThreadSliceLengths_GemmK0_GemmN_GemmK1 = Sequence<1, 4, 8>;
using GemmBBlockTransferThreadClusterLengths_GemmK0_GemmN_GemmK1 = Sequence<4, 64, 1>;
using GemmBBlockTransferThreadSliceLengths_GemmK0_GemmN_GemmK1 = Sequence<4, 1, 8>;
using GemmBBlockTransferThreadClusterLengths_GemmK0_GemmN_GemmK1 = Sequence<1, 16, 1>;
constexpr index_t GemmBBlockTransferSrcScalarPerVector_GemmK1 = 8;
constexpr index_t GemmBBlockTransferDstScalarPerVector_GemmK1 = 8;
......
......@@ -162,23 +162,23 @@ void device_gemm_xdlops_mk_kn_mn(const Tensor<ABType>& a_m_k,
constexpr index_t BBlockTransferDstScalarPerVector_K1 = 4;
constexpr index_t CThreadTransferDstScalarPerVector = 1;
#elif 1
#elif 0
// [M, N, K0, K1] = [256, 128, 4, 8], C = 128, for fp16
constexpr index_t BlockSize = 256;
constexpr index_t MPerBlock = 256;
constexpr index_t MPerBlock = 32;
constexpr index_t NPerBlock = 128;
constexpr index_t KPerBlock = 4;
constexpr index_t MPerXDL = 32;
constexpr index_t NPerXDL = 32;
constexpr index_t MPerXDL = 16;
constexpr index_t NPerXDL = 16;
constexpr index_t K1 = 8;
constexpr index_t MRepeat = 4;
constexpr index_t NRepeat = 2;
constexpr index_t MRepeat = 1;
constexpr index_t NRepeat = 4;
using ABlockTransferThreadSliceLengths_K0_M_K1 = Sequence<1, 4, 8>;
using ABlockTransferThreadClusterLengths_K0_M_K1 = Sequence<4, 64, 1>;
using ABlockTransferThreadSliceLengths_K0_M_K1 = Sequence<1, 1, 8>;
using ABlockTransferThreadClusterLengths_K0_M_K1 = Sequence<4, 32, 1>;
constexpr index_t ABlockTransferSrcScalarPerVector_K1 = 8;
constexpr index_t ABlockTransferDstScalarPerVector_K1 = 8;
......@@ -189,6 +189,34 @@ void device_gemm_xdlops_mk_kn_mn(const Tensor<ABType>& a_m_k,
constexpr index_t BBlockTransferSrcScalarPerVector_N = 2;
constexpr index_t BBlockTransferDstScalarPerVector_K1 = 8;
constexpr index_t CThreadTransferDstScalarPerVector = 1;
#elif 1
// [M, N, K0, K1] = [256, 128, 4, 8], C = 128, for fp16
constexpr index_t BlockSize = 64;
constexpr index_t MPerBlock = 48;
constexpr index_t NPerBlock = 16;
constexpr index_t KPerBlock = 4;
constexpr index_t MPerXDL = 16;
constexpr index_t NPerXDL = 16;
constexpr index_t K1 = 8;
constexpr index_t MRepeat = 3;
constexpr index_t NRepeat = 1;
using ABlockTransferThreadSliceLengths_K0_M_K1 = Sequence<4, 1, 8>;
using ABlockTransferThreadClusterLengths_K0_M_K1 = Sequence<1, 48, 1>;
constexpr index_t ABlockTransferSrcScalarPerVector_K1 = 1;
constexpr index_t ABlockTransferDstScalarPerVector_K1 = 1;
using BBlockTransferThreadSliceLengths_K0_N_K1 = Sequence<4, 1, 8>;
using BBlockTransferThreadClusterLengths_K0_N_K1 = Sequence<1, 16, 1>;
constexpr index_t BBlockTransferSrcScalarPerVector_N = 1;
constexpr index_t BBlockTransferDstScalarPerVector_K1 = 1;
constexpr index_t CThreadTransferDstScalarPerVector = 1;
#elif 0
// [M, N, K0, K1] = [128, 256, 4, 8] for fp16
......@@ -351,8 +379,7 @@ void device_gemm_xdlops_mk_kn_mn(const Tensor<ABType>& a_m_k,
b_k_n.mDesc.GetStrides()[1],
b_k_n.mDesc.GetStrides()[0]));
const auto c_m_n_grid_desc = make_naive_tensor_descriptor(
make_tuple(M, N), make_tuple(c_m_n.mDesc.GetStrides()[0], c_m_n.mDesc.GetStrides()[1]));
const auto c_m_n_grid_desc = make_naive_tensor_descriptor_packed(make_tuple(M, N));
// HACK: hacks that control index calculation when iterating over A, B, C matrix
constexpr auto a_k0_m_k1_grid_step_hacks = make_tuple(make_tuple(Sequence<0>{}, // 0+: K0
......
......@@ -6,6 +6,15 @@
#include "tensor_descriptor_helper.hpp"
#include "gridwise_gemm_xdlops_v2r3.hpp"
struct OpPassThrough
{
template <typename T>
__host__ __device__ constexpr T operator()(T v) const
{
return v;
}
};
template <ck::index_t BlockSize,
typename FloatAB,
typename FloatAcc,
......@@ -70,6 +79,8 @@ __host__ float driver_gemm_xdlops_v2r3(const FloatAB* p_a_grid,
constexpr auto I1 = Number<1>{};
constexpr auto I2 = Number<2>{};
using ElementwiseOperation = OpPassThrough;
using GridwiseGemm =
GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3<BlockSize,
FloatAB,
......@@ -79,6 +90,9 @@ __host__ float driver_gemm_xdlops_v2r3(const FloatAB* p_a_grid,
AGridDesc_K0_M_K1,
BGridDesc_K0_N_K,
CMNGridDesc,
ElementwiseOperation,
ElementwiseOperation,
ElementwiseOperation,
MPerBlock,
NPerBlock,
KPerBlock,
......@@ -152,6 +166,8 @@ __host__ float driver_gemm_xdlops_v2r3(const FloatAB* p_a_grid,
float ave_time = 0;
auto element_op_ = OpPassThrough{};
#if CK_EXPERIMENTAL_PASS_TENSOR_DESCRIPTOR_BY_VALUE
if(has_main_k0_block_loop)
{
......@@ -162,6 +178,9 @@ __host__ float driver_gemm_xdlops_v2r3(const FloatAB* p_a_grid,
remove_reference_t<AGridDesc_K0_M_K1>,
remove_reference_t<BGridDesc_K0_N_K>,
remove_reference_t<CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2>,
ElementwiseOperation,
ElementwiseOperation,
ElementwiseOperation,
remove_reference_t<Block2CTileMap>,
true>;
......@@ -176,6 +195,9 @@ __host__ float driver_gemm_xdlops_v2r3(const FloatAB* p_a_grid,
a_grid_desc_k0_m_k1,
b_grid_desc_k0_n_k1,
c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc,
element_op_,
element_op_,
element_op_,
block_2_ctile_map);
}
else
......@@ -187,6 +209,9 @@ __host__ float driver_gemm_xdlops_v2r3(const FloatAB* p_a_grid,
remove_reference_t<AGridDesc_K0_M_K1>,
remove_reference_t<BGridDesc_K0_N_K>,
remove_reference_t<CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2>,
ElementwiseOperation,
ElementwiseOperation,
ElementwiseOperation,
remove_reference_t<Block2CTileMap>,
false>;
......@@ -201,6 +226,9 @@ __host__ float driver_gemm_xdlops_v2r3(const FloatAB* p_a_grid,
a_grid_desc_k0_m_k1,
b_grid_desc_k0_n_k1,
c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc,
element_op_,
element_op_,
element_op_,
block_2_ctile_map);
}
#elif CK_EXPERIMENTAL_PASS_TENSOR_DESCRIPTOR_BY_VOID_POINTER
......
......@@ -12,10 +12,10 @@
#include "host_tensor_generator.hpp"
#include "conv_common.hpp"
#include "device_tensor.hpp"
#include "device_convolution_forward_implicit_gemm_v4r4_dlops_nchw_kcyx_nkhw.hpp"
#include "device_convolution_forward_implicit_gemm_v4r4r2_dlops_nhwc_kyxc_nhwk.hpp"
#include "device_convolution_forward_implicit_gemm_v6r1_dlops_nchw_kcyx_nkhw.hpp"
#include "device_convolution_forward_implicit_gemm_v4r4r2_xdlops_nchw_kcyx_nkhw.hpp"
//#include "device_convolution_forward_implicit_gemm_v4r4_dlops_nchw_kcyx_nkhw.hpp"
//#include "device_convolution_forward_implicit_gemm_v4r4r2_dlops_nhwc_kyxc_nhwk.hpp"
//#include "device_convolution_forward_implicit_gemm_v6r1_dlops_nchw_kcyx_nkhw.hpp"
//#include "device_convolution_forward_implicit_gemm_v4r4r2_xdlops_nchw_kcyx_nkhw.hpp"
#include "device_convolution_forward_implicit_gemm_v4r4r4_xdlops_nhwc_kyxc_nhwk.hpp"
#define USE_DYNAMIC_MODE 1
......
......@@ -22,9 +22,9 @@
#include "device_gemm_xdlops_km_nk_nm.hpp"
#define USE_GEMM_XDL_MK_KN_MN 1
#define USE_GEMM_XDL_MK_NK_MN 1
#define USE_GEMM_XDL_KM_KN_MN 1
#define USE_GEMM_XDL_KM_NK_MN 1
#define USE_GEMM_XDL_MK_NK_MN 0
#define USE_GEMM_XDL_KM_KN_MN 0
#define USE_GEMM_XDL_KM_NK_MN 0
#define USE_GEMM_XDL_MK_KN_NM 0
#define USE_GEMM_XDL_MK_NK_NM 0
#define USE_GEMM_XDL_KM_KN_NM 0
......@@ -445,8 +445,8 @@ int main(int argc, char* argv[])
if(do_log)
{
LogRangeAsType<float>(std::cout << "a : ", a.mData, ",") << std::endl;
LogRangeAsType<float>(std::cout << "b: ", b.mData, ",") << std::endl;
// LogRangeAsType<float>(std::cout << "a : ", a.mData, ",") << std::endl;
// LogRangeAsType<float>(std::cout << "b: ", b.mData, ",") << std::endl;
LogRangeAsType<float>(std::cout << "c_host : ", c_host.mData, ",") << std::endl;
LogRangeAsType<float>(std::cout << "c_device: ", c_device.mData, ",") << std::endl;
}
......
......@@ -15,13 +15,13 @@ include_directories(BEFORE
# device_gemm_instance
set(DEVICE_GEMM_INSTANCE_SOURCE
${PROJECT_SOURCE_DIR}/device_operation/device_gemm_xdl_instance_f32_f32_f32_mk_kn_mn.cpp;
${PROJECT_SOURCE_DIR}/device_operation/device_gemm_xdl_instance_f32_f32_f32_mk_nk_mn.cpp;
${PROJECT_SOURCE_DIR}/device_operation/device_gemm_xdl_instance_f32_f32_f32_km_kn_mn.cpp;
${PROJECT_SOURCE_DIR}/device_operation/device_gemm_xdl_instance_f32_f32_f32_km_nk_mn.cpp;
${PROJECT_SOURCE_DIR}/device_operation/device_gemm_xdl_instance_f16_f16_f16_mk_kn_mn.cpp;
${PROJECT_SOURCE_DIR}/device_operation/device_gemm_xdl_instance_f16_f16_f16_mk_nk_mn.cpp;
${PROJECT_SOURCE_DIR}/device_operation/device_gemm_xdl_instance_f16_f16_f16_km_kn_mn.cpp;
${PROJECT_SOURCE_DIR}/device_operation/device_gemm_xdl_instance_f16_f16_f16_km_nk_mn.cpp;
#${PROJECT_SOURCE_DIR}/device_operation/device_gemm_xdl_instance_f32_f32_f32_mk_nk_mn.cpp;
#${PROJECT_SOURCE_DIR}/device_operation/device_gemm_xdl_instance_f32_f32_f32_km_kn_mn.cpp;
#${PROJECT_SOURCE_DIR}/device_operation/device_gemm_xdl_instance_f32_f32_f32_km_nk_mn.cpp;
#${PROJECT_SOURCE_DIR}/device_operation/device_gemm_xdl_instance_f16_f16_f16_mk_kn_mn.cpp;
#${PROJECT_SOURCE_DIR}/device_operation/device_gemm_xdl_instance_f16_f16_f16_mk_nk_mn.cpp;
#${PROJECT_SOURCE_DIR}/device_operation/device_gemm_xdl_instance_f16_f16_f16_km_kn_mn.cpp;
#${PROJECT_SOURCE_DIR}/device_operation/device_gemm_xdl_instance_f16_f16_f16_km_nk_mn.cpp;
)
add_library(device_gemm_instance SHARED ${DEVICE_GEMM_INSTANCE_SOURCE})
......@@ -31,20 +31,20 @@ set_target_properties(device_gemm_instance PROPERTIES POSITION_INDEPENDENT_CODE
install(TARGETS device_gemm_instance LIBRARY DESTINATION lib)
# device_conv_instance
set(DEVICE_CONV_INSTANCE_SOURCE
${PROJECT_SOURCE_DIR}/device_operation/device_conv_xdl_instance_f32_f32_f32_nhwc_kyxc_nhwk.cpp;
${PROJECT_SOURCE_DIR}/device_operation/device_conv_xdl_instance_f16_f16_f16_nhwc_kyxc_nhwk.cpp;
)
#set(DEVICE_CONV_INSTANCE_SOURCE
##${PROJECT_SOURCE_DIR}/device_operation/device_conv_xdl_instance_f32_f32_f32_nhwc_kyxc_nhwk.cpp;
##${PROJECT_SOURCE_DIR}/device_operation/device_conv_xdl_instance_f16_f16_f16_nhwc_kyxc_nhwk.cpp;
#)
add_library(device_conv_instance SHARED ${DEVICE_CONV_INSTANCE_SOURCE})
target_include_directories(device_conv_instance SYSTEM PUBLIC $<BUILD_INTERFACE:${HALF_INCLUDE_DIR}>)
target_compile_features(device_conv_instance PUBLIC)
set_target_properties(device_conv_instance PROPERTIES POSITION_INDEPENDENT_CODE ON)
install(TARGETS device_conv_instance LIBRARY DESTINATION lib)
#add_library(device_conv_instance SHARED ${DEVICE_CONV_INSTANCE_SOURCE})
#target_include_directories(device_conv_instance SYSTEM PUBLIC $<BUILD_INTERFACE:${HALF_INCLUDE_DIR}>)
#target_compile_features(device_conv_instance PUBLIC)
#set_target_properties(device_conv_instance PROPERTIES POSITION_INDEPENDENT_CODE ON)
#install(TARGETS device_conv_instance LIBRARY DESTINATION lib)
# ck_profiler
set(PROFILER_SOURCE profiler.cpp gemm_profiler.cpp conv_profiler.cpp)
set(PROFILER_SOURCE profiler.cpp gemm_profiler.cpp)
add_executable(ckProfiler ${PROFILER_SOURCE})
target_link_libraries(ckProfiler PRIVATE host_tensor)
target_link_libraries(ckProfiler PRIVATE device_gemm_instance device_conv_instance)
target_link_libraries(ckProfiler PRIVATE device_gemm_instance)
......@@ -66,6 +66,7 @@ int gemm_profiler(int argc, char* argv[])
const int StrideB = std::stoi(argv[12]);
const int StrideC = std::stoi(argv[13]);
#if 0
if(data_type == GemmDataType::F16_F16_F16 && layout == GemmMatrixLayout::MK_KN_MN)
{
ck::profiler::profile_gemm<ck::half_t,
......@@ -210,6 +211,27 @@ int gemm_profiler(int argc, char* argv[])
(StrideB < 0) ? K : StrideB,
(StrideC < 0) ? N : StrideC);
}
#endif
if(data_type == GemmDataType::F32_F32_F32 && layout == GemmMatrixLayout::MK_KN_MN)
{
ck::profiler::profile_gemm<float,
float,
float,
ck::tensor_layout::gemm::RowMajor,
ck::tensor_layout::gemm::RowMajor,
ck::tensor_layout::gemm::RowMajor>(do_verification,
init_method,
do_log,
nrepeat,
M,
N,
K,
(StrideA < 0) ? K : StrideA,
(StrideB < 0) ? N : StrideB,
(StrideC < 0) ? N : StrideC);
}
else
{
throw std::runtime_error("wrong! this GEMM data_type & layout is not implemented");
......
......@@ -6,7 +6,7 @@
#include <half.hpp>
int gemm_profiler(int, char*[]);
int conv_profiler(int, char*[]);
// int conv_profiler(int, char*[]);
int main(int argc, char* argv[])
{
......@@ -14,10 +14,10 @@ int main(int argc, char* argv[])
{
return gemm_profiler(argc, argv);
}
else if(strcmp(argv[1], "conv") == 0)
{
return conv_profiler(argc, argv);
}
// else if(strcmp(argv[1], "conv") == 0)
//{
// return conv_profiler(argc, argv);
//}
else
{
printf("arg1: tensor operation (gemm=GEMM, conv=Convolution)\n");
......
......@@ -10,7 +10,7 @@ cmake
-D CMAKE_INSTALL_PREFIX=${MY_PROJECT_INSTALL} \
-D BUILD_DEV=OFF \
-D CMAKE_BUILD_TYPE=Release \
-D CMAKE_CXX_FLAGS="-DCK_AMD_GPU_GFX908 --amdgpu-target=gfx908 -O3 -ftemplate-backtrace-limit=0 -mllvm --amdgpu-spill-vgpr-to-agpr=0 -gline-tables-only -save-temps=$PWD" \
-D CMAKE_CXX_FLAGS="-DCK_AMD_GPU_GFX908 --amdgpu-target=gfx908 -O1 -ftemplate-backtrace-limit=0 -mllvm --amdgpu-spill-vgpr-to-agpr=0 -gline-tables-only -save-temps=$PWD" \
-D CMAKE_CXX_COMPILER=/opt/rocm/bin/hipcc \
-D CMAKE_PREFIX_PATH=/opt/rocm \
-D CMAKE_VERBOSE_MAKEFILE:BOOL=ON \
......
......@@ -22,7 +22,7 @@ REPEAT=$6
######### layout algo verify init log repeat N__ K___ C___ Y X Hi_ Wi__ Strides Dilations LeftPads RightPads Desired_grid_size__
#$DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 128 128 192 3 3 71 71 2 2 1 1 1 1 1 1 $DESIRED_GRID_SIZE
$DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 128 256 192 3 3 71 71 2 2 1 1 1 1 1 1 $DESIRED_GRID_SIZE
$DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 1 16 32 1 1 1 48 1 1 1 1 0 0 0 0 $DESIRED_GRID_SIZE
#$DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 128 256 1024 1 7 17 17 1 1 1 1 0 3 0 3 $DESIRED_GRID_SIZE
#$DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 256 256 256 3 3 14 14 1 1 1 1 1 1 1 1 $DESIRED_GRID_SIZE
#$DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 128 128 128 3 3 14 14 1 1 1 1 1 1 1 1 $DESIRED_GRID_SIZE
......
......@@ -19,7 +19,8 @@ REPEAT=$6
######### layout algo verify init log repeat M___ N___ K___ M01_ N01_
#$DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 960 1024 1024 $M01 $N01
$DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 48 16 32 $M01 $N01
#$DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 1024 1024 1024 $M01 $N01
#$DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 1920 2048 2048 $M01 $N01
$DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 3840 4096 4096 $M01 $N01
#$DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 3840 4096 4096 $M01 $N01
#$DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 7680 8192 8192 $M01 $N01
......@@ -25,21 +25,21 @@ REPEAT=$7
#$DRIVER $OP $DATATYPE $LAYOUT $VERIFY $INIT $LOG $REPEAT 2048 2048 2048 2048 2048 2048
$DRIVER $OP $DATATYPE $LAYOUT $VERIFY $INIT $LOG $REPEAT 960 1024 1024 -1 -1 -1
$DRIVER $OP $DATATYPE $LAYOUT $VERIFY $INIT $LOG $REPEAT 1920 2048 2048 -1 -1 -1
$DRIVER $OP $DATATYPE $LAYOUT $VERIFY $INIT $LOG $REPEAT 3840 4096 4096 -1 -1 -1
$DRIVER $OP $DATATYPE $LAYOUT $VERIFY $INIT $LOG $REPEAT 7680 8192 8192 -1 -1 -1
$DRIVER $OP $DATATYPE $LAYOUT $VERIFY $INIT $LOG $REPEAT 1024 1024 1024 1024 1024 1024
$DRIVER $OP $DATATYPE $LAYOUT $VERIFY $INIT $LOG $REPEAT 2048 2048 2048 2048 2048 2048
$DRIVER $OP $DATATYPE $LAYOUT $VERIFY $INIT $LOG $REPEAT 4096 4096 4096 4096 4096 4096
$DRIVER $OP $DATATYPE $LAYOUT $VERIFY $INIT $LOG $REPEAT 8192 8192 8192 8192 8192 8192
$DRIVER $OP $DATATYPE $LAYOUT $VERIFY $INIT $LOG $REPEAT 1024 1024 1024 1056 1056 1056
$DRIVER $OP $DATATYPE $LAYOUT $VERIFY $INIT $LOG $REPEAT 2048 2048 2048 2080 2080 2080
$DRIVER $OP $DATATYPE $LAYOUT $VERIFY $INIT $LOG $REPEAT 4096 4096 4096 4128 4128 4128
$DRIVER $OP $DATATYPE $LAYOUT $VERIFY $INIT $LOG $REPEAT 8192 8192 8192 8224 8224 8224
$DRIVER $OP $DATATYPE $LAYOUT $VERIFY $INIT $LOG $REPEAT 1024 1024 1024 1088 1088 1088
$DRIVER $OP $DATATYPE $LAYOUT $VERIFY $INIT $LOG $REPEAT 2048 2048 2048 2112 2112 2112
$DRIVER $OP $DATATYPE $LAYOUT $VERIFY $INIT $LOG $REPEAT 4096 4096 4096 4160 4160 4160
$DRIVER $OP $DATATYPE $LAYOUT $VERIFY $INIT $LOG $REPEAT 8192 8192 8192 8256 8256 8256
#$DRIVER $OP $DATATYPE $LAYOUT $VERIFY $INIT $LOG $REPEAT 1920 2048 2048 -1 -1 -1
#$DRIVER $OP $DATATYPE $LAYOUT $VERIFY $INIT $LOG $REPEAT 3840 4096 4096 -1 -1 -1
#$DRIVER $OP $DATATYPE $LAYOUT $VERIFY $INIT $LOG $REPEAT 7680 8192 8192 -1 -1 -1
#$DRIVER $OP $DATATYPE $LAYOUT $VERIFY $INIT $LOG $REPEAT 1024 1024 1024 1024 1024 1024
#$DRIVER $OP $DATATYPE $LAYOUT $VERIFY $INIT $LOG $REPEAT 2048 2048 2048 2048 2048 2048
#$DRIVER $OP $DATATYPE $LAYOUT $VERIFY $INIT $LOG $REPEAT 4096 4096 4096 4096 4096 4096
#$DRIVER $OP $DATATYPE $LAYOUT $VERIFY $INIT $LOG $REPEAT 8192 8192 8192 8192 8192 8192
#$DRIVER $OP $DATATYPE $LAYOUT $VERIFY $INIT $LOG $REPEAT 1024 1024 1024 1056 1056 1056
#$DRIVER $OP $DATATYPE $LAYOUT $VERIFY $INIT $LOG $REPEAT 2048 2048 2048 2080 2080 2080
#$DRIVER $OP $DATATYPE $LAYOUT $VERIFY $INIT $LOG $REPEAT 4096 4096 4096 4128 4128 4128
#$DRIVER $OP $DATATYPE $LAYOUT $VERIFY $INIT $LOG $REPEAT 8192 8192 8192 8224 8224 8224
#$DRIVER $OP $DATATYPE $LAYOUT $VERIFY $INIT $LOG $REPEAT 1024 1024 1024 1088 1088 1088
#$DRIVER $OP $DATATYPE $LAYOUT $VERIFY $INIT $LOG $REPEAT 2048 2048 2048 2112 2112 2112
#$DRIVER $OP $DATATYPE $LAYOUT $VERIFY $INIT $LOG $REPEAT 4096 4096 4096 4160 4160 4160
#$DRIVER $OP $DATATYPE $LAYOUT $VERIFY $INIT $LOG $REPEAT 8192 8192 8192 8256 8256 8256
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment