Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel_ROCM
Commits
54617a85
Commit
54617a85
authored
Jan 22, 2025
by
Jiming Ruan
Browse files
Adds support to Welford algorithm and fast div for rmsnorm
parent
c2ea75ed
Changes
5
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
164 additions
and
84 deletions
+164
-84
include/ck_tile/ops/norm_reduce/block/block_norm_reduce.hpp
include/ck_tile/ops/norm_reduce/block/block_norm_reduce.hpp
+55
-30
include/ck_tile/ops/rmsnorm2d/pipeline/rmsnorm2d_fwd_pipeline_default_policy.hpp
...norm2d/pipeline/rmsnorm2d_fwd_pipeline_default_policy.hpp
+38
-24
include/ck_tile/ops/rmsnorm2d/pipeline/rmsnorm2d_fwd_pipeline_one_pass.hpp
...ps/rmsnorm2d/pipeline/rmsnorm2d_fwd_pipeline_one_pass.hpp
+32
-16
include/ck_tile/ops/rmsnorm2d/pipeline/rmsnorm2d_fwd_pipeline_two_pass.hpp
...ps/rmsnorm2d/pipeline/rmsnorm2d_fwd_pipeline_two_pass.hpp
+35
-14
include/ck_tile/ops/rmsnorm2d/pipeline/rmsnorm2d_fwd_traits.hpp
...e/ck_tile/ops/rmsnorm2d/pipeline/rmsnorm2d_fwd_traits.hpp
+4
-0
No files found.
include/ck_tile/ops/norm_reduce/block/block_norm_reduce.hpp
View file @
54617a85
...
...
@@ -36,7 +36,7 @@ struct BlockNormReduce
constexpr
auto
I1
=
number
<
1
>
{};
constexpr
auto
spans
=
XDistributedTensor_
::
get_distributed_spans
();
constexpr
bool
comp
uteVariance
=
constexpr
bool
comp
_var
=
!
std
::
is_same
<
VarDistributedTensor_
,
null_tensor
>::
value
&&
kComputeVariance
;
sweep_tile_span
(
spans
[
I1
],
[
&
](
auto
dstr_idx_i1
)
{
...
...
@@ -50,7 +50,7 @@ struct BlockNormReduce
auto
x
=
ck_tile
::
type_convert
<
ComputeDataType
>
(
x_tensor
[
in_dstr_idx
]);
if
(
kWelford
)
{
if
constexpr
(
comp
uteVariance
)
if
constexpr
(
comp
_var
)
{
welford_update
(
mean_tensor
(
out_dstr_idx
),
var_tensor
(
out_dstr_idx
),
...
...
@@ -67,7 +67,7 @@ struct BlockNormReduce
else
{
mean_tensor
(
out_dstr_idx
)
+=
x
;
if
constexpr
(
comp
uteVariance
)
if
constexpr
(
comp
_var
)
{
var_tensor
(
out_dstr_idx
)
+=
x
*
x
;
}
...
...
@@ -98,7 +98,8 @@ struct BlockNormReduce
int
&
cur_count_
,
const
int
&
max_count_
)
{
Impl
(
x_tensor
,
mean_tensor
,
null_tensor
{},
cur_count_
,
max_count_
);
auto
nt
=
null_tensor
{};
Impl
(
x_tensor
,
mean_tensor
,
nt
,
cur_count_
,
max_count_
);
}
template
<
typename
XDistributedTensor_
>
...
...
@@ -152,31 +153,39 @@ struct BlockNormReduceSync
using
DstrEncode
=
typename
Dstr
::
DstrEncode
;
using
DstrEncodeDetail
=
typename
DstrEncode
::
detail
;
static_assert
(
std
::
is_same_v
<
Dstr
,
typename
VarDistributedTensor_
::
StaticTileDistribution
>
,
"wrong!"
);
constexpr
index_t
NDimP
=
Dstr
::
get_num_of_dimension_p
();
constexpr
index_t
NDimR
=
Dstr
::
get_num_of_dimension_r
();
constexpr
index_t
idim_p_lane
=
NDimP
-
1
;
constexpr
bool
comp
uteVariance
=
constexpr
bool
comp
_var
=
!
std
::
is_same
<
VarDistributedTensor_
,
null_tensor
>::
value
&&
kComputeVariance
;
constexpr
index_t
thread_buf_size
=
MeanDistributedTensor_
::
get_thread_buffer_size
();
if
constexpr
(
comp_var
)
{
static_assert
(
std
::
is_same_v
<
Dstr
,
typename
VarDistributedTensor_
::
StaticTileDistribution
>
,
"wrong!"
);
static_assert
(
thread_buf_size
==
VarDistributedTensor_
::
get_thread_buffer_size
());
}
// const auto ps_idx = make_array<index_t>(get_warp_id(), get_lane_id());
// const auto rs_idx =
// mean_tensor.get_tile_distribution().calculate_rs_index_from_ps_index(ps_idx);
constexpr
index_t
thread_buf_size
=
MeanDistributedTensor_
::
get_thread_buffer_size
();
static_assert
((
computeVariance
==
false
)
||
(
thread_buf_size
==
VarDistributedTensor_
::
get_thread_buffer_size
()));
const
int
original_count
=
count
;
// loop over thread data
static_for
<
0
,
thread_buf_size
,
1
>
{}([
&
](
auto
i
)
{
auto
v_local_mean
=
mean_tensor
.
get_thread_buffer
()[
i
];
auto
v_local_var
=
computeVariance
?
var_tensor
.
get_thread_buffer
()[
i
]
:
0
;
auto
v_local_mean
=
mean_tensor
.
get_thread_buffer
()[
i
];
auto
v_local_var
=
[
&
]()
{
if
constexpr
(
comp_var
)
return
var_tensor
.
get_thread_buffer
()[
i
];
else
return
0
;
}();
auto
v_local_count
=
original_count
;
// cross-lane reduce for replication
...
...
@@ -206,13 +215,13 @@ struct BlockNormReduceSync
// pull data from remote lane
const
auto
v_remote_mean
=
warp_shuffle
(
v_local_mean
,
src_lane
);
const
auto
v_remote_var
=
comp
uteVariance
?
warp_shuffle
(
v_local_var
,
src_lane
)
:
0
;
comp
_var
?
warp_shuffle
(
v_local_var
,
src_lane
)
:
0
;
if
(
kWelford
)
{
const
auto
v_remote_count
=
warp_shuffle
(
v_local_count
,
src_lane
);
// norm_reduce merge
if
constexpr
(
comp
uteVariance
)
if
constexpr
(
comp
_var
)
{
welford_merge
(
v_local_mean
,
v_local_var
,
...
...
@@ -234,7 +243,7 @@ struct BlockNormReduceSync
else
{
v_local_mean
+=
v_remote_mean
;
if
constexpr
(
comp
uteVariance
)
if
constexpr
(
comp
_var
)
{
v_local_var
+=
v_remote_var
;
}
...
...
@@ -244,7 +253,11 @@ struct BlockNormReduceSync
});
mean_tensor
.
get_thread_buffer
()(
i
)
=
v_local_mean
;
var_tensor
.
get_thread_buffer
()(
i
)
=
v_local_var
;
if
constexpr
(
comp_var
)
{
var_tensor
.
get_thread_buffer
()(
i
)
=
v_local_var
;
}
if
(
kWelford
)
{
count
=
v_local_count
;
...
...
@@ -263,7 +276,8 @@ struct BlockNormReduceSync
template
<
typename
MeanDistributedTensor_
>
CK_TILE_DEVICE
void
operator
()(
MeanDistributedTensor_
&
mean_tensor
,
int
&
count
)
{
Impl
(
mean_tensor
,
null_tensor
{},
count
);
auto
nt
=
null_tensor
{};
Impl
(
mean_tensor
,
nt
,
count
);
}
};
...
...
@@ -348,17 +362,18 @@ struct BlockNormReduceCrossWarpSync
// using DstrEncode = typename Dstr::DstrEncode;
// using DstrEncodeDetail = typename DstrEncode::detail;
constexpr
bool
comp
uteVariance
=
constexpr
bool
comp
_var
=
!
std
::
is_same
<
VarDistributedTensor_
,
null_tensor
>::
value
&&
kComputeVariance
;
static_assert
(
(
computeVariance
==
false
)
||
std
::
is_same_v
<
Dstr
,
typename
VarDistributedTensor_
::
StaticTileDistribution
>
,
"wrong!"
);
constexpr
index_t
thread_buf_size
=
MeanDistributedTensor_
::
get_thread_buffer_size
();
static_assert
((
computeVariance
==
false
)
||
(
thread_buf_size
==
VarDistributedTensor_
::
get_thread_buffer_size
()));
if
constexpr
(
comp_var
)
{
static_assert
(
thread_buf_size
==
VarDistributedTensor_
::
get_thread_buffer_size
());
static_assert
(
std
::
is_same_v
<
Dstr
,
typename
VarDistributedTensor_
::
StaticTileDistribution
>
,
"wrong!"
);
}
// Note: we always pack everything into fp32x4
smem_dtype
*
smem_ptr
=
reinterpret_cast
<
smem_dtype
*>
(
smem
);
...
...
@@ -413,7 +428,12 @@ struct BlockNormReduceCrossWarpSync
// TODO: use descriptor for this
auto
v_local
=
all_scratch
[
i_0
*
num_reduce_warps
];
auto
v_local_mean
=
bit_cast
<
DataType
>
(
v_local
[
0
]);
auto
v_local_var
=
kComputeVariance
?
bit_cast
<
DataType
>
(
v_local
[
1
])
:
0
;
auto
v_local_var
=
[
&
]()
{
if
constexpr
(
comp_var
)
return
bit_cast
<
DataType
>
(
v_local
[
1
]);
else
return
0
;
}();
int
v_local_count
=
kWelford
?
(
kComputeVariance
?
bit_cast
<
int
>
(
v_local
[
2
])
:
bit_cast
<
int
>
(
v_local
[
1
]))
:
0
;
...
...
@@ -458,7 +478,11 @@ struct BlockNormReduceCrossWarpSync
});
mean_tensor
.
get_thread_buffer
()(
i_0
)
=
v_local_mean
;
var_tensor
.
get_thread_buffer
()(
i_0
)
=
v_local_var
;
if
constexpr
(
comp_var
)
{
var_tensor
.
get_thread_buffer
()(
i_0
)
=
v_local_var
;
}
if
constexpr
(
kWelford
)
{
count
=
v_local_count
;
...
...
@@ -479,7 +503,8 @@ struct BlockNormReduceCrossWarpSync
template
<
typename
MeanDistributedTensor_
>
CK_TILE_DEVICE
void
operator
()(
MeanDistributedTensor_
&
mean_tensor
,
int
&
count
,
void
*
smem
)
{
Impl
(
mean_tensor
,
null_tensor
{},
count
,
smem
);
auto
nt
=
null_tensor
{};
Impl
(
mean_tensor
,
nt
,
count
,
smem
);
}
};
...
...
include/ck_tile/ops/rmsnorm2d/pipeline/rmsnorm2d_fwd_pipeline_default_policy.hpp
View file @
54617a85
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-202
4
, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-202
5
, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core.hpp"
#include "ck_tile/ops/reduce/block/block_reduce
2d
_problem.hpp"
#include "ck_tile/ops/reduce/block/block_reduce
2d
.hpp"
#include "ck_tile/ops/
norm_
reduce/block/block_
norm_
reduce_problem.hpp"
#include "ck_tile/ops/
norm_
reduce/block/block_
norm_
reduce.hpp"
namespace
ck_tile
{
...
...
@@ -43,30 +43,39 @@ struct Rmsnorm2dFwdPipelineDefaultPolicy
}
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
GetBlockReduce
2d
()
CK_TILE_HOST_DEVICE
static
constexpr
auto
GetBlock
Norm
Reduce
()
{
using
P_
=
BlockReduce2dProblem
<
typename
Problem
::
ComputeDataType
,
typename
Problem
::
ComputeDataType
,
typename
Problem
::
BlockShape
>
;
return
BlockReduce2d
<
P_
>
{};
using
P_
=
BlockNormReduceProblem
<
typename
Problem
::
ComputeDataType
,
typename
Problem
::
ComputeDataType
,
typename
Problem
::
BlockShape
,
false
,
Problem
::
Traits
::
kFastFDiv
,
Problem
::
Traits
::
kWelford
>
;
return
BlockNormReduce
<
P_
>
{};
}
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
GetBlockReduce
2d
Sync
()
CK_TILE_HOST_DEVICE
static
constexpr
auto
GetBlock
Norm
ReduceSync
()
{
using
P_
=
BlockReduce2dProblem
<
typename
Problem
::
ComputeDataType
,
typename
Problem
::
ComputeDataType
,
typename
Problem
::
BlockShape
>
;
return
BlockReduce2dSync
<
P_
>
{};
using
P_
=
BlockNormReduceProblem
<
typename
Problem
::
ComputeDataType
,
typename
Problem
::
ComputeDataType
,
typename
Problem
::
BlockShape
,
false
,
Problem
::
Traits
::
kFastFDiv
,
Problem
::
Traits
::
kWelford
>
;
return
BlockNormReduceSync
<
P_
>
{};
}
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
GetBlockReduce
2d
CrossWarpSync
()
CK_TILE_HOST_DEVICE
static
constexpr
auto
GetBlock
Norm
ReduceCrossWarpSync
()
{
using
P_
=
BlockReduce2dProblem
<
typename
Problem
::
ComputeDataType
,
typename
Problem
::
ComputeDataType
,
typename
Problem
::
BlockShape
>
;
return
BlockReduce2dCrossWarpSync
<
P_
>
{};
using
P_
=
BlockNormReduceProblem
<
typename
Problem
::
ComputeDataType
,
typename
Problem
::
ComputeDataType
,
typename
Problem
::
BlockShape
,
false
,
Problem
::
Traits
::
kFastFDiv
,
Problem
::
Traits
::
kWelford
>
;
return
BlockNormReduceCrossWarpSync
<
P_
>
{};
}
template
<
typename
Problem
>
...
...
@@ -74,17 +83,22 @@ struct Rmsnorm2dFwdPipelineDefaultPolicy
{
if
constexpr
(
Problem
::
kNeedCrossWarpSync
)
{
using
P_
=
BlockReduce2dProblem
<
typename
Problem
::
ComputeDataType
,
typename
Problem
::
ComputeDataType
,
typename
Problem
::
BlockShape
>
;
using
P_
=
BlockNormReduceProblem
<
typename
Problem
::
ComputeDataType
,
typename
Problem
::
ComputeDataType
,
typename
Problem
::
BlockShape
,
false
,
Problem
::
Traits
::
kFastFDiv
,
Problem
::
Traits
::
kWelford
>
;
using
block_reduce
2d
=
BlockReduce
2d
<
P_
>
;
using
block_reduce
=
Block
Norm
Reduce
<
P_
>
;
using
x_block_tile
=
decltype
(
make_static_distributed_tensor
<
typename
Problem
::
ComputeDataType
>
(
MakeXBlockTileDistribution
<
Problem
>
()));
using
y_block_tile
=
decltype
(
block_reduce2d
::
template
MakeYBlockTile
<
x_block_tile
>());
using
mean_var_block_tile
=
decltype
(
block_reduce
::
template
MakeMeanVarBlockTile
<
x_block_tile
>());
return
GetBlockReduce2dCrossWarpSync
<
Problem
>
().
template
GetSmemSize
<
y_block_tile
>();
return
GetBlockNormReduceCrossWarpSync
<
Problem
>
()
.
template
GetSmemSize
<
mean_var_block_tile
>();
}
else
{
...
...
include/ck_tile/ops/rmsnorm2d/pipeline/rmsnorm2d_fwd_pipeline_one_pass.hpp
View file @
54617a85
...
...
@@ -31,6 +31,8 @@ struct Rmsnorm2dFwdPipelineOnePass
static
constexpr
bool
kNeedCrossWarpSync
=
Problem
::
kNeedCrossWarpSync
;
static
constexpr
bool
kPadM
=
false
;
// TODO - BlockRmsnorm2dFwdProblem::kPadM
static
constexpr
bool
kPadN
=
Problem
::
Traits
::
kPadN
;
static
constexpr
bool
kFastFDiv
=
Problem
::
Traits
::
kFastFDiv
;
static
constexpr
bool
kWelford
=
Problem
::
Traits
::
kWelford
;
static
constexpr
auto
kFusedAdd
=
Problem
::
Traits
::
kFusedAdd
;
static
constexpr
auto
kFusedQuant
=
Problem
::
Traits
::
kFusedQuant
;
...
...
@@ -62,7 +64,7 @@ struct Rmsnorm2dFwdPipelineOnePass
const
YResidualWindow
&
y_residual_window_
,
InvRmsWindow
&
inv_rms_window
,
const
SmoothScaleWindow
&
sm_scale_window_
,
YScaleWindow
&
y_scale_window
_
,
YScaleWindow
&
y_scale_window
,
ComputeDataType
epsilon
,
ck_tile
::
index_t
row_size
,
void
*
smem
,
...
...
@@ -77,12 +79,13 @@ struct Rmsnorm2dFwdPipelineOnePass
auto
y_residual_window
=
make_tile_window
(
y_residual_window_
,
Policy
::
template
MakeXBlockTileDistribution
<
Problem
>());
auto
reduce_square_sum_func
=
ReduceOp
::
SquareAdd
{};
auto
reduce_sum_func
=
ReduceOp
::
Add
{};
auto
block_reduce2d
=
Policy
::
template
GetBlockReduce2d
<
Problem
>();
auto
block_reduce2d_sync
=
Policy
::
template
GetBlockReduce2dSync
<
Problem
>();
auto
block_reduce2d_cross_warp_sync
=
Policy
::
template
GetBlockReduce2dCrossWarpSync
<
Problem
>();
int
cur_count
=
0
;
int
max_count
=
block_tile_welford_calculate_max_count
<
typename
Problem
::
BlockShape
>
(
row_size
);
auto
block_norm_reduce
=
Policy
::
template
GetBlockNormReduce
<
Problem
>();
auto
block_norm_reduce_sync
=
Policy
::
template
GetBlockNormReduceSync
<
Problem
>();
auto
block_norm_reduce_cross_warp_sync
=
Policy
::
template
GetBlockNormReduceCrossWarpSync
<
Problem
>();
auto
x
=
load_tile
(
x_window
);
auto
x_resi
=
load_tile
(
x_residual_window
);
...
...
@@ -105,19 +108,32 @@ struct Rmsnorm2dFwdPipelineOnePass
}
}
// Calculate square here because block norm reduce only supports naive mean.
auto
square
=
make_static_distributed_tensor
<
ComputeDataType
>
(
acc
.
get_tile_distribution
());
sweep_tile
(
acc
,
[
&
](
auto
idx
)
{
square
(
idx
)
=
acc
(
idx
)
*
acc
(
idx
);
});
// compute mean square each-thread->cross-lane->cross-warp
auto
square_sum
=
block_reduce2d
(
acc
,
reduce_square_sum_func
.
GetIdentityValue
<
ComputeDataType
>
(),
reduce_square_sum_func
);
block_reduce2d_sync
(
square_sum
,
reduce_sum_func
);
block_reduce2d_cross_warp_sync
(
square_sum
,
smem
,
reduce_sum_func
);
using
XTensorType
=
decltype
(
cast_tile
<
ComputeDataType
>
(
x
));
auto
square_mean
=
block_norm_reduce
.
template
MakeMeanVarBlockTile
<
XTensorType
>();
clear_tile
(
square_mean
);
block_norm_reduce
(
square
,
square_mean
,
cur_count
,
max_count
);
block_norm_reduce_sync
(
square_mean
,
cur_count
);
block_norm_reduce_cross_warp_sync
(
square_mean
,
cur_count
,
smem
);
// compute inv-rms
auto
inv_rms
=
tile_elementwise_in
(
[
&
](
const
auto
&
v_
)
{
return
type_convert
<
ComputeDataType
>
(
1.0
f
)
/
(
sqrt
(
v_
/
row_size
+
epsilon
));
if
constexpr
(
kFastFDiv
&&
std
::
is_same_v
<
ComputeDataType
,
float
>
)
{
return
type_convert
<
ComputeDataType
>
(
1.0
f
)
*
__builtin_amdgcn_rcpf
(
sqrt
(
v_
+
epsilon
));
}
else
{
return
type_convert
<
ComputeDataType
>
(
1.0
f
)
/
(
sqrt
(
v_
+
epsilon
));
}
},
square_
sum
);
square_
mean
);
if
constexpr
(
kSaveInvRms
)
store_tile
(
inv_rms_window
,
cast_tile
<
InvRmsDataType
>
(
inv_rms
));
...
...
@@ -137,11 +153,11 @@ struct Rmsnorm2dFwdPipelineOnePass
if
constexpr
(
kFusedQuant
==
Rmsnorm2dFusedQuantEnum
::
SMOOTH_DYNAMIC_QUANT
)
{
Epilogue
{}(
y_window_
,
sm_scale_window_
,
y_scale_window
_
,
rmsn
,
smem
);
Epilogue
{}(
y_window_
,
sm_scale_window_
,
y_scale_window
,
rmsn
,
smem
);
}
else
if
constexpr
(
kFusedQuant
==
Rmsnorm2dFusedQuantEnum
::
DYNAMIC_QUANT
)
{
Epilogue
{}(
y_window_
,
y_scale_window
_
,
rmsn
,
smem
);
Epilogue
{}(
y_window_
,
y_scale_window
,
rmsn
,
smem
);
}
else
{
...
...
include/ck_tile/ops/rmsnorm2d/pipeline/rmsnorm2d_fwd_pipeline_two_pass.hpp
View file @
54617a85
...
...
@@ -31,6 +31,8 @@ struct Rmsnorm2dFwdPipelineTwoPass
static
constexpr
bool
kNeedCrossWarpSync
=
Problem
::
kNeedCrossWarpSync
;
static
constexpr
bool
kPadM
=
false
;
// TODO - BlockRmsnorm2dFwdProblem::kPadM
static
constexpr
bool
kPadN
=
Problem
::
Traits
::
kPadN
;
static
constexpr
bool
kFastFDiv
=
Problem
::
Traits
::
kFastFDiv
;
static
constexpr
bool
kWelford
=
Problem
::
Traits
::
kWelford
;
static
constexpr
auto
kFusedAdd
=
Problem
::
Traits
::
kFusedAdd
;
static
constexpr
auto
kFusedQuant
=
Problem
::
Traits
::
kFusedQuant
;
...
...
@@ -82,16 +84,23 @@ struct Rmsnorm2dFwdPipelineTwoPass
index_t
num_n_tile_iteration
=
__builtin_amdgcn_readfirstlane
(
integer_divide_ceil
(
row_size
,
Block_N
));
auto
reduce_square_sum_func
=
ReduceOp
::
SquareAdd
{};
auto
reduce_sum_func
=
ReduceOp
::
Add
{};
auto
block_reduce2d
=
Policy
::
template
GetBlockReduce2d
<
Problem
>();
auto
block_reduce2d_sync
=
Policy
::
template
GetBlockReduce2dSync
<
Problem
>();
auto
block_reduce2d_cross_warp_sync
=
Policy
::
template
GetBlockReduce2dCrossWarpSync
<
Problem
>();
// total number of count assume current iter have no pad(only last iter has pad)
constexpr
index_t
count_per_iter
=
Problem
::
BlockShape
::
Repeat_N
*
Problem
::
BlockShape
::
Vector_N
;
const
index_t
last_iter_n
=
row_size
-
(
num_n_tile_iteration
-
1
)
*
Block_N
;
using
ComputeTensorType
=
decltype
(
cast_tile
<
ComputeDataType
>
(
load_tile
(
x_window
)));
auto
square_sum
=
block_reduce2d
.
template
MakeYBlockTile
<
ComputeTensorType
>();
set_tile
(
square_sum
,
reduce_square_sum_func
.
GetIdentityValue
<
ComputeDataType
>
());
int
cur_count
=
0
;
int
max_count
=
(
num_n_tile_iteration
-
1
)
*
count_per_iter
+
block_tile_welford_calculate_max_count
<
typename
Problem
::
BlockShape
>
(
last_iter_n
);
auto
block_norm_reduce
=
Policy
::
template
GetBlockNormReduce
<
Problem
>();
auto
block_norm_reduce_sync
=
Policy
::
template
GetBlockNormReduceSync
<
Problem
>();
auto
block_norm_reduce_cross_warp_sync
=
Policy
::
template
GetBlockNormReduceCrossWarpSync
<
Problem
>();
using
XTensorType
=
decltype
(
cast_tile
<
ComputeDataType
>
(
load_tile
(
x_window
)));
auto
square_mean
=
block_norm_reduce
.
template
MakeMeanVarBlockTile
<
XTensorType
>();
clear_tile
(
square_mean
);
for
(
int
iN
=
__builtin_amdgcn_readfirstlane
(
0
);
iN
<
num_n_tile_iteration
;
++
iN
)
{
...
...
@@ -102,6 +111,7 @@ struct Rmsnorm2dFwdPipelineTwoPass
move_tile_window
(
x_residual_window
,
{
0
,
Block_N
});
auto
acc
=
cast_tile
<
ComputeDataType
>
(
x
);
if
constexpr
(
kFusedAdd
==
Rmsnorm2dFusedAddEnum
::
PRE_ADD
||
kFusedAdd
==
Rmsnorm2dFusedAddEnum
::
PRE_ADD_STORE
)
{
...
...
@@ -116,18 +126,29 @@ struct Rmsnorm2dFwdPipelineTwoPass
}
}
block_reduce2d
(
acc
,
square_sum
,
reduce_square_sum_func
);
// Calculate square here because block norm reduce only supports naive mean.
sweep_tile
(
acc
,
[
&
](
auto
idx
)
{
acc
(
idx
)
*=
acc
(
idx
);
});
block_norm_reduce
(
acc
,
square_mean
,
cur_count
,
max_count
);
}
block_reduce
2d
_sync
(
square_
sum
,
reduce_sum_f
un
c
);
block_reduce
2d
_cross_warp_sync
(
square_
sum
,
smem
,
reduce_sum_func
);
block_
norm_
reduce_sync
(
square_
mean
,
cur_co
un
t
);
block_
norm_
reduce_cross_warp_sync
(
square_
mean
,
cur_count
,
smem
);
// compute inv-rms
auto
inv_rms
=
tile_elementwise_in
(
[
&
](
const
auto
&
v_
)
{
return
type_convert
<
ComputeDataType
>
(
1.0
f
)
/
(
sqrt
(
v_
/
row_size
+
epsilon
));
if
(
kFastFDiv
&&
std
::
is_same_v
<
ComputeDataType
,
float
>
)
{
return
type_convert
<
ComputeDataType
>
(
1.0
f
)
*
__builtin_amdgcn_rcpf
(
sqrt
(
v_
+
epsilon
));
}
else
{
return
type_convert
<
ComputeDataType
>
(
1.0
f
)
/
(
sqrt
(
v_
+
epsilon
));
}
},
square_
sum
);
square_
mean
);
if
constexpr
(
kSaveInvRms
)
store_tile
(
inv_rms_window
,
cast_tile
<
InvRmsDataType
>
(
inv_rms
));
...
...
include/ck_tile/ops/rmsnorm2d/pipeline/rmsnorm2d_fwd_traits.hpp
View file @
54617a85
...
...
@@ -39,6 +39,8 @@ template<> struct Rmsnorm2dFusedQuantEnumName<Rmsnorm2dFusedQuantEnum::SMOOTH_DY
template
<
bool
kPadN_
,
bool
kSaveInvRms_
,
bool
kFastFDiv_
,
bool
kWelford_
,
bool
kTwoPass_
,
Rmsnorm2dFusedAddEnum
kFusedAdd_
,
Rmsnorm2dFusedQuantEnum
kFusedQuant_
>
...
...
@@ -46,6 +48,8 @@ struct Rmsnorm2dFwdTraits
{
static
constexpr
bool
kPadN
=
kPadN_
;
static
constexpr
bool
kSaveInvRms
=
kSaveInvRms_
;
static
constexpr
bool
kFastFDiv
=
kFastFDiv_
;
static
constexpr
bool
kWelford
=
kWelford_
;
static
constexpr
bool
kTwoPass
=
kTwoPass_
;
static
constexpr
Rmsnorm2dFusedAddEnum
kFusedAdd
=
kFusedAdd_
;
static
constexpr
Rmsnorm2dFusedQuantEnum
kFusedQuant
=
kFusedQuant_
;
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment