Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
MIGraphX
Commits
476ed17c
Unverified
Commit
476ed17c
authored
Aug 28, 2023
by
Brian Pickrell
Committed by
GitHub
Aug 28, 2023
Browse files
Merge branch 'develop' into rand_uniform
parents
f4f9d711
6f1c947f
Changes
96
Show whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
161 additions
and
134 deletions
+161
-134
src/py/CMakeLists.txt
src/py/CMakeLists.txt
+1
-0
src/py/include/migraphx/py.hpp
src/py/include/migraphx/py.hpp
+2
-1
src/py/py_loader.cpp
src/py/py_loader.cpp
+1
-1
src/rewrite_quantization.cpp
src/rewrite_quantization.cpp
+5
-7
src/simplify_algebra.cpp
src/simplify_algebra.cpp
+6
-5
src/sqlite.cpp
src/sqlite.cpp
+1
-0
src/targets/cpu/target.cpp
src/targets/cpu/target.cpp
+1
-1
src/targets/gpu/device/include/migraphx/gpu/device/launch.hpp
...targets/gpu/device/include/migraphx/gpu/device/launch.hpp
+7
-1
src/targets/gpu/device/include/migraphx/gpu/device/nary.hpp
src/targets/gpu/device/include/migraphx/gpu/device/nary.hpp
+2
-2
src/targets/gpu/device/topk.cpp
src/targets/gpu/device/topk.cpp
+2
-2
src/targets/gpu/device_name.cpp
src/targets/gpu/device_name.cpp
+1
-15
src/targets/gpu/fuse_mlir.cpp
src/targets/gpu/fuse_mlir.cpp
+35
-30
src/targets/gpu/include/migraphx/gpu/context.hpp
src/targets/gpu/include/migraphx/gpu/context.hpp
+2
-8
src/targets/gpu/include/migraphx/gpu/device_name.hpp
src/targets/gpu/include/migraphx/gpu/device_name.hpp
+0
-2
src/targets/gpu/include/migraphx/gpu/hip.hpp
src/targets/gpu/include/migraphx/gpu/hip.hpp
+1
-1
src/targets/gpu/include/migraphx/gpu/mlir.hpp
src/targets/gpu/include/migraphx/gpu/mlir.hpp
+3
-2
src/targets/gpu/jit/ck_gemm.cpp
src/targets/gpu/jit/ck_gemm.cpp
+2
-1
src/targets/gpu/jit/mlir.cpp
src/targets/gpu/jit/mlir.cpp
+6
-4
src/targets/gpu/mlir.cpp
src/targets/gpu/mlir.cpp
+82
-50
src/targets/gpu/time_op.cpp
src/targets/gpu/time_op.cpp
+1
-1
No files found.
src/py/CMakeLists.txt
View file @
476ed17c
...
@@ -24,6 +24,7 @@
...
@@ -24,6 +24,7 @@
option
(
MIGRAPHX_ENABLE_PYTHON
"Enable python bindings"
ON
)
option
(
MIGRAPHX_ENABLE_PYTHON
"Enable python bindings"
ON
)
add_library
(
migraphx_py py_loader.cpp
)
add_library
(
migraphx_py py_loader.cpp
)
migraphx_generate_export_header
(
migraphx_py
)
target_include_directories
(
migraphx_py PRIVATE include
)
target_include_directories
(
migraphx_py PRIVATE include
)
target_link_libraries
(
migraphx_py PUBLIC migraphx
)
target_link_libraries
(
migraphx_py PUBLIC migraphx
)
rocm_install_targets
(
TARGETS migraphx_py INCLUDE include
)
rocm_install_targets
(
TARGETS migraphx_py INCLUDE include
)
...
...
src/py/include/migraphx/py.hpp
View file @
476ed17c
...
@@ -26,11 +26,12 @@
...
@@ -26,11 +26,12 @@
#include <migraphx/config.hpp>
#include <migraphx/config.hpp>
#include <migraphx/program.hpp>
#include <migraphx/program.hpp>
#include <migraphx/py/export.h>
namespace
migraphx
{
namespace
migraphx
{
inline
namespace
MIGRAPHX_INLINE_NS
{
inline
namespace
MIGRAPHX_INLINE_NS
{
program
load_py
(
const
std
::
string
&
filename
);
MIGRAPHX_PY_EXPORT
program
load_py
(
const
std
::
string
&
filename
);
}
// namespace MIGRAPHX_INLINE_NS
}
// namespace MIGRAPHX_INLINE_NS
}
// namespace migraphx
}
// namespace migraphx
...
...
src/py/py_loader.cpp
View file @
476ed17c
...
@@ -64,7 +64,7 @@ static dynamic_loader py_lib()
...
@@ -64,7 +64,7 @@ static dynamic_loader py_lib()
return
lib
;
return
lib
;
}
}
program
load_py
(
const
std
::
string
&
filename
)
MIGRAPHX_PY_EXPORT
program
load_py
(
const
std
::
string
&
filename
)
{
{
static
auto
f
=
py_lib
().
get_function
<
program
(
const
std
::
string
&
)
>
(
"migraphx_load_py"
);
static
auto
f
=
py_lib
().
get_function
<
program
(
const
std
::
string
&
)
>
(
"migraphx_load_py"
);
return
f
(
filename
);
return
f
(
filename
);
...
...
src/rewrite_quantization.cpp
View file @
476ed17c
...
@@ -28,6 +28,7 @@
...
@@ -28,6 +28,7 @@
#include <migraphx/tune_axis.hpp>
#include <migraphx/tune_axis.hpp>
#include <migraphx/program.hpp>
#include <migraphx/program.hpp>
#include <migraphx/shape.hpp>
#include <migraphx/shape.hpp>
#include <migraphx/common.hpp>
namespace
migraphx
{
namespace
migraphx
{
inline
namespace
MIGRAPHX_INLINE_NS
{
inline
namespace
MIGRAPHX_INLINE_NS
{
...
@@ -62,12 +63,9 @@ void apply_quantizelinear(module& m, instruction_ref ins)
...
@@ -62,12 +63,9 @@ void apply_quantizelinear(module& m, instruction_ref ins)
min_quant
=
qt
.
min
();
min_quant
=
qt
.
min
();
});
});
auto
s
=
add_zero_point
->
get_shape
();
auto
s
=
add_zero_point
->
get_shape
();
std
::
vector
<
int
>
min_data
(
s
.
elements
(),
min_quant
);
auto
min_arg
=
m
.
add_literal
(
literal
{
shape
{
s
.
type
()},
{
min_quant
}});
std
::
vector
<
int
>
max_data
(
s
.
elements
(),
max_quant
);
auto
max_arg
=
m
.
add_literal
(
literal
{
shape
{
s
.
type
()},
{
max_quant
}});
auto
min_arg
=
m
.
add_literal
(
literal
(
s
,
min_data
));
auto
saturate
=
insert_common_op
(
m
,
ins
,
make_op
(
"clip"
),
{
add_zero_point
,
min_arg
,
max_arg
});
auto
max_arg
=
m
.
add_literal
(
literal
(
s
,
max_data
));
auto
saturate
=
m
.
insert_instruction
(
ins
,
make_op
(
"clip"
),
add_zero_point
,
min_arg
,
max_arg
);
m
.
replace_instruction
(
m
.
replace_instruction
(
ins
,
make_op
(
"convert"
,
{{
"target_type"
,
ins
->
get_shape
().
type
()}}),
saturate
);
ins
,
make_op
(
"convert"
,
{{
"target_type"
,
ins
->
get_shape
().
type
()}}),
saturate
);
}
}
...
...
src/simplify_algebra.cpp
View file @
476ed17c
...
@@ -1095,8 +1095,9 @@ MIGRAPHX_PRED_MATCHER(horiz_conv_dot, instruction_ref ins)
...
@@ -1095,8 +1095,9 @@ MIGRAPHX_PRED_MATCHER(horiz_conv_dot, instruction_ref ins)
};
};
};
};
auto
dots
=
std
::
count_if
(
ins
->
outputs
().
begin
(),
ins
->
outputs
().
end
(),
pred
(
"dot"
));
auto
dots
=
std
::
count_if
(
ins
->
outputs
().
begin
(),
ins
->
outputs
().
end
(),
pred
(
"dot"
));
auto
qdots
=
std
::
count_if
(
ins
->
outputs
().
begin
(),
ins
->
outputs
().
end
(),
pred
(
"quant_dot"
));
auto
convs
=
std
::
count_if
(
ins
->
outputs
().
begin
(),
ins
->
outputs
().
end
(),
pred
(
"convolution"
));
auto
convs
=
std
::
count_if
(
ins
->
outputs
().
begin
(),
ins
->
outputs
().
end
(),
pred
(
"convolution"
));
return
(
dots
>=
2
or
convs
>=
2
);
return
(
dots
>=
2
or
convs
>=
2
or
qdots
>=
2
);
}
}
struct
find_conv_dot_horiz_fusion
struct
find_conv_dot_horiz_fusion
...
@@ -1110,7 +1111,7 @@ struct find_conv_dot_horiz_fusion
...
@@ -1110,7 +1111,7 @@ struct find_conv_dot_horiz_fusion
auto
pred
=
[](
auto
i
,
auto
j
)
{
auto
pred
=
[](
auto
i
,
auto
j
)
{
if
(
i
->
get_operator
()
!=
j
->
get_operator
())
if
(
i
->
get_operator
()
!=
j
->
get_operator
())
return
false
;
return
false
;
if
(
not
contains
({
"dot"
,
"convolution"
},
i
->
name
()))
if
(
not
contains
({
"quant_dot"
,
"dot"
,
"convolution"
},
i
->
name
()))
return
true
;
return
true
;
auto
x
=
i
->
inputs
()[
1
]
->
get_shape
().
lens
();
auto
x
=
i
->
inputs
()[
1
]
->
get_shape
().
lens
();
auto
y
=
j
->
inputs
()[
1
]
->
get_shape
().
lens
();
auto
y
=
j
->
inputs
()[
1
]
->
get_shape
().
lens
();
...
@@ -1118,7 +1119,7 @@ struct find_conv_dot_horiz_fusion
...
@@ -1118,7 +1119,7 @@ struct find_conv_dot_horiz_fusion
return
false
;
return
false
;
// Check that non-axes match
// Check that non-axes match
int
axis
=
1
;
int
axis
=
1
;
if
(
i
->
name
()
==
"dot"
)
if
(
i
->
name
()
==
"dot"
or
i
->
name
()
==
"quant_dot"
)
{
{
axis
=
x
.
size
()
-
1
;
axis
=
x
.
size
()
-
1
;
}
}
...
@@ -1129,7 +1130,7 @@ struct find_conv_dot_horiz_fusion
...
@@ -1129,7 +1130,7 @@ struct find_conv_dot_horiz_fusion
if
(
std
::
distance
(
start
,
last
)
<
2
)
if
(
std
::
distance
(
start
,
last
)
<
2
)
return
;
return
;
auto
&&
name
=
(
*
start
)
->
name
();
auto
&&
name
=
(
*
start
)
->
name
();
if
(
not
contains
({
"dot"
,
"convolution"
},
name
))
if
(
not
contains
({
"quant_dot"
,
"dot"
,
"convolution"
},
name
))
return
;
return
;
auto
op
=
(
*
start
)
->
get_operator
();
auto
op
=
(
*
start
)
->
get_operator
();
int
group
=
1
;
int
group
=
1
;
...
@@ -1144,7 +1145,7 @@ struct find_conv_dot_horiz_fusion
...
@@ -1144,7 +1145,7 @@ struct find_conv_dot_horiz_fusion
start
,
last
,
std
::
back_inserter
(
args
),
[
&
](
auto
x
)
{
return
x
->
inputs
().
at
(
1
);
});
start
,
last
,
std
::
back_inserter
(
args
),
[
&
](
auto
x
)
{
return
x
->
inputs
().
at
(
1
);
});
int
axis
=
1
;
int
axis
=
1
;
int
concat_axis
=
0
;
int
concat_axis
=
0
;
if
(
name
==
"dot"
)
if
(
name
==
"dot"
or
name
==
"quant_dot"
)
{
{
axis
=
int
(
args
.
front
()
->
get_shape
().
lens
().
size
()
-
1
);
axis
=
int
(
args
.
front
()
->
get_shape
().
lens
().
size
()
-
1
);
concat_axis
=
axis
;
concat_axis
=
axis
;
...
...
src/sqlite.cpp
View file @
476ed17c
...
@@ -48,6 +48,7 @@ struct sqlite_impl
...
@@ -48,6 +48,7 @@ struct sqlite_impl
template
<
class
F
>
template
<
class
F
>
void
exec
(
const
char
*
sql
,
F
f
)
void
exec
(
const
char
*
sql
,
F
f
)
{
{
// cppcheck-suppress constParameterPointer
auto
callback
=
[](
void
*
obj
,
auto
...
xs
)
->
int
{
auto
callback
=
[](
void
*
obj
,
auto
...
xs
)
->
int
{
try
try
{
{
...
...
src/targets/cpu/target.cpp
View file @
476ed17c
...
@@ -61,7 +61,7 @@ namespace cpu {
...
@@ -61,7 +61,7 @@ namespace cpu {
std
::
string
target
::
name
()
const
{
return
"cpu"
;
}
std
::
string
target
::
name
()
const
{
return
"cpu"
;
}
// cppcheck-suppress constParameter
// cppcheck-suppress constParameter
Reference
std
::
vector
<
pass
>
target
::
get_passes
(
migraphx
::
context
&
gctx
,
const
compile_options
&
)
const
std
::
vector
<
pass
>
target
::
get_passes
(
migraphx
::
context
&
gctx
,
const
compile_options
&
)
const
{
{
auto
&
ctx
=
any_cast
<
context
>
(
gctx
);
auto
&
ctx
=
any_cast
<
context
>
(
gctx
);
...
...
src/targets/gpu/device/include/migraphx/gpu/device/launch.hpp
View file @
476ed17c
...
@@ -81,6 +81,12 @@ inline auto launch(hipStream_t stream, index_int global, index_int local)
...
@@ -81,6 +81,12 @@ inline auto launch(hipStream_t stream, index_int global, index_int local)
dim3
nthreads
(
local
);
dim3
nthreads
(
local
);
// cppcheck-suppress UseDeviceLaunch
// cppcheck-suppress UseDeviceLaunch
hipLaunchKernelGGL
((
launcher
<
f_type
>
),
nblocks
,
nthreads
,
0
,
stream
,
f
);
hipLaunchKernelGGL
((
launcher
<
f_type
>
),
nblocks
,
nthreads
,
0
,
stream
,
f
);
hipError_t
kernel_launch_status
=
hipGetLastError
();
if
(
kernel_launch_status
!=
hipSuccess
)
{
MIGRAPHX_THROW
(
"MIGraphX device kernel failed to launch with error: "
+
std
::
string
(
hipGetErrorString
(
kernel_launch_status
)));
}
};
};
}
}
...
...
src/targets/gpu/device/include/migraphx/gpu/device/nary.hpp
View file @
476ed17c
...
@@ -124,7 +124,7 @@ void nary_broadcast_vec_impl(
...
@@ -124,7 +124,7 @@ void nary_broadcast_vec_impl(
buffer
[
i
]
=
binput
.
data
()[
i
];
buffer
[
i
]
=
binput
.
data
()[
i
];
}
}
__syncthreads
();
__syncthreads
();
auto
*
bp
=
as_pointer
(
buffer
);
const
auto
*
bp
=
as_pointer
(
buffer
);
// Process the data
// Process the data
for
(
size_t
i
=
idx
.
global
;
i
<
nelements
;
i
+=
nglobal
)
for
(
size_t
i
=
idx
.
global
;
i
<
nelements
;
i
+=
nglobal
)
{
{
...
@@ -219,7 +219,7 @@ void nary_double_broadcast_vec_impl(
...
@@ -219,7 +219,7 @@ void nary_double_broadcast_vec_impl(
buffer
[
i
+
bdim_vec_len
]
=
binput2
.
data
()[
i
];
buffer
[
i
+
bdim_vec_len
]
=
binput2
.
data
()[
i
];
}
}
__syncthreads
();
__syncthreads
();
auto
*
bp
=
as_pointer
(
buffer
);
const
auto
*
bp
=
as_pointer
(
buffer
);
// Process the data
// Process the data
for
(
size_t
i
=
idx
.
global
;
i
<
nelements
;
i
+=
nglobal
)
for
(
size_t
i
=
idx
.
global
;
i
<
nelements
;
i
+=
nglobal
)
{
{
...
...
src/targets/gpu/device/topk.cpp
View file @
476ed17c
...
@@ -72,12 +72,12 @@ struct hip_heap_vector
...
@@ -72,12 +72,12 @@ struct hip_heap_vector
index_int
l
=
2
*
index
+
1
;
index_int
l
=
2
*
index
+
1
;
index_int
r
=
2
*
index
+
2
;
index_int
r
=
2
*
index
+
2
;
if
(
l
<
n
&&
compare
(
data
[
data_index
(
l
)],
data
[
data_index
(
index
)]))
if
(
l
<
n
and
compare
(
data
[
data_index
(
l
)],
data
[
data_index
(
index
)]))
{
{
index
=
l
;
index
=
l
;
}
}
if
(
r
<
n
&&
compare
(
data
[
data_index
(
r
)],
data
[
data_index
(
index
)]))
if
(
r
<
n
and
compare
(
data
[
data_index
(
r
)],
data
[
data_index
(
index
)]))
{
{
index
=
r
;
index
=
r
;
if
(
compare
(
data
[
data_index
(
l
)],
data
[
data_index
(
r
)]))
if
(
compare
(
data
[
data_index
(
l
)],
data
[
data_index
(
r
)]))
...
...
src/targets/gpu/device_name.cpp
View file @
476ed17c
...
@@ -31,20 +31,6 @@ namespace migraphx {
...
@@ -31,20 +31,6 @@ namespace migraphx {
inline
namespace
MIGRAPHX_INLINE_NS
{
inline
namespace
MIGRAPHX_INLINE_NS
{
namespace
gpu
{
namespace
gpu
{
template
<
class
HipDeviceProp
>
std
::
string
get_arch_name
(
rank
<
0
>
,
const
HipDeviceProp
&
props
)
{
return
"gfx"
+
std
::
to_string
(
props
.
gcnArch
);
}
template
<
class
HipDeviceProp
>
auto
get_arch_name
(
rank
<
1
>
,
const
HipDeviceProp
&
props
)
->
decltype
(
std
::
string
(
props
.
gcnArchName
))
{
return
std
::
string
(
props
.
gcnArchName
);
}
std
::
string
get_arch_name
(
const
hipDeviceProp_t
&
props
)
{
return
get_arch_name
(
rank
<
1
>
{},
props
);
}
int
get_device_id
()
int
get_device_id
()
{
{
int
device
;
int
device
;
...
@@ -60,7 +46,7 @@ std::string get_device_name()
...
@@ -60,7 +46,7 @@ std::string get_device_name()
auto
status
=
hipGetDeviceProperties
(
&
props
,
get_device_id
());
auto
status
=
hipGetDeviceProperties
(
&
props
,
get_device_id
());
if
(
status
!=
hipSuccess
)
if
(
status
!=
hipSuccess
)
MIGRAPHX_THROW
(
"Failed to get device properties"
);
MIGRAPHX_THROW
(
"Failed to get device properties"
);
return
get_a
rch
_n
ame
(
props
)
;
return
props
.
gcnA
rch
N
ame
;
}
}
}
// namespace gpu
}
// namespace gpu
...
...
src/targets/gpu/fuse_mlir.cpp
View file @
476ed17c
...
@@ -86,7 +86,7 @@ struct mlir_op
...
@@ -86,7 +86,7 @@ struct mlir_op
size_t
param_cnt
=
0
;
size_t
param_cnt
=
0
;
std
::
vector
<
std
::
string
>
names
=
mod
->
get_parameter_names
();
std
::
vector
<
std
::
string
>
names
=
mod
->
get_parameter_names
();
std
::
sort
(
names
.
begin
(),
names
.
end
());
std
::
sort
(
names
.
begin
(),
names
.
end
());
for
(
std
::
string
param_name
:
names
)
for
(
const
std
::
string
&
param_name
:
names
)
{
{
ins_shapes
[
mod
->
get_parameter
(
param_name
)]
=
inputs
[
param_cnt
++
];
ins_shapes
[
mod
->
get_parameter
(
param_name
)]
=
inputs
[
param_cnt
++
];
}
}
...
@@ -210,7 +210,8 @@ struct find_mlir_op
...
@@ -210,7 +210,8 @@ struct find_mlir_op
return
false
;
return
false
;
}
}
const
std
::
initializer_list
<
std
::
string
>
any_type_ops
=
{
"@literal"
,
"@param"
,
"@return"
};
const
std
::
initializer_list
<
std
::
string
>
any_type_ops
=
{
"@literal"
,
"@param"
,
"@return"
};
const
std
::
initializer_list
<
std
::
string
>
no_bool_ops
=
{
"convolution"
,
const
std
::
initializer_list
<
std
::
string
>
no_bool_ops
=
{
"convolution"
,
"quant_convolution"
,
"quant_convolution"
,
"dot"
,
"dot"
,
"quant_dot"
,
"quant_dot"
,
...
@@ -225,27 +226,31 @@ struct find_mlir_op
...
@@ -225,27 +226,31 @@ struct find_mlir_op
"quantizelinear"
,
"quantizelinear"
,
"dequantizelinear"
,
"dequantizelinear"
,
"abs"
,
"abs"
,
"neg"
};
"neg"
,
const
std
::
initializer_list
<
std
::
string
>
fp_only_ops
=
{
"ceil"
,
};
const
std
::
initializer_list
<
std
::
string
>
fp_only_ops
=
{
"ceil"
,
"erf"
,
"erf"
,
"exp"
,
"exp"
,
"floor"
,
"floor"
,
"log"
,
"log"
,
"recip"
,
"recip"
,
"rsqrt"
,
"rsqrt"
,
"sigmoid"
// There are bugs in MLIR right now for models using sigmoid so disable it for now
// "sigmoid",
"softmax"
,
"softmax"
,
"tanh"
};
"tanh"
,
};
bool
is_float
=
contains
({
type_t
::
float_type
,
type_t
::
half_type
},
result_type
);
bool
is_float
=
contains
({
type_t
::
float_type
,
type_t
::
half_type
},
result_type
);
if
(
contains
(
any_type_ops
,
name
))
if
(
contains
(
any_type_ops
,
name
))
return
true
;
return
true
;
if
(
result_type
!=
type_t
::
bool_type
&&
contains
(
no_bool_ops
,
name
))
if
(
result_type
!=
type_t
::
bool_type
and
contains
(
no_bool_ops
,
name
))
return
true
;
return
true
;
if
(
is_float
&&
contains
(
fp_only_ops
,
name
))
if
(
is_float
and
contains
(
fp_only_ops
,
name
))
return
true
;
return
true
;
// Only conversions between floating types are known to be unambigiously
// Only conversions between floating types are known to be unambigiously
// supported.
// supported.
if
(
is_float
&&
name
==
"convert"
)
if
(
is_float
and
name
==
"convert"
)
{
{
return
std
::
all_of
(
i
.
inputs
().
begin
(),
i
.
inputs
().
end
(),
[](
const
auto
&
arg
)
{
return
std
::
all_of
(
i
.
inputs
().
begin
(),
i
.
inputs
().
end
(),
[](
const
auto
&
arg
)
{
return
contains
({
type_t
::
float_type
,
type_t
::
half_type
},
arg
->
get_shape
().
type
());
return
contains
({
type_t
::
float_type
,
type_t
::
half_type
},
arg
->
get_shape
().
type
());
...
...
src/targets/gpu/include/migraphx/gpu/context.hpp
View file @
476ed17c
...
@@ -46,13 +46,7 @@ using hip_event_ptr = MIGRAPHX_MANAGE_PTR(hipEvent_t, hipEventDestroy);
...
@@ -46,13 +46,7 @@ using hip_event_ptr = MIGRAPHX_MANAGE_PTR(hipEvent_t, hipEventDestroy);
struct
hip_device
struct
hip_device
{
{
hip_device
()
hip_device
()
:
device_props
{}
{
add_stream
();
}
{
device_props
.
gcnArchName
[
0
]
=
'\0'
;
device_props
.
gcnArch
=
0
;
device_props
.
multiProcessorCount
=
0
;
add_stream
();
}
hip_device
(
std
::
size_t
id
,
std
::
size_t
n
)
:
device_id
(
id
)
hip_device
(
std
::
size_t
id
,
std
::
size_t
n
)
:
device_id
(
id
)
{
{
...
@@ -171,7 +165,7 @@ struct hip_device
...
@@ -171,7 +165,7 @@ struct hip_device
std
::
size_t
stream_id
()
const
{
return
current_stream
;
}
std
::
size_t
stream_id
()
const
{
return
current_stream
;
}
std
::
string
get_device_name
()
const
{
return
get_arch_name
(
device_props
)
;
}
std
::
string
get_device_name
()
const
{
return
device_props
.
gcnArchName
;
}
std
::
string
get_gfx_name
()
const
{
return
trim
(
split_string
(
get_device_name
(),
':'
).
front
());
}
std
::
string
get_gfx_name
()
const
{
return
trim
(
split_string
(
get_device_name
(),
':'
).
front
());
}
...
...
src/targets/gpu/include/migraphx/gpu/device_name.hpp
View file @
476ed17c
...
@@ -33,8 +33,6 @@ namespace migraphx {
...
@@ -33,8 +33,6 @@ namespace migraphx {
inline
namespace
MIGRAPHX_INLINE_NS
{
inline
namespace
MIGRAPHX_INLINE_NS
{
namespace
gpu
{
namespace
gpu
{
MIGRAPHX_GPU_EXPORT
std
::
string
get_arch_name
(
const
hipDeviceProp_t
&
props
);
MIGRAPHX_GPU_EXPORT
std
::
string
get_device_name
();
MIGRAPHX_GPU_EXPORT
std
::
string
get_device_name
();
MIGRAPHX_GPU_EXPORT
int
get_device_id
();
MIGRAPHX_GPU_EXPORT
int
get_device_id
();
...
...
src/targets/gpu/include/migraphx/gpu/hip.hpp
View file @
476ed17c
...
@@ -92,7 +92,7 @@ struct hip_sync_stream
...
@@ -92,7 +92,7 @@ struct hip_sync_stream
return
inputs
.
front
();
return
inputs
.
front
();
}
}
argument
compute
(
context
&
ctx
,
const
shape
&
,
const
std
::
vector
<
argument
>&
args
)
const
argument
compute
(
const
context
&
ctx
,
const
shape
&
,
const
std
::
vector
<
argument
>&
args
)
const
{
{
gpu_sync
(
ctx
);
gpu_sync
(
ctx
);
if
(
args
.
empty
())
if
(
args
.
empty
())
...
...
src/targets/gpu/include/migraphx/gpu/mlir.hpp
View file @
476ed17c
...
@@ -37,7 +37,7 @@ struct module;
...
@@ -37,7 +37,7 @@ struct module;
namespace
gpu
{
namespace
gpu
{
MIGRAPHX_GPU_EXPORT
std
::
string
dump_mlir
(
const
module
&
m
);
MIGRAPHX_GPU_EXPORT
std
::
string
dump_mlir
(
const
module
&
m
);
MIGRAPHX_GPU_EXPORT
code_object_op
compile_mlir
(
const
context
&
ctx
,
MIGRAPHX_GPU_EXPORT
code_object_op
compile_mlir
(
const
context
&
migraphx_
ctx
,
module
m
,
module
m
,
const
std
::
vector
<
instruction_ref
>&
inputs
,
const
std
::
vector
<
instruction_ref
>&
inputs
,
const
value
&
solution
);
const
value
&
solution
);
...
@@ -47,7 +47,8 @@ MIGRAPHX_GPU_EXPORT instruction_ref insert_mlir(module& m,
...
@@ -47,7 +47,8 @@ MIGRAPHX_GPU_EXPORT instruction_ref insert_mlir(module& m,
code_object_op
co
,
code_object_op
co
,
const
std
::
vector
<
instruction_ref
>&
inputs
);
const
std
::
vector
<
instruction_ref
>&
inputs
);
MIGRAPHX_GPU_EXPORT
tuning_config
get_tuning_config_mlir
(
module
m
,
MIGRAPHX_GPU_EXPORT
tuning_config
get_tuning_config_mlir
(
const
context
&
migraphx_ctx
,
module
m
,
const
std
::
vector
<
shape
>&
inputs
);
const
std
::
vector
<
shape
>&
inputs
);
}
// namespace gpu
}
// namespace gpu
...
...
src/targets/gpu/jit/ck_gemm.cpp
View file @
476ed17c
...
@@ -300,7 +300,8 @@ struct ck_gemm_compiler : compiler<ck_gemm_compiler>
...
@@ -300,7 +300,8 @@ struct ck_gemm_compiler : compiler<ck_gemm_compiler>
const
auto
&
b_shape
=
inputs
[
1
];
const
auto
&
b_shape
=
inputs
[
1
];
const
auto
&
c_shape
=
inputs
.
back
();
const
auto
&
c_shape
=
inputs
.
back
();
auto
rank
=
a_shape
.
lens
().
size
();
// cppcheck-suppress unreadVariable
auto
rank
=
a_shape
.
ndim
();
auto
batch_count
=
get_batch_count
(
c_shape
);
auto
batch_count
=
get_batch_count
(
c_shape
);
auto
m
=
c_shape
.
lens
()[
rank
-
2
];
auto
m
=
c_shape
.
lens
()[
rank
-
2
];
...
...
src/targets/gpu/jit/mlir.cpp
View file @
476ed17c
...
@@ -37,7 +37,7 @@ struct mlir_compiler : compiler<mlir_compiler>
...
@@ -37,7 +37,7 @@ struct mlir_compiler : compiler<mlir_compiler>
operation
compile_op
(
context
&
,
const
std
::
vector
<
shape
>&
,
const
value
&
)
const
{
return
{};
}
operation
compile_op
(
context
&
,
const
std
::
vector
<
shape
>&
,
const
value
&
)
const
{
return
{};
}
compiler_replace
compiler_replace
compile
(
context
&
ctx
,
instruction_ref
ins
,
const
operation
&
,
const
value
&
solution
)
const
compile
(
const
context
&
ctx
,
instruction_ref
ins
,
const
operation
&
,
const
value
&
solution
)
const
{
{
auto
*
smod
=
ins
->
module_inputs
().
front
();
auto
*
smod
=
ins
->
module_inputs
().
front
();
assert
(
smod
->
get_parameter_names
().
size
()
==
ins
->
inputs
().
size
()
-
1
);
assert
(
smod
->
get_parameter_names
().
size
()
==
ins
->
inputs
().
size
()
-
1
);
...
@@ -52,14 +52,16 @@ struct mlir_compiler : compiler<mlir_compiler>
...
@@ -52,14 +52,16 @@ struct mlir_compiler : compiler<mlir_compiler>
}};
}};
}
}
optional
<
tuning_config
>
optional
<
tuning_config
>
get_tuning_config
(
const
context
&
ctx
,
get_tuning_config
(
context
&
,
instruction_ref
ins
,
const
operation
&
,
bool
exhaustive
)
const
instruction_ref
ins
,
const
operation
&
,
bool
exhaustive
)
const
{
{
if
(
not
exhaustive
)
if
(
not
exhaustive
)
return
nullopt
;
return
nullopt
;
auto
shapes
=
to_shapes
(
ins
->
inputs
());
auto
shapes
=
to_shapes
(
ins
->
inputs
());
auto
*
smod
=
ins
->
module_inputs
().
front
();
auto
*
smod
=
ins
->
module_inputs
().
front
();
return
get_tuning_config_mlir
(
*
smod
,
shapes
);
return
get_tuning_config_mlir
(
ctx
,
*
smod
,
shapes
);
}
}
};
};
...
...
src/targets/gpu/mlir.cpp
View file @
476ed17c
...
@@ -36,7 +36,10 @@
...
@@ -36,7 +36,10 @@
#include <mutex>
#include <mutex>
#if !defined(MLIR_MIGRAPHX_DIALECT_API_VERSION) || MLIR_MIGRAPHX_DIALECT_API_VERSION != 3
#if !defined(MLIR_MIGRAPHX_DIALECT_API_VERSION) || MLIR_MIGRAPHX_DIALECT_API_VERSION != 3
#warning "Incompatible version of rocMLIR library used, disabling"
#warning "Incompatible version of rocMLIR library used, disabling"
// Only undefine when not using cppcheck
#ifndef CPPCHECK
#undef MIGRAPHX_MLIR
#undef MIGRAPHX_MLIR
#endif
#else
#else
#include <mlir-c/RegisterRocMLIR.h>
#include <mlir-c/RegisterRocMLIR.h>
#endif
#endif
...
@@ -173,12 +176,6 @@ std::string mlir_print(F f, T x)
...
@@ -173,12 +176,6 @@ std::string mlir_print(F f, T x)
return
ss
.
str
();
return
ss
.
str
();
}
}
bool
has_xdlops
(
const
std
::
string
&
target_arch
)
{
const
auto
device_name
=
trim
(
split_string
(
target_arch
,
':'
).
front
());
return
(
starts_with
(
device_name
,
"gfx9"
)
and
device_name
>=
"gfx908"
);
}
struct
mlir_program
struct
mlir_program
{
{
mlir_program
()
mlir_program
()
...
@@ -513,7 +510,8 @@ struct mlir_program
...
@@ -513,7 +510,8 @@ struct mlir_program
ops
.
add_attributes
({{
"function_type"
,
make_function_type
(
inputs
,
outputs
)},
ops
.
add_attributes
({{
"function_type"
,
make_function_type
(
inputs
,
outputs
)},
{
"sym_name"
,
sym_name
},
{
"sym_name"
,
sym_name
},
{
"kernel"
,
std
::
string
(
"mixr"
)},
{
"kernel"
,
std
::
string
(
"mixr"
)},
{
"arch"
,
target_arch
}});
{
"arch"
,
target_arch
},
{
"num_cu"
,
num_cu
}});
ops
.
add_region
(
std
::
move
(
region
));
ops
.
add_region
(
std
::
move
(
region
));
insert
(
body
,
std
::
move
(
ops
));
insert
(
body
,
std
::
move
(
ops
));
...
@@ -596,9 +594,6 @@ struct mlir_program
...
@@ -596,9 +594,6 @@ struct mlir_program
{
{
pp
=
pp
=
problem_params
{
ins
->
get_operator
(),
to_shapes
(
ins
->
inputs
()),
ins
->
get_shape
()};
problem_params
{
ins
->
get_operator
(),
to_shapes
(
ins
->
inputs
()),
ins
->
get_shape
()};
// check if HW supports xdlops
if
(
has_xdlops
(
target_arch
))
ops
.
add_attributes
({{
"xdlopsV2"
,
true
}});
}
}
std
::
vector
<
MlirValue
>
inputs
;
std
::
vector
<
MlirValue
>
inputs
;
...
@@ -647,7 +642,12 @@ struct mlir_program
...
@@ -647,7 +642,12 @@ struct mlir_program
return
op
;
return
op
;
}
}
void
find_target
()
{
target_arch
=
get_device_name
();
}
void
set_gpu_properties
(
const
context
&
migraphx_ctx
)
{
const
auto
&
device
=
migraphx_ctx
.
get_current_device
();
target_arch
=
device
.
get_device_name
();
num_cu
=
device
.
get_cu_count
();
}
std
::
pair
<
std
::
size_t
,
std
::
size_t
>
get_launch_params
()
const
std
::
pair
<
std
::
size_t
,
std
::
size_t
>
get_launch_params
()
const
{
{
...
@@ -661,7 +661,7 @@ struct mlir_program
...
@@ -661,7 +661,7 @@ struct mlir_program
value
::
binary
get_binary
()
const
value
::
binary
get_binary
()
const
{
{
in
t
size
=
0
;
size_
t
size
=
0
;
mlirGetBinary
(
mmodule
.
get
(),
&
size
,
nullptr
);
mlirGetBinary
(
mmodule
.
get
(),
&
size
,
nullptr
);
value
::
binary
result
(
size
);
value
::
binary
result
(
size
);
if
(
mlirGetBinary
(
mmodule
.
get
(),
&
size
,
reinterpret_cast
<
char
*>
(
result
.
data
())))
if
(
mlirGetBinary
(
mmodule
.
get
(),
&
size
,
reinterpret_cast
<
char
*>
(
result
.
data
())))
...
@@ -669,30 +669,41 @@ struct mlir_program
...
@@ -669,30 +669,41 @@ struct mlir_program
MIGRAPHX_THROW
(
"Failed to compile mlir program"
);
MIGRAPHX_THROW
(
"Failed to compile mlir program"
);
}
}
void
set_tuning
(
const
value
&
v
)
void
set_tuning
(
const
value
&
v
)
MIGRAPHX_TIDY_CONST
{
{
auto
str
=
v
.
to
<
std
::
string
>
();
const
auto
*
str
=
v
.
if_string
();
// We need to make a copy of the buffer since mlirRockTuningSetFromStr may modify the string
if
(
str
==
nullptr
)
std
::
vector
<
char
>
buffer
(
str
.
begin
(),
str
.
end
());
MIGRAPHX_THROW
(
"mlir tuning solutions must be strings"
);
buffer
.
push_back
(
0
);
if
(
not
mlirRockTuningSetFromStr
(
mmodule
.
get
(),
make_mlir_string_ref
(
*
str
)))
if
(
not
mlirRockTuningSetFromStr
(
mmodule
.
get
(),
buffer
.
data
()))
MIGRAPHX_THROW
(
"Failed setting tuning key: "
+
*
str
);
MIGRAPHX_THROW
(
"Failed setting tuning key: "
+
str
);
}
}
tuning_config
get_tuning_config
()
MIGRAPHX_TIDY_CONST
tuning_config
get_tuning_config
()
MIGRAPHX_TIDY_CONST
{
{
tuning_config
tc
;
tuning_config
tc
;
run_high_level_pipeline
();
run_high_level_pipeline
();
mlir_tuning_space
params
{
mlirRockTuningSpaceCreate
(
mmodule
.
get
())};
mlir_tuning_space
params
{
for
(
auto
i
:
range
(
mlirRockTuningGetNumParamsFull
(
params
.
get
())))
mlirRockTuningSpaceCreate
(
mmodule
.
get
(),
RocmlirTuningParamSetKindFull
)};
for
(
auto
i
:
range
(
mlirRockTuningGetNumParams
(
params
.
get
())))
{
{
mlir_tuning_param
param
{
mlirRockTuningParamCreate
()};
mlir_tuning_param
param
{
mlirRockTuningParamCreate
()};
if
(
not
mlirRockTuningParamGet
(
params
.
get
(),
i
,
param
.
get
()))
if
(
not
mlirRockTuningParamGet
(
params
.
get
(),
i
,
param
.
get
()))
MIGRAPHX_THROW
(
"Incorrect mlir tuning parameter: "
+
std
::
to_string
(
i
));
MIGRAPHX_THROW
(
"Incorrect mlir tuning parameter: "
+
std
::
to_string
(
i
));
tc
.
solutions
.
push_back
(
std
::
string
{
mlirRockTuningGetParamStr
(
param
.
get
())});
std
::
array
<
char
,
ROCMLIR_TUNING_KEY_BUFSZ
>
perf_key
;
}
size_t
perf_key_bytes
=
mlir_tuning_table
tuning_table
{
mlirRockTuningTableCreate
()};
mlirRockTuningParamToString
(
param
.
get
(),
perf_key
.
data
(),
perf_key
.
size
());
tc
.
problem
=
std
::
string
{
mlirRockTuningGetKey
(
tuning_table
.
get
(),
mmodule
.
get
())};
if
(
perf_key_bytes
>
perf_key
.
size
())
MIGRAPHX_THROW
(
"Tuning perf key was "
+
std
::
to_string
(
perf_key_bytes
)
+
" bytes and thus too long"
);
tc
.
solutions
.
emplace_back
(
perf_key
.
begin
(),
perf_key
.
begin
()
+
perf_key_bytes
);
}
std
::
array
<
char
,
ROCMLIR_TUNING_KEY_BUFSZ
>
tuning_key
;
size_t
tuning_key_bytes
=
mlirRockTuningGetKey
(
mmodule
.
get
(),
tuning_key
.
data
(),
tuning_key
.
size
());
if
(
tuning_key_bytes
>
tuning_key
.
size
())
MIGRAPHX_THROW
(
"Tuning table key was "
+
std
::
to_string
(
tuning_key_bytes
)
+
" bytes and thus too long"
);
tc
.
problem
=
std
::
string
(
tuning_key
.
begin
(),
tuning_key
.
begin
()
+
tuning_key_bytes
);
return
tc
;
return
tc
;
}
}
...
@@ -700,10 +711,10 @@ struct mlir_program
...
@@ -700,10 +711,10 @@ struct mlir_program
// This function appends to tuning cfg file that could be
// This function appends to tuning cfg file that could be
// used with rocMLIR tuning scripts.
// used with rocMLIR tuning scripts.
void
dump_tuning_cfg
(
const
char
*
prob_config
)
const
void
dump_tuning_cfg
(
const
std
::
string
&
prob_config
)
const
{
{
std
::
string
tuning_cfg_path
=
string_value_of
(
MIGRAPHX_MLIR_TUNING_CFG
{});
std
::
string
tuning_cfg_path
=
string_value_of
(
MIGRAPHX_MLIR_TUNING_CFG
{});
if
(
!
tuning_cfg_path
.
empty
())
if
(
not
tuning_cfg_path
.
empty
())
{
{
std
::
vector
<
std
::
string
>
tokens
=
split_string
(
prob_config
,
'\t'
);
std
::
vector
<
std
::
string
>
tokens
=
split_string
(
prob_config
,
'\t'
);
std
::
string
prob
=
tokens
[
1
];
std
::
string
prob
=
tokens
[
1
];
...
@@ -720,51 +731,66 @@ struct mlir_program
...
@@ -720,51 +731,66 @@ struct mlir_program
}
}
}
}
static
mlir_tuning_table
create
_tuning_table
()
static
std
::
pair
<
mlir_tuning_table
,
bool
>
load
_tuning_table
()
{
{
mlir_tuning_table
tuning_table
{
mlirRockTuningTableCreate
()};
mlir_tuning_table
tuning_table
{
mlirRockTuningTableCreate
()};
bool
found_table
=
false
;
std
::
string
tuning_db_path
=
string_value_of
(
MIGRAPHX_MLIR_TUNING_DB
{});
std
::
string
tuning_db_path
=
string_value_of
(
MIGRAPHX_MLIR_TUNING_DB
{});
if
(
!
tuning_db_path
.
empty
())
if
(
not
tuning_db_path
.
empty
())
{
{
std
::
ifstream
tuning_db_tsv
(
tuning_db_path
);
std
::
ifstream
tuning_db_tsv
(
tuning_db_path
);
if
(
tuning_db_tsv
)
if
(
tuning_db_tsv
)
{
{
found_table
=
true
;
std
::
string
line
;
std
::
string
line
;
while
(
std
::
getline
(
tuning_db_tsv
,
line
))
while
(
std
::
getline
(
tuning_db_tsv
,
line
))
{
{
std
::
vector
<
std
::
string
>
tokens
=
split_string
(
line
,
'\t'
);
std
::
vector
<
std
::
string
>
tokens
=
split_string
(
line
,
'\t'
);
std
::
string
arch
=
tokens
[
0
];
std
::
string
arch
=
tokens
[
0
];
std
::
string
prob
=
tokens
[
1
];
std
::
string
num_cu
=
tokens
[
1
];
std
::
string
perf
=
tokens
[
2
];
std
::
string
prob
=
tokens
[
2
];
std
::
string
key
=
arch
.
append
(
"
\t
"
).
append
(
prob
);
std
::
string
perf
=
tokens
[
3
];
mlirRockTuningUpdateTable
(
tuning_table
.
get
(),
key
.
c_str
(),
perf
.
c_str
(),
1.0
);
std
::
string
key
=
arch
.
append
(
"
\t
"
).
append
(
num_cu
).
append
(
"
\t
"
).
append
(
prob
);
mlirRockTuningUpdateTable
(
tuning_table
.
get
(),
make_mlir_string_ref
(
key
),
make_mlir_string_ref
(
perf
),
1.0
);
}
}
}
}
}
}
else
else
{
{
found_table
=
false
;
std
::
cerr
std
::
cerr
<<
"WARNING: MLIR tuning db not found. Please set MIGRAPHX_MLIR_TUNING_DB for "
<<
"WARNING: MLIR tuning db not found. Please set MIGRAPHX_MLIR_TUNING_DB for "
"optimal performance."
"optimal performance."
<<
std
::
endl
;
<<
std
::
endl
;
}
}
return
tuning_table
;
return
std
::
make_pair
(
std
::
move
(
tuning_table
),
found_table
)
;
}
}
bool
get_module_tuned
()
const
bool
get_module_tuned
()
const
{
{
static
mlir_tuning_table
tuning_table
=
create_tuning_table
();
static
std
::
pair
<
mlir_tuning_table
,
bool
>
tuning_table
=
load_tuning_table
();
// The tuning table as currently implemented is currently not
if
(
not
mlirRockTuningSetFromTable
(
tuning_table
.
first
.
get
(),
mmodule
.
get
()))
// thread safe. This will be fixed in the future. For now,
{
// stick a mutex around all tuning table interaction.
std
::
array
<
char
,
ROCMLIR_TUNING_KEY_BUFSZ
>
prob_config
;
static
std
::
mutex
lock
;
size_t
prob_config_bytes
=
std
::
lock_guard
<
std
::
mutex
>
guard
(
lock
);
mlirRockTuningGetKey
(
mmodule
.
get
(),
prob_config
.
data
(),
prob_config
.
size
());
if
(
!
mlirRockTuningSetFromTable
(
tuning_table
.
get
(),
mmodule
.
get
()))
if
(
prob_config_bytes
>=
prob_config
.
size
())
{
{
const
char
*
prob_config
=
mlirRockTuningGetKey
(
tuning_table
.
get
(),
mmodule
.
get
());
std
::
cerr
<<
"MLIR tuning key overflowed buffer, needed "
<<
prob_config_bytes
std
::
stringstream
key
(
prob_config
);
<<
" bytes"
<<
std
::
endl
;
std
::
cerr
<<
"fails to set param on"
<<
prob_config
<<
std
::
endl
;
return
false
;
dump_tuning_cfg
(
prob_config
);
}
std
::
string
prob_config_str
(
prob_config
.
begin
(),
prob_config
.
begin
()
+
prob_config_bytes
);
if
(
tuning_table
.
second
)
{
std
::
cerr
<<
"NOTE: MLIR tuning table did not include a key for "
<<
prob_config_str
<<
std
::
endl
;
}
dump_tuning_cfg
(
prob_config_str
);
return
false
;
return
false
;
}
}
return
true
;
return
true
;
...
@@ -775,7 +801,8 @@ struct mlir_program
...
@@ -775,7 +801,8 @@ struct mlir_program
mlir_module
mmodule
;
mlir_module
mmodule
;
problem_params
pp
;
problem_params
pp
;
std
::
deque
<
std
::
string
>
strings
{};
std
::
deque
<
std
::
string
>
strings
{};
std
::
string
target_arch
;
std
::
string
target_arch
=
""
;
std
::
size_t
num_cu
=
0
;
std
::
string
sym_name
;
std
::
string
sym_name
;
};
};
...
@@ -832,7 +859,7 @@ void adjust_param_shapes(module& m, const std::vector<shape>& inputs)
...
@@ -832,7 +859,7 @@ void adjust_param_shapes(module& m, const std::vector<shape>& inputs)
}
}
}
}
code_object_op
compile_mlir
(
const
context
&
,
code_object_op
compile_mlir
(
const
context
&
migraphx_ctx
,
module
m
,
module
m
,
const
std
::
vector
<
instruction_ref
>&
inputs
,
const
std
::
vector
<
instruction_ref
>&
inputs
,
const
value
&
solution
)
const
value
&
solution
)
...
@@ -844,7 +871,7 @@ code_object_op compile_mlir(const context&,
...
@@ -844,7 +871,7 @@ code_object_op compile_mlir(const context&,
std
::
cout
<<
m
<<
std
::
endl
;
std
::
cout
<<
m
<<
std
::
endl
;
mlir_program
mp
;
mlir_program
mp
;
mp
.
find_target
(
);
mp
.
set_gpu_properties
(
migraphx_ctx
);
mp
.
parse
(
m
);
mp
.
parse
(
m
);
auto
mod_op
=
mlirModuleGetOperation
(
mp
.
mmodule
.
get
());
auto
mod_op
=
mlirModuleGetOperation
(
mp
.
mmodule
.
get
());
if
(
trace
)
if
(
trace
)
...
@@ -871,12 +898,13 @@ instruction_ref insert_mlir(module& m,
...
@@ -871,12 +898,13 @@ instruction_ref insert_mlir(module& m,
return
m
.
insert_instruction
(
ins
,
co
,
refs
);
return
m
.
insert_instruction
(
ins
,
co
,
refs
);
}
}
tuning_config
get_tuning_config_mlir
(
module
m
,
const
std
::
vector
<
shape
>&
inputs
)
tuning_config
get_tuning_config_mlir
(
const
context
&
migraphx_ctx
,
module
m
,
const
std
::
vector
<
shape
>&
inputs
)
{
{
adjust_param_shapes
(
m
,
inputs
);
adjust_param_shapes
(
m
,
inputs
);
mlir_program
mp
;
mlir_program
mp
;
mp
.
find_target
(
);
mp
.
set_gpu_properties
(
migraphx_ctx
);
mp
.
parse
(
m
);
mp
.
parse
(
m
);
return
mp
.
get_tuning_config
();
return
mp
.
get_tuning_config
();
}
}
...
@@ -903,10 +931,14 @@ instruction_ref
...
@@ -903,10 +931,14 @@ instruction_ref
insert_mlir
(
module
&
m
,
instruction_ref
,
code_object_op
co
,
const
std
::
vector
<
instruction_ref
>&
)
insert_mlir
(
module
&
m
,
instruction_ref
,
code_object_op
co
,
const
std
::
vector
<
instruction_ref
>&
)
{
{
use
(
co
);
use
(
co
);
use
(
m
);
return
m
.
end
();
return
m
.
end
();
}
}
tuning_config
get_tuning_config_mlir
(
module
,
const
std
::
vector
<
shape
>&
)
{
return
{};
}
tuning_config
get_tuning_config_mlir
(
const
context
&
,
module
,
const
std
::
vector
<
shape
>&
)
{
return
{};
}
// NOLINTEND(performance-unnecessary-value-param)
// NOLINTEND(performance-unnecessary-value-param)
#endif
#endif
...
...
src/targets/gpu/time_op.cpp
View file @
476ed17c
...
@@ -34,7 +34,7 @@ namespace gpu {
...
@@ -34,7 +34,7 @@ namespace gpu {
std
::
vector
<
argument
>
generate_arguments
(
const
std
::
vector
<
shape
>&
shapes
,
unsigned
long
seed
=
0
)
std
::
vector
<
argument
>
generate_arguments
(
const
std
::
vector
<
shape
>&
shapes
,
unsigned
long
seed
=
0
)
{
{
std
::
vector
<
argument
>
args
;
std
::
vector
<
argument
>
args
;
std
::
transform
(
shapes
.
begin
(),
shapes
.
end
(),
std
::
back_inserter
(
args
),
[
&
](
auto
&
s
)
{
std
::
transform
(
shapes
.
begin
(),
shapes
.
end
(),
std
::
back_inserter
(
args
),
[
&
](
const
auto
&
s
)
{
return
to_gpu
(
generate_argument
(
s
,
seed
++
));
return
to_gpu
(
generate_argument
(
s
,
seed
++
));
});
});
return
args
;
return
args
;
...
...
Prev
1
2
3
4
5
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment