Commit de1afb7b authored by Rostyslav Geyyer's avatar Rostyslav Geyyer
Browse files

Merge branch 'develop' of https://github.com/ROCmSoftwarePlatform/composable_kernel into lwpck-977

parents ce562aa6 f7331c60
...@@ -462,7 +462,6 @@ struct mfma_type<MfmaInstr::mfma_f64_16x16x4f64> ...@@ -462,7 +462,6 @@ struct mfma_type<MfmaInstr::mfma_f64_16x16x4f64>
} }
}; };
#if defined CK_ENABLE_FP8
template <> template <>
struct mfma_type<MfmaInstr::mfma_f32_32x32x16f8f8> struct mfma_type<MfmaInstr::mfma_f32_32x32x16f8f8>
{ {
...@@ -506,9 +505,7 @@ struct mfma_type<MfmaInstr::mfma_f32_16x16x32f8f8> ...@@ -506,9 +505,7 @@ struct mfma_type<MfmaInstr::mfma_f32_16x16x32f8f8>
intrin_mfma_f32_16x16x32f8f8<MPerXdlops, NPerXdlops>::Run(a, b, reg_c); intrin_mfma_f32_16x16x32f8f8<MPerXdlops, NPerXdlops>::Run(a, b, reg_c);
} }
}; };
#endif
#if defined CK_ENABLE_BF8
template <> template <>
struct mfma_type<MfmaInstr::mfma_f32_32x32x16bf8bf8> struct mfma_type<MfmaInstr::mfma_f32_32x32x16bf8bf8>
{ {
...@@ -552,9 +549,7 @@ struct mfma_type<MfmaInstr::mfma_f32_16x16x32bf8bf8> ...@@ -552,9 +549,7 @@ struct mfma_type<MfmaInstr::mfma_f32_16x16x32bf8bf8>
intrin_mfma_f32_16x16x32bf8bf8<MPerXdlops, NPerXdlops>::Run(a, b, reg_c); intrin_mfma_f32_16x16x32bf8bf8<MPerXdlops, NPerXdlops>::Run(a, b, reg_c);
} }
}; };
#endif
#if defined CK_ENABLE_FP8 && defined CK_ENABLE_BF8
template <> template <>
struct mfma_type<MfmaInstr::mfma_f32_32x32x16f8bf8> struct mfma_type<MfmaInstr::mfma_f32_32x32x16f8bf8>
{ {
...@@ -598,9 +593,7 @@ struct mfma_type<MfmaInstr::mfma_f32_16x16x32f8bf8> ...@@ -598,9 +593,7 @@ struct mfma_type<MfmaInstr::mfma_f32_16x16x32f8bf8>
intrin_mfma_f32_16x16x32f8bf8<MPerXdlops, NPerXdlops>::Run(a, b, reg_c); intrin_mfma_f32_16x16x32f8bf8<MPerXdlops, NPerXdlops>::Run(a, b, reg_c);
} }
}; };
#endif
#if defined CK_ENABLE_FP8 && defined CK_ENABLE_BF8
template <> template <>
struct mfma_type<MfmaInstr::mfma_f32_32x32x16bf8f8> struct mfma_type<MfmaInstr::mfma_f32_32x32x16bf8f8>
{ {
...@@ -644,7 +637,6 @@ struct mfma_type<MfmaInstr::mfma_f32_16x16x32bf8f8> ...@@ -644,7 +637,6 @@ struct mfma_type<MfmaInstr::mfma_f32_16x16x32bf8f8>
intrin_mfma_f32_16x16x32bf8f8<MPerXdlops, NPerXdlops>::Run(a, b, reg_c); intrin_mfma_f32_16x16x32bf8f8<MPerXdlops, NPerXdlops>::Run(a, b, reg_c);
} }
}; };
#endif
template <typename base_type, template <typename base_type,
index_t MPerXdlops, index_t MPerXdlops,
...@@ -792,7 +784,6 @@ struct MfmaSelector ...@@ -792,7 +784,6 @@ struct MfmaSelector
} }
#endif #endif
#if defined CK_ENABLE_FP8
template <> template <>
static constexpr auto GetMfma<f8_t, 32, 32>() static constexpr auto GetMfma<f8_t, 32, 32>()
{ {
...@@ -804,9 +795,7 @@ struct MfmaSelector ...@@ -804,9 +795,7 @@ struct MfmaSelector
{ {
return MfmaInstr::mfma_f32_16x16x32f8f8; return MfmaInstr::mfma_f32_16x16x32f8f8;
} }
#endif
#if defined CK_ENABLE_BF8
template <> template <>
static constexpr auto GetMfma<bf8_t, 32, 32>() static constexpr auto GetMfma<bf8_t, 32, 32>()
{ {
...@@ -818,9 +807,7 @@ struct MfmaSelector ...@@ -818,9 +807,7 @@ struct MfmaSelector
{ {
return MfmaInstr::mfma_f32_16x16x32bf8bf8; return MfmaInstr::mfma_f32_16x16x32bf8bf8;
} }
#endif
#if defined CK_ENABLE_FP8 && defined CK_ENABLE_BF8
template <> template <>
static constexpr auto GetMfma<f8_t, 32, 32, bf8_t>() static constexpr auto GetMfma<f8_t, 32, 32, bf8_t>()
{ {
...@@ -832,9 +819,7 @@ struct MfmaSelector ...@@ -832,9 +819,7 @@ struct MfmaSelector
{ {
return MfmaInstr::mfma_f32_16x16x32f8bf8; return MfmaInstr::mfma_f32_16x16x32f8bf8;
} }
#endif
#if defined CK_ENABLE_FP8 && defined CK_ENABLE_BF8
template <> template <>
static constexpr auto GetMfma<bf8_t, 32, 32, f8_t>() static constexpr auto GetMfma<bf8_t, 32, 32, f8_t>()
{ {
...@@ -846,7 +831,6 @@ struct MfmaSelector ...@@ -846,7 +831,6 @@ struct MfmaSelector
{ {
return MfmaInstr::mfma_f32_16x16x32bf8f8; return MfmaInstr::mfma_f32_16x16x32bf8f8;
} }
#endif
static constexpr auto selected_mfma = static constexpr auto selected_mfma =
mfma_type<GetMfma<base_type, MPerXdlops, NPerXdlops, additional_type>()>{}; mfma_type<GetMfma<base_type, MPerXdlops, NPerXdlops, additional_type>()>{};
...@@ -1051,18 +1035,10 @@ struct XdlopsGemm ...@@ -1051,18 +1035,10 @@ struct XdlopsGemm
static_assert( static_assert(
is_same<base_type, double>::value || is_same<base_type, float>::value || is_same<base_type, double>::value || is_same<base_type, float>::value ||
is_same<base_type, half_t>::value || is_same<base_type, bhalf_t>::value || is_same<base_type, half_t>::value || is_same<base_type, bhalf_t>::value ||
is_same<base_type, int8_t>::value is_same<base_type, int8_t>::value || is_same<base_type, f8_t>::value ||
#if defined CK_ENABLE_FP8 is_same<base_type, bf8_t>::value ||
|| is_same<base_type, f8_t>::value (is_same<base_type, f8_t>::value && is_same<additional_type, bf8_t>::value) ||
#endif (is_same<base_type, bf8_t>::value && is_same<additional_type, f8_t>::value),
#if defined CK_ENABLE_BF8
|| is_same<base_type, bf8_t>::value
#endif
#if defined CK_ENABLE_FP8 && defined CK_ENABLE_BF8
|| (is_same<base_type, f8_t>::value && is_same<additional_type, bf8_t>::value) ||
(is_same<base_type, bf8_t>::value && is_same<additional_type, f8_t>::value)
#endif
,
"base base_type must be double, float, half, bfloat16, int8_t, f8_t or bf8_t!"); "base base_type must be double, float, half, bfloat16, int8_t, f8_t or bf8_t!");
static_for<0, KPack / mfma_instr.k_per_blk, 1>{}([&](auto k) { static_for<0, KPack / mfma_instr.k_per_blk, 1>{}([&](auto k) {
......
// SPDX-License-Identifier: MIT // SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#ifndef CK_AMD_XDLOPS_HPP #pragma once
#define CK_AMD_XDLOPS_HPP
#include "data_type.hpp"
namespace ck { namespace ck {
...@@ -355,7 +352,6 @@ struct intrin_mfma_f64_16x16x4f64<16, 16> ...@@ -355,7 +352,6 @@ struct intrin_mfma_f64_16x16x4f64<16, 16>
} }
}; };
#if defined CK_ENABLE_FP8
template <index_t MPerWave, index_t NPerWave> template <index_t MPerWave, index_t NPerWave>
struct intrin_mfma_f32_32x32x16f8f8; struct intrin_mfma_f32_32x32x16f8f8;
...@@ -418,9 +414,7 @@ struct intrin_mfma_f32_16x16x32f8f8<16, 16> ...@@ -418,9 +414,7 @@ struct intrin_mfma_f32_16x16x32f8f8<16, 16>
#endif #endif
} }
}; };
#endif
#if defined CK_ENABLE_BF8
template <index_t MPerWave, index_t NPerWave> template <index_t MPerWave, index_t NPerWave>
struct intrin_mfma_f32_32x32x16bf8bf8; struct intrin_mfma_f32_32x32x16bf8bf8;
...@@ -483,9 +477,7 @@ struct intrin_mfma_f32_16x16x32bf8bf8<16, 16> ...@@ -483,9 +477,7 @@ struct intrin_mfma_f32_16x16x32bf8bf8<16, 16>
#endif #endif
} }
}; };
#endif
#if defined CK_ENABLE_FP8 && defined CK_ENABLE_BF8
template <index_t MPerWave, index_t NPerWave> template <index_t MPerWave, index_t NPerWave>
struct intrin_mfma_f32_32x32x16f8bf8; struct intrin_mfma_f32_32x32x16f8bf8;
...@@ -548,9 +540,7 @@ struct intrin_mfma_f32_16x16x32f8bf8<16, 16> ...@@ -548,9 +540,7 @@ struct intrin_mfma_f32_16x16x32f8bf8<16, 16>
#endif #endif
} }
}; };
#endif
#if defined CK_ENABLE_FP8 && defined CK_ENABLE_BF8
template <index_t MPerWave, index_t NPerWave> template <index_t MPerWave, index_t NPerWave>
struct intrin_mfma_f32_32x32x16bf8f8; struct intrin_mfma_f32_32x32x16bf8f8;
...@@ -613,6 +603,5 @@ struct intrin_mfma_f32_16x16x32bf8f8<16, 16> ...@@ -613,6 +603,5 @@ struct intrin_mfma_f32_16x16x32bf8f8<16, 16>
#endif #endif
} }
}; };
#endif
} // namespace ck } // namespace ck
#endif
...@@ -9,9 +9,7 @@ namespace ck { ...@@ -9,9 +9,7 @@ namespace ck {
using bhalf_t = ushort; using bhalf_t = ushort;
using half_t = _Float16; using half_t = _Float16;
#ifdef CK_EXPERIMENTAL_BIT_INT_EXTENSION_INT4
using int4_t = _BitInt(4); using int4_t = _BitInt(4);
#endif
using f8_t = _BitInt(8); using f8_t = _BitInt(8);
using bf8_t = unsigned _BitInt(8); using bf8_t = unsigned _BitInt(8);
...@@ -1123,5 +1121,4 @@ struct NumericUtils<bf8_t> ...@@ -1123,5 +1121,4 @@ struct NumericUtils<bf8_t>
static constexpr int bias = 16; // negative zero nan mode static constexpr int bias = 16; // negative zero nan mode
// static constexpr int bias = 15; // ieee mode // static constexpr int bias = 15; // ieee mode
}; };
} // namespace ck } // namespace ck
...@@ -192,6 +192,8 @@ inner_product<int8x4_t, int8x4_t, int32_t>(const int8x4_t& a, const int8x4_t& b, ...@@ -192,6 +192,8 @@ inner_product<int8x4_t, int8x4_t, int32_t>(const int8x4_t& a, const int8x4_t& b,
#else #else
c = __builtin_amdgcn_sdot4(bit_cast<int32_t>(a), bit_cast<int32_t>(b), c, false); c = __builtin_amdgcn_sdot4(bit_cast<int32_t>(a), bit_cast<int32_t>(b), c, false);
#endif #endif
#elif defined(CK_USE_AMD_V_DOT4_I32_I8_GFX11)
c = __builtin_amdgcn_sudot4(true, bit_cast<int32_t>(a), true, bit_cast<int32_t>(b), c, false);
#else #else
const vector_type<int8_t, 4> a_vector{a}; const vector_type<int8_t, 4> a_vector{a};
const vector_type<int8_t, 4> b_vector{b}; const vector_type<int8_t, 4> b_vector{b};
......
This diff is collapsed.
This diff is collapsed.
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment