Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel_ROCM
Commits
ed8ef7e5
Commit
ed8ef7e5
authored
Jul 26, 2024
by
danyao12
Browse files
dropout patch for mrepeat 16*16
parent
94c957b3
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
92 additions
and
28 deletions
+92
-28
example/ck_tile/01_fmha/codegen/ops/fmha_bwd.py
example/ck_tile/01_fmha/codegen/ops/fmha_bwd.py
+1
-1
include/ck_tile/core/utility/philox_rand.hpp
include/ck_tile/core/utility/philox_rand.hpp
+17
-0
include/ck_tile/ops/fmha/block/block_dropout.hpp
include/ck_tile/ops/fmha/block/block_dropout.hpp
+74
-27
No files found.
example/ck_tile/01_fmha/codegen/ops/fmha_bwd.py
View file @
ed8ef7e5
...
...
@@ -478,7 +478,7 @@ def get_bwd_dq_dk_dv_blobs(kernel_filter : Optional[str], receipt, mask_impl) ->
continue
if
((
bias
==
"no"
or
bias
==
"alibi"
)
and
dbias
==
"t"
):
continue
if
((
hdim
<
=
128
and
(
"wg16"
in
dropout
))
or
(
hdim
==
256
and
(
"wg32"
in
dropout
))):
if
((
(
hdim
==
64
or
hdim
=
=
128
)
and
(
"wg16"
in
dropout
))
or
(
(
hdim
==
32
or
hdim
==
256
)
and
(
"wg32"
in
dropout
))):
continue
k
=
FmhaBwdDQDKDVKernel
(
F_idx
=
0
,
F_hdim
=
hdim
,
F_dtype
=
dtype
,
F_tile
=
tile
,
F_spad
=
spad
,
F_skpad
=
skpad
,
F_dpad
=
dpad
,
F_dvpad
=
dvpad
,
...
...
include/ck_tile/core/utility/philox_rand.hpp
View file @
ed8ef7e5
...
...
@@ -53,6 +53,23 @@ class philox
out_tmp
[
3
]
=
tmp_ph
.
w
;
}
CK_TILE_HOST_DEVICE
void
get_random_8x8
(
uint8_t
*
out
,
const
unsigned
long
long
subsequence
,
const
index_t
start_idx
)
const
{
uint4
tmp_ph
;
tmp_ph
=
get_philox_4x32
(
subsequence
);
uint32x4_t
tmp
;
tmp
[
0
]
=
tmp_ph
.
x
;
tmp
[
1
]
=
tmp_ph
.
y
;
tmp
[
2
]
=
tmp_ph
.
z
;
tmp
[
3
]
=
tmp_ph
.
w
;
uint32_t
*
out_tmp
=
reinterpret_cast
<
uint32_t
*>
(
&
out
[
0
]);
out_tmp
[
0
]
=
tmp
[
start_idx
];
out_tmp
[
1
]
=
tmp
[
start_idx
+
2
];
}
CK_TILE_HOST_DEVICE
void
get_random_4x8
(
uint8_t
*
out
,
const
unsigned
long
long
subsequence
,
const
index_t
start_idx
)
const
...
...
include/ck_tile/ops/fmha/block/block_dropout.hpp
View file @
ed8ef7e5
...
...
@@ -60,10 +60,22 @@ struct BlockDropout<true, IsWG32_, IsStoreRandval_>
{
constexpr
auto
config
=
BlockGemm
::
Policy
::
template
GetWarpGemmMWarpNWarp
<
typename
BlockGemm
::
Problem
>();
using
WG
=
remove_cvref_t
<
decltype
(
config
.
template
at
<
0
>())
>
;
constexpr
index_t
MWarp
=
config
.
template
at
<
1
>();
constexpr
index_t
NWarp
=
config
.
template
at
<
2
>();
constexpr
index_t
kMPerStep
=
MWarp
*
WG
::
kM
;
using
BlockGemmShape
=
remove_cvref_t
<
typename
BlockGemm
::
BlockGemmShape
>
;
using
WG
=
remove_cvref_t
<
decltype
(
config
.
template
at
<
0
>())
>
;
constexpr
index_t
kMPerBlock
=
BlockGemmShape
::
kM
;
constexpr
index_t
MWarp
=
config
.
template
at
<
1
>();
constexpr
index_t
NWarp
=
config
.
template
at
<
2
>();
constexpr
bool
MBwdWG16MultiIterCheck
=
(
!
IsFwd
)
&&
(
!
IsWG32
)
&&
(
kMPerBlock
>
16
);
constexpr
index_t
kMPerStep
=
[
&
]()
{
if
constexpr
(
MBwdWG16MultiIterCheck
)
{
return
MWarp
*
WG
::
kM
*
2
;
}
else
{
return
MWarp
*
WG
::
kM
;
}
}();
constexpr
index_t
kNPerStep
=
NWarp
*
WG
::
kN
;
const
auto
block_origin
=
randval_dram_block_window_tmp
.
get_window_origin
();
...
...
@@ -116,15 +128,27 @@ struct BlockDropout<true, IsWG32_, IsStoreRandval_>
return
randval_lds_block_desc
;
}
template
<
typename
BlockGemm
>
template
<
typename
BlockGemm
,
bool
IsFwd
=
true
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
MakeRandValTileDistribution
()
{
constexpr
auto
config
=
BlockGemm
::
Policy
::
template
GetWarpGemmMWarpNWarp
<
typename
BlockGemm
::
Problem
>();
constexpr
index_t
MWarp
=
config
.
template
at
<
1
>();
constexpr
index_t
NWarp
=
config
.
template
at
<
2
>();
constexpr
index_t
MIterPerWarp
=
1
;
using
BlockGemmShape
=
remove_cvref_t
<
typename
BlockGemm
::
BlockGemmShape
>
;
constexpr
index_t
kMPerBlock
=
BlockGemmShape
::
kM
;
constexpr
index_t
MWarp
=
config
.
template
at
<
1
>();
constexpr
index_t
NWarp
=
config
.
template
at
<
2
>();
constexpr
bool
MBwdWG16MultiIterCheck
=
(
!
IsFwd
)
&&
(
!
IsWG32
)
&&
(
kMPerBlock
>
16
);
constexpr
index_t
MIterPerWarp
=
[
&
]()
{
if
constexpr
(
MBwdWG16MultiIterCheck
)
{
return
2
;
}
else
{
return
1
;
}
}();
constexpr
index_t
NIterPerWarp
=
1
;
constexpr
auto
randval_block_outer_part_dstr_encoding
=
tile_distribution_encoding
<
...
...
@@ -297,22 +321,34 @@ struct BlockDropout<true, IsWG32_, IsStoreRandval_>
{
constexpr
auto
config
=
BlockGemm
::
Policy
::
template
GetWarpGemmMWarpNWarp
<
typename
BlockGemm
::
Problem
>();
using
WG
=
remove_cvref_t
<
decltype
(
config
.
template
at
<
0
>())
>
;
constexpr
index_t
MWarp
=
config
.
template
at
<
1
>();
constexpr
index_t
NWarp
=
config
.
template
at
<
2
>();
using
BlockGemmShape
=
remove_cvref_t
<
typename
BlockGemm
::
BlockGemmShape
>
;
constexpr
index_t
kMPerBlock
=
BlockGemmShape
::
kM
;
constexpr
index_t
kNPerBlock
=
BlockGemmShape
::
kN
;
constexpr
index_t
kMPerStep
=
MWarp
*
WG
::
kM
;
constexpr
index_t
kNPerStep
=
NWarp
*
WG
::
kN
;
using
WG
=
remove_cvref_t
<
decltype
(
config
.
template
at
<
0
>())
>
;
constexpr
index_t
MWarp
=
config
.
template
at
<
1
>();
constexpr
index_t
NWarp
=
config
.
template
at
<
2
>();
using
BlockGemmShape
=
remove_cvref_t
<
typename
BlockGemm
::
BlockGemmShape
>
;
constexpr
index_t
kMPerBlock
=
BlockGemmShape
::
kM
;
constexpr
index_t
kNPerBlock
=
BlockGemmShape
::
kN
;
constexpr
bool
MBwdWG16MultiIterCheck
=
(
!
IsWG32
)
&&
(
kMPerBlock
>
16
);
constexpr
bool
MBwdWG16SingleIterCheck
=
(
!
IsWG32
)
&&
(
kMPerBlock
==
16
);
constexpr
index_t
kMPerStep
=
[
&
]()
{
if
constexpr
(
MBwdWG16MultiIterCheck
)
{
return
MWarp
*
WG
::
kM
*
2
;
}
else
{
return
MWarp
*
WG
::
kM
;
}
}();
constexpr
index_t
kNPerStep
=
NWarp
*
WG
::
kN
;
// register distribute
auto
randval
=
make_static_distributed_tensor
<
uint8_t
>
(
MakeRandValTileDistribution
<
BlockGemm
>
());
auto
randval
=
make_static_distributed_tensor
<
uint8_t
>
(
MakeRandValTileDistribution
<
BlockGemm
,
false
>
());
if
constexpr
(
IsWG32
)
static_assert
(
randval
.
kThreadElementSpaceSize
==
16
);
else
static_assert
(
randval
.
kThreadElementSpaceSize
==
4
);
static_assert
(
randval
.
kThreadElementSpaceSize
==
4
||
randval
.
kThreadElementSpaceSize
==
8
);
static_for
<
0
,
kNPerBlock
/
kNPerStep
,
1
>
{}([
&
](
auto
i_n0
)
{
static_for
<
0
,
kMPerBlock
/
kMPerStep
,
1
>
{}([
&
](
auto
i_m0
)
{
...
...
@@ -324,14 +360,14 @@ struct BlockDropout<true, IsWG32_, IsStoreRandval_>
}
else
{
block_row_start
=
start_m0_idx
/
32
;
block_col_start
=
(
start_n0_idx
/
32
)
+
get_warp_id
()
/
2
;
block_row_start
=
start_m0_idx
/
32
+
i_m0
;
block_col_start
=
(
start_n0_idx
/
32
)
+
get_warp_id
()
/
2
+
i_n0
*
2
;
}
uint2
rowcol
=
make_uint2
(
block_row_start
,
block_col_start
);
// generate random number
uint8_t
*
random_uint8_t_
;
if
constexpr
(
!
IsWG32
)
if
constexpr
(
MBwdWG16SingleIterCheck
)
{
uint8_t
random_uint8_t
[
4
];
// m0t0 ~m0t15/m0t32~m0t47: 0
...
...
@@ -344,6 +380,16 @@ struct BlockDropout<true, IsWG32_, IsStoreRandval_>
random_uint8_t
,
reinterpret_cast
<
unsigned
long
long
&>
(
rowcol
),
start_idx
);
random_uint8_t_
=
random_uint8_t
;
}
else
if
constexpr
(
MBwdWG16MultiIterCheck
)
{
uint8_t
random_uint8_t
[
8
];
// t0 ~t15/t32~t47: 0
// t16~t31/t48~t63: 1
const
index_t
start_idx
=
(
get_lane_id
()
>>
4
)
&
1
;
ph
.
get_random_8x8
(
random_uint8_t
,
reinterpret_cast
<
unsigned
long
long
&>
(
rowcol
),
start_idx
);
random_uint8_t_
=
random_uint8_t
;
}
else
{
uint8_t
random_uint8_t
[
16
];
...
...
@@ -356,10 +402,11 @@ struct BlockDropout<true, IsWG32_, IsStoreRandval_>
int
i_random_idx
=
0
;
sweep_tile_span
(
randval_spans
[
number
<
0
>
{}],
[
&
](
auto
idx0
)
{
sweep_tile_span
(
randval_spans
[
number
<
1
>
{}],
[
&
](
auto
idx1
)
{
constexpr
auto
r_idx
=
ck_tile
::
make_tuple
(
idx0
,
idx1
);
randval
(
r_idx
)
=
random_uint8_t_
[
i_random_idx
++
];
constexpr
auto
p_idx0
=
tile_distributed_index
<
i_m0
,
idx0
.
impl_
.
at
(
1
),
idx0
.
impl_
.
at
(
2
)
>
{};
constexpr
auto
r_idx
=
ck_tile
::
make_tuple
(
idx0
,
idx1
);
randval
(
r_idx
)
=
random_uint8_t_
[
i_random_idx
++
];
constexpr
auto
p_idx0
=
tile_distributed_index
<
i_m0
+
idx0
.
impl_
.
at
(
0
),
idx0
.
impl_
.
at
(
1
),
idx0
.
impl_
.
at
(
2
)
>
{};
constexpr
auto
p_idx1
=
tile_distributed_index
<
i_n0
>
{};
constexpr
auto
p_idx
=
ck_tile
::
make_tuple
(
p_idx0
,
p_idx1
);
p_compute
(
p_idx
)
=
randval
[
r_idx
]
<=
p_undrop_in_uint8_t
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment