Unverified Commit 2cab8d39 authored by Dan Yao's avatar Dan Yao Committed by GitHub
Browse files

CK Tile FA Training kernels (#1286)



* FA fwd dropout

* FA bwd

* epilogue reuse

* CMakeLists update

* [CK_TILE] support alibi (#1269)

* add alibi support

* fix code

* update code based on comment

* Support more hdim

* fix fp8 bias

* support seqlen_k=0 case

* remove unused printf

* fix format

---------
Co-authored-by: default avatarrocking <ChunYu.Lai@amd.com>

* now fwd/bwd can build

* bwd alibi

* add bwd validation stream_config

* update generated filenames

* update bwd kernel launch

* CK_TILE_HOST_DEVICE in philox

* Transpose -> transpose

* format

* format

* format

* Generate the instance for FA required

* format

* fix error in WarpGemm

---------

Co-authored-by: danyao12 <danyao12>
Co-authored-by: default avatarcarlushuang <carlus.huang@amd.com>
Co-authored-by: default avatarrocking <ChunYu.Lai@amd.com>
Co-authored-by: default avatarPo Yen Chen <PoYen.Chen@amd.com>
Co-authored-by: default avatarJing Zhang <jizhan@amd.com>
parent 76827d82
# generate a list of kernels, but not actually emit files at config stage # generate a list of kernels, but not actually emit files at config stage
execute_process( execute_process(
COMMAND ${Python3_EXECUTABLE} ${CMAKE_CURRENT_LIST_DIR}/generate.py COMMAND ${Python3_EXECUTABLE} ${CMAKE_CURRENT_LIST_DIR}/generate.py
--list_blobs ${CMAKE_CURRENT_BINARY_DIR}/blob_list.txt --direction fwd --list_blobs ${CMAKE_CURRENT_BINARY_DIR}/fwd_blob_list.txt
) )
# NOTE: for cmake, the FMHA_FWD_GEN_BLOBS files must be in the same directory execute_process(
COMMAND ${Python3_EXECUTABLE} ${CMAKE_CURRENT_LIST_DIR}/generate.py
--direction bwd --list_blobs ${CMAKE_CURRENT_BINARY_DIR}/bwd_blob_list.txt
)
# NOTE: for cmake, the FMHA_FWD_GEN_BLOBS/FMHA_BWD_GEN_BLOBS files must be in the same directory
# as current cmake list, otherwise will not figure out the dependency properly # as current cmake list, otherwise will not figure out the dependency properly
file(STRINGS ${CMAKE_CURRENT_BINARY_DIR}/blob_list.txt FMHA_FWD_GEN_BLOBS) file(STRINGS ${CMAKE_CURRENT_BINARY_DIR}/fwd_blob_list.txt FMHA_FWD_GEN_BLOBS)
file(STRINGS ${CMAKE_CURRENT_BINARY_DIR}/bwd_blob_list.txt FMHA_BWD_GEN_BLOBS)
add_custom_command( add_custom_command(
OUTPUT ${FMHA_FWD_GEN_BLOBS} OUTPUT ${FMHA_FWD_GEN_BLOBS}
COMMAND ${Python3_EXECUTABLE} ${CMAKE_CURRENT_LIST_DIR}/generate.py COMMAND ${Python3_EXECUTABLE} ${CMAKE_CURRENT_LIST_DIR}/generate.py
--output_dir ${CMAKE_CURRENT_BINARY_DIR} --direction fwd --output_dir ${CMAKE_CURRENT_BINARY_DIR}
)
add_custom_command(
OUTPUT ${FMHA_BWD_GEN_BLOBS}
COMMAND ${Python3_EXECUTABLE} ${CMAKE_CURRENT_LIST_DIR}/generate.py
--direction bwd --output_dir ${CMAKE_CURRENT_BINARY_DIR}
) )
set(EXAMPLE_FMHA_FWD "tile_example_fmha_fwd") set(EXAMPLE_FMHA_FWD "tile_example_fmha_fwd")
...@@ -22,6 +34,14 @@ add_executable(${EXAMPLE_FMHA_FWD} EXCLUDE_FROM_ALL fmha_fwd.cpp) ...@@ -22,6 +34,14 @@ add_executable(${EXAMPLE_FMHA_FWD} EXCLUDE_FROM_ALL fmha_fwd.cpp)
target_include_directories(${EXAMPLE_FMHA_FWD} PRIVATE ${CMAKE_CURRENT_LIST_DIR}) target_include_directories(${EXAMPLE_FMHA_FWD} PRIVATE ${CMAKE_CURRENT_LIST_DIR})
target_sources(${EXAMPLE_FMHA_FWD} PRIVATE ${FMHA_FWD_GEN_BLOBS}) target_sources(${EXAMPLE_FMHA_FWD} PRIVATE ${FMHA_FWD_GEN_BLOBS})
set(EXAMPLE_FMHA_BWD "tile_example_fmha_bwd")
# not using add_example_executable() to add this target, since we don't want this to have
# to be included in "make all/install/check"
message("adding example ${EXAMPLE_FMHA_BWD}")
add_executable(${EXAMPLE_FMHA_BWD} EXCLUDE_FROM_ALL fmha_bwd.cpp)
target_include_directories(${EXAMPLE_FMHA_BWD} PRIVATE ${CMAKE_CURRENT_LIST_DIR})
target_sources(${EXAMPLE_FMHA_BWD} PRIVATE ${FMHA_BWD_GEN_BLOBS})
# NOTE: this is dangerous since will change the whole kernel to flush denormals # NOTE: this is dangerous since will change the whole kernel to flush denormals
# WIP with compiler team for an exp2 intrinsic..., then remove this # WIP with compiler team for an exp2 intrinsic..., then remove this
if(NOT DEFINED FMHA_FWD_FAST_EXP2) if(NOT DEFINED FMHA_FWD_FAST_EXP2)
...@@ -29,16 +49,27 @@ if(NOT DEFINED FMHA_FWD_FAST_EXP2) ...@@ -29,16 +49,27 @@ if(NOT DEFINED FMHA_FWD_FAST_EXP2)
endif() endif()
set(EXAMPLE_FMHA_FWD_COMPILE_OPTIONS) set(EXAMPLE_FMHA_FWD_COMPILE_OPTIONS)
set(EXAMPLE_FMHA_BWD_COMPILE_OPTIONS)
# NOTE: we turn off undefined-func-template to let source compile without explicit declare function specializations # NOTE: we turn off undefined-func-template to let source compile without explicit declare function specializations
# ... because they are auto-generated # ... because they are auto-generated
if(FMHA_FWD_FAST_EXP2) if(FMHA_FWD_FAST_EXP2)
list(APPEND EXAMPLE_FMHA_FWD_COMPILE_OPTIONS -Wno-undefined-func-template -DCK_TILE_FMHA_FWD_FAST_EXP2=1 -fgpu-flush-denormals-to-zero) list(APPEND EXAMPLE_FMHA_FWD_COMPILE_OPTIONS -Wno-undefined-func-template -DCK_TILE_FMHA_FWD_FAST_EXP2=1 -fgpu-flush-denormals-to-zero)
list(APPEND EXAMPLE_FMHA_BWD_COMPILE_OPTIONS -Wno-undefined-func-template -DCK_TILE_FMHA_FWD_FAST_EXP2=1 -fgpu-flush-denormals-to-zero)
else() else()
list(APPEND EXAMPLE_FMHA_FWD_COMPILE_OPTIONS -Wno-undefined-func-template -DCK_TILE_FMHA_FWD_FAST_EXP2=0) list(APPEND EXAMPLE_FMHA_FWD_COMPILE_OPTIONS -Wno-undefined-func-template -DCK_TILE_FMHA_FWD_FAST_EXP2=0)
list(APPEND EXAMPLE_FMHA_BWD_COMPILE_OPTIONS -Wno-undefined-func-template -DCK_TILE_FMHA_FWD_FAST_EXP2=0)
endif() endif()
# Allow comparing floating points directly in order to check sentinel values # Allow comparing floating points directly in order to check sentinel values
list(APPEND EXAMPLE_FMHA_FWD_COMPILE_OPTIONS -Wno-float-equal) list(APPEND EXAMPLE_FMHA_FWD_COMPILE_OPTIONS -Wno-float-equal)
list(APPEND EXAMPLE_FMHA_BWD_COMPILE_OPTIONS -Wno-float-equal)
target_compile_options(${EXAMPLE_FMHA_FWD} PRIVATE ${EXAMPLE_FMHA_FWD_COMPILE_OPTIONS}) target_compile_options(${EXAMPLE_FMHA_FWD} PRIVATE ${EXAMPLE_FMHA_FWD_COMPILE_OPTIONS})
target_compile_options(${EXAMPLE_FMHA_BWD} PRIVATE ${EXAMPLE_FMHA_BWD_COMPILE_OPTIONS})
# TODO: we have to turn off this global prop, otherwise the progress bar generated
# by cmake will print too many files, execvp: /bin/sh: Argument list too long
# however, this property may affect global
# TODO: consider codegen a makefile by us
set_property(GLOBAL PROPERTY RULE_MESSAGES OFF)
This diff is collapsed.
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core.hpp"
#include "ck_tile/host/kernel_launch.hpp"
#include "ck_tile/ops/fmha.hpp"
#include "ck_tile/ops/epilogue.hpp"
#include "mask.hpp"
#include "bias.hpp"
#include <type_traits>
template <typename DataType>
struct FmhaBwdTypeConfig;
template <>
struct FmhaBwdTypeConfig<ck_tile::half_t>
{
using QDataType = ck_tile::half_t;
using KDataType = ck_tile::half_t;
using VDataType = ck_tile::half_t;
using GemmDataType = ck_tile::half_t;
using BiasDataType = ck_tile::half_t;
using LSEDataType = float;
using AccDataType = float; // data type for gemm accumulation
using DDataType = float;
using RandValOutputDataType = uint8_t;
using ODataType = ck_tile::half_t;
using OGradDataType = ck_tile::half_t;
using QGradDataType = ck_tile::half_t;
using KGradDataType = ck_tile::half_t;
using VGradDataType = ck_tile::half_t;
using BiasGradDataType = ck_tile::half_t;
};
template <>
struct FmhaBwdTypeConfig<ck_tile::bf16_t>
{
using QDataType = ck_tile::bf16_t;
using KDataType = ck_tile::bf16_t;
using VDataType = ck_tile::bf16_t;
using GemmDataType = ck_tile::bf16_t;
using BiasDataType = ck_tile::bf16_t;
using LSEDataType = float;
using AccDataType = float; // data type for gemm accumulation
using DDataType = float;
using RandValOutputDataType = uint8_t;
using ODataType = ck_tile::bf16_t;
using OGradDataType = ck_tile::bf16_t;
using QGradDataType = ck_tile::bf16_t;
using KGradDataType = ck_tile::bf16_t;
using VGradDataType = ck_tile::bf16_t;
using BiasGradDataType = ck_tile::bf16_t;
};
struct FmhaMasks
{
using NoMask = ck_tile::GenericAttentionMask<false>;
using GenericMask = ck_tile::GenericAttentionMask<true, true>;
using CausalMask = ck_tile::GenericAttentionMask<true, false>;
};
// runtime args, some will passed to karg, some will used to compute grids/blocks
struct fmha_bwd_args
{
const void* q_ptr;
const void* k_ptr;
const void* v_ptr;
const void* bias_ptr; // bias or alibi_slope pointer
const void* o_ptr;
const void* lse_ptr;
const void* do_ptr;
void* d_ptr;
void* rand_val_ptr;
void* dq_ptr;
void* dk_ptr;
void* dv_ptr;
void* dbias_ptr;
const void* seqstart_q_ptr;
const void* seqstart_k_ptr;
const void* seqlen_k_ptr;
ck_tile::index_t seqlen_q;
ck_tile::index_t seqlen_k;
ck_tile::index_t batch;
ck_tile::index_t max_seqlen_q;
ck_tile::index_t max_seqlen_k;
ck_tile::index_t hdim_q;
ck_tile::index_t hdim_v;
ck_tile::index_t nhead_q;
ck_tile::index_t nhead_k;
float scale;
ck_tile::index_t stride_q;
ck_tile::index_t stride_k;
ck_tile::index_t stride_v;
ck_tile::index_t stride_bias; // if alibi, b*h need set this to h, 1*h need set this to 0
ck_tile::index_t stride_o;
ck_tile::index_t stride_randval;
ck_tile::index_t stride_do;
ck_tile::index_t stride_dk;
ck_tile::index_t stride_dv;
ck_tile::index_t stride_dbias;
ck_tile::index_t nhead_stride_q;
ck_tile::index_t nhead_stride_k;
ck_tile::index_t nhead_stride_v;
ck_tile::index_t nhead_stride_bias;
ck_tile::index_t nhead_stride_o;
ck_tile::index_t nhead_stride_randval;
ck_tile::index_t nhead_stride_do;
ck_tile::index_t nhead_stride_lsed;
ck_tile::index_t nhead_stride_dbias;
ck_tile::index_t batch_stride_q;
ck_tile::index_t batch_stride_k;
ck_tile::index_t batch_stride_v;
ck_tile::index_t batch_stride_bias;
ck_tile::index_t batch_stride_o;
ck_tile::index_t batch_stride_randval;
ck_tile::index_t batch_stride_do;
ck_tile::index_t batch_stride_lsed;
ck_tile::index_t batch_stride_dk;
ck_tile::index_t batch_stride_dv;
ck_tile::index_t batch_stride_dbias;
ck_tile::index_t window_size_left;
ck_tile::index_t window_size_right;
ck_tile::index_t mask_type;
float p_drop;
float p_undrop;
bool s_randval;
std::tuple<uint64_t, uint64_t> drop_seed_offset;
};
template <typename FmhaBwdDQDKDVKernel>
auto fmha_bwd_dq_dk_dv_create_kargs_and_grids(fmha_bwd_args args)
{
assert(args.nhead_q % args.nhead_k == 0);
auto kargs = [&] {
// create group mode kernel arguments
if constexpr(FmhaBwdDQDKDVKernel::kIsGroupMode)
{
return FmhaBwdDQDKDVKernel::MakeKargs(args.q_ptr,
args.k_ptr,
args.v_ptr,
args.bias_ptr,
args.lse_ptr,
args.do_ptr,
args.d_ptr,
args.rand_val_ptr,
args.dq_ptr,
args.dk_ptr,
args.dv_ptr,
args.dbias_ptr,
args.seqstart_q_ptr,
args.seqstart_k_ptr,
args.seqlen_k_ptr,
args.hdim_q,
args.hdim_v,
args.nhead_q,
args.nhead_q / args.nhead_k,
args.scale,
args.stride_q,
args.stride_k,
args.stride_v,
args.stride_bias,
args.stride_randval,
args.stride_do,
args.stride_dk,
args.stride_dv,
args.stride_dbias,
args.nhead_stride_q,
args.nhead_stride_k,
args.nhead_stride_v,
args.nhead_stride_bias,
args.nhead_stride_randval,
args.nhead_stride_do,
args.nhead_stride_lsed,
args.nhead_stride_dbias,
args.batch_stride_lsed,
args.window_size_left,
args.window_size_right,
args.mask_type,
args.p_drop,
args.s_randval,
args.drop_seed_offset);
}
else
{ // create batch mode kernel arguments
return FmhaBwdDQDKDVKernel::MakeKargs(args.q_ptr,
args.k_ptr,
args.v_ptr,
args.bias_ptr,
args.lse_ptr,
args.do_ptr,
args.d_ptr,
args.rand_val_ptr,
args.dq_ptr,
args.dk_ptr,
args.dv_ptr,
args.dbias_ptr,
args.seqlen_q,
args.seqlen_k,
args.hdim_q,
args.hdim_v,
args.nhead_q,
args.nhead_q / args.nhead_k,
args.scale,
args.stride_q,
args.stride_k,
args.stride_v,
args.stride_bias,
args.stride_randval,
args.stride_do,
args.stride_dk,
args.stride_dv,
args.stride_dbias,
args.nhead_stride_q,
args.nhead_stride_k,
args.nhead_stride_v,
args.nhead_stride_bias,
args.nhead_stride_randval,
args.nhead_stride_do,
args.nhead_stride_lsed,
args.nhead_stride_dbias,
args.batch_stride_q,
args.batch_stride_k,
args.batch_stride_v,
args.batch_stride_bias,
args.batch_stride_randval,
args.batch_stride_do,
args.batch_stride_lsed,
args.batch_stride_dk,
args.batch_stride_dv,
args.batch_stride_dbias,
args.window_size_left,
args.window_size_right,
args.mask_type,
args.p_drop,
args.s_randval,
args.drop_seed_offset);
}
}();
dim3 grids = FmhaBwdDQDKDVKernel::GridSize(args.batch, args.nhead_q, args.max_seqlen_k);
return ck_tile::make_tuple(kargs, grids);
}
template <typename FmhaBwdOGradDotOKernel>
auto fmha_bwd_dot_do_o_create_kargs_and_grids(fmha_bwd_args args)
{
auto kargs = [&] {
// create group mode kernel arguments
if constexpr(FmhaBwdOGradDotOKernel::kIsGroupMode)
{
return FmhaBwdOGradDotOKernel::MakeKargs(args.o_ptr,
args.do_ptr,
args.d_ptr,
args.p_undrop,
args.seqstart_q_ptr,
args.hdim_v,
args.stride_do,
args.stride_o,
args.nhead_stride_do,
args.nhead_stride_o,
args.nhead_stride_lsed,
args.batch_stride_lsed);
}
else
{ // create batch mode kernel arguments
return FmhaBwdOGradDotOKernel::MakeKargs(args.o_ptr,
args.do_ptr,
args.d_ptr,
args.p_undrop,
args.seqlen_q,
args.hdim_v,
args.stride_do,
args.stride_o,
args.nhead_stride_do,
args.nhead_stride_o,
args.nhead_stride_lsed,
args.batch_stride_do,
args.batch_stride_o,
args.batch_stride_lsed);
}
}();
dim3 grids = FmhaBwdOGradDotOKernel::GridSize(args.batch, args.nhead_q, args.max_seqlen_q);
return ck_tile::make_tuple(kargs, grids);
}
// this is used to pattern-match internl kernel implementation, not to instantiate kernel
template <ck_tile::index_t HDim_,
typename DataType_,
bool kIsGroupMode_,
ck_tile::BlockFmhaBwdPipelineEnum FmhaBwdPipelineEnum_,
typename FmhaMask_,
ck_tile::BlockAttentionBiasEnum BiasEnum_,
bool kHasBiasGrad_,
bool kHasDropout_,
bool kPadS_,
bool kPadSK_,
bool kPadD_,
bool kPadDv_>
struct fmha_bwd_dq_dk_dv_traits_
{
static constexpr ck_tile::index_t HDim = HDim_;
using DataType = ck_tile::remove_cvref_t<DataType_>;
static constexpr bool kIsGroupMode = kIsGroupMode_;
static constexpr auto FmhaBwdPipelineEnum = FmhaBwdPipelineEnum_;
using FmhaMask = ck_tile::remove_cvref_t<FmhaMask_>;
static constexpr auto BiasEnum = BiasEnum_;
static constexpr bool kHasBiasGrad = kHasBiasGrad_;
static constexpr bool kHasDropout = kHasDropout_;
static constexpr bool kPadS = kPadS_;
static constexpr bool kPadSK = kPadSK_;
static constexpr bool kPadD = kPadD_;
static constexpr bool kPadDv = kPadDv_;
};
template <typename Traits_>
float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config&, fmha_bwd_args);
template <typename Traits_>
void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config&, fmha_bwd_args);
template <typename Traits_>
std::string fmha_bwd_dq_dk_dv_get_name_();
template <ck_tile::index_t HDim_, typename DataType_, bool kIsGroupMode_, bool kPadS_, bool kPadDv_>
struct fmha_bwd_dot_do_o_traits_
{
static constexpr ck_tile::index_t HDim = HDim_;
using DataType = ck_tile::remove_cvref_t<DataType_>;
static constexpr bool kIsGroupMode = kIsGroupMode_;
static constexpr bool kPadS = kPadS_;
static constexpr bool kPadDv = kPadDv_;
};
template <typename Traits_>
float fmha_bwd_dot_do_o_(const ck_tile::stream_config&, fmha_bwd_args);
template <typename Traits_>
void fmha_bwd_dot_do_o_oneshot_(const ck_tile::stream_config&, fmha_bwd_args);
template <typename Traits_>
std::string fmha_bwd_dot_do_o_get_name_();
// This is the public API, will be generated by script
struct fmha_bwd_traits
{
int hdim_q;
int hdim_v;
std::string data_type;
bool is_group_mode;
mask_enum mask_type;
bias_enum bias_type; // 0:no bias, 1:elementwise bias, 2:alibi. sync with BlockAttentionBiasEnum
bool has_dbias;
bool has_dropout;
// TODO: padding check is inside this api
};
float fmha_bwd(fmha_bwd_traits, fmha_bwd_args, const ck_tile::stream_config&);
// SPDX-License-Identifier: MIT // SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#include "fmha_fwd.hpp" #include "fmha_fwd.hpp"
#include "ck_tile/host.hpp" #include "ck_tile/host.hpp"
...@@ -110,6 +110,9 @@ auto create_args(int argc, char* argv[]) ...@@ -110,6 +110,9 @@ auto create_args(int argc, char* argv[])
"11939", "11939",
"random seed used for initializing input tensors. 0 for " "random seed used for initializing input tensors. 0 for "
"non-deterministic seed") "non-deterministic seed")
.insert("p_drop", "0", "0~1 probability of dropout")
.insert("drop_seed", "1", "seed for random number generator")
.insert("drop_offset", "0", "offset for random number generator")
.insert("timer", "gpu", "gpu:gpu timer, cpu:cpu timer") .insert("timer", "gpu", "gpu:gpu timer, cpu:cpu timer")
.insert("warmup", "5", "number of iterations before benchmark the kernel") .insert("warmup", "5", "number of iterations before benchmark the kernel")
.insert("repeat", "20", "number of iterations to benchmark the kernel"); .insert("repeat", "20", "number of iterations to benchmark the kernel");
...@@ -128,26 +131,11 @@ auto get_elimit(std::string /*init_method*/) ...@@ -128,26 +131,11 @@ auto get_elimit(std::string /*init_method*/)
} }
template <> template <>
auto get_elimit<ck_tile::bf16_t>(std::string init_method) auto get_elimit<ck_tile::bf16_t>(std::string /*init_method*/)
{ {
if(init_method == "ui" || init_method == "ni") double rtol = 1e-2;
{ double atol = 1e-2;
double rtol = 1e-2; return ck_tile::make_tuple(rtol, atol);
double atol = 1e-2;
return ck_tile::make_tuple(rtol, atol);
}
else if(init_method == "nf")
{
double rtol = 1e-2;
double atol = 1e-2;
return ck_tile::make_tuple(rtol, atol);
}
else
{
double rtol = 3e-3;
double atol = 3e-3;
return ck_tile::make_tuple(rtol, atol);
}
} }
template <> template <>
...@@ -250,6 +238,21 @@ bool run(const ck_tile::ArgParser& arg_parser) ...@@ -250,6 +238,21 @@ bool run(const ck_tile::ArgParser& arg_parser)
mask_info mask = mask_info::decode( mask_info mask = mask_info::decode(
arg_parser.get_str("mask"), seqlen_qs[0], seqlen_ks[0]); // TODO: we don't need x/y anymore arg_parser.get_str("mask"), seqlen_qs[0], seqlen_ks[0]); // TODO: we don't need x/y anymore
float p_drop = arg_parser.get_float("p_drop");
uint64_t drop_seed = arg_parser.get_uint64("drop_seed");
uint64_t drop_offset = arg_parser.get_uint64("drop_offset");
if(p_drop < 0.0f || p_drop > 1.0f)
{
std::cerr << "The value of p_drop should be 0~1" << std::endl;
return false;
}
bool s_randval = false;
if(p_drop > 0.0f && do_validation)
{
s_randval = true;
}
std::string init_method = arg_parser.get_str("init"); std::string init_method = arg_parser.get_str("init");
std::optional<uint32_t> seed = arg_parser.get_uint32("seed"); std::optional<uint32_t> seed = arg_parser.get_uint32("seed");
if(*seed == 0) if(*seed == 0)
...@@ -274,21 +277,23 @@ bool run(const ck_tile::ArgParser& arg_parser) ...@@ -274,21 +277,23 @@ bool run(const ck_tile::ArgParser& arg_parser)
using TypeConfig = FmhaFwdTypeConfig<DataType>; using TypeConfig = FmhaFwdTypeConfig<DataType>;
using QDataType = typename TypeConfig::QDataType; using QDataType = typename TypeConfig::QDataType;
using KDataType = typename TypeConfig::KDataType; using KDataType = typename TypeConfig::KDataType;
using VDataType = typename TypeConfig::VDataType; using VDataType = typename TypeConfig::VDataType;
using BiasDataType = typename TypeConfig::BiasDataType; using BiasDataType = typename TypeConfig::BiasDataType;
using LSEDataType = typename TypeConfig::LSEDataType; using RandValOutputDataType = typename TypeConfig::RandValOutputDataType;
using SaccDataType = typename TypeConfig::SaccDataType; using LSEDataType = typename TypeConfig::LSEDataType;
using SMPLComputeDataType = typename TypeConfig::SMPLComputeDataType; using SaccDataType = typename TypeConfig::SaccDataType;
using PDataType = typename TypeConfig::PDataType; using SMPLComputeDataType = typename TypeConfig::SMPLComputeDataType;
using OaccDataType = typename TypeConfig::OaccDataType; using PDataType = typename TypeConfig::PDataType;
using ODataType = typename TypeConfig::ODataType; using OaccDataType = typename TypeConfig::OaccDataType;
using ODataType = typename TypeConfig::ODataType;
// accumulation numbers for performance evaluation // accumulation numbers for performance evaluation
std::size_t flop = 0, num_byte = 0; std::size_t flop = 0, num_byte = 0;
auto max_seqlen_q = auto max_seqlen_q =
std::numeric_limits<int32_t>::min(); // we will use max seqlen to decide grid size std::numeric_limits<int32_t>::min(); // we will use max seqlen to decide grid size
auto max_seqlen_k = std::numeric_limits<int32_t>::min();
{ {
for(ck_tile::index_t wb = 0; wb < batch; ++wb) for(ck_tile::index_t wb = 0; wb < batch; ++wb)
{ {
...@@ -300,6 +305,11 @@ bool run(const ck_tile::ArgParser& arg_parser) ...@@ -300,6 +305,11 @@ bool run(const ck_tile::ArgParser& arg_parser)
max_seqlen_q = real_seqlen_q; max_seqlen_q = real_seqlen_q;
} }
if(max_seqlen_k < real_seqlen_k)
{
max_seqlen_k = real_seqlen_k;
}
flop += nhead * (static_cast<std::size_t>(2) * real_seqlen_q * real_seqlen_k * hdim_q + flop += nhead * (static_cast<std::size_t>(2) * real_seqlen_q * real_seqlen_k * hdim_q +
static_cast<std::size_t>(2) * real_seqlen_q * hdim_v * real_seqlen_k); static_cast<std::size_t>(2) * real_seqlen_q * hdim_v * real_seqlen_k);
...@@ -353,12 +363,16 @@ bool run(const ck_tile::ArgParser& arg_parser) ...@@ -353,12 +363,16 @@ bool run(const ck_tile::ArgParser& arg_parser)
// self define lse data layout as [shape_batch, nhead, shape_seqlen_q] // self define lse data layout as [shape_batch, nhead, shape_seqlen_q]
ck_tile::HostTensor<LSEDataType> lse_host( ck_tile::HostTensor<LSEDataType> lse_host(
lse ? std::array<ck_tile::index_t, 3>{shape_batch, nhead, shape_seqlen_q} lse ? std::array<ck_tile::index_t, 3>{batch, nhead, max_seqlen_q}
: std::array<ck_tile::index_t, 3>{1, 1, 1} /* dummy shape for simplifying code */); : std::array<ck_tile::index_t, 3>{1, 1, 1} /* dummy shape for simplifying code */);
ck_tile::HostTensor<ODataType> o_host( ck_tile::HostTensor<ODataType> o_host(
get_lengths(o_perm, shape_batch, nhead, shape_seqlen_q, hdim_v)); get_lengths(o_perm, shape_batch, nhead, shape_seqlen_q, hdim_v));
ck_tile::HostTensor<RandValOutputDataType> randval_host(
p_drop > 0 ? get_lengths(true, shape_batch, nhead, shape_seqlen_q, max_seqlen_k)
: std::array<ck_tile::index_t, 4>{1, 1, 1, 1});
if(init_method == "ui" || init_method == "0") if(init_method == "ui" || init_method == "0")
{ {
ck_tile::FillUniformDistributionIntegerValue<QDataType>{-3.f, 3.f, seed}(q_host); ck_tile::FillUniformDistributionIntegerValue<QDataType>{-3.f, 3.f, seed}(q_host);
...@@ -434,6 +448,7 @@ bool run(const ck_tile::ArgParser& arg_parser) ...@@ -434,6 +448,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
ck_tile::DeviceMem seqstart_q(seqstart_q_host.size() * sizeof(int32_t)); ck_tile::DeviceMem seqstart_q(seqstart_q_host.size() * sizeof(int32_t));
ck_tile::DeviceMem seqstart_k(seqstart_k_host.size() * sizeof(int32_t)); ck_tile::DeviceMem seqstart_k(seqstart_k_host.size() * sizeof(int32_t));
ck_tile::DeviceMem seqlen_k_buf(seqlen_kpads[0] < 0 ? 0 : seqlen_ks.size() * sizeof(int32_t)); ck_tile::DeviceMem seqlen_k_buf(seqlen_kpads[0] < 0 ? 0 : seqlen_ks.size() * sizeof(int32_t));
ck_tile::DeviceMem randval_buf(randval_host.get_element_space_size_in_bytes());
ck_tile::DeviceMem alibi_slope_buf(alibi_slope_host.get_element_space_size_in_bytes()); ck_tile::DeviceMem alibi_slope_buf(alibi_slope_host.get_element_space_size_in_bytes());
q_buf.ToDevice(q_host.data()); q_buf.ToDevice(q_host.data());
...@@ -463,8 +478,8 @@ bool run(const ck_tile::ArgParser& arg_parser) ...@@ -463,8 +478,8 @@ bool run(const ck_tile::ArgParser& arg_parser)
<< (seqlen_kpads[0] < 0 ? "" << (seqlen_kpads[0] < 0 ? ""
: (std::string("(") + std::to_string(seqlen_kpads[0]) + ")")) : (std::string("(") + std::to_string(seqlen_kpads[0]) + ")"))
<< ", d:" << hdim_q << "/" << hdim_v << ", scale_s:" << scale_s << ", bias:" << bias << ", d:" << hdim_q << "/" << hdim_v << ", scale_s:" << scale_s << ", bias:" << bias
<< ", lse:" << lse << ", squant:" << squant << ", mask:" << mask << ", v:" << vlayout << ", p_drop:" << p_drop << ", lse:" << lse << ", squant:" << squant
<< std::flush; << ", mask:" << mask << ", v:" << vlayout << std::flush;
auto fmha_traits = fmha_fwd_traits{hdim_q, auto fmha_traits = fmha_fwd_traits{hdim_q,
hdim_v, hdim_v,
...@@ -474,6 +489,7 @@ bool run(const ck_tile::ArgParser& arg_parser) ...@@ -474,6 +489,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
mask.type, mask.type,
bias.type, bias.type,
lse, lse,
p_drop > 0.0f,
squant}; squant};
auto p_compute_element_func = [&]() { auto p_compute_element_func = [&]() {
...@@ -505,8 +521,9 @@ bool run(const ck_tile::ArgParser& arg_parser) ...@@ -505,8 +521,9 @@ bool run(const ck_tile::ArgParser& arg_parser)
else else
return i_perm ? shape_seqlen_k : nhead_k * shape_seqlen_k; return i_perm ? shape_seqlen_k : nhead_k * shape_seqlen_k;
}(); }();
const ck_tile::index_t stride_bias = (i_perm ? shape_seqlen_k : 1 * shape_seqlen_k); const ck_tile::index_t stride_bias = (i_perm ? shape_seqlen_k : 1 * shape_seqlen_k);
const ck_tile::index_t stride_o = (o_perm ? hdim_v : nhead * hdim_v); const ck_tile::index_t stride_randval = (max_seqlen_k);
const ck_tile::index_t stride_o = (o_perm ? hdim_v : nhead * hdim_v);
// setup nhead_stride_* arguments // setup nhead_stride_* arguments
const ck_tile::index_t nhead_stride_q = (i_perm ? shape_seqlen_q * hdim_q : hdim_q); const ck_tile::index_t nhead_stride_q = (i_perm ? shape_seqlen_q * hdim_q : hdim_q);
const ck_tile::index_t nhead_stride_k = (i_perm ? shape_seqlen_k * hdim_q : hdim_q); const ck_tile::index_t nhead_stride_k = (i_perm ? shape_seqlen_k * hdim_q : hdim_q);
...@@ -518,21 +535,24 @@ bool run(const ck_tile::ArgParser& arg_parser) ...@@ -518,21 +535,24 @@ bool run(const ck_tile::ArgParser& arg_parser)
}(); }();
const ck_tile::index_t nhead_stride_bias = const ck_tile::index_t nhead_stride_bias =
(i_perm ? 0 * shape_seqlen_q * shape_seqlen_k : 0 * shape_seqlen_k); (i_perm ? 0 * shape_seqlen_q * shape_seqlen_k : 0 * shape_seqlen_k);
const ck_tile::index_t nhead_stride_lse = (shape_seqlen_q * 1); const ck_tile::index_t nhead_stride_randval = (shape_seqlen_q * max_seqlen_k);
const ck_tile::index_t nhead_stride_o = (o_perm ? shape_seqlen_q * hdim_v : hdim_v); const ck_tile::index_t nhead_stride_lse = max_seqlen_q;
const ck_tile::index_t nhead_stride_o = (o_perm ? shape_seqlen_q * hdim_v : hdim_v);
// setup batch_stride_* arguments // setup batch_stride_* arguments
const ck_tile::index_t batch_stride_q = (nhead * shape_seqlen_q * hdim_q); const ck_tile::index_t batch_stride_q = (nhead * shape_seqlen_q * hdim_q);
const ck_tile::index_t batch_stride_k = (nhead_k * shape_seqlen_k * hdim_q); const ck_tile::index_t batch_stride_k = (nhead_k * shape_seqlen_k * hdim_q);
const ck_tile::index_t batch_stride_v = (nhead_k * hdim_v * shape_seqlen_k); const ck_tile::index_t batch_stride_v = (nhead_k * hdim_v * shape_seqlen_k);
const ck_tile::index_t batch_stride_bias = (0 * nhead * shape_seqlen_q * shape_seqlen_k); const ck_tile::index_t batch_stride_bias = (0 * nhead * shape_seqlen_q * shape_seqlen_k);
const ck_tile::index_t batch_stride_lse = (nhead * shape_seqlen_q * 1); const ck_tile::index_t batch_stride_randval = (nhead * shape_seqlen_q * max_seqlen_k);
const ck_tile::index_t batch_stride_o = (nhead * shape_seqlen_q * hdim_v); const ck_tile::index_t batch_stride_lse = (nhead * max_seqlen_q);
const ck_tile::index_t batch_stride_o = (nhead * shape_seqlen_q * hdim_v);
return fmha_fwd_args{q_buf.GetDeviceBuffer(), return fmha_fwd_args{q_buf.GetDeviceBuffer(),
k_buf.GetDeviceBuffer(), k_buf.GetDeviceBuffer(),
v_buf.GetDeviceBuffer(), v_buf.GetDeviceBuffer(),
bias.type == bias_enum::alibi ? alibi_slope_buf.GetDeviceBuffer() bias.type == bias_enum::alibi ? alibi_slope_buf.GetDeviceBuffer()
: bias_buf.GetDeviceBuffer(), : bias_buf.GetDeviceBuffer(),
randval_buf.GetDeviceBuffer(),
lse_buf.GetDeviceBuffer(), lse_buf.GetDeviceBuffer(),
o_buf.GetDeviceBuffer(), o_buf.GetDeviceBuffer(),
seqstart_q.GetDeviceBuffer(), seqstart_q.GetDeviceBuffer(),
...@@ -554,22 +574,28 @@ bool run(const ck_tile::ArgParser& arg_parser) ...@@ -554,22 +574,28 @@ bool run(const ck_tile::ArgParser& arg_parser)
stride_v, stride_v,
bias.type == bias_enum::alibi ? (bias.rank_info == 0 ? 0 : nhead) bias.type == bias_enum::alibi ? (bias.rank_info == 0 ? 0 : nhead)
: stride_bias, : stride_bias,
stride_randval,
stride_o, stride_o,
nhead_stride_q, nhead_stride_q,
nhead_stride_k, nhead_stride_k,
nhead_stride_v, nhead_stride_v,
nhead_stride_bias, nhead_stride_bias,
nhead_stride_randval,
nhead_stride_lse, nhead_stride_lse,
nhead_stride_o, nhead_stride_o,
batch_stride_q, batch_stride_q,
batch_stride_k, batch_stride_k,
batch_stride_v, batch_stride_v,
batch_stride_bias, batch_stride_bias,
batch_stride_randval,
batch_stride_lse, batch_stride_lse,
batch_stride_o, batch_stride_o,
mask.left, mask.left,
mask.right, mask.right,
static_cast<ck_tile::index_t>(mask.type)}; static_cast<ck_tile::index_t>(mask.type),
p_drop,
s_randval,
{drop_seed, drop_offset}};
}(); }();
float ave_time = fmha_fwd(fmha_traits, fmha_args, stream_config); float ave_time = fmha_fwd(fmha_traits, fmha_args, stream_config);
...@@ -596,6 +622,11 @@ bool run(const ck_tile::ArgParser& arg_parser) ...@@ -596,6 +622,11 @@ bool run(const ck_tile::ArgParser& arg_parser)
o_buf.FromDevice(o_host.data()); o_buf.FromDevice(o_host.data());
lse_buf.FromDevice(lse_host.data()); lse_buf.FromDevice(lse_host.data());
randval_buf.FromDevice(randval_host.data());
float p_undrop = 1.0 - p_drop;
uint8_t p_undrop_in_uint8_t =
uint8_t(std::floor(p_undrop * std::numeric_limits<uint8_t>::max()));
float rp_undrop = 1.0 / p_undrop;
bool pass = true; bool pass = true;
...@@ -771,6 +802,17 @@ bool run(const ck_tile::ArgParser& arg_parser) ...@@ -771,6 +802,17 @@ bool run(const ck_tile::ArgParser& arg_parser)
s_host_ref, p_host_ref, p_compute_element_func); s_host_ref, p_host_ref, p_compute_element_func);
} }
if(p_drop > 0)
{
ck_tile::HostTensor<RandValOutputDataType> randval_host_ref(
{nhead, real_seqlen_q, real_seqlen_k});
randval_host_ref.ForEach([&](auto& self, auto idx) {
self(idx) = randval_host(b, idx[0], idx[1] + query_offset, idx[2]);
});
ck_tile::reference_batched_dropout(
p_host_ref, randval_host_ref, p_undrop_in_uint8_t, rp_undrop);
}
ck_tile::reference_batched_gemm<PDataType, VDataType, OaccDataType, ODataType>( ck_tile::reference_batched_gemm<PDataType, VDataType, OaccDataType, ODataType>(
p_host_ref, p_host_ref,
v_host_ref, v_host_ref,
...@@ -804,9 +846,8 @@ bool run(const ck_tile::ArgParser& arg_parser) ...@@ -804,9 +846,8 @@ bool run(const ck_tile::ArgParser& arg_parser)
if(lse) if(lse)
{ {
ck_tile::HostTensor<SMPLComputeDataType> lse_host_result({nhead, real_seqlen_q}); ck_tile::HostTensor<SMPLComputeDataType> lse_host_result({nhead, real_seqlen_q});
lse_host_result.ForEach([&](auto& self, auto idx) { lse_host_result.ForEach(
self(idx) = lse_host(b, idx[0], idx[1] + query_offset); [&](auto& self, auto idx) { self(idx) = lse_host(wb, idx[0], idx[1]); });
});
bool lse_pass = ck_tile::check_err(lse_host_result, bool lse_pass = ck_tile::check_err(lse_host_result,
lse_host_ref, lse_host_ref,
......
...@@ -17,61 +17,65 @@ struct FmhaFwdTypeConfig; ...@@ -17,61 +17,65 @@ struct FmhaFwdTypeConfig;
template <> template <>
struct FmhaFwdTypeConfig<ck_tile::half_t> struct FmhaFwdTypeConfig<ck_tile::half_t>
{ {
using QDataType = ck_tile::half_t; using QDataType = ck_tile::half_t;
using KDataType = ck_tile::half_t; using KDataType = ck_tile::half_t;
using VDataType = ck_tile::half_t; using VDataType = ck_tile::half_t;
using BiasDataType = ck_tile::half_t; using BiasDataType = ck_tile::half_t;
using LSEDataType = float; // data type for lse(logsumexp L_j = max_j + log(l_j)) using RandValOutputDataType = uint8_t;
using SaccDataType = float; // data type for first gemm accumulation using LSEDataType = float; // data type for lse(logsumexp L_j = max_j + log(l_j))
using SMPLComputeDataType = float; // data type for reduction, softmax using SaccDataType = float; // data type for first gemm accumulation
using PDataType = ck_tile::half_t; // data type for A matrix of second gemm using SMPLComputeDataType = float; // data type for reduction, softmax
using OaccDataType = float; // data type for second gemm accumulation using PDataType = ck_tile::half_t; // data type for A matrix of second gemm
using ODataType = ck_tile::half_t; using OaccDataType = float; // data type for second gemm accumulation
using ODataType = ck_tile::half_t;
}; };
template <> template <>
struct FmhaFwdTypeConfig<ck_tile::bf16_t> struct FmhaFwdTypeConfig<ck_tile::bf16_t>
{ {
using QDataType = ck_tile::bf16_t; using QDataType = ck_tile::bf16_t;
using KDataType = ck_tile::bf16_t; using KDataType = ck_tile::bf16_t;
using VDataType = ck_tile::bf16_t; using VDataType = ck_tile::bf16_t;
using BiasDataType = ck_tile::bf16_t; using BiasDataType = ck_tile::bf16_t;
using LSEDataType = float; // data type for lse(logsumexp L_j = max_j + log(l_j)) using RandValOutputDataType = uint8_t;
using SaccDataType = float; // data type for first gemm accumulation using LSEDataType = float; // data type for lse(logsumexp L_j = max_j + log(l_j))
using SMPLComputeDataType = float; // data type for reduction, softmax using SaccDataType = float; // data type for first gemm accumulation
using PDataType = ck_tile::bf16_t; // data type for A matrix of second gemm using SMPLComputeDataType = float; // data type for reduction, softmax
using OaccDataType = float; // data type for second gemm accumulation using PDataType = ck_tile::bf16_t; // data type for A matrix of second gemm
using ODataType = ck_tile::bf16_t; using OaccDataType = float; // data type for second gemm accumulation
using ODataType = ck_tile::bf16_t;
}; };
template <> template <>
struct FmhaFwdTypeConfig<ck_tile::fp8_t> struct FmhaFwdTypeConfig<ck_tile::fp8_t>
{ {
using QDataType = ck_tile::fp8_t; using QDataType = ck_tile::fp8_t;
using KDataType = ck_tile::fp8_t; using KDataType = ck_tile::fp8_t;
using VDataType = ck_tile::fp8_t; using VDataType = ck_tile::fp8_t;
using BiasDataType = float; using BiasDataType = float;
using LSEDataType = float; // data type for lse(logsumexp L_j = max_j + log(l_j)) using RandValOutputDataType = uint8_t;
using SaccDataType = float; // data type for first gemm accumulation using LSEDataType = float; // data type for lse(logsumexp L_j = max_j + log(l_j))
using SMPLComputeDataType = float; // data type for reduction, softmax using SaccDataType = float; // data type for first gemm accumulation
using PDataType = ck_tile::fp8_t; // data type for A matrix of second gemm using SMPLComputeDataType = float; // data type for reduction, softmax
using OaccDataType = float; // data type for second gemm accumulation using PDataType = ck_tile::fp8_t; // data type for A matrix of second gemm
using ODataType = ck_tile::fp8_t; using OaccDataType = float; // data type for second gemm accumulation
using ODataType = ck_tile::fp8_t;
}; };
template <> template <>
struct FmhaFwdTypeConfig<ck_tile::bf8_t> struct FmhaFwdTypeConfig<ck_tile::bf8_t>
{ {
using QDataType = ck_tile::bf8_t; using QDataType = ck_tile::bf8_t;
using KDataType = ck_tile::bf8_t; using KDataType = ck_tile::bf8_t;
using VDataType = ck_tile::bf8_t; using VDataType = ck_tile::bf8_t;
using BiasDataType = ck_tile::bf8_t; using BiasDataType = ck_tile::bf8_t;
using LSEDataType = float; // data type for lse(logsumexp L_j = max_j + log(l_j)) using RandValOutputDataType = uint8_t;
using SaccDataType = float; // data type for first gemm accumulation using LSEDataType = float; // data type for lse(logsumexp L_j = max_j + log(l_j))
using SMPLComputeDataType = float; // data type for reduction, softmax using SaccDataType = float; // data type for first gemm accumulation
using PDataType = ck_tile::bf8_t; // data type for A matrix of second gemm using SMPLComputeDataType = float; // data type for reduction, softmax
using OaccDataType = float; // data type for second gemm accumulation using PDataType = ck_tile::bf8_t; // data type for A matrix of second gemm
using ODataType = ck_tile::bf8_t; using OaccDataType = float; // data type for second gemm accumulation
using ODataType = ck_tile::bf8_t;
}; };
struct FmhaMasks struct FmhaMasks
...@@ -88,6 +92,7 @@ struct fmha_fwd_args ...@@ -88,6 +92,7 @@ struct fmha_fwd_args
const void* k_ptr; const void* k_ptr;
const void* v_ptr; const void* v_ptr;
const void* bias_ptr; // bias or alibi_slope pointer const void* bias_ptr; // bias or alibi_slope pointer
void* rand_val_ptr;
void* lse_ptr; void* lse_ptr;
void* o_ptr; void* o_ptr;
const void* seqstart_q_ptr; const void* seqstart_q_ptr;
...@@ -108,22 +113,28 @@ struct fmha_fwd_args ...@@ -108,22 +113,28 @@ struct fmha_fwd_args
ck_tile::index_t stride_k; ck_tile::index_t stride_k;
ck_tile::index_t stride_v; ck_tile::index_t stride_v;
ck_tile::index_t stride_bias; // if alibi, b*h need set this to h, 1*h need set this to 0 ck_tile::index_t stride_bias; // if alibi, b*h need set this to h, 1*h need set this to 0
ck_tile::index_t stride_randval;
ck_tile::index_t stride_o; ck_tile::index_t stride_o;
ck_tile::index_t nhead_stride_q; ck_tile::index_t nhead_stride_q;
ck_tile::index_t nhead_stride_k; ck_tile::index_t nhead_stride_k;
ck_tile::index_t nhead_stride_v; ck_tile::index_t nhead_stride_v;
ck_tile::index_t nhead_stride_bias; ck_tile::index_t nhead_stride_bias;
ck_tile::index_t nhead_stride_randval;
ck_tile::index_t nhead_stride_lse; ck_tile::index_t nhead_stride_lse;
ck_tile::index_t nhead_stride_o; ck_tile::index_t nhead_stride_o;
ck_tile::index_t batch_stride_q; ck_tile::index_t batch_stride_q;
ck_tile::index_t batch_stride_k; ck_tile::index_t batch_stride_k;
ck_tile::index_t batch_stride_v; ck_tile::index_t batch_stride_v;
ck_tile::index_t batch_stride_bias; ck_tile::index_t batch_stride_bias;
ck_tile::index_t batch_stride_randval;
ck_tile::index_t batch_stride_lse; ck_tile::index_t batch_stride_lse;
ck_tile::index_t batch_stride_o; ck_tile::index_t batch_stride_o;
ck_tile::index_t window_size_left; ck_tile::index_t window_size_left;
ck_tile::index_t window_size_right; ck_tile::index_t window_size_right;
ck_tile::index_t mask_type; ck_tile::index_t mask_type;
float p_drop;
bool s_randval;
std::tuple<uint64_t, uint64_t> drop_seed_offset;
}; };
template <typename FmhaKernel> template <typename FmhaKernel>
...@@ -138,6 +149,7 @@ auto fmha_fwd_create_kargs_and_grids(fmha_fwd_args args) ...@@ -138,6 +149,7 @@ auto fmha_fwd_create_kargs_and_grids(fmha_fwd_args args)
args.k_ptr, args.k_ptr,
args.v_ptr, args.v_ptr,
args.bias_ptr, args.bias_ptr,
args.rand_val_ptr,
args.lse_ptr, args.lse_ptr,
args.o_ptr, args.o_ptr,
args.seqstart_q_ptr, args.seqstart_q_ptr,
...@@ -145,6 +157,7 @@ auto fmha_fwd_create_kargs_and_grids(fmha_fwd_args args) ...@@ -145,6 +157,7 @@ auto fmha_fwd_create_kargs_and_grids(fmha_fwd_args args)
args.seqlen_k_ptr, args.seqlen_k_ptr,
args.hdim_q, args.hdim_q,
args.hdim_v, args.hdim_v,
args.nhead_q,
args.nhead_q / args.nhead_k, args.nhead_q / args.nhead_k,
args.scale_s, args.scale_s,
args.scale_p, args.scale_p,
...@@ -153,16 +166,22 @@ auto fmha_fwd_create_kargs_and_grids(fmha_fwd_args args) ...@@ -153,16 +166,22 @@ auto fmha_fwd_create_kargs_and_grids(fmha_fwd_args args)
args.stride_k, args.stride_k,
args.stride_v, args.stride_v,
args.stride_bias, args.stride_bias,
args.stride_randval,
args.stride_o, args.stride_o,
args.nhead_stride_q, args.nhead_stride_q,
args.nhead_stride_k, args.nhead_stride_k,
args.nhead_stride_v, args.nhead_stride_v,
args.nhead_stride_bias, args.nhead_stride_bias,
args.nhead_stride_randval,
args.nhead_stride_lse, args.nhead_stride_lse,
args.nhead_stride_o, args.nhead_stride_o,
args.batch_stride_lse,
args.window_size_left, args.window_size_left,
args.window_size_right, args.window_size_right,
args.mask_type); args.mask_type,
args.p_drop,
args.s_randval,
args.drop_seed_offset);
} }
else else
{ // create batch mode kernel arguments { // create batch mode kernel arguments
...@@ -170,12 +189,14 @@ auto fmha_fwd_create_kargs_and_grids(fmha_fwd_args args) ...@@ -170,12 +189,14 @@ auto fmha_fwd_create_kargs_and_grids(fmha_fwd_args args)
args.k_ptr, args.k_ptr,
args.v_ptr, args.v_ptr,
args.bias_ptr, args.bias_ptr,
args.rand_val_ptr,
args.lse_ptr, args.lse_ptr,
args.o_ptr, args.o_ptr,
args.seqlen_q, args.seqlen_q,
args.seqlen_k, args.seqlen_k,
args.hdim_q, args.hdim_q,
args.hdim_v, args.hdim_v,
args.nhead_q,
args.nhead_q / args.nhead_k, args.nhead_q / args.nhead_k,
args.scale_s, args.scale_s,
args.scale_p, args.scale_p,
...@@ -184,22 +205,28 @@ auto fmha_fwd_create_kargs_and_grids(fmha_fwd_args args) ...@@ -184,22 +205,28 @@ auto fmha_fwd_create_kargs_and_grids(fmha_fwd_args args)
args.stride_k, args.stride_k,
args.stride_v, args.stride_v,
args.stride_bias, args.stride_bias,
args.stride_randval,
args.stride_o, args.stride_o,
args.nhead_stride_q, args.nhead_stride_q,
args.nhead_stride_k, args.nhead_stride_k,
args.nhead_stride_v, args.nhead_stride_v,
args.nhead_stride_bias, args.nhead_stride_bias,
args.nhead_stride_randval,
args.nhead_stride_lse, args.nhead_stride_lse,
args.nhead_stride_o, args.nhead_stride_o,
args.batch_stride_q, args.batch_stride_q,
args.batch_stride_k, args.batch_stride_k,
args.batch_stride_v, args.batch_stride_v,
args.batch_stride_bias, args.batch_stride_bias,
args.batch_stride_randval,
args.batch_stride_lse, args.batch_stride_lse,
args.batch_stride_o, args.batch_stride_o,
args.window_size_left, args.window_size_left,
args.window_size_right, args.window_size_right,
args.mask_type); args.mask_type,
args.p_drop,
args.s_randval,
args.drop_seed_offset);
} }
}(); }();
...@@ -222,6 +249,7 @@ template <ck_tile::index_t HDim_, ...@@ -222,6 +249,7 @@ template <ck_tile::index_t HDim_,
typename FmhaMask_, typename FmhaMask_,
ck_tile::BlockAttentionBiasEnum BiasEnum_, ck_tile::BlockAttentionBiasEnum BiasEnum_,
bool kStoreLse_, bool kStoreLse_,
bool kHasDropout_,
bool kDoFp8StaticQuant_, bool kDoFp8StaticQuant_,
bool kPadS_, bool kPadS_,
bool kPadSK_, bool kPadSK_,
...@@ -243,6 +271,7 @@ struct fmha_fwd_traits_ ...@@ -243,6 +271,7 @@ struct fmha_fwd_traits_
using FmhaMask = ck_tile::remove_cvref_t<FmhaMask_>; using FmhaMask = ck_tile::remove_cvref_t<FmhaMask_>;
static constexpr auto BiasEnum = BiasEnum_; static constexpr auto BiasEnum = BiasEnum_;
static constexpr bool kStoreLse = kStoreLse_; static constexpr bool kStoreLse = kStoreLse_;
static constexpr bool kHasDropout = kHasDropout_;
static constexpr bool kDoFp8StaticQuant = kDoFp8StaticQuant_; static constexpr bool kDoFp8StaticQuant = kDoFp8StaticQuant_;
static constexpr bool kPadS = kPadS_; static constexpr bool kPadS = kPadS_;
static constexpr bool kPadSK = kPadSK_; static constexpr bool kPadSK = kPadSK_;
...@@ -264,6 +293,7 @@ struct fmha_fwd_traits ...@@ -264,6 +293,7 @@ struct fmha_fwd_traits
mask_enum mask_type; mask_enum mask_type;
bias_enum bias_type; // 0:no bias, 1:elementwise bias, 2:alibi. sync with BlockAttentionBiasEnum bias_enum bias_type; // 0:no bias, 1:elementwise bias, 2:alibi. sync with BlockAttentionBiasEnum
bool has_lse; bool has_lse;
bool has_dropout;
bool do_fp8_static_quant; bool do_fp8_static_quant;
// TODO: padding check is inside this api // TODO: padding check is inside this api
}; };
......
This diff is collapsed.
#!/bin/sh
# TODO: run this script from CK root
BUILD=build
EXE=$BUILD/bin/tile_example_fmha_bwd
VALID=0
for prec in "fp16" "bf16" ; do
for perm in 0 1 ; do
for hdim in 32 64 128 ; do
nhead=$((2048 / $hdim)) # follow fav2 setup
$EXE -prec=$prec -b=32 -h=$nhead -d=$hdim -s=512 -iperm=$perm -operm=$perm -kname=1 -v=$VALID ; sleep 3
$EXE -prec=$prec -b=16 -h=$nhead -d=$hdim -s=1024 -iperm=$perm -operm=$perm -kname=1 -v=$VALID ; sleep 3
$EXE -prec=$prec -b=8 -h=$nhead -d=$hdim -s=2048 -iperm=$perm -operm=$perm -kname=1 -v=$VALID ; sleep 3
$EXE -prec=$prec -b=4 -h=$nhead -d=$hdim -s=4096 -iperm=$perm -operm=$perm -kname=1 -v=$VALID ; sleep 3
$EXE -prec=$prec -b=2 -h=$nhead -d=$hdim -s=8192 -iperm=$perm -operm=$perm -kname=1 -v=$VALID ; sleep 3
$EXE -prec=$prec -b=1 -h=$nhead -d=$hdim -s=16384 -iperm=$perm -operm=$perm -kname=1 -v=$VALID ; sleep 3
done
done
done
#!/bin/sh
# TODO: run this script from CK root
BUILD=build
EXE=$BUILD/bin/tile_example_fmha_bwd
KNAME=1
export CK_WARMUP=0
export CK_REPEAT=1
COMMON_ARGS='-v=1'
for prec in "fp16" "bf16" ; do
for perm in 0 1 ; do
for hdim in 32 64 128 ; do
for mode in 0 1 ; do
for bias in "n" "e" "a"; do
for dbias in 0 1 ; do
for p_drop in 0.0 0.2; do
$EXE -prec=$prec -b=1 -h=4 -h_k=2 -d=$hdim -s=259 -bias=$bias -dbias=$dbias -p_drop=$p_drop -iperm=$perm -operm=$perm -v=1 -mode=$mode -kname=$KNAME $COMMON_ARGS
$EXE -prec=$prec -b=2 -h=2 -d=$hdim -s=516 -s_k=253 -bias=$bias -dbias=$dbias -p_drop=$p_drop -iperm=$perm -operm=$perm -v=1 -mode=$mode -kname=$KNAME $COMMON_ARGS
$EXE -prec=$prec -b=1 -h=4 -h_k=1 -d=$hdim -s=500 -s_k=251 -bias=$bias -dbias=$dbias -p_drop=$p_drop -iperm=$perm -operm=$perm -mask=1 -v=1 -mode=$mode -kname=$KNAME $COMMON_ARGS
$EXE -prec=$prec -b=1 -h=2 -d=$hdim -s=900 -s_k=258 -bias=$bias -dbias=$dbias -p_drop=$p_drop -iperm=$perm -operm=$perm -mask=2 -v=1 -mode=$mode -kname=$KNAME $COMMON_ARGS
$EXE -prec=$prec -b=2 -h=1 -d=$hdim -s=987 -s_k=219 -bias=$bias -dbias=$dbias -p_drop=$p_drop -iperm=$perm -operm=$perm -mask=t:128,30 -v=1 -mode=$mode -kname=$KNAME $COMMON_ARGS
$EXE -prec=$prec -b=2 -h=3 -h_k=1 -d=$hdim -s=244 -s_k=499 -bias=$bias -dbias=$dbias -p_drop=$p_drop -iperm=$perm -operm=$perm -mask=b:4,35 -v=1 -mode=$mode -kname=$KNAME $COMMON_ARGS
done
done
done
done
done
done
done
...@@ -17,17 +17,18 @@ for perm in 0 1 ; do ...@@ -17,17 +17,18 @@ for perm in 0 1 ; do
for vlayout in "r" "c" ; do for vlayout in "r" "c" ; do
for hdim in 32 64 128 256 ; do for hdim in 32 64 128 256 ; do
for lse in 0 1 ; do for lse in 0 1 ; do
for bias in "n" "e" "a"; do for bias in "n" "e" "a" ; do
for p_drop in 0.0 0.2; do
# $EXE -prec=$prec -mode=$mode -b=1 -h=1 -d=$hdim -s=1024 -bias=$bias -lse=$lse -iperm=$perm -operm=$perm -vlayout=$vlayout -kname=$KNAME $COMMON_ARGS # $EXE -prec=$prec -mode=$mode -b=1 -h=1 -d=$hdim -s=1024 -bias=$bias -p_drop=$p_drop -lse=$lse -iperm=$perm -operm=$perm -vlayout=$vlayout -kname=$KNAME $COMMON_ARGS
$EXE -prec=$prec -mode=$mode -b=2 -h=2 -h_k=1 -d=16, -d_v=$hdim -s=55 -s_k=256 -bias=$bias -lse=$lse -iperm=$perm -operm=$perm -vlayout=$vlayout -kname=$KNAME $COMMON_ARGS $EXE -prec=$prec -mode=$mode -b=2 -h=2 -h_k=1 -d=16, -d_v=$hdim -s=55 -s_k=256 -bias=$bias -p_drop=$p_drop -lse=$lse -iperm=$perm -operm=$perm -vlayout=$vlayout -kname=$KNAME $COMMON_ARGS
$EXE -prec=$prec -mode=$mode -b=1 -h=3 -d=$hdim -s=100 -s_k=51 -bias=$bias -lse=$lse -iperm=$perm -operm=$perm -vlayout=$vlayout -kname=$KNAME $COMMON_ARGS $EXE -prec=$prec -mode=$mode -b=1 -h=3 -d=$hdim -s=100 -s_k=51 -bias=$bias -p_drop=$p_drop -lse=$lse -iperm=$perm -operm=$perm -vlayout=$vlayout -kname=$KNAME $COMMON_ARGS
$EXE -prec=$prec -mode=$mode -b=2 -h=1 -d=16 -d_v=$hdim -s=99 -s_k=256 -bias=$bias -lse=$lse -iperm=$perm -operm=$perm -mask=1 -vlayout=$vlayout -kname=$KNAME $COMMON_ARGS $EXE -prec=$prec -mode=$mode -b=2 -h=1 -d=16 -d_v=$hdim -s=99 -s_k=256 -bias=$bias -p_drop=$p_drop -lse=$lse -iperm=$perm -operm=$perm -mask=1 -vlayout=$vlayout -kname=$KNAME $COMMON_ARGS
$EXE -prec=$prec -mode=$mode -b=1 -h=2 -h_k=1 -d=$hdim -s=1024 -s_k=256 -bias=$bias -lse=$lse -iperm=$perm -operm=$perm -mask=2 -vlayout=$vlayout -kname=$KNAME $COMMON_ARGS $EXE -prec=$prec -mode=$mode -b=1 -h=2 -h_k=1 -d=$hdim -s=1024 -s_k=256 -bias=$bias -p_drop=$p_drop -lse=$lse -iperm=$perm -operm=$perm -mask=2 -vlayout=$vlayout -kname=$KNAME $COMMON_ARGS
$EXE -prec=$prec -mode=$mode -b=2 -h=1 -d=$hdim -d_v=24 -s=3 -s_k=99 -bias=$bias -lse=$lse -iperm=$perm -operm=$perm -mask=2 -vlayout=$vlayout -kname=$KNAME $COMMON_ARGS $EXE -prec=$prec -mode=$mode -b=2 -h=1 -d=$hdim -d_v=24 -s=3 -s_k=99 -bias=$bias -p_drop=$p_drop -lse=$lse -iperm=$perm -operm=$perm -mask=2 -vlayout=$vlayout -kname=$KNAME $COMMON_ARGS
$EXE -prec=$prec -mode=$mode -b=3 -h=2 -h_k=1 -d=$hdim -s=200 -s_k=520 -bias=$bias -lse=$lse -iperm=$perm -operm=$perm -mask=t:128,30 -vlayout=$vlayout -kname=$KNAME $COMMON_ARGS $EXE -prec=$prec -mode=$mode -b=3 -h=2 -h_k=1 -d=$hdim -s=200 -s_k=520 -bias=$bias -p_drop=$p_drop -lse=$lse -iperm=$perm -operm=$perm -mask=t:128,30 -vlayout=$vlayout -kname=$KNAME $COMMON_ARGS
$EXE -prec=$prec -mode=$mode -b=2 -h=1 -d=$hdim -s=99 -s_k=32 -bias=$bias -lse=$lse -iperm=$perm -operm=$perm -mask=b:4,35 -vlayout=$vlayout -kname=$KNAME $COMMON_ARGS $EXE -prec=$prec -mode=$mode -b=2 -h=1 -d=$hdim -s=99 -s_k=32 -bias=$bias -p_drop=$p_drop -lse=$lse -iperm=$perm -operm=$perm -mask=b:4,35 -vlayout=$vlayout -kname=$KNAME $COMMON_ARGS
$EXE -prec=$prec -mode=$mode -b=1 -h=2 -h_k=1 -d=$hdim -s=33 -s_k=0 -bias=$bias -lse=$lse -iperm=$perm -operm=$perm -mask=2 -vlayout=$vlayout -kname=$KNAME $COMMON_ARGS $EXE -prec=$prec -mode=$mode -b=1 -h=2 -h_k=1 -d=$hdim -s=33 -s_k=0 -bias=$bias -p_drop=$p_drop -lse=$lse -iperm=$perm -operm=$perm -mask=2 -vlayout=$vlayout -kname=$KNAME $COMMON_ARGS
$EXE -prec=$prec -mode=$mode -b=1 -h=2 -h_k=1 -d=$hdim -s=1 -s_k=10 -s_kpad=32 -bias=$bias -lse=$lse -iperm=$perm -operm=$perm -mask=2 -vlayout=$vlayout -kname=$KNAME $COMMON_ARGS $EXE -prec=$prec -mode=$mode -b=1 -h=2 -h_k=1 -d=$hdim -s=1 -s_k=10 -s_kpad=32 -bias=$bias -lse=$lse -iperm=$perm -operm=$perm -mask=2 -vlayout=$vlayout -kname=$KNAME $COMMON_ARGS
done done
...@@ -37,6 +38,7 @@ done ...@@ -37,6 +38,7 @@ done
done done
done done
done done
done
for perm in 0 1 ; do for perm in 0 1 ; do
for bias in "n" "e" "a" ; do for bias in "n" "e" "a" ; do
......
...@@ -8,6 +8,7 @@ ...@@ -8,6 +8,7 @@
#include "ck_tile/core/algorithm/space_filling_curve.hpp" #include "ck_tile/core/algorithm/space_filling_curve.hpp"
#include "ck_tile/core/arch/amd_buffer_addressing.hpp" #include "ck_tile/core/arch/amd_buffer_addressing.hpp"
#include "ck_tile/core/arch/arch.hpp" #include "ck_tile/core/arch/arch.hpp"
#include "ck_tile/core/arch/generic_memory_space_atomic.hpp"
#include "ck_tile/core/arch/utility.hpp" #include "ck_tile/core/arch/utility.hpp"
#include "ck_tile/core/config.hpp" #include "ck_tile/core/config.hpp"
#include "ck_tile/core/container/array.hpp" #include "ck_tile/core/container/array.hpp"
...@@ -47,10 +48,12 @@ ...@@ -47,10 +48,12 @@
#include "ck_tile/core/tensor/tile_distribution_encoding.hpp" #include "ck_tile/core/tensor/tile_distribution_encoding.hpp"
#include "ck_tile/core/tensor/tile_elementwise.hpp" #include "ck_tile/core/tensor/tile_elementwise.hpp"
#include "ck_tile/core/tensor/tile_window.hpp" #include "ck_tile/core/tensor/tile_window.hpp"
#include "ck_tile/core/tensor/update_tile.hpp"
#include "ck_tile/core/utility/bit_cast.hpp" #include "ck_tile/core/utility/bit_cast.hpp"
#include "ck_tile/core/utility/functional.hpp" #include "ck_tile/core/utility/functional.hpp"
#include "ck_tile/core/utility/ignore.hpp" #include "ck_tile/core/utility/ignore.hpp"
#include "ck_tile/core/utility/magic_div.hpp" #include "ck_tile/core/utility/magic_div.hpp"
#include "ck_tile/core/utility/philox_rand.hpp"
#include "ck_tile/core/utility/random.hpp" #include "ck_tile/core/utility/random.hpp"
#include "ck_tile/core/utility/to_sequence.hpp" #include "ck_tile/core/utility/to_sequence.hpp"
#include "ck_tile/core/utility/transpose_vectors.hpp" #include "ck_tile/core/utility/transpose_vectors.hpp"
......
// SPDX-License-Identifier: MIT // SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once #pragma once
...@@ -783,6 +783,28 @@ llvm_amdgcn_raw_buffer_store_i32(int32_t vdata, ...@@ -783,6 +783,28 @@ llvm_amdgcn_raw_buffer_store_i32(int32_t vdata,
index_t soffset, index_t soffset,
index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.store.i32"); index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.store.i32");
// buffer store ui16
CK_TILE_DEVICE_EXTERN void
llvm_amdgcn_raw_buffer_store_ui16(uint16_t vdata,
int32x4_t rsrc,
index_t voffset,
index_t soffset,
index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.store.i16");
CK_TILE_DEVICE_EXTERN void
llvm_amdgcn_raw_buffer_store_ui16x2(uint16x2_t vdata,
int32x4_t rsrc,
index_t voffset,
index_t soffset,
index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.store.v2i16");
CK_TILE_DEVICE_EXTERN void
llvm_amdgcn_raw_buffer_store_ui16x4(uint16x4_t vdata,
int32x4_t rsrc,
index_t voffset,
index_t soffset,
index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.store.v4i16");
CK_TILE_DEVICE_EXTERN void CK_TILE_DEVICE_EXTERN void
llvm_amdgcn_raw_buffer_store_i32x2(int32x2_t vdata, llvm_amdgcn_raw_buffer_store_i32x2(int32x2_t vdata,
int32x4_t rsrc, int32x4_t rsrc,
...@@ -1353,7 +1375,10 @@ CK_TILE_DEVICE void amd_buffer_store_impl(const thread_buffer<T, N> src_thread_d ...@@ -1353,7 +1375,10 @@ CK_TILE_DEVICE void amd_buffer_store_impl(const thread_buffer<T, N> src_thread_d
(N == 1 || N == 2 || N == 4 || N == 8 || N == 16)) || (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)) ||
(std::is_same<T, fp8_t>::value && (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)) || (std::is_same<T, fp8_t>::value && (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)) ||
(std::is_same<T, bf8_t>::value && (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)) || (std::is_same<T, bf8_t>::value && (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)) ||
(std::is_same<T, int8_t>::value && (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)), (std::is_same<T, int8_t>::value && (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)) ||
(std::is_same<T, uint16_t>::value &&
(N == 1 || N == 2 || N == 4 || N == 8 || N == 16)) ||
(std::is_same<T, uint8_t>::value && (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)),
"wrong! not implemented"); "wrong! not implemented");
if constexpr(std::is_same<T, float>::value) // fp32 if constexpr(std::is_same<T, float>::value) // fp32
...@@ -1492,6 +1517,49 @@ CK_TILE_DEVICE void amd_buffer_store_impl(const thread_buffer<T, N> src_thread_d ...@@ -1492,6 +1517,49 @@ CK_TILE_DEVICE void amd_buffer_store_impl(const thread_buffer<T, N> src_thread_d
static_cast<index_t>(coherence)); static_cast<index_t>(coherence));
} }
} }
else if constexpr(std::is_same<T, uint16_t>::value)
{
if constexpr(N == 1)
{
llvm_amdgcn_raw_buffer_store_ui16(bit_cast<uint16_t>(src_thread_data),
dst_wave_buffer_resource,
dst_thread_addr_offset,
dst_wave_addr_offset,
static_cast<index_t>(coherence));
}
else if constexpr(N == 2)
{
llvm_amdgcn_raw_buffer_store_ui16x2(bit_cast<uint16x2_t>(src_thread_data),
dst_wave_buffer_resource,
dst_thread_addr_offset,
dst_wave_addr_offset,
static_cast<index_t>(coherence));
}
else if constexpr(N == 4)
{
llvm_amdgcn_raw_buffer_store_ui16x4(bit_cast<uint16x4_t>(src_thread_data),
dst_wave_buffer_resource,
dst_thread_addr_offset,
dst_wave_addr_offset,
static_cast<index_t>(coherence));
}
else if constexpr(N == 8)
{
llvm_amdgcn_raw_buffer_store_ui16x4(
src_thread_data.template get_as<uint16x4_t>()[number<0>{}],
dst_wave_buffer_resource,
dst_thread_addr_offset,
dst_wave_addr_offset,
static_cast<index_t>(coherence));
llvm_amdgcn_raw_buffer_store_ui16x4(
src_thread_data.template get_as<uint16x4_t>()[number<1>{}],
dst_wave_buffer_resource,
dst_thread_addr_offset,
dst_wave_addr_offset + 4 * sizeof(uint16_t),
static_cast<index_t>(coherence));
}
}
else else
{ {
using r_t = thread_buffer<int8_t, sizeof(T) * N>; using r_t = thread_buffer<int8_t, sizeof(T) * N>;
...@@ -1609,7 +1677,7 @@ CK_TILE_DEVICE void amd_buffer_atomic_add_impl(const thread_buffer<T, N>& src_th ...@@ -1609,7 +1677,7 @@ CK_TILE_DEVICE void amd_buffer_atomic_add_impl(const thread_buffer<T, N>& src_th
{ {
if constexpr(N == 2) if constexpr(N == 2)
{ {
llvm_amdgcn_raw_buffer_atomic_add_fp16x2(bit_cast<fp16_t>(src_thread_data), llvm_amdgcn_raw_buffer_atomic_add_fp16x2(bit_cast<fp16x2_t>(src_thread_data),
dst_wave_buffer_resource, dst_wave_buffer_resource,
dst_thread_addr_offset, dst_thread_addr_offset,
dst_wave_addr_offset, dst_wave_addr_offset,
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core/numeric/vector_type.hpp"
#include "ck_tile/core/numeric/type_convert.hpp"
#include "ck_tile/core/container/thread_buffer.hpp"
namespace ck_tile {
CK_TILE_HOST_DEVICE bf16_t add_bf16_t(const bf16_t& a, const bf16_t& b)
{
return type_convert<bf16_t>(type_convert<float>(a) + type_convert<float>(b));
}
CK_TILE_HOST_DEVICE bf16x2_t add_bf16x2_t(const bf16x2_t& a, const bf16x2_t& b)
{
bf16x2_t rtn;
rtn[0] = add_bf16_t(a[0], b[0]);
rtn[1] = add_bf16_t(a[1], b[1]);
return rtn;
}
// Caution: DO NOT REMOVE
// intentionally have only declaration but no definition to cause compilation failure when trying to
// instantiate this template. The purpose is to make the implementation of atomic_add explicit for
// each datatype.
template <typename X>
CK_TILE_DEVICE void atomic_add(X* p_dst, const X& x);
template <>
CK_TILE_DEVICE void atomic_add<bf16x2_t>(bf16x2_t* p_dst, const bf16x2_t& x)
{
union U32BF162_ADDR
{
uint32_t* u32_a;
bf16x2_t* bf162_a;
};
union U32BF162
{
uint32_t u32;
bf16x2_t bf162;
};
U32BF162_ADDR dword_addr;
U32BF162 cur_v;
U32BF162 new_;
uint32_t old_v, new_v;
dword_addr.bf162_a = p_dst;
cur_v.u32 = *dword_addr.u32_a;
do
{
old_v = cur_v.u32;
new_.bf162 = add_bf16x2_t(cur_v.bf162, x);
new_v = new_.u32;
cur_v.u32 = atomicCAS(dword_addr.u32_a, old_v, new_v);
} while(cur_v.u32 != old_v);
}
template <typename T, index_t N>
CK_TILE_DEVICE void atomic_add_g(T* p_dst, const thread_buffer<T, N>& x)
{
static_assert((std::is_same<T, int32_t>::value && (N == 1)) ||
(std::is_same<T, uint32_t>::value && (N == 1)) ||
(std::is_same<T, float>::value && (N == 1 || N == 2)) ||
(std::is_same<T, double>::value && (N == 1 || N == 2)) ||
(std::is_same<T, bf16_t>::value && (N == 2 || N == 4)),
"wrong! not implemented");
constexpr auto I0 = number<0>{};
constexpr auto I1 = number<1>{};
if constexpr(std::is_same<T, float>::value)
{
if constexpr(N == 1)
{
atomicAdd(p_dst, bit_cast<float>(x));
}
else if constexpr(N == 2)
{
atomicAdd(c_style_pointer_cast<float*>(p_dst), x.template get_as<float>()[I0]);
atomicAdd(c_style_pointer_cast<float*>(p_dst) + 1, x.template get_as<float>()[I1]);
}
}
else if constexpr(std::is_same<T, double>::value)
{
if constexpr(N == 1)
{
return atomicAdd(p_dst, bit_cast<double>(x));
}
else if constexpr(N == 2)
{
atomicAdd(c_style_pointer_cast<double*>(p_dst), x.template get_as<double>()[I0]);
atomicAdd(c_style_pointer_cast<double*>(p_dst) + 1, x.template get_as<double>()[I1]);
}
}
else if constexpr(std::is_same<T, int32_t>::value)
{
if constexpr(N == 1)
{
atomicAdd(p_dst, bit_cast<int32_t>(x));
}
}
else if constexpr(std::is_same<T, uint32_t>::value)
{
if constexpr(N == 1)
{
atomicAdd(p_dst, bit_cast<uint32_t>(x));
}
}
else if constexpr(std::is_same<T, bf16_t>::value)
{
if constexpr(N == 2)
{
atomic_add(c_style_pointer_cast<bf16x2_t*>(p_dst), bit_cast<bf16x2_t>(x));
}
else if constexpr(N == 4)
{
atomic_add(c_style_pointer_cast<bf16x2_t*>(p_dst), x.template get_as<bf16x2_t>()[I0]);
atomic_add(c_style_pointer_cast<bf16x2_t*>(p_dst) + 1,
x.template get_as<bf16x2_t>()[I1]);
}
}
}
template <typename T, index_t N>
CK_TILE_DEVICE void atomic_max_g(T* p_dst, const thread_buffer<T, N>& x)
{
static_assert((std::is_same<T, int32_t>::value && (N == 1)) ||
(std::is_same<T, uint32_t>::value && (N == 1)) ||
(std::is_same<T, float>::value && (N == 1 || N == 2)) ||
(std::is_same<T, double>::value && (N == 1)),
"wrong! not implemented");
constexpr auto I0 = number<0>{};
constexpr auto I1 = number<1>{};
if constexpr(std::is_same<T, float>::value)
{
if constexpr(N == 1)
{
atomicMax(p_dst, bit_cast<float>(x));
}
else if constexpr(N == 2)
{
atomicMax(c_style_pointer_cast<float*>(p_dst), x.template get_as<float>()[I0]);
atomicMax(c_style_pointer_cast<float*>(p_dst) + 1, x.template get_as<float>()[I1]);
}
}
else if constexpr(std::is_same<T, double>::value)
{
if constexpr(N == 1)
{
atomicMax(p_dst, bit_cast<double>(x));
}
}
else if constexpr(std::is_same<T, int32_t>::value)
{
if constexpr(N == 1)
{
atomicMax(p_dst, bit_cast<int32_t>(x));
}
}
else if constexpr(std::is_same<T, uint32_t>::value)
{
if constexpr(N == 1)
{
atomicMax(p_dst, bit_cast<uint32_t>(x));
}
}
}
} // namespace ck_tile
// SPDX-License-Identifier: MIT // SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once #pragma once
...@@ -144,6 +144,15 @@ using int8x16_t = int8_t __attribute((ext_vector_type(16))); ...@@ -144,6 +144,15 @@ using int8x16_t = int8_t __attribute((ext_vector_type(16)));
using int8x32_t = int8_t __attribute((ext_vector_type(32))); using int8x32_t = int8_t __attribute((ext_vector_type(32)));
using int8x64_t = int8_t __attribute((ext_vector_type(64))); using int8x64_t = int8_t __attribute((ext_vector_type(64)));
// ui8
// using uint8_t
using uint8x2_t = uint8_t __attribute((ext_vector_type(2)));
using uint8x4_t = uint8_t __attribute((ext_vector_type(4)));
using uint8x8_t = uint8_t __attribute((ext_vector_type(8)));
using uint8x16_t = uint8_t __attribute((ext_vector_type(16)));
using uint8x32_t = uint8_t __attribute((ext_vector_type(32)));
using uint8x64_t = uint8_t __attribute((ext_vector_type(64)));
#if CK_TILE_USE_CUSTOM_DATA_TYPE #if CK_TILE_USE_CUSTOM_DATA_TYPE
// f8 // f8
// using fp8_t // using fp8_t
......
...@@ -6,6 +6,7 @@ ...@@ -6,6 +6,7 @@
#include "ck_tile/core/config.hpp" #include "ck_tile/core/config.hpp"
#include "ck_tile/core/arch/arch.hpp" #include "ck_tile/core/arch/arch.hpp"
#include "ck_tile/core/arch/amd_buffer_addressing.hpp" #include "ck_tile/core/arch/amd_buffer_addressing.hpp"
#include "ck_tile/core/arch/generic_memory_space_atomic.hpp"
#include "ck_tile/core/container/array.hpp" #include "ck_tile/core/container/array.hpp"
#include "ck_tile/core/numeric/integer.hpp" #include "ck_tile/core/numeric/integer.hpp"
#include "ck_tile/core/numeric/integral_constant.hpp" #include "ck_tile/core/numeric/integral_constant.hpp"
...@@ -507,10 +508,10 @@ struct buffer_view<address_space_enum::global, ...@@ -507,10 +508,10 @@ struct buffer_view<address_space_enum::global,
bool constexpr use_amd_buffer_addressing = false; bool constexpr use_amd_buffer_addressing = false;
#endif #endif
constexpr index_t t_per_x = scalar_per_x_vector / scalar_per_t_vector;
if constexpr(use_amd_buffer_addressing) if constexpr(use_amd_buffer_addressing)
{ {
constexpr index_t t_per_x = scalar_per_x_vector / scalar_per_t_vector;
amd_buffer_atomic_add<remove_cvref_t<T>, t_per_x>( amd_buffer_atomic_add<remove_cvref_t<T>, t_per_x>(
x, p_data_, i, is_valid_element, buffer_size_); x, p_data_, i, is_valid_element, buffer_size_);
} }
...@@ -518,7 +519,7 @@ struct buffer_view<address_space_enum::global, ...@@ -518,7 +519,7 @@ struct buffer_view<address_space_enum::global,
{ {
if(is_valid_element) if(is_valid_element)
{ {
atomic_add<X>(c_style_pointer_cast<X*>(&p_data_[i]), x); atomic_add_g<remove_cvref_t<T>, t_per_x>(&p_data_[i], x);
} }
} }
} }
...@@ -547,16 +548,16 @@ struct buffer_view<address_space_enum::global, ...@@ -547,16 +548,16 @@ struct buffer_view<address_space_enum::global,
bool constexpr use_amd_buffer_addressing = false; bool constexpr use_amd_buffer_addressing = false;
#endif #endif
constexpr index_t t_per_x = scalar_per_x_vector / scalar_per_t_vector;
if constexpr(use_amd_buffer_addressing) if constexpr(use_amd_buffer_addressing)
{ {
constexpr index_t t_per_x = scalar_per_x_vector / scalar_per_t_vector;
amd_buffer_atomic_max<remove_cvref_t<T>, t_per_x>( amd_buffer_atomic_max<remove_cvref_t<T>, t_per_x>(
x, p_data_, i, is_valid_element, buffer_size_); x, p_data_, i, is_valid_element, buffer_size_);
} }
else if(is_valid_element) else if(is_valid_element)
{ {
atomic_max<X>(c_style_pointer_cast<X*>(&p_data_[i]), x); atomic_max_g<remove_cvref_t<T>, t_per_x>(&p_data_[i], x);
} }
} }
......
// SPDX-License-Identifier: MIT // SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once #pragma once
......
...@@ -16,7 +16,9 @@ ...@@ -16,7 +16,9 @@
namespace ck_tile { namespace ck_tile {
template <typename BufferView_, typename TensorDesc_> template <typename BufferView_,
typename TensorDesc_,
memory_operation_enum DstInMemOp_ = memory_operation_enum::set>
struct tensor_view struct tensor_view
{ {
using buffer_view = remove_reference_t<BufferView_>; using buffer_view = remove_reference_t<BufferView_>;
...@@ -24,6 +26,7 @@ struct tensor_view ...@@ -24,6 +26,7 @@ struct tensor_view
using TensorDesc = remove_cvref_t<TensorDesc_>; using TensorDesc = remove_cvref_t<TensorDesc_>;
using TensorIndex = array<index_t, TensorDesc::get_num_of_top_dimension()>; using TensorIndex = array<index_t, TensorDesc::get_num_of_top_dimension()>;
using TensorCoord = decltype(make_tensor_coordinate(TensorDesc{}, TensorIndex{})); using TensorCoord = decltype(make_tensor_coordinate(TensorDesc{}, TensorIndex{}));
static constexpr auto DstInMemOp = DstInMemOp_;
CK_TILE_HOST_DEVICE constexpr tensor_view() = default; CK_TILE_HOST_DEVICE constexpr tensor_view() = default;
...@@ -140,6 +143,23 @@ struct tensor_view ...@@ -140,6 +143,23 @@ struct tensor_view
x); x);
} }
// X is vector of DataType.
// "coord" is coordinate of DataType, not X. "coord" should be aligned to X
template <typename X,
bool oob_conditional_check = true,
typename std::enable_if<
std::is_same_v<typename vector_traits<remove_cvref_t<X>>::scalar_type,
typename vector_traits<remove_cvref_t<DataType>>::scalar_type>,
bool>::type = false>
CK_TILE_HOST_DEVICE constexpr void update_vectorized_elements(
const TensorCoord& coord, const X& x, bool_constant<oob_conditional_check> = {})
{
buf_.template update<DstInMemOp, X, oob_conditional_check>(
coord.get_offset(),
coordinate_has_valid_offset_assuming_top_index_is_valid(desc_, coord),
x);
}
CK_TILE_HOST_DEVICE void print() const CK_TILE_HOST_DEVICE void print() const
{ {
printf("tensor_view{"); printf("tensor_view{");
...@@ -178,6 +198,7 @@ CK_TILE_HOST_DEVICE constexpr auto make_tensor_view(DataType* p, ...@@ -178,6 +198,7 @@ CK_TILE_HOST_DEVICE constexpr auto make_tensor_view(DataType* p,
} }
template <address_space_enum BufferAddressSpace = address_space_enum::generic, template <address_space_enum BufferAddressSpace = address_space_enum::generic,
memory_operation_enum DstInMemOp = memory_operation_enum::set,
typename DataType, typename DataType,
typename... Lengths, typename... Lengths,
typename... Strides, typename... Strides,
...@@ -198,7 +219,7 @@ make_naive_tensor_view(DataType* p, ...@@ -198,7 +219,7 @@ make_naive_tensor_view(DataType* p,
auto buffer_view = make_buffer_view<BufferAddressSpace>(p, desc.get_element_space_size()); auto buffer_view = make_buffer_view<BufferAddressSpace>(p, desc.get_element_space_size());
return tensor_view<decltype(buffer_view), decltype(desc)>{buffer_view, desc}; return tensor_view<decltype(buffer_view), decltype(desc), DstInMemOp>{buffer_view, desc};
} }
template <address_space_enum BufferAddressSpace = address_space_enum::generic, template <address_space_enum BufferAddressSpace = address_space_enum::generic,
...@@ -232,8 +253,9 @@ CK_TILE_HOST_DEVICE constexpr auto transform_tensor_view(const OldTensorView& ol ...@@ -232,8 +253,9 @@ CK_TILE_HOST_DEVICE constexpr auto transform_tensor_view(const OldTensorView& ol
NewLowerDimensionOldVisibleIdss{}, NewLowerDimensionOldVisibleIdss{},
NewUpperDimensionNewVisibleIdss{}); NewUpperDimensionNewVisibleIdss{});
return tensor_view<typename OldTensorView::buffer_view, remove_cvref_t<decltype(new_desc)>>{ return tensor_view<typename OldTensorView::buffer_view,
old_tensor_view.buf_, new_desc}; remove_cvref_t<decltype(new_desc)>,
remove_cvref_t<OldTensorView>::DstInMemOp>{old_tensor_view.buf_, new_desc};
} }
template <typename TensorView, template <typename TensorView,
......
...@@ -9,6 +9,7 @@ ...@@ -9,6 +9,7 @@
#include "ck_tile/core/container/sequence.hpp" #include "ck_tile/core/container/sequence.hpp"
#include "ck_tile/core/container/tuple.hpp" #include "ck_tile/core/container/tuple.hpp"
#include "ck_tile/core/container/container_helper.hpp" #include "ck_tile/core/container/container_helper.hpp"
#include "ck_tile/core/container/meta_data_buffer.hpp"
#include "ck_tile/core/tensor/tensor_adaptor.hpp" #include "ck_tile/core/tensor/tensor_adaptor.hpp"
#include "ck_tile/core/tensor/tile_distribution_encoding.hpp" #include "ck_tile/core/tensor/tile_distribution_encoding.hpp"
#include "ck_tile/core/utility/functional.hpp" #include "ck_tile/core/utility/functional.hpp"
......
...@@ -594,6 +594,66 @@ struct tile_window_with_static_distribution ...@@ -594,6 +594,66 @@ struct tile_window_with_static_distribution
}); });
} }
template <bool oob_conditional_check = true>
CK_TILE_DEVICE void update(const static_distributed_tensor<DataType, TileDstr>& dstr_tensor,
bool_constant<oob_conditional_check> = {}) const
{
using Traits = load_store_traits;
using vector_t = typename Traits::vector_t;
using SFC_Ys = typename Traits::SFC_Ys;
constexpr auto tile_dstr = TileDstr{};
// loop over thread tensor space [y0, y1, ...]
static_for<0, NumCoord, 1>{}([&](auto iCoord) {
/// TODO: use structure binding (to be captured later) if compiled in C++20
auto window_adaptor_thread_coord = pre_computed_coords_[iCoord][I0];
auto bottom_tensor_thread_coord = pre_computed_coords_[iCoord][I1];
static_for<0, NumAccessPerCoord, 1>{}([&](auto iCoordAccess) {
constexpr auto iAccess = number<iCoord * NumAccessPerCoord + iCoordAccess>{};
// data index [y0, y1, ...]
constexpr auto idx_ys_start = SFC_Ys::get_index(iAccess);
// read from distributed tensor
vector_t vec_value;
static_for<0, Traits::ScalarPerVector, 1>{}([&](auto j) {
constexpr auto idx_ys = generate_array(
[&](auto jj) {
return jj == Traits::VectorDimY ? (idx_ys_start[jj] + j)
: idx_ys_start[jj];
},
number<NDimY>{});
constexpr index_t d =
tile_dstr.get_ys_to_d_descriptor().calculate_offset(idx_ys);
vec_value.template get_as<DataType>()(j) =
dstr_tensor.get_thread_buffer().template at<d>();
});
// write into bottom tensor
get_bottom_tensor_view().template update_vectorized_elements<vector_t>(
bottom_tensor_thread_coord, vec_value, bool_constant<oob_conditional_check>{});
// move thread coordinate
if constexpr(iCoordAccess != (NumAccessPerCoord - 1))
{
constexpr auto idx_diff_ys = SFC_Ys::get_forward_step(iAccess);
constexpr auto idx_diff_ps_ys =
container_concat(array<index_t, NDimP>{0}, idx_diff_ys);
move_window_adaptor_and_bottom_tensor_thread_coordinate(
window_adaptor_thread_coord, bottom_tensor_thread_coord, idx_diff_ps_ys);
}
});
});
}
// move thread's botom tensor coordiante // move thread's botom tensor coordiante
// [x0', x1', ... ] ==> [offset] // [x0', x1', ... ] ==> [offset]
// also move window-origin // also move window-origin
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core/config.hpp"
#include "ck_tile/core/numeric/integer.hpp"
#include "ck_tile/core/numeric/integral_constant.hpp"
#include "ck_tile/core/algorithm/coordinate_transform.hpp"
#include "ck_tile/core/container/container_helper.hpp"
#include "ck_tile/core/numeric/math.hpp"
#include "ck_tile/core/tensor/tile_window.hpp"
#include "ck_tile/core/utility/type_traits.hpp"
namespace ck_tile {
template <typename BottomTensorView_,
typename WindowLengths_,
typename TileDistribution_,
typename DataType_>
CK_TILE_DEVICE void
update_tile(tile_window_with_static_lengths<BottomTensorView_, WindowLengths_>& tile_window_tmp,
const static_distributed_tensor<DataType_, TileDistribution_>& dstr_tensor)
{
using DataType = remove_cvref_t<typename BottomTensorView_::DataType>;
using TileDstr = remove_cvref_t<TileDistribution_>;
static_assert(std::is_same_v<remove_cvref_t<DataType_>, DataType>, "wrong!");
constexpr auto tile_dstr = TileDstr{};
auto tile_window = make_tile_window(tile_window_tmp.get_bottom_tensor_view(),
tile_window_tmp.get_window_lengths(),
tile_window_tmp.get_window_origin(),
tile_dstr);
tile_window.update(dstr_tensor);
}
template <typename BottomTensorView_,
typename WindowLengths_,
typename TileDistribution_,
index_t NumCoord,
typename DataType_>
CK_TILE_DEVICE void
update_tile(tile_window_with_static_distribution<BottomTensorView_,
WindowLengths_,
TileDistribution_,
NumCoord>& tile_window,
const static_distributed_tensor<DataType_, TileDistribution_>& dstr_tensor)
{
tile_window.update(dstr_tensor);
}
} // namespace ck_tile
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment