Commit c027803e authored by Paul's avatar Paul
Browse files

Use device name

parent ee391f22
...@@ -43,6 +43,11 @@ auto get_arch_name(rank<1>, const HipDeviceProp& props) -> decltype(std::string( ...@@ -43,6 +43,11 @@ auto get_arch_name(rank<1>, const HipDeviceProp& props) -> decltype(std::string(
return std::string(props.gcnArchName); return std::string(props.gcnArchName);
} }
std::string get_arch_name(const hipDeviceProp_t& props)
{
return get_arch_name(rank<1>{}, props);
}
int get_device_id() int get_device_id()
{ {
int device; int device;
...@@ -58,7 +63,7 @@ std::string get_device_name() ...@@ -58,7 +63,7 @@ std::string get_device_name()
auto status = hipGetDeviceProperties(&props, get_device_id()); auto status = hipGetDeviceProperties(&props, get_device_id());
if(status != hipSuccess) if(status != hipSuccess)
MIGRAPHX_THROW("Failed to get device properties"); MIGRAPHX_THROW("Failed to get device properties");
return get_arch_name(rank<1>{}, props); return get_arch_name(props);
} }
} // namespace gpu } // namespace gpu
......
...@@ -170,7 +170,9 @@ struct hip_device ...@@ -170,7 +170,9 @@ struct hip_device
std::size_t stream_id() const { return current_stream; } std::size_t stream_id() const { return current_stream; }
std::string get_device_name() const { return device_props.gcnArchName; } std::string get_device_name() const { return get_arch_name(device_props); }
std::string get_gfx_name() const { return trim(split_string(get_device_name(), ':').front()); }
std::size_t get_device_major() const { return device_props.major; } std::size_t get_device_major() const { return device_props.major; }
......
...@@ -27,10 +27,14 @@ ...@@ -27,10 +27,14 @@
#include <migraphx/config.hpp> #include <migraphx/config.hpp>
#include <string> #include <string>
struct hipDeviceProp_t;
namespace migraphx { namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
namespace gpu { namespace gpu {
std::string get_arch_name(const hipDeviceProp_t& props);
std::string get_device_name(); std::string get_device_name();
int get_device_id(); int get_device_id();
......
...@@ -333,7 +333,7 @@ struct ck_gemm_compiler : compiler<ck_gemm_compiler> ...@@ -333,7 +333,7 @@ struct ck_gemm_compiler : compiler<ck_gemm_compiler>
cde_op}; cde_op};
} }
operation compile_op(context& /* ctx */, const std::vector<shape>& inputs, const value& v) const operation compile_op(context& ctx, const std::vector<shape>& inputs, const value& v) const
{ {
auto a_shape = inputs[0]; auto a_shape = inputs[0];
auto b_shape = inputs[1]; auto b_shape = inputs[1];
...@@ -343,7 +343,7 @@ struct ck_gemm_compiler : compiler<ck_gemm_compiler> ...@@ -343,7 +343,7 @@ struct ck_gemm_compiler : compiler<ck_gemm_compiler>
auto problem = create_problem(inputs, v); auto problem = create_problem(inputs, v);
const auto include_header = problem.GetIncludeHeader(); const auto include_header = problem.GetIncludeHeader();
const auto solutions = problem.GetSolutions("gfx90a"); const auto solutions = problem.GetSolutions(ctx.get_current_device().get_gfx_name());
const auto solution = solutions.at(tuning_value); const auto solution = solutions.at(tuning_value);
const auto template_str = solution.template_str; const auto template_str = solution.template_str;
const auto blocks_per_batch = solution.grid_size; const auto blocks_per_batch = solution.grid_size;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment