"src/include/threadwise_tensor_slice_op.hpp" did not exist on "b7d052459d1f67cd3c1fdcb331027da18a479e63"
Commit 7a20cf67 authored by Jakub Piasecki's avatar Jakub Piasecki
Browse files

tmp save

parent a793afc9
......@@ -76,7 +76,6 @@ __global__ void
const AElementwiseOperation a_element_op,
const BElementwiseOperation b_element_op,
const CDEElementwiseOperation cde_element_op,
const index_t batch_count,
const AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1,
const BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1,
const DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
......@@ -89,18 +88,19 @@ __global__ void
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \
defined(__gfx94__))
// offset base pointer for each work-group
const index_t num_blocks_per_batch =
__builtin_amdgcn_readfirstlane(get_grid_size() / batch_count);
const index_t g_idx = __builtin_amdgcn_readfirstlane(get_block_1d_id() / num_blocks_per_batch);
// const index_t num_blocks_per_batch =
// __builtin_amdgcn_readfirstlane(get_grid_size() / batch_count);
//const index_t g_idx = __builtin_amdgcn_readfirstlane(get_block_1d_id() / num_blocks_per_batch);
const index_t g_idx = __builtin_amdgcn_readfirstlane(blockIdx.y);
const long_index_t a_batch_offset =
const long_index_t a_group_offset =
amd_wave_read_first_lane(compute_ptr_offset_of_batch.GetAPtrOffset(g_idx));
const long_index_t b_batch_offset =
const long_index_t b_group_offset =
amd_wave_read_first_lane(compute_ptr_offset_of_batch.GetBPtrOffset(g_idx));
const long_index_t e_batch_offset =
const long_index_t e_group_offset =
amd_wave_read_first_lane(compute_ptr_offset_of_batch.GetEPtrOffset(g_idx));
const auto ds_batch_offset = compute_ptr_offset_of_batch.GetDsPtrOffset(g_idx);
const auto ds_group_offset = compute_ptr_offset_of_batch.GetDsPtrOffset(g_idx);
__shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()];
......@@ -110,12 +110,12 @@ __global__ void
DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock::Size();
static_for<0, NumDTensor, 1>{}(
[&](auto i) { p_ds_grid_grp(i) = p_ds_grid[i] + ds_batch_offset[i]; });
[&](auto i) { p_ds_grid_grp(i) = p_ds_grid[i] + ds_group_offset[i]; });
GridwiseGemm::template Run<HasMainKBlockLoop>(p_a_grid + a_batch_offset,
p_b_grid + b_batch_offset,
GridwiseGemm::template Run<HasMainKBlockLoop>(p_a_grid + a_group_offset,
p_b_grid + b_group_offset,
p_ds_grid_grp,
p_e_grid + e_batch_offset,
p_e_grid + e_group_offset,
p_shared,
a_element_op,
b_element_op,
......@@ -130,7 +130,6 @@ __global__ void
ignore = p_b_grid;
ignore = p_ds_grid;
ignore = p_e_grid;
ignore = batch_count;
ignore = a_grid_desc_ak0_m_ak1;
ignore = b_grid_desc_bk0_n_bk1;
ignore = ds_grid_desc_mblock_mperblock_nblock_nperblock;
......@@ -200,7 +199,8 @@ template <index_t NDimSpatial,
index_t CDEBlockTransferScalarPerVector_NPerBlock,
LoopScheduler LoopSched = make_default_loop_scheduler(),
typename AComputeType = ADataType,
typename BComputeType = AComputeType>
typename BComputeType = AComputeType,
index_t NumGroupsToMerge = 1>
struct DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1
: public DeviceGroupedConvBwdDataMultipleD<NDimSpatial,
ALayout, // output image
......@@ -220,6 +220,8 @@ struct DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1
// TODO: Extend support for more spatial dimensions.
static_assert(NDimSpatial == 2 || NDimSpatial == 3,
"wrong! only implemented for 2D and 3D now");
static_assert(NumGroupsToMerge >= 1, "wrong! NumGroupsToMerge must be greater or equal to 1!");
using DeviceOp = DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1;
......@@ -242,7 +244,8 @@ struct DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1
NPerBlock,
KPerBlock,
DoPadGemmM,
DoPadGemmN>{};
DoPadGemmN,
NumGroupsToMerge>{};
static auto GetDummyABDsEGridDescriptor()
{
......@@ -453,10 +456,10 @@ struct DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1
// A/B/Ds/E Batch Stride
compute_ptr_offset_of_batch_.BatchStrideA_ = a_g_n_k_wos_strides[0];
compute_ptr_offset_of_batch_.BatchStrideB_ = b_g_k_c_xs_strides[0];
compute_ptr_offset_of_batch_.BatchStrideE_ = e_g_n_c_wis_strides[0];
compute_ptr_offset_of_batch_.BatchStrideE_ = e_g_n_c_wis_strides[0] * NumGroupsToMerge; // sure?
static_for<0, NumDTensor, 1>{}([&](auto i) {
compute_ptr_offset_of_batch_.BatchStrideDs_(i) = ds_g_n_c_wis_strides[i][0];
compute_ptr_offset_of_batch_.BatchStrideDs_(i) = ds_g_n_c_wis_strides[i][0] * NumGroupsToMerge;
});
static constexpr auto NonSpatialDimsNum = Number<3>{};
......@@ -713,7 +716,7 @@ struct DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1
float ave_time = 0;
for(std::size_t i = 0; i < arg.a_grid_desc_ak0_m_ak1_container_.size(); i++)
for(std::size_t i = 0; i < arg.a_grid_desc_ak0_m_ak1_container_.size(); i++) // why is there a for?
{
if(!GridwiseGemm::CheckValidity(arg.a_grid_desc_m_k_container_[i],
arg.b_grid_desc_n_k_container_[i],
......@@ -752,7 +755,7 @@ struct DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1
return launch_and_time_kernel(
stream_config,
kernel,
dim3(grid_size),
dim3(grid_size), // change to gdx, gdy, gdz after removing for and calculating all in one go
dim3(BlockSize),
0,
arg.p_a_grid_,
......@@ -762,7 +765,7 @@ struct DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1
arg.a_element_op_,
arg.b_element_op_,
arg.cde_element_op_,
arg.a_g_n_k_wos_lengths_[0], // Group count
//arg.a_g_n_k_wos_lengths_[0], // Group count
arg.a_grid_desc_ak0_m_ak1_container_[i],
arg.b_grid_desc_bk0_n_bk1_container_[i],
arg.ds_grid_desc_mblock_mperblock_nblock_nperblock_container_[i],
......@@ -798,8 +801,9 @@ struct DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1
return false;
}
const index_t ConvK = arg.b_g_k_c_xs_lengths_[1];
const index_t ConvC = arg.b_g_k_c_xs_lengths_[2];
const index_t ConvG = arg.b_g_k_c_xs_lengths_[I0];
const index_t ConvK = arg.b_g_k_c_xs_lengths_[I1];
const index_t ConvC = arg.b_g_k_c_xs_lengths_[I2];
// Specifialization
if constexpr(ConvBackwardDataSpecialization ==
......@@ -894,6 +898,22 @@ struct DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1
return false;
}
if constexpr(NumGroupsToMerge > 1)
{
if(!(ConvC == 1))
{
return false;
}
if(ConvG % NumGroupsToMerge != 0)
{
return false;
}
if constexpr(!is_NSpatialGK_GKSpatial_NSpatialGC<ALayout, BLayout, ELayout>())
{
return false;
}
}
// Gridwise GEMM size
for(std::size_t i = 0; i < arg.a_grid_desc_ak0_m_ak1_container_.size(); i++)
{
......@@ -1031,7 +1051,8 @@ struct DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1
<< ABlockTransferSrcScalarPerVector << ", "
<< BBlockTransferSrcScalarPerVector << ", "
<< CShuffleMXdlPerWavePerShuffle << ", "
<< CShuffleNXdlPerWavePerShuffle
<< CShuffleNXdlPerWavePerShuffle << ", "
<< NumGroupsToMerge
<< ">";
return str.str();
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment