Commit 0acd3ebe authored by ltqin's avatar ltqin
Browse files

start change gridwise k split

parent 1043ab4f
...@@ -148,7 +148,8 @@ template <index_t BlockSize, ...@@ -148,7 +148,8 @@ template <index_t BlockSize,
typename CGridStepHacks, typename CGridStepHacks,
typename AGridMoveSliceWindowStepHacks, typename AGridMoveSliceWindowStepHacks,
typename BGridMoveSliceWindowStepHacks, typename BGridMoveSliceWindowStepHacks,
bool CAccessOrderMRepeatNRepeat> bool CAccessOrderMRepeatNRepeat,
index_t KBatch>
struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r4 struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r4
{ {
static constexpr auto I0 = Number<0>{}; static constexpr auto I0 = Number<0>{};
...@@ -233,7 +234,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r4 ...@@ -233,7 +234,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r4
} }
__host__ __device__ static constexpr index_t __host__ __device__ static constexpr index_t
CalculateMNGridSize(const CMNGridDesc& c_m_n_grid_desc) CalculateGridSize(const CMNGridDesc& c_m_n_grid_desc)
{ {
const auto M = c_m_n_grid_desc.GetLength(I0); const auto M = c_m_n_grid_desc.GetLength(I0);
const auto N = c_m_n_grid_desc.GetLength(I1); const auto N = c_m_n_grid_desc.GetLength(I1);
...@@ -243,15 +244,24 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r4 ...@@ -243,15 +244,24 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r4
return grid_size_mn; return grid_size_mn;
} }
__host__ __device__ static constexpr index_t CalculateGridSize(const index_t M, const index_t N)
{
const index_t grid_size_mn = (M / MPerBlock) * (N / NPerBlock);
return grid_size_mn;
}
__host__ __device__ static constexpr auto __host__ __device__ static constexpr auto
MakeABK0MK1GridDescriptor(const AK0MK1GridDesc& a_k0_m_k1_grid_desc, const index_t kbatch) MakeABK0MK1GridDescriptor(const AK0MK1GridDesc& a_k0_m_k1_grid_desc)
{ {
const auto K0 = a_k0_m_k1_grid_desc.GetLength(I0); const auto K0 = a_k0_m_k1_grid_desc.GetLength(I0);
const auto M = a_k0_m_k1_grid_desc.GetLength(I1); const auto M = a_k0_m_k1_grid_desc.GetLength(I1);
assert(K0 % KBatch == 0);
const auto a_b_k0_m_k1_grid_desc = transform_tensor_descriptor( const auto a_b_k0_m_k1_grid_desc = transform_tensor_descriptor(
a_k0_m_k1_grid_desc, a_k0_m_k1_grid_desc,
make_tuple(make_unmerge_transform(make_tuple(kbatch, K0 / kbatch)), make_tuple(make_unmerge_transform(make_tuple(KBatch, K0 / KBatch)),
make_pass_through_transform(M), make_pass_through_transform(M),
make_pass_through_transform(K1Value)), make_pass_through_transform(K1Value)),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}), make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}),
...@@ -260,14 +270,16 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r4 ...@@ -260,14 +270,16 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r4
} }
__host__ __device__ static constexpr auto __host__ __device__ static constexpr auto
MakeBBK0NK1GridDescriptor(const BK0NK1GridDesc& b_k0_n_k1_grid_desc, const index_t kbatch) MakeBBK0NK1GridDescriptor(const BK0NK1GridDesc& b_k0_n_k1_grid_desc)
{ {
const auto K0 = b_k0_n_k1_grid_desc.GetLength(I0); const auto K0 = b_k0_n_k1_grid_desc.GetLength(I0);
const auto N = b_k0_n_k1_grid_desc.GetLength(I1); const auto N = b_k0_n_k1_grid_desc.GetLength(I1);
assert(K0 % KBatch == 0);
const auto b_b_k0_n_k1_grid_desc = transform_tensor_descriptor( const auto b_b_k0_n_k1_grid_desc = transform_tensor_descriptor(
b_k0_n_k1_grid_desc, b_k0_n_k1_grid_desc,
make_tuple(make_unmerge_transform(make_tuple(kbatch, K0 / kbatch)), make_tuple(make_unmerge_transform(make_tuple(KBatch, K0 / KBatch)),
make_pass_through_transform(N), make_pass_through_transform(N),
make_pass_through_transform(K1Value)), make_pass_through_transform(K1Value)),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}), make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}),
...@@ -327,8 +339,8 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r4 ...@@ -327,8 +339,8 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r4
return c_blockid_to_m0_n0_block_cluster_adaptor; return c_blockid_to_m0_n0_block_cluster_adaptor;
} }
using ABK0MK1GridDesc = decltype(MakeABK0MK1GridDescriptor(AK0MK1GridDesc{}, I1)); using ABK0MK1GridDesc = decltype(MakeABK0MK1GridDescriptor(AK0MK1GridDesc{}));
using BBK0NK1GridDesc = decltype(MakeBBK0NK1GridDescriptor(BK0NK1GridDesc{}, I1)); using BBK0NK1GridDesc = decltype(MakeBBK0NK1GridDescriptor(BK0NK1GridDesc{}));
using CM0N0M1N1M2M3M4N2GridDesc = decltype(MakeCM0N0M1N1M2M3M4N2GridDescriptor(CMNGridDesc{})); using CM0N0M1N1M2M3M4N2GridDesc = decltype(MakeCM0N0M1N1M2M3M4N2GridDescriptor(CMNGridDesc{}));
using CBlockClusterAdaptor = decltype(MakeCBlockClusterAdaptor(CMNGridDesc{})); using CBlockClusterAdaptor = decltype(MakeCBlockClusterAdaptor(CMNGridDesc{}));
...@@ -344,24 +356,28 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r4 ...@@ -344,24 +356,28 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r4
const CBlockClusterAdaptor& c_block_cluster_adaptor) const CBlockClusterAdaptor& c_block_cluster_adaptor)
{ {
const auto a_grid_buf = make_dynamic_buffer<AddressSpaceEnum_t::Global>( const auto a_grid_buf = make_dynamic_buffer<AddressSpaceEnum_t::Global>(
p_a_grid, a_k0_m_k1_grid_desc.GetElementSpaceSize()); p_a_grid, a_b_k0_m_k1_grid_desc.GetElementSpaceSize());
const auto b_grid_buf = make_dynamic_buffer<AddressSpaceEnum_t::Global>( const auto b_grid_buf = make_dynamic_buffer<AddressSpaceEnum_t::Global>(
p_b_grid, b_k0_n_k1_grid_desc.GetElementSpaceSize()); p_b_grid, b_b_k0_n_k1_grid_desc.GetElementSpaceSize());
auto c_grid_buf = make_dynamic_buffer<AddressSpaceEnum_t::Global>( auto c_grid_buf = make_dynamic_buffer<AddressSpaceEnum_t::Global>(
p_c_grid, c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc.GetElementSpaceSize()); p_c_grid, c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc.GetElementSpaceSize());
const auto K0 = a_k0_m_k1_grid_desc.GetLength(I0); const auto K0 = a_b_k0_m_k1_grid_desc.GetLength(I1);
const auto M = a_b_k0_m_k1_grid_desc.GetLength(I2);
const auto kbatch = CalculateKBatch(CMNGridDesc{}, b_k0_n_k1_grid_desc); const auto N = b_b_k0_n_k1_grid_desc.GetLength(I2);
if(get_block_1d_id() == 0) const auto b_grid_size = CalculateGridSize(M, N);
printf("*****kbatch : %d, %d, %d, %d\n", const auto nBatch = get_block_1d_id() / b_grid_size;
kbatch, const auto blockid_in_batch = get_block_1d_id() % b_grid_size;
a_b_k0_m_k1_grid_desc.GetLength(I0), if(get_block_1d_id() == 2000)
b_b_k0_n_k1_grid_desc.GetLength(I0), printf("grid size: %d, Batch: %d block_id: %d k0: %d\n",
b_grid_size,
nBatch,
blockid_in_batch,
K0); K0);
// divide block work by [M, N] // divide block work by [M, N]
const auto block_work_idx = const auto block_work_idx =
c_block_cluster_adaptor.CalculateBottomIndex(make_multi_index(get_block_1d_id())); c_block_cluster_adaptor.CalculateBottomIndex(make_multi_index(blockid_in_batch));
// HACK: this force m/n_block_data_idx_on_grid into SGPR // HACK: this force m/n_block_data_idx_on_grid into SGPR
const index_t m_block_data_idx_on_grid = const index_t m_block_data_idx_on_grid =
......
...@@ -75,6 +75,8 @@ void device_convolution_backward_weight_implicit_gemm_v4r4r3_xdlops_nchw_kcyx_nk ...@@ -75,6 +75,8 @@ void device_convolution_backward_weight_implicit_gemm_v4r4r3_xdlops_nchw_kcyx_nk
constexpr index_t GemmBBlockTransferDstScalarPerVector_GemmK1 = 8; constexpr index_t GemmBBlockTransferDstScalarPerVector_GemmK1 = 8;
constexpr index_t GemmCThreadTransferDstScalarPerVector = 1; constexpr index_t GemmCThreadTransferDstScalarPerVector = 1;
constexpr index_t KBatch = 96;
#elif 1 #elif 1
// [M, N, K0, K1] = [128, 128, 4, 8] for fp16 // [M, N, K0, K1] = [128, 128, 4, 8] for fp16
constexpr index_t BlockSize = 256; constexpr index_t BlockSize = 256;
...@@ -167,7 +169,7 @@ void device_convolution_backward_weight_implicit_gemm_v4r4r3_xdlops_nchw_kcyx_nk ...@@ -167,7 +169,7 @@ void device_convolution_backward_weight_implicit_gemm_v4r4r3_xdlops_nchw_kcyx_nk
TInWei, TInWei,
TAcc, TAcc,
TOut, TOut,
InMemoryDataOperationEnum_t::Set, InMemoryDataOperationEnum_t::AtomicAdd,
decltype(out_gemmk0_gemmm_gemmk1_grid_desc), decltype(out_gemmk0_gemmm_gemmk1_grid_desc),
decltype(in_gemmk0_gemmn_gemmk1_grid_desc), decltype(in_gemmk0_gemmn_gemmk1_grid_desc),
decltype(wei_gemmm_gemmn_grid_desc), decltype(wei_gemmm_gemmn_grid_desc),
...@@ -203,18 +205,19 @@ void device_convolution_backward_weight_implicit_gemm_v4r4r3_xdlops_nchw_kcyx_nk ...@@ -203,18 +205,19 @@ void device_convolution_backward_weight_implicit_gemm_v4r4r3_xdlops_nchw_kcyx_nk
decltype(wei_m0_n0_m1_n1_m2_m3_m4_n2_grid_step_hacks), decltype(wei_m0_n0_m1_n1_m2_m3_m4_n2_grid_step_hacks),
decltype(out_gemmk0_gemmm_gemmk1_grid_move_slice_window_step_hacks), decltype(out_gemmk0_gemmm_gemmk1_grid_move_slice_window_step_hacks),
decltype(in_gemmk0_gemmn_gemmk1_grid_move_slice_window_step_hacks), decltype(in_gemmk0_gemmn_gemmk1_grid_move_slice_window_step_hacks),
false>(static_cast<TOut*>(out_n_k_ho_wo_device_buf.GetDeviceBuffer()), false,
static_cast<TInWei*>(in_n_c_hi_wi_device_buf.GetDeviceBuffer()), KBatch>(static_cast<TOut*>(out_n_k_ho_wo_device_buf.GetDeviceBuffer()),
static_cast<TInWei*>(wei_k_c_y_x_device_buf.GetDeviceBuffer()), static_cast<TInWei*>(in_n_c_hi_wi_device_buf.GetDeviceBuffer()),
out_gemmk0_gemmm_gemmk1_grid_desc, static_cast<TInWei*>(wei_k_c_y_x_device_buf.GetDeviceBuffer()),
in_gemmk0_gemmn_gemmk1_grid_desc, out_gemmk0_gemmm_gemmk1_grid_desc,
wei_gemmm_gemmn_grid_desc, in_gemmk0_gemmn_gemmk1_grid_desc,
out_gemmk0_gemmm_gemmk1_grid_step_hacks, wei_gemmm_gemmn_grid_desc,
in_gemmk0_gemmn_gemmk1_grid_step_hacks, out_gemmk0_gemmm_gemmk1_grid_step_hacks,
wei_m0_n0_m1_n1_m2_m3_m4_n2_grid_step_hacks, in_gemmk0_gemmn_gemmk1_grid_step_hacks,
out_gemmk0_gemmm_gemmk1_grid_move_slice_window_step_hacks, wei_m0_n0_m1_n1_m2_m3_m4_n2_grid_step_hacks,
in_gemmk0_gemmn_gemmk1_grid_move_slice_window_step_hacks, out_gemmk0_gemmm_gemmk1_grid_move_slice_window_step_hacks,
nrepeat); in_gemmk0_gemmn_gemmk1_grid_move_slice_window_step_hacks,
nrepeat);
float perf = static_cast<float>(calculate_convolution_flops( float perf = static_cast<float>(calculate_convolution_flops(
in_n_c_hi_wi_desc, wei_k_c_y_x_desc, out_n_k_ho_wo_desc)) / in_n_c_hi_wi_desc, wei_k_c_y_x_desc, out_n_k_ho_wo_desc)) /
......
...@@ -46,7 +46,8 @@ template <ck::index_t BlockSize, ...@@ -46,7 +46,8 @@ template <ck::index_t BlockSize,
typename CGridStepHacks, typename CGridStepHacks,
typename AGridMoveSliceWindowStepHacks, typename AGridMoveSliceWindowStepHacks,
typename BGridMoveSliceWindowStepHacks, typename BGridMoveSliceWindowStepHacks,
bool CAccessOrderMRepeatNRepeat> bool CAccessOrderMRepeatNRepeat,
ck::index_t KBatch>
__host__ float driver_gemm_xdlops_v2r4(const FloatAB* p_a_grid, __host__ float driver_gemm_xdlops_v2r4(const FloatAB* p_a_grid,
const FloatAB* p_b_grid, const FloatAB* p_b_grid,
FloatC* p_c_grid, FloatC* p_c_grid,
...@@ -108,7 +109,8 @@ __host__ float driver_gemm_xdlops_v2r4(const FloatAB* p_a_grid, ...@@ -108,7 +109,8 @@ __host__ float driver_gemm_xdlops_v2r4(const FloatAB* p_a_grid,
CGridStepHacks, CGridStepHacks,
AGridMoveSliceWindowStepHacks, AGridMoveSliceWindowStepHacks,
BGridMoveSliceWindowStepHacks, BGridMoveSliceWindowStepHacks,
CAccessOrderMRepeatNRepeat>; CAccessOrderMRepeatNRepeat,
KBatch>;
{ {
std::cout << "a_k0_m_k1_grid_desc{" << a_k0_m_k1_grid_desc.GetLength(I0) << ", " std::cout << "a_k0_m_k1_grid_desc{" << a_k0_m_k1_grid_desc.GetLength(I0) << ", "
...@@ -122,13 +124,11 @@ __host__ float driver_gemm_xdlops_v2r4(const FloatAB* p_a_grid, ...@@ -122,13 +124,11 @@ __host__ float driver_gemm_xdlops_v2r4(const FloatAB* p_a_grid,
std::cout << "c_m_n_grid_desc{ " << c_m_n_grid_desc.GetLength(I0) << ", " std::cout << "c_m_n_grid_desc{ " << c_m_n_grid_desc.GetLength(I0) << ", "
<< c_m_n_grid_desc.GetLength(I1) << "}" << std::endl; << c_m_n_grid_desc.GetLength(I1) << "}" << std::endl;
} }
const auto kbatch = GridwiseGemm::CalculateKBatch(c_m_n_grid_desc, b_k0_n_k1_grid_desc); // const auto kbatch = GridwiseGemm::CalculateKBatch(c_m_n_grid_desc, b_k0_n_k1_grid_desc);
const auto a_b_k0_m_k1_grid_desc = const auto a_b_k0_m_k1_grid_desc = GridwiseGemm::MakeABK0MK1GridDescriptor(a_k0_m_k1_grid_desc);
GridwiseGemm::MakeABK0MK1GridDescriptor(a_k0_m_k1_grid_desc, kbatch); const auto b_b_k0_n_k1_grid_desc = GridwiseGemm::MakeBBK0NK1GridDescriptor(b_k0_n_k1_grid_desc);
const auto b_b_k0_n_k1_grid_desc =
GridwiseGemm::MakeBBK0NK1GridDescriptor(b_k0_n_k1_grid_desc, kbatch);
{ {
std::cout << "k batch number is: " << kbatch << std::endl; // std::cout << "k batch number is: " << kbatch << std::endl;
} }
if(!GridwiseGemm::CheckValidity(a_k0_m_k1_grid_desc, b_k0_n_k1_grid_desc, c_m_n_grid_desc)) if(!GridwiseGemm::CheckValidity(a_k0_m_k1_grid_desc, b_k0_n_k1_grid_desc, c_m_n_grid_desc))
{ {
...@@ -147,8 +147,8 @@ __host__ float driver_gemm_xdlops_v2r4(const FloatAB* p_a_grid, ...@@ -147,8 +147,8 @@ __host__ float driver_gemm_xdlops_v2r4(const FloatAB* p_a_grid,
using CBlockClusterAdaptor = decltype(c_block_cluster_adaptor); using CBlockClusterAdaptor = decltype(c_block_cluster_adaptor);
const index_t grid_size_mn = GridwiseGemm::CalculateMNGridSize(c_m_n_grid_desc); const index_t grid_size_mn = GridwiseGemm::CalculateGridSize(c_m_n_grid_desc);
const index_t grid_size = grid_size_mn * kbatch; const index_t grid_size = grid_size_mn * KBatch;
{ {
std::cout << "mxn gridSize : " << grid_size_mn << " finally grid_size : " << grid_size std::cout << "mxn gridSize : " << grid_size_mn << " finally grid_size : " << grid_size
<< std::endl; << std::endl;
...@@ -189,6 +189,8 @@ __host__ float driver_gemm_xdlops_v2r4(const FloatAB* p_a_grid, ...@@ -189,6 +189,8 @@ __host__ float driver_gemm_xdlops_v2r4(const FloatAB* p_a_grid,
a_k0_m_k1_grid_desc_dev_buf.ToDevice(&a_k0_m_k1_grid_desc); a_k0_m_k1_grid_desc_dev_buf.ToDevice(&a_k0_m_k1_grid_desc);
b_k0_n_k1_grid_desc_dev_buf.ToDevice(&b_k0_n_k1_grid_desc); b_k0_n_k1_grid_desc_dev_buf.ToDevice(&b_k0_n_k1_grid_desc);
a_b_k0_m_k1_grid_desc_dev_buf.ToDevice(&a_b_k0_m_k1_grid_desc);
b_b_k0_n_k1_grid_desc_dev_buf.ToDevice(&b_b_k0_n_k1_grid_desc);
c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc_dev_buf.ToDevice(&c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc); c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc_dev_buf.ToDevice(&c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc);
c_block_cluster_adaptor_dev_buf.ToDevice(&c_block_cluster_adaptor); c_block_cluster_adaptor_dev_buf.ToDevice(&c_block_cluster_adaptor);
......
...@@ -267,6 +267,8 @@ int main(int argc, char* argv[]) ...@@ -267,6 +267,8 @@ int main(int argc, char* argv[])
{ {
throw std::runtime_error("wrong! layout"); throw std::runtime_error("wrong! layout");
} }
// set zero to wei_device
wei_device.GenerateTensorValue(GeneratorTensor_0{}, num_thread);
const auto tmp = f_make_for_device_nchw(); const auto tmp = f_make_for_device_nchw();
......
...@@ -15,6 +15,17 @@ struct GeneratorTensor_1 ...@@ -15,6 +15,17 @@ struct GeneratorTensor_1
} }
}; };
struct GeneratorTensor_0
{
int value = 0;
template <typename... Is>
float operator()(Is...)
{
return value;
}
};
struct GeneratorTensor_2 struct GeneratorTensor_2
{ {
int min_value = 0; int min_value = 0;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment