"...resnet50_tensorflow.git" did not exist on "36db24502c550bc899b6277d177afa97a8975ebd"
Commit d19487eb authored by carlushuang's avatar carlushuang
Browse files

update device code

parent e32db0e9
......@@ -73,3 +73,72 @@ float launch_and_time_kernel(const StreamConfig& stream_config,
return 0;
#endif
}
template <typename... Args, typename F, typename PreProcessFunc>
float launch_and_time_kernel_with_preprocess(const StreamConfig& stream_config,
PreProcessFunc preprocess,
F kernel,
dim3 grid_dim,
dim3 block_dim,
std::size_t lds_byte,
Args... args)
{
#if CK_TIME_KERNEL
if(stream_config.time_kernel_)
{
#if DEBUG_LOG
printf("%s: grid_dim {%d, %d, %d}, block_dim {%d, %d, %d} \n",
__func__,
grid_dim.x,
grid_dim.y,
grid_dim.z,
block_dim.x,
block_dim.y,
block_dim.z);
printf("Warm up 1 time\n");
#endif
// warm up
preprocess();
kernel<<<grid_dim, block_dim, lds_byte, stream_config.stream_id_>>>(args...);
const int nrepeat = 10;
#if DEBUG_LOG
printf("Start running %d times...\n", nrepeat);
#endif
hipEvent_t start, stop;
hip_check_error(hipEventCreate(&start));
hip_check_error(hipEventCreate(&stop));
hip_check_error(hipDeviceSynchronize());
hip_check_error(hipEventRecord(start, stream_config.stream_id_));
for(int i = 0; i < nrepeat; ++i)
{
preprocess();
kernel<<<grid_dim, block_dim, lds_byte, stream_config.stream_id_>>>(args...);
}
hip_check_error(hipEventRecord(stop, stream_config.stream_id_));
hip_check_error(hipEventSynchronize(stop));
float total_time = 0;
hip_check_error(hipEventElapsedTime(&total_time, start, stop));
return total_time / nrepeat;
}
else
{
preprocess();
kernel<<<grid_dim, block_dim, lds_byte, stream_config.stream_id_>>>(args...);
return 0;
}
#else
kernel<<<grid_dim, block_dim, lds_byte, stream_config.stream_id_>>>(args...);
return 0;
#endif
}
......@@ -148,6 +148,22 @@ struct DeviceGemmXdlStreamK : public DeviceGemmStreamK<ALayout,
StreamKReductionStrategy::Atomic)
{
hipGetErrorString(hipMemset(karg.p_c_grid, 0, karg.M * karg.N * sizeof(CDataType)));
ave_time = launch_and_time_kernel(stream_config,
kernel,
grid_dims,
dim3(BlockSize),
0,
karg.p_a_grid,
karg.p_b_grid,
karg.p_c_grid,
karg.p_workspace_,
karg.M,
karg.N,
karg.K,
karg.StrideA,
karg.StrideB,
karg.StrideC,
karg.block_mapping);
}
else if constexpr(GridwiseGemm::Block2CTileMap::ReductionStrategy ==
StreamKReductionStrategy::Reduction)
......@@ -156,27 +172,32 @@ struct DeviceGemmXdlStreamK : public DeviceGemmStreamK<ALayout,
workspace_semaphore =
workspace_semaphore + karg.block_mapping.get_workspace_size_for_acc(
sizeof(typename GridwiseGemm::FloatAcc));
hipGetErrorString(hipMemset(
workspace_semaphore, 0, karg.block_mapping.get_workspace_size_for_semaphore()));
auto preprocess = [&]() {
hipGetErrorString(
hipMemset(workspace_semaphore,
0,
karg.block_mapping.get_workspace_size_for_semaphore()));
};
ave_time = launch_and_time_kernel_with_preprocess(stream_config,
preprocess,
kernel,
grid_dims,
dim3(BlockSize),
0,
karg.p_a_grid,
karg.p_b_grid,
karg.p_c_grid,
karg.p_workspace_,
karg.M,
karg.N,
karg.K,
karg.StrideA,
karg.StrideB,
karg.StrideC,
karg.block_mapping);
}
ave_time = launch_and_time_kernel(stream_config,
kernel,
grid_dims,
dim3(BlockSize),
0,
karg.p_a_grid,
karg.p_b_grid,
karg.p_c_grid,
karg.p_workspace_,
karg.M,
karg.N,
karg.K,
karg.StrideA,
karg.StrideB,
karg.StrideC,
karg.block_mapping);
return ave_time;
}
......
......@@ -23,7 +23,7 @@ namespace ck {
template <typename GridwiseGemm>
__global__ void
#if CK_USE_LAUNCH_BOUNDS
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, 1)
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
#endif
kernel_gemm_xdlops_streamk(const typename GridwiseGemm::FloatAB* p_a_grid,
const typename GridwiseGemm::FloatAB* p_b_grid,
......@@ -978,7 +978,7 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_streamk
Sequence<0, 1, 2, 3>, // typename DimAccessOrder,
3, // index_t VectorDim,
CBlockTransferScalarPerVector_NWaveNPerXDL, // index_t ScalarPerVector,
true, // bool ThreadTransferSrcResetCoordinateAfterRun,
false, // bool ThreadTransferSrcResetCoordinateAfterRun,
false> // bool ThreadTransferDstResetCoordinateAfterRun
{c_block_desc_mblock_mpershuffle_nblock_npershuffle,
make_multi_index(0, 0, 0, 0),
......@@ -1007,8 +1007,10 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_streamk
Sequence<0, 1, 2, 3>, // typename DimAccessOrder,
3, // index_t VectorDim,
CBlockTransferScalarPerVector_NWaveNPerXDL, // index_t ScalarPerVector,
false, // bool ThreadTransferSrcResetCoordinateAfterRun,
false> // bool ThreadTransferDstResetCoordinateAfterRun
false, // bool ThreadTransferSrcResetCoordinateAfterRun, => need to be false,
// othre wise has scratch
false> // bool ThreadTransferDstResetCoordinateAfterRun, => need to be false,
// othre wise has scratch
{c_block_desc_mblock_mpershuffle_nblock_npershuffle,
make_multi_index(0, 0, 0, 0),
c_block_desc_mshuffle_mpershuffle_nshuffle_npershuffle,
......@@ -1050,6 +1052,10 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_streamk
// make sure it's safe to do ds_read
block_sync_lds();
c_block_copy_lds_to_global.SetSrcSliceOrigin(
c_block_desc_mblock_mpershuffle_nblock_npershuffle,
make_tuple(0, 0, 0, 0));
// LDS to global
if(is_dp_block)
c_block_copy_lds_to_global.template Run<decltype(c_block_buf),
......
......@@ -156,11 +156,11 @@ bool profile_gemm_streamk_impl(int do_verification,
c_element_op,
NumSKBlocks);
DeviceMem workspace;
std::size_t workspace_size = op_ptr->GetWorkSpaceSize(argument_ptr);
std::size_t workspace_size = op_ptr->GetWorkSpaceSize(argument_ptr.get());
if(workspace_size != 0)
{
workspace.Realloc(workspace_size);
op_ptr->SetWorkSpacePointer(argument_ptr, workspace.GetDeviceBuffer());
op_ptr->SetWorkSpacePointer(argument_ptr.get(), workspace.GetDeviceBuffer());
}
auto invoker_ptr = op_ptr->MakeInvokerPointer();
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment