Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
MIGraphX
Commits
52585d4f
Unverified
Commit
52585d4f
authored
Nov 10, 2023
by
Chris Austen
Committed by
GitHub
Nov 10, 2023
Browse files
Merge branch 'develop' into enable_navi_32_ci
parents
f0370072
d8011adf
Changes
81
Expand all
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
1125 additions
and
388 deletions
+1125
-388
src/include/migraphx/op/normalize_attribute.hpp
src/include/migraphx/op/normalize_attribute.hpp
+2
-0
src/include/migraphx/op/quantizelinear.hpp
src/include/migraphx/op/quantizelinear.hpp
+5
-5
src/include/migraphx/op/slice.hpp
src/include/migraphx/op/slice.hpp
+322
-136
src/include/migraphx/operators.hpp
src/include/migraphx/operators.hpp
+1
-1
src/normalize_attributes.cpp
src/normalize_attributes.cpp
+8
-8
src/onnx/include/migraphx/onnx/onnx_parser.hpp
src/onnx/include/migraphx/onnx/onnx_parser.hpp
+5
-4
src/onnx/onnx.cpp
src/onnx/onnx.cpp
+1
-0
src/onnx/parse_generic_op.cpp
src/onnx/parse_generic_op.cpp
+1
-1
src/onnx/parse_isinf.cpp
src/onnx/parse_isinf.cpp
+87
-0
src/onnx/parse_loop.cpp
src/onnx/parse_loop.cpp
+10
-0
src/onnx/parse_resize.cpp
src/onnx/parse_resize.cpp
+90
-63
src/onnx/parse_slice.cpp
src/onnx/parse_slice.cpp
+6
-4
src/py/migraphx_py.cpp
src/py/migraphx_py.cpp
+5
-2
src/rewrite_quantization.cpp
src/rewrite_quantization.cpp
+1
-1
src/simplify_dyn_ops.cpp
src/simplify_dyn_ops.cpp
+46
-2
src/simplify_reshapes.cpp
src/simplify_reshapes.cpp
+2
-2
src/targets/gpu/CMakeLists.txt
src/targets/gpu/CMakeLists.txt
+12
-1
src/targets/gpu/compile_ops.cpp
src/targets/gpu/compile_ops.cpp
+10
-7
src/targets/gpu/gemm_impl.cpp
src/targets/gpu/gemm_impl.cpp
+474
-145
src/targets/gpu/include/migraphx/gpu/gemm.hpp
src/targets/gpu/include/migraphx/gpu/gemm.hpp
+37
-6
No files found.
src/include/migraphx/op/normalize_attribute.hpp
View file @
52585d4f
...
@@ -40,6 +40,8 @@ namespace op {
...
@@ -40,6 +40,8 @@ namespace op {
* 2. use_rank (default) vs use_len:
* 2. use_rank (default) vs use_len:
* `use_rank` sets the max value/index of the attribute as the rank of lens.
* `use_rank` sets the max value/index of the attribute as the rank of lens.
* `use_lens` sets the max value/index as the corresponding value in lens at the axes index.
* `use_lens` sets the max value/index as the corresponding value in lens at the axes index.
* Uses the dynamic_dimension.max value for dynamic shapes. Returns the original vector
* (no normalization) if any of dynamic_dimension[axes] are not fixed.
* 3. `clip_min` vs. `not_clip_min` (default):
* 3. `clip_min` vs. `not_clip_min` (default):
* Clip values less than the minimum to the minimum or not.
* Clip values less than the minimum to the minimum or not.
* 4. `include_min` vs. `exclude_min` (default):
* 4. `include_min` vs. `exclude_min` (default):
...
...
src/include/migraphx/op/quantizelinear.hpp
View file @
52585d4f
...
@@ -30,11 +30,11 @@
...
@@ -30,11 +30,11 @@
#include <migraphx/par_for.hpp>
#include <migraphx/par_for.hpp>
#include <migraphx/value.hpp>
#include <migraphx/value.hpp>
#include <cmath>
#include <cmath>
#include <fenv.h>
namespace
migraphx
{
namespace
migraphx
{
inline
namespace
MIGRAPHX_INLINE_NS
{
inline
namespace
MIGRAPHX_INLINE_NS
{
namespace
op
{
namespace
op
{
struct
quantizelinear
struct
quantizelinear
{
{
std
::
string
name
()
const
{
return
"quantizelinear"
;
}
std
::
string
name
()
const
{
return
"quantizelinear"
;
}
...
@@ -71,26 +71,26 @@ struct quantizelinear
...
@@ -71,26 +71,26 @@ struct quantizelinear
{
{
y_zero_point
=
args
.
at
(
2
);
y_zero_point
=
args
.
at
(
2
);
}
}
argument
result
{
output_shape
};
argument
result
{
output_shape
};
auto
rounding_mode
=
fegetround
();
fesetround
(
FE_TONEAREST
);
visit_all
(
result
,
y_zero_point
)([
&
](
auto
output
,
auto
zero_pts
)
{
visit_all
(
result
,
y_zero_point
)([
&
](
auto
output
,
auto
zero_pts
)
{
visit_all
(
x
,
y_scale
)([
&
](
auto
input
,
auto
scales
)
{
visit_all
(
x
,
y_scale
)([
&
](
auto
input
,
auto
scales
)
{
using
quant_type
=
typename
decltype
(
output
)
::
value_type
;
using
quant_type
=
typename
decltype
(
output
)
::
value_type
;
auto
min_value
=
std
::
numeric_limits
<
quant_type
>::
min
();
auto
min_value
=
std
::
numeric_limits
<
quant_type
>::
min
();
auto
max_value
=
std
::
numeric_limits
<
quant_type
>::
max
();
auto
max_value
=
std
::
numeric_limits
<
quant_type
>::
max
();
par_for
(
output_shape
.
elements
(),
[
&
](
auto
i
)
{
par_for
(
output_shape
.
elements
(),
[
&
](
auto
i
)
{
int64_t
quantized
=
static_cast
<
int64_t
>
(
std
::
round
(
input
[
i
]
/
scales
[
i
]))
+
int64_t
quantized
=
static_cast
<
int64_t
>
(
std
::
nearbyint
(
input
[
i
]
/
scales
[
i
]))
+
static_cast
<
int64_t
>
(
zero_pts
[
i
]);
static_cast
<
int64_t
>
(
zero_pts
[
i
]);
output
[
i
]
=
std
::
max
(
static_cast
<
int64_t
>
(
min_value
),
output
[
i
]
=
std
::
max
(
static_cast
<
int64_t
>
(
min_value
),
std
::
min
(
static_cast
<
int64_t
>
(
max_value
),
quantized
));
std
::
min
(
static_cast
<
int64_t
>
(
max_value
),
quantized
));
});
});
});
});
});
});
fesetround
(
rounding_mode
);
return
result
;
return
result
;
}
}
};
};
}
// namespace op
}
// namespace op
}
// namespace MIGRAPHX_INLINE_NS
}
// namespace MIGRAPHX_INLINE_NS
}
// namespace migraphx
}
// namespace migraphx
...
...
src/include/migraphx/op/slice.hpp
View file @
52585d4f
This diff is collapsed.
Click to expand it.
src/include/migraphx/operators.hpp
View file @
52585d4f
...
@@ -84,6 +84,7 @@
...
@@ -84,6 +84,7 @@
#include <migraphx/op/mod.hpp>
#include <migraphx/op/mod.hpp>
#include <migraphx/op/mul.hpp>
#include <migraphx/op/mul.hpp>
#include <migraphx/op/multibroadcast.hpp>
#include <migraphx/op/multibroadcast.hpp>
#include <migraphx/op/nearbyint.hpp>
#include <migraphx/op/neg.hpp>
#include <migraphx/op/neg.hpp>
#include <migraphx/op/nonmaxsuppression.hpp>
#include <migraphx/op/nonmaxsuppression.hpp>
#include <migraphx/op/nonzero.hpp>
#include <migraphx/op/nonzero.hpp>
...
@@ -110,7 +111,6 @@
...
@@ -110,7 +111,6 @@
#include <migraphx/op/rnn_variable_seq_lens.hpp>
#include <migraphx/op/rnn_variable_seq_lens.hpp>
#include <migraphx/op/rnn_var_sl_last_output.hpp>
#include <migraphx/op/rnn_var_sl_last_output.hpp>
#include <migraphx/op/roialign.hpp>
#include <migraphx/op/roialign.hpp>
#include <migraphx/op/round.hpp>
#include <migraphx/op/rsqrt.hpp>
#include <migraphx/op/rsqrt.hpp>
#include <migraphx/op/scalar.hpp>
#include <migraphx/op/scalar.hpp>
#include <migraphx/op/scatter_add.hpp>
#include <migraphx/op/scatter_add.hpp>
...
...
src/normalize_attributes.cpp
View file @
52585d4f
...
@@ -66,15 +66,15 @@ auto tune_attribute(const std::vector<int64_t>& vec,
...
@@ -66,15 +66,15 @@ auto tune_attribute(const std::vector<int64_t>& vec,
{
{
if
(
input_shape
.
dynamic
())
if
(
input_shape
.
dynamic
())
{
{
// return the unchanged `vec` if the dynamic_dimensions at `axes` are not fixed
if
(
std
::
any_of
(
axes
.
begin
(),
axes
.
end
(),
[
&
](
auto
ax
)
{
return
not
input_shape
.
dyn_dims
().
at
(
ax
).
is_fixed
();
}))
{
return
vec
;
}
std
::
transform
(
axes
.
begin
(),
axes
.
end
(),
max_vals
.
begin
(),
[
&
](
auto
i
)
{
std
::
transform
(
axes
.
begin
(),
axes
.
end
(),
max_vals
.
begin
(),
[
&
](
auto
i
)
{
const
auto
&
dd
=
input_shape
.
dyn_dims
().
at
(
i
);
return
input_shape
.
dyn_dims
().
at
(
i
).
max
;
if
(
not
dd
.
is_fixed
())
{
MIGRAPHX_THROW
(
"NORMALIZE_ATTR: 'use_lens' on a non-fixed dynamic dimension, axis="
+
std
::
to_string
(
i
));
}
return
dd
.
max
;
});
});
}
}
else
else
...
...
src/onnx/include/migraphx/onnx/onnx_parser.hpp
View file @
52585d4f
...
@@ -97,10 +97,11 @@ struct onnx_parser
...
@@ -97,10 +97,11 @@ struct onnx_parser
shape
::
dynamic_dimension
default_dyn_dim_value
=
{
1
,
1
};
shape
::
dynamic_dimension
default_dyn_dim_value
=
{
1
,
1
};
std
::
unordered_map
<
std
::
string
,
std
::
vector
<
std
::
size_t
>>
map_input_dims
;
std
::
unordered_map
<
std
::
string
,
std
::
vector
<
std
::
size_t
>>
map_input_dims
;
std
::
unordered_map
<
std
::
string
,
std
::
vector
<
shape
::
dynamic_dimension
>>
map_dyn_input_dims
;
std
::
unordered_map
<
std
::
string
,
std
::
vector
<
shape
::
dynamic_dimension
>>
map_dyn_input_dims
;
bool
use_dyn_output
=
false
;
bool
use_dyn_output
=
false
;
bool
skip_unknown_operators
=
false
;
bool
skip_unknown_operators
=
false
;
int64_t
max_loop_iterations
=
10
;
int64_t
max_loop_iterations
=
10
;
int64_t
opset_version
=
13
;
int64_t
limit_max_iterations
=
std
::
numeric_limits
<
uint16_t
>::
max
();
int64_t
opset_version
=
13
;
std
::
unordered_map
<
std
::
string
,
op_func
>
ops
;
std
::
unordered_map
<
std
::
string
,
op_func
>
ops
;
...
...
src/onnx/onnx.cpp
View file @
52585d4f
...
@@ -67,6 +67,7 @@ program parse_onnx_from(const onnx_options& options, Ts&&... xs)
...
@@ -67,6 +67,7 @@ program parse_onnx_from(const onnx_options& options, Ts&&... xs)
}
}
parser
.
skip_unknown_operators
=
options
.
skip_unknown_operators
;
parser
.
skip_unknown_operators
=
options
.
skip_unknown_operators
;
parser
.
max_loop_iterations
=
options
.
max_loop_iterations
;
parser
.
max_loop_iterations
=
options
.
max_loop_iterations
;
parser
.
limit_max_iterations
=
options
.
limit_max_iterations
;
parser
.
use_dyn_output
=
options
.
use_dyn_output
;
parser
.
use_dyn_output
=
options
.
use_dyn_output
;
if
(
options
.
print_program_on_error
)
if
(
options
.
print_program_on_error
)
...
...
src/onnx/parse_generic_op.cpp
View file @
52585d4f
...
@@ -60,7 +60,7 @@ struct parse_generic_op : op_parser<parse_generic_op>
...
@@ -60,7 +60,7 @@ struct parse_generic_op : op_parser<parse_generic_op>
{
"Neg"
,
"neg"
},
{
"Neg"
,
"neg"
},
{
"Reciprocal"
,
"recip"
},
{
"Reciprocal"
,
"recip"
},
{
"Relu"
,
"relu"
},
{
"Relu"
,
"relu"
},
{
"Round"
,
"
round
"
},
{
"Round"
,
"
nearbyint
"
},
{
"Sigmoid"
,
"sigmoid"
},
{
"Sigmoid"
,
"sigmoid"
},
{
"Sign"
,
"sign"
},
{
"Sign"
,
"sign"
},
{
"Sin"
,
"sin"
},
{
"Sin"
,
"sin"
},
...
...
src/onnx/parse_isinf.cpp
0 → 100644
View file @
52585d4f
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2023 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#include <migraphx/onnx/op_parser.hpp>
#include <migraphx/ranges.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/instruction.hpp>
namespace
migraphx
{
inline
namespace
MIGRAPHX_INLINE_NS
{
namespace
onnx
{
struct
parse_isinf
:
op_parser
<
parse_isinf
>
{
std
::
vector
<
op_desc
>
operators
()
const
{
return
{{
"IsInf"
,
"isinf"
}};
}
instruction_ref
parse
(
const
op_desc
&
/*opd*/
,
const
onnx_parser
&
parser
,
onnx_parser
::
node_info
info
,
const
std
::
vector
<
instruction_ref
>&
args
)
const
{
bool
detect_negative
=
true
;
bool
detect_positive
=
true
;
if
(
contains
(
info
.
attributes
,
"detect_negative"
))
{
detect_negative
=
static_cast
<
bool
>
(
parser
.
parse_value
(
info
.
attributes
.
at
(
"detect_negative"
)).
at
<
int
>
());
}
if
(
contains
(
info
.
attributes
,
"detect_positive"
))
{
detect_positive
=
static_cast
<
bool
>
(
parser
.
parse_value
(
info
.
attributes
.
at
(
"detect_positive"
)).
at
<
int
>
());
}
auto
x_shape
=
args
[
0
]
->
get_shape
();
if
(
not
detect_negative
and
not
detect_positive
)
{
return
info
.
add_instruction
(
make_op
(
"multibroadcast"
,
{{
"out_lens"
,
x_shape
.
lens
()}}),
info
.
add_literal
(
migraphx
::
literal
{
migraphx
::
shape
{
shape
::
bool_type
},
{
false
}}));
}
auto
is_inf
=
info
.
add_instruction
(
make_op
(
"isinf"
),
args
[
0
]);
if
(
detect_negative
and
detect_positive
)
{
return
is_inf
;
}
auto
zero_l
=
info
.
add_literal
(
migraphx
::
literal
{
migraphx
::
shape
{
x_shape
.
type
()},
{
0
}});
auto
mb_zero
=
info
.
add_instruction
(
make_op
(
"multibroadcast"
,
{{
"out_lens"
,
x_shape
.
lens
()}}),
zero_l
);
auto
cond
=
info
.
add_broadcastable_binary_op
(
detect_negative
?
"less"
:
"greater"
,
args
[
0
],
mb_zero
);
if
(
cond
->
get_shape
().
type
()
!=
shape
::
bool_type
)
{
cond
=
info
.
add_instruction
(
make_op
(
"convert"
,
{{
"target_type"
,
shape
::
bool_type
}}),
cond
);
}
return
info
.
add_instruction
(
make_op
(
"logical_and"
),
is_inf
,
cond
);
}
};
}
// namespace onnx
}
// namespace MIGRAPHX_INLINE_NS
}
// namespace migraphx
src/onnx/parse_loop.cpp
View file @
52585d4f
...
@@ -58,6 +58,16 @@ struct parse_loop : op_parser<parse_loop>
...
@@ -58,6 +58,16 @@ struct parse_loop : op_parser<parse_loop>
}
}
}
}
// cap max_iter because loop uses static shapes with max_iter size and huge numbers
// here can cause overflow
if
(
max_iterations
>
parser
.
limit_max_iterations
)
{
std
::
cerr
<<
"WARNING: PARSE_LOOP max_iterations exceeds the maximum loop "
"iterations limit, it will be changed from "
<<
max_iterations
<<
" to "
<<
parser
.
limit_max_iterations
<<
".
\n
"
;
max_iterations
=
parser
.
limit_max_iterations
;
}
// condition input is empty
// condition input is empty
if
(
args
.
at
(
1
)
->
name
()
==
"undefined"
)
if
(
args
.
at
(
1
)
->
name
()
==
"undefined"
)
{
{
...
...
src/onnx/parse_resize.cpp
View file @
52585d4f
...
@@ -181,6 +181,76 @@ static std::string get_nearest_mode(const onnx_parser::attribute_map& attr)
...
@@ -181,6 +181,76 @@ static std::string get_nearest_mode(const onnx_parser::attribute_map& attr)
return
nearest_mode
;
return
nearest_mode
;
}
}
static
std
::
vector
<
double
>
get_scales
(
const
onnx_parser
::
attribute_map
&
attr
)
{
std
::
vector
<
double
>
scales
;
if
(
contains
(
attr
,
"scales"
))
{
copy
(
attr
.
at
(
"scales"
).
floats
(),
std
::
back_inserter
(
scales
));
}
return
scales
;
}
static
void
parse_args
(
const
std
::
vector
<
instruction_ref
>&
args
,
const
std
::
vector
<
size_t
>&
in_lens
,
const
std
::
string
&
op_name
,
std
::
vector
<
double
>&
vec_scale
,
std
::
vector
<
std
::
size_t
>&
out_lens
)
{
for
(
const
auto
&
arg
:
args
)
{
if
(
arg
->
name
()
==
"undefined"
or
arg
==
args
.
front
())
{
continue
;
}
// skipped empty input
auto
lens
=
arg
->
get_shape
().
lens
();
if
(
lens
.
empty
())
{
continue
;
}
auto
type
=
arg
->
get_shape
().
type
();
// output size
if
(
type
==
shape
::
int64_type
)
{
auto
arg_out_s
=
arg
->
eval
();
check_arg_empty
(
arg_out_s
,
"PARSE_"
+
op_name
+
": dynamic output size is not supported!"
);
arg_out_s
.
visit
([
&
](
const
auto
&
ol
)
{
out_lens
.
assign
(
ol
.
begin
(),
ol
.
end
());
});
if
(
out_lens
.
size
()
!=
in_lens
.
size
())
{
MIGRAPHX_THROW
(
"PARSE_"
+
op_name
+
": specified output size does not match input size"
);
}
// compute the scale
vec_scale
.
resize
(
in_lens
.
size
());
std
::
transform
(
in_lens
.
begin
(),
in_lens
.
end
(),
out_lens
.
begin
(),
vec_scale
.
begin
(),
[](
auto
iss
,
auto
oss
)
{
return
1.0
*
oss
/
iss
;
});
}
else
{
// scale input
if
(
lens
[
0
]
==
in_lens
.
size
())
{
auto
arg_scale
=
arg
->
eval
();
check_arg_empty
(
arg_scale
,
"PARSE_"
+
op_name
+
": dynamic input scale is not supported!"
);
arg_scale
.
visit
([
&
](
const
auto
&
v
)
{
vec_scale
.
assign
(
v
.
begin
(),
v
.
end
());
});
}
}
}
}
struct
parse_resize
:
op_parser
<
parse_resize
>
struct
parse_resize
:
op_parser
<
parse_resize
>
{
{
std
::
vector
<
op_desc
>
operators
()
const
{
return
{{
"Resize"
},
{
"Upsample"
}};
}
std
::
vector
<
op_desc
>
operators
()
const
{
return
{{
"Resize"
},
{
"Upsample"
}};
}
...
@@ -214,72 +284,30 @@ struct parse_resize : op_parser<parse_resize>
...
@@ -214,72 +284,30 @@ struct parse_resize : op_parser<parse_resize>
std
::
vector
<
std
::
size_t
>
out_lens
(
in_lens
.
size
());
std
::
vector
<
std
::
size_t
>
out_lens
(
in_lens
.
size
());
// scale
// scale
std
::
vector
<
double
>
vec_scale
;
std
::
vector
<
double
>
vec_scale
=
get_scales
(
info
.
attributes
)
;
for
(
const
auto
&
arg
:
args
)
// If `scales` was not an attribute, it must be an input
if
(
vec_scale
.
empty
())
{
{
if
(
arg
->
name
()
==
"undefined"
or
arg
==
args
.
front
())
// Depending on the args, it *must* populate the `vec_scale`, and might populate
{
// `out_lens`
continue
;
parse_args
(
args
,
in_lens
,
opd
.
op_name
,
vec_scale
,
out_lens
);
}
}
// skipped empty input
auto
lens
=
arg
->
get_shape
().
lens
();
if
(
lens
.
empty
())
{
continue
;
}
auto
type
=
arg
->
get_shape
().
type
();
// output size
if
(
type
==
shape
::
int64_type
)
{
auto
arg_out_s
=
arg
->
eval
();
check_arg_empty
(
arg_out_s
,
"PARSE_"
+
opd
.
op_name
+
": dynamic output size is not supported!"
);
arg_out_s
.
visit
([
&
](
const
auto
&
ol
)
{
out_lens
.
assign
(
ol
.
begin
(),
ol
.
end
());
});
if
(
out_lens
.
size
()
!=
in_lens
.
size
())
{
MIGRAPHX_THROW
(
"PARSE_"
+
opd
.
op_name
+
": specified output size does not match input size"
);
}
// compute the scale
if
(
in_lens
.
size
()
!=
vec_scale
.
size
())
vec_scale
.
resize
(
in_lens
.
size
());
{
std
::
transform
(
in_lens
.
begin
(),
MIGRAPHX_THROW
(
"PARSE_"
+
opd
.
op_name
+
": ranks of input and scale are different!"
);
in_lens
.
end
(),
}
out_lens
.
begin
(),
vec_scale
.
begin
(),
[](
auto
iss
,
auto
oss
)
{
return
1.0
*
oss
/
iss
;
});
}
else
{
// scale input
// if the output was not calculated yet, we update it based on the scales
if
(
lens
[
0
]
==
in_lens
.
size
())
if
(
all_of
(
out_lens
.
cbegin
(),
out_lens
.
cend
(),
[](
auto
o
)
{
return
o
==
0
;
}))
{
{
auto
arg_scale
=
arg
->
eval
();
std
::
transform
(
check_arg_empty
(
arg_scale
,
in_lens
.
begin
(),
"PARSE_"
+
opd
.
op_name
+
in_lens
.
end
(),
": dynamic input scale is not supported!"
);
vec_scale
.
begin
(),
out_lens
.
begin
(),
arg_scale
.
visit
([
&
](
const
auto
&
v
)
{
vec_scale
.
assign
(
v
.
begin
(),
v
.
end
());
});
[
&
](
auto
idx
,
auto
scale
)
{
return
static_cast
<
std
::
size_t
>
(
idx
*
scale
);
});
if
(
in_lens
.
size
()
!=
vec_scale
.
size
())
{
MIGRAPHX_THROW
(
"PARSE_"
+
opd
.
op_name
+
": ranks of input and scale are different!"
);
}
std
::
transform
(
in_lens
.
begin
(),
in_lens
.
end
(),
vec_scale
.
begin
(),
out_lens
.
begin
(),
[
&
](
auto
idx
,
auto
scale
)
{
return
static_cast
<
std
::
size_t
>
(
idx
*
scale
);
});
}
}
}
}
shape
out_s
{
in_s
.
type
(),
out_lens
};
shape
out_s
{
in_s
.
type
(),
out_lens
};
...
@@ -288,7 +316,6 @@ struct parse_resize : op_parser<parse_resize>
...
@@ -288,7 +316,6 @@ struct parse_resize : op_parser<parse_resize>
// reshape input to one-dimension
// reshape input to one-dimension
std
::
vector
<
int64_t
>
rsp_lens
=
{
static_cast
<
int64_t
>
(
in_s
.
elements
())};
std
::
vector
<
int64_t
>
rsp_lens
=
{
static_cast
<
int64_t
>
(
in_s
.
elements
())};
args
[
0
]
=
info
.
make_contiguous
(
args
[
0
]);
auto
rsp
=
info
.
add_instruction
(
make_op
(
"reshape"
,
{{
"dims"
,
rsp_lens
}}),
args
[
0
]);
auto
rsp
=
info
.
add_instruction
(
make_op
(
"reshape"
,
{{
"dims"
,
rsp_lens
}}),
args
[
0
]);
if
(
mode
==
"nearest"
)
if
(
mode
==
"nearest"
)
...
...
src/onnx/parse_slice.cpp
View file @
52585d4f
...
@@ -46,6 +46,9 @@ struct parse_slice : op_parser<parse_slice>
...
@@ -46,6 +46,9 @@ struct parse_slice : op_parser<parse_slice>
void
always_insert
(
instruction_ref
arg
)
{
op_args
.
insert
(
op_args
.
begin
(),
arg
);
}
void
always_insert
(
instruction_ref
arg
)
{
op_args
.
insert
(
op_args
.
begin
(),
arg
);
}
/**
* Either insert argument into `this->op_args` or return the constant value of the argument
*/
std
::
vector
<
int64_t
>
insert
(
instruction_ref
arg
)
std
::
vector
<
int64_t
>
insert
(
instruction_ref
arg
)
{
{
std
::
vector
<
int64_t
>
result
;
std
::
vector
<
int64_t
>
result
;
...
@@ -144,16 +147,15 @@ struct parse_slice : op_parser<parse_slice>
...
@@ -144,16 +147,15 @@ struct parse_slice : op_parser<parse_slice>
sd
.
op
.
axes
=
axes
;
sd
.
op
.
axes
=
axes
;
}
}
if
(
not
sd
.
steps
.
empty
(
))
if
(
std
::
any_of
(
sd
.
steps
.
begin
(),
sd
.
steps
.
end
(),
[](
auto
s
)
{
return
s
!=
1
;
}
))
{
{
if
(
sd
.
op
.
starts
.
empty
()
or
sd
.
op
.
ends
.
empty
())
if
(
sd
.
op
.
starts
.
empty
()
or
sd
.
op
.
ends
.
empty
())
MIGRAPHX_THROW
(
"PARSE_SLICE: steps and variable starts and ends is not supported"
);
MIGRAPHX_THROW
(
"PARSE_SLICE: steps and variable starts and/or ends is not supported"
);
if
(
sd
.
op
.
axes
.
empty
())
if
(
sd
.
op
.
axes
.
empty
())
MIGRAPHX_THROW
(
"PARSE_SLICE: steps and variable axes is not supported"
);
MIGRAPHX_THROW
(
"PARSE_SLICE: steps and variable axes is not supported"
);
}
}
assert
(
sd
.
steps
.
empty
()
or
sd
.
steps
.
size
()
==
sd
.
op
.
axes
.
size
());
// If any axes have negative step, prepare to add a "reverse" op
// If any axes have negative step, prepare to add a "reverse" op
for
(
auto
i
:
range
(
sd
.
steps
.
size
()))
for
(
auto
i
:
range
(
sd
.
steps
.
size
()))
{
{
...
...
src/py/migraphx_py.cpp
View file @
52585d4f
...
@@ -472,7 +472,8 @@ MIGRAPHX_PYBIND11_MODULE(migraphx, m)
...
@@ -472,7 +472,8 @@ MIGRAPHX_PYBIND11_MODULE(migraphx, m)
map_dyn_input_dims
,
map_dyn_input_dims
,
bool
skip_unknown_operators
,
bool
skip_unknown_operators
,
bool
print_program_on_error
,
bool
print_program_on_error
,
int64_t
max_loop_iterations
)
{
int64_t
max_loop_iterations
,
int64_t
limit_max_iterations
)
{
migraphx
::
onnx_options
options
;
migraphx
::
onnx_options
options
;
options
.
default_dim_value
=
default_dim_value
;
options
.
default_dim_value
=
default_dim_value
;
options
.
default_dyn_dim_value
=
default_dyn_dim_value
;
options
.
default_dyn_dim_value
=
default_dyn_dim_value
;
...
@@ -481,6 +482,7 @@ MIGRAPHX_PYBIND11_MODULE(migraphx, m)
...
@@ -481,6 +482,7 @@ MIGRAPHX_PYBIND11_MODULE(migraphx, m)
options
.
skip_unknown_operators
=
skip_unknown_operators
;
options
.
skip_unknown_operators
=
skip_unknown_operators
;
options
.
print_program_on_error
=
print_program_on_error
;
options
.
print_program_on_error
=
print_program_on_error
;
options
.
max_loop_iterations
=
max_loop_iterations
;
options
.
max_loop_iterations
=
max_loop_iterations
;
options
.
limit_max_iterations
=
limit_max_iterations
;
return
migraphx
::
parse_onnx
(
filename
,
options
);
return
migraphx
::
parse_onnx
(
filename
,
options
);
},
},
"Parse onnx file"
,
"Parse onnx file"
,
...
@@ -492,7 +494,8 @@ MIGRAPHX_PYBIND11_MODULE(migraphx, m)
...
@@ -492,7 +494,8 @@ MIGRAPHX_PYBIND11_MODULE(migraphx, m)
std
::
unordered_map
<
std
::
string
,
std
::
vector
<
migraphx
::
shape
::
dynamic_dimension
>>
(),
std
::
unordered_map
<
std
::
string
,
std
::
vector
<
migraphx
::
shape
::
dynamic_dimension
>>
(),
py
::
arg
(
"skip_unknown_operators"
)
=
false
,
py
::
arg
(
"skip_unknown_operators"
)
=
false
,
py
::
arg
(
"print_program_on_error"
)
=
false
,
py
::
arg
(
"print_program_on_error"
)
=
false
,
py
::
arg
(
"max_loop_iterations"
)
=
10
);
py
::
arg
(
"max_loop_iterations"
)
=
10
,
py
::
arg
(
"limit_max_iterations"
)
=
std
::
numeric_limits
<
uint16_t
>::
max
());
m
.
def
(
m
.
def
(
"parse_onnx_buffer"
,
"parse_onnx_buffer"
,
...
...
src/rewrite_quantization.cpp
View file @
52585d4f
...
@@ -47,7 +47,7 @@ void apply_quantizelinear(module& m, instruction_ref ins)
...
@@ -47,7 +47,7 @@ void apply_quantizelinear(module& m, instruction_ref ins)
ins
,
make_op
(
"convert"
,
{{
"target_type"
,
y_scale
->
get_shape
().
type
()}}),
x
);
ins
,
make_op
(
"convert"
,
{{
"target_type"
,
y_scale
->
get_shape
().
type
()}}),
x
);
}
}
auto
div
=
m
.
insert_instruction
(
ins
,
make_op
(
"div"
),
x
,
y_scale
);
auto
div
=
m
.
insert_instruction
(
ins
,
make_op
(
"div"
),
x
,
y_scale
);
auto
add_zero_point
=
m
.
insert_instruction
(
ins
,
make_op
(
"
round
"
),
div
);
auto
add_zero_point
=
m
.
insert_instruction
(
ins
,
make_op
(
"
nearbyint
"
),
div
);
if
(
ins
->
inputs
().
size
()
==
3
)
if
(
ins
->
inputs
().
size
()
==
3
)
{
{
...
...
src/simplify_dyn_ops.cpp
View file @
52585d4f
...
@@ -24,6 +24,7 @@
...
@@ -24,6 +24,7 @@
#include <migraphx/simplify_dyn_ops.hpp>
#include <migraphx/simplify_dyn_ops.hpp>
#include <migraphx/matcher.hpp>
#include <migraphx/matcher.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/literal.hpp>
namespace
migraphx
{
namespace
migraphx
{
inline
namespace
MIGRAPHX_INLINE_NS
{
inline
namespace
MIGRAPHX_INLINE_NS
{
...
@@ -131,10 +132,53 @@ struct find_const_4in_slice
...
@@ -131,10 +132,53 @@ struct find_const_4in_slice
}
}
};
};
/**
* Simplify dimensions_of to a literal when the input arugment has a static shape
* or the dynamic dimensions from `start` to `end` are fixed.
*/
struct
find_static_dimensions_of
{
auto
matcher
()
const
{
return
match
::
name
(
"dimensions_of"
)();
}
void
apply
(
module
&
m
,
const
match
::
matcher_result
&
mr
)
const
{
auto
ins
=
mr
.
result
;
auto
input
=
ins
->
inputs
().
at
(
0
);
auto
dimensions_of_value
=
ins
->
get_operator
().
to_value
();
auto
start
=
dimensions_of_value
.
at
(
"start"
).
to
<
std
::
size_t
>
();
auto
end
=
dimensions_of_value
.
at
(
"end"
).
to
<
std
::
size_t
>
();
if
(
input
->
get_shape
().
dynamic
())
{
// check if dynamic dimensions from start to end are fixed
auto
dds
=
input
->
get_shape
().
dyn_dims
();
if
(
std
::
any_of
(
dds
.
begin
()
+
start
,
dds
.
begin
()
+
end
,
[](
auto
dd
)
{
return
not
dd
.
is_fixed
();
}))
{
return
;
}
}
std
::
size_t
output_ndim
=
end
-
start
;
std
::
vector
<
int64_t
>
vec_shape
(
output_ndim
);
migraphx
::
shape
s
(
migraphx
::
shape
::
int64_type
,
{
output_ndim
});
std
::
vector
<
std
::
size_t
>
input_lens
=
input
->
get_shape
().
to_static
(
1
).
lens
();
std
::
transform
(
input_lens
.
begin
()
+
start
,
input_lens
.
begin
()
+
end
,
vec_shape
.
begin
(),
[](
auto
i
)
{
return
int64_t
(
i
);
});
migraphx
::
shape
output_shape
{
migraphx
::
shape
::
int64_type
,
{
end
-
start
}};
auto
lit_ins
=
m
.
add_literal
(
migraphx
::
literal
{
output_shape
,
vec_shape
});
m
.
replace_instruction
(
ins
,
lit_ins
);
}
};
void
simplify_dyn_ops
::
apply
(
module
&
m
)
const
void
simplify_dyn_ops
::
apply
(
module
&
m
)
const
{
{
match
::
find_matches
(
match
::
find_matches
(
m
,
m
,
find_static_2in_broadcasts
{},
find_const_3in_slice
{},
find_const_4in_slice
{});
find_static_2in_broadcasts
{},
find_static_dimensions_of
{},
find_const_3in_slice
{},
find_const_4in_slice
{});
}
}
}
// namespace MIGRAPHX_INLINE_NS
}
// namespace MIGRAPHX_INLINE_NS
...
...
src/simplify_reshapes.cpp
View file @
52585d4f
...
@@ -647,8 +647,8 @@ struct find_broadcast_transpose
...
@@ -647,8 +647,8 @@ struct find_broadcast_transpose
{
{
auto
transpose
=
r
.
result
;
auto
transpose
=
r
.
result
;
auto
transpose_lens
=
transpose
->
get_shape
().
lens
();
auto
transpose_lens
=
transpose
->
get_shape
().
lens
();
auto
bcast_ins
=
r
.
instructions
[
"bcast_ins"
];
auto
bcast_ins
=
r
.
instructions
[
"bcast_ins"
];
auto
input
=
bcast_ins
->
inputs
().
front
();
auto
input
=
bcast_ins
->
inputs
().
front
();
// scalar transformation does not need extra transpose
// scalar transformation does not need extra transpose
if
(
not
input
->
get_shape
().
scalar
())
if
(
not
input
->
get_shape
().
scalar
())
{
{
...
...
src/targets/gpu/CMakeLists.txt
View file @
52585d4f
# ####################################################################################
# ####################################################################################
# The MIT License (MIT)
# The MIT License (MIT)
#
#
# Copyright (c) 2015-202
2
Advanced Micro Devices, Inc. All rights reserved.
# Copyright (c) 2015-202
3
Advanced Micro Devices, Inc. All rights reserved.
#
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
# of this software and associated documentation files (the "Software"), to deal
...
@@ -245,10 +245,14 @@ else()
...
@@ -245,10 +245,14 @@ else()
endif
()
endif
()
# Check miopen find mode api
# Check miopen find mode api
include
(
CheckLibraryExists
)
include
(
CheckLibraryExists
)
get_target_property
(
MIOPEN_LOCATION MIOpen LOCATION
)
get_target_property
(
MIOPEN_LOCATION MIOpen LOCATION
)
get_target_property
(
ROCBLAS_LOCATION roc::rocblas LOCATION
)
check_library_exists
(
MIOpen
"miopenHiddenSetConvolutionFindMode"
"
${
MIOPEN_LOCATION
}
"
HAS_FIND_MODE_API
)
check_library_exists
(
MIOpen
"miopenHiddenSetConvolutionFindMode"
"
${
MIOPEN_LOCATION
}
"
HAS_FIND_MODE_API
)
check_library_exists
(
MIOpen
"miopenFindSolutions"
"
${
MIOPEN_LOCATION
}
"
HAS_FIND_2_API
)
check_library_exists
(
MIOpen
"miopenFindSolutions"
"
${
MIOPEN_LOCATION
}
"
HAS_FIND_2_API
)
# Beta API for automated GEMM tuning
check_library_exists
(
roc::rocblas
"rocblas_gemm_ex_get_solutions"
"
${
ROCBLAS_LOCATION
}
"
HAS_ROCBLAS_TUNING_BETA_FEATURE_API
)
set
(
MIGRAPHX_USE_FIND_2_API
"
${
HAS_FIND_2_API
}
"
CACHE BOOL
""
)
set
(
MIGRAPHX_USE_FIND_2_API
"
${
HAS_FIND_2_API
}
"
CACHE BOOL
""
)
...
@@ -271,6 +275,13 @@ else()
...
@@ -271,6 +275,13 @@ else()
message
(
STATUS
"MIOpen does not have find mode api"
)
message
(
STATUS
"MIOpen does not have find mode api"
)
endif
()
endif
()
if
(
HAS_ROCBLAS_TUNING_BETA_FEATURE_API
)
target_compile_definitions
(
migraphx_gpu PUBLIC -DMIGRAPHX_USE_ROCBLAS_TUNING_API -DROCBLAS_BETA_FEATURES_API -DROCBLAS_NO_DEPRECATED_WARNINGS
)
message
(
STATUS
"MIGraphx is using Beta API of rocBLAS"
)
else
()
message
(
STATUS
"rocBLAS does not have User Tuning Beta API"
)
endif
()
target_link_libraries
(
migraphx_gpu PUBLIC migraphx MIOpen roc::rocblas
)
target_link_libraries
(
migraphx_gpu PUBLIC migraphx MIOpen roc::rocblas
)
target_link_libraries
(
migraphx_gpu PRIVATE migraphx_device migraphx_kernels
)
target_link_libraries
(
migraphx_gpu PRIVATE migraphx_device migraphx_kernels
)
if
(
MIGRAPHX_USE_COMPOSABLEKERNEL
)
if
(
MIGRAPHX_USE_COMPOSABLEKERNEL
)
...
...
src/targets/gpu/compile_ops.cpp
View file @
52585d4f
...
@@ -168,6 +168,7 @@ struct compile_plan
...
@@ -168,6 +168,7 @@ struct compile_plan
}
}
const
compiled_result
&
benchmark
(
problem_cache
&
pc
)
const
const
compiled_result
&
benchmark
(
problem_cache
&
pc
)
const
{
{
const
auto
trace_level
=
value_of
(
MIGRAPHX_TRACE_BENCHMARKING
{});
if
(
results
.
empty
())
if
(
results
.
empty
())
MIGRAPHX_THROW
(
"No configs to tune"
);
MIGRAPHX_THROW
(
"No configs to tune"
);
if
(
results
.
size
()
==
1
)
if
(
results
.
size
()
==
1
)
...
@@ -178,9 +179,10 @@ struct compile_plan
...
@@ -178,9 +179,10 @@ struct compile_plan
}
}
if
(
not
config
)
if
(
not
config
)
MIGRAPHX_THROW
(
"Multiple kernels without config"
);
MIGRAPHX_THROW
(
"Multiple kernels without config"
);
std
::
cout
<<
"Benchmarking "
<<
preop
.
name
()
<<
": "
<<
results
.
size
()
<<
" configs"
if
(
trace_level
>
0
)
<<
std
::
endl
;
std
::
cout
<<
"Benchmarking "
<<
preop
.
name
()
<<
": "
<<
results
.
size
()
<<
" configs"
if
(
enabled
(
MIGRAPHX_TRACE_BENCHMARKING
{}))
<<
std
::
endl
;
if
(
trace_level
>
1
)
std
::
cout
<<
"Problem: "
<<
config
->
problem
<<
std
::
endl
;
std
::
cout
<<
"Problem: "
<<
config
->
problem
<<
std
::
endl
;
std
::
vector
<
double
>
times
;
std
::
vector
<
double
>
times
;
times
.
reserve
(
results
.
size
());
times
.
reserve
(
results
.
size
());
...
@@ -189,22 +191,23 @@ struct compile_plan
...
@@ -189,22 +191,23 @@ struct compile_plan
config
->
solutions
.
begin
(),
config
->
solutions
.
begin
(),
std
::
back_inserter
(
times
),
std
::
back_inserter
(
times
),
[
&
](
const
auto
&
cr
,
const
auto
&
solution
)
{
[
&
](
const
auto
&
cr
,
const
auto
&
solution
)
{
if
(
enabled
(
MIGRAPHX_TRACE_BENCHMARKING
{})
)
if
(
trace_level
>
1
)
std
::
cout
<<
"Benchmarking solution: "
<<
solution
<<
std
::
endl
;
std
::
cout
<<
"Benchmarking solution: "
<<
solution
<<
std
::
endl
;
if
(
not
cr
.
has_value
())
if
(
not
cr
.
has_value
())
{
{
if
(
enabled
(
MIGRAPHX_TRACE_BENCHMARKING
{})
)
if
(
trace_level
>
1
)
std
::
cout
<<
"No binary"
<<
std
::
endl
;
std
::
cout
<<
"No binary"
<<
std
::
endl
;
return
std
::
numeric_limits
<
double
>::
max
();
return
std
::
numeric_limits
<
double
>::
max
();
}
}
auto
t
=
time_op
(
auto
t
=
time_op
(
*
ctx
,
cr
->
replace
.
code_object
,
to_shapes
(
cr
->
ins
->
inputs
()),
20
);
*
ctx
,
cr
->
replace
.
code_object
,
to_shapes
(
cr
->
ins
->
inputs
()),
20
);
if
(
enabled
(
MIGRAPHX_TRACE_BENCHMARKING
{})
)
if
(
trace_level
>
1
)
std
::
cout
<<
t
<<
"ms"
<<
std
::
endl
;
std
::
cout
<<
t
<<
"ms"
<<
std
::
endl
;
return
t
;
return
t
;
});
});
auto
i
=
std
::
distance
(
times
.
begin
(),
std
::
min_element
(
times
.
begin
(),
times
.
end
()));
auto
i
=
std
::
distance
(
times
.
begin
(),
std
::
min_element
(
times
.
begin
(),
times
.
end
()));
std
::
cout
<<
"Fastest solution: "
<<
config
->
solutions
.
at
(
i
)
<<
std
::
endl
;
if
(
trace_level
>
0
)
std
::
cout
<<
"Fastest solution: "
<<
config
->
solutions
.
at
(
i
)
<<
std
::
endl
;
pc
.
insert
(
preop
.
name
(),
config
->
problem
,
config
->
solutions
.
at
(
i
));
pc
.
insert
(
preop
.
name
(),
config
->
problem
,
config
->
solutions
.
at
(
i
));
if
(
not
results
[
i
].
has_value
())
if
(
not
results
[
i
].
has_value
())
MIGRAPHX_THROW
(
"No valid tuned compilation."
);
MIGRAPHX_THROW
(
"No valid tuned compilation."
);
...
...
src/targets/gpu/gemm_impl.cpp
View file @
52585d4f
This diff is collapsed.
Click to expand it.
src/targets/gpu/include/migraphx/gpu/gemm.hpp
View file @
52585d4f
/*
/*
* The MIT License (MIT)
* The MIT License (MIT)
*
*
* Copyright (c) 2015-202
2
Advanced Micro Devices, Inc. All rights reserved.
* Copyright (c) 2015-202
3
Advanced Micro Devices, Inc. All rights reserved.
*
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* of this software and associated documentation files (the "Software"), to deal
...
@@ -40,9 +40,8 @@ inline namespace MIGRAPHX_INLINE_NS {
...
@@ -40,9 +40,8 @@ inline namespace MIGRAPHX_INLINE_NS {
namespace
gpu
{
namespace
gpu
{
struct
context
;
struct
context
;
void
blas_shape
(
const
shape
&
s
);
shape
transpose_batch
(
const
shape
&
s
,
unsigned
trans_batch
);
shape
transpose_batch
(
const
shape
&
s
,
unsigned
trans_batch
);
void
blas_shape
(
const
shape
&
s
);
template
<
class
Op
>
template
<
class
Op
>
struct
rocblas_gemm
struct
rocblas_gemm
...
@@ -52,6 +51,7 @@ struct rocblas_gemm
...
@@ -52,6 +51,7 @@ struct rocblas_gemm
float
beta
=
0
;
float
beta
=
0
;
bool
compute_fp32
=
false
;
bool
compute_fp32
=
false
;
unsigned
trans_batch
=
0
;
unsigned
trans_batch
=
0
;
int32_t
solution_idx
=
0
;
template
<
class
Self
,
class
F
>
template
<
class
Self
,
class
F
>
static
auto
reflect
(
Self
&
self
,
F
f
)
static
auto
reflect
(
Self
&
self
,
F
f
)
...
@@ -60,7 +60,8 @@ struct rocblas_gemm
...
@@ -60,7 +60,8 @@ struct rocblas_gemm
pack
(
f
(
self
.
alpha
,
"alpha"
),
pack
(
f
(
self
.
alpha
,
"alpha"
),
f
(
self
.
beta
,
"beta"
),
f
(
self
.
beta
,
"beta"
),
f
(
self
.
compute_fp32
,
"compute_fp32"
),
f
(
self
.
compute_fp32
,
"compute_fp32"
),
f
(
self
.
trans_batch
,
"trans_batch"
)));
f
(
self
.
trans_batch
,
"trans_batch"
),
f
(
self
.
solution_idx
,
"solution_idx"
)));
}
}
std
::
string
name
()
const
std
::
string
name
()
const
...
@@ -76,6 +77,8 @@ struct rocblas_gemm
...
@@ -76,6 +77,8 @@ struct rocblas_gemm
{
{
std
::
vector
<
shape
>
in_shapes
(
inputs
);
std
::
vector
<
shape
>
in_shapes
(
inputs
);
in_shapes
.
pop_back
();
in_shapes
.
pop_back
();
// When input shapes are A, B, C the GEMM equation is C = α AB+ β C where α, β are
// scalars
check_shapes
{
in_shapes
,
*
this
}.
has
(
2
,
3
);
check_shapes
{
in_shapes
,
*
this
}.
has
(
2
,
3
);
blas_shape
(
inputs
[
0
]);
blas_shape
(
inputs
[
0
]);
blas_shape
(
inputs
[
1
]);
blas_shape
(
inputs
[
1
]);
...
@@ -111,11 +114,12 @@ struct rocblas_gemm
...
@@ -111,11 +114,12 @@ struct rocblas_gemm
{
{
if
(
this
->
name
()
==
"gpu::gemm"
)
if
(
this
->
name
()
==
"gpu::gemm"
)
{
{
gemm
(
ctx
,
output_shape
,
args
,
alpha
,
beta
,
compute_fp32
);
gemm
_compute
(
ctx
,
output_shape
,
args
,
alpha
,
beta
,
compute_fp32
,
solution_idx
);
}
}
else
else
{
{
gemm
(
ctx
,
output_shape
,
args
,
int32_t
(
alpha
),
int32_t
(
beta
),
compute_fp32
);
gemm_compute
(
ctx
,
output_shape
,
args
,
int32_t
(
alpha
),
int32_t
(
beta
),
compute_fp32
,
solution_idx
);
}
}
return
args
.
back
();
return
args
.
back
();
}
}
...
@@ -124,6 +128,33 @@ struct rocblas_gemm
...
@@ -124,6 +128,33 @@ struct rocblas_gemm
{
{
return
shapes
.
size
()
-
1
;
return
shapes
.
size
()
-
1
;
}
}
void
finalize
(
context
&
ctx
,
const
shape
&
output_shape
,
const
std
::
vector
<
shape
>&
input_shapes
)
{
#ifdef MIGRAPHX_USE_ROCBLAS_TUNING_API
if
(
enabled
(
MIGRAPHX_ENABLE_GEMM_TUNING
{})
or
ctx
.
get_exhaustive_tune_flag
())
{
if
(
this
->
name
()
==
"gpu::gemm"
)
{
solution_idx
=
gemm_finalize
(
ctx
,
output_shape
,
input_shapes
,
alpha
,
beta
,
compute_fp32
,
solution_idx
);
}
else
{
solution_idx
=
gemm_finalize
(
ctx
,
output_shape
,
input_shapes
,
int32_t
(
alpha
),
int32_t
(
beta
),
compute_fp32
,
solution_idx
);
}
}
#else
// suppress compiler warnings
(
void
)
ctx
,
(
void
)
output_shape
,
(
void
)
input_shapes
;
#endif
}
};
};
}
// namespace gpu
}
// namespace gpu
...
...
Prev
1
2
3
4
5
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment