Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
MIGraphX
Commits
ad9c25ea
Commit
ad9c25ea
authored
Nov 26, 2023
by
Umang Yadav
Browse files
add eliminate_fp8 pass
parent
4604f2e1
Changes
5
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
137 additions
and
61 deletions
+137
-61
src/CMakeLists.txt
src/CMakeLists.txt
+1
-0
src/eliminate_fp8.cpp
src/eliminate_fp8.cpp
+62
-0
src/include/migraphx/eliminate_fp8.hpp
src/include/migraphx/eliminate_fp8.hpp
+52
-0
src/targets/gpu/gemm_impl.cpp
src/targets/gpu/gemm_impl.cpp
+15
-61
src/targets/gpu/target.cpp
src/targets/gpu/target.cpp
+7
-0
No files found.
src/CMakeLists.txt
View file @
ad9c25ea
...
...
@@ -49,6 +49,7 @@ add_library(migraphx
eliminate_concat.cpp
eliminate_contiguous.cpp
eliminate_data_type.cpp
eliminate_fp8.cpp
eliminate_identity.cpp
eliminate_pad.cpp
env.cpp
...
...
src/eliminate_fp8.cpp
0 → 100644
View file @
ad9c25ea
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#include <utility>
#include <migraphx/eliminate_fp8.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/program.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/iterator_for.hpp>
#include <migraphx/stringutils.hpp>
#include <migraphx/ranges.hpp>
namespace
migraphx
{
inline
namespace
MIGRAPHX_INLINE_NS
{
void
eliminate_fp8
::
apply
(
module
&
m
)
const
{
for
(
auto
ins
:
iterator_for
(
m
))
{
if
(
not
contains
(
op_names
,
ins
->
name
()))
continue
;
migraphx
::
shape
::
type_t
orig_type
=
ins
->
get_shape
().
type
();
std
::
vector
<
instruction_ref
>
orig_inputs
=
ins
->
inputs
();
std
::
vector
<
instruction_ref
>
new_inputs
;
for
(
const
auto
&
i
:
orig_inputs
)
{
new_inputs
.
push_back
(
m
.
insert_instruction
(
ins
,
migraphx
::
make_op
(
"convert"
,
{{
"target_type"
,
migraphx
::
to_value
(
target_type
)}}),
i
));
}
auto
new_ins
=
m
.
insert_instruction
(
ins
,
ins
->
get_operator
(),
{
new_inputs
});
auto
convert_back_ins
=
m
.
insert_instruction
(
ins
,
migraphx
::
make_op
(
"convert"
,
{{
"target_type"
,
migraphx
::
to_value
(
orig_type
)}}),
new_ins
);
m
.
replace_instruction
(
ins
,
convert_back_ins
);
}
}
}
// namespace MIGRAPHX_INLINE_NS
}
// namespace migraphx
src/include/migraphx/eliminate_fp8.hpp
0 → 100644
View file @
ad9c25ea
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2023 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#ifndef MIGRAPHX_GUARD_AMDMIGRAPHX_ELIMINATE_FP8_HPP
#define MIGRAPHX_GUARD_AMDMIGRAPHX_ELIMINATE_FP8_HPP
#include <migraphx/config.hpp>
#include <migraphx/shape.hpp>
#include <set>
#include <string>
namespace
migraphx
{
inline
namespace
MIGRAPHX_INLINE_NS
{
struct
module
;
/**
This will insert convert operators for the operators that are not implemented for FP8 dtypes
*/
struct
MIGRAPHX_EXPORT
eliminate_fp8
{
// TODO: Add all device ops as a later PR and add tests for those.
std
::
set
<
std
::
string
>
op_names
;
shape
::
type_t
target_type
=
migraphx
::
shape
::
float_type
;
std
::
string
name
()
const
{
return
"eliminate_fp8"
;
}
void
apply
(
module
&
m
)
const
;
};
}
// namespace MIGRAPHX_INLINE_NS
}
// namespace migraphx
#endif
src/targets/gpu/gemm_impl.cpp
View file @
ad9c25ea
...
...
@@ -227,18 +227,20 @@ struct gemm_impl
{
if
(
strided_batched
)
{
auto
common_args
=
create_strided_batched_args_common
_fp8
(
ctx
,
input_args
);
auto
common_args
=
create_strided_batched_args_common
(
ctx
,
input_args
);
rocblas_invoke
(
&
rocblas_gemm_strided_batched_ex3
,
common_args
,
rocblas_compute_type_f32
,
rocblas_gemm_algo_standard
,
solution_idx
,
gemm_flags
);
}
else
{
auto
common_args
=
create_gemm_ex_args_common
_fp8
(
ctx
,
input_args
);
auto
common_args
=
create_gemm_ex_args_common
(
ctx
,
input_args
);
rocblas_invoke
(
&
rocblas_gemm_ex3
,
common_args
,
rocblas_compute_type_f32
,
rocblas_gemm_algo_standard
,
solution_idx
,
gemm_flags
);
...
...
@@ -252,6 +254,7 @@ struct gemm_impl
auto
common_args
=
create_strided_batched_args_common
(
ctx
,
input_args
);
rocblas_invoke
(
&
rocblas_gemm_strided_batched_ex
,
common_args
,
compute_type
,
rocblas_gemm_algo_solution_index
,
solution_idx
,
gemm_flags
);
...
...
@@ -261,6 +264,7 @@ struct gemm_impl
auto
common_args
=
create_gemm_ex_args_common
(
ctx
,
input_args
);
rocblas_invoke
(
&
rocblas_gemm_ex
,
common_args
,
compute_type
,
rocblas_gemm_algo_solution_index
,
solution_idx
,
gemm_flags
);
...
...
@@ -300,6 +304,7 @@ struct gemm_impl
auto
common_args
=
create_strided_batched_args_common
(
ctx
,
input_args
);
check_valid
=
rocblas_invoke
(
&
rocblas_gemm_strided_batched_ex
,
common_args
,
compute_type
,
rocblas_gemm_algo_solution_index
,
solution_idx
,
rocblas_gemm_flags_check_solution_index
);
...
...
@@ -309,6 +314,7 @@ struct gemm_impl
auto
common_args
=
create_gemm_ex_args_common
(
ctx
,
input_args
);
check_valid
=
rocblas_invoke
(
&
rocblas_gemm_ex
,
common_args
,
compute_type
,
rocblas_gemm_algo_solution_index
,
solution_idx
,
rocblas_gemm_flags_check_solution_index
);
...
...
@@ -359,40 +365,8 @@ struct gemm_impl
output_type
,
ldd
,
d_stride
,
num_matrices
,
compute_type
);
num_matrices
);
}
auto
create_strided_batched_args_common_fp8
(
context
&
ctx
,
const
std
::
vector
<
argument
>&
args
)
const
{
return
pack
(
ctx
.
get_stream
().
get_rocblas
(),
transb
?
rocblas_operation_transpose
:
rocblas_operation_none
,
transa
?
rocblas_operation_transpose
:
rocblas_operation_none
,
n
,
m
,
k
,
get_alpha
(),
args
[
1
].
data
(),
arg_type
,
ldb
,
b_stride
,
args
[
0
].
data
(),
arg_type
,
lda
,
a_stride
,
get_beta
(),
args
[
2
].
data
(),
output_type
,
ldc
,
c_stride
,
is_3inputs
?
args
[
3
].
data
()
:
args
[
2
].
data
(),
output_type
,
ldd
,
d_stride
,
num_matrices
,
rocblas_compute_type_f32
);
}
/**
* Helper method to create that subset of a long rocBLAS argument list that is common
* to multiple "gemm_ex..." calls.
...
...
@@ -424,33 +398,9 @@ struct gemm_impl
ldc
,
is_3inputs
?
args
[
3
].
data
()
:
args
[
2
].
data
(),
output_type
,
ldd
,
compute_type
);
}
auto
create_gemm_ex_args_common_fp8
(
context
&
ctx
,
const
std
::
vector
<
argument
>&
args
)
const
{
return
pack
(
ctx
.
get_stream
().
get_rocblas
(),
transb
?
rocblas_operation_transpose
:
rocblas_operation_none
,
transa
?
rocblas_operation_transpose
:
rocblas_operation_none
,
n
,
m
,
k
,
get_alpha
(),
args
[
1
].
data
(),
arg_type
,
ldb
,
args
[
0
].
data
(),
arg_type
,
lda
,
get_beta
(),
args
[
2
].
data
(),
output_type
,
ldc
,
is_3inputs
?
args
[
3
].
data
()
:
args
[
2
].
data
(),
output_type
,
ldd
,
rocblas_compute_type_f32
);
ldd
);
}
#ifdef MIGRAPHX_USE_ROCBLAS_TUNING_API
/**
* Find best rocBLAS solution: Get list of solutions and try them all, returning the index
...
...
@@ -478,6 +428,7 @@ struct gemm_impl
auto
common_args
=
create_strided_batched_args_common
(
ctx
,
input_args
);
rocblas_invoke
(
&
rocblas_gemm_strided_batched_ex_get_solutions
,
common_args
,
compute_type
,
rocblas_gemm_algo_solution_index
,
gemm_flags
,
nullptr
,
...
...
@@ -487,6 +438,7 @@ struct gemm_impl
auto
common_sol_args
=
create_strided_batched_args_common
(
ctx
,
input_args
);
rocblas_invoke
(
&
rocblas_gemm_strided_batched_ex_get_solutions
,
common_sol_args
,
compute_type
,
rocblas_gemm_algo_solution_index
,
gemm_flags
,
solution_indices
.
data
(),
...
...
@@ -497,6 +449,7 @@ struct gemm_impl
auto
common_args
=
create_gemm_ex_args_common
(
ctx
,
input_args
);
rocblas_invoke
(
&
rocblas_gemm_ex_get_solutions
,
common_args
,
compute_type
,
rocblas_gemm_algo_solution_index
,
gemm_flags
,
nullptr
,
...
...
@@ -506,6 +459,7 @@ struct gemm_impl
auto
common_sol_args
=
create_gemm_ex_args_common
(
ctx
,
input_args
);
rocblas_invoke
(
&
rocblas_gemm_ex_get_solutions
,
common_sol_args
,
compute_type
,
rocblas_gemm_algo_solution_index
,
gemm_flags
,
solution_indices
.
data
(),
...
...
src/targets/gpu/target.cpp
View file @
ad9c25ea
...
...
@@ -52,6 +52,7 @@
#include <migraphx/simplify_qdq.hpp>
#include <migraphx/simplify_reshapes.hpp>
#include <migraphx/split_single_dyn_dim.hpp>
#include <migraphx/eliminate_fp8.hpp>
#include <migraphx/gpu/allocation_model.hpp>
#include <migraphx/gpu/compile_miopen.hpp>
#include <migraphx/gpu/compile_ops.hpp>
...
...
@@ -105,6 +106,11 @@ std::vector<pass> target::get_passes(migraphx::context& gctx, const compile_opti
unsupported_types
.
erase
(
shape
::
type_t
::
uint8_type
);
unsupported_types
.
erase
(
shape
::
type_t
::
int32_type
);
unsupported_types
.
erase
(
shape
::
type_t
::
tuple_type
);
std
::
set
<
std
::
string
>
unsupported_fp8_ops
=
{};
if
(
not
gpu
::
rocblas_fp8_available
())
{
unsupported_fp8_ops
.
insert
(
"dot"
);
}
// clang-format off
return
{
...
...
@@ -147,6 +153,7 @@ std::vector<pass> target::get_passes(migraphx::context& gctx, const compile_opti
dead_code_elimination
{},
enable_pass
(
mlir_enabled
(),
fuse_mlir
{
&
ctx
}),
dead_code_elimination
{},
eliminate_fp8
{
unsupported_fp8_ops
},
lowering
{
&
ctx
,
options
.
offload_copy
},
eliminate_contiguous
{
"gpu::contiguous"
},
dead_code_elimination
{},
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment