Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel_ROCM
Commits
dec32dc6
Commit
dec32dc6
authored
Jan 31, 2025
by
ThomasNing
Browse files
Finish the feature and merge with develop on the computeV2
parents
71352c44
c5fff071
Changes
215
Hide whitespace changes
Inline
Side-by-side
Showing
15 changed files
with
407 additions
and
172 deletions
+407
-172
profiler/src/profile_gemm_multiply_multiply.cpp
profiler/src/profile_gemm_multiply_multiply.cpp
+8
-1
profiler/src/profile_grouped_gemm_fixed_nk.cpp
profiler/src/profile_grouped_gemm_fixed_nk.cpp
+100
-63
pyproject.toml
pyproject.toml
+5
-2
python/ck4inductor/universal_gemm/gen_instances.py
python/ck4inductor/universal_gemm/gen_instances.py
+7
-6
python/test/test_gen_instances.py
python/test/test_gen_instances.py
+46
-0
test/CMakeLists.txt
test/CMakeLists.txt
+46
-0
test/ck_tile/batched_gemm/test_batched_gemm.cpp
test/ck_tile/batched_gemm/test_batched_gemm.cpp
+1
-1
test/ck_tile/batched_gemm/test_batched_gemm_util.hpp
test/ck_tile/batched_gemm/test_batched_gemm_util.hpp
+4
-4
test/ck_tile/gemm/test_gemm_pipeline.cpp
test/ck_tile/gemm/test_gemm_pipeline.cpp
+15
-13
test/ck_tile/gemm/test_gemm_pipeline_ut_cases.inc
test/ck_tile/gemm/test_gemm_pipeline_ut_cases.inc
+26
-5
test/ck_tile/gemm/test_gemm_pipeline_util.hpp
test/ck_tile/gemm/test_gemm_pipeline_util.hpp
+98
-72
test/ck_tile/grouped_gemm/test_grouped_gemm.cpp
test/ck_tile/grouped_gemm/test_grouped_gemm.cpp
+1
-1
test/ck_tile/grouped_gemm/test_grouped_gemm_util.hpp
test/ck_tile/grouped_gemm/test_grouped_gemm_util.hpp
+1
-4
test/data_type/CMakeLists.txt
test/data_type/CMakeLists.txt
+1
-0
test/data_type/test_bhalf.cpp
test/data_type/test_bhalf.cpp
+48
-0
No files found.
profiler/src/profile_gemm_multiply_multiply.cpp
View file @
dec32dc6
...
@@ -28,6 +28,7 @@ enum struct GemmDataType
...
@@ -28,6 +28,7 @@ enum struct GemmDataType
F16_F16_F16_F8
,
// 6
F16_F16_F16_F8
,
// 6
F8_F8_BF16
,
// 7
F8_F8_BF16
,
// 7
INT8_INT8_BF16
,
// 8
INT8_INT8_BF16
,
// 8
F8_F8_F16
,
// 9
};
};
#define OP_NAME "gemm_multiply_multiply"
#define OP_NAME "gemm_multiply_multiply"
...
@@ -40,7 +41,7 @@ int profile_gemm_multiply_multiply(int argc, char* argv[])
...
@@ -40,7 +41,7 @@ int profile_gemm_multiply_multiply(int argc, char* argv[])
printf
(
"arg1: tensor operation ("
OP_NAME
": "
OP_DESC
")
\n
"
);
printf
(
"arg1: tensor operation ("
OP_NAME
": "
OP_DESC
")
\n
"
);
printf
(
"arg2: data type (0: fp32; 1: fp16; 2: bf16; 3: int8; 4: f8@f16; 5: f16@f8; 6: "
printf
(
"arg2: data type (0: fp32; 1: fp16; 2: bf16; 3: int8; 4: f8@f16; 5: f16@f8; 6: "
"f16->f8; 7: f8->bf16, "
"f16->f8; 7: f8->bf16, "
"comp f8; 8: int8->bf16)
\n
"
);
"comp f8; 8: int8->bf16
; 9: f8->f16, comp f8;
)
\n
"
);
printf
(
"arg3: matrix layout (0: A[m, k] * B[k, n] = C[m, n];
\n
"
);
printf
(
"arg3: matrix layout (0: A[m, k] * B[k, n] = C[m, n];
\n
"
);
printf
(
" 1: A[m, k] * B[n, k] = C[m, n];
\n
"
);
printf
(
" 1: A[m, k] * B[n, k] = C[m, n];
\n
"
);
printf
(
" 2: A[k, m] * B[k, n] = C[m, n];
\n
"
);
printf
(
" 2: A[k, m] * B[k, n] = C[m, n];
\n
"
);
...
@@ -89,6 +90,7 @@ int profile_gemm_multiply_multiply(int argc, char* argv[])
...
@@ -89,6 +90,7 @@ int profile_gemm_multiply_multiply(int argc, char* argv[])
using
F32
=
float
;
using
F32
=
float
;
using
BF16
=
ck
::
bhalf_t
;
using
BF16
=
ck
::
bhalf_t
;
using
F16
=
ck
::
half_t
;
using
F8
=
ck
::
f8_t
;
using
F8
=
ck
::
f8_t
;
using
I8
=
int8_t
;
using
I8
=
int8_t
;
using
I32
=
int
;
using
I32
=
int
;
...
@@ -165,6 +167,11 @@ int profile_gemm_multiply_multiply(int argc, char* argv[])
...
@@ -165,6 +167,11 @@ int profile_gemm_multiply_multiply(int argc, char* argv[])
return
profile
(
return
profile
(
F8
{},
F8
{},
F8
{},
F32
{},
F32
{},
F32
{},
BF16
{},
Row
{},
Col
{},
Row
{},
Col
{},
Row
{});
F8
{},
F8
{},
F8
{},
F32
{},
F32
{},
F32
{},
BF16
{},
Row
{},
Col
{},
Row
{},
Col
{},
Row
{});
}
}
else
if
(
data_type
==
GemmDataType
::
F8_F8_F16
&&
layout
==
GemmMatrixLayout
::
MK_NK_MN
)
{
return
profile
(
F8
{},
F8
{},
F8
{},
F32
{},
F32
{},
F32
{},
F16
{},
Row
{},
Col
{},
Row
{},
Col
{},
Row
{});
}
else
if
(
data_type
==
GemmDataType
::
INT8_INT8_BF16
&&
layout
==
GemmMatrixLayout
::
MK_NK_MN
)
else
if
(
data_type
==
GemmDataType
::
INT8_INT8_BF16
&&
layout
==
GemmMatrixLayout
::
MK_NK_MN
)
{
{
return
profile
(
return
profile
(
...
...
profiler/src/profile_grouped_gemm_fixed_nk.cpp
View file @
dec32dc6
...
@@ -17,11 +17,11 @@ enum struct GemmMatrixLayout
...
@@ -17,11 +17,11 @@ enum struct GemmMatrixLayout
enum
struct
GemmDataType
enum
struct
GemmDataType
{
{
BF16_I8_BF16
,
// 0
BF16_I8_BF16
,
// 0
F16_F16_F16
,
// 1
F16_F16_F16
,
// 1
F16_F8_F16
,
// 2
F16_F8_F16
,
// 2
F16_I8_F16
,
// 3
F16_I8_F16
,
// 3
BF16_BF16_BF16
// 4
};
};
#define OP_NAME "grouped_gemm_fixed_nk"
#define OP_NAME "grouped_gemm_fixed_nk"
...
@@ -39,7 +39,6 @@ std::vector<int> argToIntArray(char* input)
...
@@ -39,7 +39,6 @@ std::vector<int> argToIntArray(char* input)
{
{
out
.
push_back
(
std
::
stoi
(
item
));
out
.
push_back
(
std
::
stoi
(
item
));
}
}
return
out
;
return
out
;
}
}
...
@@ -83,14 +82,6 @@ int profile_grouped_gemm_fixed_nk(int argc, char* argv[])
...
@@ -83,14 +82,6 @@ int profile_grouped_gemm_fixed_nk(int argc, char* argv[])
const
auto
StrideCs
=
argToIntArray
(
argv
[
13
]);
const
auto
StrideCs
=
argToIntArray
(
argv
[
13
]);
const
int
kbatch
=
argc
>=
15
?
std
::
stoi
(
argv
[
14
])
:
1
;
const
int
kbatch
=
argc
>=
15
?
std
::
stoi
(
argv
[
14
])
:
1
;
using
F32
=
float
;
using
F16
=
ck
::
half_t
;
#if defined(CK_ENABLE_FP8)
using
F8
=
ck
::
f8_t
;
#endif
using
BF16
=
ck
::
bhalf_t
;
using
I8
=
int8_t
;
int
n_warmup
=
1
;
int
n_warmup
=
1
;
int
n_iter
=
10
;
int
n_iter
=
10
;
if
(
argc
==
17
)
if
(
argc
==
17
)
...
@@ -99,13 +90,12 @@ int profile_grouped_gemm_fixed_nk(int argc, char* argv[])
...
@@ -99,13 +90,12 @@ int profile_grouped_gemm_fixed_nk(int argc, char* argv[])
n_iter
=
std
::
stoi
(
argv
[
16
]);
n_iter
=
std
::
stoi
(
argv
[
16
]);
}
}
#if defined(CK_ENABLE_BF16) && defined(CK_ENABLE_INT8)
if
(
data_type
==
GemmDataType
::
F16_F16_F16
&&
layout
==
GemmMatrixLayout
::
MK_KN_MN
)
if
(
data_type
==
GemmDataType
::
BF16_I8_BF16
&&
layout
==
GemmMatrixLayout
::
MK_KN_MN
)
{
{
ck
::
profiler
::
profile_grouped_gemm_fixed_nk_impl
<
BF16
,
ck
::
profiler
::
profile_grouped_gemm_fixed_nk_impl
<
ck
::
half_t
,
I8
,
ck
::
half_t
,
BF16
,
ck
::
half_t
,
F32
,
float
,
ck
::
tensor_layout
::
gemm
::
RowMajor
,
ck
::
tensor_layout
::
gemm
::
RowMajor
,
ck
::
tensor_layout
::
gemm
::
RowMajor
,
ck
::
tensor_layout
::
gemm
::
RowMajor
,
ck
::
tensor_layout
::
gemm
::
RowMajor
>
(
ck
::
tensor_layout
::
gemm
::
RowMajor
>
(
...
@@ -123,12 +113,12 @@ int profile_grouped_gemm_fixed_nk(int argc, char* argv[])
...
@@ -123,12 +113,12 @@ int profile_grouped_gemm_fixed_nk(int argc, char* argv[])
n_warmup
,
n_warmup
,
n_iter
);
n_iter
);
}
}
else
if
(
data_type
==
GemmDataType
::
B
F16_
I8_B
F16
&&
layout
==
GemmMatrixLayout
::
MK_NK_MN
)
else
if
(
data_type
==
GemmDataType
::
F16_
F16_
F16
&&
layout
==
GemmMatrixLayout
::
MK_NK_MN
)
{
{
ck
::
profiler
::
profile_grouped_gemm_fixed_nk_impl
<
BF16
,
ck
::
profiler
::
profile_grouped_gemm_fixed_nk_impl
<
ck
::
half_t
,
I8
,
ck
::
half_t
,
BF16
,
ck
::
half_t
,
F32
,
float
,
ck
::
tensor_layout
::
gemm
::
RowMajor
,
ck
::
tensor_layout
::
gemm
::
RowMajor
,
ck
::
tensor_layout
::
gemm
::
ColumnMajor
,
ck
::
tensor_layout
::
gemm
::
ColumnMajor
,
ck
::
tensor_layout
::
gemm
::
RowMajor
>
(
ck
::
tensor_layout
::
gemm
::
RowMajor
>
(
...
@@ -146,14 +136,13 @@ int profile_grouped_gemm_fixed_nk(int argc, char* argv[])
...
@@ -146,14 +136,13 @@ int profile_grouped_gemm_fixed_nk(int argc, char* argv[])
n_warmup
,
n_warmup
,
n_iter
);
n_iter
);
}
}
#endif
#if defined(CK_ENABLE_FP8)
#if defined(CK_ENABLE_FP16)
else
if
(
data_type
==
GemmDataType
::
F16_F8_F16
&&
layout
==
GemmMatrixLayout
::
MK_KN_MN
)
else
if
(
data_type
==
GemmDataType
::
F16_F16_F16
&&
layout
==
GemmMatrixLayout
::
MK_KN_MN
)
{
{
ck
::
profiler
::
profile_grouped_gemm_fixed_nk_impl
<
F16
,
ck
::
profiler
::
profile_grouped_gemm_fixed_nk_impl
<
ck
::
half_t
,
F16
,
ck
::
f8_t
,
F16
,
ck
::
half_t
,
F32
,
float
,
ck
::
tensor_layout
::
gemm
::
RowMajor
,
ck
::
tensor_layout
::
gemm
::
RowMajor
,
ck
::
tensor_layout
::
gemm
::
RowMajor
,
ck
::
tensor_layout
::
gemm
::
RowMajor
,
ck
::
tensor_layout
::
gemm
::
RowMajor
>
(
ck
::
tensor_layout
::
gemm
::
RowMajor
>
(
...
@@ -171,12 +160,12 @@ int profile_grouped_gemm_fixed_nk(int argc, char* argv[])
...
@@ -171,12 +160,12 @@ int profile_grouped_gemm_fixed_nk(int argc, char* argv[])
n_warmup
,
n_warmup
,
n_iter
);
n_iter
);
}
}
else
if
(
data_type
==
GemmDataType
::
F16_F
16
_F16
&&
layout
==
GemmMatrixLayout
::
MK_NK_MN
)
else
if
(
data_type
==
GemmDataType
::
F16_F
8
_F16
&&
layout
==
GemmMatrixLayout
::
MK_NK_MN
)
{
{
ck
::
profiler
::
profile_grouped_gemm_fixed_nk_impl
<
F16
,
ck
::
profiler
::
profile_grouped_gemm_fixed_nk_impl
<
ck
::
half_t
,
F16
,
ck
::
f8_t
,
F16
,
ck
::
half_t
,
F32
,
float
,
ck
::
tensor_layout
::
gemm
::
RowMajor
,
ck
::
tensor_layout
::
gemm
::
RowMajor
,
ck
::
tensor_layout
::
gemm
::
ColumnMajor
,
ck
::
tensor_layout
::
gemm
::
ColumnMajor
,
ck
::
tensor_layout
::
gemm
::
RowMajor
>
(
ck
::
tensor_layout
::
gemm
::
RowMajor
>
(
...
@@ -194,14 +183,14 @@ int profile_grouped_gemm_fixed_nk(int argc, char* argv[])
...
@@ -194,14 +183,14 @@ int profile_grouped_gemm_fixed_nk(int argc, char* argv[])
n_warmup
,
n_warmup
,
n_iter
);
n_iter
);
}
}
#endif
#endif
// CK_ENABLE_FP8
#if defined(CK_ENABLE_
FP16) && defined(CK_ENABLE_FP
8)
#if defined(CK_ENABLE_
INT
8)
else
if
(
data_type
==
GemmDataType
::
F16_
F
8_F16
&&
layout
==
GemmMatrixLayout
::
MK_KN_MN
)
else
if
(
data_type
==
GemmDataType
::
F16_
I
8_F16
&&
layout
==
GemmMatrixLayout
::
MK_KN_MN
)
{
{
ck
::
profiler
::
profile_grouped_gemm_fixed_nk_impl
<
F16
,
ck
::
profiler
::
profile_grouped_gemm_fixed_nk_impl
<
ck
::
half_t
,
F8
,
int8_t
,
F16
,
ck
::
half_t
,
F32
,
float
,
ck
::
tensor_layout
::
gemm
::
RowMajor
,
ck
::
tensor_layout
::
gemm
::
RowMajor
,
ck
::
tensor_layout
::
gemm
::
RowMajor
,
ck
::
tensor_layout
::
gemm
::
RowMajor
,
ck
::
tensor_layout
::
gemm
::
RowMajor
>
(
ck
::
tensor_layout
::
gemm
::
RowMajor
>
(
...
@@ -219,12 +208,12 @@ int profile_grouped_gemm_fixed_nk(int argc, char* argv[])
...
@@ -219,12 +208,12 @@ int profile_grouped_gemm_fixed_nk(int argc, char* argv[])
n_warmup
,
n_warmup
,
n_iter
);
n_iter
);
}
}
else
if
(
data_type
==
GemmDataType
::
F16_
F
8_F16
&&
layout
==
GemmMatrixLayout
::
MK_NK_MN
)
else
if
(
data_type
==
GemmDataType
::
F16_
I
8_F16
&&
layout
==
GemmMatrixLayout
::
MK_NK_MN
)
{
{
ck
::
profiler
::
profile_grouped_gemm_fixed_nk_impl
<
F16
,
ck
::
profiler
::
profile_grouped_gemm_fixed_nk_impl
<
ck
::
half_t
,
F8
,
int8_t
,
F16
,
ck
::
half_t
,
F32
,
float
,
ck
::
tensor_layout
::
gemm
::
RowMajor
,
ck
::
tensor_layout
::
gemm
::
RowMajor
,
ck
::
tensor_layout
::
gemm
::
ColumnMajor
,
ck
::
tensor_layout
::
gemm
::
ColumnMajor
,
ck
::
tensor_layout
::
gemm
::
RowMajor
>
(
ck
::
tensor_layout
::
gemm
::
RowMajor
>
(
...
@@ -242,14 +231,14 @@ int profile_grouped_gemm_fixed_nk(int argc, char* argv[])
...
@@ -242,14 +231,14 @@ int profile_grouped_gemm_fixed_nk(int argc, char* argv[])
n_warmup
,
n_warmup
,
n_iter
);
n_iter
);
}
}
#endif
#endif
// CK_ENABLE_INT8
#if defined(CK_ENABLE_F
P
16)
&& defined(CK_ENABLE_INT8)
#if defined(CK_ENABLE_
B
F16)
else
if
(
data_type
==
GemmDataType
::
F16_
I8_
F16
&&
layout
==
GemmMatrixLayout
::
MK_KN_MN
)
else
if
(
data_type
==
GemmDataType
::
B
F16_
BF16_B
F16
&&
layout
==
GemmMatrixLayout
::
MK_KN_MN
)
{
{
ck
::
profiler
::
profile_grouped_gemm_fixed_nk_impl
<
F16
,
ck
::
profiler
::
profile_grouped_gemm_fixed_nk_impl
<
ck
::
bhalf_t
,
I8
,
ck
::
bhalf_t
,
F16
,
ck
::
bhalf_t
,
F32
,
float
,
ck
::
tensor_layout
::
gemm
::
RowMajor
,
ck
::
tensor_layout
::
gemm
::
RowMajor
,
ck
::
tensor_layout
::
gemm
::
RowMajor
,
ck
::
tensor_layout
::
gemm
::
RowMajor
,
ck
::
tensor_layout
::
gemm
::
RowMajor
>
(
ck
::
tensor_layout
::
gemm
::
RowMajor
>
(
...
@@ -267,12 +256,59 @@ int profile_grouped_gemm_fixed_nk(int argc, char* argv[])
...
@@ -267,12 +256,59 @@ int profile_grouped_gemm_fixed_nk(int argc, char* argv[])
n_warmup
,
n_warmup
,
n_iter
);
n_iter
);
}
}
else
if
(
data_type
==
GemmDataType
::
F16_I8_F16
&&
layout
==
GemmMatrixLayout
::
MK_NK_MN
)
else
if
(
data_type
==
GemmDataType
::
BF16_BF16_BF16
&&
layout
==
GemmMatrixLayout
::
MK_NK_MN
)
{
ck
::
profiler
::
profile_grouped_gemm_fixed_nk_impl
<
ck
::
bhalf_t
,
ck
::
bhalf_t
,
ck
::
bhalf_t
,
float
,
ck
::
tensor_layout
::
gemm
::
RowMajor
,
ck
::
tensor_layout
::
gemm
::
ColumnMajor
,
ck
::
tensor_layout
::
gemm
::
RowMajor
>
(
do_verification
,
init_method
,
do_log
,
time_kernel
,
Ms
,
Ns
,
Ks
,
StrideAs
,
StrideBs
,
StrideCs
,
kbatch
,
n_warmup
,
n_iter
);
}
#if defined(CK_ENABLE_INT8)
else
if
(
data_type
==
GemmDataType
::
BF16_I8_BF16
&&
layout
==
GemmMatrixLayout
::
MK_KN_MN
)
{
{
ck
::
profiler
::
profile_grouped_gemm_fixed_nk_impl
<
F16
,
ck
::
profiler
::
profile_grouped_gemm_fixed_nk_impl
<
ck
::
bhalf_t
,
I8
,
int8_t
,
F16
,
ck
::
bhalf_t
,
F32
,
float
,
ck
::
tensor_layout
::
gemm
::
RowMajor
,
ck
::
tensor_layout
::
gemm
::
RowMajor
,
ck
::
tensor_layout
::
gemm
::
RowMajor
>
(
do_verification
,
init_method
,
do_log
,
time_kernel
,
Ms
,
Ns
,
Ks
,
StrideAs
,
StrideBs
,
StrideCs
,
kbatch
,
n_warmup
,
n_iter
);
}
else
if
(
data_type
==
GemmDataType
::
BF16_I8_BF16
&&
layout
==
GemmMatrixLayout
::
MK_NK_MN
)
{
ck
::
profiler
::
profile_grouped_gemm_fixed_nk_impl
<
ck
::
bhalf_t
,
int8_t
,
ck
::
bhalf_t
,
float
,
ck
::
tensor_layout
::
gemm
::
RowMajor
,
ck
::
tensor_layout
::
gemm
::
RowMajor
,
ck
::
tensor_layout
::
gemm
::
ColumnMajor
,
ck
::
tensor_layout
::
gemm
::
ColumnMajor
,
ck
::
tensor_layout
::
gemm
::
RowMajor
>
(
ck
::
tensor_layout
::
gemm
::
RowMajor
>
(
...
@@ -286,11 +322,12 @@ int profile_grouped_gemm_fixed_nk(int argc, char* argv[])
...
@@ -286,11 +322,12 @@ int profile_grouped_gemm_fixed_nk(int argc, char* argv[])
StrideAs
,
StrideAs
,
StrideBs
,
StrideBs
,
StrideCs
,
StrideCs
,
1
,
kbatch
,
n_warmup
,
n_warmup
,
n_iter
);
n_iter
);
}
}
#endif
#endif // CK_ENABLE_INT8
#endif // CK_ENABLE_BF16
else
else
{
{
throw
std
::
runtime_error
(
"wrong! this GEMM data_type & layout is not implemented"
);
throw
std
::
runtime_error
(
"wrong! this GEMM data_type & layout is not implemented"
);
...
...
pyproject.toml
View file @
dec32dc6
...
@@ -21,16 +21,19 @@ dependencies = []
...
@@ -21,16 +21,19 @@ dependencies = []
"Bug
Tracker"
=
"https://github.com/rocm/composable_kernel/issues"
"Bug
Tracker"
=
"https://github.com/rocm/composable_kernel/issues"
[tool.setuptools]
[tool.setuptools]
packages
=
[
"ck4inductor"
,
"ck4inductor.include"
,
"ck4inductor.library"
]
packages
=
[
"ck4inductor"
,
"ck4inductor.include"
,
"ck4inductor.library"
,
"ck4inductor.universal_gemm"
,
"ck4inductor.batched_universal_gemm"
,
"ck4inductor.grouped_conv_fwd"
]
[tool.setuptools.package-dir]
[tool.setuptools.package-dir]
ck4inductor
=
"python/ck4inductor"
ck4inductor
=
"python/ck4inductor"
"ck4inductor.universal_gemm"
=
"python/ck4inductor/universal_gemm"
"ck4inductor.batched_universal_gemm"
=
"python/ck4inductor/batched_universal_gemm"
"ck4inductor.grouped_conv_fwd"
=
"python/ck4inductor/grouped_conv_fwd"
"ck4inductor.include"
=
"include"
"ck4inductor.include"
=
"include"
"ck4inductor.library"
=
"library"
"ck4inductor.library"
=
"library"
[tool.setuptools.package-data]
[tool.setuptools.package-data]
"ck4inductor.include"
=
["ck/**/*.hpp"]
"ck4inductor.include"
=
["ck/**/*.hpp"]
"ck4inductor.library"
=
["src/tensor_operation_instance/gpu/gemm_universal/**/*.hpp"]
"ck4inductor.library"
=
[
"src/tensor_operation_instance/gpu/gemm_universal/**/*.hpp"
,
"src/tensor_operation_instance/gpu/gemm_universal_batched/**/*.hpp"
,
"include/ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/**/*.hpp"
]
[tool.setuptools.dynamic]
[tool.setuptools.dynamic]
version
=
{
attr
=
"setuptools_scm.get_version"
}
version
=
{
attr
=
"setuptools_scm.get_version"
}
python/ck4inductor/universal_gemm/gen_instances.py
View file @
dec32dc6
...
@@ -68,12 +68,13 @@ def parse_instances(str_instances: List[str]) -> List[CKGemmOperation]:
...
@@ -68,12 +68,13 @@ def parse_instances(str_instances: List[str]) -> List[CKGemmOperation]:
template_args
.
insert
(
2
,
tuple
())
# ds layout
template_args
.
insert
(
2
,
tuple
())
# ds layout
template_args
.
insert
(
6
,
tuple
())
# ds dtype
template_args
.
insert
(
6
,
tuple
())
# ds dtype
try
:
new_instance
=
CKGemmOperation
(
new_instance
=
CKGemmOperation
(
*
template_args
,
# type: ignore[arg-type]
*
template_args
,
# type: ignore[arg-type]
)
)
op_instances
.
append
(
new_instance
)
op_instances
.
append
(
new_instance
)
except
TypeError
as
e
:
log
.
debug
(
f
"
{
e
}
when parsing
{
line
}
"
)
return
op_instances
return
op_instances
...
...
python/test/test_gen_instances.py
0 → 100644
View file @
dec32dc6
# SPDX-License-Identifier: MIT
# Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
import
logging
import
unittest
from
ck4inductor.universal_gemm.gen_instances
import
(
gen_ops_library
as
gen_gemm_ops_library
,
)
from
ck4inductor.universal_gemm.gen_instances
import
(
gen_ops_preselected
as
gen_gemm_ops_preselected
,
)
from
ck4inductor.grouped_conv_fwd.gen_instances
import
(
gen_conv_ops_library
as
gen_conv_ops_library
,
)
from
ck4inductor.batched_universal_gemm.gen_instances
import
(
gen_ops_library
as
gen_batched_gemm_ops_library
,
)
log
=
logging
.
getLogger
(
__name__
)
class
TestGenInstances
(
unittest
.
TestCase
):
def
test_gen_gemm_instances
(
self
):
instances
=
gen_gemm_ops_library
()
log
.
debug
(
"%d gemm instances from library"
%
len
(
instances
))
self
.
assertTrue
(
instances
)
def
test_preselected_gemm_instances
(
self
):
instances
=
gen_gemm_ops_preselected
()
log
.
debug
(
"%d preselected gemm instances"
%
len
(
instances
))
self
.
assertTrue
(
instances
)
def
test_gen_conv_instances
(
self
):
instances
=
gen_conv_ops_library
()
log
.
debug
(
"%d gemm instances from library"
%
len
(
instances
))
self
.
assertTrue
(
instances
)
def
test_gen_batched_gemm_instances
(
self
):
instances
=
gen_batched_gemm_ops_library
()
log
.
debug
(
"%d gemm instances from library"
%
len
(
instances
))
self
.
assertTrue
(
instances
)
test/CMakeLists.txt
View file @
dec32dc6
...
@@ -7,6 +7,34 @@ include(gtest)
...
@@ -7,6 +7,34 @@ include(gtest)
add_custom_target
(
tests
)
add_custom_target
(
tests
)
# list of tests that are labelled as REGRESSION_TEST for make regression (runtime more than 30 seconds)
# all other tests are labelled as SMOKE_TEST
set
(
REGRESSION_TESTS
test_gemm_standalone_xdl_fp16
test_gemm_fp16
test_gemm_splitk
test_batched_gemm
test_gemm_universal
test_batched_gemm_softmax_gemm_fp16
test_batched_gemm_softmax_gemm_permute_fp16
test_batched_gemm_bias_softmax_gemm_permute_fp16
test_batched_gemm_softmax_gemm_permute_bf16
test_batched_gemm_bias_softmax_gemm_permute_bf16
test_grouped_gemm_splitk
test_reduce_no_index
test_reduce_with_index
test_convnd_fwd
test_convnd_bwd_data
test_grouped_convnd_fwd
test_grouped_convnd_bwd_weight
test_softmax_rank3
test_softmax_rank4
test_batchnorm_fwd_rank_4
test_batchnorm_bwd_rank_4
test_grouped_convnd_bwd_data_xdl
test_conv_tensor_rearrange
)
function
(
add_test_executable TEST_NAME
)
function
(
add_test_executable TEST_NAME
)
message
(
"adding test
${
TEST_NAME
}
"
)
message
(
"adding test
${
TEST_NAME
}
"
)
set
(
result 1
)
set
(
result 1
)
...
@@ -88,6 +116,15 @@ function(add_test_executable TEST_NAME)
...
@@ -88,6 +116,15 @@ function(add_test_executable TEST_NAME)
endif
()
endif
()
#message("add_test returns ${result}")
#message("add_test returns ${result}")
set
(
result
${
result
}
PARENT_SCOPE
)
set
(
result
${
result
}
PARENT_SCOPE
)
if
(
result EQUAL 0 AND NOT
"
${
TEST_NAME
}
"
IN_LIST REGRESSION_TESTS
)
message
(
"adding to SMOKE TEST FILTER
${
TEST_NAME
}
"
)
set_tests_properties
(
${
TEST_NAME
}
PROPERTIES LABELS
"SMOKE_TEST"
)
add_dependencies
(
smoke
${
TEST_NAME
}
)
elseif
(
result EQUAL 0 AND
"
${
TEST_NAME
}
"
IN_LIST REGRESSION_TESTS
)
message
(
"Adding to REGRESSION TEST FILTER
${
TEST_NAME
}
"
)
set_tests_properties
(
${
TEST_NAME
}
PROPERTIES LABELS
"REGRESSION_TEST"
)
add_dependencies
(
regression
${
TEST_NAME
}
)
endif
()
endfunction
()
endfunction
()
function
(
add_gtest_executable TEST_NAME
)
function
(
add_gtest_executable TEST_NAME
)
...
@@ -168,6 +205,15 @@ function(add_gtest_executable TEST_NAME)
...
@@ -168,6 +205,15 @@ function(add_gtest_executable TEST_NAME)
endif
()
endif
()
#message("add_gtest returns ${result}")
#message("add_gtest returns ${result}")
set
(
result
${
result
}
PARENT_SCOPE
)
set
(
result
${
result
}
PARENT_SCOPE
)
if
(
result EQUAL 0 AND NOT
"
${
TEST_NAME
}
"
IN_LIST REGRESSION_TESTS
)
#message("adding to smoke test FILTER ${TEST_NAME}")
set_tests_properties
(
${
TEST_NAME
}
PROPERTIES LABELS
"SMOKE_TEST"
)
add_dependencies
(
smoke
${
TEST_NAME
}
)
elseif
(
result EQUAL 0 AND
"
${
TEST_NAME
}
"
IN_LIST REGRESSION_TESTS
)
#message("Adding to REGRESSION TEST FILTER ${TEST_NAME}")
set_tests_properties
(
${
TEST_NAME
}
PROPERTIES LABELS
"REGRESSION_TEST"
)
add_dependencies
(
regression
${
TEST_NAME
}
)
endif
()
endfunction
()
endfunction
()
add_compile_options
(
-Wno-c++20-extensions
)
add_compile_options
(
-Wno-c++20-extensions
)
...
...
test/ck_tile/batched_gemm/test_batched_gemm.cpp
View file @
dec32dc6
...
@@ -17,7 +17,7 @@ using Col = ck_tile::tensor_layout::gemm::ColumnMajor;
...
@@ -17,7 +17,7 @@ using Col = ck_tile::tensor_layout::gemm::ColumnMajor;
// clang-format off
// clang-format off
using
KernelTypes
=
::
testing
::
Types
<
using
KernelTypes
=
::
testing
::
Types
<
// ALayout, BLayout, CLayout, ADataType, BDataType, AccDataType, CDataType
// ALayout, BLayout, CLayout, ADataType, BDataType, AccDataType, CDataType
std
::
tuple
<
Row
,
Row
,
Row
,
F16
,
F16
,
F32
,
F16
>
,
//
std::tuple< Row, Row, Row, F16, F16, F32, F16>,
//std::tuple< Col, Row, Row, F16, F16, F32, F16>,
//std::tuple< Col, Row, Row, F16, F16, F32, F16>,
std
::
tuple
<
Row
,
Col
,
Row
,
F16
,
F16
,
F32
,
F16
>
//,
std
::
tuple
<
Row
,
Col
,
Row
,
F16
,
F16
,
F32
,
F16
>
//,
//std::tuple< Col, Col, Row, F16, F16, F32, F16>
//std::tuple< Col, Col, Row, F16, F16, F32, F16>
...
...
test/ck_tile/batched_gemm/test_batched_gemm_util.hpp
View file @
dec32dc6
// SPDX-License-Identifier: MIT
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2024
-2025
, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#pragma once
#include <sstream>
#include <sstream>
...
@@ -61,7 +61,7 @@ class TestCkTileBatchedGemm : public ::testing::Test
...
@@ -61,7 +61,7 @@ class TestCkTileBatchedGemm : public ::testing::Test
ck_tile
::
sequence
<
M_Warp
,
N_Warp
,
K_Warp
>
,
ck_tile
::
sequence
<
M_Warp
,
N_Warp
,
K_Warp
>
,
ck_tile
::
sequence
<
M_Warp_Tile
,
N_Warp_Tile
,
K_Warp_Tile
>>
;
ck_tile
::
sequence
<
M_Warp_Tile
,
N_Warp_Tile
,
K_Warp_Tile
>>
;
using
TilePartitioner
=
ck_tile
::
GemmTilePartitioner
<
CodegenGemmShape
>
;
using
TilePartitioner
=
ck_tile
::
GemmTile
2D
Partitioner
<
CodegenGemmShape
>
;
using
GemmEpilogue
=
std
::
conditional_t
<
using
GemmEpilogue
=
std
::
conditional_t
<
CShuffleEpilogue
,
CShuffleEpilogue
,
...
@@ -73,8 +73,8 @@ class TestCkTileBatchedGemm : public ::testing::Test
...
@@ -73,8 +73,8 @@ class TestCkTileBatchedGemm : public ::testing::Test
kOutputRank
,
kOutputRank
,
1
,
1
,
0
,
0
,
TilePartitioner
::
k
M
,
TilePartitioner
::
M
PerBlock
,
TilePartitioner
::
k
N
>>
,
TilePartitioner
::
N
PerBlock
>>
,
ck_tile
::
Default2DEpilogue
<
ck_tile
::
Default2DEpilogue
<
ck_tile
::
Default2DEpilogueProblem
<
AccDataType
,
CDataType
,
kPadM
,
kPadN
>>>
;
ck_tile
::
Default2DEpilogueProblem
<
AccDataType
,
CDataType
,
kPadM
,
kPadN
>>>
;
...
...
test/ck_tile/gemm/test_gemm_pipeline.cpp
View file @
dec32dc6
...
@@ -14,26 +14,28 @@ using Row = ck_tile::tensor_layout::gemm::RowMajor;
...
@@ -14,26 +14,28 @@ using Row = ck_tile::tensor_layout::gemm::RowMajor;
using
Col
=
ck_tile
::
tensor_layout
::
gemm
::
ColumnMajor
;
using
Col
=
ck_tile
::
tensor_layout
::
gemm
::
ColumnMajor
;
using
Intrawave
=
ck_tile
::
integral_constant
<
ck_tile
::
GemmPipelineScheduler
,
using
Intrawave
=
ck_tile
::
integral_constant
<
ck_tile
::
GemmPipelineScheduler
,
ck_tile
::
GemmPipelineScheduler
::
Intrawave
>
;
ck_tile
::
GemmPipelineScheduler
::
Intrawave
>
;
using
Interwave
=
ck_tile
::
integral_constant
<
ck_tile
::
GemmPipelineScheduler
,
// using Interwave = ck_tile::integral_constant<ck_tile::GemmPipelineScheduler,
ck_tile
::
GemmPipelineScheduler
::
Interwave
>
;
// ck_tile::GemmPipelineScheduler::Interwave>;
using
Mem
=
ck_tile
::
integral_constant
<
GemmPipelineType
,
GemmPipelineType
::
Mem
>
;
// using Mem = ck_tile::integral_constant<GemmPipelineType, GemmPipelineType::Mem>;
using
Comp
=
ck_tile
::
integral_constant
<
GemmPipelineType
,
GemmPipelineType
::
Comp
>
;
using
Comp
=
ck_tile
::
integral_constant
<
GemmPipelineType
,
GemmPipelineType
::
Comp
>
;
// TODO: Enable Memory pipeline, when it would be updated for vector loads on non-K major tensors.
// clang-format off
// clang-format off
using
KernelTypes
=
::
testing
::
Types
<
using
KernelTypes
=
::
testing
::
Types
<
// ALayout, BLayout, CLayout, ADataType, BDataType, AccDataType, CDataType, GemmPipelineScheduler, PipelineType
// ALayout, BLayout, CLayout, ADataType, BDataType, AccDataType, CDataType, GemmPipelineScheduler, PipelineType
std
::
tuple
<
Row
,
Row
,
Row
,
F16
,
F16
,
F32
,
F16
,
Intrawave
,
Mem
>
,
//
std::tuple< Row, Row, Row, F16, F16, F32, F16, Intrawave, Mem>,
std
::
tuple
<
Row
,
Row
,
Row
,
F16
,
F16
,
F32
,
F16
,
Intrawave
,
Comp
>
,
std
::
tuple
<
Row
,
Row
,
Row
,
F16
,
F16
,
F32
,
F16
,
Intrawave
,
Comp
>
,
std
::
tuple
<
Row
,
Row
,
Row
,
F16
,
F16
,
F32
,
F16
,
Interwave
,
Mem
>
,
//
std::tuple< Row, Row, Row, F16, F16, F32, F16, Interwave, Mem>,
std
::
tuple
<
Row
,
Col
,
Row
,
F16
,
F16
,
F32
,
F16
,
Intrawave
,
Mem
>
,
//
std::tuple< Row, Col, Row, F16, F16, F32, F16, Intrawave, Mem>,
std
::
tuple
<
Row
,
Col
,
Row
,
F16
,
F16
,
F32
,
F16
,
Intrawave
,
Comp
>
,
std
::
tuple
<
Row
,
Col
,
Row
,
F16
,
F16
,
F32
,
F16
,
Intrawave
,
Comp
>
,
std
::
tuple
<
Row
,
Col
,
Row
,
F16
,
F16
,
F32
,
F16
,
Interwave
,
Mem
>
,
//
std::tuple< Row, Col, Row, F16, F16, F32, F16, Interwave, Mem>,
std
::
tuple
<
Col
,
Row
,
Row
,
F16
,
F16
,
F32
,
F16
,
Intrawave
,
Mem
>
,
//
std::tuple< Col, Row, Row, F16, F16, F32, F16, Intrawave, Mem>,
std
::
tuple
<
Col
,
Row
,
Row
,
F16
,
F16
,
F32
,
F16
,
Intrawave
,
Comp
>
,
std
::
tuple
<
Col
,
Row
,
Row
,
F16
,
F16
,
F32
,
F16
,
Intrawave
,
Comp
>
,
std
::
tuple
<
Col
,
Row
,
Row
,
F16
,
F16
,
F32
,
F16
,
Interwave
,
Mem
>
,
//
std::tuple< Col, Row, Row, F16, F16, F32, F16, Interwave, Mem>,
std
::
tuple
<
Col
,
Col
,
Row
,
F16
,
F16
,
F32
,
F16
,
Intrawave
,
Mem
>
,
//
std::tuple< Col, Col, Row, F16, F16, F32, F16, Intrawave, Mem>,
std
::
tuple
<
Col
,
Col
,
Row
,
F16
,
F16
,
F32
,
F16
,
Intrawave
,
Comp
>
,
std
::
tuple
<
Col
,
Col
,
Row
,
F16
,
F16
,
F32
,
F16
,
Intrawave
,
Comp
>
std
::
tuple
<
Col
,
Col
,
Row
,
F16
,
F16
,
F32
,
F16
,
Interwave
,
Mem
>
//
std::tuple< Col, Col, Row, F16, F16, F32, F16, Interwave, Mem>
>
;
>
;
// clang-format on
// clang-format on
...
...
test/ck_tile/gemm/test_gemm_pipeline_ut_cases.inc
View file @
dec32dc6
...
@@ -10,22 +10,43 @@ TYPED_TEST(TestCkTileGemmPipeline, SmallM)
...
@@ -10,22 +10,43 @@ TYPED_TEST(TestCkTileGemmPipeline, SmallM)
constexpr
int
K
=
320
;
constexpr
int
K
=
320
;
for
(
int
M
:
Ms
)
for
(
int
M
:
Ms
)
this
->
Run
(
M
,
N
,
K
);
{
if
constexpr
(
std
::
is_same_v
<
typename
TestFixture
::
ALayout
,
ck_tile
::
tensor_layout
::
gemm
::
ColumnMajor
>
)
EXPECT_THROW
((
this
->
Run
(
M
,
N
,
K
)),
std
::
runtime_error
);
else
this
->
Run
(
M
,
N
,
K
);
}
}
}
TYPED_TEST
(
TestCkTileGemmPipeline
,
MidLargeM
)
TYPED_TEST
(
TestCkTileGemmPipeline
,
MidLargeM
)
{
{
std
::
vector
<
int
>
Ms
{
127
,
255
,
312
,
799
,
1573
};
std
::
vector
<
int
>
Ms
{
127
,
255
,
312
,
799
,
1573
};
constexpr
int
N
=
1024
;
constexpr
int
N
=
1024
;
constexpr
int
K
=
320
;
constexpr
int
K
=
320
;
constexpr
int
VecLoadSize
=
8
;
for
(
int
M
:
Ms
)
for
(
int
M
:
Ms
)
this
->
Run
(
M
,
N
,
K
);
{
if
constexpr
(
std
::
is_same_v
<
typename
TestFixture
::
ALayout
,
ck_tile
::
tensor_layout
::
gemm
::
ColumnMajor
>
)
{
// TODO: Can we anyhow deduce used vector load size?
if
(
M
%
VecLoadSize
==
0
)
this
->
Run
(
M
,
N
,
K
);
else
EXPECT_THROW
((
this
->
Run
(
M
,
N
,
K
)),
std
::
runtime_error
);
}
else
{
this
->
Run
(
M
,
N
,
K
);
}
}
}
}
TYPED_TEST
(
TestCkTileGemmPipeline
,
PaddK
)
TYPED_TEST
(
TestCkTileGemmPipeline
,
PaddK
)
{
{
std
::
vector
<
int
>
Ms
{
12
7
};
std
::
vector
<
int
>
Ms
{
12
8
};
constexpr
int
N
=
1024
;
constexpr
int
N
=
1024
;
constexpr
int
K
=
432
;
constexpr
int
K
=
432
;
...
...
test/ck_tile/gemm/test_gemm_pipeline_util.hpp
View file @
dec32dc6
...
@@ -16,6 +16,7 @@ enum struct GemmPipelineType
...
@@ -16,6 +16,7 @@ enum struct GemmPipelineType
Mem
,
Mem
,
Comp
Comp
};
};
template
<
typename
Tuple
>
template
<
typename
Tuple
>
class
TestCkTileGemmPipeline
:
public
::
testing
::
Test
class
TestCkTileGemmPipeline
:
public
::
testing
::
Test
{
{
...
@@ -51,6 +52,9 @@ class TestCkTileGemmPipeline : public ::testing::Test
...
@@ -51,6 +52,9 @@ class TestCkTileGemmPipeline : public ::testing::Test
constexpr
bool
kPadN
=
PadN
;
constexpr
bool
kPadN
=
PadN
;
constexpr
bool
kPadK
=
PadK
;
constexpr
bool
kPadK
=
PadK
;
// TODO: For now - but this should also be a test parameter
constexpr
bool
TransposeC
=
false
;
constexpr
int
kBlockPerCu
=
1
;
constexpr
int
kBlockPerCu
=
1
;
// ===============================================
// ===============================================
...
@@ -59,20 +63,22 @@ class TestCkTileGemmPipeline : public ::testing::Test
...
@@ -59,20 +63,22 @@ class TestCkTileGemmPipeline : public ::testing::Test
ck_tile
::
TileGemmShape
<
ck_tile
::
sequence
<
M_Tile
,
N_Tile
,
K_Tile
>
,
ck_tile
::
TileGemmShape
<
ck_tile
::
sequence
<
M_Tile
,
N_Tile
,
K_Tile
>
,
ck_tile
::
sequence
<
M_Warp
,
N_Warp
,
K_Warp
>
,
ck_tile
::
sequence
<
M_Warp
,
N_Warp
,
K_Warp
>
,
ck_tile
::
sequence
<
M_Warp_Tile
,
N_Warp_Tile
,
K_Warp_Tile
>>
;
ck_tile
::
sequence
<
M_Warp_Tile
,
N_Warp_Tile
,
K_Warp_Tile
>>
;
using
TilePartitioner
=
ck_tile
::
GemmTilePartitioner
<
GemmShape
>
;
using
TilePartitioner
=
ck_tile
::
GemmTile
2D
Partitioner
<
GemmShape
>
;
using
GemmEpilogue
=
ck_tile
::
Default2DEpilogue
<
using
GemmEpilogue
=
ck_tile
::
Default2DEpilogue
<
ck_tile
::
Default2DEpilogueProblem
<
AccDataType
,
CDataType
,
kPadM
,
kPadN
>>
;
ck_tile
::
Default2DEpilogueProblem
<
AccDataType
,
CDataType
,
kPadM
,
kPadN
>>
;
using
Traits
=
ck_tile
::
TileGemmTraits
<
kPadM
,
kPadN
,
kPadK
,
ALayout
,
BLayout
,
CLayout
>
;
using
Traits
=
ck_tile
::
TileGemmTraits
<
kPadM
,
kPadN
,
kPadK
,
ALayout
,
BLayout
,
CLayout
>
;
using
GemmUniversalTraits
=
ck_tile
::
TileGemmUniversalTraits
<
kPadM
,
kPadN
,
kPadK
,
ALayout
,
BLayout
,
CLayout
,
TransposeC
>
;
using
GemmPipelineProblem
=
ck_tile
::
GemmPipelineProblem
<
ADataType
,
BDataType
,
AccDataType
,
GemmShape
,
Traits
>
;
using
BaseGemmPipeline
=
std
::
conditional_t
<
using
BaseGemmPipeline
=
PipelineType
==
GemmPipelineType
::
Mem
,
std
::
conditional_t
<
PipelineType
==
GemmPipelineType
::
Mem
,
ck_tile
::
BaseGemmPipelineAgBgCrMem
<
ck_tile
::
BaseGemmPipelineAgBgCrMem
<
GemmPipelineProblem
>
,
ck_tile
::
GemmPipelineProblem
<
ADataType
,
BDataType
,
AccDataType
,
GemmShape
,
Traits
>>
,
ck_tile
::
BaseGemmPipelineAgBgCrCompV3
<
GemmPipelineProblem
>>
;
ck_tile
::
BaseGemmPipelineAgBgCrCompV3
<
ck_tile
::
GemmPipelineProblem
<
ADataType
,
BDataType
,
AccDataType
,
GemmShape
,
Traits
>>>
;
const
ck_tile
::
index_t
k_grain
=
args
.
k_batch
*
K_Tile
;
const
ck_tile
::
index_t
k_grain
=
args
.
k_batch
*
K_Tile
;
const
ck_tile
::
index_t
K_split
=
(
args
.
K
+
k_grain
-
1
)
/
k_grain
*
K_Tile
;
const
ck_tile
::
index_t
K_split
=
(
args
.
K
+
k_grain
-
1
)
/
k_grain
*
K_Tile
;
...
@@ -84,26 +90,22 @@ class TestCkTileGemmPipeline : public ::testing::Test
...
@@ -84,26 +90,22 @@ class TestCkTileGemmPipeline : public ::testing::Test
constexpr
bool
has_hot_loop_v
=
has_hot_loop_
.
value
;
constexpr
bool
has_hot_loop_v
=
has_hot_loop_
.
value
;
constexpr
auto
tail_number_v
=
tail_number_
.
value
;
constexpr
auto
tail_number_v
=
tail_number_
.
value
;
using
GemmPipeline
=
using
UniversalGemmProblem
=
ck_tile
::
UniversalGemmPipelineProblem
<
ADataType
,
std
::
conditional_t
<
PipelineType
==
GemmPipelineType
::
Mem
,
BDataType
,
ck_tile
::
GemmPipelineAgBgCrMem
<
AccDataType
,
ck_tile
::
UniversalGemmPipelineProblem
<
ADataType
,
GemmShape
,
BDataType
,
GemmUniversalTraits
,
AccDataType
,
Scheduler
,
GemmShape
,
has_hot_loop_v
,
Traits
,
tail_number_v
>
;
Scheduler
,
has_hot_loop_v
,
using
GemmPipeline
=
std
::
conditional_t
<
tail_number_v
>>
,
PipelineType
==
GemmPipelineType
::
Mem
,
ck_tile
::
GemmPipelineAgBgCrCompV3
<
ck_tile
::
GemmPipelineAgBgCrMem
<
UniversalGemmProblem
,
ck_tile
::
UniversalGemmPipelineProblem
<
ADataType
,
ck_tile
::
UniversalGemmPipelineAgBgCrPolicy
>
,
BDataType
,
ck_tile
::
GemmPipelineAgBgCrCompV3
<
UniversalGemmProblem
,
AccDataType
,
ck_tile
::
UniversalGemmPipelineAgBgCrPolicy
>>
;
GemmShape
,
Traits
,
Scheduler
,
has_hot_loop_v
,
tail_number_v
>>>
;
using
Kernel
=
ck_tile
::
GemmKernel
<
TilePartitioner
,
GemmPipeline
,
GemmEpilogue
>
;
using
Kernel
=
ck_tile
::
GemmKernel
<
TilePartitioner
,
GemmPipeline
,
GemmEpilogue
>
;
auto
kargs
=
Kernel
::
MakeKernelArgs
(
args
);
auto
kargs
=
Kernel
::
MakeKernelArgs
(
args
);
...
@@ -129,70 +131,94 @@ class TestCkTileGemmPipeline : public ::testing::Test
...
@@ -129,70 +131,94 @@ class TestCkTileGemmPipeline : public ::testing::Test
if
(
has_hot_loop
)
if
(
has_hot_loop
)
{
{
// Tail pipeline One to Seven
if
constexpr
(
PipelineType
==
GemmPipelineType
::
Comp
)
if
(
tail_num
==
ck_tile
::
TailNumber
::
One
)
{
Run
(
ck_tile
::
bool_constant
<
true
>
{},
ck_tile
::
integral_constant
<
ck_tile
::
TailNumber
,
ck_tile
::
TailNumber
::
One
>
{});
}
else
if
(
tail_num
==
ck_tile
::
TailNumber
::
Full
)
{
{
Run
(
ck_tile
::
bool_constant
<
true
>
{},
if
(
tail_num
==
ck_tile
::
TailNumber
::
Full
)
ck_tile
::
integral_constant
<
ck_tile
::
TailNumber
,
ck_tile
::
TailNumber
::
Full
>
{});
}
if
constexpr
(
BaseGemmPipeline
::
PrefetchStages
>
2
)
{
if
(
tail_num
==
ck_tile
::
TailNumber
::
Two
)
{
{
Run
(
ck_tile
::
bool_constant
<
true
>
{},
Run
(
ck_tile
::
bool_constant
<
true
>
{},
ck_tile
::
integral_constant
<
ck_tile
::
TailNumber
,
ck_tile
::
integral_constant
<
ck_tile
::
TailNumber
,
ck_tile
::
TailNumber
::
Two
>
{});
ck_tile
::
TailNumber
::
Full
>
{});
}
}
}
else
if
constexpr
(
BaseGemmPipeline
::
PrefetchStages
>
3
)
{
if
(
tail_num
==
ck_tile
::
TailNumber
::
Three
)
{
{
Run
(
ck_tile
::
bool_constant
<
true
>
{},
std
::
ostringstream
err
;
ck_tile
::
integral_constant
<
ck_tile
::
TailNumber
,
err
<<
"For compute pipeline tail number should always be Full, but have
\"
"
ck_tile
::
TailNumber
::
Three
>
{});
<<
tail_num
<<
"
\"
which is not supported! PrefetchStages: "
<<
BaseGemmPipeline
::
PrefetchStages
<<
"
\n
File: "
<<
__FILE__
<<
":"
<<
__LINE__
<<
", in function: "
<<
__func__
;
throw
std
::
runtime_error
(
err
.
str
());
}
}
}
}
if
constexpr
(
BaseGemmPipeline
::
PrefetchStages
>
4
)
if
constexpr
(
PipelineType
==
GemmPipelineType
::
Mem
)
{
{
if
(
tail_num
==
ck_tile
::
TailNumber
::
Four
)
// Tail pipeline One to Seven
if
(
tail_num
==
ck_tile
::
TailNumber
::
One
)
{
{
Run
(
ck_tile
::
bool_constant
<
true
>
{},
Run
(
ck_tile
::
bool_constant
<
true
>
{},
ck_tile
::
integral_constant
<
ck_tile
::
TailNumber
,
ck_tile
::
integral_constant
<
ck_tile
::
TailNumber
,
ck_tile
::
TailNumber
::
Four
>
{});
ck_tile
::
TailNumber
::
One
>
{});
}
}
}
else
if
(
tail_num
==
ck_tile
::
TailNumber
::
Full
)
if
constexpr
(
BaseGemmPipeline
::
PrefetchStages
>
5
)
{
if
(
tail_num
==
ck_tile
::
TailNumber
::
Five
)
{
{
Run
(
ck_tile
::
bool_constant
<
true
>
{},
Run
(
ck_tile
::
bool_constant
<
true
>
{},
ck_tile
::
integral_constant
<
ck_tile
::
TailNumber
,
ck_tile
::
integral_constant
<
ck_tile
::
TailNumber
,
ck_tile
::
TailNumber
::
F
ive
>
{});
ck_tile
::
TailNumber
::
F
ull
>
{});
}
}
}
if
constexpr
(
BaseGemmPipeline
::
PrefetchStages
>
6
)
if
constexpr
(
BaseGemmPipeline
::
PrefetchStages
>
2
)
{
if
(
tail_num
==
ck_tile
::
TailNumber
::
Six
)
{
{
Run
(
ck_tile
::
bool_constant
<
true
>
{},
if
(
tail_num
==
ck_tile
::
TailNumber
::
Two
)
ck_tile
::
integral_constant
<
ck_tile
::
TailNumber
,
{
ck_tile
::
TailNumber
::
Six
>
{});
Run
(
ck_tile
::
bool_constant
<
true
>
{},
ck_tile
::
integral_constant
<
ck_tile
::
TailNumber
,
ck_tile
::
TailNumber
::
Two
>
{});
}
}
}
}
if
constexpr
(
BaseGemmPipeline
::
PrefetchStages
>
3
)
if
constexpr
(
BaseGemmPipeline
::
PrefetchStages
>
7
)
{
if
(
tail_num
==
ck_tile
::
TailNumber
::
Seven
)
{
{
Run
(
ck_tile
::
bool_constant
<
true
>
{},
if
(
tail_num
==
ck_tile
::
TailNumber
::
Three
)
ck_tile
::
integral_constant
<
ck_tile
::
TailNumber
,
{
ck_tile
::
TailNumber
::
Seven
>
{});
Run
(
ck_tile
::
bool_constant
<
true
>
{},
ck_tile
::
integral_constant
<
ck_tile
::
TailNumber
,
ck_tile
::
TailNumber
::
Three
>
{});
}
}
if
constexpr
(
BaseGemmPipeline
::
PrefetchStages
>
4
)
{
if
(
tail_num
==
ck_tile
::
TailNumber
::
Four
)
{
Run
(
ck_tile
::
bool_constant
<
true
>
{},
ck_tile
::
integral_constant
<
ck_tile
::
TailNumber
,
ck_tile
::
TailNumber
::
Four
>
{});
}
}
if
constexpr
(
BaseGemmPipeline
::
PrefetchStages
>
5
)
{
if
(
tail_num
==
ck_tile
::
TailNumber
::
Five
)
{
Run
(
ck_tile
::
bool_constant
<
true
>
{},
ck_tile
::
integral_constant
<
ck_tile
::
TailNumber
,
ck_tile
::
TailNumber
::
Five
>
{});
}
}
if
constexpr
(
BaseGemmPipeline
::
PrefetchStages
>
6
)
{
if
(
tail_num
==
ck_tile
::
TailNumber
::
Six
)
{
Run
(
ck_tile
::
bool_constant
<
true
>
{},
ck_tile
::
integral_constant
<
ck_tile
::
TailNumber
,
ck_tile
::
TailNumber
::
Six
>
{});
}
}
if
constexpr
(
BaseGemmPipeline
::
PrefetchStages
>
7
)
{
if
(
tail_num
==
ck_tile
::
TailNumber
::
Seven
)
{
Run
(
ck_tile
::
bool_constant
<
true
>
{},
ck_tile
::
integral_constant
<
ck_tile
::
TailNumber
,
ck_tile
::
TailNumber
::
Seven
>
{});
}
}
}
}
}
}
}
...
...
test/ck_tile/grouped_gemm/test_grouped_gemm.cpp
View file @
dec32dc6
...
@@ -17,7 +17,7 @@ using Col = ck_tile::tensor_layout::gemm::ColumnMajor;
...
@@ -17,7 +17,7 @@ using Col = ck_tile::tensor_layout::gemm::ColumnMajor;
// clang-format off
// clang-format off
using
KernelTypes
=
::
testing
::
Types
<
using
KernelTypes
=
::
testing
::
Types
<
// ALayout, BLayout, CLayout, ADataType, BDataType, AccDataType, CDataType
// ALayout, BLayout, CLayout, ADataType, BDataType, AccDataType, CDataType
std
::
tuple
<
Row
,
Row
,
Row
,
F16
,
F16
,
F32
,
F16
>
,
//
std::tuple< Row, Row, Row, F16, F16, F32, F16>,
//std::tuple< Col, Row, Row, F16, F16, F32, F16>,
//std::tuple< Col, Row, Row, F16, F16, F32, F16>,
std
::
tuple
<
Row
,
Col
,
Row
,
F16
,
F16
,
F32
,
F16
>
//,
std
::
tuple
<
Row
,
Col
,
Row
,
F16
,
F16
,
F32
,
F16
>
//,
//std::tuple< Col, Col, Row, F16, F16, F32, F16>
//std::tuple< Col, Col, Row, F16, F16, F32, F16>
...
...
test/ck_tile/grouped_gemm/test_grouped_gemm_util.hpp
View file @
dec32dc6
...
@@ -96,12 +96,9 @@ class TestCkTileGroupedGemm : public ::testing::Test
...
@@ -96,12 +96,9 @@ class TestCkTileGroupedGemm : public ::testing::Test
CodegenGemmShape
,
CodegenGemmShape
,
CodegenGemmTraits
<
ALayout
,
BLayout
,
CLayout
>>
;
CodegenGemmTraits
<
ALayout
,
BLayout
,
CLayout
>>
;
using
CodegenGemmPolicy
=
ck_tile
::
UniversalGemmPipelineAgBgCrPolicy
;
template
<
typename
ALayout
,
typename
BLayout
,
typename
CLayout
>
template
<
typename
ALayout
,
typename
BLayout
,
typename
CLayout
>
using
CodegenGemmPipeline
=
using
CodegenGemmPipeline
=
ck_tile
::
GemmPipelineAGmemBGmemCRegV1
<
CodegenPipelineProblem
<
ALayout
,
BLayout
,
CLayout
>
,
ck_tile
::
GemmPipelineAGmemBGmemCRegV1
<
CodegenPipelineProblem
<
ALayout
,
BLayout
,
CLayout
>>
;
CodegenGemmPolicy
>
;
template
<
typename
ALayout
,
typename
BLayout
,
typename
CLayout
>
template
<
typename
ALayout
,
typename
BLayout
,
typename
CLayout
>
using
Kernel
=
ck_tile
::
GroupedGemmKernel
<
TilePartitioner
,
using
Kernel
=
ck_tile
::
GroupedGemmKernel
<
TilePartitioner
,
...
...
test/data_type/CMakeLists.txt
View file @
dec32dc6
...
@@ -49,3 +49,4 @@ if(result EQUAL 0)
...
@@ -49,3 +49,4 @@ if(result EQUAL 0)
endif
()
endif
()
add_gtest_executable
(
test_type_convert_const type_convert_const.cpp
)
add_gtest_executable
(
test_type_convert_const type_convert_const.cpp
)
add_gtest_executable
(
test_bhalf test_bhalf.cpp
)
test/data_type/test_bhalf.cpp
0 → 100644
View file @
dec32dc6
// SPDX-License-Identifier: MIT
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
#include "gtest/gtest.h"
#include "ck/utility/data_type.hpp"
#include "ck/utility/type_convert.hpp"
using
ck
::
bhalf_t
;
using
ck
::
type_convert
;
TEST
(
BHALF_T
,
Nan
)
{
const
uint16_t
binary_bhalf_nan
=
0x7FC0
;
const
bhalf_t
bhalf_nan
=
ck
::
bit_cast
<
bhalf_t
>
(
binary_bhalf_nan
);
EXPECT_EQ
(
bhalf_nan
,
type_convert
<
bhalf_t
>
(
ck
::
NumericLimits
<
float
>::
QuietNaN
()));
}
TEST
(
BHALF_T
,
Inf
)
{
const
uint16_t
binary_bhalf_inf
=
0x7F80
;
const
bhalf_t
bhalf_inf
=
ck
::
bit_cast
<
bhalf_t
>
(
binary_bhalf_inf
);
EXPECT_EQ
(
bhalf_inf
,
type_convert
<
bhalf_t
>
(
ck
::
NumericLimits
<
float
>::
Infinity
()));
}
TEST
(
BHALF_T
,
MantisaOverflow
)
{
const
float
abs_tol
=
std
::
pow
(
2
,
-
7
);
const
uint32_t
val
=
0x81FFFFFF
;
const
float
float_val
=
ck
::
bit_cast
<
float
>
(
val
);
ASSERT_NEAR
(
float_val
,
type_convert
<
float
>
(
type_convert
<
bhalf_t
>
(
float_val
)),
abs_tol
);
}
TEST
(
BHALF_T
,
ExpOverflow
)
{
const
uint32_t
val
=
0xFF800000
;
const
float
float_val
=
ck
::
bit_cast
<
float
>
(
val
);
ASSERT_EQ
(
type_convert
<
float
>
(
type_convert
<
bhalf_t
>
(
float_val
)),
float_val
);
}
TEST
(
BHALF_T
,
MantisaExpOverflow
)
{
const
uint32_t
val
=
0xFFFFFFFF
;
const
float
float_val
=
ck
::
bit_cast
<
float
>
(
val
);
ASSERT_TRUE
(
std
::
isnan
(
float_val
));
ASSERT_TRUE
(
std
::
isnan
(
type_convert
<
float
>
(
type_convert
<
bhalf_t
>
(
float_val
))));
}
Prev
1
…
7
8
9
10
11
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment