"lmdeploy/git@developer.sourcefind.cn:guobj/qwen_lmdeploy.git" did not exist on "2a4754785144a08f1e1feeb11fad87bbd6e41610"
Commit 0404f777 authored by Chao Liu's avatar Chao Liu
Browse files

refactor

parent 08c7f743
...@@ -5,7 +5,7 @@ ...@@ -5,7 +5,7 @@
#include "nvToolsExt.h" #include "nvToolsExt.h"
#include "tensor.hpp" #include "tensor.hpp"
#include "constant_tensor_descriptor.cuh" #include "constant_tensor_descriptor.cuh"
#include "direct_convolution.cuh" #include "direct_convolution_2.cuh"
template <class T> template <class T>
struct GeneratorConstant struct GeneratorConstant
...@@ -133,8 +133,8 @@ void device_convolution( ...@@ -133,8 +133,8 @@ void device_convolution(
constexpr auto wei_desc = WeiDesc{}; constexpr auto wei_desc = WeiDesc{};
constexpr auto out_desc = OutDesc{}; constexpr auto out_desc = OutDesc{};
constexpr unsigned NPerBlock = 1; constexpr unsigned NPerBlock = 1;
constexpr unsigned KPerBlock = 1; constexpr unsigned KPerBlock = 2;
constexpr unsigned CPerBlockLoop = 1; constexpr unsigned CPerBlockLoop = 4;
constexpr unsigned OutTileSizeH = 2; constexpr unsigned OutTileSizeH = 2;
constexpr unsigned OutTileSizeW = 2; constexpr unsigned OutTileSizeW = 2;
constexpr unsigned YPerBlock = 8; constexpr unsigned YPerBlock = 8;
...@@ -213,7 +213,7 @@ int main() ...@@ -213,7 +213,7 @@ int main()
constexpr unsigned C = 256; constexpr unsigned C = 256;
constexpr unsigned HI = 34; constexpr unsigned HI = 34;
constexpr unsigned WI = 34; constexpr unsigned WI = 34;
constexpr unsigned K = 56; constexpr unsigned K = 64;
constexpr unsigned S = 3; constexpr unsigned S = 3;
constexpr unsigned R = 3; constexpr unsigned R = 3;
#elif 0 #elif 0
......
...@@ -62,13 +62,9 @@ __device__ void blockwise_4d_tensor_op( ...@@ -62,13 +62,9 @@ __device__ void blockwise_4d_tensor_op(
{ {
for(unsigned did3 = did3_begin; did3 < src_desc.GetLength(I3); did3 += NWorkLen3) for(unsigned did3 = did3_begin; did3 < src_desc.GetLength(I3); did3 += NWorkLen3)
{ {
const unsigned sindex = const unsigned sindex = src_desc.Get1dIndex(did0, did1, did2, did3);
src_desc.GetStride(I0) * did0 + src_desc.GetStride(I1) * did1 +
src_desc.GetStride(I2) * did2 + src_desc.GetStride(I3) * did3;
const unsigned dindex = const unsigned dindex = dst_desc.Get1dIndex(did0, did1, did2, did3);
dst_desc.GetStride(I0) * did0 + dst_desc.GetStride(I1) * did1 +
dst_desc.GetStride(I2) * did2 + dst_desc.GetStride(I3) * did3;
f(p_src[dindex], p_dst[sindex]); f(p_src[dindex], p_dst[sindex]);
...@@ -115,6 +111,8 @@ __device__ void blockwise_4d_tensor_op( ...@@ -115,6 +111,8 @@ __device__ void blockwise_4d_tensor_op(
static_assert(is_same<decltype(src_desc.GetLengths()), decltype(dst_desc.GetLengths())>::value); static_assert(is_same<decltype(src_desc.GetLengths()), decltype(dst_desc.GetLengths())>::value);
constexpr auto desc = make_ConstantTensorDescriptor(src_desc.GetLengths());
#if 0 #if 0
if(threadIdx.x == 0) if(threadIdx.x == 0)
{ {
...@@ -125,29 +123,27 @@ __device__ void blockwise_4d_tensor_op( ...@@ -125,29 +123,27 @@ __device__ void blockwise_4d_tensor_op(
unsigned lid = threadIdx.x; unsigned lid = threadIdx.x;
for(unsigned i = lid; i < src_desc.GetElementSize(); i += BlockSize) for(unsigned i = lid; i < desc.GetElementSize(); i += BlockSize)
{ {
unsigned is = i; unsigned is = i;
const unsigned did0 = is / src_desc.GetStride(I0); const unsigned did0 = is / desc.GetStride(I0);
is -= did0 * src_desc.GetStride(I0); is -= did0 * desc.GetStride(I0);
const unsigned did1 = is / src_desc.GetStride(I1); const unsigned did1 = is / desc.GetStride(I1);
is -= did1 * src_desc.GetStride(I1); is -= did1 * desc.GetStride(I1);
const unsigned did2 = is / src_desc.GetStride(I2); const unsigned did2 = is / desc.GetStride(I2);
is -= did2 * src_desc.GetStride(I2); is -= did2 * desc.GetStride(I2);
const unsigned did3 = is / src_desc.GetStride(I3); const unsigned did3 = is / src_desc.GetStride(I3);
const unsigned sindex = src_desc.GetStride(I0) * did0 + src_desc.GetStride(I1) * did1 + const unsigned sindex = src_desc.Get1dIndex(did0, did1, did2, did3);
src_desc.GetStride(I2) * did2 + src_desc.GetStride(I3) * did3;
const unsigned dindex = dst_desc.GetStride(I0) * did0 + dst_desc.GetStride(I1) * did1 + const unsigned dindex = dst_desc.Get1dIndex(did0, did1, did2, did3);
dst_desc.GetStride(I2) * did2 + dst_desc.GetStride(I3) * did3;
f(p_src[sindex], p_dst[dindex]); f(p_src[sindex], p_dst[dindex]);
} }
......
This diff is collapsed.
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment