Unverified Commit 77042e30 authored by kahmed10's avatar kahmed10 Committed by GitHub
Browse files

Update xdlops/rocblas fp32 arch (#1752)

Refactor supported gfx archs
parent 42772fd6
...@@ -164,10 +164,10 @@ std::string mlir_print(F f, T x) ...@@ -164,10 +164,10 @@ std::string mlir_print(F f, T x)
return ss.str(); return ss.str();
} }
const std::unordered_set<std::string>& get_xdlops_archs() bool has_xdlops(const std::string& target_arch)
{ {
static std::unordered_set<std::string> supported_archs{"gfx908", "gfx90a"}; const auto device_name = trim(split_string(target_arch, ':').front());
return supported_archs; return (starts_with(device_name, "gfx9") and device_name >= "gfx908");
} }
struct mlir_program struct mlir_program
...@@ -560,9 +560,7 @@ struct mlir_program ...@@ -560,9 +560,7 @@ struct mlir_program
pp = pp =
problem_params{ins->get_operator(), to_shapes(ins->inputs()), ins->get_shape()}; problem_params{ins->get_operator(), to_shapes(ins->inputs()), ins->get_shape()};
// check if HW supports xdlops // check if HW supports xdlops
auto target_chip = trim(split_string(target_arch, ':').front()); if(has_xdlops(target_arch))
bool xdlops = contains(get_xdlops_archs(), target_chip);
if(xdlops)
ops.add_attributes({{"xdlopsV2", true}}); ops.add_attributes({{"xdlopsV2", true}});
} }
......
...@@ -47,16 +47,10 @@ rocblas_handle_ptr create_rocblas_handle_ptr(hipStream_t s) ...@@ -47,16 +47,10 @@ rocblas_handle_ptr create_rocblas_handle_ptr(hipStream_t s)
return rb; return rb;
} }
const std::unordered_set<std::string>& get_rocblas_fp32_archs()
{
static std::unordered_set<std::string> supported_archs{"gfx908", "gfx90a"};
return supported_archs;
}
bool get_compute_fp32_flag() bool get_compute_fp32_flag()
{ {
const auto device_name = trim(split_string(get_device_name(), ':').front()); const auto device_name = trim(split_string(get_device_name(), ':').front());
return contains(get_rocblas_fp32_archs(), device_name); return (starts_with(device_name, "gfx9") and device_name >= "gfx908");
} }
bool get_int8_x4_format(context& ctx) bool get_int8_x4_format(context& ctx)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment