Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel_ROCM
Commits
06f575a3
Commit
06f575a3
authored
Jul 21, 2024
by
danyao12
Browse files
refactor dropout
parent
99436cd4
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
63 additions
and
60 deletions
+63
-60
include/ck_tile/ops/fmha/block/block_dropout.hpp
include/ck_tile/ops/fmha/block/block_dropout.hpp
+11
-4
include/ck_tile/ops/fmha/kernel/fmha_bwd_kernel.hpp
include/ck_tile/ops/fmha/kernel/fmha_bwd_kernel.hpp
+20
-18
include/ck_tile/ops/fmha/kernel/fmha_fwd_kernel.hpp
include/ck_tile/ops/fmha/kernel/fmha_fwd_kernel.hpp
+16
-19
include/ck_tile/ops/fmha/kernel/fmha_fwd_splitkv_kernel.hpp
include/ck_tile/ops/fmha/kernel/fmha_fwd_splitkv_kernel.hpp
+16
-19
No files found.
include/ck_tile/ops/fmha/block/block_dropout.hpp
View file @
06f575a3
...
...
@@ -8,8 +8,15 @@
namespace
ck_tile
{
struct
NullBlockDropout
template
<
bool
IsDropout_
,
bool
IsWG32_
,
bool
IsStoreRandval_
>
struct
BlockDropout
;
template
<
bool
IsWG32_
,
bool
IsStoreRandval_
>
struct
BlockDropout
<
false
,
IsWG32_
,
IsStoreRandval_
>
{
static
constexpr
bool
IsDropout
=
false
;
static
constexpr
bool
IsStoreRandval
=
IsStoreRandval_
;
template
<
typename
BlockGemm
,
bool
IsFwd
=
true
,
typename
RandValDramBlockWindowTmp
>
__host__
__device__
static
constexpr
auto
MakeRandvalDramWindow
(
RandValDramBlockWindowTmp
&
randval_dram_block_window_tmp
,
...
...
@@ -22,10 +29,10 @@ struct NullBlockDropout
}
};
template
<
bool
Is
Dropout_
=
true
,
bool
IsWG32_
=
true
,
bool
IsStoreRandval_
=
false
>
struct
BlockDropout
template
<
bool
Is
WG32_
,
bool
IsStoreRandval_
>
struct
BlockDropout
<
true
,
IsWG32_
,
IsStoreRandval_
>
{
static
constexpr
bool
IsDropout
=
IsDropout_
;
static
constexpr
bool
IsDropout
=
true
;
// true: 32*32 warp gemm
// false: 16*16 warp gemm
static
constexpr
bool
IsWG32
=
IsWG32_
;
...
...
include/ck_tile/ops/fmha/kernel/fmha_bwd_kernel.hpp
View file @
06f575a3
...
...
@@ -915,27 +915,29 @@ struct FmhaBwdDQDKDVKernel
}();
// dropout
float
rp_undrop
=
1
;
float
scale_rp_undrop
=
1
;
uint8_t
p_undrop_in_uint8_t
=
std
::
numeric_limits
<
uint8_t
>::
max
();
uint64_t
drop_seed
=
0
;
uint64_t
drop_offset
=
0
;
float
rp_undrop
=
1
;
float
scale_rp_undrop
=
1
;
if
constexpr
(
kHasDropout
)
{
rp_undrop
=
kargs
.
rp_undrop
;
scale_rp_undrop
=
kargs
.
scale_rp_undrop
;
p_undrop_in_uint8_t
=
kargs
.
p_undrop_in_uint8_t
;
drop_seed
=
kargs
.
drop_seed
;
drop_offset
=
kargs
.
drop_offset
;
rp_undrop
=
kargs
.
rp_undrop
;
scale_rp_undrop
=
kargs
.
scale_rp_undrop
;
}
FmhaDropout
dropout
(
i_batch
,
i_nhead
,
kargs
.
num_head_q
,
drop_seed
,
drop_offset
,
rp_undrop
,
p_undrop_in_uint8_t
);
auto
dropout
=
[
&
,
i_nhead_
=
i_nhead
,
i_batch_
=
i_batch
]()
{
if
constexpr
(
kHasDropout
)
{
return
FmhaDropout
{
i_batch_
,
i_nhead_
,
kargs
.
num_head_q
,
kargs
.
drop_seed
,
kargs
.
drop_offset
,
kargs
.
rp_undrop
,
kargs
.
p_undrop_in_uint8_t
};
}
else
{
return
FmhaDropout
{};
};
}();
auto
randval_dram_window
=
[
&
,
i_nhead_
=
i_nhead
]()
{
constexpr
auto
randval_dram_window_lengths
=
...
...
include/ck_tile/ops/fmha/kernel/fmha_fwd_kernel.hpp
View file @
06f575a3
...
...
@@ -749,25 +749,22 @@ struct FmhaFwdKernel
}();
// dropout
float
rp_undrop
=
1
;
uint8_t
p_undrop_in_uint8_t
=
std
::
numeric_limits
<
uint8_t
>::
max
();
uint64_t
drop_seed
=
0
;
uint64_t
drop_offset
=
0
;
if
constexpr
(
kHasDropout
)
{
rp_undrop
=
kargs
.
rp_undrop
;
p_undrop_in_uint8_t
=
kargs
.
p_undrop_in_uint8_t
;
drop_seed
=
kargs
.
drop_seed
;
drop_offset
=
kargs
.
drop_offset
;
}
FmhaDropout
dropout
(
i_batch
,
i_nhead
,
kargs
.
num_head_q
,
drop_seed
,
drop_offset
,
rp_undrop
,
p_undrop_in_uint8_t
);
auto
dropout
=
[
&
,
i_nhead_
=
i_nhead
,
i_batch_
=
i_batch
]()
{
if
constexpr
(
kHasDropout
)
{
return
FmhaDropout
{
i_batch_
,
i_nhead_
,
kargs
.
num_head_q
,
kargs
.
drop_seed
,
kargs
.
drop_offset
,
kargs
.
rp_undrop
,
kargs
.
p_undrop_in_uint8_t
};
}
else
{
return
FmhaDropout
{};
};
}();
auto
randval_dram_window
=
[
&
,
i_nhead_
=
i_nhead
]()
{
constexpr
auto
randval_dram_window_lengths
=
...
...
include/ck_tile/ops/fmha/kernel/fmha_fwd_splitkv_kernel.hpp
View file @
06f575a3
...
...
@@ -747,25 +747,22 @@ struct FmhaFwdSplitKVKernel
}();
// dropout
float
rp_undrop
=
1
;
uint8_t
p_undrop_in_uint8_t
=
std
::
numeric_limits
<
uint8_t
>::
max
();
uint64_t
drop_seed
=
0
;
uint64_t
drop_offset
=
0
;
if
constexpr
(
kHasDropout
)
{
rp_undrop
=
kargs
.
rp_undrop
;
p_undrop_in_uint8_t
=
kargs
.
p_undrop_in_uint8_t
;
drop_seed
=
kargs
.
drop_seed
;
drop_offset
=
kargs
.
drop_offset
;
}
FmhaDropout
dropout
(
i_batch
,
i_nhead
,
kargs
.
num_head_q
,
drop_seed
,
drop_offset
,
rp_undrop
,
p_undrop_in_uint8_t
);
auto
dropout
=
[
&
,
i_nhead_
=
i_nhead
,
i_batch_
=
i_batch
]()
{
if
constexpr
(
kHasDropout
)
{
return
FmhaDropout
{
i_batch_
,
i_nhead_
,
kargs
.
num_head_q
,
kargs
.
drop_seed
,
kargs
.
drop_offset
,
kargs
.
rp_undrop
,
kargs
.
p_undrop_in_uint8_t
};
}
else
{
return
FmhaDropout
{};
};
}();
auto
randval_dram_window
=
[
&
,
i_nhead_
=
i_nhead
]()
{
constexpr
auto
randval_dram_window_lengths
=
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment