"...unconditional_image_generation/train_unconditional.py" did not exist on "8c1f51978c705b49a23526a8311b64716411afe2"
Commit e276fc95 authored by Artur Wojcik's avatar Artur Wojcik
Browse files

merge 'uif2-temp' to uif2-initial

parent 9b3a0d42
add_executable(client_gemm gemm.cpp) add_executable(client_gemm gemm.cpp)
target_link_libraries(client_gemm PRIVATE composable_kernel::device_other_operations composable_kernel::device_gemm_operations) target_link_libraries(client_gemm PRIVATE composable_kernel::device_other_operations composable_kernel::device_gemm_operations)
target_compile_features(client_gemm PRIVATE cxx_std_17)
...@@ -2,15 +2,12 @@ add_custom_target(client_gemm_fastgelu_examples) ...@@ -2,15 +2,12 @@ add_custom_target(client_gemm_fastgelu_examples)
add_executable(client_gemm_add_add_fastgelu gemm_add_add_fastgelu.cpp) add_executable(client_gemm_add_add_fastgelu gemm_add_add_fastgelu.cpp)
target_link_libraries(client_gemm_add_add_fastgelu PRIVATE composable_kernel::device_gemm_operations) target_link_libraries(client_gemm_add_add_fastgelu PRIVATE composable_kernel::device_gemm_operations)
target_compile_features(client_gemm_add_add_fastgelu PRIVATE cxx_std_17)
add_executable(client_gemm_add_fastgelu gemm_add_fastgelu.cpp) add_executable(client_gemm_add_fastgelu gemm_add_fastgelu.cpp)
target_link_libraries(client_gemm_add_fastgelu PRIVATE composable_kernel::device_gemm_operations) target_link_libraries(client_gemm_add_fastgelu PRIVATE composable_kernel::device_gemm_operations)
target_compile_features(client_gemm_add_fastgelu PRIVATE cxx_std_17)
add_executable(client_gemm_fastgelu gemm_fastgelu.cpp) add_executable(client_gemm_fastgelu gemm_fastgelu.cpp)
target_link_libraries(client_gemm_fastgelu PRIVATE composable_kernel::device_gemm_operations) target_link_libraries(client_gemm_fastgelu PRIVATE composable_kernel::device_gemm_operations)
target_compile_features(client_gemm_fastgelu PRIVATE cxx_std_17)
add_dependencies(client_gemm_fastgelu_examples client_gemm_add_add_fastgelu client_gemm_add_fastgelu add_dependencies(client_gemm_fastgelu_examples client_gemm_add_add_fastgelu client_gemm_add_fastgelu
client_gemm_fastgelu) client_gemm_fastgelu)
...@@ -19,15 +16,12 @@ add_custom_target(client_gemm_fastgelu_generic_examples) ...@@ -19,15 +16,12 @@ add_custom_target(client_gemm_fastgelu_generic_examples)
add_executable(client_gemm_add_add_fastgelu_generic gemm_add_add_fastgelu_generic.cpp) add_executable(client_gemm_add_add_fastgelu_generic gemm_add_add_fastgelu_generic.cpp)
target_link_libraries(client_gemm_add_add_fastgelu_generic composable_kernel::device_gemm_operations) target_link_libraries(client_gemm_add_add_fastgelu_generic composable_kernel::device_gemm_operations)
target_compile_features(client_gemm_add_add_fastgelu_generic PRIVATE cxx_std_17)
add_executable(client_gemm_add_fastgelu_generic gemm_add_fastgelu_generic.cpp) add_executable(client_gemm_add_fastgelu_generic gemm_add_fastgelu_generic.cpp)
target_link_libraries(client_gemm_add_fastgelu_generic PRIVATE composable_kernel::device_gemm_operations) target_link_libraries(client_gemm_add_fastgelu_generic PRIVATE composable_kernel::device_gemm_operations)
target_compile_features(client_gemm_add_fastgelu_generic PRIVATE cxx_std_17)
add_executable(client_gemm_fastgelu_generic gemm_fastgelu_generic.cpp) add_executable(client_gemm_fastgelu_generic gemm_fastgelu_generic.cpp)
target_link_libraries(client_gemm_fastgelu_generic PRIVATE composable_kernel::device_gemm_operations) target_link_libraries(client_gemm_fastgelu_generic PRIVATE composable_kernel::device_gemm_operations)
target_compile_features(client_gemm_fastgelu_generic PRIVATE cxx_std_17)
add_dependencies(client_gemm_fastgelu_generic_examples client_gemm_add_add_fastgelu_generic add_dependencies(client_gemm_fastgelu_generic_examples client_gemm_add_add_fastgelu_generic
client_gemm_add_fastgelu_generic client_gemm_fastgelu_generic) client_gemm_add_fastgelu_generic client_gemm_fastgelu_generic)
\ No newline at end of file
add_executable(client_gemm_add_add_layernorm_naive gemm_add_add_layernorm_naive.cpp) add_executable(client_gemm_add_add_layernorm_naive gemm_add_add_layernorm_naive.cpp)
target_link_libraries(client_gemm_add_add_layernorm_naive PRIVATE composable_kernel::device_gemm_operations composable_kernel::device_other_operations) target_link_libraries(client_gemm_add_add_layernorm_naive PRIVATE composable_kernel::device_gemm_operations composable_kernel::device_other_operations)
target_compile_features(client_gemm_add_add_reduce_normalize PRIVATE cxx_std_17)
add_executable(client_gemm_add_relu_add_layernorm_welford gemm_add_relu_add_layernorm_welford.cpp) add_executable(client_gemm_add_relu_add_layernorm_welford gemm_add_relu_add_layernorm_welford.cpp)
target_link_libraries(client_gemm_add_relu_add_layernorm_welford PRIVATE composable_kernel::device_gemm_operations composable_kernel::device_other_operations) target_link_libraries(client_gemm_add_relu_add_layernorm_welford PRIVATE composable_kernel::device_gemm_operations composable_kernel::device_other_operations)
target_compile_features(client_gemm_add_relu_add_layernorm_welford PRIVATE cxx_std_17)
add_executable(client_contraction_scale_fp32 contraction_scale_fp32.cpp) add_executable(client_contraction_scale_fp32 contraction_scale_fp32.cpp)
target_link_libraries(client_contraction_scale_fp32 PRIVATE composable_kernel::device_other_operations composable_kernel::device_contraction_operations composable_kernel::device_gemm_operations) target_link_libraries(client_contraction_scale_fp32 PRIVATE composable_kernel::device_other_operations composable_kernel::device_contraction_operations composable_kernel::device_gemm_operations)
target_compile_features(client_contraction_scale PRIVATE cxx_std_17)
add_executable(client_contraction_bilinear_fp32 contraction_bilinear_fp32.cpp) add_executable(client_contraction_bilinear_fp32 contraction_bilinear_fp32.cpp)
target_link_libraries(client_contraction_bilinear_fp32 PRIVATE composable_kernel::device_other_operations composable_kernel::device_contraction_operations composable_kernel::device_gemm_operations) target_link_libraries(client_contraction_bilinear_fp32 PRIVATE composable_kernel::device_other_operations composable_kernel::device_contraction_operations composable_kernel::device_gemm_operations)
target_compile_features(client_contraction_bilinear PRIVATE cxx_std_17)
add_executable(client_contraction_scale_fp64 contraction_scale_fp64.cpp) add_executable(client_contraction_scale_fp64 contraction_scale_fp64.cpp)
target_link_libraries(client_contraction_scale_fp64 PRIVATE composable_kernel::device_other_operations composable_kernel::device_contraction_operations composable_kernel::device_gemm_operations) target_link_libraries(client_contraction_scale_fp64 PRIVATE composable_kernel::device_other_operations composable_kernel::device_contraction_operations composable_kernel::device_gemm_operations)
target_compile_features(client_contraction_scale_fp64 PRIVATE cxx_std_17)
add_executable(client_contraction_bilinear_fp64 contraction_bilinear_fp64.cpp) add_executable(client_contraction_bilinear_fp64 contraction_bilinear_fp64.cpp)
target_link_libraries(client_contraction_bilinear_fp64 PRIVATE composable_kernel::device_other_operations composable_kernel::device_contraction_operations composable_kernel::device_gemm_operations) target_link_libraries(client_contraction_bilinear_fp64 PRIVATE composable_kernel::device_other_operations composable_kernel::device_contraction_operations composable_kernel::device_gemm_operations)
target_compile_features(client_contraction_blinear_fp64 PRIVATE cxx_std_17)
add_executable(contraction_g1m2n3k1_add_xdl_fp16 contraction_g1m2n3k1_add_xdl_fp16.cpp) add_executable(contraction_g1m2n3k1_add_xdl_fp16 contraction_g1m2n3k1_add_xdl_fp16.cpp)
target_link_libraries(contraction_g1m2n3k1_add_xdl_fp16 PRIVATE composable_kernel::device_other_operations composable_kernel::device_contraction_operations composable_kernel::device_gemm_operations) target_link_libraries(contraction_g1m2n3k1_add_xdl_fp16 PRIVATE composable_kernel::device_other_operations composable_kernel::device_contraction_operations composable_kernel::device_gemm_operations)
target_compile_features(contraction_g1m2n3k1_add_xdl-fp16 PRIVATE cxx_std_17)
add_executable(client_layernorm2d_fwd layernorm2d_fwd.cpp) add_executable(client_layernorm2d_fwd layernorm2d_fwd.cpp)
target_link_libraries(client_layernorm2d_fwd PRIVATE composable_kernel::device_other_operations) target_link_libraries(client_layernorm2d_fwd PRIVATE composable_kernel::device_other_operations)
target_compile_features(client_layernorm2d_fwd PRIVATE cxx_std_17)
add_executable(client_layernorm4d_fwd layernorm4d_fwd.cpp) add_executable(client_layernorm4d_fwd layernorm4d_fwd.cpp)
target_link_libraries(client_layernorm4d_fwd PRIVATE composable_kernel::device_other_operations) target_link_libraries(client_layernorm4d_fwd PRIVATE composable_kernel::device_other_operations)
target_compile_features(client_layernorm4d_fwd PRIVATE cxx_std_17)
add_executable(client_softmax4d softmax4d.cpp) add_executable(client_softmax4d softmax4d.cpp)
target_link_libraries(client_softmax4d PRIVATE composable_kernel::device_other_operations composable_kernel::device_reduction_operations) target_link_libraries(client_softmax4d PRIVATE composable_kernel::device_other_operations composable_kernel::device_reduction_operations)
target_compile_features(client_softmax4d PRIVATE cxx_std_17)
add_executable(client_fused_attention fused_attention.cpp) add_executable(client_fused_attention fused_attention.cpp)
target_link_libraries(client_fused_attention PRIVATE composable_kernel::device_other_operations composable_kernel::device_gemm_operations) target_link_libraries(client_fused_attention PRIVATE composable_kernel::device_other_operations composable_kernel::device_gemm_operations)
target_compile_features(client_fused_attention PRIVATE cxx_std_17)
add_executable(client_fused_attention_bias fused_attention_bias.cpp) add_executable(client_fused_attention_bias fused_attention_bias.cpp)
target_link_libraries(client_fused_attention_bias PRIVATE composable_kernel::device_other_operations composable_kernel::device_gemm_operations) target_link_libraries(client_fused_attention_bias PRIVATE composable_kernel::device_other_operations composable_kernel::device_gemm_operations)
target_compile_features(client_fused_attention_bias PRIVATE cxx_std_17)
if(DTYPES MATCHES "int8" OR NOT DEFINED DTYPES) if(DTYPES MATCHES "int8" OR NOT DEFINED DTYPES)
add_executable(client_conv2d_fwd_bias_tanh_perchannel_quantization conv2d_fwd_bias_tanh_perchannel_quantization.cpp) add_executable(client_conv2d_fwd_bias_tanh_perchannel_quantization conv2d_fwd_bias_tanh_perchannel_quantization.cpp)
target_link_libraries(client_conv2d_fwd_bias_tanh_perchannel_quantization PRIVATE composable_kernel::device_conv_operations composable_kernel::device_other_operations composable_kernel::device_gemm_operations) target_link_libraries(client_conv2d_fwd_bias_tanh_perchannel_quantization PRIVATE composable_kernel::device_conv_operations composable_kernel::device_other_operations composable_kernel::device_gemm_operations)
target_compile_features(client_conv2d_fwd_bias_tanh_perchangel_quantization PRIVATE cxx_std_17)
add_executable(client_conv2d_fwd_bias_relu_perchannel_quantization conv2d_fwd_bias_relu_perchannel_quantization.cpp) add_executable(client_conv2d_fwd_bias_relu_perchannel_quantization conv2d_fwd_bias_relu_perchannel_quantization.cpp)
target_link_libraries(client_conv2d_fwd_bias_relu_perchannel_quantization PRIVATE composable_kernel::device_conv_operations composable_kernel::device_other_operations composable_kernel::device_gemm_operations) target_link_libraries(client_conv2d_fwd_bias_relu_perchannel_quantization PRIVATE composable_kernel::device_conv_operations composable_kernel::device_other_operations composable_kernel::device_gemm_operations)
target_compile_features(client_conv2d_fwd_bias_relu_perchannel_quantization PRIVATE cxx_std_17)
add_executable(client_conv2d_fwd_bias_tanh_perlayer_quantization conv2d_fwd_bias_tanh_perlayer_quantization.cpp) add_executable(client_conv2d_fwd_bias_tanh_perlayer_quantization conv2d_fwd_bias_tanh_perlayer_quantization.cpp)
target_link_libraries(client_conv2d_fwd_bias_tanh_perlayer_quantization PRIVATE composable_kernel::device_conv_operations composable_kernel::device_other_operations composable_kernel::device_gemm_operations) target_link_libraries(client_conv2d_fwd_bias_tanh_perlayer_quantization PRIVATE composable_kernel::device_conv_operations composable_kernel::device_other_operations composable_kernel::device_gemm_operations)
target_compile_features(client_conv2d_fwd_bias_tanh_perlayer_quantization PRIVATE cxx_std_17)
add_executable(client_conv2d_fwd_bias_relu_perlayer_quantization conv2d_fwd_bias_relu_perlayer_quantization.cpp) add_executable(client_conv2d_fwd_bias_relu_perlayer_quantization conv2d_fwd_bias_relu_perlayer_quantization.cpp)
target_link_libraries(client_conv2d_fwd_bias_relu_perlayer_quantization PRIVATE composable_kernel::device_conv_operations composable_kernel::device_other_operations composable_kernel::device_gemm_operations) target_link_libraries(client_conv2d_fwd_bias_relu_perlayer_quantization PRIVATE composable_kernel::device_conv_operations composable_kernel::device_other_operations composable_kernel::device_gemm_operations)
target_compile_features(client_conv2d_fwd_bias_relu_perlayer_quantization PRIVATE cxx_std_17)
add_executable(client_conv2d_fwd_perchannel_quantization conv2d_fwd_perchannel_quantization.cpp) add_executable(client_conv2d_fwd_perchannel_quantization conv2d_fwd_perchannel_quantization.cpp)
target_link_libraries(client_conv2d_fwd_perchannel_quantization PRIVATE composable_kernel::device_conv_operations composable_kernel::device_other_operations composable_kernel::device_gemm_operations) target_link_libraries(client_conv2d_fwd_perchannel_quantization PRIVATE composable_kernel::device_conv_operations composable_kernel::device_other_operations composable_kernel::device_gemm_operations)
target_compile_features(client_conv2d_fwd_perchannel_quantization PRIVATE cxx_std_17)
add_executable(client_conv2d_fwd_perlayer_quantization conv2d_fwd_perlayer_quantization.cpp) add_executable(client_conv2d_fwd_perlayer_quantization conv2d_fwd_perlayer_quantization.cpp)
target_link_libraries(client_conv2d_fwd_perlayer_quantization PRIVATE composable_kernel::device_conv_operations composable_kernel::device_other_operations composable_kernel::device_gemm_operations) target_link_libraries(client_conv2d_fwd_perlayer_quantization PRIVATE composable_kernel::device_conv_operations composable_kernel::device_other_operations composable_kernel::device_gemm_operations)
target_compile_features(client_conv2d_fwd_perlayer_quantization PRIVATE cxx_std_17)
add_executable(client_gemm_quantization gemm_quantization.cpp) add_executable(client_gemm_quantization gemm_quantization.cpp)
target_link_libraries(client_gemm_quantization PRIVATE composable_kernel::device_conv_operations composable_kernel::device_other_operations composable_kernel::device_gemm_operations) target_link_libraries(client_gemm_quantization PRIVATE composable_kernel::device_conv_operations composable_kernel::device_other_operations composable_kernel::device_gemm_operations)
target_compile_features(client_gemm_quantization PRIVATE cxx_std_17)
endif() endif()
...@@ -9,9 +9,3 @@ target_link_libraries(client_grouped_conv2d_bwd_weight_fp16 PRIVATE composable_k ...@@ -9,9 +9,3 @@ target_link_libraries(client_grouped_conv2d_bwd_weight_fp16 PRIVATE composable_k
target_link_libraries(client_grouped_conv3d_bwd_weight_fp16 PRIVATE composable_kernel::device_conv_operations) target_link_libraries(client_grouped_conv3d_bwd_weight_fp16 PRIVATE composable_kernel::device_conv_operations)
target_link_libraries(client_grouped_conv3d_bwd_weight_fp32 PRIVATE composable_kernel::device_conv_operations) target_link_libraries(client_grouped_conv3d_bwd_weight_fp32 PRIVATE composable_kernel::device_conv_operations)
target_link_libraries(client_grouped_conv3d_bwd_weight_fp16_comp_bf8_fp8 PRIVATE composable_kernel::device_conv_operations) target_link_libraries(client_grouped_conv3d_bwd_weight_fp16_comp_bf8_fp8 PRIVATE composable_kernel::device_conv_operations)
target_compile_features(client_grouped_conv1d_bwd_weight_fp16 PRIVATE cxx_std_17)
target_compile_features(client_grouped_conv2d_bwd_weight_fp16 PRIVATE cxx_std_17)
target_compile_features(client_grouped_conv3d_bwd_weight_fp16 PRIVATE cxx_std_17)
target_compile_features(client_grouped_conv3d_bwd_weight_fp32 PRIVATE cxx_std_17)
target_compile_features(client_grouped_conv3d_bwd_weight_fp16_comp_bf8_fp8 PRIVATE cxx_std_17)
add_executable(client_elementwise_layernorm2d elementwise_layernorm2d.cpp) add_executable(client_elementwise_layernorm2d elementwise_layernorm2d.cpp)
target_link_libraries(client_elementwise_layernorm2d PRIVATE composable_kernel::device_other_operations) target_link_libraries(client_elementwise_layernorm2d PRIVATE composable_kernel::device_other_operations)
target_compile_features(client_elementwise_layernorm2d PRIVATE cxx_std_17)
...@@ -4,6 +4,3 @@ add_executable(client_batchnorm_infer_nhwc batchnorm_infer_nhwc.cpp) ...@@ -4,6 +4,3 @@ add_executable(client_batchnorm_infer_nhwc batchnorm_infer_nhwc.cpp)
target_link_libraries(client_batchnorm_fwd_nhwc PRIVATE composable_kernel::device_other_operations) target_link_libraries(client_batchnorm_fwd_nhwc PRIVATE composable_kernel::device_other_operations)
target_link_libraries(client_batchnorm_bwd_nhwc PRIVATE composable_kernel::device_other_operations) target_link_libraries(client_batchnorm_bwd_nhwc PRIVATE composable_kernel::device_other_operations)
target_link_libraries(client_batchnorm_infer_nhwc PRIVATE composable_kernel::device_other_operations) target_link_libraries(client_batchnorm_infer_nhwc PRIVATE composable_kernel::device_other_operations)
target_compile_features(client_batchnorm_fwd_nhwc PRIVATE cxx_std_17)
target_compile_features(client_batchnorm_bwd_nhwc PRIVATE cxx_std_17)
target_compile_features(client_batchnorm_infer_nhwc PRIVATE cxx_std_17)
add_executable(client_batchnorm_fwd_instance_id batchnorm_fwd_instance_id.cpp) add_executable(client_batchnorm_fwd_instance_id batchnorm_fwd_instance_id.cpp)
target_link_libraries(client_batchnorm_fwd_instance_id PRIVATE composable_kernel::device_other_operations) target_link_libraries(client_batchnorm_fwd_instance_id PRIVATE composable_kernel::device_other_operations)
target_compile_features(client_batchnorm_fwd_instance_id PRIVATE cxx_std_17)
cmake_minimum_required(VERSION 3.15) cmake_minimum_required(VERSION 3.15)
project(ck_app LANGUAGES CXX) project(ck_app)
add_compile_options(-std=c++17)
if (DTYPES) if (DTYPES)
add_definitions(-DDTYPES) add_definitions(-DDTYPES)
...@@ -48,60 +49,13 @@ else() ...@@ -48,60 +49,13 @@ else()
endif() endif()
find_package(composable_kernel COMPONENTS device_other_operations device_gemm_operations device_conv_operations device_contraction_operations device_reduction_operations) find_package(composable_kernel COMPONENTS device_other_operations device_gemm_operations device_conv_operations device_contraction_operations device_reduction_operations)
find_package(hip REQUIRED PATHS /opt/rocm $ENV{HIP_PATH}) find_package(hip REQUIRED PATHS /opt/rocm)
message(STATUS "Build with HIP ${hip_VERSION}") message(STATUS "Build with HIP ${hip_VERSION}")
add_subdirectory(01_gemm) # add all example subdir
add_subdirectory(02_gemm_bilinear) file(GLOB dir_list LIST_DIRECTORIES true *)
add_subdirectory(03_gemm_bias_relu) FOREACH(subdir ${dir_list})
add_subdirectory(04_gemm_add_add_fastgelu) IF(IS_DIRECTORY "${subdir}" AND (NOT "${subdir}" MATCHES "build"))
add_subdirectory(09_convnd_fwd) add_subdirectory(${subdir})
add_subdirectory(10_convnd_fwd_multiple_d_multiple_reduce) ENDIF()
add_subdirectory(12_reduce) ENDFOREACH()
add_subdirectory(13_pool2d_fwd)
add_subdirectory(14_gemm_quantization)
add_subdirectory(15_grouped_gemm)
add_subdirectory(16_gemm_multi_d_multi_reduces)
add_subdirectory(17_convnd_bwd_data)
add_subdirectory(18_batched_gemm_reduce)
add_subdirectory(19_binary_elementwise)
add_subdirectory(20_grouped_conv_bwd_weight)
add_subdirectory(21_gemm_layernorm)
add_subdirectory(22_cgemm)
add_subdirectory(23_softmax)
add_subdirectory(24_batched_gemm)
add_subdirectory(25_gemm_bias_e_permute)
add_subdirectory(26_contraction)
add_subdirectory(27_layernorm2d_fwd)
add_subdirectory(28_grouped_gemm_bias_e_permute)
add_subdirectory(29_batched_gemm_bias_e_permute)
add_subdirectory(30_grouped_conv_fwd_multiple_d)
add_subdirectory(31_batched_gemm_gemm)
add_subdirectory(32_batched_gemm_scale_softmax_gemm)
add_subdirectory(33_multiple_reduce)
add_subdirectory(34_batchnorm)
add_subdirectory(35_splitK_gemm)
add_subdirectory(36_sparse_embedding)
add_subdirectory(37_batched_gemm_add_add_relu_gemm_add)
add_subdirectory(38_grouped_conv_bwd_data_multiple_d)
add_subdirectory(39_permute)
add_subdirectory(40_conv2d_fwd_quantization)
add_subdirectory(41_grouped_conv_conv_fwd)
add_subdirectory(42_groupnorm_fwd)
add_subdirectory(43_splitk_gemm_bias_e_permute)
add_subdirectory(44_elementwise_permute)
add_subdirectory(45_elementwise_normalization)
add_subdirectory(46_gemm_add_multiply)
add_subdirectory(47_gemm_bias_softmax_gemm_permute)
add_subdirectory(48_pool3d_fwd)
add_subdirectory(49_maxpool2d_bwd)
add_subdirectory(50_put_element)
add_subdirectory(51_avgpool3d_bwd)
add_subdirectory(52_im2col_col2im)
add_subdirectory(53_layernorm_bwd)
add_subdirectory(54_groupnorm_bwd)
add_subdirectory(60_gemm_multi_ABD)
add_subdirectory(61_contraction_multi_ABD)
add_subdirectory(62_conv_fwd_activ)
add_subdirectory(63_layernorm4d_fwd)
add_subdirectory(64_tensor_transforms)
\ No newline at end of file
...@@ -25,18 +25,40 @@ ...@@ -25,18 +25,40 @@
################################################################################ ################################################################################
# - Enable warning all for gcc/clang or use /W4 for visual studio # - Enable warning all for gcc/clang or use /W4 for visual studio
## Strict compile options for Visual C++ compiler ## Strict warning level
set(__default_msvc_compile_options /w) if (MSVC)
# Use the highest warning level for visual studio.
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} /w")
set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} /w")
# set(CMAKE_CXX_WARNING_LEVEL 4)
# if (CMAKE_CXX_FLAGS MATCHES "/W[0-4]")
# string(REGEX REPLACE "/W[0-4]" "/W4" CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS}")
# else ()
# set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} /W4")
# endif ()
## Strict compile options for GNU/Clang compilers # set(CMAKE_C_WARNING_LEVEL 4)
set(__default_compile_options # if (CMAKE_C_FLAGS MATCHES "/W[0-4]")
-Wall -Wextra # string(REGEX REPLACE "/W[0-4]" "/W4" CMAKE_C_FLAGS "${CMAKE_C_FLAGS}")
# else ()
# set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} /W4")
# endif ()
else()
foreach(COMPILER C CXX)
set(CMAKE_COMPILER_WARNINGS)
# use -Wall for gcc and clang
list(APPEND CMAKE_COMPILER_WARNINGS
-Wall
-Wextra
-Wcomment -Wcomment
-Wendif-labels -Wendif-labels
-Wformat -Wformat
-Winit-self -Winit-self
-Wreturn-type -Wreturn-type
-Wsequence-point -Wsequence-point
# Shadow is broken on gcc when using lambdas
# -Wshadow
-Wswitch -Wswitch
-Wtrigraphs -Wtrigraphs
-Wundef -Wundef
...@@ -50,11 +72,9 @@ set(__default_compile_options ...@@ -50,11 +72,9 @@ set(__default_compile_options
-Wno-extra-semi-stmt -Wno-extra-semi-stmt
-Wno-unused-template -Wno-unused-template
) )
if (CMAKE_${COMPILER}_COMPILER_ID MATCHES "Clang")
## Strict compile options for Clang compilers list(APPEND CMAKE_COMPILER_WARNINGS
set(__default_clang_compile_options
-Weverything -Weverything
-Wshadow
-Wno-c++98-compat -Wno-c++98-compat
-Wno-c++98-compat-pedantic -Wno-c++98-compat-pedantic
-Wno-conversion -Wno-conversion
...@@ -74,35 +94,21 @@ set(__default_clang_compile_options ...@@ -74,35 +94,21 @@ set(__default_clang_compile_options
-Wno-unused-command-line-argument -Wno-unused-command-line-argument
-Wno-weak-vtables -Wno-weak-vtables
-Wno-covered-switch-default -Wno-covered-switch-default
-Wno-unsafe-buffer-usage) -Wno-unsafe-buffer-usage
)
if(WIN32) else()
list(APPEND __default_clang_compile_options if (CMAKE_${COMPILER}_COMPILER_ID MATCHES "GNU" AND ${COMPILER} MATCHES "CXX")
-fms-extensions # cmake 3.5.2 does not support >=.
-fms-compatibility if(NOT CMAKE_CXX_COMPILER_VERSION VERSION_LESS "6.1")
-fdelayed-template-parsing) list(APPEND CMAKE_COMPILER_WARNINGS
endif() -Wno-ignored-attributes)
endif()
set(__default_gnu_compile_options endif()
-Wduplicated-branches list(APPEND CMAKE_COMPILER_WARNINGS
-Wduplicated-cond
-Wno-noexcept-type
-Wno-ignored-attributes
-Wodr
-Wshift-negative-value
-Wshift-overflow=2
-Wno-missing-field-initializers -Wno-missing-field-initializers
-Wno-maybe-uninitialized -Wno-deprecated-declarations
-Wno-deprecated-declarations) )
endif()
add_compile_options( add_definitions(${CMAKE_COMPILER_WARNINGS})
"$<$<OR:$<CXX_COMPILER_ID:MSVC>,$<C_COMPILER_ID:MSVC>>:${__default_msvc_compile_options}>" endforeach()
"$<$<OR:$<CXX_COMPILER_ID:GNU,Clang>,$<C_COMPILER_ID:GNU,Clang>>:${__default_compile_options}>" endif ()
"$<$<OR:$<AND:$<CXX_COMPILER_ID:GNU>,$<VERSION_GREATER_EQUAL:$<CXX_COMPILER_VERSION>,7>>,$<AND:$<C_COMPILER_ID:GNU>,$<VERSION_GREATER_EQUAL:$<C_COMPILER_VERSION>,7>>>:${__default_gnu_compile_options}>"
"$<$<OR:$<CXX_COMPILER_ID:Clang>,$<C_COMPILER_ID:Clang>>:${__default_clang_compile_options}>")
unset(__default_msvc_compile_options)
unset(__default_compile_options)
unset(__default_gnu_compile_options)
unset(__default_clang_compile_options)
...@@ -32,15 +32,21 @@ FetchContent_MakeAvailable(googletest) ...@@ -32,15 +32,21 @@ FetchContent_MakeAvailable(googletest)
# Restore the old value of BUILD_SHARED_LIBS # Restore the old value of BUILD_SHARED_LIBS
set(BUILD_SHARED_LIBS ${__build_shared_libs} CACHE BOOL "Type of libraries to build" FORCE) set(BUILD_SHARED_LIBS ${__build_shared_libs} CACHE BOOL "Type of libraries to build" FORCE)
set(GTEST_CXX_FLAGS
-Wno-undef
-Wno-global-constructors
-Wno-zero-as-null-pointer-constant
-Wno-switch-enum
-Wno-float-equal
-Wno-unused-member-function)
if(WIN32) if(WIN32)
list(APPEND GTEST_CMAKE_CXX_FLAGS list(APPEND GTEST_CXX_FLAGS
-Wno-suggest-destructor-override -Wno-suggest-destructor-override
-Wno-suggest-override -Wno-suggest-override
-Wno-nonportable-system-include-path -Wno-nonportable-system-include-path
-Wno-language-extension-token) -Wno-language-extension-token)
endif() endif()
target_compile_options(gtest PRIVATE -Wno-undef) target_compile_options(gtest PRIVATE ${GTEST_CXX_FLAGS})
target_compile_options(gtest_main PRIVATE -Wno-undef) target_compile_options(gtest_main PRIVATE ${GTEST_CXX_FLAGS})
...@@ -79,7 +79,7 @@ std::ostream& show_2d_matrix(std::ostream& os, Tensor<DataType>& matrix) ...@@ -79,7 +79,7 @@ std::ostream& show_2d_matrix(std::ostream& os, Tensor<DataType>& matrix)
} }
int main(int argc, char* argv[]) int main(int argc, char* argv[])
{ {
bool do_verification = false; bool do_verification = 0;
int init_method = 0; int init_method = 0;
bool time_kernel = false; bool time_kernel = false;
......
...@@ -112,7 +112,7 @@ bool bnorm_bwd_nhwc_test(bool do_verification, ...@@ -112,7 +112,7 @@ bool bnorm_bwd_nhwc_test(bool do_verification,
bool time_kernel, bool time_kernel,
const std::vector<size_t> inOutLengths, const std::vector<size_t> inOutLengths,
bool haveSavedMeanInvVar, bool haveSavedMeanInvVar,
double _epsilon) double epsilon)
{ {
// for NHWC BatchNorm calculation of mean and meansquare // for NHWC BatchNorm calculation of mean and meansquare
constexpr index_t Rank = 4; constexpr index_t Rank = 4;
...@@ -292,7 +292,7 @@ bool bnorm_bwd_nhwc_test(bool do_verification, ...@@ -292,7 +292,7 @@ bool bnorm_bwd_nhwc_test(bool do_verification,
bnScale_dev.GetDeviceBuffer(), bnScale_dev.GetDeviceBuffer(),
haveSavedMeanInvVar ? savedMean_dev.GetDeviceBuffer() : nullptr, haveSavedMeanInvVar ? savedMean_dev.GetDeviceBuffer() : nullptr,
haveSavedMeanInvVar ? savedInvVar_dev.GetDeviceBuffer() : nullptr, haveSavedMeanInvVar ? savedInvVar_dev.GetDeviceBuffer() : nullptr,
_epsilon, epsilon,
PassThroughOp{}, PassThroughOp{},
dx_dev.GetDeviceBuffer(), dx_dev.GetDeviceBuffer(),
dscale_dev.GetDeviceBuffer(), dscale_dev.GetDeviceBuffer(),
...@@ -371,7 +371,7 @@ bool bnorm_bwd_nhwc_test(bool do_verification, ...@@ -371,7 +371,7 @@ bool bnorm_bwd_nhwc_test(bool do_verification,
bnScale.mData.data(), bnScale.mData.data(),
haveSavedMeanInvVar ? savedMean.mData.data() : nullptr, haveSavedMeanInvVar ? savedMean.mData.data() : nullptr,
haveSavedMeanInvVar ? savedInvVar.mData.data() : nullptr, haveSavedMeanInvVar ? savedInvVar.mData.data() : nullptr,
_epsilon, epsilon,
PassThroughOp{}, PassThroughOp{},
dx_ref.mData.data(), dx_ref.mData.data(),
dscale_ref.mData.data(), dscale_ref.mData.data(),
......
...@@ -119,7 +119,7 @@ bool bnorm_infer_nhwc_test(bool do_verification, ...@@ -119,7 +119,7 @@ bool bnorm_infer_nhwc_test(bool do_verification,
int init_method, int init_method,
bool time_kernel, bool time_kernel,
const std::vector<size_t> inOutLengths, const std::vector<size_t> inOutLengths,
double _epsilon) double epsilon)
{ {
// for NHWC BatchNorm calculation of mean and meansquare // for NHWC BatchNorm calculation of mean and meansquare
constexpr int Rank = 4; constexpr int Rank = 4;
...@@ -251,7 +251,7 @@ bool bnorm_infer_nhwc_test(bool do_verification, ...@@ -251,7 +251,7 @@ bool bnorm_infer_nhwc_test(bool do_verification,
x_dev.GetDeviceBuffer(), x_dev.GetDeviceBuffer(),
bnScale_dev.GetDeviceBuffer(), bnScale_dev.GetDeviceBuffer(),
bnBias_dev.GetDeviceBuffer(), bnBias_dev.GetDeviceBuffer(),
_epsilon, epsilon,
estimatedMean_dev.GetDeviceBuffer(), estimatedMean_dev.GetDeviceBuffer(),
estimatedVariance_dev.GetDeviceBuffer(), estimatedVariance_dev.GetDeviceBuffer(),
y_dev.GetDeviceBuffer()); y_dev.GetDeviceBuffer());
...@@ -289,7 +289,7 @@ bool bnorm_infer_nhwc_test(bool do_verification, ...@@ -289,7 +289,7 @@ bool bnorm_infer_nhwc_test(bool do_verification,
x.mData.data(), x.mData.data(),
bnScale.mData.data(), bnScale.mData.data(),
bnBias.mData.data(), bnBias.mData.data(),
_epsilon, epsilon,
PassThroughOp{}, PassThroughOp{},
estimatedMean.mData.data(), estimatedMean.mData.data(),
estimatedVariance.mData.data(), estimatedVariance.mData.data(),
......
...@@ -135,8 +135,8 @@ bool bnorm_fwd_nhwc_test(bool do_verification, ...@@ -135,8 +135,8 @@ bool bnorm_fwd_nhwc_test(bool do_verification,
const std::vector<size_t> inOutLengths, const std::vector<size_t> inOutLengths,
bool updateMovingAverage, bool updateMovingAverage,
bool saveMeanAndInvVariance, bool saveMeanAndInvVariance,
double _averageFactor, double averageFactor,
double _epsilon) double epsilon)
{ {
// for NHWC BatchNorm calculation of mean and meansquare // for NHWC BatchNorm calculation of mean and meansquare
constexpr int Rank = 4; constexpr int Rank = 4;
...@@ -310,12 +310,12 @@ bool bnorm_fwd_nhwc_test(bool do_verification, ...@@ -310,12 +310,12 @@ bool bnorm_fwd_nhwc_test(bool do_verification,
x_dev.GetDeviceBuffer(), x_dev.GetDeviceBuffer(),
bnScale_dev.GetDeviceBuffer(), bnScale_dev.GetDeviceBuffer(),
bnBias_dev.GetDeviceBuffer(), bnBias_dev.GetDeviceBuffer(),
_epsilon, epsilon,
PassThroughOp{}, PassThroughOp{},
y_dev.GetDeviceBuffer(), y_dev.GetDeviceBuffer(),
saveMeanAndInvVariance ? resultSaveMean_dev.GetDeviceBuffer() : nullptr, saveMeanAndInvVariance ? resultSaveMean_dev.GetDeviceBuffer() : nullptr,
saveMeanAndInvVariance ? resultSaveInvVariance_dev.GetDeviceBuffer() : nullptr, saveMeanAndInvVariance ? resultSaveInvVariance_dev.GetDeviceBuffer() : nullptr,
_averageFactor, averageFactor,
updateMovingAverage ? resultRunningMean_dev.GetDeviceBuffer() : nullptr, updateMovingAverage ? resultRunningMean_dev.GetDeviceBuffer() : nullptr,
updateMovingAverage ? resultRunningVariance_dev.GetDeviceBuffer() : nullptr); updateMovingAverage ? resultRunningVariance_dev.GetDeviceBuffer() : nullptr);
...@@ -392,12 +392,12 @@ bool bnorm_fwd_nhwc_test(bool do_verification, ...@@ -392,12 +392,12 @@ bool bnorm_fwd_nhwc_test(bool do_verification,
x.mData.data(), x.mData.data(),
bnScale.mData.data(), bnScale.mData.data(),
bnBias.mData.data(), bnBias.mData.data(),
_epsilon, epsilon,
PassThroughOp{}, PassThroughOp{},
y_ref.mData.data(), y_ref.mData.data(),
saveMeanAndInvVariance ? resultSaveMean_ref.mData.data() : nullptr, saveMeanAndInvVariance ? resultSaveMean_ref.mData.data() : nullptr,
saveMeanAndInvVariance ? resultSaveInvVariance_ref.mData.data() : nullptr, saveMeanAndInvVariance ? resultSaveInvVariance_ref.mData.data() : nullptr,
_averageFactor, averageFactor,
updateMovingAverage ? resultRunningMean_ref.mData.data() : nullptr, updateMovingAverage ? resultRunningMean_ref.mData.data() : nullptr,
updateMovingAverage ? resultRunningVariance_ref.mData.data() : nullptr); updateMovingAverage ? resultRunningVariance_ref.mData.data() : nullptr);
......
...@@ -135,8 +135,8 @@ bool bnorm_fwd_nhwc_test(bool do_verification, ...@@ -135,8 +135,8 @@ bool bnorm_fwd_nhwc_test(bool do_verification,
const std::vector<size_t> inOutLengths, const std::vector<size_t> inOutLengths,
bool updateMovingAverage, bool updateMovingAverage,
bool saveMeanAndInvVariance, bool saveMeanAndInvVariance,
double _averageFactor, double averageFactor,
double _epsilon) double epsilon)
{ {
// for NHWC BatchNorm calculation of mean and meansquare // for NHWC BatchNorm calculation of mean and meansquare
constexpr int Rank = 4; constexpr int Rank = 4;
...@@ -310,12 +310,12 @@ bool bnorm_fwd_nhwc_test(bool do_verification, ...@@ -310,12 +310,12 @@ bool bnorm_fwd_nhwc_test(bool do_verification,
x_dev.GetDeviceBuffer(), x_dev.GetDeviceBuffer(),
bnScale_dev.GetDeviceBuffer(), bnScale_dev.GetDeviceBuffer(),
bnBias_dev.GetDeviceBuffer(), bnBias_dev.GetDeviceBuffer(),
_epsilon, epsilon,
PassThroughOp{}, PassThroughOp{},
y_dev.GetDeviceBuffer(), y_dev.GetDeviceBuffer(),
saveMeanAndInvVariance ? resultSaveMean_dev.GetDeviceBuffer() : nullptr, saveMeanAndInvVariance ? resultSaveMean_dev.GetDeviceBuffer() : nullptr,
saveMeanAndInvVariance ? resultSaveInvVariance_dev.GetDeviceBuffer() : nullptr, saveMeanAndInvVariance ? resultSaveInvVariance_dev.GetDeviceBuffer() : nullptr,
_averageFactor, averageFactor,
updateMovingAverage ? resultRunningMean_dev.GetDeviceBuffer() : nullptr, updateMovingAverage ? resultRunningMean_dev.GetDeviceBuffer() : nullptr,
updateMovingAverage ? resultRunningVariance_dev.GetDeviceBuffer() : nullptr); updateMovingAverage ? resultRunningVariance_dev.GetDeviceBuffer() : nullptr);
...@@ -392,12 +392,12 @@ bool bnorm_fwd_nhwc_test(bool do_verification, ...@@ -392,12 +392,12 @@ bool bnorm_fwd_nhwc_test(bool do_verification,
x.mData.data(), x.mData.data(),
bnScale.mData.data(), bnScale.mData.data(),
bnBias.mData.data(), bnBias.mData.data(),
_epsilon, epsilon,
PassThroughOp{}, PassThroughOp{},
y_ref.mData.data(), y_ref.mData.data(),
saveMeanAndInvVariance ? resultSaveMean_ref.mData.data() : nullptr, saveMeanAndInvVariance ? resultSaveMean_ref.mData.data() : nullptr,
saveMeanAndInvVariance ? resultSaveInvVariance_ref.mData.data() : nullptr, saveMeanAndInvVariance ? resultSaveInvVariance_ref.mData.data() : nullptr,
_averageFactor, averageFactor,
updateMovingAverage ? resultRunningMean_ref.mData.data() : nullptr, updateMovingAverage ? resultRunningMean_ref.mData.data() : nullptr,
updateMovingAverage ? resultRunningVariance_ref.mData.data() : nullptr); updateMovingAverage ? resultRunningVariance_ref.mData.data() : nullptr);
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment