Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel_ROCM
Commits
e276fc95
Commit
e276fc95
authored
Dec 05, 2023
by
Artur Wojcik
Browse files
merge 'uif2-temp' to uif2-initial
parent
9b3a0d42
Changes
27
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
125 additions
and
197 deletions
+125
-197
client_example/01_gemm/CMakeLists.txt
client_example/01_gemm/CMakeLists.txt
+0
-2
client_example/02_gemm_add_add_fastgelu/CMakeLists.txt
client_example/02_gemm_add_add_fastgelu/CMakeLists.txt
+0
-6
client_example/03_gemm_layernorm/CMakeLists.txt
client_example/03_gemm_layernorm/CMakeLists.txt
+0
-2
client_example/04_contraction/CMakeLists.txt
client_example/04_contraction/CMakeLists.txt
+1
-5
client_example/05_layernorm/CMakeLists.txt
client_example/05_layernorm/CMakeLists.txt
+0
-3
client_example/06_softmax/CMakeLists.txt
client_example/06_softmax/CMakeLists.txt
+0
-1
client_example/08_fused_attention/CMakeLists.txt
client_example/08_fused_attention/CMakeLists.txt
+0
-2
client_example/09_quantization/CMakeLists.txt
client_example/09_quantization/CMakeLists.txt
+0
-7
client_example/11_grouped_conv_bwd_weight/CMakeLists.txt
client_example/11_grouped_conv_bwd_weight/CMakeLists.txt
+0
-6
client_example/12_elementwise_normalization/CMakeLists.txt
client_example/12_elementwise_normalization/CMakeLists.txt
+0
-1
client_example/13_batchnorm/CMakeLists.txt
client_example/13_batchnorm/CMakeLists.txt
+0
-3
client_example/14_instance_id/CMakeLists.txt
client_example/14_instance_id/CMakeLists.txt
+0
-1
client_example/CMakeLists.txt
client_example/CMakeLists.txt
+10
-56
cmake/EnableCompilerWarnings.cmake
cmake/EnableCompilerWarnings.cmake
+85
-79
cmake/GTest.cmake
cmake/GTest.cmake
+10
-4
example/01_gemm/gemm_xdl_skip_b_lds_fp16.cpp
example/01_gemm/gemm_xdl_skip_b_lds_fp16.cpp
+1
-1
example/34_batchnorm/batchnorm_backward_nhwc.cpp
example/34_batchnorm/batchnorm_backward_nhwc.cpp
+3
-3
example/34_batchnorm/batchnorm_forward_inferring_nhwc.cpp
example/34_batchnorm/batchnorm_forward_inferring_nhwc.cpp
+3
-3
example/34_batchnorm/batchnorm_forward_training_nhwc.cpp
example/34_batchnorm/batchnorm_forward_training_nhwc.cpp
+6
-6
example/34_batchnorm/batchnorm_forward_training_nhwc_obsolete.cpp
...34_batchnorm/batchnorm_forward_training_nhwc_obsolete.cpp
+6
-6
No files found.
client_example/01_gemm/CMakeLists.txt
View file @
e276fc95
add_executable
(
client_gemm gemm.cpp
)
add_executable
(
client_gemm gemm.cpp
)
target_link_libraries
(
client_gemm PRIVATE composable_kernel::device_other_operations composable_kernel::device_gemm_operations
)
target_link_libraries
(
client_gemm PRIVATE composable_kernel::device_other_operations composable_kernel::device_gemm_operations
)
target_compile_features
(
client_gemm PRIVATE cxx_std_17
)
client_example/02_gemm_add_add_fastgelu/CMakeLists.txt
View file @
e276fc95
...
@@ -2,15 +2,12 @@ add_custom_target(client_gemm_fastgelu_examples)
...
@@ -2,15 +2,12 @@ add_custom_target(client_gemm_fastgelu_examples)
add_executable
(
client_gemm_add_add_fastgelu gemm_add_add_fastgelu.cpp
)
add_executable
(
client_gemm_add_add_fastgelu gemm_add_add_fastgelu.cpp
)
target_link_libraries
(
client_gemm_add_add_fastgelu PRIVATE composable_kernel::device_gemm_operations
)
target_link_libraries
(
client_gemm_add_add_fastgelu PRIVATE composable_kernel::device_gemm_operations
)
target_compile_features
(
client_gemm_add_add_fastgelu PRIVATE cxx_std_17
)
add_executable
(
client_gemm_add_fastgelu gemm_add_fastgelu.cpp
)
add_executable
(
client_gemm_add_fastgelu gemm_add_fastgelu.cpp
)
target_link_libraries
(
client_gemm_add_fastgelu PRIVATE composable_kernel::device_gemm_operations
)
target_link_libraries
(
client_gemm_add_fastgelu PRIVATE composable_kernel::device_gemm_operations
)
target_compile_features
(
client_gemm_add_fastgelu PRIVATE cxx_std_17
)
add_executable
(
client_gemm_fastgelu gemm_fastgelu.cpp
)
add_executable
(
client_gemm_fastgelu gemm_fastgelu.cpp
)
target_link_libraries
(
client_gemm_fastgelu PRIVATE composable_kernel::device_gemm_operations
)
target_link_libraries
(
client_gemm_fastgelu PRIVATE composable_kernel::device_gemm_operations
)
target_compile_features
(
client_gemm_fastgelu PRIVATE cxx_std_17
)
add_dependencies
(
client_gemm_fastgelu_examples client_gemm_add_add_fastgelu client_gemm_add_fastgelu
add_dependencies
(
client_gemm_fastgelu_examples client_gemm_add_add_fastgelu client_gemm_add_fastgelu
client_gemm_fastgelu
)
client_gemm_fastgelu
)
...
@@ -19,15 +16,12 @@ add_custom_target(client_gemm_fastgelu_generic_examples)
...
@@ -19,15 +16,12 @@ add_custom_target(client_gemm_fastgelu_generic_examples)
add_executable
(
client_gemm_add_add_fastgelu_generic gemm_add_add_fastgelu_generic.cpp
)
add_executable
(
client_gemm_add_add_fastgelu_generic gemm_add_add_fastgelu_generic.cpp
)
target_link_libraries
(
client_gemm_add_add_fastgelu_generic composable_kernel::device_gemm_operations
)
target_link_libraries
(
client_gemm_add_add_fastgelu_generic composable_kernel::device_gemm_operations
)
target_compile_features
(
client_gemm_add_add_fastgelu_generic PRIVATE cxx_std_17
)
add_executable
(
client_gemm_add_fastgelu_generic gemm_add_fastgelu_generic.cpp
)
add_executable
(
client_gemm_add_fastgelu_generic gemm_add_fastgelu_generic.cpp
)
target_link_libraries
(
client_gemm_add_fastgelu_generic PRIVATE composable_kernel::device_gemm_operations
)
target_link_libraries
(
client_gemm_add_fastgelu_generic PRIVATE composable_kernel::device_gemm_operations
)
target_compile_features
(
client_gemm_add_fastgelu_generic PRIVATE cxx_std_17
)
add_executable
(
client_gemm_fastgelu_generic gemm_fastgelu_generic.cpp
)
add_executable
(
client_gemm_fastgelu_generic gemm_fastgelu_generic.cpp
)
target_link_libraries
(
client_gemm_fastgelu_generic PRIVATE composable_kernel::device_gemm_operations
)
target_link_libraries
(
client_gemm_fastgelu_generic PRIVATE composable_kernel::device_gemm_operations
)
target_compile_features
(
client_gemm_fastgelu_generic PRIVATE cxx_std_17
)
add_dependencies
(
client_gemm_fastgelu_generic_examples client_gemm_add_add_fastgelu_generic
add_dependencies
(
client_gemm_fastgelu_generic_examples client_gemm_add_add_fastgelu_generic
client_gemm_add_fastgelu_generic client_gemm_fastgelu_generic
)
client_gemm_add_fastgelu_generic client_gemm_fastgelu_generic
)
\ No newline at end of file
client_example/03_gemm_layernorm/CMakeLists.txt
View file @
e276fc95
add_executable
(
client_gemm_add_add_layernorm_naive gemm_add_add_layernorm_naive.cpp
)
add_executable
(
client_gemm_add_add_layernorm_naive gemm_add_add_layernorm_naive.cpp
)
target_link_libraries
(
client_gemm_add_add_layernorm_naive PRIVATE composable_kernel::device_gemm_operations composable_kernel::device_other_operations
)
target_link_libraries
(
client_gemm_add_add_layernorm_naive PRIVATE composable_kernel::device_gemm_operations composable_kernel::device_other_operations
)
target_compile_features
(
client_gemm_add_add_reduce_normalize PRIVATE cxx_std_17
)
add_executable
(
client_gemm_add_relu_add_layernorm_welford gemm_add_relu_add_layernorm_welford.cpp
)
add_executable
(
client_gemm_add_relu_add_layernorm_welford gemm_add_relu_add_layernorm_welford.cpp
)
target_link_libraries
(
client_gemm_add_relu_add_layernorm_welford PRIVATE composable_kernel::device_gemm_operations composable_kernel::device_other_operations
)
target_link_libraries
(
client_gemm_add_relu_add_layernorm_welford PRIVATE composable_kernel::device_gemm_operations composable_kernel::device_other_operations
)
target_compile_features
(
client_gemm_add_relu_add_layernorm_welford PRIVATE cxx_std_17
)
client_example/04_contraction/CMakeLists.txt
View file @
e276fc95
add_executable
(
client_contraction_scale_fp32 contraction_scale_fp32.cpp
)
add_executable
(
client_contraction_scale_fp32 contraction_scale_fp32.cpp
)
target_link_libraries
(
client_contraction_scale_fp32 PRIVATE composable_kernel::device_other_operations composable_kernel::device_contraction_operations composable_kernel::device_gemm_operations
)
target_link_libraries
(
client_contraction_scale_fp32 PRIVATE composable_kernel::device_other_operations composable_kernel::device_contraction_operations composable_kernel::device_gemm_operations
)
target_compile_features
(
client_contraction_scale PRIVATE cxx_std_17
)
add_executable
(
client_contraction_bilinear_fp32 contraction_bilinear_fp32.cpp
)
add_executable
(
client_contraction_bilinear_fp32 contraction_bilinear_fp32.cpp
)
target_link_libraries
(
client_contraction_bilinear_fp32 PRIVATE composable_kernel::device_other_operations composable_kernel::device_contraction_operations composable_kernel::device_gemm_operations
)
target_link_libraries
(
client_contraction_bilinear_fp32 PRIVATE composable_kernel::device_other_operations composable_kernel::device_contraction_operations composable_kernel::device_gemm_operations
)
target_compile_features
(
client_contraction_bilinear PRIVATE cxx_std_17
)
add_executable
(
client_contraction_scale_fp64 contraction_scale_fp64.cpp
)
add_executable
(
client_contraction_scale_fp64 contraction_scale_fp64.cpp
)
target_link_libraries
(
client_contraction_scale_fp64 PRIVATE composable_kernel::device_other_operations composable_kernel::device_contraction_operations composable_kernel::device_gemm_operations
)
target_link_libraries
(
client_contraction_scale_fp64 PRIVATE composable_kernel::device_other_operations composable_kernel::device_contraction_operations composable_kernel::device_gemm_operations
)
target_compile_features
(
client_contraction_scale_fp64 PRIVATE cxx_std_17
)
add_executable
(
client_contraction_bilinear_fp64 contraction_bilinear_fp64.cpp
)
add_executable
(
client_contraction_bilinear_fp64 contraction_bilinear_fp64.cpp
)
target_link_libraries
(
client_contraction_bilinear_fp64 PRIVATE composable_kernel::device_other_operations composable_kernel::device_contraction_operations composable_kernel::device_gemm_operations
)
target_link_libraries
(
client_contraction_bilinear_fp64 PRIVATE composable_kernel::device_other_operations composable_kernel::device_contraction_operations composable_kernel::device_gemm_operations
)
target_compile_features
(
client_contraction_blinear_fp64 PRIVATE cxx_std_17
)
add_executable
(
contraction_g1m2n3k1_add_xdl_fp16 contraction_g1m2n3k1_add_xdl_fp16.cpp
)
add_executable
(
contraction_g1m2n3k1_add_xdl_fp16 contraction_g1m2n3k1_add_xdl_fp16.cpp
)
target_link_libraries
(
contraction_g1m2n3k1_add_xdl_fp16 PRIVATE composable_kernel::device_other_operations composable_kernel::device_contraction_operations composable_kernel::device_gemm_operations
)
target_link_libraries
(
contraction_g1m2n3k1_add_xdl_fp16 PRIVATE composable_kernel::device_other_operations composable_kernel::device_contraction_operations composable_kernel::device_gemm_operations
)
target_compile_features
(
contraction_g1m2n3k1_add_xdl-fp16 PRIVATE cxx_std_17
)
client_example/05_layernorm/CMakeLists.txt
View file @
e276fc95
add_executable
(
client_layernorm2d_fwd layernorm2d_fwd.cpp
)
add_executable
(
client_layernorm2d_fwd layernorm2d_fwd.cpp
)
target_link_libraries
(
client_layernorm2d_fwd PRIVATE composable_kernel::device_other_operations
)
target_link_libraries
(
client_layernorm2d_fwd PRIVATE composable_kernel::device_other_operations
)
target_compile_features
(
client_layernorm2d_fwd PRIVATE cxx_std_17
)
add_executable
(
client_layernorm4d_fwd layernorm4d_fwd.cpp
)
add_executable
(
client_layernorm4d_fwd layernorm4d_fwd.cpp
)
target_link_libraries
(
client_layernorm4d_fwd PRIVATE composable_kernel::device_other_operations
)
target_link_libraries
(
client_layernorm4d_fwd PRIVATE composable_kernel::device_other_operations
)
target_compile_features
(
client_layernorm4d_fwd PRIVATE cxx_std_17
)
client_example/06_softmax/CMakeLists.txt
View file @
e276fc95
add_executable
(
client_softmax4d softmax4d.cpp
)
add_executable
(
client_softmax4d softmax4d.cpp
)
target_link_libraries
(
client_softmax4d PRIVATE composable_kernel::device_other_operations composable_kernel::device_reduction_operations
)
target_link_libraries
(
client_softmax4d PRIVATE composable_kernel::device_other_operations composable_kernel::device_reduction_operations
)
target_compile_features
(
client_softmax4d PRIVATE cxx_std_17
)
client_example/08_fused_attention/CMakeLists.txt
View file @
e276fc95
add_executable
(
client_fused_attention fused_attention.cpp
)
add_executable
(
client_fused_attention fused_attention.cpp
)
target_link_libraries
(
client_fused_attention PRIVATE composable_kernel::device_other_operations composable_kernel::device_gemm_operations
)
target_link_libraries
(
client_fused_attention PRIVATE composable_kernel::device_other_operations composable_kernel::device_gemm_operations
)
target_compile_features
(
client_fused_attention PRIVATE cxx_std_17
)
add_executable
(
client_fused_attention_bias fused_attention_bias.cpp
)
add_executable
(
client_fused_attention_bias fused_attention_bias.cpp
)
target_link_libraries
(
client_fused_attention_bias PRIVATE composable_kernel::device_other_operations composable_kernel::device_gemm_operations
)
target_link_libraries
(
client_fused_attention_bias PRIVATE composable_kernel::device_other_operations composable_kernel::device_gemm_operations
)
target_compile_features
(
client_fused_attention_bias PRIVATE cxx_std_17
)
client_example/09_quantization/CMakeLists.txt
View file @
e276fc95
if
(
DTYPES MATCHES
"int8"
OR NOT DEFINED DTYPES
)
if
(
DTYPES MATCHES
"int8"
OR NOT DEFINED DTYPES
)
add_executable
(
client_conv2d_fwd_bias_tanh_perchannel_quantization conv2d_fwd_bias_tanh_perchannel_quantization.cpp
)
add_executable
(
client_conv2d_fwd_bias_tanh_perchannel_quantization conv2d_fwd_bias_tanh_perchannel_quantization.cpp
)
target_link_libraries
(
client_conv2d_fwd_bias_tanh_perchannel_quantization PRIVATE composable_kernel::device_conv_operations composable_kernel::device_other_operations composable_kernel::device_gemm_operations
)
target_link_libraries
(
client_conv2d_fwd_bias_tanh_perchannel_quantization PRIVATE composable_kernel::device_conv_operations composable_kernel::device_other_operations composable_kernel::device_gemm_operations
)
target_compile_features
(
client_conv2d_fwd_bias_tanh_perchangel_quantization PRIVATE cxx_std_17
)
add_executable
(
client_conv2d_fwd_bias_relu_perchannel_quantization conv2d_fwd_bias_relu_perchannel_quantization.cpp
)
add_executable
(
client_conv2d_fwd_bias_relu_perchannel_quantization conv2d_fwd_bias_relu_perchannel_quantization.cpp
)
target_link_libraries
(
client_conv2d_fwd_bias_relu_perchannel_quantization PRIVATE composable_kernel::device_conv_operations composable_kernel::device_other_operations composable_kernel::device_gemm_operations
)
target_link_libraries
(
client_conv2d_fwd_bias_relu_perchannel_quantization PRIVATE composable_kernel::device_conv_operations composable_kernel::device_other_operations composable_kernel::device_gemm_operations
)
target_compile_features
(
client_conv2d_fwd_bias_relu_perchannel_quantization PRIVATE cxx_std_17
)
add_executable
(
client_conv2d_fwd_bias_tanh_perlayer_quantization conv2d_fwd_bias_tanh_perlayer_quantization.cpp
)
add_executable
(
client_conv2d_fwd_bias_tanh_perlayer_quantization conv2d_fwd_bias_tanh_perlayer_quantization.cpp
)
target_link_libraries
(
client_conv2d_fwd_bias_tanh_perlayer_quantization PRIVATE composable_kernel::device_conv_operations composable_kernel::device_other_operations composable_kernel::device_gemm_operations
)
target_link_libraries
(
client_conv2d_fwd_bias_tanh_perlayer_quantization PRIVATE composable_kernel::device_conv_operations composable_kernel::device_other_operations composable_kernel::device_gemm_operations
)
target_compile_features
(
client_conv2d_fwd_bias_tanh_perlayer_quantization PRIVATE cxx_std_17
)
add_executable
(
client_conv2d_fwd_bias_relu_perlayer_quantization conv2d_fwd_bias_relu_perlayer_quantization.cpp
)
add_executable
(
client_conv2d_fwd_bias_relu_perlayer_quantization conv2d_fwd_bias_relu_perlayer_quantization.cpp
)
target_link_libraries
(
client_conv2d_fwd_bias_relu_perlayer_quantization PRIVATE composable_kernel::device_conv_operations composable_kernel::device_other_operations composable_kernel::device_gemm_operations
)
target_link_libraries
(
client_conv2d_fwd_bias_relu_perlayer_quantization PRIVATE composable_kernel::device_conv_operations composable_kernel::device_other_operations composable_kernel::device_gemm_operations
)
target_compile_features
(
client_conv2d_fwd_bias_relu_perlayer_quantization PRIVATE cxx_std_17
)
add_executable
(
client_conv2d_fwd_perchannel_quantization conv2d_fwd_perchannel_quantization.cpp
)
add_executable
(
client_conv2d_fwd_perchannel_quantization conv2d_fwd_perchannel_quantization.cpp
)
target_link_libraries
(
client_conv2d_fwd_perchannel_quantization PRIVATE composable_kernel::device_conv_operations composable_kernel::device_other_operations composable_kernel::device_gemm_operations
)
target_link_libraries
(
client_conv2d_fwd_perchannel_quantization PRIVATE composable_kernel::device_conv_operations composable_kernel::device_other_operations composable_kernel::device_gemm_operations
)
target_compile_features
(
client_conv2d_fwd_perchannel_quantization PRIVATE cxx_std_17
)
add_executable
(
client_conv2d_fwd_perlayer_quantization conv2d_fwd_perlayer_quantization.cpp
)
add_executable
(
client_conv2d_fwd_perlayer_quantization conv2d_fwd_perlayer_quantization.cpp
)
target_link_libraries
(
client_conv2d_fwd_perlayer_quantization PRIVATE composable_kernel::device_conv_operations composable_kernel::device_other_operations composable_kernel::device_gemm_operations
)
target_link_libraries
(
client_conv2d_fwd_perlayer_quantization PRIVATE composable_kernel::device_conv_operations composable_kernel::device_other_operations composable_kernel::device_gemm_operations
)
target_compile_features
(
client_conv2d_fwd_perlayer_quantization PRIVATE cxx_std_17
)
add_executable
(
client_gemm_quantization gemm_quantization.cpp
)
add_executable
(
client_gemm_quantization gemm_quantization.cpp
)
target_link_libraries
(
client_gemm_quantization PRIVATE composable_kernel::device_conv_operations composable_kernel::device_other_operations composable_kernel::device_gemm_operations
)
target_link_libraries
(
client_gemm_quantization PRIVATE composable_kernel::device_conv_operations composable_kernel::device_other_operations composable_kernel::device_gemm_operations
)
target_compile_features
(
client_gemm_quantization PRIVATE cxx_std_17
)
endif
()
endif
()
client_example/11_grouped_conv_bwd_weight/CMakeLists.txt
View file @
e276fc95
...
@@ -9,9 +9,3 @@ target_link_libraries(client_grouped_conv2d_bwd_weight_fp16 PRIVATE composable_k
...
@@ -9,9 +9,3 @@ target_link_libraries(client_grouped_conv2d_bwd_weight_fp16 PRIVATE composable_k
target_link_libraries
(
client_grouped_conv3d_bwd_weight_fp16 PRIVATE composable_kernel::device_conv_operations
)
target_link_libraries
(
client_grouped_conv3d_bwd_weight_fp16 PRIVATE composable_kernel::device_conv_operations
)
target_link_libraries
(
client_grouped_conv3d_bwd_weight_fp32 PRIVATE composable_kernel::device_conv_operations
)
target_link_libraries
(
client_grouped_conv3d_bwd_weight_fp32 PRIVATE composable_kernel::device_conv_operations
)
target_link_libraries
(
client_grouped_conv3d_bwd_weight_fp16_comp_bf8_fp8 PRIVATE composable_kernel::device_conv_operations
)
target_link_libraries
(
client_grouped_conv3d_bwd_weight_fp16_comp_bf8_fp8 PRIVATE composable_kernel::device_conv_operations
)
target_compile_features
(
client_grouped_conv1d_bwd_weight_fp16 PRIVATE cxx_std_17
)
target_compile_features
(
client_grouped_conv2d_bwd_weight_fp16 PRIVATE cxx_std_17
)
target_compile_features
(
client_grouped_conv3d_bwd_weight_fp16 PRIVATE cxx_std_17
)
target_compile_features
(
client_grouped_conv3d_bwd_weight_fp32 PRIVATE cxx_std_17
)
target_compile_features
(
client_grouped_conv3d_bwd_weight_fp16_comp_bf8_fp8 PRIVATE cxx_std_17
)
client_example/12_elementwise_normalization/CMakeLists.txt
View file @
e276fc95
add_executable
(
client_elementwise_layernorm2d elementwise_layernorm2d.cpp
)
add_executable
(
client_elementwise_layernorm2d elementwise_layernorm2d.cpp
)
target_link_libraries
(
client_elementwise_layernorm2d PRIVATE composable_kernel::device_other_operations
)
target_link_libraries
(
client_elementwise_layernorm2d PRIVATE composable_kernel::device_other_operations
)
target_compile_features
(
client_elementwise_layernorm2d PRIVATE cxx_std_17
)
client_example/13_batchnorm/CMakeLists.txt
View file @
e276fc95
...
@@ -4,6 +4,3 @@ add_executable(client_batchnorm_infer_nhwc batchnorm_infer_nhwc.cpp)
...
@@ -4,6 +4,3 @@ add_executable(client_batchnorm_infer_nhwc batchnorm_infer_nhwc.cpp)
target_link_libraries
(
client_batchnorm_fwd_nhwc PRIVATE composable_kernel::device_other_operations
)
target_link_libraries
(
client_batchnorm_fwd_nhwc PRIVATE composable_kernel::device_other_operations
)
target_link_libraries
(
client_batchnorm_bwd_nhwc PRIVATE composable_kernel::device_other_operations
)
target_link_libraries
(
client_batchnorm_bwd_nhwc PRIVATE composable_kernel::device_other_operations
)
target_link_libraries
(
client_batchnorm_infer_nhwc PRIVATE composable_kernel::device_other_operations
)
target_link_libraries
(
client_batchnorm_infer_nhwc PRIVATE composable_kernel::device_other_operations
)
target_compile_features
(
client_batchnorm_fwd_nhwc PRIVATE cxx_std_17
)
target_compile_features
(
client_batchnorm_bwd_nhwc PRIVATE cxx_std_17
)
target_compile_features
(
client_batchnorm_infer_nhwc PRIVATE cxx_std_17
)
client_example/14_instance_id/CMakeLists.txt
View file @
e276fc95
add_executable
(
client_batchnorm_fwd_instance_id batchnorm_fwd_instance_id.cpp
)
add_executable
(
client_batchnorm_fwd_instance_id batchnorm_fwd_instance_id.cpp
)
target_link_libraries
(
client_batchnorm_fwd_instance_id PRIVATE composable_kernel::device_other_operations
)
target_link_libraries
(
client_batchnorm_fwd_instance_id PRIVATE composable_kernel::device_other_operations
)
target_compile_features
(
client_batchnorm_fwd_instance_id PRIVATE cxx_std_17
)
client_example/CMakeLists.txt
View file @
e276fc95
cmake_minimum_required
(
VERSION 3.15
)
cmake_minimum_required
(
VERSION 3.15
)
project
(
ck_app LANGUAGES CXX
)
project
(
ck_app
)
add_compile_options
(
-std=c++17
)
if
(
DTYPES
)
if
(
DTYPES
)
add_definitions
(
-DDTYPES
)
add_definitions
(
-DDTYPES
)
...
@@ -48,60 +49,13 @@ else()
...
@@ -48,60 +49,13 @@ else()
endif
()
endif
()
find_package
(
composable_kernel COMPONENTS device_other_operations device_gemm_operations device_conv_operations device_contraction_operations device_reduction_operations
)
find_package
(
composable_kernel COMPONENTS device_other_operations device_gemm_operations device_conv_operations device_contraction_operations device_reduction_operations
)
find_package
(
hip REQUIRED PATHS /opt/rocm
$ENV{HIP_PATH}
)
find_package
(
hip REQUIRED PATHS /opt/rocm
)
message
(
STATUS
"Build with HIP
${
hip_VERSION
}
"
)
message
(
STATUS
"Build with HIP
${
hip_VERSION
}
"
)
add_subdirectory
(
01_gemm
)
# add all example subdir
add_subdirectory
(
02_gemm_bilinear
)
file
(
GLOB dir_list LIST_DIRECTORIES true *
)
add_subdirectory
(
03_gemm_bias_relu
)
FOREACH
(
subdir
${
dir_list
}
)
add_subdirectory
(
04_gemm_add_add_fastgelu
)
IF
(
IS_DIRECTORY
"
${
subdir
}
"
AND
(
NOT
"
${
subdir
}
"
MATCHES
"build"
))
add_subdirectory
(
09_convnd_fwd
)
add_subdirectory
(
${
subdir
}
)
add_subdirectory
(
10_convnd_fwd_multiple_d_multiple_reduce
)
ENDIF
()
add_subdirectory
(
12_reduce
)
ENDFOREACH
()
add_subdirectory
(
13_pool2d_fwd
)
add_subdirectory
(
14_gemm_quantization
)
add_subdirectory
(
15_grouped_gemm
)
add_subdirectory
(
16_gemm_multi_d_multi_reduces
)
add_subdirectory
(
17_convnd_bwd_data
)
add_subdirectory
(
18_batched_gemm_reduce
)
add_subdirectory
(
19_binary_elementwise
)
add_subdirectory
(
20_grouped_conv_bwd_weight
)
add_subdirectory
(
21_gemm_layernorm
)
add_subdirectory
(
22_cgemm
)
add_subdirectory
(
23_softmax
)
add_subdirectory
(
24_batched_gemm
)
add_subdirectory
(
25_gemm_bias_e_permute
)
add_subdirectory
(
26_contraction
)
add_subdirectory
(
27_layernorm2d_fwd
)
add_subdirectory
(
28_grouped_gemm_bias_e_permute
)
add_subdirectory
(
29_batched_gemm_bias_e_permute
)
add_subdirectory
(
30_grouped_conv_fwd_multiple_d
)
add_subdirectory
(
31_batched_gemm_gemm
)
add_subdirectory
(
32_batched_gemm_scale_softmax_gemm
)
add_subdirectory
(
33_multiple_reduce
)
add_subdirectory
(
34_batchnorm
)
add_subdirectory
(
35_splitK_gemm
)
add_subdirectory
(
36_sparse_embedding
)
add_subdirectory
(
37_batched_gemm_add_add_relu_gemm_add
)
add_subdirectory
(
38_grouped_conv_bwd_data_multiple_d
)
add_subdirectory
(
39_permute
)
add_subdirectory
(
40_conv2d_fwd_quantization
)
add_subdirectory
(
41_grouped_conv_conv_fwd
)
add_subdirectory
(
42_groupnorm_fwd
)
add_subdirectory
(
43_splitk_gemm_bias_e_permute
)
add_subdirectory
(
44_elementwise_permute
)
add_subdirectory
(
45_elementwise_normalization
)
add_subdirectory
(
46_gemm_add_multiply
)
add_subdirectory
(
47_gemm_bias_softmax_gemm_permute
)
add_subdirectory
(
48_pool3d_fwd
)
add_subdirectory
(
49_maxpool2d_bwd
)
add_subdirectory
(
50_put_element
)
add_subdirectory
(
51_avgpool3d_bwd
)
add_subdirectory
(
52_im2col_col2im
)
add_subdirectory
(
53_layernorm_bwd
)
add_subdirectory
(
54_groupnorm_bwd
)
add_subdirectory
(
60_gemm_multi_ABD
)
add_subdirectory
(
61_contraction_multi_ABD
)
add_subdirectory
(
62_conv_fwd_activ
)
add_subdirectory
(
63_layernorm4d_fwd
)
add_subdirectory
(
64_tensor_transforms
)
\ No newline at end of file
cmake/EnableCompilerWarnings.cmake
View file @
e276fc95
...
@@ -25,84 +25,90 @@
...
@@ -25,84 +25,90 @@
################################################################################
################################################################################
# - Enable warning all for gcc/clang or use /W4 for visual studio
# - Enable warning all for gcc/clang or use /W4 for visual studio
## Strict compile options for Visual C++ compiler
## Strict warning level
set
(
__default_msvc_compile_options /w
)
if
(
MSVC
)
# Use the highest warning level for visual studio.
set
(
CMAKE_CXX_FLAGS
"
${
CMAKE_CXX_FLAGS
}
/w"
)
set
(
CMAKE_C_FLAGS
"
${
CMAKE_C_FLAGS
}
/w"
)
# set(CMAKE_CXX_WARNING_LEVEL 4)
# if (CMAKE_CXX_FLAGS MATCHES "/W[0-4]")
# string(REGEX REPLACE "/W[0-4]" "/W4" CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS}")
# else ()
# set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} /W4")
# endif ()
## Strict compile options for GNU/Clang compilers
# set(CMAKE_C_WARNING_LEVEL 4)
set
(
__default_compile_options
# if (CMAKE_C_FLAGS MATCHES "/W[0-4]")
-Wall -Wextra
# string(REGEX REPLACE "/W[0-4]" "/W4" CMAKE_C_FLAGS "${CMAKE_C_FLAGS}")
-Wcomment
# else ()
-Wendif-labels
# set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} /W4")
-Wformat
# endif ()
-Winit-self
-Wreturn-type
-Wsequence-point
-Wswitch
-Wtrigraphs
-Wundef
-Wuninitialized
-Wunreachable-code
-Wunused
-Wno-reserved-identifier
-Werror
-Wno-option-ignored
-Wsign-compare
-Wno-extra-semi-stmt
-Wno-unused-template
)
## Strict compile options for Clang compilers
set
(
__default_clang_compile_options
-Weverything
-Wshadow
-Wno-c++98-compat
-Wno-c++98-compat-pedantic
-Wno-conversion
-Wno-double-promotion
-Wno-exit-time-destructors
-Wno-extra-semi
-Wno-float-conversion
-Wno-gnu-anonymous-struct
-Wno-gnu-zero-variadic-macro-arguments
-Wno-missing-prototypes
-Wno-nested-anon-types
-Wno-padded
-Wno-return-std-move-in-c++11
-Wno-shorten-64-to-32
-Wno-sign-conversion
-Wno-unknown-warning-option
-Wno-unused-command-line-argument
-Wno-weak-vtables
-Wno-covered-switch-default
-Wno-unsafe-buffer-usage
)
if
(
WIN32
)
list
(
APPEND __default_clang_compile_options
-fms-extensions
-fms-compatibility
-fdelayed-template-parsing
)
endif
()
set
(
__default_gnu_compile_options
-Wduplicated-branches
-Wduplicated-cond
-Wno-noexcept-type
-Wno-ignored-attributes
-Wodr
-Wshift-negative-value
-Wshift-overflow=2
-Wno-missing-field-initializers
-Wno-maybe-uninitialized
-Wno-deprecated-declarations
)
add_compile_options
(
"$<$<OR:$<CXX_COMPILER_ID:MSVC>,$<C_COMPILER_ID:MSVC>>:
${
__default_msvc_compile_options
}
>"
"$<$<OR:$<CXX_COMPILER_ID:GNU,Clang>,$<C_COMPILER_ID:GNU,Clang>>:
${
__default_compile_options
}
>"
"$<$<OR:$<AND:$<CXX_COMPILER_ID:GNU>,$<VERSION_GREATER_EQUAL:$<CXX_COMPILER_VERSION>,7>>,$<AND:$<C_COMPILER_ID:GNU>,$<VERSION_GREATER_EQUAL:$<C_COMPILER_VERSION>,7>>>:
${
__default_gnu_compile_options
}
>"
"$<$<OR:$<CXX_COMPILER_ID:Clang>,$<C_COMPILER_ID:Clang>>:
${
__default_clang_compile_options
}
>"
)
unset
(
__default_msvc_compile_options
)
unset
(
__default_compile_options
)
unset
(
__default_gnu_compile_options
)
unset
(
__default_clang_compile_options
)
else
()
foreach
(
COMPILER C CXX
)
set
(
CMAKE_COMPILER_WARNINGS
)
# use -Wall for gcc and clang
list
(
APPEND CMAKE_COMPILER_WARNINGS
-Wall
-Wextra
-Wcomment
-Wendif-labels
-Wformat
-Winit-self
-Wreturn-type
-Wsequence-point
# Shadow is broken on gcc when using lambdas
# -Wshadow
-Wswitch
-Wtrigraphs
-Wundef
-Wuninitialized
-Wunreachable-code
-Wunused
-Wno-reserved-identifier
-Werror
-Wno-option-ignored
-Wsign-compare
-Wno-extra-semi-stmt
-Wno-unused-template
)
if
(
CMAKE_
${
COMPILER
}
_COMPILER_ID MATCHES
"Clang"
)
list
(
APPEND CMAKE_COMPILER_WARNINGS
-Weverything
-Wno-c++98-compat
-Wno-c++98-compat-pedantic
-Wno-conversion
-Wno-double-promotion
-Wno-exit-time-destructors
-Wno-extra-semi
-Wno-float-conversion
-Wno-gnu-anonymous-struct
-Wno-gnu-zero-variadic-macro-arguments
-Wno-missing-prototypes
-Wno-nested-anon-types
-Wno-padded
-Wno-return-std-move-in-c++11
-Wno-shorten-64-to-32
-Wno-sign-conversion
-Wno-unknown-warning-option
-Wno-unused-command-line-argument
-Wno-weak-vtables
-Wno-covered-switch-default
-Wno-unsafe-buffer-usage
)
else
()
if
(
CMAKE_
${
COMPILER
}
_COMPILER_ID MATCHES
"GNU"
AND
${
COMPILER
}
MATCHES
"CXX"
)
# cmake 3.5.2 does not support >=.
if
(
NOT CMAKE_CXX_COMPILER_VERSION VERSION_LESS
"6.1"
)
list
(
APPEND CMAKE_COMPILER_WARNINGS
-Wno-ignored-attributes
)
endif
()
endif
()
list
(
APPEND CMAKE_COMPILER_WARNINGS
-Wno-missing-field-initializers
-Wno-deprecated-declarations
)
endif
()
add_definitions
(
${
CMAKE_COMPILER_WARNINGS
}
)
endforeach
()
endif
()
cmake/GTest.cmake
View file @
e276fc95
...
@@ -32,15 +32,21 @@ FetchContent_MakeAvailable(googletest)
...
@@ -32,15 +32,21 @@ FetchContent_MakeAvailable(googletest)
# Restore the old value of BUILD_SHARED_LIBS
# Restore the old value of BUILD_SHARED_LIBS
set
(
BUILD_SHARED_LIBS
${
__build_shared_libs
}
CACHE BOOL
"Type of libraries to build"
FORCE
)
set
(
BUILD_SHARED_LIBS
${
__build_shared_libs
}
CACHE BOOL
"Type of libraries to build"
FORCE
)
set
(
GTEST_CXX_FLAGS
-Wno-undef
-Wno-global-constructors
-Wno-zero-as-null-pointer-constant
-Wno-switch-enum
-Wno-float-equal
-Wno-unused-member-function
)
if
(
WIN32
)
if
(
WIN32
)
list
(
APPEND GTEST_
CMAKE_
CXX_FLAGS
list
(
APPEND GTEST_CXX_FLAGS
-Wno-suggest-destructor-override
-Wno-suggest-destructor-override
-Wno-suggest-override
-Wno-suggest-override
-Wno-nonportable-system-include-path
-Wno-nonportable-system-include-path
-Wno-language-extension-token
)
-Wno-language-extension-token
)
endif
()
endif
()
target_compile_options
(
gtest PRIVATE -Wno-undef
)
target_compile_options
(
gtest PRIVATE
${
GTEST_CXX_FLAGS
}
)
target_compile_options
(
gtest_main PRIVATE -Wno-undef
)
target_compile_options
(
gtest_main PRIVATE
${
GTEST_CXX_FLAGS
}
)
example/01_gemm/gemm_xdl_skip_b_lds_fp16.cpp
View file @
e276fc95
...
@@ -79,7 +79,7 @@ std::ostream& show_2d_matrix(std::ostream& os, Tensor<DataType>& matrix)
...
@@ -79,7 +79,7 @@ std::ostream& show_2d_matrix(std::ostream& os, Tensor<DataType>& matrix)
}
}
int
main
(
int
argc
,
char
*
argv
[])
int
main
(
int
argc
,
char
*
argv
[])
{
{
bool
do_verification
=
false
;
bool
do_verification
=
0
;
int
init_method
=
0
;
int
init_method
=
0
;
bool
time_kernel
=
false
;
bool
time_kernel
=
false
;
...
...
example/34_batchnorm/batchnorm_backward_nhwc.cpp
View file @
e276fc95
...
@@ -112,7 +112,7 @@ bool bnorm_bwd_nhwc_test(bool do_verification,
...
@@ -112,7 +112,7 @@ bool bnorm_bwd_nhwc_test(bool do_verification,
bool
time_kernel
,
bool
time_kernel
,
const
std
::
vector
<
size_t
>
inOutLengths
,
const
std
::
vector
<
size_t
>
inOutLengths
,
bool
haveSavedMeanInvVar
,
bool
haveSavedMeanInvVar
,
double
_
epsilon
)
double
epsilon
)
{
{
// for NHWC BatchNorm calculation of mean and meansquare
// for NHWC BatchNorm calculation of mean and meansquare
constexpr
index_t
Rank
=
4
;
constexpr
index_t
Rank
=
4
;
...
@@ -292,7 +292,7 @@ bool bnorm_bwd_nhwc_test(bool do_verification,
...
@@ -292,7 +292,7 @@ bool bnorm_bwd_nhwc_test(bool do_verification,
bnScale_dev
.
GetDeviceBuffer
(),
bnScale_dev
.
GetDeviceBuffer
(),
haveSavedMeanInvVar
?
savedMean_dev
.
GetDeviceBuffer
()
:
nullptr
,
haveSavedMeanInvVar
?
savedMean_dev
.
GetDeviceBuffer
()
:
nullptr
,
haveSavedMeanInvVar
?
savedInvVar_dev
.
GetDeviceBuffer
()
:
nullptr
,
haveSavedMeanInvVar
?
savedInvVar_dev
.
GetDeviceBuffer
()
:
nullptr
,
_
epsilon
,
epsilon
,
PassThroughOp
{},
PassThroughOp
{},
dx_dev
.
GetDeviceBuffer
(),
dx_dev
.
GetDeviceBuffer
(),
dscale_dev
.
GetDeviceBuffer
(),
dscale_dev
.
GetDeviceBuffer
(),
...
@@ -371,7 +371,7 @@ bool bnorm_bwd_nhwc_test(bool do_verification,
...
@@ -371,7 +371,7 @@ bool bnorm_bwd_nhwc_test(bool do_verification,
bnScale
.
mData
.
data
(),
bnScale
.
mData
.
data
(),
haveSavedMeanInvVar
?
savedMean
.
mData
.
data
()
:
nullptr
,
haveSavedMeanInvVar
?
savedMean
.
mData
.
data
()
:
nullptr
,
haveSavedMeanInvVar
?
savedInvVar
.
mData
.
data
()
:
nullptr
,
haveSavedMeanInvVar
?
savedInvVar
.
mData
.
data
()
:
nullptr
,
_
epsilon
,
epsilon
,
PassThroughOp
{},
PassThroughOp
{},
dx_ref
.
mData
.
data
(),
dx_ref
.
mData
.
data
(),
dscale_ref
.
mData
.
data
(),
dscale_ref
.
mData
.
data
(),
...
...
example/34_batchnorm/batchnorm_forward_inferring_nhwc.cpp
View file @
e276fc95
...
@@ -119,7 +119,7 @@ bool bnorm_infer_nhwc_test(bool do_verification,
...
@@ -119,7 +119,7 @@ bool bnorm_infer_nhwc_test(bool do_verification,
int
init_method
,
int
init_method
,
bool
time_kernel
,
bool
time_kernel
,
const
std
::
vector
<
size_t
>
inOutLengths
,
const
std
::
vector
<
size_t
>
inOutLengths
,
double
_
epsilon
)
double
epsilon
)
{
{
// for NHWC BatchNorm calculation of mean and meansquare
// for NHWC BatchNorm calculation of mean and meansquare
constexpr
int
Rank
=
4
;
constexpr
int
Rank
=
4
;
...
@@ -251,7 +251,7 @@ bool bnorm_infer_nhwc_test(bool do_verification,
...
@@ -251,7 +251,7 @@ bool bnorm_infer_nhwc_test(bool do_verification,
x_dev
.
GetDeviceBuffer
(),
x_dev
.
GetDeviceBuffer
(),
bnScale_dev
.
GetDeviceBuffer
(),
bnScale_dev
.
GetDeviceBuffer
(),
bnBias_dev
.
GetDeviceBuffer
(),
bnBias_dev
.
GetDeviceBuffer
(),
_
epsilon
,
epsilon
,
estimatedMean_dev
.
GetDeviceBuffer
(),
estimatedMean_dev
.
GetDeviceBuffer
(),
estimatedVariance_dev
.
GetDeviceBuffer
(),
estimatedVariance_dev
.
GetDeviceBuffer
(),
y_dev
.
GetDeviceBuffer
());
y_dev
.
GetDeviceBuffer
());
...
@@ -289,7 +289,7 @@ bool bnorm_infer_nhwc_test(bool do_verification,
...
@@ -289,7 +289,7 @@ bool bnorm_infer_nhwc_test(bool do_verification,
x
.
mData
.
data
(),
x
.
mData
.
data
(),
bnScale
.
mData
.
data
(),
bnScale
.
mData
.
data
(),
bnBias
.
mData
.
data
(),
bnBias
.
mData
.
data
(),
_
epsilon
,
epsilon
,
PassThroughOp
{},
PassThroughOp
{},
estimatedMean
.
mData
.
data
(),
estimatedMean
.
mData
.
data
(),
estimatedVariance
.
mData
.
data
(),
estimatedVariance
.
mData
.
data
(),
...
...
example/34_batchnorm/batchnorm_forward_training_nhwc.cpp
View file @
e276fc95
...
@@ -135,8 +135,8 @@ bool bnorm_fwd_nhwc_test(bool do_verification,
...
@@ -135,8 +135,8 @@ bool bnorm_fwd_nhwc_test(bool do_verification,
const
std
::
vector
<
size_t
>
inOutLengths
,
const
std
::
vector
<
size_t
>
inOutLengths
,
bool
updateMovingAverage
,
bool
updateMovingAverage
,
bool
saveMeanAndInvVariance
,
bool
saveMeanAndInvVariance
,
double
_
averageFactor
,
double
averageFactor
,
double
_
epsilon
)
double
epsilon
)
{
{
// for NHWC BatchNorm calculation of mean and meansquare
// for NHWC BatchNorm calculation of mean and meansquare
constexpr
int
Rank
=
4
;
constexpr
int
Rank
=
4
;
...
@@ -310,12 +310,12 @@ bool bnorm_fwd_nhwc_test(bool do_verification,
...
@@ -310,12 +310,12 @@ bool bnorm_fwd_nhwc_test(bool do_verification,
x_dev
.
GetDeviceBuffer
(),
x_dev
.
GetDeviceBuffer
(),
bnScale_dev
.
GetDeviceBuffer
(),
bnScale_dev
.
GetDeviceBuffer
(),
bnBias_dev
.
GetDeviceBuffer
(),
bnBias_dev
.
GetDeviceBuffer
(),
_
epsilon
,
epsilon
,
PassThroughOp
{},
PassThroughOp
{},
y_dev
.
GetDeviceBuffer
(),
y_dev
.
GetDeviceBuffer
(),
saveMeanAndInvVariance
?
resultSaveMean_dev
.
GetDeviceBuffer
()
:
nullptr
,
saveMeanAndInvVariance
?
resultSaveMean_dev
.
GetDeviceBuffer
()
:
nullptr
,
saveMeanAndInvVariance
?
resultSaveInvVariance_dev
.
GetDeviceBuffer
()
:
nullptr
,
saveMeanAndInvVariance
?
resultSaveInvVariance_dev
.
GetDeviceBuffer
()
:
nullptr
,
_
averageFactor
,
averageFactor
,
updateMovingAverage
?
resultRunningMean_dev
.
GetDeviceBuffer
()
:
nullptr
,
updateMovingAverage
?
resultRunningMean_dev
.
GetDeviceBuffer
()
:
nullptr
,
updateMovingAverage
?
resultRunningVariance_dev
.
GetDeviceBuffer
()
:
nullptr
);
updateMovingAverage
?
resultRunningVariance_dev
.
GetDeviceBuffer
()
:
nullptr
);
...
@@ -392,12 +392,12 @@ bool bnorm_fwd_nhwc_test(bool do_verification,
...
@@ -392,12 +392,12 @@ bool bnorm_fwd_nhwc_test(bool do_verification,
x
.
mData
.
data
(),
x
.
mData
.
data
(),
bnScale
.
mData
.
data
(),
bnScale
.
mData
.
data
(),
bnBias
.
mData
.
data
(),
bnBias
.
mData
.
data
(),
_
epsilon
,
epsilon
,
PassThroughOp
{},
PassThroughOp
{},
y_ref
.
mData
.
data
(),
y_ref
.
mData
.
data
(),
saveMeanAndInvVariance
?
resultSaveMean_ref
.
mData
.
data
()
:
nullptr
,
saveMeanAndInvVariance
?
resultSaveMean_ref
.
mData
.
data
()
:
nullptr
,
saveMeanAndInvVariance
?
resultSaveInvVariance_ref
.
mData
.
data
()
:
nullptr
,
saveMeanAndInvVariance
?
resultSaveInvVariance_ref
.
mData
.
data
()
:
nullptr
,
_
averageFactor
,
averageFactor
,
updateMovingAverage
?
resultRunningMean_ref
.
mData
.
data
()
:
nullptr
,
updateMovingAverage
?
resultRunningMean_ref
.
mData
.
data
()
:
nullptr
,
updateMovingAverage
?
resultRunningVariance_ref
.
mData
.
data
()
:
nullptr
);
updateMovingAverage
?
resultRunningVariance_ref
.
mData
.
data
()
:
nullptr
);
...
...
example/34_batchnorm/batchnorm_forward_training_nhwc_obsolete.cpp
View file @
e276fc95
...
@@ -135,8 +135,8 @@ bool bnorm_fwd_nhwc_test(bool do_verification,
...
@@ -135,8 +135,8 @@ bool bnorm_fwd_nhwc_test(bool do_verification,
const
std
::
vector
<
size_t
>
inOutLengths
,
const
std
::
vector
<
size_t
>
inOutLengths
,
bool
updateMovingAverage
,
bool
updateMovingAverage
,
bool
saveMeanAndInvVariance
,
bool
saveMeanAndInvVariance
,
double
_
averageFactor
,
double
averageFactor
,
double
_
epsilon
)
double
epsilon
)
{
{
// for NHWC BatchNorm calculation of mean and meansquare
// for NHWC BatchNorm calculation of mean and meansquare
constexpr
int
Rank
=
4
;
constexpr
int
Rank
=
4
;
...
@@ -310,12 +310,12 @@ bool bnorm_fwd_nhwc_test(bool do_verification,
...
@@ -310,12 +310,12 @@ bool bnorm_fwd_nhwc_test(bool do_verification,
x_dev
.
GetDeviceBuffer
(),
x_dev
.
GetDeviceBuffer
(),
bnScale_dev
.
GetDeviceBuffer
(),
bnScale_dev
.
GetDeviceBuffer
(),
bnBias_dev
.
GetDeviceBuffer
(),
bnBias_dev
.
GetDeviceBuffer
(),
_
epsilon
,
epsilon
,
PassThroughOp
{},
PassThroughOp
{},
y_dev
.
GetDeviceBuffer
(),
y_dev
.
GetDeviceBuffer
(),
saveMeanAndInvVariance
?
resultSaveMean_dev
.
GetDeviceBuffer
()
:
nullptr
,
saveMeanAndInvVariance
?
resultSaveMean_dev
.
GetDeviceBuffer
()
:
nullptr
,
saveMeanAndInvVariance
?
resultSaveInvVariance_dev
.
GetDeviceBuffer
()
:
nullptr
,
saveMeanAndInvVariance
?
resultSaveInvVariance_dev
.
GetDeviceBuffer
()
:
nullptr
,
_
averageFactor
,
averageFactor
,
updateMovingAverage
?
resultRunningMean_dev
.
GetDeviceBuffer
()
:
nullptr
,
updateMovingAverage
?
resultRunningMean_dev
.
GetDeviceBuffer
()
:
nullptr
,
updateMovingAverage
?
resultRunningVariance_dev
.
GetDeviceBuffer
()
:
nullptr
);
updateMovingAverage
?
resultRunningVariance_dev
.
GetDeviceBuffer
()
:
nullptr
);
...
@@ -392,12 +392,12 @@ bool bnorm_fwd_nhwc_test(bool do_verification,
...
@@ -392,12 +392,12 @@ bool bnorm_fwd_nhwc_test(bool do_verification,
x
.
mData
.
data
(),
x
.
mData
.
data
(),
bnScale
.
mData
.
data
(),
bnScale
.
mData
.
data
(),
bnBias
.
mData
.
data
(),
bnBias
.
mData
.
data
(),
_
epsilon
,
epsilon
,
PassThroughOp
{},
PassThroughOp
{},
y_ref
.
mData
.
data
(),
y_ref
.
mData
.
data
(),
saveMeanAndInvVariance
?
resultSaveMean_ref
.
mData
.
data
()
:
nullptr
,
saveMeanAndInvVariance
?
resultSaveMean_ref
.
mData
.
data
()
:
nullptr
,
saveMeanAndInvVariance
?
resultSaveInvVariance_ref
.
mData
.
data
()
:
nullptr
,
saveMeanAndInvVariance
?
resultSaveInvVariance_ref
.
mData
.
data
()
:
nullptr
,
_
averageFactor
,
averageFactor
,
updateMovingAverage
?
resultRunningMean_ref
.
mData
.
data
()
:
nullptr
,
updateMovingAverage
?
resultRunningMean_ref
.
mData
.
data
()
:
nullptr
,
updateMovingAverage
?
resultRunningVariance_ref
.
mData
.
data
()
:
nullptr
);
updateMovingAverage
?
resultRunningVariance_ref
.
mData
.
data
()
:
nullptr
);
...
...
Prev
1
2
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment