Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
24af0144
Unverified
Commit
24af0144
authored
Nov 12, 2022
by
Po Yen Chen
Committed by
GitHub
Nov 12, 2022
Browse files
Merge branch 'develop' into gemm_layernorm_welford
parents
961f5e9e
b79bbbc2
Changes
813
Show whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
890 additions
and
312 deletions
+890
-312
test/gemm/gemm_standalone_xdl_fp16.cpp
test/gemm/gemm_standalone_xdl_fp16.cpp
+162
-0
test/gemm/gemm_util.hpp
test/gemm/gemm_util.hpp
+75
-51
test/gemm/instance/gemm_f16_nn_instance.cpp
test/gemm/instance/gemm_f16_nn_instance.cpp
+86
-0
test/gemm/instance/gemm_f16_nn_instance.hpp
test/gemm/instance/gemm_f16_nn_instance.hpp
+41
-0
test/gemm/instance/gemm_f16_nt_instance.cpp
test/gemm/instance/gemm_f16_nt_instance.cpp
+86
-0
test/gemm/instance/gemm_f16_nt_instance.hpp
test/gemm/instance/gemm_f16_nt_instance.hpp
+41
-0
test/gemm/instance/gemm_f16_tn_instance.cpp
test/gemm/instance/gemm_f16_tn_instance.cpp
+86
-0
test/gemm/instance/gemm_f16_tn_instance.hpp
test/gemm/instance/gemm_f16_tn_instance.hpp
+41
-0
test/gemm/instance/gemm_f16_tt_instance.cpp
test/gemm/instance/gemm_f16_tt_instance.cpp
+86
-0
test/gemm/instance/gemm_f16_tt_instance.hpp
test/gemm/instance/gemm_f16_tt_instance.hpp
+41
-0
test/gemm/run_gemm_test.inc
test/gemm/run_gemm_test.inc
+41
-0
test/gemm_split_k/gemm_split_k.cpp
test/gemm_split_k/gemm_split_k.cpp
+5
-4
test/grouped_convnd_bwd_weight/CMakeLists.txt
test/grouped_convnd_bwd_weight/CMakeLists.txt
+2
-0
test/grouped_convnd_bwd_weight/grouped_convnd_bwd_weight.cpp
test/grouped_convnd_bwd_weight/grouped_convnd_bwd_weight.cpp
+91
-0
test/layernorm/test_layernorm2d_fp16.cpp
test/layernorm/test_layernorm2d_fp16.cpp
+0
-29
test/layernorm/test_layernorm2d_fp32.cpp
test/layernorm/test_layernorm2d_fp32.cpp
+0
-29
test/layernorm/test_layernorm2d_util.hpp
test/layernorm/test_layernorm2d_util.hpp
+0
-179
test/normalization/CMakeLists.txt
test/normalization/CMakeLists.txt
+4
-4
test/normalization/test_groupnorm_fp16.cpp
test/normalization/test_groupnorm_fp16.cpp
+1
-8
test/normalization/test_groupnorm_fp32.cpp
test/normalization/test_groupnorm_fp32.cpp
+1
-8
No files found.
test/gemm/gemm_standalone_xdl_fp16.cpp
0 → 100644
View file @
24af0144
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#include "gemm_util.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_gemm_xdl.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle.hpp"
#include "gemm_f16_nn_instance.hpp"
#include "gemm_f16_nt_instance.hpp"
#include "gemm_f16_tn_instance.hpp"
#include "gemm_f16_tt_instance.hpp"
using
Row
=
ck
::
tensor_layout
::
gemm
::
RowMajor
;
using
Col
=
ck
::
tensor_layout
::
gemm
::
ColumnMajor
;
using
PassThrough
=
ck
::
tensor_operation
::
element_wise
::
PassThrough
;
using
F16
=
ck
::
half_t
;
using
ADataType
=
F16
;
using
BDataType
=
F16
;
using
AccDataType
=
float
;
using
CDataType
=
F16
;
using
ALayout
=
Row
;
using
BLayout
=
Col
;
using
CLayout
=
Row
;
using
AElementOp
=
PassThrough
;
using
BElementOp
=
PassThrough
;
using
CElementOp
=
PassThrough
;
using
ck
::
gemm_util
::
GemmParams
;
using
ck
::
tensor_operation
::
device
::
BaseOperator
;
using
ck
::
tensor_operation
::
device
::
DeviceGemm
;
using
namespace
ck
::
tensor_operation
::
device
::
instance
;
using
DeviceGemmNN
=
DeviceGemm
<
Col
,
Col
,
Row
,
F16
,
F16
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
>
;
using
DeviceGemmNT
=
DeviceGemm
<
Col
,
Row
,
Row
,
F16
,
F16
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
>
;
using
DeviceGemmTN
=
DeviceGemm
<
Row
,
Col
,
Row
,
F16
,
F16
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
>
;
using
DeviceGemmTT
=
DeviceGemm
<
Row
,
Row
,
Row
,
F16
,
F16
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
>
;
struct
LayoutConfig
{
bool
ARowMajor
;
bool
BRowMajor
;
bool
CRowMajor
;
};
int
main
(
int
argc
,
char
*
argv
[])
{
// Class DeviceGemm is templated by layout and precision types so it is not an option to contain
// them in a single vector. Instead we use abstract BaseOperator class and dynamic_cast() it
// upon invocation.
// And since DeviceGemm does not expose template arg information, an extra book keeping class
// LayoutConfig is used for determining which type a BaseOperator instance should be cast to.
using
OpFactoryFn
=
void
(
*
)(
std
::
vector
<
std
::
unique_ptr
<
BaseOperator
>>&
);
std
::
vector
<
std
::
tuple
<
GemmParams
,
LayoutConfig
,
OpFactoryFn
>>
problems
=
{
// clang-format off
// 104 tiles
{
GemmParams
{
2048
,
3328
,
4096
},
LayoutConfig
{
false
,
false
,
true
},
add_gemm_f16_nn_256x256
},
{
GemmParams
{
2048
,
1664
,
4096
},
LayoutConfig
{
false
,
false
,
true
},
add_gemm_f16_nn_256x128
},
{
GemmParams
{
1024
,
1664
,
4096
},
LayoutConfig
{
false
,
false
,
true
},
add_gemm_f16_nn_128x128
},
{
GemmParams
{
1024
,
832
,
4096
},
LayoutConfig
{
false
,
false
,
true
},
add_gemm_f16_nn_128x64
},
{
GemmParams
{
2048
,
3328
,
4096
},
LayoutConfig
{
false
,
true
,
true
},
add_gemm_f16_nt_256x256
},
{
GemmParams
{
2048
,
1664
,
4096
},
LayoutConfig
{
false
,
true
,
true
},
add_gemm_f16_nt_256x128
},
{
GemmParams
{
1024
,
1664
,
4096
},
LayoutConfig
{
false
,
true
,
true
},
add_gemm_f16_nt_128x128
},
{
GemmParams
{
1024
,
832
,
4096
},
LayoutConfig
{
false
,
true
,
true
},
add_gemm_f16_nt_128x64
},
{
GemmParams
{
2048
,
3328
,
4096
},
LayoutConfig
{
true
,
false
,
true
},
add_gemm_f16_tn_256x256
},
{
GemmParams
{
2048
,
1664
,
4096
},
LayoutConfig
{
true
,
false
,
true
},
add_gemm_f16_tn_256x128
},
{
GemmParams
{
1024
,
1664
,
4096
},
LayoutConfig
{
true
,
false
,
true
},
add_gemm_f16_tn_128x128
},
{
GemmParams
{
1024
,
832
,
4096
},
LayoutConfig
{
true
,
false
,
true
},
add_gemm_f16_tn_128x64
},
{
GemmParams
{
2048
,
3328
,
4096
},
LayoutConfig
{
true
,
true
,
true
},
add_gemm_f16_tt_256x256
},
{
GemmParams
{
2048
,
1664
,
4096
},
LayoutConfig
{
true
,
true
,
true
},
add_gemm_f16_tt_256x128
},
{
GemmParams
{
1024
,
1664
,
4096
},
LayoutConfig
{
true
,
true
,
true
},
add_gemm_f16_tt_128x128
},
{
GemmParams
{
1024
,
832
,
4096
},
LayoutConfig
{
true
,
true
,
true
},
add_gemm_f16_tt_128x64
},
// 110 tiles
{
GemmParams
{
2560
,
2816
,
4096
},
LayoutConfig
{
false
,
false
,
true
},
add_gemm_f16_nn_256x256
},
{
GemmParams
{
2560
,
1408
,
4096
},
LayoutConfig
{
false
,
false
,
true
},
add_gemm_f16_nn_256x128
},
{
GemmParams
{
1280
,
1408
,
4096
},
LayoutConfig
{
false
,
false
,
true
},
add_gemm_f16_nn_128x128
},
{
GemmParams
{
1280
,
704
,
4096
},
LayoutConfig
{
false
,
false
,
true
},
add_gemm_f16_nn_128x64
},
{
GemmParams
{
2560
,
2816
,
4096
},
LayoutConfig
{
false
,
true
,
true
},
add_gemm_f16_nt_256x256
},
{
GemmParams
{
2560
,
1408
,
4096
},
LayoutConfig
{
false
,
true
,
true
},
add_gemm_f16_nt_256x128
},
{
GemmParams
{
1280
,
1408
,
4096
},
LayoutConfig
{
false
,
true
,
true
},
add_gemm_f16_nt_128x128
},
{
GemmParams
{
1280
,
704
,
4096
},
LayoutConfig
{
false
,
true
,
true
},
add_gemm_f16_nt_128x64
},
{
GemmParams
{
2560
,
2816
,
4096
},
LayoutConfig
{
true
,
false
,
true
},
add_gemm_f16_tn_256x256
},
{
GemmParams
{
2560
,
1408
,
4096
},
LayoutConfig
{
true
,
false
,
true
},
add_gemm_f16_tn_256x128
},
{
GemmParams
{
1280
,
1408
,
4096
},
LayoutConfig
{
true
,
false
,
true
},
add_gemm_f16_tn_128x128
},
{
GemmParams
{
1280
,
704
,
4096
},
LayoutConfig
{
true
,
false
,
true
},
add_gemm_f16_tn_128x64
},
{
GemmParams
{
2560
,
2816
,
4096
},
LayoutConfig
{
true
,
true
,
true
},
add_gemm_f16_tt_256x256
},
{
GemmParams
{
2560
,
1408
,
4096
},
LayoutConfig
{
true
,
true
,
true
},
add_gemm_f16_tt_256x128
},
{
GemmParams
{
1280
,
1408
,
4096
},
LayoutConfig
{
true
,
true
,
true
},
add_gemm_f16_tt_128x128
},
{
GemmParams
{
1280
,
704
,
4096
},
LayoutConfig
{
true
,
true
,
true
},
add_gemm_f16_tt_128x64
},
// clang-format on
};
bool
do_verification
=
true
;
bool
time_kernel
=
true
;
if
(
argc
==
1
)
{
// use default
}
else
if
(
argc
==
3
)
{
do_verification
=
std
::
stoi
(
argv
[
1
]);
time_kernel
=
std
::
stoi
(
argv
[
2
]);
}
else
{
std
::
cerr
<<
"arg1: verification (0=no, 1=yes)"
<<
std
::
endl
<<
"arg2: time kernel (0=no, 1=yes)"
<<
std
::
endl
;
return
0
;
}
bool
pass
=
true
;
for
(
auto
&
p
:
problems
)
{
GemmParams
&
problem_size
=
std
::
get
<
0
>
(
p
);
const
LayoutConfig
&
layout_config
=
std
::
get
<
1
>
(
p
);
const
auto
&
factory
=
std
::
get
<
2
>
(
p
);
std
::
vector
<
std
::
unique_ptr
<
BaseOperator
>>
ops
;
factory
(
ops
);
// overwrite strides
problem_size
.
StrideA
=
layout_config
.
ARowMajor
?
problem_size
.
K
:
problem_size
.
M
;
problem_size
.
StrideB
=
layout_config
.
BRowMajor
?
problem_size
.
N
:
problem_size
.
K
;
problem_size
.
StrideC
=
layout_config
.
CRowMajor
?
problem_size
.
N
:
problem_size
.
M
;
if
(
!
layout_config
.
ARowMajor
&&
!
layout_config
.
BRowMajor
)
{
auto
op_ptr
=
dynamic_cast
<
DeviceGemmNN
*>
(
ops
[
0
].
get
());
pass
&=
ck
::
gemm_util
::
TestGemm
<
AccDataType
>
{}(
op_ptr
,
problem_size
,
do_verification
,
time_kernel
);
}
else
if
(
!
layout_config
.
ARowMajor
&&
layout_config
.
BRowMajor
)
{
auto
op_ptr
=
dynamic_cast
<
DeviceGemmNT
*>
(
ops
[
0
].
get
());
pass
&=
ck
::
gemm_util
::
TestGemm
<
AccDataType
>
{}(
op_ptr
,
problem_size
,
do_verification
,
time_kernel
);
}
else
if
(
layout_config
.
ARowMajor
&&
!
layout_config
.
BRowMajor
)
{
auto
op_ptr
=
dynamic_cast
<
DeviceGemmTN
*>
(
ops
[
0
].
get
());
pass
&=
ck
::
gemm_util
::
TestGemm
<
AccDataType
>
{}(
op_ptr
,
problem_size
,
do_verification
,
time_kernel
);
}
else
if
(
layout_config
.
ARowMajor
&&
layout_config
.
BRowMajor
)
{
auto
op_ptr
=
dynamic_cast
<
DeviceGemmTT
*>
(
ops
[
0
].
get
());
pass
&=
ck
::
gemm_util
::
TestGemm
<
AccDataType
>
{}(
op_ptr
,
problem_size
,
do_verification
,
time_kernel
);
}
}
std
::
cout
<<
(
pass
?
"ALL TESTS PASSED"
:
"SOME TESTS FAILED"
)
<<
std
::
endl
;
return
pass
?
0
:
1
;
}
test/gemm/gemm_util.hpp
View file @
24af0144
...
...
@@ -9,6 +9,7 @@
#include "ck/library/utility/device_memory.hpp"
#include "ck/library/utility/host_tensor.hpp"
#include "ck/library/utility/host_tensor_generator.hpp"
#include "ck/library/utility/literals.hpp"
#include "ck/library/reference_tensor_operation/cpu/reference_gemm.hpp"
namespace
ck
{
...
...
@@ -16,21 +17,13 @@ namespace gemm_util {
struct
GemmParams
{
GemmParams
()
:
M
(
1024
),
N
(
1024
),
K
(
1024
),
StrideA
(
1024
),
StrideB
(
1024
),
StrideC
(
1024
),
alpha
(
1
),
beta
(
0
)
{
}
ck
::
index_t
M
;
ck
::
index_t
N
;
ck
::
index_t
K
;
ck
::
index_t
M
=
1024
;
ck
::
index_t
N
=
1024
;
ck
::
index_t
K
=
1024
;
ck
::
index_t
StrideA
;
ck
::
index_t
StrideB
;
ck
::
index_t
StrideC
;
float
alpha
;
float
beta
;
ck
::
index_t
StrideA
=
1024
;
ck
::
index_t
StrideB
=
1024
;
ck
::
index_t
StrideC
=
1024
;
};
template
<
typename
GemmInstance
,
...
...
@@ -69,7 +62,8 @@ bool RunDeviceGEMM(DeviceGemmPtr_& gemmPtr,
Tensor
<
CDataType
>&
C
,
AElementwiseOperation
a_element_op
,
BElementwiseOperation
b_element_op
,
CElementwiseOperation
c_element_op
)
CElementwiseOperation
c_element_op
,
bool
time_kernel
)
{
DeviceMem
a_m_k_device_buf
(
sizeof
(
ADataType
)
*
A
.
mDesc
.
GetElementSpaceSize
());
DeviceMem
b_k_n_device_buf
(
sizeof
(
BDataType
)
*
B
.
mDesc
.
GetElementSpaceSize
());
...
...
@@ -94,7 +88,20 @@ bool RunDeviceGEMM(DeviceGemmPtr_& gemmPtr,
{
a_m_k_device_buf
.
ToDevice
(
A
.
mData
.
data
());
b_k_n_device_buf
.
ToDevice
(
B
.
mData
.
data
());
invoker_ptr
->
Run
(
argument_ptr
.
get
());
float
ave_time
=
invoker_ptr
->
Run
(
argument_ptr
.
get
(),
StreamConfig
{
nullptr
,
time_kernel
});
std
::
size_t
flop
=
std
::
size_t
(
2
)
*
params
.
M
*
params
.
N
*
params
.
K
;
std
::
size_t
num_btype
=
sizeof
(
ADataType
)
*
params
.
M
*
params
.
K
+
sizeof
(
BDataType
)
*
params
.
K
*
params
.
N
+
sizeof
(
CDataType
)
*
params
.
M
*
params
.
N
;
float
tflops
=
static_cast
<
float
>
(
flop
)
/
1.E9
/
ave_time
;
float
gb_per_sec
=
num_btype
/
1.E6
/
ave_time
;
std
::
cout
<<
"Perf: "
<<
ave_time
<<
" ms, "
<<
tflops
<<
" TFlops, "
<<
gb_per_sec
<<
" GB/s, "
<<
std
::
endl
;
c_m_n_device_buf
.
FromDevice
(
C
.
mData
.
data
());
return
true
;
...
...
@@ -109,32 +116,28 @@ bool RunDeviceGEMM(DeviceGemmPtr_& gemmPtr,
}
}
template
<
typename
DeviceGemmPtr_
,
typename
ADataType
,
template
<
typename
AccDataType
>
struct
TestGemm
{
template
<
typename
ADataType
,
typename
BDataType
,
typename
CDataType
,
typename
AccDataType
,
typename
ALayout
,
typename
BLayout
,
typename
CLayout
,
typename
AElementwiseOperation
,
typename
BElementwiseOperation
,
typename
CElementwiseOperation
>
struct
TestGemm
{
typename
CLayout
>
auto
PrepareGemmTensor
(
const
ck
::
gemm_util
::
GemmParams
&
params
)
{
auto
f_host_tensor_descriptor
=
[](
std
::
size_t
row
,
std
::
size_t
col
,
std
::
size_t
stride
,
auto
layout
)
{
using
namespace
ck
::
literals
;
if
(
std
::
is_same
<
decltype
(
layout
),
ck
::
tensor_layout
::
gemm
::
RowMajor
>::
value
)
{
return
HostTensorDescriptor
(
std
::
vector
<
std
::
size_t
>
({
row
,
col
}),
std
::
vector
<
std
::
size_t
>
({
stride
,
1
}));
return
HostTensorDescriptor
({
row
,
col
},
{
stride
,
1
_uz
});
}
else
{
return
HostTensorDescriptor
(
std
::
vector
<
std
::
size_t
>
({
row
,
col
}),
std
::
vector
<
std
::
size_t
>
({
1
,
stride
}));
return
HostTensorDescriptor
({
row
,
col
},
{
1
_uz
,
stride
});
}
};
...
...
@@ -156,25 +159,42 @@ struct TestGemm
f_generate_tensor_value
(
a_m_k
,
ADataType
{});
f_generate_tensor_value
(
b_k_n
,
BDataType
{});
std
::
cout
<<
"a_m_k: "
<<
a_m_k
.
mDesc
<<
std
::
endl
;
std
::
cout
<<
"b_k_n: "
<<
b_k_n
.
mDesc
<<
std
::
endl
;
std
::
cout
<<
"c_m_n: "
<<
c_m_n_host_result
.
mDesc
<<
std
::
endl
;
return
std
::
make_tuple
(
a_m_k
,
b_k_n
,
c_m_n_host_result
,
c_m_n_device_result
);
}
auto
operator
()(
const
DeviceGemmPtr_
&
gemmPtr
)
template
<
template
<
class
...
>
class
DeviceGemmPtr_
,
typename
ALayout
,
typename
BLayout
,
typename
CLayout
,
typename
ADataType
,
typename
BDataType
,
typename
CDataType
,
typename
AElementwiseOperation
,
typename
BElementwiseOperation
,
typename
CElementwiseOperation
>
auto
operator
()(
DeviceGemmPtr_
<
ALayout
,
BLayout
,
CLayout
,
ADataType
,
BDataType
,
CDataType
,
AElementwiseOperation
,
BElementwiseOperation
,
CElementwiseOperation
>*
gemmPtr
,
const
GemmParams
&
params
=
GemmParams
{},
bool
do_verification
=
true
,
bool
time_kernel
=
false
)
{
std
::
cout
<<
"ALayout = "
<<
ALayout
{}.
name
<<
", BLayout = "
<<
BLayout
{}.
name
<<
", CLayout = "
<<
CLayout
{}.
name
<<
std
::
endl
;
std
::
cout
<<
gemmPtr
->
GetTypeString
()
<<
std
::
endl
;
// Arrange
ck
::
gemm_util
::
GemmParams
params
;
params
.
M
=
1024
;
params
.
N
=
1024
;
params
.
K
=
1024
;
params
.
StrideA
=
1024
;
params
.
StrideB
=
1024
;
params
.
StrideC
=
1024
;
auto
host_tensors
=
PrepareGemmTensor
(
params
);
auto
host_tensors
=
PrepareGemmTensor
<
ADataType
,
BDataType
,
CDataType
,
ALayout
,
BLayout
,
CLayout
>
(
params
);
const
Tensor
<
ADataType
>&
a
=
std
::
get
<
0
>
(
host_tensors
);
const
Tensor
<
BDataType
>&
b
=
std
::
get
<
1
>
(
host_tensors
);
...
...
@@ -193,40 +213,44 @@ struct TestGemm
AElementwiseOperation
,
BElementwiseOperation
,
CElementwiseOperation
>
;
if
(
do_verification
)
{
ck
::
gemm_util
::
RunHostGEMM
<
ReferenceGemmInstance
>
(
a
,
b
,
c_host
,
a_element_op
,
b_element_op
,
c_element_op
);
}
// Act
bool
is_supported
=
ck
::
gemm_util
::
RunDeviceGEMM
(
gemmPtr
,
params
,
a
,
b
,
c_device
,
a_element_op
,
b_element_op
,
c_element_op
);
gemmPtr
,
params
,
a
,
b
,
c_device
,
a_element_op
,
b_element_op
,
c_element_op
,
time_kernel
);
if
(
is_supported
)
if
(
is_supported
&&
do_verification
)
{
// Assert
bool
res
=
false
;
if
(
std
::
is_same
<
CDataType
,
float
>::
value
)
{
res
=
ck
::
utils
::
check_err
(
c_device
.
mData
,
c_host
.
mData
);
res
=
ck
::
utils
::
check_err
(
c_device
,
c_host
);
std
::
cout
<<
(
res
?
"SUCCESS"
:
"FAILURE"
)
<<
std
::
endl
;
}
else
if
(
std
::
is_same
<
CDataType
,
ck
::
half_t
>::
value
)
{
res
=
ck
::
utils
::
check_err
(
c_device
.
mData
,
c_host
.
mData
);
res
=
ck
::
utils
::
check_err
(
c_device
,
c_host
);
std
::
cout
<<
(
res
?
"SUCCESS"
:
"FAILURE"
)
<<
std
::
endl
;
}
else
if
(
std
::
is_same
<
CDataType
,
ck
::
bhalf_t
>::
value
)
{
res
=
ck
::
utils
::
check_err
(
c_device
.
mData
,
c_host
.
mData
);
res
=
ck
::
utils
::
check_err
(
c_device
,
c_host
);
std
::
cout
<<
(
res
?
"SUCCESS"
:
"FAILURE"
)
<<
std
::
endl
;
}
else
if
(
std
::
is_same
<
CDataType
,
int8_t
>::
value
)
{
res
=
ck
::
utils
::
check_err
(
c_device
.
mData
,
c_host
.
mData
);
res
=
ck
::
utils
::
check_err
(
c_device
,
c_host
);
std
::
cout
<<
(
res
?
"SUCCESS"
:
"FAILURE"
)
<<
std
::
endl
;
}
else
if
(
std
::
is_same
<
CDataType
,
double
>::
value
)
{
res
=
ck
::
utils
::
check_err
(
c_device
.
mData
,
c_host
.
mData
);
res
=
ck
::
utils
::
check_err
(
c_device
,
c_host
);
std
::
cout
<<
(
res
?
"SUCCESS"
:
"FAILURE"
)
<<
std
::
endl
;
}
...
...
test/gemm/instance/gemm_f16_nn_instance.cpp
0 → 100644
View file @
24af0144
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#include <cstdlib>
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
#include "gemm_f16_nn_instance.hpp"
namespace
ck
{
namespace
tensor_operation
{
namespace
device
{
namespace
instance
{
static
constexpr
auto
GemmDefault
=
ck
::
tensor_operation
::
device
::
GemmSpecialization
::
Default
;
using
gemm_f16_nn_256x256
=
std
::
tuple
<
// clang-format off
//#####################| ALayout| BLayout| CLayout| AData| BData| CData| AccData| CShuffle| A| B| C| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer|
//#####################| | | | Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector|
//#####################| | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl|
//#####################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
DeviceGemm_Xdl_CShuffle
<
Col
,
Col
,
Row
,
F16
,
F16
,
F16
,
F32
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmDefault
,
1
,
256
,
256
,
256
,
32
,
2
,
8
,
32
,
32
,
4
,
4
,
S
<
4
,
64
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
2
,
0
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
1
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
8
>
// clang-format on
>
;
using
gemm_f16_nn_256x128
=
std
::
tuple
<
// clang-format off
//#####################| ALayout| BLayout| CLayout| AData| BData| CData| AccData| CShuffle| A| B| C| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer|
//#####################| | | | Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector|
//#####################| | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl|
//#####################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
DeviceGemm_Xdl_CShuffle
<
Col
,
Col
,
Row
,
F16
,
F16
,
F16
,
F32
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmDefault
,
1
,
256
,
256
,
128
,
32
,
2
,
8
,
32
,
32
,
4
,
2
,
S
<
4
,
64
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
2
,
0
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
1
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
8
>
// clang-format on
>
;
using
gemm_f16_nn_128x128
=
std
::
tuple
<
// clang-format off
//#####################| ALayout| BLayout| CLayout| AData| BData| CData| AccData| CShuffle| A| B| C| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer|
//#####################| | | | Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector|
//#####################| | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl|
//#####################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
DeviceGemm_Xdl_CShuffle
<
Col
,
Col
,
Row
,
F16
,
F16
,
F16
,
F32
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmDefault
,
1
,
256
,
128
,
128
,
32
,
2
,
8
,
32
,
32
,
2
,
2
,
S
<
8
,
32
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
2
,
0
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
1
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
8
>
// clang-format on
>
;
using
gemm_f16_nn_128x64
=
std
::
tuple
<
// clang-format off
//#####################| ALayout| BLayout| CLayout| AData| BData| CData| AccData| CShuffle| A| B| C| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer|
//#####################| | | | Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector|
//#####################| | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl|
//#####################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
DeviceGemm_Xdl_CShuffle
<
Col
,
Col
,
Row
,
F16
,
F16
,
F16
,
F32
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmDefault
,
1
,
256
,
128
,
64
,
32
,
2
,
8
,
32
,
32
,
2
,
1
,
S
<
8
,
32
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
2
,
0
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
1
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
8
>
// clang-format on
>
;
void
add_gemm_f16_nn_256x256
(
std
::
vector
<
std
::
unique_ptr
<
BaseOperator
>>&
instances
)
{
add_device_operation_instances
(
instances
,
gemm_f16_nn_256x256
{});
}
void
add_gemm_f16_nn_256x128
(
std
::
vector
<
std
::
unique_ptr
<
BaseOperator
>>&
instances
)
{
add_device_operation_instances
(
instances
,
gemm_f16_nn_256x128
{});
}
void
add_gemm_f16_nn_128x128
(
std
::
vector
<
std
::
unique_ptr
<
BaseOperator
>>&
instances
)
{
add_device_operation_instances
(
instances
,
gemm_f16_nn_128x128
{});
}
void
add_gemm_f16_nn_128x64
(
std
::
vector
<
std
::
unique_ptr
<
BaseOperator
>>&
instances
)
{
add_device_operation_instances
(
instances
,
gemm_f16_nn_128x64
{});
}
}
// namespace instance
}
// namespace device
}
// namespace tensor_operation
}
// namespace ck
test/gemm/instance/gemm_f16_nn_instance.hpp
0 → 100644
View file @
24af0144
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#include <cstdlib>
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
namespace
ck
{
namespace
tensor_operation
{
namespace
device
{
namespace
instance
{
using
F16
=
ck
::
half_t
;
using
F32
=
float
;
using
Row
=
ck
::
tensor_layout
::
gemm
::
RowMajor
;
using
Col
=
ck
::
tensor_layout
::
gemm
::
ColumnMajor
;
template
<
ck
::
index_t
...
Is
>
using
S
=
ck
::
Sequence
<
Is
...
>
;
using
PassThrough
=
ck
::
tensor_operation
::
element_wise
::
PassThrough
;
void
add_gemm_f16_nn_256x256
(
std
::
vector
<
std
::
unique_ptr
<
BaseOperator
>>&
instances
);
void
add_gemm_f16_nn_256x128
(
std
::
vector
<
std
::
unique_ptr
<
BaseOperator
>>&
instances
);
void
add_gemm_f16_nn_128x128
(
std
::
vector
<
std
::
unique_ptr
<
BaseOperator
>>&
instances
);
void
add_gemm_f16_nn_128x64
(
std
::
vector
<
std
::
unique_ptr
<
BaseOperator
>>&
instances
);
}
// namespace instance
}
// namespace device
}
// namespace tensor_operation
}
// namespace ck
test/gemm/instance/gemm_f16_nt_instance.cpp
0 → 100644
View file @
24af0144
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#include <cstdlib>
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
#include "gemm_f16_nt_instance.hpp"
namespace
ck
{
namespace
tensor_operation
{
namespace
device
{
namespace
instance
{
static
constexpr
auto
GemmDefault
=
ck
::
tensor_operation
::
device
::
GemmSpecialization
::
Default
;
using
gemm_f16_nt_256x256
=
std
::
tuple
<
// clang-format off
//#####################| ALayout| BLayout| CLayout| AData| BData| CData| AccData| CShuffle| A| B| C| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer|
//#####################| | | | Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector|
//#####################| | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl|
//#####################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
DeviceGemm_Xdl_CShuffle
<
Col
,
Row
,
Row
,
F16
,
F16
,
F16
,
F32
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmDefault
,
1
,
256
,
256
,
256
,
32
,
2
,
2
,
32
,
32
,
4
,
4
,
S
<
4
,
64
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
2
,
0
,
S
<
4
,
64
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
2
,
0
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
8
>
// clang-format on
>
;
using
gemm_f16_nt_256x128
=
std
::
tuple
<
// clang-format off
//#####################| ALayout| BLayout| CLayout| AData| BData| CData| AccData| CShuffle| A| B| C| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer|
//#####################| | | | Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector|
//#####################| | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl|
//#####################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
DeviceGemm_Xdl_CShuffle
<
Col
,
Row
,
Row
,
F16
,
F16
,
F16
,
F32
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmDefault
,
1
,
256
,
256
,
128
,
32
,
2
,
2
,
32
,
32
,
4
,
2
,
S
<
4
,
64
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
2
,
0
,
S
<
8
,
32
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
2
,
0
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
8
>
// clang-format on
>
;
using
gemm_f16_nt_128x128
=
std
::
tuple
<
// clang-format off
//#####################| ALayout| BLayout| CLayout| AData| BData| CData| AccData| CShuffle| A| B| C| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer|
//#####################| | | | Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector|
//#####################| | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl|
//#####################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
DeviceGemm_Xdl_CShuffle
<
Col
,
Row
,
Row
,
F16
,
F16
,
F16
,
F32
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmDefault
,
1
,
256
,
128
,
128
,
32
,
2
,
2
,
32
,
32
,
2
,
2
,
S
<
8
,
32
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
2
,
0
,
S
<
8
,
32
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
2
,
0
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
8
>
// clang-format on
>
;
using
gemm_f16_nt_128x64
=
std
::
tuple
<
// clang-format off
//#####################| ALayout| BLayout| CLayout| AData| BData| CData| AccData| CShuffle| A| B| C| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer|
//#####################| | | | Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector|
//#####################| | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl|
//#####################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
DeviceGemm_Xdl_CShuffle
<
Col
,
Row
,
Row
,
F16
,
F16
,
F16
,
F32
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmDefault
,
1
,
256
,
128
,
64
,
32
,
2
,
2
,
32
,
32
,
2
,
1
,
S
<
8
,
32
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
2
,
0
,
S
<
16
,
16
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
2
,
0
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
8
>
// clang-format on
>
;
void
add_gemm_f16_nt_256x256
(
std
::
vector
<
std
::
unique_ptr
<
BaseOperator
>>&
instances
)
{
add_device_operation_instances
(
instances
,
gemm_f16_nt_256x256
{});
}
void
add_gemm_f16_nt_256x128
(
std
::
vector
<
std
::
unique_ptr
<
BaseOperator
>>&
instances
)
{
add_device_operation_instances
(
instances
,
gemm_f16_nt_256x128
{});
}
void
add_gemm_f16_nt_128x128
(
std
::
vector
<
std
::
unique_ptr
<
BaseOperator
>>&
instances
)
{
add_device_operation_instances
(
instances
,
gemm_f16_nt_128x128
{});
}
void
add_gemm_f16_nt_128x64
(
std
::
vector
<
std
::
unique_ptr
<
BaseOperator
>>&
instances
)
{
add_device_operation_instances
(
instances
,
gemm_f16_nt_128x64
{});
}
}
// namespace instance
}
// namespace device
}
// namespace tensor_operation
}
// namespace ck
test/gemm/instance/gemm_f16_nt_instance.hpp
0 → 100644
View file @
24af0144
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#include <cstdlib>
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
namespace
ck
{
namespace
tensor_operation
{
namespace
device
{
namespace
instance
{
using
F16
=
ck
::
half_t
;
using
F32
=
float
;
using
Row
=
ck
::
tensor_layout
::
gemm
::
RowMajor
;
using
Col
=
ck
::
tensor_layout
::
gemm
::
ColumnMajor
;
template
<
ck
::
index_t
...
Is
>
using
S
=
ck
::
Sequence
<
Is
...
>
;
using
PassThrough
=
ck
::
tensor_operation
::
element_wise
::
PassThrough
;
void
add_gemm_f16_nt_256x256
(
std
::
vector
<
std
::
unique_ptr
<
BaseOperator
>>&
instances
);
void
add_gemm_f16_nt_256x128
(
std
::
vector
<
std
::
unique_ptr
<
BaseOperator
>>&
instances
);
void
add_gemm_f16_nt_128x128
(
std
::
vector
<
std
::
unique_ptr
<
BaseOperator
>>&
instances
);
void
add_gemm_f16_nt_128x64
(
std
::
vector
<
std
::
unique_ptr
<
BaseOperator
>>&
instances
);
}
// namespace instance
}
// namespace device
}
// namespace tensor_operation
}
// namespace ck
test/gemm/instance/gemm_f16_tn_instance.cpp
0 → 100644
View file @
24af0144
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#include <cstdlib>
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
#include "gemm_f16_tn_instance.hpp"
namespace
ck
{
namespace
tensor_operation
{
namespace
device
{
namespace
instance
{
static
constexpr
auto
GemmDefault
=
ck
::
tensor_operation
::
device
::
GemmSpecialization
::
Default
;
using
gemm_f16_tn_256x256
=
std
::
tuple
<
// clang-format off
//#####################| ALayout| BLayout| CLayout| AData| BData| CData| AccData| CShuffle| A| B| C| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer|
//#####################| | | | Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector|
//#####################| | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl|
//#####################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
DeviceGemm_Xdl_CShuffle
<
Row
,
Col
,
Row
,
F16
,
F16
,
F16
,
F32
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmDefault
,
1
,
256
,
256
,
256
,
32
,
8
,
8
,
32
,
32
,
4
,
4
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
1
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
1
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
8
>
// clang-format on
>
;
using
gemm_f16_tn_256x128
=
std
::
tuple
<
// clang-format off
//#####################| ALayout| BLayout| CLayout| AData| BData| CData| AccData| CShuffle| A| B| C| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer|
//#####################| | | | Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector|
//#####################| | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl|
//#####################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
DeviceGemm_Xdl_CShuffle
<
Row
,
Col
,
Row
,
F16
,
F16
,
F16
,
F32
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmDefault
,
1
,
256
,
256
,
128
,
32
,
8
,
8
,
32
,
32
,
4
,
2
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
1
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
1
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
8
>
// clang-format on
>
;
using
gemm_f16_tn_128x128
=
std
::
tuple
<
// clang-format off
//#####################| ALayout| BLayout| CLayout| AData| BData| CData| AccData| CShuffle| A| B| C| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer|
//#####################| | | | Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector|
//#####################| | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl|
//#####################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
DeviceGemm_Xdl_CShuffle
<
Row
,
Col
,
Row
,
F16
,
F16
,
F16
,
F32
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmDefault
,
1
,
256
,
128
,
128
,
32
,
8
,
8
,
32
,
32
,
2
,
2
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
1
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
1
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
8
>
// clang-format on
>
;
using
gemm_f16_tn_128x64
=
std
::
tuple
<
// clang-format off
//#####################| ALayout| BLayout| CLayout| AData| BData| CData| AccData| CShuffle| A| B| C| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer|
//#####################| | | | Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector|
//#####################| | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl|
//#####################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
DeviceGemm_Xdl_CShuffle
<
Row
,
Col
,
Row
,
F16
,
F16
,
F16
,
F32
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmDefault
,
1
,
256
,
128
,
64
,
32
,
8
,
8
,
32
,
32
,
2
,
1
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
1
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
1
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
8
>
// clang-format on
>
;
void
add_gemm_f16_tn_256x256
(
std
::
vector
<
std
::
unique_ptr
<
BaseOperator
>>&
instances
)
{
add_device_operation_instances
(
instances
,
gemm_f16_tn_256x256
{});
}
void
add_gemm_f16_tn_256x128
(
std
::
vector
<
std
::
unique_ptr
<
BaseOperator
>>&
instances
)
{
add_device_operation_instances
(
instances
,
gemm_f16_tn_256x128
{});
}
void
add_gemm_f16_tn_128x128
(
std
::
vector
<
std
::
unique_ptr
<
BaseOperator
>>&
instances
)
{
add_device_operation_instances
(
instances
,
gemm_f16_tn_128x128
{});
}
void
add_gemm_f16_tn_128x64
(
std
::
vector
<
std
::
unique_ptr
<
BaseOperator
>>&
instances
)
{
add_device_operation_instances
(
instances
,
gemm_f16_tn_128x64
{});
}
}
// namespace instance
}
// namespace device
}
// namespace tensor_operation
}
// namespace ck
test/gemm/instance/gemm_f16_tn_instance.hpp
0 → 100644
View file @
24af0144
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#include <cstdlib>
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
namespace
ck
{
namespace
tensor_operation
{
namespace
device
{
namespace
instance
{
using
F16
=
ck
::
half_t
;
using
F32
=
float
;
using
Row
=
ck
::
tensor_layout
::
gemm
::
RowMajor
;
using
Col
=
ck
::
tensor_layout
::
gemm
::
ColumnMajor
;
template
<
ck
::
index_t
...
Is
>
using
S
=
ck
::
Sequence
<
Is
...
>
;
using
PassThrough
=
ck
::
tensor_operation
::
element_wise
::
PassThrough
;
void
add_gemm_f16_tn_256x256
(
std
::
vector
<
std
::
unique_ptr
<
BaseOperator
>>&
instances
);
void
add_gemm_f16_tn_256x128
(
std
::
vector
<
std
::
unique_ptr
<
BaseOperator
>>&
instances
);
void
add_gemm_f16_tn_128x128
(
std
::
vector
<
std
::
unique_ptr
<
BaseOperator
>>&
instances
);
void
add_gemm_f16_tn_128x64
(
std
::
vector
<
std
::
unique_ptr
<
BaseOperator
>>&
instances
);
}
// namespace instance
}
// namespace device
}
// namespace tensor_operation
}
// namespace ck
test/gemm/instance/gemm_f16_tt_instance.cpp
0 → 100644
View file @
24af0144
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#include <cstdlib>
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
#include "gemm_f16_tt_instance.hpp"
namespace
ck
{
namespace
tensor_operation
{
namespace
device
{
namespace
instance
{
static
constexpr
auto
GemmDefault
=
ck
::
tensor_operation
::
device
::
GemmSpecialization
::
Default
;
using
gemm_f16_tt_256x256
=
std
::
tuple
<
// clang-format off
//#####################| ALayout| BLayout| CLayout| AData| BData| CData| AccData| CShuffle| A| B| C| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer|
//#####################| | | | Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector|
//#####################| | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl|
//#####################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
DeviceGemm_Xdl_CShuffle
<
Row
,
Row
,
Row
,
F16
,
F16
,
F16
,
F32
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmDefault
,
1
,
256
,
256
,
256
,
32
,
8
,
2
,
32
,
32
,
4
,
4
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
1
,
S
<
4
,
64
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
2
,
0
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
8
>
// clang-format on
>
;
using
gemm_f16_tt_256x128
=
std
::
tuple
<
// clang-format off
//#####################| ALayout| BLayout| CLayout| AData| BData| CData| AccData| CShuffle| A| B| C| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer|
//#####################| | | | Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector|
//#####################| | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl|
//#####################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
DeviceGemm_Xdl_CShuffle
<
Row
,
Row
,
Row
,
F16
,
F16
,
F16
,
F32
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmDefault
,
1
,
256
,
256
,
128
,
32
,
8
,
2
,
32
,
32
,
4
,
2
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
1
,
S
<
8
,
32
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
2
,
0
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
8
>
// clang-format on
>
;
using
gemm_f16_tt_128x128
=
std
::
tuple
<
// clang-format off
//#####################| ALayout| BLayout| CLayout| AData| BData| CData| AccData| CShuffle| A| B| C| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer|
//#####################| | | | Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector|
//#####################| | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl|
//#####################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
DeviceGemm_Xdl_CShuffle
<
Row
,
Row
,
Row
,
F16
,
F16
,
F16
,
F32
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmDefault
,
1
,
256
,
128
,
128
,
32
,
8
,
2
,
32
,
32
,
2
,
2
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
1
,
S
<
8
,
32
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
2
,
0
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
8
>
// clang-format on
>
;
using
gemm_f16_tt_128x64
=
std
::
tuple
<
// clang-format off
//#####################| ALayout| BLayout| CLayout| AData| BData| CData| AccData| CShuffle| A| B| C| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer|
//#####################| | | | Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector|
//#####################| | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl|
//#####################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
DeviceGemm_Xdl_CShuffle
<
Row
,
Row
,
Row
,
F16
,
F16
,
F16
,
F32
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmDefault
,
1
,
256
,
128
,
64
,
32
,
8
,
2
,
32
,
32
,
2
,
1
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
1
,
S
<
16
,
16
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
2
,
0
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
8
>
// clang-format on
>
;
void
add_gemm_f16_tt_256x256
(
std
::
vector
<
std
::
unique_ptr
<
BaseOperator
>>&
instances
)
{
add_device_operation_instances
(
instances
,
gemm_f16_tt_256x256
{});
}
void
add_gemm_f16_tt_256x128
(
std
::
vector
<
std
::
unique_ptr
<
BaseOperator
>>&
instances
)
{
add_device_operation_instances
(
instances
,
gemm_f16_tt_256x128
{});
}
void
add_gemm_f16_tt_128x128
(
std
::
vector
<
std
::
unique_ptr
<
BaseOperator
>>&
instances
)
{
add_device_operation_instances
(
instances
,
gemm_f16_tt_128x128
{});
}
void
add_gemm_f16_tt_128x64
(
std
::
vector
<
std
::
unique_ptr
<
BaseOperator
>>&
instances
)
{
add_device_operation_instances
(
instances
,
gemm_f16_tt_128x64
{});
}
}
// namespace instance
}
// namespace device
}
// namespace tensor_operation
}
// namespace ck
test/gemm/instance/gemm_f16_tt_instance.hpp
0 → 100644
View file @
24af0144
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#include <cstdlib>
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
namespace
ck
{
namespace
tensor_operation
{
namespace
device
{
namespace
instance
{
using
F16
=
ck
::
half_t
;
using
F32
=
float
;
using
Row
=
ck
::
tensor_layout
::
gemm
::
RowMajor
;
using
Col
=
ck
::
tensor_layout
::
gemm
::
ColumnMajor
;
template
<
ck
::
index_t
...
Is
>
using
S
=
ck
::
Sequence
<
Is
...
>
;
using
PassThrough
=
ck
::
tensor_operation
::
element_wise
::
PassThrough
;
void
add_gemm_f16_tt_256x256
(
std
::
vector
<
std
::
unique_ptr
<
BaseOperator
>>&
instances
);
void
add_gemm_f16_tt_256x128
(
std
::
vector
<
std
::
unique_ptr
<
BaseOperator
>>&
instances
);
void
add_gemm_f16_tt_128x128
(
std
::
vector
<
std
::
unique_ptr
<
BaseOperator
>>&
instances
);
void
add_gemm_f16_tt_128x64
(
std
::
vector
<
std
::
unique_ptr
<
BaseOperator
>>&
instances
);
}
// namespace instance
}
// namespace device
}
// namespace tensor_operation
}
// namespace ck
test/gemm/run_gemm_test.inc
0 → 100644
View file @
24af0144
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
int
run_gemm_test
()
{
using
Row
=
ck
::
tensor_layout
::
gemm
::
RowMajor
;
using
Col
=
ck
::
tensor_layout
::
gemm
::
ColumnMajor
;
using
PassThrough
=
ck
::
tensor_operation
::
element_wise
::
PassThrough
;
auto
test
=
[
&
](
auto
a_layout
,
auto
b_layout
,
auto
c_layout
)
{
bool
pass
=
true
;
using
DeviceOp
=
ck
::
tensor_operation
::
device
::
DeviceGemm
<
decltype
(
a_layout
),
decltype
(
b_layout
),
decltype
(
c_layout
),
ADataType
,
BDataType
,
CDataType
,
PassThrough
,
PassThrough
,
PassThrough
>
;
const
auto
gemmPtrs
=
ck
::
tensor_operation
::
device
::
instance
::
DeviceOperationInstanceFactory
<
DeviceOp
>::
GetInstances
();
for
(
auto
&
gemmPtr
:
gemmPtrs
)
{
pass
&=
ck
::
gemm_util
::
TestGemm
<
AccDataType
>
{}(
gemmPtr
.
get
());
}
return
pass
;
};
bool
pass
=
test
(
Row
{},
Row
{},
Row
{})
&&
test
(
Row
{},
Col
{},
Row
{})
&&
test
(
Col
{},
Row
{},
Row
{})
&&
test
(
Col
{},
Col
{},
Row
{});
std
::
cout
<<
"TestGemm ..... "
<<
(
pass
?
"SUCCESS"
:
"FAILURE"
)
<<
std
::
endl
;
return
pass
?
0
:
1
;
}
test/gemm_split_k/gemm_split_k.cpp
View file @
24af0144
...
...
@@ -16,6 +16,7 @@
#include "ck/library/utility/device_memory.hpp"
#include "ck/library/utility/host_tensor.hpp"
#include "ck/library/utility/host_tensor_generator.hpp"
#include "ck/library/utility/literals.hpp"
#include "ck/library/reference_tensor_operation/cpu/reference_gemm.hpp"
#include "ck/library/utility/host_gemm.hpp"
...
...
@@ -93,15 +94,15 @@ int test_gemm(const gemmArgs& args)
auto
f_host_tensor_descriptor
=
[](
std
::
size_t
row
,
std
::
size_t
col
,
std
::
size_t
stride
,
bool
row_major
)
{
using
namespace
ck
::
literals
;
if
(
row_major
)
{
return
HostTensorDescriptor
(
std
::
vector
<
std
::
size_t
>
({
row
,
col
}),
std
::
vector
<
std
::
size_t
>
({
stride
,
1
}));
return
HostTensorDescriptor
({
row
,
col
},
{
stride
,
1
_uz
});
}
else
{
return
HostTensorDescriptor
(
std
::
vector
<
std
::
size_t
>
({
row
,
col
}),
std
::
vector
<
std
::
size_t
>
({
1
,
stride
}));
return
HostTensorDescriptor
({
row
,
col
},
{
1
_uz
,
stride
});
}
};
...
...
test/grouped_convnd_bwd_weight/CMakeLists.txt
0 → 100644
View file @
24af0144
add_gtest_executable
(
test_grouped_convnd_bwd_weight grouped_convnd_bwd_weight.cpp
)
target_link_libraries
(
test_grouped_convnd_bwd_weight PRIVATE utility device_grouped_conv1d_bwd_weight_instance device_grouped_conv2d_bwd_weight_instance device_grouped_conv3d_bwd_weight_instance
)
test/grouped_convnd_bwd_weight/grouped_convnd_bwd_weight.cpp
0 → 100644
View file @
24af0144
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#include <cstdlib>
#include <iostream>
#include <initializer_list>
#include <tuple>
#include <vector>
#include <gtest/gtest.h>
#include "profiler/include/profile_grouped_conv_bwd_weight_impl.hpp"
template
<
typename
Tuple
>
class
TestGroupedConvndBwdWeight
:
public
::
testing
::
Test
{
protected:
using
DataType
=
std
::
tuple_element_t
<
0
,
Tuple
>
;
std
::
vector
<
ck
::
utils
::
conv
::
ConvParam
>
conv_params
;
ck
::
index_t
split_k
{
2
};
template
<
ck
::
index_t
NDimSpatial
>
void
Run
()
{
for
(
auto
&
param
:
conv_params
)
{
bool
pass
;
EXPECT_FALSE
(
conv_params
.
empty
());
pass
=
ck
::
profiler
::
profile_grouped_conv_bwd_weight_impl
<
NDimSpatial
,
ck
::
tuple_element_t
<
NDimSpatial
-
1
,
ck
::
Tuple
<
ck
::
tensor_layout
::
convolution
::
GNWC
,
ck
::
tensor_layout
::
convolution
::
GNHWC
,
ck
::
tensor_layout
::
convolution
::
GNDHWC
>>
,
ck
::
tuple_element_t
<
NDimSpatial
-
1
,
ck
::
Tuple
<
ck
::
tensor_layout
::
convolution
::
GKXC
,
ck
::
tensor_layout
::
convolution
::
GKYXC
,
ck
::
tensor_layout
::
convolution
::
GKZYXC
>>
,
ck
::
tuple_element_t
<
NDimSpatial
-
1
,
ck
::
Tuple
<
ck
::
tensor_layout
::
convolution
::
GNWK
,
ck
::
tensor_layout
::
convolution
::
GNHWK
,
ck
::
tensor_layout
::
convolution
::
GNDHWK
>>
,
DataType
,
DataType
,
DataType
>
(
true
,
// do_verification
1
,
// init_method integer value
false
,
// do_log
false
,
// time_kernel
param
,
split_k
);
EXPECT_TRUE
(
pass
);
}
}
};
using
KernelTypes
=
::
testing
::
Types
<
std
::
tuple
<
float
>
,
std
::
tuple
<
ck
::
half_t
>
,
std
::
tuple
<
ck
::
bhalf_t
>>
;
TYPED_TEST_SUITE
(
TestGroupedConvndBwdWeight
,
KernelTypes
);
TYPED_TEST
(
TestGroupedConvndBwdWeight
,
Test1D
)
{
this
->
conv_params
.
clear
();
this
->
conv_params
.
push_back
({
1
,
4
,
128
,
128
,
256
,
{
1
},
{
14
},
{
2
},
{
1
},
{
0
},
{
0
}});
this
->
conv_params
.
push_back
({
1
,
4
,
128
,
128
,
256
,
{
3
},
{
28
},
{
1
},
{
1
},
{
1
},
{
1
}});
this
->
conv_params
.
push_back
({
1
,
4
,
128
,
128
,
256
,
{
1
},
{
3
},
{
1
},
{
1
},
{
0
},
{
0
}});
this
->
template
Run
<
1
>();
}
TYPED_TEST
(
TestGroupedConvndBwdWeight
,
Test2D
)
{
this
->
conv_params
.
clear
();
this
->
conv_params
.
push_back
(
{
2
,
4
,
128
,
128
,
256
,
{
1
,
1
},
{
7
,
7
},
{
2
,
2
},
{
1
,
1
},
{
0
,
0
},
{
0
,
0
}});
this
->
conv_params
.
push_back
(
{
2
,
4
,
32
,
128
,
256
,
{
3
,
3
},
{
14
,
14
},
{
1
,
1
},
{
1
,
1
},
{
1
,
1
},
{
1
,
1
}});
this
->
conv_params
.
push_back
(
{
2
,
4
,
128
,
128
,
256
,
{
1
,
1
},
{
3
,
3
},
{
1
,
1
},
{
1
,
1
},
{
0
,
0
},
{
0
,
0
}});
this
->
template
Run
<
2
>();
}
TYPED_TEST
(
TestGroupedConvndBwdWeight
,
Test3D
)
{
this
->
conv_params
.
clear
();
this
->
conv_params
.
push_back
(
{
3
,
4
,
128
,
128
,
256
,
{
1
,
1
,
1
},
{
7
,
7
,
7
},
{
2
,
2
,
2
},
{
1
,
1
,
1
},
{
0
,
0
,
0
},
{
0
,
0
,
0
}});
this
->
conv_params
.
push_back
(
{
3
,
4
,
32
,
128
,
256
,
{
3
,
3
,
3
},
{
14
,
14
,
3
},
{
1
,
1
,
1
},
{
1
,
1
,
1
},
{
1
,
1
,
1
},
{
1
,
1
,
1
}});
this
->
conv_params
.
push_back
(
{
3
,
4
,
128
,
128
,
256
,
{
1
,
1
,
1
},
{
3
,
3
,
3
},
{
1
,
1
,
1
},
{
1
,
1
,
1
},
{
0
,
0
,
0
},
{
0
,
0
,
0
}});
this
->
template
Run
<
3
>();
}
test/layernorm/test_layernorm2d_fp16.cpp
deleted
100644 → 0
View file @
961f5e9e
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#include "gtest/gtest.h"
#include "test_layernorm2d_util.hpp"
template
<
ck
::
index_t
N
>
using
I
=
ck
::
Number
<
N
>
;
template
<
typename
Tuple
>
class
TestLayernorm2dFP16
:
public
ck
::
TestLayernorm2d
<
Tuple
>
{
};
// clang-format off
using
KernelTypes
=
::
testing
::
Types
<
// XDataType, GammaDataType, BetaDataType, AccDataType, YDataType, Rank, NumReduceDim, BlockSize, MThreadClusterSize, KThreadClusterSize, MThreadSliceSize, KThreadSliceSize, XYSrcVectorDim, XSrcVectorSize, GammaSrcVectorDim , GammaSrcVectorSize, BetaSrcVectorDim, BetaSrcVectorSize, YDstVectorSize>
std
::
tuple
<
ck
::
half_t
,
ck
::
half_t
,
ck
::
half_t
,
float
,
ck
::
half_t
,
I
<
2
>
,
I
<
1
>
,
I
<
256
>
,
I
<
8
>
,
I
<
32
>
,
I
<
1
>
,
I
<
8
>
,
I
<
1
>
,
I
<
8
>
,
I
<
1
>
,
I
<
8
>
,
I
<
1
>
,
I
<
8
>
,
I
<
8
>>
,
std
::
tuple
<
ck
::
half_t
,
ck
::
half_t
,
ck
::
half_t
,
float
,
ck
::
half_t
,
I
<
2
>
,
I
<
1
>
,
I
<
256
>
,
I
<
8
>
,
I
<
32
>
,
I
<
2
>
,
I
<
8
>
,
I
<
1
>
,
I
<
8
>
,
I
<
1
>
,
I
<
8
>
,
I
<
1
>
,
I
<
8
>
,
I
<
8
>>
,
std
::
tuple
<
ck
::
half_t
,
ck
::
half_t
,
ck
::
half_t
,
float
,
ck
::
half_t
,
I
<
2
>
,
I
<
1
>
,
I
<
256
>
,
I
<
4
>
,
I
<
64
>
,
I
<
1
>
,
I
<
8
>
,
I
<
1
>
,
I
<
8
>
,
I
<
1
>
,
I
<
8
>
,
I
<
1
>
,
I
<
8
>
,
I
<
8
>>
,
std
::
tuple
<
ck
::
half_t
,
ck
::
half_t
,
ck
::
half_t
,
float
,
ck
::
half_t
,
I
<
2
>
,
I
<
1
>
,
I
<
256
>
,
I
<
4
>
,
I
<
64
>
,
I
<
2
>
,
I
<
8
>
,
I
<
1
>
,
I
<
8
>
,
I
<
1
>
,
I
<
8
>
,
I
<
1
>
,
I
<
8
>
,
I
<
8
>>
,
std
::
tuple
<
ck
::
half_t
,
ck
::
half_t
,
ck
::
half_t
,
float
,
ck
::
half_t
,
I
<
2
>
,
I
<
1
>
,
I
<
256
>
,
I
<
2
>
,
I
<
128
>
,
I
<
1
>
,
I
<
8
>
,
I
<
1
>
,
I
<
8
>
,
I
<
1
>
,
I
<
8
>
,
I
<
1
>
,
I
<
8
>
,
I
<
8
>>
,
std
::
tuple
<
ck
::
half_t
,
ck
::
half_t
,
ck
::
half_t
,
float
,
ck
::
half_t
,
I
<
2
>
,
I
<
1
>
,
I
<
256
>
,
I
<
2
>
,
I
<
128
>
,
I
<
2
>
,
I
<
8
>
,
I
<
1
>
,
I
<
8
>
,
I
<
1
>
,
I
<
8
>
,
I
<
1
>
,
I
<
8
>
,
I
<
8
>>
,
std
::
tuple
<
ck
::
half_t
,
ck
::
half_t
,
ck
::
half_t
,
float
,
ck
::
half_t
,
I
<
2
>
,
I
<
1
>
,
I
<
256
>
,
I
<
1
>
,
I
<
256
>
,
I
<
1
>
,
I
<
8
>
,
I
<
1
>
,
I
<
8
>
,
I
<
1
>
,
I
<
8
>
,
I
<
1
>
,
I
<
8
>
,
I
<
8
>>
,
std
::
tuple
<
ck
::
half_t
,
ck
::
half_t
,
ck
::
half_t
,
float
,
ck
::
half_t
,
I
<
2
>
,
I
<
1
>
,
I
<
256
>
,
I
<
1
>
,
I
<
256
>
,
I
<
2
>
,
I
<
8
>
,
I
<
1
>
,
I
<
8
>
,
I
<
1
>
,
I
<
8
>
,
I
<
1
>
,
I
<
8
>
,
I
<
8
>>
>
;
// clang-format on
TYPED_TEST_SUITE
(
TestLayernorm2dFP16
,
KernelTypes
);
TYPED_TEST
(
TestLayernorm2dFP16
,
Test_FP16
)
{
this
->
Run
();
}
test/layernorm/test_layernorm2d_fp32.cpp
deleted
100644 → 0
View file @
961f5e9e
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#include "gtest/gtest.h"
#include "test_layernorm2d_util.hpp"
template
<
ck
::
index_t
N
>
using
I
=
ck
::
Number
<
N
>
;
template
<
typename
Tuple
>
class
TestLayernorm2dFP32
:
public
ck
::
TestLayernorm2d
<
Tuple
>
{
};
// clang-format off
using
KernelTypes
=
::
testing
::
Types
<
// XDataType, GammaDataType, BetaDataType, AccDataType, YDataType, Rank, NumReduceDim, BlockSize, MThreadClusterSize, KThreadClusterSize, MThreadSliceSize, KThreadSliceSize, XYSrcVectorDim, XSrcVectorSize, GammaSrcVectorDim, GammaSrcVectorSize, BetaSrcVectorDim, BetaSrcVectorSize, YDstVectorSize>
std
::
tuple
<
float
,
float
,
float
,
float
,
float
,
I
<
2
>
,
I
<
1
>
,
I
<
256
>
,
I
<
8
>
,
I
<
32
>
,
I
<
1
>
,
I
<
8
>
,
I
<
1
>
,
I
<
4
>
,
I
<
1
>
,
I
<
4
>
,
I
<
1
>
,
I
<
4
>
,
I
<
4
>>
,
std
::
tuple
<
float
,
float
,
float
,
float
,
float
,
I
<
2
>
,
I
<
1
>
,
I
<
256
>
,
I
<
8
>
,
I
<
32
>
,
I
<
2
>
,
I
<
8
>
,
I
<
1
>
,
I
<
4
>
,
I
<
1
>
,
I
<
4
>
,
I
<
1
>
,
I
<
4
>
,
I
<
4
>>
,
std
::
tuple
<
float
,
float
,
float
,
float
,
float
,
I
<
2
>
,
I
<
1
>
,
I
<
256
>
,
I
<
4
>
,
I
<
64
>
,
I
<
1
>
,
I
<
8
>
,
I
<
1
>
,
I
<
4
>
,
I
<
1
>
,
I
<
4
>
,
I
<
1
>
,
I
<
4
>
,
I
<
4
>>
,
std
::
tuple
<
float
,
float
,
float
,
float
,
float
,
I
<
2
>
,
I
<
1
>
,
I
<
256
>
,
I
<
4
>
,
I
<
64
>
,
I
<
2
>
,
I
<
8
>
,
I
<
1
>
,
I
<
4
>
,
I
<
1
>
,
I
<
4
>
,
I
<
1
>
,
I
<
4
>
,
I
<
4
>>
,
std
::
tuple
<
float
,
float
,
float
,
float
,
float
,
I
<
2
>
,
I
<
1
>
,
I
<
256
>
,
I
<
2
>
,
I
<
128
>
,
I
<
1
>
,
I
<
8
>
,
I
<
1
>
,
I
<
4
>
,
I
<
1
>
,
I
<
4
>
,
I
<
1
>
,
I
<
4
>
,
I
<
4
>>
,
std
::
tuple
<
float
,
float
,
float
,
float
,
float
,
I
<
2
>
,
I
<
1
>
,
I
<
256
>
,
I
<
2
>
,
I
<
128
>
,
I
<
2
>
,
I
<
8
>
,
I
<
1
>
,
I
<
4
>
,
I
<
1
>
,
I
<
4
>
,
I
<
1
>
,
I
<
4
>
,
I
<
4
>>
,
std
::
tuple
<
float
,
float
,
float
,
float
,
float
,
I
<
2
>
,
I
<
1
>
,
I
<
256
>
,
I
<
1
>
,
I
<
256
>
,
I
<
1
>
,
I
<
8
>
,
I
<
1
>
,
I
<
4
>
,
I
<
1
>
,
I
<
4
>
,
I
<
1
>
,
I
<
4
>
,
I
<
4
>>
,
std
::
tuple
<
float
,
float
,
float
,
float
,
float
,
I
<
2
>
,
I
<
1
>
,
I
<
256
>
,
I
<
1
>
,
I
<
256
>
,
I
<
2
>
,
I
<
8
>
,
I
<
1
>
,
I
<
4
>
,
I
<
1
>
,
I
<
4
>
,
I
<
1
>
,
I
<
4
>
,
I
<
4
>>
>
;
// clang-format on
TYPED_TEST_SUITE
(
TestLayernorm2dFP32
,
KernelTypes
);
TYPED_TEST
(
TestLayernorm2dFP32
,
Test_FP32
)
{
this
->
Run
();
}
test/layernorm/test_layernorm2d_util.hpp
deleted
100644 → 0
View file @
961f5e9e
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <vector>
#include <iostream>
#include <gtest/gtest.h>
#include "ck/ck.hpp"
#include "ck/utility/number.hpp"
#include "ck/tensor_operation/gpu/device/device_layernorm_impl.hpp"
#include "ck/library/utility/check_err.hpp"
#include "ck/library/utility/host_tensor.hpp"
#include "ck/library/utility/device_memory.hpp"
#include "ck/library/reference_tensor_operation/cpu/reference_layernorm.hpp"
namespace
ck
{
template
<
typename
Range
>
std
::
string
serialize_range
(
const
Range
&
range
)
{
std
::
stringstream
ss
;
for
(
auto
&
r
:
range
)
{
ss
<<
r
<<
", "
;
}
std
::
string
str
=
ss
.
str
();
return
std
::
string
(
str
.
begin
(),
str
.
end
()
-
2
);
}
template
<
typename
Tuple
>
class
TestLayernorm2d
:
public
::
testing
::
Test
{
protected:
using
XDataType
=
std
::
tuple_element_t
<
0
,
Tuple
>
;
using
GammaDataType
=
std
::
tuple_element_t
<
1
,
Tuple
>
;
using
BetaDataType
=
std
::
tuple_element_t
<
2
,
Tuple
>
;
using
AccDataType
=
std
::
tuple_element_t
<
3
,
Tuple
>
;
using
YDataType
=
std
::
tuple_element_t
<
4
,
Tuple
>
;
static
constexpr
index_t
Rank
=
std
::
tuple_element_t
<
5
,
Tuple
>
{}.
value
;
static
constexpr
index_t
NumReduceDim
=
std
::
tuple_element_t
<
6
,
Tuple
>
{}.
value
;
static
constexpr
index_t
BlockSize
=
std
::
tuple_element_t
<
7
,
Tuple
>
{}.
value
;
static
constexpr
index_t
MThreadClusterSize
=
std
::
tuple_element_t
<
8
,
Tuple
>
{}.
value
;
static
constexpr
index_t
KThreadClusterSize
=
std
::
tuple_element_t
<
9
,
Tuple
>
{}.
value
;
static
constexpr
index_t
MThreadSliceSize
=
std
::
tuple_element_t
<
10
,
Tuple
>
{}.
value
;
static
constexpr
index_t
KThreadSliceSize
=
std
::
tuple_element_t
<
11
,
Tuple
>
{}.
value
;
static
constexpr
index_t
XYSrcVectorDim
=
std
::
tuple_element_t
<
12
,
Tuple
>
{}.
value
;
static
constexpr
index_t
XSrcVectorSize
=
std
::
tuple_element_t
<
13
,
Tuple
>
{}.
value
;
static
constexpr
index_t
GammaSrcVectorDim
=
std
::
tuple_element_t
<
14
,
Tuple
>
{}.
value
;
static
constexpr
index_t
GammaSrcVectorSize
=
std
::
tuple_element_t
<
15
,
Tuple
>
{}.
value
;
static
constexpr
index_t
BetaSrcVectorDim
=
std
::
tuple_element_t
<
16
,
Tuple
>
{}.
value
;
static
constexpr
index_t
BetaSrcVectorSize
=
std
::
tuple_element_t
<
17
,
Tuple
>
{}.
value
;
static
constexpr
index_t
YDstVectorSize
=
std
::
tuple_element_t
<
18
,
Tuple
>
{}.
value
;
using
PassThrough
=
ck
::
tensor_operation
::
element_wise
::
PassThrough
;
using
ReferenceInstance
=
tensor_operation
::
host
::
ReferenceLayernorm
<
XDataType
,
GammaDataType
,
BetaDataType
,
YDataType
,
AccDataType
,
PassThrough
,
Rank
,
NumReduceDim
>
;
using
DeviceInstance
=
tensor_operation
::
device
::
DeviceLayernormImpl
<
XDataType
,
GammaDataType
,
BetaDataType
,
AccDataType
,
YDataType
,
PassThrough
,
Rank
,
NumReduceDim
,
BlockSize
,
MThreadClusterSize
,
KThreadClusterSize
,
MThreadSliceSize
,
KThreadSliceSize
,
XYSrcVectorDim
,
XSrcVectorSize
,
GammaSrcVectorDim
,
GammaSrcVectorSize
,
BetaSrcVectorDim
,
BetaSrcVectorSize
,
YDstVectorSize
>
;
TestLayernorm2d
()
:
ref_instance_invoker_
(
ReferenceInstance
{}.
MakeInvoker
())
{}
void
RunSingle
(
const
std
::
vector
<
index_t
>&
lengths
,
const
std
::
vector
<
index_t
>&
reduceDims
,
const
std
::
vector
<
index_t
>&
GammaLength
,
const
std
::
vector
<
index_t
>&
GammaStride
,
const
std
::
vector
<
index_t
>&
BetaLength
,
const
std
::
vector
<
index_t
>&
BetaStride
)
{
Tensor
<
XDataType
>
x
(
lengths
);
Tensor
<
GammaDataType
>
gamma
(
GammaLength
);
Tensor
<
BetaDataType
>
beta
(
BetaLength
);
Tensor
<
YDataType
>
y
(
lengths
);
Tensor
<
YDataType
>
y_ref
(
lengths
);
x
.
GenerateTensorValue
(
GeneratorTensor_3
<
XDataType
>
{
0
,
1.0
});
gamma
.
GenerateTensorValue
(
GeneratorTensor_3
<
GammaDataType
>
{
0.0
,
1.0
});
beta
.
GenerateTensorValue
(
GeneratorTensor_3
<
BetaDataType
>
{
0.0
,
1.0
});
DeviceMem
x_dev
(
sizeof
(
XDataType
)
*
x
.
mDesc
.
GetElementSpaceSize
());
DeviceMem
gamma_dev
(
sizeof
(
GammaDataType
)
*
gamma
.
mDesc
.
GetElementSpaceSize
());
DeviceMem
beta_dev
(
sizeof
(
BetaDataType
)
*
beta
.
mDesc
.
GetElementSpaceSize
());
DeviceMem
y_dev
(
sizeof
(
YDataType
)
*
y
.
mDesc
.
GetElementSpaceSize
());
x_dev
.
ToDevice
(
x
.
mData
.
data
());
gamma_dev
.
ToDevice
(
gamma
.
mData
.
data
());
beta_dev
.
ToDevice
(
beta
.
mData
.
data
());
auto
device_instance
=
DeviceInstance
{};
auto
argument_ptr
=
device_instance
.
MakeArgumentPointer
(
lengths
,
std
::
vector
<
ck
::
index_t
>
{
x
.
mDesc
.
GetStrides
().
begin
(),
x
.
mDesc
.
GetStrides
().
end
()},
GammaStride
,
BetaStride
,
std
::
vector
<
ck
::
index_t
>
{
y
.
mDesc
.
GetStrides
().
begin
(),
y
.
mDesc
.
GetStrides
().
end
()},
reduceDims
,
1e-4
,
x_dev
.
GetDeviceBuffer
(),
gamma_dev
.
GetDeviceBuffer
(),
beta_dev
.
GetDeviceBuffer
(),
y_dev
.
GetDeviceBuffer
(),
PassThrough
{});
if
(
!
device_instance
.
IsSupportedArgument
(
argument_ptr
.
get
()))
{
return
;
}
auto
invoker_ptr
=
device_instance
.
MakeInvokerPointer
();
invoker_ptr
->
Run
(
argument_ptr
.
get
());
ref_instance_invoker_
.
Run
(
{
x
,
gamma
,
beta
,
y_ref
,
PassThrough
{},
lengths
,
reduceDims
,
1e-4
});
y_dev
.
FromDevice
(
y
.
mData
.
data
());
bool
pass
;
if
(
std
::
is_same
<
XDataType
,
int8_t
>::
value
)
{
EXPECT_TRUE
(
pass
=
ck
::
utils
::
check_err
(
y
.
mData
,
y_ref
.
mData
,
"Error: Incorrect results!"
,
0
,
1
));
}
else
{
EXPECT_TRUE
(
pass
=
ck
::
utils
::
check_err
(
y
.
mData
,
y_ref
.
mData
,
"Error: Incorrect results d1"
,
1e-3
,
1e-3
));
}
if
(
!
pass
)
{
FAIL
()
<<
"Failure in input lengths = ["
<<
serialize_range
(
lengths
)
<<
"], "
<<
"reduce dim = ["
<<
serialize_range
(
reduceDims
)
<<
"]."
;
}
}
void
Run
()
{
std
::
vector
<
std
::
vector
<
index_t
>>
lengths
=
{
{
4
,
256
},
{
8
,
511
},
{
9
,
1032
},
{
4
,
2048
},
{
1
,
8192
},
{
4000
,
2000
}};
for
(
auto
length
:
lengths
)
{
this
->
RunSingle
(
length
,
{
1
},
{
length
[
1
]},
{
0
,
1
},
{
length
[
1
]},
{
0
,
1
});
}
}
typename
ReferenceInstance
::
Invoker
ref_instance_invoker_
;
};
}
// namespace ck
test/
layernorm
/CMakeLists.txt
→
test/
normalization
/CMakeLists.txt
View file @
24af0144
...
...
@@ -5,8 +5,9 @@ add_gtest_executable(test_layernorm2d_fp16 test_layernorm2d_fp16.cpp)
add_gtest_executable
(
test_groupnorm_fp16 test_groupnorm_fp16.cpp
)
add_gtest_executable
(
test_groupnorm_fp32 test_groupnorm_fp32.cpp
)
target_link_libraries
(
test_layernorm2d_fp32 PRIVATE utility
)
target_link_libraries
(
test_layernorm2d_fp16 PRIVATE utility
)
target_link_libraries
(
test_layernorm2d_fp32 PRIVATE utility device_normalization_instance
)
target_link_libraries
(
test_layernorm2d_fp16 PRIVATE utility device_normalization_instance
)
target_link_libraries
(
test_groupnorm_fp16 PRIVATE utility device_normalization_instance
)
target_link_libraries
(
test_groupnorm_fp32 PRIVATE utility device_normalization_instance
)
...
...
@@ -14,4 +15,3 @@ add_dependencies(test_layernorm test_layernorm2d_fp32)
add_dependencies
(
test_layernorm test_layernorm2d_fp16
)
add_dependencies
(
test_layernorm test_groupnorm_fp16
)
add_dependencies
(
test_layernorm test_groupnorm_fp32
)
test/
layernorm
/test_groupnorm_fp16.cpp
→
test/
normalization
/test_groupnorm_fp16.cpp
View file @
24af0144
...
...
@@ -20,7 +20,7 @@ class TestGroupnorm : public ::testing::Test
void
Run
()
{
// N, H, W, G, C
//
[
N, H, W, G, C
], reduce H, W, C
std
::
vector
<
std
::
vector
<
ck
::
index_t
>>
lengths
=
{{
1
,
1
,
1
,
1
,
1
},
{
1
,
2
,
3
,
4
,
5
},
{
256
,
9
,
9
,
9
,
9
},
...
...
@@ -45,13 +45,6 @@ class TestGroupnorm : public ::testing::Test
using
KernelTypes
=
::
testing
::
Types
<
// XDataType, GammaDataType, BetaDataType, AccDataType, YDataType>
std
::
tuple
<
F16
,
F16
,
F16
,
F32
,
F16
>
,
std
::
tuple
<
F16
,
F16
,
F16
,
F32
,
F16
>
,
std
::
tuple
<
F16
,
F16
,
F16
,
F32
,
F16
>
,
std
::
tuple
<
F16
,
F16
,
F16
,
F32
,
F16
>
,
std
::
tuple
<
F16
,
F16
,
F16
,
F32
,
F16
>
,
std
::
tuple
<
F16
,
F16
,
F16
,
F32
,
F16
>
,
std
::
tuple
<
F16
,
F16
,
F16
,
F32
,
F16
>
,
std
::
tuple
<
F16
,
F16
,
F16
,
F32
,
F16
>>
;
TYPED_TEST_SUITE
(
TestGroupnorm
,
KernelTypes
);
...
...
test/
layernorm
/test_groupnorm_fp32.cpp
→
test/
normalization
/test_groupnorm_fp32.cpp
View file @
24af0144
...
...
@@ -20,7 +20,7 @@ class TestGroupnorm : public ::testing::Test
void
Run
()
{
// N, H, W, G, C
//
[
N, H, W, G, C
], reduce H, W, C
std
::
vector
<
std
::
vector
<
ck
::
index_t
>>
lengths
=
{{
1
,
1
,
1
,
1
,
1
},
{
1
,
2
,
3
,
4
,
5
},
{
256
,
9
,
9
,
9
,
9
},
...
...
@@ -43,13 +43,6 @@ class TestGroupnorm : public ::testing::Test
using
KernelTypes
=
::
testing
::
Types
<
// XDataType, GammaDataType, BetaDataType, AccDataType, YDataType>
std
::
tuple
<
F32
,
F32
,
F32
,
F32
,
F32
>
,
std
::
tuple
<
F32
,
F32
,
F32
,
F32
,
F32
>
,
std
::
tuple
<
F32
,
F32
,
F32
,
F32
,
F32
>
,
std
::
tuple
<
F32
,
F32
,
F32
,
F32
,
F32
>
,
std
::
tuple
<
F32
,
F32
,
F32
,
F32
,
F32
>
,
std
::
tuple
<
F32
,
F32
,
F32
,
F32
,
F32
>
,
std
::
tuple
<
F32
,
F32
,
F32
,
F32
,
F32
>
,
std
::
tuple
<
F32
,
F32
,
F32
,
F32
,
F32
>>
;
TYPED_TEST_SUITE
(
TestGroupnorm
,
KernelTypes
);
...
...
Prev
1
…
36
37
38
39
40
41
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment