Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel_ROCM
Commits
c1f8d53c
Unverified
Commit
c1f8d53c
authored
Nov 14, 2024
by
feli
Committed by
GitHub
Nov 14, 2024
Browse files
[Ck_tile] hot fix, fix rpcf param setting err (#1657)
Co-authored-by:
dummycoderfe
<
noplydummmycoder@163.com
>
parent
efd92615
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
21 additions
and
8 deletions
+21
-8
include/ck_tile/ops/layernorm2d/pipeline/layernorm2d_fwd_pipeline_one_pass.hpp
...ayernorm2d/pipeline/layernorm2d_fwd_pipeline_one_pass.hpp
+1
-1
include/ck_tile/ops/layernorm2d/pipeline/layernorm2d_fwd_pipeline_two_pass.hpp
...ayernorm2d/pipeline/layernorm2d_fwd_pipeline_two_pass.hpp
+11
-3
include/ck_tile/ops/welford/block/block_welford.hpp
include/ck_tile/ops/welford/block/block_welford.hpp
+9
-4
No files found.
include/ck_tile/ops/layernorm2d/pipeline/layernorm2d_fwd_pipeline_one_pass.hpp
View file @
c1f8d53c
...
...
@@ -121,7 +121,7 @@ struct Layernorm2dFwdPipelineOnePass
auto
[
mean
,
var
]
=
block_welford
(
acc
,
cur_count
,
max_count
);
block_welford_sync
(
mean
,
var
,
cur_count
);
block_welford_cross_warp_sync
(
mean
,
var
,
cur_count
,
smem
);
block_tile_welford_post_scale_var
(
var
,
cur_count
);
block_tile_welford_post_scale_var
(
var
,
cur_count
,
constant
<
kFastFDiv
>
{}
);
// compute inv-std
auto
inv_std
=
tile_elementwise_in
(
...
...
include/ck_tile/ops/layernorm2d/pipeline/layernorm2d_fwd_pipeline_two_pass.hpp
View file @
c1f8d53c
...
...
@@ -35,6 +35,7 @@ struct Layernorm2dFwdPipelineTwoPass
static
constexpr
bool
kNeedCrossWarpSync
=
Problem
::
kNeedCrossWarpSync
;
static
constexpr
bool
kPadM
=
false
;
// TODO - BlockLayernorm2dFwdProblem::kPadM
static
constexpr
bool
kPadN
=
Problem
::
Traits
::
kPadN
;
static
constexpr
bool
kFastFDiv
=
Problem
::
Traits
::
kFastFDiv
;
static
constexpr
auto
kFusedAdd
=
Problem
::
Traits
::
kFusedAdd
;
static
constexpr
auto
kFusedQuant
=
Problem
::
Traits
::
kFusedQuant
;
...
...
@@ -137,15 +138,22 @@ struct Layernorm2dFwdPipelineTwoPass
block_welford_sync
(
mean
,
var
,
cur_count
);
block_welford_cross_warp_sync
(
mean
,
var
,
cur_count
,
smem
);
block_tile_welford_post_scale_var
(
var
,
cur_count
);
block_tile_welford_post_scale_var
(
var
,
cur_count
,
constant
<
kFastFDiv
>
{}
);
// compute inv-std
auto
inv_std
=
tile_elementwise_in
(
[
&
](
const
auto
&
v_
)
{
return
type_convert
<
ComputeDataType
>
(
1.0
f
)
/
(
sqrt
(
v_
+
epsilon
));
if
(
kFastFDiv
&&
std
::
is_same_v
<
ComputeDataType
,
float
>
)
{
return
type_convert
<
ComputeDataType
>
(
1.0
f
)
*
__builtin_amdgcn_rcpf
(
sqrt
(
v_
+
epsilon
));
}
else
{
return
type_convert
<
ComputeDataType
>
(
1.0
f
)
/
sqrt
(
v_
+
epsilon
);
}
},
var
);
if
constexpr
(
kSaveMean
)
store_tile
(
mean_window
,
cast_tile
<
MeanDataType
>
(
mean
));
if
constexpr
(
kSaveInvStd
)
...
...
include/ck_tile/ops/welford/block/block_welford.hpp
View file @
c1f8d53c
...
...
@@ -47,8 +47,11 @@ struct BlockWelford
auto
x
=
ck_tile
::
type_convert
<
ComputeDataType
>
(
x_tensor
[
in_dstr_idx
]);
welford_update
(
mean_tensor
(
out_dstr_idx
),
var_tensor
(
out_dstr_idx
),
x
,
cur_count_
);
welford_update
(
mean_tensor
(
out_dstr_idx
),
var_tensor
(
out_dstr_idx
),
x
,
cur_count_
,
constant
<
kFastFDiv
>
{});
});
}
});
...
...
@@ -159,7 +162,8 @@ struct BlockWelfordSync
v_local_count
,
v_remote_mean
,
v_remote_var
,
v_remote_count
);
v_remote_count
,
constant
<
kFastFDiv
>
{});
});
}
});
...
...
@@ -307,7 +311,8 @@ struct BlockWelfordCrossWarpSync
v_local_count
,
v_remote_mean
,
v_remote_var
,
v_remote_count
);
v_remote_count
,
constant
<
kFastFDiv
>
{});
});
mean_tensor
.
get_thread_buffer
()(
i_0
)
=
v_local_mean
;
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment