"tests/vscode:/vscode.git/clone" did not exist on "72182747a17231075b26768d37694781bd992daf"
Commit 09d4c3a4 authored by illsilin's avatar illsilin
Browse files

merge from public repo

parents 171ed358 8e4c3fb1
...@@ -176,7 +176,20 @@ struct HostTensorDescriptor ...@@ -176,7 +176,20 @@ struct HostTensorDescriptor
return std::inner_product(iss.begin(), iss.end(), mStrides.begin(), std::size_t{0}); return std::inner_product(iss.begin(), iss.end(), mStrides.begin(), std::size_t{0});
} }
friend std::ostream& operator<<(std::ostream& os, const HostTensorDescriptor& desc); friend std::ostream& operator<<(std::ostream& os, const HostTensorDescriptor& desc)
{
os << "dim " << desc.get_num_of_dimension() << ", ";
os << "lengths {";
LogRange(os, desc.get_lengths(), ", ");
os << "}, ";
os << "strides {";
LogRange(os, desc.get_strides(), ", ");
os << "}";
return os;
}
private: private:
std::vector<std::size_t> mLens; std::vector<std::size_t> mLens;
......
...@@ -5,6 +5,7 @@ ...@@ -5,6 +5,7 @@
#include "ck_tile/core.hpp" #include "ck_tile/core.hpp"
#include "ck_tile/host/host_tensor.hpp" #include "ck_tile/host/host_tensor.hpp"
#include "ck_tile/ops/common/tensor_layout.hpp"
#include <thread> #include <thread>
namespace ck_tile { namespace ck_tile {
...@@ -13,6 +14,9 @@ template <typename ADataType, ...@@ -13,6 +14,9 @@ template <typename ADataType,
typename BDataType, typename BDataType,
typename AccDataType, typename AccDataType,
typename CDataType, typename CDataType,
typename LayoutA,
typename LayoutB,
typename LayoutC,
typename AElementOp = ck_tile::identity, typename AElementOp = ck_tile::identity,
typename BElementOp = ck_tile::identity, typename BElementOp = ck_tile::identity,
typename ACCElementOp = ck_tile::identity> typename ACCElementOp = ck_tile::identity>
...@@ -24,7 +28,12 @@ CK_TILE_HOST void reference_gemm(const HostTensor<ADataType>& a_m_k, ...@@ -24,7 +28,12 @@ CK_TILE_HOST void reference_gemm(const HostTensor<ADataType>& a_m_k,
const ACCElementOp& acc_element_op = {}) const ACCElementOp& acc_element_op = {})
{ {
const int N = b_n_k.mDesc.get_lengths()[0]; const int N = b_n_k.mDesc.get_lengths()[0];
const int K = b_n_k.mDesc.get_lengths()[1]; const int K = (std::is_same_v<LayoutA, tensor_layout::gemm::RowMajor>)
? a_m_k.mDesc.get_lengths()[1]
: a_m_k.mDesc.get_lengths()[0];
const int M = (std::is_same_v<LayoutA, tensor_layout::gemm::RowMajor>)
? a_m_k.mDesc.get_lengths()[0]
: a_m_k.mDesc.get_lengths()[1];
auto f = [&](auto m) { auto f = [&](auto m) {
for(int n = 0; n < N; ++n) for(int n = 0; n < N; ++n)
...@@ -33,7 +42,9 @@ CK_TILE_HOST void reference_gemm(const HostTensor<ADataType>& a_m_k, ...@@ -33,7 +42,9 @@ CK_TILE_HOST void reference_gemm(const HostTensor<ADataType>& a_m_k,
for(int k = 0; k < K; ++k) for(int k = 0; k < K; ++k)
{ {
ADataType v_a = a_element_op(a_m_k(m, k)); ADataType v_a = (std::is_same_v<LayoutA, tensor_layout::gemm::RowMajor>)
? a_element_op(a_m_k(m, k))
: a_element_op(a_m_k(k, m));
BDataType v_b = b_element_op(b_n_k(n, k)); BDataType v_b = b_element_op(b_n_k(n, k));
v_acc += ck_tile::type_convert<AccDataType>(v_a) * v_acc += ck_tile::type_convert<AccDataType>(v_a) *
...@@ -44,7 +55,123 @@ CK_TILE_HOST void reference_gemm(const HostTensor<ADataType>& a_m_k, ...@@ -44,7 +55,123 @@ CK_TILE_HOST void reference_gemm(const HostTensor<ADataType>& a_m_k,
} }
}; };
make_ParallelTensorFunctor(f, make_ParallelTensorFunctor(f, M)(std::thread::hardware_concurrency());
c_m_n.mDesc.get_lengths()[0])(std::thread::hardware_concurrency()); }
template <typename ADataType, typename BDataType, typename AccDataType, typename CDataType>
__global__ void naive_gemm_kernel(ADataType* A,
BDataType* B,
CDataType* C,
ck_tile::index_t M,
ck_tile::index_t N,
ck_tile::index_t K,
ck_tile::index_t strideA,
ck_tile::index_t strideB,
ck_tile::index_t strideC)
{
int idx = blockIdx.x * blockDim.x + threadIdx.x;
int row = idx / N; // Compute row index
int col = idx % N; // Compute column index
if(row < M && col < N)
{
AccDataType acc = 0.0;
for(int k = 0; k < K; ++k)
{
acc += static_cast<AccDataType>(A[row * strideA + k]) *
static_cast<AccDataType>(B[col * strideB + k]);
}
C[row * strideC + col] = acc; // Store as AccDataType
}
}
template <typename ADataType, typename BDataType, typename AccDataType, typename CDataType>
void reference_gemm_gpu(DeviceMem& a_device,
DeviceMem& b_device,
DeviceMem& c_device,
index_t M,
index_t N,
index_t K,
index_t stride_a,
index_t stride_b,
index_t stride_c)
{
ADataType* d_A;
BDataType* d_B;
CDataType* d_C;
hipError_t errA = hipMalloc(&d_A, M * K * sizeof(ADataType));
hipError_t errB = hipMalloc(&d_B, N * K * sizeof(BDataType));
hipError_t errC = hipMalloc(&d_C, M * N * sizeof(CDataType));
if(errA != hipSuccess)
{
std::cerr << "Error allocating device memory for A: " << hipGetErrorString(errA)
<< std::endl;
return; // Early exit on error
}
if(errB != hipSuccess)
{
std::cerr << "Error allocating device memory for B: " << hipGetErrorString(errB)
<< std::endl;
return; // Early exit on error
}
if(errC != hipSuccess)
{
std::cerr << "Error allocating device memory for C: " << hipGetErrorString(errC)
<< std::endl;
return; // Early exit on error
}
errA = hipMemcpy(
d_A, a_device.GetDeviceBuffer(), M * K * sizeof(ADataType), hipMemcpyHostToDevice);
if(errA != hipSuccess)
{
std::cerr << "Error copying A to device: " << hipGetErrorString(errA) << std::endl;
}
errB = hipMemcpy(
d_B, b_device.GetDeviceBuffer(), N * K * sizeof(BDataType), hipMemcpyHostToDevice);
if(errB != hipSuccess)
{
std::cerr << "Error copying B to device: " << hipGetErrorString(errB) << std::endl;
}
int totalElements = M * N;
int numThreadsPerBlock = 256; // Common choice for threads per block
int numBlocks = (totalElements + numThreadsPerBlock - 1) / numThreadsPerBlock;
naive_gemm_kernel<ADataType, BDataType, AccDataType, CDataType>
<<<numBlocks, numThreadsPerBlock>>>(d_A, d_B, d_C, M, N, K, stride_a, stride_b, stride_c);
errC = hipMemcpy(
c_device.GetDeviceBuffer(), d_C, M * N * sizeof(CDataType), hipMemcpyDeviceToHost);
if(errC != hipSuccess)
{
std::cerr << "Error copying C to device: " << hipGetErrorString(errC) << std::endl;
}
errA = hipFree(d_A);
if(errA != hipSuccess)
{
std::cerr << "Error free the A memory: " << hipGetErrorString(errA) << std::endl;
}
errB = hipFree(d_B);
if(errB != hipSuccess)
{
std::cerr << "Error free the B memory: " << hipGetErrorString(errB) << std::endl;
}
errC = hipFree(d_C);
if(errC != hipSuccess)
{
std::cerr << "Error free the C memory: " << hipGetErrorString(errC) << std::endl;
}
return;
} }
} // namespace ck_tile } // namespace ck_tile
// SPDX-License-Identifier: MIT // SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once #pragma once
...@@ -9,53 +9,125 @@ ...@@ -9,53 +9,125 @@
namespace ck_tile { namespace ck_tile {
template <typename T> template <typename InDataType, typename OutDataType, index_t NDimSpatial>
CK_TILE_HOST void reference_im2col(HostTensor<T>& in_mtx_host_ref, CK_TILE_HOST void reference_im2col(const HostTensor<InDataType>& in_host,
const HostTensor<T>& in_host, HostTensor<OutDataType>& out_host,
int /*N*/, const ck_tile::conv::ConvParam& conv_params)
int /*K*/,
int C,
int /*Y*/,
int X,
int Hi,
int Wi,
int Ho,
int Wo,
int ConvStrideH,
int ConvStrideW,
int ConvDilationH,
int ConvDilationW,
int InLeftPadH,
int InLeftPadW,
int /*InRightPadH*/,
int /*InRightPadW*/)
{ {
int GemmM = in_mtx_host_ref.get_lengths()[0]; const long_index_t G = in_host.get_lengths()[0];
int GemmK = in_mtx_host_ref.get_lengths()[1]; const long_index_t N = in_host.get_lengths()[1];
const long_index_t C = in_host.get_lengths()[2];
for(int gemm_m = 0; gemm_m < GemmM; ++gemm_m) if constexpr(NDimSpatial == 1)
{ {
int mtmp = gemm_m; const long_index_t Wo = conv_params.output_spatial_lengths_[0];
int n = mtmp / (Ho * Wo); auto func = [&](auto g, auto n, auto wo) {
mtmp -= n * Ho * Wo; long_index_t row = n * Wo + wo;
int ho = mtmp / Wo; long_index_t column = 0;
int wo = mtmp - ho * Wo;
for(long_index_t x = 0; x < conv_params.filter_spatial_lengths_[0]; ++x)
for(int gemm_k = 0; gemm_k < GemmK; ++gemm_k) {
{ auto wi = static_cast<long_index_t>(wo * conv_params.conv_filter_strides_[0]) +
int ktmp = gemm_k; static_cast<long_index_t>(x * conv_params.conv_filter_dilations_[0]) -
int y = ktmp / (X * C); static_cast<long_index_t>(conv_params.input_left_pads_[0]);
ktmp -= y * X * C;
int x = ktmp / C; for(long_index_t c = 0; c < C; ++c)
int c = ktmp - x * C; {
if(wi >= 0 && type_convert<std::size_t>(wi) < in_host.get_lengths()[3])
int hi = y * ConvDilationH + ho * ConvStrideH - InLeftPadH; {
int wi = x * ConvDilationW + wo * ConvStrideW - InLeftPadW; InDataType v_in = in_host(g, n, c, wi);
out_host(g, row, column) = type_convert<OutDataType>(v_in);
bool inbound = (hi >= 0 && hi < Hi && wi >= 0 && wi < Wi); }
column++;
in_mtx_host_ref(gemm_m, gemm_k) = inbound ? in_host(n, hi, wi, c) : 0; }
} }
};
make_ParallelTensorFunctor(func, G, N, Wo)(std::thread::hardware_concurrency());
}
else if constexpr(NDimSpatial == 2)
{
const long_index_t Ho = conv_params.output_spatial_lengths_[0];
const long_index_t Wo = conv_params.output_spatial_lengths_[1];
auto func = [&](auto g, auto n, auto ho, auto wo) {
long_index_t row = n * Ho * Wo + ho * Wo + wo;
long_index_t column = 0;
for(long_index_t y = 0; y < conv_params.filter_spatial_lengths_[0]; ++y)
{
auto hi = static_cast<long_index_t>(ho * conv_params.conv_filter_strides_[0]) +
static_cast<long_index_t>(y * conv_params.conv_filter_dilations_[0]) -
static_cast<long_index_t>(conv_params.input_left_pads_[0]);
for(long_index_t x = 0; x < conv_params.filter_spatial_lengths_[1]; ++x)
{
auto wi = static_cast<long_index_t>(wo * conv_params.conv_filter_strides_[1]) +
static_cast<long_index_t>(x * conv_params.conv_filter_dilations_[1]) -
static_cast<long_index_t>(conv_params.input_left_pads_[1]);
for(long_index_t c = 0; c < C; ++c)
{
if(hi >= 0 && type_convert<std::size_t>(hi) < in_host.get_lengths()[3] &&
wi >= 0 && type_convert<std::size_t>(wi) < in_host.get_lengths()[4])
{
InDataType v_in = in_host(g, n, c, hi, wi);
out_host(g, row, column) = type_convert<OutDataType>(v_in);
}
column++;
}
}
}
};
make_ParallelTensorFunctor(func, G, N, Ho, Wo)(std::thread::hardware_concurrency());
}
else if constexpr(NDimSpatial == 3)
{
const long_index_t Do = conv_params.output_spatial_lengths_[0];
const long_index_t Ho = conv_params.output_spatial_lengths_[1];
const long_index_t Wo = conv_params.output_spatial_lengths_[2];
auto func = [&](auto g, auto n, auto d_o, auto ho, auto wo) {
long_index_t row = n * Do * Ho * Wo + d_o * Ho * Wo + ho * Wo + wo;
long_index_t column = 0;
for(long_index_t z = 0; z < conv_params.filter_spatial_lengths_[0]; ++z)
{
auto di = static_cast<long_index_t>(d_o * conv_params.conv_filter_strides_[0]) +
static_cast<long_index_t>(z * conv_params.conv_filter_dilations_[0]) -
static_cast<long_index_t>(conv_params.input_left_pads_[0]);
for(long_index_t y = 0; y < conv_params.filter_spatial_lengths_[1]; ++y)
{
auto hi = static_cast<long_index_t>(ho * conv_params.conv_filter_strides_[1]) +
static_cast<long_index_t>(y * conv_params.conv_filter_dilations_[1]) -
static_cast<long_index_t>(conv_params.input_left_pads_[1]);
for(long_index_t x = 0; x < conv_params.filter_spatial_lengths_[2]; ++x)
{
auto wi =
static_cast<long_index_t>(wo * conv_params.conv_filter_strides_[2]) +
static_cast<long_index_t>(x * conv_params.conv_filter_dilations_[2]) -
static_cast<long_index_t>(conv_params.input_left_pads_[2]);
for(long_index_t c = 0; c < C; ++c)
{
if(di >= 0 &&
type_convert<std::size_t>(di) < in_host.get_lengths()[3] &&
hi >= 0 &&
type_convert<std::size_t>(hi) < in_host.get_lengths()[4] &&
wi >= 0 && type_convert<std::size_t>(wi) < in_host.get_lengths()[5])
{
InDataType v_in = in_host(g, n, c, di, hi, wi);
out_host(g, row, column) = type_convert<OutDataType>(v_in);
}
column++;
}
}
}
}
};
make_ParallelTensorFunctor(func, G, N, Do, Ho, Wo)(std::thread::hardware_concurrency());
} }
} }
} // namespace ck_tile } // namespace ck_tile
...@@ -308,9 +308,9 @@ struct SimplifiedGenericAttentionMask ...@@ -308,9 +308,9 @@ struct SimplifiedGenericAttentionMask
{ {
auto [origin_start, origin_end] = GetTileRangeAlongX(i_y, height, width); auto [origin_start, origin_end] = GetTileRangeAlongX(i_y, height, width);
const index_t x_per_split = ck_tile::max(1, x_total / num_splits); const index_t x_per_split = ck_tile::max(1, integer_divide_ceil(x_total, num_splits));
const index_t split_start = x_per_split * i_split; const index_t split_start = x_per_split * i_split;
const index_t split_end = (i_split == num_splits - 1 ? x_total : split_start + x_per_split); const index_t split_end = split_start + x_per_split;
return ck_tile::make_tuple(ck_tile::max(origin_start, split_start), return ck_tile::make_tuple(ck_tile::max(origin_start, split_start),
ck_tile::min(origin_end, split_end)); ck_tile::min(origin_end, split_end));
......
...@@ -78,8 +78,6 @@ struct FmhaFwdSplitKVCombineKernel ...@@ -78,8 +78,6 @@ struct FmhaFwdSplitKVCombineKernel
void* o_ptr; void* o_ptr;
ck_tile::index_t batch; ck_tile::index_t batch;
ck_tile::index_t max_seqlen_q;
ck_tile::index_t seqlen_q; ck_tile::index_t seqlen_q;
ck_tile::index_t hdim_v; ck_tile::index_t hdim_v;
ck_tile::index_t num_splits; ck_tile::index_t num_splits;
...@@ -91,8 +89,6 @@ struct FmhaFwdSplitKVCombineKernel ...@@ -91,8 +89,6 @@ struct FmhaFwdSplitKVCombineKernel
ck_tile::index_t nhead_stride_o_acc; ck_tile::index_t nhead_stride_o_acc;
ck_tile::index_t nhead_stride_o; ck_tile::index_t nhead_stride_o;
ck_tile::index_t batch_stride_o_acc;
ck_tile::index_t split_stride_lse_acc; ck_tile::index_t split_stride_lse_acc;
ck_tile::index_t split_stride_o_acc; ck_tile::index_t split_stride_o_acc;
}; };
...@@ -114,8 +110,9 @@ struct FmhaFwdSplitKVCombineKernel ...@@ -114,8 +110,9 @@ struct FmhaFwdSplitKVCombineKernel
std::conditional_t<kStoreLSE, CommonLSEKargs, EmptyKargs<0>>, std::conditional_t<kStoreLSE, CommonLSEKargs, EmptyKargs<0>>,
std::conditional_t<kDoFp8StaticQuant, Fp8StaticQuantKargs, EmptyKargs<1>> std::conditional_t<kDoFp8StaticQuant, Fp8StaticQuantKargs, EmptyKargs<1>>
{ {
ck_tile::index_t batch_stride_o;
ck_tile::index_t batch_stride_lse_acc; ck_tile::index_t batch_stride_lse_acc;
ck_tile::index_t batch_stride_o_acc;
ck_tile::index_t batch_stride_o;
}; };
struct GroupModeKargs struct GroupModeKargs
...@@ -135,7 +132,6 @@ struct FmhaFwdSplitKVCombineKernel ...@@ -135,7 +132,6 @@ struct FmhaFwdSplitKVCombineKernel
void* lse_ptr, void* lse_ptr,
void* o_ptr, void* o_ptr,
ck_tile::index_t batch, ck_tile::index_t batch,
ck_tile::index_t max_seqlen_q,
ck_tile::index_t seqlen_q, ck_tile::index_t seqlen_q,
ck_tile::index_t hdim_v, ck_tile::index_t hdim_v,
ck_tile::index_t num_splits, ck_tile::index_t num_splits,
...@@ -157,7 +153,6 @@ struct FmhaFwdSplitKVCombineKernel ...@@ -157,7 +153,6 @@ struct FmhaFwdSplitKVCombineKernel
o_acc_ptr, o_acc_ptr,
o_ptr, o_ptr,
batch, batch,
max_seqlen_q,
seqlen_q, seqlen_q,
hdim_v, hdim_v,
num_splits, num_splits,
...@@ -166,13 +161,13 @@ struct FmhaFwdSplitKVCombineKernel ...@@ -166,13 +161,13 @@ struct FmhaFwdSplitKVCombineKernel
nhead_stride_lse_acc, nhead_stride_lse_acc,
nhead_stride_o_acc, nhead_stride_o_acc,
nhead_stride_o, nhead_stride_o,
batch_stride_o_acc,
split_stride_lse_acc, split_stride_lse_acc,
split_stride_o_acc}, // args for common karg split_stride_o_acc}, // args for common karg
{}, // placeholder for lse {}, // placeholder for lse
{}, // placeholder for fp8_static_quant args {}, // placeholder for fp8_static_quant args
batch_stride_o, batch_stride_lse_acc,
batch_stride_lse_acc}; batch_stride_o_acc,
batch_stride_o};
if constexpr(kStoreLSE) if constexpr(kStoreLSE)
{ {
...@@ -195,7 +190,6 @@ struct FmhaFwdSplitKVCombineKernel ...@@ -195,7 +190,6 @@ struct FmhaFwdSplitKVCombineKernel
void* lse_ptr, void* lse_ptr,
void* o_ptr, void* o_ptr,
ck_tile::index_t batch, ck_tile::index_t batch,
ck_tile::index_t max_seqlen_q,
const void* seqstart_q_ptr, const void* seqstart_q_ptr,
ck_tile::index_t hdim_v, ck_tile::index_t hdim_v,
ck_tile::index_t num_splits, ck_tile::index_t num_splits,
...@@ -206,7 +200,6 @@ struct FmhaFwdSplitKVCombineKernel ...@@ -206,7 +200,6 @@ struct FmhaFwdSplitKVCombineKernel
ck_tile::index_t nhead_stride_o_acc, ck_tile::index_t nhead_stride_o_acc,
ck_tile::index_t nhead_stride_lse, ck_tile::index_t nhead_stride_lse,
ck_tile::index_t nhead_stride_o, ck_tile::index_t nhead_stride_o,
ck_tile::index_t batch_stride_o_acc,
ck_tile::index_t split_stride_lse_acc, ck_tile::index_t split_stride_lse_acc,
ck_tile::index_t split_stride_o_acc) ck_tile::index_t split_stride_o_acc)
{ {
...@@ -214,7 +207,6 @@ struct FmhaFwdSplitKVCombineKernel ...@@ -214,7 +207,6 @@ struct FmhaFwdSplitKVCombineKernel
o_acc_ptr, o_acc_ptr,
o_ptr, o_ptr,
batch, batch,
max_seqlen_q,
-1, // seqlen will be updated by another pointer -1, // seqlen will be updated by another pointer
hdim_v, hdim_v,
num_splits, num_splits,
...@@ -223,7 +215,6 @@ struct FmhaFwdSplitKVCombineKernel ...@@ -223,7 +215,6 @@ struct FmhaFwdSplitKVCombineKernel
nhead_stride_lse_acc, nhead_stride_lse_acc,
nhead_stride_o_acc, nhead_stride_o_acc,
nhead_stride_o, nhead_stride_o,
batch_stride_o_acc,
split_stride_lse_acc, split_stride_lse_acc,
split_stride_o_acc}, // args for common karg split_stride_o_acc}, // args for common karg
{}, // placeholder for lse {}, // placeholder for lse
...@@ -243,12 +234,12 @@ struct FmhaFwdSplitKVCombineKernel ...@@ -243,12 +234,12 @@ struct FmhaFwdSplitKVCombineKernel
return kargs; return kargs;
} }
__host__ static constexpr auto GridSize(ck_tile::index_t batch_size_, __host__ static constexpr auto GridSize(ck_tile::index_t batch_size,
ck_tile::index_t nhead_, ck_tile::index_t nhead,
ck_tile::index_t seqlen_q_, ck_tile::index_t max_seqlen_q,
ck_tile::index_t hdim_v_) ck_tile::index_t hdim_v)
{ {
return TilePartitioner::GridSize(batch_size_, nhead_, seqlen_q_, hdim_v_); return TilePartitioner::GridSize(batch_size, nhead, max_seqlen_q, hdim_v);
} }
__host__ static constexpr auto BlockSize() { return dim3(kBlockSize); } __host__ static constexpr auto BlockSize() { return dim3(kBlockSize); }
...@@ -270,10 +261,8 @@ struct FmhaFwdSplitKVCombineKernel ...@@ -270,10 +261,8 @@ struct FmhaFwdSplitKVCombineKernel
const index_t i_m0 = __builtin_amdgcn_readfirstlane(i_tile_m * FmhaPipeline::kM0); const index_t i_m0 = __builtin_amdgcn_readfirstlane(i_tile_m * FmhaPipeline::kM0);
const index_t i_n1 = __builtin_amdgcn_readfirstlane(i_tile_n * FmhaPipeline::kN1); const index_t i_n1 = __builtin_amdgcn_readfirstlane(i_tile_n * FmhaPipeline::kN1);
const long_index_t batch_offset_o_acc =
static_cast<long_index_t>(i_batch) * kargs.batch_stride_o_acc;
long_index_t batch_offset_lse_acc = 0; long_index_t batch_offset_lse_acc = 0;
long_index_t batch_offset_o_acc = 0;
long_index_t batch_offset_lse = 0; long_index_t batch_offset_lse = 0;
long_index_t batch_offset_o = 0; long_index_t batch_offset_o = 0;
...@@ -282,14 +271,16 @@ struct FmhaFwdSplitKVCombineKernel ...@@ -282,14 +271,16 @@ struct FmhaFwdSplitKVCombineKernel
// get starting offset for each batch // get starting offset for each batch
const long_index_t query_start = kargs.seqstart_q_ptr[i_batch]; const long_index_t query_start = kargs.seqstart_q_ptr[i_batch];
batch_offset_o = query_start * kargs.row_stride_o;
batch_offset_lse_acc = query_start; batch_offset_lse_acc = query_start;
batch_offset_o_acc = query_start * kargs.row_stride_o_acc;
if constexpr(kStoreLSE) if constexpr(kStoreLSE)
{ {
batch_offset_lse = query_start; batch_offset_lse = query_start;
} }
batch_offset_o = query_start * kargs.row_stride_o;
// get real # queries & # keys under group mode // get real # queries & # keys under group mode
const auto adjusted_seqstart_q_ptr = kargs.seqstart_q_ptr + i_batch; const auto adjusted_seqstart_q_ptr = kargs.seqstart_q_ptr + i_batch;
kargs.seqlen_q = adjusted_seqstart_q_ptr[1] - adjusted_seqstart_q_ptr[0]; kargs.seqlen_q = adjusted_seqstart_q_ptr[1] - adjusted_seqstart_q_ptr[0];
...@@ -303,13 +294,15 @@ struct FmhaFwdSplitKVCombineKernel ...@@ -303,13 +294,15 @@ struct FmhaFwdSplitKVCombineKernel
} }
else else
{ {
batch_offset_o = static_cast<long_index_t>(i_batch) * kargs.batch_stride_o;
batch_offset_lse_acc = static_cast<long_index_t>(i_batch) * kargs.batch_stride_lse_acc; batch_offset_lse_acc = static_cast<long_index_t>(i_batch) * kargs.batch_stride_lse_acc;
batch_offset_o_acc = static_cast<long_index_t>(i_batch) * kargs.batch_stride_o_acc;
if constexpr(kStoreLSE) if constexpr(kStoreLSE)
{ {
batch_offset_lse = static_cast<long_index_t>(i_batch) * kargs.batch_stride_lse; batch_offset_lse = static_cast<long_index_t>(i_batch) * kargs.batch_stride_lse;
} }
batch_offset_o = static_cast<long_index_t>(i_batch) * kargs.batch_stride_o;
} }
// for simplicity, batch stride we just modify the pointer // for simplicity, batch stride we just modify the pointer
...@@ -341,7 +334,7 @@ struct FmhaFwdSplitKVCombineKernel ...@@ -341,7 +334,7 @@ struct FmhaFwdSplitKVCombineKernel
auto o_acc_dram = [&]() { auto o_acc_dram = [&]() {
const auto o_acc_dram_naive = make_naive_tensor_view<address_space_enum::global>( const auto o_acc_dram_naive = make_naive_tensor_view<address_space_enum::global>(
o_acc_ptr, o_acc_ptr,
make_tuple(kargs.num_splits, kargs.max_seqlen_q, kargs.hdim_v), make_tuple(kargs.num_splits, kargs.seqlen_q, kargs.hdim_v),
make_tuple(kargs.split_stride_o_acc, kargs.row_stride_o_acc, 1), make_tuple(kargs.split_stride_o_acc, kargs.row_stride_o_acc, 1),
number<FmhaPipeline::kAlignmentOacc>{}, number<FmhaPipeline::kAlignmentOacc>{},
number<1>{}); number<1>{});
...@@ -351,14 +344,14 @@ struct FmhaFwdSplitKVCombineKernel ...@@ -351,14 +344,14 @@ struct FmhaFwdSplitKVCombineKernel
make_tuple(number<1>{}, number<FmhaPipeline::kM0>{}, number<FmhaPipeline::kN1>{}), make_tuple(number<1>{}, number<FmhaPipeline::kM0>{}, number<FmhaPipeline::kN1>{}),
sequence<false, kPadSeqLenQ, kPadHeadDimV>{}); sequence<false, kPadSeqLenQ, kPadHeadDimV>{});
const index_t padded_max_seqlen_q = const index_t padded_seqlen_q =
o_acc_dram_view.get_tensor_descriptor().get_lengths()[number<1>{}]; o_acc_dram_view.get_tensor_descriptor().get_lengths()[number<1>{}];
const index_t padded_hdim_v = const index_t padded_hdim_v =
o_acc_dram_view.get_tensor_descriptor().get_lengths()[number<2>{}]; o_acc_dram_view.get_tensor_descriptor().get_lengths()[number<2>{}];
return transform_tensor_view( return transform_tensor_view(
o_acc_dram_view, o_acc_dram_view,
make_tuple(make_merge_transform(make_tuple(kargs.num_splits, padded_max_seqlen_q)), make_tuple(make_merge_transform(make_tuple(kargs.num_splits, padded_seqlen_q)),
make_pass_through_transform(padded_hdim_v)), make_pass_through_transform(padded_hdim_v)),
make_tuple(sequence<0, 1>{}, sequence<2>{}), make_tuple(sequence<0, 1>{}, sequence<2>{}),
make_tuple(sequence<0>{}, sequence<1>{})); make_tuple(sequence<0>{}, sequence<1>{}));
...@@ -417,7 +410,7 @@ struct FmhaFwdSplitKVCombineKernel ...@@ -417,7 +410,7 @@ struct FmhaFwdSplitKVCombineKernel
identity{}, // lse_element_func identity{}, // lse_element_func
composes(saturates<fp8_t>{}, scales{kargs.scale_o}), // o_acc_element_func composes(saturates<fp8_t>{}, scales{kargs.scale_o}), // o_acc_element_func
kargs.num_splits, kargs.num_splits,
kargs.max_seqlen_q, kargs.seqlen_q,
smem_ptr); smem_ptr);
} }
else else
...@@ -426,7 +419,7 @@ struct FmhaFwdSplitKVCombineKernel ...@@ -426,7 +419,7 @@ struct FmhaFwdSplitKVCombineKernel
o_acc_dram_window, o_acc_dram_window,
lse_dram_window, lse_dram_window,
kargs.num_splits, kargs.num_splits,
kargs.max_seqlen_q, kargs.seqlen_q,
smem_ptr); smem_ptr);
} }
}(); }();
......
...@@ -13,21 +13,20 @@ struct FmhaFwdSplitKVCombineTilePartitioner ...@@ -13,21 +13,20 @@ struct FmhaFwdSplitKVCombineTilePartitioner
static constexpr ck_tile::index_t kM0 = kM0_; static constexpr ck_tile::index_t kM0 = kM0_;
static constexpr ck_tile::index_t kN1 = kN1_; static constexpr ck_tile::index_t kN1 = kN1_;
CK_TILE_HOST static constexpr auto GridSize(ck_tile::index_t batch_size_, CK_TILE_HOST static constexpr auto GridSize(ck_tile::index_t batch_size,
ck_tile::index_t nhead_, ck_tile::index_t nhead,
ck_tile::index_t seqlen_q_, ck_tile::index_t max_seqlen_q,
ck_tile::index_t hdim_v_) ck_tile::index_t hdim_v)
{ {
// TODO: this may need tuning // TODO: this may need tuning
return dim3(ck_tile::integer_divide_ceil(seqlen_q_, kM0) * return dim3(ck_tile::integer_divide_ceil(max_seqlen_q, kM0) *
ck_tile::integer_divide_ceil(hdim_v_, kN1), ck_tile::integer_divide_ceil(hdim_v, kN1),
nhead_, nhead,
batch_size_); batch_size);
} }
CK_TILE_DEVICE auto operator()(ck_tile::index_t /*seqlen_q*/, ck_tile::index_t hdim_v) CK_TILE_DEVICE auto operator()(ck_tile::index_t /*seqlen_q*/, ck_tile::index_t hdim_v)
{ {
// const index_t num_tile_m0 = seqlen_q / kM0;
const index_t num_tile_n1 = ck_tile::integer_divide_ceil(hdim_v, kN1); const index_t num_tile_n1 = ck_tile::integer_divide_ceil(hdim_v, kN1);
const index_t i_block = blockIdx.x; const index_t i_block = blockIdx.x;
......
...@@ -135,9 +135,6 @@ struct FmhaFwdSplitKVKernel ...@@ -135,9 +135,6 @@ struct FmhaFwdSplitKVKernel
ck_tile::index_t nhead_stride_lse_acc; ck_tile::index_t nhead_stride_lse_acc;
ck_tile::index_t nhead_stride_o_acc; ck_tile::index_t nhead_stride_o_acc;
ck_tile::index_t batch_stride_lse_acc;
ck_tile::index_t batch_stride_o_acc;
ck_tile::index_t split_stride_lse_acc; ck_tile::index_t split_stride_lse_acc;
ck_tile::index_t split_stride_o_acc; ck_tile::index_t split_stride_o_acc;
}; };
...@@ -201,6 +198,8 @@ struct FmhaFwdSplitKVKernel ...@@ -201,6 +198,8 @@ struct FmhaFwdSplitKVKernel
ck_tile::index_t batch_stride_q; ck_tile::index_t batch_stride_q;
ck_tile::index_t batch_stride_k; ck_tile::index_t batch_stride_k;
ck_tile::index_t batch_stride_v; ck_tile::index_t batch_stride_v;
ck_tile::index_t batch_stride_lse_acc;
ck_tile::index_t batch_stride_o_acc;
}; };
struct GroupModeKargs struct GroupModeKargs
...@@ -217,8 +216,8 @@ struct FmhaFwdSplitKVKernel ...@@ -217,8 +216,8 @@ struct FmhaFwdSplitKVKernel
const int32_t* seqstart_k_ptr; const int32_t* seqstart_k_ptr;
const int32_t* seqlen_k_ptr; const int32_t* seqlen_k_ptr;
ck_tile::index_t batch_stride_k; ck_tile::index_t batch_stride_k; // only used for paged-kvcache
ck_tile::index_t batch_stride_v; ck_tile::index_t batch_stride_v; // only used for paged-kvcache
}; };
using Kargs = std::conditional_t<kIsGroupMode, GroupModeKargs, BatchModeKargs>; using Kargs = std::conditional_t<kIsGroupMode, GroupModeKargs, BatchModeKargs>;
...@@ -296,8 +295,6 @@ struct FmhaFwdSplitKVKernel ...@@ -296,8 +295,6 @@ struct FmhaFwdSplitKVKernel
nhead_stride_v, nhead_stride_v,
nhead_stride_lse_acc, nhead_stride_lse_acc,
nhead_stride_o_acc, nhead_stride_o_acc,
batch_stride_lse_acc,
batch_stride_o_acc,
split_stride_lse_acc, split_stride_lse_acc,
split_stride_o_acc}, // args for common karg split_stride_o_acc}, // args for common karg
{}, // placeholder for bias {}, // placeholder for bias
...@@ -307,7 +304,9 @@ struct FmhaFwdSplitKVKernel ...@@ -307,7 +304,9 @@ struct FmhaFwdSplitKVKernel
reinterpret_cast<const int32_t*>(seqlen_k_ptr), reinterpret_cast<const int32_t*>(seqlen_k_ptr),
batch_stride_q, batch_stride_q,
batch_stride_k, batch_stride_k,
batch_stride_v}; batch_stride_v,
batch_stride_lse_acc,
batch_stride_o_acc};
if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS) if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS)
{ {
...@@ -375,10 +374,8 @@ struct FmhaFwdSplitKVKernel ...@@ -375,10 +374,8 @@ struct FmhaFwdSplitKVKernel
ck_tile::index_t nhead_stride_bias, ck_tile::index_t nhead_stride_bias,
ck_tile::index_t nhead_stride_lse_acc, ck_tile::index_t nhead_stride_lse_acc,
ck_tile::index_t nhead_stride_o_acc, ck_tile::index_t nhead_stride_o_acc,
ck_tile::index_t batch_stride_k, ck_tile::index_t batch_stride_k, // only used for paged-kvcache
ck_tile::index_t batch_stride_v, ck_tile::index_t batch_stride_v, // only used for paged-kvcache
ck_tile::index_t batch_stride_lse_acc,
ck_tile::index_t batch_stride_o_acc,
ck_tile::index_t split_stride_lse_acc, ck_tile::index_t split_stride_lse_acc,
ck_tile::index_t split_stride_o_acc, ck_tile::index_t split_stride_o_acc,
ck_tile::index_t window_size_left, ck_tile::index_t window_size_left,
...@@ -412,8 +409,6 @@ struct FmhaFwdSplitKVKernel ...@@ -412,8 +409,6 @@ struct FmhaFwdSplitKVKernel
nhead_stride_v, nhead_stride_v,
nhead_stride_lse_acc, nhead_stride_lse_acc,
nhead_stride_o_acc, nhead_stride_o_acc,
batch_stride_lse_acc,
batch_stride_o_acc,
split_stride_lse_acc, split_stride_lse_acc,
split_stride_o_acc}, // args for common karg split_stride_o_acc}, // args for common karg
{}, // placeholder for bias {}, // placeholder for bias
...@@ -452,11 +447,11 @@ struct FmhaFwdSplitKVKernel ...@@ -452,11 +447,11 @@ struct FmhaFwdSplitKVKernel
__host__ static constexpr auto GridSize(ck_tile::index_t batch_size, __host__ static constexpr auto GridSize(ck_tile::index_t batch_size,
ck_tile::index_t nhead, ck_tile::index_t nhead,
ck_tile::index_t seqlen_q, ck_tile::index_t max_seqlen_q,
ck_tile::index_t hdim_v, ck_tile::index_t hdim_v,
ck_tile::index_t num_splits) ck_tile::index_t num_splits)
{ {
return TilePartitioner::GridSize(batch_size, nhead, seqlen_q, hdim_v, num_splits); return TilePartitioner::GridSize(batch_size, nhead, max_seqlen_q, hdim_v, num_splits);
} }
__host__ static constexpr auto BlockSize() { return dim3(kBlockSize); } __host__ static constexpr auto BlockSize() { return dim3(kBlockSize); }
...@@ -483,8 +478,7 @@ struct FmhaFwdSplitKVKernel ...@@ -483,8 +478,7 @@ struct FmhaFwdSplitKVKernel
long_index_t batch_offset_v = 0; long_index_t batch_offset_v = 0;
long_index_t batch_offset_bias = 0; long_index_t batch_offset_bias = 0;
long_index_t batch_offset_lse_acc = 0; long_index_t batch_offset_lse_acc = 0;
const long_index_t batch_offset_o_acc = long_index_t batch_offset_o_acc = 0;
static_cast<long_index_t>(i_batch) * kargs.batch_stride_o_acc;
if constexpr(kIsGroupMode) if constexpr(kIsGroupMode)
{ {
...@@ -492,9 +486,9 @@ struct FmhaFwdSplitKVKernel ...@@ -492,9 +486,9 @@ struct FmhaFwdSplitKVKernel
const long_index_t query_start = kargs.seqstart_q_ptr[i_batch]; const long_index_t query_start = kargs.seqstart_q_ptr[i_batch];
const long_index_t key_start = kargs.seqstart_k_ptr[i_batch]; const long_index_t key_start = kargs.seqstart_k_ptr[i_batch];
batch_offset_q = query_start * kargs.stride_q; batch_offset_q = query_start * kargs.stride_q;
batch_offset_k = key_start * kargs.stride_k; batch_offset_k = key_start * kargs.stride_k;
batch_offset_lse_acc = query_start;
if constexpr(std::is_same_v<VLayout, ck_tile::tensor_layout::gemm::RowMajor>) if constexpr(std::is_same_v<VLayout, ck_tile::tensor_layout::gemm::RowMajor>)
{ {
batch_offset_v = key_start * kargs.stride_v; batch_offset_v = key_start * kargs.stride_v;
...@@ -508,6 +502,9 @@ struct FmhaFwdSplitKVKernel ...@@ -508,6 +502,9 @@ struct FmhaFwdSplitKVKernel
batch_offset_bias = query_start * kargs.stride_bias + key_start; batch_offset_bias = query_start * kargs.stride_bias + key_start;
} }
batch_offset_lse_acc = query_start;
batch_offset_o_acc = query_start * kargs.stride_o_acc;
// get real # queries & # keys under group mode // get real # queries & # keys under group mode
kargs.seqlen_q = kargs.seqstart_q_ptr[i_batch + 1] - kargs.seqstart_q_ptr[i_batch]; kargs.seqlen_q = kargs.seqstart_q_ptr[i_batch + 1] - kargs.seqstart_q_ptr[i_batch];
...@@ -545,6 +542,7 @@ struct FmhaFwdSplitKVKernel ...@@ -545,6 +542,7 @@ struct FmhaFwdSplitKVKernel
batch_offset_k = static_cast<long_index_t>(i_cache_batch) * kargs.batch_stride_k; batch_offset_k = static_cast<long_index_t>(i_cache_batch) * kargs.batch_stride_k;
batch_offset_v = static_cast<long_index_t>(i_cache_batch) * kargs.batch_stride_v; batch_offset_v = static_cast<long_index_t>(i_cache_batch) * kargs.batch_stride_v;
batch_offset_lse_acc = static_cast<long_index_t>(i_batch) * kargs.batch_stride_lse_acc; batch_offset_lse_acc = static_cast<long_index_t>(i_batch) * kargs.batch_stride_lse_acc;
batch_offset_o_acc = static_cast<long_index_t>(i_batch) * kargs.batch_stride_o_acc;
if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS) if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS)
{ {
...@@ -895,8 +893,8 @@ struct FmhaFwdSplitKVKernel ...@@ -895,8 +893,8 @@ struct FmhaFwdSplitKVKernel
const auto o_acc_dram_naive = make_naive_tensor_view<address_space_enum::global>( const auto o_acc_dram_naive = make_naive_tensor_view<address_space_enum::global>(
o_acc_ptr, o_acc_ptr,
make_tuple(kargs.seqlen_q, kargs.hdim_v), make_tuple(kargs.seqlen_q, kargs.hdim_v),
make_tuple(kargs.hdim_v, 1), make_tuple(kargs.stride_o_acc, 1),
number<FmhaPipeline::kAlignmentO>{}, number<1>{},
number<1>{}); number<1>{});
return pad_tensor_view( return pad_tensor_view(
......
...@@ -20,12 +20,12 @@ struct FmhaFwdSplitKVTilePartitioner ...@@ -20,12 +20,12 @@ struct FmhaFwdSplitKVTilePartitioner
__host__ static constexpr auto GridSize(ck_tile::index_t batch_size, __host__ static constexpr auto GridSize(ck_tile::index_t batch_size,
ck_tile::index_t nhead, ck_tile::index_t nhead,
ck_tile::index_t seqlen_q, ck_tile::index_t max_seqlen_q,
ck_tile::index_t hdim_v, ck_tile::index_t hdim_v,
ck_tile::index_t num_splits) ck_tile::index_t num_splits)
{ {
// TODO: this may need tuning // TODO: this may need tuning
return dim3(ck_tile::integer_divide_ceil(seqlen_q, kM0) * return dim3(ck_tile::integer_divide_ceil(max_seqlen_q, kM0) *
ck_tile::integer_divide_ceil(hdim_v, kN1), ck_tile::integer_divide_ceil(hdim_v, kN1),
nhead * num_splits, nhead * num_splits,
batch_size); batch_size);
......
...@@ -827,6 +827,7 @@ struct BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP ...@@ -827,6 +827,7 @@ struct BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP
}, },
s_acc, s_acc,
bias_s_tile); bias_s_tile);
__builtin_amdgcn_sched_barrier(0);
} }
else if constexpr(BiasEnum == BlockAttentionBiasEnum::ALIBI) else if constexpr(BiasEnum == BlockAttentionBiasEnum::ALIBI)
{ {
...@@ -918,6 +919,7 @@ struct BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP ...@@ -918,6 +919,7 @@ struct BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP
gemm_1(dv_acc, pt_reg_tensor, dot_reg_tensor); gemm_1(dv_acc, pt_reg_tensor, dot_reg_tensor);
HotLoopScheduler::template GemmStagedScheduler<1>(); HotLoopScheduler::template GemmStagedScheduler<1>();
__builtin_amdgcn_sched_barrier(0);
// STAGE 4, OGrad@V Gemm2 // STAGE 4, OGrad@V Gemm2
auto dp_acc = SPGradBlockTileType{}; auto dp_acc = SPGradBlockTileType{};
...@@ -927,6 +929,7 @@ struct BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP ...@@ -927,6 +929,7 @@ struct BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP
dp_acc = gemm_2(do_reg_tensor, v_reg_tensor); dp_acc = gemm_2(do_reg_tensor, v_reg_tensor);
HotLoopScheduler::template GemmStagedScheduler<2>(); HotLoopScheduler::template GemmStagedScheduler<2>();
__builtin_amdgcn_sched_barrier(0);
// STAGE 5, P^T(PGrad^T - D) // STAGE 5, P^T(PGrad^T - D)
auto ds = SPGradBlockTileType{}; auto ds = SPGradBlockTileType{};
...@@ -965,6 +968,7 @@ struct BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP ...@@ -965,6 +968,7 @@ struct BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP
Policy::template MakeBiasTileDistribution<Problem>()); Policy::template MakeBiasTileDistribution<Problem>());
shuffle_tile(dbias_tile, shuffled_dbias_tile); shuffle_tile(dbias_tile, shuffled_dbias_tile);
store_tile(dbias_dram_window, dbias_tile); store_tile(dbias_dram_window, dbias_tile);
__builtin_amdgcn_sched_barrier(0);
} }
// STAGE 6, SGrad^T@Q^T Gemm3 // STAGE 6, SGrad^T@Q^T Gemm3
...@@ -984,6 +988,7 @@ struct BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP ...@@ -984,6 +988,7 @@ struct BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP
move_tile_window(ds_lds_read_window, {0, kK4}); move_tile_window(ds_lds_read_window, {0, kK4});
HotLoopScheduler::template GemmStagedScheduler<3>(); HotLoopScheduler::template GemmStagedScheduler<3>();
__builtin_amdgcn_sched_barrier(0);
// STAGE 7, SGrad@K^T Gemm4 // STAGE 7, SGrad@K^T Gemm4
auto dq_acc = QGradBlockTileType{}; auto dq_acc = QGradBlockTileType{};
clear_tile(dq_acc); clear_tile(dq_acc);
...@@ -1005,6 +1010,7 @@ struct BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP ...@@ -1005,6 +1010,7 @@ struct BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP
}); });
HotLoopScheduler::template GemmStagedScheduler<4>(); HotLoopScheduler::template GemmStagedScheduler<4>();
__builtin_amdgcn_sched_barrier(0);
// Results Scale // Results Scale
if constexpr(FmhaDropout::IsDropout) if constexpr(FmhaDropout::IsDropout)
......
...@@ -25,14 +25,15 @@ struct BlockFmhaBwdPipelineDefaultPolicy ...@@ -25,14 +25,15 @@ struct BlockFmhaBwdPipelineDefaultPolicy
template <typename Problem> template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto GetQKBlockGemm() CK_TILE_HOST_DEVICE static constexpr auto GetQKBlockGemm()
{ {
using BlockGemmProblem = using BlockGemmProblem = BlockGemmPipelineProblem<
BlockGemmPipelineProblem<typename Problem::QDataType, typename Problem::QDataType,
typename Problem::KDataType, typename Problem::KDataType,
typename Problem::AccDataType, typename Problem::AccDataType,
Problem::kBlockSize, TileGemmShape<sequence<Problem::BlockFmhaShape::kM0,
TileGemmShape<Problem::BlockFmhaShape::kM0, Problem::BlockFmhaShape::kN0,
Problem::BlockFmhaShape::kN0, Problem::BlockFmhaShape::kK0>,
Problem::BlockFmhaShape::kK0>>; typename Problem::BlockFmhaShape::Gemm0BlockWarps,
typename Problem::BlockFmhaShape::Gemm0WarpTile>>;
using WarpGemm = WarpGemmMfmaDispatcher< using WarpGemm = WarpGemmMfmaDispatcher<
typename Problem::QDataType, typename Problem::QDataType,
...@@ -57,14 +58,15 @@ struct BlockFmhaBwdPipelineDefaultPolicy ...@@ -57,14 +58,15 @@ struct BlockFmhaBwdPipelineDefaultPolicy
template <typename Problem> template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto GetPTOGradTBlockGemm() CK_TILE_HOST_DEVICE static constexpr auto GetPTOGradTBlockGemm()
{ {
using BlockGemmProblem = using BlockGemmProblem = BlockGemmPipelineProblem<
BlockGemmPipelineProblem<typename Problem::GemmDataType, typename Problem::GemmDataType,
typename Problem::OGradDataType, typename Problem::OGradDataType,
typename Problem::AccDataType, typename Problem::AccDataType,
Problem::kBlockSize, TileGemmShape<sequence<Problem::BlockFmhaShape::kN0,
TileGemmShape<Problem::BlockFmhaShape::kN0, Problem::BlockFmhaShape::kVHeaddim,
Problem::BlockFmhaShape::kVHeaddim, Problem::BlockFmhaShape::kK1>,
Problem::BlockFmhaShape::kK1>>; typename Problem::BlockFmhaShape::Gemm1BlockWarps,
typename Problem::BlockFmhaShape::Gemm1WarpTile>>;
using WarpGemm = using WarpGemm =
WarpGemmMfmaDispatcher<typename Problem::GemmDataType, WarpGemmMfmaDispatcher<typename Problem::GemmDataType,
...@@ -88,14 +90,15 @@ struct BlockFmhaBwdPipelineDefaultPolicy ...@@ -88,14 +90,15 @@ struct BlockFmhaBwdPipelineDefaultPolicy
template <typename Problem> template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto GetOGradVBlockGemm() CK_TILE_HOST_DEVICE static constexpr auto GetOGradVBlockGemm()
{ {
using BlockGemmProblem = using BlockGemmProblem = BlockGemmPipelineProblem<
BlockGemmPipelineProblem<typename Problem::OGradDataType, typename Problem::OGradDataType,
typename Problem::VDataType, typename Problem::VDataType,
typename Problem::AccDataType, typename Problem::AccDataType,
Problem::kBlockSize, TileGemmShape<sequence<Problem::BlockFmhaShape::kM0,
TileGemmShape<Problem::BlockFmhaShape::kM0, Problem::BlockFmhaShape::kN0,
Problem::BlockFmhaShape::kN0, Problem::BlockFmhaShape::kK2>,
Problem::BlockFmhaShape::kK2>>; typename Problem::BlockFmhaShape::Gemm2BlockWarps,
typename Problem::BlockFmhaShape::Gemm2WarpTile>>;
using WarpGemm = WarpGemmMfmaDispatcher< using WarpGemm = WarpGemmMfmaDispatcher<
typename Problem::OGradDataType, typename Problem::OGradDataType,
...@@ -120,14 +123,15 @@ struct BlockFmhaBwdPipelineDefaultPolicy ...@@ -120,14 +123,15 @@ struct BlockFmhaBwdPipelineDefaultPolicy
template <typename Problem> template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto GetSGradTQTBlockGemm() CK_TILE_HOST_DEVICE static constexpr auto GetSGradTQTBlockGemm()
{ {
using BlockGemmProblem = using BlockGemmProblem = BlockGemmPipelineProblem<
BlockGemmPipelineProblem<typename Problem::GemmDataType, typename Problem::GemmDataType,
typename Problem::QDataType, typename Problem::QDataType,
typename Problem::AccDataType, typename Problem::AccDataType,
Problem::kBlockSize, TileGemmShape<sequence<Problem::BlockFmhaShape::kN0,
TileGemmShape<Problem::BlockFmhaShape::kN0, Problem::BlockFmhaShape::kQKHeaddim,
Problem::BlockFmhaShape::kQKHeaddim, Problem::BlockFmhaShape::kK3>,
Problem::BlockFmhaShape::kK3>>; typename Problem::BlockFmhaShape::Gemm3BlockWarps,
typename Problem::BlockFmhaShape::Gemm3WarpTile>>;
using WarpGemm = using WarpGemm =
WarpGemmMfmaDispatcher<typename Problem::GemmDataType, WarpGemmMfmaDispatcher<typename Problem::GemmDataType,
...@@ -151,14 +155,15 @@ struct BlockFmhaBwdPipelineDefaultPolicy ...@@ -151,14 +155,15 @@ struct BlockFmhaBwdPipelineDefaultPolicy
template <typename Problem> template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto GetSGradKTBlockGemm() CK_TILE_HOST_DEVICE static constexpr auto GetSGradKTBlockGemm()
{ {
using BlockGemmProblem = using BlockGemmProblem = BlockGemmPipelineProblem<
BlockGemmPipelineProblem<typename Problem::GemmDataType, typename Problem::GemmDataType,
typename Problem::KDataType, typename Problem::KDataType,
typename Problem::AccDataType, typename Problem::AccDataType,
Problem::kBlockSize, TileGemmShape<sequence<Problem::BlockFmhaShape::kM0,
TileGemmShape<Problem::BlockFmhaShape::kM0, Problem::BlockFmhaShape::kQKHeaddim,
Problem::BlockFmhaShape::kQKHeaddim, Problem::BlockFmhaShape::kK4>,
Problem::BlockFmhaShape::kK4>>; typename Problem::BlockFmhaShape::Gemm4BlockWarps,
typename Problem::BlockFmhaShape::Gemm4WarpTile>>;
using WarpGemm = using WarpGemm =
WarpGemmMfmaDispatcher<typename Problem::GemmDataType, WarpGemmMfmaDispatcher<typename Problem::GemmDataType,
...@@ -1722,7 +1727,7 @@ struct BlockFmhaBwdPipelineDefaultPolicy ...@@ -1722,7 +1727,7 @@ struct BlockFmhaBwdPipelineDefaultPolicy
} }
template <> template <>
CK_TILE_DEVICE static constexpr void GemmStagedScheduler<0>() CK_TILE_DEVICE constexpr void GemmStagedScheduler<0>()
{ {
// Mem: Q, LSE, OGrad, D global load, OGrad^T LDS load // Mem: Q, LSE, OGrad, D global load, OGrad^T LDS load
// Comp: Q x K // Comp: Q x K
...@@ -1754,7 +1759,7 @@ struct BlockFmhaBwdPipelineDefaultPolicy ...@@ -1754,7 +1759,7 @@ struct BlockFmhaBwdPipelineDefaultPolicy
} }
template <> template <>
CK_TILE_DEVICE static constexpr void GemmStagedScheduler<1>() CK_TILE_DEVICE constexpr void GemmStagedScheduler<1>()
{ {
// Mem: Q^T LDS load // Mem: Q^T LDS load
// Comp: OGrad x V // Comp: OGrad x V
...@@ -1772,7 +1777,7 @@ struct BlockFmhaBwdPipelineDefaultPolicy ...@@ -1772,7 +1777,7 @@ struct BlockFmhaBwdPipelineDefaultPolicy
} }
template <> template <>
CK_TILE_DEVICE static constexpr void GemmStagedScheduler<2>() CK_TILE_DEVICE constexpr void GemmStagedScheduler<2>()
{ {
// Mem: Q, QT, LSE, OGrad, OGradT, D, LDS store // Mem: Q, QT, LSE, OGrad, OGradT, D, LDS store
// Comp: PT x OGrad // Comp: PT x OGrad
...@@ -1791,7 +1796,7 @@ struct BlockFmhaBwdPipelineDefaultPolicy ...@@ -1791,7 +1796,7 @@ struct BlockFmhaBwdPipelineDefaultPolicy
} }
template <> template <>
CK_TILE_DEVICE static constexpr void GemmStagedScheduler<3>() CK_TILE_DEVICE constexpr void GemmStagedScheduler<3>()
{ {
// Mem: SGradT LDS store, SGrad, Q, LSE LDS load. // Mem: SGradT LDS store, SGrad, Q, LSE LDS load.
// Comp: SGradT x QT // Comp: SGradT x QT
...@@ -1825,7 +1830,7 @@ struct BlockFmhaBwdPipelineDefaultPolicy ...@@ -1825,7 +1830,7 @@ struct BlockFmhaBwdPipelineDefaultPolicy
} }
template <> template <>
CK_TILE_DEVICE static constexpr void GemmStagedScheduler<4>() CK_TILE_DEVICE constexpr void GemmStagedScheduler<4>()
{ {
// Mem: SGrad, OGrad, D LDS load. // Mem: SGrad, OGrad, D LDS load.
// Comp: SGrad x KT // Comp: SGrad x KT
......
...@@ -107,7 +107,7 @@ struct BlockFmhaFwdSplitKVCombinePipeline ...@@ -107,7 +107,7 @@ struct BlockFmhaFwdSplitKVCombinePipeline
const LSEElementFunction& lse_element_func, const LSEElementFunction& lse_element_func,
const OaccElementFunction& o_acc_element_func, const OaccElementFunction& o_acc_element_func,
index_t num_splits, index_t num_splits,
index_t max_seqlen_q, index_t seqlen_q,
void* smem_ptr) const void* smem_ptr) const
{ {
// lse_acc tile in LDS // lse_acc tile in LDS
...@@ -261,7 +261,7 @@ struct BlockFmhaFwdSplitKVCombinePipeline ...@@ -261,7 +261,7 @@ struct BlockFmhaFwdSplitKVCombinePipeline
auto o_acc = make_static_distributed_tensor<OaccDataType>(o_acc_dist); auto o_acc = make_static_distributed_tensor<OaccDataType>(o_acc_dist);
clear_tile(o_acc); clear_tile(o_acc);
const index_t padded_max_seqlen_q = integer_divide_ceil(max_seqlen_q, kM0) * kM0; const index_t padded_seqlen_q = integer_divide_ceil(seqlen_q, kM0) * kM0;
for(index_t i_split = 0; i_split < num_splits; ++i_split) for(index_t i_split = 0; i_split < num_splits; ++i_split)
{ {
...@@ -282,7 +282,7 @@ struct BlockFmhaFwdSplitKVCombinePipeline ...@@ -282,7 +282,7 @@ struct BlockFmhaFwdSplitKVCombinePipeline
}); });
} }
move_tile_window(o_acc_dram_window, {padded_max_seqlen_q, 0}); move_tile_window(o_acc_dram_window, {padded_seqlen_q, 0});
} }
o_acc = tile_elementwise_in(o_acc_element_func, o_acc); o_acc = tile_elementwise_in(o_acc_element_func, o_acc);
...@@ -297,7 +297,7 @@ struct BlockFmhaFwdSplitKVCombinePipeline ...@@ -297,7 +297,7 @@ struct BlockFmhaFwdSplitKVCombinePipeline
const OaccDramBlockWindow& o_acc_dram_block_window, const OaccDramBlockWindow& o_acc_dram_block_window,
LSEDramBlockWindow& lse_dram_block_window, LSEDramBlockWindow& lse_dram_block_window,
index_t num_splits, index_t num_splits,
index_t max_seqlen_q, index_t seqlen_q,
void* smem_ptr) const void* smem_ptr) const
{ {
return operator()(lse_acc_dram_block_window, return operator()(lse_acc_dram_block_window,
...@@ -306,7 +306,7 @@ struct BlockFmhaFwdSplitKVCombinePipeline ...@@ -306,7 +306,7 @@ struct BlockFmhaFwdSplitKVCombinePipeline
identity{}, identity{},
identity{}, identity{},
num_splits, num_splits,
max_seqlen_q, seqlen_q,
smem_ptr); smem_ptr);
} }
}; };
......
...@@ -64,8 +64,6 @@ struct BlockFmhaFwdSplitKVPipelineQRKSVS ...@@ -64,8 +64,6 @@ struct BlockFmhaFwdSplitKVPipelineQRKSVS
return kPadSeqLenK ? 1 : Policy::template GetAlignmentV<Problem>(); return kPadSeqLenK ? 1 : Policy::template GetAlignmentV<Problem>();
}(); }();
static constexpr index_t kAlignmentO =
kPadHeadDimV ? 1 : Policy::template GetAlignmentO<Problem>();
static constexpr index_t kAlignmentBias = static constexpr index_t kAlignmentBias =
kPadSeqLenK ? 1 : Policy::template GetAlignmentBias<Problem>(); kPadSeqLenK ? 1 : Policy::template GetAlignmentBias<Problem>();
...@@ -212,8 +210,8 @@ struct BlockFmhaFwdSplitKVPipelineQRKSVS ...@@ -212,8 +210,8 @@ struct BlockFmhaFwdSplitKVPipelineQRKSVS
const auto [seqlen_k_start, seqlen_k_end] = mask.GetTileRangeAlongX( const auto [seqlen_k_start, seqlen_k_end] = mask.GetTileRangeAlongX(
q_origin.at(number<0>{}), number<kM0>{}, number<kN0>{}, num_splits, i_split); q_origin.at(number<0>{}), number<kM0>{}, number<kN0>{}, num_splits, i_split);
// check early exit if masked and no work to do. // check early exit if no work to do
if constexpr(FmhaMask::IsMasking || kHasUnevenSplits) if constexpr(FmhaMask::IsMasking || kPadSeqLenK || kHasUnevenSplits)
{ {
const index_t original_num_total_loop = const index_t original_num_total_loop =
integer_divide_ceil(seqlen_k_end - seqlen_k_start, kN0); integer_divide_ceil(seqlen_k_end - seqlen_k_start, kN0);
...@@ -616,7 +614,8 @@ struct BlockFmhaFwdSplitKVPipelineQRKSVS ...@@ -616,7 +614,8 @@ struct BlockFmhaFwdSplitKVPipelineQRKSVS
sweep_tile_span(o_spans[number<0>{}], [&](auto idx0) { sweep_tile_span(o_spans[number<0>{}], [&](auto idx0) {
constexpr auto i_idx = make_tuple(idx0); constexpr auto i_idx = make_tuple(idx0);
const auto tmp = [&]() { const auto tmp = [&]() {
if constexpr(FmhaMask::IsMasking) if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS ||
FmhaMask::IsMasking)
{ {
return l[i_idx] == 0.f ? 0.f : 1 / l[i_idx]; return l[i_idx] == 0.f ? 0.f : 1 / l[i_idx];
} }
......
...@@ -215,8 +215,8 @@ struct BlockFmhaPipelineQRKSVS ...@@ -215,8 +215,8 @@ struct BlockFmhaPipelineQRKSVS
const auto num_total_loop = integer_divide_ceil(seqlen_k_end - seqlen_k_start, kN0); const auto num_total_loop = integer_divide_ceil(seqlen_k_end - seqlen_k_start, kN0);
// check early exit if masked and no work to do. // check early exit if no work to do
if constexpr(FmhaMask::IsMasking) if constexpr(FmhaMask::IsMasking || kPadSeqLenK)
{ {
if(num_total_loop <= 0) if(num_total_loop <= 0)
{ {
......
...@@ -268,7 +268,7 @@ struct BlockFmhaPipelineQRKSVSAsync ...@@ -268,7 +268,7 @@ struct BlockFmhaPipelineQRKSVSAsync
const auto num_total_loop = integer_divide_ceil(seqlen_k_end - seqlen_k_start, kN0); const auto num_total_loop = integer_divide_ceil(seqlen_k_end - seqlen_k_start, kN0);
// check early exit // check early exit if no work to do
if constexpr(FmhaMask::IsMasking || kPadSeqLenK) if constexpr(FmhaMask::IsMasking || kPadSeqLenK)
{ {
if(num_total_loop <= 0) if(num_total_loop <= 0)
......
...@@ -75,14 +75,15 @@ struct BlockFmhaPipelineQXCustomPolicy</* QLoadOnce = */ true> ...@@ -75,14 +75,15 @@ struct BlockFmhaPipelineQXCustomPolicy</* QLoadOnce = */ true>
template <typename Problem> template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto GetQKBlockGemm() CK_TILE_HOST_DEVICE static constexpr auto GetQKBlockGemm()
{ {
using BlockGemmProblem = using BlockGemmProblem = BlockGemmPipelineProblem<
BlockGemmPipelineProblem<typename Problem::QDataType, typename Problem::QDataType,
typename Problem::KDataType, typename Problem::KDataType,
typename Problem::SaccDataType, typename Problem::SaccDataType,
Problem::kBlockSize, TileGemmShape<sequence<Problem::BlockFmhaShape::kM0,
TileGemmShape<Problem::BlockFmhaShape::kM0, Problem::BlockFmhaShape::kN0,
Problem::BlockFmhaShape::kN0, Problem::BlockFmhaShape::kK0>,
Problem::BlockFmhaShape::kK0>>; typename Problem::BlockFmhaShape::Gemm0BlockWarps,
typename Problem::BlockFmhaShape::Gemm0WarpTile>>;
constexpr auto warp_gemm = []() { constexpr auto warp_gemm = []() {
if constexpr(std::is_same_v<typename Problem::QDataType, half_t> && if constexpr(std::is_same_v<typename Problem::QDataType, half_t> &&
...@@ -198,14 +199,15 @@ struct BlockFmhaPipelineQXCustomPolicy</* QLoadOnce = */ false> ...@@ -198,14 +199,15 @@ struct BlockFmhaPipelineQXCustomPolicy</* QLoadOnce = */ false>
template <typename Problem> template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto GetQKBlockGemm() CK_TILE_HOST_DEVICE static constexpr auto GetQKBlockGemm()
{ {
using BlockGemmProblem = using BlockGemmProblem = BlockGemmPipelineProblem<
BlockGemmPipelineProblem<typename Problem::QDataType, typename Problem::QDataType,
typename Problem::KDataType, typename Problem::KDataType,
typename Problem::SaccDataType, typename Problem::SaccDataType,
Problem::kBlockSize, TileGemmShape<sequence<Problem::BlockFmhaShape::kM0,
TileGemmShape<Problem::BlockFmhaShape::kM0, Problem::BlockFmhaShape::kN0,
Problem::BlockFmhaShape::kN0, Problem::BlockFmhaShape::kK0>,
Problem::BlockFmhaShape::kK0>>; typename Problem::BlockFmhaShape::Gemm0BlockWarps,
typename Problem::BlockFmhaShape::Gemm0WarpTile>>;
constexpr auto warp_gemm = []() { constexpr auto warp_gemm = []() {
if constexpr(std::is_same_v<typename Problem::QDataType, half_t> && if constexpr(std::is_same_v<typename Problem::QDataType, half_t> &&
...@@ -952,14 +954,15 @@ struct BlockFmhaPipelineQXKSVSCustomPolicy : BlockFmhaPipelineQXCustomPolicy<QLo ...@@ -952,14 +954,15 @@ struct BlockFmhaPipelineQXKSVSCustomPolicy : BlockFmhaPipelineQXCustomPolicy<QLo
template <typename Problem> template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto GetKVBlockGemm() CK_TILE_HOST_DEVICE static constexpr auto GetKVBlockGemm()
{ {
using BlockGemmProblem = using BlockGemmProblem = BlockGemmPipelineProblem<
BlockGemmPipelineProblem<typename Problem::PDataType, typename Problem::PDataType,
typename Problem::VDataType, typename Problem::VDataType,
typename Problem::OaccDataType, typename Problem::OaccDataType,
Problem::kBlockSize, TileGemmShape<sequence<Problem::BlockFmhaShape::kM0,
TileGemmShape<Problem::BlockFmhaShape::kM0, Problem::BlockFmhaShape::kN1,
Problem::BlockFmhaShape::kN1, Problem::BlockFmhaShape::kK1>,
Problem::BlockFmhaShape::kK1>>; typename Problem::BlockFmhaShape::Gemm1BlockWarps,
typename Problem::BlockFmhaShape::Gemm1WarpTile>>;
auto warp_gemm = [&]() { auto warp_gemm = [&]() {
if constexpr(std::is_same_v<typename Problem::KDataType, fp8_t> && if constexpr(std::is_same_v<typename Problem::KDataType, fp8_t> &&
......
...@@ -21,6 +21,8 @@ ...@@ -21,6 +21,8 @@
#include "ck_tile/ops/gemm/block/block_gemm_asmem_bsmem_creg_v1_custom_policy.hpp" #include "ck_tile/ops/gemm/block/block_gemm_asmem_bsmem_creg_v1_custom_policy.hpp"
#include "ck_tile/ops/gemm/block/block_gemm_asmem_bsmem_creg_v1_default_policy.hpp" #include "ck_tile/ops/gemm/block/block_gemm_asmem_bsmem_creg_v1_default_policy.hpp"
#include "ck_tile/ops/gemm/block/block_gemm_problem.hpp" #include "ck_tile/ops/gemm/block/block_gemm_problem.hpp"
#include "ck_tile/ops/gemm/kernel/gemm_kernel.hpp"
#include "ck_tile/ops/gemm/kernel/gemm_tile_partitioner.hpp"
#include "ck_tile/ops/gemm/pipeline/block_gemm_pipeline_agmem_bgmem_creg_v1.hpp" #include "ck_tile/ops/gemm/pipeline/block_gemm_pipeline_agmem_bgmem_creg_v1.hpp"
#include "ck_tile/ops/gemm/pipeline/block_gemm_pipeline_agmem_bgmem_creg_v1_default_policy.hpp" #include "ck_tile/ops/gemm/pipeline/block_gemm_pipeline_agmem_bgmem_creg_v1_default_policy.hpp"
#include "ck_tile/ops/gemm/pipeline/block_gemm_pipeline_agmem_bgmem_creg_v2.hpp" #include "ck_tile/ops/gemm/pipeline/block_gemm_pipeline_agmem_bgmem_creg_v2.hpp"
......
...@@ -4,7 +4,8 @@ ...@@ -4,7 +4,8 @@
#pragma once #pragma once
#include "ck_tile/core.hpp" #include "ck_tile/core.hpp"
#include "ck_tile/ops/gemm/block/block_gemm_areg_bsmem_creg_v1_default_policy.hpp" #include "ck_tile/ops/gemm/block/block_gemm_areg_bgmem_creg_v1_default_policy.hpp"
#include "ck_tile/ops/gemm/block/block_gemm_problem.hpp"
namespace ck_tile { namespace ck_tile {
...@@ -27,9 +28,9 @@ struct BlockGemmARegBGmemCRegV1 ...@@ -27,9 +28,9 @@ struct BlockGemmARegBGmemCRegV1
static constexpr index_t kBlockSize = Problem::kBlockSize; static constexpr index_t kBlockSize = Problem::kBlockSize;
// use BlockGemmARegBSmemCRegV1 as the underlying block-GEMM implementation // use BlockGemmARegBSmemCRegV1 as the underlying block-GEMM implementation
using BlockGemmARegBSmemCRegImpl = BlockGemmARegBSmemCRegV1< using BlockGemmARegBGmemCRegImpl = BlockGemmARegBGmemCRegV1<
BlockGemmProblem<ADataType, BDataType, CDataType, kBlockSize, BlockGemmShape>, BlockGemmProblem<ADataType, BDataType, CDataType, kBlockSize, BlockGemmShape>,
BlockGemmARegBSmemCRegV1DefaultPolicy>; BlockGemmARegBGmemCRegV1DefaultPolicy>;
CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetStaticLdsSize() CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetStaticLdsSize()
{ {
...@@ -82,7 +83,7 @@ struct BlockGemmARegBGmemCRegV1 ...@@ -82,7 +83,7 @@ struct BlockGemmARegBGmemCRegV1
block_sync_lds(); block_sync_lds();
// block GEMM // block GEMM
BlockGemmARegBSmemCRegImpl{}(c_block_tensor, a_block_tensor, b_block_smem_window); BlockGemmARegBGmemCRegImpl{}(c_block_tensor, a_block_tensor, b_block_smem_window);
} }
// C = A * B // C = A * B
...@@ -128,7 +129,7 @@ struct BlockGemmARegBGmemCRegV1 ...@@ -128,7 +129,7 @@ struct BlockGemmARegBGmemCRegV1
block_sync_lds(); block_sync_lds();
// block GEMM // block GEMM
return BlockGemmARegBSmemCRegImpl{}(a_block_tensor, b_block_smem_window); return BlockGemmARegBGmemCRegImpl{}(a_block_tensor, b_block_smem_window);
} }
}; };
......
...@@ -49,6 +49,10 @@ struct BlockGemmASmemBSmemCRegV1DefaultPolicy ...@@ -49,6 +49,10 @@ struct BlockGemmASmemBSmemCRegV1DefaultPolicy
{ {
return make_tuple(WarpGemmMfmaBf16Bf16F32M32N32K16TransposedCDistribution{}, 4, 1); return make_tuple(WarpGemmMfmaBf16Bf16F32M32N32K16TransposedCDistribution{}, 4, 1);
} }
else
{
static_assert(false, "Unsupported data type configuration for GEMM warp execution.");
}
} }
}; };
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core.hpp"
#include "ck_tile/ops/common.hpp"
#include <iostream>
#include <string>
namespace ck_tile {
template <typename TilePartitioner_,
typename GemmPipeline_,
typename EpiloguePipeline_,
typename LayoutA_,
typename LayoutB_,
typename LayoutC_>
struct GemmKernel
{
using TilePartitioner = remove_cvref_t<TilePartitioner_>;
using GemmPipeline = remove_cvref_t<GemmPipeline_>;
using EpiloguePipeline = remove_cvref_t<EpiloguePipeline_>;
using LayoutA = remove_cvref_t<LayoutA_>;
using LayoutB = remove_cvref_t<LayoutB_>;
using LayoutC = remove_cvref_t<LayoutC_>;
static constexpr index_t KernelBlockSize = GemmPipeline::kBlockSize;
using ADataType = remove_cvref_t<typename GemmPipeline::ADataType>;
using BDataType = remove_cvref_t<typename GemmPipeline::BDataType>;
using CAccDataType = remove_cvref_t<typename GemmPipeline::CDataType>;
using CODataType = remove_cvref_t<typename EpiloguePipeline::ODataType>;
__host__ static constexpr auto GridSize(index_t M_size, index_t N_size, index_t Batch_size)
{
return TilePartitioner::GridSize(M_size, N_size, Batch_size);
}
__host__ static constexpr auto BlockSize() { return dim3(KernelBlockSize); }
struct GemmCommonKargs
{
const void* a_ptr;
const void* b_ptr;
void* c_ptr;
float epsilon;
ck_tile::index_t M;
ck_tile::index_t N;
ck_tile::index_t K;
ck_tile::index_t stride_A;
ck_tile::index_t stride_B;
ck_tile::index_t stride_C;
};
CK_TILE_HOST static constexpr GemmCommonKargs MakeKargs(const void* a_ptr,
const void* b_ptr,
void* c_ptr,
float epsilon,
ck_tile::index_t M,
ck_tile::index_t N,
ck_tile::index_t K,
ck_tile::index_t stride_A,
ck_tile::index_t stride_B,
ck_tile::index_t stride_C)
{
return GemmCommonKargs{a_ptr, b_ptr, c_ptr, epsilon, M, N, K, stride_A, stride_B, stride_C};
}
CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetSmemSize()
{
return ck_tile::max(GemmPipeline::GetSmemSize(), EpiloguePipeline::GetSmemSize());
}
CK_TILE_DEVICE void operator()(GemmCommonKargs kargs) const
{
const auto [i_m, i_n] = TilePartitioner{}();
// options
const ADataType* a_start = static_cast<const ADataType*>(kargs.a_ptr);
const BDataType* b_start = static_cast<const BDataType*>(kargs.b_ptr);
// Convert pointers to tensor views
auto a_tensor_view = [&]() {
if constexpr(std::is_same_v<LayoutA, tensor_layout::gemm::ColumnMajor>)
{
return make_naive_tensor_view<address_space_enum::global>(
a_start,
make_tuple(kargs.M, kargs.K),
make_tuple(1, kargs.stride_A),
number<GemmPipeline::AlignmentA>{},
number<1>{});
}
else
{
return make_naive_tensor_view<address_space_enum::global>(
a_start,
make_tuple(kargs.M, kargs.K),
make_tuple(kargs.stride_A, 1),
number<GemmPipeline::AlignmentA>{},
number<1>{});
}
}();
auto b_tensor_view = [&]() {
if constexpr(std::is_same_v<LayoutB, tensor_layout::gemm::RowMajor>)
{
return make_naive_tensor_view<address_space_enum::global>(
b_start,
make_tuple(kargs.N, kargs.K),
make_tuple(1, kargs.stride_B),
number<GemmPipeline::AlignmentB>{},
number<1>{});
}
else
{ // Default NK layout
return make_naive_tensor_view<address_space_enum::global>(
b_start,
make_tuple(kargs.N, kargs.K),
make_tuple(kargs.stride_B, 1),
number<GemmPipeline::AlignmentB>{},
number<1>{});
}
}();
auto a_pad_view = pad_tensor_view(
a_tensor_view,
make_tuple(number<TilePartitioner::kM>{}, number<TilePartitioner::kK>{}),
sequence < 0,
GemmPipeline::kPadA ? 1 : 0 > {});
auto ABlockWindow = make_tile_window(
a_pad_view,
make_tuple(number<TilePartitioner::kM>{}, number<TilePartitioner::kK>{}),
{i_m, 0});
auto b_pad_view = pad_tensor_view(
b_tensor_view,
make_tuple(number<TilePartitioner::kN>{}, number<TilePartitioner::kK>{}),
sequence < 0,
GemmPipeline::kPadB ? 1 : 0 > {});
auto BBlockWindow = make_tile_window(
b_pad_view,
make_tuple(number<TilePartitioner::kN>{}, number<TilePartitioner::kK>{}),
{i_n, 0});
// allocate LDS
__shared__ char smem_ptr[GetSmemSize()];
const index_t num_loop = (kargs.K + TilePartitioner::kK - 1) / TilePartitioner::kK;
auto acc = GemmPipeline{}(ABlockWindow, BBlockWindow, num_loop, smem_ptr);
CODataType* c_start = static_cast<CODataType*>(kargs.c_ptr);
auto c_tensor_view = [&]() {
if constexpr(std::is_same_v<LayoutC, tensor_layout::gemm::ColumnMajor>)
{
return make_naive_tensor_view<address_space_enum::global>(
c_start,
make_tuple(kargs.M, kargs.N),
make_tuple(1, kargs.stride_C),
number<GemmPipeline::AlignmentC>{},
number<1>{});
}
else
{
return make_naive_tensor_view<address_space_enum::global>(
c_start,
make_tuple(kargs.M, kargs.N),
make_tuple(kargs.stride_C, 1),
number<GemmPipeline::AlignmentC>{},
number<1>{});
}
}();
auto c_pad_view = pad_tensor_view(
c_tensor_view,
make_tuple(number<TilePartitioner::kM>{}, number<TilePartitioner::kN>{}),
sequence < 0,
GemmPipeline::kPadC ? 1 : 0 > {});
auto CBlockWindow_pad = make_tile_window(
c_pad_view,
make_tuple(number<TilePartitioner::kM>{}, number<TilePartitioner::kN>{}),
{i_m, i_n});
EpiloguePipeline{}(CBlockWindow_pad, acc);
}
};
} // namespace ck_tile
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core.hpp"
namespace ck_tile {
template <typename BlockGemmShape_>
struct GemmTilePartitioner
{
using BlockGemmShape = ck_tile::remove_cvref_t<BlockGemmShape_>;
static constexpr ck_tile::index_t kM = BlockGemmShape::kM;
static constexpr ck_tile::index_t kN = BlockGemmShape::kN;
static constexpr ck_tile::index_t kK = BlockGemmShape::kK;
CK_TILE_HOST static constexpr auto
GridSize(ck_tile::index_t M, ck_tile::index_t N, ck_tile::index_t batch_size)
{
ck_tile::index_t GridDimX = (M + kM - 1) / kM;
ck_tile::index_t GridDimY = (N + kN - 1) / kN;
ck_tile::index_t GridDimZ = batch_size;
return dim3(GridDimX, GridDimY, GridDimZ);
}
CK_TILE_DEVICE auto operator()()
{
const index_t iM = __builtin_amdgcn_readfirstlane(blockIdx.x * kM);
const index_t iN = __builtin_amdgcn_readfirstlane(blockIdx.y * kN);
return ck_tile::make_tuple(iM, iN);
}
};
} // namespace ck_tile
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment