Commit 6433eede authored by Jing Zhang's avatar Jing Zhang
Browse files

fixed

parent afeccb5f
...@@ -21,8 +21,8 @@ using CElementOp = PassThrough; ...@@ -21,8 +21,8 @@ using CElementOp = PassThrough;
static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default; static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default;
static constexpr bool PermuteA = false; static constexpr bool PermuteA = false;
static constexpr bool PermuteB = true; static constexpr bool PermuteB = true;
static constexpr ck::index_t KPerBlock = 128; static constexpr ck::index_t KPerBlock = 128;
// clang-format off // clang-format off
......
...@@ -114,7 +114,10 @@ struct StaticBufferTupleOfVector ...@@ -114,7 +114,10 @@ struct StaticBufferTupleOfVector
// Get X // Get X
// i is offset of S, not X. i should be aligned to X // i is offset of S, not X. i should be aligned to X
template <typename X, index_t I> template <typename X,
index_t I,
typename enable_if<has_same_scalar_type<S, X>::value || !is_native_type<S>(),
bool>::type = false>
__host__ __device__ constexpr auto GetAsType(Number<I> i) const __host__ __device__ constexpr auto GetAsType(Number<I> i) const
{ {
constexpr auto s_per_x = Number<scalar_type<remove_cvref_t<X>>::vector_size>{}; constexpr auto s_per_x = Number<scalar_type<remove_cvref_t<X>>::vector_size>{};
...@@ -130,7 +133,10 @@ struct StaticBufferTupleOfVector ...@@ -130,7 +133,10 @@ struct StaticBufferTupleOfVector
// Set X // Set X
// i is offset of S, not X. i should be aligned to X // i is offset of S, not X. i should be aligned to X
template <typename X, index_t I> template <typename X,
index_t I,
typename enable_if<has_same_scalar_type<S, X>::value || !is_native_type<S>(),
bool>::type = false>
__host__ __device__ constexpr void SetAsType(Number<I> i, X x) __host__ __device__ constexpr void SetAsType(Number<I> i, X x)
{ {
constexpr auto s_per_x = Number<scalar_type<remove_cvref_t<X>>::vector_size>{}; constexpr auto s_per_x = Number<scalar_type<remove_cvref_t<X>>::vector_size>{};
......
...@@ -43,7 +43,6 @@ if(SUPPORTED_GPU_TARGETS MATCHES "gfx9") ...@@ -43,7 +43,6 @@ if(SUPPORTED_GPU_TARGETS MATCHES "gfx9")
list(APPEND PROFILER_SOURCES profile_gemm_add_silu.cpp) list(APPEND PROFILER_SOURCES profile_gemm_add_silu.cpp)
list(APPEND PROFILER_SOURCES profile_gemm_add_relu_add_layernorm.cpp) list(APPEND PROFILER_SOURCES profile_gemm_add_relu_add_layernorm.cpp)
list(APPEND PROFILER_SOURCES profile_grouped_gemm_fixed_nk.cpp) list(APPEND PROFILER_SOURCES profile_grouped_gemm_fixed_nk.cpp)
list(APPEND PROFILER_SOURCES profile_grouped_gemm_two_stage.cpp)
list(APPEND PROFILER_SOURCES profile_grouped_gemm_fastgelu.cpp) list(APPEND PROFILER_SOURCES profile_grouped_gemm_fastgelu.cpp)
list(APPEND PROFILER_SOURCES profile_grouped_gemm_tile_loop.cpp) list(APPEND PROFILER_SOURCES profile_grouped_gemm_tile_loop.cpp)
list(APPEND PROFILER_SOURCES profile_grouped_gemm_multiply_tile_loop.cpp) list(APPEND PROFILER_SOURCES profile_grouped_gemm_multiply_tile_loop.cpp)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment