Commit d7110645 authored by carlushuang's avatar carlushuang
Browse files

fix 2p pass

parent d5d7de90
......@@ -35,7 +35,7 @@ float layernorm2d_fwd_(const S& s, A a)
{
using DataType = typename Traits_::DataType;
using PipelineProblem = ck_tile::Layernorm2dFwdRowwiseProblem<
using PipelineProblem = ck_tile::Layernorm2dFwdPipelineProblem<
typename LayerNormTypeConfig<DataType>::XDataType,
typename LayerNormTypeConfig<DataType>::GammaDataType,
typename LayerNormTypeConfig<DataType>::BetaDataType,
......@@ -48,8 +48,8 @@ float layernorm2d_fwd_(const S& s, A a)
Traits_::kSaveMeanInvStd,
Traits_::kTwoPass>;
using OnePassPipeline = ck_tile::Layernorm2dFwdOnePassPipeline<PipelineProblem>;
using TwoPassPipeline = ck_tile::Layernorm2dFwdTwoPassPipeline<PipelineProblem>;
using OnePassPipeline = ck_tile::Layernorm2dFwdPipelineOnePass<PipelineProblem>;
using TwoPassPipeline = ck_tile::Layernorm2dFwdPipelineTwoPass<PipelineProblem>;
using Pipeline = std::conditional_t<Traits_::kTwoPass, TwoPassPipeline, OnePassPipeline>;
using Kernel = ck_tile::Layernorm2dFwd<Pipeline>;
......
#!/bin/sh
# call from top of CK folder
EXE=./bin/tile_example_layernorm2d_fwd
EXE=./build/bin/tile_example_layernorm2d_fwd
for pr_i in "fp16" "bf16" ; do
$EXE -prec=$pr_i -m=99 -n=13
......@@ -27,4 +27,5 @@ $EXE -prec=$pr_i -m=1 -n=3182
$EXE -prec=$pr_i -m=9 -n=4096
$EXE -prec=$pr_i -m=3 -n=8192
$EXE -prec=$pr_i -m=1 -n=10547
$EXE -prec=$pr_i -m=3 -n=17134
done
......@@ -5,8 +5,8 @@
#include "ck_tile/ops/layernorm2d/kernel/layernorm2d_fwd_kernel.hpp"
#include "ck_tile/ops/layernorm2d/kernel/layernorm2d_fwd_shape.hpp"
#include "ck_tile/ops/layernorm2d/pipeline/layernorm2d_fwd_rowwise_default_policy.hpp"
#include "ck_tile/ops/layernorm2d/pipeline/layernorm2d_fwd_one_pass_pipeline.hpp"
#include "ck_tile/ops/layernorm2d/pipeline/layernorm2d_fwd_two_pass_pipeline.hpp"
#include "ck_tile/ops/layernorm2d/pipeline/layernorm2d_fwd_rowwise_problem.hpp"
#include "ck_tile/ops/layernorm2d/pipeline/layernorm2d_fwd_pipeline_default_policy.hpp"
#include "ck_tile/ops/layernorm2d/pipeline/layernorm2d_fwd_pipeline_one_pass.hpp"
#include "ck_tile/ops/layernorm2d/pipeline/layernorm2d_fwd_pipeline_problem.hpp"
#include "ck_tile/ops/layernorm2d/pipeline/layernorm2d_fwd_pipeline_two_pass.hpp"
#include "ck_tile/ops/common/tensor_layout.hpp"
......@@ -9,7 +9,7 @@
namespace ck_tile {
struct Layernorm2dFwdRowwiseDefaultPolicy
struct Layernorm2dFwdPipelineDefaultPolicy
{
template <typename Problem>
CK_TILE_DEVICE static constexpr auto MakeXBlockTileDistribution()
......
......@@ -4,14 +4,14 @@
#pragma once
#include "ck_tile/core.hpp"
#include "ck_tile/ops/layernorm2d/pipeline/layernorm2d_fwd_rowwise_default_policy.hpp"
#include "ck_tile/ops/layernorm2d/pipeline/layernorm2d_fwd_pipeline_default_policy.hpp"
#include <string>
#include <type_traits>
namespace ck_tile {
template <typename Problem_, typename Policy_ = Layernorm2dFwdRowwiseDefaultPolicy>
struct Layernorm2dFwdOnePassPipeline
template <typename Problem_, typename Policy_ = Layernorm2dFwdPipelineDefaultPolicy>
struct Layernorm2dFwdPipelineOnePass
{
using Problem = ck_tile::remove_cvref_t<Problem_>;
using Policy = ck_tile::remove_cvref_t<Policy_>;
......
......@@ -18,7 +18,7 @@ template <typename XDataType_,
bool kPadN_,
bool kSaveMeanInvStd_,
bool kTwoPass_>
struct Layernorm2dFwdRowwiseProblem
struct Layernorm2dFwdPipelineProblem
{
using XDataType = remove_cvref_t<XDataType_>;
using GammaDataType = remove_cvref_t<GammaDataType_>;
......
......@@ -4,14 +4,14 @@
#pragma once
#include "ck_tile/core.hpp"
#include "ck_tile/ops/layernorm2d/pipeline/layernorm2d_fwd_rowwise_default_policy.hpp"
#include "ck_tile/ops/layernorm2d/pipeline/layernorm2d_fwd_pipeline_default_policy.hpp"
#include <string>
#include <type_traits>
namespace ck_tile {
template <typename Problem_, typename Policy_ = Layernorm2dFwdRowwiseDefaultPolicy>
struct Layernorm2dFwdTwoPassPipeline
template <typename Problem_, typename Policy_ = Layernorm2dFwdPipelineDefaultPolicy>
struct Layernorm2dFwdPipelineTwoPass
{
using Problem = ck_tile::remove_cvref_t<Problem_>;
using Policy = ck_tile::remove_cvref_t<Policy_>;
......@@ -73,9 +73,15 @@ struct Layernorm2dFwdTwoPassPipeline
index_t num_n_tile_iteration =
__builtin_amdgcn_readfirstlane(integer_divide_ceil(row_size, Block_N));
// total number of count assume current iter have no pad(only last iter has pad)
constexpr index_t count_per_iter =
Problem::BlockShape::Repeat_N * Problem::BlockShape::Vector_N;
const index_t last_iter_n = row_size - (num_n_tile_iteration - 1) * Block_N;
int cur_count = 0;
int max_count =
block_tile_welford_calculate_max_count<typename Problem::BlockShape>(row_size);
(num_n_tile_iteration - 1) * count_per_iter +
block_tile_welford_calculate_max_count<typename Problem::BlockShape>(last_iter_n);
auto block_welford = Policy::template GetBlockWelford<Problem>();
auto block_welford_sync = Policy::template GetBlockWelfordSync<Problem>();
auto block_welford_cross_warp_sync =
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment