Commit e5bcd2bb authored by Jing Zhang's avatar Jing Zhang
Browse files

debug

parent 41cdd380
...@@ -157,10 +157,13 @@ struct BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1 ...@@ -157,10 +157,13 @@ struct BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1
__host__ __device__ static constexpr auto __host__ __device__ static constexpr auto
MakeCGridDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(const CGridDesc_M_N& c_grid_desc_m_n) MakeCGridDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(const CGridDesc_M_N& c_grid_desc_m_n)
{ {
const auto M = c_grid_desc_m_n.GetLength(I0);
const auto N = c_grid_desc_m_n.GetLength(I1);
const auto c_grid_desc_m0_n0_m1_n1_m2_n2 = transform_tensor_descriptor( const auto c_grid_desc_m0_n0_m1_n1_m2_n2 = transform_tensor_descriptor(
c_grid_desc_m_n, c_grid_desc_m_n,
make_tuple(make_unmerge_transform(make_tuple(MRepeat, MWaves, MPerXDL)), make_tuple(make_unmerge_transform(make_tuple(M / (MWaves * MPerXDL), MWaves, MPerXDL)),
make_unmerge_transform(make_tuple(NRepeat, NWaves, NPerXDL))), make_unmerge_transform(make_tuple(N / (NWaves * NPerXDL), NWaves, NPerXDL))),
make_tuple(Sequence<0>{}, Sequence<1>{}), make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0, 2, 4>{}, Sequence<1, 3, 5>{})); make_tuple(Sequence<0, 2, 4>{}, Sequence<1, 3, 5>{}));
......
...@@ -288,13 +288,13 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3 ...@@ -288,13 +288,13 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3
// A matrix in LDS memory, dst of blockwise copy // A matrix in LDS memory, dst of blockwise copy
constexpr auto a_block_desc_k0_m_k1 = [&]() { constexpr auto a_block_desc_k0_m_k1 = [&]() {
if constexpr(ABlockLdsExtraM) // if constexpr(ABlockLdsExtraM)
{ //{
return make_naive_tensor_descriptor( // return make_naive_tensor_descriptor(
make_tuple(Number<K0PerBlock>{}, Number<MPerBlock>{}, K1), // make_tuple(Number<K0PerBlock>{}, Number<MPerBlock>{}, K1),
make_tuple(Number<MPerBlock + 1>{} * K1, K1, I1)); // make_tuple(Number<MPerBlock + 1>{} * K1, K1, I1));
} //}
else // else
{ {
return make_naive_tensor_descriptor_aligned( return make_naive_tensor_descriptor_aligned(
make_tuple(Number<K0PerBlock>{}, Number<MPerBlock>{}, K1), max_lds_align); make_tuple(Number<K0PerBlock>{}, Number<MPerBlock>{}, K1), max_lds_align);
...@@ -303,13 +303,13 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3 ...@@ -303,13 +303,13 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3
// B matrix in LDS memory, dst of blockwise copy // B matrix in LDS memory, dst of blockwise copy
constexpr auto b_block_desc_k0_n_k1 = [&]() { constexpr auto b_block_desc_k0_n_k1 = [&]() {
if constexpr(BBlockLdsExtraN) // if constexpr(BBlockLdsExtraN)
{ //{
return make_naive_tensor_descriptor( // return make_naive_tensor_descriptor(
make_tuple(Number<K0PerBlock>{}, Number<NPerBlock>{}, K1), // make_tuple(Number<K0PerBlock>{}, Number<NPerBlock>{}, K1),
make_tuple(Number<NPerBlock + 1>{} * K1, K1, I1)); // make_tuple(Number<NPerBlock + 1>{} * K1, K1, I1));
} //}
else // else
{ {
return make_naive_tensor_descriptor_aligned( return make_naive_tensor_descriptor_aligned(
make_tuple(Number<K0PerBlock>{}, Number<NPerBlock>{}, K1), max_lds_align); make_tuple(Number<K0PerBlock>{}, Number<NPerBlock>{}, K1), max_lds_align);
...@@ -619,6 +619,11 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3 ...@@ -619,6 +619,11 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3
const index_t n_thread_data_on_grid = const index_t n_thread_data_on_grid =
n_block_data_idx_on_grid + c_thread_mtx_on_block[I1]; n_block_data_idx_on_grid + c_thread_mtx_on_block[I1];
printf("%d %d %d\n",
get_thread_local_1d_id(),
c_thread_mtx_on_block[I0],
c_thread_mtx_on_block[I1]);
constexpr auto c_m0_n0_m1_n1_m2_m3_m4_n2_grid_tensor_step_hacks = CGridStepHacks{}; constexpr auto c_m0_n0_m1_n1_m2_m3_m4_n2_grid_tensor_step_hacks = CGridStepHacks{};
const auto m_thread_data_on_grid_to_m0_m1_m2_m3_m4_adaptor = const auto m_thread_data_on_grid_to_m0_m1_m2_m3_m4_adaptor =
...@@ -640,6 +645,14 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3 ...@@ -640,6 +645,14 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3
n_thread_data_on_grid_to_n0_n1_n2_adaptor.CalculateBottomIndex( n_thread_data_on_grid_to_n0_n1_n2_adaptor.CalculateBottomIndex(
make_multi_index(n_thread_data_on_grid)); make_multi_index(n_thread_data_on_grid));
c_thread_buf.Fill(get_thread_local_1d_id());
if(get_thread_local_1d_id() == 0)
printf("%d %d %d\n",
c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2.GetLength(I0),
c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2.GetLength(I1),
c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2.GetLength(I2));
auto c_thread_copy = auto c_thread_copy =
ThreadwiseTensorSliceTransfer_v1r3<FloatAcc, ThreadwiseTensorSliceTransfer_v1r3<FloatAcc,
FloatC, FloatC,
...@@ -652,7 +665,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3 ...@@ -652,7 +665,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3
CThreadTransferDstScalarPerVector, CThreadTransferDstScalarPerVector,
CGlobalMemoryDataOperation, CGlobalMemoryDataOperation,
1, 1,
true>{ false>{
c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2, c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2,
make_multi_index(m_thread_data_on_grid_idx[I0], make_multi_index(m_thread_data_on_grid_idx[I0],
n_thread_data_on_grid_idx[I0], n_thread_data_on_grid_idx[I0],
......
...@@ -214,6 +214,7 @@ struct ThreadwiseTensorSliceTransfer_v1r3 ...@@ -214,6 +214,7 @@ struct ThreadwiseTensorSliceTransfer_v1r3
dst_coord_.GetOffset(), dst_coord_.GetOffset(),
is_dst_valid, is_dst_valid,
dst_vector.template AsType<dst_vector_t>()[Number<0>{}]); dst_vector.template AsType<dst_vector_t>()[Number<0>{}]);
printf("copy: %d %d\n", dst_coord_.GetOffset(), dst_coord_.GetIndex()[I0]);
} }
else if constexpr(DstInMemOp == InMemoryDataOperationEnum_t::AtomicAdd) else if constexpr(DstInMemOp == InMemoryDataOperationEnum_t::AtomicAdd)
{ {
......
...@@ -589,6 +589,7 @@ struct XdlopsGemm ...@@ -589,6 +589,7 @@ struct XdlopsGemm
const auto N0 = c_desc_m0_n0_m1_n1_m2_n2.GetLength(I1); const auto N0 = c_desc_m0_n0_m1_n1_m2_n2.GetLength(I1);
const auto M1 = c_desc_m0_n0_m1_n1_m2_n2.GetLength(I2); const auto M1 = c_desc_m0_n0_m1_n1_m2_n2.GetLength(I2);
const auto N1 = c_desc_m0_n0_m1_n1_m2_n2.GetLength(I3); const auto N1 = c_desc_m0_n0_m1_n1_m2_n2.GetLength(I3);
const auto N2 = c_desc_m0_n0_m1_n1_m2_n2.GetLength(I5);
return transform_tensor_descriptor( return transform_tensor_descriptor(
c_desc_m0_n0_m1_n1_m2_n2, c_desc_m0_n0_m1_n1_m2_n2,
...@@ -599,7 +600,7 @@ struct XdlopsGemm ...@@ -599,7 +600,7 @@ struct XdlopsGemm
make_unmerge_transform(make_tuple(mfma_instr.num_groups_per_blk, make_unmerge_transform(make_tuple(mfma_instr.num_groups_per_blk,
mfma_instr.num_input_blks, mfma_instr.num_input_blks,
mfma_instr.group_size)), mfma_instr.group_size)),
make_pass_through_transform(mfma_instr.num_threads_per_blk)), make_pass_through_transform(N2)),
make_tuple(Sequence<0>{}, make_tuple(Sequence<0>{},
Sequence<1>{}, Sequence<1>{},
Sequence<2>{}, Sequence<2>{},
......
...@@ -57,7 +57,7 @@ ...@@ -57,7 +57,7 @@
// AMD buffer addressing // AMD buffer addressing
#ifndef CK_USE_AMD_BUFFER_ADDRESSING #ifndef CK_USE_AMD_BUFFER_ADDRESSING
#define CK_USE_AMD_BUFFER_ADDRESSING 1 #define CK_USE_AMD_BUFFER_ADDRESSING 0
#endif #endif
// only gfx908 support native floating point atomic add // only gfx908 support native floating point atomic add
......
...@@ -104,6 +104,11 @@ struct StaticBufferOfVectorTypeV2 : public StaticallyIndexedArray<T, N> ...@@ -104,6 +104,11 @@ struct StaticBufferOfVectorTypeV2 : public StaticallyIndexedArray<T, N>
[&](auto i) { GetElement(i, true) = invalid_element_value_; }); [&](auto i) { GetElement(i, true) = invalid_element_value_; });
} }
__host__ __device__ void Fill(VecBaseType val)
{
static_for<0, GetNumElements(), 1>{}([&](auto i) { GetElement(i, true) = val; });
}
__host__ __device__ static constexpr bool IsStaticBuffer() { return true; } __host__ __device__ static constexpr bool IsStaticBuffer() { return true; }
__host__ __device__ static constexpr bool IsDynamicBuffer() { return false; } __host__ __device__ static constexpr bool IsDynamicBuffer() { return false; }
......
...@@ -27,14 +27,18 @@ using device_gemm_xdl_instance_f32_f32_f32_mk_kn_mn = std::tuple< ...@@ -27,14 +27,18 @@ using device_gemm_xdl_instance_f32_f32_f32_mk_kn_mn = std::tuple<
//##########| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadSlice| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ThreadSlice| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| SrcDstVectorDim| DstScalar| AddExtraM| AddExtraN| //##########| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadSlice| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ThreadSlice| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| SrcDstVectorDim| DstScalar| AddExtraM| AddExtraN|
//##########| | | | | | | | Operation| Operation| Operation| | | | | | | | Wave| Wave| Lengths_K0_N_K1| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| Lengths_K0_N_K1| Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerVector| | | //##########| | | | | | | | Operation| Operation| Operation| | | | | | | | Wave| Wave| Lengths_K0_N_K1| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| Lengths_K0_N_K1| Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerVector| | |
//##########| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | //##########| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
DeviceGemmXdl< F32, F32, F32, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, 256, 256, 128, 4, 4, 32, 32, 4, 2, S<1, 4, 4>, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, S<1, 2, 4>, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, 7, 1, true, true>, //DeviceGemmXdl< F32, F32, F32, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, 256, 96, 128, 4, 4, 16, 16, 3, 4, S<1, 3, 4>, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, S<1, 2, 4>, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, 7, 1, true, true>
DeviceGemmXdl< F32, F32, F32, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, 256, 128, 256, 4, 4, 32, 32, 2, 4, S<1, 2, 4>, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, S<1, 4, 4>, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 7, 1, true, true>, //DeviceGemmXdl< F32, F32, F32, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, 256, 32, 128, 4, 4, 16, 16, 1, 4, S<1, 1, 4>, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, S<1, 2, 4>, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, 7, 1, true, true>
DeviceGemmXdl< F32, F32, F32, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, 128, 128, 128, 4, 4, 32, 32, 4, 2, S<1, 4, 4>, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, S<1, 4, 4>, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 7, 1, true, true>, DeviceGemmXdl< F32, F32, F32, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, 128, 96, 128, 4, 4, 32, 32, 3, 2, S<1, 3, 4>, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, S<1, 4, 4>, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, 7, 1, true, true>
DeviceGemmXdl< F32, F32, F32, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, 256, 128, 128, 4, 4, 32, 32, 2, 2, S<1, 2, 4>, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, S<1, 2, 4>, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, 7, 1, true, true>, //DeviceGemmXdl< F32, F32, F32, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, 256, 128, 256, 4, 4, 32, 32, 2, 4, S<1, 2, 4>, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, S<1, 4, 4>, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 7, 1, true, true>,
DeviceGemmXdl< F32, F32, F32, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, 128, 128, 64, 4, 4, 32, 32, 2, 2, S<1, 4, 4>, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, S<1, 2, 4>, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, 7, 1, true, true>, //DeviceGemmXdl< F32, F32, F32, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, 256, 256, 128, 4, 4, 32, 32, 4, 2, S<1, 4, 4>, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, S<1, 2, 4>, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, 7, 1, true, true>,
DeviceGemmXdl< F32, F32, F32, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, 128, 64, 128, 4, 4, 32, 32, 2, 2, S<1, 2, 4>, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, S<1, 4, 4>, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 7, 1, true, true>, //DeviceGemmXdl< F32, F32, F32, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, 256, 128, 256, 4, 4, 32, 32, 2, 4, S<1, 2, 4>, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, S<1, 4, 4>, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 7, 1, true, true>,
DeviceGemmXdl< F32, F32, F32, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, 256, 128, 64, 4, 4, 32, 32, 2, 1, S<1, 2, 4>, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, S<1, 1, 4>, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 4, 7, 1, true, true>, //DeviceGemmXdl< F32, F32, F32, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, 128, 128, 128, 4, 4, 32, 32, 4, 2, S<1, 4, 4>, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, S<1, 4, 4>, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 7, 1, true, true>,
DeviceGemmXdl< F32, F32, F32, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, 256, 64, 128, 4, 4, 32, 32, 1, 2, S<1, 1, 4>, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, S<1, 2, 4>, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, 7, 1, true, true> //DeviceGemmXdl< F32, F32, F32, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, 256, 128, 128, 4, 4, 32, 32, 2, 2, S<1, 2, 4>, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, S<1, 2, 4>, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, 7, 1, true, true>,
//DeviceGemmXdl< F32, F32, F32, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, 128, 128, 64, 4, 4, 32, 32, 2, 2, S<1, 4, 4>, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, S<1, 2, 4>, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, 7, 1, true, true>,
//DeviceGemmXdl< F32, F32, F32, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, 128, 64, 128, 4, 4, 32, 32, 2, 2, S<1, 2, 4>, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, S<1, 4, 4>, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 7, 1, true, true>,
//DeviceGemmXdl< F32, F32, F32, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, 256, 128, 64, 4, 4, 32, 32, 2, 1, S<1, 2, 4>, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, S<1, 1, 4>, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 4, 7, 1, true, true>,
//DeviceGemmXdl< F32, F32, F32, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, 256, 64, 128, 4, 4, 32, 32, 1, 2, S<1, 1, 4>, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, S<1, 2, 4>, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, 7, 1, true, true>
// clang-format on // clang-format on
>; >;
......
...@@ -287,27 +287,27 @@ void device_convolution_forward_implicit_gemm_v4r4r4_xdlops_nhwc_kyxc_nhwk( ...@@ -287,27 +287,27 @@ void device_convolution_forward_implicit_gemm_v4r4r4_xdlops_nhwc_kyxc_nhwk(
constexpr index_t GemmCThreadTransferDstScalarPerVector = 1; constexpr index_t GemmCThreadTransferDstScalarPerVector = 1;
#elif 1 #elif 1
// [M, N, K0, K1] = [128, 256, 4, 8], C = 128, for fp16 // [M, N, K0, K1] = [128, 256, 4, 8], C = 128, for fp16
constexpr index_t BlockSize = 256; constexpr index_t BlockSize = 64;
constexpr index_t GemmMPerBlock = 128; constexpr index_t GemmMPerBlock = 48;
constexpr index_t GemmNPerBlock = 256; constexpr index_t GemmNPerBlock = 16;
constexpr index_t GemmKPerBlock = 4; constexpr index_t GemmKPerBlock = 4;
constexpr index_t GemmMPerXDL = 32; constexpr index_t GemmMPerXDL = 16;
constexpr index_t GemmNPerXDL = 32; constexpr index_t GemmNPerXDL = 16;
constexpr index_t GemmK1 = 8; constexpr index_t GemmK1 = 8;
constexpr index_t MRepeat = 2; constexpr index_t MRepeat = 3;
constexpr index_t NRepeat = 4; constexpr index_t NRepeat = 1;
using GemmABlockTransferThreadSliceLengths_GemmK0_GemmM_GemmK1 = Sequence<1, 2, 8>; using GemmABlockTransferThreadSliceLengths_GemmK0_GemmM_GemmK1 = Sequence<4, 1, 8>;
using GemmABlockTransferThreadClusterLengths_GemmK0_GemmM_GemmK1 = Sequence<4, 64, 1>; using GemmABlockTransferThreadClusterLengths_GemmK0_GemmM_GemmK1 = Sequence<1, 48, 1>;
constexpr index_t GemmABlockTransferSrcScalarPerVector_GemmK1 = 8; constexpr index_t GemmABlockTransferSrcScalarPerVector_GemmK1 = 8;
constexpr index_t GemmABlockTransferDstScalarPerVector_GemmK1 = 8; constexpr index_t GemmABlockTransferDstScalarPerVector_GemmK1 = 8;
using GemmBBlockTransferThreadSliceLengths_GemmK0_GemmN_GemmK1 = Sequence<1, 4, 8>; using GemmBBlockTransferThreadSliceLengths_GemmK0_GemmN_GemmK1 = Sequence<4, 1, 8>;
using GemmBBlockTransferThreadClusterLengths_GemmK0_GemmN_GemmK1 = Sequence<4, 64, 1>; using GemmBBlockTransferThreadClusterLengths_GemmK0_GemmN_GemmK1 = Sequence<1, 16, 1>;
constexpr index_t GemmBBlockTransferSrcScalarPerVector_GemmK1 = 8; constexpr index_t GemmBBlockTransferSrcScalarPerVector_GemmK1 = 8;
constexpr index_t GemmBBlockTransferDstScalarPerVector_GemmK1 = 8; constexpr index_t GemmBBlockTransferDstScalarPerVector_GemmK1 = 8;
......
...@@ -162,23 +162,23 @@ void device_gemm_xdlops_mk_kn_mn(const Tensor<ABType>& a_m_k, ...@@ -162,23 +162,23 @@ void device_gemm_xdlops_mk_kn_mn(const Tensor<ABType>& a_m_k,
constexpr index_t BBlockTransferDstScalarPerVector_K1 = 4; constexpr index_t BBlockTransferDstScalarPerVector_K1 = 4;
constexpr index_t CThreadTransferDstScalarPerVector = 1; constexpr index_t CThreadTransferDstScalarPerVector = 1;
#elif 1 #elif 0
// [M, N, K0, K1] = [256, 128, 4, 8], C = 128, for fp16 // [M, N, K0, K1] = [256, 128, 4, 8], C = 128, for fp16
constexpr index_t BlockSize = 256; constexpr index_t BlockSize = 256;
constexpr index_t MPerBlock = 256; constexpr index_t MPerBlock = 32;
constexpr index_t NPerBlock = 128; constexpr index_t NPerBlock = 128;
constexpr index_t KPerBlock = 4; constexpr index_t KPerBlock = 4;
constexpr index_t MPerXDL = 32; constexpr index_t MPerXDL = 16;
constexpr index_t NPerXDL = 32; constexpr index_t NPerXDL = 16;
constexpr index_t K1 = 8; constexpr index_t K1 = 8;
constexpr index_t MRepeat = 4; constexpr index_t MRepeat = 1;
constexpr index_t NRepeat = 2; constexpr index_t NRepeat = 4;
using ABlockTransferThreadSliceLengths_K0_M_K1 = Sequence<1, 4, 8>; using ABlockTransferThreadSliceLengths_K0_M_K1 = Sequence<1, 1, 8>;
using ABlockTransferThreadClusterLengths_K0_M_K1 = Sequence<4, 64, 1>; using ABlockTransferThreadClusterLengths_K0_M_K1 = Sequence<4, 32, 1>;
constexpr index_t ABlockTransferSrcScalarPerVector_K1 = 8; constexpr index_t ABlockTransferSrcScalarPerVector_K1 = 8;
constexpr index_t ABlockTransferDstScalarPerVector_K1 = 8; constexpr index_t ABlockTransferDstScalarPerVector_K1 = 8;
...@@ -189,6 +189,34 @@ void device_gemm_xdlops_mk_kn_mn(const Tensor<ABType>& a_m_k, ...@@ -189,6 +189,34 @@ void device_gemm_xdlops_mk_kn_mn(const Tensor<ABType>& a_m_k,
constexpr index_t BBlockTransferSrcScalarPerVector_N = 2; constexpr index_t BBlockTransferSrcScalarPerVector_N = 2;
constexpr index_t BBlockTransferDstScalarPerVector_K1 = 8; constexpr index_t BBlockTransferDstScalarPerVector_K1 = 8;
constexpr index_t CThreadTransferDstScalarPerVector = 1;
#elif 1
// [M, N, K0, K1] = [256, 128, 4, 8], C = 128, for fp16
constexpr index_t BlockSize = 64;
constexpr index_t MPerBlock = 48;
constexpr index_t NPerBlock = 16;
constexpr index_t KPerBlock = 4;
constexpr index_t MPerXDL = 16;
constexpr index_t NPerXDL = 16;
constexpr index_t K1 = 8;
constexpr index_t MRepeat = 3;
constexpr index_t NRepeat = 1;
using ABlockTransferThreadSliceLengths_K0_M_K1 = Sequence<4, 1, 8>;
using ABlockTransferThreadClusterLengths_K0_M_K1 = Sequence<1, 48, 1>;
constexpr index_t ABlockTransferSrcScalarPerVector_K1 = 1;
constexpr index_t ABlockTransferDstScalarPerVector_K1 = 1;
using BBlockTransferThreadSliceLengths_K0_N_K1 = Sequence<4, 1, 8>;
using BBlockTransferThreadClusterLengths_K0_N_K1 = Sequence<1, 16, 1>;
constexpr index_t BBlockTransferSrcScalarPerVector_N = 1;
constexpr index_t BBlockTransferDstScalarPerVector_K1 = 1;
constexpr index_t CThreadTransferDstScalarPerVector = 1; constexpr index_t CThreadTransferDstScalarPerVector = 1;
#elif 0 #elif 0
// [M, N, K0, K1] = [128, 256, 4, 8] for fp16 // [M, N, K0, K1] = [128, 256, 4, 8] for fp16
...@@ -351,8 +379,7 @@ void device_gemm_xdlops_mk_kn_mn(const Tensor<ABType>& a_m_k, ...@@ -351,8 +379,7 @@ void device_gemm_xdlops_mk_kn_mn(const Tensor<ABType>& a_m_k,
b_k_n.mDesc.GetStrides()[1], b_k_n.mDesc.GetStrides()[1],
b_k_n.mDesc.GetStrides()[0])); b_k_n.mDesc.GetStrides()[0]));
const auto c_m_n_grid_desc = make_naive_tensor_descriptor( const auto c_m_n_grid_desc = make_naive_tensor_descriptor_packed(make_tuple(M, N));
make_tuple(M, N), make_tuple(c_m_n.mDesc.GetStrides()[0], c_m_n.mDesc.GetStrides()[1]));
// HACK: hacks that control index calculation when iterating over A, B, C matrix // HACK: hacks that control index calculation when iterating over A, B, C matrix
constexpr auto a_k0_m_k1_grid_step_hacks = make_tuple(make_tuple(Sequence<0>{}, // 0+: K0 constexpr auto a_k0_m_k1_grid_step_hacks = make_tuple(make_tuple(Sequence<0>{}, // 0+: K0
......
...@@ -6,6 +6,15 @@ ...@@ -6,6 +6,15 @@
#include "tensor_descriptor_helper.hpp" #include "tensor_descriptor_helper.hpp"
#include "gridwise_gemm_xdlops_v2r3.hpp" #include "gridwise_gemm_xdlops_v2r3.hpp"
struct OpPassThrough
{
template <typename T>
__host__ __device__ constexpr T operator()(T v) const
{
return v;
}
};
template <ck::index_t BlockSize, template <ck::index_t BlockSize,
typename FloatAB, typename FloatAB,
typename FloatAcc, typename FloatAcc,
...@@ -70,6 +79,8 @@ __host__ float driver_gemm_xdlops_v2r3(const FloatAB* p_a_grid, ...@@ -70,6 +79,8 @@ __host__ float driver_gemm_xdlops_v2r3(const FloatAB* p_a_grid,
constexpr auto I1 = Number<1>{}; constexpr auto I1 = Number<1>{};
constexpr auto I2 = Number<2>{}; constexpr auto I2 = Number<2>{};
using ElementwiseOperation = OpPassThrough;
using GridwiseGemm = using GridwiseGemm =
GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3<BlockSize, GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3<BlockSize,
FloatAB, FloatAB,
...@@ -79,6 +90,9 @@ __host__ float driver_gemm_xdlops_v2r3(const FloatAB* p_a_grid, ...@@ -79,6 +90,9 @@ __host__ float driver_gemm_xdlops_v2r3(const FloatAB* p_a_grid,
AGridDesc_K0_M_K1, AGridDesc_K0_M_K1,
BGridDesc_K0_N_K, BGridDesc_K0_N_K,
CMNGridDesc, CMNGridDesc,
ElementwiseOperation,
ElementwiseOperation,
ElementwiseOperation,
MPerBlock, MPerBlock,
NPerBlock, NPerBlock,
KPerBlock, KPerBlock,
...@@ -152,6 +166,8 @@ __host__ float driver_gemm_xdlops_v2r3(const FloatAB* p_a_grid, ...@@ -152,6 +166,8 @@ __host__ float driver_gemm_xdlops_v2r3(const FloatAB* p_a_grid,
float ave_time = 0; float ave_time = 0;
auto element_op_ = OpPassThrough{};
#if CK_EXPERIMENTAL_PASS_TENSOR_DESCRIPTOR_BY_VALUE #if CK_EXPERIMENTAL_PASS_TENSOR_DESCRIPTOR_BY_VALUE
if(has_main_k0_block_loop) if(has_main_k0_block_loop)
{ {
...@@ -162,6 +178,9 @@ __host__ float driver_gemm_xdlops_v2r3(const FloatAB* p_a_grid, ...@@ -162,6 +178,9 @@ __host__ float driver_gemm_xdlops_v2r3(const FloatAB* p_a_grid,
remove_reference_t<AGridDesc_K0_M_K1>, remove_reference_t<AGridDesc_K0_M_K1>,
remove_reference_t<BGridDesc_K0_N_K>, remove_reference_t<BGridDesc_K0_N_K>,
remove_reference_t<CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2>, remove_reference_t<CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2>,
ElementwiseOperation,
ElementwiseOperation,
ElementwiseOperation,
remove_reference_t<Block2CTileMap>, remove_reference_t<Block2CTileMap>,
true>; true>;
...@@ -176,6 +195,9 @@ __host__ float driver_gemm_xdlops_v2r3(const FloatAB* p_a_grid, ...@@ -176,6 +195,9 @@ __host__ float driver_gemm_xdlops_v2r3(const FloatAB* p_a_grid,
a_grid_desc_k0_m_k1, a_grid_desc_k0_m_k1,
b_grid_desc_k0_n_k1, b_grid_desc_k0_n_k1,
c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc, c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc,
element_op_,
element_op_,
element_op_,
block_2_ctile_map); block_2_ctile_map);
} }
else else
...@@ -187,6 +209,9 @@ __host__ float driver_gemm_xdlops_v2r3(const FloatAB* p_a_grid, ...@@ -187,6 +209,9 @@ __host__ float driver_gemm_xdlops_v2r3(const FloatAB* p_a_grid,
remove_reference_t<AGridDesc_K0_M_K1>, remove_reference_t<AGridDesc_K0_M_K1>,
remove_reference_t<BGridDesc_K0_N_K>, remove_reference_t<BGridDesc_K0_N_K>,
remove_reference_t<CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2>, remove_reference_t<CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2>,
ElementwiseOperation,
ElementwiseOperation,
ElementwiseOperation,
remove_reference_t<Block2CTileMap>, remove_reference_t<Block2CTileMap>,
false>; false>;
...@@ -201,6 +226,9 @@ __host__ float driver_gemm_xdlops_v2r3(const FloatAB* p_a_grid, ...@@ -201,6 +226,9 @@ __host__ float driver_gemm_xdlops_v2r3(const FloatAB* p_a_grid,
a_grid_desc_k0_m_k1, a_grid_desc_k0_m_k1,
b_grid_desc_k0_n_k1, b_grid_desc_k0_n_k1,
c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc, c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc,
element_op_,
element_op_,
element_op_,
block_2_ctile_map); block_2_ctile_map);
} }
#elif CK_EXPERIMENTAL_PASS_TENSOR_DESCRIPTOR_BY_VOID_POINTER #elif CK_EXPERIMENTAL_PASS_TENSOR_DESCRIPTOR_BY_VOID_POINTER
......
...@@ -12,10 +12,10 @@ ...@@ -12,10 +12,10 @@
#include "host_tensor_generator.hpp" #include "host_tensor_generator.hpp"
#include "conv_common.hpp" #include "conv_common.hpp"
#include "device_tensor.hpp" #include "device_tensor.hpp"
#include "device_convolution_forward_implicit_gemm_v4r4_dlops_nchw_kcyx_nkhw.hpp" //#include "device_convolution_forward_implicit_gemm_v4r4_dlops_nchw_kcyx_nkhw.hpp"
#include "device_convolution_forward_implicit_gemm_v4r4r2_dlops_nhwc_kyxc_nhwk.hpp" //#include "device_convolution_forward_implicit_gemm_v4r4r2_dlops_nhwc_kyxc_nhwk.hpp"
#include "device_convolution_forward_implicit_gemm_v6r1_dlops_nchw_kcyx_nkhw.hpp" //#include "device_convolution_forward_implicit_gemm_v6r1_dlops_nchw_kcyx_nkhw.hpp"
#include "device_convolution_forward_implicit_gemm_v4r4r2_xdlops_nchw_kcyx_nkhw.hpp" //#include "device_convolution_forward_implicit_gemm_v4r4r2_xdlops_nchw_kcyx_nkhw.hpp"
#include "device_convolution_forward_implicit_gemm_v4r4r4_xdlops_nhwc_kyxc_nhwk.hpp" #include "device_convolution_forward_implicit_gemm_v4r4r4_xdlops_nhwc_kyxc_nhwk.hpp"
#define USE_DYNAMIC_MODE 1 #define USE_DYNAMIC_MODE 1
......
...@@ -22,9 +22,9 @@ ...@@ -22,9 +22,9 @@
#include "device_gemm_xdlops_km_nk_nm.hpp" #include "device_gemm_xdlops_km_nk_nm.hpp"
#define USE_GEMM_XDL_MK_KN_MN 1 #define USE_GEMM_XDL_MK_KN_MN 1
#define USE_GEMM_XDL_MK_NK_MN 1 #define USE_GEMM_XDL_MK_NK_MN 0
#define USE_GEMM_XDL_KM_KN_MN 1 #define USE_GEMM_XDL_KM_KN_MN 0
#define USE_GEMM_XDL_KM_NK_MN 1 #define USE_GEMM_XDL_KM_NK_MN 0
#define USE_GEMM_XDL_MK_KN_NM 0 #define USE_GEMM_XDL_MK_KN_NM 0
#define USE_GEMM_XDL_MK_NK_NM 0 #define USE_GEMM_XDL_MK_NK_NM 0
#define USE_GEMM_XDL_KM_KN_NM 0 #define USE_GEMM_XDL_KM_KN_NM 0
...@@ -445,8 +445,8 @@ int main(int argc, char* argv[]) ...@@ -445,8 +445,8 @@ int main(int argc, char* argv[])
if(do_log) if(do_log)
{ {
LogRangeAsType<float>(std::cout << "a : ", a.mData, ",") << std::endl; // LogRangeAsType<float>(std::cout << "a : ", a.mData, ",") << std::endl;
LogRangeAsType<float>(std::cout << "b: ", b.mData, ",") << std::endl; // LogRangeAsType<float>(std::cout << "b: ", b.mData, ",") << std::endl;
LogRangeAsType<float>(std::cout << "c_host : ", c_host.mData, ",") << std::endl; LogRangeAsType<float>(std::cout << "c_host : ", c_host.mData, ",") << std::endl;
LogRangeAsType<float>(std::cout << "c_device: ", c_device.mData, ",") << std::endl; LogRangeAsType<float>(std::cout << "c_device: ", c_device.mData, ",") << std::endl;
} }
......
...@@ -15,13 +15,13 @@ include_directories(BEFORE ...@@ -15,13 +15,13 @@ include_directories(BEFORE
# device_gemm_instance # device_gemm_instance
set(DEVICE_GEMM_INSTANCE_SOURCE set(DEVICE_GEMM_INSTANCE_SOURCE
${PROJECT_SOURCE_DIR}/device_operation/device_gemm_xdl_instance_f32_f32_f32_mk_kn_mn.cpp; ${PROJECT_SOURCE_DIR}/device_operation/device_gemm_xdl_instance_f32_f32_f32_mk_kn_mn.cpp;
${PROJECT_SOURCE_DIR}/device_operation/device_gemm_xdl_instance_f32_f32_f32_mk_nk_mn.cpp; #${PROJECT_SOURCE_DIR}/device_operation/device_gemm_xdl_instance_f32_f32_f32_mk_nk_mn.cpp;
${PROJECT_SOURCE_DIR}/device_operation/device_gemm_xdl_instance_f32_f32_f32_km_kn_mn.cpp; #${PROJECT_SOURCE_DIR}/device_operation/device_gemm_xdl_instance_f32_f32_f32_km_kn_mn.cpp;
${PROJECT_SOURCE_DIR}/device_operation/device_gemm_xdl_instance_f32_f32_f32_km_nk_mn.cpp; #${PROJECT_SOURCE_DIR}/device_operation/device_gemm_xdl_instance_f32_f32_f32_km_nk_mn.cpp;
${PROJECT_SOURCE_DIR}/device_operation/device_gemm_xdl_instance_f16_f16_f16_mk_kn_mn.cpp; #${PROJECT_SOURCE_DIR}/device_operation/device_gemm_xdl_instance_f16_f16_f16_mk_kn_mn.cpp;
${PROJECT_SOURCE_DIR}/device_operation/device_gemm_xdl_instance_f16_f16_f16_mk_nk_mn.cpp; #${PROJECT_SOURCE_DIR}/device_operation/device_gemm_xdl_instance_f16_f16_f16_mk_nk_mn.cpp;
${PROJECT_SOURCE_DIR}/device_operation/device_gemm_xdl_instance_f16_f16_f16_km_kn_mn.cpp; #${PROJECT_SOURCE_DIR}/device_operation/device_gemm_xdl_instance_f16_f16_f16_km_kn_mn.cpp;
${PROJECT_SOURCE_DIR}/device_operation/device_gemm_xdl_instance_f16_f16_f16_km_nk_mn.cpp; #${PROJECT_SOURCE_DIR}/device_operation/device_gemm_xdl_instance_f16_f16_f16_km_nk_mn.cpp;
) )
add_library(device_gemm_instance SHARED ${DEVICE_GEMM_INSTANCE_SOURCE}) add_library(device_gemm_instance SHARED ${DEVICE_GEMM_INSTANCE_SOURCE})
...@@ -31,20 +31,20 @@ set_target_properties(device_gemm_instance PROPERTIES POSITION_INDEPENDENT_CODE ...@@ -31,20 +31,20 @@ set_target_properties(device_gemm_instance PROPERTIES POSITION_INDEPENDENT_CODE
install(TARGETS device_gemm_instance LIBRARY DESTINATION lib) install(TARGETS device_gemm_instance LIBRARY DESTINATION lib)
# device_conv_instance # device_conv_instance
set(DEVICE_CONV_INSTANCE_SOURCE #set(DEVICE_CONV_INSTANCE_SOURCE
${PROJECT_SOURCE_DIR}/device_operation/device_conv_xdl_instance_f32_f32_f32_nhwc_kyxc_nhwk.cpp; ##${PROJECT_SOURCE_DIR}/device_operation/device_conv_xdl_instance_f32_f32_f32_nhwc_kyxc_nhwk.cpp;
${PROJECT_SOURCE_DIR}/device_operation/device_conv_xdl_instance_f16_f16_f16_nhwc_kyxc_nhwk.cpp; ##${PROJECT_SOURCE_DIR}/device_operation/device_conv_xdl_instance_f16_f16_f16_nhwc_kyxc_nhwk.cpp;
) #)
add_library(device_conv_instance SHARED ${DEVICE_CONV_INSTANCE_SOURCE}) #add_library(device_conv_instance SHARED ${DEVICE_CONV_INSTANCE_SOURCE})
target_include_directories(device_conv_instance SYSTEM PUBLIC $<BUILD_INTERFACE:${HALF_INCLUDE_DIR}>) #target_include_directories(device_conv_instance SYSTEM PUBLIC $<BUILD_INTERFACE:${HALF_INCLUDE_DIR}>)
target_compile_features(device_conv_instance PUBLIC) #target_compile_features(device_conv_instance PUBLIC)
set_target_properties(device_conv_instance PROPERTIES POSITION_INDEPENDENT_CODE ON) #set_target_properties(device_conv_instance PROPERTIES POSITION_INDEPENDENT_CODE ON)
install(TARGETS device_conv_instance LIBRARY DESTINATION lib) #install(TARGETS device_conv_instance LIBRARY DESTINATION lib)
# ck_profiler # ck_profiler
set(PROFILER_SOURCE profiler.cpp gemm_profiler.cpp conv_profiler.cpp) set(PROFILER_SOURCE profiler.cpp gemm_profiler.cpp)
add_executable(ckProfiler ${PROFILER_SOURCE}) add_executable(ckProfiler ${PROFILER_SOURCE})
target_link_libraries(ckProfiler PRIVATE host_tensor) target_link_libraries(ckProfiler PRIVATE host_tensor)
target_link_libraries(ckProfiler PRIVATE device_gemm_instance device_conv_instance) target_link_libraries(ckProfiler PRIVATE device_gemm_instance)
...@@ -66,6 +66,7 @@ int gemm_profiler(int argc, char* argv[]) ...@@ -66,6 +66,7 @@ int gemm_profiler(int argc, char* argv[])
const int StrideB = std::stoi(argv[12]); const int StrideB = std::stoi(argv[12]);
const int StrideC = std::stoi(argv[13]); const int StrideC = std::stoi(argv[13]);
#if 0
if(data_type == GemmDataType::F16_F16_F16 && layout == GemmMatrixLayout::MK_KN_MN) if(data_type == GemmDataType::F16_F16_F16 && layout == GemmMatrixLayout::MK_KN_MN)
{ {
ck::profiler::profile_gemm<ck::half_t, ck::profiler::profile_gemm<ck::half_t,
...@@ -210,6 +211,27 @@ int gemm_profiler(int argc, char* argv[]) ...@@ -210,6 +211,27 @@ int gemm_profiler(int argc, char* argv[])
(StrideB < 0) ? K : StrideB, (StrideB < 0) ? K : StrideB,
(StrideC < 0) ? N : StrideC); (StrideC < 0) ? N : StrideC);
} }
#endif
if(data_type == GemmDataType::F32_F32_F32 && layout == GemmMatrixLayout::MK_KN_MN)
{
ck::profiler::profile_gemm<float,
float,
float,
ck::tensor_layout::gemm::RowMajor,
ck::tensor_layout::gemm::RowMajor,
ck::tensor_layout::gemm::RowMajor>(do_verification,
init_method,
do_log,
nrepeat,
M,
N,
K,
(StrideA < 0) ? K : StrideA,
(StrideB < 0) ? N : StrideB,
(StrideC < 0) ? N : StrideC);
}
else else
{ {
throw std::runtime_error("wrong! this GEMM data_type & layout is not implemented"); throw std::runtime_error("wrong! this GEMM data_type & layout is not implemented");
......
...@@ -6,7 +6,7 @@ ...@@ -6,7 +6,7 @@
#include <half.hpp> #include <half.hpp>
int gemm_profiler(int, char*[]); int gemm_profiler(int, char*[]);
int conv_profiler(int, char*[]); // int conv_profiler(int, char*[]);
int main(int argc, char* argv[]) int main(int argc, char* argv[])
{ {
...@@ -14,10 +14,10 @@ int main(int argc, char* argv[]) ...@@ -14,10 +14,10 @@ int main(int argc, char* argv[])
{ {
return gemm_profiler(argc, argv); return gemm_profiler(argc, argv);
} }
else if(strcmp(argv[1], "conv") == 0) // else if(strcmp(argv[1], "conv") == 0)
{ //{
return conv_profiler(argc, argv); // return conv_profiler(argc, argv);
} //}
else else
{ {
printf("arg1: tensor operation (gemm=GEMM, conv=Convolution)\n"); printf("arg1: tensor operation (gemm=GEMM, conv=Convolution)\n");
......
...@@ -10,7 +10,7 @@ cmake ...@@ -10,7 +10,7 @@ cmake
-D CMAKE_INSTALL_PREFIX=${MY_PROJECT_INSTALL} \ -D CMAKE_INSTALL_PREFIX=${MY_PROJECT_INSTALL} \
-D BUILD_DEV=OFF \ -D BUILD_DEV=OFF \
-D CMAKE_BUILD_TYPE=Release \ -D CMAKE_BUILD_TYPE=Release \
-D CMAKE_CXX_FLAGS="-DCK_AMD_GPU_GFX908 --amdgpu-target=gfx908 -O3 -ftemplate-backtrace-limit=0 -mllvm --amdgpu-spill-vgpr-to-agpr=0 -gline-tables-only -save-temps=$PWD" \ -D CMAKE_CXX_FLAGS="-DCK_AMD_GPU_GFX908 --amdgpu-target=gfx908 -O1 -ftemplate-backtrace-limit=0 -mllvm --amdgpu-spill-vgpr-to-agpr=0 -gline-tables-only -save-temps=$PWD" \
-D CMAKE_CXX_COMPILER=/opt/rocm/bin/hipcc \ -D CMAKE_CXX_COMPILER=/opt/rocm/bin/hipcc \
-D CMAKE_PREFIX_PATH=/opt/rocm \ -D CMAKE_PREFIX_PATH=/opt/rocm \
-D CMAKE_VERBOSE_MAKEFILE:BOOL=ON \ -D CMAKE_VERBOSE_MAKEFILE:BOOL=ON \
......
...@@ -22,7 +22,7 @@ REPEAT=$6 ...@@ -22,7 +22,7 @@ REPEAT=$6
######### layout algo verify init log repeat N__ K___ C___ Y X Hi_ Wi__ Strides Dilations LeftPads RightPads Desired_grid_size__ ######### layout algo verify init log repeat N__ K___ C___ Y X Hi_ Wi__ Strides Dilations LeftPads RightPads Desired_grid_size__
#$DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 128 128 192 3 3 71 71 2 2 1 1 1 1 1 1 $DESIRED_GRID_SIZE #$DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 128 128 192 3 3 71 71 2 2 1 1 1 1 1 1 $DESIRED_GRID_SIZE
$DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 128 256 192 3 3 71 71 2 2 1 1 1 1 1 1 $DESIRED_GRID_SIZE $DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 1 16 32 1 1 1 48 1 1 1 1 0 0 0 0 $DESIRED_GRID_SIZE
#$DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 128 256 1024 1 7 17 17 1 1 1 1 0 3 0 3 $DESIRED_GRID_SIZE #$DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 128 256 1024 1 7 17 17 1 1 1 1 0 3 0 3 $DESIRED_GRID_SIZE
#$DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 256 256 256 3 3 14 14 1 1 1 1 1 1 1 1 $DESIRED_GRID_SIZE #$DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 256 256 256 3 3 14 14 1 1 1 1 1 1 1 1 $DESIRED_GRID_SIZE
#$DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 128 128 128 3 3 14 14 1 1 1 1 1 1 1 1 $DESIRED_GRID_SIZE #$DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 128 128 128 3 3 14 14 1 1 1 1 1 1 1 1 $DESIRED_GRID_SIZE
......
...@@ -19,7 +19,8 @@ REPEAT=$6 ...@@ -19,7 +19,8 @@ REPEAT=$6
######### layout algo verify init log repeat M___ N___ K___ M01_ N01_ ######### layout algo verify init log repeat M___ N___ K___ M01_ N01_
#$DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 960 1024 1024 $M01 $N01 #$DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 960 1024 1024 $M01 $N01
$DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 48 16 32 $M01 $N01
#$DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 1024 1024 1024 $M01 $N01 #$DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 1024 1024 1024 $M01 $N01
#$DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 1920 2048 2048 $M01 $N01 #$DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 1920 2048 2048 $M01 $N01
$DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 3840 4096 4096 $M01 $N01 #$DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 3840 4096 4096 $M01 $N01
#$DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 7680 8192 8192 $M01 $N01 #$DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 7680 8192 8192 $M01 $N01
...@@ -25,21 +25,21 @@ REPEAT=$7 ...@@ -25,21 +25,21 @@ REPEAT=$7
#$DRIVER $OP $DATATYPE $LAYOUT $VERIFY $INIT $LOG $REPEAT 2048 2048 2048 2048 2048 2048 #$DRIVER $OP $DATATYPE $LAYOUT $VERIFY $INIT $LOG $REPEAT 2048 2048 2048 2048 2048 2048
$DRIVER $OP $DATATYPE $LAYOUT $VERIFY $INIT $LOG $REPEAT 960 1024 1024 -1 -1 -1 $DRIVER $OP $DATATYPE $LAYOUT $VERIFY $INIT $LOG $REPEAT 960 1024 1024 -1 -1 -1
$DRIVER $OP $DATATYPE $LAYOUT $VERIFY $INIT $LOG $REPEAT 1920 2048 2048 -1 -1 -1 #$DRIVER $OP $DATATYPE $LAYOUT $VERIFY $INIT $LOG $REPEAT 1920 2048 2048 -1 -1 -1
$DRIVER $OP $DATATYPE $LAYOUT $VERIFY $INIT $LOG $REPEAT 3840 4096 4096 -1 -1 -1 #$DRIVER $OP $DATATYPE $LAYOUT $VERIFY $INIT $LOG $REPEAT 3840 4096 4096 -1 -1 -1
$DRIVER $OP $DATATYPE $LAYOUT $VERIFY $INIT $LOG $REPEAT 7680 8192 8192 -1 -1 -1 #$DRIVER $OP $DATATYPE $LAYOUT $VERIFY $INIT $LOG $REPEAT 7680 8192 8192 -1 -1 -1
$DRIVER $OP $DATATYPE $LAYOUT $VERIFY $INIT $LOG $REPEAT 1024 1024 1024 1024 1024 1024 #$DRIVER $OP $DATATYPE $LAYOUT $VERIFY $INIT $LOG $REPEAT 1024 1024 1024 1024 1024 1024
$DRIVER $OP $DATATYPE $LAYOUT $VERIFY $INIT $LOG $REPEAT 2048 2048 2048 2048 2048 2048 #$DRIVER $OP $DATATYPE $LAYOUT $VERIFY $INIT $LOG $REPEAT 2048 2048 2048 2048 2048 2048
$DRIVER $OP $DATATYPE $LAYOUT $VERIFY $INIT $LOG $REPEAT 4096 4096 4096 4096 4096 4096 #$DRIVER $OP $DATATYPE $LAYOUT $VERIFY $INIT $LOG $REPEAT 4096 4096 4096 4096 4096 4096
$DRIVER $OP $DATATYPE $LAYOUT $VERIFY $INIT $LOG $REPEAT 8192 8192 8192 8192 8192 8192 #$DRIVER $OP $DATATYPE $LAYOUT $VERIFY $INIT $LOG $REPEAT 8192 8192 8192 8192 8192 8192
$DRIVER $OP $DATATYPE $LAYOUT $VERIFY $INIT $LOG $REPEAT 1024 1024 1024 1056 1056 1056 #$DRIVER $OP $DATATYPE $LAYOUT $VERIFY $INIT $LOG $REPEAT 1024 1024 1024 1056 1056 1056
$DRIVER $OP $DATATYPE $LAYOUT $VERIFY $INIT $LOG $REPEAT 2048 2048 2048 2080 2080 2080 #$DRIVER $OP $DATATYPE $LAYOUT $VERIFY $INIT $LOG $REPEAT 2048 2048 2048 2080 2080 2080
$DRIVER $OP $DATATYPE $LAYOUT $VERIFY $INIT $LOG $REPEAT 4096 4096 4096 4128 4128 4128 #$DRIVER $OP $DATATYPE $LAYOUT $VERIFY $INIT $LOG $REPEAT 4096 4096 4096 4128 4128 4128
$DRIVER $OP $DATATYPE $LAYOUT $VERIFY $INIT $LOG $REPEAT 8192 8192 8192 8224 8224 8224 #$DRIVER $OP $DATATYPE $LAYOUT $VERIFY $INIT $LOG $REPEAT 8192 8192 8192 8224 8224 8224
$DRIVER $OP $DATATYPE $LAYOUT $VERIFY $INIT $LOG $REPEAT 1024 1024 1024 1088 1088 1088 #$DRIVER $OP $DATATYPE $LAYOUT $VERIFY $INIT $LOG $REPEAT 1024 1024 1024 1088 1088 1088
$DRIVER $OP $DATATYPE $LAYOUT $VERIFY $INIT $LOG $REPEAT 2048 2048 2048 2112 2112 2112 #$DRIVER $OP $DATATYPE $LAYOUT $VERIFY $INIT $LOG $REPEAT 2048 2048 2048 2112 2112 2112
$DRIVER $OP $DATATYPE $LAYOUT $VERIFY $INIT $LOG $REPEAT 4096 4096 4096 4160 4160 4160 #$DRIVER $OP $DATATYPE $LAYOUT $VERIFY $INIT $LOG $REPEAT 4096 4096 4096 4160 4160 4160
$DRIVER $OP $DATATYPE $LAYOUT $VERIFY $INIT $LOG $REPEAT 8192 8192 8192 8256 8256 8256 #$DRIVER $OP $DATATYPE $LAYOUT $VERIFY $INIT $LOG $REPEAT 8192 8192 8192 8256 8256 8256
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment