Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
3bb0cbe7
Commit
3bb0cbe7
authored
Jul 05, 2022
by
rocking
Browse files
We only use one block in K dimension.
Hence, we can simplify the indexing of global R/W.
parent
6d3ad8cd
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
12 additions
and
20 deletions
+12
-20
include/ck/tensor_operation/gpu/device/device_layernorm.hpp
include/ck/tensor_operation/gpu/device/device_layernorm.hpp
+0
-1
include/ck/tensor_operation/gpu/grid/gridwise_layernorm.hpp
include/ck/tensor_operation/gpu/grid/gridwise_layernorm.hpp
+12
-19
No files found.
include/ck/tensor_operation/gpu/device/device_layernorm.hpp
View file @
3bb0cbe7
...
@@ -199,7 +199,6 @@ struct DeviceLayernorm : public BaseOperator
...
@@ -199,7 +199,6 @@ struct DeviceLayernorm : public BaseOperator
gamma_grid_desc_m_k
,
gamma_grid_desc_m_k
,
beta_grid_desc_m_k
,
beta_grid_desc_m_k
,
y_grid_desc_m_k
,
y_grid_desc_m_k
,
arg
.
blkGroupSize
,
arg
.
numBlockTileIteration
,
arg
.
numBlockTileIteration
,
arg
.
epsilon_
,
arg
.
epsilon_
,
arg
.
in_dev_
,
arg
.
in_dev_
,
...
...
include/ck/tensor_operation/gpu/grid/gridwise_layernorm.hpp
View file @
3bb0cbe7
...
@@ -25,7 +25,6 @@ __global__ void kernel_layernorm(const GridDesc_M_K x_grid_desc_m_k,
...
@@ -25,7 +25,6 @@ __global__ void kernel_layernorm(const GridDesc_M_K x_grid_desc_m_k,
const
GridDesc_M_K
gamma_grid_desc_m_k
,
const
GridDesc_M_K
gamma_grid_desc_m_k
,
const
GridDesc_M_K
beta_grid_desc_m_k
,
const
GridDesc_M_K
beta_grid_desc_m_k
,
const
GridDesc_M_K
y_grid_desc_m_k
,
const
GridDesc_M_K
y_grid_desc_m_k
,
index_t
block_group_size
,
index_t
num_k_block_tile_iteration
,
index_t
num_k_block_tile_iteration
,
AccDataType
epsilon
,
AccDataType
epsilon
,
const
XDataType
*
const
__restrict__
p_x_global
,
const
XDataType
*
const
__restrict__
p_x_global
,
...
@@ -37,7 +36,6 @@ __global__ void kernel_layernorm(const GridDesc_M_K x_grid_desc_m_k,
...
@@ -37,7 +36,6 @@ __global__ void kernel_layernorm(const GridDesc_M_K x_grid_desc_m_k,
gamma_grid_desc_m_k
,
gamma_grid_desc_m_k
,
beta_grid_desc_m_k
,
beta_grid_desc_m_k
,
y_grid_desc_m_k
,
y_grid_desc_m_k
,
block_group_size
,
num_k_block_tile_iteration
,
num_k_block_tile_iteration
,
epsilon
,
epsilon
,
p_x_global
,
p_x_global
,
...
@@ -119,7 +117,6 @@ struct GridwiseLayernorm_mk_to_mk
...
@@ -119,7 +117,6 @@ struct GridwiseLayernorm_mk_to_mk
const
GridDesc_M_K
&
gamma_grid_desc_m_k
,
const
GridDesc_M_K
&
gamma_grid_desc_m_k
,
const
GridDesc_M_K
&
beta_grid_desc_m_k
,
const
GridDesc_M_K
&
beta_grid_desc_m_k
,
const
GridDesc_M_K
&
y_grid_desc_m_k
,
const
GridDesc_M_K
&
y_grid_desc_m_k
,
index_t
block_group_size
,
index_t
num_k_block_tile_iteration
,
index_t
num_k_block_tile_iteration
,
AccDataType
epsilon
,
AccDataType
epsilon
,
const
XDataType
*
const
__restrict__
p_x_global
,
const
XDataType
*
const
__restrict__
p_x_global
,
...
@@ -171,8 +168,6 @@ struct GridwiseLayernorm_mk_to_mk
...
@@ -171,8 +168,6 @@ struct GridwiseLayernorm_mk_to_mk
const
index_t
thread_local_id
=
get_thread_local_1d_id
();
const
index_t
thread_local_id
=
get_thread_local_1d_id
();
const
index_t
block_global_id
=
get_block_1d_id
();
const
index_t
block_global_id
=
get_block_1d_id
();
const
index_t
blkgroup_id
=
block_global_id
/
block_group_size
;
const
index_t
block_local_id
=
block_global_id
%
block_group_size
;
const
auto
thread_cluster_idx
=
const
auto
thread_cluster_idx
=
thread_cluster_desc
.
CalculateBottomIndex
(
make_multi_index
(
thread_local_id
));
thread_cluster_desc
.
CalculateBottomIndex
(
make_multi_index
(
thread_local_id
));
...
@@ -180,8 +175,6 @@ struct GridwiseLayernorm_mk_to_mk
...
@@ -180,8 +175,6 @@ struct GridwiseLayernorm_mk_to_mk
const
auto
thread_m_cluster_id
=
thread_cluster_idx
[
I0
];
const
auto
thread_m_cluster_id
=
thread_cluster_idx
[
I0
];
const
auto
thread_k_cluster_id
=
thread_cluster_idx
[
I1
];
const
auto
thread_k_cluster_id
=
thread_cluster_idx
[
I1
];
const
index_t
reduceSizePerBlock
=
K_BlockTileSize
*
num_k_block_tile_iteration
;
using
ThreadBufferLengths
=
Sequence
<
MThreadSliceSize
,
KThreadSliceSize
>
;
using
ThreadBufferLengths
=
Sequence
<
MThreadSliceSize
,
KThreadSliceSize
>
;
constexpr
auto
thread_buffer_desc
=
make_naive_tensor_descriptor_packed
(
constexpr
auto
thread_buffer_desc
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
Number
<
MThreadSliceSize
>
{},
Number
<
KThreadSliceSize
>
{}));
make_tuple
(
Number
<
MThreadSliceSize
>
{},
Number
<
KThreadSliceSize
>
{}));
...
@@ -197,9 +190,9 @@ struct GridwiseLayernorm_mk_to_mk
...
@@ -197,9 +190,9 @@ struct GridwiseLayernorm_mk_to_mk
1
,
1
,
true
>
(
true
>
(
x_grid_desc_m_k
,
x_grid_desc_m_k
,
make_multi_index
(
bl
kgroup
_id
*
M_BlockTileSize
+
thread_m_cluster_id
*
MThreadSliceSize
,
make_multi_index
(
bl
ock_global
_id
*
M_BlockTileSize
+
block_local_id
*
reduceSizePerBlock
+
thread_m_cluster_id
*
MThreadSliceSize
,
thread_k_cluster_id
*
KThreadSliceSize
));
thread_k_cluster_id
*
KThreadSliceSize
));
auto
threadwise_gamma_load
=
ThreadwiseTensorSliceTransfer_v2
<
GammaDataType
,
auto
threadwise_gamma_load
=
ThreadwiseTensorSliceTransfer_v2
<
GammaDataType
,
AccDataType
,
AccDataType
,
...
@@ -212,9 +205,9 @@ struct GridwiseLayernorm_mk_to_mk
...
@@ -212,9 +205,9 @@ struct GridwiseLayernorm_mk_to_mk
1
,
1
,
true
>
(
true
>
(
gamma_grid_desc_m_k
,
gamma_grid_desc_m_k
,
make_multi_index
(
bl
kgroup
_id
*
M_BlockTileSize
+
thread_m_cluster_id
*
MThreadSliceSize
,
make_multi_index
(
bl
ock_global
_id
*
M_BlockTileSize
+
block_local_id
*
reduceSizePerBlock
+
thread_m_cluster_id
*
MThreadSliceSize
,
thread_k_cluster_id
*
KThreadSliceSize
));
thread_k_cluster_id
*
KThreadSliceSize
));
auto
threadwise_beta_load
=
ThreadwiseTensorSliceTransfer_v2
<
BetaDataType
,
auto
threadwise_beta_load
=
ThreadwiseTensorSliceTransfer_v2
<
BetaDataType
,
AccDataType
,
AccDataType
,
...
@@ -227,9 +220,9 @@ struct GridwiseLayernorm_mk_to_mk
...
@@ -227,9 +220,9 @@ struct GridwiseLayernorm_mk_to_mk
1
,
1
,
true
>
(
true
>
(
beta_grid_desc_m_k
,
beta_grid_desc_m_k
,
make_multi_index
(
bl
kgroup
_id
*
M_BlockTileSize
+
thread_m_cluster_id
*
MThreadSliceSize
,
make_multi_index
(
bl
ock_global
_id
*
M_BlockTileSize
+
block_local_id
*
reduceSizePerBlock
+
thread_m_cluster_id
*
MThreadSliceSize
,
thread_k_cluster_id
*
KThreadSliceSize
));
thread_k_cluster_id
*
KThreadSliceSize
));
auto
threadwise_y_store
=
ThreadwiseTensorSliceTransfer_v1r3
<
AccDataType
,
auto
threadwise_y_store
=
ThreadwiseTensorSliceTransfer_v1r3
<
AccDataType
,
YDataType
,
YDataType
,
...
@@ -244,9 +237,9 @@ struct GridwiseLayernorm_mk_to_mk
...
@@ -244,9 +237,9 @@ struct GridwiseLayernorm_mk_to_mk
1
,
1
,
true
>
(
true
>
(
y_grid_desc_m_k
,
y_grid_desc_m_k
,
make_multi_index
(
bl
kgroup
_id
*
M_BlockTileSize
+
thread_m_cluster_id
*
MThreadSliceSize
,
make_multi_index
(
bl
ock_global
_id
*
M_BlockTileSize
+
block_local_id
*
reduceSizePerBlock
+
thread_m_cluster_id
*
MThreadSliceSize
,
thread_k_cluster_id
*
KThreadSliceSize
),
thread_k_cluster_id
*
KThreadSliceSize
),
PassThroughOp
{});
PassThroughOp
{});
// Copy x from Cache
// Copy x from Cache
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment