Commit 7c636a51 authored by Alan Turner's avatar Alan Turner
Browse files

Formatting

parent b9f1b198
...@@ -349,9 +349,9 @@ struct ck_gemm_compiler : compiler<ck_gemm_compiler> ...@@ -349,9 +349,9 @@ struct ck_gemm_compiler : compiler<ck_gemm_compiler>
std::array<std::size_t, 3> config{m, n, k}; std::array<std::size_t, 3> config{m, n, k};
auto tuning_val = v.get("tuning_val", get_tuning_for({a_shape, b_shape, c_shape})); auto tuning_val = v.get("tuning_val", get_tuning_for({a_shape, b_shape, c_shape}));
auto ip = instance{get_instance(tuning_val, [&](const auto& x) -> bool { auto ip = instance{get_instance(tuning_val, [&](const auto& x) -> bool {
return true;/* get_layout(a_shape) == x[0] and get_layout(b_shape) == x[1] and return true; /* get_layout(a_shape) == x[0] and get_layout(b_shape) == x[1] and
get_layout(c_shape) == x[3] and get_type(a_shape) == x[4] and get_layout(c_shape) == x[3] and get_type(a_shape) == x[4] and
get_type(b_shape) == x[5] and get_type(c_shape) == x[9]; */ get_type(b_shape) == x[5] and get_type(c_shape) == x[9]; */
})}; })};
assert(inputs.size() < 4 or v.contains("post")); assert(inputs.size() < 4 or v.contains("post"));
if(v.contains("post")) if(v.contains("post"))
...@@ -374,13 +374,14 @@ struct ck_gemm_compiler : compiler<ck_gemm_compiler> ...@@ -374,13 +374,14 @@ struct ck_gemm_compiler : compiler<ck_gemm_compiler>
gemm_type += "Padding"; gemm_type += "Padding";
ip.set_gemm("ck::tensor_operation::device::GemmSpecialization::" + gemm_type); ip.set_gemm("ck::tensor_operation::device::GemmSpecialization::" + gemm_type);
auto blocks_per_batch = int_div_ceil(m, 128) * int_div_ceil(n, 128);;//ip.get_grid_size(config); auto blocks_per_batch = int_div_ceil(m, 128) * int_div_ceil(n, 128);
; // ip.get_grid_size(config);
hip_compile_options options; hip_compile_options options;
auto block_size = 256;//ip.get_block_size(); auto block_size = 256; // ip.get_block_size();
auto grid_size = can_fold_batch ? blocks_per_batch : batch_count * blocks_per_batch; auto grid_size = can_fold_batch ? blocks_per_batch : batch_count * blocks_per_batch;
options.set_launch_params(v, grid_size * block_size, block_size); options.set_launch_params(v, grid_size * block_size, block_size);
//auto new_inputs = inputs; // auto new_inputs = inputs;
auto new_inputs = inputs; auto new_inputs = inputs;
// auto out_s = inputs.back(); // auto out_s = inputs.back();
// new_inputs.back() = shape{shape::int8_type, out_s.lens(), out_s.strides()}; // new_inputs.back() = shape{shape::int8_type, out_s.lens(), out_s.strides()};
......
...@@ -46,7 +46,9 @@ using ck_transposeb = decltype(make_shape(ck_transposeb_dims(get_shape_c<Tensor> ...@@ -46,7 +46,9 @@ using ck_transposeb = decltype(make_shape(ck_transposeb_dims(get_shape_c<Tensor>
ck_transposeb_dims(get_shape_c<Tensor>{}.strides))); ck_transposeb_dims(get_shape_c<Tensor>{}.strides)));
template <class... Xs> template <class... Xs>
constexpr void noop(Xs...) {} constexpr void noop(Xs...)
{
}
template <class G, class E, class A, class B, class... Ds> template <class G, class E, class A, class B, class... Ds>
__device__ void ck_gemm_matrix(E e, A a, B b, Ds... ds) __device__ void ck_gemm_matrix(E e, A a, B b, Ds... ds)
...@@ -78,17 +80,17 @@ __device__ void ck_gemm_matrix(E e, A a, B b, Ds... ds) ...@@ -78,17 +80,17 @@ __device__ void ck_gemm_matrix(E e, A a, B b, Ds... ds)
// GridwiseGemm::MakeAGridDescriptor_K0_M0_M1_K1(a_grid_desc_k0_m_k1); // GridwiseGemm::MakeAGridDescriptor_K0_M0_M1_K1(a_grid_desc_k0_m_k1);
// noop(a, b, e, M, N, K, a_grid_desc_k0_m_k1, a_grid_desc_k0_m0_m1_k1, ds...); // noop(a, b, e, M, N, K, a_grid_desc_k0_m_k1, a_grid_desc_k0_m0_m1_k1, ds...);
////////////////////////// //////////////////////////
// constexpr const G gemm{}; // constexpr const G gemm{};
// constexpr const auto a_grid_desc_m_k = gemm.matrix_padder.PadADescriptor_M_K(to_ck_tensor<A>()); // constexpr const auto a_grid_desc_m_k =
// constexpr const auto b_grid_desc_n_k = // gemm.matrix_padder.PadADescriptor_M_K(to_ck_tensor<A>()); constexpr const auto
// b_grid_desc_n_k =
// gemm.matrix_padder.PadBDescriptor_N_K(to_ck_tensor<ck_transposeb<B>>()); // gemm.matrix_padder.PadBDescriptor_N_K(to_ck_tensor<ck_transposeb<B>>());
// constexpr const auto e_grid_desc_m_n = gemm.matrix_padder.PadCDescriptor_M_N(to_ck_tensor<E>()); // constexpr const auto e_grid_desc_m_n =
// constexpr const auto ds_grid_desc_m_n = // gemm.matrix_padder.PadCDescriptor_M_N(to_ck_tensor<E>()); constexpr const auto
// ds_grid_desc_m_n =
// ck::make_tuple(gemm.matrix_padder.PadCDescriptor_M_N(to_ck_tensor<Ds>())...); // ck::make_tuple(gemm.matrix_padder.PadCDescriptor_M_N(to_ck_tensor<Ds>())...);
// constexpr const auto block_2_etile_map = gemm.MakeDefaultBlock2ETileMap(e_grid_desc_m_n); // constexpr const auto block_2_etile_map = gemm.MakeDefaultBlock2ETileMap(e_grid_desc_m_n);
...@@ -112,7 +114,8 @@ __device__ void ck_gemm_matrix(E e, A a, B b, Ds... ds) ...@@ -112,7 +114,8 @@ __device__ void ck_gemm_matrix(E e, A a, B b, Ds... ds)
// __shared__ char p_shared_block[GridwiseGemm::GetSharedMemoryNumberOfByte()]; // __shared__ char p_shared_block[GridwiseGemm::GetSharedMemoryNumberOfByte()];
// constexpr const bool HasMainKBlockLoop = // constexpr const bool HasMainKBlockLoop =
// GridwiseGemm::CalculateHasMainKBlockLoop(a_grid_desc_ak0_m_ak1.GetLength(ck::Number<0>{}) * // GridwiseGemm::CalculateHasMainKBlockLoop(a_grid_desc_ak0_m_ak1.GetLength(ck::Number<0>{})
// *
// a_grid_desc_ak0_m_ak1.GetLength(ck::Number<2>{})); // a_grid_desc_ak0_m_ak1.GetLength(ck::Number<2>{}));
// GridwiseGemm::template Run<HasMainKBlockLoop>(to_ck_const_pointer(a.data()), // GridwiseGemm::template Run<HasMainKBlockLoop>(to_ck_const_pointer(a.data()),
// to_ck_const_pointer(b.data()), // to_ck_const_pointer(b.data()),
......
...@@ -4058,12 +4058,7 @@ def int8_gemm_verify(): ...@@ -4058,12 +4058,7 @@ def int8_gemm_verify():
outputs=['x'], outputs=['x'],
) )
convert = onnx.helper.make_node( convert = onnx.helper.make_node('Cast', inputs=['x'], outputs=['y'], to=6)
'Cast',
inputs=['x'],
outputs=['y'],
to=6
)
return ([node, convert], [m1, m2], [y]) return ([node, convert], [m1, m2], [y])
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment