Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel_ROCM
Commits
af06f68e
Commit
af06f68e
authored
Jan 30, 2025
by
Andriy Roshchenko
Browse files
Ensure correct naming
parent
2bd601e1
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
19 additions
and
39 deletions
+19
-39
test/mx_mfma_op/mx_mfma_op.cpp
test/mx_mfma_op/mx_mfma_op.cpp
+17
-36
test/mx_mfma_op/mx_mfma_op.hpp
test/mx_mfma_op/mx_mfma_op.hpp
+2
-3
No files found.
test/mx_mfma_op/mx_mfma_op.cpp
View file @
af06f68e
...
@@ -15,8 +15,8 @@ using ck::type_convert;
...
@@ -15,8 +15,8 @@ using ck::type_convert;
*
*
* @param init - selects initialization algorithm for A and B tensors
* @param init - selects initialization algorithm for A and B tensors
*/
*/
template
<
typename
AType
,
typename
BType
,
typename
CType
,
ck
::
mx_mfma_test
::
MFMA_F8F6F4
mfma
>
template
<
typename
AType
,
typename
BType
,
typename
CType
,
ck
::
MFMA_F8F6F4
mfma
>
bool
run_test
(
ck
::
index_t
init
)
bool
run_
mfma_
test
(
ck
::
index_t
init
)
{
{
using
ALayout
=
ck
::
tensor_layout
::
gemm
::
ColumnMajor
;
using
ALayout
=
ck
::
tensor_layout
::
gemm
::
ColumnMajor
;
using
BLayout
=
ck
::
tensor_layout
::
gemm
::
ColumnMajor
;
using
BLayout
=
ck
::
tensor_layout
::
gemm
::
ColumnMajor
;
...
@@ -30,23 +30,22 @@ bool run_test(ck::index_t init)
...
@@ -30,23 +30,22 @@ bool run_test(ck::index_t init)
constexpr
auto
BLOCK_N
=
mfma_instr
.
n_per_blk
;
constexpr
auto
BLOCK_N
=
mfma_instr
.
n_per_blk
;
constexpr
auto
BLOCK_K
=
mfma_instr
.
num_input_blks
*
mfma_instr
.
k_per_blk
;
constexpr
auto
BLOCK_K
=
mfma_instr
.
num_input_blks
*
mfma_instr
.
k_per_blk
;
const
auto
mx_mfma_kernel
=
const
auto
mx_mfma_kernel
=
ck
::
matmul
<
AType
,
BType
,
CType
,
AccType
,
BLOCK_M
,
BLOCK_N
,
BLOCK_K
>
;
ck
::
mx_mfma_test
::
matmul
<
AType
,
BType
,
CType
,
AccType
,
BLOCK_M
,
BLOCK_N
,
BLOCK_K
>
;
bool
pass
=
true
;
bool
pass
=
true
;
pass
=
ck
::
mx_
mfma_test
::
TestMFMA
<
decltype
(
mx_mfma_kernel
),
pass
=
ck
::
mfma_test
::
TestMFMA
<
decltype
(
mx_mfma_kernel
),
AType
,
AType
,
BType
,
BType
,
CType
,
CType
,
AccType
,
AccType
,
CPUAccType
,
CPUAccType
,
ALayout
,
ALayout
,
BLayout
,
BLayout
,
CLayout
,
CLayout
,
BLOCK_M
,
BLOCK_M
,
BLOCK_N
,
BLOCK_N
,
BLOCK_K
>
{}(
mx_mfma_kernel
,
init
);
BLOCK_K
>
{}(
mx_mfma_kernel
,
init
);
return
pass
;
return
pass
;
}
}
...
@@ -54,31 +53,13 @@ bool run_test(ck::index_t init)
...
@@ -54,31 +53,13 @@ bool run_test(ck::index_t init)
TEST
(
MFMA
,
FP8MFMA16x16x128
)
TEST
(
MFMA
,
FP8MFMA16x16x128
)
{
{
auto
AB_init
=
0
;
auto
AB_init
=
0
;
auto
pass
=
run_test
<
f8_t
,
f8_t
,
half_t
,
ck
::
mx_mfma_test
::
MFMA_F8F6F4
::
F32_16x16x128
>
(
AB_init
);
auto
pass
=
run_
mfma_
test
<
f8_t
,
f8_t
,
half_t
,
ck
::
MFMA_F8F6F4
::
F32_16x16x128
>
(
AB_init
);
EXPECT_TRUE
(
pass
);
EXPECT_TRUE
(
pass
);
}
}
TEST
(
MFMA
,
FP8MFMA32x32x64
)
TEST
(
MFMA
,
FP8MFMA32x32x64
)
{
{
auto
AB_init
=
0
;
auto
AB_init
=
0
;
auto
pass
=
run_test
<
f8_t
,
f8_t
,
float
,
ck
::
mx_mfma_test
::
MFMA_F8F6F4
::
F32_32x32x64
>
(
AB_init
);
auto
pass
=
run_
mfma_
test
<
f8_t
,
f8_t
,
float
,
ck
::
MFMA_F8F6F4
::
F32_32x32x64
>
(
AB_init
);
EXPECT_TRUE
(
pass
);
EXPECT_TRUE
(
pass
);
}
}
// TEST(MXMFMA, FP8MFMA32x32x64)
// {
// EXPECT_TRUE(run_test<f8, 1, f8, 1, float, 1, float, float, 32, 32, 64>());
// }
// TEST(MXMFMA, BF8MFMA16x16x128)
// {
// EXPECT_TRUE(run_test<bf8, 1, bf8, 1, float, 1, float, float, 16, 16, 128>());
// }
// TEST(MXMFMA, BF8MFMA32x32x64)
// {
// EXPECT_TRUE(run_test<bf8, 1, bf8, 1, float, 1, float, float, 32, 32, 64>());
// }
// TEST(MXMFMA, MXFP8xMXFP8) { EXPECT_TRUE(false) << "Not Implemented\n"; }
// TEST(MXMFMA, MXBF8xMXBF8) { EXPECT_TRUE(false) << "Not Implemented\n"; }
test/mx_mfma_op/mx_mfma_op.hpp
View file @
af06f68e
...
@@ -11,7 +11,6 @@
...
@@ -11,7 +11,6 @@
#include "ck/library/utility/check_err.hpp"
#include "ck/library/utility/check_err.hpp"
namespace
ck
{
namespace
ck
{
namespace
mx_mfma_test
{
// MFMA instructions supported in this test
// MFMA instructions supported in this test
enum
class
MFMA_F8F6F4
enum
class
MFMA_F8F6F4
...
@@ -353,7 +352,6 @@ __global__ void matmul(const AType* a, const BType* b, CType* c)
...
@@ -353,7 +352,6 @@ __global__ void matmul(const AType* a, const BType* b, CType* c)
auto
storeC
=
store_C_col_major
<
CType
,
CFragT
,
BLOCK_M
,
BLOCK_N
>
{};
auto
storeC
=
store_C_col_major
<
CType
,
CFragT
,
BLOCK_M
,
BLOCK_N
>
{};
storeC
(
c
,
fragC
);
storeC
(
c
,
fragC
);
}
}
/**
/**
* @brief Structure to hold dimension parameters for GEMM tensors.
* @brief Structure to hold dimension parameters for GEMM tensors.
*
*
...
@@ -375,6 +373,7 @@ struct GemmParams
...
@@ -375,6 +373,7 @@ struct GemmParams
ck
::
index_t
StrideC
=
-
1
;
ck
::
index_t
StrideC
=
-
1
;
};
};
namespace
mfma_test
{
template
<
typename
GemmInstance
,
template
<
typename
GemmInstance
,
typename
ADataType
,
typename
ADataType
,
typename
BDataType
,
typename
BDataType
,
...
@@ -564,5 +563,5 @@ struct TestMFMA
...
@@ -564,5 +563,5 @@ struct TestMFMA
}
}
};
};
}
// namespace
mx_
mfma_test
}
// namespace mfma_test
}
// namespace ck
}
// namespace ck
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment