Commit 5ec6a912 authored by Jun Liu's avatar Jun Liu
Browse files

Merge branch 'develop' into amd-develop

parents d39c3f5d 3bb0fe6c
......@@ -84,4 +84,9 @@ inline bool is_gfx11_supported()
ck::get_device_name() == "gfx1102" || ck::get_device_name() == "gfx1103";
}
inline bool is_gfx12_supported()
{
return ck::get_device_name() == "gfx1200" || ck::get_device_name() == "gfx1201";
}
} // namespace ck
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
......@@ -300,9 +300,9 @@ struct BlockwiseGemmDpp_ak0mak1_bk0nbk1_m0n0m1n1m2n2
constexpr index_t c_offset =
c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0));
dpp_gemm.template Run(a_thread_vec.template AsType<dpp_input_type>(),
b_thread_vec.template AsType<dpp_input_type>(),
c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
dpp_gemm.Run(a_thread_vec.template AsType<dpp_input_type>(),
b_thread_vec.template AsType<dpp_input_type>(),
c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
});
});
});
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
......@@ -613,7 +613,7 @@ struct BlockwiseGemmXdlops_pipeline_v4
constexpr index_t c_offset =
c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0));
xdlops_gemm.template Run(
xdlops_gemm.Run(
a_thread_vec.template AsType<mfma_input_type>(),
b_thread_vec.template AsType<mfma_input_type>(),
c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
......@@ -681,7 +681,7 @@ struct BlockwiseGemmXdlops_pipeline_v4
constexpr index_t c_offset =
c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0));
xdlops_gemm.template Run(
xdlops_gemm.Run(
a_thread_vec.template AsType<mfma_input_type>(),
b_thread_vec.template AsType<mfma_input_type>(),
c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
......@@ -749,10 +749,9 @@ struct BlockwiseGemmXdlops_pipeline_v4
constexpr index_t c_offset =
c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0));
xdlops_gemm.template Run(
a_thread_vec.template AsType<mfma_input_type>(),
b_thread_vec.template AsType<mfma_input_type>(),
c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
xdlops_gemm.Run(a_thread_vec.template AsType<mfma_input_type>(),
b_thread_vec.template AsType<mfma_input_type>(),
c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
});
});
});
......@@ -808,10 +807,9 @@ struct BlockwiseGemmXdlops_pipeline_v4
constexpr index_t c_offset =
c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0));
xdlops_gemm.template Run(
a_thread_vec.template AsType<mfma_input_type>(),
b_thread_vec.template AsType<mfma_input_type>(),
c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
xdlops_gemm.Run(a_thread_vec.template AsType<mfma_input_type>(),
b_thread_vec.template AsType<mfma_input_type>(),
c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
});
});
});
......@@ -840,10 +838,9 @@ struct BlockwiseGemmXdlops_pipeline_v4
constexpr index_t c_offset =
c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0));
xdlops_gemm.template Run(
a_thread_vec.template AsType<mfma_input_type>(),
b_thread_vec.template AsType<mfma_input_type>(),
c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
xdlops_gemm.Run(a_thread_vec.template AsType<mfma_input_type>(),
b_thread_vec.template AsType<mfma_input_type>(),
c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
});
});
});
......@@ -901,10 +898,9 @@ struct BlockwiseGemmXdlops_pipeline_v4
constexpr index_t c_offset =
c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0));
xdlops_gemm.template Run(
a_thread_vec.template AsType<mfma_input_type>(),
b_thread_vec.template AsType<mfma_input_type>(),
c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
xdlops_gemm.Run(a_thread_vec.template AsType<mfma_input_type>(),
b_thread_vec.template AsType<mfma_input_type>(),
c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
});
});
});
......@@ -939,10 +935,9 @@ struct BlockwiseGemmXdlops_pipeline_v4
constexpr index_t c_offset =
c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0));
xdlops_gemm.template Run(
a_thread_vec.template AsType<mfma_input_type>(),
b_thread_vec.template AsType<mfma_input_type>(),
c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
xdlops_gemm.Run(a_thread_vec.template AsType<mfma_input_type>(),
b_thread_vec.template AsType<mfma_input_type>(),
c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
});
});
});
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
......@@ -144,12 +144,12 @@ struct BlockwiseGemmXdlops_pipeline_v1<BlockGemmPipelineScheduler::Intrawave,
static constexpr index_t PrefillStages = 1;
static constexpr index_t GlobalBufferNum = 1;
__host__ static constexpr bool BlockHasHotloop(index_t num_loop)
__host__ __device__ static constexpr bool BlockHasHotloop(index_t num_loop)
{
return num_loop > PrefetchStages;
}
__host__ static constexpr TailNumber BlockLoopTailNum(index_t num_loop)
__host__ __device__ static constexpr TailNumber BlockLoopTailNum(index_t num_loop)
{
ignore = num_loop;
return TailNumber::Full;
......@@ -259,7 +259,7 @@ struct BlockwiseGemmXdlops_pipeline_v1<BlockGemmPipelineScheduler::Intrawave,
constexpr index_t c_offset =
c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0));
xdlops_gemm.template Run(
xdlops_gemm.Run(
a_thread_vec.template AsType<mfma_input_type>(),
b_thread_vec.template AsType<mfma_input_type>(),
c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
......@@ -319,10 +319,9 @@ struct BlockwiseGemmXdlops_pipeline_v1<BlockGemmPipelineScheduler::Intrawave,
constexpr index_t c_offset =
c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0));
xdlops_gemm.template Run(
a_thread_vec.template AsType<mfma_input_type>(),
b_thread_vec.template AsType<mfma_input_type>(),
c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
xdlops_gemm.Run(a_thread_vec.template AsType<mfma_input_type>(),
b_thread_vec.template AsType<mfma_input_type>(),
c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
});
});
});
......@@ -446,12 +445,12 @@ struct BlockwiseGemmXdlops_pipeline_v1<BlockGemmPipelineScheduler::Interwave,
static constexpr index_t PrefetchStages = 1;
static constexpr index_t PrefillStages = 1;
static constexpr index_t GlobalBufferNum = 1;
__host__ static constexpr bool BlockHasHotloop(index_t num_loop)
__host__ __device__ static constexpr bool BlockHasHotloop(index_t num_loop)
{
return num_loop > PrefetchStages;
}
__host__ static constexpr TailNumber BlockLoopTailNum(index_t num_loop)
__host__ __device__ static constexpr TailNumber BlockLoopTailNum(index_t num_loop)
{
ignore = num_loop;
return TailNumber::Full;
......@@ -584,7 +583,7 @@ struct BlockwiseGemmXdlops_pipeline_v1<BlockGemmPipelineScheduler::Interwave,
block_sync_lds();
__builtin_amdgcn_sched_barrier(0);
}
xdlops_gemm.template Run(
xdlops_gemm.Run(
a_thread_vec.template AsType<mfma_input_type>(),
b_thread_vec.template AsType<mfma_input_type>(),
c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
......@@ -668,7 +667,7 @@ struct BlockwiseGemmXdlops_pipeline_v1<BlockGemmPipelineScheduler::Interwave,
block_sync_lds();
__builtin_amdgcn_sched_barrier(0);
}
xdlops_gemm.template Run(
xdlops_gemm.Run(
a_thread_vec.template AsType<mfma_input_type>(),
b_thread_vec.template AsType<mfma_input_type>(),
c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
......@@ -153,12 +153,12 @@ struct BlockwiseGemmXdlops_pipeline_v2<BlockGemmPipelineScheduler::Intrawave,
static constexpr index_t PrefillStages = 1;
static constexpr index_t GlobalBufferNum = PrefetchStages;
__host__ static constexpr bool BlockHasHotloop(index_t num_loop)
__host__ __device__ static constexpr bool BlockHasHotloop(index_t num_loop)
{
return num_loop > PrefetchStages;
}
__host__ static constexpr TailNumber BlockLoopTailNum(index_t num_loop)
__host__ __device__ static constexpr TailNumber BlockLoopTailNum(index_t num_loop)
{
if(num_loop % PrefetchStages == 1)
{
......@@ -303,7 +303,7 @@ struct BlockwiseGemmXdlops_pipeline_v2<BlockGemmPipelineScheduler::Intrawave,
constexpr index_t c_offset =
c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0));
xdlops_gemm.template Run(
xdlops_gemm.Run(
a_thread_vec.template AsType<mfma_input_type>(),
b_thread_vec.template AsType<mfma_input_type>(),
c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
......@@ -374,7 +374,7 @@ struct BlockwiseGemmXdlops_pipeline_v2<BlockGemmPipelineScheduler::Intrawave,
constexpr index_t c_offset =
c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0));
xdlops_gemm.template Run(
xdlops_gemm.Run(
a_thread_vec.template AsType<mfma_input_type>(),
b_thread_vec.template AsType<mfma_input_type>(),
c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
......@@ -428,10 +428,9 @@ struct BlockwiseGemmXdlops_pipeline_v2<BlockGemmPipelineScheduler::Intrawave,
constexpr index_t c_offset =
c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0));
xdlops_gemm.template Run(
a_thread_vec.template AsType<mfma_input_type>(),
b_thread_vec.template AsType<mfma_input_type>(),
c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
xdlops_gemm.Run(a_thread_vec.template AsType<mfma_input_type>(),
b_thread_vec.template AsType<mfma_input_type>(),
c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
});
});
});
......@@ -480,10 +479,9 @@ struct BlockwiseGemmXdlops_pipeline_v2<BlockGemmPipelineScheduler::Intrawave,
constexpr index_t c_offset =
c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0));
xdlops_gemm.template Run(
a_thread_vec.template AsType<mfma_input_type>(),
b_thread_vec.template AsType<mfma_input_type>(),
c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
xdlops_gemm.Run(a_thread_vec.template AsType<mfma_input_type>(),
b_thread_vec.template AsType<mfma_input_type>(),
c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
});
});
});
......@@ -646,12 +644,12 @@ struct BlockwiseGemmXdlops_pipeline_v2<BlockGemmPipelineScheduler::Interwave,
static constexpr index_t PrefillStages = 1;
static constexpr index_t GlobalBufferNum = PrefetchStages;
__host__ static constexpr bool BlockHasHotloop(index_t num_loop)
__host__ __device__ static constexpr bool BlockHasHotloop(index_t num_loop)
{
return num_loop > PrefetchStages;
}
__host__ static constexpr TailNumber BlockLoopTailNum(index_t num_loop)
__host__ __device__ static constexpr TailNumber BlockLoopTailNum(index_t num_loop)
{
if(num_loop % PrefetchStages == 1)
{
......@@ -821,7 +819,7 @@ struct BlockwiseGemmXdlops_pipeline_v2<BlockGemmPipelineScheduler::Interwave,
block_sync_lds();
__builtin_amdgcn_sched_barrier(0);
}
xdlops_gemm.template Run(
xdlops_gemm.Run(
a_thread_vec.template AsType<mfma_input_type>(),
b_thread_vec.template AsType<mfma_input_type>(),
c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
......@@ -914,7 +912,7 @@ struct BlockwiseGemmXdlops_pipeline_v2<BlockGemmPipelineScheduler::Interwave,
block_sync_lds();
__builtin_amdgcn_sched_barrier(0);
}
xdlops_gemm.template Run(
xdlops_gemm.Run(
a_thread_vec.template AsType<mfma_input_type>(),
b_thread_vec.template AsType<mfma_input_type>(),
c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
......@@ -990,7 +988,7 @@ struct BlockwiseGemmXdlops_pipeline_v2<BlockGemmPipelineScheduler::Interwave,
block_sync_lds();
__builtin_amdgcn_sched_barrier(0);
}
xdlops_gemm.template Run(
xdlops_gemm.Run(
a_thread_vec.template AsType<mfma_input_type>(),
b_thread_vec.template AsType<mfma_input_type>(),
c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
......@@ -1066,7 +1064,7 @@ struct BlockwiseGemmXdlops_pipeline_v2<BlockGemmPipelineScheduler::Interwave,
block_sync_lds();
__builtin_amdgcn_sched_barrier(0);
}
xdlops_gemm.template Run(
xdlops_gemm.Run(
a_thread_vec.template AsType<mfma_input_type>(),
b_thread_vec.template AsType<mfma_input_type>(),
c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
......@@ -146,12 +146,12 @@ struct BlockwiseGemmXdlops_pipeline_v3<BlockGemmPipelineScheduler::Intrawave,
static constexpr index_t PrefillStages = 1;
static constexpr index_t GlobalBufferNum = 1;
__host__ static constexpr bool BlockHasHotloop(index_t num_loop)
__host__ __device__ static constexpr bool BlockHasHotloop(index_t num_loop)
{
return num_loop > PrefetchStages;
}
__host__ static constexpr TailNumber BlockLoopTailNum(index_t num_loop)
__host__ __device__ static constexpr TailNumber BlockLoopTailNum(index_t num_loop)
{
ignore = num_loop;
return TailNumber::Full;
......@@ -381,7 +381,7 @@ struct BlockwiseGemmXdlops_pipeline_v3<BlockGemmPipelineScheduler::Intrawave,
constexpr index_t c_offset =
c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0));
xdlops_gemm.template Run(
xdlops_gemm.Run(
a_thread_vec.template AsType<mfma_input_type>(),
b_thread_vec.template AsType<mfma_input_type>(),
c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
......@@ -440,10 +440,9 @@ struct BlockwiseGemmXdlops_pipeline_v3<BlockGemmPipelineScheduler::Intrawave,
constexpr index_t c_offset =
c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0));
xdlops_gemm.template Run(
a_thread_vec.template AsType<mfma_input_type>(),
b_thread_vec.template AsType<mfma_input_type>(),
c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
xdlops_gemm.Run(a_thread_vec.template AsType<mfma_input_type>(),
b_thread_vec.template AsType<mfma_input_type>(),
c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
});
});
});
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
......@@ -147,12 +147,12 @@ struct BlockwiseGemmXdlops_pipeline_v4<BlockGemmPipelineScheduler::Intrawave,
static constexpr index_t GlobalBufferNum = 2;
static constexpr index_t HotloopUnroll = 2;
__host__ static constexpr bool BlockHasHotloop(index_t num_loop)
__host__ __device__ static constexpr bool BlockHasHotloop(index_t num_loop)
{
return num_loop > PrefetchStages;
}
__host__ static constexpr TailNumber BlockLoopTailNum(index_t num_loop)
__host__ __device__ static constexpr TailNumber BlockLoopTailNum(index_t num_loop)
{
if(num_loop % HotloopUnroll == 1)
{
......@@ -403,7 +403,7 @@ struct BlockwiseGemmXdlops_pipeline_v4<BlockGemmPipelineScheduler::Intrawave,
constexpr index_t c_offset =
c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0));
xdlops_gemm.template Run(
xdlops_gemm.Run(
a_thread_vec.template AsType<mfma_input_type>(),
b_thread_vec.template AsType<mfma_input_type>(),
c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
......@@ -472,10 +472,9 @@ struct BlockwiseGemmXdlops_pipeline_v4<BlockGemmPipelineScheduler::Intrawave,
constexpr index_t c_offset =
c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0));
xdlops_gemm.template Run(
a_thread_vec.template AsType<mfma_input_type>(),
b_thread_vec.template AsType<mfma_input_type>(),
c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
xdlops_gemm.Run(a_thread_vec.template AsType<mfma_input_type>(),
b_thread_vec.template AsType<mfma_input_type>(),
c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
});
});
});
......@@ -529,10 +528,9 @@ struct BlockwiseGemmXdlops_pipeline_v4<BlockGemmPipelineScheduler::Intrawave,
constexpr index_t c_offset =
c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0));
xdlops_gemm.template Run(
a_thread_vec.template AsType<mfma_input_type>(),
b_thread_vec.template AsType<mfma_input_type>(),
c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
xdlops_gemm.Run(a_thread_vec.template AsType<mfma_input_type>(),
b_thread_vec.template AsType<mfma_input_type>(),
c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
});
});
});
......@@ -562,10 +560,9 @@ struct BlockwiseGemmXdlops_pipeline_v4<BlockGemmPipelineScheduler::Intrawave,
constexpr index_t c_offset =
c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0));
xdlops_gemm.template Run(
a_thread_vec.template AsType<mfma_input_type>(),
b_thread_vec.template AsType<mfma_input_type>(),
c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
xdlops_gemm.Run(a_thread_vec.template AsType<mfma_input_type>(),
b_thread_vec.template AsType<mfma_input_type>(),
c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
});
});
});
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
......@@ -444,7 +444,7 @@ struct BlockwiseGemmXdlops_pipeline_v5<BlockGemmPipelineScheduler::Intrawave,
constexpr index_t c_offset =
c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0));
xdlops_gemm.template Run(
xdlops_gemm.Run(
a_thread_vec.template AsType<mfma_input_type>(),
b_thread_vec.template AsType<mfma_input_type>(),
c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
......@@ -513,10 +513,9 @@ struct BlockwiseGemmXdlops_pipeline_v5<BlockGemmPipelineScheduler::Intrawave,
constexpr index_t c_offset =
c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0));
xdlops_gemm.template Run(
a_thread_vec.template AsType<mfma_input_type>(),
b_thread_vec.template AsType<mfma_input_type>(),
c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
xdlops_gemm.Run(a_thread_vec.template AsType<mfma_input_type>(),
b_thread_vec.template AsType<mfma_input_type>(),
c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
});
a_thread_copy_.Run(
a_block_desc_m0_m1_m2_k,
......@@ -564,10 +563,9 @@ struct BlockwiseGemmXdlops_pipeline_v5<BlockGemmPipelineScheduler::Intrawave,
constexpr index_t c_offset =
c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0));
xdlops_gemm.template Run(
a_thread_vec.template AsType<mfma_input_type>(),
b_thread_vec.template AsType<mfma_input_type>(),
c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
xdlops_gemm.Run(a_thread_vec.template AsType<mfma_input_type>(),
b_thread_vec.template AsType<mfma_input_type>(),
c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
});
a_thread_copy_.Run(
......@@ -607,10 +605,9 @@ struct BlockwiseGemmXdlops_pipeline_v5<BlockGemmPipelineScheduler::Intrawave,
constexpr index_t c_offset =
c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0));
xdlops_gemm.template Run(
a_thread_vec.template AsType<mfma_input_type>(),
b_thread_vec.template AsType<mfma_input_type>(),
c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
xdlops_gemm.Run(a_thread_vec.template AsType<mfma_input_type>(),
b_thread_vec.template AsType<mfma_input_type>(),
c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
});
});
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
......@@ -13,6 +13,504 @@
namespace ck {
#ifdef __gfx12__
template <index_t BlockSize,
typename FloatA,
typename FloatB,
typename FloatAcc,
typename ABlockDesc,
typename BBlockDesc,
index_t MPerBlock,
index_t NPerBlock,
index_t KPerBlock,
index_t MPerWMMA,
index_t NPerWMMA,
index_t MRepeat,
index_t NRepeat,
index_t KPack,
bool AEnableLds = true,
bool BEnableLds = true,
bool TransposeC = false>
/* Option: Read from LDS, big buffer hold all threads required data
* Source
* A: K0PerBlock x MPerBlock x K1
* B: K0PerBlock x NPerBlock x K1
* Destination
* C, non-transpose
* thread level: MRepeat x NRepeat x MAccVgprs
* block level: MRepeat x MWave x MSubGroup x NRepeat x NWave x NThreadPerSubGroup x MAccVgprs
* KPACK == WMMA_K = 16
*
* Option: Read from VMEM, small buffer hold each thread own required data (Skip LDS)
* Source:
* A(if skip LDS): MRepeat x KPack
* B(if skip LDS): NRepeat x KPack
* Destination
* C, non-transpose
* block level: MRepeat x MWave x MSubGroup x NRepeat x NWave x NThreadPerSubGroup x MAccVgprs
*/
struct BlockwiseGemmWMMA
{
static constexpr auto I0 = Number<0>{};
static constexpr auto I1 = Number<1>{};
static constexpr auto I2 = Number<2>{};
static constexpr auto I3 = Number<3>{};
static constexpr auto I4 = Number<4>{};
static constexpr auto I5 = Number<5>{};
static constexpr auto WmmaK = Number<16>{};
using ThisThreadBlock = ThisThreadBlock<BlockSize>;
// Hardcode of WaveSize, since current HIP Runtime(5.4.0-10984) could not return correct one.
static constexpr index_t WaveSize = 32;
// When use LDS, each Row(16 consecutive lanes) read whole data from source buffer
// When not use LDS, each Row read half of whole data from source buffer, exchange the data via
// permutation
static constexpr index_t A_KRow = 2;
static constexpr index_t B_KRow = 2;
static constexpr index_t A_K1 = ABlockDesc{}.GetLength(I5);
static constexpr index_t B_K1 = BBlockDesc{}.GetLength(I5);
static constexpr auto wmma_gemm =
WmmaGemm<FloatA, FloatB, FloatAcc, MPerWMMA, NPerWMMA, KPack, TransposeC>{};
static constexpr index_t MWaves = MPerBlock / (MRepeat * MPerWMMA);
static constexpr index_t NWaves = NPerBlock / (NRepeat * NPerWMMA);
StaticBufferTupleOfVector<AddressSpaceEnum::Vgpr,
FloatAcc,
MRepeat * NRepeat,
wmma_gemm.GetRegSizePerWmma(),
true>
c_thread_buf_;
__host__ __device__ constexpr auto& GetCThreadBuffer() { return c_thread_buf_; }
__device__ static auto GetWaveIdx()
{
const index_t thread_id = ThisThreadBlock::GetThreadId();
constexpr auto threadid_to_wave_idx_adaptor = make_single_stage_tensor_adaptor(
make_tuple(make_merge_transform(make_tuple(MWaves, NWaves, WaveSize))),
make_tuple(Sequence<0, 1, 2>{}),
make_tuple(Sequence<0>{}));
return threadid_to_wave_idx_adaptor.CalculateBottomIndex(make_multi_index(thread_id));
}
// Default, Block buffer in LDS, thread level offset enabled
__device__ static auto CalculateAThreadOriginDataIndex()
{
if constexpr(AEnableLds)
{
const auto wave_idx = GetWaveIdx();
const auto waveId_m = wave_idx[I0];
const auto WMMA_a_idx = wmma_gemm.CalculateAThreadOriginDataIndex();
// |KRepeat |MRepeat|MWave |KRow |MLane |KPack
return make_tuple(0, 0, waveId_m, wmma_gemm.GetSubGroupId(), WMMA_a_idx, 0);
}
else
{
return make_tuple(0, 0, 0, 0, 0, 0);
}
}
__device__ static auto CalculateBThreadOriginDataIndex()
{
if constexpr(BEnableLds)
{
const auto wave_idx = GetWaveIdx();
const auto waveId_n = wave_idx[I1];
const auto WMMA_b_idx = wmma_gemm.CalculateBThreadOriginDataIndex();
// |KRepeat |NRepeat|Nwave |KRow |NLane |KPack
return make_tuple(0, 0, waveId_n, wmma_gemm.GetSubGroupId(), WMMA_b_idx, 0);
}
else
{
return make_tuple(0, 0, 0, 0, 0, 0);
}
}
template <index_t m0, index_t n0>
__device__ static auto CalculateCThreadOriginDataIndex(Number<m0>, Number<n0>)
{
const auto wave_idx = GetWaveIdx();
const auto waveId_m = wave_idx[I0];
const auto waveId_n = wave_idx[I1];
const auto blk_idx = wmma_gemm.GetBeginOfThreadBlk();
constexpr auto mrepeat_mwave_mperWMMA_to_m_adaptor = make_single_stage_tensor_adaptor(
make_tuple(make_unmerge_transform(make_tuple(MRepeat, MWaves, MPerWMMA))),
make_tuple(Sequence<0>{}),
make_tuple(Sequence<0, 1, 2>{}));
constexpr auto nrepeat_nwave_nperWMMA_to_n_adaptor = make_single_stage_tensor_adaptor(
make_tuple(make_unmerge_transform(make_tuple(NRepeat, NWaves, NPerWMMA))),
make_tuple(Sequence<0>{}),
make_tuple(Sequence<0, 1, 2>{}));
const index_t c_thread_m = mrepeat_mwave_mperWMMA_to_m_adaptor.CalculateBottomIndex(
make_tuple(m0, waveId_m, blk_idx[I0]))[I0];
const index_t c_thread_n = nrepeat_nwave_nperWMMA_to_n_adaptor.CalculateBottomIndex(
make_tuple(n0, waveId_n, blk_idx[I1]))[I0];
return make_tuple(c_thread_m, c_thread_n);
}
template <index_t m0, index_t n0>
__device__ static auto CalculateCThreadOriginDataIndex7D(Number<m0>, Number<n0>)
{
const auto wave_idx = GetWaveIdx();
const auto waveId_m = wave_idx[I0];
const auto waveId_n = wave_idx[I1];
const auto blk_idx = wmma_gemm.GetBeginOfThreadBlk3D();
return make_tuple(
Number<m0>{}, waveId_m, blk_idx[I0], Number<n0>{}, waveId_n, blk_idx[I1], blk_idx[I2]);
}
using Tuple6 = decltype(CalculateAThreadOriginDataIndex());
__host__ __device__ BlockwiseGemmWMMA(Tuple6 a_origin = CalculateAThreadOriginDataIndex(),
Tuple6 b_origin = CalculateBThreadOriginDataIndex())
: a_thread_copy_(a_origin), b_thread_copy_(b_origin)
{
static_assert(ABlockDesc::IsKnownAtCompileTime() && BBlockDesc::IsKnownAtCompileTime(),
"wrong! Desc should be known at compile-time");
static_assert(ThisThreadBlock::GetNumOfThread() == MWaves * NWaves * WaveSize,
"ThisThreadBlock::GetNumOfThread() != MWaves * NWaves * WaveSize\n");
static_assert(MPerBlock % (MPerWMMA * MRepeat) == 0 &&
NPerBlock % (NPerWMMA * NRepeat) == 0,
"wrong!");
}
// transposed WMMA output C' = B' * A'
__host__ __device__ static constexpr auto
GetCThreadDescriptor_MRepeat_MWave_MThreadPerSubGroup_NRepeat_NWave_NSubGroup_NAccVgprs()
{
constexpr auto c_msubgroup_nthreadpersubgroup_maccvgprs_tblk_lens =
wmma_gemm.GetCMSubGroupNThreadPerSubGroupMAccVgprsThreadBlkLengths();
constexpr auto NAccVgprs = c_msubgroup_nthreadpersubgroup_maccvgprs_tblk_lens[I2];
return make_naive_tensor_descriptor_packed(
// |MRepeat |MWave |MSubGroup |NRepeat |NWave
// |NThreadPerSubGroup |MAccVgprs
make_tuple(Number<MRepeat>{}, I1, I1, Number<NRepeat>{}, I1, I1, NAccVgprs));
}
// Thread level, register decriptor. Vector-write
__host__ __device__ static constexpr auto
GetCThreadDescriptor_MRepeat_MWave_MSubGroup_NRepeat_NWave_NThreadPerSubGroup_MAccVgprs()
{
constexpr auto c_msubgroup_nthreadpersubgroup_maccvgprs_tblk_lens =
wmma_gemm.GetCMSubGroupNThreadPerSubGroupMAccVgprsThreadBlkLengths();
constexpr auto MAccVgprs = c_msubgroup_nthreadpersubgroup_maccvgprs_tblk_lens[I2];
constexpr auto AccStride = c_msubgroup_nthreadpersubgroup_maccvgprs_tblk_lens[I3];
return make_naive_tensor_descriptor(
// |MRepeat |MWave |MSubGroup |NRepeat |NWave
// |NThreadPerSubGroup |MAccVgprs
make_tuple(Number<MRepeat>{}, I1, I1, Number<NRepeat>{}, I1, I1, MAccVgprs),
make_tuple(Number<NRepeat>{} * MAccVgprs * AccStride,
Number<NRepeat>{} * MAccVgprs * AccStride,
Number<NRepeat>{} * MAccVgprs * AccStride,
MAccVgprs * AccStride,
MAccVgprs * AccStride,
MAccVgprs * AccStride,
AccStride));
}
template <typename CGridDesc_M_N>
__host__ __device__ static constexpr auto
MakeCGridDescriptor_MBlockxRepeat_MWave_MSubGroup_NBlockxRepeat_NWave_NThreadPerSubGroup_MAccVgprs(
const CGridDesc_M_N& c_grid_desc_m_n)
{
const auto M = c_grid_desc_m_n.GetLength(I0);
const auto N = c_grid_desc_m_n.GetLength(I1);
const auto c_grid_desc_mblockxrepeat_mwave_mperwmma_nblockxrepeat_nwave_nperwmma =
transform_tensor_descriptor(
c_grid_desc_m_n,
make_tuple(
make_unmerge_transform(make_tuple(M / (MWaves * MPerWMMA), MWaves, MPerWMMA)),
make_unmerge_transform(make_tuple(N / (NWaves * NPerWMMA), NWaves, NPerWMMA))),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0, 1, 2>{}, Sequence<3, 4, 5>{}));
return wmma_gemm
.MakeCDesc_MBlockxRepeat_MWave_MSubGroup_NBlockxRepeat_NWave_NThreadPerSubGroup_MAccVgprs(
c_grid_desc_mblockxrepeat_mwave_mperwmma_nblockxrepeat_nwave_nperwmma);
}
// transposed WMMA output C' = B' * A'
__host__ __device__ static constexpr auto
GetCBlockDescriptor_MRepeat_MWave_MThreadPerSubGroup_NRepeat_NWave_NSubGroup_NAccVgprs()
{
constexpr auto c_block_desc_mrepeat_mwave_mperwmma_nrepeat_nwave_nperwmma =
make_naive_tensor_descriptor_packed(make_tuple(Number<MRepeat>{},
Number<MWaves>{},
Number<MPerWMMA>{},
Number<NRepeat>{},
Number<NWaves>{},
Number<NPerWMMA>{}));
return wmma_gemm
.MakeCDesc_MBlockxRepeat_MWave_MThreadPerSubGroup_NBlockxRepeat_NWave_NSubGroup_NAccVgprs(
c_block_desc_mrepeat_mwave_mperwmma_nrepeat_nwave_nperwmma);
}
// Provide dimension size
__host__ __device__ static constexpr auto
GetCBlockDescriptor_MRepeat_MWave_MSubGroup_NRepeat_NWave_NThreadPerSubGroup_MAccVgprs()
{
constexpr auto c_block_desc_mrepeat_mwave_mperwmma_nrepeat_nwave_nperwmma =
make_naive_tensor_descriptor_packed(make_tuple(Number<MRepeat>{},
Number<MWaves>{},
Number<MPerWMMA>{},
Number<NRepeat>{},
Number<NWaves>{},
Number<NPerWMMA>{}));
return wmma_gemm
.MakeCDesc_MBlockxRepeat_MWave_MSubGroup_NBlockxRepeat_NWave_NThreadPerSubGroup_MAccVgprs(
c_block_desc_mrepeat_mwave_mperwmma_nrepeat_nwave_nperwmma);
}
// Describe how data allocated in thread copy src buffer
// M0_M1_M2 = MRepeat_MWave_MPerWmma, N0_N1_N2 = NRepeat_NWave_NPerWmma
static constexpr ABlockDesc a_block_desc_k0_m0_m1_m2_k1;
static constexpr BBlockDesc b_block_desc_k0_n0_n1_n2_k1;
template <typename ABlockBuffer, typename BBlockBuffer, typename CThreadBuffer>
__device__ void Run(const ABlockBuffer& a_block_buf,
const BBlockBuffer& b_block_buf,
CThreadBuffer& c_thread_buf) const
{
auto a_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, FloatA>(
a_thread_desc_.GetElementSpaceSize());
auto b_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, FloatB>(
b_thread_desc_.GetElementSpaceSize());
static_assert(KPack % (A_K1 * A_KRow) == 0, "");
static_assert(KPack % (B_K1 * B_KRow) == 0, "");
// basic intrinsic to determine loopover direction
if constexpr(MRepeat < NRepeat)
{
static_for<0, KPerBlock / KPack, 1>{}(
[&](auto k) { // k=0,1,2 instead of k=0,kpack*1, ...
static_for<0, MRepeat, 1>{}([&](auto m0) {
// read A
a_thread_copy_.Run(
a_block_desc_k0_m0_m1_m2_k1,
make_tuple(Number<k * KPack / A_K1 / A_KRow>{}, m0, I0, I0, I0, I0),
a_block_buf,
a_thread_desc_,
make_tuple(I0, m0, I0, I0, I0, I0),
a_thread_buf);
static_for<0, NRepeat, 1>{}([&](auto n0) {
// read B
b_thread_copy_.Run(
b_block_desc_k0_n0_n1_n2_k1,
make_tuple(Number<k * KPack / B_K1 / B_KRow>{}, n0, I0, I0, I0, I0),
b_block_buf,
b_thread_desc_,
make_tuple(I0, n0, I0, I0, I0, I0),
b_thread_buf);
vector_type<FloatA, KPack / A_KRow> a_thread_vec;
vector_type<FloatB, KPack / B_KRow> b_thread_vec;
static_for<0, KPack / A_KRow, 1>{}([&](auto i) {
a_thread_vec.template AsType<FloatA>()(i) =
a_thread_buf[Number<a_thread_desc_.CalculateOffset(
make_tuple(i / A_K1, m0, 0, 0, 0, i % A_K1))>{}];
});
static_for<0, KPack / B_KRow, 1>{}([&](auto i) {
b_thread_vec.template AsType<FloatB>()(i) =
b_thread_buf[Number<b_thread_desc_.CalculateOffset(
make_tuple(i / B_K1, n0, 0, 0, 0, i % B_K1))>{}];
});
using wmma_input_type_a =
typename vector_type<FloatA, WmmaK / A_KRow>::type;
using wmma_input_type_b =
typename vector_type<FloatB, WmmaK / B_KRow>::type;
constexpr index_t c_offset =
c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0));
wmma_gemm.template Run(
a_thread_vec.template AsType<wmma_input_type_a>(),
b_thread_vec.template AsType<wmma_input_type_b>(),
c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
});
});
});
}
else
{
static_for<0, NRepeat, 1>{}([&](auto n0) {
static_for<0, MRepeat, 1>{}([&](auto m0) {
static_for<0, KPerBlock / KPack, 1>{}([&](auto k) { // k=0,1,2 instead of
// k=0,kpack*1, ..
// read B
b_thread_copy_.Run(
b_block_desc_k0_n0_n1_n2_k1,
make_tuple(Number<k * KPack / B_K1 / B_KRow>{}, n0, I0, I0, I0, I0),
b_block_buf,
b_thread_desc_,
make_tuple(I0, n0, I0, I0, I0, I0),
b_thread_buf);
// read A
a_thread_copy_.Run(
a_block_desc_k0_m0_m1_m2_k1,
make_tuple(Number<k * KPack / A_K1 / A_KRow>{}, m0, I0, I0, I0, I0),
a_block_buf,
a_thread_desc_,
make_tuple(I0, m0, I0, I0, I0, I0),
a_thread_buf);
vector_type<FloatA, KPack / A_KRow> a_thread_vec;
vector_type<FloatB, KPack / B_KRow> b_thread_vec;
static_for<0, KPack / A_KRow, 1>{}([&](auto i) {
a_thread_vec.template AsType<FloatA>()(i) =
a_thread_buf[Number<a_thread_desc_.CalculateOffset(
make_tuple(i / A_K1, m0, 0, 0, 0, i % A_K1))>{}];
});
static_for<0, KPack / B_KRow, 1>{}([&](auto i) {
b_thread_vec.template AsType<FloatB>()(i) =
b_thread_buf[Number<b_thread_desc_.CalculateOffset(
make_tuple(i / B_K1, n0, 0, 0, 0, i % B_K1))>{}];
});
using wmma_input_type_a =
typename vector_type<FloatA, WmmaK / A_KRow>::type;
using wmma_input_type_b =
typename vector_type<FloatB, WmmaK / B_KRow>::type;
constexpr index_t c_offset =
c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0));
wmma_gemm.template Run(
a_thread_vec.template AsType<wmma_input_type_a>(),
b_thread_vec.template AsType<wmma_input_type_b>(),
c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
});
});
});
}
}
protected:
static constexpr auto a_thread_desc_ = make_naive_tensor_descriptor(
make_tuple(Number<KPack / A_K1 / A_KRow>{}, Number<MRepeat>{}, I1, I1, I1, Number<A_K1>{}),
make_tuple(Number<A_K1>{},
Number<KPack / A_KRow>{},
Number<A_K1>{},
Number<A_K1>{},
Number<A_K1>{},
Number<1>{}));
static constexpr auto b_thread_desc_ = make_naive_tensor_descriptor(
make_tuple(Number<KPack / B_K1 / B_KRow>{}, Number<NRepeat>{}, I1, I1, I1, Number<B_K1>{}),
make_tuple(Number<B_K1>{},
Number<KPack / B_KRow>{},
Number<B_K1>{},
Number<B_K1>{},
Number<B_K1>{},
Number<1>{}));
// C[M, N, NumRegWMMA]
static constexpr auto c_thread_desc_ = make_naive_tensor_descriptor_packed(
make_tuple(Number<MRepeat>{}, Number<NRepeat>{}, wmma_gemm.GetRegSizePerWmma()));
template <bool EnableLds>
struct AThreadCopySelector;
template <>
struct AThreadCopySelector<true>
{
using type =
ThreadwiseTensorSliceTransfer_v4<FloatA,
FloatA,
decltype(a_block_desc_k0_m0_m1_m2_k1),
decltype(a_thread_desc_),
Sequence<KPack / A_K1 / A_KRow, 1, 1, 1, 1, A_K1>,
Sequence<0, 1, 2, 3, 4, 5>,
5,
A_K1,
A_K1>;
};
template <>
struct AThreadCopySelector<false>
{
using type = ThreadwiseTensorSliceTransfer_StaticToStatic_IntraRow<
FloatA,
FloatA,
decltype(a_block_desc_k0_m0_m1_m2_k1),
decltype(a_thread_desc_),
tensor_operation::element_wise::PassThrough,
Sequence<KPack / A_K1 / A_KRow, 1, 1, 1, 1, A_K1>,
Sequence<0, 1, 2, 3, 4, 5>,
5,
A_K1,
false>;
};
template <bool EnableLds>
struct BThreadCopySelector;
template <>
struct BThreadCopySelector<true>
{
using type =
ThreadwiseTensorSliceTransfer_v4<FloatB,
FloatB,
decltype(b_block_desc_k0_n0_n1_n2_k1),
decltype(b_thread_desc_),
Sequence<KPack / B_K1 / B_KRow, 1, 1, 1, 1, B_K1>,
Sequence<0, 1, 2, 3, 4, 5>,
5,
B_K1,
B_K1>;
};
template <>
struct BThreadCopySelector<false>
{
using type = ThreadwiseTensorSliceTransfer_StaticToStatic_IntraRow<
FloatB,
FloatB,
decltype(b_block_desc_k0_n0_n1_n2_k1),
decltype(b_thread_desc_),
tensor_operation::element_wise::PassThrough,
Sequence<KPack / B_K1 / B_KRow, 1, 1, 1, 1, B_K1>,
Sequence<0, 1, 2, 3, 4, 5>,
5,
B_K1,
false>;
};
typename AThreadCopySelector<AEnableLds>::type a_thread_copy_;
typename BThreadCopySelector<BEnableLds>::type b_thread_copy_;
};
#else
template <index_t BlockSize,
typename FloatA,
typename FloatB,
......@@ -352,10 +850,9 @@ struct BlockwiseGemmWMMA
constexpr index_t c_offset =
c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0));
wmma_gemm.template Run(
a_thread_vec.template AsType<wmma_input_type_a>(),
b_thread_vec.template AsType<wmma_input_type_b>(),
c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
wmma_gemm.Run(a_thread_vec.template AsType<wmma_input_type_a>(),
b_thread_vec.template AsType<wmma_input_type_b>(),
c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
});
});
});
......@@ -411,10 +908,9 @@ struct BlockwiseGemmWMMA
constexpr index_t c_offset =
c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0));
wmma_gemm.template Run(
a_thread_vec.template AsType<wmma_input_type_a>(),
b_thread_vec.template AsType<wmma_input_type_b>(),
c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
wmma_gemm.Run(a_thread_vec.template AsType<wmma_input_type_a>(),
b_thread_vec.template AsType<wmma_input_type_b>(),
c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
});
});
});
......@@ -529,5 +1025,6 @@ struct BlockwiseGemmWMMA
typename AThreadCopySelector<AEnableLds>::type a_thread_copy_;
typename BThreadCopySelector<BEnableLds>::type b_thread_copy_;
};
#endif
} // namespace ck
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
......@@ -340,10 +340,9 @@ struct BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1
constexpr index_t c_offset =
c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0));
xdlops_gemm.template Run(
a_thread_vec.template AsType<mfma_input_type_a>(),
b_thread_vec.template AsType<mfma_input_type_b>(),
c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
xdlops_gemm.Run(a_thread_vec.template AsType<mfma_input_type_a>(),
b_thread_vec.template AsType<mfma_input_type_b>(),
c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
});
});
});
......@@ -488,7 +487,14 @@ struct BlockwiseGemmXdlopsInterwave_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1
// sync point.
if constexpr(k.value != 0 || KPerInnerLoop == KPerThread)
{
#ifdef __gfx12__
asm volatile("\
s_barrier_signal -1 \n \
s_barrier_wait -1 \
" ::);
#else
asm volatile("s_barrier" ::);
#endif
__builtin_amdgcn_sched_barrier(0);
}
static_for<0, KPerInnerLoop, KPack>{}([&](auto k_) {
......@@ -530,10 +536,9 @@ struct BlockwiseGemmXdlopsInterwave_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1
// TODO: insert setprio in more precise manner since we
// could have more than >1 MFMA instructions in single call
xdlops_gemm.template Run(
a_thread_vec.template AsType<mfma_input_type_a>(),
b_thread_vec.template AsType<mfma_input_type_b>(),
c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
xdlops_gemm.Run(a_thread_vec.template AsType<mfma_input_type_a>(),
b_thread_vec.template AsType<mfma_input_type_b>(),
c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
if constexpr(k_.value == 0 && m0.value == 0 && n0.value == 0)
{
__builtin_amdgcn_sched_barrier(0);
......@@ -963,10 +968,9 @@ struct BlockwiseGemmXdlops_v2
constexpr index_t c_offset =
c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0));
xdlops_gemm.template Run(
a_thread_vec.template AsType<mfma_input_type>(),
b_thread_vec.template AsType<mfma_input_type>(),
c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
xdlops_gemm.Run(a_thread_vec.template AsType<mfma_input_type>(),
b_thread_vec.template AsType<mfma_input_type>(),
c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
});
});
});
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
......@@ -281,10 +281,9 @@ struct BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1r1
constexpr index_t c_offset =
c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0));
xdlops_gemm.template Run(
a_thread_vec.template AsType<mfma_input_type>(),
b_thread_vec.template AsType<mfma_input_type>(),
c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
xdlops_gemm.Run(a_thread_vec.template AsType<mfma_input_type>(),
b_thread_vec.template AsType<mfma_input_type>(),
c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
});
});
});
......
#pragma once
#include "ck/utility/common_header.hpp"
#include "ck/tensor_description/multi_index_transform_helper.hpp"
#include "ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp"
#include "ck/tensor_operation/operator_transform/transform_conv_fwd_to_gemm.hpp"
#include "ck/tensor_operation/gpu/device/matrix_padder.hpp"
#include <fstream>
#include <variant>
// functions to return the corresponding structs based on generated template parameters
using layouts = std::variant<ck::tensor_layout::convolution::GNWK,
ck::tensor_layout::convolution::GNHWK,
ck::tensor_layout::convolution::NHWGK,
ck::tensor_layout::convolution::GNDHWK,
ck::tensor_layout::convolution::NDHWGK>;
// return the layout type: currently this is the only type supported in MIOpen
auto layout_type(std::string type)
{
if(type == "ck::tensor_layout::convolution::NHWGK")
{
return ck::tensor_layout::convolution::NHWGK{};
}
throw std::runtime_error("Incorrect layout");
}
// return the right gemm spec based on the generated template parameters
ck::tensor_operation::device::GemmSpecialization gemm_type(std::string type)
{
if(type == "ck::tensor_operation::device::GemmSpecialization::Default")
{
return ck::tensor_operation::device::GemmSpecialization::Default;
}
if(type == "ck::tensor_operation::device::GemmSpecialization::MNKPadding")
{
return ck::tensor_operation::device::GemmSpecialization::MNKPadding;
}
throw std::runtime_error("Incorrect gemm spec: " + type);
}
// return the type of convolution
ck::tensor_operation::device::ConvolutionForwardSpecialization conv_type(std::string type)
{
if(type == "ck::tensor_operation::device::ConvolutionForwardSpecialization::Default")
{
return ck::tensor_operation::device::ConvolutionForwardSpecialization::Default;
}
if(type == "ck::tensor_operation::device::ConvolutionForwardSpecialization::Filter1x1Pad0")
{
return ck::tensor_operation::device::ConvolutionForwardSpecialization::Filter1x1Pad0;
}
if(type ==
"ck::tensor_operation::device::ConvolutionForwardSpecialization::Filter1x1Stride1Pad0")
{
return ck::tensor_operation::device::ConvolutionForwardSpecialization::Filter1x1Stride1Pad0;
}
if(type == "ck::tensor_operation::device::ConvolutionForwardSpecialization::OddC")
{
return ck::tensor_operation::device::ConvolutionForwardSpecialization::OddC;
}
throw std::runtime_error("Incorrect conv spec: " + type);
}
// Function to call on MatrixPadder via a wrapper struct
// NOTE: CK only uses MNKPadding for forward convolution
template <typename CDesc_MRaw_NRaw>
auto pad(ck::index_t mpb,
ck::index_t npb,
ck::index_t kpb,
ck::tensor_operation::device::GemmSpecialization gemm,
CDesc_MRaw_NRaw conv)
{
if(gemm == ck::tensor_operation::device::GemmSpecialization::MNKPadding)
{
ck::tensor_operation::device::MatrixPadder<
ck::tensor_operation::device::GemmSpecialization::MNKPadding,
ck::index_t,
ck::index_t,
ck::index_t>
a;
a.MPerTile_ = mpb;
a.NPerTile_ = npb;
a.KPerTile_ = kpb;
auto tmp = grid_desc(a, conv);
return tmp;
}
throw std::runtime_error("Incorrect template parameters, check gemm spec");
}
// Functions to call on TransformConvFwdToGemm through wrapper: different functions based on num
// dims
// FIXME: add a way to properly pass in the layout
auto transform_conv(ck::index_t num_dim,
ck::tensor_operation::device::ConvolutionForwardSpecialization spec,
ck::Array<ck::index_t, 5> out_lengths,
ck::Array<ck::index_t, 5> out_strides)
{
if(num_dim == 2 &&
spec == ck::tensor_operation::device::ConvolutionForwardSpecialization::Default)
{
ck::tensor_operation::TransformConvFwdToGemm<
2,
ck::tensor_operation::device::ConvolutionForwardSpecialization::Default>
conv_fwd;
auto res = ck::tensor_operation::TransformConv();
return res.transform_func(out_lengths, out_strides, conv_fwd);
}
if(num_dim == 2 &&
spec == ck::tensor_operation::device::ConvolutionForwardSpecialization::Filter1x1Pad0)
{
ck::tensor_operation::TransformConvFwdToGemm<
2,
ck::tensor_operation::device::ConvolutionForwardSpecialization::Filter1x1Pad0>
conv_fwd;
auto res = ck::tensor_operation::TransformConv();
return res.transform_func(out_lengths, out_strides, conv_fwd);
}
if(num_dim == 2 &&
spec == ck::tensor_operation::device::ConvolutionForwardSpecialization::Filter1x1Stride1Pad0)
{
ck::tensor_operation::TransformConvFwdToGemm<
2,
ck::tensor_operation::device::ConvolutionForwardSpecialization::Filter1x1Stride1Pad0>
conv_fwd;
auto res = ck::tensor_operation::TransformConv();
return res.transform_func(out_lengths, out_strides, conv_fwd);
}
if(num_dim == 2 && spec == ck::tensor_operation::device::ConvolutionForwardSpecialization::OddC)
{
ck::tensor_operation::TransformConvFwdToGemm<
2,
ck::tensor_operation::device::ConvolutionForwardSpecialization::OddC>
conv_fwd;
auto res = ck::tensor_operation::TransformConv();
return res.transform_func(out_lengths, out_strides, conv_fwd);
}
throw std::runtime_error("Incorrect conv spec");
}
auto transform_conv_3d(ck::index_t num_dim,
ck::tensor_operation::device::ConvolutionForwardSpecialization spec,
ck::Array<ck::index_t, 6> out_lengths,
ck::Array<ck::index_t, 6> out_strides)
{
if(num_dim == 3 &&
spec == ck::tensor_operation::device::ConvolutionForwardSpecialization::Default)
{
ck::tensor_operation::TransformConvFwdToGemm<
3,
ck::tensor_operation::device::ConvolutionForwardSpecialization::Default>
conv_fwd;
auto res = ck::tensor_operation::TransformConv();
return res.transform_func(out_lengths, out_strides, conv_fwd);
}
if(num_dim == 3 &&
spec == ck::tensor_operation::device::ConvolutionForwardSpecialization::Filter1x1Pad0)
{
ck::tensor_operation::TransformConvFwdToGemm<
3,
ck::tensor_operation::device::ConvolutionForwardSpecialization::Filter1x1Pad0>
conv_fwd;
auto res = ck::tensor_operation::TransformConv();
return res.transform_func(out_lengths, out_strides, conv_fwd);
}
if(num_dim == 3 &&
spec == ck::tensor_operation::device::ConvolutionForwardSpecialization::Filter1x1Stride1Pad0)
{
ck::tensor_operation::TransformConvFwdToGemm<
3,
ck::tensor_operation::device::ConvolutionForwardSpecialization::Filter1x1Stride1Pad0>
conv_fwd;
auto res = ck::tensor_operation::TransformConv();
return res.transform_func(out_lengths, out_strides, conv_fwd);
}
if(num_dim == 3 && spec == ck::tensor_operation::device::ConvolutionForwardSpecialization::OddC)
{
ck::tensor_operation::TransformConvFwdToGemm<
3,
ck::tensor_operation::device::ConvolutionForwardSpecialization::OddC>
conv_fwd;
auto res = ck::tensor_operation::TransformConv();
return res.transform_func(out_lengths, out_strides, conv_fwd);
}
throw std::runtime_error("Incorrect conv spec");
}
auto transform_conv_1d(ck::index_t num_dim,
ck::tensor_operation::device::ConvolutionForwardSpecialization spec,
ck::Array<ck::index_t, 4> out_lengths,
ck::Array<ck::index_t, 4> out_strides)
{
if(num_dim == 1 &&
spec == ck::tensor_operation::device::ConvolutionForwardSpecialization::Default)
{
ck::tensor_operation::TransformConvFwdToGemm<
1,
ck::tensor_operation::device::ConvolutionForwardSpecialization::Default>
conv_fwd;
auto res = ck::tensor_operation::TransformConv();
return res.transform_func(out_lengths, out_strides, conv_fwd);
}
if(num_dim == 1 &&
spec == ck::tensor_operation::device::ConvolutionForwardSpecialization::Filter1x1Pad0)
{
ck::tensor_operation::TransformConvFwdToGemm<
1,
ck::tensor_operation::device::ConvolutionForwardSpecialization::Filter1x1Pad0>
conv_fwd;
auto res = ck::tensor_operation::TransformConv();
return res.transform_func(out_lengths, out_strides, conv_fwd);
}
if(num_dim == 1 &&
spec == ck::tensor_operation::device::ConvolutionForwardSpecialization::Filter1x1Stride1Pad0)
{
ck::tensor_operation::TransformConvFwdToGemm<
1,
ck::tensor_operation::device::ConvolutionForwardSpecialization::Filter1x1Stride1Pad0>
conv_fwd;
auto res = ck::tensor_operation::TransformConv();
return res.transform_func(out_lengths, out_strides, conv_fwd);
}
if(num_dim == 1 && spec == ck::tensor_operation::device::ConvolutionForwardSpecialization::OddC)
{
ck::tensor_operation::TransformConvFwdToGemm<
1,
ck::tensor_operation::device::ConvolutionForwardSpecialization::OddC>
conv_fwd;
auto res = ck::tensor_operation::TransformConv();
return res.transform_func(out_lengths, out_strides, conv_fwd);
}
throw std::runtime_error("Incorrect dims or conv spec");
}
template <typename CGridDesc_M_N>
auto block_2_etile(ck::index_t m_per_block, ck::index_t n_per_block, CGridDesc_M_N matrix_padder)
{
if(m_per_block == 32 && n_per_block == 64)
{
auto b2e = ck::BlockToCTileMap_M00_N0_M01Adapt<32, 64, CGridDesc_M_N>(matrix_padder);
return b2e.CalculateGridSize(matrix_padder);
}
if(m_per_block == 32 && n_per_block == 128)
{
ck::BlockToCTileMap_M00_N0_M01Adapt<32, 128, CGridDesc_M_N> b2e(matrix_padder);
return b2e.CalculateGridSize(matrix_padder);
}
if(m_per_block == 64 && n_per_block == 32)
{
ck::BlockToCTileMap_M00_N0_M01Adapt<64, 32, CGridDesc_M_N> b2e(matrix_padder);
return b2e.CalculateGridSize(matrix_padder);
}
if(m_per_block == 64 && n_per_block == 64)
{
ck::BlockToCTileMap_M00_N0_M01Adapt<64, 64, CGridDesc_M_N> b2e(matrix_padder);
return b2e.CalculateGridSize(matrix_padder);
}
if(m_per_block == 64 && n_per_block == 128)
{
ck::BlockToCTileMap_M00_N0_M01Adapt<64, 128, CGridDesc_M_N> b2e(matrix_padder);
return b2e.CalculateGridSize(matrix_padder);
}
if(m_per_block == 128 && n_per_block == 32)
{
ck::BlockToCTileMap_M00_N0_M01Adapt<128, 32, CGridDesc_M_N> b2e(matrix_padder);
return b2e.CalculateGridSize(matrix_padder);
}
if(m_per_block == 128 && n_per_block == 64)
{
ck::BlockToCTileMap_M00_N0_M01Adapt<128, 64, CGridDesc_M_N> b2e(matrix_padder);
return b2e.CalculateGridSize(matrix_padder);
}
if(m_per_block == 128 && n_per_block == 128)
{
ck::BlockToCTileMap_M00_N0_M01Adapt<128, 128, CGridDesc_M_N> b2e(matrix_padder);
return b2e.CalculateGridSize(matrix_padder);
}
if(m_per_block == 128 && n_per_block == 256)
{
ck::BlockToCTileMap_M00_N0_M01Adapt<128, 256, CGridDesc_M_N> b2e(matrix_padder);
return b2e.CalculateGridSize(matrix_padder);
}
if(m_per_block == 256 && n_per_block == 128)
{
ck::BlockToCTileMap_M00_N0_M01Adapt<256, 128, CGridDesc_M_N> b2e(matrix_padder);
return b2e.CalculateGridSize(matrix_padder);
}
throw std::runtime_error("Incorrect template parameters");
}
// wrapper functions by dims to get grid size - uses above 3 functions
// TODO: eventually remove the 1d/2d versions as CK will only support 3d convolutions
auto get_launch_params_1d(ck::host::Solution solution,
ck::Array<ck::index_t, 4> out_lengths,
ck::Array<ck::index_t, 4> out_strides)
{
auto num_dim = solution.GetTemplateParameter<ck::index_t>("NumDim");
auto m_per_block = solution.GetTemplateParameter<ck::index_t>("MPerBlock");
auto n_per_block = solution.GetTemplateParameter<ck::index_t>("NPerBlock");
auto k_per_block = solution.GetTemplateParameter<ck::index_t>("KPerBlock");
auto GemmType = solution.GetTemplateParameter<std::string>("GemmSpecialization");
auto ConvType = solution.GetTemplateParameter<std::string>("ConvSpecialization");
ck::tensor_operation::device::GemmSpecialization GemmSpec = gemm_type(GemmType);
ck::tensor_operation::device::ConvolutionForwardSpecialization ConvSpec = conv_type(ConvType);
auto conv_to_gemm_transformer = transform_conv_1d(num_dim, ConvSpec, out_lengths, out_strides);
auto matrix_padder =
pad(m_per_block, n_per_block, k_per_block, GemmSpec, conv_to_gemm_transformer);
auto b2e = block_2_etile(m_per_block, n_per_block, matrix_padder);
return b2e;
}
auto get_launch_params(ck::host::Solution solution,
ck::Array<ck::index_t, 5> out_lengths,
ck::Array<ck::index_t, 5> out_strides)
{
auto num_dim = solution.GetTemplateParameter<ck::index_t>("NumDim");
auto m_per_block = solution.GetTemplateParameter<ck::index_t>("MPerBlock");
auto n_per_block = solution.GetTemplateParameter<ck::index_t>("NPerBlock");
auto k_per_block = solution.GetTemplateParameter<ck::index_t>("KPerBlock");
auto GemmType = solution.GetTemplateParameter<std::string>("GemmSpecialization");
auto ConvType = solution.GetTemplateParameter<std::string>("ConvSpecialization");
ck::tensor_operation::device::GemmSpecialization GemmSpec = gemm_type(GemmType);
ck::tensor_operation::device::ConvolutionForwardSpecialization ConvSpec = conv_type(ConvType);
auto conv_to_gemm_transformer = transform_conv(num_dim, ConvSpec, out_lengths, out_strides);
auto matrix_padder =
pad(m_per_block, n_per_block, k_per_block, GemmSpec, conv_to_gemm_transformer);
auto b2e = block_2_etile(m_per_block, n_per_block, matrix_padder);
return b2e;
}
auto get_launch_params_3d(ck::host::Solution solution,
ck::Array<ck::index_t, 6> out_lengths,
ck::Array<ck::index_t, 6> out_strides)
{
auto num_dim = solution.GetTemplateParameter<ck::index_t>("NumDim");
auto m_per_block = solution.GetTemplateParameter<ck::index_t>("MPerBlock");
auto n_per_block = solution.GetTemplateParameter<ck::index_t>("NPerBlock");
auto k_per_block = solution.GetTemplateParameter<ck::index_t>("KPerBlock");
auto GemmType = solution.GetTemplateParameter<std::string>("GemmSpecialization");
auto ConvType = solution.GetTemplateParameter<std::string>("ConvSpecialization");
ck::tensor_operation::device::GemmSpecialization GemmSpec = gemm_type(GemmType);
ck::tensor_operation::device::ConvolutionForwardSpecialization ConvSpec = conv_type(ConvType);
auto conv_to_gemm_transformer = transform_conv_3d(num_dim, ConvSpec, out_lengths, out_strides);
auto matrix_padder =
pad(m_per_block, n_per_block, k_per_block, GemmSpec, conv_to_gemm_transformer);
auto b2e = block_2_etile(m_per_block, n_per_block, matrix_padder);
return b2e;
}
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <functional>
#include <iostream>
#include <iterator>
#include <numeric>
#include <sstream>
#include "ck/utility/common_header.hpp"
#include "ck/tensor_description/tensor_descriptor.hpp"
#include "ck/tensor_description/tensor_descriptor_helper.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/convolution_forward_specialization.hpp"
#include "ck/tensor_operation/operator_transform/transform_conv_fwd_to_gemm.hpp"
#include "ck/tensor_operation/gpu/device/device_grouped_conv_fwd_multiple_abd.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/matrix_padder.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_xdl_cshuffle.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_abd_xdl_cshuffle.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_utils.hpp"
#include "ck/host_utility/device_prop.hpp"
#include "ck/host_utility/kernel_launch.hpp"
#include "ck/host_utility/io.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace {
/*
* \brief Wrapper function of GridwiseGemm::Run to realize BatchedGEMM.
*
* \tparam ComputePtrOffsetOfBatch Class that computes the base pointer offsets of A, B, C matrix
* given the batch. For example, ComputePtrOffsetOfStridedBatch() computes the offsets of evenly
* strided batched, but we can easily extend to other layouts. The returned offset can be either \p
* index_t or \p long_index_t. If it returns \p long_index_t, we are not subject to the 2GB
* limitations.
*
* \tparam Block2ETileMap Block2ETileMap::CalculateBottomIndex() takes in id of a workgroup and
* returns the 2D index of the tile that it computes. \see
* GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3::Run().
*
* \note Using \p ComputePtrOffsetOfBatch gives us the flexibility that 2 workgroups can compute 2
* tiles from different matrices. Keep in mind that these 2 matrices can share the same grid
* descriptor (like in BatchedGEMM), or use their own grid descriptors (in GroupedGemm). \link
* impl/device_conv3d_fwd_xdl_ndhwc_kzyxc_ndhwk.hpp kernel_gemm_xdlops_v2r3_for_conv3d \endlink for
* \link DeviceConv3d \endlink uses the same concept, but currently does NOT encapsulate the
* computing of pointer offset into \p ComputePtrOffsetOfStridedBatch.
*
* \note \p Block2ETileMap allows customized mapping between a workgroup and the C-tile it computes.
* Together with \p ComputePtrOffsetOfBatch, we can reuse GridwiseGemm (and GridwiseGemm fusion ) to
* realize BatchedGemm and GroupedGemm (and the corresponding GEMM fusion).
*
*/
template <typename GridwiseGemm,
typename AsPointer, // tuples if multi AB, pointers if no
typename BsPointer,
typename DsPointer,
typename EDataType,
typename AElementwiseOperation,
typename BElementwiseOperation,
typename CDEElementwiseOperation,
typename AGridDesc_AK0_M_AK1,
typename BGridDesc_BK0_N_BK1,
typename DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock,
typename EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock,
typename Block2ETileMap,
typename ComputePtrOffsetOfBatch,
bool HasMainKBlockLoop,
bool isMultiA,
bool isMultiB>
__device__ void device_grouped_conv_fwd_multiple_abd_xdl_cshuffle(
AsPointer p_as_grid,
BsPointer p_bs_grid,
DsPointer p_ds_grid,
EDataType* __restrict__ p_e_grid,
const AElementwiseOperation a_element_op,
const BElementwiseOperation b_element_op,
const CDEElementwiseOperation cde_element_op,
const index_t batch_count,
const AGridDesc_AK0_M_AK1 a_grid_desc_k0_m_k1,
const BGridDesc_BK0_N_BK1 b_grid_desc_k0_n_k1,
const DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
ds_grid_desc_mblock_mperblock_nblock_nperblock,
const EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock
e_grid_desc_mblock_mperblock_nblock_nperblock_,
const Block2ETileMap block_2_ctile_map,
const ComputePtrOffsetOfBatch compute_ptr_offset_of_batch)
{
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \
defined(__gfx94__))
// offset base pointer for each work-group
const index_t num_blocks_per_batch =
__builtin_amdgcn_readfirstlane(get_grid_size() / batch_count);
const index_t g_idx = __builtin_amdgcn_readfirstlane(get_block_1d_id() / num_blocks_per_batch);
const long_index_t e_batch_offset = __builtin_amdgcn_readfirstlane(
static_cast<long_index_t>(compute_ptr_offset_of_batch.GetEPtrOffset(g_idx)));
const auto& ds_batch_offset = compute_ptr_offset_of_batch.GetDsPtrOffset(g_idx);
__shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()];
DsPointer p_ds_grid_grp;
static constexpr index_t NumDTensor =
DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock::Size();
static_for<0, NumDTensor, 1>{}(
[&](auto i) { p_ds_grid_grp(i) = p_ds_grid[i] + ds_batch_offset[i]; });
if constexpr(isMultiA || isMultiB)
{
AsPointer p_as_grid_grp;
BsPointer p_bs_grid_grp;
const auto& as_batch_offset = compute_ptr_offset_of_batch.GetAsPtrOffset(g_idx);
static constexpr index_t NumATensor = AGridDesc_AK0_M_AK1::Size();
static_for<0, NumATensor, 1>{}(
[&](auto i) { p_as_grid_grp(i) = p_as_grid[i] + as_batch_offset[i]; });
const auto& bs_batch_offset = compute_ptr_offset_of_batch.GetBsPtrOffset(g_idx);
static constexpr index_t NumBTensor = BGridDesc_BK0_N_BK1::Size();
static_for<0, NumBTensor, 1>{}(
[&](auto i) { p_bs_grid_grp(i) = p_bs_grid[i] + bs_batch_offset[i]; });
GridwiseGemm::template Run<HasMainKBlockLoop>(
p_as_grid_grp,
p_bs_grid_grp,
p_ds_grid_grp,
p_e_grid + e_batch_offset,
p_shared,
a_element_op,
b_element_op,
cde_element_op,
a_grid_desc_k0_m_k1,
b_grid_desc_k0_n_k1,
ds_grid_desc_mblock_mperblock_nblock_nperblock,
e_grid_desc_mblock_mperblock_nblock_nperblock_,
block_2_ctile_map);
}
else
{
const long_index_t a_batch_offset = __builtin_amdgcn_readfirstlane(
static_cast<long_index_t>(compute_ptr_offset_of_batch.GetAPtrOffset(g_idx)));
const long_index_t b_batch_offset = __builtin_amdgcn_readfirstlane(
static_cast<long_index_t>(compute_ptr_offset_of_batch.GetBPtrOffset(g_idx)));
GridwiseGemm::template Run<HasMainKBlockLoop>(
p_as_grid + a_batch_offset,
p_bs_grid + b_batch_offset,
p_ds_grid_grp,
p_e_grid + e_batch_offset,
p_shared,
a_element_op,
b_element_op,
cde_element_op,
a_grid_desc_k0_m_k1,
b_grid_desc_k0_n_k1,
ds_grid_desc_mblock_mperblock_nblock_nperblock,
e_grid_desc_mblock_mperblock_nblock_nperblock_,
block_2_ctile_map);
}
#else
ignore = p_as_grid;
ignore = p_bs_grid;
ignore = p_ds_grid;
ignore = p_e_grid;
ignore = batch_count;
ignore = a_grid_desc_k0_m_k1;
ignore = b_grid_desc_k0_n_k1;
ignore = ds_grid_desc_mblock_mperblock_nblock_nperblock;
ignore = e_grid_desc_mblock_mperblock_nblock_nperblock_;
ignore = a_element_op;
ignore = b_element_op;
ignore = cde_element_op;
ignore = compute_ptr_offset_of_batch;
ignore = block_2_ctile_map;
#endif
}
template <typename GridwiseGemm,
typename AsPointer, // tuples if multi AB, pointers if no
typename BsPointer,
typename DsPointer,
typename EDataType,
typename AElementwiseOperation,
typename BElementwiseOperation,
typename CDEElementwiseOperation,
typename AGridDesc_AK0_M_AK1,
typename BGridDesc_BK0_N_BK1,
typename DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock,
typename EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock,
typename Block2ETileMap,
typename ComputePtrOffsetOfBatch,
bool HasMainKBlockLoop,
bool isMultiA,
bool isMultiB>
__global__ void
#if CK_USE_LAUNCH_BOUNDS
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
#endif
kernel_grouped_conv_fwd_multiple_abd_xdl_cshuffle(
AsPointer p_as_grid,
BsPointer p_bs_grid,
DsPointer p_ds_grid,
EDataType* __restrict__ p_e_grid,
const AElementwiseOperation a_element_op,
const BElementwiseOperation b_element_op,
const CDEElementwiseOperation cde_element_op,
const index_t batch_count,
const AGridDesc_AK0_M_AK1 a_grid_desc_k0_m_k1,
const BGridDesc_BK0_N_BK1 b_grid_desc_k0_n_k1,
const DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
ds_grid_desc_mblock_mperblock_nblock_nperblock,
const EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock
e_grid_desc_mblock_mperblock_nblock_nperblock_,
const Block2ETileMap block_2_ctile_map,
const ComputePtrOffsetOfBatch compute_ptr_offset_of_batch)
{
device_grouped_conv_fwd_multiple_abd_xdl_cshuffle<
GridwiseGemm,
AsPointer, // tuples if multi AB, pointers if no
BsPointer,
DsPointer,
EDataType,
AElementwiseOperation,
BElementwiseOperation,
CDEElementwiseOperation,
AGridDesc_AK0_M_AK1,
BGridDesc_BK0_N_BK1,
DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock,
EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock,
Block2ETileMap,
ComputePtrOffsetOfBatch,
HasMainKBlockLoop,
isMultiA,
isMultiB>(p_as_grid,
p_bs_grid,
p_ds_grid,
*p_e_grid,
a_element_op,
b_element_op,
cde_element_op,
batch_count,
a_grid_desc_k0_m_k1,
b_grid_desc_k0_n_k1,
ds_grid_desc_mblock_mperblock_nblock_nperblock,
e_grid_desc_mblock_mperblock_nblock_nperblock_,
block_2_ctile_map,
compute_ptr_offset_of_batch);
}
} // namespace
template <typename T>
using is_tuple = decltype(std::declval<T&>().IsTuple());
//
// @brief Device Convolution operation.
//
// Supports:
// @li Forward convolution with up to 3 spatial dimentions
// @li Input tensor in GNWC data format
// @li Weight tensor in GKXC data format
// @li Output tensor in GNWK data format
//
// 1D:
// out[N, Wo, K] = in[N, Wi, C] * wei[K, X, C]
// 2D:
// out[N, Ho, Wo, K] = in[N, Hi, Wi, C] * wei[K, Y, X, C]
// 3D:
// out[N, Do, Ho, Wo, K] = in[N, Di, Hi, Wi, C] * wei[K, Z, Y, X, C]
//
template <index_t NDimSpatial,
typename ALayout,
typename BLayout,
typename DsLayout,
typename ELayout,
typename ADataType,
typename BDataType,
typename AccDataType,
typename CShuffleDataType,
typename DsDataType,
typename EDataType,
typename AElementwiseOperation,
typename BElementwiseOperation,
typename CDEElementwiseOperation,
ConvolutionForwardSpecialization ConvForwardSpecialization,
GemmSpecialization GemmSpec,
index_t NumGemmKPrefetchStage,
index_t BlockSize,
index_t MPerBlock,
index_t NPerBlock,
index_t KPerBlock,
index_t AK1,
index_t BK1,
index_t MPerXDL,
index_t NPerXDL,
index_t MXdlPerWave,
index_t NXdlPerWave,
typename ABlockTransferThreadClusterLengths_AK0_M_AK1,
typename ABlockTransferThreadClusterArrangeOrder,
typename ABlockTransferSrcAccessOrder,
index_t ABlockTransferSrcVectorDim,
index_t ABlockTransferSrcScalarPerVector,
index_t ABlockTransferDstScalarPerVector_AK1,
index_t ABlockLdsExtraM,
typename BBlockTransferThreadClusterLengths_BK0_N_BK1,
typename BBlockTransferThreadClusterArrangeOrder,
typename BBlockTransferSrcAccessOrder,
index_t BBlockTransferSrcVectorDim,
index_t BBlockTransferSrcScalarPerVector,
index_t BBlockTransferDstScalarPerVector_BK1,
index_t BBlockLdsExtraN,
index_t CShuffleMXdlPerWavePerShuffle,
index_t CShuffleNXdlPerWavePerShuffle,
typename CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
index_t CDEBlockTransferScalarPerVector_NPerBlock,
typename ComputeDataType =
decltype(UnpackDataType<is_detected<is_tuple, ADataType>::value,
Number<0>,
ADataType>()), // ComputeType is InputType by default (first
// in tuple for MultiAB), unpack if tuple was
// passed
LoopScheduler LoopSched = make_default_loop_scheduler()>
struct CodegenDeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
: public DeviceGroupedConvFwdMultipleABD<NDimSpatial,
ALayout,
BLayout,
DsLayout,
ELayout,
ADataType,
BDataType,
DsDataType,
EDataType,
AElementwiseOperation,
BElementwiseOperation,
CDEElementwiseOperation,
ComputeDataType>
{
using DeviceOp = CodegenDeviceGroupedConvFwdMultipleABD_Xdl_CShuffle;
static constexpr bool isMultiA = is_detected<is_tuple, ADataType>::value;
static constexpr bool isMultiB = is_detected<is_tuple, BDataType>::value;
static constexpr index_t NumATensor = GetNumABTensors<isMultiA, ADataType>();
static constexpr index_t NumBTensor = GetNumABTensors<isMultiB, BDataType>();
static constexpr index_t NumDTensor = DsDataType::Size();
static constexpr auto I0 = Number<0>{};
static constexpr auto I1 = Number<1>{};
static constexpr auto I2 = Number<2>{};
static constexpr auto I3 = Number<3>{};
static constexpr auto conv_to_gemm_transformer =
TransformConvFwdToGemm<NDimSpatial, ConvForwardSpecialization>{};
static constexpr auto matrix_padder =
MatrixPadder<GemmSpec, index_t, index_t, index_t>{MPerBlock, NPerBlock, KPerBlock};
template <typename ALay>
__host__ __device__ static auto
MakeAGridDescriptor_M_K(const ck::Array<index_t, NDimSpatial + 3>& a_g_n_c_wis_lengths,
const ck::Array<index_t, NDimSpatial + 3>& a_g_n_c_wis_strides,
const ck::Array<index_t, NDimSpatial + 3>& b_g_k_c_xs_lengths,
const ck::Array<index_t, NDimSpatial + 3>& b_g_k_c_xs_strides,
const ck::Array<index_t, NDimSpatial + 3>& e_g_n_k_wos_lengths,
const ck::Array<index_t, NDimSpatial + 3>& e_g_n_k_wos_strides,
const ck::Array<index_t, NDimSpatial>& conv_filter_strides,
const ck::Array<index_t, NDimSpatial>& conv_filter_dilations,
const ck::Array<index_t, NDimSpatial>& input_left_pads,
const ck::Array<index_t, NDimSpatial>& input_right_pads)
{
const auto in_gemmmraw_gemmkraw_desc =
conv_to_gemm_transformer.template MakeADescriptor_M_K<ALay>(a_g_n_c_wis_lengths,
a_g_n_c_wis_strides,
b_g_k_c_xs_lengths,
b_g_k_c_xs_strides,
e_g_n_k_wos_lengths,
e_g_n_k_wos_strides,
conv_filter_strides,
conv_filter_dilations,
input_left_pads,
input_right_pads);
const auto in_gemmm_gemmk_desc =
matrix_padder.PadADescriptor_M_K(in_gemmmraw_gemmkraw_desc);
return in_gemmm_gemmk_desc;
}
template <typename BLay>
__host__ __device__ static auto
MakeBGridDescriptor_N_K(const ck::Array<index_t, NDimSpatial + 3>& b_g_k_c_xs_lengths,
const ck::Array<index_t, NDimSpatial + 3>& b_g_k_c_xs_strides)
{
const auto wei_gemmnraw_gemmkraw_desc =
conv_to_gemm_transformer.template MakeBDescriptor_N_K<BLay>(b_g_k_c_xs_lengths,
b_g_k_c_xs_strides);
const auto wei_gemmn_gemmk_desc =
matrix_padder.PadBDescriptor_N_K(wei_gemmnraw_gemmkraw_desc);
return wei_gemmn_gemmk_desc;
}
template <typename ELay>
__host__ __device__ static auto
MakeEGridDescriptor_M_N(const ck::Array<index_t, NDimSpatial + 3>& e_g_n_k_wos_lengths,
const ck::Array<index_t, NDimSpatial + 3>& e_g_n_k_wos_strides)
{
const auto out_gemmmraw_gemmnraw_desc =
conv_to_gemm_transformer.template MakeCDescriptor_M_N<ELay>(e_g_n_k_wos_lengths,
e_g_n_k_wos_strides);
const auto out_gemmm_gemmn_desc =
matrix_padder.PadCDescriptor_M_N(out_gemmmraw_gemmnraw_desc);
return out_gemmm_gemmn_desc;
}
// Shape of Ds and E must be aligned. Strides can be different.
// Pass e_g_n_k_wos_lengths for logical broadcast.
__host__ __device__ static auto MakeDsGridDescriptor_M_N(
const ck::Array<index_t, NDimSpatial + 3>& e_g_n_k_wos_lengths,
const ck::Array<ck::Array<index_t, NDimSpatial + 3>, NumDTensor>& ds_g_n_k_wos_strides)
{
return generate_tuple(
[&](auto i) {
using DLayout = remove_cvref_t<tuple_element_t<i.value, DsLayout>>;
return DeviceOp::MakeEGridDescriptor_M_N<DLayout>(e_g_n_k_wos_lengths,
ds_g_n_k_wos_strides[i]);
},
Number<NumDTensor>{});
}
// desc for problem definition
using AGridDesc_M_K = remove_cvref_t<decltype(MakeAGridDescriptor_M_K<ALayout>(
{}, {}, {}, {}, {}, {}, {}, {}, {}, {}))>;
using BGridDesc_N_K = remove_cvref_t<decltype(MakeBGridDescriptor_N_K<BLayout>({}, {}))>;
using DsGridDesc_M_N = remove_cvref_t<decltype(MakeDsGridDescriptor_M_N({}, {}))>;
using EGridDesc_M_N = remove_cvref_t<decltype(MakeEGridDescriptor_M_N<ELayout>({}, {}))>;
// If we are using multiAB and one of the template datatype parameters is not a tuple, convert
// it to it
using GemmADataType = std::conditional_t<!isMultiA && isMultiB, Tuple<ADataType>, ADataType>;
using GemmBDataType = std::conditional_t<!isMultiB && isMultiA, Tuple<BDataType>, BDataType>;
#define GridwiseGemmTemplateParameters \
GemmADataType, GemmBDataType, ComputeDataType, AccDataType, CShuffleDataType, DsDataType, \
EDataType, AElementwiseOperation, BElementwiseOperation, CDEElementwiseOperation, \
InMemoryDataOperationEnum::Set, NumGemmKPrefetchStage, BlockSize, MPerBlock, NPerBlock, \
KPerBlock, AK1, BK1, MPerXDL, NPerXDL, MXdlPerWave, NXdlPerWave, \
ABlockTransferThreadClusterLengths_AK0_M_AK1, ABlockTransferThreadClusterArrangeOrder, \
ABlockTransferSrcAccessOrder, ABlockTransferSrcVectorDim, \
ABlockTransferSrcScalarPerVector, ABlockTransferDstScalarPerVector_AK1, false, \
ABlockLdsExtraM, BBlockTransferThreadClusterLengths_BK0_N_BK1, \
BBlockTransferThreadClusterArrangeOrder, BBlockTransferSrcAccessOrder, \
BBlockTransferSrcVectorDim, BBlockTransferSrcScalarPerVector, \
BBlockTransferDstScalarPerVector_BK1, false, BBlockLdsExtraN, \
CShuffleMXdlPerWavePerShuffle, CShuffleNXdlPerWavePerShuffle, \
CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, \
CDEBlockTransferScalarPerVector_NPerBlock, LoopSched
// Use appropriate gridwise gemm
using GridwiseGemm =
std::conditional_t<isMultiA || isMultiB,
GridwiseGemmMultipleABD_xdl_cshuffle<GridwiseGemmTemplateParameters>,
GridwiseGemmMultipleD_xdl_cshuffle<GridwiseGemmTemplateParameters>>;
// If ADataTypes or BDataTypes is tuple, user has to pass ck::Array with pointers.
using APointers =
std::conditional_t<isMultiA, ck::Array<const void*, NumATensor>&, const void*>;
using BPointers =
std::conditional_t<isMultiB, ck::Array<const void*, NumBTensor>&, const void*>;
// Use Tuple for the both cases for GridPointer to initialize it in Argument constructor (not
// in initializer list what is required for single const pointer).
using AGridPointer = remove_cvref_t<
decltype(GetAGridPointer < isMultiA || isMultiB, GridwiseGemm, ADataType > ())>;
using BGridPointer = remove_cvref_t<
decltype(GetBGridPointer < isMultiA || isMultiB, GridwiseGemm, BDataType > ())>;
// desc for blockwise copy
using AGridDesc_AK0_M_AK1 =
remove_cvref_t<decltype(GridwiseGemm::MakeDefaultAGridDescriptor_AK0_M_AK1(
AGridDesc_M_K{}))>;
using BGridDesc_BK0_N_BK1 =
remove_cvref_t<decltype(GridwiseGemm::MakeDefaultBGridDescriptor_BK0_N_BK1(
BGridDesc_N_K{}))>;
using DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock = remove_cvref_t<
decltype(GridwiseGemm::MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
DsGridDesc_M_N{}))>;
using EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock =
remove_cvref_t<decltype(GridwiseGemm::MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
EGridDesc_M_N{}))>;
// block-to-e-tile map
using Block2ETileMap =
remove_cvref_t<decltype(GridwiseGemm::MakeDefaultBlock2ETileMap(EGridDesc_M_N{}))>;
// Argument
struct Argument
{
__device__ __host__ Argument(
APointers p_as,
BPointers p_bs,
const ck::Array<const void*, NumDTensor>& p_ds,
void* p_e,
const ck::Array<index_t, NDimSpatial + 3>& a_g_n_c_wis_lengths,
const ck::Array<index_t, NDimSpatial + 3>& a_g_n_c_wis_strides,
const ck::Array<index_t, NDimSpatial + 3>& b_g_k_c_xs_lengths,
const ck::Array<index_t, NDimSpatial + 3>& b_g_k_c_xs_strides,
const ck::Array<ck::Array<index_t, NDimSpatial + 3>, NumDTensor>& ds_g_n_k_wos_lengths,
const ck::Array<ck::Array<index_t, NDimSpatial + 3>, NumDTensor>& ds_g_n_k_wos_strides,
const ck::Array<index_t, NDimSpatial + 3>& e_g_n_k_wos_lengths,
const ck::Array<index_t, NDimSpatial + 3>& e_g_n_k_wos_strides,
const ck::Array<index_t, NDimSpatial>& conv_filter_strides,
const ck::Array<index_t, NDimSpatial>& conv_filter_dilations,
const ck::Array<index_t, NDimSpatial>& input_left_pads,
const ck::Array<index_t, NDimSpatial>& input_right_pads,
const AElementwiseOperation& a_element_op,
const BElementwiseOperation& b_element_op,
const CDEElementwiseOperation& cde_element_op)
: p_as_grid_{},
p_bs_grid_{},
p_ds_grid_{},
p_e_grid_{static_cast<EDataType*>(p_e)},
num_group_{a_g_n_c_wis_lengths[0]},
a_grid_desc_m_k_{DeviceOp::MakeAGridDescriptor_M_K<ALayout>(a_g_n_c_wis_lengths,
a_g_n_c_wis_strides,
b_g_k_c_xs_lengths,
b_g_k_c_xs_strides,
e_g_n_k_wos_lengths,
e_g_n_k_wos_strides,
conv_filter_strides,
conv_filter_dilations,
input_left_pads,
input_right_pads)},
b_grid_desc_n_k_{DeviceOp::MakeBGridDescriptor_N_K<BLayout>(b_g_k_c_xs_lengths,
b_g_k_c_xs_strides)},
ds_grid_desc_m_n_{},
e_grid_desc_m_n_{DeviceOp::MakeEGridDescriptor_M_N<ELayout>(e_g_n_k_wos_lengths,
e_g_n_k_wos_strides)},
a_grid_desc_ak0_m_ak1_{
GridwiseGemm::MakeDefaultAGridDescriptor_AK0_M_AK1(a_grid_desc_m_k_)},
b_grid_desc_bk0_n_bk1_{
GridwiseGemm::MakeDefaultBGridDescriptor_BK0_N_BK1(b_grid_desc_n_k_)},
ds_grid_desc_mblock_mperblock_nblock_nperblock_{},
e_grid_desc_mblock_mperblock_nblock_nperblock_{},
block_2_etile_map_{GridwiseGemm::MakeDefaultBlock2ETileMap(e_grid_desc_m_n_)},
compute_ptr_offset_of_batch_{},
a_element_op_{a_element_op},
b_element_op_{b_element_op},
cde_element_op_{cde_element_op},
a_g_n_c_wis_lengths_{a_g_n_c_wis_lengths},
a_g_n_c_wis_strides_{a_g_n_c_wis_strides},
b_g_k_c_xs_lengths_{b_g_k_c_xs_lengths},
b_g_k_c_xs_strides_{b_g_k_c_xs_strides},
ds_g_n_k_wos_lengths_{ds_g_n_k_wos_lengths},
ds_g_n_k_wos_strides_{ds_g_n_k_wos_strides},
e_g_n_k_wos_lengths_{e_g_n_k_wos_lengths},
e_g_n_k_wos_strides_{e_g_n_k_wos_strides},
conv_filter_strides_{conv_filter_strides},
conv_filter_dilations_{conv_filter_dilations},
input_left_pads_{input_left_pads},
input_right_pads_{input_right_pads}
{
// A/B/E Batch Stride
if constexpr(isMultiA || isMultiB)
{
static_for<0, NumATensor, 1>{}([&](auto i) {
// Init compute_ptr_offset_of_batch_ for multiple AB
compute_ptr_offset_of_batch_.BatchStrideA_(i) = a_g_n_c_wis_strides[0];
// Use GemmADataType/GemmBDataType to iterate over tuple (even if passed data
// type is not tuple)
using DataType = remove_cvref_t<tuple_element_t<i.value, GemmADataType>>;
// It is possible that one of the AB is a pointer and one is a tuple.
// Then also use multiAB but we have to cast single pointer instead of tuple of
// pointer.
if constexpr(isMultiA)
{
// p_as is tuple
p_as_grid_(i) = static_cast<const DataType*>(p_as[i.value]);
}
else
{
// if MultiB and not MultiA then p_as is single pointer
p_as_grid_(i) = static_cast<const DataType*>(p_as);
}
});
static_for<0, NumBTensor, 1>{}([&](auto i) {
// Init compute_ptr_offset_of_batch_ for multiple AB
compute_ptr_offset_of_batch_.BatchStrideB_(i) = b_g_k_c_xs_strides[0];
using DataType = remove_cvref_t<tuple_element_t<i.value, GemmBDataType>>;
// It is possible that one of the AB is a pointer and one is a tuple.
// Then also use multiAB but we have to cast single pointer instead of tuple of
// pointer.
if constexpr(isMultiB)
{
// p_bs is tuple
p_bs_grid_(i) = static_cast<const DataType*>(p_bs[i.value]);
}
else
{
// if MultiA and not MultiB then p_bs is single pointer
p_bs_grid_(i) = static_cast<const DataType*>(p_bs);
}
});
}
else
{
compute_ptr_offset_of_batch_.BatchStrideA_ = a_g_n_c_wis_strides[0];
compute_ptr_offset_of_batch_.BatchStrideB_ = b_g_k_c_xs_strides[0];
// p_as and p_bs are pointers
p_as_grid_(I0) = static_cast<const ADataType*>(p_as);
p_bs_grid_(I0) = static_cast<const BDataType*>(p_bs);
}
// populate pointer, batch stride, desc for Ds
static_for<0, NumDTensor, 1>{}([&](auto i) {
using DLayout = remove_cvref_t<tuple_element_t<i.value, DsLayout>>;
using DDataType = remove_cvref_t<tuple_element_t<i.value, DsDataType>>;
// D pointer
p_ds_grid_(i) = static_cast<const DDataType*>(p_ds[i]);
// D batch stride
compute_ptr_offset_of_batch_.BatchStrideDs_(i) = ds_g_n_k_wos_strides[i][0];
// D desc
ds_grid_desc_m_n_(i) = DeviceOp::MakeEGridDescriptor_M_N<DLayout>(
e_g_n_k_wos_lengths, ds_g_n_k_wos_strides[i]);
});
compute_ptr_offset_of_batch_.BatchStrideE_ = e_g_n_k_wos_strides[0];
// populate desc for Ds/E
if constexpr(isMultiA || isMultiB)
{
const auto as_grid_desc_ak0_m_ak1 =
generate_tuple([&](auto) { return a_grid_desc_m_k_; }, Number<NumATensor>{});
const auto bs_grid_desc_bk0_n_bk1 =
generate_tuple([&](auto) { return b_grid_desc_n_k_; }, Number<NumBTensor>{});
if(GridwiseGemm::CheckValidity(as_grid_desc_ak0_m_ak1,
bs_grid_desc_bk0_n_bk1,
ds_grid_desc_m_n_,
e_grid_desc_m_n_,
block_2_etile_map_))
{
e_grid_desc_mblock_mperblock_nblock_nperblock_ =
GridwiseGemm::MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
e_grid_desc_m_n_);
ds_grid_desc_mblock_mperblock_nblock_nperblock_ =
GridwiseGemm::MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
ds_grid_desc_m_n_);
}
}
else
{
if(GridwiseGemm::CheckValidity(a_grid_desc_m_k_,
b_grid_desc_n_k_,
ds_grid_desc_m_n_,
e_grid_desc_m_n_,
block_2_etile_map_))
{
e_grid_desc_mblock_mperblock_nblock_nperblock_ =
GridwiseGemm::MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
e_grid_desc_m_n_);
ds_grid_desc_mblock_mperblock_nblock_nperblock_ =
GridwiseGemm::MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
ds_grid_desc_m_n_);
}
}
}
// private:
// pointers (tuple if multi AB, pointer if no)
AGridPointer p_as_grid_;
BGridPointer p_bs_grid_;
typename GridwiseGemm::DsGridPointer p_ds_grid_;
EDataType* p_e_grid_;
// tensor descriptors for problem definiton
index_t num_group_;
AGridDesc_M_K a_grid_desc_m_k_;
BGridDesc_N_K b_grid_desc_n_k_;
DsGridDesc_M_N ds_grid_desc_m_n_;
EGridDesc_M_N e_grid_desc_m_n_;
// tensor descriptors for block/thread-wise copy
AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1_;
BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1_;
DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock
ds_grid_desc_mblock_mperblock_nblock_nperblock_;
EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock e_grid_desc_mblock_mperblock_nblock_nperblock_;
// block-to-e-tile map
Block2ETileMap block_2_etile_map_;
// for computing batch offset
ComputePtrOffsetOfStridedBatch<NumATensor, NumBTensor, NumDTensor>
compute_ptr_offset_of_batch_;
// element-wise op
AElementwiseOperation a_element_op_;
BElementwiseOperation b_element_op_;
CDEElementwiseOperation cde_element_op_;
// for checking IsSupportedArgument()
ck::Array<index_t, NDimSpatial + 3> a_g_n_c_wis_lengths_;
ck::Array<index_t, NDimSpatial + 3> a_g_n_c_wis_strides_;
ck::Array<index_t, NDimSpatial + 3> b_g_k_c_xs_lengths_;
ck::Array<index_t, NDimSpatial + 3> b_g_k_c_xs_strides_;
ck::Array<ck::Array<index_t, NDimSpatial + 3>, NumDTensor> ds_g_n_k_wos_lengths_;
ck::Array<ck::Array<index_t, NDimSpatial + 3>, NumDTensor> ds_g_n_k_wos_strides_;
ck::Array<index_t, NDimSpatial + 3> e_g_n_k_wos_lengths_;
ck::Array<index_t, NDimSpatial + 3> e_g_n_k_wos_strides_;
ck::Array<index_t, NDimSpatial> conv_filter_strides_;
ck::Array<index_t, NDimSpatial> conv_filter_dilations_;
ck::Array<index_t, NDimSpatial> input_left_pads_;
ck::Array<index_t, NDimSpatial> input_right_pads_;
};
static __device__ __host__ auto MakeArgument(
APointers p_as,
BPointers p_bs,
const ck::Array<const void*, NumDTensor>& p_ds,
void* p_e,
const ck::Array<index_t, NDimSpatial + 3>& a_g_n_c_wis_lengths,
const ck::Array<index_t, NDimSpatial + 3>& a_g_n_c_wis_strides,
const ck::Array<index_t, NDimSpatial + 3>& b_g_k_c_xs_lengths,
const ck::Array<index_t, NDimSpatial + 3>& b_g_k_c_xs_strides,
const ck::Array<ck::Array<index_t, NDimSpatial + 3>, NumDTensor>& ds_g_n_k_wos_lengths,
const ck::Array<ck::Array<index_t, NDimSpatial + 3>, NumDTensor>& ds_g_n_k_wos_strides,
const ck::Array<index_t, NDimSpatial + 3>& e_g_n_k_wos_lengths,
const ck::Array<index_t, NDimSpatial + 3>& e_g_n_k_wos_strides,
const ck::Array<index_t, NDimSpatial>& conv_filter_strides,
const ck::Array<index_t, NDimSpatial>& conv_filter_dilations,
const ck::Array<index_t, NDimSpatial>& input_left_pads,
const ck::Array<index_t, NDimSpatial>& input_right_pads,
const AElementwiseOperation& a_element_op,
const BElementwiseOperation& b_element_op,
const CDEElementwiseOperation& cde_element_op)
{
return Argument{p_as,
p_bs,
p_ds,
p_e,
a_g_n_c_wis_lengths,
a_g_n_c_wis_strides,
b_g_k_c_xs_lengths,
b_g_k_c_xs_strides,
ds_g_n_k_wos_lengths,
ds_g_n_k_wos_strides,
e_g_n_k_wos_lengths,
e_g_n_k_wos_strides,
conv_filter_strides,
conv_filter_dilations,
input_left_pads,
input_right_pads,
a_element_op,
b_element_op,
cde_element_op};
}
};
} // namespace device
} // namespace tensor_operation
} // namespace ck
......@@ -133,8 +133,13 @@ struct DeviceBatchedContractionMultipleD_Wmma_CShuffle
static constexpr auto NWaves = NPerBlock / (NRepeat * NPerWmma);
static constexpr auto WmmaK = K1 == 16 ? 32 : 16;
static constexpr auto AEnableLds_auto = NWaves == 1 ? false : true;
static constexpr auto BEnableLds_auto = MWaves == 1 ? false : true;
static constexpr auto MaxVectorLoadA = K1 * sizeof(ADataType) == 16 ? true : false;
static constexpr auto MaxVectorLoadB = K1 * sizeof(BDataType) == 16 ? true : false;
static constexpr auto AEnableLds_auto =
(NWaves == 1 && (MaxVectorLoadA || MRepeat == 1)) ? false : true;
static constexpr auto BEnableLds_auto =
(MWaves == 1 && (MaxVectorLoadB || NRepeat == 1)) ? false : true;
// If true, LDS is used unconditionally
static constexpr auto AEnableLds_manu = false;
......@@ -829,7 +834,7 @@ struct DeviceBatchedContractionMultipleD_Wmma_CShuffle
static bool IsSupportedArgument(const Argument& arg)
{
if(ck::is_gfx11_supported())
if(ck::is_gfx11_supported() || ck::is_gfx12_supported())
{
if constexpr(!(is_same_v<AccDataType, float> || is_same_v<AccDataType, int32_t>))
{
......@@ -869,11 +874,15 @@ struct DeviceBatchedContractionMultipleD_Wmma_CShuffle
}
else
{
if(!(arg.a_kz_stride_ == 1 &&
arg.a_grid_desc_.GetLength(I2) % ABlockTransferSrcScalarPerVector == 0))
if(!(arg.a_kz_stride_ == 1))
{
printf("DeviceOp: Vector Access A-k check failure\n");
return false;
index_t LastK =
AEnableLds ? arg.a_grid_desc_.GetLength(I2) : arg.a_grid_desc_.GetLength(I6);
if(LastK % ABlockTransferSrcScalarPerVector == 0)
{
printf("DeviceOp: Vector Access A-k check failure\n");
return false;
}
}
}
......
......@@ -70,8 +70,9 @@ __global__ void
const ComputePtrOffsetOfBatch compute_ptr_offset_of_batch,
const Block2CTileMap block_2_ctile_map)
{
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx906__) || defined(__gfx908__) || \
defined(__gfx90a__) || defined(__gfx94__) || defined(__gfx103__) || defined(__gfx11__))
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx906__) || defined(__gfx908__) || \
defined(__gfx90a__) || defined(__gfx94__) || defined(__gfx103__) || defined(__gfx11__) || \
defined(__gfx12__))
const index_t num_blocks_per_batch =
__builtin_amdgcn_readfirstlane(get_grid_size() / batch_count);
......@@ -648,7 +649,7 @@ struct DeviceBatchedGemmMultipleD_Dl : public DeviceBatchedGemmMultiD<ALayout,
static bool IsSupportedArgument(const Argument& arg)
{
if(ck::get_device_name() == "gfx906" || ck::is_xdl_supported() ||
ck::is_gfx103_supported() || ck::is_gfx11_supported())
ck::is_gfx103_supported() || ck::is_gfx11_supported() || ck::is_gfx12_supported())
{
bool pass = true;
pass = pass && arg.K_ % K1 == 0;
......
......@@ -56,7 +56,7 @@ __global__ void
bool input_permute,
bool output_permute)
{
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx11__))
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx11__) || defined(__gfx12__))
// clang-format off
// ***************************************************
......@@ -159,6 +159,7 @@ __global__ void
ignore = O;
ignore = G0;
ignore = G1;
ignore = alpha;
ignore = input_permute;
ignore = output_permute;
#endif // end of if (defined(__gfx11__))
......@@ -187,7 +188,7 @@ __global__ void
index_t head_size,
float alpha)
{
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx11__))
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx11__) || defined(__gfx12__))
// clang-format off
// ***************************************************
......@@ -321,7 +322,7 @@ __global__ void
index_t head_size,
float alpha)
{
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx11__))
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx11__) || defined(__gfx12__))
// clang-format off
// ***************************************************
......@@ -858,7 +859,7 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Wmma_CShuffle
static bool IsSupportedArgument(const RawArg& arg)
{
if(ck::is_gfx11_supported())
if(ck::is_gfx11_supported() || ck::is_gfx12_supported())
{
if constexpr(!(is_same_v<Acc0DataType, float> || is_same_v<Acc0DataType, int32_t>))
{
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2023-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
......@@ -247,7 +247,8 @@ struct DeviceColumnToImageImpl
independent_filter_strides,
conv_filter_dilations,
input_left_pads_with_offset,
input_right_pads);
input_right_pads,
N);
const auto in_gemmm_gemmk_desc =
matrix_padder.PadADescriptor_M_K(in_gemmmraw_gemmkraw_desc);
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2023-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
......@@ -501,29 +501,24 @@ struct DeviceContractionMultipleABD_Xdl_CShuffle
// for sanity check of vector memory access
for(index_t i = 0; i < NumATensor; ++i)
{
as_mz_consecutive_[i] = a_ms_ks_strides[i][NumDimM - 1] == 1;
as_kz_consecutive_[i] = a_ms_ks_strides[i][NumDimM + NumDimK - 1] == 1;
as_max_read_elems_[i] =
tie(as_continous_dim_[i], as_max_read_elems_[i]) =
CalculateMaxRead<NumDimM, NumDimK>(a_ms_ks_lengths[i], a_ms_ks_strides[i]);
}
for(index_t i = 0; i < NumBTensor; ++i)
{
bs_nz_consecutive_[i] = b_ns_ks_strides[i][NumDimN - 1] == 1;
bs_kz_consecutive_[i] = b_ns_ks_strides[i][NumDimN + NumDimK - 1] == 1;
bs_max_read_elems_[i] =
tie(bs_continous_dim_[i], bs_max_read_elems_[i]) =
CalculateMaxRead<NumDimN, NumDimK>(b_ns_ks_lengths[i], b_ns_ks_strides[i]);
}
for(index_t i = 0; i < NumDTensor; ++i)
{
ds_nz_consecutive_[i] = d_ms_ns_strides[i][NumDimM + NumDimN - 1] == 1;
ds_max_read_elems_[i] =
tie(ds_continous_dim_[i], ds_max_read_elems_[i]) =
CalculateMaxRead<NumDimM, NumDimN>(d_ms_ns_lengths[i], d_ms_ns_strides[i]);
}
e_nz_consecutive_ = e_ms_ns_stride[NumDimM + NumDimN - 1] == 1;
e_max_write_elems_ = CalculateMaxRead<NumDimM, NumDimN>(e_ms_ns_length, e_ms_ns_stride);
tie(e_continous_dim_, e_max_write_elems_) =
CalculateMaxRead<NumDimM, NumDimN>(e_ms_ns_length, e_ms_ns_stride);
}
// pointers
......@@ -553,14 +548,11 @@ struct DeviceContractionMultipleABD_Xdl_CShuffle
BElementwiseOperation b_element_op_;
CDEElementwiseOperation cde_element_op_;
// Describe whether the last part of a given dimension of A/B/D/E is consecutive
// in the memory or not.
std::array<bool, NumATensor> as_mz_consecutive_;
std::array<bool, NumATensor> as_kz_consecutive_;
std::array<bool, NumBTensor> bs_nz_consecutive_;
std::array<bool, NumBTensor> bs_kz_consecutive_;
std::array<bool, NumDTensor> ds_nz_consecutive_;
bool e_nz_consecutive_;
// Describe whether the last part of a given dimension of A/B/D/E is continues dim.
std::array<index_t, NumATensor> as_continous_dim_;
std::array<index_t, NumATensor> bs_continous_dim_;
std::array<index_t, NumBTensor> ds_continous_dim_;
index_t e_continous_dim_;
std::array<index_t, NumATensor> as_max_read_elems_;
std::array<index_t, NumBTensor> bs_max_read_elems_;
......@@ -659,9 +651,9 @@ struct DeviceContractionMultipleABD_Xdl_CShuffle
const bool valid_a_vector_size =
arg.as_max_read_elems_[i] % ABlockTransferSrcScalarPerVector == 0;
const bool valid_a_access_dim_m =
ABlockTransferSrcVectorDim == 1 && arg.as_mz_consecutive_[i];
ABlockTransferSrcVectorDim == 1 && arg.as_continous_dim_[i] == 0;
const bool valid_a_access_dim_k =
ABlockTransferSrcVectorDim == 2 && arg.as_kz_consecutive_[i];
ABlockTransferSrcVectorDim == 2 && arg.as_continous_dim_[i] == 1;
const bool valid_a_access_dim = valid_a_access_dim_m || valid_a_access_dim_k;
if(!((valid_a_vector_size && valid_a_access_dim) ||
ABlockTransferSrcScalarPerVector == 1))
......@@ -679,9 +671,9 @@ struct DeviceContractionMultipleABD_Xdl_CShuffle
const bool valid_b_vector_size =
arg.bs_max_read_elems_[i] % BBlockTransferSrcScalarPerVector == 0;
const bool valid_b_access_dim_n =
BBlockTransferSrcVectorDim == 1 && arg.bs_nz_consecutive_[i];
BBlockTransferSrcVectorDim == 1 && arg.bs_continous_dim_[i] == 0;
const bool valid_b_access_dim_k =
BBlockTransferSrcVectorDim == 2 && arg.bs_kz_consecutive_[i];
BBlockTransferSrcVectorDim == 2 && arg.bs_continous_dim_[i] == 1;
const bool valid_b_access_dim = valid_b_access_dim_n || valid_b_access_dim_k;
if(!((valid_b_vector_size && valid_b_access_dim) ||
BBlockTransferSrcScalarPerVector == 1))
......@@ -699,7 +691,7 @@ struct DeviceContractionMultipleABD_Xdl_CShuffle
const bool valid_d_vector_size =
arg.ds_max_read_elems_[i] % CDEBlockTransferScalarPerVector_NPerBlock == 0;
// Vector read of Ds is always on N dimension.
const bool valid_d_access_dim = arg.ds_nz_consecutive_[i];
const bool valid_d_access_dim = arg.ds_continous_dim_[i] == 1;
if(!((valid_d_vector_size && valid_d_access_dim) ||
CDEBlockTransferScalarPerVector_NPerBlock == 1))
{
......@@ -714,7 +706,7 @@ struct DeviceContractionMultipleABD_Xdl_CShuffle
const bool valid_e_vector_size =
arg.e_max_write_elems_ % CDEBlockTransferScalarPerVector_NPerBlock == 0;
// Vector write of E is always on N dimension.
const bool valid_e_access_dim = arg.e_nz_consecutive_;
const bool valid_e_access_dim = arg.e_continous_dim_ == 1;
if(!((valid_e_vector_size && valid_e_access_dim) ||
CDEBlockTransferScalarPerVector_NPerBlock == 1))
{
......
......@@ -442,25 +442,19 @@ struct DeviceContractionMultipleD_Xdl_CShuffle
}
// for sanity check of vector memory access
a_mz_consecutive_ = a_ms_ks_strides[NumDimM - 1] == 1;
a_kz_consecutive_ = a_ms_ks_strides[NumDimM + NumDimK - 1] == 1;
a_max_read_elems_ =
tie(a_continous_dim_, a_max_read_elems_) =
CalculateMaxRead<NumDimM, NumDimK>(a_ms_ks_lengths, a_ms_ks_strides);
b_nz_consecutive_ = b_ns_ks_strides[NumDimN - 1] == 1;
b_kz_consecutive_ = b_ns_ks_strides[NumDimN + NumDimK - 1] == 1;
b_max_read_elems_ =
tie(b_continous_dim_, b_max_read_elems_) =
CalculateMaxRead<NumDimN, NumDimK>(b_ns_ks_lengths, b_ns_ks_strides);
for(index_t i = 0; i < NumDTensor; ++i)
{
ds_nz_consecutive_[i] = ds_ms_ns_strides[i][NumDimM + NumDimN - 1] == 1;
ds_max_read_elems_[i] =
tie(ds_continous_dim_[i], ds_max_read_elems_[i]) =
CalculateMaxRead<NumDimM, NumDimN>(ds_ms_ns_lengths[i], ds_ms_ns_strides[i]);
}
e_nz_consecutive_ = e_ms_ns_strides[NumDimM + NumDimN - 1] == 1;
e_max_write_elems_ =
tie(e_continous_dim_, e_max_write_elems_) =
CalculateMaxRead<NumDimM, NumDimN>(e_ms_ns_lengths, e_ms_ns_strides);
}
......@@ -501,14 +495,11 @@ struct DeviceContractionMultipleD_Xdl_CShuffle
BElementwiseOperation b_element_op_;
CDEElementwiseOperation cde_element_op_;
// Describe whether the last part of a given dimension of A/B/D/E is consecutive
// in the memory or not.
bool a_mz_consecutive_;
bool a_kz_consecutive_;
bool b_nz_consecutive_;
bool b_kz_consecutive_;
std::array<bool, NumDTensor> ds_nz_consecutive_;
bool e_nz_consecutive_;
// Describe whether the last part of a given dimension of A/B/D/E is continues dim.
index_t a_continous_dim_;
index_t b_continous_dim_;
std::array<index_t, NumDTensor> ds_continous_dim_;
index_t e_continous_dim_;
index_t a_max_read_elems_;
index_t b_max_read_elems_;
......@@ -601,9 +592,7 @@ struct DeviceContractionMultipleD_Xdl_CShuffle
return false;
}
if(ck::get_device_name() != "gfx90a" && ck::get_device_name() != "gfx940" &&
ck::get_device_name() != "gfx941" && ck::get_device_name() != "gfx942" &&
std::is_same<ADataType, double>::value)
if(!ck::is_lds_direct_load_supported() && std::is_same<ADataType, double>::value)
{
return false;
}
......@@ -624,8 +613,10 @@ struct DeviceContractionMultipleD_Xdl_CShuffle
const bool valid_a_vector_size =
arg.a_max_read_elems_ % ABlockTransferSrcScalarPerVector == 0;
const bool valid_a_access_dim_m = ABlockTransferSrcVectorDim == 1 && arg.a_mz_consecutive_;
const bool valid_a_access_dim_k = ABlockTransferSrcVectorDim == 2 && arg.a_kz_consecutive_;
const bool valid_a_access_dim_m =
ABlockTransferSrcVectorDim == 1 && arg.a_continous_dim_ == 0;
const bool valid_a_access_dim_k =
ABlockTransferSrcVectorDim == 2 && arg.a_continous_dim_ == 1;
const bool valid_a_access_dim =
valid_a_access_dim_m || valid_a_access_dim_k || ABlockTransferSrcScalarPerVector == 1;
if(!(valid_a_vector_size && valid_a_access_dim))
......@@ -635,8 +626,10 @@ struct DeviceContractionMultipleD_Xdl_CShuffle
const bool valid_b_vector_size =
arg.b_max_read_elems_ % BBlockTransferSrcScalarPerVector == 0;
const bool valid_b_access_dim_n = BBlockTransferSrcVectorDim == 1 && arg.b_nz_consecutive_;
const bool valid_b_access_dim_k = BBlockTransferSrcVectorDim == 2 && arg.b_kz_consecutive_;
const bool valid_b_access_dim_n =
BBlockTransferSrcVectorDim == 1 && arg.b_continous_dim_ == 0;
const bool valid_b_access_dim_k =
BBlockTransferSrcVectorDim == 2 && arg.b_continous_dim_ == 1;
const bool valid_b_access_dim =
valid_b_access_dim_n || valid_b_access_dim_k || BBlockTransferSrcScalarPerVector == 1;
if(!(valid_b_vector_size && valid_b_access_dim))
......@@ -650,7 +643,7 @@ struct DeviceContractionMultipleD_Xdl_CShuffle
arg.ds_max_read_elems_[i] % CDEBlockTransferScalarPerVector_NPerBlock == 0;
// Vector read of Ds is always on N dimension.
const bool valid_d_access_dim =
arg.ds_nz_consecutive_[i] || CDEBlockTransferScalarPerVector_NPerBlock == 1;
arg.ds_continous_dim_[i] == 1 || CDEBlockTransferScalarPerVector_NPerBlock == 1;
if(!(valid_d_vector_size && valid_d_access_dim))
{
valid_ds_access = false;
......@@ -665,7 +658,7 @@ struct DeviceContractionMultipleD_Xdl_CShuffle
arg.e_max_write_elems_ % CDEBlockTransferScalarPerVector_NPerBlock == 0;
// Vector write of E is always on N dimension.
const bool valid_e_access_dim =
arg.e_nz_consecutive_ || CDEBlockTransferScalarPerVector_NPerBlock == 1;
arg.e_continous_dim_ == 1 || CDEBlockTransferScalarPerVector_NPerBlock == 1;
if(!(valid_e_vector_size && valid_e_access_dim))
{
return false;
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2023-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
......@@ -50,25 +50,53 @@ auto CalculateMaxRead(const std::vector<index_t>& lengths, const std::vector<ind
}
// Determine the beginning and end idx of the group representing the FCD.
index_t begin_idx, end_idx;
if(strides[NumDim1 - 1] == 1)
index_t begin_idx, end_idx, continous_dim, consecutive_stride = 1;
if(strides[NumDim1 - 1] == 1 && strides[NumDim1 + NumDim2 - 1] == 1)
{
begin_idx = 0;
end_idx = NumDim1 - 1;
// MZ or KZ are ones
bool dims1_are_ones = true;
for(index_t dim_idx = 0; dim_idx < NumDim1; dim_idx++)
{
if(lengths[dim_idx] != 1)
{
dims1_are_ones = false;
}
}
if(dims1_are_ones)
{
begin_idx = NumDim1;
end_idx = NumDim1 + NumDim2 - 1;
continous_dim = 1;
}
else
{
begin_idx = 0;
end_idx = NumDim1 - 1;
continous_dim = 0;
}
}
else if(strides[NumDim1 - 1] == 1)
{
begin_idx = 0;
end_idx = NumDim1 - 1;
continous_dim = 0;
}
else if(strides[NumDim1 + NumDim2 - 1] == 1)
{
begin_idx = NumDim1;
end_idx = NumDim1 + NumDim2 - 1;
begin_idx = NumDim1;
end_idx = NumDim1 + NumDim2 - 1;
continous_dim = 1;
}
else
{
// The dimension consecutive in memory is not the last dimension of any group, so only
// one element can be read/written at once.
return 1;
consecutive_stride = 1;
continous_dim = 0;
return make_tuple(continous_dim, consecutive_stride);
}
index_t consecutive_stride = 1;
for(index_t dim_idx = end_idx; dim_idx >= begin_idx; --dim_idx)
{
if(strides[dim_idx] == consecutive_stride)
......@@ -81,7 +109,7 @@ auto CalculateMaxRead(const std::vector<index_t>& lengths, const std::vector<ind
}
}
const index_t max_subsequent_elems = consecutive_stride;
return max_subsequent_elems;
return make_tuple(continous_dim, max_subsequent_elems);
}
} // namespace device
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment