Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
d7bb21c2
Commit
d7bb21c2
authored
Sep 23, 2022
by
wangshaojie6
Browse files
optimize group layer norm
parent
8daff431
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
149 additions
and
58 deletions
+149
-58
include/ck/tensor_operation/gpu/grid/gridwise_layernorm_welford_variance.hpp
...peration/gpu/grid/gridwise_layernorm_welford_variance.hpp
+149
-58
No files found.
include/ck/tensor_operation/gpu/grid/gridwise_layernorm_welford_variance.hpp
View file @
d7bb21c2
...
@@ -57,7 +57,7 @@ struct GridwiseLayernormWelfordVariance_mk_to_mk
...
@@ -57,7 +57,7 @@ struct GridwiseLayernormWelfordVariance_mk_to_mk
make_cluster_descriptor
(
ThreadClusterLengths_M_K
{},
ThreadClusterArrangeOrder
{});
make_cluster_descriptor
(
ThreadClusterLengths_M_K
{},
ThreadClusterArrangeOrder
{});
using
ThreadReduceSrcDesc_M_K
=
decltype
(
make_naive_tensor_descriptor_packed
(
using
ThreadReduceSrcDesc_M_K
=
decltype
(
make_naive_tensor_descriptor_packed
(
make_tuple
(
Number
<
MThreadSliceSize
>
{},
Number
<
KThreadSlice
Size
>
{})));
make_tuple
(
Number
<
MThreadSliceSize
>
{},
Number
<
XSrcVector
Size
>
{})));
using
ThreadReduceDstDesc_M
=
using
ThreadReduceDstDesc_M
=
decltype
(
make_naive_tensor_descriptor_packed
(
make_tuple
(
Number
<
MThreadSliceSize
>
{})));
decltype
(
make_naive_tensor_descriptor_packed
(
make_tuple
(
Number
<
MThreadSliceSize
>
{})));
...
@@ -73,8 +73,14 @@ struct GridwiseLayernormWelfordVariance_mk_to_mk
...
@@ -73,8 +73,14 @@ struct GridwiseLayernormWelfordVariance_mk_to_mk
static
constexpr
auto
I1
=
Number
<
1
>
{};
static
constexpr
auto
I1
=
Number
<
1
>
{};
static
constexpr
auto
I2
=
Number
<
2
>
{};
static
constexpr
auto
I2
=
Number
<
2
>
{};
static
constexpr
index_t
M_BlockTileSize
=
MThreadClusterSize
*
MThreadSliceSize
;
static
constexpr
index_t
M_BlockTileSize
=
MThreadClusterSize
*
MThreadSliceSize
;
static
constexpr
index_t
K_BlockTileSize
=
KThreadClusterSize
*
KThreadSliceSize
;
static
constexpr
index_t
K_BlockTileSize
=
KThreadClusterSize
*
KThreadSliceSize
;
static
constexpr
index_t
K_BlockTileStepSize
=
KThreadClusterSize
*
XSrcVectorSize
;
static
constexpr
auto
XThreadBufferNumber
=
Number
<
KThreadSliceSize
/
XSrcVectorSize
>
{};
static
constexpr
auto
GammaThreadBufferNumber
=
Number
<
KThreadSliceSize
/
XSrcVectorSize
>
{};
static
constexpr
auto
BetaThreadBufferNumber
=
Number
<
KThreadSliceSize
/
XSrcVectorSize
>
{};
static
constexpr
auto
YThreadBufferNumber
=
Number
<
KThreadSliceSize
/
XSrcVectorSize
>
{};
__device__
static
int
GetKPerThread
(
const
GridDesc_M_K
&
x_grid_desc_m_k
,
__device__
static
int
GetKPerThread
(
const
GridDesc_M_K
&
x_grid_desc_m_k
,
int
thread_k_cluster_id
)
int
thread_k_cluster_id
)
...
@@ -116,6 +122,47 @@ struct GridwiseLayernormWelfordVariance_mk_to_mk
...
@@ -116,6 +122,47 @@ struct GridwiseLayernormWelfordVariance_mk_to_mk
auto
y_global_val_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
auto
y_global_val_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
p_y_global
,
y_grid_desc_m_k
.
GetElementSpaceSize
());
p_y_global
,
y_grid_desc_m_k
.
GetElementSpaceSize
());
auto
x_thread_buf
=
generate_tuple
(
[
&
](
auto
i
)
{
ignore
=
i
;
return
StaticBuffer
<
AddressSpaceEnum
::
Vgpr
,
AccDataType
,
MThreadSliceSize
*
XSrcVectorSize
,
true
>
{};
},
Number
<
XThreadBufferNumber
>
{});
auto
gamma_thread_buf
=
generate_tuple
(
[
&
](
auto
i
)
{
ignore
=
i
;
return
StaticBuffer
<
AddressSpaceEnum
::
Vgpr
,
AccDataType
,
MThreadSliceSize
*
GammaSrcVectorSize
,
true
>
{};
},
Number
<
GammaThreadBufferNumber
>
{});
auto
beta_thread_buf
=
generate_tuple
(
[
&
](
auto
i
)
{
ignore
=
i
;
return
StaticBuffer
<
AddressSpaceEnum
::
Vgpr
,
AccDataType
,
MThreadSliceSize
*
BetaSrcVectorSize
,
true
>
{};
},
Number
<
BetaThreadBufferNumber
>
{});
auto
y_thread_buf
=
generate_tuple
(
[
&
](
auto
i
)
{
ignore
=
i
;
return
StaticBuffer
<
AddressSpaceEnum
::
Vgpr
,
AccDataType
,
MThreadSliceSize
*
YDstVectorSize
,
true
>
{};
},
Number
<
YThreadBufferNumber
>
{});
#if 0
StaticBuffer<AddressSpaceEnum::Vgpr, AccDataType, MThreadSliceSize * KThreadSliceSize, true>
StaticBuffer<AddressSpaceEnum::Vgpr, AccDataType, MThreadSliceSize * KThreadSliceSize, true>
x_thread_buf;
x_thread_buf;
...
@@ -129,6 +176,7 @@ struct GridwiseLayernormWelfordVariance_mk_to_mk
...
@@ -129,6 +176,7 @@ struct GridwiseLayernormWelfordVariance_mk_to_mk
StaticBuffer<AddressSpaceEnum::Vgpr, AccDataType, MThreadSliceSize * KThreadSliceSize, true>
StaticBuffer<AddressSpaceEnum::Vgpr, AccDataType, MThreadSliceSize * KThreadSliceSize, true>
y_thread_buf;
y_thread_buf;
#endif
StaticBuffer
<
AddressSpaceEnum
::
Vgpr
,
AccDataType
,
MThreadSliceSize
,
true
>
mean_thread_buf
;
StaticBuffer
<
AddressSpaceEnum
::
Vgpr
,
AccDataType
,
MThreadSliceSize
,
true
>
mean_thread_buf
;
StaticBuffer
<
AddressSpaceEnum
::
Vgpr
,
AccDataType
,
MThreadSliceSize
,
true
>
var_thread_buf
;
StaticBuffer
<
AddressSpaceEnum
::
Vgpr
,
AccDataType
,
MThreadSliceSize
,
true
>
var_thread_buf
;
...
@@ -142,9 +190,14 @@ struct GridwiseLayernormWelfordVariance_mk_to_mk
...
@@ -142,9 +190,14 @@ struct GridwiseLayernormWelfordVariance_mk_to_mk
const
auto
thread_m_cluster_id
=
thread_cluster_idx
[
I0
];
const
auto
thread_m_cluster_id
=
thread_cluster_idx
[
I0
];
const
auto
thread_k_cluster_id
=
thread_cluster_idx
[
I1
];
const
auto
thread_k_cluster_id
=
thread_cluster_idx
[
I1
];
using
ThreadBufferLengths_M_K
=
Sequence
<
MThreadSliceSize
,
KThreadSliceSize
>
;
// using ThreadBufferLengths_M_K = Sequence<MThreadSliceSize, KThreadSliceSize>;
// constexpr auto thread_buffer_desc_m_k = make_naive_tensor_descriptor_packed(
// make_tuple(Number<MThreadSliceSize>{}, Number<KThreadSliceSize>{}));
// auto red_num = slice/vector;
using
ThreadBufferLengths_M_K
=
Sequence
<
MThreadSliceSize
,
XSrcVectorSize
>
;
constexpr
auto
thread_buffer_desc_m_k
=
make_naive_tensor_descriptor_packed
(
constexpr
auto
thread_buffer_desc_m_k
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
Number
<
MThreadSliceSize
>
{},
Number
<
KThreadSlice
Size
>
{}));
make_tuple
(
Number
<
MThreadSliceSize
>
{},
Number
<
XSrcVector
Size
>
{}));
auto
threadwise_x_load
=
ThreadwiseTensorSliceTransfer_v2
<
XDataType
,
auto
threadwise_x_load
=
ThreadwiseTensorSliceTransfer_v2
<
XDataType
,
AccDataType
,
AccDataType
,
...
@@ -214,8 +267,11 @@ struct GridwiseLayernormWelfordVariance_mk_to_mk
...
@@ -214,8 +267,11 @@ struct GridwiseLayernormWelfordVariance_mk_to_mk
// Copy x from Cache
// Copy x from Cache
// one pass: fwd, second pass: bwd
// one pass: fwd, second pass: bwd
// constexpr auto thread_copy_fwd_step_m_k =
// make_multi_index(0, SweepOnce ? 0 : K_BlockTileSize);
constexpr
auto
thread_copy_fwd_step_m_k
=
constexpr
auto
thread_copy_fwd_step_m_k
=
make_multi_index
(
0
,
SweepOnce
?
0
:
K_BlockTileSize
);
make_multi_index
(
0
,
SweepOnce
?
0
:
K_BlockTileS
tepS
ize
);
constexpr
auto
thread_copy_bwd_step_m_k
=
constexpr
auto
thread_copy_bwd_step_m_k
=
make_multi_index
(
0
,
SweepOnce
?
0
:
-
K_BlockTileSize
);
make_multi_index
(
0
,
SweepOnce
?
0
:
-
K_BlockTileSize
);
...
@@ -238,16 +294,18 @@ struct GridwiseLayernormWelfordVariance_mk_to_mk
...
@@ -238,16 +294,18 @@ struct GridwiseLayernormWelfordVariance_mk_to_mk
for
(
index_t
reducedTiles
=
0
;
reducedTiles
<
num_k_block_tile_iteration
;
++
reducedTiles
)
for
(
index_t
reducedTiles
=
0
;
reducedTiles
<
num_k_block_tile_iteration
;
++
reducedTiles
)
{
{
static_for
<
0
,
XThreadBufferNumber
,
1
>
{}([
&
](
auto
i
)
{
threadwise_x_load
.
Run
(
x_grid_desc_m_k
,
threadwise_x_load
.
Run
(
x_grid_desc_m_k
,
x_global_val_buf
,
x_global_val_buf
,
thread_buffer_desc_m_k
,
thread_buffer_desc_m_k
,
make_tuple
(
I0
,
I0
),
make_tuple
(
I0
,
I0
),
x_thread_buf
);
x_thread_buf
(
i
));
threadwise_x_load
.
MoveSrcSliceWindow
(
x_grid_desc_m_k
,
thread_copy_fwd_step_m_k
);
threadwise_x_load
.
MoveSrcSliceWindow
(
x_grid_desc_m_k
,
thread_copy_fwd_step_m_k
);
threadwise_welford
.
Run
(
x_thread_buf
,
mean_thread_buf
,
var_thread_buf
);
threadwise_welford
.
Run
(
x_thread_buf
[
i
],
mean_thread_buf
,
var_thread_buf
);
});
}
}
#if 1
static_for
<
0
,
MThreadSliceSize
,
1
>
{}([
&
](
auto
I
)
{
static_for
<
0
,
MThreadSliceSize
,
1
>
{}([
&
](
auto
I
)
{
if
constexpr
(
I
>
0
)
if
constexpr
(
I
>
0
)
block_sync_lds
();
block_sync_lds
();
...
@@ -255,6 +313,7 @@ struct GridwiseLayernormWelfordVariance_mk_to_mk
...
@@ -255,6 +313,7 @@ struct GridwiseLayernormWelfordVariance_mk_to_mk
int
count
=
threadwise_welford
.
cur_count_
;
int
count
=
threadwise_welford
.
cur_count_
;
BlockwiseWelford
::
Run
(
mean_thread_buf
(
I
),
var_thread_buf
(
I
),
count
);
BlockwiseWelford
::
Run
(
mean_thread_buf
(
I
),
var_thread_buf
(
I
),
count
);
});
});
#endif
auto
thread_copy_tail_m_k
=
(
num_k_block_tile_iteration
-
1
)
*
thread_copy_fwd_step_m_k
;
auto
thread_copy_tail_m_k
=
(
num_k_block_tile_iteration
-
1
)
*
thread_copy_fwd_step_m_k
;
...
@@ -267,62 +326,94 @@ struct GridwiseLayernormWelfordVariance_mk_to_mk
...
@@ -267,62 +326,94 @@ struct GridwiseLayernormWelfordVariance_mk_to_mk
{
{
if
constexpr
(
!
SweepOnce
)
if
constexpr
(
!
SweepOnce
)
{
{
threadwise_x_load
.
Run
(
x_grid_desc_m_k
,
static_for
<
0
,
XThreadBufferNumber
,
1
>
{}([
&
](
auto
i
)
{
x_global_val_buf
,
threadwise_x_load
.
Run
(
x_grid_desc_m_k
,
thread_buffer_desc_m_k
,
x_global_val_buf
,
make_tuple
(
I0
,
I0
),
thread_buffer_desc_m_k
,
x_thread_buf
);
make_tuple
(
I0
,
I0
),
x_thread_buf
(
i
));
threadwise_x_load
.
MoveSrcSliceWindow
(
x_grid_desc_m_k
,
thread_copy_fwd_step_m_k
);
});
}
}
threadwise_gamma_load
.
Run
(
gamma_grid_desc_m_k
,
#if 1
gamma_global_val_buf
,
static_for
<
0
,
GammaThreadBufferNumber
,
1
>
{}([
&
](
auto
i
)
{
thread_buffer_desc_m_k
,
threadwise_gamma_load
.
Run
(
gamma_grid_desc_m_k
,
make_tuple
(
I0
,
I0
),
gamma_global_val_buf
,
gamma_thread_buf
);
thread_buffer_desc_m_k
,
make_tuple
(
I0
,
I0
),
gamma_thread_buf
(
i
));
static_for
<
0
,
MThreadSliceSize
,
1
>
{}([
&
](
auto
iM
)
{
threadwise_gamma_load
.
MoveSrcSliceWindow
(
gamma_grid_desc_m_k
,
thread_copy_fwd_step_m_k
);
static_for
<
0
,
KThreadSliceSize
,
1
>
{}([
&
](
auto
iK
)
{
constexpr
auto
offset_m_k
=
thread_buffer_desc_m_k
.
CalculateOffset
(
make_tuple
(
iM
,
iK
));
// normalize
y_thread_buf
(
Number
<
offset_m_k
>
{})
=
(
x_thread_buf
(
Number
<
offset_m_k
>
{})
-
mean_thread_buf
(
iM
))
/
sqrt
(
var_thread_buf
(
iM
)
+
epsilon
);
// gamma
y_thread_buf
(
Number
<
offset_m_k
>
{})
=
y_thread_buf
(
Number
<
offset_m_k
>
{})
*
gamma_thread_buf
(
Number
<
offset_m_k
>
{});
});
});
});
#endif
threadwise_beta_load
.
Run
(
beta_grid_desc_m_k
,
#if 0
beta_global_val_buf
,
static_for<0, gamma_thread_buf.Size(), 1>{}([&](auto i){
thread_buffer_desc_m_k
,
gamma_thread_buf(i) = 1;
make_tuple
(
I0
,
I0
),
beta_thread_buf(i) = 1;
beta_thread_buf
);
});
#endif
static_for
<
0
,
MThreadSliceSize
,
1
>
{}([
&
](
auto
iM
)
{
static_for
<
0
,
MThreadSliceSize
,
1
>
{}([
&
](
auto
iM
)
{
static_for
<
0
,
KThreadSliceSize
,
1
>
{}([
&
](
auto
iK
)
{
auto
divisor
=
1
/
sqrt
(
var_thread_buf
(
iM
)
+
epsilon
);
constexpr
auto
offset_m_k
=
static_for
<
0
,
XThreadBufferNumber
,
1
>
{}([
&
](
auto
iK0
)
{
thread_buffer_desc_m_k
.
CalculateOffset
(
make_tuple
(
iM
,
iK
));
static_for
<
0
,
XSrcVectorSize
,
1
>
{}([
&
](
auto
iK1
)
{
constexpr
auto
offset_m_k
=
thread_buffer_desc_m_k
.
CalculateOffset
(
make_tuple
(
iM
,
iK1
));
#if 1
// normalize
y_thread_buf
(
iK0
)(
Number
<
offset_m_k
>
{})
=
(
x_thread_buf
(
iK0
)(
Number
<
offset_m_k
>
{})
-
mean_thread_buf
(
iM
))
*
divisor
;
#endif
// gamma
y_thread_buf
(
iK0
)(
Number
<
offset_m_k
>
{})
=
y_thread_buf
(
iK0
)(
Number
<
offset_m_k
>
{})
*
gamma_thread_buf
(
iK0
)(
Number
<
offset_m_k
>
{});
});
});
});
#if 1
static_for
<
0
,
BetaThreadBufferNumber
,
1
>
{}([
&
](
auto
i
)
{
threadwise_beta_load
.
Run
(
beta_grid_desc_m_k
,
beta_global_val_buf
,
thread_buffer_desc_m_k
,
make_tuple
(
I0
,
I0
),
beta_thread_buf
(
i
));
threadwise_beta_load
.
MoveSrcSliceWindow
(
beta_grid_desc_m_k
,
thread_copy_fwd_step_m_k
);
});
#endif
// beta
static_for
<
0
,
MThreadSliceSize
,
1
>
{}([
&
](
auto
iM
)
{
y_thread_buf
(
Number
<
offset_m_k
>
{})
=
static_for
<
0
,
XThreadBufferNumber
,
1
>
{}([
&
](
auto
iK0
)
{
y_thread_buf
(
Number
<
offset_m_k
>
{})
+
beta_thread_buf
(
Number
<
offset_m_k
>
{});
static_for
<
0
,
XSrcVectorSize
,
1
>
{}([
&
](
auto
iK1
)
{
constexpr
auto
offset_m_k
=
thread_buffer_desc_m_k
.
CalculateOffset
(
make_tuple
(
iM
,
iK1
));
// beta
y_thread_buf
(
iK0
)(
Number
<
offset_m_k
>
{})
=
y_thread_buf
(
iK0
)(
Number
<
offset_m_k
>
{})
+
beta_thread_buf
(
iK0
)(
Number
<
offset_m_k
>
{});
});
});
});
});
});
threadwise_y_store
.
Run
(
thread_buffer_desc_m_k
,
static_for
<
0
,
YThreadBufferNumber
,
1
>
{}([
&
](
auto
i
)
{
make_tuple
(
I0
,
I0
),
threadwise_y_store
.
Run
(
thread_buffer_desc_m_k
,
y_thread_buf
,
make_tuple
(
I0
,
I0
),
y_grid_desc_m_k
,
y_thread_buf
(
i
),
y_global_val_buf
);
y_grid_desc_m_k
,
y_global_val_buf
);
threadwise_y_store
.
MoveDstSliceWindow
(
y_grid_desc_m_k
,
thread_copy_fwd_step_m_k
);
});
threadwise_x_load
.
MoveSrcSliceWindow
(
x_grid_desc_m_k
,
thread_copy_bwd_step_m_k
);
threadwise_x_load
.
MoveSrcSliceWindow
(
x_grid_desc_m_k
,
2
*
thread_copy_bwd_step_m_k
);
threadwise_gamma_load
.
MoveSrcSliceWindow
(
gamma_grid_desc_m_k
,
thread_copy_bwd_step_m_k
);
threadwise_gamma_load
.
MoveSrcSliceWindow
(
gamma_grid_desc_m_k
,
2
*
thread_copy_bwd_step_m_k
);
threadwise_beta_load
.
MoveSrcSliceWindow
(
beta_grid_desc_m_k
,
thread_copy_bwd_step_m_k
);
threadwise_beta_load
.
MoveSrcSliceWindow
(
beta_grid_desc_m_k
,
2
*
thread_copy_bwd_step_m_k
);
threadwise_y_store
.
MoveDstSliceWindow
(
y_grid_desc_m_k
,
thread_copy_bwd_step_m_k
);
threadwise_y_store
.
MoveDstSliceWindow
(
y_grid_desc_m_k
,
2
*
thread_copy_bwd_step_m_k
);
}
}
}
}
};
};
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment