Unverified Commit 4d8fce33 authored by Bartłomiej Kocot's avatar Bartłomiej Kocot Committed by GitHub
Browse files

Add SplitK support into Batched GEMM V3 (#1729)



* add bmm api

* add bf16 multi_d

* add ckProfiler for bf16

* add ckProfiler files

* add more instance; fixed 64bit index issue

* fixed naming

* enabled batched Ds

* use long_index for ds offsets

* clean

* add bmm fp8 ckProfiler

* Update example/24_batched_gemm/batched_gemm_xdl_bf16_v3.cpp
Co-authored-by: default avatarBartłomiej Kocot <bartlomiejkocot98@gmail.com>

* Update example/24_batched_gemm/batched_gemm_xdl_fp8_rowwise_v3.cpp
Co-authored-by: default avatarBartłomiej Kocot <bartlomiejkocot98@gmail.com>

* Update example/24_batched_gemm/run_batched_gemm_example_rowwise.inc
Co-authored-by: default avatarBartłomiej Kocot <bartlomiejkocot98@gmail.com>

* Update library/src/tensor_operation_instance/gpu/gemm_universal_batched/device_batched_gemm_xdl_universal_bf16_bf16_bf16/device_batched_gemm_xdl_universal_bf16_bf16_bf16_mk_nk_mn.hpp
Co-authored-by: default avatarBartłomiej Kocot <bartlomiejkocot98@gmail.com>

* Update library/src/tensor_operation_instance/gpu/gemm_universal_batched/device_batched_gemm_xdl_universal_bf16_bf16_bf16/device_batched_gemm_xdl_universal_bf16_bf16_bf16_mk_nk_mn_mem_v1_default_instance.cpp
Co-authored-by: default avatarBartłomiej Kocot <bartlomiejkocot98@gmail.com>

* Update library/src/tensor_operation_instance/gpu/gemm_universal_batched/device_batched_gemm_xdl_universal_bf16_bf16_bf16/device_batched_gemm_xdl_universal_bf16_bf16_bf16_mk_nk_mn_mem_v2_default_instance.cpp
Co-authored-by: default avatarBartłomiej Kocot <bartlomiejkocot98@gmail.com>

* Update profiler/src/profile_gemm_universal_batched.cpp
Co-authored-by: default avatarBartłomiej Kocot <bartlomiejkocot98@gmail.com>

* Update profiler/include/profiler/profile_gemm_universal_batched_impl.hpp
Co-authored-by: default avatarBartłomiej Kocot <bartlomiejkocot98@gmail.com>

* clean

* Update include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_multiple_d_xdl_cshuffle_v3.hpp

* Update include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_multiple_d_xdl_cshuffle_v3.hpp

* Update library/src/tensor_operation_instance/gpu/gemm_universal_batched/device_batched_gemm_xdl_universal_bf16_bf16_bf16/device_batched_gemm_xdl_universal_bf16_bf16_bf16_mk_nk_mn_comp_default_instance.cpp

* Update include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_multiple_d_xdl_cshuffle_v3.hpp

* Update include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_multiple_d_xdl_cshuffle_v3.hpp

* Update include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_multiple_d_xdl_cshuffle_v3.hpp

* refactor batch offset func

* add splitk suppport into bmm_v3

* clean

* clean

* format

* fixed

* fix

---------
Co-authored-by: default avatarJing Zhang <jizhan@fb.com>
Co-authored-by: default avatarzjing14 <zhangjing14@gmail.com>
parent 4e731776
...@@ -78,14 +78,14 @@ using DeviceGemmInstance = ck::tensor_operation::device::DeviceBatchedGemmMultiD ...@@ -78,14 +78,14 @@ using DeviceGemmInstance = ck::tensor_operation::device::DeviceBatchedGemmMultiD
2, // ABlockTransferSrcVectorDim 2, // ABlockTransferSrcVectorDim
8, // ABlockTransferSrcScalarPerVector 8, // ABlockTransferSrcScalarPerVector
8, // ABlockTransferDstScalarPerVector_AK1 8, // ABlockTransferDstScalarPerVector_AK1
1, // ABlockLdsExtraM 0, // ABlockLdsExtraM
S<4, 64, 1>, // BBlockTransferThreadClusterLengths_BK0_N_BK1 S<4, 64, 1>, // BBlockTransferThreadClusterLengths_BK0_N_BK1
S<1, 0, 2>, // BBlockTransferThreadClusterArrangeOrder S<1, 0, 2>, // BBlockTransferThreadClusterArrangeOrder
S<1, 0, 2>, // BBlockTransferSrcAccessOrder S<1, 0, 2>, // BBlockTransferSrcAccessOrder
2, // BBlockTransferSrcVectorDim 2, // BBlockTransferSrcVectorDim
8, // BBlockTransferSrcScalarPerVector 8, // BBlockTransferSrcScalarPerVector
8, // BBlockTransferDstScalarPerVector_BK1 8, // BBlockTransferDstScalarPerVector_BK1
1, // BBlockLdsExtraN 0, // BBlockLdsExtraN
1, // CShuffleMXdlPerWavePerShuffle 1, // CShuffleMXdlPerWavePerShuffle
1, // CShuffleNXdlPerWavePerShuffle 1, // CShuffleNXdlPerWavePerShuffle
S<1, 32, 1, 8>, // CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock S<1, 32, 1, 8>, // CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
......
...@@ -89,7 +89,8 @@ struct DeviceBatchedGemmV2MultiD : public BaseOperator ...@@ -89,7 +89,8 @@ struct DeviceBatchedGemmV2MultiD : public BaseOperator
index_t BatchStrideE, index_t BatchStrideE,
AElementwiseOperation a_element_op, AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op, BElementwiseOperation b_element_op,
CDEElementwiseOperation cde_element_op) = 0; CDEElementwiseOperation cde_element_op,
index_t KBatch) = 0;
virtual std::unique_ptr<BaseInvoker> MakeInvokerPointer() = 0; virtual std::unique_ptr<BaseInvoker> MakeInvokerPointer() = 0;
}; };
......
...@@ -41,12 +41,15 @@ __global__ void ...@@ -41,12 +41,15 @@ __global__ void
__shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()];
const index_t g_idx = blockIdx.z % karg.Batch; const index_t g_idx = blockIdx.z % karg.Batch;
const index_t k_idx = blockIdx.z / karg.Batch;
const auto a_batch_offset = karg.compute_ptr_offset_of_batch.GetAPtrOffset(g_idx); const auto a_batch_offset = karg.compute_ptr_offset_of_batch.GetAPtrOffset(g_idx);
const auto b_batch_offset = karg.compute_ptr_offset_of_batch.GetBPtrOffset(g_idx); const auto b_batch_offset = karg.compute_ptr_offset_of_batch.GetBPtrOffset(g_idx);
const auto ds_batch_offset = karg.compute_ptr_offset_of_batch.GetDsPtrOffset(g_idx); const auto ds_batch_offset = karg.compute_ptr_offset_of_batch.GetDsPtrOffset(g_idx);
const auto c_batch_offset = karg.compute_ptr_offset_of_batch.GetCPtrOffset(g_idx); const auto c_batch_offset = karg.compute_ptr_offset_of_batch.GetCPtrOffset(g_idx);
auto splitk_batch_offset = typename GridwiseGemm::SplitKBatchOffset(karg, k_idx);
// populate pointer, desc for Ds // populate pointer, desc for Ds
static_for<0, GridwiseGemm::NumDTensor, 1>{}([&](auto i) { static_for<0, GridwiseGemm::NumDTensor, 1>{}([&](auto i) {
// D pointer // D pointer
...@@ -54,8 +57,8 @@ __global__ void ...@@ -54,8 +57,8 @@ __global__ void
}); });
GridwiseGemm::template Run<HasMainKBlockLoop, CGlobalMemoryDataOperation, TailNum>( GridwiseGemm::template Run<HasMainKBlockLoop, CGlobalMemoryDataOperation, TailNum>(
karg.p_a_grid + a_batch_offset, karg.p_a_grid + a_batch_offset + splitk_batch_offset.a_k_split_offset,
karg.p_b_grid + b_batch_offset, karg.p_b_grid + b_batch_offset + splitk_batch_offset.b_k_split_offset,
karg.p_ds_grid, karg.p_ds_grid,
karg.p_c_grid + c_batch_offset, karg.p_c_grid + c_batch_offset,
p_shared, p_shared,
...@@ -87,12 +90,15 @@ __global__ void ...@@ -87,12 +90,15 @@ __global__ void
__shared__ char p_shared_1[GridwiseGemm::GetSharedMemoryNumberOfByte()]; __shared__ char p_shared_1[GridwiseGemm::GetSharedMemoryNumberOfByte()];
const index_t g_idx = blockIdx.z % karg.Batch; const index_t g_idx = blockIdx.z % karg.Batch;
const index_t k_idx = blockIdx.z / karg.Batch;
const auto a_batch_offset = karg.compute_ptr_offset_of_batch.GetAPtrOffset(g_idx); const auto a_batch_offset = karg.compute_ptr_offset_of_batch.GetAPtrOffset(g_idx);
const auto b_batch_offset = karg.compute_ptr_offset_of_batch.GetBPtrOffset(g_idx); const auto b_batch_offset = karg.compute_ptr_offset_of_batch.GetBPtrOffset(g_idx);
const auto ds_batch_offset = karg.compute_ptr_offset_of_batch.GetDsPtrOffset(g_idx); const auto ds_batch_offset = karg.compute_ptr_offset_of_batch.GetDsPtrOffset(g_idx);
const auto c_batch_offset = karg.compute_ptr_offset_of_batch.GetCPtrOffset(g_idx); const auto c_batch_offset = karg.compute_ptr_offset_of_batch.GetCPtrOffset(g_idx);
auto splitk_batch_offset = typename GridwiseGemm::SplitKBatchOffset(karg, k_idx);
// populate pointer, desc for Ds // populate pointer, desc for Ds
static_for<0, GridwiseGemm::NumDTensor, 1>{}([&](auto i) { static_for<0, GridwiseGemm::NumDTensor, 1>{}([&](auto i) {
// D pointer // D pointer
...@@ -100,8 +106,8 @@ __global__ void ...@@ -100,8 +106,8 @@ __global__ void
}); });
GridwiseGemm::template Run_2Lds<HasMainKBlockLoop, CGlobalMemoryDataOperation, TailNum>( GridwiseGemm::template Run_2Lds<HasMainKBlockLoop, CGlobalMemoryDataOperation, TailNum>(
karg.p_a_grid + a_batch_offset, karg.p_a_grid + a_batch_offset + splitk_batch_offset.a_k_split_offset,
karg.p_b_grid + b_batch_offset, karg.p_b_grid + b_batch_offset + splitk_batch_offset.b_k_split_offset,
karg.p_ds_grid, karg.p_ds_grid,
karg.p_c_grid + c_batch_offset, karg.p_c_grid + c_batch_offset,
p_shared_0, p_shared_0,
...@@ -303,7 +309,8 @@ struct DeviceBatchedGemmMultiD_Xdl_CShuffle_V3 ...@@ -303,7 +309,8 @@ struct DeviceBatchedGemmMultiD_Xdl_CShuffle_V3
index_t Batch_, index_t Batch_,
AElementwiseOperation a_element_op_, AElementwiseOperation a_element_op_,
BElementwiseOperation b_element_op_, BElementwiseOperation b_element_op_,
CElementwiseOperation c_element_op_) CElementwiseOperation c_element_op_,
index_t KBatch_)
: GridwiseGemm::Argument{p_a_grid_, : GridwiseGemm::Argument{p_a_grid_,
p_b_grid_, p_b_grid_,
p_ds_grid_, p_ds_grid_,
...@@ -315,7 +322,7 @@ struct DeviceBatchedGemmMultiD_Xdl_CShuffle_V3 ...@@ -315,7 +322,7 @@ struct DeviceBatchedGemmMultiD_Xdl_CShuffle_V3
StrideB_, StrideB_,
StrideDs_, StrideDs_,
StrideE_, StrideE_,
1, KBatch_,
a_element_op_, a_element_op_,
b_element_op_, b_element_op_,
c_element_op_}, c_element_op_},
...@@ -336,13 +343,14 @@ struct DeviceBatchedGemmMultiD_Xdl_CShuffle_V3 ...@@ -336,13 +343,14 @@ struct DeviceBatchedGemmMultiD_Xdl_CShuffle_V3
arg.Print(); arg.Print();
} }
if(!GridwiseGemm::CheckValidity(arg) || arg.KBatch > 1) if(!GridwiseGemm::CheckValidity(arg))
{ {
throw std::runtime_error("wrong! GridwiseGemm has invalid setting"); throw std::runtime_error("wrong! GridwiseGemm has invalid setting");
} }
index_t gdx, gdy, gdz; index_t gdx, gdy, gdz;
std::tie(gdx, gdy, gdz) = GridwiseGemm::CalculateGridSize(arg.M, arg.N, arg.Batch); std::tie(gdx, gdy, gdz) =
GridwiseGemm::CalculateGridSize(arg.M, arg.N, arg.Batch * arg.KBatch);
float ave_time = 0; float ave_time = 0;
...@@ -387,10 +395,11 @@ struct DeviceBatchedGemmMultiD_Xdl_CShuffle_V3 ...@@ -387,10 +395,11 @@ struct DeviceBatchedGemmMultiD_Xdl_CShuffle_V3
rotating_mem.Next(); rotating_mem.Next();
// clear c mem // clear c mem
if(arg_.KBatch > 1) if(arg_.KBatch > 1)
hipGetErrorString(hipMemsetAsync(arg_.p_c_grid, hipGetErrorString(
0, hipMemsetAsync(arg_.p_c_grid,
arg_.M * arg_.N * sizeof(CDataType), 0,
stream_config.stream_id_)); arg.Batch * arg_.M * arg_.N * sizeof(CDataType),
stream_config.stream_id_));
}; };
ave_time = ck::utility::launch_and_time_kernel_with_preprocess<false>( ave_time = ck::utility::launch_and_time_kernel_with_preprocess<false>(
...@@ -889,7 +898,8 @@ struct DeviceBatchedGemmMultiD_Xdl_CShuffle_V3 ...@@ -889,7 +898,8 @@ struct DeviceBatchedGemmMultiD_Xdl_CShuffle_V3
index_t BatchStrideE, index_t BatchStrideE,
AElementwiseOperation a_element_op, AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op, BElementwiseOperation b_element_op,
CElementwiseOperation c_element_op) CElementwiseOperation c_element_op,
index_t KBatch = 1)
{ {
return Argument{static_cast<const ADataType*>(p_a), return Argument{static_cast<const ADataType*>(p_a),
static_cast<const BDataType*>(p_b), static_cast<const BDataType*>(p_b),
...@@ -909,7 +919,8 @@ struct DeviceBatchedGemmMultiD_Xdl_CShuffle_V3 ...@@ -909,7 +919,8 @@ struct DeviceBatchedGemmMultiD_Xdl_CShuffle_V3
Batch, Batch,
a_element_op, a_element_op,
b_element_op, b_element_op,
c_element_op}; c_element_op,
KBatch};
} }
static auto MakeInvoker() { return Invoker{}; } static auto MakeInvoker() { return Invoker{}; }
...@@ -934,7 +945,8 @@ struct DeviceBatchedGemmMultiD_Xdl_CShuffle_V3 ...@@ -934,7 +945,8 @@ struct DeviceBatchedGemmMultiD_Xdl_CShuffle_V3
index_t BatchStrideE, index_t BatchStrideE,
AElementwiseOperation a_element_op, AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op, BElementwiseOperation b_element_op,
CElementwiseOperation c_element_op) override CElementwiseOperation c_element_op,
index_t KBatch = 1) override
{ {
return std::make_unique<Argument>(static_cast<const ADataType*>(p_a), return std::make_unique<Argument>(static_cast<const ADataType*>(p_a),
static_cast<const BDataType*>(p_b), static_cast<const BDataType*>(p_b),
...@@ -954,7 +966,8 @@ struct DeviceBatchedGemmMultiD_Xdl_CShuffle_V3 ...@@ -954,7 +966,8 @@ struct DeviceBatchedGemmMultiD_Xdl_CShuffle_V3
Batch, Batch,
a_element_op, a_element_op,
b_element_op, b_element_op,
c_element_op); c_element_op,
KBatch);
} }
// polymorphic // polymorphic
......
...@@ -41,7 +41,7 @@ __global__ void ...@@ -41,7 +41,7 @@ __global__ void
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx9__)) #if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx9__))
__shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()];
auto splitk_batch_offset = typename GridwiseGemm::SplitKBatchOffset(karg); auto splitk_batch_offset = typename GridwiseGemm::SplitKBatchOffset(karg, blockIdx.z);
GridwiseGemm::template Run<HasMainKBlockLoop, CGlobalMemoryDataOperation, TailNum>( GridwiseGemm::template Run<HasMainKBlockLoop, CGlobalMemoryDataOperation, TailNum>(
karg.p_a_grid + splitk_batch_offset.a_k_split_offset, karg.p_a_grid + splitk_batch_offset.a_k_split_offset,
...@@ -76,7 +76,7 @@ __global__ void ...@@ -76,7 +76,7 @@ __global__ void
__shared__ char p_shared_0[GridwiseGemm::GetSharedMemoryNumberOfByte()]; __shared__ char p_shared_0[GridwiseGemm::GetSharedMemoryNumberOfByte()];
__shared__ char p_shared_1[GridwiseGemm::GetSharedMemoryNumberOfByte()]; __shared__ char p_shared_1[GridwiseGemm::GetSharedMemoryNumberOfByte()];
auto splitk_batch_offset = typename GridwiseGemm::SplitKBatchOffset(karg); auto splitk_batch_offset = typename GridwiseGemm::SplitKBatchOffset(karg, blockIdx.z);
GridwiseGemm::template Run_2Lds<HasMainKBlockLoop, CGlobalMemoryDataOperation, TailNum>( GridwiseGemm::template Run_2Lds<HasMainKBlockLoop, CGlobalMemoryDataOperation, TailNum>(
karg.p_a_grid + splitk_batch_offset.a_k_split_offset, karg.p_a_grid + splitk_batch_offset.a_k_split_offset,
...@@ -639,27 +639,27 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3 ...@@ -639,27 +639,27 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3
struct SplitKBatchOffset struct SplitKBatchOffset
{ {
__device__ SplitKBatchOffset(Argument& karg) __device__ SplitKBatchOffset(Argument& karg, index_t k_id)
{ {
if constexpr(is_same_v<tensor_layout::gemm::RowMajor, ALayout>) if constexpr(is_same_v<tensor_layout::gemm::RowMajor, ALayout>)
{ {
a_k_split_offset = blockIdx.z * karg.KRead; a_k_split_offset = k_id * karg.KRead;
} }
else if constexpr(is_same_v<tensor_layout::gemm::ColumnMajor, ALayout>) else if constexpr(is_same_v<tensor_layout::gemm::ColumnMajor, ALayout>)
{ {
a_k_split_offset = blockIdx.z * karg.KRead * karg.StrideA; a_k_split_offset = k_id * karg.KRead * karg.StrideA;
} }
if constexpr(is_same_v<tensor_layout::gemm::RowMajor, BLayout>) if constexpr(is_same_v<tensor_layout::gemm::RowMajor, BLayout>)
{ {
b_k_split_offset = blockIdx.z * karg.KRead * karg.StrideB; b_k_split_offset = k_id * karg.KRead * karg.StrideB;
} }
else if constexpr(is_same_v<tensor_layout::gemm::ColumnMajor, BLayout>) else if constexpr(is_same_v<tensor_layout::gemm::ColumnMajor, BLayout>)
{ {
b_k_split_offset = blockIdx.z * karg.KRead; b_k_split_offset = k_id * karg.KRead;
} }
if(blockIdx.z < static_cast<uint32_t>(karg.KBatch - 1)) if(k_id < karg.KBatch - 1)
{ {
karg.K = karg.KRead; karg.K = karg.KRead;
} }
......
...@@ -52,6 +52,9 @@ using device_batched_gemm_xdl_universal_bf16_bf16_bf16_mk_nk_mn_comp_instances = ...@@ -52,6 +52,9 @@ using device_batched_gemm_xdl_universal_bf16_bf16_bf16_mk_nk_mn_comp_instances =
DeviceBatchedGemmMultiD_Xdl_CShuffle_V3< Row, Col, DsLayout, Row, BF16, BF16, DsDataType, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 256, 256, 32, 8, 8, 32, 32, 4, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 16>, S<4>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v5>, DeviceBatchedGemmMultiD_Xdl_CShuffle_V3< Row, Col, DsLayout, Row, BF16, BF16, DsDataType, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 256, 256, 32, 8, 8, 32, 32, 4, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 16>, S<4>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v5>,
DeviceBatchedGemmMultiD_Xdl_CShuffle_V3< Row, Col, DsLayout, Row, BF16, BF16, DsDataType, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 224, 256, 64, 8, 8, 16, 16, 7, 8, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 2, S<1, 16, 1, 16>, S<4>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3>, DeviceBatchedGemmMultiD_Xdl_CShuffle_V3< Row, Col, DsLayout, Row, BF16, BF16, DsDataType, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 224, 256, 64, 8, 8, 16, 16, 7, 8, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 2, S<1, 16, 1, 16>, S<4>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3>,
DeviceBatchedGemmMultiD_Xdl_CShuffle_V3< Row, Col, DsLayout, Row, BF16, BF16, DsDataType, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 256, 224, 64, 8, 8, 16, 16, 8, 7, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 2, 1, S<1, 32, 1, 8>, S<4>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3>, DeviceBatchedGemmMultiD_Xdl_CShuffle_V3< Row, Col, DsLayout, Row, BF16, BF16, DsDataType, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 256, 224, 64, 8, 8, 16, 16, 8, 7, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 2, 1, S<1, 32, 1, 8>, S<4>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3>,
DeviceBatchedGemmMultiD_Xdl_CShuffle_V3< Row, Col, DsLayout, Row, BF16, BF16, DsDataType, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 256, 160, 64, 8, 8, 16, 16, 8, 5, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 2, 1, S<1, 32, 1, 8>, S<4>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3>,
DeviceBatchedGemmMultiD_Xdl_CShuffle_V3< Row, Col, DsLayout, Row, BF16, BF16, DsDataType, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 160, 64, 8, 8, 32, 32, 1, 5, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 64, 1, 4>, S<8>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3>,
DeviceBatchedGemmMultiD_Xdl_CShuffle_V3< Row, Col, DsLayout, Row, BF16, BF16, DsDataType, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 160, 128, 64, 8, 8, 32, 32, 5, 1, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 16>, S<4>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3>,
DeviceBatchedGemmMultiD_Xdl_CShuffle_V3< Row, Col, DsLayout, Row, BF16, BF16, DsDataType, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 64, 8, 8, 32, 32, 2, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 16>, S<4>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3>, DeviceBatchedGemmMultiD_Xdl_CShuffle_V3< Row, Col, DsLayout, Row, BF16, BF16, DsDataType, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 64, 8, 8, 32, 32, 2, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 16>, S<4>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3>,
DeviceBatchedGemmMultiD_Xdl_CShuffle_V3< Row, Col, DsLayout, Row, BF16, BF16, DsDataType, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 64, 8, 8, 32, 32, 2, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 16>, S<4>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v5>, DeviceBatchedGemmMultiD_Xdl_CShuffle_V3< Row, Col, DsLayout, Row, BF16, BF16, DsDataType, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 64, 8, 8, 32, 32, 2, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 16>, S<4>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v5>,
DeviceBatchedGemmMultiD_Xdl_CShuffle_V3< Row, Col, DsLayout, Row, BF16, BF16, DsDataType, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 64, 8, 8, 32, 32, 2, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 16>, S<4>, BlockGemmPipelineScheduler::Interwave, BlockGemmPipelineVersion::v1> DeviceBatchedGemmMultiD_Xdl_CShuffle_V3< Row, Col, DsLayout, Row, BF16, BF16, DsDataType, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 64, 8, 8, 32, 32, 2, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 16>, S<4>, BlockGemmPipelineScheduler::Interwave, BlockGemmPipelineVersion::v1>
......
...@@ -42,6 +42,7 @@ using device_batched_gemm_xdl_universal_f8_f8_bf16_mk_nk_mn_comp_instances = std ...@@ -42,6 +42,7 @@ using device_batched_gemm_xdl_universal_f8_f8_bf16_mk_nk_mn_comp_instances = std
//##################################| | | | | Type| Type| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise|Specialization| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MXdlPerWave_MWaveMPerXdl| ScalarPerVector| Pipeline| Pipeline| //##################################| | | | | Type| Type| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise|Specialization| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MXdlPerWave_MWaveMPerXdl| ScalarPerVector| Pipeline| Pipeline|
//##################################| | | | | | | | | | | Operation| Operation| Operation| | | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NXdlPerWave_NWaveNPerXdl| _NWaveNPerXdl| Scheduler| Verision| //##################################| | | | | | | | | | | Operation| Operation| Operation| | | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NXdlPerWave_NWaveNPerXdl| _NWaveNPerXdl| Scheduler| Verision|
//##################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | //##################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
#ifdef __gfx94__ #ifdef __gfx94__
// Compute friendly // Compute friendly
DeviceBatchedGemmMultiD_Xdl_CShuffle_V3< Row, Col, DsLayout, Row, F8, F8, DsDataType, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 256, 256, 64, 16, 16, 32, 32, 4, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 8>, S<8>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v4, F8>, DeviceBatchedGemmMultiD_Xdl_CShuffle_V3< Row, Col, DsLayout, Row, F8, F8, DsDataType, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 256, 256, 64, 16, 16, 32, 32, 4, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 8>, S<8>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v4, F8>,
...@@ -72,6 +73,7 @@ using device_batched_gemm_xdl_universal_f8_f8_bf16_mk_nk_mn_mem_instances = std: ...@@ -72,6 +73,7 @@ using device_batched_gemm_xdl_universal_f8_f8_bf16_mk_nk_mn_mem_instances = std:
//##################################| | | | | Type| Type| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise|Specialization| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MXdlPerWave_MWaveMPerXdl| ScalarPerVector| Pipeline| Pipeline| //##################################| | | | | Type| Type| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise|Specialization| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MXdlPerWave_MWaveMPerXdl| ScalarPerVector| Pipeline| Pipeline|
//##################################| | | | | | | | | | | Operation| Operation| Operation| | | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NXdlPerWave_NWaveNPerXdl| _NWaveNPerXdl| Scheduler| Verision| //##################################| | | | | | | | | | | Operation| Operation| Operation| | | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NXdlPerWave_NWaveNPerXdl| _NWaveNPerXdl| Scheduler| Verision|
//##################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | //##################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
#if defined(__gfx94__) || defined(CK_USE_FP8_ON_UNSUPPORTED_ARCH) #if defined(__gfx94__) || defined(CK_USE_FP8_ON_UNSUPPORTED_ARCH)
DeviceBatchedGemmMultiD_Xdl_CShuffle_V3< Row, Col, DsLayout, Row, F8, F8, DsDataType, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 32, 16, 128, 16, 16, 16, 16, 1, 1, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 16, 1, 8>, S<2>, BlkGemmPipeSched, BlockGemmPipelineVersion::v1, F8>, DeviceBatchedGemmMultiD_Xdl_CShuffle_V3< Row, Col, DsLayout, Row, F8, F8, DsDataType, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 32, 16, 128, 16, 16, 16, 16, 1, 1, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 16, 1, 8>, S<2>, BlkGemmPipeSched, BlockGemmPipelineVersion::v1, F8>,
DeviceBatchedGemmMultiD_Xdl_CShuffle_V3< Row, Col, DsLayout, Row, F8, F8, DsDataType, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 64, 16, 16, 128, 16, 16, 16, 16, 1, 1, S<8, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 16, 1, 4>, S<4>, BlkGemmPipeSched, BlockGemmPipelineVersion::v1, F8>, DeviceBatchedGemmMultiD_Xdl_CShuffle_V3< Row, Col, DsLayout, Row, F8, F8, DsDataType, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 64, 16, 16, 128, 16, 16, 16, 16, 1, 1, S<8, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 16, 1, 4>, S<4>, BlkGemmPipeSched, BlockGemmPipelineVersion::v1, F8>,
......
...@@ -48,6 +48,7 @@ bool profile_gemm_universal_batched_impl(int do_verification, ...@@ -48,6 +48,7 @@ bool profile_gemm_universal_batched_impl(int do_verification,
int StrideB, int StrideB,
int StrideC, int StrideC,
int BatchCount, int BatchCount,
int KBatch,
int n_warmup, int n_warmup,
int n_iter, int n_iter,
uint64_t rotating = 0) uint64_t rotating = 0)
...@@ -147,89 +148,100 @@ bool profile_gemm_universal_batched_impl(int do_verification, ...@@ -147,89 +148,100 @@ bool profile_gemm_universal_batched_impl(int do_verification,
float best_ave_time = 0; float best_ave_time = 0;
float best_tflops = 0; float best_tflops = 0;
float best_gb_per_sec = 0; float best_gb_per_sec = 0;
float best_kbatch = 0;
// profile device op instances // profile device op instances
for(auto& op_ptr : op_ptrs) for(auto& op_ptr : op_ptrs)
{ {
std::unique_ptr<tensor_operation::device::BaseArgument> argument_ptr; std::vector<int> kbatch_list = {1, 2, 4, 8, 16, 19, 32, 38};
// false branch for multi d dl kernel
argument_ptr =
op_ptr->MakeArgumentPointer(static_cast<ADataType*>(a_device_buf.GetDeviceBuffer()),
static_cast<BDataType*>(b_device_buf.GetDeviceBuffer()),
{},
static_cast<CDataType*>(c_device_buf.GetDeviceBuffer()),
M,
N,
K,
BatchCount,
StrideA,
StrideB,
{},
StrideC,
BatchStrideA,
BatchStrideB,
{},
BatchStrideC,
ck::tensor_operation::element_wise::PassThrough{},
ck::tensor_operation::element_wise::PassThrough{},
ck::tensor_operation::element_wise::PassThrough{});
auto invoker_ptr = op_ptr->MakeInvokerPointer();
if(op_ptr->IsSupportedArgument(argument_ptr.get()))
{
// re-init C to zero before profiling next kernel
c_device_buf.SetZero();
std::string op_name = op_ptr->GetTypeString();
float ave_time = invoker_ptr->Run( if(KBatch > 0)
argument_ptr.get(), {
StreamConfig{nullptr, time_kernel, 0, n_warmup, n_iter, true, rotating_count}); kbatch_list = {KBatch};
}
std::size_t flop = std::size_t(2) * BatchCount * M * N * K; for(std::size_t i = 0; i < kbatch_list.size(); i++)
{
auto kbatch_curr = kbatch_list[i];
auto argument_ptr =
op_ptr->MakeArgumentPointer(static_cast<ADataType*>(a_device_buf.GetDeviceBuffer()),
static_cast<BDataType*>(b_device_buf.GetDeviceBuffer()),
{},
static_cast<CDataType*>(c_device_buf.GetDeviceBuffer()),
M,
N,
K,
BatchCount,
StrideA,
StrideB,
{},
StrideC,
BatchStrideA,
BatchStrideB,
{},
BatchStrideC,
ck::tensor_operation::element_wise::PassThrough{},
ck::tensor_operation::element_wise::PassThrough{},
ck::tensor_operation::element_wise::PassThrough{},
kbatch_curr);
auto invoker_ptr = op_ptr->MakeInvokerPointer();
if(op_ptr->IsSupportedArgument(argument_ptr.get()))
{
std::string op_name = op_ptr->GetTypeString();
std::size_t num_btype = (sizeof(ADataType) * M * K + sizeof(BDataType) * K * N + float ave_time = invoker_ptr->Run(
sizeof(CDataType) * M * N) * argument_ptr.get(),
BatchCount; StreamConfig{nullptr, time_kernel, 0, n_warmup, n_iter, true, rotating_count});
float tflops = static_cast<float>(flop) / 1.E9 / ave_time; std::size_t flop = std::size_t(2) * BatchCount * M * N * K;
float gb_per_sec = num_btype / 1.E6 / ave_time; std::size_t num_btype = (sizeof(ADataType) * M * K + sizeof(BDataType) * K * N +
sizeof(CDataType) * M * N) *
BatchCount;
std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec float tflops = static_cast<float>(flop) / 1.E9 / ave_time;
<< " GB/s, " << op_name << std::endl;
if(tflops > best_tflops) float gb_per_sec = num_btype / 1.E6 / ave_time;
{
best_op_name = op_name;
best_tflops = tflops;
best_ave_time = ave_time;
best_gb_per_sec = gb_per_sec;
}
if(do_verification) std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec
{ << " GB/s, " << op_name << ", KBatch " << kbatch_curr << std::endl;
c_device_buf.FromDevice(c_g_m_n_device_result.mData.data());
pass = pass & ck::utils::check_err(c_g_m_n_device_result, c_g_m_n_host_result); if(tflops > best_tflops)
{
best_op_name = op_name;
best_tflops = tflops;
best_ave_time = ave_time;
best_gb_per_sec = gb_per_sec;
best_kbatch = kbatch_curr;
}
if(do_log) if(do_verification)
{ {
LogRangeAsType<float>(std::cout << "a : ", a_g_m_k.mData, ",") << std::endl; c_device_buf.FromDevice(c_g_m_n_device_result.mData.data());
LogRangeAsType<float>(std::cout << "b: ", b_g_k_n.mData, ",") << std::endl;
LogRangeAsType<float>(std::cout << "c_host: ", c_g_m_n_host_result.mData, ",") pass = pass & ck::utils::check_err(c_g_m_n_device_result, c_g_m_n_host_result);
<< std::endl;
LogRangeAsType<float>( if(do_log)
std::cout << "c_device: ", c_g_m_n_device_result.mData, ",") {
<< std::endl; LogRangeAsType<float>(std::cout << "a : ", a_g_m_k.mData, ",") << std::endl;
LogRangeAsType<float>(std::cout << "b: ", b_g_k_n.mData, ",") << std::endl;
LogRangeAsType<float>(
std::cout << "c_host: ", c_g_m_n_host_result.mData, ",")
<< std::endl;
LogRangeAsType<float>(
std::cout << "c_device: ", c_g_m_n_device_result.mData, ",")
<< std::endl;
}
} }
} }
} else
else {
{ std::cout << op_ptr->GetTypeString() << " does not support this problem"
std::cout << op_ptr->GetTypeString() << " does not support this problem" << std::endl; << std::endl;
}
} }
} }
...@@ -270,8 +282,8 @@ bool profile_gemm_universal_batched_impl(int do_verification, ...@@ -270,8 +282,8 @@ bool profile_gemm_universal_batched_impl(int do_verification,
std::cout << " B = " << BatchCount << " M = " << M << " N = " << N << " K = " << K std::cout << " B = " << BatchCount << " M = " << M << " N = " << N << " K = " << K
<< " StrideA = " << StrideA << " StrideB = " << StrideB << " StrideC = " << StrideC << " StrideA = " << StrideA << " StrideB = " << StrideB << " StrideC = " << StrideC
<< ": " << best_ave_time << " ms, " << best_tflops << " TFlops, " << best_gb_per_sec << " KBatch = " << best_kbatch << ": " << best_ave_time << " ms, " << best_tflops
<< " GB/s, " << best_op_name << std::endl; << " TFlops, " << best_gb_per_sec << " GB/s, " << best_op_name << std::endl;
return pass; return pass;
} }
......
...@@ -31,7 +31,7 @@ enum struct GemmDataType ...@@ -31,7 +31,7 @@ enum struct GemmDataType
int profile_batched_gemm_universal(int argc, char* argv[]) int profile_batched_gemm_universal(int argc, char* argv[])
{ {
if(argc != 18 && argc != 21) if(argc != 19 && argc != 22)
{ {
// clang-format off // clang-format off
printf("arg1: tensor operation (" OP_NAME ": " OP_DESC ")\n"); printf("arg1: tensor operation (" OP_NAME ": " OP_DESC ")\n");
...@@ -44,11 +44,11 @@ int profile_batched_gemm_universal(int argc, char* argv[]) ...@@ -44,11 +44,11 @@ int profile_batched_gemm_universal(int argc, char* argv[])
printf("arg5: initialization (0: no init; 1: integer value; 2: decimal value)\n"); printf("arg5: initialization (0: no init; 1: integer value; 2: decimal value)\n");
printf("arg6: print tensor value (0: no; 1: yes)\n"); printf("arg6: print tensor value (0: no; 1: yes)\n");
printf("arg7: time kernel (0=n0, 1=yes)\n"); printf("arg7: time kernel (0=n0, 1=yes)\n");
printf("arg8 to 17: M, N, K, StrideA, StrideB, StrideC, BatchStrideA, BatchStrideB, BatchStrideC, BatchCount\n"); printf("arg8 to 18: M, N, K, StrideA, StrideB, StrideC, BatchStrideA, BatchStrideB, BatchStrideC, BatchCount, KBatch\n");
printf("optional:\n"); printf("optional:\n");
printf("arg18: number of warm-up cycles (default 1)\n"); printf("arg19: number of warm-up cycles (default 1)\n");
printf("arg19: number of iterations (default 10)\n"); printf("arg20: number of iterations (default 10)\n");
printf("arg20: memory for rotating buffer (default 0, size in MB)\n"); printf("arg21: memory for rotating buffer (default 0, size in MB)\n");
// clang-format on // clang-format on
exit(1); exit(1);
} }
...@@ -56,11 +56,11 @@ int profile_batched_gemm_universal(int argc, char* argv[]) ...@@ -56,11 +56,11 @@ int profile_batched_gemm_universal(int argc, char* argv[])
int n_warmup = 1; int n_warmup = 1;
int n_iter = 10; int n_iter = 10;
uint64_t rotating = 0; uint64_t rotating = 0;
if(argc == 21) if(argc == 22)
{ {
n_warmup = std::stoi(argv[18]); n_warmup = std::stoi(argv[19]);
n_iter = std::stoi(argv[19]); n_iter = std::stoi(argv[20]);
rotating = std::stoull(argv[20]) * 1024 * 1024; rotating = std::stoull(argv[21]) * 1024 * 1024;
} }
const auto data_type = static_cast<GemmDataType>(std::stoi(argv[2])); const auto data_type = static_cast<GemmDataType>(std::stoi(argv[2]));
...@@ -83,6 +83,7 @@ int profile_batched_gemm_universal(int argc, char* argv[]) ...@@ -83,6 +83,7 @@ int profile_batched_gemm_universal(int argc, char* argv[])
const int BatchStrideC = std::stoi(argv[16]); const int BatchStrideC = std::stoi(argv[16]);
const int BatchCount = std::stoi(argv[17]); const int BatchCount = std::stoi(argv[17]);
const int KBatch = std::stoi(argv[18]);
#if defined(CK_USE_FP8_ON_UNSUPPORTED_ARCH) || defined(CK_USE_GFX94) #if defined(CK_USE_FP8_ON_UNSUPPORTED_ARCH) || defined(CK_USE_GFX94)
using F8 = ck::f8_t; using F8 = ck::f8_t;
...@@ -159,6 +160,7 @@ int profile_batched_gemm_universal(int argc, char* argv[]) ...@@ -159,6 +160,7 @@ int profile_batched_gemm_universal(int argc, char* argv[])
StrideB_, StrideB_,
StrideC_, StrideC_,
BatchCount, BatchCount,
KBatch,
n_warmup, n_warmup,
n_iter, n_iter,
rotating); rotating);
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment