Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
MIGraphX
Commits
e9a3bdaa
Commit
e9a3bdaa
authored
Dec 06, 2023
by
Khalique Ahmed
Browse files
manual merge
parents
d7dfe995
a09dc502
Changes
66
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
408 additions
and
103 deletions
+408
-103
docs/driver/read.rst
docs/driver/read.rst
+4
-0
src/driver/CMakeLists.txt
src/driver/CMakeLists.txt
+1
-0
src/driver/main.cpp
src/driver/main.cpp
+5
-0
src/driver/passes.cpp
src/driver/passes.cpp
+109
-0
src/driver/passes.hpp
src/driver/passes.hpp
+14
-17
src/eliminate_data_type.cpp
src/eliminate_data_type.cpp
+72
-20
src/include/migraphx/eliminate_data_type.hpp
src/include/migraphx/eliminate_data_type.hpp
+2
-1
src/targets/gpu/CMakeLists.txt
src/targets/gpu/CMakeLists.txt
+10
-0
src/targets/gpu/fuse_mlir.cpp
src/targets/gpu/fuse_mlir.cpp
+1
-1
src/targets/gpu/gemm_impl.cpp
src/targets/gpu/gemm_impl.cpp
+68
-17
src/targets/gpu/include/migraphx/gpu/rocblas.hpp
src/targets/gpu/include/migraphx/gpu/rocblas.hpp
+2
-0
src/targets/gpu/kernels/include/migraphx/kernels/dpp.hpp
src/targets/gpu/kernels/include/migraphx/kernels/dpp.hpp
+21
-7
src/targets/gpu/kernels/include/migraphx/kernels/float8.hpp
src/targets/gpu/kernels/include/migraphx/kernels/float8.hpp
+2
-6
src/targets/gpu/kernels/include/migraphx/kernels/reduce.hpp
src/targets/gpu/kernels/include/migraphx/kernels/reduce.hpp
+36
-22
src/targets/gpu/rocblas.cpp
src/targets/gpu/rocblas.cpp
+10
-0
src/targets/gpu/target.cpp
src/targets/gpu/target.cpp
+17
-0
test/verify/gemm_2args_bmv.cpp
test/verify/gemm_2args_bmv.cpp
+8
-3
test/verify/gemm_2args_mm_1.cpp
test/verify/gemm_2args_mm_1.cpp
+8
-3
test/verify/gemm_2args_mm_2.cpp
test/verify/gemm_2args_mm_2.cpp
+9
-3
test/verify/gemm_2args_mm_3.cpp
test/verify/gemm_2args_mm_3.cpp
+9
-3
No files found.
docs/driver/read.rst
View file @
e9a3bdaa
...
@@ -58,6 +58,10 @@ Set the default dynamic dimension (format {min:x, max:y, optimals:[o1,o2,...]})
...
@@ -58,6 +58,10 @@ Set the default dynamic dimension (format {min:x, max:y, optimals:[o1,o2,...]})
Optimize when reading
Optimize when reading
.. option:: --apply-pass, -p
Passes to apply to model
.. option:: --graphviz, -g
.. option:: --graphviz, -g
Print out a graphviz representation.
Print out a graphviz representation.
...
...
src/driver/CMakeLists.txt
View file @
e9a3bdaa
...
@@ -25,6 +25,7 @@
...
@@ -25,6 +25,7 @@
add_executable
(
driver
add_executable
(
driver
main.cpp
main.cpp
verify.cpp
verify.cpp
passes.cpp
perf.cpp
perf.cpp
resnet50.cpp
resnet50.cpp
inceptionv3.cpp
inceptionv3.cpp
...
...
src/driver/main.cpp
View file @
e9a3bdaa
...
@@ -26,6 +26,7 @@
...
@@ -26,6 +26,7 @@
#include "argument_parser.hpp"
#include "argument_parser.hpp"
#include "command.hpp"
#include "command.hpp"
#include "precision.hpp"
#include "precision.hpp"
#include "passes.hpp"
#include "perf.hpp"
#include "perf.hpp"
#include "models.hpp"
#include "models.hpp"
#include "marker_roctx.hpp"
#include "marker_roctx.hpp"
...
@@ -83,6 +84,7 @@ struct loader
...
@@ -83,6 +84,7 @@ struct loader
std
::
vector
<
std
::
string
>
param_dims
;
std
::
vector
<
std
::
string
>
param_dims
;
std
::
vector
<
std
::
string
>
dyn_param_dims
;
std
::
vector
<
std
::
string
>
dyn_param_dims
;
std
::
vector
<
std
::
string
>
output_names
;
std
::
vector
<
std
::
string
>
output_names
;
std
::
vector
<
std
::
string
>
passes
;
void
parse
(
argument_parser
&
ap
)
void
parse
(
argument_parser
&
ap
)
{
{
...
@@ -130,6 +132,7 @@ struct loader
...
@@ -130,6 +132,7 @@ struct loader
ap
.
append
(),
ap
.
append
(),
ap
.
nargs
(
2
));
ap
.
nargs
(
2
));
ap
(
optimize
,
{
"--optimize"
,
"-O"
},
ap
.
help
(
"Optimize when reading"
),
ap
.
set_value
(
true
));
ap
(
optimize
,
{
"--optimize"
,
"-O"
},
ap
.
help
(
"Optimize when reading"
),
ap
.
set_value
(
true
));
ap
(
passes
,
{
"--apply-pass"
,
"-p"
},
ap
.
help
(
"Passes to apply to model"
),
ap
.
append
());
ap
(
output_type
,
ap
(
output_type
,
{
"--graphviz"
,
"-g"
},
{
"--graphviz"
,
"-g"
},
ap
.
help
(
"Print out a graphviz representation."
),
ap
.
help
(
"Print out a graphviz representation."
),
...
@@ -337,6 +340,8 @@ struct loader
...
@@ -337,6 +340,8 @@ struct loader
migraphx
::
dead_code_elimination
{},
migraphx
::
dead_code_elimination
{},
});
});
}
}
if
(
not
passes
.
empty
())
migraphx
::
run_passes
(
*
p
.
get_main_module
(),
get_passes
(
passes
));
return
p
;
return
p
;
}
}
...
...
test/verify/gemm_add_broadcast_half
.cpp
→
src/driver/passes
.cpp
View file @
e9a3bdaa
...
@@ -22,28 +22,88 @@
...
@@ -22,28 +22,88 @@
* THE SOFTWARE.
* THE SOFTWARE.
*/
*/
#include "verify_program.hpp"
#include "passes.hpp"
#include <migraphx/program.hpp>
#include <migraphx/generate.hpp>
#include <migraphx/auto_contiguous.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/dead_code_elimination.hpp>
#include <migraphx/apply_alpha_beta.hpp>
#include <migraphx/eliminate_allocation.hpp>
struct
gemm_add_broadcast_half
:
verify_program
<
gemm_add_broadcast_half
>
#include <migraphx/eliminate_common_subexpression.hpp>
#include <migraphx/eliminate_concat.hpp>
#include <migraphx/eliminate_contiguous.hpp>
#include <migraphx/eliminate_data_type.hpp>
#include <migraphx/eliminate_identity.hpp>
#include <migraphx/eliminate_pad.hpp>
#include <migraphx/inline_module.hpp>
#include <migraphx/insert_pad.hpp>
#include <migraphx/normalize_ops.hpp>
#include <migraphx/optimize_module.hpp>
#include <migraphx/promote_literals.hpp>
#include <migraphx/propagate_constant.hpp>
#include <migraphx/rewrite_gelu.hpp>
#include <migraphx/rewrite_pooling.hpp>
#include <migraphx/rewrite_quantization.hpp>
#include <migraphx/rewrite_rnn.hpp>
#include <migraphx/simplify_algebra.hpp>
#include <migraphx/simplify_dyn_ops.hpp>
#include <migraphx/simplify_qdq.hpp>
#include <migraphx/simplify_reshapes.hpp>
#include <migraphx/ranges.hpp>
#include <unordered_map>
namespace
migraphx
{
namespace
driver
{
inline
namespace
MIGRAPHX_INLINE_NS
{
std
::
unordered_map
<
std
::
string
,
pass
>
create_passes_lookup
()
{
std
::
unordered_map
<
std
::
string
,
pass
>
result
;
// clang-format off
std
::
initializer_list
<
pass
>
passes
=
{
auto_contiguous
{},
dead_code_elimination
{},
eliminate_allocation
{},
eliminate_common_subexpression
{},
eliminate_concat
{},
eliminate_contiguous
{},
eliminate_data_type
{},
eliminate_identity
{},
eliminate_pad
{},
inline_module
{},
insert_pad
{},
normalize_ops
{},
optimize_module
{},
promote_literals
{},
propagate_constant
{},
rewrite_gelu
{},
rewrite_pooling
{},
rewrite_quantization
{},
rewrite_rnn
{},
simplify_algebra
{},
simplify_dyn_ops
{},
simplify_qdq
{},
simplify_reshapes
{},
};
// clang-format on
for
(
const
auto
&
pass
:
passes
)
result
[
pass
.
name
()]
=
pass
;
result
[
"eliminate_dead_code"
]
=
dead_code_elimination
{};
return
result
;
}
std
::
vector
<
pass
>
get_passes
(
const
std
::
vector
<
std
::
string
>&
names
)
{
{
migraphx
::
program
create_program
()
const
std
::
vector
<
pass
>
result
;
{
static
const
std
::
unordered_map
<
std
::
string
,
pass
>
lookup
=
create_passes_lookup
();
migraphx
::
program
p
;
std
::
transform
(
auto
*
mm
=
p
.
get_main_module
();
names
.
begin
(),
names
.
end
(),
std
::
back_inserter
(
result
),
[](
const
std
::
string
&
name
)
{
migraphx
::
shape
m1_shape
{
migraphx
::
shape
::
half_type
,
{
1
,
2
,
3
}};
if
(
not
contains
(
lookup
,
name
))
migraphx
::
shape
m2_shape
{
migraphx
::
shape
::
half_type
,
{
1
,
3
,
4
}};
MIGRAPHX_THROW
(
"Unknown pass: "
+
name
);
migraphx
::
shape
m3_shape
{
migraphx
::
shape
::
half_type
,
{
1
,
1
,
4
}};
return
lookup
.
at
(
name
);
auto
l1
=
mm
->
add_parameter
(
"1"
,
m1_shape
);
});
auto
l2
=
mm
->
add_parameter
(
"2"
,
m2_shape
);
return
result
;
auto
l3
=
mm
->
add_parameter
(
"3"
,
m3_shape
);
}
auto
l3_b
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"multibroadcast"
,
{{
"out_lens"
,
{
1
,
2
,
4
}}}),
l3
);
auto
dot
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"dot"
),
l1
,
l2
);
}
// namespace MIGRAPHX_INLINE_NS
mm
->
add_instruction
(
migraphx
::
make_op
(
"add"
),
dot
,
l3_b
);
}
// namespace driver
return
p
;
}
// namespace migraphx
}
};
test/verify/test_nonzero_half.c
pp
→
src/driver/passes.h
pp
View file @
e9a3bdaa
...
@@ -21,23 +21,20 @@
...
@@ -21,23 +21,20 @@
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
* THE SOFTWARE.
*/
*/
#ifndef MIGRAPHX_GUARD_DRIVER_PASSES_HPP
#define MIGRAPHX_GUARD_DRIVER_PASSES_HPP
#include "verify_program.hpp"
#include <migraphx/pass.hpp>
#include <migraphx/program.hpp>
#include <vector>
#include <migraphx/generate.hpp>
#include <migraphx/make_op.hpp>
struct
test_nonzero_half
:
verify_program
<
test_nonzero_half
>
namespace
migraphx
{
{
namespace
driver
{
migraphx
::
program
create_program
()
const
inline
namespace
MIGRAPHX_INLINE_NS
{
{
migraphx
::
program
p
;
auto
*
mm
=
p
.
get_main_module
();
migraphx
::
shape
s
{
migraphx
::
shape
::
half_type
,
{
3
,
4
,
3
,
5
}};
auto
x
=
mm
->
add_parameter
(
"data"
,
s
);
auto
r
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"nonzero"
),
x
);
mm
->
add_return
({
r
});
return
p
;
std
::
vector
<
pass
>
get_passes
(
const
std
::
vector
<
std
::
string
>&
names
);
}
};
}
// namespace MIGRAPHX_INLINE_NS
}
// namespace driver
}
// namespace migraphx
#endif
src/eliminate_data_type.cpp
View file @
e9a3bdaa
...
@@ -31,6 +31,72 @@
...
@@ -31,6 +31,72 @@
namespace
migraphx
{
namespace
migraphx
{
inline
namespace
MIGRAPHX_INLINE_NS
{
inline
namespace
MIGRAPHX_INLINE_NS
{
void
insert_convert_to_supported_type
(
module
&
m
,
instruction_ref
ins
,
migraphx
::
shape
::
type_t
target_type
,
std
::
set
<
migraphx
::
shape
::
type_t
>
unsupported_types
)
{
migraphx
::
shape
::
type_t
orig_type
=
ins
->
get_shape
().
type
();
std
::
vector
<
instruction_ref
>
inputs
=
ins
->
inputs
();
std
::
transform
(
inputs
.
begin
(),
inputs
.
end
(),
inputs
.
begin
(),
[
&
](
const
auto
&
i
)
{
if
(
contains
(
unsupported_types
,
i
->
get_shape
().
type
()))
{
return
m
.
insert_instruction
(
ins
,
migraphx
::
make_op
(
"convert"
,
{{
"target_type"
,
migraphx
::
to_value
(
target_type
)}}),
i
);
}
else
{
return
i
;
}
});
// if no change
if
(
inputs
==
ins
->
inputs
())
return
;
auto
op
=
ins
->
get_operator
();
auto
attributes
=
op
.
attributes
();
if
(
attributes
.
contains
(
"general_data_type"
))
{
op
=
make_op
(
attributes
[
"general_data_type"
].
to
<
std
::
string
>
(),
op
.
to_value
());
}
auto
new_ins
=
m
.
insert_instruction
(
ins
,
op
,
inputs
);
if
(
orig_type
==
shape
::
tuple_type
)
{
auto
orig_outs
=
ins
->
outputs
();
if
(
not
std
::
all_of
(
orig_outs
.
begin
(),
orig_outs
.
end
(),
[
&
](
const
auto
out_ins
)
{
return
out_ins
->
name
()
==
"get_tuple_elem"
;
}))
MIGRAPHX_THROW
(
"eliminate_data_type: Instruction with tuple output doesn't have all its "
"usages as get_tuple_elem instruction"
);
std
::
transform
(
orig_outs
.
begin
(),
orig_outs
.
end
(),
orig_outs
.
begin
(),
[
&
](
const
auto
out_ins
)
{
auto
gte_ins
=
m
.
insert_instruction
(
ins
,
out_ins
->
get_operator
(),
new_ins
);
auto
orig_out_type
=
out_ins
->
get_shape
().
type
();
if
(
contains
(
unsupported_types
,
orig_out_type
))
{
auto
gte_convert
=
m
.
insert_instruction
(
ins
,
make_op
(
"convert"
,
{{
"target_type"
,
orig_out_type
}}),
gte_ins
);
return
m
.
replace_instruction
(
out_ins
,
gte_convert
);
}
else
{
return
m
.
replace_instruction
(
out_ins
,
gte_ins
);
}
});
}
else
{
auto
convert_back_ins
=
m
.
insert_instruction
(
ins
,
migraphx
::
make_op
(
"convert"
,
{{
"target_type"
,
migraphx
::
to_value
(
orig_type
)}}),
new_ins
);
m
.
replace_instruction
(
ins
,
convert_back_ins
);
}
}
void
eliminate_data_type
::
apply
(
module
&
m
)
const
void
eliminate_data_type
::
apply
(
module
&
m
)
const
{
{
static
const
std
::
vector
<
std
::
string
>
skip_op_names
=
{
"convert"
,
static
const
std
::
vector
<
std
::
string
>
skip_op_names
=
{
"convert"
,
...
@@ -42,31 +108,17 @@ void eliminate_data_type::apply(module& m) const
...
@@ -42,31 +108,17 @@ void eliminate_data_type::apply(module& m) const
"scatternd_add"
,
"scatternd_add"
,
"scatternd_mul"
,
"scatternd_mul"
,
"scatternd_none"
};
"scatternd_none"
};
if
(
unsupported_types
.
empty
())
return
;
for
(
auto
ins
:
iterator_for
(
m
))
for
(
auto
ins
:
iterator_for
(
m
))
{
{
if
(
ins
->
name
()[
0
]
==
'@'
)
if
(
ins
->
name
()[
0
]
==
'@'
)
continue
;
continue
;
if
(
contains
(
skip_op_names
,
ins
->
name
()))
if
(
contains
(
skip_op_names
,
ins
->
name
())
and
not
contains
(
unsupported_ops
,
ins
->
name
()))
continue
;
auto
inputs
=
ins
->
inputs
();
std
::
transform
(
inputs
.
begin
(),
inputs
.
end
(),
inputs
.
begin
(),
[
&
](
auto
i
)
{
if
(
types
.
count
(
i
->
get_shape
().
type
())
==
0
)
return
i
;
return
m
.
insert_instruction
(
ins
,
make_op
(
"convert"
,
{{
"target_type"
,
target_type
}}),
i
);
});
if
(
inputs
==
ins
->
inputs
())
continue
;
continue
;
auto
op
=
ins
->
get_operator
();
if
(
contains
(
unsupported_ops
,
"all"
)
or
contains
(
unsupported_ops
,
ins
->
name
()))
auto
attributes
=
op
.
attributes
();
insert_convert_to_supported_type
(
m
,
ins
,
target_type
,
unsupported_types
);
if
(
attributes
.
contains
(
"general_data_type"
))
{
op
=
make_op
(
attributes
[
"general_data_type"
].
to
<
std
::
string
>
(),
op
.
to_value
());
}
auto
old_type
=
ins
->
get_shape
().
type
();
auto
out
=
m
.
insert_instruction
(
ins
,
op
,
inputs
);
auto
convert
=
m
.
insert_instruction
(
ins
,
make_op
(
"convert"
,
{{
"target_type"
,
old_type
}}),
out
);
m
.
replace_instruction
(
ins
,
convert
);
}
}
}
}
...
...
src/include/migraphx/eliminate_data_type.hpp
View file @
e9a3bdaa
...
@@ -40,8 +40,9 @@ struct module;
...
@@ -40,8 +40,9 @@ struct module;
*/
*/
struct
MIGRAPHX_EXPORT
eliminate_data_type
struct
MIGRAPHX_EXPORT
eliminate_data_type
{
{
std
::
set
<
shape
::
type_t
>
types
;
std
::
set
<
shape
::
type_t
>
unsupported_
types
;
shape
::
type_t
target_type
;
shape
::
type_t
target_type
;
std
::
set
<
std
::
string
>
unsupported_ops
=
{
"all"
};
std
::
string
name
()
const
{
return
"eliminate_data_type"
;
}
std
::
string
name
()
const
{
return
"eliminate_data_type"
;
}
void
apply
(
module
&
m
)
const
;
void
apply
(
module
&
m
)
const
;
};
};
...
...
src/targets/gpu/CMakeLists.txt
View file @
e9a3bdaa
...
@@ -259,6 +259,8 @@ check_library_exists(MIOpen "miopenHiddenSetConvolutionFindMode" "${MIOPEN_LOCAT
...
@@ -259,6 +259,8 @@ check_library_exists(MIOpen "miopenHiddenSetConvolutionFindMode" "${MIOPEN_LOCAT
check_library_exists
(
MIOpen
"miopenFindSolutions"
"
${
MIOPEN_LOCATION
}
"
HAS_FIND_2_API
)
check_library_exists
(
MIOpen
"miopenFindSolutions"
"
${
MIOPEN_LOCATION
}
"
HAS_FIND_2_API
)
# Beta API for automated GEMM tuning
# Beta API for automated GEMM tuning
check_library_exists
(
roc::rocblas
"rocblas_gemm_ex_get_solutions"
"
${
ROCBLAS_LOCATION
}
"
HAS_ROCBLAS_TUNING_BETA_FEATURE_API
)
check_library_exists
(
roc::rocblas
"rocblas_gemm_ex_get_solutions"
"
${
ROCBLAS_LOCATION
}
"
HAS_ROCBLAS_TUNING_BETA_FEATURE_API
)
# rocblas FP8 API
check_library_exists
(
roc::rocblas
"rocblas_gemm_strided_batched_ex3"
"
${
ROCBLAS_LOCATION
}
"
HAS_ROCBLAS_FP8_BETA_API
)
set
(
MIGRAPHX_USE_FIND_2_API
"
${
HAS_FIND_2_API
}
"
CACHE BOOL
""
)
set
(
MIGRAPHX_USE_FIND_2_API
"
${
HAS_FIND_2_API
}
"
CACHE BOOL
""
)
...
@@ -288,10 +290,18 @@ else()
...
@@ -288,10 +290,18 @@ else()
message
(
STATUS
"rocBLAS does not have User Tuning Beta API"
)
message
(
STATUS
"rocBLAS does not have User Tuning Beta API"
)
endif
()
endif
()
if
(
HAS_ROCBLAS_FP8_BETA_API
)
target_compile_definitions
(
migraphx_gpu PUBLIC -DMIGRAPHX_USE_ROCBLAS_FP8_API -DROCBLAS_BETA_FEATURES_API -DROCBLAS_NO_DEPRECATED_WARNINGS
)
message
(
STATUS
"MIGraphX is using Beta API of rocBLAS for FP8 computations"
)
else
()
message
(
STATUS
"rocBLAS does not have Fp8 Beta API"
)
endif
()
target_link_libraries
(
migraphx_gpu PUBLIC migraphx MIOpen roc::rocblas
)
target_link_libraries
(
migraphx_gpu PUBLIC migraphx MIOpen roc::rocblas
)
target_link_libraries
(
migraphx_gpu PRIVATE migraphx_device migraphx_kernels
)
target_link_libraries
(
migraphx_gpu PRIVATE migraphx_device migraphx_kernels
)
if
(
MIGRAPHX_USE_COMPOSABLEKERNEL
)
if
(
MIGRAPHX_USE_COMPOSABLEKERNEL
)
target_link_libraries
(
migraphx_gpu PRIVATE composable_kernel::jit_library
)
target_link_libraries
(
migraphx_gpu PRIVATE composable_kernel::jit_library
)
target_compile_definitions
(
migraphx_gpu PRIVATE MIGRAPHX_USE_COMPOSABLEKERNEL=1
)
endif
()
endif
()
add_subdirectory
(
driver
)
add_subdirectory
(
driver
)
...
...
src/targets/gpu/fuse_mlir.cpp
View file @
e9a3bdaa
...
@@ -531,7 +531,7 @@ void fuse_mlir::apply(module_pass_manager& mpm) const
...
@@ -531,7 +531,7 @@ void fuse_mlir::apply(module_pass_manager& mpm) const
match
::
find_matches
(
match
::
find_matches
(
mpm
,
mpm
,
find_mlir_standalone_convolution_op
{
get_mode
(
"convolution"
,
mlir_mode
::
int8
)},
find_mlir_standalone_convolution_op
{
get_mode
(
"convolution"
,
mlir_mode
::
fast
)},
find_mlir_standalone_dot_op
{
get_mode
(
"dot"
,
mlir_mode
::
none
)});
find_mlir_standalone_dot_op
{
get_mode
(
"dot"
,
mlir_mode
::
none
)});
#else
#else
(
void
)
mpm
;
(
void
)
mpm
;
...
...
src/targets/gpu/gemm_impl.cpp
View file @
e9a3bdaa
...
@@ -22,11 +22,14 @@
...
@@ -22,11 +22,14 @@
* THE SOFTWARE.
* THE SOFTWARE.
*/
*/
#include <rocblas/internal/rocblas-types.h>
#include <rocblas/rocblas.h>
#include <rocblas/rocblas.h>
#include <migraphx/gpu/rocblas.hpp>
#include <migraphx/gpu/gemm_impl.hpp>
#include <migraphx/gpu/gemm_impl.hpp>
#include <migraphx/reduce_dims.hpp>
#include <migraphx/reduce_dims.hpp>
#include <migraphx/generate.hpp>
#include <migraphx/generate.hpp>
#include <migraphx/time.hpp>
#include <migraphx/time.hpp>
#include <type_traits>
using
microseconds
=
std
::
chrono
::
duration
<
double
,
std
::
micro
>
;
using
microseconds
=
std
::
chrono
::
duration
<
double
,
std
::
micro
>
;
...
@@ -34,6 +37,20 @@ namespace migraphx {
...
@@ -34,6 +37,20 @@ namespace migraphx {
inline
namespace
MIGRAPHX_INLINE_NS
{
inline
namespace
MIGRAPHX_INLINE_NS
{
namespace
gpu
{
namespace
gpu
{
/*
Regular rocBLAS API takes compute_type as `rocblas_datatype` enum value v/s "ex3" BETA API takes it
as `rocblas_computetype` enum value. `rb_compute_type` is faciliator to implictly cast integer enum
value to required type that can be used inside `common_args` generator.
*/
struct
rb_compute_type
{
int
type
=
0
;
rb_compute_type
(
rocblas_datatype
t
)
:
type
(
static_cast
<
int
>
(
t
))
{}
rb_compute_type
(
rocblas_computetype
t
)
:
type
(
static_cast
<
int
>
(
t
))
{}
operator
rocblas_datatype
()
const
{
return
static_cast
<
rocblas_datatype
>
(
type
);
}
operator
rocblas_computetype
()
const
{
return
static_cast
<
rocblas_computetype
>
(
type
);
}
};
// Convert rocBLAS datatypes to equivalent Migraphx data types
// Convert rocBLAS datatypes to equivalent Migraphx data types
rocblas_datatype
get_type
(
shape
::
type_t
type
)
rocblas_datatype
get_type
(
shape
::
type_t
type
)
{
{
...
@@ -46,7 +63,7 @@ rocblas_datatype get_type(shape::type_t type)
...
@@ -46,7 +63,7 @@ rocblas_datatype get_type(shape::type_t type)
case
shape
::
uint8_type
:
return
rocblas_datatype_u8_r
;
case
shape
::
uint8_type
:
return
rocblas_datatype_u8_r
;
case
shape
::
int32_type
:
return
rocblas_datatype_i32_r
;
case
shape
::
int32_type
:
return
rocblas_datatype_i32_r
;
case
shape
::
uint32_type
:
return
rocblas_datatype_u32_r
;
case
shape
::
uint32_type
:
return
rocblas_datatype_u32_r
;
case
shape
::
fp8e4m3fnuz_type
:
case
shape
::
fp8e4m3fnuz_type
:
return
rocblas_datatype_f8_r
;
case
shape
::
tuple_type
:
case
shape
::
tuple_type
:
case
shape
::
bool_type
:
case
shape
::
bool_type
:
case
shape
::
uint16_type
:
case
shape
::
uint16_type
:
...
@@ -183,12 +200,17 @@ struct gemm_impl
...
@@ -183,12 +200,17 @@ struct gemm_impl
{
{
output_type
=
rocblas_datatype_i32_r
;
output_type
=
rocblas_datatype_i32_r
;
}
}
compute_type
=
output_type
;
compute_type
=
rb_compute_type
{
output_type
}
;
if
(
compute_fp32
)
if
(
compute_fp32
)
{
{
if
(
arg_type
==
rocblas_datatype_f16_r
)
if
(
arg_type
==
rocblas_datatype_f16_r
)
compute_type
=
rocblas_datatype_f32_r
;
compute_type
=
rocblas_datatype_f32_r
;
}
}
if
(
arg_type
==
rocblas_datatype_f8_r
)
{
assert
(
get_type
(
input_shapes
[
1
].
type
())
==
rocblas_datatype_f8_r
);
compute_type
=
rocblas_compute_type_f32
;
}
auto
a_lens
=
input_shapes
[
0
].
lens
();
auto
a_lens
=
input_shapes
[
0
].
lens
();
auto
b_lens
=
input_shapes
[
1
].
lens
();
auto
b_lens
=
input_shapes
[
1
].
lens
();
...
@@ -217,23 +239,52 @@ struct gemm_impl
...
@@ -217,23 +239,52 @@ struct gemm_impl
void
run
(
context
&
ctx
,
const
std
::
vector
<
argument
>&
input_args
,
int32_t
solution_idx
=
0
)
const
void
run
(
context
&
ctx
,
const
std
::
vector
<
argument
>&
input_args
,
int32_t
solution_idx
=
0
)
const
{
{
if
(
strided_batched
)
#ifdef MIGRAPHX_USE_ROCBLAS_FP8_API
if
(
rocblas_fp8_available
()
and
std
::
any_of
(
input_args
.
begin
(),
input_args
.
end
(),
[](
const
auto
i
)
{
return
i
.
get_shape
().
type
()
==
migraphx
::
shape
::
fp8e4m3fnuz_type
;
}))
{
{
auto
common_args
=
create_strided_batched_args_common
(
ctx
,
input_args
);
if
(
strided_batched
)
rocblas_invoke
(
&
rocblas_gemm_strided_batched_ex
,
{
common_args
,
auto
common_args
=
create_strided_batched_args_common
(
ctx
,
input_args
);
rocblas_gemm_algo_solution_index
,
rocblas_invoke
(
&
rocblas_gemm_strided_batched_ex3
,
solution_idx
,
common_args
,
gemm_flags
);
rocblas_gemm_algo_standard
,
solution_idx
,
gemm_flags
);
}
else
{
auto
common_args
=
create_gemm_ex_args_common
(
ctx
,
input_args
);
rocblas_invoke
(
&
rocblas_gemm_ex3
,
common_args
,
rocblas_gemm_algo_standard
,
solution_idx
,
gemm_flags
);
}
}
}
else
else
#endif
{
{
auto
common_args
=
create_gemm_ex_args_common
(
ctx
,
input_args
);
if
(
strided_batched
)
rocblas_invoke
(
&
rocblas_gemm_ex
,
{
common_args
,
auto
common_args
=
create_strided_batched_args_common
(
ctx
,
input_args
);
rocblas_gemm_algo_solution_index
,
rocblas_invoke
(
&
rocblas_gemm_strided_batched_ex
,
solution_idx
,
common_args
,
gemm_flags
);
rocblas_gemm_algo_solution_index
,
solution_idx
,
gemm_flags
);
}
else
{
auto
common_args
=
create_gemm_ex_args_common
(
ctx
,
input_args
);
rocblas_invoke
(
&
rocblas_gemm_ex
,
common_args
,
rocblas_gemm_algo_solution_index
,
solution_idx
,
gemm_flags
);
}
}
}
}
}
...
@@ -331,7 +382,6 @@ struct gemm_impl
...
@@ -331,7 +382,6 @@ struct gemm_impl
num_matrices
,
num_matrices
,
compute_type
);
compute_type
);
}
}
/**
/**
* Helper method to create that subset of a long rocBLAS argument list that is common
* Helper method to create that subset of a long rocBLAS argument list that is common
* to multiple "gemm_ex..." calls.
* to multiple "gemm_ex..." calls.
...
@@ -366,6 +416,7 @@ struct gemm_impl
...
@@ -366,6 +416,7 @@ struct gemm_impl
ldd
,
ldd
,
compute_type
);
compute_type
);
}
}
#ifdef MIGRAPHX_USE_ROCBLAS_TUNING_API
#ifdef MIGRAPHX_USE_ROCBLAS_TUNING_API
/**
/**
* Find best rocBLAS solution: Get list of solutions and try them all, returning the index
* Find best rocBLAS solution: Get list of solutions and try them all, returning the index
...
@@ -481,8 +532,8 @@ struct gemm_impl
...
@@ -481,8 +532,8 @@ struct gemm_impl
rocblas_int
b_stride
=
0
;
rocblas_int
b_stride
=
0
;
rocblas_int
c_stride
=
0
;
rocblas_int
c_stride
=
0
;
rocblas_int
d_stride
=
0
;
rocblas_int
d_stride
=
0
;
rocblas_datatype
compute_type
=
rocblas_datatype_f32_r
;
rocblas_datatype
arg_type
=
rocblas_datatype_f32_r
;
rocblas_datatype
arg_type
=
rocblas_datatype_f32_r
;
rb_compute_type
compute_type
=
rocblas_datatype_f32_r
;
rocblas_datatype
output_type
=
rocblas_datatype_f32_r
;
rocblas_datatype
output_type
=
rocblas_datatype_f32_r
;
bool
strided_batched
=
true
;
bool
strided_batched
=
true
;
bool
is_3inputs
=
true
;
bool
is_3inputs
=
true
;
...
...
src/targets/gpu/include/migraphx/gpu/rocblas.hpp
View file @
e9a3bdaa
...
@@ -40,6 +40,8 @@ struct context;
...
@@ -40,6 +40,8 @@ struct context;
MIGRAPHX_GPU_EXPORT
bool
get_compute_fp32_flag
();
MIGRAPHX_GPU_EXPORT
bool
get_compute_fp32_flag
();
MIGRAPHX_GPU_EXPORT
bool
rocblas_fp8_available
();
}
// namespace gpu
}
// namespace gpu
}
// namespace MIGRAPHX_INLINE_NS
}
// namespace MIGRAPHX_INLINE_NS
}
// namespace migraphx
}
// namespace migraphx
...
...
src/targets/gpu/kernels/include/migraphx/kernels/dpp.hpp
View file @
e9a3bdaa
...
@@ -49,12 +49,8 @@ constexpr unsigned int dpp_row_bcast(unsigned int x)
...
@@ -49,12 +49,8 @@ constexpr unsigned int dpp_row_bcast(unsigned int x)
return
y
;
return
y
;
}
}
template
<
unsigned
int
DppCtrl
,
template
<
class
T
,
class
F
>
unsigned
int
RowMask
=
0xf
,
__device__
T
dpp_op
(
T
&
x
,
F
f
)
unsigned
int
BankMask
=
0xf
,
bool
BoundCtrl
=
false
,
class
T
>
__device__
T
dpp_mov
(
T
&
x
)
{
{
static
const
index_int
n
=
sizeof
(
T
)
<
4
?
1
:
sizeof
(
T
)
/
4
;
static
const
index_int
n
=
sizeof
(
T
)
<
4
?
1
:
sizeof
(
T
)
/
4
;
union
type
union
type
...
@@ -68,10 +64,28 @@ __device__ T dpp_mov(T& x)
...
@@ -68,10 +64,28 @@ __device__ T dpp_mov(T& x)
input
.
data
=
x
;
input
.
data
=
x
;
for
(
index_int
i
=
0
;
i
<
n
;
i
++
)
for
(
index_int
i
=
0
;
i
<
n
;
i
++
)
{
{
output
.
reg
[
i
]
=
__hip_move_dpp
(
input
.
reg
[
i
],
DppCtrl
,
RowMask
,
BankMask
,
BoundCtrl
);
output
.
reg
[
i
]
=
f
(
input
.
reg
[
i
]
);
}
}
return
output
.
data
;
return
output
.
data
;
}
}
template
<
unsigned
int
DppCtrl
,
unsigned
int
RowMask
=
0xf
,
unsigned
int
BankMask
=
0xf
,
bool
BoundCtrl
=
false
,
class
T
>
__device__
T
dpp_mov
(
T
&
x
)
{
return
dpp_op
(
x
,
[](
auto
i
)
{
return
__hip_move_dpp
(
i
,
DppCtrl
,
RowMask
,
BankMask
,
BoundCtrl
);
});
}
template
<
unsigned
int
Mask
,
class
T
>
__device__
T
dpp_swizzle
(
T
&
x
)
{
return
dpp_op
(
x
,
[](
auto
i
)
{
return
__hip_ds_swizzle
(
i
,
Mask
);
});
}
#endif // MIGRAPHX_HAS_DPP
#endif // MIGRAPHX_HAS_DPP
}
// namespace migraphx
}
// namespace migraphx
...
...
src/targets/gpu/kernels/include/migraphx/kernels/float8.hpp
View file @
e9a3bdaa
...
@@ -501,9 +501,7 @@ class numeric_limits<fp8e5m2fnuz>
...
@@ -501,9 +501,7 @@ class numeric_limits<fp8e5m2fnuz>
{
{
return
fp8e5m2fnuz
(
0x7F
,
fp8e5m2fnuz
::
from_bits
());
return
fp8e5m2fnuz
(
0x7F
,
fp8e5m2fnuz
::
from_bits
());
}
}
// this is min value that is not DeNormalized(DeNorm). DeNorm min is 0x01. I am not sure if we
// this is min value that is not DeNormalized(DeNorm). DeNorm min is 0x01.
// want to make this distinction. For the floating points we would end up using lowest most of
// the times.
static
constexpr
__device__
fp8e5m2fnuz
min
()
static
constexpr
__device__
fp8e5m2fnuz
min
()
{
{
return
fp8e5m2fnuz
(
0x4
,
fp8e5m2fnuz
::
from_bits
());
return
fp8e5m2fnuz
(
0x4
,
fp8e5m2fnuz
::
from_bits
());
...
@@ -528,9 +526,7 @@ class numeric_limits<fp8e5m2>
...
@@ -528,9 +526,7 @@ class numeric_limits<fp8e5m2>
}
}
static
constexpr
__device__
fp8e5m2
max
()
{
return
fp8e5m2
(
0x7B
,
fp8e5m2
::
from_bits
());
}
static
constexpr
__device__
fp8e5m2
max
()
{
return
fp8e5m2
(
0x7B
,
fp8e5m2
::
from_bits
());
}
// this is min value that is not DeNormalized(DeNorm). DeNorm min is 0x01. I am not sure if we
// this is min value that is not DeNormalized(DeNorm). DeNorm min is 0x01.
// want to make this distinction. For the floating points we would end up using lowest most of
// the times.
static
constexpr
__device__
fp8e5m2
min
()
{
return
fp8e5m2
(
0x4
,
fp8e5m2
::
from_bits
());
}
static
constexpr
__device__
fp8e5m2
min
()
{
return
fp8e5m2
(
0x4
,
fp8e5m2
::
from_bits
());
}
static
constexpr
__device__
fp8e5m2
lowest
()
{
return
fp8e5m2
(
0xFB
,
fp8e5m2
::
from_bits
());
}
static
constexpr
__device__
fp8e5m2
lowest
()
{
return
fp8e5m2
(
0xFB
,
fp8e5m2
::
from_bits
());
}
...
...
src/targets/gpu/kernels/include/migraphx/kernels/reduce.hpp
View file @
e9a3bdaa
...
@@ -45,7 +45,10 @@ __device__ void dpp_reduce(T& in, Op op)
...
@@ -45,7 +45,10 @@ __device__ void dpp_reduce(T& in, Op op)
in
=
op
(
in
,
out
);
in
=
op
(
in
,
out
);
out
=
dpp_mov
<
dpp_row_shr
(
8
),
0xf
,
0xc
>
(
in
);
out
=
dpp_mov
<
dpp_row_shr
(
8
),
0xf
,
0xc
>
(
in
);
in
=
op
(
in
,
out
);
in
=
op
(
in
,
out
);
#if __AMDGCN_WAVEFRONT_SIZE == 64
#if __AMDGCN_WAVEFRONT_SIZE == 32
out
=
dpp_swizzle
<
0x1e0
>
(
in
);
in
=
op
(
in
,
out
);
#else
out
=
dpp_mov
<
dpp_row_bcast
(
15
),
0xa
>
(
in
);
out
=
dpp_mov
<
dpp_row_bcast
(
15
),
0xa
>
(
in
);
in
=
op
(
in
,
out
);
in
=
op
(
in
,
out
);
out
=
dpp_mov
<
dpp_row_bcast
(
31
),
0xc
>
(
in
);
out
=
dpp_mov
<
dpp_row_bcast
(
31
),
0xc
>
(
in
);
...
@@ -54,9 +57,11 @@ __device__ void dpp_reduce(T& in, Op op)
...
@@ -54,9 +57,11 @@ __device__ void dpp_reduce(T& in, Op op)
}
}
#if defined(MIGRAPHX_USE_CLANG_TIDY) || defined(CPPCHECK)
#if defined(MIGRAPHX_USE_CLANG_TIDY) || defined(CPPCHECK)
// NOLINTNEXTLINE
// NOLINTNEXTLINE
#define MIGRAPHX_DPP_REDUCE_ASM(x, ins) x = 1
#define MIGRAPHX_DPP_REDUCE_ASM(x, ins, f) \
(void)f; \
x = 1
#elif __AMDGCN_WAVEFRONT_SIZE == 64
#elif __AMDGCN_WAVEFRONT_SIZE == 64
#define MIGRAPHX_DPP_REDUCE_ASM(x, ins
)
\
#define MIGRAPHX_DPP_REDUCE_ASM(x, ins
, f)
\
__asm__ volatile("s_nop 4\n" #ins " %0 %0 %0 row_shr:1\n" \
__asm__ volatile("s_nop 4\n" #ins " %0 %0 %0 row_shr:1\n" \
"s_nop 1\n" #ins " %0 %0 %0 row_shr:2\n" \
"s_nop 1\n" #ins " %0 %0 %0 row_shr:2\n" \
"s_nop 1\n" #ins " %0 %0 %0 row_shr:4 bank_mask:0xe\n" \
"s_nop 1\n" #ins " %0 %0 %0 row_shr:4 bank_mask:0xe\n" \
...
@@ -65,29 +70,42 @@ __device__ void dpp_reduce(T& in, Op op)
...
@@ -65,29 +70,42 @@ __device__ void dpp_reduce(T& in, Op op)
"s_nop 1\n" #ins " %0 %0 %0 row_bcast:31 row_mask:0xc\n" \
"s_nop 1\n" #ins " %0 %0 %0 row_bcast:31 row_mask:0xc\n" \
"s_nop 1\n" \
"s_nop 1\n" \
: "=v"(x) \
: "=v"(x) \
: "0"(x))
: "0"(x)); \
(void)f
#else
#else
#define MIGRAPHX_DPP_REDUCE_ASM(x, ins
)
\
#define MIGRAPHX_DPP_REDUCE_ASM(x, ins
, f)
\
__asm__ volatile("s_nop 4\n" #ins " %0 %0 %0 row_shr:1\n" \
__asm__ volatile("s_nop 4\n" #ins " %0 %0 %0 row_shr:1\n" \
"s_nop 1\n" #ins " %0 %0 %0 row_shr:2\n" \
"s_nop 1\n" #ins " %0 %0 %0 row_shr:2\n" \
"s_nop 1\n" #ins " %0 %0 %0 row_shr:4 bank_mask:0xe\n" \
"s_nop 1\n" #ins " %0 %0 %0 row_shr:4 bank_mask:0xe\n" \
"s_nop 1\n" #ins " %0 %0 %0 row_shr:8 bank_mask:0xc\n" \
"s_nop 1\n" #ins " %0 %0 %0 row_shr:8 bank_mask:0xc\n" \
"s_nop 1\n" \
"s_nop 1\n" \
: "=v"(x) \
: "=v"(x) \
: "0"(x))
: "0"(x)); \
auto y = dpp_swizzle<0x1e0>(x); \
x = f(x, y)
#endif
#endif
// NOLINTNEXTLINE
// NOLINTNEXTLINE
#define MIGRAPHX_DPP_REDUCE(op, prefix, sign) \
#define MIGRAPHX_DPP_REDUCE(op, prefix, sign) \
__device__ inline void dpp_reduce(double& x, op) { MIGRAPHX_DPP_REDUCE_ASM(x, prefix##_f64); } \
__device__ inline void dpp_reduce(double& x, op f) \
__device__ inline void dpp_reduce(float& x, op) { MIGRAPHX_DPP_REDUCE_ASM(x, prefix##_f32); } \
{ \
__device__ inline void dpp_reduce(half& x, op) { MIGRAPHX_DPP_REDUCE_ASM(x, prefix##_f16); } \
MIGRAPHX_DPP_REDUCE_ASM(x, prefix##_f64, f); \
__device__ inline void dpp_reduce(int32_t& x, op) \
} \
{ \
__device__ inline void dpp_reduce(float& x, op f) \
MIGRAPHX_DPP_REDUCE_ASM(x, prefix##sign##32); \
{ \
} \
MIGRAPHX_DPP_REDUCE_ASM(x, prefix##_f32, f); \
__device__ inline void dpp_reduce(uint32_t& x, op) { MIGRAPHX_DPP_REDUCE_ASM(x, prefix##_u32); }
} \
__device__ inline void dpp_reduce(half& x, op f) \
{ \
MIGRAPHX_DPP_REDUCE_ASM(x, prefix##_f16, f); \
} \
__device__ inline void dpp_reduce(int32_t& x, op f) \
{ \
MIGRAPHX_DPP_REDUCE_ASM(x, prefix##sign##32, f); \
} \
__device__ inline void dpp_reduce(uint32_t& x, op f) \
{ \
MIGRAPHX_DPP_REDUCE_ASM(x, prefix##_u32, f); \
}
// Note: when max and min are in int32_t, signed version of instruction needs to be used.
// Note: when max and min are in int32_t, signed version of instruction needs to be used.
MIGRAPHX_DPP_REDUCE
(
op
::
sum
,
v_add
,
_u
)
MIGRAPHX_DPP_REDUCE
(
op
::
sum
,
v_add
,
_u
)
...
@@ -99,11 +117,7 @@ template <class Op, class T, class Index, class F>
...
@@ -99,11 +117,7 @@ template <class Op, class T, class Index, class F>
__device__
auto
block_reduce
(
index
idx
,
Op
op
,
T
init
,
Index
n
,
F
f
)
__device__
auto
block_reduce
(
index
idx
,
Op
op
,
T
init
,
Index
n
,
F
f
)
{
{
MIGRAPHX_ASSERT
(
idx
.
max_nlocal
()
==
idx
.
nlocal
());
MIGRAPHX_ASSERT
(
idx
.
max_nlocal
()
==
idx
.
nlocal
());
#if __AMDGCN_WAVEFRONT_SIZE == 32
constexpr
index_int
lanes_per_thread
=
__AMDGCN_WAVEFRONT_SIZE
;
constexpr
index_int
lanes_per_thread
=
16
;
#else
constexpr
index_int
lanes_per_thread
=
64
;
#endif
using
type
=
decltype
(
index
::
invoke_loop
(
f
,
0
,
_c
<
0
>
));
using
type
=
decltype
(
index
::
invoke_loop
(
f
,
0
,
_c
<
0
>
));
__shared__
type
buffer
[
idx
.
max_nlocal
()
/
lanes_per_thread
];
__shared__
type
buffer
[
idx
.
max_nlocal
()
/
lanes_per_thread
];
type
x
=
type
(
init
);
type
x
=
type
(
init
);
...
...
src/targets/gpu/rocblas.cpp
View file @
e9a3bdaa
...
@@ -53,6 +53,16 @@ bool get_compute_fp32_flag()
...
@@ -53,6 +53,16 @@ bool get_compute_fp32_flag()
return
(
starts_with
(
device_name
,
"gfx9"
)
and
device_name
>=
"gfx908"
);
return
(
starts_with
(
device_name
,
"gfx9"
)
and
device_name
>=
"gfx908"
);
}
}
bool
rocblas_fp8_available
()
{
#ifndef MIGRAPHX_USE_ROCBLAS_FP8_API
return
false
;
#else
const
auto
device_name
=
trim
(
split_string
(
get_device_name
(),
':'
).
front
());
return
(
starts_with
(
device_name
,
"gfx9"
)
and
device_name
>=
"gfx940"
);
#endif
}
}
// namespace gpu
}
// namespace gpu
}
// namespace MIGRAPHX_INLINE_NS
}
// namespace MIGRAPHX_INLINE_NS
}
// namespace migraphx
}
// namespace migraphx
src/targets/gpu/target.cpp
View file @
e9a3bdaa
...
@@ -105,6 +105,21 @@ std::vector<pass> target::get_passes(migraphx::context& gctx, const compile_opti
...
@@ -105,6 +105,21 @@ std::vector<pass> target::get_passes(migraphx::context& gctx, const compile_opti
unsupported_types
.
erase
(
shape
::
type_t
::
uint8_type
);
unsupported_types
.
erase
(
shape
::
type_t
::
uint8_type
);
unsupported_types
.
erase
(
shape
::
type_t
::
int32_type
);
unsupported_types
.
erase
(
shape
::
type_t
::
int32_type
);
unsupported_types
.
erase
(
shape
::
type_t
::
tuple_type
);
unsupported_types
.
erase
(
shape
::
type_t
::
tuple_type
);
std
::
set
<
std
::
string
>
unsupported_fp8_ops
=
{};
if
(
not
gpu
::
rocblas_fp8_available
())
{
unsupported_fp8_ops
.
insert
(
"dot"
);
}
// add all device kernels
unsupported_fp8_ops
.
insert
(
"logsoftmax"
);
unsupported_fp8_ops
.
insert
(
"nonzero"
);
unsupported_fp8_ops
.
insert
(
"prefix_scan_sum"
);
unsupported_fp8_ops
.
insert
(
"scatter_none"
);
unsupported_fp8_ops
.
insert
(
"topk"
);
unsupported_fp8_ops
.
insert
(
"rnn_var_sl_shift_output"
);
unsupported_fp8_ops
.
insert
(
"multinomial"
);
unsupported_fp8_ops
.
insert
(
"argmax"
);
unsupported_fp8_ops
.
insert
(
"argmin"
);
// clang-format off
// clang-format off
return
return
{
{
...
@@ -135,6 +150,8 @@ std::vector<pass> target::get_passes(migraphx::context& gctx, const compile_opti
...
@@ -135,6 +150,8 @@ std::vector<pass> target::get_passes(migraphx::context& gctx, const compile_opti
dead_code_elimination
{},
dead_code_elimination
{},
prefuse_ops
{},
prefuse_ops
{},
dead_code_elimination
{},
dead_code_elimination
{},
eliminate_data_type
{{
migraphx
::
shape
::
fp8e4m3fnuz_type
},
shape
::
float_type
,
unsupported_fp8_ops
},
dead_code_elimination
{},
optimize_module
{},
optimize_module
{},
fuse_pointwise
{},
fuse_pointwise
{},
dead_code_elimination
{},
dead_code_elimination
{},
...
...
test/verify/gemm_2args_bmv.cpp
View file @
e9a3bdaa
...
@@ -27,14 +27,15 @@
...
@@ -27,14 +27,15 @@
#include <migraphx/generate.hpp>
#include <migraphx/generate.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/make_op.hpp>
struct
gemm_2args_bmv
:
verify_program
<
gemm_2args_bmv
>
template
<
migraphx
::
shape
::
type_t
DType
>
struct
gemm_2args_bmv
:
verify_program
<
gemm_2args_bmv
<
DType
>>
{
{
migraphx
::
program
create_program
()
const
migraphx
::
program
create_program
()
const
{
{
migraphx
::
program
p
;
migraphx
::
program
p
;
auto
*
mm
=
p
.
get_main_module
();
auto
*
mm
=
p
.
get_main_module
();
migraphx
::
shape
m1_shape
{
migraphx
::
shape
::
float_t
ype
,
{
2
,
3
,
3
,
5
}};
migraphx
::
shape
m1_shape
{
DT
ype
,
{
2
,
3
,
3
,
5
}};
migraphx
::
shape
m2_shape
{
migraphx
::
shape
::
float_t
ype
,
{
5
}};
migraphx
::
shape
m2_shape
{
DT
ype
,
{
5
}};
auto
l1
=
mm
->
add_parameter
(
"1"
,
m1_shape
);
auto
l1
=
mm
->
add_parameter
(
"1"
,
m1_shape
);
auto
l2
=
mm
->
add_parameter
(
"2"
,
m2_shape
);
auto
l2
=
mm
->
add_parameter
(
"2"
,
m2_shape
);
auto
ul2
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"unsqueeze"
,
{{
"axes"
,
{
1
}}}),
l2
);
auto
ul2
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"unsqueeze"
,
{{
"axes"
,
{
1
}}}),
l2
);
...
@@ -46,3 +47,7 @@ struct gemm_2args_bmv : verify_program<gemm_2args_bmv>
...
@@ -46,3 +47,7 @@ struct gemm_2args_bmv : verify_program<gemm_2args_bmv>
return
p
;
return
p
;
}
}
};
};
template
struct
gemm_2args_bmv
<
migraphx
::
shape
::
float_type
>;
template
struct
gemm_2args_bmv
<
migraphx
::
shape
::
half_type
>;
template
struct
gemm_2args_bmv
<
migraphx
::
shape
::
fp8e4m3fnuz_type
>;
test/verify/gemm_2args_mm_1.cpp
View file @
e9a3bdaa
...
@@ -27,14 +27,15 @@
...
@@ -27,14 +27,15 @@
#include <migraphx/generate.hpp>
#include <migraphx/generate.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/make_op.hpp>
struct
gemm_2args_mm_1
:
verify_program
<
gemm_2args_mm_1
>
template
<
migraphx
::
shape
::
type_t
DType
>
struct
gemm_2args_mm_1
:
verify_program
<
gemm_2args_mm_1
<
DType
>>
{
{
migraphx
::
program
create_program
()
const
migraphx
::
program
create_program
()
const
{
{
migraphx
::
program
p
;
migraphx
::
program
p
;
auto
*
mm
=
p
.
get_main_module
();
auto
*
mm
=
p
.
get_main_module
();
migraphx
::
shape
m1_shape
{
migraphx
::
shape
::
float_t
ype
,
{
2
,
2
,
3
}};
migraphx
::
shape
m1_shape
{
DT
ype
,
{
2
,
2
,
3
}};
migraphx
::
shape
m2_shape
{
migraphx
::
shape
::
float_t
ype
,
{
1
,
3
,
4
}};
migraphx
::
shape
m2_shape
{
DT
ype
,
{
1
,
3
,
4
}};
auto
l1
=
mm
->
add_parameter
(
"1"
,
m1_shape
);
auto
l1
=
mm
->
add_parameter
(
"1"
,
m1_shape
);
auto
l2
=
mm
->
add_parameter
(
"2"
,
m2_shape
);
auto
l2
=
mm
->
add_parameter
(
"2"
,
m2_shape
);
auto
bl2
=
auto
bl2
=
...
@@ -45,3 +46,7 @@ struct gemm_2args_mm_1 : verify_program<gemm_2args_mm_1>
...
@@ -45,3 +46,7 @@ struct gemm_2args_mm_1 : verify_program<gemm_2args_mm_1>
return
p
;
return
p
;
}
}
};
};
template
struct
gemm_2args_mm_1
<
migraphx
::
shape
::
float_type
>;
template
struct
gemm_2args_mm_1
<
migraphx
::
shape
::
half_type
>;
template
struct
gemm_2args_mm_1
<
migraphx
::
shape
::
fp8e4m3fnuz_type
>;
test/verify/gemm_2args_mm_2.cpp
View file @
e9a3bdaa
...
@@ -24,17 +24,19 @@
...
@@ -24,17 +24,19 @@
#include "verify_program.hpp"
#include "verify_program.hpp"
#include <migraphx/program.hpp>
#include <migraphx/program.hpp>
#include <migraphx/shape.hpp>
#include <migraphx/generate.hpp>
#include <migraphx/generate.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/make_op.hpp>
struct
gemm_2args_mm_2
:
verify_program
<
gemm_2args_mm_2
>
template
<
migraphx
::
shape
::
type_t
DType
>
struct
gemm_2args_mm_2
:
verify_program
<
gemm_2args_mm_2
<
DType
>>
{
{
migraphx
::
program
create_program
()
const
migraphx
::
program
create_program
()
const
{
{
migraphx
::
program
p
;
migraphx
::
program
p
;
auto
*
mm
=
p
.
get_main_module
();
auto
*
mm
=
p
.
get_main_module
();
migraphx
::
shape
m1_shape
{
migraphx
::
shape
::
float_t
ype
,
{
2
,
2
,
3
}};
migraphx
::
shape
m1_shape
{
DT
ype
,
{
2
,
2
,
3
}};
migraphx
::
shape
m2_shape
{
migraphx
::
shape
::
float_t
ype
,
{
3
,
4
}};
migraphx
::
shape
m2_shape
{
DT
ype
,
{
3
,
4
}};
auto
l1
=
mm
->
add_parameter
(
"1"
,
m1_shape
);
auto
l1
=
mm
->
add_parameter
(
"1"
,
m1_shape
);
auto
l2
=
mm
->
add_parameter
(
"2"
,
m2_shape
);
auto
l2
=
mm
->
add_parameter
(
"2"
,
m2_shape
);
auto
bl2
=
auto
bl2
=
...
@@ -45,3 +47,7 @@ struct gemm_2args_mm_2 : verify_program<gemm_2args_mm_2>
...
@@ -45,3 +47,7 @@ struct gemm_2args_mm_2 : verify_program<gemm_2args_mm_2>
return
p
;
return
p
;
}
}
};
};
template
struct
gemm_2args_mm_2
<
migraphx
::
shape
::
float_type
>;
template
struct
gemm_2args_mm_2
<
migraphx
::
shape
::
half_type
>;
template
struct
gemm_2args_mm_2
<
migraphx
::
shape
::
fp8e4m3fnuz_type
>;
test/verify/gemm_2args_mm_3.cpp
View file @
e9a3bdaa
...
@@ -24,17 +24,19 @@
...
@@ -24,17 +24,19 @@
#include "verify_program.hpp"
#include "verify_program.hpp"
#include <migraphx/program.hpp>
#include <migraphx/program.hpp>
#include <migraphx/shape.hpp>
#include <migraphx/generate.hpp>
#include <migraphx/generate.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/make_op.hpp>
struct
gemm_2args_mm_3
:
verify_program
<
gemm_2args_mm_3
>
template
<
migraphx
::
shape
::
type_t
DType
>
struct
gemm_2args_mm_3
:
verify_program
<
gemm_2args_mm_3
<
DType
>>
{
{
migraphx
::
program
create_program
()
const
migraphx
::
program
create_program
()
const
{
{
migraphx
::
program
p
;
migraphx
::
program
p
;
auto
*
mm
=
p
.
get_main_module
();
auto
*
mm
=
p
.
get_main_module
();
migraphx
::
shape
m1_shape
{
migraphx
::
shape
::
float_t
ype
,
{
1
,
2
,
3
}};
migraphx
::
shape
m1_shape
{
DT
ype
,
{
1
,
2
,
3
}};
migraphx
::
shape
m2_shape
{
migraphx
::
shape
::
float_t
ype
,
{
3
,
3
,
4
}};
migraphx
::
shape
m2_shape
{
DT
ype
,
{
3
,
3
,
4
}};
auto
l1
=
mm
->
add_parameter
(
"1"
,
m1_shape
);
auto
l1
=
mm
->
add_parameter
(
"1"
,
m1_shape
);
auto
bl1
=
auto
bl1
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"multibroadcast"
,
{{
"out_lens"
,
{
3
,
2
,
3
}}}),
l1
);
mm
->
add_instruction
(
migraphx
::
make_op
(
"multibroadcast"
,
{{
"out_lens"
,
{
3
,
2
,
3
}}}),
l1
);
...
@@ -45,3 +47,7 @@ struct gemm_2args_mm_3 : verify_program<gemm_2args_mm_3>
...
@@ -45,3 +47,7 @@ struct gemm_2args_mm_3 : verify_program<gemm_2args_mm_3>
return
p
;
return
p
;
}
}
};
};
template
struct
gemm_2args_mm_3
<
migraphx
::
shape
::
float_type
>;
template
struct
gemm_2args_mm_3
<
migraphx
::
shape
::
half_type
>;
template
struct
gemm_2args_mm_3
<
migraphx
::
shape
::
fp8e4m3fnuz_type
>;
Prev
1
2
3
4
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment