Commit be11c5dd authored by Jing Zhang's avatar Jing Zhang
Browse files

new c buffer type

parent 649af8b2
......@@ -131,15 +131,16 @@ struct BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_xdlops
const FloatB* __restrict__ p_b_block,
FloatC p_c_thread)
{
p_c_thread.s.x.l =
XdlopsGemm.template Run<M, N, K>(p_a_block, p_b_block, p_c_thread.s.x.l);
p_c_thread.s.y.l = XdlopsGemm.template Run<M, N, K>(
p_a_block + MPerXdlops, p_b_block, p_c_thread.s.y.l);
p_c_thread.At(Number<64>{})(Number<0>{}) = XdlopsGemm.template Run<M, N, K>(
p_a_block, p_b_block, p_c_thread.At(Number<64>{})[Number<0>{}]);
p_c_thread.At(Number<64>{})(Number<1>{}) = XdlopsGemm.template Run<M, N, K>(
p_a_block + MPerXdlops, p_b_block, p_c_thread.At(Number<64>{})[Number<1>{}]);
return p_c_thread;
}
};
#if 0
template <>
struct WithMNRepeats<1, 2>
{
......@@ -168,6 +169,7 @@ struct BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_xdlops
return XdlopsGemm.template Run<M, N, K>(p_a_block, p_b_block, p_c_thread);
}
};
#endif
#endif
template <class FloatA, class FloatB, class FloatC>
......
......@@ -209,7 +209,8 @@ struct GridwiseBatchGemmXdlops_gkmkpack_gknkpack_gmn_v2
__shared__ ABFloat p_b_block[b_block_space];
// get zero-initialized output register of vector type
auto c_thread_vec = blockwise_gemm.CreateOutputVecZero();
// auto c_thread_vec = blockwise_gemm.CreateOutputVecZero();
auto c_thread_vec = float_vec128_t{};
// preload data into LDS
{
......
......@@ -93,11 +93,12 @@ struct intrin_mfma_f32_32x32x1f32<64, 128, AStride, BStride>
template <index_t AStride, index_t BStride>
struct intrin_mfma_f32_32x32x1f32<64, 64, AStride, BStride>
{
__device__ static c_vec32_2_t::VecType
run(const float* reg_a, const float* reg_b, c_vec32_2_t::VecType reg_c)
__device__ static float_vec64_t run(const float* reg_a, const float* reg_b, float_vec64_t reg_c)
{
reg_c.s.x = llvm_intrin_amdgcn_mfma_f32_32x32x1f32(reg_a[0], reg_b[0], reg_c.s.x, 1, 0, 0);
reg_c.s.y = llvm_intrin_amdgcn_mfma_f32_32x32x1f32(reg_a[0], reg_b[0], reg_c.s.y, 1, 1, 0);
reg_c.At(Number<32>{})(Number<0>{}) = llvm_intrin_amdgcn_mfma_f32_32x32x1f32(
reg_a[0], reg_b[0], reg_c.At(Number<32>{})[Number<0>{}], 1, 0, 0);
reg_c.At(Number<32>{})(Number<1>{}) = llvm_intrin_amdgcn_mfma_f32_32x32x1f32(
reg_a[0], reg_b[0], reg_c.At(Number<32>{})[Number<1>{}], 1, 1, 0);
return reg_c;
}
};
......@@ -464,5 +465,5 @@ struct intrin_mfma_f32_4x4x2bf16<8, 64>
return reg_c;
}
};
}
} // namespace ck
#endif
......@@ -10,6 +10,8 @@ typedef float float4_t __attribute__((ext_vector_type(4)));
typedef float float8_t __attribute__((ext_vector_type(8)));
typedef float float16_t __attribute__((ext_vector_type(16)));
typedef float float32_t __attribute__((ext_vector_type(32)));
typedef float float64_t __attribute__((ext_vector_type(64)));
typedef float float128_t __attribute__((ext_vector_type(128)));
// float16
typedef _Float16 half_t;
......@@ -137,6 +139,107 @@ union float_vec16_t
}
};
union float_vec32_t
{
StaticallyIndexedArray<float, 32> s1;
StaticallyIndexedArray<float2_t, 16> s2;
StaticallyIndexedArray<float4_t, 8> s4;
StaticallyIndexedArray<float_vec8_t, 4> s8;
StaticallyIndexedArray<float_vec16_t, 2> s16;
StaticallyIndexedArray<float32_t, 1> s32;
__host__ __device__ constexpr float_vec32_t() {}
template <index_t vs>
__host__ __device__ auto& At(Number<vs>);
template <>
__host__ __device__ auto& At(Number<1>)
{
return s1;
}
template <>
__host__ __device__ auto& At(Number<2>)
{
return s2;
}
template <>
__host__ __device__ auto& At(Number<4>)
{
return s4;
}
template <>
__host__ __device__ auto& At(Number<8>)
{
return s8;
}
template <>
__host__ __device__ auto& At(Number<16>)
{
return s16;
}
};
union float_vec64_t
{
StaticallyIndexedArray<float, 64> s1;
StaticallyIndexedArray<float32_t, 2> s32;
StaticallyIndexedArray<float64_t, 1> s64;
float n[64];
__host__ __device__ constexpr float_vec64_t() {}
template <index_t vs>
__host__ __device__ auto& At(Number<vs>);
template <>
__host__ __device__ auto& At(Number<1>)
{
return s1;
}
template <>
__host__ __device__ auto& At(Number<32>)
{
return s32;
}
};
union float_vec128_t
{
StaticallyIndexedArray<float, 64> s1;
StaticallyIndexedArray<float32_t, 4> s32;
StaticallyIndexedArray<float_vec64_t, 2> s64;
StaticallyIndexedArray<float128_t, 1> s128;
float n[128];
__host__ __device__ constexpr float_vec128_t() {}
template <index_t vs>
__host__ __device__ auto& At(Number<vs>);
template <>
__host__ __device__ auto& At(Number<1>)
{
return s1;
}
template <>
__host__ __device__ auto& At(Number<32>)
{
return s32;
}
template <>
__host__ __device__ auto& At(Number<64>)
{
return s64;
}
};
template <typename T, index_t BufferSize>
constexpr auto GetRegBuffer();
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment