# Copyright © Advanced Micro Devices, Inc., or its affiliates.
# SPDX-License-Identifier:  MIT

cmake_minimum_required(VERSION 3.25.2)

# Enable PIC/PIE to ensure compatibility with the plugin loader system (dlopen). This prevents
# potential Thread Local Storage (TLS) model mismatches between the executable and dynamically
# loaded backend plugins.
set(CMAKE_POSITION_INDEPENDENT_CODE ON)

add_compile_definitions(__HIP_PLATFORM_AMD__)

if(DEFINED ENV{ROCM_PATH})
    set(ROCM_PATH "$ENV{ROCM_PATH}")
else()
    message(FATAL_ERROR "Must be source dtk/env.sh")
endif()

project(hipdnn_samples VERSION 0.1.0 LANGUAGES C CXX)
include(GNUInstallDirs)
set(CMAKE_CXX_STANDARD 17)

find_package(hip REQUIRED)
find_package(Threads REQUIRED)

if(NOT TARGET hipdnn_frontend)
    find_package(hipdnn_frontend CONFIG REQUIRED)
endif()

include_directories(${CMAKE_CURRENT_LIST_DIR})
set(CMAKE_RUNTIME_OUTPUT_DIRECTORY ${CMAKE_BINARY_DIR}/bin)

list(PREPEND HIPDNN_WARNING_COMPILE_OPTIONS
    -Werror                                  # Treat all warnings as errors
    -Wall                                    # Enable most common warnings
    -Wextra                                  # Enable additional warnings not covered by -Wall
    -Wpedantic                               # Enforce strict ISO C++ compliance
    -Wshadow                                 # Warn about variable shadowing
    -Wnon-virtual-dtor                       # Warn if a class with virtual functions has a non-virtual destructor
    -Wold-style-cast                         # Warn about C-style casts
    -Wcast-align                             # Warn about potential performance issues with misaligned casts
    -Woverloaded-virtual                     # Warn if a base class function is hidden by a derived class function with the same name
    -Wconversion                             # Warn about implicit type conversions that may alter a value
    -Wsign-conversion                        # Warn about implicit conversions between signed and unsigned types
    -Wnull-dereference                       # Warn about dereferencing null pointers
    -Wdouble-promotion                       # Warn when a float is implicitly promoted to a double
    -Wformat=2                               # Enable stricter format string checks
    -Winit-self                              # Warn about variables initialized with itself
    -Wunreachable-code                       # Warn about unreachable code
    -Wno-return-type                         # DTK-25.04.2 need ignore
    -Wswitch-default                         # Warn if a switch statement does not have a default case
)

function(add_hipdnn_sample NAME SOURCE)
    add_executable(${NAME} ${SOURCE})
    target_compile_options(${NAME} PRIVATE ${HIPDNN_WARNING_COMPILE_OPTIONS})
    target_link_libraries(${NAME} PRIVATE hip::host Threads::Threads hipdnn_frontend)
endfunction()

add_hipdnn_sample(bn_inference batchnorm/BnInference.cpp)
#add_hipdnn_sample(bn_finalize batchnorm/BnFinalize.cpp)
add_hipdnn_sample(bn_training batchnorm/BnTraining.cpp)
add_hipdnn_sample(bn_backward batchnorm/BnBackward.cpp)
#add_hipdnn_sample(bn_backward_weight batchnorm/BnBackwardWeight.cpp)
add_hipdnn_sample(conv_forward convolution/ConvForward.cpp)
add_hipdnn_sample(conv_backward convolution/ConvBackward.cpp)
add_hipdnn_sample(conv_wrw convolution/ConvBackwardWeight.cpp)
add_hipdnn_sample(conv_bias_prelu conv_fusion/ConvBiasPrelu.cpp)
add_hipdnn_sample(conv_bias_prelu_add conv_fusion/ConvBiasPreluAdd.cpp)
add_hipdnn_sample(conv_bias_swish_add conv_fusion/ConvBiasSwishAdd.cpp)
add_hipdnn_sample(conv_bias_swish conv_fusion/ConvBiasSwish.cpp)
add_hipdnn_sample(conv_bias_relu conv_fusion/ConvBiasRelu.cpp)
add_hipdnn_sample(conv_bias_add conv_fusion/ConvBiasAdd.cpp)
add_hipdnn_sample(conv_bias conv_fusion/ConvBias.cpp)
add_hipdnn_sample(conv_bias_add_relu conv_fusion/ConvBiasAddRelu.cpp)
add_hipdnn_sample(convbwd_bias_relu conv_fusion/ConvbwdBiasRelu.cpp)
add_hipdnn_sample(convint8_bias conv_fusion/Int8ConvBias.cpp)
add_hipdnn_sample(convint8_bias_add conv_fusion/Int8ConvBiasAdd.cpp)
add_hipdnn_sample(convint8_bias_add_relu conv_fusion/Int8ConvBiasAddRelu.cpp)
add_hipdnn_sample(convint8_bias_relu conv_fusion/Int8ConvBiasRelu.cpp)
add_hipdnn_sample(convint8_bias_relu_add conv_fusion/Int8ConvBiasReluAdd.cpp)
add_hipdnn_sample(convfp16_bias_relu conv_fusion/Fp16ConvBiasRelu.cpp)
add_hipdnn_sample(ln_inference layernorm/LnInference.cpp)
#add_hipdnn_sample(ln_backward layernorm/LnBackward.cpp)
add_hipdnn_sample(rms_forward rmsnorm/RmsnormForward.cpp)
add_hipdnn_sample(deform_conv_fprop deformconvolution/DeformConvForward.cpp)
add_hipdnn_sample(deform_conv_dgrad deformconvolution/DeformConvBackward.cpp)
add_hipdnn_sample(deform_conv_wgrad deformconvolution/DeformConvBackwardWeight.cpp)
add_hipdnn_sample(gn_training groupnorm/GNTraining.cpp)
add_hipdnn_sample(gn_inference groupnorm/GNInference.cpp)
add_hipdnn_sample(gn_backward groupnorm/GNBackward.cpp)
add_hipdnn_sample(add_layernorm fusion/AddLayernorm.cpp)
add_hipdnn_sample(gn_swish fusion/GroupnormSwish.cpp)
add_hipdnn_sample(sdpa_inference sdpa/SDPAInference.cpp)
add_hipdnn_sample(reduction reduction/Reduction.cpp)
add_hipdnn_sample(reluBwd_reduction reduction/PointwiseReduction.cpp)
add_hipdnn_sample(transpose transpose/Transpose.cpp)
#add_hipdnn_sample(genstats genstats/Genstats.cpp)
add_hipdnn_sample(reshape_transpose fusion/ReshapeTranspose.cpp)
#add_hipdnn_sample(resample resample/Resample.cpp)
add_hipdnn_sample(deform_attn_fprop deformattention/DeformAttnForward.cpp)
add_hipdnn_sample(deform_attn_dgrad deformattention/DeformAttnBackward.cpp)
add_hipdnn_sample(instancenorm_inference instancenorm/InstancenormInference.cpp)
add_hipdnn_sample(instancenorm_backward instancenorm/InstancenormBackward.cpp)
add_hipdnn_sample(instancenorm_training instancenorm/InstancenormTraining.cpp)
#add_hipdnn_sample(block_scale_dequantize block_scale/BlockScaleDequantize.cpp)
#add_hipdnn_sample(block_scale_quantize block_scale/BlockScaleQuantize.cpp)
#add_hipdnn_sample(slice slice/Slice.cpp)
#add_hipdnn_sample(rng rng/Rng.cpp)
add_hipdnn_sample(adamw adamw/Adamw.cpp)
add_hipdnn_sample(transformer_adamw adamw/TransformerAdamw.cpp)
add_hipdnn_sample(concatenate concatenate/Concatenate.cpp)
# add_hipdnn_sample(pw_conv_genstats fusion/PointwiseConvGenstats.cpp)
add_hipdnn_sample(concat_conv concat_conv_fusion/ConcatConv.cpp)
add_hipdnn_sample(concat_conv_bias concat_conv_fusion/ConcatConvBias.cpp)
add_hipdnn_sample(concat_conv_bias_add concat_conv_fusion/ConcatConvBiasAdd.cpp)
add_hipdnn_sample(concat_conv_bias_leakyRelu concat_conv_fusion/ConcatConvBiasLeakyRelu.cpp)
add_hipdnn_sample(concat_conv_bias_leakyRelu_add concat_conv_fusion/ConcatConvBiasLeakyReluAdd.cpp)
add_hipdnn_sample(conv_bias_depthToSpace conv_depthtospace_fusion/ConvBiasDepthToSpace.cpp)
add_hipdnn_sample(conv_bias_depthToSpace_add conv_depthtospace_fusion/ConvBiasDepthToSpaceAdd.cpp)
add_hipdnn_sample(conv_bias_add_depthToSpace conv_depthtospace_fusion/ConvBiasAddDepthToSpace.cpp)
add_hipdnn_sample(conv_bias_depthToSpace_clippedRelu conv_depthtospace_fusion/ConvBiasDepthToSpaceClippedRelu.cpp)
add_hipdnn_sample(conv_bias_depthToSpace_clippedRelu_add conv_depthtospace_fusion/ConvBiasDepthToSpaceClippedReluAdd.cpp)
add_hipdnn_sample(conv_depthToSpace conv_depthtospace_fusion/ConvDepthToSpace.cpp)
add_hipdnn_sample(matmul matmul/Matmul.cpp)
add_hipdnn_sample(matmul_bias matmul_fusion/MatmulBias.cpp)
add_hipdnn_sample(matmul_bias_swish matmul_fusion/MatmulBiasSwish.cpp)
add_hipdnn_sample(rope_forward rope/RopeForward.cpp)
add_hipdnn_sample(rope_backward rope/RopeBackward.cpp)
add_hipdnn_sample(pointwise_binary pointwise/BinaryPointwise.cpp)
add_hipdnn_sample(softmax softmax/Softmax.cpp)
add_hipdnn_sample(ctc_loss ctc_loss/CtcLoss.cpp)
add_hipdnn_sample(kthvalue2d kthvalue/Kthvalue2D.cpp)
add_hipdnn_sample(kthvalue4d kthvalue/Kthvalue4D.cpp)
add_hipdnn_sample(multi_margin_loss multi_margin_loss/MultiMarginLoss.cpp)
add_hipdnn_sample(soft_margin_loss soft_margin_loss/SoftMarginLossForward.cpp)
add_hipdnn_sample(soft_margin_loss_backward soft_margin_loss/SoftMarginLossBackward.cpp)
add_hipdnn_sample(getitem_indices_backward getitem_backward/GetitemBackwardIndices.cpp)
add_hipdnn_sample(getitem_slice_backward getitem_backward/GetitemBackwardSlice.cpp)
add_hipdnn_sample(scale_bias_relu_conv_genstats conv_bn_fusion/ScaleBiasReluConvGenstats.cpp)
add_hipdnn_sample(scale_bias_relu_convwrw conv_bn_fusion/ScaleBiasReluConvwrw.cpp)
add_hipdnn_sample(mul_mul_add_add conv_bn_fusion/MulMulAddAdd.cpp)
add_hipdnn_sample(sub_mul_mul_add_convbwd_relubwd_bnwrw conv_bn_fusion/SubMulMulAddConvbwdRelubwdBnwrw.cpp)
add_hipdnn_sample(conv_genstats conv_bn_fusion/ConvGenstats.cpp)
add_hipdnn_sample(scale_bias conv_bn_fusion/ScaleBias.cpp)