Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel_ROCM
Commits
1111bc69
Commit
1111bc69
authored
Aug 23, 2024
by
illsilin
Browse files
Merge remote-tracking branch 'composable_kernel/develop' into merge_from_public
parents
3d61f89a
25935b57
Changes
29
Hide whitespace changes
Inline
Side-by-side
Showing
9 changed files
with
182 additions
and
97 deletions
+182
-97
library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_convscale/xdl/device_grouped_conv3d_fwd_xdl_combconvscale_ndhwgc_gkzyxc_ndhwgk_f8_f8_f32_instance.cpp
...combconvscale_ndhwgc_gkzyxc_ndhwgk_f8_f8_f32_instance.cpp
+61
-0
library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_convscale_relu/CMakeLists.txt
...ance/gpu/grouped_conv3d_fwd_convscale_relu/CMakeLists.txt
+2
-1
library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_convscale_relu/xdl/device_grouped_conv3d_fwd_xdl_combconvscale_relu_ndhwgc_gkzyxc_ndhwgk_f8_f8_f32_instance.cpp
...onvscale_relu_ndhwgc_gkzyxc_ndhwgk_f8_f8_f32_instance.cpp
+61
-0
library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_convscale_relu/xdl/device_grouped_conv3d_fwd_xdl_convscale_relu_ndhwgc_gkzyxc_ndhwgk_f8_instance.cpp
...d_xdl_convscale_relu_ndhwgc_gkzyxc_ndhwgk_f8_instance.cpp
+1
-53
profiler/include/profiler/profile_grouped_conv_bwd_weight_impl.hpp
...include/profiler/profile_grouped_conv_bwd_weight_impl.hpp
+0
-4
script/convert_miopen_driver_to_profiler.py
script/convert_miopen_driver_to_profiler.py
+1
-1
test/data_type/CMakeLists.txt
test/data_type/CMakeLists.txt
+10
-0
test/data_type/test_bf8.cpp
test/data_type/test_bf8.cpp
+23
-19
test/data_type/test_fp8.cpp
test/data_type/test_fp8.cpp
+23
-19
No files found.
library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_convscale/xdl/device_grouped_conv3d_fwd_xdl_combconvscale_ndhwgc_gkzyxc_ndhwgk_f8_f8_f32_instance.cpp
0 → 100644
View file @
1111bc69
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_outelementop_instance.hpp"
#include "ck/library/tensor_operation_instance/gpu/grouped_convolution_forward_convscale.hpp"
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
namespace
ck
{
namespace
tensor_operation
{
namespace
device
{
namespace
instance
{
void
add_device_grouped_conv3d_fwd_xdl_combconvscale_ndhwgc_gkzyxc_ndhwgk_f8_f8_f32_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGroupedConvFwdMultipleABD
<
3
,
NDHWGC
,
GKZYXC
,
ck
::
Tuple
<>
,
NDHWGK
,
F8
,
F8
,
ck
::
Tuple
<>
,
F32
,
PassThrough
,
PassThrough
,
CombConvScale
,
F8
,
F8
>>>&
instances
)
{
add_device_operation_instances
(
instances
,
device_grouped_conv_fwd_xdl_outelementop_f8_f8_f32_instances
<
3
,
NDHWGC
,
GKZYXC
,
ck
::
Tuple
<>
,
NDHWGK
,
ConvFwdDefault
,
CombConvScale
>
{});
add_device_operation_instances
(
instances
,
device_grouped_conv_fwd_xdl_outelementop_f8_f8_f32_instances
<
3
,
NDHWGC
,
GKZYXC
,
ck
::
Tuple
<>
,
NDHWGK
,
ConvFwd1x1P0
,
CombConvScale
>
{});
add_device_operation_instances
(
instances
,
device_grouped_conv_fwd_xdl_outelementop_f8_f8_f32_instances
<
3
,
NDHWGC
,
GKZYXC
,
ck
::
Tuple
<>
,
NDHWGK
,
ConvFwd1x1S1P0
,
CombConvScale
>
{});
}
}
// namespace instance
}
// namespace device
}
// namespace tensor_operation
}
// namespace ck
library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_convscale_relu/CMakeLists.txt
View file @
1111bc69
# ONLY XDL_KERNELS
# ONLY XDL_KERNELS
set
(
GROUPED_CONV3D_FWD_CONVSCALE_RELU
set
(
GROUPED_CONV3D_FWD_CONVSCALE_RELU
xdl/device_grouped_conv3d_fwd_xdl_convscale_relu_ndhwgc_gkzyxc_ndhwgk_f8_instance.cpp
)
xdl/device_grouped_conv3d_fwd_xdl_convscale_relu_ndhwgc_gkzyxc_ndhwgk_f8_instance.cpp
xdl/device_grouped_conv3d_fwd_xdl_combconvscale_relu_ndhwgc_gkzyxc_ndhwgk_f8_f8_f32_instance.cpp
)
add_instance_library
(
device_grouped_conv3d_fwd_convscale_relu_instance
${
GROUPED_CONV3D_FWD_CONVSCALE_RELU
}
)
add_instance_library
(
device_grouped_conv3d_fwd_convscale_relu_instance
${
GROUPED_CONV3D_FWD_CONVSCALE_RELU
}
)
library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_convscale_relu/xdl/device_grouped_conv3d_fwd_xdl_combconvscale_relu_ndhwgc_gkzyxc_ndhwgk_f8_f8_f32_instance.cpp
0 → 100644
View file @
1111bc69
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_outelementop_instance.hpp"
#include "ck/library/tensor_operation_instance/gpu/grouped_convolution_forward_convscale_relu.hpp"
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
namespace
ck
{
namespace
tensor_operation
{
namespace
device
{
namespace
instance
{
void
add_device_grouped_conv3d_fwd_xdl_combconvscale_relu_ndhwgc_gkzyxc_ndhwgk_f8_f8_f32_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGroupedConvFwdMultipleABD
<
3
,
NDHWGC
,
GKZYXC
,
ck
::
Tuple
<>
,
NDHWGK
,
F8
,
F8
,
ck
::
Tuple
<>
,
F32
,
PassThrough
,
PassThrough
,
CombConvScaleRelu
,
F8
,
F8
>>>&
instances
)
{
add_device_operation_instances
(
instances
,
device_grouped_conv_fwd_xdl_outelementop_f8_f8_f32_instances
<
3
,
NDHWGC
,
GKZYXC
,
ck
::
Tuple
<>
,
NDHWGK
,
ConvFwdDefault
,
CombConvScaleRelu
>
{});
add_device_operation_instances
(
instances
,
device_grouped_conv_fwd_xdl_outelementop_f8_f8_f32_instances
<
3
,
NDHWGC
,
GKZYXC
,
ck
::
Tuple
<>
,
NDHWGK
,
ConvFwd1x1P0
,
CombConvScaleRelu
>
{});
add_device_operation_instances
(
instances
,
device_grouped_conv_fwd_xdl_outelementop_f8_f8_f32_instances
<
3
,
NDHWGC
,
GKZYXC
,
ck
::
Tuple
<>
,
NDHWGK
,
ConvFwd1x1S1P0
,
CombConvScaleRelu
>
{});
}
}
// namespace instance
}
// namespace device
}
// namespace tensor_operation
}
// namespace ck
library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_convscale_relu/xdl/device_grouped_conv3d_fwd_xdl_convscale_relu_ndhwgc_gkzyxc_ndhwgk_f8_instance.cpp
View file @
1111bc69
...
@@ -3,16 +3,13 @@
...
@@ -3,16 +3,13 @@
#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_outelementop_instance.hpp"
#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_outelementop_instance.hpp"
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
#include "ck/tensor_operation/gpu/element/combined_element_wise_operation.hpp"
#include "ck/library/tensor_operation_instance/gpu/grouped_convolution_forward_convscale_relu.hpp"
#include "ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp"
namespace
ck
{
namespace
ck
{
namespace
tensor_operation
{
namespace
tensor_operation
{
namespace
device
{
namespace
device
{
namespace
instance
{
namespace
instance
{
using
ConvScaleRelu
=
ck
::
tensor_operation
::
element_wise
::
ConvScaleRelu
;
void
add_device_grouped_conv3d_fwd_xdl_convscale_relu_ndhwgc_gkzyxc_ndhwgk_f8_instances
(
void
add_device_grouped_conv3d_fwd_xdl_convscale_relu_ndhwgc_gkzyxc_ndhwgk_f8_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGroupedConvFwdMultipleABD
<
3
,
std
::
vector
<
std
::
unique_ptr
<
DeviceGroupedConvFwdMultipleABD
<
3
,
NDHWGC
,
NDHWGC
,
...
@@ -57,55 +54,6 @@ void add_device_grouped_conv3d_fwd_xdl_convscale_relu_ndhwgc_gkzyxc_ndhwgk_f8_in
...
@@ -57,55 +54,6 @@ void add_device_grouped_conv3d_fwd_xdl_convscale_relu_ndhwgc_gkzyxc_ndhwgk_f8_in
ConvFwd1x1S1P0
,
ConvFwd1x1S1P0
,
ConvScaleRelu
>
{});
ConvScaleRelu
>
{});
}
}
namespace
ew
=
ck
::
tensor_operation
::
element_wise
;
using
CombConvScaleRelu
=
ew
::
UnaryCombinedOp
<
ew
::
Scale
,
ew
::
Scale
,
ew
::
Relu
>
;
void
add_device_grouped_conv3d_fwd_xdl_combconvscale_relu_ndhwgc_gkzyxc_ndhwgk_f8_f8_f32_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGroupedConvFwdMultipleABD
<
3
,
NDHWGC
,
GKZYXC
,
ck
::
Tuple
<>
,
NDHWGK
,
F8
,
F8
,
ck
::
Tuple
<>
,
F32
,
PassThrough
,
PassThrough
,
CombConvScaleRelu
,
F8
,
F8
>>>&
instances
)
{
add_device_operation_instances
(
instances
,
device_grouped_conv_fwd_xdl_outelementop_f8_f8_f32_instances
<
3
,
NDHWGC
,
GKZYXC
,
ck
::
Tuple
<>
,
NDHWGK
,
ConvFwdDefault
,
CombConvScaleRelu
>
{});
add_device_operation_instances
(
instances
,
device_grouped_conv_fwd_xdl_outelementop_f8_f8_f32_instances
<
3
,
NDHWGC
,
GKZYXC
,
ck
::
Tuple
<>
,
NDHWGK
,
ConvFwd1x1P0
,
CombConvScaleRelu
>
{});
add_device_operation_instances
(
instances
,
device_grouped_conv_fwd_xdl_outelementop_f8_f8_f32_instances
<
3
,
NDHWGC
,
GKZYXC
,
ck
::
Tuple
<>
,
NDHWGK
,
ConvFwd1x1S1P0
,
CombConvScaleRelu
>
{});
}
}
// namespace instance
}
// namespace instance
}
// namespace device
}
// namespace device
}
// namespace tensor_operation
}
// namespace tensor_operation
...
...
profiler/include/profiler/profile_grouped_conv_bwd_weight_impl.hpp
View file @
1111bc69
...
@@ -250,18 +250,14 @@ bool profile_grouped_conv_bwd_weight_impl(int do_verification,
...
@@ -250,18 +250,14 @@ bool profile_grouped_conv_bwd_weight_impl(int do_verification,
{
{
LogRangeAsType
<
float
>
(
std
::
cout
<<
"output : "
,
output
.
mData
,
","
)
LogRangeAsType
<
float
>
(
std
::
cout
<<
"output : "
,
output
.
mData
,
","
)
<<
std
::
endl
;
<<
std
::
endl
;
;
LogRangeAsType
<
float
>
(
LogRangeAsType
<
float
>
(
std
::
cout
<<
"weight (device): "
,
weight_device_result
.
mData
,
","
)
std
::
cout
<<
"weight (device): "
,
weight_device_result
.
mData
,
","
)
<<
std
::
endl
;
<<
std
::
endl
;
;
LogRangeAsType
<
float
>
(
LogRangeAsType
<
float
>
(
std
::
cout
<<
"weight (host): "
,
weight_host_result
.
mData
,
","
)
std
::
cout
<<
"weight (host): "
,
weight_host_result
.
mData
,
","
)
<<
std
::
endl
;
<<
std
::
endl
;
;
LogRangeAsType
<
float
>
(
std
::
cout
<<
"input: "
,
input
.
mData
,
","
)
LogRangeAsType
<
float
>
(
std
::
cout
<<
"input: "
,
input
.
mData
,
","
)
<<
std
::
endl
;
<<
std
::
endl
;
;
}
}
}
}
}
}
...
...
script/convert_miopen_driver_to_profiler.py
View file @
1111bc69
...
@@ -26,7 +26,7 @@ def run_ck_profiler_cmd(cmd):
...
@@ -26,7 +26,7 @@ def run_ck_profiler_cmd(cmd):
def
parse_data_type
(
args
):
def
parse_data_type
(
args
):
if
args
.
data_type
==
"fp32"
:
if
args
.
data_type
==
"fp32"
:
if
args
.
ck_profier_op
==
"grouped_conv_bwd_weight"
or
\
if
args
.
ck_profier_op
==
"grouped_conv_bwd_weight"
or
\
args
.
ck_profier_op
==
"grouped_conv_bwd_
weight
"
or
\
args
.
ck_profier_op
==
"grouped_conv_bwd_
data
"
or
\
args
.
ck_profier_op
==
"grouped_conv_fwd"
:
args
.
ck_profier_op
==
"grouped_conv_fwd"
:
args
.
data_type
=
0
args
.
data_type
=
0
if
args
.
data_type
==
"fp16"
:
if
args
.
data_type
==
"fp16"
:
...
...
test/data_type/CMakeLists.txt
View file @
1111bc69
if
(
GPU_TARGETS
)
if
(
GPU_TARGETS MATCHES
"gfx10"
OR GPU_TARGETS MATCHES
"gfx11"
OR GPU_TARGETS MATCHES
"gfx12"
)
add_definitions
(
-DCK_SKIP_FLAKY_F8_TEST
)
set
(
CK_SKIP_FLAKY_F8_TEST
"ON"
)
endif
()
else
()
add_definitions
(
-DCK_SKIP_FLAKY_F8_TEST
)
set
(
CK_SKIP_FLAKY_F8_TEST
"ON"
)
endif
()
if
(
USE_BITINT_EXTENSION_INT4
)
if
(
USE_BITINT_EXTENSION_INT4
)
add_gtest_executable
(
test_int4 test_int4.cpp
)
add_gtest_executable
(
test_int4 test_int4.cpp
)
if
(
result EQUAL 0
)
if
(
result EQUAL 0
)
...
...
test/data_type/test_bf8.cpp
View file @
1111bc69
// SPDX-License-Identifier: MIT
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-202
3
, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-202
4
, Advanced Micro Devices, Inc. All rights reserved.
#include "gtest/gtest.h"
#include "gtest/gtest.h"
#include "ck/utility/data_type.hpp"
#include "ck/utility/data_type.hpp"
#include "ck/utility/type_convert.hpp"
#include "ck/utility/type_convert.hpp"
using
ck
::
bf8_t
;
using
ck
::
bf8_t
;
using
ck
::
f8_convert_rne
;
using
ck
::
f8_convert_sr
;
using
ck
::
f8_convert_sr
;
using
ck
::
half_t
;
using
ck
::
half_t
;
using
ck
::
type_convert
;
using
ck
::
type_convert
;
...
@@ -24,33 +25,36 @@ TEST(BF8, ConvertFP32Nearest)
...
@@ -24,33 +25,36 @@ TEST(BF8, ConvertFP32Nearest)
// fix the tolerance value
// fix the tolerance value
float
abs_tol
=
1e-6
;
float
abs_tol
=
1e-6
;
// convert 0 float to bf8 and back, check if holds
// convert 0 float to bf8 and back, check if holds
ASSERT_NEAR
(
0.0
f
,
type_convert
<
float
>
(
type_convert
<
bf8_t
>
(
0.0
f
)),
abs_tol
);
ASSERT_NEAR
(
0.0
f
,
type_convert
<
float
>
(
f8_convert_rne
<
bf8_t
>
(
0.0
f
)),
abs_tol
);
// don't run the next test on gfx11 devices
#ifndef CK_SKIP_FLAKY_F8_TEST
// convert minimal float to bf8 and back, check if holds
// convert minimal float to bf8 and back, check if holds
ASSERT_NEAR
(
std
::
numeric_limits
<
float
>::
min
(),
ASSERT_NEAR
(
std
::
numeric_limits
<
float
>::
min
(),
type_convert
<
float
>
(
type
_convert
<
bf8_t
>
(
std
::
numeric_limits
<
float
>::
min
())),
type_convert
<
float
>
(
f8
_convert
_rne
<
bf8_t
>
(
std
::
numeric_limits
<
float
>::
min
())),
abs_tol
);
abs_tol
);
#endif
// convert maximal bf8_t to float and check if equal to 57344.0
// convert maximal bf8_t to float and check if equal to 57344.0
ASSERT_NEAR
(
57344.0
f
,
type_convert
<
float
>
(
type
_convert
<
bf8_t
>
(
57344.0
f
)),
abs_tol
);
ASSERT_NEAR
(
57344.0
f
,
type_convert
<
float
>
(
f8
_convert
_rne
<
bf8_t
>
(
57344.0
f
)),
abs_tol
);
// convert maximal float to bf8 and back, check if clipped to 57344.0
// convert maximal float to bf8 and back, check if clipped to 57344.0
ASSERT_NEAR
(
57344.0
f
,
ASSERT_NEAR
(
57344.0
f
,
type_convert
<
float
>
(
type
_convert
<
bf8_t
>
(
std
::
numeric_limits
<
float
>::
max
())),
type_convert
<
float
>
(
f8
_convert
_rne
<
bf8_t
>
(
std
::
numeric_limits
<
float
>::
max
())),
abs_tol
);
abs_tol
);
// convert inf float to bf8_t and check if it is qNan
// convert inf float to bf8_t and check if it is qNan
ASSERT_NEAR
(
type_convert
<
bf8_t
>
(
0x80
),
ASSERT_NEAR
(
type_convert
<
bf8_t
>
(
0x80
),
type
_convert
<
bf8_t
>
(
std
::
numeric_limits
<
float
>::
infinity
()),
f8
_convert
_rne
<
bf8_t
>
(
std
::
numeric_limits
<
float
>::
infinity
()),
abs_tol
);
abs_tol
);
// positive norm float value to bf8 and back, check if holds
// positive norm float value to bf8 and back, check if holds
float
pos_float
=
0.0000762939
f
;
float
pos_float
=
0.0000762939
f
;
ASSERT_NEAR
(
pos_float
,
type_convert
<
float
>
(
type
_convert
<
bf8_t
>
(
pos_float
)),
abs_tol
);
ASSERT_NEAR
(
pos_float
,
type_convert
<
float
>
(
f8
_convert
_rne
<
bf8_t
>
(
pos_float
)),
abs_tol
);
// negative norm float value to bf8 and back, check if holds
// negative norm float value to bf8 and back, check if holds
float
neg_float
=
-
0.0000610351
f
;
float
neg_float
=
-
0.0000610351
f
;
ASSERT_NEAR
(
neg_float
,
type_convert
<
float
>
(
type
_convert
<
bf8_t
>
(
neg_float
)),
abs_tol
);
ASSERT_NEAR
(
neg_float
,
type_convert
<
float
>
(
f8
_convert
_rne
<
bf8_t
>
(
neg_float
)),
abs_tol
);
// positive subnorm float value to bf8 and back, check if holds
// positive subnorm float value to bf8 and back, check if holds
pos_float
=
0.0000305175
f
;
pos_float
=
0.0000305175
f
;
ASSERT_NEAR
(
pos_float
,
type_convert
<
float
>
(
type
_convert
<
bf8_t
>
(
pos_float
)),
abs_tol
);
ASSERT_NEAR
(
pos_float
,
type_convert
<
float
>
(
f8
_convert
_rne
<
bf8_t
>
(
pos_float
)),
abs_tol
);
// negative subnorm float value to bf8 and back, check if holds
// negative subnorm float value to bf8 and back, check if holds
neg_float
=
-
0.0000152587
f
;
neg_float
=
-
0.0000152587
f
;
ASSERT_NEAR
(
neg_float
,
type_convert
<
float
>
(
type
_convert
<
bf8_t
>
(
neg_float
)),
abs_tol
);
ASSERT_NEAR
(
neg_float
,
type_convert
<
float
>
(
f8
_convert
_rne
<
bf8_t
>
(
neg_float
)),
abs_tol
);
}
}
TEST
(
BF8
,
ConvertFP32Stochastic
)
TEST
(
BF8
,
ConvertFP32Stochastic
)
...
@@ -92,34 +96,34 @@ TEST(BF8, ConvertFP16Nearest)
...
@@ -92,34 +96,34 @@ TEST(BF8, ConvertFP16Nearest)
// fix the tolerance value
// fix the tolerance value
float
abs_tol
=
1e-3
;
float
abs_tol
=
1e-3
;
// convert 0 fp16 to bf8 and back, check if holds
// convert 0 fp16 to bf8 and back, check if holds
ASSERT_NEAR
(
half_t
{
0.0
},
type_convert
<
half_t
>
(
type
_convert
<
bf8_t
>
(
half_t
{
0.0
})),
abs_tol
);
ASSERT_NEAR
(
half_t
{
0.0
},
type_convert
<
half_t
>
(
f8
_convert
_rne
<
bf8_t
>
(
half_t
{
0.0
})),
abs_tol
);
// convert minimal fp16 to bf8 and back, check if holds
// convert minimal fp16 to bf8 and back, check if holds
ASSERT_NEAR
(
ck
::
NumericLimits
<
half_t
>::
Min
(),
ASSERT_NEAR
(
ck
::
NumericLimits
<
half_t
>::
Min
(),
type_convert
<
half_t
>
(
type
_convert
<
bf8_t
>
(
ck
::
NumericLimits
<
half_t
>::
Min
())),
type_convert
<
half_t
>
(
f8
_convert
_rne
<
bf8_t
>
(
ck
::
NumericLimits
<
half_t
>::
Min
())),
abs_tol
);
abs_tol
);
// convert maximal bf8_t to fp16 and check if equal to 57344.0
// convert maximal bf8_t to fp16 and check if equal to 57344.0
ASSERT_NEAR
(
ASSERT_NEAR
(
half_t
{
57344.0
},
type_convert
<
half_t
>
(
type
_convert
<
bf8_t
>
(
half_t
{
57344.0
})),
abs_tol
);
half_t
{
57344.0
},
type_convert
<
half_t
>
(
f8
_convert
_rne
<
bf8_t
>
(
half_t
{
57344.0
})),
abs_tol
);
// convert maximal fp16 to bf8 and back, check if clipped to 57344.0
// convert maximal fp16 to bf8 and back, check if clipped to 57344.0
ASSERT_NEAR
(
half_t
{
57344.0
},
ASSERT_NEAR
(
half_t
{
57344.0
},
type_convert
<
half_t
>
(
type
_convert
<
bf8_t
>
(
ck
::
NumericLimits
<
half_t
>::
Max
())),
type_convert
<
half_t
>
(
f8
_convert
_rne
<
bf8_t
>
(
ck
::
NumericLimits
<
half_t
>::
Max
())),
abs_tol
);
abs_tol
);
// convert QuietNaN fp16 to bf8_t and check if it is QuietNaN
// convert QuietNaN fp16 to bf8_t and check if it is QuietNaN
ASSERT_NEAR
(
type_convert
<
bf8_t
>
(
0x80
),
ASSERT_NEAR
(
type_convert
<
bf8_t
>
(
0x80
),
type
_convert
<
bf8_t
>
(
ck
::
NumericLimits
<
half_t
>::
QuietNaN
()),
f8
_convert
_rne
<
bf8_t
>
(
ck
::
NumericLimits
<
half_t
>::
QuietNaN
()),
abs_tol
);
abs_tol
);
// positive norm fp16 value to bf8 and back, check if holds
// positive norm fp16 value to bf8 and back, check if holds
half_t
pos_half
=
half_t
{
0.0000762939
};
half_t
pos_half
=
half_t
{
0.0000762939
};
ASSERT_NEAR
(
pos_half
,
type_convert
<
half_t
>
(
type
_convert
<
bf8_t
>
(
pos_half
)),
abs_tol
);
ASSERT_NEAR
(
pos_half
,
type_convert
<
half_t
>
(
f8
_convert
_rne
<
bf8_t
>
(
pos_half
)),
abs_tol
);
// negative norm fp16 value to bf8 and back, check if holds
// negative norm fp16 value to bf8 and back, check if holds
half_t
neg_half
=
half_t
{
-
0.0000610351
};
half_t
neg_half
=
half_t
{
-
0.0000610351
};
ASSERT_NEAR
(
neg_half
,
type_convert
<
half_t
>
(
type
_convert
<
bf8_t
>
(
neg_half
)),
abs_tol
);
ASSERT_NEAR
(
neg_half
,
type_convert
<
half_t
>
(
f8
_convert
_rne
<
bf8_t
>
(
neg_half
)),
abs_tol
);
// positive subnorm fp16 value to bf8 and back, check if holds
// positive subnorm fp16 value to bf8 and back, check if holds
pos_half
=
half_t
{
0.0000305175
};
pos_half
=
half_t
{
0.0000305175
};
ASSERT_NEAR
(
pos_half
,
type_convert
<
half_t
>
(
type
_convert
<
bf8_t
>
(
pos_half
)),
abs_tol
);
ASSERT_NEAR
(
pos_half
,
type_convert
<
half_t
>
(
f8
_convert
_rne
<
bf8_t
>
(
pos_half
)),
abs_tol
);
// negative subnorm fp16 value to bf8 and back, check if holds
// negative subnorm fp16 value to bf8 and back, check if holds
neg_half
=
half_t
{
-
0.0000152587
};
neg_half
=
half_t
{
-
0.0000152587
};
ASSERT_NEAR
(
neg_half
,
type_convert
<
half_t
>
(
type
_convert
<
bf8_t
>
(
neg_half
)),
abs_tol
);
ASSERT_NEAR
(
neg_half
,
type_convert
<
half_t
>
(
f8
_convert
_rne
<
bf8_t
>
(
neg_half
)),
abs_tol
);
}
}
TEST
(
BF8
,
ConvertFP16Stochastic
)
TEST
(
BF8
,
ConvertFP16Stochastic
)
...
...
test/data_type/test_fp8.cpp
View file @
1111bc69
// SPDX-License-Identifier: MIT
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-202
3
, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-202
4
, Advanced Micro Devices, Inc. All rights reserved.
#include "gtest/gtest.h"
#include "gtest/gtest.h"
#include "ck/utility/data_type.hpp"
#include "ck/utility/data_type.hpp"
#include "ck/utility/type_convert.hpp"
#include "ck/utility/type_convert.hpp"
using
ck
::
f8_convert_rne
;
using
ck
::
f8_convert_sr
;
using
ck
::
f8_convert_sr
;
using
ck
::
f8_t
;
using
ck
::
f8_t
;
using
ck
::
half_t
;
using
ck
::
half_t
;
...
@@ -24,33 +25,36 @@ TEST(FP8, ConvertFP32Nearest)
...
@@ -24,33 +25,36 @@ TEST(FP8, ConvertFP32Nearest)
// fix the tolerance value
// fix the tolerance value
float
abs_tol
=
1e-6
;
float
abs_tol
=
1e-6
;
// convert 0 float to fp8 and back, check if holds
// convert 0 float to fp8 and back, check if holds
ASSERT_NEAR
(
0.0
f
,
type_convert
<
float
>
(
type_convert
<
f8_t
>
(
0.0
f
)),
abs_tol
);
ASSERT_NEAR
(
0.0
f
,
type_convert
<
float
>
(
f8_convert_rne
<
f8_t
>
(
0.0
f
)),
abs_tol
);
// don't run the next test on gfx11 devices
#ifndef CK_SKIP_FLAKY_F8_TEST
// convert minimal float to fp8 and back, check if holds
// convert minimal float to fp8 and back, check if holds
ASSERT_NEAR
(
std
::
numeric_limits
<
float
>::
min
(),
ASSERT_NEAR
(
std
::
numeric_limits
<
float
>::
min
(),
type_convert
<
float
>
(
type
_convert
<
f8_t
>
(
std
::
numeric_limits
<
float
>::
min
())),
type_convert
<
float
>
(
f8
_convert
_rne
<
f8_t
>
(
std
::
numeric_limits
<
float
>::
min
())),
abs_tol
);
abs_tol
);
#endif
// convert maximal f8_t to float and check if equal to 240.0
// convert maximal f8_t to float and check if equal to 240.0
ASSERT_NEAR
(
240.0
f
,
type_convert
<
float
>
(
type
_convert
<
f8_t
>
(
240.0
f
)),
abs_tol
);
ASSERT_NEAR
(
240.0
f
,
type_convert
<
float
>
(
f8
_convert
_rne
<
f8_t
>
(
240.0
f
)),
abs_tol
);
// convert maximal float to fp8 and back, check if clipped to 240.0
// convert maximal float to fp8 and back, check if clipped to 240.0
ASSERT_NEAR
(
240.0
f
,
ASSERT_NEAR
(
240.0
f
,
type_convert
<
float
>
(
type
_convert
<
f8_t
>
(
std
::
numeric_limits
<
float
>::
max
())),
type_convert
<
float
>
(
f8
_convert
_rne
<
f8_t
>
(
std
::
numeric_limits
<
float
>::
max
())),
abs_tol
);
abs_tol
);
// convert inf float to f8_t and check if it is qNan
// convert inf float to f8_t and check if it is qNan
ASSERT_NEAR
(
type_convert
<
f8_t
>
(
0x80
),
ASSERT_NEAR
(
type_convert
<
f8_t
>
(
0x80
),
type
_convert
<
f8_t
>
(
std
::
numeric_limits
<
float
>::
infinity
()),
f8
_convert
_rne
<
f8_t
>
(
std
::
numeric_limits
<
float
>::
infinity
()),
abs_tol
);
abs_tol
);
// positive norm float value to fp8 and back, check if holds
// positive norm float value to fp8 and back, check if holds
float
pos_float
=
0.017578125
f
;
float
pos_float
=
0.017578125
f
;
ASSERT_NEAR
(
pos_float
,
type_convert
<
float
>
(
type
_convert
<
f8_t
>
(
pos_float
)),
abs_tol
);
ASSERT_NEAR
(
pos_float
,
type_convert
<
float
>
(
f8
_convert
_rne
<
f8_t
>
(
pos_float
)),
abs_tol
);
// negative norm float value to fp8 and back, check if holds
// negative norm float value to fp8 and back, check if holds
float
neg_float
=
-
0.015625
f
;
float
neg_float
=
-
0.015625
f
;
ASSERT_NEAR
(
neg_float
,
type_convert
<
float
>
(
type
_convert
<
f8_t
>
(
neg_float
)),
abs_tol
);
ASSERT_NEAR
(
neg_float
,
type_convert
<
float
>
(
f8
_convert
_rne
<
f8_t
>
(
neg_float
)),
abs_tol
);
// positive subnorm float value to fp8 and back, check if holds
// positive subnorm float value to fp8 and back, check if holds
pos_float
=
0.00390625
f
;
pos_float
=
0.00390625
f
;
ASSERT_NEAR
(
pos_float
,
type_convert
<
float
>
(
type
_convert
<
f8_t
>
(
pos_float
)),
abs_tol
);
ASSERT_NEAR
(
pos_float
,
type_convert
<
float
>
(
f8
_convert
_rne
<
f8_t
>
(
pos_float
)),
abs_tol
);
// negative subnorm float value to fp8 and back, check if holds
// negative subnorm float value to fp8 and back, check if holds
neg_float
=
-
0.001953125
f
;
neg_float
=
-
0.001953125
f
;
ASSERT_NEAR
(
neg_float
,
type_convert
<
float
>
(
type
_convert
<
f8_t
>
(
neg_float
)),
abs_tol
);
ASSERT_NEAR
(
neg_float
,
type_convert
<
float
>
(
f8
_convert
_rne
<
f8_t
>
(
neg_float
)),
abs_tol
);
}
}
TEST
(
FP8
,
ConvertFP32Stochastic
)
TEST
(
FP8
,
ConvertFP32Stochastic
)
...
@@ -92,33 +96,33 @@ TEST(FP8, ConvertFP16Nearest)
...
@@ -92,33 +96,33 @@ TEST(FP8, ConvertFP16Nearest)
// fix the tolerance value
// fix the tolerance value
float
abs_tol
=
1e-3
;
float
abs_tol
=
1e-3
;
// convert 0 fp16 to fp8 and back, check if holds
// convert 0 fp16 to fp8 and back, check if holds
ASSERT_NEAR
(
half_t
{
0.0
},
type_convert
<
half_t
>
(
type
_convert
<
f8_t
>
(
half_t
{
0.0
})),
abs_tol
);
ASSERT_NEAR
(
half_t
{
0.0
},
type_convert
<
half_t
>
(
f8
_convert
_rne
<
f8_t
>
(
half_t
{
0.0
})),
abs_tol
);
// convert minimal fp16 to fp8 and back, check if holds
// convert minimal fp16 to fp8 and back, check if holds
ASSERT_NEAR
(
ck
::
NumericLimits
<
half_t
>::
Min
(),
ASSERT_NEAR
(
ck
::
NumericLimits
<
half_t
>::
Min
(),
type_convert
<
half_t
>
(
type
_convert
<
f8_t
>
(
ck
::
NumericLimits
<
half_t
>::
Min
())),
type_convert
<
half_t
>
(
f8
_convert
_rne
<
f8_t
>
(
ck
::
NumericLimits
<
half_t
>::
Min
())),
abs_tol
);
abs_tol
);
// convert maximal f8_t to fp16 and check if equal to 240.0
// convert maximal f8_t to fp16 and check if equal to 240.0
ASSERT_NEAR
(
half_t
{
240.0
},
type_convert
<
half_t
>
(
type
_convert
<
f8_t
>
(
half_t
{
240.0
})),
abs_tol
);
ASSERT_NEAR
(
half_t
{
240.0
},
type_convert
<
half_t
>
(
f8
_convert
_rne
<
f8_t
>
(
half_t
{
240.0
})),
abs_tol
);
// convert maximal fp16 to fp8 and back, check if clipped to 240.0
// convert maximal fp16 to fp8 and back, check if clipped to 240.0
ASSERT_NEAR
(
half_t
{
240.0
},
ASSERT_NEAR
(
half_t
{
240.0
},
type_convert
<
half_t
>
(
type
_convert
<
f8_t
>
(
ck
::
NumericLimits
<
half_t
>::
Max
())),
type_convert
<
half_t
>
(
f8
_convert
_rne
<
f8_t
>
(
ck
::
NumericLimits
<
half_t
>::
Max
())),
abs_tol
);
abs_tol
);
// convert QuietNaN fp16 to f8_t and check if it is QuietNaN
// convert QuietNaN fp16 to f8_t and check if it is QuietNaN
ASSERT_NEAR
(
type_convert
<
f8_t
>
(
0x80
),
ASSERT_NEAR
(
type_convert
<
f8_t
>
(
0x80
),
type
_convert
<
f8_t
>
(
ck
::
NumericLimits
<
half_t
>::
QuietNaN
()),
f8
_convert
_rne
<
f8_t
>
(
ck
::
NumericLimits
<
half_t
>::
QuietNaN
()),
abs_tol
);
abs_tol
);
// positive norm fp16 value to fp8 and back, check if holds
// positive norm fp16 value to fp8 and back, check if holds
half_t
pos_half
=
half_t
{
0.017578125
};
half_t
pos_half
=
half_t
{
0.017578125
};
ASSERT_NEAR
(
pos_half
,
type_convert
<
half_t
>
(
type
_convert
<
f8_t
>
(
pos_half
)),
abs_tol
);
ASSERT_NEAR
(
pos_half
,
type_convert
<
half_t
>
(
f8
_convert
_rne
<
f8_t
>
(
pos_half
)),
abs_tol
);
// negative norm fp16 value to fp8 and back, check if holds
// negative norm fp16 value to fp8 and back, check if holds
half_t
neg_half
=
half_t
{
-
0.015625
};
half_t
neg_half
=
half_t
{
-
0.015625
};
ASSERT_NEAR
(
neg_half
,
type_convert
<
half_t
>
(
type
_convert
<
f8_t
>
(
neg_half
)),
abs_tol
);
ASSERT_NEAR
(
neg_half
,
type_convert
<
half_t
>
(
f8
_convert
_rne
<
f8_t
>
(
neg_half
)),
abs_tol
);
// positive subnorm fp16 value to fp8 and back, check if holds
// positive subnorm fp16 value to fp8 and back, check if holds
pos_half
=
half_t
{
0.00390625
};
pos_half
=
half_t
{
0.00390625
};
ASSERT_NEAR
(
pos_half
,
type_convert
<
half_t
>
(
type
_convert
<
f8_t
>
(
pos_half
)),
abs_tol
);
ASSERT_NEAR
(
pos_half
,
type_convert
<
half_t
>
(
f8
_convert
_rne
<
f8_t
>
(
pos_half
)),
abs_tol
);
// negative subnorm fp16 value to fp8 and back, check if holds
// negative subnorm fp16 value to fp8 and back, check if holds
neg_half
=
half_t
{
-
0.001953125
};
neg_half
=
half_t
{
-
0.001953125
};
ASSERT_NEAR
(
neg_half
,
type_convert
<
half_t
>
(
type
_convert
<
f8_t
>
(
neg_half
)),
abs_tol
);
ASSERT_NEAR
(
neg_half
,
type_convert
<
half_t
>
(
f8
_convert
_rne
<
f8_t
>
(
neg_half
)),
abs_tol
);
}
}
TEST
(
FP8
,
ConvertFP16Stochastic
)
TEST
(
FP8
,
ConvertFP16Stochastic
)
...
...
Prev
1
2
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment