Commit 0c9cdbce authored by aska-0096's avatar aska-0096
Browse files

format

parent 0517cf08
...@@ -49,8 +49,7 @@ static constexpr auto ABSpec = ck::tensor_operation::device::TensorSpecializatio ...@@ -49,8 +49,7 @@ static constexpr auto ABSpec = ck::tensor_operation::device::TensorSpecializatio
static constexpr auto DESpec = ck::tensor_operation::device::TensorSpecialization::Default; static constexpr auto DESpec = ck::tensor_operation::device::TensorSpecialization::Default;
using DeviceOpInstanceKKNN = using DeviceOpInstanceKKNN =
ck::tensor_operation::device::DeviceBatchedContractionMultipleD_Wmma_CShuffle< ck::tensor_operation::device::DeviceBatchedContractionMultipleD_Wmma_CShuffle<NumDimG,
NumDimG,
NumDimM, NumDimM,
NumDimN, NumDimN,
NumDimK, NumDimK,
...@@ -311,7 +310,8 @@ int main(int argc, char* argv[]) ...@@ -311,7 +310,8 @@ int main(int argc, char* argv[])
DeviceMem a_device_buf(sizeof(ADataType) * a_gs_ms_ks.mDesc.GetElementSpaceSize()); DeviceMem a_device_buf(sizeof(ADataType) * a_gs_ms_ks.mDesc.GetElementSpaceSize());
DeviceMem b_device_buf(sizeof(BDataType) * b_gs_ns_ks.mDesc.GetElementSpaceSize()); DeviceMem b_device_buf(sizeof(BDataType) * b_gs_ns_ks.mDesc.GetElementSpaceSize());
DeviceMem d_device_buf(sizeof(DDataType) * d_gs_ms_ns.mDesc.GetElementSpaceSize()); DeviceMem d_device_buf(sizeof(DDataType) * d_gs_ms_ns.mDesc.GetElementSpaceSize());
DeviceMem e_device_buf(sizeof(EDataType) * e_gs_ms_ns_device_result.mDesc.GetElementSpaceSize()); DeviceMem e_device_buf(sizeof(EDataType) *
e_gs_ms_ns_device_result.mDesc.GetElementSpaceSize());
a_device_buf.ToDevice(a_gs_ms_ks.mData.data()); a_device_buf.ToDevice(a_gs_ms_ks.mData.data());
b_device_buf.ToDevice(b_gs_ns_ks.mData.data()); b_device_buf.ToDevice(b_gs_ns_ks.mData.data());
...@@ -363,7 +363,7 @@ int main(int argc, char* argv[]) ...@@ -363,7 +363,7 @@ int main(int argc, char* argv[])
ck::index_t K = ck::accumulate_n<ck::index_t>( ck::index_t K = ck::accumulate_n<ck::index_t>(
a_gs_ms_ks_lengths.begin() + NumDimG + NumDimM, NumDimK, 1, std::multiplies<>{}); a_gs_ms_ks_lengths.begin() + NumDimG + NumDimM, NumDimK, 1, std::multiplies<>{});
std::cout<<"GMNK="<<G<<", "<<M<<", "<<N<<", "<<K<<std::endl; std::cout << "GMNK=" << G << ", " << M << ", " << N << ", " << K << std::endl;
std::size_t flop = std::size_t(2) * G * M * N * K; std::size_t flop = std::size_t(2) * G * M * N * K;
std::size_t num_btype = sizeof(ADataType) * G * M * K + sizeof(BDataType) * G * K * N + std::size_t num_btype = sizeof(ADataType) * G * M * K + sizeof(BDataType) * G * K * N +
sizeof(DDataType) * G * M * N + sizeof(EDataType) * G * M * N; sizeof(DDataType) * G * M * N + sizeof(EDataType) * G * M * N;
......
...@@ -605,9 +605,11 @@ struct DeviceBatchedContractionMultipleD_Wmma_CShuffle ...@@ -605,9 +605,11 @@ struct DeviceBatchedContractionMultipleD_Wmma_CShuffle
b_grid_desc_n_k_ = b_grid_desc_n_k_ =
DeviceOp::MakeBGridDescriptor_N_K(b_gs_ns_ks_lengths, b_gs_ns_ks_strides); DeviceOp::MakeBGridDescriptor_N_K(b_gs_ns_ks_lengths, b_gs_ns_ks_strides);
ds_grid_desc_m_n_ = DeviceOp::MakeDsGridDescriptor_M_N(ds_gs_ms_ns_lengths, ds_gs_ms_ns_strides); ds_grid_desc_m_n_ =
DeviceOp::MakeDsGridDescriptor_M_N(ds_gs_ms_ns_lengths, ds_gs_ms_ns_strides);
e_grid_desc_m_n_ = DeviceOp::MakeEGridDescriptor_M_N(e_gs_ms_ns_lengths, e_gs_ms_ns_strides); e_grid_desc_m_n_ =
DeviceOp::MakeEGridDescriptor_M_N(e_gs_ms_ns_lengths, e_gs_ms_ns_strides);
a_grid_desc_k0_m_k1_ = DeviceOp::MakeAGridDescriptor_K0_M_K1(a_grid_desc_m_k_); a_grid_desc_k0_m_k1_ = DeviceOp::MakeAGridDescriptor_K0_M_K1(a_grid_desc_m_k_);
b_grid_desc_k0_n_k1_ = DeviceOp::MakeBGridDescriptor_K0_N_K1(b_grid_desc_n_k_); b_grid_desc_k0_n_k1_ = DeviceOp::MakeBGridDescriptor_K0_N_K1(b_grid_desc_n_k_);
...@@ -619,8 +621,7 @@ struct DeviceBatchedContractionMultipleD_Wmma_CShuffle ...@@ -619,8 +621,7 @@ struct DeviceBatchedContractionMultipleD_Wmma_CShuffle
ds_grid_desc_m_n_); ds_grid_desc_m_n_);
e_grid_desc_mblock_mperblock_nblock_nperblock = e_grid_desc_mblock_mperblock_nblock_nperblock =
GridwiseOp::MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( GridwiseOp::MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(e_grid_desc_m_n_);
e_grid_desc_m_n_);
// for sanity check of vector memory access // for sanity check of vector memory access
a_mz_stride_ = a_gs_ms_ks_strides[NumDimG + NumDimM - 1]; a_mz_stride_ = a_gs_ms_ks_strides[NumDimG + NumDimM - 1];
...@@ -696,9 +697,11 @@ struct DeviceBatchedContractionMultipleD_Wmma_CShuffle ...@@ -696,9 +697,11 @@ struct DeviceBatchedContractionMultipleD_Wmma_CShuffle
{ {
const index_t G = arg.e_grid_desc_g_m_n_.GetLength(I0); const index_t G = arg.e_grid_desc_g_m_n_.GetLength(I0);
const index_t grid_size = arg.block_2_ctile_map_.CalculateGridSize(arg.e_grid_desc_m_n_) * G; const index_t grid_size =
arg.block_2_ctile_map_.CalculateGridSize(arg.e_grid_desc_m_n_) * G;
const auto K = arg.a_grid_desc_k0_m_k1_.GetLength(I0) * arg.a_grid_desc_k0_m_k1_.GetLength(I2); const auto K =
arg.a_grid_desc_k0_m_k1_.GetLength(I0) * arg.a_grid_desc_k0_m_k1_.GetLength(I2);
auto launch_kernel = [&](auto has_main_k_block_loop) { auto launch_kernel = [&](auto has_main_k_block_loop) {
constexpr bool has_main_loop = has_main_k_block_loop.value; constexpr bool has_main_loop = has_main_k_block_loop.value;
......
...@@ -262,9 +262,9 @@ struct DeviceGroupedConvFwdMultipleD_Wmma_CShuffle ...@@ -262,9 +262,9 @@ struct DeviceGroupedConvFwdMultipleD_Wmma_CShuffle
const auto AK1 = K1; const auto AK1 = K1;
const auto AK0 = K / AK1; const auto AK0 = K / AK1;
return transform_tensor_descriptor( return transform_tensor_descriptor(a_grid_desc_m_k,
a_grid_desc_m_k, make_tuple(make_unmerge_transform(make_tuple(AK0, AK1)),
make_tuple(make_unmerge_transform(make_tuple(AK0, AK1)), make_pass_through_transform(M)), make_pass_through_transform(M)),
make_tuple(Sequence<1>{}, Sequence<0>{}), make_tuple(Sequence<1>{}, Sequence<0>{}),
make_tuple(Sequence<0, 2>{}, Sequence<1>{})); make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
} }
...@@ -280,9 +280,9 @@ struct DeviceGroupedConvFwdMultipleD_Wmma_CShuffle ...@@ -280,9 +280,9 @@ struct DeviceGroupedConvFwdMultipleD_Wmma_CShuffle
const auto BK1 = K1; const auto BK1 = K1;
const auto BK0 = K / BK1; const auto BK0 = K / BK1;
return transform_tensor_descriptor( return transform_tensor_descriptor(b_grid_desc_n_k,
b_grid_desc_n_k, make_tuple(make_unmerge_transform(make_tuple(BK0, BK1)),
make_tuple(make_unmerge_transform(make_tuple(BK0, BK1)), make_pass_through_transform(N)), make_pass_through_transform(N)),
make_tuple(Sequence<1>{}, Sequence<0>{}), make_tuple(Sequence<1>{}, Sequence<0>{}),
make_tuple(Sequence<0, 2>{}, Sequence<1>{})); make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
} }
...@@ -390,10 +390,8 @@ struct DeviceGroupedConvFwdMultipleD_Wmma_CShuffle ...@@ -390,10 +390,8 @@ struct DeviceGroupedConvFwdMultipleD_Wmma_CShuffle
ds_grid_desc_m_n_{}, ds_grid_desc_m_n_{},
e_grid_desc_m_n_{DeviceOp::MakeEGridDescriptor_M_N<ELayout>(e_g_n_k_wos_lengths, e_grid_desc_m_n_{DeviceOp::MakeEGridDescriptor_M_N<ELayout>(e_g_n_k_wos_lengths,
e_g_n_k_wos_strides)}, e_g_n_k_wos_strides)},
a_grid_desc_ak0_m_ak1_{ a_grid_desc_ak0_m_ak1_{DeviceOp::MakeAGridDescriptor_AK0_M_AK1(a_grid_desc_m_k_)},
DeviceOp::MakeAGridDescriptor_AK0_M_AK1(a_grid_desc_m_k_)}, b_grid_desc_bk0_n_bk1_{DeviceOp::MakeBGridDescriptor_BK0_N_BK1(b_grid_desc_n_k_)},
b_grid_desc_bk0_n_bk1_{
DeviceOp::MakeBGridDescriptor_BK0_N_BK1(b_grid_desc_n_k_)},
ds_grid_desc_mblock_mperblock_nblock_nperblock_{}, ds_grid_desc_mblock_mperblock_nblock_nperblock_{},
e_grid_desc_mblock_mperblock_nblock_nperblock_{}, e_grid_desc_mblock_mperblock_nblock_nperblock_{},
block_2_etile_map_{GridwiseOp::MakeDefaultBlock2CTileMap(e_grid_desc_m_n_, M01, N01)}, block_2_etile_map_{GridwiseOp::MakeDefaultBlock2CTileMap(e_grid_desc_m_n_, M01, N01)},
...@@ -432,12 +430,12 @@ struct DeviceGroupedConvFwdMultipleD_Wmma_CShuffle ...@@ -432,12 +430,12 @@ struct DeviceGroupedConvFwdMultipleD_Wmma_CShuffle
}); });
// D desc // D desc
ds_grid_desc_m_n_ = DeviceOp::MakeDsGridDescriptor_M_N(ds_g_n_k_wos_lengths, ds_g_n_k_wos_strides); ds_grid_desc_m_n_ =
DeviceOp::MakeDsGridDescriptor_M_N(ds_g_n_k_wos_lengths, ds_g_n_k_wos_strides);
// populate desc for Ds/E // populate desc for Ds/E
e_grid_desc_mblock_mperblock_nblock_nperblock_ = e_grid_desc_mblock_mperblock_nblock_nperblock_ =
GridwiseOp::MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( GridwiseOp::MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(e_grid_desc_m_n_);
e_grid_desc_m_n_);
ds_grid_desc_mblock_mperblock_nblock_nperblock_ = ds_grid_desc_mblock_mperblock_nblock_nperblock_ =
GridwiseOp::MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( GridwiseOp::MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
ds_grid_desc_m_n_); ds_grid_desc_m_n_);
......
...@@ -148,14 +148,14 @@ __global__ void ...@@ -148,14 +148,14 @@ __global__ void
const Block2CTileMap block_2_etile_map) const Block2CTileMap block_2_etile_map)
{ {
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx1100__)) #if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx1100__))
//printf("entry kernel launch"); // printf("entry kernel launch");
__shared__ char p_shared[GridwiseOp::GetSharedMemoryNumberOfByte()]; __shared__ char p_shared[GridwiseOp::GetSharedMemoryNumberOfByte()];
const index_t num_blocks_per_batch = const index_t num_blocks_per_batch =
__builtin_amdgcn_readfirstlane(get_grid_size() / batch_count); __builtin_amdgcn_readfirstlane(get_grid_size() / batch_count);
const index_t g_idx = __builtin_amdgcn_readfirstlane(get_block_1d_id() / num_blocks_per_batch); const index_t g_idx = __builtin_amdgcn_readfirstlane(get_block_1d_id() / num_blocks_per_batch);
//printf("before compute_ptr_offset call"); // printf("before compute_ptr_offset call");
const long_index_t a_batch_offset = __builtin_amdgcn_readfirstlane( const long_index_t a_batch_offset = __builtin_amdgcn_readfirstlane(
static_cast<long_index_t>(compute_ptr_offset_of_batch.GetAPtrOffset(g_idx))); static_cast<long_index_t>(compute_ptr_offset_of_batch.GetAPtrOffset(g_idx)));
const long_index_t b_batch_offset = __builtin_amdgcn_readfirstlane( const long_index_t b_batch_offset = __builtin_amdgcn_readfirstlane(
...@@ -170,12 +170,12 @@ __global__ void ...@@ -170,12 +170,12 @@ __global__ void
DsPointer p_ds_grid_grp; DsPointer p_ds_grid_grp;
//printf("before allocate pointer d"); // printf("before allocate pointer d");
static_for<0, NumDTensor, 1>{}( static_for<0, NumDTensor, 1>{}(
[&](auto i) { p_ds_grid_grp(i) = p_ds_grid[i] + ds_batch_offset[i]; }); [&](auto i) { p_ds_grid_grp(i) = p_ds_grid[i] + ds_batch_offset[i]; });
//printf("before entry"); // printf("before entry");
GridwiseOp::template Run<HasMainKBlockLoop>(p_a_grid + a_batch_offset, GridwiseOp::template Run<HasMainKBlockLoop>(p_a_grid + a_batch_offset,
p_b_grid + b_batch_offset, p_b_grid + b_batch_offset,
...@@ -570,7 +570,7 @@ struct GridwiseGemmMultipleD_k0mk1_k0nk1_mn_wmma_cshuffle ...@@ -570,7 +570,7 @@ struct GridwiseGemmMultipleD_k0mk1_k0nk1_mn_wmma_cshuffle
const CDEElementwiseOperation& cde_element_op, const CDEElementwiseOperation& cde_element_op,
const Block2CTileMap& block_2_ctile_map) const Block2CTileMap& block_2_ctile_map)
{ {
//printf("safe entry"); // printf("safe entry");
// clang-format off // clang-format off
/*******************************************************************************/ /*******************************************************************************/
// Memory buffer zone. // Memory buffer zone.
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment