Unverified Commit 7450417d authored by Mirza Halilčević's avatar Mirza Halilčević Committed by GitHub
Browse files

Merge branch 'develop' into ck_host_lib

parents 6d597346 da0c21f6
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#include <iostream>
#include <numeric>
#include <initializer_list>
#include <cstdlib>
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_xdl_cshuffle_v3.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp"
#include "ck/library/utility/device_memory.hpp"
#include "ck/library/utility/host_tensor.hpp"
#include "ck/library/utility/host_tensor_generator.hpp"
#include "ck/library/utility/literals.hpp"
#include "ck/library/reference_tensor_operation/cpu/reference_gemm.hpp"
#include "ck/library/utility/check_err.hpp"
#include "ck/utility/blkgemmpipe_scheduler.hpp"
template <ck::index_t... Is>
using S = ck::Sequence<Is...>;
using I8 = int8_t;
using I32 = int;
using F16 = ck::half_t;
using FP8 = ck::f8_t;
using F32 = float;
using Row = ck::tensor_layout::gemm::RowMajor;
using Col = ck::tensor_layout::gemm::ColumnMajor;
using A0DataType = I8;
using B0DataType = I8;
using AccDataType = I32;
using CShuffleDataType = I32;
using D0DataType = F32;
using D1DataType = F32;
using DsDataType = ck::Tuple<D0DataType, D1DataType>;
using EDataType = F16;
using A0Layout = Row;
using B0Layout = Col;
using D0Layout = Row;
using D1Layout = Col;
using DsLayout = ck::Tuple<D0Layout, D1Layout>;
using ELayout = Row;
struct MultiplyMultiply
{
template <typename E, typename C, typename D0, typename D1>
__host__ __device__ constexpr void
operator()(E& e, const C& c, const D0& d0, const D1& d1) const;
template <>
__host__ __device__ constexpr void operator()<ck::half_t, float, float, float>(
ck::half_t& e, const float& c, const float& d0, const float& d1) const
{
const float x0_f = c * d0 * d1;
e = ck::type_convert<ck::half_t>(x0_f);
}
template <>
__host__ __device__ constexpr void operator()<ck::half_t, int, float, float>(
ck::half_t& e, const int& c, const float& d0, const float& d1) const
{
const float x0_f =
ck::type_convert<float>(c) * ck::type_convert<float>(d0) * ck::type_convert<float>(d1);
e = ck::type_convert<ck::half_t>(x0_f);
}
template <>
__host__ __device__ constexpr void operator()<ck::bhalf_t, int, float, float>(
ck::bhalf_t& e, const int& c, const float& d0, const float& d1) const
{
const float x0_f =
ck::type_convert<float>(c) * ck::type_convert<float>(d0) * ck::type_convert<float>(d1);
e = ck::type_convert<ck::bhalf_t>(x0_f);
}
};
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
using AElementOp = PassThrough;
using BElementOp = PassThrough;
using CDEElementOp = MultiplyMultiply;
static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::MNPadding;
using DeviceOpInstance = ck::tensor_operation::device::DeviceGemmMultiD_Xdl_CShuffle_V3
// clang-format off
///######| ALayout| BLayout| DsLayout| ELayout| AData| BData| DsData| EData| AccData| CShuffle| A| B| CDE| GEMM| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer|
///######| | | | | Type| Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise| Spacialization| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector|
///######| | | | | | | | | | | Operation| Operation| Operation| | | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl|
///######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | S<C, D0, D1>|
///###### RRR
///< Row, Row, DsLayout, ELayout, A0DataType, B0DataType, DsDataType, EDataType, AccDataType, CShuffleDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 256, 256, 128, 64, 16, 4, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 4, 0, 1, 1, S<1, 32, 1, 8>, S<8, 8, 1>, ck::BlockGemmPipelineScheduler::Interwave, ck::BlockGemmPipelineVersion::v1, I8>;
///###### RCR
< Row, Col, DsLayout, ELayout, A0DataType, B0DataType, DsDataType, EDataType, AccDataType, CShuffleDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 256, 256, 128, 64, 16, 16, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 8>, S<8, 8, 1>, ck::BlockGemmPipelineScheduler::Interwave, ck::BlockGemmPipelineVersion::v1, I8>;
// clang-format on
int main(int argc, char* argv[])
{
bool do_verification = true;
int init_method = 1;
bool time_kernel = false;
// GEMM shape
ck::index_t M = 3840;
ck::index_t N = 4096;
ck::index_t K = 4096;
ck::index_t StrideA = K;
ck::index_t StrideB = K;
ck::index_t StrideD = 0;
ck::index_t StrideE = N;
ck::index_t KBatch = 1;
if(argc == 1)
{
// use default case
}
else if(argc == 4)
{
do_verification = std::stoi(argv[1]);
init_method = std::stoi(argv[2]);
time_kernel = std::stoi(argv[3]);
}
else if(argc == 12)
{
do_verification = std::stoi(argv[1]);
init_method = std::stoi(argv[2]);
time_kernel = std::stoi(argv[3]);
M = std::stoi(argv[4]);
N = std::stoi(argv[5]);
K = std::stoi(argv[6]);
StrideA = std::stoi(argv[7]);
StrideB = std::stoi(argv[8]);
StrideD = std::stoi(argv[9]);
StrideE = std::stoi(argv[10]);
KBatch = std::stoi(argv[11]);
}
else
{
printf("arg1: verification (0=no, 1=yes)\n");
printf("arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n");
printf("arg3: time kernel (0=no, 1=yes)\n");
printf(
"arg4 to 9: M (256x), N(128x), K(32x), StrideA, StrideB, StrideD, StrideE, KBatch\n");
exit(0);
}
do_verification = false;
auto f_host_tensor_descriptor =
[](std::size_t row, std::size_t col, std::size_t stride, auto layout) {
using namespace ck::literals;
if(std::is_same<decltype(layout), ck::tensor_layout::gemm::RowMajor>::value)
{
return HostTensorDescriptor({row, col}, {stride, 1_uz});
}
else
{
return HostTensorDescriptor({row, col}, {1_uz, stride});
}
};
Tensor<A0DataType> a0_m_k(f_host_tensor_descriptor(M, K, StrideA, A0Layout{}));
Tensor<B0DataType> b0_k_n(f_host_tensor_descriptor(K, N, StrideB, B0Layout{}));
Tensor<D0DataType> d0_m_n(f_host_tensor_descriptor(M, N, StrideD, D0Layout{}));
Tensor<D1DataType> d1_m_n(f_host_tensor_descriptor(M, N, StrideD, D1Layout{}));
Tensor<EDataType> e_m_n_host_result(f_host_tensor_descriptor(M, N, StrideE, ELayout{}));
Tensor<EDataType> e_m_n_device_result(f_host_tensor_descriptor(M, N, StrideE, ELayout{}));
std::cout << "a0_m_k: " << a0_m_k.mDesc << std::endl;
std::cout << "b0_k_n: " << b0_k_n.mDesc << std::endl;
std::cout << "d1_m_n: " << d1_m_n.mDesc << std::endl;
std::cout << "d0_m_n: " << d0_m_n.mDesc << std::endl;
std::cout << "e_m_n: " << e_m_n_host_result.mDesc << std::endl;
switch(init_method)
{
case 0: break;
case 1:
a0_m_k.GenerateTensorValue(GeneratorTensor_2<A0DataType>{-2, 2});
b0_k_n.GenerateTensorValue(GeneratorTensor_2<B0DataType>{0, 2});
d0_m_n.GenerateTensorValue(GeneratorTensor_2<D0DataType>{0, 2});
d1_m_n.GenerateTensorValue(GeneratorTensor_2<D1DataType>{0, 2});
break;
default:
a0_m_k.GenerateTensorValue(GeneratorTensor_3<A0DataType>{0.0, 1.0});
b0_k_n.GenerateTensorValue(GeneratorTensor_3<B0DataType>{-0.5, 0.5});
d0_m_n.GenerateTensorValue(GeneratorTensor_3<D0DataType>{-0.5, 0.5});
d1_m_n.GenerateTensorValue(GeneratorTensor_3<D1DataType>{-0.5, 0.5});
}
DeviceMem a0_device_buf(sizeof(A0DataType) * a0_m_k.mDesc.GetElementSpaceSize());
DeviceMem b0_device_buf(sizeof(B0DataType) * b0_k_n.mDesc.GetElementSpaceSize());
DeviceMem d0_device_buf(sizeof(D0DataType) * d0_m_n.mDesc.GetElementSpaceSize());
DeviceMem d1_device_buf(sizeof(D1DataType) * d1_m_n.mDesc.GetElementSpaceSize());
DeviceMem e_device_buf(sizeof(EDataType) * e_m_n_device_result.mDesc.GetElementSpaceSize());
a0_device_buf.ToDevice(a0_m_k.mData.data());
b0_device_buf.ToDevice(b0_k_n.mData.data());
d0_device_buf.ToDevice(d0_m_n.mData.data());
d1_device_buf.ToDevice(d1_m_n.mData.data());
e_device_buf.ToDevice(e_m_n_device_result.mData.data());
auto a_element_op = AElementOp{};
auto b_element_op = BElementOp{};
auto cde_element_op = CDEElementOp{};
constexpr ck::index_t NumDTensor = DsDataType::Size();
constexpr auto I0 = ck::Number<0>{};
// do GEMM
auto device_op = DeviceOpInstance{};
auto invoker = device_op.MakeInvoker();
auto argument =
device_op.MakeArgument(a0_device_buf.GetDeviceBuffer(),
b0_device_buf.GetDeviceBuffer(),
std::array<const void*, NumDTensor>{d0_device_buf.GetDeviceBuffer(),
d1_device_buf.GetDeviceBuffer()},
e_device_buf.GetDeviceBuffer(),
M,
N,
K,
StrideA,
StrideB,
std::array<ck::index_t, NumDTensor>{I0, I0},
StrideE,
KBatch,
a_element_op,
b_element_op,
cde_element_op);
if(!device_op.IsSupportedArgument(argument))
{
throw std::runtime_error(
"wrong! device_gemm with the specified compilation parameters does "
"not support this GEMM problem");
}
float ave_time = invoker.Run(argument, StreamConfig{nullptr, time_kernel, 20, 50});
std::size_t flop = std::size_t(2) * M * N * K;
std::size_t num_btype =
sizeof(A0DataType) * M * K + sizeof(B0DataType) * K * N + sizeof(EDataType) * M * N;
float tflops = static_cast<float>(flop) / 1.E9 / ave_time;
float gb_per_sec = num_btype / 1.E6 / ave_time;
std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s"
<< std::endl;
if(do_verification)
{
invoker.Run(argument, StreamConfig{nullptr, false});
e_device_buf.FromDevice(e_m_n_device_result.mData.data());
Tensor<CShuffleDataType> c_m_n({M, N});
using ReferenceGemmInstance = ck::tensor_operation::host::ReferenceGemm<A0DataType,
B0DataType,
CShuffleDataType,
AccDataType,
PassThrough,
PassThrough,
PassThrough>;
auto ref_gemm = ReferenceGemmInstance{};
auto ref_invoker = ref_gemm.MakeInvoker();
auto ref_argument = ref_gemm.MakeArgument(
a0_m_k, b0_k_n, c_m_n, PassThrough{}, PassThrough{}, PassThrough{});
ref_invoker.Run(ref_argument);
for(int m = 0; m < M; ++m)
{
for(int n = 0; n < N; ++n)
{
cde_element_op(e_m_n_host_result(m, n), c_m_n(m, n), d0_m_n(m, n), d1_m_n(m, n));
}
}
e_device_buf.FromDevice(e_m_n_device_result.mData.data());
return ck::utils::check_err(e_m_n_device_result, e_m_n_host_result) ? 0 : 1;
}
return 0;
}
...@@ -127,44 +127,47 @@ int run_complex_contraction_bilinear_example(int argc, char* argv[]) ...@@ -127,44 +127,47 @@ int run_complex_contraction_bilinear_example(int argc, char* argv[])
switch(init_method) switch(init_method)
{ {
case 0: break; case 0: break;
case 1: case 1:
a_ms_ks_re.GenerateTensorValue(GeneratorTensor_2<ADataType>{-5, 5}); a_ms_ks_re.GenerateTensorValue(GeneratorTensor_2<ADataType>{-5, 5});
b_ns_ks_re.GenerateTensorValue(GeneratorTensor_2<BDataType>{-5, 5}); b_ns_ks_re.GenerateTensorValue(GeneratorTensor_2<BDataType>{-5, 5});
d_ms_ns_re.GenerateTensorValue(GeneratorTensor_2<BDataType>{-5, 5}); d_ms_ns_re.GenerateTensorValue(GeneratorTensor_2<BDataType>{-5, 5});
a_ms_ks_img.GenerateTensorValue(GeneratorTensor_2<ADataType>{-5, 5}); a_ms_ks_img.GenerateTensorValue(GeneratorTensor_2<ADataType>{-5, 5});
b_ns_ks_img.GenerateTensorValue(GeneratorTensor_2<BDataType>{-5, 5}); b_ns_ks_img.GenerateTensorValue(GeneratorTensor_2<BDataType>{-5, 5});
d_ms_ns_img.GenerateTensorValue(GeneratorTensor_2<BDataType>{-5, 5}); d_ms_ns_img.GenerateTensorValue(GeneratorTensor_2<BDataType>{-5, 5});
break; break;
default: default:
a_ms_ks_re.GenerateTensorValue(GeneratorTensor_3<ADataType>{0.0, 1.0}); a_ms_ks_re.GenerateTensorValue(GeneratorTensor_3<ADataType>{0.0, 1.0});
b_ns_ks_re.GenerateTensorValue(GeneratorTensor_3<BDataType>{-0.5, 0.5}); b_ns_ks_re.GenerateTensorValue(GeneratorTensor_3<BDataType>{-0.5, 0.5});
d_ms_ns_re.GenerateTensorValue(GeneratorTensor_3<BDataType>{-0.5, 0.5}); d_ms_ns_re.GenerateTensorValue(GeneratorTensor_3<BDataType>{-0.5, 0.5});
a_ms_ks_img.GenerateTensorValue(GeneratorTensor_3<ADataType>{0.0, 1.0}); a_ms_ks_img.GenerateTensorValue(GeneratorTensor_3<ADataType>{0.0, 1.0});
b_ns_ks_img.GenerateTensorValue(GeneratorTensor_3<BDataType>{-0.5, 0.5}); b_ns_ks_img.GenerateTensorValue(GeneratorTensor_3<BDataType>{-0.5, 0.5});
d_ms_ns_img.GenerateTensorValue(GeneratorTensor_3<BDataType>{-0.5, 0.5}); d_ms_ns_img.GenerateTensorValue(GeneratorTensor_3<BDataType>{-0.5, 0.5});
break; break;
} }
DeviceMem a_device_buf_re(sizeof(ADataType) * a_ms_ks_re.mDesc.GetElementSpaceSize()); DeviceMem a_device_buf_re(sizeof(ADataType) * a_ms_ks_re.mDesc.GetElementSpaceSize());
DeviceMem b_device_buf_re(sizeof(BDataType) * b_ns_ks_re.mDesc.GetElementSpaceSize()); DeviceMem b_device_buf_re(sizeof(BDataType) * b_ns_ks_re.mDesc.GetElementSpaceSize());
DeviceMem d_device_buf_re(sizeof(DDataType) * d_ms_ns_re.mDesc.GetElementSpaceSize()); DeviceMem d_device_buf_re(sizeof(DDataType) * d_ms_ns_re.mDesc.GetElementSpaceSize());
DeviceMem e_device_buf_re(sizeof(EDataType) * e_ms_ns_device_result_re.mDesc.GetElementSpaceSize()); DeviceMem e_device_buf_re(sizeof(EDataType) *
e_ms_ns_device_result_re.mDesc.GetElementSpaceSize());
DeviceMem a_device_buf_img(sizeof(ADataType) * a_ms_ks_img.mDesc.GetElementSpaceSize()); DeviceMem a_device_buf_img(sizeof(ADataType) * a_ms_ks_img.mDesc.GetElementSpaceSize());
DeviceMem b_device_buf_img(sizeof(BDataType) * b_ns_ks_img.mDesc.GetElementSpaceSize()); DeviceMem b_device_buf_img(sizeof(BDataType) * b_ns_ks_img.mDesc.GetElementSpaceSize());
DeviceMem d_device_buf_img(sizeof(DDataType) * d_ms_ns_img.mDesc.GetElementSpaceSize()); DeviceMem d_device_buf_img(sizeof(DDataType) * d_ms_ns_img.mDesc.GetElementSpaceSize());
DeviceMem e_device_buf_img(sizeof(EDataType) * e_ms_ns_device_result_img.mDesc.GetElementSpaceSize()); DeviceMem e_device_buf_img(sizeof(EDataType) *
e_ms_ns_device_result_img.mDesc.GetElementSpaceSize());
// Intermediate Value For E Real and Img // Intermediate Value For E Real and Img
DeviceMem e_device_buf_re1(sizeof(EDataType) * e_ms_ns_device_result_re.mDesc.GetElementSpaceSize()); DeviceMem e_device_buf_re1(sizeof(EDataType) *
DeviceMem e_device_buf_img1(sizeof(EDataType) * e_ms_ns_device_result_img.mDesc.GetElementSpaceSize()); e_ms_ns_device_result_re.mDesc.GetElementSpaceSize());
DeviceMem e_device_buf_img1(sizeof(EDataType) *
e_ms_ns_device_result_img.mDesc.GetElementSpaceSize());
a_device_buf_re.ToDevice(a_ms_ks_re.mData.data()); a_device_buf_re.ToDevice(a_ms_ks_re.mData.data());
b_device_buf_re.ToDevice(b_ns_ks_re.mData.data()); b_device_buf_re.ToDevice(b_ns_ks_re.mData.data());
...@@ -181,7 +184,7 @@ int run_complex_contraction_bilinear_example(int argc, char* argv[]) ...@@ -181,7 +184,7 @@ int run_complex_contraction_bilinear_example(int argc, char* argv[])
// set zero for intermediate values // set zero for intermediate values
e_device_buf_re1.SetZero(); e_device_buf_re1.SetZero();
e_device_buf_img1.SetZero(); e_device_buf_img1.SetZero();
auto a_element_op = AElementOp{}; auto a_element_op = AElementOp{};
auto b_element_op = BElementOp{}; auto b_element_op = BElementOp{};
auto cde_element_op = CDEElementOp{alpha, beta}; auto cde_element_op = CDEElementOp{alpha, beta};
...@@ -189,23 +192,24 @@ int run_complex_contraction_bilinear_example(int argc, char* argv[]) ...@@ -189,23 +192,24 @@ int run_complex_contraction_bilinear_example(int argc, char* argv[])
// device operation // device operation
// For real Intermediate Value re_1 // For real Intermediate Value re_1
auto op = DeviceOpInstance{}; auto op = DeviceOpInstance{};
auto invoker = op.MakeInvoker(); auto invoker = op.MakeInvoker();
auto argument_re1 = op.MakeArgument(a_device_buf_re.GetDeviceBuffer(), auto argument_re1 =
b_device_buf_re.GetDeviceBuffer(), op.MakeArgument(a_device_buf_re.GetDeviceBuffer(),
std::array<const void*, 1>{d_device_buf_re.GetDeviceBuffer()}, b_device_buf_re.GetDeviceBuffer(),
e_device_buf_re1.GetDeviceBuffer(), std::array<const void*, 1>{d_device_buf_re.GetDeviceBuffer()},
a_ms_ks_lengths, e_device_buf_re1.GetDeviceBuffer(),
a_ms_ks_strides, a_ms_ks_lengths,
b_ns_ks_lengths, a_ms_ks_strides,
b_ns_ks_strides, b_ns_ks_lengths,
std::array<std::vector<ck::index_t>, 1>{d_ms_ns_lengths}, b_ns_ks_strides,
std::array<std::vector<ck::index_t>, 1>{d_ms_ns_strides}, std::array<std::vector<ck::index_t>, 1>{d_ms_ns_lengths},
e_ms_ns_lengths, std::array<std::vector<ck::index_t>, 1>{d_ms_ns_strides},
e_ms_ns_strides, e_ms_ns_lengths,
a_element_op, e_ms_ns_strides,
b_element_op, a_element_op,
cde_element_op); b_element_op,
cde_element_op);
if(!op.IsSupportedArgument(argument_re1)) if(!op.IsSupportedArgument(argument_re1))
{ {
...@@ -216,7 +220,6 @@ int run_complex_contraction_bilinear_example(int argc, char* argv[]) ...@@ -216,7 +220,6 @@ int run_complex_contraction_bilinear_example(int argc, char* argv[])
float ave_time_re1 = invoker.Run(argument_re1, StreamConfig{nullptr, time_kernel}); float ave_time_re1 = invoker.Run(argument_re1, StreamConfig{nullptr, time_kernel});
alpha = -1.f; alpha = -1.f;
beta = 1.f; beta = 1.f;
...@@ -228,21 +231,22 @@ int run_complex_contraction_bilinear_example(int argc, char* argv[]) ...@@ -228,21 +231,22 @@ int run_complex_contraction_bilinear_example(int argc, char* argv[])
// For real Intermediate Value re_2 // For real Intermediate Value re_2
// auto op = DeviceOpInstance{}; // auto op = DeviceOpInstance{};
// auto invoker = op.MakeInvoker(); // auto invoker = op.MakeInvoker();
auto argument_re2 = op.MakeArgument(a_device_buf_img.GetDeviceBuffer(), auto argument_re2 =
b_device_buf_img.GetDeviceBuffer(), op.MakeArgument(a_device_buf_img.GetDeviceBuffer(),
std::array<const void*, 1>{e_device_buf_re1.GetDeviceBuffer()}, b_device_buf_img.GetDeviceBuffer(),
e_device_buf_re.GetDeviceBuffer(), std::array<const void*, 1>{e_device_buf_re1.GetDeviceBuffer()},
a_ms_ks_lengths, e_device_buf_re.GetDeviceBuffer(),
a_ms_ks_strides, a_ms_ks_lengths,
b_ns_ks_lengths, a_ms_ks_strides,
b_ns_ks_strides, b_ns_ks_lengths,
std::array<std::vector<ck::index_t>, 1>{d_ms_ns_lengths}, b_ns_ks_strides,
std::array<std::vector<ck::index_t>, 1>{d_ms_ns_strides}, std::array<std::vector<ck::index_t>, 1>{d_ms_ns_lengths},
e_ms_ns_lengths, std::array<std::vector<ck::index_t>, 1>{d_ms_ns_strides},
e_ms_ns_strides, e_ms_ns_lengths,
a_element_op, e_ms_ns_strides,
b_element_op, a_element_op,
cde_element_op); b_element_op,
cde_element_op);
if(!op.IsSupportedArgument(argument_re2)) if(!op.IsSupportedArgument(argument_re2))
{ {
...@@ -253,7 +257,6 @@ int run_complex_contraction_bilinear_example(int argc, char* argv[]) ...@@ -253,7 +257,6 @@ int run_complex_contraction_bilinear_example(int argc, char* argv[])
float ave_time_re2 = invoker.Run(argument_re2, StreamConfig{nullptr, time_kernel}); float ave_time_re2 = invoker.Run(argument_re2, StreamConfig{nullptr, time_kernel});
alpha = 1.f; alpha = 1.f;
beta = 1.f; beta = 1.f;
...@@ -261,22 +264,22 @@ int run_complex_contraction_bilinear_example(int argc, char* argv[]) ...@@ -261,22 +264,22 @@ int run_complex_contraction_bilinear_example(int argc, char* argv[])
b_element_op = BElementOp{}; b_element_op = BElementOp{};
cde_element_op = CDEElementOp{alpha, beta}; cde_element_op = CDEElementOp{alpha, beta};
auto argument_img1 = op.MakeArgument(a_device_buf_re.GetDeviceBuffer(), auto argument_img1 =
b_device_buf_img.GetDeviceBuffer(), op.MakeArgument(a_device_buf_re.GetDeviceBuffer(),
std::array<const void*, 1>{d_device_buf_img.GetDeviceBuffer()}, b_device_buf_img.GetDeviceBuffer(),
e_device_buf_img1.GetDeviceBuffer(), std::array<const void*, 1>{d_device_buf_img.GetDeviceBuffer()},
a_ms_ks_lengths, e_device_buf_img1.GetDeviceBuffer(),
a_ms_ks_strides, a_ms_ks_lengths,
b_ns_ks_lengths, a_ms_ks_strides,
b_ns_ks_strides, b_ns_ks_lengths,
std::array<std::vector<ck::index_t>, 1>{d_ms_ns_lengths}, b_ns_ks_strides,
std::array<std::vector<ck::index_t>, 1>{d_ms_ns_strides}, std::array<std::vector<ck::index_t>, 1>{d_ms_ns_lengths},
e_ms_ns_lengths, std::array<std::vector<ck::index_t>, 1>{d_ms_ns_strides},
e_ms_ns_strides, e_ms_ns_lengths,
a_element_op, e_ms_ns_strides,
b_element_op, a_element_op,
cde_element_op); b_element_op,
cde_element_op);
if(!op.IsSupportedArgument(argument_img1)) if(!op.IsSupportedArgument(argument_img1))
{ {
...@@ -290,23 +293,22 @@ int run_complex_contraction_bilinear_example(int argc, char* argv[]) ...@@ -290,23 +293,22 @@ int run_complex_contraction_bilinear_example(int argc, char* argv[])
alpha = 1.f; alpha = 1.f;
beta = 1.f; beta = 1.f;
auto argument_img2 = op.MakeArgument(a_device_buf_img.GetDeviceBuffer(), auto argument_img2 =
b_device_buf_re.GetDeviceBuffer(), op.MakeArgument(a_device_buf_img.GetDeviceBuffer(),
std::array<const void*, 1>{e_device_buf_img1.GetDeviceBuffer()}, b_device_buf_re.GetDeviceBuffer(),
e_device_buf_img.GetDeviceBuffer(), std::array<const void*, 1>{e_device_buf_img1.GetDeviceBuffer()},
a_ms_ks_lengths, e_device_buf_img.GetDeviceBuffer(),
a_ms_ks_strides, a_ms_ks_lengths,
b_ns_ks_lengths, a_ms_ks_strides,
b_ns_ks_strides, b_ns_ks_lengths,
std::array<std::vector<ck::index_t>, 1>{d_ms_ns_lengths}, b_ns_ks_strides,
std::array<std::vector<ck::index_t>, 1>{d_ms_ns_strides}, std::array<std::vector<ck::index_t>, 1>{d_ms_ns_lengths},
e_ms_ns_lengths, std::array<std::vector<ck::index_t>, 1>{d_ms_ns_strides},
e_ms_ns_strides, e_ms_ns_lengths,
a_element_op, e_ms_ns_strides,
b_element_op, a_element_op,
cde_element_op); b_element_op,
cde_element_op);
if(!op.IsSupportedArgument(argument_img2)) if(!op.IsSupportedArgument(argument_img2))
{ {
...@@ -317,7 +319,6 @@ int run_complex_contraction_bilinear_example(int argc, char* argv[]) ...@@ -317,7 +319,6 @@ int run_complex_contraction_bilinear_example(int argc, char* argv[])
float ave_time_img2 = invoker.Run(argument_img2, StreamConfig{nullptr, time_kernel}); float ave_time_img2 = invoker.Run(argument_img2, StreamConfig{nullptr, time_kernel});
ck::index_t M = ck::index_t M =
ck::accumulate_n<ck::index_t>(e_ms_ns_lengths.begin(), NumDimM, 1, std::multiplies<>{}); ck::accumulate_n<ck::index_t>(e_ms_ns_lengths.begin(), NumDimM, 1, std::multiplies<>{});
...@@ -331,9 +332,9 @@ int run_complex_contraction_bilinear_example(int argc, char* argv[]) ...@@ -331,9 +332,9 @@ int run_complex_contraction_bilinear_example(int argc, char* argv[])
std::size_t num_btype = sizeof(ADataType) * M * K + sizeof(BDataType) * K * N + std::size_t num_btype = sizeof(ADataType) * M * K + sizeof(BDataType) * K * N +
sizeof(DDataType) * M * N + sizeof(EDataType) * M * N * 2; sizeof(DDataType) * M * N + sizeof(EDataType) * M * N * 2;
float ave_time = ave_time_img2 + ave_time_img1 + ave_time_re2 + ave_time_re1 ; float ave_time = ave_time_img2 + ave_time_img1 + ave_time_re2 + ave_time_re1;
float tflops = static_cast<float>(flop) / 1.E9 / ave_time; float tflops = static_cast<float>(flop) / 1.E9 / ave_time;
float gb_per_sec = num_btype / 1.E6 / ave_time; float gb_per_sec = num_btype / 1.E6 / ave_time;
std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s, " std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s, "
...@@ -343,7 +344,7 @@ int run_complex_contraction_bilinear_example(int argc, char* argv[]) ...@@ -343,7 +344,7 @@ int run_complex_contraction_bilinear_example(int argc, char* argv[])
e_device_buf_img.FromDevice(e_ms_ns_device_result_img.mData.data()); e_device_buf_img.FromDevice(e_ms_ns_device_result_img.mData.data());
auto isRealOk = 0; auto isRealOk = 0;
auto isImgOk = 0; auto isImgOk = 0;
if(do_verification) if(do_verification)
{ {
...@@ -366,17 +367,16 @@ int run_complex_contraction_bilinear_example(int argc, char* argv[]) ...@@ -366,17 +367,16 @@ int run_complex_contraction_bilinear_example(int argc, char* argv[])
auto ref_op = ReferenceOpInstance{}; auto ref_op = ReferenceOpInstance{};
auto ref_invoker = ref_op.MakeInvoker(); auto ref_invoker = ref_op.MakeInvoker();
auto ref_argument_re = auto ref_argument_re = ref_op.MakeArgument(
ref_op.MakeArgument(a_ms_ks_re, b_ns_ks_re, c_ms_ns_host_result_re, a_element_op, b_element_op); a_ms_ks_re, b_ns_ks_re, c_ms_ns_host_result_re, a_element_op, b_element_op);
ref_invoker.Run(ref_argument_re); ref_invoker.Run(ref_argument_re);
alpha = 1.f; alpha = 1.f;
beta = 1.f; beta = 1.f;
cde_element_op = CDEElementOp{alpha, beta}; cde_element_op = CDEElementOp{alpha, beta};
for(size_t m0 = 0; m0 < e_ms_ns_host_result_re.mDesc.GetLengths()[0]; ++m0) for(size_t m0 = 0; m0 < e_ms_ns_host_result_re.mDesc.GetLengths()[0]; ++m0)
{ {
for(size_t m1 = 0; m1 < e_ms_ns_host_result_re.mDesc.GetLengths()[1]; ++m1) for(size_t m1 = 0; m1 < e_ms_ns_host_result_re.mDesc.GetLengths()[1]; ++m1)
...@@ -395,11 +395,11 @@ int run_complex_contraction_bilinear_example(int argc, char* argv[]) ...@@ -395,11 +395,11 @@ int run_complex_contraction_bilinear_example(int argc, char* argv[])
alpha = 1.f; alpha = 1.f;
beta = -1.f; beta = -1.f;
cde_element_op = CDEElementOp{alpha, beta}; cde_element_op = CDEElementOp{alpha, beta};
auto ref_argument_re1 = auto ref_argument_re1 = ref_op.MakeArgument(
ref_op.MakeArgument(a_ms_ks_img, b_ns_ks_img, c_ms_ns_host_result_re1, a_element_op, b_element_op); a_ms_ks_img, b_ns_ks_img, c_ms_ns_host_result_re1, a_element_op, b_element_op);
ref_invoker.Run(ref_argument_re1); ref_invoker.Run(ref_argument_re1);
...@@ -419,23 +419,20 @@ int run_complex_contraction_bilinear_example(int argc, char* argv[]) ...@@ -419,23 +419,20 @@ int run_complex_contraction_bilinear_example(int argc, char* argv[])
} }
} }
isRealOk = ck::utils::check_err(e_ms_ns_device_result_re, e_ms_ns_host_result_re) ? 0 : 1; isRealOk = ck::utils::check_err(e_ms_ns_device_result_re, e_ms_ns_host_result_re) ? 0 : 1;
// Img Part Verification // Img Part Verification
Tensor<CShuffleDataType> c_ms_ns_host_result_img(e_ms_ns_lengths, e_ms_ns_strides); Tensor<CShuffleDataType> c_ms_ns_host_result_img(e_ms_ns_lengths, e_ms_ns_strides);
Tensor<CShuffleDataType> c_ms_ns_host_result_img1(e_ms_ns_lengths, e_ms_ns_strides); Tensor<CShuffleDataType> c_ms_ns_host_result_img1(e_ms_ns_lengths, e_ms_ns_strides);
auto ref_argument_img = auto ref_argument_img = ref_op.MakeArgument(
ref_op.MakeArgument(a_ms_ks_re, b_ns_ks_img, c_ms_ns_host_result_img, a_element_op, b_element_op); a_ms_ks_re, b_ns_ks_img, c_ms_ns_host_result_img, a_element_op, b_element_op);
ref_invoker.Run(ref_argument_img); ref_invoker.Run(ref_argument_img);
alpha = 1.f; alpha = 1.f;
beta = 1.f; beta = 1.f;
cde_element_op = CDEElementOp{alpha, beta}; cde_element_op = CDEElementOp{alpha, beta};
for(size_t m0 = 0; m0 < e_ms_ns_host_result_img.mDesc.GetLengths()[0]; ++m0) for(size_t m0 = 0; m0 < e_ms_ns_host_result_img.mDesc.GetLengths()[0]; ++m0)
...@@ -454,9 +451,9 @@ int run_complex_contraction_bilinear_example(int argc, char* argv[]) ...@@ -454,9 +451,9 @@ int run_complex_contraction_bilinear_example(int argc, char* argv[])
} }
} }
auto ref_argument_img1 = auto ref_argument_img1 = ref_op.MakeArgument(
ref_op.MakeArgument(a_ms_ks_img, b_ns_ks_re, c_ms_ns_host_result_img1, a_element_op, b_element_op); a_ms_ks_img, b_ns_ks_re, c_ms_ns_host_result_img1, a_element_op, b_element_op);
ref_invoker.Run(ref_argument_img1); ref_invoker.Run(ref_argument_img1);
for(size_t m0 = 0; m0 < e_ms_ns_host_result_img.mDesc.GetLengths()[0]; ++m0) for(size_t m0 = 0; m0 < e_ms_ns_host_result_img.mDesc.GetLengths()[0]; ++m0)
...@@ -475,7 +472,7 @@ int run_complex_contraction_bilinear_example(int argc, char* argv[]) ...@@ -475,7 +472,7 @@ int run_complex_contraction_bilinear_example(int argc, char* argv[])
} }
} }
isImgOk = ck::utils::check_err(e_ms_ns_device_result_re, e_ms_ns_host_result_re) ? 0 : 1; isImgOk = ck::utils::check_err(e_ms_ns_device_result_re, e_ms_ns_host_result_re) ? 0 : 1;
return (isRealOk && isImgOk); return (isRealOk && isImgOk);
} }
......
...@@ -85,9 +85,9 @@ function(add_example_executable EXAMPLE_NAME FILE_NAME) ...@@ -85,9 +85,9 @@ function(add_example_executable EXAMPLE_NAME FILE_NAME)
#only continue if there are some source files left on the list #only continue if there are some source files left on the list
if(FILE_NAME) if(FILE_NAME)
if(FILE_NAME MATCHES "_xdl") if(FILE_NAME MATCHES "_xdl")
list(REMOVE_ITEM EX_TARGETS gfx1030 gfx1100 gfx1101 gfx1102 gfx1103 gfx1200 gfx1201) list(REMOVE_ITEM EX_TARGETS gfx900 gfx906 gfx906:xnack- gfx1030 gfx1100 gfx1101 gfx1102 gfx1103 gfx1200 gfx1201 gfx10.3-generic gfx11-generic gfx12-generic)
elseif(FILE_NAME MATCHES "_wmma") elseif(FILE_NAME MATCHES "_wmma")
list(REMOVE_ITEM EX_TARGETS gfx908 gfx90a gfx940 gfx941 gfx942 gfx1030) list(REMOVE_ITEM EX_TARGETS gfx900 gfx906 gfx906:xnack- gfx908:xnack+ gfx908:xnack- gfx90a:xnack+ gfx90a:xnack- gfx908 gfx90a gfx940 gfx941 gfx942 gfx1030)
endif() endif()
set_source_files_properties(${FILE_NAME} PROPERTIES LANGUAGE HIP) set_source_files_properties(${FILE_NAME} PROPERTIES LANGUAGE HIP)
add_executable(${EXAMPLE_NAME} ${FILE_NAME}) add_executable(${EXAMPLE_NAME} ${FILE_NAME})
...@@ -169,9 +169,9 @@ function(add_example_executable_no_testing EXAMPLE_NAME FILE_NAME) ...@@ -169,9 +169,9 @@ function(add_example_executable_no_testing EXAMPLE_NAME FILE_NAME)
#only continue if there are some source files left on the list #only continue if there are some source files left on the list
if(FILE_NAME) if(FILE_NAME)
if(FILE_NAME MATCHES "_xdl") if(FILE_NAME MATCHES "_xdl")
list(REMOVE_ITEM EX_TARGETS gfx900 gfx906 gfx1030 gfx1100 gfx1101 gfx1102 gfx1103 gfx1200 gfx1201) list(REMOVE_ITEM EX_TARGETS gfx900 gfx906 gfx906:xnack- gfx1030 gfx1100 gfx1101 gfx1102 gfx1103 gfx1200 gfx1201 gfx10.3-generic gfx11-generic gfx12-generic)
elseif(FILE_NAME MATCHES "_wmma") elseif(FILE_NAME MATCHES "_wmma")
list(REMOVE_ITEM EX_TARGETS gfx900 gfx906 gfx908 gfx90a gfx940 gfx941 gfx942 gfx1030) list(REMOVE_ITEM EX_TARGETS gfx900 gfx906 gfx906:xnack- gfx908:xnack+ gfx908:xnack- gfx90a:xnack+ gfx90a:xnack- gfx908 gfx90a gfx940 gfx941 gfx942 gfx1030)
endif() endif()
set_source_files_properties(${FILE_NAME} PROPERTIES LANGUAGE HIP) set_source_files_properties(${FILE_NAME} PROPERTIES LANGUAGE HIP)
add_executable(${EXAMPLE_NAME} ${FILE_NAME}) add_executable(${EXAMPLE_NAME} ${FILE_NAME})
......
...@@ -21,6 +21,14 @@ DTYPE_BITS = { ...@@ -21,6 +21,14 @@ DTYPE_BITS = {
"bf8" : 8 "bf8" : 8
} }
K0_MAX_SUBMAX_MAP = {
32 : 32,
64 : 64,
96 : 128,
128: 128,
256: 256
}
TILE_PARTITIONER_MAP = { TILE_PARTITIONER_MAP = {
"shb" : "ck_tile::FmhaFwdTilePartitioner_SHB", "shb" : "ck_tile::FmhaFwdTilePartitioner_SHB",
"hbs" : "ck_tile::FmhaFwdTilePartitioner_HBS", "hbs" : "ck_tile::FmhaFwdTilePartitioner_HBS",
...@@ -35,14 +43,13 @@ FMHA_FWD_KERNEL_HEADER = """// SPDX-License-Identifier: MIT ...@@ -35,14 +43,13 @@ FMHA_FWD_KERNEL_HEADER = """// SPDX-License-Identifier: MIT
FMHA_FWD_KERNEL_BODY=""" FMHA_FWD_KERNEL_BODY="""
using fmha_dtype_{F_idx} = {F_dtype}; using fmha_dtype_{F_idx} = {F_dtype};
using fmha_block_tile_{F_idx} = ck_tile::sequence<{F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0blen}>; using fmha_block_tile_{F_idx} = ck_tile::sequence<{F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0max}>;
using fmha_block_warps_{F_idx} = ck_tile::sequence<{F_rm}, {F_rn}, {F_rk}>;
using fmha_warp_tile_{F_idx} = ck_tile::sequence<{F_wm}, {F_wn}, {F_wk}>; using fmha_warp_tile_{F_idx} = ck_tile::sequence<{F_wm}, {F_wn}, {F_wk}>;
using fmha_shape_{F_idx} = ck_tile::TileFmhaShape<fmha_block_tile_{F_idx}, using fmha_shape_{F_idx} = ck_tile::TileFmhaShape<fmha_block_tile_{F_idx},
fmha_block_warps_{F_idx}, ck_tile::sequence<{F_rm0}, {F_rn0}, {F_rk0}>,
fmha_warp_tile_{F_idx}, fmha_warp_tile_{F_idx},
fmha_block_warps_{F_idx}, ck_tile::sequence<{F_rm1}, {F_rn1}, {F_rk1}>,
fmha_warp_tile_{F_idx}, fmha_warp_tile_{F_idx},
{F_vlayout}>; {F_vlayout}>;
...@@ -88,7 +95,7 @@ using fmha_kernel_{F_idx} = ...@@ -88,7 +95,7 @@ using fmha_kernel_{F_idx} =
fmha_pipeline_{F_idx}, fmha_pipeline_{F_idx},
fmha_epilogue_{F_idx}>; fmha_epilogue_{F_idx}>;
using trait_{F_idx} = fmha_fwd_traits_<{F_hdim}, {F_dtype}, {F_mode},{F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0blen}, {F_vlayout}, using trait_{F_idx} = fmha_fwd_traits_<{F_hdim}, {F_dtype}, {F_mode},{F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0max}, {F_vlayout},
{F_pipeline_enum}, fmha_mask_{F_idx}, {F_bias}, {F_lse}, {F_dropout}, {F_squant}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}>; {F_pipeline_enum}, fmha_mask_{F_idx}, {F_bias}, {F_lse}, {F_dropout}, {F_squant}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}>;
#include <iostream> #include <iostream>
...@@ -126,7 +133,7 @@ FMHA_FWD_API_PER_HDIM_CASE=""" {F_if} (t.hdim_q <= {F_hdim} && t.hdim_v < ...@@ -126,7 +133,7 @@ FMHA_FWD_API_PER_HDIM_CASE=""" {F_if} (t.hdim_q <= {F_hdim} && t.hdim_v <
FMHA_FWD_API_INNER_DISPATCH=""" {F_if}((t.is_group_mode == {F_mode}) && (t.is_v_rowmajor == {F_vlayout}) && ({F_mask_check}) && (t.bias_type == {F_bias_check}) && (t.has_lse == {F_lse}) && (t.has_dropout == {F_dropout}) && (t.do_fp8_static_quant == {F_squant}) && FMHA_FWD_API_INNER_DISPATCH=""" {F_if}((t.is_group_mode == {F_mode}) && (t.is_v_rowmajor == {F_vlayout}) && ({F_mask_check}) && (t.bias_type == {F_bias_check}) && (t.has_lse == {F_lse}) && (t.has_dropout == {F_dropout}) && (t.do_fp8_static_quant == {F_squant}) &&
({F_scheck}) && ({F_skcheck}) && ({F_dcheck}) && ({F_dvcheck})) {{ ({F_scheck}) && ({F_skcheck}) && ({F_dcheck}) && ({F_dvcheck})) {{
using trait_ = fmha_fwd_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0blen}, {F_vlayout}, {F_pipeline_enum}, {F_mask}, {F_bias}, {F_lse}, {F_dropout}, {F_squant}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}>; using trait_ = fmha_fwd_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0max}, {F_vlayout}, {F_pipeline_enum}, {F_mask}, {F_bias}, {F_lse}, {F_dropout}, {F_squant}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}>;
return fmha_fwd_<trait_>(s, a); return fmha_fwd_<trait_>(s, a);
}} }}
""" """
...@@ -143,7 +150,7 @@ class FmhaFwdApiTrait: ...@@ -143,7 +150,7 @@ class FmhaFwdApiTrait:
bk0 : int # tile size along qk gemm unroll bk0 : int # tile size along qk gemm unroll
bn1 : int # tile size along v head_dim bn1 : int # tile size along v head_dim
bk1 : int # tile size along kv gemm unroll bk1 : int # tile size along kv gemm unroll
bk0blen : int bk0max : int
vlayout : str vlayout : str
mask : str mask : str
bias : str # bias : str #
...@@ -157,7 +164,7 @@ class FmhaFwdApiTrait: ...@@ -157,7 +164,7 @@ class FmhaFwdApiTrait:
@property @property
def name(self) -> str: def name(self) -> str:
return f'{self.hdim}-{self.dtype}-{self.mode}-{self.bm0}-{self.bn0}-{self.bk0}-{self.bn0}-{self.bk1}-{self.bk0blen}-'+\ return f'{self.hdim}-{self.dtype}-{self.mode}-{self.bm0}-{self.bn0}-{self.bk0}-{self.bn0}-{self.bk1}-{self.bk0max}-'+\
f'{self.vlayout}-{self.mask}-{self.bias}-{self.lse}-{self.dropout}-{self.squant}-{self.spad}-{self.skpad}-{self.dpad}-{self.dvpad}' f'{self.vlayout}-{self.mask}-{self.bias}-{self.lse}-{self.dropout}-{self.squant}-{self.spad}-{self.skpad}-{self.dpad}-{self.dvpad}'
@property @property
...@@ -189,8 +196,9 @@ class FmhaFwdApiTrait: ...@@ -189,8 +196,9 @@ class FmhaFwdApiTrait:
if self.dpad == 't': return f'a.hdim_q % {vec} == 0' if self.dpad == 't': return f'a.hdim_q % {vec} == 0'
else : assert False else : assert False
elif self.pipeline_tag in ['qr']: elif self.pipeline_tag in ['qr']:
if self.dpad == 't': return f'true /*a.hdim_q % {self.bk0blen} != 0*/' # TODO: order of get_pipelines() matters! (ugly) bk0submax = K0_MAX_SUBMAX_MAP[self.bk0max]
else : return f'a.hdim_q % {self.bk0blen} == 0' if self.dpad == 't': return f'true /*a.hdim_q % {bk0submax} != 0*/' # TODO: order of get_pipelines() matters! (ugly)
else : return f'a.hdim_q % {bk0submax} == 0'
else: assert False else: assert False
@property @property
...@@ -200,8 +208,9 @@ class FmhaFwdApiTrait: ...@@ -200,8 +208,9 @@ class FmhaFwdApiTrait:
if self.dvpad == 't': return f'a.hdim_v % {vec} == 0' if self.dvpad == 't': return f'a.hdim_v % {vec} == 0'
else : assert False else : assert False
elif self.pipeline_tag in ['qr']: elif self.pipeline_tag in ['qr']:
if self.dvpad == 't': return f'true /*a.hdim_v % {self.bk0blen} != 0*/' # TODO: order of get_pipelines() matters! (ugly) bk0submax = K0_MAX_SUBMAX_MAP[self.bk0max]
else : return f'a.hdim_v % {self.bk0blen} == 0' if self.dvpad == 't': return f'true /*a.hdim_v % {bk0submax} != 0*/' # TODO: order of get_pipelines() matters! (ugly)
else : return f'a.hdim_v % {bk0submax} == 0'
else: assert False else: assert False
@dataclass @dataclass
...@@ -272,7 +281,7 @@ class FmhaFwdApiPool: ...@@ -272,7 +281,7 @@ class FmhaFwdApiPool:
F_lse=BOOL_MAP[trait.lse], F_dropout=BOOL_MAP[trait.dropout] , F_lse=BOOL_MAP[trait.lse], F_dropout=BOOL_MAP[trait.dropout] ,
F_squant=BOOL_MAP[trait.squant], F_scheck=trait.scheck, F_skcheck=trait.skcheck, F_dcheck=trait.dcheck, F_dvcheck=trait.dvcheck, F_squant=BOOL_MAP[trait.squant], F_scheck=trait.scheck, F_skcheck=trait.skcheck, F_dcheck=trait.dcheck, F_dvcheck=trait.dvcheck,
F_spad=BOOL_MAP[trait.spad], F_skpad=BOOL_MAP[trait.skpad], F_dpad=BOOL_MAP[trait.dpad], F_dvpad=BOOL_MAP[trait.dvpad], F_spad=BOOL_MAP[trait.spad], F_skpad=BOOL_MAP[trait.skpad], F_dpad=BOOL_MAP[trait.dpad], F_dvpad=BOOL_MAP[trait.dvpad],
F_bm0=trait.bm0, F_bn0=trait.bn0, F_bk0=trait.bk0, F_bn1=trait.bn1, F_bk1=trait.bk1, F_bk0blen=trait.bk0blen, F_bm0=trait.bm0, F_bn0=trait.bn0, F_bk0=trait.bk0, F_bn1=trait.bn1, F_bk1=trait.bk1, F_bk0max=trait.bk0max,
F_hdim=hdim, F_dtype=DTYPE_MAP[dtype]) F_hdim=hdim, F_dtype=DTYPE_MAP[dtype])
if_j = 'if' if j == 0 else 'else if' if_j = 'if' if j == 0 else 'else if'
per_hdim_case = per_hdim_case + FMHA_FWD_API_PER_HDIM_CASE.format(F_if=if_j, F_hdim=hdim, F_inner_dispatch=inners) per_hdim_case = per_hdim_case + FMHA_FWD_API_PER_HDIM_CASE.format(F_if=if_j, F_hdim=hdim, F_inner_dispatch=inners)
...@@ -290,19 +299,22 @@ class FmhaFwdTileSize: ...@@ -290,19 +299,22 @@ class FmhaFwdTileSize:
F_bk0 : int # tile size along qk gemm unroll F_bk0 : int # tile size along qk gemm unroll
F_bn1 : int # tile size along v head_dim F_bn1 : int # tile size along v head_dim
F_bk1 : int # tile size along kv gemm unroll F_bk1 : int # tile size along kv gemm unroll
F_bk0blen : int # total length of K0, used for pipeline that need load Q at once (or repeately load Q as a whole tile) F_bk0max : int # total length of K0, used for pipeline that need load Q at once (or repeately load Q as a whole tile)
F_rm : int # number of warps along q seqlen (block warps) F_rm0 : int # number of warps for gemm0 along q seqlen
F_rn : int # number of warps along k seqlen(not used) F_rn0 : int # number of warps for gemm0 along k seqlen
F_rk : int # number of warps along gemm-k(not used) F_rk0 : int # number of warps for gemm0 along head dim q (not used)
F_rm1 : int # number of warps for gemm1 along q seqlen
F_rn1 : int # number of warps for gemm1 along head dim v
F_rk1 : int # number of warps for gemm1 along k seqlen (not used)
F_wm : int # warp size along m (warp size) F_wm : int # warp size along m (warp size)
F_wn : int # warp size along n F_wn : int # warp size along n
F_wk : int # warp size along k F_wk : int # warp size along k
F_occupancy : int # occupancy, -1 will let pipeline decide the occupancy, other value will overwrite occupancy F_occupancy : int # occupancy, -1 will let pipeline decide the occupancy, other value will overwrite occupancy
@property @property
def name(self) -> str: def name(self) -> str:
return f"b{self.F_bm0}x{self.F_bn0}x{self.F_bk0}x{self.F_bn1}x{self.F_bk1}x{self.F_bk0blen}" +\ return f"b{self.F_bm0}x{self.F_bn0}x{self.F_bk0}x{self.F_bn1}x{self.F_bk1}x{self.F_bk0max}" +\
f"_r{self.F_rm}x{self.F_rn}x{self.F_rk}_w{self.F_wm}x{self.F_wn}x{self.F_wk}" +\ f"_r{self.F_rm0}x{self.F_rn0}x{self.F_rk0}_r{self.F_rm1}x{self.F_rn1}x{self.F_rk1}" +\
("" if self.F_occupancy == -1 else f"_o{self.F_occupancy}") f"_w{self.F_wm}x{self.F_wn}x{self.F_wk}" + ("" if self.F_occupancy == -1 else f"_o{self.F_occupancy}")
@dataclass @dataclass
class FmhaFwdKernel: class FmhaFwdKernel:
...@@ -333,10 +345,13 @@ class FmhaFwdKernel: ...@@ -333,10 +345,13 @@ class FmhaFwdKernel:
F_bk0 = self.F_tile.F_bk0, F_bk0 = self.F_tile.F_bk0,
F_bn1 = self.F_tile.F_bn1, F_bn1 = self.F_tile.F_bn1,
F_bk1 = self.F_tile.F_bk1, F_bk1 = self.F_tile.F_bk1,
F_bk0blen = self.F_tile.F_bk0blen, F_bk0max = self.F_tile.F_bk0max,
F_rm = self.F_tile.F_rm, F_rm0 = self.F_tile.F_rm0,
F_rn = self.F_tile.F_rn, F_rn0 = self.F_tile.F_rn0,
F_rk = self.F_tile.F_rk, F_rk0 = self.F_tile.F_rk0,
F_rm1 = self.F_tile.F_rm1,
F_rn1 = self.F_tile.F_rn1,
F_rk1 = self.F_tile.F_rk1,
F_wm = self.F_tile.F_wm, F_wm = self.F_tile.F_wm,
F_wn = self.F_tile.F_wn, F_wn = self.F_tile.F_wn,
F_wk = self.F_tile.F_wk, F_wk = self.F_tile.F_wk,
...@@ -377,7 +392,7 @@ class FmhaFwdKernel: ...@@ -377,7 +392,7 @@ class FmhaFwdKernel:
bk0=self.F_tile.F_bk0, bk0=self.F_tile.F_bk0,
bn1=self.F_tile.F_bn1, bn1=self.F_tile.F_bn1,
bk1=self.F_tile.F_bk1, bk1=self.F_tile.F_bk1,
bk0blen=self.F_tile.F_bk0blen, bk0max=self.F_tile.F_bk0max,
vlayout=self.F_pipeline.F_vlayout, vlayout=self.F_pipeline.F_vlayout,
mask=self.F_pipeline.F_mask, mask=self.F_pipeline.F_mask,
bias=self.F_pipeline.F_bias, bias=self.F_pipeline.F_bias,
...@@ -394,16 +409,17 @@ class FmhaFwdKernel: ...@@ -394,16 +409,17 @@ class FmhaFwdKernel:
def get_fmha_fwd_tile_dict_from_dtype(dtype : str) -> Optional[dict]: def get_fmha_fwd_tile_dict_from_dtype(dtype : str) -> Optional[dict]:
if dtype == 'fp16' or dtype == 'bf16': if dtype == 'fp16' or dtype == 'bf16':
return { return {
'32' : FmhaFwdTileSize(128, 64, 16, 32, 32, 32, 2, 1, 1, 32, 32, 16, -1), '32' : FmhaFwdTileSize(128, 64, 16, 32, 32, 32, 2, 1, 1, 2, 1, 1, 32, 32, 16, -1),
'64' : FmhaFwdTileSize(128, 64, 32, 64, 32, 64, 4, 1, 1, 32, 32, 16, -1), '64' : FmhaFwdTileSize(128, 64, 32, 64, 32, 64, 4, 1, 1, 4, 1, 1, 32, 32, 16, -1),
'128' : FmhaFwdTileSize(128, 128, 32, 128, 32, 128, 4, 1, 1, 32, 32, 16, -1), ## '96' : FmhaFwdTileSize(128, 128, 32, 128, 32, 96, 4, 1, 1, 4, 1, 1, 32, 32, 16, -1),
'256' : FmhaFwdTileSize(128, 128, 32, 256, 32, 256, 4, 1, 1, 32, 32, 16, -1), '128' : FmhaFwdTileSize(128, 128, 32, 128, 32, 128, 4, 1, 1, 4, 1, 1, 32, 32, 16, -1),
'256' : FmhaFwdTileSize(128, 128, 32, 256, 32, 256, 4, 1, 1, 4, 1, 1, 32, 32, 16, -1),
} }
elif dtype == 'fp8' or dtype == 'bf8': elif dtype == 'fp8' or dtype == 'bf8':
return { return {
'64' : FmhaFwdTileSize(128, 64, 32, 64, 32, 64, 2, 1, 1, 32, 32, 32, -1), '64' : FmhaFwdTileSize(128, 64, 32, 64, 32, 64, 2, 1, 1, 2, 1, 1, 32, 32, 32, -1),
'128' : FmhaFwdTileSize(128, 128, 32, 128, 32, 128, 4, 1, 1, 32, 32, 32, -1), '128' : FmhaFwdTileSize(128, 128, 32, 128, 32, 128, 4, 1, 1, 4, 1, 1, 32, 32, 32, -1),
'256' : FmhaFwdTileSize(128, 128, 32, 256, 32, 256, 4, 1, 1, 32, 32, 32, -1) '256' : FmhaFwdTileSize(128, 128, 32, 256, 32, 256, 4, 1, 1, 4, 1, 1, 32, 32, 32, -1)
} }
else: else:
return None return None
...@@ -505,4 +521,4 @@ def list_blobs(file_path : Path, kernel_filter : Optional[str], receipt, mask_im ...@@ -505,4 +521,4 @@ def list_blobs(file_path : Path, kernel_filter : Optional[str], receipt, mask_im
_, kernels = get_fwd_blobs(kernel_filter, receipt, mask_impl) _, kernels = get_fwd_blobs(kernel_filter, receipt, mask_impl)
for kernel in kernels: for kernel in kernels:
f.write(str(file_path.parent / GEN_DIR / kernel.filename) + "\n") f.write(str(file_path.parent / GEN_DIR / kernel.filename) + "\n")
f.write(str(file_path.parent / GEN_DIR / FMHA_FWD_API_FILENAME) + "\n") f.write(str(file_path.parent / GEN_DIR / FMHA_FWD_API_FILENAME) + "\n")
\ No newline at end of file
...@@ -29,6 +29,14 @@ DTYPE_BITS = { ...@@ -29,6 +29,14 @@ DTYPE_BITS = {
"bf8" : 8 "bf8" : 8
} }
K0_MAX_SUBMAX_MAP = {
32 : 32,
64 : 64,
96 : 128,
128: 128,
256: 256
}
FMHA_FWD_SPLITKV_PIPELINE_MAP = { FMHA_FWD_SPLITKV_PIPELINE_MAP = {
"qr" : "ck_tile::BlockFmhaFwdSplitKVPipelineQRKSVS", "qr" : "ck_tile::BlockFmhaFwdSplitKVPipelineQRKSVS",
"qr_async" : "ck_tile::BlockFmhaFwdSplitKVPipelineQRKSVSAsync", "qr_async" : "ck_tile::BlockFmhaFwdSplitKVPipelineQRKSVSAsync",
...@@ -41,14 +49,13 @@ using fmha_mask_{F_idx} = {F_mask}; ...@@ -41,14 +49,13 @@ using fmha_mask_{F_idx} = {F_mask};
namespace {{ namespace {{
template <bool kHasUnevenSplits> template <bool kHasUnevenSplits>
struct kernel_runner {{ struct kernel_runner {{
using fmha_block_tile = ck_tile::sequence<{F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0blen}>; using fmha_block_tile = ck_tile::sequence<{F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0max}>;
using fmha_block_warps = ck_tile::sequence<{F_rm}, {F_rn}, {F_rk}>;
using fmha_warp_tile = ck_tile::sequence<{F_wm}, {F_wn}, {F_wk}>; using fmha_warp_tile = ck_tile::sequence<{F_wm}, {F_wn}, {F_wk}>;
using fmha_shape = ck_tile::TileFmhaShape<fmha_block_tile, using fmha_shape = ck_tile::TileFmhaShape<fmha_block_tile,
fmha_block_warps, ck_tile::sequence<{F_rm0}, {F_rn0}, {F_rk0}>,
fmha_warp_tile, fmha_warp_tile,
fmha_block_warps, ck_tile::sequence<{F_rm1}, {F_rn1}, {F_rk1}>,
fmha_warp_tile, fmha_warp_tile,
{F_vlayout}>; {F_vlayout}>;
...@@ -104,7 +111,7 @@ static void run(const ck_tile::stream_config& s, fmha_fwd_splitkv_args a) ...@@ -104,7 +111,7 @@ static void run(const ck_tile::stream_config& s, fmha_fwd_splitkv_args a)
}}; }};
}} }}
using trait_{F_idx} = fmha_fwd_splitkv_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0blen}, {F_vlayout}, using trait_{F_idx} = fmha_fwd_splitkv_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0max}, {F_vlayout},
{F_pipeline_enum}, fmha_mask_{F_idx}, {F_bias}, {F_lse}, {F_squant}, {F_pagedkv}, {F_spad}, {F_skpad}, {F_dpad}, {F_pipeline_enum}, fmha_mask_{F_idx}, {F_bias}, {F_lse}, {F_squant}, {F_pagedkv}, {F_spad}, {F_skpad}, {F_dpad},
{F_dvpad}>; {F_dvpad}>;
...@@ -162,10 +169,12 @@ using fmha_pipeline_problem = ck_tile::BlockFmhaSplitKVCombinePipelineProblem< ...@@ -162,10 +169,12 @@ using fmha_pipeline_problem = ck_tile::BlockFmhaSplitKVCombinePipelineProblem<
using fmha_pipeline = ck_tile::BlockFmhaFwdSplitKVCombinePipeline< using fmha_pipeline = ck_tile::BlockFmhaFwdSplitKVCombinePipeline<
fmha_pipeline_problem>; fmha_pipeline_problem>;
/// FIXME: use {F_spad}/{F_dvpad} as kPadM/kPadN parameters after solving
/// store_tile_raw() data corruption issue
using fmha_epilogue = using fmha_epilogue =
ck_tile::Default2DEpilogue<ck_tile::Default2DEpilogueProblem<typename FmhaFwdTypeConfig<{F_dtype}>::OaccDataType, ck_tile::Default2DEpilogue<ck_tile::Default2DEpilogueProblem<typename FmhaFwdTypeConfig<{F_dtype}>::OaccDataType,
typename FmhaFwdTypeConfig<{F_dtype}>::ODataType, typename FmhaFwdTypeConfig<{F_dtype}>::ODataType,
{F_spad}, {F_dvpad}>>; false, false>>;
using fmha_kernel = using fmha_kernel =
ck_tile::FmhaFwdSplitKVCombineKernel<ck_tile::FmhaFwdSplitKVCombineTilePartitioner<{F_bm0}, {F_bn1}>, ck_tile::FmhaFwdSplitKVCombineKernel<ck_tile::FmhaFwdSplitKVCombineTilePartitioner<{F_bm0}, {F_bn1}>,
...@@ -191,7 +200,9 @@ using trait_{F_idx} = fmha_fwd_splitkv_combine_traits_<{F_hdim}, {F_dtype}, {F_m ...@@ -191,7 +200,9 @@ using trait_{F_idx} = fmha_fwd_splitkv_combine_traits_<{F_hdim}, {F_dtype}, {F_m
template<> template<>
void fmha_fwd_splitkv_combine_oneshot_<trait_{F_idx}>(const ck_tile::stream_config& s, fmha_fwd_splitkv_args a) void fmha_fwd_splitkv_combine_oneshot_<trait_{F_idx}>(const ck_tile::stream_config& s, fmha_fwd_splitkv_args a)
{{ {{
if (a.num_splits <= 16) {{ if (a.num_splits <= 8) {{
kernel_runner<3>::run(s, a);
}} else if (a.num_splits <= 16) {{
kernel_runner<4>::run(s, a); kernel_runner<4>::run(s, a);
}} else if (a.num_splits <= 32) {{ }} else if (a.num_splits <= 32) {{
kernel_runner<5>::run(s, a); kernel_runner<5>::run(s, a);
...@@ -238,8 +249,8 @@ float fmha_fwd_splitkv(fmha_fwd_splitkv_traits t, fmha_fwd_splitkv_args a, const ...@@ -238,8 +249,8 @@ float fmha_fwd_splitkv(fmha_fwd_splitkv_traits t, fmha_fwd_splitkv_args a, const
FMHA_FWD_SPLITKV_API_INNER_DISPATCH=""" {F_if}((t.is_group_mode == {F_mode}) && (t.is_v_rowmajor == {F_vlayout}) && ({F_mask_check}) && (t.bias_type == {F_bias_check}) && (t.has_lse == {F_lse}) && (t.do_fp8_static_quant == {F_squant}) && FMHA_FWD_SPLITKV_API_INNER_DISPATCH=""" {F_if}((t.is_group_mode == {F_mode}) && (t.is_v_rowmajor == {F_vlayout}) && ({F_mask_check}) && (t.bias_type == {F_bias_check}) && (t.has_lse == {F_lse}) && (t.do_fp8_static_quant == {F_squant}) &&
((a.block_table_ptr != nullptr) == {F_pagedkv}) && ({F_scheck}) && ({F_skcheck}) && ({F_dcheck}) && ({F_dvcheck})) {{ ((a.block_table_ptr != nullptr) == {F_pagedkv}) && ({F_scheck}) && ({F_skcheck}) && ({F_dcheck}) && ({F_dvcheck})) {{
using traits_ = fmha_fwd_splitkv_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0blen}, {F_vlayout}, {F_pipeline_enum}, {F_mask}, {F_bias}, {F_lse}, {F_squant}, {F_pagedkv}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}>; using traits_ = fmha_fwd_splitkv_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0max}, {F_vlayout}, {F_pipeline_enum}, {F_mask}, {F_bias}, {F_lse}, {F_squant}, {F_pagedkv}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}>;
using traits2_ = fmha_fwd_splitkv_combine_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_bm0}/2, {F_bn1}, {F_lse}, {F_squant}, {F_spad}, {F_dvpad}>; using traits2_ = fmha_fwd_splitkv_combine_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_bm0}/2, {F_bn1}/2, {F_lse}, {F_squant}, {F_spad}, {F_dvpad}>;
return fmha_fwd_splitkv_<traits_, traits2_>(s, a); return fmha_fwd_splitkv_<traits_, traits2_>(s, a);
}} }}
...@@ -257,7 +268,7 @@ class FmhaFwdSplitKVApiTrait: ...@@ -257,7 +268,7 @@ class FmhaFwdSplitKVApiTrait:
bk0 : int # tile size along qk gemm unroll bk0 : int # tile size along qk gemm unroll
bn1 : int # tile size along v head_dim bn1 : int # tile size along v head_dim
bk1 : int # tile size along kv gemm unroll bk1 : int # tile size along kv gemm unroll
bk0blen : int bk0max : int
vlayout : str vlayout : str
mask : str mask : str
bias : str # bias : str #
...@@ -267,11 +278,11 @@ class FmhaFwdSplitKVApiTrait: ...@@ -267,11 +278,11 @@ class FmhaFwdSplitKVApiTrait:
skpad : str skpad : str
dpad : str dpad : str
dvpad : str dvpad : str
pagedkv : str pagedkv : str
@property @property
def name(self) -> str: def name(self) -> str:
return f'{self.hdim}-{self.dtype}-{self.mode}-{self.bm0}-{self.bn0}-{self.bk0}-{self.bn0}-{self.bk1}-{self.bk0blen}-'+\ return f'{self.hdim}-{self.dtype}-{self.mode}-{self.bm0}-{self.bn0}-{self.bk0}-{self.bn0}-{self.bk1}-{self.bk0max}-'+\
f'{self.vlayout}-{self.mask}-{self.bias}-{self.lse}-{self.squant}-{self.spad}-{self.skpad}-{self.dpad}-'+\ f'{self.vlayout}-{self.mask}-{self.bias}-{self.lse}-{self.squant}-{self.spad}-{self.skpad}-{self.dpad}-'+\
f'{self.dvpad}-{self.pagedkv}' f'{self.dvpad}-{self.pagedkv}'
...@@ -304,8 +315,9 @@ class FmhaFwdSplitKVApiTrait: ...@@ -304,8 +315,9 @@ class FmhaFwdSplitKVApiTrait:
if self.dpad == 't': return f'a.hdim_q % {vec} == 0' if self.dpad == 't': return f'a.hdim_q % {vec} == 0'
else : assert False else : assert False
elif self.pipeline_tag in ['qr']: elif self.pipeline_tag in ['qr']:
if self.dpad == 't': return f'true /*a.hdim_q % {self.bk0blen} != 0*/' # TODO: order of get_pipelines() matters! (ugly) bk0submax = K0_MAX_SUBMAX_MAP[self.bk0max]
else : return f'a.hdim_q % {self.bk0blen} == 0' if self.dpad == 't': return f'true /*a.hdim_q % {bk0submax} != 0*/' # TODO: order of get_pipelines() matters! (ugly)
else : return f'a.hdim_q % {bk0submax} == 0'
else: assert False else: assert False
@property @property
...@@ -315,8 +327,9 @@ class FmhaFwdSplitKVApiTrait: ...@@ -315,8 +327,9 @@ class FmhaFwdSplitKVApiTrait:
if self.dvpad == 't': return f'a.hdim_v % {vec} == 0' if self.dvpad == 't': return f'a.hdim_v % {vec} == 0'
else : assert False else : assert False
elif self.pipeline_tag in ['qr']: elif self.pipeline_tag in ['qr']:
if self.dvpad == 't': return f'true /*a.hdim_v % {self.bk0blen} != 0*/' # TODO: order of get_pipelines() matters! (ugly) bk0submax = K0_MAX_SUBMAX_MAP[self.bk0max]
else : return f'a.hdim_v % {self.bk0blen} == 0' if self.dvpad == 't': return f'true /*a.hdim_v % {bk0submax} != 0*/' # TODO: order of get_pipelines() matters! (ugly)
else : return f'a.hdim_v % {bk0submax} == 0'
else: assert False else: assert False
@dataclass @dataclass
...@@ -411,7 +424,7 @@ class FmhaFwdSplitKVApiPool: ...@@ -411,7 +424,7 @@ class FmhaFwdSplitKVApiPool:
F_lse=BOOL_MAP[trait.lse], F_squant=BOOL_MAP[trait.squant], F_pagedkv=BOOL_MAP[trait.pagedkv], F_lse=BOOL_MAP[trait.lse], F_squant=BOOL_MAP[trait.squant], F_pagedkv=BOOL_MAP[trait.pagedkv],
F_scheck=trait.scheck, F_skcheck=trait.skcheck, F_dcheck=trait.dcheck, F_dvcheck=trait.dvcheck, F_scheck=trait.scheck, F_skcheck=trait.skcheck, F_dcheck=trait.dcheck, F_dvcheck=trait.dvcheck,
F_spad=BOOL_MAP[trait.spad], F_skpad=BOOL_MAP[trait.skpad], F_dpad=BOOL_MAP[trait.dpad], F_dvpad=BOOL_MAP[trait.dvpad], F_spad=BOOL_MAP[trait.spad], F_skpad=BOOL_MAP[trait.skpad], F_dpad=BOOL_MAP[trait.dpad], F_dvpad=BOOL_MAP[trait.dvpad],
F_bm0=trait.bm0, F_bn0=trait.bn0, F_bk0=trait.bk0, F_bn1=trait.bn1, F_bk1=trait.bk1, F_bk0blen=trait.bk0blen, F_bm0=trait.bm0, F_bn0=trait.bn0, F_bk0=trait.bk0, F_bn1=trait.bn1, F_bk1=trait.bk1, F_bk0max=trait.bk0max,
F_hdim=hdim, F_dtype=DTYPE_MAP[dtype]) F_hdim=hdim, F_dtype=DTYPE_MAP[dtype])
if_j = 'if' if j == 0 else 'else if' if_j = 'if' if j == 0 else 'else if'
per_hdim_case = per_hdim_case + FMHA_FWD_API_PER_HDIM_CASE.format(F_if=if_j, F_hdim=hdim, F_inner_dispatch=inners) per_hdim_case = per_hdim_case + FMHA_FWD_API_PER_HDIM_CASE.format(F_if=if_j, F_hdim=hdim, F_inner_dispatch=inners)
...@@ -455,10 +468,13 @@ class FmhaFwdSplitKVKernel: ...@@ -455,10 +468,13 @@ class FmhaFwdSplitKVKernel:
F_bk0 = self.F_tile.F_bk0, F_bk0 = self.F_tile.F_bk0,
F_bn1 = self.F_tile.F_bn1, F_bn1 = self.F_tile.F_bn1,
F_bk1 = self.F_tile.F_bk1, F_bk1 = self.F_tile.F_bk1,
F_bk0blen = self.F_tile.F_bk0blen, F_bk0max = self.F_tile.F_bk0max,
F_rm = self.F_tile.F_rm, F_rm0 = self.F_tile.F_rm0,
F_rn = self.F_tile.F_rn, F_rn0 = self.F_tile.F_rn0,
F_rk = self.F_tile.F_rk, F_rk0 = self.F_tile.F_rk0,
F_rm1 = self.F_tile.F_rm1,
F_rn1 = self.F_tile.F_rn1,
F_rk1 = self.F_tile.F_rk1,
F_wm = self.F_tile.F_wm, F_wm = self.F_tile.F_wm,
F_wn = self.F_tile.F_wn, F_wn = self.F_tile.F_wn,
F_wk = self.F_tile.F_wk, F_wk = self.F_tile.F_wk,
...@@ -498,7 +514,7 @@ class FmhaFwdSplitKVKernel: ...@@ -498,7 +514,7 @@ class FmhaFwdSplitKVKernel:
bk0=self.F_tile.F_bk0, bk0=self.F_tile.F_bk0,
bn1=self.F_tile.F_bn1, bn1=self.F_tile.F_bn1,
bk1=self.F_tile.F_bk1, bk1=self.F_tile.F_bk1,
bk0blen=self.F_tile.F_bk0blen, bk0max=self.F_tile.F_bk0max,
vlayout=self.F_pipeline.F_vlayout, vlayout=self.F_pipeline.F_vlayout,
mask=self.F_pipeline.F_mask, mask=self.F_pipeline.F_mask,
bias=self.F_pipeline.F_bias, bias=self.F_pipeline.F_bias,
...@@ -551,16 +567,17 @@ class FmhaFwdSplitKVCombineKernel: ...@@ -551,16 +567,17 @@ class FmhaFwdSplitKVCombineKernel:
def get_fmha_fwd_tile_dict_from_dtype(dtype : str) -> Optional[dict]: def get_fmha_fwd_tile_dict_from_dtype(dtype : str) -> Optional[dict]:
if dtype == 'fp16' or dtype == 'bf16': if dtype == 'fp16' or dtype == 'bf16':
return { return {
'32' : FmhaFwdTileSize(128, 64, 16, 32, 32, 32, 2, 1, 1, 32, 32, 16, -1), '32' : FmhaFwdTileSize(32, 64, 16, 32, 32, 32, 2, 1, 1, 2, 1, 1, 16, 16, 16, -1),
'64' : FmhaFwdTileSize(128, 64, 32, 64, 32, 64, 4, 1, 1, 32, 32, 16, -1), '64' : FmhaFwdTileSize(64, 64, 32, 64, 32, 64, 4, 1, 1, 4, 1, 1, 16, 16, 16, -1),
'128' : FmhaFwdTileSize(128, 128, 32, 128, 32, 128, 4, 1, 1, 32, 32, 16, -1), ## '96' : FmhaFwdTileSize(64, 128, 32, 128, 32, 96, 4, 1, 1, 4, 1, 1, 16, 16, 16, -1),
'256' : FmhaFwdTileSize(128, 128, 32, 256, 32, 256, 4, 1, 1, 32, 32, 16, -1), '128' : FmhaFwdTileSize(64, 128, 32, 128, 32, 128, 4, 1, 1, 4, 1, 1, 16, 16, 16, -1),
'256' : FmhaFwdTileSize(64, 128, 32, 256, 32, 256, 4, 1, 1, 4, 1, 1, 16, 16, 16, -1),
} }
elif dtype == 'fp8' or dtype == 'bf8': elif dtype == 'fp8' or dtype == 'bf8':
return { return {
'64' : FmhaFwdTileSize(128, 64, 32, 64, 32, 64, 2, 1, 1, 32, 32, 32, -1), '64' : FmhaFwdTileSize(128, 64, 32, 64, 32, 64, 2, 1, 1, 2, 1, 1, 32, 32, 32, -1),
'128' : FmhaFwdTileSize(128, 128, 32, 128, 32, 128, 4, 1, 1, 32, 32, 32, -1), '128' : FmhaFwdTileSize(128, 128, 32, 128, 32, 128, 4, 1, 1, 4, 1, 1, 32, 32, 32, -1),
'256' : FmhaFwdTileSize(128, 128, 32, 256, 32, 256, 4, 1, 1, 32, 32, 32, -1) '256' : FmhaFwdTileSize(128, 128, 32, 256, 32, 256, 4, 1, 1, 4, 1, 1, 32, 32, 32, -1)
} }
else: else:
return None return None
...@@ -568,16 +585,17 @@ def get_fmha_fwd_tile_dict_from_dtype(dtype : str) -> Optional[dict]: ...@@ -568,16 +585,17 @@ def get_fmha_fwd_tile_dict_from_dtype(dtype : str) -> Optional[dict]:
def get_fmha_fwd_splitkv_combine_tile_dict_from_dtype(dtype : str) -> Optional[dict]: def get_fmha_fwd_splitkv_combine_tile_dict_from_dtype(dtype : str) -> Optional[dict]:
if dtype == 'fp16' or dtype == 'bf16': if dtype == 'fp16' or dtype == 'bf16':
return { return {
'32' : FmhaFwdSplitKVCombineTileSize(64, 32, -1), '32' : FmhaFwdSplitKVCombineTileSize(16, 16, -1),
'64' : FmhaFwdSplitKVCombineTileSize(64, 64, -1), '64' : FmhaFwdSplitKVCombineTileSize(32, 32, -1),
'128' : FmhaFwdSplitKVCombineTileSize(64, 128, -1), ## '96' : FmhaFwdSplitKVCombineTileSize(32, 64, -1),
'256' : FmhaFwdSplitKVCombineTileSize(64, 256, -1), '128' : FmhaFwdSplitKVCombineTileSize(32, 64, -1),
'256' : FmhaFwdSplitKVCombineTileSize(32, 128, -1),
} }
elif dtype == 'fp8' or dtype == 'bf8': elif dtype == 'fp8' or dtype == 'bf8':
return { return {
'64' : FmhaFwdSplitKVCombineTileSize(64, 64, -1), '64' : FmhaFwdSplitKVCombineTileSize(64, 32, -1),
'128' : FmhaFwdSplitKVCombineTileSize(64, 128, -1), '128' : FmhaFwdSplitKVCombineTileSize(64, 64, -1),
'256' : FmhaFwdSplitKVCombineTileSize(64, 256, -1), '256' : FmhaFwdSplitKVCombineTileSize(64, 128, -1),
} }
else: else:
return None return None
...@@ -598,7 +616,7 @@ def get_fwd_splitkv_blobs(kernel_filter : Optional[str], receipt, mask_impl) -> ...@@ -598,7 +616,7 @@ def get_fwd_splitkv_blobs(kernel_filter : Optional[str], receipt, mask_impl) ->
if dtype in ['fp16', 'bf16']: if dtype in ['fp16', 'bf16']:
for mask, bias, lse, pagedkv in itertools.product(get_mask_map(mask_impl).keys(), BIAS_MAP.keys(), ["t", "f"], ["t", "f"]): for mask, bias, lse, pagedkv in itertools.product(get_mask_map(mask_impl).keys(), BIAS_MAP.keys(), ["t", "f"], ["t", "f"]):
# TODO: use async pipeline when compiler is more stable # TODO: use async pipeline when compiler is more stable
if hdim == 256 or hdim in [32, 64, 128]: if hdim == 256 or hdim in [32, 64, 128]: ### [32, 64, 96, 128]:
# if True: # if True:
pipelines.append(Pipeline('qr', 'row', 'f', 't', 'f', 'f', bias, lse, squant, pagedkv, mask)) pipelines.append(Pipeline('qr', 'row', 'f', 't', 'f', 'f', bias, lse, squant, pagedkv, mask))
pipelines.append(Pipeline('qr', 'col', 'f', 't', 'f', 'f', bias, lse, squant, pagedkv, mask)) pipelines.append(Pipeline('qr', 'col', 'f', 't', 'f', 'f', bias, lse, squant, pagedkv, mask))
...@@ -737,4 +755,4 @@ def list_blobs(file_path : Path, kernel_filter : Optional[str], receipt, mask_im ...@@ -737,4 +755,4 @@ def list_blobs(file_path : Path, kernel_filter : Optional[str], receipt, mask_im
_, kernels = get_fwd_splitkv_blobs(kernel_filter, receipt, mask_impl) _, kernels = get_fwd_splitkv_blobs(kernel_filter, receipt, mask_impl)
for kernel in kernels: for kernel in kernels:
f.write(str(file_path.parent / GEN_DIR / kernel.filename) + "\n") f.write(str(file_path.parent / GEN_DIR / kernel.filename) + "\n")
f.write(str(file_path.parent / GEN_DIR / FMHA_FWD_SPLITKV_API_FILENAME) + "\n") f.write(str(file_path.parent / GEN_DIR / FMHA_FWD_SPLITKV_API_FILENAME) + "\n")
\ No newline at end of file
...@@ -557,33 +557,16 @@ bool run(const ck_tile::ArgParser& arg_parser) ...@@ -557,33 +557,16 @@ bool run(const ck_tile::ArgParser& arg_parser)
} }
#endif #endif
struct static const auto get_lengths = [](bool permute,
{ ck_tile::index_t b /*batch*/,
auto operator()(bool permute, ck_tile::index_t h /*nhead*/,
ck_tile::index_t b /*batch*/, ck_tile::index_t s /*seqlen*/,
ck_tile::index_t h /*nhead*/, ck_tile::index_t d /*hdim*/) {
ck_tile::index_t s /*seqlen*/, if(permute)
ck_tile::index_t d /*hdim*/) return std::array<ck_tile::index_t, 4>{b, h, s, d};
{ else
if(permute) return std::array<ck_tile::index_t, 4>{b, s, h, d};
return std::array<ck_tile::index_t, 4>{b, h, s, d}; };
else
return std::array<ck_tile::index_t, 4>{b, s, h, d};
}
auto operator()(bool permute,
ck_tile::index_t ns /*num_splits*/,
ck_tile::index_t b /*batch*/,
ck_tile::index_t h /*nhead*/,
ck_tile::index_t s /*seqlen*/,
ck_tile::index_t d /*hdim*/)
{
if(permute)
return std::array<ck_tile::index_t, 5>{ns, b, h, s, d};
else
return std::array<ck_tile::index_t, 5>{ns, b, s, h, d};
}
} get_lengths;
bool is_v_rowmajor = vlayout == std::string("r"); bool is_v_rowmajor = vlayout == std::string("r");
...@@ -635,12 +618,15 @@ bool run(const ck_tile::ArgParser& arg_parser) ...@@ -635,12 +618,15 @@ bool run(const ck_tile::ArgParser& arg_parser)
ck_tile::HostTensor<LSEDataType> lse_acc_host( ck_tile::HostTensor<LSEDataType> lse_acc_host(
1 < num_splits || use_kvcache 1 < num_splits || use_kvcache
? std::array<ck_tile::index_t, 4>{num_splits, shape_batch, nhead, shape_seqlen_q} ? std::array<ck_tile::index_t, 4>{shape_batch, nhead, num_splits, shape_seqlen_q}
: std::array<ck_tile::index_t, 4>{1, 1, 1, 1}); : std::array<ck_tile::index_t, 4>{1, 1, 1, 1});
ck_tile::HostTensor<OaccDataType> o_acc_host( ck_tile::HostTensor<OaccDataType> o_acc_host(
1 < num_splits || use_kvcache 1 < num_splits || use_kvcache ? std::array<ck_tile::index_t, 5>{shape_batch,
? get_lengths(o_perm, num_splits, shape_batch, nhead, shape_seqlen_q, hdim_v) nhead,
: std::array<ck_tile::index_t, 5>{1, 1, 1, 1, 1}); num_splits,
shape_seqlen_q,
hdim_v}
: std::array<ck_tile::index_t, 5>{1, 1, 1, 1, 1});
// batch mode of lse data layout is [batch, nhead, seqlen_q] // batch mode of lse data layout is [batch, nhead, seqlen_q]
// group mode of lse data layout is [nhead, total_seqlen_q] // group mode of lse data layout is [nhead, total_seqlen_q]
...@@ -880,7 +866,7 @@ bool run(const ck_tile::ArgParser& arg_parser) ...@@ -880,7 +866,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
}(); }();
const ck_tile::index_t stride_bias = (i_perm ? shape_seqlen_k : 1 * shape_seqlen_k); const ck_tile::index_t stride_bias = (i_perm ? shape_seqlen_k : 1 * shape_seqlen_k);
const ck_tile::index_t stride_randval = (max_seqlen_k); const ck_tile::index_t stride_randval = (max_seqlen_k);
const ck_tile::index_t stride_o_acc = (o_perm ? hdim_v : nhead * hdim_v); const ck_tile::index_t stride_o_acc = (hdim_v);
const ck_tile::index_t stride_o = (o_perm ? hdim_v : nhead * hdim_v); const ck_tile::index_t stride_o = (o_perm ? hdim_v : nhead * hdim_v);
// setup nhead_stride_* arguments // setup nhead_stride_* arguments
const ck_tile::index_t nhead_stride_q = (i_perm ? shape_seqlen_q * hdim_q : hdim_q); const ck_tile::index_t nhead_stride_q = (i_perm ? shape_seqlen_q * hdim_q : hdim_q);
...@@ -906,8 +892,8 @@ bool run(const ck_tile::ArgParser& arg_parser) ...@@ -906,8 +892,8 @@ bool run(const ck_tile::ArgParser& arg_parser)
(i_perm ? 0 * shape_seqlen_q * shape_seqlen_k : 0 * shape_seqlen_k); (i_perm ? 0 * shape_seqlen_q * shape_seqlen_k : 0 * shape_seqlen_k);
const ck_tile::index_t nhead_stride_randval = (shape_seqlen_q * max_seqlen_k); const ck_tile::index_t nhead_stride_randval = (shape_seqlen_q * max_seqlen_k);
const ck_tile::index_t nhead_stride_lse = shape_seqlen_q; const ck_tile::index_t nhead_stride_lse = shape_seqlen_q;
const ck_tile::index_t nhead_stride_lse_acc = shape_seqlen_q; const ck_tile::index_t nhead_stride_lse_acc = (num_splits * shape_seqlen_q);
const ck_tile::index_t nhead_stride_o_acc = (o_perm ? shape_seqlen_q * hdim_v : hdim_v); const ck_tile::index_t nhead_stride_o_acc = (num_splits * shape_seqlen_q * hdim_v);
const ck_tile::index_t nhead_stride_o = (o_perm ? shape_seqlen_q * hdim_v : hdim_v); const ck_tile::index_t nhead_stride_o = (o_perm ? shape_seqlen_q * hdim_v : hdim_v);
// setup batch_stride_* arguments // setup batch_stride_* arguments
const ck_tile::index_t batch_stride_q = (nhead * shape_seqlen_q * hdim_q); const ck_tile::index_t batch_stride_q = (nhead * shape_seqlen_q * hdim_q);
...@@ -922,13 +908,13 @@ bool run(const ck_tile::ArgParser& arg_parser) ...@@ -922,13 +908,13 @@ bool run(const ck_tile::ArgParser& arg_parser)
const ck_tile::index_t batch_stride_bias = (0 * nhead * shape_seqlen_q * shape_seqlen_k); const ck_tile::index_t batch_stride_bias = (0 * nhead * shape_seqlen_q * shape_seqlen_k);
const ck_tile::index_t batch_stride_randval = (nhead * shape_seqlen_q * max_seqlen_k); const ck_tile::index_t batch_stride_randval = (nhead * shape_seqlen_q * max_seqlen_k);
const ck_tile::index_t batch_stride_lse = (nhead * shape_seqlen_q); const ck_tile::index_t batch_stride_lse = (nhead * shape_seqlen_q);
const ck_tile::index_t batch_stride_lse_acc = (nhead * shape_seqlen_q); const ck_tile::index_t batch_stride_lse_acc = (nhead * num_splits * shape_seqlen_q);
const ck_tile::index_t batch_stride_o_acc = (nhead * shape_seqlen_q * hdim_v); const ck_tile::index_t batch_stride_o_acc = (nhead * num_splits * shape_seqlen_q * hdim_v);
const ck_tile::index_t batch_stride_o = (nhead * shape_seqlen_q * hdim_v); const ck_tile::index_t batch_stride_o = (nhead * shape_seqlen_q * hdim_v);
const ck_tile::index_t batch_stride_block_table = (max_num_page_blocks / batch); const ck_tile::index_t batch_stride_block_table = (max_num_page_blocks / batch);
// setup split_stride_* arguments (only used in split-kv kernel) // setup split_stride_* arguments (only used in split-kv kernel)
const ck_tile::index_t split_stride_lse_acc = (shape_batch * nhead * shape_seqlen_q); const ck_tile::index_t split_stride_lse_acc = (shape_seqlen_q);
const ck_tile::index_t split_stride_o_acc = (shape_batch * nhead * shape_seqlen_q * hdim_v); const ck_tile::index_t split_stride_o_acc = (shape_seqlen_q * hdim_v);
args.q_ptr = q_buf.GetDeviceBuffer(); args.q_ptr = q_buf.GetDeviceBuffer();
args.k_ptr = k_buf.GetDeviceBuffer(); args.k_ptr = k_buf.GetDeviceBuffer();
......
...@@ -47,6 +47,9 @@ def list_blobs(output_file : Optional[str], api_list : List[str], kernel_filter ...@@ -47,6 +47,9 @@ def list_blobs(output_file : Optional[str], api_list : List[str], kernel_filter
assert output_file is not None assert output_file is not None
file_path = Path(output_file) file_path = Path(output_file)
# create an empty file / drop its contents if it exists
open(file_path, "w").close()
for api in api_list: for api in api_list:
handler = handlers[api][HandlerId.LIST_BLOBS] handler = handlers[api][HandlerId.LIST_BLOBS]
handler(file_path, kernel_filter, receipt, mask_impl) handler(file_path, kernel_filter, receipt, mask_impl)
......
...@@ -29,14 +29,14 @@ while getopts ":sa" opt; do ...@@ -29,14 +29,14 @@ while getopts ":sa" opt; do
done done
run_fp16_bf16_tests() { run_fp16_bf16_tests() {
local NUM_SPLITS=(1) local NUM_SPLITS="1"
local PAGE_BLOCK_SIZE=(0) local PAGE_BLOCK_SIZE="0"
local CACHE_BATCH_IDX=(0) local CACHE_BATCH_IDX="0"
if [ $TEST_SPLITKV -eq 1 ] ; then if [ $TEST_SPLITKV -eq 1 ] ; then
NUM_SPLITS+=(2 3) NUM_SPLITS="$NUM_SPLITS 2 3"
PAGE_BLOCK_SIZE+=(128) PAGE_BLOCK_SIZE="$PAGE_BLOCK_SIZE 128"
CACHE_BATCH_IDX+=(1) CACHE_BATCH_IDX="$CACHE_BATCH_IDX 1"
fi fi
for prec in "fp16" "bf16" ; do for prec in "fp16" "bf16" ; do
...@@ -47,9 +47,9 @@ run_fp16_bf16_tests() { ...@@ -47,9 +47,9 @@ run_fp16_bf16_tests() {
for lse in 0 1 ; do for lse in 0 1 ; do
for bias in "n" "e" "a" ; do for bias in "n" "e" "a" ; do
for p_drop in 0.0 0.2 ; do for p_drop in 0.0 0.2 ; do
for num_splits in "${NUM_SPLITS[@]}" ; do for num_splits in $NUM_SPLITS ; do
for page_block_size in "${PAGE_BLOCK_SIZE[@]}" ; do for page_block_size in $PAGE_BLOCK_SIZE ; do
for cache_batch_idx in "${CACHE_BATCH_IDX[@]}" ; do for cache_batch_idx in $CACHE_BATCH_IDX ; do
# $EXE -prec=$prec -mode=$mode -b=1 -h=1 -d=$hdim -s=1024 -bias=$bias -p_drop=$p_drop -lse=$lse -iperm=$perm -operm=$perm -vlayout=$vlayout -num_splits=$num_splits -page_block_size=$page_block_size -kname=$KNAME $COMMON_ARGS # $EXE -prec=$prec -mode=$mode -b=1 -h=1 -d=$hdim -s=1024 -bias=$bias -p_drop=$p_drop -lse=$lse -iperm=$perm -operm=$perm -vlayout=$vlayout -num_splits=$num_splits -page_block_size=$page_block_size -kname=$KNAME $COMMON_ARGS
$EXE -prec=$prec -mode=$mode -b=2 -h=2 -h_k=1 -d=16, -d_v=$hdim -s=55 -s_k=256 -bias=$bias -p_drop=$p_drop -lse=$lse -iperm=$perm -operm=$perm -vlayout=$vlayout -num_splits=$num_splits -page_block_size=$page_block_size -cache_batch_idx=$cache_batch_idx -kname=$KNAME $COMMON_ARGS $EXE -prec=$prec -mode=$mode -b=2 -h=2 -h_k=1 -d=16, -d_v=$hdim -s=55 -s_k=256 -bias=$bias -p_drop=$p_drop -lse=$lse -iperm=$perm -operm=$perm -vlayout=$vlayout -num_splits=$num_splits -page_block_size=$page_block_size -cache_batch_idx=$cache_batch_idx -kname=$KNAME $COMMON_ARGS
...@@ -103,4 +103,4 @@ if [ $TEST_APPENDKV -eq 1 ] ; then ...@@ -103,4 +103,4 @@ if [ $TEST_APPENDKV -eq 1 ] ; then
run_fp16_appendkv_tests run_fp16_appendkv_tests
fi fi
set +x set +x
\ No newline at end of file
# not using add_example_executable() to add this target, since we don't want this to have set(LAYERNORM2D_FWD_KNOWN_APIS "fwd;bwd")
# to be included in "make all/install/check" set(LAYERNORM2D_FWD_ENABLE_APIS "fwd" CACHE STRING
add_executable(tile_example_layernorm2d_fwd EXCLUDE_FROM_ALL layernorm2d_fwd.cpp) "semicolon-separated list of APIs to generate (${LAYERNORM2D_FWD_KNOWN_APIS}) & link, or \"all\".")
target_compile_options(tile_example_layernorm2d_fwd PRIVATE -DSAVE_MEAN_INV_STD) if(LAYERNORM2D_FWD_ENABLE_APIS STREQUAL "all")
\ No newline at end of file set(LAYERNORM2D_FWD_ENABLE_APIS ${LAYERNORM2D_FWD_KNOWN_APIS})
endif()
# generate a list of kernels, but not actually emit files at config sta
execute_process(
COMMAND ${Python3_EXECUTABLE} ${CMAKE_CURRENT_LIST_DIR}/generate.py
--api ${LAYERNORM2D_FWD_ENABLE_APIS} --working_path ${CMAKE_CURRENT_BINARY_DIR} --list_blobs
RESULT_VARIABLE ret
)
if(ret AND NOT ret EQUAL 0)
message( FATAL_ERROR "Fail to generate kernels via Python. ${ret}")
endif()
file(STRINGS ${CMAKE_CURRENT_BINARY_DIR}/layernorm2d_fwd_blobs.txt LAYERNORM2D_FWD_GEN_BLOBS)
add_custom_command(
OUTPUT ${LAYERNORM2D_FWD_GEN_BLOBS}
COMMAND ${Python3_EXECUTABLE} ${CMAKE_CURRENT_LIST_DIR}/generate.py
--api ${LAYERNORM2D_FWD_ENABLE_APIS} --working_path ${CMAKE_CURRENT_BINARY_DIR} --gen_blobs
)
set(EXAMPLE_LAYERNORM2D_FWD "tile_example_layernorm2d_fwd")
message("adding example ${EXAMPLE_LAYERNORM2D_FWD}")
add_executable(${EXAMPLE_LAYERNORM2D_FWD} EXCLUDE_FROM_ALL layernorm2d_fwd.cpp)
target_include_directories(${EXAMPLE_LAYERNORM2D_FWD} PRIVATE ${CMAKE_CURRENT_LIST_DIR})
target_sources(${EXAMPLE_LAYERNORM2D_FWD} PRIVATE ${LAYERNORM2D_FWD_GEN_BLOBS})
set(EXAMPLE_LAYERNORM2D_FWD_COMPILE_OPTIONS)
# NOTE: we turn off undefined-func-template to let source compile without explicit declare function specializations
list(APPEND EXAMPLE_LAYERNORM2D_FWD_COMPILE_OPTIONS -Wno-undefined-func-template -Wno-float-equal)
target_compile_options(${EXAMPLE_LAYERNORM2D_FWD} PRIVATE ${EXAMPLE_LAYERNORM2D_FWD_COMPILE_OPTIONS})
# TODO: we have to turn off this global prop, otherwise the progress bar generated
# by cmake will print too many files, execvp: /bin/sh: Argument list too long
# however, this property may affect global
# TODO: consider codegen a makefile by us
set_property(GLOBAL PROPERTY RULE_MESSAGES OFF)
# Layernorm2D forward # Layernorm2D forward
This folder contains example for Layernorm2D forward using ck_tile tile-programming implementation. This folder contains example for Layernorm2D forward using `ck_tile` tile-programming implementation.
# Implementation and feature support
## welford online algorithm
We use welfold algorithm to update `mean`/`variance` block by block. For `N <=4096` case we can compute `mean`/`var`/`normalization` within one loop, we call it `one-pass`. For large N case, it is hard to keep `mean`/`var` inside register/LDS and then computation `normalization`, so we need to load input twice, first time to compute `mean`/`var` block-by-block, then load input another time to compute the `normalization`. We call it `two-pass`.
## mean/variance save
In training case the mean/variance need to store out (TBD, not supported yet)
## prenorm/postnorm
![](misc/pnorm.png)
since [prenorm/postnorm](https://arxiv.org/pdf/1906.01787) is quite common in LLM blocks, this example boosts this feature by kernel fusion. Note that `prenorm`/`postnorm` always need to do elementwise-add a `shortcut` before the actual layernorm computation, and optionally store out the result to global. You can use `-fadd=1` to test `pre-add+store`, or `-fadd=2` to test `pre-add` without store out (not codegen by default).
## smooth-quant/dynamic-quant
we support smooth/dynamic quantization for `int8` output, by setting `-fquant=1` and `-prec_o=int8`. In this case the output will doing a rowwise dynamic quantization like below. Note that smooth-quant require input a `(1*N)` size per-channel scale(in fp32 in our example, though this is customizable), then elememt-wise multiply the tensor for each row, then compute the rowwise dynamic quant. if set `-fquant=2` will have the input per-channel scale stage, only the dynamic quant. This case is supported in our kernel but by default not generated (TBD: add some filter in generate.py support on-demand codegen)
![](misc/dquant.png)
```
# assume output int8, hidden_states is [m, n] shape and in fp16/bf16
# [m, 1]
per_token_amax, _ = torch.max(
input=torch.abs(hidden_states),
dim=-1,
keepdim=True
)
per_token_scale = per_token_amax.to(dtype=torch.float32) / 127.0
# quant hidden_states
hidden_states = (hidden_states / per_token_scale).to(dtype=torch.int8)
return hidden_states, per_token_scale
# hidden_states now is int8 will feed to next layer as intput
# per_token_scale will be used as dequant factor later layer
```
## build ## build
``` ```
# in the root of ck_tile # in the root of ck_tile
mkdir build && cd build mkdir build && cd build
# you can replace <arch> with the appropriate architecture (for example gfx90a or gfx942) or leave it blank sh ../script/cmake-ck-dev.sh ../ <arch> # you can replace this <arch> to gfx90a, gfx942...
sh ../script/cmake-ck-dev.sh ../ <arch>
make tile_example_layernorm2d_fwd -j make tile_example_layernorm2d_fwd -j
``` ```
This will result in an executable `build/bin/tile_example_layernorm2d_fwd` This will result in an executable `build/bin/tile_example_layernorm2d_fwd`
...@@ -16,8 +51,35 @@ This will result in an executable `build/bin/tile_example_layernorm2d_fwd` ...@@ -16,8 +51,35 @@ This will result in an executable `build/bin/tile_example_layernorm2d_fwd`
``` ```
args: args:
-m m dimension (default:3328) -m m dimension (default:3328)
-n m dimension (default:4096) -n n dimension (default:4096)
-stride stride per row, if -1 then equal to n (default:-1)
-e epsilon (default:1e-5) -e epsilon (default:1e-5)
-save_mv save mean/variance(invstd) or not. set to 1 in training case (default:0)
-v cpu validation or not (default:1) -v cpu validation or not (default:1)
-prec precision (default:fp16) -kname print kernel name or not (default:1)
``` -prec_i input precision (default:fp16)
\ No newline at end of file -prec_o output precision, set auto will be the same as input (default:auto)
-prec_sx output quant scale type, set auto will be the same as input. used when fquant=1 (default:auto)
-prec_sy output quant scale type, set auto will be the same as input. used when fquant=1 or 2 (default:auto)
-fadd fused-add, 0:no fused add, 1:preadd+store, 2:preadd only (default:0)
-fquant fused-quant, 0:no, 1:smooth-dynamic-quant, 2:dynamic-quant (default:0)
-warmup cold iter (default:5)
-repeat hot iter (default:20)
```
## limitations
Note that `fquant=2`, `fadd=2`, `prec_sx/prec_sy` other than `fp32` are not by default generated. Though our kernel template suppor this. (TBD: add some flag in generate.py) to generate those instance on demand. Beside, `N>8192` case will by default using two-pass pipeline, and `-fquant=1/2` are not supported yet. If need suport `N>8192` and `fused+residual+store`, you can use this example together with `12_smoothquant`, to construct layernorm+residual, and smoothquant, 2 kernels for this purpose.
```
# some case
# standard fp16 layernorm 2d, m=10. n=1024
./build/bin/tile_example_layernorm2d_fwd -m=10 -n=1024
# standard fp16 layernorm 2d, m=10. n=1024, fused-smooth-quant, output in int8
./build/bin/tile_example_layernorm2d_fwd -m=10 -n=1024 -prec_o=int8 -fquant=1
# standard fp16 layernorm 2d, m=10. n=1024, fused-smooth-quant+fused-add-store, output in int8
./build/bin/tile_example_layernorm2d_fwd -m=10 -n=1024 -prec_o=int8 -fquant=1 -fadd=1
```
# SPDX-License-Identifier: MIT
# Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
# generate kernel instances to speed up compilation
import argparse
from enum import IntEnum
from pathlib import Path
import sys
from typing import List, Optional, Any
import functools
import itertools
import copy
from dataclasses import dataclass
def get_if_str(idx, total, lase_else = True):
if idx == 0:
return 'if'
elif idx < total - 1:
return 'else if'
else:
if lase_else:
return 'else'
else:
return 'else if'
FUSED_ADD_ENUM_STR_MAP = [
'no',
'pras', # pre-norm
'pra' ] # post-norm
FUSED_FUSED_SWEEP_STR_MAP = [
'no',
'dquant' ]
DATA_TYPE_MAP = {'fp32' : 'float',
'fp16' : 'ck_tile::fp16_t',
'bf16' : 'ck_tile::bf16_t',
'int8' : 'ck_tile::int8_t'}
def BOOL_MAP(b_) -> str:
if b_:
return 'true'
else:
return 'false'
class layernorm_fwd_codegen:
API_TRAITS_DEFINE = """
// this is used to pattern-match internl kernel implementation, not to instantiate kernel
template <typename XDataType_,
typename YDataType_,
typename XScaleDataType_,
typename YScaleDataType_,
ck_tile::index_t Repeat_M_, // each thread repeat along M
ck_tile::index_t Repeat_N_, // each thread repeat along N
ck_tile::index_t ThreadPerBlock_M_, // num threads along M
ck_tile::index_t ThreadPerBlock_N_, // num threads along N
ck_tile::index_t Vector_N_, // vector size along N
bool kPadN_,
bool kSaveMeanInvStd_,
bool kFastFDiv_,
bool kTwoPass_,
ck_tile::index_t kFusedAdd_ = 0,
ck_tile::index_t kFusedQuant_ = 0>
struct layernorm2d_fwd_traits_
{
using XDataType = ck_tile::remove_cvref_t<XDataType_>;
using YDataType = ck_tile::remove_cvref_t<YDataType_>;
using XScaleDataType = ck_tile::remove_cvref_t<XScaleDataType_>;
using YScaleDataType = ck_tile::remove_cvref_t<YScaleDataType_>;
static constexpr bool is_warp_per_row = ThreadPerBlock_N_ <= warpSize;
static_assert((ThreadPerBlock_M_ * ThreadPerBlock_N_) % warpSize == 0);
static constexpr ck_tile::index_t total_warps =
(ThreadPerBlock_M_ * ThreadPerBlock_N_) / warpSize;
// num of warps along m
static constexpr ck_tile::index_t BlockWarps_M = []() {
if constexpr(is_warp_per_row)
{
static_assert(warpSize % ThreadPerBlock_N_ == 0);
return total_warps * (warpSize / ThreadPerBlock_N_);
}
else
{
// static_assert(warpSize % ThreadPerBlock_M_ == 0);
return total_warps / (ThreadPerBlock_N_ / warpSize);
}
}();
// num of warps along n
static constexpr ck_tile::index_t BlockWarps_N = []() {
if constexpr(is_warp_per_row)
{
static_assert(warpSize % ThreadPerBlock_N_ == 0);
return 1;
}
else
{
static_assert(ThreadPerBlock_N_ % warpSize == 0);
return ThreadPerBlock_N_ / warpSize;
}
}();
static constexpr ck_tile::index_t Repeat_M = Repeat_M_;
static constexpr ck_tile::index_t Repeat_N = Repeat_N_;
static constexpr ck_tile::index_t Block_M = Repeat_M_ * ThreadPerBlock_M_;
static constexpr ck_tile::index_t Block_N = Repeat_N_ * ThreadPerBlock_N_ * Vector_N_;
static constexpr ck_tile::index_t Warp_M = ThreadPerBlock_M_ / BlockWarps_M;
static constexpr ck_tile::index_t Warp_N = ThreadPerBlock_N_ / BlockWarps_N * Vector_N_;
using BlockTile = ck_tile::sequence<Block_M, Block_N>;
using BlockWarps = ck_tile::sequence<BlockWarps_M, BlockWarps_N>;
using WarpTile = ck_tile::sequence<Warp_M, Warp_N>;
using Vector = ck_tile::sequence<1, Vector_N_>;
using Shape = ck_tile::Generic2dBlockShape<BlockTile, BlockWarps, WarpTile, Vector>;
static constexpr bool kPadN = kPadN_;
static constexpr bool kSaveMeanInvStd = kSaveMeanInvStd_;
static constexpr bool kFastFDiv = kFastFDiv_;
static constexpr bool kTwoPass = kTwoPass_;
static constexpr ck_tile::index_t kFusedAdd = kFusedAdd_;
static constexpr ck_tile::index_t kFusedQuant = kFusedQuant_;
};
template <typename XDataType_,
typename YDataType_,
typename XScaleDataType_,
typename YScaleDataType_,
ck_tile::index_t Repeat_M_, // each thread repeat along M
ck_tile::index_t Repeat_N_, // each thread repeat along N
ck_tile::index_t ThreadPerBlock_M_, // num threads along M
ck_tile::index_t ThreadPerBlock_N_, // num threads along N
ck_tile::index_t Vector_N_, // vector size along N
bool kPadN_,
bool kSaveMeanInvStd_,
bool kFastFDiv_,
bool kTwoPass_,
int kFusedAdd_,
int kFusedQuant_>
using traits_ = layernorm2d_fwd_traits_<XDataType_,
YDataType_,
XScaleDataType_,
YScaleDataType_,
Repeat_M_,
Repeat_N_,
ThreadPerBlock_M_,
ThreadPerBlock_N_,
Vector_N_,
kPadN_,
kSaveMeanInvStd_,
kFastFDiv_,
kTwoPass_,
kFusedAdd_,
kFusedQuant_>;
"""
API_COMMON_HEADER = """
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#include <ck_tile/core.hpp>
#include "layernorm2d_fwd.hpp"
#include <ck_tile/ops/epilogue.hpp>
#include <iostream>
#pragma once
using S = ck_tile::stream_config;
using A = layernorm2d_fwd_args;
{F_traits_define}
template <typename Traits_>
float layernorm2d_fwd_(const S& s, A a)
{{
using XDataType = typename Traits_::XDataType;
using YDataType = typename Traits_::YDataType;
using XScaleDataType = typename Traits_::XScaleDataType;
using YScaleDataType = typename Traits_::YScaleDataType;
using ComputeDataType = typename LayerNormTypeConfig<XDataType, YDataType, XScaleDataType, YScaleDataType>::ComputeDataType;
using PipelineTraits = ck_tile::Layernorm2dFwdTraits<Traits_::kPadN,
Traits_::kSaveMeanInvStd,
Traits_::kFastFDiv,
Traits_::kTwoPass,
static_cast<ck_tile::Layernorm2dFusedAddEnum>(Traits_::kFusedAdd),
static_cast<ck_tile::Layernorm2dFusedQuantEnum>(Traits_::kFusedQuant)>;
using PipelineProblem = ck_tile::Layernorm2dFwdPipelineProblem<
typename LayerNormTypeConfig<XDataType, YDataType, XScaleDataType, YScaleDataType>::XDataType,
typename LayerNormTypeConfig<XDataType, YDataType, XScaleDataType, YScaleDataType>::GammaDataType,
typename LayerNormTypeConfig<XDataType, YDataType, XScaleDataType, YScaleDataType>::BetaDataType,
typename LayerNormTypeConfig<XDataType, YDataType, XScaleDataType, YScaleDataType>::ComputeDataType,
typename LayerNormTypeConfig<XDataType, YDataType, XScaleDataType, YScaleDataType>::YDataType,
typename LayerNormTypeConfig<XDataType, YDataType, XScaleDataType, YScaleDataType>::MeanDataType,
typename LayerNormTypeConfig<XDataType, YDataType, XScaleDataType, YScaleDataType>::InvStdDataType,
typename LayerNormTypeConfig<XDataType, YDataType, XScaleDataType, YScaleDataType>::XScaleDataType,
typename LayerNormTypeConfig<XDataType, YDataType, XScaleDataType, YScaleDataType>::YScaleDataType,
typename Traits_::Shape,
PipelineTraits>;
using OnePassPipeline = ck_tile::Layernorm2dFwdPipelineOnePass<PipelineProblem>;
using TwoPassPipeline = ck_tile::Layernorm2dFwdPipelineTwoPass<PipelineProblem>;
using Pipeline = std::conditional_t<Traits_::kTwoPass, TwoPassPipeline, OnePassPipeline>;
using Default2DEpilogueProblem = ck_tile::Default2DEpilogueProblem<ComputeDataType, YDataType, false, Traits_::kPadN, false>;
using Default2DEpilogue = ck_tile::Default2DEpilogue<Default2DEpilogueProblem>;
static constexpr bool UseSmoothInputScale = Traits_::kFusedQuant == 1;
using DynamicQuantEpilogueProblem = ck_tile::DynamicQuantEpilogueProblem<ComputeDataType, XScaleDataType, YScaleDataType, YDataType, typename Traits_::Shape,
ck_tile::DynamicQuantEpilogueTraits<false, Traits_::kPadN, UseSmoothInputScale, false, true/*max3*/>>;
using DynamicQuantEpilogue = ck_tile::DynamicQuantEpilogue<DynamicQuantEpilogueProblem>;
using Epilogue = std::conditional_t<Traits_::kFusedQuant == 1, DynamicQuantEpilogue, Default2DEpilogue>;
using Kernel = ck_tile::Layernorm2dFwd<Pipeline, Epilogue>;
const dim3 grids = Kernel::GridSize(a);
constexpr dim3 blocks = Kernel::BlockSize();
constexpr ck_tile::index_t kBlockPerCu = 1;
auto kargs = Kernel::MakeKargs(a);
if(s.log_level_ > 0)
std::cout << ", " << Kernel::GetName() << std::flush;
return ck_tile::launch_kernel(
s, ck_tile::make_kernel<blocks.x, kBlockPerCu>(Kernel{{}}, grids, blocks, 0, kargs));
}}
"""
API_BASE = """
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#include <ck_tile/core.hpp>
#include "layernorm2d_fwd.hpp"
{F_traits_define}
// Note: this internal API only declare, not define here, otherwise will block `make -j`
template <typename Traits_>
float layernorm2d_fwd_(const ck_tile::stream_config& s, layernorm2d_fwd_args a);
float layernorm2d_fwd(layernorm2d_fwd_traits t,
layernorm2d_fwd_args a,
const ck_tile::stream_config& s)
{{
float r = -1;
{F_dispatch}
return r;
}}
"""
API_PER_DTYPE=""" {F_if}(t.prec_i == \"{F_i_type}\" && t.prec_o == \"{F_o_type}\"){{
{F_per_n_case}
}}
"""
API_PER_N_CASE=""" {F_if} {F_N_COND} {{
{F_inner_dispatch}
}}
"""
API_INNER_CASE=""" {F_if} {F_VEC_COND}
r={F_instance_func}(s, a);
"""
INSTANCE_BASE = """
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#include "layernorm2d_fwd_api_common.hpp"
// clang-format off
// prec_i prec_o prec_sy rm rn tm tn vn pd mv rpcf 2p add sweep
{F_instance_def}
// clang-format on
"""
def __init__(self, working_path, kernel_filter):
self.working_path = working_path
self.kernel_filter = kernel_filter
class k_fuesd_add_enum(IntEnum):
F_NO_ADD = 0
F_PRE_ADD = 1
F_PRE_ADD_STORE_RESIDUAL = 2
class k_fused_sweep_enum(IntEnum):
F_NO_SWEEP = 0
F_RENORM = 1
F_DYNAMIC_QUANT = 2
@dataclass
class k_traits:
F_kPadN : bool
F_kSaveMeanInvStd : bool
F_kTwoPass : bool
F_kFusedAdd : Any #: layernorm_fwd_codegen.k_fuesd_add_enum
F_kFusedQuant : Any #: layernorm_fwd_codegen.k_fused_sweep_enum
@dataclass
class k_shape:
F_BlockTile : List[int]
F_WarpPerBlock : List[int]
F_WarpTile : List[int]
F_Vector_ : List[int]
@property
def F_BlockSize(self) -> int:
return functools.reduce(lambda a, b: a*b, self.F_WarpTile)
@dataclass
class k_problem:
F_XDataType : str
F_GammaDataType : str
F_BetaDataType : str
F_ComputeDataType : str
F_YDataType : str
F_MeanDataType : str
F_InvStdDataType : str
F_BlockShape : str
F_Traits : Any #k_traits
@dataclass
class k_pipeline_one_pass:
F_Problem : Any #k_problem
@dataclass
class k_pipeline_two_pass:
F_Problem : Any #k_problem
@dataclass
class default_2d_epilogue_problem:
F_AccDataType : str
F_ODataType : str
F_kPadM : bool
F_kPadN : bool
@dataclass
class default_2d_epilogue:
F_problem : Any
@dataclass
class k_kernel:
F_pipeline : Any
F_epilogue : Any
@dataclass
class h_traits:
F_XDataType : str
F_YDataType : str
F_XScaleDataType : str
F_YScaleDataType : str
F_Repeat_M : int
F_Repeat_N : int
F_ThreadPerBlock_M : int
F_ThreadPerBlock_N : int
F_Vector_N : int
F_kPadN : bool
F_kSaveMeanInvStd_ : bool
F_kFastFDiv_ : bool
F_kTwoPass_ : bool
F_kFusedAdd : int
F_kFusedQuant : int
@property
def trait_name(self) ->str:
t_ = f'{DATA_TYPE_MAP[self.F_XDataType]}, {DATA_TYPE_MAP[self.F_YDataType]}, {DATA_TYPE_MAP[self.F_XScaleDataType]}, {DATA_TYPE_MAP[self.F_YScaleDataType]}, {self.F_Repeat_M:2}, {self.F_Repeat_N:2}, {self.F_ThreadPerBlock_M:2}, {self.F_ThreadPerBlock_N:4}'
t_ += f', {self.F_Vector_N:2}, {BOOL_MAP(self.F_kPadN):5}, {BOOL_MAP(self.F_kSaveMeanInvStd_):5}, {BOOL_MAP(self.F_kFastFDiv_):5}'
t_ += f', {BOOL_MAP(self.F_kTwoPass_):5}, {self.F_kFusedAdd:4}, {self.F_kFusedQuant:4}'
return t_
# string when calling this kernel
@property
def call_name(self) -> str:
return f'layernorm2d_fwd_<traits_<{self.trait_name}>>'
# string when define this kernel
@property
def def_name(self) -> str:
return f'template float layernorm2d_fwd_<traits_<{self.trait_name}>>(const S&, A);'
# this class hold kernel under same source file
@dataclass
class h_instance:
F_DataTypePair : str
F_N : str
F_add : int
F_sweep : int
instance_list : List[Any] # List[h_traits]
@property
def name(self) -> str:
prec_i, prec_o = self.F_DataTypePair.split(',')
dtype_str = f'{prec_i}' if prec_i == prec_o else f'{prec_i}_{prec_o}'
nnn = f'layernorm2d_fwd_{dtype_str}_n{self.F_N}'
if self.F_add != 0:
nnn = nnn + '_' + FUSED_ADD_ENUM_STR_MAP[self.F_add]
if self.F_sweep != 0:
nnn = nnn + '_' + FUSED_FUSED_SWEEP_STR_MAP[self.F_sweep]
return nnn
@property
def instance_name(self) ->str:
return self.name
@property
def content(self) ->str:
instance_defs = ''
for ins in self.instance_list:
instance_defs += ins.def_name + '\n'
return layernorm_fwd_codegen.INSTANCE_BASE.format(F_instance_def=instance_defs)
@property
def name_api(self) -> str:
return 'layernorm2d_fwd_api'
@property
def name_common_header(self) -> str:
return 'layernorm2d_fwd_api_common'
@property
def content_api(self) -> str:
# 1 sort based on dtype
t_dtype_dict = dict()
blobs = self.get_blobs()
for blob in blobs:
if blob.F_DataTypePair not in t_dtype_dict:
t_dtype_dict[blob.F_DataTypePair] = {}
if blob.F_N not in t_dtype_dict[blob.F_DataTypePair]:
t_dtype_dict[blob.F_DataTypePair][blob.F_N] = []
t_dtype_dict[blob.F_DataTypePair][blob.F_N].append(blob)
d_str = ''
for i_d, dtype_ in enumerate(t_dtype_dict):
blob_per_t = t_dtype_dict[dtype_]
n_str = ''
for i_n, n_ in enumerate(blob_per_t):
blob_per_n = blob_per_t[n_]
inner_str = ""
for i_b, b_ in enumerate(blob_per_n):
# generate single kernel instance file
#vec_str = ""
for i_ins, ins in enumerate(b_.instance_list):
idx_in_n = i_b * len(b_.instance_list) + i_ins
len_in_n = len(blob_per_n) * len(b_.instance_list)
# _if = 'if' if i_ins == 0 else 'else if'
if ins.F_kFusedQuant == 0:
_sweep_cond = 't.fused_quant == {f_fused_sweep}'.format(f_fused_sweep = ins.F_kFusedQuant)
elif ins.F_kFusedQuant == 1:
_sweep_cond = 't.fused_quant == {f_fused_sweep} && (t.prec_sx == \"{f_sx_type}\" && t.prec_sy == \"{f_sy_type}\")'.format(
f_fused_sweep = ins.F_kFusedQuant, f_sx_type=ins.F_XScaleDataType, f_sy_type=ins.F_YScaleDataType)
elif ins.F_kFusedQuant == 2:
_sweep_cond = 't.fused_quant == {f_fused_sweep} && (t.prec_sy == \"{f_sy_type}\")'.format(
f_fused_sweep = ins.F_kFusedQuant, f_sy_type=ins.F_YScaleDataType)
_cond = '((a.n % {f_vec_n} == 0) && (t.fused_add == {f_fused_add}) && ({f_sweep_cond}))'.format(
f_vec_n = ins.F_Vector_N, f_fused_add = ins.F_kFusedAdd,
f_sweep_cond = _sweep_cond)
inner_str += self.API_INNER_CASE.format(F_if = get_if_str(idx_in_n, len_in_n, False),
F_VEC_COND = _cond, F_instance_func=ins.call_name)
#inner_str = inner_str + vec_str
n_cnd = f'(a.n <= {n_})' if (i_n < len(blob_per_t) - 1) else ''
n_str += self.API_PER_N_CASE.format(F_if = get_if_str(i_n, len(blob_per_t)), F_N_COND=n_cnd, F_inner_dispatch=inner_str)
prec_i, prec_o = dtype_.split(',')
d_str += self.API_PER_DTYPE.format(F_if = get_if_str(i_d, len(t_dtype_dict), False), F_i_type=prec_i, F_o_type=prec_o, F_per_n_case=n_str)
api_base = self.API_BASE.format(F_traits_define=self.API_TRAITS_DEFINE, F_dispatch=d_str)
return api_base
@property
def content_common_header(self) -> str:
return self.API_COMMON_HEADER.format(F_traits_define=self.API_TRAITS_DEFINE)
def get_blobs(self):
h_traits = layernorm_fwd_codegen.h_traits
h_instance = layernorm_fwd_codegen.h_instance
dynamic_quant_out_dtype = ['int8']
# some predefined support range
# (prec_i,prec_o) for simplicity this string will be used as key for dict
scale_list = [('fp32,fp32')]
dtype_list = [('fp16,fp16'), ('bf16,bf16'),
('fp16,int8'), ('bf16,int8')] # NOTE: only fused-dynamic-quant use int8 out
#fused_add_list = [0, 1, 2]
#fused_sweep_list = [0, 1, 2] # NOTE: only single pass can use fused dynamic quant
fused_add_list = [0, 1]
fused_sweep_list = [0, 1] # NOTE: only single pass can use fused dynamic quant
# rm rn tm tn vn pd mv fdiv 2p add sweep
h_trait_dict = {'64' : [ h_traits('x', 'y', 'xs', 'ys', 1, 1, 8, 8, 8, True, False, True, False, 0, 0),
h_traits('x', 'y', 'xs', 'ys', 1, 1, 4, 16, 4, True, False, True, False, 0, 0),
h_traits('x', 'y', 'xs', 'ys', 1, 1, 4, 64, 1, True, False, True, False, 0, 0)],
'128' : [ h_traits('x', 'y', 'xs', 'ys', 1, 1, 4, 16, 8, True, False, True, False, 0, 0),
h_traits('x', 'y', 'xs', 'ys', 1, 1, 4, 64, 2, True, False, True, False, 0, 0),
h_traits('x', 'y', 'xs', 'ys', 1, 2, 4, 64, 1, True, False, True, False, 0, 0)],
'256' : [ h_traits('x', 'y', 'xs', 'ys', 1, 1, 4, 64, 4, True, False, True, False, 0, 0),
h_traits('x', 'y', 'xs', 'ys', 1, 2, 4, 64, 2, True, False, True, False, 0, 0),
h_traits('x', 'y', 'xs', 'ys', 1, 4, 4, 64, 1, True, False, True, False, 0, 0)],
'512' : [ h_traits('x', 'y', 'xs', 'ys', 1, 1, 4, 64, 8, True, False, True, False, 0, 0),
h_traits('x', 'y', 'xs', 'ys', 1, 2, 4, 64, 4, True, False, True, False, 0, 0),
h_traits('x', 'y', 'xs', 'ys', 1, 4, 4, 64, 2, True, False, True, False, 0, 0),
h_traits('x', 'y', 'xs', 'ys', 1, 8, 4, 64, 1, True, False, True, False, 0, 0)],
'768' : [ h_traits('x', 'y', 'xs', 'ys', 1, 3, 4, 64, 4, True, False, True, False, 0, 0),
h_traits('x', 'y', 'xs', 'ys', 1, 6, 4, 64, 2, True, False, True, False, 0, 0),
h_traits('x', 'y', 'xs', 'ys', 1, 12, 4, 64, 1, True, False, True, False, 0, 0)],
'1024' :[ h_traits('x', 'y', 'xs', 'ys', 1, 1, 2, 128, 8, True, False, True, False, 0, 0),
h_traits('x', 'y', 'xs', 'ys', 1, 2, 2, 128, 4, True, False, True, False, 0, 0),
h_traits('x', 'y', 'xs', 'ys', 1, 4, 2, 128, 2, True, False, True, False, 0, 0),
h_traits('x', 'y', 'xs', 'ys', 1, 4, 1, 256, 1, True, False, True, False, 0, 0)],
'1536' :[ h_traits('x', 'y', 'xs', 'ys', 1, 3, 4, 64, 8, True, False, True, False, 0, 0),
h_traits('x', 'y', 'xs', 'ys', 1, 3, 2, 128, 4, True, False, True, False, 0, 0),
h_traits('x', 'y', 'xs', 'ys', 1, 3, 1, 256, 2, True, False, True, False, 0, 0),
h_traits('x', 'y', 'xs', 'ys', 1, 6, 1, 256, 1, True, False, True, False, 0, 0)],
'2048' :[ h_traits('x', 'y', 'xs', 'ys', 1, 1, 1, 256, 8, True, False, True, False, 0, 0),
h_traits('x', 'y', 'xs', 'ys', 1, 2, 1, 256, 4, True, False, True, False, 0, 0),
h_traits('x', 'y', 'xs', 'ys', 1, 4, 1, 256, 2, True, False, True, False, 0, 0),
h_traits('x', 'y', 'xs', 'ys', 1, 8, 1, 256, 1, True, False, True, False, 0, 0)],
'3072' :[ h_traits('x', 'y', 'xs', 'ys', 1, 3, 1, 128, 8, True, False, True, False, 0, 0),
h_traits('x', 'y', 'xs', 'ys', 1, 3, 1, 256, 4, True, False, True, False, 0, 0),
h_traits('x', 'y', 'xs', 'ys', 1, 6, 1, 256, 2, True, False, True, False, 0, 0),
h_traits('x', 'y', 'xs', 'ys', 1, 3, 1,1024, 1, True, False, True, False, 0, 0)],
'4096' :[ h_traits('x', 'y', 'xs', 'ys', 1, 2, 1, 256, 8, True, False, True, False, 0, 0),
h_traits('x', 'y', 'xs', 'ys', 1, 4, 1, 256, 4, True, False, True, False, 0, 0),
h_traits('x', 'y', 'xs', 'ys', 1, 2, 1,1024, 2, True, False, True, False, 0, 0),
h_traits('x', 'y', 'xs', 'ys', 1, 4, 1,1024, 1, True, False, True, False, 0, 0)],
'6144' :[ h_traits('x', 'y', 'xs', 'ys', 1, 3, 1, 256, 8, True, False, True, False, 0, 0),
h_traits('x', 'y', 'xs', 'ys', 1, 3, 1, 512, 4, True, False, True, False, 0, 0),
h_traits('x', 'y', 'xs', 'ys', 1, 3, 1,1024, 2, True, False, True, False, 0, 0),
h_traits('x', 'y', 'xs', 'ys', 1, 6, 1,1024, 1, True, False, True, False, 0, 0)],
'8192' :[ h_traits('x', 'y', 'xs', 'ys', 1, 4, 1, 256, 8, True, False, True, False, 0, 0),
h_traits('x', 'y', 'xs', 'ys', 1, 4, 1, 512, 4, True, False, True, False, 0, 0),
h_traits('x', 'y', 'xs', 'ys', 1, 4, 1,1024, 2, True, False, True, False, 0, 0),
h_traits('x', 'y', 'xs', 'ys', 1, 8, 1,1024, 1, True, False, True, False, 0, 0)],
'big' :[ h_traits('x', 'y', 'xs', 'ys', 1, 2, 1, 256, 8, True, False, True, True, 0, 0),
h_traits('x', 'y', 'xs', 'ys', 1, 4, 1, 256, 4, True, False, True, True, 0, 0),
h_traits('x', 'y', 'xs', 'ys', 1, 2, 1,1024, 2, True, False, True, True, 0, 0),
h_traits('x', 'y', 'xs', 'ys', 1, 4, 1,1024, 1, True, False, True, True, 0, 0)]}
total_blob = list()
for hs_key in h_trait_dict:
hs = h_trait_dict[hs_key]
current_n = hs[0].F_Repeat_N * hs[0].F_ThreadPerBlock_N * hs[0].F_Vector_N
for dtype, scale_type, fused_add, fused_quant in itertools.product(dtype_list, scale_list, fused_add_list, fused_sweep_list):
prec_i, prec_o = dtype.split(',')
scale_x, scale_y = scale_type.split(',')
if prec_o in dynamic_quant_out_dtype and fused_quant != 1:
continue # skip non dynamic quant case
if fused_quant == 1 and hs_key == 'big':
continue
current_hs = list()
for chs_ in hs:
h_ = copy.copy(chs_) # copy the base instance out
h_.F_XDataType = prec_i
h_.F_YDataType = prec_o
h_.F_XScaleDataType = scale_y
h_.F_YScaleDataType = scale_x
h_.F_kFusedAdd = fused_add
h_.F_kFusedQuant = fused_quant
current_hs.append(h_) # + "\n"
#f.write(str(f.parent / GEN_DIR / (blobs.api_common_header_
current_n_str = 'big' if hs_key == 'big' else current_n
total_blob.append(h_instance(dtype, current_n_str, fused_add, fused_quant, current_hs))
return total_blob
def list_blobs(self) -> None:
w_p = Path(self.working_path)
list_p = w_p / 'layernorm2d_fwd_blobs.txt'
blobs = self.get_blobs()
with list_p.open('w') as list_f:
# api related file
list_f.write(str(w_p / (self.name_api + ".cpp")) + "\n")
list_f.write(str(w_p / (self.name_common_header + ".hpp")) + "\n")
# kernel instance file
for b in blobs:
list_f.write(str(w_p / (b.name + ".cpp")) + "\n")
def gen_blobs(self) -> None:
w_p = Path(self.working_path)
(w_p / (self.name_api + ".cpp")).write_text(self.content_api)
(w_p / (self.name_common_header + ".hpp")).write_text(self.content_common_header)
blobs = self.get_blobs()
for b in blobs:
(w_p / (b.name + ".cpp")).write_text(b.content)
def list_blobs(args):
api_list = args.api.split(',')
for api in api_list:
if api == 'fwd':
layernorm_fwd_codegen(args.working_path, args.filter).list_blobs()
def gen_blobs(args):
api_list = args.api.split(',')
for api in api_list:
if api == 'fwd':
layernorm_fwd_codegen(args.working_path, args.filter).gen_blobs()
if __name__ == "__main__":
parser = argparse.ArgumentParser(
prog="generate",
description="gen API for CK layernorm kernel",
)
parser.add_argument(
"-a",
"--api",
default='fwd[all]',
required=False,
help="supply API(s) to generate (default: fwd). separated by comma."
)
# the directory for list_blobs/gen_blobs to write files into
parser.add_argument(
"-w",
"--working_path",
default="./",
required=False,
help="the path where all the blobs are going to be generated"
)
# this script have 2 modes
# 1) list_blobs mode, will generate a txt file with all the files going to be generated.
# this is useful in build system like cmake to construct source code dependency, by
# reading the content out of this file
# 2) gen_blobs mode, will generate the actuall kernel instance and api. If in framework
# like FA, only need to use this mode
parser.add_argument(
"-l",
"--list_blobs",
action='store_true',
help="list all the kernels to a file, "
)
parser.add_argument(
"-g",
"--gen_blobs",
action='store_true',
help="generate all kernels into different tile"
)
# TODO: if using filter, must apply same value to output_dir and list_blobs
parser.add_argument(
"-f",
"--filter",
required=False,
help="filter out kernels that need to generate, using fnmatch module"
)
parser.add_argument(
"-t",
"--traits",
default="all",
required=False,
help="enable/disable some feature. default generate all"
)
parser.add_argument(
"-r",
"--receipt",
default=0,
required=False,
help="codegen receipt."
)
args = parser.parse_args()
# print(f'{args.list_blobs}-{args.gen_blobs}')
if (args.gen_blobs and args.list_blobs) or ((not args.gen_blobs) and (not args.list_blobs)):
print('gen_blobs/list_blobs must specify only one option')
sys.exit()
p = Path(args.working_path)
if not p.exists():
p.mkdir()
if args.list_blobs:
list_blobs(args)
else:
gen_blobs(args)
#include "ck_tile/host.hpp" #include "ck_tile/host.hpp"
#include "layernorm2d_fwd.hpp" #include "layernorm2d_fwd.hpp"
#include <algorithm>
#include <cstring> #include <cstring>
// Host API implementation // different threshold for different dtype
float layernorm2d_fwd(layernorm2d_fwd_traits t, template <typename DataType>
layernorm2d_fwd_args a, auto get_elimit()
const ck_tile::stream_config& s)
{ {
if(t.data_type.compare("fp16") == 0) double rtol = 1e-2;
{ double atol = 1e-2;
using XDataType = ck_tile::half_t; return ck_tile::make_tuple(rtol, atol);
using YDataType = ck_tile::half_t; }
using GammaDataType = ck_tile::half_t;
using BetaDataType = ck_tile::half_t;
#ifdef SAVE_MEAN_INV_STD
using MeanDataType = ck_tile::half_t;
using InvStdDataType = ck_tile::half_t;
#else
using MeanDataType = ck_tile::null_type;
using InvStdDataType = ck_tile::null_type;
#endif
using ComputeDataType = float;
using thread_tile = ck_tile::sequence<4, 4>;
using warp_tile = ck_tile::sequence<8, 128>;
using block_tile = ck_tile::sequence<32, 128>;
using Shape = ck_tile::TileLayernorm2dShape<thread_tile, warp_tile, block_tile>;
using PipelineProblem = ck_tile::BlockLayernorm2dFwdProblem<XDataType,
GammaDataType,
BetaDataType,
ComputeDataType,
YDataType,
MeanDataType,
InvStdDataType,
Shape,
true,
true>;
using Kernel = ck_tile::Layernorm2dFwd<PipelineProblem>;
auto kargs = Kernel::MakeKargs(
a.p_x, a.p_gamma, a.p_beta, a.p_y, a.p_mean, a.p_invStd, a.epsilon, a.M, a.N);
const dim3 grids = Kernel::GridSize(a.M);
constexpr dim3 blocks = Kernel::BlockSize();
constexpr ck_tile::index_t kBlockPerCu = Shape::kMWarpPerBlock * Shape::kNWarpPerBlock;
float ave_time = ck_tile::launch_kernel(
s, ck_tile::make_kernel<blocks.x, kBlockPerCu>(Kernel{}, grids, blocks, 0, kargs));
return ave_time;
}
return 0; template <>
auto get_elimit<ck_tile::bf16_t>()
{
double rtol = 1e-2;
double atol = 1e-2;
return ck_tile::make_tuple(rtol, atol);
} }
auto create_args(int argc, char* argv[]) auto create_args(int argc, char* argv[])
{ {
ck_tile::ArgParser arg_parser; ck_tile::ArgParser arg_parser;
arg_parser.insert("m", "3328", "m dimension") arg_parser.insert("m", "3328", "m dimension")
.insert("n", "4096", "m dimension") .insert("n", "4096", "n dimension")
.insert("x_stride", "-1", "x row_stride, if -1 then equal to n")
.insert("xr_stride", "-1", "x residule row_stride, if -1 then equal to n")
.insert("y_stride", "-1", "y row_stride, if -1 then equal to n")
.insert("yr_stride", "-1", "y residule row_stride, if -1 then equal to n")
.insert("e", "1e-5", "epsilon") .insert("e", "1e-5", "epsilon")
.insert("save_mv", "0", "save mean/variance(invstd) or not. set to 1 in training case")
.insert("v", "1", "cpu validation or not") .insert("v", "1", "cpu validation or not")
.insert("prec", "fp16", "precision"); .insert("kname", "1", "print kernel name or not")
.insert("prec_i", "fp16", "input precision")
.insert("prec_o", "auto", "output precision, set auto will be the same as input")
.insert("prec_sx",
"auto",
"output quant scale type, set auto will use fp32. used when fquant=1")
.insert("prec_sy",
"auto",
"output quant scale type, set auto will use fp32. used when fquant=1 or 2")
.insert("fadd", "0", "fused-add, 0:no fused add, 1:preadd+store, 2:preadd only")
.insert("fquant", "0", "fused-quant, 0:no, 1:smooth-dynamic-quant, 2:dynamic-quant")
.insert("warmup", "5", "cold iter")
.insert("repeat", "20", "hot iter");
bool result = arg_parser.parse(argc, argv); bool result = arg_parser.parse(argc, argv);
return std::make_tuple(result, arg_parser); return std::make_tuple(result, arg_parser);
} }
int main(int argc, char* argv[]) template <typename InDataType,
typename OutDataType,
typename XScaleDataType,
typename YScaleDataType,
bool SaveMeanVar>
bool run(const ck_tile::ArgParser& arg_parser)
{ {
ck_tile::index_t m = arg_parser.get_int("m");
ck_tile::index_t n = arg_parser.get_int("n");
ck_tile::index_t x_stride = arg_parser.get_int("x_stride");
if(x_stride < 0)
x_stride = n;
ck_tile::index_t xr_stride = arg_parser.get_int("xr_stride");
if(xr_stride < 0)
xr_stride = n;
ck_tile::index_t y_stride = arg_parser.get_int("y_stride");
if(y_stride < 0)
y_stride = n;
ck_tile::index_t yr_stride = arg_parser.get_int("yr_stride");
if(yr_stride < 0)
yr_stride = n;
float epsilon = arg_parser.get_float("e");
std::string prec_i = arg_parser.get_str("prec_i");
std::string prec_o = arg_parser.get_str("prec_o");
std::string prec_sx = arg_parser.get_str("prec_sx");
std::string prec_sy = arg_parser.get_str("prec_sy");
if(prec_o == "auto")
{
prec_o = prec_i;
}
if(prec_sx == "auto")
{
prec_sx = "fp32";
}
if(prec_sy == "auto")
{
prec_sy = "fp32";
}
auto [result, arg_parser] = create_args(argc, argv); int kname = arg_parser.get_int("kname");
if(!result) int do_validation = arg_parser.get_int("v");
return -1; int warmup = arg_parser.get_int("warmup");
int repeat = arg_parser.get_int("repeat");
int fused_add = arg_parser.get_int("fadd");
int fused_quant = arg_parser.get_int("fquant");
if(fused_quant == 1 && prec_o != "int8")
{
std::cout << "if fused_quant is 1, only support \"-prec_o=int8\" case" << std::endl;
return false;
}
assert(x_stride >= n);
using TypeConfig = LayerNormTypeConfig<InDataType, OutDataType, XScaleDataType, YScaleDataType>;
float epsilon = arg_parser.get_float("e"); using XDataType = typename TypeConfig::XDataType;
ck_tile::index_t M = arg_parser.get_int("m"); using YDataType = typename TypeConfig::YDataType;
ck_tile::index_t N = arg_parser.get_int("n"); using GammaDataType = typename TypeConfig::GammaDataType;
std::string data_type = arg_parser.get_str("prec"); using BetaDataType = typename TypeConfig::BetaDataType;
int do_validation = arg_parser.get_int("v"); using XResidualDataType = XDataType;
using YResidualDataType = XDataType;
using XDataType = ck_tile::half_t;
using YDataType = ck_tile::half_t; using MeanDataType =
using GammaDataType = ck_tile::half_t; std::conditional_t<SaveMeanVar, typename TypeConfig::MeanDataType, ck_tile::null_type>;
using BetaDataType = ck_tile::half_t; using InvStdDataType =
#ifdef SAVE_MEAN_INV_STD std::conditional_t<SaveMeanVar, typename TypeConfig::InvStdDataType, ck_tile::null_type>;
using MeanDataType = ck_tile::half_t;
using InvStdDataType = ck_tile::half_t; using ComputeDataType = typename TypeConfig::ComputeDataType;
#else
using MeanDataType = ck_tile::null_type;
using InvStdDataType = ck_tile::null_type;
#endif
using ComputeDataType = float;
// host verify // host verify
ck_tile::HostTensor<XDataType> x_host({M, N}); ck_tile::HostTensor<XDataType> x_host({m, n}, {x_stride, 1});
ck_tile::HostTensor<GammaDataType> gamma_host({N}); ck_tile::HostTensor<GammaDataType> gamma_host({n});
ck_tile::HostTensor<BetaDataType> beta_host({N}); ck_tile::HostTensor<BetaDataType> beta_host({n});
ck_tile::HostTensor<XResidualDataType> x_residual_host({m, n}, {xr_stride, 1});
ck_tile::HostTensor<YResidualDataType> y_residual_host({m, n}, {yr_stride, 1});
ck_tile::HostTensor<YDataType> y_host_ref({M, N}); ck_tile::HostTensor<YDataType> y_host_ref({m, n}, {y_stride, 1});
ck_tile::HostTensor<YDataType> y_host_dev({M, N}); ck_tile::HostTensor<YDataType> y_host_dev({m, n}, {y_stride, 1});
ck_tile::HostTensor<MeanDataType> mean_host_ref({M}); ck_tile::HostTensor<MeanDataType> mean_host_ref({m});
ck_tile::HostTensor<InvStdDataType> invStd_host_ref({M}); ck_tile::HostTensor<InvStdDataType> invStd_host_ref({m});
ck_tile::HostTensor<YScaleDataType> y_scale_host_ref({m});
ck_tile::HostTensor<YScaleDataType> y_scale_host_dev({m});
#ifdef SAVE_MEAN_INV_STD ck_tile::HostTensor<XScaleDataType> x_scale_host({n});
ck_tile::HostTensor<MeanDataType> mean_host_dev({M}); ck_tile::HostTensor<XScaleDataType> x_scale_host_dev({n});
ck_tile::HostTensor<InvStdDataType> invStd_host_dev({M});
#endif
ck_tile::FillUniformDistribution<XDataType>{-5.f, 5.f}(x_host); ck_tile::FillUniformDistribution<XDataType>{-.5f, .5f}(x_host);
ck_tile::FillUniformDistribution<GammaDataType>{-5.f, 5.f}(gamma_host); ck_tile::FillUniformDistribution<XResidualDataType>{-.5f, .5f}(x_residual_host);
ck_tile::FillUniformDistribution<BetaDataType>{-5.f, 5.f}(beta_host); ck_tile::FillUniformDistribution<XScaleDataType>{-1.f, 1.f}(x_scale_host);
ck_tile::FillUniformDistribution<GammaDataType>{-.5f, .5f}(gamma_host);
ck_tile::FillUniformDistribution<BetaDataType>{-.5f, .5f}(beta_host);
ck_tile::DeviceMem x_buf(x_host.get_element_space_size_in_bytes()); ck_tile::DeviceMem x_buf(x_host.get_element_space_size_in_bytes());
ck_tile::DeviceMem gamma_buf(gamma_host.get_element_space_size_in_bytes()); ck_tile::DeviceMem gamma_buf(gamma_host.get_element_space_size_in_bytes());
ck_tile::DeviceMem beta_buf(beta_host.get_element_space_size_in_bytes()); ck_tile::DeviceMem beta_buf(beta_host.get_element_space_size_in_bytes());
ck_tile::DeviceMem y_buf(y_host_dev.get_element_space_size_in_bytes()); ck_tile::DeviceMem y_buf(y_host_dev.get_element_space_size_in_bytes());
ck_tile::DeviceMem y_scale_buf(y_scale_host_dev.get_element_space_size_in_bytes());
ck_tile::DeviceMem x_scale_buf(x_scale_host_dev.get_element_space_size_in_bytes());
#ifdef SAVE_MEAN_INV_STD ck_tile::DeviceMem x_residual_buf(x_residual_host.get_element_space_size_in_bytes());
ck_tile::DeviceMem mean_buf(mean_host_dev.get_element_space_size_in_bytes()); ck_tile::DeviceMem y_residual_buf(y_residual_host.get_element_space_size_in_bytes());
ck_tile::DeviceMem invStd_buf(invStd_host_dev.get_element_space_size_in_bytes());
#endif
x_buf.ToDevice(x_host.data()); x_buf.ToDevice(x_host.data());
gamma_buf.ToDevice(gamma_host.data()); gamma_buf.ToDevice(gamma_host.data());
beta_buf.ToDevice(beta_host.data()); beta_buf.ToDevice(beta_host.data());
x_residual_buf.ToDevice(x_residual_host.data());
layernorm2d_fwd_traits traits{data_type}; x_scale_buf.ToDevice(x_scale_host.data());
auto prec_str = [&]() {
auto base_str = prec_i;
if(prec_i != prec_o)
{
base_str += "|" + prec_o;
}
if(fused_quant == 1)
{
base_str += std::string("(") + prec_sy + ")";
}
return base_str;
}();
std::cout << "[" << prec_str << "]"
<< " m:" << m << ", n:" << n << ", x_stride:" << x_stride
<< ", xr_stride:" << xr_stride << ", y_stride:" << y_stride
<< ", yr_stride:" << yr_stride << std::flush;
layernorm2d_fwd_traits traits{
prec_i, prec_o, prec_sx, prec_sy, SaveMeanVar, fused_add, fused_quant};
layernorm2d_fwd_args args{x_buf.GetDeviceBuffer(), layernorm2d_fwd_args args{x_buf.GetDeviceBuffer(),
fused_add != 0 ? x_residual_buf.GetDeviceBuffer() : nullptr,
fused_quant == 1 ? x_scale_buf.GetDeviceBuffer() : nullptr,
gamma_buf.GetDeviceBuffer(), gamma_buf.GetDeviceBuffer(),
beta_buf.GetDeviceBuffer(), beta_buf.GetDeviceBuffer(),
y_buf.GetDeviceBuffer(), y_buf.GetDeviceBuffer(),
#ifdef SAVE_MEAN_INV_STD fused_add == 1 ? y_residual_buf.GetDeviceBuffer() : nullptr,
mean_buf.GetDeviceBuffer(), fused_quant != 0 ? y_scale_buf.GetDeviceBuffer() : nullptr,
invStd_buf.GetDeviceBuffer(), nullptr, // p_mean, unsupported yet
#else nullptr, // p_invStd, unsupported yet
nullptr,
nullptr,
#endif
epsilon, epsilon,
M, m,
N}; n,
x_stride, // x row_stride
xr_stride, // x residule row stride
y_stride, // y row stride
yr_stride}; // y residule row stride
float ave_time = layernorm2d_fwd(traits, args, ck_tile::stream_config{nullptr, true}); float ave_time = layernorm2d_fwd(
traits, args, ck_tile::stream_config{nullptr, true, kname ? 1 : 0, warmup, repeat});
if(ave_time < 0)
{
std::cout << " not supported!" << std::endl << std::flush;
return false;
}
std::size_t num_byte = sizeof(XDataType) * M * N + sizeof(GammaDataType) * N + std::size_t num_byte = sizeof(XDataType) * m * n + sizeof(GammaDataType) * n +
sizeof(BetaDataType) * N + sizeof(YDataType) * M * N; sizeof(BetaDataType) * n + sizeof(YDataType) * m * n;
float gb_per_sec = num_byte / 1.E6 / ave_time; float gb_per_sec = num_byte / 1.E6 / ave_time;
std::cout << "[" << data_type << "]" std::cout << ", " << ave_time * 1.E3 << " us, " << gb_per_sec << " GB/s" << std::flush;
<< " m:" << M << ", n:" << N << ", " << ave_time << " ms, " << gb_per_sec << " GB/s"
<< std::flush;
bool pass = true; bool pass = true;
if(do_validation) if(do_validation)
{ {
// reference // reference
if(fused_add != 0)
{
// fused pre_add/pre_add_store
// TODO we accumulate directly to x_host for simplcity here...
std::transform(x_host.mData.cbegin(),
x_host.mData.cend(),
x_residual_host.mData.cbegin(),
x_host.mData.begin(),
[](auto x_, auto r_) {
auto o_ = ck_tile::type_convert<ComputeDataType>(x_) +
ck_tile::type_convert<ComputeDataType>(r_);
return ck_tile::type_convert<XDataType>(o_);
});
}
ck_tile::reference_layernorm2d_fwd<XDataType, ck_tile::reference_layernorm2d_fwd<XDataType,
GammaDataType, GammaDataType,
BetaDataType, BetaDataType,
...@@ -172,22 +245,187 @@ int main(int argc, char* argv[]) ...@@ -172,22 +245,187 @@ int main(int argc, char* argv[])
InvStdDataType>( InvStdDataType>(
x_host, gamma_host, beta_host, y_host_ref, mean_host_ref, invStd_host_ref, epsilon); x_host, gamma_host, beta_host, y_host_ref, mean_host_ref, invStd_host_ref, epsilon);
if(fused_quant != 0)
{
auto dquant_functor = [&](int m_, auto& o_, auto& acc_) {
int N_ = acc_.mDesc.get_lengths()[1];
if(fused_quant == 1)
{
for(int n_ = 0; n_ < N_; n_++)
{
// input smooth outlier
acc_(m_, n_) =
acc_(m_, n_) * ck_tile::type_convert<ComputeDataType>(x_scale_host(n_));
}
}
ComputeDataType absmax = static_cast<ComputeDataType>(0);
for(int n_ = 0; n_ < N_; n_++)
{
const auto a = ck_tile::abs(acc_(m_, n_));
absmax = a > absmax ? a : absmax;
}
// printf("cpu:absmax:%f\n", absmax);
ComputeDataType y_scale = absmax / static_cast<ComputeDataType>(127.0);
y_scale_host_ref(m_) = ck_tile::type_convert<YScaleDataType>(y_scale);
for(int n_ = 0; n_ < N_; n_++)
{
o_(m_, n_) = ck_tile::type_convert<YDataType>(acc_(m_, n_) / y_scale);
}
};
ck_tile::reference_layernorm2d_fwd<XDataType,
GammaDataType,
BetaDataType,
ComputeDataType,
YDataType,
MeanDataType,
InvStdDataType>(x_host,
gamma_host,
beta_host,
y_host_ref,
mean_host_ref,
invStd_host_ref,
epsilon,
dquant_functor);
}
else
{
ck_tile::reference_layernorm2d_fwd<XDataType,
GammaDataType,
BetaDataType,
ComputeDataType,
YDataType,
MeanDataType,
InvStdDataType>(
x_host, gamma_host, beta_host, y_host_ref, mean_host_ref, invStd_host_ref, epsilon);
}
y_buf.FromDevice(y_host_dev.data()); y_buf.FromDevice(y_host_dev.data());
pass = ck_tile::check_err(y_host_dev, y_host_ref); ck_tile::HostTensor<YResidualDataType> y_residual_host_dev({m, n}, {yr_stride, 1});
if(fused_add == 1)
{
y_residual_buf.FromDevice(y_residual_host_dev.data());
}
auto [rtol, atol] = get_elimit<InDataType>();
if(x_stride == n)
{
pass = ck_tile::check_err(
y_host_dev, y_host_ref, std::string("OUT Error: Incorrect results!"), rtol, atol);
if(fused_add == 1)
{
pass &= ck_tile::check_err(y_residual_host_dev,
x_host,
std::string("ADD Error: Incorrect results!"),
rtol,
atol);
}
}
else
{
for(int i_r = 0; i_r < m; i_r++)
{
std::vector<YDataType> y_host_dev_row(y_host_dev.begin() + i_r * y_stride,
y_host_dev.begin() + i_r * y_stride + n);
std::vector<YDataType> y_host_ref_row(y_host_ref.begin() + i_r * y_stride,
y_host_ref.begin() + i_r * y_stride + n);
pass &= ck_tile::check_err(y_host_dev_row,
y_host_ref_row,
std::string("OUT[") + std::to_string(i_r) +
std::string("] Error: Incorrect results!"),
rtol,
atol);
if(fused_add == 1)
{
std::vector<YResidualDataType> y_residual_host_dev_row(
y_residual_host_dev.begin() + i_r * yr_stride,
y_residual_host_dev.begin() + i_r * yr_stride + n);
std::vector<YResidualDataType> y_residual_host_ref_row(
x_host.begin() + i_r * yr_stride, x_host.begin() + i_r * yr_stride + n);
pass &= ck_tile::check_err(y_residual_host_dev_row,
y_residual_host_ref_row,
std::string("ADD[") + std::to_string(i_r) +
std::string("] Error: Incorrect results!"),
rtol,
atol);
}
}
}
if(fused_quant == 1)
{
y_scale_buf.FromDevice(y_scale_host_dev.data());
pass &= ck_tile::check_err(y_scale_host_dev,
y_scale_host_ref,
std::string("SCALE Error: Incorrect results!"),
rtol,
atol);
}
std::cout << ", valid:" << (pass ? "y" : "n") << std::flush << std::endl;
}
return pass;
}
#ifdef SAVE_MEAN_INV_STD int main(int argc, char* argv[])
mean_buf.FromDevice(mean_host_dev.data()); {
pass &= ck_tile::check_err(mean_host_dev, mean_host_ref); auto [result, arg_parser] = create_args(argc, argv);
if(!result)
return -1;
invStd_buf.FromDevice(invStd_host_dev.data()); std::string prec_i = arg_parser.get_str("prec_i");
pass &= ck_tile::check_err(invStd_host_dev, invStd_host_ref); std::string prec_o = arg_parser.get_str("prec_o");
#endif std::string prec_sx = arg_parser.get_str("prec_sx");
std::string prec_sy = arg_parser.get_str("prec_sy");
std::cout << ", valid:" << (pass ? "y" : "n") << std::flush; if(prec_o == "auto")
{
prec_o = prec_i;
}
if(prec_sx == "auto")
{
prec_sx = "fp32";
}
if(prec_sy == "auto")
{
prec_sy = "fp32";
} }
int save_mv = arg_parser.get_int("save_mv");
std::cout << std::endl << std::flush; // no dynamic quant case
if(prec_i == "fp16" && prec_o == "fp16" && prec_sx == "fp32" && prec_sy == "fp32" && save_mv)
{
return run<ck_tile::half_t, ck_tile::half_t, float, float, true>(arg_parser) ? 0 : -2;
}
else if(prec_i == "fp16" && prec_o == "fp16" && prec_sx == "fp32" && prec_sy == "fp32" &&
!save_mv)
{
return run<ck_tile::half_t, ck_tile::half_t, float, float, false>(arg_parser) ? 0 : -2;
}
else if(prec_i == "bf16" && prec_o == "bf16" && prec_sx == "fp32" && prec_sy == "fp32" &&
save_mv)
{
return run<ck_tile::bf16_t, ck_tile::bf16_t, float, float, true>(arg_parser) ? 0 : -2;
}
else if(prec_i == "bf16" && prec_o == "bf16" && prec_sx == "fp32" && prec_sy == "fp32" &&
!save_mv)
{
return run<ck_tile::bf16_t, ck_tile::bf16_t, float, float, true>(arg_parser) ? 0 : -2;
}
// dynamic quant case, only in inference
else if(prec_i == "fp16" && prec_o == "int8" && prec_sx == "fp32" && prec_sy == "fp32" &&
!save_mv)
{
return run<ck_tile::half_t, ck_tile::int8_t, float, float, false>(arg_parser) ? 0 : -2;
}
else if(prec_i == "bf16" && prec_o == "int8" && prec_sx == "fp32" && prec_sy == "fp32" &&
!save_mv)
{
return run<ck_tile::bf16_t, ck_tile::int8_t, float, float, false>(arg_parser) ? 0 : -2;
}
return !pass; return -3;
} }
...@@ -8,23 +8,57 @@ ...@@ -8,23 +8,57 @@
#include "ck_tile/ops/layernorm2d.hpp" #include "ck_tile/ops/layernorm2d.hpp"
#include <string> #include <string>
struct layernorm2d_fwd_traits template <typename InType, typename OutType, typename XScaleDataType_, typename YScaleDataType_>
struct LayerNormTypeConfig;
template <typename OutType, typename XScaleDataType_, typename YScaleDataType_>
struct LayerNormTypeConfig<ck_tile::half_t, OutType, XScaleDataType_, YScaleDataType_>
{
using XDataType = ck_tile::half_t;
using YDataType = OutType;
using GammaDataType = ck_tile::half_t;
using BetaDataType = ck_tile::half_t;
using MeanDataType = ck_tile::half_t;
using InvStdDataType = ck_tile::half_t;
using ComputeDataType = float;
using XScaleDataType = XScaleDataType_;
using YScaleDataType = YScaleDataType_;
};
template <typename OutType, typename XScaleDataType_, typename YScaleDataType_>
struct LayerNormTypeConfig<ck_tile::bf16_t, OutType, XScaleDataType_, YScaleDataType_>
{ {
std::string data_type; using XDataType = ck_tile::bf16_t;
using YDataType = OutType;
using GammaDataType = ck_tile::bf16_t;
using BetaDataType = ck_tile::bf16_t;
using MeanDataType = ck_tile::bf16_t;
using InvStdDataType = ck_tile::bf16_t;
using ComputeDataType = float;
using XScaleDataType = XScaleDataType_;
using YScaleDataType = YScaleDataType_;
}; };
struct layernorm2d_fwd_args // runtime args
struct layernorm2d_fwd_args : public ck_tile::Layernorm2dFwdHostArgs
{ {
const void* p_x;
const void* p_gamma;
const void* p_beta;
void* p_y;
void* p_mean;
void* p_invStd;
float epsilon;
ck_tile::index_t M;
ck_tile::index_t N;
}; };
// host API // This is the public API, will be generated by script
struct layernorm2d_fwd_traits
{
std::string prec_i; // input precision
std::string prec_o; // output precision
// if fused_quant == 1, need set prec_sx/prec_sy to proper string, otherwise can set
// arbitrary(will skip check) if fused_quant == 2, need set prec_sy to proper string, otherwise
// can set arbitrary(will skip check)
std::string prec_sx; // x-scale, used for [1*N] input smooth quant
std::string prec_sy; // y-scale, used for [M*1] output for next layer
bool save_mean_var; //
int fused_add; // 0:no-add, 1:pre-add-store, 2:pre-add
int fused_quant; // 0:no-sweep, 1:smooth-dynamic-quant, 2:dynamic-quant
};
float layernorm2d_fwd(layernorm2d_fwd_traits, layernorm2d_fwd_args, const ck_tile::stream_config&); float layernorm2d_fwd(layernorm2d_fwd_traits, layernorm2d_fwd_args, const ck_tile::stream_config&);
#!/bin/sh
EXE="$(find . -name tile_example_layernorm2d_fwd -type f | head -n 1)"
$EXE -m=1 -n=1 -e=1e-12 -v=1 -prec_i=bf16 -repeat=1000
$EXE -m=700 -n=80 -e=1e-12 -v=1 -prec_i=bf16 -repeat=1000
$EXE -m=700 -n=128 -e=1e-12 -v=1 -prec_i=bf16 -repeat=1000
$EXE -m=700 -n=144 -e=1e-12 -v=1 -prec_i=bf16 -repeat=1000
$EXE -m=700 -n=168 -e=1e-12 -v=1 -prec_i=bf16 -repeat=1000
$EXE -m=700 -n=184 -e=1e-12 -v=1 -prec_i=bf16 -repeat=1000
$EXE -m=700 -n=256 -e=1e-12 -v=1 -prec_i=bf16 -repeat=1000
$EXE -m=700 -n=288 -e=1e-12 -v=1 -prec_i=bf16 -repeat=1000
$EXE -m=700 -n=344 -e=1e-12 -v=1 -prec_i=bf16 -repeat=1000
$EXE -m=700 -n=376 -e=1e-12 -v=1 -prec_i=bf16 -repeat=1000
$EXE -m=700 -n=448 -e=1e-12 -v=1 -prec_i=bf16 -repeat=1000
$EXE -m=700 -n=512 -e=1e-12 -v=1 -prec_i=bf16 -repeat=1000
$EXE -m=700 -n=924 -e=1e-12 -v=1 -prec_i=bf16 -repeat=1000
$EXE -m=700 -n=1024 -e=1e-12 -v=1 -prec_i=bf16 -repeat=1000
$EXE -m=700 -n=1078 -e=1e-12 -v=1 -prec_i=bf16 -repeat=1000
$EXE -m=700 -n=1996 -e=1e-12 -v=1 -prec_i=bf16 -repeat=1000
$EXE -m=700 -n=4080 -e=1e-12 -v=1 -prec_i=bf16 -repeat=1000
$EXE -m=700 -n=80 -e=1e-12 -v=1 -prec_i=fp16 -repeat=1000
$EXE -m=700 -n=128 -e=1e-12 -v=1 -prec_i=fp16 -repeat=1000
$EXE -m=700 -n=144 -e=1e-12 -v=1 -prec_i=fp16 -repeat=1000
$EXE -m=700 -n=168 -e=1e-12 -v=1 -prec_i=fp16 -repeat=1000
$EXE -m=700 -n=184 -e=1e-12 -v=1 -prec_i=fp16 -repeat=1000
$EXE -m=700 -n=256 -e=1e-12 -v=1 -prec_i=fp16 -repeat=1000
$EXE -m=700 -n=288 -e=1e-12 -v=1 -prec_i=fp16 -repeat=1000
$EXE -m=700 -n=344 -e=1e-12 -v=1 -prec_i=fp16 -repeat=1000
$EXE -m=700 -n=376 -e=1e-12 -v=1 -prec_i=fp16 -repeat=1000
$EXE -m=700 -n=448 -e=1e-12 -v=1 -prec_i=fp16 -repeat=1000
$EXE -m=700 -n=512 -e=1e-12 -v=1 -prec_i=fp16 -repeat=1000
$EXE -m=700 -n=924 -e=1e-12 -v=1 -prec_i=fp16 -repeat=1000
$EXE -m=700 -n=1024 -e=1e-12 -v=1 -prec_i=fp16 -repeat=1000
$EXE -m=700 -n=1078 -e=1e-12 -v=1 -prec_i=fp16 -repeat=1000
$EXE -m=700 -n=1996 -e=1e-12 -v=1 -prec_i=fp16 -repeat=1000
$EXE -m=700 -n=4080 -e=1e-12 -v=1 -prec_i=fp16 -repeat=1000
\ No newline at end of file
#!/bin/sh
EXE="$(find . -name tile_example_layernorm2d_fwd -type f | head -n 1)"
for fquant in "" "-fquant=1 -prec_o=int8"; do
for pr_i in "fp16" "bf16" ; do
for fadd in "0" "1"; do
$EXE -prec_i=$pr_i -fadd=$fadd $fquant -m=99 -n=13
$EXE -prec_i=$pr_i -fadd=$fadd $fquant -m=17 -n=16
$EXE -prec_i=$pr_i -fadd=$fadd $fquant -m=1 -n=100
$EXE -prec_i=$pr_i -fadd=$fadd $fquant -m=4 -n=128
$EXE -prec_i=$pr_i -fadd=$fadd $fquant -m=80 -n=127
$EXE -prec_i=$pr_i -fadd=$fadd $fquant -m=22 -n=255 -stride=256
$EXE -prec_i=$pr_i -fadd=$fadd $fquant -m=7 -n=599
$EXE -prec_i=$pr_i -fadd=$fadd $fquant -m=19 -n=512
$EXE -prec_i=$pr_i -fadd=$fadd $fquant -m=33 -n=313 -stride=1000
$EXE -prec_i=$pr_i -fadd=$fadd $fquant -m=11 -n=510
$EXE -prec_i=$pr_i -fadd=$fadd $fquant -m=171 -n=676 -stride=818
$EXE -prec_i=$pr_i -fadd=$fadd $fquant -m=91 -n=636
$EXE -prec_i=$pr_i -fadd=$fadd $fquant -m=12 -n=768 -stride=800
$EXE -prec_i=$pr_i -fadd=$fadd $fquant -m=100 -n=766 -stride=812
$EXE -prec_i=$pr_i -fadd=$fadd $fquant -m=31 -n=1024
$EXE -prec_i=$pr_i -fadd=$fadd $fquant -m=64 -n=1000 -stride=1004
$EXE -prec_i=$pr_i -fadd=$fadd $fquant -m=8 -n=1501
$EXE -prec_i=$pr_i -fadd=$fadd $fquant -m=3 -n=1826
$EXE -prec_i=$pr_i -fadd=$fadd $fquant -m=5 -n=2040
$EXE -prec_i=$pr_i -fadd=$fadd $fquant -m=7 -n=2734
$EXE -prec_i=$pr_i -fadd=$fadd $fquant -m=1 -n=3182
$EXE -prec_i=$pr_i -fadd=$fadd $fquant -m=9 -n=4096
$EXE -prec_i=$pr_i -fadd=$fadd $fquant -m=3 -n=8192
#$EXE -prec_i=$pr_i -fadd=$fadd $fquant -m=1 -n=10547
#$EXE -prec_i=$pr_i -fadd=$fadd $fquant -m=3 -n=17134
done
done
done
set(CMAKE_BUILD_TYPE Debug) add_executable(tile_example_gemm_basic EXCLUDE_FROM_ALL gemm_basic.cpp)
add_executable(tile_example_gemm_basic EXCLUDE_FROM_ALL gemm_basic.cpp) add_executable(tile_example_gemm_mem_pipeline EXCLUDE_FROM_ALL gemm_mem_pipeline.cpp)
\ No newline at end of file
...@@ -8,7 +8,10 @@ This folder contains example for GEMM using ck_tile tile-programming implementat ...@@ -8,7 +8,10 @@ This folder contains example for GEMM using ck_tile tile-programming implementat
mkdir build && cd build mkdir build && cd build
# you can replace <arch> with the appropriate architecture (for example gfx90a or gfx942) or leave it blank # you can replace <arch> with the appropriate architecture (for example gfx90a or gfx942) or leave it blank
sh ../script/cmake-ck-dev.sh ../ <arch> sh ../script/cmake-ck-dev.sh ../ <arch>
# The basic pipeline method on the gemm calculation
make tile_example_gemm_basic -j make tile_example_gemm_basic -j
# The memory bound pipeline on the gemm calculation
make tile_example_gemm_mem_pipeline -j
``` ```
This will result in an executable `build/bin/tile_example_gemm_basic` This will result in an executable `build/bin/tile_example_gemm_basic`
......
// SPDX-License-Identifier: MIT // SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#include "gemm_basic.hpp"
#include <hip/hip_runtime.h> #include <hip/hip_runtime.h>
#include <cstring> #include <cstring>
...@@ -10,58 +9,56 @@ ...@@ -10,58 +9,56 @@
#include <string> #include <string>
#include <tuple> #include <tuple>
auto create_args(int argc, char* argv[]) #include "ck_tile/ops/epilogue.hpp"
{ #include "ck_tile/ops/gemm.hpp"
ck_tile::ArgParser arg_parser; #include "ck_tile/host.hpp"
arg_parser.insert("b", "1", "batch size") #include "gemm_basic.hpp"
.insert("m", "1024", "m dimension")
.insert("n", "2048", "n dimension")
.insert("k", "64", "k dimension")
.insert("stride_a", "0", "Tensor A stride")
.insert("stride_b", "0", "Tensor B stride")
.insert("stride_c", "0", "Tensor C stride")
.insert("v", "2", "0. No validation, 1. Validation on CPU, 2. Validation on GPU")
.insert("e", "1e-5", "Absolute error tolerance")
.insert("prec", "fp16", "data type. fp16/bf16/fp8/bf8")
.insert("warmup", "10", "number of iterations before benchmark the kernel")
.insert("repeat", "100", "number of iterations to benchmark the kernel")
.insert("timer", "gpu", "gpu:gpu timer, cpu:cpu timer");
bool result = arg_parser.parse(argc, argv);
return std::make_tuple(result, arg_parser);
}
template <typename LayoutA, template <typename ALayout, typename BLayout, typename CLayout>
typename LayoutB,
typename LayoutC,
typename PipelineProblem,
typename GemmPipeline,
typename GemmShape>
float gemm_calc(const gemm_basic_args& args, const ck_tile::stream_config& s) float gemm_calc(const gemm_basic_args& args, const ck_tile::stream_config& s)
{ {
// The kPadA, kPadB, kPadC & kBlockPerCu should also come from the Codegen part. // The kPadM, kPadN, kPadK & kBlockPerCu should also come from the Codegen part.
constexpr bool kPadA = true; constexpr bool kPadM = false;
constexpr bool kPadB = true; constexpr bool kPadN = false;
constexpr bool kPadK = false;
constexpr bool kTilePermute = false; constexpr bool kTilePermute = false;
// The rank and permutation will also be generate out by the CodeGen part.
constexpr ck_tile::index_t kOutputRank = 2;
constexpr int kBlockPerCu = 1; constexpr int kBlockPerCu = 1;
using TilePartitioner = ck_tile::GemmTilePartitioner<GemmShape>; // This part comes from the Codegen
constexpr ck_tile::index_t M_Tile = 128;
constexpr ck_tile::index_t N_Tile = 128;
constexpr ck_tile::index_t K_Tile = 32;
// The rank and permutation will also be generate out by the CodeGen part. constexpr ck_tile::index_t M_Warp = 2;
constexpr ck_tile::index_t kOutputRank = 2; constexpr ck_tile::index_t N_Warp = 2;
constexpr ck_tile::index_t K_Warp = 1;
constexpr ck_tile::index_t M_Warp_Tile = 32;
constexpr ck_tile::index_t N_Warp_Tile = 32;
constexpr ck_tile::index_t K_Warp_Tile = 8;
// Whether doing the CShuffle (transpose before the global memory), depending on the output // Whether doing the CShuffle (transpose before the global memory), depending on the output
// layout. // layout.
constexpr bool CShuffleEpilogue = constexpr bool CShuffleEpilogue =
std::is_same_v<LayoutC, ck_tile::tensor_layout::gemm::ColumnMajor>; std::is_same_v<CLayout, ck_tile::tensor_layout::gemm::ColumnMajor>;
using CodegenGemmShape =
ck_tile::TileGemmShape<ck_tile::sequence<M_Tile, N_Tile, K_Tile>,
ck_tile::sequence<M_Warp, N_Warp, K_Warp>,
ck_tile::sequence<M_Warp_Tile, N_Warp_Tile, K_Warp_Tile>>;
using TilePartitioner = ck_tile::GemmTilePartitioner<CodegenGemmShape>;
using GemmEpilogue = std::conditional_t< using GemmEpilogue = std::conditional_t<
CShuffleEpilogue, CShuffleEpilogue,
ck_tile::CShuffleEpilogue<ck_tile::CShuffleEpilogueProblem<AccDataType, ck_tile::CShuffleEpilogue<ck_tile::CShuffleEpilogueProblem<AccDataType,
CDataType, CDataType,
kPadA, kPadM,
kPadB, kPadN,
kTilePermute, kTilePermute,
kOutputRank, kOutputRank,
1, 1,
...@@ -69,15 +66,22 @@ float gemm_calc(const gemm_basic_args& args, const ck_tile::stream_config& s) ...@@ -69,15 +66,22 @@ float gemm_calc(const gemm_basic_args& args, const ck_tile::stream_config& s)
TilePartitioner::kM, TilePartitioner::kM,
TilePartitioner::kN>>, TilePartitioner::kN>>,
ck_tile::Default2DEpilogue< ck_tile::Default2DEpilogue<
ck_tile::Default2DEpilogueProblem<AccDataType, CDataType, kPadA, kPadB>>>; ck_tile::Default2DEpilogueProblem<AccDataType, CDataType, kPadM, kPadN>>>;
using CodegenGemmTraits =
ck_tile::TileGemmTraits<kPadM, kPadN, kPadK, ALayout, BLayout, CLayout>;
using CodegenPipelineProblem = ck_tile::
GemmPipelineProblem<ADataType, BDataType, AccDataType, CodegenGemmShape, CodegenGemmTraits>;
using CodegenGemmPolicy = ck_tile::UniversalGemmPipelineAgBgCrPolicy;
using CodegenGemmPipeline =
ck_tile::GemmPipelineAGmemBGmemCRegV1<CodegenPipelineProblem, CodegenGemmPolicy>;
// ToDo: Will add the codegen part to test different pipeline policies in GEMM. // ToDo: Will add the codegen part to test different pipeline policies in GEMM.
// Now we only use the BlockGemmASmemBSmemCRegV1DefaultPolicy. // Now we only use the BlockGemmASmemBSmemCRegV1DefaultPolicy.
using Kernel = ck_tile::GemmKernel<TilePartitioner, GemmPipeline, GemmEpilogue>; using Kernel = ck_tile::GemmKernel<TilePartitioner, CodegenGemmPipeline, GemmEpilogue>;
auto kargs = Kernel::MakeKargs(args.p_a, auto kargs = Kernel::MakeKargs(args.p_a,
args.p_b, args.p_b,
args.p_c, args.p_c,
args.epsilon,
args.M, args.M,
args.N, args.N,
args.K, args.K,
...@@ -88,299 +92,20 @@ float gemm_calc(const gemm_basic_args& args, const ck_tile::stream_config& s) ...@@ -88,299 +92,20 @@ float gemm_calc(const gemm_basic_args& args, const ck_tile::stream_config& s)
const dim3 grids = Kernel::GridSize(args.M, args.N, args.kbatch); const dim3 grids = Kernel::GridSize(args.M, args.N, args.kbatch);
constexpr dim3 blocks = Kernel::BlockSize(); constexpr dim3 blocks = Kernel::BlockSize();
float ave_time = ck_tile::launch_kernel( if(s.log_level_ > 0)
s, ck_tile::make_kernel<blocks.x, kBlockPerCu>(Kernel{}, grids, blocks, 0, kargs));
return ave_time;
}
template <typename DataType,
typename LayoutA,
typename LayoutB,
typename LayoutC,
typename PipelineProblem,
typename GemmPipeline,
typename GemmShape>
float invoke_gemm(ck_tile::DeviceMem& a_buf,
ck_tile::DeviceMem& b_buf,
ck_tile::DeviceMem& c_buf,
const ck_tile::ArgParser& arg_parser)
{
std::string data_type = arg_parser.get_str("prec");
if(data_type != DataTypeTraits<DataType>::name)
{
std::cerr << "Data type mismatch: expected " << DataTypeTraits<DataType>::name << ", got "
<< data_type << std::endl;
return -1; // Or handle the error appropriately
}
float epsilon = arg_parser.get_float("e");
ck_tile::index_t batch_size = arg_parser.get_int("b");
ck_tile::index_t M = arg_parser.get_int("m");
ck_tile::index_t N = arg_parser.get_int("n");
ck_tile::index_t K = arg_parser.get_int("k");
ck_tile::index_t stride_a = arg_parser.get_int("stride_a");
ck_tile::index_t stride_b = arg_parser.get_int("stride_b");
ck_tile::index_t stride_c = arg_parser.get_int("stride_c");
gemm_basic_args args;
args.p_a = a_buf.GetDeviceBuffer();
args.p_b = b_buf.GetDeviceBuffer();
args.p_c = c_buf.GetDeviceBuffer();
args.epsilon = epsilon;
args.kbatch = batch_size;
args.M = M;
args.N = N;
args.K = K;
// Only set stride_M and stride_N if they are non-zero and not equal to K.
if(stride_a != 0)
{
args.stride_A = stride_a;
}
else
{ {
args.stride_A = [&]() { std::cout << "Launching kernel with args:"
if constexpr(std::is_same_v<LayoutA, ck_tile::tensor_layout::gemm::ColumnMajor>) << " grid: {" << grids.x << ", " << grids.y << ", " << grids.z << "}"
{ << ", blocks: {" << blocks.x << ", " << blocks.y << ", " << blocks.z << "}"
return M; << std::endl;
}
else
{
return K;
}
}();
} }
if(stride_b != 0) float ave_time = ck_tile::launch_kernel(
{ s, ck_tile::make_kernel<blocks.x, kBlockPerCu>(Kernel{}, grids, blocks, 0, kargs));
args.stride_B = stride_b;
}
else
{
args.stride_B = [&]() {
if constexpr(std::is_same_v<LayoutB, ck_tile::tensor_layout::gemm::RowMajor>)
{
return N;
}
else
{
return K;
}
}();
}
if(stride_c != 0)
{
args.stride_C = stride_c;
}
else
{
args.stride_C = [&]() {
if constexpr(std::is_same_v<LayoutC, ck_tile::tensor_layout::gemm::ColumnMajor>)
{
return M;
}
else
{
return N;
}
}();
}
float ave_time = gemm_calc<LayoutA, LayoutB, LayoutC, PipelineProblem, GemmPipeline, GemmShape>(
args, ck_tile::stream_config{nullptr, true});
std::size_t num_byte =
sizeof(ADataType) * M * K + sizeof(BDataType) * N * K + sizeof(CDataType) * M * N;
float gb_per_sec = num_byte / 1.E6 / ave_time;
std::cout << "The overall perfomance of the GEMM with "
<< "[" << data_type << "]"
<< "batch size: " << batch_size << ". m:" << M << ", n:" << N << ", k:" << K
<< " is: \n";
std::cout << "Running time: " << ave_time << "ms, Throughput " << gb_per_sec << "GB/s \n"
<< std::flush;
return ave_time; return ave_time;
} }
int main(int argc, char* argv[]) #include "run_gemm_example.inc"
{
auto [result, arg_parser] = create_args(argc, argv);
if(!result)
return -1;
ck_tile::index_t M = arg_parser.get_int("m");
ck_tile::index_t N = arg_parser.get_int("n");
ck_tile::index_t K = arg_parser.get_int("k");
// The Matrix Multiplication goes with Matrix A (M, K), Matrix B (N, K) = Matrix C (M, N).
using matrix_a_layout = ck_tile::tensor_layout::gemm::RowMajor;
using matrix_b_layout = ck_tile::tensor_layout::gemm::ColumnMajor;
using matrix_c_layout = ck_tile::tensor_layout::gemm::RowMajor;
// host verify
std::vector<int> a_dimensions =
(std::is_same_v<matrix_a_layout, ck_tile::tensor_layout::gemm::RowMajor>)
? std::vector<int>{M, K}
: std::vector<int>{K, M};
std::vector<int> b_dimensions =
(std::is_same_v<matrix_b_layout, ck_tile::tensor_layout::gemm::ColumnMajor>)
? std::vector<int>{N, K}
: std::vector<int>{K, N};
std::vector<int> c_dimensions =
(std::is_same_v<matrix_c_layout, ck_tile::tensor_layout::gemm::RowMajor>)
? std::vector<int>{M, N}
: std::vector<int>{N, M};
ck_tile::HostTensor<ADataType> a_host(a_dimensions);
ck_tile::HostTensor<BDataType> b_host(b_dimensions);
ck_tile::HostTensor<CDataType> c_host_ref(c_dimensions);
ck_tile::HostTensor<CDataType> c_host_dev(c_dimensions);
ck_tile::FillUniformDistribution<ADataType>{-5.f, 5.f}(a_host);
ck_tile::FillUniformDistribution<BDataType>{-5.f, 5.f}(b_host);
ck_tile::DeviceMem a_buf(a_host.get_element_space_size_in_bytes());
ck_tile::DeviceMem b_buf(b_host.get_element_space_size_in_bytes());
ck_tile::DeviceMem c_buf(c_host_dev.get_element_space_size_in_bytes());
a_buf.ToDevice(a_host.data());
b_buf.ToDevice(b_host.data());
// The kPadA, kPadB, kPadC & kBlockPerCu should also come from the Codegen part.
constexpr bool kPadA = true;
constexpr bool kPadB = true;
constexpr bool kPadC = true;
// This part comes from the Codegen
constexpr ck_tile::index_t M_Tile = 128;
constexpr ck_tile::index_t N_Tile = 128;
constexpr ck_tile::index_t K_Tile = 32;
constexpr ck_tile::index_t M_Warp = 2;
constexpr ck_tile::index_t N_Warp = 2;
constexpr ck_tile::index_t K_Warp = 1;
constexpr ck_tile::index_t M_Warp_Tile = 32;
constexpr ck_tile::index_t N_Warp_Tile = 32;
constexpr ck_tile::index_t K_Warp_Tile = 8;
using CodegenGemmShape =
ck_tile::TileGemmShape<ck_tile::sequence<M_Tile, N_Tile, K_Tile>,
ck_tile::sequence<M_Warp, N_Warp, K_Warp>,
ck_tile::sequence<M_Warp_Tile, N_Warp_Tile, K_Warp_Tile>>;
using CodegenGemmTraits = ck_tile::
TileGemmTraits<kPadA, kPadB, kPadC, matrix_a_layout, matrix_b_layout, matrix_c_layout>;
using CodegenPipelineProblem = ck_tile::
GemmPipelineProblem<ADataType, BDataType, AccDataType, CodegenGemmShape, CodegenGemmTraits>;
using CodegenGemmPolicy = ck_tile::
UniversalGemmPipelineAgBgCrPolicy<matrix_a_layout, matrix_b_layout, matrix_c_layout>;
using CodegenGemmPipeline =
ck_tile::GemmPipelineAGmemBGmemCRegV1<CodegenPipelineProblem, CodegenGemmPolicy>;
invoke_gemm<ck_tile::half_t,
matrix_a_layout,
matrix_b_layout,
matrix_c_layout,
CodegenPipelineProblem,
CodegenGemmPipeline,
CodegenGemmShape>(a_buf, b_buf, c_buf, arg_parser);
c_buf.FromDevice(c_host_dev.data());
bool pass_cpu = true;
if(arg_parser.get_int("v") == 1)
{
// ToDo: Will Add the Element Op (bias) verification in the future.
ck_tile::reference_gemm<ADataType,
BDataType,
AccDataType,
CDataType,
matrix_a_layout,
matrix_b_layout,
matrix_c_layout>(a_host, b_host, c_host_ref);
pass_cpu = ck_tile::check_err(c_host_dev, c_host_ref);
std::cout << "The CPU veification result is:" << (pass_cpu ? "correct" : "fail")
<< std::flush;
}
bool pass_gpu = true;
if(arg_parser.get_int("v") == 2)
{
ck_tile::index_t stride_a = arg_parser.get_int("stride_a");
ck_tile::index_t stride_b = arg_parser.get_int("stride_b");
ck_tile::index_t stride_c = arg_parser.get_int("stride_c");
if(stride_a == 0)
{
if constexpr(std::is_same_v<matrix_a_layout, ck_tile::tensor_layout::gemm::ColumnMajor>)
{
stride_a = M;
}
else
{
stride_a = K;
}
}
if(stride_b == 0)
{
if constexpr(std::is_same_v<matrix_b_layout, ck_tile::tensor_layout::gemm::RowMajor>)
{
stride_b = N;
}
else
{
stride_b = K;
}
}
if(stride_c == 0)
{
if constexpr(std::is_same_v<matrix_c_layout, ck_tile::tensor_layout::gemm::ColumnMajor>)
{
stride_c = M;
}
else
{
stride_c = N;
}
}
ck_tile::HostTensor<CDataType> c_host_gpu_ref(c_dimensions);
ck_tile::DeviceMem c_gpu_buf(c_host_gpu_ref.get_element_space_size_in_bytes());
ck_tile::reference_gemm_gpu<ADataType, int main(int argc, char* argv[]) { return !run_gemm_example(argc, argv); }
BDataType,
AccDataType,
CDataType,
matrix_a_layout,
matrix_b_layout,
matrix_c_layout>(
a_buf, b_buf, c_gpu_buf, M, N, K, stride_a, stride_b, stride_c);
c_buf.FromDevice(c_host_gpu_ref.data());
pass_gpu = ck_tile::check_err(c_host_dev, c_host_gpu_ref);
std::cout << "The GPU veification result is: " << (pass_gpu ? "correct" : "fail")
<< std::flush;
}
std::cout << std::endl << std::flush;
return !pass_gpu;
}
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment