Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel_ROCM
Commits
c03a937d
Commit
c03a937d
authored
Oct 28, 2024
by
dummycoderfe
Browse files
one block ok
parent
7db609fe
Changes
9
Hide whitespace changes
Inline
Side-by-side
Showing
9 changed files
with
213 additions
and
180 deletions
+213
-180
example/ck_tile/06_layernorm2d_bwd/instances/layernorm2d_bwd_api.cpp
...tile/06_layernorm2d_bwd/instances/layernorm2d_bwd_api.cpp
+2
-21
example/ck_tile/06_layernorm2d_bwd/instances/layernorm2d_bwd_instance_common.hpp
...rnorm2d_bwd/instances/layernorm2d_bwd_instance_common.hpp
+0
-11
example/ck_tile/06_layernorm2d_bwd/layernorm2d_bwd.cpp
example/ck_tile/06_layernorm2d_bwd/layernorm2d_bwd.cpp
+33
-44
example/ck_tile/06_layernorm2d_bwd/layernorm2d_bwd.hpp
example/ck_tile/06_layernorm2d_bwd/layernorm2d_bwd.hpp
+29
-0
include/ck_tile/host.hpp
include/ck_tile/host.hpp
+1
-0
include/ck_tile/host/reference/reference_layernorm2d_bwd.hpp
include/ck_tile/host/reference/reference_layernorm2d_bwd.hpp
+53
-0
include/ck_tile/ops/layernorm2d/kernel/layernorm2d_bwd_gamma_beta_kernel.hpp
.../layernorm2d/kernel/layernorm2d_bwd_gamma_beta_kernel.hpp
+37
-21
include/ck_tile/ops/layernorm2d/pipeline/layernorm2d_bwd_pipeline_default_policy.hpp
...rm2d/pipeline/layernorm2d_bwd_pipeline_default_policy.hpp
+21
-19
include/ck_tile/ops/layernorm2d/pipeline/layernorm2d_bwd_pipeline_gamma_beta.hpp
...ernorm2d/pipeline/layernorm2d_bwd_pipeline_gamma_beta.hpp
+37
-64
No files found.
example/ck_tile/06_layernorm2d_bwd/instances/layernorm2d_bwd_api.cpp
View file @
c03a937d
...
@@ -4,25 +4,6 @@
...
@@ -4,25 +4,6 @@
#include <ck_tile/core.hpp>
#include <ck_tile/core.hpp>
#include "layernorm2d_bwd.hpp"
#include "layernorm2d_bwd.hpp"
template
<
typename
DataType_
,
ck_tile
::
index_t
Repeat_M_
,
// each thread repeat along M
ck_tile
::
index_t
ThreadPerBlock_M_
,
// num threads along M
ck_tile
::
index_t
ThreadPerBlock_N_
,
// num threads along N
bool
kPadN_
>
using
trait_
=
layernorm2d_bwd_traits_
<
DataType_
,
Repeat_M_
,
ThreadPerBlock_M_
,
ThreadPerBlock_N_
,
kPadN_
>
;
template
<
typename
data_type
>
float
layernorm2d_bwd_b16_
(
layernorm2d_bwd_traits
/*t*/
,
layernorm2d_bwd_args
a
,
const
ck_tile
::
stream_config
&
s
)
{
return
layernorm2d_bwd_
<
trait_
<
data_type
,
1
,
1
,
64
,
true
>>
(
s
,
a
);
}
float
layernorm2d_bwd
(
layernorm2d_bwd_traits
t
,
float
layernorm2d_bwd
(
layernorm2d_bwd_traits
t
,
layernorm2d_bwd_args
a
,
layernorm2d_bwd_args
a
,
const
ck_tile
::
stream_config
&
s
)
const
ck_tile
::
stream_config
&
s
)
...
@@ -31,11 +12,11 @@ float layernorm2d_bwd(layernorm2d_bwd_traits t,
...
@@ -31,11 +12,11 @@ float layernorm2d_bwd(layernorm2d_bwd_traits t,
float
r
=
-
1
;
float
r
=
-
1
;
if
(
t
.
data_type
.
compare
(
"fp16"
)
==
0
)
if
(
t
.
data_type
.
compare
(
"fp16"
)
==
0
)
{
{
return
layernorm2d_bwd_b16_
<
ck_tile
::
fp16_t
>
(
t
,
a
,
s
);
return
layernorm2d_bwd_b16_
<
ck_tile
::
fp16_t
>
{}
(
t
,
a
,
s
);
}
}
else
if
(
t
.
data_type
.
compare
(
"bf16"
)
==
0
)
else
if
(
t
.
data_type
.
compare
(
"bf16"
)
==
0
)
{
{
return
layernorm2d_bwd_b16_
<
ck_tile
::
bf16_t
>
(
t
,
a
,
s
);
return
layernorm2d_bwd_b16_
<
ck_tile
::
bf16_t
>
{}
(
t
,
a
,
s
);
}
}
if
(
r
<
0
)
if
(
r
<
0
)
throw
std
::
runtime_error
(
"Without supported instances!"
);
throw
std
::
runtime_error
(
"Without supported instances!"
);
...
...
example/ck_tile/06_layernorm2d_bwd/instances/layernorm2d_bwd_instance_common.hpp
View file @
c03a937d
...
@@ -11,17 +11,6 @@
...
@@ -11,17 +11,6 @@
using
S
=
ck_tile
::
stream_config
;
using
S
=
ck_tile
::
stream_config
;
using
A
=
layernorm2d_bwd_args
;
using
A
=
layernorm2d_bwd_args
;
template
<
typename
DataType_
,
ck_tile
::
index_t
Repeat_M_
,
// each thread repeat along M
ck_tile
::
index_t
ThreadPerBlock_M_
,
// num threads along M
ck_tile
::
index_t
ThreadPerBlock_N_
,
// num threads along N
bool
kPadN_
>
using
trait_
=
layernorm2d_bwd_traits_
<
DataType_
,
Repeat_M_
,
ThreadPerBlock_M_
,
ThreadPerBlock_N_
,
kPadN_
>
;
template
<
typename
Traits_
>
template
<
typename
Traits_
>
float
layernorm2d_bwd_
(
const
S
&
s
,
A
a
)
float
layernorm2d_bwd_
(
const
S
&
s
,
A
a
)
{
{
...
...
example/ck_tile/06_layernorm2d_bwd/layernorm2d_bwd.cpp
View file @
c03a937d
...
@@ -64,21 +64,27 @@ bool run(const ck_tile::ArgParser& arg_parser)
...
@@ -64,21 +64,27 @@ bool run(const ck_tile::ArgParser& arg_parser)
using
ComputeDataType
=
typename
TypeConfig
::
ComputeDataType
;
using
ComputeDataType
=
typename
TypeConfig
::
ComputeDataType
;
// host verify
// host verify
ck_tile
::
HostTensor
<
XDataType
>
x_host
({
m
,
n
},
{
stride
,
1
});
ck_tile
::
HostTensor
<
YDataType
>
dy_host
({
m
,
n
},
{
stride
,
1
});
ck_tile
::
HostTensor
<
YDataType
>
dy_host
({
m
,
n
},
{
stride
,
1
});
ck_tile
::
HostTensor
<
MeanDataType
>
mean_host
({
m
});
ck_tile
::
HostTensor
<
MeanDataType
>
mean_host
({
m
});
ck_tile
::
HostTensor
<
InvStdDataType
>
invStd_host
({
m
});
ck_tile
::
HostTensor
<
InvStdDataType
>
invStd_host
({
m
});
ck_tile
::
HostTensor
<
GammaDataType
>
dgamma_host_dev
({
n
});
ck_tile
::
index_t
blockM
=
layernorm2d_bwd_block_m
<
XDataType
>
();
ck_tile
::
HostTensor
<
BetaDataType
>
dbeta_host_dev
({
n
});
ck_tile
::
index_t
reduce_m
=
(
m
+
blockM
-
1
)
/
blockM
;
ck_tile
::
HostTensor
<
GammaDataType
>
dgamma_host_ref
({
n
});
ck_tile
::
HostTensor
<
GammaDataType
>
dgamma_host_dev
({
reduce_m
,
n
});
ck_tile
::
HostTensor
<
BetaDataType
>
dbeta_host_ref
({
n
});
ck_tile
::
HostTensor
<
BetaDataType
>
dbeta_host_dev
({
reduce_m
,
n
});
ck_tile
::
HostTensor
<
GammaDataType
>
dgamma_host_ref
({
reduce_m
,
n
});
ck_tile
::
HostTensor
<
BetaDataType
>
dbeta_host_ref
({
reduce_m
,
n
});
// ck_tile::FillMonotonicSeq<YDataType>{}(dy_host);
ck_tile
::
FillUniformDistribution
<
YDataType
>
{
-
.5
f
,
.5
f
}(
dy_host
);
ck_tile
::
FillUniformDistribution
<
YDataType
>
{
-
.5
f
,
.5
f
}(
dy_host
);
// ck_tile::FillUniformDistribution<MeanDataType>{-.5f, .5f}(mean_host);
ck_tile
::
FillUniformDistribution
<
MeanDataType
>
{
-
.5
f
,
.5
f
}(
mean_host
);
ck_tile
::
FillMonotonicSeq
<
MeanDataType
>
{}(
mean_host
);
ck_tile
::
FillUniformDistribution
<
XDataType
>
{
-
.5
f
,
.5
f
}(
x_host
);
// ck_tile::FillMonotonicSeq<MeanDataType>{}(mean_host);
ck_tile
::
FillUniformDistribution
<
InvStdDataType
>
{
-
.5
f
,
.5
f
}(
invStd_host
);
ck_tile
::
FillUniformDistribution
<
InvStdDataType
>
{
-
.5
f
,
.5
f
}(
invStd_host
);
ck_tile
::
DeviceMem
x_buf
(
x_host
.
get_element_space_size_in_bytes
());
ck_tile
::
DeviceMem
dy_buf
(
dy_host
.
get_element_space_size_in_bytes
());
ck_tile
::
DeviceMem
dy_buf
(
dy_host
.
get_element_space_size_in_bytes
());
ck_tile
::
DeviceMem
mean_buf
(
mean_host
.
get_element_space_size_in_bytes
());
ck_tile
::
DeviceMem
mean_buf
(
mean_host
.
get_element_space_size_in_bytes
());
ck_tile
::
DeviceMem
invStd_buf
(
invStd_host
.
get_element_space_size_in_bytes
());
ck_tile
::
DeviceMem
invStd_buf
(
invStd_host
.
get_element_space_size_in_bytes
());
...
@@ -86,6 +92,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
...
@@ -86,6 +92,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
ck_tile
::
DeviceMem
dgamma_buf
(
dgamma_host_dev
.
get_element_space_size_in_bytes
());
ck_tile
::
DeviceMem
dgamma_buf
(
dgamma_host_dev
.
get_element_space_size_in_bytes
());
ck_tile
::
DeviceMem
dbeta_buf
(
dbeta_host_dev
.
get_element_space_size_in_bytes
());
ck_tile
::
DeviceMem
dbeta_buf
(
dbeta_host_dev
.
get_element_space_size_in_bytes
());
x_buf
.
ToDevice
(
x_host
.
data
());
dy_buf
.
ToDevice
(
dy_host
.
data
());
dy_buf
.
ToDevice
(
dy_host
.
data
());
mean_buf
.
ToDevice
(
mean_host
.
data
());
mean_buf
.
ToDevice
(
mean_host
.
data
());
invStd_buf
.
ToDevice
(
invStd_host
.
data
());
invStd_buf
.
ToDevice
(
invStd_host
.
data
());
...
@@ -94,13 +101,12 @@ bool run(const ck_tile::ArgParser& arg_parser)
...
@@ -94,13 +101,12 @@ bool run(const ck_tile::ArgParser& arg_parser)
<<
" m:"
<<
m
<<
", n:"
<<
n
<<
", stride:"
<<
stride
<<
std
::
flush
;
<<
" m:"
<<
m
<<
", n:"
<<
n
<<
", stride:"
<<
stride
<<
std
::
flush
;
layernorm2d_bwd_traits
traits
{
data_type
};
layernorm2d_bwd_traits
traits
{
data_type
};
layernorm2d_bwd_args
args
{
x_buf
.
GetDeviceBuffer
(),
layernorm2d_bwd_args
args
{
dy_buf
.
GetDeviceBuffer
(),
dy_buf
.
GetDeviceBuffer
(),
mean_buf
.
GetDeviceBuffer
(),
mean_buf
.
GetDeviceBuffer
(),
invStd_buf
.
GetDeviceBuffer
(),
invStd_buf
.
GetDeviceBuffer
(),
dgamma_buf
.
GetDeviceBuffer
(),
dgamma_buf
.
GetDeviceBuffer
(),
dbeta_buf
.
GetDeviceBuffer
(),
dbeta_buf
.
GetDeviceBuffer
(),
nullptr
,
m
,
m
,
n
,
n
,
stride
};
stride
};
...
@@ -118,41 +124,24 @@ bool run(const ck_tile::ArgParser& arg_parser)
...
@@ -118,41 +124,24 @@ bool run(const ck_tile::ArgParser& arg_parser)
if
(
do_validation
)
if
(
do_validation
)
{
{
// // reference
// reference
// ck_tile::reference_layernorm2d_bwd<XDataType,
ck_tile
::
reference_layernorm2d_bwd_gamma_part
<
XDataType
,
// GammaDataType,
GammaDataType
,
// BetaDataType,
BetaDataType
,
// ComputeDataType,
ComputeDataType
,
// YDataType,
YDataType
,
// MeanDataType,
MeanDataType
,
// InvStdDataType>(
InvStdDataType
>
(
// x_host, gamma_host, beta_host, y_host_ref, mean_host_ref, invStd_host_ref, epsilon);
x_host
,
dy_host
,
mean_host
,
invStd_host
,
dgamma_host_ref
,
dbeta_host_ref
);
// y_buf.FromDevice(y_host_dev.data());
dgamma_buf
.
FromDevice
(
dgamma_host_dev
.
data
());
dbeta_buf
.
FromDevice
(
dbeta_host_dev
.
data
());
// auto [rtol, atol] = get_elimit<DataType>();
// if(stride == n)
auto
[
rtol
,
atol
]
=
get_elimit
<
DataType
>
();
// {
pass
=
ck_tile
::
check_err
(
// pass = ck_tile::check_err(
dgamma_host_dev
,
dgamma_host_ref
,
std
::
string
(
"GAMMA OUT Error: Incorrect results!"
),
rtol
,
atol
);
// y_host_dev, y_host_ref, std::string("OUT Error: Incorrect results!"), rtol, atol);
pass
&=
ck_tile
::
check_err
(
// }
dbeta_host_dev
,
dbeta_host_ref
,
std
::
string
(
"BETA OUT Error: Incorrect results!"
),
rtol
,
atol
);
// else
// {
// for(int i_r = 0; i_r < m; i_r++)
// {
// std::vector<YDataType> y_host_dev_row(y_host_dev.begin() + i_r * stride,
// y_host_dev.begin() + i_r * stride + n);
// std::vector<YDataType> y_host_ref_row(y_host_ref.begin() + i_r * stride,
// y_host_ref.begin() + i_r * stride + n);
// pass &= ck_tile::check_err(y_host_dev_row,
// y_host_ref_row,
// std::string("OUT[") + std::to_string(i_r) +
// std::string("] Error: Incorrect results!"),
// rtol,
// atol);
// }
// }
std
::
cout
<<
", valid:"
<<
(
pass
?
"y"
:
"n"
)
<<
std
::
flush
<<
std
::
endl
;
std
::
cout
<<
", valid:"
<<
(
pass
?
"y"
:
"n"
)
<<
std
::
flush
<<
std
::
endl
;
}
}
...
...
example/ck_tile/06_layernorm2d_bwd/layernorm2d_bwd.hpp
View file @
c03a937d
...
@@ -101,6 +101,17 @@ struct layernorm2d_bwd_traits_
...
@@ -101,6 +101,17 @@ struct layernorm2d_bwd_traits_
static
constexpr
bool
kPadN
=
kPadN_
;
static
constexpr
bool
kPadN
=
kPadN_
;
};
};
template
<
typename
DataType_
,
ck_tile
::
index_t
Repeat_M_
,
// each thread repeat along M
ck_tile
::
index_t
ThreadPerBlock_M_
,
// num threads along M
ck_tile
::
index_t
ThreadPerBlock_N_
,
// num threads along N
bool
kPadN_
>
using
trait_
=
layernorm2d_bwd_traits_
<
DataType_
,
Repeat_M_
,
ThreadPerBlock_M_
,
ThreadPerBlock_N_
,
kPadN_
>
;
template
<
typename
Traits_
>
template
<
typename
Traits_
>
float
layernorm2d_bwd_
(
const
ck_tile
::
stream_config
&
s
,
layernorm2d_bwd_args
a
);
float
layernorm2d_bwd_
(
const
ck_tile
::
stream_config
&
s
,
layernorm2d_bwd_args
a
);
...
@@ -108,6 +119,24 @@ float layernorm2d_bwd_(const ck_tile::stream_config& s, layernorm2d_bwd_args a);
...
@@ -108,6 +119,24 @@ float layernorm2d_bwd_(const ck_tile::stream_config& s, layernorm2d_bwd_args a);
struct
layernorm2d_bwd_traits
struct
layernorm2d_bwd_traits
{
{
std
::
string
data_type
;
std
::
string
data_type
;
};
template
<
typename
data_type
>
struct
layernorm2d_bwd_b16_
{
/* data */
using
Trait
=
trait_
<
data_type
,
1
,
1
,
64
,
true
>
;
float
operator
()
(
layernorm2d_bwd_traits
/*t*/
,
layernorm2d_bwd_args
a
,
const
ck_tile
::
stream_config
&
s
)
{
return
layernorm2d_bwd_
<
Trait
>
(
s
,
a
);
}
};
template
<
typename
data_type
>
ck_tile
::
index_t
layernorm2d_bwd_block_m
()
{
return
layernorm2d_bwd_b16_
<
data_type
>::
Trait
::
Block_M
;
};
};
float
layernorm2d_bwd
(
layernorm2d_bwd_traits
,
layernorm2d_bwd_args
,
const
ck_tile
::
stream_config
&
);
float
layernorm2d_bwd
(
layernorm2d_bwd_traits
,
layernorm2d_bwd_args
,
const
ck_tile
::
stream_config
&
);
include/ck_tile/host.hpp
View file @
c03a937d
...
@@ -22,6 +22,7 @@
...
@@ -22,6 +22,7 @@
#include "ck_tile/host/reference/reference_gemm.hpp"
#include "ck_tile/host/reference/reference_gemm.hpp"
#include "ck_tile/host/reference/reference_im2col.hpp"
#include "ck_tile/host/reference/reference_im2col.hpp"
#include "ck_tile/host/reference/reference_layernorm2d_fwd.hpp"
#include "ck_tile/host/reference/reference_layernorm2d_fwd.hpp"
#include "ck_tile/host/reference/reference_layernorm2d_bwd.hpp"
#include "ck_tile/host/reference/reference_reduce.hpp"
#include "ck_tile/host/reference/reference_reduce.hpp"
#include "ck_tile/host/reference/reference_softmax.hpp"
#include "ck_tile/host/reference/reference_softmax.hpp"
#include "ck_tile/host/stream_config.hpp"
#include "ck_tile/host/stream_config.hpp"
...
...
include/ck_tile/host/reference/reference_layernorm2d_bwd.hpp
0 → 100644
View file @
c03a937d
#pragma once
#include "ck_tile/core.hpp"
#include "ck_tile/host/host_tensor.hpp"
#include <thread>
namespace
ck_tile
{
template
<
typename
XDataType
,
typename
GammaDataType
,
typename
BetaDataType
,
typename
ComputeDataType
,
typename
YDataType
,
typename
MeanDataType
,
typename
InvStdDataType
>
CK_TILE_HOST
void
reference_layernorm2d_bwd_gamma_part
(
const
HostTensor
<
XDataType
>&
x_m_n
,
const
HostTensor
<
YDataType
>&
dy_m_n
,
const
HostTensor
<
MeanDataType
>&
mean_m
,
const
HostTensor
<
InvStdDataType
>&
inv_std_m
,
HostTensor
<
GammaDataType
>&
dgamma_mpart_n
,
HostTensor
<
BetaDataType
>&
dbeta_mpart_n
)
{
const
auto
MN
=
x_m_n
.
mDesc
.
get_lengths
();
const
auto
M
=
MN
[
0
];
const
auto
N
=
MN
[
1
];
const
auto
PartM
=
dgamma_mpart_n
.
mDesc
.
get_lengths
()[
0
];
const
auto
MLoop
=
(
M
+
PartM
-
1
)
/
PartM
;
auto
f
=
[
&
](
auto
m
)
{
const
auto
m_offset
=
m
*
MLoop
;
for
(
int
n
=
0
;
n
<
N
;
++
n
)
{
ComputeDataType
gamma_acc
=
0
;
ComputeDataType
beta_acc
=
0
;
for
(
int
inner_m
=
0
;
inner_m
<
MLoop
&&
m_offset
+
inner_m
<
M
;
inner_m
++
)
{
const
ComputeDataType
mean
=
ck_tile
::
type_convert
<
ComputeDataType
>
(
mean_m
(
m_offset
+
inner_m
));
const
ComputeDataType
inv_std
=
ck_tile
::
type_convert
<
ComputeDataType
>
(
inv_std_m
(
m_offset
+
inner_m
));
const
ComputeDataType
x
=
ck_tile
::
type_convert
<
ComputeDataType
>
(
x_m_n
(
m_offset
+
inner_m
,
n
));
const
ComputeDataType
dy
=
ck_tile
::
type_convert
<
ComputeDataType
>
(
dy_m_n
(
m_offset
+
inner_m
,
n
));
gamma_acc
+=
dy
*
(
x
-
mean
)
*
inv_std
;
beta_acc
+=
dy
;
}
dgamma_mpart_n
(
m
,
n
)
=
ck_tile
::
type_convert
<
GammaDataType
>
(
gamma_acc
);
dbeta_mpart_n
(
m
,
n
)
=
ck_tile
::
type_convert
<
BetaDataType
>
(
beta_acc
);
}
};
make_ParallelTensorFunctor
(
f
,
PartM
)(
std
::
thread
::
hardware_concurrency
());
}
}
// namespace ck_tile
include/ck_tile/ops/layernorm2d/kernel/layernorm2d_bwd_gamma_beta_kernel.hpp
View file @
c03a937d
...
@@ -11,13 +11,13 @@ namespace ck_tile {
...
@@ -11,13 +11,13 @@ namespace ck_tile {
// host side args
// host side args
struct
Layernorm2dBwdGammaBetaHostArgs
struct
Layernorm2dBwdGammaBetaHostArgs
{
{
const
void
*
p_x
;
const
void
*
p_dY
;
const
void
*
p_dY
;
const
void
*
p_mean
;
const
void
*
p_mean
;
const
void
*
p_invStd
;
const
void
*
p_invStd
;
void
*
p_dGamma
;
void
*
p_dGamma
;
void
*
p_dBeta
;
void
*
p_dBeta
;
void
*
p_yMul
;
index_t
m
;
index_t
m
;
index_t
n
;
index_t
n
;
...
@@ -51,13 +51,13 @@ struct Layernorm2dBwdGammaBeta
...
@@ -51,13 +51,13 @@ struct Layernorm2dBwdGammaBeta
struct
Kargs
struct
Kargs
{
{
const
void
*
p_x
;
const
void
*
p_dY
;
const
void
*
p_dY
;
const
void
*
p_mean
;
const
void
*
p_mean
;
const
void
*
p_invStd
;
const
void
*
p_invStd
;
void
*
p_dGamma
;
void
*
p_dGamma
;
void
*
p_dBeta
;
void
*
p_dBeta
;
void
*
p_yMul
;
index_t
m
;
index_t
m
;
index_t
n
;
index_t
n
;
...
@@ -67,12 +67,12 @@ struct Layernorm2dBwdGammaBeta
...
@@ -67,12 +67,12 @@ struct Layernorm2dBwdGammaBeta
CK_TILE_HOST
static
constexpr
Kargs
MakeKargs
(
const
Hargs
&
hargs
)
CK_TILE_HOST
static
constexpr
Kargs
MakeKargs
(
const
Hargs
&
hargs
)
{
{
return
Kargs
{
hargs
.
p_dY
,
return
Kargs
{
hargs
.
p_x
,
hargs
.
p_dY
,
hargs
.
p_mean
,
hargs
.
p_mean
,
hargs
.
p_invStd
,
hargs
.
p_invStd
,
hargs
.
p_dGamma
,
hargs
.
p_dGamma
,
hargs
.
p_dBeta
,
hargs
.
p_dBeta
,
hargs
.
p_yMul
,
hargs
.
m
,
hargs
.
m
,
hargs
.
n
,
hargs
.
n
,
hargs
.
stride
};
hargs
.
stride
};
...
@@ -119,7 +119,22 @@ struct Layernorm2dBwdGammaBeta
...
@@ -119,7 +119,22 @@ struct Layernorm2dBwdGammaBeta
CK_TILE_DEVICE
void
operator
()(
Kargs
kargs
)
const
CK_TILE_DEVICE
void
operator
()(
Kargs
kargs
)
const
{
{
const
auto
iM
=
get_block_id
()
*
Block_M
;
const
auto
block_id
=
get_block_id
();
const
auto
iM
=
block_id
*
Block_M
;
const
auto
x_window
=
[
&
]()
{
const
auto
tmp_
=
make_naive_tensor_view
<
address_space_enum
::
global
>
(
static_cast
<
const
XDataType
*>
(
kargs
.
p_x
),
make_tuple
(
kargs
.
m
,
kargs
.
n
),
make_tuple
(
kargs
.
stride
,
1
));
// NOTE: we don't do any pad in this kernel for loading, assume that inside kernel will
// check the max count dynamically
const
auto
tmp2_
=
pad_tensor_view
(
tmp_
,
make_tuple
(
number
<
Block_M
>
{},
number
<
Block_N
>
{}),
sequence
<
false
,
false
>
{});
return
make_tile_window
(
tmp2_
,
make_tuple
(
number
<
Block_M
>
{},
number
<
Block_N
>
{}),
{
iM
,
0
});
}();
const
auto
dy_window
=
[
&
]()
{
const
auto
dy_window
=
[
&
]()
{
const
auto
tmp_
=
make_naive_tensor_view
<
address_space_enum
::
global
>
(
const
auto
tmp_
=
make_naive_tensor_view
<
address_space_enum
::
global
>
(
...
@@ -134,7 +149,7 @@ struct Layernorm2dBwdGammaBeta
...
@@ -134,7 +149,7 @@ struct Layernorm2dBwdGammaBeta
return
make_tile_window
(
return
make_tile_window
(
tmp2_
,
make_tuple
(
number
<
Block_M
>
{},
number
<
Block_N
>
{}),
{
iM
,
0
});
tmp2_
,
make_tuple
(
number
<
Block_M
>
{},
number
<
Block_N
>
{}),
{
iM
,
0
});
}();
}();
const
auto
mean_window
=
[
&
]()
{
const
auto
mean_window
=
[
&
]()
{
const
auto
tmp_
=
make_naive_tensor_view
<
address_space_enum
::
global
>
(
const
auto
tmp_
=
make_naive_tensor_view
<
address_space_enum
::
global
>
(
static_cast
<
const
MeanDataType
*>
(
kargs
.
p_mean
),
static_cast
<
const
MeanDataType
*>
(
kargs
.
p_mean
),
...
@@ -144,7 +159,7 @@ struct Layernorm2dBwdGammaBeta
...
@@ -144,7 +159,7 @@ struct Layernorm2dBwdGammaBeta
const
auto
tmp2_
=
const
auto
tmp2_
=
pad_tensor_view
(
tmp_
,
make_tuple
(
number
<
Block_M
>
{}),
sequence
<
false
>
{});
pad_tensor_view
(
tmp_
,
make_tuple
(
number
<
Block_M
>
{}),
sequence
<
false
>
{});
return
make_tile_window
(
tmp2_
,
make_tuple
(
number
<
Block_M
>
{}),
{
0
});
return
make_tile_window
(
tmp2_
,
make_tuple
(
number
<
Block_M
>
{}),
{
iM
});
}();
}();
const
auto
invstd_window
=
[
&
]()
{
const
auto
invstd_window
=
[
&
]()
{
...
@@ -156,36 +171,37 @@ struct Layernorm2dBwdGammaBeta
...
@@ -156,36 +171,37 @@ struct Layernorm2dBwdGammaBeta
const
auto
tmp2_
=
const
auto
tmp2_
=
pad_tensor_view
(
tmp_
,
make_tuple
(
number
<
Block_M
>
{}),
sequence
<
false
>
{});
pad_tensor_view
(
tmp_
,
make_tuple
(
number
<
Block_M
>
{}),
sequence
<
false
>
{});
return
make_tile_window
(
tmp2_
,
make_tuple
(
number
<
Block_M
>
{}),
{
0
});
return
make_tile_window
(
tmp2_
,
make_tuple
(
number
<
Block_M
>
{}),
{
iM
});
}();
}();
const
auto
dgamma_window
=
[
&
]()
{
auto
dgamma_window
=
[
&
]()
{
const
auto
tmp_
=
make_naive_tensor_view
<
address_space_enum
::
global
>
(
const
auto
tmp_
=
make_naive_tensor_view
<
address_space_enum
::
global
>
(
static_cast
<
const
GammaDataType
*>
(
kargs
.
p_dGamma
),
static_cast
<
GammaDataType
*>
(
kargs
.
p_dGamma
),
make_tuple
(
kargs
.
n
),
make_tuple
(
gridDim
.
x
,
kargs
.
n
),
make_tuple
(
1
));
make_tuple
(
kargs
.
n
,
1
));
const
auto
tmp2_
=
const
auto
tmp2_
=
pad_tensor_view
(
tmp_
,
make_tuple
(
number
<
Block_N
>
{}),
sequence
<
false
>
{});
pad_tensor_view
(
tmp_
,
make_tuple
(
number
<
1
>
{},
number
<
Block_N
>
{}),
sequence
<
false
,
kPadN
>
{});
return
make_tile_window
(
tmp2_
,
make_tuple
(
number
<
Block_N
>
{}),
{
0
});
return
make_tile_window
(
tmp2_
,
make_tuple
(
number
<
1
>
{},
number
<
Block_N
>
{}),
{
block_id
,
0
});
}();
}();
const
auto
dbeta_window
=
[
&
]()
{
auto
dbeta_window
=
[
&
]()
{
const
auto
tmp_
=
make_naive_tensor_view
<
address_space_enum
::
global
>
(
const
auto
tmp_
=
make_naive_tensor_view
<
address_space_enum
::
global
>
(
static_cast
<
const
BetaDataType
*>
(
kargs
.
p_dBeta
),
static_cast
<
BetaDataType
*>
(
kargs
.
p_dBeta
),
make_tuple
(
kargs
.
n
),
make_tuple
(
gridDim
.
x
,
kargs
.
n
),
make_tuple
(
1
));
make_tuple
(
kargs
.
n
,
1
));
const
auto
tmp2_
=
const
auto
tmp2_
=
pad_tensor_view
(
tmp_
,
make_tuple
(
number
<
Block_N
>
{}),
sequence
<
false
>
{});
pad_tensor_view
(
tmp_
,
make_tuple
(
number
<
1
>
{},
number
<
Block_N
>
{}),
sequence
<
false
,
kPadN
>
{});
return
make_tile_window
(
tmp2_
,
make_tuple
(
number
<
Block_M
>
{},
number
<
Block_N
>
{}),
{
0
});
return
make_tile_window
(
tmp2_
,
make_tuple
(
number
<
1
>
{},
number
<
Block_N
>
{}),
{
block_id
,
0
});
}();
}();
__shared__
char
smem
[
GetSmemSize
()];
__shared__
char
smem
[
GetSmemSize
()];
Pipeline
{}(
dy_window
,
Pipeline
{}(
x_window
,
dy_window
,
mean_window
,
mean_window
,
invstd_window
,
invstd_window
,
dgamma_window
,
dgamma_window
,
...
...
include/ck_tile/ops/layernorm2d/pipeline/layernorm2d_bwd_pipeline_default_policy.hpp
View file @
c03a937d
...
@@ -10,7 +10,7 @@ namespace ck_tile {
...
@@ -10,7 +10,7 @@ namespace ck_tile {
struct
Layernorm2dBwdGammaBetaPipelineDefaultPolicy
struct
Layernorm2dBwdGammaBetaPipelineDefaultPolicy
{
{
template
<
typename
Problem
>
template
<
typename
Problem
>
CK_TILE_DEVICE
static
constexpr
auto
Make
Dy
BlockTileDistribution
()
CK_TILE_DEVICE
static
constexpr
auto
Make
X
BlockTileDistribution
()
{
{
using
S
=
typename
Problem
::
BlockShape
;
using
S
=
typename
Problem
::
BlockShape
;
...
@@ -18,11 +18,11 @@ struct Layernorm2dBwdGammaBetaPipelineDefaultPolicy
...
@@ -18,11 +18,11 @@ struct Layernorm2dBwdGammaBetaPipelineDefaultPolicy
tile_distribution_encoding
<
tile_distribution_encoding
<
sequence
<>
,
sequence
<>
,
tuple
<
sequence
<
S
::
Repeat_M
,
S
::
WarpPerBlock_M
,
S
::
ThreadPerWarp_M
>
,
tuple
<
sequence
<
S
::
Repeat_M
,
S
::
WarpPerBlock_M
,
S
::
ThreadPerWarp_M
>
,
sequence
<
S
::
WarpPerBlock_N
,
S
::
ThreadPerWarp_N
>>
,
sequence
<
S
::
Repeat_N
,
S
::
WarpPerBlock_N
,
S
::
ThreadPerWarp_N
>>
,
tuple
<
sequence
<
1
,
2
>
,
sequence
<
1
,
2
>>
,
tuple
<
sequence
<
1
,
2
>
,
sequence
<
1
,
2
>>
,
tuple
<
sequence
<
1
,
0
>
,
sequence
<
2
,
1
>>
,
tuple
<
sequence
<
1
,
1
>
,
sequence
<
2
,
2
>>
,
sequence
<
1
>
,
sequence
<
1
,
2
>
,
sequence
<
0
>>
{});
sequence
<
0
,
0
>>
{});
}
}
template
<
typename
Problem
>
template
<
typename
Problem
>
CK_TILE_DEVICE
static
constexpr
auto
MakeMeanBlockTileDistribution
()
CK_TILE_DEVICE
static
constexpr
auto
MakeMeanBlockTileDistribution
()
...
@@ -39,20 +39,22 @@ struct Layernorm2dBwdGammaBetaPipelineDefaultPolicy
...
@@ -39,20 +39,22 @@ struct Layernorm2dBwdGammaBetaPipelineDefaultPolicy
sequence
<
0
>>
{});
sequence
<
0
>>
{});
}
}
// template <typename Problem>
template
<
typename
Problem
>
// CK_TILE_DEVICE static constexpr auto MakeGammaBetaBlockTileDistribution()
CK_TILE_DEVICE
static
constexpr
auto
MakeGammaBetaBlockTileDistribution
()
// {
{
// using S = typename Problem::BlockShape;
using
S
=
typename
Problem
::
BlockShape
;
// return make_static_tile_distribution(
return
make_static_tile_distribution
(
// tile_distribution_encoding<
tile_distribution_encoding
<
// sequence<S::Repeat_M, S::WarpPerBlock_M, S::ThreadPerWarp_M>,
sequence
<>
,
// tuple<sequence<S::WarpPerBlock_N, S::ThreadPerWarp_N>>,
tuple
<
sequence
<
S
::
WarpPerBlock_M
,
S
::
ThreadPerWarp_M
>
,
// tuple<sequence<0, 1>, sequence<0, 1>>,
sequence
<
S
::
Repeat_N
,
S
::
WarpPerBlock_N
,
S
::
ThreadPerWarp_N
>>
,
// tuple<sequence<1, 0>, sequence<2, 1>>,
tuple
<
sequence
<
1
,
2
>
,
sequence
<
1
,
2
>>
,
// sequence<0>,
tuple
<
sequence
<
0
,
1
>
,
sequence
<
1
,
2
>>
,
// sequence<0>>{});
sequence
<
2
>
,
// }
sequence
<
0
>>
{});
}
template
<
typename
Problem
>
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
index_t
GetSmemSize
()
CK_TILE_HOST_DEVICE
static
constexpr
index_t
GetSmemSize
()
{
{
...
...
include/ck_tile/ops/layernorm2d/pipeline/layernorm2d_bwd_pipeline_gamma_beta.hpp
View file @
c03a937d
...
@@ -24,7 +24,7 @@ struct Layernorm2dBwdGammaBetaPipeline
...
@@ -24,7 +24,7 @@ struct Layernorm2dBwdGammaBetaPipeline
using
MeanDataType
=
ck_tile
::
remove_cvref_t
<
typename
Problem
::
MeanDataType
>
;
using
MeanDataType
=
ck_tile
::
remove_cvref_t
<
typename
Problem
::
MeanDataType
>
;
using
InvStdDataType
=
ck_tile
::
remove_cvref_t
<
typename
Problem
::
InvStdDataType
>
;
using
InvStdDataType
=
ck_tile
::
remove_cvref_t
<
typename
Problem
::
InvStdDataType
>
;
static
constexpr
bool
kPadM
=
false
;
// TODO - BlockLayernorm2dBwdGammaBetaProblem::kPadM
static
constexpr
bool
kPadM
=
false
;
static
constexpr
bool
kPadN
=
Problem
::
kPadN
;
static
constexpr
bool
kPadN
=
Problem
::
kPadN
;
static
constexpr
const
char
*
name
=
[]()
{
static
constexpr
const
char
*
name
=
[]()
{
...
@@ -35,31 +35,13 @@ struct Layernorm2dBwdGammaBetaPipeline
...
@@ -35,31 +35,13 @@ struct Layernorm2dBwdGammaBetaPipeline
{
{
return
Policy
::
template
GetSmemSize
<
Problem
>();
return
Policy
::
template
GetSmemSize
<
Problem
>();
}
}
// template <typename DumpTensor_>
template
<
typename
XWindow
,
// CK_TILE_DEVICE void dump(const DumpTensor_& x) const
// {
// constexpr auto I0 = number<0>{};
// constexpr auto I1 = number<1>{};
// constexpr auto spans = DumpTensor_::get_distributed_spans();
// sweep_tile_span(spans[I1], [&](auto i1) {
// sweep_tile_span(spans[I0], [&](auto i0) {
// constexpr auto in_dstr_idx = make_tuple(i0, i1);
// auto v = ck_tile::type_convert<float>(x[in_dstr_idx]);
// index_t tid =
// (threadIdx.z * (blockDim.x * blockDim.y)) + (threadIdx.y * blockDim.x) + threadIdx.x;
// printf("%d %f\n", tid, v);
// });
// });
// }
template
<
typename
DYWindow
,
typename
MeanWindow
,
typename
MeanWindow
,
typename
InvStdWindow
,
typename
InvStdWindow
,
typename
DGammaWindow
,
typename
DGammaWindow
,
typename
DBetaWindow
>
typename
DBetaWindow
>
CK_TILE_DEVICE
auto
operator
()(
const
DYWindow
&
dy_window_
,
CK_TILE_DEVICE
auto
operator
()(
const
XWindow
&
x_window_
,
const
XWindow
&
dy_window_
,
const
MeanWindow
&
mean_window_
,
const
MeanWindow
&
mean_window_
,
const
InvStdWindow
&
inv_std_window_
,
const
InvStdWindow
&
inv_std_window_
,
DGammaWindow
&
gamma_window_
,
DGammaWindow
&
gamma_window_
,
...
@@ -67,52 +49,43 @@ struct Layernorm2dBwdGammaBetaPipeline
...
@@ -67,52 +49,43 @@ struct Layernorm2dBwdGammaBetaPipeline
ck_tile
::
index_t
row_size
,
ck_tile
::
index_t
row_size
,
void
*
smem
)
const
void
*
smem
)
const
{
{
const
auto
dy_window
=
make_tile_window
(
dy_window_
,
auto
gamma_beta_dist
=
Policy
::
template
MakeGammaBetaBlockTileDistribution
<
Problem
>();
Policy
::
template
MakeDyBlockTileDistribution
<
Problem
>());
auto
mean_dist
=
Policy
::
template
MakeMeanBlockTileDistribution
<
Problem
>();
const
auto
mean_window
=
make_tile_window
(
auto
x_dist
=
Policy
::
template
MakeXBlockTileDistribution
<
Problem
>();
mean_window_
,
Policy
::
template
MakeMeanBlockTileDistribution
<
Problem
>());
const
auto
inv_std_window
=
make_tile_window
(
const
auto
x_window
=
make_tile_window
(
x_window_
,
x_dist
);
inv_std_window_
,
Policy
::
template
MakeMeanBlockTileDistribution
<
Problem
>());
const
auto
dy_window
=
make_tile_window
(
dy_window_
,
x_dist
);
// const auto gamma_window = make_tile_window(
const
auto
mean_window
=
make_tile_window
(
mean_window_
,
mean_dist
);
// gamma_window_, Policy::template MakeGammaBetaBlockTileDistribution<Problem>());
const
auto
inv_std_window
=
make_tile_window
(
inv_std_window_
,
mean_dist
);
// const auto beta_window = make_tile_window(
const
auto
x_tile
=
load_tile
(
x_window
);
// beta_window_, Policy::template MakeGammaBetaBlockTileDistribution<Problem>());
const
auto
dy_tile
=
load_tile
(
dy_window
);
const
auto
mean_tile
=
load_tile
(
mean_window
);
const
auto
dy
=
load_tile
(
dy_window
);
const
auto
inv_std_tile
=
load_tile
(
inv_std_window
);
const
auto
mean
=
load_tile
(
mean_window
);
const
auto
inv_std
=
load_tile
(
inv_std_window
);
// auto y = make_static_distributed_tensor<YDataType>(dy.get_tile_distribution());
auto
gamma_window
=
make_tile_window
(
gamma_window_
,
gamma_beta_dist
);
sweep_tile
(
mean
,
[
&
](
auto
idx
)
{
auto
beta_window
=
make_tile_window
(
beta_window_
,
gamma_beta_dist
);
auto
gamma_tile
=
make_static_distributed_tensor
<
GammaDataType
>
(
gamma_beta_dist
);
auto
beta_tile
=
make_static_distributed_tensor
<
BetaDataType
>
(
gamma_beta_dist
);
sweep_tile
(
x_tile
,
[
&
](
auto
idx
)
{
constexpr
auto
i_idx
=
make_tuple
(
idx
[
number
<
0
>
{}]);
constexpr
auto
i_idx
=
make_tuple
(
idx
[
number
<
0
>
{}]);
// constexpr auto j_idx = make_tuple(idx[number<1>{}]);
constexpr
auto
j_idx
=
make_tuple
(
idx
[
number
<
1
>
{}]);
constexpr
auto
gb_idx
=
make_tuple
(
number
<
0
>
{},
idx
[
number
<
1
>
{}]);
index_t
tid
=
(
threadIdx
.
y
*
blockDim
.
x
)
+
threadIdx
.
x
;
auto
&
gamma
=
gamma_tile
(
gb_idx
);
const
auto
m
=
type_convert
<
float
>
(
mean
[
i_idx
]);
auto
&
beta
=
beta_tile
(
gb_idx
);
if
(
blockIdx
.
x
==
0
&&
blockIdx
.
y
==
0
)
const
auto
x
=
type_convert
<
ComputeDataType
>
(
x_tile
[
idx
]);
printf
(
"%d %f
\n
"
,
tid
,
m
);
const
auto
dy
=
type_convert
<
ComputeDataType
>
(
dy_tile
[
idx
]);
const
auto
mean
=
type_convert
<
ComputeDataType
>
(
mean_tile
[
i_idx
]);
const
auto
inv_std
=
type_convert
<
ComputeDataType
>
(
inv_std_tile
[
i_idx
]);
beta
+=
type_convert
<
BetaDataType
>
(
dy
);
gamma
+=
type_convert
<
GammaDataType
>
(
dy
*
(
x
-
mean
)
*
inv_std
);
// index_t tid = (threadIdx.y * blockDim.x) + threadIdx.x;
// if(blockIdx.x < 3 && blockIdx.y == 0 && tid < 3) {
// printf("bid %d tid %d count %d gb %f %f\n",blockIdx.x, tid, count, type_convert<float>(g), type_convert<float>(b));
// }
});
});
// dump(dy);
store_tile
(
gamma_window
,
gamma_tile
);
// dump(mean);
store_tile
(
beta_window
,
beta_tile
);
// dump(inv_std);
*
reinterpret_cast
<
char
*>
(
smem
)
=
row_size
;
// layernorm computation
// auto y = make_static_distributed_tensor<YDataType>(x.get_tile_distribution());
// sweep_tile(y, [&, mean_ = mean](auto idx) {
// constexpr auto i_idx = make_tuple(idx[number<0>{}]);
// constexpr auto j_idx = make_tuple(idx[number<1>{}]);
// const auto gamma_ = type_convert<ComputeDataType>(gamma[j_idx]);
// const auto beta_ = type_convert<ComputeDataType>(beta[j_idx]);
// const auto x_ = type_convert<ComputeDataType>(x[idx]);
// auto y_ = (x_ - mean_[i_idx]) * inv_std[i_idx] * gamma_ + beta_;
// y(idx) = type_convert<YDataType>(y_);
// });
// store_tile(y_window, y);
}
}
};
};
}
// namespace ck_tile
}
// namespace ck_tile
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment