Unverified Commit 140d2fa6 authored by Illia Silin's avatar Illia Silin Committed by GitHub
Browse files

Merge pull request #197 from ROCm/merge_from_public

Merge from public
parents 87ea11d0 d4d83037
......@@ -97,13 +97,6 @@
#cmakedefine CK_ENABLE_DL_KERNELS @CK_ENABLE_DL_KERNELS@
#endif
//
// Instances supports in the current CK build
//
#ifndef CK_ENABLE_INSTANCES_ONLY
#cmakedefine CK_ENABLE_INSTANCES_ONLY @CK_ENABLE_INSTANCES_ONLY@
#endif
//
// CK kernels which support XDL (MI series)
//
......
......@@ -66,6 +66,9 @@ float launch_and_time_kernel(const StreamConfig& stream_config,
hip_check_error(hipEventElapsedTime(&total_time, start, stop));
hip_check_error(hipEventDestroy(start));
hip_check_error(hipEventDestroy(stop));
return total_time / nrepeat;
}
else
......@@ -143,6 +146,9 @@ float launch_and_time_kernel_with_preprocess(const StreamConfig& stream_config,
hip_check_error(hipEventElapsedTime(&total_time, start, stop));
hip_check_error(hipEventDestroy(start));
hip_check_error(hipEventDestroy(stop));
return total_time / nrepeat;
}
else
......
......@@ -308,7 +308,7 @@ struct BlockwiseGemmXdlops_pipeline_v1_ab_scale<BlockGemmPipelineScheduler::Intr
typename vector_type<ComputeDataType,
xdlops_gemm.K1PerXdlops>::type;
xdlops_gemm.template Run(
xdlops_gemm.template Run<>(
a_thread_vec.template AsType<mfma_input_type>(),
b_thread_vec.template AsType<mfma_input_type>(),
c_thread_buf_per_scale.GetVectorTypeReference(I0));
......@@ -390,7 +390,8 @@ struct BlockwiseGemmXdlops_pipeline_v1_ab_scale<BlockGemmPipelineScheduler::Intr
using mfma_input_type =
typename vector_type<ComputeDataType, xdlops_gemm.K1PerXdlops>::type;
xdlops_gemm.template Run(a_thread_vec.template AsType<mfma_input_type>(),
xdlops_gemm.template Run<>(
a_thread_vec.template AsType<mfma_input_type>(),
b_thread_vec.template AsType<mfma_input_type>(),
c_thread_buf_per_scale.GetVectorTypeReference(I0));
});
......
......@@ -350,7 +350,7 @@ struct BlockwiseGemmXdlops_pipeline_v2_ab_scale<BlockGemmPipelineScheduler::Intr
typename vector_type<ComputeDataType,
xdlops_gemm.K1PerXdlops>::type;
xdlops_gemm.template Run(
xdlops_gemm.template Run<>(
a_thread_vec.template AsType<mfma_input_type>(),
b_thread_vec.template AsType<mfma_input_type>(),
c_thread_buf_per_scale.GetVectorTypeReference(I0));
......@@ -443,7 +443,7 @@ struct BlockwiseGemmXdlops_pipeline_v2_ab_scale<BlockGemmPipelineScheduler::Intr
typename vector_type<ComputeDataType,
xdlops_gemm.K1PerXdlops>::type;
xdlops_gemm.template Run(
xdlops_gemm.template Run<>(
a_thread_vec.template AsType<mfma_input_type>(),
b_thread_vec.template AsType<mfma_input_type>(),
c_thread_buf_per_scale.GetVectorTypeReference(I0));
......@@ -518,7 +518,8 @@ struct BlockwiseGemmXdlops_pipeline_v2_ab_scale<BlockGemmPipelineScheduler::Intr
using mfma_input_type =
typename vector_type<ComputeDataType, xdlops_gemm.K1PerXdlops>::type;
xdlops_gemm.template Run(a_thread_vec.template AsType<mfma_input_type>(),
xdlops_gemm.template Run<>(
a_thread_vec.template AsType<mfma_input_type>(),
b_thread_vec.template AsType<mfma_input_type>(),
c_thread_buf_per_scale.GetVectorTypeReference(I0));
});
......@@ -575,7 +576,8 @@ struct BlockwiseGemmXdlops_pipeline_v2_ab_scale<BlockGemmPipelineScheduler::Intr
using mfma_input_type =
typename vector_type<ComputeDataType, xdlops_gemm.K1PerXdlops>::type;
xdlops_gemm.template Run(a_thread_vec.template AsType<mfma_input_type>(),
xdlops_gemm.template Run<>(
a_thread_vec.template AsType<mfma_input_type>(),
b_thread_vec.template AsType<mfma_input_type>(),
c_thread_buf_per_scale.GetVectorTypeReference(I0));
});
......
......@@ -427,7 +427,7 @@ struct BlockwiseGemmXdlops_pipeline_v3_ab_scale<BlockGemmPipelineScheduler::Intr
typename vector_type<ComputeDataType,
xdlops_gemm.K1PerXdlops>::type;
xdlops_gemm.template Run(
xdlops_gemm.template Run<>(
a_thread_vec.template AsType<mfma_input_type>(),
b_thread_vec.template AsType<mfma_input_type>(),
c_thread_buf_per_scale.GetVectorTypeReference(I0));
......@@ -504,7 +504,8 @@ struct BlockwiseGemmXdlops_pipeline_v3_ab_scale<BlockGemmPipelineScheduler::Intr
using mfma_input_type =
typename vector_type<ComputeDataType, xdlops_gemm.K1PerXdlops>::type;
xdlops_gemm.template Run(a_thread_vec.template AsType<mfma_input_type>(),
xdlops_gemm.template Run<>(
a_thread_vec.template AsType<mfma_input_type>(),
b_thread_vec.template AsType<mfma_input_type>(),
c_thread_buf_per_scale.GetVectorTypeReference(I0));
});
......
......@@ -64,7 +64,7 @@ __global__ void
const index_t N = gemm_desc_ptr[group_id].N;
const index_t K = gemm_desc_ptr[group_id].K;
if(M * N * K == 0)
if(M == 0 || N == 0 || K == 0)
return;
const auto StrideAs = gemm_desc_ptr[group_id].StrideAs;
......
......@@ -345,7 +345,7 @@ struct DeviceGroupedGemmMultipleDSplitKXdlCShuffleTwoStage
const index_t N = gemm_descs[i].N_;
const index_t K = gemm_descs[i].K_;
if(M * N * K == 0)
if(M == 0 || N == 0 || K == 0)
{
skipped_group_count_++;
continue;
......
......@@ -109,7 +109,7 @@ __global__ void
N = gemm_desc_ptr[group_id].N;
K = gemm_desc_ptr[group_id].K;
if(M * N * K == 0)
if(M == 0 || N == 0 || K == 0)
{
grid_size_grp = 0;
continue;
......
......@@ -68,7 +68,7 @@ __global__ void
const index_t N = gemm_desc_ptr[group_id].N;
const index_t K = gemm_desc_ptr[group_id].K;
if(M * N * K == 0)
if(M == 0 || N == 0 || K == 0)
return;
const auto StrideA = gemm_desc_ptr[group_id].StrideA;
......
......@@ -50,12 +50,22 @@ class ArgParser
}
return *this;
}
void print()
void print() const
{
// find max key length
std::string::size_type max_key_length = 11;
for(auto& key : keys)
{
if(max_key_length < key.length())
{
max_key_length = key.length();
}
}
printf("args:\n");
for(auto& key : keys)
{
auto value = input_map[key];
auto value = input_map.at(key);
std::vector<std::string> help_text_lines;
size_t pos = 0;
for(size_t next_pos = value.help_text.find('\n', pos); next_pos != std::string::npos;)
......@@ -69,8 +79,7 @@ class ArgParser
std::string(value.help_text.begin() + pos, value.help_text.end()));
std::string default_value = std::string("(default:") + value.value + std::string(")");
std::cout << std::setw(2) << std::setw(12 - value.name.length()) << "-" << key
std::cout << std::setw(1 + max_key_length - value.name.length()) << "-" << key
<< std::setw(4) << " " << help_text_lines[0] << " " << default_value
<< std::endl;
......@@ -78,7 +87,8 @@ class ArgParser
help_next_line != help_text_lines.end();
++help_next_line)
{
std::cout << std::setw(17) << " " << *help_next_line << std::endl;
std::cout << std::setw(1 + max_key_length + 4) << " " << *help_next_line
<< std::endl;
}
}
}
......
......@@ -13,7 +13,6 @@ namespace conv {
struct ConvParam
{
ConvParam();
ConvParam(ck_tile::index_t n_dim,
ck_tile::index_t group_count,
ck_tile::index_t n_batch,
......@@ -199,11 +198,6 @@ struct ConvParam
}
};
ConvParam::ConvParam()
: ConvParam::ConvParam(2, 1, 128, 256, 192, {3, 3}, {71, 71}, {2, 2}, {1, 1}, {1, 1}, {1, 1})
{
}
CK_TILE_HOST std::string get_conv_param_parser_helper_msg()
{
std::string msg;
......
......@@ -27,7 +27,9 @@ CK_TILE_HOST void reference_gemm(const HostTensor<ADataType>& a_m_k,
const BElementOp& b_element_op = {},
const ACCElementOp& acc_element_op = {})
{
const int N = b_n_k.mDesc.get_lengths()[0];
const int N = (std::is_same_v<LayoutB, tensor_layout::gemm::ColumnMajor>)
? b_n_k.mDesc.get_lengths()[0]
: b_n_k.mDesc.get_lengths()[1];
const int K = (std::is_same_v<LayoutA, tensor_layout::gemm::RowMajor>)
? a_m_k.mDesc.get_lengths()[1]
: a_m_k.mDesc.get_lengths()[0];
......@@ -45,20 +47,31 @@ CK_TILE_HOST void reference_gemm(const HostTensor<ADataType>& a_m_k,
ADataType v_a = (std::is_same_v<LayoutA, tensor_layout::gemm::RowMajor>)
? a_element_op(a_m_k(m, k))
: a_element_op(a_m_k(k, m));
BDataType v_b = b_element_op(b_n_k(n, k));
BDataType v_b = (std::is_same_v<LayoutB, tensor_layout::gemm::ColumnMajor>)
? b_element_op(b_n_k(n, k))
: b_element_op(b_n_k(k, n));
v_acc += ck_tile::type_convert<AccDataType>(v_a) *
ck_tile::type_convert<AccDataType>(v_b);
}
c_m_n(m, n) = ck_tile::type_convert<CDataType>(acc_element_op(v_acc));
CDataType& c_ref = (std::is_same_v<LayoutC, tensor_layout::gemm::RowMajor>)
? c_m_n(m, n)
: c_m_n(n, m);
c_ref = ck_tile::type_convert<CDataType>(acc_element_op(v_acc));
}
};
make_ParallelTensorFunctor(f, M)(std::thread::hardware_concurrency());
}
template <typename ADataType, typename BDataType, typename AccDataType, typename CDataType>
template <typename ADataType,
typename BDataType,
typename AccDataType,
typename CDataType,
typename LayoutA,
typename LayoutB,
typename LayoutC>
__global__ void naive_gemm_kernel(ADataType* A,
BDataType* B,
CDataType* C,
......@@ -76,18 +89,32 @@ __global__ void naive_gemm_kernel(ADataType* A,
if(row < M && col < N)
{
AccDataType acc = 0.0;
for(int k = 0; k < K; ++k)
{
acc += static_cast<AccDataType>(A[row * strideA + k]) *
static_cast<AccDataType>(B[col * strideB + k]);
// Adjust indexing based on matrix layout
int a_index = (std::is_same_v<LayoutA, tensor_layout::gemm::RowMajor>)
? row * strideA + k
: k * strideA + row;
int b_index = (std::is_same_v<LayoutB, tensor_layout::gemm::ColumnMajor>)
? col * strideB + k
: k * strideB + col;
acc += static_cast<AccDataType>(A[a_index]) * static_cast<AccDataType>(B[b_index]);
}
C[row * strideC + col] = acc; // Store as AccDataType
int c_index = (std::is_same_v<LayoutC, tensor_layout::gemm::RowMajor>)
? row * strideC + col
: col * strideC + row;
C[c_index] = acc;
}
}
template <typename ADataType, typename BDataType, typename AccDataType, typename CDataType>
template <typename ADataType,
typename BDataType,
typename AccDataType,
typename CDataType,
typename LayoutA,
typename LayoutB,
typename LayoutC>
void reference_gemm_gpu(DeviceMem& a_device,
DeviceMem& b_device,
DeviceMem& c_device,
......@@ -145,7 +172,7 @@ void reference_gemm_gpu(DeviceMem& a_device,
int numThreadsPerBlock = 256; // Common choice for threads per block
int numBlocks = (totalElements + numThreadsPerBlock - 1) / numThreadsPerBlock;
naive_gemm_kernel<ADataType, BDataType, AccDataType, CDataType>
naive_gemm_kernel<ADataType, BDataType, AccDataType, CDataType, LayoutA, LayoutB, LayoutC>
<<<numBlocks, numThreadsPerBlock>>>(d_A, d_B, d_C, M, N, K, stride_a, stride_b, stride_c);
errC = hipMemcpy(
c_device.GetDeviceBuffer(), d_C, M * N * sizeof(CDataType), hipMemcpyDeviceToHost);
......
......@@ -3,5 +3,6 @@
#pragma once
#include "ck_tile/ops/epilogue/cshuffle_epilogue.hpp"
#include "ck_tile/ops/epilogue/default_2d_epilogue.hpp"
#include "ck_tile/ops/common/tensor_layout.hpp"
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core.hpp"
#define CK_TILE_MAX_RANK 5
namespace ck_tile {
// this epilogue aiming to store a matrix with different layout from the shared memory to the global
// memory.
template <typename AccDataType_,
typename ODataType_,
bool kPadM_,
bool kPadN_,
bool kTilePermute_,
index_t kRank_,
index_t kPerm0,
index_t kPerm1,
index_t TileSize0,
index_t TileSize1,
index_t kPerm2 = 0,
index_t kPerm3 = 0,
index_t kPerm4 = 0,
index_t TileSize2 = 0,
index_t TileSize3 = 0,
index_t TileSize4 = 0>
struct CShuffleEpilogueProblem
{
using AccDataType = remove_cvref_t<AccDataType_>;
using ODataType = remove_cvref_t<ODataType_>;
static constexpr bool kPadM = kPadM_;
static constexpr bool kPadN = kPadN_;
static constexpr bool kTilePermute = kTilePermute_;
static constexpr index_t kRank = kRank_;
static constexpr index_t kPerm[CK_TILE_MAX_RANK] = {kPerm0, kPerm1, kPerm2, kPerm3, kPerm4};
static constexpr index_t tile_sizes[CK_TILE_MAX_RANK] = {
TileSize0, TileSize1, TileSize2, TileSize3, TileSize4};
};
template <typename Problem_, typename Policy_ = void>
struct CShuffleEpilogue
{
using Problem = remove_cvref_t<Problem_>;
using AccDataType = remove_cvref_t<typename Problem::AccDataType>;
using ODataType = remove_cvref_t<typename Problem::ODataType>;
static constexpr bool kPadM = Problem::kPadM;
static constexpr bool kPadN = Problem::kPadN;
const index_t* kPerm = Problem::kPerm;
static constexpr bool kTilePermute = Problem::kTilePermute;
static constexpr index_t kRank = Problem::kRank;
const index_t* tile_sizes = Problem::tile_sizes;
// No additional shared memory needed
CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSize() { return 0; }
template <typename OAccTile>
CK_TILE_DEVICE void permute_tile_data(OAccTile& o_acc_tile)
{
using DataType = typename OAccTile::DataType;
// Get thread buffer
auto& thread_buf = o_acc_tile.get_thread_buffer();
// Create a temporary buffer to hold the permuted data
thread_buffer<DataType, OAccTile::kThreadElementSpaceSize> permuted_thread_buf;
// Get the lengths of each dimension
auto thread_tensor_lengths = o_acc_tile.get_lengths();
// Total number of elements
index_t total_elements = OAccTile::kThreadElementSpaceSize;
// Iterate over all elements
for(index_t linear_idx = 0; linear_idx < total_elements; ++linear_idx)
{
// Convert linear index to multi-dimensional indices
array<index_t, kRank> indices;
index_t remaining = linear_idx;
static_for<0, kRank, 1>{}([&](auto i) {
constexpr auto rev_i = kRank - 1 - i;
indices(rev_i) = remaining % thread_tensor_lengths.get(number<rev_i>{});
remaining /= thread_tensor_lengths.get(number<rev_i>{});
});
// Apply the permutation
array<index_t, kRank> permuted_indices;
static_for<0, kRank, 1>{}(
[&](auto i) { permuted_indices(i) = indices.get(number<Problem::kPerm[i]>{}); });
// Compute offsets
index_t dst_offset = 0;
index_t stride = 1;
static_for<0, kRank, 1>{}([&](auto i) {
constexpr auto rev_i = kRank - 1 - i;
dst_offset += permuted_indices[rev_i] * stride;
stride *= thread_tensor_lengths.get(number<rev_i>{});
});
// Move the data
permuted_thread_buf(dst_offset) = thread_buf[linear_idx];
}
// Copy the permuted data back to the original thread buffer
for(index_t i = 0; i < total_elements; ++i)
{
thread_buf.set_as(i, permuted_thread_buf.get(i));
}
}
template <typename ODramWindowTmp, typename OAccTile>
CK_TILE_DEVICE auto operator()(ODramWindowTmp& o_dram_window_tmp, OAccTile& o_acc_tile)
{
const auto& current_window_origin = o_dram_window_tmp.get_window_origin();
// Compute the tile coordinates by dividing the window origin by the tile sizes
index_t tile_coords[CK_TILE_MAX_RANK] = {0};
for(index_t i = 0; i < kRank; ++i)
{
tile_coords[i] = current_window_origin[i] / tile_sizes[i];
// printf("The tile_coord is: %d", tile_coords[i]);
}
// Apply the permutation to the tile coordinates
index_t permuted_tile_coords[CK_TILE_MAX_RANK];
for(index_t i = 0; i < kRank; ++i)
{
permuted_tile_coords[i] = tile_coords[kPerm[i]];
// printf("The new permuted_tile_coords is: %d", permuted_tile_coords[i]);
}
// Compute the permuted window origin
index_t permuted_window_origin[CK_TILE_MAX_RANK] = {0};
for(index_t i = 0; i < kRank; ++i)
{
permuted_window_origin[i] = permuted_tile_coords[i] * tile_sizes[i];
// printf("The new permuted_window_origin is: %d", permuted_window_origin[i]);
}
typename ODramWindowTmp::BottomTensorIndex step = {};
for(index_t i = 0; i < kRank; ++i)
{
step[i] = permuted_window_origin[i] - current_window_origin[i];
}
// Move the window
move_tile_window(o_dram_window_tmp, step);
// Permute the data within the tile if necessary
if constexpr(kTilePermute)
{
permute_tile_data(o_acc_tile);
}
// Store the tile data to the permuted location
if constexpr(kPadM || kPadN)
{
store_tile_raw(o_dram_window_tmp, cast_tile<ODataType>(o_acc_tile));
buffer_store_fence();
}
else
{
store_tile(o_dram_window_tmp, cast_tile<ODataType>(o_acc_tile));
}
}
};
} // namespace ck_tile
......@@ -6,8 +6,11 @@
#include "ck_tile/core.hpp"
#include "ck_tile/ops/common.hpp"
#include "ck_tile/ops/fmha/block/block_attention_bias_enum.hpp"
#include <string>
#include <type_traits>
#include <utility>
#include <variant>
// S[seqlen_q, seqlen_k] = Q[seqlen_q, hdim_q] @ K[seqlen_k, hdim_q]
// S'[seqlen_q, seqlen_k] = S[seqlen_q, seqlen_k] * Scale[1]
......@@ -194,11 +197,23 @@ struct FmhaBwdDQDKDVKernel
ck_tile::GenericAttentionMaskEnum mask_type;
};
struct FmhaBwdCommonDropoutKargs
struct FmhaBwdDropoutSeedOffset
{
template <typename T>
union ValueOrPointer
{
T val;
const T* ptr;
};
ValueOrPointer<uint64_t> drop_seed;
ValueOrPointer<uint64_t> drop_offset;
bool is_drop_seed_offset_from_host;
};
struct FmhaBwdCommonDropoutKargs : FmhaBwdDropoutSeedOffset
{
void init_dropout(const float p_drop,
const std::tuple<uint64_t, uint64_t>& drop_seed_offset,
const float raw_scale)
void init_dropout(float p_drop, uint64_t seed, uint64_t offset, float raw_scale)
{
float p_undrop = 1.0 - p_drop;
p_undrop_in_uint8_t =
......@@ -206,23 +221,41 @@ struct FmhaBwdDQDKDVKernel
rp_undrop = 1.0 / p_undrop;
scale_rp_undrop = rp_undrop * raw_scale;
drop_seed = std::get<0>(drop_seed_offset);
drop_offset = std::get<1>(drop_seed_offset);
this->drop_seed.val = seed;
this->drop_offset.val = offset;
this->is_drop_seed_offset_from_host = true;
}
void init_dropout(float p_drop,
const uint64_t* seed_ptr,
const uint64_t* offset_ptr,
float raw_scale)
{
float p_undrop = 1.0 - p_drop;
p_undrop_in_uint8_t =
uint8_t(std::floor(p_undrop * std::numeric_limits<uint8_t>::max()));
rp_undrop = 1.0 / p_undrop;
scale_rp_undrop = rp_undrop * raw_scale;
this->drop_seed.ptr = seed_ptr;
this->drop_offset.ptr = offset_ptr;
this->is_drop_seed_offset_from_host = false;
}
float rp_undrop = 1;
float scale_rp_undrop = 1;
uint8_t p_undrop_in_uint8_t = std::numeric_limits<uint8_t>::max();
uint64_t drop_seed = 1;
uint64_t drop_offset = 0;
void* rand_val_ptr = nullptr;
ck_tile::index_t stride_randval = 0;
ck_tile::index_t nhead_stride_randval = 0;
};
struct FmhaBwdBatchModeDropoutKargs : FmhaBwdCommonDropoutKargs
{
ck_tile::index_t batch_stride_randval = 0;
};
struct FmhaBwdDeterministicKargs
{
ck_tile::index_t split_stride_dq_acc = 0;
......@@ -327,7 +360,8 @@ struct FmhaBwdDQDKDVKernel
ck_tile::index_t window_size_right,
ck_tile::index_t mask_type,
float p_drop,
const std::tuple<uint64_t, uint64_t>& drop_seed_offset)
std::variant<std::pair<uint64_t, uint64_t>, std::pair<const void*, const void*>>
drop_seed_offset)
{
Kargs kargs{{q_ptr,
k_ptr,
......@@ -405,7 +439,20 @@ struct FmhaBwdDQDKDVKernel
if constexpr(kHasDropout)
{
kargs.init_dropout(p_drop, drop_seed_offset, scale);
if(drop_seed_offset.index() == 0) // seed & offset come from host
{
const auto& [seed, offset] = std::get<0>(drop_seed_offset);
kargs.init_dropout(p_drop, seed, offset, scale);
}
else // seed & offset come from device
{
const auto& [seed_ptr, offset_ptr] = std::get<1>(drop_seed_offset);
kargs.init_dropout(p_drop,
reinterpret_cast<const uint64_t*>(seed_ptr),
reinterpret_cast<const uint64_t*>(offset_ptr),
scale);
}
if constexpr(kIsStoreRandval)
{
kargs.rand_val_ptr = rand_val_ptr;
......@@ -471,7 +518,8 @@ struct FmhaBwdDQDKDVKernel
ck_tile::index_t window_size_right,
ck_tile::index_t mask_type,
float p_drop,
const std::tuple<uint64_t, uint64_t>& drop_seed_offset)
std::variant<std::pair<uint64_t, uint64_t>, std::pair<const void*, const void*>>
drop_seed_offset)
{
Kargs kargs{{q_ptr,
k_ptr,
......@@ -539,7 +587,20 @@ struct FmhaBwdDQDKDVKernel
}
if constexpr(kHasDropout)
{
kargs.init_dropout(p_drop, drop_seed_offset, scale);
if(drop_seed_offset.index() == 0) // seed & offset come from host
{
const auto& [seed, offset] = std::get<0>(drop_seed_offset);
kargs.init_dropout(p_drop, seed, offset, scale);
}
else // seed & offset come from device
{
const auto& [seed_ptr, offset_ptr] = std::get<1>(drop_seed_offset);
kargs.init_dropout(p_drop,
reinterpret_cast<const uint64_t*>(seed_ptr),
reinterpret_cast<const uint64_t*>(offset_ptr),
scale);
}
if constexpr(kIsStoreRandval)
{
kargs.rand_val_ptr = rand_val_ptr;
......@@ -958,8 +1019,10 @@ struct FmhaBwdDQDKDVKernel
return FmhaDropout{i_batch_,
i_nhead_,
kargs.num_head_q,
kargs.drop_seed,
kargs.drop_offset,
kargs.is_drop_seed_offset_from_host ? kargs.drop_seed.val
: *kargs.drop_seed.ptr,
kargs.is_drop_seed_offset_from_host ? kargs.drop_offset.val
: *kargs.drop_offset.ptr,
kargs.rp_undrop,
kargs.p_undrop_in_uint8_t};
}
......
......@@ -6,8 +6,11 @@
#include "ck_tile/core.hpp"
#include "ck_tile/ops/common.hpp"
#include "ck_tile/ops/fmha/block/block_attention_bias_enum.hpp"
#include <string>
#include <type_traits>
#include <utility>
#include <variant>
// S[seqlen_q, seqlen_k] = Q[seqlen_q, hdim_q] @ K[seqlen_k, hdim_q]
// S'[seqlen_q, seqlen_k] = S[seqlen_q, seqlen_k] * Scale[1]
......@@ -170,29 +173,55 @@ struct FmhaFwdKernel
ck_tile::index_t batch_stride_lse = 0;
};
struct FmhaFwdCommonDropoutKargs
struct FmhaFwdDropoutSeedOffset
{
template <typename T>
union ValueOrPointer
{
T val;
const T* ptr;
};
ValueOrPointer<uint64_t> drop_seed;
ValueOrPointer<uint64_t> drop_offset;
bool is_drop_seed_offset_from_host;
};
struct FmhaFwdCommonDropoutKargs : FmhaFwdDropoutSeedOffset
{
void init_dropout(const float p_drop,
const std::tuple<uint64_t, uint64_t>& drop_seed_offset)
void init_dropout(float p_drop, uint64_t seed, uint64_t offset)
{
float p_undrop = 1.0 - p_drop;
p_undrop_in_uint8_t =
uint8_t(std::floor(p_undrop * std::numeric_limits<uint8_t>::max()));
rp_undrop = 1.0 / p_undrop;
drop_seed = std::get<0>(drop_seed_offset);
drop_offset = std::get<1>(drop_seed_offset);
this->drop_seed.val = seed;
this->drop_offset.val = offset;
this->is_drop_seed_offset_from_host = true;
}
void init_dropout(float p_drop, const uint64_t* seed_ptr, const uint64_t* offset_ptr)
{
float p_undrop = 1.0 - p_drop;
p_undrop_in_uint8_t =
uint8_t(std::floor(p_undrop * std::numeric_limits<uint8_t>::max()));
rp_undrop = 1.0 / p_undrop;
this->drop_seed.ptr = seed_ptr;
this->drop_offset.ptr = offset_ptr;
this->is_drop_seed_offset_from_host = false;
}
float rp_undrop = 1;
uint8_t p_undrop_in_uint8_t = std::numeric_limits<uint8_t>::max();
bool is_store_randval = false;
uint64_t drop_seed = 1;
uint64_t drop_offset = 0;
void* rand_val_ptr = nullptr;
ck_tile::index_t stride_randval = 0;
ck_tile::index_t nhead_stride_randval = 0;
};
struct FmhaFwdBatchModeDropoutKargs : FmhaFwdCommonDropoutKargs
{
ck_tile::index_t batch_stride_randval = 0;
......@@ -278,7 +307,8 @@ struct FmhaFwdKernel
ck_tile::index_t mask_type,
float p_drop,
bool s_randval,
const std::tuple<uint64_t, uint64_t>& drop_seed_offset)
std::variant<std::pair<uint64_t, uint64_t>, std::pair<const void*, const void*>>
drop_seed_offset)
{
Kargs kargs{{q_ptr,
k_ptr,
......@@ -344,7 +374,19 @@ struct FmhaFwdKernel
}
if constexpr(kHasDropout)
{
kargs.init_dropout(p_drop, drop_seed_offset);
if(drop_seed_offset.index() == 0) // seed & offset come from host
{
const auto& [seed, offset] = std::get<0>(drop_seed_offset);
kargs.init_dropout(p_drop, seed, offset);
}
else // seed & offset come from device
{
const auto& [seed_ptr, offset_ptr] = std::get<1>(drop_seed_offset);
kargs.init_dropout(p_drop,
reinterpret_cast<const uint64_t*>(seed_ptr),
reinterpret_cast<const uint64_t*>(offset_ptr));
}
kargs.rand_val_ptr = rand_val_ptr;
kargs.stride_randval = stride_randval;
kargs.nhead_stride_randval = nhead_stride_randval;
......@@ -392,7 +434,8 @@ struct FmhaFwdKernel
ck_tile::index_t mask_type,
float p_drop,
bool s_randval,
const std::tuple<uint64_t, uint64_t>& drop_seed_offset)
std::variant<std::pair<uint64_t, uint64_t>, std::pair<const void*, const void*>>
drop_seed_offset)
{
Kargs kargs{{q_ptr,
k_ptr,
......@@ -455,7 +498,19 @@ struct FmhaFwdKernel
}
if constexpr(kHasDropout)
{
kargs.init_dropout(p_drop, drop_seed_offset);
if(drop_seed_offset.index() == 0) // seed & offset come from host
{
const auto& [seed, offset] = std::get<0>(drop_seed_offset);
kargs.init_dropout(p_drop, seed, offset);
}
else // seed & offset come from device
{
const auto& [seed_ptr, offset_ptr] = std::get<1>(drop_seed_offset);
kargs.init_dropout(p_drop,
reinterpret_cast<const uint64_t*>(seed_ptr),
reinterpret_cast<const uint64_t*>(offset_ptr));
}
kargs.rand_val_ptr = rand_val_ptr;
kargs.stride_randval = stride_randval;
kargs.nhead_stride_randval = nhead_stride_randval;
......@@ -748,8 +803,10 @@ struct FmhaFwdKernel
return BlockDropout{i_batch_,
i_nhead_,
kargs.num_head_q,
kargs.drop_seed,
kargs.drop_offset,
kargs.is_drop_seed_offset_from_host ? kargs.drop_seed.val
: *kargs.drop_seed.ptr,
kargs.is_drop_seed_offset_from_host ? kargs.drop_offset.val
: *kargs.drop_offset.ptr,
kargs.rp_undrop,
kargs.p_undrop_in_uint8_t,
kargs.is_store_randval};
......
......@@ -5,8 +5,9 @@
#include "ck_tile/core.hpp"
#include "ck_tile/ops/common/tensor_layout.hpp"
#include "ck_tile/ops/gemm/pipeline/block_gemm_pipeline_problem.hpp"
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_problem.hpp"
#include "ck_tile/ops/gemm/pipeline/tile_gemm_shape.hpp"
#include "ck_tile/ops/gemm/pipeline/tile_gemm_traits.hpp"
#include "ck_tile/ops/gemm/warp/warp_gemm.hpp"
#include "ck_tile/ops/gemm/warp/warp_gemm_dispatcher.hpp"
#include "ck_tile/ops/gemm/block/block_gemm_areg_bsmem_creg_v1_custom_policy.hpp"
......@@ -25,15 +26,21 @@ struct BlockFmhaBwdPipelineDefaultPolicy
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto GetQKBlockGemm()
{
using BlockGemmProblem = BlockGemmPipelineProblem<
typename Problem::QDataType,
using GemmProblem =
GemmPipelineProblem<typename Problem::QDataType,
typename Problem::KDataType,
typename Problem::AccDataType,
TileGemmShape<sequence<Problem::BlockFmhaShape::kM0,
Problem::BlockFmhaShape::kN0,
Problem::BlockFmhaShape::kK0>,
typename Problem::BlockFmhaShape::Gemm0BlockWarps,
typename Problem::BlockFmhaShape::Gemm0WarpTile>>;
typename Problem::BlockFmhaShape::Gemm0WarpTile>,
TileGemmTraits<Problem::kPadSeqLenQ,
Problem::kPadSeqLenK,
Problem::kPadHeadDimQ,
typename tensor_layout::gemm::RowMajor,
typename tensor_layout::gemm::ColumnMajor,
typename tensor_layout::gemm::RowMajor>>;
using WarpGemm = WarpGemmMfmaDispatcher<
typename Problem::QDataType,
......@@ -52,21 +59,27 @@ struct BlockFmhaBwdPipelineDefaultPolicy
typename Problem::BlockFmhaShape::Gemm0BlockWarps,
WarpGemm>;
return BlockGemmARegBRegCRegV1<BlockGemmProblem, BlockGemmPolicy>{};
return BlockGemmARegBRegCRegV1<GemmProblem, BlockGemmPolicy>{};
}
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto GetPTOGradTBlockGemm()
{
using BlockGemmProblem = BlockGemmPipelineProblem<
typename Problem::GemmDataType,
using GemmProblem =
GemmPipelineProblem<typename Problem::GemmDataType,
typename Problem::OGradDataType,
typename Problem::AccDataType,
TileGemmShape<sequence<Problem::BlockFmhaShape::kN0,
Problem::BlockFmhaShape::kVHeaddim,
Problem::BlockFmhaShape::kK1>,
typename Problem::BlockFmhaShape::Gemm1BlockWarps,
typename Problem::BlockFmhaShape::Gemm1WarpTile>>;
typename Problem::BlockFmhaShape::Gemm1WarpTile>,
TileGemmTraits<Problem::kPadSeqLenQ,
Problem::kPadHeadDimV,
Problem::kPadHeadDimV,
typename tensor_layout::gemm::RowMajor,
typename tensor_layout::gemm::ColumnMajor,
typename tensor_layout::gemm::RowMajor>>;
using WarpGemm =
WarpGemmMfmaDispatcher<typename Problem::GemmDataType,
......@@ -84,21 +97,27 @@ struct BlockFmhaBwdPipelineDefaultPolicy
typename Problem::BlockFmhaShape::Gemm1BlockWarps,
WarpGemm>;
return BlockGemmARegBRegCRegV1<BlockGemmProblem, BlockGemmPolicy>{};
return BlockGemmARegBRegCRegV1<GemmProblem, BlockGemmPolicy>{};
}
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto GetOGradVBlockGemm()
{
using BlockGemmProblem = BlockGemmPipelineProblem<
typename Problem::OGradDataType,
using GemmProblem =
GemmPipelineProblem<typename Problem::OGradDataType,
typename Problem::VDataType,
typename Problem::AccDataType,
TileGemmShape<sequence<Problem::BlockFmhaShape::kM0,
Problem::BlockFmhaShape::kN0,
Problem::BlockFmhaShape::kK2>,
typename Problem::BlockFmhaShape::Gemm2BlockWarps,
typename Problem::BlockFmhaShape::Gemm2WarpTile>>;
typename Problem::BlockFmhaShape::Gemm2WarpTile>,
TileGemmTraits<Problem::kPadSeqLenQ,
Problem::kPadSeqLenK,
Problem::kPadHeadDimQ,
typename tensor_layout::gemm::RowMajor,
typename tensor_layout::gemm::ColumnMajor,
typename tensor_layout::gemm::RowMajor>>;
using WarpGemm = WarpGemmMfmaDispatcher<
typename Problem::OGradDataType,
......@@ -117,21 +136,27 @@ struct BlockFmhaBwdPipelineDefaultPolicy
typename Problem::BlockFmhaShape::Gemm2BlockWarps,
WarpGemm>;
return BlockGemmARegBRegCRegV1<BlockGemmProblem, BlockGemmPolicy>{};
return BlockGemmARegBRegCRegV1<GemmProblem, BlockGemmPolicy>{};
}
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto GetSGradTQTBlockGemm()
{
using BlockGemmProblem = BlockGemmPipelineProblem<
typename Problem::GemmDataType,
using GemmProblem =
GemmPipelineProblem<typename Problem::GemmDataType,
typename Problem::QDataType,
typename Problem::AccDataType,
TileGemmShape<sequence<Problem::BlockFmhaShape::kN0,
Problem::BlockFmhaShape::kQKHeaddim,
Problem::BlockFmhaShape::kK3>,
typename Problem::BlockFmhaShape::Gemm3BlockWarps,
typename Problem::BlockFmhaShape::Gemm3WarpTile>>;
typename Problem::BlockFmhaShape::Gemm3WarpTile>,
TileGemmTraits<Problem::kPadSeqLenK,
Problem::kPadHeadDimQ,
Problem::kPadSeqLenK,
typename tensor_layout::gemm::RowMajor,
typename tensor_layout::gemm::ColumnMajor,
typename tensor_layout::gemm::RowMajor>>;
using WarpGemm =
WarpGemmMfmaDispatcher<typename Problem::GemmDataType,
......@@ -149,21 +174,27 @@ struct BlockFmhaBwdPipelineDefaultPolicy
typename Problem::BlockFmhaShape::Gemm3BlockWarps,
WarpGemm>;
return BlockGemmARegBRegCRegV1<BlockGemmProblem, BlockGemmPolicy>{};
return BlockGemmARegBRegCRegV1<GemmProblem, BlockGemmPolicy>{};
}
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto GetSGradKTBlockGemm()
{
using BlockGemmProblem = BlockGemmPipelineProblem<
typename Problem::GemmDataType,
using GemmProblem =
GemmPipelineProblem<typename Problem::GemmDataType,
typename Problem::KDataType,
typename Problem::AccDataType,
TileGemmShape<sequence<Problem::BlockFmhaShape::kM0,
Problem::BlockFmhaShape::kQKHeaddim,
Problem::BlockFmhaShape::kK4>,
typename Problem::BlockFmhaShape::Gemm4BlockWarps,
typename Problem::BlockFmhaShape::Gemm4WarpTile>>;
typename Problem::BlockFmhaShape::Gemm4WarpTile>,
TileGemmTraits<Problem::kPadSeqLenQ,
Problem::kPadHeadDimQ,
Problem::kPadSeqLenK,
typename tensor_layout::gemm::RowMajor,
typename tensor_layout::gemm::ColumnMajor,
typename tensor_layout::gemm::RowMajor>>;
using WarpGemm =
WarpGemmMfmaDispatcher<typename Problem::GemmDataType,
......@@ -181,7 +212,7 @@ struct BlockFmhaBwdPipelineDefaultPolicy
typename Problem::BlockFmhaShape::Gemm4BlockWarps,
WarpGemm>;
return BlockGemmARegBRegCRegV1<BlockGemmProblem, BlockGemmPolicy>{};
return BlockGemmARegBRegCRegV1<GemmProblem, BlockGemmPolicy>{};
}
// these are for global load
......
......@@ -172,22 +172,27 @@ struct BlockFmhaFwdSplitKVCombinePipeline
lse_accum, sequence<1>{}, f_max, -numeric<LSEDataType>::infinity());
block_tile_reduce_sync(lse_max, f_max, bool_constant<false>{});
static const auto get_validated_m = [](LSEDataType raw_m) {
return raw_m == -numeric<LSEDataType>::infinity() ? type_convert<LSEDataType>(0.f)
: raw_m;
};
decltype(lse_accum) lse_exp;
{
constexpr auto spans = decltype(lse_exp)::get_distributed_spans();
sweep_tile_span(spans[number<0>{}], [&](auto idx0) {
constexpr auto i_idx = make_tuple(idx0);
if(lse_max[i_idx] == -numeric<LSEDataType>::infinity())
{
sweep_tile_span(spans[number<1>{}], [&](auto idx1) {
constexpr auto i_j_idx = make_tuple(idx0, idx1);
lse_exp(i_j_idx) =
ck_tile::exp(lse_accum(i_j_idx) - get_validated_m(lse_max(i_idx)));
lse_exp(i_j_idx) = ck_tile::type_convert<LSEDataType>(0.0f);
});
}
else
{
sweep_tile_span(spans[number<1>{}], [&](auto idx1) {
constexpr auto i_j_idx = make_tuple(idx0, idx1);
lse_exp(i_j_idx) = ck_tile::exp(lse_accum(i_j_idx) - lse_max(i_idx));
});
}
});
}
......@@ -201,15 +206,10 @@ struct BlockFmhaFwdSplitKVCombinePipeline
sweep_tile_span(spans[number<0>{}], [&](auto idx0) {
constexpr auto i_idx = make_tuple(idx0);
if(lse_sum(i_idx) == 0.f || lse_sum(i_idx) != lse_sum(i_idx))
{
lse_logsum(i_idx) = numeric<LSEDataType>::infinity();
}
if(lse_sum[i_idx] == ck_tile::type_convert<LSEDataType>(0.0f))
lse_logsum(i_idx) = -numeric<LSEDataType>::infinity();
else
{
lse_logsum(i_idx) =
ck_tile::log(lse_sum(i_idx)) + get_validated_m(lse_max(i_idx));
}
lse_logsum(i_idx) = ck_tile::log(lse_sum(i_idx)) + lse_max(i_idx);
});
}
......@@ -218,6 +218,25 @@ struct BlockFmhaFwdSplitKVCombinePipeline
constexpr auto spans = decltype(lse_accum)::get_distributed_spans();
sweep_tile_span(spans[number<0>{}], [&](auto idx0) {
constexpr auto i_idx = make_tuple(idx0);
if(lse_logsum(i_idx) == -numeric<LSEDataType>::infinity())
{
sweep_tile_span(spans[number<1>{}], [&](auto idx1) {
constexpr auto i_j_idx = make_tuple(idx0, idx1);
const auto x_indices = get_x_indices_from_distributed_indices(
lse_accum.get_tile_distribution(), i_j_idx);
const auto col = x_indices.at(number<1>{});
if(col < num_splits)
{
const auto row = x_indices.at(number<0>{});
lse_acc_lds(row, col) = ck_tile::type_convert<LSEDataType>(0.0f);
}
});
}
else
{
sweep_tile_span(spans[number<1>{}], [&](auto idx1) {
constexpr auto i_j_idx = make_tuple(idx0, idx1);
......@@ -233,22 +252,13 @@ struct BlockFmhaFwdSplitKVCombinePipeline
ck_tile::exp(lse_accum(i_j_idx) - lse_logsum(i_idx));
}
});
}
});
}
block_sync_lds();
if constexpr(kStoreLSE)
{
constexpr auto spans = decltype(lse_logsum)::get_distributed_spans();
sweep_tile_span(spans[number<0>{}], [&](auto idx0) {
constexpr auto i_idx = make_tuple(idx0);
if(lse_logsum(i_idx) == numeric<LSEDataType>::infinity())
{
lse_logsum(i_idx) = -numeric<LSEDataType>::infinity();
}
});
store_tile(lse_dram_window_tmp, tile_elementwise_in(lse_element_func, lse_logsum));
}
......
......@@ -21,14 +21,23 @@ struct BlockFmhaFwdSplitKVCombinePipelineDefaultPolicy
CK_TILE_HOST_DEVICE static constexpr auto GetAlignmentOacc()
{
using OaccDataType = remove_cvref_t<typename Problem::OaccDataType>;
return 16 / sizeof(OaccDataType);
constexpr index_t kBlockSize = Problem::kBlockSize;
constexpr index_t kMPerBlock = Problem::kM0;
constexpr index_t kNPerBlock = Problem::kN1;
constexpr index_t M1 = kBlockSize / get_warp_size();
constexpr index_t M2 = min(kMPerBlock / M1, get_warp_size());
constexpr index_t N0 = get_warp_size() / M2;
constexpr index_t N1 = kNPerBlock / N0;
return min(N1, static_cast<index_t>(16 / sizeof(OaccDataType)));
}
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto GetAlignmentO()
{
using ODataType = remove_cvref_t<typename Problem::ODataType>;
return 16 / sizeof(ODataType);
return GetAlignmentOacc<Problem>();
}
template <typename Problem>
......@@ -150,16 +159,14 @@ struct BlockFmhaFwdSplitKVCombinePipelineDefaultPolicy
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto MakeOaccDramTileDistribution()
{
using OaccDataType = remove_cvref_t<typename Problem::OaccDataType>;
constexpr index_t kBlockSize = Problem::kBlockSize;
constexpr index_t kMPerBlock = Problem::kM0;
constexpr index_t kNPerBlock = Problem::kN1;
constexpr index_t N1 = 16 / sizeof(OaccDataType);
constexpr index_t N0 = kNPerBlock / N1;
constexpr index_t M2 = get_warp_size() / N0;
constexpr index_t M1 = kBlockSize / get_warp_size();
constexpr index_t M2 = min(kMPerBlock / M1, get_warp_size());
constexpr index_t N0 = get_warp_size() / M2;
constexpr index_t N1 = kNPerBlock / N0;
constexpr index_t M0 = kMPerBlock / (M2 * M1);
return make_static_tile_distribution(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment