Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel_ROCM
Commits
0ef27d53
Commit
0ef27d53
authored
Jan 21, 2025
by
Andriy Roshchenko
Browse files
Merge remote-tracking branch 'origin/gfx950' into andriy/lwpck-2682
parents
6778c318
74a743e2
Changes
348
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
651 additions
and
220 deletions
+651
-220
include/ck_tile/ops/layernorm2d/pipeline/layernorm2d_fwd_pipeline_default_policy.hpp
...rm2d/pipeline/layernorm2d_fwd_pipeline_default_policy.hpp
+30
-27
include/ck_tile/ops/layernorm2d/pipeline/layernorm2d_fwd_pipeline_one_pass.hpp
...ayernorm2d/pipeline/layernorm2d_fwd_pipeline_one_pass.hpp
+49
-19
include/ck_tile/ops/layernorm2d/pipeline/layernorm2d_fwd_pipeline_problem.hpp
...layernorm2d/pipeline/layernorm2d_fwd_pipeline_problem.hpp
+14
-12
include/ck_tile/ops/layernorm2d/pipeline/layernorm2d_fwd_pipeline_two_pass.hpp
...ayernorm2d/pipeline/layernorm2d_fwd_pipeline_two_pass.hpp
+48
-17
include/ck_tile/ops/layernorm2d/pipeline/layernorm2d_fwd_traits.hpp
..._tile/ops/layernorm2d/pipeline/layernorm2d_fwd_traits.hpp
+17
-0
include/ck_tile/ops/norm_reduce.hpp
include/ck_tile/ops/norm_reduce.hpp
+10
-0
include/ck_tile/ops/norm_reduce/block/block_norm_reduce.hpp
include/ck_tile/ops/norm_reduce/block/block_norm_reduce.hpp
+78
-48
include/ck_tile/ops/norm_reduce/block/block_norm_reduce_problem.hpp
..._tile/ops/norm_reduce/block/block_norm_reduce_problem.hpp
+7
-2
include/ck_tile/ops/norm_reduce/thread/thread_welford.hpp
include/ck_tile/ops/norm_reduce/thread/thread_welford.hpp
+0
-0
include/ck_tile/ops/permute.hpp
include/ck_tile/ops/permute.hpp
+1
-1
include/ck_tile/ops/reduce.hpp
include/ck_tile/ops/reduce.hpp
+1
-1
include/ck_tile/ops/rmsnorm2d.hpp
include/ck_tile/ops/rmsnorm2d.hpp
+2
-1
include/ck_tile/ops/rmsnorm2d/kernel/rmsnorm2d_fwd_kernel.hpp
...ude/ck_tile/ops/rmsnorm2d/kernel/rmsnorm2d_fwd_kernel.hpp
+169
-28
include/ck_tile/ops/rmsnorm2d/pipeline/rmsnorm2d_fwd_pipeline_default_policy.hpp
...norm2d/pipeline/rmsnorm2d_fwd_pipeline_default_policy.hpp
+5
-5
include/ck_tile/ops/rmsnorm2d/pipeline/rmsnorm2d_fwd_pipeline_one_pass.hpp
...ps/rmsnorm2d/pipeline/rmsnorm2d_fwd_pipeline_one_pass.hpp
+66
-15
include/ck_tile/ops/rmsnorm2d/pipeline/rmsnorm2d_fwd_pipeline_problem.hpp
...ops/rmsnorm2d/pipeline/rmsnorm2d_fwd_pipeline_problem.hpp
+13
-13
include/ck_tile/ops/rmsnorm2d/pipeline/rmsnorm2d_fwd_pipeline_two_pass.hpp
...ps/rmsnorm2d/pipeline/rmsnorm2d_fwd_pipeline_two_pass.hpp
+73
-18
include/ck_tile/ops/rmsnorm2d/pipeline/rmsnorm2d_fwd_traits.hpp
...e/ck_tile/ops/rmsnorm2d/pipeline/rmsnorm2d_fwd_traits.hpp
+54
-0
include/ck_tile/ops/smoothquant.hpp
include/ck_tile/ops/smoothquant.hpp
+1
-1
include/ck_tile/ops/smoothquant/kernel/moe_smoothquant_kernel.hpp
...ck_tile/ops/smoothquant/kernel/moe_smoothquant_kernel.hpp
+13
-12
No files found.
include/ck_tile/ops/layernorm2d/pipeline/layernorm2d_fwd_pipeline_default_policy.hpp
View file @
0ef27d53
...
...
@@ -4,8 +4,8 @@
#pragma once
#include "ck_tile/core.hpp"
#include "ck_tile/ops/
welford
/block/block_
welford
_problem.hpp"
#include "ck_tile/ops/
welford
/block/block_
welford
.hpp"
#include "ck_tile/ops/
norm_reduce
/block/block_
norm_reduce
_problem.hpp"
#include "ck_tile/ops/
norm_reduce
/block/block_
norm_reduce
.hpp"
namespace
ck_tile
{
...
...
@@ -43,36 +43,38 @@ struct Layernorm2dFwdPipelineDefaultPolicy
}
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
GetBlock
Welford
()
CK_TILE_HOST_DEVICE
static
constexpr
auto
GetBlock
NormReduce
()
{
using
P_
=
Block
Welford
Problem
<
typename
Problem
::
ComputeDataType
,
typename
Problem
::
ComputeDataType
,
typename
Problem
::
BlockShape
,
Problem
::
Traits
::
kFastFDiv
>
;
return
Block
Welford
<
P_
>
{};
using
P_
=
Block
NormReduce
Problem
<
typename
Problem
::
ComputeDataType
,
typename
Problem
::
ComputeDataType
,
typename
Problem
::
BlockShape
,
Problem
::
Traits
::
kFastFDiv
,
Problem
::
Traits
::
kWelford
>
;
return
Block
NormReduce
<
P_
>
{};
}
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
GetBlock
Welford
Sync
()
CK_TILE_HOST_DEVICE
static
constexpr
auto
GetBlock
NormReduce
Sync
()
{
using
P_
=
BlockWelfordProblem
<
typename
Problem
::
ComputeDataType
,
typename
Problem
::
ComputeDataType
,
typename
Problem
::
BlockShape
,
Problem
::
Traits
::
kFastFDiv
>
;
using
P_
=
BlockNormReduceProblem
<
typename
Problem
::
ComputeDataType
,
typename
Problem
::
ComputeDataType
,
typename
Problem
::
BlockShape
,
Problem
::
Traits
::
kFastFDiv
,
Problem
::
Traits
::
kWelford
>
;
return
Block
Welford
Sync
<
P_
>
{};
return
Block
NormReduce
Sync
<
P_
>
{};
}
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
GetBlock
Welford
CrossWarpSync
()
CK_TILE_HOST_DEVICE
static
constexpr
auto
GetBlock
NormReduce
CrossWarpSync
()
{
using
P_
=
BlockWelfordProblem
<
typename
Problem
::
ComputeDataType
,
typename
Problem
::
ComputeDataType
,
typename
Problem
::
BlockShape
,
Problem
::
Traits
::
kFastFDiv
>
;
using
P_
=
BlockNormReduceProblem
<
typename
Problem
::
ComputeDataType
,
typename
Problem
::
ComputeDataType
,
typename
Problem
::
BlockShape
,
Problem
::
Traits
::
kFastFDiv
,
Problem
::
Traits
::
kWelford
>
;
return
Block
Welford
CrossWarpSync
<
P_
>
{};
return
Block
NormReduce
CrossWarpSync
<
P_
>
{};
}
template
<
typename
Problem
>
...
...
@@ -80,19 +82,20 @@ struct Layernorm2dFwdPipelineDefaultPolicy
{
if
constexpr
(
Problem
::
kNeedCrossWarpSync
)
{
using
P_
=
BlockWelfordProblem
<
typename
Problem
::
ComputeDataType
,
typename
Problem
::
ComputeDataType
,
typename
Problem
::
BlockShape
,
Problem
::
Traits
::
kFastFDiv
>
;
using
P_
=
BlockNormReduceProblem
<
typename
Problem
::
ComputeDataType
,
typename
Problem
::
ComputeDataType
,
typename
Problem
::
BlockShape
,
Problem
::
Traits
::
kFastFDiv
,
Problem
::
Traits
::
kWelford
>
;
using
block_welford
=
Block
Welford
<
P_
>
;
using
block_welford
=
Block
NormReduce
<
P_
>
;
using
x_block_tile
=
decltype
(
make_static_distributed_tensor
<
typename
Problem
::
ComputeDataType
>
(
MakeXBlockTileDistribution
<
Problem
>
()));
using
mean_var_block_tile
=
decltype
(
block_welford
::
template
MakeMeanVarBlockTile
<
x_block_tile
>());
return
GetBlock
Welford
CrossWarpSync
<
Problem
>
()
return
GetBlock
NormReduce
CrossWarpSync
<
Problem
>
()
.
template
GetSmemSize
<
mean_var_block_tile
>();
}
else
...
...
include/ck_tile/ops/layernorm2d/pipeline/layernorm2d_fwd_pipeline_one_pass.hpp
View file @
0ef27d53
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-202
4
, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-202
5
, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
...
...
@@ -18,6 +18,7 @@ struct Layernorm2dFwdPipelineOnePass
using
Policy
=
ck_tile
::
remove_cvref_t
<
Policy_
>
;
using
XDataType
=
ck_tile
::
remove_cvref_t
<
typename
Problem
::
XDataType
>
;
using
XBiasDataType
=
ck_tile
::
remove_cvref_t
<
typename
Problem
::
XBiasDataType
>
;
using
GammaDataType
=
ck_tile
::
remove_cvref_t
<
typename
Problem
::
GammaDataType
>
;
using
BetaDataType
=
ck_tile
::
remove_cvref_t
<
typename
Problem
::
BetaDataType
>
;
using
ComputeDataType
=
ck_tile
::
remove_cvref_t
<
typename
Problem
::
ComputeDataType
>
;
...
...
@@ -37,6 +38,8 @@ struct Layernorm2dFwdPipelineOnePass
static
constexpr
bool
kPadM
=
false
;
// TODO - BlockLayernorm2dFwdProblem::kPadM
static
constexpr
bool
kPadN
=
Problem
::
Traits
::
kPadN
;
static
constexpr
bool
kFastFDiv
=
Problem
::
Traits
::
kFastFDiv
;
static
constexpr
bool
kWelford
=
Problem
::
Traits
::
kWelford
;
static
constexpr
auto
kXbias
=
Problem
::
Traits
::
kXbias
;
static
constexpr
auto
kFusedAdd
=
Problem
::
Traits
::
kFusedAdd
;
static
constexpr
auto
kFusedQuant
=
Problem
::
Traits
::
kFusedQuant
;
...
...
@@ -54,24 +57,26 @@ struct Layernorm2dFwdPipelineOnePass
template
<
typename
XWindow
,
typename
XResidualWindow
,
typename
XBiasWindow
,
typename
GammaWindow
,
typename
BetaWindow
,
typename
YWindow
,
typename
YResidualWindow
,
typename
MeanWindow
,
typename
InvStdWindow
,
typename
X
ScaleWindow
,
typename
Smooth
ScaleWindow
,
typename
YScaleWindow
,
typename
Epilogue
>
CK_TILE_DEVICE
auto
operator
()(
const
XWindow
&
x_window_
,
const
XResidualWindow
&
x_residual_window_
,
const
XBiasWindow
&
x_bias_window_
,
const
GammaWindow
&
gamma_window_
,
const
BetaWindow
&
beta_window_
,
YWindow
&
y_window_
,
const
YResidualWindow
&
y_residual_window_
,
MeanWindow
&
mean_window
,
InvStdWindow
&
inv_std_window
,
const
X
ScaleWindow
&
x
_scale_window_
,
const
Smooth
ScaleWindow
&
sm
_scale_window_
,
YScaleWindow
&
y_scale_window
,
ComputeDataType
epsilon
,
ck_tile
::
index_t
row_size
,
...
...
@@ -80,6 +85,8 @@ struct Layernorm2dFwdPipelineOnePass
{
const
auto
x_window
=
make_tile_window
(
x_window_
,
Policy
::
template
MakeXBlockTileDistribution
<
Problem
>());
const
auto
x_bias_window
=
make_tile_window
(
x_bias_window_
,
Policy
::
template
MakeGammaBetaBlockTileDistribution
<
Problem
>());
const
auto
gamma_window
=
make_tile_window
(
gamma_window_
,
Policy
::
template
MakeGammaBetaBlockTileDistribution
<
Problem
>());
const
auto
beta_window
=
make_tile_window
(
...
...
@@ -89,23 +96,38 @@ struct Layernorm2dFwdPipelineOnePass
auto
y_residual_window
=
make_tile_window
(
y_residual_window_
,
Policy
::
template
MakeXBlockTileDistribution
<
Problem
>());
auto
x
=
load_tile
(
x_window
);
auto
x_resi
=
load_tile
(
x_residual_window
);
auto
x
=
load_tile
(
x_window
);
auto
x_resi
=
load_tile
(
x_residual_window
);
const
auto
x_bias
=
load_tile
(
x_bias_window
);
int
cur_count
=
0
;
int
max_count
=
block_tile_welford_calculate_max_count
<
typename
Problem
::
BlockShape
>
(
row_size
);
auto
block_welford
=
Policy
::
template
GetBlockWelford
<
Problem
>();
auto
block_welford_sync
=
Policy
::
template
GetBlockWelfordSync
<
Problem
>();
auto
block_welford_cross_warp_sync
=
Policy
::
template
GetBlockWelfordCrossWarpSync
<
Problem
>();
auto
block_norm_reduce
=
Policy
::
template
GetBlockNormReduce
<
Problem
>();
auto
block_norm_reduce_sync
=
Policy
::
template
GetBlockNormReduceSync
<
Problem
>();
auto
block_norm_reduce_cross_warp_sync
=
Policy
::
template
GetBlockNormReduceCrossWarpSync
<
Problem
>();
using
XTensorType
=
decltype
(
cast_tile
<
ComputeDataType
>
(
x
));
auto
mean
=
block_norm_reduce
.
template
MakeMeanVarBlockTile
<
XTensorType
>();
auto
var
=
block_norm_reduce
.
template
MakeMeanVarBlockTile
<
XTensorType
>();
clear_tile
(
mean
);
clear_tile
(
var
);
// load gamma/beta (TODO: support no gamma/beta?)
const
auto
gamma
=
load_tile
(
gamma_window
);
const
auto
beta
=
load_tile
(
beta_window
);
auto
acc
=
cast_tile
<
ComputeDataType
>
(
x
);
if
constexpr
(
kXbias
==
Layernorm2dXBiasEnum
::
ADD_BIAS
)
{
sweep_tile
(
x
,
[
&
](
auto
idx
)
{
// compute x = bias + x
constexpr
auto
j_idx
=
make_tuple
(
idx
[
number
<
1
>
{}]);
acc
(
idx
)
=
type_convert
<
ComputeDataType
>
(
x_bias
[
j_idx
])
+
acc
(
idx
);
});
}
if
constexpr
(
kFusedAdd
==
Layernorm2dFusedAddEnum
::
PRE_ADD_STORE
||
kFusedAdd
==
Layernorm2dFusedAddEnum
::
PRE_ADD
)
{
...
...
@@ -117,12 +139,21 @@ struct Layernorm2dFwdPipelineOnePass
store_tile
(
y_residual_window
,
cast_tile
<
YResidualDataType
>
(
acc
));
}
// compute welford each-thread->cross-lane->cross-warp
auto
[
mean
,
var
]
=
block_welford
(
acc
,
cur_count
,
max_count
);
block_welford_sync
(
mean
,
var
,
cur_count
);
block_welford_cross_warp_sync
(
mean
,
var
,
cur_count
,
smem
);
block_tile_welford_post_scale_var
(
var
,
cur_count
,
constant
<
kFastFDiv
>
{});
// compute reduce each-thread->cross-lane->cross-warp
block_norm_reduce
(
acc
,
mean
,
var
,
cur_count
,
max_count
);
block_norm_reduce_sync
(
mean
,
var
,
cur_count
);
block_norm_reduce_cross_warp_sync
(
mean
,
var
,
cur_count
,
smem
);
if
(
kWelford
)
{
block_tile_welford_post_scale_var
(
var
,
cur_count
,
constant
<
kFastFDiv
>
{});
}
else
{
sweep_tile
(
mean
,
[
&
](
auto
idx
)
{
mean
(
idx
)
=
mean
(
idx
)
/
type_convert
<
MeanDataType
>
(
row_size
);
var
(
idx
)
=
var
(
idx
)
/
type_convert
<
MeanDataType
>
(
row_size
)
-
mean
(
idx
)
*
mean
(
idx
);
});
}
// compute inv-std
auto
inv_std
=
tile_elementwise_in
(
[
&
](
const
auto
&
v_
)
{
...
...
@@ -153,14 +184,13 @@ struct Layernorm2dFwdPipelineOnePass
const
auto
beta_
=
type_convert
<
ComputeDataType
>
(
beta
[
j_idx
]);
auto
ln_
=
(
acc
[
idx
]
-
mean_
[
i_idx
])
*
inv_std
[
i_idx
]
*
gamma_
+
beta_
;
ln
(
idx
)
=
ln_
;
ln
(
idx
)
=
ln_
;
});
if
constexpr
(
kFusedQuant
==
Layernorm2dFusedQuantEnum
::
DYNAMIC_QUANT
||
kFusedQuant
==
Layernorm2dFusedQuantEnum
::
SMOOTH_DYNAMIC_QUANT
)
{
Epilogue
{}(
y_window_
,
x
_scale_window_
,
y_scale_window
,
ln
,
smem
);
Epilogue
{}(
y_window_
,
sm
_scale_window_
,
y_scale_window
,
ln
,
smem
);
}
else
Epilogue
{}(
y_window_
,
ln
);
...
...
include/ck_tile/ops/layernorm2d/pipeline/layernorm2d_fwd_pipeline_problem.hpp
View file @
0ef27d53
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-202
4
, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-202
5
, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
...
...
@@ -8,28 +8,30 @@
namespace
ck_tile
{
template
<
typename
XDataType_
,
typename
XBiasDataType_
,
typename
GammaDataType_
,
typename
BetaDataType_
,
typename
ComputeDataType_
,
typename
YDataType_
,
typename
MeanDataType_
,
typename
InvStdDataType_
,
typename
X
ScaleDataType_
,
typename
Smooth
ScaleDataType_
,
typename
YScaleDataType_
,
typename
BlockShape_
,
typename
Traits_
>
struct
Layernorm2dFwdPipelineProblem
{
using
XDataType
=
remove_cvref_t
<
XDataType_
>
;
using
GammaDataType
=
remove_cvref_t
<
GammaDataType_
>
;
using
BetaDataType
=
remove_cvref_t
<
BetaDataType_
>
;
using
ComputeDataType
=
remove_cvref_t
<
ComputeDataType_
>
;
using
YDataType
=
remove_cvref_t
<
YDataType_
>
;
using
MeanDataType
=
remove_cvref_t
<
MeanDataType_
>
;
using
InvStdDataType
=
remove_cvref_t
<
InvStdDataType_
>
;
using
XScaleDataType
=
remove_cvref_t
<
XScaleDataType_
>
;
using
YScaleDataType
=
remove_cvref_t
<
YScaleDataType_
>
;
using
BlockShape
=
remove_cvref_t
<
BlockShape_
>
;
using
XDataType
=
remove_cvref_t
<
XDataType_
>
;
using
XBiasDataType
=
remove_cvref_t
<
XBiasDataType_
>
;
using
GammaDataType
=
remove_cvref_t
<
GammaDataType_
>
;
using
BetaDataType
=
remove_cvref_t
<
BetaDataType_
>
;
using
ComputeDataType
=
remove_cvref_t
<
ComputeDataType_
>
;
using
YDataType
=
remove_cvref_t
<
YDataType_
>
;
using
MeanDataType
=
remove_cvref_t
<
MeanDataType_
>
;
using
InvStdDataType
=
remove_cvref_t
<
InvStdDataType_
>
;
using
SmoothScaleDataType
=
remove_cvref_t
<
SmoothScaleDataType_
>
;
using
YScaleDataType
=
remove_cvref_t
<
YScaleDataType_
>
;
using
BlockShape
=
remove_cvref_t
<
BlockShape_
>
;
static
constexpr
bool
kNeedCrossLaneSync
=
BlockShape
::
ThreadPerWarp_N
>
1
;
static
constexpr
bool
kNeedCrossWarpSync
=
BlockShape
::
WarpPerBlock_N
>
1
;
...
...
include/ck_tile/ops/layernorm2d/pipeline/layernorm2d_fwd_pipeline_two_pass.hpp
View file @
0ef27d53
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-202
4
, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-202
5
, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
...
...
@@ -17,6 +17,7 @@ struct Layernorm2dFwdPipelineTwoPass
using
Policy
=
ck_tile
::
remove_cvref_t
<
Policy_
>
;
using
XDataType
=
ck_tile
::
remove_cvref_t
<
typename
Problem
::
XDataType
>
;
using
XBiasDataType
=
ck_tile
::
remove_cvref_t
<
typename
Problem
::
XBiasDataType
>
;
using
GammaDataType
=
ck_tile
::
remove_cvref_t
<
typename
Problem
::
GammaDataType
>
;
using
BetaDataType
=
ck_tile
::
remove_cvref_t
<
typename
Problem
::
BetaDataType
>
;
using
ComputeDataType
=
ck_tile
::
remove_cvref_t
<
typename
Problem
::
ComputeDataType
>
;
...
...
@@ -36,6 +37,8 @@ struct Layernorm2dFwdPipelineTwoPass
static
constexpr
bool
kPadM
=
false
;
// TODO - BlockLayernorm2dFwdProblem::kPadM
static
constexpr
bool
kPadN
=
Problem
::
Traits
::
kPadN
;
static
constexpr
bool
kFastFDiv
=
Problem
::
Traits
::
kFastFDiv
;
static
constexpr
bool
kWelford
=
Problem
::
Traits
::
kWelford
;
static
constexpr
auto
kXbias
=
Problem
::
Traits
::
kXbias
;
static
constexpr
auto
kFusedAdd
=
Problem
::
Traits
::
kFusedAdd
;
static
constexpr
auto
kFusedQuant
=
Problem
::
Traits
::
kFusedQuant
;
...
...
@@ -53,32 +56,37 @@ struct Layernorm2dFwdPipelineTwoPass
template
<
typename
XWindow
,
typename
XResidualWindow
,
typename
XBiasWindow
,
typename
GammaWindow
,
typename
BetaWindow
,
typename
YWindow
,
typename
YResidualWindow
,
typename
MeanWindow
,
typename
InvStdWindow
,
typename
X
ScaleWindow
,
typename
Smooth
ScaleWindow
,
typename
YScaleWindow
,
typename
Epilogue
>
CK_TILE_DEVICE
auto
operator
()(
const
XWindow
&
x_window_
,
const
XResidualWindow
&
x_residual_window_
,
const
XBiasWindow
&
x_bias_window_
,
const
GammaWindow
&
gamma_window_
,
const
BetaWindow
&
beta_window_
,
YWindow
&
y_window
,
const
YResidualWindow
&
y_residual_window_
,
MeanWindow
&
mean_window
,
InvStdWindow
&
inv_std_window
,
const
X
ScaleWindow
&
/*
x
_scale_window*/
,
const
Smooth
ScaleWindow
&
/*
sm
_scale_window*/
,
YScaleWindow
&
/*y_scale_window*/
,
ComputeDataType
epsilon
,
ck_tile
::
index_t
row_size
,
void
*
smem
,
Epilogue
)
const
{
static_assert
(
kWelford
==
true
,
"2 pass only supports welford merge"
);
auto
x_window
=
make_tile_window
(
x_window_
,
Policy
::
template
MakeXBlockTileDistribution
<
Problem
>());
auto
x_bias_window
=
make_tile_window
(
x_bias_window_
,
Policy
::
template
MakeGammaBetaBlockTileDistribution
<
Problem
>());
auto
gamma_window
=
make_tile_window
(
gamma_window_
,
Policy
::
template
MakeGammaBetaBlockTileDistribution
<
Problem
>());
auto
beta_window
=
make_tile_window
(
...
...
@@ -102,24 +110,35 @@ struct Layernorm2dFwdPipelineTwoPass
int
max_count
=
(
num_n_tile_iteration
-
1
)
*
count_per_iter
+
block_tile_welford_calculate_max_count
<
typename
Problem
::
BlockShape
>
(
last_iter_n
);
auto
block_
welford
=
Policy
::
template
GetBlock
Welford
<
Problem
>();
auto
block_
welford
_sync
=
Policy
::
template
GetBlock
Welford
Sync
<
Problem
>();
auto
block_
welford
_cross_warp_sync
=
Policy
::
template
GetBlock
Welford
CrossWarpSync
<
Problem
>();
auto
block_
norm_reduce
=
Policy
::
template
GetBlock
NormReduce
<
Problem
>();
auto
block_
norm_reduce
_sync
=
Policy
::
template
GetBlock
NormReduce
Sync
<
Problem
>();
auto
block_
norm_reduce
_cross_warp_sync
=
Policy
::
template
GetBlock
NormReduce
CrossWarpSync
<
Problem
>();
using
XTensorType
=
decltype
(
cast_tile
<
ComputeDataType
>
(
load_tile
(
x_window
)));
auto
mean
=
block_
welford
.
template
MakeMeanVarBlockTile
<
XTensorType
>();
auto
var
=
block_
welford
.
template
MakeMeanVarBlockTile
<
XTensorType
>();
auto
mean
=
block_
norm_reduce
.
template
MakeMeanVarBlockTile
<
XTensorType
>();
auto
var
=
block_
norm_reduce
.
template
MakeMeanVarBlockTile
<
XTensorType
>();
for
(
int
iN
=
__builtin_amdgcn_readfirstlane
(
0
);
iN
<
num_n_tile_iteration
;
++
iN
)
{
auto
x
=
load_tile
(
x_window
);
auto
x_resi
=
load_tile
(
x_residual_window
);
auto
x
=
load_tile
(
x_window
);
auto
x_resi
=
load_tile
(
x_residual_window
);
const
auto
x_bias
=
load_tile
(
x_bias_window
);
move_tile_window
(
x_window
,
{
0
,
Block_N
});
move_tile_window
(
x_residual_window
,
{
0
,
Block_N
});
move_tile_window
(
x_bias_window
,
{
Block_N
});
auto
acc
=
cast_tile
<
ComputeDataType
>
(
x
);
if
constexpr
(
kXbias
==
Layernorm2dXBiasEnum
::
ADD_BIAS
)
{
sweep_tile
(
x
,
[
&
](
auto
idx
)
{
// compute x = bias + x
constexpr
auto
j_idx
=
make_tuple
(
idx
[
number
<
1
>
{}]);
acc
(
idx
)
=
type_convert
<
ComputeDataType
>
(
x_bias
[
j_idx
])
+
acc
(
idx
);
});
}
if
constexpr
(
kFusedAdd
==
Layernorm2dFusedAddEnum
::
PRE_ADD_STORE
||
kFusedAdd
==
Layernorm2dFusedAddEnum
::
PRE_ADD
)
{
...
...
@@ -133,11 +152,11 @@ struct Layernorm2dFwdPipelineTwoPass
move_tile_window
(
y_residual_window
,
{
0
,
Block_N
});
}
}
block_
welford
(
acc
,
mean
,
var
,
cur_count
,
max_count
);
block_
norm_reduce
(
acc
,
mean
,
var
,
cur_count
,
max_count
);
}
block_
welford
_sync
(
mean
,
var
,
cur_count
);
block_
welford
_cross_warp_sync
(
mean
,
var
,
cur_count
,
smem
);
block_
norm_reduce
_sync
(
mean
,
var
,
cur_count
);
block_
norm_reduce
_cross_warp_sync
(
mean
,
var
,
cur_count
,
smem
);
block_tile_welford_post_scale_var
(
var
,
cur_count
,
constant
<
kFastFDiv
>
{});
// compute inv-std
...
...
@@ -165,6 +184,7 @@ struct Layernorm2dFwdPipelineTwoPass
move_tile_window
(
x_window
,
{
0
,
-
Block_N
});
move_tile_window
(
x_residual_window
,
{
0
,
-
Block_N
});
move_tile_window
(
x_bias_window
,
{
-
Block_N
});
move_tile_window
(
gamma_window
,
{
stride_to_right_most_window
});
move_tile_window
(
beta_window
,
{
stride_to_right_most_window
});
move_tile_window
(
y_window
,
{
0
,
stride_to_right_most_window
});
...
...
@@ -172,9 +192,19 @@ struct Layernorm2dFwdPipelineTwoPass
// layernorm computation
for
(
int
iN
=
__builtin_amdgcn_readfirstlane
(
0
);
iN
<
num_n_tile_iteration
;
++
iN
)
{
auto
x
=
load_tile
(
x_window
);
auto
x_resi
=
load_tile
(
x_residual_window
);
auto
acc
=
cast_tile
<
ComputeDataType
>
(
x
);
auto
x
=
load_tile
(
x_window
);
auto
x_resi
=
load_tile
(
x_residual_window
);
const
auto
x_bias
=
load_tile
(
x_bias_window
);
auto
acc
=
cast_tile
<
ComputeDataType
>
(
x
);
if
constexpr
(
kXbias
==
Layernorm2dXBiasEnum
::
ADD_BIAS
)
{
sweep_tile
(
x
,
[
&
](
auto
idx
)
{
// compute x = bias + x
constexpr
auto
j_idx
=
make_tuple
(
idx
[
number
<
1
>
{}]);
acc
(
idx
)
=
type_convert
<
ComputeDataType
>
(
x_bias
[
j_idx
])
+
acc
(
idx
);
});
}
if
constexpr
(
kFusedAdd
==
Layernorm2dFusedAddEnum
::
PRE_ADD_STORE
||
kFusedAdd
==
Layernorm2dFusedAddEnum
::
PRE_ADD
)
...
...
@@ -207,6 +237,7 @@ struct Layernorm2dFwdPipelineTwoPass
move_tile_window
(
x_window
,
{
0
,
-
Block_N
});
move_tile_window
(
x_residual_window
,
{
0
,
-
Block_N
});
move_tile_window
(
x_bias_window
,
{
-
Block_N
});
move_tile_window
(
gamma_window
,
{
-
Block_N
});
move_tile_window
(
beta_window
,
{
-
Block_N
});
move_tile_window
(
y_window
,
{
0
,
-
Block_N
});
...
...
include/ck_tile/ops/layernorm2d/pipeline/layernorm2d_fwd_traits.hpp
View file @
0ef27d53
...
...
@@ -7,6 +7,19 @@
namespace
ck_tile
{
enum
class
Layernorm2dXBiasEnum
{
NO_BIAS
=
0
,
// add bias before fused add
ADD_BIAS
=
1
,
};
// clang-format off
template
<
Layernorm2dXBiasEnum
>
struct
Layernorm2dXBiasEnumName
;
template
<
>
struct
Layernorm2dXBiasEnumName
<
Layernorm2dXBiasEnum
::
NO_BIAS
>
{
static
constexpr
const
char
*
name
=
"no"
;
};
template
<
>
struct
Layernorm2dXBiasEnumName
<
Layernorm2dXBiasEnum
::
ADD_BIAS
>
{
static
constexpr
const
char
*
name
=
"xbias"
;
};
// clang-format on
enum
class
Layernorm2dFusedAddEnum
{
NO_ADD
=
0
,
...
...
@@ -40,7 +53,9 @@ template<> struct Layernorm2dFusedQuantEnumName<Layernorm2dFusedQuantEnum::SMOOT
template
<
bool
kPadN_
,
bool
kSaveMeanInvStd_
,
bool
kFastFDiv_
,
bool
kWelford_
,
bool
kTwoPass_
,
Layernorm2dXBiasEnum
kXbias_
,
Layernorm2dFusedAddEnum
kFusedAdd_
,
Layernorm2dFusedQuantEnum
kFusedQuant_
>
struct
Layernorm2dFwdTraits
...
...
@@ -48,7 +63,9 @@ struct Layernorm2dFwdTraits
static
constexpr
bool
kPadN
=
kPadN_
;
static
constexpr
bool
kSaveMeanInvStd
=
kSaveMeanInvStd_
;
static
constexpr
bool
kFastFDiv
=
kFastFDiv_
;
static
constexpr
bool
kWelford
=
kWelford_
;
static
constexpr
bool
kTwoPass
=
kTwoPass_
;
static
constexpr
Layernorm2dXBiasEnum
kXbias
=
kXbias_
;
static
constexpr
Layernorm2dFusedAddEnum
kFusedAdd
=
kFusedAdd_
;
static
constexpr
Layernorm2dFusedQuantEnum
kFusedQuant
=
kFusedQuant_
;
};
...
...
include/ck_tile/ops/
welford
.hpp
→
include/ck_tile/ops/
norm_reduce
.hpp
View file @
0ef27d53
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-202
4
, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-202
5
, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/ops/
welford
/block/block_
welford
.hpp"
#include "ck_tile/ops/
welford
/block/block_
welford
_problem.hpp"
#include "ck_tile/ops/
welford
/thread/thread_welford.hpp"
#include "ck_tile/ops/
norm_reduce
/block/block_
norm_reduce
.hpp"
#include "ck_tile/ops/
norm_reduce
/block/block_
norm_reduce
_problem.hpp"
#include "ck_tile/ops/
norm_reduce
/thread/thread_welford.hpp"
#include "ck_tile/ops/common/generic_2d_block_shape.hpp"
#include "ck_tile/ops/common/tensor_layout.hpp"
include/ck_tile/ops/
welford
/block/block_
welford
.hpp
→
include/ck_tile/ops/
norm_reduce
/block/block_
norm_reduce
.hpp
View file @
0ef27d53
...
...
@@ -4,22 +4,23 @@
#pragma once
#include "ck_tile/core.hpp"
#include "ck_tile/ops/
welford
/thread/thread_welford.hpp"
#include "ck_tile/ops/
norm_reduce
/thread/thread_welford.hpp"
namespace
ck_tile
{
template
<
typename
Problem_
,
typename
Policy_
=
void
>
struct
Block
Welford
struct
Block
NormReduce
{
using
Problem
=
remove_cvref_t
<
Problem_
>
;
using
XDataType
=
typename
Problem
::
XDataType
;
using
ComputeDataType
=
typename
Problem
::
ComputeDataType
;
static
constexpr
bool
kFastFDiv
=
Problem
::
kFastFDiv
;
static
constexpr
bool
kWelford
=
Problem
::
kWelford
;
CK_TILE_DEVICE
constexpr
Block
Welford
()
{}
CK_TILE_DEVICE
constexpr
Block
NormReduce
()
{}
// [CAUSION] - max_count_ is to deal with the padding problem
// max_count_ is depend on caller, eg: naive and splitN
welford
will have different
// max_count_ is depend on caller, eg: naive and splitN
norm_reduce
will have different
// calculation of max_count_
// -> use block_welford_calculate_max_count to compute
template
<
typename
XDistributedTensor_
,
...
...
@@ -40,18 +41,24 @@ struct BlockWelford
if
(
cur_count_
<
max_count_
)
{
++
cur_count_
;
sweep_tile_span
(
spans
[
I0
],
[
&
](
auto
dstr_idx_i0
)
{
constexpr
auto
in_dstr_idx
=
make_tuple
(
dstr_idx_i0
,
dstr_idx_i1
);
constexpr
auto
out_dstr_idx
=
make_tuple
(
dstr_idx_i0
);
auto
x
=
ck_tile
::
type_convert
<
ComputeDataType
>
(
x_tensor
[
in_dstr_idx
]);
welford_update
(
mean_tensor
(
out_dstr_idx
),
var_tensor
(
out_dstr_idx
),
x
,
cur_count_
,
constant
<
kFastFDiv
>
{});
if
(
kWelford
)
{
welford_update
(
mean_tensor
(
out_dstr_idx
),
var_tensor
(
out_dstr_idx
),
x
,
cur_count_
,
constant
<
kFastFDiv
>
{});
}
else
{
mean_tensor
(
out_dstr_idx
)
+=
x
;
var_tensor
(
out_dstr_idx
)
+=
x
*
x
;
}
});
}
});
...
...
@@ -91,10 +98,11 @@ struct BlockWelford
};
template
<
typename
Problem_
,
typename
Policy_
=
void
>
struct
Block
Welford
Sync
struct
Block
NormReduce
Sync
{
using
Problem
=
remove_cvref_t
<
Problem_
>
;
static
constexpr
bool
kFastFDiv
=
Problem
::
kFastFDiv
;
static
constexpr
bool
kWelford
=
Problem
::
kWelford
;
template
<
typename
MeanDistributedTensor_
,
typename
VarDistributedTensor_
>
CK_TILE_DEVICE
void
...
...
@@ -152,36 +160,48 @@ struct BlockWelfordSync
(
number
<
lid_over_rid_derivative
<<
istage
.
value
>
{}.
value
);
// pull data from remote lane
const
auto
v_remote_mean
=
warp_shuffle
(
v_local_mean
,
src_lane
);
const
auto
v_remote_var
=
warp_shuffle
(
v_local_var
,
src_lane
);
const
auto
v_remote_count
=
warp_shuffle
(
v_local_count
,
src_lane
);
// welford merge
welford_merge
(
v_local_mean
,
v_local_var
,
v_local_count
,
v_remote_mean
,
v_remote_var
,
v_remote_count
,
constant
<
kFastFDiv
>
{});
const
auto
v_remote_mean
=
warp_shuffle
(
v_local_mean
,
src_lane
);
const
auto
v_remote_var
=
warp_shuffle
(
v_local_var
,
src_lane
);
if
(
kWelford
)
{
const
auto
v_remote_count
=
warp_shuffle
(
v_local_count
,
src_lane
);
// norm_reduce merge
welford_merge
(
v_local_mean
,
v_local_var
,
v_local_count
,
v_remote_mean
,
v_remote_var
,
v_remote_count
,
constant
<
kFastFDiv
>
{});
}
else
{
v_local_mean
+=
v_remote_mean
;
v_local_var
+=
v_remote_var
;
}
});
}
});
mean_tensor
.
get_thread_buffer
()(
i
)
=
v_local_mean
;
var_tensor
.
get_thread_buffer
()(
i
)
=
v_local_var
;
count
=
v_local_count
;
if
(
kWelford
)
{
count
=
v_local_count
;
}
});
}
};
template
<
typename
Problem_
,
typename
Policy_
=
void
>
struct
Block
Welford
CrossWarpSync
struct
Block
NormReduce
CrossWarpSync
{
using
Problem
=
remove_cvref_t
<
Problem_
>
;
using
BlockShape
=
typename
Problem
::
BlockShape
;
static
constexpr
bool
kFastFDiv
=
Problem
::
kFastFDiv
;
static
constexpr
bool
kWelford
=
Problem
::
kWelford
;
using
smem_dtype
=
std
::
conditional_t
<
kWelford
,
fp32x4_t
,
fp32x2_t
>
;
template
<
typename
MeanDistributedTensor_
>
CK_TILE_DEVICE
static
constexpr
index_t
GetReduceWarps
()
...
...
@@ -252,7 +272,7 @@ struct BlockWelfordCrossWarpSync
static_assert
(
thread_buf_size
==
VarDistributedTensor_
::
get_thread_buffer_size
());
// Note: we always pack everything into fp32x4
fp32x4_t
*
smem_ptr
=
reinterpret_cast
<
fp32x4_t
*>
(
smem
);
smem_dtype
*
smem_ptr
=
reinterpret_cast
<
smem_dtype
*>
(
smem
);
const
index_t
lane_id
=
get_lane_id
();
const
index_t
warp_id
=
get_warp_id
();
constexpr
auto
num_reduce_warps
=
GetReduceWarps
<
MeanDistributedTensor_
>
();
...
...
@@ -267,11 +287,13 @@ struct BlockWelfordCrossWarpSync
if
(
lane_id
==
0
)
{
static_for
<
0
,
thread_buf_size
,
1
>
{}([
&
](
auto
i
)
{
fp32x4_t
local_scratch_
;
smem_dtype
local_scratch_
;
local_scratch_
[
0
]
=
bit_cast
<
float
>
(
mean_tensor
.
get_thread_buffer
()[
i
]);
local_scratch_
[
1
]
=
bit_cast
<
float
>
(
var_tensor
.
get_thread_buffer
()[
i
]);
local_scratch_
[
2
]
=
bit_cast
<
float
>
(
count
);
if
(
kWelford
)
{
local_scratch_
[
2
]
=
bit_cast
<
float
>
(
count
);
}
smem_ptr
[
smem_offset
+
i
*
num_warps
]
=
local_scratch_
;
});
}
...
...
@@ -280,7 +302,7 @@ struct BlockWelfordCrossWarpSync
// load from smem. here we let everythread to do compute :)
index_t
local_warp_id
=
warp_id
/
num_reduce_warps
;
index_t
local_smem_os
=
local_warp_id
*
num_reduce_warps
;
fp32x4_t
all_scratch
[
thread_buf_size
*
num_reduce_warps
];
smem_dtype
all_scratch
[
thread_buf_size
*
num_reduce_warps
];
static_for
<
0
,
thread_buf_size
,
1
>
{}([
&
](
auto
i_0
)
{
static_for
<
0
,
num_reduce_warps
,
1
>
{}([
&
](
auto
i_1
)
{
all_scratch
[
i_0
*
num_reduce_warps
+
i_1
]
=
...
...
@@ -293,32 +315,40 @@ struct BlockWelfordCrossWarpSync
static_for
<
0
,
thread_buf_size
,
1
>
{}([
&
](
auto
i_0
)
{
// TODO: use descriptor for this
auto
v_local
=
all_scratch
[
i_0
*
num_reduce_warps
];
auto
v_local_mean
=
bit_cast
<
DataType
>
(
v_local
[
0
]);
auto
v_local_var
=
bit_cast
<
DataType
>
(
v_local
[
1
]);
auto
v_local_count
=
bit_cast
<
int
>
(
v_local
[
2
]);
auto
v_local
=
all_scratch
[
i_0
*
num_reduce_warps
];
auto
v_local_mean
=
bit_cast
<
DataType
>
(
v_local
[
0
]);
auto
v_local_var
=
bit_cast
<
DataType
>
(
v_local
[
1
]);
int
v_local_count
=
kWelford
?
bit_cast
<
int
>
(
v_local
[
2
])
:
0
;
// further reduce mean/var
static_for
<
0
,
num_reduce_warps
-
1
,
1
>
{}([
&
](
auto
i_1_n1
)
{
constexpr
auto
i_1
=
number
<
i_1_n1
+
1
>
{};
const
fp32x4_t
v_remote
=
all_scratch
[
i_0
*
num_reduce_warps
+
i_1
];
const
smem_dtype
v_remote
=
all_scratch
[
i_0
*
num_reduce_warps
+
i_1
];
const
auto
v_remote_mean
=
bit_cast
<
DataType
>
(
v_remote
[
0
]);
const
auto
v_remote_var
=
bit_cast
<
DataType
>
(
v_remote
[
1
]);
const
auto
v_remote_count
=
bit_cast
<
int
>
(
v_remote
[
2
]);
welford_merge
(
v_local_mean
,
v_local_var
,
v_local_count
,
v_remote_mean
,
v_remote_var
,
v_remote_count
,
constant
<
kFastFDiv
>
{});
if
(
kWelford
)
{
const
auto
v_remote_count
=
bit_cast
<
int
>
(
v_remote
[
2
]);
welford_merge
(
v_local_mean
,
v_local_var
,
v_local_count
,
v_remote_mean
,
v_remote_var
,
v_remote_count
,
constant
<
kFastFDiv
>
{});
}
else
{
v_local_mean
+=
v_remote_mean
;
v_local_var
+=
v_remote_var
;
}
});
mean_tensor
.
get_thread_buffer
()(
i_0
)
=
v_local_mean
;
var_tensor
.
get_thread_buffer
()(
i_0
)
=
v_local_var
;
count
=
v_local_count
;
if
(
kWelford
)
count
=
v_local_count
;
});
}
};
...
...
include/ck_tile/ops/
welford
/block/block_
welford
_problem.hpp
→
include/ck_tile/ops/
norm_reduce
/block/block_
norm_reduce
_problem.hpp
View file @
0ef27d53
...
...
@@ -7,13 +7,18 @@
namespace
ck_tile
{
template
<
typename
XDataType_
,
typename
ComputeDataType_
,
typename
BlockShape_
,
bool
kFastFDiv_
>
struct
BlockWelfordProblem
template
<
typename
XDataType_
,
typename
ComputeDataType_
,
typename
BlockShape_
,
bool
kFastFDiv_
,
bool
kWelford_
>
struct
BlockNormReduceProblem
{
using
XDataType
=
remove_cvref_t
<
XDataType_
>
;
using
ComputeDataType
=
remove_cvref_t
<
ComputeDataType_
>
;
using
BlockShape
=
remove_cvref_t
<
BlockShape_
>
;
static
constexpr
bool
kFastFDiv
=
kFastFDiv_
;
static
constexpr
bool
kWelford
=
kWelford_
;
};
}
// namespace ck_tile
include/ck_tile/ops/
welford
/thread/thread_welford.hpp
→
include/ck_tile/ops/
norm_reduce
/thread/thread_welford.hpp
View file @
0ef27d53
File moved
include/ck_tile/ops/permute.hpp
View file @
0ef27d53
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-202
4
, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-202
5
, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
...
...
include/ck_tile/ops/reduce.hpp
View file @
0ef27d53
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-202
4
, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-202
5
, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
...
...
include/ck_tile/ops/rmsnorm2d.hpp
View file @
0ef27d53
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-202
4
, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-202
5
, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
...
...
@@ -8,5 +8,6 @@
#include "ck_tile/ops/rmsnorm2d/pipeline/rmsnorm2d_fwd_pipeline_one_pass.hpp"
#include "ck_tile/ops/rmsnorm2d/pipeline/rmsnorm2d_fwd_pipeline_problem.hpp"
#include "ck_tile/ops/rmsnorm2d/pipeline/rmsnorm2d_fwd_pipeline_two_pass.hpp"
#include "ck_tile/ops/rmsnorm2d/pipeline/rmsnorm2d_fwd_traits.hpp"
#include "ck_tile/ops/common/generic_2d_block_shape.hpp"
#include "ck_tile/ops/common/tensor_layout.hpp"
include/ck_tile/ops/rmsnorm2d/kernel/rmsnorm2d_fwd_kernel.hpp
View file @
0ef27d53
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-202
4
, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-202
5
, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core.hpp"
#include "ck_tile/ops/common.hpp"
#include "ck_tile/ops/rmsnorm2d/pipeline/rmsnorm2d_fwd_traits.hpp"
namespace
ck_tile
{
// host side args
struct
Rmsnorm2dFwdHostArgs
{
const
void
*
p_x
;
// [m ,n], input, fp16/bf16
const
void
*
p_gamma
;
// [1, n], gamma, prec same as input
const
void
*
p_x
;
// [m ,n], input, fp16/bf16
const
void
*
p_x_residual
;
// [m ,n], shortcut input, prec same as input, nullptr if not used
const
void
*
p_sm_scale
;
// [1 ,n], smooth scale input, fp32, nullptr if not used
const
void
*
p_gamma
;
// [1, n], gamma, prec same as input
void
*
p_y
;
// [m, n], output, fp16/bf16
void
*
p_invRms
;
// [m, 1], output inv-rms, prec same as input, nullptr if not used
void
*
p_y
;
// [m, n], output, fp16/bf16
void
*
p_y_residual
;
// [m, n], shortcut output, prec same as input, nullptr if not used
void
*
p_y_scale
;
// [m, 1], output a dynamic quant per row, nullptr if not used
void
*
p_invRms
;
// [m, 1], output inv-rms, prec same as input, nullptr if not used
float
epsilon
;
index_t
m
;
index_t
n
;
index_t
stride
;
// row_stride
index_t
x_stride
;
// x row_stride
index_t
xr_stride
;
// x residule row stride
index_t
y_stride
;
// y row stride
index_t
yr_stride
;
// y residule row stride
};
// TODO: Extract some type to wrapper class
template
<
typename
Pipeline_
>
template
<
typename
Pipeline_
,
typename
Epilogue_
>
struct
Rmsnorm2dFwd
{
using
Pipeline
=
remove_cvref_t
<
Pipeline_
>
;
using
Epilogue
=
remove_cvref_t
<
Epilogue_
>
;
using
Problem
=
typename
Pipeline
::
Problem
;
using
XDataType
=
remove_cvref_t
<
typename
Problem
::
XDataType
>
;
using
GammaDataType
=
remove_cvref_t
<
typename
Problem
::
GammaDataType
>
;
using
ComputeDataType
=
remove_cvref_t
<
typename
Problem
::
ComputeDataType
>
;
using
YDataType
=
remove_cvref_t
<
typename
Problem
::
YDataType
>
;
using
InvRmsDataType
=
remove_cvref_t
<
typename
Problem
::
InvRmsDataType
>
;
using
XDataType
=
remove_cvref_t
<
typename
Problem
::
XDataType
>
;
using
GammaDataType
=
remove_cvref_t
<
typename
Problem
::
GammaDataType
>
;
using
ComputeDataType
=
remove_cvref_t
<
typename
Problem
::
ComputeDataType
>
;
using
YDataType
=
remove_cvref_t
<
typename
Problem
::
YDataType
>
;
using
InvRmsDataType
=
remove_cvref_t
<
typename
Problem
::
InvRmsDataType
>
;
using
SmoothScaleDataType
=
remove_cvref_t
<
typename
Problem
::
SmoothScaleDataType
>
;
using
YScaleDataType
=
remove_cvref_t
<
typename
Problem
::
YScaleDataType
>
;
// for simplicity, shortcut input/output type is same as X
using
XResidualDataType
=
XDataType
;
using
YResidualDataType
=
XDataType
;
static
constexpr
bool
kHasGamma
=
!
std
::
is_same_v
<
GammaDataType
,
null_type
>
;
static
constexpr
bool
kSaveInvRms
=
Problem
::
kSaveInvRms
;
static
constexpr
bool
kSaveInvRms
=
Problem
::
Traits
::
kSaveInvRms
;
static
constexpr
index_t
Block_M
=
Problem
::
BlockShape
::
Block_M
;
static
constexpr
index_t
Block_N
=
Problem
::
BlockShape
::
Block_N
;
static
constexpr
bool
kPadM
=
false
;
// always no need to pad along M
static
constexpr
bool
kPadN
=
Problem
::
kPadN
;
static
constexpr
bool
kTwoPass
=
Problem
::
kTwoPass
;
static
constexpr
index_t
Block_M
=
Problem
::
BlockShape
::
Block_M
;
static
constexpr
index_t
Block_N
=
Problem
::
BlockShape
::
Block_N
;
static
constexpr
bool
kPadM
=
false
;
// always no need to pad along M
static
constexpr
bool
kPadN
=
Problem
::
Traits
::
kPadN
;
static
constexpr
bool
kTwoPass
=
Problem
::
Traits
::
kTwoPass
;
static
constexpr
auto
kFusedAdd
=
Problem
::
Traits
::
kFusedAdd
;
static
constexpr
auto
kFusedQuant
=
Problem
::
Traits
::
kFusedQuant
;
static
constexpr
index_t
ThreadPerWarp_N
=
Problem
::
BlockShape
::
ThreadPerWarp_N
;
static
constexpr
index_t
Vector_N
=
Problem
::
BlockShape
::
Vector_N
;
...
...
@@ -56,29 +73,43 @@ struct Rmsnorm2dFwd
struct
Kargs
{
const
void
*
p_x
;
const
void
*
p_x_residual
;
const
void
*
p_sm_scale
;
const
void
*
p_gamma
;
void
*
p_y
;
void
*
p_y_residual
;
void
*
p_y_scale
;
void
*
p_invRms
;
float
epsilon
;
index_t
m
;
index_t
n
;
index_t
stride
;
// row_stride
index_t
x_stride
;
// x row_stride
index_t
xr_stride
;
// x residule row stride
index_t
y_stride
;
// y row stride
index_t
yr_stride
;
// y residule row stride
};
using
Hargs
=
Rmsnorm2dFwdHostArgs
;
CK_TILE_HOST
static
constexpr
Kargs
MakeKargs
(
const
Hargs
&
hargs
)
{
return
Kargs
{
hargs
.
p_x
,
hargs
.
p_x_residual
,
hargs
.
p_sm_scale
,
hargs
.
p_gamma
,
hargs
.
p_y
,
hargs
.
p_y_residual
,
hargs
.
p_y_scale
,
hargs
.
p_invRms
,
hargs
.
epsilon
,
hargs
.
m
,
hargs
.
n
,
hargs
.
stride
};
hargs
.
x_stride
,
hargs
.
xr_stride
,
hargs
.
y_stride
,
hargs
.
yr_stride
};
}
CK_TILE_HOST
static
constexpr
auto
GridSize
(
const
Hargs
&
hargs
)
...
...
@@ -95,6 +126,7 @@ struct Rmsnorm2dFwd
template
<
>
struct
t2s
<
ck_tile
::
bf16_t
>
{
static
constexpr
const
char
*
name
=
"bf16"
;
};
template
<
>
struct
t2s
<
ck_tile
::
fp8_t
>
{
static
constexpr
const
char
*
name
=
"fp8"
;
};
template
<
>
struct
t2s
<
ck_tile
::
bf8_t
>
{
static
constexpr
const
char
*
name
=
"bf8"
;
};
template
<
>
struct
t2s
<
ck_tile
::
int8_t
>
{
static
constexpr
const
char
*
name
=
"int8"
;
};
// clang-format on
// in byte
...
...
@@ -102,24 +134,41 @@ struct Rmsnorm2dFwd
CK_TILE_HOST
static
std
::
string
GetName
()
{
#define _SS_ std::string
#define _TS_ std::to_string
// clang-format off
using
S_
=
typename
Problem
::
BlockShape
;
auto
surfix
=
[
&
]
()
{
std
::
string
n
;
if
(
kFusedAdd
!=
Rmsnorm2dFusedAddEnum
::
NO_ADD
)
n
+=
_SS_
(
"_"
)
+
Rmsnorm2dFusedAddEnumName
<
kFusedAdd
>::
name
;
if
(
kFusedQuant
!=
Rmsnorm2dFusedQuantEnum
::
NO_SWEEP
)
n
+=
_SS_
(
"_"
)
+
Rmsnorm2dFusedQuantEnumName
<
kFusedQuant
>::
name
;
if
(
kPadN
)
n
+=
"_pn"
;
if
(
kSaveInvRms
)
n
+=
"_rms"
;
if
(
kTwoPass
)
n
+=
"_2p"
;
return
n
;
}();
#define _SS_ std::string
#define _TS_ std::to_string
return
_SS_
(
"rmsnorm2d_fwd_"
)
+
_SS_
(
t2s
<
XDataType
>::
name
)
+
"_"
+
auto
prec_str
=
[
&
]
()
{
std
::
string
base_str
=
_SS_
(
t2s
<
XDataType
>::
name
);
if
(
!
std
::
is_same_v
<
XDataType
,
YDataType
>
)
{
base_str
+=
_SS_
(
"_"
)
+
_SS_
(
t2s
<
YDataType
>::
name
);
}
if
(
kFusedQuant
==
Rmsnorm2dFusedQuantEnum
::
SMOOTH_DYNAMIC_QUANT
)
{
base_str
+=
_SS_
(
"_sx"
)
+
_SS_
(
t2s
<
SmoothScaleDataType
>::
name
);
base_str
+=
_SS_
(
"_sy"
)
+
_SS_
(
t2s
<
YScaleDataType
>::
name
);
}
if
(
kFusedQuant
==
Rmsnorm2dFusedQuantEnum
::
DYNAMIC_QUANT
)
{
base_str
+=
_SS_
(
"_sy"
)
+
_SS_
(
t2s
<
YScaleDataType
>::
name
);
}
return
base_str
;
}();
return
_SS_
(
"rmsnorm2d_fwd_"
)
+
_SS_
(
prec_str
)
+
"_"
+
_TS_
(
S_
::
Block_M
)
+
"x"
+
_TS_
(
S_
::
Block_N
)
+
"_"
+
_TS_
(
S_
::
WarpPerBlock_M
)
+
"x"
+
_TS_
(
S_
::
WarpPerBlock_N
)
+
"_"
+
_TS_
(
S_
::
Warp_M
)
+
"x"
+
_TS_
(
S_
::
Warp_N
)
+
"_"
+
_TS_
(
S_
::
Vector_M
)
+
"x"
+
_TS_
(
S_
::
Vector_N
)
+
"_"
+
_SS_
(
Pipeline
::
name
)
+
surfix
;
#undef _SS_
#undef _TS_
// clang-format on
#undef _SS_
#undef _TS_
}
CK_TILE_DEVICE
void
operator
()(
Kargs
kargs
)
const
...
...
@@ -130,7 +179,7 @@ struct Rmsnorm2dFwd
const
auto
tmp_
=
make_naive_tensor_view
<
address_space_enum
::
global
>
(
static_cast
<
const
XDataType
*>
(
kargs
.
p_x
),
make_tuple
(
kargs
.
m
,
kargs
.
n
),
make_tuple
(
kargs
.
stride
,
1
),
make_tuple
(
kargs
.
x_
stride
,
1
),
number
<
Vector_N
>
{},
number
<
1
>
{});
...
...
@@ -140,6 +189,29 @@ struct Rmsnorm2dFwd
tmp2_
,
make_tuple
(
number
<
Block_M
>
{},
number
<
Block_N
>
{}),
{
iM
,
0
});
}();
const
auto
x_residual_window
=
[
&
]()
{
if
constexpr
(
kFusedAdd
==
Rmsnorm2dFusedAddEnum
::
PRE_ADD
||
kFusedAdd
==
Rmsnorm2dFusedAddEnum
::
PRE_ADD_STORE
)
{
const
auto
tmp_
=
make_naive_tensor_view
<
address_space_enum
::
global
>
(
static_cast
<
const
XResidualDataType
*>
(
kargs
.
p_x_residual
),
make_tuple
(
kargs
.
m
,
kargs
.
n
),
make_tuple
(
kargs
.
xr_stride
,
1
),
number
<
Vector_N
>
{},
number
<
1
>
{});
const
auto
tmp2_
=
pad_tensor_view
(
tmp_
,
make_tuple
(
number
<
Block_M
>
{},
number
<
Block_N
>
{}),
sequence
<
kPadM
,
kPadN
>
{});
return
make_tile_window
(
tmp2_
,
make_tuple
(
number
<
Block_M
>
{},
number
<
Block_N
>
{}),
{
iM
,
0
});
}
else
{
return
make_null_tile_window
(
make_tuple
(
number
<
Block_M
>
{},
number
<
Block_N
>
{}));
}
}();
const
auto
gamma_window
=
[
&
]()
{
const
auto
tmp_
=
make_naive_tensor_view
<
address_space_enum
::
global
>
(
static_cast
<
const
GammaDataType
*>
(
kargs
.
p_gamma
),
...
...
@@ -158,7 +230,7 @@ struct Rmsnorm2dFwd
auto
tmp_
=
make_naive_tensor_view
<
address_space_enum
::
global
>
(
static_cast
<
YDataType
*>
(
kargs
.
p_y
),
make_tuple
(
kargs
.
m
,
kargs
.
n
),
make_tuple
(
kargs
.
stride
,
1
),
make_tuple
(
kargs
.
y_
stride
,
1
),
number
<
Vector_N
>
{},
number
<
1
>
{});
...
...
@@ -168,6 +240,28 @@ struct Rmsnorm2dFwd
tmp2_
,
make_tuple
(
number
<
Block_M
>
{},
number
<
Block_N
>
{}),
{
iM
,
0
});
}();
auto
y_residual_window
=
[
&
]()
{
if
constexpr
(
kFusedAdd
==
Rmsnorm2dFusedAddEnum
::
PRE_ADD_STORE
)
{
auto
tmp_
=
make_naive_tensor_view
<
address_space_enum
::
global
>
(
static_cast
<
YResidualDataType
*>
(
kargs
.
p_y_residual
),
make_tuple
(
kargs
.
m
,
kargs
.
n
),
make_tuple
(
kargs
.
yr_stride
,
1
),
number
<
Vector_N
>
{},
number
<
1
>
{});
auto
tmp2_
=
pad_tensor_view
(
tmp_
,
make_tuple
(
number
<
Block_M
>
{},
number
<
Block_N
>
{}),
sequence
<
kPadM
,
kPadN
>
{});
return
make_tile_window
(
tmp2_
,
make_tuple
(
number
<
Block_M
>
{},
number
<
Block_N
>
{}),
{
iM
,
0
});
}
else
{
return
make_null_tile_window
(
make_tuple
(
number
<
Block_M
>
{},
number
<
Block_N
>
{}));
}
}();
auto
inv_rms_window
=
[
&
]()
{
if
constexpr
(
kSaveInvRms
)
{
...
...
@@ -187,15 +281,62 @@ struct Rmsnorm2dFwd
return
make_null_tile_window
(
make_tuple
(
number
<
Block_M
>
{}));
}();
auto
sm_scale_window
=
[
&
]()
{
if
constexpr
(
kFusedQuant
==
Rmsnorm2dFusedQuantEnum
::
SMOOTH_DYNAMIC_QUANT
)
{
const
auto
win_
=
[
&
]()
{
const
auto
tmp_0_
=
make_naive_tensor_view_packed
<
address_space_enum
::
global
>
(
static_cast
<
const
SmoothScaleDataType
*>
(
kargs
.
p_sm_scale
),
make_tuple
(
kargs
.
n
),
number
<
Vector_N
>
{});
return
pad_tensor_view
(
tmp_0_
,
make_tuple
(
number
<
Block_N
>
{}),
sequence
<
false
>
{});
// sm_scale no need pad
}();
return
make_tile_window
(
win_
,
make_tuple
(
number
<
Block_N
>
{}),
{
0
});
}
else
{
return
make_null_tile_window
(
make_tuple
(
number
<
Block_N
>
{}));
}
}();
auto
y_scale_window
=
[
&
]()
{
if
constexpr
(
kFusedQuant
==
Rmsnorm2dFusedQuantEnum
::
SMOOTH_DYNAMIC_QUANT
||
kFusedQuant
==
Rmsnorm2dFusedQuantEnum
::
DYNAMIC_QUANT
)
{
const
auto
win_
=
[
&
]()
{
const
auto
tmp_0_
=
make_naive_tensor_view_packed
<
address_space_enum
::
global
>
(
static_cast
<
YScaleDataType
*>
(
kargs
.
p_y_scale
),
make_tuple
(
kargs
.
m
),
number
<
1
>
{});
return
pad_tensor_view
(
tmp_0_
,
make_tuple
(
number
<
Block_M
>
{}),
sequence
<
kPadM
>
{});
}();
return
make_tile_window
(
win_
,
make_tuple
(
number
<
Block_M
>
{}),
{
iM
});
}
else
{
return
make_null_tile_window
(
make_tuple
(
number
<
Block_M
>
{}));
}
}();
__shared__
char
smem
[
GetSmemSize
()];
Pipeline
{}(
x_window
,
x_residual_window
,
gamma_window
,
y_window
,
y_residual_window
,
inv_rms_window
,
sm_scale_window
,
y_scale_window
,
static_cast
<
const
ComputeDataType
>
(
kargs
.
epsilon
),
kargs
.
n
,
smem
);
smem
,
Epilogue
{});
}
};
...
...
include/ck_tile/ops/rmsnorm2d/pipeline/rmsnorm2d_fwd_pipeline_default_policy.hpp
View file @
0ef27d53
...
...
@@ -45,7 +45,7 @@ struct Rmsnorm2dFwdPipelineDefaultPolicy
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
GetBlockReduce2d
()
{
using
P_
=
BlockReduce2dProblem
<
typename
Problem
::
X
DataType
,
using
P_
=
BlockReduce2dProblem
<
typename
Problem
::
Compute
DataType
,
typename
Problem
::
ComputeDataType
,
typename
Problem
::
BlockShape
>
;
return
BlockReduce2d
<
P_
>
{};
...
...
@@ -54,7 +54,7 @@ struct Rmsnorm2dFwdPipelineDefaultPolicy
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
GetBlockReduce2dSync
()
{
using
P_
=
BlockReduce2dProblem
<
typename
Problem
::
X
DataType
,
using
P_
=
BlockReduce2dProblem
<
typename
Problem
::
Compute
DataType
,
typename
Problem
::
ComputeDataType
,
typename
Problem
::
BlockShape
>
;
return
BlockReduce2dSync
<
P_
>
{};
...
...
@@ -63,7 +63,7 @@ struct Rmsnorm2dFwdPipelineDefaultPolicy
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
GetBlockReduce2dCrossWarpSync
()
{
using
P_
=
BlockReduce2dProblem
<
typename
Problem
::
X
DataType
,
using
P_
=
BlockReduce2dProblem
<
typename
Problem
::
Compute
DataType
,
typename
Problem
::
ComputeDataType
,
typename
Problem
::
BlockShape
>
;
return
BlockReduce2dCrossWarpSync
<
P_
>
{};
...
...
@@ -74,13 +74,13 @@ struct Rmsnorm2dFwdPipelineDefaultPolicy
{
if
constexpr
(
Problem
::
kNeedCrossWarpSync
)
{
using
P_
=
BlockReduce2dProblem
<
typename
Problem
::
X
DataType
,
using
P_
=
BlockReduce2dProblem
<
typename
Problem
::
Compute
DataType
,
typename
Problem
::
ComputeDataType
,
typename
Problem
::
BlockShape
>
;
using
block_reduce2d
=
BlockReduce2d
<
P_
>
;
using
x_block_tile
=
decltype
(
make_static_distributed_tensor
<
typename
Problem
::
X
DataType
>
(
decltype
(
make_static_distributed_tensor
<
typename
Problem
::
Compute
DataType
>
(
MakeXBlockTileDistribution
<
Problem
>
()));
using
y_block_tile
=
decltype
(
block_reduce2d
::
template
MakeYBlockTile
<
x_block_tile
>());
...
...
include/ck_tile/ops/rmsnorm2d/pipeline/rmsnorm2d_fwd_pipeline_one_pass.hpp
View file @
0ef27d53
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-202
4
, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-202
5
, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
...
...
@@ -22,12 +22,17 @@ struct Rmsnorm2dFwdPipelineOnePass
using
YDataType
=
ck_tile
::
remove_cvref_t
<
typename
Problem
::
YDataType
>
;
using
InvRmsDataType
=
ck_tile
::
remove_cvref_t
<
typename
Problem
::
InvRmsDataType
>
;
using
XResidualDataType
=
XDataType
;
using
YResidualDataType
=
XDataType
;
static
constexpr
bool
kHasGamma
=
!
std
::
is_same_v
<
GammaDataType
,
ck_tile
::
null_type
>
;
static
constexpr
bool
kSaveInvRms
=
Problem
::
kSaveInvRms
;
static
constexpr
bool
kSaveInvRms
=
Problem
::
Traits
::
kSaveInvRms
;
static
constexpr
bool
kNeedCrossWarpSync
=
Problem
::
kNeedCrossWarpSync
;
static
constexpr
bool
kPadM
=
false
;
// TODO - BlockRmsnorm2dFwdProblem::kPadM
static
constexpr
bool
kPadN
=
Problem
::
kPadN
;
static
constexpr
bool
kPadN
=
Problem
::
Traits
::
kPadN
;
static
constexpr
auto
kFusedAdd
=
Problem
::
Traits
::
kFusedAdd
;
static
constexpr
auto
kFusedQuant
=
Problem
::
Traits
::
kFusedQuant
;
static
constexpr
const
char
*
name
=
[]()
{
if
constexpr
(
kNeedCrossWarpSync
)
...
...
@@ -41,19 +46,36 @@ struct Rmsnorm2dFwdPipelineOnePass
return
Policy
::
template
GetSmemSize
<
Problem
>();
}
template
<
typename
XWindow
,
typename
GammaWindow
,
typename
YWindow
,
typename
InvRmsWindow
>
template
<
typename
XWindow
,
typename
XResidualWindow
,
typename
GammaWindow
,
typename
YWindow
,
typename
YResidualWindow
,
typename
InvRmsWindow
,
typename
SmoothScaleWindow
,
typename
YScaleWindow
,
typename
Epilogue
>
CK_TILE_DEVICE
auto
operator
()(
const
XWindow
&
x_window_
,
const
XResidualWindow
&
x_residual_window_
,
const
GammaWindow
&
gamma_window_
,
YWindow
&
y_window
,
YWindow
&
y_window_
,
const
YResidualWindow
&
y_residual_window_
,
InvRmsWindow
&
inv_rms_window
,
const
SmoothScaleWindow
&
sm_scale_window_
,
YScaleWindow
&
y_scale_window_
,
ComputeDataType
epsilon
,
ck_tile
::
index_t
row_size
,
void
*
smem
)
const
void
*
smem
,
Epilogue
)
const
{
const
auto
x_window
=
make_tile_window
(
x_window_
,
Policy
::
template
MakeXBlockTileDistribution
<
Problem
>());
const
auto
gamma_window
=
make_tile_window
(
gamma_window_
,
Policy
::
template
MakeGammaBlockTileDistribution
<
Problem
>());
const
auto
x_residual_window
=
make_tile_window
(
x_residual_window_
,
Policy
::
template
MakeXBlockTileDistribution
<
Problem
>());
auto
y_residual_window
=
make_tile_window
(
y_residual_window_
,
Policy
::
template
MakeXBlockTileDistribution
<
Problem
>());
auto
reduce_square_sum_func
=
ReduceOp
::
SquareAdd
{};
auto
reduce_sum_func
=
ReduceOp
::
Add
{};
...
...
@@ -62,13 +84,31 @@ struct Rmsnorm2dFwdPipelineOnePass
auto
block_reduce2d_cross_warp_sync
=
Policy
::
template
GetBlockReduce2dCrossWarpSync
<
Problem
>();
const
auto
x
=
load_tile
(
x_window
);
auto
x
=
load_tile
(
x_window
);
auto
x_resi
=
load_tile
(
x_residual_window
);
// load gamma (TODO: support no gamma?)
const
auto
gamma
=
load_tile
(
gamma_window
);
auto
acc
=
cast_tile
<
ComputeDataType
>
(
x
);
if
constexpr
(
kFusedAdd
==
Rmsnorm2dFusedAddEnum
::
PRE_ADD
||
kFusedAdd
==
Rmsnorm2dFusedAddEnum
::
PRE_ADD_STORE
)
{
sweep_tile
(
x_resi
,
[
&
](
auto
idx
)
{
// compute x = x_resi + x
acc
(
idx
)
=
type_convert
<
ComputeDataType
>
(
x_resi
(
idx
))
+
acc
(
idx
);
});
if
constexpr
(
kFusedAdd
==
Rmsnorm2dFusedAddEnum
::
PRE_ADD_STORE
)
{
store_tile
(
y_residual_window
,
cast_tile
<
YResidualDataType
>
(
acc
));
}
}
// compute mean square each-thread->cross-lane->cross-warp
auto
square_sum
=
block_reduce2d
(
x
,
reduce_square_sum_func
.
GetIdentityValue
<
ComputeDataType
>
(),
reduce_square_sum_func
);
auto
square_sum
=
block_reduce2d
(
acc
,
reduce_square_sum_func
.
GetIdentityValue
<
ComputeDataType
>
(),
reduce_square_sum_func
);
block_reduce2d_sync
(
square_sum
,
reduce_sum_func
);
block_reduce2d_cross_warp_sync
(
square_sum
,
smem
,
reduce_sum_func
);
...
...
@@ -83,19 +123,30 @@ struct Rmsnorm2dFwdPipelineOnePass
store_tile
(
inv_rms_window
,
cast_tile
<
InvRmsDataType
>
(
inv_rms
));
// rmsnorm computation
auto
y
=
make_static_distributed_tensor
<
Y
DataType
>
(
x
.
get_tile_distribution
());
sweep_tile
(
y
,
[
&
,
inv_rms_
=
inv_rms
](
auto
idx
)
{
auto
rmsn
=
make_static_distributed_tensor
<
Compute
DataType
>
(
x
.
get_tile_distribution
());
sweep_tile
(
rmsn
,
[
&
,
inv_rms_
=
inv_rms
](
auto
idx
)
{
constexpr
auto
i_idx
=
make_tuple
(
idx
[
number
<
0
>
{}]);
constexpr
auto
j_idx
=
make_tuple
(
idx
[
number
<
1
>
{}]);
const
auto
gamma_
=
type_convert
<
ComputeDataType
>
(
gamma
[
j_idx
]);
const
auto
x_
=
type_convert
<
ComputeDataType
>
(
x
[
idx
]);
auto
y_
=
x_
*
inv_rms_
[
i_idx
]
*
gamma_
;
auto
rmsn_
=
acc
[
idx
]
*
inv_rms_
[
i_idx
]
*
gamma_
;
y
(
idx
)
=
type_convert
<
YDataType
>
(
y_
)
;
rmsn
(
idx
)
=
rmsn_
;
});
store_tile
(
y_window
,
y
);
if
constexpr
(
kFusedQuant
==
Rmsnorm2dFusedQuantEnum
::
SMOOTH_DYNAMIC_QUANT
)
{
Epilogue
{}(
y_window_
,
sm_scale_window_
,
y_scale_window_
,
rmsn
,
smem
);
}
else
if
constexpr
(
kFusedQuant
==
Rmsnorm2dFusedQuantEnum
::
DYNAMIC_QUANT
)
{
Epilogue
{}(
y_window_
,
y_scale_window_
,
rmsn
,
smem
);
}
else
{
Epilogue
{}(
y_window_
,
rmsn
);
}
}
};
}
// namespace ck_tile
include/ck_tile/ops/rmsnorm2d/pipeline/rmsnorm2d_fwd_pipeline_problem.hpp
View file @
0ef27d53
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-202
4
, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-202
5
, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
...
...
@@ -12,25 +12,25 @@ template <typename XDataType_,
typename
ComputeDataType_
,
typename
YDataType_
,
typename
InvRmsDataType_
,
typename
SmoothScaleDataType_
,
typename
YScaleDataType_
,
typename
BlockShape_
,
bool
kPadN_
,
bool
kSaveInvRms_
,
bool
kTwoPass_
>
typename
Traits_
>
struct
Rmsnorm2dFwdPipelineProblem
{
using
XDataType
=
remove_cvref_t
<
XDataType_
>
;
using
GammaDataType
=
remove_cvref_t
<
GammaDataType_
>
;
using
ComputeDataType
=
remove_cvref_t
<
ComputeDataType_
>
;
using
YDataType
=
remove_cvref_t
<
YDataType_
>
;
using
InvRmsDataType
=
remove_cvref_t
<
InvRmsDataType_
>
;
using
BlockShape
=
remove_cvref_t
<
BlockShape_
>
;
using
XDataType
=
remove_cvref_t
<
XDataType_
>
;
using
GammaDataType
=
remove_cvref_t
<
GammaDataType_
>
;
using
ComputeDataType
=
remove_cvref_t
<
ComputeDataType_
>
;
using
YDataType
=
remove_cvref_t
<
YDataType_
>
;
using
InvRmsDataType
=
remove_cvref_t
<
InvRmsDataType_
>
;
using
SmoothScaleDataType
=
remove_cvref_t
<
SmoothScaleDataType_
>
;
using
YScaleDataType
=
remove_cvref_t
<
YScaleDataType_
>
;
using
BlockShape
=
remove_cvref_t
<
BlockShape_
>
;
static
constexpr
bool
kNeedCrossLaneSync
=
BlockShape
::
ThreadPerWarp_N
>
1
;
static
constexpr
bool
kNeedCrossWarpSync
=
BlockShape
::
WarpPerBlock_N
>
1
;
static
constexpr
bool
kPadN
=
kPadN_
;
static
constexpr
bool
kSaveInvRms
=
kSaveInvRms_
;
static
constexpr
bool
kTwoPass
=
kTwoPass_
;
using
Traits
=
remove_cvref_t
<
Traits_
>
;
};
}
// namespace ck_tile
include/ck_tile/ops/rmsnorm2d/pipeline/rmsnorm2d_fwd_pipeline_two_pass.hpp
View file @
0ef27d53
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-202
4
, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-202
5
, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
...
...
@@ -22,12 +22,17 @@ struct Rmsnorm2dFwdPipelineTwoPass
using
YDataType
=
ck_tile
::
remove_cvref_t
<
typename
Problem
::
YDataType
>
;
using
InvRmsDataType
=
ck_tile
::
remove_cvref_t
<
typename
Problem
::
InvRmsDataType
>
;
using
XResidualDataType
=
XDataType
;
using
YResidualDataType
=
XDataType
;
static
constexpr
bool
kHasGamma
=
!
std
::
is_same_v
<
GammaDataType
,
ck_tile
::
null_type
>
;
static
constexpr
bool
kSaveInvRms
=
Problem
::
kSaveInvRms
;
static
constexpr
bool
kSaveInvRms
=
Problem
::
Traits
::
kSaveInvRms
;
static
constexpr
bool
kNeedCrossWarpSync
=
Problem
::
kNeedCrossWarpSync
;
static
constexpr
bool
kPadM
=
false
;
// TODO - BlockRmsnorm2dFwdProblem::kPadM
static
constexpr
bool
kPadN
=
Problem
::
kPadN
;
static
constexpr
bool
kPadN
=
Problem
::
Traits
::
kPadN
;
static
constexpr
auto
kFusedAdd
=
Problem
::
Traits
::
kFusedAdd
;
static
constexpr
auto
kFusedQuant
=
Problem
::
Traits
::
kFusedQuant
;
static
constexpr
const
char
*
name
=
[]()
{
if
constexpr
(
kNeedCrossWarpSync
)
...
...
@@ -41,19 +46,36 @@ struct Rmsnorm2dFwdPipelineTwoPass
return
Policy
::
template
GetSmemSize
<
Problem
>();
}
template
<
typename
XWindow
,
typename
GammaWindow
,
typename
YWindow
,
typename
InvRmsWindow
>
template
<
typename
XWindow
,
typename
XResidualWindow
,
typename
GammaWindow
,
typename
YWindow
,
typename
YResidualWindow
,
typename
InvRmsWindow
,
typename
SmoothScaleWindow
,
typename
YScaleWindow
,
typename
Epilogue
>
CK_TILE_DEVICE
auto
operator
()(
const
XWindow
&
x_window_
,
const
XResidualWindow
&
x_residual_window_
,
const
GammaWindow
&
gamma_window_
,
YWindow
&
y_window
,
const
YResidualWindow
&
y_residual_window_
,
InvRmsWindow
&
inv_rms_window
,
const
SmoothScaleWindow
&
/*sm_scale_window_*/
,
YScaleWindow
&
/*y_scale_window*/
,
ComputeDataType
epsilon
,
ck_tile
::
index_t
row_size
,
void
*
smem
)
const
void
*
smem
,
Epilogue
)
const
{
auto
x_window
=
make_tile_window
(
x_window_
,
Policy
::
template
MakeXBlockTileDistribution
<
Problem
>());
auto
gamma_window
=
make_tile_window
(
gamma_window_
,
Policy
::
template
MakeGammaBlockTileDistribution
<
Problem
>());
auto
x_residual_window
=
make_tile_window
(
x_residual_window_
,
Policy
::
template
MakeXBlockTileDistribution
<
Problem
>());
auto
y_residual_window
=
make_tile_window
(
y_residual_window_
,
Policy
::
template
MakeXBlockTileDistribution
<
Problem
>());
// Problem::BlockShape
static
constexpr
index_t
Block_N
=
Problem
::
BlockShape
::
Block_N
;
...
...
@@ -67,15 +89,34 @@ struct Rmsnorm2dFwdPipelineTwoPass
auto
block_reduce2d_cross_warp_sync
=
Policy
::
template
GetBlockReduce2dCrossWarpSync
<
Problem
>();
using
X
TensorType
=
decltype
(
load_tile
(
x_window
));
auto
square_sum
=
block_reduce2d
.
template
MakeYBlockTile
<
X
TensorType
>();
using
Compute
TensorType
=
decltype
(
cast_tile
<
ComputeDataType
>
(
load_tile
(
x_window
))
)
;
auto
square_sum
=
block_reduce2d
.
template
MakeYBlockTile
<
Compute
TensorType
>();
set_tile
(
square_sum
,
reduce_square_sum_func
.
GetIdentityValue
<
ComputeDataType
>
());
for
(
int
iN
=
__builtin_amdgcn_readfirstlane
(
0
);
iN
<
num_n_tile_iteration
;
++
iN
)
{
const
auto
x
=
load_tile
(
x_window
);
block_reduce2d
(
x
,
square_sum
,
reduce_square_sum_func
);
auto
x
=
load_tile
(
x_window
);
auto
x_resi
=
load_tile
(
x_residual_window
);
move_tile_window
(
x_window
,
{
0
,
Block_N
});
move_tile_window
(
x_residual_window
,
{
0
,
Block_N
});
auto
acc
=
cast_tile
<
ComputeDataType
>
(
x
);
if
constexpr
(
kFusedAdd
==
Rmsnorm2dFusedAddEnum
::
PRE_ADD
||
kFusedAdd
==
Rmsnorm2dFusedAddEnum
::
PRE_ADD_STORE
)
{
sweep_tile
(
x_resi
,
[
&
](
auto
idx
)
{
// compute x = x_resi + x
acc
(
idx
)
=
type_convert
<
ComputeDataType
>
(
x_resi
(
idx
))
+
acc
(
idx
);
});
if
constexpr
(
kFusedAdd
==
Rmsnorm2dFusedAddEnum
::
PRE_ADD_STORE
)
{
store_tile
(
y_residual_window
,
cast_tile
<
YResidualDataType
>
(
acc
));
move_tile_window
(
y_residual_window
,
{
0
,
Block_N
});
}
}
block_reduce2d
(
acc
,
square_sum
,
reduce_square_sum_func
);
}
block_reduce2d_sync
(
square_sum
,
reduce_sum_func
);
...
...
@@ -96,33 +137,47 @@ struct Rmsnorm2dFwdPipelineTwoPass
row_size
%
Block_N
==
0
?
row_size
-
Block_N
:
row_size
-
row_size
%
Block_N
;
move_tile_window
(
x_window
,
{
0
,
-
Block_N
});
move_tile_window
(
x_residual_window
,
{
0
,
-
Block_N
});
move_tile_window
(
gamma_window
,
{
stride_to_right_most_window
});
move_tile_window
(
y_window
,
{
0
,
stride_to_right_most_window
});
// rmsnorm computation
for
(
int
iN
=
__builtin_amdgcn_readfirstlane
(
0
);
iN
<
num_n_tile_iteration
;
++
iN
)
{
const
auto
x
=
load_tile
(
x_window
);
// load gamma/beta (TODO: support no gamma/beta?)
auto
x
=
load_tile
(
x_window
);
auto
x_resi
=
load_tile
(
x_residual_window
);
auto
acc
=
cast_tile
<
ComputeDataType
>
(
x
);
if
constexpr
(
kFusedAdd
==
Rmsnorm2dFusedAddEnum
::
PRE_ADD_STORE
||
kFusedAdd
==
Rmsnorm2dFusedAddEnum
::
PRE_ADD
)
{
sweep_tile
(
x_resi
,
[
&
](
auto
idx
)
{
// compute x = x_resi + x
acc
(
idx
)
=
type_convert
<
ComputeDataType
>
(
x_resi
(
idx
))
+
acc
(
idx
);
});
}
// load gamma (TODO: support no gamma?)
const
auto
gamma
=
load_tile
(
gamma_window
);
auto
y
=
make_static_distributed_tensor
<
YDataType
>
(
x
.
get_tile_distribution
());
sweep_tile
(
y
,
[
&
,
inv_rms_
=
inv_rms
](
auto
idx
)
{
// rmsnorm computation
auto
rmsn
=
make_static_distributed_tensor
<
ComputeDataType
>
(
x
.
get_tile_distribution
());
sweep_tile
(
rmsn
,
[
&
,
inv_rms_
=
inv_rms
](
auto
idx
)
{
constexpr
auto
i_idx
=
make_tuple
(
idx
[
number
<
0
>
{}]);
constexpr
auto
j_idx
=
make_tuple
(
idx
[
number
<
1
>
{}]);
const
auto
gamma_
=
type_convert
<
ComputeDataType
>
(
gamma
[
j_idx
]);
const
auto
x_
=
type_convert
<
ComputeDataType
>
(
x
[
idx
]);
auto
y_
=
x_
*
inv_rms_
[
i_idx
]
*
gamma_
;
auto
rmsn_
=
acc
(
idx
)
*
inv_rms_
[
i_idx
]
*
gamma_
;
y
(
idx
)
=
type_convert
<
YDataType
>
(
y_
)
;
rmsn
(
idx
)
=
rmsn_
;
});
store_tile
(
y_window
,
y
);
static_assert
(
kFusedQuant
==
Rmsnorm2dFusedQuantEnum
::
NO_SWEEP
);
Epilogue
{}(
y_window
,
rmsn
);
move_tile_window
(
x_window
,
{
0
,
-
Block_N
});
move_tile_window
(
x_residual_window
,
{
0
,
-
Block_N
});
move_tile_window
(
gamma_window
,
{
-
Block_N
});
move_tile_window
(
y_window
,
{
0
,
-
Block_N
});
}
...
...
include/ck_tile/ops/rmsnorm2d/pipeline/rmsnorm2d_fwd_traits.hpp
0 → 100644
View file @
0ef27d53
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core/utility/type_traits.hpp"
namespace
ck_tile
{
enum
class
Rmsnorm2dFusedAddEnum
{
NO_ADD
=
0
,
// fused add before RMSNorm and store result to global
PRE_ADD_STORE
=
1
,
// fused add before RMSNorm, but not store result
PRE_ADD
=
2
,
};
// clang-format off
template
<
Rmsnorm2dFusedAddEnum
>
struct
Rmsnorm2dFusedAddEnumName
;
template
<
>
struct
Rmsnorm2dFusedAddEnumName
<
Rmsnorm2dFusedAddEnum
::
NO_ADD
>
{
static
constexpr
const
char
*
name
=
"no"
;
};
template
<
>
struct
Rmsnorm2dFusedAddEnumName
<
Rmsnorm2dFusedAddEnum
::
PRE_ADD_STORE
>
{
static
constexpr
const
char
*
name
=
"pras"
;
};
template
<
>
struct
Rmsnorm2dFusedAddEnumName
<
Rmsnorm2dFusedAddEnum
::
PRE_ADD
>
{
static
constexpr
const
char
*
name
=
"pra"
;
};
// clang-format on
enum
class
Rmsnorm2dFusedQuantEnum
{
NO_SWEEP
=
0
,
SMOOTH_DYNAMIC_QUANT
=
1
,
// smooth oulier + rowwise quant, need input x-scale and store y_scale
DYNAMIC_QUANT
=
2
,
// rowwise quant, store out a y-scale
};
// clang-format off
template
<
Rmsnorm2dFusedQuantEnum
>
struct
Rmsnorm2dFusedQuantEnumName
;
template
<
>
struct
Rmsnorm2dFusedQuantEnumName
<
Rmsnorm2dFusedQuantEnum
::
NO_SWEEP
>
{
static
constexpr
const
char
*
name
=
"no"
;
};
template
<
>
struct
Rmsnorm2dFusedQuantEnumName
<
Rmsnorm2dFusedQuantEnum
::
DYNAMIC_QUANT
>
{
static
constexpr
const
char
*
name
=
"dqt"
;
};
template
<
>
struct
Rmsnorm2dFusedQuantEnumName
<
Rmsnorm2dFusedQuantEnum
::
SMOOTH_DYNAMIC_QUANT
>
{
static
constexpr
const
char
*
name
=
"smdqt"
;
};
// clang-format on
template
<
bool
kPadN_
,
bool
kSaveInvRms_
,
bool
kTwoPass_
,
Rmsnorm2dFusedAddEnum
kFusedAdd_
,
Rmsnorm2dFusedQuantEnum
kFusedQuant_
>
struct
Rmsnorm2dFwdTraits
{
static
constexpr
bool
kPadN
=
kPadN_
;
static
constexpr
bool
kSaveInvRms
=
kSaveInvRms_
;
static
constexpr
bool
kTwoPass
=
kTwoPass_
;
static
constexpr
Rmsnorm2dFusedAddEnum
kFusedAdd
=
kFusedAdd_
;
static
constexpr
Rmsnorm2dFusedQuantEnum
kFusedQuant
=
kFusedQuant_
;
};
}
// namespace ck_tile
include/ck_tile/ops/smoothquant.hpp
View file @
0ef27d53
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-202
4
, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-202
5
, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
...
...
include/ck_tile/ops/smoothquant/kernel/moe_smoothquant_kernel.hpp
View file @
0ef27d53
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-202
4
, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-202
5
, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
...
...
@@ -12,7 +12,7 @@ namespace ck_tile {
struct
MoeSmoothquantHostArgs
{
const
void
*
p_x
;
// [tokens ,hidden_size], input, fp16/bf16
const
void
*
p_
x
scale
;
// [experts, hidden_size], input, columnwise scale, fp32
const
void
*
p_
sm
scale
;
// [experts, hidden_size], input, columnwise scale, fp32
const
void
*
p_topk_ids
;
// [tokens, topk]
void
*
p_yscale
;
// [topk * tokens, 1], output, rowwise quant scale
...
...
@@ -33,11 +33,11 @@ struct MoeSmoothquant
using
Pipeline
=
remove_cvref_t
<
Pipeline_
>
;
using
Problem
=
typename
Pipeline
::
Problem
;
using
XDataType
=
remove_cvref_t
<
typename
Problem
::
XDataType
>
;
using
X
ScaleDataType
=
remove_cvref_t
<
typename
Problem
::
X
ScaleDataType
>
;
using
ComputeDataType
=
remove_cvref_t
<
typename
Problem
::
ComputeDataType
>
;
using
YScaleDataType
=
remove_cvref_t
<
typename
Problem
::
YScaleDataType
>
;
using
QYDataType
=
remove_cvref_t
<
typename
Problem
::
QYDataType
>
;
using
XDataType
=
remove_cvref_t
<
typename
Problem
::
XDataType
>
;
using
Smooth
ScaleDataType
=
remove_cvref_t
<
typename
Problem
::
Smooth
ScaleDataType
>
;
using
ComputeDataType
=
remove_cvref_t
<
typename
Problem
::
ComputeDataType
>
;
using
YScaleDataType
=
remove_cvref_t
<
typename
Problem
::
YScaleDataType
>
;
using
QYDataType
=
remove_cvref_t
<
typename
Problem
::
QYDataType
>
;
static
constexpr
index_t
Block_M
=
Problem
::
BlockShape
::
Block_M
;
static
constexpr
index_t
Block_N
=
Problem
::
BlockShape
::
Block_N
;
...
...
@@ -57,7 +57,7 @@ struct MoeSmoothquant
struct
Kargs
{
const
void
*
p_x
;
// [tokens ,hidden_size], input, fp16/bf16
const
void
*
p_
x
scale
;
// [experts, hidden_size], input, columnwise scale, fp32
const
void
*
p_
sm
scale
;
// [experts, hidden_size], input, columnwise scale, fp32
const
void
*
p_topk_ids
;
// [tokens, topk]
void
*
p_yscale
;
// [topk, tokens, 1], output, rowwise quant scale
...
...
@@ -75,7 +75,7 @@ struct MoeSmoothquant
CK_TILE_HOST
static
constexpr
Kargs
MakeKargs
(
const
Hargs
&
hargs
)
{
return
Kargs
{
hargs
.
p_x
,
hargs
.
p_
x
scale
,
hargs
.
p_
sm
scale
,
hargs
.
p_topk_ids
,
hargs
.
p_yscale
,
hargs
.
p_qy
,
...
...
@@ -153,9 +153,10 @@ struct MoeSmoothquant
}();
// [experts, hidden_size],
const
auto
x
scale_window
=
[
&
]()
{
const
auto
sm
scale_window
=
[
&
]()
{
const
auto
tmp_
=
make_naive_tensor_view
<
address_space_enum
::
global
>
(
static_cast
<
const
XScaleDataType
*>
(
kargs
.
p_xscale
)
+
i_expert
*
kargs
.
hidden_size
,
static_cast
<
const
SmoothScaleDataType
*>
(
kargs
.
p_smscale
)
+
i_expert
*
kargs
.
hidden_size
,
make_tuple
(
kargs
.
hidden_size
),
make_tuple
(
1
),
number
<
Vector_N
>
{},
...
...
@@ -198,7 +199,7 @@ struct MoeSmoothquant
__shared__
char
smem
[
GetSmemSize
()];
Pipeline
{}(
x_window
,
x
scale_window
,
yscale_window
,
qy_window
,
kargs
.
hidden_size
,
smem
);
Pipeline
{}(
x_window
,
sm
scale_window
,
yscale_window
,
qy_window
,
kargs
.
hidden_size
,
smem
);
}
};
...
...
Prev
1
…
7
8
9
10
11
12
13
14
15
…
18
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment