Commit 7ffb0921 authored by Adam Osewski's avatar Adam Osewski
Browse files

Merge branch 'develop' into aosewski/ck_tile_universal_gemm_p1

parents 4cf45f1b 0023f01a
......@@ -569,7 +569,9 @@ if(NOT DEFINED INSTANCES_ONLY)
PACKAGE_NAME examples
)
add_subdirectory(example)
add_subdirectory(test)
if(BUILD_TESTING)
add_subdirectory(test)
endif()
rocm_package_setup_component(profiler
LIBRARY_NAME composablekernel
......
......@@ -320,7 +320,7 @@ def cmake_build(Map conf=[:]){
if (package_build == true && (env.BRANCH_NAME == "develop" || env.BRANCH_NAME == "amd-master")) {
archiveArtifacts artifacts: "build/*.deb", allowEmptyArchive: true, fingerprint: true
}
if (params.RUN_CK_TILE_TESTS){
if (params.RUN_CK_TILE_FMHA_TESTS){
try{
archiveArtifacts "perf_fmha_fwd_*.log"
archiveArtifacts "perf_fmha_bwd_*.log"
......@@ -371,7 +371,7 @@ def buildHipClangJob(Map conf=[:]){
def retimage
(retimage, image) = getDockerImage(conf)
gitStatusWrapper(credentialsId: "${status_wrapper_creds}", gitHubContext: "Jenkins - ${variant}", account: 'ROCm', repo: 'composable_kernel') {
gitStatusWrapper(credentialsId: "${env.ck_git_creds}", gitHubContext: "Jenkins - ${variant}", account: 'ROCm', repo: 'composable_kernel') {
withDockerContainer(image: image, args: dockerOpts + ' -v=/var/jenkins/:/var/jenkins') {
timeout(time: 48, unit: 'HOURS')
{
......@@ -426,7 +426,7 @@ def runCKProfiler(Map conf=[:]){
def variant = env.STAGE_NAME
def retimage
gitStatusWrapper(credentialsId: "${status_wrapper_creds}", gitHubContext: "Jenkins - ${variant}", account: 'ROCm', repo: 'composable_kernel') {
gitStatusWrapper(credentialsId: "${env.ck_git_creds}", gitHubContext: "Jenkins - ${variant}", account: 'ROCm', repo: 'composable_kernel') {
try {
(retimage, image) = getDockerImage(conf)
withDockerContainer(image: image, args: dockerOpts) {
......@@ -563,7 +563,7 @@ def Build_CK(Map conf=[:]){
def variant = env.STAGE_NAME
def retimage
gitStatusWrapper(credentialsId: "${env.status_wrapper_creds}", gitHubContext: "Jenkins - ${variant}", account: 'ROCm', repo: 'composable_kernel') {
gitStatusWrapper(credentialsId: "${env.ck_git_creds}", gitHubContext: "Jenkins - ${variant}", account: 'ROCm', repo: 'composable_kernel') {
try {
(retimage, image) = getDockerImage(conf)
withDockerContainer(image: image, args: dockerOpts) {
......@@ -668,7 +668,7 @@ def process_results(Map conf=[:]){
def variant = env.STAGE_NAME
def retimage
gitStatusWrapper(credentialsId: "${env.status_wrapper_creds}", gitHubContext: "Jenkins - ${variant}", account: 'ROCm', repo: 'composable_kernel') {
gitStatusWrapper(credentialsId: "${env.ck_git_creds}", gitHubContext: "Jenkins - ${variant}", account: 'ROCm', repo: 'composable_kernel') {
try {
(retimage, image) = getDockerImage(conf)
}
......@@ -682,7 +682,7 @@ def process_results(Map conf=[:]){
timeout(time: 1, unit: 'HOURS'){
try{
dir("script"){
if (params.RUN_CK_TILE_TESTS){
if (params.RUN_CK_TILE_FMHA_TESTS){
try{
unstash "perf_fmha_fwd_gfx942.log"
unstash "perf_fmha_bwd_gfx942.log"
......@@ -838,7 +838,7 @@ pipeline {
dbsshport = "${dbsshport}"
dbsshuser = "${dbsshuser}"
dbsshpassword = "${dbsshpassword}"
status_wrapper_creds = "${status_wrapper_creds}"
ck_git_creds = "${ck_git_creds}"
gerrit_cred="${gerrit_cred}"
DOCKER_BUILDKIT = "1"
}
......
......@@ -233,6 +233,8 @@ function(add_embed_library EMBED_NAME)
else()
target_sources(${EMBED_NAME} INTERFACE $<TARGET_OBJECTS:${INTERNAL_EMBED_LIB}>)
endif()
target_include_directories(${EMBED_NAME} INTERFACE "${EMBED_DIR}/include")
target_include_directories(${EMBED_NAME} INTERFACE
$<BUILD_INTERFACE:${EMBED_DIR}/include>
$<INSTALL_INTERFACE:include/ck>)
endfunction()
......@@ -39,6 +39,7 @@ set_target_properties(ck_host PROPERTIES
target_include_directories(ck_host PUBLIC
$<BUILD_INTERFACE:${CMAKE_CURRENT_SOURCE_DIR}/include>
$<INSTALL_INTERFACE:include>
)
add_executable(ck-template-driver driver/main.cpp)
......@@ -48,6 +49,12 @@ rocm_install(
TARGETS ck_host ck_headers
EXPORT ck_hostTargets
)
rocm_install(EXPORT ck_hostTargets
FILE composable_kernelck_hostTargets.cmake
NAMESPACE composable_kernel::
DESTINATION ${CMAKE_INSTALL_LIBDIR}/cmake/composable_kernel)
rocm_install(DIRECTORY include/ck DESTINATION ${CMAKE_INSTALL_INCLUDEDIR})
add_subdirectory(test)
if(BUILD_TESTING)
add_subdirectory(test)
endif()
rocm-docs-core==1.8.1
rocm-docs-core==1.8.2
sphinxcontrib-bibtex==2.6.3
......@@ -103,7 +103,7 @@ requests==2.32.3
# via
# pygithub
# sphinx
rocm-docs-core==1.8.1
rocm-docs-core==1.8.2
# via -r requirements.in
six==1.16.0
# via pybtex
......
add_example_executable(example_complex_contraction_bilinear_xdl_fp32 complex_contraction_bilinear_xdl_fp32.cpp)
add_example_executable(example_complex_contraction_bilinear_xdl_fp64 complex_contraction_bilinear_xdl_fp64.cpp)
# Instructions for ```example_complex_contraction_bilinear_xdl_fp32```
## Run
```bash
#arg1: verification (0=no, 1=yes)
#arg2: initialization (0=no init, 1=integer value, 2=decimal value)
#arg3: time kernel (0=no, 1=yes)
./bin/example_contraction_bilinear_xdl_fp32 1 1 1
```
This diff is collapsed.
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "common_instances.hpp"
using ADataType = F32;
using BDataType = F32;
using AccDataType = F32;
using CShuffleDataType = F32;
using DDataType = F32;
using DsDataType = ck::Tuple<DDataType>;
using EDataType = F32;
using ComputeDataType = F32;
static constexpr ck::index_t NumDimM = 2;
static constexpr ck::index_t NumDimN = 2;
static constexpr ck::index_t NumDimK = 2;
using AElementOp = ck::tensor_operation::element_wise::PassThrough;
using BElementOp = ck::tensor_operation::element_wise::PassThrough;
using CDEElementOp = ck::tensor_operation::element_wise::Bilinear;
using DeviceOpInstanceKKNN = DeviceOpInstanceKK_Generic<NumDimM,
NumDimN,
NumDimK,
ADataType,
BDataType,
AccDataType,
CShuffleDataType,
DsDataType,
EDataType,
ComputeDataType,
AElementOp,
BElementOp,
CDEElementOp>;
using DeviceOpInstanceKNNN = DeviceOpInstanceKN_Generic<NumDimM,
NumDimN,
NumDimK,
ADataType,
BDataType,
AccDataType,
CShuffleDataType,
DsDataType,
EDataType,
ComputeDataType,
AElementOp,
BElementOp,
CDEElementOp>;
using DeviceOpInstanceMKNN = DeviceOpInstanceMK_Generic<NumDimM,
NumDimN,
NumDimK,
ADataType,
BDataType,
AccDataType,
CShuffleDataType,
DsDataType,
EDataType,
ComputeDataType,
AElementOp,
BElementOp,
CDEElementOp>;
using DeviceOpInstanceMNNN = DeviceOpInstanceMN_Generic<NumDimM,
NumDimN,
NumDimK,
ADataType,
BDataType,
AccDataType,
CShuffleDataType,
DsDataType,
EDataType,
ComputeDataType,
AElementOp,
BElementOp,
CDEElementOp>;
using DeviceOpInstance = DeviceOpInstanceKKNN;
#include "run_complex_contraction_bilinear_example.inc"
int main(int argc, char* argv[]) { return run_complex_contraction_bilinear_example(argc, argv); }
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "common_instances.hpp"
using ADataType = F64;
using BDataType = F64;
using AccDataType = F64;
using CShuffleDataType = F64;
using DDataType = F64;
using DsDataType = ck::Tuple<DDataType>;
using EDataType = F64;
using ComputeDataType = F64;
static constexpr ck::index_t NumDimM = 2;
static constexpr ck::index_t NumDimN = 2;
static constexpr ck::index_t NumDimK = 2;
using AElementOp = ck::tensor_operation::element_wise::PassThrough;
using BElementOp = ck::tensor_operation::element_wise::PassThrough;
using CDEElementOp = ck::tensor_operation::element_wise::Bilinear;
using DeviceOpInstanceKKNN = DeviceOpInstanceKK_FP64<NumDimM,
NumDimN,
NumDimK,
ADataType,
BDataType,
AccDataType,
CShuffleDataType,
DsDataType,
EDataType,
ComputeDataType,
AElementOp,
BElementOp,
CDEElementOp>;
using DeviceOpInstanceKNNN = DeviceOpInstanceKN_FP64<NumDimM,
NumDimN,
NumDimK,
ADataType,
BDataType,
AccDataType,
CShuffleDataType,
DsDataType,
EDataType,
ComputeDataType,
AElementOp,
BElementOp,
CDEElementOp>;
using DeviceOpInstanceMKNN = DeviceOpInstanceMK_FP64<NumDimM,
NumDimN,
NumDimK,
ADataType,
BDataType,
AccDataType,
CShuffleDataType,
DsDataType,
EDataType,
ComputeDataType,
AElementOp,
BElementOp,
CDEElementOp>;
using DeviceOpInstanceMNNN = DeviceOpInstanceMN_FP64<NumDimM,
NumDimN,
NumDimK,
ADataType,
BDataType,
AccDataType,
CShuffleDataType,
DsDataType,
EDataType,
ComputeDataType,
AElementOp,
BElementOp,
CDEElementOp>;
using DeviceOpInstance = DeviceOpInstanceKKNN;
#include "run_complex_contraction_bilinear_example.inc"
int main(int argc, char* argv[]) { return run_complex_contraction_bilinear_example(argc, argv); }
......@@ -70,8 +70,13 @@ args:
-seed random seed used for initializing input tensors. 0 for non-deterministic seed (default:11939)
-warmup number of iterations before benchmark the kernel (default:5)
-repeat number of iterations to benchmark the kernel (default:20)
-drop_seed seed for the random number generator for the dropout layer, default is 1
-drop_offset offset for the dropout layer which is used during random number generation, default is 0
-drop_prefs flag to indicate `drop_seed` and `drop_offset` values if present on the GPU, default is 0, 0 - host, 1 - GPU
```
Example: `./bin/tile_example_fmha_fwd -b=1 -h=16 -s=16384 -d=128` will run a fmha case with batch=1, nhead=16, sequence length=16384, hdim=128, fp16 case.
Example 1: `./bin/tile_example_fmha_fwd -b=1 -h=16 -s=16384 -d=128` will run a fmha case with batch=1, nhead=16, sequence length=16384, hdim=128, fp16 case.
Example 2: `./bin/tile_example_fmha_fwd -b=1 -h=8 -s=16384 -d=64 -drop_prefs=1 -drop_seed=10 -drop_offset=1234` will run a fmha case with
batch=1, nhead=8, sequence length=16384, hdim=64, drop_seed=0 (in GPU memory), drop_offset=1234 (in GPU memory) fp16 case
## support features
Currently we are still in rapid development stage, so more features/optimizations will be coming soon.
......
......@@ -85,6 +85,9 @@ auto create_args(int argc, char* argv[])
.insert("p_drop", "0", "0~1 probability of dropout")
.insert("drop_seed", "1", "seed for random number generator")
.insert("drop_offset", "0", "offset for random number generator")
.insert("drop_prefs",
"0",
"seed and offset values are present on GPU; 0 - host, 1 - device/GPU")
.insert("timer", "gpu", "gpu:gpu timer, cpu:cpu timer")
.insert("warmup", "5", "number of iterations before benchmark the kernel")
.insert("repeat", "20", "number of iterations to benchmark the kernel")
......@@ -99,10 +102,23 @@ auto create_args(int argc, char* argv[])
// different threshold for different dtype
template <typename DataType>
auto get_elimit(int /*init_method*/)
auto get_elimit(ck_tile::index_t /*hdim_q*/, ck_tile::index_t /*hdim_v*/)
{
double rtol = 1e-2;
double atol = 1e-2;
return ck_tile::make_tuple(rtol, atol);
}
template <>
auto get_elimit<ck_tile::bf16_t>(ck_tile::index_t hdim_q, ck_tile::index_t hdim_v)
{
double rtol = 1e-2;
double atol = 1e-2;
if(hdim_q > 128 && hdim_v > 128) // 3.2 for RTZ/1.5 for RTN
{
rtol = 3.2e-2;
atol = 3.2e-2;
}
return ck_tile::make_tuple(rtol, atol);
}
......@@ -145,6 +161,8 @@ bool run(const ck_tile::ArgParser& arg_parser)
float p_drop = arg_parser.get_float("p_drop");
uint64_t drop_seed = arg_parser.get_uint64("drop_seed");
uint64_t drop_offset = arg_parser.get_uint64("drop_offset");
bool drop_prefs = arg_parser.get_bool("drop_prefs");
if(use_dbias && bias.type != bias_enum::elementwise_bias)
{
std::cerr << "dbias only exists when bias type is elementwise" << std::endl;
......@@ -368,6 +386,8 @@ bool run(const ck_tile::ArgParser& arg_parser)
ck_tile::DeviceMem dbias_buf(dbias_host.get_element_space_size_in_bytes());
ck_tile::DeviceMem seqstart_q(seqstart_q_host.size() * sizeof(int32_t));
ck_tile::DeviceMem seqstart_k(seqstart_k_host.size() * sizeof(int32_t));
ck_tile::DeviceMem drop_seed_buf(drop_prefs ? sizeof(uint64_t) : 0);
ck_tile::DeviceMem drop_offset_buf(drop_prefs ? sizeof(uint64_t) : 0);
ck_tile::DeviceMem alibi_slope_buf(alibi_slope_host.get_element_space_size_in_bytes());
ck_tile::DeviceMem dq_acc_buf(dq_acc_host.get_element_space_size_in_bytes());
......@@ -378,6 +398,8 @@ bool run(const ck_tile::ArgParser& arg_parser)
do_buf.ToDevice(do_host.data());
seqstart_q.ToDevice(seqstart_q_host.data());
seqstart_k.ToDevice(seqstart_k_host.data());
drop_seed_buf.ToDevice(drop_prefs ? &drop_seed : nullptr);
drop_offset_buf.ToDevice(drop_prefs ? &drop_offset : nullptr);
alibi_slope_buf.ToDevice(alibi_slope_host.data());
// clang-format off
......@@ -459,6 +481,18 @@ bool run(const ck_tile::ArgParser& arg_parser)
const ck_tile::index_t split_stride_dq_acc =
(shape_batch * nhead * shape_seqlen_q * hdim_q);
const auto drop_seed_offset = [&]() -> decltype(fmha_bwd_args::drop_seed_offset) {
if(drop_prefs)
{
return std::make_pair(drop_seed_buf.GetDeviceBuffer(),
drop_offset_buf.GetDeviceBuffer());
}
else
{
return std::make_pair(drop_seed, drop_offset);
}
}();
return fmha_bwd_args{q_buf.GetDeviceBuffer(),
k_buf.GetDeviceBuffer(),
v_buf.GetDeviceBuffer(),
......@@ -532,7 +566,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
static_cast<ck_tile::index_t>(mask.type),
p_drop,
p_undrop,
{drop_seed, drop_offset}};
drop_seed_offset};
}();
float ave_time = fmha_bwd(fmha_traits, fmha_args, stream_config);
......@@ -899,7 +933,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
}
// clang-format on
auto [rtol, atol] = get_elimit<DataType>(init_method);
auto [rtol, atol] = get_elimit<DataType>(hdim_q, hdim_v);
bool dq_cur_pass = ck_tile::check_err(dq_host_result,
dq_host_ref,
std::string("Error: QGrad Incorrect results!"),
......
......@@ -9,7 +9,10 @@
#include "ck_tile/ops/epilogue.hpp"
#include "mask.hpp"
#include "bias.hpp"
#include <type_traits>
#include <utility>
#include <variant>
template <typename DataType>
struct FmhaBwdTypeConfig;
......@@ -135,7 +138,8 @@ struct fmha_bwd_args
ck_tile::index_t mask_type;
float p_drop;
float p_undrop;
std::tuple<uint64_t, uint64_t> drop_seed_offset;
std::variant<std::pair<uint64_t, uint64_t>, std::pair<const void*, const void*>>
drop_seed_offset;
};
template <typename FmhaBwdDQDKDVKernel>
......
......@@ -122,6 +122,9 @@ auto create_args(int argc, char* argv[])
.insert("p_drop", "0", "0~1 probability of dropout")
.insert("drop_seed", "1", "seed for random number generator")
.insert("drop_offset", "0", "offset for random number generator")
.insert("drop_prefs",
"0",
"seed and offset values are present on GPU; 0 - host, 1 - device/GPU")
.insert("timer", "gpu", "gpu:gpu timer, cpu:cpu timer")
.insert(
"rotary_dim", "0", "RoPE rotary dimension. rotary_dim <= 0 means not apply RoPE at all")
......@@ -442,6 +445,8 @@ bool run(const ck_tile::ArgParser& arg_parser)
float p_drop = arg_parser.get_float("p_drop");
uint64_t drop_seed = arg_parser.get_uint64("drop_seed");
uint64_t drop_offset = arg_parser.get_uint64("drop_offset");
bool drop_prefs = arg_parser.get_bool("drop_prefs");
if(p_drop < 0.0f || p_drop > 1.0f)
{
std::cerr << "The value of p_drop should be 0~1" << std::endl;
......@@ -552,16 +557,33 @@ bool run(const ck_tile::ArgParser& arg_parser)
}
#endif
auto get_lengths = [&](bool permute,
ck_tile::index_t b /*batch*/,
ck_tile::index_t h /*nhead*/,
ck_tile::index_t s /*seqlen*/,
ck_tile::index_t d /*hdim*/) {
if(permute)
return std::array<ck_tile::index_t, 4>{b, h, s, d};
else
return std::array<ck_tile::index_t, 4>{b, s, h, d};
};
struct
{
auto operator()(bool permute,
ck_tile::index_t b /*batch*/,
ck_tile::index_t h /*nhead*/,
ck_tile::index_t s /*seqlen*/,
ck_tile::index_t d /*hdim*/)
{
if(permute)
return std::array<ck_tile::index_t, 4>{b, h, s, d};
else
return std::array<ck_tile::index_t, 4>{b, s, h, d};
}
auto operator()(bool permute,
ck_tile::index_t ns /*num_splits*/,
ck_tile::index_t b /*batch*/,
ck_tile::index_t h /*nhead*/,
ck_tile::index_t s /*seqlen*/,
ck_tile::index_t d /*hdim*/)
{
if(permute)
return std::array<ck_tile::index_t, 5>{ns, b, h, s, d};
else
return std::array<ck_tile::index_t, 5>{ns, b, s, h, d};
}
} get_lengths;
bool is_v_rowmajor = vlayout == std::string("r");
......@@ -617,7 +639,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
: std::array<ck_tile::index_t, 4>{1, 1, 1, 1});
ck_tile::HostTensor<OaccDataType> o_acc_host(
1 < num_splits || use_kvcache
? std::array<ck_tile::index_t, 5>{num_splits, batch, nhead, max_seqlen_q, hdim_v}
? get_lengths(o_perm, num_splits, shape_batch, nhead, shape_seqlen_q, hdim_v)
: std::array<ck_tile::index_t, 5>{1, 1, 1, 1, 1});
// batch mode of lse data layout is [batch, nhead, seqlen_q]
......@@ -739,6 +761,8 @@ bool run(const ck_tile::ArgParser& arg_parser)
need_append_kvcache ? cache_seqlen_ks.size() * sizeof(int32_t) : 0);
ck_tile::DeviceMem rotary_cos_buf(rotary_cos_host.get_element_space_size_in_bytes());
ck_tile::DeviceMem rotary_sin_buf(rotary_sin_host.get_element_space_size_in_bytes());
ck_tile::DeviceMem drop_seed_buf(drop_prefs ? sizeof(uint64_t) : 0);
ck_tile::DeviceMem drop_offset_buf(drop_prefs ? sizeof(uint64_t) : 0);
ck_tile::DeviceMem randval_buf(randval_host.get_element_space_size_in_bytes());
ck_tile::DeviceMem alibi_slope_buf(alibi_slope_host.get_element_space_size_in_bytes());
ck_tile::DeviceMem block_table_buf(block_table_host.get_element_space_size_in_bytes());
......@@ -757,6 +781,8 @@ bool run(const ck_tile::ArgParser& arg_parser)
cache_seqlen_k_buf.ToDevice(need_append_kvcache ? cache_seqlen_ks.data() : nullptr);
rotary_cos_buf.ToDevice(rotary_cos_host.data());
rotary_sin_buf.ToDevice(rotary_sin_host.data());
drop_seed_buf.ToDevice(drop_prefs ? &drop_seed : nullptr);
drop_offset_buf.ToDevice(drop_prefs ? &drop_offset : nullptr);
alibi_slope_buf.ToDevice(alibi_slope_host.data());
block_table_buf.ToDevice(block_table_host.data());
cache_batch_idx_buf.ToDevice(cache_batch_idx_host.data());
......@@ -854,7 +880,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
}();
const ck_tile::index_t stride_bias = (i_perm ? shape_seqlen_k : 1 * shape_seqlen_k);
const ck_tile::index_t stride_randval = (max_seqlen_k);
const ck_tile::index_t stride_o_acc = hdim_v;
const ck_tile::index_t stride_o_acc = (o_perm ? hdim_v : nhead * hdim_v);
const ck_tile::index_t stride_o = (o_perm ? hdim_v : nhead * hdim_v);
// setup nhead_stride_* arguments
const ck_tile::index_t nhead_stride_q = (i_perm ? shape_seqlen_q * hdim_q : hdim_q);
......@@ -881,7 +907,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
const ck_tile::index_t nhead_stride_randval = (shape_seqlen_q * max_seqlen_k);
const ck_tile::index_t nhead_stride_lse = shape_seqlen_q;
const ck_tile::index_t nhead_stride_lse_acc = shape_seqlen_q;
const ck_tile::index_t nhead_stride_o_acc = (max_seqlen_q * hdim_v);
const ck_tile::index_t nhead_stride_o_acc = (o_perm ? shape_seqlen_q * hdim_v : hdim_v);
const ck_tile::index_t nhead_stride_o = (o_perm ? shape_seqlen_q * hdim_v : hdim_v);
// setup batch_stride_* arguments
const ck_tile::index_t batch_stride_q = (nhead * shape_seqlen_q * hdim_q);
......@@ -897,12 +923,12 @@ bool run(const ck_tile::ArgParser& arg_parser)
const ck_tile::index_t batch_stride_randval = (nhead * shape_seqlen_q * max_seqlen_k);
const ck_tile::index_t batch_stride_lse = (nhead * shape_seqlen_q);
const ck_tile::index_t batch_stride_lse_acc = (nhead * shape_seqlen_q);
const ck_tile::index_t batch_stride_o_acc = (nhead * max_seqlen_q * hdim_v);
const ck_tile::index_t batch_stride_o_acc = (nhead * shape_seqlen_q * hdim_v);
const ck_tile::index_t batch_stride_o = (nhead * shape_seqlen_q * hdim_v);
const ck_tile::index_t batch_stride_block_table = (max_num_page_blocks / batch);
// setup split_stride_* arguments (only used in split-kv kernel)
const ck_tile::index_t split_stride_lse_acc = (shape_batch * nhead * shape_seqlen_q);
const ck_tile::index_t split_stride_o_acc = (batch * nhead * max_seqlen_q * hdim_v);
const ck_tile::index_t split_stride_o_acc = (shape_batch * nhead * shape_seqlen_q * hdim_v);
args.q_ptr = q_buf.GetDeviceBuffer();
args.k_ptr = k_buf.GetDeviceBuffer();
......@@ -996,9 +1022,17 @@ bool run(const ck_tile::ArgParser& arg_parser)
args.nhead_stride_randval = nhead_stride_randval;
args.batch_stride_randval = batch_stride_randval;
args.p_drop = p_drop;
args.s_randval = s_randval;
args.drop_seed_offset = std::tie(drop_seed, drop_offset);
args.p_drop = p_drop;
args.s_randval = s_randval;
if(drop_prefs)
{
args.drop_seed_offset = std::make_pair(drop_seed_buf.GetDeviceBuffer(),
drop_offset_buf.GetDeviceBuffer());
}
else
{
args.drop_seed_offset = std::make_pair(drop_seed, drop_offset);
}
}
else if constexpr(std::is_same_v<fmha_fwd_splitkv_args, std::decay_t<decltype(args)>>)
{
......
......@@ -13,6 +13,8 @@
#include "rotary.hpp"
#include <type_traits>
#include <utility>
#include <variant>
template <typename DataType>
struct FmhaFwdTypeConfig;
......@@ -144,7 +146,9 @@ struct fmha_fwd_args
float p_drop;
bool s_randval;
std::tuple<uint64_t, uint64_t> drop_seed_offset;
std::variant<std::pair<uint64_t, uint64_t>, std::pair<const void*, const void*>>
drop_seed_offset;
};
struct fmha_fwd_splitkv_args
......@@ -398,10 +402,8 @@ auto fmha_fwd_splitkv_create_kargs_and_grids(fmha_fwd_splitkv_args args)
args.nhead_stride_bias,
args.nhead_stride_lse_acc,
args.nhead_stride_o_acc,
args.batch_stride_k,
args.batch_stride_v,
args.batch_stride_lse_acc,
args.batch_stride_o_acc,
args.batch_stride_k, // only used for paged-kvcache
args.batch_stride_v, // only used for paged-kvcache
args.split_stride_lse_acc,
args.split_stride_o_acc,
args.window_size_left,
......@@ -475,7 +477,6 @@ auto fmha_fwd_splitkv_combine_create_kargs_and_grids(fmha_fwd_splitkv_args args)
args.lse_ptr,
args.o_ptr,
args.batch,
args.max_seqlen_q,
args.seqstart_q_ptr,
args.hdim_v,
args.num_splits,
......@@ -486,7 +487,6 @@ auto fmha_fwd_splitkv_combine_create_kargs_and_grids(fmha_fwd_splitkv_args args)
args.nhead_stride_o_acc,
args.nhead_stride_lse,
args.nhead_stride_o,
args.batch_stride_o_acc,
args.split_stride_lse_acc,
args.split_stride_o_acc);
}
......@@ -497,7 +497,6 @@ auto fmha_fwd_splitkv_combine_create_kargs_and_grids(fmha_fwd_splitkv_args args)
args.lse_ptr,
args.o_ptr,
args.batch,
args.max_seqlen_q,
args.seqlen_q,
args.hdim_v,
args.num_splits,
......
......@@ -35,7 +35,9 @@ float layernorm2d_fwd(layernorm2d_fwd_traits t,
YDataType,
MeanDataType,
InvStdDataType,
Shape>;
Shape,
true,
true>;
using Kernel = ck_tile::Layernorm2dFwd<PipelineProblem>;
......
# not using add_example_executable() to add this target, since we don't want this to have
# to be included in "make all/install/check"
add_executable(tile_example_img2col EXCLUDE_FROM_ALL image_to_column.cpp)
# Image to Column
This folder contains example for Image to Column using ck_tile tile-programming implementation.
## build
```
# in the root of ck_tile
mkdir build && cd build
sh ../script/cmake-ck-dev.sh ../ <arch> # you can replace this <arch> to gfx90a, gfx942...
make tile_example_img2col -j
```
This will result in an executable `build/bin/tile_example_img2col`
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment