Commit 3eb036b2 authored by Shucai Xiao's avatar Shucai Xiao
Browse files

fixed build errors.

parent 02188419
......@@ -379,7 +379,7 @@ struct cpu_gemm
if(out_lens == c_lens)
{
visit_all(result, c)([&](auto output, auto input) {
std::memcpy(output.data(), input.data(), c_shape.bytes());
std::memcpy(output.data(), input.data(), c.get_shape().bytes());
});
}
// need broadcast
......@@ -397,7 +397,7 @@ struct cpu_gemm
visit_all(result, c)([&](auto output, auto input) {
for(std::size_t i = 0; i < m; i++)
{
std::memcpy((output.data() + i * n), input.data(), c_shape.bytes());
std::memcpy((output.data() + i * n), input.data(), c.get_shape().bytes());
}
});
}
......
......@@ -171,8 +171,7 @@ shape miopen_gemm::compute_shape(const std::vector<shape>& inputs) const
return op.compute_shape(inputs);
}
void miopen_gemm::fill_result(context& ctx,
const shape& output_shape,
void miopen_gemm::fill_result(const shape& output_shape,
const argument& result,
const argument& c) const
{
......@@ -182,8 +181,8 @@ void miopen_gemm::fill_result(context& ctx,
{
output_shape.visit_type([&](auto as) {
auto to_pointer = [&](auto&& arg) { return to_rocblas_type(as.from(arg.data())); };
hipMemcpy(to_pointer(args[3]),
to_pointer(args[2]),
hipMemcpy(to_pointer(result),
to_pointer(c),
output_shape.bytes(),
hipMemcpyDeviceToDevice);
});
......@@ -191,15 +190,15 @@ void miopen_gemm::fill_result(context& ctx,
else if(c.single())
{
output_shape.visit_type([&](auto as) {
auto to_pointer = [&](auto&& arg, std::size_t offset) {
auto to_pointer = [&](auto&& arg, std::size_t offset = 0) {
return to_rocblas_type(as.from(arg.data() + offset));
};
for(std::size_t i = 0; i < output_shape.elements(); ++i)
{
hipMemcpy(to_pointer(args[3], i),
to_pointer(args[2]),
args[2].get_shape().bytes(),
hipMemcpy(to_pointer(result, i),
to_pointer(c),
c.get_shape().bytes(),
hipMemcpyDeviceToDevice);
}
});
......@@ -209,15 +208,15 @@ void miopen_gemm::fill_result(context& ctx,
auto m = out_lens[0];
auto n = out_lens[1];
output_shape.visit_type([&](auto as) {
auto to_pointer = [&](auto&& arg, std::size_t offset) {
auto to_pointer = [&](auto&& arg, std::size_t offset = 0) {
return to_rocblas_type(as.from(arg.data() + offset));
};
for(std::size_t i = 0; i < m; ++i)
{
hipMemcpy(to_pointer(args[3], i * n),
to_pointer(args[2]),
args[2].get_shape().bytes(),
hipMemcpy(to_pointer(result, i * n),
to_pointer(c),
c.get_shape().bytes(),
hipMemcpyDeviceToDevice);
}
});
......@@ -232,9 +231,9 @@ void miopen_gemm::fill_result(context& ctx,
for(std::size_t i = 0; i < output_shape.elements(); ++i)
{
hipMemcpy(to_pointer(args[3], i),
to_pointer(args[2], i / n),
args[2].get_shape().type_size(),
hipMemcpy(to_pointer(result, i),
to_pointer(c, i / out_lens[0]),
c.get_shape().type_size(),
hipMemcpyDeviceToDevice);
}
});
......@@ -248,7 +247,7 @@ argument miopen_gemm::compute(context& ctx,
bool is_3inputs = (args.size() == 4);
if(is_3inputs)
{
fill_result(ctx, output_shape, args[3], args[2]);
fill_result(output_shape, args[3], args[2]);
output_shape.visit_type([&](auto as) {
auto alpha_r = to_rocblas_type(as(op.alpha));
......@@ -302,8 +301,7 @@ argument miopen_gemm::compute(context& ctx,
to_pointer(args[2]));
generic_rocblas_scal(
as, ctx.get_stream().get_rocblas(), 1, &alpha_r, to_pointer(args[2]));
1);
as, ctx.get_stream().get_rocblas(), 1, &alpha_r, to_pointer(args[2]), 1);
});
}
// matrix * vector
......@@ -323,7 +321,7 @@ argument miopen_gemm::compute(context& ctx,
a_lens.rbegin() + 2, a_lens.rend(), std::size_t{1}, std::multiplies<std::size_t>());
output_shape.visit_type([&](auto as) {
auto alpha_r = to_rocblas_type(as(op.alpha));
auto beta_r = = to_rocblas_type(as(beta));
auto beta_r = to_rocblas_type(as(beta));
auto to_pointer = [&](auto&& arg, std::size_t offset = 0) {
return to_rocblas_type(as.from(arg.data() + offset));
};
......@@ -340,7 +338,7 @@ argument miopen_gemm::compute(context& ctx,
to_pointer(args[1]),
1,
&beta_r,
to_pointer(args[2], batch_no * n) 1);
to_pointer(args[2], batch_no * n), 1);
}
});
}
......@@ -361,7 +359,7 @@ argument miopen_gemm::compute(context& ctx,
b_lens.rbegin() + 2, b_lens.rend(), std::size_t{1}, std::multiplies<std::size_t>());
output_shape.visit_type([&](auto as) {
auto alpha_r = to_rocblas_type(as(op.alpha));
auto beta_r = = to_rocblas_type(as(beta));
auto beta_r = to_rocblas_type(as(beta));
auto to_pointer = [&](auto&& arg, std::size_t offset = 0) {
return to_rocblas_type(as.from(arg.data() + offset));
};
......@@ -378,7 +376,7 @@ argument miopen_gemm::compute(context& ctx,
to_pointer(args[1], batch_no * m * n),
1,
&beta_r,
to_pointer(args[2], batch_no * m) 1);
to_pointer(args[2], batch_no * m), 1);
}
});
}
......@@ -399,6 +397,7 @@ argument miopen_gemm::compute(context& ctx,
rocblas_int m = out_lens[out_lens.size() - 2];
rocblas_int n = out_lens[out_lens.size() - 1];
rocblas_int k = args[0].get_shape().lens()[a_lens.size() - 1];
float beta = 0.0f;
auto input_dims = std::min(a_lens.size(), b_lens.size());
std::size_t axis{0};
for(axis = 2; axis < input_dims; ++axis)
......@@ -429,9 +428,9 @@ argument miopen_gemm::compute(context& ctx,
shape::type_t t = output_shape.type();
shape a_batch_shape{t, a_batch_lens};
shape b_batch_shape{t, b_batch_lens};
shape out_diff_shape{t, out_batch_lens};
shape out_batch_shape{t, out_batch_lens};
shape_for_each(out_diff_shape, [&](auto out_idx) {
shape_for_each(out_batch_shape, [&](auto out_idx) {
std::size_t out_ind = out_batch_shape.index(out_idx.begin(), out_idx.end());
std::vector<std::size_t> a_idx(a_lens.size() - axis);
std::vector<std::size_t> b_idx(b_lens.size() - axis);
......@@ -451,7 +450,7 @@ argument miopen_gemm::compute(context& ctx,
output_shape.visit_type([&](auto as) {
auto alpha_r = to_rocblas_type(as(op.alpha));
auto beta_r = = to_rocblas_type(as(beta));
auto beta_r = to_rocblas_type(as(beta));
auto to_pointer = [&](auto&& arg, std::size_t offset = 0) {
return to_rocblas_type(as.from(arg.data() + offset));
};
......
......@@ -20,8 +20,7 @@ struct miopen_gemm
int output_alias(const std::vector<shape>& shapes) const { return shapes.size() - 1; }
private:
void fill_result(context& ctx,
const shape& output_shape,
void fill_result(const shape& output_shape,
const argument& result,
const argument& c) const;
};
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment