Commit 3eb036b2 authored by Shucai Xiao's avatar Shucai Xiao
Browse files

fixed build errors.

parent 02188419
...@@ -379,7 +379,7 @@ struct cpu_gemm ...@@ -379,7 +379,7 @@ struct cpu_gemm
if(out_lens == c_lens) if(out_lens == c_lens)
{ {
visit_all(result, c)([&](auto output, auto input) { visit_all(result, c)([&](auto output, auto input) {
std::memcpy(output.data(), input.data(), c_shape.bytes()); std::memcpy(output.data(), input.data(), c.get_shape().bytes());
}); });
} }
// need broadcast // need broadcast
...@@ -397,7 +397,7 @@ struct cpu_gemm ...@@ -397,7 +397,7 @@ struct cpu_gemm
visit_all(result, c)([&](auto output, auto input) { visit_all(result, c)([&](auto output, auto input) {
for(std::size_t i = 0; i < m; i++) for(std::size_t i = 0; i < m; i++)
{ {
std::memcpy((output.data() + i * n), input.data(), c_shape.bytes()); std::memcpy((output.data() + i * n), input.data(), c.get_shape().bytes());
} }
}); });
} }
......
...@@ -171,8 +171,7 @@ shape miopen_gemm::compute_shape(const std::vector<shape>& inputs) const ...@@ -171,8 +171,7 @@ shape miopen_gemm::compute_shape(const std::vector<shape>& inputs) const
return op.compute_shape(inputs); return op.compute_shape(inputs);
} }
void miopen_gemm::fill_result(context& ctx, void miopen_gemm::fill_result(const shape& output_shape,
const shape& output_shape,
const argument& result, const argument& result,
const argument& c) const const argument& c) const
{ {
...@@ -182,8 +181,8 @@ void miopen_gemm::fill_result(context& ctx, ...@@ -182,8 +181,8 @@ void miopen_gemm::fill_result(context& ctx,
{ {
output_shape.visit_type([&](auto as) { output_shape.visit_type([&](auto as) {
auto to_pointer = [&](auto&& arg) { return to_rocblas_type(as.from(arg.data())); }; auto to_pointer = [&](auto&& arg) { return to_rocblas_type(as.from(arg.data())); };
hipMemcpy(to_pointer(args[3]), hipMemcpy(to_pointer(result),
to_pointer(args[2]), to_pointer(c),
output_shape.bytes(), output_shape.bytes(),
hipMemcpyDeviceToDevice); hipMemcpyDeviceToDevice);
}); });
...@@ -191,15 +190,15 @@ void miopen_gemm::fill_result(context& ctx, ...@@ -191,15 +190,15 @@ void miopen_gemm::fill_result(context& ctx,
else if(c.single()) else if(c.single())
{ {
output_shape.visit_type([&](auto as) { output_shape.visit_type([&](auto as) {
auto to_pointer = [&](auto&& arg, std::size_t offset) { auto to_pointer = [&](auto&& arg, std::size_t offset = 0) {
return to_rocblas_type(as.from(arg.data() + offset)); return to_rocblas_type(as.from(arg.data() + offset));
}; };
for(std::size_t i = 0; i < output_shape.elements(); ++i) for(std::size_t i = 0; i < output_shape.elements(); ++i)
{ {
hipMemcpy(to_pointer(args[3], i), hipMemcpy(to_pointer(result, i),
to_pointer(args[2]), to_pointer(c),
args[2].get_shape().bytes(), c.get_shape().bytes(),
hipMemcpyDeviceToDevice); hipMemcpyDeviceToDevice);
} }
}); });
...@@ -209,15 +208,15 @@ void miopen_gemm::fill_result(context& ctx, ...@@ -209,15 +208,15 @@ void miopen_gemm::fill_result(context& ctx,
auto m = out_lens[0]; auto m = out_lens[0];
auto n = out_lens[1]; auto n = out_lens[1];
output_shape.visit_type([&](auto as) { output_shape.visit_type([&](auto as) {
auto to_pointer = [&](auto&& arg, std::size_t offset) { auto to_pointer = [&](auto&& arg, std::size_t offset = 0) {
return to_rocblas_type(as.from(arg.data() + offset)); return to_rocblas_type(as.from(arg.data() + offset));
}; };
for(std::size_t i = 0; i < m; ++i) for(std::size_t i = 0; i < m; ++i)
{ {
hipMemcpy(to_pointer(args[3], i * n), hipMemcpy(to_pointer(result, i * n),
to_pointer(args[2]), to_pointer(c),
args[2].get_shape().bytes(), c.get_shape().bytes(),
hipMemcpyDeviceToDevice); hipMemcpyDeviceToDevice);
} }
}); });
...@@ -232,9 +231,9 @@ void miopen_gemm::fill_result(context& ctx, ...@@ -232,9 +231,9 @@ void miopen_gemm::fill_result(context& ctx,
for(std::size_t i = 0; i < output_shape.elements(); ++i) for(std::size_t i = 0; i < output_shape.elements(); ++i)
{ {
hipMemcpy(to_pointer(args[3], i), hipMemcpy(to_pointer(result, i),
to_pointer(args[2], i / n), to_pointer(c, i / out_lens[0]),
args[2].get_shape().type_size(), c.get_shape().type_size(),
hipMemcpyDeviceToDevice); hipMemcpyDeviceToDevice);
} }
}); });
...@@ -248,7 +247,7 @@ argument miopen_gemm::compute(context& ctx, ...@@ -248,7 +247,7 @@ argument miopen_gemm::compute(context& ctx,
bool is_3inputs = (args.size() == 4); bool is_3inputs = (args.size() == 4);
if(is_3inputs) if(is_3inputs)
{ {
fill_result(ctx, output_shape, args[3], args[2]); fill_result(output_shape, args[3], args[2]);
output_shape.visit_type([&](auto as) { output_shape.visit_type([&](auto as) {
auto alpha_r = to_rocblas_type(as(op.alpha)); auto alpha_r = to_rocblas_type(as(op.alpha));
...@@ -302,8 +301,7 @@ argument miopen_gemm::compute(context& ctx, ...@@ -302,8 +301,7 @@ argument miopen_gemm::compute(context& ctx,
to_pointer(args[2])); to_pointer(args[2]));
generic_rocblas_scal( generic_rocblas_scal(
as, ctx.get_stream().get_rocblas(), 1, &alpha_r, to_pointer(args[2])); as, ctx.get_stream().get_rocblas(), 1, &alpha_r, to_pointer(args[2]), 1);
1);
}); });
} }
// matrix * vector // matrix * vector
...@@ -323,7 +321,7 @@ argument miopen_gemm::compute(context& ctx, ...@@ -323,7 +321,7 @@ argument miopen_gemm::compute(context& ctx,
a_lens.rbegin() + 2, a_lens.rend(), std::size_t{1}, std::multiplies<std::size_t>()); a_lens.rbegin() + 2, a_lens.rend(), std::size_t{1}, std::multiplies<std::size_t>());
output_shape.visit_type([&](auto as) { output_shape.visit_type([&](auto as) {
auto alpha_r = to_rocblas_type(as(op.alpha)); auto alpha_r = to_rocblas_type(as(op.alpha));
auto beta_r = = to_rocblas_type(as(beta)); auto beta_r = to_rocblas_type(as(beta));
auto to_pointer = [&](auto&& arg, std::size_t offset = 0) { auto to_pointer = [&](auto&& arg, std::size_t offset = 0) {
return to_rocblas_type(as.from(arg.data() + offset)); return to_rocblas_type(as.from(arg.data() + offset));
}; };
...@@ -340,7 +338,7 @@ argument miopen_gemm::compute(context& ctx, ...@@ -340,7 +338,7 @@ argument miopen_gemm::compute(context& ctx,
to_pointer(args[1]), to_pointer(args[1]),
1, 1,
&beta_r, &beta_r,
to_pointer(args[2], batch_no * n) 1); to_pointer(args[2], batch_no * n), 1);
} }
}); });
} }
...@@ -361,7 +359,7 @@ argument miopen_gemm::compute(context& ctx, ...@@ -361,7 +359,7 @@ argument miopen_gemm::compute(context& ctx,
b_lens.rbegin() + 2, b_lens.rend(), std::size_t{1}, std::multiplies<std::size_t>()); b_lens.rbegin() + 2, b_lens.rend(), std::size_t{1}, std::multiplies<std::size_t>());
output_shape.visit_type([&](auto as) { output_shape.visit_type([&](auto as) {
auto alpha_r = to_rocblas_type(as(op.alpha)); auto alpha_r = to_rocblas_type(as(op.alpha));
auto beta_r = = to_rocblas_type(as(beta)); auto beta_r = to_rocblas_type(as(beta));
auto to_pointer = [&](auto&& arg, std::size_t offset = 0) { auto to_pointer = [&](auto&& arg, std::size_t offset = 0) {
return to_rocblas_type(as.from(arg.data() + offset)); return to_rocblas_type(as.from(arg.data() + offset));
}; };
...@@ -378,7 +376,7 @@ argument miopen_gemm::compute(context& ctx, ...@@ -378,7 +376,7 @@ argument miopen_gemm::compute(context& ctx,
to_pointer(args[1], batch_no * m * n), to_pointer(args[1], batch_no * m * n),
1, 1,
&beta_r, &beta_r,
to_pointer(args[2], batch_no * m) 1); to_pointer(args[2], batch_no * m), 1);
} }
}); });
} }
...@@ -399,6 +397,7 @@ argument miopen_gemm::compute(context& ctx, ...@@ -399,6 +397,7 @@ argument miopen_gemm::compute(context& ctx,
rocblas_int m = out_lens[out_lens.size() - 2]; rocblas_int m = out_lens[out_lens.size() - 2];
rocblas_int n = out_lens[out_lens.size() - 1]; rocblas_int n = out_lens[out_lens.size() - 1];
rocblas_int k = args[0].get_shape().lens()[a_lens.size() - 1]; rocblas_int k = args[0].get_shape().lens()[a_lens.size() - 1];
float beta = 0.0f;
auto input_dims = std::min(a_lens.size(), b_lens.size()); auto input_dims = std::min(a_lens.size(), b_lens.size());
std::size_t axis{0}; std::size_t axis{0};
for(axis = 2; axis < input_dims; ++axis) for(axis = 2; axis < input_dims; ++axis)
...@@ -429,9 +428,9 @@ argument miopen_gemm::compute(context& ctx, ...@@ -429,9 +428,9 @@ argument miopen_gemm::compute(context& ctx,
shape::type_t t = output_shape.type(); shape::type_t t = output_shape.type();
shape a_batch_shape{t, a_batch_lens}; shape a_batch_shape{t, a_batch_lens};
shape b_batch_shape{t, b_batch_lens}; shape b_batch_shape{t, b_batch_lens};
shape out_diff_shape{t, out_batch_lens}; shape out_batch_shape{t, out_batch_lens};
shape_for_each(out_diff_shape, [&](auto out_idx) { shape_for_each(out_batch_shape, [&](auto out_idx) {
std::size_t out_ind = out_batch_shape.index(out_idx.begin(), out_idx.end()); std::size_t out_ind = out_batch_shape.index(out_idx.begin(), out_idx.end());
std::vector<std::size_t> a_idx(a_lens.size() - axis); std::vector<std::size_t> a_idx(a_lens.size() - axis);
std::vector<std::size_t> b_idx(b_lens.size() - axis); std::vector<std::size_t> b_idx(b_lens.size() - axis);
...@@ -451,7 +450,7 @@ argument miopen_gemm::compute(context& ctx, ...@@ -451,7 +450,7 @@ argument miopen_gemm::compute(context& ctx,
output_shape.visit_type([&](auto as) { output_shape.visit_type([&](auto as) {
auto alpha_r = to_rocblas_type(as(op.alpha)); auto alpha_r = to_rocblas_type(as(op.alpha));
auto beta_r = = to_rocblas_type(as(beta)); auto beta_r = to_rocblas_type(as(beta));
auto to_pointer = [&](auto&& arg, std::size_t offset = 0) { auto to_pointer = [&](auto&& arg, std::size_t offset = 0) {
return to_rocblas_type(as.from(arg.data() + offset)); return to_rocblas_type(as.from(arg.data() + offset));
}; };
......
...@@ -20,8 +20,7 @@ struct miopen_gemm ...@@ -20,8 +20,7 @@ struct miopen_gemm
int output_alias(const std::vector<shape>& shapes) const { return shapes.size() - 1; } int output_alias(const std::vector<shape>& shapes) const { return shapes.size() - 1; }
private: private:
void fill_result(context& ctx, void fill_result(const shape& output_shape,
const shape& output_shape,
const argument& result, const argument& result,
const argument& c) const; const argument& c) const;
}; };
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment