Commit 1b15b21a authored by Chao Liu's avatar Chao Liu
Browse files

update static_tensor for dealing with invalid element

parent 2fd5e6ae
#ifndef CK_STATIC_TENSOR_HPP
#define CK_STATIC_TENSOR_HPP
#include "ignore.hpp"
namespace ck {
// StaticTensor for Scalar
......@@ -17,10 +15,10 @@ struct StaticTensor
static constexpr index_t ndim_ = TensorDesc::GetNumOfDimension();
static constexpr index_t element_space_size_ = desc_.GetElementSpaceSize();
__host__ __device__ constexpr StaticTensor() : invalid_element_value_{0} {}
__host__ __device__ constexpr StaticTensor() : invalid_element_scalar_value_{0} {}
__host__ __device__ constexpr StaticTensor(T invalid_element_value)
: invalid_element_value_{invalid_element_value}
: invalid_element_scalar_value_{invalid_element_value}
{
}
......@@ -44,11 +42,11 @@ struct StaticTensor
{
if constexpr(InvalidElementUseNumericalZeroValue)
{
return T{0};
return zero_scalar_value_;
}
else
{
return invalid_element_value_;
return invalid_element_scalar_value_;
}
}
}
......@@ -71,12 +69,14 @@ struct StaticTensor
}
else
{
return ignore;
return ignored_element_scalar_;
}
}
StaticBuffer<AddressSpace, T, element_space_size_, true> data_;
T invalid_element_value_ = T{0};
static constexpr T zero_scalar_value_ = T{0};
const T invalid_element_scalar_value_;
T ignored_element_scalar_;
};
// StaticTensor for vector
......@@ -97,10 +97,13 @@ struct StaticTensorTupleOfVectorBuffer
using V = vector_type<S, ScalarPerVector>;
__host__ __device__ constexpr StaticTensorTupleOfVectorBuffer() : invalid_element_value_{0} {}
__host__ __device__ constexpr StaticTensorTupleOfVectorBuffer()
: invalid_element_scalar_value_{0}
{
}
__host__ __device__ constexpr StaticTensorTupleOfVectorBuffer(S invalid_element_value)
: invalid_element_value_{invalid_element_value}
: invalid_element_scalar_value_{invalid_element_value}
{
}
......@@ -125,11 +128,11 @@ struct StaticTensorTupleOfVectorBuffer
{
if constexpr(InvalidElementUseNumericalZeroValue)
{
return S{0};
return zero_scalar_value_;
}
else
{
return invalid_element_value_;
return invalid_element_scalar_value_;
}
}
}
......@@ -153,7 +156,7 @@ struct StaticTensorTupleOfVectorBuffer
}
else
{
return ignore;
return ignored_element_scalar_;
}
}
......@@ -186,7 +189,7 @@ struct StaticTensorTupleOfVectorBuffer
else
{
// TODO: is this right way to initialize a vector?
return X{invalid_element_value_};
return X{invalid_element_scalar_value_};
}
}
}
......@@ -237,7 +240,9 @@ struct StaticTensorTupleOfVectorBuffer
}
StaticBufferTupleOfVector<AddressSpace, S, num_of_vector_, ScalarPerVector, true> data_;
S invalid_element_value_ = S{0};
static constexpr S zero_scalar_value_ = S{0};
const S invalid_element_scalar_value_ = S{0};
S ignored_element_scalar_;
};
template <AddressSpaceEnum_t AddressSpace,
......
......@@ -114,6 +114,16 @@ struct BlockwiseTensorSliceTransfer_v4
}
}
template <typename SrcBuffer, typename DstBuffer>
__device__ void Run(const SrcDesc& src_desc,
const SrcBuffer& src_buf,
const DstDesc& dst_desc,
DstBuffer& dst_buf)
{
RunRead(src_desc, src_buf);
RunWrite(dst_desc, dst_buf);
}
__device__ void MoveSrcSliceWindow(const SrcDesc& src_desc, const Index& step)
{
if(BlockSize == thread_cluster_desc_.GetElementSize() or
......
......@@ -165,6 +165,7 @@ struct ThreadwiseTensorSliceTransfer_v3r2
static_for<1, nDim, 1>{}([&](auto i) {
index_t tmp = ordered_src_access_idx[I0];
// TODO: BUG: should start at 1
static_for<0, i, 1>{}([&](auto j) {
tmp = tmp * ordered_src_access_lengths[j] + ordered_src_access_idx[j];
});
......@@ -412,6 +413,7 @@ struct ThreadwiseTensorSliceTransfer_v3r2
static_for<1, nDim, 1>{}([&](auto i) {
index_t tmp = ordered_dst_access_idx[I0];
// TODO: BUG: should start at 1
static_for<0, i, 1>{}([&](auto j) {
tmp = tmp * ordered_dst_access_lengths[j] + ordered_dst_access_idx[j];
});
......@@ -512,7 +514,8 @@ struct ThreadwiseTensorSliceTransfer_v3r2
template <typename DstBuffer>
__device__ void RunWrite(const DstDesc& dst_desc, DstBuffer& dst_buf)
{
constexpr index_t ntransform_dst = DstDesc::GetNumOfTransform();
// TODO: why need remove_cvref_t ?
constexpr index_t ntransform_dst = remove_cvref_t<DstDesc>::GetNumOfTransform();
constexpr auto zeros = typename uniform_sequence_gen<ntransform_dst, 0>::type{};
......@@ -545,6 +548,7 @@ struct ThreadwiseTensorSliceTransfer_v3r2
forward_sweep_(I0) = true;
// TODO: BUG: should start at 1
static_for<1, nDim, 1>{}([&](auto i) {
index_t tmp = ordered_src_access_lengths[I0] - 1;
......@@ -608,6 +612,7 @@ struct ThreadwiseTensorSliceTransfer_v3r2
static_for<1, nDim, 1>{}([&](auto i) {
index_t tmp = ordered_dst_access_lengths[I0] - 1;
// TODO: BUG: should start at 1
static_for<0, i, 1>{}([&](auto j) {
tmp = tmp * ordered_dst_access_lengths[j] + ordered_dst_access_lengths[j] - 1;
});
......
......@@ -35,8 +35,8 @@
#include "dynamic_buffer.hpp"
#include "is_known_at_compile_time.hpp"
#include "transpose_vectors.hpp"
#include "inner_product.hpp"
#include "element_wise_operation.hpp"
// TODO: remove this
#if CK_USE_AMD_INLINE_ASM
......
......@@ -24,12 +24,16 @@
#define CK_MIN_BLOCK_PER_CU 2
#endif
// buffer resourse
// GPU-specific parameters
#if defined(CK_AMD_GPU_GFX803) || defined(CK_AMD_GPU_GFX900) || defined(CK_AMD_GPU_GFX906) || \
defined(CK_AMD_GPU_GFX908) || defined(CK_AMD_GPU_GFX90A)
// buffer resourse
#define CK_BUFFER_RESOURCE_3RD_DWORD 0x00020000
// wave size
#define CK_GPU_WAVE_SIZE 64
#elif defined(CK_AMD_GPU_GFX1030)
#define CK_BUFFER_RESOURCE_3RD_DWORD 0x31014000
#define CK_GPU_WAVE_SIZE 32
#endif
// FMA instruction
......
......@@ -5,8 +5,12 @@
namespace ck {
__device__ constexpr index_t get_wave_size() { return CK_GPU_WAVE_SIZE; }
__device__ index_t get_thread_local_1d_id() { return threadIdx.x; }
__device__ index_t get_wave_local_1d_id() { return threadIdx.x / get_wave_size(); }
__device__ index_t get_block_1d_id() { return blockIdx.x; }
} // namespace ck
......
......@@ -29,8 +29,8 @@ template <typename InDataType,
ck::index_t NPerBlock,
ck::index_t K0PerBlock,
ck::index_t K1,
ck::index_t MPerXDL,
ck::index_t NPerXDL,
ck::index_t MPerXdl,
ck::index_t NPerXdl,
ck::index_t MXdlPerWave,
ck::index_t NXdlPerWave,
typename ABlockTransferThreadSliceLengths_K0_M_K1,
......@@ -266,8 +266,8 @@ struct DeviceConv2dFwdXdl_Output_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N
MPerBlock,
NPerBlock,
K0PerBlock,
MPerXDL,
NPerXDL,
MPerXdl,
NPerXdl,
K1,
MXdlPerWave,
NXdlPerWave,
......@@ -299,10 +299,12 @@ struct DeviceConv2dFwdXdl_Output_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N
ABlockLdsAddExtraM,
BBlockLdsAddExtraN>;
#if !DEBUG_USE_C_SHUFFLE
using CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2 =
decltype(GridwiseGemm::MakeCGridDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(CGridDesc_M_N{}));
using Block2CTileMap = decltype(GridwiseGemm::MakeBlock2CTileMap(CGridDesc_M_N{}, 1, 1));
#endif
// Argument
struct Argument : public BaseArgument
......@@ -331,7 +333,11 @@ struct DeviceConv2dFwdXdl_Output_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N
a_grid_desc_k0_m_k1_{},
b_grid_desc_k0_n_k1_{},
c_grid_desc_m_n_{},
#if !DEBUG_USE_C_SHUFFLE
c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_{},
#else
c_grid_desc_mblock_mrepeat_mwavemperxdl_nblock_nrepeat_nwavenperxdl_{},
#endif
block_2_ctile_map_{},
M01_{M01},
N01_{N01},
......@@ -358,8 +364,15 @@ struct DeviceConv2dFwdXdl_Output_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N
if(GridwiseGemm::CheckValidity(
a_grid_desc_k0_m_k1_, b_grid_desc_k0_n_k1_, c_grid_desc_m_n_, M01_, N01_))
{
#if !DEBUG_USE_C_SHUFFLE
c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_ =
GridwiseGemm::MakeCGridDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(c_grid_desc_m_n_);
#else
c_grid_desc_mblock_mrepeat_mwavemperxdl_nblock_nrepeat_nwavenperxdl_ =
GridwiseGemm::
MakeCGridDescriptor_MBlock_MRepeat_MWaveMPerXdl_NBlock_NRepeat_NWaveNPerXdl(
c_grid_desc_m_n_);
#endif
block_2_ctile_map_ = GridwiseGemm::MakeBlock2CTileMap(c_grid_desc_m_n_, M01, N01);
}
......@@ -372,8 +385,15 @@ struct DeviceConv2dFwdXdl_Output_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N
AGridDesc_K0_M_K1 a_grid_desc_k0_m_k1_;
BGridDesc_K0_N_K1 b_grid_desc_k0_n_k1_;
CGridDesc_M_N c_grid_desc_m_n_;
CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2 c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_;
Block2CTileMap block_2_ctile_map_;
#if !DEBUG_USE_C_SHUFFLE
typename GridwiseGemm::CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2
c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_;
#else
typename GridwiseGemm::
CGridDescriptor_MBlock_MRepeat_MWaveMPerXdl_NBlock_NRepeat_NWaveNPerXdl
c_grid_desc_mblock_mrepeat_mwavemperxdl_nblock_nrepeat_nwavenperxdl_;
#endif
typename GridwiseGemm::Block2CTileMap block_2_ctile_map_;
index_t M01_;
index_t N01_;
InElementwiseOperation in_element_op_;
......@@ -427,11 +447,17 @@ struct DeviceConv2dFwdXdl_Output_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N
CDataType,
remove_reference_t<DeviceOp::AGridDesc_K0_M_K1>,
remove_reference_t<DeviceOp::BGridDesc_K0_N_K1>,
remove_reference_t<DeviceOp::CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2>,
#if !DEBUG_USE_C_SHUFFLE
remove_reference_t<typename GridwiseGemm::CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2>,
#else
remove_reference_t<
typename GridwiseGemm::
CGridDescriptor_MBlock_MRepeat_MWaveMPerXdl_NBlock_NRepeat_NWaveNPerXdl>,
#endif
InElementwiseOperation,
WeiElementwiseOperation,
OutElementwiseOperation,
remove_reference_t<DeviceOp::Block2CTileMap>,
remove_reference_t<typename GridwiseGemm::Block2CTileMap>,
true>;
ave_time = launch_and_time_kernel(kernel,
......@@ -444,7 +470,11 @@ struct DeviceConv2dFwdXdl_Output_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N
arg.p_c_grid_,
arg.a_grid_desc_k0_m_k1_,
arg.b_grid_desc_k0_n_k1_,
#if !DEBUG_USE_C_SHUFFLE
arg.c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_,
#else
arg.c_grid_desc_mblock_mrepeat_mwavemperxdl_nblock_nrepeat_nwavenperxdl_,
#endif
arg.in_element_op_,
arg.wei_element_op_,
arg.out_element_op_,
......@@ -458,11 +488,17 @@ struct DeviceConv2dFwdXdl_Output_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N
CDataType,
remove_reference_t<DeviceOp::AGridDesc_K0_M_K1>,
remove_reference_t<DeviceOp::BGridDesc_K0_N_K1>,
remove_reference_t<DeviceOp::CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2>,
#if !DEBUG_USE_C_SHUFFLE
remove_reference_t<typename GridwiseGemm::CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2>,
#else
remove_reference_t<
typename GridwiseGemm::
CGridDescriptor_MBlock_MRepeat_MWaveMPerXdl_NBlock_NRepeat_NWaveNPerXdl>,
#endif
InElementwiseOperation,
WeiElementwiseOperation,
OutElementwiseOperation,
remove_reference_t<DeviceOp::Block2CTileMap>,
remove_reference_t<typename GridwiseGemm::Block2CTileMap>,
false>;
ave_time = launch_and_time_kernel(kernel,
......@@ -475,7 +511,11 @@ struct DeviceConv2dFwdXdl_Output_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N
arg.p_c_grid_,
arg.a_grid_desc_k0_m_k1_,
arg.b_grid_desc_k0_n_k1_,
#if !DEBUG_USE_C_SHUFFLE
arg.c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_,
#else
arg.c_grid_desc_mblock_mrepeat_mwavemperxdl_nblock_nrepeat_nwavenperxdl_,
#endif
arg.in_element_op_,
arg.wei_element_op_,
arg.out_element_op_,
......
#ifndef ELEMENT_WISE_OPERATION_HPP
#define ELEMENT_WISE_OPERATION_HPP
namespace ck {
namespace tensor_operation {
namespace element_wise {
struct PassThrough
{
template <typename T>
__host__ __device__ constexpr T operator()(T v) const
{
return v;
}
};
struct AddRelu
{
template <typename T1>
__host__ constexpr float operator()(float v0, T1 v1) const
{
float b = v0 + v1;
float c = b > 0 ? b : 0;
return c;
}
template <typename T1>
__device__ constexpr float operator()(float v0, T1 v1) const
{
#if 0
float a = v1 + v0;
float b = max(a, float(0));
return b;
#else
float b = v1 + v0;
float c = b > 0 ? b : 0;
return c;
#endif
}
};
struct AddReluAdd
{
template <typename T1, typename T2>
__host__ constexpr float operator()(float v0, T1 v1, T2 v2) const
{
float b = v0 + v1;
float c = b > 0 ? b : 0;
float d = c + v2;
return d;
}
template <typename T1, typename T2>
__device__ constexpr float operator()(float v0, T1 v1, T2 v2) const
{
#if 0
float a = v1 + v0;
float b = max(a, float(0));
float c = b + v2;
return c;
#else
float b = v1 + v2;
float c = (v0 > -v1) ? b + v0 : v2;
return c;
#endif
}
};
struct AddLeakyReluAdd
{
template <typename T1, typename T2>
__host__ constexpr float operator()(float v0, T1 v1, T2 v2) const
{
float a = v0 + v1;
float b = 0.1 * a;
float c = b > 0 ? b : 0;
float d = c + v2;
return d;
}
template <typename T1, typename T2>
__device__ constexpr float operator()(float v0, T1 v1, T2 v2) const
{
#if 0
// this use not too many registers, but use fp64 mul
float a = v0 + v1;
float b = 0.1 * a;
float c = b > 0 ? b : 0;
float d = c + v2;
return d;
#elif 0
// this spill register
float a = v0 + v1;
float b = float(0.1) * a;
float c = b > 0 ? b : 0;
float d = c + v2;
return d;
#elif 0
// this use lots of registers (but no spill)
constexpr float alpha = 0.1;
constexpr float alpha_inv = 1.0 / alpha;
float a = v2 * alpha_inv;
float b = v1 + v0;
float c = b > 0 ? b : 0;
float d = alpha * (a + c);
return d;
#elif 1
// this use lots of registers (but no spill), 89 Tflops
constexpr float alpha = 0.1;
constexpr float alpha_inv = 1.0 / alpha;
float a = v2 * alpha_inv;
float b = v1 + v0;
float c = max(b, float(0));
float d = alpha * (a + c);
return d;
#elif 1
// this spill registers, 89 Tflops
float a = v0 + v1;
float alpha = 0.1;
float b;
asm volatile("\n \
v_mul_f32_e32 %0, %1, %2 \n \
"
: "=v"(b)
: "s"(alpha), "v"(a));
float c = b > 0 ? b : 0;
float d = c + v2;
return d;
#endif
}
};
} // namespace element_wise
} // namespace tensor_operation
} // namespace ck
#endif
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment