Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel_ROCM
Commits
f199c936
Commit
f199c936
authored
Dec 12, 2024
by
AMD-dteng
Browse files
1.remove fmha change 2.change buffer name from bias to xbias
parent
ec07718a
Changes
8
Show whitespace changes
Inline
Side-by-side
Showing
8 changed files
with
53 additions
and
50 deletions
+53
-50
example/ck_tile/01_fmha/codegen/ops/fmha_fwd.py
example/ck_tile/01_fmha/codegen/ops/fmha_fwd.py
+2
-2
example/ck_tile/02_layernorm2d/generate.py
example/ck_tile/02_layernorm2d/generate.py
+2
-2
example/ck_tile/02_layernorm2d/layernorm2d_fwd.cpp
example/ck_tile/02_layernorm2d/layernorm2d_fwd.cpp
+12
-9
example/ck_tile/02_layernorm2d/layernorm2d_fwd.hpp
example/ck_tile/02_layernorm2d/layernorm2d_fwd.hpp
+2
-2
include/ck_tile/ops/layernorm2d/kernel/layernorm2d_fwd_kernel.hpp
...ck_tile/ops/layernorm2d/kernel/layernorm2d_fwd_kernel.hpp
+7
-7
include/ck_tile/ops/layernorm2d/pipeline/layernorm2d_fwd_pipeline_one_pass.hpp
...ayernorm2d/pipeline/layernorm2d_fwd_pipeline_one_pass.hpp
+9
-9
include/ck_tile/ops/layernorm2d/pipeline/layernorm2d_fwd_pipeline_problem.hpp
...layernorm2d/pipeline/layernorm2d_fwd_pipeline_problem.hpp
+2
-2
include/ck_tile/ops/layernorm2d/pipeline/layernorm2d_fwd_pipeline_two_pass.hpp
...ayernorm2d/pipeline/layernorm2d_fwd_pipeline_two_pass.hpp
+17
-17
No files found.
example/ck_tile/01_fmha/codegen/ops/fmha_fwd.py
View file @
f199c936
...
@@ -410,8 +410,8 @@ def get_fmha_fwd_tile_dict_from_dtype(dtype : str) -> Optional[dict]:
...
@@ -410,8 +410,8 @@ def get_fmha_fwd_tile_dict_from_dtype(dtype : str) -> Optional[dict]:
if
dtype
==
'fp16'
or
dtype
==
'bf16'
:
if
dtype
==
'fp16'
or
dtype
==
'bf16'
:
return
{
return
{
'32'
:
FmhaFwdTileSize
(
128
,
64
,
16
,
32
,
32
,
32
,
2
,
1
,
1
,
2
,
1
,
1
,
32
,
32
,
16
,
-
1
),
'32'
:
FmhaFwdTileSize
(
128
,
64
,
16
,
32
,
32
,
32
,
2
,
1
,
1
,
2
,
1
,
1
,
32
,
32
,
16
,
-
1
),
'64'
:
FmhaFwdTileSize
(
64
,
64
,
32
,
64
,
32
,
64
,
4
,
1
,
1
,
4
,
1
,
1
,
16
,
16
,
16
,
-
1
),
#
'64' : FmhaFwdTileSize(64, 64, 32, 64, 32, 64, 4, 1, 1, 4, 1, 1, 16, 16, 16, -1),
#
'64' : FmhaFwdTileSize(128, 64, 32, 64, 32, 64, 4, 1, 1, 4, 1, 1, 32, 32, 16, -1),
'64'
:
FmhaFwdTileSize
(
128
,
64
,
32
,
64
,
32
,
64
,
4
,
1
,
1
,
4
,
1
,
1
,
32
,
32
,
16
,
-
1
),
## '96' : FmhaFwdTileSize(128, 128, 32, 128, 32, 96, 4, 1, 1, 4, 1, 1, 32, 32, 16, -1),
## '96' : FmhaFwdTileSize(128, 128, 32, 128, 32, 96, 4, 1, 1, 4, 1, 1, 32, 32, 16, -1),
'128'
:
FmhaFwdTileSize
(
128
,
128
,
32
,
128
,
32
,
128
,
4
,
1
,
1
,
4
,
1
,
1
,
32
,
32
,
16
,
-
1
),
'128'
:
FmhaFwdTileSize
(
128
,
128
,
32
,
128
,
32
,
128
,
4
,
1
,
1
,
4
,
1
,
1
,
32
,
32
,
16
,
-
1
),
'256'
:
FmhaFwdTileSize
(
128
,
128
,
32
,
256
,
32
,
256
,
4
,
1
,
1
,
4
,
1
,
1
,
32
,
32
,
16
,
-
1
),
'256'
:
FmhaFwdTileSize
(
128
,
128
,
32
,
256
,
32
,
256
,
4
,
1
,
1
,
4
,
1
,
1
,
32
,
32
,
16
,
-
1
),
...
...
example/ck_tile/02_layernorm2d/generate.py
View file @
f199c936
...
@@ -198,7 +198,7 @@ float layernorm2d_fwd_(const S& s, A a)
...
@@ -198,7 +198,7 @@ float layernorm2d_fwd_(const S& s, A a)
static_cast<ck_tile::Layernorm2dFusedQuantEnum>(Traits_::kFusedQuant)>;
static_cast<ck_tile::Layernorm2dFusedQuantEnum>(Traits_::kFusedQuant)>;
using PipelineProblem = ck_tile::Layernorm2dFwdPipelineProblem<
using PipelineProblem = ck_tile::Layernorm2dFwdPipelineProblem<
typename LayerNormTypeConfig<XDataType, YDataType, XScaleDataType, YScaleDataType>::XDataType,
typename LayerNormTypeConfig<XDataType, YDataType, XScaleDataType, YScaleDataType>::XDataType,
typename LayerNormTypeConfig<XDataType, YDataType, XScaleDataType, YScaleDataType>::BiasDataType,
typename LayerNormTypeConfig<XDataType, YDataType, XScaleDataType, YScaleDataType>::
X
BiasDataType,
typename LayerNormTypeConfig<XDataType, YDataType, XScaleDataType, YScaleDataType>::GammaDataType,
typename LayerNormTypeConfig<XDataType, YDataType, XScaleDataType, YScaleDataType>::GammaDataType,
typename LayerNormTypeConfig<XDataType, YDataType, XScaleDataType, YScaleDataType>::BetaDataType,
typename LayerNormTypeConfig<XDataType, YDataType, XScaleDataType, YScaleDataType>::BetaDataType,
typename LayerNormTypeConfig<XDataType, YDataType, XScaleDataType, YScaleDataType>::ComputeDataType,
typename LayerNormTypeConfig<XDataType, YDataType, XScaleDataType, YScaleDataType>::ComputeDataType,
...
@@ -330,7 +330,7 @@ float layernorm2d_fwd(layernorm2d_fwd_traits t,
...
@@ -330,7 +330,7 @@ float layernorm2d_fwd(layernorm2d_fwd_traits t,
@
dataclass
@
dataclass
class
k_problem
:
class
k_problem
:
F_XDataType
:
str
F_XDataType
:
str
F_BiasDataType
:
str
F_
X
BiasDataType
:
str
F_GammaDataType
:
str
F_GammaDataType
:
str
F_BetaDataType
:
str
F_BetaDataType
:
str
F_ComputeDataType
:
str
F_ComputeDataType
:
str
...
...
example/ck_tile/02_layernorm2d/layernorm2d_fwd.cpp
View file @
f199c936
...
@@ -109,7 +109,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
...
@@ -109,7 +109,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
using
XDataType
=
typename
TypeConfig
::
XDataType
;
using
XDataType
=
typename
TypeConfig
::
XDataType
;
using
YDataType
=
typename
TypeConfig
::
YDataType
;
using
YDataType
=
typename
TypeConfig
::
YDataType
;
using
BiasDataType
=
typename
TypeConfig
::
BiasDataType
;
using
X
BiasDataType
=
typename
TypeConfig
::
X
BiasDataType
;
using
GammaDataType
=
typename
TypeConfig
::
GammaDataType
;
using
GammaDataType
=
typename
TypeConfig
::
GammaDataType
;
using
BetaDataType
=
typename
TypeConfig
::
BetaDataType
;
using
BetaDataType
=
typename
TypeConfig
::
BetaDataType
;
using
XResidualDataType
=
XDataType
;
using
XResidualDataType
=
XDataType
;
...
@@ -124,7 +124,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
...
@@ -124,7 +124,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
// host verify
// host verify
ck_tile
::
HostTensor
<
XDataType
>
x_host
({
m
,
n
},
{
x_stride
,
1
});
ck_tile
::
HostTensor
<
XDataType
>
x_host
({
m
,
n
},
{
x_stride
,
1
});
ck_tile
::
HostTensor
<
BiasDataType
>
bias_host
({
n
});
ck_tile
::
HostTensor
<
X
BiasDataType
>
x_
bias_host
({
n
});
ck_tile
::
HostTensor
<
GammaDataType
>
gamma_host
({
n
});
ck_tile
::
HostTensor
<
GammaDataType
>
gamma_host
({
n
});
ck_tile
::
HostTensor
<
BetaDataType
>
beta_host
({
n
});
ck_tile
::
HostTensor
<
BetaDataType
>
beta_host
({
n
});
...
@@ -145,12 +145,12 @@ bool run(const ck_tile::ArgParser& arg_parser)
...
@@ -145,12 +145,12 @@ bool run(const ck_tile::ArgParser& arg_parser)
ck_tile
::
FillUniformDistribution
<
XDataType
>
{
-
.5
f
,
.5
f
}(
x_host
);
ck_tile
::
FillUniformDistribution
<
XDataType
>
{
-
.5
f
,
.5
f
}(
x_host
);
ck_tile
::
FillUniformDistribution
<
XResidualDataType
>
{
-
.5
f
,
.5
f
}(
x_residual_host
);
ck_tile
::
FillUniformDistribution
<
XResidualDataType
>
{
-
.5
f
,
.5
f
}(
x_residual_host
);
ck_tile
::
FillUniformDistribution
<
XScaleDataType
>
{
-
1.
f
,
1.
f
}(
x_scale_host
);
ck_tile
::
FillUniformDistribution
<
XScaleDataType
>
{
-
1.
f
,
1.
f
}(
x_scale_host
);
ck_tile
::
FillUniformDistribution
<
BiasDataType
>
{
-
.5
f
,
.5
f
}(
bias_host
);
ck_tile
::
FillUniformDistribution
<
X
BiasDataType
>
{
-
.5
f
,
.5
f
}(
x_
bias_host
);
ck_tile
::
FillUniformDistribution
<
GammaDataType
>
{
-
.5
f
,
.5
f
}(
gamma_host
);
ck_tile
::
FillUniformDistribution
<
GammaDataType
>
{
-
.5
f
,
.5
f
}(
gamma_host
);
ck_tile
::
FillUniformDistribution
<
BetaDataType
>
{
-
.5
f
,
.5
f
}(
beta_host
);
ck_tile
::
FillUniformDistribution
<
BetaDataType
>
{
-
.5
f
,
.5
f
}(
beta_host
);
ck_tile
::
DeviceMem
x_buf
(
x_host
.
get_element_space_size_in_bytes
());
ck_tile
::
DeviceMem
x_buf
(
x_host
.
get_element_space_size_in_bytes
());
ck_tile
::
DeviceMem
bias_buf
(
bias_host
.
get_element_space_size_in_bytes
());
ck_tile
::
DeviceMem
x_
bias_buf
(
x_
bias_host
.
get_element_space_size_in_bytes
());
ck_tile
::
DeviceMem
gamma_buf
(
gamma_host
.
get_element_space_size_in_bytes
());
ck_tile
::
DeviceMem
gamma_buf
(
gamma_host
.
get_element_space_size_in_bytes
());
ck_tile
::
DeviceMem
beta_buf
(
beta_host
.
get_element_space_size_in_bytes
());
ck_tile
::
DeviceMem
beta_buf
(
beta_host
.
get_element_space_size_in_bytes
());
ck_tile
::
DeviceMem
y_buf
(
y_host_dev
.
get_element_space_size_in_bytes
());
ck_tile
::
DeviceMem
y_buf
(
y_host_dev
.
get_element_space_size_in_bytes
());
...
@@ -161,7 +161,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
...
@@ -161,7 +161,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
ck_tile
::
DeviceMem
y_residual_buf
(
y_residual_host
.
get_element_space_size_in_bytes
());
ck_tile
::
DeviceMem
y_residual_buf
(
y_residual_host
.
get_element_space_size_in_bytes
());
x_buf
.
ToDevice
(
x_host
.
data
());
x_buf
.
ToDevice
(
x_host
.
data
());
bias_buf
.
ToDevice
(
bias_host
.
data
());
x_
bias_buf
.
ToDevice
(
x_
bias_host
.
data
());
gamma_buf
.
ToDevice
(
gamma_host
.
data
());
gamma_buf
.
ToDevice
(
gamma_host
.
data
());
beta_buf
.
ToDevice
(
beta_host
.
data
());
beta_buf
.
ToDevice
(
beta_host
.
data
());
x_residual_buf
.
ToDevice
(
x_residual_host
.
data
());
x_residual_buf
.
ToDevice
(
x_residual_host
.
data
());
...
@@ -191,7 +191,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
...
@@ -191,7 +191,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
layernorm2d_fwd_args
args
{
x_buf
.
GetDeviceBuffer
(),
layernorm2d_fwd_args
args
{
x_buf
.
GetDeviceBuffer
(),
fused_add
!=
0
?
x_residual_buf
.
GetDeviceBuffer
()
:
nullptr
,
fused_add
!=
0
?
x_residual_buf
.
GetDeviceBuffer
()
:
nullptr
,
fused_quant
==
1
?
x_scale_buf
.
GetDeviceBuffer
()
:
nullptr
,
fused_quant
==
1
?
x_scale_buf
.
GetDeviceBuffer
()
:
nullptr
,
bias_buf
.
GetDeviceBuffer
(),
x_
bias_buf
.
GetDeviceBuffer
(),
gamma_buf
.
GetDeviceBuffer
(),
gamma_buf
.
GetDeviceBuffer
(),
beta_buf
.
GetDeviceBuffer
(),
beta_buf
.
GetDeviceBuffer
(),
...
@@ -218,8 +218,9 @@ bool run(const ck_tile::ArgParser& arg_parser)
...
@@ -218,8 +218,9 @@ bool run(const ck_tile::ArgParser& arg_parser)
return
false
;
return
false
;
}
}
std
::
size_t
num_byte
=
sizeof
(
XDataType
)
*
m
*
n
+
sizeof
(
BiasDataType
)
*
n
+
sizeof
(
GammaDataType
)
*
n
+
std
::
size_t
num_byte
=
sizeof
(
XDataType
)
*
m
*
n
+
sizeof
(
XBiasDataType
)
*
n
+
sizeof
(
BetaDataType
)
*
n
+
sizeof
(
YDataType
)
*
m
*
n
;
sizeof
(
GammaDataType
)
*
n
+
sizeof
(
BetaDataType
)
*
n
+
sizeof
(
YDataType
)
*
m
*
n
;
float
gb_per_sec
=
num_byte
/
1.E6
/
ave_time
;
float
gb_per_sec
=
num_byte
/
1.E6
/
ave_time
;
std
::
cout
<<
", "
<<
ave_time
*
1.E3
<<
" us, "
<<
gb_per_sec
<<
" GB/s"
<<
std
::
flush
;
std
::
cout
<<
", "
<<
ave_time
*
1.E3
<<
" us, "
<<
gb_per_sec
<<
" GB/s"
<<
std
::
flush
;
...
@@ -240,7 +241,9 @@ bool run(const ck_tile::ArgParser& arg_parser)
...
@@ -240,7 +241,9 @@ bool run(const ck_tile::ArgParser& arg_parser)
{
{
for
(
int
idx_n
=
0
;
idx_n
<
N
;
++
idx_n
)
for
(
int
idx_n
=
0
;
idx_n
<
N
;
++
idx_n
)
{
{
x_host
(
idx_m
,
idx_n
)
=
ck_tile
::
type_convert
<
XDataType
>
(
ck_tile
::
type_convert
<
ComputeDataType
>
(
x_host
(
idx_m
,
idx_n
))
+
ck_tile
::
type_convert
<
ComputeDataType
>
(
bias_host
(
idx_n
)));
x_host
(
idx_m
,
idx_n
)
=
ck_tile
::
type_convert
<
XDataType
>
(
ck_tile
::
type_convert
<
ComputeDataType
>
(
x_host
(
idx_m
,
idx_n
))
+
ck_tile
::
type_convert
<
ComputeDataType
>
(
x_bias_host
(
idx_n
)));
}
}
}
}
}
}
...
...
example/ck_tile/02_layernorm2d/layernorm2d_fwd.hpp
View file @
f199c936
...
@@ -16,7 +16,7 @@ struct LayerNormTypeConfig<ck_tile::half_t, OutType, XScaleDataType_, YScaleData
...
@@ -16,7 +16,7 @@ struct LayerNormTypeConfig<ck_tile::half_t, OutType, XScaleDataType_, YScaleData
{
{
using
XDataType
=
ck_tile
::
half_t
;
using
XDataType
=
ck_tile
::
half_t
;
using
YDataType
=
OutType
;
using
YDataType
=
OutType
;
using
BiasDataType
=
ck_tile
::
half_t
;
using
X
BiasDataType
=
ck_tile
::
half_t
;
using
GammaDataType
=
ck_tile
::
half_t
;
using
GammaDataType
=
ck_tile
::
half_t
;
using
BetaDataType
=
ck_tile
::
half_t
;
using
BetaDataType
=
ck_tile
::
half_t
;
using
MeanDataType
=
ck_tile
::
half_t
;
using
MeanDataType
=
ck_tile
::
half_t
;
...
@@ -31,7 +31,7 @@ struct LayerNormTypeConfig<ck_tile::bf16_t, OutType, XScaleDataType_, YScaleData
...
@@ -31,7 +31,7 @@ struct LayerNormTypeConfig<ck_tile::bf16_t, OutType, XScaleDataType_, YScaleData
{
{
using
XDataType
=
ck_tile
::
bf16_t
;
using
XDataType
=
ck_tile
::
bf16_t
;
using
YDataType
=
OutType
;
using
YDataType
=
OutType
;
using
BiasDataType
=
ck_tile
::
bf16_t
;
using
X
BiasDataType
=
ck_tile
::
bf16_t
;
using
GammaDataType
=
ck_tile
::
bf16_t
;
using
GammaDataType
=
ck_tile
::
bf16_t
;
using
BetaDataType
=
ck_tile
::
bf16_t
;
using
BetaDataType
=
ck_tile
::
bf16_t
;
using
MeanDataType
=
ck_tile
::
bf16_t
;
using
MeanDataType
=
ck_tile
::
bf16_t
;
...
...
include/ck_tile/ops/layernorm2d/kernel/layernorm2d_fwd_kernel.hpp
View file @
f199c936
...
@@ -15,7 +15,7 @@ struct Layernorm2dFwdHostArgs
...
@@ -15,7 +15,7 @@ struct Layernorm2dFwdHostArgs
const
void
*
p_x
;
// [m ,n], input, fp16/bf16
const
void
*
p_x
;
// [m ,n], input, fp16/bf16
const
void
*
p_x_residual
;
// [m ,n], shortcut input, prec same as input, nullptr if not used
const
void
*
p_x_residual
;
// [m ,n], shortcut input, prec same as input, nullptr if not used
const
void
*
p_x_scale
;
// [1 ,n], smooth scale input, fp32, nullptr if not used
const
void
*
p_x_scale
;
// [1 ,n], smooth scale input, fp32, nullptr if not used
const
void
*
p_bias
;
// [1, n], bias, prec same as input
const
void
*
p_
x_
bias
;
// [1, n], bias, prec same as input
const
void
*
p_gamma
;
// [1, n], gamma, prec same as input
const
void
*
p_gamma
;
// [1, n], gamma, prec same as input
const
void
*
p_beta
;
// [1, n], beta, prec same as input
const
void
*
p_beta
;
// [1, n], beta, prec same as input
...
@@ -44,7 +44,7 @@ struct Layernorm2dFwd
...
@@ -44,7 +44,7 @@ struct Layernorm2dFwd
using
Problem
=
typename
Pipeline
::
Problem
;
using
Problem
=
typename
Pipeline
::
Problem
;
using
XDataType
=
remove_cvref_t
<
typename
Problem
::
XDataType
>
;
using
XDataType
=
remove_cvref_t
<
typename
Problem
::
XDataType
>
;
using
BiasDataType
=
remove_cvref_t
<
typename
Problem
::
BiasDataType
>
;
using
X
BiasDataType
=
remove_cvref_t
<
typename
Problem
::
X
BiasDataType
>
;
using
GammaDataType
=
remove_cvref_t
<
typename
Problem
::
GammaDataType
>
;
using
GammaDataType
=
remove_cvref_t
<
typename
Problem
::
GammaDataType
>
;
using
BetaDataType
=
remove_cvref_t
<
typename
Problem
::
BetaDataType
>
;
using
BetaDataType
=
remove_cvref_t
<
typename
Problem
::
BetaDataType
>
;
using
ComputeDataType
=
remove_cvref_t
<
typename
Problem
::
ComputeDataType
>
;
using
ComputeDataType
=
remove_cvref_t
<
typename
Problem
::
ComputeDataType
>
;
...
@@ -85,7 +85,7 @@ struct Layernorm2dFwd
...
@@ -85,7 +85,7 @@ struct Layernorm2dFwd
const
void
*
p_x
;
// [m ,n], input, fp16/bf16
const
void
*
p_x
;
// [m ,n], input, fp16/bf16
const
void
*
p_x_residual
;
// [m ,n], shortcut input, prec same as input, nullptr if not used
const
void
*
p_x_residual
;
// [m ,n], shortcut input, prec same as input, nullptr if not used
const
void
*
p_x_scale
;
// [1 ,n], smooth scale input, fp32, nullptr if not used
const
void
*
p_x_scale
;
// [1 ,n], smooth scale input, fp32, nullptr if not used
const
void
*
p_bias
;
// [1, n], bias, prec same as input
const
void
*
p_
x_
bias
;
// [1, n], bias, prec same as input
const
void
*
p_gamma
;
// [1, n], gamma, prec same as input
const
void
*
p_gamma
;
// [1, n], gamma, prec same as input
const
void
*
p_beta
;
// [1, n], beta, prec same as input
const
void
*
p_beta
;
// [1, n], beta, prec same as input
...
@@ -112,7 +112,7 @@ struct Layernorm2dFwd
...
@@ -112,7 +112,7 @@ struct Layernorm2dFwd
return
Kargs
{
hargs
.
p_x
,
return
Kargs
{
hargs
.
p_x
,
hargs
.
p_x_residual
,
hargs
.
p_x_residual
,
hargs
.
p_x_scale
,
hargs
.
p_x_scale
,
hargs
.
p_bias
,
hargs
.
p_
x_
bias
,
hargs
.
p_gamma
,
hargs
.
p_gamma
,
hargs
.
p_beta
,
hargs
.
p_beta
,
hargs
.
p_y
,
hargs
.
p_y
,
...
@@ -234,11 +234,11 @@ struct Layernorm2dFwd
...
@@ -234,11 +234,11 @@ struct Layernorm2dFwd
}
}
}();
}();
const
auto
bias_window
=
[
&
]()
{
const
auto
x_
bias_window
=
[
&
]()
{
if
constexpr
(
kBias
==
Layernorm2dBiasEnum
::
ADD_BIAS
)
if
constexpr
(
kBias
==
Layernorm2dBiasEnum
::
ADD_BIAS
)
{
{
const
auto
tmp_
=
make_naive_tensor_view
<
address_space_enum
::
global
>
(
const
auto
tmp_
=
make_naive_tensor_view
<
address_space_enum
::
global
>
(
static_cast
<
const
BiasDataType
*>
(
kargs
.
p_bias
),
static_cast
<
const
X
BiasDataType
*>
(
kargs
.
p_
x_
bias
),
make_tuple
(
kargs
.
n
),
make_tuple
(
kargs
.
n
),
make_tuple
(
1
),
make_tuple
(
1
),
number
<
Vector_N
>
{},
number
<
Vector_N
>
{},
...
@@ -398,7 +398,7 @@ struct Layernorm2dFwd
...
@@ -398,7 +398,7 @@ struct Layernorm2dFwd
Pipeline
{}(
x_window
,
Pipeline
{}(
x_window
,
x_residual_window
,
x_residual_window
,
bias_window
,
x_
bias_window
,
gamma_window
,
gamma_window
,
beta_window
,
beta_window
,
y_window
,
y_window
,
...
...
include/ck_tile/ops/layernorm2d/pipeline/layernorm2d_fwd_pipeline_one_pass.hpp
View file @
f199c936
...
@@ -18,7 +18,7 @@ struct Layernorm2dFwdPipelineOnePass
...
@@ -18,7 +18,7 @@ struct Layernorm2dFwdPipelineOnePass
using
Policy
=
ck_tile
::
remove_cvref_t
<
Policy_
>
;
using
Policy
=
ck_tile
::
remove_cvref_t
<
Policy_
>
;
using
XDataType
=
ck_tile
::
remove_cvref_t
<
typename
Problem
::
XDataType
>
;
using
XDataType
=
ck_tile
::
remove_cvref_t
<
typename
Problem
::
XDataType
>
;
using
BiasDataType
=
ck_tile
::
remove_cvref_t
<
typename
Problem
::
BiasDataType
>
;
using
X
BiasDataType
=
ck_tile
::
remove_cvref_t
<
typename
Problem
::
X
BiasDataType
>
;
using
GammaDataType
=
ck_tile
::
remove_cvref_t
<
typename
Problem
::
GammaDataType
>
;
using
GammaDataType
=
ck_tile
::
remove_cvref_t
<
typename
Problem
::
GammaDataType
>
;
using
BetaDataType
=
ck_tile
::
remove_cvref_t
<
typename
Problem
::
BetaDataType
>
;
using
BetaDataType
=
ck_tile
::
remove_cvref_t
<
typename
Problem
::
BetaDataType
>
;
using
ComputeDataType
=
ck_tile
::
remove_cvref_t
<
typename
Problem
::
ComputeDataType
>
;
using
ComputeDataType
=
ck_tile
::
remove_cvref_t
<
typename
Problem
::
ComputeDataType
>
;
...
@@ -56,7 +56,7 @@ struct Layernorm2dFwdPipelineOnePass
...
@@ -56,7 +56,7 @@ struct Layernorm2dFwdPipelineOnePass
template
<
typename
XWindow
,
template
<
typename
XWindow
,
typename
XResidualWindow
,
typename
XResidualWindow
,
typename
BiasWindow
,
typename
X
BiasWindow
,
typename
GammaWindow
,
typename
GammaWindow
,
typename
BetaWindow
,
typename
BetaWindow
,
typename
YWindow
,
typename
YWindow
,
...
@@ -68,7 +68,7 @@ struct Layernorm2dFwdPipelineOnePass
...
@@ -68,7 +68,7 @@ struct Layernorm2dFwdPipelineOnePass
typename
Epilogue
>
typename
Epilogue
>
CK_TILE_DEVICE
auto
operator
()(
const
XWindow
&
x_window_
,
CK_TILE_DEVICE
auto
operator
()(
const
XWindow
&
x_window_
,
const
XResidualWindow
&
x_residual_window_
,
const
XResidualWindow
&
x_residual_window_
,
const
BiasWindow
&
bias_window_
,
const
X
BiasWindow
&
x_
bias_window_
,
const
GammaWindow
&
gamma_window_
,
const
GammaWindow
&
gamma_window_
,
const
BetaWindow
&
beta_window_
,
const
BetaWindow
&
beta_window_
,
YWindow
&
y_window_
,
YWindow
&
y_window_
,
...
@@ -84,8 +84,8 @@ struct Layernorm2dFwdPipelineOnePass
...
@@ -84,8 +84,8 @@ struct Layernorm2dFwdPipelineOnePass
{
{
const
auto
x_window
=
const
auto
x_window
=
make_tile_window
(
x_window_
,
Policy
::
template
MakeXBlockTileDistribution
<
Problem
>());
make_tile_window
(
x_window_
,
Policy
::
template
MakeXBlockTileDistribution
<
Problem
>());
const
auto
bias_window
=
make_tile_window
(
const
auto
x_
bias_window
=
make_tile_window
(
bias_window_
,
Policy
::
template
MakeGammaBetaBlockTileDistribution
<
Problem
>());
x_
bias_window_
,
Policy
::
template
MakeGammaBetaBlockTileDistribution
<
Problem
>());
const
auto
gamma_window
=
make_tile_window
(
const
auto
gamma_window
=
make_tile_window
(
gamma_window_
,
Policy
::
template
MakeGammaBetaBlockTileDistribution
<
Problem
>());
gamma_window_
,
Policy
::
template
MakeGammaBetaBlockTileDistribution
<
Problem
>());
const
auto
beta_window
=
make_tile_window
(
const
auto
beta_window
=
make_tile_window
(
...
@@ -97,7 +97,7 @@ struct Layernorm2dFwdPipelineOnePass
...
@@ -97,7 +97,7 @@ struct Layernorm2dFwdPipelineOnePass
auto
x
=
load_tile
(
x_window
);
auto
x
=
load_tile
(
x_window
);
auto
x_resi
=
load_tile
(
x_residual_window
);
auto
x_resi
=
load_tile
(
x_residual_window
);
const
auto
bias
=
load_tile
(
bias_window
);
const
auto
x_
bias
=
load_tile
(
x_
bias_window
);
int
cur_count
=
0
;
int
cur_count
=
0
;
int
max_count
=
int
max_count
=
...
@@ -121,7 +121,7 @@ struct Layernorm2dFwdPipelineOnePass
...
@@ -121,7 +121,7 @@ struct Layernorm2dFwdPipelineOnePass
{
{
// compute x = bias + x
// compute x = bias + x
constexpr
auto
j_idx
=
make_tuple
(
idx
[
number
<
1
>
{}]);
constexpr
auto
j_idx
=
make_tuple
(
idx
[
number
<
1
>
{}]);
acc
(
idx
)
=
type_convert
<
ComputeDataType
>
(
bias
[
j_idx
])
+
acc
(
idx
);
acc
(
idx
)
=
type_convert
<
ComputeDataType
>
(
x_
bias
[
j_idx
])
+
acc
(
idx
);
}
}
// compute x = x_resi + x
// compute x = x_resi + x
acc
(
idx
)
=
type_convert
<
ComputeDataType
>
(
x_resi
(
idx
))
+
acc
(
idx
);
acc
(
idx
)
=
type_convert
<
ComputeDataType
>
(
x_resi
(
idx
))
+
acc
(
idx
);
...
...
include/ck_tile/ops/layernorm2d/pipeline/layernorm2d_fwd_pipeline_problem.hpp
View file @
f199c936
...
@@ -8,7 +8,7 @@
...
@@ -8,7 +8,7 @@
namespace
ck_tile
{
namespace
ck_tile
{
template
<
typename
XDataType_
,
template
<
typename
XDataType_
,
typename
BiasDataType_
,
typename
X
BiasDataType_
,
typename
GammaDataType_
,
typename
GammaDataType_
,
typename
BetaDataType_
,
typename
BetaDataType_
,
typename
ComputeDataType_
,
typename
ComputeDataType_
,
...
@@ -22,7 +22,7 @@ template <typename XDataType_,
...
@@ -22,7 +22,7 @@ template <typename XDataType_,
struct
Layernorm2dFwdPipelineProblem
struct
Layernorm2dFwdPipelineProblem
{
{
using
XDataType
=
remove_cvref_t
<
XDataType_
>
;
using
XDataType
=
remove_cvref_t
<
XDataType_
>
;
using
BiasDataType
=
remove_cvref_t
<
BiasDataType_
>
;
using
X
BiasDataType
=
remove_cvref_t
<
X
BiasDataType_
>
;
using
GammaDataType
=
remove_cvref_t
<
GammaDataType_
>
;
using
GammaDataType
=
remove_cvref_t
<
GammaDataType_
>
;
using
BetaDataType
=
remove_cvref_t
<
BetaDataType_
>
;
using
BetaDataType
=
remove_cvref_t
<
BetaDataType_
>
;
using
ComputeDataType
=
remove_cvref_t
<
ComputeDataType_
>
;
using
ComputeDataType
=
remove_cvref_t
<
ComputeDataType_
>
;
...
...
include/ck_tile/ops/layernorm2d/pipeline/layernorm2d_fwd_pipeline_two_pass.hpp
View file @
f199c936
...
@@ -17,7 +17,7 @@ struct Layernorm2dFwdPipelineTwoPass
...
@@ -17,7 +17,7 @@ struct Layernorm2dFwdPipelineTwoPass
using
Policy
=
ck_tile
::
remove_cvref_t
<
Policy_
>
;
using
Policy
=
ck_tile
::
remove_cvref_t
<
Policy_
>
;
using
XDataType
=
ck_tile
::
remove_cvref_t
<
typename
Problem
::
XDataType
>
;
using
XDataType
=
ck_tile
::
remove_cvref_t
<
typename
Problem
::
XDataType
>
;
using
BiasDataType
=
ck_tile
::
remove_cvref_t
<
typename
Problem
::
BiasDataType
>
;
using
X
BiasDataType
=
ck_tile
::
remove_cvref_t
<
typename
Problem
::
X
BiasDataType
>
;
using
GammaDataType
=
ck_tile
::
remove_cvref_t
<
typename
Problem
::
GammaDataType
>
;
using
GammaDataType
=
ck_tile
::
remove_cvref_t
<
typename
Problem
::
GammaDataType
>
;
using
BetaDataType
=
ck_tile
::
remove_cvref_t
<
typename
Problem
::
BetaDataType
>
;
using
BetaDataType
=
ck_tile
::
remove_cvref_t
<
typename
Problem
::
BetaDataType
>
;
using
ComputeDataType
=
ck_tile
::
remove_cvref_t
<
typename
Problem
::
ComputeDataType
>
;
using
ComputeDataType
=
ck_tile
::
remove_cvref_t
<
typename
Problem
::
ComputeDataType
>
;
...
@@ -55,7 +55,7 @@ struct Layernorm2dFwdPipelineTwoPass
...
@@ -55,7 +55,7 @@ struct Layernorm2dFwdPipelineTwoPass
template
<
typename
XWindow
,
template
<
typename
XWindow
,
typename
XResidualWindow
,
typename
XResidualWindow
,
typename
BiasWindow
,
typename
X
BiasWindow
,
typename
GammaWindow
,
typename
GammaWindow
,
typename
BetaWindow
,
typename
BetaWindow
,
typename
YWindow
,
typename
YWindow
,
...
@@ -67,7 +67,7 @@ struct Layernorm2dFwdPipelineTwoPass
...
@@ -67,7 +67,7 @@ struct Layernorm2dFwdPipelineTwoPass
typename
Epilogue
>
typename
Epilogue
>
CK_TILE_DEVICE
auto
operator
()(
const
XWindow
&
x_window_
,
CK_TILE_DEVICE
auto
operator
()(
const
XWindow
&
x_window_
,
const
XResidualWindow
&
x_residual_window_
,
const
XResidualWindow
&
x_residual_window_
,
const
BiasWindow
&
bias_window_
,
const
X
BiasWindow
&
x_
bias_window_
,
const
GammaWindow
&
gamma_window_
,
const
GammaWindow
&
gamma_window_
,
const
BetaWindow
&
beta_window_
,
const
BetaWindow
&
beta_window_
,
YWindow
&
y_window
,
YWindow
&
y_window
,
...
@@ -83,8 +83,8 @@ struct Layernorm2dFwdPipelineTwoPass
...
@@ -83,8 +83,8 @@ struct Layernorm2dFwdPipelineTwoPass
{
{
auto
x_window
=
auto
x_window
=
make_tile_window
(
x_window_
,
Policy
::
template
MakeXBlockTileDistribution
<
Problem
>());
make_tile_window
(
x_window_
,
Policy
::
template
MakeXBlockTileDistribution
<
Problem
>());
auto
bias_window
=
make_tile_window
(
auto
x_
bias_window
=
make_tile_window
(
bias_window_
,
Policy
::
template
MakeGammaBetaBlockTileDistribution
<
Problem
>());
x_
bias_window_
,
Policy
::
template
MakeGammaBetaBlockTileDistribution
<
Problem
>());
auto
gamma_window
=
make_tile_window
(
auto
gamma_window
=
make_tile_window
(
gamma_window_
,
Policy
::
template
MakeGammaBetaBlockTileDistribution
<
Problem
>());
gamma_window_
,
Policy
::
template
MakeGammaBetaBlockTileDistribution
<
Problem
>());
auto
beta_window
=
make_tile_window
(
auto
beta_window
=
make_tile_window
(
...
@@ -121,11 +121,11 @@ struct Layernorm2dFwdPipelineTwoPass
...
@@ -121,11 +121,11 @@ struct Layernorm2dFwdPipelineTwoPass
{
{
auto
x
=
load_tile
(
x_window
);
auto
x
=
load_tile
(
x_window
);
auto
x_resi
=
load_tile
(
x_residual_window
);
auto
x_resi
=
load_tile
(
x_residual_window
);
const
auto
bias
=
load_tile
(
bias_window
);
const
auto
x_
bias
=
load_tile
(
x_
bias_window
);
move_tile_window
(
x_window
,
{
0
,
Block_N
});
move_tile_window
(
x_window
,
{
0
,
Block_N
});
move_tile_window
(
x_residual_window
,
{
0
,
Block_N
});
move_tile_window
(
x_residual_window
,
{
0
,
Block_N
});
move_tile_window
(
bias_window
,
{
Block_N
});
move_tile_window
(
x_
bias_window
,
{
Block_N
});
auto
acc
=
cast_tile
<
ComputeDataType
>
(
x
);
auto
acc
=
cast_tile
<
ComputeDataType
>
(
x
);
if
constexpr
(
kFusedAdd
==
Layernorm2dFusedAddEnum
::
PRE_ADD_STORE
||
if
constexpr
(
kFusedAdd
==
Layernorm2dFusedAddEnum
::
PRE_ADD_STORE
||
...
@@ -136,7 +136,7 @@ struct Layernorm2dFwdPipelineTwoPass
...
@@ -136,7 +136,7 @@ struct Layernorm2dFwdPipelineTwoPass
{
{
// compute x = bias + x
// compute x = bias + x
constexpr
auto
j_idx
=
make_tuple
(
idx
[
number
<
1
>
{}]);
constexpr
auto
j_idx
=
make_tuple
(
idx
[
number
<
1
>
{}]);
acc
(
idx
)
=
type_convert
<
ComputeDataType
>
(
bias
[
j_idx
])
+
acc
(
idx
);
acc
(
idx
)
=
type_convert
<
ComputeDataType
>
(
x_
bias
[
j_idx
])
+
acc
(
idx
);
}
}
// compute x = x_resi + x
// compute x = x_resi + x
acc
(
idx
)
=
type_convert
<
ComputeDataType
>
(
x_resi
(
idx
))
+
acc
(
idx
);
acc
(
idx
)
=
type_convert
<
ComputeDataType
>
(
x_resi
(
idx
))
+
acc
(
idx
);
...
@@ -179,7 +179,7 @@ struct Layernorm2dFwdPipelineTwoPass
...
@@ -179,7 +179,7 @@ struct Layernorm2dFwdPipelineTwoPass
move_tile_window
(
x_window
,
{
0
,
-
Block_N
});
move_tile_window
(
x_window
,
{
0
,
-
Block_N
});
move_tile_window
(
x_residual_window
,
{
0
,
-
Block_N
});
move_tile_window
(
x_residual_window
,
{
0
,
-
Block_N
});
move_tile_window
(
bias_window
,
{
-
Block_N
});
move_tile_window
(
x_
bias_window
,
{
-
Block_N
});
move_tile_window
(
gamma_window
,
{
stride_to_right_most_window
});
move_tile_window
(
gamma_window
,
{
stride_to_right_most_window
});
move_tile_window
(
beta_window
,
{
stride_to_right_most_window
});
move_tile_window
(
beta_window
,
{
stride_to_right_most_window
});
move_tile_window
(
y_window
,
{
0
,
stride_to_right_most_window
});
move_tile_window
(
y_window
,
{
0
,
stride_to_right_most_window
});
...
@@ -189,7 +189,7 @@ struct Layernorm2dFwdPipelineTwoPass
...
@@ -189,7 +189,7 @@ struct Layernorm2dFwdPipelineTwoPass
{
{
auto
x
=
load_tile
(
x_window
);
auto
x
=
load_tile
(
x_window
);
auto
x_resi
=
load_tile
(
x_residual_window
);
auto
x_resi
=
load_tile
(
x_residual_window
);
const
auto
bias
=
load_tile
(
bias_window
);
const
auto
x_
bias
=
load_tile
(
x_
bias_window
);
auto
acc
=
cast_tile
<
ComputeDataType
>
(
x
);
auto
acc
=
cast_tile
<
ComputeDataType
>
(
x
);
if
constexpr
(
kFusedAdd
==
Layernorm2dFusedAddEnum
::
PRE_ADD_STORE
||
if
constexpr
(
kFusedAdd
==
Layernorm2dFusedAddEnum
::
PRE_ADD_STORE
||
...
@@ -200,7 +200,7 @@ struct Layernorm2dFwdPipelineTwoPass
...
@@ -200,7 +200,7 @@ struct Layernorm2dFwdPipelineTwoPass
{
{
// compute x = bias + x
// compute x = bias + x
constexpr
auto
j_idx
=
make_tuple
(
idx
[
number
<
1
>
{}]);
constexpr
auto
j_idx
=
make_tuple
(
idx
[
number
<
1
>
{}]);
acc
(
idx
)
=
type_convert
<
ComputeDataType
>
(
bias
[
j_idx
])
+
acc
(
idx
);
acc
(
idx
)
=
type_convert
<
ComputeDataType
>
(
x_
bias
[
j_idx
])
+
acc
(
idx
);
}
}
// compute x = x_resi + x
// compute x = x_resi + x
acc
(
idx
)
=
type_convert
<
ComputeDataType
>
(
x_resi
(
idx
))
+
acc
(
idx
);
acc
(
idx
)
=
type_convert
<
ComputeDataType
>
(
x_resi
(
idx
))
+
acc
(
idx
);
...
@@ -229,7 +229,7 @@ struct Layernorm2dFwdPipelineTwoPass
...
@@ -229,7 +229,7 @@ struct Layernorm2dFwdPipelineTwoPass
move_tile_window
(
x_window
,
{
0
,
-
Block_N
});
move_tile_window
(
x_window
,
{
0
,
-
Block_N
});
move_tile_window
(
x_residual_window
,
{
0
,
-
Block_N
});
move_tile_window
(
x_residual_window
,
{
0
,
-
Block_N
});
move_tile_window
(
bias_window
,
{
-
Block_N
});
move_tile_window
(
x_
bias_window
,
{
-
Block_N
});
move_tile_window
(
gamma_window
,
{
-
Block_N
});
move_tile_window
(
gamma_window
,
{
-
Block_N
});
move_tile_window
(
beta_window
,
{
-
Block_N
});
move_tile_window
(
beta_window
,
{
-
Block_N
});
move_tile_window
(
y_window
,
{
0
,
-
Block_N
});
move_tile_window
(
y_window
,
{
0
,
-
Block_N
});
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment