Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel_ROCM
Commits
8f41bd8e
Commit
8f41bd8e
authored
Apr 11, 2024
by
Jun Liu
Browse files
Merge branch 'develop' into amd-develop
parents
7f65ac05
d7f05fb9
Changes
144
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
188 additions
and
119 deletions
+188
-119
profiler/src/profile_contraction_scale.cpp
profiler/src/profile_contraction_scale.cpp
+93
-54
profiler/src/profile_grouped_conv_fwd.cpp
profiler/src/profile_grouped_conv_fwd.cpp
+7
-1
test/contraction/test_contraction_interface_xdl.cpp
test/contraction/test_contraction_interface_xdl.cpp
+1
-13
test/contraction/test_contraction_xdl.cpp
test/contraction/test_contraction_xdl.cpp
+87
-51
No files found.
profiler/src/profile_contraction_scale.cpp
View file @
8f41bd8e
// SPDX-License-Identifier: MIT
// SPDX-License-Identifier: MIT
// Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2023
-2024
, Advanced Micro Devices, Inc. All rights reserved.
#include <iostream>
#include <iostream>
#include <numeric>
#include <numeric>
...
@@ -19,7 +19,8 @@ static void print_helper_msg()
...
@@ -19,7 +19,8 @@ static void print_helper_msg()
std
::
cout
<<
"arg1: tensor operation ("
OP_NAME
": "
OP_DESC
")
\n
"
std
::
cout
<<
"arg1: tensor operation ("
OP_NAME
": "
OP_DESC
")
\n
"
<<
"arg2: data type (0: fp32; 1: f64; 2: f16; 3: bf16)
\n
"
<<
"arg2: data type (0: fp32; 1: f64; 2: f16; 3: bf16)
\n
"
<<
"arg3: compute data type (0: fp32; 1: f64; 2: f16; 3: bf16)
\n
"
<<
"arg3: compute data type (0: fp32; 1: f64; 2: f16; 3: bf16)
\n
"
<<
"arg4: matrix layout (0: A[m0, m1, k0, k1] * B[k0, k1, n0, n1] + "
<<
"arg4: Number of dimension for M, N and K (one for all)
\n
"
<<
"arg5: matrix layout (0: A[m0, m1, k0, k1] * B[k0, k1, n0, n1] + "
"D[m0, m1, n0, n1] = E[m0, m1, n0, n1];
\n
"
"D[m0, m1, n0, n1] = E[m0, m1, n0, n1];
\n
"
<<
" 1: A[m0, m1, k0, k1] * B[n0, n1, k0, k1] + "
<<
" 1: A[m0, m1, k0, k1] * B[n0, n1, k0, k1] + "
"D[m0, m1, n0, n1] = E[m0, m1, n0, n1];
\n
"
"D[m0, m1, n0, n1] = E[m0, m1, n0, n1];
\n
"
...
@@ -27,22 +28,22 @@ static void print_helper_msg()
...
@@ -27,22 +28,22 @@ static void print_helper_msg()
"D[m0, m1, n0, n1] = E[m0, m1, n0, n1];
\n
"
"D[m0, m1, n0, n1] = E[m0, m1, n0, n1];
\n
"
<<
" 3: A[k0, k1, m0, m1] * B[n0, n1, k0, k1] + "
<<
" 3: A[k0, k1, m0, m1] * B[n0, n1, k0, k1] + "
"D[m0, m1, n0, n1] = E[m0, m1, n0, n1])
\n
"
"D[m0, m1, n0, n1] = E[m0, m1, n0, n1])
\n
"
<<
"arg
5
: verification (0: no; 1: yes)
\n
"
<<
"arg
6
: verification (0: no; 1: yes)
\n
"
<<
"arg
6
: initialization (0: no init; 1: integer value; 2: decimal "
<<
"arg
7
: initialization (0: no init; 1: integer value; 2: decimal "
<<
"value)
\n
"
<<
"value)
\n
"
<<
"arg
7
: print tensor value (0: no; 1: yes)
\n
"
<<
"arg
8
: print tensor value (0: no; 1: yes)
\n
"
<<
"arg
8
: time kernel (0: no, 1: yes)
\n
"
<<
"arg
9
: time kernel (0: no, 1: yes)
\n
"
<<
"arg
9
: alpha
\n
"
<<
"arg
10
: alpha
\n
"
<<
"arg1
0
to 1
5
: M0, M1, N0, N1, K0, K1
\n
"
<<
"arg1
1
to 1
6/28
: M0, M1, N0, N1, K0, K1
\n
"
<<
"arg1
6
to 3
1
: Strides for A, B,
D and
E (skip for default)
\n
"
<<
"arg1
7/29
to 3
2/63
: Strides for A, B, E (skip for default)
\n
"
<<
std
::
endl
;
<<
std
::
endl
;
}
}
int
profile_contraction_scale
(
int
argc
,
char
*
argv
[])
int
profile_contraction_scale
(
int
argc
,
char
*
argv
[])
{
{
const
bool
default_strides
=
argc
==
1
6
;
const
bool
default_strides
=
argc
==
1
7
||
argc
==
29
;
if
(
argc
!=
3
2
&&
argc
!=
1
6
)
if
(
argc
!=
2
9
&&
argc
!=
6
5
&&
!
default_strides
)
{
{
print_helper_msg
();
print_helper_msg
();
exit
(
1
);
exit
(
1
);
...
@@ -50,31 +51,30 @@ int profile_contraction_scale(int argc, char* argv[])
...
@@ -50,31 +51,30 @@ int profile_contraction_scale(int argc, char* argv[])
const
auto
data_type
=
static_cast
<
ContractionDataType
>
(
std
::
stoi
(
argv
[
2
]));
const
auto
data_type
=
static_cast
<
ContractionDataType
>
(
std
::
stoi
(
argv
[
2
]));
const
auto
compute_data_type
=
static_cast
<
ContractionComputeDataType
>
(
std
::
stoi
(
argv
[
3
]));
const
auto
compute_data_type
=
static_cast
<
ContractionComputeDataType
>
(
std
::
stoi
(
argv
[
3
]));
const
auto
layout
=
static_cast
<
ContractionMatrixLayout
>
(
std
::
stoi
(
argv
[
4
]));
const
ck
::
index_t
NumDimMNK
=
std
::
stoi
(
argv
[
4
]);
const
bool
do_verification
=
std
::
stoi
(
argv
[
5
]);
const
auto
layout
=
static_cast
<
ContractionMatrixLayout
>
(
std
::
stoi
(
argv
[
5
]));
const
ck
::
index_t
init_method
=
std
::
stoi
(
argv
[
6
]);
const
bool
do_verification
=
std
::
stoi
(
argv
[
6
]);
const
bool
do_log
=
std
::
stoi
(
argv
[
7
]);
const
ck
::
index_t
init_method
=
std
::
stoi
(
argv
[
7
]);
const
bool
time_kernel
=
std
::
stoi
(
argv
[
8
]);
const
bool
do_log
=
std
::
stoi
(
argv
[
8
]);
const
float
alpha
=
std
::
stof
(
argv
[
9
]);
const
bool
time_kernel
=
std
::
stoi
(
argv
[
9
]);
const
float
alpha
=
std
::
stof
(
argv
[
10
]);
std
::
vector
<
ck
::
index_t
>
M
;
std
::
vector
<
ck
::
index_t
>
M
;
std
::
vector
<
ck
::
index_t
>
N
;
std
::
vector
<
ck
::
index_t
>
N
;
std
::
vector
<
ck
::
index_t
>
K
;
std
::
vector
<
ck
::
index_t
>
K
;
const
ck
::
index_t
dims_arg_num
=
10
;
const
ck
::
index_t
dims_arg_num
=
11
;
collect_index_params
(
argv
,
M
,
dims_arg_num
,
2
);
collect_index_params
(
argv
,
M
,
dims_arg_num
,
NumDimMNK
);
collect_index_params
(
argv
,
N
,
dims_arg_num
+
2
,
2
);
collect_index_params
(
argv
,
N
,
dims_arg_num
+
NumDimMNK
,
NumDimMNK
);
collect_index_params
(
argv
,
K
,
dims_arg_num
+
4
,
2
);
collect_index_params
(
argv
,
K
,
dims_arg_num
+
NumDimMNK
*
2
,
NumDimMNK
);
std
::
vector
<
ck
::
index_t
>
StridesA
;
std
::
vector
<
ck
::
index_t
>
StridesA
(
NumDimMNK
*
2
);
std
::
vector
<
ck
::
index_t
>
StridesB
;
std
::
vector
<
ck
::
index_t
>
StridesB
(
NumDimMNK
*
2
);
std
::
vector
<
ck
::
index_t
>
StridesE
;
std
::
vector
<
ck
::
index_t
>
StridesE
(
NumDimMNK
*
2
);
std
::
vector
<
ck
::
index_t
>
StridesD
;
if
(
!
default_strides
)
if
(
!
default_strides
)
{
{
collect_index_params
(
argv
,
StridesA
,
dims_arg_num
+
6
,
4
);
collect_index_params
(
argv
,
StridesA
,
dims_arg_num
+
NumDimMNK
*
3
,
NumDimMNK
*
2
);
collect_index_params
(
argv
,
StridesB
,
dims_arg_num
+
10
,
4
);
collect_index_params
(
argv
,
StridesB
,
dims_arg_num
+
NumDimMNK
*
5
,
NumDimMNK
*
2
);
collect_index_params
(
argv
,
StridesE
,
dims_arg_num
+
14
,
4
);
collect_index_params
(
argv
,
StridesE
,
dims_arg_num
+
NumDimMNK
*
7
,
NumDimMNK
*
2
);
collect_index_params
(
argv
,
StridesD
,
dims_arg_num
+
18
,
4
);
}
}
using
F16
=
ck
::
half_t
;
using
F16
=
ck
::
half_t
;
...
@@ -93,32 +93,71 @@ int profile_contraction_scale(int argc, char* argv[])
...
@@ -93,32 +93,71 @@ int profile_contraction_scale(int argc, char* argv[])
if
(
default_strides
)
if
(
default_strides
)
{
{
assign_default_strides
(
a_layout
,
StridesA
,
{
M
[
0
],
M
[
1
],
K
[
0
],
K
[
1
]});
auto
merge_dims
=
[](
const
std
::
vector
<
ck
::
index_t
>&
dims01
,
assign_default_strides
(
b_layout
,
StridesB
,
{
N
[
0
],
N
[
1
],
K
[
0
],
K
[
1
]});
const
std
::
vector
<
ck
::
index_t
>&
dims23
)
{
assign_default_strides
(
cde_layout
,
StridesE
,
{
M
[
0
],
M
[
1
],
N
[
0
],
N
[
1
]});
std
::
vector
<
ck
::
index_t
>
dims_szt
(
dims01
.
begin
(),
dims01
.
end
());
assign_default_strides
(
cde_layout
,
StridesD
,
{
M
[
0
],
M
[
1
],
N
[
0
],
N
[
1
]});
dims_szt
.
insert
(
dims_szt
.
end
(),
dims23
.
begin
(),
dims23
.
end
());
return
dims_szt
;
};
assign_default_strides
(
a_layout
,
StridesA
,
merge_dims
(
M
,
K
));
assign_default_strides
(
b_layout
,
StridesB
,
merge_dims
(
N
,
K
));
assign_default_strides
(
cde_layout
,
StridesE
,
merge_dims
(
M
,
N
));
}
}
bool
pass
=
ck
::
profiler
::
profile_contraction_impl
<
ALayout
,
if
(
NumDimMNK
==
2
)
BLayout
,
{
CDELayout
,
bool
pass
=
ck
::
profiler
::
profile_contraction_impl
<
2
,
DataType
,
ALayout
,
ComputeDataType
,
BLayout
,
ck
::
Tuple
<>
,
CDELayout
,
Scale
>
(
do_verification
,
DataType
,
init_method
,
ComputeDataType
,
do_log
,
ck
::
Tuple
<>
,
time_kernel
,
Scale
>
(
do_verification
,
Scale
{
alpha
},
init_method
,
M
,
do_log
,
N
,
time_kernel
,
K
,
Scale
{
alpha
},
StridesA
,
M
,
StridesB
,
N
,
StridesE
,
K
,
StridesD
);
StridesA
,
StridesB
,
return
pass
;
StridesE
,
StridesE
);
return
pass
;
}
else
if
(
NumDimMNK
==
6
)
{
bool
pass
=
ck
::
profiler
::
profile_contraction_impl
<
6
,
ALayout
,
BLayout
,
CDELayout
,
DataType
,
ComputeDataType
,
ck
::
Tuple
<>
,
Scale
>
(
do_verification
,
init_method
,
do_log
,
time_kernel
,
Scale
{
alpha
},
M
,
N
,
K
,
StridesA
,
StridesB
,
StridesE
,
StridesE
);
return
pass
;
}
else
{
throw
std
::
runtime_error
(
"Not supported NumDimMNK"
);
return
false
;
}
};
};
auto
run_profile_for_datatype
=
[
&
](
auto
type
,
auto
compute_type
)
{
auto
run_profile_for_datatype
=
[
&
](
auto
type
,
auto
compute_type
)
{
...
...
profiler/src/profile_grouped_conv_fwd.cpp
View file @
8f41bd8e
...
@@ -26,6 +26,7 @@ enum struct ConvDataType
...
@@ -26,6 +26,7 @@ enum struct ConvDataType
F8_F8_F8
,
// 4
F8_F8_F8
,
// 4
BF8_BF8_F8
,
// 5
BF8_BF8_F8
,
// 5
F8_BF8_F8
,
// 6
F8_BF8_F8
,
// 6
BF8_F8_F8
,
// 7
};
};
#define OP_NAME "grouped_conv_fwd"
#define OP_NAME "grouped_conv_fwd"
...
@@ -42,7 +43,8 @@ static void print_helper_msg()
...
@@ -42,7 +43,8 @@ static void print_helper_msg()
<<
" 3: Input int8, Weight int8, Output int8
\n
"
<<
" 3: Input int8, Weight int8, Output int8
\n
"
<<
" 4: Input fp8, Weight fp8, Output fp8
\n
"
<<
" 4: Input fp8, Weight fp8, Output fp8
\n
"
<<
" 5: Input bf8, Weight bf8, Output fp8
\n
"
<<
" 5: Input bf8, Weight bf8, Output fp8
\n
"
<<
" 6: Input fp8, Weight bf8, Output fp8)
\n
"
<<
" 6: Input fp8, Weight bf8, Output fp8
\n
"
<<
" 7: Input bf8, Weight fp8, Output fp8)
\n
"
<<
"arg3: tensor layout (0: Input[G, N, Hi, Wi, C], Weight[G, K, Y, X, C], Output[G, N, Ho, Wo, K]
\n
"
<<
"arg3: tensor layout (0: Input[G, N, Hi, Wi, C], Weight[G, K, Y, X, C], Output[G, N, Ho, Wo, K]
\n
"
<<
" 1: Input[N, Hi, Wi, G, C], Weight[G, K, Y, X, C], Output[N, Ho, Wo, G, K])
\n
"
<<
" 1: Input[N, Hi, Wi, G, C], Weight[G, K, Y, X, C], Output[N, Ho, Wo, G, K])
\n
"
<<
"arg4: verification (0: no, 1: yes)
\n
"
<<
"arg4: verification (0: no, 1: yes)
\n
"
...
@@ -281,6 +283,10 @@ int profile_grouped_conv_fwd(int argc, char* argv[])
...
@@ -281,6 +283,10 @@ int profile_grouped_conv_fwd(int argc, char* argv[])
{
{
return
profile
(
I3
,
NDHWGC
{},
GKZYXC
{},
NDHWGK
{},
F8
{},
BF8
{},
F8
{},
F8
{},
BF8
{});
return
profile
(
I3
,
NDHWGC
{},
GKZYXC
{},
NDHWGK
{},
F8
{},
BF8
{},
F8
{},
F8
{},
BF8
{});
}
}
else
if
(
data_type
==
ConvDataType
::
BF8_F8_F8
)
{
return
profile
(
I3
,
NDHWGC
{},
GKZYXC
{},
NDHWGK
{},
BF8
{},
F8
{},
F8
{},
BF8
{},
F8
{});
}
}
}
std
::
cout
<<
"this data_type & layout is not implemented"
<<
std
::
endl
;
std
::
cout
<<
"this data_type & layout is not implemented"
<<
std
::
endl
;
...
...
test/contraction/test_contraction_interface_xdl.cpp
View file @
8f41bd8e
// SPDX-License-Identifier: MIT
// SPDX-License-Identifier: MIT
// Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2023
-2024
, Advanced Micro Devices, Inc. All rights reserved.
#include <stdexcept>
#include <stdexcept>
#include <vector>
#include <vector>
...
@@ -125,18 +125,6 @@ class ContractionDeviceOpWrapper
...
@@ -125,18 +125,6 @@ class ContractionDeviceOpWrapper
}
}
};
};
TEST
(
TestContractionInterface
,
IncorrectNumDims
)
{
std
::
vector
<
std
::
vector
<
ck
::
index_t
>>
Dims
=
{{
4
,
4
},
{
4
,
4
,
4
,
4
},
{
4
,
4
,
4
,
4
,
4
,
4
}};
std
::
vector
<
std
::
vector
<
ck
::
index_t
>>
Strides
=
{{
1
,
1
},
{
1
,
1
,
1
,
1
},
{
1
,
1
,
1
,
1
,
1
,
1
}};
ContractionDeviceOpWrapper
<
F32
,
F32
,
F32
,
F32
,
1
>
wrapper_1d
;
ContractionDeviceOpWrapper
<
F32
,
F32
,
F32
,
F32
,
2
>
wrapper_2d
;
ContractionDeviceOpWrapper
<
F32
,
F32
,
F32
,
F32
,
3
>
wrapper_3d
;
EXPECT_FALSE
(
wrapper_1d
.
IsSupportedInstance
(
Dims
[
0
],
Strides
[
0
]));
EXPECT_TRUE
(
wrapper_2d
.
IsSupportedInstance
(
Dims
[
1
],
Strides
[
1
]));
EXPECT_FALSE
(
wrapper_3d
.
IsSupportedInstance
(
Dims
[
2
],
Strides
[
2
]));
}
TEST
(
TestContractionInterface
,
IncorrectDataTypes
)
TEST
(
TestContractionInterface
,
IncorrectDataTypes
)
{
{
std
::
vector
<
ck
::
index_t
>
Dims
=
{
4
,
4
,
4
,
4
};
std
::
vector
<
ck
::
index_t
>
Dims
=
{
4
,
4
,
4
,
4
};
...
...
test/contraction/test_contraction_xdl.cpp
View file @
8f41bd8e
// SPDX-License-Identifier: MIT
// SPDX-License-Identifier: MIT
// Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2023
-2024
, Advanced Micro Devices, Inc. All rights reserved.
#include <cstdlib>
#include <cstdlib>
#include <iostream>
#include <iostream>
...
@@ -23,8 +23,11 @@ using Col = ck::tensor_layout::gemm::ColumnMajor;
...
@@ -23,8 +23,11 @@ using Col = ck::tensor_layout::gemm::ColumnMajor;
using
Bilinear
=
ck
::
tensor_operation
::
element_wise
::
Bilinear
;
using
Bilinear
=
ck
::
tensor_operation
::
element_wise
::
Bilinear
;
using
Scale
=
ck
::
tensor_operation
::
element_wise
::
Scale
;
using
Scale
=
ck
::
tensor_operation
::
element_wise
::
Scale
;
template
<
ck
::
index_t
NDims
>
struct
Dimensions
struct
Dimensions
{
{
constexpr
static
ck
::
index_t
NumDimMNK
=
NDims
;
std
::
vector
<
ck
::
index_t
>
M
;
std
::
vector
<
ck
::
index_t
>
M
;
std
::
vector
<
ck
::
index_t
>
N
;
std
::
vector
<
ck
::
index_t
>
N
;
std
::
vector
<
ck
::
index_t
>
K
;
std
::
vector
<
ck
::
index_t
>
K
;
...
@@ -42,53 +45,58 @@ class TestContraction : public ::testing::Test
...
@@ -42,53 +45,58 @@ class TestContraction : public ::testing::Test
using
ComputeDataType
=
std
::
tuple_element_t
<
5
,
Tuple
>
;
using
ComputeDataType
=
std
::
tuple_element_t
<
5
,
Tuple
>
;
using
CDElementOp
=
std
::
tuple_element_t
<
6
,
Tuple
>
;
using
CDElementOp
=
std
::
tuple_element_t
<
6
,
Tuple
>
;
std
::
vector
<
Dimensions
>
dimension_list
=
{{{
32
,
32
},
{
32
,
32
},
{
32
,
32
}},
{{
16
,
16
},
{
32
,
32
},
{
16
,
16
}}};
std
::
vector
<
ck
::
index_t
>
init_methods
=
{
1
,
2
};
std
::
vector
<
ck
::
index_t
>
init_methods
=
{
1
,
2
};
std
::
unique_ptr
<
CDElementOp
>
p_cd_element_op
;
std
::
unique_ptr
<
CDElementOp
>
p_cd_element_op
;
void
Run
()
template
<
ck
::
index_t
NumDim
>
void
Run
(
Dimensions
<
NumDim
>
dimension_params
)
{
{
for
(
auto
&
dimension_params
:
dimension_list
)
constexpr
ck
::
index_t
NumDimMNK
=
ck
::
remove_cvref_t
<
decltype
(
dimension_params
)
>::
NumDimMNK
;
std
::
vector
<
ck
::
index_t
>
StridesA
(
2
*
NumDim
);
std
::
vector
<
ck
::
index_t
>
StridesB
(
2
*
NumDim
);
std
::
vector
<
ck
::
index_t
>
StridesC
(
2
*
NumDim
);
std
::
vector
<
ck
::
index_t
>
StridesD
(
2
*
NumDim
);
const
auto
&
M
=
dimension_params
.
M
;
const
auto
&
N
=
dimension_params
.
N
;
const
auto
&
K
=
dimension_params
.
K
;
auto
merge_dims
=
[](
const
std
::
vector
<
ck
::
index_t
>&
dims01
,
const
std
::
vector
<
ck
::
index_t
>&
dims23
)
{
std
::
vector
<
ck
::
index_t
>
dims_szt
(
dims01
.
begin
(),
dims01
.
end
());
dims_szt
.
insert
(
dims_szt
.
end
(),
dims23
.
begin
(),
dims23
.
end
());
return
dims_szt
;
};
assign_default_strides
(
ALayout
{},
StridesA
,
merge_dims
(
M
,
K
));
assign_default_strides
(
BLayout
{},
StridesB
,
merge_dims
(
N
,
K
));
assign_default_strides
(
CDLayout
{},
StridesC
,
merge_dims
(
M
,
N
));
assign_default_strides
(
CDLayout
{},
StridesD
,
merge_dims
(
M
,
N
));
for
(
const
ck
::
index_t
init_method
:
init_methods
)
{
{
std
::
vector
<
ck
::
index_t
>
StridesA
;
bool
pass
=
std
::
vector
<
ck
::
index_t
>
StridesB
;
ck
::
profiler
::
profile_contraction_impl
<
NumDimMNK
,
std
::
vector
<
ck
::
index_t
>
StridesC
;
ALayout
,
std
::
vector
<
ck
::
index_t
>
StridesD
;
BLayout
,
CDLayout
,
const
auto
&
M
=
dimension_params
.
M
;
DataType
,
const
auto
&
N
=
dimension_params
.
N
;
ComputeDataType
,
const
auto
&
K
=
dimension_params
.
K
;
DTupleDataType
,
CDElementOp
>
(
true
/*do_verification*/
,
assign_default_strides
(
ALayout
{},
StridesA
,
{
M
[
0
],
M
[
1
],
K
[
0
],
K
[
1
]});
init_method
,
assign_default_strides
(
BLayout
{},
StridesB
,
{
N
[
0
],
N
[
1
],
K
[
0
],
K
[
1
]});
false
/*do_logs*/
,
assign_default_strides
(
CDLayout
{},
StridesC
,
{
M
[
0
],
M
[
1
],
N
[
0
],
N
[
1
]});
false
/*time_kernel*/
,
assign_default_strides
(
CDLayout
{},
StridesD
,
{
M
[
0
],
M
[
1
],
N
[
0
],
N
[
1
]});
*
p_cd_element_op
,
dimension_params
.
M
,
for
(
const
ck
::
index_t
init_method
:
init_methods
)
dimension_params
.
N
,
{
dimension_params
.
K
,
bool
pass
=
StridesA
,
ck
::
profiler
::
profile_contraction_impl
<
ALayout
,
StridesB
,
BLayout
,
StridesC
,
CDLayout
,
StridesD
);
DataType
,
EXPECT_TRUE
(
pass
);
ComputeDataType
,
DTupleDataType
,
CDElementOp
>
(
true
/*do_verification*/
,
init_method
,
false
/*do_logs*/
,
false
/*time_kernel*/
,
*
p_cd_element_op
,
dimension_params
.
M
,
dimension_params
.
N
,
dimension_params
.
K
,
StridesA
,
StridesB
,
StridesC
,
StridesD
);
EXPECT_TRUE
(
pass
);
}
}
}
}
}
};
};
...
@@ -122,17 +130,31 @@ TYPED_TEST_SUITE(TestContractionScale, ScaleKernelTypes);
...
@@ -122,17 +130,31 @@ TYPED_TEST_SUITE(TestContractionScale, ScaleKernelTypes);
TYPED_TEST
(
TestContractionBilinear
,
bilinear
)
TYPED_TEST
(
TestContractionBilinear
,
bilinear
)
{
{
this
->
p_cd_element_op
=
std
::
make_unique
<
Bilinear
>
(
1.
f
,
1.
f
);
this
->
p_cd_element_op
=
std
::
make_unique
<
Bilinear
>
(
1.
f
,
1.
f
);
this
->
Run
();
this
->
template
Run
<
6
>({{
2
,
3
,
2
,
3
,
2
,
3
},
{
2
,
3
,
2
,
3
,
2
,
3
},
{
2
,
2
,
2
,
2
,
2
,
4
}});
this
->
template
Run
<
6
>({{
1
,
1
,
1
,
3
,
2
,
3
},
{
1
,
1
,
1
,
3
,
2
,
3
},
{
1
,
1
,
1
,
2
,
2
,
4
}});
this
->
template
Run
<
2
>({{
16
,
8
},
{
16
,
8
},
{
16
,
8
}});
this
->
template
Run
<
2
>({{
8
,
16
},
{
16
,
8
},
{
8
,
16
}});
this
->
p_cd_element_op
=
std
::
make_unique
<
Bilinear
>
(
-
0.5
f
,
0.5
f
);
this
->
p_cd_element_op
=
std
::
make_unique
<
Bilinear
>
(
-
0.5
f
,
0.5
f
);
this
->
Run
();
this
->
template
Run
<
6
>({{
2
,
3
,
2
,
3
,
2
,
3
},
{
2
,
3
,
2
,
3
,
2
,
3
},
{
2
,
2
,
2
,
2
,
2
,
4
}});
this
->
template
Run
<
6
>({{
1
,
1
,
1
,
3
,
2
,
3
},
{
1
,
1
,
1
,
3
,
2
,
3
},
{
1
,
1
,
1
,
2
,
2
,
4
}});
this
->
template
Run
<
2
>({{
16
,
8
},
{
16
,
8
},
{
16
,
8
}});
this
->
template
Run
<
2
>({{
8
,
16
},
{
16
,
8
},
{
8
,
16
}});
}
}
TYPED_TEST
(
TestContractionScale
,
scale
)
TYPED_TEST
(
TestContractionScale
,
scale
)
{
{
this
->
p_cd_element_op
=
std
::
make_unique
<
Scale
>
(
1.
f
);
this
->
p_cd_element_op
=
std
::
make_unique
<
Scale
>
(
1.
f
);
this
->
Run
();
this
->
template
Run
<
6
>({{
2
,
3
,
2
,
3
,
2
,
3
},
{
2
,
3
,
2
,
3
,
2
,
3
},
{
2
,
2
,
2
,
2
,
2
,
4
}});
this
->
template
Run
<
6
>({{
1
,
1
,
1
,
3
,
2
,
3
},
{
1
,
1
,
1
,
3
,
2
,
3
},
{
1
,
1
,
1
,
2
,
2
,
4
}});
this
->
template
Run
<
2
>({{
16
,
8
},
{
16
,
8
},
{
16
,
8
}});
this
->
template
Run
<
2
>({{
8
,
16
},
{
16
,
8
},
{
8
,
16
}});
this
->
p_cd_element_op
=
std
::
make_unique
<
Scale
>
(
0.5
f
);
this
->
p_cd_element_op
=
std
::
make_unique
<
Scale
>
(
0.5
f
);
this
->
Run
();
this
->
template
Run
<
6
>({{
2
,
3
,
2
,
3
,
2
,
3
},
{
2
,
3
,
2
,
3
,
2
,
3
},
{
2
,
2
,
2
,
2
,
2
,
4
}});
this
->
template
Run
<
6
>({{
1
,
1
,
1
,
3
,
2
,
3
},
{
1
,
1
,
1
,
3
,
2
,
3
},
{
1
,
1
,
1
,
2
,
2
,
4
}});
this
->
template
Run
<
2
>({{
16
,
8
},
{
16
,
8
},
{
16
,
8
}});
this
->
template
Run
<
2
>({{
8
,
16
},
{
16
,
8
},
{
8
,
16
}});
}
}
template
<
typename
Tuple
>
template
<
typename
Tuple
>
...
@@ -165,15 +187,29 @@ TYPED_TEST_SUITE(TestContractionScaleMixedPrecision, ScaleKernelTypesMixedPrecis
...
@@ -165,15 +187,29 @@ TYPED_TEST_SUITE(TestContractionScaleMixedPrecision, ScaleKernelTypesMixedPrecis
TYPED_TEST
(
TestContractionBilinearMixedPrecision
,
bilinear
)
TYPED_TEST
(
TestContractionBilinearMixedPrecision
,
bilinear
)
{
{
this
->
p_cd_element_op
=
std
::
make_unique
<
Bilinear
>
(
1.
f
,
1.
f
);
this
->
p_cd_element_op
=
std
::
make_unique
<
Bilinear
>
(
1.
f
,
1.
f
);
this
->
Run
();
this
->
template
Run
<
6
>({{
2
,
3
,
2
,
3
,
2
,
3
},
{
2
,
3
,
2
,
3
,
2
,
3
},
{
2
,
2
,
2
,
2
,
2
,
4
}});
this
->
template
Run
<
6
>({{
1
,
1
,
1
,
3
,
2
,
3
},
{
1
,
1
,
1
,
3
,
2
,
3
},
{
1
,
1
,
1
,
2
,
2
,
4
}});
this
->
template
Run
<
2
>({{
16
,
8
},
{
16
,
8
},
{
16
,
8
}});
this
->
template
Run
<
2
>({{
8
,
16
},
{
16
,
8
},
{
8
,
16
}});
this
->
p_cd_element_op
=
std
::
make_unique
<
Bilinear
>
(
-
0.5
f
,
0.5
f
);
this
->
p_cd_element_op
=
std
::
make_unique
<
Bilinear
>
(
-
0.5
f
,
0.5
f
);
this
->
Run
();
this
->
template
Run
<
6
>({{
2
,
3
,
2
,
3
,
2
,
3
},
{
2
,
3
,
2
,
3
,
2
,
3
},
{
2
,
2
,
2
,
2
,
2
,
4
}});
this
->
template
Run
<
6
>({{
1
,
1
,
1
,
3
,
2
,
3
},
{
1
,
1
,
1
,
3
,
2
,
3
},
{
1
,
1
,
1
,
2
,
2
,
4
}});
this
->
template
Run
<
2
>({{
16
,
8
},
{
16
,
8
},
{
16
,
8
}});
this
->
template
Run
<
2
>({{
8
,
16
},
{
16
,
8
},
{
8
,
16
}});
}
}
TYPED_TEST
(
TestContractionScaleMixedPrecision
,
scale
)
TYPED_TEST
(
TestContractionScaleMixedPrecision
,
scale
)
{
{
this
->
p_cd_element_op
=
std
::
make_unique
<
Scale
>
(
1.
f
);
this
->
p_cd_element_op
=
std
::
make_unique
<
Scale
>
(
1.
f
);
this
->
Run
();
this
->
template
Run
<
6
>({{
2
,
3
,
2
,
3
,
2
,
3
},
{
2
,
3
,
2
,
3
,
2
,
3
},
{
2
,
2
,
2
,
2
,
2
,
4
}});
this
->
template
Run
<
6
>({{
1
,
1
,
1
,
3
,
2
,
3
},
{
1
,
1
,
1
,
3
,
2
,
3
},
{
1
,
1
,
1
,
2
,
2
,
4
}});
this
->
template
Run
<
2
>({{
16
,
8
},
{
16
,
8
},
{
16
,
8
}});
this
->
template
Run
<
2
>({{
8
,
16
},
{
16
,
8
},
{
8
,
16
}});
this
->
p_cd_element_op
=
std
::
make_unique
<
Scale
>
(
0.5
f
);
this
->
p_cd_element_op
=
std
::
make_unique
<
Scale
>
(
0.5
f
);
this
->
Run
();
this
->
template
Run
<
6
>({{
2
,
3
,
2
,
3
,
2
,
3
},
{
2
,
3
,
2
,
3
,
2
,
3
},
{
2
,
2
,
2
,
2
,
2
,
4
}});
this
->
template
Run
<
6
>({{
1
,
1
,
1
,
3
,
2
,
3
},
{
1
,
1
,
1
,
3
,
2
,
3
},
{
1
,
1
,
1
,
2
,
2
,
4
}});
this
->
template
Run
<
2
>({{
16
,
8
},
{
16
,
8
},
{
16
,
8
}});
this
->
template
Run
<
2
>({{
8
,
16
},
{
16
,
8
},
{
8
,
16
}});
}
}
Prev
1
…
4
5
6
7
8
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment