Commit b9f1b198 authored by Alan Turner's avatar Alan Turner
Browse files

Reconfigure to use int8 ck gemms

parent ac7a0025
......@@ -17,7 +17,7 @@ namespace gpu {
struct ck_gemm
{
operation op = make_op("dot");
operation op = make_op("quant_dot");
template <class Self, class F>
static auto reflect(Self& self, F f)
......@@ -46,7 +46,7 @@ struct ck_gemm
check_gemm_shape(input);
auto r = op.compute_shape({a, b});
if(mods.empty())
return r;
return r.with_type(shape::int8_type);
return r.with_type(mods.front()->get_output_shapes().front().type());
}
};
......@@ -56,7 +56,7 @@ namespace {
MIGRAPHX_PRED_MATCHER(is_ck_gemm, instruction_ref ins)
{
if(ins->name() != "dot")
if(ins->name() != "quant_dot")
return false;
auto a = ins->inputs().front()->get_shape();
auto b = ins->inputs().back()->get_shape();
......@@ -87,7 +87,7 @@ struct find_ck_gemm_pointwise
auto gemm_it = std::find(inputs.begin(), inputs.end(), x_ins);
auto gemm_idx = gemm_it - inputs.begin();
assert(gemm_it != inputs.end());
if(ins->get_shape().type() != shape::half_type)
if(ins->get_shape().type() != shape::int8_type and ins->get_shape().type())
return;
if(gemm_idx != 0)
{
......@@ -110,7 +110,7 @@ struct find_ck_gemm_pointwise
struct find_ck_gemm
{
auto matcher() const { return match::name("dot")(is_ck_gemm().bind("gemm")); }
auto matcher() const { return match::name("quant_dot")(is_ck_gemm().bind("gemm")); }
void apply(module_pass_manager& mpm, const match::matcher_result& r) const
{
......
......@@ -58,6 +58,60 @@ static const char* const ck_gemm_kernel = R"__migraphx__(
#include <args.hpp>
#include <migraphx/kernels/ck_gemm.hpp>
#include <migraphx/kernels/pointwise.hpp>
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_dl.hpp"
using Row = ck::tensor_layout::gemm::RowMajor;
using Col = ck::tensor_layout::gemm::ColumnMajor;
template <ck::index_t... Is>
using S = ck::Sequence<Is...>;
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
using Empty_Tuple = ck::Tuple<>;
using GEMM = ck::tensor_operation::device::DeviceGemmMultipleD_Dl<
Row,
Row,
Empty_Tuple,
Row,
int8_t,
int8_t,
int32_t,
Empty_Tuple,
int8_t, //EDataType
PassThrough,
PassThrough,
PassThrough,
ck::tensor_operation::device::GemmSpecialization::MNKPadding,
256,
128,
128,
16,
4,
4,
4,
1,
S<8,2>,
S<8,2>,
S<8,1,1,4>,
S<2,1,128,1>,
S<1,2,0,3>,
S<1,2,0,3>,
S<4,1,1,4>,
S<1,2,0,3>,
S<1,1,1,4>,
S<2,1,4,4>,
S<8,1,32,1>,
S<0,3,1,2>,
S<0,3,1,2>,
S<1,1,4,1>,
S<0,3,1,2>,
S<1,1,4,4>,
S<0,1,2,3,4,5>,
5,
4>;
namespace migraphx {
......@@ -68,7 +122,7 @@ extern "C" {
__global__ void ${kernel}(${params})
{
transform_args(make_tensors(), rotate_last())(${args})([](auto... xs) {
ck_gemm<CK_DeviceGemmMultipleD<${instance}>, ${blocks_per_batch}>(xs...);
ck_gemm<GEMM, ${blocks_per_batch}>(xs...);
});
}
......@@ -295,9 +349,9 @@ struct ck_gemm_compiler : compiler<ck_gemm_compiler>
std::array<std::size_t, 3> config{m, n, k};
auto tuning_val = v.get("tuning_val", get_tuning_for({a_shape, b_shape, c_shape}));
auto ip = instance{get_instance(tuning_val, [&](const auto& x) -> bool {
return get_layout(a_shape) == x[0] and get_layout(b_shape) == x[1] and
return true;/* get_layout(a_shape) == x[0] and get_layout(b_shape) == x[1] and
get_layout(c_shape) == x[3] and get_type(a_shape) == x[4] and
get_type(b_shape) == x[5] and get_type(c_shape) == x[9];
get_type(b_shape) == x[5] and get_type(c_shape) == x[9]; */
})};
assert(inputs.size() < 4 or v.contains("post"));
if(v.contains("post"))
......@@ -320,19 +374,23 @@ struct ck_gemm_compiler : compiler<ck_gemm_compiler>
gemm_type += "Padding";
ip.set_gemm("ck::tensor_operation::device::GemmSpecialization::" + gemm_type);
auto blocks_per_batch = ip.get_grid_size(config);
auto blocks_per_batch = int_div_ceil(m, 128) * int_div_ceil(n, 128);;//ip.get_grid_size(config);
hip_compile_options options;
auto block_size = ip.get_block_size();
auto block_size = 256;//ip.get_block_size();
auto grid_size = can_fold_batch ? blocks_per_batch : batch_count * blocks_per_batch;
options.set_launch_params(v, grid_size * block_size, block_size);
options.inputs = inputs;
//auto new_inputs = inputs;
auto new_inputs = inputs;
// auto out_s = inputs.back();
// new_inputs.back() = shape{shape::int8_type, out_s.lens(), out_s.strides()};
options.inputs = new_inputs;
options.output = c_shape;
options.kernel_name = v.get("kernel", "ck_gemm_kernel");
options.virtual_inputs = inputs;
options.virtual_inputs = new_inputs;
if(can_fold_batch)
{
auto vinputs = inputs;
auto vinputs = new_inputs;
fold_batch_dims(vinputs[0]);
remove_batch_dims(vinputs[1]);
std::for_each(vinputs.begin() + 2, vinputs.end(), fold_batch_dims);
......
......@@ -45,55 +45,88 @@ template <class Tensor>
using ck_transposeb = decltype(make_shape(ck_transposeb_dims(get_shape_c<Tensor>{}.lens),
ck_transposeb_dims(get_shape_c<Tensor>{}.strides)));
template <class... Xs>
constexpr void noop(Xs...) {}
template <class G, class E, class A, class B, class... Ds>
__device__ void ck_gemm_matrix(E e, A a, B b, Ds... ds)
{
constexpr const G gemm{};
constexpr const auto a_grid_desc_m_k = gemm.matrix_padder.PadADescriptor_M_K(to_ck_tensor<A>());
constexpr const auto b_grid_desc_n_k =
gemm.matrix_padder.PadBDescriptor_N_K(to_ck_tensor<ck_transposeb<B>>());
constexpr const auto e_grid_desc_m_n = gemm.matrix_padder.PadCDescriptor_M_N(to_ck_tensor<E>());
constexpr const auto ds_grid_desc_m_n =
ck::make_tuple(gemm.matrix_padder.PadCDescriptor_M_N(to_ck_tensor<Ds>())...);
constexpr const auto block_2_etile_map = gemm.MakeDefaultBlock2ETileMap(e_grid_desc_m_n);
using GridwiseGemm = typename G::GridwiseGemm;
// tensor descriptors for block/thread-wise copy
constexpr auto a_grid_desc_ak0_m_ak1 =
GridwiseGemm::MakeDefaultAGridDescriptor_AK0_M_AK1(a_grid_desc_m_k);
constexpr auto b_grid_desc_bk0_n_bk1 =
GridwiseGemm::MakeDefaultBGridDescriptor_BK0_N_BK1(b_grid_desc_n_k);
constexpr auto ds_grid_desc_mblock_mperblock_nblock_nperblock =
GridwiseGemm::MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(ds_grid_desc_m_n);
constexpr auto e_grid_desc_mblock_mperblock_nblock_nperblock =
GridwiseGemm::MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(e_grid_desc_m_n);
static_assert(GridwiseGemm::CheckValidity(
a_grid_desc_m_k, b_grid_desc_n_k, ds_grid_desc_m_n, e_grid_desc_m_n, block_2_etile_map));
__shared__ char p_shared_block[GridwiseGemm::GetSharedMemoryNumberOfByte()];
constexpr const bool HasMainKBlockLoop =
GridwiseGemm::CalculateHasMainKBlockLoop(a_grid_desc_ak0_m_ak1.GetLength(ck::Number<0>{}) *
a_grid_desc_ak0_m_ak1.GetLength(ck::Number<2>{}));
GridwiseGemm::template Run<HasMainKBlockLoop>(to_ck_const_pointer(a.data()),
constexpr auto desc = G::make_descriptor(to_ck_tensor<A>(),
to_ck_tensor<ck_transposeb<B>>(),
ck::make_tuple(to_ck_tensor<Ds>()...),
to_ck_tensor<E>());
G::Run(desc,
to_ck_const_pointer(a.data()),
to_ck_const_pointer(b.data()),
ck::make_tuple(to_ck_const_pointer(ds.data())...),
to_ck_pointer(e.data()),
p_shared_block,
gemm.a_element_op,
gemm.b_element_op,
gemm.cde_element_op,
a_grid_desc_ak0_m_ak1,
b_grid_desc_bk0_n_bk1,
ds_grid_desc_mblock_mperblock_nblock_nperblock,
e_grid_desc_mblock_mperblock_nblock_nperblock,
block_2_etile_map);
to_ck_pointer(e.data()));
// constexpr const auto M = a.get_shape().lens[0];
// constexpr const auto N = b.get_shape().lens[1];
// constexpr const auto K = a.get_shape().lens[1];
// constexpr const auto K1Number = ck::Number<4>{};
// constexpr const auto K0 = K / 4;
// using GridwiseGemm = typename G::GridwiseGemm;
// constexpr auto a_grid_desc_k0_m_k1 = ck::transform_tensor_descriptor(
// to_ck_tensor<A>(),
// ck::make_tuple(ck::make_unmerge_transform(ck::make_tuple(K0, K1Number)),
// ck::make_pass_through_transform(M)),
// ck::make_tuple(ck::Sequence<1>{}, ck::Sequence<0>{}),
// ck::make_tuple(ck::Sequence<0, 2>{}, ck::Sequence<1>{}));
// constexpr const auto a_grid_desc_k0_m0_m1_k1 =
// GridwiseGemm::MakeAGridDescriptor_K0_M0_M1_K1(a_grid_desc_k0_m_k1);
// noop(a, b, e, M, N, K, a_grid_desc_k0_m_k1, a_grid_desc_k0_m0_m1_k1, ds...);
//////////////////////////
// constexpr const G gemm{};
// constexpr const auto a_grid_desc_m_k = gemm.matrix_padder.PadADescriptor_M_K(to_ck_tensor<A>());
// constexpr const auto b_grid_desc_n_k =
// gemm.matrix_padder.PadBDescriptor_N_K(to_ck_tensor<ck_transposeb<B>>());
// constexpr const auto e_grid_desc_m_n = gemm.matrix_padder.PadCDescriptor_M_N(to_ck_tensor<E>());
// constexpr const auto ds_grid_desc_m_n =
// ck::make_tuple(gemm.matrix_padder.PadCDescriptor_M_N(to_ck_tensor<Ds>())...);
// constexpr const auto block_2_etile_map = gemm.MakeDefaultBlock2ETileMap(e_grid_desc_m_n);
// using GridwiseGemm = typename G::GridwiseGemm;
// // tensor descriptors for block/thread-wise copy
// constexpr auto a_grid_desc_ak0_m_ak1 =
// GridwiseGemm::MakeDefaultAGridDescriptor_AK0_M_AK1(a_grid_desc_m_k);
// constexpr auto b_grid_desc_bk0_n_bk1 =
// GridwiseGemm::MakeDefaultBGridDescriptor_BK0_N_BK1(b_grid_desc_n_k);
// constexpr auto ds_grid_desc_mblock_mperblock_nblock_nperblock =
// GridwiseGemm::MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(ds_grid_desc_m_n);
// constexpr auto e_grid_desc_mblock_mperblock_nblock_nperblock =
// GridwiseGemm::MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(e_grid_desc_m_n);
// static_assert(GridwiseGemm::CheckValidity(
// a_grid_desc_m_k, b_grid_desc_n_k, ds_grid_desc_m_n, e_grid_desc_m_n, block_2_etile_map));
// __shared__ char p_shared_block[GridwiseGemm::GetSharedMemoryNumberOfByte()];
// constexpr const bool HasMainKBlockLoop =
// GridwiseGemm::CalculateHasMainKBlockLoop(a_grid_desc_ak0_m_ak1.GetLength(ck::Number<0>{}) *
// a_grid_desc_ak0_m_ak1.GetLength(ck::Number<2>{}));
// GridwiseGemm::template Run<HasMainKBlockLoop>(to_ck_const_pointer(a.data()),
// to_ck_const_pointer(b.data()),
// ck::make_tuple(to_ck_const_pointer(ds.data())...),
// to_ck_pointer(e.data()),
// p_shared_block,
// gemm.a_element_op,
// gemm.b_element_op,
// gemm.cde_element_op,
// a_grid_desc_ak0_m_ak1,
// b_grid_desc_bk0_n_bk1,
// ds_grid_desc_mblock_mperblock_nblock_nperblock,
// e_grid_desc_mblock_mperblock_nblock_nperblock,
// block_2_etile_map);
}
template <class G, index_int BlocksPerBatch, class... Ts>
......
......@@ -4031,6 +4031,43 @@ def matmulinteger_test():
return ([node], [m1, m2], [y])
@onnx_test()
def int8_gemm():
m1 = helper.make_tensor_value_info('1', TensorProto.INT8, [256, 256])
m2 = helper.make_tensor_value_info('2', TensorProto.INT8, [256, 256])
y = helper.make_tensor_value_info('y', TensorProto.INT8, [256, 256])
node = onnx.helper.make_node(
'MatMulInteger',
inputs=['1', '2'],
outputs=['y'],
)
return ([node], [m1, m2], [y])
@onnx_test()
def int8_gemm_verify():
m1 = helper.make_tensor_value_info('1', TensorProto.INT8, [256, 256])
m2 = helper.make_tensor_value_info('2', TensorProto.INT8, [256, 256])
y = helper.make_tensor_value_info('y', TensorProto.INT32, [256, 256])
node = onnx.helper.make_node(
'MatMulInteger',
inputs=['1', '2'],
outputs=['x'],
)
convert = onnx.helper.make_node(
'Cast',
inputs=['x'],
outputs=['y'],
to=6
)
return ([node, convert], [m1, m2], [y])
@onnx_test()
def matmulinteger_dyn_error():
m1 = helper.make_tensor_value_info('1', TensorProto.INT8, [None, 6, 16])
......
 int8_gemm:j

1
2y" MatMulInteger int8_gemmZ
1


Z
2


b
y


B
\ No newline at end of file
int8_gemm_verify:

1
2x" MatMulInteger

xy"Cast*
toint8_gemm_verifyZ
1


Z
2


b
y


B
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment