Commit c04204af authored by Alan Turner's avatar Alan Turner
Browse files

Fix header schema

parent 29aae082
......@@ -22,10 +22,11 @@
* THE SOFTWARE.
*/
#include <migraphx/gpu/fuse_ck.hpp>
#include <migraphx/gpu/gemm_softmax_gemm.hpp>
#include <migraphx/matcher.hpp>
#include <migraphx/pass_manager.hpp>
#include <migraphx/register_op.hpp>
#include <migraphx/ck.hpp>
#include <migraphx/gpu/ck.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
......
......@@ -24,9 +24,10 @@
#ifndef MIGRAPHX_GUARD_GPU_CK_HPP
#define MIGRAPHX_GUARD_GPU_CK_HPP
#include <migraphx/ck.hpp>
#include <migraphx/gpu/context.hpp>
#include <migraphx/compile_src.hpp>
#include <migraphx/env.hpp>
#include <migraphx/shape.hpp>
#include <migraphx/stringutils.hpp>
#include "ck/host/device_gemm_multiple_d.hpp"
#include "ck/host/device_batched_gemm_softmax_gemm.hpp"
......@@ -35,6 +36,13 @@ namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace gpu {
#ifndef _WIN32
MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_ENABLE_CK);
MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_LOG_CK_GEMM);
MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_CK_DEBUG);
MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_TUNE_CK);
#endif
// NOLINTNEXTLINE
const char* const disable_warning_pragma = R"__migraphx__(
#pragma clang diagnostic push
......
......@@ -21,22 +21,16 @@
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#ifndef MIGRAPHX_GUARD_CK_HPP
#define MIGRAPHX_GUARD_CK_HPP
#ifndef MIGRAPHX_GUARD_GPU_GEMM_SOFTMAX_GEMM_HPP
#define MIGRAPHX_GUARD_GPU_GEMM_SOFTMAX_GEMM_HPP
#include <migraphx/env.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/check_shapes.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
#ifndef _WIN32
MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_ENABLE_CK);
MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_LOG_CK_GEMM);
MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_CK_DEBUG);
MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_TUNE_CK);
#endif
namespace gpu {
struct gemm_softmax_gemm
{
......@@ -49,7 +43,7 @@ struct gemm_softmax_gemm
return pack(f(self.op, "op"), f(self.scale, "scale"));
}
std::string name() const { return "gemm_softmax_gemm"; }
std::string name() const { return "gpu::gemm_softmax_gemm"; }
void check_gemm_shape(const shape& s) const
{
......@@ -75,7 +69,8 @@ struct gemm_softmax_gemm
static bool is_ck_supported_type(shape::type_t t) { return contains({shape::half_type}, t); }
};
} // namespace gpu
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#endif
#endif // MIGRAPHX_GUARD_GPU_CK_HPP
......@@ -23,11 +23,12 @@
*/
#include <migraphx/permutation.hpp>
#include <migraphx/gpu/prefuse_ops.hpp>
#include <migraphx/gpu/gemm_softmax_gemm.hpp>
#include <migraphx/match/layernorm.hpp>
#include <migraphx/register_op.hpp>
#include <migraphx/pass_manager.hpp>
#include <migraphx/dead_code_elimination.hpp>
#include <migraphx/ck.hpp>
#include <migraphx/gpu/ck.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
......
......@@ -53,7 +53,7 @@
#include <migraphx/simplify_reshapes.hpp>
#include <migraphx/split_single_dyn_dim.hpp>
#include <migraphx/gpu/allocation_model.hpp>
#include <migraphx/ck.hpp>
#include <migraphx/gpu/ck.hpp>
#include <migraphx/gpu/compile_miopen.hpp>
#include <migraphx/gpu/compile_ops.hpp>
#include <migraphx/gpu/concat_gpu_opt.hpp>
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment