Commit be11c5dd authored by Jing Zhang's avatar Jing Zhang
Browse files

new c buffer type

parent 649af8b2
...@@ -131,15 +131,16 @@ struct BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_xdlops ...@@ -131,15 +131,16 @@ struct BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_xdlops
const FloatB* __restrict__ p_b_block, const FloatB* __restrict__ p_b_block,
FloatC p_c_thread) FloatC p_c_thread)
{ {
p_c_thread.s.x.l = p_c_thread.At(Number<64>{})(Number<0>{}) = XdlopsGemm.template Run<M, N, K>(
XdlopsGemm.template Run<M, N, K>(p_a_block, p_b_block, p_c_thread.s.x.l); p_a_block, p_b_block, p_c_thread.At(Number<64>{})[Number<0>{}]);
p_c_thread.s.y.l = XdlopsGemm.template Run<M, N, K>( p_c_thread.At(Number<64>{})(Number<1>{}) = XdlopsGemm.template Run<M, N, K>(
p_a_block + MPerXdlops, p_b_block, p_c_thread.s.y.l); p_a_block + MPerXdlops, p_b_block, p_c_thread.At(Number<64>{})[Number<1>{}]);
return p_c_thread; return p_c_thread;
} }
}; };
#if 0
template <> template <>
struct WithMNRepeats<1, 2> struct WithMNRepeats<1, 2>
{ {
...@@ -168,6 +169,7 @@ struct BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_xdlops ...@@ -168,6 +169,7 @@ struct BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_xdlops
return XdlopsGemm.template Run<M, N, K>(p_a_block, p_b_block, p_c_thread); return XdlopsGemm.template Run<M, N, K>(p_a_block, p_b_block, p_c_thread);
} }
}; };
#endif
#endif #endif
template <class FloatA, class FloatB, class FloatC> template <class FloatA, class FloatB, class FloatC>
......
...@@ -209,7 +209,8 @@ struct GridwiseBatchGemmXdlops_gkmkpack_gknkpack_gmn_v2 ...@@ -209,7 +209,8 @@ struct GridwiseBatchGemmXdlops_gkmkpack_gknkpack_gmn_v2
__shared__ ABFloat p_b_block[b_block_space]; __shared__ ABFloat p_b_block[b_block_space];
// get zero-initialized output register of vector type // get zero-initialized output register of vector type
auto c_thread_vec = blockwise_gemm.CreateOutputVecZero(); // auto c_thread_vec = blockwise_gemm.CreateOutputVecZero();
auto c_thread_vec = float_vec128_t{};
// preload data into LDS // preload data into LDS
{ {
......
...@@ -93,11 +93,12 @@ struct intrin_mfma_f32_32x32x1f32<64, 128, AStride, BStride> ...@@ -93,11 +93,12 @@ struct intrin_mfma_f32_32x32x1f32<64, 128, AStride, BStride>
template <index_t AStride, index_t BStride> template <index_t AStride, index_t BStride>
struct intrin_mfma_f32_32x32x1f32<64, 64, AStride, BStride> struct intrin_mfma_f32_32x32x1f32<64, 64, AStride, BStride>
{ {
__device__ static c_vec32_2_t::VecType __device__ static float_vec64_t run(const float* reg_a, const float* reg_b, float_vec64_t reg_c)
run(const float* reg_a, const float* reg_b, c_vec32_2_t::VecType reg_c)
{ {
reg_c.s.x = llvm_intrin_amdgcn_mfma_f32_32x32x1f32(reg_a[0], reg_b[0], reg_c.s.x, 1, 0, 0); reg_c.At(Number<32>{})(Number<0>{}) = llvm_intrin_amdgcn_mfma_f32_32x32x1f32(
reg_c.s.y = llvm_intrin_amdgcn_mfma_f32_32x32x1f32(reg_a[0], reg_b[0], reg_c.s.y, 1, 1, 0); reg_a[0], reg_b[0], reg_c.At(Number<32>{})[Number<0>{}], 1, 0, 0);
reg_c.At(Number<32>{})(Number<1>{}) = llvm_intrin_amdgcn_mfma_f32_32x32x1f32(
reg_a[0], reg_b[0], reg_c.At(Number<32>{})[Number<1>{}], 1, 1, 0);
return reg_c; return reg_c;
} }
}; };
...@@ -464,5 +465,5 @@ struct intrin_mfma_f32_4x4x2bf16<8, 64> ...@@ -464,5 +465,5 @@ struct intrin_mfma_f32_4x4x2bf16<8, 64>
return reg_c; return reg_c;
} }
}; };
} } // namespace ck
#endif #endif
...@@ -10,6 +10,8 @@ typedef float float4_t __attribute__((ext_vector_type(4))); ...@@ -10,6 +10,8 @@ typedef float float4_t __attribute__((ext_vector_type(4)));
typedef float float8_t __attribute__((ext_vector_type(8))); typedef float float8_t __attribute__((ext_vector_type(8)));
typedef float float16_t __attribute__((ext_vector_type(16))); typedef float float16_t __attribute__((ext_vector_type(16)));
typedef float float32_t __attribute__((ext_vector_type(32))); typedef float float32_t __attribute__((ext_vector_type(32)));
typedef float float64_t __attribute__((ext_vector_type(64)));
typedef float float128_t __attribute__((ext_vector_type(128)));
// float16 // float16
typedef _Float16 half_t; typedef _Float16 half_t;
...@@ -137,6 +139,107 @@ union float_vec16_t ...@@ -137,6 +139,107 @@ union float_vec16_t
} }
}; };
union float_vec32_t
{
StaticallyIndexedArray<float, 32> s1;
StaticallyIndexedArray<float2_t, 16> s2;
StaticallyIndexedArray<float4_t, 8> s4;
StaticallyIndexedArray<float_vec8_t, 4> s8;
StaticallyIndexedArray<float_vec16_t, 2> s16;
StaticallyIndexedArray<float32_t, 1> s32;
__host__ __device__ constexpr float_vec32_t() {}
template <index_t vs>
__host__ __device__ auto& At(Number<vs>);
template <>
__host__ __device__ auto& At(Number<1>)
{
return s1;
}
template <>
__host__ __device__ auto& At(Number<2>)
{
return s2;
}
template <>
__host__ __device__ auto& At(Number<4>)
{
return s4;
}
template <>
__host__ __device__ auto& At(Number<8>)
{
return s8;
}
template <>
__host__ __device__ auto& At(Number<16>)
{
return s16;
}
};
union float_vec64_t
{
StaticallyIndexedArray<float, 64> s1;
StaticallyIndexedArray<float32_t, 2> s32;
StaticallyIndexedArray<float64_t, 1> s64;
float n[64];
__host__ __device__ constexpr float_vec64_t() {}
template <index_t vs>
__host__ __device__ auto& At(Number<vs>);
template <>
__host__ __device__ auto& At(Number<1>)
{
return s1;
}
template <>
__host__ __device__ auto& At(Number<32>)
{
return s32;
}
};
union float_vec128_t
{
StaticallyIndexedArray<float, 64> s1;
StaticallyIndexedArray<float32_t, 4> s32;
StaticallyIndexedArray<float_vec64_t, 2> s64;
StaticallyIndexedArray<float128_t, 1> s128;
float n[128];
__host__ __device__ constexpr float_vec128_t() {}
template <index_t vs>
__host__ __device__ auto& At(Number<vs>);
template <>
__host__ __device__ auto& At(Number<1>)
{
return s1;
}
template <>
__host__ __device__ auto& At(Number<32>)
{
return s32;
}
template <>
__host__ __device__ auto& At(Number<64>)
{
return s64;
}
};
template <typename T, index_t BufferSize> template <typename T, index_t BufferSize>
constexpr auto GetRegBuffer(); constexpr auto GetRegBuffer();
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment