Commit a3edbec9 authored by Jing Zhang's avatar Jing Zhang
Browse files

debugging vector global store

parent 631d9892
......@@ -587,6 +587,7 @@ struct GridwiseBatchGemmXdlops_gkmkpack_gknkpack_gmn_v2
constexpr index_t BlkSize = blockwise_gemm.GetBlkSize();
constexpr index_t NumBlks = blockwise_gemm.GetNumBlks();
#if 1
// force unrolling the output loop to get ride of scratches
static_for<0, NumBlks, 1>{}([&](auto blk_id) {
// calculate origin of thread output tensor on global memory
......@@ -618,8 +619,39 @@ struct GridwiseBatchGemmXdlops_gkmkpack_gknkpack_gmn_v2
m_thread_data_on_global % M2,
n_thread_data_on_global))
.Store(c_thread_vec, p_c_global);
//.Store(c_thread_vec.GetVector(Number<M0 * M2>{})[Number<blk_id>{}], p_c_global);
});
#else
for(index_t i = 0; i < NumBlks; ++i)
{
// calculate origin of thread output tensor on global memory
// blockwise GEMM c matrix starting index
const auto c_thread_mtx_on_block = blockwise_gemm.GetBeginOfThreadMatrixC(i);
const index_t m_thread_data_on_global =
m_block_data_on_global + c_thread_mtx_on_block.row;
const index_t n_thread_data_on_global =
n_block_data_on_global + c_thread_mtx_on_block.col;
ThreadwiseGenericTensorSliceCopy_v4r2<decltype(c_g_m0_m1_m2_n_thread_desc),
decltype(c_g_m0_m1_m2_n_global_desc),
CThreadCopySliceLengths,
arithmetic_sequence_gen<0, 5, 1>::type,
4,
1,
1,
AddressSpace::Vgpr,
AddressSpace::Global,
CGlobalMemoryOp>(
make_multi_index(0, 0, 0, 0, 0),
make_multi_index(g_block_data_on_global,
m_thread_data_on_global / (M2 * M1),
m_thread_data_on_global % (M2 * M1) / M2,
m_thread_data_on_global % M2,
n_thread_data_on_global))
.Run(c_thread_vec.n + i * BlkSize, p_c_global);
}
#endif
}
}
};
......
......@@ -56,13 +56,10 @@ struct mfma_info<mfma_instr::mfma_f32_32x32x1xf32>
class FloatA,
class FloatB,
class FloatC>
__device__ FloatC run(const FloatA* a, const FloatB* b, FloatC reg_c) const
__device__ FloatC run(const FloatA a, const FloatB b, FloatC reg_c) const
{
const auto p_a = reinterpret_cast<const float*>(a);
const auto p_b = reinterpret_cast<const float*>(b);
return intrin_mfma_f32_32x32x1f32<MPerXdlops, NPerXdlops, AStride, BStride>::run(
p_a, p_b, reg_c);
a, b, reg_c);
}
};
......@@ -90,12 +87,9 @@ struct mfma_info<mfma_instr::mfma_f32_32x32x2xf32>
class FloatA,
class FloatB,
class FloatC>
__device__ FloatC run(const FloatA* a, const FloatB* b, FloatC reg_c) const
__device__ FloatC run(const FloatA a, const FloatB b, FloatC reg_c) const
{
const auto p_a = reinterpret_cast<const float*>(a);
const auto p_b = reinterpret_cast<const float*>(b);
return intrin_mfma_f32_32x32x2f32(p_a, p_b, reg_c);
return intrin_mfma_f32_32x32x2f32(a, b, reg_c);
}
};
......@@ -749,6 +743,23 @@ struct XdlopsGemm_t
constexpr index_t BStride = K * KRepeats;
static_if<!IsKReduction>{}([&](auto) {
static_for<0, K, 1>{}([&](auto k_i) {
index_t a_offset = k_i * M + laneId;
reg_a.GetVector(Number<data_size>{})(Number<k_i>{}) = lds_load(p_a_wave, a_offset);
index_t b_offset = k_i * N + laneId;
reg_b.GetVector(Number<data_size>{})(Number<k_i>{}) = lds_load(p_b_wave, b_offset);
});
static_for<0, K * KRepeats, 1>{}([&](auto k_i) {
p_c_thread = mfma_type.template run<MPerXdlops, NPerXdlops, AStride, BStride>(
reg_a.GetVector(Number<mfma_type.k_base>{})[Number<k_i>{}],
reg_b.GetVector(Number<mfma_type.k_base>{})[Number<k_i>{}],
p_c_thread);
});
#if 0
for(index_t m_i = 0; m_i < MRepeats; ++m_i)
for(index_t k_i = 0; k_i < K; ++k_i)
......@@ -775,15 +786,6 @@ struct XdlopsGemm_t
const index_t blk_id = laneId / mfma_type.num_threads_blk;
const index_t blk_td = laneId % mfma_type.num_threads_blk;
#if 0
// load into registers
for(index_t k_i = 0; k_i < K; k_i += mfma_type.num_input_blks)
{
a[k_i] = p_a_wave[(k_i + blk_id) * M + blk_td];
b[k_i] = p_b_wave[(k_i + blk_id) * N + blk_td];
}
#endif
static_for<0, K, mfma_type.num_input_blks>{}([&](auto k_i) {
index_t a_offset = (k_i + blk_id) * M + blk_td;
......@@ -796,8 +798,6 @@ struct XdlopsGemm_t
lds_load(p_b_wave, b_offset);
});
// for(index_t k_i = 0; k_i < K; k_i += mfma_type.num_input_blks)
// for(index_t i = 0; i < KRepeats; ++i)
static_for<0, K, mfma_type.num_input_blks>{}([&](auto k_i) {
static_for<0, KRepeats, 1>{}([&](auto i) {
constexpr index_t offset = k_i * KRepeats + i;
......
......@@ -93,12 +93,12 @@ struct intrin_mfma_f32_32x32x1f32<64, 128, AStride, BStride>
template <index_t AStride, index_t BStride>
struct intrin_mfma_f32_32x32x1f32<64, 64, AStride, BStride>
{
__device__ static float_vec64_t run(const float* reg_a, const float* reg_b, float_vec64_t reg_c)
__device__ static float_vec64_t run(const float reg_a, const float reg_b, float_vec64_t reg_c)
{
reg_c.v32(Number<0>{}) = llvm_intrin_amdgcn_mfma_f32_32x32x1f32(
reg_a[0], reg_b[0], reg_c.v32[Number<0>{}], 1, 0, 0);
reg_c.v32(Number<1>{}) = llvm_intrin_amdgcn_mfma_f32_32x32x1f32(
reg_a[0], reg_b[0], reg_c.v32[Number<1>{}], 1, 1, 0);
reg_c.v32(Number<0>{}) =
llvm_intrin_amdgcn_mfma_f32_32x32x1f32(reg_a, reg_b, reg_c.v32[Number<0>{}], 1, 0, 0);
reg_c.v32(Number<1>{}) =
llvm_intrin_amdgcn_mfma_f32_32x32x1f32(reg_a, reg_b, reg_c.v32[Number<1>{}], 1, 1, 0);
return reg_c;
}
};
......@@ -125,10 +125,12 @@ struct intrin_mfma_f32_32x32x1f32<32, 64, AStride, BStride>
}
};
__device__ c_vec16_1_t::VecType
intrin_mfma_f32_32x32x2f32(const float* reg_a, const float* reg_b, c_vec16_1_t::VecType reg_c)
__device__ float_vec16_t intrin_mfma_f32_32x32x2f32(const float reg_a,
const float reg_b,
float_vec16_t reg_c)
{
reg_c.s.x = llvm_intrin_amdgcn_mfma_f32_32x32x2f32(reg_a[0], reg_b[0], reg_c.s.x, 0, 0, 0);
reg_c.s16(Number<0>{}) =
llvm_intrin_amdgcn_mfma_f32_32x32x2f32(reg_a, reg_b, reg_c.s16[Number<0>{}], 0, 0, 0);
return reg_c;
}
......
......@@ -219,6 +219,7 @@ union float_vec64_t
StaticallyIndexedArray<float_vec32_t, 2> s32;
StaticallyIndexedArray<float32_t, 2> v32;
StaticallyIndexedArray<float64_t, 1> s64;
// float n[64];
__host__ __device__ constexpr float_vec64_t() { s64(Number<0>{}) = 0; }
template <index_t vs>
......@@ -308,6 +309,12 @@ constexpr auto GetRegBuffer<float, 16>()
return float_vec16_t{};
}
template <>
constexpr auto GetRegBuffer<float, 32>()
{
return float_vec32_t{};
}
template <>
constexpr auto GetRegBuffer<float, 64>()
{
......
......@@ -64,15 +64,15 @@ void gridwise_convolution_forward_implicit_gemm_v4r4_xdlops_nchw_kcyx_nkhw(
make_native_tensor_descriptor_packed(Sequence<N, K, Ho, Wo>{});
// read params: tunning parameters
constexpr index_t GemmMPerBlock = 16;
constexpr index_t GemmNPerBlock = 16;
constexpr index_t GemmMPerBlock = 128;
constexpr index_t GemmNPerBlock = 128;
constexpr index_t GemmKPerBlock = 4;
constexpr index_t GemmMPerWave = 16;
constexpr index_t GemmNPerWave = 16;
constexpr index_t GemmMPerWave = 64;
constexpr index_t GemmNPerWave = 64;
constexpr index_t GemmKPack = 4;
// read params: dependent parameters
constexpr index_t BlockSize = 64;
constexpr index_t BlockSize = 256;
constexpr index_t GemmM = K;
constexpr index_t GemmN = N * Ho * Wo;
......@@ -83,7 +83,7 @@ void gridwise_convolution_forward_implicit_gemm_v4r4_xdlops_nchw_kcyx_nkhw(
// A matrix copy
constexpr index_t GemmABlockCopyClusterLengths_GemmK = 4;
constexpr index_t GemmABlockCopyClusterLengths_GemmM = 16;
constexpr index_t GemmABlockCopyClusterLengths_GemmM = 64;
constexpr index_t GemmABlockCopyClusterLengths_GemmKPack = 1;
constexpr index_t GemmABlockCopyThreadSliceLengths_GemmK =
......@@ -113,9 +113,9 @@ void gridwise_convolution_forward_implicit_gemm_v4r4_xdlops_nchw_kcyx_nkhw(
constexpr index_t GemmABlockCopyDstDataPerWrite_GemmKPack = 4;
// B matrix Copy
constexpr index_t GemmBBlockCopyClusterLengths_GemmK = 4;
constexpr index_t GemmBBlockCopyClusterLengths_GemmN = 4;
constexpr index_t GemmBBlockCopyClusterLengths_GemmKPack = 1;
constexpr index_t GemmBBlockCopyClusterLengths_GemmK = 2;
constexpr index_t GemmBBlockCopyClusterLengths_GemmN = 32;
constexpr index_t GemmBBlockCopyClusterLengths_GemmKPack = 4;
constexpr index_t GemmBBlockCopyThreadSliceLengths_GemmK =
GemmKPerBlock / GemmBBlockCopyClusterLengths_GemmK;
......@@ -141,7 +141,7 @@ void gridwise_convolution_forward_implicit_gemm_v4r4_xdlops_nchw_kcyx_nkhw(
using GemmBBlockCopyDstAccessOrder = Sequence<0, 1, 2, 3>; // [GemmG, GemmK, GemmN, GemmKPack]
constexpr index_t GemmBBlockCopySrcDataPerRead_GemmN = 4;
constexpr index_t GemmBBlockCopyDstDataPerWrite_GemmKPack = 4;
constexpr index_t GemmBBlockCopyDstDataPerWrite_GemmKPack = 1;
// gridwise GEMM
constexpr auto wkgrp_schd_order = NBlock1MBlock0;
......
......@@ -24,11 +24,11 @@ int main(int argc, char* argv[])
using namespace ck;
// 1x1, 56x56
constexpr index_t N = 64;
constexpr index_t C = 128;
constexpr index_t HI = 56;
constexpr index_t WI = 56;
constexpr index_t K = 128;
constexpr index_t N = 128;
constexpr index_t C = 1024;
constexpr index_t HI = 14;
constexpr index_t WI = 14;
constexpr index_t K = 1024;
constexpr index_t Y = 1;
constexpr index_t X = 1;
......
......@@ -10,7 +10,7 @@ cmake
-D CMAKE_INSTALL_PREFIX=${MY_PROJECT_INSTALL} \
-D CMAKE_BUILD_TYPE=Release \
-D DEVICE_BACKEND="AMD" \
-D CMAKE_CXX_FLAGS="-O3 --amdgpu-target=gfx908 -gline-tables-only -save-temps=$CWD" \
-D CMAKE_CXX_FLAGS="-O3 --amdgpu-target=gfx908 -save-temps=$CWD" \
-D CMAKE_CXX_COMPILER=/opt/rocm/bin/hipcc \
-D CMAKE_PREFIX_PATH="/opt/rocm" \
-D CMAKE_VERBOSE_MAKEFILE:BOOL=ON \
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment