cmake_minimum_required(VERSION 3.15)
project(ck_app LANGUAGES CXX)

if (DTYPES)
    add_definitions(-DDTYPES)
    if (DTYPES MATCHES "int8")
        add_definitions(-DCK_ENABLE_INT8)
        if(NOT DEFINED ${CK_ENABLE_INT8})
            set(CK_ENABLE_INT8 "ON")
        endif()
    endif()
    if (DTYPES MATCHES "fp8")
        add_definitions(-DCK_ENABLE_FP8)
        if(NOT DEFINED ${CK_ENABLE_FP8})
            set(CK_ENABLE_FP8 "ON")
        endif()
    endif()
    if (DTYPES MATCHES "fp16")
        add_definitions(-DCK_ENABLE_FP16)
        if(NOT DEFINED ${CK_ENABLE_FP16})
            set(CK_ENABLE_FP16 "ON")
        endif()
    endif()
    if (DTYPES MATCHES "fp32")
        add_definitions(-DCK_ENABLE_FP32)
        if(NOT DEFINED ${CK_ENABLE_FP32})
            set(CK_ENABLE_FP32 "ON")
        endif()
    endif()
    if (DTYPES MATCHES "fp64")
        add_definitions(-DCK_ENABLE_FP64)
        if(NOT DEFINED ${CK_ENABLE_FP64})
            set(CK_ENABLE_FP64 "ON")
        endif()
    endif()
    if (DTYPES MATCHES "bf16")
        add_definitions(-DCK_ENABLE_BF16)
        if(NOT DEFINED ${CK_ENABLE_BF16})
            set(CK_ENABLE_BF16 "ON")
        endif()
    endif()
    message("DTYPES macro set to ${DTYPES}")
else()
    add_definitions(-DCK_ENABLE_INT8 -DCK_ENABLE_FP8 -DCK_ENABLE_FP16 -DCK_ENABLE_FP32 -DCK_ENABLE_FP64 -DCK_ENABLE_BF16)
    if(NOT DEFINED ${CK_ENABLE_ALL_DTYPES})
        set(CK_ENABLE_ALL_DTYPES "ON")
    endif()
endif()

find_package(composable_kernel COMPONENTS device_other_operations device_gemm_operations device_conv_operations device_contraction_operations device_reduction_operations)
find_package(hip REQUIRED PATHS /opt/rocm $ENV{HIP_PATH})
message(STATUS "Build with HIP ${hip_VERSION}")

add_subdirectory(01_gemm)
add_subdirectory(02_gemm_bilinear)
add_subdirectory(03_gemm_bias_relu)
add_subdirectory(04_gemm_add_add_fastgelu)
add_subdirectory(09_convnd_fwd)
add_subdirectory(10_convnd_fwd_multiple_d_multiple_reduce)
add_subdirectory(12_reduce)
add_subdirectory(13_pool2d_fwd)
add_subdirectory(14_gemm_quantization)
add_subdirectory(15_grouped_gemm)
add_subdirectory(16_gemm_multi_d_multi_reduces)
add_subdirectory(17_convnd_bwd_data)
add_subdirectory(18_batched_gemm_reduce)
add_subdirectory(19_binary_elementwise)
add_subdirectory(20_grouped_conv_bwd_weight)
add_subdirectory(21_gemm_layernorm)
add_subdirectory(22_cgemm)
add_subdirectory(23_softmax)
add_subdirectory(24_batched_gemm)
add_subdirectory(25_gemm_bias_e_permute)
add_subdirectory(26_contraction)
add_subdirectory(27_layernorm2d_fwd)
add_subdirectory(28_grouped_gemm_bias_e_permute)
add_subdirectory(29_batched_gemm_bias_e_permute)
add_subdirectory(30_grouped_conv_fwd_multiple_d)
add_subdirectory(31_batched_gemm_gemm)
add_subdirectory(32_batched_gemm_scale_softmax_gemm)
add_subdirectory(33_multiple_reduce)
add_subdirectory(34_batchnorm)
add_subdirectory(35_splitK_gemm)
add_subdirectory(36_sparse_embedding)
add_subdirectory(37_batched_gemm_add_add_relu_gemm_add)
add_subdirectory(38_grouped_conv_bwd_data_multiple_d)
add_subdirectory(39_permute)
add_subdirectory(40_conv2d_fwd_quantization)
add_subdirectory(41_grouped_conv_conv_fwd)
add_subdirectory(42_groupnorm_fwd)
add_subdirectory(43_splitk_gemm_bias_e_permute)
add_subdirectory(44_elementwise_permute)
add_subdirectory(45_elementwise_normalization)
add_subdirectory(46_gemm_add_multiply)
add_subdirectory(47_gemm_bias_softmax_gemm_permute)
add_subdirectory(48_pool3d_fwd)
add_subdirectory(49_maxpool2d_bwd)
add_subdirectory(50_put_element)
add_subdirectory(51_avgpool3d_bwd)
add_subdirectory(52_im2col_col2im)
add_subdirectory(53_layernorm_bwd)
add_subdirectory(54_groupnorm_bwd)
add_subdirectory(60_gemm_multi_ABD)
add_subdirectory(61_contraction_multi_ABD)
add_subdirectory(62_conv_fwd_activ)
add_subdirectory(63_layernorm4d_fwd)
add_subdirectory(64_tensor_transforms)