Unverified Commit 7d50244e authored by Illia Silin's avatar Illia Silin Committed by GitHub
Browse files

Merge pull request #209 from ROCm/andriy/merge_from_public

Update develop branch from public repository
parents f221c2b0 d51701d4
......@@ -5,6 +5,7 @@
#include "ck_tile/core.hpp"
#include "ck_tile/ops/layernorm2d/pipeline/layernorm2d_fwd_pipeline_default_policy.hpp"
#include "ck_tile/ops/layernorm2d/pipeline/layernorm2d_fwd_traits.hpp"
#include <string>
#include <type_traits>
......@@ -24,14 +25,19 @@ struct Layernorm2dFwdPipelineOnePass
using MeanDataType = ck_tile::remove_cvref_t<typename Problem::MeanDataType>;
using InvStdDataType = ck_tile::remove_cvref_t<typename Problem::InvStdDataType>;
using XResidualDataType = XDataType;
using YResidualDataType = XDataType;
static constexpr bool kHasGamma = !std::is_same_v<GammaDataType, ck_tile::null_type>;
static constexpr bool kHasBeta = !std::is_same_v<BetaDataType, ck_tile::null_type>;
static constexpr bool kSaveMean = Problem::kSaveMeanInvStd;
static constexpr bool kSaveInvStd = Problem::kSaveMeanInvStd;
static constexpr bool kSaveMean = Problem::Traits::kSaveMeanInvStd;
static constexpr bool kSaveInvStd = Problem::Traits::kSaveMeanInvStd;
static constexpr bool kNeedCrossWarpSync = Problem::kNeedCrossWarpSync;
static constexpr bool kPadM = false; // TODO - BlockLayernorm2dFwdProblem::kPadM
static constexpr bool kPadN = Problem::kPadN;
static constexpr bool kPadN = Problem::Traits::kPadN;
static constexpr auto kFusedAdd = Problem::Traits::kFusedAdd;
static constexpr auto kFusedQuant = Problem::Traits::kFusedQuant;
static constexpr const char* name = []() {
if constexpr(kNeedCrossWarpSync)
......@@ -46,20 +52,30 @@ struct Layernorm2dFwdPipelineOnePass
}
template <typename XWindow,
typename XResidualWindow,
typename GammaWindow,
typename BetaWindow,
typename YWindow,
typename YResidualWindow,
typename MeanWindow,
typename InvStdWindow>
typename InvStdWindow,
typename XScaleWindow,
typename YScaleWindow,
typename Epilogue>
CK_TILE_DEVICE auto operator()(const XWindow& x_window_,
const XResidualWindow& x_residual_window_,
const GammaWindow& gamma_window_,
const BetaWindow& beta_window_,
YWindow& y_window,
YWindow& y_window_,
const YResidualWindow& y_residual_window_,
MeanWindow& mean_window,
InvStdWindow& inv_std_window,
const XScaleWindow& x_scale_window_,
YScaleWindow& y_scale_window,
ComputeDataType epsilon,
ck_tile::index_t row_size,
void* smem) const
void* smem,
Epilogue) const
{
const auto x_window =
make_tile_window(x_window_, Policy::template MakeXBlockTileDistribution<Problem>());
......@@ -67,8 +83,17 @@ struct Layernorm2dFwdPipelineOnePass
gamma_window_, Policy::template MakeGammaBetaBlockTileDistribution<Problem>());
const auto beta_window = make_tile_window(
beta_window_, Policy::template MakeGammaBetaBlockTileDistribution<Problem>());
const auto x_residual_window = make_tile_window(
x_residual_window_, Policy::template MakeXBlockTileDistribution<Problem>());
auto y_residual_window = make_tile_window(
y_residual_window_, Policy::template MakeXBlockTileDistribution<Problem>());
const auto x_scale_window = make_tile_window(
x_scale_window_, Policy::template MakeGammaBetaBlockTileDistribution<Problem>());
auto x = load_tile(x_window);
auto x_resi = load_tile(x_residual_window);
auto x_scale = load_tile(x_scale_window);
const auto x = load_tile(x_window);
int cur_count = 0;
int max_count =
block_tile_welford_calculate_max_count<typename Problem::BlockShape>(row_size);
......@@ -81,6 +106,18 @@ struct Layernorm2dFwdPipelineOnePass
const auto gamma = load_tile(gamma_window);
const auto beta = load_tile(beta_window);
if constexpr(kFusedAdd == Layernorm2dFusedAddEnum::PRE_ADD_STORE ||
kFusedAdd == Layernorm2dFusedAddEnum::PRE_ADD)
{
sweep_tile(x_resi, [&](auto idx) {
// compute x = x_resi + x
x(idx) = type_convert<YResidualDataType>(x_resi(idx)) +
type_convert<YResidualDataType>(x(idx));
});
if constexpr(kFusedAdd == Layernorm2dFusedAddEnum::PRE_ADD_STORE)
store_tile(y_residual_window, x);
}
// compute welford each-thread->cross-lane->cross-warp
auto [mean, var] = block_welford(x, cur_count, max_count);
block_welford_sync(mean, var, cur_count);
......@@ -90,7 +127,7 @@ struct Layernorm2dFwdPipelineOnePass
// compute inv-std
auto inv_std = tile_elementwise_in(
[&](const auto& v_) {
return type_convert<ComputeDataType>(1.0f) / (sqrt(v_) + epsilon);
return type_convert<ComputeDataType>(1.0f) / (sqrt(v_ + epsilon));
},
var);
......@@ -100,8 +137,8 @@ struct Layernorm2dFwdPipelineOnePass
store_tile(inv_std_window, cast_tile<InvStdDataType>(inv_std));
// layernorm computation
auto y = make_static_distributed_tensor<YDataType>(x.get_tile_distribution());
sweep_tile(y, [&, mean_ = mean](auto idx) {
auto ln = make_static_distributed_tensor<ComputeDataType>(x.get_tile_distribution());
sweep_tile(ln, [&, mean_ = mean](auto idx) {
constexpr auto i_idx = make_tuple(idx[number<0>{}]);
constexpr auto j_idx = make_tuple(idx[number<1>{}]);
......@@ -109,11 +146,28 @@ struct Layernorm2dFwdPipelineOnePass
const auto beta_ = type_convert<ComputeDataType>(beta[j_idx]);
const auto x_ = type_convert<ComputeDataType>(x[idx]);
auto y_ = (x_ - mean_[i_idx]) * inv_std[i_idx] * gamma_ + beta_;
auto ln_ = (x_ - mean_[i_idx]) * inv_std[i_idx] * gamma_ + beta_;
y(idx) = type_convert<YDataType>(y_);
ln(idx) = ln_;
});
store_tile(y_window, y);
if constexpr(kFusedQuant == Layernorm2dFusedQuantEnum::SMOOTH_DYNAMIC_QUANT)
{
// smooth-quant pre-scale, then run rowwise-quant
sweep_tile(ln, [&](auto idx) {
constexpr auto j_idx = make_tuple(idx[number<1>{}]);
const auto xs_ = type_convert<ComputeDataType>(x_scale[j_idx]);
ln(idx) = ln(idx) * xs_;
});
}
if constexpr(kFusedQuant == Layernorm2dFusedQuantEnum::DYNAMIC_QUANT ||
kFusedQuant == Layernorm2dFusedQuantEnum::SMOOTH_DYNAMIC_QUANT)
{
Epilogue{}(y_window_, y_scale_window, ln, smem);
}
else
Epilogue{}(y_window_, ln);
}
};
} // namespace ck_tile
......@@ -14,10 +14,10 @@ template <typename XDataType_,
typename YDataType_,
typename MeanDataType_,
typename InvStdDataType_,
typename XScaleDataType_,
typename YScaleDataType_,
typename BlockShape_,
bool kPadN_,
bool kSaveMeanInvStd_,
bool kTwoPass_>
typename Traits_>
struct Layernorm2dFwdPipelineProblem
{
using XDataType = remove_cvref_t<XDataType_>;
......@@ -27,14 +27,14 @@ struct Layernorm2dFwdPipelineProblem
using YDataType = remove_cvref_t<YDataType_>;
using MeanDataType = remove_cvref_t<MeanDataType_>;
using InvStdDataType = remove_cvref_t<InvStdDataType_>;
using XScaleDataType = remove_cvref_t<XScaleDataType_>;
using YScaleDataType = remove_cvref_t<YScaleDataType_>;
using BlockShape = remove_cvref_t<BlockShape_>;
static constexpr bool kNeedCrossLaneSync = BlockShape::ThreadPerWarp_N > 1;
static constexpr bool kNeedCrossWarpSync = BlockShape::WarpPerBlock_N > 1;
static constexpr bool kPadN = kPadN_;
static constexpr bool kSaveMeanInvStd = kSaveMeanInvStd_;
static constexpr bool kTwoPass = kTwoPass_;
using Traits = remove_cvref_t<Traits_>;
};
} // namespace ck_tile
......@@ -24,20 +24,25 @@ struct Layernorm2dFwdPipelineTwoPass
using MeanDataType = ck_tile::remove_cvref_t<typename Problem::MeanDataType>;
using InvStdDataType = ck_tile::remove_cvref_t<typename Problem::InvStdDataType>;
using XResidualDataType = XDataType;
using YResidualDataType = XDataType;
static constexpr bool kHasGamma = !std::is_same_v<GammaDataType, ck_tile::null_type>;
static constexpr bool kHasBeta = !std::is_same_v<BetaDataType, ck_tile::null_type>;
static constexpr bool kSaveMean = Problem::kSaveMeanInvStd;
static constexpr bool kSaveInvStd = Problem::kSaveMeanInvStd;
static constexpr bool kSaveMean = Problem::Traits::kSaveMeanInvStd;
static constexpr bool kSaveInvStd = Problem::Traits::kSaveMeanInvStd;
static constexpr bool kNeedCrossWarpSync = Problem::kNeedCrossWarpSync;
static constexpr bool kPadM = false; // TODO - BlockLayernorm2dFwdProblem::kPadM
static constexpr bool kPadN = Problem::kPadN;
static constexpr bool kPadN = Problem::Traits::kPadN;
static constexpr auto kFusedAdd = Problem::Traits::kFusedAdd;
static constexpr auto kFusedQuant = Problem::Traits::kFusedQuant;
static constexpr const char* name = []() {
if constexpr(kNeedCrossWarpSync)
return "bpr"; // block per row
return "bpr_2p"; // block per row
else
return "wpr"; // warp per row
return "wpr_2p"; // warp per row
}();
CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSize()
......@@ -46,20 +51,30 @@ struct Layernorm2dFwdPipelineTwoPass
}
template <typename XWindow,
typename XResidualWindow,
typename GammaWindow,
typename BetaWindow,
typename YWindow,
typename YResidualWindow,
typename MeanWindow,
typename InvStdWindow>
typename InvStdWindow,
typename XScaleWindow,
typename YScaleWindow,
typename Epilogue>
CK_TILE_DEVICE auto operator()(const XWindow& x_window_,
const XResidualWindow& x_residual_window_,
const GammaWindow& gamma_window_,
const BetaWindow& beta_window_,
YWindow& y_window,
const YResidualWindow& y_residual_window_,
MeanWindow& mean_window,
InvStdWindow& inv_std_window,
const XScaleWindow& /*x_scale_window*/,
YScaleWindow& /*y_scale_window*/,
ComputeDataType epsilon,
ck_tile::index_t row_size,
void* smem) const
void* smem,
Epilogue) const
{
auto x_window =
make_tile_window(x_window_, Policy::template MakeXBlockTileDistribution<Problem>());
......@@ -67,6 +82,10 @@ struct Layernorm2dFwdPipelineTwoPass
gamma_window_, Policy::template MakeGammaBetaBlockTileDistribution<Problem>());
auto beta_window = make_tile_window(
beta_window_, Policy::template MakeGammaBetaBlockTileDistribution<Problem>());
auto x_residual_window = make_tile_window(
x_residual_window_, Policy::template MakeXBlockTileDistribution<Problem>());
auto y_residual_window = make_tile_window(
y_residual_window_, Policy::template MakeXBlockTileDistribution<Problem>());
// Problem::BlockShape
static constexpr index_t Block_N = Problem::BlockShape::Block_N;
......@@ -93,9 +112,26 @@ struct Layernorm2dFwdPipelineTwoPass
for(int iN = __builtin_amdgcn_readfirstlane(0); iN < num_n_tile_iteration; ++iN)
{
const auto x = load_tile(x_window);
block_welford(x, mean, var, cur_count, max_count);
auto x = load_tile(x_window);
auto x_resi = load_tile(x_residual_window);
move_tile_window(x_window, {0, Block_N});
move_tile_window(x_residual_window, {0, Block_N});
if constexpr(kFusedAdd == Layernorm2dFusedAddEnum::PRE_ADD_STORE ||
kFusedAdd == Layernorm2dFusedAddEnum::PRE_ADD)
{
sweep_tile(x_resi, [&](auto idx) {
// compute x = x_resi + x
x(idx) = type_convert<YResidualDataType>(x_resi(idx)) +
type_convert<YResidualDataType>(x(idx));
});
if constexpr(kFusedAdd == Layernorm2dFusedAddEnum::PRE_ADD_STORE)
{
store_tile(y_residual_window, x);
move_tile_window(y_residual_window, {0, Block_N});
}
}
block_welford(x, mean, var, cur_count, max_count);
}
block_welford_sync(mean, var, cur_count);
......@@ -105,7 +141,7 @@ struct Layernorm2dFwdPipelineTwoPass
// compute inv-std
auto inv_std = tile_elementwise_in(
[&](const auto& v_) {
return type_convert<ComputeDataType>(1.0f) / (sqrt(v_) + epsilon);
return type_convert<ComputeDataType>(1.0f) / (sqrt(v_ + epsilon));
},
var);
......@@ -118,9 +154,8 @@ struct Layernorm2dFwdPipelineTwoPass
ck_tile::index_t stride_to_right_most_window =
row_size % Block_N == 0 ? row_size - Block_N : row_size - row_size % Block_N;
// x_window.foo();
// gamma_window.foo();
move_tile_window(x_window, {0, -Block_N});
move_tile_window(x_residual_window, {0, -Block_N});
move_tile_window(gamma_window, {stride_to_right_most_window});
move_tile_window(beta_window, {stride_to_right_most_window});
move_tile_window(y_window, {0, stride_to_right_most_window});
......@@ -128,14 +163,24 @@ struct Layernorm2dFwdPipelineTwoPass
// layernorm computation
for(int iN = __builtin_amdgcn_readfirstlane(0); iN < num_n_tile_iteration; ++iN)
{
const auto x = load_tile(x_window);
auto x = load_tile(x_window);
auto x_resi = load_tile(x_residual_window);
if constexpr(kFusedAdd == Layernorm2dFusedAddEnum::PRE_ADD_STORE ||
kFusedAdd == Layernorm2dFusedAddEnum::PRE_ADD)
{
sweep_tile(x_resi, [&](auto idx) {
// compute x = x_resi + x
x(idx) = type_convert<YResidualDataType>(x_resi(idx)) +
type_convert<YResidualDataType>(x(idx));
});
}
// load gamma/beta (TODO: support no gamma/beta?)
const auto gamma = load_tile(gamma_window);
const auto beta = load_tile(beta_window);
auto y = make_static_distributed_tensor<YDataType>(x.get_tile_distribution());
auto ln = make_static_distributed_tensor<ComputeDataType>(x.get_tile_distribution());
sweep_tile(y, [&, mean_ = mean](auto idx) {
sweep_tile(ln, [&, mean_ = mean](auto idx) {
constexpr auto i_idx = make_tuple(idx[number<0>{}]);
constexpr auto j_idx = make_tuple(idx[number<1>{}]);
......@@ -143,14 +188,16 @@ struct Layernorm2dFwdPipelineTwoPass
const auto beta_ = type_convert<ComputeDataType>(beta[j_idx]);
const auto x_ = type_convert<ComputeDataType>(x[idx]);
auto y_ = (x_ - mean_[i_idx]) * inv_std[i_idx] * gamma_ + beta_;
auto ln_ = (x_ - mean_[i_idx]) * inv_std[i_idx] * gamma_ + beta_;
y(idx) = type_convert<YDataType>(y_);
ln(idx) = ln_;
});
store_tile(y_window, y);
static_assert(kFusedQuant != Layernorm2dFusedQuantEnum::DYNAMIC_QUANT);
Epilogue{}(y_window, ln);
move_tile_window(x_window, {0, -Block_N});
move_tile_window(x_residual_window, {0, -Block_N});
move_tile_window(gamma_window, {-Block_N});
move_tile_window(beta_window, {-Block_N});
move_tile_window(y_window, {0, -Block_N});
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core/utility/type_traits.hpp"
namespace ck_tile {
enum class Layernorm2dFusedAddEnum
{
NO_ADD = 0,
// fused add before layernorm and store result to global
PRE_ADD_STORE = 1,
// fused add before layernorm, but not store result
PRE_ADD = 2,
};
// clang-format off
template<Layernorm2dFusedAddEnum> struct Layernorm2dFusedAddEnumName;
template<> struct Layernorm2dFusedAddEnumName<Layernorm2dFusedAddEnum::NO_ADD> { static constexpr const char * name = "no"; };
template<> struct Layernorm2dFusedAddEnumName<Layernorm2dFusedAddEnum::PRE_ADD_STORE> { static constexpr const char * name = "pras"; };
template<> struct Layernorm2dFusedAddEnumName<Layernorm2dFusedAddEnum::PRE_ADD> { static constexpr const char * name = "pra"; };
// clang-format on
enum class Layernorm2dFusedQuantEnum
{
NO_SWEEP = 0,
SMOOTH_DYNAMIC_QUANT = 1, // smooth oulier + rowwise quant, need input x-scale and store y_scale
DYNAMIC_QUANT = 2, // rowwise quant, store out a y-scale
};
// clang-format off
template<Layernorm2dFusedQuantEnum> struct Layernorm2dFusedQuantEnumName;
template<> struct Layernorm2dFusedQuantEnumName<Layernorm2dFusedQuantEnum::NO_SWEEP> { static constexpr const char * name = "no"; };
template<> struct Layernorm2dFusedQuantEnumName<Layernorm2dFusedQuantEnum::DYNAMIC_QUANT> { static constexpr const char * name = "dqt"; };
template<> struct Layernorm2dFusedQuantEnumName<Layernorm2dFusedQuantEnum::SMOOTH_DYNAMIC_QUANT> { static constexpr const char * name = "smdqt"; };
// clang-format on
template <bool kPadN_,
bool kSaveMeanInvStd_,
bool kTwoPass_,
Layernorm2dFusedAddEnum kFusedAdd_,
Layernorm2dFusedQuantEnum kFusedQuant_>
struct Layernorm2dFwdTraits
{
static constexpr bool kPadN = kPadN_;
static constexpr bool kSaveMeanInvStd = kSaveMeanInvStd_;
static constexpr bool kTwoPass = kTwoPass_;
static constexpr Layernorm2dFusedAddEnum kFusedAdd = kFusedAdd_;
static constexpr Layernorm2dFusedQuantEnum kFusedQuant = kFusedQuant_;
};
} // namespace ck_tile
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/ops/permute/kernel/generic_permute_kernel.hpp"
#include "ck_tile/ops/permute/pipeline/generic_petmute_problem.hpp"
#include "ck_tile/ops/common/generic_2d_block_shape.hpp"
#include "ck_tile/ops/common/tensor_layout.hpp"
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core.hpp"
#include "ck_tile/ops/common.hpp"
// #include "ck_tile/ops/permute/pipeline/generic_petmute_problem.hpp"
namespace ck_tile {
/* independent host side argument, no template
*/
struct GenericPermuteHostArgs
{
static constexpr index_t kMaxRanks = 8; // TODO: hardcoded
const void* p_src;
void* p_dst;
index_t rank;
index_t shape[kMaxRanks]; // input shape
index_t perm[kMaxRanks]; // permute index
};
/*
simulate torch.permute:
x_ = x_.view(x.shape[0],
x.shape[1]//16, 16,
x.shape[2]//32, 4, 8)
x_ = x_.permute(0,1,3,4,2,5)
x_ = x_.contiguous()
x_ = x_.view(x.shape[0], x.shape[1], x.shape[2]);//
this kernel is supposed not to be performant(just OK), with functional support up to kMaxRanks
dim of permutation, with a single kernel
*/
template <typename Problem_>
struct GenericPermute
{
using Problem = ck_tile::remove_cvref_t<Problem_>;
using DataType = remove_cvref_t<typename Problem::DataType>;
static constexpr index_t kBlockSize = Problem::kBlockSize;
static constexpr index_t kMaxRanks = Problem::kMaxRanks;
static constexpr bool KeepLastDim = Problem::KeepLastDim;
struct __attribute__((packed)) Kargs
{
const void* p_src;
void* p_dst;
// index_t rank;
index_t num_elements;
index_t perm_length[kMaxRanks]; // tensor length after permutation
index_t perm_stride[kMaxRanks]; // tensor stride after permutation
};
CK_TILE_HOST static constexpr index_t TotalElements(const GenericPermuteHostArgs& h)
{
index_t n = 1;
for(auto i = 0; i < h.rank; i++)
{
n *= h.shape[i];
}
return n;
}
CK_TILE_HOST static constexpr Kargs MakeKargs(const GenericPermuteHostArgs& h)
{
Kargs a;
a.p_src = h.p_src;
a.p_dst = h.p_dst;
// assert rank <= kMaxRanks
index_t i = 0;
index_t perm[kMaxRanks];
index_t x_shape[kMaxRanks];
index_t x_stride[kMaxRanks];
// index_t perm_length[kMaxRanks];
for(; i < h.rank; i++)
{
x_shape[i] = h.shape[i];
perm[i] = h.perm[i];
}
for(; i < kMaxRanks; i++)
{
x_shape[i] = 1;
perm[i] = i; // will index to len = 1
}
index_t stride = 1;
for(index_t j = kMaxRanks - 1; j >= 0; j--)
{
x_stride[j] = stride;
stride *= x_shape[j];
}
for(index_t j = 0; j < kMaxRanks; j++)
{
a.perm_length[j] = x_shape[perm[j]];
a.perm_stride[j] = x_stride[perm[j]];
}
a.num_elements = TotalElements(h);
return a;
}
CK_TILE_HOST static constexpr auto GridSize(GenericPermuteHostArgs h)
{
auto total = TotalElements(h);
auto grids = dim3((total + BlockSize() - 1) / BlockSize());
// printf("### total:%d, grids:%dx%dx%d\n", total, );
return grids;
}
CK_TILE_HOST_DEVICE static constexpr auto BlockSize() { return Problem::kBlockSize; }
CK_TILE_DEVICE void operator()(Kargs kargs) const
{
index_t id = blockIdx.x * BlockSize() + threadIdx.x;
if(id >= kargs.num_elements)
return;
const auto perm_length =
generate_tuple([&](auto I) { return kargs.perm_length[I]; }, number<kMaxRanks>{});
const auto perm_stride =
generate_tuple([&](auto I) { return kargs.perm_stride[I]; }, number<kMaxRanks>{});
const DataType* p_src = reinterpret_cast<const DataType*>(kargs.p_src);
DataType* p_dst = reinterpret_cast<DataType*>(kargs.p_dst);
const auto src_view_0 = make_naive_tensor_view<address_space_enum::global>(
p_src, perm_length, perm_stride, number<1>{}, number<1>{});
const auto src_view = transform_tensor_view(
src_view_0,
make_tuple(make_merge_transform(perm_length)),
make_tuple(typename arithmetic_sequence_gen<0, kMaxRanks, 1>::type{}),
make_tuple(sequence<0>{}));
auto dst_view_0 = make_naive_tensor_view_packed<address_space_enum::global>(
p_dst, perm_length, number<1>{});
auto dst_view = transform_tensor_view(
dst_view_0,
make_tuple(make_merge_transform(perm_length)),
make_tuple(typename arithmetic_sequence_gen<0, kMaxRanks, 1>::type{}),
make_tuple(sequence<0>{}));
// TODO: hard code to vector 1
using vector_t = thread_buffer<DataType, 1>;
const auto src_coord =
make_tensor_coordinate(src_view.get_tensor_descriptor(), array<index_t, 1>{id});
const auto dst_coord =
make_tensor_coordinate(dst_view.get_tensor_descriptor(), array<index_t, 1>{id});
// printf("src id:%d, os:%d\n", id, src_coord.get_offset());
// printf("dst id:%d, os:%d\n", id, dst_coord.get_offset());
const vector_t x = src_view.template get_vectorized_elements<vector_t>(src_coord, 0);
dst_view.template set_vectorized_elements<vector_t>(dst_coord, 0, x);
}
};
} // namespace ck_tile
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core/utility/type_traits.hpp"
namespace ck_tile {
template <typename DataType_,
index_t kBlockSize_ = 256,
index_t kMaxRanks_ = 8,
bool KeepLastDim_ = false>
struct GenericPermuteProblem
{
using DataType = remove_cvref_t<DataType_>;
static constexpr index_t kBlockSize = kBlockSize_;
static constexpr index_t kMaxRanks = kMaxRanks_;
/* KeepLastDim:
* if last dim keep the same? this can help enable vector load
* permute(0, 2, 4, 1, 3, 5) -> true
* permute(0, 3, 2, 1) -> false
*/
static constexpr bool KeepLastDim = KeepLastDim_;
// TODO: not used(?)
};
} // namespace ck_tile
......@@ -4,4 +4,8 @@
#pragma once
#include "ck_tile/ops/reduce/block/block_reduce.hpp"
#include "ck_tile/ops/reduce/block/block_reduce2d.hpp"
#include "ck_tile/ops/reduce/block/block_reduce2d_default_policy.hpp"
#include "ck_tile/ops/reduce/block/block_reduce2d_problem.hpp"
#include "ck_tile/ops/common/generic_2d_block_shape.hpp"
#include "ck_tile/ops/common/tensor_layout.hpp"
......@@ -4,9 +4,15 @@
#pragma once
#include "ck_tile/core.hpp"
#include <tuple>
// This file is not support cross warp reduce
namespace ck_tile {
/*
* TODO: block_tile_reduce_sync() currently has a limitation
* Y dim must have at least one dim not been reduced
*/
// synchronize reduce result (cross lane reduction and broadcast on replicated dimension)
template <typename AccDistributedTensor_, typename ReduceFunc, bool WithBroadcast = true>
CK_TILE_DEVICE void block_tile_reduce_sync(AccDistributedTensor_& acc_tensor,
......@@ -22,7 +28,7 @@ CK_TILE_DEVICE void block_tile_reduce_sync(AccDistributedTensor_& acc_tensor,
constexpr index_t idim_p_lane = NDimP - 1;
const auto ps_idx = make_array<index_t>(get_block_id(), get_lane_id());
const auto ps_idx = detail::get_partition_index(acc_tensor.get_tile_distribution());
const auto rs_idx = acc_tensor.get_tile_distribution().calculate_rs_index_from_ps_index(ps_idx);
constexpr index_t thread_buf_size = AccDistributedTensor_::get_thread_buffer_size();
......@@ -104,6 +110,65 @@ CK_TILE_DEVICE void block_tile_reduce_sync(AccDistributedTensor_& acc_tensor,
});
}
/*
* this version is faster, using xor to do reduce, no need broadcast anymore
* TODO: the limitation is to-be-reduced P dim can only mapping to one R dim?
*/
template <typename AccDistributedTensor_, typename ReduceFunc>
CK_TILE_DEVICE void block_tile_reduce_xor_sync(AccDistributedTensor_& acc_tensor,
const ReduceFunc& reduce_func)
{
using Dstr = typename AccDistributedTensor_::StaticTileDistribution;
using DstrEncode = typename Dstr::DstrEncode;
using DstrEncodeDetail = typename DstrEncode::detail;
constexpr index_t NDimP = Dstr::get_num_of_dimension_p();
constexpr index_t NDimR = Dstr::get_num_of_dimension_r();
constexpr index_t idim_p_lane = NDimP - 1;
constexpr index_t thread_buf_size = AccDistributedTensor_::get_thread_buffer_size();
// loop over thread data
static_for<0, thread_buf_size, 1>{}([&](auto i) {
auto v_local = acc_tensor.get_thread_buffer()[i];
// cross-lane reduce for replication
// only reduce on R dimension correspond to lane
// (lane id maps to this R dimension)
static_for<0, NDimR, 1>{}([&](auto idim_r) {
// FIXME: nasty to use does_p_own_r_
if constexpr(DstrEncodeDetail::does_p_own_r_[idim_p_lane][idim_r])
{
constexpr index_t r_length = DstrEncode::rs_lengths_[idim_r];
constexpr index_t lid_over_rid_derivative =
DstrEncodeDetail::ps_over_rs_derivative_[idim_p_lane][idim_r];
static_assert(is_power_of_two_integer(r_length),
"wrong! only support power of 2 reduction");
constexpr index_t nstage = integer_log2_floor(r_length);
// reduction sweep forward
static_for<0, nstage, 1>{}([&](auto istage) {
// xor
index_t src_lane =
__lane_id() ^ (number<lid_over_rid_derivative << istage.value>{}.value);
// pull data from remote lane
const auto v_remote = warp_shuffle(v_local, src_lane);
// reduce
v_local = reduce_func(v_local, v_remote);
});
}
});
acc_tensor.get_thread_buffer()(i) = v_local;
});
}
// FIXME: this is for 2D to 1D reduce only, need to support n-D
template <typename AccDistributedTensor_,
typename InDistributedTensor_,
......@@ -175,6 +240,10 @@ CK_TILE_DEVICE void block_tile_reduce(AccDistributedTensor_& acc_tensor,
#endif
}
/*
* TODO: block_tile_reduce() currently has a limitation
* Y dim must have at least one dim not been reduced
*/
template <typename AccDataType_,
typename InDistributedTensor_,
index_t... InReduceDims,
......@@ -208,4 +277,109 @@ CK_TILE_DEVICE auto block_tile_reduce(const InDistributedTensor_& in_tensor,
return acc_tensor;
}
// this version only support 2D->1D reduce (reduce-dim=seq<0, 1>)
// this version only support in/acc/out datatypes are the same
// this version will call thread/warp+sync in one function call
//
template <typename InDistributedTensor_>
struct BlockReduce2D
{
using InDistributedTensor = remove_cvref_t<InDistributedTensor_>;
using InDataType = typename InDistributedTensor::DataType;
CK_TILE_HOST_DEVICE BlockReduce2D(const InDistributedTensor& t_, const InDataType& reduce_init_)
: t(t_), reduce_init(reduce_init_)
{
}
CK_TILE_HOST_DEVICE constexpr auto MakeDstBlockTile() const
{
using ReduceDim = sequence<1>; // hard coded
constexpr auto acc_dstr =
make_static_tile_distribution(ck_tile::detail::make_reduce_tile_distribution_encoding(
InDistributedTensor::get_tile_distribution()
.get_static_tile_distribution_encoding(),
ReduceDim{}));
auto dst_ = make_static_distributed_tensor<InDataType>(acc_dstr);
// init acc_tensor
tile_elementwise_inout([&](auto& x_) { x_ = type_convert<InDataType>(reduce_init); }, dst_);
return dst_;
}
// return number of pixels each lane need to reduce
CK_TILE_HOST_DEVICE constexpr auto get_reduce_length_y() const
{
constexpr auto spans = InDistributedTensor::get_distributed_spans();
}
// Here ReducePacksPerXDim is not the same meaning as that in static_uford/sweep_tile_uspan
// this is number of packs along the X-dim. We need to compute the Unpacks along the Y dim
// internally
// For simplicity, we just support along the row dimension, ReducePacksPerXDim is always 2
// element , and the first element is always ignored For simplicity, will always try from
// right-to-left to find alone which Y dim to split
template <typename ReduceFunc,
typename ReduceSyncFunc,
typename ReducePacksPerXDim = uniform_sequence_gen_t<2, 1>>
CK_TILE_HOST_DEVICE auto operator()(const ReduceFunc& reduce_func,
const ReduceSyncFunc& reduce_sync_func,
ReducePacksPerXDim = {}) const
{
constexpr auto spans = InDistributedTensor::get_distributed_spans();
constexpr auto row_y_unpacks = [&]() {
constexpr auto row_y_lengths = typename decltype(spans[number<1>{}])::Impl{};
constexpr auto row_y_size =
reduce_on_sequence(row_y_lengths, multiplies{}, number<1>{});
constexpr auto row_y_packs = ReducePacksPerXDim{}.at(number<1>{});
static_assert(row_y_size % row_y_packs == 0);
constexpr auto row_y_slice_size = row_y_size / row_y_packs;
constexpr auto slice_info = slice_sequence(row_y_lengths, number<row_y_slice_size>{});
constexpr auto unpacks = slice_info[number<1>{}];
return unpacks;
}();
auto acc_tensor = MakeDstBlockTile();
// in-thread reduction
// FIXME: hard coded to be 2D to 1D reduction
sweep_tile_span(spans[number<0>{}], [&](auto dstr_idx_i0) {
constexpr auto acc_dstr_idx = make_tuple(dstr_idx_i0);
auto acc = acc_tensor[acc_dstr_idx];
sweep_tile_uspan(
spans[number<1>{}],
[&](auto... dstr_idx_i1) {
acc = reduce_func(acc, t[make_tuple(dstr_idx_i0, dstr_idx_i1)]...);
},
row_y_unpacks);
acc_tensor(acc_dstr_idx) = acc;
});
// TODO: always use xor to do cross-lane reduce
block_tile_reduce_xor_sync(acc_tensor, reduce_sync_func);
return acc_tensor;
}
template <typename ReduceFunc>
CK_TILE_HOST_DEVICE auto operator()(const ReduceFunc& reduce_func) const
{
return operator()(reduce_func, reduce_func);
}
InDistributedTensor t;
InDataType reduce_init;
};
// deduction guide
template <typename T>
CK_TILE_HOST_DEVICE_EXTERN BlockReduce2D(const T&, const typename T::DataType&)->BlockReduce2D<T>;
} // namespace ck_tile
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core.hpp"
namespace ck_tile {
template <typename Problem_, typename Policy_ = void>
struct BlockReduce2d
{
// in-thread reduction
using Problem = remove_cvref_t<Problem_>;
using XDataType = typename Problem::XDataType;
using ComputeDataType = typename Problem::ComputeDataType;
CK_TILE_DEVICE constexpr BlockReduce2d() {}
template <typename XDistributedTensor_,
typename YDistributedTensor_,
typename ReduceFunc,
typename ReducePacksPerXDim = uniform_sequence_gen_t<2, 1>>
CK_TILE_DEVICE void operator()(const XDistributedTensor_& x_tensor,
YDistributedTensor_& y_tensor,
const ReduceFunc& reduce_func,
ReducePacksPerXDim = {})
{
sweep_tile<XDistributedTensor_>(
[&](auto... idx_) {
constexpr auto idx_0 = make_tuple(make_tuple(idx_[number<0>{}]...)[number<0>{}]);
y_tensor(idx_0) = reduce_func(y_tensor(idx_0), x_tensor[idx_]...);
},
ReducePacksPerXDim{});
#if 0
constexpr auto I0 = number<0>{};
constexpr auto I1 = number<1>{};
constexpr auto spans = XDistributedTensor_::get_distributed_spans();
// FIXME: hard coded to reduce 2nd axis
sweep_tile_span(spans[I0], [&](auto dstr_idx_i0) {
constexpr auto y_dstr_idx = make_tuple(dstr_idx_i0);
auto y = y_tensor[y_dstr_idx];
sweep_tile_span(spans[I1], [&](auto dstr_idx_i1) {
constexpr auto in_dstr_idx = make_tuple(dstr_idx_i0, dstr_idx_i1);
const auto x = ck_tile::type_convert<ComputeDataType>(x_tensor[in_dstr_idx]);
y = reduce_func(y, x);
});
y_tensor(y_dstr_idx) = y;
});
#endif
}
template <typename XDistributedTensor_>
CK_TILE_DEVICE static auto MakeYBlockTile()
{
static_assert(std::is_same_v<XDataType, typename XDistributedTensor_::DataType>, "wrong!");
// FIXME: hard coded to reduce 2nd axis
constexpr auto reduce_dims = sequence<1>{};
constexpr auto dstr =
make_static_tile_distribution(detail::make_reduce_tile_distribution_encoding(
XDistributedTensor_::get_tile_distribution()
.get_static_tile_distribution_encoding(),
reduce_dims));
auto tensor = make_static_distributed_tensor<ComputeDataType>(dstr);
return tensor;
}
template <typename XDistributedTensor_,
typename ReduceFunc,
typename ReducePacksPerXDim = uniform_sequence_gen_t<2, 1>>
CK_TILE_DEVICE auto operator()(const XDistributedTensor_& x_tensor,
const ComputeDataType& reduce_init,
const ReduceFunc& reduce_func,
ReducePacksPerXDim = {})
{
auto y_tensor = MakeYBlockTile<XDistributedTensor_>();
set_tile(y_tensor, reduce_init);
(*this)(x_tensor, y_tensor, reduce_func, ReducePacksPerXDim{});
return y_tensor;
}
};
template <typename Problem_, typename Policy_ = void>
struct BlockReduce2dSync
{
using Problem = remove_cvref_t<Problem_>;
template <typename YDistributedTensor_, typename ReduceFunc>
CK_TILE_DEVICE void operator()(YDistributedTensor_& y_tensor, const ReduceFunc& reduce_func)
{
using Dstr = typename YDistributedTensor_::StaticTileDistribution;
using DstrEncode = typename Dstr::DstrEncode;
using DstrEncodeDetail = typename DstrEncode::detail;
constexpr index_t NDimP = Dstr::get_num_of_dimension_p();
constexpr index_t NDimR = Dstr::get_num_of_dimension_r();
constexpr index_t idim_p_lane = NDimP - 1;
// const auto ps_idx = make_array<index_t>(get_warp_id(), get_lane_id());
// const auto rs_idx =
// y_tensor.get_tile_distribution().calculate_rs_index_from_ps_index(ps_idx);
constexpr index_t thread_buf_size = YDistributedTensor_::get_thread_buffer_size();
// loop over thread data
static_for<0, thread_buf_size, 1>{}([&](auto i) {
auto v_local = y_tensor.get_thread_buffer()[i];
// cross-lane reduce for replication
// only reduce on R dimension correspond to lane
// (lane id maps to this R dimension)
static_for<0, NDimR, 1>{}([&](auto idim_r) {
// FIXME: nasty to use does_p_own_r_
if constexpr(DstrEncodeDetail::does_p_own_r_[idim_p_lane][idim_r])
{
constexpr index_t r_length = DstrEncode::rs_lengths_[idim_r];
constexpr index_t lid_over_rid_derivative =
DstrEncodeDetail::ps_over_rs_derivative_[idim_p_lane][idim_r];
static_assert(is_power_of_two_integer(r_length),
"wrong! only support power of 2 reduction");
constexpr index_t nstage = integer_log2_floor(r_length);
// reduction sweep forward
static_for<0, nstage, 1>{}([&](auto istage) {
// xor
index_t src_lane =
(__lane_id()) ^
(number<lid_over_rid_derivative << istage.value>{}.value);
// pull data from remote lane
const auto v_remote = warp_shuffle(v_local, src_lane);
// reduce
v_local = reduce_func(v_local, v_remote);
});
}
});
// TODO - Do we need to broadcast to other lane?
y_tensor.get_thread_buffer()(i) = v_local;
});
}
};
template <typename Problem_, typename Policy_ = void>
struct BlockReduce2dCrossWarpSync
{
using Problem = remove_cvref_t<Problem_>;
using BlockShape = typename Problem::BlockShape;
template <typename YDistributedTensor_>
CK_TILE_DEVICE static constexpr index_t GetReduceWarps()
{
constexpr index_t num_reduce_warps = [&]() {
using Dstr = typename YDistributedTensor_::StaticTileDistribution;
using DstrEncode = typename Dstr::DstrEncode;
using DstrEncodeDetail = typename DstrEncode::detail;
constexpr index_t NDimR = Dstr::get_num_of_dimension_r();
constexpr index_t idim_p_warp = 0;
index_t len_ = 1;
static_for<0, NDimR, 1>{}([&](auto idim_r) {
if constexpr(DstrEncodeDetail::does_p_own_r_[idim_p_warp][idim_r])
{
constexpr index_t r_length = DstrEncode::rs_lengths_[idim_r];
len_ *= r_length;
}
});
return len_;
}();
return num_reduce_warps;
}
// return in byte
template <typename YDistributedTensor_>
CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSize()
{
using DataType = typename YDistributedTensor_::DataType;
// constexpr auto num_reduce_warps = GetReduceWarps<YDistributedTensor_>();
constexpr index_t thread_buf_size = YDistributedTensor_::get_thread_buffer_size();
// we need to store all data from every wave into smem
// e.g. 2x2 reduce along N
// -------------> reduce N
// | w0 | w1 | ___> | w01 |
// | w2 | w3 | | w23 |
//
// -> store data from every wave into LDS
//
//
// -------------> reduce N
// | w0 | w1 | w2 | w3 | -----> | w0123 |
//
// -> also store data from every wave into LDS
constexpr index_t num_warps = BlockShape::BlockSize / warpSize;
return num_warps * thread_buf_size * sizeof(DataType);
}
template <typename YDistributedTensor_, typename ReduceFunc>
CK_TILE_DEVICE void
operator()(YDistributedTensor_& y_tensor, void* smem, const ReduceFunc& reduce_func)
{
using DataType = typename YDistributedTensor_::DataType;
constexpr index_t thread_buf_size = YDistributedTensor_::get_thread_buffer_size();
DataType* smem_ptr = reinterpret_cast<DataType*>(smem);
const index_t lane_id = get_lane_id();
const index_t warp_id = get_warp_id();
constexpr auto num_reduce_warps = GetReduceWarps<YDistributedTensor_>();
constexpr index_t num_warps = BlockShape::BlockSize / warpSize;
const index_t smem_offset = warp_id;
// skip if nonthing to do
if constexpr(num_reduce_warps == 1)
return;
// store into smem only for lane-0 within one warp
if(lane_id == 0)
{
static_for<0, thread_buf_size, 1>{}([&](auto i) {
smem_ptr[smem_offset + i * num_warps] = y_tensor.get_thread_buffer()[i];
});
}
block_sync_lds();
// load from smem. here we let everythread to do compute :)
index_t local_warp_id = warp_id / num_reduce_warps;
index_t local_smem_os = local_warp_id * num_reduce_warps;
DataType all_scratch[thread_buf_size * num_reduce_warps];
static_for<0, thread_buf_size, 1>{}([&](auto i_0) {
static_for<0, num_reduce_warps, 1>{}([&](auto i_1) {
all_scratch[i_0 * num_reduce_warps + i_1] =
smem_ptr[i_0 * num_warps + local_smem_os + i_1];
});
});
block_sync_lds(); // TODO: we don't need sync here
static_for<0, thread_buf_size, 1>{}([&](auto i_0) {
// TODO: use descriptor for this
auto v_local = all_scratch[i_0 * num_reduce_warps];
// further reduce mean/var
static_for<0, num_reduce_warps - 1, 1>{}([&](auto i_1_n1) {
constexpr auto i_1 = number<i_1_n1 + 1>{};
const DataType v_remote = all_scratch[i_0 * num_reduce_warps + i_1];
// reduce
v_local = reduce_func(v_local, v_remote);
});
y_tensor.get_thread_buffer()(i_0) = v_local;
});
}
};
} // namespace ck_tile
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core.hpp"
#include "ck_tile/ops/reduce/block/block_reduce2d_problem.hpp"
#include "ck_tile/ops/reduce/block/block_reduce2d.hpp"
namespace ck_tile {
struct BlockReduce2dDefaultPolicy
{
template <typename Problem>
CK_TILE_DEVICE static constexpr auto MakeXBlockTileDistribution()
{
using S = typename Problem::BlockShape;
return make_static_tile_distribution(
tile_distribution_encoding<
sequence<>,
tuple<sequence<S::Repeat_M, S::WarpPerBlock_M, S::ThreadPerWarp_M, S::Vector_M>,
sequence<S::Repeat_N, S::WarpPerBlock_N, S::ThreadPerWarp_N, S::Vector_N>>,
tuple<sequence<1, 2>, sequence<1, 2>>,
tuple<sequence<1, 1>, sequence<2, 2>>,
sequence<1, 1, 2, 2>,
sequence<0, 3, 0, 3>>{});
}
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto GetBlockReduce2d()
{
using P_ = BlockReduce2dProblem<typename Problem::XDataType,
typename Problem::ComputeDataType,
typename Problem::BlockShape>;
return BlockReduce2d<P_>{};
}
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto GetBlockReduce2dSync()
{
using P_ = BlockReduce2dProblem<typename Problem::XDataType,
typename Problem::ComputeDataType,
typename Problem::BlockShape>;
return BlockReduce2dSync<P_>{};
}
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto GetBlockReduce2dCrossWarpSync()
{
using P_ = BlockReduce2dProblem<typename Problem::XDataType,
typename Problem::ComputeDataType,
typename Problem::BlockShape>;
return BlockReduce2dCrossWarpSync<P_>{};
}
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSize()
{
if constexpr(Problem::kNeedCrossWarpSync)
{
using P_ = BlockReduce2dProblem<typename Problem::XDataType,
typename Problem::ComputeDataType,
typename Problem::BlockShape>;
using block_reduce2d = BlockReduce2d<P_>;
using x_block_tile =
decltype(make_static_distributed_tensor<typename Problem::XDataType>(
MakeXBlockTileDistribution<Problem>()));
using y_block_tile = decltype(block_reduce2d::template MakeYBlockTile<x_block_tile>());
return GetBlockReduce2dCrossWarpSync<Problem>().template GetSmemSize<y_block_tile>();
}
else
{
return 1; // zero size arrays are an extension
}
}
};
} // namespace ck_tile
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core.hpp"
namespace ck_tile {
template <typename XDataType_, typename ComputeDataType_, typename BlockShape_>
struct BlockReduce2dProblem
{
using XDataType = remove_cvref_t<XDataType_>;
using ComputeDataType = remove_cvref_t<ComputeDataType_>;
using BlockShape = remove_cvref_t<BlockShape_>;
};
} // namespace ck_tile
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/ops/rmsnorm2d/kernel/rmsnorm2d_fwd_kernel.hpp"
#include "ck_tile/ops/rmsnorm2d/kernel/rmsnorm2d_fwd_shape.hpp"
#include "ck_tile/ops/rmsnorm2d/pipeline/rmsnorm2d_fwd_pipeline_default_policy.hpp"
#include "ck_tile/ops/rmsnorm2d/pipeline/rmsnorm2d_fwd_pipeline_one_pass.hpp"
#include "ck_tile/ops/rmsnorm2d/pipeline/rmsnorm2d_fwd_pipeline_problem.hpp"
#include "ck_tile/ops/rmsnorm2d/pipeline/rmsnorm2d_fwd_pipeline_two_pass.hpp"
#include "ck_tile/ops/common/generic_2d_block_shape.hpp"
#include "ck_tile/ops/common/tensor_layout.hpp"
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core.hpp"
#include "ck_tile/ops/common.hpp"
namespace ck_tile {
// host side args
struct Rmsnorm2dFwdHostArgs
{
const void* p_x;
const void* p_gamma;
void* p_y;
void* p_invRms;
float epsilon;
index_t m;
index_t n;
index_t stride; // row_stride
};
// TODO: Extract some type to wrapper class
template <typename Pipeline_>
struct Rmsnorm2dFwd
{
using Pipeline = remove_cvref_t<Pipeline_>;
using Problem = typename Pipeline::Problem;
using XDataType = remove_cvref_t<typename Problem::XDataType>;
using GammaDataType = remove_cvref_t<typename Problem::GammaDataType>;
using ComputeDataType = remove_cvref_t<typename Problem::ComputeDataType>;
using YDataType = remove_cvref_t<typename Problem::YDataType>;
using InvRmsDataType = remove_cvref_t<typename Problem::InvRmsDataType>;
static constexpr bool kHasGamma = !std::is_same_v<GammaDataType, null_type>;
static constexpr bool kSaveInvRms = Problem::kSaveInvRms;
static constexpr index_t Block_M = Problem::BlockShape::Block_M;
static constexpr index_t Block_N = Problem::BlockShape::Block_N;
static constexpr bool kPadM = false; // always no need to pad along M
static constexpr bool kPadN = Problem::kPadN;
static constexpr bool kTwoPass = Problem::kTwoPass;
static constexpr index_t ThreadPerWarp_N = Problem::BlockShape::ThreadPerWarp_N;
static constexpr index_t Vector_N = Problem::BlockShape::Vector_N;
static constexpr index_t Repeat_N = Problem::BlockShape::Repeat_N;
static constexpr auto I0 = number<0>{};
static constexpr auto I1 = number<1>{};
struct Kargs
{
const void* p_x;
const void* p_gamma;
void* p_y;
void* p_invRms;
float epsilon;
index_t m;
index_t n;
index_t stride; // row_stride
};
using Hargs = Rmsnorm2dFwdHostArgs;
CK_TILE_HOST static constexpr Kargs MakeKargs(const Hargs& hargs)
{
return Kargs{hargs.p_x,
hargs.p_gamma,
hargs.p_y,
hargs.p_invRms,
hargs.epsilon,
hargs.m,
hargs.n,
hargs.stride};
}
CK_TILE_HOST static constexpr auto GridSize(const Hargs& hargs)
{
return (hargs.m + Block_M - 1) / Block_M;
}
CK_TILE_HOST static constexpr auto BlockSize() { return Problem::BlockShape::BlockSize; }
// clang-format off
template <typename T> struct t2s;
template <> struct t2s<float> { static constexpr const char * name = "fp32"; };
template <> struct t2s<ck_tile::fp16_t> { static constexpr const char * name = "fp16"; };
template <> struct t2s<ck_tile::bf16_t> { static constexpr const char * name = "bf16"; };
template <> struct t2s<ck_tile::fp8_t> { static constexpr const char * name = "fp8"; };
template <> struct t2s<ck_tile::bf8_t> { static constexpr const char * name = "bf8"; };
// clang-format on
// in byte
CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSize() { return Pipeline::GetSmemSize(); }
CK_TILE_HOST static std::string GetName()
{
// clang-format off
using S_ = typename Problem::BlockShape;
auto surfix = [&] () {
std::string n;
if (kPadN) n += "_pn";
if (kSaveInvRms) n += "_rms";
if (kTwoPass) n += "_2p";
return n; }();
#define _SS_ std::string
#define _TS_ std::to_string
return _SS_("rmsnorm2d_fwd_") + _SS_(t2s<XDataType>::name) + "_" +
_TS_(S_::Block_M) + "x" + _TS_(S_::Block_N) + "_" + _TS_(S_::WarpPerBlock_M) + "x" + _TS_(S_::WarpPerBlock_N) + "_" +
_TS_(S_::Warp_M) + "x" + _TS_(S_::Warp_N) + "_" + _TS_(S_::Vector_M) + "x" + _TS_(S_::Vector_N) + "_" +
_SS_(Pipeline::name) + surfix;
#undef _SS_
#undef _TS_
// clang-format on
}
CK_TILE_DEVICE void operator()(Kargs kargs) const
{
const auto iM = get_block_id() * Block_M;
const auto x_window = [&]() {
const auto tmp_ = make_naive_tensor_view<address_space_enum::global>(
static_cast<const XDataType*>(kargs.p_x),
make_tuple(kargs.m, kargs.n),
make_tuple(kargs.stride, 1),
number<Vector_N>{},
number<1>{});
const auto tmp2_ = pad_tensor_view(
tmp_, make_tuple(number<Block_M>{}, number<Block_N>{}), sequence<kPadM, kPadN>{});
return make_tile_window(
tmp2_, make_tuple(number<Block_M>{}, number<Block_N>{}), {iM, 0});
}();
const auto gamma_window = [&]() {
const auto tmp_ = make_naive_tensor_view<address_space_enum::global>(
static_cast<const GammaDataType*>(kargs.p_gamma),
make_tuple(kargs.n),
make_tuple(1),
number<Vector_N>{},
number<1>{});
const auto tmp2_ =
pad_tensor_view(tmp_, make_tuple(number<Block_N>{}), sequence<kPadM>{});
return make_tile_window(tmp2_, make_tuple(number<Block_N>{}), {0});
}();
auto y_window = [&]() {
auto tmp_ = make_naive_tensor_view<address_space_enum::global>(
static_cast<YDataType*>(kargs.p_y),
make_tuple(kargs.m, kargs.n),
make_tuple(kargs.stride, 1),
number<Vector_N>{},
number<1>{});
auto tmp2_ = pad_tensor_view(
tmp_, make_tuple(number<Block_M>{}, number<Block_N>{}), sequence<kPadM, kPadN>{});
return make_tile_window(
tmp2_, make_tuple(number<Block_M>{}, number<Block_N>{}), {iM, 0});
}();
auto inv_rms_window = [&]() {
if constexpr(kSaveInvRms)
{
const auto inv_rms_m = [&]() {
const auto inv_rms_dram_naive =
make_naive_tensor_view_packed<address_space_enum::global>(
static_cast<InvRmsDataType*>(kargs.p_invRms),
make_tuple(kargs.m),
number<1>{});
return pad_tensor_view(
inv_rms_dram_naive, make_tuple(number<Block_M>{}), sequence<kPadM>{});
}();
return make_tile_window(inv_rms_m, make_tuple(number<Block_M>{}), {iM});
}
else
return make_null_tile_window(make_tuple(number<Block_M>{}));
}();
__shared__ char smem[GetSmemSize()];
Pipeline{}(x_window,
gamma_window,
y_window,
inv_rms_window,
static_cast<const ComputeDataType>(kargs.epsilon),
kargs.n,
smem);
}
};
} // namespace ck_tile
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
......@@ -42,7 +42,7 @@ template <typename BlockTile_, // block size, seq<M, N>
typename Vector_, // contiguous pixels(vector size) along seq<M, N>
index_t BlockSize_ =
warpSize* reduce_on_sequence(WarpPerBlock_{}, multiplies{}, number<1>{})>
struct Layernorm2dShape
struct Rmsnorm2dShape
{
// block size
static constexpr index_t Block_M = BlockTile_::at(number<0>{});
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core.hpp"
#include "ck_tile/ops/reduce/block/block_reduce2d_problem.hpp"
#include "ck_tile/ops/reduce/block/block_reduce2d.hpp"
namespace ck_tile {
struct Rmsnorm2dFwdPipelineDefaultPolicy
{
template <typename Problem>
CK_TILE_DEVICE static constexpr auto MakeXBlockTileDistribution()
{
using S = typename Problem::BlockShape;
return make_static_tile_distribution(
tile_distribution_encoding<
sequence<>,
tuple<sequence<S::Repeat_M, S::WarpPerBlock_M, S::ThreadPerWarp_M, S::Vector_M>,
sequence<S::Repeat_N, S::WarpPerBlock_N, S::ThreadPerWarp_N, S::Vector_N>>,
tuple<sequence<1, 2>, sequence<1, 2>>,
tuple<sequence<1, 1>, sequence<2, 2>>,
sequence<1, 1, 2, 2>,
sequence<0, 3, 0, 3>>{});
}
template <typename Problem>
CK_TILE_DEVICE static constexpr auto MakeGammaBlockTileDistribution()
{
using S = typename Problem::BlockShape;
return make_static_tile_distribution(
tile_distribution_encoding<
sequence<S::WarpPerBlock_M, S::ThreadPerWarp_M>,
tuple<sequence<S::Repeat_N, S::WarpPerBlock_N, S::ThreadPerWarp_N, S::Vector_N>>,
tuple<sequence<0, 1>, sequence<0, 1>>,
tuple<sequence<0, 1>, sequence<1, 2>>,
sequence<1, 1>,
sequence<0, 3>>{});
}
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto GetBlockReduce2d()
{
using P_ = BlockReduce2dProblem<typename Problem::XDataType,
typename Problem::ComputeDataType,
typename Problem::BlockShape>;
return BlockReduce2d<P_>{};
}
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto GetBlockReduce2dSync()
{
using P_ = BlockReduce2dProblem<typename Problem::XDataType,
typename Problem::ComputeDataType,
typename Problem::BlockShape>;
return BlockReduce2dSync<P_>{};
}
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto GetBlockReduce2dCrossWarpSync()
{
using P_ = BlockReduce2dProblem<typename Problem::XDataType,
typename Problem::ComputeDataType,
typename Problem::BlockShape>;
return BlockReduce2dCrossWarpSync<P_>{};
}
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSize()
{
if constexpr(Problem::kNeedCrossWarpSync)
{
using P_ = BlockReduce2dProblem<typename Problem::XDataType,
typename Problem::ComputeDataType,
typename Problem::BlockShape>;
using block_reduce2d = BlockReduce2d<P_>;
using x_block_tile =
decltype(make_static_distributed_tensor<typename Problem::XDataType>(
MakeXBlockTileDistribution<Problem>()));
using y_block_tile = decltype(block_reduce2d::template MakeYBlockTile<x_block_tile>());
return GetBlockReduce2dCrossWarpSync<Problem>().template GetSmemSize<y_block_tile>();
}
else
{
return 1; // zero size arrays are an extension
}
}
};
} // namespace ck_tile
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core.hpp"
#include "ck_tile/ops/rmsnorm2d/pipeline/rmsnorm2d_fwd_pipeline_default_policy.hpp"
#include <string>
#include <type_traits>
namespace ck_tile {
template <typename Problem_, typename Policy_ = Rmsnorm2dFwdPipelineDefaultPolicy>
struct Rmsnorm2dFwdPipelineOnePass
{
using Problem = ck_tile::remove_cvref_t<Problem_>;
using Policy = ck_tile::remove_cvref_t<Policy_>;
using XDataType = ck_tile::remove_cvref_t<typename Problem::XDataType>;
using GammaDataType = ck_tile::remove_cvref_t<typename Problem::GammaDataType>;
using ComputeDataType = ck_tile::remove_cvref_t<typename Problem::ComputeDataType>;
using YDataType = ck_tile::remove_cvref_t<typename Problem::YDataType>;
using InvRmsDataType = ck_tile::remove_cvref_t<typename Problem::InvRmsDataType>;
static constexpr bool kHasGamma = !std::is_same_v<GammaDataType, ck_tile::null_type>;
static constexpr bool kSaveInvRms = Problem::kSaveInvRms;
static constexpr bool kNeedCrossWarpSync = Problem::kNeedCrossWarpSync;
static constexpr bool kPadM = false; // TODO - BlockRmsnorm2dFwdProblem::kPadM
static constexpr bool kPadN = Problem::kPadN;
static constexpr const char* name = []() {
if constexpr(kNeedCrossWarpSync)
return "bpr_op"; // block per row
else
return "wpr_op"; // warp per row
}();
CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSize()
{
return Policy::template GetSmemSize<Problem>();
}
template <typename XWindow, typename GammaWindow, typename YWindow, typename InvRmsWindow>
CK_TILE_DEVICE auto operator()(const XWindow& x_window_,
const GammaWindow& gamma_window_,
YWindow& y_window,
InvRmsWindow& inv_rms_window,
ComputeDataType epsilon,
ck_tile::index_t row_size,
void* smem) const
{
const auto x_window =
make_tile_window(x_window_, Policy::template MakeXBlockTileDistribution<Problem>());
const auto gamma_window = make_tile_window(
gamma_window_, Policy::template MakeGammaBlockTileDistribution<Problem>());
auto reduce_square_sum_func = ReduceOp::SquareAdd{};
auto reduce_sum_func = ReduceOp::Add{};
auto block_reduce2d = Policy::template GetBlockReduce2d<Problem>();
auto block_reduce2d_sync = Policy::template GetBlockReduce2dSync<Problem>();
auto block_reduce2d_cross_warp_sync =
Policy::template GetBlockReduce2dCrossWarpSync<Problem>();
const auto x = load_tile(x_window);
// load gamma (TODO: support no gamma?)
const auto gamma = load_tile(gamma_window);
// compute mean square each-thread->cross-lane->cross-warp
auto square_sum = block_reduce2d(
x, reduce_square_sum_func.GetIdentityValue<ComputeDataType>(), reduce_square_sum_func);
block_reduce2d_sync(square_sum, reduce_sum_func);
block_reduce2d_cross_warp_sync(square_sum, smem, reduce_sum_func);
// compute inv-rms
auto inv_rms = tile_elementwise_in(
[&](const auto& v_) {
return type_convert<ComputeDataType>(1.0f) / (sqrt(v_ / row_size + epsilon));
},
square_sum);
if constexpr(kSaveInvRms)
store_tile(inv_rms_window, cast_tile<InvRmsDataType>(inv_rms));
// rmsnorm computation
auto y = make_static_distributed_tensor<YDataType>(x.get_tile_distribution());
sweep_tile(y, [&, inv_rms_ = inv_rms](auto idx) {
constexpr auto i_idx = make_tuple(idx[number<0>{}]);
constexpr auto j_idx = make_tuple(idx[number<1>{}]);
const auto gamma_ = type_convert<ComputeDataType>(gamma[j_idx]);
const auto x_ = type_convert<ComputeDataType>(x[idx]);
auto y_ = x_ * inv_rms_[i_idx] * gamma_;
y(idx) = type_convert<YDataType>(y_);
});
store_tile(y_window, y);
}
};
} // namespace ck_tile
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core/utility/type_traits.hpp"
namespace ck_tile {
template <typename XDataType_,
typename GammaDataType_,
typename ComputeDataType_,
typename YDataType_,
typename InvRmsDataType_,
typename BlockShape_,
bool kPadN_,
bool kSaveInvRms_,
bool kTwoPass_>
struct Rmsnorm2dFwdPipelineProblem
{
using XDataType = remove_cvref_t<XDataType_>;
using GammaDataType = remove_cvref_t<GammaDataType_>;
using ComputeDataType = remove_cvref_t<ComputeDataType_>;
using YDataType = remove_cvref_t<YDataType_>;
using InvRmsDataType = remove_cvref_t<InvRmsDataType_>;
using BlockShape = remove_cvref_t<BlockShape_>;
static constexpr bool kNeedCrossLaneSync = BlockShape::ThreadPerWarp_N > 1;
static constexpr bool kNeedCrossWarpSync = BlockShape::WarpPerBlock_N > 1;
static constexpr bool kPadN = kPadN_;
static constexpr bool kSaveInvRms = kSaveInvRms_;
static constexpr bool kTwoPass = kTwoPass_;
};
} // namespace ck_tile
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core.hpp"
#include "ck_tile/ops/rmsnorm2d/pipeline/rmsnorm2d_fwd_pipeline_default_policy.hpp"
#include <string>
#include <type_traits>
namespace ck_tile {
template <typename Problem_, typename Policy_ = Rmsnorm2dFwdPipelineDefaultPolicy>
struct Rmsnorm2dFwdPipelineTwoPass
{
using Problem = ck_tile::remove_cvref_t<Problem_>;
using Policy = ck_tile::remove_cvref_t<Policy_>;
using XDataType = ck_tile::remove_cvref_t<typename Problem::XDataType>;
using GammaDataType = ck_tile::remove_cvref_t<typename Problem::GammaDataType>;
using ComputeDataType = ck_tile::remove_cvref_t<typename Problem::ComputeDataType>;
using YDataType = ck_tile::remove_cvref_t<typename Problem::YDataType>;
using InvRmsDataType = ck_tile::remove_cvref_t<typename Problem::InvRmsDataType>;
static constexpr bool kHasGamma = !std::is_same_v<GammaDataType, ck_tile::null_type>;
static constexpr bool kSaveInvRms = Problem::kSaveInvRms;
static constexpr bool kNeedCrossWarpSync = Problem::kNeedCrossWarpSync;
static constexpr bool kPadM = false; // TODO - BlockRmsnorm2dFwdProblem::kPadM
static constexpr bool kPadN = Problem::kPadN;
static constexpr const char* name = []() {
if constexpr(kNeedCrossWarpSync)
return "bpr_tp"; // block per row
else
return "wpr_tp"; // warp per row
}();
CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSize()
{
return Policy::template GetSmemSize<Problem>();
}
template <typename XWindow, typename GammaWindow, typename YWindow, typename InvRmsWindow>
CK_TILE_DEVICE auto operator()(const XWindow& x_window_,
const GammaWindow& gamma_window_,
YWindow& y_window,
InvRmsWindow& inv_rms_window,
ComputeDataType epsilon,
ck_tile::index_t row_size,
void* smem) const
{
auto x_window =
make_tile_window(x_window_, Policy::template MakeXBlockTileDistribution<Problem>());
auto gamma_window = make_tile_window(
gamma_window_, Policy::template MakeGammaBlockTileDistribution<Problem>());
// Problem::BlockShape
static constexpr index_t Block_N = Problem::BlockShape::Block_N;
index_t num_n_tile_iteration =
__builtin_amdgcn_readfirstlane(integer_divide_ceil(row_size, Block_N));
auto reduce_square_sum_func = ReduceOp::SquareAdd{};
auto reduce_sum_func = ReduceOp::Add{};
auto block_reduce2d = Policy::template GetBlockReduce2d<Problem>();
auto block_reduce2d_sync = Policy::template GetBlockReduce2dSync<Problem>();
auto block_reduce2d_cross_warp_sync =
Policy::template GetBlockReduce2dCrossWarpSync<Problem>();
using XTensorType = decltype(load_tile(x_window));
auto square_sum = block_reduce2d.template MakeYBlockTile<XTensorType>();
set_tile(square_sum, reduce_square_sum_func.GetIdentityValue<ComputeDataType>());
for(int iN = __builtin_amdgcn_readfirstlane(0); iN < num_n_tile_iteration; ++iN)
{
const auto x = load_tile(x_window);
block_reduce2d(x, square_sum, reduce_square_sum_func);
move_tile_window(x_window, {0, Block_N});
}
block_reduce2d_sync(square_sum, reduce_sum_func);
block_reduce2d_cross_warp_sync(square_sum, smem, reduce_sum_func);
// compute inv-rms
auto inv_rms = tile_elementwise_in(
[&](const auto& v_) {
return type_convert<ComputeDataType>(1.0f) / (sqrt(v_ / row_size + epsilon));
},
square_sum);
if constexpr(kSaveInvRms)
store_tile(inv_rms_window, cast_tile<InvRmsDataType>(inv_rms));
// reverse read x to reuse cache
ck_tile::index_t stride_to_right_most_window =
row_size % Block_N == 0 ? row_size - Block_N : row_size - row_size % Block_N;
move_tile_window(x_window, {0, -Block_N});
move_tile_window(gamma_window, {stride_to_right_most_window});
move_tile_window(y_window, {0, stride_to_right_most_window});
// rmsnorm computation
for(int iN = __builtin_amdgcn_readfirstlane(0); iN < num_n_tile_iteration; ++iN)
{
const auto x = load_tile(x_window);
// load gamma/beta (TODO: support no gamma/beta?)
const auto gamma = load_tile(gamma_window);
auto y = make_static_distributed_tensor<YDataType>(x.get_tile_distribution());
sweep_tile(y, [&, inv_rms_ = inv_rms](auto idx) {
constexpr auto i_idx = make_tuple(idx[number<0>{}]);
constexpr auto j_idx = make_tuple(idx[number<1>{}]);
const auto gamma_ = type_convert<ComputeDataType>(gamma[j_idx]);
const auto x_ = type_convert<ComputeDataType>(x[idx]);
auto y_ = x_ * inv_rms_[i_idx] * gamma_;
y(idx) = type_convert<YDataType>(y_);
});
store_tile(y_window, y);
move_tile_window(x_window, {0, -Block_N});
move_tile_window(gamma_window, {-Block_N});
move_tile_window(y_window, {0, -Block_N});
}
}
};
} // namespace ck_tile
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/ops/softmax/block/block_softmax_2d.hpp"
#include "ck_tile/ops/softmax/block/block_softmax_2d_problem.hpp"
#include "ck_tile/ops/common/generic_2d_block_shape.hpp"
#include "ck_tile/ops/common/tensor_layout.hpp"
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment