Commit 5899b0fc authored by Alan Turner's avatar Alan Turner
Browse files

Formatting

parent e72ecc75
...@@ -131,16 +131,17 @@ struct find_ck_batched_gemm ...@@ -131,16 +131,17 @@ struct find_ck_batched_gemm
void apply(module_pass_manager& mpm, const match::matcher_result& r) const void apply(module_pass_manager& mpm, const match::matcher_result& r) const
{ {
auto ins = r.result; auto ins = r.result;
mpm.get_module().replace_instruction(ins, ck_batched_gemm{ins->get_operator()}, ins->inputs()); mpm.get_module().replace_instruction(
ins, ck_batched_gemm{ins->get_operator()}, ins->inputs());
} }
}; };
} // namespace } // namespace
void fuse_ck::apply(module_pass_manager& mpm) const void fuse_ck::apply(module_pass_manager& mpm) const
{ {
match::find_matches(mpm, find_ck_gemm{}); match::find_matches(mpm, find_ck_gemm{});
match::find_matches(mpm, find_ck_batched_gemm{}); match::find_matches(mpm, find_ck_batched_gemm{});
} }
} // namespace gpu } // namespace gpu
......
...@@ -131,10 +131,7 @@ static std::size_t get_tuning_for(const std::vector<shape>& inputs) ...@@ -131,10 +131,7 @@ static std::size_t get_tuning_for(const std::vector<shape>& inputs)
return it->second; return it->second;
} }
static std::size_t get_batch_stride(const shape& s) static std::size_t get_batch_stride(const shape& s) { return s.strides()[s.strides().size() - 3]; }
{
return s.strides()[s.strides().size() - 3];
}
struct ck_batched_gemm_compiler : compiler<ck_batched_gemm_compiler> struct ck_batched_gemm_compiler : compiler<ck_batched_gemm_compiler>
{ {
...@@ -186,7 +183,7 @@ struct ck_batched_gemm_compiler : compiler<ck_batched_gemm_compiler> ...@@ -186,7 +183,7 @@ struct ck_batched_gemm_compiler : compiler<ck_batched_gemm_compiler>
hip_compile_options options; hip_compile_options options;
// batch_count // batch_count
auto out_lens = c_shape.lens(); auto out_lens = c_shape.lens();
auto batch_count = std::accumulate( auto batch_count = std::accumulate(
out_lens.rbegin() + 2, out_lens.rend(), std::size_t{1}, std::multiplies<std::size_t>()); out_lens.rbegin() + 2, out_lens.rend(), std::size_t{1}, std::multiplies<std::size_t>());
auto batchStrideA = get_batch_stride(a_shape); auto batchStrideA = get_batch_stride(a_shape);
...@@ -209,7 +206,8 @@ struct ck_batched_gemm_compiler : compiler<ck_batched_gemm_compiler> ...@@ -209,7 +206,8 @@ struct ck_batched_gemm_compiler : compiler<ck_batched_gemm_compiler>
options.kernel_name = "ck_batched_gemm_kernel"; options.kernel_name = "ck_batched_gemm_kernel";
options.virtual_inputs = inputs; options.virtual_inputs = inputs;
auto src = interpolate_string(ck_batched_gemm_kernel, {{"instance", join_strings(instance, ",")}}); auto src =
interpolate_string(ck_batched_gemm_kernel, {{"instance", join_strings(instance, ",")}});
return compile_hip_code_object(src, options); return compile_hip_code_object(src, options);
} }
......
...@@ -111,12 +111,13 @@ constexpr F for_each(Iterator first, Iterator last, F f) ...@@ -111,12 +111,13 @@ constexpr F for_each(Iterator first, Iterator last, F f)
} }
template <class Iterator, class T> template <class Iterator, class T>
constexpr void fill (Iterator first, Iterator last, const T& val) constexpr void fill(Iterator first, Iterator last, const T& val)
{ {
while (first != last) { while(first != last)
*first = val; {
++first; *first = val;
} ++first;
}
} }
template <class Iterator, class Predicate> template <class Iterator, class Predicate>
......
...@@ -62,10 +62,10 @@ constexpr auto to_ck_tensor() ...@@ -62,10 +62,10 @@ constexpr auto to_ck_tensor()
template <class Tensor> template <class Tensor>
constexpr auto to_ck_batched_tensor() constexpr auto to_ck_batched_tensor()
{ {
constexpr auto s = get_shape_c<Tensor>{}; constexpr auto s = get_shape_c<Tensor>{};
constexpr auto sz = s.lens.size(); constexpr auto sz = s.lens.size();
return ck::make_naive_tensor_descriptor(ck::make_tuple(s.lens[sz - 2], s.lens[sz - 1]), return ck::make_naive_tensor_descriptor(ck::make_tuple(s.lens[sz - 2], s.lens[sz - 1]),
ck::make_tuple(s.strides[sz - 2], s.strides[sz - 1])); ck::make_tuple(s.strides[sz - 2], s.strides[sz - 1]));
} }
template <class F> template <class F>
......
...@@ -53,13 +53,13 @@ template <ck::index_t NumDTensor> ...@@ -53,13 +53,13 @@ template <ck::index_t NumDTensor>
struct ComputePtrOffsetOfStridedBatch struct ComputePtrOffsetOfStridedBatch
{ {
__device__ ComputePtrOffsetOfStridedBatch(ck::index_t BatchStrideA, __device__ ComputePtrOffsetOfStridedBatch(ck::index_t BatchStrideA,
ck::index_t BatchStrideB, ck::index_t BatchStrideB,
std::array<ck::index_t, NumDTensor> BatchStrideDs, std::array<ck::index_t, NumDTensor> BatchStrideDs,
ck::index_t BatchStrideE) ck::index_t BatchStrideE)
: BatchStrideA_(BatchStrideA), : BatchStrideA_(BatchStrideA),
BatchStrideB_(BatchStrideB), BatchStrideB_(BatchStrideB),
BatchStrideDs_(BatchStrideDs), BatchStrideDs_(BatchStrideDs),
BatchStrideE_(BatchStrideE) BatchStrideE_(BatchStrideE)
{ {
} }
...@@ -94,15 +94,17 @@ struct ComputePtrOffsetOfStridedBatch ...@@ -94,15 +94,17 @@ struct ComputePtrOffsetOfStridedBatch
ck::index_t BatchStrideE_; ck::index_t BatchStrideE_;
}; };
template <class G, class Settings, class A, class B, class E, class... Ds> template <class G, class Settings, class A, class B, class E, class... Ds>
__device__ void ck_batched_gemm(Settings s, A a, B b, E e, Ds... ds) __device__ void ck_batched_gemm(Settings s, A a, B b, E e, Ds... ds)
{ {
constexpr const G gemm{}; constexpr const G gemm{};
constexpr const auto a_grid_desc_m_k = gemm.matrix_padder.PadADescriptor_M_K(to_ck_batched_tensor<A>()); constexpr const auto a_grid_desc_m_k =
constexpr const auto b_grid_desc_n_k = gemm.matrix_padder.PadBDescriptor_N_K(to_ck_batched_tensor<B>()); gemm.matrix_padder.PadADescriptor_M_K(to_ck_batched_tensor<A>());
constexpr const auto e_grid_desc_m_n = gemm.matrix_padder.PadCDescriptor_M_N(to_ck_batched_tensor<E>()); constexpr const auto b_grid_desc_n_k =
gemm.matrix_padder.PadBDescriptor_N_K(to_ck_batched_tensor<B>());
constexpr const auto e_grid_desc_m_n =
gemm.matrix_padder.PadCDescriptor_M_N(to_ck_batched_tensor<E>());
constexpr const auto ds_grid_desc_m_n = constexpr const auto ds_grid_desc_m_n =
ck::make_tuple(gemm.matrix_padder.PadCDescriptor_M_N(to_ck_batched_tensor<Ds>())...); ck::make_tuple(gemm.matrix_padder.PadCDescriptor_M_N(to_ck_batched_tensor<Ds>())...);
constexpr const auto block_2_etile_map = gemm.MakeDefaultBlock2ETileMap(e_grid_desc_m_n); constexpr const auto block_2_etile_map = gemm.MakeDefaultBlock2ETileMap(e_grid_desc_m_n);
...@@ -124,17 +126,18 @@ __device__ void ck_batched_gemm(Settings s, A a, B b, E e, Ds... ds) ...@@ -124,17 +126,18 @@ __device__ void ck_batched_gemm(Settings s, A a, B b, E e, Ds... ds)
constexpr const bool HasMainKBlockLoop = constexpr const bool HasMainKBlockLoop =
GridwiseGemm::CalculateHasMainKBlockLoop(a_grid_desc_ak0_m_ak1.GetLength(ck::Number<0>{}) * GridwiseGemm::CalculateHasMainKBlockLoop(a_grid_desc_ak0_m_ak1.GetLength(ck::Number<0>{}) *
a_grid_desc_ak0_m_ak1.GetLength(ck::Number<2>{})); a_grid_desc_ak0_m_ak1.GetLength(ck::Number<2>{}));
static constexpr ck::index_t NumDTensor = gemm.NumDTensor; static constexpr ck::index_t NumDTensor = gemm.NumDTensor;
std::array<ck::index_t, NumDTensor> batchStrideDs; std::array<ck::index_t, NumDTensor> batchStrideDs;
ck::static_for<0, NumDTensor, 1>{}( ck::static_for<0, NumDTensor, 1>{}([&](auto i) { batchStrideDs[i] = s.batchStrideC; });
[&](auto i) { batchStrideDs[i] = s.batchStrideC; }); const ComputePtrOffsetOfStridedBatch<NumDTensor> compute_ptr_offset_of_batch{
const ComputePtrOffsetOfStridedBatch<NumDTensor> compute_ptr_offset_of_batch{s.batchStrideA, s.batchStrideB, batchStrideDs, s.batchStrideC}; s.batchStrideA, s.batchStrideB, batchStrideDs, s.batchStrideC};
auto batch_count = s.batch_count; auto batch_count = s.batch_count;
const ck::index_t num_blocks_per_batch = const ck::index_t num_blocks_per_batch =
__builtin_amdgcn_readfirstlane(ck::get_grid_size() / batch_count); __builtin_amdgcn_readfirstlane(ck::get_grid_size() / batch_count);
const ck::index_t g_idx = __builtin_amdgcn_readfirstlane(ck::get_block_1d_id() / num_blocks_per_batch); const ck::index_t g_idx =
__builtin_amdgcn_readfirstlane(ck::get_block_1d_id() / num_blocks_per_batch);
const ck::long_index_t a_batch_offset = __builtin_amdgcn_readfirstlane( const ck::long_index_t a_batch_offset = __builtin_amdgcn_readfirstlane(
static_cast<ck::long_index_t>(compute_ptr_offset_of_batch.GetAPtrOffset(g_idx))); static_cast<ck::long_index_t>(compute_ptr_offset_of_batch.GetAPtrOffset(g_idx)));
...@@ -154,7 +157,7 @@ __device__ void ck_batched_gemm(Settings s, A a, B b, E e, Ds... ds) ...@@ -154,7 +157,7 @@ __device__ void ck_batched_gemm(Settings s, A a, B b, E e, Ds... ds)
GridwiseGemm::template Run<HasMainKBlockLoop>(a.data() + a_batch_offset, GridwiseGemm::template Run<HasMainKBlockLoop>(a.data() + a_batch_offset,
b.data() + b_batch_offset, b.data() + b_batch_offset,
p_ds_grid_grp, p_ds_grid_grp,
e.data() + e_batch_offset, e.data() + e_batch_offset,
p_shared, p_shared,
gemm.a_element_op, gemm.a_element_op,
......
...@@ -32,14 +32,14 @@ struct ck_batched_gemm : verify_program<ck_batched_gemm> ...@@ -32,14 +32,14 @@ struct ck_batched_gemm : verify_program<ck_batched_gemm>
migraphx::program create_program() const migraphx::program create_program() const
{ {
migraphx::program p; migraphx::program p;
auto* mm = p.get_main_module(); auto* mm = p.get_main_module();
std::size_t b = 2; std::size_t b = 2;
std::size_t m = 3; std::size_t m = 3;
std::size_t n = 3; std::size_t n = 3;
std::size_t k = 3; std::size_t k = 3;
migraphx::shape m1_shape{migraphx::shape::half_type, {b, m, k}}; migraphx::shape m1_shape{migraphx::shape::half_type, {b, m, k}};
std::vector<float> v1(b*m*k, 1); std::vector<float> v1(b * m * k, 1);
std::vector<float> v2(b*k*n, 1);//{1, 2, 3, 4, 5, 6, 7, 8}; std::vector<float> v2(b * k * n, 1); //{1, 2, 3, 4, 5, 6, 7, 8};
// auto l1 = mm->add_parameter("1", m1_shape); // auto l1 = mm->add_parameter("1", m1_shape);
// auto l2 = mm->add_parameter("2", m1_shape); // auto l2 = mm->add_parameter("2", m1_shape);
auto l1 = mm->add_literal(migraphx::literal{m1_shape, v1}); auto l1 = mm->add_literal(migraphx::literal{m1_shape, v1});
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment