Commit ad9c25ea authored by Umang Yadav's avatar Umang Yadav
Browse files

add eliminate_fp8 pass

parent 4604f2e1
......@@ -49,6 +49,7 @@ add_library(migraphx
eliminate_concat.cpp
eliminate_contiguous.cpp
eliminate_data_type.cpp
eliminate_fp8.cpp
eliminate_identity.cpp
eliminate_pad.cpp
env.cpp
......
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#include <utility>
#include <migraphx/eliminate_fp8.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/program.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/iterator_for.hpp>
#include <migraphx/stringutils.hpp>
#include <migraphx/ranges.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
void eliminate_fp8::apply(module& m) const
{
for(auto ins : iterator_for(m))
{
if(not contains(op_names, ins->name()))
continue;
migraphx::shape::type_t orig_type = ins->get_shape().type();
std::vector<instruction_ref> orig_inputs = ins->inputs();
std::vector<instruction_ref> new_inputs;
for(const auto& i : orig_inputs)
{
new_inputs.push_back(m.insert_instruction(
ins,
migraphx::make_op("convert", {{"target_type", migraphx::to_value(target_type)}}),
i));
}
auto new_ins = m.insert_instruction(ins, ins->get_operator(), {new_inputs});
auto convert_back_ins = m.insert_instruction(
ins,
migraphx::make_op("convert", {{"target_type", migraphx::to_value(orig_type)}}),
new_ins);
m.replace_instruction(ins, convert_back_ins);
}
}
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2023 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#ifndef MIGRAPHX_GUARD_AMDMIGRAPHX_ELIMINATE_FP8_HPP
#define MIGRAPHX_GUARD_AMDMIGRAPHX_ELIMINATE_FP8_HPP
#include <migraphx/config.hpp>
#include <migraphx/shape.hpp>
#include <set>
#include <string>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
struct module;
/**
This will insert convert operators for the operators that are not implemented for FP8 dtypes
*/
struct MIGRAPHX_EXPORT eliminate_fp8
{
// TODO: Add all device ops as a later PR and add tests for those.
std::set<std::string> op_names;
shape::type_t target_type = migraphx::shape::float_type;
std::string name() const { return "eliminate_fp8"; }
void apply(module& m) const;
};
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#endif
......@@ -227,18 +227,20 @@ struct gemm_impl
{
if(strided_batched)
{
auto common_args = create_strided_batched_args_common_fp8(ctx, input_args);
auto common_args = create_strided_batched_args_common(ctx, input_args);
rocblas_invoke(&rocblas_gemm_strided_batched_ex3,
common_args,
rocblas_compute_type_f32,
rocblas_gemm_algo_standard,
solution_idx,
gemm_flags);
}
else
{
auto common_args = create_gemm_ex_args_common_fp8(ctx, input_args);
auto common_args = create_gemm_ex_args_common(ctx, input_args);
rocblas_invoke(&rocblas_gemm_ex3,
common_args,
rocblas_compute_type_f32,
rocblas_gemm_algo_standard,
solution_idx,
gemm_flags);
......@@ -252,6 +254,7 @@ struct gemm_impl
auto common_args = create_strided_batched_args_common(ctx, input_args);
rocblas_invoke(&rocblas_gemm_strided_batched_ex,
common_args,
compute_type,
rocblas_gemm_algo_solution_index,
solution_idx,
gemm_flags);
......@@ -261,6 +264,7 @@ struct gemm_impl
auto common_args = create_gemm_ex_args_common(ctx, input_args);
rocblas_invoke(&rocblas_gemm_ex,
common_args,
compute_type,
rocblas_gemm_algo_solution_index,
solution_idx,
gemm_flags);
......@@ -300,6 +304,7 @@ struct gemm_impl
auto common_args = create_strided_batched_args_common(ctx, input_args);
check_valid = rocblas_invoke(&rocblas_gemm_strided_batched_ex,
common_args,
compute_type,
rocblas_gemm_algo_solution_index,
solution_idx,
rocblas_gemm_flags_check_solution_index);
......@@ -309,6 +314,7 @@ struct gemm_impl
auto common_args = create_gemm_ex_args_common(ctx, input_args);
check_valid = rocblas_invoke(&rocblas_gemm_ex,
common_args,
compute_type,
rocblas_gemm_algo_solution_index,
solution_idx,
rocblas_gemm_flags_check_solution_index);
......@@ -359,40 +365,8 @@ struct gemm_impl
output_type,
ldd,
d_stride,
num_matrices,
compute_type);
num_matrices);
}
auto create_strided_batched_args_common_fp8(context& ctx,
const std::vector<argument>& args) const
{
return pack(ctx.get_stream().get_rocblas(),
transb ? rocblas_operation_transpose : rocblas_operation_none,
transa ? rocblas_operation_transpose : rocblas_operation_none,
n,
m,
k,
get_alpha(),
args[1].data(),
arg_type,
ldb,
b_stride,
args[0].data(),
arg_type,
lda,
a_stride,
get_beta(),
args[2].data(),
output_type,
ldc,
c_stride,
is_3inputs ? args[3].data() : args[2].data(),
output_type,
ldd,
d_stride,
num_matrices,
rocblas_compute_type_f32);
}
/**
* Helper method to create that subset of a long rocBLAS argument list that is common
* to multiple "gemm_ex..." calls.
......@@ -424,33 +398,9 @@ struct gemm_impl
ldc,
is_3inputs ? args[3].data() : args[2].data(),
output_type,
ldd,
compute_type);
}
auto create_gemm_ex_args_common_fp8(context& ctx, const std::vector<argument>& args) const
{
return pack(ctx.get_stream().get_rocblas(),
transb ? rocblas_operation_transpose : rocblas_operation_none,
transa ? rocblas_operation_transpose : rocblas_operation_none,
n,
m,
k,
get_alpha(),
args[1].data(),
arg_type,
ldb,
args[0].data(),
arg_type,
lda,
get_beta(),
args[2].data(),
output_type,
ldc,
is_3inputs ? args[3].data() : args[2].data(),
output_type,
ldd,
rocblas_compute_type_f32);
ldd);
}
#ifdef MIGRAPHX_USE_ROCBLAS_TUNING_API
/**
* Find best rocBLAS solution: Get list of solutions and try them all, returning the index
......@@ -478,6 +428,7 @@ struct gemm_impl
auto common_args = create_strided_batched_args_common(ctx, input_args);
rocblas_invoke(&rocblas_gemm_strided_batched_ex_get_solutions,
common_args,
compute_type,
rocblas_gemm_algo_solution_index,
gemm_flags,
nullptr,
......@@ -487,6 +438,7 @@ struct gemm_impl
auto common_sol_args = create_strided_batched_args_common(ctx, input_args);
rocblas_invoke(&rocblas_gemm_strided_batched_ex_get_solutions,
common_sol_args,
compute_type,
rocblas_gemm_algo_solution_index,
gemm_flags,
solution_indices.data(),
......@@ -497,6 +449,7 @@ struct gemm_impl
auto common_args = create_gemm_ex_args_common(ctx, input_args);
rocblas_invoke(&rocblas_gemm_ex_get_solutions,
common_args,
compute_type,
rocblas_gemm_algo_solution_index,
gemm_flags,
nullptr,
......@@ -506,6 +459,7 @@ struct gemm_impl
auto common_sol_args = create_gemm_ex_args_common(ctx, input_args);
rocblas_invoke(&rocblas_gemm_ex_get_solutions,
common_sol_args,
compute_type,
rocblas_gemm_algo_solution_index,
gemm_flags,
solution_indices.data(),
......
......@@ -52,6 +52,7 @@
#include <migraphx/simplify_qdq.hpp>
#include <migraphx/simplify_reshapes.hpp>
#include <migraphx/split_single_dyn_dim.hpp>
#include <migraphx/eliminate_fp8.hpp>
#include <migraphx/gpu/allocation_model.hpp>
#include <migraphx/gpu/compile_miopen.hpp>
#include <migraphx/gpu/compile_ops.hpp>
......@@ -105,6 +106,11 @@ std::vector<pass> target::get_passes(migraphx::context& gctx, const compile_opti
unsupported_types.erase(shape::type_t::uint8_type);
unsupported_types.erase(shape::type_t::int32_type);
unsupported_types.erase(shape::type_t::tuple_type);
std::set<std::string> unsupported_fp8_ops = {};
if(not gpu::rocblas_fp8_available())
{
unsupported_fp8_ops.insert("dot");
}
// clang-format off
return
{
......@@ -147,6 +153,7 @@ std::vector<pass> target::get_passes(migraphx::context& gctx, const compile_opti
dead_code_elimination{},
enable_pass(mlir_enabled(), fuse_mlir{&ctx}),
dead_code_elimination{},
eliminate_fp8{unsupported_fp8_ops},
lowering{&ctx, options.offload_copy},
eliminate_contiguous{"gpu::contiguous"},
dead_code_elimination{},
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment