Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel_ROCM
Commits
d62f0358
Commit
d62f0358
authored
Oct 16, 2024
by
rocking
Browse files
Remove save mean and inv std
parent
29cff07e
Changes
5
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
45 additions
and
56 deletions
+45
-56
example/ck_tile/02_layernorm2d/example_layernorm2d_fwd.cpp
example/ck_tile/02_layernorm2d/example_layernorm2d_fwd.cpp
+0
-14
example/ck_tile/02_layernorm2d/layernorm2d_fwd.hpp
example/ck_tile/02_layernorm2d/layernorm2d_fwd.hpp
+19
-28
example/ck_tile/02_layernorm2d/layernorm_dispatch.hpp
example/ck_tile/02_layernorm2d/layernorm_dispatch.hpp
+11
-2
include/ck_tile/ops/layernorm2d/kernel/layernorm2d_fwd_kernel.hpp
...ck_tile/ops/layernorm2d/kernel/layernorm2d_fwd_kernel.hpp
+2
-2
include/ck_tile/ops/layernorm2d/pipeline/block_layernorm2d_fwd_problem.hpp
...ps/layernorm2d/pipeline/block_layernorm2d_fwd_problem.hpp
+13
-10
No files found.
example/ck_tile/02_layernorm2d/example_layernorm2d_fwd.cpp
View file @
d62f0358
...
@@ -52,11 +52,6 @@ bool run(const ck_tile::ArgParser& arg_parser)
...
@@ -52,11 +52,6 @@ bool run(const ck_tile::ArgParser& arg_parser)
ck_tile
::
HostTensor
<
MeanDataType
>
mean_host_ref
({
M
});
ck_tile
::
HostTensor
<
MeanDataType
>
mean_host_ref
({
M
});
ck_tile
::
HostTensor
<
InvStdDataType
>
invStd_host_ref
({
M
});
ck_tile
::
HostTensor
<
InvStdDataType
>
invStd_host_ref
({
M
});
// TODO - move SAVE_MEAN_INV_STD to user args
#ifdef SAVE_MEAN_INV_STD
ck_tile
::
HostTensor
<
MeanDataType
>
mean_host_dev
({
M
});
ck_tile
::
HostTensor
<
InvStdDataType
>
invStd_host_dev
({
M
});
#endif
ck_tile
::
FillUniformDistribution
<
XDataType
>
{
-
.5
f
,
.5
f
}(
x_host
);
ck_tile
::
FillUniformDistribution
<
XDataType
>
{
-
.5
f
,
.5
f
}(
x_host
);
ck_tile
::
FillUniformDistribution
<
GammaDataType
>
{
-
.5
f
,
.5
f
}(
gamma_host
);
ck_tile
::
FillUniformDistribution
<
GammaDataType
>
{
-
.5
f
,
.5
f
}(
gamma_host
);
...
@@ -66,10 +61,6 @@ bool run(const ck_tile::ArgParser& arg_parser)
...
@@ -66,10 +61,6 @@ bool run(const ck_tile::ArgParser& arg_parser)
ck_tile
::
DeviceMem
gamma_buf
(
gamma_host
.
get_element_space_size_in_bytes
());
ck_tile
::
DeviceMem
gamma_buf
(
gamma_host
.
get_element_space_size_in_bytes
());
ck_tile
::
DeviceMem
beta_buf
(
beta_host
.
get_element_space_size_in_bytes
());
ck_tile
::
DeviceMem
beta_buf
(
beta_host
.
get_element_space_size_in_bytes
());
ck_tile
::
DeviceMem
y_buf
(
y_host_dev
.
get_element_space_size_in_bytes
());
ck_tile
::
DeviceMem
y_buf
(
y_host_dev
.
get_element_space_size_in_bytes
());
#ifdef SAVE_MEAN_INV_STD
ck_tile
::
DeviceMem
mean_buf
(
mean_host_dev
.
get_element_space_size_in_bytes
());
ck_tile
::
DeviceMem
invStd_buf
(
invStd_host_dev
.
get_element_space_size_in_bytes
());
#endif
x_buf
.
ToDevice
(
x_host
.
data
());
x_buf
.
ToDevice
(
x_host
.
data
());
gamma_buf
.
ToDevice
(
gamma_host
.
data
());
gamma_buf
.
ToDevice
(
gamma_host
.
data
());
...
@@ -81,13 +72,8 @@ bool run(const ck_tile::ArgParser& arg_parser)
...
@@ -81,13 +72,8 @@ bool run(const ck_tile::ArgParser& arg_parser)
gamma_buf
.
GetDeviceBuffer
(),
gamma_buf
.
GetDeviceBuffer
(),
beta_buf
.
GetDeviceBuffer
(),
beta_buf
.
GetDeviceBuffer
(),
y_buf
.
GetDeviceBuffer
(),
y_buf
.
GetDeviceBuffer
(),
#ifdef SAVE_MEAN_INV_STD
mean_buf
.
GetDeviceBuffer
(),
invStd_buf
.
GetDeviceBuffer
(),
#else
nullptr
,
nullptr
,
nullptr
,
nullptr
,
#endif
epsilon
,
epsilon
,
M
,
M
,
N
};
N
};
...
...
example/ck_tile/02_layernorm2d/layernorm2d_fwd.hpp
View file @
d62f0358
...
@@ -8,48 +8,34 @@
...
@@ -8,48 +8,34 @@
#include "ck_tile/ops/layernorm2d.hpp"
#include "ck_tile/ops/layernorm2d.hpp"
#include <string>
#include <string>
struct
layernorm2d_fwd_traits
{
std
::
string
data_type
;
};
template
<
typename
DataType
>
template
<
typename
DataType
>
struct
LayerNormTypeConfig
;
struct
LayerNormTypeConfig
;
template
<
>
template
<
>
struct
LayerNormTypeConfig
<
ck_tile
::
half_t
>
struct
LayerNormTypeConfig
<
ck_tile
::
half_t
>
{
{
using
XDataType
=
ck_tile
::
half_t
;
using
XDataType
=
ck_tile
::
half_t
;
using
YDataType
=
ck_tile
::
half_t
;
using
YDataType
=
ck_tile
::
half_t
;
using
GammaDataType
=
ck_tile
::
half_t
;
using
GammaDataType
=
ck_tile
::
half_t
;
using
BetaDataType
=
ck_tile
::
half_t
;
using
BetaDataType
=
ck_tile
::
half_t
;
#ifdef SAVE_MEAN_INV_STD
using
MeanDataType
=
ck_tile
::
half_t
;
using
MeanDataType
=
ck_tile
::
half_t
;
using
InvStdDataType
=
ck_tile
::
half_t
;
using
InvStdDataType
=
ck_tile
::
half_t
;
#else
using
MeanDataType
=
ck_tile
::
null_type
;
using
InvStdDataType
=
ck_tile
::
null_type
;
#endif
using
ComputeDataType
=
float
;
using
ComputeDataType
=
float
;
};
};
template
<
>
template
<
>
struct
LayerNormTypeConfig
<
float
>
struct
LayerNormTypeConfig
<
float
>
{
{
using
XDataType
=
float
;
using
XDataType
=
float
;
using
YDataType
=
float
;
using
YDataType
=
float
;
using
GammaDataType
=
float
;
using
GammaDataType
=
float
;
using
BetaDataType
=
float
;
using
BetaDataType
=
float
;
#ifdef SAVE_MEAN_INV_STD
using
MeanDataType
=
float
;
using
MeanDataType
=
float
;
using
InvStdDataType
=
float
;
using
InvStdDataType
=
float
;
#else
using
MeanDataType
=
ck_tile
::
null_type
;
using
InvStdDataType
=
ck_tile
::
null_type
;
#endif
using
ComputeDataType
=
float
;
using
ComputeDataType
=
float
;
};
};
// runtime args
struct
layernorm2d_fwd_args
struct
layernorm2d_fwd_args
{
{
const
void
*
p_x
;
const
void
*
p_x
;
...
@@ -63,5 +49,10 @@ struct layernorm2d_fwd_args
...
@@ -63,5 +49,10 @@ struct layernorm2d_fwd_args
ck_tile
::
index_t
N
;
ck_tile
::
index_t
N
;
};
};
// host API
// This is the public API, will be generated by script
struct
layernorm2d_fwd_traits
{
std
::
string
data_type
;
};
float
layernorm2d_fwd
(
layernorm2d_fwd_traits
,
layernorm2d_fwd_args
,
const
ck_tile
::
stream_config
&
);
float
layernorm2d_fwd
(
layernorm2d_fwd_traits
,
layernorm2d_fwd_args
,
const
ck_tile
::
stream_config
&
);
example/ck_tile/02_layernorm2d/layernorm_dispatch.hpp
View file @
d62f0358
...
@@ -14,6 +14,7 @@ template <typename InOutDataType,
...
@@ -14,6 +14,7 @@ template <typename InOutDataType,
ck_tile
::
index_t
NThread
,
ck_tile
::
index_t
NThread
,
ck_tile
::
index_t
VectorAccessSize
,
ck_tile
::
index_t
VectorAccessSize
,
bool
kPadN
,
bool
kPadN
,
bool
kSaveMeanInvStd
,
bool
kTwoPass
>
bool
kTwoPass
>
struct
layernorm_dispatch
struct
layernorm_dispatch
{
{
...
@@ -38,6 +39,7 @@ struct layernorm_dispatch
...
@@ -38,6 +39,7 @@ struct layernorm_dispatch
typename
LayerNormTypeConfig
<
InOutDataType
>::
InvStdDataType
,
typename
LayerNormTypeConfig
<
InOutDataType
>::
InvStdDataType
,
Shape
,
Shape
,
kPadN
,
kPadN
,
kSaveMeanInvStd
,
kTwoPass
>
;
kTwoPass
>
;
using
Kernel
=
ck_tile
::
Layernorm2dFwd
<
PipelineProblem
>
;
using
Kernel
=
ck_tile
::
Layernorm2dFwd
<
PipelineProblem
>
;
...
@@ -75,6 +77,13 @@ template <typename InOutDataType,
...
@@ -75,6 +77,13 @@ template <typename InOutDataType,
bool
kTwoPass
=
false
>
bool
kTwoPass
=
false
>
float
run_layernorm
(
const
layernorm2d_fwd_args
&
param
,
ck_tile
::
stream_config
stream
)
float
run_layernorm
(
const
layernorm2d_fwd_args
&
param
,
ck_tile
::
stream_config
stream
)
{
{
return
layernorm_dispatch
<
InOutDataType
,
NRepeat
,
NThread
,
VectorAccessSize
,
kPadN
,
kTwoPass
>::
// TODO - Add SaveMeanInvStd instance
Run
(
param
,
stream
);
constexpr
bool
kSaveMeanInvStd
=
false
;
return
layernorm_dispatch
<
InOutDataType
,
NRepeat
,
NThread
,
VectorAccessSize
,
kSaveMeanInvStd
,
kPadN
,
kTwoPass
>::
Run
(
param
,
stream
);
};
};
include/ck_tile/ops/layernorm2d/kernel/layernorm2d_fwd_kernel.hpp
View file @
d62f0358
...
@@ -26,8 +26,8 @@ struct Layernorm2dFwd
...
@@ -26,8 +26,8 @@ struct Layernorm2dFwd
static
constexpr
bool
kHasGamma
=
!
std
::
is_same_v
<
GammaDataType
,
ck_tile
::
null_type
>
;
static
constexpr
bool
kHasGamma
=
!
std
::
is_same_v
<
GammaDataType
,
ck_tile
::
null_type
>
;
static
constexpr
bool
kHasBeta
=
!
std
::
is_same_v
<
BetaDataType
,
ck_tile
::
null_type
>
;
static
constexpr
bool
kHasBeta
=
!
std
::
is_same_v
<
BetaDataType
,
ck_tile
::
null_type
>
;
static
constexpr
bool
kSaveMean
=
!
std
::
is_same_v
<
MeanDataType
,
ck_tile
::
null_type
>
;
static
constexpr
bool
kSaveMean
=
Problem
::
kSaveMeanInvStd
;
static
constexpr
bool
kSaveInvStd
=
!
std
::
is_same_v
<
InvStdDataType
,
ck_tile
::
null_type
>
;
static
constexpr
bool
kSaveInvStd
=
Problem
::
kSaveMeanInvStd
;
static
constexpr
ck_tile
::
index_t
kMPerBlock
=
Problem
::
BlockShape
::
kMPerBlock
;
static
constexpr
ck_tile
::
index_t
kMPerBlock
=
Problem
::
BlockShape
::
kMPerBlock
;
static
constexpr
ck_tile
::
index_t
kNPerBlock
=
Problem
::
BlockShape
::
kNPerBlock
;
static
constexpr
ck_tile
::
index_t
kNPerBlock
=
Problem
::
BlockShape
::
kNPerBlock
;
...
...
include/ck_tile/ops/layernorm2d/pipeline/block_layernorm2d_fwd_problem.hpp
View file @
d62f0358
...
@@ -16,19 +16,22 @@ template <typename XDataType_,
...
@@ -16,19 +16,22 @@ template <typename XDataType_,
typename
InvStdDataType_
,
typename
InvStdDataType_
,
typename
BlockShape_
,
typename
BlockShape_
,
bool
kPadN_
,
bool
kPadN_
,
bool
kSaveMeanInvStd_
,
bool
kTwoPass_
>
bool
kTwoPass_
>
struct
BlockLayernorm2dFwdProblem
struct
BlockLayernorm2dFwdProblem
{
{
using
XDataType
=
remove_cvref_t
<
XDataType_
>
;
using
XDataType
=
remove_cvref_t
<
XDataType_
>
;
using
GammaDataType
=
remove_cvref_t
<
GammaDataType_
>
;
using
GammaDataType
=
remove_cvref_t
<
GammaDataType_
>
;
using
BetaDataType
=
remove_cvref_t
<
BetaDataType_
>
;
using
BetaDataType
=
remove_cvref_t
<
BetaDataType_
>
;
using
ComputeDataType
=
remove_cvref_t
<
ComputeDataType_
>
;
using
ComputeDataType
=
remove_cvref_t
<
ComputeDataType_
>
;
using
YDataType
=
remove_cvref_t
<
YDataType_
>
;
using
YDataType
=
remove_cvref_t
<
YDataType_
>
;
using
MeanDataType
=
remove_cvref_t
<
MeanDataType_
>
;
using
MeanDataType
=
remove_cvref_t
<
MeanDataType_
>
;
using
InvStdDataType
=
remove_cvref_t
<
InvStdDataType_
>
;
using
InvStdDataType
=
remove_cvref_t
<
InvStdDataType_
>
;
using
BlockShape
=
remove_cvref_t
<
BlockShape_
>
;
using
BlockShape
=
remove_cvref_t
<
BlockShape_
>
;
static
constexpr
bool
kPadN
=
kPadN_
;
static
constexpr
bool
kTwoPass
=
kTwoPass_
;
static
constexpr
bool
kPadN
=
kPadN_
;
static
constexpr
bool
kSaveMeanInvStd
=
kSaveMeanInvStd_
;
static
constexpr
bool
kTwoPass
=
kTwoPass_
;
};
};
}
// namespace ck_tile
}
// namespace ck_tile
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment