Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel_ROCM
Commits
efab74a3
Commit
efab74a3
authored
Jan 24, 2025
by
Rostyslav Geyyer
Browse files
Merge branch 'gfx950' into lwpck-2619
parents
86950b3a
bcef33c1
Changes
362
Show whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
1755 additions
and
76 deletions
+1755
-76
profiler/src/profile_gemm_b_scale.cpp
profiler/src/profile_gemm_b_scale.cpp
+181
-0
profiler/src/profile_gemm_multiply_multiply.cpp
profiler/src/profile_gemm_multiply_multiply.cpp
+8
-1
profiler/src/profile_gemm_universal.cpp
profiler/src/profile_gemm_universal.cpp
+14
-3
profiler/src/profile_gemm_universal_streamk.cpp
profiler/src/profile_gemm_universal_streamk.cpp
+19
-2
pyproject.toml
pyproject.toml
+5
-2
python/ck4inductor/universal_gemm/gen_instances.py
python/ck4inductor/universal_gemm/gen_instances.py
+7
-6
python/test/test_gen_instances.py
python/test/test_gen_instances.py
+46
-0
script/cmake-ck-dev.sh
script/cmake-ck-dev.sh
+2
-2
script/process_perf_data.py
script/process_perf_data.py
+14
-0
script/process_perf_data.sh
script/process_perf_data.sh
+16
-0
script/process_qa_data.sh
script/process_qa_data.sh
+16
-0
test/CMakeLists.txt
test/CMakeLists.txt
+53
-1
test/ck_tile/batched_gemm/test_batched_gemm_util.hpp
test/ck_tile/batched_gemm/test_batched_gemm_util.hpp
+21
-22
test/ck_tile/gemm/test_gemm_pipeline_util.hpp
test/ck_tile/gemm/test_gemm_pipeline_util.hpp
+12
-32
test/data_type/CMakeLists.txt
test/data_type/CMakeLists.txt
+14
-1
test/data_type/test_bhalf.cpp
test/data_type/test_bhalf.cpp
+48
-0
test/data_type/test_fp8_ocp.cpp
test/data_type/test_fp8_ocp.cpp
+2
-2
test/data_type/test_mx_bf8.cpp
test/data_type/test_mx_bf8.cpp
+654
-0
test/data_type/test_mx_fp8.cpp
test/data_type/test_mx_fp8.cpp
+616
-0
test/grouped_convnd_bwd_weight/test_grouped_convnd_bwd_weight.cpp
...uped_convnd_bwd_weight/test_grouped_convnd_bwd_weight.cpp
+7
-2
No files found.
profiler/src/profile_gemm_b_scale.cpp
0 → 100644
View file @
efab74a3
// SPDX-License-Identifier: MIT
// Copyright (c) 2023-2024, Advanced Micro Devices, Inc. All rights reserved.
#include <cstdlib>
#include <initializer_list>
#include <iostream>
#include <numeric>
#include "profiler/profile_gemm_b_scale_impl.hpp"
#include "profiler_operation_registry.hpp"
enum
struct
GemmMatrixLayout
{
MK_KN_MN
,
// 0
MK_NK_MN
,
// 1
KM_KN_MN
,
// 2
KM_NK_MN
,
// 3
};
enum
struct
GemmDataType
{
F32_F32_F32
,
// 0
F16_F16_F16
,
// 1
BF16_BF16_BF16
,
// 2
INT8_INT8_INT8
,
// 3
F8_F16_F16
,
// 4
F16_F8_F16
,
// 5
F16_F16_F16_F8
,
// 6
F8_F8_BF16
,
// 7
F16_I4_F16
,
// 8
};
enum
struct
BScaleBlockTile
{
K_64
,
// 0
K_128
,
// 1
};
#define OP_NAME "gemm_b_scale"
#define OP_DESC "Int4-dequant GEMM"
int
profile_gemm_b_scale
(
int
argc
,
char
*
argv
[])
{
if
(
argc
!=
16
&&
argc
!=
19
)
{
printf
(
"arg1: tensor operation ("
OP_NAME
": "
OP_DESC
")
\n
"
);
printf
(
"arg2: data type (0: fp32; 1: fp16; 2: bf16; 3: int8; 4: f8@f16; 5: f16@f8; 6: "
"f16->f8; 7: f8->bf16, "
"comp f8; 8: f16@i4)
\n
"
);
printf
(
"arg3: matrix layout (0: A[m, k] * B[k, n] = C[m, n];
\n
"
);
printf
(
" 1: A[m, k] * B[n, k] = C[m, n];
\n
"
);
printf
(
" 2: A[k, m] * B[k, n] = C[m, n];
\n
"
);
printf
(
" 3: A[k, m] * B[n, k] = C[m, n])
\n
"
);
printf
(
"arg4: B scale block tile (0: 64, 1: 128):
\n
"
);
printf
(
"arg5: verification (0: no; 1: yes)
\n
"
);
printf
(
"arg6: initialization (0: no init; 1: integer value; 2: decimal value)
\n
"
);
printf
(
"arg7: print tensor value (0: no; 1: yes)
\n
"
);
printf
(
"arg8: time kernel (0=no, 1=yes)
\n
"
);
printf
(
"arg9 to 14: M, N, K, StrideA, StrideB, StrideC
\n
"
);
printf
(
"arg15: split k into mulitiple batch
\n
"
);
printf
(
"optional:
\n
"
);
printf
(
"arg16: number of warm-up cycles (default 1)
\n
"
);
printf
(
"arg17: number of iterations (default 10)
\n
"
);
printf
(
"arg18: memory for rotating buffer (default 0, size in MB)
\n
"
);
exit
(
1
);
}
printf
(
"Start profiling
\n
"
);
const
auto
data_type
=
static_cast
<
GemmDataType
>
(
std
::
stoi
(
argv
[
2
]));
const
auto
layout
=
static_cast
<
GemmMatrixLayout
>
(
std
::
stoi
(
argv
[
3
]));
const
auto
B_scale_block
=
static_cast
<
BScaleBlockTile
>
(
std
::
stoi
(
argv
[
4
]));
const
bool
do_verification
=
std
::
stoi
(
argv
[
5
]);
const
int
init_method
=
std
::
stoi
(
argv
[
6
]);
const
bool
do_log
=
std
::
stoi
(
argv
[
7
]);
const
bool
time_kernel
=
std
::
stoi
(
argv
[
8
]);
const
int
M
=
std
::
stoi
(
argv
[
9
]);
const
int
N
=
std
::
stoi
(
argv
[
10
]);
const
int
K
=
std
::
stoi
(
argv
[
11
]);
const
int
StrideA
=
std
::
stoi
(
argv
[
12
]);
const
int
StrideB
=
std
::
stoi
(
argv
[
13
]);
const
int
StrideC
=
std
::
stoi
(
argv
[
14
]);
const
int
KBatch
=
std
::
stoi
(
argv
[
15
]);
printf
(
"M:%d, N:%d, K:%d, StrideA:%d, StrideB:%d, StrideC:%d, KBatch:%d
\n
"
,
M
,
N
,
K
,
StrideA
,
StrideB
,
StrideC
,
KBatch
);
int
n_warmup
=
1
;
int
n_iter
=
10
;
uint64_t
rotating
=
0
;
if
(
argc
==
19
)
{
n_warmup
=
std
::
stoi
(
argv
[
16
]);
n_iter
=
std
::
stoi
(
argv
[
17
]);
rotating
=
std
::
stoull
(
argv
[
18
])
*
1024
*
1024
;
printf
(
"n_warmup:%d, n_iter:%d, rotating:%lu
\n
"
,
n_warmup
,
n_iter
,
rotating
);
}
using
F32
=
float
;
using
F16
=
ck
::
half_t
;
using
I4
=
ck
::
pk_i4_t
;
using
Row
=
ck
::
tensor_layout
::
gemm
::
RowMajor
;
using
Col
=
ck
::
tensor_layout
::
gemm
::
ColumnMajor
;
auto
profile
=
[
&
](
auto
a_type
,
auto
b_type
,
auto
b_scale_type
,
auto
comp_type
,
auto
acc_type
,
auto
c_type
,
auto
scale_block_k
,
auto
a_layout
,
auto
b_layout
,
auto
c_layout
)
{
using
ADataType
=
decltype
(
a_type
);
using
BDataType
=
decltype
(
b_type
);
using
BScaleDataType
=
decltype
(
b_scale_type
);
using
ComputeDataType
=
decltype
(
comp_type
);
using
AccDataType
=
decltype
(
acc_type
);
using
CDataType
=
decltype
(
c_type
);
using
ALayout
=
decltype
(
a_layout
);
using
BLayout
=
decltype
(
b_layout
);
using
CLayout
=
decltype
(
c_layout
);
const
int
DefaultStrideA
=
ck
::
is_same_v
<
ALayout
,
Row
>
?
K
:
M
;
const
int
DefaultStrideB
=
ck
::
is_same_v
<
BLayout
,
Row
>
?
N
:
K
;
const
int
DefaultStrideC
=
ck
::
is_same_v
<
CLayout
,
Row
>
?
N
:
M
;
bool
pass
=
ck
::
profiler
::
profile_gemm_b_scale_impl
<
ADataType
,
BDataType
,
BScaleDataType
,
ComputeDataType
,
AccDataType
,
CDataType
,
scale_block_k
,
ALayout
,
BLayout
,
CLayout
>
(
do_verification
,
init_method
,
do_log
,
time_kernel
,
M
,
N
,
K
,
(
StrideA
<
0
)
?
DefaultStrideA
:
StrideA
,
(
StrideB
<
0
)
?
DefaultStrideB
:
StrideB
,
(
StrideC
<
0
)
?
DefaultStrideC
:
StrideC
,
KBatch
,
n_warmup
,
n_iter
,
rotating
);
return
pass
?
0
:
1
;
};
if
(
data_type
==
GemmDataType
::
F16_I4_F16
&&
layout
==
GemmMatrixLayout
::
MK_NK_MN
&&
B_scale_block
==
BScaleBlockTile
::
K_128
)
{
printf
(
"F16_I4_F16 MK_NK_MN K_128
\n
"
);
return
profile
(
F16
{},
I4
{},
F16
{},
F16
{},
F32
{},
F16
{},
ck
::
Number
<
128
>
{},
Row
{},
Col
{},
Row
{});
}
else
{
std
::
cout
<<
"this data_type & layout is not implemented"
<<
std
::
endl
;
return
1
;
}
}
REGISTER_PROFILER_OPERATION
(
OP_NAME
,
OP_DESC
,
profile_gemm_b_scale
);
profiler/src/profile_gemm_multiply_multiply.cpp
View file @
efab74a3
...
...
@@ -28,6 +28,7 @@ enum struct GemmDataType
F16_F16_F16_F8
,
// 6
F8_F8_BF16
,
// 7
INT8_INT8_BF16
,
// 8
F8_F8_F16
,
// 9
};
#define OP_NAME "gemm_multiply_multiply"
...
...
@@ -40,7 +41,7 @@ int profile_gemm_multiply_multiply(int argc, char* argv[])
printf
(
"arg1: tensor operation ("
OP_NAME
": "
OP_DESC
")
\n
"
);
printf
(
"arg2: data type (0: fp32; 1: fp16; 2: bf16; 3: int8; 4: f8@f16; 5: f16@f8; 6: "
"f16->f8; 7: f8->bf16, "
"comp f8; 8: int8->bf16)
\n
"
);
"comp f8; 8: int8->bf16
; 9: f8->f16, comp f8;
)
\n
"
);
printf
(
"arg3: matrix layout (0: A[m, k] * B[k, n] = C[m, n];
\n
"
);
printf
(
" 1: A[m, k] * B[n, k] = C[m, n];
\n
"
);
printf
(
" 2: A[k, m] * B[k, n] = C[m, n];
\n
"
);
...
...
@@ -89,6 +90,7 @@ int profile_gemm_multiply_multiply(int argc, char* argv[])
using
F32
=
float
;
using
BF16
=
ck
::
bhalf_t
;
using
F16
=
ck
::
half_t
;
using
F8
=
ck
::
f8_t
;
using
I8
=
int8_t
;
using
I32
=
int
;
...
...
@@ -165,6 +167,11 @@ int profile_gemm_multiply_multiply(int argc, char* argv[])
return
profile
(
F8
{},
F8
{},
F8
{},
F32
{},
F32
{},
F32
{},
BF16
{},
Row
{},
Col
{},
Row
{},
Col
{},
Row
{});
}
else
if
(
data_type
==
GemmDataType
::
F8_F8_F16
&&
layout
==
GemmMatrixLayout
::
MK_NK_MN
)
{
return
profile
(
F8
{},
F8
{},
F8
{},
F32
{},
F32
{},
F32
{},
F16
{},
Row
{},
Col
{},
Row
{},
Col
{},
Row
{});
}
else
if
(
data_type
==
GemmDataType
::
INT8_INT8_BF16
&&
layout
==
GemmMatrixLayout
::
MK_NK_MN
)
{
return
profile
(
...
...
profiler/src/profile_gemm_universal.cpp
View file @
efab74a3
// SPDX-License-Identifier: MIT
// Copyright (c) 2023-2024, Advanced Micro Devices, Inc. All rights reserved.
#include <cstdlib>
#include <initializer_list>
#include <iostream>
#include <numeric>
#include <initializer_list>
#include <cstdlib>
#include "profiler/profile_gemm_universal_impl.hpp"
#include "profiler_operation_registry.hpp"
...
...
@@ -27,6 +27,8 @@ enum struct GemmDataType
F16_F8_F16
,
// 5
F16_F16_F16_F8
,
// 6
F8_F8_BF16
,
// 7
F16_I4_F16
,
// 8
BF16_I4_BF16
,
// 9
};
#define OP_NAME "gemm_universal"
...
...
@@ -39,7 +41,7 @@ int profile_gemm_universal(int argc, char* argv[])
printf
(
"arg1: tensor operation ("
OP_NAME
": "
OP_DESC
")
\n
"
);
printf
(
"arg2: data type (0: fp32; 1: fp16; 2: bf16; 3: int8; 4: f8@f16; 5: f16@f8; 6: "
"f16->f8; 7: f8->bf16, "
"comp f8
)
\n
"
);
"comp f8
; 8: f16@i4; 9: bf16@i4
\n
"
);
printf
(
"arg3: matrix layout (0: A[m, k] * B[k, n] = C[m, n];
\n
"
);
printf
(
" 1: A[m, k] * B[n, k] = C[m, n];
\n
"
);
printf
(
" 2: A[k, m] * B[k, n] = C[m, n];
\n
"
);
...
...
@@ -103,6 +105,7 @@ int profile_gemm_universal(int argc, char* argv[])
using
BF16
=
ck
::
bhalf_t
;
#if defined(CK_USE_FP8_ON_UNSUPPORTED_ARCH) || defined(CK_USE_GFX94)
using
F8
=
ck
::
f8_t
;
using
I4
=
ck
::
pk_i4_t
;
#endif
using
Row
=
ck
::
tensor_layout
::
gemm
::
RowMajor
;
...
...
@@ -207,6 +210,14 @@ int profile_gemm_universal(int argc, char* argv[])
{
return
profile
(
F8
{},
F8
{},
F8
{},
F32
{},
BF16
{},
Row
{},
Col
{},
Row
{});
}
else
if
(
data_type
==
GemmDataType
::
F16_I4_F16
&&
layout
==
GemmMatrixLayout
::
MK_NK_MN
)
{
return
profile
(
F16
{},
I4
{},
F16
{},
F32
{},
F16
{},
Row
{},
Col
{},
Row
{});
}
else
if
(
data_type
==
GemmDataType
::
BF16_I4_BF16
&&
layout
==
GemmMatrixLayout
::
MK_NK_MN
)
{
return
profile
(
BF16
{},
I4
{},
BF16
{},
F32
{},
BF16
{},
Row
{},
Col
{},
Row
{});
}
#endif
else
{
...
...
profiler/src/profile_gemm_universal_streamk.cpp
100755 → 100644
View file @
efab74a3
...
...
@@ -85,6 +85,7 @@ int profile_gemm_universal_streamk(int argc, char* argv[])
using
F32
=
float
;
using
F16
=
ck
::
half_t
;
using
BF16
=
ck
::
bhalf_t
;
#if defined(CK_USE_FP8_ON_UNSUPPORTED_ARCH) || defined(CK_USE_GFX94)
using
F8
=
ck
::
f8_t
;
...
...
@@ -165,6 +166,22 @@ int profile_gemm_universal_streamk(int argc, char* argv[])
return
profile
(
F8
{},
F16
{},
F32
{},
F16
{},
Row
{},
Col
{},
Row
{});
}
#endif
else
if
(
data_type
==
GemmDataType
::
BF16_BF16_BF16
&&
layout
==
GemmMatrixLayout
::
MK_KN_MN
)
{
return
profile
(
BF16
{},
BF16
{},
F32
{},
BF16
{},
Row
{},
Row
{},
Row
{});
}
else
if
(
data_type
==
GemmDataType
::
BF16_BF16_BF16
&&
layout
==
GemmMatrixLayout
::
MK_NK_MN
)
{
return
profile
(
BF16
{},
BF16
{},
F32
{},
BF16
{},
Row
{},
Col
{},
Row
{});
}
else
if
(
data_type
==
GemmDataType
::
BF16_BF16_BF16
&&
layout
==
GemmMatrixLayout
::
KM_KN_MN
)
{
return
profile
(
BF16
{},
BF16
{},
F32
{},
BF16
{},
Col
{},
Row
{},
Row
{});
}
else
if
(
data_type
==
GemmDataType
::
BF16_BF16_BF16
&&
layout
==
GemmMatrixLayout
::
KM_NK_MN
)
{
return
profile
(
BF16
{},
BF16
{},
F32
{},
BF16
{},
Col
{},
Col
{},
Row
{});
}
else
{
std
::
cout
<<
"this data_type & layout is not implemented"
<<
std
::
endl
;
...
...
pyproject.toml
View file @
efab74a3
...
...
@@ -21,16 +21,19 @@ dependencies = []
"Bug
Tracker"
=
"https://github.com/rocm/composable_kernel/issues"
[tool.setuptools]
packages
=
[
"ck4inductor"
,
"ck4inductor.include"
,
"ck4inductor.library"
]
packages
=
[
"ck4inductor"
,
"ck4inductor.include"
,
"ck4inductor.library"
,
"ck4inductor.universal_gemm"
,
"ck4inductor.batched_universal_gemm"
,
"ck4inductor.grouped_conv_fwd"
]
[tool.setuptools.package-dir]
ck4inductor
=
"python/ck4inductor"
"ck4inductor.universal_gemm"
=
"python/ck4inductor/universal_gemm"
"ck4inductor.batched_universal_gemm"
=
"python/ck4inductor/batched_universal_gemm"
"ck4inductor.grouped_conv_fwd"
=
"python/ck4inductor/grouped_conv_fwd"
"ck4inductor.include"
=
"include"
"ck4inductor.library"
=
"library"
[tool.setuptools.package-data]
"ck4inductor.include"
=
["ck/**/*.hpp"]
"ck4inductor.library"
=
["src/tensor_operation_instance/gpu/gemm_universal/**/*.hpp"]
"ck4inductor.library"
=
[
"src/tensor_operation_instance/gpu/gemm_universal/**/*.hpp"
,
"src/tensor_operation_instance/gpu/gemm_universal_batched/**/*.hpp"
,
"include/ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/**/*.hpp"
]
[tool.setuptools.dynamic]
version
=
{
attr
=
"setuptools_scm.get_version"
}
python/ck4inductor/universal_gemm/gen_instances.py
View file @
efab74a3
...
...
@@ -68,12 +68,13 @@ def parse_instances(str_instances: List[str]) -> List[CKGemmOperation]:
template_args
.
insert
(
2
,
tuple
())
# ds layout
template_args
.
insert
(
6
,
tuple
())
# ds dtype
try
:
new_instance
=
CKGemmOperation
(
*
template_args
,
# type: ignore[arg-type]
)
op_instances
.
append
(
new_instance
)
except
TypeError
as
e
:
log
.
debug
(
f
"
{
e
}
when parsing
{
line
}
"
)
return
op_instances
...
...
python/test/test_gen_instances.py
0 → 100644
View file @
efab74a3
# SPDX-License-Identifier: MIT
# Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
import
logging
import
unittest
from
ck4inductor.universal_gemm.gen_instances
import
(
gen_ops_library
as
gen_gemm_ops_library
,
)
from
ck4inductor.universal_gemm.gen_instances
import
(
gen_ops_preselected
as
gen_gemm_ops_preselected
,
)
from
ck4inductor.grouped_conv_fwd.gen_instances
import
(
gen_conv_ops_library
as
gen_conv_ops_library
,
)
from
ck4inductor.batched_universal_gemm.gen_instances
import
(
gen_ops_library
as
gen_batched_gemm_ops_library
,
)
log
=
logging
.
getLogger
(
__name__
)
class
TestGenInstances
(
unittest
.
TestCase
):
def
test_gen_gemm_instances
(
self
):
instances
=
gen_gemm_ops_library
()
log
.
debug
(
"%d gemm instances from library"
%
len
(
instances
))
self
.
assertTrue
(
instances
)
def
test_preselected_gemm_instances
(
self
):
instances
=
gen_gemm_ops_preselected
()
log
.
debug
(
"%d preselected gemm instances"
%
len
(
instances
))
self
.
assertTrue
(
instances
)
def
test_gen_conv_instances
(
self
):
instances
=
gen_conv_ops_library
()
log
.
debug
(
"%d gemm instances from library"
%
len
(
instances
))
self
.
assertTrue
(
instances
)
def
test_gen_batched_gemm_instances
(
self
):
instances
=
gen_batched_gemm_ops_library
()
log
.
debug
(
"%d gemm instances from library"
%
len
(
instances
))
self
.
assertTrue
(
instances
)
script/cmake-ck-dev.sh
View file @
efab74a3
...
...
@@ -15,7 +15,7 @@ else
fi
cmake
\
-D
CMAKE_PREFIX_PATH
=
/opt/rocm
\
-D
CMAKE_PREFIX_PATH
=
/opt/rocm
/
\
-D
CMAKE_CXX_COMPILER
=
/opt/rocm/bin/hipcc
\
-D
CMAKE_CXX_FLAGS
=
"-Xclang -mllvm -Xclang -enable-post-misched=0 -std=c++17 -O3 -ftemplate-backtrace-limit=0 -fPIE -Wno-gnu-line-marker"
\
-D
CMAKE_BUILD_TYPE
=
Release
\
...
...
script/process_perf_data.py
View file @
efab74a3
...
...
@@ -149,6 +149,12 @@ def parse_logfile(logfile):
lst
=
line
.
split
()
line_dict
=
dict
(
zip
(
lst
[
1
:],
lst
))
res
.
append
(
line_dict
[
'TFlops,'
])
elif
'perf_tile_gemm_basic'
in
logfile
or
'perf_tile_gemm_mem_pipeline'
in
logfile
:
for
line
in
open
(
logfile
):
if
'TFlops'
in
line
:
lst
=
line
.
split
()
line_dict
=
dict
(
zip
(
lst
[
1
:],
lst
))
res
.
append
(
line_dict
[
'TFlops,'
])
return
res
...
...
@@ -330,6 +336,14 @@ def main():
for
i
in
range
(
1
,
len
(
results
)
+
1
):
testlist
.
append
(
"Test%i"
%
i
)
table_name
=
"ck_fmha_bwd_tflops"
if
'gemm_basic_fp16'
in
filename
:
for
i
in
range
(
1
,
len
(
results
)
+
1
):
testlist
.
append
(
"Test%i"
%
i
)
table_name
=
"ck_tile_gemm_basic_fp16_tflops"
if
'gemm_mem_pipeline_fp16'
in
filename
:
for
i
in
range
(
1
,
len
(
results
)
+
1
):
testlist
.
append
(
"Test%i"
%
i
)
table_name
=
"ck_tile_gemm_mem_pipeline_fp16_tflops"
tflops_base
=
get_baseline
(
table_name
,
conn
)
store_new_test_result
(
table_name
,
results
,
testlist
,
branch_name
,
node_id
,
gpu_arch
,
compute_units
,
rocm_vers
,
hip_vers
,
environment
,
sqlEngine
)
...
...
script/process_perf_data.sh
View file @
efab74a3
...
...
@@ -43,3 +43,19 @@ file=./perf_fmha_bwd_gfx90a.log
if
[
-e
"
$file
"
]
;
then
python3 process_perf_data.py perf_fmha_bwd_gfx90a.log
fi
file
=
./perf_tile_gemm_basic_fp16_gfx942.log
if
[
-e
"
$file
"
]
;
then
python3 process_perf_data.py perf_tile_gemm_basic_fp16_gfx942.log
fi
file
=
./perf_tile_gemm_basic_fp16_gfx90a.log
if
[
-e
"
$file
"
]
;
then
python3 process_perf_data.py perf_tile_gemm_basic_fp16_gfx90a.log
fi
file
=
./perf_tile_gemm_mem_pipeline_fp16_gfx942.log
if
[
-e
"
$file
"
]
;
then
python3 process_perf_data.py perf_tile_gemm_mem_pipeline_fp16_gfx942.log
fi
file
=
./perf_tile_gemm_mem_pipeline_fp16_gfx90a.log
if
[
-e
"
$file
"
]
;
then
python3 process_perf_data.py perf_tile_gemm_mem_pipeline_fp16_gfx90a.log
fi
script/process_qa_data.sh
View file @
efab74a3
...
...
@@ -52,3 +52,19 @@ file=./perf_fmha_bwd_gfx90a.log
if
[
-e
"
$file
"
]
;
then
python3 process_perf_data.py perf_fmha_bwd_gfx90a.log
fi
file
=
./perf_gemm_basic_gfx942.log
if
[
-e
"
$file
"
]
;
then
python3 process_perf_data.py perf_gemm_basic_gfx942.log
fi
file
=
./perf_gemm_basic_gfx90a.log
if
[
-e
"
$file
"
]
;
then
python3 process_perf_data.py perf_gemm_basic_gfx90a.log
fi
file
=
./perf_gemm_mem_pipeline_gfx942.log
if
[
-e
"
$file
"
]
;
then
python3 process_perf_data.py perf_gemm_mem_pipeline_gfx942.log
fi
file
=
./perf_gemm_mem_pipeline_gfx90a.log
if
[
-e
"
$file
"
]
;
then
python3 process_perf_data.py perf_gemm_mem_pipeline_gfx90a.log
fi
test/CMakeLists.txt
View file @
efab74a3
...
...
@@ -7,6 +7,34 @@ include(gtest)
add_custom_target
(
tests
)
# list of tests that are labelled as REGRESSION_TEST for make regression (runtime more than 30 seconds)
# all other tests are labelled as SMOKE_TEST
set
(
REGRESSION_TESTS
test_gemm_standalone_xdl_fp16
test_gemm_fp16
test_gemm_splitk
test_batched_gemm
test_gemm_universal
test_batched_gemm_softmax_gemm_fp16
test_batched_gemm_softmax_gemm_permute_fp16
test_batched_gemm_bias_softmax_gemm_permute_fp16
test_batched_gemm_softmax_gemm_permute_bf16
test_batched_gemm_bias_softmax_gemm_permute_bf16
test_grouped_gemm_splitk
test_reduce_no_index
test_reduce_with_index
test_convnd_fwd
test_convnd_bwd_data
test_grouped_convnd_fwd
test_grouped_convnd_bwd_weight
test_softmax_rank3
test_softmax_rank4
test_batchnorm_fwd_rank_4
test_batchnorm_bwd_rank_4
test_grouped_convnd_bwd_data_xdl
test_conv_tensor_rearrange
)
function
(
add_test_executable TEST_NAME
)
message
(
"adding test
${
TEST_NAME
}
"
)
set
(
result 1
)
...
...
@@ -43,6 +71,12 @@ function(add_test_executable TEST_NAME)
set
(
TEST_TARGETS
${
SUPPORTED_GPU_TARGETS
}
)
foreach
(
source IN LISTS ARGN
)
if
(
NOT DEFINED DPP_KERNELS AND source MATCHES
"_dpp"
)
message
(
"removing dpp test
${
source
}
"
)
list
(
REMOVE_ITEM ARGN
"
${
source
}
"
)
endif
()
endforeach
()
foreach
(
source IN LISTS ARGN
)
if
(
NOT DEFINED DL_KERNELS AND source MATCHES
"_dl"
)
message
(
"removing dl test
${
source
}
"
)
...
...
@@ -82,6 +116,15 @@ function(add_test_executable TEST_NAME)
endif
()
#message("add_test returns ${result}")
set
(
result
${
result
}
PARENT_SCOPE
)
if
(
result EQUAL 0 AND NOT
"
${
TEST_NAME
}
"
IN_LIST REGRESSION_TESTS
)
message
(
"adding to SMOKE TEST FILTER
${
TEST_NAME
}
"
)
set_tests_properties
(
${
TEST_NAME
}
PROPERTIES LABELS
"SMOKE_TEST"
)
add_dependencies
(
smoke
${
TEST_NAME
}
)
elseif
(
result EQUAL 0 AND
"
${
TEST_NAME
}
"
IN_LIST REGRESSION_TESTS
)
message
(
"Adding to REGRESSION TEST FILTER
${
TEST_NAME
}
"
)
set_tests_properties
(
${
TEST_NAME
}
PROPERTIES LABELS
"REGRESSION_TEST"
)
add_dependencies
(
regression
${
TEST_NAME
}
)
endif
()
endfunction
()
function
(
add_gtest_executable TEST_NAME
)
...
...
@@ -162,6 +205,15 @@ function(add_gtest_executable TEST_NAME)
endif
()
#message("add_gtest returns ${result}")
set
(
result
${
result
}
PARENT_SCOPE
)
if
(
result EQUAL 0 AND NOT
"
${
TEST_NAME
}
"
IN_LIST REGRESSION_TESTS
)
#message("adding to smoke test FILTER ${TEST_NAME}")
set_tests_properties
(
${
TEST_NAME
}
PROPERTIES LABELS
"SMOKE_TEST"
)
add_dependencies
(
smoke
${
TEST_NAME
}
)
elseif
(
result EQUAL 0 AND
"
${
TEST_NAME
}
"
IN_LIST REGRESSION_TESTS
)
#message("Adding to REGRESSION TEST FILTER ${TEST_NAME}")
set_tests_properties
(
${
TEST_NAME
}
PROPERTIES LABELS
"REGRESSION_TEST"
)
add_dependencies
(
regression
${
TEST_NAME
}
)
endif
()
endfunction
()
add_compile_options
(
-Wno-c++20-extensions
)
...
...
@@ -206,7 +258,7 @@ add_subdirectory(wrapper)
if
(
SUPPORTED_GPU_TARGETS MATCHES
"gfx11"
)
add_subdirectory
(
wmma_op
)
endif
()
if
(
SUPPORTED_GPU_TARGETS MATCHES
"gfx942"
AND CK_HIP_VERSION_MAJOR GREATER_EQUAL 6 AND CK_HIP_VERSION_MINOR GREATER_EQUAL 2
)
# smfmac needs ROCm6.2
if
(
SUPPORTED_GPU_TARGETS MATCHES
"gfx942"
OR SUPPORTED_GPU_TARGETS MATCHES
"gfx950"
)
# smfmac needs ROCm6.2
add_subdirectory
(
smfmac_op
)
endif
()
add_subdirectory
(
position_embedding
)
...
...
test/ck_tile/batched_gemm/test_batched_gemm_util.hpp
View file @
efab74a3
...
...
@@ -24,12 +24,9 @@ class TestCkTileBatchedGemm : public ::testing::Test
using
AccDataType
=
std
::
tuple_element_t
<
5
,
Tuple
>
;
using
CDataType
=
std
::
tuple_element_t
<
6
,
Tuple
>
;
struct
batched_gemm_kargs
:
public
ck_tile
::
BatchedGemmHostArgs
{
};
template
<
typename
ALayout
,
typename
BLayout
,
typename
CLayout
>
void
invoke_batched_gemm
(
const
batched_gemm_kargs
&
args
,
const
ck_tile
::
stream_config
&
s
)
void
invoke_batched_gemm
(
const
ck_tile
::
BatchedGemmHostArgs
&
args
,
const
ck_tile
::
stream_config
&
s
)
{
// The kPadM, kPadN, kPadK & kBlockPerCu should also come from the Codegen part.
constexpr
bool
kPadM
=
false
;
...
...
@@ -94,9 +91,9 @@ class TestCkTileBatchedGemm : public ::testing::Test
using
Kernel
=
ck_tile
::
BatchedGemmKernel
<
TilePartitioner
,
CodegenGemmPipeline
,
GemmEpilogue
>
;
auto
kargs
=
Kernel
::
MakeK
a
rgs
(
args
);
auto
kargs
=
Kernel
::
MakeK
ernelA
rgs
(
args
);
const
dim3
grids
=
Kernel
::
GridSize
(
args
);
const
dim3
grids
=
Kernel
::
GridSize
(
args
.
M
,
args
.
N
,
args
.
k_batch
,
args
.
batch_count
);
constexpr
dim3
blocks
=
Kernel
::
BlockSize
();
if
(
s
.
log_level_
>
0
)
...
...
@@ -185,21 +182,23 @@ class TestCkTileBatchedGemm : public ::testing::Test
c_m_n_dev_buf
.
SetZero
();
c_m_n_dev_result
.
SetZero
();
batched_gemm_kargs
kargs
{
a_m_k_dev_buf
.
GetDeviceBuffer
(),
b_k_n_dev_buf
.
GetDeviceBuffer
(),
c_m_n_dev_buf
.
GetDeviceBuffer
(),
M
,
N
,
K
,
StrideA
,
StrideB
,
StrideC
,
BatchStrideA
,
BatchStrideB
,
BatchStrideC
,
BatchCount
};
invoke_batched_gemm
<
ALayout
,
BLayout
,
CLayout
>
(
kargs
,
ck_tile
::
BatchedGemmHostArgs
args
;
args
.
a_ptr
=
a_m_k_dev_buf
.
GetDeviceBuffer
();
args
.
b_ptr
=
b_k_n_dev_buf
.
GetDeviceBuffer
();
args
.
c_ptr
=
c_m_n_dev_buf
.
GetDeviceBuffer
();
args
.
k_batch
=
1
;
args
.
M
=
M
;
args
.
N
=
N
;
args
.
K
=
K
;
args
.
stride_A
=
StrideA
;
args
.
stride_B
=
StrideB
;
args
.
stride_C
=
StrideC
;
args
.
batch_stride_A
=
BatchStrideA
;
args
.
batch_stride_B
=
BatchStrideB
;
args
.
batch_stride_C
=
BatchStrideC
;
args
.
batch_count
=
BatchCount
;
invoke_batched_gemm
<
ALayout
,
BLayout
,
CLayout
>
(
args
,
ck_tile
::
stream_config
{
nullptr
,
false
});
std
::
cout
<<
"Run kernel with M ="
<<
M
<<
" N ="
<<
N
<<
" K ="
<<
K
...
...
test/ck_tile/gemm/test_gemm_pipeline_util.hpp
View file @
efab74a3
...
...
@@ -31,22 +31,8 @@ class TestCkTileGemmPipeline : public ::testing::Test
static
constexpr
auto
PipelineType
=
std
::
tuple_element_t
<
8
,
Tuple
>::
value
;
// TODO: expose tile size through test t-param ?
struct
gemm_args
{
const
void
*
p_a
;
const
void
*
p_b
;
void
*
p_c
;
ck_tile
::
index_t
kbatch
;
ck_tile
::
index_t
M
;
ck_tile
::
index_t
N
;
ck_tile
::
index_t
K
;
ck_tile
::
index_t
stride_A
;
ck_tile
::
index_t
stride_B
;
ck_tile
::
index_t
stride_C
;
};
template
<
bool
PadM
,
bool
PadN
,
bool
PadK
>
void
invoke_gemm
(
const
gemm_a
rgs
&
args
,
const
ck_tile
::
stream_config
&
s
)
void
invoke_gemm
(
const
ck_tile
::
GemmHostA
rgs
&
args
,
const
ck_tile
::
stream_config
&
s
)
{
// TODO: This should be parameterized in tests
constexpr
ck_tile
::
index_t
M_Tile
=
128
;
...
...
@@ -88,7 +74,9 @@ class TestCkTileGemmPipeline : public ::testing::Test
ck_tile
::
GemmPipelineProblem
<
ADataType
,
BDataType
,
AccDataType
,
GemmShape
,
Traits
>>>
;
const
ck_tile
::
index_t
num_loop
=
TilePartitioner
::
GetLoopNum
(
args
.
K
);
const
ck_tile
::
index_t
k_grain
=
args
.
k_batch
*
K_Tile
;
const
ck_tile
::
index_t
K_split
=
(
args
.
K
+
k_grain
-
1
)
/
k_grain
*
K_Tile
;
const
ck_tile
::
index_t
num_loop
=
TilePartitioner
::
GetLoopNum
(
K_split
);
const
bool
has_hot_loop
=
BaseGemmPipeline
::
BlockHasHotloop
(
num_loop
);
const
ck_tile
::
TailNumber
tail_num
=
BaseGemmPipeline
::
GetBlockLoopTailNum
(
num_loop
);
...
...
@@ -117,17 +105,9 @@ class TestCkTileGemmPipeline : public ::testing::Test
has_hot_loop_v
,
tail_number_v
>>>
;
using
Kernel
=
ck_tile
::
GemmKernel
<
TilePartitioner
,
GemmPipeline
,
GemmEpilogue
>
;
auto
kargs
=
Kernel
::
MakeKargs
(
args
.
p_a
,
args
.
p_b
,
args
.
p_c
,
args
.
M
,
args
.
N
,
args
.
K
,
args
.
stride_A
,
args
.
stride_B
,
args
.
stride_C
);
const
dim3
grids
=
Kernel
::
GridSize
(
args
.
M
,
args
.
N
,
args
.
kbatch
);
auto
kargs
=
Kernel
::
MakeKernelArgs
(
args
);
const
dim3
grids
=
Kernel
::
GridSize
(
args
.
M
,
args
.
N
,
args
.
k_batch
);
constexpr
dim3
blocks
=
Kernel
::
BlockSize
();
if
(
!
Kernel
::
IsSupportedArgument
(
kargs
))
...
...
@@ -319,11 +299,11 @@ class TestCkTileGemmPipeline : public ::testing::Test
c_m_n_dev_buf
.
SetZero
();
c_m_n_dev_result
.
SetZero
();
gemm_a
rgs
args
;
args
.
p_a
=
a_m_k_dev_buf
.
GetDeviceBuffer
();
args
.
p_b
=
b_k_n_dev_buf
.
GetDeviceBuffer
();
args
.
p_c
=
c_m_n_dev_buf
.
GetDeviceBuffer
();
args
.
kbatch
=
kbatch
;
ck_tile
::
GemmHostA
rgs
args
;
args
.
a_ptr
=
a_m_k_dev_buf
.
GetDeviceBuffer
();
args
.
b_ptr
=
b_k_n_dev_buf
.
GetDeviceBuffer
();
args
.
c_ptr
=
c_m_n_dev_buf
.
GetDeviceBuffer
();
args
.
k
_
batch
=
kbatch
;
args
.
M
=
M
;
args
.
N
=
N
;
args
.
K
=
K
;
...
...
test/data_type/CMakeLists.txt
View file @
efab74a3
...
...
@@ -12,6 +12,7 @@ endif()
add_custom_target
(
test_fp8
)
if
(
CK_USE_OCP_FP8
)
# add test for ocp data types
add_gtest_executable
(
test_fp8_ocp test_fp8_ocp.cpp
)
if
(
result EQUAL 0
)
target_link_libraries
(
test_fp8_ocp PRIVATE utility
)
...
...
@@ -62,16 +63,28 @@ if(GPU_TARGETS MATCHES "gfx950")
endif
()
add_dependencies
(
test_mx_data_types test_bf6
)
add_gtest_executable
(
test_mx_fp8 test_mx_fp8.cpp
)
if
(
result EQUAL 0
)
target_link_libraries
(
test_mx_fp8 PRIVATE utility
)
endif
()
add_dependencies
(
test_mx_data_types test_mx_fp8
)
add_gtest_executable
(
test_mx_bf8 test_mx_bf8.cpp
)
if
(
result EQUAL 0
)
target_link_libraries
(
test_mx_bf8 PRIVATE utility
)
endif
()
add_dependencies
(
test_mx_data_types test_mx_bf8
)
add_gtest_executable
(
test_e8m0 test_e8m0.cpp
)
if
(
result EQUAL 0
)
target_link_libraries
(
test_e8m0 PRIVATE utility
)
endif
()
add_dependencies
(
test_mx_data_types test_e8m0
)
endif
()
add_gtest_executable
(
test_custom_type test_custom_type.cpp
)
if
(
result EQUAL 0
)
target_link_libraries
(
test_custom_type PRIVATE utility
)
endif
()
add_gtest_executable
(
test_type_convert_const type_convert_const.cpp
)
add_gtest_executable
(
test_bhalf test_bhalf.cpp
)
test/data_type/test_bhalf.cpp
0 → 100644
View file @
efab74a3
// SPDX-License-Identifier: MIT
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
#include "gtest/gtest.h"
#include "ck/utility/data_type.hpp"
#include "ck/utility/type_convert.hpp"
using
ck
::
bhalf_t
;
using
ck
::
type_convert
;
TEST
(
BHALF_T
,
Nan
)
{
const
uint16_t
binary_bhalf_nan
=
0x7FC0
;
const
bhalf_t
bhalf_nan
=
ck
::
bit_cast
<
bhalf_t
>
(
binary_bhalf_nan
);
EXPECT_EQ
(
bhalf_nan
,
type_convert
<
bhalf_t
>
(
ck
::
NumericLimits
<
float
>::
QuietNaN
()));
}
TEST
(
BHALF_T
,
Inf
)
{
const
uint16_t
binary_bhalf_inf
=
0x7F80
;
const
bhalf_t
bhalf_inf
=
ck
::
bit_cast
<
bhalf_t
>
(
binary_bhalf_inf
);
EXPECT_EQ
(
bhalf_inf
,
type_convert
<
bhalf_t
>
(
ck
::
NumericLimits
<
float
>::
Infinity
()));
}
TEST
(
BHALF_T
,
MantisaOverflow
)
{
const
float
abs_tol
=
std
::
pow
(
2
,
-
7
);
const
uint32_t
val
=
0x81FFFFFF
;
const
float
float_val
=
ck
::
bit_cast
<
float
>
(
val
);
ASSERT_NEAR
(
float_val
,
type_convert
<
float
>
(
type_convert
<
bhalf_t
>
(
float_val
)),
abs_tol
);
}
TEST
(
BHALF_T
,
ExpOverflow
)
{
const
uint32_t
val
=
0xFF800000
;
const
float
float_val
=
ck
::
bit_cast
<
float
>
(
val
);
ASSERT_EQ
(
type_convert
<
float
>
(
type_convert
<
bhalf_t
>
(
float_val
)),
float_val
);
}
TEST
(
BHALF_T
,
MantisaExpOverflow
)
{
const
uint32_t
val
=
0xFFFFFFFF
;
const
float
float_val
=
ck
::
bit_cast
<
float
>
(
val
);
ASSERT_TRUE
(
std
::
isnan
(
float_val
));
ASSERT_TRUE
(
std
::
isnan
(
type_convert
<
float
>
(
type_convert
<
bhalf_t
>
(
float_val
))));
}
test/data_type/test_fp8_ocp.cpp
View file @
efab74a3
...
...
@@ -60,8 +60,8 @@ TEST(FP8OCP, ConvertFP32Nearest)
float
neg_float
=
-
0.015625
f
;
//-2^-6
ASSERT_NEAR
(
neg_float
,
type_convert
<
float
>
(
f8_convert_rne
<
f8_ocp_t
>
(
neg_float
)),
0.0
f
);
// positive subnorm f
loat
value to fp8 and back, check if holds
pos_float
=
0.00390625
f
;
// positive subnorm f
p8
value to fp8 and back, check if holds
pos_float
=
0.00390625
f
;
// 2^-8
ASSERT_NEAR
(
pos_float
,
type_convert
<
float
>
(
f8_convert_rne
<
f8_ocp_t
>
(
pos_float
)),
abs_tol
);
// min subnorm fp8 value to fp8 and back, check if holds
...
...
test/data_type/test_mx_bf8.cpp
0 → 100644
View file @
efab74a3
// SPDX-License-Identifier: MIT
// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved.
#include "gtest/gtest.h"
#include "ck/library/utility/device_memory.hpp"
#include "ck/utility/scaled_type_convert.hpp"
using
ck
::
bf8_ocp_t
;
using
ck
::
bf8x16_ocp_t
;
using
ck
::
bf8x2_ocp_t
;
using
ck
::
bf8x32_ocp_t
;
using
ck
::
e8m0_bexp_t
;
using
ck
::
float16_t
;
using
ck
::
float2_t
;
using
ck
::
float32_t
;
using
ck
::
mxf8_convert_rne
;
using
ck
::
mxf8_convert_sr
;
using
ck
::
scaled_type_convert
;
using
ck
::
type_convert
;
constexpr
uint64_t
test_size
=
256
*
256
+
2
+
4
+
6
;
/**
* @brief Tests conversion of BF8 values to float using E8M0 exponent scaling.
*
* This function performs a series of conversions from BF8 values to float values using
* E8M0 exponent scaling. It handles all possible combinations of E8M0 and BF8 values,
* as well as specific vector and rounding conversions.
*
* @param N The maximum number of conversions to perform.
* @param p_test Pointer to the output array where the converted float values will be stored.
* @param p_completed Pointer to a variable that tracks the number of completed conversions.
*
* @note If either p_test or p_completed is nullptr, the function will return immediately.
* @note The function will stop converting if the number of conversions reaches N.
* @note First 256*256 conversions are for all possible combinations of E8M0 and BF8 values that are
* stored in memory sequentially with BF8 values varying faster.
*
* The function performs the following conversions:
* - All possible combinations of E8M0 and BF8 values. [256x256]
* - Vector conversions bf8x2 -> f32x2. [2]
* - Vector conversions f32x2 -> bf8x2 rne. [2]
* - Vector conversions f32x2 -> bf8x2 sr. [2]
* - Round to nearest even conversions for specific float values. [6]
*
* The results are stored in the p_test array, and the number of completed conversions
* is updated in the p_completed variable.
*/
__host__
__device__
void
test_mx_bf8_scaled_convert
(
uint64_t
N
,
float
*
p_test
,
uint64_t
*
p_completed
)
{
if
(
p_completed
==
nullptr
)
{
return
;
}
uint64_t
&
i
=
*
p_completed
;
i
=
0
;
if
(
p_test
==
nullptr
)
{
return
;
}
// All possible combinations of E8M0 and BF8
for
(
ck
::
index_t
exp_id
=
0
;
exp_id
<
256
;
exp_id
++
)
{
for
(
ck
::
index_t
bf8_id
=
0
;
bf8_id
<
256
;
bf8_id
++
)
{
uint8_t
bf8_uid
=
static_cast
<
uint8_t
>
(
bf8_id
);
auto
v
=
scaled_type_convert
<
float
>
(
e8m0_bexp_t
(
exp_id
),
bf8_ocp_t
{
bf8_uid
});
p_test
[
i
]
=
v
;
i
++
;
if
(
i
>=
N
)
{
return
;
}
}
}
/// Test vector conversions
// bf8x2 -> f32x2
bf8x2_ocp_t
bf8x2
{
bf8x2_ocp_t
::
data_v
{
0b10000100
,
0b00000001
}};
//-2^-14, 2^-16
auto
scale
=
e8m0_bexp_t
(
8.0
f
);
float2_t
f32x2
=
scaled_type_convert
<
float2_t
>
(
scale
,
bf8x2
);
p_test
[
i
++
]
=
f32x2
[
0
];
if
(
i
>=
N
)
{
return
;
}
p_test
[
i
++
]
=
f32x2
[
1
];
if
(
i
>=
N
)
{
return
;
}
// f32x2 -> bf8x2
f32x2
=
{
-
8.0
f
,
4.0
f
};
auto
scale2
=
e8m0_bexp_t
(
2.0
f
);
bf8x2
=
mxf8_convert_rne
<
bf8x2_ocp_t
>
(
f32x2
,
type_convert
<
float
>
(
scale2
));
// expect {-4, 2}
p_test
[
i
++
]
=
type_convert
<
float
>
(
bf8x2
.
AsType
<
bf8_ocp_t
>
()(
ck
::
Number
<
0
>
{}));
//-4f
if
(
i
>=
N
)
{
return
;
}
p_test
[
i
++
]
=
type_convert
<
float
>
(
bf8x2
.
AsType
<
bf8_ocp_t
>
()(
ck
::
Number
<
1
>
{}));
// 2f
if
(
i
>=
N
)
{
return
;
}
auto
scale4
=
e8m0_bexp_t
(
4.0
f
);
bf8x2
=
mxf8_convert_sr
<
bf8x2_ocp_t
>
(
f32x2
,
type_convert
<
float
>
(
scale4
));
// expect {-2, 1}
p_test
[
i
++
]
=
type_convert
<
float
>
(
bf8x2
.
AsType
<
bf8_ocp_t
>
()(
ck
::
Number
<
0
>
{}));
//-2f
if
(
i
>=
N
)
{
return
;
}
p_test
[
i
++
]
=
type_convert
<
float
>
(
bf8x2
.
AsType
<
bf8_ocp_t
>
()(
ck
::
Number
<
1
>
{}));
// 1f
if
(
i
>=
N
)
{
return
;
}
/// Test round to nearest even
p_test
[
i
++
]
=
type_convert
<
float
>
(
mxf8_convert_rne
<
bf8_ocp_t
>
(
1024.0
f
,
4.0
f
));
// 1024/4
if
(
i
>=
N
)
{
return
;
}
p_test
[
i
++
]
=
type_convert
<
float
>
(
mxf8_convert_rne
<
bf8_ocp_t
>
(
std
::
numeric_limits
<
float
>::
quiet_NaN
(),
4.0
f
));
// => NaN
if
(
i
>=
N
)
{
return
;
}
p_test
[
i
++
]
=
type_convert
<
float
>
(
mxf8_convert_rne
<
bf8_ocp_t
>
(
std
::
numeric_limits
<
float
>::
infinity
(),
2.0
f
));
// => BF8 Inf on device
if
(
i
>=
N
)
{
return
;
}
// 31000/0.5 > 57344 => BF8 Inf on device
p_test
[
i
++
]
=
type_convert
<
float
>
(
mxf8_convert_rne
<
bf8_ocp_t
>
(
31000.0
f
,
0.5
f
));
if
(
i
>=
N
)
{
return
;
}
// -31000/0.5 < -57344 => -BF8 Inf on device
p_test
[
i
++
]
=
type_convert
<
float
>
(
mxf8_convert_rne
<
bf8_ocp_t
>
(
-
31000.0
f
,
0.5
f
));
if
(
i
>=
N
)
{
return
;
}
p_test
[
i
++
]
=
type_convert
<
float
>
(
mxf8_convert_rne
<
bf8_ocp_t
>
(
powf
(
2.0
f
,
16.0
f
),
4.0
f
));
// 2^16/4 = 65536/4
if
(
i
>=
N
)
{
return
;
}
}
TEST
(
MXBF8
,
HostScaledConvert
)
{
std
::
vector
<
float
>
out
(
test_size
,
-
1.0
f
);
uint64_t
completed
=
0
;
test_mx_bf8_scaled_convert
(
test_size
,
out
.
data
(),
&
completed
);
// V = X * P; X - E8M0 scale, P - BF8
// If X = NaN, then V = NaN regardless of P
uint8_t
e8m0_nan_id
=
ck
::
NumericLimits
<
e8m0_bexp_t
>::
QuietNaN
().
data
;
for
(
ck
::
index_t
bf8_id
=
0
;
bf8_id
<
256
;
bf8_id
++
)
{
auto
idx
=
e8m0_nan_id
*
256
+
bf8_id
;
ASSERT_TRUE
(
std
::
isnan
(
out
[
idx
]));
}
// If P in {Inf, NaN}, then V = P
std
::
set
<
uint8_t
>
bf8_spec_ids
;
bf8_spec_ids
.
insert
(
0b11111111
);
// -NaN
bf8_spec_ids
.
insert
(
0b01111111
);
// +NaN
bf8_spec_ids
.
insert
(
0b11111101
);
// -NaN
bf8_spec_ids
.
insert
(
0b01111101
);
// +NaN
bf8_spec_ids
.
insert
(
0b11111110
);
// -NaN
bf8_spec_ids
.
insert
(
0b01111110
);
// +NaN
bf8_spec_ids
.
insert
(
0b11111100
);
// -inf
bf8_spec_ids
.
insert
(
0b01111100
);
// +inf
for
(
ck
::
index_t
exp_id
=
0
;
exp_id
<
256
;
exp_id
++
)
{
if
(
exp_id
==
e8m0_nan_id
)
continue
;
for
(
auto
bf8_spec_id
:
bf8_spec_ids
)
{
auto
idx
=
exp_id
*
256
+
bf8_spec_id
;
if
(
std
::
isnan
(
type_convert
<
float
>
(
bf8_ocp_t
{
bf8_spec_id
})))
{
ASSERT_TRUE
(
std
::
isnan
(
out
[
idx
]))
<<
"exp_id: "
<<
exp_id
<<
" bf8_id: "
<<
bf8_spec_id
<<
std
::
endl
<<
type_convert
<
float
>
(
e8m0_bexp_t
(
exp_id
))
<<
" * "
<<
type_convert
<
float
>
(
bf8_ocp_t
{
bf8_spec_id
})
<<
" != "
<<
out
[
idx
];
}
else
{
ASSERT_EQ
(
out
[
idx
],
type_convert
<
float
>
(
bf8_ocp_t
{
bf8_spec_id
}))
<<
"exp_id: "
<<
exp_id
<<
" bf8_id: "
<<
bf8_spec_id
<<
std
::
endl
<<
type_convert
<
float
>
(
e8m0_bexp_t
(
exp_id
))
<<
" * "
<<
type_convert
<
float
>
(
bf8_ocp_t
{
bf8_spec_id
})
<<
" != "
<<
out
[
idx
];
}
}
}
// V = X * P; X, P - finite
for
(
ck
::
index_t
exp_id
=
0
;
exp_id
<
256
;
exp_id
++
)
{
if
(
exp_id
==
e8m0_nan_id
)
continue
;
for
(
ck
::
index_t
bf8_id
=
0
;
bf8_id
<
256
;
bf8_id
++
)
{
if
(
bf8_spec_ids
.
find
(
bf8_id
)
!=
bf8_spec_ids
.
end
())
continue
;
uint8_t
bf8_uid
=
static_cast
<
uint8_t
>
(
bf8_id
);
auto
idx
=
exp_id
*
256
+
bf8_uid
;
ASSERT_FLOAT_EQ
(
out
[
idx
],
type_convert
<
float
>
(
e8m0_bexp_t
(
exp_id
))
*
type_convert
<
float
>
(
bf8_ocp_t
{
bf8_uid
}))
<<
"exp_id: "
<<
exp_id
<<
" bf8_id: "
<<
bf8_uid
<<
std
::
endl
<<
type_convert
<
float
>
(
e8m0_bexp_t
(
exp_id
))
<<
" * "
<<
type_convert
<
float
>
(
bf8_ocp_t
{
bf8_uid
});
}
}
/// Test vector conversions
auto
i
=
256
*
256
;
// bf8x2 -> f32x2
EXPECT_EQ
(
out
[
i
++
],
-
powf
(
2.0
f
,
-
11.0
f
));
EXPECT_EQ
(
out
[
i
++
],
powf
(
2.0
f
,
-
13.0
f
));
// f32x2 -> bf8x2
// RNE
EXPECT_EQ
(
out
[
i
++
],
-
4.0
f
);
EXPECT_EQ
(
out
[
i
++
],
2.0
f
);
// SR
EXPECT_EQ
(
out
[
i
++
],
-
2.0
f
);
EXPECT_EQ
(
out
[
i
++
],
1.0
f
);
/// Test round to nearest even
EXPECT_EQ
(
out
[
i
++
],
1024.0
f
/
4.0
f
)
<<
"out[i-1]: "
<<
out
[
i
-
1
];
EXPECT_TRUE
(
std
::
isnan
(
out
[
i
++
]))
<<
"out[i-1]: "
<<
out
[
i
-
1
];
EXPECT_EQ
(
out
[
i
++
],
type_convert
<
float
>
(
ck
::
NumericLimits
<
bf8_ocp_t
>::
Max
()))
<<
"out[i-1]: "
<<
out
[
i
-
1
];
EXPECT_EQ
(
out
[
i
++
],
type_convert
<
float
>
(
ck
::
NumericLimits
<
bf8_ocp_t
>::
Max
()))
<<
"out[i-1]: "
<<
out
[
i
-
1
];
EXPECT_EQ
(
out
[
i
++
],
type_convert
<
float
>
(
ck
::
NumericLimits
<
bf8_ocp_t
>::
Lowest
()))
<<
"out[i-1]: "
<<
out
[
i
-
1
];
EXPECT_EQ
(
out
[
i
++
],
powf
(
2.0
f
,
14.0
f
))
<<
"out[i-1]: "
<<
out
[
i
-
1
];
EXPECT_EQ
(
test_size
,
completed
);
EXPECT_EQ
(
test_size
,
i
);
}
__global__
void
test_mx_bf8_device_scaled_convert
(
uint64_t
N
,
float
*
p_test
,
uint64_t
*
p_completed
)
{
test_mx_bf8_scaled_convert
(
N
,
p_test
,
p_completed
);
}
TEST
(
MXBF8
,
DeviceScaledConvert
)
{
std
::
vector
<
float
>
out
(
test_size
,
-
1.0
f
);
DeviceMem
device_out
(
test_size
*
sizeof
(
float
));
DeviceMem
device_completed
(
sizeof
(
uint64_t
));
device_out
.
SetValue
(
-
21.0
f
);
device_completed
.
SetValue
(
-
21.0
f
);
test_mx_bf8_device_scaled_convert
<<<
1
,
1
>>>
(
test_size
,
static_cast
<
float
*>
(
device_out
.
GetDeviceBuffer
()),
static_cast
<
uint64_t
*>
(
device_completed
.
GetDeviceBuffer
()));
uint64_t
completed
=
0
;
device_completed
.
FromDevice
(
&
completed
);
device_out
.
FromDevice
(
out
.
data
());
// V = X * P; X - E8M0 scale, P - BF8
// If X = NaN, then V = NaN regardless of P
uint8_t
e8m0_nan_id
=
ck
::
NumericLimits
<
e8m0_bexp_t
>::
QuietNaN
().
data
;
for
(
ck
::
index_t
bf8_id
=
0
;
bf8_id
<
256
;
bf8_id
++
)
{
auto
idx
=
e8m0_nan_id
*
256
+
bf8_id
;
ASSERT_TRUE
(
std
::
isnan
(
out
[
idx
]))
<<
"idx: "
<<
idx
<<
" out[idx]: "
<<
out
[
idx
];
}
// If P in {Inf, NaN}, then V = P
std
::
set
<
uint8_t
>
bf8_spec_ids
;
bf8_spec_ids
.
insert
(
0b11111111
);
//-NaN
bf8_spec_ids
.
insert
(
0b01111111
);
// +NaN
bf8_spec_ids
.
insert
(
0b11111101
);
//-NaN
bf8_spec_ids
.
insert
(
0b01111101
);
// +NaN
bf8_spec_ids
.
insert
(
0b11111110
);
//-NaN
bf8_spec_ids
.
insert
(
0b01111110
);
// +NaN
bf8_spec_ids
.
insert
(
0b11111100
);
//-inf
bf8_spec_ids
.
insert
(
0b01111100
);
// +inf
for
(
ck
::
index_t
exp_id
=
0
;
exp_id
<
256
;
exp_id
++
)
{
if
(
exp_id
==
e8m0_nan_id
)
continue
;
for
(
auto
bf8_spec_id
:
bf8_spec_ids
)
{
auto
idx
=
exp_id
*
256
+
bf8_spec_id
;
if
(
std
::
isnan
(
type_convert
<
float
>
(
bf8_ocp_t
{
bf8_spec_id
})))
{
ASSERT_TRUE
(
std
::
isnan
(
out
[
idx
]))
<<
"exp_id: "
<<
exp_id
<<
" bf8_id: "
<<
bf8_spec_id
<<
std
::
endl
<<
type_convert
<
float
>
(
e8m0_bexp_t
(
exp_id
))
<<
" * "
<<
type_convert
<
float
>
(
bf8_ocp_t
{
bf8_spec_id
})
<<
" != "
<<
out
[
idx
];
}
else
{
ASSERT_EQ
(
out
[
idx
],
type_convert
<
float
>
(
bf8_ocp_t
{
bf8_spec_id
}))
<<
"exp_id: "
<<
exp_id
<<
" bf8_id: "
<<
bf8_spec_id
<<
std
::
endl
<<
type_convert
<
float
>
(
e8m0_bexp_t
(
exp_id
))
<<
" * "
<<
type_convert
<
float
>
(
bf8_ocp_t
{
bf8_spec_id
})
<<
" != "
<<
out
[
idx
];
}
}
}
for
(
ck
::
index_t
exp_id
=
0
;
exp_id
<
256
;
exp_id
++
)
{
if
(
exp_id
==
e8m0_nan_id
)
continue
;
for
(
ck
::
index_t
bf8_id
=
0
;
bf8_id
<
256
;
bf8_id
++
)
{
if
(
bf8_spec_ids
.
find
(
bf8_id
)
!=
bf8_spec_ids
.
end
())
continue
;
uint8_t
bf8_uid
=
static_cast
<
uint8_t
>
(
bf8_id
);
auto
idx
=
exp_id
*
256
+
bf8_uid
;
ASSERT_FLOAT_EQ
(
out
[
idx
],
type_convert
<
float
>
(
e8m0_bexp_t
(
exp_id
))
*
type_convert
<
float
>
(
bf8_ocp_t
{
bf8_uid
}))
<<
"exp_id: "
<<
exp_id
<<
" bf8_id: "
<<
bf8_uid
<<
std
::
endl
<<
type_convert
<
float
>
(
e8m0_bexp_t
(
exp_id
))
<<
" * "
<<
type_convert
<
float
>
(
bf8_ocp_t
{
bf8_uid
});
}
}
/// Test vector conversions
auto
i
=
256
*
256
;
// bf8x2 -> f32x2
EXPECT_EQ
(
out
[
i
++
],
-
powf
(
2.0
f
,
-
11.0
f
));
EXPECT_EQ
(
out
[
i
++
],
powf
(
2.0
f
,
-
13.0
f
));
// f32x2 -> bf8x2
// RNE
EXPECT_EQ
(
out
[
i
++
],
-
4.0
f
);
EXPECT_EQ
(
out
[
i
++
],
2.0
f
);
// SR
EXPECT_EQ
(
out
[
i
++
],
-
2.0
f
);
EXPECT_EQ
(
out
[
i
++
],
1.0
f
);
/// Test round to nearest even
EXPECT_EQ
(
out
[
i
++
],
1024.0
f
/
4.0
f
)
<<
"out[i-1]: "
<<
out
[
i
-
1
];
EXPECT_TRUE
(
std
::
isnan
(
out
[
i
++
]))
<<
"out[i-1]: "
<<
out
[
i
-
1
];
#if 1
EXPECT_TRUE
(
std
::
isinf
(
out
[
i
++
]))
<<
"out[i-1]: "
<<
out
[
i
-
1
];
EXPECT_TRUE
(
std
::
isinf
(
out
[
i
++
]))
<<
"out[i-1]: "
<<
out
[
i
-
1
];
EXPECT_TRUE
(
std
::
isinf
(
out
[
i
++
]))
<<
"out[i-1]: "
<<
out
[
i
-
1
];
#else
// NOTE: Host and Device have different behavior.
// Device returns Infs, while Host returns Max (saturation to finite value).
EXPECT_EQ
(
out
[
i
++
],
type_convert
<
float
>
(
ck
::
NumericLimits
<
bf8_ocp_t
>::
Max
()))
<<
"out[i-1]: "
<<
out
[
i
-
1
];
EXPECT_EQ
(
out
[
i
++
],
type_convert
<
float
>
(
ck
::
NumericLimits
<
bf8_ocp_t
>::
Max
()))
<<
"out[i-1]: "
<<
out
[
i
-
1
];
EXPECT_EQ
(
out
[
i
++
],
type_convert
<
float
>
(
ck
::
NumericLimits
<
bf8_ocp_t
>::
Lowest
()))
<<
"out[i-1]: "
<<
out
[
i
-
1
];
#endif
EXPECT_EQ
(
out
[
i
++
],
powf
(
2.0
f
,
14.0
f
))
<<
"out[i-1]: "
<<
out
[
i
-
1
];
EXPECT_EQ
(
test_size
,
completed
);
EXPECT_EQ
(
test_size
,
i
);
}
__host__
__device__
float
vec16_generator
(
ck
::
index_t
i
)
{
return
powf
(
-
1.0
f
,
i
)
*
powf
(
2.0
f
,
i
);
}
__global__
void
test_mx_bf8x16_device_scaled_convert
(
float
*
p_test
,
uint64_t
*
p_completed
)
{
constexpr
int
N
=
16
;
if
(
p_completed
==
nullptr
)
{
return
;
}
uint64_t
&
i
=
*
p_completed
;
i
=
0
;
if
(
p_test
==
nullptr
)
{
return
;
}
auto
scale2
=
e8m0_bexp_t
(
2.0
f
);
bf8x16_ocp_t
bf8x16
{};
float16_t
float16
{};
ck
::
static_for
<
0
,
N
,
1
>
{}(
[
&
](
auto
ii
)
{
float16
[
static_cast
<
int
>
(
ii
)]
=
vec16_generator
(
ii
);
});
bf8x16
=
scaled_type_convert
<
bf8x16_ocp_t
>
(
scale2
,
float16
);
ck
::
static_for
<
0
,
N
,
1
>
{}([
&
](
auto
ii
)
{
p_test
[
i
++
]
=
type_convert
<
float
>
(
bf8x16
.
AsType
<
bf8_ocp_t
>
()(
ck
::
Number
<
ii
>
{}));
});
}
TEST
(
MXBF8
,
DeviceF32x16ToBF8x16ScaledConvert
)
{
constexpr
int
N
=
16
;
std
::
vector
<
float
>
out
(
N
,
-
1.0
f
);
DeviceMem
device_out
(
N
*
sizeof
(
float
));
DeviceMem
device_completed
(
sizeof
(
uint64_t
));
device_out
.
SetValue
(
-
21.0
f
);
device_completed
.
SetValue
(
-
21.0
f
);
test_mx_bf8x16_device_scaled_convert
<<<
1
,
1
>>>
(
static_cast
<
float
*>
(
device_out
.
GetDeviceBuffer
()),
static_cast
<
uint64_t
*>
(
device_completed
.
GetDeviceBuffer
()));
uint64_t
completed
=
0
;
device_completed
.
FromDevice
(
&
completed
);
device_out
.
FromDevice
(
out
.
data
());
auto
i
=
0
;
ck
::
static_for
<
0
,
N
,
1
>
{}([
&
](
auto
ii
)
{
EXPECT_EQ
(
out
[
i
++
],
vec16_generator
(
ii
)
/
2.0
f
)
<<
"ii: "
<<
ii
<<
std
::
endl
;
});
EXPECT_EQ
(
N
,
completed
);
EXPECT_EQ
(
N
,
i
);
}
__host__
__device__
float
vec32_generator
(
ck
::
index_t
i
)
{
if
(
i
<
16
)
{
return
vec16_generator
(
i
%
16
);
}
else
{
return
1.5
f
*
vec16_generator
(
i
%
16
);
}
}
__global__
void
test_mx_bf8x32_device_scaled_convert
(
float
*
p_test
,
uint64_t
*
p_completed
)
{
constexpr
int
N
=
32
;
if
(
p_completed
==
nullptr
)
{
return
;
}
uint64_t
&
i
=
*
p_completed
;
i
=
0
;
if
(
p_test
==
nullptr
)
{
return
;
}
auto
scale2
=
e8m0_bexp_t
(
2.0
f
);
bf8x32_ocp_t
bf8x32
{};
float32_t
float32
{};
ck
::
static_for
<
0
,
N
,
1
>
{}(
[
&
](
auto
ii
)
{
float32
[
static_cast
<
int
>
(
ii
)]
=
vec32_generator
(
ii
);
});
bf8x32
=
mxf8_convert_rne
<
bf8x32_ocp_t
>
(
float32
,
type_convert
<
float
>
(
scale2
));
ck
::
static_for
<
0
,
N
,
1
>
{}([
&
](
auto
ii
)
{
p_test
[
i
++
]
=
type_convert
<
float
>
(
bf8x32
.
AsType
<
bf8_ocp_t
>
()(
ck
::
Number
<
ii
>
{}));
});
}
TEST
(
MXBF8
,
DeviceF32x32ToBF8x32ScaledConvert
)
{
constexpr
int
N
=
32
;
std
::
vector
<
float
>
out
(
N
,
-
1.0
f
);
DeviceMem
device_out
(
N
*
sizeof
(
float
));
DeviceMem
device_completed
(
sizeof
(
uint64_t
));
device_out
.
SetValue
(
-
21.0
f
);
device_completed
.
SetValue
(
-
21.0
f
);
test_mx_bf8x32_device_scaled_convert
<<<
1
,
1
>>>
(
static_cast
<
float
*>
(
device_out
.
GetDeviceBuffer
()),
static_cast
<
uint64_t
*>
(
device_completed
.
GetDeviceBuffer
()));
uint64_t
completed
=
0
;
device_completed
.
FromDevice
(
&
completed
);
device_out
.
FromDevice
(
out
.
data
());
auto
i
=
0
;
ck
::
static_for
<
0
,
N
,
1
>
{}([
&
](
auto
ii
)
{
EXPECT_EQ
(
out
[
i
++
],
vec32_generator
(
ii
)
/
2.0
f
)
<<
"ii: "
<<
ii
<<
std
::
endl
;
});
EXPECT_EQ
(
N
,
completed
);
EXPECT_EQ
(
N
,
i
);
}
__global__
void
test_mx_bf8x32_device_scaled_convert_sr
(
float
*
p_test
,
uint64_t
*
p_completed
)
{
constexpr
int
N
=
32
;
if
(
p_completed
==
nullptr
)
{
return
;
}
uint64_t
&
i
=
*
p_completed
;
i
=
0
;
if
(
p_test
==
nullptr
)
{
return
;
}
auto
scale2
=
e8m0_bexp_t
(
8.0
f
);
bf8x32_ocp_t
bf8x32
{};
float32_t
float32
{};
ck
::
static_for
<
0
,
N
,
1
>
{}(
[
&
](
auto
ii
)
{
float32
[
static_cast
<
int
>
(
ii
)]
=
vec32_generator
(
ii
);
});
bf8x32
=
mxf8_convert_sr
<
bf8x32_ocp_t
>
(
float32
,
type_convert
<
float
>
(
scale2
));
ck
::
static_for
<
0
,
N
,
1
>
{}([
&
](
auto
ii
)
{
p_test
[
i
++
]
=
type_convert
<
float
>
(
bf8x32
.
AsType
<
bf8_ocp_t
>
()(
ck
::
Number
<
ii
>
{}));
});
}
TEST
(
MXBF8
,
DeviceF32x32ToBF8x32ScaledConvertSR
)
{
constexpr
int
N
=
32
;
std
::
vector
<
float
>
out
(
N
,
-
1.0
f
);
DeviceMem
device_out
(
N
*
sizeof
(
float
));
DeviceMem
device_completed
(
sizeof
(
uint64_t
));
device_out
.
SetValue
(
-
21.0
f
);
device_completed
.
SetValue
(
-
21.0
f
);
test_mx_bf8x32_device_scaled_convert_sr
<<<
1
,
1
>>>
(
static_cast
<
float
*>
(
device_out
.
GetDeviceBuffer
()),
static_cast
<
uint64_t
*>
(
device_completed
.
GetDeviceBuffer
()));
uint64_t
completed
=
0
;
device_completed
.
FromDevice
(
&
completed
);
device_out
.
FromDevice
(
out
.
data
());
auto
i
=
0
;
ck
::
static_for
<
0
,
N
,
1
>
{}([
&
](
auto
ii
)
{
EXPECT_EQ
(
out
[
i
++
],
vec32_generator
(
ii
)
/
8.0
f
)
<<
"ii: "
<<
ii
<<
std
::
endl
;
});
EXPECT_EQ
(
N
,
completed
);
EXPECT_EQ
(
N
,
i
);
}
__global__
void
test_mx_f32x32_device_scaled_convert
(
float
*
p_test
,
uint64_t
*
p_completed
)
{
constexpr
int
N
=
32
;
if
(
p_completed
==
nullptr
)
{
return
;
}
uint64_t
&
i
=
*
p_completed
;
i
=
0
;
if
(
p_test
==
nullptr
)
{
return
;
}
auto
scale2
=
e8m0_bexp_t
(
4.0
f
);
bf8x32_ocp_t
bf8x32
{};
float32_t
float32
{};
ck
::
static_for
<
0
,
N
,
1
>
{}([
&
](
auto
ii
)
{
bf8x32
.
AsType
<
bf8_ocp_t
>
()(
ii
)
=
type_convert
<
bf8_ocp_t
>
(
vec32_generator
(
ii
)
/
16.0
f
);
});
float32
=
scaled_type_convert
<
float32_t
>
(
scale2
,
bf8x32
);
ck
::
static_for
<
0
,
N
,
1
>
{}([
&
](
auto
ii
)
{
p_test
[
i
++
]
=
float32
[
static_cast
<
int
>
(
ii
)];
});
}
TEST
(
MXBF8
,
DeviceBF8x32ToF32x32ScaledConvert
)
{
constexpr
int
N
=
32
;
std
::
vector
<
float
>
out
(
N
,
-
1.0
f
);
DeviceMem
device_out
(
N
*
sizeof
(
float
));
DeviceMem
device_completed
(
sizeof
(
uint64_t
));
device_out
.
SetValue
(
-
21.0
f
);
device_completed
.
SetValue
(
-
21.0
f
);
test_mx_f32x32_device_scaled_convert
<<<
1
,
1
>>>
(
static_cast
<
float
*>
(
device_out
.
GetDeviceBuffer
()),
static_cast
<
uint64_t
*>
(
device_completed
.
GetDeviceBuffer
()));
uint64_t
completed
=
0
;
device_completed
.
FromDevice
(
&
completed
);
device_out
.
FromDevice
(
out
.
data
());
auto
i
=
0
;
ck
::
static_for
<
0
,
N
,
1
>
{}([
&
](
auto
ii
)
{
EXPECT_EQ
(
out
[
i
++
],
vec32_generator
(
ii
)
/
4.0
f
)
<<
"ii: "
<<
ii
<<
std
::
endl
;
});
EXPECT_EQ
(
N
,
completed
);
EXPECT_EQ
(
N
,
i
);
}
test/data_type/test_mx_fp8.cpp
0 → 100644
View file @
efab74a3
// SPDX-License-Identifier: MIT
// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved.
#include "gtest/gtest.h"
#include "ck/library/utility/device_memory.hpp"
#include "ck/utility/scaled_type_convert.hpp"
using
ck
::
e8m0_bexp_t
;
using
ck
::
f8_ocp_t
;
using
ck
::
f8x16_ocp_t
;
using
ck
::
f8x2_ocp_t
;
using
ck
::
f8x32_ocp_t
;
using
ck
::
float16_t
;
using
ck
::
float2_t
;
using
ck
::
float32_t
;
using
ck
::
mxf8_convert_rne
;
using
ck
::
mxf8_convert_sr
;
using
ck
::
scaled_type_convert
;
using
ck
::
type_convert
;
using
ck
::
fp8_impl
::
fp8x2_storage_t
;
constexpr
uint64_t
test_size
=
256
*
256
+
2
+
4
+
6
;
/**
* @brief Tests conversion of FP8 values to float using E8M0 exponent scaling.
*
* This function performs a series of conversions from FP8 values to float values using
* E8M0 exponent scaling. It handles all possible combinations of E8M0 and FP8 values,
* as well as specific vector and rounding conversions.
*
* @param N The maximum number of conversions to perform.
* @param p_test Pointer to the output array where the converted float values will be stored.
* @param p_completed Pointer to a variable that tracks the number of completed conversions.
*
* @note If either p_test or p_completed is nullptr, the function will return immediately.
* @note The function will stop converting if the number of conversions reaches N.
* @note First 256*256 conversions are for all possible combinations of E8M0 and FP8 values that are
* stored in memory sequentially with FP8 values varying faster.
*
* The function performs the following conversions:
* - All possible combinations of E8M0 and FP8 values. [256x256]
* - Vector conversions f8x2 -> f32x2. [2]
* - Vector conversions f32x2 -> f8x2 rne. [2]
* - Vector conversions f32x2 -> f8x2 sr. [2]
* - Round to nearest even conversions for specific float values. [6]
*
* The results are stored in the p_test array, and the number of completed conversions
* is updated in the p_completed variable.
*/
__host__
__device__
void
test_mx_fp8_scaled_convert
(
uint64_t
N
,
float
*
p_test
,
uint64_t
*
p_completed
)
{
if
(
p_completed
==
nullptr
)
{
return
;
}
uint64_t
&
i
=
*
p_completed
;
i
=
0
;
if
(
p_test
==
nullptr
)
{
return
;
}
// All possible combinations of E8M0 and FP8
for
(
ck
::
index_t
exp_id
=
0
;
exp_id
<
256
;
exp_id
++
)
{
for
(
ck
::
index_t
fp8_id
=
0
;
fp8_id
<
256
;
fp8_id
++
)
{
uint8_t
fp8_uid
=
static_cast
<
uint8_t
>
(
fp8_id
);
auto
v
=
scaled_type_convert
<
float
>
(
e8m0_bexp_t
(
exp_id
),
f8_ocp_t
{
fp8_uid
});
p_test
[
i
]
=
v
;
i
++
;
if
(
i
>=
N
)
{
return
;
}
}
}
/// Test vector conversions
// f8x2 -> f32x2
f8x2_ocp_t
fp8x2
{
f8x2_ocp_t
::
data_v
{
0b10001000
,
0b00000001
}};
//-2^-6, 2^-9
auto
scale2
=
e8m0_bexp_t
(
2.0
f
);
float2_t
f32x2
=
scaled_type_convert
<
float2_t
>
(
scale2
,
fp8x2
);
p_test
[
i
++
]
=
f32x2
[
0
];
if
(
i
>=
N
)
{
return
;
}
p_test
[
i
++
]
=
f32x2
[
1
];
if
(
i
>=
N
)
{
return
;
}
// f32x2 -> f8x2
f32x2
=
{
-
8.0
f
,
4.0
f
};
fp8x2
=
mxf8_convert_rne
<
f8x2_ocp_t
>
(
f32x2
,
type_convert
<
float
>
(
scale2
));
// expect {-4, 2}
p_test
[
i
++
]
=
type_convert
<
float
>
(
fp8x2
.
AsType
<
f8_ocp_t
>
()(
ck
::
Number
<
0
>
{}));
//-4f
if
(
i
>=
N
)
{
return
;
}
p_test
[
i
++
]
=
type_convert
<
float
>
(
fp8x2
.
AsType
<
f8_ocp_t
>
()(
ck
::
Number
<
1
>
{}));
// 2f
if
(
i
>=
N
)
{
return
;
}
auto
scale4
=
e8m0_bexp_t
(
4.0
f
);
fp8x2
=
mxf8_convert_sr
<
f8x2_ocp_t
>
(
f32x2
,
type_convert
<
float
>
(
scale4
));
// expect {-2, 1}
p_test
[
i
++
]
=
type_convert
<
float
>
(
fp8x2
.
AsType
<
f8_ocp_t
>
()(
ck
::
Number
<
0
>
{}));
//-2f
if
(
i
>=
N
)
{
return
;
}
p_test
[
i
++
]
=
type_convert
<
float
>
(
fp8x2
.
AsType
<
f8_ocp_t
>
()(
ck
::
Number
<
1
>
{}));
// 1f
if
(
i
>=
N
)
{
return
;
}
/// Test round to nearest even
p_test
[
i
++
]
=
type_convert
<
float
>
(
mxf8_convert_rne
<
f8_ocp_t
>
(
1024.0
f
,
4.0
f
));
// 1024/4
if
(
i
>=
N
)
{
return
;
}
p_test
[
i
++
]
=
type_convert
<
float
>
(
mxf8_convert_rne
<
f8_ocp_t
>
(
std
::
numeric_limits
<
float
>::
quiet_NaN
(),
4.0
f
));
// => NaN
if
(
i
>=
N
)
{
return
;
}
// Inf/2 > 448 => NaN on device
p_test
[
i
++
]
=
type_convert
<
float
>
(
mxf8_convert_rne
<
f8_ocp_t
>
(
std
::
numeric_limits
<
float
>::
infinity
(),
2.0
f
));
if
(
i
>=
N
)
{
return
;
}
// 256/0.5 > 448 => NaN on device
p_test
[
i
++
]
=
type_convert
<
float
>
(
mxf8_convert_rne
<
f8_ocp_t
>
(
256.0
f
,
0.5
f
));
if
(
i
>=
N
)
{
return
;
}
// -256/0.5 < -448 => NaN on device
p_test
[
i
++
]
=
type_convert
<
float
>
(
mxf8_convert_rne
<
f8_ocp_t
>
(
-
256.0
f
,
0.5
f
));
if
(
i
>=
N
)
{
return
;
}
// proper scale selection 2^13 < 10000; 2^8 < 448 => scale = 2^(13-8) = 2^5
p_test
[
i
++
]
=
type_convert
<
float
>
(
mxf8_convert_rne
<
f8_ocp_t
>
(
10000.0
f
,
32.0
f
));
// 10000/32 = 312.5
if
(
i
>=
N
)
{
return
;
}
}
TEST
(
MXFP8
,
HostScaledConvert
)
{
std
::
vector
<
float
>
out
(
test_size
,
-
1.0
f
);
uint64_t
completed
=
0
;
test_mx_fp8_scaled_convert
(
test_size
,
out
.
data
(),
&
completed
);
// V = X * P; X - E8M0 scale, P - FP8
// If X = NaN, then V = NaN regardless of P
uint8_t
e8m0_nan_id
=
ck
::
NumericLimits
<
e8m0_bexp_t
>::
QuietNaN
().
data
;
for
(
ck
::
index_t
fp8_id
=
0
;
fp8_id
<
256
;
fp8_id
++
)
{
auto
idx
=
e8m0_nan_id
*
256
+
fp8_id
;
ASSERT_TRUE
(
std
::
isnan
(
out
[
idx
]));
}
// If P in {Inf, NaN}, then V = P
std
::
set
<
uint8_t
>
fp8_nan_ids
;
fp8_nan_ids
.
insert
(
0b11111111
);
//-NaN
fp8_nan_ids
.
insert
(
0b01111111
);
// +NaN
for
(
ck
::
index_t
exp_id
=
0
;
exp_id
<
256
;
exp_id
++
)
{
if
(
exp_id
==
e8m0_nan_id
)
continue
;
for
(
auto
fp8_nan_id
:
fp8_nan_ids
)
{
auto
idx
=
exp_id
*
256
+
fp8_nan_id
;
ASSERT_TRUE
(
std
::
isnan
(
out
[
idx
]));
}
}
for
(
ck
::
index_t
exp_id
=
0
;
exp_id
<
256
;
exp_id
++
)
{
if
(
exp_id
==
e8m0_nan_id
)
continue
;
for
(
ck
::
index_t
fp8_id
=
0
;
fp8_id
<
256
;
fp8_id
++
)
{
if
(
fp8_nan_ids
.
find
(
fp8_id
)
!=
fp8_nan_ids
.
end
())
continue
;
uint8_t
fp8_uid
=
static_cast
<
uint8_t
>
(
fp8_id
);
auto
idx
=
exp_id
*
256
+
fp8_uid
;
ASSERT_FLOAT_EQ
(
out
[
idx
],
type_convert
<
float
>
(
e8m0_bexp_t
(
exp_id
))
*
type_convert
<
float
>
(
f8_ocp_t
{
fp8_uid
}))
<<
"exp_id: "
<<
exp_id
<<
" fp8_id: "
<<
fp8_id
<<
std
::
endl
<<
type_convert
<
float
>
(
e8m0_bexp_t
(
exp_id
))
<<
" * "
<<
type_convert
<
float
>
(
f8_ocp_t
{
fp8_uid
});
}
}
/// Test vector conversions
auto
i
=
256
*
256
;
// f8x2 -> f32x2
EXPECT_EQ
(
out
[
i
++
],
-
powf
(
2.0
f
,
-
5.0
f
));
EXPECT_EQ
(
out
[
i
++
],
powf
(
2.0
f
,
-
8.0
f
));
// f32x2 -> fp8x2
// RNE
EXPECT_EQ
(
out
[
i
++
],
-
4.0
f
);
EXPECT_EQ
(
out
[
i
++
],
2.0
f
);
// SR
EXPECT_EQ
(
out
[
i
++
],
-
2.0
f
);
EXPECT_EQ
(
out
[
i
++
],
1.0
f
);
/// Test round to nearest even
EXPECT_EQ
(
out
[
i
++
],
1024.0
f
/
4.0
f
)
<<
"out[i-1]: "
<<
out
[
i
-
1
];
EXPECT_TRUE
(
std
::
isnan
(
out
[
i
++
]))
<<
"out[i-1]: "
<<
out
[
i
-
1
];
EXPECT_EQ
(
out
[
i
++
],
type_convert
<
float
>
(
ck
::
NumericLimits
<
f8_ocp_t
>::
Max
()))
<<
"out[i-1]: "
<<
out
[
i
-
1
];
EXPECT_EQ
(
out
[
i
++
],
type_convert
<
float
>
(
ck
::
NumericLimits
<
f8_ocp_t
>::
Max
()))
<<
"out[i-1]: "
<<
out
[
i
-
1
];
EXPECT_EQ
(
out
[
i
++
],
type_convert
<
float
>
(
ck
::
NumericLimits
<
f8_ocp_t
>::
Lowest
()))
<<
"out[i-1]: "
<<
out
[
i
-
1
];
EXPECT_EQ
(
out
[
i
++
],
type_convert
<
float
>
(
type_convert
<
f8_ocp_t
>
(
312.5
f
)))
<<
"out[i-1]: "
<<
out
[
i
-
1
];
EXPECT_EQ
(
test_size
,
completed
);
EXPECT_EQ
(
test_size
,
i
);
}
__global__
void
test_mx_fp8_device_scaled_convert
(
uint64_t
N
,
float
*
p_test
,
uint64_t
*
p_completed
)
{
test_mx_fp8_scaled_convert
(
N
,
p_test
,
p_completed
);
}
TEST
(
MXFP8
,
DeviceScaledConvert
)
{
std
::
vector
<
float
>
out
(
test_size
,
-
1.0
f
);
DeviceMem
device_out
(
test_size
*
sizeof
(
float
));
DeviceMem
device_completed
(
sizeof
(
uint64_t
));
device_out
.
SetValue
(
-
21.0
f
);
device_completed
.
SetValue
(
-
21.0
f
);
test_mx_fp8_device_scaled_convert
<<<
1
,
1
>>>
(
test_size
,
static_cast
<
float
*>
(
device_out
.
GetDeviceBuffer
()),
static_cast
<
uint64_t
*>
(
device_completed
.
GetDeviceBuffer
()));
uint64_t
completed
=
0
;
device_completed
.
FromDevice
(
&
completed
);
device_out
.
FromDevice
(
out
.
data
());
// V = X * P; X - E8M0 scale, P - FP8
// If X = NaN, then V = NaN regardless of P
uint8_t
e8m0_nan_id
=
ck
::
NumericLimits
<
e8m0_bexp_t
>::
QuietNaN
().
data
;
for
(
ck
::
index_t
fp8_id
=
0
;
fp8_id
<
256
;
fp8_id
++
)
{
auto
idx
=
e8m0_nan_id
*
256
+
fp8_id
;
ASSERT_TRUE
(
std
::
isnan
(
out
[
idx
]))
<<
"idx: "
<<
idx
<<
" out[idx]: "
<<
out
[
idx
];
}
// If P in {Inf, NaN}, then V = P
std
::
set
<
uint8_t
>
fp8_nan_ids
;
fp8_nan_ids
.
insert
(
0b11111111
);
//-NaN
fp8_nan_ids
.
insert
(
0b01111111
);
// +NaN
for
(
ck
::
index_t
exp_id
=
0
;
exp_id
<
256
;
exp_id
++
)
{
if
(
exp_id
==
e8m0_nan_id
)
continue
;
for
(
auto
fp8_nan_id
:
fp8_nan_ids
)
{
auto
idx
=
exp_id
*
256
+
fp8_nan_id
;
ASSERT_TRUE
(
std
::
isnan
(
out
[
idx
]))
<<
"idx: "
<<
idx
<<
" out[idx]: "
<<
out
[
idx
];
}
}
for
(
ck
::
index_t
exp_id
=
0
;
exp_id
<
256
;
exp_id
++
)
{
if
(
exp_id
==
e8m0_nan_id
)
continue
;
for
(
ck
::
index_t
fp8_id
=
0
;
fp8_id
<
256
;
fp8_id
++
)
{
if
(
fp8_nan_ids
.
find
(
fp8_id
)
!=
fp8_nan_ids
.
end
())
continue
;
uint8_t
fp8_uid
=
static_cast
<
uint8_t
>
(
fp8_id
);
auto
idx
=
exp_id
*
256
+
fp8_uid
;
ASSERT_FLOAT_EQ
(
out
[
idx
],
type_convert
<
float
>
(
e8m0_bexp_t
(
exp_id
))
*
type_convert
<
float
>
(
f8_ocp_t
{
fp8_uid
}))
<<
"exp_id: "
<<
exp_id
<<
" fp8_id: "
<<
fp8_id
<<
std
::
endl
<<
type_convert
<
float
>
(
e8m0_bexp_t
(
exp_id
))
<<
" * "
<<
type_convert
<
float
>
(
f8_ocp_t
{
fp8_uid
});
}
}
/// Test vector conversions
auto
i
=
256
*
256
;
// f8x2 -> f32x2
EXPECT_EQ
(
out
[
i
++
],
-
powf
(
2.0
f
,
-
5.0
f
));
EXPECT_EQ
(
out
[
i
++
],
powf
(
2.0
f
,
-
8.0
f
));
// f32x2 -> fp8x2
// RNE
EXPECT_EQ
(
out
[
i
++
],
-
4.0
f
);
EXPECT_EQ
(
out
[
i
++
],
2.0
f
);
// SR
EXPECT_EQ
(
out
[
i
++
],
-
2.0
f
);
EXPECT_EQ
(
out
[
i
++
],
1.0
f
);
/// Test round to nearest even
EXPECT_EQ
(
out
[
i
++
],
1024.0
f
/
4.0
f
)
<<
"out[i-1]: "
<<
out
[
i
-
1
];
EXPECT_TRUE
(
std
::
isnan
(
out
[
i
++
]))
<<
"out[i-1]: "
<<
out
[
i
-
1
];
#if 1
EXPECT_TRUE
(
std
::
isnan
(
out
[
i
++
]))
<<
"out[i-1]: "
<<
out
[
i
-
1
];
EXPECT_TRUE
(
std
::
isnan
(
out
[
i
++
]))
<<
"out[i-1]: "
<<
out
[
i
-
1
];
EXPECT_TRUE
(
std
::
isnan
(
out
[
i
++
]))
<<
"out[i-1]: "
<<
out
[
i
-
1
];
#else
// NOTE: Host and Device have different behavior.
// Device returns NaN, while Host returns Max (saturation to finite value).
EXPECT_EQ
(
out
[
i
++
],
type_convert
<
float
>
(
ck
::
NumericLimits
<
f8_ocp_t
>::
Max
()))
<<
"out[i-1]: "
<<
out
[
i
-
1
];
EXPECT_EQ
(
out
[
i
++
],
type_convert
<
float
>
(
ck
::
NumericLimits
<
f8_ocp_t
>::
Max
()))
<<
"out[i-1]: "
<<
out
[
i
-
1
];
EXPECT_EQ
(
out
[
i
++
],
type_convert
<
float
>
(
ck
::
NumericLimits
<
f8_ocp_t
>::
Lowest
()))
<<
"out[i-1]: "
<<
out
[
i
-
1
];
#endif
EXPECT_EQ
(
out
[
i
++
],
type_convert
<
float
>
(
type_convert
<
f8_ocp_t
>
(
312.5
f
)))
<<
"out[i-1]: "
<<
out
[
i
-
1
];
EXPECT_EQ
(
test_size
,
completed
);
EXPECT_EQ
(
test_size
,
i
);
}
__host__
__device__
float
vec16_generator
(
ck
::
index_t
i
)
{
return
(
i
<
8
?
-
1.0
:
1.0
)
*
powf
(
2.0
f
,
i
%
8
);
}
__global__
void
test_mx_fp8x16_device_scaled_convert
(
float
*
p_test
,
uint64_t
*
p_completed
)
{
constexpr
int
N
=
16
;
if
(
p_completed
==
nullptr
)
{
return
;
}
uint64_t
&
i
=
*
p_completed
;
i
=
0
;
if
(
p_test
==
nullptr
)
{
return
;
}
auto
scale2
=
e8m0_bexp_t
(
2.0
f
);
f8x16_ocp_t
fp8x16
{};
float16_t
float16
{};
ck
::
static_for
<
0
,
N
,
1
>
{}(
[
&
](
auto
ii
)
{
float16
[
static_cast
<
int
>
(
ii
)]
=
vec16_generator
(
ii
);
});
fp8x16
=
scaled_type_convert
<
ck
::
f8x16_ocp_t
>
(
scale2
,
float16
);
ck
::
static_for
<
0
,
N
,
1
>
{}([
&
](
auto
ii
)
{
p_test
[
i
++
]
=
type_convert
<
float
>
(
fp8x16
.
AsType
<
f8_ocp_t
>
()(
ck
::
Number
<
ii
>
{}));
});
}
TEST
(
MXFP8
,
DeviceF32x16ToF8x16ScaledConvert
)
{
constexpr
int
N
=
16
;
std
::
vector
<
float
>
out
(
N
,
-
1.0
f
);
DeviceMem
device_out
(
N
*
sizeof
(
float
));
DeviceMem
device_completed
(
sizeof
(
uint64_t
));
device_out
.
SetValue
(
-
21.0
f
);
device_completed
.
SetValue
(
-
21.0
f
);
test_mx_fp8x16_device_scaled_convert
<<<
1
,
1
>>>
(
static_cast
<
float
*>
(
device_out
.
GetDeviceBuffer
()),
static_cast
<
uint64_t
*>
(
device_completed
.
GetDeviceBuffer
()));
uint64_t
completed
=
0
;
device_completed
.
FromDevice
(
&
completed
);
device_out
.
FromDevice
(
out
.
data
());
auto
i
=
0
;
ck
::
static_for
<
0
,
N
,
1
>
{}([
&
](
auto
ii
)
{
EXPECT_EQ
(
out
[
i
++
],
vec16_generator
(
ii
)
/
2.0
f
)
<<
"ii: "
<<
ii
<<
std
::
endl
;
});
EXPECT_EQ
(
N
,
completed
);
EXPECT_EQ
(
N
,
i
);
}
__host__
__device__
float
vec32_generator
(
ck
::
index_t
i
)
{
if
(
i
<
16
)
{
return
vec16_generator
(
i
%
16
);
}
else
{
return
1.5
f
*
vec16_generator
(
i
%
16
);
}
}
__global__
void
test_mx_fp8x32_device_scaled_convert
(
float
*
p_test
,
uint64_t
*
p_completed
)
{
constexpr
int
N
=
32
;
if
(
p_completed
==
nullptr
)
{
return
;
}
uint64_t
&
i
=
*
p_completed
;
i
=
0
;
if
(
p_test
==
nullptr
)
{
return
;
}
auto
scale2
=
e8m0_bexp_t
(
2.0
f
);
f8x32_ocp_t
fp8x32
{};
float32_t
float32
{};
ck
::
static_for
<
0
,
N
,
1
>
{}(
[
&
](
auto
ii
)
{
float32
[
static_cast
<
int
>
(
ii
)]
=
vec32_generator
(
ii
);
});
fp8x32
=
mxf8_convert_rne
<
f8x32_ocp_t
>
(
float32
,
type_convert
<
float
>
(
scale2
));
ck
::
static_for
<
0
,
N
,
1
>
{}(
[
&
](
auto
ii
)
{
p_test
[
i
++
]
=
type_convert
<
float
>
(
fp8x32
.
AsType
<
f8_ocp_t
>
()(
ii
));
});
}
TEST
(
MXFP8
,
DeviceF32x32ToF8x32ScaledConvert
)
{
constexpr
int
N
=
32
;
std
::
vector
<
float
>
out
(
N
,
-
1.0
f
);
DeviceMem
device_out
(
N
*
sizeof
(
float
));
DeviceMem
device_completed
(
sizeof
(
uint64_t
));
device_out
.
SetValue
(
-
21.0
f
);
device_completed
.
SetValue
(
-
21.0
f
);
test_mx_fp8x32_device_scaled_convert
<<<
1
,
1
>>>
(
static_cast
<
float
*>
(
device_out
.
GetDeviceBuffer
()),
static_cast
<
uint64_t
*>
(
device_completed
.
GetDeviceBuffer
()));
uint64_t
completed
=
0
;
device_completed
.
FromDevice
(
&
completed
);
device_out
.
FromDevice
(
out
.
data
());
auto
i
=
0
;
ck
::
static_for
<
0
,
N
,
1
>
{}([
&
](
auto
ii
)
{
EXPECT_EQ
(
out
[
i
++
],
vec32_generator
(
ii
)
/
2.0
f
)
<<
"ii: "
<<
ii
<<
std
::
endl
;
});
EXPECT_EQ
(
N
,
completed
);
EXPECT_EQ
(
N
,
i
);
}
__global__
void
test_mx_fp8x32_device_scaled_convert_sr
(
float
*
p_test
,
uint64_t
*
p_completed
)
{
constexpr
int
N
=
32
;
if
(
p_completed
==
nullptr
)
{
return
;
}
uint64_t
&
i
=
*
p_completed
;
i
=
0
;
if
(
p_test
==
nullptr
)
{
return
;
}
auto
scale2
=
e8m0_bexp_t
(
8.0
f
);
f8x32_ocp_t
fp8x32
{};
float32_t
float32
{};
ck
::
static_for
<
0
,
N
,
1
>
{}(
[
&
](
auto
ii
)
{
float32
[
static_cast
<
int
>
(
ii
)]
=
vec32_generator
(
ii
);
});
fp8x32
=
mxf8_convert_sr
<
f8x32_ocp_t
>
(
float32
,
type_convert
<
float
>
(
scale2
));
ck
::
static_for
<
0
,
N
,
1
>
{}(
[
&
](
auto
ii
)
{
p_test
[
i
++
]
=
type_convert
<
float
>
(
fp8x32
.
AsType
<
f8_ocp_t
>
()(
ii
));
});
}
TEST
(
MXFP8
,
DeviceF32x32ToF8x32ScaledConvertSR
)
{
constexpr
int
N
=
32
;
std
::
vector
<
float
>
out
(
N
,
-
1.0
f
);
DeviceMem
device_out
(
N
*
sizeof
(
float
));
DeviceMem
device_completed
(
sizeof
(
uint64_t
));
device_out
.
SetValue
(
-
21.0
f
);
device_completed
.
SetValue
(
-
21.0
f
);
test_mx_fp8x32_device_scaled_convert_sr
<<<
1
,
1
>>>
(
static_cast
<
float
*>
(
device_out
.
GetDeviceBuffer
()),
static_cast
<
uint64_t
*>
(
device_completed
.
GetDeviceBuffer
()));
uint64_t
completed
=
0
;
device_completed
.
FromDevice
(
&
completed
);
device_out
.
FromDevice
(
out
.
data
());
auto
i
=
0
;
ck
::
static_for
<
0
,
N
,
1
>
{}([
&
](
auto
ii
)
{
EXPECT_EQ
(
out
[
i
++
],
vec32_generator
(
ii
)
/
8.0
f
)
<<
"ii: "
<<
ii
<<
std
::
endl
;
});
EXPECT_EQ
(
N
,
completed
);
EXPECT_EQ
(
N
,
i
);
}
__global__
void
test_mx_f32x32_device_scaled_convert
(
float
*
p_test
,
uint64_t
*
p_completed
)
{
constexpr
int
N
=
32
;
if
(
p_completed
==
nullptr
)
{
return
;
}
uint64_t
&
i
=
*
p_completed
;
i
=
0
;
if
(
p_test
==
nullptr
)
{
return
;
}
auto
scale2
=
e8m0_bexp_t
(
4.0
f
);
f8x32_ocp_t
fp8x32
{};
float32_t
float32
{};
ck
::
static_for
<
0
,
N
,
1
>
{}([
&
](
auto
ii
)
{
fp8x32
.
AsType
<
f8_ocp_t
>
()(
ii
)
=
type_convert
<
f8_ocp_t
>
(
vec32_generator
(
ii
)
/
16.0
f
);
});
float32
=
scaled_type_convert
<
float32_t
>
(
scale2
,
fp8x32
);
ck
::
static_for
<
0
,
N
,
1
>
{}([
&
](
auto
ii
)
{
p_test
[
i
++
]
=
float32
[
static_cast
<
int
>
(
ii
)];
});
}
TEST
(
MXFP8
,
DeviceF8x32ToF32x32ScaledConvert
)
{
constexpr
int
N
=
32
;
std
::
vector
<
float
>
out
(
N
,
-
1.0
f
);
DeviceMem
device_out
(
N
*
sizeof
(
float
));
DeviceMem
device_completed
(
sizeof
(
uint64_t
));
device_out
.
SetValue
(
-
21.0
f
);
device_completed
.
SetValue
(
-
21.0
f
);
test_mx_f32x32_device_scaled_convert
<<<
1
,
1
>>>
(
static_cast
<
float
*>
(
device_out
.
GetDeviceBuffer
()),
static_cast
<
uint64_t
*>
(
device_completed
.
GetDeviceBuffer
()));
uint64_t
completed
=
0
;
device_completed
.
FromDevice
(
&
completed
);
device_out
.
FromDevice
(
out
.
data
());
auto
i
=
0
;
ck
::
static_for
<
0
,
N
,
1
>
{}([
&
](
auto
ii
)
{
EXPECT_EQ
(
out
[
i
++
],
vec32_generator
(
ii
)
/
4.0
f
)
<<
"ii: "
<<
ii
<<
std
::
endl
;
});
EXPECT_EQ
(
N
,
completed
);
EXPECT_EQ
(
N
,
i
);
}
test/grouped_convnd_bwd_weight/test_grouped_convnd_bwd_weight.cpp
View file @
efab74a3
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-202
4
, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-202
5
, Advanced Micro Devices, Inc. All rights reserved.
#include <cstdlib>
#include <iostream>
...
...
@@ -43,7 +43,6 @@ class TestGroupedConvndBwdWeight : public ::testing::Test
return
true
;
}
}
if
(
ck
::
is_gfx11_supported
()
||
ck
::
is_gfx12_supported
())
{
// on gfx11 only support for 3d is implemented
...
...
@@ -143,19 +142,23 @@ using KernelTypes2d = ::testing::Types<
std
::
tuple
<
float
,
float
,
float
,
GNHWC
,
GKYXC
,
GNHWK
,
ck
::
Number
<
2
>>
,
std
::
tuple
<
ck
::
half_t
,
ck
::
half_t
,
ck
::
half_t
,
GNHWC
,
GKYXC
,
GNHWK
,
ck
::
Number
<
2
>>
,
std
::
tuple
<
ck
::
bhalf_t
,
float
,
ck
::
bhalf_t
,
GNHWC
,
GKYXC
,
GNHWK
,
ck
::
Number
<
2
>>
,
std
::
tuple
<
ck
::
bhalf_t
,
ck
::
bhalf_t
,
ck
::
bhalf_t
,
NHWGC
,
GKYXC
,
NHWGK
,
ck
::
Number
<
2
>>
,
std
::
tuple
<
float
,
float
,
float
,
NHWGC
,
GKYXC
,
NHWGK
,
ck
::
Number
<
2
>>
,
std
::
tuple
<
ck
::
half_t
,
ck
::
half_t
,
ck
::
half_t
,
NHWGC
,
GKYXC
,
NHWGK
,
ck
::
Number
<
2
>>
,
std
::
tuple
<
ck
::
bhalf_t
,
float
,
ck
::
bhalf_t
,
NHWGC
,
GKYXC
,
NHWGK
,
ck
::
Number
<
2
>>
,
std
::
tuple
<
ck
::
bhalf_t
,
ck
::
bhalf_t
,
ck
::
bhalf_t
,
NGCHW
,
GKYXC
,
NGKHW
,
ck
::
Number
<
2
>>
,
std
::
tuple
<
ck
::
half_t
,
ck
::
half_t
,
ck
::
half_t
,
NGCHW
,
GKYXC
,
NGKHW
,
ck
::
Number
<
2
>>>
;
using
KernelTypes3d
=
::
testing
::
Types
<
std
::
tuple
<
float
,
float
,
float
,
GNDHWC
,
GKZYXC
,
GNDHWK
,
ck
::
Number
<
3
>>
,
std
::
tuple
<
ck
::
half_t
,
ck
::
half_t
,
ck
::
half_t
,
GNDHWC
,
GKZYXC
,
GNDHWK
,
ck
::
Number
<
3
>>
,
std
::
tuple
<
ck
::
bhalf_t
,
float
,
ck
::
bhalf_t
,
GNDHWC
,
GKZYXC
,
GNDHWK
,
ck
::
Number
<
3
>>
,
std
::
tuple
<
int8_t
,
int8_t
,
int8_t
,
GNDHWC
,
GKZYXC
,
GNDHWK
,
ck
::
Number
<
3
>>
,
std
::
tuple
<
ck
::
bhalf_t
,
ck
::
bhalf_t
,
ck
::
bhalf_t
,
NDHWGC
,
GKZYXC
,
NDHWGK
,
ck
::
Number
<
3
>>
,
std
::
tuple
<
float
,
float
,
float
,
NDHWGC
,
GKZYXC
,
NDHWGK
,
ck
::
Number
<
3
>>
,
std
::
tuple
<
ck
::
half_t
,
ck
::
half_t
,
ck
::
half_t
,
NDHWGC
,
GKZYXC
,
NDHWGK
,
ck
::
Number
<
3
>>
,
std
::
tuple
<
ck
::
bhalf_t
,
float
,
ck
::
bhalf_t
,
NDHWGC
,
GKZYXC
,
NDHWGK
,
ck
::
Number
<
3
>>
,
std
::
tuple
<
int8_t
,
int8_t
,
int8_t
,
NDHWGC
,
GKZYXC
,
NDHWGK
,
ck
::
Number
<
3
>>
,
std
::
tuple
<
ck
::
bhalf_t
,
ck
::
bhalf_t
,
ck
::
bhalf_t
,
NGCDHW
,
GKZYXC
,
NGKDHW
,
ck
::
Number
<
3
>>
,
std
::
tuple
<
ck
::
half_t
,
ck
::
half_t
,
ck
::
half_t
,
NGCDHW
,
GKZYXC
,
NGKDHW
,
ck
::
Number
<
3
>>>
;
TYPED_TEST_SUITE
(
TestGroupedConvndBwdWeight1d
,
KernelTypes1d
);
...
...
@@ -179,6 +182,8 @@ TYPED_TEST(TestGroupedConvndBwdWeight2d, Test2D)
this
->
conv_params
.
clear
();
this
->
conv_params
.
push_back
(
{
2
,
2
,
64
,
128
,
256
,
{
1
,
1
},
{
7
,
7
},
{
2
,
2
},
{
1
,
1
},
{
0
,
0
},
{
0
,
0
}});
this
->
conv_params
.
push_back
({
2
,
2
,
64
,
3
,
3
,
{
1
,
1
},
{
7
,
7
},
{
1
,
1
},
{
1
,
1
},
{
0
,
0
},
{
0
,
0
}});
this
->
conv_params
.
push_back
({
2
,
2
,
64
,
5
,
5
,
{
1
,
1
},
{
7
,
7
},
{
1
,
1
},
{
1
,
1
},
{
0
,
0
},
{
0
,
0
}});
this
->
conv_params
.
push_back
(
{
2
,
2
,
4
,
128
,
256
,
{
3
,
3
},
{
14
,
14
},
{
1
,
1
},
{
1
,
1
},
{
1
,
1
},
{
1
,
1
}});
this
->
conv_params
.
push_back
(
...
...
Prev
1
…
14
15
16
17
18
19
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment