Commit f6ceef78 authored by ThomasNing's avatar ThomasNing
Browse files

merge with the develop branch

parents 536c5458 25935b57
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#include "convnd_fwd_convscale_reduce_common.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp"
using InDataType = ck::f8_t;
using WeiDataType = ck::f8_t;
using AccDataType = float;
using CShuffleDataType = float;
using ConvOutDataType = float; // data type of convolution result
using OutDataType = ck::f8_t; // data type of final result
using AComputeDataType = ck::f8_t;
using BComputeDataType = ck::f8_t;
template <ck::index_t... Is>
using S = ck::Sequence<Is...>;
using InElementOp = PassThrough;
using WeiElementOp = PassThrough;
using OutElementOp = ConvScaleRelu;
static constexpr auto ConvSpec =
ck::tensor_operation::device::ConvolutionForwardSpecialization::Default;
static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::MNKPadding;
template <ck::index_t NDimSpatial, typename InLayout, typename WeiLayout, typename OutLayout>
using DeviceGroupedConvNDFwdInstance =
ck::tensor_operation::device::DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<
NDimSpatial,
InLayout,
WeiLayout,
ck::Tuple<>,
OutLayout,
InDataType,
WeiDataType,
AccDataType,
CShuffleDataType,
ck::Tuple<>,
ConvOutDataType,
InElementOp,
WeiElementOp,
OutElementOp,
ConvSpec, // ConvForwardSpecialization
GemmSpec, // GemmSpecialization
1, //
256, // BlockSize
128, // MPerBlock
256, // NPerBlock
32, // KPerBlock
8, // AK1
8, // BK1
32, // MPerXdl
32, // NPerXdl
2, // MXdlPerWave
4, // NXdlPerWave
S<4, 64, 1>, // ABlockTransferThreadClusterLengths_AK0_M_AK1
S<1, 0, 2>, // ABlockTransferThreadClusterArrangeOrder
S<1, 0, 2>, // ABlockTransferSrcAccessOrder
2, // ABlockTransferSrcVectorDim
8, // ABlockTransferSrcScalarPerVector
8, // ABlockTransferDstScalarPerVector_AK1
1, // ABlockLdsExtraM
S<4, 64, 1>, // BBlockTransferThreadClusterLengths_BK0_N_BK1
S<1, 0, 2>, // BBlockTransferThreadClusterArrangeOrder
S<1, 0, 2>, // BBlockTransferSrcAccessOrder
2, // BBlockTransferSrcVectorDim
8, // BBlockTransferSrcScalarPerVector
8, // BBlockTransferDstScalarPerVector_BK1
1, // BBlockLdsExtraN
1,
1,
S<1, 32, 1, 8>,
8,
AComputeDataType,
BComputeDataType>;
#include "run_convnd_fwd_example.inc"
int main(int argc, char* argv[]) { return run_convnd_fwd_example(argc, argv) ? 0 : 1; }
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
bool run_convnd_fwd_example(int argc, char* argv[])
{
print_helper_msg();
bool do_verification = true;
int init_method = 1;
bool time_kernel = false;
ck::utils::conv::ConvParam conv_param{
2, 1, 128, 256, 192, {3, 3}, {71, 71}, {2, 2}, {1, 1}, {1, 1}, {1, 1}};
if(argc == 1)
{
// use default
}
else if(argc == 4)
{
do_verification = std::stoi(argv[1]);
init_method = std::stoi(argv[2]);
time_kernel = std::stoi(argv[3]);
}
else
{
do_verification = std::stoi(argv[1]);
init_method = std::stoi(argv[2]);
time_kernel = std::stoi(argv[3]);
const ck::index_t num_dim_spatial = std::stoi(argv[4]);
conv_param = ck::utils::conv::parse_conv_param(num_dim_spatial, 5, argv);
}
// instantiate in and wei element ops, will
// instantiate out_element_op below for every iteration
const auto in_element_op = InElementOp{};
const auto wei_element_op = WeiElementOp{};
const auto run = [&](auto ndim_spatial, auto in_layout, auto wei_layout, auto out_layout) {
constexpr ck::index_t ndim_spatial_value = ndim_spatial.value;
using InLayout = decltype(in_layout);
using WeiLayout = decltype(wei_layout);
using OutLayout = decltype(out_layout);
const auto in_g_n_c_wis_desc =
ck::utils::conv::make_input_host_tensor_descriptor_g_n_c_wis_packed<InLayout>(
conv_param);
const auto wei_g_k_c_xs_desc =
ck::utils::conv::make_weight_host_tensor_descriptor_g_k_c_xs_packed<WeiLayout>(
conv_param);
const auto out_g_n_k_wos_desc =
ck::utils::conv::make_output_host_tensor_descriptor_g_n_k_wos_packed<OutLayout>(
conv_param);
return run_grouped_conv_fwd<
ndim_spatial_value,
InDataType,
WeiDataType,
ConvOutDataType,
OutDataType,
InElementOp,
WeiElementOp,
OutElementOp,
DeviceGroupedConvNDFwdInstance<ndim_spatial_value, InLayout, WeiLayout, OutLayout>>(
do_verification,
init_method,
time_kernel,
conv_param,
in_g_n_c_wis_desc,
wei_g_k_c_xs_desc,
out_g_n_k_wos_desc,
in_element_op,
wei_element_op);
};
namespace ctc = ck::tensor_layout::convolution;
if(conv_param.num_dim_spatial_ == 1)
{
return run(ck::Number<1>{}, ctc::GNWC{}, ctc::GKXC{}, ctc::GNWK{});
}
else if(conv_param.num_dim_spatial_ == 2)
{
return run(ck::Number<2>{}, ctc::GNHWC{}, ctc::GKYXC{}, ctc::GNHWK{});
}
else if(conv_param.num_dim_spatial_ == 3)
{
return run(ck::Number<3>{}, ctc::GNDHWC{}, ctc::GKZYXC{}, ctc::GNDHWK{});
}
return true;
}
...@@ -208,6 +208,7 @@ int main(int argc, char* argv[]) ...@@ -208,6 +208,7 @@ int main(int argc, char* argv[])
StrideB, StrideB,
std::array<ck::index_t, NumDTensor>{StrideD, StrideD}, std::array<ck::index_t, NumDTensor>{StrideD, StrideD},
StrideE, StrideE,
1,
a_element_op, a_element_op,
b_element_op, b_element_op,
cde_element_op); cde_element_op);
......
...@@ -69,7 +69,7 @@ using AElementOp = PassThrough; ...@@ -69,7 +69,7 @@ using AElementOp = PassThrough;
using BElementOp = PassThrough; using BElementOp = PassThrough;
using CDEElementOp = MultiplyMultiply; using CDEElementOp = MultiplyMultiply;
static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::Default; static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::MNPadding;
using DeviceOpInstance = ck::tensor_operation::device::DeviceGemmMultiD_Xdl_CShuffle_V3 using DeviceOpInstance = ck::tensor_operation::device::DeviceGemmMultiD_Xdl_CShuffle_V3
// clang-format off // clang-format off
...@@ -99,6 +99,8 @@ int main(int argc, char* argv[]) ...@@ -99,6 +99,8 @@ int main(int argc, char* argv[])
ck::index_t StrideD = 0; ck::index_t StrideD = 0;
ck::index_t StrideE = N; ck::index_t StrideE = N;
ck::index_t KBatch = 1;
if(argc == 1) if(argc == 1)
{ {
// use default case // use default case
...@@ -109,7 +111,7 @@ int main(int argc, char* argv[]) ...@@ -109,7 +111,7 @@ int main(int argc, char* argv[])
init_method = std::stoi(argv[2]); init_method = std::stoi(argv[2]);
time_kernel = std::stoi(argv[3]); time_kernel = std::stoi(argv[3]);
} }
else if(argc == 11) else if(argc == 12)
{ {
do_verification = std::stoi(argv[1]); do_verification = std::stoi(argv[1]);
init_method = std::stoi(argv[2]); init_method = std::stoi(argv[2]);
...@@ -123,13 +125,16 @@ int main(int argc, char* argv[]) ...@@ -123,13 +125,16 @@ int main(int argc, char* argv[])
StrideB = std::stoi(argv[8]); StrideB = std::stoi(argv[8]);
StrideD = std::stoi(argv[9]); StrideD = std::stoi(argv[9]);
StrideE = std::stoi(argv[10]); StrideE = std::stoi(argv[10]);
KBatch = std::stoi(argv[11]);
} }
else else
{ {
printf("arg1: verification (0=no, 1=yes)\n"); printf("arg1: verification (0=no, 1=yes)\n");
printf("arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n"); printf("arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n");
printf("arg3: time kernel (0=no, 1=yes)\n"); printf("arg3: time kernel (0=no, 1=yes)\n");
printf("arg4 to 9: M (256x), N(128x), K(32x), StrideA, StrideB, StrideD, StrideE\n"); printf(
"arg4 to 9: M (256x), N(128x), K(32x), StrideA, StrideB, StrideD, StrideE, KBatch\n");
exit(0); exit(0);
} }
...@@ -212,6 +217,7 @@ int main(int argc, char* argv[]) ...@@ -212,6 +217,7 @@ int main(int argc, char* argv[])
StrideB, StrideB,
std::array<ck::index_t, NumDTensor>{I0, I0}, std::array<ck::index_t, NumDTensor>{I0, I0},
StrideE, StrideE,
KBatch,
a_element_op, a_element_op,
b_element_op, b_element_op,
cde_element_op); cde_element_op);
...@@ -236,10 +242,12 @@ int main(int argc, char* argv[]) ...@@ -236,10 +242,12 @@ int main(int argc, char* argv[])
std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s" std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s"
<< std::endl; << std::endl;
e_device_buf.FromDevice(e_m_n_device_result.mData.data());
if(do_verification) if(do_verification)
{ {
invoker.Run(argument, StreamConfig{nullptr, false});
e_device_buf.FromDevice(e_m_n_device_result.mData.data());
Tensor<CShuffleDataType> c_m_n({M, N}); Tensor<CShuffleDataType> c_m_n({M, N});
using ReferenceGemmInstance = ck::tensor_operation::host::ReferenceGemm<A0DataType, using ReferenceGemmInstance = ck::tensor_operation::host::ReferenceGemm<A0DataType,
......
...@@ -72,10 +72,24 @@ function(add_example_executable EXAMPLE_NAME FILE_NAME) ...@@ -72,10 +72,24 @@ function(add_example_executable EXAMPLE_NAME FILE_NAME)
list(REMOVE_ITEM FILE_NAME "${source}") list(REMOVE_ITEM FILE_NAME "${source}")
endif() endif()
endforeach() endforeach()
#Do not build any FP8 examples if CK_ENABLE_FP8 not set
foreach(source IN LISTS FILE_NAME)
if(NOT DEFINED CK_ENABLE_FP8 AND source MATCHES "_fp8")
message("removing fp8 example ${source} ")
list(REMOVE_ITEM FILE_NAME "${source}")
endif()
endforeach()
#Do not build any BF8 examples if CK_ENABLE_BF8 not set
foreach(source IN LISTS FILE_NAME)
if(NOT DEFINED CK_ENABLE_BF8 AND source MATCHES "_bf8")
message("removing bf8 example ${source} ")
list(REMOVE_ITEM FILE_NAME "${source}")
endif()
endforeach()
#only continue if there are some source files left on the list #only continue if there are some source files left on the list
if(FILE_NAME) if(FILE_NAME)
if(FILE_NAME MATCHES "_xdl") if(FILE_NAME MATCHES "_xdl")
list(REMOVE_ITEM EX_TARGETS gfx1030 gfx1100 gfx1101 gfx1102 gfx1103) list(REMOVE_ITEM EX_TARGETS gfx1030 gfx1100 gfx1101 gfx1102 gfx1103 gfx1200 gfx1201)
elseif(FILE_NAME MATCHES "_wmma") elseif(FILE_NAME MATCHES "_wmma")
list(REMOVE_ITEM EX_TARGETS gfx908 gfx90a gfx940 gfx941 gfx942 gfx1030) list(REMOVE_ITEM EX_TARGETS gfx908 gfx90a gfx940 gfx941 gfx942 gfx1030)
endif() endif()
...@@ -162,7 +176,7 @@ function(add_example_executable_no_testing EXAMPLE_NAME FILE_NAME) ...@@ -162,7 +176,7 @@ function(add_example_executable_no_testing EXAMPLE_NAME FILE_NAME)
#only continue if there are some source files left on the list #only continue if there are some source files left on the list
if(FILE_NAME) if(FILE_NAME)
if(FILE_NAME MATCHES "_xdl") if(FILE_NAME MATCHES "_xdl")
list(REMOVE_ITEM EX_TARGETS gfx900 gfx906 gfx1030 gfx1100 gfx1101 gfx1102 gfx1103) list(REMOVE_ITEM EX_TARGETS gfx900 gfx906 gfx1030 gfx1100 gfx1101 gfx1102 gfx1103 gfx1200 gfx1201)
elseif(FILE_NAME MATCHES "_wmma") elseif(FILE_NAME MATCHES "_wmma")
list(REMOVE_ITEM EX_TARGETS gfx900 gfx906 gfx908 gfx90a gfx940 gfx941 gfx942 gfx1030) list(REMOVE_ITEM EX_TARGETS gfx900 gfx906 gfx908 gfx90a gfx940 gfx941 gfx942 gfx1030)
endif() endif()
......
...@@ -6,7 +6,7 @@ execute_process( ...@@ -6,7 +6,7 @@ execute_process(
execute_process( execute_process(
COMMAND ${Python3_EXECUTABLE} ${CMAKE_CURRENT_LIST_DIR}/generate.py COMMAND ${Python3_EXECUTABLE} ${CMAKE_CURRENT_LIST_DIR}/generate.py
--api bwd --list_blobs ${CMAKE_CURRENT_BINARY_DIR}/bwd_blob_list.txt --api bwd --list_blobs ${CMAKE_CURRENT_BINARY_DIR}/bwd_blob_list.txt --receipt 3
) )
# NOTE: for cmake, the FMHA_FWD_GEN_BLOBS/FMHA_BWD_GEN_BLOBS files must be in the same directory # NOTE: for cmake, the FMHA_FWD_GEN_BLOBS/FMHA_BWD_GEN_BLOBS files must be in the same directory
...@@ -23,7 +23,7 @@ add_custom_command( ...@@ -23,7 +23,7 @@ add_custom_command(
add_custom_command( add_custom_command(
OUTPUT ${FMHA_BWD_GEN_BLOBS} OUTPUT ${FMHA_BWD_GEN_BLOBS}
COMMAND ${Python3_EXECUTABLE} ${CMAKE_CURRENT_LIST_DIR}/generate.py COMMAND ${Python3_EXECUTABLE} ${CMAKE_CURRENT_LIST_DIR}/generate.py
--api bwd --output_dir ${CMAKE_CURRENT_BINARY_DIR} --api bwd --output_dir ${CMAKE_CURRENT_BINARY_DIR} --receipt 3
) )
set(EXAMPLE_FMHA_FWD "tile_example_fmha_fwd") set(EXAMPLE_FMHA_FWD "tile_example_fmha_fwd")
...@@ -55,11 +55,10 @@ set(EXAMPLE_FMHA_BWD_COMPILE_OPTIONS) ...@@ -55,11 +55,10 @@ set(EXAMPLE_FMHA_BWD_COMPILE_OPTIONS)
# ... because they are auto-generated # ... because they are auto-generated
if(FMHA_FWD_FAST_EXP2) if(FMHA_FWD_FAST_EXP2)
list(APPEND EXAMPLE_FMHA_FWD_COMPILE_OPTIONS -Wno-undefined-func-template -DCK_TILE_FMHA_FWD_FAST_EXP2=1 -fgpu-flush-denormals-to-zero) list(APPEND EXAMPLE_FMHA_FWD_COMPILE_OPTIONS -Wno-undefined-func-template -DCK_TILE_FMHA_FWD_FAST_EXP2=1 -fgpu-flush-denormals-to-zero)
list(APPEND EXAMPLE_FMHA_BWD_COMPILE_OPTIONS -Wno-undefined-func-template -DCK_TILE_FMHA_FWD_FAST_EXP2=1 -fgpu-flush-denormals-to-zero)
else() else()
list(APPEND EXAMPLE_FMHA_FWD_COMPILE_OPTIONS -Wno-undefined-func-template -DCK_TILE_FMHA_FWD_FAST_EXP2=0) list(APPEND EXAMPLE_FMHA_FWD_COMPILE_OPTIONS -Wno-undefined-func-template -DCK_TILE_FMHA_FWD_FAST_EXP2=0)
list(APPEND EXAMPLE_FMHA_BWD_COMPILE_OPTIONS -Wno-undefined-func-template -DCK_TILE_FMHA_FWD_FAST_EXP2=0)
endif() endif()
list(APPEND EXAMPLE_FMHA_BWD_COMPILE_OPTIONS -Wno-undefined-func-template -fgpu-flush-denormals-to-zero)
# Allow comparing floating points directly in order to check sentinel values # Allow comparing floating points directly in order to check sentinel values
list(APPEND EXAMPLE_FMHA_FWD_COMPILE_OPTIONS -Wno-float-equal) list(APPEND EXAMPLE_FMHA_FWD_COMPILE_OPTIONS -Wno-float-equal)
......
...@@ -66,6 +66,22 @@ BIAS_CHECK_MAP = { ...@@ -66,6 +66,22 @@ BIAS_CHECK_MAP = {
"alibi" : "bias_enum::alibi" "alibi" : "bias_enum::alibi"
} }
DROPOUT_MAP = {
"no" : "ck_tile::BlockDropoutBwd<false, true, false>",
"dropout_wg32" : "ck_tile::BlockDropoutBwd<true, true, false>",
"dropout_wg32_storerandval" : "ck_tile::BlockDropoutBwd<true, true, true >",
"dropout_wg16" : "ck_tile::BlockDropoutBwd<true, false, false>",
"dropout_wg16_storerandval" : "ck_tile::BlockDropoutBwd<true, false, true >"
}
DROPOUT_CHECK_MAP = {
"no" : "t.has_dropout == false",
"dropout_wg32" : "t.has_dropout == true && t.is_store_randval == false",
"dropout_wg32_storerandval" : "t.has_dropout == true && t.is_store_randval == true",
"dropout_wg16" : "t.has_dropout == true && t.is_store_randval == false",
"dropout_wg16_storerandval" : "t.has_dropout == true && t.is_store_randval == true",
}
MODE_MAP = { MODE_MAP = {
"batch" : "false", "batch" : "false",
"group" : "true" "group" : "true"
......
...@@ -427,12 +427,19 @@ def get_fwd_blobs(kernel_filter : Optional[str], receipt, mask_impl) -> Tuple[Fm ...@@ -427,12 +427,19 @@ def get_fwd_blobs(kernel_filter : Optional[str], receipt, mask_impl) -> Tuple[Fm
pipelines.append(FmhaFwdPipeline('qr', 'row', 't', 't', 't', 't', bias, lse, dropout, squant, mask)) pipelines.append(FmhaFwdPipeline('qr', 'row', 't', 't', 't', 't', bias, lse, dropout, squant, mask))
pipelines.append(FmhaFwdPipeline('qr', 'col', 't', 't', 't', 't', bias, lse, dropout, squant, mask)) pipelines.append(FmhaFwdPipeline('qr', 'col', 't', 't', 't', 't', bias, lse, dropout, squant, mask))
else:
if bias == "bias":
# TODO: rocm 6.2 compiler problem if using qr_async for bias case
pipelines.append(FmhaFwdPipeline('qr', 'row', 'f', 'f', 'f', 'f', bias, lse, dropout, squant, mask))
pipelines.append(FmhaFwdPipeline('qr', 'row', 't', 't', 't', 't', bias, lse, dropout, squant, mask))
pipelines.append(FmhaFwdPipeline('qr', 'col', 'f', 'f', 'f', 'f', bias, lse, dropout, squant, mask))
pipelines.append(FmhaFwdPipeline('qr', 'col', 't', 't', 't', 't', bias, lse, dropout, squant, mask))
else: else:
pipelines.append(FmhaFwdPipeline('qr_async', 'row', 't', 'f', 't', 't', bias, lse, dropout, squant, mask)) pipelines.append(FmhaFwdPipeline('qr_async', 'row', 't', 'f', 't', 't', bias, lse, dropout, squant, mask))
pipelines.append(FmhaFwdPipeline('qr_async', 'row', 't', 't', 't', 't', bias, lse, dropout, squant, mask)) pipelines.append(FmhaFwdPipeline('qr_async', 'row', 't', 't', 't', 't', bias, lse, dropout, squant, mask))
pipelines.append(FmhaFwdPipeline('qr_async', 'col', 't', 'f', 't', 't', bias, lse, dropout, squant, mask)) pipelines.append(FmhaFwdPipeline('qr_async', 'col', 't', 'f', 't', 't', bias, lse, dropout, squant, mask))
pipelines.append(FmhaFwdPipeline('qr_async', 'col', 't', 't', 't', 't', bias, lse, dropout, squant, mask)) pipelines.append(FmhaFwdPipeline('qr_async', 'col', 't', 't', 't', 't', bias, lse, dropout, squant, mask))
if receipt == 1: if receipt == 1 and bias != "bias":
pipelines.append(FmhaFwdPipeline('qr', 'row', 't', 't', 't', 't', bias, lse, dropout, squant, mask)) # TODO: cover arbitraty hdim pipelines.append(FmhaFwdPipeline('qr', 'row', 't', 't', 't', 't', bias, lse, dropout, squant, mask)) # TODO: cover arbitraty hdim
pipelines.append(FmhaFwdPipeline('qr', 'col', 't', 'f', 't', 't', bias, lse, dropout, squant, mask)) # TODO: cover arbitraty hdim pipelines.append(FmhaFwdPipeline('qr', 'col', 't', 'f', 't', 't', bias, lse, dropout, squant, mask)) # TODO: cover arbitraty hdim
elif dtype in ['fp8', 'bf8']: elif dtype in ['fp8', 'bf8']:
......
...@@ -87,7 +87,11 @@ auto create_args(int argc, char* argv[]) ...@@ -87,7 +87,11 @@ auto create_args(int argc, char* argv[])
.insert("drop_offset", "0", "offset for random number generator") .insert("drop_offset", "0", "offset for random number generator")
.insert("timer", "gpu", "gpu:gpu timer, cpu:cpu timer") .insert("timer", "gpu", "gpu:gpu timer, cpu:cpu timer")
.insert("warmup", "5", "number of iterations before benchmark the kernel") .insert("warmup", "5", "number of iterations before benchmark the kernel")
.insert("repeat", "20", "number of iterations to benchmark the kernel"); .insert("repeat", "20", "number of iterations to benchmark the kernel")
.insert("deterministic",
"0",
"if set to 1 will use multi-buffer reduction strategy for dq, atomic opeartion "
"will not be used");
bool result = arg_parser.parse(argc, argv); bool result = arg_parser.parse(argc, argv);
return std::make_tuple(result, arg_parser); return std::make_tuple(result, arg_parser);
...@@ -128,11 +132,6 @@ bool run(const ck_tile::ArgParser& arg_parser) ...@@ -128,11 +132,6 @@ bool run(const ck_tile::ArgParser& arg_parser)
ck_tile::index_t hdim_v = arg_parser.get_int("d_v"); ck_tile::index_t hdim_v = arg_parser.get_int("d_v");
if(hdim_v < 0) if(hdim_v < 0)
hdim_v = hdim_q; hdim_v = hdim_q;
if(hdim_q % 2 != 0 || hdim_v % 2 != 0)
{
std::cerr << "FMHA Bwd kernel currently only supports even headdim" << std::endl;
return false;
}
bool i_perm = arg_parser.get_bool("iperm"); // if true, will be batch * nhead * seqlen * hdim bool i_perm = arg_parser.get_bool("iperm"); // if true, will be batch * nhead * seqlen * hdim
bool o_perm = arg_parser.get_bool("operm"); // if false, will be batch * seqlen * nhead * hdim bool o_perm = arg_parser.get_bool("operm"); // if false, will be batch * seqlen * nhead * hdim
...@@ -180,6 +179,7 @@ bool run(const ck_tile::ArgParser& arg_parser) ...@@ -180,6 +179,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
int stream_warmup = arg_parser.get_int("warmup"); int stream_warmup = arg_parser.get_int("warmup");
int stream_repeat = arg_parser.get_int("repeat"); int stream_repeat = arg_parser.get_int("repeat");
bool kname = arg_parser.get_bool("kname"); bool kname = arg_parser.get_bool("kname");
bool deterministic = arg_parser.get_bool("deterministic");
ck_tile::stream_config stream_config{nullptr, ck_tile::stream_config stream_config{nullptr,
true, true,
...@@ -265,6 +265,9 @@ bool run(const ck_tile::ArgParser& arg_parser) ...@@ -265,6 +265,9 @@ bool run(const ck_tile::ArgParser& arg_parser)
(mode == mode_enum::batch ? seqlen_q : seqstart_q_host.back()); (mode == mode_enum::batch ? seqlen_q : seqstart_q_host.back());
const ck_tile::index_t shape_seqlen_k = const ck_tile::index_t shape_seqlen_k =
(mode == mode_enum::batch ? seqlen_k : seqstart_k_host.back()); (mode == mode_enum::batch ? seqlen_k : seqstart_k_host.back());
const ck_tile::index_t kN0 = (hdim_q <= 128) ? 128 : 64;
const ck_tile::index_t nsplits =
deterministic ? ck_tile::integer_divide_ceil(max_seqlen_k, kN0) : 1;
ck_tile::HostTensor<QDataType> q_host( ck_tile::HostTensor<QDataType> q_host(
get_lengths(i_perm, shape_batch, nhead, shape_seqlen_q, hdim_q)); get_lengths(i_perm, shape_batch, nhead, shape_seqlen_q, hdim_q));
...@@ -284,9 +287,9 @@ bool run(const ck_tile::ArgParser& arg_parser) ...@@ -284,9 +287,9 @@ bool run(const ck_tile::ArgParser& arg_parser)
ck_tile::HostTensor<ODataType> o_host( ck_tile::HostTensor<ODataType> o_host(
get_lengths(o_perm, shape_batch, nhead, shape_seqlen_q, hdim_v)); get_lengths(o_perm, shape_batch, nhead, shape_seqlen_q, hdim_v));
ck_tile::HostTensor<LSEDataType> lse_host( ck_tile::HostTensor<LSEDataType> lse_host(
std::array<ck_tile::index_t, 3>{batch, nhead, max_seqlen_q}); std::array<ck_tile::index_t, 3>{shape_batch, nhead, shape_seqlen_q});
ck_tile::HostTensor<DDataType> d_host( ck_tile::HostTensor<DDataType> d_host(
std::array<ck_tile::index_t, 3>{batch, nhead, max_seqlen_q}); std::array<ck_tile::index_t, 3>{shape_batch, nhead, shape_seqlen_q});
ck_tile::HostTensor<RandValOutputDataType> randval_host( ck_tile::HostTensor<RandValOutputDataType> randval_host(
p_drop > 0 ? get_lengths(true, shape_batch, nhead, shape_seqlen_q, max_seqlen_k) p_drop > 0 ? get_lengths(true, shape_batch, nhead, shape_seqlen_q, max_seqlen_k)
: std::array<ck_tile::index_t, 4>{1, 1, 1, 1}); : std::array<ck_tile::index_t, 4>{1, 1, 1, 1});
...@@ -302,6 +305,10 @@ bool run(const ck_tile::ArgParser& arg_parser) ...@@ -302,6 +305,10 @@ bool run(const ck_tile::ArgParser& arg_parser)
use_dbias use_dbias
? get_lengths(i_perm, shape_batch, nhead, shape_seqlen_q, max_seqlen_k) ? get_lengths(i_perm, shape_batch, nhead, shape_seqlen_q, max_seqlen_k)
: std::array<ck_tile::index_t, 4>{1, 1, 1, 1} /* dummy shape for simplifying code */); : std::array<ck_tile::index_t, 4>{1, 1, 1, 1} /* dummy shape for simplifying code */);
ck_tile::HostTensor<AccDataType> dq_acc_host(
i_perm
? std::array<ck_tile::index_t, 5>{nsplits, shape_batch, nhead, shape_seqlen_q, hdim_q}
: std::array<ck_tile::index_t, 5>{nsplits, shape_batch, shape_seqlen_q, nhead, hdim_q});
if(init_method == 0) if(init_method == 0)
{ {
...@@ -362,6 +369,7 @@ bool run(const ck_tile::ArgParser& arg_parser) ...@@ -362,6 +369,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
ck_tile::DeviceMem seqstart_q(seqstart_q_host.size() * sizeof(int32_t)); ck_tile::DeviceMem seqstart_q(seqstart_q_host.size() * sizeof(int32_t));
ck_tile::DeviceMem seqstart_k(seqstart_k_host.size() * sizeof(int32_t)); ck_tile::DeviceMem seqstart_k(seqstart_k_host.size() * sizeof(int32_t));
ck_tile::DeviceMem alibi_slope_buf(alibi_slope_host.get_element_space_size_in_bytes()); ck_tile::DeviceMem alibi_slope_buf(alibi_slope_host.get_element_space_size_in_bytes());
ck_tile::DeviceMem dq_acc_buf(dq_acc_host.get_element_space_size_in_bytes());
q_buf.ToDevice(q_host.data()); q_buf.ToDevice(q_host.data());
k_buf.ToDevice(k_host.data()); k_buf.ToDevice(k_host.data());
...@@ -387,8 +395,17 @@ bool run(const ck_tile::ArgParser& arg_parser) ...@@ -387,8 +395,17 @@ bool run(const ck_tile::ArgParser& arg_parser)
std::cout << "[" << prec << "|" << mode << "|" << io_layout(i_perm, o_perm) << "] b:" << batch std::cout << "[" << prec << "|" << mode << "|" << io_layout(i_perm, o_perm) << "] b:" << batch
<< ", h:" << nhead << "/" << nhead_k << ", s:" << seqlen_q << "/" << seqlen_k << ", h:" << nhead << "/" << nhead_k << ", s:" << seqlen_q << "/" << seqlen_k
<< ", d:" << hdim_q << "/" << hdim_v << ", scale:" << scale << ", bias:" << bias << ", d:" << hdim_q << "/" << hdim_v << ", scale:" << scale << ", bias:" << bias
<< ", dbias:" << use_dbias << ", p_drop:" << p_drop << ", mask:" << mask << ", dbias:" << use_dbias << ", p_drop:" << p_drop << ", s_randval:" << s_randval
<< std::flush; << ", deterministic:" << deterministic << ", mask:" << mask << std::flush;
std::size_t workspace_size =
dq_acc_host.get_element_space_size_in_bytes() * sizeof(AccDataType) / (1024 * 1024);
if(deterministic == 1)
{
std::cout << "\nDeterministic mode ON: " << workspace_size
<< " MByte memory workspace allocated" << std::endl;
}
auto fmha_traits = fmha_bwd_traits{hdim_q, auto fmha_traits = fmha_bwd_traits{hdim_q,
hdim_v, hdim_v,
...@@ -397,7 +414,9 @@ bool run(const ck_tile::ArgParser& arg_parser) ...@@ -397,7 +414,9 @@ bool run(const ck_tile::ArgParser& arg_parser)
mask.type, mask.type,
bias.type, bias.type,
use_dbias, use_dbias,
p_drop > 0.0f}; p_drop > 0.0f,
s_randval,
deterministic};
auto fmha_args = [&]() { auto fmha_args = [&]() {
assert(nhead % nhead_k == 0); assert(nhead % nhead_k == 0);
/// NOTE: we broadcast bias from [1, 1, seqlen_q, seqlen_k] to [batch, nhead, seqlen_q, /// NOTE: we broadcast bias from [1, 1, seqlen_q, seqlen_k] to [batch, nhead, seqlen_q,
...@@ -422,7 +441,7 @@ bool run(const ck_tile::ArgParser& arg_parser) ...@@ -422,7 +441,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
const ck_tile::index_t nhead_stride_o = (o_perm ? shape_seqlen_q * hdim_v : hdim_v); const ck_tile::index_t nhead_stride_o = (o_perm ? shape_seqlen_q * hdim_v : hdim_v);
const ck_tile::index_t nhead_stride_randval = (shape_seqlen_q * max_seqlen_k); const ck_tile::index_t nhead_stride_randval = (shape_seqlen_q * max_seqlen_k);
const ck_tile::index_t nhead_stride_do = (o_perm ? shape_seqlen_q * hdim_v : hdim_v); const ck_tile::index_t nhead_stride_do = (o_perm ? shape_seqlen_q * hdim_v : hdim_v);
const ck_tile::index_t nhead_stride_lsed = max_seqlen_q; const ck_tile::index_t nhead_stride_lsed = shape_seqlen_q;
const ck_tile::index_t nhead_stride_dbias = const ck_tile::index_t nhead_stride_dbias =
(i_perm ? shape_seqlen_q * max_seqlen_k : max_seqlen_k); (i_perm ? shape_seqlen_q * max_seqlen_k : max_seqlen_k);
// setup batch_stride_* arguments // setup batch_stride_* arguments
...@@ -433,10 +452,12 @@ bool run(const ck_tile::ArgParser& arg_parser) ...@@ -433,10 +452,12 @@ bool run(const ck_tile::ArgParser& arg_parser)
const ck_tile::index_t batch_stride_o = (nhead * shape_seqlen_q * hdim_v); const ck_tile::index_t batch_stride_o = (nhead * shape_seqlen_q * hdim_v);
const ck_tile::index_t batch_stride_randval = (nhead * shape_seqlen_q * max_seqlen_k); const ck_tile::index_t batch_stride_randval = (nhead * shape_seqlen_q * max_seqlen_k);
const ck_tile::index_t batch_stride_do = (nhead * shape_seqlen_q * hdim_v); const ck_tile::index_t batch_stride_do = (nhead * shape_seqlen_q * hdim_v);
const ck_tile::index_t batch_stride_lsed = (nhead * max_seqlen_q); const ck_tile::index_t batch_stride_lsed = (nhead * shape_seqlen_q);
const ck_tile::index_t batch_stride_dk = (nhead * shape_seqlen_k * hdim_q); const ck_tile::index_t batch_stride_dk = (nhead * shape_seqlen_k * hdim_q);
const ck_tile::index_t batch_stride_dv = (nhead * shape_seqlen_k * hdim_v); const ck_tile::index_t batch_stride_dv = (nhead * shape_seqlen_k * hdim_v);
const ck_tile::index_t batch_stride_dbias = (nhead * shape_seqlen_q * max_seqlen_k); const ck_tile::index_t batch_stride_dbias = (nhead * shape_seqlen_q * max_seqlen_k);
const ck_tile::index_t split_stride_dq_acc =
(shape_batch * nhead * shape_seqlen_q * hdim_q);
return fmha_bwd_args{q_buf.GetDeviceBuffer(), return fmha_bwd_args{q_buf.GetDeviceBuffer(),
k_buf.GetDeviceBuffer(), k_buf.GetDeviceBuffer(),
...@@ -452,6 +473,7 @@ bool run(const ck_tile::ArgParser& arg_parser) ...@@ -452,6 +473,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
dk_buf.GetDeviceBuffer(), dk_buf.GetDeviceBuffer(),
dv_buf.GetDeviceBuffer(), dv_buf.GetDeviceBuffer(),
dbias_buf.GetDeviceBuffer(), dbias_buf.GetDeviceBuffer(),
dq_acc_buf.GetDeviceBuffer(),
seqstart_q.GetDeviceBuffer(), seqstart_q.GetDeviceBuffer(),
seqstart_k.GetDeviceBuffer(), seqstart_k.GetDeviceBuffer(),
nullptr, nullptr,
...@@ -473,6 +495,8 @@ bool run(const ck_tile::ArgParser& arg_parser) ...@@ -473,6 +495,8 @@ bool run(const ck_tile::ArgParser& arg_parser)
stride_o, stride_o,
stride_randval, stride_randval,
stride_do, stride_do,
stride_q, // stride_dq_acc
stride_q, // stride_dq
stride_dk, stride_dk,
stride_dv, stride_dv,
stride_dbias, stride_dbias,
...@@ -484,6 +508,10 @@ bool run(const ck_tile::ArgParser& arg_parser) ...@@ -484,6 +508,10 @@ bool run(const ck_tile::ArgParser& arg_parser)
nhead_stride_randval, nhead_stride_randval,
nhead_stride_do, nhead_stride_do,
nhead_stride_lsed, nhead_stride_lsed,
nhead_stride_q, // nhead_stride_dq_acc
nhead_stride_q, // nhead_stride_dq
nhead_stride_k, // nhead_stride_dk
nhead_stride_v, // nhead_stride_dv
nhead_stride_dbias, nhead_stride_dbias,
batch_stride_q, batch_stride_q,
batch_stride_k, batch_stride_k,
...@@ -493,15 +521,17 @@ bool run(const ck_tile::ArgParser& arg_parser) ...@@ -493,15 +521,17 @@ bool run(const ck_tile::ArgParser& arg_parser)
batch_stride_randval, batch_stride_randval,
batch_stride_do, batch_stride_do,
batch_stride_lsed, batch_stride_lsed,
batch_stride_q, // batch_stride_dq_acc
batch_stride_q, // batch_stride_dq
batch_stride_dk, batch_stride_dk,
batch_stride_dv, batch_stride_dv,
batch_stride_dbias, batch_stride_dbias,
split_stride_dq_acc,
mask.left, mask.left,
mask.right, mask.right,
static_cast<ck_tile::index_t>(mask.type), static_cast<ck_tile::index_t>(mask.type),
p_drop, p_drop,
p_undrop, p_undrop,
s_randval,
{drop_seed, drop_offset}}; {drop_seed, drop_offset}};
}(); }();
...@@ -719,7 +749,7 @@ bool run(const ck_tile::ArgParser& arg_parser) ...@@ -719,7 +749,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
if(o_perm) o_host_ref.ForEach([&](auto& self, auto idx) { o_host(b, idx[0], idx[1] + query_offset, idx[2]) = self(idx); }); if(o_perm) o_host_ref.ForEach([&](auto& self, auto idx) { o_host(b, idx[0], idx[1] + query_offset, idx[2]) = self(idx); });
else o_host_ref.ForEach([&](auto& self, auto idx) { o_host(b, idx[1] + query_offset, idx[0], idx[2]) = self(idx); }); else o_host_ref.ForEach([&](auto& self, auto idx) { o_host(b, idx[1] + query_offset, idx[0], idx[2]) = self(idx); });
lse_host_ref.ForEach([&](auto& self, auto idx) { lse_host(wb, idx[0], idx[1]) = self(idx); }); lse_host_ref.ForEach([&](auto& self, auto idx) { lse_host(b, idx[0], idx[1] + query_offset) = self(idx); });
// clang-format on // clang-format on
q_host_refs.push_back(q_host_ref); q_host_refs.push_back(q_host_ref);
...@@ -738,6 +768,7 @@ bool run(const ck_tile::ArgParser& arg_parser) ...@@ -738,6 +768,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
lse_buf.ToDevice(lse_host.data()); lse_buf.ToDevice(lse_host.data());
dq_buf.SetZero(); dq_buf.SetZero();
dbias_buf.SetZero(); dbias_buf.SetZero();
dq_acc_buf.SetZero();
ck_tile::stream_config stream_config_v{ ck_tile::stream_config stream_config_v{
nullptr, true, 0, 0, 1, arg_parser.get_str("timer") == std::string("gpu")}; nullptr, true, 0, 0, 1, arg_parser.get_str("timer") == std::string("gpu")};
......
...@@ -77,6 +77,7 @@ struct fmha_bwd_args ...@@ -77,6 +77,7 @@ struct fmha_bwd_args
void* dk_ptr; void* dk_ptr;
void* dv_ptr; void* dv_ptr;
void* dbias_ptr; void* dbias_ptr;
void* dq_acc_ptr;
const void* seqstart_q_ptr; const void* seqstart_q_ptr;
const void* seqstart_k_ptr; const void* seqstart_k_ptr;
const void* seqlen_k_ptr; const void* seqlen_k_ptr;
...@@ -97,6 +98,8 @@ struct fmha_bwd_args ...@@ -97,6 +98,8 @@ struct fmha_bwd_args
ck_tile::index_t stride_o; ck_tile::index_t stride_o;
ck_tile::index_t stride_randval; ck_tile::index_t stride_randval;
ck_tile::index_t stride_do; ck_tile::index_t stride_do;
ck_tile::index_t stride_dq_acc;
ck_tile::index_t stride_dq;
ck_tile::index_t stride_dk; ck_tile::index_t stride_dk;
ck_tile::index_t stride_dv; ck_tile::index_t stride_dv;
ck_tile::index_t stride_dbias; ck_tile::index_t stride_dbias;
...@@ -108,6 +111,10 @@ struct fmha_bwd_args ...@@ -108,6 +111,10 @@ struct fmha_bwd_args
ck_tile::index_t nhead_stride_randval; ck_tile::index_t nhead_stride_randval;
ck_tile::index_t nhead_stride_do; ck_tile::index_t nhead_stride_do;
ck_tile::index_t nhead_stride_lsed; ck_tile::index_t nhead_stride_lsed;
ck_tile::index_t nhead_stride_dq_acc;
ck_tile::index_t nhead_stride_dq;
ck_tile::index_t nhead_stride_dk;
ck_tile::index_t nhead_stride_dv;
ck_tile::index_t nhead_stride_dbias; ck_tile::index_t nhead_stride_dbias;
ck_tile::index_t batch_stride_q; ck_tile::index_t batch_stride_q;
ck_tile::index_t batch_stride_k; ck_tile::index_t batch_stride_k;
...@@ -117,15 +124,17 @@ struct fmha_bwd_args ...@@ -117,15 +124,17 @@ struct fmha_bwd_args
ck_tile::index_t batch_stride_randval; ck_tile::index_t batch_stride_randval;
ck_tile::index_t batch_stride_do; ck_tile::index_t batch_stride_do;
ck_tile::index_t batch_stride_lsed; ck_tile::index_t batch_stride_lsed;
ck_tile::index_t batch_stride_dq_acc;
ck_tile::index_t batch_stride_dq;
ck_tile::index_t batch_stride_dk; ck_tile::index_t batch_stride_dk;
ck_tile::index_t batch_stride_dv; ck_tile::index_t batch_stride_dv;
ck_tile::index_t batch_stride_dbias; ck_tile::index_t batch_stride_dbias;
ck_tile::index_t split_stride_dq_acc;
ck_tile::index_t window_size_left; ck_tile::index_t window_size_left;
ck_tile::index_t window_size_right; ck_tile::index_t window_size_right;
ck_tile::index_t mask_type; ck_tile::index_t mask_type;
float p_drop; float p_drop;
float p_undrop; float p_undrop;
bool s_randval;
std::tuple<uint64_t, uint64_t> drop_seed_offset; std::tuple<uint64_t, uint64_t> drop_seed_offset;
}; };
...@@ -145,10 +154,10 @@ auto fmha_bwd_dq_dk_dv_create_kargs_and_grids(fmha_bwd_args args) ...@@ -145,10 +154,10 @@ auto fmha_bwd_dq_dk_dv_create_kargs_and_grids(fmha_bwd_args args)
args.do_ptr, args.do_ptr,
args.d_ptr, args.d_ptr,
args.rand_val_ptr, args.rand_val_ptr,
args.dq_ptr,
args.dk_ptr, args.dk_ptr,
args.dv_ptr, args.dv_ptr,
args.dbias_ptr, args.dbias_ptr,
args.dq_acc_ptr,
args.seqstart_q_ptr, args.seqstart_q_ptr,
args.seqstart_k_ptr, args.seqstart_k_ptr,
args.seqlen_k_ptr, args.seqlen_k_ptr,
...@@ -163,6 +172,7 @@ auto fmha_bwd_dq_dk_dv_create_kargs_and_grids(fmha_bwd_args args) ...@@ -163,6 +172,7 @@ auto fmha_bwd_dq_dk_dv_create_kargs_and_grids(fmha_bwd_args args)
args.stride_bias, args.stride_bias,
args.stride_randval, args.stride_randval,
args.stride_do, args.stride_do,
args.stride_dq_acc,
args.stride_dk, args.stride_dk,
args.stride_dv, args.stride_dv,
args.stride_dbias, args.stride_dbias,
...@@ -173,13 +183,15 @@ auto fmha_bwd_dq_dk_dv_create_kargs_and_grids(fmha_bwd_args args) ...@@ -173,13 +183,15 @@ auto fmha_bwd_dq_dk_dv_create_kargs_and_grids(fmha_bwd_args args)
args.nhead_stride_randval, args.nhead_stride_randval,
args.nhead_stride_do, args.nhead_stride_do,
args.nhead_stride_lsed, args.nhead_stride_lsed,
args.nhead_stride_dq_acc,
args.nhead_stride_dk,
args.nhead_stride_dv,
args.nhead_stride_dbias, args.nhead_stride_dbias,
args.batch_stride_lsed, args.split_stride_dq_acc,
args.window_size_left, args.window_size_left,
args.window_size_right, args.window_size_right,
args.mask_type, args.mask_type,
args.p_drop, args.p_drop,
args.s_randval,
args.drop_seed_offset); args.drop_seed_offset);
} }
else else
...@@ -192,10 +204,10 @@ auto fmha_bwd_dq_dk_dv_create_kargs_and_grids(fmha_bwd_args args) ...@@ -192,10 +204,10 @@ auto fmha_bwd_dq_dk_dv_create_kargs_and_grids(fmha_bwd_args args)
args.do_ptr, args.do_ptr,
args.d_ptr, args.d_ptr,
args.rand_val_ptr, args.rand_val_ptr,
args.dq_ptr,
args.dk_ptr, args.dk_ptr,
args.dv_ptr, args.dv_ptr,
args.dbias_ptr, args.dbias_ptr,
args.dq_acc_ptr,
args.seqlen_q, args.seqlen_q,
args.seqlen_k, args.seqlen_k,
args.hdim_q, args.hdim_q,
...@@ -209,6 +221,7 @@ auto fmha_bwd_dq_dk_dv_create_kargs_and_grids(fmha_bwd_args args) ...@@ -209,6 +221,7 @@ auto fmha_bwd_dq_dk_dv_create_kargs_and_grids(fmha_bwd_args args)
args.stride_bias, args.stride_bias,
args.stride_randval, args.stride_randval,
args.stride_do, args.stride_do,
args.stride_dq_acc,
args.stride_dk, args.stride_dk,
args.stride_dv, args.stride_dv,
args.stride_dbias, args.stride_dbias,
...@@ -219,6 +232,9 @@ auto fmha_bwd_dq_dk_dv_create_kargs_and_grids(fmha_bwd_args args) ...@@ -219,6 +232,9 @@ auto fmha_bwd_dq_dk_dv_create_kargs_and_grids(fmha_bwd_args args)
args.nhead_stride_randval, args.nhead_stride_randval,
args.nhead_stride_do, args.nhead_stride_do,
args.nhead_stride_lsed, args.nhead_stride_lsed,
args.nhead_stride_dq_acc,
args.nhead_stride_dk,
args.nhead_stride_dv,
args.nhead_stride_dbias, args.nhead_stride_dbias,
args.batch_stride_q, args.batch_stride_q,
args.batch_stride_k, args.batch_stride_k,
...@@ -227,14 +243,15 @@ auto fmha_bwd_dq_dk_dv_create_kargs_and_grids(fmha_bwd_args args) ...@@ -227,14 +243,15 @@ auto fmha_bwd_dq_dk_dv_create_kargs_and_grids(fmha_bwd_args args)
args.batch_stride_randval, args.batch_stride_randval,
args.batch_stride_do, args.batch_stride_do,
args.batch_stride_lsed, args.batch_stride_lsed,
args.batch_stride_dq_acc,
args.batch_stride_dk, args.batch_stride_dk,
args.batch_stride_dv, args.batch_stride_dv,
args.batch_stride_dbias, args.batch_stride_dbias,
args.split_stride_dq_acc,
args.window_size_left, args.window_size_left,
args.window_size_right, args.window_size_right,
args.mask_type, args.mask_type,
args.p_drop, args.p_drop,
args.s_randval,
args.drop_seed_offset); args.drop_seed_offset);
} }
}(); }();
...@@ -260,8 +277,7 @@ auto fmha_bwd_dot_do_o_create_kargs_and_grids(fmha_bwd_args args) ...@@ -260,8 +277,7 @@ auto fmha_bwd_dot_do_o_create_kargs_and_grids(fmha_bwd_args args)
args.stride_o, args.stride_o,
args.nhead_stride_do, args.nhead_stride_do,
args.nhead_stride_o, args.nhead_stride_o,
args.nhead_stride_lsed, args.nhead_stride_lsed);
args.batch_stride_lsed);
} }
else else
{ // create batch mode kernel arguments { // create batch mode kernel arguments
...@@ -286,19 +302,59 @@ auto fmha_bwd_dot_do_o_create_kargs_and_grids(fmha_bwd_args args) ...@@ -286,19 +302,59 @@ auto fmha_bwd_dot_do_o_create_kargs_and_grids(fmha_bwd_args args)
return ck_tile::make_tuple(kargs, grids); return ck_tile::make_tuple(kargs, grids);
} }
template <typename FmhaBwdConvertQGradKernel>
auto fmha_bwd_convert_dq_create_kargs_and_grids(fmha_bwd_args args)
{
auto kargs = [&] {
// create group mode kernel arguments
if constexpr(FmhaBwdConvertQGradKernel::kIsGroupMode)
{
return FmhaBwdConvertQGradKernel::MakeKargs(args.dq_acc_ptr,
args.dq_ptr,
args.seqstart_q_ptr,
args.seqstart_k_ptr,
args.hdim_q,
args.stride_dq,
args.stride_dq_acc,
args.nhead_stride_dq,
args.nhead_stride_dq_acc,
args.split_stride_dq_acc);
}
else
{ // create batch mode kernel arguments
return FmhaBwdConvertQGradKernel::MakeKargs(args.dq_acc_ptr,
args.dq_ptr,
args.seqlen_q,
args.seqlen_k,
args.hdim_q,
args.stride_dq,
args.stride_dq_acc,
args.nhead_stride_dq,
args.nhead_stride_dq_acc,
args.batch_stride_dq,
args.batch_stride_dq_acc,
args.split_stride_dq_acc);
}
}();
dim3 grids = FmhaBwdConvertQGradKernel::GridSize(args.batch, args.nhead_q, args.max_seqlen_q);
return ck_tile::make_tuple(kargs, grids);
}
// this is used to pattern-match internl kernel implementation, not to instantiate kernel // this is used to pattern-match internl kernel implementation, not to instantiate kernel
template <ck_tile::index_t HDim_, template <ck_tile::index_t HDim_,
typename DataType_, typename DataType_,
bool kIsGroupMode_, bool kIsGroupMode_,
ck_tile::BlockFmhaBwdPipelineEnum FmhaBwdPipelineEnum_, ck_tile::BlockFmhaBwdPipelineEnum FmhaBwdPipelineEnum_,
typename FmhaMask_, typename FmhaMask_,
typename FmhaDropout_,
ck_tile::BlockAttentionBiasEnum BiasEnum_, ck_tile::BlockAttentionBiasEnum BiasEnum_,
bool kHasBiasGrad_, bool kHasBiasGrad_,
bool kHasDropout_,
bool kPadS_, bool kPadS_,
bool kPadSK_, bool kPadSK_,
bool kPadD_, bool kPadD_,
bool kPadDv_> bool kPadDv_,
bool kIsDeterministic_>
struct fmha_bwd_dq_dk_dv_traits_ struct fmha_bwd_dq_dk_dv_traits_
{ {
static constexpr ck_tile::index_t HDim = HDim_; static constexpr ck_tile::index_t HDim = HDim_;
...@@ -306,13 +362,14 @@ struct fmha_bwd_dq_dk_dv_traits_ ...@@ -306,13 +362,14 @@ struct fmha_bwd_dq_dk_dv_traits_
static constexpr bool kIsGroupMode = kIsGroupMode_; static constexpr bool kIsGroupMode = kIsGroupMode_;
static constexpr auto FmhaBwdPipelineEnum = FmhaBwdPipelineEnum_; static constexpr auto FmhaBwdPipelineEnum = FmhaBwdPipelineEnum_;
using FmhaMask = ck_tile::remove_cvref_t<FmhaMask_>; using FmhaMask = ck_tile::remove_cvref_t<FmhaMask_>;
using FmhaDropout = ck_tile::remove_cvref_t<FmhaDropout_>;
static constexpr auto BiasEnum = BiasEnum_; static constexpr auto BiasEnum = BiasEnum_;
static constexpr bool kHasBiasGrad = kHasBiasGrad_; static constexpr bool kHasBiasGrad = kHasBiasGrad_;
static constexpr bool kHasDropout = kHasDropout_;
static constexpr bool kPadS = kPadS_; static constexpr bool kPadS = kPadS_;
static constexpr bool kPadSK = kPadSK_; static constexpr bool kPadSK = kPadSK_;
static constexpr bool kPadD = kPadD_; static constexpr bool kPadD = kPadD_;
static constexpr bool kPadDv = kPadDv_; static constexpr bool kPadDv = kPadDv_;
static constexpr bool kIsDeterministic = kIsDeterministic_;
}; };
template <typename Traits_> template <typename Traits_>
...@@ -343,6 +400,31 @@ void fmha_bwd_dot_do_o_oneshot_(const ck_tile::stream_config&, fmha_bwd_args); ...@@ -343,6 +400,31 @@ void fmha_bwd_dot_do_o_oneshot_(const ck_tile::stream_config&, fmha_bwd_args);
template <typename Traits_> template <typename Traits_>
std::string fmha_bwd_dot_do_o_get_name_(); std::string fmha_bwd_dot_do_o_get_name_();
template <ck_tile::index_t HDim_,
typename DataType_,
bool kIsGroupMode_,
bool kPadS_,
bool kPadD_,
bool kIsDeterministic_>
struct fmha_bwd_convert_dq_traits_
{
static constexpr ck_tile::index_t HDim = HDim_;
using DataType = ck_tile::remove_cvref_t<DataType_>;
static constexpr bool kIsGroupMode = kIsGroupMode_;
static constexpr bool kPadS = kPadS_;
static constexpr bool kPadD = kPadD_;
static constexpr bool kIsDeterministic = kIsDeterministic_;
};
template <typename Traits_>
float fmha_bwd_convert_dq_(const ck_tile::stream_config&, fmha_bwd_args);
template <typename Traits_>
void fmha_bwd_convert_dq_oneshot_(const ck_tile::stream_config&, fmha_bwd_args);
template <typename Traits_>
std::string fmha_bwd_convert_dq_get_name_();
// This is the public API, will be generated by script // This is the public API, will be generated by script
struct fmha_bwd_traits struct fmha_bwd_traits
{ {
...@@ -354,6 +436,8 @@ struct fmha_bwd_traits ...@@ -354,6 +436,8 @@ struct fmha_bwd_traits
bias_enum bias_type; // 0:no bias, 1:elementwise bias, 2:alibi. sync with BlockAttentionBiasEnum bias_enum bias_type; // 0:no bias, 1:elementwise bias, 2:alibi. sync with BlockAttentionBiasEnum
bool has_dbias; bool has_dbias;
bool has_dropout; bool has_dropout;
bool is_store_randval;
bool is_deterministic;
// TODO: padding check is inside this api // TODO: padding check is inside this api
}; };
float fmha_bwd(fmha_bwd_traits, fmha_bwd_args, const ck_tile::stream_config&); float fmha_bwd(fmha_bwd_traits, fmha_bwd_args, const ck_tile::stream_config&);
...@@ -479,16 +479,18 @@ bool run(const ck_tile::ArgParser& arg_parser) ...@@ -479,16 +479,18 @@ bool run(const ck_tile::ArgParser& arg_parser)
: std::array<ck_tile::index_t, 2>{1, 1}); : std::array<ck_tile::index_t, 2>{1, 1});
ck_tile::HostTensor<LSEDataType> lse_acc_host( ck_tile::HostTensor<LSEDataType> lse_acc_host(
1 < num_splits ? std::array<ck_tile::index_t, 4>{num_splits, batch, nhead, max_seqlen_q} 1 < num_splits
? std::array<ck_tile::index_t, 4>{num_splits, shape_batch, nhead, shape_seqlen_q}
: std::array<ck_tile::index_t, 4>{1, 1, 1, 1}); : std::array<ck_tile::index_t, 4>{1, 1, 1, 1});
ck_tile::HostTensor<OaccDataType> o_acc_host( ck_tile::HostTensor<OaccDataType> o_acc_host(
1 < num_splits 1 < num_splits
? std::array<ck_tile::index_t, 5>{num_splits, batch, nhead, max_seqlen_q, hdim_v} ? std::array<ck_tile::index_t, 5>{num_splits, batch, nhead, max_seqlen_q, hdim_v}
: std::array<ck_tile::index_t, 5>{1, 1, 1, 1, 1}); : std::array<ck_tile::index_t, 5>{1, 1, 1, 1, 1});
// self define lse data layout as [batch, nhead, max_seqlen_q] // batch mode of lse data layout is [batch, nhead, seqlen_q]
// group mode of lse data layout is [nhead, total_seqlen_q]
ck_tile::HostTensor<LSEDataType> lse_host( ck_tile::HostTensor<LSEDataType> lse_host(
lse ? std::array<ck_tile::index_t, 3>{batch, nhead, max_seqlen_q} lse ? std::array<ck_tile::index_t, 3>{shape_batch, nhead, shape_seqlen_q}
: std::array<ck_tile::index_t, 3>{1, 1, 1} /* dummy shape for simplifying code */); : std::array<ck_tile::index_t, 3>{1, 1, 1} /* dummy shape for simplifying code */);
ck_tile::HostTensor<ODataType> o_host( ck_tile::HostTensor<ODataType> o_host(
...@@ -669,8 +671,8 @@ bool run(const ck_tile::ArgParser& arg_parser) ...@@ -669,8 +671,8 @@ bool run(const ck_tile::ArgParser& arg_parser)
const ck_tile::index_t nhead_stride_bias = const ck_tile::index_t nhead_stride_bias =
(i_perm ? 0 * shape_seqlen_q * shape_seqlen_k : 0 * shape_seqlen_k); (i_perm ? 0 * shape_seqlen_q * shape_seqlen_k : 0 * shape_seqlen_k);
const ck_tile::index_t nhead_stride_randval = (shape_seqlen_q * max_seqlen_k); const ck_tile::index_t nhead_stride_randval = (shape_seqlen_q * max_seqlen_k);
const ck_tile::index_t nhead_stride_lse = max_seqlen_q; const ck_tile::index_t nhead_stride_lse = shape_seqlen_q;
const ck_tile::index_t nhead_stride_lse_acc = max_seqlen_q; const ck_tile::index_t nhead_stride_lse_acc = shape_seqlen_q;
const ck_tile::index_t nhead_stride_o_acc = (max_seqlen_q * hdim_v); const ck_tile::index_t nhead_stride_o_acc = (max_seqlen_q * hdim_v);
const ck_tile::index_t nhead_stride_o = (o_perm ? shape_seqlen_q * hdim_v : hdim_v); const ck_tile::index_t nhead_stride_o = (o_perm ? shape_seqlen_q * hdim_v : hdim_v);
// setup batch_stride_* arguments // setup batch_stride_* arguments
...@@ -679,12 +681,12 @@ bool run(const ck_tile::ArgParser& arg_parser) ...@@ -679,12 +681,12 @@ bool run(const ck_tile::ArgParser& arg_parser)
const ck_tile::index_t batch_stride_v = (nhead_k * hdim_v * shape_seqlen_k); const ck_tile::index_t batch_stride_v = (nhead_k * hdim_v * shape_seqlen_k);
const ck_tile::index_t batch_stride_bias = (0 * nhead * shape_seqlen_q * shape_seqlen_k); const ck_tile::index_t batch_stride_bias = (0 * nhead * shape_seqlen_q * shape_seqlen_k);
const ck_tile::index_t batch_stride_randval = (nhead * shape_seqlen_q * max_seqlen_k); const ck_tile::index_t batch_stride_randval = (nhead * shape_seqlen_q * max_seqlen_k);
const ck_tile::index_t batch_stride_lse = (nhead * max_seqlen_q); const ck_tile::index_t batch_stride_lse = (nhead * shape_seqlen_q);
const ck_tile::index_t batch_stride_lse_acc = (nhead * max_seqlen_q); const ck_tile::index_t batch_stride_lse_acc = (nhead * shape_seqlen_q);
const ck_tile::index_t batch_stride_o_acc = (nhead * max_seqlen_q * hdim_v); const ck_tile::index_t batch_stride_o_acc = (nhead * max_seqlen_q * hdim_v);
const ck_tile::index_t batch_stride_o = (nhead * shape_seqlen_q * hdim_v); const ck_tile::index_t batch_stride_o = (nhead * shape_seqlen_q * hdim_v);
// setup split_stride_* arguments (only used in split-kv kernel) // setup split_stride_* arguments (only used in split-kv kernel)
const ck_tile::index_t split_stride_lse_acc = (batch * nhead * max_seqlen_q); const ck_tile::index_t split_stride_lse_acc = (shape_batch * nhead * shape_seqlen_q);
const ck_tile::index_t split_stride_o_acc = (batch * nhead * max_seqlen_q * hdim_v); const ck_tile::index_t split_stride_o_acc = (batch * nhead * max_seqlen_q * hdim_v);
return fmha_fwd_args{q_buf.GetDeviceBuffer(), return fmha_fwd_args{q_buf.GetDeviceBuffer(),
...@@ -996,8 +998,9 @@ bool run(const ck_tile::ArgParser& arg_parser) ...@@ -996,8 +998,9 @@ bool run(const ck_tile::ArgParser& arg_parser)
if(lse) if(lse)
{ {
ck_tile::HostTensor<SMPLComputeDataType> lse_host_result({nhead, real_seqlen_q}); ck_tile::HostTensor<SMPLComputeDataType> lse_host_result({nhead, real_seqlen_q});
lse_host_result.ForEach( lse_host_result.ForEach([&](auto& self, auto idx) {
[&](auto& self, auto idx) { self(idx) = lse_host(wb, idx[0], idx[1]); }); self(idx) = lse_host(b, idx[0], idx[1] + query_offset);
});
cur_pass = ck_tile::check_err(lse_host_result, cur_pass = ck_tile::check_err(lse_host_result,
lse_host_ref, lse_host_ref,
......
...@@ -185,7 +185,6 @@ auto fmha_fwd_create_kargs_and_grids(fmha_fwd_args args) ...@@ -185,7 +185,6 @@ auto fmha_fwd_create_kargs_and_grids(fmha_fwd_args args)
args.nhead_stride_randval, args.nhead_stride_randval,
args.nhead_stride_lse, args.nhead_stride_lse,
args.nhead_stride_o, args.nhead_stride_o,
args.batch_stride_lse,
args.window_size_left, args.window_size_left,
args.window_size_right, args.window_size_right,
args.mask_type, args.mask_type,
...@@ -284,7 +283,6 @@ auto fmha_fwd_splitkv_create_kargs_and_grids(fmha_fwd_args args) ...@@ -284,7 +283,6 @@ auto fmha_fwd_splitkv_create_kargs_and_grids(fmha_fwd_args args)
args.nhead_stride_randval, args.nhead_stride_randval,
args.nhead_stride_lse_acc, args.nhead_stride_lse_acc,
args.nhead_stride_o_acc, args.nhead_stride_o_acc,
args.batch_stride_lse_acc,
args.batch_stride_o_acc, args.batch_stride_o_acc,
args.split_stride_lse_acc, args.split_stride_lse_acc,
args.split_stride_o_acc, args.split_stride_o_acc,
...@@ -376,9 +374,7 @@ auto fmha_fwd_splitkv_combine_create_kargs_and_grids(fmha_fwd_args args) ...@@ -376,9 +374,7 @@ auto fmha_fwd_splitkv_combine_create_kargs_and_grids(fmha_fwd_args args)
args.nhead_stride_o_acc, args.nhead_stride_o_acc,
args.nhead_stride_lse, args.nhead_stride_lse,
args.nhead_stride_o, args.nhead_stride_o,
args.batch_stride_lse_acc,
args.batch_stride_o_acc, args.batch_stride_o_acc,
args.batch_stride_lse,
args.split_stride_lse_acc, args.split_stride_lse_acc,
args.split_stride_o_acc); args.split_stride_o_acc);
} }
......
#!/bin/bash
#
# in order to run this script you'd first need to build the tile_example_fmha_fwd and tile_eaxmple_fmha_bwd executables in ../build/bin/
#
# run the script as "./run_full_test.sh <tag for your test environment> <branch name> <host name> <gpu_arch>
# input arguments:
# environment tag : a string describing the specifics of your test environment
# branch name : name of the branch in git repo (git status | grep -e 'On branch')
# host name : $hostname
# gpu architecture: e.g., gfx90a, or gfx942, etc.
#get the command line arguments:
export env_type=$1
echo 'Environment type: ' $env_type
export branch=$2
echo 'Branch name: ' $branch
export host_name=$3
echo 'Host name: ' $host_name
export GPU_arch=$4
echo 'GPU_arch: ' $GPU_arch
function print_log_header(){
rm -f $1;
echo 'On branch ' $3 &> $1;
echo 'Node name: ' $4 >> $1;
#get GPU_arch and number of compute units from rocminfo
echo -n "GPU_arch: " >> $1; rocminfo | grep "Name:" | grep "gfx" >> $1;
rocminfo | grep "Compute Unit:" >> $1;
hipcc --version | grep -e 'HIP version' >> $1;
echo 'Environment type: ' $2 >> $1;
/opt/rocm/bin/amdclang++ --version | grep -e 'InstalledDir' >> $1;
}
#run verification tests
example/ck_tile/01_fmha/script/smoke_test_fwd.sh
example/ck_tile/01_fmha/script/smoke_test_bwd.sh
#run performance benchmarks
export fmha_fwd_log="perf_fmha_fwd_$GPU_arch.log"
print_log_header $fmha_fwd_log $env_type $branch $host_name
example/ck_tile/01_fmha/script/benchmark_fwd.sh 2>&1 | tee -a $fmha_fwd_log
export fmha_bwd_log="perf_fmha_bwd_$GPU_arch.log"
print_log_header $fmha_bwd_log $env_type $branch $host_name
example/ck_tile/01_fmha/script/benchmark_bwd.sh 2>&1 | tee -a $fmha_bwd_log
...@@ -11,18 +11,19 @@ COMMON_ARGS='-v=1' ...@@ -11,18 +11,19 @@ COMMON_ARGS='-v=1'
set -x set -x
for prec in "fp16" "bf16" ; do for prec in "fp16" "bf16" ; do
for perm in 0 1 ; do for perm in 0 1 ; do
for hdim in 32 64 128 ; do for hdim in 32 64 128 256 ; do
for mode in 0 1 ; do for mode in 0 1 ; do
for bias in "n" "e" "a"; do for bias in "n" "a" ; do
for dbias in 0 1 ; do for dbias in 0 ; do
for p_drop in 0.0 0.2; do for p_drop in 0.0 0.2 ; do
for deterministic in 0 ; do
$EXE -prec=$prec -b=1 -h=4 -h_k=2 -d=$hdim -s=259 -bias=$bias -dbias=$dbias -p_drop=$p_drop -iperm=$perm -operm=$perm -v=1 -mode=$mode -kname=$KNAME $COMMON_ARGS $EXE -prec=$prec -b=1 -h=4 -h_k=2 -d=$hdim -s=259 -bias=$bias -dbias=$dbias -p_drop=$p_drop -iperm=$perm -operm=$perm -deterministic=$deterministic -v=1 -mode=$mode -kname=$KNAME $COMMON_ARGS
$EXE -prec=$prec -b=2 -h=2 -d=$hdim -s=516 -s_k=253 -bias=$bias -dbias=$dbias -p_drop=$p_drop -iperm=$perm -operm=$perm -v=1 -mode=$mode -kname=$KNAME $COMMON_ARGS $EXE -prec=$prec -b=2 -h=2 -d=$hdim -s=516 -s_k=253 -bias=$bias -dbias=$dbias -p_drop=$p_drop -iperm=$perm -operm=$perm -deterministic=$deterministic -v=1 -mode=$mode -kname=$KNAME $COMMON_ARGS
$EXE -prec=$prec -b=1 -h=4 -h_k=1 -d=$hdim -s=500 -s_k=251 -bias=$bias -dbias=$dbias -p_drop=$p_drop -iperm=$perm -operm=$perm -mask=1 -v=1 -mode=$mode -kname=$KNAME $COMMON_ARGS $EXE -prec=$prec -b=1 -h=4 -h_k=1 -d=$hdim -s=500 -s_k=251 -bias=$bias -dbias=$dbias -p_drop=$p_drop -iperm=$perm -operm=$perm -mask=1 -deterministic=$deterministic -v=1 -mode=$mode -kname=$KNAME $COMMON_ARGS
$EXE -prec=$prec -b=1 -h=2 -d=$hdim -s=900 -s_k=258 -bias=$bias -dbias=$dbias -p_drop=$p_drop -iperm=$perm -operm=$perm -mask=2 -v=1 -mode=$mode -kname=$KNAME $COMMON_ARGS $EXE -prec=$prec -b=1 -h=2 -d=$hdim -s=900 -s_k=258 -bias=$bias -dbias=$dbias -p_drop=$p_drop -iperm=$perm -operm=$perm -mask=2 -v=1 -deterministic=$deterministic -mode=$mode -kname=$KNAME $COMMON_ARGS
$EXE -prec=$prec -b=2 -h=1 -d=$hdim -s=987 -s_k=219 -bias=$bias -dbias=$dbias -p_drop=$p_drop -iperm=$perm -operm=$perm -mask=t:128,30 -v=1 -mode=$mode -kname=$KNAME $COMMON_ARGS $EXE -prec=$prec -b=2 -h=1 -d=$hdim -s=987 -s_k=219 -bias=$bias -dbias=$dbias -p_drop=$p_drop -iperm=$perm -operm=$perm -mask=t:128,30 -deterministic=$deterministic -v=1 -mode=$mode -kname=$KNAME $COMMON_ARGS
$EXE -prec=$prec -b=2 -h=3 -h_k=1 -d=$hdim -s=244 -s_k=499 -bias=$bias -dbias=$dbias -p_drop=$p_drop -iperm=$perm -operm=$perm -mask=b:4,35 -v=1 -mode=$mode -kname=$KNAME $COMMON_ARGS $EXE -prec=$prec -b=2 -h=3 -h_k=1 -d=$hdim -s=244 -s_k=499 -bias=$bias -dbias=$dbias -p_drop=$p_drop -iperm=$perm -operm=$perm -mask=b:4,35 -deterministic=$deterministic -v=1 -mode=$mode -kname=$KNAME $COMMON_ARGS
done done
done done
...@@ -31,4 +32,5 @@ done ...@@ -31,4 +32,5 @@ done
done done
done done
done done
done
set +x set +x
// SPDX-License-Identifier: MIT // SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once #pragma once
...@@ -153,8 +153,8 @@ CK_DECLARE_ENV_VAR_BOOL(CK_LOGGING) ...@@ -153,8 +153,8 @@ CK_DECLARE_ENV_VAR_BOOL(CK_LOGGING)
// LDS direct loads using inline assembly // LDS direct loads using inline assembly
#define CK_USE_AMD_LDS_DIRECT_LOAD_INLINE_ASM 0 #define CK_USE_AMD_LDS_DIRECT_LOAD_INLINE_ASM 0
// set stochastic rounding as default for f8 conversions // set rounding to nearest even as default for f8 conversions
#define CK_USE_SR_F8_CONVERSION 1 #define CK_USE_SR_F8_CONVERSION 0
// block synchronization only s_wait lgkmcnt(0), not vmcnt(0) // block synchronization only s_wait lgkmcnt(0), not vmcnt(0)
#define CK_EXPERIMENTAL_BLOCK_SYNC_LDS_WITHOUT_SYNC_VMEM 1 #define CK_EXPERIMENTAL_BLOCK_SYNC_LDS_WITHOUT_SYNC_VMEM 1
......
...@@ -65,6 +65,12 @@ inline bool is_lds_direct_load_supported() ...@@ -65,6 +65,12 @@ inline bool is_lds_direct_load_supported()
ck::get_device_name() == "gfx941" || ck::get_device_name() == "gfx942"; ck::get_device_name() == "gfx941" || ck::get_device_name() == "gfx942";
} }
inline bool is_bf16_atomic_supported()
{
return ck::get_device_name() == "gfx940" || ck::get_device_name() == "gfx941" ||
ck::get_device_name() == "gfx942";
}
inline bool is_gfx101_supported() inline bool is_gfx101_supported()
{ {
return ck::get_device_name() == "gfx1010" || ck::get_device_name() == "gfx1011" || return ck::get_device_name() == "gfx1010" || ck::get_device_name() == "gfx1011" ||
......
...@@ -14,6 +14,124 @@ ...@@ -14,6 +14,124 @@
namespace ck { namespace ck {
namespace utility { namespace utility {
template <typename Argument, typename DsDataType>
struct RotatingMemWrapperMultiD
{
static constexpr index_t NumDs = DsDataType::Size();
using ADataType = decltype(Argument::p_a_grid);
using BDataType = decltype(Argument::p_b_grid);
using DsGridPointer = decltype(Argument::p_ds_grid);
RotatingMemWrapperMultiD() = delete;
RotatingMemWrapperMultiD(Argument& arg_,
std::size_t rotating_count_,
std::size_t size_a_,
std::size_t size_b_,
std::array<std::size_t, NumDs> size_ds_)
: arg(arg_),
rotating_count(rotating_count_),
size_a(size_a_),
size_b(size_b_),
size_ds(size_ds_)
{
p_a_grids.push_back(arg.p_a_grid);
p_b_grids.push_back(arg.p_b_grid);
p_ds_grids.push_back(arg.p_ds_grid);
for(size_t i = 1; i < rotating_count; i++)
{
{
void* pADeviceBuf;
hip_check_error(hipMalloc(static_cast<void**>(&pADeviceBuf), size_a_));
hip_check_error(hipMemcpy(static_cast<void*>(pADeviceBuf),
const_cast<void*>(p_a_grids[0]),
size_a_,
hipMemcpyDeviceToDevice));
p_a_grids.push_back(pADeviceBuf);
}
{
void* pBDeviceBuf;
hip_check_error(hipMalloc(static_cast<void**>(&pBDeviceBuf), size_b_));
hip_check_error(hipMemcpy(static_cast<void*>(pBDeviceBuf),
const_cast<void*>(p_b_grids[0]),
size_b_,
hipMemcpyDeviceToDevice));
p_b_grids.push_back(pBDeviceBuf);
}
{
DsGridPointer ds_buffer;
static_for<0, NumDs, 1>{}([&](auto j) {
void* pDDeviceBuf;
hip_check_error(hipMalloc(static_cast<void**>(&pDDeviceBuf), size_ds_[j]));
hip_check_error(hipMemcpy(static_cast<void*>(pDDeviceBuf),
static_cast<const void*>(p_ds_grids[0][j]),
size_ds_[j],
hipMemcpyDeviceToDevice));
using DDataType = remove_cvref_t<tuple_element_t<j.value, DsDataType>>;
ds_buffer(j) = static_cast<const DDataType*>(pDDeviceBuf);
});
p_ds_grids.push_back(ds_buffer);
}
}
}
void Next()
{
if(rotating_count > 1)
{
std::size_t idx = iter++ % rotating_count;
arg.p_a_grid = reinterpret_cast<ADataType>(p_a_grids[idx]);
arg.p_b_grid = reinterpret_cast<BDataType>(p_b_grids[idx]);
arg.p_ds_grid = p_ds_grids[idx];
}
}
void Print()
{
std::cout << "RotatingMemWrapperMultiD: { size_a: " << size_a << ", size_b: " << size_b
<< ", rotating_count: " << rotating_count << "}" << std::endl;
}
~RotatingMemWrapperMultiD()
{
if(rotating_count > 1)
{
// restore ptr
arg.p_a_grid = reinterpret_cast<ADataType>(p_a_grids[0]);
arg.p_b_grid = reinterpret_cast<BDataType>(p_b_grids[0]);
arg.p_ds_grid = p_ds_grids[0];
// free device mem
for(size_t i = 1; i < rotating_count; i++)
{
hip_check_error(hipFree(const_cast<void*>(p_a_grids[i])));
hip_check_error(hipFree(const_cast<void*>(p_b_grids[i])));
static_for<0, NumDs, 1>{}([&](auto j) {
using DDataType = remove_cvref_t<tuple_element_t<j.value, DsDataType>>;
hip_check_error(
hipFree(static_cast<void*>(const_cast<DDataType*>(p_ds_grids[i][j]))));
});
}
}
}
private:
Argument& arg;
std::size_t iter = 0;
std::size_t rotating_count = 1;
std::size_t size_a = 0;
std::size_t size_b = 0;
std::array<std::size_t, NumDs> size_ds = {0};
std::vector<const void*> p_a_grids;
std::vector<const void*> p_b_grids;
std::vector<DsGridPointer> p_ds_grids;
};
template <typename Argument> template <typename Argument>
struct RotatingMemWrapper struct RotatingMemWrapper
{ {
......
...@@ -53,6 +53,49 @@ struct DeviceGemmMultipleD : public BaseOperator ...@@ -53,6 +53,49 @@ struct DeviceGemmMultipleD : public BaseOperator
virtual std::unique_ptr<BaseInvoker> MakeInvokerPointer() = 0; virtual std::unique_ptr<BaseInvoker> MakeInvokerPointer() = 0;
}; };
// GEMM:
// input : A[M, K], B[K, N],
// input : D0[M, N], D1[M, N], ...
// output : E[M, N]
// C = a_op(A) * b_op(B)
// E = cde_op(C, D0, D1, ...)
// Assume:
// D0, D1, ... and E have the same layout
template <typename ALayout,
typename BLayout,
typename DsLayout,
typename ELayout,
typename ADataType,
typename BDataType,
typename DsDataType,
typename EDataType,
typename AElementwiseOperation,
typename BElementwiseOperation,
typename CDEElementwiseOperation>
struct DeviceGemmMultipleDSplitK : public BaseOperator
{
static constexpr index_t NumDTensor = DsDataType::Size();
virtual std::unique_ptr<BaseArgument>
MakeArgumentPointer(const void* p_a,
const void* p_b,
std::array<const void*, NumDTensor> p_ds,
void* p_e,
ck::index_t M,
ck::index_t N,
ck::index_t K,
ck::index_t StrideA,
ck::index_t StrideB,
std::array<ck::index_t, NumDTensor> StrideDs,
ck::index_t StrideE,
ck::index_t KBatch,
AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op,
CDEElementwiseOperation cde_element_op) = 0;
virtual std::unique_ptr<BaseInvoker> MakeInvokerPointer() = 0;
};
} // namespace device } // namespace device
} // namespace tensor_operation } // namespace tensor_operation
} // namespace ck } // namespace ck
...@@ -126,6 +126,29 @@ struct DeviceGroupedConvFwdMultipleABD : public BaseOperator ...@@ -126,6 +126,29 @@ struct DeviceGroupedConvFwdMultipleABD : public BaseOperator
const BElementwiseOperation& b_element_op, const BElementwiseOperation& b_element_op,
const CDEElementwiseOperation& cde_element_op) = 0; const CDEElementwiseOperation& cde_element_op) = 0;
virtual std::unique_ptr<BaseArgument>
MakeArgumentPointer(APointers p_a,
BPointers p_b,
const std::array<const void*, NumDTensor>& p_ds,
void* p_e,
const std::array<long_index_t, NDimSpatial + 3>& a_g_n_c_wis_lengths,
const std::array<long_index_t, NDimSpatial + 3>& a_g_n_c_wis_strides,
const std::array<long_index_t, NDimSpatial + 3>& b_g_k_c_xs_lengths,
const std::array<long_index_t, NDimSpatial + 3>& b_g_k_c_xs_strides,
const std::array<std::array<long_index_t, NDimSpatial + 3>, NumDTensor>&
ds_g_n_k_wos_lengths,
const std::array<std::array<long_index_t, NDimSpatial + 3>, NumDTensor>&
ds_g_n_k_wos_strides,
const std::array<long_index_t, NDimSpatial + 3>& e_g_n_k_wos_lengths,
const std::array<long_index_t, NDimSpatial + 3>& e_g_n_k_wos_strides,
const std::array<long_index_t, NDimSpatial>& conv_filter_strides,
const std::array<long_index_t, NDimSpatial>& conv_filter_dilations,
const std::array<long_index_t, NDimSpatial>& input_left_pads,
const std::array<long_index_t, NDimSpatial>& input_right_pads,
const AElementwiseOperation& a_element_op,
const BElementwiseOperation& b_element_op,
const CDEElementwiseOperation& cde_element_op) = 0;
virtual std::unique_ptr<BaseInvoker> MakeInvokerPointer() = 0; virtual std::unique_ptr<BaseInvoker> MakeInvokerPointer() = 0;
}; };
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment