Commit 8e897da7 authored by Jing Zhang's avatar Jing Zhang
Browse files

add gridwise_gemm_v3

parent baac64e4
......@@ -148,7 +148,7 @@ template <index_t BlockSize,
typename AGlobalMoveSliceWindowStepHacks,
typename BGlobalMoveSliceWindowStepHacks,
ActivTypeEnum_t activ_type = ActivTypeEnum_t::None>
struct GridwiseGemmDlops_km_kn_mn_v3
struct GridwiseGemmDlops_km_kn_mn_v2
{
static constexpr auto I0 = Number<0>{};
static constexpr auto I1 = Number<1>{};
......
#ifndef CK_GRIDWISE_GEMM_V2_ADD_HPP
#define CK_GRIDWISE_GEMM_V2_ADD_HPP
#ifndef CK_GRIDWISE_GEMM_V3_HPP
#define CK_GRIDWISE_GEMM_V3_HPP
#include "common_header.hpp"
#include "multi_index_transform_helper.hpp"
......@@ -301,7 +301,7 @@ template <index_t BlockSize,
index_t bias_type = 0,
index_t out_type = 1,
index_t add_type = 0>
struct GridwiseGemmDlops_km_kn_mn_v3_add
struct GridwiseGemmDlops_km_kn_mn_v3
{
static constexpr auto I0 = Number<0>{};
static constexpr auto I1 = Number<1>{};
......@@ -733,8 +733,6 @@ struct GridwiseGemmDlops_km_kn_mn_v3_add
{
constexpr auto c_k1_n_h2_w2_thread_gemm_desc = CThreadDesc_K1_N_H2_W2{};
if constexpr(activ_type > 0)
{
static_for<0, c_k1_n_h2_w2_thread_gemm_desc.GetElementSpaceSize(), 1>{}([&](auto i) {
if constexpr(activ_type == 1)
{
......@@ -754,7 +752,6 @@ struct GridwiseGemmDlops_km_kn_mn_v3_add
}
});
}
}
template <typename CThreadBuff,
typename CGlobalBuff,
......
......@@ -4,7 +4,7 @@
#include "common_header.hpp"
#include "tensor_descriptor.hpp"
#include "tensor_descriptor_helper.hpp"
#include "gridwise_gemm_dlops_v2_add.hpp"
#include "gridwise_gemm_dlops_v3.hpp"
template <ck::index_t BlockSize,
typename FloatAB,
......@@ -294,7 +294,7 @@ struct DriverDynamicConvolutionForwardImplicitGemmDlops_v5r1_nc0hwc1_kc0yxc1_nk0
static_assert(c_k_n_hop_wop_grid_desc.IsKnownAtCompileTime(), "");
// GEMM
using GridwiseGemm = GridwiseGemmDlops_km_kn_mn_v3_add<
using GridwiseGemm = GridwiseGemmDlops_km_kn_mn_v3<
BlockSize,
FloatAB,
FloatAcc,
......
......@@ -4,7 +4,7 @@
#include "common_header.hpp"
#include "tensor_descriptor.hpp"
#include "tensor_descriptor_helper.hpp"
#include "gridwise_gemm_dlops_v2_add.hpp"
#include "gridwise_gemm_dlops_v3.hpp"
template <ck::index_t BlockSize,
typename FloatAB,
......@@ -259,7 +259,7 @@ struct DriverDynamicConvolutionForwardImplicitGemmDlops_v5r1_nc0hwc1_kc0yxc1_nk0
static_assert(c_k_n_hop_wop_grid_desc.IsKnownAtCompileTime(), "");
// GEMM
using GridwiseGemm = GridwiseGemmDlops_km_kn_mn_v3_add<
using GridwiseGemm = GridwiseGemmDlops_km_kn_mn_v3<
BlockSize,
FloatAB,
FloatAcc,
......
......@@ -4,7 +4,7 @@
#include "common_header.hpp"
#include "tensor_descriptor.hpp"
#include "tensor_descriptor_helper.hpp"
#include "gridwise_gemm_dlops_v2_add.hpp"
#include "gridwise_gemm_dlops_v3.hpp"
template <ck::index_t BlockSize,
typename FloatAB,
......@@ -298,7 +298,7 @@ struct DriverDynamicConvolutionForwardImplicitGemmDlops_v5r1_nc0hwc1_kc0yxc1_nk0
static_assert(c_k_n_hop_wop_grid_desc.IsKnownAtCompileTime(), "");
// GEMM
using GridwiseGemm = GridwiseGemmDlops_km_kn_mn_v3_add<
using GridwiseGemm = GridwiseGemmDlops_km_kn_mn_v3<
BlockSize,
FloatAB,
FloatAcc,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment