Commit c04204af authored by Alan Turner's avatar Alan Turner
Browse files

Fix header schema

parent 29aae082
...@@ -22,10 +22,11 @@ ...@@ -22,10 +22,11 @@
* THE SOFTWARE. * THE SOFTWARE.
*/ */
#include <migraphx/gpu/fuse_ck.hpp> #include <migraphx/gpu/fuse_ck.hpp>
#include <migraphx/gpu/gemm_softmax_gemm.hpp>
#include <migraphx/matcher.hpp> #include <migraphx/matcher.hpp>
#include <migraphx/pass_manager.hpp> #include <migraphx/pass_manager.hpp>
#include <migraphx/register_op.hpp> #include <migraphx/register_op.hpp>
#include <migraphx/ck.hpp> #include <migraphx/gpu/ck.hpp>
namespace migraphx { namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
......
...@@ -24,9 +24,10 @@ ...@@ -24,9 +24,10 @@
#ifndef MIGRAPHX_GUARD_GPU_CK_HPP #ifndef MIGRAPHX_GUARD_GPU_CK_HPP
#define MIGRAPHX_GUARD_GPU_CK_HPP #define MIGRAPHX_GUARD_GPU_CK_HPP
#include <migraphx/ck.hpp>
#include <migraphx/gpu/context.hpp>
#include <migraphx/compile_src.hpp> #include <migraphx/compile_src.hpp>
#include <migraphx/env.hpp>
#include <migraphx/shape.hpp>
#include <migraphx/stringutils.hpp>
#include "ck/host/device_gemm_multiple_d.hpp" #include "ck/host/device_gemm_multiple_d.hpp"
#include "ck/host/device_batched_gemm_softmax_gemm.hpp" #include "ck/host/device_batched_gemm_softmax_gemm.hpp"
...@@ -35,6 +36,13 @@ namespace migraphx { ...@@ -35,6 +36,13 @@ namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
namespace gpu { namespace gpu {
#ifndef _WIN32
MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_ENABLE_CK);
MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_LOG_CK_GEMM);
MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_CK_DEBUG);
MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_TUNE_CK);
#endif
// NOLINTNEXTLINE // NOLINTNEXTLINE
const char* const disable_warning_pragma = R"__migraphx__( const char* const disable_warning_pragma = R"__migraphx__(
#pragma clang diagnostic push #pragma clang diagnostic push
......
...@@ -21,22 +21,16 @@ ...@@ -21,22 +21,16 @@
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE. * THE SOFTWARE.
*/ */
#ifndef MIGRAPHX_GUARD_CK_HPP #ifndef MIGRAPHX_GUARD_GPU_GEMM_SOFTMAX_GEMM_HPP
#define MIGRAPHX_GUARD_CK_HPP #define MIGRAPHX_GUARD_GPU_GEMM_SOFTMAX_GEMM_HPP
#include <migraphx/env.hpp>
#include <migraphx/make_op.hpp> #include <migraphx/make_op.hpp>
#include <migraphx/check_shapes.hpp> #include <migraphx/check_shapes.hpp>
namespace migraphx { namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
namespace gpu {
#ifndef _WIN32
MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_ENABLE_CK);
MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_LOG_CK_GEMM);
MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_CK_DEBUG);
MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_TUNE_CK);
#endif
struct gemm_softmax_gemm struct gemm_softmax_gemm
{ {
...@@ -49,7 +43,7 @@ struct gemm_softmax_gemm ...@@ -49,7 +43,7 @@ struct gemm_softmax_gemm
return pack(f(self.op, "op"), f(self.scale, "scale")); return pack(f(self.op, "op"), f(self.scale, "scale"));
} }
std::string name() const { return "gemm_softmax_gemm"; } std::string name() const { return "gpu::gemm_softmax_gemm"; }
void check_gemm_shape(const shape& s) const void check_gemm_shape(const shape& s) const
{ {
...@@ -75,7 +69,8 @@ struct gemm_softmax_gemm ...@@ -75,7 +69,8 @@ struct gemm_softmax_gemm
static bool is_ck_supported_type(shape::type_t t) { return contains({shape::half_type}, t); } static bool is_ck_supported_type(shape::type_t t) { return contains({shape::half_type}, t); }
}; };
} // namespace gpu
} // namespace MIGRAPHX_INLINE_NS } // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx } // namespace migraphx
#endif // MIGRAPHX_GUARD_GPU_CK_HPP
#endif
...@@ -23,11 +23,12 @@ ...@@ -23,11 +23,12 @@
*/ */
#include <migraphx/permutation.hpp> #include <migraphx/permutation.hpp>
#include <migraphx/gpu/prefuse_ops.hpp> #include <migraphx/gpu/prefuse_ops.hpp>
#include <migraphx/gpu/gemm_softmax_gemm.hpp>
#include <migraphx/match/layernorm.hpp> #include <migraphx/match/layernorm.hpp>
#include <migraphx/register_op.hpp> #include <migraphx/register_op.hpp>
#include <migraphx/pass_manager.hpp> #include <migraphx/pass_manager.hpp>
#include <migraphx/dead_code_elimination.hpp> #include <migraphx/dead_code_elimination.hpp>
#include <migraphx/ck.hpp> #include <migraphx/gpu/ck.hpp>
namespace migraphx { namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
......
...@@ -53,7 +53,7 @@ ...@@ -53,7 +53,7 @@
#include <migraphx/simplify_reshapes.hpp> #include <migraphx/simplify_reshapes.hpp>
#include <migraphx/split_single_dyn_dim.hpp> #include <migraphx/split_single_dyn_dim.hpp>
#include <migraphx/gpu/allocation_model.hpp> #include <migraphx/gpu/allocation_model.hpp>
#include <migraphx/ck.hpp> #include <migraphx/gpu/ck.hpp>
#include <migraphx/gpu/compile_miopen.hpp> #include <migraphx/gpu/compile_miopen.hpp>
#include <migraphx/gpu/compile_ops.hpp> #include <migraphx/gpu/compile_ops.hpp>
#include <migraphx/gpu/concat_gpu_opt.hpp> #include <migraphx/gpu/concat_gpu_opt.hpp>
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment