Commit 698c188d authored by Brian Pickrell's avatar Brian Pickrell
Browse files

First try to compile a call to rocblas-beta API; doesn't compile. Note the...

First try to compile a call to rocblas-beta API; doesn't compile.  Note the update to rocm version in Dockerfile.  May require a new Docker image build.
parent fc9ebb06
Pipeline #667 failed with stages
in 0 seconds
...@@ -10,7 +10,7 @@ RUN apt-get update && apt-get install -y gnupg2 --no-install-recommends curl && ...@@ -10,7 +10,7 @@ RUN apt-get update && apt-get install -y gnupg2 --no-install-recommends curl &&
curl -sL http://repo.radeon.com/rocm/rocm.gpg.key | apt-key add - curl -sL http://repo.radeon.com/rocm/rocm.gpg.key | apt-key add -
# Add rocm repository # Add rocm repository
RUN sh -c 'echo deb [arch=amd64 trusted=yes] http://repo.radeon.com/rocm/apt/5.4.2/ ubuntu main > /etc/apt/sources.list.d/rocm.list' RUN sh -c 'echo deb [arch=amd64 trusted=yes] http://repo.radeon.com/rocm/apt/5.4.3/ ubuntu main > /etc/apt/sources.list.d/rocm.list'
# Install dependencies # Install dependencies
RUN apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install -y --allow-unauthenticated \ RUN apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install -y --allow-unauthenticated \
......
...@@ -21,10 +21,16 @@ ...@@ -21,10 +21,16 @@
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE. * THE SOFTWARE.
*/ */
#define ROCBLAS_BETA_FEATURES_API 1
#include <rocblas/rocblas.h> #include <rocblas/rocblas.h>
#include <migraphx/gpu/gemm_impl.hpp> #include <migraphx/gpu/gemm_impl.hpp>
// #include "rocblas_gemm_ex_get_solutions.hpp"
#include <migraphx/generate.hpp>
#include <migraphx/reduce_dims.hpp> #include <migraphx/reduce_dims.hpp>
#include <migraphx/permutation.hpp> #include <migraphx/permutation.hpp>
#include <migraphx/operation.hpp>
#include <migraphx/time.hpp>
namespace migraphx { namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
...@@ -102,6 +108,45 @@ static rocblas_int get_batch_stride(const argument& a) ...@@ -102,6 +108,45 @@ static rocblas_int get_batch_stride(const argument& a)
return a.get_shape().strides()[a.get_shape().strides().size() - 3]; return a.get_shape().strides()[a.get_shape().strides().size() - 3];
} }
std::vector<argument> generate_arguments(const std::vector<shape>& shapes, unsigned long seed = 0)
{
std::vector<argument> args;
std::transform(shapes.begin(), shapes.end(), std::back_inserter(args), [&](auto& s) {
return to_gpu(generate_argument(s, seed++));
});
return args;
}
// from perf.cpp
using milliseconds = std::chrono::duration<double, std::milli>;
std::pair<double, double>
time_op(context& ictx, operation op, const std::vector<shape>& inputs, int n)
{
// TODO: Use std::ref
migraphx::context ctx = ictx;
auto& gctx = any_cast<migraphx::gpu::context>(ctx);
auto output = op.compute_shape(inputs);
// op.finalize(ctx, output, inputs);
auto args = generate_arguments(inputs);
auto run = [&] {
op.compute(ctx, output, args);
ctx.finish();
};
gctx.enable_perf_measurement();
run();
double host_time = 0.0;
double device_time = 0.0;
for(auto i : range(n))
{
(void)i;
host_time += time<milliseconds>(run);
device_time += gctx.get_elapsed_ms();
}
return std::make_pair(host_time / n, device_time / n);
}
template <class T> template <class T>
void gemm_impl(context& ctx, void gemm_impl(context& ctx,
const shape& output_shape, const shape& output_shape,
...@@ -183,74 +228,76 @@ void gemm_impl(context& ctx, ...@@ -183,74 +228,76 @@ void gemm_impl(context& ctx,
// instead of rocblas_gemm_strided_batched_ex. // instead of rocblas_gemm_strided_batched_ex.
m *= num_matrices; m *= num_matrices;
// the rocblas_gemm API handles inputs and output matrices as
// column-major format. When doing a C = A * B, we actually do
// C^T = (B^T) * (A^T). That is the reason we input args[1] as
// A and args[0] as B in calling the rocblas_gemm.
rocblas_invoke(&rocblas_gemm_ex,
ctx.get_stream().get_rocblas(),
transb ? rocblas_operation_transpose : rocblas_operation_none,
transa ? rocblas_operation_transpose : rocblas_operation_none,
n,
m,
k,
alpha_v,
to_pointer(args.at(1)),
arg_type,
ldb,
to_pointer(args.at(0)),
arg_type,
lda,
beta_v,
to_pointer(args[2]),
output_type,
ldc,
is_3inputs ? to_pointer(args[3]) : to_pointer(args[2]),
output_type,
ldd,
compute_type,
rocblas_gemm_algo_standard,
0,
flag);
}
else
{
auto a_stride = get_batch_stride(args[0]);
auto b_stride = get_batch_stride(args[1]);
auto c_stride = get_batch_stride(args[2]);
auto d_stride = is_3inputs ? get_batch_stride(args[3]) : c_stride;
rocblas_invoke(&rocblas_gemm_strided_batched_ex,
ctx.get_stream().get_rocblas(),
transb ? rocblas_operation_transpose : rocblas_operation_none,
transa ? rocblas_operation_transpose : rocblas_operation_none,
n,
m,
k,
alpha_v,
to_pointer(args.at(1)),
arg_type,
ldb,
b_stride,
to_pointer(args.at(0)),
arg_type,
lda,
a_stride,
beta_v,
to_pointer(args[2]),
output_type,
ldc,
c_stride,
is_3inputs ? to_pointer(args[3]) : to_pointer(args[2]),
output_type,
ldd,
d_stride,
num_matrices,
compute_type,
rocblas_gemm_algo_standard,
0,
flag);
} }
auto da = to_pointer(args.at(0));
auto db = to_pointer(args.at(1));
auto dc = to_pointer(args.at(2));
auto type = arg_type;
#define GEMM_EX_ARGS \
handle, transa, transb, m, n, k, alpha_v, da, type, lda, db, type, ldb, beta_v, dc, type, ldc, \
dc, type, ldc, type, rocblas_gemm_algo_solution_index
rocblas_handle handle;
CHECK_ROCBLAS_ERROR(rocblas_create_handle(&handle));
// Get number of solutions
rocblas_int size;
CHECK_ROCBLAS_ERROR(
rocblas_gemm_ex_get_solutions(GEMM_EX_ARGS, rocblas_gemm_flags_none, NULL, &size));
rocblas_cout << size << " solution(s) found" << std::endl;
// Fill array with list of solutions
std::vector<rocblas_int> ary(size);
CHECK_ROCBLAS_ERROR(
rocblas_gemm_ex_get_solutions(GEMM_EX_ARGS, rocblas_gemm_flags_none, ary.data(), &size));
// GEMM_EX_ARGS;
// rocblas_gemm_ex_get_solutions_template(rocblas_handle handle,
// rocblas_operation trans_a,
// rocblas_operation trans_b,
// rocblas_int m,
// rocblas_int n,
// rocblas_int k,
// const void* alpha,
// const void* a,
// rocblas_datatype a_type,
// rocblas_stride offsetAin,
// rocblas_int lda,
// rocblas_stride stride_a,
// const void* b,
// rocblas_datatype b_type,
// rocblas_stride offsetBin,
// rocblas_int ldb,
// rocblas_stride stride_b,
// const void* beta,
// const void* c,
// rocblas_datatype c_type,
// rocblas_stride offsetCin,
// rocblas_int ldc,
// rocblas_stride stride_c,
// void* d,
// rocblas_datatype d_type,
// rocblas_stride offsetDin,
// rocblas_int ldd,
// rocblas_stride stride_d,
// rocblas_int batch_count,
// rocblas_datatype compute_type,
// uint32_t flags,
// rocblas_int* list_array,
// rocblas_int* list_size)
// return pack(ctx.get_stream().get_rocblas());
// Get number of solutions
// rocblas_int size;
// CHECK_ROCBLAS_ERROR(
// rocblas_gemm_ex_get_solutions(GEMM_EX_ARGS, rocblas_gemm_flags_none, NULL, &size));
}); });
printf("here I am = ================================================");
exit(8);
} }
void gemm(context& ctx, void gemm(context& ctx,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment