Commit 6f210155 authored by root's avatar root
Browse files

Merge branch 'gemm_bf16_sk_muozturk' of...

Merge branch 'gemm_bf16_sk_muozturk' of https://github.com/ROCm/composable_kernel into gemm_bf16_sk_muozturk
parents 41b54611 78394194
...@@ -97,13 +97,6 @@ ...@@ -97,13 +97,6 @@
#cmakedefine CK_ENABLE_DL_KERNELS @CK_ENABLE_DL_KERNELS@ #cmakedefine CK_ENABLE_DL_KERNELS @CK_ENABLE_DL_KERNELS@
#endif #endif
//
// Instances supports in the current CK build
//
#ifndef CK_ENABLE_INSTANCES_ONLY
#cmakedefine CK_ENABLE_INSTANCES_ONLY @CK_ENABLE_INSTANCES_ONLY@
#endif
// //
// CK kernels which support XDL (MI series) // CK kernels which support XDL (MI series)
// //
......
...@@ -66,6 +66,9 @@ float launch_and_time_kernel(const StreamConfig& stream_config, ...@@ -66,6 +66,9 @@ float launch_and_time_kernel(const StreamConfig& stream_config,
hip_check_error(hipEventElapsedTime(&total_time, start, stop)); hip_check_error(hipEventElapsedTime(&total_time, start, stop));
hip_check_error(hipEventDestroy(start));
hip_check_error(hipEventDestroy(stop));
return total_time / nrepeat; return total_time / nrepeat;
} }
else else
...@@ -143,6 +146,9 @@ float launch_and_time_kernel_with_preprocess(const StreamConfig& stream_config, ...@@ -143,6 +146,9 @@ float launch_and_time_kernel_with_preprocess(const StreamConfig& stream_config,
hip_check_error(hipEventElapsedTime(&total_time, start, stop)); hip_check_error(hipEventElapsedTime(&total_time, start, stop));
hip_check_error(hipEventDestroy(start));
hip_check_error(hipEventDestroy(stop));
return total_time / nrepeat; return total_time / nrepeat;
} }
else else
......
...@@ -64,7 +64,7 @@ __global__ void ...@@ -64,7 +64,7 @@ __global__ void
const index_t N = gemm_desc_ptr[group_id].N; const index_t N = gemm_desc_ptr[group_id].N;
const index_t K = gemm_desc_ptr[group_id].K; const index_t K = gemm_desc_ptr[group_id].K;
if(M * N * K == 0) if(M == 0 || N == 0 || K == 0)
return; return;
const auto StrideAs = gemm_desc_ptr[group_id].StrideAs; const auto StrideAs = gemm_desc_ptr[group_id].StrideAs;
......
...@@ -345,7 +345,7 @@ struct DeviceGroupedGemmMultipleDSplitKXdlCShuffleTwoStage ...@@ -345,7 +345,7 @@ struct DeviceGroupedGemmMultipleDSplitKXdlCShuffleTwoStage
const index_t N = gemm_descs[i].N_; const index_t N = gemm_descs[i].N_;
const index_t K = gemm_descs[i].K_; const index_t K = gemm_descs[i].K_;
if(M * N * K == 0) if(M == 0 || N == 0 || K == 0)
{ {
skipped_group_count_++; skipped_group_count_++;
continue; continue;
......
...@@ -109,7 +109,7 @@ __global__ void ...@@ -109,7 +109,7 @@ __global__ void
N = gemm_desc_ptr[group_id].N; N = gemm_desc_ptr[group_id].N;
K = gemm_desc_ptr[group_id].K; K = gemm_desc_ptr[group_id].K;
if(M * N * K == 0) if(M == 0 || N == 0 || K == 0)
{ {
grid_size_grp = 0; grid_size_grp = 0;
continue; continue;
......
...@@ -68,7 +68,7 @@ __global__ void ...@@ -68,7 +68,7 @@ __global__ void
const index_t N = gemm_desc_ptr[group_id].N; const index_t N = gemm_desc_ptr[group_id].N;
const index_t K = gemm_desc_ptr[group_id].K; const index_t K = gemm_desc_ptr[group_id].K;
if(M * N * K == 0) if(M == 0 || N == 0 || K == 0)
return; return;
const auto StrideA = gemm_desc_ptr[group_id].StrideA; const auto StrideA = gemm_desc_ptr[group_id].StrideA;
......
...@@ -58,7 +58,7 @@ struct thread_buffer { ...@@ -58,7 +58,7 @@ struct thread_buffer {
template <index_t I> CK_TILE_HOST_DEVICE constexpr const auto& at() const { return get(I); } template <index_t I> CK_TILE_HOST_DEVICE constexpr const auto& at() const { return get(I); }
template <index_t I> CK_TILE_HOST_DEVICE constexpr auto& at(number<I>) { return get(I); } template <index_t I> CK_TILE_HOST_DEVICE constexpr auto& at(number<I>) { return get(I); }
template <index_t I> CK_TILE_HOST_DEVICE constexpr const auto& at(number<I>) const { return get(I); } template <index_t I> CK_TILE_HOST_DEVICE constexpr const auto& at(number<I>) const { return get(I); }
template <typename X_, template <typename X_,
typename std::enable_if<has_same_scalar_type<value_type, X_>::value, bool>::type = false> typename std::enable_if<has_same_scalar_type<value_type, X_>::value, bool>::type = false>
CK_TILE_HOST_DEVICE constexpr auto _get_as() const CK_TILE_HOST_DEVICE constexpr auto _get_as() const
......
...@@ -50,12 +50,22 @@ class ArgParser ...@@ -50,12 +50,22 @@ class ArgParser
} }
return *this; return *this;
} }
void print() void print() const
{ {
// find max key length
std::string::size_type max_key_length = 11;
for(auto& key : keys)
{
if(max_key_length < key.length())
{
max_key_length = key.length();
}
}
printf("args:\n"); printf("args:\n");
for(auto& key : keys) for(auto& key : keys)
{ {
auto value = input_map[key]; auto value = input_map.at(key);
std::vector<std::string> help_text_lines; std::vector<std::string> help_text_lines;
size_t pos = 0; size_t pos = 0;
for(size_t next_pos = value.help_text.find('\n', pos); next_pos != std::string::npos;) for(size_t next_pos = value.help_text.find('\n', pos); next_pos != std::string::npos;)
...@@ -69,8 +79,7 @@ class ArgParser ...@@ -69,8 +79,7 @@ class ArgParser
std::string(value.help_text.begin() + pos, value.help_text.end())); std::string(value.help_text.begin() + pos, value.help_text.end()));
std::string default_value = std::string("(default:") + value.value + std::string(")"); std::string default_value = std::string("(default:") + value.value + std::string(")");
std::cout << std::setw(1 + max_key_length - value.name.length()) << "-" << key
std::cout << std::setw(2) << std::setw(12 - value.name.length()) << "-" << key
<< std::setw(4) << " " << help_text_lines[0] << " " << default_value << std::setw(4) << " " << help_text_lines[0] << " " << default_value
<< std::endl; << std::endl;
...@@ -78,7 +87,8 @@ class ArgParser ...@@ -78,7 +87,8 @@ class ArgParser
help_next_line != help_text_lines.end(); help_next_line != help_text_lines.end();
++help_next_line) ++help_next_line)
{ {
std::cout << std::setw(17) << " " << *help_next_line << std::endl; std::cout << std::setw(1 + max_key_length + 4) << " " << *help_next_line
<< std::endl;
} }
} }
} }
......
...@@ -13,7 +13,6 @@ namespace conv { ...@@ -13,7 +13,6 @@ namespace conv {
struct ConvParam struct ConvParam
{ {
ConvParam();
ConvParam(ck_tile::index_t n_dim, ConvParam(ck_tile::index_t n_dim,
ck_tile::index_t group_count, ck_tile::index_t group_count,
ck_tile::index_t n_batch, ck_tile::index_t n_batch,
...@@ -199,11 +198,6 @@ struct ConvParam ...@@ -199,11 +198,6 @@ struct ConvParam
} }
}; };
ConvParam::ConvParam()
: ConvParam::ConvParam(2, 1, 128, 256, 192, {3, 3}, {71, 71}, {2, 2}, {1, 1}, {1, 1}, {1, 1})
{
}
CK_TILE_HOST std::string get_conv_param_parser_helper_msg() CK_TILE_HOST std::string get_conv_param_parser_helper_msg()
{ {
std::string msg; std::string msg;
......
...@@ -27,7 +27,9 @@ CK_TILE_HOST void reference_gemm(const HostTensor<ADataType>& a_m_k, ...@@ -27,7 +27,9 @@ CK_TILE_HOST void reference_gemm(const HostTensor<ADataType>& a_m_k,
const BElementOp& b_element_op = {}, const BElementOp& b_element_op = {},
const ACCElementOp& acc_element_op = {}) const ACCElementOp& acc_element_op = {})
{ {
const int N = b_n_k.mDesc.get_lengths()[0]; const int N = (std::is_same_v<LayoutB, tensor_layout::gemm::ColumnMajor>)
? b_n_k.mDesc.get_lengths()[0]
: b_n_k.mDesc.get_lengths()[1];
const int K = (std::is_same_v<LayoutA, tensor_layout::gemm::RowMajor>) const int K = (std::is_same_v<LayoutA, tensor_layout::gemm::RowMajor>)
? a_m_k.mDesc.get_lengths()[1] ? a_m_k.mDesc.get_lengths()[1]
: a_m_k.mDesc.get_lengths()[0]; : a_m_k.mDesc.get_lengths()[0];
...@@ -45,20 +47,31 @@ CK_TILE_HOST void reference_gemm(const HostTensor<ADataType>& a_m_k, ...@@ -45,20 +47,31 @@ CK_TILE_HOST void reference_gemm(const HostTensor<ADataType>& a_m_k,
ADataType v_a = (std::is_same_v<LayoutA, tensor_layout::gemm::RowMajor>) ADataType v_a = (std::is_same_v<LayoutA, tensor_layout::gemm::RowMajor>)
? a_element_op(a_m_k(m, k)) ? a_element_op(a_m_k(m, k))
: a_element_op(a_m_k(k, m)); : a_element_op(a_m_k(k, m));
BDataType v_b = b_element_op(b_n_k(n, k)); BDataType v_b = (std::is_same_v<LayoutB, tensor_layout::gemm::ColumnMajor>)
? b_element_op(b_n_k(n, k))
: b_element_op(b_n_k(k, n));
v_acc += ck_tile::type_convert<AccDataType>(v_a) * v_acc += ck_tile::type_convert<AccDataType>(v_a) *
ck_tile::type_convert<AccDataType>(v_b); ck_tile::type_convert<AccDataType>(v_b);
} }
c_m_n(m, n) = ck_tile::type_convert<CDataType>(acc_element_op(v_acc)); CDataType& c_ref = (std::is_same_v<LayoutC, tensor_layout::gemm::RowMajor>)
? c_m_n(m, n)
: c_m_n(n, m);
c_ref = ck_tile::type_convert<CDataType>(acc_element_op(v_acc));
} }
}; };
make_ParallelTensorFunctor(f, M)(std::thread::hardware_concurrency()); make_ParallelTensorFunctor(f, M)(std::thread::hardware_concurrency());
} }
template <typename ADataType, typename BDataType, typename AccDataType, typename CDataType> template <typename ADataType,
typename BDataType,
typename AccDataType,
typename CDataType,
typename LayoutA,
typename LayoutB,
typename LayoutC>
__global__ void naive_gemm_kernel(ADataType* A, __global__ void naive_gemm_kernel(ADataType* A,
BDataType* B, BDataType* B,
CDataType* C, CDataType* C,
...@@ -76,18 +89,32 @@ __global__ void naive_gemm_kernel(ADataType* A, ...@@ -76,18 +89,32 @@ __global__ void naive_gemm_kernel(ADataType* A,
if(row < M && col < N) if(row < M && col < N)
{ {
AccDataType acc = 0.0; AccDataType acc = 0.0;
for(int k = 0; k < K; ++k) for(int k = 0; k < K; ++k)
{ {
acc += static_cast<AccDataType>(A[row * strideA + k]) * // Adjust indexing based on matrix layout
static_cast<AccDataType>(B[col * strideB + k]); int a_index = (std::is_same_v<LayoutA, tensor_layout::gemm::RowMajor>)
? row * strideA + k
: k * strideA + row;
int b_index = (std::is_same_v<LayoutB, tensor_layout::gemm::ColumnMajor>)
? col * strideB + k
: k * strideB + col;
acc += static_cast<AccDataType>(A[a_index]) * static_cast<AccDataType>(B[b_index]);
} }
C[row * strideC + col] = acc; // Store as AccDataType int c_index = (std::is_same_v<LayoutC, tensor_layout::gemm::RowMajor>)
? row * strideC + col
: col * strideC + row;
C[c_index] = acc;
} }
} }
template <typename ADataType, typename BDataType, typename AccDataType, typename CDataType> template <typename ADataType,
typename BDataType,
typename AccDataType,
typename CDataType,
typename LayoutA,
typename LayoutB,
typename LayoutC>
void reference_gemm_gpu(DeviceMem& a_device, void reference_gemm_gpu(DeviceMem& a_device,
DeviceMem& b_device, DeviceMem& b_device,
DeviceMem& c_device, DeviceMem& c_device,
...@@ -145,7 +172,7 @@ void reference_gemm_gpu(DeviceMem& a_device, ...@@ -145,7 +172,7 @@ void reference_gemm_gpu(DeviceMem& a_device,
int numThreadsPerBlock = 256; // Common choice for threads per block int numThreadsPerBlock = 256; // Common choice for threads per block
int numBlocks = (totalElements + numThreadsPerBlock - 1) / numThreadsPerBlock; int numBlocks = (totalElements + numThreadsPerBlock - 1) / numThreadsPerBlock;
naive_gemm_kernel<ADataType, BDataType, AccDataType, CDataType> naive_gemm_kernel<ADataType, BDataType, AccDataType, CDataType, LayoutA, LayoutB, LayoutC>
<<<numBlocks, numThreadsPerBlock>>>(d_A, d_B, d_C, M, N, K, stride_a, stride_b, stride_c); <<<numBlocks, numThreadsPerBlock>>>(d_A, d_B, d_C, M, N, K, stride_a, stride_b, stride_c);
errC = hipMemcpy( errC = hipMemcpy(
c_device.GetDeviceBuffer(), d_C, M * N * sizeof(CDataType), hipMemcpyDeviceToHost); c_device.GetDeviceBuffer(), d_C, M * N * sizeof(CDataType), hipMemcpyDeviceToHost);
......
...@@ -3,5 +3,6 @@ ...@@ -3,5 +3,6 @@
#pragma once #pragma once
#include "ck_tile/ops/epilogue/cshuffle_epilogue.hpp"
#include "ck_tile/ops/epilogue/default_2d_epilogue.hpp" #include "ck_tile/ops/epilogue/default_2d_epilogue.hpp"
#include "ck_tile/ops/common/tensor_layout.hpp" #include "ck_tile/ops/common/tensor_layout.hpp"
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core.hpp"
#define CK_TILE_MAX_RANK 5
namespace ck_tile {
// this epilogue aiming to store a matrix with different layout from the shared memory to the global
// memory.
template <typename AccDataType_,
typename ODataType_,
bool kPadM_,
bool kPadN_,
bool kTilePermute_,
index_t kRank_,
index_t kPerm0,
index_t kPerm1,
index_t TileSize0,
index_t TileSize1,
index_t kPerm2 = 0,
index_t kPerm3 = 0,
index_t kPerm4 = 0,
index_t TileSize2 = 0,
index_t TileSize3 = 0,
index_t TileSize4 = 0>
struct CShuffleEpilogueProblem
{
using AccDataType = remove_cvref_t<AccDataType_>;
using ODataType = remove_cvref_t<ODataType_>;
static constexpr bool kPadM = kPadM_;
static constexpr bool kPadN = kPadN_;
static constexpr bool kTilePermute = kTilePermute_;
static constexpr index_t kRank = kRank_;
static constexpr index_t kPerm[CK_TILE_MAX_RANK] = {kPerm0, kPerm1, kPerm2, kPerm3, kPerm4};
static constexpr index_t tile_sizes[CK_TILE_MAX_RANK] = {
TileSize0, TileSize1, TileSize2, TileSize3, TileSize4};
};
template <typename Problem_, typename Policy_ = void>
struct CShuffleEpilogue
{
using Problem = remove_cvref_t<Problem_>;
using AccDataType = remove_cvref_t<typename Problem::AccDataType>;
using ODataType = remove_cvref_t<typename Problem::ODataType>;
static constexpr bool kPadM = Problem::kPadM;
static constexpr bool kPadN = Problem::kPadN;
const index_t* kPerm = Problem::kPerm;
static constexpr bool kTilePermute = Problem::kTilePermute;
static constexpr index_t kRank = Problem::kRank;
const index_t* tile_sizes = Problem::tile_sizes;
// No additional shared memory needed
CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSize() { return 0; }
template <typename OAccTile>
CK_TILE_DEVICE void permute_tile_data(OAccTile& o_acc_tile)
{
using DataType = typename OAccTile::DataType;
// Get thread buffer
auto& thread_buf = o_acc_tile.get_thread_buffer();
// Create a temporary buffer to hold the permuted data
thread_buffer<DataType, OAccTile::kThreadElementSpaceSize> permuted_thread_buf;
// Get the lengths of each dimension
auto thread_tensor_lengths = o_acc_tile.get_lengths();
// Total number of elements
index_t total_elements = OAccTile::kThreadElementSpaceSize;
// Iterate over all elements
for(index_t linear_idx = 0; linear_idx < total_elements; ++linear_idx)
{
// Convert linear index to multi-dimensional indices
array<index_t, kRank> indices;
index_t remaining = linear_idx;
static_for<0, kRank, 1>{}([&](auto i) {
constexpr auto rev_i = kRank - 1 - i;
indices(rev_i) = remaining % thread_tensor_lengths.get(number<rev_i>{});
remaining /= thread_tensor_lengths.get(number<rev_i>{});
});
// Apply the permutation
array<index_t, kRank> permuted_indices;
static_for<0, kRank, 1>{}(
[&](auto i) { permuted_indices(i) = indices.get(number<Problem::kPerm[i]>{}); });
// Compute offsets
index_t dst_offset = 0;
index_t stride = 1;
static_for<0, kRank, 1>{}([&](auto i) {
constexpr auto rev_i = kRank - 1 - i;
dst_offset += permuted_indices[rev_i] * stride;
stride *= thread_tensor_lengths.get(number<rev_i>{});
});
// Move the data
permuted_thread_buf(dst_offset) = thread_buf[linear_idx];
}
// Copy the permuted data back to the original thread buffer
for(index_t i = 0; i < total_elements; ++i)
{
thread_buf.set_as(i, permuted_thread_buf.get(i));
}
}
template <typename ODramWindowTmp, typename OAccTile>
CK_TILE_DEVICE auto operator()(ODramWindowTmp& o_dram_window_tmp, OAccTile& o_acc_tile)
{
const auto& current_window_origin = o_dram_window_tmp.get_window_origin();
// Compute the tile coordinates by dividing the window origin by the tile sizes
index_t tile_coords[CK_TILE_MAX_RANK] = {0};
for(index_t i = 0; i < kRank; ++i)
{
tile_coords[i] = current_window_origin[i] / tile_sizes[i];
// printf("The tile_coord is: %d", tile_coords[i]);
}
// Apply the permutation to the tile coordinates
index_t permuted_tile_coords[CK_TILE_MAX_RANK];
for(index_t i = 0; i < kRank; ++i)
{
permuted_tile_coords[i] = tile_coords[kPerm[i]];
// printf("The new permuted_tile_coords is: %d", permuted_tile_coords[i]);
}
// Compute the permuted window origin
index_t permuted_window_origin[CK_TILE_MAX_RANK] = {0};
for(index_t i = 0; i < kRank; ++i)
{
permuted_window_origin[i] = permuted_tile_coords[i] * tile_sizes[i];
// printf("The new permuted_window_origin is: %d", permuted_window_origin[i]);
}
typename ODramWindowTmp::BottomTensorIndex step = {};
for(index_t i = 0; i < kRank; ++i)
{
step[i] = permuted_window_origin[i] - current_window_origin[i];
}
// Move the window
move_tile_window(o_dram_window_tmp, step);
// Permute the data within the tile if necessary
if constexpr(kTilePermute)
{
permute_tile_data(o_acc_tile);
}
// Store the tile data to the permuted location
if constexpr(kPadM || kPadN)
{
store_tile_raw(o_dram_window_tmp, cast_tile<ODataType>(o_acc_tile));
buffer_store_fence();
}
else
{
store_tile(o_dram_window_tmp, cast_tile<ODataType>(o_acc_tile));
}
}
};
} // namespace ck_tile
...@@ -6,8 +6,11 @@ ...@@ -6,8 +6,11 @@
#include "ck_tile/core.hpp" #include "ck_tile/core.hpp"
#include "ck_tile/ops/common.hpp" #include "ck_tile/ops/common.hpp"
#include "ck_tile/ops/fmha/block/block_attention_bias_enum.hpp" #include "ck_tile/ops/fmha/block/block_attention_bias_enum.hpp"
#include <string> #include <string>
#include <type_traits> #include <type_traits>
#include <utility>
#include <variant>
// S[seqlen_q, seqlen_k] = Q[seqlen_q, hdim_q] @ K[seqlen_k, hdim_q] // S[seqlen_q, seqlen_k] = Q[seqlen_q, hdim_q] @ K[seqlen_k, hdim_q]
// S'[seqlen_q, seqlen_k] = S[seqlen_q, seqlen_k] * Scale[1] // S'[seqlen_q, seqlen_k] = S[seqlen_q, seqlen_k] * Scale[1]
...@@ -194,11 +197,39 @@ struct FmhaBwdDQDKDVKernel ...@@ -194,11 +197,39 @@ struct FmhaBwdDQDKDVKernel
ck_tile::GenericAttentionMaskEnum mask_type; ck_tile::GenericAttentionMaskEnum mask_type;
}; };
struct FmhaBwdCommonDropoutKargs struct FmhaBwdDropoutSeedOffset
{ {
void init_dropout(const float p_drop, template <typename T>
const std::tuple<uint64_t, uint64_t>& drop_seed_offset, union ValueOrPointer
const float raw_scale) {
T val;
const T* ptr;
};
ValueOrPointer<uint64_t> drop_seed;
ValueOrPointer<uint64_t> drop_offset;
bool is_drop_seed_offset_from_host;
};
struct FmhaBwdCommonDropoutKargs : FmhaBwdDropoutSeedOffset
{
void init_dropout(float p_drop, uint64_t seed, uint64_t offset, float raw_scale)
{
float p_undrop = 1.0 - p_drop;
p_undrop_in_uint8_t =
uint8_t(std::floor(p_undrop * std::numeric_limits<uint8_t>::max()));
rp_undrop = 1.0 / p_undrop;
scale_rp_undrop = rp_undrop * raw_scale;
this->drop_seed.val = seed;
this->drop_offset.val = offset;
this->is_drop_seed_offset_from_host = true;
}
void init_dropout(float p_drop,
const uint64_t* seed_ptr,
const uint64_t* offset_ptr,
float raw_scale)
{ {
float p_undrop = 1.0 - p_drop; float p_undrop = 1.0 - p_drop;
p_undrop_in_uint8_t = p_undrop_in_uint8_t =
...@@ -206,23 +237,25 @@ struct FmhaBwdDQDKDVKernel ...@@ -206,23 +237,25 @@ struct FmhaBwdDQDKDVKernel
rp_undrop = 1.0 / p_undrop; rp_undrop = 1.0 / p_undrop;
scale_rp_undrop = rp_undrop * raw_scale; scale_rp_undrop = rp_undrop * raw_scale;
drop_seed = std::get<0>(drop_seed_offset); this->drop_seed.ptr = seed_ptr;
drop_offset = std::get<1>(drop_seed_offset); this->drop_offset.ptr = offset_ptr;
this->is_drop_seed_offset_from_host = false;
} }
float rp_undrop = 1; float rp_undrop = 1;
float scale_rp_undrop = 1; float scale_rp_undrop = 1;
uint8_t p_undrop_in_uint8_t = std::numeric_limits<uint8_t>::max(); uint8_t p_undrop_in_uint8_t = std::numeric_limits<uint8_t>::max();
uint64_t drop_seed = 1;
uint64_t drop_offset = 0;
void* rand_val_ptr = nullptr; void* rand_val_ptr = nullptr;
ck_tile::index_t stride_randval = 0; ck_tile::index_t stride_randval = 0;
ck_tile::index_t nhead_stride_randval = 0; ck_tile::index_t nhead_stride_randval = 0;
}; };
struct FmhaBwdBatchModeDropoutKargs : FmhaBwdCommonDropoutKargs struct FmhaBwdBatchModeDropoutKargs : FmhaBwdCommonDropoutKargs
{ {
ck_tile::index_t batch_stride_randval = 0; ck_tile::index_t batch_stride_randval = 0;
}; };
struct FmhaBwdDeterministicKargs struct FmhaBwdDeterministicKargs
{ {
ck_tile::index_t split_stride_dq_acc = 0; ck_tile::index_t split_stride_dq_acc = 0;
...@@ -327,7 +360,8 @@ struct FmhaBwdDQDKDVKernel ...@@ -327,7 +360,8 @@ struct FmhaBwdDQDKDVKernel
ck_tile::index_t window_size_right, ck_tile::index_t window_size_right,
ck_tile::index_t mask_type, ck_tile::index_t mask_type,
float p_drop, float p_drop,
const std::tuple<uint64_t, uint64_t>& drop_seed_offset) std::variant<std::pair<uint64_t, uint64_t>, std::pair<const void*, const void*>>
drop_seed_offset)
{ {
Kargs kargs{{q_ptr, Kargs kargs{{q_ptr,
k_ptr, k_ptr,
...@@ -405,7 +439,20 @@ struct FmhaBwdDQDKDVKernel ...@@ -405,7 +439,20 @@ struct FmhaBwdDQDKDVKernel
if constexpr(kHasDropout) if constexpr(kHasDropout)
{ {
kargs.init_dropout(p_drop, drop_seed_offset, scale); if(drop_seed_offset.index() == 0) // seed & offset come from host
{
const auto& [seed, offset] = std::get<0>(drop_seed_offset);
kargs.init_dropout(p_drop, seed, offset, scale);
}
else // seed & offset come from device
{
const auto& [seed_ptr, offset_ptr] = std::get<1>(drop_seed_offset);
kargs.init_dropout(p_drop,
reinterpret_cast<const uint64_t*>(seed_ptr),
reinterpret_cast<const uint64_t*>(offset_ptr),
scale);
}
if constexpr(kIsStoreRandval) if constexpr(kIsStoreRandval)
{ {
kargs.rand_val_ptr = rand_val_ptr; kargs.rand_val_ptr = rand_val_ptr;
...@@ -471,7 +518,8 @@ struct FmhaBwdDQDKDVKernel ...@@ -471,7 +518,8 @@ struct FmhaBwdDQDKDVKernel
ck_tile::index_t window_size_right, ck_tile::index_t window_size_right,
ck_tile::index_t mask_type, ck_tile::index_t mask_type,
float p_drop, float p_drop,
const std::tuple<uint64_t, uint64_t>& drop_seed_offset) std::variant<std::pair<uint64_t, uint64_t>, std::pair<const void*, const void*>>
drop_seed_offset)
{ {
Kargs kargs{{q_ptr, Kargs kargs{{q_ptr,
k_ptr, k_ptr,
...@@ -539,7 +587,20 @@ struct FmhaBwdDQDKDVKernel ...@@ -539,7 +587,20 @@ struct FmhaBwdDQDKDVKernel
} }
if constexpr(kHasDropout) if constexpr(kHasDropout)
{ {
kargs.init_dropout(p_drop, drop_seed_offset, scale); if(drop_seed_offset.index() == 0) // seed & offset come from host
{
const auto& [seed, offset] = std::get<0>(drop_seed_offset);
kargs.init_dropout(p_drop, seed, offset, scale);
}
else // seed & offset come from device
{
const auto& [seed_ptr, offset_ptr] = std::get<1>(drop_seed_offset);
kargs.init_dropout(p_drop,
reinterpret_cast<const uint64_t*>(seed_ptr),
reinterpret_cast<const uint64_t*>(offset_ptr),
scale);
}
if constexpr(kIsStoreRandval) if constexpr(kIsStoreRandval)
{ {
kargs.rand_val_ptr = rand_val_ptr; kargs.rand_val_ptr = rand_val_ptr;
...@@ -958,8 +1019,10 @@ struct FmhaBwdDQDKDVKernel ...@@ -958,8 +1019,10 @@ struct FmhaBwdDQDKDVKernel
return FmhaDropout{i_batch_, return FmhaDropout{i_batch_,
i_nhead_, i_nhead_,
kargs.num_head_q, kargs.num_head_q,
kargs.drop_seed, kargs.is_drop_seed_offset_from_host ? kargs.drop_seed.val
kargs.drop_offset, : *kargs.drop_seed.ptr,
kargs.is_drop_seed_offset_from_host ? kargs.drop_offset.val
: *kargs.drop_offset.ptr,
kargs.rp_undrop, kargs.rp_undrop,
kargs.p_undrop_in_uint8_t}; kargs.p_undrop_in_uint8_t};
} }
......
...@@ -6,8 +6,11 @@ ...@@ -6,8 +6,11 @@
#include "ck_tile/core.hpp" #include "ck_tile/core.hpp"
#include "ck_tile/ops/common.hpp" #include "ck_tile/ops/common.hpp"
#include "ck_tile/ops/fmha/block/block_attention_bias_enum.hpp" #include "ck_tile/ops/fmha/block/block_attention_bias_enum.hpp"
#include <string> #include <string>
#include <type_traits> #include <type_traits>
#include <utility>
#include <variant>
// S[seqlen_q, seqlen_k] = Q[seqlen_q, hdim_q] @ K[seqlen_k, hdim_q] // S[seqlen_q, seqlen_k] = Q[seqlen_q, hdim_q] @ K[seqlen_k, hdim_q]
// S'[seqlen_q, seqlen_k] = S[seqlen_q, seqlen_k] * Scale[1] // S'[seqlen_q, seqlen_k] = S[seqlen_q, seqlen_k] * Scale[1]
...@@ -170,29 +173,55 @@ struct FmhaFwdKernel ...@@ -170,29 +173,55 @@ struct FmhaFwdKernel
ck_tile::index_t batch_stride_lse = 0; ck_tile::index_t batch_stride_lse = 0;
}; };
struct FmhaFwdCommonDropoutKargs struct FmhaFwdDropoutSeedOffset
{ {
void init_dropout(const float p_drop, template <typename T>
const std::tuple<uint64_t, uint64_t>& drop_seed_offset) union ValueOrPointer
{
T val;
const T* ptr;
};
ValueOrPointer<uint64_t> drop_seed;
ValueOrPointer<uint64_t> drop_offset;
bool is_drop_seed_offset_from_host;
};
struct FmhaFwdCommonDropoutKargs : FmhaFwdDropoutSeedOffset
{
void init_dropout(float p_drop, uint64_t seed, uint64_t offset)
{
float p_undrop = 1.0 - p_drop;
p_undrop_in_uint8_t =
uint8_t(std::floor(p_undrop * std::numeric_limits<uint8_t>::max()));
rp_undrop = 1.0 / p_undrop;
this->drop_seed.val = seed;
this->drop_offset.val = offset;
this->is_drop_seed_offset_from_host = true;
}
void init_dropout(float p_drop, const uint64_t* seed_ptr, const uint64_t* offset_ptr)
{ {
float p_undrop = 1.0 - p_drop; float p_undrop = 1.0 - p_drop;
p_undrop_in_uint8_t = p_undrop_in_uint8_t =
uint8_t(std::floor(p_undrop * std::numeric_limits<uint8_t>::max())); uint8_t(std::floor(p_undrop * std::numeric_limits<uint8_t>::max()));
rp_undrop = 1.0 / p_undrop; rp_undrop = 1.0 / p_undrop;
drop_seed = std::get<0>(drop_seed_offset); this->drop_seed.ptr = seed_ptr;
drop_offset = std::get<1>(drop_seed_offset); this->drop_offset.ptr = offset_ptr;
this->is_drop_seed_offset_from_host = false;
} }
float rp_undrop = 1; float rp_undrop = 1;
uint8_t p_undrop_in_uint8_t = std::numeric_limits<uint8_t>::max(); uint8_t p_undrop_in_uint8_t = std::numeric_limits<uint8_t>::max();
bool is_store_randval = false; bool is_store_randval = false;
uint64_t drop_seed = 1;
uint64_t drop_offset = 0;
void* rand_val_ptr = nullptr; void* rand_val_ptr = nullptr;
ck_tile::index_t stride_randval = 0; ck_tile::index_t stride_randval = 0;
ck_tile::index_t nhead_stride_randval = 0; ck_tile::index_t nhead_stride_randval = 0;
}; };
struct FmhaFwdBatchModeDropoutKargs : FmhaFwdCommonDropoutKargs struct FmhaFwdBatchModeDropoutKargs : FmhaFwdCommonDropoutKargs
{ {
ck_tile::index_t batch_stride_randval = 0; ck_tile::index_t batch_stride_randval = 0;
...@@ -278,7 +307,8 @@ struct FmhaFwdKernel ...@@ -278,7 +307,8 @@ struct FmhaFwdKernel
ck_tile::index_t mask_type, ck_tile::index_t mask_type,
float p_drop, float p_drop,
bool s_randval, bool s_randval,
const std::tuple<uint64_t, uint64_t>& drop_seed_offset) std::variant<std::pair<uint64_t, uint64_t>, std::pair<const void*, const void*>>
drop_seed_offset)
{ {
Kargs kargs{{q_ptr, Kargs kargs{{q_ptr,
k_ptr, k_ptr,
...@@ -344,7 +374,19 @@ struct FmhaFwdKernel ...@@ -344,7 +374,19 @@ struct FmhaFwdKernel
} }
if constexpr(kHasDropout) if constexpr(kHasDropout)
{ {
kargs.init_dropout(p_drop, drop_seed_offset); if(drop_seed_offset.index() == 0) // seed & offset come from host
{
const auto& [seed, offset] = std::get<0>(drop_seed_offset);
kargs.init_dropout(p_drop, seed, offset);
}
else // seed & offset come from device
{
const auto& [seed_ptr, offset_ptr] = std::get<1>(drop_seed_offset);
kargs.init_dropout(p_drop,
reinterpret_cast<const uint64_t*>(seed_ptr),
reinterpret_cast<const uint64_t*>(offset_ptr));
}
kargs.rand_val_ptr = rand_val_ptr; kargs.rand_val_ptr = rand_val_ptr;
kargs.stride_randval = stride_randval; kargs.stride_randval = stride_randval;
kargs.nhead_stride_randval = nhead_stride_randval; kargs.nhead_stride_randval = nhead_stride_randval;
...@@ -392,7 +434,8 @@ struct FmhaFwdKernel ...@@ -392,7 +434,8 @@ struct FmhaFwdKernel
ck_tile::index_t mask_type, ck_tile::index_t mask_type,
float p_drop, float p_drop,
bool s_randval, bool s_randval,
const std::tuple<uint64_t, uint64_t>& drop_seed_offset) std::variant<std::pair<uint64_t, uint64_t>, std::pair<const void*, const void*>>
drop_seed_offset)
{ {
Kargs kargs{{q_ptr, Kargs kargs{{q_ptr,
k_ptr, k_ptr,
...@@ -455,7 +498,19 @@ struct FmhaFwdKernel ...@@ -455,7 +498,19 @@ struct FmhaFwdKernel
} }
if constexpr(kHasDropout) if constexpr(kHasDropout)
{ {
kargs.init_dropout(p_drop, drop_seed_offset); if(drop_seed_offset.index() == 0) // seed & offset come from host
{
const auto& [seed, offset] = std::get<0>(drop_seed_offset);
kargs.init_dropout(p_drop, seed, offset);
}
else // seed & offset come from device
{
const auto& [seed_ptr, offset_ptr] = std::get<1>(drop_seed_offset);
kargs.init_dropout(p_drop,
reinterpret_cast<const uint64_t*>(seed_ptr),
reinterpret_cast<const uint64_t*>(offset_ptr));
}
kargs.rand_val_ptr = rand_val_ptr; kargs.rand_val_ptr = rand_val_ptr;
kargs.stride_randval = stride_randval; kargs.stride_randval = stride_randval;
kargs.nhead_stride_randval = nhead_stride_randval; kargs.nhead_stride_randval = nhead_stride_randval;
...@@ -748,8 +803,10 @@ struct FmhaFwdKernel ...@@ -748,8 +803,10 @@ struct FmhaFwdKernel
return BlockDropout{i_batch_, return BlockDropout{i_batch_,
i_nhead_, i_nhead_,
kargs.num_head_q, kargs.num_head_q,
kargs.drop_seed, kargs.is_drop_seed_offset_from_host ? kargs.drop_seed.val
kargs.drop_offset, : *kargs.drop_seed.ptr,
kargs.is_drop_seed_offset_from_host ? kargs.drop_offset.val
: *kargs.drop_offset.ptr,
kargs.rp_undrop, kargs.rp_undrop,
kargs.p_undrop_in_uint8_t, kargs.p_undrop_in_uint8_t,
kargs.is_store_randval}; kargs.is_store_randval};
......
...@@ -5,8 +5,9 @@ ...@@ -5,8 +5,9 @@
#include "ck_tile/core.hpp" #include "ck_tile/core.hpp"
#include "ck_tile/ops/common/tensor_layout.hpp" #include "ck_tile/ops/common/tensor_layout.hpp"
#include "ck_tile/ops/gemm/pipeline/block_gemm_pipeline_problem.hpp" #include "ck_tile/ops/gemm/pipeline/gemm_pipeline_problem.hpp"
#include "ck_tile/ops/gemm/pipeline/tile_gemm_shape.hpp" #include "ck_tile/ops/gemm/pipeline/tile_gemm_shape.hpp"
#include "ck_tile/ops/gemm/pipeline/tile_gemm_traits.hpp"
#include "ck_tile/ops/gemm/warp/warp_gemm.hpp" #include "ck_tile/ops/gemm/warp/warp_gemm.hpp"
#include "ck_tile/ops/gemm/warp/warp_gemm_dispatcher.hpp" #include "ck_tile/ops/gemm/warp/warp_gemm_dispatcher.hpp"
#include "ck_tile/ops/gemm/block/block_gemm_areg_bsmem_creg_v1_custom_policy.hpp" #include "ck_tile/ops/gemm/block/block_gemm_areg_bsmem_creg_v1_custom_policy.hpp"
...@@ -25,15 +26,21 @@ struct BlockFmhaBwdPipelineDefaultPolicy ...@@ -25,15 +26,21 @@ struct BlockFmhaBwdPipelineDefaultPolicy
template <typename Problem> template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto GetQKBlockGemm() CK_TILE_HOST_DEVICE static constexpr auto GetQKBlockGemm()
{ {
using BlockGemmProblem = BlockGemmPipelineProblem< using GemmProblem =
typename Problem::QDataType, GemmPipelineProblem<typename Problem::QDataType,
typename Problem::KDataType, typename Problem::KDataType,
typename Problem::AccDataType, typename Problem::AccDataType,
TileGemmShape<sequence<Problem::BlockFmhaShape::kM0, TileGemmShape<sequence<Problem::BlockFmhaShape::kM0,
Problem::BlockFmhaShape::kN0, Problem::BlockFmhaShape::kN0,
Problem::BlockFmhaShape::kK0>, Problem::BlockFmhaShape::kK0>,
typename Problem::BlockFmhaShape::Gemm0BlockWarps, typename Problem::BlockFmhaShape::Gemm0BlockWarps,
typename Problem::BlockFmhaShape::Gemm0WarpTile>>; typename Problem::BlockFmhaShape::Gemm0WarpTile>,
TileGemmTraits<Problem::kPadSeqLenQ,
Problem::kPadSeqLenK,
Problem::kPadHeadDimQ,
typename tensor_layout::gemm::RowMajor,
typename tensor_layout::gemm::ColumnMajor,
typename tensor_layout::gemm::RowMajor>>;
using WarpGemm = WarpGemmMfmaDispatcher< using WarpGemm = WarpGemmMfmaDispatcher<
typename Problem::QDataType, typename Problem::QDataType,
...@@ -52,21 +59,27 @@ struct BlockFmhaBwdPipelineDefaultPolicy ...@@ -52,21 +59,27 @@ struct BlockFmhaBwdPipelineDefaultPolicy
typename Problem::BlockFmhaShape::Gemm0BlockWarps, typename Problem::BlockFmhaShape::Gemm0BlockWarps,
WarpGemm>; WarpGemm>;
return BlockGemmARegBRegCRegV1<BlockGemmProblem, BlockGemmPolicy>{}; return BlockGemmARegBRegCRegV1<GemmProblem, BlockGemmPolicy>{};
} }
template <typename Problem> template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto GetPTOGradTBlockGemm() CK_TILE_HOST_DEVICE static constexpr auto GetPTOGradTBlockGemm()
{ {
using BlockGemmProblem = BlockGemmPipelineProblem< using GemmProblem =
typename Problem::GemmDataType, GemmPipelineProblem<typename Problem::GemmDataType,
typename Problem::OGradDataType, typename Problem::OGradDataType,
typename Problem::AccDataType, typename Problem::AccDataType,
TileGemmShape<sequence<Problem::BlockFmhaShape::kN0, TileGemmShape<sequence<Problem::BlockFmhaShape::kN0,
Problem::BlockFmhaShape::kVHeaddim, Problem::BlockFmhaShape::kVHeaddim,
Problem::BlockFmhaShape::kK1>, Problem::BlockFmhaShape::kK1>,
typename Problem::BlockFmhaShape::Gemm1BlockWarps, typename Problem::BlockFmhaShape::Gemm1BlockWarps,
typename Problem::BlockFmhaShape::Gemm1WarpTile>>; typename Problem::BlockFmhaShape::Gemm1WarpTile>,
TileGemmTraits<Problem::kPadSeqLenQ,
Problem::kPadHeadDimV,
Problem::kPadHeadDimV,
typename tensor_layout::gemm::RowMajor,
typename tensor_layout::gemm::ColumnMajor,
typename tensor_layout::gemm::RowMajor>>;
using WarpGemm = using WarpGemm =
WarpGemmMfmaDispatcher<typename Problem::GemmDataType, WarpGemmMfmaDispatcher<typename Problem::GemmDataType,
...@@ -84,21 +97,27 @@ struct BlockFmhaBwdPipelineDefaultPolicy ...@@ -84,21 +97,27 @@ struct BlockFmhaBwdPipelineDefaultPolicy
typename Problem::BlockFmhaShape::Gemm1BlockWarps, typename Problem::BlockFmhaShape::Gemm1BlockWarps,
WarpGemm>; WarpGemm>;
return BlockGemmARegBRegCRegV1<BlockGemmProblem, BlockGemmPolicy>{}; return BlockGemmARegBRegCRegV1<GemmProblem, BlockGemmPolicy>{};
} }
template <typename Problem> template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto GetOGradVBlockGemm() CK_TILE_HOST_DEVICE static constexpr auto GetOGradVBlockGemm()
{ {
using BlockGemmProblem = BlockGemmPipelineProblem< using GemmProblem =
typename Problem::OGradDataType, GemmPipelineProblem<typename Problem::OGradDataType,
typename Problem::VDataType, typename Problem::VDataType,
typename Problem::AccDataType, typename Problem::AccDataType,
TileGemmShape<sequence<Problem::BlockFmhaShape::kM0, TileGemmShape<sequence<Problem::BlockFmhaShape::kM0,
Problem::BlockFmhaShape::kN0, Problem::BlockFmhaShape::kN0,
Problem::BlockFmhaShape::kK2>, Problem::BlockFmhaShape::kK2>,
typename Problem::BlockFmhaShape::Gemm2BlockWarps, typename Problem::BlockFmhaShape::Gemm2BlockWarps,
typename Problem::BlockFmhaShape::Gemm2WarpTile>>; typename Problem::BlockFmhaShape::Gemm2WarpTile>,
TileGemmTraits<Problem::kPadSeqLenQ,
Problem::kPadSeqLenK,
Problem::kPadHeadDimQ,
typename tensor_layout::gemm::RowMajor,
typename tensor_layout::gemm::ColumnMajor,
typename tensor_layout::gemm::RowMajor>>;
using WarpGemm = WarpGemmMfmaDispatcher< using WarpGemm = WarpGemmMfmaDispatcher<
typename Problem::OGradDataType, typename Problem::OGradDataType,
...@@ -117,21 +136,27 @@ struct BlockFmhaBwdPipelineDefaultPolicy ...@@ -117,21 +136,27 @@ struct BlockFmhaBwdPipelineDefaultPolicy
typename Problem::BlockFmhaShape::Gemm2BlockWarps, typename Problem::BlockFmhaShape::Gemm2BlockWarps,
WarpGemm>; WarpGemm>;
return BlockGemmARegBRegCRegV1<BlockGemmProblem, BlockGemmPolicy>{}; return BlockGemmARegBRegCRegV1<GemmProblem, BlockGemmPolicy>{};
} }
template <typename Problem> template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto GetSGradTQTBlockGemm() CK_TILE_HOST_DEVICE static constexpr auto GetSGradTQTBlockGemm()
{ {
using BlockGemmProblem = BlockGemmPipelineProblem< using GemmProblem =
typename Problem::GemmDataType, GemmPipelineProblem<typename Problem::GemmDataType,
typename Problem::QDataType, typename Problem::QDataType,
typename Problem::AccDataType, typename Problem::AccDataType,
TileGemmShape<sequence<Problem::BlockFmhaShape::kN0, TileGemmShape<sequence<Problem::BlockFmhaShape::kN0,
Problem::BlockFmhaShape::kQKHeaddim, Problem::BlockFmhaShape::kQKHeaddim,
Problem::BlockFmhaShape::kK3>, Problem::BlockFmhaShape::kK3>,
typename Problem::BlockFmhaShape::Gemm3BlockWarps, typename Problem::BlockFmhaShape::Gemm3BlockWarps,
typename Problem::BlockFmhaShape::Gemm3WarpTile>>; typename Problem::BlockFmhaShape::Gemm3WarpTile>,
TileGemmTraits<Problem::kPadSeqLenK,
Problem::kPadHeadDimQ,
Problem::kPadSeqLenK,
typename tensor_layout::gemm::RowMajor,
typename tensor_layout::gemm::ColumnMajor,
typename tensor_layout::gemm::RowMajor>>;
using WarpGemm = using WarpGemm =
WarpGemmMfmaDispatcher<typename Problem::GemmDataType, WarpGemmMfmaDispatcher<typename Problem::GemmDataType,
...@@ -149,21 +174,27 @@ struct BlockFmhaBwdPipelineDefaultPolicy ...@@ -149,21 +174,27 @@ struct BlockFmhaBwdPipelineDefaultPolicy
typename Problem::BlockFmhaShape::Gemm3BlockWarps, typename Problem::BlockFmhaShape::Gemm3BlockWarps,
WarpGemm>; WarpGemm>;
return BlockGemmARegBRegCRegV1<BlockGemmProblem, BlockGemmPolicy>{}; return BlockGemmARegBRegCRegV1<GemmProblem, BlockGemmPolicy>{};
} }
template <typename Problem> template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto GetSGradKTBlockGemm() CK_TILE_HOST_DEVICE static constexpr auto GetSGradKTBlockGemm()
{ {
using BlockGemmProblem = BlockGemmPipelineProblem< using GemmProblem =
typename Problem::GemmDataType, GemmPipelineProblem<typename Problem::GemmDataType,
typename Problem::KDataType, typename Problem::KDataType,
typename Problem::AccDataType, typename Problem::AccDataType,
TileGemmShape<sequence<Problem::BlockFmhaShape::kM0, TileGemmShape<sequence<Problem::BlockFmhaShape::kM0,
Problem::BlockFmhaShape::kQKHeaddim, Problem::BlockFmhaShape::kQKHeaddim,
Problem::BlockFmhaShape::kK4>, Problem::BlockFmhaShape::kK4>,
typename Problem::BlockFmhaShape::Gemm4BlockWarps, typename Problem::BlockFmhaShape::Gemm4BlockWarps,
typename Problem::BlockFmhaShape::Gemm4WarpTile>>; typename Problem::BlockFmhaShape::Gemm4WarpTile>,
TileGemmTraits<Problem::kPadSeqLenQ,
Problem::kPadHeadDimQ,
Problem::kPadSeqLenK,
typename tensor_layout::gemm::RowMajor,
typename tensor_layout::gemm::ColumnMajor,
typename tensor_layout::gemm::RowMajor>>;
using WarpGemm = using WarpGemm =
WarpGemmMfmaDispatcher<typename Problem::GemmDataType, WarpGemmMfmaDispatcher<typename Problem::GemmDataType,
...@@ -181,7 +212,7 @@ struct BlockFmhaBwdPipelineDefaultPolicy ...@@ -181,7 +212,7 @@ struct BlockFmhaBwdPipelineDefaultPolicy
typename Problem::BlockFmhaShape::Gemm4BlockWarps, typename Problem::BlockFmhaShape::Gemm4BlockWarps,
WarpGemm>; WarpGemm>;
return BlockGemmARegBRegCRegV1<BlockGemmProblem, BlockGemmPolicy>{}; return BlockGemmARegBRegCRegV1<GemmProblem, BlockGemmPolicy>{};
} }
// these are for global load // these are for global load
......
...@@ -172,22 +172,27 @@ struct BlockFmhaFwdSplitKVCombinePipeline ...@@ -172,22 +172,27 @@ struct BlockFmhaFwdSplitKVCombinePipeline
lse_accum, sequence<1>{}, f_max, -numeric<LSEDataType>::infinity()); lse_accum, sequence<1>{}, f_max, -numeric<LSEDataType>::infinity());
block_tile_reduce_sync(lse_max, f_max, bool_constant<false>{}); block_tile_reduce_sync(lse_max, f_max, bool_constant<false>{});
static const auto get_validated_m = [](LSEDataType raw_m) {
return raw_m == -numeric<LSEDataType>::infinity() ? type_convert<LSEDataType>(0.f)
: raw_m;
};
decltype(lse_accum) lse_exp; decltype(lse_accum) lse_exp;
{ {
constexpr auto spans = decltype(lse_exp)::get_distributed_spans(); constexpr auto spans = decltype(lse_exp)::get_distributed_spans();
sweep_tile_span(spans[number<0>{}], [&](auto idx0) { sweep_tile_span(spans[number<0>{}], [&](auto idx0) {
constexpr auto i_idx = make_tuple(idx0); constexpr auto i_idx = make_tuple(idx0);
sweep_tile_span(spans[number<1>{}], [&](auto idx1) { if(lse_max[i_idx] == -numeric<LSEDataType>::infinity())
constexpr auto i_j_idx = make_tuple(idx0, idx1); {
sweep_tile_span(spans[number<1>{}], [&](auto idx1) {
constexpr auto i_j_idx = make_tuple(idx0, idx1);
lse_exp(i_j_idx) = lse_exp(i_j_idx) = ck_tile::type_convert<LSEDataType>(0.0f);
ck_tile::exp(lse_accum(i_j_idx) - get_validated_m(lse_max(i_idx))); });
}); }
else
{
sweep_tile_span(spans[number<1>{}], [&](auto idx1) {
constexpr auto i_j_idx = make_tuple(idx0, idx1);
lse_exp(i_j_idx) = ck_tile::exp(lse_accum(i_j_idx) - lse_max(i_idx));
});
}
}); });
} }
...@@ -201,15 +206,10 @@ struct BlockFmhaFwdSplitKVCombinePipeline ...@@ -201,15 +206,10 @@ struct BlockFmhaFwdSplitKVCombinePipeline
sweep_tile_span(spans[number<0>{}], [&](auto idx0) { sweep_tile_span(spans[number<0>{}], [&](auto idx0) {
constexpr auto i_idx = make_tuple(idx0); constexpr auto i_idx = make_tuple(idx0);
if(lse_sum(i_idx) == 0.f || lse_sum(i_idx) != lse_sum(i_idx)) if(lse_sum[i_idx] == ck_tile::type_convert<LSEDataType>(0.0f))
{ lse_logsum(i_idx) = -numeric<LSEDataType>::infinity();
lse_logsum(i_idx) = numeric<LSEDataType>::infinity();
}
else else
{ lse_logsum(i_idx) = ck_tile::log(lse_sum(i_idx)) + lse_max(i_idx);
lse_logsum(i_idx) =
ck_tile::log(lse_sum(i_idx)) + get_validated_m(lse_max(i_idx));
}
}); });
} }
...@@ -218,37 +218,47 @@ struct BlockFmhaFwdSplitKVCombinePipeline ...@@ -218,37 +218,47 @@ struct BlockFmhaFwdSplitKVCombinePipeline
constexpr auto spans = decltype(lse_accum)::get_distributed_spans(); constexpr auto spans = decltype(lse_accum)::get_distributed_spans();
sweep_tile_span(spans[number<0>{}], [&](auto idx0) { sweep_tile_span(spans[number<0>{}], [&](auto idx0) {
constexpr auto i_idx = make_tuple(idx0); constexpr auto i_idx = make_tuple(idx0);
sweep_tile_span(spans[number<1>{}], [&](auto idx1) { if(lse_logsum(i_idx) == -numeric<LSEDataType>::infinity())
constexpr auto i_j_idx = make_tuple(idx0, idx1); {
sweep_tile_span(spans[number<1>{}], [&](auto idx1) {
constexpr auto i_j_idx = make_tuple(idx0, idx1);
const auto x_indices = get_x_indices_from_distributed_indices( const auto x_indices = get_x_indices_from_distributed_indices(
lse_accum.get_tile_distribution(), i_j_idx); lse_accum.get_tile_distribution(), i_j_idx);
const auto col = x_indices.at(number<1>{}); const auto col = x_indices.at(number<1>{});
if(col < num_splits) if(col < num_splits)
{ {
const auto row = x_indices.at(number<0>{}); const auto row = x_indices.at(number<0>{});
lse_acc_lds(row, col) = lse_acc_lds(row, col) = ck_tile::type_convert<LSEDataType>(0.0f);
ck_tile::exp(lse_accum(i_j_idx) - lse_logsum(i_idx)); }
} });
}); }
else
{
sweep_tile_span(spans[number<1>{}], [&](auto idx1) {
constexpr auto i_j_idx = make_tuple(idx0, idx1);
const auto x_indices = get_x_indices_from_distributed_indices(
lse_accum.get_tile_distribution(), i_j_idx);
const auto col = x_indices.at(number<1>{});
if(col < num_splits)
{
const auto row = x_indices.at(number<0>{});
lse_acc_lds(row, col) =
ck_tile::exp(lse_accum(i_j_idx) - lse_logsum(i_idx));
}
});
}
}); });
} }
block_sync_lds(); block_sync_lds();
if constexpr(kStoreLSE) if constexpr(kStoreLSE)
{ {
constexpr auto spans = decltype(lse_logsum)::get_distributed_spans();
sweep_tile_span(spans[number<0>{}], [&](auto idx0) {
constexpr auto i_idx = make_tuple(idx0);
if(lse_logsum(i_idx) == numeric<LSEDataType>::infinity())
{
lse_logsum(i_idx) = -numeric<LSEDataType>::infinity();
}
});
store_tile(lse_dram_window_tmp, tile_elementwise_in(lse_element_func, lse_logsum)); store_tile(lse_dram_window_tmp, tile_elementwise_in(lse_element_func, lse_logsum));
} }
......
...@@ -21,14 +21,23 @@ struct BlockFmhaFwdSplitKVCombinePipelineDefaultPolicy ...@@ -21,14 +21,23 @@ struct BlockFmhaFwdSplitKVCombinePipelineDefaultPolicy
CK_TILE_HOST_DEVICE static constexpr auto GetAlignmentOacc() CK_TILE_HOST_DEVICE static constexpr auto GetAlignmentOacc()
{ {
using OaccDataType = remove_cvref_t<typename Problem::OaccDataType>; using OaccDataType = remove_cvref_t<typename Problem::OaccDataType>;
return 16 / sizeof(OaccDataType);
constexpr index_t kBlockSize = Problem::kBlockSize;
constexpr index_t kMPerBlock = Problem::kM0;
constexpr index_t kNPerBlock = Problem::kN1;
constexpr index_t M1 = kBlockSize / get_warp_size();
constexpr index_t M2 = min(kMPerBlock / M1, get_warp_size());
constexpr index_t N0 = get_warp_size() / M2;
constexpr index_t N1 = kNPerBlock / N0;
return min(N1, static_cast<index_t>(16 / sizeof(OaccDataType)));
} }
template <typename Problem> template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto GetAlignmentO() CK_TILE_HOST_DEVICE static constexpr auto GetAlignmentO()
{ {
using ODataType = remove_cvref_t<typename Problem::ODataType>; return GetAlignmentOacc<Problem>();
return 16 / sizeof(ODataType);
} }
template <typename Problem> template <typename Problem>
...@@ -150,16 +159,14 @@ struct BlockFmhaFwdSplitKVCombinePipelineDefaultPolicy ...@@ -150,16 +159,14 @@ struct BlockFmhaFwdSplitKVCombinePipelineDefaultPolicy
template <typename Problem> template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto MakeOaccDramTileDistribution() CK_TILE_HOST_DEVICE static constexpr auto MakeOaccDramTileDistribution()
{ {
using OaccDataType = remove_cvref_t<typename Problem::OaccDataType>;
constexpr index_t kBlockSize = Problem::kBlockSize; constexpr index_t kBlockSize = Problem::kBlockSize;
constexpr index_t kMPerBlock = Problem::kM0; constexpr index_t kMPerBlock = Problem::kM0;
constexpr index_t kNPerBlock = Problem::kN1; constexpr index_t kNPerBlock = Problem::kN1;
constexpr index_t N1 = 16 / sizeof(OaccDataType);
constexpr index_t N0 = kNPerBlock / N1;
constexpr index_t M2 = get_warp_size() / N0;
constexpr index_t M1 = kBlockSize / get_warp_size(); constexpr index_t M1 = kBlockSize / get_warp_size();
constexpr index_t M2 = min(kMPerBlock / M1, get_warp_size());
constexpr index_t N0 = get_warp_size() / M2;
constexpr index_t N1 = kNPerBlock / N0;
constexpr index_t M0 = kMPerBlock / (M2 * M1); constexpr index_t M0 = kMPerBlock / (M2 * M1);
return make_static_tile_distribution( return make_static_tile_distribution(
......
...@@ -5,8 +5,9 @@ ...@@ -5,8 +5,9 @@
#include "ck_tile/core.hpp" #include "ck_tile/core.hpp"
#include "ck_tile/ops/common/tensor_layout.hpp" #include "ck_tile/ops/common/tensor_layout.hpp"
#include "ck_tile/ops/gemm/pipeline/block_gemm_pipeline_problem.hpp" #include "ck_tile/ops/gemm/pipeline/gemm_pipeline_problem.hpp"
#include "ck_tile/ops/gemm/pipeline/tile_gemm_shape.hpp" #include "ck_tile/ops/gemm/pipeline/tile_gemm_shape.hpp"
#include "ck_tile/ops/gemm/pipeline/tile_gemm_traits.hpp"
#include "ck_tile/ops/gemm/warp/warp_gemm.hpp" #include "ck_tile/ops/gemm/warp/warp_gemm.hpp"
#include "ck_tile/ops/gemm/warp/warp_gemm_dispatcher.hpp" #include "ck_tile/ops/gemm/warp/warp_gemm_dispatcher.hpp"
#include "ck_tile/ops/gemm/block/block_gemm_asmem_bsmem_creg_v1_custom_policy.hpp" #include "ck_tile/ops/gemm/block/block_gemm_asmem_bsmem_creg_v1_custom_policy.hpp"
...@@ -75,15 +76,21 @@ struct BlockFmhaPipelineQXCustomPolicy</* QLoadOnce = */ true> ...@@ -75,15 +76,21 @@ struct BlockFmhaPipelineQXCustomPolicy</* QLoadOnce = */ true>
template <typename Problem> template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto GetQKBlockGemm() CK_TILE_HOST_DEVICE static constexpr auto GetQKBlockGemm()
{ {
using BlockGemmProblem = BlockGemmPipelineProblem< using GemmProblem =
typename Problem::QDataType, GemmPipelineProblem<typename Problem::QDataType,
typename Problem::KDataType, typename Problem::KDataType,
typename Problem::SaccDataType, typename Problem::SaccDataType,
TileGemmShape<sequence<Problem::BlockFmhaShape::kM0, TileGemmShape<sequence<Problem::BlockFmhaShape::kM0,
Problem::BlockFmhaShape::kN0, Problem::BlockFmhaShape::kN0,
Problem::BlockFmhaShape::kK0>, Problem::BlockFmhaShape::kK0>,
typename Problem::BlockFmhaShape::Gemm0BlockWarps, typename Problem::BlockFmhaShape::Gemm0BlockWarps,
typename Problem::BlockFmhaShape::Gemm0WarpTile>>; typename Problem::BlockFmhaShape::Gemm0WarpTile>,
TileGemmTraits<Problem::kPadSeqLenQ,
Problem::kPadSeqLenK,
Problem::kPadHeadDimQ,
typename tensor_layout::gemm::RowMajor,
typename tensor_layout::gemm::ColumnMajor,
typename tensor_layout::gemm::RowMajor>>;
constexpr auto warp_gemm = []() { constexpr auto warp_gemm = []() {
if constexpr(std::is_same_v<typename Problem::QDataType, half_t> && if constexpr(std::is_same_v<typename Problem::QDataType, half_t> &&
...@@ -116,7 +123,7 @@ struct BlockFmhaPipelineQXCustomPolicy</* QLoadOnce = */ true> ...@@ -116,7 +123,7 @@ struct BlockFmhaPipelineQXCustomPolicy</* QLoadOnce = */ true>
typename Problem::BlockFmhaShape::Gemm0BlockWarps, typename Problem::BlockFmhaShape::Gemm0BlockWarps,
decltype(warp_gemm)>; decltype(warp_gemm)>;
return BlockGemmARegBSmemCRegV2<BlockGemmProblem, BlockGemmPolicy>{}; return BlockGemmARegBSmemCRegV2<GemmProblem, BlockGemmPolicy>{};
} }
}; };
...@@ -199,15 +206,21 @@ struct BlockFmhaPipelineQXCustomPolicy</* QLoadOnce = */ false> ...@@ -199,15 +206,21 @@ struct BlockFmhaPipelineQXCustomPolicy</* QLoadOnce = */ false>
template <typename Problem> template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto GetQKBlockGemm() CK_TILE_HOST_DEVICE static constexpr auto GetQKBlockGemm()
{ {
using BlockGemmProblem = BlockGemmPipelineProblem< using GemmProblem =
typename Problem::QDataType, GemmPipelineProblem<typename Problem::QDataType,
typename Problem::KDataType, typename Problem::KDataType,
typename Problem::SaccDataType, typename Problem::SaccDataType,
TileGemmShape<sequence<Problem::BlockFmhaShape::kM0, TileGemmShape<sequence<Problem::BlockFmhaShape::kM0,
Problem::BlockFmhaShape::kN0, Problem::BlockFmhaShape::kN0,
Problem::BlockFmhaShape::kK0>, Problem::BlockFmhaShape::kK0>,
typename Problem::BlockFmhaShape::Gemm0BlockWarps, typename Problem::BlockFmhaShape::Gemm0BlockWarps,
typename Problem::BlockFmhaShape::Gemm0WarpTile>>; typename Problem::BlockFmhaShape::Gemm0WarpTile>,
TileGemmTraits<Problem::kPadSeqLenQ,
Problem::kPadSeqLenK,
Problem::kPadHeadDimQ,
typename tensor_layout::gemm::RowMajor,
typename tensor_layout::gemm::ColumnMajor,
typename tensor_layout::gemm::RowMajor>>;
constexpr auto warp_gemm = []() { constexpr auto warp_gemm = []() {
if constexpr(std::is_same_v<typename Problem::QDataType, half_t> && if constexpr(std::is_same_v<typename Problem::QDataType, half_t> &&
...@@ -240,7 +253,7 @@ struct BlockFmhaPipelineQXCustomPolicy</* QLoadOnce = */ false> ...@@ -240,7 +253,7 @@ struct BlockFmhaPipelineQXCustomPolicy</* QLoadOnce = */ false>
typename Problem::BlockFmhaShape::Gemm0BlockWarps, typename Problem::BlockFmhaShape::Gemm0BlockWarps,
decltype(warp_gemm)>; decltype(warp_gemm)>;
return BlockGemmASmemBSmemCRegV1<BlockGemmProblem, BlockGemmPolicy>{}; return BlockGemmASmemBSmemCRegV1<GemmProblem, BlockGemmPolicy>{};
} }
}; };
...@@ -954,15 +967,21 @@ struct BlockFmhaPipelineQXKSVSCustomPolicy : BlockFmhaPipelineQXCustomPolicy<QLo ...@@ -954,15 +967,21 @@ struct BlockFmhaPipelineQXKSVSCustomPolicy : BlockFmhaPipelineQXCustomPolicy<QLo
template <typename Problem> template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto GetKVBlockGemm() CK_TILE_HOST_DEVICE static constexpr auto GetKVBlockGemm()
{ {
using BlockGemmProblem = BlockGemmPipelineProblem< using GemmProblem =
typename Problem::PDataType, GemmPipelineProblem<typename Problem::PDataType,
typename Problem::VDataType, typename Problem::VDataType,
typename Problem::OaccDataType, typename Problem::OaccDataType,
TileGemmShape<sequence<Problem::BlockFmhaShape::kM0, TileGemmShape<sequence<Problem::BlockFmhaShape::kM0,
Problem::BlockFmhaShape::kN1, Problem::BlockFmhaShape::kN1,
Problem::BlockFmhaShape::kK1>, Problem::BlockFmhaShape::kK1>,
typename Problem::BlockFmhaShape::Gemm1BlockWarps, typename Problem::BlockFmhaShape::Gemm1BlockWarps,
typename Problem::BlockFmhaShape::Gemm1WarpTile>>; typename Problem::BlockFmhaShape::Gemm1WarpTile>,
TileGemmTraits<Problem::kPadSeqLenQ,
Problem::kPadSeqLenK,
Problem::kPadHeadDimQ,
typename tensor_layout::gemm::RowMajor,
typename tensor_layout::gemm::ColumnMajor,
typename tensor_layout::gemm::RowMajor>>;
auto warp_gemm = [&]() { auto warp_gemm = [&]() {
if constexpr(std::is_same_v<typename Problem::KDataType, fp8_t> && if constexpr(std::is_same_v<typename Problem::KDataType, fp8_t> &&
...@@ -996,7 +1015,7 @@ struct BlockFmhaPipelineQXKSVSCustomPolicy : BlockFmhaPipelineQXCustomPolicy<QLo ...@@ -996,7 +1015,7 @@ struct BlockFmhaPipelineQXKSVSCustomPolicy : BlockFmhaPipelineQXCustomPolicy<QLo
typename Problem::OaccDataType, typename Problem::OaccDataType,
typename Problem::BlockFmhaShape::Gemm1BlockWarps, typename Problem::BlockFmhaShape::Gemm1BlockWarps,
WarpGemm>; WarpGemm>;
return BlockGemmARegBSmemCRegV2<BlockGemmProblem, BlockGemmPolicy>{}; return BlockGemmARegBSmemCRegV2<GemmProblem, BlockGemmPolicy>{};
} }
}; };
......
...@@ -23,12 +23,13 @@ ...@@ -23,12 +23,13 @@
#include "ck_tile/ops/gemm/block/block_gemm_problem.hpp" #include "ck_tile/ops/gemm/block/block_gemm_problem.hpp"
#include "ck_tile/ops/gemm/kernel/gemm_kernel.hpp" #include "ck_tile/ops/gemm/kernel/gemm_kernel.hpp"
#include "ck_tile/ops/gemm/kernel/gemm_tile_partitioner.hpp" #include "ck_tile/ops/gemm/kernel/gemm_tile_partitioner.hpp"
#include "ck_tile/ops/gemm/pipeline/block_gemm_pipeline_agmem_bgmem_creg_v1.hpp" #include "ck_tile/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_v1.hpp"
#include "ck_tile/ops/gemm/pipeline/block_gemm_pipeline_agmem_bgmem_creg_v1_default_policy.hpp" #include "ck_tile/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_v1_default_policy.hpp"
#include "ck_tile/ops/gemm/pipeline/block_gemm_pipeline_agmem_bgmem_creg_v2.hpp" #include "ck_tile/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_v2.hpp"
#include "ck_tile/ops/gemm/pipeline/block_gemm_pipeline_agmem_bgmem_creg_v2_default_policy.hpp" #include "ck_tile/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_v2_default_policy.hpp"
#include "ck_tile/ops/gemm/pipeline/block_gemm_pipeline_problem.hpp" #include "ck_tile/ops/gemm/pipeline/gemm_pipeline_problem.hpp"
#include "ck_tile/ops/gemm/pipeline/tile_gemm_shape.hpp" #include "ck_tile/ops/gemm/pipeline/tile_gemm_shape.hpp"
#include "ck_tile/ops/gemm/pipeline/tile_gemm_traits.hpp"
#include "ck_tile/ops/gemm/warp/warp_gemm.hpp" #include "ck_tile/ops/gemm/warp/warp_gemm.hpp"
#include "ck_tile/ops/gemm/warp/warp_gemm_attribute_mfma.hpp" #include "ck_tile/ops/gemm/warp/warp_gemm_attribute_mfma.hpp"
#include "ck_tile/ops/gemm/warp/warp_gemm_attribute_mfma_impl.hpp" #include "ck_tile/ops/gemm/warp/warp_gemm_attribute_mfma_impl.hpp"
......
...@@ -11,20 +11,12 @@ ...@@ -11,20 +11,12 @@
namespace ck_tile { namespace ck_tile {
template <typename TilePartitioner_, template <typename TilePartitioner_, typename GemmPipeline_, typename EpiloguePipeline_>
typename GemmPipeline_,
typename EpiloguePipeline_,
typename LayoutA_,
typename LayoutB_,
typename LayoutC_>
struct GemmKernel struct GemmKernel
{ {
using TilePartitioner = remove_cvref_t<TilePartitioner_>; using TilePartitioner = remove_cvref_t<TilePartitioner_>;
using GemmPipeline = remove_cvref_t<GemmPipeline_>; using GemmPipeline = remove_cvref_t<GemmPipeline_>;
using EpiloguePipeline = remove_cvref_t<EpiloguePipeline_>; using EpiloguePipeline = remove_cvref_t<EpiloguePipeline_>;
using LayoutA = remove_cvref_t<LayoutA_>;
using LayoutB = remove_cvref_t<LayoutB_>;
using LayoutC = remove_cvref_t<LayoutC_>;
static constexpr index_t KernelBlockSize = GemmPipeline::kBlockSize; static constexpr index_t KernelBlockSize = GemmPipeline::kBlockSize;
using ADataType = remove_cvref_t<typename GemmPipeline::ADataType>; using ADataType = remove_cvref_t<typename GemmPipeline::ADataType>;
...@@ -32,6 +24,10 @@ struct GemmKernel ...@@ -32,6 +24,10 @@ struct GemmKernel
using CAccDataType = remove_cvref_t<typename GemmPipeline::CDataType>; using CAccDataType = remove_cvref_t<typename GemmPipeline::CDataType>;
using CODataType = remove_cvref_t<typename EpiloguePipeline::ODataType>; using CODataType = remove_cvref_t<typename EpiloguePipeline::ODataType>;
using LayoutA = remove_cvref_t<typename GemmPipeline::LayoutA>;
using LayoutB = remove_cvref_t<typename GemmPipeline::LayoutB>;
using LayoutC = remove_cvref_t<typename GemmPipeline::LayoutC>;
__host__ static constexpr auto GridSize(index_t M_size, index_t N_size, index_t Batch_size) __host__ static constexpr auto GridSize(index_t M_size, index_t N_size, index_t Batch_size)
{ {
return TilePartitioner::GridSize(M_size, N_size, Batch_size); return TilePartitioner::GridSize(M_size, N_size, Batch_size);
...@@ -184,6 +180,7 @@ struct GemmKernel ...@@ -184,6 +180,7 @@ struct GemmKernel
c_pad_view, c_pad_view,
make_tuple(number<TilePartitioner::kM>{}, number<TilePartitioner::kN>{}), make_tuple(number<TilePartitioner::kM>{}, number<TilePartitioner::kN>{}),
{i_m, i_n}); {i_m, i_n});
EpiloguePipeline{}(CBlockWindow_pad, acc); EpiloguePipeline{}(CBlockWindow_pad, acc);
} }
}; };
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment