Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
MIGraphX
Commits
4935b575
Commit
4935b575
authored
Nov 01, 2023
by
Umang Yadav
Browse files
test_gpu_jit working after removing implicit_conversion_op
parent
16b5e050
Changes
8
Show whitespace changes
Inline
Side-by-side
Showing
8 changed files
with
81 additions
and
47 deletions
+81
-47
src/include/migraphx/fp8e4m3fnuz.hpp
src/include/migraphx/fp8e4m3fnuz.hpp
+41
-25
src/include/migraphx/half.hpp
src/include/migraphx/half.hpp
+13
-0
src/targets/gpu/CMakeLists.txt
src/targets/gpu/CMakeLists.txt
+3
-0
src/targets/gpu/kernels/include/migraphx/kernels/pointwise.hpp
...argets/gpu/kernels/include/migraphx/kernels/pointwise.hpp
+1
-2
src/targets/gpu/kernels/include/migraphx/kernels/reduce.hpp
src/targets/gpu/kernels/include/migraphx/kernels/reduce.hpp
+3
-4
src/targets/gpu/kernels/include/migraphx/kernels/type_traits.hpp
...gets/gpu/kernels/include/migraphx/kernels/type_traits.hpp
+6
-1
src/targets/gpu/kernels/include/migraphx/kernels/types.hpp
src/targets/gpu/kernels/include/migraphx/kernels/types.hpp
+1
-1
test/gpu/jit.cpp
test/gpu/jit.cpp
+13
-14
No files found.
src/include/migraphx/fp8e4m3fnuz.hpp
View file @
4935b575
...
...
@@ -52,7 +52,7 @@
#include <limits>
#include <sstream>
#include <iostream>
#include <migraphx/
half
.hpp>
#include <migraphx/
config
.hpp>
#include <string>
#include <utility>
...
...
@@ -61,6 +61,7 @@
// therefore, when this file is used from the host side, compilation takes much
// longer. By guarding the __device__ directive we can control that such compilation
// only happens for kernels which include this file.
// need to include hip_runtime.h otherwise it complains about __host__ and __device__
#include <hip/hip_runtime.h>
#define MIGRAPHX_HIP_HOST_DEVICE __host__ __device__
#else
...
...
@@ -72,7 +73,7 @@
#pragma clang diagnostic ignored "-Wold-style-cast"
#pragma clang diagnostic ignored "-Wreserved-identifier"
namespace
migraphx
{
inline
namespace
MIGRAPHX_INLINE_NS
{
namespace
detail
{
inline
MIGRAPHX_HIP_HOST_DEVICE
float
fp32_from_bits
(
uint32_t
w
)
...
...
@@ -102,7 +103,7 @@ inline MIGRAPHX_HIP_HOST_DEVICE uint32_t fp32_to_bits(float f)
*
* @note The implementation doesn't use any floating-point operations.
*/
inline
MIGRAPHX_HIP_HOST_DEVICE
float
fp8e4m3fnuz_to_fp32_value
(
uint8_t
input
)
inline
MIGRAPHX_HIP_HOST_DEVICE
constexpr
float
fp8e4m3fnuz_to_fp32_value
(
uint8_t
input
)
{
constexpr
float
e4m3fnuz_lut
[
256
]
=
{
0.0
f
,
0.0009765625
f
,
0.001953125
f
,
...
...
@@ -275,6 +276,24 @@ inline MIGRAPHX_HIP_HOST_DEVICE uint8_t fp8e4m3fnuz_from_fp32_value(float f)
return
result
;
}
/// Temporary half-precision expression.
/// This class represents a half-precision expression which just stores a single-precision value
/// internally.
struct
expr
{
/// Conversion constructor.
/// \param f single-precision value to convert
explicit
expr
(
float
f
)
:
value_
(
f
)
{}
/// Conversion to single-precision.
/// \return single precision value representing expression value
operator
float
()
const
{
return
value_
;
}
private:
/// Internal expression value stored in single-precision.
float
value_
;
};
}
// namespace detail
struct
alignas
(
1
)
fp8e4m3fnuz
...
...
@@ -290,16 +309,18 @@ struct alignas(1) fp8e4m3fnuz
MIGRAPHX_HIP_HOST_DEVICE
constexpr
fp8e4m3fnuz
(
uint8_t
bits
,
from_bits_t
)
:
x
(
bits
)
{}
MIGRAPHX_HIP_HOST_DEVICE
fp8e4m3fnuz
(
const
fp8e4m3fnuz
&
y
)
=
default
;
inline
explicit
MIGRAPHX_HIP_HOST_DEVICE
fp8e4m3fnuz
(
float
value
)
:
x
(
detail
::
fp8e4m3fnuz_from_fp32_value
(
value
))
{
}
inline
explicit
MIGRAPHX_HIP_HOST_DEVICE
fp8e4m3fnuz
(
migraphx
::
half
value
)
:
x
(
detail
::
fp8e4m3fnuz_from_fp32_value
(
float
(
value
)))
{
}
inline
MIGRAPHX_HIP_HOST_DEVICE
operator
float
()
const
fp8e4m3fnuz
&
MIGRAPHX_HIP_HOST_DEVICE
operator
=
(
const
fp8e4m3fnuz
&
rhs
)
=
default
;
fp8e4m3fnuz
&
MIGRAPHX_HIP_HOST_DEVICE
operator
=
(
fp8e4m3fnuz
&&
rhs
)
=
default
;
inline
constexpr
MIGRAPHX_HIP_HOST_DEVICE
operator
float
()
const
{
return
detail
::
fp8e4m3fnuz_to_fp32_value
(
x
);
}
...
...
@@ -310,12 +331,6 @@ struct alignas(1) fp8e4m3fnuz
return
*
this
;
}
fp8e4m3fnuz
&
MIGRAPHX_HIP_HOST_DEVICE
operator
=
(
migraphx
::
half
rhs
)
{
x
=
detail
::
fp8e4m3fnuz_from_fp32_value
(
float
(
rhs
));
return
*
this
;
}
inline
bool
MIGRAPHX_HIP_HOST_DEVICE
isnan
()
const
{
return
x
==
0b10000000
;
}
fp8e4m3fnuz
&
MIGRAPHX_HIP_HOST_DEVICE
operator
+=
(
float
rhs
)
...
...
@@ -346,6 +361,7 @@ inline std::ostream& operator<<(std::ostream& out, const fp8e4m3fnuz& value)
return
out
;
}
}
// namespace MIGRAPHX_INLINE_NS
}
// namespace migraphx
namespace
std
{
...
...
@@ -429,17 +445,17 @@ struct common_type<migraphx::fp8e4m3fnuz, migraphx::fp8e4m3fnuz>
using
type
=
float
;
};
template
<
>
struct
common_type
<
migraphx
::
fp8e4m3fnuz
,
migraphx
::
half
>
{
using
type
=
float
;
};
template
<
>
struct
common_type
<
migraphx
::
half
,
migraphx
::
fp8e4m3fnuz
>
{
using
type
=
float
;
};
//
template <>
//
struct common_type<migraphx::fp8e4m3fnuz, migraphx::half>
//
{
//
using type = float;
//
};
//
template <>
//
struct common_type<migraphx::half, migraphx::fp8e4m3fnuz>
//
{
//
using type = float;
//
};
}
// namespace std
#pragma clang diagnostic pop
...
...
src/include/migraphx/half.hpp
View file @
4935b575
...
...
@@ -27,6 +27,7 @@
#include <half/half.hpp>
#include <migraphx/config.hpp>
#include <migraphx/fp8e4m3fnuz.hpp>
namespace
migraphx
{
inline
namespace
MIGRAPHX_INLINE_NS
{
...
...
@@ -67,6 +68,18 @@ struct common_type<T, migraphx::half> : std::common_type<float, T> // NOLINT
{
};
template
<
>
struct
common_type
<
migraphx
::
fp8e4m3fnuz
,
migraphx
::
half
>
{
using
type
=
float
;
};
template
<
>
struct
common_type
<
migraphx
::
half
,
migraphx
::
fp8e4m3fnuz
>
{
using
type
=
float
;
};
template
<
>
struct
common_type
<
migraphx
::
half
,
migraphx
::
half
>
{
...
...
src/targets/gpu/CMakeLists.txt
View file @
4935b575
...
...
@@ -49,6 +49,8 @@ endif()
file
(
GLOB KERNEL_FILES CONFIGURE_DEPENDS
${
CMAKE_CURRENT_SOURCE_DIR
}
/kernels/include/migraphx/kernels/*.hpp
)
message
(
STATUS
"CMAKE Source Dir is :
${
CMAKE_SOURCE_DIR
}
"
)
list
(
APPEND KERNEL_FILES
${
CMAKE_SOURCE_DIR
}
/src/include/migraphx/fp8e4m3fnuz.hpp
)
message
(
STATUS
"KERNEL_FILES:
${
KERNEL_FILES
}
"
)
if
(
NOT MIGRAPHX_USE_COMPOSABLEKERNEL
)
...
...
@@ -58,6 +60,7 @@ if(NOT MIGRAPHX_USE_COMPOSABLEKERNEL)
${
CMAKE_CURRENT_SOURCE_DIR
}
/kernels/include/migraphx/kernels/ck.hpp
)
endif
()
include
(
Embed
)
add_embed_library
(
migraphx_kernels
${
KERNEL_FILES
}
RELATIVE
${
CMAKE_CURRENT_SOURCE_DIR
}
/kernels/include/
)
...
...
src/targets/gpu/kernels/include/migraphx/kernels/pointwise.hpp
View file @
4935b575
...
...
@@ -36,8 +36,7 @@ namespace migraphx {
template
<
class
F
,
class
T
,
class
...
Ts
>
__device__
void
pointwise_tensor
(
index
idx
,
F
f
,
T
out
,
Ts
...
xs
)
{
idx
.
global_stride
(
out
.
get_shape
().
elements
(),
[
&
](
auto
i
)
{
out
[
i
]
=
implicit_conversion
(
f
(
xs
[
i
]...));
});
idx
.
global_stride
(
out
.
get_shape
().
elements
(),
[
&
](
auto
i
)
{
out
[
i
]
=
f
(
xs
[
i
]...);
});
}
template
<
class
...
Transforms
>
...
...
src/targets/gpu/kernels/include/migraphx/kernels/reduce.hpp
View file @
4935b575
...
...
@@ -244,9 +244,8 @@ struct reducer_base
{
auto
&&
derived
=
static_cast
<
const
Derived
&>
(
*
this
);
auto
t
=
derived
.
slice
(
x
);
return
make_storage_access
<
typename
decltype
(
t
)
::
type
>
([
=
](
auto
i
,
auto
...)
->
auto
&
{
return
t
[
i
];
});
return
make_storage_access
<
typename
decltype
(
t
)
::
type
>
(
[
=
](
auto
i
,
auto
...)
->
auto
&
{
return
t
[
i
];
});
}
}
...
...
@@ -578,7 +577,7 @@ __device__ void fused_reduce(Output output, F f)
}
else
{
r
.
outer
([
&
]
{
output
[
out_idx
]
=
implicit_conversion
(
result
)
;
});
r
.
outer
([
&
]
{
output
[
out_idx
]
=
result
;
});
}
});
}
...
...
src/targets/gpu/kernels/include/migraphx/kernels/type_traits.hpp
View file @
4935b575
...
...
@@ -24,6 +24,7 @@
#ifndef MIGRAPHX_GUARD_AMDMIGRAPHX_KERNELS_TYPE_TRAITS_HPP
#define MIGRAPHX_GUARD_AMDMIGRAPHX_KERNELS_TYPE_TRAITS_HPP
#include <migraphx/fp8e4m3fnuz.hpp>
#include <migraphx/kernels/types.hpp>
#include <migraphx/kernels/integral_constant.hpp>
...
...
@@ -230,7 +231,7 @@ constexpr unsigned long int_max(unsigned long n)
template
<
class
T
,
MIGRAPHX_REQUIRES
(
is_integral
<
T
>{}
or
is_floating_point
<
T
>
{}
or
is_same
<
T
,
migraphx
::
half
>
{})
>
is_same
<
T
,
migraphx
::
half
>
{}
or
is_same
<
T
,
migraphx
::
fp8e4m3fnuz
>
{}
)
>
constexpr
T
numeric_max
()
{
if
constexpr
(
is_integral
<
T
>
{})
...
...
@@ -246,6 +247,8 @@ constexpr T numeric_max()
return
__FLT_MAX__
;
else
if
constexpr
(
is_same
<
T
,
migraphx
::
half
>
{})
return
__FLT16_MAX__
;
else
if
constexpr
(
is_same
<
T
,
migraphx
::
fp8e4m3fnuz
>
{})
return
T
{
0x7F
,
migraphx
::
fp8e4m3fnuz
::
from_bits
()};
else
return
0
;
}
...
...
@@ -260,6 +263,8 @@ constexpr T numeric_lowest()
else
return
-
numeric_max
<
T
>
()
-
1
;
}
else
if
constexpr
(
is_same
<
T
,
migraphx
::
fp8e4m3fnuz
>
{})
return
T
{
0xFF
,
migraphx
::
fp8e4m3fnuz
::
from_bits
()};
else
{
return
-
numeric_max
<
T
>
();
...
...
src/targets/gpu/kernels/include/migraphx/kernels/types.hpp
View file @
4935b575
...
...
@@ -23,7 +23,7 @@
*/
#ifndef MIGRAPHX_GUARD_AMDMIGRAPHX_KERNELS_TYPES_HPP
#define MIGRAPHX_GUARD_AMDMIGRAPHX_KERNELS_TYPES_HPP
#include <migraphx/fp8e4m3fnuz.hpp>
#include <migraphx/kernels/hip.hpp>
namespace
migraphx
{
...
...
test/gpu/jit.cpp
View file @
4935b575
...
...
@@ -144,7 +144,7 @@ extern "C" {
__global__ void kernel(${type}* p)
{
auto x = *p;
*p =
migraphx::implicit_conversion(
migraphx::${invoke}
)
;
*p = migraphx::${invoke};
}
}
...
...
@@ -345,18 +345,18 @@ TEST_CASE(compile_math)
// clang-format on
};
std
::
vector
<
std
::
string
>
data_types
;
auto
vec_sizes
=
{
2
,
4
,
6
};
//
auto vec_sizes = {2, 4, 6};
for
(
auto
&&
t
:
migraphx
::
shape
::
types
())
{
if
(
contains
({
migraphx
::
shape
::
bool_type
,
migraphx
::
shape
::
tuple_type
},
t
))
continue
;
auto
name
=
migraphx
::
shape
::
cpp_type
(
t
);
if
(
t
==
migraphx
::
shape
::
half_type
)
if
(
t
==
migraphx
::
shape
::
half_type
or
t
==
migraphx
::
shape
::
float8_type
)
name
.
insert
(
0
,
"migraphx::"
);
data_types
.
push_back
(
name
);
migraphx
::
transform
(
vec_sizes
,
std
::
back_inserter
(
data_types
),
[
&
](
auto
i
)
{
return
"migraphx::vec<"
+
name
+
", "
+
std
::
to_string
(
i
)
+
">"
;
});
//
migraphx::transform(vec_sizes, std::back_inserter(data_types), [&](auto i) {
//
return "migraphx::vec<" + name + ", " + std::to_string(i) + ">";
//
});
}
migraphx
::
shape
input
{
migraphx
::
shape
::
float_type
,
{
5
,
2
}};
migraphx
::
gpu
::
hip_compile_options
options
;
...
...
@@ -399,7 +399,7 @@ TEST_CASE(assert_type_min_max)
if
(
contains
({
migraphx
::
shape
::
bool_type
,
migraphx
::
shape
::
tuple_type
},
t
))
continue
;
auto
name
=
migraphx
::
shape
::
cpp_type
(
t
);
if
(
t
==
migraphx
::
shape
::
half_type
)
if
(
t
==
migraphx
::
shape
::
half_type
or
t
==
migraphx
::
shape
::
float8_type
)
name
.
insert
(
0
,
"migraphx::"
);
migraphx
::
shape
::
visit
(
t
,
[
&
](
auto
as
)
{
...
...
@@ -423,7 +423,6 @@ TEST_CASE(assert_type_min_max)
min
=
std
::
to_string
(
as
.
min
());
max
=
std
::
to_string
(
as
.
max
());
}
auto
src
=
migraphx
::
interpolate_string
(
assert_template
,
{{
"type"
,
name
},
{
"max"
,
max
},
{
"min"
,
min
}});
migraphx
::
shape
input
{
migraphx
::
shape
::
float_type
,
{
5
,
2
}};
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment