Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
MIGraphX
Commits
f8a75f8a
Commit
f8a75f8a
authored
Dec 07, 2023
by
Paul
Browse files
Merge
parents
74448ed6
d00fdf6e
Changes
242
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
835 additions
and
93 deletions
+835
-93
src/targets/gpu/kernels/include/migraphx/kernels/pad.hpp
src/targets/gpu/kernels/include/migraphx/kernels/pad.hpp
+3
-2
src/targets/gpu/kernels/include/migraphx/kernels/reduce.hpp
src/targets/gpu/kernels/include/migraphx/kernels/reduce.hpp
+39
-29
src/targets/gpu/kernels/include/migraphx/kernels/roialign.hpp
...targets/gpu/kernels/include/migraphx/kernels/roialign.hpp
+18
-15
src/targets/gpu/kernels/include/migraphx/kernels/scatter_reduction_modes.hpp
...nels/include/migraphx/kernels/scatter_reduction_modes.hpp
+83
-0
src/targets/gpu/kernels/include/migraphx/kernels/scatternd.hpp
...argets/gpu/kernels/include/migraphx/kernels/scatternd.hpp
+1
-27
src/targets/gpu/kernels/include/migraphx/kernels/softmax.hpp
src/targets/gpu/kernels/include/migraphx/kernels/softmax.hpp
+1
-1
src/targets/gpu/kernels/include/migraphx/kernels/tensor_view.hpp
...gets/gpu/kernels/include/migraphx/kernels/tensor_view.hpp
+1
-0
src/targets/gpu/kernels/include/migraphx/kernels/type_traits.hpp
...gets/gpu/kernels/include/migraphx/kernels/type_traits.hpp
+1
-1
src/targets/gpu/kernels/include/migraphx/kernels/vec.hpp
src/targets/gpu/kernels/include/migraphx/kernels/vec.hpp
+1
-1
src/targets/gpu/mlir.cpp
src/targets/gpu/mlir.cpp
+13
-1
src/targets/gpu/prefuse_ops.cpp
src/targets/gpu/prefuse_ops.cpp
+35
-10
src/targets/gpu/target.cpp
src/targets/gpu/target.cpp
+1
-0
src/targets/ref/CMakeLists.txt
src/targets/ref/CMakeLists.txt
+2
-1
src/tf/CMakeLists.txt
src/tf/CMakeLists.txt
+9
-2
src/tmp_dir.cpp
src/tmp_dir.cpp
+11
-1
src/verify_args.cpp
src/verify_args.cpp
+0
-1
test/CMakeLists.txt
test/CMakeLists.txt
+1
-0
test/float_equal.cpp
test/float_equal.cpp
+11
-1
test/fp8e4m3fn.cpp
test/fp8e4m3fn.cpp
+291
-0
test/fp8e4m3fnuz.cpp
test/fp8e4m3fnuz.cpp
+313
-0
No files found.
src/targets/gpu/kernels/include/migraphx/kernels/pad.hpp
View file @
f8a75f8a
...
...
@@ -28,6 +28,7 @@
#include <migraphx/kernels/index.hpp>
#include <migraphx/kernels/algorithm.hpp>
#include <migraphx/kernels/ranges.hpp>
#include <migraphx/kernels/vec.hpp>
namespace
migraphx
{
...
...
@@ -53,9 +54,9 @@ __device__ void pad(const index& idx,
if
(
any_of
(
range_multi
.
begin
(),
range_multi
.
end
(),
[
&
](
auto
j
)
{
return
multi
[
j
]
<
offsets
[
j
]
or
input_idx
[
j
]
>=
input_bounds
[
j
];
}))
output
[
multi
]
=
pad_val
;
output
[
multi
]
=
implicit_conversion
(
pad_val
)
;
else
output
[
multi
]
=
input
[
input_idx
];
output
[
multi
]
=
implicit_conversion
(
input
[
input_idx
]
)
;
});
}
...
...
src/targets/gpu/kernels/include/migraphx/kernels/reduce.hpp
View file @
f8a75f8a
...
...
@@ -64,7 +64,7 @@ __device__ void dpp_reduce(T& in, Op op)
#if __AMDGCN_WAVEFRONT_SIZE == 32
if
constexpr
(
SubWaveSize
>
16
)
{
out
=
dpp_swizzle
<
dpp_row_bcast
(
15
)
>
(
in
);
out
=
dpp_swizzle
<
0x1e0
>
(
in
);
in
=
op
(
in
,
out
);
}
#else
...
...
@@ -89,9 +89,11 @@ __device__ void dpp_reduce(T& in, Op op)
#if defined(MIGRAPHX_USE_CLANG_TIDY) || defined(CPPCHECK)
// NOLINTNEXTLINE
#define MIGRAPHX_DPP_REDUCE_ASM(x, ins) x = 1
#define MIGRAPHX_DPP_REDUCE_ASM(x, ins, f) \
(void)f; \
x = 1
#elif __AMDGCN_WAVEFRONT_SIZE == 64
#define MIGRAPHX_DPP_REDUCE_ASM(x, ins
)
\
#define MIGRAPHX_DPP_REDUCE_ASM(x, ins
, f)
\
__asm__ volatile("s_nop 4\n" #ins " %0 %0 %0 row_shr:1\n" \
"s_nop 1\n" #ins " %0 %0 %0 row_shr:2\n" \
"s_nop 1\n" #ins " %0 %0 %0 row_shr:4 bank_mask:0xe\n" \
...
...
@@ -100,29 +102,42 @@ __device__ void dpp_reduce(T& in, Op op)
"s_nop 1\n" #ins " %0 %0 %0 row_bcast:31 row_mask:0xc\n" \
"s_nop 1\n" \
: "=v"(x) \
: "0"(x))
: "0"(x)); \
(void)f
#else
#define MIGRAPHX_DPP_REDUCE_ASM(x, ins
)
\
#define MIGRAPHX_DPP_REDUCE_ASM(x, ins
, f)
\
__asm__ volatile("s_nop 4\n" #ins " %0 %0 %0 row_shr:1\n" \
"s_nop 1\n" #ins " %0 %0 %0 row_shr:2\n" \
"s_nop 1\n" #ins " %0 %0 %0 row_shr:4 bank_mask:0xe\n" \
"s_nop 1\n" #ins " %0 %0 %0 row_shr:8 bank_mask:0xc\n" \
"s_nop 1\n" \
"s_nop 1\n" \
: "=v"(x) \
: "0"(x))
: "0"(x)); \
auto y = dpp_swizzle<0x1e0>(x); \
x = f(x, y)
#endif
// NOLINTNEXTLINE
#define MIGRAPHX_DPP_REDUCE(op, prefix, sign) \
__device__ inline void dpp_reduce(double& x, op) { MIGRAPHX_DPP_REDUCE_ASM(x, prefix##_f64); } \
__device__ inline void dpp_reduce(float& x, op) { MIGRAPHX_DPP_REDUCE_ASM(x, prefix##_f32); } \
__device__ inline void dpp_reduce(half& x, op) { MIGRAPHX_DPP_REDUCE_ASM(x, prefix##_f16); } \
__device__ inline void dpp_reduce(int32_t& x, op) \
{ \
MIGRAPHX_DPP_REDUCE_ASM(x, prefix##sign##32); \
} \
__device__ inline void dpp_reduce(uint32_t& x, op) { MIGRAPHX_DPP_REDUCE_ASM(x, prefix##_u32); }
#define MIGRAPHX_DPP_REDUCE(op, prefix, sign) \
__device__ inline void dpp_reduce(double& x, op f) \
{ \
MIGRAPHX_DPP_REDUCE_ASM(x, prefix##_f64, f); \
} \
__device__ inline void dpp_reduce(float& x, op f) \
{ \
MIGRAPHX_DPP_REDUCE_ASM(x, prefix##_f32, f); \
} \
__device__ inline void dpp_reduce(half& x, op f) \
{ \
MIGRAPHX_DPP_REDUCE_ASM(x, prefix##_f16, f); \
} \
__device__ inline void dpp_reduce(int32_t& x, op f) \
{ \
MIGRAPHX_DPP_REDUCE_ASM(x, prefix##sign##32, f); \
} \
__device__ inline void dpp_reduce(uint32_t& x, op f) \
{ \
MIGRAPHX_DPP_REDUCE_ASM(x, prefix##_u32, f); \
}
// Note: when max and min are in int32_t, signed version of instruction needs to be used.
MIGRAPHX_DPP_REDUCE
(
op
::
sum
,
v_add
,
_u
)
...
...
@@ -154,14 +169,10 @@ __device__ auto block_reduce(index idx, Op op, T init, Index n, F f)
MIGRAPHX_ASSERT
(
idx
.
max_nlocal
()
==
idx
.
nlocal
());
if
(
idx
.
max_nlocal
()
==
idx
.
nlocal_wave
())
return
wave_reduce
(
idx
,
op
,
init
,
n
,
f
);
#if __AMDGCN_WAVEFRONT_SIZE == 32
constexpr
index_int
lanes_per_thread
=
16
;
#else
constexpr
index_int
lanes_per_thread
=
64
;
#endif
constexpr
index_int
lanes_per_thread
=
__AMDGCN_WAVEFRONT_SIZE
;
using
type
=
decltype
(
index
::
invoke_loop
(
f
,
0
,
_c
<
0
>
));
__shared__
type
buffer
[
idx
.
max_nlocal
()
/
lanes_per_thread
];
type
x
=
init
;
type
x
=
type
(
init
)
;
idx
.
local_stride
(
n
,
[
&
](
auto
i
,
auto
d
)
{
x
=
op
(
x
,
index
::
invoke_loop
(
f
,
i
,
d
));
});
dpp_reduce
(
x
,
op
);
...
...
@@ -172,7 +183,7 @@ __device__ auto block_reduce(index idx, Op op, T init, Index n, F f)
}
__syncthreads
();
type
y
=
init
;
type
y
=
type
(
init
)
;
for
(
index_int
i
=
0
;
i
<
idx
.
nlocal
()
/
lanes_per_thread
;
i
++
)
{
y
=
op
(
y
,
buffer
[
i
]);
...
...
@@ -299,9 +310,8 @@ struct reducer_base
{
auto
&&
derived
=
static_cast
<
const
Derived
&>
(
*
this
);
auto
t
=
derived
.
slice
(
x
);
return
make_storage_access
<
typename
decltype
(
t
)
::
type
>
([
=
](
auto
i
,
auto
...)
->
auto
&
{
return
t
[
i
];
});
return
make_storage_access
<
typename
decltype
(
t
)
::
type
>
(
[
=
](
auto
i
,
auto
...)
->
auto
&
{
return
t
[
i
];
});
}
}
...
...
@@ -448,7 +458,7 @@ struct block
{
using
max_iterations
=
decltype
(
idx
.
max_local_stride_iterations
(
n
));
inner_storage
<
R
,
max_iterations
{},
N
>
storage
;
idx
.
local_stride
(
n
,
[
&
](
auto
j
,
auto
d
)
{
storage
(
j
,
d
)
=
f
(
xs
(
j
,
d
)...);
});
idx
.
local_stride
(
n
,
[
&
](
auto
j
,
auto
d
)
{
storage
(
j
,
d
)
=
R
{
f
(
xs
(
j
,
d
)...)
}
;
});
return
storage
;
}
};
...
...
@@ -617,7 +627,7 @@ struct lane
__device__
auto
reduce_impl
(
Op
op
,
T
init
,
Read
read
,
N
n
,
U
&&
x
,
Us
&&
...
xs
)
const
{
using
type
=
remove_reference_t
<
decltype
(
x
(
0
,
_c
<
0
>
))
>
;
type
r
=
init
;
type
r
=
type
(
init
)
;
for
(
index_int
j
=
0
;
j
<
n
;
j
++
)
{
r
=
op
(
r
,
read
(
x
(
j
,
_c
<
0
>
),
xs
(
j
,
_c
<
0
>
)...));
...
...
src/targets/gpu/kernels/include/migraphx/kernels/roialign.hpp
View file @
f8a75f8a
...
...
@@ -62,7 +62,7 @@ struct avg_pool
template
<
class
T
>
MIGRAPHX_DEVICE_CONSTEXPR
T
final
(
T
x
,
index_int
y
)
{
return
(
y
==
0
)
?
0.0
:
(
x
/
y
)
;
return
(
y
==
0
)
?
T
{
0.0
}
:
T
{
x
/
y
}
;
}
};
...
...
@@ -76,7 +76,7 @@ MIGRAPHX_DEVICE_CONSTEXPR typename Iterator::value_type bilinear_interpolate(
{
if
(
xy
[
ii
]
<
-
1.0
f
or
xy
[
ii
]
>
dims
[
ii
])
{
return
0
;
return
implicit_conversion
(
0
)
;
}
xy
[
ii
]
=
migraphx
::
max
(
xy
[
ii
],
0.0
f
);
...
...
@@ -92,15 +92,16 @@ MIGRAPHX_DEVICE_CONSTEXPR typename Iterator::value_type bilinear_interpolate(
high
[
0
]
*
dims
[
1
]
+
low
[
1
],
high
[
0
]
*
dims
[
1
]
+
high
[
1
]};
float
ly
=
xy
[
0
]
-
low
[
0
];
float
lx
=
xy
[
1
]
-
low
[
1
];
float
hy
=
1.0
f
-
ly
;
float
hx
=
1.0
f
-
lx
;
array
<
typename
Iterator
::
value_type
,
4
>
ws
=
{
hy
*
hx
,
hy
*
lx
,
ly
*
hx
,
ly
*
lx
};
float
ly
=
xy
[
0
]
-
low
[
0
];
float
lx
=
xy
[
1
]
-
low
[
1
];
float
hy
=
1.0
f
-
ly
;
float
hx
=
1.0
f
-
lx
;
// do calculations in floating point and convert final result to required type
array
<
float
,
4
>
ws
=
{
hy
*
hx
,
hy
*
lx
,
ly
*
hx
,
ly
*
lx
};
auto
v01
=
pooling
(
data
[
locs
[
0
]]
*
ws
[
0
],
data
[
locs
[
1
]]
*
ws
[
1
]);
auto
v23
=
pooling
(
data
[
locs
[
2
]]
*
ws
[
2
],
data
[
locs
[
3
]]
*
ws
[
3
]);
return
pooling
(
v01
,
v23
);
return
implicit_conversion
(
pooling
(
v01
,
v23
)
)
;
}
template
<
class
Iterator
,
class
Op
>
...
...
@@ -113,8 +114,9 @@ MIGRAPHX_DEVICE_CONSTEXPR auto calc_pooling(const Iterator& data,
float
roi_offset
,
Op
op
)
{
typename
Iterator
::
value_type
output_val
=
op
.
init
();
const
int64_t
count
=
bin_grid_size
[
0
]
*
bin_grid_size
[
1
];
using
in_dtype
=
typename
Iterator
::
value_type
;
in_dtype
output_val
=
in_dtype
{
op
.
init
()};
const
int64_t
count
=
bin_grid_size
[
0
]
*
bin_grid_size
[
1
];
dfor
(
bin_grid_size
[
0
],
bin_grid_size
[
1
])([
&
](
auto
iy
,
auto
ix
)
{
array
<
index_int
,
2
>
id
=
{
iy
,
ix
};
array
<
float
,
2
>
locs
=
...
...
@@ -148,7 +150,6 @@ __device__ void roialign(const T& x_t, const U& rois_t, const V& ind_t, W& y_t,
const
auto
x
=
x_t
.
begin
();
const
auto
rois
=
rois_t
.
begin
();
const
auto
ind
=
ind_t
.
begin
();
// input shape
auto
x_lens
=
x_t
.
get_shape
().
lens
;
auto
channel_num
=
x_lens
[
1
];
...
...
@@ -176,10 +177,12 @@ __device__ void roialign(const T& x_t, const U& rois_t, const V& ind_t, W& y_t,
const
auto
offset_rois
=
rois
+
(
n
*
roi_column_num
);
const
int
batch_ind
=
ind
[
n
];
array
<
float
,
2
>
roi_starts
=
{
offset_rois
[
1
]
*
s
.
spatial_scale
,
offset_rois
[
0
]
*
s
.
spatial_scale
};
array
<
float
,
2
>
roi_ends
=
{
offset_rois
[
3
]
*
s
.
spatial_scale
,
offset_rois
[
2
]
*
s
.
spatial_scale
};
array
<
float
,
2
>
roi_starts
=
{
static_cast
<
float
>
(
offset_rois
[
1
])
*
static_cast
<
float
>
(
s
.
spatial_scale
),
static_cast
<
float
>
(
offset_rois
[
0
])
*
static_cast
<
float
>
(
s
.
spatial_scale
)};
array
<
float
,
2
>
roi_ends
=
{
static_cast
<
float
>
(
offset_rois
[
3
])
*
static_cast
<
float
>
(
s
.
spatial_scale
),
static_cast
<
float
>
(
offset_rois
[
2
])
*
static_cast
<
float
>
(
s
.
spatial_scale
)};
array
<
float
,
2
>
roi_size
{};
array
<
float
,
2
>
bin_size
{};
...
...
src/targets/gpu/kernels/include/migraphx/kernels/scatter_reduction_modes.hpp
0 → 100644
View file @
f8a75f8a
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2023 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#ifndef MIGRAPHX_GUARD_KERNELS_SCATTER_REDUCTION_MODES_HPP
#define MIGRAPHX_GUARD_KERNELS_SCATTER_REDUCTION_MODES_HPP
#include <migraphx/kernels/types.hpp>
namespace
migraphx
{
struct
assign_none
{
template
<
class
T
,
class
U
>
MIGRAPHX_DEVICE_CONSTEXPR
void
operator
()(
T
&
x
,
U
y
)
const
{
x
=
y
;
}
};
struct
assign_add
{
template
<
class
T
,
class
U
>
MIGRAPHX_DEVICE_CONSTEXPR
void
operator
()(
T
&
x
,
U
y
)
const
{
atomicAdd
(
&
x
,
y
);
}
};
struct
assign_mul
{
template
<
class
T
,
class
U
>
MIGRAPHX_DEVICE_CONSTEXPR
void
operator
()(
T
&
x
,
U
y
)
const
{
T
old
=
x
;
T
assumed
;
do
{
assumed
=
old
;
old
=
atomicCAS
(
&
x
,
assumed
,
assumed
*
y
);
}
while
(
assumed
!=
old
);
}
};
struct
assign_max
{
template
<
typename
T
,
typename
U
>
MIGRAPHX_DEVICE_CONSTEXPR
void
operator
()(
T
&
x
,
U
y
)
const
{
atomicMax
(
&
x
,
y
);
}
};
struct
assign_min
{
template
<
typename
T
,
typename
U
>
MIGRAPHX_DEVICE_CONSTEXPR
void
operator
()(
T
&
x
,
U
y
)
const
{
atomicMin
(
&
x
,
y
);
}
};
}
// namespace migraphx
#endif
src/targets/gpu/kernels/include/migraphx/kernels/scatternd.hpp
View file @
f8a75f8a
...
...
@@ -26,36 +26,10 @@
#include <migraphx/kernels/index.hpp>
#include <migraphx/kernels/algorithm.hpp>
#include <migraphx/kernels/scatter_reduction_modes.hpp>
namespace
migraphx
{
struct
assign_none
{
template
<
class
T
,
class
U
>
MIGRAPHX_DEVICE_CONSTEXPR
void
operator
()(
T
&
x
,
U
y
)
const
{
x
=
y
;
}
};
struct
assign_add
{
template
<
class
T
,
class
U
>
MIGRAPHX_DEVICE_CONSTEXPR
void
operator
()(
T
&
x
,
U
y
)
const
{
x
+=
y
;
}
};
struct
assign_mul
{
template
<
class
T
,
class
U
>
MIGRAPHX_DEVICE_CONSTEXPR
void
operator
()(
T
&
x
,
U
y
)
const
{
x
*=
y
;
}
};
template
<
class
T
,
class
U
,
class
V
,
class
F
>
__device__
void
scatternd
(
const
T
&
indices_t
,
const
U
&
updates_t
,
const
V
&
output_t
,
F
f
)
{
...
...
src/targets/gpu/kernels/include/migraphx/kernels/softmax.hpp
View file @
f8a75f8a
...
...
@@ -43,7 +43,7 @@ __device__ void softmax(Input input1, Output output)
auto
exp_in
=
r
.
inner
([
&
](
auto
x
)
{
return
migraphx
::
exp
(
x
-
c
);
})(
input
);
auto
batch_sum
=
r
.
reduce
(
op
::
sum
{},
0
,
[](
auto
x
)
{
return
migraphx
::
convert
<
float
>
(
x
);
})(
exp_in
);
r
.
inner
([
&
](
auto
&
y
,
auto
x
)
{
y
=
x
/
batch_sum
;
})(
output
,
exp_in
);
r
.
inner
([
&
](
auto
&
y
,
auto
x
)
{
y
=
implicit_conversion
(
x
/
batch_sum
)
;
})(
output
,
exp_in
);
});
}
...
...
src/targets/gpu/kernels/include/migraphx/kernels/tensor_view.hpp
View file @
f8a75f8a
...
...
@@ -27,6 +27,7 @@
#include <migraphx/kernels/shape.hpp>
#include <migraphx/kernels/debug.hpp>
#include <migraphx/kernels/iota_iterator.hpp>
#include <migraphx/kernels/float8.hpp>
namespace
migraphx
{
...
...
src/targets/gpu/kernels/include/migraphx/kernels/type_traits.hpp
View file @
f8a75f8a
...
...
@@ -251,7 +251,7 @@ constexpr T numeric_max()
}
template
<
class
T
>
constexpr
T
numeric_lowest
()
constexpr
auto
numeric_lowest
()
->
decltype
(
numeric_max
<
T
>
())
{
if
constexpr
(
is_integral
<
T
>
{})
{
...
...
src/targets/gpu/kernels/include/migraphx/kernels/vec.hpp
View file @
f8a75f8a
...
...
@@ -207,7 +207,7 @@ struct implicit_conversion_op
template
<
class
U
>
constexpr
operator
U
()
const
{
return
x
;
return
static_cast
<
U
>
(
x
)
;
}
};
...
...
src/targets/gpu/mlir.cpp
View file @
f8a75f8a
...
...
@@ -73,6 +73,7 @@ namespace gpu {
MIGRAPHX_DECLARE_ENV_VAR
(
MIGRAPHX_TRACE_MLIR
);
MIGRAPHX_DECLARE_ENV_VAR
(
MIGRAPHX_MLIR_TUNE_EXHAUSTIVE
);
MIGRAPHX_DECLARE_ENV_VAR
(
MIGRAPHX_MLIR_TUNE_LIMIT
);
MIGRAPHX_DECLARE_ENV_VAR
(
MIGRAPHX_MLIR_TUNING_DB
);
MIGRAPHX_DECLARE_ENV_VAR
(
MIGRAPHX_MLIR_TUNING_CFG
);
...
...
@@ -796,7 +797,9 @@ struct mlir_program
if
(
enabled
(
MIGRAPHX_MLIR_TUNE_EXHAUSTIVE
{}))
tuning_mode
=
RocmlirTuningParamSetKindExhaustive
;
mlir_tuning_space
params
{
mlirRockTuningSpaceCreate
(
mmodule
.
get
(),
tuning_mode
)};
for
(
auto
i
:
range
(
mlirRockTuningGetNumParams
(
params
.
get
())))
const
auto
limit
=
value_of
(
MIGRAPHX_MLIR_TUNE_LIMIT
{},
std
::
numeric_limits
<
std
::
size_t
>::
max
());
for
(
auto
i
:
range
(
std
::
min
<
std
::
size_t
>
(
limit
,
mlirRockTuningGetNumParams
(
params
.
get
()))))
{
mlir_tuning_param
param
{
mlirRockTuningParamCreate
()};
if
(
not
mlirRockTuningParamGet
(
params
.
get
(),
i
,
param
.
get
()))
...
...
@@ -1032,6 +1035,15 @@ tuning_config get_tuning_config_mlir(const context& migraphx_ctx,
mlir_program
mp
;
mp
.
set_gpu_properties
(
migraphx_ctx
);
mp
.
parse
(
m
);
const
bool
trace
=
enabled
(
MIGRAPHX_TRACE_MLIR
{});
static
std
::
mutex
mutex
;
if
(
trace
)
{
const
std
::
lock_guard
<
std
::
mutex
>
lock
(
mutex
);
auto
mod_op
=
mlirModuleGetOperation
(
mp
.
mmodule
.
get
());
std
::
cout
<<
mlir_print
(
&
mlirOperationPrint
,
mod_op
)
<<
std
::
endl
;
}
return
mp
.
get_tuning_config
(
exhaustive
);
}
...
...
src/targets/gpu/prefuse_ops.cpp
View file @
f8a75f8a
...
...
@@ -28,7 +28,10 @@
#include <migraphx/register_op.hpp>
#include <migraphx/pass_manager.hpp>
#include <migraphx/dead_code_elimination.hpp>
#ifdef MIGRAPHX_USE_COMPOSABLEKERNEL
#include <migraphx/gpu/ck.hpp>
#endif
#include <migraphx/gpu/fuse_mlir.hpp>
namespace
migraphx
{
inline
namespace
MIGRAPHX_INLINE_NS
{
...
...
@@ -128,26 +131,49 @@ struct pre_gemm_softmax_gemm : gemm_softmax_gemm
};
MIGRAPHX_REGISTER_OP
(
pre_gemm_softmax_gemm
);
MIGRAPHX_PRED_MATCHER
(
is_ck_gemm
,
instruction_ref
ins
)
auto
is_ck_gemm
(
)
{
if
(
ins
->
name
()
!=
"dot"
)
return
match
::
make_basic_pred_matcher
([
=
](
instruction_ref
ins
)
{
#ifdef MIGRAPHX_USE_COMPOSABLEKERNEL
if
(
not
enabled
(
MIGRAPHX_ENABLE_CK
{}))
return
false
;
if
(
ins
->
name
()
!=
"dot"
)
return
false
;
if
(
not
pre_gemm_softmax_gemm
::
is_ck_supported_type
(
ins
->
get_shape
().
type
()))
return
false
;
return
true
;
#else
(
void
)
ins
;
return
false
;
if
(
not
pre_gemm_softmax_gemm
::
is_ck_supported_type
(
ins
->
get_shape
().
type
()))
return
false
;
return
true
;
#endif
});
}
auto
is_mlir_gemm
()
{
return
match
::
make_basic_pred_matcher
([
=
](
instruction_ref
ins
)
{
if
(
not
mlir_attention_enabled
())
return
false
;
if
(
ins
->
name
()
!=
"dot"
)
return
false
;
return
std
::
all_of
(
ins
->
inputs
().
begin
(),
ins
->
inputs
().
end
(),
[
&
](
auto
i
)
{
return
pre_gemm_softmax_gemm
::
is_mlir_supported_type
(
i
->
get_shape
().
type
());
});
});
}
struct
find_gemm_softmax_gemm
{
auto
matcher
()
const
{
auto
gemm1
=
match
::
skip
(
match
::
name
(
"contiguous"
))(
match
::
name
(
"dot"
)(
is_ck_gemm
(
).
bind
(
"gemm1"
)));
auto
gemm1
=
match
::
skip
(
match
::
name
(
"contiguous"
))(
match
::
name
(
"dot"
)(
match
::
any_of
(
is_ck_gemm
(),
is_mlir_gemm
()
).
bind
(
"gemm1"
)));
auto
mul
=
match
::
name
(
"mul"
)(
match
::
nargs
(
2
),
match
::
either_arg
(
0
,
1
)(
match
::
is_constant
().
bind
(
"scale"
),
gemm1
));
auto
softmax
=
match
::
name
(
"softmax"
)(
match
::
arg
(
0
)(
mul
)).
bind
(
"softmax"
);
return
match
::
name
(
"dot"
)(
is_ck_gemm
().
bind
(
"gemm2"
))(
match
::
arg
(
0
)(
softmax
));
return
match
::
name
(
"dot"
)(
match
::
any_of
(
is_ck_gemm
(),
is_mlir_gemm
()).
bind
(
"gemm2"
))(
match
::
arg
(
0
)(
softmax
));
}
void
apply
(
module_pass_manager
&
mpm
,
const
match
::
matcher_result
&
r
)
const
...
...
@@ -182,8 +208,7 @@ void prefuse_ops::apply(module_pass_manager& mpm) const
match
::
find_matches
(
mpm
.
get_module
(),
find_layernorm
{});
mpm
.
run_pass
(
dead_code_elimination
{});
match
::
find_matches
(
mpm
.
get_module
(),
find_add_layernorm
{});
if
(
enabled
(
MIGRAPHX_ENABLE_CK
{}))
match
::
find_matches
(
mpm
,
find_gemm_softmax_gemm
{});
match
::
find_matches
(
mpm
,
find_gemm_softmax_gemm
{});
}
}
// namespace gpu
...
...
src/targets/gpu/target.cpp
View file @
f8a75f8a
...
...
@@ -98,6 +98,7 @@ std::vector<pass> target::get_passes(migraphx::context& gctx, const compile_opti
ctx
.
set_exhaustive_tune_flag
(
options
.
exhaustive_tune
);
std
::
set
<
shape
::
type_t
>
unsupported_types
(
shape
::
types
().
begin
(),
shape
::
types
().
end
());
unsupported_types
.
erase
(
shape
::
type_t
::
float_type
);
unsupported_types
.
erase
(
shape
::
type_t
::
fp8e4m3fnuz_type
);
unsupported_types
.
erase
(
shape
::
type_t
::
half_type
);
unsupported_types
.
erase
(
shape
::
type_t
::
bool_type
);
unsupported_types
.
erase
(
shape
::
type_t
::
int8_type
);
...
...
src/targets/ref/CMakeLists.txt
View file @
f8a75f8a
...
...
@@ -33,8 +33,9 @@ rocm_set_soversion(migraphx_ref ${MIGRAPHX_SO_VERSION})
find_path
(
BLAZE_INCLUDE blaze/Blaze.h
)
rocm_clang_tidy_check
(
migraphx_ref
)
target_link_libraries
(
migraphx_ref PRIVATE Threads::Threads
)
target_link_libraries
(
migraphx_ref PUBLIC migraphx
)
target_include_directories
(
migraphx_ref PRIVATE
${
BLAZE_INCLUDE
}
)
target_include_directories
(
migraphx_ref
SYSTEM
PRIVATE
${
BLAZE_INCLUDE
}
)
target_compile_definitions
(
migraphx_ref PRIVATE -DBLAZE_USE_CPP_THREADS
)
migraphx_generate_export_header
(
migraphx_ref
)
...
...
src/tf/CMakeLists.txt
View file @
f8a75f8a
...
...
@@ -38,7 +38,11 @@ protobuf_generate_cpp(
)
add_library
(
tf-proto STATIC
${
PROTO_SRCS
}
)
target_include_directories
(
tf-proto SYSTEM PUBLIC
${
CMAKE_CURRENT_BINARY_DIR
}
${
PROTOBUF_INCLUDE_DIR
}
)
target_compile_options
(
tf-proto PRIVATE -w
)
if
(
MSVC
)
target_compile_options
(
tf-proto PRIVATE /w
)
else
()
target_compile_options
(
tf-proto PRIVATE -w
)
endif
()
target_link_libraries
(
tf-proto PRIVATE
${
PROTOBUF_LIBRARY
}
)
set_target_properties
(
tf-proto PROPERTIES POSITION_INDEPENDENT_CODE On
)
...
...
@@ -49,7 +53,10 @@ target_include_directories(migraphx_tf PRIVATE include)
set_target_properties
(
migraphx_tf PROPERTIES EXPORT_NAME tf
)
rocm_set_soversion
(
migraphx_tf
${
MIGRAPHX_SO_VERSION
}
)
rocm_clang_tidy_check
(
migraphx_tf
)
target_link_libraries
(
migraphx_tf PRIVATE tf-proto
"-Wl,--exclude-libs,ALL"
)
target_link_libraries
(
migraphx_tf PRIVATE tf-proto
)
if
(
NOT WIN32
)
target_link_libraries
(
migraphx_tf PRIVATE
"-Wl,--exclude-libs,ALL"
)
endif
()
target_link_libraries
(
migraphx_tf PUBLIC migraphx
)
rocm_install_targets
(
...
...
src/tmp_dir.cpp
View file @
f8a75f8a
...
...
@@ -31,8 +31,18 @@
#include <sstream>
#include <iostream>
#include <string>
#include <sys/types.h>
#ifdef _WIN32
// cppcheck-suppress definePrefix
#define WIN32_LEAN_AND_MEAN
#include <Windows.h>
#undef getpid
// cppcheck-suppress [definePrefix, defineUpperCase]
#define getpid _getpid
#else
#include <unistd.h>
#include <sys/types.h>
#endif
namespace
migraphx
{
inline
namespace
MIGRAPHX_INLINE_NS
{
...
...
src/verify_args.cpp
View file @
f8a75f8a
...
...
@@ -88,7 +88,6 @@ bool verify_args(const std::string& name,
if
(
target_nan_idx
>=
0
)
std
::
cout
<<
"Non finite number found in target at "
<<
target_nan_idx
<<
": "
<<
target
[
target_nan_idx
]
<<
std
::
endl
;
std
::
cout
<<
"MIGraphX verification passed successfully."
<<
std
::
endl
;
}
});
return
passed
;
...
...
test/CMakeLists.txt
View file @
f8a75f8a
...
...
@@ -150,6 +150,7 @@ function(test_headers PREFIX)
list
(
REMOVE_ITEM HEADERS
${
CMAKE_SOURCE_DIR
}
/src/targets/gpu/include/migraphx/gpu/ck.hpp
)
endif
()
list
(
REMOVE_ITEM HEADERS
${
CMAKE_SOURCE_DIR
}
/src/include/migraphx/float8_impl.hpp
)
foreach
(
HEADER
${
HEADERS
}
)
file
(
RELATIVE_PATH HEADER_REL
${
CMAKE_SOURCE_DIR
}
${
HEADER
}
)
string
(
MAKE_C_IDENTIFIER
${
HEADER_REL
}
TEST_NAME
)
...
...
test/float_equal.cpp
View file @
f8a75f8a
...
...
@@ -22,6 +22,7 @@
* THE SOFTWARE.
*/
#include <migraphx/float_equal.hpp>
#include <migraphx/float8.hpp>
#include <migraphx/half.hpp>
#include "test.hpp"
...
...
@@ -53,7 +54,7 @@ auto test_float_equal(T x, U y)
template
<
class
T
,
class
U
>
void
test_equality
()
{
auto
x1
=
T
(
0.1
);
auto
x1
=
T
(
0.1
25
);
auto
x2
=
U
(
0.0
);
auto
x3
=
U
(
1.0
);
EXPECT
(
test_float_equal
(
x1
,
x1
));
...
...
@@ -71,8 +72,12 @@ void test_equality()
TEST_CASE_REGISTER
(
test_equality
<
double
,
float
>
);
TEST_CASE_REGISTER
(
test_equality
<
double
,
int
>
);
TEST_CASE_REGISTER
(
test_equality
<
double
,
migraphx
::
half
>
);
TEST_CASE_REGISTER
(
test_equality
<
double
,
migraphx
::
fp8
::
fp8e4m3fnuz
>
);
TEST_CASE_REGISTER
(
test_equality
<
float
,
int
>
);
TEST_CASE_REGISTER
(
test_equality
<
float
,
migraphx
::
fp8
::
fp8e4m3fnuz
>
);
TEST_CASE_REGISTER
(
test_equality
<
migraphx
::
half
,
int
>
);
TEST_CASE_REGISTER
(
test_equality
<
migraphx
::
half
,
migraphx
::
fp8
::
fp8e4m3fnuz
>
);
TEST_CASE_REGISTER
(
test_equality
<
migraphx
::
fp8
::
fp8e4m3fnuz
,
int
>
);
template
<
class
T
,
class
U
>
void
test_limits
()
...
...
@@ -110,8 +115,13 @@ void test_limits()
TEST_CASE_REGISTER
(
test_limits
<
double
,
float
>
);
TEST_CASE_REGISTER
(
test_limits
<
double
,
int
>
);
TEST_CASE_REGISTER
(
test_limits
<
double
,
migraphx
::
half
>
);
TEST_CASE_REGISTER
(
test_limits
<
double
,
migraphx
::
fp8
::
fp8e4m3fnuz
>
);
TEST_CASE_REGISTER
(
test_limits
<
float
,
int
>
);
TEST_CASE_REGISTER
(
test_limits
<
float
,
migraphx
::
fp8
::
fp8e4m3fnuz
>
);
TEST_CASE_REGISTER
(
test_limits
<
int
,
migraphx
::
half
>
);
TEST_CASE_REGISTER
(
test_limits
<
int
,
migraphx
::
fp8
::
fp8e4m3fnuz
>
);
TEST_CASE_REGISTER
(
test_limits
<
migraphx
::
fp8
::
fp8e4m3fnuz
,
migraphx
::
half
>
);
#ifndef _WIN32
// On Windows, types int and long have the same min and max values.
TEST_CASE_REGISTER
(
test_limits
<
long
,
int
>
);
...
...
test/fp8e4m3fn.cpp
0 → 100644
View file @
f8a75f8a
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#include <cmath>
#include <migraphx/float_equal.hpp>
#include <migraphx/float8.hpp>
#include <migraphx/half.hpp>
#include <migraphx/ranges.hpp>
#include "test.hpp"
#include <limits>
float
fp8e4m3fn_to_fp32_value
(
uint8_t
input
)
{
constexpr
std
::
array
<
float
,
256
>
e4m3fnuz_lut
=
{
0.0
,
0.001953125
,
0.00390625
,
0.005859375
,
0.0078125
,
0.009765625
,
0.01171875
,
0.013671875
,
0.015625
,
0.017578125
,
0.01953125
,
0.021484375
,
0.0234375
,
0.025390625
,
0.02734375
,
0.029296875
,
0.03125
,
0.03515625
,
0.0390625
,
0.04296875
,
0.046875
,
0.05078125
,
0.0546875
,
0.05859375
,
0.0625
,
0.0703125
,
0.078125
,
0.0859375
,
0.09375
,
0.1015625
,
0.109375
,
0.1171875
,
0.125
,
0.140625
,
0.15625
,
0.171875
,
0.1875
,
0.203125
,
0.21875
,
0.234375
,
0.25
,
0.28125
,
0.3125
,
0.34375
,
0.375
,
0.40625
,
0.4375
,
0.46875
,
0.5
,
0.5625
,
0.625
,
0.6875
,
0.75
,
0.8125
,
0.875
,
0.9375
,
1.0
,
1.125
,
1.25
,
1.375
,
1.5
,
1.625
,
1.75
,
1.875
,
2.0
,
2.25
,
2.5
,
2.75
,
3.0
,
3.25
,
3.5
,
3.75
,
4.0
,
4.5
,
5.0
,
5.5
,
6.0
,
6.5
,
7.0
,
7.5
,
8.0
,
9.0
,
10.0
,
11.0
,
12.0
,
13.0
,
14.0
,
15.0
,
16.0
,
18.0
,
20.0
,
22.0
,
24.0
,
26.0
,
28.0
,
30.0
,
32.0
,
36.0
,
40.0
,
44.0
,
48.0
,
52.0
,
56.0
,
60.0
,
64.0
,
72.0
,
80.0
,
88.0
,
96.0
,
104.0
,
112.0
,
120.0
,
128.0
,
144.0
,
160.0
,
176.0
,
192.0
,
208.0
,
224.0
,
240.0
,
256.0
,
288.0
,
320.0
,
352.0
,
384.0
,
416.0
,
448.0
,
std
::
numeric_limits
<
float
>::
quiet_NaN
(),
-
0.0
,
-
0.001953125
,
-
0.00390625
,
-
0.005859375
,
-
0.0078125
,
-
0.009765625
,
-
0.01171875
,
-
0.013671875
,
-
0.015625
,
-
0.017578125
,
-
0.01953125
,
-
0.021484375
,
-
0.0234375
,
-
0.025390625
,
-
0.02734375
,
-
0.029296875
,
-
0.03125
,
-
0.03515625
,
-
0.0390625
,
-
0.04296875
,
-
0.046875
,
-
0.05078125
,
-
0.0546875
,
-
0.05859375
,
-
0.0625
,
-
0.0703125
,
-
0.078125
,
-
0.0859375
,
-
0.09375
,
-
0.1015625
,
-
0.109375
,
-
0.1171875
,
-
0.125
,
-
0.140625
,
-
0.15625
,
-
0.171875
,
-
0.1875
,
-
0.203125
,
-
0.21875
,
-
0.234375
,
-
0.25
,
-
0.28125
,
-
0.3125
,
-
0.34375
,
-
0.375
,
-
0.40625
,
-
0.4375
,
-
0.46875
,
-
0.5
,
-
0.5625
,
-
0.625
,
-
0.6875
,
-
0.75
,
-
0.8125
,
-
0.875
,
-
0.9375
,
-
1.0
,
-
1.125
,
-
1.25
,
-
1.375
,
-
1.5
,
-
1.625
,
-
1.75
,
-
1.875
,
-
2.0
,
-
2.25
,
-
2.5
,
-
2.75
,
-
3.0
,
-
3.25
,
-
3.5
,
-
3.75
,
-
4.0
,
-
4.5
,
-
5.0
,
-
5.5
,
-
6.0
,
-
6.5
,
-
7.0
,
-
7.5
,
-
8.0
,
-
9.0
,
-
10.0
,
-
11.0
,
-
12.0
,
-
13.0
,
-
14.0
,
-
15.0
,
-
16.0
,
-
18.0
,
-
20.0
,
-
22.0
,
-
24.0
,
-
26.0
,
-
28.0
,
-
30.0
,
-
32.0
,
-
36.0
,
-
40.0
,
-
44.0
,
-
48.0
,
-
52.0
,
-
56.0
,
-
60.0
,
-
64.0
,
-
72.0
,
-
80.0
,
-
88.0
,
-
96.0
,
-
104.0
,
-
112.0
,
-
120.0
,
-
128.0
,
-
144.0
,
-
160.0
,
-
176.0
,
-
192.0
,
-
208.0
,
-
224.0
,
-
240.0
,
-
256.0
,
-
288.0
,
-
320.0
,
-
352.0
,
-
384.0
,
-
416.0
,
-
448.0
,
std
::
numeric_limits
<
float
>::
quiet_NaN
(),
};
return
e4m3fnuz_lut
[
input
];
}
TEST_CASE
(
test_fp8_cast_to_float
)
{
std
::
vector
<
uint8_t
>
bit_vals
(
256
);
std
::
iota
(
bit_vals
.
begin
(),
bit_vals
.
end
(),
0
);
EXPECT
(
bool
{
std
::
all_of
(
bit_vals
.
begin
(),
bit_vals
.
end
(),
[](
uint8_t
bit_val
)
{
migraphx
::
fp8
::
fp8e4m3fn
fp8_val
(
bit_val
,
migraphx
::
fp8
::
fp8e4m3fn
::
from_bits
());
if
(
std
::
isnan
(
float
(
fp8_val
))
and
std
::
isnan
(
fp8e4m3fn_to_fp32_value
(
bit_val
)))
{
return
true
;
}
return
migraphx
::
float_equal
(
float
(
fp8_val
),
fp8e4m3fn_to_fp32_value
(
bit_val
));
})});
}
TEST_CASE
(
test_fp8_cast_from_float
)
{
std
::
unordered_map
<
float
,
uint8_t
>
test_vals
=
{
{{
512
,
0x7e
},
{
-
512
,
0xfe
},
{
448
,
0x7e
},
{
-
448
,
0xfe
},
{
256
,
0x78
},
{
-
256
,
0xf8
},
{
240
,
0x77
},
{
-
240
,
0xf7
},
{
1e-07
,
0x0
},
{
1e+07
,
0x7e
},
{
1
,
0x38
},
{
-
1
,
0xb8
},
{
0.1
,
0x1d
},
{
0.11
,
0x1e
},
{
0.111
,
0x1e
},
{
0.1111
,
0x1e
},
{
-
0.1
,
0x9d
},
{
-
0.11
,
0x9e
},
{
-
0.111
,
0x9e
},
{
-
0.1111
,
0x9e
},
{
0.2
,
0x25
},
{
2
,
0x40
},
{
20
,
0x5a
},
{
200
,
0x74
},
{
-
0.2
,
0xa5
},
{
-
2
,
0xc0
},
{
-
20
,
0xda
},
{
-
200
,
0xf4
},
{
0.5
,
0x30
},
{
-
0.5
,
0xb0
},
{
1.17549e-38
,
0x0
},
{
1.4013e-45
,
0x0
},
{
0.0078125
,
0x4
},
{
-
0.0078125
,
0x84
},
{
0.000976562
,
0x0
},
{
-
0.000976562
,
0x80
},
{
0.000488281
,
0x0
},
{
-
0.000488281
,
0x80
}}};
EXPECT
(
bool
{
std
::
all_of
(
test_vals
.
begin
(),
test_vals
.
end
(),
[](
const
auto
sample
)
{
return
migraphx
::
float_equal
(
migraphx
::
fp8
::
fp8e4m3fn
(
sample
.
first
),
migraphx
::
fp8
::
fp8e4m3fn
(
sample
.
second
,
migraphx
::
fp8
::
fp8e4m3fn
::
from_bits
()));
})});
}
TEST_CASE
(
test_positive_zero
)
{
float
zero
=
0.0
;
migraphx
::
fp8
::
fp8e4m3fn
fp8_zero
(
zero
);
EXPECT
(
fp8_zero
.
is_zero
());
EXPECT
(
migraphx
::
float_equal
(
zero
,
float
(
fp8_zero
)));
}
TEST_CASE
(
test_negative_zero
)
{
float
nzero
=
-
0.0
;
migraphx
::
fp8
::
fp8e4m3fn
fp8_nzero
(
nzero
);
EXPECT
(
fp8_nzero
.
is_zero
());
// negative zero is preserved for fp8e4m3fn
EXPECT
(
migraphx
::
float_equal
(
nzero
,
float
(
fp8_nzero
)));
}
TEST_CASE
(
test_pos_zero_eq_neg_zero
)
{
float
nzero
=
-
0.0
;
float
pzero
=
0.0
;
migraphx
::
fp8
::
fp8e5m2
fp8_nzero
(
nzero
);
migraphx
::
fp8
::
fp8e5m2
fp8_pzero
(
pzero
);
EXPECT
(
fp8_nzero
==
fp8_pzero
);
}
TEST_CASE
(
test_nan_1
)
{
float
fnan
=
std
::
numeric_limits
<
float
>::
quiet_NaN
();
migraphx
::
fp8
::
fp8e4m3fn
fp8_nan
(
fnan
);
EXPECT
(
fp8_nan
.
is_nan
());
EXPECT
(
std
::
isnan
(
fp8_nan
));
}
TEST_CASE
(
test_nan_2
)
{
auto
fnan
=
std
::
numeric_limits
<
migraphx
::
fp8
::
fp8e4m3fn
>::
quiet_NaN
();
migraphx
::
fp8
::
fp8e4m3fn
fp8_nan
(
fnan
.
data
,
migraphx
::
fp8
::
fp8e4m3fn
::
from_bits
());
EXPECT
(
fp8_nan
.
is_nan
());
EXPECT
(
std
::
isnan
(
fp8_nan
));
EXPECT
(
std
::
isnan
(
float
(
fp8_nan
)));
}
TEST_CASE
(
test_infinity_1
)
{
float
finf
=
std
::
numeric_limits
<
float
>::
infinity
();
// no inf in fp8e4m3fn, it gets clipped to max()
migraphx
::
fp8
::
fp8e4m3fn
fp8_max
(
finf
);
EXPECT
(
fp8_max
==
std
::
numeric_limits
<
migraphx
::
fp8
::
fp8e4m3fn
>::
max
());
}
TEST_CASE
(
test_infinity_2
)
{
// neg inf
float
finf
=
-
1.0
*
std
::
numeric_limits
<
float
>::
infinity
();
// no inf in fp8e4m3fn, it gets clipped to lowest
migraphx
::
fp8
::
fp8e4m3fn
fp8_lowest
(
finf
);
EXPECT
(
bool
{
fp8_lowest
==
std
::
numeric_limits
<
migraphx
::
fp8
::
fp8e4m3fn
>::
lowest
()});
}
TEST_CASE
(
test_numeric_max_1
)
{
float
fmax
=
std
::
numeric_limits
<
float
>::
max
();
migraphx
::
fp8
::
fp8e4m3fn
fp8_max
(
fmax
);
EXPECT
(
fp8_max
==
std
::
numeric_limits
<
migraphx
::
fp8
::
fp8e4m3fn
>::
max
());
}
TEST_CASE
(
test_numeric_max_2
)
{
// gets clipped to max
float
fmax
=
2
*
std
::
numeric_limits
<
migraphx
::
fp8
::
fp8e4m3fn
>::
max
();
migraphx
::
fp8
::
fp8e4m3fn
fp8_max
(
fmax
);
EXPECT
(
fp8_max
==
std
::
numeric_limits
<
migraphx
::
fp8
::
fp8e4m3fn
>::
max
());
}
TEST_CASE
(
test_numeric_lowest_1
)
{
float
flowest
=
std
::
numeric_limits
<
float
>::
lowest
();
migraphx
::
fp8
::
fp8e4m3fn
fp8_lowest
(
flowest
);
EXPECT
(
fp8_lowest
==
std
::
numeric_limits
<
migraphx
::
fp8
::
fp8e4m3fn
>::
lowest
());
}
TEST_CASE
(
test_numeric_lowest_2
)
{
// gets clipped to lowest
float
fmin
=
2.0
*
std
::
numeric_limits
<
migraphx
::
fp8
::
fp8e4m3fn
>::
lowest
();
migraphx
::
fp8
::
fp8e4m3fn
fp8_lowest
(
fmin
);
EXPECT
(
fp8_lowest
==
std
::
numeric_limits
<
migraphx
::
fp8
::
fp8e4m3fn
>::
lowest
());
}
TEST_CASE
(
test_max_eq_lowest
)
{
EXPECT
(
migraphx
::
float_equal
(
std
::
numeric_limits
<
migraphx
::
fp8
::
fp8e4m3fn
>::
lowest
(),
-
1
*
std
::
numeric_limits
<
migraphx
::
fp8
::
fp8e4m3fn
>::
max
()));
}
TEST_CASE
(
test_isfinite
)
{
EXPECT
(
std
::
isfinite
(
migraphx
::
fp8
::
fp8e4m3fn
(
0.0
)));
EXPECT
(
std
::
isfinite
(
migraphx
::
fp8
::
fp8e4m3fn
(
-
0.0
)));
EXPECT
(
not
std
::
isfinite
(
migraphx
::
fp8
::
fp8e4m3fn
(
std
::
numeric_limits
<
migraphx
::
fp8
::
fp8e4m3fn
>::
quiet_NaN
())));
}
TEST_CASE
(
test_no_infinity
)
{
EXPECT
(
not
bool
{
std
::
numeric_limits
<
migraphx
::
fp8
::
fp8e4m3fn
>::
has_infinity
});
}
TEST_CASE
(
test_binary_ops
)
{
auto
a
=
migraphx
::
fp8
::
fp8e4m3fn
(
-
1.0
);
auto
b
=
migraphx
::
fp8
::
fp8e4m3fn
(
1.0
);
auto
c
=
migraphx
::
fp8
::
fp8e4m3fn
(
0.0
);
auto
d
=
migraphx
::
fp8
::
fp8e4m3fn
(
-
0.0
);
EXPECT
(
migraphx
::
float_equal
((
c
+
d
),
c
));
EXPECT
(
migraphx
::
float_equal
((
c
+
d
),
d
));
EXPECT
(
migraphx
::
float_equal
((
a
+
b
),
c
));
EXPECT
(
migraphx
::
float_equal
((
a
+
b
),
d
));
auto
e
=
migraphx
::
fp8
::
fp8e4m3fn
(
10.0
);
auto
f
=
migraphx
::
fp8
::
fp8e4m3fn
(
-
10.0
);
EXPECT
(
bool
{
e
>
f
});
EXPECT
(
bool
{
f
<
e
});
EXPECT
(
bool
{
f
<=
e
});
EXPECT
(
bool
{
e
>=
f
});
EXPECT
(
bool
{
e
<=
e
});
EXPECT
(
bool
{
f
>=
f
});
EXPECT
(
not
migraphx
::
float_equal
(
f
,
e
));
}
TEST_CASE
(
test_fabs
)
{
auto
a
=
migraphx
::
fp8
::
fp8e4m3fn
(
-
1.0
);
auto
b
=
migraphx
::
fp8
::
fp8e4m3fn
(
1.0
);
EXPECT
(
migraphx
::
float_equal
(
b
,
migraphx
::
fp8
::
fabs
(
a
)));
}
TEST_CASE
(
test_stream_op
)
{
auto
a
=
migraphx
::
fp8
::
fp8e4m3fn
(
-
1.0
);
std
::
stringstream
ss
;
ss
<<
a
;
EXPECT
(
std
::
string
(
"-1"
)
==
ss
.
str
());
ss
=
std
::
stringstream
();
auto
b
=
std
::
numeric_limits
<
migraphx
::
fp8
::
fp8e4m3fn
>::
quiet_NaN
();
ss
<<
b
;
EXPECT
(
std
::
string
(
"nan"
)
==
ss
.
str
());
}
int
main
(
int
argc
,
const
char
*
argv
[])
{
test
::
run
(
argc
,
argv
);
}
test/fp8e4m3fnuz.cpp
0 → 100644
View file @
f8a75f8a
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#include <cmath>
#include <migraphx/float_equal.hpp>
#include <migraphx/float8.hpp>
#include <migraphx/half.hpp>
#include <migraphx/ranges.hpp>
#include "test.hpp"
#include <limits>
float
fp8e4m3fnuz_to_fp32_value
(
uint8_t
input
)
{
constexpr
std
::
array
<
float
,
256
>
e4m3fnuz_lut
=
{
0.0
f
,
0.0009765625
f
,
0.001953125
f
,
0.0029296875
f
,
0.00390625
f
,
0.0048828125
f
,
0.005859375
f
,
0.0068359375
f
,
0.0078125
f
,
0.0087890625
f
,
0.009765625
f
,
0.0107421875
f
,
0.01171875
f
,
0.0126953125
f
,
0.013671875
f
,
0.0146484375
f
,
0.015625
f
,
0.017578125
f
,
0.01953125
f
,
0.021484375
f
,
0.0234375
f
,
0.025390625
f
,
0.02734375
f
,
0.029296875
f
,
0.03125
f
,
0.03515625
f
,
0.0390625
f
,
0.04296875
f
,
0.046875
f
,
0.05078125
f
,
0.0546875
f
,
0.05859375
f
,
0.0625
f
,
0.0703125
f
,
0.078125
f
,
0.0859375
f
,
0.09375
f
,
0.1015625
f
,
0.109375
f
,
0.1171875
f
,
0.125
f
,
0.140625
f
,
0.15625
f
,
0.171875
f
,
0.1875
f
,
0.203125
f
,
0.21875
f
,
0.234375
f
,
0.25
f
,
0.28125
f
,
0.3125
f
,
0.34375
f
,
0.375
f
,
0.40625
f
,
0.4375
f
,
0.46875
f
,
0.5
f
,
0.5625
f
,
0.625
f
,
0.6875
f
,
0.75
f
,
0.8125
f
,
0.875
f
,
0.9375
f
,
1.0
f
,
1.125
f
,
1.25
f
,
1.375
f
,
1.5
f
,
1.625
f
,
1.75
f
,
1.875
f
,
2.0
f
,
2.25
f
,
2.5
f
,
2.75
f
,
3.0
f
,
3.25
f
,
3.5
f
,
3.75
f
,
4.0
f
,
4.5
f
,
5.0
f
,
5.5
f
,
6.0
f
,
6.5
f
,
7.0
f
,
7.5
f
,
8.0
f
,
9.0
f
,
10.0
f
,
11.0
f
,
12.0
f
,
13.0
f
,
14.0
f
,
15.0
f
,
16.0
f
,
18.0
f
,
20.0
f
,
22.0
f
,
24.0
f
,
26.0
f
,
28.0
f
,
30.0
f
,
32.0
f
,
36.0
f
,
40.0
f
,
44.0
f
,
48.0
f
,
52.0
f
,
56.0
f
,
60.0
f
,
64.0
f
,
72.0
f
,
80.0
f
,
88.0
f
,
96.0
f
,
104.0
f
,
112.0
f
,
120.0
f
,
128.0
f
,
144.0
f
,
160.0
f
,
176.0
f
,
192.0
f
,
208.0
f
,
224.0
f
,
240.0
f
,
std
::
numeric_limits
<
float
>::
quiet_NaN
(),
-
0.0009765625
f
,
-
0.001953125
f
,
-
0.0029296875
f
,
-
0.00390625
f
,
-
0.0048828125
f
,
-
0.005859375
f
,
-
0.0068359375
f
,
-
0.0078125
f
,
-
0.0087890625
f
,
-
0.009765625
f
,
-
0.0107421875
f
,
-
0.01171875
f
,
-
0.0126953125
f
,
-
0.013671875
f
,
-
0.0146484375
f
,
-
0.015625
f
,
-
0.017578125
f
,
-
0.01953125
f
,
-
0.021484375
f
,
-
0.0234375
f
,
-
0.025390625
f
,
-
0.02734375
f
,
-
0.029296875
f
,
-
0.03125
f
,
-
0.03515625
f
,
-
0.0390625
f
,
-
0.04296875
f
,
-
0.046875
f
,
-
0.05078125
f
,
-
0.0546875
f
,
-
0.05859375
f
,
-
0.0625
f
,
-
0.0703125
f
,
-
0.078125
f
,
-
0.0859375
f
,
-
0.09375
f
,
-
0.1015625
f
,
-
0.109375
f
,
-
0.1171875
f
,
-
0.125
f
,
-
0.140625
f
,
-
0.15625
f
,
-
0.171875
f
,
-
0.1875
f
,
-
0.203125
f
,
-
0.21875
f
,
-
0.234375
f
,
-
0.25
f
,
-
0.28125
f
,
-
0.3125
f
,
-
0.34375
f
,
-
0.375
f
,
-
0.40625
f
,
-
0.4375
f
,
-
0.46875
f
,
-
0.5
f
,
-
0.5625
f
,
-
0.625
f
,
-
0.6875
f
,
-
0.75
f
,
-
0.8125
f
,
-
0.875
f
,
-
0.9375
f
,
-
1.0
f
,
-
1.125
f
,
-
1.25
f
,
-
1.375
f
,
-
1.5
f
,
-
1.625
f
,
-
1.75
f
,
-
1.875
f
,
-
2.0
f
,
-
2.25
f
,
-
2.5
f
,
-
2.75
f
,
-
3.0
f
,
-
3.25
f
,
-
3.5
f
,
-
3.75
f
,
-
4.0
f
,
-
4.5
f
,
-
5.0
f
,
-
5.5
f
,
-
6.0
f
,
-
6.5
f
,
-
7.0
f
,
-
7.5
f
,
-
8.0
f
,
-
9.0
f
,
-
10.0
f
,
-
11.0
f
,
-
12.0
f
,
-
13.0
f
,
-
14.0
f
,
-
15.0
f
,
-
16.0
f
,
-
18.0
f
,
-
20.0
f
,
-
22.0
f
,
-
24.0
f
,
-
26.0
f
,
-
28.0
f
,
-
30.0
f
,
-
32.0
f
,
-
36.0
f
,
-
40.0
f
,
-
44.0
f
,
-
48.0
f
,
-
52.0
f
,
-
56.0
f
,
-
60.0
f
,
-
64.0
f
,
-
72.0
f
,
-
80.0
f
,
-
88.0
f
,
-
96.0
f
,
-
104.0
f
,
-
112.0
f
,
-
120.0
f
,
-
128.0
f
,
-
144.0
f
,
-
160.0
f
,
-
176.0
f
,
-
192.0
f
,
-
208.0
f
,
-
224.0
f
,
-
240.0
f
,
};
return
e4m3fnuz_lut
[
input
];
}
TEST_CASE
(
test_fp8_cast_to_float
)
{
std
::
vector
<
uint8_t
>
bit_vals
(
256
);
std
::
iota
(
bit_vals
.
begin
(),
bit_vals
.
end
(),
0
);
EXPECT
(
bool
{
std
::
all_of
(
bit_vals
.
begin
(),
bit_vals
.
end
(),
[](
uint8_t
bit_val
)
{
migraphx
::
fp8
::
fp8e4m3fnuz
fp8_val
(
bit_val
,
migraphx
::
fp8
::
fp8e4m3fnuz
::
from_bits
());
if
(
std
::
isnan
(
float
(
fp8_val
))
and
std
::
isnan
(
fp8e4m3fnuz_to_fp32_value
(
bit_val
)))
{
return
true
;
}
return
migraphx
::
float_equal
(
float
(
fp8_val
),
fp8e4m3fnuz_to_fp32_value
(
bit_val
));
})});
}
TEST_CASE
(
test_fp8_cast_from_float
)
{
std
::
unordered_map
<
float
,
uint8_t
>
test_vals
=
{{
256
,
0x7f
},
{
-
256
,
0xff
},
{
240
,
0x7f
},
{
-
240
,
0xff
},
{
1e-07
,
0x0
},
{
1e+07
,
0x7f
},
{
1
,
0x40
},
{
-
1
,
0xc0
},
{
0.1
,
0x25
},
{
0.11
,
0x26
},
{
0.111
,
0x26
},
{
0.1111
,
0x26
},
{
-
0.1
,
0xa5
},
{
-
0.11
,
0xa6
},
{
-
0.111
,
0xa6
},
{
-
0.1111
,
0xa6
},
{
0.2
,
0x2d
},
{
2
,
0x48
},
{
20
,
0x62
},
{
200
,
0x7c
},
{
-
0.2
,
0xad
},
{
-
2
,
0xc8
},
{
-
20
,
0xe2
},
{
-
200
,
0xfc
},
{
0.5
,
0x38
},
{
-
0.5
,
0xb8
},
{
1.17549e-38
,
0x0
},
{
1.4013e-45
,
0x0
},
{
0.00390625
,
0x4
},
{
-
0.00390625
,
0x84
},
{
0.00195312
,
0x2
},
{
-
0.00195312
,
0x82
},
{
0.000976562
,
0x1
},
{
-
0.000976562
,
0x81
},
{
0.000488281
,
0x0
},
{
-
0.000488281
,
0x0
}};
EXPECT
(
bool
{
std
::
all_of
(
test_vals
.
begin
(),
test_vals
.
end
(),
[](
const
auto
sample
)
{
return
migraphx
::
float_equal
(
migraphx
::
fp8
::
fp8e4m3fnuz
(
sample
.
first
),
migraphx
::
fp8
::
fp8e4m3fnuz
(
sample
.
second
,
migraphx
::
fp8
::
fp8e4m3fnuz
::
from_bits
()));
})});
}
TEST_CASE
(
test_positive_zero
)
{
float
zero
=
0.0
;
migraphx
::
fp8
::
fp8e4m3fnuz
fp8_zero
(
zero
);
EXPECT
(
fp8_zero
.
is_zero
());
EXPECT
(
migraphx
::
float_equal
(
zero
,
float
(
fp8_zero
)));
}
TEST_CASE
(
test_negative_zero
)
{
float
nzero
=
-
0.0
;
float
pzero
=
0.0
;
migraphx
::
fp8
::
fp8e4m3fnuz
fp8_nzero
(
nzero
);
EXPECT
(
fp8_nzero
.
is_zero
());
// negative zero gets converted to positive zero
EXPECT
(
migraphx
::
float_equal
(
pzero
,
float
(
fp8_nzero
)));
}
TEST_CASE
(
test_nan_1
)
{
float
fnan
=
std
::
numeric_limits
<
float
>::
quiet_NaN
();
migraphx
::
fp8
::
fp8e4m3fnuz
fp8_nan
(
fnan
);
EXPECT
(
fp8_nan
.
is_nan
());
EXPECT
(
std
::
isnan
(
fp8_nan
));
}
TEST_CASE
(
test_nan_2
)
{
auto
fnan
=
std
::
numeric_limits
<
migraphx
::
fp8
::
fp8e4m3fnuz
>::
quiet_NaN
();
migraphx
::
fp8
::
fp8e4m3fnuz
fp8_nan
(
fnan
.
data
,
migraphx
::
fp8
::
fp8e4m3fnuz
::
from_bits
());
EXPECT
(
fp8_nan
.
is_nan
());
EXPECT
(
std
::
isnan
(
fp8_nan
));
EXPECT
(
std
::
isnan
(
float
(
fp8_nan
)));
}
TEST_CASE
(
test_infinity_1
)
{
float
finf
=
std
::
numeric_limits
<
float
>::
infinity
();
// no inf in fp8e4m3fnuz it gets clipped to Nans
migraphx
::
fp8
::
fp8e4m3fnuz
fp8_nan
(
finf
);
EXPECT
(
fp8_nan
.
is_nan
());
EXPECT
(
std
::
isnan
(
float
(
fp8_nan
)));
}
TEST_CASE
(
test_infinity_2
)
{
// neg inf
float
finf
=
-
1.0
*
std
::
numeric_limits
<
float
>::
infinity
();
// no inf in fp8e4m3fnuz it gets clipped to NaNs
migraphx
::
fp8
::
fp8e4m3fnuz
fp8_nan
(
finf
);
EXPECT
(
fp8_nan
.
is_nan
());
EXPECT
(
std
::
isnan
(
float
(
fp8_nan
)));
}
TEST_CASE
(
test_numeric_max_1
)
{
float
fmax
=
std
::
numeric_limits
<
float
>::
max
();
migraphx
::
fp8
::
fp8e4m3fnuz
fp8_max
(
fmax
);
EXPECT
(
fp8_max
==
std
::
numeric_limits
<
migraphx
::
fp8
::
fp8e4m3fnuz
>::
max
());
}
TEST_CASE
(
test_numeric_max_2
)
{
// gets clipped to max
float
fmax
=
2
*
std
::
numeric_limits
<
migraphx
::
fp8
::
fp8e4m3fnuz
>::
max
();
migraphx
::
fp8
::
fp8e4m3fnuz
fp8_max
(
fmax
);
EXPECT
(
fp8_max
==
std
::
numeric_limits
<
migraphx
::
fp8
::
fp8e4m3fnuz
>::
max
());
}
TEST_CASE
(
test_numeric_lowest_1
)
{
float
flowest
=
std
::
numeric_limits
<
float
>::
lowest
();
migraphx
::
fp8
::
fp8e4m3fnuz
fp8_lowest
(
flowest
);
EXPECT
(
fp8_lowest
==
std
::
numeric_limits
<
migraphx
::
fp8
::
fp8e4m3fnuz
>::
lowest
());
}
TEST_CASE
(
test_numeric_lowest_2
)
{
// gets clipped to lowest
float
fmin
=
2.0
*
std
::
numeric_limits
<
migraphx
::
fp8
::
fp8e4m3fnuz
>::
lowest
();
migraphx
::
fp8
::
fp8e4m3fnuz
fp8_lowest
(
fmin
);
EXPECT
(
fp8_lowest
==
std
::
numeric_limits
<
migraphx
::
fp8
::
fp8e4m3fnuz
>::
lowest
());
}
TEST_CASE
(
test_max_eq_lowest
)
{
EXPECT
(
migraphx
::
float_equal
(
std
::
numeric_limits
<
migraphx
::
fp8
::
fp8e4m3fnuz
>::
lowest
(),
-
1
*
std
::
numeric_limits
<
migraphx
::
fp8
::
fp8e4m3fnuz
>::
max
()));
}
TEST_CASE
(
test_isfinite
)
{
EXPECT
(
std
::
isfinite
(
migraphx
::
fp8
::
fp8e4m3fnuz
(
0.0
)));
EXPECT
(
std
::
isfinite
(
migraphx
::
fp8
::
fp8e4m3fnuz
(
-
0.0
)));
EXPECT
(
not
std
::
isfinite
(
migraphx
::
fp8
::
fp8e4m3fnuz
(
std
::
numeric_limits
<
migraphx
::
fp8
::
fp8e4m3fnuz
>::
quiet_NaN
())));
}
TEST_CASE
(
test_no_infinity
)
{
EXPECT
(
not
bool
{
std
::
numeric_limits
<
migraphx
::
fp8
::
fp8e4m3fnuz
>::
has_infinity
});
}
TEST_CASE
(
test_binary_ops
)
{
auto
a
=
migraphx
::
fp8
::
fp8e4m3fnuz
(
-
1.0
);
auto
b
=
migraphx
::
fp8
::
fp8e4m3fnuz
(
1.0
);
auto
c
=
migraphx
::
fp8
::
fp8e4m3fnuz
(
0.0
);
auto
d
=
migraphx
::
fp8
::
fp8e4m3fnuz
(
-
0.0
);
EXPECT
(
migraphx
::
float_equal
((
c
+
d
),
c
));
EXPECT
(
migraphx
::
float_equal
((
c
+
d
),
d
));
EXPECT
(
migraphx
::
float_equal
((
a
+
b
),
c
));
EXPECT
(
migraphx
::
float_equal
((
a
+
b
),
d
));
auto
e
=
migraphx
::
fp8
::
fp8e4m3fnuz
(
10.0
);
auto
f
=
migraphx
::
fp8
::
fp8e4m3fnuz
(
-
10.0
);
EXPECT
(
bool
{
e
>
f
});
EXPECT
(
bool
{
f
<
e
});
EXPECT
(
bool
{
f
<=
e
});
EXPECT
(
bool
{
e
>=
f
});
EXPECT
(
bool
{
e
<=
e
});
EXPECT
(
bool
{
f
>=
f
});
EXPECT
(
not
migraphx
::
float_equal
(
f
,
e
));
}
TEST_CASE
(
test_fabs
)
{
auto
a
=
migraphx
::
fp8
::
fp8e4m3fnuz
(
-
1.0
);
auto
b
=
migraphx
::
fp8
::
fp8e4m3fnuz
(
1.0
);
EXPECT
(
migraphx
::
float_equal
(
b
,
migraphx
::
fp8
::
fabs
(
a
)));
}
TEST_CASE
(
test_stream_op
)
{
auto
a
=
migraphx
::
fp8
::
fp8e4m3fnuz
(
-
1.0
);
std
::
stringstream
ss
;
ss
<<
a
;
EXPECT
(
std
::
string
(
"-1"
)
==
ss
.
str
());
ss
=
std
::
stringstream
();
auto
b
=
std
::
numeric_limits
<
migraphx
::
fp8
::
fp8e4m3fnuz
>::
quiet_NaN
();
ss
<<
b
;
EXPECT
(
std
::
string
(
"nan"
)
==
ss
.
str
());
}
int
main
(
int
argc
,
const
char
*
argv
[])
{
test
::
run
(
argc
,
argv
);
}
Prev
1
2
3
4
5
6
7
8
9
10
…
13
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment