Commit cf9bcb31 authored by root's avatar root
Browse files

minimize arg size

parent 09cc45d3
......@@ -54,7 +54,9 @@ using DeviceGemmInstance = ck::tensor_operation::device::DeviceGroupedGemmXdlSpl
//######| | | | | Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector|
//######| | | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl|
//######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
< ALayout, BLayout, DsLayout, ELayout, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmDefault, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>;
//< ALayout, BLayout, DsLayout, ELayout, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmDefault, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>;
//< ALayout, BLayout, DsLayout, ELayout, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmDefault, 1, 64, 32, 64, 32, 8, 8, 32, 32, 1, 2, S<1, 4, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 4, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8>;
< ALayout, BLayout, DsLayout, ELayout, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmDefault, 1, 64, 32, 64, 32, 8, 8, 32, 32, 1, 2, S<1, 4, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 4, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8>;
// clang-format on
#include "run_grouped_gemm_example.inc"
......@@ -66,8 +68,8 @@ int main(int argc, char* argv[])
problem_size.group_count = 16;
problem_size.Ms = {
167, 183, 177, 181, 153, 139, 156, 173, 163, 150, 204, 184, 168, 156, 168, 148};
problem_size.Ms = {2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2};
// problem_size.Ms = {2, 1, 1, 1, 1, 1, 3, 4, 3, 5, 2, 4, 2, 1, 0, 1};
for(int i = 0; i < problem_size.group_count; i++)
{
......
......@@ -23,8 +23,10 @@ namespace ck {
namespace tensor_operation {
namespace device {
#if 1
template <typename GridwiseGemm,
typename GemmDesc,
typename GemmSharedArgs,
bool HasMainKBlockLoop,
InMemoryDataOperationEnum CGlobalMemoryDataOperation>
__global__ void
......@@ -32,7 +34,7 @@ __global__ void
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
#endif
kernel_grouped_gemm_xdl_splitk(const void CK_CONSTANT_ADDRESS_SPACE* gemm_descs_const,
const index_t group_count)
const GemmSharedArgs gemm_shared_args)
{
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \
defined(__gfx940__))
......@@ -43,6 +45,87 @@ __global__ void
const auto gemm_desc_ptr =
reinterpret_cast<const GemmDesc*>(cast_pointer_to_generic_address_space(gemm_descs_const));
const index_t group_id = block_id / gemm_shared_args.block_size;
#if 1
// const auto M = gemm_shared_args.M;
// const auto N = gemm_shared_args.N;
// const auto K = gemm_shared_args.K;
// const auto StrideA = gemm_shared_args.StrideA;
// const auto StrideB = gemm_shared_args.StrideB;
// const auto StrideC = gemm_shared_args.StrideC;
// const auto MPadded = gemm_shared_args.MPadded;
// const auto NPadded = gemm_shared_args.NPadded;
// const auto KPadded = gemm_shared_args.KPadded;
// const auto K0 = gemm_shared_args.KPadded;
// const auto k_batch = gemm_shared_args.k_batch;
// M = 2 N = 768 K = 4608 StrideA = 4608 StrideB = 4608 StrideC = 768 MPadded = 32 NPadded = 768
// KPadded = 4608 K0 = 576 k_batch = 1
const auto M = 2;
const auto N = 768;
const auto K = 4608;
const auto StrideA = 4608;
const auto StrideB = 4608;
const auto StrideC = 768;
const auto MPadded = 32;
const auto NPadded = 768;
const auto KPadded = 4608;
const auto K0 = 576;
const auto k_batch = 1;
// const auto block_2_ctile_map = gemm_shared_args.block_2_ctile_map;
#endif
GridwiseGemm::template Run<HasMainKBlockLoop, CGlobalMemoryDataOperation>(
gemm_desc_ptr[group_id].karg_.p_a_grid,
gemm_desc_ptr[group_id].karg_.p_b_grid,
gemm_desc_ptr[group_id].karg_.p_c_grid,
M,
N,
K,
StrideA,
StrideB,
StrideC,
MPadded,
NPadded,
KPadded,
K0,
k_batch,
static_cast<void*>(p_shared),
gemm_desc_ptr[group_id].karg_.block_2_ctile_map);
#else
ignore = gemm_descs_const;
ignore = all_gemm_block_size;
#endif // end of if (defined(__gfx908__) || defined(__gfx90a__))
}
#else
template <typename GridwiseGemm,
typename GemmDesc,
bool HasMainKBlockLoop,
InMemoryDataOperationEnum CGlobalMemoryDataOperation>
__global__ void
#if CK_USE_LAUNCH_BOUNDS
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
#endif
kernel_grouped_gemm_xdl_splitk(const void CK_CONSTANT_ADDRESS_SPACE* gemm_descs_const,
// const index_t group_count
const index_t block_size)
{
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \
defined(__gfx940__))
constexpr index_t shared_size = GridwiseGemm::GetSharedMemoryNumberOfByte();
__shared__ uint8_t p_shared[shared_size];
const index_t block_id = get_block_1d_id();
const auto gemm_desc_ptr =
reinterpret_cast<const GemmDesc*>(cast_pointer_to_generic_address_space(gemm_descs_const));
#if 0
index_t left = 0;
index_t right = group_count;
index_t group_id = index_t((left + right) / 2);
......@@ -60,16 +143,51 @@ __global__ void
}
group_id = index_t((left + right) / 2);
}
#else
const index_t group_id = block_id / block_size;
#endif
#if 0
const auto N = gemm_desc_ptr[0].karg_.N;
const auto K = gemm_desc_ptr[0].karg_.K;
const auto StrideB = gemm_desc_ptr[0].karg_.StrideB;
const auto NPadded = gemm_desc_ptr[0].karg_.NPadded;
const auto KPadded = gemm_desc_ptr[0].karg_.KPadded;
const auto K0 = gemm_desc_ptr[0].karg_.KPadded;
const auto k_batch = gemm_desc_ptr[0].karg_.k_batch;
const auto block_2_ctile_map = gemm_desc_ptr[0].block_2_ctile_map_;
GridwiseGemm::template Run<HasMainKBlockLoop, CGlobalMemoryDataOperation>(
gemm_desc_ptr[group_id].karg_.p_a_grid,
gemm_desc_ptr[group_id].karg_.p_b_grid,
gemm_desc_ptr[group_id].karg_.p_c_grid,
gemm_desc_ptr[group_id].karg_.M,
N,
K,
gemm_desc_ptr[group_id].karg_.StrideA,
StrideB,
gemm_desc_ptr[group_id].karg_.StrideC,
gemm_desc_ptr[group_id].karg_.MPadded,
NPadded,
KPadded,
K0,
k_batch,
static_cast<void*>(p_shared),
block_2_ctile_map);
#else
GridwiseGemm::template Run<HasMainKBlockLoop, CGlobalMemoryDataOperation>(
gemm_desc_ptr[group_id].karg_,
static_cast<void*>(p_shared),
gemm_desc_ptr[group_id].block_2_ctile_map_);
#endif
#else
ignore = gemm_descs_const;
ignore = group_count;
#endif // end of if (defined(__gfx908__) || defined(__gfx90a__))
}
#endif
template <typename ALayout,
typename BLayout,
......@@ -406,10 +524,104 @@ struct DeviceGroupedGemmXdlSplitKCShuffle : public DeviceGroupedGemmSplitK<ALayo
}
}
struct ArgumentMsN1K1
{
const ADataType* p_a_grid;
const BDataType* p_b_grid;
EDataType* p_c_grid;
// index_t M;
// index_t StrideA;
// index_t StrideC;
// index_t MPadded;
GroupedGemmBlock2ETileMap block_2_ctile_map;
};
struct GemmTransKernelArgMsN1K1
{
ArgumentMsN1K1 karg_;
};
#if 1
std::vector<GemmTransKernelArgMsN1K1> gemm_kernel_args_msn1k1_;
index_t all_gemm_block_size =
arg.gemm_kernel_args_[0].block_end_ - arg.gemm_kernel_args_[0].block_start_;
for(const auto& trans_arg : arg.gemm_kernel_args_)
{
auto karg = ArgumentMsN1K1{trans_arg.karg_.p_a_grid,
trans_arg.karg_.p_b_grid,
trans_arg.karg_.p_c_grid,
trans_arg.block_2_ctile_map_};
// auto block_size = trans_arg.block_end_ - trans_arg.block_start_;
// std::cout << "trans_arg.block_start_: " << trans_arg.block_start_
// << " trans_arg.block_end_: " << trans_arg.block_end_
// << " block_size: " << block_size << std::endl;
gemm_kernel_args_msn1k1_.push_back({karg});
}
#endif
#if 0
hip_check_error(hipMemcpy(arg.p_workspace_,
arg.gemm_kernel_args_.data(),
arg.gemm_kernel_args_.size() * sizeof(GemmTransKernelArg),
hipMemcpyHostToDevice));
#else
struct GemmSharedArgs
{
index_t block_size;
// index_t M;
// index_t N;
// index_t K;
// index_t StrideA;
// index_t StrideB;
// index_t StrideC;
// index_t MPadded;
// index_t NPadded;
// index_t KPadded;
// index_t K0;
// index_t k_batch;
// GroupedGemmBlock2ETileMap block_2_ctile_map;
#if 0
void print()
{
std::cout << "block_size = " << block_size << " M = " << M << " N = " << N
<< " K = " << K << " StrideA = " << StrideA
<< " StrideB = " << StrideB << " StrideC = " << StrideC
<< " MPadded = " << MPadded << " NPadded = " << NPadded
<< " KPadded = " << KPadded << " K0 = " << K0
<< " k_batch = " << k_batch << std::endl;
}
#endif
};
auto shared_karg = GemmSharedArgs{
all_gemm_block_size,
// arg.gemm_kernel_args_[0].karg_.M,
// arg.gemm_kernel_args_[0].karg_.N,
// arg.gemm_kernel_args_[0].karg_.K,
// arg.gemm_kernel_args_[0].karg_.StrideA,
// arg.gemm_kernel_args_[0].karg_.StrideB,
// arg.gemm_kernel_args_[0].karg_.StrideC,
// arg.gemm_kernel_args_[0].karg_.MPadded,
// arg.gemm_kernel_args_[0].karg_.NPadded,
// arg.gemm_kernel_args_[0].karg_.KPadded,
// arg.gemm_kernel_args_[0].karg_.K0,
// arg.gemm_kernel_args_[0].karg_.k_batch,
// arg.gemm_kernel_args_[0].block_2_ctile_map_,
};
// shared_karg.print();
hip_check_error(
hipMemcpy(arg.p_workspace_,
gemm_kernel_args_msn1k1_.data(),
gemm_kernel_args_msn1k1_.size() * sizeof(GemmTransKernelArgMsN1K1),
hipMemcpyHostToDevice));
#endif
float ave_time = 0;
......@@ -431,9 +643,14 @@ struct DeviceGroupedGemmXdlSplitKCShuffle : public DeviceGroupedGemmSplitK<ALayo
dim3(BlockSize),
0,
cast_pointer_to_constant_address_space(arg.p_workspace_),
arg.gemm_kernel_args_.size());
shared_karg
// all_gemm_block_size
);
};
std::cout << "all_have_main_k0_block_loop: " << all_have_main_k0_block_loop
<< " all_have_kbatch_gt_one: " << all_have_kbatch_gt_one << std::endl;
#if 0
if(all_have_main_k0_block_loop)
{
if(all_have_kbatch_gt_one)
......@@ -480,6 +697,58 @@ struct DeviceGroupedGemmXdlSplitKCShuffle : public DeviceGroupedGemmSplitK<ALayo
Run(kernel);
}
}
#else
if(all_have_main_k0_block_loop)
{
if(all_have_kbatch_gt_one)
{
const auto kernel =
kernel_grouped_gemm_xdl_splitk<GridwiseGemm,
GemmTransKernelArgMsN1K1,
GemmSharedArgs,
true,
InMemoryDataOperationEnum::AtomicAdd>;
Run(kernel);
}
else
{
const auto kernel =
kernel_grouped_gemm_xdl_splitk<GridwiseGemm,
GemmTransKernelArgMsN1K1,
GemmSharedArgs,
true,
InMemoryDataOperationEnum::Set>;
Run(kernel);
}
}
else
{
if(all_have_kbatch_gt_one)
{
const auto kernel =
kernel_grouped_gemm_xdl_splitk<GridwiseGemm,
GemmTransKernelArgMsN1K1,
GemmSharedArgs,
false,
InMemoryDataOperationEnum::AtomicAdd>;
Run(kernel);
}
else
{
const auto kernel =
kernel_grouped_gemm_xdl_splitk<GridwiseGemm,
GemmTransKernelArgMsN1K1,
GemmSharedArgs,
false,
InMemoryDataOperationEnum::Set>;
Run(kernel);
}
}
#endif
return ave_time;
}
......@@ -614,7 +883,7 @@ struct DeviceGroupedGemmXdlSplitKCShuffle : public DeviceGroupedGemmSplitK<ALayo
{
return SetKBatchSize(*dynamic_cast<Argument*>(p_arg), kbatch);
}
};
}; // namespace device
} // namespace device
} // namespace tensor_operation
......
......@@ -8,7 +8,7 @@ MY_PROJECT_SOURCE=$1
cmake \
-D CMAKE_PREFIX_PATH=/opt/rocm \
-D CMAKE_CXX_COMPILER=/opt/rocm/bin/hipcc \
-D CMAKE_CXX_FLAGS="-std=c++17 -O3 -ftemplate-backtrace-limit=0 -fPIE -Wno-gnu-line-marker \
-D CMAKE_CXX_FLAGS="-std=c++20 -O3 -ftemplate-backtrace-limit=0 -fPIE -Wno-gnu-line-marker \
-save-temps=$PWD" \
-D CMAKE_BUILD_TYPE=Release \
-D BUILD_DEV=ON \
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment