Unverified Commit 42e6dcea authored by Illia Silin's avatar Illia Silin Committed by GitHub
Browse files

Fix compilation errors with Clang20.0. (#1533)

* fix clang20 compilation errors for gfx90a

* fix clang20 compilation errors for gfx11 targets
parent 65f8d144
......@@ -406,7 +406,7 @@ struct BlockwiseGemmXdlops_pipeline_v4
}
template <>
__device__ static constexpr auto TailScheduler<1>()
__device__ constexpr auto TailScheduler<1>()
{
// schedule
constexpr auto num_ds_read_inst =
......@@ -433,7 +433,7 @@ struct BlockwiseGemmXdlops_pipeline_v4
}
template <>
__device__ static constexpr auto TailScheduler<2>()
__device__ constexpr auto TailScheduler<2>()
{
// schedule
constexpr auto num_ds_read_inst =
......
......@@ -324,55 +324,55 @@ struct DppSelector
static constexpr auto GetDpp();
template <>
static constexpr auto GetDpp<half_t, 8, 32>()
constexpr auto GetDpp<half_t, 8, 32>()
{
return DppInstr::dpp8_f16_8x32x2;
}
template <>
static constexpr auto GetDpp<half_t, 8, 16>()
constexpr auto GetDpp<half_t, 8, 16>()
{
return DppInstr::dpp8_f16_8x16x2;
}
template <>
static constexpr auto GetDpp<half_t, 16, 16>()
constexpr auto GetDpp<half_t, 16, 16>()
{
return DppInstr::dpp8_f16_16x16x2;
}
template <>
static constexpr auto GetDpp<half_t, 32, 8>()
constexpr auto GetDpp<half_t, 32, 8>()
{
return DppInstr::dpp8_f16_32x8x2;
}
template <>
static constexpr auto GetDpp<half_t, 1, 32>()
constexpr auto GetDpp<half_t, 1, 32>()
{
return DppInstr::dpp8_f16_1x32x2;
}
template <>
static constexpr auto GetDpp<half_t, 2, 32>()
constexpr auto GetDpp<half_t, 2, 32>()
{
return DppInstr::dpp8_f16_2x32x2;
}
template <>
static constexpr auto GetDpp<half_t, 2, 16>()
constexpr auto GetDpp<half_t, 2, 16>()
{
return DppInstr::dpp8_f16_2x16x2;
}
template <>
static constexpr auto GetDpp<half_t, 4, 16>()
constexpr auto GetDpp<half_t, 4, 16>()
{
return DppInstr::dpp8_f16_4x16x2;
}
template <>
static constexpr auto GetDpp<half_t, 4, 32>()
constexpr auto GetDpp<half_t, 4, 32>()
{
return DppInstr::dpp8_f16_4x32x2;
}
......
......@@ -415,7 +415,7 @@ struct WmmaSelector
static constexpr auto GetWmma();
template <>
static constexpr auto GetWmma<half_t, half_t, float, 16, 16>()
constexpr auto GetWmma<half_t, half_t, float, 16, 16>()
{
#ifdef __gfx12__
return WmmaInstr::wmma_f32_16x16x16_f16_gfx12;
......@@ -425,7 +425,7 @@ struct WmmaSelector
}
template <>
static constexpr auto GetWmma<bhalf_t, bhalf_t, float, 16, 16>()
constexpr auto GetWmma<bhalf_t, bhalf_t, float, 16, 16>()
{
#ifdef __gfx12__
return WmmaInstr::wmma_f32_16x16x16_bf16_gfx12;
......@@ -435,19 +435,19 @@ struct WmmaSelector
}
template <>
static constexpr auto GetWmma<half_t, half_t, half_t, 16, 16>()
constexpr auto GetWmma<half_t, half_t, half_t, 16, 16>()
{
return WmmaInstr::wmma_f16_16x16x16_f16;
}
template <>
static constexpr auto GetWmma<bhalf_t, bhalf_t, bhalf_t, 16, 16>()
constexpr auto GetWmma<bhalf_t, bhalf_t, bhalf_t, 16, 16>()
{
return WmmaInstr::wmma_bf16_16x16x16_bf16;
}
template <>
static constexpr auto GetWmma<int8_t, int8_t, int, 16, 16>()
constexpr auto GetWmma<int8_t, int8_t, int, 16, 16>()
{
#ifdef __gfx12__
return WmmaInstr::wmma_i32_16x16x16_iu8_gfx12;
......@@ -458,7 +458,7 @@ struct WmmaSelector
#ifdef CK_EXPERIMENTAL_BIT_INT_EXTENSION_INT4
template <>
static constexpr auto GetWmma<int4_t, int4_t, int, 16, 16>()
constexpr auto GetWmma<int4_t, int4_t, int, 16, 16>()
{
return WmmaInstr::wmma_i32_16x16x16_iu4;
}
......
......@@ -651,97 +651,97 @@ struct MfmaSelector
static constexpr auto GetMfma();
template <>
static constexpr auto GetMfma<double, 16, 16>()
constexpr auto GetMfma<double, 16, 16>()
{
return MfmaInstr::mfma_f64_16x16x4f64;
}
template <>
static constexpr auto GetMfma<float, 64, 64>()
constexpr auto GetMfma<float, 64, 64>()
{
return MfmaInstr::mfma_f32_32x32x1xf32;
}
template <>
static constexpr auto GetMfma<float, 32, 64>()
constexpr auto GetMfma<float, 32, 64>()
{
return MfmaInstr::mfma_f32_32x32x1xf32;
}
template <>
static constexpr auto GetMfma<float, 16, 64>()
constexpr auto GetMfma<float, 16, 64>()
{
return MfmaInstr::mfma_f32_16x16x1xf32;
}
template <>
static constexpr auto GetMfma<float, 8, 64>()
constexpr auto GetMfma<float, 8, 64>()
{
return MfmaInstr::mfma_f32_4x4x1xf32;
}
template <>
static constexpr auto GetMfma<float, 4, 64>()
constexpr auto GetMfma<float, 4, 64>()
{
return MfmaInstr::mfma_f32_4x4x1xf32;
}
template <>
static constexpr auto GetMfma<float, 32, 32>()
constexpr auto GetMfma<float, 32, 32>()
{
return MfmaInstr::mfma_f32_32x32x2xf32;
}
template <>
static constexpr auto GetMfma<float, 16, 16>()
constexpr auto GetMfma<float, 16, 16>()
{
return MfmaInstr::mfma_f32_16x16x4xf32;
}
template <>
static constexpr auto GetMfma<half_t, 64, 64>()
constexpr auto GetMfma<half_t, 64, 64>()
{
return MfmaInstr::mfma_f32_32x32x4f16;
}
template <>
static constexpr auto GetMfma<half_t, 32, 64>()
constexpr auto GetMfma<half_t, 32, 64>()
{
return MfmaInstr::mfma_f32_32x32x4f16;
}
template <>
static constexpr auto GetMfma<half_t, 32, 32>()
constexpr auto GetMfma<half_t, 32, 32>()
{
return MfmaInstr::mfma_f32_32x32x8f16;
}
template <>
static constexpr auto GetMfma<half_t, 16, 16>()
constexpr auto GetMfma<half_t, 16, 16>()
{
return MfmaInstr::mfma_f32_16x16x16f16;
}
template <>
static constexpr auto GetMfma<half_t, 16, 64>()
constexpr auto GetMfma<half_t, 16, 64>()
{
return MfmaInstr::mfma_f32_16x16x4f16;
}
template <>
static constexpr auto GetMfma<half_t, 8, 64>()
constexpr auto GetMfma<half_t, 8, 64>()
{
return MfmaInstr::mfma_f32_4x4x4f16;
}
template <>
static constexpr auto GetMfma<half_t, 4, 64>()
constexpr auto GetMfma<half_t, 4, 64>()
{
return MfmaInstr::mfma_f32_4x4x4f16;
}
template <>
static constexpr auto GetMfma<bhalf_t, 32, 32>()
constexpr auto GetMfma<bhalf_t, 32, 32>()
{
#if defined(CK_USE_AMD_MFMA_BF16_1K_OP)
return MfmaInstr::mfma_f32_32x32x8bf16_1k;
......@@ -751,7 +751,7 @@ struct MfmaSelector
}
template <>
static constexpr auto GetMfma<bhalf_t, 16, 16>()
constexpr auto GetMfma<bhalf_t, 16, 16>()
{
#if defined(CK_USE_AMD_MFMA_BF16_1K_OP)
return MfmaInstr::mfma_f32_16x16x16bf16_1k;
......@@ -762,72 +762,72 @@ struct MfmaSelector
#if defined(CK_USE_AMD_MFMA_GFX940)
template <>
static constexpr auto GetMfma<int8_t, 32, 32>()
constexpr auto GetMfma<int8_t, 32, 32>()
{
return MfmaInstr::mfma_i32_32x32x16i8;
}
template <>
static constexpr auto GetMfma<int8_t, 16, 16>()
constexpr auto GetMfma<int8_t, 16, 16>()
{
return MfmaInstr::mfma_i32_16x16x32i8;
}
#else
template <>
static constexpr auto GetMfma<int8_t, 32, 32>()
constexpr auto GetMfma<int8_t, 32, 32>()
{
return MfmaInstr::mfma_i32_32x32x8i8;
}
template <>
static constexpr auto GetMfma<int8_t, 16, 16>()
constexpr auto GetMfma<int8_t, 16, 16>()
{
return MfmaInstr::mfma_i32_16x16x16i8;
}
#endif
template <>
static constexpr auto GetMfma<f8_t, 32, 32>()
constexpr auto GetMfma<f8_t, 32, 32>()
{
return MfmaInstr::mfma_f32_32x32x16f8f8;
}
template <>
static constexpr auto GetMfma<f8_t, 16, 16>()
constexpr auto GetMfma<f8_t, 16, 16>()
{
return MfmaInstr::mfma_f32_16x16x32f8f8;
}
template <>
static constexpr auto GetMfma<bf8_t, 32, 32>()
constexpr auto GetMfma<bf8_t, 32, 32>()
{
return MfmaInstr::mfma_f32_32x32x16bf8bf8;
}
template <>
static constexpr auto GetMfma<bf8_t, 16, 16>()
constexpr auto GetMfma<bf8_t, 16, 16>()
{
return MfmaInstr::mfma_f32_16x16x32bf8bf8;
}
template <>
static constexpr auto GetMfma<f8_t, 32, 32, bf8_t>()
constexpr auto GetMfma<f8_t, 32, 32, bf8_t>()
{
return MfmaInstr::mfma_f32_32x32x16f8bf8;
}
template <>
static constexpr auto GetMfma<f8_t, 16, 16, bf8_t>()
constexpr auto GetMfma<f8_t, 16, 16, bf8_t>()
{
return MfmaInstr::mfma_f32_16x16x32f8bf8;
}
template <>
static constexpr auto GetMfma<bf8_t, 32, 32, f8_t>()
constexpr auto GetMfma<bf8_t, 32, 32, f8_t>()
{
return MfmaInstr::mfma_f32_32x32x16bf8f8;
}
template <>
static constexpr auto GetMfma<bf8_t, 16, 16, f8_t>()
constexpr auto GetMfma<bf8_t, 16, 16, f8_t>()
{
return MfmaInstr::mfma_f32_16x16x32bf8f8;
}
......
......@@ -1727,7 +1727,7 @@ struct BlockFmhaBwdPipelineDefaultPolicy
}
template <>
CK_TILE_DEVICE static constexpr void GemmStagedScheduler<0>()
CK_TILE_DEVICE constexpr void GemmStagedScheduler<0>()
{
// Mem: Q, LSE, OGrad, D global load, OGrad^T LDS load
// Comp: Q x K
......@@ -1759,7 +1759,7 @@ struct BlockFmhaBwdPipelineDefaultPolicy
}
template <>
CK_TILE_DEVICE static constexpr void GemmStagedScheduler<1>()
CK_TILE_DEVICE constexpr void GemmStagedScheduler<1>()
{
// Mem: Q^T LDS load
// Comp: OGrad x V
......@@ -1777,7 +1777,7 @@ struct BlockFmhaBwdPipelineDefaultPolicy
}
template <>
CK_TILE_DEVICE static constexpr void GemmStagedScheduler<2>()
CK_TILE_DEVICE constexpr void GemmStagedScheduler<2>()
{
// Mem: Q, QT, LSE, OGrad, OGradT, D, LDS store
// Comp: PT x OGrad
......@@ -1796,7 +1796,7 @@ struct BlockFmhaBwdPipelineDefaultPolicy
}
template <>
CK_TILE_DEVICE static constexpr void GemmStagedScheduler<3>()
CK_TILE_DEVICE constexpr void GemmStagedScheduler<3>()
{
// Mem: SGradT LDS store, SGrad, Q, LSE LDS load.
// Comp: SGradT x QT
......@@ -1830,7 +1830,7 @@ struct BlockFmhaBwdPipelineDefaultPolicy
}
template <>
CK_TILE_DEVICE static constexpr void GemmStagedScheduler<4>()
CK_TILE_DEVICE constexpr void GemmStagedScheduler<4>()
{
// Mem: SGrad, OGrad, D LDS load.
// Comp: SGrad x KT
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment