Commit 30e15644 authored by AMD-dteng's avatar AMD-dteng
Browse files

temp commit

parent 677a842e
......@@ -521,7 +521,7 @@ include_directories(BEFORE
SET(BUILD_DEV ON CACHE BOOL "BUILD_DEV")
if(BUILD_DEV)
add_compile_options(-Werror)
#add_compile_options(-Werror)
add_compile_options(-Weverything)
endif()
message("CMAKE_CXX_FLAGS: ${CMAKE_CXX_FLAGS}")
......
......@@ -66,7 +66,7 @@ else()
-Wunreachable-code
-Wunused
-Wno-reserved-identifier
-Werror
#-Werror
-Wno-option-ignored
-Wsign-compare
-Wno-extra-semi-stmt
......
make tile_example_layernorm2d_bwd -j 200
./bin/tile_example_layernorm2d_bwd -m=2048 -n=2048
rocprofv2 --kernel-trace -d /home/dteng/PerfProf/out -o kernel_trace
rocprofv2 -i /home/dteng/PerfProf/input.txt --plugin att auto -d /home/dteng/PerfProf/out
rocprofv2 -i /home/dteng/PerfProf/input.txt --plugin att auto --mode csv -d /home/dteng/PerfProf/out
\ No newline at end of file
......@@ -84,7 +84,8 @@ struct layernorm2d_fwd_traits_
if constexpr(is_warp_per_row)
{
static_assert(warpSize % ThreadPerBlock_N_ == 0);
return total_warps * (warpSize / ThreadPerBlock_N_);
//return total_warps * (warpSize / ThreadPerBlock_N_);
return total_warps;
}
else
{
......@@ -483,7 +484,7 @@ float layernorm2d_fwd(layernorm2d_fwd_traits t,
_sweep_cond = 't.fused_quant == {f_fused_sweep} && (t.prec_sy == \"{f_sy_type}\")'.format(
f_fused_sweep = ins.F_kFusedQuant, f_sy_type=ins.F_YScaleDataType)
_cond = '((a.n % {f_vec_n} == 0) && (t.xbias == {f_xbias}) && (t.fused_add == {f_fused_add}) && ({f_sweep_cond}))'.format(
f_vec_n = ins.F_Vector_N, f_xbias = ins.F_kXbias, f_fused_add = ins.F_kFusedAdd,
f_vec_n = 1, f_xbias = ins.F_kXbias, f_fused_add = ins.F_kFusedAdd,
f_sweep_cond = _sweep_cond)
inner_str += self.API_INNER_CASE.format(F_if = get_if_str(idx_in_n, len_in_n, False),
F_VEC_COND = _cond, F_instance_func=ins.call_name)
......
......@@ -5,7 +5,29 @@
#include "layernorm2d_bwd_instance_common.hpp"
// clang-format off
// rm tm tn pd
template float layernorm2d_bwd_<trait_<ck_tile::bf16_t, 1, 1, 64, true>>(const S&, A);
template float layernorm2d_bwd_<trait_<ck_tile::fp16_t, 1, 1, 64, true>>(const S&, A);
// rm rn tm tn vn pd
// template float layernorm2d_bwd_<trait_<ck_tile::bf16_t, 1, 1, 1, 64, 1, true>>(const S&, A);
// template float layernorm2d_bwd_<trait_<ck_tile::fp16_t, 1, 1, 1, 64, 1, true>>(const S&, A);
// template float layernorm2d_bwd_<trait_<ck_tile::bf16_t, 1, 1, 1, 128, 1, true>>(const S&, A);
// template float layernorm2d_bwd_<trait_<ck_tile::fp16_t, 1, 1, 1, 128, 1, true>>(const S&, A);
// template float layernorm2d_bwd_<trait_<ck_tile::bf16_t, 1, 1, 1, 128, 8, true>>(const S&, A);
// template float layernorm2d_bwd_<trait_<ck_tile::fp16_t, 1, 1, 1, 128, 8, true>>(const S&, A);
// template float layernorm2d_bwd_<trait_<ck_tile::bf16_t, 1, 1, 1, 256, 1, true>>(const S&, A);
// template float layernorm2d_bwd_<trait_<ck_tile::fp16_t, 1, 1, 1, 256, 1, true>>(const S&, A);
// large m
// template float layernorm2d_bwd_<trait_<ck_tile::bf16_t, 1, 2, 4, 16, 8, true>>(const S&, A);
// template float layernorm2d_bwd_<trait_<ck_tile::fp16_t, 1, 2, 4, 16, 8, true>>(const S&, A);
// template float layernorm2d_bwd_<trait_<ck_tile::bf16_t, 1, 3, 8, 8, 8, true>>(const S&, A);
// template float layernorm2d_bwd_<trait_<ck_tile::fp16_t, 1, 3, 8, 8, 8, true>>(const S&, A);
// template float layernorm2d_bwd_<trait_<ck_tile::bf16_t, 1, 4, 32, 8, 8, true>>(const S&, A);
// template float layernorm2d_bwd_<trait_<ck_tile::fp16_t, 1, 4, 32, 8, 8, true>>(const S&, A);
// template float layernorm2d_bwd_<trait_<ck_tile::bf16_t, 1, 8, 64, 4, 8, true>>(const S&, A);
// template float layernorm2d_bwd_<trait_<ck_tile::fp16_t, 1, 8, 64, 4, 8, true>>(const S&, A);
// large n
// template float layernorm2d_bwd_<trait_<ck_tile::bf16_t, 1, 32, 4, 16, 8, true>>(const S&, A);
// template float layernorm2d_bwd_<trait_<ck_tile::fp16_t, 1, 32, 4, 16, 8, true>>(const S&, A);
template float layernorm2d_bwd_<trait_<ck_tile::bf16_t, 1, 4, 1, 128, 8, true>>(const S&, A);
template float layernorm2d_bwd_<trait_<ck_tile::fp16_t, 1, 4, 1, 128, 8, true>>(const S&, A);
// clang-format on
......@@ -126,6 +126,11 @@ bool run(const ck_tile::ArgParser& arg_parser)
dgamma_buf.GetDeviceBuffer(),
dbeta_buf.GetDeviceBuffer(),
dx_buf.GetDeviceBuffer(),
//tmp
ds_buf.GetDeviceBuffer(),
db_buf.GetDeviceBuffer(),
m,
n,
stride};
......@@ -155,12 +160,25 @@ bool run(const ck_tile::ArgParser& arg_parser)
dgamma_buf.FromDevice(dgamma_host_dev.data());
dbeta_buf.FromDevice(dbeta_host_dev.data());
dx_buf.FromDevice(dx_host_dev.data());
//tmp
ds_buf.FromDevice(ds_host_dev.data());
db_buf.FromDevice(db_host_dev.data());
auto [rtol, atol] = get_elimit<DataType>();
pass = ck_tile::check_err(
dgamma_host_dev, dgamma_host_ref, std::string("GAMMA OUT Error: Incorrect results!"), rtol, atol);
// pass = ck_tile::check_err(
// dgamma_host_dev, dgamma_host_ref, std::string("GAMMA OUT Error: Incorrect results!"), rtol, atol);
// pass &= ck_tile::check_err(
// dbeta_host_dev, dbeta_host_ref, std::string("BETA OUT Error: Incorrect results!"), rtol, atol);
pass &= ck_tile::check_err(
dbeta_host_dev, dbeta_host_ref, std::string("BETA OUT Error: Incorrect results!"), rtol, atol);
dx_host_dev, dx_host_ref, std::string("DX OUT Error: Incorrect results!"), rtol, atol);
//tmp
// pass &= ck_tile::check_err(
// ds_host_dev, ds_host_ref, std::string("DS OUT Error: Incorrect results!"), rtol, atol);
// pass &= ck_tile::check_err(
// db_host_dev, db_host_ref, std::string("DB OUT Error: Incorrect results!"), rtol, atol);
std::cout << ", valid:" << (pass ? "y" : "n") << std::flush << std::endl;
}
......
......@@ -43,8 +43,10 @@ struct layernorm2d_bwd_args : public ck_tile::Layernorm2dBwdGammaBetaHostArgs
// this is used to pattern-match internl kernel implementation, not to instantiate kernel
template <typename DataType_,
ck_tile::index_t Repeat_M_, // each thread repeat along M
ck_tile::index_t Repeat_N_, // each thread repeat along N
ck_tile::index_t ThreadPerBlock_M_, // num threads along M
ck_tile::index_t ThreadPerBlock_N_, // num threads along N
ck_tile::index_t Vector_N_, // vector size along N
bool kPadN_>
struct layernorm2d_bwd_traits_
{
......@@ -60,7 +62,8 @@ struct layernorm2d_bwd_traits_
if constexpr(is_warp_per_row)
{
static_assert(warpSize % ThreadPerBlock_N_ == 0);
return total_warps * (warpSize / ThreadPerBlock_N_);
// return total_warps * (warpSize / ThreadPerBlock_N_);
return total_warps;
}
else
{
......@@ -84,17 +87,18 @@ struct layernorm2d_bwd_traits_
}();
static constexpr ck_tile::index_t Repeat_M = Repeat_M_;
static constexpr ck_tile::index_t Repeat_N = Repeat_N_;
static constexpr ck_tile::index_t Block_M = Repeat_M_ * ThreadPerBlock_M_;
static constexpr ck_tile::index_t Block_N = ThreadPerBlock_N_;
static constexpr ck_tile::index_t Block_N = Repeat_N_ * ThreadPerBlock_N_ * Vector_N_;
static constexpr ck_tile::index_t Warp_M = ThreadPerBlock_M_ / BlockWarps_M;
static constexpr ck_tile::index_t Warp_N = ThreadPerBlock_N_ / BlockWarps_N;
static constexpr ck_tile::index_t Warp_N = ThreadPerBlock_N_ / BlockWarps_N * Vector_N_;
using BlockTile = ck_tile::sequence<Block_M, Block_N>;
using BlockWarps = ck_tile::sequence<BlockWarps_M, BlockWarps_N>;
using WarpTile = ck_tile::sequence<Warp_M, Warp_N>;
using Vector = ck_tile::sequence<1, 1>;
using Vector = ck_tile::sequence<1, Vector_N_>;
using Shape = ck_tile::Generic2dBlockShape<BlockTile, BlockWarps, WarpTile, Vector>;
......@@ -103,13 +107,17 @@ struct layernorm2d_bwd_traits_
template <typename DataType_,
ck_tile::index_t Repeat_M_, // each thread repeat along M
ck_tile::index_t Repeat_N_, // each thread repeat along N
ck_tile::index_t ThreadPerBlock_M_, // num threads along M
ck_tile::index_t ThreadPerBlock_N_, // num threads along N
ck_tile::index_t Vector_N_, // vector size along N
bool kPadN_>
using trait_ = layernorm2d_bwd_traits_<DataType_,
Repeat_M_,
Repeat_N_,
ThreadPerBlock_M_,
ThreadPerBlock_N_,
Vector_N_,
kPadN_>;
template <typename Traits_>
......@@ -126,7 +134,9 @@ template <typename data_type>
struct layernorm2d_bwd_b16_
{
/* data */
using Trait = trait_<data_type, 1, 1, 64, true>;
//using Trait = trait_<data_type, 1, 1, 1, 256, 1, true>;
//using Trait = trait_<data_type, 1, 8, 64, 4, 8, true>;
using Trait = trait_<data_type, 1, 4, 1, 128, 8, true>;
float operator() (layernorm2d_bwd_traits /*t*/,
layernorm2d_bwd_args a,
const ck_tile::stream_config& s) {
......
......@@ -48,6 +48,7 @@ CK_TILE_HOST void reference_layernorm2d_bwd_gamma_part(const HostTensor<XDataTyp
const ComputeDataType dy = ck_tile::type_convert<ComputeDataType>(dy_m_n(m_offset + inner_m, n));
gamma_acc += dy * (x - mean) * inv_std;
beta_acc += dy;
//printf("\ndteng print---dy[%d][%d]=%f\n",m_offset + inner_m,n,dy);
}
dgamma_mpart_n(m, n) = ck_tile::type_convert<GammaDataType>(gamma_acc);
......@@ -69,14 +70,18 @@ CK_TILE_HOST void reference_layernorm2d_bwd_gamma_part(const HostTensor<XDataTyp
ds += dy * gamma * x;
db += dy * gamma;
}
ds_m(m_offset + inner_m) = ds;
db_m(m_offset + inner_m) = db;
ComputeDataType b = (db * mean - ds) * inv_std * inv_std * inv_std / N;
ComputeDataType c = -b * mean - db * inv_std / N;
for(int n = 0; n < N; ++n)
{
const ComputeDataType dy = ck_tile::type_convert<ComputeDataType>(dy_m_n(m_offset + inner_m, n));
const ComputeDataType x = ck_tile::type_convert<ComputeDataType>(x_m_n(m_offset + inner_m, n));
const ComputeDataType gamma = ck_tile::type_convert<ComputeDataType>(gamma_n(n));
dx_m_n(m_offset + inner_m, n) = ck_tile::type_convert<XDataType>(dy * gamma * inv_std + b * x + c);
//printf("\ndteng print---dx[%d][%d]=%f\n",m_offset + inner_m,n,ck_tile::type_convert<ComputeDataType>(dx_m_n(m_offset + inner_m, n)));
}
}
};
......
......@@ -21,6 +21,10 @@ struct Layernorm2dBwdGammaBetaHostArgs
void* p_dBeta;
void* p_dX;
//tmp
void* p_dS;
void* p_dB;
index_t m;
index_t n;
index_t stride; // row_stride
......@@ -43,6 +47,7 @@ struct Layernorm2dBwdGammaBeta
static constexpr index_t Block_M = Problem::BlockShape::Block_M;
static constexpr index_t Block_N = Problem::BlockShape::Block_N;
static constexpr index_t Vector_N = Problem::BlockShape::Vector_N;
static constexpr bool kPadM = false; // always no need to pad along M
static constexpr bool kPadN = Problem::kPadN;
......@@ -63,6 +68,10 @@ struct Layernorm2dBwdGammaBeta
void* p_dBeta;
void* p_dX;
//tmp
void* p_dS;
void* p_dB;
index_t m;
index_t n;
index_t stride; // row_stride
......@@ -79,6 +88,11 @@ struct Layernorm2dBwdGammaBeta
hargs.p_dGamma,
hargs.p_dBeta,
hargs.p_dX,
//tmp
hargs.p_dS,
hargs.p_dB,
hargs.m,
hargs.n,
hargs.stride};
......@@ -128,11 +142,17 @@ struct Layernorm2dBwdGammaBeta
const auto block_id = get_block_id();
const auto iM = block_id * Block_M;
// if(threadIdx.x == 0 && blockIdx.x == 0){
// printf("dteng block shape---WarpPerBlock_M=%d, WarpPerBlock_N=%d, ThreadPerWarp_M=%d, ThreadPerWarp_N=%d, Vector_N=%d\n", static_cast<int>(Problem::BlockShape::WarpPerBlock_M), static_cast<int>(Problem::BlockShape::WarpPerBlock_N), static_cast<int>(Problem::BlockShape::ThreadPerWarp_M), static_cast<int>(Problem::BlockShape::ThreadPerWarp_N), static_cast<int>(Problem::BlockShape::Vector_N));
// }
const auto x_window = [&]() {
const auto tmp_ = make_naive_tensor_view<address_space_enum::global>(
static_cast<const XDataType*>(kargs.p_x),
make_tuple(kargs.m, kargs.n),
make_tuple(kargs.stride, 1));
make_tuple(kargs.stride, 1),
number<Vector_N>{},
number<1>{});
// NOTE: we don't do any pad in this kernel for loading, assume that inside kernel will
// check the max count dynamically
......@@ -146,7 +166,9 @@ struct Layernorm2dBwdGammaBeta
const auto tmp_ = make_naive_tensor_view<address_space_enum::global>(
static_cast<const YDataType*>(kargs.p_dY),
make_tuple(kargs.m, kargs.n),
make_tuple(kargs.stride, 1));
make_tuple(kargs.stride, 1),
number<Vector_N>{},
number<1>{});
// NOTE: we don't do any pad in this kernel for loading, assume that inside kernel will
// check the max count dynamically
......@@ -160,7 +182,9 @@ struct Layernorm2dBwdGammaBeta
const auto tmp_ = make_naive_tensor_view<address_space_enum::global>(
static_cast<const MeanDataType*>(kargs.p_gamma),
make_tuple(kargs.n),
make_tuple(1));
make_tuple(1),
number<Vector_N>{},
number<1>{});
const auto tmp2_ =
pad_tensor_view(tmp_, make_tuple(number<Block_N>{}), sequence<false>{});
......@@ -175,7 +199,7 @@ struct Layernorm2dBwdGammaBeta
make_tuple(1));
const auto tmp2_ =
pad_tensor_view(tmp_, make_tuple(number<Block_M>{}), sequence<false>{});
pad_tensor_view(tmp_, make_tuple(number<Block_M>{}), sequence<kPadM>{});
return make_tile_window(tmp2_, make_tuple(number<Block_M>{}), {iM});
}();
......@@ -187,7 +211,7 @@ struct Layernorm2dBwdGammaBeta
make_tuple(1));
const auto tmp2_ =
pad_tensor_view(tmp_, make_tuple(number<Block_M>{}), sequence<false>{});
pad_tensor_view(tmp_, make_tuple(number<Block_M>{}), sequence<kPadM>{});
return make_tile_window(tmp2_, make_tuple(number<Block_M>{}), {iM});
}();
......@@ -196,7 +220,9 @@ struct Layernorm2dBwdGammaBeta
const auto tmp_ = make_naive_tensor_view<address_space_enum::global>(
static_cast<GammaDataType*>(kargs.p_dGamma),
make_tuple(gridDim.x, kargs.n),
make_tuple(kargs.n, 1));
make_tuple(kargs.n, 1),
number<Vector_N>{},
number<1>{});
const auto tmp2_ =
pad_tensor_view(tmp_, make_tuple(number<1>{}, number<Block_N>{}), sequence<false, kPadN>{});
......@@ -208,7 +234,9 @@ struct Layernorm2dBwdGammaBeta
const auto tmp_ = make_naive_tensor_view<address_space_enum::global>(
static_cast<BetaDataType*>(kargs.p_dBeta),
make_tuple(gridDim.x, kargs.n),
make_tuple(kargs.n, 1));
make_tuple(kargs.n, 1),
number<Vector_N>{},
number<1>{});
const auto tmp2_ =
pad_tensor_view(tmp_, make_tuple(number<1>{}, number<Block_N>{}), sequence<false, kPadN>{});
......@@ -219,14 +247,42 @@ struct Layernorm2dBwdGammaBeta
const auto tmp_ = make_naive_tensor_view<address_space_enum::global>(
static_cast<XDataType*>(kargs.p_dX),
make_tuple(kargs.m, kargs.n),
make_tuple(kargs.stride, 1));
make_tuple(kargs.stride, 1),
number<Vector_N>{},
number<1>{});
const auto tmp2_ =
pad_tensor_view(tmp_, make_tuple(number<Block_M>{}, number<Block_N>{}), sequence<false, false>{});
pad_tensor_view(tmp_, make_tuple(number<Block_M>{}, number<Block_N>{}), sequence<kPadM, kPadN>{});
return make_tile_window(tmp2_, make_tuple(number<Block_M>{}, number<Block_N>{}), {iM, 0});
}();
__shared__ char smem[GetSmemSize()];
//tmp
const auto ds_window = [&]() {
const auto tmp_ = make_naive_tensor_view<address_space_enum::global>(
static_cast<ComputeDataType*>(kargs.p_dS),
make_tuple(kargs.m),
make_tuple(1));
const auto tmp2_ =
pad_tensor_view(tmp_, make_tuple(number<Block_M>{}), sequence<kPadM>{});
return make_tile_window(tmp2_, make_tuple(number<Block_M>{}), {iM});
}();
const auto db_window = [&]() {
const auto tmp_ = make_naive_tensor_view<address_space_enum::global>(
static_cast<ComputeDataType*>(kargs.p_dB),
make_tuple(kargs.m),
make_tuple(1));
const auto tmp2_ =
pad_tensor_view(tmp_, make_tuple(number<Block_M>{}), sequence<kPadM>{});
return make_tile_window(tmp2_, make_tuple(number<Block_M>{}), {iM});
}();
// __shared__ char smem[GetSmemSize()];
__shared__ char smem[0];
Pipeline{}(x_window,
dy_window,
......@@ -236,6 +292,11 @@ struct Layernorm2dBwdGammaBeta
dgamma_window,
dbeta_window,
dx_window,
//tmp
ds_window,
db_window,
kargs.n,
smem);
}
......
......@@ -192,7 +192,9 @@ struct Layernorm2dFwd
CK_TILE_DEVICE void operator()(Kargs kargs) const
{
const auto iM = get_block_id() * Block_M;
// if(threadIdx.x == 0 && blockIdx.x == 0){
// printf("dteng block shape---WarpPerBlock_M=%d, WarpPerBlock_N=%d, ThreadPerWarp_M=%d, ThreadPerWarp_N=%d, Vector_N=%d\n", static_cast<int>(Problem::BlockShape::WarpPerBlock_M), static_cast<int>(Problem::BlockShape::WarpPerBlock_N), static_cast<int>(Problem::BlockShape::ThreadPerWarp_M), static_cast<int>(Problem::BlockShape::ThreadPerWarp_N), static_cast<int>(Problem::BlockShape::Vector_N));
// }
const auto x_window = [&]() {
const auto tmp_ = make_naive_tensor_view<address_space_enum::global>(
static_cast<const XDataType*>(kargs.p_x),
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core.hpp"
#include "ck_tile/ops/layernorm2d/pipeline/layernorm2d_bwd_pipeline_default_policy.hpp"
#include "ck_tile/ops/reduce/block/block_reduce2d_default_policy.hpp"
#include <string>
#include <type_traits>
namespace ck_tile {
template <typename Problem_, typename Policy_ = Layernorm2dBwdGammaBetaPipelineDefaultPolicy>
struct Layernorm2dBwdGammaBetaPipeline
{
using Problem = ck_tile::remove_cvref_t<Problem_>;
using Policy = ck_tile::remove_cvref_t<Policy_>;
using ReducePolicy = ck_tile::remove_cvref_t<BlockReduce2dDefaultPolicy>;
using XDataType = ck_tile::remove_cvref_t<typename Problem::XDataType>;
using GammaDataType = ck_tile::remove_cvref_t<typename Problem::GammaDataType>;
using BetaDataType = ck_tile::remove_cvref_t<typename Problem::BetaDataType>;
using ComputeDataType = ck_tile::remove_cvref_t<typename Problem::ComputeDataType>;
using YDataType = ck_tile::remove_cvref_t<typename Problem::YDataType>;
using MeanDataType = ck_tile::remove_cvref_t<typename Problem::MeanDataType>;
using InvStdDataType = ck_tile::remove_cvref_t<typename Problem::InvStdDataType>;
static constexpr bool kPadM = false;
static constexpr bool kPadN = Problem::kPadN;
static constexpr const char* name = []() { return "bwd_gamma_beta"; }();
CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSize()
{
return Policy::template GetSmemSize<Problem>();
}
template <typename XWindow,
typename GammaWindow,
typename MeanWindow,
typename InvStdWindow,
typename DGammaWindow,
typename DBetaWindow,
typename DXWindow,
// tmp
typename DSWindow,
typename DBWindow>
CK_TILE_DEVICE auto operator()(const XWindow& x_window_,
const XWindow& dy_window_,
const GammaWindow& gamma_window_,
const MeanWindow& mean_window_,
const InvStdWindow& inv_std_window_,
DGammaWindow& dgamma_window_,
DBetaWindow& dbeta_window_,
DXWindow& dx_window_,
// tmp
DSWindow& ds_window_,
DBWindow& db_window_,
ck_tile::index_t row_size,
void* smem) const
{
(void)smem;
auto gamma_beta_dist = Policy::template MakeGammaBetaBlockTileDistribution<Problem>();
auto dgamma_beta_dist = Policy::template MakeDGammaBetaBlockTileDistribution<Problem>();
auto mean_dist = Policy::template MakeMeanBlockTileDistribution<Problem>();
auto x_dist = Policy::template MakeXBlockTileDistribution<Problem>();
const auto x_window = make_tile_window(x_window_, x_dist);
const auto dy_window = make_tile_window(dy_window_, x_dist);
const auto gamma_window = make_tile_window(gamma_window_, gamma_beta_dist); // TO CHECK
const auto mean_window = make_tile_window(mean_window_, mean_dist);
const auto inv_std_window = make_tile_window(inv_std_window_, mean_dist);
auto dgamma_window = make_tile_window(dgamma_window_, dgamma_beta_dist);
auto dbeta_window = make_tile_window(dbeta_window_, dgamma_beta_dist);
auto dx_window = make_tile_window(dx_window_, x_dist);
const auto mean_tile = load_tile(mean_window);
const auto inv_std_tile = load_tile(inv_std_window);
// tmp
(void)ds_window_;
(void)db_window_;
//auto ds_window = make_tile_window(ds_window_, mean_dist);
//auto db_window = make_tile_window(db_window_, mean_dist);
auto ds_tile = make_static_distributed_tensor<ComputeDataType>(mean_dist);
auto db_tile = make_static_distributed_tensor<ComputeDataType>(mean_dist);
clear_tile(ds_tile);
clear_tile(db_tile);
auto dgamma_tile = make_static_distributed_tensor<GammaDataType>(dgamma_beta_dist);
auto dbeta_tile = make_static_distributed_tensor<BetaDataType>(dgamma_beta_dist);
auto dx_tile = make_static_distributed_tensor<XDataType>(x_dist);
auto dgamma = cast_tile<ComputeDataType>(dgamma_tile);
auto dbeta = cast_tile<ComputeDataType>(dbeta_tile);
auto dx = cast_tile<ComputeDataType>(dx_tile);
static constexpr index_t Block_N = Problem::BlockShape::Block_N;
index_t num_n_tile_iteration = __builtin_amdgcn_readfirstlane(integer_divide_ceil(row_size, Block_N));
for(int iN = __builtin_amdgcn_readfirstlane(0); iN < num_n_tile_iteration; ++iN)
{
const auto x_tile = load_tile(x_window);
const auto dy_tile = load_tile(dy_window);
const auto gamma_tile = load_tile(gamma_window);
move_tile_window(x_window, {0, Block_N});
move_tile_window(dy_window, {0, Block_N});
move_tile_window(gamma_window, {Block_N});
sweep_tile(x_tile, [&](auto idx) {
constexpr auto i_idx = make_tuple(idx[number<0>{}]);
constexpr auto j_idx = make_tuple(idx[number<1>{}]);
const auto x = type_convert<ComputeDataType>(x_tile[idx]);
const auto dy = type_convert<ComputeDataType>(dy_tile[idx]);
const auto gamma = type_convert<ComputeDataType>(gamma_tile[j_idx]);
ds_tile(i_idx) += dy * gamma * x;
db_tile(i_idx) += dy * gamma;
// printf("threadidx=%d, blockidx=%d, ds_tile=%f\n",threadIdx.x, blockIdx.x, ds_tile[i_idx]);
});
}
auto block_reduce2d_sync = ReducePolicy::template GetBlockReduce2dSync<Problem>();
block_reduce2d_sync(ds_tile, ck_tile::ReduceOp::Add{});
block_reduce2d_sync(db_tile, ck_tile::ReduceOp::Add{});
// sweep_tile(x_tile, [&](auto idx) {
// constexpr auto i_idx = make_tuple(idx[number<0>{}]);
// printf("post::threadidx=%d, blockidx=%d, ds_tile=%f\n",threadIdx.x, blockIdx.x,
// ds_tile[i_idx]);
// });
//store_tile(ds_window, ds_tile);
//store_tile(db_window, db_tile);
ck_tile::index_t stride_to_right_most_window = row_size % Block_N == 0 ? row_size - Block_N : row_size - row_size % Block_N;
move_tile_window(x_window, {0, -Block_N});
move_tile_window(dy_window, {0, -Block_N});
move_tile_window(gamma_window, {-Block_N});
move_tile_window(dx_window, {0, stride_to_right_most_window});
move_tile_window(dbeta_window, {0, stride_to_right_most_window});
move_tile_window(dgamma_window, {0, stride_to_right_most_window});
using XDistributedTensor = decltype(load_tile(x_window));
constexpr auto spans = XDistributedTensor::get_distributed_spans();
for(int iN = __builtin_amdgcn_readfirstlane(0); iN < num_n_tile_iteration; ++iN)
{
sweep_tile_span(spans[number<0>{}], [&](auto i_idx) {
constexpr auto idx0 = make_tuple(i_idx);
const auto mean = type_convert<ComputeDataType>(mean_tile[idx0]);
const auto inv_std = type_convert<ComputeDataType>(inv_std_tile[idx0]);
auto b = (db_tile[idx0] * mean - ds_tile[idx0]) * inv_std * inv_std * inv_std / row_size;
auto c = -b * mean - db_tile[idx0] * inv_std / row_size;
sweep_tile_span(spans[number<1>{}], [&](auto j_idx) {
constexpr auto idx = make_tuple(i_idx, j_idx);
constexpr auto gb_idx = make_tuple(number<0>{}, j_idx);
const auto x = type_convert<ComputeDataType>(x_tile[idx]);
const auto dy = type_convert<ComputeDataType>(dy_tile[idx]);
const auto gamma = type_convert<ComputeDataType>(gamma_tile[idx]);
dbeta(gb_idx) += dy;
dgamma(gb_idx) += dy * (x - mean) * inv_std;
dx(idx) = dy * gamma * inv_std + b * x + c;
});
});
store_tile(dbeta_window, cast_tile<BetaDataType>(dbeta));
store_tile(dgamma_window, cast_tile<GammaDataType>(dgamma));
store_tile(dx_window, cast_tile<XDataType>(dx));
move_tile_window(x_window, {0, -Block_N});
move_tile_window(dy_window, {0, -Block_N});
move_tile_window(gamma_window, {-Block_N});
move_tile_window(dx_window, {0, -Block_N});
move_tile_window(dbeta_window, {0, -Block_N});
move_tile_window(dgamma_window, {0, -Block_N});
}
}
};
} // namespace ck_tile
......@@ -17,12 +17,12 @@ struct Layernorm2dBwdGammaBetaPipelineDefaultPolicy
return make_static_tile_distribution(
tile_distribution_encoding<
sequence<>,
tuple<sequence<S::Repeat_M, S::WarpPerBlock_M, S::ThreadPerWarp_M>,
sequence<S::Repeat_N, S::WarpPerBlock_N, S::ThreadPerWarp_N>>,
tuple<sequence<S::Repeat_M, S::WarpPerBlock_M, S::ThreadPerWarp_M, S::Vector_M>,
sequence<S::Repeat_N, S::WarpPerBlock_N, S::ThreadPerWarp_N, S::Vector_N>>,
tuple<sequence<1, 2>, sequence<1, 2>>,
tuple<sequence<1, 1>, sequence<2, 2>>,
sequence<1, 2>,
sequence<0, 0>>{});
sequence<1, 1, 2, 2>,
sequence<0, 3, 0, 3>>{});
}
template <typename Problem>
CK_TILE_DEVICE static constexpr auto MakeMeanBlockTileDistribution()
......@@ -32,11 +32,11 @@ struct Layernorm2dBwdGammaBetaPipelineDefaultPolicy
return make_static_tile_distribution(
tile_distribution_encoding<
sequence<S::WarpPerBlock_N, S::ThreadPerWarp_N>,
tuple<sequence<S::Repeat_M, S::WarpPerBlock_M, S::ThreadPerWarp_M>>,
tuple<sequence<S::Repeat_M, S::WarpPerBlock_M, S::ThreadPerWarp_M, S::Vector_M>>,
tuple<sequence<1, 0>, sequence<1, 0>>,
tuple<sequence<1, 0>, sequence<2, 1>>,
sequence<1>,
sequence<0>>{});
sequence<1, 1>,
sequence<0, 3>>{});
}
template <typename Problem>
......@@ -48,11 +48,11 @@ struct Layernorm2dBwdGammaBetaPipelineDefaultPolicy
tile_distribution_encoding<
sequence<>,
tuple<sequence<S::WarpPerBlock_M, S::ThreadPerWarp_M>,
sequence<S::Repeat_N, S::WarpPerBlock_N, S::ThreadPerWarp_N>>,
sequence<S::Repeat_N, S::WarpPerBlock_N, S::ThreadPerWarp_N, S::Vector_N>>,
tuple<sequence<1, 2>, sequence<1, 2>>,
tuple<sequence<0, 1>, sequence<1, 2>>,
sequence<2>,
sequence<0>>{});
sequence<2, 2>,
sequence<0, 3>>{});
}
template <typename Problem>
......
......@@ -5,6 +5,7 @@
#include "ck_tile/core.hpp"
#include "ck_tile/ops/layernorm2d/pipeline/layernorm2d_bwd_pipeline_default_policy.hpp"
#include "ck_tile/ops/reduce/block/block_reduce2d_default_policy.hpp"
#include <string>
#include <type_traits>
......@@ -15,6 +16,7 @@ struct Layernorm2dBwdGammaBetaPipeline
{
using Problem = ck_tile::remove_cvref_t<Problem_>;
using Policy = ck_tile::remove_cvref_t<Policy_>;
using ReducePolicy = ck_tile::remove_cvref_t<BlockReduce2dDefaultPolicy>;
using XDataType = ck_tile::remove_cvref_t<typename Problem::XDataType>;
using GammaDataType = ck_tile::remove_cvref_t<typename Problem::GammaDataType>;
......@@ -27,13 +29,12 @@ struct Layernorm2dBwdGammaBetaPipeline
static constexpr bool kPadM = false;
static constexpr bool kPadN = Problem::kPadN;
static constexpr const char* name = []() {
return "bwd_gamma_beta";
}();
static constexpr const char* name = []() { return "bwd_gamma_beta"; }();
CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSize()
{
return Policy::template GetSmemSize<Problem>();
return ReducePolicy::template GetSmemSize<Problem>();
//GetBlockReduce2dCrossWarpSync<Problem>().template GetSmemSize<y_block_tile>();
}
template <typename XWindow,
typename GammaWindow,
......@@ -41,7 +42,11 @@ struct Layernorm2dBwdGammaBetaPipeline
typename InvStdWindow,
typename DGammaWindow,
typename DBetaWindow,
typename DXWindow>
typename DXWindow,
// tmp
typename DSWindow,
typename DBWindow>
CK_TILE_DEVICE auto operator()(const XWindow& x_window_,
const XWindow& dy_window_,
const GammaWindow& gamma_window_,
......@@ -50,11 +55,14 @@ struct Layernorm2dBwdGammaBetaPipeline
DGammaWindow& dgamma_window_,
DBetaWindow& dbeta_window_,
DXWindow& dx_window_,
// tmp
DSWindow& ds_window_,
DBWindow& db_window_,
ck_tile::index_t row_size,
void* smem) const
{
(void)row_size;
(void)smem;
auto gamma_beta_dist = Policy::template MakeGammaBetaBlockTileDistribution<Problem>();
auto dgamma_beta_dist = Policy::template MakeDGammaBetaBlockTileDistribution<Problem>();
......@@ -63,70 +71,109 @@ struct Layernorm2dBwdGammaBetaPipeline
const auto x_window = make_tile_window(x_window_, x_dist);
const auto dy_window = make_tile_window(dy_window_, x_dist);
const auto gamma_window = make_tile_window(gamma_window_, gamma_beta_dist); //TO CHECK
const auto gamma_window = make_tile_window(gamma_window_, gamma_beta_dist); // TO CHECK
const auto mean_window = make_tile_window(mean_window_, mean_dist);
const auto inv_std_window = make_tile_window(inv_std_window_, mean_dist);
auto dgamma_window = make_tile_window(dgamma_window_, dgamma_beta_dist);
auto dbeta_window = make_tile_window(dbeta_window_, dgamma_beta_dist);
auto dx_window = make_tile_window(dx_window_, x_dist);
const auto x_tile = load_tile(x_window);
const auto dy_tile = load_tile(dy_window);
const auto gamma_tile = load_tile(gamma_window);
const auto mean_tile = load_tile(mean_window);
const auto inv_std_tile = load_tile(inv_std_window);
auto dgamma_window = make_tile_window(dgamma_window_, dgamma_beta_dist);
auto dbeta_window = make_tile_window(dbeta_window_, dgamma_beta_dist);
auto dx_window = make_tile_window(dx_window_, x_dist);
auto dgamma_tile = make_static_distributed_tensor<GammaDataType>(dgamma_beta_dist);
auto dbeta_tile = make_static_distributed_tensor<BetaDataType>(dgamma_beta_dist);
// tmp
auto ds_window = make_tile_window(ds_window_, mean_dist);
auto db_window = make_tile_window(db_window_, mean_dist);
auto ds_tile = make_static_distributed_tensor<ComputeDataType>(mean_dist);
auto db_tile = make_static_distributed_tensor<ComputeDataType>(mean_dist);
clear_tile(ds_tile);
clear_tile(db_tile);
// (void)ds_window;
// (void)db_window;
// auto dgamma_tile = make_static_distributed_tensor<GammaDataType>(dgamma_beta_dist);
// auto dbeta_tile = make_static_distributed_tensor<BetaDataType>(dgamma_beta_dist);
auto dx_tile = make_static_distributed_tensor<XDataType>(x_dist);
auto dgamma = cast_tile<ComputeDataType>(dgamma_tile);
auto dbeta = cast_tile<ComputeDataType>(dbeta_tile);
auto dx = cast_tile<XDataType>(dx_tile);
// auto dgamma = cast_tile<ComputeDataType>(dgamma_tile);
// auto dbeta = cast_tile<ComputeDataType>(dbeta_tile);
auto dx = cast_tile<ComputeDataType>(dx_tile);
// auto gen_ones = [](ck_tile::index_t size) -> uint64_t {
// if (size <= 0) return 0;
// if (size >= 64) return 0xFFFFFFFFFFFFFFFF;
// return (1ULL << size) - 1;
// };
(void)dx_window;
(void)dx;
(void)gamma_tile;
// uint64_t lane_en = gen_ones(row_size);
// printf("lane en is %lu", lane_en);
// //uint64_t lane_en = (1ULL << row_size) - 1;
// asm volatile("s_mov_b64 exec, %[s_lane_en]"
// :
// : [s_lane_en]"s"(lane_en)
// : );
sweep_tile(x_tile, [&](auto idx) {
constexpr auto i_idx = make_tuple(idx[number<0>{}]);
//constexpr auto j_idx = make_tuple(idx[number<1>{}]);
constexpr auto gb_idx = make_tuple(number<0>{}, idx[number<1>{}]);
// auto &gamma = gamma_tile(gb_idx);
// auto &beta = beta_tile(gb_idx);
constexpr auto j_idx = make_tuple(idx[number<1>{}]);
const auto x = type_convert<ComputeDataType>(x_tile[idx]);
const auto dy = type_convert<ComputeDataType>(dy_tile[idx]);
const auto mean = type_convert<ComputeDataType>(mean_tile[i_idx]);
const auto inv_std = type_convert<ComputeDataType>(inv_std_tile[i_idx]);
// beta += type_convert<BetaDataType>(dy);
// gamma += type_convert<GammaDataType>(dy * (x - mean) * inv_std);
dbeta(gb_idx) += dy;
dgamma(gb_idx) += dy * (x - mean) * inv_std;
// index_t tid = (threadIdx.y * blockDim.x) + threadIdx.x;
// if(blockIdx.x < 3 && blockIdx.y == 0 && tid < 3) {
// printf("bid %d tid %d count %d gb %f %f\n",blockIdx.x, tid, count, type_convert<float>(g), type_convert<float>(b));
// }
const auto gamma = type_convert<ComputeDataType>(gamma_tile[j_idx]);
ds_tile(i_idx) += dy * gamma * x;
db_tile(i_idx) += dy * gamma;
// printf("db_tile pre: threadidx=%d, blockidx=%d, db_tile=%f\n",threadIdx.x, blockIdx.x, db_tile[i_idx]);
// printf("dy_tile: threadidx=%d, blockidx=%d, dy_tile=%f\n",threadIdx.x, blockIdx.x, dy);
// printf("x: threadidx=%d, blockidx=%d, x_tile=%f\n",threadIdx.x, blockIdx.x, x);
// printf("gamma: threadidx=%d, blockidx=%d, gamma_tile=%f\n",threadIdx.x, blockIdx.x, gamma);
});
store_tile(dbeta_window, cast_tile<BetaDataType>(dbeta));
store_tile(dgamma_window, cast_tile<GammaDataType>(dgamma));
// store_tile(gamma_window, gamma_tile);
// store_tile(beta_window, beta_tile);
auto block_reduce2d_sync = ReducePolicy::template GetBlockReduce2dSync<Problem>();
auto block_reduce2d_cross_warp_sync = ReducePolicy::template GetBlockReduce2dCrossWarpSync<Problem>();
block_reduce2d_sync(ds_tile, ck_tile::ReduceOp::Add{});
block_reduce2d_sync(db_tile, ck_tile::ReduceOp::Add{});
// block_reduce2d_cross_warp_sync(ds_tile, smem, ck_tile::ReduceOp::Add{});
// block_reduce2d_cross_warp_sync(db_tile, smem, ck_tile::ReduceOp::Add{});
// auto ds = cast_tile<ComputeDataType>(mean_tile);
// auto db = cast_tile<ComputeDataType>(mean_tile);
// //calculate dx
// sweep_tile(x_tile, [&](auto idx)) {
// sweep_tile(x_tile, [&](auto idx) {
// constexpr auto i_idx = make_tuple(idx[number<0>{}]);
// constexpr auto j_idx = make_tuple(idx[number<1>{}]);
// const auto x = type_convert<ComputeDataType>(x_tile[idx]);
// const auto dy = type_convert<ComputeDataType>(dy_tile[idx]);
// const auto gamma = type_convert<ComputeDataType>(gamma_tile[j_idx]);
// // const auto mean = type_convert<ComputeDataType>(mean_tile[i_idx]);
// // const auto inv_std = type_convert<ComputeDataType>(inv_std_tile[i_idx]);
// ds[i_idx] += dy * gamma * x;
// db[i_idx] += dy * gamma;
// }
// printf("db_tile post: threadidx=%d, blockidx=%d, db_tile=%f\n",threadIdx.x, blockIdx.x,
// db_tile[i_idx]);
// });
// store_tile(ds_window, ds_tile);
// store_tile(db_window, db_tile);
using XDistributedTensor = decltype(load_tile(x_window));
constexpr auto spans = XDistributedTensor::get_distributed_spans();
sweep_tile_span(spans[number<0>{}], [&](auto i_idx) {
constexpr auto idx0 = make_tuple(i_idx);
const auto mean = type_convert<ComputeDataType>(mean_tile[idx0]);
const auto inv_std = type_convert<ComputeDataType>(inv_std_tile[idx0]);
auto b = (db_tile[idx0] * mean - ds_tile[idx0]) * inv_std * inv_std * inv_std / row_size;
auto c = -b * mean - db_tile[idx0] * inv_std / row_size;
sweep_tile_span(spans[number<1>{}], [&](auto j_idx) {
constexpr auto idx1 = make_tuple(j_idx);
constexpr auto idx = make_tuple(i_idx, j_idx);
//constexpr auto gb_idx = make_tuple(number<0>{}, j_idx);
const auto x = type_convert<ComputeDataType>(x_tile[idx]);
const auto dy = type_convert<ComputeDataType>(dy_tile[idx]);
const auto gamma = type_convert<ComputeDataType>(gamma_tile[idx1]);
// dbeta(gb_idx) += dy;
// dgamma(gb_idx) += dy * (x - mean) * inv_std;
dx(idx) = dy * gamma * inv_std + b * x + c;
//printf("dx: threadidx=%d, blockidx=%d, dx_tile=%f\n",threadIdx.x, blockIdx.x, dx(idx));
});
});
// store_tile(dbeta_window, cast_tile<BetaDataType>(dbeta));
// store_tile(dgamma_window, cast_tile<GammaDataType>(dgamma));
store_tile(dx_window, cast_tile<XDataType>(dx));
}
};
} // namespace ck_tile
......@@ -28,6 +28,8 @@ struct Layernorm2dBwdGammaBetaPipelineProblem
using BlockShape = remove_cvref_t<BlockShape_>;
static constexpr bool kPadN = kPadN_;
static constexpr bool kNeedCrossLaneSync = BlockShape::ThreadPerWarp_N > 1;
static constexpr bool kNeedCrossWarpSync = BlockShape::WarpPerBlock_N > 1;
};
} // namespace ck_tile
......@@ -133,7 +133,10 @@ struct Layernorm2dFwdPipelineOnePass
{
sweep_tile(x_resi, [&](auto idx) {
// compute x = x_resi + x
//printf("x: threadidx=%d, blockidx=%d, acc=%f\n",threadIdx.x, blockIdx.x, x(idx));
// printf("acc pre: threadidx=%d, blockidx=%d, acc=%f\n",threadIdx.x, blockIdx.x, acc(idx));
acc(idx) = type_convert<ComputeDataType>(x_resi(idx)) + acc(idx);
// printf("acc post: threadidx=%d, blockidx=%d, acc=%f\n",threadIdx.x, blockIdx.x, acc(idx));
});
if constexpr(kFusedAdd == Layernorm2dFusedAddEnum::PRE_ADD_STORE)
store_tile(y_residual_window, cast_tile<YResidualDataType>(acc));
......@@ -184,6 +187,7 @@ struct Layernorm2dFwdPipelineOnePass
const auto beta_ = type_convert<ComputeDataType>(beta[j_idx]);
auto ln_ = (acc[idx] - mean_[i_idx]) * inv_std[i_idx] * gamma_ + beta_;
// printf("ln: threadidx=%d, blockidx=%d, acc=%f\n",threadIdx.x, blockIdx.x, ln_);
ln(idx) = ln_;
});
......
......@@ -17,7 +17,7 @@ fi
cmake \
-D CMAKE_PREFIX_PATH=/opt/rocm/ \
-D CMAKE_CXX_COMPILER=/opt/rocm/bin/hipcc \
-D CMAKE_CXX_FLAGS="-Xclang -mllvm -Xclang -enable-post-misched=0 -std=c++17 -O3 -ftemplate-backtrace-limit=0 -fPIE -Wno-gnu-line-marker" \
-D CMAKE_CXX_FLAGS="-Xclang -mllvm -Xclang -enable-post-misched=0 -std=c++17 -O3 -ftemplate-backtrace-limit=0 -fPIE -Wno-gnu-line-marker --save-temps" \
-D CMAKE_BUILD_TYPE=Release \
-D BUILD_DEV=ON \
-D GPU_TARGETS=$GPU_TARGETS \
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment