Unverified Commit dda18da0 authored by Illia Silin's avatar Illia Silin Committed by GitHub
Browse files

Merge branch 'develop' into ck_migraphx_integration

parents 3b2a7aee 4cf70b36
...@@ -3,5 +3,6 @@ ...@@ -3,5 +3,6 @@
#pragma once #pragma once
#include "ck_tile/ops/epilogue/cshuffle_epilogue.hpp"
#include "ck_tile/ops/epilogue/default_2d_epilogue.hpp" #include "ck_tile/ops/epilogue/default_2d_epilogue.hpp"
#include "ck_tile/ops/common/tensor_layout.hpp" #include "ck_tile/ops/common/tensor_layout.hpp"
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core.hpp"
#define CK_TILE_MAX_RANK 5
namespace ck_tile {
// this epilogue aiming to store a matrix with different layout from the shared memory to the global
// memory.
template <typename AccDataType_,
typename ODataType_,
bool kPadM_,
bool kPadN_,
bool kTilePermute_,
index_t kRank_,
index_t kPerm0,
index_t kPerm1,
index_t TileSize0,
index_t TileSize1,
index_t kPerm2 = 0,
index_t kPerm3 = 0,
index_t kPerm4 = 0,
index_t TileSize2 = 0,
index_t TileSize3 = 0,
index_t TileSize4 = 0>
struct CShuffleEpilogueProblem
{
using AccDataType = remove_cvref_t<AccDataType_>;
using ODataType = remove_cvref_t<ODataType_>;
static constexpr bool kPadM = kPadM_;
static constexpr bool kPadN = kPadN_;
static constexpr bool kTilePermute = kTilePermute_;
static constexpr index_t kRank = kRank_;
static constexpr index_t kPerm[CK_TILE_MAX_RANK] = {kPerm0, kPerm1, kPerm2, kPerm3, kPerm4};
static constexpr index_t tile_sizes[CK_TILE_MAX_RANK] = {
TileSize0, TileSize1, TileSize2, TileSize3, TileSize4};
};
template <typename Problem_, typename Policy_ = void>
struct CShuffleEpilogue
{
using Problem = remove_cvref_t<Problem_>;
using AccDataType = remove_cvref_t<typename Problem::AccDataType>;
using ODataType = remove_cvref_t<typename Problem::ODataType>;
static constexpr bool kPadM = Problem::kPadM;
static constexpr bool kPadN = Problem::kPadN;
const index_t* kPerm = Problem::kPerm;
static constexpr bool kTilePermute = Problem::kTilePermute;
static constexpr index_t kRank = Problem::kRank;
const index_t* tile_sizes = Problem::tile_sizes;
// No additional shared memory needed
CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSize() { return 0; }
template <typename OAccTile>
CK_TILE_DEVICE void permute_tile_data(OAccTile& o_acc_tile)
{
using DataType = typename OAccTile::DataType;
// Get thread buffer
auto& thread_buf = o_acc_tile.get_thread_buffer();
// Create a temporary buffer to hold the permuted data
thread_buffer<DataType, OAccTile::kThreadElementSpaceSize> permuted_thread_buf;
// Get the lengths of each dimension
auto thread_tensor_lengths = o_acc_tile.get_lengths();
// Total number of elements
index_t total_elements = OAccTile::kThreadElementSpaceSize;
// Iterate over all elements
for(index_t linear_idx = 0; linear_idx < total_elements; ++linear_idx)
{
// Convert linear index to multi-dimensional indices
array<index_t, kRank> indices;
index_t remaining = linear_idx;
static_for<0, kRank, 1>{}([&](auto i) {
constexpr auto rev_i = kRank - 1 - i;
indices(rev_i) = remaining % thread_tensor_lengths.get(number<rev_i>{});
remaining /= thread_tensor_lengths.get(number<rev_i>{});
});
// Apply the permutation
array<index_t, kRank> permuted_indices;
static_for<0, kRank, 1>{}(
[&](auto i) { permuted_indices(i) = indices.get(number<Problem::kPerm[i]>{}); });
// Compute offsets
index_t dst_offset = 0;
index_t stride = 1;
static_for<0, kRank, 1>{}([&](auto i) {
constexpr auto rev_i = kRank - 1 - i;
dst_offset += permuted_indices[rev_i] * stride;
stride *= thread_tensor_lengths.get(number<rev_i>{});
});
// Move the data
permuted_thread_buf(dst_offset) = thread_buf[linear_idx];
}
// Copy the permuted data back to the original thread buffer
for(index_t i = 0; i < total_elements; ++i)
{
thread_buf.set_as(i, permuted_thread_buf.get(i));
}
}
template <typename ODramWindowTmp, typename OAccTile>
CK_TILE_DEVICE auto operator()(ODramWindowTmp& o_dram_window_tmp, OAccTile& o_acc_tile)
{
const auto& current_window_origin = o_dram_window_tmp.get_window_origin();
// Compute the tile coordinates by dividing the window origin by the tile sizes
index_t tile_coords[CK_TILE_MAX_RANK] = {0};
for(index_t i = 0; i < kRank; ++i)
{
tile_coords[i] = current_window_origin[i] / tile_sizes[i];
// printf("The tile_coord is: %d", tile_coords[i]);
}
// Apply the permutation to the tile coordinates
index_t permuted_tile_coords[CK_TILE_MAX_RANK];
for(index_t i = 0; i < kRank; ++i)
{
permuted_tile_coords[i] = tile_coords[kPerm[i]];
// printf("The new permuted_tile_coords is: %d", permuted_tile_coords[i]);
}
// Compute the permuted window origin
index_t permuted_window_origin[CK_TILE_MAX_RANK] = {0};
for(index_t i = 0; i < kRank; ++i)
{
permuted_window_origin[i] = permuted_tile_coords[i] * tile_sizes[i];
// printf("The new permuted_window_origin is: %d", permuted_window_origin[i]);
}
typename ODramWindowTmp::BottomTensorIndex step = {};
for(index_t i = 0; i < kRank; ++i)
{
step[i] = permuted_window_origin[i] - current_window_origin[i];
}
// Move the window
move_tile_window(o_dram_window_tmp, step);
// Permute the data within the tile if necessary
if constexpr(kTilePermute)
{
permute_tile_data(o_acc_tile);
}
// Store the tile data to the permuted location
if constexpr(kPadM || kPadN)
{
store_tile_raw(o_dram_window_tmp, cast_tile<ODataType>(o_acc_tile));
buffer_store_fence();
}
else
{
store_tile(o_dram_window_tmp, cast_tile<ODataType>(o_acc_tile));
}
}
};
} // namespace ck_tile
...@@ -6,8 +6,11 @@ ...@@ -6,8 +6,11 @@
#include "ck_tile/core.hpp" #include "ck_tile/core.hpp"
#include "ck_tile/ops/common.hpp" #include "ck_tile/ops/common.hpp"
#include "ck_tile/ops/fmha/block/block_attention_bias_enum.hpp" #include "ck_tile/ops/fmha/block/block_attention_bias_enum.hpp"
#include <string> #include <string>
#include <type_traits> #include <type_traits>
#include <utility>
#include <variant>
// S[seqlen_q, seqlen_k] = Q[seqlen_q, hdim_q] @ K[seqlen_k, hdim_q] // S[seqlen_q, seqlen_k] = Q[seqlen_q, hdim_q] @ K[seqlen_k, hdim_q]
// S'[seqlen_q, seqlen_k] = S[seqlen_q, seqlen_k] * Scale[1] // S'[seqlen_q, seqlen_k] = S[seqlen_q, seqlen_k] * Scale[1]
...@@ -194,11 +197,39 @@ struct FmhaBwdDQDKDVKernel ...@@ -194,11 +197,39 @@ struct FmhaBwdDQDKDVKernel
ck_tile::GenericAttentionMaskEnum mask_type; ck_tile::GenericAttentionMaskEnum mask_type;
}; };
struct FmhaBwdCommonDropoutKargs struct FmhaBwdDropoutSeedOffset
{ {
void init_dropout(const float p_drop, template <typename T>
const std::tuple<uint64_t, uint64_t>& drop_seed_offset, union ValueOrPointer
const float raw_scale) {
T val;
const T* ptr;
};
ValueOrPointer<uint64_t> drop_seed;
ValueOrPointer<uint64_t> drop_offset;
bool is_drop_seed_offset_from_host;
};
struct FmhaBwdCommonDropoutKargs : FmhaBwdDropoutSeedOffset
{
void init_dropout(float p_drop, uint64_t seed, uint64_t offset, float raw_scale)
{
float p_undrop = 1.0 - p_drop;
p_undrop_in_uint8_t =
uint8_t(std::floor(p_undrop * std::numeric_limits<uint8_t>::max()));
rp_undrop = 1.0 / p_undrop;
scale_rp_undrop = rp_undrop * raw_scale;
this->drop_seed.val = seed;
this->drop_offset.val = offset;
this->is_drop_seed_offset_from_host = true;
}
void init_dropout(float p_drop,
const uint64_t* seed_ptr,
const uint64_t* offset_ptr,
float raw_scale)
{ {
float p_undrop = 1.0 - p_drop; float p_undrop = 1.0 - p_drop;
p_undrop_in_uint8_t = p_undrop_in_uint8_t =
...@@ -206,23 +237,25 @@ struct FmhaBwdDQDKDVKernel ...@@ -206,23 +237,25 @@ struct FmhaBwdDQDKDVKernel
rp_undrop = 1.0 / p_undrop; rp_undrop = 1.0 / p_undrop;
scale_rp_undrop = rp_undrop * raw_scale; scale_rp_undrop = rp_undrop * raw_scale;
drop_seed = std::get<0>(drop_seed_offset); this->drop_seed.ptr = seed_ptr;
drop_offset = std::get<1>(drop_seed_offset); this->drop_offset.ptr = offset_ptr;
this->is_drop_seed_offset_from_host = false;
} }
float rp_undrop = 1; float rp_undrop = 1;
float scale_rp_undrop = 1; float scale_rp_undrop = 1;
uint8_t p_undrop_in_uint8_t = std::numeric_limits<uint8_t>::max(); uint8_t p_undrop_in_uint8_t = std::numeric_limits<uint8_t>::max();
uint64_t drop_seed = 1;
uint64_t drop_offset = 0;
void* rand_val_ptr = nullptr; void* rand_val_ptr = nullptr;
ck_tile::index_t stride_randval = 0; ck_tile::index_t stride_randval = 0;
ck_tile::index_t nhead_stride_randval = 0; ck_tile::index_t nhead_stride_randval = 0;
}; };
struct FmhaBwdBatchModeDropoutKargs : FmhaBwdCommonDropoutKargs struct FmhaBwdBatchModeDropoutKargs : FmhaBwdCommonDropoutKargs
{ {
ck_tile::index_t batch_stride_randval = 0; ck_tile::index_t batch_stride_randval = 0;
}; };
struct FmhaBwdDeterministicKargs struct FmhaBwdDeterministicKargs
{ {
ck_tile::index_t split_stride_dq_acc = 0; ck_tile::index_t split_stride_dq_acc = 0;
...@@ -327,7 +360,8 @@ struct FmhaBwdDQDKDVKernel ...@@ -327,7 +360,8 @@ struct FmhaBwdDQDKDVKernel
ck_tile::index_t window_size_right, ck_tile::index_t window_size_right,
ck_tile::index_t mask_type, ck_tile::index_t mask_type,
float p_drop, float p_drop,
const std::tuple<uint64_t, uint64_t>& drop_seed_offset) std::variant<std::pair<uint64_t, uint64_t>, std::pair<const void*, const void*>>
drop_seed_offset)
{ {
Kargs kargs{{q_ptr, Kargs kargs{{q_ptr,
k_ptr, k_ptr,
...@@ -405,7 +439,20 @@ struct FmhaBwdDQDKDVKernel ...@@ -405,7 +439,20 @@ struct FmhaBwdDQDKDVKernel
if constexpr(kHasDropout) if constexpr(kHasDropout)
{ {
kargs.init_dropout(p_drop, drop_seed_offset, scale); if(drop_seed_offset.index() == 0) // seed & offset come from host
{
const auto& [seed, offset] = std::get<0>(drop_seed_offset);
kargs.init_dropout(p_drop, seed, offset, scale);
}
else // seed & offset come from device
{
const auto& [seed_ptr, offset_ptr] = std::get<1>(drop_seed_offset);
kargs.init_dropout(p_drop,
reinterpret_cast<const uint64_t*>(seed_ptr),
reinterpret_cast<const uint64_t*>(offset_ptr),
scale);
}
if constexpr(kIsStoreRandval) if constexpr(kIsStoreRandval)
{ {
kargs.rand_val_ptr = rand_val_ptr; kargs.rand_val_ptr = rand_val_ptr;
...@@ -471,7 +518,8 @@ struct FmhaBwdDQDKDVKernel ...@@ -471,7 +518,8 @@ struct FmhaBwdDQDKDVKernel
ck_tile::index_t window_size_right, ck_tile::index_t window_size_right,
ck_tile::index_t mask_type, ck_tile::index_t mask_type,
float p_drop, float p_drop,
const std::tuple<uint64_t, uint64_t>& drop_seed_offset) std::variant<std::pair<uint64_t, uint64_t>, std::pair<const void*, const void*>>
drop_seed_offset)
{ {
Kargs kargs{{q_ptr, Kargs kargs{{q_ptr,
k_ptr, k_ptr,
...@@ -539,7 +587,20 @@ struct FmhaBwdDQDKDVKernel ...@@ -539,7 +587,20 @@ struct FmhaBwdDQDKDVKernel
} }
if constexpr(kHasDropout) if constexpr(kHasDropout)
{ {
kargs.init_dropout(p_drop, drop_seed_offset, scale); if(drop_seed_offset.index() == 0) // seed & offset come from host
{
const auto& [seed, offset] = std::get<0>(drop_seed_offset);
kargs.init_dropout(p_drop, seed, offset, scale);
}
else // seed & offset come from device
{
const auto& [seed_ptr, offset_ptr] = std::get<1>(drop_seed_offset);
kargs.init_dropout(p_drop,
reinterpret_cast<const uint64_t*>(seed_ptr),
reinterpret_cast<const uint64_t*>(offset_ptr),
scale);
}
if constexpr(kIsStoreRandval) if constexpr(kIsStoreRandval)
{ {
kargs.rand_val_ptr = rand_val_ptr; kargs.rand_val_ptr = rand_val_ptr;
...@@ -958,8 +1019,10 @@ struct FmhaBwdDQDKDVKernel ...@@ -958,8 +1019,10 @@ struct FmhaBwdDQDKDVKernel
return FmhaDropout{i_batch_, return FmhaDropout{i_batch_,
i_nhead_, i_nhead_,
kargs.num_head_q, kargs.num_head_q,
kargs.drop_seed, kargs.is_drop_seed_offset_from_host ? kargs.drop_seed.val
kargs.drop_offset, : *kargs.drop_seed.ptr,
kargs.is_drop_seed_offset_from_host ? kargs.drop_offset.val
: *kargs.drop_offset.ptr,
kargs.rp_undrop, kargs.rp_undrop,
kargs.p_undrop_in_uint8_t}; kargs.p_undrop_in_uint8_t};
} }
......
...@@ -6,8 +6,11 @@ ...@@ -6,8 +6,11 @@
#include "ck_tile/core.hpp" #include "ck_tile/core.hpp"
#include "ck_tile/ops/common.hpp" #include "ck_tile/ops/common.hpp"
#include "ck_tile/ops/fmha/block/block_attention_bias_enum.hpp" #include "ck_tile/ops/fmha/block/block_attention_bias_enum.hpp"
#include <string> #include <string>
#include <type_traits> #include <type_traits>
#include <utility>
#include <variant>
// S[seqlen_q, seqlen_k] = Q[seqlen_q, hdim_q] @ K[seqlen_k, hdim_q] // S[seqlen_q, seqlen_k] = Q[seqlen_q, hdim_q] @ K[seqlen_k, hdim_q]
// S'[seqlen_q, seqlen_k] = S[seqlen_q, seqlen_k] * Scale[1] // S'[seqlen_q, seqlen_k] = S[seqlen_q, seqlen_k] * Scale[1]
...@@ -170,29 +173,55 @@ struct FmhaFwdKernel ...@@ -170,29 +173,55 @@ struct FmhaFwdKernel
ck_tile::index_t batch_stride_lse = 0; ck_tile::index_t batch_stride_lse = 0;
}; };
struct FmhaFwdCommonDropoutKargs struct FmhaFwdDropoutSeedOffset
{ {
void init_dropout(const float p_drop, template <typename T>
const std::tuple<uint64_t, uint64_t>& drop_seed_offset) union ValueOrPointer
{
T val;
const T* ptr;
};
ValueOrPointer<uint64_t> drop_seed;
ValueOrPointer<uint64_t> drop_offset;
bool is_drop_seed_offset_from_host;
};
struct FmhaFwdCommonDropoutKargs : FmhaFwdDropoutSeedOffset
{
void init_dropout(float p_drop, uint64_t seed, uint64_t offset)
{
float p_undrop = 1.0 - p_drop;
p_undrop_in_uint8_t =
uint8_t(std::floor(p_undrop * std::numeric_limits<uint8_t>::max()));
rp_undrop = 1.0 / p_undrop;
this->drop_seed.val = seed;
this->drop_offset.val = offset;
this->is_drop_seed_offset_from_host = true;
}
void init_dropout(float p_drop, const uint64_t* seed_ptr, const uint64_t* offset_ptr)
{ {
float p_undrop = 1.0 - p_drop; float p_undrop = 1.0 - p_drop;
p_undrop_in_uint8_t = p_undrop_in_uint8_t =
uint8_t(std::floor(p_undrop * std::numeric_limits<uint8_t>::max())); uint8_t(std::floor(p_undrop * std::numeric_limits<uint8_t>::max()));
rp_undrop = 1.0 / p_undrop; rp_undrop = 1.0 / p_undrop;
drop_seed = std::get<0>(drop_seed_offset); this->drop_seed.ptr = seed_ptr;
drop_offset = std::get<1>(drop_seed_offset); this->drop_offset.ptr = offset_ptr;
this->is_drop_seed_offset_from_host = false;
} }
float rp_undrop = 1; float rp_undrop = 1;
uint8_t p_undrop_in_uint8_t = std::numeric_limits<uint8_t>::max(); uint8_t p_undrop_in_uint8_t = std::numeric_limits<uint8_t>::max();
bool is_store_randval = false; bool is_store_randval = false;
uint64_t drop_seed = 1;
uint64_t drop_offset = 0;
void* rand_val_ptr = nullptr; void* rand_val_ptr = nullptr;
ck_tile::index_t stride_randval = 0; ck_tile::index_t stride_randval = 0;
ck_tile::index_t nhead_stride_randval = 0; ck_tile::index_t nhead_stride_randval = 0;
}; };
struct FmhaFwdBatchModeDropoutKargs : FmhaFwdCommonDropoutKargs struct FmhaFwdBatchModeDropoutKargs : FmhaFwdCommonDropoutKargs
{ {
ck_tile::index_t batch_stride_randval = 0; ck_tile::index_t batch_stride_randval = 0;
...@@ -278,7 +307,8 @@ struct FmhaFwdKernel ...@@ -278,7 +307,8 @@ struct FmhaFwdKernel
ck_tile::index_t mask_type, ck_tile::index_t mask_type,
float p_drop, float p_drop,
bool s_randval, bool s_randval,
const std::tuple<uint64_t, uint64_t>& drop_seed_offset) std::variant<std::pair<uint64_t, uint64_t>, std::pair<const void*, const void*>>
drop_seed_offset)
{ {
Kargs kargs{{q_ptr, Kargs kargs{{q_ptr,
k_ptr, k_ptr,
...@@ -344,7 +374,19 @@ struct FmhaFwdKernel ...@@ -344,7 +374,19 @@ struct FmhaFwdKernel
} }
if constexpr(kHasDropout) if constexpr(kHasDropout)
{ {
kargs.init_dropout(p_drop, drop_seed_offset); if(drop_seed_offset.index() == 0) // seed & offset come from host
{
const auto& [seed, offset] = std::get<0>(drop_seed_offset);
kargs.init_dropout(p_drop, seed, offset);
}
else // seed & offset come from device
{
const auto& [seed_ptr, offset_ptr] = std::get<1>(drop_seed_offset);
kargs.init_dropout(p_drop,
reinterpret_cast<const uint64_t*>(seed_ptr),
reinterpret_cast<const uint64_t*>(offset_ptr));
}
kargs.rand_val_ptr = rand_val_ptr; kargs.rand_val_ptr = rand_val_ptr;
kargs.stride_randval = stride_randval; kargs.stride_randval = stride_randval;
kargs.nhead_stride_randval = nhead_stride_randval; kargs.nhead_stride_randval = nhead_stride_randval;
...@@ -392,7 +434,8 @@ struct FmhaFwdKernel ...@@ -392,7 +434,8 @@ struct FmhaFwdKernel
ck_tile::index_t mask_type, ck_tile::index_t mask_type,
float p_drop, float p_drop,
bool s_randval, bool s_randval,
const std::tuple<uint64_t, uint64_t>& drop_seed_offset) std::variant<std::pair<uint64_t, uint64_t>, std::pair<const void*, const void*>>
drop_seed_offset)
{ {
Kargs kargs{{q_ptr, Kargs kargs{{q_ptr,
k_ptr, k_ptr,
...@@ -455,7 +498,19 @@ struct FmhaFwdKernel ...@@ -455,7 +498,19 @@ struct FmhaFwdKernel
} }
if constexpr(kHasDropout) if constexpr(kHasDropout)
{ {
kargs.init_dropout(p_drop, drop_seed_offset); if(drop_seed_offset.index() == 0) // seed & offset come from host
{
const auto& [seed, offset] = std::get<0>(drop_seed_offset);
kargs.init_dropout(p_drop, seed, offset);
}
else // seed & offset come from device
{
const auto& [seed_ptr, offset_ptr] = std::get<1>(drop_seed_offset);
kargs.init_dropout(p_drop,
reinterpret_cast<const uint64_t*>(seed_ptr),
reinterpret_cast<const uint64_t*>(offset_ptr));
}
kargs.rand_val_ptr = rand_val_ptr; kargs.rand_val_ptr = rand_val_ptr;
kargs.stride_randval = stride_randval; kargs.stride_randval = stride_randval;
kargs.nhead_stride_randval = nhead_stride_randval; kargs.nhead_stride_randval = nhead_stride_randval;
...@@ -748,8 +803,10 @@ struct FmhaFwdKernel ...@@ -748,8 +803,10 @@ struct FmhaFwdKernel
return BlockDropout{i_batch_, return BlockDropout{i_batch_,
i_nhead_, i_nhead_,
kargs.num_head_q, kargs.num_head_q,
kargs.drop_seed, kargs.is_drop_seed_offset_from_host ? kargs.drop_seed.val
kargs.drop_offset, : *kargs.drop_seed.ptr,
kargs.is_drop_seed_offset_from_host ? kargs.drop_offset.val
: *kargs.drop_offset.ptr,
kargs.rp_undrop, kargs.rp_undrop,
kargs.p_undrop_in_uint8_t, kargs.p_undrop_in_uint8_t,
kargs.is_store_randval}; kargs.is_store_randval};
......
...@@ -5,7 +5,7 @@ ...@@ -5,7 +5,7 @@
#include "ck_tile/core.hpp" #include "ck_tile/core.hpp"
#include "ck_tile/ops/common/tensor_layout.hpp" #include "ck_tile/ops/common/tensor_layout.hpp"
#include "ck_tile/ops/gemm/pipeline/block_gemm_pipeline_problem.hpp" #include "ck_tile/ops/gemm/block/block_gemm_problem.hpp"
#include "ck_tile/ops/gemm/pipeline/tile_gemm_shape.hpp" #include "ck_tile/ops/gemm/pipeline/tile_gemm_shape.hpp"
#include "ck_tile/ops/gemm/warp/warp_gemm.hpp" #include "ck_tile/ops/gemm/warp/warp_gemm.hpp"
#include "ck_tile/ops/gemm/warp/warp_gemm_dispatcher.hpp" #include "ck_tile/ops/gemm/warp/warp_gemm_dispatcher.hpp"
...@@ -25,15 +25,16 @@ struct BlockFmhaBwdPipelineDefaultPolicy ...@@ -25,15 +25,16 @@ struct BlockFmhaBwdPipelineDefaultPolicy
template <typename Problem> template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto GetQKBlockGemm() CK_TILE_HOST_DEVICE static constexpr auto GetQKBlockGemm()
{ {
using BlockGemmProblem = BlockGemmPipelineProblem< using GemmProblem =
typename Problem::QDataType, BlockGemmProblem<typename Problem::QDataType,
typename Problem::KDataType, typename Problem::KDataType,
typename Problem::AccDataType, typename Problem::AccDataType,
TileGemmShape<sequence<Problem::BlockFmhaShape::kM0, Problem::kBlockSize,
Problem::BlockFmhaShape::kN0, TileGemmShape<sequence<Problem::BlockFmhaShape::kM0,
Problem::BlockFmhaShape::kK0>, Problem::BlockFmhaShape::kN0,
typename Problem::BlockFmhaShape::Gemm0BlockWarps, Problem::BlockFmhaShape::kK0>,
typename Problem::BlockFmhaShape::Gemm0WarpTile>>; typename Problem::BlockFmhaShape::Gemm0BlockWarps,
typename Problem::BlockFmhaShape::Gemm0WarpTile>>;
using WarpGemm = WarpGemmMfmaDispatcher< using WarpGemm = WarpGemmMfmaDispatcher<
typename Problem::QDataType, typename Problem::QDataType,
...@@ -52,21 +53,22 @@ struct BlockFmhaBwdPipelineDefaultPolicy ...@@ -52,21 +53,22 @@ struct BlockFmhaBwdPipelineDefaultPolicy
typename Problem::BlockFmhaShape::Gemm0BlockWarps, typename Problem::BlockFmhaShape::Gemm0BlockWarps,
WarpGemm>; WarpGemm>;
return BlockGemmARegBRegCRegV1<BlockGemmProblem, BlockGemmPolicy>{}; return BlockGemmARegBRegCRegV1<GemmProblem, BlockGemmPolicy>{};
} }
template <typename Problem> template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto GetPTOGradTBlockGemm() CK_TILE_HOST_DEVICE static constexpr auto GetPTOGradTBlockGemm()
{ {
using BlockGemmProblem = BlockGemmPipelineProblem< using GemmProblem =
typename Problem::GemmDataType, BlockGemmProblem<typename Problem::GemmDataType,
typename Problem::OGradDataType, typename Problem::OGradDataType,
typename Problem::AccDataType, typename Problem::AccDataType,
TileGemmShape<sequence<Problem::BlockFmhaShape::kN0, Problem::kBlockSize,
Problem::BlockFmhaShape::kVHeaddim, TileGemmShape<sequence<Problem::BlockFmhaShape::kN0,
Problem::BlockFmhaShape::kK1>, Problem::BlockFmhaShape::kVHeaddim,
typename Problem::BlockFmhaShape::Gemm1BlockWarps, Problem::BlockFmhaShape::kK1>,
typename Problem::BlockFmhaShape::Gemm1WarpTile>>; typename Problem::BlockFmhaShape::Gemm1BlockWarps,
typename Problem::BlockFmhaShape::Gemm1WarpTile>>;
using WarpGemm = using WarpGemm =
WarpGemmMfmaDispatcher<typename Problem::GemmDataType, WarpGemmMfmaDispatcher<typename Problem::GemmDataType,
...@@ -84,21 +86,22 @@ struct BlockFmhaBwdPipelineDefaultPolicy ...@@ -84,21 +86,22 @@ struct BlockFmhaBwdPipelineDefaultPolicy
typename Problem::BlockFmhaShape::Gemm1BlockWarps, typename Problem::BlockFmhaShape::Gemm1BlockWarps,
WarpGemm>; WarpGemm>;
return BlockGemmARegBRegCRegV1<BlockGemmProblem, BlockGemmPolicy>{}; return BlockGemmARegBRegCRegV1<GemmProblem, BlockGemmPolicy>{};
} }
template <typename Problem> template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto GetOGradVBlockGemm() CK_TILE_HOST_DEVICE static constexpr auto GetOGradVBlockGemm()
{ {
using BlockGemmProblem = BlockGemmPipelineProblem< using GemmProblem =
typename Problem::OGradDataType, BlockGemmProblem<typename Problem::OGradDataType,
typename Problem::VDataType, typename Problem::VDataType,
typename Problem::AccDataType, typename Problem::AccDataType,
TileGemmShape<sequence<Problem::BlockFmhaShape::kM0, Problem::kBlockSize,
Problem::BlockFmhaShape::kN0, TileGemmShape<sequence<Problem::BlockFmhaShape::kM0,
Problem::BlockFmhaShape::kK2>, Problem::BlockFmhaShape::kN0,
typename Problem::BlockFmhaShape::Gemm2BlockWarps, Problem::BlockFmhaShape::kK2>,
typename Problem::BlockFmhaShape::Gemm2WarpTile>>; typename Problem::BlockFmhaShape::Gemm2BlockWarps,
typename Problem::BlockFmhaShape::Gemm2WarpTile>>;
using WarpGemm = WarpGemmMfmaDispatcher< using WarpGemm = WarpGemmMfmaDispatcher<
typename Problem::OGradDataType, typename Problem::OGradDataType,
...@@ -117,21 +120,22 @@ struct BlockFmhaBwdPipelineDefaultPolicy ...@@ -117,21 +120,22 @@ struct BlockFmhaBwdPipelineDefaultPolicy
typename Problem::BlockFmhaShape::Gemm2BlockWarps, typename Problem::BlockFmhaShape::Gemm2BlockWarps,
WarpGemm>; WarpGemm>;
return BlockGemmARegBRegCRegV1<BlockGemmProblem, BlockGemmPolicy>{}; return BlockGemmARegBRegCRegV1<GemmProblem, BlockGemmPolicy>{};
} }
template <typename Problem> template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto GetSGradTQTBlockGemm() CK_TILE_HOST_DEVICE static constexpr auto GetSGradTQTBlockGemm()
{ {
using BlockGemmProblem = BlockGemmPipelineProblem< using GemmProblem =
typename Problem::GemmDataType, BlockGemmProblem<typename Problem::GemmDataType,
typename Problem::QDataType, typename Problem::QDataType,
typename Problem::AccDataType, typename Problem::AccDataType,
TileGemmShape<sequence<Problem::BlockFmhaShape::kN0, Problem::kBlockSize,
Problem::BlockFmhaShape::kQKHeaddim, TileGemmShape<sequence<Problem::BlockFmhaShape::kN0,
Problem::BlockFmhaShape::kK3>, Problem::BlockFmhaShape::kQKHeaddim,
typename Problem::BlockFmhaShape::Gemm3BlockWarps, Problem::BlockFmhaShape::kK3>,
typename Problem::BlockFmhaShape::Gemm3WarpTile>>; typename Problem::BlockFmhaShape::Gemm3BlockWarps,
typename Problem::BlockFmhaShape::Gemm3WarpTile>>;
using WarpGemm = using WarpGemm =
WarpGemmMfmaDispatcher<typename Problem::GemmDataType, WarpGemmMfmaDispatcher<typename Problem::GemmDataType,
...@@ -149,21 +153,22 @@ struct BlockFmhaBwdPipelineDefaultPolicy ...@@ -149,21 +153,22 @@ struct BlockFmhaBwdPipelineDefaultPolicy
typename Problem::BlockFmhaShape::Gemm3BlockWarps, typename Problem::BlockFmhaShape::Gemm3BlockWarps,
WarpGemm>; WarpGemm>;
return BlockGemmARegBRegCRegV1<BlockGemmProblem, BlockGemmPolicy>{}; return BlockGemmARegBRegCRegV1<GemmProblem, BlockGemmPolicy>{};
} }
template <typename Problem> template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto GetSGradKTBlockGemm() CK_TILE_HOST_DEVICE static constexpr auto GetSGradKTBlockGemm()
{ {
using BlockGemmProblem = BlockGemmPipelineProblem< using GemmProblem =
typename Problem::GemmDataType, BlockGemmProblem<typename Problem::GemmDataType,
typename Problem::KDataType, typename Problem::KDataType,
typename Problem::AccDataType, typename Problem::AccDataType,
TileGemmShape<sequence<Problem::BlockFmhaShape::kM0, Problem::kBlockSize,
Problem::BlockFmhaShape::kQKHeaddim, TileGemmShape<sequence<Problem::BlockFmhaShape::kM0,
Problem::BlockFmhaShape::kK4>, Problem::BlockFmhaShape::kQKHeaddim,
typename Problem::BlockFmhaShape::Gemm4BlockWarps, Problem::BlockFmhaShape::kK4>,
typename Problem::BlockFmhaShape::Gemm4WarpTile>>; typename Problem::BlockFmhaShape::Gemm4BlockWarps,
typename Problem::BlockFmhaShape::Gemm4WarpTile>>;
using WarpGemm = using WarpGemm =
WarpGemmMfmaDispatcher<typename Problem::GemmDataType, WarpGemmMfmaDispatcher<typename Problem::GemmDataType,
...@@ -181,7 +186,7 @@ struct BlockFmhaBwdPipelineDefaultPolicy ...@@ -181,7 +186,7 @@ struct BlockFmhaBwdPipelineDefaultPolicy
typename Problem::BlockFmhaShape::Gemm4BlockWarps, typename Problem::BlockFmhaShape::Gemm4BlockWarps,
WarpGemm>; WarpGemm>;
return BlockGemmARegBRegCRegV1<BlockGemmProblem, BlockGemmPolicy>{}; return BlockGemmARegBRegCRegV1<GemmProblem, BlockGemmPolicy>{};
} }
// these are for global load // these are for global load
......
...@@ -172,22 +172,27 @@ struct BlockFmhaFwdSplitKVCombinePipeline ...@@ -172,22 +172,27 @@ struct BlockFmhaFwdSplitKVCombinePipeline
lse_accum, sequence<1>{}, f_max, -numeric<LSEDataType>::infinity()); lse_accum, sequence<1>{}, f_max, -numeric<LSEDataType>::infinity());
block_tile_reduce_sync(lse_max, f_max, bool_constant<false>{}); block_tile_reduce_sync(lse_max, f_max, bool_constant<false>{});
static const auto get_validated_m = [](LSEDataType raw_m) {
return raw_m == -numeric<LSEDataType>::infinity() ? type_convert<LSEDataType>(0.f)
: raw_m;
};
decltype(lse_accum) lse_exp; decltype(lse_accum) lse_exp;
{ {
constexpr auto spans = decltype(lse_exp)::get_distributed_spans(); constexpr auto spans = decltype(lse_exp)::get_distributed_spans();
sweep_tile_span(spans[number<0>{}], [&](auto idx0) { sweep_tile_span(spans[number<0>{}], [&](auto idx0) {
constexpr auto i_idx = make_tuple(idx0); constexpr auto i_idx = make_tuple(idx0);
sweep_tile_span(spans[number<1>{}], [&](auto idx1) { if(lse_max[i_idx] == -numeric<LSEDataType>::infinity())
constexpr auto i_j_idx = make_tuple(idx0, idx1); {
sweep_tile_span(spans[number<1>{}], [&](auto idx1) {
constexpr auto i_j_idx = make_tuple(idx0, idx1);
lse_exp(i_j_idx) = lse_exp(i_j_idx) = ck_tile::type_convert<LSEDataType>(0.0f);
ck_tile::exp(lse_accum(i_j_idx) - get_validated_m(lse_max(i_idx))); });
}); }
else
{
sweep_tile_span(spans[number<1>{}], [&](auto idx1) {
constexpr auto i_j_idx = make_tuple(idx0, idx1);
lse_exp(i_j_idx) = ck_tile::exp(lse_accum(i_j_idx) - lse_max(i_idx));
});
}
}); });
} }
...@@ -201,15 +206,10 @@ struct BlockFmhaFwdSplitKVCombinePipeline ...@@ -201,15 +206,10 @@ struct BlockFmhaFwdSplitKVCombinePipeline
sweep_tile_span(spans[number<0>{}], [&](auto idx0) { sweep_tile_span(spans[number<0>{}], [&](auto idx0) {
constexpr auto i_idx = make_tuple(idx0); constexpr auto i_idx = make_tuple(idx0);
if(lse_sum(i_idx) == 0.f || lse_sum(i_idx) != lse_sum(i_idx)) if(lse_sum[i_idx] == ck_tile::type_convert<LSEDataType>(0.0f))
{ lse_logsum(i_idx) = -numeric<LSEDataType>::infinity();
lse_logsum(i_idx) = numeric<LSEDataType>::infinity();
}
else else
{ lse_logsum(i_idx) = ck_tile::log(lse_sum(i_idx)) + lse_max(i_idx);
lse_logsum(i_idx) =
ck_tile::log(lse_sum(i_idx)) + get_validated_m(lse_max(i_idx));
}
}); });
} }
...@@ -218,37 +218,47 @@ struct BlockFmhaFwdSplitKVCombinePipeline ...@@ -218,37 +218,47 @@ struct BlockFmhaFwdSplitKVCombinePipeline
constexpr auto spans = decltype(lse_accum)::get_distributed_spans(); constexpr auto spans = decltype(lse_accum)::get_distributed_spans();
sweep_tile_span(spans[number<0>{}], [&](auto idx0) { sweep_tile_span(spans[number<0>{}], [&](auto idx0) {
constexpr auto i_idx = make_tuple(idx0); constexpr auto i_idx = make_tuple(idx0);
sweep_tile_span(spans[number<1>{}], [&](auto idx1) { if(lse_logsum(i_idx) == -numeric<LSEDataType>::infinity())
constexpr auto i_j_idx = make_tuple(idx0, idx1); {
sweep_tile_span(spans[number<1>{}], [&](auto idx1) {
constexpr auto i_j_idx = make_tuple(idx0, idx1);
const auto x_indices = get_x_indices_from_distributed_indices( const auto x_indices = get_x_indices_from_distributed_indices(
lse_accum.get_tile_distribution(), i_j_idx); lse_accum.get_tile_distribution(), i_j_idx);
const auto col = x_indices.at(number<1>{}); const auto col = x_indices.at(number<1>{});
if(col < num_splits) if(col < num_splits)
{ {
const auto row = x_indices.at(number<0>{}); const auto row = x_indices.at(number<0>{});
lse_acc_lds(row, col) = lse_acc_lds(row, col) = ck_tile::type_convert<LSEDataType>(0.0f);
ck_tile::exp(lse_accum(i_j_idx) - lse_logsum(i_idx)); }
} });
}); }
else
{
sweep_tile_span(spans[number<1>{}], [&](auto idx1) {
constexpr auto i_j_idx = make_tuple(idx0, idx1);
const auto x_indices = get_x_indices_from_distributed_indices(
lse_accum.get_tile_distribution(), i_j_idx);
const auto col = x_indices.at(number<1>{});
if(col < num_splits)
{
const auto row = x_indices.at(number<0>{});
lse_acc_lds(row, col) =
ck_tile::exp(lse_accum(i_j_idx) - lse_logsum(i_idx));
}
});
}
}); });
} }
block_sync_lds(); block_sync_lds();
if constexpr(kStoreLSE) if constexpr(kStoreLSE)
{ {
constexpr auto spans = decltype(lse_logsum)::get_distributed_spans();
sweep_tile_span(spans[number<0>{}], [&](auto idx0) {
constexpr auto i_idx = make_tuple(idx0);
if(lse_logsum(i_idx) == numeric<LSEDataType>::infinity())
{
lse_logsum(i_idx) = -numeric<LSEDataType>::infinity();
}
});
store_tile(lse_dram_window_tmp, tile_elementwise_in(lse_element_func, lse_logsum)); store_tile(lse_dram_window_tmp, tile_elementwise_in(lse_element_func, lse_logsum));
} }
......
...@@ -21,14 +21,23 @@ struct BlockFmhaFwdSplitKVCombinePipelineDefaultPolicy ...@@ -21,14 +21,23 @@ struct BlockFmhaFwdSplitKVCombinePipelineDefaultPolicy
CK_TILE_HOST_DEVICE static constexpr auto GetAlignmentOacc() CK_TILE_HOST_DEVICE static constexpr auto GetAlignmentOacc()
{ {
using OaccDataType = remove_cvref_t<typename Problem::OaccDataType>; using OaccDataType = remove_cvref_t<typename Problem::OaccDataType>;
return 16 / sizeof(OaccDataType);
constexpr index_t kBlockSize = Problem::kBlockSize;
constexpr index_t kMPerBlock = Problem::kM0;
constexpr index_t kNPerBlock = Problem::kN1;
constexpr index_t M1 = kBlockSize / get_warp_size();
constexpr index_t M2 = min(kMPerBlock / M1, get_warp_size());
constexpr index_t N0 = get_warp_size() / M2;
constexpr index_t N1 = kNPerBlock / N0;
return min(N1, static_cast<index_t>(16 / sizeof(OaccDataType)));
} }
template <typename Problem> template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto GetAlignmentO() CK_TILE_HOST_DEVICE static constexpr auto GetAlignmentO()
{ {
using ODataType = remove_cvref_t<typename Problem::ODataType>; return GetAlignmentOacc<Problem>();
return 16 / sizeof(ODataType);
} }
template <typename Problem> template <typename Problem>
...@@ -150,16 +159,14 @@ struct BlockFmhaFwdSplitKVCombinePipelineDefaultPolicy ...@@ -150,16 +159,14 @@ struct BlockFmhaFwdSplitKVCombinePipelineDefaultPolicy
template <typename Problem> template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto MakeOaccDramTileDistribution() CK_TILE_HOST_DEVICE static constexpr auto MakeOaccDramTileDistribution()
{ {
using OaccDataType = remove_cvref_t<typename Problem::OaccDataType>;
constexpr index_t kBlockSize = Problem::kBlockSize; constexpr index_t kBlockSize = Problem::kBlockSize;
constexpr index_t kMPerBlock = Problem::kM0; constexpr index_t kMPerBlock = Problem::kM0;
constexpr index_t kNPerBlock = Problem::kN1; constexpr index_t kNPerBlock = Problem::kN1;
constexpr index_t N1 = 16 / sizeof(OaccDataType);
constexpr index_t N0 = kNPerBlock / N1;
constexpr index_t M2 = get_warp_size() / N0;
constexpr index_t M1 = kBlockSize / get_warp_size(); constexpr index_t M1 = kBlockSize / get_warp_size();
constexpr index_t M2 = min(kMPerBlock / M1, get_warp_size());
constexpr index_t N0 = get_warp_size() / M2;
constexpr index_t N1 = kNPerBlock / N0;
constexpr index_t M0 = kMPerBlock / (M2 * M1); constexpr index_t M0 = kMPerBlock / (M2 * M1);
return make_static_tile_distribution( return make_static_tile_distribution(
......
...@@ -5,7 +5,8 @@ ...@@ -5,7 +5,8 @@
#include "ck_tile/core.hpp" #include "ck_tile/core.hpp"
#include "ck_tile/ops/common/tensor_layout.hpp" #include "ck_tile/ops/common/tensor_layout.hpp"
#include "ck_tile/ops/gemm/pipeline/block_gemm_pipeline_problem.hpp" #include "ck_tile/ops/gemm/block/block_gemm_problem.hpp"
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_problem.hpp"
#include "ck_tile/ops/gemm/pipeline/tile_gemm_shape.hpp" #include "ck_tile/ops/gemm/pipeline/tile_gemm_shape.hpp"
#include "ck_tile/ops/gemm/warp/warp_gemm.hpp" #include "ck_tile/ops/gemm/warp/warp_gemm.hpp"
#include "ck_tile/ops/gemm/warp/warp_gemm_dispatcher.hpp" #include "ck_tile/ops/gemm/warp/warp_gemm_dispatcher.hpp"
...@@ -75,15 +76,16 @@ struct BlockFmhaPipelineQXCustomPolicy</* QLoadOnce = */ true> ...@@ -75,15 +76,16 @@ struct BlockFmhaPipelineQXCustomPolicy</* QLoadOnce = */ true>
template <typename Problem> template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto GetQKBlockGemm() CK_TILE_HOST_DEVICE static constexpr auto GetQKBlockGemm()
{ {
using BlockGemmProblem = BlockGemmPipelineProblem< using GemmProblem =
typename Problem::QDataType, BlockGemmProblem<typename Problem::QDataType,
typename Problem::KDataType, typename Problem::KDataType,
typename Problem::SaccDataType, typename Problem::SaccDataType,
TileGemmShape<sequence<Problem::BlockFmhaShape::kM0, Problem::kBlockSize,
Problem::BlockFmhaShape::kN0, TileGemmShape<sequence<Problem::BlockFmhaShape::kM0,
Problem::BlockFmhaShape::kK0>, Problem::BlockFmhaShape::kN0,
typename Problem::BlockFmhaShape::Gemm0BlockWarps, Problem::BlockFmhaShape::kK0>,
typename Problem::BlockFmhaShape::Gemm0WarpTile>>; typename Problem::BlockFmhaShape::Gemm0BlockWarps,
typename Problem::BlockFmhaShape::Gemm0WarpTile>>;
constexpr auto warp_gemm = []() { constexpr auto warp_gemm = []() {
if constexpr(std::is_same_v<typename Problem::QDataType, half_t> && if constexpr(std::is_same_v<typename Problem::QDataType, half_t> &&
...@@ -116,7 +118,7 @@ struct BlockFmhaPipelineQXCustomPolicy</* QLoadOnce = */ true> ...@@ -116,7 +118,7 @@ struct BlockFmhaPipelineQXCustomPolicy</* QLoadOnce = */ true>
typename Problem::BlockFmhaShape::Gemm0BlockWarps, typename Problem::BlockFmhaShape::Gemm0BlockWarps,
decltype(warp_gemm)>; decltype(warp_gemm)>;
return BlockGemmARegBSmemCRegV2<BlockGemmProblem, BlockGemmPolicy>{}; return BlockGemmARegBSmemCRegV2<GemmProblem, BlockGemmPolicy>{};
} }
}; };
...@@ -199,15 +201,16 @@ struct BlockFmhaPipelineQXCustomPolicy</* QLoadOnce = */ false> ...@@ -199,15 +201,16 @@ struct BlockFmhaPipelineQXCustomPolicy</* QLoadOnce = */ false>
template <typename Problem> template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto GetQKBlockGemm() CK_TILE_HOST_DEVICE static constexpr auto GetQKBlockGemm()
{ {
using BlockGemmProblem = BlockGemmPipelineProblem< using GemmProblem =
typename Problem::QDataType, BlockGemmProblem<typename Problem::QDataType,
typename Problem::KDataType, typename Problem::KDataType,
typename Problem::SaccDataType, typename Problem::SaccDataType,
TileGemmShape<sequence<Problem::BlockFmhaShape::kM0, Problem::kBlockSize,
Problem::BlockFmhaShape::kN0, TileGemmShape<sequence<Problem::BlockFmhaShape::kM0,
Problem::BlockFmhaShape::kK0>, Problem::BlockFmhaShape::kN0,
typename Problem::BlockFmhaShape::Gemm0BlockWarps, Problem::BlockFmhaShape::kK0>,
typename Problem::BlockFmhaShape::Gemm0WarpTile>>; typename Problem::BlockFmhaShape::Gemm0BlockWarps,
typename Problem::BlockFmhaShape::Gemm0WarpTile>>;
constexpr auto warp_gemm = []() { constexpr auto warp_gemm = []() {
if constexpr(std::is_same_v<typename Problem::QDataType, half_t> && if constexpr(std::is_same_v<typename Problem::QDataType, half_t> &&
...@@ -240,7 +243,7 @@ struct BlockFmhaPipelineQXCustomPolicy</* QLoadOnce = */ false> ...@@ -240,7 +243,7 @@ struct BlockFmhaPipelineQXCustomPolicy</* QLoadOnce = */ false>
typename Problem::BlockFmhaShape::Gemm0BlockWarps, typename Problem::BlockFmhaShape::Gemm0BlockWarps,
decltype(warp_gemm)>; decltype(warp_gemm)>;
return BlockGemmASmemBSmemCRegV1<BlockGemmProblem, BlockGemmPolicy>{}; return BlockGemmASmemBSmemCRegV1<GemmProblem, BlockGemmPolicy>{};
} }
}; };
...@@ -954,15 +957,16 @@ struct BlockFmhaPipelineQXKSVSCustomPolicy : BlockFmhaPipelineQXCustomPolicy<QLo ...@@ -954,15 +957,16 @@ struct BlockFmhaPipelineQXKSVSCustomPolicy : BlockFmhaPipelineQXCustomPolicy<QLo
template <typename Problem> template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto GetKVBlockGemm() CK_TILE_HOST_DEVICE static constexpr auto GetKVBlockGemm()
{ {
using BlockGemmProblem = BlockGemmPipelineProblem< using GemmProblem =
typename Problem::PDataType, BlockGemmProblem<typename Problem::PDataType,
typename Problem::VDataType, typename Problem::VDataType,
typename Problem::OaccDataType, typename Problem::OaccDataType,
TileGemmShape<sequence<Problem::BlockFmhaShape::kM0, Problem::kBlockSize,
Problem::BlockFmhaShape::kN1, TileGemmShape<sequence<Problem::BlockFmhaShape::kM0,
Problem::BlockFmhaShape::kK1>, Problem::BlockFmhaShape::kN1,
typename Problem::BlockFmhaShape::Gemm1BlockWarps, Problem::BlockFmhaShape::kK1>,
typename Problem::BlockFmhaShape::Gemm1WarpTile>>; typename Problem::BlockFmhaShape::Gemm1BlockWarps,
typename Problem::BlockFmhaShape::Gemm1WarpTile>>;
auto warp_gemm = [&]() { auto warp_gemm = [&]() {
if constexpr(std::is_same_v<typename Problem::KDataType, fp8_t> && if constexpr(std::is_same_v<typename Problem::KDataType, fp8_t> &&
...@@ -996,7 +1000,7 @@ struct BlockFmhaPipelineQXKSVSCustomPolicy : BlockFmhaPipelineQXCustomPolicy<QLo ...@@ -996,7 +1000,7 @@ struct BlockFmhaPipelineQXKSVSCustomPolicy : BlockFmhaPipelineQXCustomPolicy<QLo
typename Problem::OaccDataType, typename Problem::OaccDataType,
typename Problem::BlockFmhaShape::Gemm1BlockWarps, typename Problem::BlockFmhaShape::Gemm1BlockWarps,
WarpGemm>; WarpGemm>;
return BlockGemmARegBSmemCRegV2<BlockGemmProblem, BlockGemmPolicy>{}; return BlockGemmARegBSmemCRegV2<GemmProblem, BlockGemmPolicy>{};
} }
}; };
......
...@@ -23,12 +23,13 @@ ...@@ -23,12 +23,13 @@
#include "ck_tile/ops/gemm/block/block_gemm_problem.hpp" #include "ck_tile/ops/gemm/block/block_gemm_problem.hpp"
#include "ck_tile/ops/gemm/kernel/gemm_kernel.hpp" #include "ck_tile/ops/gemm/kernel/gemm_kernel.hpp"
#include "ck_tile/ops/gemm/kernel/gemm_tile_partitioner.hpp" #include "ck_tile/ops/gemm/kernel/gemm_tile_partitioner.hpp"
#include "ck_tile/ops/gemm/pipeline/block_gemm_pipeline_agmem_bgmem_creg_v1.hpp" #include "ck_tile/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_v1.hpp"
#include "ck_tile/ops/gemm/pipeline/block_gemm_pipeline_agmem_bgmem_creg_v1_default_policy.hpp" #include "ck_tile/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_v1_default_policy.hpp"
#include "ck_tile/ops/gemm/pipeline/block_gemm_pipeline_agmem_bgmem_creg_v2.hpp" #include "ck_tile/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_v2.hpp"
#include "ck_tile/ops/gemm/pipeline/block_gemm_pipeline_agmem_bgmem_creg_v2_default_policy.hpp" #include "ck_tile/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_v2_default_policy.hpp"
#include "ck_tile/ops/gemm/pipeline/block_gemm_pipeline_problem.hpp" #include "ck_tile/ops/gemm/pipeline/gemm_pipeline_problem.hpp"
#include "ck_tile/ops/gemm/pipeline/tile_gemm_shape.hpp" #include "ck_tile/ops/gemm/pipeline/tile_gemm_shape.hpp"
#include "ck_tile/ops/gemm/pipeline/tile_gemm_traits.hpp"
#include "ck_tile/ops/gemm/warp/warp_gemm.hpp" #include "ck_tile/ops/gemm/warp/warp_gemm.hpp"
#include "ck_tile/ops/gemm/warp/warp_gemm_attribute_mfma.hpp" #include "ck_tile/ops/gemm/warp/warp_gemm_attribute_mfma.hpp"
#include "ck_tile/ops/gemm/warp/warp_gemm_attribute_mfma_impl.hpp" #include "ck_tile/ops/gemm/warp/warp_gemm_attribute_mfma_impl.hpp"
......
...@@ -11,20 +11,12 @@ ...@@ -11,20 +11,12 @@
namespace ck_tile { namespace ck_tile {
template <typename TilePartitioner_, template <typename TilePartitioner_, typename GemmPipeline_, typename EpiloguePipeline_>
typename GemmPipeline_,
typename EpiloguePipeline_,
typename LayoutA_,
typename LayoutB_,
typename LayoutC_>
struct GemmKernel struct GemmKernel
{ {
using TilePartitioner = remove_cvref_t<TilePartitioner_>; using TilePartitioner = remove_cvref_t<TilePartitioner_>;
using GemmPipeline = remove_cvref_t<GemmPipeline_>; using GemmPipeline = remove_cvref_t<GemmPipeline_>;
using EpiloguePipeline = remove_cvref_t<EpiloguePipeline_>; using EpiloguePipeline = remove_cvref_t<EpiloguePipeline_>;
using LayoutA = remove_cvref_t<LayoutA_>;
using LayoutB = remove_cvref_t<LayoutB_>;
using LayoutC = remove_cvref_t<LayoutC_>;
static constexpr index_t KernelBlockSize = GemmPipeline::kBlockSize; static constexpr index_t KernelBlockSize = GemmPipeline::kBlockSize;
using ADataType = remove_cvref_t<typename GemmPipeline::ADataType>; using ADataType = remove_cvref_t<typename GemmPipeline::ADataType>;
...@@ -32,6 +24,10 @@ struct GemmKernel ...@@ -32,6 +24,10 @@ struct GemmKernel
using CAccDataType = remove_cvref_t<typename GemmPipeline::CDataType>; using CAccDataType = remove_cvref_t<typename GemmPipeline::CDataType>;
using CODataType = remove_cvref_t<typename EpiloguePipeline::ODataType>; using CODataType = remove_cvref_t<typename EpiloguePipeline::ODataType>;
using LayoutA = remove_cvref_t<typename GemmPipeline::LayoutA>;
using LayoutB = remove_cvref_t<typename GemmPipeline::LayoutB>;
using LayoutC = remove_cvref_t<typename GemmPipeline::LayoutC>;
__host__ static constexpr auto GridSize(index_t M_size, index_t N_size, index_t Batch_size) __host__ static constexpr auto GridSize(index_t M_size, index_t N_size, index_t Batch_size)
{ {
return TilePartitioner::GridSize(M_size, N_size, Batch_size); return TilePartitioner::GridSize(M_size, N_size, Batch_size);
...@@ -184,6 +180,7 @@ struct GemmKernel ...@@ -184,6 +180,7 @@ struct GemmKernel
c_pad_view, c_pad_view,
make_tuple(number<TilePartitioner::kM>{}, number<TilePartitioner::kN>{}), make_tuple(number<TilePartitioner::kM>{}, number<TilePartitioner::kN>{}),
{i_m, i_n}); {i_m, i_n});
EpiloguePipeline{}(CBlockWindow_pad, acc); EpiloguePipeline{}(CBlockWindow_pad, acc);
} }
}; };
......
...@@ -4,15 +4,15 @@ ...@@ -4,15 +4,15 @@
#pragma once #pragma once
#include "ck_tile/core.hpp" #include "ck_tile/core.hpp"
#include "ck_tile/ops/gemm/pipeline/block_gemm_pipeline_agmem_bgmem_creg_v1_default_policy.hpp" #include "ck_tile/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_v1_default_policy.hpp"
namespace ck_tile { namespace ck_tile {
// A Tile Window: global memory // A Tile Window: global memory
// B Tile Window: global memory // B Tile Window: global memory
// C Distributed tensor: register // C Distributed tensor: register
template <typename Problem, typename Policy = BlockGemmPipelineAGmemBGmemCRegV1DefaultPolicy> template <typename Problem, typename Policy = GemmPipelineAGmemBGmemCRegV1DefaultPolicy>
struct BlockGemmPipelineAGmemBGmemCRegV1 struct GemmPipelineAGmemBGmemCRegV1
{ {
using ADataType = remove_cvref_t<typename Problem::ADataType>; using ADataType = remove_cvref_t<typename Problem::ADataType>;
using BDataType = remove_cvref_t<typename Problem::BDataType>; using BDataType = remove_cvref_t<typename Problem::BDataType>;
...@@ -33,6 +33,10 @@ struct BlockGemmPipelineAGmemBGmemCRegV1 ...@@ -33,6 +33,10 @@ struct BlockGemmPipelineAGmemBGmemCRegV1
static constexpr bool kPadB = Problem::kPadB; static constexpr bool kPadB = Problem::kPadB;
static constexpr bool kPadC = Problem::kPadC; static constexpr bool kPadC = Problem::kPadC;
using LayoutA = remove_cvref_t<typename Problem::LayoutA>;
using LayoutB = remove_cvref_t<typename Problem::LayoutB>;
using LayoutC = remove_cvref_t<typename Problem::LayoutC>;
CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetStaticLdsSize() CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetStaticLdsSize()
{ {
return ck_tile::integer_divide_ceil( return ck_tile::integer_divide_ceil(
......
...@@ -7,9 +7,9 @@ ...@@ -7,9 +7,9 @@
namespace ck_tile { namespace ck_tile {
// Default policy for BlockGemmPipelineAGmemBGmemCRegV1 // Default policy for GemmPipelineAGmemBGmemCRegV1
// Default policy class should not be templated, put template on member functions instead // Default policy class should not be templated, put template on member functions instead
struct BlockGemmPipelineAGmemBGmemCRegV1DefaultPolicy struct GemmPipelineAGmemBGmemCRegV1DefaultPolicy
{ {
#if 0 #if 0
// 2d // 2d
......
...@@ -4,15 +4,15 @@ ...@@ -4,15 +4,15 @@
#pragma once #pragma once
#include "ck_tile/core.hpp" #include "ck_tile/core.hpp"
#include "ck_tile/ops/gemm/pipeline/block_gemm_pipeline_agmem_bgmem_creg_v2_default_policy.hpp" #include "ck_tile/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_v2_default_policy.hpp"
namespace ck_tile { namespace ck_tile {
// A Tile Window: global memory // A Tile Window: global memory
// B Tile Window: global memory // B Tile Window: global memory
// C Distributed tensor: register // C Distributed tensor: register
template <typename Problem, typename Policy = BlockGemmPipelineAGmemBGmemCRegV2DefaultPolicy> template <typename Problem, typename Policy = GemmPipelineAGmemBGmemCRegV2DefaultPolicy>
struct BlockGemmPipelineAGmemBGmemCRegV2 struct GemmPipelineAGmemBGmemCRegV2
{ {
using ADataType = remove_cvref_t<typename Problem::ADataType>; using ADataType = remove_cvref_t<typename Problem::ADataType>;
using BDataType = remove_cvref_t<typename Problem::BDataType>; using BDataType = remove_cvref_t<typename Problem::BDataType>;
......
...@@ -7,12 +7,11 @@ ...@@ -7,12 +7,11 @@
namespace ck_tile { namespace ck_tile {
// Default policy for BlockGemmPipelineAGmemBGmemCRegV2 // Default policy for GemmPipelineAGmemBGmemCRegV2
// Default policy class should not be templated, put template on member functions instead // Default policy class should not be templated, put template on member functions instead
// NOTE: policy should be binded to its corresponding operation. It's just a coincidence that // NOTE: policy should be binded to its corresponding operation. It's just a coincidence that
// BlockGemmPipelineAGmemBGmemCRegV2DefaultPolicy is the same as // GemmPipelineAGmemBGmemCRegV2DefaultPolicy is the same as
// BlockGemmPipelineAGmemBGmemCRegV1DefaultPolicy // GemmPipelineAGmemBGmemCRegV1DefaultPolicy
using BlockGemmPipelineAGmemBGmemCRegV2DefaultPolicy = using GemmPipelineAGmemBGmemCRegV2DefaultPolicy = GemmPipelineAGmemBGmemCRegV1DefaultPolicy;
BlockGemmPipelineAGmemBGmemCRegV1DefaultPolicy;
} // namespace ck_tile } // namespace ck_tile
...@@ -13,20 +13,23 @@ template <typename ADataType_, ...@@ -13,20 +13,23 @@ template <typename ADataType_,
typename BDataType_, typename BDataType_,
typename CDataType_, typename CDataType_,
typename BlockGemmShape_, typename BlockGemmShape_,
bool kPadA_ = false, typename TileGemmTraits_>
bool kPadB_ = false, struct GemmPipelineProblem
bool kPadC_ = false>
struct BlockGemmPipelineProblem
{ {
using ADataType = remove_cvref_t<ADataType_>; using ADataType = remove_cvref_t<ADataType_>;
using BDataType = remove_cvref_t<BDataType_>; using BDataType = remove_cvref_t<BDataType_>;
using CDataType = remove_cvref_t<CDataType_>; using CDataType = remove_cvref_t<CDataType_>;
using BlockGemmShape = remove_cvref_t<BlockGemmShape_>; using BlockGemmShape = remove_cvref_t<BlockGemmShape_>;
using GemmTraits = remove_cvref_t<TileGemmTraits_>;
static constexpr index_t kBlockSize = BlockGemmShape::NumWarps * get_warp_size(); static constexpr index_t kBlockSize = BlockGemmShape::NumWarps * get_warp_size();
static constexpr bool kPadA = kPadA_; static constexpr bool kPadA = GemmTraits::kPadA;
static constexpr bool kPadB = kPadB_; static constexpr bool kPadB = GemmTraits::kPadB;
static constexpr bool kPadC = kPadC_; static constexpr bool kPadC = GemmTraits::kPadC;
using LayoutA = remove_cvref_t<typename GemmTraits::LayoutA>;
using LayoutB = remove_cvref_t<typename GemmTraits::LayoutB>;
using LayoutC = remove_cvref_t<typename GemmTraits::LayoutC>;
static constexpr index_t AlignmentA = kPadA ? 1 : VectorLoadSize / sizeof(ADataType); static constexpr index_t AlignmentA = kPadA ? 1 : VectorLoadSize / sizeof(ADataType);
static constexpr index_t AlignmentB = kPadB ? 1 : VectorLoadSize / sizeof(BDataType); static constexpr index_t AlignmentB = kPadB ? 1 : VectorLoadSize / sizeof(BDataType);
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core.hpp"
namespace ck_tile {
template <bool kPadA_,
bool kPadB_,
bool kPadC_,
typename LayoutA_,
typename LayoutB_,
typename LayoutC_>
struct TileGemmTraits
{
static constexpr bool kPadA = kPadA_;
static constexpr bool kPadB = kPadB_;
static constexpr bool kPadC = kPadC_;
using LayoutA = LayoutA_;
using LayoutB = LayoutB_;
using LayoutC = LayoutC_;
};
} // namespace ck_tile
...@@ -31,8 +31,14 @@ struct Layernorm2dFwd ...@@ -31,8 +31,14 @@ struct Layernorm2dFwd
static constexpr ck_tile::index_t kMPerBlock = Problem::BlockShape::kMPerBlock; static constexpr ck_tile::index_t kMPerBlock = Problem::BlockShape::kMPerBlock;
static constexpr ck_tile::index_t kNPerBlock = Problem::BlockShape::kNPerBlock; static constexpr ck_tile::index_t kNPerBlock = Problem::BlockShape::kNPerBlock;
static constexpr bool kPadM = Problem::kPadM;
static constexpr bool kPadN = Problem::kPadN;
static constexpr ck_tile::index_t kNThreadPerWarp = Problem::BlockShape::kNThreadPerWarp; static constexpr ck_tile::index_t kNThreadPerWarp = Problem::BlockShape::kNThreadPerWarp;
static constexpr ck_tile::index_t kNPerThread = Problem::BlockShape::kNPerThread;
static constexpr auto I0 = number<0>{};
static constexpr auto I1 = number<1>{};
struct Kargs struct Kargs
{ {
...@@ -96,19 +102,25 @@ struct Layernorm2dFwd ...@@ -96,19 +102,25 @@ struct Layernorm2dFwd
sequence<2>>{}); sequence<2>>{});
} }
template <typename Dstr> CK_TILE_DEVICE static int GetWelfordMaxCount(int N)
CK_TILE_DEVICE static constexpr auto GetNPerThread(Dstr)
{ {
constexpr auto nDstrSpan = Dstr::get_distributed_spans().template at<1>(); constexpr ck_tile::index_t kNThreadPerBlock = kNPerBlock / kNPerThread;
using Lengths = decltype(nDstrSpan.impl_);
ck_tile::index_t ret = 1; int thread_id_n = get_thread_id() % kNThreadPerBlock;
int max_count =
__builtin_amdgcn_readfirstlane(N < kNPerBlock ? 0 : kNPerThread * (N / kNPerBlock));
int n_per_block_tail_loop =
__builtin_amdgcn_readfirstlane(N - max_count * kNThreadPerBlock);
ck_tile::static_for<0, Lengths::size(), 1>{}( if(n_per_block_tail_loop > 0)
[&](auto idx) { ret *= Lengths::template at(idx); }); {
int thread_max_n = (thread_id_n + 1) * kNPerThread;
int delta = thread_max_n - n_per_block_tail_loop;
delta = clamp(thread_max_n - n_per_block_tail_loop, 0, kNPerThread);
max_count += kNPerThread - delta;
}
return ret; return max_count;
} }
template <typename DistributedTensor> template <typename DistributedTensor>
...@@ -129,42 +141,29 @@ struct Layernorm2dFwd ...@@ -129,42 +141,29 @@ struct Layernorm2dFwd
return out_dstr_tensor; return out_dstr_tensor;
} }
template <bool Cond = (kHasGamma && kHasBeta)> template <typename XBlockWindow,
CK_TILE_DEVICE std::enable_if_t<Cond> TwoPassLayernorm2dFwd(const XDataType* p_x, typename GammaBlockWindow,
const GammaDataType* p_gamma, typename BetaBlockWindow,
const BetaDataType* p_beta, typename YBlockWindow,
YDataType* p_y, typename MeanBlockWindow,
MeanDataType* p_mean, typename InvStdBlockWindow,
InvStdDataType* p_invStd, bool Cond = (kHasGamma && kHasBeta)>
const ComputeDataType epsilon, CK_TILE_DEVICE std::enable_if_t<Cond>
ck_tile::index_t M, TwoPassLayernorm2dFwd(XBlockWindow& x_block_window,
ck_tile::index_t N) const GammaBlockWindow& gamma_block_window,
BetaBlockWindow& beta_block_window,
YBlockWindow& y_block_window,
MeanBlockWindow& mean_block_window,
InvStdBlockWindow& inv_std_block_window,
ComputeDataType epsilon,
ck_tile::index_t N) const
{ {
constexpr auto I0 = number<0>{}; // TODO - Optimize tail loop to reduce move_tile_window()
constexpr auto I1 = number<1>{}; index_t num_n_tile_iteration =
__builtin_amdgcn_readfirstlane(integer_divide_ceil(N, kNPerBlock));
const auto x_m_n = make_naive_tensor_view<address_space_enum::global>(
p_x, make_tuple(M, N), make_tuple(N, 1), number<32>{}, number<1>{});
const auto gamma_n = make_naive_tensor_view<address_space_enum::global>(
p_gamma, make_tuple(N), make_tuple(1), number<32>{}, number<1>{});
const auto beta_n = make_naive_tensor_view<address_space_enum::global>( int welford_max_count = GetWelfordMaxCount(N);
p_beta, make_tuple(N), make_tuple(1), number<32>{}, number<1>{}); ThreadWelford<ComputeDataType, XDataType> thread_welford{welford_max_count};
const auto iM = get_block_id() * kMPerBlock;
constexpr auto xDstr = MakeXBlockTileDistribution();
auto x_block_window = make_tile_window(
x_m_n, make_tuple(number<kMPerBlock>{}, number<kNPerBlock>{}), {iM, 0}, xDstr);
index_t num_n_tile_iteration = __builtin_amdgcn_readfirstlane(N / kNPerBlock);
// TODO: padding - handle max_count if N % kNPerBlock != 0
constexpr auto NPerThread = GetNPerThread(xDstr);
ThreadWelford<ComputeDataType, XDataType> thread_welford{
type_convert<int>(NPerThread * N / kNPerBlock)};
using XTensorType = decltype(load_tile(x_block_window)); using XTensorType = decltype(load_tile(x_block_window));
auto mean_compute_block_tensor = auto mean_compute_block_tensor =
...@@ -190,44 +189,14 @@ struct Layernorm2dFwd ...@@ -190,44 +189,14 @@ struct Layernorm2dFwd
auto inv_std_compute_block_tensor = InvSqrt(var_compute_block_tensor, epsilon); auto inv_std_compute_block_tensor = InvSqrt(var_compute_block_tensor, epsilon);
if constexpr(kSaveMean) if constexpr(kSaveMean)
{
const auto mean_m = make_naive_tensor_view_packed<address_space_enum::global>(
p_mean, make_tuple(M), number<32>{});
auto mean_block_window =
make_tile_window(mean_m, make_tuple(number<kMPerBlock>{}), {iM});
store_tile(mean_block_window, cast_tile<MeanDataType>(mean_compute_block_tensor)); store_tile(mean_block_window, cast_tile<MeanDataType>(mean_compute_block_tensor));
}
if constexpr(kSaveInvStd) if constexpr(kSaveInvStd)
{ store_tile(inv_std_block_window,
const auto inv_std_m = make_naive_tensor_view_packed<address_space_enum::global>( cast_tile<InvStdDataType>(inv_std_compute_block_tensor));
p_invStd, make_tuple(M), number<32>{});
auto inv_std_block_window =
make_tile_window(inv_std_m, make_tuple(number<kMPerBlock>{}), {iM});
store_tile(inv_std_block_window, cast_tile<MeanDataType>(inv_std_compute_block_tensor));
}
// TODO: Extract normalize pipeline
const auto y_m_n = make_naive_tensor_view<address_space_enum::global>(
p_y, make_tuple(M, N), make_tuple(N, 1), number<32>{}, number<1>{});
auto y_block_window = make_tile_window(
y_m_n, make_tuple(number<kMPerBlock>{}, number<kNPerBlock>{}), {iM, 0});
constexpr auto gammaDstr = MakeGammaBetaBlockTileDistribution();
constexpr auto betaDstr = gammaDstr;
auto gamma_block_window =
make_tile_window(gamma_n, make_tuple(number<kNPerBlock>{}), {0}, gammaDstr);
auto beta_block_window = make_tile_window(
beta_n, make_tuple(number<kMPerBlock>{}, number<kNPerBlock>{}), {0}, betaDstr);
// reverse read x to reuse cache // reverse read x to reuse cache
ck_tile::index_t stride_to_right_most_window = N - kNPerBlock; ck_tile::index_t stride_to_right_most_window =
N % kNPerBlock == 0 ? N - kNPerBlock : N - N % kNPerBlock;
move_tile_window(x_block_window, {0, -kNPerBlock}); move_tile_window(x_block_window, {0, -kNPerBlock});
move_tile_window(gamma_block_window, {stride_to_right_most_window}); move_tile_window(gamma_block_window, {stride_to_right_most_window});
...@@ -274,17 +243,209 @@ struct Layernorm2dFwd ...@@ -274,17 +243,209 @@ struct Layernorm2dFwd
} }
} }
template <typename XBlockWindow,
typename GammaBlockWindow,
typename BetaBlockWindow,
typename YBlockWindow,
typename MeanBlockWindow,
typename InvStdBlockWindow,
bool Cond = (kHasGamma && kHasBeta)>
CK_TILE_DEVICE std::enable_if_t<Cond>
OnePassLayernorm2dFwd(XBlockWindow& x_block_window,
GammaBlockWindow& gamma_block_window,
BetaBlockWindow& beta_block_window,
YBlockWindow& y_block_window,
MeanBlockWindow& mean_block_window,
InvStdBlockWindow& inv_std_block_window,
ComputeDataType epsilon,
ck_tile::index_t N) const
{
int welford_max_count = GetWelfordMaxCount(N);
ThreadWelford<ComputeDataType, XDataType> thread_welford{welford_max_count};
using XTensorType = decltype(load_tile(x_block_window));
auto mean_compute_block_tensor =
thread_welford.template MakeInitialMeanVarDistributedTensor<XTensorType>();
auto var_compute_block_tensor =
thread_welford.template MakeInitialMeanVarDistributedTensor<XTensorType>();
clear_tile(mean_compute_block_tensor);
clear_tile(var_compute_block_tensor);
const auto x_block_tensor = load_tile(x_block_window);
thread_welford(x_block_tensor, mean_compute_block_tensor, var_compute_block_tensor);
// TODO: support cross warp Welford
WarpMergeWelford<ComputeDataType, true>{}(
mean_compute_block_tensor, var_compute_block_tensor, thread_welford.cur_count_);
auto inv_std_compute_block_tensor = InvSqrt(var_compute_block_tensor, epsilon);
if constexpr(kSaveMean)
store_tile(mean_block_window, cast_tile<MeanDataType>(mean_compute_block_tensor));
if constexpr(kSaveInvStd)
store_tile(inv_std_block_window,
cast_tile<InvStdDataType>(inv_std_compute_block_tensor));
// normalize
const auto gamma_block_tensor = load_tile(gamma_block_window);
const auto beta_block_tensor = load_tile(beta_block_window);
constexpr auto x_spans = decltype(x_block_tensor)::get_distributed_spans();
auto y_block_tensor =
make_static_distributed_tensor<YDataType>(x_block_tensor.get_tile_distribution());
sweep_tile_span(x_spans[I1], [&](auto idx1) {
constexpr auto j_idx = make_tuple(idx1);
const auto gamma = type_convert<ComputeDataType>(gamma_block_tensor[j_idx]);
const auto beta = type_convert<ComputeDataType>(beta_block_tensor[j_idx]);
sweep_tile_span(x_spans[I0], [&](auto idx0) {
constexpr auto i_idx = make_tuple(idx0);
constexpr auto i_j_idx = make_tuple(idx0, idx1);
const auto mean = mean_compute_block_tensor[i_idx];
const auto inv_std = inv_std_compute_block_tensor[i_idx];
const auto x = type_convert<ComputeDataType>(x_block_tensor[i_j_idx]);
auto y = (x - mean) * inv_std * gamma + beta;
y_block_tensor(i_j_idx) = type_convert<YDataType>(y);
});
});
store_tile(y_block_window, y_block_tensor);
}
CK_TILE_DEVICE void operator()(Kargs kargs) const CK_TILE_DEVICE void operator()(Kargs kargs) const
{ {
TwoPassLayernorm2dFwd(static_cast<const XDataType*>(kargs.p_x), const auto x_m_n = [&]() {
static_cast<const GammaDataType*>(kargs.p_gamma), const auto x_dram_naive = make_naive_tensor_view<address_space_enum::global>(
static_cast<const BetaDataType*>(kargs.p_beta), static_cast<const XDataType*>(kargs.p_x),
static_cast<YDataType*>(kargs.p_y), make_tuple(kargs.M, kargs.N),
static_cast<MeanDataType*>(kargs.p_mean), make_tuple(kargs.N, 1),
static_cast<InvStdDataType*>(kargs.p_invStd), number<kNPerThread>{},
static_cast<const ComputeDataType>(kargs.epsilon), number<1>{});
kargs.M,
kargs.N); return pad_tensor_view(x_dram_naive,
make_tuple(number<kMPerBlock>{}, number<kNPerBlock>{}),
sequence<kPadM, kPadN>{});
}();
const auto gamma_n = [&]() {
const auto gamma_dram_naive = make_naive_tensor_view<address_space_enum::global>(
static_cast<const GammaDataType*>(kargs.p_gamma),
make_tuple(kargs.N),
make_tuple(1),
number<kNPerThread>{},
number<1>{});
return pad_tensor_view(
gamma_dram_naive, make_tuple(number<kNPerBlock>{}), sequence<kPadN>{});
}();
const auto beta_n = [&]() {
const auto gamma_dram_naive = make_naive_tensor_view<address_space_enum::global>(
static_cast<const BetaDataType*>(kargs.p_beta),
make_tuple(kargs.N),
make_tuple(1),
number<kNPerThread>{},
number<1>{});
return pad_tensor_view(
gamma_dram_naive, make_tuple(number<kNPerBlock>{}), sequence<kPadN>{});
}();
const auto iM = get_block_id() * kMPerBlock;
constexpr auto xDstr = MakeXBlockTileDistribution();
auto x_block_window = make_tile_window(
x_m_n, make_tuple(number<kMPerBlock>{}, number<kNPerBlock>{}), {iM, 0}, xDstr);
const auto y_m_n = [&]() {
const auto y_dram_naive = make_naive_tensor_view<address_space_enum::global>(
static_cast<YDataType*>(kargs.p_y),
make_tuple(kargs.M, kargs.N),
make_tuple(kargs.N, 1),
number<kNPerThread>{},
number<1>{});
return pad_tensor_view(y_dram_naive,
make_tuple(number<kMPerBlock>{}, number<kNPerBlock>{}),
sequence<kPadM, kPadN>{});
}();
auto y_block_window = make_tile_window(
y_m_n, make_tuple(number<kMPerBlock>{}, number<kNPerBlock>{}), {iM, 0});
constexpr auto gammaDstr = MakeGammaBetaBlockTileDistribution();
constexpr auto betaDstr = gammaDstr;
auto gamma_block_window =
make_tile_window(gamma_n, make_tuple(number<kNPerBlock>{}), {0}, gammaDstr);
auto beta_block_window = make_tile_window(
beta_n, make_tuple(number<kMPerBlock>{}, number<kNPerBlock>{}), {0}, betaDstr);
auto mean_block_window = [&]() {
if constexpr(kSaveMean)
{
const auto mean_m = [&]() {
const auto mean_dram_naive =
make_naive_tensor_view_packed<address_space_enum::global>(
static_cast<MeanDataType*>(kargs.p_mean),
make_tuple(kargs.M),
number<1>{});
return pad_tensor_view(
mean_dram_naive, make_tuple(number<kMPerBlock>{}), sequence<kPadM>{});
}();
return make_tile_window(mean_m, make_tuple(number<kMPerBlock>{}), {iM});
}
else
return make_null_tile_window(make_tuple(number<kMPerBlock>{}));
}();
auto inv_std_block_window = [&]() {
if constexpr(kSaveInvStd)
{
const auto inv_std_m = [&]() {
const auto inv_std_dram_naive =
make_naive_tensor_view_packed<address_space_enum::global>(
static_cast<InvStdDataType*>(kargs.p_invStd),
make_tuple(kargs.M),
number<1>{});
return pad_tensor_view(
inv_std_dram_naive, make_tuple(number<kMPerBlock>{}), sequence<kPadM>{});
}();
return make_tile_window(inv_std_m, make_tuple(number<kMPerBlock>{}), {iM});
}
else
return make_null_tile_window(make_tuple(number<kMPerBlock>{}));
}();
if(kargs.N <= kNPerBlock)
OnePassLayernorm2dFwd(x_block_window,
gamma_block_window,
beta_block_window,
y_block_window,
mean_block_window,
inv_std_block_window,
static_cast<const ComputeDataType>(kargs.epsilon),
kargs.N);
else
TwoPassLayernorm2dFwd(x_block_window,
gamma_block_window,
beta_block_window,
y_block_window,
mean_block_window,
inv_std_block_window,
static_cast<const ComputeDataType>(kargs.epsilon),
kargs.N);
} }
}; };
......
...@@ -14,17 +14,21 @@ template <typename XDataType_, ...@@ -14,17 +14,21 @@ template <typename XDataType_,
typename YDataType_, typename YDataType_,
typename MeanDataType_, typename MeanDataType_,
typename InvStdDataType_, typename InvStdDataType_,
typename BlockShape_> typename BlockShape_,
bool kPadM_,
bool kPadN_>
struct BlockLayernorm2dFwdProblem struct BlockLayernorm2dFwdProblem
{ {
using XDataType = remove_cvref_t<XDataType_>; using XDataType = remove_cvref_t<XDataType_>;
using GammaDataType = remove_cvref_t<GammaDataType_>; using GammaDataType = remove_cvref_t<GammaDataType_>;
using BetaDataType = remove_cvref_t<BetaDataType_>; using BetaDataType = remove_cvref_t<BetaDataType_>;
using ComputeDataType = remove_cvref_t<ComputeDataType_>; using ComputeDataType = remove_cvref_t<ComputeDataType_>;
using YDataType = remove_cvref_t<YDataType_>; using YDataType = remove_cvref_t<YDataType_>;
using MeanDataType = remove_cvref_t<MeanDataType_>; using MeanDataType = remove_cvref_t<MeanDataType_>;
using InvStdDataType = remove_cvref_t<InvStdDataType_>; using InvStdDataType = remove_cvref_t<InvStdDataType_>;
using BlockShape = remove_cvref_t<BlockShape_>; using BlockShape = remove_cvref_t<BlockShape_>;
static constexpr bool kPadM = kPadM_;
static constexpr bool kPadN = kPadN_;
}; };
} // namespace ck_tile } // namespace ck_tile
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <iostream>
#include <sstream>
#include "ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp"
#include "ck/tensor_operation/gpu/device/device_base.hpp"
namespace ck {
template <typename ALayout,
typename BLayout,
typename CLayout,
typename ADataType,
typename BDataType,
typename CDataType,
typename AccDataType,
typename AElementwiseOperation,
typename BElementwiseOperation,
typename CDEElementwiseOperation,
typename ComputeTypeA,
typename ComputeTypeB>
__global__ void
#if CK_USE_LAUNCH_BOUNDS
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
#endif
naive_gemm_kernel(const ADataType* __restrict__ p_a_grid,
const BDataType* __restrict__ p_b_grid,
CDataType* __restrict__ p_c_grid,
index_t m,
index_t n,
index_t k,
const AElementwiseOperation a_element_op,
const BElementwiseOperation b_element_op,
const CDEElementwiseOperation c_element_op)
{
using RowMajor = ck::tensor_layout::gemm::RowMajor;
const int row_idx = blockIdx.x * blockDim.x + threadIdx.x;
const int col_idx = blockIdx.y * blockDim.y + threadIdx.y;
if(row_idx < m && col_idx < n)
{
AccDataType v_acc = static_cast<AccDataType>(0.0);
ComputeTypeA v_a = static_cast<ComputeTypeA>(0.0);
ComputeTypeB v_b = static_cast<ComputeTypeB>(0.0);
CDataType v_c = static_cast<CDataType>(0.0);
for(int k_idx = 0; k_idx < k; ++k_idx)
{
// check input matrices layout
int element_idx_a = 0;
int element_idx_b = 0;
if constexpr(std::is_same_v<ALayout, RowMajor>)
{
element_idx_a = row_idx * k + k_idx;
}
else
{
element_idx_a = row_idx + m * k_idx;
}
if constexpr(std::is_same_v<BLayout, RowMajor>)
{
element_idx_b = k_idx * n + col_idx;
}
else
{
element_idx_b = k_idx + k * col_idx;
}
// apply a_element_op
a_element_op(v_a, p_a_grid[element_idx_a]);
// apply b_element_op
b_element_op(v_b, p_b_grid[element_idx_b]);
// multiply and accumulate
v_acc += static_cast<AccDataType>(v_a) * static_cast<AccDataType>(v_b);
}
// apply c_element_op
c_element_op(v_c, v_acc);
// check output matrix layout
int element_idx_c = 0;
if constexpr(std::is_same_v<CLayout, RowMajor>)
{
element_idx_c = row_idx * n + col_idx;
}
else
{
element_idx_c = row_idx + m * col_idx;
}
// prepare output
p_c_grid[element_idx_c] = v_c;
}
}
} // namespace ck
namespace ck {
namespace tensor_operation {
namespace device {
template <typename ALayout,
typename BLayout,
typename CLayout,
typename ADataType,
typename BDataType,
typename CDataType,
typename AccDataType,
typename AElementwiseOperation,
typename BElementwiseOperation,
typename CElementwiseOperation,
typename ComputeTypeA = CDataType,
typename ComputeTypeB = ComputeTypeA>
struct ReferenceGemm : public device::BaseOperator
{
// Argument
struct Argument : public device::BaseArgument
{
Argument(const void* p_a_grid,
const void* p_b_grid,
void* p_c_grid,
index_t m,
index_t n,
index_t k,
AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op,
CElementwiseOperation c_element_op)
: p_a_grid_{static_cast<const ADataType*>(p_a_grid)},
p_b_grid_{static_cast<const BDataType*>(p_b_grid)},
p_c_grid_{static_cast<CDataType*>(p_c_grid)},
m_{m},
n_{n},
k_{k},
a_element_op_{a_element_op},
b_element_op_{b_element_op},
c_element_op_{c_element_op}
{
}
const ADataType* p_a_grid_;
const BDataType* p_b_grid_;
CDataType* p_c_grid_;
index_t m_;
index_t n_;
index_t k_;
AElementwiseOperation a_element_op_;
BElementwiseOperation b_element_op_;
CElementwiseOperation c_element_op_;
};
// Invoker
struct Invoker : public device::BaseInvoker
{
using Argument = ReferenceGemm::Argument;
float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{})
{
int block_size = 16;
dim3 block_dim(block_size, block_size, 1);
dim3 grid_dim(
(arg.m_ + block_size - 1) / block_size, (arg.n_ + block_size - 1) / block_size, 1);
auto launch_kernel = [&]() {
const auto kernel = naive_gemm_kernel<ALayout,
BLayout,
CLayout,
ADataType,
BDataType,
CDataType,
AccDataType,
AElementwiseOperation,
BElementwiseOperation,
CElementwiseOperation,
ComputeTypeA,
ComputeTypeB>;
return launch_and_time_kernel(stream_config,
kernel,
grid_dim,
block_dim,
0,
arg.p_a_grid_,
arg.p_b_grid_,
arg.p_c_grid_,
arg.m_,
arg.n_,
arg.k_,
arg.a_element_op_,
arg.b_element_op_,
arg.c_element_op_);
};
return launch_kernel();
}
float Run(const device::BaseArgument* p_arg,
const StreamConfig& stream_config = StreamConfig{}) override
{
return Run(*dynamic_cast<const Argument*>(p_arg), stream_config);
}
};
bool IsSupportedArgument(const device::BaseArgument*) override { return true; }
static auto MakeArgument(const void* p_a_grid,
const void* p_b_grid,
void* p_c_grid,
index_t m,
index_t n,
index_t k,
AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op,
CElementwiseOperation c_element_op)
{
return Argument{
p_a_grid, p_b_grid, p_c_grid, m, n, k, a_element_op, b_element_op, c_element_op};
}
static auto MakeInvoker() { return Invoker{}; }
virtual std::unique_ptr<device::BaseInvoker> MakeInvokerPointer()
{
return std::make_unique<Invoker>(Invoker{});
}
std::string GetTypeString() const override
{
auto str = std::stringstream();
// clang-format off
str << "Device Reference Gemm"
<< std::endl;
// clang-format on
return str.str();
}
};
} // namespace device
} // namespace tensor_operation
} // namespace ck
...@@ -37,11 +37,7 @@ function(add_instance_library INSTANCE_NAME) ...@@ -37,11 +37,7 @@ function(add_instance_library INSTANCE_NAME)
endforeach() endforeach()
endif() endif()
if(INSTANCES_ONLY) set(INST_TARGETS ${SUPPORTED_GPU_TARGETS})
set(INST_TARGETS ${DEFAULT_GPU_TARGETS})
else()
set(INST_TARGETS ${GPU_TARGETS})
endif()
# Do not build DL instances if DL_KERNELS macro is not set # Do not build DL instances if DL_KERNELS macro is not set
foreach(source IN LISTS ARGN) foreach(source IN LISTS ARGN)
...@@ -64,9 +60,9 @@ function(add_instance_library INSTANCE_NAME) ...@@ -64,9 +60,9 @@ function(add_instance_library INSTANCE_NAME)
list(REMOVE_ITEM ARGN "${source}") list(REMOVE_ITEM ARGN "${source}")
endif() endif()
endforeach() endforeach()
# Do not build mha instances if gfx94 targets are not on the target list # Do not build mha instances if gfx94 or gfx90a targets are not on the target list
foreach(source IN LISTS ARGN) foreach(source IN LISTS ARGN)
if(NOT INST_TARGETS MATCHES "gfx94" AND source MATCHES "mha") if(NOT INST_TARGETS MATCHES "gfx94" AND NOT INST_TARGETS MATCHES "gfx90a" AND source MATCHES "mha")
message("removing mha instance ${source} ") message("removing mha instance ${source} ")
list(REMOVE_ITEM ARGN "${source}") list(REMOVE_ITEM ARGN "${source}")
endif() endif()
...@@ -75,17 +71,13 @@ function(add_instance_library INSTANCE_NAME) ...@@ -75,17 +71,13 @@ function(add_instance_library INSTANCE_NAME)
if(ARGN) if(ARGN)
set(INST_OBJ) set(INST_OBJ)
foreach(source IN LISTS ARGN) foreach(source IN LISTS ARGN)
if(INSTANCES_ONLY) set(INST_TARGETS ${SUPPORTED_GPU_TARGETS})
set(INST_TARGETS ${DEFAULT_GPU_TARGETS})
else()
set(INST_TARGETS ${GPU_TARGETS})
endif()
if(source MATCHES "_xdl") if(source MATCHES "_xdl")
list(REMOVE_ITEM INST_TARGETS gfx900 gfx906 gfx1030 gfx1100 gfx1101 gfx1102 gfx1103 gfx1200 gfx1201) list(REMOVE_ITEM INST_TARGETS gfx900 gfx906 gfx1030 gfx1100 gfx1101 gfx1102 gfx1103 gfx1200 gfx1201)
elseif(ARGN MATCHES "_wmma") elseif(ARGN MATCHES "_wmma")
list(REMOVE_ITEM INST_TARGETS gfx900 gfx906 gfx908 gfx90a gfx940 gfx941 gfx942 gfx1030) list(REMOVE_ITEM INST_TARGETS gfx900 gfx906 gfx908 gfx90a gfx940 gfx941 gfx942 gfx1030)
elseif(ARGN MATCHES "mha") elseif(ARGN MATCHES "mha")
list(REMOVE_ITEM INST_TARGETS gfx900 gfx906 gfx908 gfx90a gfx1030 gfx1100 gfx1101 gfx1102 gfx1103 gfx1200 gfx1201) list(REMOVE_ITEM INST_TARGETS gfx900 gfx906 gfx908 gfx1030 gfx1100 gfx1101 gfx1102 gfx1103 gfx1200 gfx1201)
endif() endif()
set(offload_targets) set(offload_targets)
foreach(target IN LISTS INST_TARGETS) foreach(target IN LISTS INST_TARGETS)
...@@ -191,12 +183,7 @@ FOREACH(subdir_path ${dir_list}) ...@@ -191,12 +183,7 @@ FOREACH(subdir_path ${dir_list})
set(add_inst 1) set(add_inst 1)
endif() endif()
if(INSTANCES_ONLY) set(INST_TARGETS ${SUPPORTED_GPU_TARGETS})
set(INST_TARGETS ${DEFAULT_GPU_TARGETS})
else()
set(INST_TARGETS ${GPU_TARGETS})
endif()
if(("${cmake_instance}" MATCHES "quantization") AND (DEFINED DTYPES) AND (NOT DTYPES MATCHES "int8")) if(("${cmake_instance}" MATCHES "quantization") AND (DEFINED DTYPES) AND (NOT DTYPES MATCHES "int8"))
message("quantization instances will not be built!") message("quantization instances will not be built!")
...@@ -320,8 +307,7 @@ if(CK_DEVICE_CONV_INSTANCES) ...@@ -320,8 +307,7 @@ if(CK_DEVICE_CONV_INSTANCES)
endif() endif()
if(CK_DEVICE_MHA_INSTANCES) if(CK_DEVICE_MHA_INSTANCES)
set(gpu_list ${INST_TARGETS}) set(gpu_list ${INST_TARGETS})
list(FILTER gpu_list INCLUDE REGEX "^gfx94") if(gpu_list MATCHES "gfx94" OR gpu_list MATCHES "gfx90a")
if(gpu_list)
add_library(device_mha_operations STATIC ${CK_DEVICE_MHA_INSTANCES}) add_library(device_mha_operations STATIC ${CK_DEVICE_MHA_INSTANCES})
add_library(composablekernels::device_mha_operations ALIAS device_mha_operations) add_library(composablekernels::device_mha_operations ALIAS device_mha_operations)
target_compile_features(device_mha_operations PUBLIC) target_compile_features(device_mha_operations PUBLIC)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment