Commit 7a20cf67 authored by Jakub Piasecki's avatar Jakub Piasecki
Browse files

tmp save

parent a793afc9
...@@ -76,7 +76,6 @@ __global__ void ...@@ -76,7 +76,6 @@ __global__ void
const AElementwiseOperation a_element_op, const AElementwiseOperation a_element_op,
const BElementwiseOperation b_element_op, const BElementwiseOperation b_element_op,
const CDEElementwiseOperation cde_element_op, const CDEElementwiseOperation cde_element_op,
const index_t batch_count,
const AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1, const AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1,
const BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1, const BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1,
const DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock const DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
...@@ -89,18 +88,19 @@ __global__ void ...@@ -89,18 +88,19 @@ __global__ void
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \ #if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \
defined(__gfx94__)) defined(__gfx94__))
// offset base pointer for each work-group // offset base pointer for each work-group
const index_t num_blocks_per_batch = // const index_t num_blocks_per_batch =
__builtin_amdgcn_readfirstlane(get_grid_size() / batch_count); // __builtin_amdgcn_readfirstlane(get_grid_size() / batch_count);
const index_t g_idx = __builtin_amdgcn_readfirstlane(get_block_1d_id() / num_blocks_per_batch); //const index_t g_idx = __builtin_amdgcn_readfirstlane(get_block_1d_id() / num_blocks_per_batch);
const index_t g_idx = __builtin_amdgcn_readfirstlane(blockIdx.y);
const long_index_t a_batch_offset = const long_index_t a_group_offset =
amd_wave_read_first_lane(compute_ptr_offset_of_batch.GetAPtrOffset(g_idx)); amd_wave_read_first_lane(compute_ptr_offset_of_batch.GetAPtrOffset(g_idx));
const long_index_t b_batch_offset = const long_index_t b_group_offset =
amd_wave_read_first_lane(compute_ptr_offset_of_batch.GetBPtrOffset(g_idx)); amd_wave_read_first_lane(compute_ptr_offset_of_batch.GetBPtrOffset(g_idx));
const long_index_t e_batch_offset = const long_index_t e_group_offset =
amd_wave_read_first_lane(compute_ptr_offset_of_batch.GetEPtrOffset(g_idx)); amd_wave_read_first_lane(compute_ptr_offset_of_batch.GetEPtrOffset(g_idx));
const auto ds_batch_offset = compute_ptr_offset_of_batch.GetDsPtrOffset(g_idx); const auto ds_group_offset = compute_ptr_offset_of_batch.GetDsPtrOffset(g_idx);
__shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()];
...@@ -110,12 +110,12 @@ __global__ void ...@@ -110,12 +110,12 @@ __global__ void
DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock::Size(); DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock::Size();
static_for<0, NumDTensor, 1>{}( static_for<0, NumDTensor, 1>{}(
[&](auto i) { p_ds_grid_grp(i) = p_ds_grid[i] + ds_batch_offset[i]; }); [&](auto i) { p_ds_grid_grp(i) = p_ds_grid[i] + ds_group_offset[i]; });
GridwiseGemm::template Run<HasMainKBlockLoop>(p_a_grid + a_batch_offset, GridwiseGemm::template Run<HasMainKBlockLoop>(p_a_grid + a_group_offset,
p_b_grid + b_batch_offset, p_b_grid + b_group_offset,
p_ds_grid_grp, p_ds_grid_grp,
p_e_grid + e_batch_offset, p_e_grid + e_group_offset,
p_shared, p_shared,
a_element_op, a_element_op,
b_element_op, b_element_op,
...@@ -130,7 +130,6 @@ __global__ void ...@@ -130,7 +130,6 @@ __global__ void
ignore = p_b_grid; ignore = p_b_grid;
ignore = p_ds_grid; ignore = p_ds_grid;
ignore = p_e_grid; ignore = p_e_grid;
ignore = batch_count;
ignore = a_grid_desc_ak0_m_ak1; ignore = a_grid_desc_ak0_m_ak1;
ignore = b_grid_desc_bk0_n_bk1; ignore = b_grid_desc_bk0_n_bk1;
ignore = ds_grid_desc_mblock_mperblock_nblock_nperblock; ignore = ds_grid_desc_mblock_mperblock_nblock_nperblock;
...@@ -200,7 +199,8 @@ template <index_t NDimSpatial, ...@@ -200,7 +199,8 @@ template <index_t NDimSpatial,
index_t CDEBlockTransferScalarPerVector_NPerBlock, index_t CDEBlockTransferScalarPerVector_NPerBlock,
LoopScheduler LoopSched = make_default_loop_scheduler(), LoopScheduler LoopSched = make_default_loop_scheduler(),
typename AComputeType = ADataType, typename AComputeType = ADataType,
typename BComputeType = AComputeType> typename BComputeType = AComputeType,
index_t NumGroupsToMerge = 1>
struct DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1 struct DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1
: public DeviceGroupedConvBwdDataMultipleD<NDimSpatial, : public DeviceGroupedConvBwdDataMultipleD<NDimSpatial,
ALayout, // output image ALayout, // output image
...@@ -220,6 +220,8 @@ struct DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1 ...@@ -220,6 +220,8 @@ struct DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1
// TODO: Extend support for more spatial dimensions. // TODO: Extend support for more spatial dimensions.
static_assert(NDimSpatial == 2 || NDimSpatial == 3, static_assert(NDimSpatial == 2 || NDimSpatial == 3,
"wrong! only implemented for 2D and 3D now"); "wrong! only implemented for 2D and 3D now");
static_assert(NumGroupsToMerge >= 1, "wrong! NumGroupsToMerge must be greater or equal to 1!");
using DeviceOp = DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1; using DeviceOp = DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1;
...@@ -242,7 +244,8 @@ struct DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1 ...@@ -242,7 +244,8 @@ struct DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1
NPerBlock, NPerBlock,
KPerBlock, KPerBlock,
DoPadGemmM, DoPadGemmM,
DoPadGemmN>{}; DoPadGemmN,
NumGroupsToMerge>{};
static auto GetDummyABDsEGridDescriptor() static auto GetDummyABDsEGridDescriptor()
{ {
...@@ -453,10 +456,10 @@ struct DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1 ...@@ -453,10 +456,10 @@ struct DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1
// A/B/Ds/E Batch Stride // A/B/Ds/E Batch Stride
compute_ptr_offset_of_batch_.BatchStrideA_ = a_g_n_k_wos_strides[0]; compute_ptr_offset_of_batch_.BatchStrideA_ = a_g_n_k_wos_strides[0];
compute_ptr_offset_of_batch_.BatchStrideB_ = b_g_k_c_xs_strides[0]; compute_ptr_offset_of_batch_.BatchStrideB_ = b_g_k_c_xs_strides[0];
compute_ptr_offset_of_batch_.BatchStrideE_ = e_g_n_c_wis_strides[0]; compute_ptr_offset_of_batch_.BatchStrideE_ = e_g_n_c_wis_strides[0] * NumGroupsToMerge; // sure?
static_for<0, NumDTensor, 1>{}([&](auto i) { static_for<0, NumDTensor, 1>{}([&](auto i) {
compute_ptr_offset_of_batch_.BatchStrideDs_(i) = ds_g_n_c_wis_strides[i][0]; compute_ptr_offset_of_batch_.BatchStrideDs_(i) = ds_g_n_c_wis_strides[i][0] * NumGroupsToMerge;
}); });
static constexpr auto NonSpatialDimsNum = Number<3>{}; static constexpr auto NonSpatialDimsNum = Number<3>{};
...@@ -713,7 +716,7 @@ struct DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1 ...@@ -713,7 +716,7 @@ struct DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1
float ave_time = 0; float ave_time = 0;
for(std::size_t i = 0; i < arg.a_grid_desc_ak0_m_ak1_container_.size(); i++) for(std::size_t i = 0; i < arg.a_grid_desc_ak0_m_ak1_container_.size(); i++) // why is there a for?
{ {
if(!GridwiseGemm::CheckValidity(arg.a_grid_desc_m_k_container_[i], if(!GridwiseGemm::CheckValidity(arg.a_grid_desc_m_k_container_[i],
arg.b_grid_desc_n_k_container_[i], arg.b_grid_desc_n_k_container_[i],
...@@ -752,7 +755,7 @@ struct DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1 ...@@ -752,7 +755,7 @@ struct DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1
return launch_and_time_kernel( return launch_and_time_kernel(
stream_config, stream_config,
kernel, kernel,
dim3(grid_size), dim3(grid_size), // change to gdx, gdy, gdz after removing for and calculating all in one go
dim3(BlockSize), dim3(BlockSize),
0, 0,
arg.p_a_grid_, arg.p_a_grid_,
...@@ -762,7 +765,7 @@ struct DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1 ...@@ -762,7 +765,7 @@ struct DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1
arg.a_element_op_, arg.a_element_op_,
arg.b_element_op_, arg.b_element_op_,
arg.cde_element_op_, arg.cde_element_op_,
arg.a_g_n_k_wos_lengths_[0], // Group count //arg.a_g_n_k_wos_lengths_[0], // Group count
arg.a_grid_desc_ak0_m_ak1_container_[i], arg.a_grid_desc_ak0_m_ak1_container_[i],
arg.b_grid_desc_bk0_n_bk1_container_[i], arg.b_grid_desc_bk0_n_bk1_container_[i],
arg.ds_grid_desc_mblock_mperblock_nblock_nperblock_container_[i], arg.ds_grid_desc_mblock_mperblock_nblock_nperblock_container_[i],
...@@ -798,8 +801,9 @@ struct DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1 ...@@ -798,8 +801,9 @@ struct DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1
return false; return false;
} }
const index_t ConvK = arg.b_g_k_c_xs_lengths_[1]; const index_t ConvG = arg.b_g_k_c_xs_lengths_[I0];
const index_t ConvC = arg.b_g_k_c_xs_lengths_[2]; const index_t ConvK = arg.b_g_k_c_xs_lengths_[I1];
const index_t ConvC = arg.b_g_k_c_xs_lengths_[I2];
// Specifialization // Specifialization
if constexpr(ConvBackwardDataSpecialization == if constexpr(ConvBackwardDataSpecialization ==
...@@ -894,6 +898,22 @@ struct DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1 ...@@ -894,6 +898,22 @@ struct DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1
return false; return false;
} }
if constexpr(NumGroupsToMerge > 1)
{
if(!(ConvC == 1))
{
return false;
}
if(ConvG % NumGroupsToMerge != 0)
{
return false;
}
if constexpr(!is_NSpatialGK_GKSpatial_NSpatialGC<ALayout, BLayout, ELayout>())
{
return false;
}
}
// Gridwise GEMM size // Gridwise GEMM size
for(std::size_t i = 0; i < arg.a_grid_desc_ak0_m_ak1_container_.size(); i++) for(std::size_t i = 0; i < arg.a_grid_desc_ak0_m_ak1_container_.size(); i++)
{ {
...@@ -1031,7 +1051,8 @@ struct DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1 ...@@ -1031,7 +1051,8 @@ struct DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1
<< ABlockTransferSrcScalarPerVector << ", " << ABlockTransferSrcScalarPerVector << ", "
<< BBlockTransferSrcScalarPerVector << ", " << BBlockTransferSrcScalarPerVector << ", "
<< CShuffleMXdlPerWavePerShuffle << ", " << CShuffleMXdlPerWavePerShuffle << ", "
<< CShuffleNXdlPerWavePerShuffle << CShuffleNXdlPerWavePerShuffle << ", "
<< NumGroupsToMerge
<< ">"; << ">";
return str.str(); return str.str();
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment