Commit 43596386 authored by Po Yen Chen's avatar Po Yen Chen
Browse files

Merge branch 'feature/add-splitkv-instance' into...

Merge branch 'feature/add-splitkv-instance' into feature/support-vllm-kcache-layout-add-splitkv-instance
parents 250399cd af07d650
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
......@@ -157,7 +157,7 @@ bool profile_batched_gemm_gemm_impl(bool do_verification,
break;
default:
a_g_m_k.GenerateTensorValue(GeneratorTensor_1<ADataType>{1});
b0_g_k_n.GenerateTensorValue(GeneratorTensor_Sequential<1>{});
b0_g_k_n.GenerateTensorValue(GeneratorTensor_Sequential<B0DataType, 1>{});
b1_g_n_o.GenerateTensorValue(GeneratorTensor_Diagonal<B1DataType>{});
}
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
......@@ -174,7 +174,7 @@ bool profile_batched_gemm_softmax_gemm_impl(bool do_verification,
break;
default:
a_g_m_k.GenerateTensorValue(GeneratorTensor_1<ADataType>{1});
b0_g_k_n.GenerateTensorValue(GeneratorTensor_Sequential<1>{});
b0_g_k_n.GenerateTensorValue(GeneratorTensor_Sequential<B0DataType, 1>{});
b1_g_n_o.GenerateTensorValue(GeneratorTensor_Diagonal<B1DataType>{});
}
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
......@@ -140,7 +140,7 @@ bool profile_batched_gemm_softmax_gemm_permute_impl(bool do_verification,
break;
default:
a_gs_ms_ks.GenerateTensorValue(GeneratorTensor_1<ADataType>{1});
b0_gs_ns_ks.GenerateTensorValue(GeneratorTensor_Sequential<1>{});
b0_gs_ns_ks.GenerateTensorValue(GeneratorTensor_Sequential<B0DataType, 1>{});
b1_gs_os_ns.GenerateTensorValue(GeneratorTensor_Diagonal<B1DataType>{});
}
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
......@@ -74,8 +74,8 @@ int profile_gemm_impl(int do_verification,
switch(init_method)
{
case 0:
ck::utils::FillConstant<ADataType>{static_cast<ADataType>(1.f)}(a_m_k);
ck::utils::FillConstant<BDataType>{static_cast<BDataType>(1.f)}(b_k_n);
ck::utils::FillConstant<ADataType>{type_convert<ADataType>(1.f)}(a_m_k);
ck::utils::FillConstant<BDataType>{type_convert<BDataType>(1.f)}(b_k_n);
break;
case 1:
ck::utils::FillUniformDistributionIntegerValue<ADataType>{-5.f, 5.f}(a_m_k);
......
......@@ -48,6 +48,7 @@ bool profile_gemm_universal_batched_impl(int do_verification,
int StrideB,
int StrideC,
int BatchCount,
int KBatch,
int n_warmup,
int n_iter,
uint64_t rotating = 0)
......@@ -147,89 +148,100 @@ bool profile_gemm_universal_batched_impl(int do_verification,
float best_ave_time = 0;
float best_tflops = 0;
float best_gb_per_sec = 0;
float best_kbatch = 0;
// profile device op instances
for(auto& op_ptr : op_ptrs)
{
std::unique_ptr<tensor_operation::device::BaseArgument> argument_ptr;
// false branch for multi d dl kernel
argument_ptr =
op_ptr->MakeArgumentPointer(static_cast<ADataType*>(a_device_buf.GetDeviceBuffer()),
static_cast<BDataType*>(b_device_buf.GetDeviceBuffer()),
{},
static_cast<CDataType*>(c_device_buf.GetDeviceBuffer()),
M,
N,
K,
BatchCount,
StrideA,
StrideB,
{},
StrideC,
BatchStrideA,
BatchStrideB,
{},
BatchStrideC,
ck::tensor_operation::element_wise::PassThrough{},
ck::tensor_operation::element_wise::PassThrough{},
ck::tensor_operation::element_wise::PassThrough{});
auto invoker_ptr = op_ptr->MakeInvokerPointer();
if(op_ptr->IsSupportedArgument(argument_ptr.get()))
{
// re-init C to zero before profiling next kernel
c_device_buf.SetZero();
std::string op_name = op_ptr->GetTypeString();
std::vector<int> kbatch_list = {1, 2, 4, 8, 16, 19, 32, 38};
float ave_time = invoker_ptr->Run(
argument_ptr.get(),
StreamConfig{nullptr, time_kernel, 0, n_warmup, n_iter, true, rotating_count});
if(KBatch > 0)
{
kbatch_list = {KBatch};
}
std::size_t flop = std::size_t(2) * BatchCount * M * N * K;
for(std::size_t i = 0; i < kbatch_list.size(); i++)
{
auto kbatch_curr = kbatch_list[i];
auto argument_ptr =
op_ptr->MakeArgumentPointer(static_cast<ADataType*>(a_device_buf.GetDeviceBuffer()),
static_cast<BDataType*>(b_device_buf.GetDeviceBuffer()),
{},
static_cast<CDataType*>(c_device_buf.GetDeviceBuffer()),
M,
N,
K,
BatchCount,
StrideA,
StrideB,
{},
StrideC,
BatchStrideA,
BatchStrideB,
{},
BatchStrideC,
ck::tensor_operation::element_wise::PassThrough{},
ck::tensor_operation::element_wise::PassThrough{},
ck::tensor_operation::element_wise::PassThrough{},
kbatch_curr);
auto invoker_ptr = op_ptr->MakeInvokerPointer();
if(op_ptr->IsSupportedArgument(argument_ptr.get()))
{
std::string op_name = op_ptr->GetTypeString();
std::size_t num_btype = (sizeof(ADataType) * M * K + sizeof(BDataType) * K * N +
sizeof(CDataType) * M * N) *
BatchCount;
float ave_time = invoker_ptr->Run(
argument_ptr.get(),
StreamConfig{nullptr, time_kernel, 0, n_warmup, n_iter, true, rotating_count});
float tflops = static_cast<float>(flop) / 1.E9 / ave_time;
std::size_t flop = std::size_t(2) * BatchCount * M * N * K;
float gb_per_sec = num_btype / 1.E6 / ave_time;
std::size_t num_btype = (sizeof(ADataType) * M * K + sizeof(BDataType) * K * N +
sizeof(CDataType) * M * N) *
BatchCount;
std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec
<< " GB/s, " << op_name << std::endl;
float tflops = static_cast<float>(flop) / 1.E9 / ave_time;
if(tflops > best_tflops)
{
best_op_name = op_name;
best_tflops = tflops;
best_ave_time = ave_time;
best_gb_per_sec = gb_per_sec;
}
float gb_per_sec = num_btype / 1.E6 / ave_time;
if(do_verification)
{
c_device_buf.FromDevice(c_g_m_n_device_result.mData.data());
std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec
<< " GB/s, " << op_name << ", KBatch " << kbatch_curr << std::endl;
pass = pass & ck::utils::check_err(c_g_m_n_device_result, c_g_m_n_host_result);
if(tflops > best_tflops)
{
best_op_name = op_name;
best_tflops = tflops;
best_ave_time = ave_time;
best_gb_per_sec = gb_per_sec;
best_kbatch = kbatch_curr;
}
if(do_log)
if(do_verification)
{
LogRangeAsType<float>(std::cout << "a : ", a_g_m_k.mData, ",") << std::endl;
LogRangeAsType<float>(std::cout << "b: ", b_g_k_n.mData, ",") << std::endl;
LogRangeAsType<float>(std::cout << "c_host: ", c_g_m_n_host_result.mData, ",")
<< std::endl;
LogRangeAsType<float>(
std::cout << "c_device: ", c_g_m_n_device_result.mData, ",")
<< std::endl;
c_device_buf.FromDevice(c_g_m_n_device_result.mData.data());
pass = pass & ck::utils::check_err(c_g_m_n_device_result, c_g_m_n_host_result);
if(do_log)
{
LogRangeAsType<float>(std::cout << "a : ", a_g_m_k.mData, ",") << std::endl;
LogRangeAsType<float>(std::cout << "b: ", b_g_k_n.mData, ",") << std::endl;
LogRangeAsType<float>(
std::cout << "c_host: ", c_g_m_n_host_result.mData, ",")
<< std::endl;
LogRangeAsType<float>(
std::cout << "c_device: ", c_g_m_n_device_result.mData, ",")
<< std::endl;
}
}
}
}
else
{
std::cout << op_ptr->GetTypeString() << " does not support this problem" << std::endl;
else
{
std::cout << op_ptr->GetTypeString() << " does not support this problem"
<< std::endl;
}
}
}
......@@ -270,8 +282,8 @@ bool profile_gemm_universal_batched_impl(int do_verification,
std::cout << " B = " << BatchCount << " M = " << M << " N = " << N << " K = " << K
<< " StrideA = " << StrideA << " StrideB = " << StrideB << " StrideC = " << StrideC
<< ": " << best_ave_time << " ms, " << best_tflops << " TFlops, " << best_gb_per_sec
<< " GB/s, " << best_op_name << std::endl;
<< " KBatch = " << best_kbatch << ": " << best_ave_time << " ms, " << best_tflops
<< " TFlops, " << best_gb_per_sec << " GB/s, " << best_op_name << std::endl;
return pass;
}
......
......@@ -144,6 +144,7 @@ bool profile_gemm_universal_impl(int do_verification,
}
std::string best_op_name;
std::optional<std::string> best_op_object_name;
float best_ave_time = 0;
float best_tflops = 0;
float best_gb_per_sec = 0;
......@@ -225,7 +226,8 @@ bool profile_gemm_universal_impl(int do_verification,
}
}
std::string op_name = op_ptr->GetTypeString();
std::string op_name = op_ptr->GetTypeString();
std::optional<std::string> op_obj_name = op_ptr->GetObjectName();
float ave_time = invoker_ptr->Run(argument_ptr.get(),
StreamConfig{nullptr,
......@@ -251,11 +253,12 @@ bool profile_gemm_universal_impl(int do_verification,
if(tflops > best_tflops && ave_time > 1e-10)
{
best_op_name = op_name;
best_tflops = tflops;
best_ave_time = ave_time;
best_gb_per_sec = gb_per_sec;
best_kbatch = kbatch_curr;
best_op_name = op_name;
best_op_object_name = op_obj_name;
best_tflops = tflops;
best_ave_time = ave_time;
best_gb_per_sec = gb_per_sec;
best_kbatch = kbatch_curr;
}
}
else
......@@ -306,6 +309,9 @@ bool profile_gemm_universal_impl(int do_verification,
<< " : " << best_ave_time << " ms, " << best_tflops << " TFlops, " << best_gb_per_sec
<< " GB/s, " << best_op_name << std::endl;
if(best_op_object_name)
std::cout << best_op_object_name.value() << std::endl;
return pass;
}
......
......@@ -31,7 +31,7 @@ enum struct GemmDataType
int profile_batched_gemm_universal(int argc, char* argv[])
{
if(argc != 18 && argc != 21)
if(argc != 19 && argc != 22)
{
// clang-format off
printf("arg1: tensor operation (" OP_NAME ": " OP_DESC ")\n");
......@@ -44,11 +44,11 @@ int profile_batched_gemm_universal(int argc, char* argv[])
printf("arg5: initialization (0: no init; 1: integer value; 2: decimal value)\n");
printf("arg6: print tensor value (0: no; 1: yes)\n");
printf("arg7: time kernel (0=n0, 1=yes)\n");
printf("arg8 to 17: M, N, K, StrideA, StrideB, StrideC, BatchStrideA, BatchStrideB, BatchStrideC, BatchCount\n");
printf("arg8 to 18: M, N, K, StrideA, StrideB, StrideC, BatchStrideA, BatchStrideB, BatchStrideC, BatchCount, KBatch\n");
printf("optional:\n");
printf("arg18: number of warm-up cycles (default 1)\n");
printf("arg19: number of iterations (default 10)\n");
printf("arg20: memory for rotating buffer (default 0, size in MB)\n");
printf("arg19: number of warm-up cycles (default 1)\n");
printf("arg20: number of iterations (default 10)\n");
printf("arg21: memory for rotating buffer (default 0, size in MB)\n");
// clang-format on
exit(1);
}
......@@ -56,11 +56,11 @@ int profile_batched_gemm_universal(int argc, char* argv[])
int n_warmup = 1;
int n_iter = 10;
uint64_t rotating = 0;
if(argc == 21)
if(argc == 22)
{
n_warmup = std::stoi(argv[18]);
n_iter = std::stoi(argv[19]);
rotating = std::stoull(argv[20]) * 1024 * 1024;
n_warmup = std::stoi(argv[19]);
n_iter = std::stoi(argv[20]);
rotating = std::stoull(argv[21]) * 1024 * 1024;
}
const auto data_type = static_cast<GemmDataType>(std::stoi(argv[2]));
......@@ -83,6 +83,7 @@ int profile_batched_gemm_universal(int argc, char* argv[])
const int BatchStrideC = std::stoi(argv[16]);
const int BatchCount = std::stoi(argv[17]);
const int KBatch = std::stoi(argv[18]);
#if defined(CK_USE_FP8_ON_UNSUPPORTED_ARCH) || defined(CK_USE_GFX94)
using F8 = ck::f8_t;
......@@ -159,6 +160,7 @@ int profile_batched_gemm_universal(int argc, char* argv[])
StrideB_,
StrideC_,
BatchCount,
KBatch,
n_warmup,
n_iter,
rotating);
......
......@@ -82,7 +82,7 @@ def parse_logfile(logfile):
StrideA=[]
StrideB=[]
StrideC=[]
if 'perf_gemm.log' in logfile:
if 'perf_gemm' in logfile and 'gemm_bilinear' not in logfile:
for line in open(logfile):
if 'Best Perf' in line:
lst=line.split()
......@@ -260,7 +260,7 @@ def main():
conn = sqlEngine.connect()
#save gemm performance tests:
if 'perf_gemm.log' in filename:
if 'perf_gemm' in filename and 'gemm_bilinear' not in filename:
#write the ck_gemm_test_params table only needed once the test set changes
#post_test_params(test_list,conn)
for i in range(1,len(results)+1):
......@@ -332,7 +332,7 @@ def main():
table_name="ck_fmha_bwd_tflops"
tflops_base = get_baseline(table_name,conn)
store_new_test_result(table_name, results, testlist, branch_name, node_id, gpu_arch, compute_units, rocm_vers, hip_vers, environment, conn)
store_new_test_result(table_name, results, testlist, branch_name, node_id, gpu_arch, compute_units, rocm_vers, hip_vers, environment, sqlEngine)
conn.close()
#compare the results to the baseline if baseline exists
......
......@@ -11,9 +11,22 @@
#process results
python3 process_perf_data.py perf_gemm.log
python3 process_perf_data.py perf_onnx_gemm.log
python3 process_perf_data.py perf_resnet50_N256.log
python3 process_perf_data.py perf_resnet50_N4.log
file=./perf_onnx_gemm_gfx10.log
if [ -e "$file" ]; then
python3 process_perf_data.py perf_onnx_gemm_gfx10.log
fi
file=./perf_onnx_gemm_gfx11.log
if [ -e "$file" ]; then
python3 process_perf_data.py perf_onnx_gemm_gfx11.log
fi
file=./perf_onnx_gemm_gfx12.log
if [ -e "$file" ]; then
python3 process_perf_data.py perf_onnx_gemm_gfx12.log
fi
file=./perf_fmha_fwd_gfx942.log
if [ -e "$file" ]; then
python3 process_perf_data.py perf_fmha_fwd_gfx942.log
......
......@@ -24,6 +24,18 @@ python3 process_perf_data.py perf_splitK_gemm.log
python3 process_perf_data.py perf_onnx_gemm.log
python3 process_perf_data.py perf_mixed_gemm.log
file=./perf_onnx_gemm_gfx10.log
if [ -e "$file" ]; then
python3 process_perf_data.py perf_onnx_gemm_gfx10.log
fi
file=./perf_onnx_gemm_gfx11.log
if [ -e "$file" ]; then
python3 process_perf_data.py perf_onnx_gemm_gfx11.log
fi
file=./perf_onnx_gemm_gfx12.log
if [ -e "$file" ]; then
python3 process_perf_data.py perf_onnx_gemm_gfx12.log
fi
file=./perf_fmha_fwd_gfx942.log
if [ -e "$file" ]; then
python3 process_perf_data.py perf_fmha_fwd_gfx942.log
......
......@@ -5,7 +5,7 @@
# post your new test results to the database and compare them to the baseline
# please contact Illia.Silin@amd.com for more details
#
# run the script as "./run_full_performance_tests.sh <verification> <tag for your test environment> <branch name> < node name>
# run the script as "./run_full_performance_tests.sh <verification> <tag for your test environment> <branch name> <node name>
# input arguments:
# verification = 0 : do not verify result correctness on CPU
# = 1 : verifuy correctness on CPU (may take a long time)
......
#!/bin/bash
#
# in order to run this script you'd first need to build the ckProfiler executable in ../build/bin/
# run the script as "./run_gemm_performance_tests.sh <verification> <tag for your test environment> <branch name> <node name> <arch>
# input arguments:
# verification = 0 : do not verify result correctness on CPU
# = 1 : verify correctness on CPU (may take a long time)
# environment tag : a string describing the specifics of your test environment
# branch name : name of the branch in git repo (git status | grep -e 'On branch')
# node name : $hostname
# arch : GPU architecture, e.g. "gfx9" or "gfx1100"
#get the command line arguments:
export verify=$1
echo 'Verification: ' $verify
export env_type=$2
echo 'Environment type: ' $env_type
export branch=$3
echo 'Branch name: ' $branch
export host_name=$4
echo 'Host name: ' $host_name
export arch=$5
echo 'GPU architecture: ' $arch
function print_log_header(){
rm -f $1;
echo 'On branch ' $3 &> $1;
echo 'Node name: ' $4 >> $1;
#get GPU_arch and number of compute units from rocminfo
echo -n "GPU_arch: " >> $1; rocminfo | grep "Name:" | grep "gfx" >> $1;
rocminfo | grep "Compute Unit:" >> $1;
hipcc --version | grep -e 'HIP version' >> $1;
echo 'Environment type: ' $2 >> $1;
/opt/rocm/bin/amdclang++ --version | grep -e 'InstalledDir' >> $1;
}
#run ONNX gemm tests
export onnx_log="perf_onnx_gemm_$arch.log"
print_log_header $onnx_log $env_type $branch $host_name
./profile_onnx_gemm.sh gemm 0 0 $verify 1 0 1 2>&1 | tee -a $onnx_log
./profile_onnx_gemm.sh gemm 1 0 $verify 1 0 1 2>&1 | tee -a $onnx_log
#!/bin/bash
#
# in order to run this script you'd first need to build the ckProfiler executable in ../build/bin/
# run the script as "./run_performance_tests.sh <verification> <tag for your test environment> <branch name> < node name>
# run the script as "./run_performance_tests.sh <verification> <tag for your test environment> <branch name> <node name>
# input arguments:
# verification = 0 : do not verify result correctness on CPU
# = 1 : verify correctness on CPU (may take a long time)
......@@ -51,20 +51,11 @@ print_log_header $gemm_log $env_type $branch $host_name
./profile_gemm.sh gemm 2 3 $verify 1 0 1 | tee -a $gemm_log
./profile_gemm.sh gemm 3 3 $verify 1 0 1 | tee -a $gemm_log
#run grouped_fwd fp16 tests
export grouped_conv_fwd_log="perf_grouped_conv_fwd_fp16.log"
print_log_header $conv_fwd_log $env_type $branch $host_name
./profile_grouped_conv_fwd.sh grouped_conv_fwd 1 1 0 $verify 1 0 1 256 2>&1 | tee -a $grouped_conv_fwd_log
#run grouped_bwd_data fp16 tests
export grouped_conv_bwd_data_log="perf_grouped_conv_bwd_data_fp16.log"
print_log_header $grouped_conv_bwd_data_log $env_type $branch $host_name
./profile_grouped_conv_bwd_data.sh grouped_conv_bwd_data 1 1 $verify 1 0 1 256 2>&1 | tee -a $grouped_conv_bwd_data_log
#run grouped_bwd_weight fp16 tests
export grouped_conv_bwd_weight_log="perf_grouped_conv_bwd_weight_fp16.log"
print_log_header $grouped_conv_bwd_weight_log $env_type $branch $host_name
./profile_grouped_conv_bwd_weight.sh grouped_conv_bwd_weight 1 1 $verify 1 0 1 256 1 2>&1 | tee -a $grouped_conv_bwd_weight_log
#run ONNX gemm tests
export onnx_log="perf_onnx_gemm.log"
print_log_header $onnx_log $env_type $branch $host_name
./profile_onnx_gemm.sh gemm 0 0 $verify 1 0 1 2>&1 | tee -a $onnx_log
./profile_onnx_gemm.sh gemm 1 0 $verify 1 0 1 2>&1 | tee -a $onnx_log
#run resnet50 tests
export resnet256_log="perf_resnet50_N256.log"
......
add_subdirectory(image_to_column)
add_subdirectory(gemm)
add_subdirectory(batched_gemm)
add_subdirectory(grouped_gemm)
......@@ -8,35 +8,29 @@
#include "ck_tile/host.hpp"
#include "test_gemm_mem_pipeline_util.hpp"
using F16 = ck_tile::half_t;
using F32 = float;
using Row = ck_tile::tensor_layout::gemm::RowMajor;
using Col = ck_tile::tensor_layout::gemm::ColumnMajor;
static constexpr auto Intrawave = ck_tile::GemmPipelineScheduler::Intrawave;
static constexpr auto Interwave = ck_tile::GemmPipelineScheduler::Interwave;
template <typename Tuple>
class TestCkTileGemmMemPipelineIntrawave : public TestCkTileGemmMemPipeline<Tuple, Intrawave>
{
};
template <typename Tuple>
class TestCkTileGemmMemPipelineInterwave : public TestCkTileGemmMemPipeline<Tuple, Interwave>
{
};
using F16 = ck_tile::half_t;
using F32 = float;
using Row = ck_tile::tensor_layout::gemm::RowMajor;
using Col = ck_tile::tensor_layout::gemm::ColumnMajor;
using Intrawave = ck_tile::integral_constant<ck_tile::GemmPipelineScheduler,
ck_tile::GemmPipelineScheduler::Intrawave>;
using Interwave = ck_tile::integral_constant<ck_tile::GemmPipelineScheduler,
ck_tile::GemmPipelineScheduler::Interwave>;
// clang-format off
using KernelTypes = ::testing::Types<
// ALayout, BLayout, CLayout, ADataType, BDataType, AccDataType, CDataType
std::tuple< Row, Col, Row, F16, F16, F32, F16>,
std::tuple< Col, Row, Row, F16, F16, F32, F16>,
std::tuple< Row, Row, Row, F16, F16, F32, F16>,
std::tuple< Col, Col, Row, F16, F16, F32, F16>
// ALayout, BLayout, CLayout, ADataType, BDataType, AccDataType, CDataType, GemmPipelineScheduler
std::tuple< Row, Row, Row, F16, F16, F32, F16, Intrawave>,
std::tuple< Row, Row, Row, F16, F16, F32, F16, Interwave>,
std::tuple< Row, Col, Row, F16, F16, F32, F16, Intrawave>,
std::tuple< Row, Col, Row, F16, F16, F32, F16, Interwave>,
std::tuple< Col, Row, Row, F16, F16, F32, F16, Intrawave>,
std::tuple< Col, Row, Row, F16, F16, F32, F16, Interwave>,
std::tuple< Col, Col, Row, F16, F16, F32, F16, Intrawave>,
std::tuple< Col, Col, Row, F16, F16, F32, F16, Interwave>
>;
// clang-format on
TYPED_TEST_SUITE(TestCkTileGemmMemPipelineIntrawave, KernelTypes);
TYPED_TEST_SUITE(TestCkTileGemmMemPipelineInterwave, KernelTypes);
TYPED_TEST_SUITE(TestCkTileGemmMemPipeline, KernelTypes);
#include "test_gemm_mem_pipeline_ut_cases.inc"
......@@ -3,11 +3,7 @@
#pragma once
//------------------------------------------------------------------------------------------------
// INTERWAVE SCHEDULER
//------------------------------------------------------------------------------------------------
TYPED_TEST(TestCkTileGemmMemPipelineInterwave, SmallM)
TYPED_TEST(TestCkTileGemmMemPipeline, SmallM)
{
std::vector<int> Ms{1, 2, 3, 4, 5, 6};
constexpr int N = 1024;
......@@ -17,7 +13,7 @@ TYPED_TEST(TestCkTileGemmMemPipelineInterwave, SmallM)
this->Run(M, N, K);
}
TYPED_TEST(TestCkTileGemmMemPipelineInterwave, MidLargeM)
TYPED_TEST(TestCkTileGemmMemPipeline, MidLargeM)
{
std::vector<int> Ms{127, 255, 312, 799, 1573};
constexpr int N = 1024;
......@@ -27,7 +23,7 @@ TYPED_TEST(TestCkTileGemmMemPipelineInterwave, MidLargeM)
this->Run(M, N, K);
}
TYPED_TEST(TestCkTileGemmMemPipelineInterwave, PaddK)
TYPED_TEST(TestCkTileGemmMemPipeline, PaddK)
{
std::vector<int> Ms{127};
constexpr int N = 1024;
......@@ -37,7 +33,7 @@ TYPED_TEST(TestCkTileGemmMemPipelineInterwave, PaddK)
this->Run(M, N, K);
}
TYPED_TEST(TestCkTileGemmMemPipelineInterwave, Regular)
TYPED_TEST(TestCkTileGemmMemPipeline, Regular)
{
std::vector<int> Ms{512};
constexpr int N = 1024;
......@@ -47,46 +43,15 @@ TYPED_TEST(TestCkTileGemmMemPipelineInterwave, Regular)
this->Run(M, N, K);
}
//------------------------------------------------------------------------------------------------
// INTRAWAVE SCHEDULER
//------------------------------------------------------------------------------------------------
TYPED_TEST(TestCkTileGemmMemPipelineIntrawave, SmallM)
TYPED_TEST(TestCkTileGemmMemPipeline, NotSupportedArgument)
{
std::vector<int> Ms{1, 2, 3, 4, 5, 6};
constexpr int N = 1024;
constexpr int K = 320;
for(int M : Ms)
this->Run(M, N, K);
}
constexpr int M = 512;
constexpr int N = 1025;
constexpr int K = 513;
TYPED_TEST(TestCkTileGemmMemPipelineIntrawave, MidLargeM)
{
std::vector<int> Ms{127, 255, 312, 799, 1573};
constexpr int N = 1024;
constexpr int K = 320;
for(int M : Ms)
this->Run(M, N, K);
}
constexpr bool PadM = false;
constexpr bool PadN = false;
constexpr bool PadK = false;
TYPED_TEST(TestCkTileGemmMemPipelineIntrawave, PaddK)
{
std::vector<int> Ms{127};
constexpr int N = 1024;
constexpr int K = 432;
for(int M : Ms)
this->Run(M, N, K);
}
TYPED_TEST(TestCkTileGemmMemPipelineIntrawave, Regular)
{
std::vector<int> Ms{512};
constexpr int N = 1024;
constexpr int K = 512;
for(int M : Ms)
this->Run(M, N, K);
EXPECT_THROW((this->template Run<PadM, PadN, PadK>(M, N, K)), std::runtime_error);
}
......@@ -11,7 +11,7 @@
#include "ck_tile/ops/epilogue.hpp"
#include "ck_tile/ops/gemm.hpp"
template <typename Tuple, ck_tile::GemmPipelineScheduler Scheduler_>
template <typename Tuple>
class TestCkTileGemmMemPipeline : public ::testing::Test
{
protected:
......@@ -22,7 +22,7 @@ class TestCkTileGemmMemPipeline : public ::testing::Test
using BDataType = std::tuple_element_t<4, Tuple>;
using AccDataType = std::tuple_element_t<5, Tuple>;
using CDataType = std::tuple_element_t<6, Tuple>;
static constexpr auto Scheduler = Scheduler_;
static constexpr auto Scheduler = std::tuple_element_t<7, Tuple>::value;
// TODO: expose tile size through test t-param ?
struct gemm_args
......@@ -39,6 +39,7 @@ class TestCkTileGemmMemPipeline : public ::testing::Test
ck_tile::index_t stride_C;
};
template <bool PadM, bool PadN, bool PadK>
void invoke_gemm(const gemm_args& args, const ck_tile::stream_config& s)
{
// TODO: This should be parameterized in tests
......@@ -54,9 +55,9 @@ class TestCkTileGemmMemPipeline : public ::testing::Test
constexpr ck_tile::index_t N_Warp_Tile = 32;
constexpr ck_tile::index_t K_Warp_Tile = 8;
constexpr bool kPadM = true;
constexpr bool kPadN = true;
constexpr bool kPadK = true;
constexpr bool kPadM = PadM;
constexpr bool kPadN = PadN;
constexpr bool kPadK = PadK;
constexpr int kBlockPerCu = 1;
......@@ -107,6 +108,11 @@ class TestCkTileGemmMemPipeline : public ::testing::Test
const dim3 grids = Kernel::GridSize(args.M, args.N, args.kbatch);
constexpr dim3 blocks = Kernel::BlockSize();
if(!Kernel::IsSupportedArgument(kargs))
{
throw std::runtime_error("Wrong! Arguments not supported! Skipping gemm!\n");
}
if(s.log_level_ > 0)
{
std::cout << "Launching kernel with args:"
......@@ -212,6 +218,7 @@ class TestCkTileGemmMemPipeline : public ::testing::Test
void SetUp() override { k_batches_ = {1}; }
template <bool PadM = true, bool PadN = true, bool PadK = true>
void Run(const int M,
const int N,
const int K,
......@@ -221,10 +228,11 @@ class TestCkTileGemmMemPipeline : public ::testing::Test
{
for(auto kb : k_batches_)
{
RunSingle(M, N, K, StrideA, StrideB, StrideC, kb);
RunSingle<PadM, PadN, PadK>(M, N, K, StrideA, StrideB, StrideC, kb);
}
}
template <bool PadM, bool PadN, bool PadK>
void RunSingle(const int M,
const int N,
const int K,
......@@ -301,7 +309,7 @@ class TestCkTileGemmMemPipeline : public ::testing::Test
args.stride_B = stride_B;
args.stride_C = stride_C;
invoke_gemm(args, ck_tile::stream_config{nullptr, false});
invoke_gemm<PadM, PadN, PadK>(args, ck_tile::stream_config{nullptr, false});
c_m_n_dev_buf.FromDevice(c_m_n_dev_result.data());
bool pass = true;
......
# Currently ck_tile is only built on gfx9
if(GPU_TARGETS MATCHES "gfx9")
add_gtest_executable(test_ck_tile_grouped_gemm test_grouped_gemm.cpp)
endif()
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#include <tuple>
#include "gtest/gtest.h"
#include "ck_tile/host.hpp"
#include "test_grouped_gemm_util.hpp"
using F16 = ck_tile::half_t;
using F32 = float;
using Row = ck_tile::tensor_layout::gemm::RowMajor;
using Col = ck_tile::tensor_layout::gemm::ColumnMajor;
// clang-format off
using KernelTypes = ::testing::Types<
// ALayout, BLayout, CLayout, ADataType, BDataType, AccDataType, CDataType
std::tuple< Row, Row, Row, F16, F16, F32, F16>,
//std::tuple< Col, Row, Row, F16, F16, F32, F16>,
std::tuple< Row, Col, Row, F16, F16, F32, F16>//,
//std::tuple< Col, Col, Row, F16, F16, F32, F16>
>;
// clang-format on
TYPED_TEST_SUITE(TestCkTileGroupedGemm, KernelTypes);
#include "test_grouped_gemm_ut_cases.inc"
#pragma once
TYPED_TEST(TestCkTileGroupedGemm, Basic)
{
const int group_count = 16;
std::vector<int> Ms;
std::vector<int> Ns;
std::vector<int> Ks;
std::vector<int> stride_As;
std::vector<int> stride_Bs;
std::vector<int> stride_Cs;
for(int i = 0; i < group_count; i++)
{
Ms.push_back(256 + 256 * i);
Ns.push_back(128 + 128 * i);
Ks.push_back(128 + 64 * i);
stride_As.push_back(Ks[i]);
stride_Bs.push_back(Ks[i]);
stride_Cs.push_back(Ns[i]);
}
this->Run(Ms, Ns, Ks, stride_As, stride_Bs, stride_Cs, group_count);
}
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment