Commit abfc94b2 authored by aska-0096's avatar aska-0096
Browse files

batchedgemm[OK], groupconv[debug]

parent 07180cb7
...@@ -112,7 +112,7 @@ int main(int argc, char* argv[]) ...@@ -112,7 +112,7 @@ int main(int argc, char* argv[])
{ {
bool do_verification = true; bool do_verification = true;
int init_method = 1; int init_method = 1;
bool time_kernel = false; bool time_kernel = true;
// GEMM shape // GEMM shape
ck::index_t M = 3840; ck::index_t M = 3840;
......
...@@ -56,7 +56,7 @@ using DeviceOpInstanceKKNN = ...@@ -56,7 +56,7 @@ using DeviceOpInstanceKKNN =
NumDimK, NumDimK,
ADataType, ADataType,
BDataType, BDataType,
ck::Tuple<DDataType>, DsDataType,
EDataType, EDataType,
AccDataType, AccDataType,
CShuffleDataType, CShuffleDataType,
...@@ -239,18 +239,18 @@ int main(int argc, char* argv[]) ...@@ -239,18 +239,18 @@ int main(int argc, char* argv[])
{ {
bool do_verification = true; bool do_verification = true;
int init_method = 1; int init_method = 1;
bool time_kernel = false; bool time_kernel = true;
ck::index_t G0 = 1; ck::index_t G0 = 1;
ck::index_t G1 = 1; ck::index_t G1 = 2;
ck::index_t M0 = 1; ck::index_t M0 = 4;
ck::index_t M1 = 1; ck::index_t M1 = 128;
ck::index_t N0 = 1; ck::index_t N0 = 16;
ck::index_t N1 = 1; ck::index_t N1 = 256;
ck::index_t K0 = 1; ck::index_t K0 = 2048;
// A[G0, G1, M0, M1, K0] // A[G0, G1, M0, M1, K0]
std::vector<ck::index_t> a_gs_ms_ks_lengths{G0, G1, M0, M1, K0}; std::vector<ck::index_t> a_gs_ms_ks_lengths{G0, G1, M0, M1, K0};
...@@ -284,13 +284,11 @@ int main(int argc, char* argv[]) ...@@ -284,13 +284,11 @@ int main(int argc, char* argv[])
printf("arg3: time kernel (0=no, 1=yes)\n"); printf("arg3: time kernel (0=no, 1=yes)\n");
exit(0); exit(0);
} }
std::cout<<"CP -4 "<<std::endl;
Tensor<ADataType> a_gs_ms_ks(a_gs_ms_ks_lengths, a_gs_ms_ks_strides); Tensor<ADataType> a_gs_ms_ks(a_gs_ms_ks_lengths, a_gs_ms_ks_strides);
Tensor<BDataType> b_gs_ns_ks(b_gs_ns_ks_lengths, b_gs_ns_ks_strides); Tensor<BDataType> b_gs_ns_ks(b_gs_ns_ks_lengths, b_gs_ns_ks_strides);
Tensor<DDataType> d_gs_ms_ns(d_gs_ms_ns_lengths, d_gs_ms_ns_strides); Tensor<DDataType> d_gs_ms_ns(d_gs_ms_ns_lengths, d_gs_ms_ns_strides);
Tensor<EDataType> e_gs_ms_ns_host_result(e_gs_ms_ns_lengths, e_gs_ms_ns_strides); Tensor<EDataType> e_gs_ms_ns_host_result(e_gs_ms_ns_lengths, e_gs_ms_ns_strides);
Tensor<EDataType> e_gs_ms_ns_device_result(e_gs_ms_ns_lengths, e_gs_ms_ns_strides); Tensor<EDataType> e_gs_ms_ns_device_result(e_gs_ms_ns_lengths, e_gs_ms_ns_strides);
std::cout<<"CP -3 "<<std::endl;
std::cout << "a_gs_ms_ks: " << a_gs_ms_ks.mDesc << std::endl; std::cout << "a_gs_ms_ks: " << a_gs_ms_ks.mDesc << std::endl;
std::cout << "b_gs_ns_ks: " << b_gs_ns_ks.mDesc << std::endl; std::cout << "b_gs_ns_ks: " << b_gs_ns_ks.mDesc << std::endl;
std::cout << "d_gs_ms_ns: " << d_gs_ms_ns.mDesc << std::endl; std::cout << "d_gs_ms_ns: " << d_gs_ms_ns.mDesc << std::endl;
...@@ -310,17 +308,14 @@ int main(int argc, char* argv[]) ...@@ -310,17 +308,14 @@ int main(int argc, char* argv[])
d_gs_ms_ns.GenerateTensorValue(GeneratorTensor_3<DDataType>{-0.5, 0.5}); d_gs_ms_ns.GenerateTensorValue(GeneratorTensor_3<DDataType>{-0.5, 0.5});
break; break;
} }
std::cout<<"CP -2 "<<std::endl;
DeviceMem a_device_buf(sizeof(ADataType) * a_gs_ms_ks.mDesc.GetElementSpaceSize()); DeviceMem a_device_buf(sizeof(ADataType) * a_gs_ms_ks.mDesc.GetElementSpaceSize());
DeviceMem b_device_buf(sizeof(BDataType) * b_gs_ns_ks.mDesc.GetElementSpaceSize()); DeviceMem b_device_buf(sizeof(BDataType) * b_gs_ns_ks.mDesc.GetElementSpaceSize());
DeviceMem d_device_buf(sizeof(DDataType) * d_gs_ms_ns.mDesc.GetElementSpaceSize()); DeviceMem d_device_buf(sizeof(DDataType) * d_gs_ms_ns.mDesc.GetElementSpaceSize());
DeviceMem e_device_buf(sizeof(EDataType) * DeviceMem e_device_buf(sizeof(EDataType) * e_gs_ms_ns_device_result.mDesc.GetElementSpaceSize());
e_gs_ms_ns_device_result.mDesc.GetElementSpaceSize());
a_device_buf.ToDevice(a_gs_ms_ks.mData.data()); a_device_buf.ToDevice(a_gs_ms_ks.mData.data());
b_device_buf.ToDevice(b_gs_ns_ks.mData.data()); b_device_buf.ToDevice(b_gs_ns_ks.mData.data());
d_device_buf.ToDevice(d_gs_ms_ns.mData.data()); d_device_buf.ToDevice(d_gs_ms_ns.mData.data());
std::cout<<"CP -1 "<<std::endl;
// set zero // set zero
e_device_buf.SetZero(); e_device_buf.SetZero();
...@@ -330,7 +325,6 @@ int main(int argc, char* argv[]) ...@@ -330,7 +325,6 @@ int main(int argc, char* argv[])
auto cde_element_op = CDEElementOp{}; auto cde_element_op = CDEElementOp{};
// device operation // device operation
std::cout<<"CP 0 "<<std::endl;
auto op = DeviceOpInstance{}; auto op = DeviceOpInstance{};
auto invoker = op.MakeInvoker(); auto invoker = op.MakeInvoker();
auto argument = op.MakeArgument(a_device_buf.GetDeviceBuffer(), auto argument = op.MakeArgument(a_device_buf.GetDeviceBuffer(),
...@@ -348,7 +342,6 @@ int main(int argc, char* argv[]) ...@@ -348,7 +342,6 @@ int main(int argc, char* argv[])
a_element_op, a_element_op,
b_element_op, b_element_op,
cde_element_op); cde_element_op);
std::cout<<"CP 1 "<<std::endl;
if(!op.IsSupportedArgument(argument)) if(!op.IsSupportedArgument(argument))
{ {
...@@ -358,7 +351,6 @@ int main(int argc, char* argv[]) ...@@ -358,7 +351,6 @@ int main(int argc, char* argv[])
} }
float ave_time = invoker.Run(argument, StreamConfig{nullptr, time_kernel}); float ave_time = invoker.Run(argument, StreamConfig{nullptr, time_kernel});
std::cout<<"CP 2 "<<std::endl;
ck::index_t G = ck::index_t G =
ck::accumulate_n<ck::index_t>(e_gs_ms_ns_lengths.begin(), NumDimG, 1, std::multiplies<>{}); ck::accumulate_n<ck::index_t>(e_gs_ms_ns_lengths.begin(), NumDimG, 1, std::multiplies<>{});
...@@ -371,7 +363,7 @@ int main(int argc, char* argv[]) ...@@ -371,7 +363,7 @@ int main(int argc, char* argv[])
ck::index_t K = ck::accumulate_n<ck::index_t>( ck::index_t K = ck::accumulate_n<ck::index_t>(
a_gs_ms_ks_lengths.begin() + NumDimG + NumDimM, NumDimK, 1, std::multiplies<>{}); a_gs_ms_ks_lengths.begin() + NumDimG + NumDimM, NumDimK, 1, std::multiplies<>{});
std::cout<<"GMNK="<<G<<", "<<M<<", "<<N<<", "<<K<<std::endl;
std::size_t flop = std::size_t(2) * G * M * N * K; std::size_t flop = std::size_t(2) * G * M * N * K;
std::size_t num_btype = sizeof(ADataType) * G * M * K + sizeof(BDataType) * G * K * N + std::size_t num_btype = sizeof(ADataType) * G * M * K + sizeof(BDataType) * G * K * N +
sizeof(DDataType) * G * M * N + sizeof(EDataType) * G * M * N; sizeof(DDataType) * G * M * N + sizeof(EDataType) * G * M * N;
......
...@@ -397,43 +397,10 @@ struct DeviceBatchedContractionMultipleD_Wmma_CShuffle ...@@ -397,43 +397,10 @@ struct DeviceBatchedContractionMultipleD_Wmma_CShuffle
using BGridDesc_N_K = decltype(MakeBGridDescriptor_N_K({}, {})); using BGridDesc_N_K = decltype(MakeBGridDescriptor_N_K({}, {}));
using DsGridDesc_M_N = remove_cvref_t<decltype(MakeDsGridDescriptor_M_N({}, {}))>; using DsGridDesc_M_N = remove_cvref_t<decltype(MakeDsGridDescriptor_M_N({}, {}))>;
using EGridDesc_M_N = decltype(MakeEGridDescriptor_M_N({}, {})); using EGridDesc_M_N = decltype(MakeEGridDescriptor_M_N({}, {}));
using DsGridDesc_G_M_N = remove_cvref_t<decltype(MakeDsGridDescriptor_G_M_N({}, {}))>; using DsGridDesc_G_M_N = remove_cvref_t<decltype(MakeDsGridDescriptor_G_M_N({}, {}))>;
using EGridDesc_G_M_N = decltype(MakeEGridDescriptor_G_M_N({}, {})); using EGridDesc_G_M_N = decltype(MakeEGridDescriptor_G_M_N({}, {}));
// A desc for source in blockwise copy
template <typename AGridDesc_M_K>
__host__ __device__ static constexpr auto
MakeAGridDescriptor_K0_M_K1(const AGridDesc_M_K& a_grid_desc_m_k)
{
const auto M = a_grid_desc_m_k.GetLength(I0);
const auto K = a_grid_desc_m_k.GetLength(I1);
const auto AK0 = K / K1;
return transform_tensor_descriptor(
a_grid_desc_m_k,
make_tuple(make_unmerge_transform(make_tuple(AK0, K1)), make_pass_through_transform(M)),
make_tuple(Sequence<1>{}, Sequence<0>{}),
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
}
// B desc for source in blockwise copy
template <typename BGridDesc_N_K>
__host__ __device__ static constexpr auto
MakeBGridDescriptor_K0_N_K1(const BGridDesc_N_K& b_grid_desc_n_k)
{
const auto N = b_grid_desc_n_k.GetLength(I0);
const auto K = b_grid_desc_n_k.GetLength(I1);
const auto BK0 = K / K1;
return transform_tensor_descriptor(
b_grid_desc_n_k,
make_tuple(make_unmerge_transform(make_tuple(BK0, K1)), make_pass_through_transform(N)),
make_tuple(Sequence<1>{}, Sequence<0>{}),
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
}
struct ComputePtrOffsetOfStridedBatch struct ComputePtrOffsetOfStridedBatch
{ {
ComputePtrOffsetOfStridedBatch(index_t batch_stride_A, ComputePtrOffsetOfStridedBatch(index_t batch_stride_A,
...@@ -482,6 +449,40 @@ struct DeviceBatchedContractionMultipleD_Wmma_CShuffle ...@@ -482,6 +449,40 @@ struct DeviceBatchedContractionMultipleD_Wmma_CShuffle
EGridDesc_G_M_N e_grid_desc_g_m_n_; EGridDesc_G_M_N e_grid_desc_g_m_n_;
}; };
// A desc for source in blockwise copy
template <typename AGridDesc_M_K>
__host__ __device__ static constexpr auto
MakeAGridDescriptor_K0_M_K1(const AGridDesc_M_K& a_grid_desc_m_k)
{
const auto M = a_grid_desc_m_k.GetLength(I0);
const auto K = a_grid_desc_m_k.GetLength(I1);
const auto AK0 = K / K1;
return transform_tensor_descriptor(
a_grid_desc_m_k,
make_tuple(make_unmerge_transform(make_tuple(AK0, K1)), make_pass_through_transform(M)),
make_tuple(Sequence<1>{}, Sequence<0>{}),
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
}
// B desc for source in blockwise copy
template <typename BGridDesc_N_K>
__host__ __device__ static constexpr auto
MakeBGridDescriptor_K0_N_K1(const BGridDesc_N_K& b_grid_desc_n_k)
{
const auto N = b_grid_desc_n_k.GetLength(I0);
const auto K = b_grid_desc_n_k.GetLength(I1);
const auto BK0 = K / K1;
return transform_tensor_descriptor(
b_grid_desc_n_k,
make_tuple(make_unmerge_transform(make_tuple(BK0, K1)), make_pass_through_transform(N)),
make_tuple(Sequence<1>{}, Sequence<0>{}),
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
}
using AGridDesc_K0_M_K1 = decltype(DeviceOp::MakeAGridDescriptor_K0_M_K1(AGridDesc_M_K{})); using AGridDesc_K0_M_K1 = decltype(DeviceOp::MakeAGridDescriptor_K0_M_K1(AGridDesc_M_K{}));
using BGridDesc_K0_N_K1 = decltype(DeviceOp::MakeBGridDescriptor_K0_N_K1(BGridDesc_N_K{})); using BGridDesc_K0_N_K1 = decltype(DeviceOp::MakeBGridDescriptor_K0_N_K1(BGridDesc_N_K{}));
...@@ -592,33 +593,27 @@ struct DeviceBatchedContractionMultipleD_Wmma_CShuffle ...@@ -592,33 +593,27 @@ struct DeviceBatchedContractionMultipleD_Wmma_CShuffle
compute_ptr_offset_of_batch_{ compute_ptr_offset_of_batch_{
a_batch_stride_, b_batch_stride_, ds_grid_desc_g_m_n_, e_grid_desc_g_m_n_} a_batch_stride_, b_batch_stride_, ds_grid_desc_g_m_n_, e_grid_desc_g_m_n_}
{ {
a_grid_desc_m_k_ =
DeviceOp::MakeAGridDescriptor_M_K(a_gs_ms_ks_lengths, a_gs_ms_ks_strides);
a_grid_desc_m_k_ =
DeviceOp::MakeBGridDescriptor_N_K(b_gs_ns_ks_lengths, b_gs_ns_ks_strides);
a_grid_desc_k0_m_k1_ = DeviceOp::MakeAGridDescriptor_K0_M_K1(a_grid_desc_m_k_);
b_grid_desc_k0_n_k1_ = DeviceOp::MakeBGridDescriptor_K0_N_K1(b_grid_desc_n_k_);
static_for<0, NumDTensor, 1>{}([&](auto i) { static_for<0, NumDTensor, 1>{}([&](auto i) {
using DDataType = remove_cvref_t<tuple_element_t<i.value, DsDataType>>; using DDataType = remove_cvref_t<tuple_element_t<i.value, DsDataType>>;
// D pointer // D pointer
p_ds_grid_(i) = static_cast<const DDataType*>(p_ds_grid[i]); p_ds_grid_(i) = static_cast<const DDataType*>(p_ds_grid[i]);
// D desc
ds_grid_desc_m_n_(i) = DeviceOp::MakeEGridDescriptor_M_N(ds_gs_ms_ns_lengths[i],
ds_gs_ms_ns_strides[i]);
}); });
e_grid_desc_m_n_ =
DeviceOp::MakeEGridDescriptor_M_N(e_gs_ms_ns_lengths, e_gs_ms_ns_strides); a_grid_desc_m_k_ =
DeviceOp::MakeAGridDescriptor_M_K(a_gs_ms_ks_lengths, a_gs_ms_ks_strides);
b_grid_desc_n_k_ =
DeviceOp::MakeBGridDescriptor_N_K(b_gs_ns_ks_lengths, b_gs_ns_ks_strides);
ds_grid_desc_m_n_ = DeviceOp::MakeDsGridDescriptor_M_N(ds_gs_ms_ns_lengths, ds_gs_ms_ns_strides);
e_grid_desc_m_n_ = DeviceOp::MakeEGridDescriptor_M_N(e_gs_ms_ns_lengths, e_gs_ms_ns_strides);
a_grid_desc_k0_m_k1_ = DeviceOp::MakeAGridDescriptor_K0_M_K1(a_grid_desc_m_k_);
b_grid_desc_k0_n_k1_ = DeviceOp::MakeBGridDescriptor_K0_N_K1(b_grid_desc_n_k_);
block_2_ctile_map_ = GridwiseOp::MakeDefaultBlock2CTileMap(e_grid_desc_m_n_, M01, N01); block_2_ctile_map_ = GridwiseOp::MakeDefaultBlock2CTileMap(e_grid_desc_m_n_, M01, N01);
if(GridwiseOp::CheckValidity(a_grid_desc_k0_m_k1_,
b_grid_desc_k0_n_k1_,
ds_grid_desc_m_n_,
e_grid_desc_m_n_,
block_2_ctile_map_))
{
ds_grid_desc_mblock_mperblock_nblock_nperblock = ds_grid_desc_mblock_mperblock_nblock_nperblock =
GridwiseOp::MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( GridwiseOp::MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
ds_grid_desc_m_n_); ds_grid_desc_m_n_);
...@@ -626,7 +621,6 @@ struct DeviceBatchedContractionMultipleD_Wmma_CShuffle ...@@ -626,7 +621,6 @@ struct DeviceBatchedContractionMultipleD_Wmma_CShuffle
e_grid_desc_mblock_mperblock_nblock_nperblock = e_grid_desc_mblock_mperblock_nblock_nperblock =
GridwiseOp::MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( GridwiseOp::MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
e_grid_desc_m_n_); e_grid_desc_m_n_);
}
// for sanity check of vector memory access // for sanity check of vector memory access
a_mz_stride_ = a_gs_ms_ks_strides[NumDimG + NumDimM - 1]; a_mz_stride_ = a_gs_ms_ks_strides[NumDimG + NumDimM - 1];
...@@ -700,65 +694,33 @@ struct DeviceBatchedContractionMultipleD_Wmma_CShuffle ...@@ -700,65 +694,33 @@ struct DeviceBatchedContractionMultipleD_Wmma_CShuffle
float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{}) float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{})
{ {
#if 0
{
std::cout << "arg.a_grid_desc_k0_m_k1_{" << arg.a_grid_desc_k0_m_k1_.GetLength(I0)
<< ", " << arg.a_grid_desc_k0_m_k1_.GetLength(I1) << ", "
<< arg.a_grid_desc_k0_m_k1_.GetLength(I2) << "}" << std::endl;
std::cout << "arg.b_grid_desc_k0_n_k1_{" << arg.b_grid_desc_k0_n_k1_.GetLength(I0)
<< ", " << arg.b_grid_desc_k0_n_k1_.GetLength(I1) << ", "
<< arg.b_grid_desc_k0_n_k1_.GetLength(I2) << "}" << std::endl;
std::cout << "arg.c_grid_desc_m_n_{ " << arg.c_grid_desc_m_n_.GetLength(I0)
<< ", " << arg.c_grid_desc_m_n_.GetLength(I1) << ", "
<< arg.c_grid_desc_m_n_.GetLength(I2) << "}" << std::endl;
}
#endif
if(!GridwiseOp::CheckValidity(arg.a_grid_desc_k0_m_k1_,
arg.b_grid_desc_k0_n_k1_,
arg.ds_grid_desc_m_n_,
arg.e_grid_desc_m_n_,
arg.block_2_ctile_map_))
{
throw std::runtime_error(
"wrong! GridwiseGemmMultipleD_wmma_cshuffle has invalid setting");
}
const index_t G = arg.e_grid_desc_g_m_n_.GetLength(I0); const index_t G = arg.e_grid_desc_g_m_n_.GetLength(I0);
const index_t grid_size = const index_t grid_size = arg.block_2_ctile_map_.CalculateGridSize(arg.e_grid_desc_m_n_) * G;
arg.block_2_ctile_map_.CalculateGridSize(arg.e_grid_desc_m_n_) * G;
const auto K = const auto K = arg.a_grid_desc_k0_m_k1_.GetLength(I0) * arg.a_grid_desc_k0_m_k1_.GetLength(I2);
arg.a_grid_desc_k0_m_k1_.GetLength(I0) * arg.a_grid_desc_k0_m_k1_.GetLength(I2);
float ave_time = 0; auto launch_kernel = [&](auto has_main_k_block_loop) {
constexpr bool has_main_loop = has_main_k_block_loop.value;
if(GridwiseOp::CalculateHasMainKBlockLoop(K))
{
const auto kernel = kernel_contraction_multiple_d_wmma_cshuffle< const auto kernel = kernel_contraction_multiple_d_wmma_cshuffle<
GridwiseOp, GridwiseOp,
ADataType, ADataType,
BDataType, BDataType,
typename GridwiseOp::DsGridPointer, typename GridwiseOp::DsGridPointer,
EDataType, EDataType,
remove_reference_t<typename DeviceOp::AGridDesc_K0_M_K1>, DeviceOp::AGridDesc_K0_M_K1,
remove_reference_t<typename DeviceOp::BGridDesc_K0_N_K1>, DeviceOp::BGridDesc_K0_N_K1,
remove_reference_t< typename GridwiseOp::DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock,
typename GridwiseOp::DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock>, typename GridwiseOp::EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock,
remove_reference_t<
typename GridwiseOp::EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock>,
AElementwiseOperation, AElementwiseOperation,
BElementwiseOperation, BElementwiseOperation,
CDEElementwiseOperation, CDEElementwiseOperation,
ComputePtrOffsetOfStridedBatch, ComputePtrOffsetOfStridedBatch,
remove_reference_t<typename GridwiseOp::DefaultBlock2CTileMap>, typename GridwiseOp::DefaultBlock2CTileMap,
true>; // Last Option is W/O has_main_loop>;
ave_time = return launch_and_time_kernel(stream_config,
launch_and_time_kernel(stream_config,
kernel, kernel,
dim3(grid_size), dim3(grid_size),
dim3(BlockSize), dim3(BlockSize),
...@@ -777,51 +739,16 @@ struct DeviceBatchedContractionMultipleD_Wmma_CShuffle ...@@ -777,51 +739,16 @@ struct DeviceBatchedContractionMultipleD_Wmma_CShuffle
arg.cde_element_op_, arg.cde_element_op_,
arg.compute_ptr_offset_of_batch_, arg.compute_ptr_offset_of_batch_,
arg.block_2_ctile_map_); arg.block_2_ctile_map_);
};
if(GridwiseOp::CalculateHasMainKBlockLoop(K))
{
return launch_kernel(integral_constant<bool, true>{});
} }
else else
{ {
const auto kernel = kernel_contraction_multiple_d_wmma_cshuffle< return launch_kernel(integral_constant<bool, false>{});
GridwiseOp,
ADataType,
BDataType,
typename GridwiseOp::DsGridPointer,
EDataType,
remove_reference_t<typename DeviceOp::AGridDesc_K0_M_K1>,
remove_reference_t<typename DeviceOp::BGridDesc_K0_N_K1>,
remove_reference_t<
typename GridwiseOp::DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock>,
remove_reference_t<
typename GridwiseOp::EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock>,
AElementwiseOperation,
BElementwiseOperation,
CDEElementwiseOperation,
ComputePtrOffsetOfStridedBatch,
remove_reference_t<typename GridwiseOp::DefaultBlock2CTileMap>,
false>;
ave_time =
launch_and_time_kernel(stream_config,
kernel,
dim3(grid_size),
dim3(BlockSize),
0,
arg.p_a_grid_,
arg.p_b_grid_,
arg.p_ds_grid_,
arg.p_e_grid_,
G,
arg.a_grid_desc_k0_m_k1_,
arg.b_grid_desc_k0_n_k1_,
arg.ds_grid_desc_mblock_mperblock_nblock_nperblock,
arg.e_grid_desc_mblock_mperblock_nblock_nperblock,
arg.a_element_op_,
arg.b_element_op_,
arg.cde_element_op_,
arg.compute_ptr_offset_of_batch_,
arg.block_2_ctile_map_);
} }
return ave_time;
} }
// polymorphic // polymorphic
......
...@@ -147,13 +147,15 @@ __global__ void ...@@ -147,13 +147,15 @@ __global__ void
const ComputePtrOffsetOfBatch compute_ptr_offset_of_batch, const ComputePtrOffsetOfBatch compute_ptr_offset_of_batch,
const Block2CTileMap block_2_etile_map) const Block2CTileMap block_2_etile_map)
{ {
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__)) #if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx1100__))
//printf("entry kernel launch");
__shared__ char p_shared[GridwiseOp::GetSharedMemoryNumberOfByte()]; __shared__ char p_shared[GridwiseOp::GetSharedMemoryNumberOfByte()];
const index_t num_blocks_per_batch = const index_t num_blocks_per_batch =
__builtin_amdgcn_readfirstlane(get_grid_size() / batch_count); __builtin_amdgcn_readfirstlane(get_grid_size() / batch_count);
const index_t g_idx = __builtin_amdgcn_readfirstlane(get_block_1d_id() / num_blocks_per_batch); const index_t g_idx = __builtin_amdgcn_readfirstlane(get_block_1d_id() / num_blocks_per_batch);
//printf("before compute_ptr_offset call");
const long_index_t a_batch_offset = __builtin_amdgcn_readfirstlane( const long_index_t a_batch_offset = __builtin_amdgcn_readfirstlane(
static_cast<long_index_t>(compute_ptr_offset_of_batch.GetAPtrOffset(g_idx))); static_cast<long_index_t>(compute_ptr_offset_of_batch.GetAPtrOffset(g_idx)));
const long_index_t b_batch_offset = __builtin_amdgcn_readfirstlane( const long_index_t b_batch_offset = __builtin_amdgcn_readfirstlane(
...@@ -163,14 +165,18 @@ __global__ void ...@@ -163,14 +165,18 @@ __global__ void
const auto ds_batch_offset = compute_ptr_offset_of_batch.GetDsPtrOffset(g_idx); const auto ds_batch_offset = compute_ptr_offset_of_batch.GetDsPtrOffset(g_idx);
DsPointer p_ds_grid_grp;
static constexpr index_t NumDTensor = static constexpr index_t NumDTensor =
DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock::Size(); DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock::Size();
DsPointer p_ds_grid_grp;
//printf("before allocate pointer d");
static_for<0, NumDTensor, 1>{}( static_for<0, NumDTensor, 1>{}(
[&](auto i) { p_ds_grid_grp(i) = p_ds_grid[i] + ds_batch_offset[i]; }); [&](auto i) { p_ds_grid_grp(i) = p_ds_grid[i] + ds_batch_offset[i]; });
//printf("before entry");
GridwiseOp::template Run<HasMainKBlockLoop>(p_a_grid + a_batch_offset, GridwiseOp::template Run<HasMainKBlockLoop>(p_a_grid + a_batch_offset,
p_b_grid + b_batch_offset, p_b_grid + b_batch_offset,
p_ds_grid_grp, p_ds_grid_grp,
...@@ -564,6 +570,7 @@ struct GridwiseGemmMultipleD_k0mk1_k0nk1_mn_wmma_cshuffle ...@@ -564,6 +570,7 @@ struct GridwiseGemmMultipleD_k0mk1_k0nk1_mn_wmma_cshuffle
const CDEElementwiseOperation& cde_element_op, const CDEElementwiseOperation& cde_element_op,
const Block2CTileMap& block_2_ctile_map) const Block2CTileMap& block_2_ctile_map)
{ {
//printf("safe entry");
// clang-format off // clang-format off
/*******************************************************************************/ /*******************************************************************************/
// Memory buffer zone. // Memory buffer zone.
...@@ -709,6 +716,7 @@ struct GridwiseGemmMultipleD_k0mk1_k0nk1_mn_wmma_cshuffle ...@@ -709,6 +716,7 @@ struct GridwiseGemmMultipleD_k0mk1_k0nk1_mn_wmma_cshuffle
c_thread_buf, c_thread_buf,
K0BlockMainLoop); K0BlockMainLoop);
/*******************************************************************************/ /*******************************************************************************/
//printf("safe 1");
// write out to C, implement shuffle // write out to C, implement shuffle
{ {
constexpr auto c_thread_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs = constexpr auto c_thread_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs =
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment