Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel_ROCM
Commits
4b59b5c9
Commit
4b59b5c9
authored
Oct 27, 2024
by
carlushuang
Browse files
add prenorm/postnorm support, refactor using generate.py
parent
4d5248e2
Changes
35
Show whitespace changes
Inline
Side-by-side
Showing
15 changed files
with
411 additions
and
309 deletions
+411
-309
example/ck_tile/02_layernorm2d/instances/layernorm2d_fwd_fp16_n512_instance.cpp
...ernorm2d/instances/layernorm2d_fwd_fp16_n512_instance.cpp
+0
-13
example/ck_tile/02_layernorm2d/instances/layernorm2d_fwd_fp16_n64_n128_instance.cpp
...rm2d/instances/layernorm2d_fwd_fp16_n64_n128_instance.cpp
+0
-12
example/ck_tile/02_layernorm2d/instances/layernorm2d_fwd_fp16_n768_instance.cpp
...ernorm2d/instances/layernorm2d_fwd_fp16_n768_instance.cpp
+0
-12
example/ck_tile/02_layernorm2d/instances/layernorm2d_fwd_instance_common.hpp
...layernorm2d/instances/layernorm2d_fwd_instance_common.hpp
+0
-67
example/ck_tile/02_layernorm2d/layernorm2d_fwd.cpp
example/ck_tile/02_layernorm2d/layernorm2d_fwd.cpp
+93
-22
example/ck_tile/02_layernorm2d/layernorm2d_fwd.hpp
example/ck_tile/02_layernorm2d/layernorm2d_fwd.hpp
+11
-76
example/ck_tile/02_layernorm2d/script/perf_test.sh
example/ck_tile/02_layernorm2d/script/perf_test.sh
+33
-33
example/ck_tile/02_layernorm2d/script/smoke_test.sh
example/ck_tile/02_layernorm2d/script/smoke_test.sh
+27
-25
include/ck_tile/core/tensor/null_tile_window.hpp
include/ck_tile/core/tensor/null_tile_window.hpp
+7
-0
include/ck_tile/ops/layernorm2d.hpp
include/ck_tile/ops/layernorm2d.hpp
+1
-0
include/ck_tile/ops/layernorm2d/kernel/layernorm2d_fwd_kernel.hpp
...ck_tile/ops/layernorm2d/kernel/layernorm2d_fwd_kernel.hpp
+84
-18
include/ck_tile/ops/layernorm2d/pipeline/layernorm2d_fwd_pipeline_one_pass.hpp
...ayernorm2d/pipeline/layernorm2d_fwd_pipeline_one_pass.hpp
+42
-12
include/ck_tile/ops/layernorm2d/pipeline/layernorm2d_fwd_pipeline_problem.hpp
...layernorm2d/pipeline/layernorm2d_fwd_pipeline_problem.hpp
+2
-6
include/ck_tile/ops/layernorm2d/pipeline/layernorm2d_fwd_pipeline_two_pass.hpp
...ayernorm2d/pipeline/layernorm2d_fwd_pipeline_two_pass.hpp
+55
-13
include/ck_tile/ops/layernorm2d/pipeline/layernorm2d_fwd_traits.hpp
..._tile/ops/layernorm2d/pipeline/layernorm2d_fwd_traits.hpp
+56
-0
No files found.
example/ck_tile/02_layernorm2d/instances/layernorm2d_fwd_fp16_n512_instance.cpp
deleted
100644 → 0
View file @
4d5248e2
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#include "layernorm2d_fwd_instance_common.hpp"
// clang-format off
// rm rn tm tn vn pd mv 2p
template
float
layernorm2d_fwd_
<
trait_
<
ck_tile
::
fp16_t
,
1
,
1
,
4
,
64
,
8
,
true
,
false
,
false
>
>
(
const
S
&
,
A
);
template
float
layernorm2d_fwd_
<
trait_
<
ck_tile
::
fp16_t
,
1
,
2
,
4
,
64
,
4
,
true
,
false
,
false
>
>
(
const
S
&
,
A
);
template
float
layernorm2d_fwd_
<
trait_
<
ck_tile
::
fp16_t
,
1
,
4
,
4
,
64
,
2
,
true
,
false
,
false
>
>
(
const
S
&
,
A
);
template
float
layernorm2d_fwd_
<
trait_
<
ck_tile
::
fp16_t
,
1
,
8
,
4
,
64
,
1
,
true
,
false
,
false
>
>
(
const
S
&
,
A
);
// clang-format on
example/ck_tile/02_layernorm2d/instances/layernorm2d_fwd_fp16_n64_n128_instance.cpp
deleted
100644 → 0
View file @
4d5248e2
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#include "layernorm2d_fwd_instance_common.hpp"
// clang-format off
// rm rn tm tn vn pd mv 2p
template
float
layernorm2d_fwd_
<
trait_
<
ck_tile
::
fp16_t
,
1
,
1
,
4
,
64
,
1
,
true
,
false
,
false
>
>
(
const
S
&
,
A
);
template
float
layernorm2d_fwd_
<
trait_
<
ck_tile
::
fp16_t
,
1
,
1
,
4
,
64
,
2
,
true
,
false
,
false
>
>
(
const
S
&
,
A
);
template
float
layernorm2d_fwd_
<
trait_
<
ck_tile
::
fp16_t
,
1
,
2
,
4
,
64
,
1
,
true
,
false
,
false
>
>
(
const
S
&
,
A
);
// clang-format on
example/ck_tile/02_layernorm2d/instances/layernorm2d_fwd_fp16_n768_instance.cpp
deleted
100644 → 0
View file @
4d5248e2
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#include "layernorm2d_fwd_instance_common.hpp"
// clang-format off
// rm rn tm tn vn pd mv 2p
template
float
layernorm2d_fwd_
<
trait_
<
ck_tile
::
fp16_t
,
1
,
3
,
4
,
64
,
4
,
true
,
false
,
false
>
>
(
const
S
&
,
A
);
template
float
layernorm2d_fwd_
<
trait_
<
ck_tile
::
fp16_t
,
1
,
6
,
4
,
64
,
2
,
true
,
false
,
false
>
>
(
const
S
&
,
A
);
template
float
layernorm2d_fwd_
<
trait_
<
ck_tile
::
fp16_t
,
1
,
12
,
4
,
64
,
1
,
true
,
false
,
false
>
>
(
const
S
&
,
A
);
// clang-format on
example/ck_tile/02_layernorm2d/instances/layernorm2d_fwd_instance_common.hpp
deleted
100644 → 0
View file @
4d5248e2
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#include <ck_tile/core.hpp>
#include "layernorm2d_fwd.hpp"
#include <iostream>
#pragma once
using
S
=
ck_tile
::
stream_config
;
using
A
=
layernorm2d_fwd_args
;
template
<
typename
DataType_
,
ck_tile
::
index_t
Repeat_M_
,
// each thread repeat along M
ck_tile
::
index_t
Repeat_N_
,
// each thread repeat along N
ck_tile
::
index_t
ThreadPerBlock_M_
,
// num threads along M
ck_tile
::
index_t
ThreadPerBlock_N_
,
// num threads along N
ck_tile
::
index_t
Vector_N_
,
// vector size along N
bool
kPadN_
,
bool
kSaveMeanInvStd_
,
bool
kTwoPass_
>
using
trait_
=
layernorm2d_fwd_traits_
<
DataType_
,
Repeat_M_
,
Repeat_N_
,
ThreadPerBlock_M_
,
ThreadPerBlock_N_
,
Vector_N_
,
kPadN_
,
kSaveMeanInvStd_
,
kTwoPass_
>
;
template
<
typename
Traits_
>
float
layernorm2d_fwd_
(
const
S
&
s
,
A
a
)
{
using
DataType
=
typename
Traits_
::
DataType
;
using
PipelineProblem
=
ck_tile
::
Layernorm2dFwdPipelineProblem
<
typename
LayerNormTypeConfig
<
DataType
>::
XDataType
,
typename
LayerNormTypeConfig
<
DataType
>::
GammaDataType
,
typename
LayerNormTypeConfig
<
DataType
>::
BetaDataType
,
typename
LayerNormTypeConfig
<
DataType
>::
ComputeDataType
,
typename
LayerNormTypeConfig
<
DataType
>::
YDataType
,
typename
LayerNormTypeConfig
<
DataType
>::
MeanDataType
,
typename
LayerNormTypeConfig
<
DataType
>::
InvStdDataType
,
typename
Traits_
::
Shape
,
Traits_
::
kPadN
,
Traits_
::
kSaveMeanInvStd
,
Traits_
::
kTwoPass
>
;
using
OnePassPipeline
=
ck_tile
::
Layernorm2dFwdPipelineOnePass
<
PipelineProblem
>
;
using
TwoPassPipeline
=
ck_tile
::
Layernorm2dFwdPipelineTwoPass
<
PipelineProblem
>
;
using
Pipeline
=
std
::
conditional_t
<
Traits_
::
kTwoPass
,
TwoPassPipeline
,
OnePassPipeline
>
;
using
Kernel
=
ck_tile
::
Layernorm2dFwd
<
Pipeline
>
;
const
dim3
grids
=
Kernel
::
GridSize
(
a
);
constexpr
dim3
blocks
=
Kernel
::
BlockSize
();
constexpr
ck_tile
::
index_t
kBlockPerCu
=
1
;
auto
kargs
=
Kernel
::
MakeKargs
(
a
);
if
(
s
.
log_level_
>
0
)
std
::
cout
<<
", "
<<
Kernel
::
GetName
()
<<
std
::
flush
;
return
ck_tile
::
launch_kernel
(
s
,
ck_tile
::
make_kernel
<
blocks
.
x
,
kBlockPerCu
>
(
Kernel
{},
grids
,
blocks
,
0
,
kargs
));
}
example/ck_tile/02_layernorm2d/layernorm2d_fwd.cpp
View file @
4b59b5c9
#include "ck_tile/host.hpp"
#include "layernorm2d_fwd.hpp"
#include <cstring>
#include <algorithm>
// different threshold for different dtype
template
<
typename
DataType
>
...
...
@@ -29,7 +30,13 @@ auto create_args(int argc, char* argv[])
.
insert
(
"save_mv"
,
"0"
,
"save mean/variance(invstd) or not. set to 1 in training case"
)
.
insert
(
"v"
,
"1"
,
"cpu validation or not"
)
.
insert
(
"kname"
,
"1"
,
"print kernel name or not"
)
.
insert
(
"prec"
,
"fp16"
,
"precision"
)
.
insert
(
"prec_i"
,
"fp16"
,
"input precision"
)
.
insert
(
"prec_o"
,
"auto"
,
"output precision, set auto will be the same as input"
)
.
insert
(
"fadd"
,
"0"
,
"fused-add, 0:no fused add, 1:fused-prenorm(preadd+store), 2:fused-postnorm(preadd)"
)
.
insert
(
"fsweep"
,
"0"
,
"fused-sweep"
)
.
insert
(
"warmup"
,
"5"
,
"cold iter"
)
.
insert
(
"repeat"
,
"20"
,
"hot iter"
);
...
...
@@ -37,7 +44,7 @@ auto create_args(int argc, char* argv[])
return
std
::
make_tuple
(
result
,
arg_parser
);
}
template
<
typename
DataType
,
bool
SaveMeanVar
>
template
<
typename
InDataType
,
typename
Out
DataType
,
bool
SaveMeanVar
>
bool
run
(
const
ck_tile
::
ArgParser
&
arg_parser
)
{
ck_tile
::
index_t
m
=
arg_parser
.
get_int
(
"m"
);
...
...
@@ -46,20 +53,29 @@ bool run(const ck_tile::ArgParser& arg_parser)
if
(
stride
<
0
)
stride
=
n
;
float
epsilon
=
arg_parser
.
get_float
(
"e"
);
std
::
string
data_type
=
arg_parser
.
get_str
(
"prec"
);
std
::
string
prec_i
=
arg_parser
.
get_str
(
"prec_i"
);
std
::
string
prec_o
=
arg_parser
.
get_str
(
"prec_o"
);
if
(
prec_o
==
"auto"
)
{
prec_o
=
prec_i
;
}
int
kname
=
arg_parser
.
get_int
(
"kname"
);
int
do_validation
=
arg_parser
.
get_int
(
"v"
);
int
warmup
=
arg_parser
.
get_int
(
"warmup"
);
int
repeat
=
arg_parser
.
get_int
(
"repeat"
);
int
fused_add
=
arg_parser
.
get_int
(
"fadd"
);
int
fused_sweep
=
arg_parser
.
get_int
(
"fsweep"
);
assert
(
stride
>=
n
);
using
TypeConfig
=
LayerNormTypeConfig
<
DataType
>
;
using
TypeConfig
=
LayerNormTypeConfig
<
InDataType
,
Out
DataType
>
;
using
XDataType
=
typename
TypeConfig
::
XDataType
;
using
YDataType
=
typename
TypeConfig
::
YDataType
;
using
GammaDataType
=
typename
TypeConfig
::
GammaDataType
;
using
BetaDataType
=
typename
TypeConfig
::
BetaDataType
;
using
SXDataType
=
XDataType
;
using
SYDataType
=
YDataType
;
using
MeanDataType
=
std
::
conditional_t
<
SaveMeanVar
,
typename
TypeConfig
::
MeanDataType
,
ck_tile
::
null_type
>
;
...
...
@@ -73,6 +89,9 @@ bool run(const ck_tile::ArgParser& arg_parser)
ck_tile
::
HostTensor
<
GammaDataType
>
gamma_host
({
n
});
ck_tile
::
HostTensor
<
BetaDataType
>
beta_host
({
n
});
ck_tile
::
HostTensor
<
SXDataType
>
sx_host
({
m
,
n
},
{
stride
,
1
});
ck_tile
::
HostTensor
<
SYDataType
>
sy_host
({
m
,
n
},
{
stride
,
1
});
ck_tile
::
HostTensor
<
YDataType
>
y_host_ref
({
m
,
n
},
{
stride
,
1
});
ck_tile
::
HostTensor
<
YDataType
>
y_host_dev
({
m
,
n
},
{
stride
,
1
});
...
...
@@ -88,19 +107,25 @@ bool run(const ck_tile::ArgParser& arg_parser)
ck_tile
::
DeviceMem
beta_buf
(
beta_host
.
get_element_space_size_in_bytes
());
ck_tile
::
DeviceMem
y_buf
(
y_host_dev
.
get_element_space_size_in_bytes
());
ck_tile
::
DeviceMem
sx_buf
(
sx_host
.
get_element_space_size_in_bytes
());
ck_tile
::
DeviceMem
sy_buf
(
sy_host
.
get_element_space_size_in_bytes
());
x_buf
.
ToDevice
(
x_host
.
data
());
gamma_buf
.
ToDevice
(
gamma_host
.
data
());
beta_buf
.
ToDevice
(
beta_host
.
data
());
sx_buf
.
ToDevice
(
sx_host
.
data
());
std
::
cout
<<
"["
<<
data_type
<<
"]"
std
::
cout
<<
"["
<<
prec_i
<<
"]"
<<
" m:"
<<
m
<<
", n:"
<<
n
<<
", stride:"
<<
stride
<<
std
::
flush
;
layernorm2d_fwd_traits
traits
{
data_type
,
SaveMeanVar
};
layernorm2d_fwd_traits
traits
{
prec_i
,
prec_o
,
SaveMeanVar
,
fused_add
,
fused_sweep
};
layernorm2d_fwd_args
args
{
x_buf
.
GetDeviceBuffer
(),
fused_add
!=
0
?
sx_buf
.
GetDeviceBuffer
()
:
nullptr
,
gamma_buf
.
GetDeviceBuffer
(),
beta_buf
.
GetDeviceBuffer
(),
y_buf
.
GetDeviceBuffer
(),
fused_add
==
1
?
sy_buf
.
GetDeviceBuffer
()
:
nullptr
,
nullptr
,
nullptr
,
epsilon
,
...
...
@@ -111,6 +136,12 @@ bool run(const ck_tile::ArgParser& arg_parser)
float
ave_time
=
layernorm2d_fwd
(
traits
,
args
,
ck_tile
::
stream_config
{
nullptr
,
true
,
kname
?
1
:
0
,
warmup
,
repeat
});
if
(
ave_time
<
0
)
{
std
::
cout
<<
" not supported!"
<<
std
::
endl
<<
std
::
flush
;
return
false
;
}
std
::
size_t
num_byte
=
sizeof
(
XDataType
)
*
m
*
n
+
sizeof
(
GammaDataType
)
*
n
+
sizeof
(
BetaDataType
)
*
n
+
sizeof
(
YDataType
)
*
m
*
n
;
...
...
@@ -122,6 +153,17 @@ bool run(const ck_tile::ArgParser& arg_parser)
if
(
do_validation
)
{
// reference
if
(
fused_add
!=
0
)
{
// fused pre_add/pre_add_store
// TODO we accumulate directly to x_host for simplcity here...
std
::
transform
(
x_host
.
mData
.
cbegin
(),
x_host
.
mData
.
cend
(),
sx_host
.
mData
.
cbegin
(),
x_host
.
mData
.
begin
(),
std
::
plus
<
XDataType
>
{});
}
ck_tile
::
reference_layernorm2d_fwd
<
XDataType
,
GammaDataType
,
BetaDataType
,
...
...
@@ -133,11 +175,22 @@ bool run(const ck_tile::ArgParser& arg_parser)
y_buf
.
FromDevice
(
y_host_dev
.
data
());
auto
[
rtol
,
atol
]
=
get_elimit
<
DataType
>
();
ck_tile
::
HostTensor
<
SYDataType
>
sy_host_dev
({
m
,
n
},
{
stride
,
1
});
if
(
fused_add
==
1
)
{
sy_buf
.
FromDevice
(
sy_host_dev
.
data
());
}
auto
[
rtol
,
atol
]
=
get_elimit
<
InDataType
>
();
if
(
stride
==
n
)
{
pass
=
ck_tile
::
check_err
(
y_host_dev
,
y_host_ref
,
std
::
string
(
"OUT Error: Incorrect results!"
),
rtol
,
atol
);
if
(
fused_add
==
1
)
{
pass
&=
ck_tile
::
check_err
(
sy_host_dev
,
x_host
,
std
::
string
(
"ADD Error: Incorrect results!"
),
rtol
,
atol
);
}
}
else
{
...
...
@@ -153,6 +206,19 @@ bool run(const ck_tile::ArgParser& arg_parser)
std
::
string
(
"] Error: Incorrect results!"
),
rtol
,
atol
);
if
(
fused_add
==
1
)
{
std
::
vector
<
SYDataType
>
sy_host_dev_row
(
sy_host_dev
.
begin
()
+
i_r
*
stride
,
sy_host_dev
.
begin
()
+
i_r
*
stride
+
n
);
std
::
vector
<
SYDataType
>
sy_host_ref_row
(
x_host
.
begin
()
+
i_r
*
stride
,
x_host
.
begin
()
+
i_r
*
stride
+
n
);
pass
&=
ck_tile
::
check_err
(
sy_host_dev_row
,
sy_host_ref_row
,
std
::
string
(
"ADD["
)
+
std
::
to_string
(
i_r
)
+
std
::
string
(
"] Error: Incorrect results!"
),
rtol
,
atol
);
}
}
}
...
...
@@ -168,23 +234,28 @@ int main(int argc, char* argv[])
if
(
!
result
)
return
-
1
;
const
std
::
string
data_type
=
arg_parser
.
get_str
(
"prec"
);
std
::
string
prec_i
=
arg_parser
.
get_str
(
"prec_i"
);
std
::
string
prec_o
=
arg_parser
.
get_str
(
"prec_o"
);
if
(
prec_o
==
"auto"
)
{
prec_o
=
prec_i
;
}
int
save_mv
=
arg_parser
.
get_int
(
"save_mv"
);
if
(
data_type
==
"fp16"
&&
save_mv
)
if
(
prec_i
==
"fp16"
&&
prec_o
==
"fp16"
&&
save_mv
)
{
return
run
<
ck_tile
::
half_t
,
true
>
(
arg_parser
)
?
0
:
-
2
;
return
run
<
ck_tile
::
half_t
,
ck_tile
::
half_t
,
true
>
(
arg_parser
)
?
0
:
-
2
;
}
else
if
(
data_type
==
"fp16"
&&
!
save_mv
)
else
if
(
prec_i
==
"fp16"
&&
prec_o
==
"fp16"
&&
!
save_mv
)
{
return
run
<
ck_tile
::
half_t
,
false
>
(
arg_parser
)
?
0
:
-
2
;
return
run
<
ck_tile
::
half_t
,
ck_tile
::
half_t
,
false
>
(
arg_parser
)
?
0
:
-
2
;
}
else
if
(
data_type
==
"bf16"
&&
save_mv
)
else
if
(
prec_i
==
"bf16"
&&
prec_o
==
"bf16"
&&
save_mv
)
{
return
run
<
ck_tile
::
bf16_t
,
true
>
(
arg_parser
)
?
0
:
-
2
;
return
run
<
ck_tile
::
bf16_t
,
ck_tile
::
bf16_t
,
true
>
(
arg_parser
)
?
0
:
-
2
;
}
else
if
(
data_type
==
"bf16"
&&
!
save_mv
)
else
if
(
prec_i
==
"bf16"
&&
prec_o
==
"bf16"
&&
!
save_mv
)
{
return
run
<
ck_tile
::
bf16_t
,
true
>
(
arg_parser
)
?
0
:
-
2
;
return
run
<
ck_tile
::
bf16_t
,
ck_tile
::
bf16_t
,
true
>
(
arg_parser
)
?
0
:
-
2
;
}
return
-
3
;
...
...
example/ck_tile/02_layernorm2d/layernorm2d_fwd.hpp
View file @
4b59b5c9
...
...
@@ -8,14 +8,14 @@
#include "ck_tile/ops/layernorm2d.hpp"
#include <string>
template
<
typename
Data
Type
>
template
<
typename
InType
,
typename
Out
Type
>
struct
LayerNormTypeConfig
;
template
<
>
struct
LayerNormTypeConfig
<
ck_tile
::
half_t
>
template
<
typename
OutType
>
struct
LayerNormTypeConfig
<
ck_tile
::
half_t
,
OutType
>
{
using
XDataType
=
ck_tile
::
half_t
;
using
YDataType
=
ck_tile
::
half_t
;
using
YDataType
=
OutType
;
using
GammaDataType
=
ck_tile
::
half_t
;
using
BetaDataType
=
ck_tile
::
half_t
;
using
MeanDataType
=
ck_tile
::
half_t
;
...
...
@@ -23,11 +23,11 @@ struct LayerNormTypeConfig<ck_tile::half_t>
using
ComputeDataType
=
float
;
};
template
<
>
struct
LayerNormTypeConfig
<
ck_tile
::
bf16_t
>
template
<
typename
OutType
>
struct
LayerNormTypeConfig
<
ck_tile
::
bf16_t
,
OutType
>
{
using
XDataType
=
ck_tile
::
bf16_t
;
using
YDataType
=
ck_tile
::
bf16_t
;
using
YDataType
=
OutType
;
using
GammaDataType
=
ck_tile
::
bf16_t
;
using
BetaDataType
=
ck_tile
::
bf16_t
;
using
MeanDataType
=
ck_tile
::
bf16_t
;
...
...
@@ -40,82 +40,17 @@ struct layernorm2d_fwd_args : public ck_tile::Layernorm2dFwdHostArgs
{
};
// this is used to pattern-match internl kernel implementation, not to instantiate kernel
template
<
typename
DataType_
,
ck_tile
::
index_t
Repeat_M_
,
// each thread repeat along M
ck_tile
::
index_t
Repeat_N_
,
// each thread repeat along N
ck_tile
::
index_t
ThreadPerBlock_M_
,
// num threads along M
ck_tile
::
index_t
ThreadPerBlock_N_
,
// num threads along N
ck_tile
::
index_t
Vector_N_
,
// vector size along N
bool
kPadN_
,
bool
kSaveMeanInvStd_
,
bool
kTwoPass_
>
struct
layernorm2d_fwd_traits_
{
using
DataType
=
ck_tile
::
remove_cvref_t
<
DataType_
>
;
static
constexpr
bool
is_warp_per_row
=
ThreadPerBlock_N_
<=
warpSize
;
static_assert
((
ThreadPerBlock_M_
*
ThreadPerBlock_N_
)
%
warpSize
==
0
);
static
constexpr
ck_tile
::
index_t
total_warps
=
(
ThreadPerBlock_M_
*
ThreadPerBlock_N_
)
/
warpSize
;
// num of warps along m
static
constexpr
ck_tile
::
index_t
BlockWarps_M
=
[]()
{
if
constexpr
(
is_warp_per_row
)
{
static_assert
(
warpSize
%
ThreadPerBlock_N_
==
0
);
return
total_warps
*
(
warpSize
/
ThreadPerBlock_N_
);
}
else
{
// static_assert(warpSize % ThreadPerBlock_M_ == 0);
return
total_warps
/
(
ThreadPerBlock_N_
/
warpSize
);
}
}();
// num of warps along n
static
constexpr
ck_tile
::
index_t
BlockWarps_N
=
[]()
{
if
constexpr
(
is_warp_per_row
)
{
static_assert
(
warpSize
%
ThreadPerBlock_N_
==
0
);
return
1
;
}
else
{
static_assert
(
ThreadPerBlock_N_
%
warpSize
==
0
);
return
ThreadPerBlock_N_
/
warpSize
;
}
}();
static
constexpr
ck_tile
::
index_t
Repeat_M
=
Repeat_M_
;
static
constexpr
ck_tile
::
index_t
Repeat_N
=
Repeat_N_
;
static
constexpr
ck_tile
::
index_t
Block_M
=
Repeat_M_
*
ThreadPerBlock_M_
;
static
constexpr
ck_tile
::
index_t
Block_N
=
Repeat_N_
*
ThreadPerBlock_N_
*
Vector_N_
;
static
constexpr
ck_tile
::
index_t
Warp_M
=
ThreadPerBlock_M_
/
BlockWarps_M
;
static
constexpr
ck_tile
::
index_t
Warp_N
=
ThreadPerBlock_N_
/
BlockWarps_N
*
Vector_N_
;
using
BlockTile
=
ck_tile
::
sequence
<
Block_M
,
Block_N
>
;
using
BlockWarps
=
ck_tile
::
sequence
<
BlockWarps_M
,
BlockWarps_N
>
;
using
WarpTile
=
ck_tile
::
sequence
<
Warp_M
,
Warp_N
>
;
using
Vector
=
ck_tile
::
sequence
<
1
,
Vector_N_
>
;
using
Shape
=
ck_tile
::
Layernorm2dShape
<
BlockTile
,
BlockWarps
,
WarpTile
,
Vector
>
;
static
constexpr
bool
kPadN
=
kPadN_
;
static
constexpr
bool
kSaveMeanInvStd
=
kSaveMeanInvStd_
;
static
constexpr
bool
kTwoPass
=
kTwoPass_
;
};
template
<
typename
Traits_
>
float
layernorm2d_fwd_
(
const
ck_tile
::
stream_config
&
s
,
layernorm2d_fwd_args
a
);
// This is the public API, will be generated by script
struct
layernorm2d_fwd_traits
{
std
::
string
data_type
;
std
::
string
prec_i
;
std
::
string
prec_o
;
bool
save_mean_var
;
int
fused_add
;
// 0:no-add, 1:pre-add, 2:post-add
int
fused_sweep
;
// 0:no-sweep,
};
float
layernorm2d_fwd
(
layernorm2d_fwd_traits
,
layernorm2d_fwd_args
,
const
ck_tile
::
stream_config
&
);
example/ck_tile/02_layernorm2d/script/perf_test.sh
View file @
4b59b5c9
...
...
@@ -2,37 +2,37 @@
# run from top of ck folder
EXE
=
build/bin/tile_example_layernorm2d_fwd
$EXE
-m
=
1
-n
=
1
-e
=
1e-12
-v
=
1
-prec
=
bf16
-repeat
=
1000
$EXE
-m
=
700
-n
=
80
-e
=
1e-12
-v
=
1
-prec
=
bf16
-repeat
=
1000
$EXE
-m
=
700
-n
=
128
-e
=
1e-12
-v
=
1
-prec
=
bf16
-repeat
=
1000
$EXE
-m
=
700
-n
=
144
-e
=
1e-12
-v
=
1
-prec
=
bf16
-repeat
=
1000
$EXE
-m
=
700
-n
=
168
-e
=
1e-12
-v
=
1
-prec
=
bf16
-repeat
=
1000
$EXE
-m
=
700
-n
=
184
-e
=
1e-12
-v
=
1
-prec
=
bf16
-repeat
=
1000
$EXE
-m
=
700
-n
=
256
-e
=
1e-12
-v
=
1
-prec
=
bf16
-repeat
=
1000
$EXE
-m
=
700
-n
=
288
-e
=
1e-12
-v
=
1
-prec
=
bf16
-repeat
=
1000
$EXE
-m
=
700
-n
=
344
-e
=
1e-12
-v
=
1
-prec
=
bf16
-repeat
=
1000
$EXE
-m
=
700
-n
=
376
-e
=
1e-12
-v
=
1
-prec
=
bf16
-repeat
=
1000
$EXE
-m
=
700
-n
=
448
-e
=
1e-12
-v
=
1
-prec
=
bf16
-repeat
=
1000
$EXE
-m
=
700
-n
=
512
-e
=
1e-12
-v
=
1
-prec
=
bf16
-repeat
=
1000
$EXE
-m
=
700
-n
=
924
-e
=
1e-12
-v
=
1
-prec
=
bf16
-repeat
=
1000
$EXE
-m
=
700
-n
=
1024
-e
=
1e-12
-v
=
1
-prec
=
bf16
-repeat
=
1000
$EXE
-m
=
700
-n
=
1078
-e
=
1e-12
-v
=
1
-prec
=
bf16
-repeat
=
1000
$EXE
-m
=
700
-n
=
1996
-e
=
1e-12
-v
=
1
-prec
=
bf16
-repeat
=
1000
$EXE
-m
=
700
-n
=
4080
-e
=
1e-12
-v
=
1
-prec
=
bf16
-repeat
=
1000
$EXE
-m
=
1
-n
=
1
-e
=
1e-12
-v
=
1
-prec
_i
=
bf16
-repeat
=
1000
$EXE
-m
=
700
-n
=
80
-e
=
1e-12
-v
=
1
-prec
_i
=
bf16
-repeat
=
1000
$EXE
-m
=
700
-n
=
128
-e
=
1e-12
-v
=
1
-prec
_i
=
bf16
-repeat
=
1000
$EXE
-m
=
700
-n
=
144
-e
=
1e-12
-v
=
1
-prec
_i
=
bf16
-repeat
=
1000
$EXE
-m
=
700
-n
=
168
-e
=
1e-12
-v
=
1
-prec
_i
=
bf16
-repeat
=
1000
$EXE
-m
=
700
-n
=
184
-e
=
1e-12
-v
=
1
-prec
_i
=
bf16
-repeat
=
1000
$EXE
-m
=
700
-n
=
256
-e
=
1e-12
-v
=
1
-prec
_i
=
bf16
-repeat
=
1000
$EXE
-m
=
700
-n
=
288
-e
=
1e-12
-v
=
1
-prec
_i
=
bf16
-repeat
=
1000
$EXE
-m
=
700
-n
=
344
-e
=
1e-12
-v
=
1
-prec
_i
=
bf16
-repeat
=
1000
$EXE
-m
=
700
-n
=
376
-e
=
1e-12
-v
=
1
-prec
_i
=
bf16
-repeat
=
1000
$EXE
-m
=
700
-n
=
448
-e
=
1e-12
-v
=
1
-prec
_i
=
bf16
-repeat
=
1000
$EXE
-m
=
700
-n
=
512
-e
=
1e-12
-v
=
1
-prec
_i
=
bf16
-repeat
=
1000
$EXE
-m
=
700
-n
=
924
-e
=
1e-12
-v
=
1
-prec
_i
=
bf16
-repeat
=
1000
$EXE
-m
=
700
-n
=
1024
-e
=
1e-12
-v
=
1
-prec
_i
=
bf16
-repeat
=
1000
$EXE
-m
=
700
-n
=
1078
-e
=
1e-12
-v
=
1
-prec
_i
=
bf16
-repeat
=
1000
$EXE
-m
=
700
-n
=
1996
-e
=
1e-12
-v
=
1
-prec
_i
=
bf16
-repeat
=
1000
$EXE
-m
=
700
-n
=
4080
-e
=
1e-12
-v
=
1
-prec
_i
=
bf16
-repeat
=
1000
$EXE
-m
=
700
-n
=
80
-e
=
1e-12
-v
=
1
-prec
=
fp16
-repeat
=
1000
$EXE
-m
=
700
-n
=
128
-e
=
1e-12
-v
=
1
-prec
=
fp16
-repeat
=
1000
$EXE
-m
=
700
-n
=
144
-e
=
1e-12
-v
=
1
-prec
=
fp16
-repeat
=
1000
$EXE
-m
=
700
-n
=
168
-e
=
1e-12
-v
=
1
-prec
=
fp16
-repeat
=
1000
$EXE
-m
=
700
-n
=
184
-e
=
1e-12
-v
=
1
-prec
=
fp16
-repeat
=
1000
$EXE
-m
=
700
-n
=
256
-e
=
1e-12
-v
=
1
-prec
=
fp16
-repeat
=
1000
$EXE
-m
=
700
-n
=
288
-e
=
1e-12
-v
=
1
-prec
=
fp16
-repeat
=
1000
$EXE
-m
=
700
-n
=
344
-e
=
1e-12
-v
=
1
-prec
=
fp16
-repeat
=
1000
$EXE
-m
=
700
-n
=
376
-e
=
1e-12
-v
=
1
-prec
=
fp16
-repeat
=
1000
$EXE
-m
=
700
-n
=
448
-e
=
1e-12
-v
=
1
-prec
=
fp16
-repeat
=
1000
$EXE
-m
=
700
-n
=
512
-e
=
1e-12
-v
=
1
-prec
=
fp16
-repeat
=
1000
$EXE
-m
=
700
-n
=
924
-e
=
1e-12
-v
=
1
-prec
=
fp16
-repeat
=
1000
$EXE
-m
=
700
-n
=
1024
-e
=
1e-12
-v
=
1
-prec
=
fp16
-repeat
=
1000
$EXE
-m
=
700
-n
=
1078
-e
=
1e-12
-v
=
1
-prec
=
fp16
-repeat
=
1000
$EXE
-m
=
700
-n
=
1996
-e
=
1e-12
-v
=
1
-prec
=
fp16
-repeat
=
1000
$EXE
-m
=
700
-n
=
4080
-e
=
1e-12
-v
=
1
-prec
=
fp16
-repeat
=
1000
\ No newline at end of file
$EXE
-m
=
700
-n
=
80
-e
=
1e-12
-v
=
1
-prec_i
=
fp16
-repeat
=
1000
$EXE
-m
=
700
-n
=
128
-e
=
1e-12
-v
=
1
-prec_i
=
fp16
-repeat
=
1000
$EXE
-m
=
700
-n
=
144
-e
=
1e-12
-v
=
1
-prec_i
=
fp16
-repeat
=
1000
$EXE
-m
=
700
-n
=
168
-e
=
1e-12
-v
=
1
-prec_i
=
fp16
-repeat
=
1000
$EXE
-m
=
700
-n
=
184
-e
=
1e-12
-v
=
1
-prec_i
=
fp16
-repeat
=
1000
$EXE
-m
=
700
-n
=
256
-e
=
1e-12
-v
=
1
-prec_i
=
fp16
-repeat
=
1000
$EXE
-m
=
700
-n
=
288
-e
=
1e-12
-v
=
1
-prec_i
=
fp16
-repeat
=
1000
$EXE
-m
=
700
-n
=
344
-e
=
1e-12
-v
=
1
-prec_i
=
fp16
-repeat
=
1000
$EXE
-m
=
700
-n
=
376
-e
=
1e-12
-v
=
1
-prec_i
=
fp16
-repeat
=
1000
$EXE
-m
=
700
-n
=
448
-e
=
1e-12
-v
=
1
-prec_i
=
fp16
-repeat
=
1000
$EXE
-m
=
700
-n
=
512
-e
=
1e-12
-v
=
1
-prec_i
=
fp16
-repeat
=
1000
$EXE
-m
=
700
-n
=
924
-e
=
1e-12
-v
=
1
-prec_i
=
fp16
-repeat
=
1000
$EXE
-m
=
700
-n
=
1024
-e
=
1e-12
-v
=
1
-prec_i
=
fp16
-repeat
=
1000
$EXE
-m
=
700
-n
=
1078
-e
=
1e-12
-v
=
1
-prec_i
=
fp16
-repeat
=
1000
$EXE
-m
=
700
-n
=
1996
-e
=
1e-12
-v
=
1
-prec_i
=
fp16
-repeat
=
1000
$EXE
-m
=
700
-n
=
4080
-e
=
1e-12
-v
=
1
-prec_i
=
fp16
-repeat
=
1000
\ No newline at end of file
example/ck_tile/02_layernorm2d/script/smoke_test.sh
View file @
4b59b5c9
...
...
@@ -3,29 +3,31 @@
EXE
=
./build/bin/tile_example_layernorm2d_fwd
for
pr_i
in
"fp16"
"bf16"
;
do
$EXE
-prec
=
$pr_i
-m
=
99
-n
=
13
$EXE
-prec
=
$pr_i
-m
=
17
-n
=
16
$EXE
-prec
=
$pr_i
-m
=
1
-n
=
100
$EXE
-prec
=
$pr_i
-m
=
4
-n
=
128
$EXE
-prec
=
$pr_i
-m
=
80
-n
=
127
$EXE
-prec
=
$pr_i
-m
=
22
-n
=
255
-stride
=
256
$EXE
-prec
=
$pr_i
-m
=
7
-n
=
599
$EXE
-prec
=
$pr_i
-m
=
19
-n
=
512
$EXE
-prec
=
$pr_i
-m
=
33
-n
=
313
-stride
=
1000
$EXE
-prec
=
$pr_i
-m
=
11
-n
=
510
$EXE
-prec
=
$pr_i
-m
=
171
-n
=
676
-stride
=
818
$EXE
-prec
=
$pr_i
-m
=
91
-n
=
636
$EXE
-prec
=
$pr_i
-m
=
12
-n
=
768
-stride
=
800
$EXE
-prec
=
$pr_i
-m
=
100
-n
=
766
-stride
=
812
$EXE
-prec
=
$pr_i
-m
=
31
-n
=
1024
$EXE
-prec
=
$pr_i
-m
=
64
-n
=
1000
-stride
=
1004
$EXE
-prec
=
$pr_i
-m
=
8
-n
=
1501
$EXE
-prec
=
$pr_i
-m
=
3
-n
=
1826
$EXE
-prec
=
$pr_i
-m
=
5
-n
=
2040
$EXE
-prec
=
$pr_i
-m
=
7
-n
=
2734
$EXE
-prec
=
$pr_i
-m
=
1
-n
=
3182
$EXE
-prec
=
$pr_i
-m
=
9
-n
=
4096
$EXE
-prec
=
$pr_i
-m
=
3
-n
=
8192
$EXE
-prec
=
$pr_i
-m
=
1
-n
=
10547
$EXE
-prec
=
$pr_i
-m
=
3
-n
=
17134
for
fadd
in
"0"
"1"
"2"
;
do
$EXE
-prec_i
=
$pr_i
-fadd
=
$fadd
-m
=
99
-n
=
13
$EXE
-prec_i
=
$pr_i
-fadd
=
$fadd
-m
=
17
-n
=
16
$EXE
-prec_i
=
$pr_i
-fadd
=
$fadd
-m
=
1
-n
=
100
$EXE
-prec_i
=
$pr_i
-fadd
=
$fadd
-m
=
4
-n
=
128
$EXE
-prec_i
=
$pr_i
-fadd
=
$fadd
-m
=
80
-n
=
127
$EXE
-prec_i
=
$pr_i
-fadd
=
$fadd
-m
=
22
-n
=
255
-stride
=
256
$EXE
-prec_i
=
$pr_i
-fadd
=
$fadd
-m
=
7
-n
=
599
$EXE
-prec_i
=
$pr_i
-fadd
=
$fadd
-m
=
19
-n
=
512
$EXE
-prec_i
=
$pr_i
-fadd
=
$fadd
-m
=
33
-n
=
313
-stride
=
1000
$EXE
-prec_i
=
$pr_i
-fadd
=
$fadd
-m
=
11
-n
=
510
$EXE
-prec_i
=
$pr_i
-fadd
=
$fadd
-m
=
171
-n
=
676
-stride
=
818
$EXE
-prec_i
=
$pr_i
-fadd
=
$fadd
-m
=
91
-n
=
636
$EXE
-prec_i
=
$pr_i
-fadd
=
$fadd
-m
=
12
-n
=
768
-stride
=
800
$EXE
-prec_i
=
$pr_i
-fadd
=
$fadd
-m
=
100
-n
=
766
-stride
=
812
$EXE
-prec_i
=
$pr_i
-fadd
=
$fadd
-m
=
31
-n
=
1024
$EXE
-prec_i
=
$pr_i
-fadd
=
$fadd
-m
=
64
-n
=
1000
-stride
=
1004
$EXE
-prec_i
=
$pr_i
-fadd
=
$fadd
-m
=
8
-n
=
1501
$EXE
-prec_i
=
$pr_i
-fadd
=
$fadd
-m
=
3
-n
=
1826
$EXE
-prec_i
=
$pr_i
-fadd
=
$fadd
-m
=
5
-n
=
2040
$EXE
-prec_i
=
$pr_i
-fadd
=
$fadd
-m
=
7
-n
=
2734
$EXE
-prec_i
=
$pr_i
-fadd
=
$fadd
-m
=
1
-n
=
3182
$EXE
-prec_i
=
$pr_i
-fadd
=
$fadd
-m
=
9
-n
=
4096
$EXE
-prec_i
=
$pr_i
-fadd
=
$fadd
-m
=
3
-n
=
8192
$EXE
-prec_i
=
$pr_i
-fadd
=
$fadd
-m
=
1
-n
=
10547
$EXE
-prec_i
=
$pr_i
-fadd
=
$fadd
-m
=
3
-n
=
17134
done
done
include/ck_tile/core/tensor/null_tile_window.hpp
View file @
4b59b5c9
...
...
@@ -80,6 +80,13 @@ CK_TILE_DEVICE constexpr auto make_tile_window(null_tensor_view,
return
null_tile_window
<
remove_cvref_t
<
WindowLengths
>>
{
window_lengths
};
}
template
<
typename
WindowLengths
,
typename
StaticTileDistribution
>
CK_TILE_DEVICE
constexpr
auto
make_tile_window
(
const
null_tile_window
<
WindowLengths
>&
t
,
const
StaticTileDistribution
&
)
{
return
t
;
}
template
<
typename
WindowLengths
>
CK_TILE_DEVICE
void
move_tile_window
(
null_tile_window
<
WindowLengths
>&
,
...
...
include/ck_tile/ops/layernorm2d.hpp
View file @
4b59b5c9
...
...
@@ -9,4 +9,5 @@
#include "ck_tile/ops/layernorm2d/pipeline/layernorm2d_fwd_pipeline_one_pass.hpp"
#include "ck_tile/ops/layernorm2d/pipeline/layernorm2d_fwd_pipeline_problem.hpp"
#include "ck_tile/ops/layernorm2d/pipeline/layernorm2d_fwd_pipeline_two_pass.hpp"
#include "ck_tile/ops/layernorm2d/pipeline/layernorm2d_fwd_traits.hpp"
#include "ck_tile/ops/common/tensor_layout.hpp"
include/ck_tile/ops/layernorm2d/kernel/layernorm2d_fwd_kernel.hpp
View file @
4b59b5c9
...
...
@@ -5,17 +5,20 @@
#include "ck_tile/core.hpp"
#include "ck_tile/ops/common.hpp"
#include "ck_tile/ops/layernorm2d/pipeline/layernorm2d_fwd_traits.hpp"
namespace
ck_tile
{
// host side args
struct
Layernorm2dFwdHostArgs
{
const
void
*
p_x
;
const
void
*
p_x
;
// input
const
void
*
p_sx
;
// shortcut input, set to nullptr if no
const
void
*
p_gamma
;
const
void
*
p_beta
;
void
*
p_y
;
void
*
p_y
;
// output
void
*
p_sy
;
// shortcut output, set to nullptr if no
void
*
p_mean
;
void
*
p_invStd
;
...
...
@@ -27,10 +30,11 @@ struct Layernorm2dFwdHostArgs
};
// TODO: Extract some type to wrapper class
template
<
typename
Pipeline_
>
template
<
typename
Pipeline_
,
typename
Epilogue_
>
struct
Layernorm2dFwd
{
using
Pipeline
=
remove_cvref_t
<
Pipeline_
>
;
using
Epilogue
=
remove_cvref_t
<
Epilogue_
>
;
using
Problem
=
typename
Pipeline
::
Problem
;
using
XDataType
=
remove_cvref_t
<
typename
Problem
::
XDataType
>
;
...
...
@@ -41,17 +45,23 @@ struct Layernorm2dFwd
using
MeanDataType
=
remove_cvref_t
<
typename
Problem
::
MeanDataType
>
;
using
InvStdDataType
=
remove_cvref_t
<
typename
Problem
::
InvStdDataType
>
;
// for simplicity, shortcut input/output type is same as X
using
SXDataType
=
XDataType
;
using
SYDataType
=
XDataType
;
static
constexpr
bool
kHasGamma
=
!
std
::
is_same_v
<
GammaDataType
,
null_type
>
;
static
constexpr
bool
kHasBeta
=
!
std
::
is_same_v
<
BetaDataType
,
null_type
>
;
static
constexpr
bool
kSaveMeanInvStd
=
Problem
::
kSaveMeanInvStd
;
static
constexpr
bool
kSaveMean
=
Problem
::
kSaveMeanInvStd
;
static
constexpr
bool
kSaveInvStd
=
Problem
::
kSaveMeanInvStd
;
static
constexpr
bool
kSaveMeanInvStd
=
Problem
::
Traits
::
kSaveMeanInvStd
;
static
constexpr
bool
kSaveMean
=
Problem
::
Traits
::
kSaveMeanInvStd
;
static
constexpr
bool
kSaveInvStd
=
Problem
::
Traits
::
kSaveMeanInvStd
;
static
constexpr
index_t
Block_M
=
Problem
::
BlockShape
::
Block_M
;
static
constexpr
index_t
Block_N
=
Problem
::
BlockShape
::
Block_N
;
static
constexpr
bool
kPadM
=
false
;
// always no need to pad along M
static
constexpr
bool
kPadN
=
Problem
::
kPadN
;
static
constexpr
bool
kTwoPass
=
Problem
::
kTwoPass
;
static
constexpr
bool
kPadN
=
Problem
::
Traits
::
kPadN
;
static
constexpr
bool
kTwoPass
=
Problem
::
Traits
::
kTwoPass
;
static
constexpr
auto
kFusedAdd
=
Problem
::
Traits
::
kFusedAdd
;
static
constexpr
auto
kFusedSweep
=
Problem
::
Traits
::
kFusedSweep
;
static
constexpr
index_t
ThreadPerWarp_N
=
Problem
::
BlockShape
::
ThreadPerWarp_N
;
static
constexpr
index_t
Vector_N
=
Problem
::
BlockShape
::
Vector_N
;
...
...
@@ -62,11 +72,13 @@ struct Layernorm2dFwd
struct
Kargs
{
const
void
*
p_x
;
const
void
*
p_x
;
// input
const
void
*
p_sx
;
// shortcut input, set to nullptr if no
const
void
*
p_gamma
;
const
void
*
p_beta
;
void
*
p_y
;
void
*
p_y
;
// output
void
*
p_sy
;
// shortcut output, set to nullptr if no
void
*
p_mean
;
void
*
p_invStd
;
...
...
@@ -81,9 +93,11 @@ struct Layernorm2dFwd
CK_TILE_HOST
static
constexpr
Kargs
MakeKargs
(
const
Hargs
&
hargs
)
{
return
Kargs
{
hargs
.
p_x
,
hargs
.
p_sx
,
hargs
.
p_gamma
,
hargs
.
p_beta
,
hargs
.
p_y
,
hargs
.
p_sy
,
hargs
.
p_mean
,
hargs
.
p_invStd
,
hargs
.
epsilon
,
...
...
@@ -113,17 +127,19 @@ struct Layernorm2dFwd
CK_TILE_HOST
static
std
::
string
GetName
()
{
// clang-format off
// clang-format off
#define _SS_ std::string
#define _TS_ std::to_string
using
S_
=
typename
Problem
::
BlockShape
;
auto
surfix
=
[
&
]
()
{
std
::
string
n
;
if
(
kFusedAdd
!=
Layernorm2dFusedAddEnum
::
NO_ADD
)
n
+=
_SS_
(
"_"
)
+
Layernorm2dFusedAddEnumName
<
kFusedAdd
>::
name
;
if
(
kFusedSweep
!=
Layernorm2dFusedSweepEnum
::
NO_SWEEP
)
n
+=
_SS_
(
"_"
)
+
Layernorm2dFusedSweepEnumName
<
kFusedSweep
>::
name
;
if
(
kPadN
)
n
+=
"_pn"
;
if
(
kSaveMeanInvStd
)
n
+=
"_mv"
;
if
(
kTwoPass
)
n
+=
"_2p"
;
return
n
;
}();
#define _SS_ std::string
#define _TS_ std::to_string
return
_SS_
(
"layernorm2d_fwd_"
)
+
_SS_
(
t2s
<
XDataType
>::
name
)
+
"_"
+
_TS_
(
S_
::
Block_M
)
+
"x"
+
_TS_
(
S_
::
Block_N
)
+
"_"
+
_TS_
(
S_
::
WarpPerBlock_M
)
+
"x"
+
_TS_
(
S_
::
WarpPerBlock_N
)
+
"_"
+
_TS_
(
S_
::
Warp_M
)
+
"x"
+
_TS_
(
S_
::
Warp_N
)
+
"_"
+
_TS_
(
S_
::
Vector_M
)
+
"x"
+
_TS_
(
S_
::
Vector_N
)
+
"_"
+
...
...
@@ -153,6 +169,31 @@ struct Layernorm2dFwd
tmp2_
,
make_tuple
(
number
<
Block_M
>
{},
number
<
Block_N
>
{}),
{
iM
,
0
});
}();
const
auto
sx_window
=
[
&
]()
{
if
constexpr
(
kFusedAdd
==
Layernorm2dFusedAddEnum
::
PRE_ADD_STORE
||
kFusedAdd
==
Layernorm2dFusedAddEnum
::
PRE_ADD
)
{
const
auto
tmp_
=
make_naive_tensor_view
<
address_space_enum
::
global
>
(
static_cast
<
const
SXDataType
*>
(
kargs
.
p_sx
),
make_tuple
(
kargs
.
m
,
kargs
.
n
),
make_tuple
(
kargs
.
stride
,
1
),
number
<
Vector_N
>
{},
number
<
1
>
{});
// NOTE: we don't do any pad in this kernel for loading, assume that inside kernel
// will check the max count dynamically
const
auto
tmp2_
=
pad_tensor_view
(
tmp_
,
make_tuple
(
number
<
Block_M
>
{},
number
<
Block_N
>
{}),
sequence
<
false
,
false
>
{});
return
make_tile_window
(
tmp2_
,
make_tuple
(
number
<
Block_M
>
{},
number
<
Block_N
>
{}),
{
iM
,
0
});
}
else
{
return
make_null_tile_window
(
make_tuple
(
number
<
Block_M
>
{},
number
<
Block_N
>
{}));
}
}();
const
auto
gamma_window
=
[
&
]()
{
const
auto
tmp_
=
make_naive_tensor_view
<
address_space_enum
::
global
>
(
static_cast
<
const
GammaDataType
*>
(
kargs
.
p_gamma
),
...
...
@@ -194,6 +235,28 @@ struct Layernorm2dFwd
tmp2_
,
make_tuple
(
number
<
Block_M
>
{},
number
<
Block_N
>
{}),
{
iM
,
0
});
}();
auto
sy_window
=
[
&
]()
{
if
constexpr
(
kFusedAdd
==
Layernorm2dFusedAddEnum
::
PRE_ADD_STORE
)
{
auto
tmp_
=
make_naive_tensor_view
<
address_space_enum
::
global
>
(
static_cast
<
SYDataType
*>
(
kargs
.
p_sy
),
make_tuple
(
kargs
.
m
,
kargs
.
n
),
make_tuple
(
kargs
.
stride
,
1
),
number
<
Vector_N
>
{},
number
<
1
>
{});
auto
tmp2_
=
pad_tensor_view
(
tmp_
,
make_tuple
(
number
<
Block_M
>
{},
number
<
Block_N
>
{}),
sequence
<
kPadM
,
kPadN
>
{});
return
make_tile_window
(
tmp2_
,
make_tuple
(
number
<
Block_M
>
{},
number
<
Block_N
>
{}),
{
iM
,
0
});
}
else
{
return
make_null_tile_window
(
make_tuple
(
number
<
Block_M
>
{},
number
<
Block_N
>
{}));
}
}();
auto
mean_window
=
[
&
]()
{
if
constexpr
(
kSaveMean
)
{
...
...
@@ -235,14 +298,17 @@ struct Layernorm2dFwd
__shared__
char
smem
[
GetSmemSize
()];
Pipeline
{}(
x_window
,
sx_window
,
gamma_window
,
beta_window
,
y_window
,
sy_window
,
mean_window
,
inv_std_window
,
static_cast
<
const
ComputeDataType
>
(
kargs
.
epsilon
),
kargs
.
n
,
smem
);
smem
,
Epilogue
{});
}
};
...
...
include/ck_tile/ops/layernorm2d/pipeline/layernorm2d_fwd_pipeline_one_pass.hpp
View file @
4b59b5c9
...
...
@@ -5,6 +5,7 @@
#include "ck_tile/core.hpp"
#include "ck_tile/ops/layernorm2d/pipeline/layernorm2d_fwd_pipeline_default_policy.hpp"
#include "ck_tile/ops/layernorm2d/pipeline/layernorm2d_fwd_traits.hpp"
#include <string>
#include <type_traits>
...
...
@@ -24,14 +25,19 @@ struct Layernorm2dFwdPipelineOnePass
using
MeanDataType
=
ck_tile
::
remove_cvref_t
<
typename
Problem
::
MeanDataType
>
;
using
InvStdDataType
=
ck_tile
::
remove_cvref_t
<
typename
Problem
::
InvStdDataType
>
;
using
SXDataType
=
XDataType
;
using
SYDataType
=
XDataType
;
static
constexpr
bool
kHasGamma
=
!
std
::
is_same_v
<
GammaDataType
,
ck_tile
::
null_type
>
;
static
constexpr
bool
kHasBeta
=
!
std
::
is_same_v
<
BetaDataType
,
ck_tile
::
null_type
>
;
static
constexpr
bool
kSaveMean
=
Problem
::
kSaveMeanInvStd
;
static
constexpr
bool
kSaveInvStd
=
Problem
::
kSaveMeanInvStd
;
static
constexpr
bool
kSaveMean
=
Problem
::
Traits
::
kSaveMeanInvStd
;
static
constexpr
bool
kSaveInvStd
=
Problem
::
Traits
::
kSaveMeanInvStd
;
static
constexpr
bool
kNeedCrossWarpSync
=
Problem
::
kNeedCrossWarpSync
;
static
constexpr
bool
kPadM
=
false
;
// TODO - BlockLayernorm2dFwdProblem::kPadM
static
constexpr
bool
kPadN
=
Problem
::
kPadN
;
static
constexpr
bool
kPadN
=
Problem
::
Traits
::
kPadN
;
static
constexpr
auto
kFusedAdd
=
Problem
::
Traits
::
kFusedAdd
;
static
constexpr
auto
kFusedSweep
=
Problem
::
Traits
::
kFusedSweep
;
static
constexpr
const
char
*
name
=
[]()
{
if
constexpr
(
kNeedCrossWarpSync
)
...
...
@@ -46,20 +52,26 @@ struct Layernorm2dFwdPipelineOnePass
}
template
<
typename
XWindow
,
typename
SXWindow
,
typename
GammaWindow
,
typename
BetaWindow
,
typename
YWindow
,
typename
SYWindow
,
typename
MeanWindow
,
typename
InvStdWindow
>
typename
InvStdWindow
,
typename
Epilogue
>
CK_TILE_DEVICE
auto
operator
()(
const
XWindow
&
x_window_
,
const
SXWindow
&
sx_window_
,
const
GammaWindow
&
gamma_window_
,
const
BetaWindow
&
beta_window_
,
YWindow
&
y_window
,
YWindow
&
y_window_
,
const
SYWindow
&
sy_window_
,
MeanWindow
&
mean_window
,
InvStdWindow
&
inv_std_window
,
ComputeDataType
epsilon
,
ck_tile
::
index_t
row_size
,
void
*
smem
)
const
void
*
smem
,
Epilogue
)
const
{
const
auto
x_window
=
make_tile_window
(
x_window_
,
Policy
::
template
MakeXBlockTileDistribution
<
Problem
>());
...
...
@@ -67,8 +79,14 @@ struct Layernorm2dFwdPipelineOnePass
gamma_window_
,
Policy
::
template
MakeGammaBetaBlockTileDistribution
<
Problem
>());
const
auto
beta_window
=
make_tile_window
(
beta_window_
,
Policy
::
template
MakeGammaBetaBlockTileDistribution
<
Problem
>());
const
auto
sx_window
=
make_tile_window
(
sx_window_
,
Policy
::
template
MakeXBlockTileDistribution
<
Problem
>());
auto
sy_window
=
make_tile_window
(
sy_window_
,
Policy
::
template
MakeXBlockTileDistribution
<
Problem
>());
auto
x
=
load_tile
(
x_window
);
auto
sx
=
load_tile
(
sx_window
);
const
auto
x
=
load_tile
(
x_window
);
int
cur_count
=
0
;
int
max_count
=
block_tile_welford_calculate_max_count
<
typename
Problem
::
BlockShape
>
(
row_size
);
...
...
@@ -81,6 +99,17 @@ struct Layernorm2dFwdPipelineOnePass
const
auto
gamma
=
load_tile
(
gamma_window
);
const
auto
beta
=
load_tile
(
beta_window
);
if
constexpr
(
kFusedAdd
==
Layernorm2dFusedAddEnum
::
PRE_ADD_STORE
||
kFusedAdd
==
Layernorm2dFusedAddEnum
::
PRE_ADD
)
{
sweep_tile
(
sx
,
[
&
](
auto
idx
)
{
// compute x = sx + x
x
(
idx
)
=
type_convert
<
SYDataType
>
(
sx
(
idx
))
+
type_convert
<
SYDataType
>
(
x
(
idx
));
});
if
constexpr
(
kFusedAdd
==
Layernorm2dFusedAddEnum
::
PRE_ADD_STORE
)
store_tile
(
sy_window
,
x
);
}
// compute welford each-thread->cross-lane->cross-warp
auto
[
mean
,
var
]
=
block_welford
(
x
,
cur_count
,
max_count
);
block_welford_sync
(
mean
,
var
,
cur_count
);
...
...
@@ -100,8 +129,8 @@ struct Layernorm2dFwdPipelineOnePass
store_tile
(
inv_std_window
,
cast_tile
<
InvStdDataType
>
(
inv_std
));
// layernorm computation
auto
y
=
make_static_distributed_tensor
<
Y
DataType
>
(
x
.
get_tile_distribution
());
sweep_tile
(
y
,
[
&
,
mean_
=
mean
](
auto
idx
)
{
auto
ln
=
make_static_distributed_tensor
<
Compute
DataType
>
(
x
.
get_tile_distribution
());
sweep_tile
(
ln
,
[
&
,
mean_
=
mean
](
auto
idx
)
{
constexpr
auto
i_idx
=
make_tuple
(
idx
[
number
<
0
>
{}]);
constexpr
auto
j_idx
=
make_tuple
(
idx
[
number
<
1
>
{}]);
...
...
@@ -109,11 +138,12 @@ struct Layernorm2dFwdPipelineOnePass
const
auto
beta_
=
type_convert
<
ComputeDataType
>
(
beta
[
j_idx
]);
const
auto
x_
=
type_convert
<
ComputeDataType
>
(
x
[
idx
]);
auto
y_
=
(
x_
-
mean_
[
i_idx
])
*
inv_std
[
i_idx
]
*
gamma_
+
beta_
;
auto
ln_
=
(
x_
-
mean_
[
i_idx
])
*
inv_std
[
i_idx
]
*
gamma_
+
beta_
;
y
(
idx
)
=
type_convert
<
YDataType
>
(
y_
)
;
ln
(
idx
)
=
ln_
;
});
store_tile
(
y_window
,
y
);
Epilogue
{}(
y_window_
,
ln
);
}
};
}
// namespace ck_tile
include/ck_tile/ops/layernorm2d/pipeline/layernorm2d_fwd_pipeline_problem.hpp
View file @
4b59b5c9
...
...
@@ -15,9 +15,7 @@ template <typename XDataType_,
typename
MeanDataType_
,
typename
InvStdDataType_
,
typename
BlockShape_
,
bool
kPadN_
,
bool
kSaveMeanInvStd_
,
bool
kTwoPass_
>
typename
Traits_
>
struct
Layernorm2dFwdPipelineProblem
{
using
XDataType
=
remove_cvref_t
<
XDataType_
>
;
...
...
@@ -32,9 +30,7 @@ struct Layernorm2dFwdPipelineProblem
static
constexpr
bool
kNeedCrossLaneSync
=
BlockShape
::
ThreadPerWarp_N
>
1
;
static
constexpr
bool
kNeedCrossWarpSync
=
BlockShape
::
WarpPerBlock_N
>
1
;
static
constexpr
bool
kPadN
=
kPadN_
;
static
constexpr
bool
kSaveMeanInvStd
=
kSaveMeanInvStd_
;
static
constexpr
bool
kTwoPass
=
kTwoPass_
;
using
Traits
=
remove_cvref_t
<
Traits_
>
;
};
}
// namespace ck_tile
include/ck_tile/ops/layernorm2d/pipeline/layernorm2d_fwd_pipeline_two_pass.hpp
View file @
4b59b5c9
...
...
@@ -24,14 +24,19 @@ struct Layernorm2dFwdPipelineTwoPass
using
MeanDataType
=
ck_tile
::
remove_cvref_t
<
typename
Problem
::
MeanDataType
>
;
using
InvStdDataType
=
ck_tile
::
remove_cvref_t
<
typename
Problem
::
InvStdDataType
>
;
using
SXDataType
=
XDataType
;
using
SYDataType
=
XDataType
;
static
constexpr
bool
kHasGamma
=
!
std
::
is_same_v
<
GammaDataType
,
ck_tile
::
null_type
>
;
static
constexpr
bool
kHasBeta
=
!
std
::
is_same_v
<
BetaDataType
,
ck_tile
::
null_type
>
;
static
constexpr
bool
kSaveMean
=
Problem
::
kSaveMeanInvStd
;
static
constexpr
bool
kSaveInvStd
=
Problem
::
kSaveMeanInvStd
;
static
constexpr
bool
kSaveMean
=
Problem
::
Traits
::
kSaveMeanInvStd
;
static
constexpr
bool
kSaveInvStd
=
Problem
::
Traits
::
kSaveMeanInvStd
;
static
constexpr
bool
kNeedCrossWarpSync
=
Problem
::
kNeedCrossWarpSync
;
static
constexpr
bool
kPadM
=
false
;
// TODO - BlockLayernorm2dFwdProblem::kPadM
static
constexpr
bool
kPadN
=
Problem
::
kPadN
;
static
constexpr
bool
kPadN
=
Problem
::
Traits
::
kPadN
;
static
constexpr
auto
kFusedAdd
=
Problem
::
Traits
::
kFusedAdd
;
static
constexpr
auto
kFusedSweep
=
Problem
::
Traits
::
kFusedSweep
;
static
constexpr
const
char
*
name
=
[]()
{
if
constexpr
(
kNeedCrossWarpSync
)
...
...
@@ -46,20 +51,26 @@ struct Layernorm2dFwdPipelineTwoPass
}
template
<
typename
XWindow
,
typename
SXWindow
,
typename
GammaWindow
,
typename
BetaWindow
,
typename
YWindow
,
typename
SYWindow
,
typename
MeanWindow
,
typename
InvStdWindow
>
typename
InvStdWindow
,
typename
Epilogue
>
CK_TILE_DEVICE
auto
operator
()(
const
XWindow
&
x_window_
,
const
SXWindow
&
sx_window_
,
const
GammaWindow
&
gamma_window_
,
const
BetaWindow
&
beta_window_
,
YWindow
&
y_window
,
const
SYWindow
&
sy_window_
,
MeanWindow
&
mean_window
,
InvStdWindow
&
inv_std_window
,
ComputeDataType
epsilon
,
ck_tile
::
index_t
row_size
,
void
*
smem
)
const
void
*
smem
,
Epilogue
)
const
{
auto
x_window
=
make_tile_window
(
x_window_
,
Policy
::
template
MakeXBlockTileDistribution
<
Problem
>());
...
...
@@ -67,6 +78,10 @@ struct Layernorm2dFwdPipelineTwoPass
gamma_window_
,
Policy
::
template
MakeGammaBetaBlockTileDistribution
<
Problem
>());
auto
beta_window
=
make_tile_window
(
beta_window_
,
Policy
::
template
MakeGammaBetaBlockTileDistribution
<
Problem
>());
auto
sx_window
=
make_tile_window
(
sx_window_
,
Policy
::
template
MakeXBlockTileDistribution
<
Problem
>());
auto
sy_window
=
make_tile_window
(
sy_window_
,
Policy
::
template
MakeXBlockTileDistribution
<
Problem
>());
// Problem::BlockShape
static
constexpr
index_t
Block_N
=
Problem
::
BlockShape
::
Block_N
;
...
...
@@ -93,9 +108,25 @@ struct Layernorm2dFwdPipelineTwoPass
for
(
int
iN
=
__builtin_amdgcn_readfirstlane
(
0
);
iN
<
num_n_tile_iteration
;
++
iN
)
{
const
auto
x
=
load_tile
(
x_window
);
block_welford
(
x
,
mean
,
var
,
cur_count
,
max_count
);
auto
x
=
load_tile
(
x_window
);
auto
sx
=
load_tile
(
sx_window
);
move_tile_window
(
x_window
,
{
0
,
Block_N
});
move_tile_window
(
sx_window
,
{
0
,
Block_N
});
if
constexpr
(
kFusedAdd
==
Layernorm2dFusedAddEnum
::
PRE_ADD_STORE
||
kFusedAdd
==
Layernorm2dFusedAddEnum
::
PRE_ADD
)
{
sweep_tile
(
sx
,
[
&
](
auto
idx
)
{
// compute x = sx + x
x
(
idx
)
=
type_convert
<
SYDataType
>
(
sx
(
idx
))
+
type_convert
<
SYDataType
>
(
x
(
idx
));
});
if
constexpr
(
kFusedAdd
==
Layernorm2dFusedAddEnum
::
PRE_ADD_STORE
)
{
store_tile
(
sy_window
,
x
);
move_tile_window
(
sy_window
,
{
0
,
Block_N
});
}
}
block_welford
(
x
,
mean
,
var
,
cur_count
,
max_count
);
}
block_welford_sync
(
mean
,
var
,
cur_count
);
...
...
@@ -121,6 +152,7 @@ struct Layernorm2dFwdPipelineTwoPass
// x_window.foo();
// gamma_window.foo();
move_tile_window
(
x_window
,
{
0
,
-
Block_N
});
move_tile_window
(
sx_window
,
{
0
,
-
Block_N
});
move_tile_window
(
gamma_window
,
{
stride_to_right_most_window
});
move_tile_window
(
beta_window
,
{
stride_to_right_most_window
});
move_tile_window
(
y_window
,
{
0
,
stride_to_right_most_window
});
...
...
@@ -128,14 +160,23 @@ struct Layernorm2dFwdPipelineTwoPass
// layernorm computation
for
(
int
iN
=
__builtin_amdgcn_readfirstlane
(
0
);
iN
<
num_n_tile_iteration
;
++
iN
)
{
const
auto
x
=
load_tile
(
x_window
);
auto
x
=
load_tile
(
x_window
);
auto
sx
=
load_tile
(
sx_window
);
if
constexpr
(
kFusedAdd
==
Layernorm2dFusedAddEnum
::
PRE_ADD_STORE
||
kFusedAdd
==
Layernorm2dFusedAddEnum
::
PRE_ADD
)
{
sweep_tile
(
sx
,
[
&
](
auto
idx
)
{
// compute x = sx + x
x
(
idx
)
=
type_convert
<
SYDataType
>
(
sx
(
idx
))
+
type_convert
<
SYDataType
>
(
x
(
idx
));
});
}
// load gamma/beta (TODO: support no gamma/beta?)
const
auto
gamma
=
load_tile
(
gamma_window
);
const
auto
beta
=
load_tile
(
beta_window
);
auto
y
=
make_static_distributed_tensor
<
Y
DataType
>
(
x
.
get_tile_distribution
());
auto
ln
=
make_static_distributed_tensor
<
Compute
DataType
>
(
x
.
get_tile_distribution
());
sweep_tile
(
y
,
[
&
,
mean_
=
mean
](
auto
idx
)
{
sweep_tile
(
ln
,
[
&
,
mean_
=
mean
](
auto
idx
)
{
constexpr
auto
i_idx
=
make_tuple
(
idx
[
number
<
0
>
{}]);
constexpr
auto
j_idx
=
make_tuple
(
idx
[
number
<
1
>
{}]);
...
...
@@ -143,14 +184,15 @@ struct Layernorm2dFwdPipelineTwoPass
const
auto
beta_
=
type_convert
<
ComputeDataType
>
(
beta
[
j_idx
]);
const
auto
x_
=
type_convert
<
ComputeDataType
>
(
x
[
idx
]);
auto
y_
=
(
x_
-
mean_
[
i_idx
])
*
inv_std
[
i_idx
]
*
gamma_
+
beta_
;
auto
ln_
=
(
x_
-
mean_
[
i_idx
])
*
inv_std
[
i_idx
]
*
gamma_
+
beta_
;
y
(
idx
)
=
type_convert
<
YDataType
>
(
y_
)
;
ln
(
idx
)
=
ln_
;
});
store_tile
(
y_window
,
y
);
Epilogue
{}
(
y_window
,
ln
);
move_tile_window
(
x_window
,
{
0
,
-
Block_N
});
move_tile_window
(
sx_window
,
{
0
,
-
Block_N
});
move_tile_window
(
gamma_window
,
{
-
Block_N
});
move_tile_window
(
beta_window
,
{
-
Block_N
});
move_tile_window
(
y_window
,
{
0
,
-
Block_N
});
...
...
include/ck_tile/ops/layernorm2d/pipeline/layernorm2d_fwd_traits.hpp
0 → 100644
View file @
4b59b5c9
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core/utility/type_traits.hpp"
namespace
ck_tile
{
enum
class
Layernorm2dFusedAddEnum
{
NO_ADD
=
0
,
// fused add before layernorm (prenorm), and store result to global
PRE_ADD_STORE
=
1
,
PRE_NORM_ADD
=
PRE_ADD_STORE
,
// fused add before layernorm (postnorm), but not store result
PRE_ADD
=
2
,
POST_NORM_ADD
=
PRE_ADD
,
};
// clang-format off
template
<
Layernorm2dFusedAddEnum
E
>
struct
Layernorm2dFusedAddEnumName
;
template
<
>
struct
Layernorm2dFusedAddEnumName
<
Layernorm2dFusedAddEnum
::
NO_ADD
>
{
static
constexpr
const
char
*
name
=
"no"
;
};
template
<
>
struct
Layernorm2dFusedAddEnumName
<
Layernorm2dFusedAddEnum
::
PRE_ADD_STORE
>
{
static
constexpr
const
char
*
name
=
"pras"
;
};
template
<
>
struct
Layernorm2dFusedAddEnumName
<
Layernorm2dFusedAddEnum
::
PRE_ADD
>
{
static
constexpr
const
char
*
name
=
"pra"
;
};
// clang-format on
enum
class
Layernorm2dFusedSweepEnum
{
NO_SWEEP
=
0
,
RENORM
=
1
,
DYNAMIC_QUANT
=
2
,
};
// clang-format off
template
<
Layernorm2dFusedSweepEnum
E
>
struct
Layernorm2dFusedSweepEnumName
;
template
<
>
struct
Layernorm2dFusedSweepEnumName
<
Layernorm2dFusedSweepEnum
::
NO_SWEEP
>
{
static
constexpr
const
char
*
name
=
"no"
;
};
template
<
>
struct
Layernorm2dFusedSweepEnumName
<
Layernorm2dFusedSweepEnum
::
RENORM
>
{
static
constexpr
const
char
*
name
=
"renorm"
;
};
template
<
>
struct
Layernorm2dFusedSweepEnumName
<
Layernorm2dFusedSweepEnum
::
DYNAMIC_QUANT
>
{
static
constexpr
const
char
*
name
=
"dequant"
;
};
// clang-format on
template
<
bool
kPadN_
,
bool
kSaveMeanInvStd_
,
bool
kTwoPass_
,
Layernorm2dFusedAddEnum
kFusedAdd_
,
Layernorm2dFusedSweepEnum
kFusedSweep_
>
struct
Layernorm2dFwdTraits
{
static
constexpr
bool
kPadN
=
kPadN_
;
static
constexpr
bool
kSaveMeanInvStd
=
kSaveMeanInvStd_
;
static
constexpr
bool
kTwoPass
=
kTwoPass_
;
static
constexpr
Layernorm2dFusedAddEnum
kFusedAdd
=
kFusedAdd_
;
static
constexpr
Layernorm2dFusedSweepEnum
kFusedSweep
=
kFusedSweep_
;
};
}
// namespace ck_tile
Prev
1
2
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment