Commit 037205c5 authored by Umang Yadav's avatar Umang Yadav
Browse files

Works now

parent 3aa465fd
...@@ -286,9 +286,9 @@ endif() ...@@ -286,9 +286,9 @@ endif()
if(HAS_ROCBLAS_FP8_BETA_API) if(HAS_ROCBLAS_FP8_BETA_API)
target_compile_definitions(migraphx_gpu PUBLIC -DMIGRAPHX_USE_ROCBLAS_FP8_API -DROCBLAS_BETA_FEATURES_API -DROCBLAS_NO_DEPRECATED_WARNINGS) target_compile_definitions(migraphx_gpu PUBLIC -DMIGRAPHX_USE_ROCBLAS_FP8_API -DROCBLAS_BETA_FEATURES_API -DROCBLAS_NO_DEPRECATED_WARNINGS)
message(STATUA "MIGraphX is using BETA API of rocBLAS for FP8 computations") message(STATUS "MIGraphX is using Beta API of rocBLAS for FP8 computations")
else() else()
message(STATUS "rocBLAS does not have FP8 BETA API") message(STATUS "rocBLAS does not have Fp8 Beta API")
endif() endif()
target_link_libraries(migraphx_gpu PUBLIC migraphx MIOpen roc::rocblas) target_link_libraries(migraphx_gpu PUBLIC migraphx MIOpen roc::rocblas)
......
...@@ -229,7 +229,7 @@ struct gemm_impl ...@@ -229,7 +229,7 @@ struct gemm_impl
auto common_args = create_strided_batched_args_common_fp8(ctx, input_args); auto common_args = create_strided_batched_args_common_fp8(ctx, input_args);
rocblas_invoke(&rocblas_gemm_strided_batched_ex3, rocblas_invoke(&rocblas_gemm_strided_batched_ex3,
common_args, common_args,
rocblas_gemm_algo_solution_index, rocblas_gemm_algo_standard,
solution_idx, solution_idx,
gemm_flags); gemm_flags);
} }
...@@ -238,7 +238,7 @@ struct gemm_impl ...@@ -238,7 +238,7 @@ struct gemm_impl
auto common_args = create_gemm_ex_args_common_fp8(ctx, input_args); auto common_args = create_gemm_ex_args_common_fp8(ctx, input_args);
rocblas_invoke(&rocblas_gemm_ex3, rocblas_invoke(&rocblas_gemm_ex3,
common_args, common_args,
rocblas_gemm_algo_solution_index, rocblas_gemm_algo_standard,
solution_idx, solution_idx,
gemm_flags); gemm_flags);
} }
...@@ -388,7 +388,7 @@ struct gemm_impl ...@@ -388,7 +388,7 @@ struct gemm_impl
ldd, ldd,
d_stride, d_stride,
num_matrices, num_matrices,
rocblas_compute_type_f8_f8_f32); rocblas_compute_type_f32);
} }
/** /**
...@@ -447,7 +447,7 @@ struct gemm_impl ...@@ -447,7 +447,7 @@ struct gemm_impl
is_3inputs ? args[3].data() : args[2].data(), is_3inputs ? args[3].data() : args[2].data(),
output_type, output_type,
ldd, ldd,
rocblas_compute_type_f8_f8_f32); rocblas_compute_type_f32);
} }
#ifdef MIGRAPHX_USE_ROCBLAS_TUNING_API #ifdef MIGRAPHX_USE_ROCBLAS_TUNING_API
/** /**
......
...@@ -27,14 +27,15 @@ ...@@ -27,14 +27,15 @@
#include <migraphx/generate.hpp> #include <migraphx/generate.hpp>
#include <migraphx/make_op.hpp> #include <migraphx/make_op.hpp>
struct gemm_2args_bmv : verify_program<gemm_2args_bmv> template<migraphx::shape::type_t DType>
struct gemm_2args_bmv : verify_program<gemm_2args_bmv<DType>>
{ {
migraphx::program create_program() const migraphx::program create_program() const
{ {
migraphx::program p; migraphx::program p;
auto* mm = p.get_main_module(); auto* mm = p.get_main_module();
migraphx::shape m1_shape{migraphx::shape::float_type, {2, 3, 3, 5}}; migraphx::shape m1_shape{DType, {2, 3, 3, 5}};
migraphx::shape m2_shape{migraphx::shape::float_type, {5}}; migraphx::shape m2_shape{DType, {5}};
auto l1 = mm->add_parameter("1", m1_shape); auto l1 = mm->add_parameter("1", m1_shape);
auto l2 = mm->add_parameter("2", m2_shape); auto l2 = mm->add_parameter("2", m2_shape);
auto ul2 = mm->add_instruction(migraphx::make_op("unsqueeze", {{"axes", {1}}}), l2); auto ul2 = mm->add_instruction(migraphx::make_op("unsqueeze", {{"axes", {1}}}), l2);
...@@ -46,3 +47,7 @@ struct gemm_2args_bmv : verify_program<gemm_2args_bmv> ...@@ -46,3 +47,7 @@ struct gemm_2args_bmv : verify_program<gemm_2args_bmv>
return p; return p;
} }
}; };
template struct gemm_2args_bmv<migraphx::shape::float_type>;
template struct gemm_2args_bmv<migraphx::shape::fp8e4m3fnuz_type>;
...@@ -27,14 +27,15 @@ ...@@ -27,14 +27,15 @@
#include <migraphx/generate.hpp> #include <migraphx/generate.hpp>
#include <migraphx/make_op.hpp> #include <migraphx/make_op.hpp>
struct gemm_2args_mm_1 : verify_program<gemm_2args_mm_1> template <migraphx::shape::type_t DType>
struct gemm_2args_mm_1 : verify_program<gemm_2args_mm_1<DType>>
{ {
migraphx::program create_program() const migraphx::program create_program() const
{ {
migraphx::program p; migraphx::program p;
auto* mm = p.get_main_module(); auto* mm = p.get_main_module();
migraphx::shape m1_shape{migraphx::shape::float_type, {2, 2, 3}}; migraphx::shape m1_shape{DType, {2, 2, 3}};
migraphx::shape m2_shape{migraphx::shape::float_type, {1, 3, 4}}; migraphx::shape m2_shape{DType, {1, 3, 4}};
auto l1 = mm->add_parameter("1", m1_shape); auto l1 = mm->add_parameter("1", m1_shape);
auto l2 = mm->add_parameter("2", m2_shape); auto l2 = mm->add_parameter("2", m2_shape);
auto bl2 = auto bl2 =
...@@ -45,3 +46,6 @@ struct gemm_2args_mm_1 : verify_program<gemm_2args_mm_1> ...@@ -45,3 +46,6 @@ struct gemm_2args_mm_1 : verify_program<gemm_2args_mm_1>
return p; return p;
} }
}; };
template struct gemm_2args_mm_1<migraphx::shape::float_type>;
template struct gemm_2args_mm_1<migraphx::shape::fp8e4m3fnuz_type>;
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment