Commit 1cb3e443 authored by carlushuang's avatar carlushuang
Browse files

rename more

parent cc1898fc
...@@ -106,7 +106,7 @@ float layernorm2d_fwd_(const S& s, A a) ...@@ -106,7 +106,7 @@ float layernorm2d_fwd_(const S& s, A a)
{ {
using DataType = typename Traits_::DataType; using DataType = typename Traits_::DataType;
using PipelineProblem = ck_tile::Layernorm2dFwdWarpPerRowProblem< using PipelineProblem = ck_tile::Layernorm2dFwdRowwiseProblem<
typename LayerNormTypeConfig<DataType>::XDataType, typename LayerNormTypeConfig<DataType>::XDataType,
typename LayerNormTypeConfig<DataType>::GammaDataType, typename LayerNormTypeConfig<DataType>::GammaDataType,
typename LayerNormTypeConfig<DataType>::BetaDataType, typename LayerNormTypeConfig<DataType>::BetaDataType,
...@@ -118,7 +118,7 @@ float layernorm2d_fwd_(const S& s, A a) ...@@ -118,7 +118,7 @@ float layernorm2d_fwd_(const S& s, A a)
Traits_::kPadN, Traits_::kPadN,
Traits_::kSaveMeanInvStd, Traits_::kSaveMeanInvStd,
Traits_::kTwoPass>; Traits_::kTwoPass>;
using Pipeline = ck_tile::Layernorm2dFwdWarpPerRowPipeline<PipelineProblem>; using Pipeline = ck_tile::Layernorm2dFwdRowwisePipeline<PipelineProblem>;
using Kernel = ck_tile::Layernorm2dFwd<Pipeline>; using Kernel = ck_tile::Layernorm2dFwd<Pipeline>;
......
...@@ -5,7 +5,7 @@ ...@@ -5,7 +5,7 @@
#include "ck_tile/ops/layernorm2d/kernel/layernorm2d_fwd_kernel.hpp" #include "ck_tile/ops/layernorm2d/kernel/layernorm2d_fwd_kernel.hpp"
#include "ck_tile/ops/layernorm2d/kernel/layernorm2d_fwd_shape.hpp" #include "ck_tile/ops/layernorm2d/kernel/layernorm2d_fwd_shape.hpp"
#include "ck_tile/ops/layernorm2d/pipeline/layernorm2d_fwd_warp_per_row_default_policy.hpp" #include "ck_tile/ops/layernorm2d/pipeline/layernorm2d_fwd_rowwise_default_policy.hpp"
#include "ck_tile/ops/layernorm2d/pipeline/layernorm2d_fwd_warp_per_row_pipeline.hpp" #include "ck_tile/ops/layernorm2d/pipeline/layernorm2d_fwd_rowwise_pipeline.hpp"
#include "ck_tile/ops/layernorm2d/pipeline/layernorm2d_fwd_warp_per_row_problem.hpp" #include "ck_tile/ops/layernorm2d/pipeline/layernorm2d_fwd_rowwise_problem.hpp"
#include "ck_tile/ops/common/tensor_layout.hpp" #include "ck_tile/ops/common/tensor_layout.hpp"
...@@ -21,7 +21,7 @@ namespace ck_tile { ...@@ -21,7 +21,7 @@ namespace ck_tile {
+--------------+--------------+ | <WarpPerBlock_M(2)> | +--------------+--------------+ | <WarpPerBlock_M(2)> |
| wrap_2 | wrap_3 | | v | wrap_2 | wrap_3 | | v
+--------------+--------------+--------------+--------------+----+ Block_M +--------------+--------------+--------------+--------------+----+ Block_M
| | | (Warp_M * WarpPerBlock_M * Repeat_M ) | | |
+ + | + + |
| | | v | | | v
+--------------+--------------+--------------+--------------+ + +--------------+--------------+--------------+--------------+ +
......
...@@ -9,7 +9,7 @@ ...@@ -9,7 +9,7 @@
namespace ck_tile { namespace ck_tile {
struct Layernorm2dFwdWarpPerRowDefaultPolicy struct Layernorm2dFwdRowwiseDefaultPolicy
{ {
template <typename Problem> template <typename Problem>
CK_TILE_DEVICE static constexpr auto MakeXBlockTileDistribution() CK_TILE_DEVICE static constexpr auto MakeXBlockTileDistribution()
......
...@@ -4,14 +4,14 @@ ...@@ -4,14 +4,14 @@
#pragma once #pragma once
#include "ck_tile/core.hpp" #include "ck_tile/core.hpp"
#include "ck_tile/ops/layernorm2d/pipeline/layernorm2d_fwd_warp_per_row_default_policy.hpp" #include "ck_tile/ops/layernorm2d/pipeline/layernorm2d_fwd_rowwise_default_policy.hpp"
#include <string> #include <string>
#include <type_traits> #include <type_traits>
namespace ck_tile { namespace ck_tile {
template <typename Problem_, typename Policy_ = Layernorm2dFwdWarpPerRowDefaultPolicy> template <typename Problem_, typename Policy_ = Layernorm2dFwdRowwiseDefaultPolicy>
struct Layernorm2dFwdWarpPerRowPipeline struct Layernorm2dFwdRowwisePipeline
{ {
using Problem = ck_tile::remove_cvref_t<Problem_>; using Problem = ck_tile::remove_cvref_t<Problem_>;
using Policy = ck_tile::remove_cvref_t<Policy_>; using Policy = ck_tile::remove_cvref_t<Policy_>;
......
...@@ -18,7 +18,7 @@ template <typename XDataType_, ...@@ -18,7 +18,7 @@ template <typename XDataType_,
bool kPadN_, bool kPadN_,
bool kSaveMeanInvStd_, bool kSaveMeanInvStd_,
bool kTwoPass_> bool kTwoPass_>
struct Layernorm2dFwdWarpPerRowProblem struct Layernorm2dFwdRowwiseProblem
{ {
using XDataType = remove_cvref_t<XDataType_>; using XDataType = remove_cvref_t<XDataType_>;
using GammaDataType = remove_cvref_t<GammaDataType_>; using GammaDataType = remove_cvref_t<GammaDataType_>;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment