Unverified Commit 989ffa2b authored by Illia Silin's avatar Illia Silin Committed by GitHub
Browse files

Merge pull request #263 from ROCm/develop

Merge from public repo
parents 0dd5d62a cabbacb6
This diff is collapsed.
...@@ -7,6 +7,7 @@ import copy ...@@ -7,6 +7,7 @@ import copy
NS = 'ck_tile' NS = 'ck_tile'
OPS = 'ops' OPS = 'ops'
REF = 'ref'
OPS_COMMON = 'common' # common header will be duplicated into ops/* other module OPS_COMMON = 'common' # common header will be duplicated into ops/* other module
HEADER_COMMON = f"""// SPDX-License-Identifier: MIT HEADER_COMMON = f"""// SPDX-License-Identifier: MIT
...@@ -29,6 +30,9 @@ class submodule_t: ...@@ -29,6 +30,9 @@ class submodule_t:
def push(self, f): def push(self, f):
if len(f.parents) != 1: # ignore ./xxx.hpp if len(f.parents) != 1: # ignore ./xxx.hpp
mod = get_module(f) mod = get_module(f)
# ref is supposed to include one header on demand
if mod == REF:
return
if mod == OPS: if mod == OPS:
if mod not in self.m.keys(): if mod not in self.m.keys():
self.m[mod] = dict() self.m[mod] = dict()
......
...@@ -52,6 +52,9 @@ using device_batched_gemm_xdl_universal_bf16_bf16_bf16_mk_nk_mn_comp_instances = ...@@ -52,6 +52,9 @@ using device_batched_gemm_xdl_universal_bf16_bf16_bf16_mk_nk_mn_comp_instances =
DeviceBatchedGemmMultiD_Xdl_CShuffle_V3< Row, Col, DsLayout, Row, BF16, BF16, DsDataType, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 256, 256, 32, 8, 8, 32, 32, 4, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 16>, S<4>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v5>, DeviceBatchedGemmMultiD_Xdl_CShuffle_V3< Row, Col, DsLayout, Row, BF16, BF16, DsDataType, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 256, 256, 32, 8, 8, 32, 32, 4, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 16>, S<4>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v5>,
DeviceBatchedGemmMultiD_Xdl_CShuffle_V3< Row, Col, DsLayout, Row, BF16, BF16, DsDataType, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 224, 256, 64, 8, 8, 16, 16, 7, 8, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 2, S<1, 16, 1, 16>, S<4>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3>, DeviceBatchedGemmMultiD_Xdl_CShuffle_V3< Row, Col, DsLayout, Row, BF16, BF16, DsDataType, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 224, 256, 64, 8, 8, 16, 16, 7, 8, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 2, S<1, 16, 1, 16>, S<4>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3>,
DeviceBatchedGemmMultiD_Xdl_CShuffle_V3< Row, Col, DsLayout, Row, BF16, BF16, DsDataType, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 256, 224, 64, 8, 8, 16, 16, 8, 7, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 2, 1, S<1, 32, 1, 8>, S<4>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3>, DeviceBatchedGemmMultiD_Xdl_CShuffle_V3< Row, Col, DsLayout, Row, BF16, BF16, DsDataType, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 256, 224, 64, 8, 8, 16, 16, 8, 7, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 2, 1, S<1, 32, 1, 8>, S<4>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3>,
DeviceBatchedGemmMultiD_Xdl_CShuffle_V3< Row, Col, DsLayout, Row, BF16, BF16, DsDataType, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 256, 160, 64, 8, 8, 16, 16, 8, 5, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 2, 1, S<1, 32, 1, 8>, S<4>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3>,
DeviceBatchedGemmMultiD_Xdl_CShuffle_V3< Row, Col, DsLayout, Row, BF16, BF16, DsDataType, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 160, 64, 8, 8, 32, 32, 1, 5, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 64, 1, 4>, S<8>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3>,
DeviceBatchedGemmMultiD_Xdl_CShuffle_V3< Row, Col, DsLayout, Row, BF16, BF16, DsDataType, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 160, 128, 64, 8, 8, 32, 32, 5, 1, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 16>, S<4>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3>,
DeviceBatchedGemmMultiD_Xdl_CShuffle_V3< Row, Col, DsLayout, Row, BF16, BF16, DsDataType, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 64, 8, 8, 32, 32, 2, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 16>, S<4>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3>, DeviceBatchedGemmMultiD_Xdl_CShuffle_V3< Row, Col, DsLayout, Row, BF16, BF16, DsDataType, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 64, 8, 8, 32, 32, 2, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 16>, S<4>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3>,
DeviceBatchedGemmMultiD_Xdl_CShuffle_V3< Row, Col, DsLayout, Row, BF16, BF16, DsDataType, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 64, 8, 8, 32, 32, 2, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 16>, S<4>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v5>, DeviceBatchedGemmMultiD_Xdl_CShuffle_V3< Row, Col, DsLayout, Row, BF16, BF16, DsDataType, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 64, 8, 8, 32, 32, 2, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 16>, S<4>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v5>,
DeviceBatchedGemmMultiD_Xdl_CShuffle_V3< Row, Col, DsLayout, Row, BF16, BF16, DsDataType, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 64, 8, 8, 32, 32, 2, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 16>, S<4>, BlockGemmPipelineScheduler::Interwave, BlockGemmPipelineVersion::v1> DeviceBatchedGemmMultiD_Xdl_CShuffle_V3< Row, Col, DsLayout, Row, BF16, BF16, DsDataType, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 64, 8, 8, 32, 32, 2, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 16>, S<4>, BlockGemmPipelineScheduler::Interwave, BlockGemmPipelineVersion::v1>
......
...@@ -42,6 +42,7 @@ using device_batched_gemm_xdl_universal_f8_f8_bf16_mk_nk_mn_comp_instances = std ...@@ -42,6 +42,7 @@ using device_batched_gemm_xdl_universal_f8_f8_bf16_mk_nk_mn_comp_instances = std
//##################################| | | | | Type| Type| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise|Specialization| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MXdlPerWave_MWaveMPerXdl| ScalarPerVector| Pipeline| Pipeline| //##################################| | | | | Type| Type| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise|Specialization| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MXdlPerWave_MWaveMPerXdl| ScalarPerVector| Pipeline| Pipeline|
//##################################| | | | | | | | | | | Operation| Operation| Operation| | | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NXdlPerWave_NWaveNPerXdl| _NWaveNPerXdl| Scheduler| Verision| //##################################| | | | | | | | | | | Operation| Operation| Operation| | | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NXdlPerWave_NWaveNPerXdl| _NWaveNPerXdl| Scheduler| Verision|
//##################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | //##################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
#ifdef __gfx94__ #ifdef __gfx94__
// Compute friendly // Compute friendly
DeviceBatchedGemmMultiD_Xdl_CShuffle_V3< Row, Col, DsLayout, Row, F8, F8, DsDataType, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 256, 256, 64, 16, 16, 32, 32, 4, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 8>, S<8>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v4, F8>, DeviceBatchedGemmMultiD_Xdl_CShuffle_V3< Row, Col, DsLayout, Row, F8, F8, DsDataType, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 256, 256, 64, 16, 16, 32, 32, 4, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 8>, S<8>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v4, F8>,
...@@ -72,6 +73,7 @@ using device_batched_gemm_xdl_universal_f8_f8_bf16_mk_nk_mn_mem_instances = std: ...@@ -72,6 +73,7 @@ using device_batched_gemm_xdl_universal_f8_f8_bf16_mk_nk_mn_mem_instances = std:
//##################################| | | | | Type| Type| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise|Specialization| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MXdlPerWave_MWaveMPerXdl| ScalarPerVector| Pipeline| Pipeline| //##################################| | | | | Type| Type| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise|Specialization| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MXdlPerWave_MWaveMPerXdl| ScalarPerVector| Pipeline| Pipeline|
//##################################| | | | | | | | | | | Operation| Operation| Operation| | | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NXdlPerWave_NWaveNPerXdl| _NWaveNPerXdl| Scheduler| Verision| //##################################| | | | | | | | | | | Operation| Operation| Operation| | | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NXdlPerWave_NWaveNPerXdl| _NWaveNPerXdl| Scheduler| Verision|
//##################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | //##################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
#if defined(__gfx94__) || defined(CK_USE_FP8_ON_UNSUPPORTED_ARCH) #if defined(__gfx94__) || defined(CK_USE_FP8_ON_UNSUPPORTED_ARCH)
DeviceBatchedGemmMultiD_Xdl_CShuffle_V3< Row, Col, DsLayout, Row, F8, F8, DsDataType, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 32, 16, 128, 16, 16, 16, 16, 1, 1, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 16, 1, 8>, S<2>, BlkGemmPipeSched, BlockGemmPipelineVersion::v1, F8>, DeviceBatchedGemmMultiD_Xdl_CShuffle_V3< Row, Col, DsLayout, Row, F8, F8, DsDataType, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 32, 16, 128, 16, 16, 16, 16, 1, 1, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 16, 1, 8>, S<2>, BlkGemmPipeSched, BlockGemmPipelineVersion::v1, F8>,
DeviceBatchedGemmMultiD_Xdl_CShuffle_V3< Row, Col, DsLayout, Row, F8, F8, DsDataType, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 64, 16, 16, 128, 16, 16, 16, 16, 1, 1, S<8, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 16, 1, 4>, S<4>, BlkGemmPipeSched, BlockGemmPipelineVersion::v1, F8>, DeviceBatchedGemmMultiD_Xdl_CShuffle_V3< Row, Col, DsLayout, Row, F8, F8, DsDataType, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 64, 16, 16, 128, 16, 16, 16, 16, 1, 1, S<8, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 16, 1, 4>, S<4>, BlkGemmPipeSched, BlockGemmPipelineVersion::v1, F8>,
......
...@@ -48,6 +48,7 @@ bool profile_gemm_universal_batched_impl(int do_verification, ...@@ -48,6 +48,7 @@ bool profile_gemm_universal_batched_impl(int do_verification,
int StrideB, int StrideB,
int StrideC, int StrideC,
int BatchCount, int BatchCount,
int KBatch,
int n_warmup, int n_warmup,
int n_iter, int n_iter,
uint64_t rotating = 0) uint64_t rotating = 0)
...@@ -147,14 +148,23 @@ bool profile_gemm_universal_batched_impl(int do_verification, ...@@ -147,14 +148,23 @@ bool profile_gemm_universal_batched_impl(int do_verification,
float best_ave_time = 0; float best_ave_time = 0;
float best_tflops = 0; float best_tflops = 0;
float best_gb_per_sec = 0; float best_gb_per_sec = 0;
float best_kbatch = 0;
// profile device op instances // profile device op instances
for(auto& op_ptr : op_ptrs) for(auto& op_ptr : op_ptrs)
{ {
std::unique_ptr<tensor_operation::device::BaseArgument> argument_ptr; std::vector<int> kbatch_list = {1, 2, 4, 8, 16, 19, 32, 38};
// false branch for multi d dl kernel
argument_ptr = if(KBatch > 0)
{
kbatch_list = {KBatch};
}
for(std::size_t i = 0; i < kbatch_list.size(); i++)
{
auto kbatch_curr = kbatch_list[i];
auto argument_ptr =
op_ptr->MakeArgumentPointer(static_cast<ADataType*>(a_device_buf.GetDeviceBuffer()), op_ptr->MakeArgumentPointer(static_cast<ADataType*>(a_device_buf.GetDeviceBuffer()),
static_cast<BDataType*>(b_device_buf.GetDeviceBuffer()), static_cast<BDataType*>(b_device_buf.GetDeviceBuffer()),
{}, {},
...@@ -173,15 +183,13 @@ bool profile_gemm_universal_batched_impl(int do_verification, ...@@ -173,15 +183,13 @@ bool profile_gemm_universal_batched_impl(int do_verification,
BatchStrideC, BatchStrideC,
ck::tensor_operation::element_wise::PassThrough{}, ck::tensor_operation::element_wise::PassThrough{},
ck::tensor_operation::element_wise::PassThrough{}, ck::tensor_operation::element_wise::PassThrough{},
ck::tensor_operation::element_wise::PassThrough{}); ck::tensor_operation::element_wise::PassThrough{},
kbatch_curr);
auto invoker_ptr = op_ptr->MakeInvokerPointer(); auto invoker_ptr = op_ptr->MakeInvokerPointer();
if(op_ptr->IsSupportedArgument(argument_ptr.get())) if(op_ptr->IsSupportedArgument(argument_ptr.get()))
{ {
// re-init C to zero before profiling next kernel
c_device_buf.SetZero();
std::string op_name = op_ptr->GetTypeString(); std::string op_name = op_ptr->GetTypeString();
float ave_time = invoker_ptr->Run( float ave_time = invoker_ptr->Run(
...@@ -199,7 +207,7 @@ bool profile_gemm_universal_batched_impl(int do_verification, ...@@ -199,7 +207,7 @@ bool profile_gemm_universal_batched_impl(int do_verification,
float gb_per_sec = num_btype / 1.E6 / ave_time; float gb_per_sec = num_btype / 1.E6 / ave_time;
std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec
<< " GB/s, " << op_name << std::endl; << " GB/s, " << op_name << ", KBatch " << kbatch_curr << std::endl;
if(tflops > best_tflops) if(tflops > best_tflops)
{ {
...@@ -207,6 +215,7 @@ bool profile_gemm_universal_batched_impl(int do_verification, ...@@ -207,6 +215,7 @@ bool profile_gemm_universal_batched_impl(int do_verification,
best_tflops = tflops; best_tflops = tflops;
best_ave_time = ave_time; best_ave_time = ave_time;
best_gb_per_sec = gb_per_sec; best_gb_per_sec = gb_per_sec;
best_kbatch = kbatch_curr;
} }
if(do_verification) if(do_verification)
...@@ -219,7 +228,8 @@ bool profile_gemm_universal_batched_impl(int do_verification, ...@@ -219,7 +228,8 @@ bool profile_gemm_universal_batched_impl(int do_verification,
{ {
LogRangeAsType<float>(std::cout << "a : ", a_g_m_k.mData, ",") << std::endl; LogRangeAsType<float>(std::cout << "a : ", a_g_m_k.mData, ",") << std::endl;
LogRangeAsType<float>(std::cout << "b: ", b_g_k_n.mData, ",") << std::endl; LogRangeAsType<float>(std::cout << "b: ", b_g_k_n.mData, ",") << std::endl;
LogRangeAsType<float>(std::cout << "c_host: ", c_g_m_n_host_result.mData, ",") LogRangeAsType<float>(
std::cout << "c_host: ", c_g_m_n_host_result.mData, ",")
<< std::endl; << std::endl;
LogRangeAsType<float>( LogRangeAsType<float>(
std::cout << "c_device: ", c_g_m_n_device_result.mData, ",") std::cout << "c_device: ", c_g_m_n_device_result.mData, ",")
...@@ -229,7 +239,9 @@ bool profile_gemm_universal_batched_impl(int do_verification, ...@@ -229,7 +239,9 @@ bool profile_gemm_universal_batched_impl(int do_verification,
} }
else else
{ {
std::cout << op_ptr->GetTypeString() << " does not support this problem" << std::endl; std::cout << op_ptr->GetTypeString() << " does not support this problem"
<< std::endl;
}
} }
} }
...@@ -270,8 +282,8 @@ bool profile_gemm_universal_batched_impl(int do_verification, ...@@ -270,8 +282,8 @@ bool profile_gemm_universal_batched_impl(int do_verification,
std::cout << " B = " << BatchCount << " M = " << M << " N = " << N << " K = " << K std::cout << " B = " << BatchCount << " M = " << M << " N = " << N << " K = " << K
<< " StrideA = " << StrideA << " StrideB = " << StrideB << " StrideC = " << StrideC << " StrideA = " << StrideA << " StrideB = " << StrideB << " StrideC = " << StrideC
<< ": " << best_ave_time << " ms, " << best_tflops << " TFlops, " << best_gb_per_sec << " KBatch = " << best_kbatch << ": " << best_ave_time << " ms, " << best_tflops
<< " GB/s, " << best_op_name << std::endl; << " TFlops, " << best_gb_per_sec << " GB/s, " << best_op_name << std::endl;
return pass; return pass;
} }
......
...@@ -144,6 +144,7 @@ bool profile_gemm_universal_impl(int do_verification, ...@@ -144,6 +144,7 @@ bool profile_gemm_universal_impl(int do_verification,
} }
std::string best_op_name; std::string best_op_name;
std::optional<std::string> best_op_object_name;
float best_ave_time = 0; float best_ave_time = 0;
float best_tflops = 0; float best_tflops = 0;
float best_gb_per_sec = 0; float best_gb_per_sec = 0;
...@@ -226,6 +227,7 @@ bool profile_gemm_universal_impl(int do_verification, ...@@ -226,6 +227,7 @@ bool profile_gemm_universal_impl(int do_verification,
} }
std::string op_name = op_ptr->GetTypeString(); std::string op_name = op_ptr->GetTypeString();
std::optional<std::string> op_obj_name = op_ptr->GetObjectName();
float ave_time = invoker_ptr->Run(argument_ptr.get(), float ave_time = invoker_ptr->Run(argument_ptr.get(),
StreamConfig{nullptr, StreamConfig{nullptr,
...@@ -252,6 +254,7 @@ bool profile_gemm_universal_impl(int do_verification, ...@@ -252,6 +254,7 @@ bool profile_gemm_universal_impl(int do_verification,
if(tflops > best_tflops && ave_time > 1e-10) if(tflops > best_tflops && ave_time > 1e-10)
{ {
best_op_name = op_name; best_op_name = op_name;
best_op_object_name = op_obj_name;
best_tflops = tflops; best_tflops = tflops;
best_ave_time = ave_time; best_ave_time = ave_time;
best_gb_per_sec = gb_per_sec; best_gb_per_sec = gb_per_sec;
...@@ -306,6 +309,9 @@ bool profile_gemm_universal_impl(int do_verification, ...@@ -306,6 +309,9 @@ bool profile_gemm_universal_impl(int do_verification,
<< " : " << best_ave_time << " ms, " << best_tflops << " TFlops, " << best_gb_per_sec << " : " << best_ave_time << " ms, " << best_tflops << " TFlops, " << best_gb_per_sec
<< " GB/s, " << best_op_name << std::endl; << " GB/s, " << best_op_name << std::endl;
if(best_op_object_name)
std::cout << best_op_object_name.value() << std::endl;
return pass; return pass;
} }
......
...@@ -31,7 +31,7 @@ enum struct GemmDataType ...@@ -31,7 +31,7 @@ enum struct GemmDataType
int profile_batched_gemm_universal(int argc, char* argv[]) int profile_batched_gemm_universal(int argc, char* argv[])
{ {
if(argc != 18 && argc != 21) if(argc != 19 && argc != 22)
{ {
// clang-format off // clang-format off
printf("arg1: tensor operation (" OP_NAME ": " OP_DESC ")\n"); printf("arg1: tensor operation (" OP_NAME ": " OP_DESC ")\n");
...@@ -44,11 +44,11 @@ int profile_batched_gemm_universal(int argc, char* argv[]) ...@@ -44,11 +44,11 @@ int profile_batched_gemm_universal(int argc, char* argv[])
printf("arg5: initialization (0: no init; 1: integer value; 2: decimal value)\n"); printf("arg5: initialization (0: no init; 1: integer value; 2: decimal value)\n");
printf("arg6: print tensor value (0: no; 1: yes)\n"); printf("arg6: print tensor value (0: no; 1: yes)\n");
printf("arg7: time kernel (0=n0, 1=yes)\n"); printf("arg7: time kernel (0=n0, 1=yes)\n");
printf("arg8 to 17: M, N, K, StrideA, StrideB, StrideC, BatchStrideA, BatchStrideB, BatchStrideC, BatchCount\n"); printf("arg8 to 18: M, N, K, StrideA, StrideB, StrideC, BatchStrideA, BatchStrideB, BatchStrideC, BatchCount, KBatch\n");
printf("optional:\n"); printf("optional:\n");
printf("arg18: number of warm-up cycles (default 1)\n"); printf("arg19: number of warm-up cycles (default 1)\n");
printf("arg19: number of iterations (default 10)\n"); printf("arg20: number of iterations (default 10)\n");
printf("arg20: memory for rotating buffer (default 0, size in MB)\n"); printf("arg21: memory for rotating buffer (default 0, size in MB)\n");
// clang-format on // clang-format on
exit(1); exit(1);
} }
...@@ -56,11 +56,11 @@ int profile_batched_gemm_universal(int argc, char* argv[]) ...@@ -56,11 +56,11 @@ int profile_batched_gemm_universal(int argc, char* argv[])
int n_warmup = 1; int n_warmup = 1;
int n_iter = 10; int n_iter = 10;
uint64_t rotating = 0; uint64_t rotating = 0;
if(argc == 21) if(argc == 22)
{ {
n_warmup = std::stoi(argv[18]); n_warmup = std::stoi(argv[19]);
n_iter = std::stoi(argv[19]); n_iter = std::stoi(argv[20]);
rotating = std::stoull(argv[20]) * 1024 * 1024; rotating = std::stoull(argv[21]) * 1024 * 1024;
} }
const auto data_type = static_cast<GemmDataType>(std::stoi(argv[2])); const auto data_type = static_cast<GemmDataType>(std::stoi(argv[2]));
...@@ -83,6 +83,7 @@ int profile_batched_gemm_universal(int argc, char* argv[]) ...@@ -83,6 +83,7 @@ int profile_batched_gemm_universal(int argc, char* argv[])
const int BatchStrideC = std::stoi(argv[16]); const int BatchStrideC = std::stoi(argv[16]);
const int BatchCount = std::stoi(argv[17]); const int BatchCount = std::stoi(argv[17]);
const int KBatch = std::stoi(argv[18]);
#if defined(CK_USE_FP8_ON_UNSUPPORTED_ARCH) || defined(CK_USE_GFX94) #if defined(CK_USE_FP8_ON_UNSUPPORTED_ARCH) || defined(CK_USE_GFX94)
using F8 = ck::f8_t; using F8 = ck::f8_t;
...@@ -159,6 +160,7 @@ int profile_batched_gemm_universal(int argc, char* argv[]) ...@@ -159,6 +160,7 @@ int profile_batched_gemm_universal(int argc, char* argv[])
StrideB_, StrideB_,
StrideC_, StrideC_,
BatchCount, BatchCount,
KBatch,
n_warmup, n_warmup,
n_iter, n_iter,
rotating); rotating);
......
...@@ -332,7 +332,7 @@ def main(): ...@@ -332,7 +332,7 @@ def main():
table_name="ck_fmha_bwd_tflops" table_name="ck_fmha_bwd_tflops"
tflops_base = get_baseline(table_name,conn) tflops_base = get_baseline(table_name,conn)
store_new_test_result(table_name, results, testlist, branch_name, node_id, gpu_arch, compute_units, rocm_vers, hip_vers, environment, conn) store_new_test_result(table_name, results, testlist, branch_name, node_id, gpu_arch, compute_units, rocm_vers, hip_vers, environment, sqlEngine)
conn.close() conn.close()
#compare the results to the baseline if baseline exists #compare the results to the baseline if baseline exists
......
# Currently ck_tile is only built on gfx9 # Currently ck_tile is only built on gfx9
if(GPU_TARGETS MATCHES "gfx9") if(GPU_TARGETS MATCHES "gfx9")
add_gtest_executable(test_ck_tile_gemm_mem_pipeline test_gemm_mem_pipeline.cpp) add_gtest_executable(test_ck_tile_gemm_pipeline test_gemm_pipeline.cpp)
endif() endif()
...@@ -6,7 +6,7 @@ ...@@ -6,7 +6,7 @@
#include "gtest/gtest.h" #include "gtest/gtest.h"
#include "ck_tile/host.hpp" #include "ck_tile/host.hpp"
#include "test_gemm_mem_pipeline_util.hpp" #include "test_gemm_pipeline_util.hpp"
using F16 = ck_tile::half_t; using F16 = ck_tile::half_t;
using F32 = float; using F32 = float;
...@@ -16,21 +16,27 @@ using Intrawave = ck_tile::integral_constant<ck_tile::GemmPipelineScheduler, ...@@ -16,21 +16,27 @@ using Intrawave = ck_tile::integral_constant<ck_tile::GemmPipelineScheduler,
ck_tile::GemmPipelineScheduler::Intrawave>; ck_tile::GemmPipelineScheduler::Intrawave>;
using Interwave = ck_tile::integral_constant<ck_tile::GemmPipelineScheduler, using Interwave = ck_tile::integral_constant<ck_tile::GemmPipelineScheduler,
ck_tile::GemmPipelineScheduler::Interwave>; ck_tile::GemmPipelineScheduler::Interwave>;
using Mem = ck_tile::integral_constant<GemmPipelineType, GemmPipelineType::Mem>;
using Comp = ck_tile::integral_constant<GemmPipelineType, GemmPipelineType::Comp>;
// clang-format off // clang-format off
using KernelTypes = ::testing::Types< using KernelTypes = ::testing::Types<
// ALayout, BLayout, CLayout, ADataType, BDataType, AccDataType, CDataType, GemmPipelineScheduler // ALayout, BLayout, CLayout, ADataType, BDataType, AccDataType, CDataType, GemmPipelineScheduler, PipelineType
std::tuple< Row, Row, Row, F16, F16, F32, F16, Intrawave>, std::tuple< Row, Row, Row, F16, F16, F32, F16, Intrawave, Mem>,
std::tuple< Row, Row, Row, F16, F16, F32, F16, Interwave>, std::tuple< Row, Row, Row, F16, F16, F32, F16, Intrawave, Comp>,
std::tuple< Row, Col, Row, F16, F16, F32, F16, Intrawave>, std::tuple< Row, Row, Row, F16, F16, F32, F16, Interwave, Mem>,
std::tuple< Row, Col, Row, F16, F16, F32, F16, Interwave>, std::tuple< Row, Col, Row, F16, F16, F32, F16, Intrawave, Mem>,
std::tuple< Col, Row, Row, F16, F16, F32, F16, Intrawave>, std::tuple< Row, Col, Row, F16, F16, F32, F16, Intrawave, Comp>,
std::tuple< Col, Row, Row, F16, F16, F32, F16, Interwave>, std::tuple< Row, Col, Row, F16, F16, F32, F16, Interwave, Mem>,
std::tuple< Col, Col, Row, F16, F16, F32, F16, Intrawave>, std::tuple< Col, Row, Row, F16, F16, F32, F16, Intrawave, Mem>,
std::tuple< Col, Col, Row, F16, F16, F32, F16, Interwave> std::tuple< Col, Row, Row, F16, F16, F32, F16, Intrawave, Comp>,
std::tuple< Col, Row, Row, F16, F16, F32, F16, Interwave, Mem>,
std::tuple< Col, Col, Row, F16, F16, F32, F16, Intrawave, Mem>,
std::tuple< Col, Col, Row, F16, F16, F32, F16, Intrawave, Comp>,
std::tuple< Col, Col, Row, F16, F16, F32, F16, Interwave, Mem>
>; >;
// clang-format on // clang-format on
TYPED_TEST_SUITE(TestCkTileGemmMemPipeline, KernelTypes); TYPED_TEST_SUITE(TestCkTileGemmPipeline, KernelTypes);
#include "test_gemm_mem_pipeline_ut_cases.inc" #include "test_gemm_pipeline_ut_cases.inc"
...@@ -3,7 +3,7 @@ ...@@ -3,7 +3,7 @@
#pragma once #pragma once
TYPED_TEST(TestCkTileGemmMemPipeline, SmallM) TYPED_TEST(TestCkTileGemmPipeline, SmallM)
{ {
std::vector<int> Ms{1, 2, 3, 4, 5, 6}; std::vector<int> Ms{1, 2, 3, 4, 5, 6};
constexpr int N = 1024; constexpr int N = 1024;
...@@ -13,7 +13,7 @@ TYPED_TEST(TestCkTileGemmMemPipeline, SmallM) ...@@ -13,7 +13,7 @@ TYPED_TEST(TestCkTileGemmMemPipeline, SmallM)
this->Run(M, N, K); this->Run(M, N, K);
} }
TYPED_TEST(TestCkTileGemmMemPipeline, MidLargeM) TYPED_TEST(TestCkTileGemmPipeline, MidLargeM)
{ {
std::vector<int> Ms{127, 255, 312, 799, 1573}; std::vector<int> Ms{127, 255, 312, 799, 1573};
constexpr int N = 1024; constexpr int N = 1024;
...@@ -23,7 +23,7 @@ TYPED_TEST(TestCkTileGemmMemPipeline, MidLargeM) ...@@ -23,7 +23,7 @@ TYPED_TEST(TestCkTileGemmMemPipeline, MidLargeM)
this->Run(M, N, K); this->Run(M, N, K);
} }
TYPED_TEST(TestCkTileGemmMemPipeline, PaddK) TYPED_TEST(TestCkTileGemmPipeline, PaddK)
{ {
std::vector<int> Ms{127}; std::vector<int> Ms{127};
constexpr int N = 1024; constexpr int N = 1024;
...@@ -33,7 +33,7 @@ TYPED_TEST(TestCkTileGemmMemPipeline, PaddK) ...@@ -33,7 +33,7 @@ TYPED_TEST(TestCkTileGemmMemPipeline, PaddK)
this->Run(M, N, K); this->Run(M, N, K);
} }
TYPED_TEST(TestCkTileGemmMemPipeline, Regular) TYPED_TEST(TestCkTileGemmPipeline, Regular)
{ {
std::vector<int> Ms{512}; std::vector<int> Ms{512};
constexpr int N = 1024; constexpr int N = 1024;
...@@ -43,7 +43,7 @@ TYPED_TEST(TestCkTileGemmMemPipeline, Regular) ...@@ -43,7 +43,7 @@ TYPED_TEST(TestCkTileGemmMemPipeline, Regular)
this->Run(M, N, K); this->Run(M, N, K);
} }
TYPED_TEST(TestCkTileGemmMemPipeline, NotSupportedArgument) TYPED_TEST(TestCkTileGemmPipeline, NotSupportedArgument)
{ {
constexpr int M = 512; constexpr int M = 512;
constexpr int N = 1025; constexpr int N = 1025;
......
...@@ -11,8 +11,13 @@ ...@@ -11,8 +11,13 @@
#include "ck_tile/ops/epilogue.hpp" #include "ck_tile/ops/epilogue.hpp"
#include "ck_tile/ops/gemm.hpp" #include "ck_tile/ops/gemm.hpp"
enum struct GemmPipelineType
{
Mem,
Comp
};
template <typename Tuple> template <typename Tuple>
class TestCkTileGemmMemPipeline : public ::testing::Test class TestCkTileGemmPipeline : public ::testing::Test
{ {
protected: protected:
using ALayout = std::tuple_element_t<0, Tuple>; using ALayout = std::tuple_element_t<0, Tuple>;
...@@ -23,6 +28,7 @@ class TestCkTileGemmMemPipeline : public ::testing::Test ...@@ -23,6 +28,7 @@ class TestCkTileGemmMemPipeline : public ::testing::Test
using AccDataType = std::tuple_element_t<5, Tuple>; using AccDataType = std::tuple_element_t<5, Tuple>;
using CDataType = std::tuple_element_t<6, Tuple>; using CDataType = std::tuple_element_t<6, Tuple>;
static constexpr auto Scheduler = std::tuple_element_t<7, Tuple>::value; static constexpr auto Scheduler = std::tuple_element_t<7, Tuple>::value;
static constexpr auto PipelineType = std::tuple_element_t<8, Tuple>::value;
// TODO: expose tile size through test t-param ? // TODO: expose tile size through test t-param ?
struct gemm_args struct gemm_args
...@@ -74,8 +80,13 @@ class TestCkTileGemmMemPipeline : public ::testing::Test ...@@ -74,8 +80,13 @@ class TestCkTileGemmMemPipeline : public ::testing::Test
using Traits = ck_tile::TileGemmTraits<kPadM, kPadN, kPadK, ALayout, BLayout, CLayout>; using Traits = ck_tile::TileGemmTraits<kPadM, kPadN, kPadK, ALayout, BLayout, CLayout>;
using BaseGemmPipeline = ck_tile::BaseGemmPipelineAgBgCrMem< using BaseGemmPipeline = std::conditional_t<
ck_tile::GemmPipelineProblem<ADataType, BDataType, AccDataType, GemmShape, Traits>>; PipelineType == GemmPipelineType::Mem,
ck_tile::BaseGemmPipelineAgBgCrMem<
ck_tile::GemmPipelineProblem<ADataType, BDataType, AccDataType, GemmShape, Traits>>,
ck_tile::BaseGemmPipelineAgBgCrCompV3<
ck_tile::
GemmPipelineProblem<ADataType, BDataType, AccDataType, GemmShape, Traits>>>;
const ck_tile::index_t num_loop = TilePartitioner::GetLoopNum(args.K); const ck_tile::index_t num_loop = TilePartitioner::GetLoopNum(args.K);
const bool has_hot_loop = BaseGemmPipeline::BlockHasHotloop(num_loop); const bool has_hot_loop = BaseGemmPipeline::BlockHasHotloop(num_loop);
...@@ -85,7 +96,18 @@ class TestCkTileGemmMemPipeline : public ::testing::Test ...@@ -85,7 +96,18 @@ class TestCkTileGemmMemPipeline : public ::testing::Test
constexpr bool has_hot_loop_v = has_hot_loop_.value; constexpr bool has_hot_loop_v = has_hot_loop_.value;
constexpr auto tail_number_v = tail_number_.value; constexpr auto tail_number_v = tail_number_.value;
using GemmPipeline = ck_tile::GemmPipelineAgBgCrMem< using GemmPipeline =
std::conditional_t<PipelineType == GemmPipelineType::Mem,
ck_tile::GemmPipelineAgBgCrMem<
ck_tile::UniversalGemmPipelineProblem<ADataType,
BDataType,
AccDataType,
GemmShape,
Traits,
Scheduler,
has_hot_loop_v,
tail_number_v>>,
ck_tile::GemmPipelineAgBgCrCompV3<
ck_tile::UniversalGemmPipelineProblem<ADataType, ck_tile::UniversalGemmPipelineProblem<ADataType,
BDataType, BDataType,
AccDataType, AccDataType,
...@@ -93,7 +115,7 @@ class TestCkTileGemmMemPipeline : public ::testing::Test ...@@ -93,7 +115,7 @@ class TestCkTileGemmMemPipeline : public ::testing::Test
Traits, Traits,
Scheduler, Scheduler,
has_hot_loop_v, has_hot_loop_v,
tail_number_v>>; tail_number_v>>>;
using Kernel = ck_tile::GemmKernel<TilePartitioner, GemmPipeline, GemmEpilogue>; using Kernel = ck_tile::GemmKernel<TilePartitioner, GemmPipeline, GemmEpilogue>;
auto kargs = Kernel::MakeKargs(args.p_a, auto kargs = Kernel::MakeKargs(args.p_a,
args.p_b, args.p_b,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment