Commit 0acd3ebe authored by ltqin's avatar ltqin
Browse files

start change gridwise k split

parent 1043ab4f
......@@ -148,7 +148,8 @@ template <index_t BlockSize,
typename CGridStepHacks,
typename AGridMoveSliceWindowStepHacks,
typename BGridMoveSliceWindowStepHacks,
bool CAccessOrderMRepeatNRepeat>
bool CAccessOrderMRepeatNRepeat,
index_t KBatch>
struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r4
{
static constexpr auto I0 = Number<0>{};
......@@ -233,7 +234,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r4
}
__host__ __device__ static constexpr index_t
CalculateMNGridSize(const CMNGridDesc& c_m_n_grid_desc)
CalculateGridSize(const CMNGridDesc& c_m_n_grid_desc)
{
const auto M = c_m_n_grid_desc.GetLength(I0);
const auto N = c_m_n_grid_desc.GetLength(I1);
......@@ -243,15 +244,24 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r4
return grid_size_mn;
}
__host__ __device__ static constexpr index_t CalculateGridSize(const index_t M, const index_t N)
{
const index_t grid_size_mn = (M / MPerBlock) * (N / NPerBlock);
return grid_size_mn;
}
__host__ __device__ static constexpr auto
MakeABK0MK1GridDescriptor(const AK0MK1GridDesc& a_k0_m_k1_grid_desc, const index_t kbatch)
MakeABK0MK1GridDescriptor(const AK0MK1GridDesc& a_k0_m_k1_grid_desc)
{
const auto K0 = a_k0_m_k1_grid_desc.GetLength(I0);
const auto M = a_k0_m_k1_grid_desc.GetLength(I1);
assert(K0 % KBatch == 0);
const auto a_b_k0_m_k1_grid_desc = transform_tensor_descriptor(
a_k0_m_k1_grid_desc,
make_tuple(make_unmerge_transform(make_tuple(kbatch, K0 / kbatch)),
make_tuple(make_unmerge_transform(make_tuple(KBatch, K0 / KBatch)),
make_pass_through_transform(M),
make_pass_through_transform(K1Value)),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}),
......@@ -260,14 +270,16 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r4
}
__host__ __device__ static constexpr auto
MakeBBK0NK1GridDescriptor(const BK0NK1GridDesc& b_k0_n_k1_grid_desc, const index_t kbatch)
MakeBBK0NK1GridDescriptor(const BK0NK1GridDesc& b_k0_n_k1_grid_desc)
{
const auto K0 = b_k0_n_k1_grid_desc.GetLength(I0);
const auto N = b_k0_n_k1_grid_desc.GetLength(I1);
assert(K0 % KBatch == 0);
const auto b_b_k0_n_k1_grid_desc = transform_tensor_descriptor(
b_k0_n_k1_grid_desc,
make_tuple(make_unmerge_transform(make_tuple(kbatch, K0 / kbatch)),
make_tuple(make_unmerge_transform(make_tuple(KBatch, K0 / KBatch)),
make_pass_through_transform(N),
make_pass_through_transform(K1Value)),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}),
......@@ -327,8 +339,8 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r4
return c_blockid_to_m0_n0_block_cluster_adaptor;
}
using ABK0MK1GridDesc = decltype(MakeABK0MK1GridDescriptor(AK0MK1GridDesc{}, I1));
using BBK0NK1GridDesc = decltype(MakeBBK0NK1GridDescriptor(BK0NK1GridDesc{}, I1));
using ABK0MK1GridDesc = decltype(MakeABK0MK1GridDescriptor(AK0MK1GridDesc{}));
using BBK0NK1GridDesc = decltype(MakeBBK0NK1GridDescriptor(BK0NK1GridDesc{}));
using CM0N0M1N1M2M3M4N2GridDesc = decltype(MakeCM0N0M1N1M2M3M4N2GridDescriptor(CMNGridDesc{}));
using CBlockClusterAdaptor = decltype(MakeCBlockClusterAdaptor(CMNGridDesc{}));
......@@ -344,24 +356,28 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r4
const CBlockClusterAdaptor& c_block_cluster_adaptor)
{
const auto a_grid_buf = make_dynamic_buffer<AddressSpaceEnum_t::Global>(
p_a_grid, a_k0_m_k1_grid_desc.GetElementSpaceSize());
p_a_grid, a_b_k0_m_k1_grid_desc.GetElementSpaceSize());
const auto b_grid_buf = make_dynamic_buffer<AddressSpaceEnum_t::Global>(
p_b_grid, b_k0_n_k1_grid_desc.GetElementSpaceSize());
p_b_grid, b_b_k0_n_k1_grid_desc.GetElementSpaceSize());
auto c_grid_buf = make_dynamic_buffer<AddressSpaceEnum_t::Global>(
p_c_grid, c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc.GetElementSpaceSize());
const auto K0 = a_k0_m_k1_grid_desc.GetLength(I0);
const auto kbatch = CalculateKBatch(CMNGridDesc{}, b_k0_n_k1_grid_desc);
if(get_block_1d_id() == 0)
printf("*****kbatch : %d, %d, %d, %d\n",
kbatch,
a_b_k0_m_k1_grid_desc.GetLength(I0),
b_b_k0_n_k1_grid_desc.GetLength(I0),
const auto K0 = a_b_k0_m_k1_grid_desc.GetLength(I1);
const auto M = a_b_k0_m_k1_grid_desc.GetLength(I2);
const auto N = b_b_k0_n_k1_grid_desc.GetLength(I2);
const auto b_grid_size = CalculateGridSize(M, N);
const auto nBatch = get_block_1d_id() / b_grid_size;
const auto blockid_in_batch = get_block_1d_id() % b_grid_size;
if(get_block_1d_id() == 2000)
printf("grid size: %d, Batch: %d block_id: %d k0: %d\n",
b_grid_size,
nBatch,
blockid_in_batch,
K0);
// divide block work by [M, N]
const auto block_work_idx =
c_block_cluster_adaptor.CalculateBottomIndex(make_multi_index(get_block_1d_id()));
c_block_cluster_adaptor.CalculateBottomIndex(make_multi_index(blockid_in_batch));
// HACK: this force m/n_block_data_idx_on_grid into SGPR
const index_t m_block_data_idx_on_grid =
......
......@@ -75,6 +75,8 @@ void device_convolution_backward_weight_implicit_gemm_v4r4r3_xdlops_nchw_kcyx_nk
constexpr index_t GemmBBlockTransferDstScalarPerVector_GemmK1 = 8;
constexpr index_t GemmCThreadTransferDstScalarPerVector = 1;
constexpr index_t KBatch = 96;
#elif 1
// [M, N, K0, K1] = [128, 128, 4, 8] for fp16
constexpr index_t BlockSize = 256;
......@@ -167,7 +169,7 @@ void device_convolution_backward_weight_implicit_gemm_v4r4r3_xdlops_nchw_kcyx_nk
TInWei,
TAcc,
TOut,
InMemoryDataOperationEnum_t::Set,
InMemoryDataOperationEnum_t::AtomicAdd,
decltype(out_gemmk0_gemmm_gemmk1_grid_desc),
decltype(in_gemmk0_gemmn_gemmk1_grid_desc),
decltype(wei_gemmm_gemmn_grid_desc),
......@@ -203,7 +205,8 @@ void device_convolution_backward_weight_implicit_gemm_v4r4r3_xdlops_nchw_kcyx_nk
decltype(wei_m0_n0_m1_n1_m2_m3_m4_n2_grid_step_hacks),
decltype(out_gemmk0_gemmm_gemmk1_grid_move_slice_window_step_hacks),
decltype(in_gemmk0_gemmn_gemmk1_grid_move_slice_window_step_hacks),
false>(static_cast<TOut*>(out_n_k_ho_wo_device_buf.GetDeviceBuffer()),
false,
KBatch>(static_cast<TOut*>(out_n_k_ho_wo_device_buf.GetDeviceBuffer()),
static_cast<TInWei*>(in_n_c_hi_wi_device_buf.GetDeviceBuffer()),
static_cast<TInWei*>(wei_k_c_y_x_device_buf.GetDeviceBuffer()),
out_gemmk0_gemmm_gemmk1_grid_desc,
......
......@@ -46,7 +46,8 @@ template <ck::index_t BlockSize,
typename CGridStepHacks,
typename AGridMoveSliceWindowStepHacks,
typename BGridMoveSliceWindowStepHacks,
bool CAccessOrderMRepeatNRepeat>
bool CAccessOrderMRepeatNRepeat,
ck::index_t KBatch>
__host__ float driver_gemm_xdlops_v2r4(const FloatAB* p_a_grid,
const FloatAB* p_b_grid,
FloatC* p_c_grid,
......@@ -108,7 +109,8 @@ __host__ float driver_gemm_xdlops_v2r4(const FloatAB* p_a_grid,
CGridStepHacks,
AGridMoveSliceWindowStepHacks,
BGridMoveSliceWindowStepHacks,
CAccessOrderMRepeatNRepeat>;
CAccessOrderMRepeatNRepeat,
KBatch>;
{
std::cout << "a_k0_m_k1_grid_desc{" << a_k0_m_k1_grid_desc.GetLength(I0) << ", "
......@@ -122,13 +124,11 @@ __host__ float driver_gemm_xdlops_v2r4(const FloatAB* p_a_grid,
std::cout << "c_m_n_grid_desc{ " << c_m_n_grid_desc.GetLength(I0) << ", "
<< c_m_n_grid_desc.GetLength(I1) << "}" << std::endl;
}
const auto kbatch = GridwiseGemm::CalculateKBatch(c_m_n_grid_desc, b_k0_n_k1_grid_desc);
const auto a_b_k0_m_k1_grid_desc =
GridwiseGemm::MakeABK0MK1GridDescriptor(a_k0_m_k1_grid_desc, kbatch);
const auto b_b_k0_n_k1_grid_desc =
GridwiseGemm::MakeBBK0NK1GridDescriptor(b_k0_n_k1_grid_desc, kbatch);
// const auto kbatch = GridwiseGemm::CalculateKBatch(c_m_n_grid_desc, b_k0_n_k1_grid_desc);
const auto a_b_k0_m_k1_grid_desc = GridwiseGemm::MakeABK0MK1GridDescriptor(a_k0_m_k1_grid_desc);
const auto b_b_k0_n_k1_grid_desc = GridwiseGemm::MakeBBK0NK1GridDescriptor(b_k0_n_k1_grid_desc);
{
std::cout << "k batch number is: " << kbatch << std::endl;
// std::cout << "k batch number is: " << kbatch << std::endl;
}
if(!GridwiseGemm::CheckValidity(a_k0_m_k1_grid_desc, b_k0_n_k1_grid_desc, c_m_n_grid_desc))
{
......@@ -147,8 +147,8 @@ __host__ float driver_gemm_xdlops_v2r4(const FloatAB* p_a_grid,
using CBlockClusterAdaptor = decltype(c_block_cluster_adaptor);
const index_t grid_size_mn = GridwiseGemm::CalculateMNGridSize(c_m_n_grid_desc);
const index_t grid_size = grid_size_mn * kbatch;
const index_t grid_size_mn = GridwiseGemm::CalculateGridSize(c_m_n_grid_desc);
const index_t grid_size = grid_size_mn * KBatch;
{
std::cout << "mxn gridSize : " << grid_size_mn << " finally grid_size : " << grid_size
<< std::endl;
......@@ -189,6 +189,8 @@ __host__ float driver_gemm_xdlops_v2r4(const FloatAB* p_a_grid,
a_k0_m_k1_grid_desc_dev_buf.ToDevice(&a_k0_m_k1_grid_desc);
b_k0_n_k1_grid_desc_dev_buf.ToDevice(&b_k0_n_k1_grid_desc);
a_b_k0_m_k1_grid_desc_dev_buf.ToDevice(&a_b_k0_m_k1_grid_desc);
b_b_k0_n_k1_grid_desc_dev_buf.ToDevice(&b_b_k0_n_k1_grid_desc);
c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc_dev_buf.ToDevice(&c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc);
c_block_cluster_adaptor_dev_buf.ToDevice(&c_block_cluster_adaptor);
......
......@@ -267,6 +267,8 @@ int main(int argc, char* argv[])
{
throw std::runtime_error("wrong! layout");
}
// set zero to wei_device
wei_device.GenerateTensorValue(GeneratorTensor_0{}, num_thread);
const auto tmp = f_make_for_device_nchw();
......
......@@ -15,6 +15,17 @@ struct GeneratorTensor_1
}
};
struct GeneratorTensor_0
{
int value = 0;
template <typename... Is>
float operator()(Is...)
{
return value;
}
};
struct GeneratorTensor_2
{
int min_value = 0;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment