Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel_ROCM
Commits
93ec1681
Commit
93ec1681
authored
Oct 21, 2024
by
rocking
Browse files
Add two pass pipeline
parent
4cef2fc5
Changes
9
Hide whitespace changes
Inline
Side-by-side
Showing
9 changed files
with
203 additions
and
6 deletions
+203
-6
example/ck_tile/02_layernorm2d/instances/layernorm2d_fwd_api.cpp
.../ck_tile/02_layernorm2d/instances/layernorm2d_fwd_api.cpp
+10
-0
example/ck_tile/02_layernorm2d/instances/layernorm2d_fwd_bf16_n4096_tp_instance.cpp
...rm2d/instances/layernorm2d_fwd_bf16_n4096_tp_instance.cpp
+14
-0
example/ck_tile/02_layernorm2d/instances/layernorm2d_fwd_fp16_n4096_tp_instance.cpp
...rm2d/instances/layernorm2d_fwd_fp16_n4096_tp_instance.cpp
+14
-0
example/ck_tile/02_layernorm2d/instances/layernorm2d_fwd_instance_common.hpp
...layernorm2d/instances/layernorm2d_fwd_instance_common.hpp
+4
-1
example/ck_tile/02_layernorm2d/layernorm2d_fwd.cpp
example/ck_tile/02_layernorm2d/layernorm2d_fwd.cpp
+2
-2
example/ck_tile/02_layernorm2d/script/smoke_test.sh
example/ck_tile/02_layernorm2d/script/smoke_test.sh
+2
-0
include/ck_tile/ops/layernorm2d.hpp
include/ck_tile/ops/layernorm2d.hpp
+2
-1
include/ck_tile/ops/layernorm2d/pipeline/layernorm2d_fwd_one_pass_pipeline.hpp
...ayernorm2d/pipeline/layernorm2d_fwd_one_pass_pipeline.hpp
+1
-2
include/ck_tile/ops/layernorm2d/pipeline/layernorm2d_fwd_two_pass_pipeline.hpp
...ayernorm2d/pipeline/layernorm2d_fwd_two_pass_pipeline.hpp
+154
-0
No files found.
example/ck_tile/02_layernorm2d/instances/layernorm2d_fwd_api.cpp
View file @
93ec1681
...
@@ -98,6 +98,16 @@ float layernorm2d_fwd_b16_(layernorm2d_fwd_traits /*t*/,
...
@@ -98,6 +98,16 @@ float layernorm2d_fwd_b16_(layernorm2d_fwd_traits /*t*/,
else
else
r
=
layernorm2d_fwd_
<
trait_
<
data_type
,
1
,
4
,
1
,
1024
,
1
,
true
,
false
,
false
>>
(
s
,
a
);
r
=
layernorm2d_fwd_
<
trait_
<
data_type
,
1
,
4
,
1
,
1024
,
1
,
true
,
false
,
false
>>
(
s
,
a
);
}
}
else
if
(
a
.
n
>
4096
)
{
if
(
a
.
n
%
8
==
0
)
r
=
layernorm2d_fwd_
<
trait_
<
data_type
,
1
,
2
,
1
,
256
,
8
,
true
,
false
,
true
>>
(
s
,
a
);
else
if
(
a
.
n
%
4
==
0
)
r
=
layernorm2d_fwd_
<
trait_
<
data_type
,
1
,
4
,
1
,
256
,
4
,
true
,
false
,
true
>>
(
s
,
a
);
else
if
(
a
.
n
%
2
==
0
)
r
=
layernorm2d_fwd_
<
trait_
<
data_type
,
1
,
2
,
1
,
1024
,
2
,
true
,
false
,
true
>>
(
s
,
a
);
else
r
=
layernorm2d_fwd_
<
trait_
<
data_type
,
1
,
4
,
1
,
1024
,
1
,
true
,
false
,
true
>>
(
s
,
a
);
}
return
r
;
return
r
;
#else
#else
return
layernorm2d_fwd_
<
trait_
<
data_type
,
1
,
1
,
1
,
256
,
4
,
true
,
false
,
false
>>
(
s
,
a
);
return
layernorm2d_fwd_
<
trait_
<
data_type
,
1
,
1
,
1
,
256
,
4
,
true
,
false
,
false
>>
(
s
,
a
);
...
...
example/ck_tile/02_layernorm2d/instances/layernorm2d_fwd_bf16_n4096_tp_instance.cpp
0 → 100644
View file @
93ec1681
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#include "layernorm2d_fwd_instance_common.hpp"
// clang-format off
// rm rn tm tn vn pd mv 2p
template
float
layernorm2d_fwd_
<
trait_
<
ck_tile
::
bf16_t
,
1
,
2
,
1
,
256
,
8
,
true
,
false
,
true
>
>
(
const
S
&
,
A
);
template
float
layernorm2d_fwd_
<
trait_
<
ck_tile
::
bf16_t
,
1
,
4
,
1
,
256
,
4
,
true
,
false
,
true
>
>
(
const
S
&
,
A
);
template
float
layernorm2d_fwd_
<
trait_
<
ck_tile
::
bf16_t
,
1
,
2
,
1
,
1024
,
2
,
true
,
false
,
true
>
>
(
const
S
&
,
A
);
template
float
layernorm2d_fwd_
<
trait_
<
ck_tile
::
bf16_t
,
1
,
4
,
1
,
1024
,
1
,
true
,
false
,
true
>
>
(
const
S
&
,
A
);
// clang-format on
example/ck_tile/02_layernorm2d/instances/layernorm2d_fwd_fp16_n4096_tp_instance.cpp
0 → 100644
View file @
93ec1681
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#include "layernorm2d_fwd_instance_common.hpp"
// clang-format off
// rm rn tm tn vn pd mv 2p
template
float
layernorm2d_fwd_
<
trait_
<
ck_tile
::
fp16_t
,
1
,
2
,
1
,
256
,
8
,
true
,
false
,
true
>
>
(
const
S
&
,
A
);
template
float
layernorm2d_fwd_
<
trait_
<
ck_tile
::
fp16_t
,
1
,
4
,
1
,
256
,
4
,
true
,
false
,
true
>
>
(
const
S
&
,
A
);
template
float
layernorm2d_fwd_
<
trait_
<
ck_tile
::
fp16_t
,
1
,
2
,
1
,
1024
,
2
,
true
,
false
,
true
>
>
(
const
S
&
,
A
);
template
float
layernorm2d_fwd_
<
trait_
<
ck_tile
::
fp16_t
,
1
,
4
,
1
,
1024
,
1
,
true
,
false
,
true
>
>
(
const
S
&
,
A
);
// clang-format on
example/ck_tile/02_layernorm2d/instances/layernorm2d_fwd_instance_common.hpp
View file @
93ec1681
...
@@ -118,7 +118,10 @@ float layernorm2d_fwd_(const S& s, A a)
...
@@ -118,7 +118,10 @@ float layernorm2d_fwd_(const S& s, A a)
Traits_
::
kPadN
,
Traits_
::
kPadN
,
Traits_
::
kSaveMeanInvStd
,
Traits_
::
kSaveMeanInvStd
,
Traits_
::
kTwoPass
>
;
Traits_
::
kTwoPass
>
;
using
Pipeline
=
ck_tile
::
Layernorm2dFwdRowwisePipeline
<
PipelineProblem
>
;
using
OnePassPipeline
=
ck_tile
::
Layernorm2dFwdOnePassPipeline
<
PipelineProblem
>
;
using
TwoPassPipeline
=
ck_tile
::
Layernorm2dFwdTwoPassPipeline
<
PipelineProblem
>
;
using
Pipeline
=
std
::
conditional_t
<
Traits_
::
kTwoPass
,
TwoPassPipeline
,
OnePassPipeline
>
;
using
Kernel
=
ck_tile
::
Layernorm2dFwd
<
Pipeline
>
;
using
Kernel
=
ck_tile
::
Layernorm2dFwd
<
Pipeline
>
;
...
...
example/ck_tile/02_layernorm2d/layernorm2d_fwd.cpp
View file @
93ec1681
...
@@ -6,8 +6,8 @@
...
@@ -6,8 +6,8 @@
template
<
typename
DataType
>
template
<
typename
DataType
>
auto
get_elimit
()
auto
get_elimit
()
{
{
double
rtol
=
1e-
3
;
double
rtol
=
1e-
2
;
double
atol
=
1e-
3
;
double
atol
=
1e-
2
;
return
ck_tile
::
make_tuple
(
rtol
,
atol
);
return
ck_tile
::
make_tuple
(
rtol
,
atol
);
}
}
...
...
example/ck_tile/02_layernorm2d/script/smoke_test.sh
100644 → 100755
View file @
93ec1681
...
@@ -25,4 +25,6 @@ $EXE -prec=$pr_i -m=5 -n=2040
...
@@ -25,4 +25,6 @@ $EXE -prec=$pr_i -m=5 -n=2040
$EXE
-prec
=
$pr_i
-m
=
7
-n
=
2734
$EXE
-prec
=
$pr_i
-m
=
7
-n
=
2734
$EXE
-prec
=
$pr_i
-m
=
1
-n
=
3182
$EXE
-prec
=
$pr_i
-m
=
1
-n
=
3182
$EXE
-prec
=
$pr_i
-m
=
9
-n
=
4096
$EXE
-prec
=
$pr_i
-m
=
9
-n
=
4096
$EXE
-prec
=
$pr_i
-m
=
3
-n
=
8192
$EXE
-prec
=
$pr_i
-m
=
1
-n
=
23547
done
done
include/ck_tile/ops/layernorm2d.hpp
View file @
93ec1681
...
@@ -6,6 +6,7 @@
...
@@ -6,6 +6,7 @@
#include "ck_tile/ops/layernorm2d/kernel/layernorm2d_fwd_kernel.hpp"
#include "ck_tile/ops/layernorm2d/kernel/layernorm2d_fwd_kernel.hpp"
#include "ck_tile/ops/layernorm2d/kernel/layernorm2d_fwd_shape.hpp"
#include "ck_tile/ops/layernorm2d/kernel/layernorm2d_fwd_shape.hpp"
#include "ck_tile/ops/layernorm2d/pipeline/layernorm2d_fwd_rowwise_default_policy.hpp"
#include "ck_tile/ops/layernorm2d/pipeline/layernorm2d_fwd_rowwise_default_policy.hpp"
#include "ck_tile/ops/layernorm2d/pipeline/layernorm2d_fwd_rowwise_pipeline.hpp"
#include "ck_tile/ops/layernorm2d/pipeline/layernorm2d_fwd_one_pass_pipeline.hpp"
#include "ck_tile/ops/layernorm2d/pipeline/layernorm2d_fwd_two_pass_pipeline.hpp"
#include "ck_tile/ops/layernorm2d/pipeline/layernorm2d_fwd_rowwise_problem.hpp"
#include "ck_tile/ops/layernorm2d/pipeline/layernorm2d_fwd_rowwise_problem.hpp"
#include "ck_tile/ops/common/tensor_layout.hpp"
#include "ck_tile/ops/common/tensor_layout.hpp"
include/ck_tile/ops/layernorm2d/pipeline/layernorm2d_fwd_
rowwise
_pipeline.hpp
→
include/ck_tile/ops/layernorm2d/pipeline/layernorm2d_fwd_
one_pass
_pipeline.hpp
View file @
93ec1681
...
@@ -11,7 +11,7 @@
...
@@ -11,7 +11,7 @@
namespace
ck_tile
{
namespace
ck_tile
{
template
<
typename
Problem_
,
typename
Policy_
=
Layernorm2dFwdRowwiseDefaultPolicy
>
template
<
typename
Problem_
,
typename
Policy_
=
Layernorm2dFwdRowwiseDefaultPolicy
>
struct
Layernorm2dFwd
Rowwise
Pipeline
struct
Layernorm2dFwd
OnePass
Pipeline
{
{
using
Problem
=
ck_tile
::
remove_cvref_t
<
Problem_
>
;
using
Problem
=
ck_tile
::
remove_cvref_t
<
Problem_
>
;
using
Policy
=
ck_tile
::
remove_cvref_t
<
Policy_
>
;
using
Policy
=
ck_tile
::
remove_cvref_t
<
Policy_
>
;
...
@@ -32,7 +32,6 @@ struct Layernorm2dFwdRowwisePipeline
...
@@ -32,7 +32,6 @@ struct Layernorm2dFwdRowwisePipeline
static
constexpr
bool
kNeedCrossWarpSync
=
Problem
::
kNeedCrossWarpSync
;
static
constexpr
bool
kNeedCrossWarpSync
=
Problem
::
kNeedCrossWarpSync
;
static
constexpr
bool
kPadM
=
false
;
// TODO - BlockLayernorm2dFwdProblem::kPadM
static
constexpr
bool
kPadM
=
false
;
// TODO - BlockLayernorm2dFwdProblem::kPadM
static
constexpr
bool
kPadN
=
Problem
::
kPadN
;
static
constexpr
bool
kPadN
=
Problem
::
kPadN
;
static
constexpr
bool
kTwoPass
=
Problem
::
kTwoPass
;
static
constexpr
const
char
*
name
=
[]()
{
static
constexpr
const
char
*
name
=
[]()
{
if
constexpr
(
kNeedCrossWarpSync
)
if
constexpr
(
kNeedCrossWarpSync
)
...
...
include/ck_tile/ops/layernorm2d/pipeline/layernorm2d_fwd_two_pass_pipeline.hpp
0 → 100644
View file @
93ec1681
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core.hpp"
#include "ck_tile/ops/layernorm2d/pipeline/layernorm2d_fwd_rowwise_default_policy.hpp"
#include <string>
#include <type_traits>
namespace
ck_tile
{
template
<
typename
Problem_
,
typename
Policy_
=
Layernorm2dFwdRowwiseDefaultPolicy
>
struct
Layernorm2dFwdTwoPassPipeline
{
using
Problem
=
ck_tile
::
remove_cvref_t
<
Problem_
>
;
using
Policy
=
ck_tile
::
remove_cvref_t
<
Policy_
>
;
using
XDataType
=
ck_tile
::
remove_cvref_t
<
typename
Problem
::
XDataType
>
;
using
GammaDataType
=
ck_tile
::
remove_cvref_t
<
typename
Problem
::
GammaDataType
>
;
using
BetaDataType
=
ck_tile
::
remove_cvref_t
<
typename
Problem
::
BetaDataType
>
;
using
ComputeDataType
=
ck_tile
::
remove_cvref_t
<
typename
Problem
::
ComputeDataType
>
;
using
YDataType
=
ck_tile
::
remove_cvref_t
<
typename
Problem
::
YDataType
>
;
using
MeanDataType
=
ck_tile
::
remove_cvref_t
<
typename
Problem
::
MeanDataType
>
;
using
InvStdDataType
=
ck_tile
::
remove_cvref_t
<
typename
Problem
::
InvStdDataType
>
;
static
constexpr
bool
kHasGamma
=
!
std
::
is_same_v
<
GammaDataType
,
ck_tile
::
null_type
>
;
static
constexpr
bool
kHasBeta
=
!
std
::
is_same_v
<
BetaDataType
,
ck_tile
::
null_type
>
;
static
constexpr
bool
kSaveMean
=
Problem
::
kSaveMeanInvStd
;
static
constexpr
bool
kSaveInvStd
=
Problem
::
kSaveMeanInvStd
;
static
constexpr
bool
kNeedCrossWarpSync
=
Problem
::
kNeedCrossWarpSync
;
static
constexpr
bool
kPadM
=
false
;
// TODO - BlockLayernorm2dFwdProblem::kPadM
static
constexpr
bool
kPadN
=
Problem
::
kPadN
;
static
constexpr
const
char
*
name
=
[]()
{
if
constexpr
(
kNeedCrossWarpSync
)
return
"bpr"
;
// block per row
else
return
"wpr"
;
// warp per row
}();
CK_TILE_HOST_DEVICE
static
constexpr
index_t
GetSmemSize
()
{
return
Policy
::
template
GetSmemSize
<
Problem
>();
}
template
<
typename
XWindow
,
typename
GammaWindow
,
typename
BetaWindow
,
typename
YWindow
,
typename
MeanWindow
,
typename
InvStdWindow
>
CK_TILE_DEVICE
auto
operator
()(
const
XWindow
&
x_window_
,
const
GammaWindow
&
gamma_window_
,
const
BetaWindow
&
beta_window_
,
YWindow
&
y_window
,
MeanWindow
&
mean_window
,
InvStdWindow
&
inv_std_window
,
ComputeDataType
epsilon
,
ck_tile
::
index_t
row_size
,
void
*
smem
)
const
{
auto
x_window
=
make_tile_window
(
x_window_
,
Policy
::
template
MakeXBlockTileDistribution
<
Problem
>());
auto
gamma_window
=
make_tile_window
(
gamma_window_
,
Policy
::
template
MakeGammaBetaBlockTileDistribution
<
Problem
>());
auto
beta_window
=
make_tile_window
(
beta_window_
,
Policy
::
template
MakeGammaBetaBlockTileDistribution
<
Problem
>());
// Problem::BlockShape
static
constexpr
index_t
Block_N
=
Problem
::
BlockShape
::
Block_N
;
index_t
num_n_tile_iteration
=
__builtin_amdgcn_readfirstlane
(
integer_divide_ceil
(
row_size
,
Block_N
));
int
cur_count
=
0
;
int
max_count
=
block_tile_welford_calculate_max_count
<
typename
Problem
::
BlockShape
>
(
row_size
);
auto
block_welford
=
Policy
::
template
GetBlockWelford
<
Problem
>();
auto
block_welford_sync
=
Policy
::
template
GetBlockWelfordSync
<
Problem
>();
auto
block_welford_cross_warp_sync
=
Policy
::
template
GetBlockWelfordCrossWarpSync
<
Problem
>();
using
XTensorType
=
decltype
(
load_tile
(
x_window
));
auto
mean
=
block_welford
.
template
MakeMeanVarBlockTile
<
XTensorType
>();
auto
var
=
block_welford
.
template
MakeMeanVarBlockTile
<
XTensorType
>();
for
(
int
iN
=
__builtin_amdgcn_readfirstlane
(
0
);
iN
<
num_n_tile_iteration
;
++
iN
)
{
const
auto
x
=
load_tile
(
x_window
);
block_welford
(
x
,
mean
,
var
,
cur_count
,
max_count
);
move_tile_window
(
x_window
,
{
0
,
Block_N
});
}
block_welford_sync
(
mean
,
var
,
cur_count
);
block_welford_cross_warp_sync
(
mean
,
var
,
cur_count
,
smem
);
block_tile_welford_post_scale_var
(
var
,
cur_count
);
// compute inv-std
auto
inv_std
=
tile_elementwise_in
(
[
&
](
const
auto
&
v_
)
{
return
type_convert
<
ComputeDataType
>
(
1.0
f
)
/
(
sqrt
(
v_
)
+
epsilon
);
},
var
);
if
constexpr
(
kSaveMean
)
store_tile
(
mean_window
,
cast_tile
<
MeanDataType
>
(
mean
));
if
constexpr
(
kSaveInvStd
)
store_tile
(
inv_std_window
,
cast_tile
<
InvStdDataType
>
(
inv_std
));
// reverse read x to reuse cache
ck_tile
::
index_t
stride_to_right_most_window
=
row_size
%
Block_N
==
0
?
row_size
-
Block_N
:
row_size
-
row_size
%
Block_N
;
// x_window.foo();
// gamma_window.foo();
move_tile_window
(
x_window
,
{
0
,
-
Block_N
});
move_tile_window
(
gamma_window
,
{
stride_to_right_most_window
});
move_tile_window
(
beta_window
,
{
stride_to_right_most_window
});
move_tile_window
(
y_window
,
{
0
,
stride_to_right_most_window
});
// layernorm computation
for
(
int
iN
=
__builtin_amdgcn_readfirstlane
(
0
);
iN
<
num_n_tile_iteration
;
++
iN
)
{
const
auto
x
=
load_tile
(
x_window
);
// load gamma/beta (TODO: support no gamma/beta?)
const
auto
gamma
=
load_tile
(
gamma_window
);
const
auto
beta
=
load_tile
(
beta_window
);
auto
y
=
make_static_distributed_tensor
<
YDataType
>
(
x
.
get_tile_distribution
());
sweep_tile
(
y
,
[
&
,
mean_
=
mean
](
auto
idx
)
{
constexpr
auto
i_idx
=
make_tuple
(
idx
[
number
<
0
>
{}]);
constexpr
auto
j_idx
=
make_tuple
(
idx
[
number
<
1
>
{}]);
const
auto
gamma_
=
type_convert
<
ComputeDataType
>
(
gamma
[
j_idx
]);
const
auto
beta_
=
type_convert
<
ComputeDataType
>
(
beta
[
j_idx
]);
const
auto
x_
=
type_convert
<
ComputeDataType
>
(
x
[
idx
]);
auto
y_
=
(
x_
-
mean_
[
i_idx
])
*
inv_std
[
i_idx
]
*
gamma_
+
beta_
;
y
(
idx
)
=
type_convert
<
YDataType
>
(
y_
);
});
store_tile
(
y_window
,
y
);
move_tile_window
(
x_window
,
{
0
,
-
Block_N
});
move_tile_window
(
gamma_window
,
{
-
Block_N
});
move_tile_window
(
beta_window
,
{
-
Block_N
});
move_tile_window
(
y_window
,
{
0
,
-
Block_N
});
}
}
};
}
// namespace ck_tile
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment