Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
MIGraphX
Commits
3bc4a779
Unverified
Commit
3bc4a779
authored
Dec 06, 2023
by
Chris Austen
Committed by
GitHub
Dec 06, 2023
Browse files
Merge branch 'develop' into doc-standard
parents
3053fc95
6a72e8fc
Changes
92
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
420 additions
and
102 deletions
+420
-102
docs/driver/read.rst
docs/driver/read.rst
+4
-0
src/driver/CMakeLists.txt
src/driver/CMakeLists.txt
+1
-0
src/driver/main.cpp
src/driver/main.cpp
+5
-0
src/driver/passes.cpp
src/driver/passes.cpp
+109
-0
src/driver/passes.hpp
src/driver/passes.hpp
+14
-17
src/eliminate_data_type.cpp
src/eliminate_data_type.cpp
+72
-20
src/include/migraphx/eliminate_data_type.hpp
src/include/migraphx/eliminate_data_type.hpp
+2
-1
src/include/migraphx/op/quant_convolution.hpp
src/include/migraphx/op/quant_convolution.hpp
+11
-5
src/targets/gpu/CMakeLists.txt
src/targets/gpu/CMakeLists.txt
+10
-0
src/targets/gpu/device_name.cpp
src/targets/gpu/device_name.cpp
+6
-0
src/targets/gpu/fuse_mlir.cpp
src/targets/gpu/fuse_mlir.cpp
+19
-7
src/targets/gpu/gemm_impl.cpp
src/targets/gpu/gemm_impl.cpp
+68
-17
src/targets/gpu/include/migraphx/gpu/device_name.hpp
src/targets/gpu/include/migraphx/gpu/device_name.hpp
+2
-0
src/targets/gpu/include/migraphx/gpu/rocblas.hpp
src/targets/gpu/include/migraphx/gpu/rocblas.hpp
+2
-0
src/targets/gpu/kernels/include/migraphx/kernels/dpp.hpp
src/targets/gpu/kernels/include/migraphx/kernels/dpp.hpp
+21
-7
src/targets/gpu/kernels/include/migraphx/kernels/float8.hpp
src/targets/gpu/kernels/include/migraphx/kernels/float8.hpp
+2
-6
src/targets/gpu/kernels/include/migraphx/kernels/reduce.hpp
src/targets/gpu/kernels/include/migraphx/kernels/reduce.hpp
+36
-22
src/targets/gpu/mlir.cpp
src/targets/gpu/mlir.cpp
+2
-0
src/targets/gpu/rocblas.cpp
src/targets/gpu/rocblas.cpp
+9
-0
src/targets/gpu/target.cpp
src/targets/gpu/target.cpp
+25
-0
No files found.
docs/driver/read.rst
View file @
3bc4a779
...
...
@@ -58,6 +58,10 @@ Set the default dynamic dimension (format {min:x, max:y, optimals:[o1,o2,...]})
Optimize when reading
.. option:: --apply-pass, -p
Passes to apply to model
.. option:: --graphviz, -g
Print out a graphviz representation.
...
...
src/driver/CMakeLists.txt
View file @
3bc4a779
...
...
@@ -25,6 +25,7 @@
add_executable
(
driver
main.cpp
verify.cpp
passes.cpp
perf.cpp
resnet50.cpp
inceptionv3.cpp
...
...
src/driver/main.cpp
View file @
3bc4a779
...
...
@@ -26,6 +26,7 @@
#include "argument_parser.hpp"
#include "command.hpp"
#include "precision.hpp"
#include "passes.hpp"
#include "perf.hpp"
#include "models.hpp"
#include "marker_roctx.hpp"
...
...
@@ -83,6 +84,7 @@ struct loader
std
::
vector
<
std
::
string
>
param_dims
;
std
::
vector
<
std
::
string
>
dyn_param_dims
;
std
::
vector
<
std
::
string
>
output_names
;
std
::
vector
<
std
::
string
>
passes
;
void
parse
(
argument_parser
&
ap
)
{
...
...
@@ -130,6 +132,7 @@ struct loader
ap
.
append
(),
ap
.
nargs
(
2
));
ap
(
optimize
,
{
"--optimize"
,
"-O"
},
ap
.
help
(
"Optimize when reading"
),
ap
.
set_value
(
true
));
ap
(
passes
,
{
"--apply-pass"
,
"-p"
},
ap
.
help
(
"Passes to apply to model"
),
ap
.
append
());
ap
(
output_type
,
{
"--graphviz"
,
"-g"
},
ap
.
help
(
"Print out a graphviz representation."
),
...
...
@@ -337,6 +340,8 @@ struct loader
migraphx
::
dead_code_elimination
{},
});
}
if
(
not
passes
.
empty
())
migraphx
::
run_passes
(
*
p
.
get_main_module
(),
get_passes
(
passes
));
return
p
;
}
...
...
test/verify/test_conv_relu_half
.cpp
→
src/driver/passes
.cpp
View file @
3bc4a779
...
...
@@ -22,23 +22,88 @@
* THE SOFTWARE.
*/
#include "verify_program.hpp"
#include <migraphx/program.hpp>
#include <migraphx/generate.hpp>
#include <migraphx/make_op.hpp>
#include "passes.hpp"
struct
test_conv_relu_half
:
verify_program
<
test_conv_relu_half
>
#include <migraphx/auto_contiguous.hpp>
#include <migraphx/dead_code_elimination.hpp>
#include <migraphx/eliminate_allocation.hpp>
#include <migraphx/eliminate_common_subexpression.hpp>
#include <migraphx/eliminate_concat.hpp>
#include <migraphx/eliminate_contiguous.hpp>
#include <migraphx/eliminate_data_type.hpp>
#include <migraphx/eliminate_identity.hpp>
#include <migraphx/eliminate_pad.hpp>
#include <migraphx/inline_module.hpp>
#include <migraphx/insert_pad.hpp>
#include <migraphx/normalize_ops.hpp>
#include <migraphx/optimize_module.hpp>
#include <migraphx/promote_literals.hpp>
#include <migraphx/propagate_constant.hpp>
#include <migraphx/rewrite_gelu.hpp>
#include <migraphx/rewrite_pooling.hpp>
#include <migraphx/rewrite_quantization.hpp>
#include <migraphx/rewrite_rnn.hpp>
#include <migraphx/simplify_algebra.hpp>
#include <migraphx/simplify_dyn_ops.hpp>
#include <migraphx/simplify_qdq.hpp>
#include <migraphx/simplify_reshapes.hpp>
#include <migraphx/ranges.hpp>
#include <unordered_map>
namespace
migraphx
{
namespace
driver
{
inline
namespace
MIGRAPHX_INLINE_NS
{
std
::
unordered_map
<
std
::
string
,
pass
>
create_passes_lookup
()
{
migraphx
::
program
create_program
()
const
{
migraphx
::
program
p
;
auto
*
mm
=
p
.
get_main_module
();
auto
input
=
mm
->
add_parameter
(
"x"
,
migraphx
::
shape
{
migraphx
::
shape
::
half_type
,
{
4
,
3
,
3
,
3
}});
auto
weights
=
mm
->
add_parameter
(
"w"
,
migraphx
::
shape
{
migraphx
::
shape
::
half_type
,
{
4
,
3
,
3
,
3
}});
auto
conv
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"convolution"
),
input
,
weights
);
mm
->
add_instruction
(
migraphx
::
make_op
(
"relu"
),
conv
);
return
p
;
}
};
std
::
unordered_map
<
std
::
string
,
pass
>
result
;
// clang-format off
std
::
initializer_list
<
pass
>
passes
=
{
auto_contiguous
{},
dead_code_elimination
{},
eliminate_allocation
{},
eliminate_common_subexpression
{},
eliminate_concat
{},
eliminate_contiguous
{},
eliminate_data_type
{},
eliminate_identity
{},
eliminate_pad
{},
inline_module
{},
insert_pad
{},
normalize_ops
{},
optimize_module
{},
promote_literals
{},
propagate_constant
{},
rewrite_gelu
{},
rewrite_pooling
{},
rewrite_quantization
{},
rewrite_rnn
{},
simplify_algebra
{},
simplify_dyn_ops
{},
simplify_qdq
{},
simplify_reshapes
{},
};
// clang-format on
for
(
const
auto
&
pass
:
passes
)
result
[
pass
.
name
()]
=
pass
;
result
[
"eliminate_dead_code"
]
=
dead_code_elimination
{};
return
result
;
}
std
::
vector
<
pass
>
get_passes
(
const
std
::
vector
<
std
::
string
>&
names
)
{
std
::
vector
<
pass
>
result
;
static
const
std
::
unordered_map
<
std
::
string
,
pass
>
lookup
=
create_passes_lookup
();
std
::
transform
(
names
.
begin
(),
names
.
end
(),
std
::
back_inserter
(
result
),
[](
const
std
::
string
&
name
)
{
if
(
not
contains
(
lookup
,
name
))
MIGRAPHX_THROW
(
"Unknown pass: "
+
name
);
return
lookup
.
at
(
name
);
});
return
result
;
}
}
// namespace MIGRAPHX_INLINE_NS
}
// namespace driver
}
// namespace migraphx
test/verify/test_nonzero_half.c
pp
→
src/driver/passes.h
pp
View file @
3bc4a779
...
...
@@ -21,23 +21,20 @@
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#ifndef MIGRAPHX_GUARD_DRIVER_PASSES_HPP
#define MIGRAPHX_GUARD_DRIVER_PASSES_HPP
#include "verify_program.hpp"
#include <migraphx/program.hpp>
#include <migraphx/generate.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/pass.hpp>
#include <vector>
struct
test_nonzero_half
:
verify_program
<
test_nonzero_half
>
{
migraphx
::
program
create_program
()
const
{
migraphx
::
program
p
;
auto
*
mm
=
p
.
get_main_module
();
migraphx
::
shape
s
{
migraphx
::
shape
::
half_type
,
{
3
,
4
,
3
,
5
}};
auto
x
=
mm
->
add_parameter
(
"data"
,
s
);
auto
r
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"nonzero"
),
x
);
mm
->
add_return
({
r
});
namespace
migraphx
{
namespace
driver
{
inline
namespace
MIGRAPHX_INLINE_NS
{
return
p
;
}
};
std
::
vector
<
pass
>
get_passes
(
const
std
::
vector
<
std
::
string
>&
names
);
}
// namespace MIGRAPHX_INLINE_NS
}
// namespace driver
}
// namespace migraphx
#endif
src/eliminate_data_type.cpp
View file @
3bc4a779
...
...
@@ -31,6 +31,72 @@
namespace
migraphx
{
inline
namespace
MIGRAPHX_INLINE_NS
{
void
insert_convert_to_supported_type
(
module
&
m
,
instruction_ref
ins
,
migraphx
::
shape
::
type_t
target_type
,
std
::
set
<
migraphx
::
shape
::
type_t
>
unsupported_types
)
{
migraphx
::
shape
::
type_t
orig_type
=
ins
->
get_shape
().
type
();
std
::
vector
<
instruction_ref
>
inputs
=
ins
->
inputs
();
std
::
transform
(
inputs
.
begin
(),
inputs
.
end
(),
inputs
.
begin
(),
[
&
](
const
auto
&
i
)
{
if
(
contains
(
unsupported_types
,
i
->
get_shape
().
type
()))
{
return
m
.
insert_instruction
(
ins
,
migraphx
::
make_op
(
"convert"
,
{{
"target_type"
,
migraphx
::
to_value
(
target_type
)}}),
i
);
}
else
{
return
i
;
}
});
// if no change
if
(
inputs
==
ins
->
inputs
())
return
;
auto
op
=
ins
->
get_operator
();
auto
attributes
=
op
.
attributes
();
if
(
attributes
.
contains
(
"general_data_type"
))
{
op
=
make_op
(
attributes
[
"general_data_type"
].
to
<
std
::
string
>
(),
op
.
to_value
());
}
auto
new_ins
=
m
.
insert_instruction
(
ins
,
op
,
inputs
);
if
(
orig_type
==
shape
::
tuple_type
)
{
auto
orig_outs
=
ins
->
outputs
();
if
(
not
std
::
all_of
(
orig_outs
.
begin
(),
orig_outs
.
end
(),
[
&
](
const
auto
out_ins
)
{
return
out_ins
->
name
()
==
"get_tuple_elem"
;
}))
MIGRAPHX_THROW
(
"eliminate_data_type: Instruction with tuple output doesn't have all its "
"usages as get_tuple_elem instruction"
);
std
::
transform
(
orig_outs
.
begin
(),
orig_outs
.
end
(),
orig_outs
.
begin
(),
[
&
](
const
auto
out_ins
)
{
auto
gte_ins
=
m
.
insert_instruction
(
ins
,
out_ins
->
get_operator
(),
new_ins
);
auto
orig_out_type
=
out_ins
->
get_shape
().
type
();
if
(
contains
(
unsupported_types
,
orig_out_type
))
{
auto
gte_convert
=
m
.
insert_instruction
(
ins
,
make_op
(
"convert"
,
{{
"target_type"
,
orig_out_type
}}),
gte_ins
);
return
m
.
replace_instruction
(
out_ins
,
gte_convert
);
}
else
{
return
m
.
replace_instruction
(
out_ins
,
gte_ins
);
}
});
}
else
{
auto
convert_back_ins
=
m
.
insert_instruction
(
ins
,
migraphx
::
make_op
(
"convert"
,
{{
"target_type"
,
migraphx
::
to_value
(
orig_type
)}}),
new_ins
);
m
.
replace_instruction
(
ins
,
convert_back_ins
);
}
}
void
eliminate_data_type
::
apply
(
module
&
m
)
const
{
static
const
std
::
vector
<
std
::
string
>
skip_op_names
=
{
"convert"
,
...
...
@@ -42,31 +108,17 @@ void eliminate_data_type::apply(module& m) const
"scatternd_add"
,
"scatternd_mul"
,
"scatternd_none"
};
if
(
unsupported_types
.
empty
())
return
;
for
(
auto
ins
:
iterator_for
(
m
))
{
if
(
ins
->
name
()[
0
]
==
'@'
)
continue
;
if
(
contains
(
skip_op_names
,
ins
->
name
()))
continue
;
auto
inputs
=
ins
->
inputs
();
std
::
transform
(
inputs
.
begin
(),
inputs
.
end
(),
inputs
.
begin
(),
[
&
](
auto
i
)
{
if
(
types
.
count
(
i
->
get_shape
().
type
())
==
0
)
return
i
;
return
m
.
insert_instruction
(
ins
,
make_op
(
"convert"
,
{{
"target_type"
,
target_type
}}),
i
);
});
if
(
inputs
==
ins
->
inputs
())
if
(
contains
(
skip_op_names
,
ins
->
name
())
and
not
contains
(
unsupported_ops
,
ins
->
name
()))
continue
;
auto
op
=
ins
->
get_operator
();
auto
attributes
=
op
.
attributes
();
if
(
attributes
.
contains
(
"general_data_type"
))
{
op
=
make_op
(
attributes
[
"general_data_type"
].
to
<
std
::
string
>
(),
op
.
to_value
());
}
auto
old_type
=
ins
->
get_shape
().
type
();
auto
out
=
m
.
insert_instruction
(
ins
,
op
,
inputs
);
auto
convert
=
m
.
insert_instruction
(
ins
,
make_op
(
"convert"
,
{{
"target_type"
,
old_type
}}),
out
);
m
.
replace_instruction
(
ins
,
convert
);
if
(
contains
(
unsupported_ops
,
"all"
)
or
contains
(
unsupported_ops
,
ins
->
name
()))
insert_convert_to_supported_type
(
m
,
ins
,
target_type
,
unsupported_types
);
}
}
...
...
src/include/migraphx/eliminate_data_type.hpp
View file @
3bc4a779
...
...
@@ -40,8 +40,9 @@ struct module;
*/
struct
MIGRAPHX_EXPORT
eliminate_data_type
{
std
::
set
<
shape
::
type_t
>
types
;
std
::
set
<
shape
::
type_t
>
unsupported_
types
;
shape
::
type_t
target_type
;
std
::
set
<
std
::
string
>
unsupported_ops
=
{
"all"
};
std
::
string
name
()
const
{
return
"eliminate_data_type"
;
}
void
apply
(
module
&
m
)
const
;
};
...
...
src/include/migraphx/op/quant_convolution.hpp
View file @
3bc4a779
...
...
@@ -27,6 +27,7 @@
#include <migraphx/op/common.hpp>
#include <migraphx/argument.hpp>
#include <migraphx/check_shapes.hpp>
#include <migraphx/shape.hpp>
#include <migraphx/config.hpp>
#include <migraphx/convolution.hpp>
#include <migraphx/value.hpp>
...
...
@@ -87,11 +88,13 @@ struct quant_convolution
}
// all input type must be int8_type and output is float_type
if
(
t
!=
shape
::
int8_type
)
std
::
set
<
migraphx
::
shape
::
type_t
>
supported_types
=
{
shape
::
int8_type
,
shape
::
fp8e4m3fnuz_type
};
if
(
not
contains
(
supported_types
,
t
))
{
MIGRAPHX_THROW
(
"QUANT_CONVOLUTION: only accept input and weights of type int8_t"
);
MIGRAPHX_THROW
(
"QUANT_CONVOLUTION: only accept input and weights of type int8_t or "
"fp8e4m3fnuz_type"
);
}
t
=
shape
::
int32_type
;
std
::
vector
<
size_t
>
output_lens
{
input
.
lens
()[
0
],
weights
.
lens
()[
0
]};
auto
padding_size
=
padding
.
size
();
...
...
@@ -107,8 +110,11 @@ struct quant_convolution
stride
[
i
]
+
1
)));
}
return
inputs
[
0
].
with_lens
(
t
,
output_lens
);
if
(
t
==
shape
::
int8_type
)
{
return
inputs
[
0
].
with_lens
(
shape
::
int32_type
,
output_lens
);
}
// else fp8 conv
return
inputs
[
0
].
with_lens
(
shape
::
float_type
,
output_lens
);
}
size_t
kdims
()
const
...
...
src/targets/gpu/CMakeLists.txt
View file @
3bc4a779
...
...
@@ -259,6 +259,8 @@ check_library_exists(MIOpen "miopenHiddenSetConvolutionFindMode" "${MIOPEN_LOCAT
check_library_exists
(
MIOpen
"miopenFindSolutions"
"
${
MIOPEN_LOCATION
}
"
HAS_FIND_2_API
)
# Beta API for automated GEMM tuning
check_library_exists
(
roc::rocblas
"rocblas_gemm_ex_get_solutions"
"
${
ROCBLAS_LOCATION
}
"
HAS_ROCBLAS_TUNING_BETA_FEATURE_API
)
# rocblas FP8 API
check_library_exists
(
roc::rocblas
"rocblas_gemm_strided_batched_ex3"
"
${
ROCBLAS_LOCATION
}
"
HAS_ROCBLAS_FP8_BETA_API
)
set
(
MIGRAPHX_USE_FIND_2_API
"
${
HAS_FIND_2_API
}
"
CACHE BOOL
""
)
...
...
@@ -288,10 +290,18 @@ else()
message
(
STATUS
"rocBLAS does not have User Tuning Beta API"
)
endif
()
if
(
HAS_ROCBLAS_FP8_BETA_API
)
target_compile_definitions
(
migraphx_gpu PUBLIC -DMIGRAPHX_USE_ROCBLAS_FP8_API -DROCBLAS_BETA_FEATURES_API -DROCBLAS_NO_DEPRECATED_WARNINGS
)
message
(
STATUS
"MIGraphX is using Beta API of rocBLAS for FP8 computations"
)
else
()
message
(
STATUS
"rocBLAS does not have Fp8 Beta API"
)
endif
()
target_link_libraries
(
migraphx_gpu PUBLIC migraphx MIOpen roc::rocblas
)
target_link_libraries
(
migraphx_gpu PRIVATE migraphx_device migraphx_kernels
)
if
(
MIGRAPHX_USE_COMPOSABLEKERNEL
)
target_link_libraries
(
migraphx_gpu PRIVATE composable_kernel::jit_library
)
target_compile_definitions
(
migraphx_gpu PRIVATE MIGRAPHX_USE_COMPOSABLEKERNEL=1
)
endif
()
add_subdirectory
(
driver
)
...
...
src/targets/gpu/device_name.cpp
View file @
3bc4a779
...
...
@@ -49,6 +49,12 @@ std::string get_device_name()
return
props
.
gcnArchName
;
}
bool
gfx_has_fp8_intrinsics
()
{
const
auto
device_name
=
trim
(
split_string
(
get_device_name
(),
':'
).
front
());
return
(
starts_with
(
device_name
,
"gfx9"
)
and
device_name
>=
"gfx940"
);
}
}
// namespace gpu
}
// namespace MIGRAPHX_INLINE_NS
}
// namespace migraphx
src/targets/gpu/fuse_mlir.cpp
View file @
3bc4a779
...
...
@@ -218,6 +218,7 @@ auto is_mlir_conv(mlir_mode mode)
return
false
;
if
(
ins
->
name
()
!=
"convolution"
and
ins
->
name
()
!=
"quant_convolution"
)
return
false
;
auto
input_arg_t
=
ins
->
inputs
().
front
()
->
get_shape
().
type
();
value
v
=
ins
->
get_operator
().
to_value
();
auto
group
=
v
.
at
(
"group"
).
to
<
int
>
();
if
(
group
!=
1
)
...
...
@@ -225,6 +226,10 @@ auto is_mlir_conv(mlir_mode mode)
// Avoid MLIR assertion: Index < Length && "Invalid index!"
if
(
ins
->
get_shape
().
lens
().
size
()
!=
4
)
return
false
;
if
(
ins
->
get_shape
().
type
()
==
shape
::
fp8e4m3fnuz_type
)
return
true
;
if
(
ins
->
get_shape
().
type
()
==
shape
::
float_type
and
input_arg_t
==
shape
::
fp8e4m3fnuz_type
)
return
true
;
if
(
ins
->
get_shape
().
type
()
==
shape
::
int8_type
)
return
true
;
if
(
mode
==
mlir_mode
::
int8
)
...
...
@@ -292,6 +297,7 @@ bool is_pointwise_op_supported_by_mlir(const instruction& i)
const
auto
result_type
=
i
.
get_shape
().
type
();
const
std
::
initializer_list
<
type_t
>
allowed_types
=
{
type_t
::
float_type
,
type_t
::
half_type
,
type_t
::
fp8e4m3fnuz_type
,
type_t
::
int8_type
,
type_t
::
int32_type
,
type_t
::
bool_type
};
...
...
@@ -331,7 +337,8 @@ bool is_pointwise_op_supported_by_mlir(const instruction& i)
"softmax"
,
"tanh"
,
};
bool
is_float
=
contains
({
type_t
::
float_type
,
type_t
::
half_type
},
result_type
);
bool
is_float
=
contains
({
type_t
::
float_type
,
type_t
::
half_type
,
type_t
::
fp8e4m3fnuz_type
},
result_type
);
if
(
contains
(
any_type_ops
,
name
))
return
true
;
if
(
result_type
!=
type_t
::
bool_type
and
contains
(
no_bool_ops
,
name
))
...
...
@@ -342,6 +349,10 @@ bool is_pointwise_op_supported_by_mlir(const instruction& i)
// supported.
if
(
is_float
and
name
==
"convert"
)
{
if
(
result_type
==
shape
::
fp8e4m3fnuz_type
)
{
return
false
;
}
// else
return
std
::
all_of
(
i
.
inputs
().
begin
(),
i
.
inputs
().
end
(),
[](
const
auto
&
arg
)
{
return
contains
({
type_t
::
float_type
,
type_t
::
half_type
},
arg
->
get_shape
().
type
());
});
...
...
@@ -404,12 +415,13 @@ struct find_mlir_standalone_op
void
apply
(
module_pass_manager
&
mpm
,
const
match
::
matcher_result
&
r
)
const
{
auto
gemm_based_op
=
r
.
result
;
//
// enable only for fp32/fp16/i8 types
// enable only for fp32/fp16/i8/fp8 types
if
(
std
::
any_of
(
gemm_based_op
->
inputs
().
begin
(),
gemm_based_op
->
inputs
().
end
(),
[
&
](
auto
i
)
{
return
not
contains
(
{
shape
::
type_t
::
float_type
,
shape
::
type_t
::
half_type
,
shape
::
type_t
::
int8_type
},
i
->
get_shape
().
type
());
return
not
contains
({
shape
::
type_t
::
float_type
,
shape
::
type_t
::
half_type
,
shape
::
type_t
::
int8_type
,
shape
::
type_t
::
fp8e4m3fnuz_type
},
i
->
get_shape
().
type
());
}))
return
;
static
size_t
counter
=
0
;
...
...
@@ -531,7 +543,7 @@ void fuse_mlir::apply(module_pass_manager& mpm) const
match
::
find_matches
(
mpm
,
find_mlir_standalone_convolution_op
{
get_mode
(
"convolution"
,
mlir_mode
::
int8
)},
find_mlir_standalone_convolution_op
{
get_mode
(
"convolution"
,
mlir_mode
::
fast
)},
find_mlir_standalone_dot_op
{
get_mode
(
"dot"
,
mlir_mode
::
none
)});
#else
(
void
)
mpm
;
...
...
src/targets/gpu/gemm_impl.cpp
View file @
3bc4a779
...
...
@@ -22,11 +22,14 @@
* THE SOFTWARE.
*/
#include <rocblas/internal/rocblas-types.h>
#include <rocblas/rocblas.h>
#include <migraphx/gpu/rocblas.hpp>
#include <migraphx/gpu/gemm_impl.hpp>
#include <migraphx/reduce_dims.hpp>
#include <migraphx/generate.hpp>
#include <migraphx/time.hpp>
#include <type_traits>
using
microseconds
=
std
::
chrono
::
duration
<
double
,
std
::
micro
>
;
...
...
@@ -34,6 +37,20 @@ namespace migraphx {
inline
namespace
MIGRAPHX_INLINE_NS
{
namespace
gpu
{
/*
Regular rocBLAS API takes compute_type as `rocblas_datatype` enum value v/s "ex3" BETA API takes it
as `rocblas_computetype` enum value. `rb_compute_type` is faciliator to implictly cast integer enum
value to required type that can be used inside `common_args` generator.
*/
struct
rb_compute_type
{
int
type
=
0
;
rb_compute_type
(
rocblas_datatype
t
)
:
type
(
static_cast
<
int
>
(
t
))
{}
rb_compute_type
(
rocblas_computetype
t
)
:
type
(
static_cast
<
int
>
(
t
))
{}
operator
rocblas_datatype
()
const
{
return
static_cast
<
rocblas_datatype
>
(
type
);
}
operator
rocblas_computetype
()
const
{
return
static_cast
<
rocblas_computetype
>
(
type
);
}
};
// Convert rocBLAS datatypes to equivalent Migraphx data types
rocblas_datatype
get_type
(
shape
::
type_t
type
)
{
...
...
@@ -46,7 +63,7 @@ rocblas_datatype get_type(shape::type_t type)
case
shape
::
uint8_type
:
return
rocblas_datatype_u8_r
;
case
shape
::
int32_type
:
return
rocblas_datatype_i32_r
;
case
shape
::
uint32_type
:
return
rocblas_datatype_u32_r
;
case
shape
::
fp8e4m3fnuz_type
:
case
shape
::
fp8e4m3fnuz_type
:
return
rocblas_datatype_f8_r
;
case
shape
::
tuple_type
:
case
shape
::
bool_type
:
case
shape
::
uint16_type
:
...
...
@@ -183,12 +200,17 @@ struct gemm_impl
{
output_type
=
rocblas_datatype_i32_r
;
}
compute_type
=
output_type
;
compute_type
=
rb_compute_type
{
output_type
}
;
if
(
compute_fp32
)
{
if
(
arg_type
==
rocblas_datatype_f16_r
)
compute_type
=
rocblas_datatype_f32_r
;
}
if
(
arg_type
==
rocblas_datatype_f8_r
)
{
assert
(
get_type
(
input_shapes
[
1
].
type
())
==
rocblas_datatype_f8_r
);
compute_type
=
rocblas_compute_type_f32
;
}
auto
a_lens
=
input_shapes
[
0
].
lens
();
auto
b_lens
=
input_shapes
[
1
].
lens
();
...
...
@@ -217,23 +239,52 @@ struct gemm_impl
void
run
(
context
&
ctx
,
const
std
::
vector
<
argument
>&
input_args
,
int32_t
solution_idx
=
0
)
const
{
if
(
strided_batched
)
#ifdef MIGRAPHX_USE_ROCBLAS_FP8_API
if
(
rocblas_fp8_available
()
and
std
::
any_of
(
input_args
.
begin
(),
input_args
.
end
(),
[](
const
auto
i
)
{
return
i
.
get_shape
().
type
()
==
migraphx
::
shape
::
fp8e4m3fnuz_type
;
}))
{
auto
common_args
=
create_strided_batched_args_common
(
ctx
,
input_args
);
rocblas_invoke
(
&
rocblas_gemm_strided_batched_ex
,
common_args
,
rocblas_gemm_algo_solution_index
,
solution_idx
,
gemm_flags
);
if
(
strided_batched
)
{
auto
common_args
=
create_strided_batched_args_common
(
ctx
,
input_args
);
rocblas_invoke
(
&
rocblas_gemm_strided_batched_ex3
,
common_args
,
rocblas_gemm_algo_standard
,
solution_idx
,
gemm_flags
);
}
else
{
auto
common_args
=
create_gemm_ex_args_common
(
ctx
,
input_args
);
rocblas_invoke
(
&
rocblas_gemm_ex3
,
common_args
,
rocblas_gemm_algo_standard
,
solution_idx
,
gemm_flags
);
}
}
else
#endif
{
auto
common_args
=
create_gemm_ex_args_common
(
ctx
,
input_args
);
rocblas_invoke
(
&
rocblas_gemm_ex
,
common_args
,
rocblas_gemm_algo_solution_index
,
solution_idx
,
gemm_flags
);
if
(
strided_batched
)
{
auto
common_args
=
create_strided_batched_args_common
(
ctx
,
input_args
);
rocblas_invoke
(
&
rocblas_gemm_strided_batched_ex
,
common_args
,
rocblas_gemm_algo_solution_index
,
solution_idx
,
gemm_flags
);
}
else
{
auto
common_args
=
create_gemm_ex_args_common
(
ctx
,
input_args
);
rocblas_invoke
(
&
rocblas_gemm_ex
,
common_args
,
rocblas_gemm_algo_solution_index
,
solution_idx
,
gemm_flags
);
}
}
}
...
...
@@ -331,7 +382,6 @@ struct gemm_impl
num_matrices
,
compute_type
);
}
/**
* Helper method to create that subset of a long rocBLAS argument list that is common
* to multiple "gemm_ex..." calls.
...
...
@@ -366,6 +416,7 @@ struct gemm_impl
ldd
,
compute_type
);
}
#ifdef MIGRAPHX_USE_ROCBLAS_TUNING_API
/**
* Find best rocBLAS solution: Get list of solutions and try them all, returning the index
...
...
@@ -481,8 +532,8 @@ struct gemm_impl
rocblas_int
b_stride
=
0
;
rocblas_int
c_stride
=
0
;
rocblas_int
d_stride
=
0
;
rocblas_datatype
compute_type
=
rocblas_datatype_f32_r
;
rocblas_datatype
arg_type
=
rocblas_datatype_f32_r
;
rb_compute_type
compute_type
=
rocblas_datatype_f32_r
;
rocblas_datatype
output_type
=
rocblas_datatype_f32_r
;
bool
strided_batched
=
true
;
bool
is_3inputs
=
true
;
...
...
src/targets/gpu/include/migraphx/gpu/device_name.hpp
View file @
3bc4a779
...
...
@@ -37,6 +37,8 @@ MIGRAPHX_GPU_EXPORT std::string get_device_name();
MIGRAPHX_GPU_EXPORT
int
get_device_id
();
MIGRAPHX_GPU_EXPORT
bool
gfx_has_fp8_intrinsics
();
}
// namespace gpu
}
// namespace MIGRAPHX_INLINE_NS
}
// namespace migraphx
...
...
src/targets/gpu/include/migraphx/gpu/rocblas.hpp
View file @
3bc4a779
...
...
@@ -40,6 +40,8 @@ struct context;
MIGRAPHX_GPU_EXPORT
bool
get_compute_fp32_flag
();
MIGRAPHX_GPU_EXPORT
bool
rocblas_fp8_available
();
}
// namespace gpu
}
// namespace MIGRAPHX_INLINE_NS
}
// namespace migraphx
...
...
src/targets/gpu/kernels/include/migraphx/kernels/dpp.hpp
View file @
3bc4a779
...
...
@@ -49,12 +49,8 @@ constexpr unsigned int dpp_row_bcast(unsigned int x)
return
y
;
}
template
<
unsigned
int
DppCtrl
,
unsigned
int
RowMask
=
0xf
,
unsigned
int
BankMask
=
0xf
,
bool
BoundCtrl
=
false
,
class
T
>
__device__
T
dpp_mov
(
T
&
x
)
template
<
class
T
,
class
F
>
__device__
T
dpp_op
(
T
&
x
,
F
f
)
{
static
const
index_int
n
=
sizeof
(
T
)
<
4
?
1
:
sizeof
(
T
)
/
4
;
union
type
...
...
@@ -68,10 +64,28 @@ __device__ T dpp_mov(T& x)
input
.
data
=
x
;
for
(
index_int
i
=
0
;
i
<
n
;
i
++
)
{
output
.
reg
[
i
]
=
__hip_move_dpp
(
input
.
reg
[
i
],
DppCtrl
,
RowMask
,
BankMask
,
BoundCtrl
);
output
.
reg
[
i
]
=
f
(
input
.
reg
[
i
]
);
}
return
output
.
data
;
}
template
<
unsigned
int
DppCtrl
,
unsigned
int
RowMask
=
0xf
,
unsigned
int
BankMask
=
0xf
,
bool
BoundCtrl
=
false
,
class
T
>
__device__
T
dpp_mov
(
T
&
x
)
{
return
dpp_op
(
x
,
[](
auto
i
)
{
return
__hip_move_dpp
(
i
,
DppCtrl
,
RowMask
,
BankMask
,
BoundCtrl
);
});
}
template
<
unsigned
int
Mask
,
class
T
>
__device__
T
dpp_swizzle
(
T
&
x
)
{
return
dpp_op
(
x
,
[](
auto
i
)
{
return
__hip_ds_swizzle
(
i
,
Mask
);
});
}
#endif // MIGRAPHX_HAS_DPP
}
// namespace migraphx
...
...
src/targets/gpu/kernels/include/migraphx/kernels/float8.hpp
View file @
3bc4a779
...
...
@@ -501,9 +501,7 @@ class numeric_limits<fp8e5m2fnuz>
{
return
fp8e5m2fnuz
(
0x7F
,
fp8e5m2fnuz
::
from_bits
());
}
// this is min value that is not DeNormalized(DeNorm). DeNorm min is 0x01. I am not sure if we
// want to make this distinction. For the floating points we would end up using lowest most of
// the times.
// this is min value that is not DeNormalized(DeNorm). DeNorm min is 0x01.
static
constexpr
__device__
fp8e5m2fnuz
min
()
{
return
fp8e5m2fnuz
(
0x4
,
fp8e5m2fnuz
::
from_bits
());
...
...
@@ -528,9 +526,7 @@ class numeric_limits<fp8e5m2>
}
static
constexpr
__device__
fp8e5m2
max
()
{
return
fp8e5m2
(
0x7B
,
fp8e5m2
::
from_bits
());
}
// this is min value that is not DeNormalized(DeNorm). DeNorm min is 0x01. I am not sure if we
// want to make this distinction. For the floating points we would end up using lowest most of
// the times.
// this is min value that is not DeNormalized(DeNorm). DeNorm min is 0x01.
static
constexpr
__device__
fp8e5m2
min
()
{
return
fp8e5m2
(
0x4
,
fp8e5m2
::
from_bits
());
}
static
constexpr
__device__
fp8e5m2
lowest
()
{
return
fp8e5m2
(
0xFB
,
fp8e5m2
::
from_bits
());
}
...
...
src/targets/gpu/kernels/include/migraphx/kernels/reduce.hpp
View file @
3bc4a779
...
...
@@ -45,7 +45,10 @@ __device__ void dpp_reduce(T& in, Op op)
in
=
op
(
in
,
out
);
out
=
dpp_mov
<
dpp_row_shr
(
8
),
0xf
,
0xc
>
(
in
);
in
=
op
(
in
,
out
);
#if __AMDGCN_WAVEFRONT_SIZE == 64
#if __AMDGCN_WAVEFRONT_SIZE == 32
out
=
dpp_swizzle
<
0x1e0
>
(
in
);
in
=
op
(
in
,
out
);
#else
out
=
dpp_mov
<
dpp_row_bcast
(
15
),
0xa
>
(
in
);
in
=
op
(
in
,
out
);
out
=
dpp_mov
<
dpp_row_bcast
(
31
),
0xc
>
(
in
);
...
...
@@ -54,9 +57,11 @@ __device__ void dpp_reduce(T& in, Op op)
}
#if defined(MIGRAPHX_USE_CLANG_TIDY) || defined(CPPCHECK)
// NOLINTNEXTLINE
#define MIGRAPHX_DPP_REDUCE_ASM(x, ins) x = 1
#define MIGRAPHX_DPP_REDUCE_ASM(x, ins, f) \
(void)f; \
x = 1
#elif __AMDGCN_WAVEFRONT_SIZE == 64
#define MIGRAPHX_DPP_REDUCE_ASM(x, ins
)
\
#define MIGRAPHX_DPP_REDUCE_ASM(x, ins
, f)
\
__asm__ volatile("s_nop 4\n" #ins " %0 %0 %0 row_shr:1\n" \
"s_nop 1\n" #ins " %0 %0 %0 row_shr:2\n" \
"s_nop 1\n" #ins " %0 %0 %0 row_shr:4 bank_mask:0xe\n" \
...
...
@@ -65,29 +70,42 @@ __device__ void dpp_reduce(T& in, Op op)
"s_nop 1\n" #ins " %0 %0 %0 row_bcast:31 row_mask:0xc\n" \
"s_nop 1\n" \
: "=v"(x) \
: "0"(x))
: "0"(x)); \
(void)f
#else
#define MIGRAPHX_DPP_REDUCE_ASM(x, ins
)
\
#define MIGRAPHX_DPP_REDUCE_ASM(x, ins
, f)
\
__asm__ volatile("s_nop 4\n" #ins " %0 %0 %0 row_shr:1\n" \
"s_nop 1\n" #ins " %0 %0 %0 row_shr:2\n" \
"s_nop 1\n" #ins " %0 %0 %0 row_shr:4 bank_mask:0xe\n" \
"s_nop 1\n" #ins " %0 %0 %0 row_shr:8 bank_mask:0xc\n" \
"s_nop 1\n" \
"s_nop 1\n" \
: "=v"(x) \
: "0"(x))
: "0"(x)); \
auto y = dpp_swizzle<0x1e0>(x); \
x = f(x, y)
#endif
// NOLINTNEXTLINE
#define MIGRAPHX_DPP_REDUCE(op, prefix, sign) \
__device__ inline void dpp_reduce(double& x, op) { MIGRAPHX_DPP_REDUCE_ASM(x, prefix##_f64); } \
__device__ inline void dpp_reduce(float& x, op) { MIGRAPHX_DPP_REDUCE_ASM(x, prefix##_f32); } \
__device__ inline void dpp_reduce(half& x, op) { MIGRAPHX_DPP_REDUCE_ASM(x, prefix##_f16); } \
__device__ inline void dpp_reduce(int32_t& x, op) \
{ \
MIGRAPHX_DPP_REDUCE_ASM(x, prefix##sign##32); \
} \
__device__ inline void dpp_reduce(uint32_t& x, op) { MIGRAPHX_DPP_REDUCE_ASM(x, prefix##_u32); }
#define MIGRAPHX_DPP_REDUCE(op, prefix, sign) \
__device__ inline void dpp_reduce(double& x, op f) \
{ \
MIGRAPHX_DPP_REDUCE_ASM(x, prefix##_f64, f); \
} \
__device__ inline void dpp_reduce(float& x, op f) \
{ \
MIGRAPHX_DPP_REDUCE_ASM(x, prefix##_f32, f); \
} \
__device__ inline void dpp_reduce(half& x, op f) \
{ \
MIGRAPHX_DPP_REDUCE_ASM(x, prefix##_f16, f); \
} \
__device__ inline void dpp_reduce(int32_t& x, op f) \
{ \
MIGRAPHX_DPP_REDUCE_ASM(x, prefix##sign##32, f); \
} \
__device__ inline void dpp_reduce(uint32_t& x, op f) \
{ \
MIGRAPHX_DPP_REDUCE_ASM(x, prefix##_u32, f); \
}
// Note: when max and min are in int32_t, signed version of instruction needs to be used.
MIGRAPHX_DPP_REDUCE
(
op
::
sum
,
v_add
,
_u
)
...
...
@@ -99,11 +117,7 @@ template <class Op, class T, class Index, class F>
__device__
auto
block_reduce
(
index
idx
,
Op
op
,
T
init
,
Index
n
,
F
f
)
{
MIGRAPHX_ASSERT
(
idx
.
max_nlocal
()
==
idx
.
nlocal
());
#if __AMDGCN_WAVEFRONT_SIZE == 32
constexpr
index_int
lanes_per_thread
=
16
;
#else
constexpr
index_int
lanes_per_thread
=
64
;
#endif
constexpr
index_int
lanes_per_thread
=
__AMDGCN_WAVEFRONT_SIZE
;
using
type
=
decltype
(
index
::
invoke_loop
(
f
,
0
,
_c
<
0
>
));
__shared__
type
buffer
[
idx
.
max_nlocal
()
/
lanes_per_thread
];
type
x
=
type
(
init
);
...
...
src/targets/gpu/mlir.cpp
View file @
3bc4a779
...
...
@@ -300,6 +300,8 @@ struct mlir_program
result
=
mlirF32TypeGet
(
ctx
.
get
());
else
if
(
as
.
type_enum
()
==
shape
::
half_type
)
result
=
mlirF16TypeGet
(
ctx
.
get
());
else
if
(
as
.
type_enum
()
==
shape
::
fp8e4m3fnuz_type
)
result
=
mlirFloat8E4M3FNUZTypeGet
(
ctx
.
get
());
else
if
(
as
.
type_enum
()
==
shape
::
double_type
)
result
=
mlirF64TypeGet
(
ctx
.
get
());
else
if
(
as
.
is_integral
())
...
...
src/targets/gpu/rocblas.cpp
View file @
3bc4a779
...
...
@@ -53,6 +53,15 @@ bool get_compute_fp32_flag()
return
(
starts_with
(
device_name
,
"gfx9"
)
and
device_name
>=
"gfx908"
);
}
bool
rocblas_fp8_available
()
{
#ifndef MIGRAPHX_USE_ROCBLAS_FP8_API
return
false
;
#else
return
gfx_has_fp8_intrinsics
();
#endif
}
}
// namespace gpu
}
// namespace MIGRAPHX_INLINE_NS
}
// namespace migraphx
src/targets/gpu/target.cpp
View file @
3bc4a779
...
...
@@ -105,6 +105,29 @@ std::vector<pass> target::get_passes(migraphx::context& gctx, const compile_opti
unsupported_types
.
erase
(
shape
::
type_t
::
uint8_type
);
unsupported_types
.
erase
(
shape
::
type_t
::
int32_type
);
unsupported_types
.
erase
(
shape
::
type_t
::
tuple_type
);
// whiltelist supported Ops for the FP8
std
::
set
<
std
::
string
>
unsupported_fp8_ops
=
{};
if
(
not
gpu
::
rocblas_fp8_available
())
{
unsupported_fp8_ops
.
insert
(
"dot"
);
}
// MIOpen doesn't have support for fp8 pooling yet.
unsupported_fp8_ops
.
insert
(
"pooling"
);
if
(
not
gpu
::
gfx_has_fp8_intrinsics
())
{
unsupported_fp8_ops
.
insert
(
"convolution"
);
unsupported_fp8_ops
.
insert
(
"quant_convolution"
);
}
// add all device kernels
unsupported_fp8_ops
.
insert
(
"logsoftmax"
);
unsupported_fp8_ops
.
insert
(
"nonzero"
);
unsupported_fp8_ops
.
insert
(
"prefix_scan_sum"
);
unsupported_fp8_ops
.
insert
(
"scatter_none"
);
unsupported_fp8_ops
.
insert
(
"topk"
);
unsupported_fp8_ops
.
insert
(
"rnn_var_sl_shift_output"
);
unsupported_fp8_ops
.
insert
(
"multinomial"
);
unsupported_fp8_ops
.
insert
(
"argmax"
);
unsupported_fp8_ops
.
insert
(
"argmin"
);
// clang-format off
return
{
...
...
@@ -136,6 +159,8 @@ std::vector<pass> target::get_passes(migraphx::context& gctx, const compile_opti
prefuse_ops
{},
dead_code_elimination
{},
auto_contiguous
{},
eliminate_data_type
{{
migraphx
::
shape
::
fp8e4m3fnuz_type
},
shape
::
float_type
,
unsupported_fp8_ops
},
dead_code_elimination
{},
optimize_module
{},
fuse_pointwise
{},
dead_code_elimination
{},
...
...
Prev
1
2
3
4
5
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment