Commit 3aa465fd authored by Umang Yadav's avatar Umang Yadav
Browse files

compiles all right

parent a6c57726
......@@ -253,6 +253,8 @@ check_library_exists(MIOpen "miopenHiddenSetConvolutionFindMode" "${MIOPEN_LOCAT
check_library_exists(MIOpen "miopenFindSolutions" "${MIOPEN_LOCATION}" HAS_FIND_2_API)
# Beta API for automated GEMM tuning
check_library_exists(roc::rocblas "rocblas_gemm_ex_get_solutions" "${ROCBLAS_LOCATION}" HAS_ROCBLAS_TUNING_BETA_FEATURE_API)
# rocblas FP8 API
check_library_exists(roc::rocblas "rocblas_gemm_ex3" "${ROCBLAS_LOCATION}" HAS_ROCBLAS_FP8_BETA_API)
set(MIGRAPHX_USE_FIND_2_API "${HAS_FIND_2_API}" CACHE BOOL "")
......@@ -282,6 +284,13 @@ else()
message(STATUS "rocBLAS does not have User Tuning Beta API")
endif()
if(HAS_ROCBLAS_FP8_BETA_API)
target_compile_definitions(migraphx_gpu PUBLIC -DMIGRAPHX_USE_ROCBLAS_FP8_API -DROCBLAS_BETA_FEATURES_API -DROCBLAS_NO_DEPRECATED_WARNINGS)
message(STATUA "MIGraphX is using BETA API of rocBLAS for FP8 computations")
else()
message(STATUS "rocBLAS does not have FP8 BETA API")
endif()
target_link_libraries(migraphx_gpu PUBLIC migraphx MIOpen roc::rocblas)
target_link_libraries(migraphx_gpu PRIVATE migraphx_device migraphx_kernels)
if(MIGRAPHX_USE_COMPOSABLEKERNEL)
......
......@@ -23,10 +23,12 @@
*/
#include <rocblas/rocblas.h>
#include <migraphx/gpu/rocblas.hpp>
#include <migraphx/gpu/gemm_impl.hpp>
#include <migraphx/reduce_dims.hpp>
#include <migraphx/generate.hpp>
#include <migraphx/time.hpp>
#include <type_traits>
using microseconds = std::chrono::duration<double, std::micro>;
......@@ -46,7 +48,7 @@ rocblas_datatype get_type(shape::type_t type)
case shape::uint8_type: return rocblas_datatype_u8_r;
case shape::int32_type: return rocblas_datatype_i32_r;
case shape::uint32_type: return rocblas_datatype_u32_r;
case shape::fp8e4m3fnuz_type:
case shape::fp8e4m3fnuz_type: return rocblas_datatype_f8_r;
case shape::tuple_type:
case shape::bool_type:
case shape::uint16_type:
......@@ -217,23 +219,50 @@ struct gemm_impl
void run(context& ctx, const std::vector<argument>& input_args, int32_t solution_idx = 0) const
{
if(strided_batched)
if(rocblas_fp8_available() and
std::any_of(input_args.begin(), input_args.end(), [](const auto i) {
return i.get_shape().type() == migraphx::shape::fp8e4m3fnuz_type;
}))
{
auto common_args = create_strided_batched_args_common(ctx, input_args);
rocblas_invoke(&rocblas_gemm_strided_batched_ex,
common_args,
rocblas_gemm_algo_solution_index,
solution_idx,
gemm_flags);
if(strided_batched)
{
auto common_args = create_strided_batched_args_common_fp8(ctx, input_args);
rocblas_invoke(&rocblas_gemm_strided_batched_ex3,
common_args,
rocblas_gemm_algo_solution_index,
solution_idx,
gemm_flags);
}
else
{
auto common_args = create_gemm_ex_args_common_fp8(ctx, input_args);
rocblas_invoke(&rocblas_gemm_ex3,
common_args,
rocblas_gemm_algo_solution_index,
solution_idx,
gemm_flags);
}
}
else
{
auto common_args = create_gemm_ex_args_common(ctx, input_args);
rocblas_invoke(&rocblas_gemm_ex,
common_args,
rocblas_gemm_algo_solution_index,
solution_idx,
gemm_flags);
if(strided_batched)
{
auto common_args = create_strided_batched_args_common(ctx, input_args);
rocblas_invoke(&rocblas_gemm_strided_batched_ex,
common_args,
rocblas_gemm_algo_solution_index,
solution_idx,
gemm_flags);
}
else
{
auto common_args = create_gemm_ex_args_common(ctx, input_args);
rocblas_invoke(&rocblas_gemm_ex,
common_args,
rocblas_gemm_algo_solution_index,
solution_idx,
gemm_flags);
}
}
}
......@@ -331,6 +360,36 @@ struct gemm_impl
num_matrices,
compute_type);
}
auto create_strided_batched_args_common_fp8(context& ctx,
const std::vector<argument>& args) const
{
return pack(ctx.get_stream().get_rocblas(),
transb ? rocblas_operation_transpose : rocblas_operation_none,
transa ? rocblas_operation_transpose : rocblas_operation_none,
n,
m,
k,
get_alpha(),
args[1].data(),
arg_type,
ldb,
b_stride,
args[0].data(),
arg_type,
lda,
a_stride,
get_beta(),
args[2].data(),
output_type,
ldc,
c_stride,
is_3inputs ? args[3].data() : args[2].data(),
output_type,
ldd,
d_stride,
num_matrices,
rocblas_compute_type_f8_f8_f32);
}
/**
* Helper method to create that subset of a long rocBLAS argument list that is common
......@@ -366,6 +425,30 @@ struct gemm_impl
ldd,
compute_type);
}
auto create_gemm_ex_args_common_fp8(context& ctx, const std::vector<argument>& args) const
{
return pack(ctx.get_stream().get_rocblas(),
transb ? rocblas_operation_transpose : rocblas_operation_none,
transa ? rocblas_operation_transpose : rocblas_operation_none,
n,
m,
k,
get_alpha(),
args[1].data(),
arg_type,
ldb,
args[0].data(),
arg_type,
lda,
get_beta(),
args[2].data(),
output_type,
ldc,
is_3inputs ? args[3].data() : args[2].data(),
output_type,
ldd,
rocblas_compute_type_f8_f8_f32);
}
#ifdef MIGRAPHX_USE_ROCBLAS_TUNING_API
/**
* Find best rocBLAS solution: Get list of solutions and try them all, returning the index
......@@ -481,8 +564,8 @@ struct gemm_impl
rocblas_int b_stride = 0;
rocblas_int c_stride = 0;
rocblas_int d_stride = 0;
rocblas_datatype compute_type = rocblas_datatype_f32_r;
rocblas_datatype arg_type = rocblas_datatype_f32_r;
rocblas_datatype compute_type = rocblas_datatype_f32_r;
rocblas_datatype output_type = rocblas_datatype_f32_r;
bool strided_batched = true;
bool is_3inputs = true;
......
......@@ -40,6 +40,8 @@ struct context;
MIGRAPHX_GPU_EXPORT bool get_compute_fp32_flag();
MIGRAPHX_GPU_EXPORT bool rocblas_fp8_available();
} // namespace gpu
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
......
......@@ -220,12 +220,44 @@ struct miopen_apply
return mod->insert_instruction(ins, make_op("allocate", {{"shape", to_value(s)}}));
}
instruction_ref convert_fp8_to_fp32(instruction_ref ins)
{
std::vector<instruction_ref> fp8_inputs = ins->inputs();
std::vector<instruction_ref> fp32_inputs;
for(const auto& i : fp8_inputs)
{
fp32_inputs.push_back(mod->insert_instruction(
ins,
migraphx::make_op(
"convert",
{{"target_type", migraphx::to_value(migraphx::shape::type_t::float_type)}}),
i));
}
auto fp32_ins = mod->insert_instruction(ins, ins->get_operator(), {fp32_inputs});
auto fp8_ins = mod->insert_instruction(
ins,
migraphx::make_op(
"convert",
{{"target_type", migraphx::to_value(migraphx::shape::type_t::fp8e4m3fnuz_type)}}),
fp32_ins);
mod->replace_instruction(ins, fp8_ins);
return fp32_ins;
}
template <typename Op>
void add_gemm_op(const std::string& name)
{
apply_map.emplace(name, [=](instruction_ref ins) {
std::vector<instruction_ref> refs = ins->inputs();
assert(refs.size() == 2);
if(not rocblas_fp8_available() and
std::any_of(refs.begin(), refs.end(), [](const auto i) {
return i->get_shape().type() == migraphx::shape::fp8e4m3fnuz_type;
}))
{
// replace fp8 ins with fp32 ins
ins = convert_fp8_to_fp32(ins);
}
auto output = insert_allocation(ins, ins->get_shape());
refs.push_back(output);
return mod->replace_instruction(ins, rocblas_gemm<Op>{Op{}, 1, 0, compute_fp32}, refs);
......
......@@ -53,6 +53,15 @@ bool get_compute_fp32_flag()
return (starts_with(device_name, "gfx9") and device_name >= "gfx908");
}
bool rocblas_fp8_available() {
#ifndef MIGRAPHX_USE_ROCBLAS_FP8_API
return false;
#else
const auto device_name = trim(split_string(get_device_name(), ':').front());
return (starts_with(device_name, "gfx9") and device_name >= "gfx940");
#endif
}
} // namespace gpu
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
......@@ -29,26 +29,26 @@
#include <migraphx/make_op.hpp>
struct test_convert : verify_program<test_convert>
template <migraphx::shape::type_t From, migraphx::shape::type_t To>
struct test_convert : verify_program<test_convert<From, To>>
{
migraphx::program create_program() const
{
migraphx::program p;
auto* mm = p.get_main_module();
migraphx::shape sa{migraphx::shape::int8_type, {8, 24}};
migraphx::shape sb{migraphx::shape::int8_type, {24, 6}};
migraphx::shape sa{From, {8, 24}};
migraphx::shape sb{From, {24, 6}};
auto pa = mm->add_parameter("a", sa);
auto pb = mm->add_parameter("b", sb);
auto ia = mm->add_instruction(
migraphx::make_op("convert",
{{"target_type", migraphx::to_value(migraphx::shape::float_type)}}),
pa);
migraphx::make_op("convert", {{"target_type", migraphx::to_value(To)}}), pa);
auto ib = mm->add_instruction(
migraphx::make_op("convert",
{{"target_type", migraphx::to_value(migraphx::shape::float_type)}}),
pb);
migraphx::make_op("convert", {{"target_type", migraphx::to_value(To)}}), pb);
mm->add_instruction(migraphx::make_op("dot"), ia, ib);
return p;
};
};
template struct test_convert<migraphx::shape::int8_type, migraphx::shape::float_type>;
template struct test_convert<migraphx::shape::fp8e4m3fnuz_type, migraphx::shape::float_type>;
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment