Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
d0c65caa
Commit
d0c65caa
authored
Feb 20, 2023
by
guangzlu
Browse files
added switch for lse storing in attn fwd
parent
54dfedcd
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
94 additions
and
46 deletions
+94
-46
include/ck/tensor_operation/gpu/device/impl/device_grouped_multihead_attention_forward_xdl_cshuffle.hpp
...vice_grouped_multihead_attention_forward_xdl_cshuffle.hpp
+94
-46
No files found.
include/ck/tensor_operation/gpu/device/impl/device_grouped_multihead_attention_forward_xdl_cshuffle.hpp
View file @
d0c65caa
...
...
@@ -32,7 +32,8 @@ template <typename GridwiseGemm,
typename
B1ElementwiseOperation
,
typename
CElementwiseOperation
,
bool
HasMainKBlockLoop
,
bool
IsDropout
>
bool
IsDropout
,
bool
IsLseStoring
>
__global__
void
#if CK_USE_LAUNCH_BOUNDS
__launch_bounds__
(
CK_MAX_THREAD_PER_BLOCK
,
CK_MIN_BLOCK_PER_CU
)
...
...
@@ -97,18 +98,16 @@ __global__ void
const
long_index_t
lse_batch_offset
=
__builtin_amdgcn_readfirstlane
(
static_cast
<
long_index_t
>
(
arg_ptr
[
group_id
].
compute_base_ptr_of_batch_
.
GetLSEBasePtr
(
g_idx
)));
// unsigned short* p_z_grid_in = //
// (arg_ptr[group_id].p_z_grid_ == nullptr ? nullptr
// : arg_ptr[group_id].p_z_grid_ + z_batch_offset);
GridwiseGemm
::
template
Run
<
HasMainKBlockLoop
,
IsDropout
>(
GridwiseGemm
::
template
Run
<
HasMainKBlockLoop
,
IsDropout
,
IsLseStoring
>(
arg_ptr
[
group_id
].
p_a_grid_
+
a_batch_offset
,
arg_ptr
[
group_id
].
p_b_grid_
+
b_batch_offset
,
arg_ptr
[
group_id
].
p_b1_grid_
+
b1_batch_offset
,
arg_ptr
[
group_id
].
p_c_grid_
+
c_batch_offset
,
arg_ptr
[
group_id
].
p_z_grid_
==
nullptr
?
nullptr
:
arg_ptr
[
group_id
].
p_z_grid_
+
z_batch_offset
,
arg_ptr
[
group_id
].
p_lse_grid_
+
lse_batch_offset
,
arg_ptr
[
group_id
].
p_lse_grid_
==
nullptr
?
nullptr
:
arg_ptr
[
group_id
].
p_lse_grid_
+
lse_batch_offset
,
// arg_ptr[group_id].p_lse_grid_ + lse_batch_offset,
p_shared
,
a_element_op
,
b_element_op
,
...
...
@@ -589,6 +588,11 @@ struct DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle
const
auto
p_z_grid
=
static_cast
<
ZDataType
*>
(
p_z_vec
[
i
]);
const
auto
p_lse_grid
=
static_cast
<
LSEDataType
*>
(
p_lse_vec
[
i
]);
if
(
p_lse_grid
==
nullptr
)
{
is_lse_storing_
=
false
;
}
const
auto
&
problem_desc
=
problem_desc_vec
[
i
];
const
auto
a_grid_desc_ak0_m_ak1
=
MakeAGridDescriptor_AK0_M_AK1
(
...
...
@@ -724,6 +728,8 @@ struct DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle
unsigned
long
long
offset_
;
GemmAccDataType
p_dropout_rescale_
;
bool
is_dropout_
;
bool
is_lse_storing_
=
true
;
};
// Invoker
...
...
@@ -756,7 +762,8 @@ struct DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle
float
ave_time
=
0
;
auto
launch_kernel
=
[
&
](
auto
has_main_k_block_loop_
,
auto
is_dropout_
)
{
auto
launch_kernel
=
[
&
](
auto
has_main_k_block_loop_
,
auto
is_dropout_
,
auto
is_lse_storing_
)
{
const
auto
kernel
=
kernel_grouped_gemm_softmax_gemm_xdl_cshuffle_v2
<
GridwiseGemm
,
GemmAccDataType
,
...
...
@@ -767,7 +774,8 @@ struct DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle
B1ElementwiseOperation
,
CElementwiseOperation
,
has_main_k_block_loop_
,
is_dropout_
>
;
is_dropout_
,
is_lse_storing_
>
;
return
launch_and_time_kernel
(
stream_config
,
...
...
@@ -793,29 +801,69 @@ struct DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle
if
(
all_has_main_k_block_loop
)
{
if
(
arg
.
is_dropout_
)
{
if
(
arg
.
is_lse_storing_
)
{
ave_time
=
launch_kernel
(
integral_constant
<
bool
,
true
>
{},
integral_constant
<
bool
,
true
>
{},
integral_constant
<
bool
,
true
>
{});
}
else
{
ave_time
=
launch_kernel
(
integral_constant
<
bool
,
true
>
{},
integral_constant
<
bool
,
true
>
{},
integral_constant
<
bool
,
false
>
{});
}
}
else
{
if
(
arg
.
is_lse_storing_
)
{
ave_time
=
launch_kernel
(
integral_constant
<
bool
,
true
>
{},
integral_constant
<
bool
,
false
>
{},
integral_constant
<
bool
,
true
>
{});
}
else
{
ave_time
=
launch_kernel
(
integral_constant
<
bool
,
true
>
{},
integral_constant
<
bool
,
false
>
{},
integral_constant
<
bool
,
false
>
{});
}
}
}
else
if
(
!
some_has_main_k_block_loop
)
{
if
(
arg
.
is_dropout_
)
{
if
(
arg
.
is_lse_storing_
)
{
ave_time
=
launch_kernel
(
integral_constant
<
bool
,
false
>
{},
integral_constant
<
bool
,
true
>
{},
integral_constant
<
bool
,
true
>
{});
}
else
{
ave_time
=
launch_kernel
(
integral_constant
<
bool
,
false
>
{},
integral_constant
<
bool
,
true
>
{},
integral_constant
<
bool
,
false
>
{});
}
}
else
{
if
(
arg
.
is_lse_storing_
)
{
ave_time
=
launch_kernel
(
integral_constant
<
bool
,
false
>
{},
integral_constant
<
bool
,
false
>
{},
integral_constant
<
bool
,
true
>
{});
}
else
{
ave_time
=
launch_kernel
(
integral_constant
<
bool
,
false
>
{},
integral_constant
<
bool
,
false
>
{},
integral_constant
<
bool
,
false
>
{});
}
}
}
else
{
throw
std
::
runtime_error
(
"wrong! all gemm problems have to simultaneously meet "
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment