Commit 7cb8a89f authored by Jing Zhang's avatar Jing Zhang
Browse files

fixed c_output

parent ccaea50e
...@@ -71,19 +71,3 @@ foreach(gpu IN LISTS GPU_TARGETS) ...@@ -71,19 +71,3 @@ foreach(gpu IN LISTS GPU_TARGETS)
set(target 1) set(target 1)
endif() endif()
endforeach() endforeach()
add_example_executable(example_gemm_xdl_fp8 gemm_xdl_fp8.cpp)
add_example_dependencies(example_gemm_xdl example_gemm_xdl_fp8)
add_example_executable(example_gemm_xdl_fp8_bf8 gemm_xdl_fp8_bf8.cpp)
add_example_dependencies(example_gemm_xdl example_gemm_xdl_fp8_bf8)
list(APPEND gpu_list gfx940 gfx941 gfx942)
set(target 0)
foreach(gpu IN LISTS GPU_TARGETS)
if(gpu IN_LIST gpu_list AND target EQUAL 0)
add_example_executable(example_gemm_xdl_fp16_fp8 gemm_xdl_fp16_fp8.cpp)
add_example_dependencies(example_gemm_xdl example_gemm_xdl_fp16_fp8)
set(target 1)
endif()
endforeach()
...@@ -66,10 +66,15 @@ struct BlockwiseGemmWMMA ...@@ -66,10 +66,15 @@ struct BlockwiseGemmWMMA
// When use LDS, each Row(16 consecutive lanes) read whole data from source buffer // When use LDS, each Row(16 consecutive lanes) read whole data from source buffer
// When not use LDS, each Row read half of whole data from source buffer, exchange the data via // When not use LDS, each Row read half of whole data from source buffer, exchange the data via
// permutation // permutation
#ifdef __gfx12__
static constexpr index_t A_KRow = 2; static constexpr index_t A_KRow = 2;
static constexpr index_t B_KRow = 2; static constexpr index_t B_KRow = 2;
static constexpr index_t A_K1 = ABlockDesc{}.GetLength(I5); #else
static constexpr index_t B_K1 = BBlockDesc{}.GetLength(I5); static constexpr index_t A_KRow = 1;
static constexpr index_t B_KRow = 1;
#endif
static constexpr index_t A_K1 = ABlockDesc{}.GetLength(I5);
static constexpr index_t B_K1 = BBlockDesc{}.GetLength(I5);
static constexpr auto wmma_gemm = static constexpr auto wmma_gemm =
WmmaGemm<FloatA, FloatB, FloatAcc, MPerWMMA, NPerWMMA, KPack, TransposeC>{}; WmmaGemm<FloatA, FloatB, FloatAcc, MPerWMMA, NPerWMMA, KPack, TransposeC>{};
...@@ -108,7 +113,11 @@ struct BlockwiseGemmWMMA ...@@ -108,7 +113,11 @@ struct BlockwiseGemmWMMA
const auto WMMA_a_idx = wmma_gemm.CalculateAThreadOriginDataIndex(); const auto WMMA_a_idx = wmma_gemm.CalculateAThreadOriginDataIndex();
// |KRepeat |MRepeat|MWave |KRow |MLane |KPack // |KRepeat |MRepeat|MWave |KRow |MLane |KPack
#ifdef __gfx12__
return make_tuple(0, 0, waveId_m, wmma_gemm.GetSubGroupId(), WMMA_a_idx, 0); return make_tuple(0, 0, waveId_m, wmma_gemm.GetSubGroupId(), WMMA_a_idx, 0);
#else
return make_tuple(0, 0, waveId_m, 0, WMMA_a_idx, 0);
#endif
} }
else else
{ {
...@@ -125,7 +134,11 @@ struct BlockwiseGemmWMMA ...@@ -125,7 +134,11 @@ struct BlockwiseGemmWMMA
const auto WMMA_b_idx = wmma_gemm.CalculateBThreadOriginDataIndex(); const auto WMMA_b_idx = wmma_gemm.CalculateBThreadOriginDataIndex();
// |KRepeat |NRepeat|Nwave |KRow |NLane |KPack // |KRepeat |NRepeat|Nwave |KRow |NLane |KPack
#ifdef __gfx12__
return make_tuple(0, 0, waveId_n, wmma_gemm.GetSubGroupId(), WMMA_b_idx, 0); return make_tuple(0, 0, waveId_n, wmma_gemm.GetSubGroupId(), WMMA_b_idx, 0);
#else
return make_tuple(0, 0, waveId_n, 0, WMMA_b_idx, 0);
#endif
} }
else else
{ {
...@@ -213,20 +226,19 @@ struct BlockwiseGemmWMMA ...@@ -213,20 +226,19 @@ struct BlockwiseGemmWMMA
constexpr auto c_msubgroup_nthreadpersubgroup_maccvgprs_tblk_lens = constexpr auto c_msubgroup_nthreadpersubgroup_maccvgprs_tblk_lens =
wmma_gemm.GetCMSubGroupNThreadPerSubGroupMAccVgprsThreadBlkLengths(); wmma_gemm.GetCMSubGroupNThreadPerSubGroupMAccVgprsThreadBlkLengths();
constexpr auto MSubGroup = c_msubgroup_nthreadpersubgroup_maccvgprs_tblk_lens[I0]; constexpr auto MAccVgprs = c_msubgroup_nthreadpersubgroup_maccvgprs_tblk_lens[I2];
constexpr auto NThreadPerSubGroup = c_msubgroup_nthreadpersubgroup_maccvgprs_tblk_lens[I1]; constexpr auto AccStride = c_msubgroup_nthreadpersubgroup_maccvgprs_tblk_lens[I3];
constexpr auto MAccVgprs = c_msubgroup_nthreadpersubgroup_maccvgprs_tblk_lens[I2]; return make_naive_tensor_descriptor(
return make_naive_tensor_descriptor_packed(
// |MRepeat |MWave |MSubGroup |NRepeat |NWave // |MRepeat |MWave |MSubGroup |NRepeat |NWave
// |NThreadPerSubGroup |MAccVgprs // |NThreadPerSubGroup |MAccVgprs
make_tuple(Number<MRepeat>{}, make_tuple(Number<MRepeat>{}, I1, I1, Number<NRepeat>{}, I1, I1, MAccVgprs),
I1, make_tuple(Number<NRepeat>{} * MAccVgprs * AccStride,
MSubGroup, Number<NRepeat>{} * MAccVgprs * AccStride,
Number<NRepeat>{}, Number<NRepeat>{} * MAccVgprs * AccStride,
I1, MAccVgprs * AccStride,
NThreadPerSubGroup, MAccVgprs * AccStride,
MAccVgprs)); MAccVgprs * AccStride,
AccStride));
} }
template <typename CGridDesc_M_N> template <typename CGridDesc_M_N>
...@@ -290,6 +302,7 @@ struct BlockwiseGemmWMMA ...@@ -290,6 +302,7 @@ struct BlockwiseGemmWMMA
static constexpr ABlockDesc a_block_desc_k0_m0_m1_m2_k1; static constexpr ABlockDesc a_block_desc_k0_m0_m1_m2_k1;
static constexpr BBlockDesc b_block_desc_k0_n0_n1_n2_k1; static constexpr BBlockDesc b_block_desc_k0_n0_n1_n2_k1;
#ifdef __gfx12__
template <typename ABlockBuffer, typename BBlockBuffer, typename CThreadBuffer> template <typename ABlockBuffer, typename BBlockBuffer, typename CThreadBuffer>
__device__ void Run(const ABlockBuffer& a_block_buf, __device__ void Run(const ABlockBuffer& a_block_buf,
const BBlockBuffer& b_block_buf, const BBlockBuffer& b_block_buf,
...@@ -414,6 +427,140 @@ struct BlockwiseGemmWMMA ...@@ -414,6 +427,140 @@ struct BlockwiseGemmWMMA
}); });
} }
} }
#else
template <typename ABlockBuffer, typename BBlockBuffer, typename CThreadBuffer>
__device__ void Run(const ABlockBuffer& a_block_buf,
const BBlockBuffer& b_block_buf,
CThreadBuffer& c_thread_buf) const
{
auto a_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, FloatA>(
a_thread_desc_.GetElementSpaceSize());
auto b_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, FloatB>(
b_thread_desc_.GetElementSpaceSize());
// basic intrinsic to determine loopover direction
if constexpr(MRepeat < NRepeat)
{
static_for<0, KPerBlock / KPack, 1>{}(
[&](auto k) { // k=0,1,2 instead of k=0,kpack*1, ...
static_for<0, MRepeat, 1>{}([&](auto m0) {
// read A
a_thread_copy_.Run(
a_block_desc_k0_m0_m1_m2_k1,
make_tuple(Number<k * KPack / A_K1 / A_KRow>{}, m0, I0, I0, I0, I0),
a_block_buf,
a_thread_desc_,
make_tuple(I0, m0, I0, I0, I0, I0),
a_thread_buf);
static_for<0, NRepeat, 1>{}([&](auto n0) {
// read B
b_thread_copy_.Run(
b_block_desc_k0_n0_n1_n2_k1,
make_tuple(Number<k * KPack / B_K1 / B_KRow>{}, n0, I0, I0, I0, I0),
b_block_buf,
b_thread_desc_,
make_tuple(I0, n0, I0, I0, I0, I0),
b_thread_buf);
vector_type<FloatA, KPack> a_thread_vec;
vector_type<FloatB, KPack> b_thread_vec;
static_for<0, KPack, 1>{}([&](auto i) {
a_thread_vec.template AsType<FloatA>()(i) =
a_thread_buf[Number<a_thread_desc_.CalculateOffset(
make_tuple(i / A_K1 / A_KRow,
m0,
0,
(i / A_K1) % A_KRow,
0,
i % A_K1))>{}];
b_thread_vec.template AsType<FloatB>()(i) =
b_thread_buf[Number<b_thread_desc_.CalculateOffset(
make_tuple(i / B_K1 / B_KRow,
n0,
0,
(i / B_K1) % B_KRow,
0,
i % B_K1))>{}];
});
using wmma_input_type_a = typename vector_type<FloatA, WmmaK>::type;
using wmma_input_type_b = typename vector_type<FloatB, WmmaK>::type;
constexpr index_t c_offset =
c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0));
wmma_gemm.template Run(
a_thread_vec.template AsType<wmma_input_type_a>(),
b_thread_vec.template AsType<wmma_input_type_b>(),
c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
});
});
});
}
else
{
static_for<0, NRepeat, 1>{}([&](auto n0) {
static_for<0, MRepeat, 1>{}([&](auto m0) {
static_for<0, KPerBlock / KPack, 1>{}([&](auto k) { // k=0,1,2 instead of
// k=0,kpack*1, ..
// read B
b_thread_copy_.Run(
b_block_desc_k0_n0_n1_n2_k1,
make_tuple(Number<k * KPack / B_K1 / B_KRow>{}, n0, I0, I0, I0, I0),
b_block_buf,
b_thread_desc_,
make_tuple(I0, n0, I0, I0, I0, I0),
b_thread_buf);
// read A
a_thread_copy_.Run(
a_block_desc_k0_m0_m1_m2_k1,
make_tuple(Number<k * KPack / A_K1 / A_KRow>{}, m0, I0, I0, I0, I0),
a_block_buf,
a_thread_desc_,
make_tuple(I0, m0, I0, I0, I0, I0),
a_thread_buf);
vector_type<FloatA, KPack> a_thread_vec;
vector_type<FloatB, KPack> b_thread_vec;
static_for<0, KPack, 1>{}([&](auto i) {
b_thread_vec.template AsType<FloatB>()(i) =
b_thread_buf[Number<b_thread_desc_.CalculateOffset(
make_tuple(i / B_K1 / B_KRow,
n0,
0,
(i / B_K1) % B_KRow,
0,
i % B_K1))>{}];
a_thread_vec.template AsType<FloatA>()(i) =
a_thread_buf[Number<a_thread_desc_.CalculateOffset(
make_tuple(i / A_K1 / A_KRow,
m0,
0,
(i / A_K1) % A_KRow,
0,
i % A_K1))>{}];
});
using wmma_input_type_a = typename vector_type<FloatA, WmmaK>::type;
using wmma_input_type_b = typename vector_type<FloatB, WmmaK>::type;
constexpr index_t c_offset =
c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0));
wmma_gemm.template Run(
a_thread_vec.template AsType<wmma_input_type_a>(),
b_thread_vec.template AsType<wmma_input_type_b>(),
c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
});
});
});
}
}
#endif
protected: protected:
static constexpr auto a_thread_desc_ = make_naive_tensor_descriptor( static constexpr auto a_thread_desc_ = make_naive_tensor_descriptor(
...@@ -449,7 +596,11 @@ struct BlockwiseGemmWMMA ...@@ -449,7 +596,11 @@ struct BlockwiseGemmWMMA
FloatA, FloatA,
decltype(a_block_desc_k0_m0_m1_m2_k1), decltype(a_block_desc_k0_m0_m1_m2_k1),
decltype(a_thread_desc_), decltype(a_thread_desc_),
#ifdef __gfx12__
Sequence<KPack / A_K1 / A_KRow, 1, 1, 1, 1, A_K1>, Sequence<KPack / A_K1 / A_KRow, 1, 1, 1, 1, A_K1>,
#else
Sequence<KPack / A_K1 / A_KRow, 1, 1, A_KRow, 1, A_K1>,
#endif
Sequence<0, 1, 2, 3, 4, 5>, Sequence<0, 1, 2, 3, 4, 5>,
5, 5,
A_K1, A_K1,
...@@ -484,7 +635,11 @@ struct BlockwiseGemmWMMA ...@@ -484,7 +635,11 @@ struct BlockwiseGemmWMMA
FloatB, FloatB,
decltype(b_block_desc_k0_n0_n1_n2_k1), decltype(b_block_desc_k0_n0_n1_n2_k1),
decltype(b_thread_desc_), decltype(b_thread_desc_),
#ifdef __gfx12__
Sequence<KPack / B_K1 / B_KRow, 1, 1, 1, 1, B_K1>, Sequence<KPack / B_K1 / B_KRow, 1, 1, 1, 1, B_K1>,
#else
Sequence<KPack / B_K1 / B_KRow, 1, 1, B_KRow, 1, B_K1>,
#endif
Sequence<0, 1, 2, 3, 4, 5>, Sequence<0, 1, 2, 3, 4, 5>,
5, 5,
B_K1, B_K1,
......
...@@ -136,6 +136,7 @@ struct wmma_type<WmmaInstr::wmma_f32_16x16x16_f16_gfx12, ...@@ -136,6 +136,7 @@ struct wmma_type<WmmaInstr::wmma_f32_16x16x16_f16_gfx12,
// static constexpr index_t src_b_data_size = 2; // static constexpr index_t src_b_data_size = 2;
// static constexpr index_t acc_data_size = 4; // static constexpr index_t acc_data_size = 4;
// * Thread mapping inside wave, num_thread_per_subgroups always alone N direction // * Thread mapping inside wave, num_thread_per_subgroups always alone N direction
static constexpr index_t acc_pack_number = 1;
static constexpr index_t num_thread_per_subgroups = n_per_wmma; static constexpr index_t num_thread_per_subgroups = n_per_wmma;
// Wave mode dependent propety // Wave mode dependent propety
...@@ -151,14 +152,11 @@ struct wmma_type<WmmaInstr::wmma_f32_16x16x16_f16_gfx12, ...@@ -151,14 +152,11 @@ struct wmma_type<WmmaInstr::wmma_f32_16x16x16_f16_gfx12,
template <index_t MPerWmma, index_t NPerWmma, class FloatA, class FloatB, class FloatC> template <index_t MPerWmma, index_t NPerWmma, class FloatA, class FloatB, class FloatC>
__device__ void run(const FloatA& a, const FloatB& b, FloatC& reg_c) const __device__ void run(const FloatA& a, const FloatB& b, FloatC& reg_c) const
{ {
static_assert(wave_size == 32, "only support wave32 for gfx12 wmma");
if constexpr(wave_size == 32) if constexpr(wave_size == 32)
{ {
intrin_wmma_f32_16x16x16_f16_w32_gfx12<MPerWmma, NPerWmma>::Run(a, b, reg_c); intrin_wmma_f32_16x16x16_f16_w32_gfx12<MPerWmma, NPerWmma>::Run(a, b, reg_c);
} }
else if constexpr(wave_size == 64)
{
static_assert(1, "");
}
} }
}; };
...@@ -340,7 +338,7 @@ struct WmmaSelector ...@@ -340,7 +338,7 @@ struct WmmaSelector
template <> template <>
static constexpr auto GetWmma<half_t, half_t, float, 16, 16>() static constexpr auto GetWmma<half_t, half_t, float, 16, 16>()
{ {
#if 1 #ifdef __gfx12__
return WmmaInstr::wmma_f32_16x16x16_f16_gfx12; return WmmaInstr::wmma_f32_16x16x16_f16_gfx12;
#else #else
return WmmaInstr::wmma_f32_16x16x16_f16; return WmmaInstr::wmma_f32_16x16x16_f16;
...@@ -389,12 +387,10 @@ struct WmmaSelector ...@@ -389,12 +387,10 @@ struct WmmaSelector
static_assert(selected_wmma.k_per_wmma == 16, "WRONG! WMMA_M must equal to 16"); static_assert(selected_wmma.k_per_wmma == 16, "WRONG! WMMA_M must equal to 16");
#if 0
static_assert(selected_wmma.wave_size * selected_wmma.num_acc_vgprs_per_wave * static_assert(selected_wmma.wave_size * selected_wmma.num_acc_vgprs_per_wave *
selected_wmma.acc_data_size == selected_wmma.acc_data_size * selected_wmma.acc_pack_number ==
selected_wmma.m_per_wmma * selected_wmma.n_per_wmma * 4, selected_wmma.m_per_wmma * selected_wmma.n_per_wmma * 4,
"WRONG! Invalid Number of Accumulator Register"); "WRONG! Invalid Number of Accumulator Register");
#endif
} }
}; };
...@@ -510,7 +506,7 @@ struct WmmaGemm ...@@ -510,7 +506,7 @@ struct WmmaGemm
__device__ static constexpr index_t GetRegSizePerWmma() __device__ static constexpr index_t GetRegSizePerWmma()
{ {
return wmma_instr.num_acc_vgprs_per_wave; return wmma_instr.num_acc_vgprs_per_wave * wmma_instr.acc_pack_number;
} }
__device__ static constexpr index_t GetWaveSize() { return wmma_instr.wave_size; } __device__ static constexpr index_t GetWaveSize() { return wmma_instr.wave_size; }
...@@ -569,12 +565,14 @@ struct WmmaGemm ...@@ -569,12 +565,14 @@ struct WmmaGemm
__host__ __device__ static auto CalculateAThreadOriginDataIndex() __host__ __device__ static auto CalculateAThreadOriginDataIndex()
{ {
return GetLaneIdUnderSubGroup(); // return GetLaneIdUnderSubGroup();
return TransposeC ? GetLaneIdUnderSubGroup() : GetSwizzledLaneIdLow();
} }
__host__ __device__ static auto CalculateBThreadOriginDataIndex() __host__ __device__ static auto CalculateBThreadOriginDataIndex()
{ {
return GetLaneIdUnderSubGroup(); // return GetLaneIdUnderSubGroup();
return TransposeC ? GetSwizzledLaneIdLow() : GetLaneIdUnderSubGroup();
} }
__device__ static CIndex GetBeginOfThreadBlk() __device__ static CIndex GetBeginOfThreadBlk()
...@@ -600,7 +598,10 @@ struct WmmaGemm ...@@ -600,7 +598,10 @@ struct WmmaGemm
__host__ __device__ static constexpr auto __host__ __device__ static constexpr auto
GetCMSubGroupNThreadPerSubGroupMAccVgprsThreadBlkLengths() GetCMSubGroupNThreadPerSubGroupMAccVgprsThreadBlkLengths()
{ {
return make_tuple(I1, I1, Number<wmma_instr.num_acc_vgprs_per_wave>{}); return make_tuple(I1,
I1,
Number<wmma_instr.num_acc_vgprs_per_wave>{},
Number<wmma_instr.acc_pack_number>{});
} }
}; };
......
...@@ -2,64 +2,64 @@ ...@@ -2,64 +2,64 @@
set(PROFILER_SOURCES set(PROFILER_SOURCES
profiler.cpp profiler.cpp
profile_gemm.cpp profile_gemm.cpp
profile_gemm_splitk.cpp #profile_gemm_splitk.cpp
profile_gemm_bias_add_reduce.cpp #profile_gemm_bias_add_reduce.cpp
profile_gemm_add_multiply.cpp #profile_gemm_add_multiply.cpp
profile_gemm_multiply_add.cpp #profile_gemm_multiply_add.cpp
profile_gemm_reduce.cpp #profile_gemm_reduce.cpp
profile_batched_gemm.cpp #profile_batched_gemm.cpp
profile_batched_gemm_reduce.cpp #profile_batched_gemm_reduce.cpp
profile_conv_fwd.cpp #profile_conv_fwd.cpp
profile_conv_fwd_bias_relu.cpp #profile_conv_fwd_bias_relu.cpp
profile_conv_fwd_bias_relu_add.cpp #profile_conv_fwd_bias_relu_add.cpp
profile_conv_bwd_data.cpp #profile_conv_bwd_data.cpp
profile_grouped_conv_fwd.cpp #profile_grouped_conv_fwd.cpp
profile_grouped_conv_bwd_weight.cpp #profile_grouped_conv_bwd_weight.cpp
profile_reduce.cpp #profile_reduce.cpp
profile_groupnorm_bwd_data.cpp #profile_groupnorm_bwd_data.cpp
profile_groupnorm_fwd.cpp #profile_groupnorm_fwd.cpp
profile_layernorm_bwd_data.cpp #profile_layernorm_bwd_data.cpp
profile_layernorm_bwd_gamma_beta.cpp #profile_layernorm_bwd_gamma_beta.cpp
profile_groupnorm_bwd_gamma_beta.cpp #profile_groupnorm_bwd_gamma_beta.cpp
profile_layernorm_fwd.cpp #profile_layernorm_fwd.cpp
profile_max_pool3d_fwd.cpp #profile_max_pool3d_fwd.cpp
profile_avg_pool3d_bwd.cpp #profile_avg_pool3d_bwd.cpp
profile_max_pool3d_bwd.cpp #profile_max_pool3d_bwd.cpp
profile_softmax.cpp #profile_softmax.cpp
profile_batchnorm_fwd.cpp #profile_batchnorm_fwd.cpp
profile_batchnorm_bwd.cpp #profile_batchnorm_bwd.cpp
profile_batchnorm_infer.cpp #profile_batchnorm_infer.cpp
profile_grouped_conv_bwd_data.cpp #profile_grouped_conv_bwd_data.cpp
profile_conv_tensor_rearrange.cpp #profile_conv_tensor_rearrange.cpp
profile_transpose.cpp #profile_transpose.cpp
profile_permute_scale.cpp #profile_permute_scale.cpp
) )
if(DL_KERNELS) #if(DL_KERNELS)
list(APPEND PROFILER_SOURCES profile_batched_gemm_multi_d.cpp) # list(APPEND PROFILER_SOURCES profile_batched_gemm_multi_d.cpp)
endif() #endif()
#
if(DTYPES MATCHES "fp16" OR NOT DEFINED DTYPES) #if(DTYPES MATCHES "fp16" OR NOT DEFINED DTYPES)
list(APPEND PROFILER_SOURCES profile_batched_gemm_gemm.cpp) # list(APPEND PROFILER_SOURCES profile_batched_gemm_gemm.cpp)
list(APPEND PROFILER_SOURCES profile_gemm_fastgelu.cpp) # list(APPEND PROFILER_SOURCES profile_gemm_fastgelu.cpp)
list(APPEND PROFILER_SOURCES profile_gemm_streamk.cpp) # list(APPEND PROFILER_SOURCES profile_gemm_streamk.cpp)
list(APPEND PROFILER_SOURCES profile_gemm_bilinear.cpp) # list(APPEND PROFILER_SOURCES profile_gemm_bilinear.cpp)
list(APPEND PROFILER_SOURCES profile_gemm_add.cpp) # list(APPEND PROFILER_SOURCES profile_gemm_add.cpp)
list(APPEND PROFILER_SOURCES profile_gemm_add_fastgelu.cpp) # list(APPEND PROFILER_SOURCES profile_gemm_add_fastgelu.cpp)
list(APPEND PROFILER_SOURCES profile_gemm_add_relu.cpp) # list(APPEND PROFILER_SOURCES profile_gemm_add_relu.cpp)
list(APPEND PROFILER_SOURCES profile_gemm_add_silu.cpp) # list(APPEND PROFILER_SOURCES profile_gemm_add_silu.cpp)
list(APPEND PROFILER_SOURCES profile_gemm_add_add_fastgelu.cpp) # list(APPEND PROFILER_SOURCES profile_gemm_add_add_fastgelu.cpp)
list(APPEND PROFILER_SOURCES profile_gemm_add_relu_add_layernorm.cpp) # list(APPEND PROFILER_SOURCES profile_gemm_add_relu_add_layernorm.cpp)
list(APPEND PROFILER_SOURCES profile_batched_gemm_add_relu_gemm_add.cpp) # list(APPEND PROFILER_SOURCES profile_batched_gemm_add_relu_gemm_add.cpp)
list(APPEND PROFILER_SOURCES profile_grouped_gemm.cpp) # list(APPEND PROFILER_SOURCES profile_grouped_gemm.cpp)
list(APPEND PROFILER_SOURCES profile_grouped_gemm_fixed_nk.cpp) # list(APPEND PROFILER_SOURCES profile_grouped_gemm_fixed_nk.cpp)
list(APPEND PROFILER_SOURCES profile_grouped_gemm_fastgelu.cpp) # list(APPEND PROFILER_SOURCES profile_grouped_gemm_fastgelu.cpp)
endif() #endif()
#
if(DTYPES MATCHES "fp32" OR DTYPES MATCHES "fp64" OR NOT DEFINED DTYPES) #if(DTYPES MATCHES "fp32" OR DTYPES MATCHES "fp64" OR NOT DEFINED DTYPES)
list(APPEND PROFILER_SOURCES profile_contraction_bilinear.cpp) # list(APPEND PROFILER_SOURCES profile_contraction_bilinear.cpp)
list(APPEND PROFILER_SOURCES profile_contraction_scale.cpp) # list(APPEND PROFILER_SOURCES profile_contraction_scale.cpp)
endif() #endif()
set(PROFILER_EXECUTABLE ckProfiler) set(PROFILER_EXECUTABLE ckProfiler)
...@@ -68,67 +68,67 @@ target_compile_options(${PROFILER_EXECUTABLE} PRIVATE -Wno-global-constructors) ...@@ -68,67 +68,67 @@ target_compile_options(${PROFILER_EXECUTABLE} PRIVATE -Wno-global-constructors)
target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE utility getopt::getopt) target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE utility getopt::getopt)
target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_gemm_instance) target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_gemm_instance)
target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_gemm_splitk_instance) #target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_gemm_splitk_instance)
target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_gemm_add_multiply_instance) #target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_gemm_add_multiply_instance)
target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_gemm_multiply_add_instance) #target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_gemm_multiply_add_instance)
target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_gemm_reduce_instance) #target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_gemm_reduce_instance)
target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_gemm_bias_add_reduce_instance) #target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_gemm_bias_add_reduce_instance)
target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_batched_gemm_instance) #target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_batched_gemm_instance)
target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_batched_gemm_reduce_instance) #target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_batched_gemm_reduce_instance)
target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_conv2d_fwd_instance) #target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_conv2d_fwd_instance)
target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_grouped_conv1d_fwd_instance) #target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_grouped_conv1d_fwd_instance)
target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_grouped_conv2d_fwd_instance) #target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_grouped_conv2d_fwd_instance)
target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_grouped_conv3d_fwd_instance) #target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_grouped_conv3d_fwd_instance)
target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_conv1d_bwd_data_instance) #target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_conv1d_bwd_data_instance)
target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_conv2d_bwd_data_instance) #target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_conv2d_bwd_data_instance)
target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_conv3d_bwd_data_instance) #target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_conv3d_bwd_data_instance)
target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_grouped_conv1d_bwd_weight_instance) #target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_grouped_conv1d_bwd_weight_instance)
target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_grouped_conv2d_bwd_weight_instance) #target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_grouped_conv2d_bwd_weight_instance)
target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_grouped_conv3d_bwd_weight_instance) #target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_grouped_conv3d_bwd_weight_instance)
target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_conv2d_fwd_bias_relu_instance) #target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_conv2d_fwd_bias_relu_instance)
target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_conv2d_fwd_bias_relu_add_instance) #target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_conv2d_fwd_bias_relu_add_instance)
target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_normalization_fwd_instance) #target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_normalization_fwd_instance)
target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_normalization_bwd_data_instance) #target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_normalization_bwd_data_instance)
target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_normalization_bwd_gamma_beta_instance) #target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_normalization_bwd_gamma_beta_instance)
target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_softmax_instance) #target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_softmax_instance)
target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_reduce_instance) #target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_reduce_instance)
target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_batchnorm_instance) #target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_batchnorm_instance)
target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_pool3d_fwd_instance) #target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_pool3d_fwd_instance)
target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_avg_pool3d_bwd_instance) #target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_avg_pool3d_bwd_instance)
target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_max_pool_bwd_instance) #target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_max_pool_bwd_instance)
target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_grouped_conv2d_bwd_data_instance) #target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_grouped_conv2d_bwd_data_instance)
target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_grouped_conv3d_bwd_data_instance) #target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_grouped_conv3d_bwd_data_instance)
target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_image_to_column_instance) #target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_image_to_column_instance)
target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_column_to_image_instance) #target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_column_to_image_instance)
target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_transpose_instance) #target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_transpose_instance)
target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_permute_scale_instance) #target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_permute_scale_instance)
#
if(DTYPES MATCHES "fp32" OR DTYPES MATCHES "fp64" OR NOT DEFINED DTYPES) #if(DTYPES MATCHES "fp32" OR DTYPES MATCHES "fp64" OR NOT DEFINED DTYPES)
target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_contraction_bilinear_instance) # target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_contraction_bilinear_instance)
target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_contraction_scale_instance) # target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_contraction_scale_instance)
endif() #endif()
#
#
#
if(DL_KERNELS) #if(DL_KERNELS)
target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_batched_gemm_multi_d_instance) # target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_batched_gemm_multi_d_instance)
endif() #endif()
#
if(DTYPES MATCHES "fp16" OR NOT DEFINED DTYPES) #if(DTYPES MATCHES "fp16" OR NOT DEFINED DTYPES)
target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_gemm_add_instance) # target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_gemm_add_instance)
target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_gemm_add_fastgelu_instance) # target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_gemm_add_fastgelu_instance)
target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_gemm_add_relu_instance) # target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_gemm_add_relu_instance)
target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_gemm_add_silu_instance) # target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_gemm_add_silu_instance)
target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_gemm_add_relu_add_layernorm_instance) # target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_gemm_add_relu_add_layernorm_instance)
target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_gemm_bilinear_instance) # target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_gemm_bilinear_instance)
target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_gemm_add_add_fastgelu_instance) # target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_gemm_add_add_fastgelu_instance)
target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_gemm_streamk_instance) # target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_gemm_streamk_instance)
target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_gemm_fastgelu_instance) # target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_gemm_fastgelu_instance)
target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_batched_gemm_gemm_instance) # target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_batched_gemm_gemm_instance)
target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_batched_gemm_add_relu_gemm_add_instance) # target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_batched_gemm_add_relu_gemm_add_instance)
target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_grouped_gemm_instance) # target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_grouped_gemm_instance)
target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_grouped_gemm_fixed_nk_instance) # target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_grouped_gemm_fixed_nk_instance)
target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_grouped_gemm_fastgelu_instance) # target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_grouped_gemm_fastgelu_instance)
endif() #endif()
rocm_install(TARGETS ${PROFILER_EXECUTABLE} COMPONENT profiler) rocm_install(TARGETS ${PROFILER_EXECUTABLE} COMPONENT profiler)
...@@ -11,7 +11,7 @@ cmake ...@@ -11,7 +11,7 @@ cmake
-D CMAKE_CXX_FLAGS="-std=c++17 -O3 -ftemplate-backtrace-limit=0 -fPIE -Wno-gnu-line-marker" \ -D CMAKE_CXX_FLAGS="-std=c++17 -O3 -ftemplate-backtrace-limit=0 -fPIE -Wno-gnu-line-marker" \
-D CMAKE_BUILD_TYPE=Release \ -D CMAKE_BUILD_TYPE=Release \
-D BUILD_DEV=OFF \ -D BUILD_DEV=OFF \
-D GPU_TARGETS="gfx1200" \ -D GPU_TARGETS="gfx1100" \
-D CMAKE_VERBOSE_MAKEFILE:BOOL=ON \ -D CMAKE_VERBOSE_MAKEFILE:BOOL=ON \
-D USE_BITINT_EXTENSION_INT4=OFF \ -D USE_BITINT_EXTENSION_INT4=OFF \
${MY_PROJECT_SOURCE} ${MY_PROJECT_SOURCE}
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment