Commit cb965031 authored by Khalique Ahmed's avatar Khalique Ahmed
Browse files

use other device name function

parent 2ec8ba6a
...@@ -20,8 +20,7 @@ void softmax(hipStream_t stream, const argument& result, const argument& arg, in ...@@ -20,8 +20,7 @@ void softmax(hipStream_t stream, const argument& result, const argument& arg, in
migraphx::shape batch_shape{result.get_shape().type(), batch_lens}; migraphx::shape batch_shape{result.get_shape().type(), batch_lens};
hip_visit_all(result, arg, batch_shape)([&](auto output, auto input, auto batch) { hip_visit_all(result, arg, batch_shape)([&](auto output, auto input, auto batch) {
const index_int max_block_size = 120; const index_int max_block_size = 128;
// const index_int max_block_size = 128;
const index_int block_size = compute_block_size(batch_item_num, max_block_size); const index_int block_size = compute_block_size(batch_item_num, max_block_size);
using type = device_type<std::remove_cv_t<typename decltype(input)::value_type>>; using type = device_type<std::remove_cv_t<typename decltype(input)::value_type>>;
type init = lowest(); type init = lowest();
......
...@@ -97,17 +97,6 @@ struct hip_device ...@@ -97,17 +97,6 @@ struct hip_device
return rbhandle.get(); return rbhandle.get();
} }
std::string get_device_name()
{
hipDeviceProp_t props{};
// int device;
// if (not (hipGetDevice(&device) == hipSuccess))
// MIGRAPHX_THROW("Unable to get hip device");
// if (not (hipGetDeviceProperties(&props, device) == hipSuccess))
// MIGRAPHX_THROW("Unable to get hip device properties");
return "gfx" + std::to_string(props.gcnArch);
}
void wait() const void wait() const
{ {
if(s == nullptr) if(s == nullptr)
......
...@@ -24,6 +24,7 @@ ...@@ -24,6 +24,7 @@
#include <migraphx/gpu/context.hpp> #include <migraphx/gpu/context.hpp>
#include <migraphx/gpu/convolution.hpp> #include <migraphx/gpu/convolution.hpp>
#include <migraphx/gpu/deconvolution.hpp> #include <migraphx/gpu/deconvolution.hpp>
#include <migraphx/gpu/device_name.hpp>
#include <migraphx/gpu/elu.hpp> #include <migraphx/gpu/elu.hpp>
#include <migraphx/gpu/equal.hpp> #include <migraphx/gpu/equal.hpp>
#include <migraphx/gpu/gemm.hpp> #include <migraphx/gpu/gemm.hpp>
...@@ -111,7 +112,7 @@ struct miopen_apply ...@@ -111,7 +112,7 @@ struct miopen_apply
#if ROCBLAS_VERSION_MAJOR >= 2 && ROCBLAS_VERSION_MINOR >= 38 #if ROCBLAS_VERSION_MAJOR >= 2 && ROCBLAS_VERSION_MINOR >= 38
auto& ctx = get_context(); auto& ctx = get_context();
const auto device_name = const auto device_name =
trim(split_string(ctx.get_stream().get_device_name(), ':').front()); trim(split_string(get_device_name(), ':').front());
if(contains(get_rocblas_fp32_archs(), device_name)) if(contains(get_rocblas_fp32_archs(), device_name))
compute_fp32 = true; compute_fp32 = true;
rocblas_gemm_flags flag; rocblas_gemm_flags flag;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment