Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel_ROCM
Commits
408534d4
Unverified
Commit
408534d4
authored
Aug 09, 2024
by
Rostyslav Geyyer
Committed by
GitHub
Aug 09, 2024
Browse files
Merge branch 'develop' into lwpck-1815
parents
a8efb3f0
da214a5a
Changes
204
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
95 additions
and
73 deletions
+95
-73
test/gemm_universal/test_gemm_universal_util.hpp
test/gemm_universal/test_gemm_universal_util.hpp
+9
-7
test/gemm_universal/test_gemm_universal_xdl.cpp
test/gemm_universal/test_gemm_universal_xdl.cpp
+17
-9
test/grouped_convnd_fwd/test_grouped_convnd_fwd.cpp
test/grouped_convnd_fwd/test_grouped_convnd_fwd.cpp
+33
-25
test/smfmac_op/smfmac_op_xdl.cpp
test/smfmac_op/smfmac_op_xdl.cpp
+36
-32
No files found.
test/gemm_universal/test_gemm_universal_util.hpp
View file @
408534d4
// SPDX-License-Identifier: MIT
// SPDX-License-Identifier: MIT
// Copyright (c) 20
18
-202
3
, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 20
23
-202
4
, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#pragma once
...
@@ -25,12 +25,13 @@ class TestGemmUniversal : public testing::Test
...
@@ -25,12 +25,13 @@ class TestGemmUniversal : public testing::Test
using
F32
=
float
;
using
F32
=
float
;
protected:
protected:
using
ALayout
=
std
::
tuple_element_t
<
0
,
Tuple
>
;
using
ALayout
=
std
::
tuple_element_t
<
0
,
Tuple
>
;
using
BLayout
=
std
::
tuple_element_t
<
1
,
Tuple
>
;
using
BLayout
=
std
::
tuple_element_t
<
1
,
Tuple
>
;
using
CLayout
=
Row
;
using
CLayout
=
Row
;
using
ADataType
=
std
::
tuple_element_t
<
2
,
Tuple
>
;
using
ADataType
=
std
::
tuple_element_t
<
2
,
Tuple
>
;
using
BDataType
=
std
::
tuple_element_t
<
3
,
Tuple
>
;
using
BDataType
=
std
::
tuple_element_t
<
3
,
Tuple
>
;
using
CDataType
=
std
::
tuple_element_t
<
4
,
Tuple
>
;
using
ComputeDataType
=
std
::
tuple_element_t
<
4
,
Tuple
>
;
using
CDataType
=
std
::
tuple_element_t
<
5
,
Tuple
>
;
public:
public:
static
constexpr
bool
verify_
=
true
;
static
constexpr
bool
verify_
=
true
;
...
@@ -66,6 +67,7 @@ class TestGemmUniversal : public testing::Test
...
@@ -66,6 +67,7 @@ class TestGemmUniversal : public testing::Test
{
{
bool
pass
=
ck
::
profiler
::
profile_gemm_universal_impl
<
ADataType
,
bool
pass
=
ck
::
profiler
::
profile_gemm_universal_impl
<
ADataType
,
BDataType
,
BDataType
,
ComputeDataType
,
F32
,
F32
,
CDataType
,
CDataType
,
ALayout
,
ALayout
,
...
...
test/gemm_universal/test_gemm_universal_xdl.cpp
View file @
408534d4
// SPDX-License-Identifier: MIT
// SPDX-License-Identifier: MIT
// Copyright (c) 20
18
-202
2
, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 20
23
-202
4
, Advanced Micro Devices, Inc. All rights reserved.
#include <tuple>
#include <tuple>
...
@@ -41,16 +41,24 @@ class TestGemmUniversal_MK_NK
...
@@ -41,16 +41,24 @@ class TestGemmUniversal_MK_NK
};
};
// clang-format off
// clang-format off
using
KernelTypes
=
::
testing
::
Types
<
using
KernelTypes_MK_KN
=
::
testing
::
Types
<
// ADataType, BDataType, CDataType
// ADataType, BDataType, ComputeDataType, CDataType
std
::
tuple
<
F16
,
F16
,
F16
>
,
std
::
tuple
<
F16
,
F16
,
F16
,
F16
>
,
std
::
tuple
<
F16
,
F8
,
F16
>
,
std
::
tuple
<
F16
,
F8
,
F16
,
F16
>
,
std
::
tuple
<
F8
,
F16
,
F16
>
,
std
::
tuple
<
F8
,
F16
,
F16
,
F16
>
,
std
::
tuple
<
BF16
,
BF16
,
BF16
>
std
::
tuple
<
BF16
,
BF16
,
BF16
,
BF16
>
>
;
using
KernelTypes_MK_NK
=
::
testing
::
Types
<
// ADataType, BDataType, ComputeDataType, CDataType
std
::
tuple
<
F16
,
F16
,
F16
,
F16
>
,
std
::
tuple
<
F16
,
F8
,
F16
,
F16
>
,
std
::
tuple
<
F8
,
F16
,
F16
,
F16
>
,
std
::
tuple
<
BF16
,
BF16
,
BF16
,
BF16
>
,
std
::
tuple
<
F8
,
F8
,
F8
,
BF16
>
>
;
>
;
// clang-format on
// clang-format on
TYPED_TEST_SUITE
(
TestGemmUniversal_MK_KN
,
KernelTypes
);
TYPED_TEST_SUITE
(
TestGemmUniversal_MK_KN
,
KernelTypes
_MK_KN
);
TYPED_TEST_SUITE
(
TestGemmUniversal_MK_NK
,
KernelTypes
);
TYPED_TEST_SUITE
(
TestGemmUniversal_MK_NK
,
KernelTypes
_MK_NK
);
#include "test_gemm_universal_ut_cases.inc"
#include "test_gemm_universal_ut_cases.inc"
test/grouped_convnd_fwd/test_grouped_convnd_fwd.cpp
View file @
408534d4
...
@@ -17,6 +17,7 @@ class TestGroupedConvndFwd : public ::testing::Test
...
@@ -17,6 +17,7 @@ class TestGroupedConvndFwd : public ::testing::Test
using
InLayout
=
std
::
tuple_element_t
<
1
,
Tuple
>
;
using
InLayout
=
std
::
tuple_element_t
<
1
,
Tuple
>
;
using
WeiLayout
=
std
::
tuple_element_t
<
2
,
Tuple
>
;
using
WeiLayout
=
std
::
tuple_element_t
<
2
,
Tuple
>
;
using
OutLayout
=
std
::
tuple_element_t
<
3
,
Tuple
>
;
using
OutLayout
=
std
::
tuple_element_t
<
3
,
Tuple
>
;
using
IndexType
=
std
::
tuple_element_t
<
4
,
Tuple
>
;
std
::
vector
<
ck
::
utils
::
conv
::
ConvParam
>
conv_params
;
std
::
vector
<
ck
::
utils
::
conv
::
ConvParam
>
conv_params
;
...
@@ -33,7 +34,10 @@ class TestGroupedConvndFwd : public ::testing::Test
...
@@ -33,7 +34,10 @@ class TestGroupedConvndFwd : public ::testing::Test
OutLayout
,
OutLayout
,
DataType
,
DataType
,
DataType
,
DataType
,
DataType
>
(
DataType
,
DataType
,
DataType
,
IndexType
>
(
true
,
// do_verification
true
,
// do_verification
1
,
// init_method: integer value
1
,
// init_method: integer value
false
,
// do_log
false
,
// do_log
...
@@ -46,30 +50,31 @@ class TestGroupedConvndFwd : public ::testing::Test
...
@@ -46,30 +50,31 @@ class TestGroupedConvndFwd : public ::testing::Test
using
namespace
ck
::
tensor_layout
::
convolution
;
using
namespace
ck
::
tensor_layout
::
convolution
;
using
KernelTypes1d
=
::
testing
::
Types
<
std
::
tuple
<
float
,
GNWC
,
GKXC
,
GNWK
>
,
using
KernelTypes1d
=
::
testing
::
Types
<
std
::
tuple
<
float
,
GNWC
,
GKXC
,
GNWK
,
ck
::
index_t
>
,
std
::
tuple
<
ck
::
half_t
,
GNWC
,
GKXC
,
GNWK
>
,
std
::
tuple
<
ck
::
half_t
,
GNWC
,
GKXC
,
GNWK
,
ck
::
index_t
>
,
std
::
tuple
<
ck
::
bhalf_t
,
GNWC
,
GKXC
,
GNWK
>
,
std
::
tuple
<
ck
::
bhalf_t
,
GNWC
,
GKXC
,
GNWK
,
ck
::
index_t
>
,
std
::
tuple
<
int8_t
,
GNWC
,
GKXC
,
GNWK
>>
;
std
::
tuple
<
int8_t
,
GNWC
,
GKXC
,
GNWK
,
ck
::
index_t
>>
;
using
KernelTypes2d
=
::
testing
::
Types
<
std
::
tuple
<
float
,
GNHWC
,
GKYXC
,
GNHWK
>
,
using
KernelTypes2d
=
::
testing
::
Types
<
std
::
tuple
<
float
,
GNHWC
,
GKYXC
,
GNHWK
,
ck
::
index_t
>
,
std
::
tuple
<
ck
::
half_t
,
GNHWC
,
GKYXC
,
GNHWK
>
,
std
::
tuple
<
ck
::
half_t
,
GNHWC
,
GKYXC
,
GNHWK
,
ck
::
index_t
>
,
std
::
tuple
<
ck
::
bhalf_t
,
GNHWC
,
GKYXC
,
GNHWK
>
,
std
::
tuple
<
ck
::
bhalf_t
,
GNHWC
,
GKYXC
,
GNHWK
,
ck
::
index_t
>
,
std
::
tuple
<
int8_t
,
GNHWC
,
GKYXC
,
GNHWK
>
,
std
::
tuple
<
int8_t
,
GNHWC
,
GKYXC
,
GNHWK
,
ck
::
index_t
>
,
std
::
tuple
<
float
,
NHWGC
,
GKYXC
,
NHWGK
>
,
std
::
tuple
<
float
,
NHWGC
,
GKYXC
,
NHWGK
,
ck
::
index_t
>
,
std
::
tuple
<
ck
::
half_t
,
NHWGC
,
GKYXC
,
NHWGK
>
,
std
::
tuple
<
ck
::
half_t
,
NHWGC
,
GKYXC
,
NHWGK
,
ck
::
index_t
>
,
std
::
tuple
<
ck
::
bhalf_t
,
NHWGC
,
GKYXC
,
NHWGK
>
,
std
::
tuple
<
ck
::
bhalf_t
,
NHWGC
,
GKYXC
,
NHWGK
,
ck
::
index_t
>
,
std
::
tuple
<
int8_t
,
NHWGC
,
GKYXC
,
NHWGK
>>
;
std
::
tuple
<
int8_t
,
NHWGC
,
GKYXC
,
NHWGK
,
ck
::
index_t
>>
;
using
KernelTypes3d
=
::
testing
::
Types
<
std
::
tuple
<
float
,
GNDHWC
,
GKZYXC
,
GNDHWK
>
,
using
KernelTypes3d
=
::
testing
::
Types
<
std
::
tuple
<
float
,
GNDHWC
,
GKZYXC
,
GNDHWK
,
ck
::
index_t
>
,
std
::
tuple
<
ck
::
half_t
,
GNDHWC
,
GKZYXC
,
GNDHWK
>
,
std
::
tuple
<
ck
::
half_t
,
GNDHWC
,
GKZYXC
,
GNDHWK
,
ck
::
index_t
>
,
std
::
tuple
<
ck
::
bhalf_t
,
GNDHWC
,
GKZYXC
,
GNDHWK
>
,
std
::
tuple
<
ck
::
bhalf_t
,
GNDHWC
,
GKZYXC
,
GNDHWK
,
ck
::
index_t
>
,
std
::
tuple
<
int8_t
,
GNDHWC
,
GKZYXC
,
GNDHWK
>
,
std
::
tuple
<
int8_t
,
GNDHWC
,
GKZYXC
,
GNDHWK
,
ck
::
index_t
>
,
std
::
tuple
<
float
,
NDHWGC
,
GKZYXC
,
NDHWGK
>
,
std
::
tuple
<
float
,
NDHWGC
,
GKZYXC
,
NDHWGK
,
ck
::
index_t
>
,
std
::
tuple
<
ck
::
half_t
,
NDHWGC
,
GKZYXC
,
NDHWGK
>
,
std
::
tuple
<
ck
::
half_t
,
NDHWGC
,
GKZYXC
,
NDHWGK
,
ck
::
index_t
>
,
std
::
tuple
<
ck
::
bhalf_t
,
NDHWGC
,
GKZYXC
,
NDHWGK
>
,
std
::
tuple
<
ck
::
bhalf_t
,
NDHWGC
,
GKZYXC
,
NDHWGK
,
ck
::
index_t
>
,
std
::
tuple
<
int8_t
,
NDHWGC
,
GKZYXC
,
NDHWGK
>>
;
std
::
tuple
<
int8_t
,
NDHWGC
,
GKZYXC
,
NDHWGK
,
ck
::
index_t
>>
;
using
KernelTypes2dLargeCases
=
::
testing
::
Types
<
std
::
tuple
<
float
,
NHWGC
,
GKYXC
,
NHWGK
>>
;
using
KernelTypes2dLargeCases
=
::
testing
::
Types
<
std
::
tuple
<
float
,
NHWGC
,
GKYXC
,
NHWGK
,
ck
::
long_index_t
>>
;
template
<
typename
Tuple
>
template
<
typename
Tuple
>
class
TestGroupedConvndFwd1d
:
public
TestGroupedConvndFwd
<
Tuple
>
class
TestGroupedConvndFwd1d
:
public
TestGroupedConvndFwd
<
Tuple
>
...
@@ -153,5 +158,8 @@ TYPED_TEST(TestGroupedConvndFwd2dLargeCases, Test2DLargeCases)
...
@@ -153,5 +158,8 @@ TYPED_TEST(TestGroupedConvndFwd2dLargeCases, Test2DLargeCases)
// With supported NumGroupsToMerge > 1
// With supported NumGroupsToMerge > 1
this
->
conv_params
.
push_back
(
this
->
conv_params
.
push_back
(
{
2
,
32
,
64
,
1
,
1
,
{
2
,
2
},
{
672
,
672
},
{
672
,
672
},
{
1
,
1
},
{
0
,
0
},
{
0
,
0
}});
{
2
,
32
,
64
,
1
,
1
,
{
2
,
2
},
{
672
,
672
},
{
672
,
672
},
{
1
,
1
},
{
0
,
0
},
{
0
,
0
}});
// When image is larger than 2GB
this
->
conv_params
.
push_back
(
{
2
,
1
,
1
,
256
,
256
,
{
3
,
3
},
{
4096
,
2048
},
{
1024
,
1024
},
{
3
,
3
},
{
1
,
1
},
{
1
,
1
}});
this
->
template
Run
<
2
>();
this
->
template
Run
<
2
>();
}
}
test/smfmac_op/smfmac_op_xdl.cpp
View file @
408534d4
...
@@ -13,6 +13,7 @@
...
@@ -13,6 +13,7 @@
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "test/smfmac_op/smfmac_op_util.hpp"
#include "test/smfmac_op/smfmac_op_util.hpp"
#include "ck/host_utility/device_prop.hpp"
using
BF16
=
ck
::
bhalf_t
;
using
BF16
=
ck
::
bhalf_t
;
using
F16
=
ck
::
half_t
;
using
F16
=
ck
::
half_t
;
...
@@ -38,40 +39,43 @@ class TestSmfmac : public ::testing::Test
...
@@ -38,40 +39,43 @@ class TestSmfmac : public ::testing::Test
void
Run
()
void
Run
()
{
{
bool
pass
=
true
;
bool
pass
=
true
;
constexpr
auto
matmul_default
=
ck
::
smfmac_op_util
::
matmul
<
Src1Type
,
if
(
ck
::
get_device_name
()
==
"gfx942"
)
Src1VecSize
,
{
Src2Type
,
constexpr
auto
matmul_default
=
ck
::
smfmac_op_util
::
matmul
<
Src1Type
,
Src2VecSize
,
Src1VecSize
,
GPUAccType
,
Src2Type
,
AccVecSize
,
Src2VecSize
,
DstType
,
GPUAccType
,
M
,
AccVecSize
,
N
,
DstType
,
K
>
;
M
,
N
,
K
>
;
constexpr
auto
smfmac_kernel_container
=
std
::
make_tuple
(
matmul_default
);
constexpr
auto
smfmac_kernel_container
=
std
::
make_tuple
(
matmul_default
);
ck
::
static_for
<
0
,
std
::
tuple_size_v
<
decltype
(
smfmac_kernel_container
)
>
,
1
>
{}([
&
](
auto
i
)
{
pass
&=
ck
::
smfmac_op_util
::
TestSmfmac
<
std
::
tuple_element_t
<
i
.
value
,
decltype
(
smfmac_kernel_container
)
>
,
Src1Type
,
Src2Type
,
DstType
,
GPUAccType
,
CPUAccType
,
decltype
(
Row
{}),
decltype
(
Row
{}),
decltype
(
Row
{}),
PassThrough
,
PassThrough
,
PassThrough
,
AccVecSize
,
M
,
N
,
K
>
{}(
std
::
get
<
ck
::
Number
<
i
>
{}
>
(
smfmac_kernel_container
));
});
ck
::
static_for
<
0
,
std
::
tuple_size_v
<
decltype
(
smfmac_kernel_container
)
>
,
1
>
{}(
[
&
](
auto
i
)
{
pass
&=
ck
::
smfmac_op_util
::
TestSmfmac
<
std
::
tuple_element_t
<
i
.
value
,
decltype
(
smfmac_kernel_container
)
>
,
Src1Type
,
Src2Type
,
DstType
,
GPUAccType
,
CPUAccType
,
decltype
(
Row
{}),
decltype
(
Row
{}),
decltype
(
Row
{}),
PassThrough
,
PassThrough
,
PassThrough
,
AccVecSize
,
M
,
N
,
K
>
{}(
std
::
get
<
ck
::
Number
<
i
>
{}
>
(
smfmac_kernel_container
));
});
}
EXPECT_TRUE
(
pass
);
EXPECT_TRUE
(
pass
);
}
}
};
};
...
...
Prev
1
…
7
8
9
10
11
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment