Commit d7110645 authored by carlushuang's avatar carlushuang
Browse files

fix 2p pass

parent d5d7de90
...@@ -35,7 +35,7 @@ float layernorm2d_fwd_(const S& s, A a) ...@@ -35,7 +35,7 @@ float layernorm2d_fwd_(const S& s, A a)
{ {
using DataType = typename Traits_::DataType; using DataType = typename Traits_::DataType;
using PipelineProblem = ck_tile::Layernorm2dFwdRowwiseProblem< using PipelineProblem = ck_tile::Layernorm2dFwdPipelineProblem<
typename LayerNormTypeConfig<DataType>::XDataType, typename LayerNormTypeConfig<DataType>::XDataType,
typename LayerNormTypeConfig<DataType>::GammaDataType, typename LayerNormTypeConfig<DataType>::GammaDataType,
typename LayerNormTypeConfig<DataType>::BetaDataType, typename LayerNormTypeConfig<DataType>::BetaDataType,
...@@ -48,8 +48,8 @@ float layernorm2d_fwd_(const S& s, A a) ...@@ -48,8 +48,8 @@ float layernorm2d_fwd_(const S& s, A a)
Traits_::kSaveMeanInvStd, Traits_::kSaveMeanInvStd,
Traits_::kTwoPass>; Traits_::kTwoPass>;
using OnePassPipeline = ck_tile::Layernorm2dFwdOnePassPipeline<PipelineProblem>; using OnePassPipeline = ck_tile::Layernorm2dFwdPipelineOnePass<PipelineProblem>;
using TwoPassPipeline = ck_tile::Layernorm2dFwdTwoPassPipeline<PipelineProblem>; using TwoPassPipeline = ck_tile::Layernorm2dFwdPipelineTwoPass<PipelineProblem>;
using Pipeline = std::conditional_t<Traits_::kTwoPass, TwoPassPipeline, OnePassPipeline>; using Pipeline = std::conditional_t<Traits_::kTwoPass, TwoPassPipeline, OnePassPipeline>;
using Kernel = ck_tile::Layernorm2dFwd<Pipeline>; using Kernel = ck_tile::Layernorm2dFwd<Pipeline>;
......
#!/bin/sh #!/bin/sh
# call from top of CK folder # call from top of CK folder
EXE=./bin/tile_example_layernorm2d_fwd EXE=./build/bin/tile_example_layernorm2d_fwd
for pr_i in "fp16" "bf16" ; do for pr_i in "fp16" "bf16" ; do
$EXE -prec=$pr_i -m=99 -n=13 $EXE -prec=$pr_i -m=99 -n=13
...@@ -27,4 +27,5 @@ $EXE -prec=$pr_i -m=1 -n=3182 ...@@ -27,4 +27,5 @@ $EXE -prec=$pr_i -m=1 -n=3182
$EXE -prec=$pr_i -m=9 -n=4096 $EXE -prec=$pr_i -m=9 -n=4096
$EXE -prec=$pr_i -m=3 -n=8192 $EXE -prec=$pr_i -m=3 -n=8192
$EXE -prec=$pr_i -m=1 -n=10547 $EXE -prec=$pr_i -m=1 -n=10547
$EXE -prec=$pr_i -m=3 -n=17134
done done
...@@ -5,8 +5,8 @@ ...@@ -5,8 +5,8 @@
#include "ck_tile/ops/layernorm2d/kernel/layernorm2d_fwd_kernel.hpp" #include "ck_tile/ops/layernorm2d/kernel/layernorm2d_fwd_kernel.hpp"
#include "ck_tile/ops/layernorm2d/kernel/layernorm2d_fwd_shape.hpp" #include "ck_tile/ops/layernorm2d/kernel/layernorm2d_fwd_shape.hpp"
#include "ck_tile/ops/layernorm2d/pipeline/layernorm2d_fwd_rowwise_default_policy.hpp" #include "ck_tile/ops/layernorm2d/pipeline/layernorm2d_fwd_pipeline_default_policy.hpp"
#include "ck_tile/ops/layernorm2d/pipeline/layernorm2d_fwd_one_pass_pipeline.hpp" #include "ck_tile/ops/layernorm2d/pipeline/layernorm2d_fwd_pipeline_one_pass.hpp"
#include "ck_tile/ops/layernorm2d/pipeline/layernorm2d_fwd_two_pass_pipeline.hpp" #include "ck_tile/ops/layernorm2d/pipeline/layernorm2d_fwd_pipeline_problem.hpp"
#include "ck_tile/ops/layernorm2d/pipeline/layernorm2d_fwd_rowwise_problem.hpp" #include "ck_tile/ops/layernorm2d/pipeline/layernorm2d_fwd_pipeline_two_pass.hpp"
#include "ck_tile/ops/common/tensor_layout.hpp" #include "ck_tile/ops/common/tensor_layout.hpp"
...@@ -9,7 +9,7 @@ ...@@ -9,7 +9,7 @@
namespace ck_tile { namespace ck_tile {
struct Layernorm2dFwdRowwiseDefaultPolicy struct Layernorm2dFwdPipelineDefaultPolicy
{ {
template <typename Problem> template <typename Problem>
CK_TILE_DEVICE static constexpr auto MakeXBlockTileDistribution() CK_TILE_DEVICE static constexpr auto MakeXBlockTileDistribution()
......
...@@ -4,14 +4,14 @@ ...@@ -4,14 +4,14 @@
#pragma once #pragma once
#include "ck_tile/core.hpp" #include "ck_tile/core.hpp"
#include "ck_tile/ops/layernorm2d/pipeline/layernorm2d_fwd_rowwise_default_policy.hpp" #include "ck_tile/ops/layernorm2d/pipeline/layernorm2d_fwd_pipeline_default_policy.hpp"
#include <string> #include <string>
#include <type_traits> #include <type_traits>
namespace ck_tile { namespace ck_tile {
template <typename Problem_, typename Policy_ = Layernorm2dFwdRowwiseDefaultPolicy> template <typename Problem_, typename Policy_ = Layernorm2dFwdPipelineDefaultPolicy>
struct Layernorm2dFwdOnePassPipeline struct Layernorm2dFwdPipelineOnePass
{ {
using Problem = ck_tile::remove_cvref_t<Problem_>; using Problem = ck_tile::remove_cvref_t<Problem_>;
using Policy = ck_tile::remove_cvref_t<Policy_>; using Policy = ck_tile::remove_cvref_t<Policy_>;
......
...@@ -18,7 +18,7 @@ template <typename XDataType_, ...@@ -18,7 +18,7 @@ template <typename XDataType_,
bool kPadN_, bool kPadN_,
bool kSaveMeanInvStd_, bool kSaveMeanInvStd_,
bool kTwoPass_> bool kTwoPass_>
struct Layernorm2dFwdRowwiseProblem struct Layernorm2dFwdPipelineProblem
{ {
using XDataType = remove_cvref_t<XDataType_>; using XDataType = remove_cvref_t<XDataType_>;
using GammaDataType = remove_cvref_t<GammaDataType_>; using GammaDataType = remove_cvref_t<GammaDataType_>;
......
...@@ -4,14 +4,14 @@ ...@@ -4,14 +4,14 @@
#pragma once #pragma once
#include "ck_tile/core.hpp" #include "ck_tile/core.hpp"
#include "ck_tile/ops/layernorm2d/pipeline/layernorm2d_fwd_rowwise_default_policy.hpp" #include "ck_tile/ops/layernorm2d/pipeline/layernorm2d_fwd_pipeline_default_policy.hpp"
#include <string> #include <string>
#include <type_traits> #include <type_traits>
namespace ck_tile { namespace ck_tile {
template <typename Problem_, typename Policy_ = Layernorm2dFwdRowwiseDefaultPolicy> template <typename Problem_, typename Policy_ = Layernorm2dFwdPipelineDefaultPolicy>
struct Layernorm2dFwdTwoPassPipeline struct Layernorm2dFwdPipelineTwoPass
{ {
using Problem = ck_tile::remove_cvref_t<Problem_>; using Problem = ck_tile::remove_cvref_t<Problem_>;
using Policy = ck_tile::remove_cvref_t<Policy_>; using Policy = ck_tile::remove_cvref_t<Policy_>;
...@@ -73,9 +73,15 @@ struct Layernorm2dFwdTwoPassPipeline ...@@ -73,9 +73,15 @@ struct Layernorm2dFwdTwoPassPipeline
index_t num_n_tile_iteration = index_t num_n_tile_iteration =
__builtin_amdgcn_readfirstlane(integer_divide_ceil(row_size, Block_N)); __builtin_amdgcn_readfirstlane(integer_divide_ceil(row_size, Block_N));
// total number of count assume current iter have no pad(only last iter has pad)
constexpr index_t count_per_iter =
Problem::BlockShape::Repeat_N * Problem::BlockShape::Vector_N;
const index_t last_iter_n = row_size - (num_n_tile_iteration - 1) * Block_N;
int cur_count = 0; int cur_count = 0;
int max_count = int max_count =
block_tile_welford_calculate_max_count<typename Problem::BlockShape>(row_size); (num_n_tile_iteration - 1) * count_per_iter +
block_tile_welford_calculate_max_count<typename Problem::BlockShape>(last_iter_n);
auto block_welford = Policy::template GetBlockWelford<Problem>(); auto block_welford = Policy::template GetBlockWelford<Problem>();
auto block_welford_sync = Policy::template GetBlockWelfordSync<Problem>(); auto block_welford_sync = Policy::template GetBlockWelfordSync<Problem>();
auto block_welford_cross_warp_sync = auto block_welford_cross_warp_sync =
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment