Commit 6a25d081 authored by carlushuang's avatar carlushuang
Browse files

Merge remote-tracking branch 'origin/develop' into ck_tile/fav2_fwd_sept

parents 02f8c487 ceaed8e0
...@@ -35,7 +35,9 @@ float layernorm2d_fwd(layernorm2d_fwd_traits t, ...@@ -35,7 +35,9 @@ float layernorm2d_fwd(layernorm2d_fwd_traits t,
YDataType, YDataType,
MeanDataType, MeanDataType,
InvStdDataType, InvStdDataType,
Shape>; Shape,
true,
true>;
using Kernel = ck_tile::Layernorm2dFwd<PipelineProblem>; using Kernel = ck_tile::Layernorm2dFwd<PipelineProblem>;
......
...@@ -6,7 +6,8 @@ This folder contains example for GEMM using ck_tile tile-programming implementat ...@@ -6,7 +6,8 @@ This folder contains example for GEMM using ck_tile tile-programming implementat
``` ```
# in the root of ck_tile # in the root of ck_tile
mkdir build && cd build mkdir build && cd build
sh ../script/cmake-ck-dev.sh ../ <arch> # you can replace this <arch> to gfx90a, gfx942... # you can replace <arch> with the appropriate architecture (for example gfx90a or gfx942) or leave it blank
sh ../script/cmake-ck-dev.sh ../ <arch>
make tile_example_gemm_basic -j make tile_example_gemm_basic -j
``` ```
This will result in an executable `build/bin/tile_example_gemm_basic` This will result in an executable `build/bin/tile_example_gemm_basic`
...@@ -14,10 +15,17 @@ This will result in an executable `build/bin/tile_example_gemm_basic` ...@@ -14,10 +15,17 @@ This will result in an executable `build/bin/tile_example_gemm_basic`
## example ## example
``` ```
args: args:
-m m dimension (default:3328) -b batch size (default:1)
-n m dimension (default:4096) -m m dimension (default:1024)
-n n dimension (default:2048)
-k k dimension (default:64) -k k dimension (default:64)
-e epsilon (default:1e-5) -stride_a Tensor A stride (default:0)
-v cpu validation or not (default:1) -stride_b Tensor B stride (default:0)
-prec precision (default:fp16) -stride_c Tensor C stride (default:0)
-v 0. No validation, 1. Validation on CPU, 2. Validation on GPU (default:2)
-e Absolute error tolerance (default:1e-5)
-prec data type. fp16/bf16/fp8/bf8 (default:fp16)
-warmup number of iterations before benchmark the kernel (default:10)
-repeat number of iterations to benchmark the kernel (default:100)
-timer gpu:gpu timer, cpu:cpu timer (default:gpu)
``` ```
...@@ -6,7 +6,8 @@ This folder contains example for Image to Column using ck_tile tile-programming ...@@ -6,7 +6,8 @@ This folder contains example for Image to Column using ck_tile tile-programming
``` ```
# in the root of ck_tile # in the root of ck_tile
mkdir build && cd build mkdir build && cd build
sh ../script/cmake-ck-dev.sh ../ <arch> # you can replace this <arch> to gfx90a, gfx942... # you can replace <arch> with the appropriate architecture (for example gfx90a or gfx942) or leave it blank
sh ../script/cmake-ck-dev.sh ../ <arch>
make tile_example_img2col -j make tile_example_img2col -j
``` ```
This will result in an executable `build/bin/tile_example_img2col` This will result in an executable `build/bin/tile_example_img2col`
...@@ -97,13 +97,6 @@ ...@@ -97,13 +97,6 @@
#cmakedefine CK_ENABLE_DL_KERNELS @CK_ENABLE_DL_KERNELS@ #cmakedefine CK_ENABLE_DL_KERNELS @CK_ENABLE_DL_KERNELS@
#endif #endif
//
// Instances supports in the current CK build
//
#ifndef CK_ENABLE_INSTANCES_ONLY
#cmakedefine CK_ENABLE_INSTANCES_ONLY @CK_ENABLE_INSTANCES_ONLY@
#endif
// //
// CK kernels which support XDL (MI series) // CK kernels which support XDL (MI series)
// //
......
...@@ -66,6 +66,9 @@ float launch_and_time_kernel(const StreamConfig& stream_config, ...@@ -66,6 +66,9 @@ float launch_and_time_kernel(const StreamConfig& stream_config,
hip_check_error(hipEventElapsedTime(&total_time, start, stop)); hip_check_error(hipEventElapsedTime(&total_time, start, stop));
hip_check_error(hipEventDestroy(start));
hip_check_error(hipEventDestroy(stop));
return total_time / nrepeat; return total_time / nrepeat;
} }
else else
...@@ -143,6 +146,9 @@ float launch_and_time_kernel_with_preprocess(const StreamConfig& stream_config, ...@@ -143,6 +146,9 @@ float launch_and_time_kernel_with_preprocess(const StreamConfig& stream_config,
hip_check_error(hipEventElapsedTime(&total_time, start, stop)); hip_check_error(hipEventElapsedTime(&total_time, start, stop));
hip_check_error(hipEventDestroy(start));
hip_check_error(hipEventDestroy(stop));
return total_time / nrepeat; return total_time / nrepeat;
} }
else else
......
...@@ -308,7 +308,7 @@ struct BlockwiseGemmXdlops_pipeline_v1_ab_scale<BlockGemmPipelineScheduler::Intr ...@@ -308,7 +308,7 @@ struct BlockwiseGemmXdlops_pipeline_v1_ab_scale<BlockGemmPipelineScheduler::Intr
typename vector_type<ComputeDataType, typename vector_type<ComputeDataType,
xdlops_gemm.K1PerXdlops>::type; xdlops_gemm.K1PerXdlops>::type;
xdlops_gemm.template Run( xdlops_gemm.template Run<>(
a_thread_vec.template AsType<mfma_input_type>(), a_thread_vec.template AsType<mfma_input_type>(),
b_thread_vec.template AsType<mfma_input_type>(), b_thread_vec.template AsType<mfma_input_type>(),
c_thread_buf_per_scale.GetVectorTypeReference(I0)); c_thread_buf_per_scale.GetVectorTypeReference(I0));
...@@ -390,9 +390,10 @@ struct BlockwiseGemmXdlops_pipeline_v1_ab_scale<BlockGemmPipelineScheduler::Intr ...@@ -390,9 +390,10 @@ struct BlockwiseGemmXdlops_pipeline_v1_ab_scale<BlockGemmPipelineScheduler::Intr
using mfma_input_type = using mfma_input_type =
typename vector_type<ComputeDataType, xdlops_gemm.K1PerXdlops>::type; typename vector_type<ComputeDataType, xdlops_gemm.K1PerXdlops>::type;
xdlops_gemm.template Run(a_thread_vec.template AsType<mfma_input_type>(), xdlops_gemm.template Run<>(
b_thread_vec.template AsType<mfma_input_type>(), a_thread_vec.template AsType<mfma_input_type>(),
c_thread_buf_per_scale.GetVectorTypeReference(I0)); b_thread_vec.template AsType<mfma_input_type>(),
c_thread_buf_per_scale.GetVectorTypeReference(I0));
}); });
static_for<0, xdlops_gemm.GetRegSizePerXdlops(), 1>{}([&](auto t) { static_for<0, xdlops_gemm.GetRegSizePerXdlops(), 1>{}([&](auto t) {
constexpr index_t c_offset = constexpr index_t c_offset =
......
...@@ -350,7 +350,7 @@ struct BlockwiseGemmXdlops_pipeline_v2_ab_scale<BlockGemmPipelineScheduler::Intr ...@@ -350,7 +350,7 @@ struct BlockwiseGemmXdlops_pipeline_v2_ab_scale<BlockGemmPipelineScheduler::Intr
typename vector_type<ComputeDataType, typename vector_type<ComputeDataType,
xdlops_gemm.K1PerXdlops>::type; xdlops_gemm.K1PerXdlops>::type;
xdlops_gemm.template Run( xdlops_gemm.template Run<>(
a_thread_vec.template AsType<mfma_input_type>(), a_thread_vec.template AsType<mfma_input_type>(),
b_thread_vec.template AsType<mfma_input_type>(), b_thread_vec.template AsType<mfma_input_type>(),
c_thread_buf_per_scale.GetVectorTypeReference(I0)); c_thread_buf_per_scale.GetVectorTypeReference(I0));
...@@ -443,7 +443,7 @@ struct BlockwiseGemmXdlops_pipeline_v2_ab_scale<BlockGemmPipelineScheduler::Intr ...@@ -443,7 +443,7 @@ struct BlockwiseGemmXdlops_pipeline_v2_ab_scale<BlockGemmPipelineScheduler::Intr
typename vector_type<ComputeDataType, typename vector_type<ComputeDataType,
xdlops_gemm.K1PerXdlops>::type; xdlops_gemm.K1PerXdlops>::type;
xdlops_gemm.template Run( xdlops_gemm.template Run<>(
a_thread_vec.template AsType<mfma_input_type>(), a_thread_vec.template AsType<mfma_input_type>(),
b_thread_vec.template AsType<mfma_input_type>(), b_thread_vec.template AsType<mfma_input_type>(),
c_thread_buf_per_scale.GetVectorTypeReference(I0)); c_thread_buf_per_scale.GetVectorTypeReference(I0));
...@@ -518,9 +518,10 @@ struct BlockwiseGemmXdlops_pipeline_v2_ab_scale<BlockGemmPipelineScheduler::Intr ...@@ -518,9 +518,10 @@ struct BlockwiseGemmXdlops_pipeline_v2_ab_scale<BlockGemmPipelineScheduler::Intr
using mfma_input_type = using mfma_input_type =
typename vector_type<ComputeDataType, xdlops_gemm.K1PerXdlops>::type; typename vector_type<ComputeDataType, xdlops_gemm.K1PerXdlops>::type;
xdlops_gemm.template Run(a_thread_vec.template AsType<mfma_input_type>(), xdlops_gemm.template Run<>(
b_thread_vec.template AsType<mfma_input_type>(), a_thread_vec.template AsType<mfma_input_type>(),
c_thread_buf_per_scale.GetVectorTypeReference(I0)); b_thread_vec.template AsType<mfma_input_type>(),
c_thread_buf_per_scale.GetVectorTypeReference(I0));
}); });
static_for<0, xdlops_gemm.GetRegSizePerXdlops(), 1>{}([&](auto t) { static_for<0, xdlops_gemm.GetRegSizePerXdlops(), 1>{}([&](auto t) {
constexpr index_t c_offset = constexpr index_t c_offset =
...@@ -575,9 +576,10 @@ struct BlockwiseGemmXdlops_pipeline_v2_ab_scale<BlockGemmPipelineScheduler::Intr ...@@ -575,9 +576,10 @@ struct BlockwiseGemmXdlops_pipeline_v2_ab_scale<BlockGemmPipelineScheduler::Intr
using mfma_input_type = using mfma_input_type =
typename vector_type<ComputeDataType, xdlops_gemm.K1PerXdlops>::type; typename vector_type<ComputeDataType, xdlops_gemm.K1PerXdlops>::type;
xdlops_gemm.template Run(a_thread_vec.template AsType<mfma_input_type>(), xdlops_gemm.template Run<>(
b_thread_vec.template AsType<mfma_input_type>(), a_thread_vec.template AsType<mfma_input_type>(),
c_thread_buf_per_scale.GetVectorTypeReference(I0)); b_thread_vec.template AsType<mfma_input_type>(),
c_thread_buf_per_scale.GetVectorTypeReference(I0));
}); });
static_for<0, xdlops_gemm.GetRegSizePerXdlops(), 1>{}([&](auto t) { static_for<0, xdlops_gemm.GetRegSizePerXdlops(), 1>{}([&](auto t) {
constexpr index_t c_offset = constexpr index_t c_offset =
......
...@@ -427,7 +427,7 @@ struct BlockwiseGemmXdlops_pipeline_v3_ab_scale<BlockGemmPipelineScheduler::Intr ...@@ -427,7 +427,7 @@ struct BlockwiseGemmXdlops_pipeline_v3_ab_scale<BlockGemmPipelineScheduler::Intr
typename vector_type<ComputeDataType, typename vector_type<ComputeDataType,
xdlops_gemm.K1PerXdlops>::type; xdlops_gemm.K1PerXdlops>::type;
xdlops_gemm.template Run( xdlops_gemm.template Run<>(
a_thread_vec.template AsType<mfma_input_type>(), a_thread_vec.template AsType<mfma_input_type>(),
b_thread_vec.template AsType<mfma_input_type>(), b_thread_vec.template AsType<mfma_input_type>(),
c_thread_buf_per_scale.GetVectorTypeReference(I0)); c_thread_buf_per_scale.GetVectorTypeReference(I0));
...@@ -504,9 +504,10 @@ struct BlockwiseGemmXdlops_pipeline_v3_ab_scale<BlockGemmPipelineScheduler::Intr ...@@ -504,9 +504,10 @@ struct BlockwiseGemmXdlops_pipeline_v3_ab_scale<BlockGemmPipelineScheduler::Intr
using mfma_input_type = using mfma_input_type =
typename vector_type<ComputeDataType, xdlops_gemm.K1PerXdlops>::type; typename vector_type<ComputeDataType, xdlops_gemm.K1PerXdlops>::type;
xdlops_gemm.template Run(a_thread_vec.template AsType<mfma_input_type>(), xdlops_gemm.template Run<>(
b_thread_vec.template AsType<mfma_input_type>(), a_thread_vec.template AsType<mfma_input_type>(),
c_thread_buf_per_scale.GetVectorTypeReference(I0)); b_thread_vec.template AsType<mfma_input_type>(),
c_thread_buf_per_scale.GetVectorTypeReference(I0));
}); });
static_for<0, xdlops_gemm.GetRegSizePerXdlops(), 1>{}([&](auto t) { static_for<0, xdlops_gemm.GetRegSizePerXdlops(), 1>{}([&](auto t) {
constexpr index_t c_offset = constexpr index_t c_offset =
......
...@@ -64,7 +64,7 @@ __global__ void ...@@ -64,7 +64,7 @@ __global__ void
const index_t N = gemm_desc_ptr[group_id].N; const index_t N = gemm_desc_ptr[group_id].N;
const index_t K = gemm_desc_ptr[group_id].K; const index_t K = gemm_desc_ptr[group_id].K;
if(M * N * K == 0) if(M == 0 || N == 0 || K == 0)
return; return;
const auto StrideAs = gemm_desc_ptr[group_id].StrideAs; const auto StrideAs = gemm_desc_ptr[group_id].StrideAs;
......
...@@ -345,7 +345,7 @@ struct DeviceGroupedGemmMultipleDSplitKXdlCShuffleTwoStage ...@@ -345,7 +345,7 @@ struct DeviceGroupedGemmMultipleDSplitKXdlCShuffleTwoStage
const index_t N = gemm_descs[i].N_; const index_t N = gemm_descs[i].N_;
const index_t K = gemm_descs[i].K_; const index_t K = gemm_descs[i].K_;
if(M * N * K == 0) if(M == 0 || N == 0 || K == 0)
{ {
skipped_group_count_++; skipped_group_count_++;
continue; continue;
......
...@@ -109,7 +109,7 @@ __global__ void ...@@ -109,7 +109,7 @@ __global__ void
N = gemm_desc_ptr[group_id].N; N = gemm_desc_ptr[group_id].N;
K = gemm_desc_ptr[group_id].K; K = gemm_desc_ptr[group_id].K;
if(M * N * K == 0) if(M == 0 || N == 0 || K == 0)
{ {
grid_size_grp = 0; grid_size_grp = 0;
continue; continue;
......
...@@ -68,7 +68,7 @@ __global__ void ...@@ -68,7 +68,7 @@ __global__ void
const index_t N = gemm_desc_ptr[group_id].N; const index_t N = gemm_desc_ptr[group_id].N;
const index_t K = gemm_desc_ptr[group_id].K; const index_t K = gemm_desc_ptr[group_id].K;
if(M * N * K == 0) if(M == 0 || N == 0 || K == 0)
return; return;
const auto StrideA = gemm_desc_ptr[group_id].StrideA; const auto StrideA = gemm_desc_ptr[group_id].StrideA;
......
...@@ -4,6 +4,7 @@ ...@@ -4,6 +4,7 @@
#pragma once #pragma once
#include <initializer_list> #include <initializer_list>
#include <vector>
#include "ck_tile/core/config.hpp" #include "ck_tile/core/config.hpp"
#include "ck_tile/core/numeric/integer.hpp" #include "ck_tile/core/numeric/integer.hpp"
......
...@@ -50,12 +50,22 @@ class ArgParser ...@@ -50,12 +50,22 @@ class ArgParser
} }
return *this; return *this;
} }
void print() void print() const
{ {
// find max key length
std::string::size_type max_key_length = 11;
for(auto& key : keys)
{
if(max_key_length < key.length())
{
max_key_length = key.length();
}
}
printf("args:\n"); printf("args:\n");
for(auto& key : keys) for(auto& key : keys)
{ {
auto value = input_map[key]; auto value = input_map.at(key);
std::vector<std::string> help_text_lines; std::vector<std::string> help_text_lines;
size_t pos = 0; size_t pos = 0;
for(size_t next_pos = value.help_text.find('\n', pos); next_pos != std::string::npos;) for(size_t next_pos = value.help_text.find('\n', pos); next_pos != std::string::npos;)
...@@ -69,8 +79,7 @@ class ArgParser ...@@ -69,8 +79,7 @@ class ArgParser
std::string(value.help_text.begin() + pos, value.help_text.end())); std::string(value.help_text.begin() + pos, value.help_text.end()));
std::string default_value = std::string("(default:") + value.value + std::string(")"); std::string default_value = std::string("(default:") + value.value + std::string(")");
std::cout << std::setw(1 + max_key_length - value.name.length()) << "-" << key
std::cout << std::setw(2) << std::setw(12 - value.name.length()) << "-" << key
<< std::setw(4) << " " << help_text_lines[0] << " " << default_value << std::setw(4) << " " << help_text_lines[0] << " " << default_value
<< std::endl; << std::endl;
...@@ -78,7 +87,8 @@ class ArgParser ...@@ -78,7 +87,8 @@ class ArgParser
help_next_line != help_text_lines.end(); help_next_line != help_text_lines.end();
++help_next_line) ++help_next_line)
{ {
std::cout << std::setw(17) << " " << *help_next_line << std::endl; std::cout << std::setw(1 + max_key_length + 4) << " " << *help_next_line
<< std::endl;
} }
} }
} }
......
...@@ -13,7 +13,6 @@ namespace conv { ...@@ -13,7 +13,6 @@ namespace conv {
struct ConvParam struct ConvParam
{ {
ConvParam();
ConvParam(ck_tile::index_t n_dim, ConvParam(ck_tile::index_t n_dim,
ck_tile::index_t group_count, ck_tile::index_t group_count,
ck_tile::index_t n_batch, ck_tile::index_t n_batch,
...@@ -199,11 +198,6 @@ struct ConvParam ...@@ -199,11 +198,6 @@ struct ConvParam
} }
}; };
ConvParam::ConvParam()
: ConvParam::ConvParam(2, 1, 128, 256, 192, {3, 3}, {71, 71}, {2, 2}, {1, 1}, {1, 1}, {1, 1})
{
}
CK_TILE_HOST std::string get_conv_param_parser_helper_msg() CK_TILE_HOST std::string get_conv_param_parser_helper_msg()
{ {
std::string msg; std::string msg;
......
...@@ -308,9 +308,9 @@ struct SimplifiedGenericAttentionMask ...@@ -308,9 +308,9 @@ struct SimplifiedGenericAttentionMask
{ {
auto [origin_start, origin_end] = GetTileRangeAlongX(i_y, height, width); auto [origin_start, origin_end] = GetTileRangeAlongX(i_y, height, width);
const index_t x_per_split = ck_tile::max(1, x_total / num_splits); const index_t x_per_split = ck_tile::max(1, integer_divide_ceil(x_total, num_splits));
const index_t split_start = x_per_split * i_split; const index_t split_start = x_per_split * i_split;
const index_t split_end = (i_split == num_splits - 1 ? x_total : split_start + x_per_split); const index_t split_end = split_start + x_per_split;
return ck_tile::make_tuple(ck_tile::max(origin_start, split_start), return ck_tile::make_tuple(ck_tile::max(origin_start, split_start),
ck_tile::min(origin_end, split_end)); ck_tile::min(origin_end, split_end));
......
...@@ -6,8 +6,11 @@ ...@@ -6,8 +6,11 @@
#include "ck_tile/core.hpp" #include "ck_tile/core.hpp"
#include "ck_tile/ops/common.hpp" #include "ck_tile/ops/common.hpp"
#include "ck_tile/ops/fmha/block/block_attention_bias_enum.hpp" #include "ck_tile/ops/fmha/block/block_attention_bias_enum.hpp"
#include <string> #include <string>
#include <type_traits> #include <type_traits>
#include <utility>
#include <variant>
// S[seqlen_q, seqlen_k] = Q[seqlen_q, hdim_q] @ K[seqlen_k, hdim_q] // S[seqlen_q, seqlen_k] = Q[seqlen_q, hdim_q] @ K[seqlen_k, hdim_q]
// S'[seqlen_q, seqlen_k] = S[seqlen_q, seqlen_k] * Scale[1] // S'[seqlen_q, seqlen_k] = S[seqlen_q, seqlen_k] * Scale[1]
...@@ -194,11 +197,39 @@ struct FmhaBwdDQDKDVKernel ...@@ -194,11 +197,39 @@ struct FmhaBwdDQDKDVKernel
ck_tile::GenericAttentionMaskEnum mask_type; ck_tile::GenericAttentionMaskEnum mask_type;
}; };
struct FmhaBwdCommonDropoutKargs struct FmhaBwdDropoutSeedOffset
{ {
void init_dropout(const float p_drop, template <typename T>
const std::tuple<uint64_t, uint64_t>& drop_seed_offset, union ValueOrPointer
const float raw_scale) {
T val;
const T* ptr;
};
ValueOrPointer<uint64_t> drop_seed;
ValueOrPointer<uint64_t> drop_offset;
bool is_drop_seed_offset_from_host;
};
struct FmhaBwdCommonDropoutKargs : FmhaBwdDropoutSeedOffset
{
void init_dropout(float p_drop, uint64_t seed, uint64_t offset, float raw_scale)
{
float p_undrop = 1.0 - p_drop;
p_undrop_in_uint8_t =
uint8_t(std::floor(p_undrop * std::numeric_limits<uint8_t>::max()));
rp_undrop = 1.0 / p_undrop;
scale_rp_undrop = rp_undrop * raw_scale;
this->drop_seed.val = seed;
this->drop_offset.val = offset;
this->is_drop_seed_offset_from_host = true;
}
void init_dropout(float p_drop,
const uint64_t* seed_ptr,
const uint64_t* offset_ptr,
float raw_scale)
{ {
float p_undrop = 1.0 - p_drop; float p_undrop = 1.0 - p_drop;
p_undrop_in_uint8_t = p_undrop_in_uint8_t =
...@@ -206,23 +237,25 @@ struct FmhaBwdDQDKDVKernel ...@@ -206,23 +237,25 @@ struct FmhaBwdDQDKDVKernel
rp_undrop = 1.0 / p_undrop; rp_undrop = 1.0 / p_undrop;
scale_rp_undrop = rp_undrop * raw_scale; scale_rp_undrop = rp_undrop * raw_scale;
drop_seed = std::get<0>(drop_seed_offset); this->drop_seed.ptr = seed_ptr;
drop_offset = std::get<1>(drop_seed_offset); this->drop_offset.ptr = offset_ptr;
this->is_drop_seed_offset_from_host = false;
} }
float rp_undrop = 1; float rp_undrop = 1;
float scale_rp_undrop = 1; float scale_rp_undrop = 1;
uint8_t p_undrop_in_uint8_t = std::numeric_limits<uint8_t>::max(); uint8_t p_undrop_in_uint8_t = std::numeric_limits<uint8_t>::max();
uint64_t drop_seed = 1;
uint64_t drop_offset = 0;
void* rand_val_ptr = nullptr; void* rand_val_ptr = nullptr;
ck_tile::index_t stride_randval = 0; ck_tile::index_t stride_randval = 0;
ck_tile::index_t nhead_stride_randval = 0; ck_tile::index_t nhead_stride_randval = 0;
}; };
struct FmhaBwdBatchModeDropoutKargs : FmhaBwdCommonDropoutKargs struct FmhaBwdBatchModeDropoutKargs : FmhaBwdCommonDropoutKargs
{ {
ck_tile::index_t batch_stride_randval = 0; ck_tile::index_t batch_stride_randval = 0;
}; };
struct FmhaBwdDeterministicKargs struct FmhaBwdDeterministicKargs
{ {
ck_tile::index_t split_stride_dq_acc = 0; ck_tile::index_t split_stride_dq_acc = 0;
...@@ -327,7 +360,8 @@ struct FmhaBwdDQDKDVKernel ...@@ -327,7 +360,8 @@ struct FmhaBwdDQDKDVKernel
ck_tile::index_t window_size_right, ck_tile::index_t window_size_right,
ck_tile::index_t mask_type, ck_tile::index_t mask_type,
float p_drop, float p_drop,
const std::tuple<uint64_t, uint64_t>& drop_seed_offset) std::variant<std::pair<uint64_t, uint64_t>, std::pair<const void*, const void*>>
drop_seed_offset)
{ {
Kargs kargs{{q_ptr, Kargs kargs{{q_ptr,
k_ptr, k_ptr,
...@@ -405,7 +439,20 @@ struct FmhaBwdDQDKDVKernel ...@@ -405,7 +439,20 @@ struct FmhaBwdDQDKDVKernel
if constexpr(kHasDropout) if constexpr(kHasDropout)
{ {
kargs.init_dropout(p_drop, drop_seed_offset, scale); if(drop_seed_offset.index() == 0) // seed & offset come from host
{
const auto& [seed, offset] = std::get<0>(drop_seed_offset);
kargs.init_dropout(p_drop, seed, offset, scale);
}
else // seed & offset come from device
{
const auto& [seed_ptr, offset_ptr] = std::get<1>(drop_seed_offset);
kargs.init_dropout(p_drop,
reinterpret_cast<const uint64_t*>(seed_ptr),
reinterpret_cast<const uint64_t*>(offset_ptr),
scale);
}
if constexpr(kIsStoreRandval) if constexpr(kIsStoreRandval)
{ {
kargs.rand_val_ptr = rand_val_ptr; kargs.rand_val_ptr = rand_val_ptr;
...@@ -471,7 +518,8 @@ struct FmhaBwdDQDKDVKernel ...@@ -471,7 +518,8 @@ struct FmhaBwdDQDKDVKernel
ck_tile::index_t window_size_right, ck_tile::index_t window_size_right,
ck_tile::index_t mask_type, ck_tile::index_t mask_type,
float p_drop, float p_drop,
const std::tuple<uint64_t, uint64_t>& drop_seed_offset) std::variant<std::pair<uint64_t, uint64_t>, std::pair<const void*, const void*>>
drop_seed_offset)
{ {
Kargs kargs{{q_ptr, Kargs kargs{{q_ptr,
k_ptr, k_ptr,
...@@ -539,7 +587,20 @@ struct FmhaBwdDQDKDVKernel ...@@ -539,7 +587,20 @@ struct FmhaBwdDQDKDVKernel
} }
if constexpr(kHasDropout) if constexpr(kHasDropout)
{ {
kargs.init_dropout(p_drop, drop_seed_offset, scale); if(drop_seed_offset.index() == 0) // seed & offset come from host
{
const auto& [seed, offset] = std::get<0>(drop_seed_offset);
kargs.init_dropout(p_drop, seed, offset, scale);
}
else // seed & offset come from device
{
const auto& [seed_ptr, offset_ptr] = std::get<1>(drop_seed_offset);
kargs.init_dropout(p_drop,
reinterpret_cast<const uint64_t*>(seed_ptr),
reinterpret_cast<const uint64_t*>(offset_ptr),
scale);
}
if constexpr(kIsStoreRandval) if constexpr(kIsStoreRandval)
{ {
kargs.rand_val_ptr = rand_val_ptr; kargs.rand_val_ptr = rand_val_ptr;
...@@ -958,8 +1019,10 @@ struct FmhaBwdDQDKDVKernel ...@@ -958,8 +1019,10 @@ struct FmhaBwdDQDKDVKernel
return FmhaDropout{i_batch_, return FmhaDropout{i_batch_,
i_nhead_, i_nhead_,
kargs.num_head_q, kargs.num_head_q,
kargs.drop_seed, kargs.is_drop_seed_offset_from_host ? kargs.drop_seed.val
kargs.drop_offset, : *kargs.drop_seed.ptr,
kargs.is_drop_seed_offset_from_host ? kargs.drop_offset.val
: *kargs.drop_offset.ptr,
kargs.rp_undrop, kargs.rp_undrop,
kargs.p_undrop_in_uint8_t}; kargs.p_undrop_in_uint8_t};
} }
......
...@@ -6,8 +6,11 @@ ...@@ -6,8 +6,11 @@
#include "ck_tile/core.hpp" #include "ck_tile/core.hpp"
#include "ck_tile/ops/common.hpp" #include "ck_tile/ops/common.hpp"
#include "ck_tile/ops/fmha/block/block_attention_bias_enum.hpp" #include "ck_tile/ops/fmha/block/block_attention_bias_enum.hpp"
#include <string> #include <string>
#include <type_traits> #include <type_traits>
#include <utility>
#include <variant>
// S[seqlen_q, seqlen_k] = Q[seqlen_q, hdim_q] @ K[seqlen_k, hdim_q] // S[seqlen_q, seqlen_k] = Q[seqlen_q, hdim_q] @ K[seqlen_k, hdim_q]
// S'[seqlen_q, seqlen_k] = S[seqlen_q, seqlen_k] * Scale[1] // S'[seqlen_q, seqlen_k] = S[seqlen_q, seqlen_k] * Scale[1]
...@@ -170,29 +173,55 @@ struct FmhaFwdKernel ...@@ -170,29 +173,55 @@ struct FmhaFwdKernel
ck_tile::index_t batch_stride_lse = 0; ck_tile::index_t batch_stride_lse = 0;
}; };
struct FmhaFwdCommonDropoutKargs struct FmhaFwdDropoutSeedOffset
{ {
void init_dropout(const float p_drop, template <typename T>
const std::tuple<uint64_t, uint64_t>& drop_seed_offset) union ValueOrPointer
{
T val;
const T* ptr;
};
ValueOrPointer<uint64_t> drop_seed;
ValueOrPointer<uint64_t> drop_offset;
bool is_drop_seed_offset_from_host;
};
struct FmhaFwdCommonDropoutKargs : FmhaFwdDropoutSeedOffset
{
void init_dropout(float p_drop, uint64_t seed, uint64_t offset)
{
float p_undrop = 1.0 - p_drop;
p_undrop_in_uint8_t =
uint8_t(std::floor(p_undrop * std::numeric_limits<uint8_t>::max()));
rp_undrop = 1.0 / p_undrop;
this->drop_seed.val = seed;
this->drop_offset.val = offset;
this->is_drop_seed_offset_from_host = true;
}
void init_dropout(float p_drop, const uint64_t* seed_ptr, const uint64_t* offset_ptr)
{ {
float p_undrop = 1.0 - p_drop; float p_undrop = 1.0 - p_drop;
p_undrop_in_uint8_t = p_undrop_in_uint8_t =
uint8_t(std::floor(p_undrop * std::numeric_limits<uint8_t>::max())); uint8_t(std::floor(p_undrop * std::numeric_limits<uint8_t>::max()));
rp_undrop = 1.0 / p_undrop; rp_undrop = 1.0 / p_undrop;
drop_seed = std::get<0>(drop_seed_offset); this->drop_seed.ptr = seed_ptr;
drop_offset = std::get<1>(drop_seed_offset); this->drop_offset.ptr = offset_ptr;
this->is_drop_seed_offset_from_host = false;
} }
float rp_undrop = 1; float rp_undrop = 1;
uint8_t p_undrop_in_uint8_t = std::numeric_limits<uint8_t>::max(); uint8_t p_undrop_in_uint8_t = std::numeric_limits<uint8_t>::max();
bool is_store_randval = false; bool is_store_randval = false;
uint64_t drop_seed = 1;
uint64_t drop_offset = 0;
void* rand_val_ptr = nullptr; void* rand_val_ptr = nullptr;
ck_tile::index_t stride_randval = 0; ck_tile::index_t stride_randval = 0;
ck_tile::index_t nhead_stride_randval = 0; ck_tile::index_t nhead_stride_randval = 0;
}; };
struct FmhaFwdBatchModeDropoutKargs : FmhaFwdCommonDropoutKargs struct FmhaFwdBatchModeDropoutKargs : FmhaFwdCommonDropoutKargs
{ {
ck_tile::index_t batch_stride_randval = 0; ck_tile::index_t batch_stride_randval = 0;
...@@ -278,7 +307,8 @@ struct FmhaFwdKernel ...@@ -278,7 +307,8 @@ struct FmhaFwdKernel
ck_tile::index_t mask_type, ck_tile::index_t mask_type,
float p_drop, float p_drop,
bool s_randval, bool s_randval,
const std::tuple<uint64_t, uint64_t>& drop_seed_offset) std::variant<std::pair<uint64_t, uint64_t>, std::pair<const void*, const void*>>
drop_seed_offset)
{ {
Kargs kargs{{q_ptr, Kargs kargs{{q_ptr,
k_ptr, k_ptr,
...@@ -344,7 +374,19 @@ struct FmhaFwdKernel ...@@ -344,7 +374,19 @@ struct FmhaFwdKernel
} }
if constexpr(kHasDropout) if constexpr(kHasDropout)
{ {
kargs.init_dropout(p_drop, drop_seed_offset); if(drop_seed_offset.index() == 0) // seed & offset come from host
{
const auto& [seed, offset] = std::get<0>(drop_seed_offset);
kargs.init_dropout(p_drop, seed, offset);
}
else // seed & offset come from device
{
const auto& [seed_ptr, offset_ptr] = std::get<1>(drop_seed_offset);
kargs.init_dropout(p_drop,
reinterpret_cast<const uint64_t*>(seed_ptr),
reinterpret_cast<const uint64_t*>(offset_ptr));
}
kargs.rand_val_ptr = rand_val_ptr; kargs.rand_val_ptr = rand_val_ptr;
kargs.stride_randval = stride_randval; kargs.stride_randval = stride_randval;
kargs.nhead_stride_randval = nhead_stride_randval; kargs.nhead_stride_randval = nhead_stride_randval;
...@@ -392,7 +434,8 @@ struct FmhaFwdKernel ...@@ -392,7 +434,8 @@ struct FmhaFwdKernel
ck_tile::index_t mask_type, ck_tile::index_t mask_type,
float p_drop, float p_drop,
bool s_randval, bool s_randval,
const std::tuple<uint64_t, uint64_t>& drop_seed_offset) std::variant<std::pair<uint64_t, uint64_t>, std::pair<const void*, const void*>>
drop_seed_offset)
{ {
Kargs kargs{{q_ptr, Kargs kargs{{q_ptr,
k_ptr, k_ptr,
...@@ -455,7 +498,19 @@ struct FmhaFwdKernel ...@@ -455,7 +498,19 @@ struct FmhaFwdKernel
} }
if constexpr(kHasDropout) if constexpr(kHasDropout)
{ {
kargs.init_dropout(p_drop, drop_seed_offset); if(drop_seed_offset.index() == 0) // seed & offset come from host
{
const auto& [seed, offset] = std::get<0>(drop_seed_offset);
kargs.init_dropout(p_drop, seed, offset);
}
else // seed & offset come from device
{
const auto& [seed_ptr, offset_ptr] = std::get<1>(drop_seed_offset);
kargs.init_dropout(p_drop,
reinterpret_cast<const uint64_t*>(seed_ptr),
reinterpret_cast<const uint64_t*>(offset_ptr));
}
kargs.rand_val_ptr = rand_val_ptr; kargs.rand_val_ptr = rand_val_ptr;
kargs.stride_randval = stride_randval; kargs.stride_randval = stride_randval;
kargs.nhead_stride_randval = nhead_stride_randval; kargs.nhead_stride_randval = nhead_stride_randval;
...@@ -748,8 +803,10 @@ struct FmhaFwdKernel ...@@ -748,8 +803,10 @@ struct FmhaFwdKernel
return BlockDropout{i_batch_, return BlockDropout{i_batch_,
i_nhead_, i_nhead_,
kargs.num_head_q, kargs.num_head_q,
kargs.drop_seed, kargs.is_drop_seed_offset_from_host ? kargs.drop_seed.val
kargs.drop_offset, : *kargs.drop_seed.ptr,
kargs.is_drop_seed_offset_from_host ? kargs.drop_offset.val
: *kargs.drop_offset.ptr,
kargs.rp_undrop, kargs.rp_undrop,
kargs.p_undrop_in_uint8_t, kargs.p_undrop_in_uint8_t,
kargs.is_store_randval}; kargs.is_store_randval};
......
...@@ -78,8 +78,6 @@ struct FmhaFwdSplitKVCombineKernel ...@@ -78,8 +78,6 @@ struct FmhaFwdSplitKVCombineKernel
void* o_ptr; void* o_ptr;
ck_tile::index_t batch; ck_tile::index_t batch;
ck_tile::index_t max_seqlen_q;
ck_tile::index_t seqlen_q; ck_tile::index_t seqlen_q;
ck_tile::index_t hdim_v; ck_tile::index_t hdim_v;
ck_tile::index_t num_splits; ck_tile::index_t num_splits;
...@@ -91,8 +89,6 @@ struct FmhaFwdSplitKVCombineKernel ...@@ -91,8 +89,6 @@ struct FmhaFwdSplitKVCombineKernel
ck_tile::index_t nhead_stride_o_acc; ck_tile::index_t nhead_stride_o_acc;
ck_tile::index_t nhead_stride_o; ck_tile::index_t nhead_stride_o;
ck_tile::index_t batch_stride_o_acc;
ck_tile::index_t split_stride_lse_acc; ck_tile::index_t split_stride_lse_acc;
ck_tile::index_t split_stride_o_acc; ck_tile::index_t split_stride_o_acc;
}; };
...@@ -114,8 +110,9 @@ struct FmhaFwdSplitKVCombineKernel ...@@ -114,8 +110,9 @@ struct FmhaFwdSplitKVCombineKernel
std::conditional_t<kStoreLSE, CommonLSEKargs, EmptyKargs<0>>, std::conditional_t<kStoreLSE, CommonLSEKargs, EmptyKargs<0>>,
std::conditional_t<kDoFp8StaticQuant, Fp8StaticQuantKargs, EmptyKargs<1>> std::conditional_t<kDoFp8StaticQuant, Fp8StaticQuantKargs, EmptyKargs<1>>
{ {
ck_tile::index_t batch_stride_o;
ck_tile::index_t batch_stride_lse_acc; ck_tile::index_t batch_stride_lse_acc;
ck_tile::index_t batch_stride_o_acc;
ck_tile::index_t batch_stride_o;
}; };
struct GroupModeKargs struct GroupModeKargs
...@@ -135,7 +132,6 @@ struct FmhaFwdSplitKVCombineKernel ...@@ -135,7 +132,6 @@ struct FmhaFwdSplitKVCombineKernel
void* lse_ptr, void* lse_ptr,
void* o_ptr, void* o_ptr,
ck_tile::index_t batch, ck_tile::index_t batch,
ck_tile::index_t max_seqlen_q,
ck_tile::index_t seqlen_q, ck_tile::index_t seqlen_q,
ck_tile::index_t hdim_v, ck_tile::index_t hdim_v,
ck_tile::index_t num_splits, ck_tile::index_t num_splits,
...@@ -157,7 +153,6 @@ struct FmhaFwdSplitKVCombineKernel ...@@ -157,7 +153,6 @@ struct FmhaFwdSplitKVCombineKernel
o_acc_ptr, o_acc_ptr,
o_ptr, o_ptr,
batch, batch,
max_seqlen_q,
seqlen_q, seqlen_q,
hdim_v, hdim_v,
num_splits, num_splits,
...@@ -166,13 +161,13 @@ struct FmhaFwdSplitKVCombineKernel ...@@ -166,13 +161,13 @@ struct FmhaFwdSplitKVCombineKernel
nhead_stride_lse_acc, nhead_stride_lse_acc,
nhead_stride_o_acc, nhead_stride_o_acc,
nhead_stride_o, nhead_stride_o,
batch_stride_o_acc,
split_stride_lse_acc, split_stride_lse_acc,
split_stride_o_acc}, // args for common karg split_stride_o_acc}, // args for common karg
{}, // placeholder for lse {}, // placeholder for lse
{}, // placeholder for fp8_static_quant args {}, // placeholder for fp8_static_quant args
batch_stride_o, batch_stride_lse_acc,
batch_stride_lse_acc}; batch_stride_o_acc,
batch_stride_o};
if constexpr(kStoreLSE) if constexpr(kStoreLSE)
{ {
...@@ -195,7 +190,6 @@ struct FmhaFwdSplitKVCombineKernel ...@@ -195,7 +190,6 @@ struct FmhaFwdSplitKVCombineKernel
void* lse_ptr, void* lse_ptr,
void* o_ptr, void* o_ptr,
ck_tile::index_t batch, ck_tile::index_t batch,
ck_tile::index_t max_seqlen_q,
const void* seqstart_q_ptr, const void* seqstart_q_ptr,
ck_tile::index_t hdim_v, ck_tile::index_t hdim_v,
ck_tile::index_t num_splits, ck_tile::index_t num_splits,
...@@ -206,7 +200,6 @@ struct FmhaFwdSplitKVCombineKernel ...@@ -206,7 +200,6 @@ struct FmhaFwdSplitKVCombineKernel
ck_tile::index_t nhead_stride_o_acc, ck_tile::index_t nhead_stride_o_acc,
ck_tile::index_t nhead_stride_lse, ck_tile::index_t nhead_stride_lse,
ck_tile::index_t nhead_stride_o, ck_tile::index_t nhead_stride_o,
ck_tile::index_t batch_stride_o_acc,
ck_tile::index_t split_stride_lse_acc, ck_tile::index_t split_stride_lse_acc,
ck_tile::index_t split_stride_o_acc) ck_tile::index_t split_stride_o_acc)
{ {
...@@ -214,7 +207,6 @@ struct FmhaFwdSplitKVCombineKernel ...@@ -214,7 +207,6 @@ struct FmhaFwdSplitKVCombineKernel
o_acc_ptr, o_acc_ptr,
o_ptr, o_ptr,
batch, batch,
max_seqlen_q,
-1, // seqlen will be updated by another pointer -1, // seqlen will be updated by another pointer
hdim_v, hdim_v,
num_splits, num_splits,
...@@ -223,7 +215,6 @@ struct FmhaFwdSplitKVCombineKernel ...@@ -223,7 +215,6 @@ struct FmhaFwdSplitKVCombineKernel
nhead_stride_lse_acc, nhead_stride_lse_acc,
nhead_stride_o_acc, nhead_stride_o_acc,
nhead_stride_o, nhead_stride_o,
batch_stride_o_acc,
split_stride_lse_acc, split_stride_lse_acc,
split_stride_o_acc}, // args for common karg split_stride_o_acc}, // args for common karg
{}, // placeholder for lse {}, // placeholder for lse
...@@ -243,12 +234,12 @@ struct FmhaFwdSplitKVCombineKernel ...@@ -243,12 +234,12 @@ struct FmhaFwdSplitKVCombineKernel
return kargs; return kargs;
} }
__host__ static constexpr auto GridSize(ck_tile::index_t batch_size_, __host__ static constexpr auto GridSize(ck_tile::index_t batch_size,
ck_tile::index_t nhead_, ck_tile::index_t nhead,
ck_tile::index_t seqlen_q_, ck_tile::index_t max_seqlen_q,
ck_tile::index_t hdim_v_) ck_tile::index_t hdim_v)
{ {
return TilePartitioner::GridSize(batch_size_, nhead_, seqlen_q_, hdim_v_); return TilePartitioner::GridSize(batch_size, nhead, max_seqlen_q, hdim_v);
} }
__host__ static constexpr auto BlockSize() { return dim3(kBlockSize); } __host__ static constexpr auto BlockSize() { return dim3(kBlockSize); }
...@@ -270,10 +261,8 @@ struct FmhaFwdSplitKVCombineKernel ...@@ -270,10 +261,8 @@ struct FmhaFwdSplitKVCombineKernel
const index_t i_m0 = __builtin_amdgcn_readfirstlane(i_tile_m * FmhaPipeline::kM0); const index_t i_m0 = __builtin_amdgcn_readfirstlane(i_tile_m * FmhaPipeline::kM0);
const index_t i_n1 = __builtin_amdgcn_readfirstlane(i_tile_n * FmhaPipeline::kN1); const index_t i_n1 = __builtin_amdgcn_readfirstlane(i_tile_n * FmhaPipeline::kN1);
const long_index_t batch_offset_o_acc =
static_cast<long_index_t>(i_batch) * kargs.batch_stride_o_acc;
long_index_t batch_offset_lse_acc = 0; long_index_t batch_offset_lse_acc = 0;
long_index_t batch_offset_o_acc = 0;
long_index_t batch_offset_lse = 0; long_index_t batch_offset_lse = 0;
long_index_t batch_offset_o = 0; long_index_t batch_offset_o = 0;
...@@ -282,14 +271,16 @@ struct FmhaFwdSplitKVCombineKernel ...@@ -282,14 +271,16 @@ struct FmhaFwdSplitKVCombineKernel
// get starting offset for each batch // get starting offset for each batch
const long_index_t query_start = kargs.seqstart_q_ptr[i_batch]; const long_index_t query_start = kargs.seqstart_q_ptr[i_batch];
batch_offset_o = query_start * kargs.row_stride_o;
batch_offset_lse_acc = query_start; batch_offset_lse_acc = query_start;
batch_offset_o_acc = query_start * kargs.row_stride_o_acc;
if constexpr(kStoreLSE) if constexpr(kStoreLSE)
{ {
batch_offset_lse = query_start; batch_offset_lse = query_start;
} }
batch_offset_o = query_start * kargs.row_stride_o;
// get real # queries & # keys under group mode // get real # queries & # keys under group mode
const auto adjusted_seqstart_q_ptr = kargs.seqstart_q_ptr + i_batch; const auto adjusted_seqstart_q_ptr = kargs.seqstart_q_ptr + i_batch;
kargs.seqlen_q = adjusted_seqstart_q_ptr[1] - adjusted_seqstart_q_ptr[0]; kargs.seqlen_q = adjusted_seqstart_q_ptr[1] - adjusted_seqstart_q_ptr[0];
...@@ -303,13 +294,15 @@ struct FmhaFwdSplitKVCombineKernel ...@@ -303,13 +294,15 @@ struct FmhaFwdSplitKVCombineKernel
} }
else else
{ {
batch_offset_o = static_cast<long_index_t>(i_batch) * kargs.batch_stride_o;
batch_offset_lse_acc = static_cast<long_index_t>(i_batch) * kargs.batch_stride_lse_acc; batch_offset_lse_acc = static_cast<long_index_t>(i_batch) * kargs.batch_stride_lse_acc;
batch_offset_o_acc = static_cast<long_index_t>(i_batch) * kargs.batch_stride_o_acc;
if constexpr(kStoreLSE) if constexpr(kStoreLSE)
{ {
batch_offset_lse = static_cast<long_index_t>(i_batch) * kargs.batch_stride_lse; batch_offset_lse = static_cast<long_index_t>(i_batch) * kargs.batch_stride_lse;
} }
batch_offset_o = static_cast<long_index_t>(i_batch) * kargs.batch_stride_o;
} }
// for simplicity, batch stride we just modify the pointer // for simplicity, batch stride we just modify the pointer
...@@ -341,7 +334,7 @@ struct FmhaFwdSplitKVCombineKernel ...@@ -341,7 +334,7 @@ struct FmhaFwdSplitKVCombineKernel
auto o_acc_dram = [&]() { auto o_acc_dram = [&]() {
const auto o_acc_dram_naive = make_naive_tensor_view<address_space_enum::global>( const auto o_acc_dram_naive = make_naive_tensor_view<address_space_enum::global>(
o_acc_ptr, o_acc_ptr,
make_tuple(kargs.num_splits, kargs.max_seqlen_q, kargs.hdim_v), make_tuple(kargs.num_splits, kargs.seqlen_q, kargs.hdim_v),
make_tuple(kargs.split_stride_o_acc, kargs.row_stride_o_acc, 1), make_tuple(kargs.split_stride_o_acc, kargs.row_stride_o_acc, 1),
number<FmhaPipeline::kAlignmentOacc>{}, number<FmhaPipeline::kAlignmentOacc>{},
number<1>{}); number<1>{});
...@@ -351,14 +344,14 @@ struct FmhaFwdSplitKVCombineKernel ...@@ -351,14 +344,14 @@ struct FmhaFwdSplitKVCombineKernel
make_tuple(number<1>{}, number<FmhaPipeline::kM0>{}, number<FmhaPipeline::kN1>{}), make_tuple(number<1>{}, number<FmhaPipeline::kM0>{}, number<FmhaPipeline::kN1>{}),
sequence<false, kPadSeqLenQ, kPadHeadDimV>{}); sequence<false, kPadSeqLenQ, kPadHeadDimV>{});
const index_t padded_max_seqlen_q = const index_t padded_seqlen_q =
o_acc_dram_view.get_tensor_descriptor().get_lengths()[number<1>{}]; o_acc_dram_view.get_tensor_descriptor().get_lengths()[number<1>{}];
const index_t padded_hdim_v = const index_t padded_hdim_v =
o_acc_dram_view.get_tensor_descriptor().get_lengths()[number<2>{}]; o_acc_dram_view.get_tensor_descriptor().get_lengths()[number<2>{}];
return transform_tensor_view( return transform_tensor_view(
o_acc_dram_view, o_acc_dram_view,
make_tuple(make_merge_transform(make_tuple(kargs.num_splits, padded_max_seqlen_q)), make_tuple(make_merge_transform(make_tuple(kargs.num_splits, padded_seqlen_q)),
make_pass_through_transform(padded_hdim_v)), make_pass_through_transform(padded_hdim_v)),
make_tuple(sequence<0, 1>{}, sequence<2>{}), make_tuple(sequence<0, 1>{}, sequence<2>{}),
make_tuple(sequence<0>{}, sequence<1>{})); make_tuple(sequence<0>{}, sequence<1>{}));
...@@ -417,7 +410,7 @@ struct FmhaFwdSplitKVCombineKernel ...@@ -417,7 +410,7 @@ struct FmhaFwdSplitKVCombineKernel
identity{}, // lse_element_func identity{}, // lse_element_func
composes(saturates<fp8_t>{}, scales{kargs.scale_o}), // o_acc_element_func composes(saturates<fp8_t>{}, scales{kargs.scale_o}), // o_acc_element_func
kargs.num_splits, kargs.num_splits,
kargs.max_seqlen_q, kargs.seqlen_q,
smem_ptr); smem_ptr);
} }
else else
...@@ -426,7 +419,7 @@ struct FmhaFwdSplitKVCombineKernel ...@@ -426,7 +419,7 @@ struct FmhaFwdSplitKVCombineKernel
o_acc_dram_window, o_acc_dram_window,
lse_dram_window, lse_dram_window,
kargs.num_splits, kargs.num_splits,
kargs.max_seqlen_q, kargs.seqlen_q,
smem_ptr); smem_ptr);
} }
}(); }();
......
...@@ -13,21 +13,20 @@ struct FmhaFwdSplitKVCombineTilePartitioner ...@@ -13,21 +13,20 @@ struct FmhaFwdSplitKVCombineTilePartitioner
static constexpr ck_tile::index_t kM0 = kM0_; static constexpr ck_tile::index_t kM0 = kM0_;
static constexpr ck_tile::index_t kN1 = kN1_; static constexpr ck_tile::index_t kN1 = kN1_;
CK_TILE_HOST static constexpr auto GridSize(ck_tile::index_t batch_size_, CK_TILE_HOST static constexpr auto GridSize(ck_tile::index_t batch_size,
ck_tile::index_t nhead_, ck_tile::index_t nhead,
ck_tile::index_t seqlen_q_, ck_tile::index_t max_seqlen_q,
ck_tile::index_t hdim_v_) ck_tile::index_t hdim_v)
{ {
// TODO: this may need tuning // TODO: this may need tuning
return dim3(ck_tile::integer_divide_ceil(seqlen_q_, kM0) * return dim3(ck_tile::integer_divide_ceil(max_seqlen_q, kM0) *
ck_tile::integer_divide_ceil(hdim_v_, kN1), ck_tile::integer_divide_ceil(hdim_v, kN1),
nhead_, nhead,
batch_size_); batch_size);
} }
CK_TILE_DEVICE auto operator()(ck_tile::index_t /*seqlen_q*/, ck_tile::index_t hdim_v) CK_TILE_DEVICE auto operator()(ck_tile::index_t /*seqlen_q*/, ck_tile::index_t hdim_v)
{ {
// const index_t num_tile_m0 = seqlen_q / kM0;
const index_t num_tile_n1 = ck_tile::integer_divide_ceil(hdim_v, kN1); const index_t num_tile_n1 = ck_tile::integer_divide_ceil(hdim_v, kN1);
const index_t i_block = blockIdx.x; const index_t i_block = blockIdx.x;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment