Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel_ROCM
Commits
3552041a
Commit
3552041a
authored
Jul 26, 2024
by
danyao12
Browse files
Merge branch 'develop' into ck_tile/fa_bwd_opt
parents
e8927110
733f33af
Changes
273
Hide whitespace changes
Inline
Side-by-side
Showing
13 changed files
with
466 additions
and
30 deletions
+466
-30
profiler/src/profile_gemm_universal_streamk.cpp
profiler/src/profile_gemm_universal_streamk.cpp
+156
-0
profiler/src/profile_grouped_conv_fwd_outelementop.cpp
profiler/src/profile_grouped_conv_fwd_outelementop.cpp
+220
-0
script/profile_grouped_conv_fwd_outelementop.sh
script/profile_grouped_conv_fwd_outelementop.sh
+20
-0
test/CMakeLists.txt
test/CMakeLists.txt
+7
-3
test/gemm_universal/test_gemm_universal_util.hpp
test/gemm_universal/test_gemm_universal_util.hpp
+9
-7
test/gemm_universal/test_gemm_universal_xdl.cpp
test/gemm_universal/test_gemm_universal_xdl.cpp
+17
-9
test/grouped_convnd_bwd_data/CMakeLists.txt
test/grouped_convnd_bwd_data/CMakeLists.txt
+4
-4
test/grouped_convnd_bwd_data/test_grouped_convnd_bwd_data_interface_wmma.cpp
..._bwd_data/test_grouped_convnd_bwd_data_interface_wmma.cpp
+8
-0
test/grouped_convnd_bwd_weight/CMakeLists.txt
test/grouped_convnd_bwd_weight/CMakeLists.txt
+4
-4
test/grouped_convnd_bwd_weight/test_grouped_convnd_bwd_weight_interface_wmma.cpp
..._weight/test_grouped_convnd_bwd_weight_interface_wmma.cpp
+8
-0
test/grouped_convnd_fwd/CMakeLists.txt
test/grouped_convnd_fwd/CMakeLists.txt
+1
-1
test/grouped_convnd_fwd/test_grouped_convnd_fwd.cpp
test/grouped_convnd_fwd/test_grouped_convnd_fwd.cpp
+9
-1
test/wmma_op/wmma_op_util.hpp
test/wmma_op/wmma_op_util.hpp
+3
-1
No files found.
profiler/src/profile_gemm_universal_streamk.cpp
0 → 100644
View file @
3552041a
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#include <iostream>
#include <numeric>
#include <initializer_list>
#include <cstdlib>
#include "profiler/profile_gemm_universal_streamk_impl.hpp"
#include "profiler_operation_registry.hpp"
enum
struct
GemmMatrixLayout
{
MK_KN_MN
,
// 0
MK_NK_MN
,
// 1
KM_KN_MN
,
// 2
KM_NK_MN
,
// 3
};
enum
struct
GemmDataType
{
F32_F32_F32
,
// 0
F16_F16_F16
,
// 1
BF16_BF16_BF16
,
// 2
INT8_INT8_INT8
,
// 3
F8_F16_F16
,
// 4
F16_F8_F16
,
// 5
F16_F16_F16_F8
,
// 6
};
#define OP_NAME "gemm_universal_streamk"
#define OP_DESC "Universal Streamk GEMM"
int
profile_gemm_universal_streamk
(
int
argc
,
char
*
argv
[])
{
if
(
argc
!=
16
&&
argc
!=
19
)
{
printf
(
"arg1: tensor operation ("
OP_NAME
": "
OP_DESC
")
\n
"
);
printf
(
"arg2: data type (0: fp32; 1: fp16; 2: bf16; 3: int8; 4: f8@f16; 5: f16@f8; 6: f16, "
"comp f8)
\n
"
);
printf
(
"arg3: matrix layout (0: A[m, k] * B[k, n] = C[m, n];
\n
"
);
printf
(
" 1: A[m, k] * B[n, k] = C[m, n];
\n
"
);
printf
(
" 2: A[k, m] * B[k, n] = C[m, n];
\n
"
);
printf
(
" 3: A[k, m] * B[n, k] = C[m, n])
\n
"
);
printf
(
"arg4: verification (0: no; 1: yes)
\n
"
);
printf
(
"arg5: initialization (0: no init; 1: integer value; 2: decimal value)
\n
"
);
printf
(
"arg6: print tensor value (0: no; 1: yes)
\n
"
);
printf
(
"arg7: time kernel (0=no, 1=yes)
\n
"
);
printf
(
"arg8 to 13: M, N, K, StrideA, StrideB, StrideC
\n
"
);
printf
(
"arg14: Stream-k select strategy 0: all DP, 1: 1-tile SK, 2: 2-tile SK
\n
"
);
printf
(
"arg15: Grid-size, -1 for max persistent kernel occupancy
\n
"
);
printf
(
"optional:
\n
"
);
printf
(
"arg16: number of warm-up cycles (default 1)
\n
"
);
printf
(
"arg17: number of iterations (default 10)
\n
"
);
printf
(
"arg18: memory for rotating buffer (default 0, size in MB)
\n
"
);
exit
(
1
);
}
const
auto
data_type
=
static_cast
<
GemmDataType
>
(
std
::
stoi
(
argv
[
2
]));
const
auto
layout
=
static_cast
<
GemmMatrixLayout
>
(
std
::
stoi
(
argv
[
3
]));
const
bool
do_verification
=
std
::
stoi
(
argv
[
4
]);
const
int
init_method
=
std
::
stoi
(
argv
[
5
]);
const
bool
do_log
=
std
::
stoi
(
argv
[
6
]);
const
bool
time_kernel
=
std
::
stoi
(
argv
[
7
]);
const
int
M
=
std
::
stoi
(
argv
[
8
]);
const
int
N
=
std
::
stoi
(
argv
[
9
]);
const
int
K
=
std
::
stoi
(
argv
[
10
]);
const
int
StrideA
=
std
::
stoi
(
argv
[
11
]);
const
int
StrideB
=
std
::
stoi
(
argv
[
12
]);
const
int
StrideC
=
std
::
stoi
(
argv
[
13
]);
const
int
Streamk_sel
=
std
::
stoi
(
argv
[
14
]);
const
int
Grid_size
=
std
::
stoi
(
argv
[
15
]);
int
n_warmup
=
20
;
int
n_iter
=
50
;
uint64_t
rotating
=
0
;
if
(
argc
==
19
)
{
n_warmup
=
std
::
stoi
(
argv
[
16
]);
n_iter
=
std
::
stoi
(
argv
[
17
]);
rotating
=
std
::
stoull
(
argv
[
18
])
*
1024
*
1024
;
}
using
F32
=
float
;
using
F16
=
ck
::
half_t
;
// using BF16 = ck::bhalf_t;
// using F8 = ck::f8_t;
using
Row
=
ck
::
tensor_layout
::
gemm
::
RowMajor
;
using
Col
=
ck
::
tensor_layout
::
gemm
::
ColumnMajor
;
auto
profile
=
[
&
](
auto
a_type
,
auto
b_type
,
auto
acc_type
,
auto
c_type
,
auto
a_layout
,
auto
b_layout
,
auto
c_layout
)
{
using
ADataType
=
decltype
(
a_type
);
using
BDataType
=
decltype
(
b_type
);
using
AccDataType
=
decltype
(
acc_type
);
using
CDataType
=
decltype
(
c_type
);
using
ALayout
=
decltype
(
a_layout
);
using
BLayout
=
decltype
(
b_layout
);
using
CLayout
=
decltype
(
c_layout
);
const
int
DefaultStrideA
=
ck
::
is_same_v
<
ALayout
,
Row
>
?
K
:
M
;
const
int
DefaultStrideB
=
ck
::
is_same_v
<
BLayout
,
Row
>
?
N
:
K
;
const
int
DefaultStrideC
=
ck
::
is_same_v
<
CLayout
,
Row
>
?
N
:
M
;
bool
pass
=
ck
::
profiler
::
profile_gemm_universal_streamk_impl
<
ADataType
,
BDataType
,
AccDataType
,
CDataType
,
ALayout
,
BLayout
,
CLayout
>
(
do_verification
,
init_method
,
do_log
,
time_kernel
,
M
,
N
,
K
,
(
StrideA
<
0
)
?
DefaultStrideA
:
StrideA
,
(
StrideB
<
0
)
?
DefaultStrideB
:
StrideB
,
(
StrideC
<
0
)
?
DefaultStrideC
:
StrideC
,
Streamk_sel
,
Grid_size
,
n_warmup
,
n_iter
,
rotating
);
return
pass
?
0
:
1
;
};
if
(
data_type
==
GemmDataType
::
F16_F16_F16
&&
layout
==
GemmMatrixLayout
::
MK_KN_MN
)
{
return
profile
(
F16
{},
F16
{},
F32
{},
F16
{},
Row
{},
Row
{},
Row
{});
}
else
if
(
data_type
==
GemmDataType
::
F16_F16_F16
&&
layout
==
GemmMatrixLayout
::
MK_NK_MN
)
{
return
profile
(
F16
{},
F16
{},
F32
{},
F16
{},
Row
{},
Col
{},
Row
{});
}
else
{
std
::
cout
<<
"this data_type & layout is not implemented"
<<
std
::
endl
;
return
1
;
}
}
REGISTER_PROFILER_OPERATION
(
OP_NAME
,
OP_DESC
,
profile_gemm_universal_streamk
);
profiler/src/profile_grouped_conv_fwd_outelementop.cpp
0 → 100644
View file @
3552041a
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "profiler/profile_grouped_conv_fwd_outelementop_impl.hpp"
#include "ck/utility/data_type.hpp"
#include "profiler_operation_registry.hpp"
#include <iostream>
enum
struct
ConvLayout
{
GNHWC_GKYXC_GNHWK
=
0
,
NHWGC_GKYXC_NHWGK
=
1
};
enum
struct
OutElementOp
{
ConvScale
=
0
,
ConvInvScale
=
1
};
enum
struct
ConvDataType
{
F8_F8_F8
=
0
,
BF8_BF8_F8
=
1
,
F8_BF8_F8
=
2
,
BF8_F8_F8
=
3
};
#define OP_NAME "grouped_conv_fwd_outelementop"
#define OP_DESC "Grouped Convolution Forward+Elementwise Operation"
static
void
print_helper_msg
()
{
// clang-format off
std
::
cout
<<
"arg1: tensor operation ("
OP_NAME
": "
OP_DESC
")
\n
"
<<
"arg2: data type (0: Input fp8, Weight fp8, Output fp8
\n
"
<<
" 1: Input bf8, Weight bf8, Output fp8
\n
"
<<
" 2: Input fp8, Weight bf8, Output fp8
\n
"
<<
" 3: Input bf8, Weight fp8, Output fp8)
\n
"
<<
"arg3: element-wise operation (0: ConvScale
\n
"
<<
" 1: ConvInvScale)
\n
"
<<
"arg4: tensor layout (0: Input[G, N, Hi, Wi, C], Weight[G, K, Y, X, C], Output[G, N, Ho, Wo, K]
\n
"
<<
" 1: Input[N, Hi, Wi, G, C], Weight[G, K, Y, X, C], Output[N, Ho, Wo, G, K])
\n
"
<<
"arg5: verification (0: no, 1: yes)
\n
"
<<
"arg6: initialization (0: no init, 1: integer value, 2: decimal value)
\n
"
<<
"arg7: print tensor value (0: no; 1: yes)
\n
"
<<
"arg8: time kernel (0: no, 1: yes)
\n
"
<<
ck
::
utils
::
conv
::
get_conv_param_parser_helper_msg
()
<<
std
::
endl
;
// clang-format on
}
int
grouped_conv_fwd_outelementop
(
int
argc
,
char
*
argv
[])
{
// 9 total, 1 for num_dim_spatial
if
(
argc
<
10
)
{
print_helper_msg
();
return
1
;
}
const
auto
data_type
=
static_cast
<
ConvDataType
>
(
std
::
stoi
(
argv
[
2
]));
const
auto
op
=
static_cast
<
OutElementOp
>
(
std
::
stoi
(
argv
[
3
]));
const
auto
layout
=
static_cast
<
ConvLayout
>
(
std
::
stoi
(
argv
[
4
]));
const
bool
do_verification
=
std
::
stoi
(
argv
[
5
]);
const
int
init_method
=
std
::
stoi
(
argv
[
6
]);
const
bool
do_log
=
std
::
stoi
(
argv
[
7
]);
const
bool
time_kernel
=
std
::
stoi
(
argv
[
8
]);
const
int
num_dim_spatial
=
std
::
stoi
(
argv
[
9
]);
// 8 for control, 1 for num_dim_spatial, 4 for G/N/K/C, and 6 * num_dim_spatial + 1 for argv[0]
if
(
argc
!=
8
+
1
+
4
+
6
*
num_dim_spatial
+
1
)
{
print_helper_msg
();
return
1
;
}
const
auto
params
=
ck
::
utils
::
conv
::
parse_conv_param
(
num_dim_spatial
,
10
,
argv
);
using
F8
=
ck
::
f8_t
;
using
BF8
=
ck
::
bf8_t
;
using
GKZYXC
=
ck
::
tensor_layout
::
convolution
::
GKZYXC
;
using
NDHWGC
=
ck
::
tensor_layout
::
convolution
::
NDHWGC
;
using
NDHWGK
=
ck
::
tensor_layout
::
convolution
::
NDHWGK
;
using
ConvScale
=
ck
::
tensor_operation
::
element_wise
::
ConvScale
;
using
ConvInvScale
=
ck
::
tensor_operation
::
element_wise
::
ConvInvscale
;
constexpr
auto
I3
=
ck
::
Number
<
3
>
{};
auto
profile
=
[
&
](
auto
num_dim_spatial_tmp
,
auto
in_layout
,
auto
wei_layout
,
auto
out_layout
,
auto
in_type
,
auto
wei_type
,
auto
out_type
,
auto
out_element_op
,
auto
a_compute_type
,
auto
b_compute_type
)
{
constexpr
ck
::
index_t
NDimSpatial
=
num_dim_spatial_tmp
.
value
;
using
InLayout
=
decltype
(
in_layout
);
using
WeiLayout
=
decltype
(
wei_layout
);
using
OutLayout
=
decltype
(
out_layout
);
using
InDataType
=
decltype
(
in_type
);
using
WeiDataType
=
decltype
(
wei_type
);
using
OutDataType
=
decltype
(
out_type
);
using
OutElementOp
=
decltype
(
out_element_op
);
using
AComputeType
=
decltype
(
a_compute_type
);
using
BComputeType
=
decltype
(
b_compute_type
);
bool
pass
=
ck
::
profiler
::
profile_grouped_conv_fwd_outelementop_impl
<
NDimSpatial
,
InLayout
,
WeiLayout
,
OutLayout
,
InDataType
,
WeiDataType
,
OutDataType
,
OutElementOp
,
AComputeType
,
BComputeType
>
(
do_verification
,
init_method
,
do_log
,
time_kernel
,
params
);
return
pass
?
0
:
1
;
};
if
(
num_dim_spatial
==
3
&&
layout
==
ConvLayout
::
NHWGC_GKYXC_NHWGK
)
{
if
(
op
==
OutElementOp
::
ConvScale
)
{
if
(
data_type
==
ConvDataType
::
F8_F8_F8
)
{
return
profile
(
I3
,
NDHWGC
{},
GKZYXC
{},
NDHWGK
{},
F8
{},
F8
{},
F8
{},
ConvScale
{},
F8
{},
F8
{});
}
else
if
(
data_type
==
ConvDataType
::
BF8_BF8_F8
)
{
return
profile
(
I3
,
NDHWGC
{},
GKZYXC
{},
NDHWGK
{},
BF8
{},
BF8
{},
F8
{},
ConvScale
{},
BF8
{},
BF8
{});
}
else
if
(
data_type
==
ConvDataType
::
F8_BF8_F8
)
{
return
profile
(
I3
,
NDHWGC
{},
GKZYXC
{},
NDHWGK
{},
F8
{},
BF8
{},
F8
{},
ConvScale
{},
F8
{},
BF8
{});
}
else
if
(
data_type
==
ConvDataType
::
BF8_F8_F8
)
{
return
profile
(
I3
,
NDHWGC
{},
GKZYXC
{},
NDHWGK
{},
BF8
{},
F8
{},
F8
{},
ConvScale
{},
BF8
{},
F8
{});
}
}
else
if
(
op
==
OutElementOp
::
ConvInvScale
)
{
if
(
data_type
==
ConvDataType
::
F8_F8_F8
)
{
return
profile
(
I3
,
NDHWGC
{},
GKZYXC
{},
NDHWGK
{},
F8
{},
F8
{},
F8
{},
ConvInvScale
{},
F8
{},
F8
{});
}
else
if
(
data_type
==
ConvDataType
::
BF8_BF8_F8
)
{
return
profile
(
I3
,
NDHWGC
{},
GKZYXC
{},
NDHWGK
{},
BF8
{},
BF8
{},
F8
{},
ConvInvScale
{},
BF8
{},
BF8
{});
}
else
if
(
data_type
==
ConvDataType
::
F8_BF8_F8
)
{
return
profile
(
I3
,
NDHWGC
{},
GKZYXC
{},
NDHWGK
{},
F8
{},
BF8
{},
F8
{},
ConvInvScale
{},
F8
{},
BF8
{});
}
else
if
(
data_type
==
ConvDataType
::
BF8_F8_F8
)
{
return
profile
(
I3
,
NDHWGC
{},
GKZYXC
{},
NDHWGK
{},
BF8
{},
F8
{},
F8
{},
ConvInvScale
{},
BF8
{},
F8
{});
}
}
}
std
::
cout
<<
"this data_type & layout is not implemented"
<<
std
::
endl
;
return
1
;
}
REGISTER_PROFILER_OPERATION
(
OP_NAME
,
OP_DESC
,
grouped_conv_fwd_outelementop
);
script/profile_grouped_conv_fwd_outelementop.sh
0 → 100755
View file @
3552041a
#!/bin/bash
## GPU visibility
export
HIP_VISIBLE_DEVICES
=
0
DRIVER
=
"../build/bin/ckProfiler"
OP
=
$1
DATATYPE
=
$2
OUTELEMENTOP
=
$3
LAYOUT
=
$4
VERIFY
=
$5
INIT
=
$6
LOG
=
$7
TIME
=
$8
N
=
$9
####### op datatype OUTELEMENTOP layout verify init log time Ndims G N K C Z Y X Di Hi Wi Sz Sy Sx Dz Dy Dx Left Pz LeftPy LeftPx RightPz RightPy RightPx
$DRIVER
$OP
$DATATYPE
$OUTELEMENTOP
$LAYOUT
$VERIFY
$INIT
$LOG
$TIME
3 32
$N
96 96 3 3 3 28 28 28 1 1 1 1 1 1 1 1 1 1 1 1
$DRIVER
$OP
$DATATYPE
$OUTELEMENTOP
$LAYOUT
$VERIFY
$INIT
$LOG
$TIME
3 32
$N
192 192 3 3 3 28 28 28 1 1 1 1 1 1 1 1 1 1 1 1
test/CMakeLists.txt
View file @
3552041a
...
...
@@ -60,7 +60,7 @@ function(add_test_executable TEST_NAME)
endif
()
endforeach
()
foreach
(
source IN LISTS ARGN
)
if
(
NOT
GPU
_TARGETS MATCHES
"gfx11"
AND NOT
GPU
_TARGETS MATCHES
"gfx12"
AND source MATCHES
"wmma"
)
if
(
NOT
TEST
_TARGETS MATCHES
"gfx11"
AND NOT
TEST
_TARGETS MATCHES
"gfx12"
AND source MATCHES
"wmma"
)
message
(
"removing wmma test
${
source
}
"
)
list
(
REMOVE_ITEM ARGN
"
${
source
}
"
)
endif
()
...
...
@@ -71,6 +71,8 @@ function(add_test_executable TEST_NAME)
list
(
REMOVE_ITEM TEST_TARGETS gfx1030 gfx1100 gfx1101 gfx1102 gfx1103
)
elseif
(
ARGN MATCHES
"_wmma"
)
list
(
REMOVE_ITEM TEST_TARGETS gfx908 gfx90a gfx940 gfx941 gfx942 gfx1030
)
elseif
(
ARGN MATCHES
"_smfmac"
)
list
(
REMOVE_ITEM TEST_TARGETS gfx1030 gfx1100 gfx1101 gfx1102 gfx1103 gfx908 gfx90a
)
endif
()
set_source_files_properties
(
${
ARGN
}
PROPERTIES LANGUAGE HIP
)
add_executable
(
${
TEST_NAME
}
${
ARGN
}
)
...
...
@@ -139,7 +141,7 @@ function(add_gtest_executable TEST_NAME)
endif
()
endforeach
()
foreach
(
source IN LISTS ARGN
)
if
(
NOT
GPU
_TARGETS MATCHES
"gfx11"
AND NOT
GPU
_TARGETS MATCHES
"gfx12"
AND source MATCHES
"wmma"
)
if
(
NOT
TEST
_TARGETS MATCHES
"gfx11"
AND NOT
TEST
_TARGETS MATCHES
"gfx12"
AND source MATCHES
"wmma"
)
message
(
"removing wmma test
${
source
}
"
)
list
(
REMOVE_ITEM ARGN
"
${
source
}
"
)
endif
()
...
...
@@ -150,6 +152,8 @@ function(add_gtest_executable TEST_NAME)
list
(
REMOVE_ITEM TEST_TARGETS gfx900 gfx906 gfx1030 gfx1100 gfx1101 gfx1102 gfx1103
)
elseif
(
ARGN MATCHES
"_wmma"
)
list
(
REMOVE_ITEM TEST_TARGETS gfx900 gfx906 gfx908 gfx90a gfx940 gfx941 gfx942 gfx1030
)
elseif
(
ARGN MATCHES
"_smfmac"
)
list
(
REMOVE_ITEM TEST_TARGETS gfx1030 gfx1100 gfx1101 gfx1102 gfx1103 gfx908 gfx90a
)
endif
()
set_source_files_properties
(
${
ARGN
}
PROPERTIES LANGUAGE HIP
)
add_executable
(
${
TEST_NAME
}
${
ARGN
}
)
...
...
@@ -209,7 +213,7 @@ add_subdirectory(wrapper)
if
(
GPU_TARGETS MATCHES
"gfx11"
)
add_subdirectory
(
wmma_op
)
endif
()
if
(
GPU_TARGETS MATCHES
"gfx942"
)
if
(
GPU_TARGETS MATCHES
"gfx942"
AND CK_HIP_VERSION_MAJOR GREATER_EQUAL 6 AND CK_HIP_VERSION_MINOR GREATER_EQUAL 2
)
# smfmac needs ROCm6.2
add_subdirectory
(
smfmac_op
)
endif
()
add_subdirectory
(
position_embedding
)
test/gemm_universal/test_gemm_universal_util.hpp
View file @
3552041a
// SPDX-License-Identifier: MIT
// Copyright (c) 20
18
-202
3
, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 20
23
-202
4
, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
...
...
@@ -25,12 +25,13 @@ class TestGemmUniversal : public testing::Test
using
F32
=
float
;
protected:
using
ALayout
=
std
::
tuple_element_t
<
0
,
Tuple
>
;
using
BLayout
=
std
::
tuple_element_t
<
1
,
Tuple
>
;
using
CLayout
=
Row
;
using
ADataType
=
std
::
tuple_element_t
<
2
,
Tuple
>
;
using
BDataType
=
std
::
tuple_element_t
<
3
,
Tuple
>
;
using
CDataType
=
std
::
tuple_element_t
<
4
,
Tuple
>
;
using
ALayout
=
std
::
tuple_element_t
<
0
,
Tuple
>
;
using
BLayout
=
std
::
tuple_element_t
<
1
,
Tuple
>
;
using
CLayout
=
Row
;
using
ADataType
=
std
::
tuple_element_t
<
2
,
Tuple
>
;
using
BDataType
=
std
::
tuple_element_t
<
3
,
Tuple
>
;
using
ComputeDataType
=
std
::
tuple_element_t
<
4
,
Tuple
>
;
using
CDataType
=
std
::
tuple_element_t
<
5
,
Tuple
>
;
public:
static
constexpr
bool
verify_
=
true
;
...
...
@@ -66,6 +67,7 @@ class TestGemmUniversal : public testing::Test
{
bool
pass
=
ck
::
profiler
::
profile_gemm_universal_impl
<
ADataType
,
BDataType
,
ComputeDataType
,
F32
,
CDataType
,
ALayout
,
...
...
test/gemm_universal/test_gemm_universal_xdl.cpp
View file @
3552041a
// SPDX-License-Identifier: MIT
// Copyright (c) 20
18
-202
2
, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 20
23
-202
4
, Advanced Micro Devices, Inc. All rights reserved.
#include <tuple>
...
...
@@ -41,16 +41,24 @@ class TestGemmUniversal_MK_NK
};
// clang-format off
using
KernelTypes
=
::
testing
::
Types
<
// ADataType, BDataType, CDataType
std
::
tuple
<
F16
,
F16
,
F16
>
,
std
::
tuple
<
F16
,
F8
,
F16
>
,
std
::
tuple
<
F8
,
F16
,
F16
>
,
std
::
tuple
<
BF16
,
BF16
,
BF16
>
using
KernelTypes_MK_KN
=
::
testing
::
Types
<
// ADataType, BDataType, ComputeDataType, CDataType
std
::
tuple
<
F16
,
F16
,
F16
,
F16
>
,
std
::
tuple
<
F16
,
F8
,
F16
,
F16
>
,
std
::
tuple
<
F8
,
F16
,
F16
,
F16
>
,
std
::
tuple
<
BF16
,
BF16
,
BF16
,
BF16
>
>
;
using
KernelTypes_MK_NK
=
::
testing
::
Types
<
// ADataType, BDataType, ComputeDataType, CDataType
std
::
tuple
<
F16
,
F16
,
F16
,
F16
>
,
std
::
tuple
<
F16
,
F8
,
F16
,
F16
>
,
std
::
tuple
<
F8
,
F16
,
F16
,
F16
>
,
std
::
tuple
<
BF16
,
BF16
,
BF16
,
BF16
>
,
std
::
tuple
<
F8
,
F8
,
F8
,
BF16
>
>
;
// clang-format on
TYPED_TEST_SUITE
(
TestGemmUniversal_MK_KN
,
KernelTypes
);
TYPED_TEST_SUITE
(
TestGemmUniversal_MK_NK
,
KernelTypes
);
TYPED_TEST_SUITE
(
TestGemmUniversal_MK_KN
,
KernelTypes
_MK_KN
);
TYPED_TEST_SUITE
(
TestGemmUniversal_MK_NK
,
KernelTypes
_MK_NK
);
#include "test_gemm_universal_ut_cases.inc"
test/grouped_convnd_bwd_data/CMakeLists.txt
View file @
3552041a
...
...
@@ -2,11 +2,11 @@ add_gtest_executable(test_grouped_convnd_bwd_data test_grouped_convnd_bwd_data_x
if
(
result EQUAL 0
)
target_link_libraries
(
test_grouped_convnd_bwd_data PRIVATE utility device_grouped_conv2d_bwd_data_instance device_grouped_conv3d_bwd_data_instance
)
endif
()
add_gtest_executable
(
test_grouped_convnd_bwd_data_interface test_grouped_convnd_bwd_data_interface_xdl.cpp
)
add_gtest_executable
(
test_grouped_convnd_bwd_data_interface
_xdl
test_grouped_convnd_bwd_data_interface_xdl.cpp
)
if
(
result EQUAL 0
)
target_link_libraries
(
test_grouped_convnd_bwd_data_interface PRIVATE utility device_grouped_conv2d_bwd_data_instance
)
target_link_libraries
(
test_grouped_convnd_bwd_data_interface
_xdl
PRIVATE utility device_grouped_conv2d_bwd_data_instance
)
endif
()
add_gtest_executable
(
test_grouped_convnd_bwd_data_interface test_grouped_convnd_bwd_data_interface_wmma.cpp
)
add_gtest_executable
(
test_grouped_convnd_bwd_data_interface
_wmma
test_grouped_convnd_bwd_data_interface_wmma.cpp
)
if
(
result EQUAL 0
)
target_link_libraries
(
test_grouped_convnd_bwd_data_interface PRIVATE utility device_grouped_conv2d_bwd_data_instance
)
target_link_libraries
(
test_grouped_convnd_bwd_data_interface
_wmma
PRIVATE utility device_grouped_conv2d_bwd_data_instance
)
endif
()
test/grouped_convnd_bwd_data/test_grouped_convnd_bwd_data_interface_wmma.cpp
View file @
3552041a
...
...
@@ -52,6 +52,14 @@ class TestGroupedConvndBwdData : public ::testing::Test
ck
::
utils
::
conv
::
ConvParam
conv_param
;
void
SetUp
()
override
{
if
(
!
ck
::
is_gfx11_supported
())
{
GTEST_SKIP
();
}
}
template
<
ck
::
index_t
NDimSpatial
>
bool
Run
()
{
...
...
test/grouped_convnd_bwd_weight/CMakeLists.txt
View file @
3552041a
...
...
@@ -5,13 +5,13 @@ if(GPU_TARGETS MATCHES "gfx9" OR DL_KERNELS)
add_gtest_executable
(
test_grouped_convnd_bwd_weight test_grouped_convnd_bwd_weight.cpp
)
target_link_libraries
(
test_grouped_convnd_bwd_weight PRIVATE utility device_grouped_conv3d_bwd_weight_instance
)
endif
()
add_gtest_executable
(
test_grouped_convnd_bwd_weight_interface test_grouped_convnd_bwd_weight_interface_xdl.cpp
)
add_gtest_executable
(
test_grouped_convnd_bwd_weight_interface
_xdl
test_grouped_convnd_bwd_weight_interface_xdl.cpp
)
if
(
result EQUAL 0
)
target_link_libraries
(
test_grouped_convnd_bwd_weight_interface PRIVATE utility
)
target_link_libraries
(
test_grouped_convnd_bwd_weight_interface
_xdl
PRIVATE utility
)
endif
()
add_gtest_executable
(
test_grouped_convnd_bwd_weight_interface test_grouped_convnd_bwd_weight_interface_wmma.cpp
)
add_gtest_executable
(
test_grouped_convnd_bwd_weight_interface
_wmma
test_grouped_convnd_bwd_weight_interface_wmma.cpp
)
if
(
result EQUAL 0
)
target_link_libraries
(
test_grouped_convnd_bwd_weight_interface PRIVATE utility
)
target_link_libraries
(
test_grouped_convnd_bwd_weight_interface
_wmma
PRIVATE utility
)
endif
()
add_gtest_executable
(
test_grouped_conv_bwd_weight_xdl_bilinear test_grouped_conv_bwd_weight_xdl_bilinear.cpp
)
if
(
result EQUAL 0
)
...
...
test/grouped_convnd_bwd_weight/test_grouped_convnd_bwd_weight_interface_wmma.cpp
View file @
3552041a
...
...
@@ -52,6 +52,14 @@ class TestGroupedConvndBwdWeight : public ::testing::Test
ck
::
utils
::
conv
::
ConvParam
conv_param
;
void
SetUp
()
override
{
if
(
!
ck
::
is_gfx11_supported
())
{
GTEST_SKIP
();
}
}
template
<
ck
::
index_t
SplitK
>
bool
Run
()
{
...
...
test/grouped_convnd_fwd/CMakeLists.txt
View file @
3552041a
if
(
GPU_TARGETS MATCHES
"gfx9"
OR GPU_TARGETS MATCHES
"gfx11"
)
add_gtest_executable
(
test_grouped_convnd_fwd test_grouped_convnd_fwd.cpp
)
if
(
GPU_TARGETS MATCHES
"gfx11"
)
if
(
(
GPU_TARGETS MATCHES
"gfx11"
)
AND
(
NOT GPU_TARGETS MATCHES
"gfx9"
))
target_link_libraries
(
test_grouped_convnd_fwd PRIVATE utility device_grouped_conv2d_fwd_instance device_grouped_conv3d_fwd_instance
)
else
()
target_link_libraries
(
test_grouped_convnd_fwd PRIVATE utility device_grouped_conv1d_fwd_instance device_grouped_conv2d_fwd_instance device_grouped_conv3d_fwd_instance
)
...
...
test/grouped_convnd_fwd/test_grouped_convnd_fwd.cpp
View file @
3552041a
...
...
@@ -104,6 +104,7 @@ TYPED_TEST(TestGroupedConvndFwd1d, Test1D)
this
->
conv_params
.
push_back
({
1
,
2
,
32
,
128
,
256
,
{
1
},
{
3
},
{
1
},
{
1
},
{
0
},
{
0
}});
this
->
conv_params
.
push_back
({
1
,
1
,
1
,
1
,
32
,
{
3
},
{
32
},
{
1
},
{
1
},
{
1
},
{
1
}});
this
->
conv_params
.
push_back
({
1
,
1
,
1
,
64
,
3
,
{
3
},
{
32
},
{
1
},
{
1
},
{
1
},
{
1
}});
this
->
conv_params
.
push_back
({
1
,
96
,
1
,
1
,
1
,
{
3
},
{
512
},
{
1
},
{
1
},
{
1
},
{
1
}});
this
->
template
Run
<
1
>();
}
...
...
@@ -119,6 +120,8 @@ TYPED_TEST(TestGroupedConvndFwd2d, Test2D)
this
->
conv_params
.
push_back
({
2
,
1
,
1
,
1
,
32
,
{
3
,
3
},
{
32
,
32
},
{
1
,
1
},
{
1
,
1
},
{
1
,
1
},
{
1
,
1
}});
this
->
conv_params
.
push_back
({
2
,
1
,
1
,
64
,
3
,
{
3
,
3
},
{
32
,
32
},
{
1
,
1
},
{
1
,
1
},
{
1
,
1
},
{
1
,
1
}});
this
->
conv_params
.
push_back
({
2
,
1
,
1
,
1
,
1
,
{
3
,
3
},
{
32
,
32
},
{
1
,
1
},
{
1
,
1
},
{
1
,
1
},
{
1
,
1
}});
this
->
conv_params
.
push_back
(
{
2
,
96
,
1
,
1
,
1
,
{
3
,
3
},
{
120
,
160
},
{
1
,
1
},
{
1
,
1
},
{
1
,
1
},
{
1
,
1
}});
this
->
template
Run
<
2
>();
}
...
...
@@ -137,6 +140,8 @@ TYPED_TEST(TestGroupedConvndFwd3d, Test3D)
{
3
,
1
,
1
,
64
,
3
,
{
3
,
3
,
3
},
{
32
,
32
,
32
},
{
1
,
1
,
1
},
{
1
,
1
,
1
},
{
1
,
1
,
1
},
{
1
,
1
,
1
}});
this
->
conv_params
.
push_back
(
{
3
,
1
,
1
,
1
,
1
,
{
3
,
3
,
3
},
{
32
,
32
,
32
},
{
1
,
1
,
1
},
{
1
,
1
,
1
},
{
1
,
1
,
1
},
{
1
,
1
,
1
}});
this
->
conv_params
.
push_back
(
{
3
,
96
,
1
,
1
,
1
,
{
3
,
3
,
3
},
{
4
,
30
,
160
},
{
1
,
1
,
1
},
{
1
,
1
,
1
},
{
1
,
1
,
1
},
{
1
,
1
,
1
}});
this
->
template
Run
<
3
>();
}
...
...
@@ -144,6 +149,9 @@ TYPED_TEST(TestGroupedConvndFwd2dLargeCases, Test2DLargeCases)
{
// Case larger than 2GB
this
->
conv_params
.
push_back
(
{
2
,
1
,
64
,
4
,
192
,
{
2
,
2
},
{
224
,
224
},
{
224
,
224
},
{
0
,
0
},
{
0
,
0
},
{
0
,
0
}});
{
2
,
1
,
64
,
4
,
192
,
{
2
,
2
},
{
224
,
224
},
{
224
,
224
},
{
1
,
1
},
{
0
,
0
},
{
0
,
0
}});
// With supported NumGroupsToMerge > 1
this
->
conv_params
.
push_back
(
{
2
,
32
,
64
,
1
,
1
,
{
2
,
2
},
{
672
,
672
},
{
672
,
672
},
{
1
,
1
},
{
0
,
0
},
{
0
,
0
}});
this
->
template
Run
<
2
>();
}
test/wmma_op/wmma_op_util.hpp
View file @
3552041a
...
...
@@ -11,6 +11,7 @@
#include "ck/library/utility/host_tensor_generator.hpp"
#include "ck/library/reference_tensor_operation/cpu/reference_gemm.hpp"
#include "ck/utility/amd_wmma.hpp"
#include "ck/host_utility/device_prop.hpp"
namespace
ck
{
namespace
wmma_op_util
{
...
...
@@ -373,7 +374,8 @@ struct TestWmma
a
,
b
,
c_host
,
a_element_op
,
b_element_op
,
c_element_op
);
// Act
bool
is_supported
=
ck
::
wmma_op_util
::
RunDeviceGEMM
(
wmma_kernel
,
a
,
b
,
c_device
);
bool
is_supported
=
ck
::
is_gfx11_supported
()
&&
ck
::
wmma_op_util
::
RunDeviceGEMM
(
wmma_kernel
,
a
,
b
,
c_device
);
if
(
is_supported
)
{
...
...
Prev
1
…
10
11
12
13
14
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment