Commit fe15fcc0 authored by Harisankar Sadasivan's avatar Harisankar Sadasivan
Browse files

debugging prints added.

parent 6c5111b7
...@@ -23,6 +23,7 @@ add_example_executable(example_gemm_xdl_fp16_v2 gemm_xdl_fp16_v2.cpp) ...@@ -23,6 +23,7 @@ add_example_executable(example_gemm_xdl_fp16_v2 gemm_xdl_fp16_v2.cpp)
add_example_dependencies(example_gemm_xdl example_gemm_xdl_fp16_v2) add_example_dependencies(example_gemm_xdl example_gemm_xdl_fp16_v2)
add_example_executable(example_gemm_xdl_fp16_streamk_v3 gemm_xdl_fp16_streamk_v3.cpp) add_example_executable(example_gemm_xdl_fp16_streamk_v3 gemm_xdl_fp16_streamk_v3.cpp)
target_compile_options(example_gemm_xdl_fp16_streamk_v3 PRIVATE -ggdb -O1 -march=native)
add_example_dependencies(example_gemm_xdl example_gemm_xdl_fp16_streamk_v3) add_example_dependencies(example_gemm_xdl example_gemm_xdl_fp16_streamk_v3)
add_example_executable(example_gemm_xdl_fp16_v3 gemm_xdl_fp16_v3.cpp) add_example_executable(example_gemm_xdl_fp16_v3 gemm_xdl_fp16_v3.cpp)
add_example_dependencies(example_gemm_xdl example_gemm_xdl_fp16_v3) add_example_dependencies(example_gemm_xdl example_gemm_xdl_fp16_v3)
......
...@@ -8,7 +8,8 @@ ...@@ -8,7 +8,8 @@
using ADataType = ck::half_t; using ADataType = ck::half_t;
using BDataType = ck::half_t; using BDataType = ck::half_t;
using AccDataType = float; using AccDataType = float;
using CShuffleDataType = ck::half_t; // using CShuffleDataType = ck::half_t;
using CShuffleDataType = float;
using CDataType = ck::half_t; using CDataType = ck::half_t;
using ALayout = Row; using ALayout = Row;
......
...@@ -239,6 +239,25 @@ bool run_gemm(const ProblemType& problem_size, const ExecutionConfig& config) ...@@ -239,6 +239,25 @@ bool run_gemm(const ProblemType& problem_size, const ExecutionConfig& config)
return true; return true;
} }
std::size_t workspace_size = gemm.GetWorkSpaceSize(&argument);
if(workspace_size != 0)
{
workspace.Realloc(workspace_size);
gemm.SetWorkSpacePointer(&argument, workspace.GetDeviceBuffer());
}
// if(workspace_size != 0)
// {
// float* ws_ptr = reinterpret_cast<float*>(malloc(workspace_size));
// size_t ws_dwords = workspace_size / sizeof(float);
// workspace.FromDevice(ws_ptr);
// printf("ws size=%0zu\n",workspace_size);
// for(size_t i = 0; i < ws_dwords; i++)
// {
// uint32_t rere = reinterpret_cast<uint32_t*>(ws_ptr)[i];
// printf("%4lu : %f(0x%08x)\n", i, ws_ptr[i], rere);
// }
// free(ws_ptr);
// }
bool pass = true; bool pass = true;
if(config.do_verification) if(config.do_verification)
...@@ -261,8 +280,15 @@ bool run_gemm(const ProblemType& problem_size, const ExecutionConfig& config) ...@@ -261,8 +280,15 @@ bool run_gemm(const ProblemType& problem_size, const ExecutionConfig& config)
return ck::utils::check_err(c_m_n_device_result_converted, c_m_n_host_result); return ck::utils::check_err(c_m_n_device_result_converted, c_m_n_host_result);
#else #else
c_m_n_device_buf.FromDevice(c_m_n_device_result.mData.data()); printf("device copy initiated\n"); // HS
if((workspace_size != 0) && (Streamk_sel > 0))
{
printf("entered if\n");
workspace.FromDevice(c_m_n_device_result.mData.data());
}
else
c_m_n_device_buf.FromDevice(c_m_n_device_result.mData.data());
printf("device copy finished\n"); // HS
pass &= ck::utils::check_err(c_m_n_device_result, pass &= ck::utils::check_err(c_m_n_device_result,
c_m_n_host_result, c_m_n_host_result,
"Error: Incorrect results!", "Error: Incorrect results!",
...@@ -273,8 +299,9 @@ bool run_gemm(const ProblemType& problem_size, const ExecutionConfig& config) ...@@ -273,8 +299,9 @@ bool run_gemm(const ProblemType& problem_size, const ExecutionConfig& config)
if(config.time_kernel) if(config.time_kernel)
{ {
printf("before running timing\n");
ave_time = invoker.Run(argument, StreamConfig{nullptr, config.time_kernel}); ave_time = invoker.Run(argument, StreamConfig{nullptr, config.time_kernel});
printf("after running timing\n");
std::size_t flop = 2_uz * M * N * K; std::size_t flop = 2_uz * M * N * K;
std::size_t num_btype = std::size_t num_btype =
sizeof(ADataType) * M * K + sizeof(BDataType) * K * N + sizeof(CDataType) * M * N; sizeof(ADataType) * M * K + sizeof(BDataType) * K * N + sizeof(CDataType) * M * N;
......
...@@ -131,25 +131,27 @@ struct DeviceGemm_Xdl_CShuffle_Streamk_V3 : public DeviceGemm_Streamk_V2<ALayout ...@@ -131,25 +131,27 @@ struct DeviceGemm_Xdl_CShuffle_Streamk_V3 : public DeviceGemm_Streamk_V2<ALayout
{ {
float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{}) float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{})
{ {
printf("inside run\n");
if(stream_config.log_level_ > 0) if(stream_config.log_level_ > 0)
{ {
arg.Print(); arg.Print();
} }
printf("done printing arg\n");
if(!GridwiseGemm::CheckValidity(arg)) if(!GridwiseGemm::CheckValidity(arg))
{ {
throw std::runtime_error("wrong! GridwiseGemm has invalid setting"); throw std::runtime_error("wrong! GridwiseGemm has invalid setting");
} }
printf("done checking arg validity\n");
float ave_time = 0; float ave_time = 0;
index_t k_grain = KPerBlock; index_t k_grain = KPerBlock;
index_t K_split = (arg.K + k_grain - 1) / k_grain * KPerBlock; index_t K_split = (arg.K + k_grain - 1) / k_grain * KPerBlock;
printf("done finding k_split\n");
const bool has_main_k_block_loop = GridwiseGemm::CalculateHasMainKBlockLoop(K_split); const bool has_main_k_block_loop = GridwiseGemm::CalculateHasMainKBlockLoop(K_split);
if constexpr(GridwiseGemm::Block2CTileMap_streamk::ReductionStrategy == if constexpr(GridwiseGemm::Block2CTileMap_streamk::ReductionStrategy ==
StreamKReductionStrategy::Atomic) StreamKReductionStrategy::Atomic)
{ {
hipGetErrorString(hipMemsetAsync( hipGetErrorString(hipMemsetAsync(
arg.p_c_grid, 0, arg.M * arg.N * sizeof(CDataType), stream_config.stream_id_)); arg.p_c_grid, 0, arg.M * arg.N * sizeof(CDataType), stream_config.stream_id_));
} }
...@@ -216,12 +218,14 @@ struct DeviceGemm_Xdl_CShuffle_Streamk_V3 : public DeviceGemm_Streamk_V2<ALayout ...@@ -216,12 +218,14 @@ struct DeviceGemm_Xdl_CShuffle_Streamk_V3 : public DeviceGemm_Streamk_V2<ALayout
hipGetErrorString(hipMemsetAsync( hipGetErrorString(hipMemsetAsync(
workspace_semaphore, workspace_semaphore,
0, 0,
arg.block_2_ctile_map_streamk.get_workspace_size_for_semaphore(), sizeof(uint32_t),
//arg.block_2_ctile_map_streamk.get_workspace_size_for_semaphore(),
stream_config.stream_id_)); stream_config.stream_id_));
}; };
printf("before ave_time\n");
ave_time = launch_and_time_kernel_with_preprocess( ave_time = launch_and_time_kernel_with_preprocess(
stream_config, preprocess, kernel, grid_dim, dim3(BlockSize), 0, arg); stream_config, preprocess, kernel, grid_dim, dim3(BlockSize), 0, arg);
printf("after ave_time\n");
} }
} }
}; };
...@@ -242,7 +246,9 @@ struct DeviceGemm_Xdl_CShuffle_Streamk_V3 : public DeviceGemm_Streamk_V2<ALayout ...@@ -242,7 +246,9 @@ struct DeviceGemm_Xdl_CShuffle_Streamk_V3 : public DeviceGemm_Streamk_V2<ALayout
true, true,
InMemoryDataOperationEnum::Set, InMemoryDataOperationEnum::Set,
minimum_occupancy>; minimum_occupancy>;
printf("before running lambda\n");
Run(kernel); Run(kernel);
printf("after running lambda\n");
} }
} }
// Tail number could be One to Seven // Tail number could be One to Seven
...@@ -443,6 +449,28 @@ struct DeviceGemm_Xdl_CShuffle_Streamk_V3 : public DeviceGemm_Streamk_V2<ALayout ...@@ -443,6 +449,28 @@ struct DeviceGemm_Xdl_CShuffle_Streamk_V3 : public DeviceGemm_Streamk_V2<ALayout
} }
}; };
size_t GetWorkSpaceSize(const BaseArgument* pArg) const override
{
const Argument* p_arg = dynamic_cast<const Argument*>(pArg);
if constexpr(GridwiseGemm::Block2CTileMap_streamk::ReductionStrategy ==
StreamKReductionStrategy::Reduction)
{
return p_arg->block_2_ctile_map_streamk.get_workspace_size(sizeof(GemmAccDataType));
}
else
{
return 0;
}
}
void SetWorkSpacePointer(BaseArgument* pArg,
void* p_workspace,
const StreamConfig& = StreamConfig{}) const override
{
Argument* pArg_ = dynamic_cast<Argument*>(pArg);
pArg_->p_workspace_ = p_workspace;
}
static constexpr bool IsValidCompilationParameter() static constexpr bool IsValidCompilationParameter()
{ {
// TODO: properly implement this check // TODO: properly implement this check
......
...@@ -1191,6 +1191,16 @@ struct GridwiseGemm_xdl_cshuffle_streamk_v3 ...@@ -1191,6 +1191,16 @@ struct GridwiseGemm_xdl_cshuffle_streamk_v3
const BElementwiseOperation b_element_op{}; const BElementwiseOperation b_element_op{};
const CElementwiseOperation c_element_op{}; const CElementwiseOperation c_element_op{};
const auto a_grid_desc_ak0_m_ak1 = MakeAGridDescriptor_AK0_M_AK1(
problem.M, problem.MPadded, problem.K, problem.KPadded, problem.StrideA, problem.AK0);
const auto b_grid_desc_bk0_n_bk1 = MakeBGridDescriptor_BK0_N_BK1(
problem.K, problem.KPadded, problem.N, problem.NPadded, problem.StrideB, problem.BK0);
const auto a_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_a_grid, a_grid_desc_ak0_m_ak1.GetElementSpaceSize());
const auto b_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_b_grid, b_grid_desc_bk0_n_bk1.GetElementSpaceSize());
// Block2CTileMap_streamk block_2_ctile_map_streamk(problem.M, // Block2CTileMap_streamk block_2_ctile_map_streamk(problem.M,
// problem.N, // problem.N,
// AK0Number * problem.KPadded, // AK0Number * problem.KPadded,
...@@ -1218,8 +1228,7 @@ struct GridwiseGemm_xdl_cshuffle_streamk_v3 ...@@ -1218,8 +1228,7 @@ struct GridwiseGemm_xdl_cshuffle_streamk_v3
static_cast<uint32_t>(block_idx) >= block_2_ctile_map_streamk.dp_start_block_idx && static_cast<uint32_t>(block_idx) >= block_2_ctile_map_streamk.dp_start_block_idx &&
static_cast<uint32_t>(block_idx) < static_cast<uint32_t>(block_idx) <
block_2_ctile_map_streamk.reduction_start_block_idx; block_2_ctile_map_streamk.reduction_start_block_idx;
is_reduction_block = static_cast<uint32_t>(block_idx) >=
block_2_ctile_map_streamk.reduction_start_block_idx;
block_2_ctile_map_streamk.get_block_itr(block_idx, iter_start, iter_end); block_2_ctile_map_streamk.get_block_itr(block_idx, iter_start, iter_end);
num_k_block_main_loop = iter_end - iter_start; num_k_block_main_loop = iter_end - iter_start;
...@@ -1229,6 +1238,8 @@ struct GridwiseGemm_xdl_cshuffle_streamk_v3 ...@@ -1229,6 +1238,8 @@ struct GridwiseGemm_xdl_cshuffle_streamk_v3
if constexpr(Block2CTileMap_streamk::ReductionStrategy == if constexpr(Block2CTileMap_streamk::ReductionStrategy ==
StreamKReductionStrategy::Reduction) StreamKReductionStrategy::Reduction)
{ {
is_reduction_block = static_cast<uint32_t>(block_idx) >=
block_2_ctile_map_streamk.reduction_start_block_idx;
if(is_reduction_block) if(is_reduction_block)
{ {
// descriptors // descriptors
...@@ -1347,7 +1358,9 @@ struct GridwiseGemm_xdl_cshuffle_streamk_v3 ...@@ -1347,7 +1358,9 @@ struct GridwiseGemm_xdl_cshuffle_streamk_v3
CElementwiseOperation{}}; CElementwiseOperation{}};
// block synchronization // block synchronization
wg_barrier.wait_eq(reduction_idx, tile_acc_offset_end - tile_acc_offset_start); wg_barrier.wait_eq(0, block_2_ctile_map_streamk.sk_num_blocks);
// wg_barrier.wait_eq(reduction_idx, tile_acc_offset_end -
// tile_acc_offset_start);
#if 0 #if 0
if(threadIdx.x == 0) { if(threadIdx.x == 0) {
...@@ -1428,7 +1441,7 @@ struct GridwiseGemm_xdl_cshuffle_streamk_v3 ...@@ -1428,7 +1441,7 @@ struct GridwiseGemm_xdl_cshuffle_streamk_v3
partial_acc_store_step_m); partial_acc_store_step_m);
} }
} }
return; continue;
} }
} }
...@@ -1446,25 +1459,6 @@ struct GridwiseGemm_xdl_cshuffle_streamk_v3 ...@@ -1446,25 +1459,6 @@ struct GridwiseGemm_xdl_cshuffle_streamk_v3
iter_end - 1, tile_idx, iter_offset); iter_end - 1, tile_idx, iter_offset);
iter_offset = __builtin_amdgcn_readfirstlane(iter_offset - current_iter_length + 1); iter_offset = __builtin_amdgcn_readfirstlane(iter_offset - current_iter_length + 1);
const auto a_grid_desc_ak0_m_ak1 = MakeAGridDescriptor_AK0_M_AK1(problem.M,
problem.MPadded,
problem.K,
problem.KPadded,
problem.StrideA,
problem.AK0);
const auto b_grid_desc_bk0_n_bk1 = MakeBGridDescriptor_BK0_N_BK1(problem.K,
problem.KPadded,
problem.N,
problem.NPadded,
problem.StrideB,
problem.BK0);
const auto a_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_a_grid, a_grid_desc_ak0_m_ak1.GetElementSpaceSize());
const auto b_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_b_grid, b_grid_desc_bk0_n_bk1.GetElementSpaceSize());
auto block_work_idx = auto block_work_idx =
block_2_ctile_map_streamk.tile_to_spatial(tile_idx, problem.M, problem.N); block_2_ctile_map_streamk.tile_to_spatial(tile_idx, problem.M, problem.N);
...@@ -1764,7 +1758,7 @@ struct GridwiseGemm_xdl_cshuffle_streamk_v3 ...@@ -1764,7 +1758,7 @@ struct GridwiseGemm_xdl_cshuffle_streamk_v3
CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
Sequence<0, 1, 2, 3>, // typename ThreadClusterArrangeOrder, Sequence<0, 1, 2, 3>, // typename ThreadClusterArrangeOrder,
CShuffleDataType, // typename SrcData, CShuffleDataType, // typename SrcData,
CDataType, // typename DstData, CShuffleDataType, // typename DstData,
decltype(c_shuffle_block_desc_mblock_mperblock_nblock_nperblock), decltype(c_shuffle_block_desc_mblock_mperblock_nblock_nperblock),
decltype(c_block_desc_mshuffle_mpershuffle_nshuffle_npershuffle), decltype(c_block_desc_mshuffle_mpershuffle_nshuffle_npershuffle),
Sequence<0, 1, 2, 3>, // typename DimAccessOrder, Sequence<0, 1, 2, 3>, // typename DimAccessOrder,
...@@ -1881,17 +1875,6 @@ struct GridwiseGemm_xdl_cshuffle_streamk_v3 ...@@ -1881,17 +1875,6 @@ struct GridwiseGemm_xdl_cshuffle_streamk_v3
c_grid_desc_mblock_mperblock_nblock_nperblock, c_global_step); c_grid_desc_mblock_mperblock_nblock_nperblock, c_global_step);
} }
}); });
if constexpr(Block2CTileMap_streamk::ReductionStrategy ==
StreamKReductionStrategy::Reduction)
{
if(is_sk_block)
{
// increase the counter for this tile
workgroup_barrier wg_barrier(p_semaphore);
wg_barrier.inc(tile_idx);
}
}
} }
// exit condition // exit condition
iter_end -= current_iter_length; iter_end -= current_iter_length;
...@@ -1905,7 +1888,17 @@ struct GridwiseGemm_xdl_cshuffle_streamk_v3 ...@@ -1905,7 +1888,17 @@ struct GridwiseGemm_xdl_cshuffle_streamk_v3
// make sure next loop LDS is ready for use // make sure next loop LDS is ready for use
block_sync_lds(); block_sync_lds();
} }
} if constexpr(Block2CTileMap_streamk::ReductionStrategy ==
StreamKReductionStrategy::Reduction)
{
if(is_sk_block)
{
// increase the counter for this tile
workgroup_barrier wg_barrier(p_semaphore);
wg_barrier.inc(0);
}
}
} // for loop
} }
template <bool HasMainKBlockLoop, template <bool HasMainKBlockLoop,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment