Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
b7aa49a3
Commit
b7aa49a3
authored
Nov 22, 2023
by
rocking
Browse files
Add sweep once pipeline for small reduce size
parent
6fae79b6
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
97 additions
and
5 deletions
+97
-5
include/ck/tensor_operation/gpu/device/impl/device_normalization_bwd_x_impl.hpp
...ation/gpu/device/impl/device_normalization_bwd_x_impl.hpp
+2
-2
include/ck/tensor_operation/gpu/grid/normalization/gridwise_normalization_bwd_x.hpp
...n/gpu/grid/normalization/gridwise_normalization_bwd_x.hpp
+95
-3
No files found.
include/ck/tensor_operation/gpu/device/impl/device_normalization_bwd_x_impl.hpp
View file @
b7aa49a3
...
@@ -266,8 +266,8 @@ struct DeviceNormalizationBwdXImpl : public DeviceNormalizationBwdX<DYDataType,
...
@@ -266,8 +266,8 @@ struct DeviceNormalizationBwdXImpl : public DeviceNormalizationBwdX<DYDataType,
dx_grid_desc_m_k_
=
Make2dDescriptor
(
lengths_
,
dxStrides_
,
numBlockTileIteration_
);
dx_grid_desc_m_k_
=
Make2dDescriptor
(
lengths_
,
dxStrides_
,
numBlockTileIteration_
);
// TODO - sweep once for small k
// TODO - sweep once for small k
//
isSweeponce_ = dy_grid_desc_m_k_.GetLength(Number<1>{}) <= K_BlockTileSize;
isSweeponce_
=
dy_grid_desc_m_k_
.
GetLength
(
Number
<
1
>
{})
<=
K_BlockTileSize
;
isSweeponce_
=
false
;
//
isSweeponce_ = false;
}
}
const
DYDataType
*
p_dy_
;
const
DYDataType
*
p_dy_
;
...
...
include/ck/tensor_operation/gpu/grid/normalization/gridwise_normalization_bwd_x.hpp
View file @
b7aa49a3
...
@@ -304,11 +304,103 @@ struct GridwiseNormalizationBwdX_mk_to_mk
...
@@ -304,11 +304,103 @@ struct GridwiseNormalizationBwdX_mk_to_mk
});
});
// Separate sweep once and sweep twice pipeline
// Separate sweep once and sweep twice pipeline
// Sweep once: for small k, if KThreadClusterSize * KThreadSliceSize > K
// we don't need to use loop to read x, dy, gamma twice
if
constexpr
(
SweepOnce
)
if
constexpr
(
SweepOnce
)
{
{
// TODO
threadwise_dy_load
.
Run
(
dy_grid_desc_m_k
,
}
// end of sweep once
dy_global_val_buf
,
else
thread_buffer_desc_m_k
,
make_tuple
(
I0
,
I0
),
dy_thread_buf
);
threadwise_x_load
.
Run
(
x_grid_desc_m_k
,
x_global_val_buf
,
thread_buffer_desc_m_k
,
make_tuple
(
I0
,
I0
),
x_thread_buf
);
threadwise_gamma_load
.
Run
(
gamma_grid_desc_m_k
,
gamma_global_val_buf
,
thread_buffer_desc_m_k
,
make_tuple
(
I0
,
I0
),
gamma_thread_buf
);
threadwise_mean_load
.
Run
(
mean_grid_desc_m_k
,
mean_global_val_buf
,
thread_buffer_desc_m_k
,
make_tuple
(
I0
,
I0
),
mean_thread_buf
);
threadwise_inv_std_load
.
Run
(
inv_std_grid_desc_m_k
,
inv_std_global_val_buf
,
thread_buffer_desc_m_k
,
make_tuple
(
I0
,
I0
),
inv_std_thread_buf
);
static_for
<
0
,
MThreadSliceSize
,
1
>
{}([
&
](
auto
iM
)
{
constexpr
auto
offset_m
=
Number
<
thread_buffer_desc_m
.
CalculateOffset
(
make_tuple
(
iM
))
>
{};
static_for
<
0
,
KThreadSliceSize
,
1
>
{}([
&
](
auto
iK
)
{
constexpr
auto
offset_m_k
=
Number
<
thread_buffer_desc_m_k
.
CalculateOffset
(
make_tuple
(
iM
,
iK
))
>
{};
ds_thread_buf
(
offset_m
)
+=
dy_thread_buf
[
offset_m_k
]
*
gamma_thread_buf
[
offset_m_k
]
*
x_thread_buf
[
offset_m_k
];
db_thread_buf
(
offset_m
)
+=
dy_thread_buf
[
offset_m_k
]
*
gamma_thread_buf
[
offset_m_k
];
});
});
static_for
<
0
,
MThreadSliceSize
,
1
>
{}([
&
](
auto
I
)
{
if
constexpr
(
I
>
0
)
block_sync_lds
();
BlockwiseSumReduce
::
Reduce
(
reduce_work_buf
,
ds_thread_buf
(
I
));
block_sync_lds
();
BlockwiseSumReduce
::
Reduce
(
reduce_work_buf
,
db_thread_buf
(
I
));
});
static_for
<
0
,
MThreadSliceSize
,
1
>
{}([
&
](
auto
iM
)
{
constexpr
auto
offset_m
=
Number
<
thread_buffer_desc_m
.
CalculateOffset
(
make_tuple
(
iM
))
>
{};
static_for
<
0
,
KThreadSliceSize
,
1
>
{}([
&
](
auto
iK
)
{
constexpr
auto
offset_m_k
=
Number
<
thread_buffer_desc_m_k
.
CalculateOffset
(
make_tuple
(
iM
,
iK
))
>
{};
// b = (db * x_mean - ds) * rstd ** (3) / reduce_size
// c = -b * x_mean - db * rstd / reduce_size
// dx = rstd * dy * gamma + b * x + c
ComputeDataType
b
=
db_thread_buf
[
offset_m
]
*
mean_thread_buf
[
offset_m_k
]
-
ds_thread_buf
[
offset_m
];
b
*=
inv_std_thread_buf
[
offset_m_k
]
*
inv_std_thread_buf
[
offset_m_k
]
*
inv_std_thread_buf
[
offset_m_k
]
/
reduce_size
;
ComputeDataType
c
=
-
b
*
mean_thread_buf
(
offset_m_k
);
c
-=
db_thread_buf
[
offset_m
]
*
inv_std_thread_buf
[
offset_m_k
]
/
reduce_size
;
dx_thread_buf
(
offset_m_k
)
=
dy_thread_buf
[
offset_m_k
]
*
gamma_thread_buf
[
offset_m_k
]
*
inv_std_thread_buf
[
offset_m_k
]
+
b
*
x_thread_buf
[
offset_m_k
]
+
c
;
});
});
threadwise_dx_store
.
Run
(
thread_buffer_desc_m_k
,
make_tuple
(
I0
,
I0
),
dx_thread_buf
,
dx_grid_desc_m_k
,
dx_global_val_buf
);
}
// end of sweep once
else
// Sweep Twice pipeline
{
{
constexpr
auto
thread_copy_fwd_step_m_k
=
make_multi_index
(
0
,
K_BlockTileSize
);
constexpr
auto
thread_copy_fwd_step_m_k
=
make_multi_index
(
0
,
K_BlockTileSize
);
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment