"...resnet50_tensorflow.git" did not exist on "5444724e783fa7e517d54996553015dda994066e"
Commit 7a5d11c9 authored by Adam Osewski's avatar Adam Osewski
Browse files

Hide BlocwiseGemmT from GridwiseGemm class scope.

When BlockwiseGemmT was defined at GridwiseGEMM class scope it caused
compilation errors on Navi architectures, where it was compiled.
parent 733c351d
...@@ -7,7 +7,6 @@ ...@@ -7,7 +7,6 @@
#include <sstream> #include <sstream>
#include <tuple> #include <tuple>
#include "ck/ck.hpp"
#include "ck/host_utility/device_prop.hpp" #include "ck/host_utility/device_prop.hpp"
#include "ck/host_utility/kernel_launch.hpp" #include "ck/host_utility/kernel_launch.hpp"
#include "ck/host_utility/hip_check_error.hpp" #include "ck/host_utility/hip_check_error.hpp"
...@@ -129,7 +128,8 @@ __global__ void ...@@ -129,7 +128,8 @@ __global__ void
const auto StrideA = gemm_desc_ptr[group_id].StrideA; const auto StrideA = gemm_desc_ptr[group_id].StrideA;
const auto StrideB = gemm_desc_ptr[group_id].StrideB; const auto StrideB = gemm_desc_ptr[group_id].StrideB;
auto& results_buffer = gridwise_gemm.GetCThreadBuffer(); using VGPRBufferT = remove_cvref_t<decltype(GridwiseGemm::GetCThreadBuffer())>;
auto results_buffer = VGPRBufferT{};
b2c_tile_map.CalculateBottomIndex(work_scheduler.tile_id_ - offset); b2c_tile_map.CalculateBottomIndex(work_scheduler.tile_id_ - offset);
results_buffer.Clear(); results_buffer.Clear();
...@@ -150,7 +150,8 @@ __global__ void ...@@ -150,7 +150,8 @@ __global__ void
StrideA, StrideA,
StrideB, StrideB,
k_batch, k_batch,
b2c_tile_map); b2c_tile_map,
results_buffer);
} while(work_scheduler.GetNextTile() && b2c_tile_map.GetNextKTileIdx()); } while(work_scheduler.GetNextTile() && b2c_tile_map.GetNextKTileIdx());
...@@ -161,7 +162,7 @@ __global__ void ...@@ -161,7 +162,7 @@ __global__ void
// if (changed group_id || next [M,N] tile) // if (changed group_id || next [M,N] tile)
if(!b2c_tile_map.IsFirstKSplitBlock()) if(!b2c_tile_map.IsFirstKSplitBlock())
{ {
gridwise_gemm.StorePartials(p_workspace); gridwise_gemm.StorePartials(p_workspace, results_buffer);
} }
work_scheduler.FlagFinished(k_batch, output_tile_idx, output_tile_idx_offset); work_scheduler.FlagFinished(k_batch, output_tile_idx, output_tile_idx_offset);
...@@ -176,7 +177,7 @@ __global__ void ...@@ -176,7 +177,7 @@ __global__ void
// Accumulate only when there is at least two workgroups processing splitk data-tiles // Accumulate only when there is at least two workgroups processing splitk data-tiles
// across same MN-output tile. // across same MN-output tile.
if(neighbour_count > 1) if(neighbour_count > 1)
gridwise_gemm.AccumulatePartials(p_workspace, neighbour_count); gridwise_gemm.AccumulatePartials(p_workspace, results_buffer, neighbour_count);
// Signal waiting blocks that they can start use their workspace. // Signal waiting blocks that they can start use their workspace.
work_scheduler.Reset(k_batch, output_tile_idx, output_tile_idx_offset); work_scheduler.Reset(k_batch, output_tile_idx, output_tile_idx_offset);
...@@ -203,7 +204,8 @@ __global__ void ...@@ -203,7 +204,8 @@ __global__ void
stride_ds, stride_ds,
stride_e, stride_e,
cde_element_op, cde_element_op,
b2c_tile_map); b2c_tile_map,
results_buffer);
} }
else if(work_scheduler.HasTile()) else if(work_scheduler.HasTile())
{ {
......
...@@ -16,7 +16,6 @@ ...@@ -16,7 +16,6 @@
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" #include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp" #include "ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_selector.hpp" #include "ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_selector.hpp"
#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp"
#include "ck/utility/reduction_functions_accumulate.hpp" #include "ck/utility/reduction_functions_accumulate.hpp"
namespace ck { namespace ck {
...@@ -318,23 +317,6 @@ class GridwiseGemmMultipleD_xdl_splitk_cshuffle_v2 ...@@ -318,23 +317,6 @@ class GridwiseGemmMultipleD_xdl_splitk_cshuffle_v2
true, true,
NumGemmKPrefetchStage>; NumGemmKPrefetchStage>;
using BlockwiseGemmT =
remove_cvref_t<decltype(BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_Selector<
BlockSize,
ComputeType,
ComputeType,
AccDataType,
decltype(GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1()),
decltype(GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1()),
MPerXdl,
NPerXdl,
MXdlPerWave,
NXdlPerWave,
KPack,
LoopSched>())>;
BlockwiseGemmT blockwise_gemm_{};
public: public:
__host__ __device__ static constexpr auto __host__ __device__ static constexpr auto
GetCShuffleBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock() GetCShuffleBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock()
...@@ -688,21 +670,36 @@ class GridwiseGemmMultipleD_xdl_splitk_cshuffle_v2 ...@@ -688,21 +670,36 @@ class GridwiseGemmMultipleD_xdl_splitk_cshuffle_v2
__device__ __host__ static constexpr auto GetMPerBlock() { return MPerBlock; } __device__ __host__ static constexpr auto GetMPerBlock() { return MPerBlock; }
__device__ __host__ static constexpr auto GetNPerBlock() { return NPerBlock; } __device__ __host__ static constexpr auto GetNPerBlock() { return NPerBlock; }
__device__ __host__ constexpr auto& GetCThreadBuffer() __device__ __host__ static constexpr auto& GetCThreadBuffer()
{ {
return blockwise_gemm_.GetCThreadBuffer(); using BlockwiseGemmT =
remove_cvref_t<decltype(BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_Selector<
BlockSize,
ComputeType,
ComputeType,
AccDataType,
decltype(GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1()),
decltype(GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1()),
MPerXdl,
NPerXdl,
MXdlPerWave,
NXdlPerWave,
KPack,
LoopSched>())>;
BlockwiseGemmT blockwise_gemm;
return blockwise_gemm.GetCThreadBuffer();
} }
template <bool HasMainKBlockLoop, typename Block2ETileMap> template <bool HasMainKBlockLoop, typename Block2ETileMap, typename CThreadBuf>
__device__ void RunGEMM(const ADataType* __restrict__ p_a_grid, __device__ void RunGEMM(const ADataType* __restrict__ p_a_grid,
const BDataType* __restrict__ p_b_grid, const BDataType* __restrict__ p_b_grid,
void* __restrict__ p_shared, void* __restrict__ p_shared,
[[maybe_unused]] const index_t KBatch,
const AElementwiseOperation& a_element_op, const AElementwiseOperation& a_element_op,
const BElementwiseOperation& b_element_op, const BElementwiseOperation& b_element_op,
const AGridDesc_KBatch_AK0_M_AK1& a_grid_desc_kbatch_ak0_m_ak1, const AGridDesc_KBatch_AK0_M_AK1& a_grid_desc_kbatch_ak0_m_ak1,
const BGridDesc_KBatch_BK0_N_BK1& b_grid_desc_kbatch_bk0_n_bk1, const BGridDesc_KBatch_BK0_N_BK1& b_grid_desc_kbatch_bk0_n_bk1,
const Block2ETileMap& block_2_etile_map) const Block2ETileMap& block_2_etile_map,
CThreadBuf& c_thread_buf)
{ {
const auto a_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>( const auto a_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_a_grid, a_grid_desc_kbatch_ak0_m_ak1.GetElementSpaceSize()); p_a_grid, a_grid_desc_kbatch_ak0_m_ak1.GetElementSpaceSize());
...@@ -760,7 +757,7 @@ class GridwiseGemmMultipleD_xdl_splitk_cshuffle_v2 ...@@ -760,7 +757,7 @@ class GridwiseGemmMultipleD_xdl_splitk_cshuffle_v2
// b_mtx[K0PerBlock, NPerBlock] is in LDS // b_mtx[K0PerBlock, NPerBlock] is in LDS
// c_mtx[MPerBlock, NPerBlock] is distributed among threads, and saved in // c_mtx[MPerBlock, NPerBlock] is distributed among threads, and saved in
// register // register
auto& c_thread_buf = blockwise_gemm_.GetCThreadBuffer(); // auto& c_thread_buf = blockwise_gemm_.GetCThreadBuffer();
// LDS allocation for A and B: be careful of alignment // LDS allocation for A and B: be careful of alignment
constexpr auto a_block_space_size_aligned = math::integer_least_multiple( constexpr auto a_block_space_size_aligned = math::integer_least_multiple(
...@@ -787,6 +784,20 @@ class GridwiseGemmMultipleD_xdl_splitk_cshuffle_v2 ...@@ -787,6 +784,20 @@ class GridwiseGemmMultipleD_xdl_splitk_cshuffle_v2
bool clear_c_thread_buf = false; bool clear_c_thread_buf = false;
auto blockwise_gemm = BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_Selector<
BlockSize,
ComputeType,
ComputeType,
AccDataType,
decltype(GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1()),
decltype(GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1()),
MPerXdl,
NPerXdl,
MXdlPerWave,
NXdlPerWave,
KPack,
LoopSched>();
gridwise_gemm_pipeline.template Run<HasMainKBlockLoop>(a_grid_desc_kbatch_ak0_m_ak1, gridwise_gemm_pipeline.template Run<HasMainKBlockLoop>(a_grid_desc_kbatch_ak0_m_ak1,
a_block_desc_kbatch_ak0_m_ak1, a_block_desc_kbatch_ak0_m_ak1,
a_blockwise_copy, a_blockwise_copy,
...@@ -799,13 +810,13 @@ class GridwiseGemmMultipleD_xdl_splitk_cshuffle_v2 ...@@ -799,13 +810,13 @@ class GridwiseGemmMultipleD_xdl_splitk_cshuffle_v2
b_grid_buf, b_grid_buf,
b_block_buf, b_block_buf,
b_block_slice_copy_step, b_block_slice_copy_step,
blockwise_gemm_, blockwise_gemm,
c_thread_buf, c_thread_buf,
num_k_block_main_loop, num_k_block_main_loop,
clear_c_thread_buf); clear_c_thread_buf);
} }
template <bool HasMainKBlockLoop, typename Block2ETileMap> template <bool HasMainKBlockLoop, typename Block2ETileMap, typename CThreadBuf>
__device__ void RunGEMM(const void* __restrict__ p_a_grid_, __device__ void RunGEMM(const void* __restrict__ p_a_grid_,
const void* __restrict__ p_b_grid_, const void* __restrict__ p_b_grid_,
void* __restrict__ p_shared, void* __restrict__ p_shared,
...@@ -817,7 +828,8 @@ class GridwiseGemmMultipleD_xdl_splitk_cshuffle_v2 ...@@ -817,7 +828,8 @@ class GridwiseGemmMultipleD_xdl_splitk_cshuffle_v2
const index_t StrideA, const index_t StrideA,
const index_t StrideB, const index_t StrideB,
const index_t KBatch, const index_t KBatch,
const Block2ETileMap& block_2_etile_map) const Block2ETileMap& block_2_etile_map,
CThreadBuf& c_thread_buf)
{ {
const auto p_a_grid = reinterpret_cast<const ADataType*>(p_a_grid_); const auto p_a_grid = reinterpret_cast<const ADataType*>(p_a_grid_);
const auto p_b_grid = reinterpret_cast<const BDataType*>(p_b_grid_); const auto p_b_grid = reinterpret_cast<const BDataType*>(p_b_grid_);
...@@ -832,19 +844,18 @@ class GridwiseGemmMultipleD_xdl_splitk_cshuffle_v2 ...@@ -832,19 +844,18 @@ class GridwiseGemmMultipleD_xdl_splitk_cshuffle_v2
RunGEMM<HasMainKBlockLoop>(p_a_grid, RunGEMM<HasMainKBlockLoop>(p_a_grid,
p_b_grid, p_b_grid,
p_shared, p_shared,
KBatch,
a_element_op, a_element_op,
b_element_op, b_element_op,
a_grid_desc_kbatch_ak0_m_ak1, a_grid_desc_kbatch_ak0_m_ak1,
b_grid_desc_kbatch_bk0_n_bk1, b_grid_desc_kbatch_bk0_n_bk1,
block_2_etile_map); block_2_etile_map,
c_thread_buf);
} }
// TODO Need to do CShuffle already here: // TODO Need to do CShuffle already here:
__device__ void StorePartials(void* __restrict__ p_workspace) template <typename CThreadBuf>
__device__ void StorePartials(void* __restrict__ p_workspace, const CThreadBuf& c_thread_buf)
{ {
const auto& c_thread_buf = blockwise_gemm_.GetCThreadBuffer();
// M0 = grid_size // M0 = grid_size
// N0 = 1 // N0 = 1
// M1 = MPerBlock // M1 = MPerBlock
...@@ -855,6 +866,21 @@ class GridwiseGemmMultipleD_xdl_splitk_cshuffle_v2 ...@@ -855,6 +866,21 @@ class GridwiseGemmMultipleD_xdl_splitk_cshuffle_v2
const auto w_grid_m0 = workspace_grid_desc_m0_n0_m1_n1.GetLength(I0); const auto w_grid_m0 = workspace_grid_desc_m0_n0_m1_n1.GetLength(I0);
const auto w_grid_n0 = workspace_grid_desc_m0_n0_m1_n1.GetLength(I1); const auto w_grid_n0 = workspace_grid_desc_m0_n0_m1_n1.GetLength(I1);
using BlockwiseGemmT =
remove_cvref_t<decltype(BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_Selector<
BlockSize,
ComputeType,
ComputeType,
AccDataType,
decltype(GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1()),
decltype(GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1()),
MPerXdl,
NPerXdl,
MXdlPerWave,
NXdlPerWave,
KPack,
LoopSched>())>;
constexpr auto c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp = constexpr auto c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp =
BlockwiseGemmT::GetCBlockDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(); BlockwiseGemmT::GetCBlockDescriptor_M0_N0_M1_N1_M2_M3_M4_N2();
...@@ -916,7 +942,7 @@ class GridwiseGemmMultipleD_xdl_splitk_cshuffle_v2 ...@@ -916,7 +942,7 @@ class GridwiseGemmMultipleD_xdl_splitk_cshuffle_v2
BlockwiseGemmT::GetCThreadDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(); BlockwiseGemmT::GetCThreadDescriptor_M0_N0_M1_N1_M2_M3_M4_N2();
const auto c_thread_mtx_on_block = const auto c_thread_mtx_on_block =
blockwise_gemm_.CalculateCThreadOriginDataIndex(I0, I0, I0, I0); BlockwiseGemmT::CalculateCThreadOriginDataIndex(I0, I0, I0, I0);
const index_t m_thread_data_on_block = c_thread_mtx_on_block[I0]; const index_t m_thread_data_on_block = c_thread_mtx_on_block[I0];
const index_t n_thread_data_on_block = c_thread_mtx_on_block[I1]; const index_t n_thread_data_on_block = c_thread_mtx_on_block[I1];
...@@ -972,9 +998,25 @@ class GridwiseGemmMultipleD_xdl_splitk_cshuffle_v2 ...@@ -972,9 +998,25 @@ class GridwiseGemmMultipleD_xdl_splitk_cshuffle_v2
w_grid_buf); w_grid_buf);
} }
__device__ void AccumulatePartials(void* __restrict__ p_workspace, uint32_t reduce_count) template <typename CThreadBuf>
__device__ void AccumulatePartials(void* __restrict__ p_workspace,
CThreadBuf& c_thread_buf,
uint32_t reduce_count)
{ {
auto& c_thread_buf = blockwise_gemm_.GetCThreadBuffer(); using BlockwiseGemmT =
remove_cvref_t<decltype(BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_Selector<
BlockSize,
ComputeType,
ComputeType,
AccDataType,
decltype(GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1()),
decltype(GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1()),
MPerXdl,
NPerXdl,
MXdlPerWave,
NXdlPerWave,
KPack,
LoopSched>())>;
constexpr auto c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2 = constexpr auto c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2 =
BlockwiseGemmT::GetCThreadDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(); BlockwiseGemmT::GetCThreadDescriptor_M0_N0_M1_N1_M2_M3_M4_N2();
...@@ -1047,7 +1089,7 @@ class GridwiseGemmMultipleD_xdl_splitk_cshuffle_v2 ...@@ -1047,7 +1089,7 @@ class GridwiseGemmMultipleD_xdl_splitk_cshuffle_v2
Sequence<7>{})); Sequence<7>{}));
const auto c_thread_mtx_on_block = const auto c_thread_mtx_on_block =
blockwise_gemm_.CalculateCThreadOriginDataIndex(I0, I0, I0, I0); BlockwiseGemmT::CalculateCThreadOriginDataIndex(I0, I0, I0, I0);
const index_t m_thread_data_on_block = c_thread_mtx_on_block[I0]; const index_t m_thread_data_on_block = c_thread_mtx_on_block[I0];
const index_t n_thread_data_on_block = c_thread_mtx_on_block[I1]; const index_t n_thread_data_on_block = c_thread_mtx_on_block[I1];
...@@ -1072,8 +1114,10 @@ class GridwiseGemmMultipleD_xdl_splitk_cshuffle_v2 ...@@ -1072,8 +1114,10 @@ class GridwiseGemmMultipleD_xdl_splitk_cshuffle_v2
make_multi_index(n_thread_data_on_block)); make_multi_index(n_thread_data_on_block));
auto p_workspace_grid = reinterpret_cast<AccDataType*>(p_workspace); auto p_workspace_grid = reinterpret_cast<AccDataType*>(p_workspace);
auto w_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global AmdBufferCoherenceEnum::GLC>( auto w_grid_buf =
p_workspace_grid, workspace_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2.GetElementSpaceSize()); make_dynamic_buffer<AddressSpaceEnum::Global, AmdBufferCoherenceEnum::GLC>(
p_workspace_grid,
workspace_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2.GetElementSpaceSize());
auto acc_load = ThreadwiseTensorSliceTransfer_v2< auto acc_load = ThreadwiseTensorSliceTransfer_v2<
AccDataType, // SrcData, AccDataType, // SrcData,
...@@ -1122,7 +1166,7 @@ class GridwiseGemmMultipleD_xdl_splitk_cshuffle_v2 ...@@ -1122,7 +1166,7 @@ class GridwiseGemmMultipleD_xdl_splitk_cshuffle_v2
} }
} }
template <typename Block2ETileMap> template <typename Block2ETileMap, typename CThreadBuf>
__device__ void RunWrite(DsGridPointer p_ds_grid, __device__ void RunWrite(DsGridPointer p_ds_grid,
EDataType* __restrict__ p_e_grid, EDataType* __restrict__ p_e_grid,
void* __restrict__ p_shared, void* __restrict__ p_shared,
...@@ -1131,7 +1175,8 @@ class GridwiseGemmMultipleD_xdl_splitk_cshuffle_v2 ...@@ -1131,7 +1175,8 @@ class GridwiseGemmMultipleD_xdl_splitk_cshuffle_v2
const std::array<index_t, NumDTensor> StrideDs, const std::array<index_t, NumDTensor> StrideDs,
const index_t StrideE, const index_t StrideE,
const CDEElementwiseOperation& cde_element_op, const CDEElementwiseOperation& cde_element_op,
const Block2ETileMap& block_2_etile_map) const Block2ETileMap& block_2_etile_map,
const CThreadBuf& c_thread_buf)
{ {
using DsGridDesc_M_N = remove_cvref_t<decltype(MakeDsGridDescriptor_M_N({}, {}, {}))>; using DsGridDesc_M_N = remove_cvref_t<decltype(MakeDsGridDescriptor_M_N({}, {}, {}))>;
...@@ -1167,8 +1212,6 @@ class GridwiseGemmMultipleD_xdl_splitk_cshuffle_v2 ...@@ -1167,8 +1212,6 @@ class GridwiseGemmMultipleD_xdl_splitk_cshuffle_v2
auto e_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>( auto e_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_e_grid, e_grid_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize()); p_e_grid, e_grid_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize());
const auto& c_thread_buf = blockwise_gemm_.GetCThreadBuffer();
// shuffle C and write out // shuffle C and write out
static_assert(MXdlPerWave % CShuffleMXdlPerWavePerShuffle == 0 && static_assert(MXdlPerWave % CShuffleMXdlPerWavePerShuffle == 0 &&
NXdlPerWave % CShuffleNXdlPerWavePerShuffle == 0, NXdlPerWave % CShuffleNXdlPerWavePerShuffle == 0,
...@@ -1180,6 +1223,21 @@ class GridwiseGemmMultipleD_xdl_splitk_cshuffle_v2 ...@@ -1180,6 +1223,21 @@ class GridwiseGemmMultipleD_xdl_splitk_cshuffle_v2
// divide block work by [M, N, K] // divide block work by [M, N, K]
const auto block_work_idx = block_2_etile_map.GetBottomIndex(); const auto block_work_idx = block_2_etile_map.GetBottomIndex();
using BlockwiseGemmT =
remove_cvref_t<decltype(BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_Selector<
BlockSize,
ComputeType,
ComputeType,
AccDataType,
decltype(GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1()),
decltype(GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1()),
MPerXdl,
NPerXdl,
MXdlPerWave,
NXdlPerWave,
KPack,
LoopSched>())>;
// TODO: hacky, fix it! // TODO: hacky, fix it!
constexpr auto c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2 = constexpr auto c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2 =
BlockwiseGemmT::GetCThreadDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(); BlockwiseGemmT::GetCThreadDescriptor_M0_N0_M1_N1_M2_M3_M4_N2();
...@@ -1225,7 +1283,7 @@ class GridwiseGemmMultipleD_xdl_splitk_cshuffle_v2 ...@@ -1225,7 +1283,7 @@ class GridwiseGemmMultipleD_xdl_splitk_cshuffle_v2
// calculate origin of thread output tensor on global memory // calculate origin of thread output tensor on global memory
// blockwise GEMM c matrix starting index // blockwise GEMM c matrix starting index
const auto c_thread_mtx_on_block = const auto c_thread_mtx_on_block =
blockwise_gemm_.CalculateCThreadOriginDataIndex(I0, I0, I0, I0); BlockwiseGemmT::CalculateCThreadOriginDataIndex(I0, I0, I0, I0);
const index_t m_thread_data_on_block = c_thread_mtx_on_block[I0]; const index_t m_thread_data_on_block = c_thread_mtx_on_block[I0];
const index_t n_thread_data_on_block = c_thread_mtx_on_block[I1]; const index_t n_thread_data_on_block = c_thread_mtx_on_block[I1];
......
...@@ -124,10 +124,6 @@ FOREACH(subdir_path ${dir_list}) ...@@ -124,10 +124,6 @@ FOREACH(subdir_path ${dir_list})
message("Found only dl instances, but DL_KERNELS is not set. Skipping.") message("Found only dl instances, but DL_KERNELS is not set. Skipping.")
set(add_inst 0) set(add_inst 0)
endif() endif()
if(("${cmake_instance}" MATCHES "ONLY XDL_INSTANCES") AND (NOT "${GPU_TARGETS}" MATCHES "gfx9"))
message("Found only xdl instances, but building for non-gfx9 targets. Skipping.")
set(add_inst 0)
endif()
if((add_inst EQUAL 1)) if((add_inst EQUAL 1))
get_filename_component(target_dir ${subdir_path} NAME) get_filename_component(target_dir ${subdir_path} NAME)
add_subdirectory(${target_dir}) add_subdirectory(${target_dir})
......
#ONLY XDL_INSTANCES
add_instance_library(device_grouped_gemm_multiple_d_instance add_instance_library(device_grouped_gemm_multiple_d_instance
device_grouped_gemm_multiple_d_splitk_xdl_cshuffle_f16_f16_f16_mk_nk_mn_irregular_instance.cpp device_grouped_gemm_multiple_d_splitk_xdl_cshuffle_f16_f16_f16_mk_nk_mn_irregular_instance.cpp
device_grouped_gemm_multiple_d_splitk_xdl_cshuffle_f16_f16_f16_mk_kn_mn_irregular_instance.cpp device_grouped_gemm_multiple_d_splitk_xdl_cshuffle_f16_f16_f16_mk_kn_mn_irregular_instance.cpp
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment