Commit ad9c25ea authored by Umang Yadav's avatar Umang Yadav
Browse files

add eliminate_fp8 pass

parent 4604f2e1
...@@ -49,6 +49,7 @@ add_library(migraphx ...@@ -49,6 +49,7 @@ add_library(migraphx
eliminate_concat.cpp eliminate_concat.cpp
eliminate_contiguous.cpp eliminate_contiguous.cpp
eliminate_data_type.cpp eliminate_data_type.cpp
eliminate_fp8.cpp
eliminate_identity.cpp eliminate_identity.cpp
eliminate_pad.cpp eliminate_pad.cpp
env.cpp env.cpp
......
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#include <utility>
#include <migraphx/eliminate_fp8.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/program.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/iterator_for.hpp>
#include <migraphx/stringutils.hpp>
#include <migraphx/ranges.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
void eliminate_fp8::apply(module& m) const
{
for(auto ins : iterator_for(m))
{
if(not contains(op_names, ins->name()))
continue;
migraphx::shape::type_t orig_type = ins->get_shape().type();
std::vector<instruction_ref> orig_inputs = ins->inputs();
std::vector<instruction_ref> new_inputs;
for(const auto& i : orig_inputs)
{
new_inputs.push_back(m.insert_instruction(
ins,
migraphx::make_op("convert", {{"target_type", migraphx::to_value(target_type)}}),
i));
}
auto new_ins = m.insert_instruction(ins, ins->get_operator(), {new_inputs});
auto convert_back_ins = m.insert_instruction(
ins,
migraphx::make_op("convert", {{"target_type", migraphx::to_value(orig_type)}}),
new_ins);
m.replace_instruction(ins, convert_back_ins);
}
}
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2023 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#ifndef MIGRAPHX_GUARD_AMDMIGRAPHX_ELIMINATE_FP8_HPP
#define MIGRAPHX_GUARD_AMDMIGRAPHX_ELIMINATE_FP8_HPP
#include <migraphx/config.hpp>
#include <migraphx/shape.hpp>
#include <set>
#include <string>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
struct module;
/**
This will insert convert operators for the operators that are not implemented for FP8 dtypes
*/
struct MIGRAPHX_EXPORT eliminate_fp8
{
// TODO: Add all device ops as a later PR and add tests for those.
std::set<std::string> op_names;
shape::type_t target_type = migraphx::shape::float_type;
std::string name() const { return "eliminate_fp8"; }
void apply(module& m) const;
};
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#endif
...@@ -227,18 +227,20 @@ struct gemm_impl ...@@ -227,18 +227,20 @@ struct gemm_impl
{ {
if(strided_batched) if(strided_batched)
{ {
auto common_args = create_strided_batched_args_common_fp8(ctx, input_args); auto common_args = create_strided_batched_args_common(ctx, input_args);
rocblas_invoke(&rocblas_gemm_strided_batched_ex3, rocblas_invoke(&rocblas_gemm_strided_batched_ex3,
common_args, common_args,
rocblas_compute_type_f32,
rocblas_gemm_algo_standard, rocblas_gemm_algo_standard,
solution_idx, solution_idx,
gemm_flags); gemm_flags);
} }
else else
{ {
auto common_args = create_gemm_ex_args_common_fp8(ctx, input_args); auto common_args = create_gemm_ex_args_common(ctx, input_args);
rocblas_invoke(&rocblas_gemm_ex3, rocblas_invoke(&rocblas_gemm_ex3,
common_args, common_args,
rocblas_compute_type_f32,
rocblas_gemm_algo_standard, rocblas_gemm_algo_standard,
solution_idx, solution_idx,
gemm_flags); gemm_flags);
...@@ -252,6 +254,7 @@ struct gemm_impl ...@@ -252,6 +254,7 @@ struct gemm_impl
auto common_args = create_strided_batched_args_common(ctx, input_args); auto common_args = create_strided_batched_args_common(ctx, input_args);
rocblas_invoke(&rocblas_gemm_strided_batched_ex, rocblas_invoke(&rocblas_gemm_strided_batched_ex,
common_args, common_args,
compute_type,
rocblas_gemm_algo_solution_index, rocblas_gemm_algo_solution_index,
solution_idx, solution_idx,
gemm_flags); gemm_flags);
...@@ -261,6 +264,7 @@ struct gemm_impl ...@@ -261,6 +264,7 @@ struct gemm_impl
auto common_args = create_gemm_ex_args_common(ctx, input_args); auto common_args = create_gemm_ex_args_common(ctx, input_args);
rocblas_invoke(&rocblas_gemm_ex, rocblas_invoke(&rocblas_gemm_ex,
common_args, common_args,
compute_type,
rocblas_gemm_algo_solution_index, rocblas_gemm_algo_solution_index,
solution_idx, solution_idx,
gemm_flags); gemm_flags);
...@@ -300,6 +304,7 @@ struct gemm_impl ...@@ -300,6 +304,7 @@ struct gemm_impl
auto common_args = create_strided_batched_args_common(ctx, input_args); auto common_args = create_strided_batched_args_common(ctx, input_args);
check_valid = rocblas_invoke(&rocblas_gemm_strided_batched_ex, check_valid = rocblas_invoke(&rocblas_gemm_strided_batched_ex,
common_args, common_args,
compute_type,
rocblas_gemm_algo_solution_index, rocblas_gemm_algo_solution_index,
solution_idx, solution_idx,
rocblas_gemm_flags_check_solution_index); rocblas_gemm_flags_check_solution_index);
...@@ -309,6 +314,7 @@ struct gemm_impl ...@@ -309,6 +314,7 @@ struct gemm_impl
auto common_args = create_gemm_ex_args_common(ctx, input_args); auto common_args = create_gemm_ex_args_common(ctx, input_args);
check_valid = rocblas_invoke(&rocblas_gemm_ex, check_valid = rocblas_invoke(&rocblas_gemm_ex,
common_args, common_args,
compute_type,
rocblas_gemm_algo_solution_index, rocblas_gemm_algo_solution_index,
solution_idx, solution_idx,
rocblas_gemm_flags_check_solution_index); rocblas_gemm_flags_check_solution_index);
...@@ -359,40 +365,8 @@ struct gemm_impl ...@@ -359,40 +365,8 @@ struct gemm_impl
output_type, output_type,
ldd, ldd,
d_stride, d_stride,
num_matrices, num_matrices);
compute_type);
} }
auto create_strided_batched_args_common_fp8(context& ctx,
const std::vector<argument>& args) const
{
return pack(ctx.get_stream().get_rocblas(),
transb ? rocblas_operation_transpose : rocblas_operation_none,
transa ? rocblas_operation_transpose : rocblas_operation_none,
n,
m,
k,
get_alpha(),
args[1].data(),
arg_type,
ldb,
b_stride,
args[0].data(),
arg_type,
lda,
a_stride,
get_beta(),
args[2].data(),
output_type,
ldc,
c_stride,
is_3inputs ? args[3].data() : args[2].data(),
output_type,
ldd,
d_stride,
num_matrices,
rocblas_compute_type_f32);
}
/** /**
* Helper method to create that subset of a long rocBLAS argument list that is common * Helper method to create that subset of a long rocBLAS argument list that is common
* to multiple "gemm_ex..." calls. * to multiple "gemm_ex..." calls.
...@@ -424,33 +398,9 @@ struct gemm_impl ...@@ -424,33 +398,9 @@ struct gemm_impl
ldc, ldc,
is_3inputs ? args[3].data() : args[2].data(), is_3inputs ? args[3].data() : args[2].data(),
output_type, output_type,
ldd, ldd);
compute_type);
}
auto create_gemm_ex_args_common_fp8(context& ctx, const std::vector<argument>& args) const
{
return pack(ctx.get_stream().get_rocblas(),
transb ? rocblas_operation_transpose : rocblas_operation_none,
transa ? rocblas_operation_transpose : rocblas_operation_none,
n,
m,
k,
get_alpha(),
args[1].data(),
arg_type,
ldb,
args[0].data(),
arg_type,
lda,
get_beta(),
args[2].data(),
output_type,
ldc,
is_3inputs ? args[3].data() : args[2].data(),
output_type,
ldd,
rocblas_compute_type_f32);
} }
#ifdef MIGRAPHX_USE_ROCBLAS_TUNING_API #ifdef MIGRAPHX_USE_ROCBLAS_TUNING_API
/** /**
* Find best rocBLAS solution: Get list of solutions and try them all, returning the index * Find best rocBLAS solution: Get list of solutions and try them all, returning the index
...@@ -478,6 +428,7 @@ struct gemm_impl ...@@ -478,6 +428,7 @@ struct gemm_impl
auto common_args = create_strided_batched_args_common(ctx, input_args); auto common_args = create_strided_batched_args_common(ctx, input_args);
rocblas_invoke(&rocblas_gemm_strided_batched_ex_get_solutions, rocblas_invoke(&rocblas_gemm_strided_batched_ex_get_solutions,
common_args, common_args,
compute_type,
rocblas_gemm_algo_solution_index, rocblas_gemm_algo_solution_index,
gemm_flags, gemm_flags,
nullptr, nullptr,
...@@ -487,6 +438,7 @@ struct gemm_impl ...@@ -487,6 +438,7 @@ struct gemm_impl
auto common_sol_args = create_strided_batched_args_common(ctx, input_args); auto common_sol_args = create_strided_batched_args_common(ctx, input_args);
rocblas_invoke(&rocblas_gemm_strided_batched_ex_get_solutions, rocblas_invoke(&rocblas_gemm_strided_batched_ex_get_solutions,
common_sol_args, common_sol_args,
compute_type,
rocblas_gemm_algo_solution_index, rocblas_gemm_algo_solution_index,
gemm_flags, gemm_flags,
solution_indices.data(), solution_indices.data(),
...@@ -497,6 +449,7 @@ struct gemm_impl ...@@ -497,6 +449,7 @@ struct gemm_impl
auto common_args = create_gemm_ex_args_common(ctx, input_args); auto common_args = create_gemm_ex_args_common(ctx, input_args);
rocblas_invoke(&rocblas_gemm_ex_get_solutions, rocblas_invoke(&rocblas_gemm_ex_get_solutions,
common_args, common_args,
compute_type,
rocblas_gemm_algo_solution_index, rocblas_gemm_algo_solution_index,
gemm_flags, gemm_flags,
nullptr, nullptr,
...@@ -506,6 +459,7 @@ struct gemm_impl ...@@ -506,6 +459,7 @@ struct gemm_impl
auto common_sol_args = create_gemm_ex_args_common(ctx, input_args); auto common_sol_args = create_gemm_ex_args_common(ctx, input_args);
rocblas_invoke(&rocblas_gemm_ex_get_solutions, rocblas_invoke(&rocblas_gemm_ex_get_solutions,
common_sol_args, common_sol_args,
compute_type,
rocblas_gemm_algo_solution_index, rocblas_gemm_algo_solution_index,
gemm_flags, gemm_flags,
solution_indices.data(), solution_indices.data(),
......
...@@ -52,6 +52,7 @@ ...@@ -52,6 +52,7 @@
#include <migraphx/simplify_qdq.hpp> #include <migraphx/simplify_qdq.hpp>
#include <migraphx/simplify_reshapes.hpp> #include <migraphx/simplify_reshapes.hpp>
#include <migraphx/split_single_dyn_dim.hpp> #include <migraphx/split_single_dyn_dim.hpp>
#include <migraphx/eliminate_fp8.hpp>
#include <migraphx/gpu/allocation_model.hpp> #include <migraphx/gpu/allocation_model.hpp>
#include <migraphx/gpu/compile_miopen.hpp> #include <migraphx/gpu/compile_miopen.hpp>
#include <migraphx/gpu/compile_ops.hpp> #include <migraphx/gpu/compile_ops.hpp>
...@@ -105,6 +106,11 @@ std::vector<pass> target::get_passes(migraphx::context& gctx, const compile_opti ...@@ -105,6 +106,11 @@ std::vector<pass> target::get_passes(migraphx::context& gctx, const compile_opti
unsupported_types.erase(shape::type_t::uint8_type); unsupported_types.erase(shape::type_t::uint8_type);
unsupported_types.erase(shape::type_t::int32_type); unsupported_types.erase(shape::type_t::int32_type);
unsupported_types.erase(shape::type_t::tuple_type); unsupported_types.erase(shape::type_t::tuple_type);
std::set<std::string> unsupported_fp8_ops = {};
if(not gpu::rocblas_fp8_available())
{
unsupported_fp8_ops.insert("dot");
}
// clang-format off // clang-format off
return return
{ {
...@@ -147,6 +153,7 @@ std::vector<pass> target::get_passes(migraphx::context& gctx, const compile_opti ...@@ -147,6 +153,7 @@ std::vector<pass> target::get_passes(migraphx::context& gctx, const compile_opti
dead_code_elimination{}, dead_code_elimination{},
enable_pass(mlir_enabled(), fuse_mlir{&ctx}), enable_pass(mlir_enabled(), fuse_mlir{&ctx}),
dead_code_elimination{}, dead_code_elimination{},
eliminate_fp8{unsupported_fp8_ops},
lowering{&ctx, options.offload_copy}, lowering{&ctx, options.offload_copy},
eliminate_contiguous{"gpu::contiguous"}, eliminate_contiguous{"gpu::contiguous"},
dead_code_elimination{}, dead_code_elimination{},
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment