"src/vscode:/vscode.git/clone" did not exist on "b16ab01db51c453e0bee8153978e5e9cbc9efbf2"
Commit e72ecc75 authored by Alan Turner's avatar Alan Turner
Browse files

Add batched gemm

parent 54dd72b6
......@@ -46,6 +46,41 @@ struct ck_gemm
};
MIGRAPHX_REGISTER_OP(ck_gemm);
struct ck_batched_gemm
{
operation op = make_op("dot");
template <class Self, class F>
static auto reflect(Self& self, F f)
{
return pack(f(self.op, "op"));
}
std::string name() const { return "gpu::ck_batched_gemm"; }
void check_gemm_shape(const shape& s) const
{
if(contains(s.lens(), 1))
MIGRAPHX_THROW("Invalid shape for ck_batched_gemm");
}
shape compute_shape(std::vector<shape> inputs, const std::vector<module_ref>& mods) const
{
check_shapes{inputs, *this}.not_broadcasted();
// if(mods.size() != 1)
// MIGRAPHX_THROW("should have one submodule.");
if(inputs.size() < 2)
MIGRAPHX_THROW("should have at least two inputs.");
auto n = inputs.size();
auto a = inputs[n - 2];
auto b = inputs[n - 1];
check_gemm_shape(a);
check_gemm_shape(b);
return op.compute_shape({a, b});
}
};
MIGRAPHX_REGISTER_OP(ck_batched_gemm);
namespace {
MIGRAPHX_PRED_MATCHER(is_ck_gemm, instruction_ref ins)
......@@ -62,6 +97,20 @@ MIGRAPHX_PRED_MATCHER(is_ck_gemm, instruction_ref ins)
return true;
}
MIGRAPHX_PRED_MATCHER(is_ck_batched_gemm, instruction_ref ins)
{
if(ins->name() != "dot")
return false;
auto a = ins->inputs().front()->get_shape();
auto b = ins->inputs().back()->get_shape();
if(a.lens().size() < 3 or b.lens().size() < 3)
return false;
if(a.lens().back() > 1024)
return false;
return true;
}
struct find_ck_gemm
{
// Find a gemm that can be replaced with a ck_gemm
......@@ -74,9 +123,25 @@ struct find_ck_gemm
}
};
struct find_ck_batched_gemm
{
// Find a batched gemm that can be replaced with a ck_batched_gemm
auto matcher() const { return match::name("dot")(is_ck_batched_gemm().bind("gemm")); }
void apply(module_pass_manager& mpm, const match::matcher_result& r) const
{
auto ins = r.result;
mpm.get_module().replace_instruction(ins, ck_batched_gemm{ins->get_operator()}, ins->inputs());
}
};
} // namespace
void fuse_ck::apply(module_pass_manager& mpm) const { match::find_matches(mpm, find_ck_gemm{}); }
void fuse_ck::apply(module_pass_manager& mpm) const
{
match::find_matches(mpm, find_ck_gemm{});
match::find_matches(mpm, find_ck_batched_gemm{});
}
} // namespace gpu
......
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#include <fstream>
#include <filesystem>
#include <migraphx/gpu/compiler.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/gpu/context.hpp>
#include <migraphx/gpu/compile_hip_code_object.hpp>
#include <migraphx/gpu/compile_hip.hpp>
#include <migraphx/ranges.hpp>
#include <migraphx/env.hpp>
#include <migraphx/reduce_dims.hpp>
#include <migraphx/stringutils.hpp>
#include <migraphx/dead_code_elimination.hpp>
#include <migraphx/eliminate_common_subexpression.hpp>
#include <migraphx/module.hpp>
#include <migraphx/pass_manager.hpp>
#include <migraphx/env.hpp>
#include <migraphx/file_buffer.hpp>
std::vector<std::string>&
get_instance(std::size_t i, const std::function<bool(const std::vector<std::string>&)>& pred);
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace gpu {
MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_LOG_CK_GEMM);
MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_CK_TUNING);
// NOLINTNEXTLINE
static const char* const ck_batched_gemm_kernel = R"__migraphx__(
#include <args.hpp>
#include <migraphx/kernels/ck_batched_gemm.hpp>
#include <migraphx/kernels/integral_constant.hpp>
#include <migraphx/kernels/generic_constant.hpp>
namespace migraphx {
extern "C" {
__global__ void ck_batched_gemm_kernel(void* a_p, void* b_p, void* c_p)
{
make_tensors()(a_p, b_p, c_p)([&](auto a, auto b, auto c) {
auto settings = make_ck_batched_gemm_settings(MIGRAPHX_MAKE_CONSTANT(int64_t{BATCH_COUNT}),
MIGRAPHX_MAKE_CONSTANT(int64_t{BATCHSTRIDEA}),
MIGRAPHX_MAKE_CONSTANT(int64_t{BATCHSTRIDEB}),
MIGRAPHX_MAKE_CONSTANT(int64_t{BATCHSTRIDEC}));
ck_batched_gemm<CK_DeviceBatchedGemmMultipleD<${instance}>>(settings, a, b, c);
});
}
}
} // namespace migraphx
)__migraphx__";
static std::size_t int_div_ceil(std::size_t x, std::size_t y) { return (x + y - 1) / y; }
static std::size_t block_size_index = 15;
static std::size_t padding_index = 13;
static std::size_t get_block_size(const std::vector<std::string>& s)
{
return std::stoull(s[block_size_index]);
}
static std::size_t get_grid_size(const std::vector<std::string>& s, std::size_t m, std::size_t n)
{
auto mpb = std::stoull(s[block_size_index + 1]);
auto npb = std::stoull(s[block_size_index + 2]);
return int_div_ceil(m, mpb) * int_div_ceil(n, npb);
}
static void set_padding(std::vector<std::string>& s, const std::string p) { s[padding_index] = p; }
template <class F, class Action>
auto action_decorate(F f, Action action)
{
return [=](auto&&... xs) {
action();
f(std::forward<decltype(xs)>(xs)...);
};
}
using tuning_entry = std::pair<std::vector<shape>, size_t>;
static std::vector<tuning_entry> read_tuning(const std::string& s)
{
if(not fs::exists(s))
return {};
return from_value<std::vector<tuning_entry>>(from_json_string(read_string(s)));
}
static std::size_t get_tuning_for(const std::vector<shape>& inputs)
{
static auto tuning = read_tuning(string_value_of(MIGRAPHX_CK_TUNING{}, ""));
if(tuning.empty())
std::cout << "*********** Warning: No CK tuning!" << std::endl;
auto it = std::find_if(
tuning.begin(), tuning.end(), [&](const auto& p) { return p.first == inputs; });
if(it == tuning.end())
{
std::cout << "*********** Warning: CK tuning missing for config!" << std::endl;
return 6;
}
return it->second;
}
static std::size_t get_batch_stride(const shape& s)
{
return s.strides()[s.strides().size() - 3];
}
struct ck_batched_gemm_compiler : compiler<ck_batched_gemm_compiler>
{
static std::string get_layout(const shape& s)
{
return s.transposed() ? "ck::tensor_layout::gemm::ColumnMajor"
: "ck::tensor_layout::gemm::RowMajor";
}
static std::string get_type(const shape& s)
{
if(s.type() == shape::half_type)
return "ck::half_t";
return shape::cpp_type(s.type());
}
std::vector<std::string> names() const { return {"ck_batched_gemm", "gpu::ck_batched_gemm"}; }
operation compile_op(context& /* ctx */, const std::vector<shape>& inputs, const value& v) const
{
auto a_shape = inputs[0];
auto b_shape = inputs[1];
auto c_shape = inputs[2];
auto m = c_shape.lens().front();
auto n = c_shape.lens().back();
auto k = a_shape.lens().back();
auto i = v.get("tuning_val", get_tuning_for(inputs));
auto& instance = get_instance(i, [&](const auto& x) -> bool {
return get_layout(a_shape) == x[0] and get_layout(b_shape) == x[1] and
get_layout(c_shape) == x[3] and get_type(a_shape) == x[4] and
get_type(b_shape) == x[5] and get_type(c_shape) == x[9];
});
const bool pad_m = m % 8;
const bool pad_n = n % 8;
const bool pad_k = k % 8;
if(pad_m or pad_n or pad_k)
{
std::string padding_t = "ck::tensor_operation::device::GemmSpecialization::";
padding_t += pad_m ? "M" : "";
padding_t += pad_n ? "N" : "";
padding_t += pad_k ? "K" : "";
padding_t += "Padding";
set_padding(instance, padding_t);
}
hip_compile_options options;
// batch_count
auto out_lens = c_shape.lens();
auto batch_count = std::accumulate(
out_lens.rbegin() + 2, out_lens.rend(), std::size_t{1}, std::multiplies<std::size_t>());
auto batchStrideA = get_batch_stride(a_shape);
auto batchStrideB = get_batch_stride(b_shape);
auto batchStrideC = get_batch_stride(c_shape);
options.params += " -DBATCH_COUNT=" + std::to_string(batch_count);
options.params += " -DBATCHSTRIDEA=" + std::to_string(batchStrideA);
options.params += " -DBATCHSTRIDEB=" + std::to_string(batchStrideB);
options.params += " -DBATCHSTRIDEC=" + std::to_string(batchStrideC);
std::cout << "Batch_count: " << batch_count << std::endl;
std::cout << "BatchStrideA: " << batchStrideA << std::endl;
std::cout << "BatchStrideB: " << batchStrideB << std::endl;
std::cout << "BatchStrideC: " << batchStrideC << std::endl;
auto block_size = get_block_size(instance);
auto grid_size = batch_count * get_grid_size(instance, m, n);
options.set_launch_params(v, grid_size * block_size, block_size);
options.inputs = inputs;
options.output = c_shape;
options.kernel_name = "ck_batched_gemm_kernel";
options.virtual_inputs = inputs;
auto src = interpolate_string(ck_batched_gemm_kernel, {{"instance", join_strings(instance, ",")}});
return compile_hip_code_object(src, options);
}
compiler_replace compile(context& ctx, instruction_ref ins, const operation& op) const
{
auto shapes = to_shapes(ins->inputs());
return action_decorate(replace(compile_op(ctx, shapes, op.to_value())), [=] {
if(enabled(MIGRAPHX_LOG_CK_GEMM{}))
std::cout << "ck_batched_gemm: " << to_json_string(to_value(shapes)) << std::endl;
});
}
};
} // namespace gpu
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
......@@ -110,6 +110,15 @@ constexpr F for_each(Iterator first, Iterator last, F f)
return f;
}
template <class Iterator, class T>
constexpr void fill (Iterator first, Iterator last, const T& val)
{
while (first != last) {
*first = val;
++first;
}
}
template <class Iterator, class Predicate>
constexpr Iterator find_if(Iterator first, Iterator last, Predicate p)
{
......
......@@ -59,6 +59,15 @@ constexpr auto to_ck_tensor()
});
}
template <class Tensor>
constexpr auto to_ck_batched_tensor()
{
constexpr auto s = get_shape_c<Tensor>{};
constexpr auto sz = s.lens.size();
return ck::make_naive_tensor_descriptor(ck::make_tuple(s.lens[sz - 2], s.lens[sz - 1]),
ck::make_tuple(s.strides[sz - 2], s.strides[sz - 1]));
}
template <class F>
struct ck_function_adaptor : F
{
......
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#ifndef MIGRAPHX_GUARD_KERNELS_CK_BATCHED_GEMM_HPP
#define MIGRAPHX_GUARD_KERNELS_CK_BATCHED_GEMM_HPP
#include <migraphx/kernels/index.hpp>
#include <migraphx/kernels/algorithm.hpp>
#include <migraphx/kernels/integral_constant.hpp>
#include <migraphx/kernels/tensor_view.hpp>
#include <migraphx/kernels/ck.hpp>
#include <migraphx/kernels/ck_batched_gemm_includes.hpp>
#include <migraphx/kernels/shape.hpp>
namespace migraphx {
template <class T0, class T1, class T2, class T3>
struct ck_batched_gemm_settings
{
T0 batch_count{};
T1 batchStrideA{};
T2 batchStrideB{};
T3 batchStrideC{};
};
template <class... Ts>
constexpr ck_batched_gemm_settings<Ts...> make_ck_batched_gemm_settings(Ts... xs)
{
return {xs...};
}
template <ck::index_t NumDTensor>
struct ComputePtrOffsetOfStridedBatch
{
__device__ ComputePtrOffsetOfStridedBatch(ck::index_t BatchStrideA,
ck::index_t BatchStrideB,
std::array<ck::index_t, NumDTensor> BatchStrideDs,
ck::index_t BatchStrideE)
: BatchStrideA_(BatchStrideA),
BatchStrideB_(BatchStrideB),
BatchStrideDs_(BatchStrideDs),
BatchStrideE_(BatchStrideE)
{
}
__host__ __device__ constexpr ck::long_index_t GetAPtrOffset(ck::index_t g_idx) const
{
return g_idx * static_cast<ck::long_index_t>(BatchStrideA_);
}
__host__ __device__ constexpr ck::long_index_t GetBPtrOffset(ck::index_t g_idx) const
{
return g_idx * static_cast<ck::long_index_t>(BatchStrideB_);
}
__host__ __device__ constexpr auto GetDsPtrOffset(ck::index_t g_idx) const
{
std::array<ck::long_index_t, NumDTensor> ds_offset;
ck::static_for<0, NumDTensor, 1>{}([&](auto i) {
ds_offset[i] = g_idx * static_cast<ck::long_index_t>(BatchStrideDs_[i]);
});
return ds_offset;
}
__host__ __device__ constexpr ck::long_index_t GetEPtrOffset(ck::index_t g_idx) const
{
return g_idx * static_cast<ck::long_index_t>(BatchStrideE_);
}
private:
ck::index_t BatchStrideA_;
ck::index_t BatchStrideB_;
std::array<ck::index_t, NumDTensor> BatchStrideDs_;
ck::index_t BatchStrideE_;
};
template <class G, class Settings, class A, class B, class E, class... Ds>
__device__ void ck_batched_gemm(Settings s, A a, B b, E e, Ds... ds)
{
constexpr const G gemm{};
constexpr const auto a_grid_desc_m_k = gemm.matrix_padder.PadADescriptor_M_K(to_ck_batched_tensor<A>());
constexpr const auto b_grid_desc_n_k = gemm.matrix_padder.PadBDescriptor_N_K(to_ck_batched_tensor<B>());
constexpr const auto e_grid_desc_m_n = gemm.matrix_padder.PadCDescriptor_M_N(to_ck_batched_tensor<E>());
constexpr const auto ds_grid_desc_m_n =
ck::make_tuple(gemm.matrix_padder.PadCDescriptor_M_N(to_ck_batched_tensor<Ds>())...);
constexpr const auto block_2_etile_map = gemm.MakeDefaultBlock2ETileMap(e_grid_desc_m_n);
using GridwiseGemm = typename G::GridwiseGemm;
// tensor descriptors for block/thread-wise copy
constexpr auto a_grid_desc_ak0_m_ak1 =
GridwiseGemm::MakeDefaultAGridDescriptor_AK0_M_AK1(a_grid_desc_m_k);
constexpr auto b_grid_desc_bk0_n_bk1 =
GridwiseGemm::MakeDefaultBGridDescriptor_BK0_N_BK1(b_grid_desc_n_k);
constexpr auto ds_grid_desc_mblock_mperblock_nblock_nperblock =
GridwiseGemm::MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(ds_grid_desc_m_n);
constexpr auto e_grid_desc_mblock_mperblock_nblock_nperblock =
GridwiseGemm::MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(e_grid_desc_m_n);
constexpr const bool HasMainKBlockLoop =
GridwiseGemm::CalculateHasMainKBlockLoop(a_grid_desc_ak0_m_ak1.GetLength(ck::Number<0>{}) *
a_grid_desc_ak0_m_ak1.GetLength(ck::Number<2>{}));
static constexpr ck::index_t NumDTensor = gemm.NumDTensor;
std::array<ck::index_t, NumDTensor> batchStrideDs;
ck::static_for<0, NumDTensor, 1>{}(
[&](auto i) { batchStrideDs[i] = s.batchStrideC; });
const ComputePtrOffsetOfStridedBatch<NumDTensor> compute_ptr_offset_of_batch{s.batchStrideA, s.batchStrideB, batchStrideDs, s.batchStrideC};
auto batch_count = s.batch_count;
const ck::index_t num_blocks_per_batch =
__builtin_amdgcn_readfirstlane(ck::get_grid_size() / batch_count);
const ck::index_t g_idx = __builtin_amdgcn_readfirstlane(ck::get_block_1d_id() / num_blocks_per_batch);
const ck::long_index_t a_batch_offset = __builtin_amdgcn_readfirstlane(
static_cast<ck::long_index_t>(compute_ptr_offset_of_batch.GetAPtrOffset(g_idx)));
const ck::long_index_t b_batch_offset = __builtin_amdgcn_readfirstlane(
static_cast<ck::long_index_t>(compute_ptr_offset_of_batch.GetBPtrOffset(g_idx)));
const ck::long_index_t e_batch_offset = __builtin_amdgcn_readfirstlane(
static_cast<ck::long_index_t>(compute_ptr_offset_of_batch.GetEPtrOffset(g_idx)));
const auto ds_batch_offset = compute_ptr_offset_of_batch.GetDsPtrOffset(g_idx);
__shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()];
auto p_ds_grid_grp = ck::make_tuple(ds.data()...);
ck::static_for<0, NumDTensor, 1>{}(
[&](auto i) { p_ds_grid_grp(i) = p_ds_grid_grp[i] + ds_batch_offset[i]; });
GridwiseGemm::template Run<HasMainKBlockLoop>(a.data() + a_batch_offset,
b.data() + b_batch_offset,
p_ds_grid_grp,
e.data() + e_batch_offset,
p_shared,
gemm.a_element_op,
gemm.b_element_op,
gemm.cde_element_op,
a_grid_desc_ak0_m_ak1,
b_grid_desc_bk0_n_bk1,
ds_grid_desc_mblock_mperblock_nblock_nperblock,
e_grid_desc_mblock_mperblock_nblock_nperblock,
block_2_etile_map);
}
} // namespace migraphx
#endif
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#ifndef MIGRAPHX_GUARD_KERNELS_CK_BG_INCLUDES_HPP
#define MIGRAPHX_GUARD_KERNELS_CK_BG_INCLUDES_HPP
#include <migraphx/kernels/index.hpp>
#include <migraphx/kernels/algorithm.hpp>
#include <migraphx/kernels/integral_constant.hpp>
#include <migraphx/kernels/tensor_view.hpp>
// #include <ck/utility/common_header.hpp>
// #include <ck/tensor_description/tensor_descriptor.hpp>
// #include <ck/tensor_description/tensor_descriptor_helper.hpp>
// #include <ck/tensor_operation/gpu/device/tensor_layout.hpp>
// #include <ck/tensor_operation/gpu/device/device_gemm.hpp>
// #include <ck/tensor_operation/gpu/device/gemm_specialization.hpp>
// #include <ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v1.hpp>
// #include <ck/tensor_operation/gpu/device/matrix_padder.hpp>
// #include <ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_xdl_cshuffle.hpp>
#include "ck/utility/common_header.hpp"
#include "ck/tensor_description/tensor_descriptor.hpp"
#include "ck/tensor_description/tensor_descriptor_helper.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/device_batched_gemm_multi_d.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/matrix_padder.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_xdl_cshuffle.hpp"
namespace migraphx {
template <ck::index_t MPerBlock, ck::index_t NPerBlock, typename CGridDesc_M_N>
struct BlockToCTileMap_M00_N0_M01Adapt
{
static constexpr auto I0 = ck::Number<0>{};
static constexpr auto I1 = ck::Number<1>{};
static constexpr auto I2 = ck::Number<2>{};
static constexpr auto I3 = ck::Number<3>{};
__host__ __device__ constexpr BlockToCTileMap_M00_N0_M01Adapt() = default;
__host__
__device__ constexpr BlockToCTileMap_M00_N0_M01Adapt(const CGridDesc_M_N& c_grid_desc_m_n,
ck::index_t M01 = 8)
: M01_(M01), c_grid_desc_m_n_(c_grid_desc_m_n)
{
}
__host__ __device__ constexpr ck::index_t
CalculateGridSize(const CGridDesc_M_N& c_grid_desc_m_n) const
{
const auto M0 = ck::math::integer_divide_ceil(c_grid_desc_m_n.GetLength(I0), MPerBlock);
const auto N0 = ck::math::integer_divide_ceil(c_grid_desc_m_n.GetLength(I1), NPerBlock);
const ck::index_t grid_size = M0 * N0;
return grid_size;
}
template <typename TopIdx>
__host__ __device__ constexpr auto CalculateBottomIndex(const TopIdx& idx_top) const
{
auto block_1d_id = idx_top[I0];
const auto M0 = ck::math::integer_divide_ceil(c_grid_desc_m_n_.GetLength(I0), MPerBlock);
const auto N0 = ck::math::integer_divide_ceil(c_grid_desc_m_n_.GetLength(I1), NPerBlock);
block_1d_id = block_1d_id % (M0 * N0); // swallow batch index
ck::index_t idx_N0 = block_1d_id % N0;
ck::index_t idx_M0 = block_1d_id / N0;
const auto M01_adapt = (idx_M0 < M0 - M0 % M01_) ? M01_ : M0 % M01_;
ck::index_t idx_M00 = idx_M0 / M01_;
ck::index_t idx_M01 = idx_M0 % M01_;
ck::index_t idx_N0_M01_local = idx_N0 + idx_M01 * N0;
return ck::make_tuple(idx_N0_M01_local % M01_adapt + idx_M00 * M01_,
idx_N0_M01_local / M01_adapt);
}
template <typename CTileIdx, typename CTileDim>
__host__ __device__ bool constexpr ValidCTileIndex(const CTileIdx& /* c_tile_idx */,
const CTileDim& /* c_tile_dim */) const
{
return true; // always valid provided that user gets grid size from CalculateGridSize()
}
__host__ __device__ constexpr bool
CheckValidity(const CGridDesc_M_N& /* c_grid_desc_m_n */) const
{
return true;
}
private:
ck::index_t M01_;
CGridDesc_M_N c_grid_desc_m_n_;
};
template <typename ALayout,
typename BLayout,
typename DsLayout,
typename ELayout,
typename ADataType,
typename BDataType,
typename AccDataType,
typename CShuffleDataType,
typename DsDataType,
typename EDataType,
typename AElementwiseOperation,
typename BElementwiseOperation,
typename CDEElementwiseOperation,
ck::tensor_operation::device::GemmSpecialization GemmSpec,
ck::index_t NumGemmKPrefetchStage,
ck::index_t BlockSize,
ck::index_t MPerBlock,
ck::index_t NPerBlock,
ck::index_t KPerBlock,
ck::index_t AK1,
ck::index_t BK1,
ck::index_t MPerXDL,
ck::index_t NPerXDL,
ck::index_t MXdlPerWave,
ck::index_t NXdlPerWave,
typename ABlockTransferThreadClusterLengths_AK0_M_AK1,
typename ABlockTransferThreadClusterArrangeOrder,
typename ABlockTransferSrcAccessOrder,
ck::index_t ABlockTransferSrcVectorDim,
ck::index_t ABlockTransferSrcScalarPerVector,
ck::index_t ABlockTransferDstScalarPerVector_AK1,
ck::index_t ABlockLdsExtraM,
typename BBlockTransferThreadClusterLengths_BK0_N_BK1,
typename BBlockTransferThreadClusterArrangeOrder,
typename BBlockTransferSrcAccessOrder,
ck::index_t BBlockTransferSrcVectorDim,
ck::index_t BBlockTransferSrcScalarPerVector,
ck::index_t BBlockTransferDstScalarPerVector_BK1,
ck::index_t BBlockLdsExtraN,
ck::index_t CShuffleMXdlPerWavePerShuffle,
ck::index_t CShuffleNXdlPerWavePerShuffle,
typename CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
ck::index_t CDEBlockTransferScalarPerVector_NPerBlock,
ck::LoopScheduler LoopSched = ck::make_default_loop_scheduler()>
struct CK_DeviceBatchedGemmMultipleD
{
ck::tensor_operation::device::MatrixPadder<GemmSpec, ck::index_t, ck::index_t, ck::index_t>
matrix_padder{MPerBlock, NPerBlock, KPerBlock};
// GridwiseGemm
using GridwiseGemm = ck::GridwiseGemmMultipleD_xdl_cshuffle<
ADataType, // TODO: distinguish A/B datatype
AccDataType,
CShuffleDataType,
DsDataType,
EDataType,
AElementwiseOperation,
BElementwiseOperation,
CDEElementwiseOperation,
ck::InMemoryDataOperationEnum::Set,
NumGemmKPrefetchStage,
BlockSize,
MPerBlock,
NPerBlock,
KPerBlock,
AK1,
BK1,
MPerXDL,
NPerXDL,
MXdlPerWave,
NXdlPerWave,
ABlockTransferThreadClusterLengths_AK0_M_AK1,
ABlockTransferThreadClusterArrangeOrder,
ABlockTransferSrcAccessOrder,
ABlockTransferSrcVectorDim,
ABlockTransferSrcScalarPerVector,
ABlockTransferDstScalarPerVector_AK1,
false,
ABlockLdsExtraM,
BBlockTransferThreadClusterLengths_BK0_N_BK1,
BBlockTransferThreadClusterArrangeOrder,
BBlockTransferSrcAccessOrder,
BBlockTransferSrcVectorDim,
BBlockTransferSrcScalarPerVector,
BBlockTransferDstScalarPerVector_BK1,
false,
BBlockLdsExtraN,
CShuffleMXdlPerWavePerShuffle,
CShuffleNXdlPerWavePerShuffle,
CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
CDEBlockTransferScalarPerVector_NPerBlock,
LoopSched>;
// return block_id to E matrix tile idx (m0, n0) mapping
template <class EGridDesc_M_N>
__device__ static constexpr auto
MakeDefaultBlock2ETileMap(const EGridDesc_M_N& e_grid_desc_m_n_)
{
return BlockToCTileMap_M00_N0_M01Adapt<MPerBlock, NPerBlock, EGridDesc_M_N>(
e_grid_desc_m_n_);
}
static constexpr ck::index_t NumDTensor = DsDataType::Size();
AElementwiseOperation a_element_op{};
BElementwiseOperation b_element_op{};
CDEElementwiseOperation cde_element_op{};
};
} // namespace migraphx
#endif
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#include "verify_program.hpp"
#include <migraphx/program.hpp>
#include <migraphx/generate.hpp>
#include <migraphx/make_op.hpp>
struct ck_batched_gemm : verify_program<ck_batched_gemm>
{
migraphx::program create_program() const
{
migraphx::program p;
auto* mm = p.get_main_module();
std::size_t b = 2;
std::size_t m = 3;
std::size_t n = 3;
std::size_t k = 3;
migraphx::shape m1_shape{migraphx::shape::half_type, {b, m, k}};
std::vector<float> v1(b*m*k, 1);
std::vector<float> v2(b*k*n, 1);//{1, 2, 3, 4, 5, 6, 7, 8};
// auto l1 = mm->add_parameter("1", m1_shape);
// auto l2 = mm->add_parameter("2", m1_shape);
auto l1 = mm->add_literal(migraphx::literal{m1_shape, v1});
auto l2 = mm->add_literal(migraphx::literal{m1_shape, v1});
mm->add_instruction(migraphx::make_op("dot"), l1, l2);
return p;
}
};
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment