Commit ef5a5f4e authored by turneram's avatar turneram
Browse files

Add ref and jit ops

parent 55cb7d3a
......@@ -119,6 +119,7 @@ register_migraphx_ops(
broadcast
capture
ceil
ck_gemm
clip
concat
contiguous
......
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#ifndef MIGRAPHX_GUARD_OPERATORS_CK_GEMM_HPP
#define MIGRAPHX_GUARD_OPERATORS_CK_GEMM_HPP
#include <migraphx/check_shapes.hpp>
#include <migraphx/argument.hpp>
#include <migraphx/config.hpp>
#include <migraphx/gemm.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace op {
struct ck_gemm
{
std::string name() const { return "ck_gemm"; }
shape compute_shape(std::vector<shape> inputs) const
{
check_shapes{inputs, *this}.same_type().has(2);
const shape& a = inputs.at(0);
const shape& b = inputs.at(1);
auto t = a.type();
if(!std::all_of(inputs.begin(), inputs.end(), [](auto s) { return s.lens().size() >= 2; }))
{
MIGRAPHX_THROW("DOT: dot only accept 2 or more dims operands");
}
// only handle the case that the batch size of a and b are the same
if(!std::equal(
a.lens().rbegin() + 2, a.lens().rend(), b.lens().rbegin() + 2, b.lens().rend()))
{
MIGRAPHX_THROW("DOT: batch size of A and B mismatch: {" + to_string_range(a.lens()) +
"} x {" + to_string_range(b.lens()) + "}");
}
std::size_t dim_0 = a.lens().size() - 2;
std::size_t dim_1 = a.lens().size() - 1;
if(a.lens()[dim_1] != b.lens()[dim_0])
{
MIGRAPHX_THROW("DOT: inner dimensions do not match: {" + to_string_range(a.lens()) +
"} x {" + to_string_range(b.lens()) + "}");
}
auto out_lens = a.lens();
out_lens[dim_1] = b.lens()[dim_1];
return {t, out_lens};
}
argument compute(shape output_shape, std::vector<argument> args) const
{
argument result = argument{output_shape};
visit_all(result, args[0], args[1])(
[&](auto cmat, auto amat, auto bmat) { gemm(cmat, amat, bmat, 1.0f, 0.0f); });
return result;
}
};
} // namespace op
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#endif
......@@ -40,6 +40,7 @@
#include <migraphx/op/broadcast.hpp>
#include <migraphx/op/capture.hpp>
#include <migraphx/op/ceil.hpp>
#include <migraphx/op/ck_gemm.hpp>
#include <migraphx/op/clip.hpp>
#include <migraphx/op/common.hpp>
#include <migraphx/op/concat.hpp>
......
......@@ -22,9 +22,12 @@
# THE SOFTWARE.
#####################################################################################
list(APPEND CMAKE_PREFIX_PATH /opt/rocm /opt/rocm/hip /opt/rocm/hcc)
list(APPEND CMAKE_PREFIX_PATH /opt/rocm /opt/rocm/hip /opt/rocm/hcc /opt/rocm/ck)
find_package(miopen)
find_package(composable_kernel 1.0.0 COMPONENTS device_operations)
find_package(hip REQUIRED PATHS /opt/rocm)
# rocblas
find_package(rocblas REQUIRED PATHS /opt/rocm)
message(STATUS "Build with rocblas")
......@@ -390,7 +393,8 @@ endif()
# Workaround broken rocblas headers
target_compile_definitions(migraphx_gpu PUBLIC -D__HIP_PLATFORM_HCC__=1)
target_link_libraries(migraphx_gpu PUBLIC migraphx MIOpen roc::rocblas)
target_link_libraries(migraphx_gpu PRIVATE migraphx_device migraphx_kernels)
target_link_libraries(migraphx_gpu PRIVATE migraphx_device migraphx_kernels composable_kernel::device_operations)
target_include_directories(migraphx_gpu PRIVATE /opt/rocm/ck/include)
add_subdirectory(driver)
......
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#include <migraphx/gpu/compiler.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/gpu/context.hpp>
#include <migraphx/gpu/compile_hip_code_object.hpp>
#include <migraphx/gpu/compile_hip.hpp>
#include <migraphx/ranges.hpp>
#include <migraphx/reduce_dims.hpp>
#include <migraphx/stringutils.hpp>
#include <migraphx/dead_code_elimination.hpp>
#include <migraphx/eliminate_common_subexpression.hpp>
#include <migraphx/module.hpp>
#include <migraphx/pass_manager.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace gpu {
// NOLINTNEXTLINE
static const char* const ck_gemm_kernel = R"__migraphx__(
#include <migraphx/kernels/ck_gemm.hpp>
#include <migraphx/kernels/ops.hpp>
#include <migraphx/kernels/integral_constant.hpp>
#include <migraphx/kernels/generic_constant.hpp>
#include <args.hpp>
namespace migraphx {
extern "C" {
__global__ void ck_gemm_kernel(void* in_data, void* in_indices, void* output)
{
make_tensors()(in_data, in_indices, output)([](auto&&... xs) {
ck_gemm(xs...);
});
}
}
} // namespace migraphx
)__migraphx__";
struct ck_gemm_compiler : compiler<ck_gemm_compiler>
{
std::vector<std::string> names() const { return {"ck_gemm"}; }
operation compile_op(context& ctx, const std::vector<shape>& inputs, const value& v) const
{
hip_compile_options options;
auto out_s = inputs.back();
options.set_launch_params(v, compute_global_for(ctx, out_s.elements()));
options.inputs = inputs;
options.output = out_s;
options.kernel_name = "ck_gemm_kernel";
options.virtual_inputs = inputs;
return compile_hip_code_object(ck_gemm_kernel, options);
}
compiler_replace compile(context& ctx, instruction_ref ins, const operation& op) const
{
return replace(compile_op(ctx, to_shapes(ins->inputs()), op.to_value()));
}
};
} // namespace gpu
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <iostream>
#include <vector>
#include "ck/tensor_operation/gpu/device/device_base.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
template <typename ALayout,
typename BLayout,
typename CLayout,
typename ADataType,
typename BDataType,
typename CDataType,
typename AElementwiseOperation,
typename BElementwiseOperation,
typename CElementwiseOperation>
struct DeviceGemm : public BaseOperator
{
virtual std::unique_ptr<BaseArgument>
MakeArgumentPointer(const void* p_a,
const void* p_b,
void* p_c,
ck::index_t M,
ck::index_t N,
ck::index_t K,
ck::index_t StrideA,
ck::index_t StrideB,
ck::index_t StrideC,
AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op,
CElementwiseOperation c_element_op) = 0;
virtual std::unique_ptr<BaseInvoker> MakeInvokerPointer() = 0;
};
} // namespace device
} // namespace tensor_operation
} // namespace ck
\ No newline at end of file
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#ifndef MIGRAPHX_GUARD_KERNELS_CK_GEMM_HPP
#define MIGRAPHX_GUARD_KERNELS_CK_GEMM_HPP
#include <migraphx/kernels/index.hpp>
#include <migraphx/kernels/algorithm.hpp>
namespace migraphx {
template <class T, class U, class V>
__device__ void ck_gemm(const T& /* data_t */, const U& /* indices_t */, const V& /* output_t */)
{
}
} // namespace migraphx
#endif
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#include "verify_program.hpp"
#include <migraphx/program.hpp>
#include <migraphx/generate.hpp>
#include <migraphx/make_op.hpp>
struct ck_gemm : verify_program<ck_gemm>
{
migraphx::program create_program() const
{
migraphx::program p;
auto* mm = p.get_main_module();
migraphx::shape m1_shape{migraphx::shape::float_type, {10, 20}};
migraphx::shape m2_shape{migraphx::shape::float_type, {20, 10}};
auto l1 = mm->add_parameter("1", m1_shape);
auto l2 = mm->add_parameter("2", m2_shape);
mm->add_instruction(migraphx::make_op("ck_gemm"), l1, l2);
return p;
}
};
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment