Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel_ROCM
Commits
27ff3dec
"git@developer.sourcefind.cn:modelzoo/resnet50_tensorflow.git" did not exist on "2a5f2a952c4cd0fa4fd14beacefd5c5142b678c1"
Commit
27ff3dec
authored
Nov 04, 2024
by
dummycoderfe
Browse files
optimze small N case using vec io and using rcp div
parent
cb6c5d39
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
18 additions
and
12 deletions
+18
-12
example/ck_tile/02_layernorm2d/generate.py
example/ck_tile/02_layernorm2d/generate.py
+6
-3
include/ck_tile/ops/layernorm2d/pipeline/layernorm2d_fwd_pipeline_one_pass.hpp
...ayernorm2d/pipeline/layernorm2d_fwd_pipeline_one_pass.hpp
+2
-1
include/ck_tile/ops/welford/block/block_welford.hpp
include/ck_tile/ops/welford/block/block_welford.hpp
+3
-2
include/ck_tile/ops/welford/thread/thread_welford.hpp
include/ck_tile/ops/welford/thread/thread_welford.hpp
+7
-6
No files found.
example/ck_tile/02_layernorm2d/generate.py
View file @
27ff3dec
...
@@ -114,7 +114,7 @@ struct layernorm2d_fwd_traits_
...
@@ -114,7 +114,7 @@ struct layernorm2d_fwd_traits_
using WarpTile = ck_tile::sequence<Warp_M, Warp_N>;
using WarpTile = ck_tile::sequence<Warp_M, Warp_N>;
using Vector = ck_tile::sequence<1, Vector_N_>;
using Vector = ck_tile::sequence<1, Vector_N_>;
using Shape = ck_tile::Generic2dBlockShape<BlockTile, BlockWarps, WarpTile, Vector>;
using Shape = ck_tile::Generic2dBlockShape<BlockTile, BlockWarps, WarpTile, Vector
, ThreadPerBlock_M_ * ThreadPerBlock_N_
>;
static constexpr bool kPadN = kPadN_;
static constexpr bool kPadN = kPadN_;
static constexpr bool kSaveMeanInvStd = kSaveMeanInvStd_;
static constexpr bool kSaveMeanInvStd = kSaveMeanInvStd_;
...
@@ -484,8 +484,11 @@ float layernorm2d_fwd(layernorm2d_fwd_traits t,
...
@@ -484,8 +484,11 @@ float layernorm2d_fwd(layernorm2d_fwd_traits t,
fused_sweep_list
=
[
0
,
1
]
# NOTE: only single pass can use fused dynamic quant
fused_sweep_list
=
[
0
,
1
]
# NOTE: only single pass can use fused dynamic quant
# rm rn tm tn vn pd mv 2p add sweep
# rm rn tm tn vn pd mv 2p add sweep
h_trait_dict
=
{
'64'
:
[
h_traits
(
'x'
,
'y'
,
'xs'
,
'ys'
,
1
,
1
,
4
,
64
,
1
,
True
,
False
,
False
,
0
,
0
)],
h_trait_dict
=
{
'64'
:
[
h_traits
(
'x'
,
'y'
,
'xs'
,
'ys'
,
1
,
1
,
8
,
8
,
8
,
True
,
False
,
False
,
0
,
0
),
'128'
:
[
h_traits
(
'x'
,
'y'
,
'xs'
,
'ys'
,
1
,
1
,
4
,
64
,
2
,
True
,
False
,
False
,
0
,
0
),
h_traits
(
'x'
,
'y'
,
'xs'
,
'ys'
,
1
,
1
,
4
,
16
,
4
,
True
,
False
,
False
,
0
,
0
),
h_traits
(
'x'
,
'y'
,
'xs'
,
'ys'
,
1
,
1
,
4
,
64
,
1
,
True
,
False
,
False
,
0
,
0
)],
'128'
:
[
h_traits
(
'x'
,
'y'
,
'xs'
,
'ys'
,
1
,
1
,
4
,
16
,
8
,
True
,
False
,
False
,
0
,
0
),
h_traits
(
'x'
,
'y'
,
'xs'
,
'ys'
,
1
,
1
,
4
,
64
,
2
,
True
,
False
,
False
,
0
,
0
),
h_traits
(
'x'
,
'y'
,
'xs'
,
'ys'
,
1
,
2
,
4
,
64
,
1
,
True
,
False
,
False
,
0
,
0
)],
h_traits
(
'x'
,
'y'
,
'xs'
,
'ys'
,
1
,
2
,
4
,
64
,
1
,
True
,
False
,
False
,
0
,
0
)],
'256'
:
[
h_traits
(
'x'
,
'y'
,
'xs'
,
'ys'
,
1
,
1
,
4
,
64
,
4
,
True
,
False
,
False
,
0
,
0
),
'256'
:
[
h_traits
(
'x'
,
'y'
,
'xs'
,
'ys'
,
1
,
1
,
4
,
64
,
4
,
True
,
False
,
False
,
0
,
0
),
h_traits
(
'x'
,
'y'
,
'xs'
,
'ys'
,
1
,
2
,
4
,
64
,
2
,
True
,
False
,
False
,
0
,
0
),
h_traits
(
'x'
,
'y'
,
'xs'
,
'ys'
,
1
,
2
,
4
,
64
,
2
,
True
,
False
,
False
,
0
,
0
),
...
...
include/ck_tile/ops/layernorm2d/pipeline/layernorm2d_fwd_pipeline_one_pass.hpp
View file @
27ff3dec
...
@@ -125,7 +125,8 @@ struct Layernorm2dFwdPipelineOnePass
...
@@ -125,7 +125,8 @@ struct Layernorm2dFwdPipelineOnePass
// compute inv-std
// compute inv-std
auto
inv_std
=
tile_elementwise_in
(
auto
inv_std
=
tile_elementwise_in
(
[
&
](
const
auto
&
v_
)
{
[
&
](
const
auto
&
v_
)
{
return
type_convert
<
ComputeDataType
>
(
1.0
f
)
/
(
sqrt
(
v_
+
epsilon
));
return
type_convert
<
ComputeDataType
>
(
1.0
f
)
*
__builtin_amdgcn_rcpf
(
sqrt
(
v_
+
epsilon
));
},
},
var
);
var
);
...
...
include/ck_tile/ops/welford/block/block_welford.hpp
View file @
27ff3dec
...
@@ -356,7 +356,8 @@ CK_TILE_DEVICE constexpr void block_tile_welford_post_scale_var(VarDistributedTe
...
@@ -356,7 +356,8 @@ CK_TILE_DEVICE constexpr void block_tile_welford_post_scale_var(VarDistributedTe
int
count
)
int
count
)
{
{
using
DataType
=
typename
VarDistributedTensor_
::
DataType
;
using
DataType
=
typename
VarDistributedTensor_
::
DataType
;
tile_elementwise_inout
([
&
count
](
auto
&
x
)
{
x
=
x
/
type_convert
<
DataType
>
(
count
);
},
tile_elementwise_inout
(
var_tensor
);
[
&
count
](
auto
&
x
)
{
x
=
x
*
__builtin_amdgcn_rcpf
(
type_convert
<
DataType
>
(
count
));
},
var_tensor
);
}
}
}
// namespace ck_tile
}
// namespace ck_tile
include/ck_tile/ops/welford/thread/thread_welford.hpp
View file @
27ff3dec
...
@@ -12,7 +12,7 @@ CK_TILE_DEVICE void welford_update(T& mean, T& var, T x, int count)
...
@@ -12,7 +12,7 @@ CK_TILE_DEVICE void welford_update(T& mean, T& var, T x, int count)
{
{
// TODO: check nan? maybe no
// TODO: check nan? maybe no
T
delta
=
x
-
mean
;
T
delta
=
x
-
mean
;
mean
+=
delta
/
count
;
mean
+=
delta
*
__builtin_amdgcn_rcpf
(
count
)
;
T
delta2
=
x
-
mean
;
T
delta2
=
x
-
mean
;
var
+=
delta
*
delta2
;
var
+=
delta
*
delta2
;
}
}
...
@@ -21,11 +21,12 @@ template <typename T>
...
@@ -21,11 +21,12 @@ template <typename T>
CK_TILE_DEVICE
static
void
CK_TILE_DEVICE
static
void
welford_merge
(
T
&
mean_a
,
T
&
var_a
,
int
&
count_a
,
T
mean_b
,
T
var_b
,
int
count_b
)
welford_merge
(
T
&
mean_a
,
T
&
var_a
,
int
&
count_a
,
T
mean_b
,
T
var_b
,
int
count_b
)
{
{
int
count
=
count_a
+
count_b
;
int
count
=
count_a
+
count_b
;
T
count_
=
type_convert
<
T
>
(
count
);
T
count_
=
type_convert
<
T
>
(
count
);
T
count_a_
=
type_convert
<
T
>
(
count_a
);
T
count_a_
=
type_convert
<
T
>
(
count_a
);
T
count_b_
=
type_convert
<
T
>
(
count_b
);
T
count_b_
=
type_convert
<
T
>
(
count_b
);
T
count_b_over_count
=
count
==
0
?
type_convert
<
T
>
(
0
)
:
count_b_
/
count_
;
T
count_b_over_count
=
count
==
0
?
type_convert
<
T
>
(
0
)
:
count_b_
*
__builtin_amdgcn_rcpf
(
count_
);
T
delta
=
mean_b
-
mean_a
;
T
delta
=
mean_b
-
mean_a
;
mean_a
+=
delta
*
count_b_over_count
;
mean_a
+=
delta
*
count_b_over_count
;
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment