"sgl-router/src/git@developer.sourcefind.cn:change/sglang.git" did not exist on "067068f27173b41a6bec437be7bbfcb3e355b080"
Commit b2888adf authored by Chao Liu's avatar Chao Liu
Browse files

change file extension to hip.hpp and hip.cpp

parent a414e3fd
#pragma once #pragma once
#include "common.cuh" #include "common.hip.hpp"
#include "ConstantTensorDescriptor.cuh" #include "ConstantTensorDescriptor.hip.hpp"
#include "ConstantMatrixDescriptor.cuh" #include "ConstantMatrixDescriptor.hip.hpp"
#include "blockwise_4d_tensor_op.cuh" #include "blockwise_4d_tensor_op.hip.hpp"
#include "blockwise_2d_tensor_op.cuh" #include "blockwise_2d_tensor_op.hip.hpp"
#include "threadwise_4d_tensor_op.cuh" #include "threadwise_4d_tensor_op.hip.hpp"
#include "blockwise_gemm.cuh" #include "blockwise_gemm.hip.hpp"
template <unsigned GridSize, template <unsigned GridSize,
unsigned BlockSize, unsigned BlockSize,
...@@ -199,9 +199,8 @@ gridwise_implicit_gemm_convolution_1_chwn_csrk_khwn(const Float* const __restric ...@@ -199,9 +199,8 @@ gridwise_implicit_gemm_convolution_1_chwn_csrk_khwn(const Float* const __restric
threadwise_4d_tensor_set_zero(out_hkwn_thread_desc, p_out_thread); threadwise_4d_tensor_set_zero(out_hkwn_thread_desc, p_out_thread);
const Float* p_in_global_block_begin = const Float* p_in_global_block_begin =
p_in_global + p_in_global + in_chwn_global_desc.Get1dIndex(
in_chwn_global_desc.Get1dIndex( 0, hi_block_data_begin, wi_block_data_begin, n_block_data_begin);
0, hi_block_data_begin, wi_block_data_begin, n_block_data_begin);
const Float* p_wei_global_block_begin = const Float* p_wei_global_block_begin =
p_wei_global + wei_csrk_global_desc.Get1dIndex(0, 0, 0, k_block_data_begin); p_wei_global + wei_csrk_global_desc.Get1dIndex(0, 0, 0, k_block_data_begin);
...@@ -258,11 +257,10 @@ gridwise_implicit_gemm_convolution_1_chwn_csrk_khwn(const Float* const __restric ...@@ -258,11 +257,10 @@ gridwise_implicit_gemm_convolution_1_chwn_csrk_khwn(const Float* const __restric
out_hkwn_thread_desc, out_hkwn_thread_desc,
p_out_thread, p_out_thread,
out_khwn_global_desc, out_khwn_global_desc,
p_out_global + p_out_global + out_khwn_global_desc.Get1dIndex(k_block_data_begin + k_thread_data_begin,
out_khwn_global_desc.Get1dIndex(k_block_data_begin + k_thread_data_begin, ho_block_data_begin + ho_thread_data_begin,
ho_block_data_begin + ho_thread_data_begin, wo_block_data_begin + wo_thread_data_begin,
wo_block_data_begin + wo_thread_data_begin, n_block_data_begin + n_thread_data_begin),
n_block_data_begin + n_thread_data_begin),
out_hkwn_thread_desc.GetLengths(), out_hkwn_thread_desc.GetLengths(),
reorder_khwn_from_hkwn); reorder_khwn_from_hkwn);
} }
#pragma once #pragma once
#include "common.cuh" #include "common.hip.hpp"
#include "ConstantTensorDescriptor.cuh" #include "ConstantTensorDescriptor.hip.hpp"
#include "ConstantMatrixDescriptor.cuh" #include "ConstantMatrixDescriptor.hip.hpp"
#include "blockwise_4d_tensor_op.cuh" #include "blockwise_4d_tensor_op.hip.hpp"
#include "blockwise_2d_tensor_op.cuh" #include "blockwise_2d_tensor_op.hip.hpp"
#include "threadwise_4d_tensor_op.cuh" #include "threadwise_4d_tensor_op.hip.hpp"
#include "blockwise_gemm.cuh" #include "blockwise_gemm.hip.hpp"
template <unsigned GridSize, template <unsigned GridSize,
unsigned BlockSize, unsigned BlockSize,
...@@ -283,11 +283,10 @@ __global__ void gridwise_implicit_gemm_convolution_1_chwn_csrk_khwn_padded( ...@@ -283,11 +283,10 @@ __global__ void gridwise_implicit_gemm_convolution_1_chwn_csrk_khwn_padded(
out_hkwn_thread_desc, out_hkwn_thread_desc,
p_out_thread, p_out_thread,
out_khwn_global_desc, out_khwn_global_desc,
p_out_global + p_out_global + out_khwn_global_desc.Get1dIndex(k_block_data_begin + k_thread_data_begin,
out_khwn_global_desc.Get1dIndex(k_block_data_begin + k_thread_data_begin, ho_block_data_begin + ho_thread_data_begin,
ho_block_data_begin + ho_thread_data_begin, wo_block_data_begin + wo_thread_data_begin,
wo_block_data_begin + wo_thread_data_begin, n_block_data_begin + n_thread_data_begin),
n_block_data_begin + n_thread_data_begin),
out_hkwn_thread_desc.GetLengths(), out_hkwn_thread_desc.GetLengths(),
reorder_khwn_from_hkwn); reorder_khwn_from_hkwn);
} }
#pragma once #pragma once
#include "common.cuh" #include "common.hip.hpp"
#include "ConstantTensorDescriptor.cuh" #include "ConstantTensorDescriptor.hip.hpp"
#include "ConstantMatrixDescriptor.cuh" #include "ConstantMatrixDescriptor.hip.hpp"
#include "blockwise_4d_tensor_op.cuh" #include "blockwise_4d_tensor_op.hip.hpp"
#include "blockwise_2d_tensor_op.cuh" #include "blockwise_2d_tensor_op.hip.hpp"
#include "threadwise_4d_tensor_op.cuh" #include "threadwise_4d_tensor_op.hip.hpp"
#include "blockwise_gemm.cuh" #include "blockwise_gemm.hip.hpp"
template <unsigned GridSize, template <unsigned GridSize,
unsigned BlockSize, unsigned BlockSize,
...@@ -339,11 +339,10 @@ __global__ void gridwise_implicit_gemm_convolution_1_chwn_csrk_khwn_padded_lds_p ...@@ -339,11 +339,10 @@ __global__ void gridwise_implicit_gemm_convolution_1_chwn_csrk_khwn_padded_lds_p
out_hkwn_thread_desc, out_hkwn_thread_desc,
p_out_thread, p_out_thread,
out_khwn_global_desc, out_khwn_global_desc,
p_out_global + p_out_global + out_khwn_global_desc.Get1dIndex(k_block_data_begin + k_thread_data_begin,
out_khwn_global_desc.Get1dIndex(k_block_data_begin + k_thread_data_begin, ho_block_data_begin + ho_thread_data_begin,
ho_block_data_begin + ho_thread_data_begin, wo_block_data_begin + wo_thread_data_begin,
wo_block_data_begin + wo_thread_data_begin, n_block_data_begin + n_thread_data_begin),
n_block_data_begin + n_thread_data_begin),
out_hkwn_thread_desc.GetLengths(), out_hkwn_thread_desc.GetLengths(),
reorder_khwn_from_hkwn); reorder_khwn_from_hkwn);
} }
#pragma once #pragma once
#include "common.cuh" #include "common.hip.hpp"
#include "ConstantTensorDescriptor.cuh" #include "ConstantTensorDescriptor.hip.hpp"
#include "ConstantMatrixDescriptor.cuh" #include "ConstantMatrixDescriptor.hip.hpp"
#include "blockwise_4d_tensor_op.cuh" #include "blockwise_4d_tensor_op.hip.hpp"
#include "threadwise_4d_tensor_op.cuh" #include "threadwise_4d_tensor_op.hip.hpp"
#include "blockwise_gemm.cuh" #include "blockwise_gemm.hip.hpp"
template <unsigned GridSize, template <unsigned GridSize,
unsigned BlockSize, unsigned BlockSize,
...@@ -160,11 +160,10 @@ gridwise_implicit_gemm_convolution_1_nchw_kcsr_nkhw(const Float* const __restric ...@@ -160,11 +160,10 @@ gridwise_implicit_gemm_convolution_1_nchw_kcsr_nkhw(const Float* const __restric
// convert [N,C,Hi,Wi] to [C,Hi,Wi,N] // convert [N,C,Hi,Wi] to [C,Hi,Wi,N]
blockwise_4d_tensor_copy_reorder_by_get_dst_from_src<BlockSize>( blockwise_4d_tensor_copy_reorder_by_get_dst_from_src<BlockSize>(
in_nchw_global_desc, in_nchw_global_desc,
p_in_global + p_in_global + in_nchw_global_desc.Get1dIndex(n_block_data_begin,
in_nchw_global_desc.Get1dIndex(n_block_data_begin, c_block_data_begin,
c_block_data_begin, hi_block_data_begin,
hi_block_data_begin, wi_block_data_begin),
wi_block_data_begin),
in_chwn_block_desc, in_chwn_block_desc,
p_in_block, p_in_block,
in_nchw_block_desc.GetLengths(), in_nchw_block_desc.GetLengths(),
...@@ -245,11 +244,10 @@ gridwise_implicit_gemm_convolution_1_nchw_kcsr_nkhw(const Float* const __restric ...@@ -245,11 +244,10 @@ gridwise_implicit_gemm_convolution_1_nchw_kcsr_nkhw(const Float* const __restric
out_hkwn_thread_desc, out_hkwn_thread_desc,
p_out_thread, p_out_thread,
out_nkhw_global_desc, out_nkhw_global_desc,
p_out_global + p_out_global + out_nkhw_global_desc.Get1dIndex(n_block_data_begin,
out_nkhw_global_desc.Get1dIndex(n_block_data_begin, k_block_data_begin + k_thread_data_begin,
k_block_data_begin + k_thread_data_begin, ho_block_data_begin + ho_thread_data_begin,
ho_block_data_begin + ho_thread_data_begin, wo_block_data_begin + wo_thread_data_begin),
wo_block_data_begin + wo_thread_data_begin),
out_hkwn_thread_desc.GetLengths(), out_hkwn_thread_desc.GetLengths(),
reorder_nkhw_from_hkwn); reorder_nkhw_from_hkwn);
#else #else
...@@ -263,11 +261,10 @@ gridwise_implicit_gemm_convolution_1_nchw_kcsr_nkhw(const Float* const __restric ...@@ -263,11 +261,10 @@ gridwise_implicit_gemm_convolution_1_nchw_kcsr_nkhw(const Float* const __restric
out_nkhw_thread_desc, out_nkhw_thread_desc,
p_out_thread, p_out_thread,
out_nkhw_global_desc, out_nkhw_global_desc,
p_out_global + p_out_global + out_nkhw_global_desc.Get1dIndex(n_block_data_begin,
out_nkhw_global_desc.Get1dIndex(n_block_data_begin, k_block_data_begin + k_thread_data_begin,
k_block_data_begin + k_thread_data_begin, ho_block_data_begin + ho_thread_data_begin,
ho_block_data_begin + ho_thread_data_begin, wo_block_data_begin + wo_thread_data_begin),
wo_block_data_begin + wo_thread_data_begin),
out_nkhw_thread_desc.GetLengths()); out_nkhw_thread_desc.GetLengths());
#endif #endif
} }
#pragma once #pragma once
#include "common.cuh" #include "common.hip.hpp"
#include "ConstantTensorDescriptor.cuh" #include "ConstantTensorDescriptor.hip.hpp"
#include "ConstantMatrixDescriptor.cuh" #include "ConstantMatrixDescriptor.hip.hpp"
#include "blockwise_4d_tensor_op.cuh" #include "blockwise_4d_tensor_op.hip.hpp"
#include "threadwise_4d_tensor_op.cuh" #include "threadwise_4d_tensor_op.hip.hpp"
#include "blockwise_gemm.cuh" #include "blockwise_gemm.hip.hpp"
template <unsigned GridSize, template <unsigned GridSize,
unsigned BlockSize, unsigned BlockSize,
...@@ -166,11 +166,10 @@ gridwise_implicit_gemm_convolution_1_nchw_srck_nkhw(const Float* const __restric ...@@ -166,11 +166,10 @@ gridwise_implicit_gemm_convolution_1_nchw_srck_nkhw(const Float* const __restric
// convert [N,C,Hi,Wi] to [C,Hi,Wi,N] // convert [N,C,Hi,Wi] to [C,Hi,Wi,N]
blockwise_4d_tensor_copy_reorder_by_get_dst_from_src<BlockSize>( blockwise_4d_tensor_copy_reorder_by_get_dst_from_src<BlockSize>(
in_nchw_global_desc, in_nchw_global_desc,
p_in_global + p_in_global + in_nchw_global_desc.Get1dIndex(n_block_data_begin,
in_nchw_global_desc.Get1dIndex(n_block_data_begin, c_block_data_begin,
c_block_data_begin, hi_block_data_begin,
hi_block_data_begin, wi_block_data_begin),
wi_block_data_begin),
in_chwn_block_desc, in_chwn_block_desc,
p_in_block, p_in_block,
in_nchw_block_desc.GetLengths(), in_nchw_block_desc.GetLengths(),
...@@ -180,10 +179,9 @@ gridwise_implicit_gemm_convolution_1_nchw_srck_nkhw(const Float* const __restric ...@@ -180,10 +179,9 @@ gridwise_implicit_gemm_convolution_1_nchw_srck_nkhw(const Float* const __restric
#if 1 #if 1
// weight: global mem to LDS, // weight: global mem to LDS,
// format is [S,R,C,K], no conversion needed // format is [S,R,C,K], no conversion needed
blockwise_wei_copy.Run( blockwise_wei_copy.Run(p_wei_global + wei_srck_global_desc.Get1dIndex(
p_wei_global + 0, 0, c_block_data_begin, k_block_data_begin),
wei_srck_global_desc.Get1dIndex(0, 0, c_block_data_begin, k_block_data_begin), p_wei_block);
p_wei_block);
#endif #endif
__syncthreads(); __syncthreads();
...@@ -219,11 +217,10 @@ gridwise_implicit_gemm_convolution_1_nchw_srck_nkhw(const Float* const __restric ...@@ -219,11 +217,10 @@ gridwise_implicit_gemm_convolution_1_nchw_srck_nkhw(const Float* const __restric
out_hkwn_thread_desc, out_hkwn_thread_desc,
p_out_thread, p_out_thread,
out_nkhw_global_desc, out_nkhw_global_desc,
p_out_global + p_out_global + out_nkhw_global_desc.Get1dIndex(n_block_data_begin + n_thread_data_begin,
out_nkhw_global_desc.Get1dIndex(n_block_data_begin + n_thread_data_begin, k_block_data_begin + k_thread_data_begin,
k_block_data_begin + k_thread_data_begin, ho_block_data_begin + ho_thread_data_begin,
ho_block_data_begin + ho_thread_data_begin, wo_block_data_begin + wo_thread_data_begin),
wo_block_data_begin + wo_thread_data_begin),
out_hkwn_thread_desc.GetLengths(), out_hkwn_thread_desc.GetLengths(),
reorder_nkhw_from_hkwn); reorder_nkhw_from_hkwn);
} }
#pragma once #pragma once
#include "common.cuh" #include "common.hip.hpp"
#include "ConstantTensorDescriptor.cuh" #include "ConstantTensorDescriptor.hip.hpp"
#include "ConstantMatrixDescriptor.cuh" #include "ConstantMatrixDescriptor.hip.hpp"
#include "blockwise_4d_tensor_op.cuh" #include "blockwise_4d_tensor_op.hip.hpp"
#include "blockwise_2d_tensor_op.cuh" #include "blockwise_2d_tensor_op.hip.hpp"
#include "threadwise_2d_tensor_op.cuh" #include "threadwise_2d_tensor_op.hip.hpp"
#include "blockwise_gemm.cuh" #include "blockwise_gemm.hip.hpp"
// define B = flatten(N, Hi, Wi) // define B = flatten(N, Hi, Wi)
template <unsigned GridSize, template <unsigned GridSize,
...@@ -121,7 +121,7 @@ gridwise_implicit_gemm_convolution_2_cnhw_csrk_knhw(const Float* const __restric ...@@ -121,7 +121,7 @@ gridwise_implicit_gemm_convolution_2_cnhw_csrk_knhw(const Float* const __restric
decltype(in_cb_block_desc), decltype(in_cb_block_desc),
decltype(in_cb_block_desc.GetLengths())>{}; decltype(in_cb_block_desc.GetLengths())>{};
#elif 0 #elif 0
const auto blockwise_in_copy = Blockwise2dTensorCopy2<BlockSize, const auto blockwise_in_copy = Blockwise2dTensorCopy2<BlockSize,
Float, Float,
decltype(in_cb_global_desc), decltype(in_cb_global_desc),
decltype(in_cb_block_desc), decltype(in_cb_block_desc),
...@@ -129,7 +129,7 @@ gridwise_implicit_gemm_convolution_2_cnhw_csrk_knhw(const Float* const __restric ...@@ -129,7 +129,7 @@ gridwise_implicit_gemm_convolution_2_cnhw_csrk_knhw(const Float* const __restric
InBlockCopyThreadPerDim0, InBlockCopyThreadPerDim0,
InBlockCopyThreadPerDim1>{}; InBlockCopyThreadPerDim1>{};
#elif 1 #elif 1
const auto blockwise_in_copy = Blockwise2dTensorCopy3<BlockSize, const auto blockwise_in_copy = Blockwise2dTensorCopy3<BlockSize,
Float, Float,
decltype(in_cb_global_desc), decltype(in_cb_global_desc),
decltype(in_cb_block_desc), decltype(in_cb_block_desc),
......
#pragma once #pragma once
#include "common.cuh" #include "common.hip.hpp"
#include "ConstantTensorDescriptor.cuh" #include "ConstantTensorDescriptor.hip.hpp"
#include "ConstantMatrixDescriptor.cuh" #include "ConstantMatrixDescriptor.hip.hpp"
#include "blockwise_4d_tensor_op.cuh" #include "blockwise_4d_tensor_op.hip.hpp"
#include "blockwise_2d_tensor_op.cuh" #include "blockwise_2d_tensor_op.hip.hpp"
#include "threadwise_2d_tensor_op.cuh" #include "threadwise_2d_tensor_op.hip.hpp"
#include "blockwise_gemm.cuh" #include "blockwise_gemm.hip.hpp"
// define B = flatten(N, Hi, Wi) // define B = flatten(N, Hi, Wi)
template <unsigned GridSize, template <unsigned GridSize,
...@@ -121,7 +121,7 @@ __global__ void gridwise_implicit_gemm_convolution_2_cnhw_csrk_knhw_lds_double_b ...@@ -121,7 +121,7 @@ __global__ void gridwise_implicit_gemm_convolution_2_cnhw_csrk_knhw_lds_double_b
decltype(in_cb_block_desc), decltype(in_cb_block_desc),
decltype(in_cb_block_desc.GetLengths())>{}; decltype(in_cb_block_desc.GetLengths())>{};
#elif 0 #elif 0
const auto blockwise_in_copy = Blockwise2dTensorCopy2<BlockSize, const auto blockwise_in_copy = Blockwise2dTensorCopy2<BlockSize,
Float, Float,
decltype(in_cb_global_desc), decltype(in_cb_global_desc),
decltype(in_cb_block_desc), decltype(in_cb_block_desc),
...@@ -129,7 +129,7 @@ __global__ void gridwise_implicit_gemm_convolution_2_cnhw_csrk_knhw_lds_double_b ...@@ -129,7 +129,7 @@ __global__ void gridwise_implicit_gemm_convolution_2_cnhw_csrk_knhw_lds_double_b
InBlockCopyThreadPerDim0, InBlockCopyThreadPerDim0,
InBlockCopyThreadPerDim1>{}; InBlockCopyThreadPerDim1>{};
#elif 1 #elif 1
const auto blockwise_in_copy = Blockwise2dTensorCopy3<BlockSize, const auto blockwise_in_copy = Blockwise2dTensorCopy3<BlockSize,
Float, Float,
decltype(in_cb_global_desc), decltype(in_cb_global_desc),
decltype(in_cb_block_desc), decltype(in_cb_block_desc),
......
#pragma once #pragma once
#include "common.cuh" #include "common.hip.hpp"
#include "ConstantTensorDescriptor.cuh" #include "ConstantTensorDescriptor.hip.hpp"
#include "ConstantMatrixDescriptor.cuh" #include "ConstantMatrixDescriptor.hip.hpp"
#include "blockwise_4d_tensor_op.cuh" #include "blockwise_4d_tensor_op.hip.hpp"
#include "blockwise_2d_tensor_op.cuh" #include "blockwise_2d_tensor_op.hip.hpp"
#include "threadwise_2d_tensor_op.cuh" #include "threadwise_2d_tensor_op.hip.hpp"
#include "blockwise_gemm.cuh" #include "blockwise_gemm.hip.hpp"
// define B = flatten(N, Hi, Wi) // define B = flatten(N, Hi, Wi)
template <unsigned GridSize, template <unsigned GridSize,
......
#pragma once #pragma once
#include "common.cuh" #include "common.hip.hpp"
#include "ConstantTensorDescriptor.cuh" #include "ConstantTensorDescriptor.hip.hpp"
#include "ConstantMatrixDescriptor.cuh" #include "ConstantMatrixDescriptor.hip.hpp"
#include "blockwise_4d_tensor_op.cuh" #include "blockwise_4d_tensor_op.hip.hpp"
#include "blockwise_2d_tensor_op.cuh" #include "blockwise_2d_tensor_op.hip.hpp"
#include "threadwise_2d_tensor_op.cuh" #include "threadwise_2d_tensor_op.hip.hpp"
#include "blockwise_gemm.cuh" #include "blockwise_gemm.hip.hpp"
// define B = N*Hi*Wi // define B = N*Hi*Wi
template <unsigned GridSize, template <unsigned GridSize,
...@@ -220,10 +220,9 @@ __global__ void gridwise_implicit_gemm_convolution_2_cnhw_srck_knhw_lds_pipeline ...@@ -220,10 +220,9 @@ __global__ void gridwise_implicit_gemm_convolution_2_cnhw_srck_knhw_lds_pipeline
#if 1 #if 1
// preload next data // preload next data
// input: global mem to LDS, // input: global mem to LDS,
blockwise_in_copy.Run( blockwise_in_copy.Run(p_in_global + in_cb_global_desc.Get1dIndex(
p_in_global + c_block_data_begin + CPerBlock, b_block_data_begin),
in_cb_global_desc.Get1dIndex(c_block_data_begin + CPerBlock, b_block_data_begin), p_in_block_next);
p_in_block_next);
#endif #endif
#if 1 #if 1
......
#pragma once #pragma once
#include "ConstantTensorDescriptor.cuh" #include "ConstantTensorDescriptor.hip.hpp"
#include "blockwise_winograd_transform.cuh" #include "blockwise_winograd_transform.hip.hpp"
#include "threadwise_winograd_transform.cuh" #include "threadwise_winograd_transform.hip.hpp"
template <class Float, template <class Float,
class InGlobalDesc, class InGlobalDesc,
...@@ -189,18 +189,17 @@ __global__ void gridwise_winograd_convolution(const Float* const __restrict__ p_ ...@@ -189,18 +189,17 @@ __global__ void gridwise_winograd_convolution(const Float* const __restrict__ p_
S, S,
R, R,
OutTileSizeH, OutTileSizeH,
OutTileSizeW>( OutTileSizeW>(in_transform_thread_block_desc,
in_transform_thread_block_desc, p_in_transform_block + in_transform_block_desc.Get1dIndex(
p_in_transform_block + n_thread_data_begin,
in_transform_block_desc.Get1dIndex(n_thread_data_begin, c_thread_data,
c_thread_data, y_thread_data_begin * InTileSizeH,
y_thread_data_begin * InTileSizeH, x_thread_data_begin * InTileSizeW),
x_thread_data_begin * InTileSizeW), wei_transform_thread_block_desc,
wei_transform_thread_block_desc, p_wei_transform_block + wei_transform_block_desc.Get1dIndex(
p_wei_transform_block + k_thread_data_begin, c_thread_data, 0, 0),
wei_transform_block_desc.Get1dIndex(k_thread_data_begin, c_thread_data, 0, 0), out_transform_thread_desc,
out_transform_thread_desc, p_out_transform_thread);
p_out_transform_thread);
} }
}; };
......
...@@ -22,7 +22,8 @@ std::ostream& LogRange(std::ostream& os, Range&& r, std::string delim) ...@@ -22,7 +22,8 @@ std::ostream& LogRange(std::ostream& os, Range&& r, std::string delim)
return os; return os;
} }
typedef enum { typedef enum
{
Half = 0, Half = 0,
Float = 1, Float = 1,
} DataType_t; } DataType_t;
......
#pragma once #pragma once
#include "ConstantTensorDescriptor.cuh" #include "ConstantTensorDescriptor.hip.hpp"
template <class Float, class Desc, class F> template <class Float, class Desc, class F>
__device__ void threadwise_2d_tensor_pointwise_operation_unary(Desc, Float* __restrict__ p, F f) __device__ void threadwise_2d_tensor_pointwise_operation_unary(Desc, Float* __restrict__ p, F f)
......
#pragma once #pragma once
#include "ConstantTensorDescriptor.cuh" #include "ConstantTensorDescriptor.hip.hpp"
template <class Float, class Desc, class F> template <class Float, class Desc, class F>
__device__ void threadwise_4d_tensor_pointwise_operation_unary(Desc, Float* __restrict__ p, F f) __device__ void threadwise_4d_tensor_pointwise_operation_unary(Desc, Float* __restrict__ p, F f)
......
#pragma once #pragma once
#include "ConstantTensorDescriptor.cuh" #include "ConstantTensorDescriptor.hip.hpp"
// optimized for scenario if p_in, p_wei, p_out are in register // optimized for scenario if p_in, p_wei, p_out are in register
template <class Float, class InDesc, class WeiDesc, class OutDesc> template <class Float, class InDesc, class WeiDesc, class OutDesc>
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment