Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
5b7c2432
Commit
5b7c2432
authored
Oct 20, 2022
by
Adam Osewski
Browse files
Merge remote-tracking branch 'rosenrodt/gemm-standalone-bench' into wavelet_model
parents
7e493730
5a995b14
Changes
353
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
642 additions
and
979 deletions
+642
-979
profiler/include/profile_softmax_impl.hpp
profiler/include/profile_softmax_impl.hpp
+10
-10
profiler/src/profile_layernorm.cpp
profiler/src/profile_layernorm.cpp
+6
-25
profiler/src/profile_softmax.cpp
profiler/src/profile_softmax.cpp
+41
-43
test/CMakeLists.txt
test/CMakeLists.txt
+3
-4
test/batched_gemm_gemm/test_batched_gemm_gemm_util.hpp
test/batched_gemm_gemm/test_batched_gemm_gemm_util.hpp
+1
-1
test/batched_gemm_masking_scale_softmax_gemm_permute/test_batched_gemm_masking_scale_softmax_gemm_permute_util.hpp
..._batched_gemm_masking_scale_softmax_gemm_permute_util.hpp
+1
-1
test/batched_gemm_softmax_gemm/test_batched_gemm_softmax_gemm_util.hpp
...gemm_softmax_gemm/test_batched_gemm_softmax_gemm_util.hpp
+1
-1
test/convnd_bwd_data/convnd_bwd_data.cpp
test/convnd_bwd_data/convnd_bwd_data.cpp
+62
-210
test/convnd_bwd_weight/convnd_bwd_weight.cpp
test/convnd_bwd_weight/convnd_bwd_weight.cpp
+61
-176
test/convnd_fwd/convnd_fwd.cpp
test/convnd_fwd/convnd_fwd.cpp
+62
-211
test/gemm/CMakeLists.txt
test/gemm/CMakeLists.txt
+10
-0
test/gemm/gemm_bf16.cpp
test/gemm/gemm_bf16.cpp
+6
-51
test/gemm/gemm_fp16.cpp
test/gemm/gemm_fp16.cpp
+6
-51
test/gemm/gemm_fp32.cpp
test/gemm/gemm_fp32.cpp
+6
-51
test/gemm/gemm_fp64.cpp
test/gemm/gemm_fp64.cpp
+6
-51
test/gemm/gemm_int8.cpp
test/gemm/gemm_int8.cpp
+6
-51
test/gemm/gemm_standalone_xdl_fp16.cpp
test/gemm/gemm_standalone_xdl_fp16.cpp
+162
-0
test/gemm/gemm_util.hpp
test/gemm/gemm_util.hpp
+65
-42
test/gemm/instance/gemm_f16_nn_instance.cpp
test/gemm/instance/gemm_f16_nn_instance.cpp
+86
-0
test/gemm/instance/gemm_f16_nn_instance.hpp
test/gemm/instance/gemm_f16_nn_instance.hpp
+41
-0
No files found.
profiler/include/profile_
normalization
_impl.hpp
→
profiler/include/profile_
softmax
_impl.hpp
View file @
5b7c2432
...
@@ -69,16 +69,16 @@ template <> std::string type_to_string<int32_t>() { return "int32"; }
...
@@ -69,16 +69,16 @@ template <> std::string type_to_string<int32_t>() { return "int32"; }
// clang-format on
// clang-format on
template
<
typename
InDataType
,
typename
AccDataType
,
typename
OutDataType
,
index_t
Rank
>
template
<
typename
InDataType
,
typename
AccDataType
,
typename
OutDataType
,
index_t
Rank
>
void
profile_
normalization
_impl
(
int
do_verification
,
void
profile_
softmax
_impl
(
int
do_verification
,
int
init_method
,
int
init_method
,
bool
do_log
,
bool
do_log
,
bool
time_kernel
,
bool
time_kernel
,
std
::
vector
<
index_t
>
in_length
,
std
::
vector
<
index_t
>
in_length
,
std
::
vector
<
index_t
>
in_strides
,
std
::
vector
<
index_t
>
in_strides
,
std
::
vector
<
index_t
>
reduce_dims
,
std
::
vector
<
index_t
>
reduce_dims
,
AccDataType
alpha
,
AccDataType
alpha
,
AccDataType
beta
,
AccDataType
beta
,
NormType
norm_type
)
NormType
norm_type
)
{
{
if
(
Rank
!=
in_length
.
size
())
if
(
Rank
!=
in_length
.
size
())
{
{
...
...
profiler/src/profile_layernorm.cpp
View file @
5b7c2432
...
@@ -12,8 +12,7 @@ using ck::index_t;
...
@@ -12,8 +12,7 @@ using ck::index_t;
struct
LayernormArgParser
struct
LayernormArgParser
{
{
std
::
unordered_map
<
std
::
string
,
std
::
vector
<
int
>>
long_opts
=
{
std
::
unordered_map
<
std
::
string
,
std
::
vector
<
int
>>
long_opts
=
{{
"length"
,
{}}};
{
"length"
,
{}},
{
"strideXY"
,
{}},
{
"strideGamma"
,
{}},
{
"strideBeta"
,
{}}};
bool
parse_opt
(
int
argc
,
char
*
argv
[],
const
std
::
string
&
key
,
int
i
)
bool
parse_opt
(
int
argc
,
char
*
argv
[],
const
std
::
string
&
key
,
int
i
)
{
{
...
@@ -52,9 +51,6 @@ void print_help_layernorm()
...
@@ -52,9 +51,6 @@ void print_help_layernorm()
<<
"arg4: print tensor value (0: no; 1: yes)
\n
"
<<
"arg4: print tensor value (0: no; 1: yes)
\n
"
<<
"arg5: time kernel (0=no, 1=yes)
\n
"
<<
"arg5: time kernel (0=no, 1=yes)
\n
"
<<
"--length: tensor extents (e.g, --length 1024 1024)
\n
"
<<
"--length: tensor extents (e.g, --length 1024 1024)
\n
"
<<
"--strideXY: tensor strides (e.g, --strideXY 1024 1)
\n
"
<<
"--strideGamma: tensor strides (e.g, --strideGamma 1)
\n
"
<<
"--strideBeta: tensor strides (e.g, --strideBeta 1)
\n
"
<<
std
::
endl
;
<<
std
::
endl
;
}
}
...
@@ -77,10 +73,7 @@ int profile_layernorm(int argc, char* argv[])
...
@@ -77,10 +73,7 @@ int profile_layernorm(int argc, char* argv[])
// parse the long options
// parse the long options
arg_parser
(
argc
,
argv
);
arg_parser
(
argc
,
argv
);
const
std
::
vector
<
index_t
>
length
=
arg_parser
.
long_opts
[
"length"
];
const
std
::
vector
<
index_t
>
length
=
arg_parser
.
long_opts
[
"length"
];
const
std
::
vector
<
index_t
>
strideXY
=
arg_parser
.
long_opts
[
"strideXY"
];
const
std
::
vector
<
index_t
>
strideGamma
=
arg_parser
.
long_opts
[
"strideGamma"
];
const
std
::
vector
<
index_t
>
strideBeta
=
arg_parser
.
long_opts
[
"strideBeta"
];
using
F16
=
ck
::
half_t
;
using
F16
=
ck
::
half_t
;
using
F32
=
float
;
using
F32
=
float
;
...
@@ -88,25 +81,13 @@ int profile_layernorm(int argc, char* argv[])
...
@@ -88,25 +81,13 @@ int profile_layernorm(int argc, char* argv[])
if
(
data_type
==
ck
::
DataTypeEnum
::
Half
)
if
(
data_type
==
ck
::
DataTypeEnum
::
Half
)
{
{
ck
::
profiler
::
profile_layernorm_impl
<
F16
,
F16
,
F16
,
F32
,
F16
,
rank
>
(
do_verification
,
ck
::
profiler
::
profile_layernorm_impl
<
F16
,
F16
,
F16
,
F32
,
F16
,
rank
>
(
init_method
,
do_verification
,
init_method
,
do_log
,
time_kernel
,
length
);
do_log
,
time_kernel
,
length
,
strideXY
,
strideGamma
,
strideBeta
);
}
}
else
if
(
data_type
==
ck
::
DataTypeEnum
::
Float
)
else
if
(
data_type
==
ck
::
DataTypeEnum
::
Float
)
{
{
ck
::
profiler
::
profile_layernorm_impl
<
F32
,
F32
,
F32
,
F32
,
F32
,
rank
>
(
do_verification
,
ck
::
profiler
::
profile_layernorm_impl
<
F32
,
F32
,
F32
,
F32
,
F32
,
rank
>
(
init_method
,
do_verification
,
init_method
,
do_log
,
time_kernel
,
length
);
do_log
,
time_kernel
,
length
,
strideXY
,
strideGamma
,
strideBeta
);
}
}
else
else
{
{
...
...
profiler/src/profile_
normalization
.cpp
→
profiler/src/profile_
softmax
.cpp
View file @
5b7c2432
...
@@ -5,7 +5,7 @@
...
@@ -5,7 +5,7 @@
#include <vector>
#include <vector>
#include <unordered_map>
#include <unordered_map>
#include "profiler/include/profile_
normalization
_impl.hpp"
#include "profiler/include/profile_
softmax
_impl.hpp"
using
ck
::
index_t
;
using
ck
::
index_t
;
using
ck
::
profiler
::
NormDataType
;
using
ck
::
profiler
::
NormDataType
;
...
@@ -95,30 +95,29 @@ int profile_normalization(int argc, char* argv[])
...
@@ -95,30 +95,29 @@ int profile_normalization(int argc, char* argv[])
{
{
if
(
data_type
==
NormDataType
::
F16_F16
)
if
(
data_type
==
NormDataType
::
F16_F16
)
{
{
ck
::
profiler
::
profile_normalization_impl
<
ck
::
half_t
,
float
,
ck
::
half_t
,
3
>
(
ck
::
profiler
::
profile_softmax_impl
<
ck
::
half_t
,
float
,
ck
::
half_t
,
3
>
(
do_verification
,
do_verification
,
init_method
,
init_method
,
do_log
,
do_log
,
time_kernel
,
time_kernel
,
length
,
length
,
stride
,
stride
,
reduce
,
reduce
,
float
(
alpha
),
float
(
alpha
),
float
(
beta
),
float
(
beta
),
norm_type
);
norm_type
);
}
}
else
if
(
data_type
==
NormDataType
::
F32_F32
)
else
if
(
data_type
==
NormDataType
::
F32_F32
)
{
{
ck
::
profiler
::
profile_
normalization
_impl
<
float
,
float
,
float
,
3
>
(
do_verification
,
ck
::
profiler
::
profile_
softmax
_impl
<
float
,
float
,
float
,
3
>
(
do_verification
,
init_method
,
init_method
,
do_log
,
do_log
,
time_kernel
,
time_kernel
,
length
,
length
,
stride
,
stride
,
reduce
,
reduce
,
float
(
alpha
),
float
(
alpha
),
float
(
beta
),
float
(
beta
),
norm_type
);
norm_type
);
}
}
else
else
{
{
...
@@ -129,30 +128,29 @@ int profile_normalization(int argc, char* argv[])
...
@@ -129,30 +128,29 @@ int profile_normalization(int argc, char* argv[])
{
{
if
(
data_type
==
NormDataType
::
F16_F16
)
if
(
data_type
==
NormDataType
::
F16_F16
)
{
{
ck
::
profiler
::
profile_normalization_impl
<
ck
::
half_t
,
float
,
ck
::
half_t
,
4
>
(
ck
::
profiler
::
profile_softmax_impl
<
ck
::
half_t
,
float
,
ck
::
half_t
,
4
>
(
do_verification
,
do_verification
,
init_method
,
init_method
,
do_log
,
do_log
,
time_kernel
,
time_kernel
,
length
,
length
,
stride
,
stride
,
reduce
,
reduce
,
float
(
alpha
),
float
(
alpha
),
float
(
beta
),
float
(
beta
),
norm_type
);
norm_type
);
}
}
else
if
(
data_type
==
NormDataType
::
F32_F32
)
else
if
(
data_type
==
NormDataType
::
F32_F32
)
{
{
ck
::
profiler
::
profile_
normalization
_impl
<
float
,
float
,
float
,
4
>
(
do_verification
,
ck
::
profiler
::
profile_
softmax
_impl
<
float
,
float
,
float
,
4
>
(
do_verification
,
init_method
,
init_method
,
do_log
,
do_log
,
time_kernel
,
time_kernel
,
length
,
length
,
stride
,
stride
,
reduce
,
reduce
,
float
(
alpha
),
float
(
alpha
),
float
(
beta
),
float
(
beta
),
norm_type
);
norm_type
);
}
}
else
else
{
{
...
...
test/CMakeLists.txt
View file @
5b7c2432
...
@@ -6,11 +6,10 @@ include(googletest)
...
@@ -6,11 +6,10 @@ include(googletest)
add_custom_target
(
tests
)
add_custom_target
(
tests
)
function
(
add_test_executable TEST_NAME
)
function
(
add_test_executable TEST_NAME
)
message
(
"adding test
${
TEST_NAME
}
"
)
message
(
"adding test
${
TEST_NAME
}
"
)
add_executable
(
${
TEST_NAME
}
${
ARGN
}
)
add_executable
(
${
TEST_NAME
}
${
ARGN
}
)
add_test
(
NAME
${
TEST_NAME
}
COMMAND $<TARGET_FILE:
${
TEST_NAME
}
>
)
add_test
(
NAME
${
TEST_NAME
}
COMMAND $<TARGET_FILE:
${
TEST_NAME
}
>
)
add_dependencies
(
tests
${
TEST_NAME
}
)
add_dependencies
(
tests
${
TEST_NAME
}
)
add_dependencies
(
check
${
TEST_NAME
}
)
add_dependencies
(
check
${
TEST_NAME
}
)
rocm_install
(
TARGETS
${
TEST_NAME
}
COMPONENT tests
)
rocm_install
(
TARGETS
${
TEST_NAME
}
COMPONENT tests
)
...
@@ -23,6 +22,7 @@ function(add_gtest_executable TEST_NAME)
...
@@ -23,6 +22,7 @@ function(add_gtest_executable TEST_NAME)
add_executable
(
${
TEST_NAME
}
${
ARGN
}
)
add_executable
(
${
TEST_NAME
}
${
ARGN
}
)
add_dependencies
(
tests
${
TEST_NAME
}
)
add_dependencies
(
tests
${
TEST_NAME
}
)
add_dependencies
(
check
${
TEST_NAME
}
)
add_dependencies
(
check
${
TEST_NAME
}
)
# suppress gtest warnings
# suppress gtest warnings
target_compile_options
(
${
TEST_NAME
}
PRIVATE -Wno-global-constructors -Wno-undef
)
target_compile_options
(
${
TEST_NAME
}
PRIVATE -Wno-global-constructors -Wno-undef
)
target_link_libraries
(
${
TEST_NAME
}
PRIVATE gtest_main
)
target_link_libraries
(
${
TEST_NAME
}
PRIVATE gtest_main
)
...
@@ -30,7 +30,6 @@ function(add_gtest_executable TEST_NAME)
...
@@ -30,7 +30,6 @@ function(add_gtest_executable TEST_NAME)
rocm_install
(
TARGETS
${
TEST_NAME
}
COMPONENT tests
)
rocm_install
(
TARGETS
${
TEST_NAME
}
COMPONENT tests
)
endfunction
(
add_gtest_executable TEST_NAME
)
endfunction
(
add_gtest_executable TEST_NAME
)
add_subdirectory
(
magic_number_division
)
add_subdirectory
(
magic_number_division
)
add_subdirectory
(
space_filling_curve
)
add_subdirectory
(
space_filling_curve
)
add_subdirectory
(
conv_util
)
add_subdirectory
(
conv_util
)
...
@@ -51,5 +50,5 @@ add_subdirectory(convnd_bwd_data)
...
@@ -51,5 +50,5 @@ add_subdirectory(convnd_bwd_data)
add_subdirectory
(
grouped_convnd_fwd
)
add_subdirectory
(
grouped_convnd_fwd
)
add_subdirectory
(
block_to_ctile_map
)
add_subdirectory
(
block_to_ctile_map
)
add_subdirectory
(
softmax
)
add_subdirectory
(
softmax
)
add_subdirectory
(
layernorm
)
add_subdirectory
(
normalization
)
add_subdirectory
(
data_type
)
add_subdirectory
(
data_type
)
test/batched_gemm_gemm/test_batched_gemm_gemm_util.hpp
View file @
5b7c2432
...
@@ -5,7 +5,7 @@
...
@@ -5,7 +5,7 @@
#include <vector>
#include <vector>
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/device_batched_gemm_gemm_xdl_cshuffle.hpp"
#include "ck/tensor_operation/gpu/device/
impl/
device_batched_gemm_gemm_xdl_cshuffle.hpp"
#include "profiler/include/profile_batched_gemm_gemm_impl.hpp"
#include "profiler/include/profile_batched_gemm_gemm_impl.hpp"
using
ck
::
tensor_operation
::
device
::
GemmSpecialization
;
using
ck
::
tensor_operation
::
device
::
GemmSpecialization
;
...
...
test/batched_gemm_masking_scale_softmax_gemm_permute/test_batched_gemm_masking_scale_softmax_gemm_permute_util.hpp
View file @
5b7c2432
...
@@ -5,7 +5,7 @@
...
@@ -5,7 +5,7 @@
#include <vector>
#include <vector>
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/device_batched_gemm_softmax_gemm_permute_xdl_cshuffle.hpp"
#include "ck/tensor_operation/gpu/device/
impl/
device_batched_gemm_softmax_gemm_permute_xdl_cshuffle.hpp"
#include "profiler/include/profile_batched_gemm_masking_scale_softmax_gemm_permute_impl.hpp"
#include "profiler/include/profile_batched_gemm_masking_scale_softmax_gemm_permute_impl.hpp"
using
ck
::
tensor_operation
::
device
::
GemmSpecialization
;
using
ck
::
tensor_operation
::
device
::
GemmSpecialization
;
...
...
test/batched_gemm_softmax_gemm/test_batched_gemm_softmax_gemm_util.hpp
View file @
5b7c2432
...
@@ -5,7 +5,7 @@
...
@@ -5,7 +5,7 @@
#include <vector>
#include <vector>
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/device_batched_gemm_softmax_gemm_xdl_cshuffle.hpp"
#include "ck/tensor_operation/gpu/device/
impl/
device_batched_gemm_softmax_gemm_xdl_cshuffle.hpp"
#include "profiler/include/profile_batched_gemm_softmax_gemm_impl.hpp"
#include "profiler/include/profile_batched_gemm_softmax_gemm_impl.hpp"
using
ck
::
tensor_operation
::
device
::
GemmSpecialization
;
using
ck
::
tensor_operation
::
device
::
GemmSpecialization
;
...
...
test/convnd_bwd_data/convnd_bwd_data.cpp
View file @
5b7c2432
...
@@ -5,237 +5,89 @@
...
@@ -5,237 +5,89 @@
#include <iostream>
#include <iostream>
#include <initializer_list>
#include <initializer_list>
#include <vector>
#include <vector>
#include <tuple>
#include <gtest/gtest.h>
#include <gtest/gtest.h>
#include "profiler/include/profile_conv_bwd_data_impl.hpp"
#include "profiler/include/profile_conv_bwd_data_impl.hpp"
template
<
typename
Tuple
>
class
TestConvndBwdData
:
public
::
testing
::
Test
class
TestConvndBwdData
:
public
::
testing
::
Test
{
{
protected:
protected:
using
DataType
=
std
::
tuple_element_t
<
0
,
Tuple
>
;
std
::
vector
<
ck
::
utils
::
conv
::
ConvParam
>
conv_params
;
std
::
vector
<
ck
::
utils
::
conv
::
ConvParam
>
conv_params
;
};
// 1d
template
<
ck
::
index_t
NDimSpatial
>
TEST_F
(
TestConvndBwdData
,
Conv1dBwdData
)
void
Run
()
{
conv_params
.
clear
();
conv_params
.
push_back
({
1
,
1
,
128
,
128
,
256
,
{
1
},
{
14
},
{
2
},
{
1
},
{
0
},
{
0
}});
conv_params
.
push_back
({
1
,
1
,
128
,
128
,
256
,
{
3
},
{
28
},
{
1
},
{
1
},
{
1
},
{
1
}});
conv_params
.
push_back
({
1
,
1
,
128
,
128
,
256
,
{
1
},
{
3
},
{
1
},
{
1
},
{
0
},
{
0
}});
for
(
auto
&
param
:
conv_params
)
{
{
bool
pass
;
for
(
auto
&
param
:
conv_params
)
{
// fp32
bool
pass
;
pass
=
ck
::
profiler
::
profile_conv_bwd_data_impl
<
1
,
EXPECT_FALSE
(
conv_params
.
empty
());
ck
::
tensor_layout
::
convolution
::
NWC
,
pass
=
ck
::
profiler
::
profile_conv_bwd_data_impl
<
ck
::
tensor_layout
::
convolution
::
KXC
,
NDimSpatial
,
ck
::
tensor_layout
::
convolution
::
NWK
,
ck
::
tuple_element_t
<
NDimSpatial
-
1
,
float
,
ck
::
Tuple
<
ck
::
tensor_layout
::
convolution
::
NWC
,
float
,
ck
::
tensor_layout
::
convolution
::
NHWC
,
float
>
(
true
,
// do_verification
ck
::
tensor_layout
::
convolution
::
NDHWC
>>
,
1
,
// init_method
ck
::
tuple_element_t
<
NDimSpatial
-
1
,
false
,
// do_log
ck
::
Tuple
<
ck
::
tensor_layout
::
convolution
::
KXC
,
false
,
// time_kernel
ck
::
tensor_layout
::
convolution
::
KYXC
,
param
);
ck
::
tensor_layout
::
convolution
::
KZYXC
>>
,
ck
::
tuple_element_t
<
NDimSpatial
-
1
,
EXPECT_TRUE
(
pass
);
ck
::
Tuple
<
ck
::
tensor_layout
::
convolution
::
NWK
,
ck
::
tensor_layout
::
convolution
::
NHWK
,
// fp16
ck
::
tensor_layout
::
convolution
::
NDHWK
>>
,
pass
=
ck
::
profiler
::
profile_conv_bwd_data_impl
<
1
,
DataType
,
ck
::
tensor_layout
::
convolution
::
NWC
,
DataType
,
ck
::
tensor_layout
::
convolution
::
KXC
,
DataType
>
(
true
,
// do_verification
ck
::
tensor_layout
::
convolution
::
NWK
,
1
,
// init_method integer value
ck
::
half_t
,
false
,
// do_log
ck
::
half_t
,
false
,
// time_kernel
ck
::
half_t
>
(
true
,
// do_verification
param
);
1
,
// init_method
EXPECT_TRUE
(
pass
);
false
,
// do_log
}
false
,
// time_kernel
}
param
);
};
EXPECT_TRUE
(
pass
);
// bf16
pass
=
ck
::
profiler
::
profile_conv_bwd_data_impl
<
1
,
ck
::
tensor_layout
::
convolution
::
NWC
,
ck
::
tensor_layout
::
convolution
::
KXC
,
ck
::
tensor_layout
::
convolution
::
NWK
,
ck
::
bhalf_t
,
ck
::
bhalf_t
,
ck
::
bhalf_t
>
(
true
,
// do_verification
1
,
// init_method
false
,
// do_log
false
,
// time_kernel
param
);
EXPECT_TRUE
(
pass
);
// int8
using
KernelTypes
=
::
testing
::
Types
<
std
::
tuple
<
float
>
,
pass
=
ck
::
profiler
::
profile_conv_bwd_data_impl
<
1
,
std
::
tuple
<
ck
::
half_t
>
,
ck
::
tensor_layout
::
convolution
::
NWC
,
std
::
tuple
<
ck
::
bhalf_t
>
,
ck
::
tensor_layout
::
convolution
::
KXC
,
std
::
tuple
<
std
::
int8_t
>>
;
ck
::
tensor_layout
::
convolution
::
NWK
,
TYPED_TEST_SUITE
(
TestConvndBwdData
,
KernelTypes
);
int8_t
,
int8_t
,
int8_t
>
(
true
,
// do_verification
1
,
// init_method
false
,
// do_log
false
,
// time_kernel
param
);
EXPECT_TRUE
(
pass
);
// 1d
}
TYPED_TEST
(
TestConvndBwdData
,
Conv1dBwdData
)
{
this
->
conv_params
.
clear
();
this
->
conv_params
.
push_back
({
1
,
1
,
128
,
128
,
256
,
{
1
},
{
14
},
{
2
},
{
1
},
{
0
},
{
0
}});
this
->
conv_params
.
push_back
({
1
,
1
,
128
,
128
,
256
,
{
3
},
{
28
},
{
1
},
{
1
},
{
1
},
{
1
}});
this
->
conv_params
.
push_back
({
1
,
1
,
128
,
128
,
256
,
{
1
},
{
3
},
{
1
},
{
1
},
{
0
},
{
0
}});
this
->
template
Run
<
1
>();
}
}
// 2d
// 2d
TEST
_F
(
TestConvndBwdData
,
Conv2dBwdData
)
TYPED_
TEST
(
TestConvndBwdData
,
Conv2dBwdData
)
{
{
conv_params
.
clear
();
this
->
conv_params
.
clear
();
conv_params
.
push_back
({
2
,
1
,
128
,
128
,
256
,
{
1
,
1
},
{
7
,
7
},
{
2
,
2
},
{
1
,
1
},
{
0
,
0
},
{
0
,
0
}});
this
->
conv_params
.
push_back
(
conv_params
.
push_back
({
2
,
1
,
128
,
128
,
256
,
{
3
,
3
},
{
14
,
14
},
{
1
,
1
},
{
1
,
1
},
{
1
,
1
},
{
1
,
1
}});
{
2
,
1
,
128
,
128
,
256
,
{
1
,
1
},
{
7
,
7
},
{
2
,
2
},
{
1
,
1
},
{
0
,
0
},
{
0
,
0
}});
conv_params
.
push_back
({
2
,
1
,
128
,
128
,
256
,
{
1
,
1
},
{
3
,
3
},
{
1
,
1
},
{
1
,
1
},
{
0
,
0
},
{
0
,
0
}});
this
->
conv_params
.
push_back
(
{
2
,
1
,
128
,
128
,
256
,
{
3
,
3
},
{
14
,
14
},
{
1
,
1
},
{
1
,
1
},
{
1
,
1
},
{
1
,
1
}});
for
(
auto
&
param
:
conv_params
)
this
->
conv_params
.
push_back
(
{
{
2
,
1
,
128
,
128
,
256
,
{
1
,
1
},
{
3
,
3
},
{
1
,
1
},
{
1
,
1
},
{
0
,
0
},
{
0
,
0
}});
bool
pass
;
this
->
template
Run
<
2
>();
// fp32
pass
=
ck
::
profiler
::
profile_conv_bwd_data_impl
<
2
,
ck
::
tensor_layout
::
convolution
::
NHWC
,
ck
::
tensor_layout
::
convolution
::
KYXC
,
ck
::
tensor_layout
::
convolution
::
NHWK
,
float
,
float
,
float
>
(
true
,
// do_verification
1
,
// init_method
false
,
// do_log
false
,
// time_kernel
param
);
EXPECT_TRUE
(
pass
);
// fp16
pass
=
ck
::
profiler
::
profile_conv_bwd_data_impl
<
2
,
ck
::
tensor_layout
::
convolution
::
NHWC
,
ck
::
tensor_layout
::
convolution
::
KYXC
,
ck
::
tensor_layout
::
convolution
::
NHWK
,
ck
::
half_t
,
ck
::
half_t
,
ck
::
half_t
>
(
true
,
// do_verification
1
,
// init_method
false
,
// do_log
false
,
// time_kernel
param
);
EXPECT_TRUE
(
pass
);
// bf16
pass
=
ck
::
profiler
::
profile_conv_bwd_data_impl
<
2
,
ck
::
tensor_layout
::
convolution
::
NHWC
,
ck
::
tensor_layout
::
convolution
::
KYXC
,
ck
::
tensor_layout
::
convolution
::
NHWK
,
ck
::
bhalf_t
,
ck
::
bhalf_t
,
ck
::
bhalf_t
>
(
true
,
// do_verification
1
,
// init_method
false
,
// do_log
false
,
// time_kernel
param
);
EXPECT_TRUE
(
pass
);
// int8
pass
=
ck
::
profiler
::
profile_conv_bwd_data_impl
<
2
,
ck
::
tensor_layout
::
convolution
::
NHWC
,
ck
::
tensor_layout
::
convolution
::
KYXC
,
ck
::
tensor_layout
::
convolution
::
NHWK
,
int8_t
,
int8_t
,
int8_t
>
(
true
,
// do_verification
1
,
// init_method
false
,
// do_log
false
,
// time_kernel
param
);
EXPECT_TRUE
(
pass
);
}
}
}
// 3d
// 3d
TEST
_F
(
TestConvndBwdData
,
Conv3dBwdData
)
TYPED_
TEST
(
TestConvndBwdData
,
Conv3dBwdData
)
{
{
conv_params
.
clear
();
this
->
conv_params
.
clear
();
conv_params
.
push_back
(
this
->
conv_params
.
push_back
(
{
3
,
1
,
128
,
128
,
256
,
{
1
,
1
,
1
},
{
7
,
7
,
7
},
{
2
,
2
,
2
},
{
1
,
1
,
1
},
{
0
,
0
,
0
},
{
0
,
0
,
0
}});
{
3
,
1
,
128
,
128
,
256
,
{
1
,
1
,
1
},
{
7
,
7
,
7
},
{
2
,
2
,
2
},
{
1
,
1
,
1
},
{
0
,
0
,
0
},
{
0
,
0
,
0
}});
conv_params
.
push_back
(
this
->
conv_params
.
push_back
(
{
3
,
1
,
128
,
128
,
256
,
{
3
,
3
,
3
},
{
14
,
14
,
3
},
{
1
,
1
,
1
},
{
1
,
1
,
1
},
{
1
,
1
,
1
},
{
1
,
1
,
1
}});
{
3
,
1
,
128
,
128
,
256
,
{
3
,
3
,
3
},
{
14
,
14
,
3
},
{
1
,
1
,
1
},
{
1
,
1
,
1
},
{
1
,
1
,
1
},
{
1
,
1
,
1
}});
conv_params
.
push_back
(
this
->
conv_params
.
push_back
(
{
3
,
1
,
128
,
128
,
256
,
{
1
,
1
,
1
},
{
3
,
3
,
3
},
{
1
,
1
,
1
},
{
1
,
1
,
1
},
{
0
,
0
,
0
},
{
0
,
0
,
0
}});
{
3
,
1
,
128
,
128
,
256
,
{
1
,
1
,
1
},
{
3
,
3
,
3
},
{
1
,
1
,
1
},
{
1
,
1
,
1
},
{
0
,
0
,
0
},
{
0
,
0
,
0
}});
this
->
template
Run
<
3
>();
for
(
auto
&
param
:
conv_params
)
{
bool
pass
;
// fp32
pass
=
ck
::
profiler
::
profile_conv_bwd_data_impl
<
3
,
ck
::
tensor_layout
::
convolution
::
NDHWC
,
ck
::
tensor_layout
::
convolution
::
KZYXC
,
ck
::
tensor_layout
::
convolution
::
NDHWK
,
float
,
float
,
float
>
(
true
,
// do_verification
1
,
// init_method
false
,
// do_log
false
,
// time_kernel
param
);
EXPECT_TRUE
(
pass
);
// fp16
pass
=
ck
::
profiler
::
profile_conv_bwd_data_impl
<
3
,
ck
::
tensor_layout
::
convolution
::
NDHWC
,
ck
::
tensor_layout
::
convolution
::
KZYXC
,
ck
::
tensor_layout
::
convolution
::
NDHWK
,
ck
::
half_t
,
ck
::
half_t
,
ck
::
half_t
>
(
true
,
// do_verification
1
,
// init_method
false
,
// do_log
false
,
// time_kernel
param
);
EXPECT_TRUE
(
pass
);
// bf16
pass
=
ck
::
profiler
::
profile_conv_bwd_data_impl
<
3
,
ck
::
tensor_layout
::
convolution
::
NDHWC
,
ck
::
tensor_layout
::
convolution
::
KZYXC
,
ck
::
tensor_layout
::
convolution
::
NDHWK
,
ck
::
bhalf_t
,
ck
::
bhalf_t
,
ck
::
bhalf_t
>
(
true
,
// do_verification
1
,
// init_method
false
,
// do_log
false
,
// time_kernel
param
);
EXPECT_TRUE
(
pass
);
// int8
pass
=
ck
::
profiler
::
profile_conv_bwd_data_impl
<
3
,
ck
::
tensor_layout
::
convolution
::
NDHWC
,
ck
::
tensor_layout
::
convolution
::
KZYXC
,
ck
::
tensor_layout
::
convolution
::
NDHWK
,
int8_t
,
int8_t
,
int8_t
>
(
true
,
// do_verification
1
,
// init_method
false
,
// do_log
false
,
// time_kernel
param
);
EXPECT_TRUE
(
pass
);
}
}
}
test/convnd_bwd_weight/convnd_bwd_weight.cpp
View file @
5b7c2432
...
@@ -5,201 +5,86 @@
...
@@ -5,201 +5,86 @@
#include <iostream>
#include <iostream>
#include <initializer_list>
#include <initializer_list>
#include <vector>
#include <vector>
#include <tuple>
#include <gtest/gtest.h>
#include <gtest/gtest.h>
#include "profiler/include/profile_conv_bwd_weight_impl.hpp"
#include "profiler/include/profile_conv_bwd_weight_impl.hpp"
template
<
typename
Tuple
>
class
TestConvndBwdWeight
:
public
::
testing
::
Test
class
TestConvndBwdWeight
:
public
::
testing
::
Test
{
{
protected:
protected:
using
DataType
=
std
::
tuple_element_t
<
0
,
Tuple
>
;
std
::
vector
<
ck
::
utils
::
conv
::
ConvParam
>
conv_params
;
std
::
vector
<
ck
::
utils
::
conv
::
ConvParam
>
conv_params
;
};
ck
::
index_t
split_k
{
2
};
// 1d
TEST_F
(
TestConvndBwdWeight
,
Conv1dBwdWeight
)
{
conv_params
.
clear
();
conv_params
.
push_back
({
1
,
1
,
128
,
128
,
256
,
{
1
},
{
14
},
{
2
},
{
1
},
{
0
},
{
0
}});
conv_params
.
push_back
({
1
,
1
,
128
,
128
,
256
,
{
3
},
{
28
},
{
1
},
{
1
},
{
1
},
{
1
}});
conv_params
.
push_back
({
1
,
1
,
128
,
128
,
256
,
{
1
},
{
3
},
{
1
},
{
1
},
{
0
},
{
0
}});
for
(
auto
&
param
:
conv_params
)
template
<
ck
::
index_t
NDimSpatial
>
void
Run
()
{
{
bool
pass
;
for
(
auto
&
param
:
conv_params
)
{
// fp32
bool
pass
;
pass
=
ck
::
profiler
::
profile_conv_bwd_weight_impl
<
1
,
EXPECT_FALSE
(
conv_params
.
empty
());
ck
::
tensor_layout
::
convolution
::
NWC
,
pass
=
ck
::
profiler
::
profile_conv_bwd_weight_impl
<
ck
::
tensor_layout
::
convolution
::
KXC
,
NDimSpatial
,
ck
::
tensor_layout
::
convolution
::
NWK
,
ck
::
tuple_element_t
<
NDimSpatial
-
1
,
float
,
ck
::
Tuple
<
ck
::
tensor_layout
::
convolution
::
NWC
,
float
,
ck
::
tensor_layout
::
convolution
::
NHWC
,
float
>
(
true
,
// do_verification
ck
::
tensor_layout
::
convolution
::
NDHWC
>>
,
1
,
// init_method
ck
::
tuple_element_t
<
NDimSpatial
-
1
,
false
,
// do_log
ck
::
Tuple
<
ck
::
tensor_layout
::
convolution
::
KXC
,
false
,
// time_kernel
ck
::
tensor_layout
::
convolution
::
KYXC
,
param
,
ck
::
tensor_layout
::
convolution
::
KZYXC
>>
,
2
);
ck
::
tuple_element_t
<
NDimSpatial
-
1
,
ck
::
Tuple
<
ck
::
tensor_layout
::
convolution
::
NWK
,
EXPECT_TRUE
(
pass
);
ck
::
tensor_layout
::
convolution
::
NHWK
,
ck
::
tensor_layout
::
convolution
::
NDHWK
>>
,
// fp16
DataType
,
pass
=
ck
::
profiler
::
profile_conv_bwd_weight_impl
<
1
,
DataType
,
ck
::
tensor_layout
::
convolution
::
NWC
,
DataType
>
(
true
,
// do_verification
ck
::
tensor_layout
::
convolution
::
KXC
,
1
,
// init_method integer value
ck
::
tensor_layout
::
convolution
::
NWK
,
false
,
// do_log
ck
::
half_t
,
false
,
// time_kernel
ck
::
half_t
,
param
,
ck
::
half_t
>
(
true
,
// do_verification
split_k
);
1
,
// init_method
EXPECT_TRUE
(
pass
);
false
,
// do_log
}
false
,
// time_kernel
}
param
,
};
2
);
EXPECT_TRUE
(
pass
);
// bf16
using
KernelTypes
=
pass
=
ck
::
profiler
::
profile_conv_bwd_weight_impl
<
1
,
::
testing
::
Types
<
std
::
tuple
<
float
>
,
std
::
tuple
<
ck
::
half_t
>
,
std
::
tuple
<
ck
::
bhalf_t
>>
;
ck
::
tensor_layout
::
convolution
::
NWC
,
TYPED_TEST_SUITE
(
TestConvndBwdWeight
,
KernelTypes
);
ck
::
tensor_layout
::
convolution
::
KXC
,
ck
::
tensor_layout
::
convolution
::
NWK
,
ck
::
bhalf_t
,
ck
::
bhalf_t
,
ck
::
bhalf_t
>
(
true
,
// do_verification
1
,
// init_method
false
,
// do_log
false
,
// time_kernel
param
,
2
);
EXPECT_TRUE
(
pass
);
TYPED_TEST
(
TestConvndBwdWeight
,
Test1D
)
}
{
this
->
conv_params
.
clear
();
this
->
conv_params
.
push_back
({
1
,
1
,
128
,
128
,
256
,
{
1
},
{
14
},
{
2
},
{
1
},
{
0
},
{
0
}});
this
->
conv_params
.
push_back
({
1
,
1
,
128
,
128
,
256
,
{
3
},
{
28
},
{
1
},
{
1
},
{
1
},
{
1
}});
this
->
conv_params
.
push_back
({
1
,
1
,
128
,
128
,
256
,
{
1
},
{
3
},
{
1
},
{
1
},
{
0
},
{
0
}});
this
->
template
Run
<
1
>();
}
}
// 2d
TYPED_TEST
(
TestConvndBwdWeight
,
Test2D
)
TEST_F
(
TestConvndBwdWeight
,
Conv2dBwdWeight
)
{
{
conv_params
.
clear
();
this
->
conv_params
.
clear
();
conv_params
.
push_back
({
2
,
1
,
128
,
128
,
256
,
{
1
,
1
},
{
7
,
7
},
{
2
,
2
},
{
1
,
1
},
{
0
,
0
},
{
0
,
0
}});
this
->
conv_params
.
push_back
(
conv_params
.
push_back
({
2
,
1
,
32
,
128
,
256
,
{
3
,
3
},
{
14
,
14
},
{
1
,
1
},
{
1
,
1
},
{
1
,
1
},
{
1
,
1
}});
{
2
,
1
,
128
,
128
,
256
,
{
1
,
1
},
{
7
,
7
},
{
2
,
2
},
{
1
,
1
},
{
0
,
0
},
{
0
,
0
}});
conv_params
.
push_back
({
2
,
1
,
128
,
128
,
256
,
{
1
,
1
},
{
3
,
3
},
{
1
,
1
},
{
1
,
1
},
{
0
,
0
},
{
0
,
0
}});
this
->
conv_params
.
push_back
(
{
2
,
1
,
32
,
128
,
256
,
{
3
,
3
},
{
14
,
14
},
{
1
,
1
},
{
1
,
1
},
{
1
,
1
},
{
1
,
1
}});
for
(
auto
&
param
:
conv_params
)
this
->
conv_params
.
push_back
(
{
{
2
,
1
,
128
,
128
,
256
,
{
1
,
1
},
{
3
,
3
},
{
1
,
1
},
{
1
,
1
},
{
0
,
0
},
{
0
,
0
}});
bool
pass
;
this
->
template
Run
<
2
>();
// fp32
pass
=
ck
::
profiler
::
profile_conv_bwd_weight_impl
<
2
,
ck
::
tensor_layout
::
convolution
::
NHWC
,
ck
::
tensor_layout
::
convolution
::
KYXC
,
ck
::
tensor_layout
::
convolution
::
NHWK
,
float
,
float
,
float
>
(
true
,
// do_verification
1
,
// init_method
false
,
// do_log
false
,
// time_kernel
param
,
2
);
EXPECT_TRUE
(
pass
);
// fp16
pass
=
ck
::
profiler
::
profile_conv_bwd_weight_impl
<
2
,
ck
::
tensor_layout
::
convolution
::
NHWC
,
ck
::
tensor_layout
::
convolution
::
KYXC
,
ck
::
tensor_layout
::
convolution
::
NHWK
,
ck
::
half_t
,
ck
::
half_t
,
ck
::
half_t
>
(
true
,
// do_verification
1
,
// init_method
false
,
// do_log
false
,
// time_kernel
param
,
2
);
EXPECT_TRUE
(
pass
);
// bf16
pass
=
ck
::
profiler
::
profile_conv_bwd_weight_impl
<
2
,
ck
::
tensor_layout
::
convolution
::
NHWC
,
ck
::
tensor_layout
::
convolution
::
KYXC
,
ck
::
tensor_layout
::
convolution
::
NHWK
,
ck
::
bhalf_t
,
ck
::
bhalf_t
,
ck
::
bhalf_t
>
(
true
,
// do_verification
1
,
// init_method
false
,
// do_log
false
,
// time_kernel
param
,
2
);
EXPECT_TRUE
(
pass
);
}
}
}
// 3d
TYPED_TEST
(
TestConvndBwdWeight
,
Test3D
)
TEST_F
(
TestConvndBwdWeight
,
Conv3dBwdWeight
)
{
{
conv_params
.
clear
();
this
->
conv_params
.
clear
();
conv_params
.
push_back
(
this
->
conv_params
.
push_back
(
{
3
,
1
,
128
,
128
,
256
,
{
1
,
1
,
1
},
{
7
,
7
,
7
},
{
2
,
2
,
2
},
{
1
,
1
,
1
},
{
0
,
0
,
0
},
{
0
,
0
,
0
}});
{
3
,
1
,
128
,
128
,
256
,
{
1
,
1
,
1
},
{
7
,
7
,
7
},
{
2
,
2
,
2
},
{
1
,
1
,
1
},
{
0
,
0
,
0
},
{
0
,
0
,
0
}});
conv_params
.
push_back
(
this
->
conv_params
.
push_back
(
{
3
,
1
,
32
,
128
,
256
,
{
3
,
3
,
3
},
{
14
,
14
,
3
},
{
1
,
1
,
1
},
{
1
,
1
,
1
},
{
1
,
1
,
1
},
{
1
,
1
,
1
}});
{
3
,
1
,
32
,
128
,
256
,
{
3
,
3
,
3
},
{
14
,
14
,
3
},
{
1
,
1
,
1
},
{
1
,
1
,
1
},
{
1
,
1
,
1
},
{
1
,
1
,
1
}});
conv_params
.
push_back
(
this
->
conv_params
.
push_back
(
{
3
,
1
,
128
,
128
,
256
,
{
1
,
1
,
1
},
{
3
,
3
,
3
},
{
1
,
1
,
1
},
{
1
,
1
,
1
},
{
0
,
0
,
0
},
{
0
,
0
,
0
}});
{
3
,
1
,
128
,
128
,
256
,
{
1
,
1
,
1
},
{
3
,
3
,
3
},
{
1
,
1
,
1
},
{
1
,
1
,
1
},
{
0
,
0
,
0
},
{
0
,
0
,
0
}});
this
->
template
Run
<
3
>();
for
(
auto
&
param
:
conv_params
)
{
bool
pass
;
// fp32
pass
=
ck
::
profiler
::
profile_conv_bwd_weight_impl
<
3
,
ck
::
tensor_layout
::
convolution
::
NDHWC
,
ck
::
tensor_layout
::
convolution
::
KZYXC
,
ck
::
tensor_layout
::
convolution
::
NDHWK
,
float
,
float
,
float
>
(
true
,
// do_verification
1
,
// init_method
false
,
// do_log
false
,
// time_kernel
param
,
2
);
EXPECT_TRUE
(
pass
);
// fp16
pass
=
ck
::
profiler
::
profile_conv_bwd_weight_impl
<
3
,
ck
::
tensor_layout
::
convolution
::
NDHWC
,
ck
::
tensor_layout
::
convolution
::
KZYXC
,
ck
::
tensor_layout
::
convolution
::
NDHWK
,
ck
::
half_t
,
ck
::
half_t
,
ck
::
half_t
>
(
true
,
// do_verification
1
,
// init_method
false
,
// do_log
false
,
// time_kernel
param
,
2
);
EXPECT_TRUE
(
pass
);
// bf16
pass
=
ck
::
profiler
::
profile_conv_bwd_weight_impl
<
3
,
ck
::
tensor_layout
::
convolution
::
NDHWC
,
ck
::
tensor_layout
::
convolution
::
KZYXC
,
ck
::
tensor_layout
::
convolution
::
NDHWK
,
ck
::
bhalf_t
,
ck
::
bhalf_t
,
ck
::
bhalf_t
>
(
true
,
// do_verification
1
,
// init_method
false
,
// do_log
false
,
// time_kernel
param
,
2
);
EXPECT_TRUE
(
pass
);
}
}
}
test/convnd_fwd/convnd_fwd.cpp
View file @
5b7c2432
...
@@ -5,237 +5,88 @@
...
@@ -5,237 +5,88 @@
#include <iostream>
#include <iostream>
#include <initializer_list>
#include <initializer_list>
#include <vector>
#include <vector>
#include <tuple>
#include <gtest/gtest.h>
#include <gtest/gtest.h>
#include "profiler/include/profile_conv_fwd_impl.hpp"
#include "profiler/include/profile_conv_fwd_impl.hpp"
template
<
typename
Tuple
>
class
TestConvndFwd
:
public
::
testing
::
Test
class
TestConvndFwd
:
public
::
testing
::
Test
{
{
protected:
protected:
using
DataType
=
std
::
tuple_element_t
<
0
,
Tuple
>
;
std
::
vector
<
ck
::
utils
::
conv
::
ConvParam
>
conv_params
;
std
::
vector
<
ck
::
utils
::
conv
::
ConvParam
>
conv_params
;
};
// 1d
template
<
ck
::
index_t
NDimSpatial
>
TEST_F
(
TestConvndFwd
,
Conv1dFwd
)
void
Run
()
{
conv_params
.
clear
();
conv_params
.
push_back
({
1
,
1
,
128
,
128
,
256
,
{
1
},
{
14
},
{
2
},
{
1
},
{
0
},
{
0
}});
conv_params
.
push_back
({
1
,
1
,
128
,
128
,
256
,
{
3
},
{
28
},
{
1
},
{
1
},
{
1
},
{
1
}});
conv_params
.
push_back
({
1
,
1
,
128
,
128
,
256
,
{
1
},
{
3
},
{
1
},
{
1
},
{
0
},
{
0
}});
for
(
auto
&
param
:
conv_params
)
{
{
bool
pass
;
for
(
auto
&
param
:
conv_params
)
{
// fp32
bool
pass
;
pass
=
ck
::
profiler
::
profile_conv_fwd_impl
<
1
,
EXPECT_FALSE
(
conv_params
.
empty
());
ck
::
tensor_layout
::
convolution
::
NWC
,
pass
=
ck
::
profiler
::
profile_conv_fwd_impl
<
ck
::
tensor_layout
::
convolution
::
KXC
,
NDimSpatial
,
ck
::
tensor_layout
::
convolution
::
NWK
,
ck
::
tuple_element_t
<
NDimSpatial
-
1
,
float
,
ck
::
Tuple
<
ck
::
tensor_layout
::
convolution
::
NWC
,
float
,
ck
::
tensor_layout
::
convolution
::
NHWC
,
float
>
(
true
,
// do_verification
ck
::
tensor_layout
::
convolution
::
NDHWC
>>
,
1
,
// init_method
ck
::
tuple_element_t
<
NDimSpatial
-
1
,
false
,
// do_log
ck
::
Tuple
<
ck
::
tensor_layout
::
convolution
::
KXC
,
false
,
// time_kernel
ck
::
tensor_layout
::
convolution
::
KYXC
,
param
);
ck
::
tensor_layout
::
convolution
::
KZYXC
>>
,
ck
::
tuple_element_t
<
NDimSpatial
-
1
,
EXPECT_TRUE
(
pass
);
ck
::
Tuple
<
ck
::
tensor_layout
::
convolution
::
NWK
,
ck
::
tensor_layout
::
convolution
::
NHWK
,
// fp16
ck
::
tensor_layout
::
convolution
::
NDHWK
>>
,
pass
=
ck
::
profiler
::
profile_conv_fwd_impl
<
1
,
DataType
,
ck
::
tensor_layout
::
convolution
::
NWC
,
DataType
,
ck
::
tensor_layout
::
convolution
::
KXC
,
DataType
>
(
true
,
// do_verification
ck
::
tensor_layout
::
convolution
::
NWK
,
1
,
// init_method integer value
ck
::
half_t
,
false
,
// do_log
ck
::
half_t
,
false
,
// time_kernel
ck
::
half_t
>
(
true
,
// do_verification
param
);
1
,
// init_method
EXPECT_TRUE
(
pass
);
false
,
// do_log
}
false
,
// time_kernel
}
param
);
};
EXPECT_TRUE
(
pass
);
// bf16
pass
=
ck
::
profiler
::
profile_conv_fwd_impl
<
1
,
ck
::
tensor_layout
::
convolution
::
NWC
,
ck
::
tensor_layout
::
convolution
::
KXC
,
ck
::
tensor_layout
::
convolution
::
NWK
,
ck
::
bhalf_t
,
ck
::
bhalf_t
,
ck
::
bhalf_t
>
(
true
,
// do_verification
1
,
// init_method
false
,
// do_log
false
,
// time_kernel
param
);
EXPECT_TRUE
(
pass
);
// int8
using
KernelTypes
=
::
testing
::
Types
<
std
::
tuple
<
float
>
,
pass
=
ck
::
profiler
::
profile_conv_fwd_impl
<
1
,
std
::
tuple
<
ck
::
half_t
>
,
ck
::
tensor_layout
::
convolution
::
NWC
,
std
::
tuple
<
ck
::
bhalf_t
>
,
ck
::
tensor_layout
::
convolution
::
KXC
,
std
::
tuple
<
std
::
int8_t
>>
;
ck
::
tensor_layout
::
convolution
::
NWK
,
TYPED_TEST_SUITE
(
TestConvndFwd
,
KernelTypes
);
int8_t
,
int8_t
,
int8_t
>
(
true
,
// do_verification
1
,
// init_method
false
,
// do_log
false
,
// time_kernel
param
);
EXPECT_TRUE
(
pass
);
// 1d
}
TYPED_TEST
(
TestConvndFwd
,
Conv1dFwd
)
{
this
->
conv_params
.
clear
();
this
->
conv_params
.
push_back
({
1
,
1
,
128
,
128
,
256
,
{
1
},
{
14
},
{
2
},
{
1
},
{
0
},
{
0
}});
this
->
conv_params
.
push_back
({
1
,
1
,
128
,
128
,
256
,
{
3
},
{
28
},
{
1
},
{
1
},
{
1
},
{
1
}});
this
->
conv_params
.
push_back
({
1
,
1
,
128
,
128
,
256
,
{
1
},
{
3
},
{
1
},
{
1
},
{
0
},
{
0
}});
this
->
template
Run
<
1
>();
}
}
// 2d
// 2d
TEST
_F
(
TestConvndFwd
,
Conv2dFwd
)
TYPED_
TEST
(
TestConvndFwd
,
Conv2dFwd
)
{
{
conv_params
.
clear
();
this
->
conv_params
.
clear
();
conv_params
.
push_back
({
2
,
1
,
128
,
128
,
256
,
{
1
,
1
},
{
7
,
7
},
{
2
,
2
},
{
1
,
1
},
{
0
,
0
},
{
0
,
0
}});
this
->
conv_params
.
push_back
(
conv_params
.
push_back
({
2
,
1
,
128
,
128
,
256
,
{
3
,
3
},
{
14
,
14
},
{
1
,
1
},
{
1
,
1
},
{
1
,
1
},
{
1
,
1
}});
{
2
,
1
,
128
,
128
,
256
,
{
1
,
1
},
{
7
,
7
},
{
2
,
2
},
{
1
,
1
},
{
0
,
0
},
{
0
,
0
}});
conv_params
.
push_back
({
2
,
1
,
128
,
128
,
256
,
{
1
,
1
},
{
3
,
3
},
{
1
,
1
},
{
1
,
1
},
{
0
,
0
},
{
0
,
0
}});
this
->
conv_params
.
push_back
(
{
2
,
1
,
128
,
128
,
256
,
{
3
,
3
},
{
14
,
14
},
{
1
,
1
},
{
1
,
1
},
{
1
,
1
},
{
1
,
1
}});
for
(
auto
&
param
:
conv_params
)
this
->
conv_params
.
push_back
(
{
{
2
,
1
,
128
,
128
,
256
,
{
1
,
1
},
{
3
,
3
},
{
1
,
1
},
{
1
,
1
},
{
0
,
0
},
{
0
,
0
}});
bool
pass
;
this
->
template
Run
<
2
>();
// fp32
pass
=
ck
::
profiler
::
profile_conv_fwd_impl
<
2
,
ck
::
tensor_layout
::
convolution
::
NHWC
,
ck
::
tensor_layout
::
convolution
::
KYXC
,
ck
::
tensor_layout
::
convolution
::
NHWK
,
float
,
float
,
float
>
(
true
,
// do_verification
1
,
// init_method
false
,
// do_log
false
,
// time_kernel
param
);
EXPECT_TRUE
(
pass
);
// fp16
pass
=
ck
::
profiler
::
profile_conv_fwd_impl
<
2
,
ck
::
tensor_layout
::
convolution
::
NHWC
,
ck
::
tensor_layout
::
convolution
::
KYXC
,
ck
::
tensor_layout
::
convolution
::
NHWK
,
ck
::
half_t
,
ck
::
half_t
,
ck
::
half_t
>
(
true
,
// do_verification
1
,
// init_method
false
,
// do_log
false
,
// time_kernel
param
);
EXPECT_TRUE
(
pass
);
// bf16
pass
=
ck
::
profiler
::
profile_conv_fwd_impl
<
2
,
ck
::
tensor_layout
::
convolution
::
NHWC
,
ck
::
tensor_layout
::
convolution
::
KYXC
,
ck
::
tensor_layout
::
convolution
::
NHWK
,
ck
::
bhalf_t
,
ck
::
bhalf_t
,
ck
::
bhalf_t
>
(
true
,
// do_verification
1
,
// init_method
false
,
// do_log
false
,
// time_kernel
param
);
EXPECT_TRUE
(
pass
);
// int8
pass
=
ck
::
profiler
::
profile_conv_fwd_impl
<
2
,
ck
::
tensor_layout
::
convolution
::
NHWC
,
ck
::
tensor_layout
::
convolution
::
KYXC
,
ck
::
tensor_layout
::
convolution
::
NHWK
,
int8_t
,
int8_t
,
int8_t
>
(
true
,
// do_verification
1
,
// init_method
false
,
// do_log
false
,
// time_kernel
param
);
EXPECT_TRUE
(
pass
);
}
}
}
// 3d
// 3d
TEST
_F
(
TestConvndFwd
,
Conv3dFwd
)
TYPED_
TEST
(
TestConvndFwd
,
Conv3dFwd
)
{
{
conv_params
.
clear
();
this
->
conv_params
.
clear
();
conv_params
.
push_back
(
this
->
conv_params
.
push_back
(
{
3
,
1
,
128
,
128
,
256
,
{
1
,
1
,
1
},
{
7
,
7
,
7
},
{
2
,
2
,
2
},
{
1
,
1
,
1
},
{
0
,
0
,
0
},
{
0
,
0
,
0
}});
{
3
,
1
,
128
,
128
,
256
,
{
1
,
1
,
1
},
{
7
,
7
,
7
},
{
2
,
2
,
2
},
{
1
,
1
,
1
},
{
0
,
0
,
0
},
{
0
,
0
,
0
}});
conv_params
.
push_back
(
this
->
conv_params
.
push_back
(
{
3
,
1
,
128
,
128
,
256
,
{
3
,
3
,
3
},
{
14
,
14
,
3
},
{
1
,
1
,
1
},
{
1
,
1
,
1
},
{
1
,
1
,
1
},
{
1
,
1
,
1
}});
{
3
,
1
,
128
,
128
,
256
,
{
3
,
3
,
3
},
{
14
,
14
,
3
},
{
1
,
1
,
1
},
{
1
,
1
,
1
},
{
1
,
1
,
1
},
{
1
,
1
,
1
}});
conv_params
.
push_back
(
this
->
conv_params
.
push_back
(
{
3
,
1
,
128
,
128
,
256
,
{
1
,
1
,
1
},
{
3
,
3
,
3
},
{
1
,
1
,
1
},
{
1
,
1
,
1
},
{
0
,
0
,
0
},
{
0
,
0
,
0
}});
{
3
,
1
,
128
,
128
,
256
,
{
1
,
1
,
1
},
{
3
,
3
,
3
},
{
1
,
1
,
1
},
{
1
,
1
,
1
},
{
0
,
0
,
0
},
{
0
,
0
,
0
}});
this
->
template
Run
<
3
>();
for
(
auto
&
param
:
conv_params
)
{
bool
pass
;
// fp32
pass
=
ck
::
profiler
::
profile_conv_fwd_impl
<
3
,
ck
::
tensor_layout
::
convolution
::
NDHWC
,
ck
::
tensor_layout
::
convolution
::
KZYXC
,
ck
::
tensor_layout
::
convolution
::
NDHWK
,
float
,
float
,
float
>
(
true
,
// do_verification
1
,
// init_method
false
,
// do_log
false
,
// time_kernel
param
);
EXPECT_TRUE
(
pass
);
// fp16
pass
=
ck
::
profiler
::
profile_conv_fwd_impl
<
3
,
ck
::
tensor_layout
::
convolution
::
NDHWC
,
ck
::
tensor_layout
::
convolution
::
KZYXC
,
ck
::
tensor_layout
::
convolution
::
NDHWK
,
ck
::
half_t
,
ck
::
half_t
,
ck
::
half_t
>
(
true
,
// do_verification
1
,
// init_method
false
,
// do_log
false
,
// time_kernel
param
);
EXPECT_TRUE
(
pass
);
// bf16
pass
=
ck
::
profiler
::
profile_conv_fwd_impl
<
3
,
ck
::
tensor_layout
::
convolution
::
NDHWC
,
ck
::
tensor_layout
::
convolution
::
KZYXC
,
ck
::
tensor_layout
::
convolution
::
NDHWK
,
ck
::
bhalf_t
,
ck
::
bhalf_t
,
ck
::
bhalf_t
>
(
true
,
// do_verification
1
,
// init_method
false
,
// do_log
false
,
// time_kernel
param
);
EXPECT_TRUE
(
pass
);
// int8
pass
=
ck
::
profiler
::
profile_conv_fwd_impl
<
3
,
ck
::
tensor_layout
::
convolution
::
NDHWC
,
ck
::
tensor_layout
::
convolution
::
KZYXC
,
ck
::
tensor_layout
::
convolution
::
NDHWK
,
int8_t
,
int8_t
,
int8_t
>
(
true
,
// do_verification
1
,
// init_method
false
,
// do_log
false
,
// time_kernel
param
);
EXPECT_TRUE
(
pass
);
}
}
}
test/gemm/CMakeLists.txt
View file @
5b7c2432
...
@@ -13,3 +13,13 @@ target_link_libraries(test_gemm_bf16 PRIVATE device_gemm_instance)
...
@@ -13,3 +13,13 @@ target_link_libraries(test_gemm_bf16 PRIVATE device_gemm_instance)
add_test_executable
(
test_gemm_int8 gemm_int8.cpp
)
add_test_executable
(
test_gemm_int8 gemm_int8.cpp
)
target_link_libraries
(
test_gemm_int8 PRIVATE utility
)
target_link_libraries
(
test_gemm_int8 PRIVATE utility
)
target_link_libraries
(
test_gemm_int8 PRIVATE device_gemm_instance
)
target_link_libraries
(
test_gemm_int8 PRIVATE device_gemm_instance
)
add_library
(
gemm_standalone_xdl_fp16_instances STATIC
instance/gemm_f16_nn_instance.cpp
instance/gemm_f16_nt_instance.cpp
instance/gemm_f16_tn_instance.cpp
instance/gemm_f16_tt_instance.cpp
)
add_test_executable
(
test_gemm_standalone_xdl_fp16 gemm_standalone_xdl_fp16.cpp
)
target_link_libraries
(
test_gemm_standalone_xdl_fp16 PRIVATE gemm_standalone_xdl_fp16_instances utility
)
target_include_directories
(
test_gemm_standalone_xdl_fp16 PRIVATE instance/
)
test/gemm/gemm_bf16.cpp
View file @
5b7c2432
...
@@ -24,56 +24,11 @@
...
@@ -24,56 +24,11 @@
#include "test/gemm/gemm_util.hpp"
#include "test/gemm/gemm_util.hpp"
int
main
()
using
ADataType
=
ck
::
bhalf_t
;
{
using
BDataType
=
ck
::
bhalf_t
;
using
ADataType
=
ck
::
bhalf_t
;
using
CDataType
=
ck
::
bhalf_t
;
using
BDataType
=
ck
::
bhalf_t
;
using
AccDataType
=
float
;
using
CDataType
=
ck
::
bhalf_t
;
using
AccDataType
=
float
;
using
Row
=
ck
::
tensor_layout
::
gemm
::
RowMajor
;
#include "run_gemm_test.inc"
using
Col
=
ck
::
tensor_layout
::
gemm
::
ColumnMajor
;
using
PassThrough
=
ck
::
tensor_operation
::
element_wise
::
PassThrough
;
int
main
()
{
return
run_gemm_test
();
}
auto
test
=
[
&
](
auto
a_layout
,
auto
b_layout
,
auto
c_layout
)
{
bool
pass
=
true
;
using
DeviceOp
=
ck
::
tensor_operation
::
device
::
DeviceGemm
<
decltype
(
a_layout
),
decltype
(
b_layout
),
decltype
(
c_layout
),
ADataType
,
BDataType
,
CDataType
,
PassThrough
,
PassThrough
,
PassThrough
>
;
const
auto
gemmPtrs
=
ck
::
tensor_operation
::
device
::
instance
::
DeviceOperationInstanceFactory
<
DeviceOp
>::
GetInstances
();
for
(
auto
&
gemmPtr
:
gemmPtrs
)
{
pass
&=
ck
::
gemm_util
::
TestGemm
<
std
::
unique_ptr
<
DeviceOp
>
,
ADataType
,
BDataType
,
CDataType
,
AccDataType
,
decltype
(
a_layout
),
decltype
(
b_layout
),
decltype
(
c_layout
),
PassThrough
,
PassThrough
,
PassThrough
>
{}(
gemmPtr
);
}
return
pass
;
};
bool
pass
=
test
(
Row
{},
Row
{},
Row
{})
&&
test
(
Row
{},
Col
{},
Row
{})
&&
test
(
Col
{},
Row
{},
Row
{})
&&
test
(
Col
{},
Col
{},
Row
{});
std
::
cout
<<
"TestGemm ..... "
<<
(
pass
?
"SUCCESS"
:
"FAILURE"
)
<<
std
::
endl
;
return
pass
?
0
:
1
;
}
test/gemm/gemm_fp16.cpp
View file @
5b7c2432
...
@@ -24,56 +24,11 @@
...
@@ -24,56 +24,11 @@
#include "test/gemm/gemm_util.hpp"
#include "test/gemm/gemm_util.hpp"
int
main
()
using
ADataType
=
ck
::
half_t
;
{
using
BDataType
=
ck
::
half_t
;
using
ADataType
=
ck
::
half_t
;
using
CDataType
=
ck
::
half_t
;
using
BDataType
=
ck
::
half_t
;
using
AccDataType
=
float
;
using
CDataType
=
ck
::
half_t
;
using
AccDataType
=
float
;
using
Row
=
ck
::
tensor_layout
::
gemm
::
RowMajor
;
#include "run_gemm_test.inc"
using
Col
=
ck
::
tensor_layout
::
gemm
::
ColumnMajor
;
using
PassThrough
=
ck
::
tensor_operation
::
element_wise
::
PassThrough
;
int
main
()
{
return
run_gemm_test
();
}
auto
test
=
[
&
](
auto
a_layout
,
auto
b_layout
,
auto
c_layout
)
{
bool
pass
=
true
;
using
DeviceOp
=
ck
::
tensor_operation
::
device
::
DeviceGemm
<
decltype
(
a_layout
),
decltype
(
b_layout
),
decltype
(
c_layout
),
ADataType
,
BDataType
,
CDataType
,
PassThrough
,
PassThrough
,
PassThrough
>
;
const
auto
gemmPtrs
=
ck
::
tensor_operation
::
device
::
instance
::
DeviceOperationInstanceFactory
<
DeviceOp
>::
GetInstances
();
for
(
auto
&
gemmPtr
:
gemmPtrs
)
{
pass
&=
ck
::
gemm_util
::
TestGemm
<
std
::
unique_ptr
<
DeviceOp
>
,
ADataType
,
BDataType
,
CDataType
,
AccDataType
,
decltype
(
a_layout
),
decltype
(
b_layout
),
decltype
(
c_layout
),
PassThrough
,
PassThrough
,
PassThrough
>
{}(
gemmPtr
);
}
return
pass
;
};
bool
pass
=
test
(
Row
{},
Row
{},
Row
{})
&&
test
(
Row
{},
Col
{},
Row
{})
&&
test
(
Col
{},
Row
{},
Row
{})
&&
test
(
Col
{},
Col
{},
Row
{});
std
::
cout
<<
"TestGemm ..... "
<<
(
pass
?
"SUCCESS"
:
"FAILURE"
)
<<
std
::
endl
;
return
pass
?
0
:
1
;
}
test/gemm/gemm_fp32.cpp
View file @
5b7c2432
...
@@ -24,56 +24,11 @@
...
@@ -24,56 +24,11 @@
#include "test/gemm/gemm_util.hpp"
#include "test/gemm/gemm_util.hpp"
int
main
()
using
ADataType
=
float
;
{
using
BDataType
=
float
;
using
ADataType
=
float
;
using
CDataType
=
float
;
using
BDataType
=
float
;
using
AccDataType
=
float
;
using
CDataType
=
float
;
using
AccDataType
=
float
;
using
Row
=
ck
::
tensor_layout
::
gemm
::
RowMajor
;
#include "run_gemm_test.inc"
using
Col
=
ck
::
tensor_layout
::
gemm
::
ColumnMajor
;
using
PassThrough
=
ck
::
tensor_operation
::
element_wise
::
PassThrough
;
int
main
()
{
return
run_gemm_test
();
}
auto
test
=
[
&
](
auto
a_layout
,
auto
b_layout
,
auto
c_layout
)
{
bool
pass
=
true
;
using
DeviceOp
=
ck
::
tensor_operation
::
device
::
DeviceGemm
<
decltype
(
a_layout
),
decltype
(
b_layout
),
decltype
(
c_layout
),
ADataType
,
BDataType
,
CDataType
,
PassThrough
,
PassThrough
,
PassThrough
>
;
const
auto
gemmPtrs
=
ck
::
tensor_operation
::
device
::
instance
::
DeviceOperationInstanceFactory
<
DeviceOp
>::
GetInstances
();
for
(
auto
&
gemmPtr
:
gemmPtrs
)
{
pass
&=
ck
::
gemm_util
::
TestGemm
<
std
::
unique_ptr
<
DeviceOp
>
,
ADataType
,
BDataType
,
CDataType
,
AccDataType
,
decltype
(
a_layout
),
decltype
(
b_layout
),
decltype
(
c_layout
),
PassThrough
,
PassThrough
,
PassThrough
>
{}(
gemmPtr
);
}
return
pass
;
};
bool
pass
=
test
(
Row
{},
Row
{},
Row
{})
&&
test
(
Row
{},
Col
{},
Row
{})
&&
test
(
Col
{},
Row
{},
Row
{})
&&
test
(
Col
{},
Col
{},
Row
{});
std
::
cout
<<
"TestGemm ..... "
<<
(
pass
?
"SUCCESS"
:
"FAILURE"
)
<<
std
::
endl
;
return
pass
?
0
:
1
;
}
test/gemm/gemm_fp64.cpp
View file @
5b7c2432
...
@@ -24,56 +24,11 @@
...
@@ -24,56 +24,11 @@
#include "test/gemm/gemm_util.hpp"
#include "test/gemm/gemm_util.hpp"
int
main
()
using
ADataType
=
double
;
{
using
BDataType
=
double
;
using
ADataType
=
double
;
using
CDataType
=
double
;
using
BDataType
=
double
;
using
AccDataType
=
double
;
using
CDataType
=
double
;
using
AccDataType
=
double
;
using
Row
=
ck
::
tensor_layout
::
gemm
::
RowMajor
;
#include "run_gemm_test.inc"
using
Col
=
ck
::
tensor_layout
::
gemm
::
ColumnMajor
;
using
PassThrough
=
ck
::
tensor_operation
::
element_wise
::
PassThrough
;
int
main
()
{
return
run_gemm_test
();
}
auto
test
=
[
&
](
auto
a_layout
,
auto
b_layout
,
auto
c_layout
)
{
bool
pass
=
true
;
using
DeviceOp
=
ck
::
tensor_operation
::
device
::
DeviceGemm
<
decltype
(
a_layout
),
decltype
(
b_layout
),
decltype
(
c_layout
),
ADataType
,
BDataType
,
CDataType
,
PassThrough
,
PassThrough
,
PassThrough
>
;
const
auto
gemmPtrs
=
ck
::
tensor_operation
::
device
::
instance
::
DeviceOperationInstanceFactory
<
DeviceOp
>::
GetInstances
();
for
(
auto
&
gemmPtr
:
gemmPtrs
)
{
pass
&=
ck
::
gemm_util
::
TestGemm
<
std
::
unique_ptr
<
DeviceOp
>
,
ADataType
,
BDataType
,
CDataType
,
AccDataType
,
decltype
(
a_layout
),
decltype
(
b_layout
),
decltype
(
c_layout
),
PassThrough
,
PassThrough
,
PassThrough
>
{}(
gemmPtr
);
}
return
pass
;
};
bool
pass
=
test
(
Row
{},
Row
{},
Row
{})
&&
test
(
Row
{},
Col
{},
Row
{})
&&
test
(
Col
{},
Row
{},
Row
{})
&&
test
(
Col
{},
Col
{},
Row
{});
std
::
cout
<<
"TestGemm ..... "
<<
(
pass
?
"SUCCESS"
:
"FAILURE"
)
<<
std
::
endl
;
return
pass
?
0
:
1
;
}
test/gemm/gemm_int8.cpp
View file @
5b7c2432
...
@@ -24,56 +24,11 @@
...
@@ -24,56 +24,11 @@
#include "test/gemm/gemm_util.hpp"
#include "test/gemm/gemm_util.hpp"
int
main
()
using
ADataType
=
int8_t
;
{
using
BDataType
=
int8_t
;
using
ADataType
=
int8_t
;
using
CDataType
=
int8_t
;
using
BDataType
=
int8_t
;
using
AccDataType
=
int32_t
;
using
CDataType
=
int8_t
;
using
AccDataType
=
int32_t
;
using
Row
=
ck
::
tensor_layout
::
gemm
::
RowMajor
;
#include "run_gemm_test.inc"
using
Col
=
ck
::
tensor_layout
::
gemm
::
ColumnMajor
;
using
PassThrough
=
ck
::
tensor_operation
::
element_wise
::
PassThrough
;
int
main
()
{
return
run_gemm_test
();
}
auto
test
=
[
&
](
auto
a_layout
,
auto
b_layout
,
auto
c_layout
)
{
bool
pass
=
true
;
using
DeviceOp
=
ck
::
tensor_operation
::
device
::
DeviceGemm
<
decltype
(
a_layout
),
decltype
(
b_layout
),
decltype
(
c_layout
),
ADataType
,
BDataType
,
CDataType
,
PassThrough
,
PassThrough
,
PassThrough
>
;
const
auto
gemmPtrs
=
ck
::
tensor_operation
::
device
::
instance
::
DeviceOperationInstanceFactory
<
DeviceOp
>::
GetInstances
();
for
(
auto
&
gemmPtr
:
gemmPtrs
)
{
pass
&=
ck
::
gemm_util
::
TestGemm
<
std
::
unique_ptr
<
DeviceOp
>
,
ADataType
,
BDataType
,
CDataType
,
AccDataType
,
decltype
(
a_layout
),
decltype
(
b_layout
),
decltype
(
c_layout
),
PassThrough
,
PassThrough
,
PassThrough
>
{}(
gemmPtr
);
}
return
pass
;
};
bool
pass
=
test
(
Row
{},
Row
{},
Row
{})
&&
test
(
Row
{},
Col
{},
Row
{})
&&
test
(
Col
{},
Row
{},
Row
{})
&&
test
(
Col
{},
Col
{},
Row
{});
std
::
cout
<<
"TestGemm ..... "
<<
(
pass
?
"SUCCESS"
:
"FAILURE"
)
<<
std
::
endl
;
return
pass
?
0
:
1
;
}
test/gemm/gemm_standalone_xdl_fp16.cpp
0 → 100644
View file @
5b7c2432
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#include "gemm_util.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_gemm_xdl.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle.hpp"
#include "gemm_f16_nn_instance.hpp"
#include "gemm_f16_nt_instance.hpp"
#include "gemm_f16_tn_instance.hpp"
#include "gemm_f16_tt_instance.hpp"
using
Row
=
ck
::
tensor_layout
::
gemm
::
RowMajor
;
using
Col
=
ck
::
tensor_layout
::
gemm
::
ColumnMajor
;
using
PassThrough
=
ck
::
tensor_operation
::
element_wise
::
PassThrough
;
using
F16
=
ck
::
half_t
;
using
ADataType
=
F16
;
using
BDataType
=
F16
;
using
AccDataType
=
float
;
using
CDataType
=
F16
;
using
ALayout
=
Row
;
using
BLayout
=
Col
;
using
CLayout
=
Row
;
using
AElementOp
=
PassThrough
;
using
BElementOp
=
PassThrough
;
using
CElementOp
=
PassThrough
;
using
ck
::
gemm_util
::
GemmParams
;
using
ck
::
tensor_operation
::
device
::
BaseOperator
;
using
ck
::
tensor_operation
::
device
::
DeviceGemm
;
using
namespace
ck
::
tensor_operation
::
device
::
instance
;
using
DeviceGemmNN
=
DeviceGemm
<
Col
,
Col
,
Row
,
F16
,
F16
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
>
;
using
DeviceGemmNT
=
DeviceGemm
<
Col
,
Row
,
Row
,
F16
,
F16
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
>
;
using
DeviceGemmTN
=
DeviceGemm
<
Row
,
Col
,
Row
,
F16
,
F16
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
>
;
using
DeviceGemmTT
=
DeviceGemm
<
Row
,
Row
,
Row
,
F16
,
F16
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
>
;
struct
LayoutConfig
{
bool
ARowMajor
;
bool
BRowMajor
;
bool
CRowMajor
;
};
int
main
(
int
argc
,
char
*
argv
[])
{
// Class DeviceGemm is templated by layout and precision types so it is not an option to contain
// them in a single vector. Instead we use abstract BaseOperator class and dynamic_cast() it
// upon invocation.
// And since DeviceGemm does not expose template arg information, an extra book keeping class
// LayoutConfig is used for determining which type a BaseOperator instance should be cast to.
using
OpFactoryFn
=
void
(
*
)(
std
::
vector
<
std
::
unique_ptr
<
BaseOperator
>>&
);
std
::
vector
<
std
::
tuple
<
GemmParams
,
LayoutConfig
,
OpFactoryFn
>>
problems
=
{
// clang-format off
// 104 tiles
{
GemmParams
{
2048
,
3328
,
4096
},
LayoutConfig
{
false
,
false
,
true
},
add_gemm_f16_nn_256x256
},
{
GemmParams
{
2048
,
1664
,
4096
},
LayoutConfig
{
false
,
false
,
true
},
add_gemm_f16_nn_256x128
},
{
GemmParams
{
1024
,
1664
,
4096
},
LayoutConfig
{
false
,
false
,
true
},
add_gemm_f16_nn_128x128
},
{
GemmParams
{
1024
,
832
,
4096
},
LayoutConfig
{
false
,
false
,
true
},
add_gemm_f16_nn_128x64
},
{
GemmParams
{
2048
,
3328
,
4096
},
LayoutConfig
{
false
,
true
,
true
},
add_gemm_f16_nt_256x256
},
{
GemmParams
{
2048
,
1664
,
4096
},
LayoutConfig
{
false
,
true
,
true
},
add_gemm_f16_nt_256x128
},
{
GemmParams
{
1024
,
1664
,
4096
},
LayoutConfig
{
false
,
true
,
true
},
add_gemm_f16_nt_128x128
},
{
GemmParams
{
1024
,
832
,
4096
},
LayoutConfig
{
false
,
true
,
true
},
add_gemm_f16_nt_128x64
},
{
GemmParams
{
2048
,
3328
,
4096
},
LayoutConfig
{
true
,
false
,
true
},
add_gemm_f16_tn_256x128
},
{
GemmParams
{
2048
,
1664
,
4096
},
LayoutConfig
{
true
,
false
,
true
},
add_gemm_f16_tn_256x128
},
{
GemmParams
{
1024
,
1664
,
4096
},
LayoutConfig
{
true
,
false
,
true
},
add_gemm_f16_tn_128x128
},
{
GemmParams
{
1024
,
832
,
4096
},
LayoutConfig
{
true
,
false
,
true
},
add_gemm_f16_tn_128x64
},
{
GemmParams
{
2048
,
3328
,
4096
},
LayoutConfig
{
true
,
true
,
true
},
add_gemm_f16_tt_256x256
},
{
GemmParams
{
2048
,
1664
,
4096
},
LayoutConfig
{
true
,
true
,
true
},
add_gemm_f16_tt_256x128
},
{
GemmParams
{
1024
,
1664
,
4096
},
LayoutConfig
{
true
,
true
,
true
},
add_gemm_f16_tt_128x128
},
{
GemmParams
{
1024
,
832
,
4096
},
LayoutConfig
{
true
,
true
,
true
},
add_gemm_f16_tt_128x64
},
// 110 tiles
{
GemmParams
{
2560
,
2816
,
4096
},
LayoutConfig
{
false
,
false
,
true
},
add_gemm_f16_nn_256x256
},
{
GemmParams
{
2560
,
1408
,
4096
},
LayoutConfig
{
false
,
false
,
true
},
add_gemm_f16_nn_256x128
},
{
GemmParams
{
1280
,
1408
,
4096
},
LayoutConfig
{
false
,
false
,
true
},
add_gemm_f16_nn_128x128
},
{
GemmParams
{
1280
,
704
,
4096
},
LayoutConfig
{
false
,
false
,
true
},
add_gemm_f16_nn_128x64
},
{
GemmParams
{
2560
,
2816
,
4096
},
LayoutConfig
{
false
,
true
,
true
},
add_gemm_f16_nt_256x256
},
{
GemmParams
{
2560
,
1408
,
4096
},
LayoutConfig
{
false
,
true
,
true
},
add_gemm_f16_nt_256x128
},
{
GemmParams
{
1280
,
1408
,
4096
},
LayoutConfig
{
false
,
true
,
true
},
add_gemm_f16_nt_128x128
},
{
GemmParams
{
1280
,
704
,
4096
},
LayoutConfig
{
false
,
true
,
true
},
add_gemm_f16_nt_128x64
},
{
GemmParams
{
2560
,
2816
,
4096
},
LayoutConfig
{
true
,
false
,
true
},
add_gemm_f16_tn_256x128
},
{
GemmParams
{
2560
,
1408
,
4096
},
LayoutConfig
{
true
,
false
,
true
},
add_gemm_f16_tn_256x128
},
{
GemmParams
{
1280
,
1408
,
4096
},
LayoutConfig
{
true
,
false
,
true
},
add_gemm_f16_tn_128x128
},
{
GemmParams
{
1280
,
704
,
4096
},
LayoutConfig
{
true
,
false
,
true
},
add_gemm_f16_tn_128x64
},
{
GemmParams
{
2560
,
2816
,
4096
},
LayoutConfig
{
true
,
true
,
true
},
add_gemm_f16_tt_256x256
},
{
GemmParams
{
2560
,
1408
,
4096
},
LayoutConfig
{
true
,
true
,
true
},
add_gemm_f16_tt_256x128
},
{
GemmParams
{
1280
,
1408
,
4096
},
LayoutConfig
{
true
,
true
,
true
},
add_gemm_f16_tt_128x128
},
{
GemmParams
{
1280
,
704
,
4096
},
LayoutConfig
{
true
,
true
,
true
},
add_gemm_f16_tt_128x64
},
// clang-format on
};
bool
do_verification
=
true
;
bool
time_kernel
=
true
;
if
(
argc
==
1
)
{
// use default
}
else
if
(
argc
==
3
)
{
do_verification
=
std
::
stoi
(
argv
[
1
]);
time_kernel
=
std
::
stoi
(
argv
[
2
]);
}
else
{
std
::
cerr
<<
"arg1: verification (0=no, 1=yes)"
<<
std
::
endl
<<
"arg2: time kernel (0=no, 1=yes)"
<<
std
::
endl
;
return
0
;
}
bool
pass
=
true
;
for
(
auto
&
p
:
problems
)
{
GemmParams
&
problem_size
=
std
::
get
<
0
>
(
p
);
const
LayoutConfig
&
layout_config
=
std
::
get
<
1
>
(
p
);
const
auto
&
factory
=
std
::
get
<
2
>
(
p
);
std
::
vector
<
std
::
unique_ptr
<
BaseOperator
>>
ops
;
factory
(
ops
);
// overwrite strides
problem_size
.
StrideA
=
layout_config
.
ARowMajor
?
problem_size
.
K
:
problem_size
.
M
;
problem_size
.
StrideB
=
layout_config
.
BRowMajor
?
problem_size
.
N
:
problem_size
.
K
;
problem_size
.
StrideC
=
layout_config
.
CRowMajor
?
problem_size
.
N
:
problem_size
.
M
;
if
(
!
layout_config
.
ARowMajor
&&
!
layout_config
.
BRowMajor
)
{
auto
op_ptr
=
dynamic_cast
<
DeviceGemmNN
*>
(
ops
[
0
].
get
());
pass
&=
ck
::
gemm_util
::
TestGemm
<
AccDataType
>
{}(
op_ptr
,
problem_size
,
do_verification
,
time_kernel
);
}
else
if
(
!
layout_config
.
ARowMajor
&&
layout_config
.
BRowMajor
)
{
auto
op_ptr
=
dynamic_cast
<
DeviceGemmNT
*>
(
ops
[
0
].
get
());
pass
&=
ck
::
gemm_util
::
TestGemm
<
AccDataType
>
{}(
op_ptr
,
problem_size
,
do_verification
,
time_kernel
);
}
else
if
(
layout_config
.
ARowMajor
&&
!
layout_config
.
BRowMajor
)
{
auto
op_ptr
=
dynamic_cast
<
DeviceGemmTN
*>
(
ops
[
0
].
get
());
pass
&=
ck
::
gemm_util
::
TestGemm
<
AccDataType
>
{}(
op_ptr
,
problem_size
,
do_verification
,
time_kernel
);
}
else
if
(
layout_config
.
ARowMajor
&&
layout_config
.
BRowMajor
)
{
auto
op_ptr
=
dynamic_cast
<
DeviceGemmTT
*>
(
ops
[
0
].
get
());
pass
&=
ck
::
gemm_util
::
TestGemm
<
AccDataType
>
{}(
op_ptr
,
problem_size
,
do_verification
,
time_kernel
);
}
}
std
::
cout
<<
(
pass
?
"ALL TESTS PASSED"
:
"SOME TESTS FAILED"
)
<<
std
::
endl
;
return
pass
?
0
:
1
;
}
test/gemm/gemm_util.hpp
View file @
5b7c2432
...
@@ -16,21 +16,13 @@ namespace gemm_util {
...
@@ -16,21 +16,13 @@ namespace gemm_util {
struct
GemmParams
struct
GemmParams
{
{
GemmParams
()
ck
::
index_t
M
=
1024
;
:
M
(
1024
),
N
(
1024
),
K
(
1024
),
StrideA
(
1024
),
StrideB
(
1024
),
StrideC
(
1024
),
alpha
(
1
),
beta
(
0
)
ck
::
index_t
N
=
1024
;
{
ck
::
index_t
K
=
1024
;
}
ck
::
index_t
M
;
ck
::
index_t
N
;
ck
::
index_t
K
;
ck
::
index_t
StrideA
;
ck
::
index_t
StrideA
=
1024
;
ck
::
index_t
StrideB
;
ck
::
index_t
StrideB
=
1024
;
ck
::
index_t
StrideC
;
ck
::
index_t
StrideC
=
1024
;
float
alpha
;
float
beta
;
};
};
template
<
typename
GemmInstance
,
template
<
typename
GemmInstance
,
...
@@ -69,7 +61,8 @@ bool RunDeviceGEMM(DeviceGemmPtr_& gemmPtr,
...
@@ -69,7 +61,8 @@ bool RunDeviceGEMM(DeviceGemmPtr_& gemmPtr,
Tensor
<
CDataType
>&
C
,
Tensor
<
CDataType
>&
C
,
AElementwiseOperation
a_element_op
,
AElementwiseOperation
a_element_op
,
BElementwiseOperation
b_element_op
,
BElementwiseOperation
b_element_op
,
CElementwiseOperation
c_element_op
)
CElementwiseOperation
c_element_op
,
bool
time_kernel
)
{
{
DeviceMem
a_m_k_device_buf
(
sizeof
(
ADataType
)
*
A
.
mDesc
.
GetElementSpaceSize
());
DeviceMem
a_m_k_device_buf
(
sizeof
(
ADataType
)
*
A
.
mDesc
.
GetElementSpaceSize
());
DeviceMem
b_k_n_device_buf
(
sizeof
(
BDataType
)
*
B
.
mDesc
.
GetElementSpaceSize
());
DeviceMem
b_k_n_device_buf
(
sizeof
(
BDataType
)
*
B
.
mDesc
.
GetElementSpaceSize
());
...
@@ -94,7 +87,20 @@ bool RunDeviceGEMM(DeviceGemmPtr_& gemmPtr,
...
@@ -94,7 +87,20 @@ bool RunDeviceGEMM(DeviceGemmPtr_& gemmPtr,
{
{
a_m_k_device_buf
.
ToDevice
(
A
.
mData
.
data
());
a_m_k_device_buf
.
ToDevice
(
A
.
mData
.
data
());
b_k_n_device_buf
.
ToDevice
(
B
.
mData
.
data
());
b_k_n_device_buf
.
ToDevice
(
B
.
mData
.
data
());
invoker_ptr
->
Run
(
argument_ptr
.
get
());
float
ave_time
=
invoker_ptr
->
Run
(
argument_ptr
.
get
(),
StreamConfig
{
nullptr
,
time_kernel
});
std
::
size_t
flop
=
std
::
size_t
(
2
)
*
params
.
M
*
params
.
N
*
params
.
K
;
std
::
size_t
num_btype
=
sizeof
(
ADataType
)
*
params
.
M
*
params
.
K
+
sizeof
(
BDataType
)
*
params
.
K
*
params
.
N
+
sizeof
(
CDataType
)
*
params
.
M
*
params
.
N
;
float
tflops
=
static_cast
<
float
>
(
flop
)
/
1.E9
/
ave_time
;
float
gb_per_sec
=
num_btype
/
1.E6
/
ave_time
;
std
::
cout
<<
"Perf: "
<<
ave_time
<<
" ms, "
<<
tflops
<<
" TFlops, "
<<
gb_per_sec
<<
" GB/s, "
<<
std
::
endl
;
c_m_n_device_buf
.
FromDevice
(
C
.
mData
.
data
());
c_m_n_device_buf
.
FromDevice
(
C
.
mData
.
data
());
return
true
;
return
true
;
...
@@ -109,19 +115,15 @@ bool RunDeviceGEMM(DeviceGemmPtr_& gemmPtr,
...
@@ -109,19 +115,15 @@ bool RunDeviceGEMM(DeviceGemmPtr_& gemmPtr,
}
}
}
}
template
<
typename
DeviceGemmPtr_
,
template
<
typename
AccDataType
>
typename
ADataType
,
typename
BDataType
,
typename
CDataType
,
typename
AccDataType
,
typename
ALayout
,
typename
BLayout
,
typename
CLayout
,
typename
AElementwiseOperation
,
typename
BElementwiseOperation
,
typename
CElementwiseOperation
>
struct
TestGemm
struct
TestGemm
{
{
template
<
typename
ADataType
,
typename
BDataType
,
typename
CDataType
,
typename
ALayout
,
typename
BLayout
,
typename
CLayout
>
auto
PrepareGemmTensor
(
const
ck
::
gemm_util
::
GemmParams
&
params
)
auto
PrepareGemmTensor
(
const
ck
::
gemm_util
::
GemmParams
&
params
)
{
{
auto
f_host_tensor_descriptor
=
auto
f_host_tensor_descriptor
=
...
@@ -156,25 +158,42 @@ struct TestGemm
...
@@ -156,25 +158,42 @@ struct TestGemm
f_generate_tensor_value
(
a_m_k
,
ADataType
{});
f_generate_tensor_value
(
a_m_k
,
ADataType
{});
f_generate_tensor_value
(
b_k_n
,
BDataType
{});
f_generate_tensor_value
(
b_k_n
,
BDataType
{});
std
::
cout
<<
"a_m_k: "
<<
a_m_k
.
mDesc
<<
std
::
endl
;
std
::
cout
<<
"b_k_n: "
<<
b_k_n
.
mDesc
<<
std
::
endl
;
std
::
cout
<<
"c_m_n: "
<<
c_m_n_host_result
.
mDesc
<<
std
::
endl
;
return
std
::
make_tuple
(
a_m_k
,
b_k_n
,
c_m_n_host_result
,
c_m_n_device_result
);
return
std
::
make_tuple
(
a_m_k
,
b_k_n
,
c_m_n_host_result
,
c_m_n_device_result
);
}
}
auto
operator
()(
const
DeviceGemmPtr_
&
gemmPtr
)
template
<
template
<
class
...
>
class
DeviceGemmPtr_
,
typename
ALayout
,
typename
BLayout
,
typename
CLayout
,
typename
ADataType
,
typename
BDataType
,
typename
CDataType
,
typename
AElementwiseOperation
,
typename
BElementwiseOperation
,
typename
CElementwiseOperation
>
auto
operator
()(
DeviceGemmPtr_
<
ALayout
,
BLayout
,
CLayout
,
ADataType
,
BDataType
,
CDataType
,
AElementwiseOperation
,
BElementwiseOperation
,
CElementwiseOperation
>*
gemmPtr
,
const
GemmParams
&
params
=
GemmParams
{},
bool
do_verification
=
true
,
bool
time_kernel
=
false
)
{
{
std
::
cout
<<
"ALayout = "
<<
ALayout
{}.
name
<<
", BLayout = "
<<
BLayout
{}.
name
std
::
cout
<<
"ALayout = "
<<
ALayout
{}.
name
<<
", BLayout = "
<<
BLayout
{}.
name
<<
", CLayout = "
<<
CLayout
{}.
name
<<
std
::
endl
;
<<
", CLayout = "
<<
CLayout
{}.
name
<<
std
::
endl
;
std
::
cout
<<
gemmPtr
->
GetTypeString
()
<<
std
::
endl
;
std
::
cout
<<
gemmPtr
->
GetTypeString
()
<<
std
::
endl
;
// Arrange
auto
host_tensors
=
ck
::
gemm_util
::
GemmParams
params
;
PrepareGemmTensor
<
ADataType
,
BDataType
,
CDataType
,
ALayout
,
BLayout
,
CLayout
>
(
params
);
params
.
M
=
1024
;
params
.
N
=
1024
;
params
.
K
=
1024
;
params
.
StrideA
=
1024
;
params
.
StrideB
=
1024
;
params
.
StrideC
=
1024
;
auto
host_tensors
=
PrepareGemmTensor
(
params
);
const
Tensor
<
ADataType
>&
a
=
std
::
get
<
0
>
(
host_tensors
);
const
Tensor
<
ADataType
>&
a
=
std
::
get
<
0
>
(
host_tensors
);
const
Tensor
<
BDataType
>&
b
=
std
::
get
<
1
>
(
host_tensors
);
const
Tensor
<
BDataType
>&
b
=
std
::
get
<
1
>
(
host_tensors
);
...
@@ -193,14 +212,18 @@ struct TestGemm
...
@@ -193,14 +212,18 @@ struct TestGemm
AElementwiseOperation
,
AElementwiseOperation
,
BElementwiseOperation
,
BElementwiseOperation
,
CElementwiseOperation
>
;
CElementwiseOperation
>
;
ck
::
gemm_util
::
RunHostGEMM
<
ReferenceGemmInstance
>
(
a
,
b
,
c_host
,
a_element_op
,
b_element_op
,
c_element_op
);
if
(
do_verification
)
{
ck
::
gemm_util
::
RunHostGEMM
<
ReferenceGemmInstance
>
(
a
,
b
,
c_host
,
a_element_op
,
b_element_op
,
c_element_op
);
}
// Act
// Act
bool
is_supported
=
ck
::
gemm_util
::
RunDeviceGEMM
(
bool
is_supported
=
ck
::
gemm_util
::
RunDeviceGEMM
(
gemmPtr
,
params
,
a
,
b
,
c_device
,
a_element_op
,
b_element_op
,
c_element_op
);
gemmPtr
,
params
,
a
,
b
,
c_device
,
a_element_op
,
b_element_op
,
c_element_op
,
time_kernel
);
if
(
is_supported
)
if
(
is_supported
&&
do_verification
)
{
{
// Assert
// Assert
bool
res
=
false
;
bool
res
=
false
;
...
...
test/gemm/instance/gemm_f16_nn_instance.cpp
0 → 100644
View file @
5b7c2432
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#include <cstdlib>
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
#include "gemm_f16_nn_instance.hpp"
namespace
ck
{
namespace
tensor_operation
{
namespace
device
{
namespace
instance
{
static
constexpr
auto
GemmDefault
=
ck
::
tensor_operation
::
device
::
GemmSpecialization
::
Default
;
using
gemm_f16_nn_256x256
=
std
::
tuple
<
// clang-format off
//#####################| ALayout| BLayout| CLayout| AData| BData| CData| AccData| CShuffle| A| B| C| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer|
//#####################| | | | Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector|
//#####################| | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl|
//#####################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
DeviceGemm_Xdl_CShuffle
<
Col
,
Col
,
Row
,
F16
,
F16
,
F16
,
F32
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmDefault
,
1
,
256
,
256
,
256
,
32
,
2
,
8
,
32
,
32
,
4
,
4
,
S
<
4
,
64
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
2
,
0
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
1
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
8
>
// clang-format on
>
;
using
gemm_f16_nn_256x128
=
std
::
tuple
<
// clang-format off
//#####################| ALayout| BLayout| CLayout| AData| BData| CData| AccData| CShuffle| A| B| C| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer|
//#####################| | | | Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector|
//#####################| | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl|
//#####################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
DeviceGemm_Xdl_CShuffle
<
Col
,
Col
,
Row
,
F16
,
F16
,
F16
,
F32
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmDefault
,
1
,
256
,
256
,
128
,
32
,
2
,
8
,
32
,
32
,
4
,
2
,
S
<
4
,
64
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
2
,
0
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
1
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
8
>
// clang-format on
>
;
using
gemm_f16_nn_128x128
=
std
::
tuple
<
// clang-format off
//#####################| ALayout| BLayout| CLayout| AData| BData| CData| AccData| CShuffle| A| B| C| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer|
//#####################| | | | Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector|
//#####################| | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl|
//#####################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
DeviceGemm_Xdl_CShuffle
<
Col
,
Col
,
Row
,
F16
,
F16
,
F16
,
F32
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmDefault
,
1
,
256
,
128
,
128
,
32
,
2
,
8
,
32
,
32
,
2
,
2
,
S
<
8
,
32
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
2
,
0
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
1
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
8
>
// clang-format on
>
;
using
gemm_f16_nn_128x64
=
std
::
tuple
<
// clang-format off
//#####################| ALayout| BLayout| CLayout| AData| BData| CData| AccData| CShuffle| A| B| C| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer|
//#####################| | | | Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector|
//#####################| | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl|
//#####################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
DeviceGemm_Xdl_CShuffle
<
Col
,
Col
,
Row
,
F16
,
F16
,
F16
,
F32
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmDefault
,
1
,
256
,
128
,
64
,
32
,
2
,
8
,
32
,
32
,
2
,
1
,
S
<
8
,
32
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
2
,
0
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
1
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
8
>
// clang-format on
>
;
void
add_gemm_f16_nn_256x256
(
std
::
vector
<
std
::
unique_ptr
<
BaseOperator
>>&
instances
)
{
add_device_operation_instances
(
instances
,
gemm_f16_nn_256x256
{});
}
void
add_gemm_f16_nn_256x128
(
std
::
vector
<
std
::
unique_ptr
<
BaseOperator
>>&
instances
)
{
add_device_operation_instances
(
instances
,
gemm_f16_nn_256x128
{});
}
void
add_gemm_f16_nn_128x128
(
std
::
vector
<
std
::
unique_ptr
<
BaseOperator
>>&
instances
)
{
add_device_operation_instances
(
instances
,
gemm_f16_nn_128x128
{});
}
void
add_gemm_f16_nn_128x64
(
std
::
vector
<
std
::
unique_ptr
<
BaseOperator
>>&
instances
)
{
add_device_operation_instances
(
instances
,
gemm_f16_nn_128x64
{});
}
}
// namespace instance
}
// namespace device
}
// namespace tensor_operation
}
// namespace ck
test/gemm/instance/gemm_f16_nn_instance.hpp
0 → 100644
View file @
5b7c2432
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#include <cstdlib>
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
namespace
ck
{
namespace
tensor_operation
{
namespace
device
{
namespace
instance
{
using
F16
=
ck
::
half_t
;
using
F32
=
float
;
using
Row
=
ck
::
tensor_layout
::
gemm
::
RowMajor
;
using
Col
=
ck
::
tensor_layout
::
gemm
::
ColumnMajor
;
template
<
ck
::
index_t
...
Is
>
using
S
=
ck
::
Sequence
<
Is
...
>
;
using
PassThrough
=
ck
::
tensor_operation
::
element_wise
::
PassThrough
;
void
add_gemm_f16_nn_256x256
(
std
::
vector
<
std
::
unique_ptr
<
BaseOperator
>>&
instances
);
void
add_gemm_f16_nn_256x128
(
std
::
vector
<
std
::
unique_ptr
<
BaseOperator
>>&
instances
);
void
add_gemm_f16_nn_128x128
(
std
::
vector
<
std
::
unique_ptr
<
BaseOperator
>>&
instances
);
void
add_gemm_f16_nn_128x64
(
std
::
vector
<
std
::
unique_ptr
<
BaseOperator
>>&
instances
);
}
// namespace instance
}
// namespace device
}
// namespace tensor_operation
}
// namespace ck
Prev
1
…
13
14
15
16
17
18
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment