Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel_ROCM
Commits
df45a6b5
Commit
df45a6b5
authored
Jan 24, 2025
by
Jiming Ruan
Browse files
50ms -> 28ms
parent
3c7fef7f
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
30 additions
and
16 deletions
+30
-16
example/ck_tile/10_rmsnorm2d/generate.py
example/ck_tile/10_rmsnorm2d/generate.py
+5
-4
include/ck_tile/ops/rmsnorm2d/pipeline/rmsnorm2d_fwd_pipeline_two_pass.hpp
...ps/rmsnorm2d/pipeline/rmsnorm2d_fwd_pipeline_two_pass.hpp
+25
-12
No files found.
example/ck_tile/10_rmsnorm2d/generate.py
View file @
df45a6b5
...
@@ -535,10 +535,11 @@ float rmsnorm2d_fwd(rmsnorm2d_fwd_traits t,
...
@@ -535,10 +535,11 @@ float rmsnorm2d_fwd(rmsnorm2d_fwd_traits t,
h_traits
(
'x'
,
'y'
,
'xs'
,
'ys'
,
1
,
4
,
1
,
512
,
4
,
True
,
False
,
False
,
0
,
0
),
h_traits
(
'x'
,
'y'
,
'xs'
,
'ys'
,
1
,
4
,
1
,
512
,
4
,
True
,
False
,
False
,
0
,
0
),
h_traits
(
'x'
,
'y'
,
'xs'
,
'ys'
,
1
,
4
,
1
,
1024
,
2
,
True
,
False
,
False
,
0
,
0
),
h_traits
(
'x'
,
'y'
,
'xs'
,
'ys'
,
1
,
4
,
1
,
1024
,
2
,
True
,
False
,
False
,
0
,
0
),
h_traits
(
'x'
,
'y'
,
'xs'
,
'ys'
,
1
,
8
,
1
,
1024
,
1
,
True
,
False
,
False
,
0
,
0
)],
h_traits
(
'x'
,
'y'
,
'xs'
,
'ys'
,
1
,
8
,
1
,
1024
,
1
,
True
,
False
,
False
,
0
,
0
)],
'big'
:[
h_traits
(
'x'
,
'y'
,
'xs'
,
'ys'
,
1
,
2
,
1
,
256
,
8
,
True
,
False
,
True
,
0
,
0
),
'big'
:[
h_traits
(
'x'
,
'y'
,
'xs'
,
'ys'
,
1
,
1
,
1
,
1024
,
8
,
True
,
False
,
True
,
0
,
0
),
h_traits
(
'x'
,
'y'
,
'xs'
,
'ys'
,
1
,
4
,
1
,
256
,
4
,
True
,
False
,
True
,
0
,
0
),
# h_traits('x', 'y', 'xs', 'ys', 1, 4, 1, 256, 4, True, False, True, 0, 0),
h_traits
(
'x'
,
'y'
,
'xs'
,
'ys'
,
1
,
2
,
1
,
1024
,
2
,
True
,
False
,
True
,
0
,
0
),
# h_traits('x', 'y', 'xs', 'ys', 1, 2, 1,1024, 2, True, False, True, 0, 0),
h_traits
(
'x'
,
'y'
,
'xs'
,
'ys'
,
1
,
4
,
1
,
1024
,
1
,
True
,
False
,
True
,
0
,
0
)]}
# h_traits('x', 'y', 'xs', 'ys', 1, 4, 1,1024, 1, True, False, True, 0, 0)
]}
total_blob
=
list
()
total_blob
=
list
()
for
hs_key
in
h_trait_dict
:
for
hs_key
in
h_trait_dict
:
hs
=
h_trait_dict
[
hs_key
]
hs
=
h_trait_dict
[
hs_key
]
...
...
include/ck_tile/ops/rmsnorm2d/pipeline/rmsnorm2d_fwd_pipeline_two_pass.hpp
View file @
df45a6b5
...
@@ -125,7 +125,7 @@ struct Rmsnorm2dFwdPipelineTwoPass
...
@@ -125,7 +125,7 @@ struct Rmsnorm2dFwdPipelineTwoPass
// compute inv-rms
// compute inv-rms
auto
inv_rms
=
tile_elementwise_in
(
auto
inv_rms
=
tile_elementwise_in
(
[
&
](
const
auto
&
v_
)
{
[
&
](
const
auto
&
v_
)
{
return
type_convert
<
ComputeDataType
>
(
1.0
f
)
/
(
sqrt
(
v_
/
row_size
+
epsilon
)
)
;
return
r
sqrt
f
(
v_
/
row_size
+
epsilon
);
},
},
square_sum
);
square_sum
);
...
@@ -136,32 +136,47 @@ struct Rmsnorm2dFwdPipelineTwoPass
...
@@ -136,32 +136,47 @@ struct Rmsnorm2dFwdPipelineTwoPass
ck_tile
::
index_t
stride_to_right_most_window
=
ck_tile
::
index_t
stride_to_right_most_window
=
row_size
%
Block_N
==
0
?
row_size
-
Block_N
:
row_size
-
row_size
%
Block_N
;
row_size
%
Block_N
==
0
?
row_size
-
Block_N
:
row_size
-
row_size
%
Block_N
;
if
constexpr
(
kFusedAdd
==
Rmsnorm2dFusedAddEnum
::
PRE_ADD_STORE
)
{
move_tile_window
(
y_residual_window
,
{
0
,
-
Block_N
});
}
else
{
move_tile_window
(
x_window
,
{
0
,
-
Block_N
});
move_tile_window
(
x_window
,
{
0
,
-
Block_N
});
move_tile_window
(
x_residual_window
,
{
0
,
-
Block_N
});
move_tile_window
(
x_residual_window
,
{
0
,
-
Block_N
});
}
move_tile_window
(
gamma_window
,
{
stride_to_right_most_window
});
move_tile_window
(
gamma_window
,
{
stride_to_right_most_window
});
move_tile_window
(
y_window
,
{
0
,
stride_to_right_most_window
});
move_tile_window
(
y_window
,
{
0
,
stride_to_right_most_window
});
// rmsnorm computation
// rmsnorm computation
for
(
int
iN
=
__builtin_amdgcn_readfirstlane
(
0
);
iN
<
num_n_tile_iteration
;
++
iN
)
for
(
int
iN
=
__builtin_amdgcn_readfirstlane
(
0
);
iN
<
num_n_tile_iteration
;
++
iN
)
{
auto
acc
=
make_static_distributed_tensor
<
ComputeDataType
>
(
decltype
(
load_tile
(
x_window
))
::
get_tile_distribution
());
if
constexpr
(
kFusedAdd
==
Rmsnorm2dFusedAddEnum
::
PRE_ADD
)
{
{
auto
x
=
load_tile
(
x_window
);
auto
x
=
load_tile
(
x_window
);
auto
x_resi
=
load_tile
(
x_residual_window
);
auto
x_resi
=
load_tile
(
x_residual_window
);
auto
acc
=
cast_tile
<
ComputeDataType
>
(
x
);
if
constexpr
(
kFusedAdd
==
Rmsnorm2dFusedAddEnum
::
PRE_ADD_STORE
||
kFusedAdd
==
Rmsnorm2dFusedAddEnum
::
PRE_ADD
)
{
sweep_tile
(
x_resi
,
[
&
](
auto
idx
)
{
sweep_tile
(
x_resi
,
[
&
](
auto
idx
)
{
// compute x = x_resi + x
// compute x = x_resi + x
acc
(
idx
)
=
type_convert
<
ComputeDataType
>
(
x_resi
(
idx
))
+
acc
(
idx
);
acc
(
idx
)
=
type_convert
<
ComputeDataType
>
(
x_resi
(
idx
))
+
acc
(
idx
);
});
});
move_tile_window
(
x_window
,
{
0
,
-
Block_N
});
move_tile_window
(
x_residual_window
,
{
0
,
-
Block_N
});
}
else
if
constexpr
(
kFusedAdd
==
Rmsnorm2dFusedAddEnum
::
PRE_ADD_STORE
)
{
acc
=
cast_tile
<
ComputeDataType
>
(
load_tile
(
y_residual_window
));
move_tile_window
(
y_residual_window
,
{
0
,
-
Block_N
});
}
}
// load gamma (TODO: support no gamma?)
// load gamma (TODO: support no gamma?)
const
auto
gamma
=
load_tile
(
gamma_window
);
const
auto
gamma
=
load_tile
(
gamma_window
);
// rmsnorm computation
// rmsnorm computation
auto
rmsn
=
make_static_distributed_tensor
<
ComputeDataType
>
(
x
.
get_tile_distribution
());
auto
rmsn
=
make_static_distributed_tensor
<
ComputeDataType
>
(
decltype
(
load_tile
(
x_window
))
::
get_tile_distribution
());
sweep_tile
(
rmsn
,
[
&
,
inv_rms_
=
inv_rms
](
auto
idx
)
{
sweep_tile
(
rmsn
,
[
&
,
inv_rms_
=
inv_rms
](
auto
idx
)
{
constexpr
auto
i_idx
=
make_tuple
(
idx
[
number
<
0
>
{}]);
constexpr
auto
i_idx
=
make_tuple
(
idx
[
number
<
0
>
{}]);
constexpr
auto
j_idx
=
make_tuple
(
idx
[
number
<
1
>
{}]);
constexpr
auto
j_idx
=
make_tuple
(
idx
[
number
<
1
>
{}]);
...
@@ -176,8 +191,6 @@ struct Rmsnorm2dFwdPipelineTwoPass
...
@@ -176,8 +191,6 @@ struct Rmsnorm2dFwdPipelineTwoPass
static_assert
(
kFusedQuant
==
Rmsnorm2dFusedQuantEnum
::
NO_SWEEP
);
static_assert
(
kFusedQuant
==
Rmsnorm2dFusedQuantEnum
::
NO_SWEEP
);
Epilogue
{}(
y_window
,
rmsn
);
Epilogue
{}(
y_window
,
rmsn
);
move_tile_window
(
x_window
,
{
0
,
-
Block_N
});
move_tile_window
(
x_residual_window
,
{
0
,
-
Block_N
});
move_tile_window
(
gamma_window
,
{
-
Block_N
});
move_tile_window
(
gamma_window
,
{
-
Block_N
});
move_tile_window
(
y_window
,
{
0
,
-
Block_N
});
move_tile_window
(
y_window
,
{
0
,
-
Block_N
});
}
}
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment