Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel_ROCM
Commits
7f65ac05
Commit
7f65ac05
authored
Apr 04, 2024
by
Jun Liu
Browse files
Merge branch 'develop' into amd-develop
parents
687d2b7e
7e5c81fe
Changes
234
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
349 additions
and
119 deletions
+349
-119
profiler/src/profile_grouped_conv_fwd.cpp
profiler/src/profile_grouped_conv_fwd.cpp
+52
-28
profiler/src/profile_grouped_gemm_two_stage.cpp
profiler/src/profile_grouped_gemm_two_stage.cpp
+157
-0
profiler/src/profile_permute_scale.cpp
profiler/src/profile_permute_scale.cpp
+24
-5
script/profile_permute_scale.sh
script/profile_permute_scale.sh
+43
-0
test/CMakeLists.txt
test/CMakeLists.txt
+25
-1
test/batched_gemm/CMakeLists.txt
test/batched_gemm/CMakeLists.txt
+3
-8
test/batched_gemm/test_batched_gemm_xdl.cpp
test/batched_gemm/test_batched_gemm_xdl.cpp
+0
-0
test/batched_gemm_gemm/CMakeLists.txt
test/batched_gemm_gemm/CMakeLists.txt
+6
-13
test/batched_gemm_gemm/test_batched_gemm_gemm_fp16_xdl.cpp
test/batched_gemm_gemm/test_batched_gemm_gemm_fp16_xdl.cpp
+0
-0
test/batched_gemm_reduce/CMakeLists.txt
test/batched_gemm_reduce/CMakeLists.txt
+3
-10
test/batched_gemm_reduce/batched_gemm_reduce_fp16_xdl.cpp
test/batched_gemm_reduce/batched_gemm_reduce_fp16_xdl.cpp
+0
-0
test/batched_gemm_softmax_gemm/CMakeLists.txt
test/batched_gemm_softmax_gemm/CMakeLists.txt
+6
-13
test/batched_gemm_softmax_gemm/test_batched_gemm_softmax_gemm_fp16_xdl.cpp
..._softmax_gemm/test_batched_gemm_softmax_gemm_fp16_xdl.cpp
+0
-0
test/batched_gemm_softmax_gemm_permute/CMakeLists.txt
test/batched_gemm_softmax_gemm_permute/CMakeLists.txt
+21
-29
test/batched_gemm_softmax_gemm_permute/test_batched_gemm_bias_softmax_gemm_permute_bf16_xdl.cpp
.../test_batched_gemm_bias_softmax_gemm_permute_bf16_xdl.cpp
+0
-0
test/batched_gemm_softmax_gemm_permute/test_batched_gemm_bias_softmax_gemm_permute_fp16_xdl.cpp
.../test_batched_gemm_bias_softmax_gemm_permute_fp16_xdl.cpp
+0
-0
test/batched_gemm_softmax_gemm_permute/test_batched_gemm_softmax_gemm_permute_bf16_xdl.cpp
...rmute/test_batched_gemm_softmax_gemm_permute_bf16_xdl.cpp
+0
-0
test/batched_gemm_softmax_gemm_permute/test_batched_gemm_softmax_gemm_permute_fp16_xdl.cpp
...rmute/test_batched_gemm_softmax_gemm_permute_fp16_xdl.cpp
+0
-0
test/contraction/CMakeLists.txt
test/contraction/CMakeLists.txt
+9
-12
test/contraction/test_contraction_interface_xdl.cpp
test/contraction/test_contraction_interface_xdl.cpp
+0
-0
No files found.
profiler/src/profile_grouped_conv_fwd.cpp
View file @
7f65ac05
...
@@ -24,6 +24,8 @@ enum struct ConvDataType
...
@@ -24,6 +24,8 @@ enum struct ConvDataType
BF16_BF16_BF16
,
// 2
BF16_BF16_BF16
,
// 2
INT8_INT8_INT8
,
// 3
INT8_INT8_INT8
,
// 3
F8_F8_F8
,
// 4
F8_F8_F8
,
// 4
BF8_BF8_F8
,
// 5
F8_BF8_F8
,
// 6
};
};
#define OP_NAME "grouped_conv_fwd"
#define OP_NAME "grouped_conv_fwd"
...
@@ -38,7 +40,9 @@ static void print_helper_msg()
...
@@ -38,7 +40,9 @@ static void print_helper_msg()
<<
" 1: Input fp16, Weight fp16, Output fp16
\n
"
<<
" 1: Input fp16, Weight fp16, Output fp16
\n
"
<<
" 2: Input bf16, Weight bf16, Output bf16
\n
"
<<
" 2: Input bf16, Weight bf16, Output bf16
\n
"
<<
" 3: Input int8, Weight int8, Output int8
\n
"
<<
" 3: Input int8, Weight int8, Output int8
\n
"
<<
" 4: Input fp8, Weight fp8, Output fp8)
\n
"
<<
" 4: Input fp8, Weight fp8, Output fp8
\n
"
<<
" 5: Input bf8, Weight bf8, Output fp8
\n
"
<<
" 6: Input fp8, Weight bf8, Output fp8)
\n
"
<<
"arg3: tensor layout (0: Input[G, N, Hi, Wi, C], Weight[G, K, Y, X, C], Output[G, N, Ho, Wo, K]
\n
"
<<
"arg3: tensor layout (0: Input[G, N, Hi, Wi, C], Weight[G, K, Y, X, C], Output[G, N, Ho, Wo, K]
\n
"
<<
" 1: Input[N, Hi, Wi, G, C], Weight[G, K, Y, X, C], Output[N, Ho, Wo, G, K])
\n
"
<<
" 1: Input[N, Hi, Wi, G, C], Weight[G, K, Y, X, C], Output[N, Ho, Wo, G, K])
\n
"
<<
"arg4: verification (0: no, 1: yes)
\n
"
<<
"arg4: verification (0: no, 1: yes)
\n
"
...
@@ -82,6 +86,7 @@ int profile_grouped_conv_fwd(int argc, char* argv[])
...
@@ -82,6 +86,7 @@ int profile_grouped_conv_fwd(int argc, char* argv[])
using
BF16
=
ck
::
bhalf_t
;
using
BF16
=
ck
::
bhalf_t
;
using
INT8
=
int8_t
;
using
INT8
=
int8_t
;
using
F8
=
ck
::
f8_t
;
using
F8
=
ck
::
f8_t
;
using
BF8
=
ck
::
bf8_t
;
//
//
using
GNWC
=
ck
::
tensor_layout
::
convolution
::
GNWC
;
using
GNWC
=
ck
::
tensor_layout
::
convolution
::
GNWC
;
...
@@ -115,7 +120,9 @@ int profile_grouped_conv_fwd(int argc, char* argv[])
...
@@ -115,7 +120,9 @@ int profile_grouped_conv_fwd(int argc, char* argv[])
auto
out_layout
,
auto
out_layout
,
auto
in_type
,
auto
in_type
,
auto
wei_type
,
auto
wei_type
,
auto
out_type
)
{
auto
out_type
,
auto
a_compute_type
,
auto
b_compute_type
)
{
constexpr
ck
::
index_t
NDimSpatial
=
num_dim_spatial_tmp
.
value
;
constexpr
ck
::
index_t
NDimSpatial
=
num_dim_spatial_tmp
.
value
;
using
InLayout
=
decltype
(
in_layout
);
using
InLayout
=
decltype
(
in_layout
);
...
@@ -126,13 +133,18 @@ int profile_grouped_conv_fwd(int argc, char* argv[])
...
@@ -126,13 +133,18 @@ int profile_grouped_conv_fwd(int argc, char* argv[])
using
WeiDataType
=
decltype
(
wei_type
);
using
WeiDataType
=
decltype
(
wei_type
);
using
OutDataType
=
decltype
(
out_type
);
using
OutDataType
=
decltype
(
out_type
);
using
AComputeType
=
decltype
(
a_compute_type
);
using
BComputeType
=
decltype
(
b_compute_type
);
bool
pass
=
ck
::
profiler
::
profile_grouped_conv_fwd_impl
<
NDimSpatial
,
bool
pass
=
ck
::
profiler
::
profile_grouped_conv_fwd_impl
<
NDimSpatial
,
InLayout
,
InLayout
,
WeiLayout
,
WeiLayout
,
OutLayout
,
OutLayout
,
InDataType
,
InDataType
,
WeiDataType
,
WeiDataType
,
OutDataType
>
(
OutDataType
,
AComputeType
,
BComputeType
>
(
do_verification
,
init_method
,
do_log
,
time_kernel
,
params
);
do_verification
,
init_method
,
do_log
,
time_kernel
,
params
);
return
pass
?
0
:
1
;
return
pass
?
0
:
1
;
...
@@ -143,57 +155,59 @@ int profile_grouped_conv_fwd(int argc, char* argv[])
...
@@ -143,57 +155,59 @@ int profile_grouped_conv_fwd(int argc, char* argv[])
{
{
if
(
data_type
==
ConvDataType
::
F32_F32_F32
)
if
(
data_type
==
ConvDataType
::
F32_F32_F32
)
{
{
return
profile
(
I1
,
GNWC
{},
GKXC
{},
GNWK
{},
F32
{},
F32
{},
F32
{});
return
profile
(
I1
,
GNWC
{},
GKXC
{},
GNWK
{},
F32
{},
F32
{},
F32
{},
F32
{},
F32
{});
}
}
else
if
(
data_type
==
ConvDataType
::
F16_F16_F16
)
else
if
(
data_type
==
ConvDataType
::
F16_F16_F16
)
{
{
return
profile
(
I1
,
GNWC
{},
GKXC
{},
GNWK
{},
F16
{},
F16
{},
F16
{});
return
profile
(
I1
,
GNWC
{},
GKXC
{},
GNWK
{},
F16
{},
F16
{},
F16
{},
F16
{},
F16
{});
}
}
else
if
(
data_type
==
ConvDataType
::
BF16_BF16_BF16
)
else
if
(
data_type
==
ConvDataType
::
BF16_BF16_BF16
)
{
{
return
profile
(
I1
,
GNWC
{},
GKXC
{},
GNWK
{},
BF16
{},
BF16
{},
BF16
{});
return
profile
(
I1
,
GNWC
{},
GKXC
{},
GNWK
{},
BF16
{},
BF16
{},
BF16
{},
BF16
{},
BF16
{});
}
}
else
if
(
data_type
==
ConvDataType
::
INT8_INT8_INT8
)
else
if
(
data_type
==
ConvDataType
::
INT8_INT8_INT8
)
{
{
return
profile
(
I1
,
GNWC
{},
GKXC
{},
GNWK
{},
INT8
{},
INT8
{},
INT8
{});
return
profile
(
I1
,
GNWC
{},
GKXC
{},
GNWK
{},
INT8
{},
INT8
{},
INT8
{},
INT8
{},
INT8
{});
}
}
}
}
else
if
(
num_dim_spatial
==
2
&&
layout
==
ConvLayout
::
GNHWC_GKYXC_GNHWK
)
else
if
(
num_dim_spatial
==
2
&&
layout
==
ConvLayout
::
GNHWC_GKYXC_GNHWK
)
{
{
if
(
data_type
==
ConvDataType
::
F32_F32_F32
)
if
(
data_type
==
ConvDataType
::
F32_F32_F32
)
{
{
return
profile
(
I2
,
GNHWC
{},
GKYXC
{},
GNHWK
{},
F32
{},
F32
{},
F32
{});
return
profile
(
I2
,
GNHWC
{},
GKYXC
{},
GNHWK
{},
F32
{},
F32
{},
F32
{},
F32
{},
F32
{});
}
}
else
if
(
data_type
==
ConvDataType
::
F16_F16_F16
)
else
if
(
data_type
==
ConvDataType
::
F16_F16_F16
)
{
{
return
profile
(
I2
,
GNHWC
{},
GKYXC
{},
GNHWK
{},
F16
{},
F16
{},
F16
{});
return
profile
(
I2
,
GNHWC
{},
GKYXC
{},
GNHWK
{},
F16
{},
F16
{},
F16
{},
F16
{},
F16
{});
}
}
else
if
(
data_type
==
ConvDataType
::
BF16_BF16_BF16
)
else
if
(
data_type
==
ConvDataType
::
BF16_BF16_BF16
)
{
{
return
profile
(
I2
,
GNHWC
{},
GKYXC
{},
GNHWK
{},
BF16
{},
BF16
{},
BF16
{});
return
profile
(
I2
,
GNHWC
{},
GKYXC
{},
GNHWK
{},
BF16
{},
BF16
{},
BF16
{},
BF16
{},
BF16
{});
}
}
else
if
(
data_type
==
ConvDataType
::
INT8_INT8_INT8
)
else
if
(
data_type
==
ConvDataType
::
INT8_INT8_INT8
)
{
{
return
profile
(
I2
,
GNHWC
{},
GKYXC
{},
GNHWK
{},
INT8
{},
INT8
{},
INT8
{});
return
profile
(
I2
,
GNHWC
{},
GKYXC
{},
GNHWK
{},
INT8
{},
INT8
{},
INT8
{},
INT8
{},
INT8
{});
}
}
}
}
else
if
(
num_dim_spatial
==
3
&&
layout
==
ConvLayout
::
GNHWC_GKYXC_GNHWK
)
else
if
(
num_dim_spatial
==
3
&&
layout
==
ConvLayout
::
GNHWC_GKYXC_GNHWK
)
{
{
if
(
data_type
==
ConvDataType
::
F32_F32_F32
)
if
(
data_type
==
ConvDataType
::
F32_F32_F32
)
{
{
return
profile
(
I3
,
GNDHWC
{},
GKZYXC
{},
GNDHWK
{},
F32
{},
F32
{},
F32
{});
return
profile
(
I3
,
GNDHWC
{},
GKZYXC
{},
GNDHWK
{},
F32
{},
F32
{},
F32
{},
F32
{},
F32
{});
}
}
else
if
(
data_type
==
ConvDataType
::
F16_F16_F16
)
else
if
(
data_type
==
ConvDataType
::
F16_F16_F16
)
{
{
return
profile
(
I3
,
GNDHWC
{},
GKZYXC
{},
GNDHWK
{},
F16
{},
F16
{},
F16
{});
return
profile
(
I3
,
GNDHWC
{},
GKZYXC
{},
GNDHWK
{},
F16
{},
F16
{},
F16
{},
F16
{},
F16
{});
}
}
else
if
(
data_type
==
ConvDataType
::
BF16_BF16_BF16
)
else
if
(
data_type
==
ConvDataType
::
BF16_BF16_BF16
)
{
{
return
profile
(
I3
,
GNDHWC
{},
GKZYXC
{},
GNDHWK
{},
BF16
{},
BF16
{},
BF16
{});
return
profile
(
I3
,
GNDHWC
{},
GKZYXC
{},
GNDHWK
{},
BF16
{},
BF16
{},
BF16
{},
BF16
{},
BF16
{});
}
}
else
if
(
data_type
==
ConvDataType
::
INT8_INT8_INT8
)
else
if
(
data_type
==
ConvDataType
::
INT8_INT8_INT8
)
{
{
return
profile
(
I3
,
GNDHWC
{},
GKZYXC
{},
GNDHWK
{},
INT8
{},
INT8
{},
INT8
{});
return
profile
(
I3
,
GNDHWC
{},
GKZYXC
{},
GNDHWK
{},
INT8
{},
INT8
{},
INT8
{},
INT8
{},
INT8
{});
}
}
}
}
// NHWGC_GKYXC_NHWGK
// NHWGC_GKYXC_NHWGK
...
@@ -201,61 +215,71 @@ int profile_grouped_conv_fwd(int argc, char* argv[])
...
@@ -201,61 +215,71 @@ int profile_grouped_conv_fwd(int argc, char* argv[])
{
{
if
(
data_type
==
ConvDataType
::
F32_F32_F32
)
if
(
data_type
==
ConvDataType
::
F32_F32_F32
)
{
{
return
profile
(
I1
,
NWGC
{},
GKXC
{},
NWGK
{},
F32
{},
F32
{},
F32
{});
return
profile
(
I1
,
NWGC
{},
GKXC
{},
NWGK
{},
F32
{},
F32
{},
F32
{},
F32
{},
F32
{});
}
}
else
if
(
data_type
==
ConvDataType
::
F16_F16_F16
)
else
if
(
data_type
==
ConvDataType
::
F16_F16_F16
)
{
{
return
profile
(
I1
,
NWGC
{},
GKXC
{},
NWGK
{},
F16
{},
F16
{},
F16
{});
return
profile
(
I1
,
NWGC
{},
GKXC
{},
NWGK
{},
F16
{},
F16
{},
F16
{},
F16
{},
F16
{});
}
}
else
if
(
data_type
==
ConvDataType
::
BF16_BF16_BF16
)
else
if
(
data_type
==
ConvDataType
::
BF16_BF16_BF16
)
{
{
return
profile
(
I1
,
NWGC
{},
GKXC
{},
NWGK
{},
BF16
{},
BF16
{},
BF16
{});
return
profile
(
I1
,
NWGC
{},
GKXC
{},
NWGK
{},
BF16
{},
BF16
{},
BF16
{},
BF16
{},
BF16
{});
}
}
else
if
(
data_type
==
ConvDataType
::
INT8_INT8_INT8
)
else
if
(
data_type
==
ConvDataType
::
INT8_INT8_INT8
)
{
{
return
profile
(
I1
,
NWGC
{},
GKXC
{},
NWGK
{},
INT8
{},
INT8
{},
INT8
{});
return
profile
(
I1
,
NWGC
{},
GKXC
{},
NWGK
{},
INT8
{},
INT8
{},
INT8
{},
INT8
{},
INT8
{});
}
}
}
}
else
if
(
num_dim_spatial
==
2
&&
layout
==
ConvLayout
::
NHWGC_GKYXC_NHWGK
)
else
if
(
num_dim_spatial
==
2
&&
layout
==
ConvLayout
::
NHWGC_GKYXC_NHWGK
)
{
{
if
(
data_type
==
ConvDataType
::
F32_F32_F32
)
if
(
data_type
==
ConvDataType
::
F32_F32_F32
)
{
{
return
profile
(
I2
,
NHWGC
{},
GKYXC
{},
NHWGK
{},
F32
{},
F32
{},
F32
{});
return
profile
(
I2
,
NHWGC
{},
GKYXC
{},
NHWGK
{},
F32
{},
F32
{},
F32
{},
F32
{},
F32
{});
}
}
else
if
(
data_type
==
ConvDataType
::
F16_F16_F16
)
else
if
(
data_type
==
ConvDataType
::
F16_F16_F16
)
{
{
return
profile
(
I2
,
NHWGC
{},
GKYXC
{},
NHWGK
{},
F16
{},
F16
{},
F16
{});
return
profile
(
I2
,
NHWGC
{},
GKYXC
{},
NHWGK
{},
F16
{},
F16
{},
F16
{},
F16
{},
F16
{});
}
}
else
if
(
data_type
==
ConvDataType
::
BF16_BF16_BF16
)
else
if
(
data_type
==
ConvDataType
::
BF16_BF16_BF16
)
{
{
return
profile
(
I2
,
NHWGC
{},
GKYXC
{},
NHWGK
{},
BF16
{},
BF16
{},
BF16
{});
return
profile
(
I2
,
NHWGC
{},
GKYXC
{},
NHWGK
{},
BF16
{},
BF16
{},
BF16
{},
BF16
{},
BF16
{});
}
}
else
if
(
data_type
==
ConvDataType
::
INT8_INT8_INT8
)
else
if
(
data_type
==
ConvDataType
::
INT8_INT8_INT8
)
{
{
return
profile
(
I2
,
NHWGC
{},
GKYXC
{},
NHWGK
{},
INT8
{},
INT8
{},
INT8
{});
return
profile
(
I2
,
NHWGC
{},
GKYXC
{},
NHWGK
{},
INT8
{},
INT8
{},
INT8
{},
INT8
{},
INT8
{});
}
}
}
}
else
if
(
num_dim_spatial
==
3
&&
layout
==
ConvLayout
::
NHWGC_GKYXC_NHWGK
)
else
if
(
num_dim_spatial
==
3
&&
layout
==
ConvLayout
::
NHWGC_GKYXC_NHWGK
)
{
{
if
(
data_type
==
ConvDataType
::
F32_F32_F32
)
if
(
data_type
==
ConvDataType
::
F32_F32_F32
)
{
{
return
profile
(
I3
,
NDHWGC
{},
GKZYXC
{},
NDHWGK
{},
F32
{},
F32
{},
F32
{});
return
profile
(
I3
,
NDHWGC
{},
GKZYXC
{},
NDHWGK
{},
F32
{},
F32
{},
F32
{},
F32
{},
F32
{});
}
}
else
if
(
data_type
==
ConvDataType
::
F16_F16_F16
)
else
if
(
data_type
==
ConvDataType
::
F16_F16_F16
)
{
{
return
profile
(
I3
,
NDHWGC
{},
GKZYXC
{},
NDHWGK
{},
F16
{},
F16
{},
F16
{});
return
profile
(
I3
,
NDHWGC
{},
GKZYXC
{},
NDHWGK
{},
F16
{},
F16
{},
F16
{},
F16
{},
F16
{});
}
}
else
if
(
data_type
==
ConvDataType
::
BF16_BF16_BF16
)
else
if
(
data_type
==
ConvDataType
::
BF16_BF16_BF16
)
{
{
return
profile
(
I3
,
NDHWGC
{},
GKZYXC
{},
NDHWGK
{},
BF16
{},
BF16
{},
BF16
{});
return
profile
(
I3
,
NDHWGC
{},
GKZYXC
{},
NDHWGK
{},
BF16
{},
BF16
{},
BF16
{},
BF16
{},
BF16
{});
}
}
else
if
(
data_type
==
ConvDataType
::
INT8_INT8_INT8
)
else
if
(
data_type
==
ConvDataType
::
INT8_INT8_INT8
)
{
{
return
profile
(
I3
,
NDHWGC
{},
GKZYXC
{},
NDHWGK
{},
INT8
{},
INT8
{},
INT8
{});
return
profile
(
I3
,
NDHWGC
{},
GKZYXC
{},
NDHWGK
{},
INT8
{},
INT8
{},
INT8
{},
INT8
{},
INT8
{});
}
}
else
if
(
data_type
==
ConvDataType
::
F8_F8_F8
)
else
if
(
data_type
==
ConvDataType
::
F8_F8_F8
)
{
{
return
profile
(
I3
,
NDHWGC
{},
GKZYXC
{},
NDHWGK
{},
F8
{},
F8
{},
F8
{});
return
profile
(
I3
,
NDHWGC
{},
GKZYXC
{},
NDHWGK
{},
F8
{},
F8
{},
F8
{},
F8
{},
F8
{});
}
else
if
(
data_type
==
ConvDataType
::
BF8_BF8_F8
)
{
return
profile
(
I3
,
NDHWGC
{},
GKZYXC
{},
NDHWGK
{},
BF8
{},
BF8
{},
F8
{},
BF8
{},
BF8
{});
}
else
if
(
data_type
==
ConvDataType
::
F8_BF8_F8
)
{
return
profile
(
I3
,
NDHWGC
{},
GKZYXC
{},
NDHWGK
{},
F8
{},
BF8
{},
F8
{},
F8
{},
BF8
{});
}
}
}
}
...
...
profiler/src/profile_grouped_gemm_two_stage.cpp
0 → 100644
View file @
7f65ac05
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#include <iostream>
#include <numeric>
#include <initializer_list>
#include <cstdlib>
#include "profiler/profile_grouped_gemm_two_stage_impl.hpp"
#include "profiler_operation_registry.hpp"
enum
struct
GemmMatrixLayout
{
MK_KN_MN
,
// 0
};
enum
struct
GemmDataType
{
F16_F16_F16
,
// 0
BF16_INT8_BF16
// 1
};
#define OP_NAME "grouped_gemm_two_stage"
#define OP_DESC "Grouped GEMM TwoStage"
namespace
{
std
::
vector
<
int
>
argToIntArray
(
char
*
input
)
{
std
::
vector
<
int
>
out
;
std
::
istringstream
in
(
input
);
std
::
string
item
;
while
(
std
::
getline
(
in
,
item
,
','
))
{
out
.
push_back
(
std
::
stoi
(
item
));
}
return
out
;
}
int
profile_grouped_gemm_two_stage
(
int
argc
,
char
*
argv
[])
{
if
(
argc
<
14
)
{
std
::
cout
<<
"arg1: tensor operation ("
OP_NAME
": "
OP_DESC
")
\n
"
<<
"arg2: data type (0: fp16; 1: bf16@int8)
\n
"
<<
"arg3: matrix layout (0: A[m, k] * B[k, n] = C[m, n]);
\n
"
<<
"arg4: verification (0: no; 1: yes)
\n
"
<<
"arg5: initialization (0: no init; 1: integer value; 2: decimal value)
\n
"
<<
"arg6: print tensor value (0: no; 1: yes)
\n
"
<<
"arg7: time kernel (0=n0, 1=yes)
\n
"
<<
"arg8 to 13: Ms, Ns, Ks, StrideAs, StrideBs, StrideCs (e.g., 256,256 128,128 64,64 "
"64,64 64,64 128,128)
\n
"
<<
"arg15: kbatch value (default 1)
\n
"
<<
"optional:
\n
"
<<
"arg16: number of warm-up cycles (default 1)
\n
"
<<
"arg17: number of iterations (default 10)
\n
"
<<
std
::
endl
;
exit
(
1
);
}
const
auto
data_type
=
static_cast
<
GemmDataType
>
(
std
::
stoi
(
argv
[
2
]));
const
auto
layout
=
static_cast
<
GemmMatrixLayout
>
(
std
::
stoi
(
argv
[
3
]));
const
bool
do_verification
=
std
::
stoi
(
argv
[
4
]);
const
int
init_method
=
std
::
stoi
(
argv
[
5
]);
const
bool
do_log
=
std
::
stoi
(
argv
[
6
]);
const
bool
time_kernel
=
std
::
stoi
(
argv
[
7
]);
const
auto
Ms
=
argToIntArray
(
argv
[
8
]);
const
auto
Ns
=
argToIntArray
(
argv
[
9
]);
const
auto
Ks
=
argToIntArray
(
argv
[
10
]);
auto
StrideAs
=
argToIntArray
(
argv
[
11
]);
auto
StrideBs
=
argToIntArray
(
argv
[
12
]);
auto
StrideCs
=
argToIntArray
(
argv
[
13
]);
const
int
kbatch
=
argc
==
15
?
std
::
stoi
(
argv
[
14
])
:
1
;
const
int
DefaultStrideA
=
Ks
[
0
];
const
int
DefaultStrideB
=
Ns
[
0
];
const
int
DefaultStrideC
=
Ns
[
0
];
for
(
size_t
i
=
0
;
i
<
Ms
.
size
();
++
i
)
{
StrideAs
[
i
]
=
StrideAs
[
i
]
==
-
1
?
DefaultStrideA
:
StrideAs
[
i
];
StrideBs
[
i
]
=
StrideBs
[
i
]
==
-
1
?
DefaultStrideB
:
StrideBs
[
i
];
StrideCs
[
i
]
=
StrideCs
[
i
]
==
-
1
?
DefaultStrideC
:
StrideCs
[
i
];
}
int
n_warmup
=
1
;
int
n_iter
=
10
;
if
(
argc
==
17
)
{
n_warmup
=
std
::
stoi
(
argv
[
16
]);
n_iter
=
std
::
stoi
(
argv
[
17
]);
}
if
(
data_type
==
GemmDataType
::
F16_F16_F16
&&
layout
==
GemmMatrixLayout
::
MK_KN_MN
)
{
ck
::
profiler
::
profile_grouped_gemm_two_stage_impl
<
ck
::
half_t
,
ck
::
half_t
,
ck
::
half_t
,
float
,
ck
::
tensor_layout
::
gemm
::
RowMajor
,
ck
::
tensor_layout
::
gemm
::
RowMajor
,
ck
::
tensor_layout
::
gemm
::
RowMajor
>
(
do_verification
,
init_method
,
do_log
,
time_kernel
,
Ms
,
Ns
,
Ks
,
StrideAs
,
StrideBs
,
StrideCs
,
kbatch
,
n_warmup
,
n_iter
);
}
else
if
(
data_type
==
GemmDataType
::
BF16_INT8_BF16
&&
layout
==
GemmMatrixLayout
::
MK_KN_MN
)
{
ck
::
profiler
::
profile_grouped_gemm_two_stage_impl
<
ck
::
bhalf_t
,
int8_t
,
ck
::
bhalf_t
,
float
,
ck
::
tensor_layout
::
gemm
::
RowMajor
,
ck
::
tensor_layout
::
gemm
::
RowMajor
,
ck
::
tensor_layout
::
gemm
::
RowMajor
>
(
do_verification
,
init_method
,
do_log
,
time_kernel
,
Ms
,
Ns
,
Ks
,
StrideAs
,
StrideBs
,
StrideCs
,
kbatch
,
n_warmup
,
n_iter
);
}
else
{
throw
std
::
runtime_error
(
"wrong! this GEMM data_type & layout is not implemented"
);
}
return
0
;
}
}
// anonymous namespace
REGISTER_PROFILER_OPERATION
(
OP_NAME
,
OP_DESC
,
profile_grouped_gemm_two_stage
);
profiler/src/profile_permute_scale.cpp
View file @
7f65ac05
...
@@ -37,6 +37,20 @@ static void print_helper_msg()
...
@@ -37,6 +37,20 @@ static void print_helper_msg()
// clang-format on
// clang-format on
}
}
void
init_strides
(
const
std
::
vector
<
ck
::
index_t
>&
lengths
,
const
std
::
vector
<
ck
::
index_t
>&
dims_order
,
std
::
vector
<
ck
::
index_t
>&
strides
)
{
ck
::
index_t
stride
=
1
;
for
(
ck
::
index_t
d
=
lengths
.
size
()
-
1
;
d
>=
0
;
d
--
)
{
ck
::
index_t
dim
=
dims_order
[
d
];
strides
[
dim
]
=
stride
;
stride
*=
lengths
[
dim
];
}
}
}
// namespace
}
// namespace
int
profile_permute_scale
(
int
argc
,
char
*
argv
[])
int
profile_permute_scale
(
int
argc
,
char
*
argv
[])
...
@@ -58,16 +72,21 @@ int profile_permute_scale(int argc, char* argv[])
...
@@ -58,16 +72,21 @@ int profile_permute_scale(int argc, char* argv[])
const
int
num_dims
=
dims_argc
/
3
;
const
int
num_dims
=
dims_argc
/
3
;
std
::
vector
<
ck
::
index_t
>
lengths
(
num_dims
);
std
::
vector
<
ck
::
index_t
>
lengths
(
num_dims
);
std
::
vector
<
ck
::
index_t
>
input_
stri
de
s
(
num_dims
);
std
::
vector
<
ck
::
index_t
>
input_
dims_or
de
r
(
num_dims
);
std
::
vector
<
ck
::
index_t
>
output_
stri
de
s
(
num_dims
);
std
::
vector
<
ck
::
index_t
>
output_
dims_or
de
r
(
num_dims
);
for
(
int
i
=
0
;
i
<
num_dims
;
i
++
)
for
(
int
i
=
0
;
i
<
num_dims
;
i
++
)
{
{
lengths
[
i
]
=
std
::
stoi
(
argv
[
control_argc
+
i
]);
lengths
[
i
]
=
std
::
stoi
(
argv
[
control_argc
+
i
]);
input_
stri
de
s
[
i
]
=
std
::
stoi
(
argv
[
control_argc
+
num_dims
+
i
]);
input_
dims_or
de
r
[
i
]
=
std
::
stoi
(
argv
[
control_argc
+
num_dims
+
i
]);
output_
stri
de
s
[
i
]
=
std
::
stoi
(
argv
[
control_argc
+
2
*
num_dims
+
i
]);
output_
dims_or
de
r
[
i
]
=
std
::
stoi
(
argv
[
control_argc
+
2
*
num_dims
+
i
]);
}
}
std
::
vector
<
ck
::
index_t
>
input_strides
(
num_dims
);
std
::
vector
<
ck
::
index_t
>
output_strides
(
num_dims
);
init_strides
(
lengths
,
input_dims_order
,
input_strides
);
init_strides
(
lengths
,
output_dims_order
,
output_strides
);
using
F32
=
float
;
using
F32
=
float
;
using
F16
=
ck
::
half_t
;
using
F16
=
ck
::
half_t
;
...
...
script/profile_permute_scale.sh
0 → 100755
View file @
7f65ac05
#!/bin/bash
## GPU visibility
export
HIP_VISIBLE_DEVICES
=
0
DRIVER
=
"../build/bin/ckProfiler"
echo
$DRIVER
OP
=
$1
DATATYPE
=
$2
VERIFY
=
$3
INIT
=
$4
LOG
=
$5
TIME
=
$6
# 1D
######## op datatype verify init log time dims in_strides_order out_strides_order
$DRIVER
$OP
$DATATYPE
$VERIFY
$INIT
$LOG
$TIME
67108864 0 0
# # 2D
# ######## op datatype verify init log time dims in_strides_order out_strides_order
$DRIVER
$OP
$DATATYPE
$VERIFY
$INIT
$LOG
$TIME
8192 8192 0 1 1 0
$DRIVER
$OP
$DATATYPE
$VERIFY
$INIT
$LOG
$TIME
8192 8192 1 0 0 1
# 3D
######## op datatype verify init log time dims in_strides_order out_strides_order
$DRIVER
$OP
$DATATYPE
$VERIFY
$INIT
$LOG
$TIME
8 1024 8192 0 1 2 2 1 0
$DRIVER
$OP
$DATATYPE
$VERIFY
$INIT
$LOG
$TIME
8 1024 8192 2 1 0 0 1 2
# 4D
######## op datatype verify init log time dims in_strides_order out_strides_order
$DRIVER
$OP
$DATATYPE
$VERIFY
$INIT
$LOG
$TIME
8 2 512 8192 0 1 2 3 3 2 1 0
$DRIVER
$OP
$DATATYPE
$VERIFY
$INIT
$LOG
$TIME
8 2 512 8192 3 2 1 0 0 1 2 3
# 5D
######## op datatype verify init log time dims in_strides_order out_strides_order
$DRIVER
$OP
$DATATYPE
$VERIFY
$INIT
$LOG
$TIME
8 2 2 256 8192 0 1 2 3 4 4 3 2 1 0
$DRIVER
$OP
$DATATYPE
$VERIFY
$INIT
$LOG
$TIME
8 2 2 256 8192 4 3 2 1 0 0 1 2 3 4
# 6D
######## op datatype verify init log time dims in_strides_order out_strides_order
$DRIVER
$OP
$DATATYPE
$VERIFY
$INIT
$LOG
$TIME
8 2 2 2 128 8192 0 1 2 3 4 5 5 4 3 2 1 0
$DRIVER
$OP
$DATATYPE
$VERIFY
$INIT
$LOG
$TIME
8 2 2 2 128 8192 5 4 3 2 1 0 0 1 2 3 4 5
test/CMakeLists.txt
View file @
7f65ac05
...
@@ -46,7 +46,18 @@ function(add_test_executable TEST_NAME)
...
@@ -46,7 +46,18 @@ function(add_test_executable TEST_NAME)
list
(
REMOVE_ITEM ARGN
"
${
source
}
"
)
list
(
REMOVE_ITEM ARGN
"
${
source
}
"
)
endif
()
endif
()
endforeach
()
endforeach
()
foreach
(
source IN LISTS ARGN
)
if
(
NOT GPU_TARGETS MATCHES
"gfx9"
AND source MATCHES
"xdl"
)
message
(
"removing xdl test
${
source
}
"
)
list
(
REMOVE_ITEM ARGN
"
${
source
}
"
)
endif
()
endforeach
()
foreach
(
source IN LISTS ARGN
)
if
(
NOT GPU_TARGETS MATCHES
"gfx11"
AND source MATCHES
"wmma"
)
message
(
"removing wmma test
${
source
}
"
)
list
(
REMOVE_ITEM ARGN
"
${
source
}
"
)
endif
()
endforeach
()
#only continue if there are some source files left on the list
#only continue if there are some source files left on the list
if
(
ARGN
)
if
(
ARGN
)
add_executable
(
${
TEST_NAME
}
${
ARGN
}
)
add_executable
(
${
TEST_NAME
}
${
ARGN
}
)
...
@@ -100,6 +111,18 @@ function(add_gtest_executable TEST_NAME)
...
@@ -100,6 +111,18 @@ function(add_gtest_executable TEST_NAME)
list
(
REMOVE_ITEM ARGN
"
${
source
}
"
)
list
(
REMOVE_ITEM ARGN
"
${
source
}
"
)
endif
()
endif
()
endforeach
()
endforeach
()
foreach
(
source IN LISTS ARGN
)
if
(
NOT GPU_TARGETS MATCHES
"gfx9"
AND source MATCHES
"xdl"
)
message
(
"removing xdl test
${
source
}
"
)
list
(
REMOVE_ITEM ARGN
"
${
source
}
"
)
endif
()
endforeach
()
foreach
(
source IN LISTS ARGN
)
if
(
NOT GPU_TARGETS MATCHES
"gfx11"
AND source MATCHES
"wmma"
)
message
(
"removing wmma test
${
source
}
"
)
list
(
REMOVE_ITEM ARGN
"
${
source
}
"
)
endif
()
endforeach
()
#only continue if there are some source files left on the list
#only continue if there are some source files left on the list
if
(
ARGN
)
if
(
ARGN
)
add_executable
(
${
TEST_NAME
}
${
ARGN
}
)
add_executable
(
${
TEST_NAME
}
${
ARGN
}
)
...
@@ -117,6 +140,7 @@ function(add_gtest_executable TEST_NAME)
...
@@ -117,6 +140,7 @@ function(add_gtest_executable TEST_NAME)
set
(
result
${
result
}
PARENT_SCOPE
)
set
(
result
${
result
}
PARENT_SCOPE
)
endfunction
()
endfunction
()
add_compile_options
(
-Wno-c++20-extensions
)
add_subdirectory
(
magic_number_division
)
add_subdirectory
(
magic_number_division
)
add_subdirectory
(
space_filling_curve
)
add_subdirectory
(
space_filling_curve
)
add_subdirectory
(
conv_util
)
add_subdirectory
(
conv_util
)
...
...
test/batched_gemm/CMakeLists.txt
View file @
7f65ac05
list
(
APPEND gpu_list gfx908 gfx90a gfx940 gfx941 gfx942
)
add_gtest_executable
(
test_batched_gemm test_batched_gemm_xdl.cpp
)
set
(
target 0
)
if
(
result EQUAL 0
)
foreach
(
gpu IN LISTS GPU_TARGETS
)
if
(
gpu IN_LIST gpu_list AND target EQUAL 0
)
add_gtest_executable
(
test_batched_gemm test_batched_gemm.cpp
)
target_link_libraries
(
test_batched_gemm PRIVATE utility device_batched_gemm_instance
)
target_link_libraries
(
test_batched_gemm PRIVATE utility device_batched_gemm_instance
)
set
(
target 1
)
endif
()
endif
()
endforeach
()
\ No newline at end of file
test/batched_gemm/test_batched_gemm.cpp
→
test/batched_gemm/test_batched_gemm
_xdl
.cpp
View file @
7f65ac05
File moved
test/batched_gemm_gemm/CMakeLists.txt
View file @
7f65ac05
list
(
APPEND gpu_list gfx908 gfx90a gfx940 gfx941 gfx942
)
add_gtest_executable
(
test_batched_gemm_gemm_fp16 test_batched_gemm_gemm_fp16_xdl.cpp
)
set
(
target 0
)
if
(
result EQUAL 0
)
foreach
(
gpu IN LISTS GPU_TARGETS
)
add_custom_target
(
test_batched_gemm_gemm
)
if
(
gpu IN_LIST gpu_list AND target EQUAL 0
)
target_link_libraries
(
test_batched_gemm_gemm_fp16 PRIVATE utility device_batched_gemm_gemm_instance
)
add_custom_target
(
test_batched_gemm_gemm
)
add_dependencies
(
test_batched_gemm_gemm test_batched_gemm_gemm_fp16
)
add_gtest_executable
(
test_batched_gemm_gemm_fp16 test_batched_gemm_gemm_fp16.cpp
)
endif
()
if
(
result EQUAL 0
)
target_link_libraries
(
test_batched_gemm_gemm_fp16 PRIVATE utility device_batched_gemm_gemm_instance
)
add_dependencies
(
test_batched_gemm_gemm test_batched_gemm_gemm_fp16
)
set
(
target 1
)
endif
()
endif
()
endforeach
()
\ No newline at end of file
test/batched_gemm_gemm/test_batched_gemm_gemm_fp16.cpp
→
test/batched_gemm_gemm/test_batched_gemm_gemm_fp16
_xdl
.cpp
View file @
7f65ac05
File moved
test/batched_gemm_reduce/CMakeLists.txt
View file @
7f65ac05
list
(
APPEND gpu_list gfx908 gfx90a gfx940 gfx941 gfx942
)
add_test_executable
(
test_batched_gemm_reduce_fp16 batched_gemm_reduce_fp16_xdl.cpp
)
set
(
target 0
)
if
(
result EQUAL 0
)
foreach
(
gpu IN LISTS GPU_TARGETS
)
target_link_libraries
(
test_batched_gemm_reduce_fp16 PRIVATE utility device_batched_gemm_reduce_instance
)
if
(
gpu IN_LIST gpu_list AND target EQUAL 0
)
add_test_executable
(
test_batched_gemm_reduce_fp16 batched_gemm_reduce_fp16.cpp
)
if
(
result EQUAL 0
)
target_link_libraries
(
test_batched_gemm_reduce_fp16 PRIVATE utility device_batched_gemm_reduce_instance
)
set
(
target 1
)
endif
()
endif
()
endif
()
endforeach
()
test/batched_gemm_reduce/batched_gemm_reduce_fp16.cpp
→
test/batched_gemm_reduce/batched_gemm_reduce_fp16
_xdl
.cpp
View file @
7f65ac05
File moved
test/batched_gemm_softmax_gemm/CMakeLists.txt
View file @
7f65ac05
list
(
APPEND gpu_list gfx908 gfx90a gfx940 gfx941 gfx942
)
add_gtest_executable
(
test_batched_gemm_softmax_gemm_fp16 test_batched_gemm_softmax_gemm_fp16_xdl.cpp
)
set
(
target 0
)
if
(
result EQUAL 0
)
foreach
(
gpu IN LISTS GPU_TARGETS
)
add_custom_target
(
test_batched_gemm_softmax_gemm
)
if
(
gpu IN_LIST gpu_list AND target EQUAL 0
)
target_link_libraries
(
test_batched_gemm_softmax_gemm_fp16 PRIVATE utility device_batched_gemm_softmax_gemm_instance
)
add_custom_target
(
test_batched_gemm_softmax_gemm
)
add_dependencies
(
test_batched_gemm_softmax_gemm test_batched_gemm_softmax_gemm_fp16
)
add_gtest_executable
(
test_batched_gemm_softmax_gemm_fp16 test_batched_gemm_softmax_gemm_fp16.cpp
)
endif
()
if
(
result EQUAL 0
)
target_link_libraries
(
test_batched_gemm_softmax_gemm_fp16 PRIVATE utility device_batched_gemm_softmax_gemm_instance
)
add_dependencies
(
test_batched_gemm_softmax_gemm test_batched_gemm_softmax_gemm_fp16
)
set
(
target 1
)
endif
()
endif
()
endforeach
()
\ No newline at end of file
test/batched_gemm_softmax_gemm/test_batched_gemm_softmax_gemm_fp16.cpp
→
test/batched_gemm_softmax_gemm/test_batched_gemm_softmax_gemm_fp16
_xdl
.cpp
View file @
7f65ac05
File moved
test/batched_gemm_softmax_gemm_permute/CMakeLists.txt
View file @
7f65ac05
list
(
APPEND gpu_list gfx908 gfx90a gfx940 gfx941 gfx942
)
add_custom_target
(
test_batched_gemm_softmax_gemm_permute
)
set
(
target 0
)
add_gtest_executable
(
test_batched_gemm_softmax_gemm_permute_fp16 test_batched_gemm_softmax_gemm_permute_fp16_xdl.cpp
)
foreach
(
gpu IN LISTS GPU_TARGETS
)
if
(
result EQUAL 0
)
if
(
gpu IN_LIST gpu_list AND target EQUAL 0
)
target_link_libraries
(
test_batched_gemm_softmax_gemm_permute_fp16 PRIVATE utility device_batched_gemm_softmax_gemm_permute_instance
)
add_custom_target
(
test_batched_gemm_softmax_gemm_permute
)
add_dependencies
(
test_batched_gemm_softmax_gemm_permute test_batched_gemm_softmax_gemm_permute_fp16
)
add_gtest_executable
(
test_batched_gemm_softmax_gemm_permute_fp16 test_batched_gemm_softmax_gemm_permute_fp16.cpp
)
endif
()
if
(
result EQUAL 0
)
add_gtest_executable
(
test_batched_gemm_bias_softmax_gemm_permute_fp16 test_batched_gemm_bias_softmax_gemm_permute_fp16_xdl.cpp
)
target_link_libraries
(
test_batched_gemm_softmax_gemm_permute_fp16 PRIVATE utility device_batched_gemm_softmax_gemm_permute_instance
)
if
(
result EQUAL 0
)
add_dependencies
(
test_batched_gemm_softmax_gemm_permute test_batched_gemm_softmax_gemm_permute_fp16
)
target_link_libraries
(
test_batched_gemm_bias_softmax_gemm_permute_fp16 PRIVATE utility device_batched_gemm_softmax_gemm_permute_instance
)
endif
()
add_dependencies
(
test_batched_gemm_softmax_gemm_permute test_batched_gemm_bias_softmax_gemm_permute_fp16
)
add_gtest_executable
(
test_batched_gemm_bias_softmax_gemm_permute_fp16 test_batched_gemm_bias_softmax_gemm_permute_fp16.cpp
)
endif
()
if
(
result EQUAL 0
)
add_gtest_executable
(
test_batched_gemm_softmax_gemm_permute_bf16 test_batched_gemm_softmax_gemm_permute_bf16_xdl.cpp
)
target_link_libraries
(
test_batched_gemm_bias_softmax_gemm_permute_fp16 PRIVATE utility device_batched_gemm_softmax_gemm_permute_instance
)
if
(
result EQUAL 0
)
add_dependencies
(
test_batched_gemm_softmax_gemm_permute test_batched_gemm_bias_softmax_gemm_permute_fp16
)
target_link_libraries
(
test_batched_gemm_softmax_gemm_permute_bf16 PRIVATE utility device_batched_gemm_softmax_gemm_permute_instance
)
endif
()
add_dependencies
(
test_batched_gemm_softmax_gemm_permute test_batched_gemm_softmax_gemm_permute_bf16
)
endif
()
add_gtest_executable
(
test_batched_gemm_softmax_gemm_permute_bf16 test_batched_gemm_softmax_gemm_permute_bf16.cpp
)
add_gtest_executable
(
test_batched_gemm_bias_softmax_gemm_permute_bf16 test_batched_gemm_bias_softmax_gemm_permute_bf16_xdl.cpp
)
if
(
result EQUAL 0
)
if
(
result EQUAL 0
)
target_link_libraries
(
test_batched_gemm_softmax_gemm_permute_bf16 PRIVATE utility device_batched_gemm_softmax_gemm_permute_instance
)
target_link_libraries
(
test_batched_gemm_bias_softmax_gemm_permute_bf16 PRIVATE utility device_batched_gemm_softmax_gemm_permute_instance
)
add_dependencies
(
test_batched_gemm_softmax_gemm_permute test_batched_gemm_softmax_gemm_permute_bf16
)
add_dependencies
(
test_batched_gemm_softmax_gemm_permute test_batched_gemm_bias_softmax_gemm_permute_bf16
)
endif
()
endif
()
add_gtest_executable
(
test_batched_gemm_bias_softmax_gemm_permute_bf16 test_batched_gemm_bias_softmax_gemm_permute_bf16.cpp
)
if
(
result EQUAL 0
)
target_link_libraries
(
test_batched_gemm_bias_softmax_gemm_permute_bf16 PRIVATE utility device_batched_gemm_softmax_gemm_permute_instance
)
add_dependencies
(
test_batched_gemm_softmax_gemm_permute test_batched_gemm_bias_softmax_gemm_permute_bf16
)
endif
()
set
(
target 1
)
endif
()
endforeach
()
\ No newline at end of file
test/batched_gemm_softmax_gemm_permute/test_batched_gemm_bias_softmax_gemm_permute_bf16.cpp
→
test/batched_gemm_softmax_gemm_permute/test_batched_gemm_bias_softmax_gemm_permute_bf16
_xdl
.cpp
View file @
7f65ac05
File moved
test/batched_gemm_softmax_gemm_permute/test_batched_gemm_bias_softmax_gemm_permute_fp16.cpp
→
test/batched_gemm_softmax_gemm_permute/test_batched_gemm_bias_softmax_gemm_permute_fp16
_xdl
.cpp
View file @
7f65ac05
File moved
test/batched_gemm_softmax_gemm_permute/test_batched_gemm_softmax_gemm_permute_bf16.cpp
→
test/batched_gemm_softmax_gemm_permute/test_batched_gemm_softmax_gemm_permute_bf16
_xdl
.cpp
View file @
7f65ac05
File moved
test/batched_gemm_softmax_gemm_permute/test_batched_gemm_softmax_gemm_permute_fp16.cpp
→
test/batched_gemm_softmax_gemm_permute/test_batched_gemm_softmax_gemm_permute_fp16
_xdl
.cpp
View file @
7f65ac05
File moved
test/contraction/CMakeLists.txt
View file @
7f65ac05
list
(
APPEND gpu_list gfx908 gfx90a gfx940 gfx941 gfx942
)
if
((
DTYPES MATCHES
"fp32"
OR DTYPES MATCHES
"fp64"
)
OR NOT DEFINED DTYPES
)
set
(
target 0
)
add_gtest_executable
(
test_contraction test_contraction_xdl.cpp
)
foreach
(
gpu IN LISTS GPU_TARGETS
)
if
(
result EQUAL 0
)
if
(
gpu IN_LIST gpu_list AND target EQUAL 0
)
target_link_libraries
(
test_contraction PRIVATE utility device_contraction_bilinear_instance device_contraction_scale_instance
)
if
((
DTYPES MATCHES
"fp32"
OR DTYPES MATCHES
"fp64"
)
OR NOT DEFINED DTYPES
)
add_gtest_executable
(
test_contraction test_contraction.cpp
)
target_link_libraries
(
test_contraction PRIVATE utility device_contraction_bilinear_instance device_contraction_scale_instance
)
add_gtest_executable
(
test_contraction_interface test_contraction_interface.cpp
)
target_link_libraries
(
test_contraction_interface PRIVATE utility device_contraction_bilinear_instance device_contraction_scale_instance
)
set
(
target 1
)
endif
()
endif
()
endif
()
endforeach
()
add_gtest_executable
(
test_contraction_interface test_contraction_interface_xdl.cpp
)
if
(
result EQUAL 0
)
target_link_libraries
(
test_contraction_interface PRIVATE utility device_contraction_bilinear_instance device_contraction_scale_instance
)
endif
()
endif
()
test/contraction/test_contraction_interface.cpp
→
test/contraction/test_contraction_interface
_xdl
.cpp
View file @
7f65ac05
File moved
Prev
1
…
6
7
8
9
10
11
12
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment