Commit 677a842e authored by AMD-dteng's avatar AMD-dteng
Browse files

local base version

parent 5d671a5f
......@@ -42,3 +42,25 @@ target_compile_options(${EXAMPLE_LAYERNORM2D_FWD} PRIVATE ${EXAMPLE_LAYERNORM2D_
# however, this property may affect global
# TODO: consider codegen a makefile by us
set_property(GLOBAL PROPERTY RULE_MESSAGES OFF)
set(EXAMPLE_LAYERNORM2D_BWD "tile_example_layernorm2d_bwd")
# not using add_example_executable() to add this target, since we don't want this to have
# to be included in "make all/install/check"
message("adding example ${EXAMPLE_LAYERNORM2D_BWD}")
file(GLOB INSTANCE_SRCS instances/*.cpp)
add_executable(${EXAMPLE_LAYERNORM2D_BWD} EXCLUDE_FROM_ALL layernorm2d_bwd.cpp)
target_include_directories(${EXAMPLE_LAYERNORM2D_BWD} PRIVATE ${CMAKE_CURRENT_LIST_DIR})
target_sources(${EXAMPLE_LAYERNORM2D_BWD} PRIVATE ${INSTANCE_SRCS})
set(EXAMPLE_layernorm2d_bwd_COMPILE_OPTIONS)
# NOTE: we turn off undefined-func-template to let source compile without explicit declare function specializations
list(APPEND EXAMPLE_layernorm2d_bwd_COMPILE_OPTIONS -Wno-undefined-func-template -Wno-float-equal)
target_compile_options(${EXAMPLE_LAYERNORM2D_BWD} PRIVATE ${EXAMPLE_layernorm2d_bwd_COMPILE_OPTIONS})
# TODO: we have to turn off this global prop, otherwise the progress bar generated
# by cmake will print too many files, execvp: /bin/sh: Argument list too long
# however, this property may affect global
# TODO: consider codegen a makefile by us
set_property(GLOBAL PROPERTY RULE_MESSAGES OFF)
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#include <ck_tile/core.hpp>
#include "layernorm2d_bwd.hpp"
float layernorm2d_bwd(layernorm2d_bwd_traits t,
layernorm2d_bwd_args a,
const ck_tile::stream_config& s)
{
float r = -1;
if(t.data_type.compare("fp16") == 0)
{
return layernorm2d_bwd_b16_<ck_tile::fp16_t>{}(t, a, s);
}
else if(t.data_type.compare("bf16") == 0)
{
return layernorm2d_bwd_b16_<ck_tile::bf16_t>{}(t, a, s);
}
if(r < 0)
throw std::runtime_error("Without supported instances!");
return r;
}
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#include "layernorm2d_bwd_instance_common.hpp"
// clang-format off
// rm tm tn pd
template float layernorm2d_bwd_<trait_<ck_tile::bf16_t, 1, 1, 64, true>>(const S&, A);
template float layernorm2d_bwd_<trait_<ck_tile::fp16_t, 1, 1, 64, true>>(const S&, A);
// clang-format on
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#include <ck_tile/core.hpp>
#include "layernorm2d_bwd.hpp"
#include <iostream>
#pragma once
using S = ck_tile::stream_config;
using A = layernorm2d_bwd_args;
template <typename Traits_>
float layernorm2d_bwd_(const S& s, A a)
{
using DataType = typename Traits_::DataType;
using PipelineProblem = ck_tile::Layernorm2dBwdGammaBetaPipelineProblem<
typename LayerNormTypeConfig<DataType>::XDataType,
typename LayerNormTypeConfig<DataType>::GammaDataType,
typename LayerNormTypeConfig<DataType>::BetaDataType,
typename LayerNormTypeConfig<DataType>::ComputeDataType,
typename LayerNormTypeConfig<DataType>::YDataType,
typename LayerNormTypeConfig<DataType>::MeanDataType,
typename LayerNormTypeConfig<DataType>::InvStdDataType,
typename Traits_::Shape,
Traits_::kPadN>;
using Pipeline = ck_tile::Layernorm2dBwdGammaBetaPipeline<PipelineProblem>;
using Kernel = ck_tile::Layernorm2dBwdGammaBeta<Pipeline>;
const dim3 grids = Kernel::GridSize(a);
constexpr dim3 blocks = Kernel::BlockSize();
constexpr ck_tile::index_t kBlockPerCu = 1;
auto kargs = Kernel::MakeKargs(a);
if(s.log_level_ > 0)
std::cout << ", " << Kernel::GetName() << std::flush;
return ck_tile::launch_kernel(
s, ck_tile::make_kernel<blocks.x, kBlockPerCu>(Kernel{}, grids, blocks, 0, kargs));
}
#include "ck_tile/host.hpp"
#include "layernorm2d_bwd.hpp"
#include <cstring>
// different threshold for different dtype
template <typename DataType>
auto get_elimit()
{
double rtol = 1e-2;
double atol = 1e-2;
return ck_tile::make_tuple(rtol, atol);
}
template <>
auto get_elimit<ck_tile::bf16_t>()
{
double rtol = 1e-2;
double atol = 1e-2;
return ck_tile::make_tuple(rtol, atol);
}
auto create_args(int argc, char* argv[])
{
ck_tile::ArgParser arg_parser;
arg_parser.insert("m", "3328", "m dimension")
.insert("n", "4096", "n dimension")
.insert("stride", "-1", "stride per row, if -1 then equal to n")
.insert("v", "1", "cpu validation or not")
.insert("kname", "1", "print kernel name or not")
.insert("prec", "fp16", "precision")
.insert("warmup", "5", "cold iter")
.insert("repeat", "20", "hot iter");
bool result = arg_parser.parse(argc, argv);
return std::make_tuple(result, arg_parser);
}
template <typename DataType>
bool run(const ck_tile::ArgParser& arg_parser)
{
ck_tile::index_t m = arg_parser.get_int("m");
ck_tile::index_t n = arg_parser.get_int("n");
ck_tile::index_t stride = arg_parser.get_int("stride");
if(stride < 0)
stride = n;
std::string data_type = arg_parser.get_str("prec");
int kname = arg_parser.get_int("kname");
int do_validation = arg_parser.get_int("v");
int warmup = arg_parser.get_int("warmup");
int repeat = arg_parser.get_int("repeat");
assert(stride >= n);
using TypeConfig = LayerNormTypeConfig<DataType>;
using XDataType = typename TypeConfig::XDataType;
using YDataType = typename TypeConfig::YDataType;
using GammaDataType = typename TypeConfig::GammaDataType;
using BetaDataType = typename TypeConfig::BetaDataType;
using MeanDataType = typename TypeConfig::MeanDataType;
using InvStdDataType = typename TypeConfig::InvStdDataType;
using ComputeDataType = typename TypeConfig::ComputeDataType;
// host verify
ck_tile::HostTensor<XDataType> x_host({m, n}, {stride, 1});
ck_tile::HostTensor<YDataType> dy_host({m, n}, {stride, 1});
ck_tile::HostTensor<GammaDataType> gamma_host({n});
ck_tile::HostTensor<MeanDataType> mean_host({m});
ck_tile::HostTensor<InvStdDataType> invStd_host({m});
ck_tile::index_t blockM = layernorm2d_bwd_block_m<XDataType>();
ck_tile::index_t reduce_m = (m + blockM - 1) / blockM;
ck_tile::HostTensor<GammaDataType> dgamma_host_dev({reduce_m, n});
ck_tile::HostTensor<BetaDataType> dbeta_host_dev({reduce_m, n});
ck_tile::HostTensor<XDataType> dx_host_dev({m, n});
ck_tile::HostTensor<GammaDataType> dgamma_host_ref({reduce_m, n});
ck_tile::HostTensor<BetaDataType> dbeta_host_ref({reduce_m, n});
ck_tile::HostTensor<XDataType> dx_host_ref({m, n});
//tmp
ck_tile::HostTensor<ComputeDataType> ds_host_dev({m});
ck_tile::HostTensor<ComputeDataType> db_host_dev({m});
ck_tile::HostTensor<ComputeDataType> ds_host_ref({m});
ck_tile::HostTensor<ComputeDataType> db_host_ref({m});
// ck_tile::FillMonotonicSeq<YDataType>{}(dy_host);
ck_tile::FillUniformDistribution<YDataType>{-.5f, .5f}(dy_host);
ck_tile::FillUniformDistribution<GammaDataType>{-.5f, .5f}(gamma_host);
ck_tile::FillUniformDistribution<MeanDataType>{-.5f, .5f}(mean_host);
ck_tile::FillUniformDistribution<XDataType>{-.5f, .5f}(x_host);
// ck_tile::FillMonotonicSeq<MeanDataType>{}(mean_host);
ck_tile::FillUniformDistribution<InvStdDataType>{-.5f, .5f}(invStd_host);
ck_tile::DeviceMem x_buf(x_host.get_element_space_size_in_bytes());
ck_tile::DeviceMem dy_buf(dy_host.get_element_space_size_in_bytes());
ck_tile::DeviceMem gamma_buf(gamma_host.get_element_space_size_in_bytes());
ck_tile::DeviceMem mean_buf(mean_host.get_element_space_size_in_bytes());
ck_tile::DeviceMem invStd_buf(invStd_host.get_element_space_size_in_bytes());
ck_tile::DeviceMem dgamma_buf(dgamma_host_dev.get_element_space_size_in_bytes());
ck_tile::DeviceMem dbeta_buf(dbeta_host_dev.get_element_space_size_in_bytes());
ck_tile::DeviceMem dx_buf(dx_host_dev.get_element_space_size_in_bytes());
//tmp
ck_tile::DeviceMem ds_buf(ds_host_dev.get_element_space_size_in_bytes());
ck_tile::DeviceMem db_buf(db_host_dev.get_element_space_size_in_bytes());
x_buf.ToDevice(x_host.data());
dy_buf.ToDevice(dy_host.data());
gamma_buf.ToDevice(gamma_host.data());
mean_buf.ToDevice(mean_host.data());
invStd_buf.ToDevice(invStd_host.data());
std::cout << "[" << data_type << "]"
<< " m:" << m << ", n:" << n << ", stride:" << stride << std::flush;
layernorm2d_bwd_traits traits{data_type};
layernorm2d_bwd_args args{x_buf.GetDeviceBuffer(),
dy_buf.GetDeviceBuffer(),
gamma_buf.GetDeviceBuffer(),
mean_buf.GetDeviceBuffer(),
invStd_buf.GetDeviceBuffer(),
dgamma_buf.GetDeviceBuffer(),
dbeta_buf.GetDeviceBuffer(),
dx_buf.GetDeviceBuffer(),
m,
n,
stride};
float ave_time = layernorm2d_bwd(
traits, args, ck_tile::stream_config{nullptr, true, kname ? 1 : 0, warmup, repeat});
std::size_t num_byte = sizeof(XDataType) * m * n + sizeof(GammaDataType) * n +
sizeof(BetaDataType) * n + sizeof(YDataType) * m * n;
float gb_per_sec = num_byte / 1.E6 / ave_time;
std::cout << sizeof(ComputeDataType) << ", " << ave_time * 1.E3 << " us, " << gb_per_sec << " GB/s" << std::flush;
bool pass = true;
if(do_validation)
{
// reference
ck_tile::reference_layernorm2d_bwd_gamma_part<XDataType,
GammaDataType,
BetaDataType,
ComputeDataType,
YDataType,
MeanDataType,
InvStdDataType>(
x_host, dy_host, gamma_host, mean_host, invStd_host, dgamma_host_ref, dbeta_host_ref, dx_host_ref, ds_host_ref, db_host_ref);
dgamma_buf.FromDevice(dgamma_host_dev.data());
dbeta_buf.FromDevice(dbeta_host_dev.data());
auto [rtol, atol] = get_elimit<DataType>();
pass = ck_tile::check_err(
dgamma_host_dev, dgamma_host_ref, std::string("GAMMA OUT Error: Incorrect results!"), rtol, atol);
pass &= ck_tile::check_err(
dbeta_host_dev, dbeta_host_ref, std::string("BETA OUT Error: Incorrect results!"), rtol, atol);
std::cout << ", valid:" << (pass ? "y" : "n") << std::flush << std::endl;
}
return pass;
}
int main(int argc, char* argv[])
{
auto [result, arg_parser] = create_args(argc, argv);
if(!result)
return -1;
const std::string data_type = arg_parser.get_str("prec");
if(data_type == "fp16")
{
return run<ck_tile::half_t>(arg_parser) ? 0 : -2;
}
else if(data_type == "bf16")
{
return run<ck_tile::bf16_t>(arg_parser) ? 0 : -2;
}
return -3;
}
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core.hpp"
#include "ck_tile/host/kernel_launch.hpp"
#include "ck_tile/ops/layernorm2d.hpp"
#include <string>
template <typename DataType>
struct LayerNormTypeConfig;
template <>
struct LayerNormTypeConfig<ck_tile::half_t>
{
using XDataType = ck_tile::half_t;
using YDataType = ck_tile::half_t;
using GammaDataType = ck_tile::half_t;
using BetaDataType = ck_tile::half_t;
using MeanDataType = ck_tile::half_t;
using InvStdDataType = ck_tile::half_t;
using ComputeDataType = float;
};
template <>
struct LayerNormTypeConfig<ck_tile::bf16_t>
{
using XDataType = ck_tile::bf16_t;
using YDataType = ck_tile::bf16_t;
using GammaDataType = ck_tile::bf16_t;
using BetaDataType = ck_tile::bf16_t;
using MeanDataType = ck_tile::bf16_t;
using InvStdDataType = ck_tile::bf16_t;
using ComputeDataType = float;
};
// runtime args
struct layernorm2d_bwd_args : public ck_tile::Layernorm2dBwdGammaBetaHostArgs
{
};
// this is used to pattern-match internl kernel implementation, not to instantiate kernel
template <typename DataType_,
ck_tile::index_t Repeat_M_, // each thread repeat along M
ck_tile::index_t ThreadPerBlock_M_, // num threads along M
ck_tile::index_t ThreadPerBlock_N_, // num threads along N
bool kPadN_>
struct layernorm2d_bwd_traits_
{
using DataType = ck_tile::remove_cvref_t<DataType_>;
static constexpr bool is_warp_per_row = ThreadPerBlock_N_ <= warpSize;
static_assert((ThreadPerBlock_M_ * ThreadPerBlock_N_) % warpSize == 0);
static constexpr ck_tile::index_t total_warps =
(ThreadPerBlock_M_ * ThreadPerBlock_N_) / warpSize;
// num of warps along m
static constexpr ck_tile::index_t BlockWarps_M = []() {
if constexpr(is_warp_per_row)
{
static_assert(warpSize % ThreadPerBlock_N_ == 0);
return total_warps * (warpSize / ThreadPerBlock_N_);
}
else
{
// static_assert(warpSize % ThreadPerBlock_M_ == 0);
return total_warps / (ThreadPerBlock_N_ / warpSize);
}
}();
// num of warps along n
static constexpr ck_tile::index_t BlockWarps_N = []() {
if constexpr(is_warp_per_row)
{
static_assert(warpSize % ThreadPerBlock_N_ == 0);
return 1;
}
else
{
static_assert(ThreadPerBlock_N_ % warpSize == 0);
return ThreadPerBlock_N_ / warpSize;
}
}();
static constexpr ck_tile::index_t Repeat_M = Repeat_M_;
static constexpr ck_tile::index_t Block_M = Repeat_M_ * ThreadPerBlock_M_;
static constexpr ck_tile::index_t Block_N = ThreadPerBlock_N_;
static constexpr ck_tile::index_t Warp_M = ThreadPerBlock_M_ / BlockWarps_M;
static constexpr ck_tile::index_t Warp_N = ThreadPerBlock_N_ / BlockWarps_N;
using BlockTile = ck_tile::sequence<Block_M, Block_N>;
using BlockWarps = ck_tile::sequence<BlockWarps_M, BlockWarps_N>;
using WarpTile = ck_tile::sequence<Warp_M, Warp_N>;
using Vector = ck_tile::sequence<1, 1>;
using Shape = ck_tile::Generic2dBlockShape<BlockTile, BlockWarps, WarpTile, Vector>;
static constexpr bool kPadN = kPadN_;
};
template <typename DataType_,
ck_tile::index_t Repeat_M_, // each thread repeat along M
ck_tile::index_t ThreadPerBlock_M_, // num threads along M
ck_tile::index_t ThreadPerBlock_N_, // num threads along N
bool kPadN_>
using trait_ = layernorm2d_bwd_traits_<DataType_,
Repeat_M_,
ThreadPerBlock_M_,
ThreadPerBlock_N_,
kPadN_>;
template <typename Traits_>
float layernorm2d_bwd_(const ck_tile::stream_config& s, layernorm2d_bwd_args a);
// This is the public API, will be generated by script
struct layernorm2d_bwd_traits
{
std::string data_type;
};
template <typename data_type>
struct layernorm2d_bwd_b16_
{
/* data */
using Trait = trait_<data_type, 1, 1, 64, true>;
float operator() (layernorm2d_bwd_traits /*t*/,
layernorm2d_bwd_args a,
const ck_tile::stream_config& s) {
return layernorm2d_bwd_<Trait>(s, a);
}
};
template <typename data_type>
ck_tile::index_t layernorm2d_bwd_block_m() {
return layernorm2d_bwd_b16_<data_type>::Trait::Block_M;
};
float layernorm2d_bwd(layernorm2d_bwd_traits, layernorm2d_bwd_args, const ck_tile::stream_config&);
......@@ -25,6 +25,7 @@
#include "ck_tile/host/reference/reference_gemm.hpp"
#include "ck_tile/host/reference/reference_im2col.hpp"
#include "ck_tile/host/reference/reference_layernorm2d_fwd.hpp"
#include "ck_tile/host/reference/reference_layernorm2d_bwd.hpp"
#include "ck_tile/host/reference/reference_moe_sorting.hpp"
#include "ck_tile/host/reference/reference_permute.hpp"
#include "ck_tile/host/reference/reference_reduce.hpp"
......
#pragma once
#include "ck_tile/core.hpp"
#include "ck_tile/host/host_tensor.hpp"
#include <thread>
namespace ck_tile {
template <typename XDataType,
typename GammaDataType,
typename BetaDataType,
typename ComputeDataType,
typename YDataType,
typename MeanDataType,
typename InvStdDataType>
CK_TILE_HOST void reference_layernorm2d_bwd_gamma_part(const HostTensor<XDataType>& x_m_n,
const HostTensor<YDataType>& dy_m_n,
const HostTensor<GammaDataType>& gamma_n,
const HostTensor<MeanDataType>& mean_m,
const HostTensor<InvStdDataType>& inv_std_m,
HostTensor<GammaDataType>& dgamma_mpart_n,
HostTensor<BetaDataType>& dbeta_mpart_n,
HostTensor<XDataType>& dx_m_n,
//tmp
HostTensor<ComputeDataType>& ds_m,
HostTensor<ComputeDataType>& db_m)
{
const auto MN = x_m_n.mDesc.get_lengths();
const int M = MN[0];
const int N = MN[1];
const int PartM = dgamma_mpart_n.mDesc.get_lengths()[0];
const int MLoop = (M + PartM - 1) / PartM;
printf("\ndteng print---M=%d,N=%d,PartM=%d,MLoop=%d\n",M,N,PartM,MLoop);
auto f = [&](auto m) {
const int m_offset = m * MLoop;
//calculate dgamma, dbeta
for(int n = 0; n < N; ++n)
{
ComputeDataType gamma_acc = 0;
ComputeDataType beta_acc = 0;
for(int inner_m = 0; inner_m < MLoop && m_offset + inner_m < M; inner_m++)
{
const ComputeDataType mean = ck_tile::type_convert<ComputeDataType>(mean_m(m_offset + inner_m));
const ComputeDataType inv_std = ck_tile::type_convert<ComputeDataType>(inv_std_m(m_offset + inner_m));
const ComputeDataType x = ck_tile::type_convert<ComputeDataType>(x_m_n(m_offset + inner_m, n));
const ComputeDataType dy = ck_tile::type_convert<ComputeDataType>(dy_m_n(m_offset + inner_m, n));
gamma_acc += dy * (x - mean) * inv_std;
beta_acc += dy;
}
dgamma_mpart_n(m, n) = ck_tile::type_convert<GammaDataType>(gamma_acc);
dbeta_mpart_n(m, n) = ck_tile::type_convert<BetaDataType>(beta_acc);
}
//calculate dx
for(int inner_m = 0; inner_m < MLoop && m_offset + inner_m < M; inner_m++)
{
ComputeDataType ds = 0;
ComputeDataType db = 0;
const ComputeDataType mean = ck_tile::type_convert<ComputeDataType>(mean_m(m_offset + inner_m));
const ComputeDataType inv_std = ck_tile::type_convert<ComputeDataType>(inv_std_m(m_offset + inner_m));
for(int n = 0; n < N; ++n)
{
const ComputeDataType dy = ck_tile::type_convert<ComputeDataType>(dy_m_n(m_offset + inner_m, n));
const ComputeDataType x = ck_tile::type_convert<ComputeDataType>(x_m_n(m_offset + inner_m, n));
const ComputeDataType gamma = ck_tile::type_convert<ComputeDataType>(gamma_n(n));
ds += dy * gamma * x;
db += dy * gamma;
}
ComputeDataType b = (db * mean - ds) * inv_std * inv_std * inv_std / N;
ComputeDataType c = -b * mean - db * inv_std / N;
for(int n = 0; n < N; ++n)
{
const ComputeDataType dy = ck_tile::type_convert<ComputeDataType>(dy_m_n(m_offset + inner_m, n));
const ComputeDataType x = ck_tile::type_convert<ComputeDataType>(x_m_n(m_offset + inner_m, n));
const ComputeDataType gamma = ck_tile::type_convert<ComputeDataType>(gamma_n(n));
dx_m_n(m_offset + inner_m, n) = ck_tile::type_convert<XDataType>(dy * gamma * inv_std + b * x + c);
}
}
};
make_ParallelTensorFunctor(f, PartM)(std::thread::hardware_concurrency());
}
} // namespace ck_tile
......@@ -10,4 +10,10 @@
#include "ck_tile/ops/layernorm2d/pipeline/layernorm2d_fwd_pipeline_two_pass.hpp"
#include "ck_tile/ops/layernorm2d/pipeline/layernorm2d_fwd_traits.hpp"
#include "ck_tile/ops/common/generic_2d_block_shape.hpp"
#include "ck_tile/ops/layernorm2d/kernel/layernorm2d_bwd_gamma_beta_kernel.hpp"
#include "ck_tile/ops/layernorm2d/pipeline/layernorm2d_bwd_pipeline_default_policy.hpp"
#include "ck_tile/ops/layernorm2d/pipeline/layernorm2d_bwd_pipeline_gamma_beta.hpp"
#include "ck_tile/ops/layernorm2d/pipeline/layernorm2d_bwd_pipeline_problem.hpp"
#include "ck_tile/ops/common/tensor_layout.hpp"
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core.hpp"
#include "ck_tile/ops/common.hpp"
namespace ck_tile {
// host side args
struct Layernorm2dBwdGammaBetaHostArgs
{
const void* p_x;
const void* p_dY;
const void* p_gamma;
const void* p_mean;
const void* p_invStd;
void* p_dGamma;
void* p_dBeta;
void* p_dX;
index_t m;
index_t n;
index_t stride; // row_stride
};
// TODO: Extract some type to wrapper class
template <typename Pipeline_>
struct Layernorm2dBwdGammaBeta
{
using Pipeline = remove_cvref_t<Pipeline_>;
using Problem = typename Pipeline::Problem;
using XDataType = remove_cvref_t<typename Problem::XDataType>;
using GammaDataType = remove_cvref_t<typename Problem::GammaDataType>;
using BetaDataType = remove_cvref_t<typename Problem::BetaDataType>;
using ComputeDataType = remove_cvref_t<typename Problem::ComputeDataType>;
using YDataType = remove_cvref_t<typename Problem::YDataType>;
using MeanDataType = remove_cvref_t<typename Problem::MeanDataType>;
using InvStdDataType = remove_cvref_t<typename Problem::InvStdDataType>;
static constexpr index_t Block_M = Problem::BlockShape::Block_M;
static constexpr index_t Block_N = Problem::BlockShape::Block_N;
static constexpr bool kPadM = false; // always no need to pad along M
static constexpr bool kPadN = Problem::kPadN;
static constexpr index_t ThreadPerWarp_N = Problem::BlockShape::ThreadPerWarp_N;
static constexpr auto I0 = number<0>{};
static constexpr auto I1 = number<1>{};
struct Kargs
{
const void* p_x;
const void* p_dY;
const void* p_gamma;
const void* p_mean;
const void* p_invStd;
void* p_dGamma;
void* p_dBeta;
void* p_dX;
index_t m;
index_t n;
index_t stride; // row_stride
};
using Hargs = Layernorm2dBwdGammaBetaHostArgs;
CK_TILE_HOST static constexpr Kargs MakeKargs(const Hargs& hargs)
{
return Kargs{hargs.p_x,
hargs.p_dY,
hargs.p_gamma,
hargs.p_mean,
hargs.p_invStd,
hargs.p_dGamma,
hargs.p_dBeta,
hargs.p_dX,
hargs.m,
hargs.n,
hargs.stride};
}
CK_TILE_HOST static constexpr auto GridSize(const Hargs& hargs)
{
return (hargs.m + Block_M - 1) / Block_M;
}
CK_TILE_HOST static constexpr auto BlockSize() { return Problem::BlockShape::BlockSize; }
// clang-format off
template <typename T> struct t2s;
template <> struct t2s<float> { static constexpr const char * name = "fp32"; };
template <> struct t2s<ck_tile::fp16_t> { static constexpr const char * name = "fp16"; };
template <> struct t2s<ck_tile::bf16_t> { static constexpr const char * name = "bf16"; };
template <> struct t2s<ck_tile::fp8_t> { static constexpr const char * name = "fp8"; };
template <> struct t2s<ck_tile::bf8_t> { static constexpr const char * name = "bf8"; };
// clang-format on
// in byte
CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSize() { return Pipeline::GetSmemSize(); }
CK_TILE_HOST static std::string GetName()
{
// clang-format off
using S_ = typename Problem::BlockShape;
auto surfix = [&] () {
std::string n;
if (kPadN) n += "_pn";
return n; }();
#define _SS_ std::string
#define _TS_ std::to_string
return _SS_("layernorm2d_bwd_") + _SS_(t2s<XDataType>::name) + "_" +
_TS_(S_::Block_M) + "x" + _TS_(S_::Block_N) + "_" + _TS_(S_::WarpPerBlock_M) + "x" + _TS_(S_::WarpPerBlock_N) + "_" +
_TS_(S_::Warp_M) + "x" + _TS_(S_::Warp_N) + "_" + _TS_(S_::Vector_M) + "x" + _TS_(1) + "_" +
_SS_(Pipeline::name) + surfix;
#undef _SS_
#undef _TS_
// clang-format on
}
CK_TILE_DEVICE void operator()(Kargs kargs) const
{
const auto block_id = get_block_id();
const auto iM = block_id * Block_M;
const auto x_window = [&]() {
const auto tmp_ = make_naive_tensor_view<address_space_enum::global>(
static_cast<const XDataType*>(kargs.p_x),
make_tuple(kargs.m, kargs.n),
make_tuple(kargs.stride, 1));
// NOTE: we don't do any pad in this kernel for loading, assume that inside kernel will
// check the max count dynamically
const auto tmp2_ = pad_tensor_view(
tmp_, make_tuple(number<Block_M>{}, number<Block_N>{}), sequence<false, false>{});
return make_tile_window(
tmp2_, make_tuple(number<Block_M>{}, number<Block_N>{}), {iM, 0});
}();
const auto dy_window = [&]() {
const auto tmp_ = make_naive_tensor_view<address_space_enum::global>(
static_cast<const YDataType*>(kargs.p_dY),
make_tuple(kargs.m, kargs.n),
make_tuple(kargs.stride, 1));
// NOTE: we don't do any pad in this kernel for loading, assume that inside kernel will
// check the max count dynamically
const auto tmp2_ = pad_tensor_view(
tmp_, make_tuple(number<Block_M>{}, number<Block_N>{}), sequence<false, false>{});
return make_tile_window(
tmp2_, make_tuple(number<Block_M>{}, number<Block_N>{}), {iM, 0});
}();
const auto gamma_window = [&]() {
const auto tmp_ = make_naive_tensor_view<address_space_enum::global>(
static_cast<const MeanDataType*>(kargs.p_gamma),
make_tuple(kargs.n),
make_tuple(1));
const auto tmp2_ =
pad_tensor_view(tmp_, make_tuple(number<Block_N>{}), sequence<false>{});
return make_tile_window(tmp2_, make_tuple(number<Block_N>{}), {0});
}();
const auto mean_window = [&]() {
const auto tmp_ = make_naive_tensor_view<address_space_enum::global>(
static_cast<const MeanDataType*>(kargs.p_mean),
make_tuple(kargs.m),
make_tuple(1));
const auto tmp2_ =
pad_tensor_view(tmp_, make_tuple(number<Block_M>{}), sequence<false>{});
return make_tile_window(tmp2_, make_tuple(number<Block_M>{}), {iM});
}();
const auto invstd_window = [&]() {
const auto tmp_ = make_naive_tensor_view<address_space_enum::global>(
static_cast<const MeanDataType*>(kargs.p_invStd),
make_tuple(kargs.m),
make_tuple(1));
const auto tmp2_ =
pad_tensor_view(tmp_, make_tuple(number<Block_M>{}), sequence<false>{});
return make_tile_window(tmp2_, make_tuple(number<Block_M>{}), {iM});
}();
auto dgamma_window = [&]() {
const auto tmp_ = make_naive_tensor_view<address_space_enum::global>(
static_cast<GammaDataType*>(kargs.p_dGamma),
make_tuple(gridDim.x, kargs.n),
make_tuple(kargs.n, 1));
const auto tmp2_ =
pad_tensor_view(tmp_, make_tuple(number<1>{}, number<Block_N>{}), sequence<false, kPadN>{});
return make_tile_window(tmp2_, make_tuple(number<1>{}, number<Block_N>{}), {block_id, 0});
}();
auto dbeta_window = [&]() {
const auto tmp_ = make_naive_tensor_view<address_space_enum::global>(
static_cast<BetaDataType*>(kargs.p_dBeta),
make_tuple(gridDim.x, kargs.n),
make_tuple(kargs.n, 1));
const auto tmp2_ =
pad_tensor_view(tmp_, make_tuple(number<1>{}, number<Block_N>{}), sequence<false, kPadN>{});
return make_tile_window(tmp2_, make_tuple(number<1>{}, number<Block_N>{}), {block_id, 0});
}();
auto dx_window = [&]() {
const auto tmp_ = make_naive_tensor_view<address_space_enum::global>(
static_cast<XDataType*>(kargs.p_dX),
make_tuple(kargs.m, kargs.n),
make_tuple(kargs.stride, 1));
const auto tmp2_ =
pad_tensor_view(tmp_, make_tuple(number<Block_M>{}, number<Block_N>{}), sequence<false, false>{});
return make_tile_window(tmp2_, make_tuple(number<Block_M>{}, number<Block_N>{}), {iM, 0});
}();
__shared__ char smem[GetSmemSize()];
Pipeline{}(x_window,
dy_window,
gamma_window,
mean_window,
invstd_window,
dgamma_window,
dbeta_window,
dx_window,
kargs.n,
smem);
}
};
} // namespace ck_tile
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core.hpp"
namespace ck_tile {
struct Layernorm2dBwdGammaBetaPipelineDefaultPolicy
{
template <typename Problem>
CK_TILE_DEVICE static constexpr auto MakeXBlockTileDistribution()
{
using S = typename Problem::BlockShape;
return make_static_tile_distribution(
tile_distribution_encoding<
sequence<>,
tuple<sequence<S::Repeat_M, S::WarpPerBlock_M, S::ThreadPerWarp_M>,
sequence<S::Repeat_N, S::WarpPerBlock_N, S::ThreadPerWarp_N>>,
tuple<sequence<1, 2>, sequence<1, 2>>,
tuple<sequence<1, 1>, sequence<2, 2>>,
sequence<1, 2>,
sequence<0, 0>>{});
}
template <typename Problem>
CK_TILE_DEVICE static constexpr auto MakeMeanBlockTileDistribution()
{
using S = typename Problem::BlockShape;
return make_static_tile_distribution(
tile_distribution_encoding<
sequence<S::WarpPerBlock_N, S::ThreadPerWarp_N>,
tuple<sequence<S::Repeat_M, S::WarpPerBlock_M, S::ThreadPerWarp_M>>,
tuple<sequence<1, 0>, sequence<1, 0>>,
tuple<sequence<1, 0>, sequence<2, 1>>,
sequence<1>,
sequence<0>>{});
}
template <typename Problem>
CK_TILE_DEVICE static constexpr auto MakeDGammaBetaBlockTileDistribution()
{
using S = typename Problem::BlockShape;
return make_static_tile_distribution(
tile_distribution_encoding<
sequence<>,
tuple<sequence<S::WarpPerBlock_M, S::ThreadPerWarp_M>,
sequence<S::Repeat_N, S::WarpPerBlock_N, S::ThreadPerWarp_N>>,
tuple<sequence<1, 2>, sequence<1, 2>>,
tuple<sequence<0, 1>, sequence<1, 2>>,
sequence<2>,
sequence<0>>{});
}
template <typename Problem>
CK_TILE_DEVICE static constexpr auto MakeGammaBetaBlockTileDistribution()
{
using S = typename Problem::BlockShape;
return make_static_tile_distribution(
tile_distribution_encoding<
sequence<S::WarpPerBlock_M, S::ThreadPerWarp_M>,
tuple<sequence<S::Repeat_N, S::WarpPerBlock_N, S::ThreadPerWarp_N, S::Vector_N>>,
tuple<sequence<0, 1>, sequence<0, 1>>,
tuple<sequence<0, 1>, sequence<1, 2>>,
sequence<1, 1>,
sequence<0, 3>>{});
}
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSize()
{
return 1;
}
};
} // namespace ck_tile
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core.hpp"
#include "ck_tile/ops/layernorm2d/pipeline/layernorm2d_bwd_pipeline_default_policy.hpp"
#include <string>
#include <type_traits>
namespace ck_tile {
template <typename Problem_, typename Policy_ = Layernorm2dBwdGammaBetaPipelineDefaultPolicy>
struct Layernorm2dBwdGammaBetaPipeline
{
using Problem = ck_tile::remove_cvref_t<Problem_>;
using Policy = ck_tile::remove_cvref_t<Policy_>;
using XDataType = ck_tile::remove_cvref_t<typename Problem::XDataType>;
using GammaDataType = ck_tile::remove_cvref_t<typename Problem::GammaDataType>;
using BetaDataType = ck_tile::remove_cvref_t<typename Problem::BetaDataType>;
using ComputeDataType = ck_tile::remove_cvref_t<typename Problem::ComputeDataType>;
using YDataType = ck_tile::remove_cvref_t<typename Problem::YDataType>;
using MeanDataType = ck_tile::remove_cvref_t<typename Problem::MeanDataType>;
using InvStdDataType = ck_tile::remove_cvref_t<typename Problem::InvStdDataType>;
static constexpr bool kPadM = false;
static constexpr bool kPadN = Problem::kPadN;
static constexpr const char* name = []() {
return "bwd_gamma_beta";
}();
CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSize()
{
return Policy::template GetSmemSize<Problem>();
}
template <typename XWindow,
typename GammaWindow,
typename MeanWindow,
typename InvStdWindow,
typename DGammaWindow,
typename DBetaWindow,
typename DXWindow>
CK_TILE_DEVICE auto operator()(const XWindow& x_window_,
const XWindow& dy_window_,
const GammaWindow& gamma_window_,
const MeanWindow& mean_window_,
const InvStdWindow& inv_std_window_,
DGammaWindow& dgamma_window_,
DBetaWindow& dbeta_window_,
DXWindow& dx_window_,
ck_tile::index_t row_size,
void* smem) const
{
(void)row_size;
(void)smem;
auto gamma_beta_dist = Policy::template MakeGammaBetaBlockTileDistribution<Problem>();
auto dgamma_beta_dist = Policy::template MakeDGammaBetaBlockTileDistribution<Problem>();
auto mean_dist = Policy::template MakeMeanBlockTileDistribution<Problem>();
auto x_dist = Policy::template MakeXBlockTileDistribution<Problem>();
const auto x_window = make_tile_window(x_window_, x_dist);
const auto dy_window = make_tile_window(dy_window_, x_dist);
const auto gamma_window = make_tile_window(gamma_window_, gamma_beta_dist); //TO CHECK
const auto mean_window = make_tile_window(mean_window_, mean_dist);
const auto inv_std_window = make_tile_window(inv_std_window_, mean_dist);
const auto x_tile = load_tile(x_window);
const auto dy_tile = load_tile(dy_window);
const auto gamma_tile = load_tile(gamma_window);
const auto mean_tile = load_tile(mean_window);
const auto inv_std_tile = load_tile(inv_std_window);
auto dgamma_window = make_tile_window(dgamma_window_, dgamma_beta_dist);
auto dbeta_window = make_tile_window(dbeta_window_, dgamma_beta_dist);
auto dx_window = make_tile_window(dx_window_, x_dist);
auto dgamma_tile = make_static_distributed_tensor<GammaDataType>(dgamma_beta_dist);
auto dbeta_tile = make_static_distributed_tensor<BetaDataType>(dgamma_beta_dist);
auto dx_tile = make_static_distributed_tensor<XDataType>(x_dist);
auto dgamma = cast_tile<ComputeDataType>(dgamma_tile);
auto dbeta = cast_tile<ComputeDataType>(dbeta_tile);
auto dx = cast_tile<XDataType>(dx_tile);
(void)dx_window;
(void)dx;
(void)gamma_tile;
sweep_tile(x_tile, [&](auto idx) {
constexpr auto i_idx = make_tuple(idx[number<0>{}]);
//constexpr auto j_idx = make_tuple(idx[number<1>{}]);
constexpr auto gb_idx = make_tuple(number<0>{}, idx[number<1>{}]);
// auto &gamma = gamma_tile(gb_idx);
// auto &beta = beta_tile(gb_idx);
const auto x = type_convert<ComputeDataType>(x_tile[idx]);
const auto dy = type_convert<ComputeDataType>(dy_tile[idx]);
const auto mean = type_convert<ComputeDataType>(mean_tile[i_idx]);
const auto inv_std = type_convert<ComputeDataType>(inv_std_tile[i_idx]);
// beta += type_convert<BetaDataType>(dy);
// gamma += type_convert<GammaDataType>(dy * (x - mean) * inv_std);
dbeta(gb_idx) += dy;
dgamma(gb_idx) += dy * (x - mean) * inv_std;
// index_t tid = (threadIdx.y * blockDim.x) + threadIdx.x;
// if(blockIdx.x < 3 && blockIdx.y == 0 && tid < 3) {
// printf("bid %d tid %d count %d gb %f %f\n",blockIdx.x, tid, count, type_convert<float>(g), type_convert<float>(b));
// }
});
store_tile(dbeta_window, cast_tile<BetaDataType>(dbeta));
store_tile(dgamma_window, cast_tile<GammaDataType>(dgamma));
// store_tile(gamma_window, gamma_tile);
// store_tile(beta_window, beta_tile);
// auto ds = cast_tile<ComputeDataType>(mean_tile);
// auto db = cast_tile<ComputeDataType>(mean_tile);
// //calculate dx
// sweep_tile(x_tile, [&](auto idx)) {
// constexpr auto i_idx = make_tuple(idx[number<0>{}]);
// constexpr auto j_idx = make_tuple(idx[number<1>{}]);
// const auto x = type_convert<ComputeDataType>(x_tile[idx]);
// const auto dy = type_convert<ComputeDataType>(dy_tile[idx]);
// const auto gamma = type_convert<ComputeDataType>(gamma_tile[j_idx]);
// // const auto mean = type_convert<ComputeDataType>(mean_tile[i_idx]);
// // const auto inv_std = type_convert<ComputeDataType>(inv_std_tile[i_idx]);
// ds[i_idx] += dy * gamma * x;
// db[i_idx] += dy * gamma;
// }
}
};
} // namespace ck_tile
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core/utility/type_traits.hpp"
namespace ck_tile {
template <typename XDataType_,
typename GammaDataType_,
typename BetaDataType_,
typename ComputeDataType_,
typename YDataType_,
typename MeanDataType_,
typename InvStdDataType_,
typename BlockShape_,
bool kPadN_>
struct Layernorm2dBwdGammaBetaPipelineProblem
{
using XDataType = remove_cvref_t<XDataType_>;
using GammaDataType = remove_cvref_t<GammaDataType_>;
using BetaDataType = remove_cvref_t<BetaDataType_>;
using ComputeDataType = remove_cvref_t<ComputeDataType_>;
using YDataType = remove_cvref_t<YDataType_>;
using MeanDataType = remove_cvref_t<MeanDataType_>;
using InvStdDataType = remove_cvref_t<InvStdDataType_>;
using BlockShape = remove_cvref_t<BlockShape_>;
static constexpr bool kPadN = kPadN_;
};
} // namespace ck_tile
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment