Commit b0964da4 authored by Shucai Xiao's avatar Shucai Xiao
Browse files

more changes to match the gemm gpu implementation

parent dacfb9b8
...@@ -82,8 +82,6 @@ std::vector<T> generate_tensor_data(const migraphx::shape& s, unsigned long seed ...@@ -82,8 +82,6 @@ std::vector<T> generate_tensor_data(const migraphx::shape& s, unsigned long seed
{ {
std::vector<T> result(s.elements()); std::vector<T> result(s.elements());
std::generate(result.begin(), result.end(), xorshf96_generator<T>{seed}); std::generate(result.begin(), result.end(), xorshf96_generator<T>{seed});
// divide a value to avoid integer overflow
std::transform(result.begin(), result.end(), result.begin(), [](auto i) { return i / 32; });
// std::generate(result.begin(), result.end(), [&]{ return seed % 7; }); // std::generate(result.begin(), result.end(), [&]{ return seed % 7; });
// std::generate(result.begin(), result.end(), []{ return 1; }); // std::generate(result.begin(), result.end(), []{ return 1; });
return result; return result;
......
...@@ -24,6 +24,7 @@ struct rocblas_quant_gemm ...@@ -24,6 +24,7 @@ struct rocblas_quant_gemm
shape compute_shape(const std::vector<shape>& inputs) const; shape compute_shape(const std::vector<shape>& inputs) const;
argument argument
compute(context& ctx, const shape& output_shape, const std::vector<argument>& args) const; compute(context& ctx, const shape& output_shape, const std::vector<argument>& args) const;
void batch_not_transposed(const std::vector<std::size_t>& strides) const;
std::ptrdiff_t output_alias(const std::vector<shape>& shapes) const std::ptrdiff_t output_alias(const std::vector<shape>& shapes) const
{ {
return shapes.size() - 1; return shapes.size() - 1;
......
...@@ -56,10 +56,27 @@ shape rocblas_quant_gemm::compute_shape(const std::vector<shape>& inputs) const ...@@ -56,10 +56,27 @@ shape rocblas_quant_gemm::compute_shape(const std::vector<shape>& inputs) const
std::vector<shape> in_shapes(inputs); std::vector<shape> in_shapes(inputs);
in_shapes.pop_back(); in_shapes.pop_back();
check_shapes{in_shapes}.not_broadcasted(); check_shapes{in_shapes}.not_broadcasted();
batch_not_transposed(inputs[0].strides());
batch_not_transposed(inputs[1].strides());
return op.compute_shape(in_shapes); return op.compute_shape(in_shapes);
} }
void rocblas_quant_gemm::batch_not_transposed(const std::vector<std::size_t>& strides) const
{
if(strides.size() <= 2)
return;
auto dim_0 = strides.size() - 2;
auto matrix_size = std::max(strides[dim_0], strides[dim_0 + 1]);
std::vector<std::size_t> batch(strides.begin(), strides.begin() + dim_0);
if(std::adjacent_find(batch.begin(), batch.end(), [&](auto i, auto j) {
return (i < j or i < matrix_size or j < matrix_size);
}) != batch.end())
{
MIGRAPHX_THROW("DOT: batch size {" + to_string_range(strides) + "} is transposed!");
}
}
argument rocblas_quant_gemm::compute(context& ctx, argument rocblas_quant_gemm::compute(context& ctx,
const shape& output_shape, const shape& output_shape,
const std::vector<argument>& args) const const std::vector<argument>& args) const
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment