Commit 736a37ba authored by Jing Zhang's avatar Jing Zhang
Browse files

debug

parent 15232a0d
......@@ -129,6 +129,7 @@ struct BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_xdlops
return p_c_thread;
}
};
#endif
template <>
struct WithMNRepeats<1, 1>
......@@ -138,10 +139,10 @@ struct BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_xdlops
const FloatB* __restrict__ p_b_block,
FloatC p_c_thread)
{
return XdlopsGemm.template Run<M, N, K>(p_a_block, p_b_block, p_c_thread);
p_c_thread = XdlopsGemm.template Run<M, N, K>(p_a_block, p_b_block, p_c_thread);
return p_c_thread;
}
};
#endif
#endif
template <class FloatA, class FloatB, class FloatC>
......
......@@ -51,6 +51,8 @@ struct make_block_work_sequence<MBlockWork, NBlockWork, NBlock1MBlock0>
__device__ constexpr auto get() { return Sequence<NBlockWork, MBlockWork>{}; }
};
#define ACCVGPR_ZERO(acc_reg_id) asm volatile("v_accvgpr_write_b32 a[" #acc_reg_id "], 0" : :);
template <index_t GridSize,
index_t BlockSize,
class ABFloat,
......@@ -212,6 +214,11 @@ struct GridwiseBatchGemmXdlops_gkmkpack_gknkpack_gmn_v2_org
constexpr index_t c_thread_size = MPerBlock * NPerBlock / BlockSize;
auto c_thread_vec = GetRegBuffer<AccFloat, c_thread_size>();
ACCVGPR_ZERO(0)
ACCVGPR_ZERO(1)
ACCVGPR_ZERO(2)
ACCVGPR_ZERO(3)
// preload data into LDS
{
a_blockwise_copy.Run(p_a_global, p_a_block);
......@@ -496,6 +503,11 @@ struct GridwiseBatchGemmXdlops_gkmkpack_gknkpack_gmn_v2
constexpr index_t c_thread_size = MPerBlock * NPerBlock / BlockSize;
auto c_thread_vec = GetRegBuffer<AccFloat, c_thread_size>();
ACCVGPR_ZERO(0)
ACCVGPR_ZERO(1)
ACCVGPR_ZERO(2)
ACCVGPR_ZERO(3)
// preload data into LDS
{
a_blockwise_copy.Run(p_a_global, p_a_block);
......@@ -615,7 +627,8 @@ struct GridwiseBatchGemmXdlops_gkmkpack_gknkpack_gmn_v2
m_thread_data_on_global % (M2 * M1) / M2,
m_thread_data_on_global % M2,
n_thread_data_on_global))
.Store(c_thread_vec.GetVector(Number<M0 * M2>{})[Number<blk_id>{}], p_c_global);
.Store(c_thread_vec, p_c_global);
//.Store(c_thread_vec.GetVector(Number<M0 * M2>{})[Number<blk_id>{}], p_c_global);
});
}
}
......
......@@ -56,15 +56,11 @@ struct ThreadwiseGenericTensorSliceCopy_v5
static_assert(is_valid_sequence_map<SrcDimAccessOrder>{}, "wrong! map is not valid");
static_assert(is_valid_sequence_map<DstDimAccessOrder>{}, "wrong! map is not valid");
static_assert(
SliceLengths{}[SrcVectorReadDim] % math::lcm(SrcDataPerRead, DstDataPerWrite) == 0,
static_assert(SliceLengths{}[SrcVectorReadDim] % SrcDataPerRead == 0,
"wrong! cannot evenly divide");
static_assert(
SliceLengths{}[DstVectorWriteDim] % math::lcm(SrcDataPerRead, DstDataPerWrite) == 0,
static_assert(SliceLengths{}[DstVectorWriteDim] % DstDataPerWrite == 0,
"wrong! cannot evenly divide");
static_assert(ThreadBufferSize == 8 || ThreadBufferSize == 16, "");
}
__device__ constexpr ThreadwiseGenericTensorSliceCopy_v5()
......@@ -194,9 +190,9 @@ struct ThreadwiseGenericTensorSliceCopy_v5
// load data from src to the long-vector buffer
const auto src_coord = mSrcSliceOrigin + to_multi_index(long_vector_data_begin_id);
auto src_buff =
vector_data_load<SrcData, src_data_per_access>::run(p_src, src_coord);
// buffer_vector_load<SrcDataPerRead, SrcDesc::GetElementSpace()>(p_src, src_coord);
auto src_buff = buffer_vector_load<SrcDataPerRead, SrcDesc::GetElementSpace()>(
p_src, src_coord);
// vector_data_load<SrcData, src_data_per_access>::run(p_src, src_coord);
// store data from the long-vector buffer to dst
constexpr auto buff_off =
......
......@@ -132,10 +132,12 @@ intrin_mfma_f32_32x32x2f32(const float* reg_a, const float* reg_b, c_vec16_1_t::
return reg_c;
}
__device__ c_vec4_1_t::VecType
intrin_mfma_f32_16x16x4f32(const float* reg_a, const float* reg_b, c_vec4_1_t::VecType reg_c)
__device__ float_vec4_t intrin_mfma_f32_16x16x4f32(const float* reg_a,
const float* reg_b,
float_vec4_t reg_c)
{
reg_c.s.x = llvm_intrin_amdgcn_mfma_f32_16x16x4f32(reg_a[0], reg_b[0], reg_c.s.x, 0, 0, 0);
reg_c.s4(Number<0>{}) =
llvm_intrin_amdgcn_mfma_f32_16x16x4f32(reg_a[0], reg_b[0], reg_c.s4[Number<0>{}], 0, 0, 0);
return reg_c;
}
......
......@@ -36,6 +36,7 @@ union float_vec4_t
StaticallyIndexedArray<float, 4> s1;
StaticallyIndexedArray<float2_t, 2> s2;
StaticallyIndexedArray<float4_t, 1> s4;
float n[4];
__host__ __device__ constexpr float_vec4_t() {}
template <index_t vs>
......
......@@ -64,15 +64,15 @@ void gridwise_convolution_forward_implicit_gemm_v4r4_xdlops_nchw_kcyx_nkhw(
make_native_tensor_descriptor_packed(Sequence<N, K, Ho, Wo>{});
// read params: tunning parameters
constexpr index_t GemmMPerBlock = 128;
constexpr index_t GemmNPerBlock = 256;
constexpr index_t GemmMPerBlock = 16;
constexpr index_t GemmNPerBlock = 16;
constexpr index_t GemmKPerBlock = 4;
constexpr index_t GemmMPerWave = 128;
constexpr index_t GemmNPerWave = 64;
constexpr index_t GemmMPerWave = 16;
constexpr index_t GemmNPerWave = 16;
constexpr index_t GemmKPack = 4;
// read params: dependent parameters
constexpr index_t BlockSize = 256;
constexpr index_t BlockSize = 64;
constexpr index_t GemmM = K;
constexpr index_t GemmN = N * Ho * Wo;
......@@ -83,7 +83,7 @@ void gridwise_convolution_forward_implicit_gemm_v4r4_xdlops_nchw_kcyx_nkhw(
// A matrix copy
constexpr index_t GemmABlockCopyClusterLengths_GemmK = 4;
constexpr index_t GemmABlockCopyClusterLengths_GemmM = 64;
constexpr index_t GemmABlockCopyClusterLengths_GemmM = 16;
constexpr index_t GemmABlockCopyClusterLengths_GemmKPack = 1;
constexpr index_t GemmABlockCopyThreadSliceLengths_GemmK =
......@@ -114,8 +114,8 @@ void gridwise_convolution_forward_implicit_gemm_v4r4_xdlops_nchw_kcyx_nkhw(
// B matrix Copy
constexpr index_t GemmBBlockCopyClusterLengths_GemmK = 4;
constexpr index_t GemmBBlockCopyClusterLengths_GemmN = 64;
constexpr index_t GemmBBlockCopyClusterLengths_GemmKPack = 1;
constexpr index_t GemmBBlockCopyClusterLengths_GemmN = 4;
constexpr index_t GemmBBlockCopyClusterLengths_GemmKPack = 4;
constexpr index_t GemmBBlockCopyThreadSliceLengths_GemmK =
GemmKPerBlock / GemmBBlockCopyClusterLengths_GemmK;
......@@ -140,7 +140,7 @@ void gridwise_convolution_forward_implicit_gemm_v4r4_xdlops_nchw_kcyx_nkhw(
using GemmBBlockCopySrcAccessOrder = Sequence<0, 1, 3, 2>; // [GemmG, GemmK, GemmKPack, GemmN]
using GemmBBlockCopyDstAccessOrder = Sequence<0, 1, 2, 3>; // [GemmG, GemmK, GemmN, GemmKPack]
constexpr index_t GemmBBlockCopySrcDataPerRead_GemmN = 1;
constexpr index_t GemmBBlockCopySrcDataPerRead_GemmN = 4;
constexpr index_t GemmBBlockCopyDstDataPerWrite_GemmKPack = 1;
// gridwise GEMM
......
......@@ -24,11 +24,11 @@ int main(int argc, char* argv[])
using namespace ck;
// 1x1, 56x56
constexpr index_t N = 128;
constexpr index_t C = 128;
constexpr index_t HI = 56;
constexpr index_t WI = 56;
constexpr index_t K = 128;
constexpr index_t N = 4;
constexpr index_t C = 32;
constexpr index_t HI = 2;
constexpr index_t WI = 2;
constexpr index_t K = 32;
constexpr index_t Y = 1;
constexpr index_t X = 1;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment