Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
MIGraphX
Commits
80bf741a
Commit
80bf741a
authored
May 19, 2023
by
Alan Turner
Browse files
Merge remote-tracking branch 'origin/develop' into ck-int8-fusion
parents
99626b4c
0e6ee3f7
Changes
136
Show whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
236 additions
and
103 deletions
+236
-103
src/split_single_dyn_dim.cpp
src/split_single_dyn_dim.cpp
+10
-4
src/targets/fpga/include/migraphx/fpga/vitis_ai_adapter.hpp
src/targets/fpga/include/migraphx/fpga/vitis_ai_adapter.hpp
+1
-1
src/targets/fpga/vitis_ai_adapter.cpp
src/targets/fpga/vitis_ai_adapter.cpp
+1
-1
src/targets/gpu/CMakeLists.txt
src/targets/gpu/CMakeLists.txt
+1
-5
src/targets/gpu/compile_gen.cpp
src/targets/gpu/compile_gen.cpp
+13
-1
src/targets/gpu/compile_hip.cpp
src/targets/gpu/compile_hip.cpp
+10
-3
src/targets/gpu/compile_hip_code_object.cpp
src/targets/gpu/compile_hip_code_object.cpp
+9
-4
src/targets/gpu/device/include/migraphx/gpu/device/visit.hpp
src/targets/gpu/device/include/migraphx/gpu/device/visit.hpp
+10
-10
src/targets/gpu/device/scatter.cpp
src/targets/gpu/device/scatter.cpp
+16
-12
src/targets/gpu/driver/CMakeLists.txt
src/targets/gpu/driver/CMakeLists.txt
+1
-0
src/targets/gpu/driver/include/migraphx/gpu/driver/action.hpp
...targets/gpu/driver/include/migraphx/gpu/driver/action.hpp
+1
-1
src/targets/gpu/fuse_mlir.cpp
src/targets/gpu/fuse_mlir.cpp
+128
-28
src/targets/gpu/gemm_impl.cpp
src/targets/gpu/gemm_impl.cpp
+0
-5
src/targets/gpu/include/migraphx/gpu/compile_hip.hpp
src/targets/gpu/include/migraphx/gpu/compile_hip.hpp
+6
-0
src/targets/gpu/include/migraphx/gpu/fuse_mlir.hpp
src/targets/gpu/include/migraphx/gpu/fuse_mlir.hpp
+2
-0
src/targets/gpu/include/migraphx/gpu/hip.hpp
src/targets/gpu/include/migraphx/gpu/hip.hpp
+15
-5
src/targets/gpu/kernels/include/migraphx/kernels/hip.hpp
src/targets/gpu/kernels/include/migraphx/kernels/hip.hpp
+0
-4
src/targets/gpu/kernels/include/migraphx/kernels/math.hpp
src/targets/gpu/kernels/include/migraphx/kernels/math.hpp
+2
-10
src/targets/gpu/kernels/include/migraphx/kernels/print.hpp
src/targets/gpu/kernels/include/migraphx/kernels/print.hpp
+2
-2
src/targets/gpu/kernels/include/migraphx/kernels/reduce.hpp
src/targets/gpu/kernels/include/migraphx/kernels/reduce.hpp
+8
-7
No files found.
src/split_single_dyn_dim.cpp
View file @
80bf741a
...
@@ -100,10 +100,10 @@ struct find_static_2in_broadcasts
...
@@ -100,10 +100,10 @@ struct find_static_2in_broadcasts
}
// namespace
}
// namespace
/**
/**
* Makes all the shapes in the dynamic_dimension range.
* Makes all the shapes in the dynamic_dimension range.
Probably won't work for `if`
*
Probably won't work for `if`
and `loop` instructions, depending on how the submodules for those
* and `loop` instructions, depending on how the submodules for those
* work. Inserts select_module instruction to the top. Replaces return, bypassing other
* work. Inserts select_module instruction to the top. Replaces return, bypassing other
* instructions.
* instructions.
Skips if the dynamic parameter outputs to a select_module operator.
*/
*/
void
split_single_dyn_dim
::
apply
(
module_pass_manager
&
mpm
)
const
void
split_single_dyn_dim
::
apply
(
module_pass_manager
&
mpm
)
const
{
{
...
@@ -111,7 +111,13 @@ void split_single_dyn_dim::apply(module_pass_manager& mpm) const
...
@@ -111,7 +111,13 @@ void split_single_dyn_dim::apply(module_pass_manager& mpm) const
auto
param_names
=
mm
->
get_parameter_names
();
auto
param_names
=
mm
->
get_parameter_names
();
auto
param_shapes
=
mm
->
get_parameter_shapes
();
auto
param_shapes
=
mm
->
get_parameter_shapes
();
optional
<
dynamic_dimensions_check
>
dd_check
=
has_one_dyn_dim
(
param_shapes
);
optional
<
dynamic_dimensions_check
>
dd_check
=
has_one_dyn_dim
(
param_shapes
);
if
(
dd_check
.
has_value
())
auto
any_sm_next
=
[
&
](
auto
ddc
)
{
auto
p_outputs
=
mm
->
get_parameter
(
ddc
->
dyn_param_str
)
->
outputs
();
return
std
::
any_of
(
p_outputs
.
cbegin
(),
p_outputs
.
cend
(),
[](
auto
ins
)
{
return
ins
->
name
()
==
"select_module"
;
});
};
if
(
dd_check
.
has_value
()
and
not
any_sm_next
(
dd_check
))
{
{
const
auto
&
dyn_param
=
mm
->
get_parameter
(
dd_check
->
dyn_param_str
);
const
auto
&
dyn_param
=
mm
->
get_parameter
(
dd_check
->
dyn_param_str
);
auto
dyn_param_shape
=
mm
->
get_parameter_shape
(
dd_check
->
dyn_param_str
);
auto
dyn_param_shape
=
mm
->
get_parameter_shape
(
dd_check
->
dyn_param_str
);
...
...
src/targets/fpga/include/migraphx/fpga/vitis_ai_adapter.hpp
View file @
80bf741a
...
@@ -41,7 +41,7 @@ class x_model
...
@@ -41,7 +41,7 @@ class x_model
void
set_shape
(
migraphx
::
shape
);
void
set_shape
(
migraphx
::
shape
);
};
};
x_model
create_xmodel
(
migraphx
::
module_ref
mod
);
x_model
create_xmodel
(
migraphx
::
const_
module_ref
mod
);
migraphx
::
argument
execute
(
const
x_model
&
xmodel
,
migraphx
::
argument
execute
(
const
x_model
&
xmodel
,
const
migraphx
::
shape
&
output_shape
,
const
migraphx
::
shape
&
output_shape
,
...
...
src/targets/fpga/vitis_ai_adapter.cpp
View file @
80bf741a
...
@@ -33,7 +33,7 @@ migraphx::shape x_model::get_shape() const { return shape; };
...
@@ -33,7 +33,7 @@ migraphx::shape x_model::get_shape() const { return shape; };
void
x_model
::
set_shape
(
migraphx
::
shape
s
)
{
shape
=
s
;
}
void
x_model
::
set_shape
(
migraphx
::
shape
s
)
{
shape
=
s
;
}
x_model
create_xmodel
(
const
migraphx
::
module_ref
mod
)
x_model
create_xmodel
(
migraphx
::
const_
module_ref
mod
)
{
{
std
::
cout
<<
"Calling an external function: create_xmodel!
\n
"
;
std
::
cout
<<
"Calling an external function: create_xmodel!
\n
"
;
x_model
xmodel
;
x_model
xmodel
;
...
...
src/targets/gpu/CMakeLists.txt
View file @
80bf741a
...
@@ -33,11 +33,7 @@ if(NOT TARGET MIOpen)
...
@@ -33,11 +33,7 @@ if(NOT TARGET MIOpen)
message
(
SEND_ERROR
"Cant find miopen"
)
message
(
SEND_ERROR
"Cant find miopen"
)
endif
()
endif
()
if
(
BUILD_DEV
)
set
(
MIGRAPHX_USE_HIPRTC OFF CACHE BOOL
"Use hipClang APIs"
)
set
(
MIGRAPHX_USE_HIPRTC OFF CACHE BOOL
"Use hipRTC APIs"
)
else
()
set
(
MIGRAPHX_USE_HIPRTC ON CACHE BOOL
"Use hipRTC APIs"
)
endif
()
include
(
Embed
)
include
(
Embed
)
file
(
GLOB KERNEL_FILES
${
CONFIGURE_DEPENDS
}
file
(
GLOB KERNEL_FILES
${
CONFIGURE_DEPENDS
}
...
...
src/targets/gpu/compile_gen.cpp
View file @
80bf741a
...
@@ -29,6 +29,7 @@
...
@@ -29,6 +29,7 @@
#include <migraphx/module.hpp>
#include <migraphx/module.hpp>
#include <migraphx/dead_code_elimination.hpp>
#include <migraphx/dead_code_elimination.hpp>
#include <migraphx/eliminate_common_subexpression.hpp>
#include <migraphx/eliminate_common_subexpression.hpp>
#include <migraphx/rewrite_quantization.hpp>
#include <migraphx/cpp_generator.hpp>
#include <migraphx/cpp_generator.hpp>
#include <migraphx/pass_manager.hpp>
#include <migraphx/pass_manager.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/instruction.hpp>
...
@@ -171,7 +172,8 @@ std::string make_transformer_args(std::vector<std::string> transformers)
...
@@ -171,7 +172,8 @@ std::string make_transformer_args(std::vector<std::string> transformers)
void
generate_pointwise
(
cpp_generator
&
gg
,
const
module
&
pm
,
const
std
::
string
&
name
)
void
generate_pointwise
(
cpp_generator
&
gg
,
const
module
&
pm
,
const
std
::
string
&
name
)
{
{
module
m
=
pm
;
module
m
=
pm
;
run_passes
(
m
,
{
eliminate_common_subexpression
{},
dead_code_elimination
{}});
run_passes
(
m
,
{
rewrite_quantization
{},
eliminate_common_subexpression
{},
dead_code_elimination
{}});
cpp_generator
g
;
cpp_generator
g
;
g
.
fmap
([](
const
std
::
string
&
fname
)
{
return
"migraphx::"
+
fname
;
});
g
.
fmap
([](
const
std
::
string
&
fname
)
{
return
"migraphx::"
+
fname
;
});
g
.
add_point_op
(
"where"
,
"${function:where}(${0}, ${1}, ${2})"
);
g
.
add_point_op
(
"where"
,
"${function:where}(${0}, ${1}, ${2})"
);
...
@@ -280,6 +282,14 @@ std::string generate_reduce(const module& m, const std::string& name)
...
@@ -280,6 +282,14 @@ std::string generate_reduce(const module& m, const std::string& name)
not
input
->
get_shape
().
broadcasted
();
not
input
->
get_shape
().
broadcasted
();
});
});
auto
inner_names
=
names
;
auto
inner_names
=
names
;
for
(
auto
input
:
ins
->
inputs
())
{
if
(
input
->
name
()
!=
"@param"
)
continue
;
if
(
contains
(
tensors
,
input
))
continue
;
inner_names
[
input
]
+=
"[out_idx]"
;
}
for
(
auto
input
:
tensors
)
for
(
auto
input
:
tensors
)
inner_names
[
input
]
+=
"_lambda_param"
;
inner_names
[
input
]
+=
"_lambda_param"
;
auto
call_function
=
auto
call_function
=
...
@@ -308,6 +318,8 @@ std::string generate_reduce(const module& m, const std::string& name)
...
@@ -308,6 +318,8 @@ std::string generate_reduce(const module& m, const std::string& name)
});
});
f
.
set_attributes
({
"__device__"
,
"__attribute__((const))"
}).
set_generic_types
(
m
).
set_name
(
name
);
f
.
set_attributes
({
"__device__"
,
"__attribute__((const))"
}).
set_generic_types
(
m
).
set_name
(
name
);
f
.
add_generic_param
(
"r"
);
f
.
add_generic_param
(
"r"
);
f
.
add_generic_param
(
"out_idx"
);
f
.
unused_param
(
"out_idx"
);
g
.
create_function
(
f
);
g
.
create_function
(
f
);
return
g
.
str
();
return
g
.
str
();
}
}
...
...
src/targets/gpu/compile_hip.cpp
View file @
80bf741a
...
@@ -56,9 +56,6 @@ MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_GPU_DUMP_SRC);
...
@@ -56,9 +56,6 @@ MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_GPU_DUMP_SRC);
#ifdef MIGRAPHX_USE_HIPRTC
#ifdef MIGRAPHX_USE_HIPRTC
MIGRAPHX_DECLARE_ENV_VAR
(
MIGRAPHX_TRACE_HIPRTC
);
MIGRAPHX_DECLARE_ENV_VAR
(
MIGRAPHX_ENABLE_HIPRTC_WORKAROUNDS
);
std
::
string
hiprtc_error
(
hiprtcResult
err
,
const
std
::
string
&
msg
)
std
::
string
hiprtc_error
(
hiprtcResult
err
,
const
std
::
string
&
msg
)
{
{
return
"hiprtc: "
+
(
hiprtcGetErrorString
(
err
)
+
(
": "
+
msg
));
return
"hiprtc: "
+
(
hiprtcGetErrorString
(
err
)
+
(
": "
+
msg
));
...
@@ -194,6 +191,7 @@ std::vector<std::vector<char>> compile_hip_src_with_hiprtc(std::vector<hiprtc_sr
...
@@ -194,6 +191,7 @@ std::vector<std::vector<char>> compile_hip_src_with_hiprtc(std::vector<hiprtc_sr
options
.
push_back
(
"-DMIGRAPHX_HAS_DPP=0"
);
options
.
push_back
(
"-DMIGRAPHX_HAS_DPP=0"
);
options
.
push_back
(
"-DMIGRAPHX_ENABLE_HIPRTC_WORKAROUNDS=1"
);
options
.
push_back
(
"-DMIGRAPHX_ENABLE_HIPRTC_WORKAROUNDS=1"
);
options
.
push_back
(
"-Wno-reserved-identifier"
);
options
.
push_back
(
"-Wno-reserved-identifier"
);
options
.
push_back
(
"-Wno-unused-parameter"
);
options
.
push_back
(
"-Wno-gnu-line-marker"
);
options
.
push_back
(
"-Wno-gnu-line-marker"
);
options
.
push_back
(
"-Wno-old-style-cast"
);
options
.
push_back
(
"-Wno-old-style-cast"
);
}
}
...
@@ -216,6 +214,15 @@ std::vector<std::vector<char>>
...
@@ -216,6 +214,15 @@ std::vector<std::vector<char>>
compile_hip_src
(
const
std
::
vector
<
src_file
>&
srcs
,
std
::
string
params
,
const
std
::
string
&
arch
)
compile_hip_src
(
const
std
::
vector
<
src_file
>&
srcs
,
std
::
string
params
,
const
std
::
string
&
arch
)
{
{
std
::
vector
<
hiprtc_src_file
>
hsrcs
{
srcs
.
begin
(),
srcs
.
end
()};
std
::
vector
<
hiprtc_src_file
>
hsrcs
{
srcs
.
begin
(),
srcs
.
end
()};
if
(
enabled
(
MIGRAPHX_GPU_DUMP_SRC
{}))
{
for
(
const
auto
&
src
:
srcs
)
{
if
(
src
.
path
.
extension
()
!=
".cpp"
)
continue
;
std
::
cout
<<
std
::
string
(
src
.
content
.
first
,
src
.
len
())
<<
std
::
endl
;
}
}
auto
p
=
dynamic_loader
::
path
(
&
compile_hip_src_with_hiprtc
);
auto
p
=
dynamic_loader
::
path
(
&
compile_hip_src_with_hiprtc
);
auto
driver
=
p
.
parent_path
().
parent_path
()
/
"bin"
/
"migraphx-hiprtc-driver"
;
auto
driver
=
p
.
parent_path
().
parent_path
()
/
"bin"
/
"migraphx-hiprtc-driver"
;
...
...
src/targets/gpu/compile_hip_code_object.cpp
View file @
80bf741a
...
@@ -135,10 +135,15 @@ compute_global_for(context& ctx, std::size_t n, std::size_t over)
...
@@ -135,10 +135,15 @@ compute_global_for(context& ctx, std::size_t n, std::size_t over)
std
::
size_t
max_global
=
ctx
.
get_current_device
().
get_cu_count
()
*
std
::
size_t
max_global
=
ctx
.
get_current_device
().
get_cu_count
()
*
ctx
.
get_current_device
().
get_max_workitems_per_cu
();
ctx
.
get_current_device
().
get_max_workitems_per_cu
();
return
[
n
,
over
,
max_global
](
std
::
size_t
local
)
{
return
[
n
,
over
,
max_global
](
std
::
size_t
local
)
{
std
::
size_t
groups
=
(
n
+
local
-
1
)
/
local
;
std
::
size_t
num_elements
=
n
;
std
::
size_t
groups
=
(
num_elements
+
local
-
1
)
/
local
;
std
::
size_t
max_blocks
=
max_global
/
local
;
std
::
size_t
max_blocks
=
max_global
/
local
;
std
::
size_t
nglobal
=
std
::
min
(
max_blocks
*
over
,
groups
)
*
local
;
std
::
size_t
nglobal
=
std
::
min
(
max_blocks
*
over
,
groups
)
*
local
;
return
std
::
min
(
nglobal
,
n
);
#ifdef MIGRAPHX_USE_HIPRTC
if
(
enabled
(
MIGRAPHX_ENABLE_HIPRTC_WORKAROUNDS
{}))
num_elements
=
((
num_elements
+
local
-
1
)
/
local
)
*
local
;
#endif
return
std
::
min
(
nglobal
,
num_elements
);
};
};
}
}
...
...
src/targets/gpu/device/include/migraphx/gpu/device/visit.hpp
View file @
80bf741a
...
@@ -94,6 +94,10 @@ template <>
...
@@ -94,6 +94,10 @@ template <>
struct
is_hip_type
<
std
::
uint8_t
>
:
std
::
true_type
struct
is_hip_type
<
std
::
uint8_t
>
:
std
::
true_type
{
{
};
};
template
<
>
struct
is_hip_type
<
std
::
int32_t
>
:
std
::
true_type
{
};
template
<
class
T
,
class
V
,
MIGRAPHX_REQUIRES
(
is_hip_type
<
typename
T
::
type
>{})
>
template
<
class
T
,
class
V
,
MIGRAPHX_REQUIRES
(
is_hip_type
<
typename
T
::
type
>{})
>
void
hip_visitor_invoke
(
T
as
,
V
&&
v
)
void
hip_visitor_invoke
(
T
as
,
V
&&
v
)
...
@@ -120,12 +124,10 @@ void hip_visit_all_impl(const shape& s, F f, V&& v, Ts&&... xs)
...
@@ -120,12 +124,10 @@ void hip_visit_all_impl(const shape& s, F f, V&& v, Ts&&... xs)
if
(
not
std
::
all_of
(
if
(
not
std
::
all_of
(
types
.
begin
(),
types
.
end
(),
[
&
](
migraphx
::
shape
::
type_t
t
)
{
return
t
==
s
.
type
();
}))
types
.
begin
(),
types
.
end
(),
[
&
](
migraphx
::
shape
::
type_t
t
)
{
return
t
==
s
.
type
();
}))
MIGRAPHX_THROW
(
"Types must be the same"
);
MIGRAPHX_THROW
(
"Types must be the same"
);
std
::
initializer_list
<
index_int
>
ranks
=
{
std
::
initializer_list
<
index_int
>
ranks
=
{
static_cast
<
index_int
>
(
get_shape
(
xs
).
ndim
())...};
static_cast
<
index_int
>
(
get_shape
(
xs
).
lens
().
size
())...};
if
(
not
std
::
all_of
(
ranks
.
begin
(),
ranks
.
end
(),
[
&
](
index_int
r
)
{
return
r
==
s
.
ndim
();
}))
if
(
not
std
::
all_of
(
ranks
.
begin
(),
ranks
.
end
(),
[
&
](
index_int
r
)
{
return
r
==
s
.
lens
().
size
();
}))
MIGRAPHX_THROW
(
"Ranks must be the same"
);
MIGRAPHX_THROW
(
"Ranks must be the same"
);
visit_tensor_size
(
s
.
lens
().
size
(),
[
&
](
auto
ndim
)
{
visit_tensor_size
(
s
.
ndim
(),
[
&
](
auto
ndim
)
{
s
.
visit_type
(
hip_visitor
([
&
](
auto
as
)
{
v
(
f
(
xs
,
ndim
,
as
)...);
}));
s
.
visit_type
(
hip_visitor
([
&
](
auto
as
)
{
v
(
f
(
xs
,
ndim
,
as
)...);
}));
});
});
}
}
...
@@ -133,12 +135,10 @@ void hip_visit_all_impl(const shape& s, F f, V&& v, Ts&&... xs)
...
@@ -133,12 +135,10 @@ void hip_visit_all_impl(const shape& s, F f, V&& v, Ts&&... xs)
template
<
class
V
,
class
F
,
class
...
Ts
>
template
<
class
V
,
class
F
,
class
...
Ts
>
void
hip_visit_views_impl
(
const
shape
&
s
,
F
f
,
V
&&
v
,
Ts
&&
...
xs
)
void
hip_visit_views_impl
(
const
shape
&
s
,
F
f
,
V
&&
v
,
Ts
&&
...
xs
)
{
{
std
::
initializer_list
<
index_int
>
ranks
=
{
std
::
initializer_list
<
index_int
>
ranks
=
{
static_cast
<
index_int
>
(
get_shape
(
xs
).
ndim
())...};
static_cast
<
index_int
>
(
get_shape
(
xs
).
lens
().
size
())...};
if
(
not
std
::
all_of
(
ranks
.
begin
(),
ranks
.
end
(),
[
&
](
index_int
r
)
{
return
r
==
s
.
ndim
();
}))
if
(
not
std
::
all_of
(
ranks
.
begin
(),
ranks
.
end
(),
[
&
](
index_int
r
)
{
return
r
==
s
.
lens
().
size
();
}))
MIGRAPHX_THROW
(
"Ranks must be the same"
);
MIGRAPHX_THROW
(
"Ranks must be the same"
);
visit_tensor_size
(
s
.
lens
().
size
(),
[
&
](
auto
ndim
)
{
v
(
f
(
xs
,
ndim
)...);
});
visit_tensor_size
(
s
.
ndim
(),
[
&
](
auto
ndim
)
{
v
(
f
(
xs
,
ndim
)...);
});
}
}
template
<
class
F
>
template
<
class
F
>
...
...
src/targets/gpu/device/scatter.cpp
View file @
80bf741a
...
@@ -37,22 +37,26 @@ argument scatter(
...
@@ -37,22 +37,26 @@ argument scatter(
hipStream_t
stream
,
argument
result
,
argument
arg0
,
argument
arg1
,
argument
arg2
,
int64_t
axis
)
hipStream_t
stream
,
argument
result
,
argument
arg0
,
argument
arg1
,
argument
arg2
,
int64_t
axis
)
{
{
auto
ds
=
arg0
.
get_shape
();
auto
ds
=
arg0
.
get_shape
();
auto
inds
=
arg1
.
get_shape
();
auto
s1
=
arg1
.
get_shape
();
auto
axis_dim_size
=
ds
.
lens
()[
axis
];
auto
axis_dim_size
=
ds
.
lens
()[
axis
];
hip_visit_all
(
result
,
arg0
,
inds
)([
&
](
auto
output
,
auto
data
,
auto
s1
)
{
hip_visit_all
(
result
,
arg0
,
arg2
)([
&
](
auto
output
,
auto
data
,
auto
update
)
{
auto
*
output_ptr
=
device_cast
(
output
.
data
());
auto
*
output_ptr
=
device_cast
(
output
.
data
());
const
auto
*
data_ptr
=
device_cast
(
data
.
data
());
const
auto
*
data_ptr
=
device_cast
(
data
.
data
());
gs_launch
(
stream
,
ds
.
elements
())([
=
](
auto
i
)
__device__
{
output_ptr
[
i
]
=
data_ptr
[
i
];
});
gs_launch
(
stream
,
ds
.
elements
())([
=
](
auto
i
)
__device__
{
output_ptr
[
i
]
=
data_ptr
[
i
];
});
hip_visit_all
(
arg1
,
arg2
)([
&
](
auto
indices
,
auto
update
)
{
hip_visit_all
(
arg1
)([
&
](
auto
indices
)
{
if
constexpr
(
indices
.
get_shape
().
lens
.
size
()
==
output
.
get_shape
().
lens
.
size
())
{
const
auto
*
upd_ptr
=
device_cast
(
update
.
data
());
const
auto
*
upd_ptr
=
device_cast
(
update
.
data
());
const
auto
*
indices_ptr
=
device_cast
(
indices
.
data
());
const
auto
*
indices_ptr
=
device_cast
(
indices
.
data
());
gs_launch
(
stream
,
ind
s
.
elements
())([
=
](
auto
i
)
__device__
{
gs_launch
(
stream
,
s
1
.
elements
())([
=
](
auto
i
)
__device__
{
auto
out_idx
=
s1
.
multi
(
i
);
auto
out_idx
=
indices
.
get_shape
()
.
multi
(
i
);
auto
index
=
indices_ptr
[
i
];
auto
index
=
indices_ptr
[
i
];
index
=
index
<
0
?
index
+
axis_dim_size
:
index
;
index
=
index
<
0
?
index
+
axis_dim_size
:
index
;
out_idx
[
axis
]
=
index
;
out_idx
[
axis
]
=
index
;
output
[
out_idx
]
=
upd_ptr
[
i
];
output
[
out_idx
]
=
upd_ptr
[
i
];
});
});
}
});
});
});
});
...
...
src/targets/gpu/driver/CMakeLists.txt
View file @
80bf741a
...
@@ -26,5 +26,6 @@ file(GLOB GPU_DRIVER_SRCS ${CONFIGURE_DEPENDS} ${CMAKE_CURRENT_SOURCE_DIR}/*.cpp
...
@@ -26,5 +26,6 @@ file(GLOB GPU_DRIVER_SRCS ${CONFIGURE_DEPENDS} ${CMAKE_CURRENT_SOURCE_DIR}/*.cpp
add_executable
(
gpu-driver
add_executable
(
gpu-driver
${
GPU_DRIVER_SRCS
}
${
GPU_DRIVER_SRCS
}
)
)
rocm_clang_tidy_check
(
gpu-driver
)
target_include_directories
(
gpu-driver PRIVATE include
)
target_include_directories
(
gpu-driver PRIVATE include
)
target_link_libraries
(
gpu-driver PRIVATE migraphx_gpu
)
target_link_libraries
(
gpu-driver PRIVATE migraphx_gpu
)
src/targets/gpu/driver/include/migraphx/gpu/driver/action.hpp
View file @
80bf741a
...
@@ -44,7 +44,7 @@ struct auto_register_action
...
@@ -44,7 +44,7 @@ struct auto_register_action
template
<
class
T
>
template
<
class
T
>
static
void
apply
()
static
void
apply
()
{
{
auto
name
=
get_type_name
<
T
>
();
const
auto
&
name
=
get_type_name
<
T
>
();
register_action
(
name
.
substr
(
name
.
rfind
(
"::"
)
+
2
),
register_action
(
name
.
substr
(
name
.
rfind
(
"::"
)
+
2
),
[](
auto
&&
...
xs
)
{
T
::
apply
(
std
::
forward
<
decltype
(
xs
)
>
(
xs
)...);
});
[](
auto
&&
...
xs
)
{
T
::
apply
(
std
::
forward
<
decltype
(
xs
)
>
(
xs
)...);
});
}
}
...
...
src/targets/gpu/fuse_mlir.cpp
View file @
80bf741a
...
@@ -38,6 +38,27 @@ namespace gpu {
...
@@ -38,6 +38,27 @@ namespace gpu {
MIGRAPHX_DECLARE_ENV_VAR
(
MIGRAPHX_ENABLE_MLIR
);
MIGRAPHX_DECLARE_ENV_VAR
(
MIGRAPHX_ENABLE_MLIR
);
bool
mlir_enabled
()
{
#ifdef MIGRAPHX_MLIR
const
bool
mlir_enabled
=
enabled
(
MIGRAPHX_ENABLE_MLIR
{});
if
(
mlir_enabled
)
{
return
true
;
}
else
{
std
::
cerr
<<
"WARNING: MIGraphX built with MLIR but it is not enabled. Please set the env "
"var MIGRAPHX_ENABLE_MLIR to use MLIR kernel generator."
<<
std
::
endl
;
return
false
;
}
#else
return
false
;
#endif
}
#ifdef MIGRAPHX_MLIR
#ifdef MIGRAPHX_MLIR
struct
mlir_op
struct
mlir_op
...
@@ -58,8 +79,41 @@ struct mlir_op
...
@@ -58,8 +79,41 @@ struct mlir_op
MIGRAPHX_THROW
(
"should have one submodule."
);
MIGRAPHX_THROW
(
"should have one submodule."
);
if
(
inputs
.
size
()
<
2
)
if
(
inputs
.
size
()
<
2
)
MIGRAPHX_THROW
(
"should have at least two inputs."
);
MIGRAPHX_THROW
(
"should have at least two inputs."
);
auto
n
=
inputs
.
size
();
return
op
.
compute_shape
({
inputs
[
n
-
2
],
inputs
[
n
-
1
]});
module_ref
mod
=
mods
[
0
];
auto
type
=
mod
->
get_output_shapes
().
front
().
type
();
std
::
unordered_map
<
instruction_ref
,
shape
>
ins_shapes
;
size_t
param_cnt
=
0
;
std
::
vector
<
std
::
string
>
names
=
mod
->
get_parameter_names
();
std
::
sort
(
names
.
begin
(),
names
.
end
());
for
(
std
::
string
param_name
:
names
)
{
ins_shapes
[
mod
->
get_parameter
(
param_name
)]
=
inputs
[
param_cnt
++
];
}
for
(
auto
ins
:
iterator_for
(
*
mod
))
{
if
(
ins
->
name
()
==
"@param"
)
{
continue
;
}
if
(
ins
->
name
()
==
"@literal"
)
{
ins_shapes
[
ins
]
=
ins
->
get_shape
();
continue
;
}
if
(
ins
->
name
()
==
"@return"
)
{
return
ins_shapes
[
ins
->
inputs
().
at
(
0
)].
with_type
(
type
);
}
std
::
vector
<
shape
>
input_shapes
;
input_shapes
.
resize
(
ins
->
inputs
().
size
());
std
::
transform
(
ins
->
inputs
().
begin
(),
ins
->
inputs
().
end
(),
input_shapes
.
begin
(),
[
&
](
auto
in
)
{
return
ins_shapes
[
in
];
});
ins_shapes
[
ins
]
=
ins
->
get_operator
().
compute_shape
(
input_shapes
);
}
MIGRAPHX_THROW
(
"No return found in the submodule"
);
}
}
};
};
MIGRAPHX_REGISTER_OP
(
mlir_op
);
MIGRAPHX_REGISTER_OP
(
mlir_op
);
...
@@ -68,7 +122,7 @@ namespace {
...
@@ -68,7 +122,7 @@ namespace {
MIGRAPHX_PRED_MATCHER
(
is_mlir_conv
,
instruction_ref
ins
)
MIGRAPHX_PRED_MATCHER
(
is_mlir_conv
,
instruction_ref
ins
)
{
{
if
(
ins
->
name
()
!=
"convolution"
)
if
(
ins
->
name
()
!=
"convolution"
and
ins
->
name
()
!=
"quant_convolution"
)
return
false
;
return
false
;
value
v
=
ins
->
get_operator
().
to_value
();
value
v
=
ins
->
get_operator
().
to_value
();
auto
group
=
v
.
at
(
"group"
).
to
<
int
>
();
auto
group
=
v
.
at
(
"group"
).
to
<
int
>
();
...
@@ -89,6 +143,53 @@ struct find_mlir_op
...
@@ -89,6 +143,53 @@ struct find_mlir_op
return
match
::
name
(
"pointwise"
)(
match
::
any_of
[
match
::
inputs
()](
dot_or_conv
.
bind
(
"x"
)));
return
match
::
name
(
"pointwise"
)(
match
::
any_of
[
match
::
inputs
()](
dot_or_conv
.
bind
(
"x"
)));
}
}
std
::
unordered_map
<
instruction_ref
,
instruction_ref
>
create_param_map_with_literals
(
module_ref
mm
,
const
module
*
pm
,
const
shape
&
shape
)
const
{
std
::
unordered_map
<
instruction_ref
,
instruction_ref
>
ins_map
;
for
(
auto
ins
:
iterator_for
(
*
pm
))
{
if
(
ins
->
name
()
!=
"@literal"
)
{
continue
;
}
literal
r
=
ins
->
get_literal
();
instruction_ref
literal
=
mm
->
add_literal
(
r
);
instruction_ref
mbcast
=
mm
->
add_instruction
(
make_op
(
"multibroadcast"
,
{{
"out_lens"
,
shape
.
lens
()}}),
literal
);
ins_map
[
ins
]
=
mbcast
;
}
return
ins_map
;
}
std
::
tuple
<
instruction_ref
,
std
::
vector
<
instruction_ref
>>
fuse_input_ops_and_gemm_based_op
(
module_ref
mm
,
instruction_ref
gemm_based_op
)
const
{
std
::
vector
<
instruction_ref
>
top_inputs
;
std
::
vector
<
instruction_ref
>
imm_inputs
;
size_t
input_cnt
=
0
;
for
(
instruction_ref
input
:
gemm_based_op
->
inputs
())
{
std
::
vector
<
operation
>
op_stream
;
while
(
contains
({
"slice"
,
"transpose"
,
"contiguous"
,
"reshape"
},
input
->
name
()))
{
op_stream
.
push_back
(
input
->
get_operator
());
input
=
input
->
inputs
().
at
(
0
);
}
top_inputs
.
push_back
(
input
);
instruction_ref
prev_input
=
mm
->
add_parameter
(
"y"
+
std
::
to_string
(
input_cnt
++
),
input
->
get_shape
());
for
(
const
auto
&
op
:
reverse
(
op_stream
))
{
prev_input
=
mm
->
add_instruction
(
op
,
{
prev_input
});
}
imm_inputs
.
push_back
(
prev_input
);
}
instruction_ref
new_gemm_based_op
=
mm
->
add_instruction
(
gemm_based_op
->
get_operator
(),
imm_inputs
);
return
{
new_gemm_based_op
,
top_inputs
};
}
void
apply
(
module_pass_manager
&
mpm
,
const
match
::
matcher_result
&
r
)
const
void
apply
(
module_pass_manager
&
mpm
,
const
match
::
matcher_result
&
r
)
const
{
{
auto
ins
=
r
.
result
;
auto
ins
=
r
.
result
;
...
@@ -98,33 +199,42 @@ struct find_mlir_op
...
@@ -98,33 +199,42 @@ struct find_mlir_op
auto
names
=
pm
->
get_parameter_names
();
auto
names
=
pm
->
get_parameter_names
();
// Whitelist pointwise operators
// Whitelist pointwise operators
if
(
std
::
any_of
(
pm
->
begin
(),
pm
->
end
(),
[](
const
auto
&
i
)
{
if
(
std
::
any_of
(
pm
->
begin
(),
pm
->
end
(),
[](
const
auto
&
i
)
{
return
not
contains
(
return
not
contains
({
"@literal"
,
{
"@literal"
,
"@param"
,
"@return"
,
"convolution"
,
"dot"
,
"add"
,
"relu"
},
"@param"
,
"@return"
,
"convolution"
,
"quant_convolution"
,
"dot"
,
"add"
,
"relu"
,
"dequantizelinear"
,
"quantizelinear"
,
"mul"
},
i
.
name
());
i
.
name
());
}))
}))
return
;
return
;
// Only fuse with fp32/fp16
// Only fuse with fp32/fp16
/int8/int32
if
(
std
::
any_of
(
ins
->
inputs
().
begin
(),
ins
->
inputs
().
end
(),
[
&
](
auto
i
)
{
if
(
std
::
any_of
(
ins
->
inputs
().
begin
(),
ins
->
inputs
().
end
(),
[
&
](
auto
i
)
{
return
not
contains
({
shape
::
type_t
::
float_type
,
shape
::
type_t
::
half_type
},
return
not
contains
({
shape
::
type_t
::
float_type
,
shape
::
type_t
::
half_type
,
shape
::
type_t
::
int8_type
,
shape
::
type_t
::
int32_type
},
i
->
get_shape
().
type
());
i
->
get_shape
().
type
());
}))
}))
return
;
return
;
std
::
sort
(
names
.
begin
(),
names
.
end
());
std
::
sort
(
names
.
begin
(),
names
.
end
());
module_ref
mm
=
mpm
.
create_module
(
"mlir_"
+
pm
->
name
());
module_ref
mm
=
mpm
.
create_module
(
"mlir_"
+
pm
->
name
());
mm
->
set_bypass
();
mm
->
set_bypass
();
std
::
unordered_map
<
instruction_ref
,
instruction_ref
>
param_map
;
std
::
unordered_map
<
instruction_ref
,
instruction_ref
>
param_map
=
auto
x
=
mm
->
add_parameter
(
"x"
+
std
::
to_string
(
names
.
size
()),
create_param_map_with_literals
(
mm
,
pm
,
gemm_based_op
->
get_shape
());
gemm_based_op
->
inputs
().
at
(
0
)
->
get_shape
());
auto
[
anchor_op
,
top_inputs
]
=
fuse_input_ops_and_gemm_based_op
(
mm
,
gemm_based_op
);
auto
w
=
mm
->
add_parameter
(
"x"
+
std
::
to_string
(
names
.
size
()
+
1
),
gemm_based_op
->
inputs
().
at
(
1
)
->
get_shape
());
auto
conv
=
mm
->
add_instruction
(
gemm_based_op
->
get_operator
(),
{
x
,
w
});
std
::
transform
(
names
.
begin
(),
std
::
transform
(
names
.
begin
(),
names
.
end
(),
names
.
end
(),
ins
->
inputs
().
begin
(),
ins
->
inputs
().
begin
(),
std
::
inserter
(
param_map
,
param_map
.
end
()),
std
::
inserter
(
param_map
,
param_map
.
end
()),
[
&
](
auto
name
,
auto
input
)
{
[
&
,
&
anchor_op
=
anchor_op
](
auto
name
,
auto
input
)
{
if
(
input
==
x_ins
)
if
(
input
==
x_ins
)
return
std
::
make_pair
(
pm
->
get_parameter
(
name
),
conv
);
return
std
::
make_pair
(
pm
->
get_parameter
(
name
),
anchor_op
);
return
std
::
make_pair
(
pm
->
get_parameter
(
name
),
return
std
::
make_pair
(
pm
->
get_parameter
(
name
),
mm
->
add_parameter
(
name
,
input
->
get_shape
()));
mm
->
add_parameter
(
name
,
input
->
get_shape
()));
});
});
...
@@ -135,7 +245,7 @@ struct find_mlir_op
...
@@ -135,7 +245,7 @@ struct find_mlir_op
ins
->
inputs
().
end
(),
ins
->
inputs
().
end
(),
std
::
back_inserter
(
inputs
),
std
::
back_inserter
(
inputs
),
[
&
](
auto
input
)
{
return
input
!=
gemm_based_op
;
});
[
&
](
auto
input
)
{
return
input
!=
gemm_based_op
;
});
inputs
.
insert
(
inputs
.
end
(),
gemm_based_op
->
inputs
()
.
begin
(),
gemm_based_op
->
inputs
()
.
end
());
inputs
.
insert
(
inputs
.
end
(),
top_
inputs
.
begin
(),
top_
inputs
.
end
());
mpm
.
get_module
().
replace_instruction
(
mpm
.
get_module
().
replace_instruction
(
ins
,
mlir_op
{
gemm_based_op
->
get_operator
()},
inputs
,
{
mm
});
ins
,
mlir_op
{
gemm_based_op
->
get_operator
()},
inputs
,
{
mm
});
}
}
...
@@ -148,17 +258,7 @@ struct find_mlir_op
...
@@ -148,17 +258,7 @@ struct find_mlir_op
void
fuse_mlir
::
apply
(
module_pass_manager
&
mpm
)
const
void
fuse_mlir
::
apply
(
module_pass_manager
&
mpm
)
const
{
{
#ifdef MIGRAPHX_MLIR
#ifdef MIGRAPHX_MLIR
const
bool
mlir_enabled
=
enabled
(
MIGRAPHX_ENABLE_MLIR
{});
if
(
mlir_enabled
)
{
match
::
find_matches
(
mpm
,
find_mlir_op
{});
match
::
find_matches
(
mpm
,
find_mlir_op
{});
}
else
{
std
::
cerr
<<
"WARNING: MIGraphX built with MLIR but it is not enabled. Please set the env "
"var MIGRAPHX_ENABLE_MLIR to use MLIR kernel generator."
<<
std
::
endl
;
}
#else
#else
(
void
)
mpm
;
(
void
)
mpm
;
#endif
#endif
...
...
src/targets/gpu/gemm_impl.cpp
View file @
80bf741a
...
@@ -157,13 +157,8 @@ void gemm_impl(context& ctx,
...
@@ -157,13 +157,8 @@ void gemm_impl(context& ctx,
compute_type
=
rocblas_datatype_f32_r
;
compute_type
=
rocblas_datatype_f32_r
;
}
}
#if ROCBLAS_VERSION_MAJOR >= 2 && ROCBLAS_VERSION_MINOR >= 38
rocblas_gemm_flags
flag
=
rocblas_gemm_flags
flag
=
int8_x4_format
?
rocblas_gemm_flags_pack_int8x4
:
rocblas_gemm_flags_none
;
int8_x4_format
?
rocblas_gemm_flags_pack_int8x4
:
rocblas_gemm_flags_none
;
#else
(
void
)
int8_x4_format
;
int
flag
=
0
;
#endif
auto
a_lens
=
args
[
0
].
get_shape
().
lens
();
auto
a_lens
=
args
[
0
].
get_shape
().
lens
();
auto
b_lens
=
args
[
1
].
get_shape
().
lens
();
auto
b_lens
=
args
[
1
].
get_shape
().
lens
();
...
...
src/targets/gpu/include/migraphx/gpu/compile_hip.hpp
View file @
80bf741a
...
@@ -27,6 +27,7 @@
...
@@ -27,6 +27,7 @@
#include <migraphx/config.hpp>
#include <migraphx/config.hpp>
#include <migraphx/filesystem.hpp>
#include <migraphx/filesystem.hpp>
#include <migraphx/compile_src.hpp>
#include <migraphx/compile_src.hpp>
#include <migraphx/env.hpp>
#include <migraphx/functional.hpp>
#include <migraphx/functional.hpp>
#include <string>
#include <string>
#include <utility>
#include <utility>
...
@@ -36,6 +37,11 @@ namespace migraphx {
...
@@ -36,6 +37,11 @@ namespace migraphx {
inline
namespace
MIGRAPHX_INLINE_NS
{
inline
namespace
MIGRAPHX_INLINE_NS
{
namespace
gpu
{
namespace
gpu
{
#ifdef MIGRAPHX_USE_HIPRTC
MIGRAPHX_DECLARE_ENV_VAR
(
MIGRAPHX_TRACE_HIPRTC
);
MIGRAPHX_DECLARE_ENV_VAR
(
MIGRAPHX_ENABLE_HIPRTC_WORKAROUNDS
);
#endif
struct
hiprtc_src_file
struct
hiprtc_src_file
{
{
hiprtc_src_file
()
=
default
;
hiprtc_src_file
()
=
default
;
...
...
src/targets/gpu/include/migraphx/gpu/fuse_mlir.hpp
View file @
80bf741a
...
@@ -34,6 +34,8 @@ struct module_pass_manager;
...
@@ -34,6 +34,8 @@ struct module_pass_manager;
namespace
gpu
{
namespace
gpu
{
bool
mlir_enabled
();
struct
fuse_mlir
struct
fuse_mlir
{
{
context
*
ctx
=
nullptr
;
context
*
ctx
=
nullptr
;
...
...
src/targets/gpu/include/migraphx/gpu/hip.hpp
View file @
80bf741a
...
@@ -29,6 +29,7 @@
...
@@ -29,6 +29,7 @@
#include <migraphx/literal.hpp>
#include <migraphx/literal.hpp>
#include <migraphx/check_shapes.hpp>
#include <migraphx/check_shapes.hpp>
#include <migraphx/functional.hpp>
#include <migraphx/functional.hpp>
#include <migraphx/dyn_output.hpp>
#include <utility>
#include <utility>
namespace
migraphx
{
namespace
migraphx
{
...
@@ -112,7 +113,7 @@ struct hip_copy_to_gpu
...
@@ -112,7 +113,7 @@ struct hip_copy_to_gpu
std
::
string
name
()
const
{
return
"hip::copy_to_gpu"
;
}
std
::
string
name
()
const
{
return
"hip::copy_to_gpu"
;
}
shape
compute_shape
(
std
::
vector
<
shape
>
inputs
)
const
shape
compute_shape
(
std
::
vector
<
shape
>
inputs
)
const
{
{
check_shapes
{
inputs
,
*
this
}.
has
(
1
,
2
).
same_type
();
check_shapes
{
inputs
,
*
this
,
true
}.
has
(
1
,
2
).
same_type
();
return
inputs
.
at
(
0
);
return
inputs
.
at
(
0
);
}
}
argument
compute
(
context
&
ctx
,
const
shape
&
,
const
std
::
vector
<
argument
>&
args
)
const
argument
compute
(
context
&
ctx
,
const
shape
&
,
const
std
::
vector
<
argument
>&
args
)
const
...
@@ -121,6 +122,10 @@ struct hip_copy_to_gpu
...
@@ -121,6 +122,10 @@ struct hip_copy_to_gpu
if
(
args
.
size
()
==
1
)
if
(
args
.
size
()
==
1
)
return
input
;
return
input
;
argument
result
=
args
[
1
].
share
();
argument
result
=
args
[
1
].
share
();
if
(
result
.
get_shape
().
dynamic
())
{
result
=
result
.
reshape
(
args
[
0
].
get_shape
());
}
gpu_copy
(
ctx
,
input
,
result
);
gpu_copy
(
ctx
,
input
,
result
);
// Associate the input since it was registered with hip
// Associate the input since it was registered with hip
return
{
result
.
get_shape
(),
[
input
,
result
]()
mutable
{
return
result
.
data
();
}};
return
{
result
.
get_shape
(),
[
input
,
result
]()
mutable
{
return
result
.
data
();
}};
...
@@ -138,19 +143,24 @@ struct hip_copy_from_gpu
...
@@ -138,19 +143,24 @@ struct hip_copy_from_gpu
std
::
string
name
()
const
{
return
"hip::copy_from_gpu"
;
}
std
::
string
name
()
const
{
return
"hip::copy_from_gpu"
;
}
shape
compute_shape
(
std
::
vector
<
shape
>
inputs
)
const
shape
compute_shape
(
std
::
vector
<
shape
>
inputs
)
const
{
{
check_shapes
{
inputs
,
*
this
}.
has
(
1
,
2
).
same_type
();
check_shapes
{
inputs
,
*
this
,
true
}.
has
(
1
,
2
).
same_type
();
return
inputs
.
at
(
0
);
return
inputs
.
at
(
0
);
}
}
argument
argument
compute
(
context
&
ctx
,
const
shape
&
output_shape
,
const
std
::
vector
<
argument
>&
args
)
const
compute
(
context
&
ctx
,
const
dyn_output
&
dyn_out
,
const
std
::
vector
<
argument
>&
args
)
const
{
{
if
(
args
.
size
()
==
1
)
if
(
args
.
size
()
==
1
)
{
{
argument
result
=
allocate_gpu
(
out
put_shape
,
true
);
argument
result
=
allocate_gpu
(
dyn_out
.
com
put
ed
_shape
,
true
);
gpu_copy
(
ctx
,
args
[
0
],
result
);
gpu_copy
(
ctx
,
args
[
0
],
result
);
return
result
;
return
result
;
}
}
copy_from_gpu
(
ctx
,
args
[
0
],
args
[
1
]);
argument
input
=
args
[
0
].
share
();
if
(
input
.
get_shape
().
dynamic
())
{
input
=
input
.
reshape
(
args
[
1
].
get_shape
());
}
copy_from_gpu
(
ctx
,
input
,
args
[
1
]);
return
args
[
1
];
return
args
[
1
];
}
}
std
::
ptrdiff_t
output_alias
(
const
std
::
vector
<
shape
>&
args
)
const
std
::
ptrdiff_t
output_alias
(
const
std
::
vector
<
shape
>&
args
)
const
...
...
src/targets/gpu/kernels/include/migraphx/kernels/hip.hpp
View file @
80bf741a
...
@@ -28,10 +28,6 @@
...
@@ -28,10 +28,6 @@
#include <hip/hip_runtime.h>
#include <hip/hip_runtime.h>
#include <hip/hip_fp16.h>
#include <hip/hip_fp16.h>
#include <hip/math_functions.h>
#include <hip/math_functions.h>
#include <hip/hip_math_constants.h>
#elif defined(MIGRAPHX_ENABLE_HIPRTC_WORKAROUNDS)
#include <hip/hip_common.h>
#include <hip/hip_math_constants.h>
#endif
#endif
#endif // MIGRAPHX_GUARD_KERNELS_HIP_HPP
#endif // MIGRAPHX_GUARD_KERNELS_HIP_HPP
src/targets/gpu/kernels/include/migraphx/kernels/math.hpp
View file @
80bf741a
...
@@ -138,7 +138,7 @@ MIGRAPHX_DEVICE_MATH_FOR(migraphx::half, floor, ::hfloor)
...
@@ -138,7 +138,7 @@ MIGRAPHX_DEVICE_MATH_FOR(migraphx::half, floor, ::hfloor)
MIGRAPHX_DEVICE_MATH_FOR
(
migraphx
::
half
,
isnan
,
::
__hisnan
)
MIGRAPHX_DEVICE_MATH_FOR
(
migraphx
::
half
,
isnan
,
::
__hisnan
)
MIGRAPHX_DEVICE_MATH_FOR
(
migraphx
::
half
,
log
,
::
hlog
)
MIGRAPHX_DEVICE_MATH_FOR
(
migraphx
::
half
,
log
,
::
hlog
)
MIGRAPHX_DEVICE_MATH_FOR
(
migraphx
::
half
,
rsqrt
,
::
hrsqrt
)
MIGRAPHX_DEVICE_MATH_FOR
(
migraphx
::
half
,
rsqrt
,
::
hrsqrt
)
//
MIGRAPHX_DEVICE_MATH_FOR(migraphx::half, sin, ::hsin)
MIGRAPHX_DEVICE_MATH_FOR
(
migraphx
::
half
,
sin
,
::
hsin
)
MIGRAPHX_DEVICE_MATH_FOR
(
migraphx
::
half
,
sqrt
,
::
hsqrt
)
MIGRAPHX_DEVICE_MATH_FOR
(
migraphx
::
half
,
sqrt
,
::
hsqrt
)
// Use float to compute half overload
// Use float to compute half overload
...
@@ -176,7 +176,7 @@ MIGRAPHX_DEVICE_MATH_HALF2(log, ::h2log)
...
@@ -176,7 +176,7 @@ MIGRAPHX_DEVICE_MATH_HALF2(log, ::h2log)
MIGRAPHX_DEVICE_MATH_HALF2
(
log10
,
::
h2log10
)
MIGRAPHX_DEVICE_MATH_HALF2
(
log10
,
::
h2log10
)
MIGRAPHX_DEVICE_MATH_HALF2
(
log2
,
::
h2log2
)
MIGRAPHX_DEVICE_MATH_HALF2
(
log2
,
::
h2log2
)
MIGRAPHX_DEVICE_MATH_HALF2
(
rsqrt
,
::
h2rsqrt
)
MIGRAPHX_DEVICE_MATH_HALF2
(
rsqrt
,
::
h2rsqrt
)
//
MIGRAPHX_DEVICE_MATH_HALF2(sin, ::h2sin)
MIGRAPHX_DEVICE_MATH_HALF2
(
sin
,
::
h2sin
)
MIGRAPHX_DEVICE_MATH_HALF2
(
sqrt
,
::
h2sqrt
)
MIGRAPHX_DEVICE_MATH_HALF2
(
sqrt
,
::
h2sqrt
)
template
<
class
T
,
class
U
>
template
<
class
T
,
class
U
>
...
@@ -217,14 +217,6 @@ constexpr auto min(const T& a, const U& b)
...
@@ -217,14 +217,6 @@ constexpr auto min(const T& a, const U& b)
return
min
<
common_type_t
<
T
,
U
>>
(
a
,
b
);
return
min
<
common_type_t
<
T
,
U
>>
(
a
,
b
);
}
}
// Sin for half is broken on hip, so use cos instead
template
<
class
T
,
MIGRAPHX_REQUIRES
(
is_same
<
vec_type
<
T
>,
half
>
{})
>
constexpr
T
sin
(
T
x
)
{
constexpr
const
T
shift
=
HIP_PIO2_F
;
return
migraphx
::
cos
(
shift
-
x
);
}
MIGRAPHX_DEVICE_MATH_VEC
(
abs
)
MIGRAPHX_DEVICE_MATH_VEC
(
abs
)
MIGRAPHX_DEVICE_MATH_VEC
(
acos
)
MIGRAPHX_DEVICE_MATH_VEC
(
acos
)
MIGRAPHX_DEVICE_MATH_VEC
(
acosh
)
MIGRAPHX_DEVICE_MATH_VEC
(
acosh
)
...
...
src/targets/gpu/kernels/include/migraphx/kernels/print.hpp
View file @
80bf741a
...
@@ -244,13 +244,13 @@ __device__ void print_once(Ts... xs)
...
@@ -244,13 +244,13 @@ __device__ void print_once(Ts... xs)
template
<
class
...
Ts
>
template
<
class
...
Ts
>
__device__
void
println
(
Ts
...
xs
)
__device__
void
println
(
Ts
...
xs
)
{
{
print_each
(
&
cout
ln
,
xs
...);
print_each
(
&
cout
,
xs
...
,
'\n'
);
}
}
template
<
class
...
Ts
>
template
<
class
...
Ts
>
__device__
void
println_once
(
Ts
...
xs
)
__device__
void
println_once
(
Ts
...
xs
)
{
{
print_each_once
(
&
cout
ln
,
xs
...);
print_each_once
(
&
cout
,
xs
...
,
'\n'
);
}
}
}
// namespace migraphx
}
// namespace migraphx
...
...
src/targets/gpu/kernels/include/migraphx/kernels/reduce.hpp
View file @
80bf741a
...
@@ -79,20 +79,21 @@ __device__ void dpp_reduce(T& in, Op op)
...
@@ -79,20 +79,21 @@ __device__ void dpp_reduce(T& in, Op op)
#endif
#endif
// NOLINTNEXTLINE
// NOLINTNEXTLINE
#define MIGRAPHX_DPP_REDUCE(op, prefix
)
\
#define MIGRAPHX_DPP_REDUCE(op, prefix
, sign)
\
__device__ inline void dpp_reduce(double& x, op) { MIGRAPHX_DPP_REDUCE_ASM(x, prefix##_f64); } \
__device__ inline void dpp_reduce(double& x, op) { MIGRAPHX_DPP_REDUCE_ASM(x, prefix##_f64); } \
__device__ inline void dpp_reduce(float& x, op) { MIGRAPHX_DPP_REDUCE_ASM(x, prefix##_f32); } \
__device__ inline void dpp_reduce(float& x, op) { MIGRAPHX_DPP_REDUCE_ASM(x, prefix##_f32); } \
__device__ inline void dpp_reduce(half& x, op) { MIGRAPHX_DPP_REDUCE_ASM(x, prefix##_f16); } \
__device__ inline void dpp_reduce(half& x, op) { MIGRAPHX_DPP_REDUCE_ASM(x, prefix##_f16); } \
__device__ inline void dpp_reduce(int32_t& x, op) \
__device__ inline void dpp_reduce(int32_t& x, op) \
{ \
{ \
MIGRAPHX_DPP_REDUCE_ASM(x, prefix##
_u32);
\
MIGRAPHX_DPP_REDUCE_ASM(x, prefix##
sign##32);
\
} \
} \
__device__ inline void dpp_reduce(uint32_t& x, op) { MIGRAPHX_DPP_REDUCE_ASM(x, prefix##_u32); }
__device__ inline void dpp_reduce(uint32_t& x, op) { MIGRAPHX_DPP_REDUCE_ASM(x, prefix##_u32); }
MIGRAPHX_DPP_REDUCE
(
op
::
sum
,
v_add
)
// Note: when max and min are in int32_t, signed version of instruction needs to be used.
MIGRAPHX_DPP_REDUCE
(
op
::
max
,
v_max
)
MIGRAPHX_DPP_REDUCE
(
op
::
sum
,
v_add
,
_u
)
MIGRAPHX_DPP_REDUCE
(
op
::
min
,
v_min
)
MIGRAPHX_DPP_REDUCE
(
op
::
product
,
v_mul
,
_u
)
MIGRAPHX_DPP_REDUCE
(
op
::
product
,
v_mul
)
MIGRAPHX_DPP_REDUCE
(
op
::
max
,
v_max
,
_i
)
MIGRAPHX_DPP_REDUCE
(
op
::
min
,
v_min
,
_i
)
template
<
class
Op
,
class
T
,
class
Index
,
class
F
>
template
<
class
Op
,
class
T
,
class
Index
,
class
F
>
__device__
auto
block_reduce
(
index
idx
,
Op
op
,
T
init
,
Index
n
,
F
f
)
__device__
auto
block_reduce
(
index
idx
,
Op
op
,
T
init
,
Index
n
,
F
f
)
...
@@ -570,7 +571,7 @@ template <class Algo, class Reduced, class Output, class F>
...
@@ -570,7 +571,7 @@ template <class Algo, class Reduced, class Output, class F>
__device__
void
fused_reduce
(
Output
output
,
F
f
)
__device__
void
fused_reduce
(
Output
output
,
F
f
)
{
{
Algo
::
template
run
<
Reduced
>([
&
](
auto
out_idx
,
auto
r
)
{
Algo
::
template
run
<
Reduced
>([
&
](
auto
out_idx
,
auto
r
)
{
auto
result
=
f
(
r
);
auto
result
=
f
(
r
,
out_idx
);
if
constexpr
(
reduce
::
is_inner_storage
<
decltype
(
result
)
>
{})
if
constexpr
(
reduce
::
is_inner_storage
<
decltype
(
result
)
>
{})
{
{
r
.
inner
([
&
](
auto
&
y
,
auto
x
)
{
y
=
x
;
})(
output
,
result
);
r
.
inner
([
&
](
auto
&
y
,
auto
x
)
{
y
=
x
;
})(
output
,
result
);
...
...
Prev
1
2
3
4
5
6
7
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment