Commit f6ceef78 authored by ThomasNing's avatar ThomasNing
Browse files

merge with the develop branch

parents 536c5458 25935b57
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#include "convnd_fwd_convscale_reduce_common.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp"
using InDataType = ck::f8_t;
using WeiDataType = ck::f8_t;
using AccDataType = float;
using CShuffleDataType = float;
using ConvOutDataType = float; // data type of convolution result
using OutDataType = ck::f8_t; // data type of final result
using AComputeDataType = ck::f8_t;
using BComputeDataType = ck::f8_t;
template <ck::index_t... Is>
using S = ck::Sequence<Is...>;
using InElementOp = PassThrough;
using WeiElementOp = PassThrough;
using OutElementOp = ConvScaleRelu;
static constexpr auto ConvSpec =
ck::tensor_operation::device::ConvolutionForwardSpecialization::Default;
static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::MNKPadding;
template <ck::index_t NDimSpatial, typename InLayout, typename WeiLayout, typename OutLayout>
using DeviceGroupedConvNDFwdInstance =
ck::tensor_operation::device::DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<
NDimSpatial,
InLayout,
WeiLayout,
ck::Tuple<>,
OutLayout,
InDataType,
WeiDataType,
AccDataType,
CShuffleDataType,
ck::Tuple<>,
ConvOutDataType,
InElementOp,
WeiElementOp,
OutElementOp,
ConvSpec, // ConvForwardSpecialization
GemmSpec, // GemmSpecialization
1, //
256, // BlockSize
128, // MPerBlock
256, // NPerBlock
32, // KPerBlock
8, // AK1
8, // BK1
32, // MPerXdl
32, // NPerXdl
2, // MXdlPerWave
4, // NXdlPerWave
S<4, 64, 1>, // ABlockTransferThreadClusterLengths_AK0_M_AK1
S<1, 0, 2>, // ABlockTransferThreadClusterArrangeOrder
S<1, 0, 2>, // ABlockTransferSrcAccessOrder
2, // ABlockTransferSrcVectorDim
8, // ABlockTransferSrcScalarPerVector
8, // ABlockTransferDstScalarPerVector_AK1
1, // ABlockLdsExtraM
S<4, 64, 1>, // BBlockTransferThreadClusterLengths_BK0_N_BK1
S<1, 0, 2>, // BBlockTransferThreadClusterArrangeOrder
S<1, 0, 2>, // BBlockTransferSrcAccessOrder
2, // BBlockTransferSrcVectorDim
8, // BBlockTransferSrcScalarPerVector
8, // BBlockTransferDstScalarPerVector_BK1
1, // BBlockLdsExtraN
1,
1,
S<1, 32, 1, 8>,
8,
AComputeDataType,
BComputeDataType>;
#include "run_convnd_fwd_example.inc"
int main(int argc, char* argv[]) { return run_convnd_fwd_example(argc, argv) ? 0 : 1; }
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
bool run_convnd_fwd_example(int argc, char* argv[])
{
print_helper_msg();
bool do_verification = true;
int init_method = 1;
bool time_kernel = false;
ck::utils::conv::ConvParam conv_param{
2, 1, 128, 256, 192, {3, 3}, {71, 71}, {2, 2}, {1, 1}, {1, 1}, {1, 1}};
if(argc == 1)
{
// use default
}
else if(argc == 4)
{
do_verification = std::stoi(argv[1]);
init_method = std::stoi(argv[2]);
time_kernel = std::stoi(argv[3]);
}
else
{
do_verification = std::stoi(argv[1]);
init_method = std::stoi(argv[2]);
time_kernel = std::stoi(argv[3]);
const ck::index_t num_dim_spatial = std::stoi(argv[4]);
conv_param = ck::utils::conv::parse_conv_param(num_dim_spatial, 5, argv);
}
// instantiate in and wei element ops, will
// instantiate out_element_op below for every iteration
const auto in_element_op = InElementOp{};
const auto wei_element_op = WeiElementOp{};
const auto run = [&](auto ndim_spatial, auto in_layout, auto wei_layout, auto out_layout) {
constexpr ck::index_t ndim_spatial_value = ndim_spatial.value;
using InLayout = decltype(in_layout);
using WeiLayout = decltype(wei_layout);
using OutLayout = decltype(out_layout);
const auto in_g_n_c_wis_desc =
ck::utils::conv::make_input_host_tensor_descriptor_g_n_c_wis_packed<InLayout>(
conv_param);
const auto wei_g_k_c_xs_desc =
ck::utils::conv::make_weight_host_tensor_descriptor_g_k_c_xs_packed<WeiLayout>(
conv_param);
const auto out_g_n_k_wos_desc =
ck::utils::conv::make_output_host_tensor_descriptor_g_n_k_wos_packed<OutLayout>(
conv_param);
return run_grouped_conv_fwd<
ndim_spatial_value,
InDataType,
WeiDataType,
ConvOutDataType,
OutDataType,
InElementOp,
WeiElementOp,
OutElementOp,
DeviceGroupedConvNDFwdInstance<ndim_spatial_value, InLayout, WeiLayout, OutLayout>>(
do_verification,
init_method,
time_kernel,
conv_param,
in_g_n_c_wis_desc,
wei_g_k_c_xs_desc,
out_g_n_k_wos_desc,
in_element_op,
wei_element_op);
};
namespace ctc = ck::tensor_layout::convolution;
if(conv_param.num_dim_spatial_ == 1)
{
return run(ck::Number<1>{}, ctc::GNWC{}, ctc::GKXC{}, ctc::GNWK{});
}
else if(conv_param.num_dim_spatial_ == 2)
{
return run(ck::Number<2>{}, ctc::GNHWC{}, ctc::GKYXC{}, ctc::GNHWK{});
}
else if(conv_param.num_dim_spatial_ == 3)
{
return run(ck::Number<3>{}, ctc::GNDHWC{}, ctc::GKZYXC{}, ctc::GNDHWK{});
}
return true;
}
......@@ -208,6 +208,7 @@ int main(int argc, char* argv[])
StrideB,
std::array<ck::index_t, NumDTensor>{StrideD, StrideD},
StrideE,
1,
a_element_op,
b_element_op,
cde_element_op);
......
......@@ -69,7 +69,7 @@ using AElementOp = PassThrough;
using BElementOp = PassThrough;
using CDEElementOp = MultiplyMultiply;
static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::Default;
static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::MNPadding;
using DeviceOpInstance = ck::tensor_operation::device::DeviceGemmMultiD_Xdl_CShuffle_V3
// clang-format off
......@@ -99,6 +99,8 @@ int main(int argc, char* argv[])
ck::index_t StrideD = 0;
ck::index_t StrideE = N;
ck::index_t KBatch = 1;
if(argc == 1)
{
// use default case
......@@ -109,7 +111,7 @@ int main(int argc, char* argv[])
init_method = std::stoi(argv[2]);
time_kernel = std::stoi(argv[3]);
}
else if(argc == 11)
else if(argc == 12)
{
do_verification = std::stoi(argv[1]);
init_method = std::stoi(argv[2]);
......@@ -123,13 +125,16 @@ int main(int argc, char* argv[])
StrideB = std::stoi(argv[8]);
StrideD = std::stoi(argv[9]);
StrideE = std::stoi(argv[10]);
KBatch = std::stoi(argv[11]);
}
else
{
printf("arg1: verification (0=no, 1=yes)\n");
printf("arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n");
printf("arg3: time kernel (0=no, 1=yes)\n");
printf("arg4 to 9: M (256x), N(128x), K(32x), StrideA, StrideB, StrideD, StrideE\n");
printf(
"arg4 to 9: M (256x), N(128x), K(32x), StrideA, StrideB, StrideD, StrideE, KBatch\n");
exit(0);
}
......@@ -212,6 +217,7 @@ int main(int argc, char* argv[])
StrideB,
std::array<ck::index_t, NumDTensor>{I0, I0},
StrideE,
KBatch,
a_element_op,
b_element_op,
cde_element_op);
......@@ -236,10 +242,12 @@ int main(int argc, char* argv[])
std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s"
<< std::endl;
e_device_buf.FromDevice(e_m_n_device_result.mData.data());
if(do_verification)
{
invoker.Run(argument, StreamConfig{nullptr, false});
e_device_buf.FromDevice(e_m_n_device_result.mData.data());
Tensor<CShuffleDataType> c_m_n({M, N});
using ReferenceGemmInstance = ck::tensor_operation::host::ReferenceGemm<A0DataType,
......
......@@ -72,10 +72,24 @@ function(add_example_executable EXAMPLE_NAME FILE_NAME)
list(REMOVE_ITEM FILE_NAME "${source}")
endif()
endforeach()
#Do not build any FP8 examples if CK_ENABLE_FP8 not set
foreach(source IN LISTS FILE_NAME)
if(NOT DEFINED CK_ENABLE_FP8 AND source MATCHES "_fp8")
message("removing fp8 example ${source} ")
list(REMOVE_ITEM FILE_NAME "${source}")
endif()
endforeach()
#Do not build any BF8 examples if CK_ENABLE_BF8 not set
foreach(source IN LISTS FILE_NAME)
if(NOT DEFINED CK_ENABLE_BF8 AND source MATCHES "_bf8")
message("removing bf8 example ${source} ")
list(REMOVE_ITEM FILE_NAME "${source}")
endif()
endforeach()
#only continue if there are some source files left on the list
if(FILE_NAME)
if(FILE_NAME MATCHES "_xdl")
list(REMOVE_ITEM EX_TARGETS gfx1030 gfx1100 gfx1101 gfx1102 gfx1103)
list(REMOVE_ITEM EX_TARGETS gfx1030 gfx1100 gfx1101 gfx1102 gfx1103 gfx1200 gfx1201)
elseif(FILE_NAME MATCHES "_wmma")
list(REMOVE_ITEM EX_TARGETS gfx908 gfx90a gfx940 gfx941 gfx942 gfx1030)
endif()
......@@ -162,7 +176,7 @@ function(add_example_executable_no_testing EXAMPLE_NAME FILE_NAME)
#only continue if there are some source files left on the list
if(FILE_NAME)
if(FILE_NAME MATCHES "_xdl")
list(REMOVE_ITEM EX_TARGETS gfx900 gfx906 gfx1030 gfx1100 gfx1101 gfx1102 gfx1103)
list(REMOVE_ITEM EX_TARGETS gfx900 gfx906 gfx1030 gfx1100 gfx1101 gfx1102 gfx1103 gfx1200 gfx1201)
elseif(FILE_NAME MATCHES "_wmma")
list(REMOVE_ITEM EX_TARGETS gfx900 gfx906 gfx908 gfx90a gfx940 gfx941 gfx942 gfx1030)
endif()
......
......@@ -6,7 +6,7 @@ execute_process(
execute_process(
COMMAND ${Python3_EXECUTABLE} ${CMAKE_CURRENT_LIST_DIR}/generate.py
--api bwd --list_blobs ${CMAKE_CURRENT_BINARY_DIR}/bwd_blob_list.txt
--api bwd --list_blobs ${CMAKE_CURRENT_BINARY_DIR}/bwd_blob_list.txt --receipt 3
)
# NOTE: for cmake, the FMHA_FWD_GEN_BLOBS/FMHA_BWD_GEN_BLOBS files must be in the same directory
......@@ -23,7 +23,7 @@ add_custom_command(
add_custom_command(
OUTPUT ${FMHA_BWD_GEN_BLOBS}
COMMAND ${Python3_EXECUTABLE} ${CMAKE_CURRENT_LIST_DIR}/generate.py
--api bwd --output_dir ${CMAKE_CURRENT_BINARY_DIR}
--api bwd --output_dir ${CMAKE_CURRENT_BINARY_DIR} --receipt 3
)
set(EXAMPLE_FMHA_FWD "tile_example_fmha_fwd")
......@@ -55,11 +55,10 @@ set(EXAMPLE_FMHA_BWD_COMPILE_OPTIONS)
# ... because they are auto-generated
if(FMHA_FWD_FAST_EXP2)
list(APPEND EXAMPLE_FMHA_FWD_COMPILE_OPTIONS -Wno-undefined-func-template -DCK_TILE_FMHA_FWD_FAST_EXP2=1 -fgpu-flush-denormals-to-zero)
list(APPEND EXAMPLE_FMHA_BWD_COMPILE_OPTIONS -Wno-undefined-func-template -DCK_TILE_FMHA_FWD_FAST_EXP2=1 -fgpu-flush-denormals-to-zero)
else()
list(APPEND EXAMPLE_FMHA_FWD_COMPILE_OPTIONS -Wno-undefined-func-template -DCK_TILE_FMHA_FWD_FAST_EXP2=0)
list(APPEND EXAMPLE_FMHA_BWD_COMPILE_OPTIONS -Wno-undefined-func-template -DCK_TILE_FMHA_FWD_FAST_EXP2=0)
endif()
list(APPEND EXAMPLE_FMHA_BWD_COMPILE_OPTIONS -Wno-undefined-func-template -fgpu-flush-denormals-to-zero)
# Allow comparing floating points directly in order to check sentinel values
list(APPEND EXAMPLE_FMHA_FWD_COMPILE_OPTIONS -Wno-float-equal)
......
......@@ -66,6 +66,22 @@ BIAS_CHECK_MAP = {
"alibi" : "bias_enum::alibi"
}
DROPOUT_MAP = {
"no" : "ck_tile::BlockDropoutBwd<false, true, false>",
"dropout_wg32" : "ck_tile::BlockDropoutBwd<true, true, false>",
"dropout_wg32_storerandval" : "ck_tile::BlockDropoutBwd<true, true, true >",
"dropout_wg16" : "ck_tile::BlockDropoutBwd<true, false, false>",
"dropout_wg16_storerandval" : "ck_tile::BlockDropoutBwd<true, false, true >"
}
DROPOUT_CHECK_MAP = {
"no" : "t.has_dropout == false",
"dropout_wg32" : "t.has_dropout == true && t.is_store_randval == false",
"dropout_wg32_storerandval" : "t.has_dropout == true && t.is_store_randval == true",
"dropout_wg16" : "t.has_dropout == true && t.is_store_randval == false",
"dropout_wg16_storerandval" : "t.has_dropout == true && t.is_store_randval == true",
}
MODE_MAP = {
"batch" : "false",
"group" : "true"
......
......@@ -14,15 +14,13 @@ from codegen.cpp_symbol_map import *
BWD_DQDKDV_PIPELINE_MAP = {
"ks_kts_vr" : "ck_tile::BlockFmhaBwdDQDKDVPipelineKSKTSVR",
"qs_ks_vr_dos" : "ck_tile::BlockFmhaBwdDQDKDVPipelineQSKSVROGradS",
"ks_vr" : "ck_tile::BlockFmhaBwdDQDKDVPipelineKSVR",
"kr_ktr_vr_iglp" : "ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP",
"kr_ktr_vr" : "ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR",
}
BWD_DQDKDV_PIPELINE_ENUM_MAP = {
"ks_kts_vr" : "ck_tile::BlockFmhaBwdPipelineEnum::KSKTSVR",
"qs_ks_vr_dos" : "ck_tile::BlockFmhaBwdPipelineEnum::QSKSVROGradS",
"ks_vr" : "ck_tile::BlockFmhaBwdPipelineEnum::KSVR",
"kr_ktr_vr_iglp" : "ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP",
"kr_ktr_vr" : "ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR",
}
FMHA_BWD_KERNEL_HEADER = """// SPDX-License-Identifier: MIT
......@@ -34,39 +32,42 @@ FMHA_BWD_KERNEL_HEADER = """// SPDX-License-Identifier: MIT
FMHA_BWD_DQ_DK_DV_KERNEL_BODY="""
using fmha_dtype_{F_idx} = {F_dtype};
using fmha_block_tile_{F_idx} = ck_tile::sequence<{F_bm0}, {F_bn0}, {F_bk0}, {F_bk1}, {F_bk2}, {F_bk3}, {F_bk4}, {F_bhdq}, {F_bhdv}>;
using fmha_block_tile_{F_idx} = ck_tile::
sequence<{F_bm0}, {F_bn0}, {F_bk0}, {F_bk1}, {F_bk2}, {F_bk3}, {F_bk4}, {F_bhdq}, {F_bhdv}>;
using fmha_block_warps0_{F_idx} = ck_tile::sequence<{F_rm0}, {F_rn0}, {F_rk0}>;
using fmha_block_warps1_{F_idx} = ck_tile::sequence<{F_rm1}, {F_rn1}, {F_rk1}>;
using fmha_block_warps2_{F_idx} = ck_tile::sequence<{F_rm2}, {F_rn2}, {F_rk2}>;
using fmha_warp_tile_{F_idx} = ck_tile::sequence<{F_wm}, {F_wn}, {F_wk}>;
using fmha_warp_tile0_{F_idx} = ck_tile::sequence<{F_wm0}, {F_wn0}, {F_wk0}>;
using fmha_warp_tile1_{F_idx} = ck_tile::sequence<{F_wm1}, {F_wn1}, {F_wk1}>;
// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape
// G0&G2 -> GSdP
// G1&G3 -> GdKV
// G4 -> GdQ
using fmha_bwd_shape_{F_idx} = ck_tile::TileFmhaBwdShape<fmha_block_tile_{F_idx},
fmha_block_warps0_{F_idx},
fmha_warp_tile_{F_idx},
fmha_block_warps1_{F_idx},
fmha_warp_tile_{F_idx},
fmha_block_warps0_{F_idx},
fmha_warp_tile_{F_idx},
fmha_block_warps1_{F_idx},
fmha_warp_tile_{F_idx},
fmha_block_warps2_{F_idx},
fmha_warp_tile_{F_idx}>;
fmha_block_warps0_{F_idx},
fmha_warp_tile0_{F_idx},
fmha_block_warps1_{F_idx},
fmha_warp_tile1_{F_idx},
fmha_block_warps0_{F_idx},
fmha_warp_tile0_{F_idx},
fmha_block_warps1_{F_idx},
fmha_warp_tile1_{F_idx},
fmha_block_warps2_{F_idx},
fmha_warp_tile0_{F_idx}>;
using fmha_bwd_trait_{F_idx} = ck_tile::TileFmhaTraits<{F_spad},
{F_skpad},
{F_dpad},
{F_dvpad},
{F_bias},
{F_dbias},
false,
{F_dropout},
false,
{F_occupancy}>;
using fmha_mask_{F_idx} = {F_mask};
{F_skpad},
{F_dpad},
{F_dvpad},
{F_bias},
{F_dbias},
false,
false,
false,
{F_occupancy}>;
using fmha_mask_{F_idx} = {F_mask};
using fmha_dropout_{F_idx} = {F_dropout};
using fmha_bwd_pipeline_problem_{F_idx} = ck_tile::BlockFmhaBwdPipelineProblem<
typename FmhaBwdTypeConfig<fmha_dtype_{F_idx}>::QDataType,
......@@ -86,55 +87,72 @@ using fmha_bwd_pipeline_problem_{F_idx} = ck_tile::BlockFmhaBwdPipelineProblem<
typename FmhaBwdTypeConfig<fmha_dtype_{F_idx}>::BiasGradDataType,
fmha_bwd_shape_{F_idx},
{F_mode},
{F_deterministic},
fmha_mask_{F_idx},
fmha_dropout_{F_idx},
fmha_bwd_trait_{F_idx}>;
using fmha_bwd_pipeline_{F_idx} = {F_pipeline}<
fmha_bwd_pipeline_problem_{F_idx}>;
using fmha_bwd_pipeline_{F_idx} = {F_pipeline}<fmha_bwd_pipeline_problem_{F_idx}>;
using fmha_bwd_dk_epilogue_{F_idx} =
ck_tile::Default2DEpilogue<ck_tile::Default2DEpilogueProblem<typename FmhaBwdTypeConfig<{F_dtype}>::AccDataType,
typename FmhaBwdTypeConfig<{F_dtype}>::KGradDataType,
false, false>>;
using fmha_bwd_dk_epilogue_{F_idx} = ck_tile::Default2DEpilogue<
ck_tile::Default2DEpilogueProblem<typename FmhaBwdTypeConfig<{F_dtype}>::AccDataType,
typename FmhaBwdTypeConfig<{F_dtype}>::KGradDataType,
{F_skpad},
{F_dpad}>>;
using fmha_bwd_dv_epilogue_{F_idx} =
ck_tile::Default2DEpilogue<ck_tile::Default2DEpilogueProblem<typename FmhaBwdTypeConfig<{F_dtype}>::AccDataType,
typename FmhaBwdTypeConfig<{F_dtype}>::VGradDataType,
false, false>>;
using fmha_bwd_dv_epilogue_{F_idx} = ck_tile::Default2DEpilogue<
ck_tile::Default2DEpilogueProblem<typename FmhaBwdTypeConfig<{F_dtype}>::AccDataType,
typename FmhaBwdTypeConfig<{F_dtype}>::VGradDataType,
{F_skpad},
{F_dvpad}>>;
using fmha_bwd_dq_dk_dv_kernel_{F_idx} =
ck_tile::FmhaBwdDQDKDVKernel<ck_tile::FmhaBwdTilePartitioner<fmha_bwd_shape_{F_idx}>,
fmha_bwd_pipeline_{F_idx},
fmha_bwd_dk_epilogue_{F_idx},
fmha_bwd_dv_epilogue_{F_idx}>;
using dq_dk_dv_trait_{F_idx} = fmha_bwd_dq_dk_dv_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_pipeline_enum}, fmha_mask_{F_idx}, {F_bias}, {F_dbias}, {F_dropout}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}>;
ck_tile::FmhaBwdDQDKDVKernel<fmha_bwd_pipeline_{F_idx},
fmha_bwd_dk_epilogue_{F_idx},
fmha_bwd_dv_epilogue_{F_idx}>;
using dq_dk_dv_trait_{F_idx} = fmha_bwd_dq_dk_dv_traits_<{F_hdim},
{F_dtype},
{F_mode},
{F_pipeline_enum},
fmha_mask_{F_idx},
fmha_dropout_{F_idx},
{F_bias},
{F_dbias},
{F_spad},
{F_skpad},
{F_dpad},
{F_dvpad},
{F_deterministic}>;
#include <iostream>
template<>
template <>
float fmha_bwd_dq_dk_dv_<dq_dk_dv_trait_{F_idx}>(const ck_tile::stream_config& s, fmha_bwd_args a)
{{
using k_ = fmha_bwd_dq_dk_dv_kernel_{F_idx};
if(s.log_level_ > 0)
std::cout << ", " << k_::GetName() << std::flush;
auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids<k_>(a);
constexpr dim3 blocks = k_::BlockSize();
auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids<k_>(a);
constexpr dim3 blocks = k_::BlockSize();
constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu;
return ck_tile::launch_kernel(s, ck_tile::make_kernel<blocks.x, kBlockPerCu>(k_{{}}, grids, blocks, 0, kargs));
return ck_tile::launch_kernel(
s, ck_tile::make_kernel<blocks.x, kBlockPerCu>(k_{{}}, grids, blocks, 0, kargs));
}}
template<>
void fmha_bwd_dq_dk_dv_oneshot_<dq_dk_dv_trait_{F_idx}>(const ck_tile::stream_config& s, fmha_bwd_args a)
template <>
void fmha_bwd_dq_dk_dv_oneshot_<dq_dk_dv_trait_{F_idx}>(const ck_tile::stream_config& s,
fmha_bwd_args a)
{{
using k_ = fmha_bwd_dq_dk_dv_kernel_{F_idx};
auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids<k_>(a);
constexpr dim3 blocks = k_::BlockSize();
using k_ = fmha_bwd_dq_dk_dv_kernel_{F_idx};
auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids<k_>(a);
constexpr dim3 blocks = k_::BlockSize();
constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu;
ck_tile::make_kernel<blocks.x, kBlockPerCu>(k_{{}}, grids, blocks, 0, kargs)(ck_tile::stream_config{{s.stream_id_}});
ck_tile::make_kernel<blocks.x, kBlockPerCu>(k_{{}}, grids, blocks, 0, kargs)(
ck_tile::stream_config{{s.stream_id_}});
}}
template<>
template <>
std::string fmha_bwd_dq_dk_dv_get_name_<dq_dk_dv_trait_{F_idx}>()
{{
using k_ = fmha_bwd_dq_dk_dv_kernel_{F_idx};
......@@ -146,14 +164,15 @@ FMHA_BWD_API_FILENAME="fmha_bwd_api.cpp"
FMHA_BWD_API="""
#include <iostream>
template<typename dot_do_o_trait_, typename dq_dk_dv_trait_>
template <typename dot_do_o_trait_, typename dq_dk_dv_trait_, typename convert_dq_trait_>
float fmha_bwd_(const ck_tile::stream_config& s, fmha_bwd_args a)
{{
if(s.log_level_ > 0)
std::cout << ", " << fmha_bwd_dot_do_o_get_name_<dot_do_o_trait_>() << ", " << fmha_bwd_dq_dk_dv_get_name_<dq_dk_dv_trait_>() << std::flush;
std::cout << ", " << fmha_bwd_dot_do_o_get_name_<dot_do_o_trait_>() << ", " << fmha_bwd_dq_dk_dv_get_name_<dq_dk_dv_trait_>() << ", " << fmha_bwd_convert_dq_get_name_<convert_dq_trait_>() << std::flush;
return ck_tile::launch_kernel(s,
[=](const ck_tile::stream_config& s_){{ fmha_bwd_dot_do_o_oneshot_<dot_do_o_trait_>(s_, a); }},
[=](const ck_tile::stream_config& s_){{ fmha_bwd_dq_dk_dv_oneshot_<dq_dk_dv_trait_>(s_, a); }}
[=](const ck_tile::stream_config& s_){{ fmha_bwd_dot_do_o_oneshot_<dot_do_o_trait_>(s_, a); }},
[=](const ck_tile::stream_config& s_){{ fmha_bwd_dq_dk_dv_oneshot_<dq_dk_dv_trait_>(s_, a); }},
[=](const ck_tile::stream_config& s_){{ fmha_bwd_convert_dq_oneshot_<convert_dq_trait_>(s_, a); }}
);
}}
......@@ -173,38 +192,36 @@ FMHA_BWD_API_PER_HDIM_CASE=""" {F_if} (t.hdim_q <= {F_hdim} && t.hdim_v <
}}
"""
FMHA_BWD_API_INNER_DISPATCH=""" {F_if}((t.is_group_mode == {F_mode}) && ({F_mask_check}) && (t.bias_type == {F_bias_check}) && (t.has_dbias == {F_dbias}) && (t.has_dropout == {F_dropout}) &&
({F_scheck}) && ({F_skcheck}) && ({F_dcheck}) && ({F_dvcheck})) {{
using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_pipeline_enum}, {F_mask}, {F_bias}, {F_dbias}, {F_dropout}, {F_spad0}, {F_skpad}, {F_dpad}, {F_dvpad}>;
FMHA_BWD_API_INNER_DISPATCH=""" {F_if}((t.is_group_mode == {F_mode}) && ({F_mask_check}) && (t.bias_type == {F_bias_check}) && (t.has_dbias == {F_dbias}) && ({F_dropout_check}) &&
({F_scheck}) && ({F_skcheck}) && ({F_dcheck}) && ({F_dvcheck}) && (t.is_deterministic == {F_deterministic})) {{
using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_spad1}, {F_dvpad}>;
r = fmha_bwd_<dot_do_o_trait_, dq_dk_dv_trait_>(s, a);
using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_pipeline_enum}, {F_mask}, {F_dropout}, {F_bias}, {F_dbias}, {F_spad0}, {F_skpad}, {F_dpad}, {F_dvpad}, {F_deterministic}>;
using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_spad1}, {F_dpad}, {F_deterministic}>;
r = fmha_bwd_<dot_do_o_trait_, dq_dk_dv_trait_, convert_dq_trait_>(s, a);
return r;
}}
"""
@dataclass
class FmhaBwdDQDKDVApiTrait:
pipeline : str
pipeline : str
# sync with fmha_bwd_traits<>, to generate fallback calls
hdim : str
dtype : str # data type
mode : str # value from MODE_MAP
bm0 : int # tile size along q seqlen (block size)
bn0 : int # tile size along k seqlen
bhdq : int # q head_dim
bhdv : int # v head_dim
mask : str
bias : str
dbias : str
dropout : str
spad : str
skpad : str
dpad : str
dvpad : str
@property
def name(self) -> str:
return f'{self.pipeline}-{self.hdim}-{self.dtype}-{self.mode}-{self.mask}-{self.bias}-{self.dbias}-{self.dropout}-{self.spad}-{self.skpad}-{self.dpad}-{self.dvpad}'
hdim : str
dtype : str # data type
mode : str # value from MODE_MAP
bm0 : int # tile size along q seqlen (block size)
bn0 : int # tile size along k seqlen
bhdq : int # q head_dim
bhdv : int # v head_dim
mask : str
bias : str
dbias : str
dropout : str
spad : str
skpad : str
dpad : str
dvpad : str
deterministic : str
def scheck(self, spad1 : str) -> str:
if self.mode == 'group':
......@@ -212,9 +229,9 @@ class FmhaBwdDQDKDVApiTrait:
elif self.spad == 't' and spad1 == 't':
return f'a.seqlen_q % {self.bm0} != 0'
elif self.spad == 'f' and spad1 == 't':
return f'a.seqlen_q % {self.bm0} == 0 and a.seqlen_q % 256 != 0' # BlockSize
return f'a.seqlen_q % {self.bm0} == 0 and a.seqlen_q % 64 != 0'
else: # self.skpad == 'f' and skpad1 == 'f'
return f'a.seqlen_q % 256 == 0' # BlockSize
return f'a.seqlen_q % 64 == 0'
@property
def skcheck(self) -> str:
......@@ -256,16 +273,19 @@ class FmhaBwdApiPool:
per_hdim_case=str()
for j, hdim in enumerate(self.dq_dk_dv_pool[dtype].keys()):
traits=self.dq_dk_dv_pool[dtype][hdim]
hdim_int = int(hdim)
inners=str()
for k, trait in enumerate(traits):
if_k = 'if' if k == 0 else 'else if'
for spad1 in ["t", "f"]:
if ((spad1 == "f" and trait.spad == "t") or (trait.mode == "group" and spad1 == "f")):
if (spad1 == "f" and (trait.spad == "t" or trait.mode == "group")):
continue
inners = inners + FMHA_BWD_API_INNER_DISPATCH.format(F_if=if_k, F_mode=MODE_MAP[trait.mode], F_mask=get_mask_map(self.mask_impl)[trait.mask], F_pipeline_enum=BWD_DQDKDV_PIPELINE_ENUM_MAP[trait.pipeline],
F_mask_check=get_mask_check_map(self.mask_impl)[trait.mask], F_bias_check=BIAS_CHECK_MAP[trait.bias], F_bias=BIAS_MAP[trait.bias], F_dbias=BOOL_MAP[trait.dbias], F_dropout=BOOL_MAP[trait.dropout],
inners = inners + FMHA_BWD_API_INNER_DISPATCH.format(F_if=if_k, F_mode=MODE_MAP[trait.mode], F_pipeline_enum=BWD_DQDKDV_PIPELINE_ENUM_MAP[trait.pipeline],
F_mask_check=get_mask_check_map(self.mask_impl)[trait.mask], F_mask=get_mask_map(self.mask_impl)[trait.mask], F_bias_check=BIAS_CHECK_MAP[trait.bias],
F_bias=BIAS_MAP[trait.bias], F_dbias=BOOL_MAP[trait.dbias], F_dropout_check=DROPOUT_CHECK_MAP[trait.dropout], F_dropout=DROPOUT_MAP[trait.dropout],
F_scheck=trait.scheck(spad1=spad1), F_skcheck=trait.skcheck, F_dcheck=trait.dcheck, F_dvcheck=trait.dvcheck, F_hdim=hdim, F_dtype=DTYPE_MAP[dtype],
F_spad0=BOOL_MAP[trait.spad], F_spad1=BOOL_MAP[spad1], F_skpad=BOOL_MAP[trait.skpad], F_dpad=BOOL_MAP[trait.dpad], F_dvpad=BOOL_MAP[trait.dvpad])
F_spad0=BOOL_MAP[trait.spad], F_spad1=BOOL_MAP[spad1], F_skpad=BOOL_MAP[trait.skpad], F_dpad=BOOL_MAP[trait.dpad], F_dvpad=BOOL_MAP[trait.dvpad],
F_deterministic=BOOL_MAP[trait.deterministic])
if_j = 'if' if j == 0 else 'else if'
per_hdim_case = per_hdim_case + FMHA_BWD_API_PER_HDIM_CASE.format(F_if=if_j, F_hdim=hdim, F_inner_dispatch=inners)
......@@ -295,81 +315,89 @@ class FmhaBwdDQDKDVTileSize:
F_bhdv : int # v head_dim
F_rm0 : int # number of warps along q seqlen (block warps) in gemm0/gemm2
F_rn0 : int # number of warps along k seqlen (block warps) in gemm0/gemm2
F_rk0 : int # number of warps along gemm-k (not used) in gemm0/gemm2
F_rk0 : int # number of warps along headdim_qk/v (not used) in gemm0/gemm2
F_rm1 : int # number of warps along k seqlen (block warps) in gemm1/gemm3
F_rn1 : int # number of warps along q seqlen (block warps) in gemm1/gemm3
F_rk1 : int # number of warps along gemm-k (not used) in gemm1/gemm3
F_rm2 : int # number of warps along k seqlen (block warps) in gemm4
F_rn2 : int # number of warps along q seqlen (block warps) in gemm4
F_rk2 : int # number of warps along gemm-k (not used) in gemm4
F_wm : int # warp size along m (warp size)
F_wn : int # warp size along n
F_wk : int # warp size along k
F_rn1 : int # number of warps along headdim_qk/v (block warps) in gemm1/gemm3
F_rk1 : int # number of warps along q seqlen (not used) in gemm1/gemm3
F_rm2 : int # number of warps along q seqlen (block warps) in gemm4
F_rn2 : int # number of warps along headdim_qk (block warps) in gemm4
F_rk2 : int # number of warps along k seqlen (not used) in gemm4
F_wm0 : int # warp size along m in gemm0/gemm2/gemm4
F_wn0 : int # warp size along n in gemm0/gemm2/gemm4
F_wk0 : int # warp size along k in gemm0/gemm2/gemm4
F_wm1 : int # warp size along m in gemm1/gemm3
F_wn1 : int # warp size along n in gemm1/gemm3
F_wk1 : int # warp size along k in gemm1/gemm3
F_occupancy : int # occupancy
@property
def name(self) -> str:
return f"b{self.F_bm0}x{self.F_bn0}x{self.F_bk0}x{self.F_bk1}x{self.F_bk2}x{self.F_bk3}x{self.F_bk4}x{self.F_bhdq}x{self.F_bhdv}" +\
f"_r{self.F_rm0}x{self.F_rn0}x{self.F_rk0}_r{self.F_rm1}x{self.F_rn1}x{self.F_rk1}_r{self.F_rm2}x{self.F_rn2}x{self.F_rk2}" +\
f"_w{self.F_wm}x{self.F_wn}x{self.F_wk}_o{self.F_occupancy}"
f"_w{self.F_wm0}x{self.F_wn0}x{self.F_wk0}_w{self.F_wm1}x{self.F_wn1}x{self.F_wk1}_o{self.F_occupancy}"
@dataclass
class FmhaBwdDQDKDVKernel:
F_idx : int # this is not a tunable, but a counter to differentiate symbol
F_hdim : int # hdim
F_dtype : str # data type
F_tile : FmhaBwdDQDKDVTileSize
F_spad : str # true/false
F_skpad : str #
F_dpad : str #
F_dvpad : str #
F_bias : str #
F_dbias : str #
F_dropout : str #
F_mask : str # value from MASK_MAP
F_mode : str # value from MODE_MAP
F_pipeline : str
mask_impl : str
F_idx : int # this is not a tunable, but a counter to differentiate symbol
F_hdim : int # hdim
F_dtype : str # data type
F_tile : FmhaBwdDQDKDVTileSize
F_spad : str # true/false
F_skpad : str #
F_dpad : str #
F_dvpad : str #
F_bias : str #
F_dbias : str #
F_dropout : str #
F_mask : str # value from MASK_MAP
F_mode : str # value from MODE_MAP
F_deterministic : str #
F_pipeline : str #
mask_impl : str #
@property
def template(self) -> str:
return FMHA_BWD_KERNEL_HEADER + \
FMHA_BWD_DQ_DK_DV_KERNEL_BODY.format(
F_idx = self.F_idx,
F_hdim = self.F_hdim,
F_dtype = DTYPE_MAP[self.F_dtype],
F_bm0 = self.F_tile.F_bm0,
F_bn0 = self.F_tile.F_bn0,
F_bk0 = self.F_tile.F_bk0,
F_bk1 = self.F_tile.F_bk1,
F_bk2 = self.F_tile.F_bk2,
F_bk3 = self.F_tile.F_bk3,
F_bk4 = self.F_tile.F_bk4,
F_bhdq = self.F_tile.F_bhdq,
F_bhdv = self.F_tile.F_bhdv,
F_rm0 = self.F_tile.F_rm0,
F_rn0 = self.F_tile.F_rn0,
F_rk0 = self.F_tile.F_rk0,
F_rm1 = self.F_tile.F_rm1,
F_rn1 = self.F_tile.F_rn1,
F_rk1 = self.F_tile.F_rk1,
F_rm2 = self.F_tile.F_rm2,
F_rn2 = self.F_tile.F_rn2,
F_rk2 = self.F_tile.F_rk2,
F_wm = self.F_tile.F_wm,
F_wn = self.F_tile.F_wn,
F_wk = self.F_tile.F_wk,
F_spad = BOOL_MAP[self.F_spad],
F_skpad = BOOL_MAP[self.F_skpad],
F_dpad = BOOL_MAP[self.F_dpad],
F_dvpad = BOOL_MAP[self.F_dvpad],
F_bias = BIAS_MAP[self.F_bias],
F_dbias = BOOL_MAP[self.F_dbias],
F_dropout = BOOL_MAP[self.F_dropout],
F_occupancy = self.F_tile.F_occupancy,
F_mask = get_mask_map(self.mask_impl)[self.F_mask],
F_mode = MODE_MAP[self.F_mode],
F_idx = self.F_idx,
F_hdim = self.F_hdim,
F_dtype = DTYPE_MAP[self.F_dtype],
F_bm0 = self.F_tile.F_bm0,
F_bn0 = self.F_tile.F_bn0,
F_bk0 = self.F_tile.F_bk0,
F_bk1 = self.F_tile.F_bk1,
F_bk2 = self.F_tile.F_bk2,
F_bk3 = self.F_tile.F_bk3,
F_bk4 = self.F_tile.F_bk4,
F_bhdq = self.F_tile.F_bhdq,
F_bhdv = self.F_tile.F_bhdv,
F_rm0 = self.F_tile.F_rm0,
F_rn0 = self.F_tile.F_rn0,
F_rk0 = self.F_tile.F_rk0,
F_rm1 = self.F_tile.F_rm1,
F_rn1 = self.F_tile.F_rn1,
F_rk1 = self.F_tile.F_rk1,
F_rm2 = self.F_tile.F_rm2,
F_rn2 = self.F_tile.F_rn2,
F_rk2 = self.F_tile.F_rk2,
F_wm0 = self.F_tile.F_wm0,
F_wn0 = self.F_tile.F_wn0,
F_wk0 = self.F_tile.F_wk0,
F_wm1 = self.F_tile.F_wm1,
F_wn1 = self.F_tile.F_wn1,
F_wk1 = self.F_tile.F_wk1,
F_spad = BOOL_MAP[self.F_spad],
F_skpad = BOOL_MAP[self.F_skpad],
F_dpad = BOOL_MAP[self.F_dpad],
F_dvpad = BOOL_MAP[self.F_dvpad],
F_bias = BIAS_MAP[self.F_bias],
F_dbias = BOOL_MAP[self.F_dbias],
F_dropout = DROPOUT_MAP[self.F_dropout],
F_occupancy = self.F_tile.F_occupancy,
F_mask = get_mask_map(self.mask_impl)[self.F_mask],
F_mode = MODE_MAP[self.F_mode],
F_deterministic = BOOL_MAP[self.F_deterministic],
F_pipeline_enum = BWD_DQDKDV_PIPELINE_ENUM_MAP[self.F_pipeline],
F_pipeline = BWD_DQDKDV_PIPELINE_MAP[self.F_pipeline])
F_pipeline = BWD_DQDKDV_PIPELINE_MAP[self.F_pipeline])
@property
def name(self) -> str:
......@@ -382,7 +410,7 @@ class FmhaBwdDQDKDVKernel:
if n != '' : n = 'p' + n
return n
pn = pad_name()
n = f"fmha_bwd_d{self.F_hdim}_{self.F_dtype}_{self.F_mode}_" + self.F_tile.name
n = f"fmha_bwd_d{self.F_hdim}_{self.F_dtype}_{self.F_mode}_" + self.F_tile.name + f'_{self.F_pipeline}'
if pn != '' : n += f'_{pn}'
if self.F_bias != 'no' : n += f'_{self.F_bias}'
if self.F_dbias == 't' : n += '_dbias'
......@@ -390,7 +418,8 @@ class FmhaBwdDQDKDVKernel:
if self.F_mask == 's_mask': n += f'_mask'
else:
if self.F_mask != 'no' : n += f'_m{self.F_mask[0]}'
if self.F_dropout == 't' : n += '_dropout'
if self.F_dropout != 'no' : n += f'_{self.F_dropout}'
if self.F_deterministic == 't' : n += '_deterministic'
return n
@property
......@@ -413,19 +442,23 @@ class FmhaBwdDQDKDVKernel:
spad=self.F_spad,
skpad=self.F_skpad,
dpad=self.F_dpad,
dvpad=self.F_dvpad)
dvpad=self.F_dvpad,
deterministic=self.F_deterministic
)
# TODO: design a more practical way to do it
# this is current supported tile size & pipeline.
def get_fmha_bwd_dq_dk_dv_tile_ppl_dict_from_dtype(dtype : str) -> Optional[dict]:
if dtype == 'fp16' or dtype == 'bf16':
return {
'32' : [FmhaBwdDQDKDVTileSize(128, 128, 32, 32, 32, 32, 32, 32, 32, 1, 4, 1, 4, 1, 1, 4, 1, 1, 32, 32, 16, 1),
"qs_ks_vr_dos"],
'64' : [FmhaBwdDQDKDVTileSize( 64, 128, 32, 32, 32, 32, 32, 64, 64, 1, 4, 1, 4, 1, 1, 2, 2, 1, 32, 32, 16, 1),
"qs_ks_vr_dos"],
'128' : [FmhaBwdDQDKDVTileSize( 64, 128, 32, 32, 32, 32, 32, 128, 128, 1, 4, 1, 4, 1, 1, 2, 2, 1, 32, 32, 16, 1),
"ks_vr"]
'32' : [FmhaBwdDQDKDVTileSize( 32, 128, 32, 32, 32, 32, 64, 32, 32, 1, 4, 1, 4, 1, 1, 2, 2, 1, 16, 16, 32, 16, 16, 16, 1),
"kr_ktr_vr_iglp", "kr_ktr_vr"],
'64' : [FmhaBwdDQDKDVTileSize( 32, 128, 64, 32, 64, 32, 32, 64, 64, 1, 4, 1, 4, 1, 1, 1, 4, 1, 16, 16, 32, 16, 16, 16, 1),
"kr_ktr_vr_iglp", "kr_ktr_vr"],
'128' : [FmhaBwdDQDKDVTileSize( 16, 128, 128, 16, 128, 16, 32, 128, 128, 1, 4, 1, 4, 1, 1, 1, 4, 1, 16, 16, 32, 16, 16, 16, 1),
"kr_ktr_vr_iglp", "kr_ktr_vr"],
'256' : [FmhaBwdDQDKDVTileSize( 16, 64, 256, 16, 256, 16, 32, 256, 256, 1, 4, 1, 4, 1, 1, 1, 4, 1, 16, 16, 32, 16, 16, 16, 1),
"kr_ktr_vr_iglp", "kr_ktr_vr"]
}
else:
return None
......@@ -440,7 +473,7 @@ def get_bwd_dq_dk_dv_blobs(kernel_filter : Optional[str], receipt, mask_impl) ->
d = get_fmha_bwd_dq_dk_dv_tile_ppl_dict_from_dtype(dtype)
if d == None:
continue
for hdim_str, mode, mask, bias, dbias, dropout, spad, skpad, dpad, dvpad in itertools.product(d.keys(), MODE_MAP.keys(), get_mask_map(mask_impl).keys(), BIAS_MAP.keys(), ["t", "f"], ["t", "f"], ["t", "f"], ["t", "f"], ["t", "f"], ["t", "f"]):
for hdim_str, mode, mask, bias, dbias, dropout, spad, skpad, dpad, dvpad, deterministic in itertools.product(d.keys(), MODE_MAP.keys(), get_mask_map(mask_impl).keys(), BIAS_MAP.keys(), ["t", "f"], DROPOUT_MAP.keys(), ["t", "f"], ["t", "f"], ["t", "f"], ["t", "f"], ["t", "f"]):
tile = d[hdim_str][0]
ppl = d[hdim_str][1]
hdim = int(hdim_str)
......@@ -448,16 +481,29 @@ def get_bwd_dq_dk_dv_blobs(kernel_filter : Optional[str], receipt, mask_impl) ->
continue
if ((bias == "no" or bias == "alibi") and dbias == "t"):
continue
if ("wg32" in dropout):
continue
if (dpad == "t" or dvpad == "t"):
ppl = d[hdim_str][2]
k = FmhaBwdDQDKDVKernel(F_idx=0, F_hdim=hdim, F_dtype=dtype, F_tile=tile,
F_spad=spad, F_skpad=skpad, F_dpad=dpad, F_dvpad=dvpad,
F_bias=bias, F_dbias=dbias, F_dropout=dropout, F_mask=mask, F_mode=mode,
F_pipeline=ppl, mask_impl=mask_impl)
F_pipeline=ppl, mask_impl=mask_impl, F_deterministic=deterministic)
if kernel_filter != None:
if not fnmatch.fnmatch(k.name, kernel_filter):
continue
if receipt == 2:
cond = dtype in ['fp16', 'bf16']
cond &= bias in ['no', 'alibi']
cond &= dropout in ['no', 'dropout_wg32', 'dropout_wg16']
cond &= dpad == dvpad
if not cond:
continue
if receipt == 3:
cond = dtype in ['fp16', 'bf16']
cond &= bias in ['no', 'alibi']
cond &= dpad == dvpad
cond &= deterministic == "f"
if not cond:
continue
api_pool.register_dq_dk_dv_traits(k.api_trait())
......@@ -468,53 +514,54 @@ def get_bwd_dq_dk_dv_blobs(kernel_filter : Optional[str], receipt, mask_impl) ->
FMHA_BWD_DOT_DO_O_KERNEL_BODY="""
using fmha_dtype_{F_idx} = {F_dtype};
using fmha_bwd_dot_do_o_trait_{F_idx} = ck_tile::TileFmhaBwdOGradDotOTraits<{F_spad},
{F_dvpad},
{F_occupancy}>;
using fmha_bwd_dot_do_o_trait_{F_idx} =
ck_tile::TileFmhaBwdOGradDotOTraits<{F_spad}, {F_dvpad}, {F_occupancy}>;
using fmha_bwd_dot_do_o_pipeline_problem_{F_idx} = ck_tile::BlockFmhaBwdOGradDotOPipelineProblem<
typename FmhaBwdTypeConfig<fmha_dtype_{F_idx}>::ODataType,
typename FmhaBwdTypeConfig<fmha_dtype_{F_idx}>::OGradDataType,
typename FmhaBwdTypeConfig<fmha_dtype_{F_idx}>::DDataType,
/* BlockSize = */ 256,
/* BlockSize = */ 64,
{F_hdim},
{F_mode},
fmha_bwd_dot_do_o_trait_{F_idx}>;
using fmha_bwd_dot_do_o_{F_idx} = typename ck_tile::BlockFmhaBwdOGradDotO<
fmha_bwd_dot_do_o_pipeline_problem_{F_idx}>;
using fmha_bwd_dot_do_o_{F_idx} =
typename ck_tile::BlockFmhaBwdOGradDotO<fmha_bwd_dot_do_o_pipeline_problem_{F_idx}>;
using fmha_bwd_dot_do_o_kernel_{F_idx} =
ck_tile::FmhaBwdOGradDotOKernel<ck_tile::FmhaBwdOGradDotOTilePartitioner</* BlockSize = */ 256>,
fmha_bwd_dot_do_o_{F_idx}>;
ck_tile::FmhaBwdOGradDotOKernel<fmha_bwd_dot_do_o_{F_idx}>;
using dot_do_o_trait_{F_idx} = fmha_bwd_dot_do_o_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_spad}, {F_dvpad}>;
using dot_do_o_trait_{F_idx} =
fmha_bwd_dot_do_o_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_spad}, {F_dvpad}>;
#include <iostream>
template<>
template <>
float fmha_bwd_dot_do_o_<dot_do_o_trait_{F_idx}>(const ck_tile::stream_config& s, fmha_bwd_args a)
{{
using k_ = fmha_bwd_dot_do_o_kernel_{F_idx};
if(s.log_level_ > 0)
std::cout << ", " << k_::GetName() << std::flush;
auto [kargs, grids] = fmha_bwd_dot_do_o_create_kargs_and_grids<k_>(a);
constexpr dim3 blocks = k_::BlockSize();
auto [kargs, grids] = fmha_bwd_dot_do_o_create_kargs_and_grids<k_>(a);
constexpr dim3 blocks = k_::BlockSize();
constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu;
return ck_tile::launch_kernel(s, ck_tile::make_kernel<blocks.x, kBlockPerCu>(k_{{}}, grids, blocks, 0, kargs));
return ck_tile::launch_kernel(
s, ck_tile::make_kernel<blocks.x, kBlockPerCu>(k_{{}}, grids, blocks, 0, kargs));
}}
template<>
template <>
void fmha_bwd_dot_do_o_oneshot_<dot_do_o_trait_{F_idx}>(const ck_tile::stream_config& s, fmha_bwd_args a)
{{
using k_ = fmha_bwd_dot_do_o_kernel_{F_idx};
auto [kargs, grids] = fmha_bwd_dot_do_o_create_kargs_and_grids<k_>(a);
constexpr dim3 blocks = k_::BlockSize();
using k_ = fmha_bwd_dot_do_o_kernel_{F_idx};
auto [kargs, grids] = fmha_bwd_dot_do_o_create_kargs_and_grids<k_>(a);
constexpr dim3 blocks = k_::BlockSize();
constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu;
ck_tile::make_kernel<blocks.x, kBlockPerCu>(k_{{}}, grids, blocks, 0, kargs)(ck_tile::stream_config{{s.stream_id_}});
ck_tile::make_kernel<blocks.x, kBlockPerCu>(k_{{}}, grids, blocks, 0, kargs)(
ck_tile::stream_config{{s.stream_id_}});
}}
template<>
template <>
std::string fmha_bwd_dot_do_o_get_name_<dot_do_o_trait_{F_idx}>()
{{
using k_ = fmha_bwd_dot_do_o_kernel_{F_idx};
......@@ -584,12 +631,150 @@ def get_bwd_dot_do_o_blobs() -> List[FmhaBwdOGradDotOKernel]:
return gen
FMHA_BWD_CONVERT_DQ_KERNEL_BODY="""
using fmha_dtype_{F_idx} = {F_dtype};
using fmha_bwd_convert_dq_trait_{F_idx} =
ck_tile::TileFmhaBwdConvertQGradTraits<{F_spad}, {F_dpad}, {F_occupancy}>;
using fmha_bwd_convert_dq_pipeline_problem_{F_idx} =
ck_tile::BlockFmhaBwdConvertQGradPipelineProblem<
typename FmhaBwdTypeConfig<fmha_dtype_{F_idx}>::AccDataType,
typename FmhaBwdTypeConfig<fmha_dtype_{F_idx}>::QGradDataType,
/* BlockSize = */ 256,
{F_bm0},
{F_bn0},
{F_hdim},
{F_mode},
{F_deterministic},
fmha_bwd_convert_dq_trait_{F_idx}>;
using fmha_bwd_convert_dq_{F_idx} =
typename ck_tile::BlockFmhaBwdConvertQGrad<fmha_bwd_convert_dq_pipeline_problem_{F_idx}>;
using fmha_bwd_convert_dq_kernel_{F_idx} =
ck_tile::FmhaBwdConvertQGradKernel<fmha_bwd_convert_dq_{F_idx}>;
using convert_dq_trait_{F_idx} = fmha_bwd_convert_dq_traits_<{F_hdim},
{F_dtype},
{F_mode},
{F_spad},
{F_dpad},
{F_deterministic}>;
#include <iostream>
template <>
float fmha_bwd_convert_dq_<convert_dq_trait_{F_idx}>(const ck_tile::stream_config& s, fmha_bwd_args a)
{{
using k_ = fmha_bwd_convert_dq_kernel_{F_idx};
if(s.log_level_ > 0)
std::cout << ", " << k_::GetName() << std::flush;
auto [kargs, grids] = fmha_bwd_convert_dq_create_kargs_and_grids<k_>(a);
constexpr dim3 blocks = k_::BlockSize();
constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu;
return ck_tile::launch_kernel(
s, ck_tile::make_kernel<blocks.x, kBlockPerCu>(k_{{}}, grids, blocks, 0, kargs));
}}
template <>
void fmha_bwd_convert_dq_oneshot_<convert_dq_trait_{F_idx}>(const ck_tile::stream_config& s,
fmha_bwd_args a)
{{
using k_ = fmha_bwd_convert_dq_kernel_{F_idx};
auto [kargs, grids] = fmha_bwd_convert_dq_create_kargs_and_grids<k_>(a);
constexpr dim3 blocks = k_::BlockSize();
constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu;
ck_tile::make_kernel<blocks.x, kBlockPerCu>(k_{{}}, grids, blocks, 0, kargs)(
ck_tile::stream_config{{s.stream_id_}});
}}
template <>
std::string fmha_bwd_convert_dq_get_name_<convert_dq_trait_{F_idx}>()
{{
using k_ = fmha_bwd_convert_dq_kernel_{F_idx};
return k_::GetName();
}}
"""
@dataclass
class FmhaBwdConvertQGradKernel:
F_idx : int # this is not a tunable, but a counter to differentiate symbol
F_hdim : int # hdim
F_dtype : str # data type
F_bm0 : int # tile size along q seqlen (block size)
F_bn0 : int # tile size along k seqlen
F_spad : str # true/false
F_dpad : str #
F_mode : str # value from MODE_MAP
F_occupancy : int #
F_deterministic : str #
@property
def template(self) -> str:
return FMHA_BWD_KERNEL_HEADER + \
FMHA_BWD_CONVERT_DQ_KERNEL_BODY.format(
F_idx = self.F_idx,
F_hdim = self.F_hdim,
F_dtype = DTYPE_MAP[self.F_dtype],
F_bm0 = self.F_bm0,
F_bn0 = self.F_bn0,
F_spad = BOOL_MAP[self.F_spad],
F_dpad = BOOL_MAP[self.F_dpad],
F_mode = MODE_MAP[self.F_mode],
F_occupancy = self.F_occupancy,
F_deterministic = BOOL_MAP[self.F_deterministic])
@property
def name(self) -> str:
def pad_name() -> str:
n = ''
if self.F_spad == 't': n += 's'
if self.F_dpad == 't' : n += 'd'
if n != '' : n = 'p' + n
return n
pn = pad_name()
n = f"fmha_bwd_convert_dq_d{self.F_hdim}_{self.F_dtype}_b{self.F_bm0}x{self.F_bn0}_{self.F_mode}_o{self.F_occupancy}"
if pn != '' : n += f'_{pn}'
if self.F_deterministic == 't' : n += f'_deterministic'
return n
@property
def filename(self) -> str:
return self.name + ".cpp"
def get_bwd_convert_dq_blobs() -> List[FmhaBwdConvertQGradKernel]:
# TODO: we don't support tuning yet, so pick up one value for pad/occupancy
# support this in future
def get_occupancy(dtype, hdim):
return 2
gen = list()
for dtype in DTYPE_MAP.keys():
d = get_fmha_bwd_dq_dk_dv_tile_ppl_dict_from_dtype(dtype)
if d == None:
continue
for hdim_str, mode, spad, dpad, deterministic in itertools.product(d.keys(), MODE_MAP.keys(), ["t", "f"], ["t", "f"], ["t", "f"]):
hdim = int(hdim_str)
tile = d[hdim_str][0]
if (mode == "group" and spad == "f"):
continue
k = FmhaBwdConvertQGradKernel(F_idx=0, F_hdim=hdim, F_dtype=dtype, F_bm0=64, F_bn0=tile.F_bn0,
F_spad=spad, F_dpad=dpad, F_mode=mode, F_occupancy=get_occupancy(dtype, hdim), F_deterministic=deterministic)
gen.append(k)
return gen
def write_single_bwd_dq_dk_dv_kernel(kernel: FmhaBwdDQDKDVKernel, autogen_dir: Path) -> None:
(autogen_dir / kernel.filename).write_text(kernel.template)
def write_single_bwd_dot_do_o_kernel(kernel: FmhaBwdOGradDotOKernel, autogen_dir: Path) -> None:
(autogen_dir / kernel.filename).write_text(kernel.template)
def write_single_bwd_convert_dq_kernel(kernel: FmhaBwdConvertQGradKernel, autogen_dir: Path) -> None:
(autogen_dir / kernel.filename).write_text(kernel.template)
def write_bwd_api(api_pool : FmhaBwdApiPool, autogen_dir: Path) -> None:
(autogen_dir / FMHA_BWD_API_FILENAME).write_text(api_pool.api)
......@@ -597,6 +782,9 @@ def write_blobs(output_dir : Path, kernel_filter : Optional[str], receipt, mask_
kernels = get_bwd_dot_do_o_blobs()
for kernel in kernels:
write_single_bwd_dot_do_o_kernel(kernel, output_dir)
kernels = get_bwd_convert_dq_blobs()
for kernel in kernels:
write_single_bwd_convert_dq_kernel(kernel, output_dir)
api_pool, kernels = get_bwd_dq_dk_dv_blobs(kernel_filter, receipt, mask_impl)
for kernel in kernels:
write_single_bwd_dq_dk_dv_kernel(kernel, output_dir)
......@@ -605,6 +793,9 @@ def write_blobs(output_dir : Path, kernel_filter : Optional[str], receipt, mask_
def list_blobs(file_path : Path, kernel_filter : Optional[str], receipt, mask_impl) -> None:
with file_path.open('a') as f:
kernels = get_bwd_dot_do_o_blobs()
for kernel in kernels:
f.write(str(file_path.parent / GEN_DIR / kernel.filename) + "\n")
kernels = get_bwd_convert_dq_blobs()
for kernel in kernels:
f.write(str(file_path.parent / GEN_DIR / kernel.filename) + "\n")
_, kernels = get_bwd_dq_dk_dv_blobs(kernel_filter, receipt, mask_impl)
......
......@@ -428,11 +428,18 @@ def get_fwd_blobs(kernel_filter : Optional[str], receipt, mask_impl) -> Tuple[Fm
pipelines.append(FmhaFwdPipeline('qr', 'row', 't', 't', 't', 't', bias, lse, dropout, squant, mask))
pipelines.append(FmhaFwdPipeline('qr', 'col', 't', 't', 't', 't', bias, lse, dropout, squant, mask))
else:
pipelines.append(FmhaFwdPipeline('qr_async', 'row', 't', 'f', 't', 't', bias, lse, dropout, squant, mask))
pipelines.append(FmhaFwdPipeline('qr_async', 'row', 't', 't', 't', 't', bias, lse, dropout, squant, mask))
pipelines.append(FmhaFwdPipeline('qr_async', 'col', 't', 'f', 't', 't', bias, lse, dropout, squant, mask))
pipelines.append(FmhaFwdPipeline('qr_async', 'col', 't', 't', 't', 't', bias, lse, dropout, squant, mask))
if receipt == 1:
if bias == "bias":
# TODO: rocm 6.2 compiler problem if using qr_async for bias case
pipelines.append(FmhaFwdPipeline('qr', 'row', 'f', 'f', 'f', 'f', bias, lse, dropout, squant, mask))
pipelines.append(FmhaFwdPipeline('qr', 'row', 't', 't', 't', 't', bias, lse, dropout, squant, mask))
pipelines.append(FmhaFwdPipeline('qr', 'col', 'f', 'f', 'f', 'f', bias, lse, dropout, squant, mask))
pipelines.append(FmhaFwdPipeline('qr', 'col', 't', 't', 't', 't', bias, lse, dropout, squant, mask))
else:
pipelines.append(FmhaFwdPipeline('qr_async', 'row', 't', 'f', 't', 't', bias, lse, dropout, squant, mask))
pipelines.append(FmhaFwdPipeline('qr_async', 'row', 't', 't', 't', 't', bias, lse, dropout, squant, mask))
pipelines.append(FmhaFwdPipeline('qr_async', 'col', 't', 'f', 't', 't', bias, lse, dropout, squant, mask))
pipelines.append(FmhaFwdPipeline('qr_async', 'col', 't', 't', 't', 't', bias, lse, dropout, squant, mask))
if receipt == 1 and bias != "bias":
pipelines.append(FmhaFwdPipeline('qr', 'row', 't', 't', 't', 't', bias, lse, dropout, squant, mask)) # TODO: cover arbitraty hdim
pipelines.append(FmhaFwdPipeline('qr', 'col', 't', 'f', 't', 't', bias, lse, dropout, squant, mask)) # TODO: cover arbitraty hdim
elif dtype in ['fp8', 'bf8']:
......
......@@ -87,7 +87,11 @@ auto create_args(int argc, char* argv[])
.insert("drop_offset", "0", "offset for random number generator")
.insert("timer", "gpu", "gpu:gpu timer, cpu:cpu timer")
.insert("warmup", "5", "number of iterations before benchmark the kernel")
.insert("repeat", "20", "number of iterations to benchmark the kernel");
.insert("repeat", "20", "number of iterations to benchmark the kernel")
.insert("deterministic",
"0",
"if set to 1 will use multi-buffer reduction strategy for dq, atomic opeartion "
"will not be used");
bool result = arg_parser.parse(argc, argv);
return std::make_tuple(result, arg_parser);
......@@ -128,11 +132,6 @@ bool run(const ck_tile::ArgParser& arg_parser)
ck_tile::index_t hdim_v = arg_parser.get_int("d_v");
if(hdim_v < 0)
hdim_v = hdim_q;
if(hdim_q % 2 != 0 || hdim_v % 2 != 0)
{
std::cerr << "FMHA Bwd kernel currently only supports even headdim" << std::endl;
return false;
}
bool i_perm = arg_parser.get_bool("iperm"); // if true, will be batch * nhead * seqlen * hdim
bool o_perm = arg_parser.get_bool("operm"); // if false, will be batch * seqlen * nhead * hdim
......@@ -177,9 +176,10 @@ bool run(const ck_tile::ArgParser& arg_parser)
seed.reset();
}
int stream_warmup = arg_parser.get_int("warmup");
int stream_repeat = arg_parser.get_int("repeat");
bool kname = arg_parser.get_bool("kname");
int stream_warmup = arg_parser.get_int("warmup");
int stream_repeat = arg_parser.get_int("repeat");
bool kname = arg_parser.get_bool("kname");
bool deterministic = arg_parser.get_bool("deterministic");
ck_tile::stream_config stream_config{nullptr,
true,
......@@ -265,6 +265,9 @@ bool run(const ck_tile::ArgParser& arg_parser)
(mode == mode_enum::batch ? seqlen_q : seqstart_q_host.back());
const ck_tile::index_t shape_seqlen_k =
(mode == mode_enum::batch ? seqlen_k : seqstart_k_host.back());
const ck_tile::index_t kN0 = (hdim_q <= 128) ? 128 : 64;
const ck_tile::index_t nsplits =
deterministic ? ck_tile::integer_divide_ceil(max_seqlen_k, kN0) : 1;
ck_tile::HostTensor<QDataType> q_host(
get_lengths(i_perm, shape_batch, nhead, shape_seqlen_q, hdim_q));
......@@ -284,9 +287,9 @@ bool run(const ck_tile::ArgParser& arg_parser)
ck_tile::HostTensor<ODataType> o_host(
get_lengths(o_perm, shape_batch, nhead, shape_seqlen_q, hdim_v));
ck_tile::HostTensor<LSEDataType> lse_host(
std::array<ck_tile::index_t, 3>{batch, nhead, max_seqlen_q});
std::array<ck_tile::index_t, 3>{shape_batch, nhead, shape_seqlen_q});
ck_tile::HostTensor<DDataType> d_host(
std::array<ck_tile::index_t, 3>{batch, nhead, max_seqlen_q});
std::array<ck_tile::index_t, 3>{shape_batch, nhead, shape_seqlen_q});
ck_tile::HostTensor<RandValOutputDataType> randval_host(
p_drop > 0 ? get_lengths(true, shape_batch, nhead, shape_seqlen_q, max_seqlen_k)
: std::array<ck_tile::index_t, 4>{1, 1, 1, 1});
......@@ -302,6 +305,10 @@ bool run(const ck_tile::ArgParser& arg_parser)
use_dbias
? get_lengths(i_perm, shape_batch, nhead, shape_seqlen_q, max_seqlen_k)
: std::array<ck_tile::index_t, 4>{1, 1, 1, 1} /* dummy shape for simplifying code */);
ck_tile::HostTensor<AccDataType> dq_acc_host(
i_perm
? std::array<ck_tile::index_t, 5>{nsplits, shape_batch, nhead, shape_seqlen_q, hdim_q}
: std::array<ck_tile::index_t, 5>{nsplits, shape_batch, shape_seqlen_q, nhead, hdim_q});
if(init_method == 0)
{
......@@ -362,6 +369,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
ck_tile::DeviceMem seqstart_q(seqstart_q_host.size() * sizeof(int32_t));
ck_tile::DeviceMem seqstart_k(seqstart_k_host.size() * sizeof(int32_t));
ck_tile::DeviceMem alibi_slope_buf(alibi_slope_host.get_element_space_size_in_bytes());
ck_tile::DeviceMem dq_acc_buf(dq_acc_host.get_element_space_size_in_bytes());
q_buf.ToDevice(q_host.data());
k_buf.ToDevice(k_host.data());
......@@ -387,8 +395,17 @@ bool run(const ck_tile::ArgParser& arg_parser)
std::cout << "[" << prec << "|" << mode << "|" << io_layout(i_perm, o_perm) << "] b:" << batch
<< ", h:" << nhead << "/" << nhead_k << ", s:" << seqlen_q << "/" << seqlen_k
<< ", d:" << hdim_q << "/" << hdim_v << ", scale:" << scale << ", bias:" << bias
<< ", dbias:" << use_dbias << ", p_drop:" << p_drop << ", mask:" << mask
<< std::flush;
<< ", dbias:" << use_dbias << ", p_drop:" << p_drop << ", s_randval:" << s_randval
<< ", deterministic:" << deterministic << ", mask:" << mask << std::flush;
std::size_t workspace_size =
dq_acc_host.get_element_space_size_in_bytes() * sizeof(AccDataType) / (1024 * 1024);
if(deterministic == 1)
{
std::cout << "\nDeterministic mode ON: " << workspace_size
<< " MByte memory workspace allocated" << std::endl;
}
auto fmha_traits = fmha_bwd_traits{hdim_q,
hdim_v,
......@@ -397,7 +414,9 @@ bool run(const ck_tile::ArgParser& arg_parser)
mask.type,
bias.type,
use_dbias,
p_drop > 0.0f};
p_drop > 0.0f,
s_randval,
deterministic};
auto fmha_args = [&]() {
assert(nhead % nhead_k == 0);
/// NOTE: we broadcast bias from [1, 1, seqlen_q, seqlen_k] to [batch, nhead, seqlen_q,
......@@ -422,7 +441,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
const ck_tile::index_t nhead_stride_o = (o_perm ? shape_seqlen_q * hdim_v : hdim_v);
const ck_tile::index_t nhead_stride_randval = (shape_seqlen_q * max_seqlen_k);
const ck_tile::index_t nhead_stride_do = (o_perm ? shape_seqlen_q * hdim_v : hdim_v);
const ck_tile::index_t nhead_stride_lsed = max_seqlen_q;
const ck_tile::index_t nhead_stride_lsed = shape_seqlen_q;
const ck_tile::index_t nhead_stride_dbias =
(i_perm ? shape_seqlen_q * max_seqlen_k : max_seqlen_k);
// setup batch_stride_* arguments
......@@ -433,10 +452,12 @@ bool run(const ck_tile::ArgParser& arg_parser)
const ck_tile::index_t batch_stride_o = (nhead * shape_seqlen_q * hdim_v);
const ck_tile::index_t batch_stride_randval = (nhead * shape_seqlen_q * max_seqlen_k);
const ck_tile::index_t batch_stride_do = (nhead * shape_seqlen_q * hdim_v);
const ck_tile::index_t batch_stride_lsed = (nhead * max_seqlen_q);
const ck_tile::index_t batch_stride_lsed = (nhead * shape_seqlen_q);
const ck_tile::index_t batch_stride_dk = (nhead * shape_seqlen_k * hdim_q);
const ck_tile::index_t batch_stride_dv = (nhead * shape_seqlen_k * hdim_v);
const ck_tile::index_t batch_stride_dbias = (nhead * shape_seqlen_q * max_seqlen_k);
const ck_tile::index_t split_stride_dq_acc =
(shape_batch * nhead * shape_seqlen_q * hdim_q);
return fmha_bwd_args{q_buf.GetDeviceBuffer(),
k_buf.GetDeviceBuffer(),
......@@ -452,6 +473,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
dk_buf.GetDeviceBuffer(),
dv_buf.GetDeviceBuffer(),
dbias_buf.GetDeviceBuffer(),
dq_acc_buf.GetDeviceBuffer(),
seqstart_q.GetDeviceBuffer(),
seqstart_k.GetDeviceBuffer(),
nullptr,
......@@ -473,6 +495,8 @@ bool run(const ck_tile::ArgParser& arg_parser)
stride_o,
stride_randval,
stride_do,
stride_q, // stride_dq_acc
stride_q, // stride_dq
stride_dk,
stride_dv,
stride_dbias,
......@@ -484,6 +508,10 @@ bool run(const ck_tile::ArgParser& arg_parser)
nhead_stride_randval,
nhead_stride_do,
nhead_stride_lsed,
nhead_stride_q, // nhead_stride_dq_acc
nhead_stride_q, // nhead_stride_dq
nhead_stride_k, // nhead_stride_dk
nhead_stride_v, // nhead_stride_dv
nhead_stride_dbias,
batch_stride_q,
batch_stride_k,
......@@ -493,15 +521,17 @@ bool run(const ck_tile::ArgParser& arg_parser)
batch_stride_randval,
batch_stride_do,
batch_stride_lsed,
batch_stride_q, // batch_stride_dq_acc
batch_stride_q, // batch_stride_dq
batch_stride_dk,
batch_stride_dv,
batch_stride_dbias,
split_stride_dq_acc,
mask.left,
mask.right,
static_cast<ck_tile::index_t>(mask.type),
p_drop,
p_undrop,
s_randval,
{drop_seed, drop_offset}};
}();
......@@ -719,7 +749,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
if(o_perm) o_host_ref.ForEach([&](auto& self, auto idx) { o_host(b, idx[0], idx[1] + query_offset, idx[2]) = self(idx); });
else o_host_ref.ForEach([&](auto& self, auto idx) { o_host(b, idx[1] + query_offset, idx[0], idx[2]) = self(idx); });
lse_host_ref.ForEach([&](auto& self, auto idx) { lse_host(wb, idx[0], idx[1]) = self(idx); });
lse_host_ref.ForEach([&](auto& self, auto idx) { lse_host(b, idx[0], idx[1] + query_offset) = self(idx); });
// clang-format on
q_host_refs.push_back(q_host_ref);
......@@ -738,6 +768,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
lse_buf.ToDevice(lse_host.data());
dq_buf.SetZero();
dbias_buf.SetZero();
dq_acc_buf.SetZero();
ck_tile::stream_config stream_config_v{
nullptr, true, 0, 0, 1, arg_parser.get_str("timer") == std::string("gpu")};
......
......@@ -77,6 +77,7 @@ struct fmha_bwd_args
void* dk_ptr;
void* dv_ptr;
void* dbias_ptr;
void* dq_acc_ptr;
const void* seqstart_q_ptr;
const void* seqstart_k_ptr;
const void* seqlen_k_ptr;
......@@ -97,6 +98,8 @@ struct fmha_bwd_args
ck_tile::index_t stride_o;
ck_tile::index_t stride_randval;
ck_tile::index_t stride_do;
ck_tile::index_t stride_dq_acc;
ck_tile::index_t stride_dq;
ck_tile::index_t stride_dk;
ck_tile::index_t stride_dv;
ck_tile::index_t stride_dbias;
......@@ -108,6 +111,10 @@ struct fmha_bwd_args
ck_tile::index_t nhead_stride_randval;
ck_tile::index_t nhead_stride_do;
ck_tile::index_t nhead_stride_lsed;
ck_tile::index_t nhead_stride_dq_acc;
ck_tile::index_t nhead_stride_dq;
ck_tile::index_t nhead_stride_dk;
ck_tile::index_t nhead_stride_dv;
ck_tile::index_t nhead_stride_dbias;
ck_tile::index_t batch_stride_q;
ck_tile::index_t batch_stride_k;
......@@ -117,15 +124,17 @@ struct fmha_bwd_args
ck_tile::index_t batch_stride_randval;
ck_tile::index_t batch_stride_do;
ck_tile::index_t batch_stride_lsed;
ck_tile::index_t batch_stride_dq_acc;
ck_tile::index_t batch_stride_dq;
ck_tile::index_t batch_stride_dk;
ck_tile::index_t batch_stride_dv;
ck_tile::index_t batch_stride_dbias;
ck_tile::index_t split_stride_dq_acc;
ck_tile::index_t window_size_left;
ck_tile::index_t window_size_right;
ck_tile::index_t mask_type;
float p_drop;
float p_undrop;
bool s_randval;
std::tuple<uint64_t, uint64_t> drop_seed_offset;
};
......@@ -145,10 +154,10 @@ auto fmha_bwd_dq_dk_dv_create_kargs_and_grids(fmha_bwd_args args)
args.do_ptr,
args.d_ptr,
args.rand_val_ptr,
args.dq_ptr,
args.dk_ptr,
args.dv_ptr,
args.dbias_ptr,
args.dq_acc_ptr,
args.seqstart_q_ptr,
args.seqstart_k_ptr,
args.seqlen_k_ptr,
......@@ -163,6 +172,7 @@ auto fmha_bwd_dq_dk_dv_create_kargs_and_grids(fmha_bwd_args args)
args.stride_bias,
args.stride_randval,
args.stride_do,
args.stride_dq_acc,
args.stride_dk,
args.stride_dv,
args.stride_dbias,
......@@ -173,13 +183,15 @@ auto fmha_bwd_dq_dk_dv_create_kargs_and_grids(fmha_bwd_args args)
args.nhead_stride_randval,
args.nhead_stride_do,
args.nhead_stride_lsed,
args.nhead_stride_dq_acc,
args.nhead_stride_dk,
args.nhead_stride_dv,
args.nhead_stride_dbias,
args.batch_stride_lsed,
args.split_stride_dq_acc,
args.window_size_left,
args.window_size_right,
args.mask_type,
args.p_drop,
args.s_randval,
args.drop_seed_offset);
}
else
......@@ -192,10 +204,10 @@ auto fmha_bwd_dq_dk_dv_create_kargs_and_grids(fmha_bwd_args args)
args.do_ptr,
args.d_ptr,
args.rand_val_ptr,
args.dq_ptr,
args.dk_ptr,
args.dv_ptr,
args.dbias_ptr,
args.dq_acc_ptr,
args.seqlen_q,
args.seqlen_k,
args.hdim_q,
......@@ -209,6 +221,7 @@ auto fmha_bwd_dq_dk_dv_create_kargs_and_grids(fmha_bwd_args args)
args.stride_bias,
args.stride_randval,
args.stride_do,
args.stride_dq_acc,
args.stride_dk,
args.stride_dv,
args.stride_dbias,
......@@ -219,6 +232,9 @@ auto fmha_bwd_dq_dk_dv_create_kargs_and_grids(fmha_bwd_args args)
args.nhead_stride_randval,
args.nhead_stride_do,
args.nhead_stride_lsed,
args.nhead_stride_dq_acc,
args.nhead_stride_dk,
args.nhead_stride_dv,
args.nhead_stride_dbias,
args.batch_stride_q,
args.batch_stride_k,
......@@ -227,14 +243,15 @@ auto fmha_bwd_dq_dk_dv_create_kargs_and_grids(fmha_bwd_args args)
args.batch_stride_randval,
args.batch_stride_do,
args.batch_stride_lsed,
args.batch_stride_dq_acc,
args.batch_stride_dk,
args.batch_stride_dv,
args.batch_stride_dbias,
args.split_stride_dq_acc,
args.window_size_left,
args.window_size_right,
args.mask_type,
args.p_drop,
args.s_randval,
args.drop_seed_offset);
}
}();
......@@ -260,8 +277,7 @@ auto fmha_bwd_dot_do_o_create_kargs_and_grids(fmha_bwd_args args)
args.stride_o,
args.nhead_stride_do,
args.nhead_stride_o,
args.nhead_stride_lsed,
args.batch_stride_lsed);
args.nhead_stride_lsed);
}
else
{ // create batch mode kernel arguments
......@@ -286,19 +302,59 @@ auto fmha_bwd_dot_do_o_create_kargs_and_grids(fmha_bwd_args args)
return ck_tile::make_tuple(kargs, grids);
}
template <typename FmhaBwdConvertQGradKernel>
auto fmha_bwd_convert_dq_create_kargs_and_grids(fmha_bwd_args args)
{
auto kargs = [&] {
// create group mode kernel arguments
if constexpr(FmhaBwdConvertQGradKernel::kIsGroupMode)
{
return FmhaBwdConvertQGradKernel::MakeKargs(args.dq_acc_ptr,
args.dq_ptr,
args.seqstart_q_ptr,
args.seqstart_k_ptr,
args.hdim_q,
args.stride_dq,
args.stride_dq_acc,
args.nhead_stride_dq,
args.nhead_stride_dq_acc,
args.split_stride_dq_acc);
}
else
{ // create batch mode kernel arguments
return FmhaBwdConvertQGradKernel::MakeKargs(args.dq_acc_ptr,
args.dq_ptr,
args.seqlen_q,
args.seqlen_k,
args.hdim_q,
args.stride_dq,
args.stride_dq_acc,
args.nhead_stride_dq,
args.nhead_stride_dq_acc,
args.batch_stride_dq,
args.batch_stride_dq_acc,
args.split_stride_dq_acc);
}
}();
dim3 grids = FmhaBwdConvertQGradKernel::GridSize(args.batch, args.nhead_q, args.max_seqlen_q);
return ck_tile::make_tuple(kargs, grids);
}
// this is used to pattern-match internl kernel implementation, not to instantiate kernel
template <ck_tile::index_t HDim_,
typename DataType_,
bool kIsGroupMode_,
ck_tile::BlockFmhaBwdPipelineEnum FmhaBwdPipelineEnum_,
typename FmhaMask_,
typename FmhaDropout_,
ck_tile::BlockAttentionBiasEnum BiasEnum_,
bool kHasBiasGrad_,
bool kHasDropout_,
bool kPadS_,
bool kPadSK_,
bool kPadD_,
bool kPadDv_>
bool kPadDv_,
bool kIsDeterministic_>
struct fmha_bwd_dq_dk_dv_traits_
{
static constexpr ck_tile::index_t HDim = HDim_;
......@@ -306,13 +362,14 @@ struct fmha_bwd_dq_dk_dv_traits_
static constexpr bool kIsGroupMode = kIsGroupMode_;
static constexpr auto FmhaBwdPipelineEnum = FmhaBwdPipelineEnum_;
using FmhaMask = ck_tile::remove_cvref_t<FmhaMask_>;
using FmhaDropout = ck_tile::remove_cvref_t<FmhaDropout_>;
static constexpr auto BiasEnum = BiasEnum_;
static constexpr bool kHasBiasGrad = kHasBiasGrad_;
static constexpr bool kHasDropout = kHasDropout_;
static constexpr bool kPadS = kPadS_;
static constexpr bool kPadSK = kPadSK_;
static constexpr bool kPadD = kPadD_;
static constexpr bool kPadDv = kPadDv_;
static constexpr bool kIsDeterministic = kIsDeterministic_;
};
template <typename Traits_>
......@@ -343,6 +400,31 @@ void fmha_bwd_dot_do_o_oneshot_(const ck_tile::stream_config&, fmha_bwd_args);
template <typename Traits_>
std::string fmha_bwd_dot_do_o_get_name_();
template <ck_tile::index_t HDim_,
typename DataType_,
bool kIsGroupMode_,
bool kPadS_,
bool kPadD_,
bool kIsDeterministic_>
struct fmha_bwd_convert_dq_traits_
{
static constexpr ck_tile::index_t HDim = HDim_;
using DataType = ck_tile::remove_cvref_t<DataType_>;
static constexpr bool kIsGroupMode = kIsGroupMode_;
static constexpr bool kPadS = kPadS_;
static constexpr bool kPadD = kPadD_;
static constexpr bool kIsDeterministic = kIsDeterministic_;
};
template <typename Traits_>
float fmha_bwd_convert_dq_(const ck_tile::stream_config&, fmha_bwd_args);
template <typename Traits_>
void fmha_bwd_convert_dq_oneshot_(const ck_tile::stream_config&, fmha_bwd_args);
template <typename Traits_>
std::string fmha_bwd_convert_dq_get_name_();
// This is the public API, will be generated by script
struct fmha_bwd_traits
{
......@@ -354,6 +436,8 @@ struct fmha_bwd_traits
bias_enum bias_type; // 0:no bias, 1:elementwise bias, 2:alibi. sync with BlockAttentionBiasEnum
bool has_dbias;
bool has_dropout;
bool is_store_randval;
bool is_deterministic;
// TODO: padding check is inside this api
};
float fmha_bwd(fmha_bwd_traits, fmha_bwd_args, const ck_tile::stream_config&);
......@@ -479,16 +479,18 @@ bool run(const ck_tile::ArgParser& arg_parser)
: std::array<ck_tile::index_t, 2>{1, 1});
ck_tile::HostTensor<LSEDataType> lse_acc_host(
1 < num_splits ? std::array<ck_tile::index_t, 4>{num_splits, batch, nhead, max_seqlen_q}
: std::array<ck_tile::index_t, 4>{1, 1, 1, 1});
1 < num_splits
? std::array<ck_tile::index_t, 4>{num_splits, shape_batch, nhead, shape_seqlen_q}
: std::array<ck_tile::index_t, 4>{1, 1, 1, 1});
ck_tile::HostTensor<OaccDataType> o_acc_host(
1 < num_splits
? std::array<ck_tile::index_t, 5>{num_splits, batch, nhead, max_seqlen_q, hdim_v}
: std::array<ck_tile::index_t, 5>{1, 1, 1, 1, 1});
// self define lse data layout as [batch, nhead, max_seqlen_q]
// batch mode of lse data layout is [batch, nhead, seqlen_q]
// group mode of lse data layout is [nhead, total_seqlen_q]
ck_tile::HostTensor<LSEDataType> lse_host(
lse ? std::array<ck_tile::index_t, 3>{batch, nhead, max_seqlen_q}
lse ? std::array<ck_tile::index_t, 3>{shape_batch, nhead, shape_seqlen_q}
: std::array<ck_tile::index_t, 3>{1, 1, 1} /* dummy shape for simplifying code */);
ck_tile::HostTensor<ODataType> o_host(
......@@ -669,8 +671,8 @@ bool run(const ck_tile::ArgParser& arg_parser)
const ck_tile::index_t nhead_stride_bias =
(i_perm ? 0 * shape_seqlen_q * shape_seqlen_k : 0 * shape_seqlen_k);
const ck_tile::index_t nhead_stride_randval = (shape_seqlen_q * max_seqlen_k);
const ck_tile::index_t nhead_stride_lse = max_seqlen_q;
const ck_tile::index_t nhead_stride_lse_acc = max_seqlen_q;
const ck_tile::index_t nhead_stride_lse = shape_seqlen_q;
const ck_tile::index_t nhead_stride_lse_acc = shape_seqlen_q;
const ck_tile::index_t nhead_stride_o_acc = (max_seqlen_q * hdim_v);
const ck_tile::index_t nhead_stride_o = (o_perm ? shape_seqlen_q * hdim_v : hdim_v);
// setup batch_stride_* arguments
......@@ -679,12 +681,12 @@ bool run(const ck_tile::ArgParser& arg_parser)
const ck_tile::index_t batch_stride_v = (nhead_k * hdim_v * shape_seqlen_k);
const ck_tile::index_t batch_stride_bias = (0 * nhead * shape_seqlen_q * shape_seqlen_k);
const ck_tile::index_t batch_stride_randval = (nhead * shape_seqlen_q * max_seqlen_k);
const ck_tile::index_t batch_stride_lse = (nhead * max_seqlen_q);
const ck_tile::index_t batch_stride_lse_acc = (nhead * max_seqlen_q);
const ck_tile::index_t batch_stride_lse = (nhead * shape_seqlen_q);
const ck_tile::index_t batch_stride_lse_acc = (nhead * shape_seqlen_q);
const ck_tile::index_t batch_stride_o_acc = (nhead * max_seqlen_q * hdim_v);
const ck_tile::index_t batch_stride_o = (nhead * shape_seqlen_q * hdim_v);
// setup split_stride_* arguments (only used in split-kv kernel)
const ck_tile::index_t split_stride_lse_acc = (batch * nhead * max_seqlen_q);
const ck_tile::index_t split_stride_lse_acc = (shape_batch * nhead * shape_seqlen_q);
const ck_tile::index_t split_stride_o_acc = (batch * nhead * max_seqlen_q * hdim_v);
return fmha_fwd_args{q_buf.GetDeviceBuffer(),
......@@ -996,8 +998,9 @@ bool run(const ck_tile::ArgParser& arg_parser)
if(lse)
{
ck_tile::HostTensor<SMPLComputeDataType> lse_host_result({nhead, real_seqlen_q});
lse_host_result.ForEach(
[&](auto& self, auto idx) { self(idx) = lse_host(wb, idx[0], idx[1]); });
lse_host_result.ForEach([&](auto& self, auto idx) {
self(idx) = lse_host(b, idx[0], idx[1] + query_offset);
});
cur_pass = ck_tile::check_err(lse_host_result,
lse_host_ref,
......
......@@ -185,7 +185,6 @@ auto fmha_fwd_create_kargs_and_grids(fmha_fwd_args args)
args.nhead_stride_randval,
args.nhead_stride_lse,
args.nhead_stride_o,
args.batch_stride_lse,
args.window_size_left,
args.window_size_right,
args.mask_type,
......@@ -284,7 +283,6 @@ auto fmha_fwd_splitkv_create_kargs_and_grids(fmha_fwd_args args)
args.nhead_stride_randval,
args.nhead_stride_lse_acc,
args.nhead_stride_o_acc,
args.batch_stride_lse_acc,
args.batch_stride_o_acc,
args.split_stride_lse_acc,
args.split_stride_o_acc,
......@@ -376,9 +374,7 @@ auto fmha_fwd_splitkv_combine_create_kargs_and_grids(fmha_fwd_args args)
args.nhead_stride_o_acc,
args.nhead_stride_lse,
args.nhead_stride_o,
args.batch_stride_lse_acc,
args.batch_stride_o_acc,
args.batch_stride_lse,
args.split_stride_lse_acc,
args.split_stride_o_acc);
}
......
#!/bin/bash
#
# in order to run this script you'd first need to build the tile_example_fmha_fwd and tile_eaxmple_fmha_bwd executables in ../build/bin/
#
# run the script as "./run_full_test.sh <tag for your test environment> <branch name> <host name> <gpu_arch>
# input arguments:
# environment tag : a string describing the specifics of your test environment
# branch name : name of the branch in git repo (git status | grep -e 'On branch')
# host name : $hostname
# gpu architecture: e.g., gfx90a, or gfx942, etc.
#get the command line arguments:
export env_type=$1
echo 'Environment type: ' $env_type
export branch=$2
echo 'Branch name: ' $branch
export host_name=$3
echo 'Host name: ' $host_name
export GPU_arch=$4
echo 'GPU_arch: ' $GPU_arch
function print_log_header(){
rm -f $1;
echo 'On branch ' $3 &> $1;
echo 'Node name: ' $4 >> $1;
#get GPU_arch and number of compute units from rocminfo
echo -n "GPU_arch: " >> $1; rocminfo | grep "Name:" | grep "gfx" >> $1;
rocminfo | grep "Compute Unit:" >> $1;
hipcc --version | grep -e 'HIP version' >> $1;
echo 'Environment type: ' $2 >> $1;
/opt/rocm/bin/amdclang++ --version | grep -e 'InstalledDir' >> $1;
}
#run verification tests
example/ck_tile/01_fmha/script/smoke_test_fwd.sh
example/ck_tile/01_fmha/script/smoke_test_bwd.sh
#run performance benchmarks
export fmha_fwd_log="perf_fmha_fwd_$GPU_arch.log"
print_log_header $fmha_fwd_log $env_type $branch $host_name
example/ck_tile/01_fmha/script/benchmark_fwd.sh 2>&1 | tee -a $fmha_fwd_log
export fmha_bwd_log="perf_fmha_bwd_$GPU_arch.log"
print_log_header $fmha_bwd_log $env_type $branch $host_name
example/ck_tile/01_fmha/script/benchmark_bwd.sh 2>&1 | tee -a $fmha_bwd_log
......@@ -11,18 +11,19 @@ COMMON_ARGS='-v=1'
set -x
for prec in "fp16" "bf16" ; do
for perm in 0 1 ; do
for hdim in 32 64 128 ; do
for hdim in 32 64 128 256 ; do
for mode in 0 1 ; do
for bias in "n" "e" "a"; do
for dbias in 0 1 ; do
for p_drop in 0.0 0.2; do
for bias in "n" "a" ; do
for dbias in 0 ; do
for p_drop in 0.0 0.2 ; do
for deterministic in 0 ; do
$EXE -prec=$prec -b=1 -h=4 -h_k=2 -d=$hdim -s=259 -bias=$bias -dbias=$dbias -p_drop=$p_drop -iperm=$perm -operm=$perm -v=1 -mode=$mode -kname=$KNAME $COMMON_ARGS
$EXE -prec=$prec -b=2 -h=2 -d=$hdim -s=516 -s_k=253 -bias=$bias -dbias=$dbias -p_drop=$p_drop -iperm=$perm -operm=$perm -v=1 -mode=$mode -kname=$KNAME $COMMON_ARGS
$EXE -prec=$prec -b=1 -h=4 -h_k=1 -d=$hdim -s=500 -s_k=251 -bias=$bias -dbias=$dbias -p_drop=$p_drop -iperm=$perm -operm=$perm -mask=1 -v=1 -mode=$mode -kname=$KNAME $COMMON_ARGS
$EXE -prec=$prec -b=1 -h=2 -d=$hdim -s=900 -s_k=258 -bias=$bias -dbias=$dbias -p_drop=$p_drop -iperm=$perm -operm=$perm -mask=2 -v=1 -mode=$mode -kname=$KNAME $COMMON_ARGS
$EXE -prec=$prec -b=2 -h=1 -d=$hdim -s=987 -s_k=219 -bias=$bias -dbias=$dbias -p_drop=$p_drop -iperm=$perm -operm=$perm -mask=t:128,30 -v=1 -mode=$mode -kname=$KNAME $COMMON_ARGS
$EXE -prec=$prec -b=2 -h=3 -h_k=1 -d=$hdim -s=244 -s_k=499 -bias=$bias -dbias=$dbias -p_drop=$p_drop -iperm=$perm -operm=$perm -mask=b:4,35 -v=1 -mode=$mode -kname=$KNAME $COMMON_ARGS
$EXE -prec=$prec -b=1 -h=4 -h_k=2 -d=$hdim -s=259 -bias=$bias -dbias=$dbias -p_drop=$p_drop -iperm=$perm -operm=$perm -deterministic=$deterministic -v=1 -mode=$mode -kname=$KNAME $COMMON_ARGS
$EXE -prec=$prec -b=2 -h=2 -d=$hdim -s=516 -s_k=253 -bias=$bias -dbias=$dbias -p_drop=$p_drop -iperm=$perm -operm=$perm -deterministic=$deterministic -v=1 -mode=$mode -kname=$KNAME $COMMON_ARGS
$EXE -prec=$prec -b=1 -h=4 -h_k=1 -d=$hdim -s=500 -s_k=251 -bias=$bias -dbias=$dbias -p_drop=$p_drop -iperm=$perm -operm=$perm -mask=1 -deterministic=$deterministic -v=1 -mode=$mode -kname=$KNAME $COMMON_ARGS
$EXE -prec=$prec -b=1 -h=2 -d=$hdim -s=900 -s_k=258 -bias=$bias -dbias=$dbias -p_drop=$p_drop -iperm=$perm -operm=$perm -mask=2 -v=1 -deterministic=$deterministic -mode=$mode -kname=$KNAME $COMMON_ARGS
$EXE -prec=$prec -b=2 -h=1 -d=$hdim -s=987 -s_k=219 -bias=$bias -dbias=$dbias -p_drop=$p_drop -iperm=$perm -operm=$perm -mask=t:128,30 -deterministic=$deterministic -v=1 -mode=$mode -kname=$KNAME $COMMON_ARGS
$EXE -prec=$prec -b=2 -h=3 -h_k=1 -d=$hdim -s=244 -s_k=499 -bias=$bias -dbias=$dbias -p_drop=$p_drop -iperm=$perm -operm=$perm -mask=b:4,35 -deterministic=$deterministic -v=1 -mode=$mode -kname=$KNAME $COMMON_ARGS
done
done
......@@ -31,4 +32,5 @@ done
done
done
done
done
set +x
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
......@@ -153,8 +153,8 @@ CK_DECLARE_ENV_VAR_BOOL(CK_LOGGING)
// LDS direct loads using inline assembly
#define CK_USE_AMD_LDS_DIRECT_LOAD_INLINE_ASM 0
// set stochastic rounding as default for f8 conversions
#define CK_USE_SR_F8_CONVERSION 1
// set rounding to nearest even as default for f8 conversions
#define CK_USE_SR_F8_CONVERSION 0
// block synchronization only s_wait lgkmcnt(0), not vmcnt(0)
#define CK_EXPERIMENTAL_BLOCK_SYNC_LDS_WITHOUT_SYNC_VMEM 1
......
......@@ -65,6 +65,12 @@ inline bool is_lds_direct_load_supported()
ck::get_device_name() == "gfx941" || ck::get_device_name() == "gfx942";
}
inline bool is_bf16_atomic_supported()
{
return ck::get_device_name() == "gfx940" || ck::get_device_name() == "gfx941" ||
ck::get_device_name() == "gfx942";
}
inline bool is_gfx101_supported()
{
return ck::get_device_name() == "gfx1010" || ck::get_device_name() == "gfx1011" ||
......
......@@ -14,6 +14,124 @@
namespace ck {
namespace utility {
template <typename Argument, typename DsDataType>
struct RotatingMemWrapperMultiD
{
static constexpr index_t NumDs = DsDataType::Size();
using ADataType = decltype(Argument::p_a_grid);
using BDataType = decltype(Argument::p_b_grid);
using DsGridPointer = decltype(Argument::p_ds_grid);
RotatingMemWrapperMultiD() = delete;
RotatingMemWrapperMultiD(Argument& arg_,
std::size_t rotating_count_,
std::size_t size_a_,
std::size_t size_b_,
std::array<std::size_t, NumDs> size_ds_)
: arg(arg_),
rotating_count(rotating_count_),
size_a(size_a_),
size_b(size_b_),
size_ds(size_ds_)
{
p_a_grids.push_back(arg.p_a_grid);
p_b_grids.push_back(arg.p_b_grid);
p_ds_grids.push_back(arg.p_ds_grid);
for(size_t i = 1; i < rotating_count; i++)
{
{
void* pADeviceBuf;
hip_check_error(hipMalloc(static_cast<void**>(&pADeviceBuf), size_a_));
hip_check_error(hipMemcpy(static_cast<void*>(pADeviceBuf),
const_cast<void*>(p_a_grids[0]),
size_a_,
hipMemcpyDeviceToDevice));
p_a_grids.push_back(pADeviceBuf);
}
{
void* pBDeviceBuf;
hip_check_error(hipMalloc(static_cast<void**>(&pBDeviceBuf), size_b_));
hip_check_error(hipMemcpy(static_cast<void*>(pBDeviceBuf),
const_cast<void*>(p_b_grids[0]),
size_b_,
hipMemcpyDeviceToDevice));
p_b_grids.push_back(pBDeviceBuf);
}
{
DsGridPointer ds_buffer;
static_for<0, NumDs, 1>{}([&](auto j) {
void* pDDeviceBuf;
hip_check_error(hipMalloc(static_cast<void**>(&pDDeviceBuf), size_ds_[j]));
hip_check_error(hipMemcpy(static_cast<void*>(pDDeviceBuf),
static_cast<const void*>(p_ds_grids[0][j]),
size_ds_[j],
hipMemcpyDeviceToDevice));
using DDataType = remove_cvref_t<tuple_element_t<j.value, DsDataType>>;
ds_buffer(j) = static_cast<const DDataType*>(pDDeviceBuf);
});
p_ds_grids.push_back(ds_buffer);
}
}
}
void Next()
{
if(rotating_count > 1)
{
std::size_t idx = iter++ % rotating_count;
arg.p_a_grid = reinterpret_cast<ADataType>(p_a_grids[idx]);
arg.p_b_grid = reinterpret_cast<BDataType>(p_b_grids[idx]);
arg.p_ds_grid = p_ds_grids[idx];
}
}
void Print()
{
std::cout << "RotatingMemWrapperMultiD: { size_a: " << size_a << ", size_b: " << size_b
<< ", rotating_count: " << rotating_count << "}" << std::endl;
}
~RotatingMemWrapperMultiD()
{
if(rotating_count > 1)
{
// restore ptr
arg.p_a_grid = reinterpret_cast<ADataType>(p_a_grids[0]);
arg.p_b_grid = reinterpret_cast<BDataType>(p_b_grids[0]);
arg.p_ds_grid = p_ds_grids[0];
// free device mem
for(size_t i = 1; i < rotating_count; i++)
{
hip_check_error(hipFree(const_cast<void*>(p_a_grids[i])));
hip_check_error(hipFree(const_cast<void*>(p_b_grids[i])));
static_for<0, NumDs, 1>{}([&](auto j) {
using DDataType = remove_cvref_t<tuple_element_t<j.value, DsDataType>>;
hip_check_error(
hipFree(static_cast<void*>(const_cast<DDataType*>(p_ds_grids[i][j]))));
});
}
}
}
private:
Argument& arg;
std::size_t iter = 0;
std::size_t rotating_count = 1;
std::size_t size_a = 0;
std::size_t size_b = 0;
std::array<std::size_t, NumDs> size_ds = {0};
std::vector<const void*> p_a_grids;
std::vector<const void*> p_b_grids;
std::vector<DsGridPointer> p_ds_grids;
};
template <typename Argument>
struct RotatingMemWrapper
{
......
......@@ -53,6 +53,49 @@ struct DeviceGemmMultipleD : public BaseOperator
virtual std::unique_ptr<BaseInvoker> MakeInvokerPointer() = 0;
};
// GEMM:
// input : A[M, K], B[K, N],
// input : D0[M, N], D1[M, N], ...
// output : E[M, N]
// C = a_op(A) * b_op(B)
// E = cde_op(C, D0, D1, ...)
// Assume:
// D0, D1, ... and E have the same layout
template <typename ALayout,
typename BLayout,
typename DsLayout,
typename ELayout,
typename ADataType,
typename BDataType,
typename DsDataType,
typename EDataType,
typename AElementwiseOperation,
typename BElementwiseOperation,
typename CDEElementwiseOperation>
struct DeviceGemmMultipleDSplitK : public BaseOperator
{
static constexpr index_t NumDTensor = DsDataType::Size();
virtual std::unique_ptr<BaseArgument>
MakeArgumentPointer(const void* p_a,
const void* p_b,
std::array<const void*, NumDTensor> p_ds,
void* p_e,
ck::index_t M,
ck::index_t N,
ck::index_t K,
ck::index_t StrideA,
ck::index_t StrideB,
std::array<ck::index_t, NumDTensor> StrideDs,
ck::index_t StrideE,
ck::index_t KBatch,
AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op,
CDEElementwiseOperation cde_element_op) = 0;
virtual std::unique_ptr<BaseInvoker> MakeInvokerPointer() = 0;
};
} // namespace device
} // namespace tensor_operation
} // namespace ck
......@@ -126,6 +126,29 @@ struct DeviceGroupedConvFwdMultipleABD : public BaseOperator
const BElementwiseOperation& b_element_op,
const CDEElementwiseOperation& cde_element_op) = 0;
virtual std::unique_ptr<BaseArgument>
MakeArgumentPointer(APointers p_a,
BPointers p_b,
const std::array<const void*, NumDTensor>& p_ds,
void* p_e,
const std::array<long_index_t, NDimSpatial + 3>& a_g_n_c_wis_lengths,
const std::array<long_index_t, NDimSpatial + 3>& a_g_n_c_wis_strides,
const std::array<long_index_t, NDimSpatial + 3>& b_g_k_c_xs_lengths,
const std::array<long_index_t, NDimSpatial + 3>& b_g_k_c_xs_strides,
const std::array<std::array<long_index_t, NDimSpatial + 3>, NumDTensor>&
ds_g_n_k_wos_lengths,
const std::array<std::array<long_index_t, NDimSpatial + 3>, NumDTensor>&
ds_g_n_k_wos_strides,
const std::array<long_index_t, NDimSpatial + 3>& e_g_n_k_wos_lengths,
const std::array<long_index_t, NDimSpatial + 3>& e_g_n_k_wos_strides,
const std::array<long_index_t, NDimSpatial>& conv_filter_strides,
const std::array<long_index_t, NDimSpatial>& conv_filter_dilations,
const std::array<long_index_t, NDimSpatial>& input_left_pads,
const std::array<long_index_t, NDimSpatial>& input_right_pads,
const AElementwiseOperation& a_element_op,
const BElementwiseOperation& b_element_op,
const CDEElementwiseOperation& cde_element_op) = 0;
virtual std::unique_ptr<BaseInvoker> MakeInvokerPointer() = 0;
};
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment