Commit 03afa098 authored by Shucai Xiao's avatar Shucai Xiao
Browse files

temp changes.

parent 9e0dca3d
...@@ -36,7 +36,7 @@ struct argument : raw_data<argument> ...@@ -36,7 +36,7 @@ struct argument : raw_data<argument>
} }
/// Provides a raw pointer to the data /// Provides a raw pointer to the data
std::function<char*()> data; std::function<char*()> data = nullptr;
/// Whether data is available /// Whether data is available
bool empty() const { return not data; } bool empty() const { return not data; }
......
...@@ -70,7 +70,6 @@ add_library(migraphx_gpu ...@@ -70,7 +70,6 @@ add_library(migraphx_gpu
lrn.cpp lrn.cpp
schedule_model.cpp schedule_model.cpp
adjust_allocation.cpp adjust_allocation.cpp
pack_int8_args.cpp
clip.cpp clip.cpp
) )
set_target_properties(migraphx_gpu PROPERTIES EXPORT_NAME gpu) set_target_properties(migraphx_gpu PROPERTIES EXPORT_NAME gpu)
......
...@@ -61,6 +61,11 @@ struct hip_tensor_descriptor ...@@ -61,6 +61,11 @@ struct hip_tensor_descriptor
{ {
std::copy(s.lens().begin(), s.lens().end(), lens); std::copy(s.lens().begin(), s.lens().end(), lens);
std::copy(s.strides().begin(), s.strides().end(), strides); std::copy(s.strides().begin(), s.strides().end(), strides);
indices.resize(s.strides().size());
std::iota(indices.begin(), indices.end(), 0);
std::sort(indices.begin(), indices.end(), [&](size_t i1, size_t i2) {
return strides[i1] > strides[i2];
});
} }
__device__ __host__ hip_index<NDim> multi(size_t idx) const __device__ __host__ hip_index<NDim> multi(size_t idx) const
...@@ -69,8 +74,8 @@ struct hip_tensor_descriptor ...@@ -69,8 +74,8 @@ struct hip_tensor_descriptor
size_t tidx = idx; size_t tidx = idx;
for(size_t is = 0; is < NDim; is++) for(size_t is = 0; is < NDim; is++)
{ {
result[is] = tidx / strides[is]; result[indices[is]] = tidx / strides[indices[is]];
tidx = tidx % strides[is]; tidx = tidx % strides[indices[is]];
} }
return result; return result;
} }
...@@ -83,6 +88,7 @@ struct hip_tensor_descriptor ...@@ -83,6 +88,7 @@ struct hip_tensor_descriptor
} }
size_t lens[NDim] = {}; size_t lens[NDim] = {};
size_t strides[NDim] = {}; size_t strides[NDim] = {};
std::vector<size_t> indices{};
}; };
} // namespace device } // namespace device
......
...@@ -13,6 +13,8 @@ struct context; ...@@ -13,6 +13,8 @@ struct context;
struct miopen_quant_gemm struct miopen_quant_gemm
{ {
op::quant_dot op; op::quant_dot op;
mutable argument pack_0{};
mutable argument pack_1{};
template <class Self, class F> template <class Self, class F>
static auto reflect(Self& self, F f) static auto reflect(Self& self, F f)
...@@ -30,17 +32,17 @@ struct miopen_quant_gemm ...@@ -30,17 +32,17 @@ struct miopen_quant_gemm
} }
}; };
struct hip_pack // struct hip_pack
{ // {
std::string name() const { return "gpu::gemm_pack"; } // std::string name() const { return "gpu::gemm_pack"; }
shape compute_shape(const std::vector<shape>& inputs) const; // shape compute_shape(const std::vector<shape>& inputs) const;
argument // argument
compute(context& ctx, const shape& output_shape, const std::vector<argument>& args) const; // compute(context& ctx, const shape& output_shape, const std::vector<argument>& args) const;
std::ptrdiff_t output_alias(const std::vector<shape>& shapes) const // std::ptrdiff_t output_alias(const std::vector<shape>& shapes) const
{ // {
return shapes.size() - 1; // return shapes.size() - 1;
} // }
}; // };
} // namespace gpu } // namespace gpu
} // namespace MIGRAPHX_INLINE_NS } // namespace MIGRAPHX_INLINE_NS
......
#include <migraphx/gpu/quant_gemm.hpp> #include <migraphx/gpu/quant_gemm.hpp>
#include <migraphx/gpu/device/pack.hpp> #include <migraphx/gpu/device/pack.hpp>
#include <migraphx/gpu/context.hpp> #include <migraphx/gpu/context.hpp>
#include <migraphx/generate.hpp>
namespace migraphx { namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
...@@ -54,15 +55,21 @@ rb_type<T>* to_rocblas_type(T* x) ...@@ -54,15 +55,21 @@ rb_type<T>* to_rocblas_type(T* x)
shape miopen_quant_gemm::compute_shape(const std::vector<shape>& inputs) const shape miopen_quant_gemm::compute_shape(const std::vector<shape>& inputs) const
{ {
std::vector<shape> input_shapes(inputs); std::vector<shape> input_shapes(inputs);
if(!inputs.at(1).transposed())
{
input_shapes.pop_back();
}
if(inputs.at(0).transposed())
{
input_shapes.pop_back();
}
input_shapes.pop_back(); input_shapes.pop_back();
// if(!inputs.at(1).transposed())
// {
// if (pack_1.empty())
// {
// pack_1 = allocate_gpu(inputs.at(1));
// }
// }
// if(inputs.at(0).transposed())
// {
// if (pack_0.empty())
// {
// pack_0 = allocate_gpu(inputs.at(0));
// }
// }
check_shapes{input_shapes}.not_broadcasted(); check_shapes{input_shapes}.not_broadcasted();
return op.compute_shape(input_shapes); return op.compute_shape(input_shapes);
...@@ -82,26 +89,37 @@ argument miopen_quant_gemm::compute(context& ctx, ...@@ -82,26 +89,37 @@ argument miopen_quant_gemm::compute(context& ctx,
rocblas_int ldb = args[1].get_shape().strides()[transb ? dim_1 : dim_0]; rocblas_int ldb = args[1].get_shape().strides()[transb ? dim_1 : dim_0];
rocblas_int ldc = args[2].get_shape().strides()[dim_0]; rocblas_int ldc = args[2].get_shape().strides()[dim_0];
size_t addi_ref_num = 0;
if(!transb) if(!transb)
{ {
++addi_ref_num;
const argument& arg_b = args[args.size() - 1];
// argument for B is the last one in the input argument vector
// use the algorithm to pack A // use the algorithm to pack A
device::pack_a(ctx.get_stream().get(), args[1], arg_b); if (pack_1.empty())
{
std::cout << "allocate pack_1" << std::endl;
pack_1 = allocate_gpu(args.at(1).get_shape());
}
//assert(!pack_1.empty());
device::pack_a(ctx.get_stream().get(), pack_1, args[1]);
auto pb = from_gpu(pack_1);
std::cout << "pb = " << pb << std::endl;
} }
// need to pack A in this scenario, use the algorithm to pack B in the // need to pack A in this scenario, use the algorithm to pack B in the
// comment of the API // comment of the API
if(transa) if(transa)
{ {
++addi_ref_num; if (pack_0.empty())
const argument& arg_a = args[args.size() - 1 - addi_ref_num]; {
device::pack_b(ctx.get_stream().get(), args[0], arg_a); std::cout << "allocate pack_0" << std::endl;
pack_0 = allocate_gpu(args.at(0).get_shape());
}
device::pack_b(ctx.get_stream().get(), pack_0, args[0]);
auto a = from_gpu(args[0]);
auto pa = from_gpu(pack_0);
std::cout << "a = " << a << std::endl;
std::cout << "pa = " << pa << std::endl;
} }
bool is_3inputs = (args.size() - addi_ref_num == 4); bool is_3inputs = (args.size() == 4);
int8_t beta = 0; int8_t beta = 0;
if(is_3inputs) if(is_3inputs)
{ {
...@@ -135,10 +153,10 @@ argument miopen_quant_gemm::compute(context& ctx, ...@@ -135,10 +153,10 @@ argument miopen_quant_gemm::compute(context& ctx,
m, m,
k, k,
&alpha_r, &alpha_r,
to_pointer(args[1]), (!transb) ? to_pointer(pack_1) : to_pointer(args[1]),
rocblas_datatype_i8_r, rocblas_datatype_i8_r,
ldb, ldb,
to_pointer(args[0]), transa ? to_pointer(pack_0) : to_pointer(args[0]),
rocblas_datatype_i8_r, rocblas_datatype_i8_r,
lda, lda,
&beta_r, &beta_r,
...@@ -165,11 +183,11 @@ argument miopen_quant_gemm::compute(context& ctx, ...@@ -165,11 +183,11 @@ argument miopen_quant_gemm::compute(context& ctx,
m, m,
k, k,
&alpha_r, &alpha_r,
to_pointer(args[1]), (!transb) ? to_pointer(pack_1) : to_pointer(args[1]),
rocblas_datatype_i8_r, rocblas_datatype_i8_r,
ldb, ldb,
k * n, k * n,
to_pointer(args[0]), transa ? to_pointer(pack_0) : to_pointer(args[0]),
rocblas_datatype_i8_r, rocblas_datatype_i8_r,
lda, lda,
m * k, m * k,
......
...@@ -71,8 +71,8 @@ std::vector<pass> target::get_passes(migraphx::context& gctx) const ...@@ -71,8 +71,8 @@ std::vector<pass> target::get_passes(migraphx::context& gctx) const
eliminate_allocation{"hip::allocate"}, eliminate_allocation{"hip::allocate"},
check_context<context>{}, check_context<context>{},
dead_code_elimination{}, dead_code_elimination{},
pack_int8_args{}, // pack_int8_args{},
dead_code_elimination{}, // dead_code_elimination{},
eliminate_identity{} eliminate_identity{}
}; };
// clang-format on // clang-format on
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment