"docs/source/git@developer.sourcefind.cn:norm/vllm.git" did not exist on "e11222333f43c8466c57d0223380dcf297b02bac"
Commit 8e897da7 authored by Jing Zhang's avatar Jing Zhang
Browse files

add gridwise_gemm_v3

parent baac64e4
...@@ -148,7 +148,7 @@ template <index_t BlockSize, ...@@ -148,7 +148,7 @@ template <index_t BlockSize,
typename AGlobalMoveSliceWindowStepHacks, typename AGlobalMoveSliceWindowStepHacks,
typename BGlobalMoveSliceWindowStepHacks, typename BGlobalMoveSliceWindowStepHacks,
ActivTypeEnum_t activ_type = ActivTypeEnum_t::None> ActivTypeEnum_t activ_type = ActivTypeEnum_t::None>
struct GridwiseGemmDlops_km_kn_mn_v3 struct GridwiseGemmDlops_km_kn_mn_v2
{ {
static constexpr auto I0 = Number<0>{}; static constexpr auto I0 = Number<0>{};
static constexpr auto I1 = Number<1>{}; static constexpr auto I1 = Number<1>{};
......
#ifndef CK_GRIDWISE_GEMM_V2_ADD_HPP #ifndef CK_GRIDWISE_GEMM_V3_HPP
#define CK_GRIDWISE_GEMM_V2_ADD_HPP #define CK_GRIDWISE_GEMM_V3_HPP
#include "common_header.hpp" #include "common_header.hpp"
#include "multi_index_transform_helper.hpp" #include "multi_index_transform_helper.hpp"
...@@ -301,7 +301,7 @@ template <index_t BlockSize, ...@@ -301,7 +301,7 @@ template <index_t BlockSize,
index_t bias_type = 0, index_t bias_type = 0,
index_t out_type = 1, index_t out_type = 1,
index_t add_type = 0> index_t add_type = 0>
struct GridwiseGemmDlops_km_kn_mn_v3_add struct GridwiseGemmDlops_km_kn_mn_v3
{ {
static constexpr auto I0 = Number<0>{}; static constexpr auto I0 = Number<0>{};
static constexpr auto I1 = Number<1>{}; static constexpr auto I1 = Number<1>{};
...@@ -733,8 +733,6 @@ struct GridwiseGemmDlops_km_kn_mn_v3_add ...@@ -733,8 +733,6 @@ struct GridwiseGemmDlops_km_kn_mn_v3_add
{ {
constexpr auto c_k1_n_h2_w2_thread_gemm_desc = CThreadDesc_K1_N_H2_W2{}; constexpr auto c_k1_n_h2_w2_thread_gemm_desc = CThreadDesc_K1_N_H2_W2{};
if constexpr(activ_type > 0)
{
static_for<0, c_k1_n_h2_w2_thread_gemm_desc.GetElementSpaceSize(), 1>{}([&](auto i) { static_for<0, c_k1_n_h2_w2_thread_gemm_desc.GetElementSpaceSize(), 1>{}([&](auto i) {
if constexpr(activ_type == 1) if constexpr(activ_type == 1)
{ {
...@@ -753,7 +751,6 @@ struct GridwiseGemmDlops_km_kn_mn_v3_add ...@@ -753,7 +751,6 @@ struct GridwiseGemmDlops_km_kn_mn_v3_add
c_thread_buf(i) = x; c_thread_buf(i) = x;
} }
}); });
}
} }
template <typename CThreadBuff, template <typename CThreadBuff,
......
...@@ -4,7 +4,7 @@ ...@@ -4,7 +4,7 @@
#include "common_header.hpp" #include "common_header.hpp"
#include "tensor_descriptor.hpp" #include "tensor_descriptor.hpp"
#include "tensor_descriptor_helper.hpp" #include "tensor_descriptor_helper.hpp"
#include "gridwise_gemm_dlops_v2_add.hpp" #include "gridwise_gemm_dlops_v3.hpp"
template <ck::index_t BlockSize, template <ck::index_t BlockSize,
typename FloatAB, typename FloatAB,
...@@ -294,7 +294,7 @@ struct DriverDynamicConvolutionForwardImplicitGemmDlops_v5r1_nc0hwc1_kc0yxc1_nk0 ...@@ -294,7 +294,7 @@ struct DriverDynamicConvolutionForwardImplicitGemmDlops_v5r1_nc0hwc1_kc0yxc1_nk0
static_assert(c_k_n_hop_wop_grid_desc.IsKnownAtCompileTime(), ""); static_assert(c_k_n_hop_wop_grid_desc.IsKnownAtCompileTime(), "");
// GEMM // GEMM
using GridwiseGemm = GridwiseGemmDlops_km_kn_mn_v3_add< using GridwiseGemm = GridwiseGemmDlops_km_kn_mn_v3<
BlockSize, BlockSize,
FloatAB, FloatAB,
FloatAcc, FloatAcc,
......
...@@ -4,7 +4,7 @@ ...@@ -4,7 +4,7 @@
#include "common_header.hpp" #include "common_header.hpp"
#include "tensor_descriptor.hpp" #include "tensor_descriptor.hpp"
#include "tensor_descriptor_helper.hpp" #include "tensor_descriptor_helper.hpp"
#include "gridwise_gemm_dlops_v2_add.hpp" #include "gridwise_gemm_dlops_v3.hpp"
template <ck::index_t BlockSize, template <ck::index_t BlockSize,
typename FloatAB, typename FloatAB,
...@@ -259,7 +259,7 @@ struct DriverDynamicConvolutionForwardImplicitGemmDlops_v5r1_nc0hwc1_kc0yxc1_nk0 ...@@ -259,7 +259,7 @@ struct DriverDynamicConvolutionForwardImplicitGemmDlops_v5r1_nc0hwc1_kc0yxc1_nk0
static_assert(c_k_n_hop_wop_grid_desc.IsKnownAtCompileTime(), ""); static_assert(c_k_n_hop_wop_grid_desc.IsKnownAtCompileTime(), "");
// GEMM // GEMM
using GridwiseGemm = GridwiseGemmDlops_km_kn_mn_v3_add< using GridwiseGemm = GridwiseGemmDlops_km_kn_mn_v3<
BlockSize, BlockSize,
FloatAB, FloatAB,
FloatAcc, FloatAcc,
......
...@@ -4,7 +4,7 @@ ...@@ -4,7 +4,7 @@
#include "common_header.hpp" #include "common_header.hpp"
#include "tensor_descriptor.hpp" #include "tensor_descriptor.hpp"
#include "tensor_descriptor_helper.hpp" #include "tensor_descriptor_helper.hpp"
#include "gridwise_gemm_dlops_v2_add.hpp" #include "gridwise_gemm_dlops_v3.hpp"
template <ck::index_t BlockSize, template <ck::index_t BlockSize,
typename FloatAB, typename FloatAB,
...@@ -298,7 +298,7 @@ struct DriverDynamicConvolutionForwardImplicitGemmDlops_v5r1_nc0hwc1_kc0yxc1_nk0 ...@@ -298,7 +298,7 @@ struct DriverDynamicConvolutionForwardImplicitGemmDlops_v5r1_nc0hwc1_kc0yxc1_nk0
static_assert(c_k_n_hop_wop_grid_desc.IsKnownAtCompileTime(), ""); static_assert(c_k_n_hop_wop_grid_desc.IsKnownAtCompileTime(), "");
// GEMM // GEMM
using GridwiseGemm = GridwiseGemmDlops_km_kn_mn_v3_add< using GridwiseGemm = GridwiseGemmDlops_km_kn_mn_v3<
BlockSize, BlockSize,
FloatAB, FloatAB,
FloatAcc, FloatAcc,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment