Unverified Commit b37322ae authored by Umang Yadav's avatar Umang Yadav Committed by GitHub
Browse files

Explicitly set rocblas_pointer_mode in examples (#1331)



* fix rocblas pointer mode

* fix formatting

* formatting

* revert header change
Co-authored-by: default avatarumangyadav <umang.yadav@amd.com>
parent bb0e04ce
......@@ -23,7 +23,7 @@
*/
#include <algorithm>
#include <hip/hip_runtime.h>
#include <rocblas.h>
#include <rocblas/rocblas.h>
#include <migraphx/migraphx.h>
#include <migraphx/migraphx.hpp> // MIGraphX's C++ API
#include <numeric>
......@@ -56,11 +56,13 @@ struct sscal_custom_op final : migraphx::experimental_custom_op_base
migraphx::arguments args) const override
{
// create rocblas stream handle
auto rocblas_handle = create_rocblas_handle_ptr(ctx);
rocblas_int n = args[1].get_shape().lengths()[0];
float* alpha = reinterpret_cast<float*>(args[0].data());
float* vec_ptr = reinterpret_cast<float*>(args[1].data());
MIGRAPHX_ROCBLAS_ASSERT(rocblas_sscal(rocblas_handle, n, alpha, vec_ptr, 1));
auto rb_handle = create_rocblas_handle_ptr(ctx);
MIGRAPHX_ROCBLAS_ASSERT(rocblas_set_pointer_mode(rb_handle, rocblas_pointer_mode_device));
rocblas_int n = args[1].get_shape().lengths()[0];
float* alpha = reinterpret_cast<float*>(args[0].data());
float* vec_ptr = reinterpret_cast<float*>(args[1].data());
MIGRAPHX_ROCBLAS_ASSERT(rocblas_sscal(rb_handle, n, alpha, vec_ptr, 1));
MIGRAPHX_ROCBLAS_ASSERT(rocblas_destroy_handle(rb_handle));
return args[1];
}
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment