Commit 7103d3ed authored by Alan Turner's avatar Alan Turner
Browse files

Call eval on scale lit and move gsg settings out of ck.hpp

parent 7e8b69ad
......@@ -234,7 +234,6 @@ struct find_ck_gemm_softmax_gemm
void apply(module_pass_manager& mpm, const match::matcher_result& r) const
{
std::cout << "Matched GSG" << std::endl;
auto ins = r.result;
auto gemm2_ins = r.instructions["gemm2"];
auto gemm1_ins = r.instructions["gemm1"];
......@@ -244,7 +243,7 @@ struct find_ck_gemm_softmax_gemm
return;
double scale = 1.0;
scale_lit->get_literal().visit([&](const auto s) {
scale_lit->eval().visit([&](const auto s) {
// CK only supports single-valued scale
if(std::all_of(
s.begin() + 1, s.end(), [&](auto v) { return float_equal(v, s.front()); }))
......
......@@ -165,18 +165,6 @@ template <class Tensor>
using ck_transposeb = decltype(make_shape(ck_transposeb_dims(get_shape_c<Tensor>{}.lens),
ck_transposeb_dims(get_shape_c<Tensor>{}.strides)));
template <class T>
struct ck_gemm_softmax_gemm_settings
{
T scale{};
};
template <class... Ts>
constexpr ck_gemm_softmax_gemm_settings<Ts...> make_ck_gemm_softmax_gemm_settings(Ts... xs)
{
return {xs...};
}
#ifdef MIGRAPHX_CK_CHECK
#define MIGRAPHX_CK_STATIC_ASSERT static_assert
#else
......
......@@ -33,6 +33,18 @@
namespace migraphx {
template <class T>
struct ck_gemm_softmax_gemm_settings
{
T scale{};
};
template <class... Ts>
constexpr ck_gemm_softmax_gemm_settings<Ts...> make_ck_gemm_softmax_gemm_settings(Ts... xs)
{
return {xs...};
}
template <class G, class C, class A, class B, class B1, class Settings>
__device__ void ck_gemm_softmax_gemm_matrix(C c, A a, B b, B1 b1, Settings s)
{
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment