Commit 0144b4f4 authored by letaoqin's avatar letaoqin
Browse files

Merge branch 'develop' into jizhan/reduce_threadwise_multi_d

parents 300337cd 34f3dfdd
...@@ -911,9 +911,8 @@ pipeline { ...@@ -911,9 +911,8 @@ pipeline {
execute_args = """ cmake -D CMAKE_PREFIX_PATH=/opt/rocm \ execute_args = """ cmake -D CMAKE_PREFIX_PATH=/opt/rocm \
-D CMAKE_CXX_COMPILER="${build_compiler()}" \ -D CMAKE_CXX_COMPILER="${build_compiler()}" \
-D CMAKE_BUILD_TYPE=Release \ -D CMAKE_BUILD_TYPE=Release \
-D GPU_TARGETS="gfx90a;gfx1030;gfx1101" \
-D INSTANCES_ONLY=ON \ -D INSTANCES_ONLY=ON \
-DCMAKE_CXX_FLAGS=" -O3 " .. && make -j32 """ -DCMAKE_CXX_FLAGS=" -O3 " .. && make -j64 """
} }
steps{ steps{
buildHipClangJobAndReboot(setup_cmd: "", build_cmd: "", no_reboot:true, build_type: 'Release', execute_cmd: execute_args) buildHipClangJobAndReboot(setup_cmd: "", build_cmd: "", no_reboot:true, build_type: 'Release', execute_cmd: execute_args)
......
rocm-docs-core==1.1.3 rocm-docs-core==1.2.0
sphinxcontrib-bibtex==2.6.2 sphinxcontrib-bibtex==2.6.2
...@@ -103,7 +103,7 @@ requests==2.31.0 ...@@ -103,7 +103,7 @@ requests==2.31.0
# via # via
# pygithub # pygithub
# sphinx # sphinx
rocm-docs-core==1.1.3 rocm-docs-core==1.2.0
# via -r requirements.in # via -r requirements.in
six==1.16.0 six==1.16.0
# via # via
......
add_example_executable(example_gemm_multiply_multiply_xdl_fp16 gemm_multiply_multiply_xdl_fp16.cpp)
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#include <iostream>
#include <numeric>
#include <initializer_list>
#include <cstdlib>
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_xdl_cshuffle_v3.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp"
#include "ck/library/utility/device_memory.hpp"
#include "ck/library/utility/host_tensor.hpp"
#include "ck/library/utility/host_tensor_generator.hpp"
#include "ck/library/utility/literals.hpp"
#include "ck/library/reference_tensor_operation/cpu/reference_gemm.hpp"
#include "ck/library/utility/check_err.hpp"
#include "ck/utility/blkgemmpipe_scheduler.hpp"
template <ck::index_t... Is>
using S = ck::Sequence<Is...>;
using F16 = ck::half_t;
using FP8 = ck::f8_t;
using F32 = float;
using Row = ck::tensor_layout::gemm::RowMajor;
using Col = ck::tensor_layout::gemm::ColumnMajor;
using A0DataType = FP8;
using B0DataType = FP8;
using AccDataType = F32;
using CShuffleDataType = F32;
using D0DataType = F32;
using D1DataType = F32;
using DsDataType = ck::Tuple<D0DataType, D1DataType>;
using EDataType = F16;
using A0Layout = Row;
using B0Layout = Col;
using D0Layout = Row;
using D1Layout = Col;
using DsLayout = ck::Tuple<D0Layout, D1Layout>;
using ELayout = Row;
struct MultiplyMultiply
{
template <typename E, typename C, typename D0, typename D1>
__host__ __device__ constexpr void
operator()(E& e, const C& c, const D0& d0, const D1& d1) const;
template <>
__host__ __device__ constexpr void operator()<ck::half_t, float, float, float>(
ck::half_t& e, const float& c, const float& d0, const float& d1) const
{
const float x0_f = c * d0 * d1;
e = ck::type_convert<ck::bhalf_t>(x0_f);
}
};
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
using AElementOp = PassThrough;
using BElementOp = PassThrough;
using CDEElementOp = MultiplyMultiply;
static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::Default;
using DeviceOpInstance = ck::tensor_operation::device::DeviceGemmMultiD_Xdl_CShuffle_V3
// clang-format off
///######| ALayout| BLayout| DsLayout| ELayout| AData| BData| DsData| EData| AccData| CShuffle| A| B| CDE| GEMM| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer|
///######| | | | | Type| Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise| Spacialization| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector|
///######| | | | | | | | | | | Operation| Operation| Operation| | | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl|
///######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | S<C, D0, D1>|
///###### RRR
///< Row, Row, DsLayout, ELayout, A0DataType, B0DataType, DsDataType, EDataType, AccDataType, CShuffleDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 256, 256, 128, 64, 16, 4, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 4, 0, 1, 1, S<1, 32, 1, 8>, S<8, 8, 1>, ck::BlockGemmPipelineScheduler::Interwave, ck::BlockGemmPipelineVersion::v1, FP8>;
///###### RCR
< Row, Col, DsLayout, ELayout, A0DataType, B0DataType, DsDataType, EDataType, AccDataType, CShuffleDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 256, 256, 128, 64, 16, 16, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 8>, S<8, 8, 1>, ck::BlockGemmPipelineScheduler::Interwave, ck::BlockGemmPipelineVersion::v1, FP8>;
// clang-format on
int main(int argc, char* argv[])
{
bool do_verification = true;
int init_method = 1;
bool time_kernel = false;
// GEMM shape
ck::index_t M = 3840;
ck::index_t N = 4096;
ck::index_t K = 4096;
ck::index_t StrideA = K;
ck::index_t StrideB = N;
ck::index_t StrideD = 0;
ck::index_t StrideE = N;
if(argc == 1)
{
// use default case
}
else if(argc == 4)
{
do_verification = std::stoi(argv[1]);
init_method = std::stoi(argv[2]);
time_kernel = std::stoi(argv[3]);
}
else if(argc == 11)
{
do_verification = std::stoi(argv[1]);
init_method = std::stoi(argv[2]);
time_kernel = std::stoi(argv[3]);
M = std::stoi(argv[4]);
N = std::stoi(argv[5]);
K = std::stoi(argv[6]);
StrideA = std::stoi(argv[7]);
StrideB = std::stoi(argv[8]);
StrideD = std::stoi(argv[9]);
StrideE = std::stoi(argv[10]);
}
else
{
printf("arg1: verification (0=no, 1=yes)\n");
printf("arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n");
printf("arg3: time kernel (0=no, 1=yes)\n");
printf("arg4 to 9: M (256x), N(128x), K(32x), StrideA, StrideB, StrideD, StrideE\n");
exit(0);
}
auto f_host_tensor_descriptor =
[](std::size_t row, std::size_t col, std::size_t stride, auto layout) {
using namespace ck::literals;
if(std::is_same<decltype(layout), ck::tensor_layout::gemm::RowMajor>::value)
{
return HostTensorDescriptor({row, col}, {stride, 1_uz});
}
else
{
return HostTensorDescriptor({row, col}, {1_uz, stride});
}
};
Tensor<A0DataType> a0_m_k(f_host_tensor_descriptor(M, K, StrideA, A0Layout{}));
Tensor<B0DataType> b0_k_n(f_host_tensor_descriptor(K, N, StrideB, B0Layout{}));
Tensor<D0DataType> d0_m_n(f_host_tensor_descriptor(M, N, StrideD, D0Layout{}));
Tensor<D1DataType> d1_m_n(f_host_tensor_descriptor(M, N, StrideD, D1Layout{}));
Tensor<EDataType> e_m_n_host_result(f_host_tensor_descriptor(M, N, StrideE, ELayout{}));
Tensor<EDataType> e_m_n_device_result(f_host_tensor_descriptor(M, N, StrideE, ELayout{}));
std::cout << "a0_m_k: " << a0_m_k.mDesc << std::endl;
std::cout << "b0_k_n: " << b0_k_n.mDesc << std::endl;
std::cout << "d1_m_n: " << d1_m_n.mDesc << std::endl;
std::cout << "d0_m_n: " << d0_m_n.mDesc << std::endl;
std::cout << "e_m_n: " << e_m_n_host_result.mDesc << std::endl;
switch(init_method)
{
case 0: break;
case 1:
a0_m_k.GenerateTensorValue(GeneratorTensor_2<A0DataType>{-5, 5});
b0_k_n.GenerateTensorValue(GeneratorTensor_2<B0DataType>{-5, 5});
d0_m_n.GenerateTensorValue(GeneratorTensor_2<D0DataType>{-5, 5});
d1_m_n.GenerateTensorValue(GeneratorTensor_2<D1DataType>{-5, 5});
break;
default:
a0_m_k.GenerateTensorValue(GeneratorTensor_3<A0DataType>{0.0, 1.0});
b0_k_n.GenerateTensorValue(GeneratorTensor_3<B0DataType>{-0.5, 0.5});
d0_m_n.GenerateTensorValue(GeneratorTensor_3<D0DataType>{-0.5, 0.5});
d1_m_n.GenerateTensorValue(GeneratorTensor_3<D1DataType>{-0.5, 0.5});
}
DeviceMem a0_device_buf(sizeof(A0DataType) * a0_m_k.mDesc.GetElementSpaceSize());
DeviceMem b0_device_buf(sizeof(B0DataType) * b0_k_n.mDesc.GetElementSpaceSize());
DeviceMem d0_device_buf(sizeof(D0DataType) * d0_m_n.mDesc.GetElementSpaceSize());
DeviceMem d1_device_buf(sizeof(D1DataType) * d1_m_n.mDesc.GetElementSpaceSize());
DeviceMem e_device_buf(sizeof(EDataType) * e_m_n_device_result.mDesc.GetElementSpaceSize());
a0_device_buf.ToDevice(a0_m_k.mData.data());
b0_device_buf.ToDevice(b0_k_n.mData.data());
d0_device_buf.ToDevice(d0_m_n.mData.data());
d1_device_buf.ToDevice(d1_m_n.mData.data());
e_device_buf.ToDevice(e_m_n_device_result.mData.data());
auto a_element_op = AElementOp{};
auto b_element_op = BElementOp{};
auto cde_element_op = CDEElementOp{};
constexpr ck::index_t NumDTensor = DsDataType::Size();
constexpr auto I0 = ck::Number<0>{};
// do GEMM
auto device_op = DeviceOpInstance{};
auto invoker = device_op.MakeInvoker();
auto argument =
device_op.MakeArgument(a0_device_buf.GetDeviceBuffer(),
b0_device_buf.GetDeviceBuffer(),
std::array<const void*, NumDTensor>{d0_device_buf.GetDeviceBuffer(),
d1_device_buf.GetDeviceBuffer()},
e_device_buf.GetDeviceBuffer(),
M,
N,
K,
StrideA,
StrideB,
std::array<ck::index_t, NumDTensor>{I0, I0},
StrideE,
a_element_op,
b_element_op,
cde_element_op);
if(!device_op.IsSupportedArgument(argument))
{
throw std::runtime_error(
"wrong! device_gemm with the specified compilation parameters does "
"not support this GEMM problem");
}
float ave_time = invoker.Run(argument, StreamConfig{nullptr, time_kernel, 20, 50});
std::size_t flop = std::size_t(2) * M * N * K;
std::size_t num_btype =
sizeof(A0DataType) * M * K + sizeof(B0DataType) * K * N + sizeof(EDataType) * M * N;
float tflops = static_cast<float>(flop) / 1.E9 / ave_time;
float gb_per_sec = num_btype / 1.E6 / ave_time;
std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s"
<< std::endl;
e_device_buf.FromDevice(e_m_n_device_result.mData.data());
if(do_verification)
{
Tensor<CShuffleDataType> c_m_n({M, N});
using ReferenceGemmInstance = ck::tensor_operation::host::ReferenceGemm<A0DataType,
B0DataType,
CShuffleDataType,
AccDataType,
PassThrough,
PassThrough,
PassThrough>;
auto ref_gemm = ReferenceGemmInstance{};
auto ref_invoker = ref_gemm.MakeInvoker();
auto ref_argument = ref_gemm.MakeArgument(
a0_m_k, b0_k_n, c_m_n, PassThrough{}, PassThrough{}, PassThrough{});
ref_invoker.Run(ref_argument);
for(int m = 0; m < M; ++m)
{
for(int n = 0; n < N; ++n)
{
cde_element_op(e_m_n_host_result(m, n), c_m_n(m, n), d0_m_n(m, n), d1_m_n(m, n));
}
}
e_device_buf.FromDevice(e_m_n_device_result.mData.data());
return ck::utils::check_err(e_m_n_device_result, e_m_n_host_result) ? 0 : 1;
}
return 0;
}
...@@ -34,6 +34,7 @@ args: ...@@ -34,6 +34,7 @@ args:
if not equal to h, then this is GQA/MQA case if not equal to h, then this is GQA/MQA case
-s seqlen_q. if group-mode, means the average value of seqlen_q (default:3328) -s seqlen_q. if group-mode, means the average value of seqlen_q (default:3328)
total_seqlen_q = seqlen_q * batch, and seqlen_q per batch may vary total_seqlen_q = seqlen_q * batch, and seqlen_q per batch may vary
also with "-s=s0,s1,s2..." comma seperated int to set per batch seqlen(group-mode)
-s_k seqlen_k, -1 means equal to s (default:-1) -s_k seqlen_k, -1 means equal to s (default:-1)
-d head dim for q, k (default:128) -d head dim for q, k (default:128)
-d_v head dim for v, -1 means equal to d (default:-1) -d_v head dim for v, -1 means equal to d (default:-1)
......
...@@ -44,11 +44,18 @@ auto create_args(int argc, char* argv[]) ...@@ -44,11 +44,18 @@ auto create_args(int argc, char* argv[])
"-1", "-1",
"num of head, for k/v, -1 means equal to h\n" "num of head, for k/v, -1 means equal to h\n"
"if not equal to h, then this is GQA/MQA case") "if not equal to h, then this is GQA/MQA case")
.insert("s", .insert(
"3328", "s",
"seqlen_q. if group-mode, means the average value of seqlen_q\n" "3328",
"total_seqlen_q = seqlen_q * batch, and seqlen_q per batch may vary") "seqlen_q. if group-mode, means the average value of seqlen_q\n"
"total_seqlen_q = seqlen_q * batch, and seqlen_q per batch may vary\n"
"also with \"-s=s0,s1,s2...\" comma seperated int to set per batch seqlen(group-mode)")
.insert("s_k", "-1", "seqlen_k, -1 means equal to s") .insert("s_k", "-1", "seqlen_k, -1 means equal to s")
.insert("s_kpad",
"-1",
"seqlen_k stride between 2 tokens, currently used in group-mode only\n"
"for kv-cache case, each batch [1,s,h,d]/[1,h,s,d] can have a stride\n"
"along seqlen, instead of packed. same as xformer kv_padding")
.insert("d", "128", "head dim for q, k") .insert("d", "128", "head dim for q, k")
.insert("d_v", "-1", "head dim for v, -1 means equal to d") .insert("d_v", "-1", "head dim for v, -1 means equal to d")
.insert("scale_s", .insert("scale_s",
...@@ -103,6 +110,7 @@ auto create_args(int argc, char* argv[]) ...@@ -103,6 +110,7 @@ auto create_args(int argc, char* argv[])
"11939", "11939",
"random seed used for initializing input tensors. 0 for " "random seed used for initializing input tensors. 0 for "
"non-deterministic seed") "non-deterministic seed")
.insert("timer", "gpu", "gpu:gpu timer, cpu:cpu timer")
.insert("warmup", "5", "number of iterations before benchmark the kernel") .insert("warmup", "5", "number of iterations before benchmark the kernel")
.insert("repeat", "20", "number of iterations to benchmark the kernel"); .insert("repeat", "20", "number of iterations to benchmark the kernel");
...@@ -177,10 +185,20 @@ bool run(const ck_tile::ArgParser& arg_parser) ...@@ -177,10 +185,20 @@ bool run(const ck_tile::ArgParser& arg_parser)
return false; return false;
} }
ck_tile::index_t seqlen_q = arg_parser.get_int("s"); auto [seqlen_qs, seqlen_ks, seqlen_kpads] = decode_seqlen(mode,
ck_tile::index_t seqlen_k = arg_parser.get_int("s_k"); batch,
if(seqlen_k < 0) arg_parser.get_str("s"),
seqlen_k = seqlen_q; arg_parser.get_str("s_k"),
arg_parser.get_str("s_kpad"));
#if 0
// clang-format off
std::cout << "seqlen_qs:"; for(auto xx : seqlen_qs) { std::cout << xx << ","; } std::cout << std::endl;
std::cout << "seqlen_ks:"; for(auto xx : seqlen_ks) { std::cout << xx << ","; } std::cout << std::endl;
std::cout << "seqlen_kpads:"; for(auto xx : seqlen_kpads) { std::cout << xx << ","; } std::cout << std::endl;
// clang-format on
#endif
ck_tile::index_t hdim_q = arg_parser.get_int("d"); ck_tile::index_t hdim_q = arg_parser.get_int("d");
ck_tile::index_t hdim_v = arg_parser.get_int("d_v"); ck_tile::index_t hdim_v = arg_parser.get_int("d_v");
if(hdim_v < 0) if(hdim_v < 0)
...@@ -229,7 +247,8 @@ bool run(const ck_tile::ArgParser& arg_parser) ...@@ -229,7 +247,8 @@ bool run(const ck_tile::ArgParser& arg_parser)
bool lse = arg_parser.get_bool("lse"); bool lse = arg_parser.get_bool("lse");
bias_info bias = bias_info::decode(arg_parser.get_str("bias")); bias_info bias = bias_info::decode(arg_parser.get_str("bias"));
mask_info mask = mask_info::decode(arg_parser.get_str("mask"), seqlen_q, seqlen_k); mask_info mask = mask_info::decode(
arg_parser.get_str("mask"), seqlen_qs[0], seqlen_ks[0]); // TODO: we don't need x/y anymore
std::string init_method = arg_parser.get_str("init"); std::string init_method = arg_parser.get_str("init");
std::optional<uint32_t> seed = arg_parser.get_uint32("seed"); std::optional<uint32_t> seed = arg_parser.get_uint32("seed");
...@@ -242,11 +261,16 @@ bool run(const ck_tile::ArgParser& arg_parser) ...@@ -242,11 +261,16 @@ bool run(const ck_tile::ArgParser& arg_parser)
int stream_repeat = arg_parser.get_int("repeat"); int stream_repeat = arg_parser.get_int("repeat");
bool kname = arg_parser.get_bool("kname"); bool kname = arg_parser.get_bool("kname");
ck_tile::stream_config stream_config{ ck_tile::stream_config stream_config{nullptr,
nullptr, true, /* log_level = */ (kname ? 1 : 0), stream_warmup, stream_repeat}; true,
/* log_level = */ (kname ? 1 : 0),
stream_warmup,
stream_repeat,
arg_parser.get_str("timer") == std::string("gpu")};
const auto seqstart_q_host = generate_seqstarts(mode, batch, seqlen_q); const auto seqstart_q_host = to_seqstarts(seqlen_qs);
const auto seqstart_k_host = generate_seqstarts(mode, batch, seqlen_k); const auto seqstart_k_host = to_seqstarts(seqlen_ks);
const auto seqstart_k_with_padding_host = to_seqstarts(seqlen_kpads);
using TypeConfig = FmhaFwdTypeConfig<DataType>; using TypeConfig = FmhaFwdTypeConfig<DataType>;
...@@ -302,9 +326,11 @@ bool run(const ck_tile::ArgParser& arg_parser) ...@@ -302,9 +326,11 @@ bool run(const ck_tile::ArgParser& arg_parser)
// host memory for storing all the tensor elements // host memory for storing all the tensor elements
const ck_tile::index_t shape_batch = (mode == mode_enum::batch ? batch : 1); const ck_tile::index_t shape_batch = (mode == mode_enum::batch ? batch : 1);
const ck_tile::index_t shape_seqlen_q = const ck_tile::index_t shape_seqlen_q =
(mode == mode_enum::batch ? seqlen_q : seqstart_q_host.back()); (mode == mode_enum::batch ? seqlen_qs[0] : seqstart_q_host.back());
const ck_tile::index_t shape_seqlen_k = const ck_tile::index_t shape_seqlen_k =
(mode == mode_enum::batch ? seqlen_k : seqstart_k_host.back()); (mode == mode_enum::batch ? seqlen_ks[0]
: (seqlen_kpads[0] < 0 ? seqstart_k_host.back()
: seqstart_k_with_padding_host.back()));
ck_tile::HostTensor<QDataType> q_host( ck_tile::HostTensor<QDataType> q_host(
get_lengths(i_perm, shape_batch, nhead, shape_seqlen_q, hdim_q)); get_lengths(i_perm, shape_batch, nhead, shape_seqlen_q, hdim_q));
...@@ -407,6 +433,7 @@ bool run(const ck_tile::ArgParser& arg_parser) ...@@ -407,6 +433,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
ck_tile::DeviceMem o_buf(o_host.get_element_space_size_in_bytes()); ck_tile::DeviceMem o_buf(o_host.get_element_space_size_in_bytes());
ck_tile::DeviceMem seqstart_q(seqstart_q_host.size() * sizeof(int32_t)); ck_tile::DeviceMem seqstart_q(seqstart_q_host.size() * sizeof(int32_t));
ck_tile::DeviceMem seqstart_k(seqstart_k_host.size() * sizeof(int32_t)); ck_tile::DeviceMem seqstart_k(seqstart_k_host.size() * sizeof(int32_t));
ck_tile::DeviceMem seqlen_k_buf(seqlen_kpads[0] < 0 ? 0 : seqlen_ks.size() * sizeof(int32_t));
ck_tile::DeviceMem alibi_slope_buf(alibi_slope_host.get_element_space_size_in_bytes()); ck_tile::DeviceMem alibi_slope_buf(alibi_slope_host.get_element_space_size_in_bytes());
q_buf.ToDevice(q_host.data()); q_buf.ToDevice(q_host.data());
...@@ -414,7 +441,9 @@ bool run(const ck_tile::ArgParser& arg_parser) ...@@ -414,7 +441,9 @@ bool run(const ck_tile::ArgParser& arg_parser)
v_buf.ToDevice(v_host.data()); v_buf.ToDevice(v_host.data());
bias_buf.ToDevice(bias_host.data()); bias_buf.ToDevice(bias_host.data());
seqstart_q.ToDevice(seqstart_q_host.data()); seqstart_q.ToDevice(seqstart_q_host.data());
seqstart_k.ToDevice(seqstart_k_host.data()); seqstart_k.ToDevice(seqlen_kpads[0] < 0 ? seqstart_k_host.data()
: seqstart_k_with_padding_host.data());
seqlen_k_buf.ToDevice(seqlen_kpads[0] < 0 ? nullptr : seqlen_ks.data());
alibi_slope_buf.ToDevice(alibi_slope_host.data()); alibi_slope_buf.ToDevice(alibi_slope_host.data());
// clang-format off // clang-format off
...@@ -430,7 +459,9 @@ bool run(const ck_tile::ArgParser& arg_parser) ...@@ -430,7 +459,9 @@ bool run(const ck_tile::ArgParser& arg_parser)
const std::string prec = arg_parser.get_str("prec"); const std::string prec = arg_parser.get_str("prec");
std::cout << "[" << prec << "|" << mode << "|" << io_layout(i_perm, o_perm) << "] b:" << batch std::cout << "[" << prec << "|" << mode << "|" << io_layout(i_perm, o_perm) << "] b:" << batch
<< ", h:" << nhead << "/" << nhead_k << ", s:" << seqlen_q << "/" << seqlen_k << ", h:" << nhead << "/" << nhead_k << ", s:" << seqlen_qs[0] << "/" << seqlen_ks[0]
<< (seqlen_kpads[0] < 0 ? ""
: (std::string("(") + std::to_string(seqlen_kpads[0]) + ")"))
<< ", d:" << hdim_q << "/" << hdim_v << ", scale_s:" << scale_s << ", bias:" << bias << ", d:" << hdim_q << "/" << hdim_v << ", scale_s:" << scale_s << ", bias:" << bias
<< ", lse:" << lse << ", squant:" << squant << ", mask:" << mask << ", v:" << vlayout << ", lse:" << lse << ", squant:" << squant << ", mask:" << mask << ", v:" << vlayout
<< std::flush; << std::flush;
...@@ -460,7 +491,7 @@ bool run(const ck_tile::ArgParser& arg_parser) ...@@ -460,7 +491,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
return ck_tile::identity{}; return ck_tile::identity{};
}(); }();
auto fmha_args = [&]() { auto fmha_args = [&, k_paddings_ = seqlen_kpads]() {
assert(nhead % nhead_k == 0); assert(nhead % nhead_k == 0);
/// NOTE: we broadcast bias from [1, 1, seqlen_q, seqlen_k] to [batch, nhead, seqlen_q, /// NOTE: we broadcast bias from [1, 1, seqlen_q, seqlen_k] to [batch, nhead, seqlen_q,
/// seqlen_k] in this example, hence both the 'batch_stride_bias' & /// seqlen_k] in this example, hence both the 'batch_stride_bias' &
...@@ -506,7 +537,7 @@ bool run(const ck_tile::ArgParser& arg_parser) ...@@ -506,7 +537,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
o_buf.GetDeviceBuffer(), o_buf.GetDeviceBuffer(),
seqstart_q.GetDeviceBuffer(), seqstart_q.GetDeviceBuffer(),
seqstart_k.GetDeviceBuffer(), seqstart_k.GetDeviceBuffer(),
nullptr, k_paddings_[0] < 0 ? nullptr : seqlen_k_buf.GetDeviceBuffer(),
shape_seqlen_q, shape_seqlen_q,
shape_seqlen_k, shape_seqlen_k,
batch, batch,
...@@ -576,7 +607,10 @@ bool run(const ck_tile::ArgParser& arg_parser) ...@@ -576,7 +607,10 @@ bool run(const ck_tile::ArgParser& arg_parser)
// adjust matrix index according to the mode // adjust matrix index according to the mode
const ck_tile::index_t b = (mode == mode_enum::batch ? wb : 0); const ck_tile::index_t b = (mode == mode_enum::batch ? wb : 0);
const ck_tile::index_t query_offset = (mode == mode_enum::batch ? 0 : seqstart_q_host[wb]); const ck_tile::index_t query_offset = (mode == mode_enum::batch ? 0 : seqstart_q_host[wb]);
const ck_tile::index_t key_offset = (mode == mode_enum::batch ? 0 : seqstart_k_host[wb]); const ck_tile::index_t key_offset =
(mode == mode_enum::batch
? 0
: (seqlen_kpads[0] < 0 ? seqstart_k_host[wb] : seqstart_k_with_padding_host[wb]));
const auto v_host_ref_lengths = const auto v_host_ref_lengths =
std::array<ck_tile::index_t, 3>{nhead, hdim_v, real_seqlen_k}; std::array<ck_tile::index_t, 3>{nhead, hdim_v, real_seqlen_k};
...@@ -661,7 +695,7 @@ bool run(const ck_tile::ArgParser& arg_parser) ...@@ -661,7 +695,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
else else
{ {
return ck_tile::Alibi<SaccDataType, true>{ return ck_tile::Alibi<SaccDataType, true>{
0, real_seqlen_q, real_seqlen_k, ck_tile::AlibiMode::VERTICAL}; 0, real_seqlen_q, real_seqlen_k, ck_tile::AlibiMode::FROM_BOTTOM_RIGHT};
} }
}(); }();
...@@ -671,7 +705,8 @@ bool run(const ck_tile::ArgParser& arg_parser) ...@@ -671,7 +705,8 @@ bool run(const ck_tile::ArgParser& arg_parser)
for(auto i_h = 0; i_h < nhead; i_h++) for(auto i_h = 0; i_h < nhead; i_h++)
{ {
SaccDataType current_slope = alibi_slope_host(i_b_slope, i_h); SaccDataType current_slope = alibi_slope_host(i_b_slope, i_h);
alibi_host.slope = current_slope; alibi_host.slope = alibi_host.mode == ck_tile::AlibiMode::VERTICAL ? current_slope
: -current_slope;
for(auto i_r = 0; i_r < real_seqlen_q; i_r++) for(auto i_r = 0; i_r < real_seqlen_q; i_r++)
{ {
for(auto i_c = 0; i_c < real_seqlen_k; i_c++) for(auto i_c = 0; i_c < real_seqlen_k; i_c++)
......
...@@ -78,6 +78,11 @@ BOOL_MAP = { ...@@ -78,6 +78,11 @@ BOOL_MAP = {
"f" : "false" "f" : "false"
} }
TILE_PARTITIONER_MAP = {
"shb" : "ck_tile::FmhaFwdTilePartitioner_SHB",
"hbs" : "ck_tile::FmhaFwdTilePartitioner_HBS",
}
DIRECTIONS = ["fwd"] DIRECTIONS = ["fwd"]
GEN_DIR = "" # in Cmake, have to generate files in same folder GEN_DIR = "" # in Cmake, have to generate files in same folder
...@@ -107,7 +112,7 @@ using fmha_trait_{F_idx} = ck_tile::TileFmhaTraits<{F_spad}, ...@@ -107,7 +112,7 @@ using fmha_trait_{F_idx} = ck_tile::TileFmhaTraits<{F_spad},
{F_dvpad}, {F_dvpad},
{F_bias}, {F_bias},
{F_lse}, {F_lse},
{F_squant}, {F_squant},
{F_occupancy}>; {F_occupancy}>;
using fmha_mask_{F_idx} = {F_mask}; using fmha_mask_{F_idx} = {F_mask};
...@@ -136,7 +141,7 @@ using fmha_epilogue_{F_idx} = ...@@ -136,7 +141,7 @@ using fmha_epilogue_{F_idx} =
{F_spad}, {F_dvpad}>>; {F_spad}, {F_dvpad}>>;
using fmha_kernel_{F_idx} = using fmha_kernel_{F_idx} =
ck_tile::FmhaFwdKernel<ck_tile::FmhaFwdTilePartitioner<fmha_shape_{F_idx}>, ck_tile::FmhaFwdKernel<{F_tile_partitioner}<fmha_shape_{F_idx}>,
fmha_pipeline_{F_idx}, fmha_pipeline_{F_idx},
fmha_epilogue_{F_idx}>; fmha_epilogue_{F_idx}>;
...@@ -154,7 +159,7 @@ float fmha_fwd_<trait_{F_idx}>(const ck_tile::stream_config& s, fmha_fwd_args a) ...@@ -154,7 +159,7 @@ float fmha_fwd_<trait_{F_idx}>(const ck_tile::stream_config& s, fmha_fwd_args a)
auto [kargs, grids] = fmha_fwd_create_kargs_and_grids<k_>(a); auto [kargs, grids] = fmha_fwd_create_kargs_and_grids<k_>(a);
constexpr dim3 blocks = k_::BlockSize(); constexpr dim3 blocks = k_::BlockSize();
constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu;
return ck_tile::launch_kernel<blocks.x, kBlockPerCu>(s, k_{{}}, grids, blocks, 0, kargs); return ck_tile::launch_kernel(s, ck_tile::make_kernel<blocks.x, kBlockPerCu>(k_{{}}, grids, blocks, 0, kargs));
}} }}
""" """
...@@ -389,6 +394,12 @@ class FmhaFwdKernel: ...@@ -389,6 +394,12 @@ class FmhaFwdKernel:
F_pipeline : FmhaFwdPipeline F_pipeline : FmhaFwdPipeline
mask_impl : str mask_impl : str
def get_tp(self) -> str:
if self.F_mode == 'group':
return 'hbs'
else:
return 'shb'
@property @property
def template(self) -> str: def template(self) -> str:
kernel_body = str() kernel_body = str()
...@@ -413,7 +424,7 @@ class FmhaFwdKernel: ...@@ -413,7 +424,7 @@ class FmhaFwdKernel:
F_spad = BOOL_MAP[self.F_pipeline.F_spad], F_spad = BOOL_MAP[self.F_pipeline.F_spad],
F_skpad = BOOL_MAP[self.F_pipeline.F_skpad], F_skpad = BOOL_MAP[self.F_pipeline.F_skpad],
F_dpad = BOOL_MAP[self.F_pipeline.F_dpad], F_dpad = BOOL_MAP[self.F_pipeline.F_dpad],
F_dvpad = BOOL_MAP[self.F_pipeline.F_dvpad], F_dvpad = BOOL_MAP[self.F_pipeline.F_dvpad],
F_bias = BIAS_MAP[self.F_pipeline.F_bias], F_bias = BIAS_MAP[self.F_pipeline.F_bias],
F_lse = BOOL_MAP[self.F_pipeline.F_lse], F_lse = BOOL_MAP[self.F_pipeline.F_lse],
F_squant = BOOL_MAP[self.F_pipeline.F_squant], F_squant = BOOL_MAP[self.F_pipeline.F_squant],
...@@ -421,12 +432,13 @@ class FmhaFwdKernel: ...@@ -421,12 +432,13 @@ class FmhaFwdKernel:
F_pipeline_enum = PIPELINE_ENUM_MAP[self.F_pipeline.tag], F_pipeline_enum = PIPELINE_ENUM_MAP[self.F_pipeline.tag],
F_mask = get_mask_map(self.mask_impl)[self.F_pipeline.F_mask], F_mask = get_mask_map(self.mask_impl)[self.F_pipeline.F_mask],
F_mode = MODE_MAP[self.F_mode], F_mode = MODE_MAP[self.F_mode],
F_pipeline = PIPELINE_MAP[self.F_pipeline.tag]) F_pipeline = PIPELINE_MAP[self.F_pipeline.tag],
F_tile_partitioner = TILE_PARTITIONER_MAP[self.get_tp()])
@property @property
def name(self) -> str: def name(self) -> str:
# TODO: we don't encode idx here # TODO: we don't encode idx here
return f"fmha_{self.direction}_d{self.F_hdim}_{self.F_dtype}_{self.F_mode}_" +\ return f"fmha_{self.direction}_d{self.F_hdim}_{self.F_dtype}_{self.F_mode}_{self.get_tp()}_" + \
self.F_tile.name + '_' + self.F_pipeline.name self.F_tile.name + '_' + self.F_pipeline.name
@property @property
......
...@@ -28,6 +28,7 @@ $EXE -prec=$prec -mode=$mode -b=2 -h=1 -d=$hdim -d_v=24 -s=3 -s_k=99 -bias=$bias ...@@ -28,6 +28,7 @@ $EXE -prec=$prec -mode=$mode -b=2 -h=1 -d=$hdim -d_v=24 -s=3 -s_k=99 -bias=$bias
$EXE -prec=$prec -mode=$mode -b=3 -h=2 -h_k=1 -d=$hdim -s=200 -s_k=520 -bias=$bias -lse=$lse -iperm=$perm -operm=$perm -mask=t:128,30 -vlayout=$vlayout -kname=$KNAME $COMMON_ARGS $EXE -prec=$prec -mode=$mode -b=3 -h=2 -h_k=1 -d=$hdim -s=200 -s_k=520 -bias=$bias -lse=$lse -iperm=$perm -operm=$perm -mask=t:128,30 -vlayout=$vlayout -kname=$KNAME $COMMON_ARGS
$EXE -prec=$prec -mode=$mode -b=2 -h=1 -d=$hdim -s=99 -s_k=32 -bias=$bias -lse=$lse -iperm=$perm -operm=$perm -mask=b:4,35 -vlayout=$vlayout -kname=$KNAME $COMMON_ARGS $EXE -prec=$prec -mode=$mode -b=2 -h=1 -d=$hdim -s=99 -s_k=32 -bias=$bias -lse=$lse -iperm=$perm -operm=$perm -mask=b:4,35 -vlayout=$vlayout -kname=$KNAME $COMMON_ARGS
$EXE -prec=$prec -mode=$mode -b=1 -h=2 -h_k=1 -d=$hdim -s=33 -s_k=0 -bias=$bias -lse=$lse -iperm=$perm -operm=$perm -mask=2 -vlayout=$vlayout -kname=$KNAME $COMMON_ARGS $EXE -prec=$prec -mode=$mode -b=1 -h=2 -h_k=1 -d=$hdim -s=33 -s_k=0 -bias=$bias -lse=$lse -iperm=$perm -operm=$perm -mask=2 -vlayout=$vlayout -kname=$KNAME $COMMON_ARGS
$EXE -prec=$prec -mode=$mode -b=1 -h=2 -h_k=1 -d=$hdim -s=1 -s_k=10 -s_kpad=32 -bias=$bias -lse=$lse -iperm=$perm -operm=$perm -mask=2 -vlayout=$vlayout -kname=$KNAME $COMMON_ARGS
done done
done done
......
...@@ -4,12 +4,14 @@ ...@@ -4,12 +4,14 @@
#pragma once #pragma once
#include <cstdint> #include <cstdint>
#include <cstdlib>
#include <optional> #include <optional>
#include <ostream> #include <ostream>
#include <tuple> #include <tuple>
#include <utility> #include <utility>
#include <vector> #include <vector>
#include <functional> #include <functional>
#include <string>
#include "ck_tile/core/container/span.hpp" #include "ck_tile/core/container/span.hpp"
...@@ -37,12 +39,14 @@ std::vector<int32_t> to_seqstarts(ck_tile::span<const int32_t> seqlens) ...@@ -37,12 +39,14 @@ std::vector<int32_t> to_seqstarts(ck_tile::span<const int32_t> seqlens)
std::vector<int32_t> generate_seqlens(mode_enum mode, std::vector<int32_t> generate_seqlens(mode_enum mode,
unsigned count, unsigned count,
int32_t seqlens_sum, int32_t seqlen_avg,
int32_t seqlen_max = -1, // if not negative, clamp max
std::optional<unsigned> seed = std::nullopt) std::optional<unsigned> seed = std::nullopt)
{ {
assert(0 < count); assert(0 < count);
std::vector<int32_t> seqlens(count, seqlens_sum); std::vector<int32_t> seqlens(
count, seqlen_max > 0 ? (seqlen_avg < seqlen_max ? seqlen_avg : seqlen_max) : seqlen_avg);
if(mode == mode_enum::group && 1 < count) if(mode == mode_enum::group && 1 < count)
{ {
...@@ -55,7 +59,7 @@ std::vector<int32_t> generate_seqlens(mode_enum mode, ...@@ -55,7 +59,7 @@ std::vector<int32_t> generate_seqlens(mode_enum mode,
std::uniform_int_distribution<size_type> step_dist(1, count - 1); std::uniform_int_distribution<size_type> step_dist(1, count - 1);
auto next_step = std::bind(step_dist, std::ref(random_engine)); auto next_step = std::bind(step_dist, std::ref(random_engine));
for(unsigned repeat = seqlens_sum * (count / 2); 0 < repeat; --repeat) for(unsigned repeat = seqlen_avg * (count / 2); 0 < repeat; --repeat)
{ {
const size_type to_decrease = next_idx(); const size_type to_decrease = next_idx();
// make sure each elements of seqlens is always greater than 0 // make sure each elements of seqlens is always greater than 0
...@@ -66,6 +70,11 @@ std::vector<int32_t> generate_seqlens(mode_enum mode, ...@@ -66,6 +70,11 @@ std::vector<int32_t> generate_seqlens(mode_enum mode,
const size_type to_increase = (to_decrease + next_step()) % count; const size_type to_increase = (to_decrease + next_step()) % count;
if(seqlen_max > 0 && seqlens[to_increase] >= seqlen_max)
{
continue;
}
--seqlens[to_decrease]; --seqlens[to_decrease];
++seqlens[to_increase]; ++seqlens[to_increase];
} }
...@@ -76,10 +85,91 @@ std::vector<int32_t> generate_seqlens(mode_enum mode, ...@@ -76,10 +85,91 @@ std::vector<int32_t> generate_seqlens(mode_enum mode,
std::vector<int32_t> generate_seqstarts(mode_enum mode, std::vector<int32_t> generate_seqstarts(mode_enum mode,
unsigned count, unsigned count,
int32_t seqlens_sum, int32_t seqlen_avg,
int32_t seqlen_max = -1,
std::optional<unsigned> seed = std::nullopt) std::optional<unsigned> seed = std::nullopt)
{ {
return to_seqstarts(generate_seqlens(mode, count, seqlens_sum, seed)); return to_seqstarts(generate_seqlens(mode, count, seqlen_avg, seqlen_max, seed));
}
/*
* decode the seqlen string from cmdline
* example (assume batch=3)
* q_val=1,2,3 k_val=4,5,6 -> OK
* q_val=1,2,3 -> OK, k same as q
* q_val=1,2 -> OK, q will rand remaining 1 element, k same as q
* q_val=1,2 k_val=4,5 -> OK, q/k will rand remaining 1 element
* q_val=1,2,3,4 -> OK, but ignore exceed one
*
* q_val=1,2 k_val=4,5,6 -> not OK, k must have same splits with q
* q_val=1,2 k_val=4 -> not OK, k must have same splits with q
*/
std::tuple<std::vector<ck_tile::index_t>,
std::vector<ck_tile::index_t>,
std::vector<ck_tile::index_t>>
decode_seqlen(mode_enum mode,
ck_tile::index_t batch,
std::string q_val,
std::string k_val,
std::string k_pad_val,
std::optional<unsigned> seed = std::nullopt)
{
#define _S2I_(str_) static_cast<ck_tile::index_t>(std::atoi((str_).c_str()))
if(mode == mode_enum::batch)
{
ck_tile::index_t q = _S2I_(q_val);
ck_tile::index_t k = _S2I_(k_val);
auto s_q = std::vector<ck_tile::index_t>(batch, q);
auto s_k = std::vector<ck_tile::index_t>(batch, k < 0 ? q : k);
auto s_kpad = std::vector<ck_tile::index_t>(batch, -1); // TODO: batch not support k_padding
return std::make_tuple(s_q, s_k, s_kpad);
}
else
{
ck_tile::index_t idx = 0;
std::string::size_type pos_q = 0;
std::string::size_type pos_k = 0;
std::string::size_type pos_kp = 0;
std::vector<ck_tile::index_t> s_q;
std::vector<ck_tile::index_t> s_k;
std::vector<ck_tile::index_t> s_kpad;
while(true)
{
auto found_q = q_val.find(',', pos_q);
auto found_k = k_val.find(',', pos_k);
auto found_kp = k_pad_val.find(',', pos_kp);
ck_tile::index_t q = _S2I_(
q_val.substr(pos_q, found_q == std::string::npos ? found_q : found_q - pos_q));
ck_tile::index_t k = _S2I_(
k_val.substr(pos_k, found_k == std::string::npos ? found_k : found_k - pos_k));
ck_tile::index_t kp = _S2I_(k_pad_val.substr(
pos_kp, found_kp == std::string::npos ? found_kp : found_kp - pos_kp));
s_q.push_back(q);
s_k.push_back(k < 0 ? q : k);
s_kpad.push_back(kp);
idx++;
if(found_q == std::string::npos || idx >= batch)
{
break;
}
pos_q = found_q + 1;
pos_k = found_k == std::string::npos ? pos_k : found_k + 1;
pos_kp = found_kp == std::string::npos ? pos_kp : found_kp + 1;
}
if(idx < batch)
{
auto rem_q = generate_seqlens(mode, batch - idx, s_q.back(), s_kpad.back(), seed);
auto rem_k = generate_seqlens(mode, batch - idx, s_k.back(), s_kpad.back(), seed);
s_q.insert(s_q.end(), rem_q.begin(), rem_q.end());
s_k.insert(s_k.end(), rem_k.begin(), rem_k.end());
s_kpad.insert(s_kpad.end(), batch - idx, s_kpad.back());
}
return std::make_tuple(s_q, s_k, s_kpad);
}
#undef _S2I_
} }
int env_get_int(const char* var_name, int default_int) int env_get_int(const char* var_name, int default_int)
...@@ -87,6 +177,6 @@ int env_get_int(const char* var_name, int default_int) ...@@ -87,6 +177,6 @@ int env_get_int(const char* var_name, int default_int)
char* v = getenv(var_name); char* v = getenv(var_name);
int r = default_int; int r = default_int;
if(v) if(v)
r = atoi(v); r = std::atoi(v);
return r; return r;
} }
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck/utility/common_header.hpp"
#include "ck/tensor_description/tensor_descriptor.hpp"
#include "ck/tensor_description/tensor_descriptor_helper.hpp"
#include "ck/tensor_description/cluster_descriptor.hpp"
#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v7r3.hpp"
#include "ck/utility/is_detected.hpp"
namespace ck {
// Thread-group level multi-source, multi-destination tensor slice data movement
// Assume:
// 1. All sources and destinations are DynamicBuffer
// 2. Same VectorDim and ScalerPerVector for all sources and destinations
// 3. DstInMemOps are per destination tensor
// 4. ThreadTransferSrcResetCoordinateAfterRunFlags are per source tensor
// 5. ThreadTransferDstResetCoordinateAfterRunFlags are per destination tensor
//
// Does following things to avoid scratch memory issue
// 1. Pass tensor descritpors by reference (or tuple of references)
// 2. Does not keep reference to tensor descriptor
// 3. Does not construct new tensor coordinate when call Run()
template <typename ThreadGroup,
typename SrcDatas,
typename DstDatas,
typename SrcDescs,
typename DstDescs,
typename ElementwiseOperation,
typename DstInMemOps, // Sequence<InMemoryDataOperationEnum ...>
typename SliceLengths,
typename ThreadClusterLengths,
typename ThreadClusterArrangeOrder,
typename SrcDimAccessOrder,
typename DstDimAccessOrder,
index_t SrcVectorDim,
index_t DstVectorDim,
typename SrcScalarPerVectors,
index_t DstScalarPerVector,
typename ThreadTransferSrcResetCoordinateAfterRunFlags,
typename ThreadTransferDstResetCoordinateAfterRunFlags,
index_t NumThreadScratch = 1>
struct ThreadGroupTensorSliceTransfer_v7r3
{
static constexpr index_t nDim =
remove_cvref_t<tuple_element_t<0, SrcDescs>>::GetNumOfDimension();
static constexpr index_t nSrc = remove_cvref_t<SrcDescs>::Size();
static constexpr index_t nDst = remove_cvref_t<DstDescs>::Size();
using Index = MultiIndex<nDim>;
static constexpr auto thread_slice_lengths = SliceLengths{} / ThreadClusterLengths{};
__device__ constexpr ThreadGroupTensorSliceTransfer_v7r3(
const SrcDescs& src_descs,
const StaticallyIndexedArray<Index, nSrc>& src_block_slice_origins,
const DstDescs& dst_descs,
const StaticallyIndexedArray<Index, nDst>& dst_block_slice_origins,
const ElementwiseOperation& element_op)
: threadwise_transfer_(src_descs,
StaticallyIndexedArray<Index, nSrc>{},
dst_descs,
StaticallyIndexedArray<Index, nDst>{},
element_op)
{
static_assert(nSrc == SrcDatas::Size() && nSrc == SrcDescs::Size() &&
nSrc == ThreadTransferSrcResetCoordinateAfterRunFlags::Size() &&
nDst == DstDatas::Size() && nDst == DstDescs::Size() &&
nDst == ThreadTransferDstResetCoordinateAfterRunFlags::Size(),
"wrong!");
static_for<0, nSrc, 1>{}([&](auto i) {
static_assert(
nDim == remove_cvref_t<tuple_element_t<i.value, SrcDescs>>::GetNumOfDimension(),
"wrong!");
});
static_for<0, nDst, 1>{}([&](auto i) {
static_assert(
nDim == remove_cvref_t<tuple_element_t<i.value, DstDescs>>::GetNumOfDimension(),
"wrong!");
});
static_assert(nDim == ThreadClusterLengths::Size() &&
nDim == ThreadClusterArrangeOrder::Size() &&
nDim == SrcDimAccessOrder::Size() && nDim == DstDimAccessOrder::Size(),
"wrong! nDim not consistent");
static_assert(
is_same<SliceLengths, decltype(thread_slice_lengths * ThreadClusterLengths{})>{},
"wrong! threads should be mapped to cover entire slicing window");
static_assert(ThreadGroup::GetNumOfThread() >= thread_cluster_desc_.GetElementSize(),
"wrong! ThreadGroup::GetNumOfThread() too small");
if(ThreadGroup::GetNumOfThread() == thread_cluster_desc_.GetElementSize() or
ThreadGroup::GetThreadId() < thread_cluster_desc_.GetElementSize())
{
const auto thread_cluster_idx = thread_cluster_desc_.CalculateBottomIndex(
make_multi_index(ThreadGroup::GetThreadId()));
const auto thread_data_idx_begin = thread_cluster_idx * thread_slice_lengths;
const auto src_thread_slice_origins = generate_tuple(
[&](auto i) { return src_block_slice_origins[i] + thread_data_idx_begin; },
Number<nSrc>{});
const auto dst_thread_slice_origins = generate_tuple(
[&](auto i) { return dst_block_slice_origins[i] + thread_data_idx_begin; },
Number<nDst>{});
threadwise_transfer_.SetSrcSliceOrigins(src_descs, src_thread_slice_origins);
threadwise_transfer_.SetDstSliceOrigins(dst_descs, dst_thread_slice_origins);
}
}
template <typename SrcBuffers, index_t ThreadScratchId = 0>
__device__ void RunRead(const SrcDescs& src_descs,
const SrcBuffers& src_bufs,
Number<ThreadScratchId> thread_scratch_id = Number<ThreadScratchId>{})
{
if(ThreadGroup::GetNumOfThread() == thread_cluster_desc_.GetElementSize() or
ThreadGroup::GetThreadId() < thread_cluster_desc_.GetElementSize())
{
threadwise_transfer_.RunRead(src_descs, src_bufs, thread_scratch_id);
}
}
template <typename T>
using is_tuple = decltype(std::declval<T&>().IsTuple());
template <typename DstBuffers, index_t ThreadScratchId = 0>
__device__ void RunWrite(const DstDescs& dst_descs,
DstBuffers dst_bufs,
Number<ThreadScratchId> thread_scratch_id = Number<ThreadScratchId>{})
{
if(ThreadGroup::GetNumOfThread() == thread_cluster_desc_.GetElementSize() or
ThreadGroup::GetThreadId() < thread_cluster_desc_.GetElementSize())
{
if constexpr(is_detected<is_tuple, decltype(dst_bufs)>::value)
threadwise_transfer_.RunWrite(dst_descs, dst_bufs, thread_scratch_id);
else
threadwise_transfer_.RunWrite(dst_descs, tie(dst_bufs), thread_scratch_id);
}
}
template <typename SrcBuffers, typename DstBuffers>
__device__ void Run(const SrcDescs& src_descs,
const SrcBuffers& src_bufs,
const DstDescs& dst_descs,
DstBuffers dst_bufs)
{
RunRead(src_descs, src_bufs);
RunWrite(dst_descs, dst_bufs);
}
template <index_t ISrc>
__device__ void
MoveSrcSliceWindow(const SrcDescs& src_descs, Number<ISrc> iSrc, const Index& step)
{
if(ThreadGroup::GetNumOfThread() == thread_cluster_desc_.GetElementSize() or
ThreadGroup::GetThreadId() < thread_cluster_desc_.GetElementSize())
{
threadwise_transfer_.MoveSrcSliceWindow(src_descs, iSrc, step);
}
}
__device__ void MoveSrcSliceWindow(const SrcDescs& src_descs, const Index& step)
{
static_for<0, SrcDescs::Size(), 1>{}(
[&](auto i) { MoveSrcSliceWindow(src_descs, i, step); });
}
template <index_t IDst>
__device__ void
MoveDstSliceWindow(const DstDescs& dst_descs, Number<IDst> iDst, const Index& step)
{
if(ThreadGroup::GetNumOfThread() == thread_cluster_desc_.GetElementSize() or
ThreadGroup::GetThreadId() < thread_cluster_desc_.GetElementSize())
{
threadwise_transfer_.MoveDstSliceWindow(dst_descs, iDst, step);
}
}
__device__ void MoveDstSliceWindow(const DstDescs& dst_descs, const Index& step)
{
static_for<0, DstDescs::Size(), 1>{}(
[&](auto i) { MoveDstSliceWindow(dst_descs, i, step); });
}
private:
static constexpr auto thread_cluster_desc_ =
make_cluster_descriptor(ThreadClusterLengths{}, ThreadClusterArrangeOrder{});
using ThreadwiseTransfer =
ThreadwiseTensorSliceTransfer_v7r3<SrcDatas,
DstDatas,
SrcDescs,
DstDescs,
ElementwiseOperation,
DstInMemOps,
decltype(thread_slice_lengths),
SrcDimAccessOrder,
DstDimAccessOrder,
SrcVectorDim,
DstVectorDim,
SrcScalarPerVectors,
DstScalarPerVector,
ThreadTransferSrcResetCoordinateAfterRunFlags,
ThreadTransferDstResetCoordinateAfterRunFlags,
NumThreadScratch>;
ThreadwiseTransfer threadwise_transfer_;
};
} // namespace ck
...@@ -29,6 +29,25 @@ CK_TILE_DEVICE int32x4_t make_wave_buffer_resource(const void* ptr, uint32_t siz ...@@ -29,6 +29,25 @@ CK_TILE_DEVICE int32x4_t make_wave_buffer_resource(const void* ptr, uint32_t siz
return __builtin_bit_cast(int32x4_t, res); return __builtin_bit_cast(int32x4_t, res);
} }
namespace impl {
// below type indicate the data type used for buffer load inline asm
// clang-format off
template<index_t N, typename T> struct buffer_load_trait;
template<typename T> struct buffer_load_trait<16, T> { using payload_t = fp32x4_t; };
template<typename T> struct buffer_load_trait<8 , T> { using payload_t = fp32x2_t; };
template<typename T> struct buffer_load_trait<4 , T> { using payload_t = float; };
template<typename T> struct buffer_load_trait<2 , T> { using payload_t = float; };
template<typename T> struct buffer_load_trait<1 , T> { using payload_t = float; };
#if CK_TILE_BUFFER_LOAD_RAW_BF16_WA
template<> struct buffer_load_trait<16, thread_buffer<bf16_t, 8>> { using payload_t = bf16x8_t; };
template<> struct buffer_load_trait<8 , thread_buffer<bf16_t, 4>> { using payload_t = bf16x4_t; };
template<> struct buffer_load_trait<4 , thread_buffer<bf16_t, 2>> { using payload_t = bf16x2_t; };
#endif
// clang-format on
} // namespace impl
// TODO: glc/slc/... // TODO: glc/slc/...
template <index_t bytes> template <index_t bytes>
struct buffer_load; struct buffer_load;
...@@ -48,7 +67,7 @@ struct buffer_load<16> ...@@ -48,7 +67,7 @@ struct buffer_load<16>
index_t /*flag*/ = 0) index_t /*flag*/ = 0)
{ {
static_assert(sizeof(T) == 16); static_assert(sizeof(T) == 16);
using mbuf_t = fp32x4_t; using mbuf_t = typename impl::buffer_load_trait<16, T>::payload_t;
asm volatile("buffer_load_dwordx4 %0, %1, %2, %3 offen offset:%4" asm volatile("buffer_load_dwordx4 %0, %1, %2, %3 offen offset:%4"
: "+v"(reinterpret_cast<mbuf_t&>(value)) : "+v"(reinterpret_cast<mbuf_t&>(value))
: "v"(v_offset), "s"(res), "s"(s_offset), "n"(i_offset) : "v"(v_offset), "s"(res), "s"(s_offset), "n"(i_offset)
...@@ -68,7 +87,7 @@ struct buffer_load<8> ...@@ -68,7 +87,7 @@ struct buffer_load<8>
index_t /*flag*/ = 0) index_t /*flag*/ = 0)
{ {
static_assert(sizeof(T) == 8); static_assert(sizeof(T) == 8);
using mbuf_t = fp32x2_t; using mbuf_t = typename impl::buffer_load_trait<8, T>::payload_t;
asm volatile("buffer_load_dwordx2 %0, %1, %2, %3 offen offset:%4" asm volatile("buffer_load_dwordx2 %0, %1, %2, %3 offen offset:%4"
: "+v"(reinterpret_cast<mbuf_t&>(value)) : "+v"(reinterpret_cast<mbuf_t&>(value))
: "v"(v_offset), "s"(res), "s"(s_offset), "n"(i_offset) : "v"(v_offset), "s"(res), "s"(s_offset), "n"(i_offset)
...@@ -88,7 +107,7 @@ struct buffer_load<4> ...@@ -88,7 +107,7 @@ struct buffer_load<4>
index_t /*flag*/ = 0) index_t /*flag*/ = 0)
{ {
static_assert(sizeof(T) == 4); static_assert(sizeof(T) == 4);
using mbuf_t = float; using mbuf_t = typename impl::buffer_load_trait<4, T>::payload_t;
asm volatile("buffer_load_dword %0, %1, %2, %3 offen offset:%4" asm volatile("buffer_load_dword %0, %1, %2, %3 offen offset:%4"
: "+v"(reinterpret_cast<mbuf_t&>(value)) : "+v"(reinterpret_cast<mbuf_t&>(value))
: "v"(v_offset), "s"(res), "s"(s_offset), "n"(i_offset) : "v"(v_offset), "s"(res), "s"(s_offset), "n"(i_offset)
...@@ -108,7 +127,7 @@ struct buffer_load<2> ...@@ -108,7 +127,7 @@ struct buffer_load<2>
index_t /*flag*/ = 0) index_t /*flag*/ = 0)
{ {
static_assert(sizeof(T) == 4); // subdword is buggy, use dword buf and convert manually static_assert(sizeof(T) == 4); // subdword is buggy, use dword buf and convert manually
using mbuf_t = float; using mbuf_t = typename impl::buffer_load_trait<2, T>::payload_t;
asm volatile("buffer_load_ushort %0, %1, %2, %3 offen offset:%4" asm volatile("buffer_load_ushort %0, %1, %2, %3 offen offset:%4"
: "+v"(reinterpret_cast<mbuf_t&>(value)) : "+v"(reinterpret_cast<mbuf_t&>(value))
: "v"(v_offset), "s"(res), "s"(s_offset), "n"(i_offset) : "v"(v_offset), "s"(res), "s"(s_offset), "n"(i_offset)
...@@ -128,7 +147,7 @@ struct buffer_load<1> ...@@ -128,7 +147,7 @@ struct buffer_load<1>
index_t /*flag*/ = 0) index_t /*flag*/ = 0)
{ {
static_assert(sizeof(T) == 4); static_assert(sizeof(T) == 4);
using mbuf_t = float; using mbuf_t = typename impl::buffer_load_trait<1, T>::payload_t;
asm volatile("buffer_load_ubyte %0, %1, %2, %3 offen offset:%4" asm volatile("buffer_load_ubyte %0, %1, %2, %3 offen offset:%4"
: "+v"(reinterpret_cast<mbuf_t&>(value)) : "+v"(reinterpret_cast<mbuf_t&>(value))
: "v"(v_offset), "s"(res), "s"(s_offset), "n"(i_offset) : "v"(v_offset), "s"(res), "s"(s_offset), "n"(i_offset)
...@@ -152,7 +171,7 @@ struct buffer_load_if<16> ...@@ -152,7 +171,7 @@ struct buffer_load_if<16>
{ {
static_assert(sizeof(T) == 16); static_assert(sizeof(T) == 16);
auto saved_exec = __builtin_amdgcn_read_exec(); auto saved_exec = __builtin_amdgcn_read_exec();
using mbuf_t = fp32x4_t; using mbuf_t = typename impl::buffer_load_trait<16, T>::payload_t;
static_assert(sizeof(mbuf_t) == sizeof(T)); static_assert(sizeof(mbuf_t) == sizeof(T));
asm volatile( asm volatile(
"v_cmpx_le_u32 exec, 1, %5\n" "v_cmpx_le_u32 exec, 1, %5\n"
...@@ -177,7 +196,7 @@ struct buffer_load_if<8> ...@@ -177,7 +196,7 @@ struct buffer_load_if<8>
{ {
static_assert(sizeof(T) == 8); static_assert(sizeof(T) == 8);
auto saved_exec = __builtin_amdgcn_read_exec(); auto saved_exec = __builtin_amdgcn_read_exec();
using mbuf_t = fp32x2_t; using mbuf_t = typename impl::buffer_load_trait<8, T>::payload_t;
asm volatile( asm volatile(
"v_cmpx_le_u32 exec, 1, %5\n" "v_cmpx_le_u32 exec, 1, %5\n"
"buffer_load_dwordx2 %0, %1, %2, %3 offen offset:%4\n" "buffer_load_dwordx2 %0, %1, %2, %3 offen offset:%4\n"
...@@ -201,7 +220,7 @@ struct buffer_load_if<4> ...@@ -201,7 +220,7 @@ struct buffer_load_if<4>
{ {
static_assert(sizeof(T) == 4); static_assert(sizeof(T) == 4);
auto saved_exec = __builtin_amdgcn_read_exec(); auto saved_exec = __builtin_amdgcn_read_exec();
using mbuf_t = float; using mbuf_t = typename impl::buffer_load_trait<4, T>::payload_t;
asm volatile( asm volatile(
"v_cmpx_le_u32 exec, 1, %5\n" "v_cmpx_le_u32 exec, 1, %5\n"
"buffer_load_dword %0, %1, %2, %3 offen offset:%4\n" "buffer_load_dword %0, %1, %2, %3 offen offset:%4\n"
...@@ -225,7 +244,7 @@ struct buffer_load_if<2> ...@@ -225,7 +244,7 @@ struct buffer_load_if<2>
{ {
static_assert(sizeof(T) == 4); static_assert(sizeof(T) == 4);
auto saved_exec = __builtin_amdgcn_read_exec(); auto saved_exec = __builtin_amdgcn_read_exec();
using mbuf_t = float; using mbuf_t = typename impl::buffer_load_trait<2, T>::payload_t;
asm volatile( asm volatile(
"v_cmpx_le_u32 exec, 1, %5\n" "v_cmpx_le_u32 exec, 1, %5\n"
"buffer_load_ushort %0, %1, %2, %3 offen offset:%4\n" "buffer_load_ushort %0, %1, %2, %3 offen offset:%4\n"
...@@ -249,7 +268,7 @@ struct buffer_load_if<1> ...@@ -249,7 +268,7 @@ struct buffer_load_if<1>
{ {
static_assert(sizeof(T) == 4); static_assert(sizeof(T) == 4);
auto saved_exec = __builtin_amdgcn_read_exec(); auto saved_exec = __builtin_amdgcn_read_exec();
using mbuf_t = float; using mbuf_t = typename impl::buffer_load_trait<1, T>::payload_t;
asm volatile( asm volatile(
"v_cmpx_le_u32 exec, 1, %5\n" "v_cmpx_le_u32 exec, 1, %5\n"
"buffer_load_ubyte %0, %1, %2, %3 offen offset:%4\n" "buffer_load_ubyte %0, %1, %2, %3 offen offset:%4\n"
......
...@@ -171,3 +171,7 @@ ...@@ -171,3 +171,7 @@
#ifndef CK_TILE_FMHA_FWD_FAST_EXP2 #ifndef CK_TILE_FMHA_FWD_FAST_EXP2
#define CK_TILE_FMHA_FWD_FAST_EXP2 0 #define CK_TILE_FMHA_FWD_FAST_EXP2 0
#endif #endif
#ifndef CK_TILE_BUFFER_LOAD_RAW_BF16_WA
#define CK_TILE_BUFFER_LOAD_RAW_BF16_WA 1
#endif
...@@ -20,3 +20,4 @@ ...@@ -20,3 +20,4 @@
#include "ck_tile/host/reference/reference_reduce.hpp" #include "ck_tile/host/reference/reference_reduce.hpp"
#include "ck_tile/host/reference/reference_softmax.hpp" #include "ck_tile/host/reference/reference_softmax.hpp"
#include "ck_tile/host/stream_config.hpp" #include "ck_tile/host/stream_config.hpp"
#include "ck_tile/host/timer.hpp"
...@@ -27,7 +27,14 @@ struct DeviceMem ...@@ -27,7 +27,14 @@ struct DeviceMem
DeviceMem() : mpDeviceBuf(nullptr), mMemSize(0) {} DeviceMem() : mpDeviceBuf(nullptr), mMemSize(0) {}
DeviceMem(std::size_t mem_size) : mMemSize(mem_size) DeviceMem(std::size_t mem_size) : mMemSize(mem_size)
{ {
HIP_CHECK_ERROR(hipMalloc(static_cast<void**>(&mpDeviceBuf), mMemSize)); if(mMemSize != 0)
{
HIP_CHECK_ERROR(hipMalloc(static_cast<void**>(&mpDeviceBuf), mMemSize));
}
else
{
mpDeviceBuf = nullptr;
}
} }
void Realloc(std::size_t mem_size) void Realloc(std::size_t mem_size)
{ {
...@@ -36,7 +43,14 @@ struct DeviceMem ...@@ -36,7 +43,14 @@ struct DeviceMem
HIP_CHECK_ERROR(hipFree(mpDeviceBuf)); HIP_CHECK_ERROR(hipFree(mpDeviceBuf));
} }
mMemSize = mem_size; mMemSize = mem_size;
HIP_CHECK_ERROR(hipMalloc(static_cast<void**>(&mpDeviceBuf), mMemSize)); if(mMemSize != 0)
{
HIP_CHECK_ERROR(hipMalloc(static_cast<void**>(&mpDeviceBuf), mMemSize));
}
else
{
mpDeviceBuf = nullptr;
}
} }
void* GetDeviceBuffer() const { return mpDeviceBuf; } void* GetDeviceBuffer() const { return mpDeviceBuf; }
std::size_t GetBufferSize() const { return mMemSize; } std::size_t GetBufferSize() const { return mMemSize; }
...@@ -47,15 +61,18 @@ struct DeviceMem ...@@ -47,15 +61,18 @@ struct DeviceMem
HIP_CHECK_ERROR( HIP_CHECK_ERROR(
hipMemcpy(mpDeviceBuf, const_cast<void*>(p), mMemSize, hipMemcpyHostToDevice)); hipMemcpy(mpDeviceBuf, const_cast<void*>(p), mMemSize, hipMemcpyHostToDevice));
} }
else // else
{ // {
throw std::runtime_error("ToDevice with an empty pointer"); // throw std::runtime_error("ToDevice with an empty pointer");
} // }
} }
void ToDevice(const void* p, const std::size_t cpySize) const void ToDevice(const void* p, const std::size_t cpySize) const
{ {
HIP_CHECK_ERROR( if(mpDeviceBuf)
hipMemcpy(mpDeviceBuf, const_cast<void*>(p), cpySize, hipMemcpyHostToDevice)); {
HIP_CHECK_ERROR(
hipMemcpy(mpDeviceBuf, const_cast<void*>(p), cpySize, hipMemcpyHostToDevice));
}
} }
void FromDevice(void* p) const void FromDevice(void* p) const
{ {
...@@ -63,14 +80,17 @@ struct DeviceMem ...@@ -63,14 +80,17 @@ struct DeviceMem
{ {
HIP_CHECK_ERROR(hipMemcpy(p, mpDeviceBuf, mMemSize, hipMemcpyDeviceToHost)); HIP_CHECK_ERROR(hipMemcpy(p, mpDeviceBuf, mMemSize, hipMemcpyDeviceToHost));
} }
else // else
{ // {
throw std::runtime_error("FromDevice with an empty pointer"); // throw std::runtime_error("FromDevice with an empty pointer");
} // }
} }
void FromDevice(void* p, const std::size_t cpySize) const void FromDevice(void* p, const std::size_t cpySize) const
{ {
HIP_CHECK_ERROR(hipMemcpy(p, mpDeviceBuf, cpySize, hipMemcpyDeviceToHost)); if(mpDeviceBuf)
{
HIP_CHECK_ERROR(hipMemcpy(p, mpDeviceBuf, cpySize, hipMemcpyDeviceToHost));
}
} }
void SetZero() const void SetZero() const
{ {
...@@ -82,13 +102,16 @@ struct DeviceMem ...@@ -82,13 +102,16 @@ struct DeviceMem
template <typename T> template <typename T>
void SetValue(T x) const void SetValue(T x) const
{ {
if(mMemSize % sizeof(T) != 0) if(mpDeviceBuf)
{ {
throw std::runtime_error("wrong! not entire DeviceMem will be set"); if(mMemSize % sizeof(T) != 0)
} {
throw std::runtime_error("wrong! not entire DeviceMem will be set");
}
// TODO: call a gpu kernel to set the value (?) // TODO: call a gpu kernel to set the value (?)
set_buffer_value<T><<<1, 1024>>>(static_cast<T*>(mpDeviceBuf), x, mMemSize / sizeof(T)); set_buffer_value<T><<<1, 1024>>>(static_cast<T*>(mpDeviceBuf), x, mMemSize / sizeof(T));
}
} }
~DeviceMem() ~DeviceMem()
{ {
......
...@@ -6,6 +6,7 @@ ...@@ -6,6 +6,7 @@
#include "ck_tile/core/config.hpp" #include "ck_tile/core/config.hpp"
#include "ck_tile/host/stream_config.hpp" #include "ck_tile/host/stream_config.hpp"
#include "ck_tile/host/hip_check_error.hpp" #include "ck_tile/host/hip_check_error.hpp"
#include "ck_tile/host/timer.hpp"
#include <hip/hip_runtime.h> #include <hip/hip_runtime.h>
#include <cstddef> #include <cstddef>
...@@ -14,153 +15,92 @@ template <int MaxThreadPerBlock, int MinBlockPerCu, typename Kernel, typename... ...@@ -14,153 +15,92 @@ template <int MaxThreadPerBlock, int MinBlockPerCu, typename Kernel, typename...
#if CK_TILE_USE_LAUNCH_BOUNDS #if CK_TILE_USE_LAUNCH_BOUNDS
__launch_bounds__(MaxThreadPerBlock, MinBlockPerCu) __launch_bounds__(MaxThreadPerBlock, MinBlockPerCu)
#endif #endif
__global__ void kentry(Kernel f, Args... args) __global__ void kentry(Args... args)
{ {
f(args...); Kernel{}(args...);
} }
template <typename... Args, typename F> //
CK_TILE_HOST float launch_and_time_kernel(const stream_config& s, // return a anonymous functor(lambda) to be called later
F kernel, // the KernelImpl should be a class without non-static data member, or let's say
dim3 grid_dim, // can be instantiate with "KernelImpl{}"
dim3 block_dim, //
std::size_t lds_byte, // the "static __device__ operator()(some_arg)" is the entry point of KernelImpl
Args... args) //
template <int MaxThreadPerBlock = CK_TILE_MAX_THREAD_PER_BLOCK,
int MinBlockPerCu = CK_TILE_MIN_BLOCK_PER_CU,
typename KernelImpl,
typename... Args>
CK_TILE_HOST auto
make_kernel(KernelImpl /*f*/, dim3 grid_dim, dim3 block_dim, std::size_t lds_byte, Args... args)
{ {
#if CK_TILE_TIME_KERNEL const auto kernel = kentry<MaxThreadPerBlock, MinBlockPerCu, KernelImpl, Args...>;
if(s.time_kernel_)
{
// warm up
for(int i = 0; i < s.cold_niters_; ++i)
{
kernel<<<grid_dim, block_dim, lds_byte, s.stream_id_>>>(args...);
hip_check_error(hipGetLastError());
}
const int nrepeat = s.nrepeat_;
hipEvent_t start, stop;
HIP_CHECK_ERROR(hipEventCreate(&start));
HIP_CHECK_ERROR(hipEventCreate(&stop));
HIP_CHECK_ERROR(hipDeviceSynchronize());
HIP_CHECK_ERROR(hipEventRecord(start, s.stream_id_));
for(int i = 0; i < nrepeat; ++i)
{
kernel<<<grid_dim, block_dim, lds_byte, s.stream_id_>>>(args...);
hip_check_error(hipGetLastError());
}
HIP_CHECK_ERROR(hipEventRecord(stop, s.stream_id_));
HIP_CHECK_ERROR(hipEventSynchronize(stop));
float total_time = 0;
HIP_CHECK_ERROR(hipEventElapsedTime(&total_time, start, stop));
return total_time / nrepeat; return [=](const stream_config& s) {
}
else
{
kernel<<<grid_dim, block_dim, lds_byte, s.stream_id_>>>(args...); kernel<<<grid_dim, block_dim, lds_byte, s.stream_id_>>>(args...);
hip_check_error(hipGetLastError()); };
return 0;
}
#else
kernel<<<grid_dim, block_dim, lds_byte, s.stream_id_>>>(args...);
hip_check_error(hipGetLastError());
return 0;
#endif
} }
template <typename... Args, typename F, typename PreProcessFunc> // clang-format off
CK_TILE_HOST float launch_and_time_kernel_with_preprocess(const stream_config& s, /*
PreProcessFunc preprocess, * launch_kernel()
F kernel, *
dim3 grid_dim, * this is the function to launch arbitrary number of kernels with optional timer(selected by stream_config)
dim3 block_dim, * the callables should have signature as "operator()(const stream_config& s){ ... }" to call
std::size_t lds_byte, *
Args... args) * the simplest way is pass in a lambda function, with "[=](const stream_config& s){ call_your_kernel_here() }"
* as signature, for the callable (pay attention to the capture list)
*
* e.g.
* ck_tile::launch_kernel(s,
* [=](const stream_config& s){ hipMemset(ptr, 0, size) },
* [=](const stream_config& s){ some_kernel<<<grids, blocks>>>(arg); }
* );
*
* if you use ck_tile kernel, or similiar to this style (structure with "static __device__ operator()(...){}")
* you can pass your kernel to ck_tile::make_kernel(), which will create a anonymous functor for you,
* then pass it to ck_tile::launch_kernel()
*
* e.g.
* ck_tile::launch_kernel(s,
* ck_tile::make_kernel<T0, B0>(kernel_0{}, grids0, blocks0, 0, kargs0),
* ck_tile::make_kernel<T0, B1>(kernel_1{}, grids1, blocks1, 0, kargs1),
* ...);
**/
// clang-format on
template <typename... Callables>
CK_TILE_HOST float launch_kernel(const stream_config& s, Callables... callables)
{ {
#if CK_TILE_TIME_KERNEL // clang-format off
if(s.time_kernel_) if(!s.time_kernel_) {
{ (callables(s),...); hip_check_error(hipGetLastError());
#if CK_TILE_DEBUG_LOG return 0;
printf("%s: grid_dim {%d, %d, %d}, block_dim {%d, %d, %d} \n", }
__func__, if(s.is_gpu_timer_) {
grid_dim.x, gpu_timer timer {};
grid_dim.y,
grid_dim.z,
block_dim.x,
block_dim.y,
block_dim.z);
printf("Warm up 1 time\n");
#endif
// warm up
preprocess();
kernel<<<grid_dim, block_dim, lds_byte, s.stream_id_>>>(args...);
hip_check_error(hipGetLastError());
const int nrepeat = 10;
#if CK_TILE_DEBUG_LOG
printf("Start running %d times...\n", nrepeat);
#endif
hipEvent_t start, stop;
HIP_CHECK_ERROR(hipEventCreate(&start));
HIP_CHECK_ERROR(hipEventCreate(&stop));
HIP_CHECK_ERROR(hipDeviceSynchronize());
HIP_CHECK_ERROR(hipEventRecord(start, s.stream_id_));
for(int i = 0; i < nrepeat; ++i) // warmup
{ for(int i = 0; i < s.cold_niters_; i++) { (callables(s),...); } hip_check_error(hipGetLastError());
preprocess();
kernel<<<grid_dim, block_dim, lds_byte, s.stream_id_>>>(args...);
hip_check_error(hipGetLastError());
}
HIP_CHECK_ERROR(hipEventRecord(stop, s.stream_id_)); timer.start(s.stream_id_);
HIP_CHECK_ERROR(hipEventSynchronize(stop)); for(int i = 0; i < s.nrepeat_; i++) { (callables(s),...); } hip_check_error(hipGetLastError());
timer.stop(s.stream_id_);
float total_time = 0; return timer.duration() / s.nrepeat_;
}
else {
cpu_timer timer {};
HIP_CHECK_ERROR(hipEventElapsedTime(&total_time, start, stop)); // warmup
for(int i = 0; i < s.cold_niters_; i++) { (callables(s),...); } hip_check_error(hipGetLastError());
return total_time / nrepeat; timer.start(s.stream_id_);
} for(int i = 0; i < s.nrepeat_; i++) { (callables(s),...); } hip_check_error(hipGetLastError());
else timer.stop(s.stream_id_);
{
preprocess();
kernel<<<grid_dim, block_dim, lds_byte, s.stream_id_>>>(args...);
hip_check_error(hipGetLastError());
return 0; return timer.duration() / s.nrepeat_;
} }
#else // clang-format on
kernel<<<grid_dim, block_dim, lds_byte, s.stream_id_>>>(args...);
hip_check_error(hipGetLastError());
return 0;
#endif
} }
template <int MaxThreadPerBlock = CK_TILE_MAX_THREAD_PER_BLOCK,
int MinBlockPerCu = CK_TILE_MIN_BLOCK_PER_CU,
typename KernelImpl,
typename... Args>
CK_TILE_HOST float launch_kernel(const stream_config& s,
KernelImpl kernel_impl,
dim3 grid_dim,
dim3 block_dim,
std::size_t dynamic_smem_byte,
Args... args)
{
const auto kernel = kentry<MaxThreadPerBlock, MinBlockPerCu, KernelImpl, Args...>;
return launch_and_time_kernel(
s, kernel, grid_dim, block_dim, dynamic_smem_byte, kernel_impl, args...);
}
} // namespace ck_tile } // namespace ck_tile
...@@ -6,6 +6,22 @@ ...@@ -6,6 +6,22 @@
#include <hip/hip_runtime.h> #include <hip/hip_runtime.h>
namespace ck_tile { namespace ck_tile {
/*
* construct this structure with behavior as:
*
* // create stream config with default stream(NULL), and not timing the kernel
* stream_config s = stream_config{};
*
* // create stream config with _some_stream_id_, and not timing the kernel
* stream_config s = stream_config{_some_stream_id_};
*
* // create stream config with _some_stream_id_, and benchmark with warmup/repeat as default
* stream_config s = stream_config{_some_stream_id_, true};
*
* // create stream config with _some_stream_id_, and benchmark using cpu timer
* stream_config s = stream_config{_some_stream_id_, true, 0, 3, 10, false};
**/
struct stream_config struct stream_config
{ {
hipStream_t stream_id_ = nullptr; hipStream_t stream_id_ = nullptr;
...@@ -13,5 +29,6 @@ struct stream_config ...@@ -13,5 +29,6 @@ struct stream_config
int log_level_ = 0; int log_level_ = 0;
int cold_niters_ = 3; int cold_niters_ = 3;
int nrepeat_ = 10; int nrepeat_ = 10;
bool is_gpu_timer_ = true; // keep compatible
}; };
} // namespace ck_tile } // namespace ck_tile
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment