"src/git@developer.sourcefind.cn:gaoqiong/migraphx.git" did not exist on "f0d27006988ab29ab5acaeaf2013b41de4f056a4"
Commit 4248a5cd authored by Alan Turner's avatar Alan Turner
Browse files

Integrate host and device interfaces

parent f5e3ac18
...@@ -28,4 +28,4 @@ half,https://github.com/ROCmSoftwarePlatform/half/archive/1.12.0.tar.gz -X heade ...@@ -28,4 +28,4 @@ half,https://github.com/ROCmSoftwarePlatform/half/archive/1.12.0.tar.gz -X heade
pybind/pybind11@d159a563383d10c821ba7b2a71905d1207db6de4 --build pybind/pybind11@d159a563383d10c821ba7b2a71905d1207db6de4 --build
msgpack/msgpack-c@cpp-3.3.0 -DMSGPACK_BUILD_TESTS=Off msgpack/msgpack-c@cpp-3.3.0 -DMSGPACK_BUILD_TESTS=Off
sqlite3@3.17 -DCMAKE_POSITION_INDEPENDENT_CODE=On sqlite3@3.17 -DCMAKE_POSITION_INDEPENDENT_CODE=On
ROCmSoftwarePlatform/composable_kernel@bef0cb20dba0d9b315df46899310478a81c21852 -X header #ROCmSoftwarePlatform/composable_kernel@bef0cb20dba0d9b315df46899310478a81c21852 -X header
...@@ -261,10 +261,15 @@ else() ...@@ -261,10 +261,15 @@ else()
message(STATUS "MIOpen does not have find mode api") message(STATUS "MIOpen does not have find mode api")
endif() endif()
#find_package(composable_kernel REQUIRED PATHS /code/composable_kernel)
find_package(composable_kernel 1.0.0 COMPONENTS device_operations)
#target_link_libraries(migraphx_gpu PRIVATE composable_kernel::device_operations)
# Workaround broken rocblas headers # Workaround broken rocblas headers
target_compile_definitions(migraphx_gpu PUBLIC -D__HIP_PLATFORM_HCC__=1) target_compile_definitions(migraphx_gpu PUBLIC -D__HIP_PLATFORM_HCC__=1)
target_link_libraries(migraphx_gpu PUBLIC migraphx MIOpen roc::rocblas) target_link_libraries(migraphx_gpu PUBLIC migraphx MIOpen roc::rocblas)
target_link_libraries(migraphx_gpu PRIVATE migraphx_device migraphx_kernels) target_link_libraries(migraphx_gpu PRIVATE migraphx_device migraphx_kernels composable_kernel::device_operations)
add_subdirectory(driver) add_subdirectory(driver)
......
...@@ -112,7 +112,8 @@ const std::vector<std::string>& compiler_warnings() ...@@ -112,7 +112,8 @@ const std::vector<std::string>& compiler_warnings()
"-Wno-sign-compare", "-Wno-sign-compare",
"-Wno-unused-command-line-argument", "-Wno-unused-command-line-argument",
"-Wno-weak-vtables", "-Wno-weak-vtables",
"-Wno-c99-extensions"}; "-Wno-c99-extensions",
"-Wno-global-constructors"};
return warnings; return warnings;
} }
......
...@@ -38,6 +38,17 @@ ...@@ -38,6 +38,17 @@
#include <migraphx/env.hpp> #include <migraphx/env.hpp>
#include <migraphx/file_buffer.hpp> #include <migraphx/file_buffer.hpp>
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/device_gemm_multiple_d.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/library/tensor_operation_instance/gpu/gemm_add_add_fastgelu.hpp"
#include "ck/library/tensor_operation_instance/solution_instances/gemm_multiple_d_xdlop_cshuffle.hpp"
#include <iostream>
const std::vector<std::string>& const std::vector<std::string>&
get_instance(std::size_t i, const std::function<bool(const std::vector<std::string>&)>& pred); get_instance(std::size_t i, const std::function<bool(const std::vector<std::string>&)>& pred);
...@@ -58,6 +69,7 @@ static const char* const ck_gemm_kernel = R"__migraphx__( ...@@ -58,6 +69,7 @@ static const char* const ck_gemm_kernel = R"__migraphx__(
#include <args.hpp> #include <args.hpp>
#include <migraphx/kernels/ck_gemm.hpp> #include <migraphx/kernels/ck_gemm.hpp>
#include <migraphx/kernels/pointwise.hpp> #include <migraphx/kernels/pointwise.hpp>
#include <ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_xdl_cshuffle.hpp>
namespace migraphx { namespace migraphx {
...@@ -68,7 +80,7 @@ extern "C" { ...@@ -68,7 +80,7 @@ extern "C" {
__global__ void ${kernel}(${params}) __global__ void ${kernel}(${params})
{ {
transform_args(make_tensors(), rotate_last())(${args})([](auto... xs) { transform_args(make_tensors(), rotate_last())(${args})([](auto... xs) {
ck_gemm<CK_DeviceGemmMultipleD<${instance}>, ${blocks_per_batch}>(xs...); ck_gemm<${solution}, ${blocks_per_batch}>(xs...);
}); });
} }
...@@ -281,6 +293,7 @@ struct ck_gemm_compiler : compiler<ck_gemm_compiler> ...@@ -281,6 +293,7 @@ struct ck_gemm_compiler : compiler<ck_gemm_compiler>
auto a_shape = inputs[0]; auto a_shape = inputs[0];
auto b_shape = inputs[1]; auto b_shape = inputs[1];
auto c_shape = inputs.back(); auto c_shape = inputs.back();
auto tuning_value = get_tuning_for({a_shape, b_shape, c_shape});
auto rank = a_shape.lens().size(); auto rank = a_shape.lens().size();
auto b_strides = b_shape.strides(); auto b_strides = b_shape.strides();
...@@ -291,39 +304,33 @@ struct ck_gemm_compiler : compiler<ck_gemm_compiler> ...@@ -291,39 +304,33 @@ struct ck_gemm_compiler : compiler<ck_gemm_compiler>
m = can_fold_batch ? m * batch_count : m; m = can_fold_batch ? m * batch_count : m;
auto n = c_shape.lens().back(); auto n = c_shape.lens().back();
auto k = a_shape.lens().back(); auto k = a_shape.lens().back();
std::array<char, 3> keys{'M', 'N', 'K'};
std::array<std::size_t, 3> config{m, n, k}; const auto numDTensors = inputs.size() - 3;
auto tuning_val = v.get("tuning_val", get_tuning_for({a_shape, b_shape, c_shape})); const bool transA = transposed_matrix(a_shape);
auto ip = instance{get_instance(tuning_val, [&](const auto& x) -> bool { const bool transB = transposed_matrix(b_shape);
return get_layout(a_shape) == x[0] and get_layout(b_shape) == x[1] and const bool transCDE = transposed_matrix(c_shape);
get_layout(c_shape) == x[3] and get_type(a_shape) == x[4] and const auto a_type = get_type(a_shape);
get_type(b_shape) == x[5] and get_type(c_shape) == x[9]; const auto b_type = get_type(b_shape);
})}; const auto cde_type = ck_tuple(inputs.begin() + 2, inputs.end() - 1, &get_type);//get_type(c_shape);
const auto cde_layout = ck_tuple(inputs.begin() + 2, inputs.end() - 1, &get_layout);
std::string ck_passthrough = "ck_passthrough";//"ck::tensor_operation::element_wise::PassThrough";
std::string cde_op = ck_passthrough;
assert(inputs.size() < 4 or v.contains("post")); assert(inputs.size() < 4 or v.contains("post"));
if(v.contains("post")) if(v.contains("post"))
{ {
ip.set_ds_layout(ck_tuple(inputs.begin() + 2, inputs.end() - 1, &get_layout)); cde_op = v.at("post").to<std::string>();
ip.set_ds_type(ck_tuple(inputs.begin() + 2, inputs.end() - 1, &get_type));
ip.set_ds_op(v.at("post").to<std::string>());
} }
auto padding = ip.get_pad(config); auto problem = ck::tensor_operation::device::instance::Problem{static_cast<ck::index_t>(m), static_cast<ck::index_t>(n), static_cast<ck::index_t>(k), static_cast<ck::index_t>(numDTensors), static_cast<ck::index_t>(tuning_value),
std::string gemm_type; transA, transB, transCDE,
for(auto i : range(padding.size())) a_type, b_type, cde_type,
{ ck_passthrough, ck_passthrough, cde_op, cde_layout};
if(padding[i] != 0) const auto solution = problem.GetSolution();
gemm_type += keys[i]; auto blocks_per_batch = problem.GetGridSize();
} auto block_size = problem.GetBlockSize();
if(gemm_type.empty())
gemm_type = "Default";
else
gemm_type += "Padding";
ip.set_gemm("ck::tensor_operation::device::GemmSpecialization::" + gemm_type);
auto blocks_per_batch = ip.get_grid_size(config);
hip_compile_options options; hip_compile_options options;
auto block_size = ip.get_block_size();
auto grid_size = can_fold_batch ? blocks_per_batch : batch_count * blocks_per_batch; auto grid_size = can_fold_batch ? blocks_per_batch : batch_count * blocks_per_batch;
options.set_launch_params(v, grid_size * block_size, block_size); options.set_launch_params(v, grid_size * block_size, block_size);
options.inputs = inputs; options.inputs = inputs;
...@@ -341,15 +348,15 @@ struct ck_gemm_compiler : compiler<ck_gemm_compiler> ...@@ -341,15 +348,15 @@ struct ck_gemm_compiler : compiler<ck_gemm_compiler>
if(v.get("check", false) or enabled(MIGRAPHX_CK_DEBUG{})) if(v.get("check", false) or enabled(MIGRAPHX_CK_DEBUG{}))
options.params += " -DMIGRAPHX_CK_CHECK=1"; options.params += " -DMIGRAPHX_CK_CHECK=1";
auto src = interpolate_string(ck_gemm_kernel, auto src = interpolate_string(ck_gemm_kernel,
{{"instance", ip.str()}, {{"solution", solution},
{"params", enum_params(inputs.size(), "void * private_p")}, {"params", enum_params(inputs.size(), "void * private_p")},
{"args", enum_params(inputs.size(), "private_p")}, {"args", enum_params(inputs.size(), "private_p")},
{"blocks_per_batch", to_string(blocks_per_batch)}, {"blocks_per_batch", to_string(blocks_per_batch)},
{"preamble", v.get("preamble", std::string{})}, {"preamble", v.get("preamble", std::string{})},
{"kernel", options.kernel_name}}); {"kernel", options.kernel_name}});
std::cout << src << std::endl;
return compile_hip_code_object(src, options); return compile_hip_code_object(src, options);
} }
......
...@@ -48,52 +48,15 @@ using ck_transposeb = decltype(make_shape(ck_transposeb_dims(get_shape_c<Tensor> ...@@ -48,52 +48,15 @@ using ck_transposeb = decltype(make_shape(ck_transposeb_dims(get_shape_c<Tensor>
template <class G, class E, class A, class B, class... Ds> template <class G, class E, class A, class B, class... Ds>
__device__ void ck_gemm_matrix(E e, A a, B b, Ds... ds) __device__ void ck_gemm_matrix(E e, A a, B b, Ds... ds)
{ {
constexpr const G gemm{}; constexpr auto desc = G::make_descriptor(to_ck_tensor<A>(),
to_ck_tensor<B>(),
constexpr const auto a_grid_desc_m_k = gemm.matrix_padder.PadADescriptor_M_K(to_ck_tensor<A>()); ck::make_tuple(to_ck_tensor<Ds>()...),
constexpr const auto b_grid_desc_n_k = to_ck_tensor<E>());
gemm.matrix_padder.PadBDescriptor_N_K(to_ck_tensor<ck_transposeb<B>>()); G::Run(desc,
to_ck_const_pointer(a.data()),
constexpr const auto e_grid_desc_m_n = gemm.matrix_padder.PadCDescriptor_M_N(to_ck_tensor<E>()); to_ck_const_pointer(b.data()),
constexpr const auto ds_grid_desc_m_n = ck::make_tuple(to_ck_const_pointer(ds.data())...),
ck::make_tuple(gemm.matrix_padder.PadCDescriptor_M_N(to_ck_tensor<Ds>())...); to_ck_pointer(e.data()));
constexpr const auto block_2_etile_map = gemm.MakeDefaultBlock2ETileMap(e_grid_desc_m_n);
using GridwiseGemm = typename G::GridwiseGemm;
// tensor descriptors for block/thread-wise copy
constexpr auto a_grid_desc_ak0_m_ak1 =
GridwiseGemm::MakeDefaultAGridDescriptor_AK0_M_AK1(a_grid_desc_m_k);
constexpr auto b_grid_desc_bk0_n_bk1 =
GridwiseGemm::MakeDefaultBGridDescriptor_BK0_N_BK1(b_grid_desc_n_k);
constexpr auto ds_grid_desc_mblock_mperblock_nblock_nperblock =
GridwiseGemm::MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(ds_grid_desc_m_n);
constexpr auto e_grid_desc_mblock_mperblock_nblock_nperblock =
GridwiseGemm::MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(e_grid_desc_m_n);
static_assert(GridwiseGemm::CheckValidity(
a_grid_desc_m_k, b_grid_desc_n_k, ds_grid_desc_m_n, e_grid_desc_m_n, block_2_etile_map));
__shared__ char p_shared_block[GridwiseGemm::GetSharedMemoryNumberOfByte()];
constexpr const bool HasMainKBlockLoop =
GridwiseGemm::CalculateHasMainKBlockLoop(a_grid_desc_ak0_m_ak1.GetLength(ck::Number<0>{}) *
a_grid_desc_ak0_m_ak1.GetLength(ck::Number<2>{}));
GridwiseGemm::template Run<HasMainKBlockLoop>(to_ck_const_pointer(a.data()),
to_ck_const_pointer(b.data()),
ck::make_tuple(to_ck_const_pointer(ds.data())...),
to_ck_pointer(e.data()),
p_shared_block,
gemm.a_element_op,
gemm.b_element_op,
gemm.cde_element_op,
a_grid_desc_ak0_m_ak1,
b_grid_desc_bk0_n_bk1,
ds_grid_desc_mblock_mperblock_nblock_nperblock,
e_grid_desc_mblock_mperblock_nblock_nperblock,
block_2_etile_map);
} }
template <class G, index_int BlocksPerBatch, class... Ts> template <class G, index_int BlocksPerBatch, class... Ts>
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment