Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
MIGraphX
Commits
c4cee345
Unverified
Commit
c4cee345
authored
Dec 01, 2023
by
Umang Yadav
Committed by
GitHub
Dec 01, 2023
Browse files
Merge branch 'develop' into rocblas_fp8
parents
c40a39c3
eafd55de
Changes
143
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
329 additions
and
158 deletions
+329
-158
src/targets/gpu/jit/scatter.hpp
src/targets/gpu/jit/scatter.hpp
+78
-0
src/targets/gpu/jit/scatternd.cpp
src/targets/gpu/jit/scatternd.cpp
+8
-37
src/targets/gpu/kernels/include/migraphx/kernels/bit_cast.hpp
...targets/gpu/kernels/include/migraphx/kernels/bit_cast.hpp
+6
-1
src/targets/gpu/kernels/include/migraphx/kernels/float8.hpp
src/targets/gpu/kernels/include/migraphx/kernels/float8.hpp
+44
-40
src/targets/gpu/kernels/include/migraphx/kernels/gathernd.hpp
...targets/gpu/kernels/include/migraphx/kernels/gathernd.hpp
+13
-13
src/targets/gpu/kernels/include/migraphx/kernels/layernorm.hpp
...argets/gpu/kernels/include/migraphx/kernels/layernorm.hpp
+4
-4
src/targets/gpu/kernels/include/migraphx/kernels/math.hpp
src/targets/gpu/kernels/include/migraphx/kernels/math.hpp
+1
-1
src/targets/gpu/kernels/include/migraphx/kernels/pad.hpp
src/targets/gpu/kernels/include/migraphx/kernels/pad.hpp
+3
-2
src/targets/gpu/kernels/include/migraphx/kernels/roialign.hpp
...targets/gpu/kernels/include/migraphx/kernels/roialign.hpp
+10
-12
src/targets/gpu/kernels/include/migraphx/kernels/scatter_reduction_modes.hpp
...nels/include/migraphx/kernels/scatter_reduction_modes.hpp
+83
-0
src/targets/gpu/kernels/include/migraphx/kernels/scatternd.hpp
...argets/gpu/kernels/include/migraphx/kernels/scatternd.hpp
+1
-27
src/targets/gpu/kernels/include/migraphx/kernels/softmax.hpp
src/targets/gpu/kernels/include/migraphx/kernels/softmax.hpp
+1
-1
src/targets/gpu/kernels/include/migraphx/kernels/type_traits.hpp
...gets/gpu/kernels/include/migraphx/kernels/type_traits.hpp
+1
-1
src/targets/gpu/kernels/include/migraphx/kernels/vec.hpp
src/targets/gpu/kernels/include/migraphx/kernels/vec.hpp
+1
-1
src/targets/gpu/mlir.cpp
src/targets/gpu/mlir.cpp
+13
-1
src/targets/gpu/prefuse_ops.cpp
src/targets/gpu/prefuse_ops.cpp
+35
-10
src/targets/ref/CMakeLists.txt
src/targets/ref/CMakeLists.txt
+2
-1
src/tf/CMakeLists.txt
src/tf/CMakeLists.txt
+9
-2
src/tmp_dir.cpp
src/tmp_dir.cpp
+11
-1
test/gpu/fuse_mlir.cpp
test/gpu/fuse_mlir.cpp
+5
-3
No files found.
src/targets/gpu/jit/scatter.hpp
0 → 100644
View file @
c4cee345
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2023 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#ifndef MIGRAPHX_GUARD_JIT_SCATTER_HPP
#define MIGRAPHX_GUARD_JIT_SCATTER_HPP
#include <migraphx/gpu/compiler.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/gpu/context.hpp>
#include <migraphx/gpu/compile_hip_code_object.hpp>
#include <migraphx/gpu/compile_hip.hpp>
namespace
migraphx
{
inline
namespace
MIGRAPHX_INLINE_NS
{
namespace
gpu
{
template
<
typename
Derived
>
struct
scatter_compiler
:
compiler
<
Derived
>
{
compiler_replace
compile
(
context
&
ctx
,
instruction_ref
ins
,
const
operation
&
op
)
const
{
const
auto
inputs
=
to_shapes
(
std
::
vector
<
instruction_ref
>
{
ins
->
inputs
().
begin
()
+
1
,
ins
->
inputs
().
end
()});
hip_compile_options
options
;
options
.
set_launch_params
(
op
.
to_value
(),
compute_global_for
(
ctx
,
inputs
.
at
(
1
).
elements
()));
options
.
inputs
=
inputs
;
options
.
output
=
inputs
.
back
();
options
.
kernel_name
=
derived
().
get_kernel_name
(
op
);
options
.
virtual_inputs
=
inputs
;
// The compiler protests the inequality comparison in assign_mul when pertaining to floating
// point, despite it making sense in the context. Thus the warning removal.
options
.
params
+=
"-Wno-float-equal"
;
const
auto
src
=
derived
().
make_interpolated_string
(
op
);
return
prepend_copy_data_to_output
(
compile_hip_code_object
(
src
,
options
));
}
compiler_replace
prepend_copy_data_to_output
(
const
operation
&
co
)
const
{
return
{
co
,
[](
module
&
m
,
instruction_ref
ins
,
const
operation
&
op
)
{
auto
args
=
ins
->
inputs
();
args
.
back
()
=
m
.
insert_instruction
(
ins
,
make_op
(
"hip::copy"
),
args
.
front
(),
args
.
back
());
args
.
erase
(
args
.
begin
());
return
m
.
replace_instruction
(
ins
,
op
,
args
);
}};
}
std
::
string
get_kernel_name
(
const
operation
&
op
)
const
{
return
op
.
name
()
+
"_kernel"
;
}
const
Derived
&
derived
()
const
{
return
static_cast
<
const
Derived
&>
(
*
this
);
}
};
}
// namespace gpu
}
// namespace MIGRAPHX_INLINE_NS
}
// namespace migraphx
#endif
src/targets/gpu/jit/scatternd.cpp
View file @
c4cee345
...
...
@@ -21,11 +21,7 @@
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#include <migraphx/gpu/compiler.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/gpu/context.hpp>
#include <migraphx/gpu/compile_hip_code_object.hpp>
#include <migraphx/gpu/compile_hip.hpp>
#include "scatter.hpp"
namespace
migraphx
{
inline
namespace
MIGRAPHX_INLINE_NS
{
...
...
@@ -55,46 +51,21 @@ MIGRAPHX_GLOBAL void scatternd_kernel(void* in_indices, void* in_updates, void*
)__migraphx__"
;
struct
scatternd_compiler
:
compiler
<
scatternd_compiler
>
struct
scatternd_compiler
:
scatter_
compiler
<
scatternd_compiler
>
{
std
::
vector
<
std
::
string
>
names
()
const
{
return
{
"scatternd_none"
,
"scatternd_add"
,
"scatternd_mul"
};
return
{
"scatternd_none"
,
"scatternd_add"
,
"scatternd_mul"
,
"scatternd_min"
,
"scatternd_max"
};
}
operation
compile_op
(
context
&
ctx
,
const
std
::
vector
<
shape
>&
inputs
,
const
value
&
v
)
const
std
::
string
make_interpolated_string
(
const
operation
&
op
)
const
{
hip_compile_options
options
;
options
.
set_launch_params
(
v
,
compute_global_for
(
ctx
,
inputs
.
at
(
1
).
elements
()));
options
.
inputs
=
inputs
;
options
.
output
=
inputs
.
back
();
options
.
kernel_name
=
"scatternd_kernel"
;
options
.
virtual_inputs
=
inputs
;
auto
reduction
=
"assign_"
+
v
.
get
(
"reduction"
,
std
::
string
{
"none"
});
auto
src
=
interpolate_string
(
scatternd_kernel
,
{{
"reduction"
,
reduction
}});
return
compile_hip_code_object
(
src
,
options
);
const
auto
reduction
=
op
.
name
().
substr
(
std
::
char_traits
<
char
>::
length
(
"scatternd_"
));
return
interpolate_string
(
scatternd_kernel
,
{{
"reduction"
,
"assign_"
+
reduction
}});
}
compiler_replace
compile
(
context
&
ctx
,
instruction_ref
ins
,
const
operation
&
op
)
const
{
assert
(
starts_with
(
op
.
name
(),
"scatternd_"
));
auto
reduction
=
op
.
name
().
substr
(
10
);
return
insert
(
compile_op
(
ctx
,
to_shapes
(
std
::
vector
<
instruction_ref
>
{
ins
->
inputs
().
begin
()
+
1
,
ins
->
inputs
().
end
()}),
{{
"reduction"
,
reduction
}}));
}
compiler_replace
insert
(
const
operation
&
co
)
const
{
return
{
co
,
[](
module
&
m
,
instruction_ref
ins
,
const
operation
&
op
)
{
auto
args
=
ins
->
inputs
();
args
.
back
()
=
m
.
insert_instruction
(
ins
,
make_op
(
"hip::copy"
),
args
.
front
(),
args
.
back
());
args
.
erase
(
args
.
begin
());
return
m
.
replace_instruction
(
ins
,
op
,
args
);
}};
}
std
::
string
get_kernel_name
(
const
operation
&
)
const
{
return
"scatternd_kernel"
;
}
};
}
// namespace gpu
...
...
src/targets/gpu/kernels/include/migraphx/kernels/bit_cast.hpp
View file @
c4cee345
...
...
@@ -22,8 +22,13 @@
#ifndef MIGRAPHX_GUARD_KERNELS_BITCAST_HPP
#define MIGRAPHX_GUARD_KERNELS_BITCAST_HPP
#include <migraphx/kernels/type_traits.hpp>
namespace
migraphx
{
template
<
typename
To
,
typename
From
>
template
<
typename
To
,
typename
From
,
MIGRAPHX_REQUIRES
(
is_trivially_copyable
<
To
>{}
and
is_trivially_copyable
<
From
>
{})
>
inline
constexpr
To
bit_cast
(
From
fr
)
noexcept
{
static_assert
(
sizeof
(
To
)
==
sizeof
(
From
));
...
...
src/targets/gpu/kernels/include/migraphx/kernels/float8.hpp
View file @
c4cee345
...
...
@@ -365,15 +365,6 @@ struct float8
inline
__device__
constexpr
float8
&
operator
=
(
const
float8
&
rhs
)
=
default
;
inline
__device__
constexpr
float8
&
operator
=
(
float8
&&
rhs
)
noexcept
=
default
;
inline
__device__
constexpr
bool
operator
==
(
const
float8
&
rhs
)
const
{
if
(
rhs
.
is_nan
()
or
rhs
.
is_inf
()
or
this
->
is_nan
()
or
this
->
is_inf
())
return
false
;
else
if
((
rhs
.
is_zero
()
and
this
->
is_zero
())
or
(
this
->
data
==
rhs
.
data
))
return
true
;
return
false
;
}
inline
__device__
constexpr
bool
operator
<
(
const
float8
&
rhs
)
const
{
const
auto
we
=
static_cast
<
float
>
(
*
this
);
...
...
@@ -403,12 +394,21 @@ using fp8e5m2fnuz = float8<migraphx::fp8::f8_type::bf8, true>;
}
// NOLINTNEXTLINE
#define MIGRAPHX_FP8_FABS(T) \
inline constexpr __device__ T fabs(T v) \
{ \
/*NOLINTNEXTLINE*/
\
v.data = v.data & 0x7f; \
return v; \
#define MIGRAPHX_FP8_OTHER_OPS(T) \
inline constexpr __device__ T fabs(T v) \
{ \
/*NOLINTNEXTLINE*/
\
v.data = v.data & 0x7f; \
return v; \
} \
inline __device__ constexpr bool operator==(const T& lhs, const T& rhs) \
{ \
if(rhs.is_nan() or rhs.is_inf() or lhs.is_nan() or lhs.is_inf()) \
return false; \
else if((rhs.is_zero() and lhs.is_zero()) or (lhs.data == rhs.data)) \
return true; \
return false; \
}
// NOLINTNEXTLINE
...
...
@@ -417,11 +417,10 @@ using fp8e5m2fnuz = float8<migraphx::fp8::f8_type::bf8, true>;
MIGRAPHX_FP8_BINARY_OP(-, T, T) \
MIGRAPHX_FP8_BINARY_OP(/, T, T) \
MIGRAPHX_FP8_BINARY_OP(+, T, T) \
MIGRAPHX_FP8_BINARY_OP(==, T, bool) \
MIGRAPHX_FP8_BINARY_OP(>=, T, bool) \
MIGRAPHX_FP8_BINARY_OP(<=, T, bool) \
MIGRAPHX_FP8_BINARY_OP(!=, T, bool) \
MIGRAPHX_FP8_
FAB
S(T)
MIGRAPHX_FP8_
OTHER_OP
S(T)
MIGRAPHX_FP8_GEN_OP_OVERLOADS
(
fp8e5m2
)
MIGRAPHX_FP8_GEN_OP_OVERLOADS
(
fp8e5m2fnuz
)
...
...
@@ -447,7 +446,7 @@ class numeric_limits<fp8e4m3fnuz>
{
return
fp8e4m3fnuz
(
0x7F
,
fp8e4m3fnuz
::
from_bits
());
}
// this is min value that is not DeNorm. DeNorm min is 0x01
// this is min value that is not DeNorm
alized(DeNorm)
. DeNorm min is 0x01
static
constexpr
__device__
fp8e4m3fnuz
min
()
{
return
fp8e4m3fnuz
(
0x08
,
fp8e4m3fnuz
::
from_bits
());
...
...
@@ -475,7 +474,7 @@ class numeric_limits<fp8e4m3fn>
}
static
constexpr
__device__
fp8e4m3fn
max
()
{
return
fp8e4m3fn
(
0x7E
,
fp8e4m3fn
::
from_bits
());
}
// this is min value that is not DeNorm. DeNorm min is 0x01
// this is min value that is not DeNorm
alized(DeNorm)
. DeNorm min is 0x01
static
constexpr
__device__
fp8e4m3fn
min
()
{
return
fp8e4m3fn
(
0x08
,
fp8e4m3fn
::
from_bits
());
}
static
constexpr
__device__
fp8e4m3fn
lowest
()
...
...
@@ -503,8 +502,10 @@ class numeric_limits<fp8e5m2fnuz>
{
return
fp8e5m2fnuz
(
0x7F
,
fp8e5m2fnuz
::
from_bits
());
}
// this is min value that is not DeNorm. DeNorm min is 0x01. I am not sure if we want to make
// this distinction. For the floating points we would end up using lowest most of the times.
// this is min value that is not DeNormalized(DeNorm). DeNorm min is 0x01. I am not sure if we
// want to make this distinction. For the floating points we would end up using lowest most of
// the times.
static
constexpr
__device__
fp8e5m2fnuz
min
()
{
return
fp8e5m2fnuz
(
0x4
,
fp8e5m2fnuz
::
from_bits
());
...
...
@@ -529,8 +530,7 @@ class numeric_limits<fp8e5m2>
}
static
constexpr
__device__
fp8e5m2
max
()
{
return
fp8e5m2
(
0x7B
,
fp8e5m2
::
from_bits
());
}
// this is min value that is not DeNorm. DeNorm min is 0x01. I am not sure if we want to make
// this distinction. For the floating points we would end up using lowest most of the times.
// this is min value that is not DeNormalized(DeNorm). DeNorm min is 0x01.
static
constexpr
__device__
fp8e5m2
min
()
{
return
fp8e5m2
(
0x4
,
fp8e5m2
::
from_bits
());
}
static
constexpr
__device__
fp8e5m2
lowest
()
{
return
fp8e5m2
(
0xFB
,
fp8e5m2
::
from_bits
());
}
...
...
@@ -540,23 +540,27 @@ class numeric_limits<fp8e5m2>
}
// namespace fp8
// NOLINTNEXTLINE
#define MIGRAPHX_FP8_MIN_MAX(T) \
template <> \
constexpr T numeric_max<T, void>() \
{ \
return fp8::numeric_limits<T>::max(); \
} \
template <> \
constexpr T numeric_lowest<T>() \
{ \
return fp8::numeric_limits<T>::lowest(); \
}
MIGRAPHX_FP8_MIN_MAX
(
fp8
::
fp8e4m3fnuz
);
MIGRAPHX_FP8_MIN_MAX
(
fp8
::
fp8e5m2fnuz
);
MIGRAPHX_FP8_MIN_MAX
(
fp8
::
fp8e4m3fn
);
MIGRAPHX_FP8_MIN_MAX
(
fp8
::
fp8e5m2
);
template
<
class
T
,
MIGRAPHX_REQUIRES
(
is_same
<
T
,
fp8
::
fp8e4m3fnuz
>{}
or
is_same
<
T
,
fp8
::
fp8e5m2fnuz
>
{}
or
is_same
<
T
,
fp8
::
fp8e4m3fn
>
{}
or
is_same
<
T
,
fp8
::
fp8e5m2
>
{})
>
constexpr
T
numeric_max
(
migraphx
::
fp8
::
f8_type
unused
=
migraphx
::
fp8
::
f8_type
::
fp8
)
{
// unused parameter is added to make this numeric_max different overload definition
// compared to numeric_max defined in type_traits.hpp
(
void
)(
unused
);
return
fp8
::
numeric_limits
<
T
>::
max
();
}
template
<
class
T
,
MIGRAPHX_REQUIRES
(
is_same
<
T
,
fp8
::
fp8e4m3fnuz
>{}
or
is_same
<
T
,
fp8
::
fp8e5m2fnuz
>
{}
or
is_same
<
T
,
fp8
::
fp8e4m3fn
>
{}
or
is_same
<
T
,
fp8
::
fp8e5m2
>
{})
>
constexpr
T
numeric_lowest
(
migraphx
::
fp8
::
f8_type
unused
=
migraphx
::
fp8
::
f8_type
::
fp8
)
{
// unused parameter is added to make this numeric_lowest different overload definition
// compared to numeric_lowest defined in type_traits.hpp
(
void
)(
unused
);
return
fp8
::
numeric_limits
<
T
>::
lowest
();
}
}
// namespace migraphx
// =================================================================================================
#if defined(__clang__)
...
...
src/targets/gpu/kernels/include/migraphx/kernels/gathernd.hpp
View file @
c4cee345
...
...
@@ -53,35 +53,35 @@ __device__ void gathernd(const T& data_t, const U& indices_t, const V& output_t,
auto
indices_shape_lens
=
indices_shape
.
lens
;
auto
data_shape_lens
=
data_shape
.
lens
;
auto
num_slice_dims
=
indices_shape_lens
.
back
();
std
::
size_t
num_slices
=
size_t
num_slices
=
accumulate
(
indices_shape_lens
.
begin
(),
indices_shape_lens
.
end
()
-
1
,
1
,
op
::
product
{});
std
::
size_t
slice_size
=
accumulate
(
data_shape_lens
.
begin
()
+
num_slice_dims
+
batch_dims
,
data_shape_lens
.
end
(),
1
,
op
::
product
{});
const
std
::
size_t
num_batches
=
size_t
slice_size
=
accumulate
(
data_shape_lens
.
begin
()
+
num_slice_dims
+
batch_dims
,
data_shape_lens
.
end
(),
1
,
op
::
product
{});
const
size_t
num_batches
=
accumulate
(
data_shape_lens
.
begin
(),
data_shape_lens
.
begin
()
+
batch_dims
,
1
,
op
::
product
{});
const
std
::
size_t
data_batch_stride
=
const
size_t
data_batch_stride
=
accumulate
(
data_shape_lens
.
begin
()
+
batch_dims
,
data_shape_lens
.
end
(),
1
,
op
::
product
{});
const
auto
num_slices_per_batch
=
num_slices
/
num_batches
;
ind
.
global_stride
(
output_shape
.
elements
(),
[
&
](
auto
i
)
{
const
auto
*
indices_ptr
=
indices_t
.
data
();
const
std
::
size_t
j
=
i
/
slice_size
;
const
std
::
size_t
batch_idx
=
j
/
num_slices_per_batch
;
const
size_t
j
=
i
/
slice_size
;
const
size_t
batch_idx
=
j
/
num_slices_per_batch
;
auto
*
slice_indices
=
indices_ptr
+
(
j
*
num_slice_dims
);
std
::
size_t
relative_slice_offset
=
0
;
for
(
std
::
size_t
idx
=
0
;
idx
<
num_slice_dims
;
++
idx
)
size_t
relative_slice_offset
=
0
;
for
(
size_t
idx
=
0
;
idx
<
num_slice_dims
;
++
idx
)
{
int64_t
index
=
slice_indices
[
idx
];
const
std
::
size_t
input_dim_idx
=
batch_dims
+
idx
;
const
size_t
input_dim_idx
=
batch_dims
+
idx
;
const
auto
input_dim
=
data_shape_lens
[
input_dim_idx
];
MIGRAPHX_ASSERT
(
index
>=
-
static_cast
<
int64_t
>
(
input_dim
)
and
index
<
static_cast
<
int64_t
>
(
input_dim
));
if
(
index
<
0
)
index
+=
input_dim
;
std
::
size_t
size_from_slice_dims
=
size_t
size_from_slice_dims
=
accumulate
(
data_shape_lens
.
begin
()
+
batch_dims
+
idx
+
1
,
data_shape_lens
.
begin
()
+
batch_dims
+
num_slice_dims
,
slice_size
,
...
...
src/targets/gpu/kernels/include/migraphx/kernels/layernorm.hpp
View file @
c4cee345
...
...
@@ -54,12 +54,12 @@ __device__ void generic_binary_layernorm(
using
value_type
=
typename
Input1
::
type
;
using
vec_value_type
=
vec_type
<
value_type
>
;
constexpr
auto
relements
=
r
.
template
elements
<
Input1
>();
constexpr
auto
relements_r
=
static_cast
<
vec_value_type
>
(
1.0
/
relements
);
constexpr
auto
relements_r
=
vec_value_type
{
1.0
/
relements
};
auto
relements_rsqrt
=
sqrt
(
relements_r
);
auto
means
=
r
.
reduce
(
op
::
sum
{},
make_array
<
vec_value_type
>
(
static_cast
<
vec_value_type
>
(
0
),
static_cast
<
vec_value_type
>
(
0
)),
make_array
<
vec_value_type
>
(
vec_value_type
{
0
},
vec_value_type
{
0
}),
[
&
](
auto
x
)
{
auto
x_out
=
x
*
relements_r
;
// dividing x by sqrt(relements) before squaring allows computing
...
...
@@ -71,7 +71,7 @@ __device__ void generic_binary_layernorm(
auto
mean_x
=
means
[
0
];
auto
mean_x2
=
means
[
1
];
auto
variance
=
mean_x2
-
(
mean_x
*
mean_x
);
value_type
eps_val
=
static_cast
<
value_type
>
(
eps
);
value_type
eps_val
=
implicit_conversion
(
eps
);
r
.
inner
([
&
](
auto
&
y
,
auto
x
,
auto
...
xs
)
{
auto
m
=
x
-
mean_x
;
...
...
src/targets/gpu/kernels/include/migraphx/kernels/math.hpp
View file @
c4cee345
...
...
@@ -290,7 +290,7 @@ MIGRAPHX_DEVICE_MATH_VEC(where)
template
<
class
T
,
class
U
>
constexpr
auto
convert
(
U
v
)
{
return
vec_transform
(
v
)([](
auto
x
)
{
return
static_cast
<
T
>
(
x
);
});
return
vec_transform
(
v
)([](
auto
x
)
->
T
{
return
static_cast
<
T
>
(
x
);
});
}
}
// namespace migraphx
...
...
src/targets/gpu/kernels/include/migraphx/kernels/pad.hpp
View file @
c4cee345
...
...
@@ -28,6 +28,7 @@
#include <migraphx/kernels/index.hpp>
#include <migraphx/kernels/algorithm.hpp>
#include <migraphx/kernels/ranges.hpp>
#include <migraphx/kernels/vec.hpp>
namespace
migraphx
{
...
...
@@ -54,9 +55,9 @@ __device__ void pad(const index& idx,
if
(
any_of
(
range_multi
.
begin
(),
range_multi
.
end
(),
[
&
](
auto
j
)
{
return
multi
[
j
]
<
offsets
[
j
]
or
input_idx
[
j
]
>=
input_bounds
[
j
];
}))
output
[
multi
]
=
otype
(
pad_val
);
output
[
multi
]
=
implicit_conversion
(
pad_val
);
else
output
[
multi
]
=
otype
(
input
[
input_idx
]);
output
[
multi
]
=
implicit_conversion
(
input
[
input_idx
]);
});
}
...
...
src/targets/gpu/kernels/include/migraphx/kernels/roialign.hpp
View file @
c4cee345
...
...
@@ -62,7 +62,7 @@ struct avg_pool
template
<
class
T
>
MIGRAPHX_DEVICE_CONSTEXPR
T
final
(
T
x
,
index_int
y
)
{
return
(
y
==
0
)
?
static_cast
<
T
>
(
0.0
)
:
static_cast
<
T
>
(
x
/
y
)
;
return
(
y
==
0
)
?
T
{
0.0
}
:
T
{
x
/
y
}
;
}
};
...
...
@@ -77,7 +77,7 @@ MIGRAPHX_DEVICE_CONSTEXPR typename Iterator::value_type bilinear_interpolate(
{
if
(
xy
[
ii
]
<
-
1.0
f
or
xy
[
ii
]
>
dims
[
ii
])
{
return
static_cast
<
ret_type
>
(
0
);
return
implicit_conversion
(
0
);
}
xy
[
ii
]
=
migraphx
::
max
(
xy
[
ii
],
0.0
f
);
...
...
@@ -93,18 +93,17 @@ MIGRAPHX_DEVICE_CONSTEXPR typename Iterator::value_type bilinear_interpolate(
high
[
0
]
*
dims
[
1
]
+
low
[
1
],
high
[
0
]
*
dims
[
1
]
+
high
[
1
]};
float
ly
=
xy
[
0
]
-
low
[
0
];
float
lx
=
xy
[
1
]
-
low
[
1
];
float
hy
=
1.0
f
-
ly
;
float
hx
=
1.0
f
-
lx
;
array
<
ret_type
,
4
>
ws
=
{
static_cast
<
ret_type
>
(
hy
*
hx
),
static_cast
<
ret_type
>
(
hy
*
lx
),
static_cast
<
ret_type
>
(
ly
*
hx
),
static_cast
<
ret_type
>
(
ly
*
lx
)};
float
ly
=
xy
[
0
]
-
low
[
0
];
float
lx
=
xy
[
1
]
-
low
[
1
];
float
hy
=
1.0
f
-
ly
;
float
hx
=
1.0
f
-
lx
;
// do calculations in floating point and convert final result to required type
array
<
float
,
4
>
ws
=
{
hy
*
hx
,
hy
*
lx
,
ly
*
hx
,
ly
*
lx
};
auto
v01
=
pooling
(
data
[
locs
[
0
]]
*
ws
[
0
],
data
[
locs
[
1
]]
*
ws
[
1
]);
auto
v23
=
pooling
(
data
[
locs
[
2
]]
*
ws
[
2
],
data
[
locs
[
3
]]
*
ws
[
3
]);
return
pooling
(
v01
,
v23
);
return
implicit_conversion
(
pooling
(
v01
,
v23
)
)
;
}
template
<
class
Iterator
,
class
Op
>
...
...
@@ -153,7 +152,6 @@ __device__ void roialign(const T& x_t, const U& rois_t, const V& ind_t, W& y_t,
const
auto
x
=
x_t
.
begin
();
const
auto
rois
=
rois_t
.
begin
();
const
auto
ind
=
ind_t
.
begin
();
using
ytype
=
typename
W
::
type
;
// input shape
auto
x_lens
=
x_t
.
get_shape
().
lens
;
auto
channel_num
=
x_lens
[
1
];
...
...
src/targets/gpu/kernels/include/migraphx/kernels/scatter_reduction_modes.hpp
0 → 100644
View file @
c4cee345
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2023 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#ifndef MIGRAPHX_GUARD_KERNELS_SCATTER_REDUCTION_MODES_HPP
#define MIGRAPHX_GUARD_KERNELS_SCATTER_REDUCTION_MODES_HPP
#include <migraphx/kernels/types.hpp>
namespace
migraphx
{
struct
assign_none
{
template
<
class
T
,
class
U
>
MIGRAPHX_DEVICE_CONSTEXPR
void
operator
()(
T
&
x
,
U
y
)
const
{
x
=
y
;
}
};
struct
assign_add
{
template
<
class
T
,
class
U
>
MIGRAPHX_DEVICE_CONSTEXPR
void
operator
()(
T
&
x
,
U
y
)
const
{
atomicAdd
(
&
x
,
y
);
}
};
struct
assign_mul
{
template
<
class
T
,
class
U
>
MIGRAPHX_DEVICE_CONSTEXPR
void
operator
()(
T
&
x
,
U
y
)
const
{
T
old
=
x
;
T
assumed
;
do
{
assumed
=
old
;
old
=
atomicCAS
(
&
x
,
assumed
,
assumed
*
y
);
}
while
(
assumed
!=
old
);
}
};
struct
assign_max
{
template
<
typename
T
,
typename
U
>
MIGRAPHX_DEVICE_CONSTEXPR
void
operator
()(
T
&
x
,
U
y
)
const
{
atomicMax
(
&
x
,
y
);
}
};
struct
assign_min
{
template
<
typename
T
,
typename
U
>
MIGRAPHX_DEVICE_CONSTEXPR
void
operator
()(
T
&
x
,
U
y
)
const
{
atomicMin
(
&
x
,
y
);
}
};
}
// namespace migraphx
#endif
src/targets/gpu/kernels/include/migraphx/kernels/scatternd.hpp
View file @
c4cee345
...
...
@@ -26,36 +26,10 @@
#include <migraphx/kernels/index.hpp>
#include <migraphx/kernels/algorithm.hpp>
#include <migraphx/kernels/scatter_reduction_modes.hpp>
namespace
migraphx
{
struct
assign_none
{
template
<
class
T
,
class
U
>
MIGRAPHX_DEVICE_CONSTEXPR
void
operator
()(
T
&
x
,
U
y
)
const
{
x
=
y
;
}
};
struct
assign_add
{
template
<
class
T
,
class
U
>
MIGRAPHX_DEVICE_CONSTEXPR
void
operator
()(
T
&
x
,
U
y
)
const
{
x
+=
y
;
}
};
struct
assign_mul
{
template
<
class
T
,
class
U
>
MIGRAPHX_DEVICE_CONSTEXPR
void
operator
()(
T
&
x
,
U
y
)
const
{
x
*=
y
;
}
};
template
<
class
T
,
class
U
,
class
V
,
class
F
>
__device__
void
scatternd
(
const
T
&
indices_t
,
const
U
&
updates_t
,
const
V
&
output_t
,
F
f
)
{
...
...
src/targets/gpu/kernels/include/migraphx/kernels/softmax.hpp
View file @
c4cee345
...
...
@@ -44,7 +44,7 @@ __device__ void softmax(Input input1, Output output)
auto
exp_in
=
r
.
inner
([
&
](
auto
x
)
{
return
migraphx
::
exp
(
x
-
c
);
})(
input
);
auto
batch_sum
=
r
.
reduce
(
op
::
sum
{},
0
,
[](
auto
x
)
{
return
migraphx
::
convert
<
float
>
(
x
);
})(
exp_in
);
r
.
inner
([
&
](
auto
&
y
,
auto
x
)
{
y
=
static_cast
<
otype
>
(
x
/
batch_sum
);
})(
output
,
exp_in
);
r
.
inner
([
&
](
auto
&
y
,
auto
x
)
{
y
=
implicit_conversion
(
x
/
batch_sum
);
})(
output
,
exp_in
);
});
}
...
...
src/targets/gpu/kernels/include/migraphx/kernels/type_traits.hpp
View file @
c4cee345
...
...
@@ -251,7 +251,7 @@ constexpr T numeric_max()
}
template
<
class
T
>
constexpr
T
numeric_lowest
()
constexpr
auto
numeric_lowest
()
->
decltype
(
numeric_max
<
T
>
())
{
if
constexpr
(
is_integral
<
T
>
{})
{
...
...
src/targets/gpu/kernels/include/migraphx/kernels/vec.hpp
View file @
c4cee345
...
...
@@ -207,7 +207,7 @@ struct implicit_conversion_op
template
<
class
U
>
constexpr
operator
U
()
const
{
return
x
;
return
static_cast
<
U
>
(
x
)
;
}
};
...
...
src/targets/gpu/mlir.cpp
View file @
c4cee345
...
...
@@ -73,6 +73,7 @@ namespace gpu {
MIGRAPHX_DECLARE_ENV_VAR
(
MIGRAPHX_TRACE_MLIR
);
MIGRAPHX_DECLARE_ENV_VAR
(
MIGRAPHX_MLIR_TUNE_EXHAUSTIVE
);
MIGRAPHX_DECLARE_ENV_VAR
(
MIGRAPHX_MLIR_TUNE_LIMIT
);
MIGRAPHX_DECLARE_ENV_VAR
(
MIGRAPHX_MLIR_TUNING_DB
);
MIGRAPHX_DECLARE_ENV_VAR
(
MIGRAPHX_MLIR_TUNING_CFG
);
...
...
@@ -796,7 +797,9 @@ struct mlir_program
if
(
enabled
(
MIGRAPHX_MLIR_TUNE_EXHAUSTIVE
{}))
tuning_mode
=
RocmlirTuningParamSetKindExhaustive
;
mlir_tuning_space
params
{
mlirRockTuningSpaceCreate
(
mmodule
.
get
(),
tuning_mode
)};
for
(
auto
i
:
range
(
mlirRockTuningGetNumParams
(
params
.
get
())))
const
auto
limit
=
value_of
(
MIGRAPHX_MLIR_TUNE_LIMIT
{},
std
::
numeric_limits
<
std
::
size_t
>::
max
());
for
(
auto
i
:
range
(
std
::
min
<
std
::
size_t
>
(
limit
,
mlirRockTuningGetNumParams
(
params
.
get
()))))
{
mlir_tuning_param
param
{
mlirRockTuningParamCreate
()};
if
(
not
mlirRockTuningParamGet
(
params
.
get
(),
i
,
param
.
get
()))
...
...
@@ -1032,6 +1035,15 @@ tuning_config get_tuning_config_mlir(const context& migraphx_ctx,
mlir_program
mp
;
mp
.
set_gpu_properties
(
migraphx_ctx
);
mp
.
parse
(
m
);
const
bool
trace
=
enabled
(
MIGRAPHX_TRACE_MLIR
{});
static
std
::
mutex
mutex
;
if
(
trace
)
{
const
std
::
lock_guard
<
std
::
mutex
>
lock
(
mutex
);
auto
mod_op
=
mlirModuleGetOperation
(
mp
.
mmodule
.
get
());
std
::
cout
<<
mlir_print
(
&
mlirOperationPrint
,
mod_op
)
<<
std
::
endl
;
}
return
mp
.
get_tuning_config
(
exhaustive
);
}
...
...
src/targets/gpu/prefuse_ops.cpp
View file @
c4cee345
...
...
@@ -28,7 +28,10 @@
#include <migraphx/register_op.hpp>
#include <migraphx/pass_manager.hpp>
#include <migraphx/dead_code_elimination.hpp>
#ifdef MIGRAPHX_USE_COMPOSABLEKERNEL
#include <migraphx/gpu/ck.hpp>
#endif
#include <migraphx/gpu/fuse_mlir.hpp>
namespace
migraphx
{
inline
namespace
MIGRAPHX_INLINE_NS
{
...
...
@@ -128,26 +131,49 @@ struct pre_gemm_softmax_gemm : gemm_softmax_gemm
};
MIGRAPHX_REGISTER_OP
(
pre_gemm_softmax_gemm
);
MIGRAPHX_PRED_MATCHER
(
is_ck_gemm
,
instruction_ref
ins
)
auto
is_ck_gemm
(
)
{
if
(
ins
->
name
()
!=
"dot"
)
return
match
::
make_basic_pred_matcher
([
=
](
instruction_ref
ins
)
{
#ifdef MIGRAPHX_USE_COMPOSABLEKERNEL
if
(
not
enabled
(
MIGRAPHX_ENABLE_CK
{}))
return
false
;
if
(
ins
->
name
()
!=
"dot"
)
return
false
;
if
(
not
pre_gemm_softmax_gemm
::
is_ck_supported_type
(
ins
->
get_shape
().
type
()))
return
false
;
return
true
;
#else
(
void
)
ins
;
return
false
;
if
(
not
pre_gemm_softmax_gemm
::
is_ck_supported_type
(
ins
->
get_shape
().
type
()))
return
false
;
return
true
;
#endif
});
}
auto
is_mlir_gemm
()
{
return
match
::
make_basic_pred_matcher
([
=
](
instruction_ref
ins
)
{
if
(
not
mlir_attention_enabled
())
return
false
;
if
(
ins
->
name
()
!=
"dot"
)
return
false
;
return
std
::
all_of
(
ins
->
inputs
().
begin
(),
ins
->
inputs
().
end
(),
[
&
](
auto
i
)
{
return
pre_gemm_softmax_gemm
::
is_mlir_supported_type
(
i
->
get_shape
().
type
());
});
});
}
struct
find_gemm_softmax_gemm
{
auto
matcher
()
const
{
auto
gemm1
=
match
::
skip
(
match
::
name
(
"contiguous"
))(
match
::
name
(
"dot"
)(
is_ck_gemm
(
).
bind
(
"gemm1"
)));
auto
gemm1
=
match
::
skip
(
match
::
name
(
"contiguous"
))(
match
::
name
(
"dot"
)(
match
::
any_of
(
is_ck_gemm
(),
is_mlir_gemm
()
).
bind
(
"gemm1"
)));
auto
mul
=
match
::
name
(
"mul"
)(
match
::
nargs
(
2
),
match
::
either_arg
(
0
,
1
)(
match
::
is_constant
().
bind
(
"scale"
),
gemm1
));
auto
softmax
=
match
::
name
(
"softmax"
)(
match
::
arg
(
0
)(
mul
)).
bind
(
"softmax"
);
return
match
::
name
(
"dot"
)(
is_ck_gemm
().
bind
(
"gemm2"
))(
match
::
arg
(
0
)(
softmax
));
return
match
::
name
(
"dot"
)(
match
::
any_of
(
is_ck_gemm
(),
is_mlir_gemm
()).
bind
(
"gemm2"
))(
match
::
arg
(
0
)(
softmax
));
}
void
apply
(
module_pass_manager
&
mpm
,
const
match
::
matcher_result
&
r
)
const
...
...
@@ -182,8 +208,7 @@ void prefuse_ops::apply(module_pass_manager& mpm) const
match
::
find_matches
(
mpm
.
get_module
(),
find_layernorm
{});
mpm
.
run_pass
(
dead_code_elimination
{});
match
::
find_matches
(
mpm
.
get_module
(),
find_add_layernorm
{});
if
(
enabled
(
MIGRAPHX_ENABLE_CK
{}))
match
::
find_matches
(
mpm
,
find_gemm_softmax_gemm
{});
match
::
find_matches
(
mpm
,
find_gemm_softmax_gemm
{});
}
}
// namespace gpu
...
...
src/targets/ref/CMakeLists.txt
View file @
c4cee345
...
...
@@ -33,8 +33,9 @@ rocm_set_soversion(migraphx_ref ${MIGRAPHX_SO_VERSION})
find_path
(
BLAZE_INCLUDE blaze/Blaze.h
)
rocm_clang_tidy_check
(
migraphx_ref
)
target_link_libraries
(
migraphx_ref PRIVATE Threads::Threads
)
target_link_libraries
(
migraphx_ref PUBLIC migraphx
)
target_include_directories
(
migraphx_ref PRIVATE
${
BLAZE_INCLUDE
}
)
target_include_directories
(
migraphx_ref
SYSTEM
PRIVATE
${
BLAZE_INCLUDE
}
)
target_compile_definitions
(
migraphx_ref PRIVATE -DBLAZE_USE_CPP_THREADS
)
migraphx_generate_export_header
(
migraphx_ref
)
...
...
src/tf/CMakeLists.txt
View file @
c4cee345
...
...
@@ -38,7 +38,11 @@ protobuf_generate_cpp(
)
add_library
(
tf-proto STATIC
${
PROTO_SRCS
}
)
target_include_directories
(
tf-proto SYSTEM PUBLIC
${
CMAKE_CURRENT_BINARY_DIR
}
${
PROTOBUF_INCLUDE_DIR
}
)
target_compile_options
(
tf-proto PRIVATE -w
)
if
(
MSVC
)
target_compile_options
(
tf-proto PRIVATE /w
)
else
()
target_compile_options
(
tf-proto PRIVATE -w
)
endif
()
target_link_libraries
(
tf-proto PRIVATE
${
PROTOBUF_LIBRARY
}
)
set_target_properties
(
tf-proto PROPERTIES POSITION_INDEPENDENT_CODE On
)
...
...
@@ -49,7 +53,10 @@ target_include_directories(migraphx_tf PRIVATE include)
set_target_properties
(
migraphx_tf PROPERTIES EXPORT_NAME tf
)
rocm_set_soversion
(
migraphx_tf
${
MIGRAPHX_SO_VERSION
}
)
rocm_clang_tidy_check
(
migraphx_tf
)
target_link_libraries
(
migraphx_tf PRIVATE tf-proto
"-Wl,--exclude-libs,ALL"
)
target_link_libraries
(
migraphx_tf PRIVATE tf-proto
)
if
(
NOT WIN32
)
target_link_libraries
(
migraphx_tf PRIVATE
"-Wl,--exclude-libs,ALL"
)
endif
()
target_link_libraries
(
migraphx_tf PUBLIC migraphx
)
rocm_install_targets
(
...
...
src/tmp_dir.cpp
View file @
c4cee345
...
...
@@ -31,8 +31,18 @@
#include <sstream>
#include <iostream>
#include <string>
#include <sys/types.h>
#ifdef _WIN32
// cppcheck-suppress definePrefix
#define WIN32_LEAN_AND_MEAN
#include <Windows.h>
#undef getpid
// cppcheck-suppress [definePrefix, defineUpperCase]
#define getpid _getpid
#else
#include <unistd.h>
#include <sys/types.h>
#endif
namespace
migraphx
{
inline
namespace
MIGRAPHX_INLINE_NS
{
...
...
test/gpu/fuse_mlir.cpp
View file @
c4cee345
...
...
@@ -144,10 +144,12 @@ TEST_CASE(int_quant_dot_tanh_fails)
auto
tanh
=
add_pointwise
(
p1
,
"main:pointwise0"
,
{
dot
},
single_pointwise
(
"tanh"
));
mm
->
add_return
({
tanh
});
}
migraphx
::
program
p2
(
p1
);
// This pass should do nothing as int32_t tanh isn't supported.
// This pass should not fuse as int32_t tanh isn't supported.
run_pass
(
p1
);
EXPECT
(
p1
==
p2
);
auto
*
mm
=
p1
.
get_main_module
();
bool
has_pointwise
=
std
::
any_of
(
mm
->
begin
(),
mm
->
end
(),
[
&
](
const
auto
&
i
)
{
return
i
.
name
()
==
"pointwise"
;
});
EXPECT
(
has_pointwise
);
}
int
main
(
int
argc
,
const
char
*
argv
[])
...
...
Prev
1
2
3
4
5
6
7
8
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment