Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel_ROCM
Commits
018e939f
"git@developer.sourcefind.cn:chenzk/alphafold2_jax.git" did not exist on "189f7d81f31abf20bbd56e098bf8dacd8a933d05"
Commit
018e939f
authored
Jan 22, 2025
by
Jiming Ruan
Browse files
Modify tests and bug fix
parent
54617a85
Changes
4
Expand all
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
118 additions
and
73 deletions
+118
-73
example/ck_tile/10_rmsnorm2d/generate.py
example/ck_tile/10_rmsnorm2d/generate.py
+103
-73
example/ck_tile/10_rmsnorm2d/rmsnorm2d_fwd.cpp
example/ck_tile/10_rmsnorm2d/rmsnorm2d_fwd.cpp
+6
-0
include/ck_tile/ops/rmsnorm2d/pipeline/rmsnorm2d_fwd_pipeline_one_pass.hpp
...ps/rmsnorm2d/pipeline/rmsnorm2d_fwd_pipeline_one_pass.hpp
+7
-0
include/ck_tile/ops/rmsnorm2d/pipeline/rmsnorm2d_fwd_pipeline_two_pass.hpp
...ps/rmsnorm2d/pipeline/rmsnorm2d_fwd_pipeline_two_pass.hpp
+2
-0
No files found.
example/ck_tile/10_rmsnorm2d/generate.py
View file @
018e939f
This diff is collapsed.
Click to expand it.
example/ck_tile/10_rmsnorm2d/rmsnorm2d_fwd.cpp
View file @
018e939f
...
...
@@ -200,6 +200,12 @@ bool run(const ck_tile::ArgParser& arg_parser)
float
ave_time
=
rmsnorm2d_fwd
(
traits
,
args
,
ck_tile
::
stream_config
{
nullptr
,
true
,
kname
?
1
:
0
,
warmup
,
repeat
});
if
(
ave_time
<
0
)
{
std
::
cout
<<
" not supported!"
<<
std
::
endl
<<
std
::
flush
;
return
false
;
}
std
::
size_t
num_byte
=
sizeof
(
XDataType
)
*
m
*
n
+
sizeof
(
GammaDataType
)
*
n
+
sizeof
(
YDataType
)
*
m
*
n
;
num_byte
+=
SaveRms
?
sizeof
(
InvRmsDataType
)
*
m
*
n
:
0
;
...
...
include/ck_tile/ops/rmsnorm2d/pipeline/rmsnorm2d_fwd_pipeline_one_pass.hpp
View file @
018e939f
...
...
@@ -120,6 +120,13 @@ struct Rmsnorm2dFwdPipelineOnePass
block_norm_reduce_sync
(
square_mean
,
cur_count
);
block_norm_reduce_cross_warp_sync
(
square_mean
,
cur_count
,
smem
);
if
constexpr
(
!
kWelford
)
{
sweep_tile
(
square_mean
,
[
&
](
auto
idx
)
{
square_mean
(
idx
)
=
square_mean
(
idx
)
/
type_convert
<
ComputeDataType
>
(
row_size
);
});
}
// compute inv-rms
auto
inv_rms
=
tile_elementwise_in
(
[
&
](
const
auto
&
v_
)
{
...
...
include/ck_tile/ops/rmsnorm2d/pipeline/rmsnorm2d_fwd_pipeline_two_pass.hpp
View file @
018e939f
...
...
@@ -70,6 +70,8 @@ struct Rmsnorm2dFwdPipelineTwoPass
void
*
smem
,
Epilogue
)
const
{
static_assert
(
kWelford
==
true
,
"2 pass only supports welford merge"
);
auto
x_window
=
make_tile_window
(
x_window_
,
Policy
::
template
MakeXBlockTileDistribution
<
Problem
>());
auto
gamma_window
=
make_tile_window
(
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment