Commit 1b409ffe authored by Jing Zhang's avatar Jing Zhang
Browse files

fix mfma_int8 on MI300

parent 3973caa4
...@@ -259,7 +259,6 @@ struct intrin_mfma_f32_16x16x8bf16<16, 16> ...@@ -259,7 +259,6 @@ struct intrin_mfma_f32_16x16x8bf16<16, 16>
} }
}; };
#if (defined(__gfx908__) || defined(__gfx90a__))
template <index_t MPerWave, index_t NPerWave> template <index_t MPerWave, index_t NPerWave>
struct intrin_mfma_i32_32x32x8i8; struct intrin_mfma_i32_32x32x8i8;
...@@ -278,7 +277,26 @@ struct intrin_mfma_i32_32x32x8i8<32, 32> ...@@ -278,7 +277,26 @@ struct intrin_mfma_i32_32x32x8i8<32, 32>
0); 0);
} }
}; };
#elif (defined(__gfx940__))
template <index_t MPerWave, index_t NPerWave>
struct intrin_mfma_i32_16x16x16i8;
template <>
struct intrin_mfma_i32_16x16x16i8<16, 16>
{
template <class FloatC>
__device__ static void Run(const int8x4_t& reg_a, const int8x4_t& reg_b, FloatC& reg_c)
{
reg_c.template AsType<int32x4_t>()(Number<0>{}) =
__builtin_amdgcn_mfma_i32_16x16x16i8(bit_cast<int32_t>(reg_a),
bit_cast<int32_t>(reg_b),
reg_c.template AsType<int32x4_t>()[Number<0>{}],
0,
0,
0);
}
};
template <index_t MPerWave, index_t NPerWave> template <index_t MPerWave, index_t NPerWave>
struct intrin_mfma_i32_32x32x16i8; struct intrin_mfma_i32_32x32x16i8;
...@@ -286,31 +304,30 @@ template <> ...@@ -286,31 +304,30 @@ template <>
struct intrin_mfma_i32_32x32x16i8<32, 32> struct intrin_mfma_i32_32x32x16i8<32, 32>
{ {
template <class FloatC> template <class FloatC>
__device__ static void Run(const int8x4_t& reg_a, const int8x4_t& reg_b, FloatC& reg_c) __device__ static void Run(const int8x8_t& reg_a, const int8x8_t& reg_b, FloatC& reg_c)
{ {
reg_c.template AsType<int32x16_t>()(Number<0>{}) = reg_c.template AsType<int32x16_t>()(Number<0>{}) =
__builtin_amdgcn_mfma_i32_32x32x16_i8(bit_cast<int32_t>(reg_a), __builtin_amdgcn_mfma_i32_32x32x16_i8(bit_cast<int64_t>(reg_a),
bit_cast<int32_t>(reg_b), bit_cast<int64_t>(reg_b),
reg_c.template AsType<int32x16_t>()[Number<0>{}], reg_c.template AsType<int32x16_t>()[Number<0>{}],
0, 0,
0, 0,
0); 0);
} }
}; };
#endif
template <index_t MPerWave, index_t NPerWave> template <index_t MPerWave, index_t NPerWave>
struct intrin_mfma_i32_16x16x16i8; struct intrin_mfma_i32_16x16x32i8;
template <> template <>
struct intrin_mfma_i32_16x16x16i8<16, 16> struct intrin_mfma_i32_16x16x32i8<16, 16>
{ {
template <class FloatC> template <class FloatC>
__device__ static void Run(const int8x4_t& reg_a, const int8x4_t& reg_b, FloatC& reg_c) __device__ static void Run(const int8x8_t& reg_a, const int8x8_t& reg_b, FloatC& reg_c)
{ {
reg_c.template AsType<int32x4_t>()(Number<0>{}) = reg_c.template AsType<int32x4_t>()(Number<0>{}) =
__builtin_amdgcn_mfma_i32_16x16x16i8(bit_cast<int32_t>(reg_a), __builtin_amdgcn_mfma_i32_16x16x32i8(bit_cast<int64_t>(reg_a),
bit_cast<int32_t>(reg_b), bit_cast<int64_t>(reg_b),
reg_c.template AsType<int32x4_t>()[Number<0>{}], reg_c.template AsType<int32x4_t>()[Number<0>{}],
0, 0,
0, 0,
......
...@@ -898,6 +898,8 @@ struct vector_type<T, 256> ...@@ -898,6 +898,8 @@ struct vector_type<T, 256>
} }
}; };
using int64_t = long;
// fp64 // fp64
using double2_t = typename vector_type<double, 2>::type; using double2_t = typename vector_type<double, 2>::type;
using double4_t = typename vector_type<double, 4>::type; using double4_t = typename vector_type<double, 4>::type;
......
...@@ -10,8 +10,8 @@ cmake ...@@ -10,8 +10,8 @@ cmake
-D CMAKE_CXX_COMPILER=/opt/rocm/bin/hipcc \ -D CMAKE_CXX_COMPILER=/opt/rocm/bin/hipcc \
-D CMAKE_CXX_FLAGS="-O3 -ftemplate-backtrace-limit=0 -gline-tables-only -save-temps=$PWD" \ -D CMAKE_CXX_FLAGS="-O3 -ftemplate-backtrace-limit=0 -gline-tables-only -save-temps=$PWD" \
-D CMAKE_BUILD_TYPE=Release \ -D CMAKE_BUILD_TYPE=Release \
-D BUILD_DEV=ON \ -D BUILD_DEV=OFF \
-D GPU_TARGETS="gfx90a" \ -D GPU_TARGETS="gfx940" \
-D CMAKE_VERBOSE_MAKEFILE:BOOL=ON \ -D CMAKE_VERBOSE_MAKEFILE:BOOL=ON \
-D USE_BITINT_EXTENSION_INT4=OFF \ -D USE_BITINT_EXTENSION_INT4=OFF \
${MY_PROJECT_SOURCE} ${MY_PROJECT_SOURCE}
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment