Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel_ROCM
Commits
9f1b4276
Commit
9f1b4276
authored
Apr 04, 2024
by
Jakub Piasecki
Browse files
resolved conflicts
parents
711857c4
c7010716
Changes
198
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
230 additions
and
180 deletions
+230
-180
library/src/tensor_operation_instance/gpu/grouped_gemm_bias/CMakeLists.txt
...r_operation_instance/gpu/grouped_gemm_bias/CMakeLists.txt
+1
-0
library/src/tensor_operation_instance/gpu/grouped_gemm_fastgelu/CMakeLists.txt
...eration_instance/gpu/grouped_gemm_fastgelu/CMakeLists.txt
+1
-0
library/src/tensor_operation_instance/gpu/grouped_gemm_fixed_nk/CMakeLists.txt
...eration_instance/gpu/grouped_gemm_fixed_nk/CMakeLists.txt
+1
-0
library/src/tensor_operation_instance/gpu/quantization/CMakeLists.txt
...tensor_operation_instance/gpu/quantization/CMakeLists.txt
+1
-0
profiler/include/profiler/profile_grouped_conv_fwd_impl.hpp
profiler/include/profiler/profile_grouped_conv_fwd_impl.hpp
+7
-3
profiler/include/profiler/profile_permute_scale_impl.hpp
profiler/include/profiler/profile_permute_scale_impl.hpp
+12
-10
profiler/src/CMakeLists.txt
profiler/src/CMakeLists.txt
+95
-61
profiler/src/profile_grouped_conv_fwd.cpp
profiler/src/profile_grouped_conv_fwd.cpp
+46
-29
profiler/src/profile_grouped_gemm_two_stage.cpp
profiler/src/profile_grouped_gemm_two_stage.cpp
+3
-3
test/CMakeLists.txt
test/CMakeLists.txt
+24
-1
test/batched_gemm/CMakeLists.txt
test/batched_gemm/CMakeLists.txt
+3
-8
test/batched_gemm/test_batched_gemm_xdl.cpp
test/batched_gemm/test_batched_gemm_xdl.cpp
+0
-0
test/batched_gemm_gemm/CMakeLists.txt
test/batched_gemm_gemm/CMakeLists.txt
+6
-13
test/batched_gemm_gemm/test_batched_gemm_gemm_fp16_xdl.cpp
test/batched_gemm_gemm/test_batched_gemm_gemm_fp16_xdl.cpp
+0
-0
test/batched_gemm_reduce/CMakeLists.txt
test/batched_gemm_reduce/CMakeLists.txt
+3
-10
test/batched_gemm_reduce/batched_gemm_reduce_fp16_xdl.cpp
test/batched_gemm_reduce/batched_gemm_reduce_fp16_xdl.cpp
+0
-0
test/batched_gemm_softmax_gemm/CMakeLists.txt
test/batched_gemm_softmax_gemm/CMakeLists.txt
+6
-13
test/batched_gemm_softmax_gemm/test_batched_gemm_softmax_gemm_fp16_xdl.cpp
..._softmax_gemm/test_batched_gemm_softmax_gemm_fp16_xdl.cpp
+0
-0
test/batched_gemm_softmax_gemm_permute/CMakeLists.txt
test/batched_gemm_softmax_gemm_permute/CMakeLists.txt
+21
-29
test/batched_gemm_softmax_gemm_permute/test_batched_gemm_bias_softmax_gemm_permute_bf16_xdl.cpp
.../test_batched_gemm_bias_softmax_gemm_permute_bf16_xdl.cpp
+0
-0
No files found.
library/src/tensor_operation_instance/gpu/grouped_gemm_bias/CMakeLists.txt
View file @
9f1b4276
# ONLY XDL_KERNELS
add_instance_library
(
device_grouped_gemm_bias_instance
device_grouped_gemm_xdl_fixed_nk_bias_f16_f16_f16_mk_kn_mn_instance.cpp
device_grouped_gemm_xdl_fixed_nk_bias_f16_f16_f16_mk_nk_mn_instance.cpp
...
...
library/src/tensor_operation_instance/gpu/grouped_gemm_fastgelu/CMakeLists.txt
View file @
9f1b4276
# ONLY XDL_KERNELS
add_instance_library
(
device_grouped_gemm_fastgelu_instance
device_grouped_gemm_fastgelu_xdl_f16_f16_f16_mk_kn_mn_instance.cpp
device_grouped_gemm_fastgelu_xdl_f16_f16_f16_mk_nk_mn_instance.cpp
...
...
library/src/tensor_operation_instance/gpu/grouped_gemm_fixed_nk/CMakeLists.txt
View file @
9f1b4276
# ONLY XDL_KERNELS
set
(
GROUPED_GEMM_FIXED_NK_INSTANCES
)
list
(
APPEND GROUPED_GEMM_FIXED_NK_INSTANCES device_grouped_gemm_xdl_fixed_nk_f16_f16_f16_mk_kn_mn_instance.cpp
...
...
library/src/tensor_operation_instance/gpu/quantization/CMakeLists.txt
View file @
9f1b4276
# ONLY XDL_AND_DL_KERNELS
set
(
CONV2D_PERLAYER_QUANT_SRC conv2d_fwd/device_conv2d_xdl_perlayer_quantization_int8_instance.cpp
)
set
(
CONV2D_PERCHANNEL_QUANT_SRC conv2d_fwd/device_conv2d_xdl_perchannel_quantization_int8_instance.cpp
)
set
(
CONV2D_BIAS_PERLAYER_QUANT_SRC conv2d_fwd/device_conv2d_xdl_bias_perlayer_quantization_int8_instance.cpp
)
...
...
profiler/include/profiler/profile_grouped_conv_fwd_impl.hpp
View file @
9f1b4276
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-202
3
, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-202
4
, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
...
...
@@ -31,7 +31,9 @@ template <ck::index_t NDimSpatial,
typename
OutLayout
,
typename
InDataType
,
typename
WeiDataType
,
typename
OutDataType
>
typename
OutDataType
,
typename
AComputeType
=
InDataType
,
typename
BComputeType
=
AComputeType
>
bool
profile_grouped_conv_fwd_impl
(
int
do_verification
,
int
init_method
,
bool
do_log
,
...
...
@@ -209,7 +211,9 @@ bool profile_grouped_conv_fwd_impl(int do_verification,
OutDataType
,
InElementOp
,
WeiElementOp
,
OutElementOp
>
;
OutElementOp
,
AComputeType
,
BComputeType
>
;
// get device op instances
const
auto
op_ptrs
=
ck
::
tensor_operation
::
device
::
instance
::
DeviceOperationInstanceFactory
<
...
...
profiler/include/profiler/profile_permute_scale_impl.hpp
View file @
9f1b4276
...
...
@@ -14,6 +14,8 @@
#include "ck/library/tensor_operation_instance/gpu/permute_scale.hpp"
#include "ck/library/reference_tensor_operation/cpu/reference_elementwise.hpp"
#include "ck/library/utility/check_err.hpp"
#include "ck/library/utility/device_memory.hpp"
#include "ck/library/utility/host_tensor.hpp"
...
...
@@ -21,14 +23,6 @@
#include "ck/library/utility/literals.hpp"
namespace
ck
{
template
<
typename
HostTensorA
,
typename
HostTensorB
,
typename
ElementOp
>
void
reference_permute_scale
(
HostTensorB
&
b_tensor
,
const
HostTensorA
&
a_tensor
,
ElementOp
tensor_op
)
{
b_tensor
.
ForEach
([
&
](
auto
&
self
,
auto
idx
)
{
tensor_op
(
self
(
idx
),
a_tensor
(
idx
));
});
}
namespace
profiler
{
template
<
typename
ADataType
,
typename
BDataType
,
index_t
NumDim
>
...
...
@@ -46,7 +40,8 @@ bool profile_permute_scale_impl(int do_verification,
using
ElementOp
=
ck
::
tensor_operation
::
element_wise
::
Scale
;
float
scale
=
2.
f
;
Tensor
<
ADataType
>
a
(
lengths_vector
,
input_strides_vector
);
std
::
array
<
Tensor
<
ADataType
>
,
1
>
as
=
{
Tensor
<
ADataType
>
(
lengths_vector
,
input_strides_vector
)};
Tensor
<
ADataType
>&
a
=
as
[
0
];
Tensor
<
BDataType
>
b
(
lengths_vector
,
output_strides_vector
);
Tensor
<
BDataType
>
host_b
(
lengths_vector
,
output_strides_vector
);
...
...
@@ -83,7 +78,14 @@ bool profile_permute_scale_impl(int do_verification,
if
(
do_verification
)
{
reference_permute_scale
(
host_b
,
a
,
ElementOp
{
scale
});
using
ReferenceElementwiseInstance
=
ck
::
tensor_operation
::
host
::
ReferenceElementwise
<
1
,
ADataType
,
BDataType
,
ElementOp
>
;
auto
ref_elementwise
=
ReferenceElementwiseInstance
{};
auto
ref_invoker
=
ref_elementwise
.
MakeInvoker
();
auto
ref_argument
=
ref_elementwise
.
MakeArgument
(
as
,
host_b
,
ElementOp
{
scale
});
ref_invoker
.
Run
(
ref_argument
);
}
auto
copy
=
[](
const
auto
&
x
,
auto
&
y
)
{
std
::
copy
(
x
.
begin
(),
x
.
end
(),
y
.
begin
());
};
...
...
profiler/src/CMakeLists.txt
View file @
9f1b4276
...
...
@@ -2,19 +2,6 @@
set
(
PROFILER_SOURCES
profiler.cpp
profile_gemm.cpp
profile_gemm_splitk.cpp
profile_gemm_bias_add_reduce.cpp
profile_gemm_add_multiply.cpp
profile_gemm_multiply_add.cpp
profile_gemm_reduce.cpp
profile_batched_gemm.cpp
profile_batched_gemm_reduce.cpp
profile_conv_fwd.cpp
profile_conv_fwd_bias_relu.cpp
profile_conv_fwd_bias_relu_add.cpp
profile_conv_bwd_data.cpp
profile_grouped_conv_fwd.cpp
profile_grouped_conv_bwd_weight.cpp
profile_reduce.cpp
profile_groupnorm_bwd_data.cpp
profile_groupnorm_fwd.cpp
...
...
@@ -29,16 +16,47 @@ set(PROFILER_SOURCES
profile_batchnorm_fwd.cpp
profile_batchnorm_bwd.cpp
profile_batchnorm_infer.cpp
profile_grouped_conv_bwd_data.cpp
profile_conv_tensor_rearrange.cpp
profile_transpose.cpp
profile_permute_scale.cpp
)
if
(
DL_KERNELS
)
list
(
APPEND PROFILER_SOURCES profile_batched_gemm_multi_d.cpp
)
if
(
GPU_TARGETS MATCHES
"gfx9"
)
if
(
DTYPES MATCHES
"fp32"
OR DTYPES MATCHES
"fp64"
OR NOT DEFINED DTYPES
)
list
(
APPEND PROFILER_SOURCES profile_contraction_bilinear.cpp
)
list
(
APPEND PROFILER_SOURCES profile_contraction_scale.cpp
)
endif
()
if
(
DTYPES MATCHES
"fp16"
OR NOT DEFINED DTYPES
)
list
(
APPEND PROFILER_SOURCES profile_gemm_reduce.cpp
)
list
(
APPEND PROFILER_SOURCES profile_batched_gemm_gemm.cpp
)
list
(
APPEND PROFILER_SOURCES profile_batched_gemm_add_relu_gemm_add.cpp
)
list
(
APPEND PROFILER_SOURCES profile_gemm_add.cpp
)
list
(
APPEND PROFILER_SOURCES profile_gemm_add_add_fastgelu.cpp
)
list
(
APPEND PROFILER_SOURCES profile_gemm_add_fastgelu.cpp
)
list
(
APPEND PROFILER_SOURCES profile_grouped_gemm.cpp
)
list
(
APPEND PROFILER_SOURCES profile_gemm_streamk.cpp
)
list
(
APPEND PROFILER_SOURCES profile_gemm_fastgelu.cpp
)
list
(
APPEND PROFILER_SOURCES profile_gemm_add_relu.cpp
)
list
(
APPEND PROFILER_SOURCES profile_gemm_add_silu.cpp
)
list
(
APPEND PROFILER_SOURCES profile_gemm_add_relu_add_layernorm.cpp
)
list
(
APPEND PROFILER_SOURCES profile_grouped_gemm_fixed_nk.cpp
)
list
(
APPEND PROFILER_SOURCES profile_grouped_gemm_two_stage.cpp
)
list
(
APPEND PROFILER_SOURCES profile_grouped_gemm_fastgelu.cpp
)
endif
()
list
(
APPEND PROFILER_SOURCES profile_gemm_multiply_add.cpp
)
list
(
APPEND PROFILER_SOURCES profile_batched_gemm.cpp
)
list
(
APPEND PROFILER_SOURCES profile_batched_gemm_reduce.cpp
)
list
(
APPEND PROFILER_SOURCES profile_gemm_add_multiply.cpp
)
list
(
APPEND PROFILER_SOURCES profile_gemm_bias_add_reduce.cpp
)
list
(
APPEND PROFILER_SOURCES profile_gemm_splitk.cpp
)
list
(
APPEND PROFILER_SOURCES profile_conv_fwd_bias_relu.cpp
)
list
(
APPEND PROFILER_SOURCES profile_conv_fwd_bias_relu_add.cpp
)
list
(
APPEND PROFILER_SOURCES profile_conv_bwd_data.cpp
)
list
(
APPEND PROFILER_SOURCES profile_conv_fwd.cpp
)
endif
()
<<<<<<< HEAD
if
(
DTYPES MATCHES
"fp16"
OR NOT DEFINED DTYPES
)
list
(
APPEND PROFILER_SOURCES profile_batched_gemm_gemm.cpp
)
list
(
APPEND PROFILER_SOURCES profile_gemm_fastgelu.cpp
)
...
...
@@ -55,11 +73,20 @@ if(DTYPES MATCHES "fp16" OR NOT DEFINED DTYPES)
list
(
APPEND PROFILER_SOURCES profile_grouped_gemm_fixed_nk.cpp
)
list
(
APPEND PROFILER_SOURCES profile_grouped_gemm_fastgelu.cpp
)
list
(
APPEND PROFILER_SOURCES profile_grouped_gemm_two_stage.cpp
)
=======
if
(
GPU_TARGETS MATCHES
"gfx11"
OR GPU_TARGETS MATCHES
"gfx9"
)
if
(
DTYPES MATCHES
"fp16"
OR NOT DEFINED DTYPES
)
list
(
APPEND PROFILER_SOURCES profile_gemm_bilinear.cpp
)
endif
()
list
(
APPEND PROFILER_SOURCES profile_grouped_conv_fwd.cpp
)
list
(
APPEND PROFILER_SOURCES profile_grouped_conv_bwd_data.cpp
)
list
(
APPEND PROFILER_SOURCES profile_grouped_conv_bwd_weight.cpp
)
>>>>>>> origin/develop
endif
()
if
(
D
TYPES MATCHES
"fp32"
OR DTYPES MATCHES
"fp64"
OR NOT DEFINED DTYPE
S
)
list
(
APPEND PROFILER_SOURCES profile_
contraction_bilinear
.cpp
)
list
(
APPEND PROFILER_SOURCES profile_
contraction_scale
.cpp
)
if
(
D
L_KERNEL
S
)
list
(
APPEND PROFILER_SOURCES profile_
batched_gemm_multi_d
.cpp
)
list
(
APPEND PROFILER_SOURCES profile_
grouped_conv_bwd_weight
.cpp
)
endif
()
set
(
PROFILER_EXECUTABLE ckProfiler
)
...
...
@@ -69,25 +96,6 @@ target_compile_options(${PROFILER_EXECUTABLE} PRIVATE -Wno-global-constructors)
target_link_libraries
(
${
PROFILER_EXECUTABLE
}
PRIVATE utility getopt::getopt
)
target_link_libraries
(
${
PROFILER_EXECUTABLE
}
PRIVATE device_gemm_instance
)
target_link_libraries
(
${
PROFILER_EXECUTABLE
}
PRIVATE device_gemm_splitk_instance
)
target_link_libraries
(
${
PROFILER_EXECUTABLE
}
PRIVATE device_gemm_add_multiply_instance
)
target_link_libraries
(
${
PROFILER_EXECUTABLE
}
PRIVATE device_gemm_multiply_add_instance
)
target_link_libraries
(
${
PROFILER_EXECUTABLE
}
PRIVATE device_gemm_reduce_instance
)
target_link_libraries
(
${
PROFILER_EXECUTABLE
}
PRIVATE device_gemm_bias_add_reduce_instance
)
target_link_libraries
(
${
PROFILER_EXECUTABLE
}
PRIVATE device_batched_gemm_instance
)
target_link_libraries
(
${
PROFILER_EXECUTABLE
}
PRIVATE device_batched_gemm_reduce_instance
)
target_link_libraries
(
${
PROFILER_EXECUTABLE
}
PRIVATE device_conv2d_fwd_instance
)
target_link_libraries
(
${
PROFILER_EXECUTABLE
}
PRIVATE device_grouped_conv1d_fwd_instance
)
target_link_libraries
(
${
PROFILER_EXECUTABLE
}
PRIVATE device_grouped_conv2d_fwd_instance
)
target_link_libraries
(
${
PROFILER_EXECUTABLE
}
PRIVATE device_grouped_conv3d_fwd_instance
)
target_link_libraries
(
${
PROFILER_EXECUTABLE
}
PRIVATE device_conv1d_bwd_data_instance
)
target_link_libraries
(
${
PROFILER_EXECUTABLE
}
PRIVATE device_conv2d_bwd_data_instance
)
target_link_libraries
(
${
PROFILER_EXECUTABLE
}
PRIVATE device_conv3d_bwd_data_instance
)
target_link_libraries
(
${
PROFILER_EXECUTABLE
}
PRIVATE device_grouped_conv1d_bwd_weight_instance
)
target_link_libraries
(
${
PROFILER_EXECUTABLE
}
PRIVATE device_grouped_conv2d_bwd_weight_instance
)
target_link_libraries
(
${
PROFILER_EXECUTABLE
}
PRIVATE device_grouped_conv3d_bwd_weight_instance
)
target_link_libraries
(
${
PROFILER_EXECUTABLE
}
PRIVATE device_conv2d_fwd_bias_relu_instance
)
target_link_libraries
(
${
PROFILER_EXECUTABLE
}
PRIVATE device_conv2d_fwd_bias_relu_add_instance
)
target_link_libraries
(
${
PROFILER_EXECUTABLE
}
PRIVATE device_normalization_fwd_instance
)
target_link_libraries
(
${
PROFILER_EXECUTABLE
}
PRIVATE device_normalization_bwd_data_instance
)
target_link_libraries
(
${
PROFILER_EXECUTABLE
}
PRIVATE device_normalization_bwd_gamma_beta_instance
)
...
...
@@ -97,39 +105,65 @@ target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_batchnorm_instance)
target_link_libraries
(
${
PROFILER_EXECUTABLE
}
PRIVATE device_pool3d_fwd_instance
)
target_link_libraries
(
${
PROFILER_EXECUTABLE
}
PRIVATE device_avg_pool3d_bwd_instance
)
target_link_libraries
(
${
PROFILER_EXECUTABLE
}
PRIVATE device_max_pool_bwd_instance
)
target_link_libraries
(
${
PROFILER_EXECUTABLE
}
PRIVATE device_grouped_conv2d_bwd_data_instance
)
target_link_libraries
(
${
PROFILER_EXECUTABLE
}
PRIVATE device_grouped_conv3d_bwd_data_instance
)
target_link_libraries
(
${
PROFILER_EXECUTABLE
}
PRIVATE device_image_to_column_instance
)
target_link_libraries
(
${
PROFILER_EXECUTABLE
}
PRIVATE device_column_to_image_instance
)
target_link_libraries
(
${
PROFILER_EXECUTABLE
}
PRIVATE device_transpose_instance
)
target_link_libraries
(
${
PROFILER_EXECUTABLE
}
PRIVATE device_permute_scale_instance
)
if
(
DTYPES MATCHES
"fp32"
OR DTYPES MATCHES
"fp64"
OR NOT DEFINED DTYPES
)
target_link_libraries
(
${
PROFILER_EXECUTABLE
}
PRIVATE device_contraction_bilinear_instance
)
target_link_libraries
(
${
PROFILER_EXECUTABLE
}
PRIVATE device_contraction_scale_instance
)
if
(
GPU_TARGETS MATCHES
"gfx9"
)
if
(
DTYPES MATCHES
"fp32"
OR DTYPES MATCHES
"fp64"
OR NOT DEFINED DTYPES
)
target_link_libraries
(
${
PROFILER_EXECUTABLE
}
PRIVATE device_contraction_bilinear_instance
)
target_link_libraries
(
${
PROFILER_EXECUTABLE
}
PRIVATE device_contraction_scale_instance
)
endif
()
if
(
DTYPES MATCHES
"fp16"
OR NOT DEFINED DTYPES
)
target_link_libraries
(
${
PROFILER_EXECUTABLE
}
PRIVATE device_gemm_add_instance
)
target_link_libraries
(
${
PROFILER_EXECUTABLE
}
PRIVATE device_gemm_add_add_fastgelu_instance
)
target_link_libraries
(
${
PROFILER_EXECUTABLE
}
PRIVATE device_gemm_fastgelu_instance
)
target_link_libraries
(
${
PROFILER_EXECUTABLE
}
PRIVATE device_batched_gemm_gemm_instance
)
target_link_libraries
(
${
PROFILER_EXECUTABLE
}
PRIVATE device_batched_gemm_add_relu_gemm_add_instance
)
target_link_libraries
(
${
PROFILER_EXECUTABLE
}
PRIVATE device_grouped_gemm_instance
)
target_link_libraries
(
${
PROFILER_EXECUTABLE
}
PRIVATE device_gemm_streamk_instance
)
target_link_libraries
(
${
PROFILER_EXECUTABLE
}
PRIVATE device_gemm_add_fastgelu_instance
)
target_link_libraries
(
${
PROFILER_EXECUTABLE
}
PRIVATE device_gemm_add_relu_instance
)
target_link_libraries
(
${
PROFILER_EXECUTABLE
}
PRIVATE device_gemm_add_silu_instance
)
target_link_libraries
(
${
PROFILER_EXECUTABLE
}
PRIVATE device_gemm_add_relu_add_layernorm_instance
)
target_link_libraries
(
${
PROFILER_EXECUTABLE
}
PRIVATE device_grouped_gemm_fixed_nk_instance
)
target_link_libraries
(
${
PROFILER_EXECUTABLE
}
PRIVATE device_grouped_gemm_fastgelu_instance
)
endif
()
target_link_libraries
(
${
PROFILER_EXECUTABLE
}
PRIVATE device_batched_gemm_instance
)
target_link_libraries
(
${
PROFILER_EXECUTABLE
}
PRIVATE device_batched_gemm_reduce_instance
)
target_link_libraries
(
${
PROFILER_EXECUTABLE
}
PRIVATE device_gemm_multiply_add_instance
)
target_link_libraries
(
${
PROFILER_EXECUTABLE
}
PRIVATE device_gemm_splitk_instance
)
target_link_libraries
(
${
PROFILER_EXECUTABLE
}
PRIVATE device_gemm_add_multiply_instance
)
target_link_libraries
(
${
PROFILER_EXECUTABLE
}
PRIVATE device_gemm_reduce_instance
)
target_link_libraries
(
${
PROFILER_EXECUTABLE
}
PRIVATE device_gemm_bias_add_reduce_instance
)
target_link_libraries
(
${
PROFILER_EXECUTABLE
}
PRIVATE device_conv2d_fwd_instance
)
target_link_libraries
(
${
PROFILER_EXECUTABLE
}
PRIVATE device_conv2d_fwd_bias_relu_instance
)
target_link_libraries
(
${
PROFILER_EXECUTABLE
}
PRIVATE device_conv2d_fwd_bias_relu_add_instance
)
target_link_libraries
(
${
PROFILER_EXECUTABLE
}
PRIVATE device_grouped_conv1d_fwd_instance
)
target_link_libraries
(
${
PROFILER_EXECUTABLE
}
PRIVATE device_conv1d_bwd_data_instance
)
target_link_libraries
(
${
PROFILER_EXECUTABLE
}
PRIVATE device_conv3d_bwd_data_instance
)
target_link_libraries
(
${
PROFILER_EXECUTABLE
}
PRIVATE device_conv2d_bwd_data_instance
)
target_link_libraries
(
${
PROFILER_EXECUTABLE
}
PRIVATE device_grouped_conv1d_bwd_weight_instance
)
target_link_libraries
(
${
PROFILER_EXECUTABLE
}
PRIVATE device_grouped_conv2d_bwd_weight_instance
)
endif
()
if
(
GPU_TARGETS MATCHES
"gfx9"
OR GPU_TARGETS MATCHES
"gfx11"
)
if
(
DTYPES MATCHES
"fp16"
OR NOT DEFINED DTYPES
)
target_link_libraries
(
${
PROFILER_EXECUTABLE
}
PRIVATE device_gemm_bilinear_instance
)
endif
()
target_link_libraries
(
${
PROFILER_EXECUTABLE
}
PRIVATE device_grouped_conv3d_fwd_instance
)
target_link_libraries
(
${
PROFILER_EXECUTABLE
}
PRIVATE device_grouped_conv2d_bwd_data_instance
)
target_link_libraries
(
${
PROFILER_EXECUTABLE
}
PRIVATE device_grouped_conv3d_bwd_data_instance
)
target_link_libraries
(
${
PROFILER_EXECUTABLE
}
PRIVATE device_grouped_conv2d_fwd_instance
)
target_link_libraries
(
${
PROFILER_EXECUTABLE
}
PRIVATE device_grouped_conv3d_bwd_weight_instance
)
endif
()
if
(
DL_KERNELS
)
target_link_libraries
(
${
PROFILER_EXECUTABLE
}
PRIVATE device_batched_gemm_multi_d_instance
)
endif
()
if
(
DTYPES MATCHES
"fp16"
OR NOT DEFINED DTYPES
)
target_link_libraries
(
${
PROFILER_EXECUTABLE
}
PRIVATE device_gemm_add_instance
)
target_link_libraries
(
${
PROFILER_EXECUTABLE
}
PRIVATE device_gemm_add_fastgelu_instance
)
target_link_libraries
(
${
PROFILER_EXECUTABLE
}
PRIVATE device_gemm_add_relu_instance
)
target_link_libraries
(
${
PROFILER_EXECUTABLE
}
PRIVATE device_gemm_add_silu_instance
)
target_link_libraries
(
${
PROFILER_EXECUTABLE
}
PRIVATE device_gemm_add_relu_add_layernorm_instance
)
target_link_libraries
(
${
PROFILER_EXECUTABLE
}
PRIVATE device_gemm_bilinear_instance
)
target_link_libraries
(
${
PROFILER_EXECUTABLE
}
PRIVATE device_gemm_add_add_fastgelu_instance
)
target_link_libraries
(
${
PROFILER_EXECUTABLE
}
PRIVATE device_gemm_streamk_instance
)
target_link_libraries
(
${
PROFILER_EXECUTABLE
}
PRIVATE device_gemm_fastgelu_instance
)
target_link_libraries
(
${
PROFILER_EXECUTABLE
}
PRIVATE device_batched_gemm_gemm_instance
)
target_link_libraries
(
${
PROFILER_EXECUTABLE
}
PRIVATE device_batched_gemm_add_relu_gemm_add_instance
)
target_link_libraries
(
${
PROFILER_EXECUTABLE
}
PRIVATE device_grouped_gemm_instance
)
target_link_libraries
(
${
PROFILER_EXECUTABLE
}
PRIVATE device_grouped_gemm_fixed_nk_instance
)
target_link_libraries
(
${
PROFILER_EXECUTABLE
}
PRIVATE device_grouped_gemm_fastgelu_instance
)
target_link_libraries
(
${
PROFILER_EXECUTABLE
}
PRIVATE device_grouped_conv1d_bwd_weight_instance
)
target_link_libraries
(
${
PROFILER_EXECUTABLE
}
PRIVATE device_grouped_conv2d_bwd_weight_instance
)
target_link_libraries
(
${
PROFILER_EXECUTABLE
}
PRIVATE device_grouped_conv3d_bwd_weight_instance
)
endif
()
rocm_install
(
TARGETS
${
PROFILER_EXECUTABLE
}
COMPONENT profiler
)
profiler/src/profile_grouped_conv_fwd.cpp
View file @
9f1b4276
...
...
@@ -25,6 +25,7 @@ enum struct ConvDataType
INT8_INT8_INT8
,
// 3
F8_F8_F8
,
// 4
BF8_BF8_F8
,
// 5
F8_BF8_F8
,
// 6
};
#define OP_NAME "grouped_conv_fwd"
...
...
@@ -40,7 +41,8 @@ static void print_helper_msg()
<<
" 2: Input bf16, Weight bf16, Output bf16
\n
"
<<
" 3: Input int8, Weight int8, Output int8
\n
"
<<
" 4: Input fp8, Weight fp8, Output fp8
\n
"
<<
" 5: Input bf8, Weight bf8, Output fp8)
\n
"
<<
" 5: Input bf8, Weight bf8, Output fp8
\n
"
<<
" 6: Input fp8, Weight bf8, Output fp8)
\n
"
<<
"arg3: tensor layout (0: Input[G, N, Hi, Wi, C], Weight[G, K, Y, X, C], Output[G, N, Ho, Wo, K]
\n
"
<<
" 1: Input[N, Hi, Wi, G, C], Weight[G, K, Y, X, C], Output[N, Ho, Wo, G, K])
\n
"
<<
"arg4: verification (0: no, 1: yes)
\n
"
...
...
@@ -118,7 +120,9 @@ int profile_grouped_conv_fwd(int argc, char* argv[])
auto
out_layout
,
auto
in_type
,
auto
wei_type
,
auto
out_type
)
{
auto
out_type
,
auto
a_compute_type
,
auto
b_compute_type
)
{
constexpr
ck
::
index_t
NDimSpatial
=
num_dim_spatial_tmp
.
value
;
using
InLayout
=
decltype
(
in_layout
);
...
...
@@ -129,13 +133,18 @@ int profile_grouped_conv_fwd(int argc, char* argv[])
using
WeiDataType
=
decltype
(
wei_type
);
using
OutDataType
=
decltype
(
out_type
);
using
AComputeType
=
decltype
(
a_compute_type
);
using
BComputeType
=
decltype
(
b_compute_type
);
bool
pass
=
ck
::
profiler
::
profile_grouped_conv_fwd_impl
<
NDimSpatial
,
InLayout
,
WeiLayout
,
OutLayout
,
InDataType
,
WeiDataType
,
OutDataType
>
(
OutDataType
,
AComputeType
,
BComputeType
>
(
do_verification
,
init_method
,
do_log
,
time_kernel
,
params
);
return
pass
?
0
:
1
;
...
...
@@ -146,57 +155,59 @@ int profile_grouped_conv_fwd(int argc, char* argv[])
{
if
(
data_type
==
ConvDataType
::
F32_F32_F32
)
{
return
profile
(
I1
,
GNWC
{},
GKXC
{},
GNWK
{},
F32
{},
F32
{},
F32
{});
return
profile
(
I1
,
GNWC
{},
GKXC
{},
GNWK
{},
F32
{},
F32
{},
F32
{},
F32
{},
F32
{});
}
else
if
(
data_type
==
ConvDataType
::
F16_F16_F16
)
{
return
profile
(
I1
,
GNWC
{},
GKXC
{},
GNWK
{},
F16
{},
F16
{},
F16
{});
return
profile
(
I1
,
GNWC
{},
GKXC
{},
GNWK
{},
F16
{},
F16
{},
F16
{},
F16
{},
F16
{});
}
else
if
(
data_type
==
ConvDataType
::
BF16_BF16_BF16
)
{
return
profile
(
I1
,
GNWC
{},
GKXC
{},
GNWK
{},
BF16
{},
BF16
{},
BF16
{});
return
profile
(
I1
,
GNWC
{},
GKXC
{},
GNWK
{},
BF16
{},
BF16
{},
BF16
{},
BF16
{},
BF16
{});
}
else
if
(
data_type
==
ConvDataType
::
INT8_INT8_INT8
)
{
return
profile
(
I1
,
GNWC
{},
GKXC
{},
GNWK
{},
INT8
{},
INT8
{},
INT8
{});
return
profile
(
I1
,
GNWC
{},
GKXC
{},
GNWK
{},
INT8
{},
INT8
{},
INT8
{},
INT8
{},
INT8
{});
}
}
else
if
(
num_dim_spatial
==
2
&&
layout
==
ConvLayout
::
GNHWC_GKYXC_GNHWK
)
{
if
(
data_type
==
ConvDataType
::
F32_F32_F32
)
{
return
profile
(
I2
,
GNHWC
{},
GKYXC
{},
GNHWK
{},
F32
{},
F32
{},
F32
{});
return
profile
(
I2
,
GNHWC
{},
GKYXC
{},
GNHWK
{},
F32
{},
F32
{},
F32
{},
F32
{},
F32
{});
}
else
if
(
data_type
==
ConvDataType
::
F16_F16_F16
)
{
return
profile
(
I2
,
GNHWC
{},
GKYXC
{},
GNHWK
{},
F16
{},
F16
{},
F16
{});
return
profile
(
I2
,
GNHWC
{},
GKYXC
{},
GNHWK
{},
F16
{},
F16
{},
F16
{},
F16
{},
F16
{});
}
else
if
(
data_type
==
ConvDataType
::
BF16_BF16_BF16
)
{
return
profile
(
I2
,
GNHWC
{},
GKYXC
{},
GNHWK
{},
BF16
{},
BF16
{},
BF16
{});
return
profile
(
I2
,
GNHWC
{},
GKYXC
{},
GNHWK
{},
BF16
{},
BF16
{},
BF16
{},
BF16
{},
BF16
{});
}
else
if
(
data_type
==
ConvDataType
::
INT8_INT8_INT8
)
{
return
profile
(
I2
,
GNHWC
{},
GKYXC
{},
GNHWK
{},
INT8
{},
INT8
{},
INT8
{});
return
profile
(
I2
,
GNHWC
{},
GKYXC
{},
GNHWK
{},
INT8
{},
INT8
{},
INT8
{},
INT8
{},
INT8
{});
}
}
else
if
(
num_dim_spatial
==
3
&&
layout
==
ConvLayout
::
GNHWC_GKYXC_GNHWK
)
{
if
(
data_type
==
ConvDataType
::
F32_F32_F32
)
{
return
profile
(
I3
,
GNDHWC
{},
GKZYXC
{},
GNDHWK
{},
F32
{},
F32
{},
F32
{});
return
profile
(
I3
,
GNDHWC
{},
GKZYXC
{},
GNDHWK
{},
F32
{},
F32
{},
F32
{},
F32
{},
F32
{});
}
else
if
(
data_type
==
ConvDataType
::
F16_F16_F16
)
{
return
profile
(
I3
,
GNDHWC
{},
GKZYXC
{},
GNDHWK
{},
F16
{},
F16
{},
F16
{});
return
profile
(
I3
,
GNDHWC
{},
GKZYXC
{},
GNDHWK
{},
F16
{},
F16
{},
F16
{},
F16
{},
F16
{});
}
else
if
(
data_type
==
ConvDataType
::
BF16_BF16_BF16
)
{
return
profile
(
I3
,
GNDHWC
{},
GKZYXC
{},
GNDHWK
{},
BF16
{},
BF16
{},
BF16
{});
return
profile
(
I3
,
GNDHWC
{},
GKZYXC
{},
GNDHWK
{},
BF16
{},
BF16
{},
BF16
{},
BF16
{},
BF16
{});
}
else
if
(
data_type
==
ConvDataType
::
INT8_INT8_INT8
)
{
return
profile
(
I3
,
GNDHWC
{},
GKZYXC
{},
GNDHWK
{},
INT8
{},
INT8
{},
INT8
{});
return
profile
(
I3
,
GNDHWC
{},
GKZYXC
{},
GNDHWK
{},
INT8
{},
INT8
{},
INT8
{},
INT8
{},
INT8
{});
}
}
// NHWGC_GKYXC_NHWGK
...
...
@@ -204,65 +215,71 @@ int profile_grouped_conv_fwd(int argc, char* argv[])
{
if
(
data_type
==
ConvDataType
::
F32_F32_F32
)
{
return
profile
(
I1
,
NWGC
{},
GKXC
{},
NWGK
{},
F32
{},
F32
{},
F32
{});
return
profile
(
I1
,
NWGC
{},
GKXC
{},
NWGK
{},
F32
{},
F32
{},
F32
{},
F32
{},
F32
{});
}
else
if
(
data_type
==
ConvDataType
::
F16_F16_F16
)
{
return
profile
(
I1
,
NWGC
{},
GKXC
{},
NWGK
{},
F16
{},
F16
{},
F16
{});
return
profile
(
I1
,
NWGC
{},
GKXC
{},
NWGK
{},
F16
{},
F16
{},
F16
{},
F16
{},
F16
{});
}
else
if
(
data_type
==
ConvDataType
::
BF16_BF16_BF16
)
{
return
profile
(
I1
,
NWGC
{},
GKXC
{},
NWGK
{},
BF16
{},
BF16
{},
BF16
{});
return
profile
(
I1
,
NWGC
{},
GKXC
{},
NWGK
{},
BF16
{},
BF16
{},
BF16
{},
BF16
{},
BF16
{});
}
else
if
(
data_type
==
ConvDataType
::
INT8_INT8_INT8
)
{
return
profile
(
I1
,
NWGC
{},
GKXC
{},
NWGK
{},
INT8
{},
INT8
{},
INT8
{});
return
profile
(
I1
,
NWGC
{},
GKXC
{},
NWGK
{},
INT8
{},
INT8
{},
INT8
{},
INT8
{},
INT8
{});
}
}
else
if
(
num_dim_spatial
==
2
&&
layout
==
ConvLayout
::
NHWGC_GKYXC_NHWGK
)
{
if
(
data_type
==
ConvDataType
::
F32_F32_F32
)
{
return
profile
(
I2
,
NHWGC
{},
GKYXC
{},
NHWGK
{},
F32
{},
F32
{},
F32
{});
return
profile
(
I2
,
NHWGC
{},
GKYXC
{},
NHWGK
{},
F32
{},
F32
{},
F32
{},
F32
{},
F32
{});
}
else
if
(
data_type
==
ConvDataType
::
F16_F16_F16
)
{
return
profile
(
I2
,
NHWGC
{},
GKYXC
{},
NHWGK
{},
F16
{},
F16
{},
F16
{});
return
profile
(
I2
,
NHWGC
{},
GKYXC
{},
NHWGK
{},
F16
{},
F16
{},
F16
{},
F16
{},
F16
{});
}
else
if
(
data_type
==
ConvDataType
::
BF16_BF16_BF16
)
{
return
profile
(
I2
,
NHWGC
{},
GKYXC
{},
NHWGK
{},
BF16
{},
BF16
{},
BF16
{});
return
profile
(
I2
,
NHWGC
{},
GKYXC
{},
NHWGK
{},
BF16
{},
BF16
{},
BF16
{},
BF16
{},
BF16
{});
}
else
if
(
data_type
==
ConvDataType
::
INT8_INT8_INT8
)
{
return
profile
(
I2
,
NHWGC
{},
GKYXC
{},
NHWGK
{},
INT8
{},
INT8
{},
INT8
{});
return
profile
(
I2
,
NHWGC
{},
GKYXC
{},
NHWGK
{},
INT8
{},
INT8
{},
INT8
{},
INT8
{},
INT8
{});
}
}
else
if
(
num_dim_spatial
==
3
&&
layout
==
ConvLayout
::
NHWGC_GKYXC_NHWGK
)
{
if
(
data_type
==
ConvDataType
::
F32_F32_F32
)
{
return
profile
(
I3
,
NDHWGC
{},
GKZYXC
{},
NDHWGK
{},
F32
{},
F32
{},
F32
{});
return
profile
(
I3
,
NDHWGC
{},
GKZYXC
{},
NDHWGK
{},
F32
{},
F32
{},
F32
{},
F32
{},
F32
{});
}
else
if
(
data_type
==
ConvDataType
::
F16_F16_F16
)
{
return
profile
(
I3
,
NDHWGC
{},
GKZYXC
{},
NDHWGK
{},
F16
{},
F16
{},
F16
{});
return
profile
(
I3
,
NDHWGC
{},
GKZYXC
{},
NDHWGK
{},
F16
{},
F16
{},
F16
{},
F16
{},
F16
{});
}
else
if
(
data_type
==
ConvDataType
::
BF16_BF16_BF16
)
{
return
profile
(
I3
,
NDHWGC
{},
GKZYXC
{},
NDHWGK
{},
BF16
{},
BF16
{},
BF16
{});
return
profile
(
I3
,
NDHWGC
{},
GKZYXC
{},
NDHWGK
{},
BF16
{},
BF16
{},
BF16
{},
BF16
{},
BF16
{});
}
else
if
(
data_type
==
ConvDataType
::
INT8_INT8_INT8
)
{
return
profile
(
I3
,
NDHWGC
{},
GKZYXC
{},
NDHWGK
{},
INT8
{},
INT8
{},
INT8
{});
return
profile
(
I3
,
NDHWGC
{},
GKZYXC
{},
NDHWGK
{},
INT8
{},
INT8
{},
INT8
{},
INT8
{},
INT8
{});
}
else
if
(
data_type
==
ConvDataType
::
F8_F8_F8
)
{
return
profile
(
I3
,
NDHWGC
{},
GKZYXC
{},
NDHWGK
{},
F8
{},
F8
{},
F8
{});
return
profile
(
I3
,
NDHWGC
{},
GKZYXC
{},
NDHWGK
{},
F8
{},
F8
{},
F8
{},
F8
{},
F8
{});
}
else
if
(
data_type
==
ConvDataType
::
BF8_BF8_F8
)
{
return
profile
(
I3
,
NDHWGC
{},
GKZYXC
{},
NDHWGK
{},
BF8
{},
BF8
{},
F8
{});
return
profile
(
I3
,
NDHWGC
{},
GKZYXC
{},
NDHWGK
{},
BF8
{},
BF8
{},
F8
{},
BF8
{},
BF8
{});
}
else
if
(
data_type
==
ConvDataType
::
F8_BF8_F8
)
{
return
profile
(
I3
,
NDHWGC
{},
GKZYXC
{},
NDHWGK
{},
F8
{},
BF8
{},
F8
{},
F8
{},
BF8
{});
}
}
...
...
profiler/src/profile_grouped_gemm_two_stage.cpp
View file @
9f1b4276
...
...
@@ -17,9 +17,9 @@ enum struct GemmMatrixLayout
enum
struct
GemmDataType
{
F16_F16_F16
,
// 0
BF16_INT8_BF16
,
// 1
BF16_BF16_BF16
// 2
F16_F16_F16
,
// 0
BF16_INT8_BF16
,
// 1
BF16_BF16_BF16
// 2
};
#define OP_NAME "grouped_gemm_two_stage"
...
...
test/CMakeLists.txt
View file @
9f1b4276
...
...
@@ -46,7 +46,18 @@ function(add_test_executable TEST_NAME)
list
(
REMOVE_ITEM ARGN
"
${
source
}
"
)
endif
()
endforeach
()
foreach
(
source IN LISTS ARGN
)
if
(
NOT GPU_TARGETS MATCHES
"gfx9"
AND source MATCHES
"xdl"
)
message
(
"removing xdl test
${
source
}
"
)
list
(
REMOVE_ITEM ARGN
"
${
source
}
"
)
endif
()
endforeach
()
foreach
(
source IN LISTS ARGN
)
if
(
NOT GPU_TARGETS MATCHES
"gfx11"
AND source MATCHES
"wmma"
)
message
(
"removing wmma test
${
source
}
"
)
list
(
REMOVE_ITEM ARGN
"
${
source
}
"
)
endif
()
endforeach
()
#only continue if there are some source files left on the list
if
(
ARGN
)
add_executable
(
${
TEST_NAME
}
${
ARGN
}
)
...
...
@@ -100,6 +111,18 @@ function(add_gtest_executable TEST_NAME)
list
(
REMOVE_ITEM ARGN
"
${
source
}
"
)
endif
()
endforeach
()
foreach
(
source IN LISTS ARGN
)
if
(
NOT GPU_TARGETS MATCHES
"gfx9"
AND source MATCHES
"xdl"
)
message
(
"removing xdl test
${
source
}
"
)
list
(
REMOVE_ITEM ARGN
"
${
source
}
"
)
endif
()
endforeach
()
foreach
(
source IN LISTS ARGN
)
if
(
NOT GPU_TARGETS MATCHES
"gfx11"
AND source MATCHES
"wmma"
)
message
(
"removing wmma test
${
source
}
"
)
list
(
REMOVE_ITEM ARGN
"
${
source
}
"
)
endif
()
endforeach
()
#only continue if there are some source files left on the list
if
(
ARGN
)
add_executable
(
${
TEST_NAME
}
${
ARGN
}
)
...
...
test/batched_gemm/CMakeLists.txt
View file @
9f1b4276
list
(
APPEND gpu_list gfx908 gfx90a gfx940 gfx941 gfx942
)
set
(
target 0
)
foreach
(
gpu IN LISTS GPU_TARGETS
)
if
(
gpu IN_LIST gpu_list AND target EQUAL 0
)
add_gtest_executable
(
test_batched_gemm test_batched_gemm.cpp
)
add_gtest_executable
(
test_batched_gemm test_batched_gemm_xdl.cpp
)
if
(
result EQUAL 0
)
target_link_libraries
(
test_batched_gemm PRIVATE utility device_batched_gemm_instance
)
set
(
target 1
)
endif
()
endforeach
()
\ No newline at end of file
endif
()
test/batched_gemm/test_batched_gemm.cpp
→
test/batched_gemm/test_batched_gemm
_xdl
.cpp
View file @
9f1b4276
File moved
test/batched_gemm_gemm/CMakeLists.txt
View file @
9f1b4276
list
(
APPEND gpu_list gfx908 gfx90a gfx940 gfx941 gfx942
)
set
(
target 0
)
foreach
(
gpu IN LISTS GPU_TARGETS
)
if
(
gpu IN_LIST gpu_list AND target EQUAL 0
)
add_custom_target
(
test_batched_gemm_gemm
)
add_gtest_executable
(
test_batched_gemm_gemm_fp16 test_batched_gemm_gemm_fp16.cpp
)
if
(
result EQUAL 0
)
target_link_libraries
(
test_batched_gemm_gemm_fp16 PRIVATE utility device_batched_gemm_gemm_instance
)
add_dependencies
(
test_batched_gemm_gemm test_batched_gemm_gemm_fp16
)
set
(
target 1
)
endif
()
endif
()
endforeach
()
\ No newline at end of file
add_gtest_executable
(
test_batched_gemm_gemm_fp16 test_batched_gemm_gemm_fp16_xdl.cpp
)
if
(
result EQUAL 0
)
add_custom_target
(
test_batched_gemm_gemm
)
target_link_libraries
(
test_batched_gemm_gemm_fp16 PRIVATE utility device_batched_gemm_gemm_instance
)
add_dependencies
(
test_batched_gemm_gemm test_batched_gemm_gemm_fp16
)
endif
()
test/batched_gemm_gemm/test_batched_gemm_gemm_fp16.cpp
→
test/batched_gemm_gemm/test_batched_gemm_gemm_fp16
_xdl
.cpp
View file @
9f1b4276
File moved
test/batched_gemm_reduce/CMakeLists.txt
View file @
9f1b4276
list
(
APPEND gpu_list gfx908 gfx90a gfx940 gfx941 gfx942
)
set
(
target 0
)
foreach
(
gpu IN LISTS GPU_TARGETS
)
if
(
gpu IN_LIST gpu_list AND target EQUAL 0
)
add_test_executable
(
test_batched_gemm_reduce_fp16 batched_gemm_reduce_fp16.cpp
)
if
(
result EQUAL 0
)
target_link_libraries
(
test_batched_gemm_reduce_fp16 PRIVATE utility device_batched_gemm_reduce_instance
)
set
(
target 1
)
endif
()
add_test_executable
(
test_batched_gemm_reduce_fp16 batched_gemm_reduce_fp16_xdl.cpp
)
if
(
result EQUAL 0
)
target_link_libraries
(
test_batched_gemm_reduce_fp16 PRIVATE utility device_batched_gemm_reduce_instance
)
endif
()
endforeach
()
test/batched_gemm_reduce/batched_gemm_reduce_fp16.cpp
→
test/batched_gemm_reduce/batched_gemm_reduce_fp16
_xdl
.cpp
View file @
9f1b4276
File moved
test/batched_gemm_softmax_gemm/CMakeLists.txt
View file @
9f1b4276
list
(
APPEND gpu_list gfx908 gfx90a gfx940 gfx941 gfx942
)
set
(
target 0
)
foreach
(
gpu IN LISTS GPU_TARGETS
)
if
(
gpu IN_LIST gpu_list AND target EQUAL 0
)
add_custom_target
(
test_batched_gemm_softmax_gemm
)
add_gtest_executable
(
test_batched_gemm_softmax_gemm_fp16 test_batched_gemm_softmax_gemm_fp16.cpp
)
if
(
result EQUAL 0
)
target_link_libraries
(
test_batched_gemm_softmax_gemm_fp16 PRIVATE utility device_batched_gemm_softmax_gemm_instance
)
add_dependencies
(
test_batched_gemm_softmax_gemm test_batched_gemm_softmax_gemm_fp16
)
set
(
target 1
)
endif
()
endif
()
endforeach
()
\ No newline at end of file
add_gtest_executable
(
test_batched_gemm_softmax_gemm_fp16 test_batched_gemm_softmax_gemm_fp16_xdl.cpp
)
if
(
result EQUAL 0
)
add_custom_target
(
test_batched_gemm_softmax_gemm
)
target_link_libraries
(
test_batched_gemm_softmax_gemm_fp16 PRIVATE utility device_batched_gemm_softmax_gemm_instance
)
add_dependencies
(
test_batched_gemm_softmax_gemm test_batched_gemm_softmax_gemm_fp16
)
endif
()
test/batched_gemm_softmax_gemm/test_batched_gemm_softmax_gemm_fp16.cpp
→
test/batched_gemm_softmax_gemm/test_batched_gemm_softmax_gemm_fp16
_xdl
.cpp
View file @
9f1b4276
File moved
test/batched_gemm_softmax_gemm_permute/CMakeLists.txt
View file @
9f1b4276
list
(
APPEND gpu_list gfx908 gfx90a gfx940 gfx941 gfx942
)
set
(
target 0
)
foreach
(
gpu IN LISTS GPU_TARGETS
)
if
(
gpu IN_LIST gpu_list AND target EQUAL 0
)
add_custom_target
(
test_batched_gemm_softmax_gemm_permute
)
add_gtest_executable
(
test_batched_gemm_softmax_gemm_permute_fp16 test_batched_gemm_softmax_gemm_permute_fp16.cpp
)
if
(
result EQUAL 0
)
target_link_libraries
(
test_batched_gemm_softmax_gemm_permute_fp16 PRIVATE utility device_batched_gemm_softmax_gemm_permute_instance
)
add_dependencies
(
test_batched_gemm_softmax_gemm_permute test_batched_gemm_softmax_gemm_permute_fp16
)
endif
()
add_gtest_executable
(
test_batched_gemm_bias_softmax_gemm_permute_fp16 test_batched_gemm_bias_softmax_gemm_permute_fp16.cpp
)
if
(
result EQUAL 0
)
target_link_libraries
(
test_batched_gemm_bias_softmax_gemm_permute_fp16 PRIVATE utility device_batched_gemm_softmax_gemm_permute_instance
)
add_dependencies
(
test_batched_gemm_softmax_gemm_permute test_batched_gemm_bias_softmax_gemm_permute_fp16
)
endif
()
add_gtest_executable
(
test_batched_gemm_softmax_gemm_permute_bf16 test_batched_gemm_softmax_gemm_permute_bf16.cpp
)
if
(
result EQUAL 0
)
target_link_libraries
(
test_batched_gemm_softmax_gemm_permute_bf16 PRIVATE utility device_batched_gemm_softmax_gemm_permute_instance
)
add_dependencies
(
test_batched_gemm_softmax_gemm_permute test_batched_gemm_softmax_gemm_permute_bf16
)
endif
()
add_gtest_executable
(
test_batched_gemm_bias_softmax_gemm_permute_bf16 test_batched_gemm_bias_softmax_gemm_permute_bf16.cpp
)
if
(
result EQUAL 0
)
target_link_libraries
(
test_batched_gemm_bias_softmax_gemm_permute_bf16 PRIVATE utility device_batched_gemm_softmax_gemm_permute_instance
)
add_dependencies
(
test_batched_gemm_softmax_gemm_permute test_batched_gemm_bias_softmax_gemm_permute_bf16
)
endif
()
set
(
target 1
)
endif
()
endforeach
()
\ No newline at end of file
add_custom_target
(
test_batched_gemm_softmax_gemm_permute
)
add_gtest_executable
(
test_batched_gemm_softmax_gemm_permute_fp16 test_batched_gemm_softmax_gemm_permute_fp16_xdl.cpp
)
if
(
result EQUAL 0
)
target_link_libraries
(
test_batched_gemm_softmax_gemm_permute_fp16 PRIVATE utility device_batched_gemm_softmax_gemm_permute_instance
)
add_dependencies
(
test_batched_gemm_softmax_gemm_permute test_batched_gemm_softmax_gemm_permute_fp16
)
endif
()
add_gtest_executable
(
test_batched_gemm_bias_softmax_gemm_permute_fp16 test_batched_gemm_bias_softmax_gemm_permute_fp16_xdl.cpp
)
if
(
result EQUAL 0
)
target_link_libraries
(
test_batched_gemm_bias_softmax_gemm_permute_fp16 PRIVATE utility device_batched_gemm_softmax_gemm_permute_instance
)
add_dependencies
(
test_batched_gemm_softmax_gemm_permute test_batched_gemm_bias_softmax_gemm_permute_fp16
)
endif
()
add_gtest_executable
(
test_batched_gemm_softmax_gemm_permute_bf16 test_batched_gemm_softmax_gemm_permute_bf16_xdl.cpp
)
if
(
result EQUAL 0
)
target_link_libraries
(
test_batched_gemm_softmax_gemm_permute_bf16 PRIVATE utility device_batched_gemm_softmax_gemm_permute_instance
)
add_dependencies
(
test_batched_gemm_softmax_gemm_permute test_batched_gemm_softmax_gemm_permute_bf16
)
endif
()
add_gtest_executable
(
test_batched_gemm_bias_softmax_gemm_permute_bf16 test_batched_gemm_bias_softmax_gemm_permute_bf16_xdl.cpp
)
if
(
result EQUAL 0
)
target_link_libraries
(
test_batched_gemm_bias_softmax_gemm_permute_bf16 PRIVATE utility device_batched_gemm_softmax_gemm_permute_instance
)
add_dependencies
(
test_batched_gemm_softmax_gemm_permute test_batched_gemm_bias_softmax_gemm_permute_bf16
)
endif
()
test/batched_gemm_softmax_gemm_permute/test_batched_gemm_bias_softmax_gemm_permute_bf16.cpp
→
test/batched_gemm_softmax_gemm_permute/test_batched_gemm_bias_softmax_gemm_permute_bf16
_xdl
.cpp
View file @
9f1b4276
File moved
Prev
1
…
4
5
6
7
8
9
10
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment