Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
d0c65caa
Commit
d0c65caa
authored
Feb 20, 2023
by
guangzlu
Browse files
added switch for lse storing in attn fwd
parent
54dfedcd
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
94 additions
and
46 deletions
+94
-46
include/ck/tensor_operation/gpu/device/impl/device_grouped_multihead_attention_forward_xdl_cshuffle.hpp
...vice_grouped_multihead_attention_forward_xdl_cshuffle.hpp
+94
-46
No files found.
include/ck/tensor_operation/gpu/device/impl/device_grouped_multihead_attention_forward_xdl_cshuffle.hpp
View file @
d0c65caa
...
@@ -32,7 +32,8 @@ template <typename GridwiseGemm,
...
@@ -32,7 +32,8 @@ template <typename GridwiseGemm,
typename
B1ElementwiseOperation
,
typename
B1ElementwiseOperation
,
typename
CElementwiseOperation
,
typename
CElementwiseOperation
,
bool
HasMainKBlockLoop
,
bool
HasMainKBlockLoop
,
bool
IsDropout
>
bool
IsDropout
,
bool
IsLseStoring
>
__global__
void
__global__
void
#if CK_USE_LAUNCH_BOUNDS
#if CK_USE_LAUNCH_BOUNDS
__launch_bounds__
(
CK_MAX_THREAD_PER_BLOCK
,
CK_MIN_BLOCK_PER_CU
)
__launch_bounds__
(
CK_MAX_THREAD_PER_BLOCK
,
CK_MIN_BLOCK_PER_CU
)
...
@@ -97,18 +98,16 @@ __global__ void
...
@@ -97,18 +98,16 @@ __global__ void
const
long_index_t
lse_batch_offset
=
__builtin_amdgcn_readfirstlane
(
static_cast
<
long_index_t
>
(
const
long_index_t
lse_batch_offset
=
__builtin_amdgcn_readfirstlane
(
static_cast
<
long_index_t
>
(
arg_ptr
[
group_id
].
compute_base_ptr_of_batch_
.
GetLSEBasePtr
(
g_idx
)));
arg_ptr
[
group_id
].
compute_base_ptr_of_batch_
.
GetLSEBasePtr
(
g_idx
)));
// unsigned short* p_z_grid_in = //
GridwiseGemm
::
template
Run
<
HasMainKBlockLoop
,
IsDropout
,
IsLseStoring
>(
// (arg_ptr[group_id].p_z_grid_ == nullptr ? nullptr
// : arg_ptr[group_id].p_z_grid_ + z_batch_offset);
GridwiseGemm
::
template
Run
<
HasMainKBlockLoop
,
IsDropout
>(
arg_ptr
[
group_id
].
p_a_grid_
+
a_batch_offset
,
arg_ptr
[
group_id
].
p_a_grid_
+
a_batch_offset
,
arg_ptr
[
group_id
].
p_b_grid_
+
b_batch_offset
,
arg_ptr
[
group_id
].
p_b_grid_
+
b_batch_offset
,
arg_ptr
[
group_id
].
p_b1_grid_
+
b1_batch_offset
,
arg_ptr
[
group_id
].
p_b1_grid_
+
b1_batch_offset
,
arg_ptr
[
group_id
].
p_c_grid_
+
c_batch_offset
,
arg_ptr
[
group_id
].
p_c_grid_
+
c_batch_offset
,
arg_ptr
[
group_id
].
p_z_grid_
==
nullptr
?
nullptr
arg_ptr
[
group_id
].
p_z_grid_
==
nullptr
?
nullptr
:
arg_ptr
[
group_id
].
p_z_grid_
+
z_batch_offset
,
:
arg_ptr
[
group_id
].
p_z_grid_
+
z_batch_offset
,
arg_ptr
[
group_id
].
p_lse_grid_
+
lse_batch_offset
,
arg_ptr
[
group_id
].
p_lse_grid_
==
nullptr
?
nullptr
:
arg_ptr
[
group_id
].
p_lse_grid_
+
lse_batch_offset
,
// arg_ptr[group_id].p_lse_grid_ + lse_batch_offset,
p_shared
,
p_shared
,
a_element_op
,
a_element_op
,
b_element_op
,
b_element_op
,
...
@@ -589,6 +588,11 @@ struct DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle
...
@@ -589,6 +588,11 @@ struct DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle
const
auto
p_z_grid
=
static_cast
<
ZDataType
*>
(
p_z_vec
[
i
]);
const
auto
p_z_grid
=
static_cast
<
ZDataType
*>
(
p_z_vec
[
i
]);
const
auto
p_lse_grid
=
static_cast
<
LSEDataType
*>
(
p_lse_vec
[
i
]);
const
auto
p_lse_grid
=
static_cast
<
LSEDataType
*>
(
p_lse_vec
[
i
]);
if
(
p_lse_grid
==
nullptr
)
{
is_lse_storing_
=
false
;
}
const
auto
&
problem_desc
=
problem_desc_vec
[
i
];
const
auto
&
problem_desc
=
problem_desc_vec
[
i
];
const
auto
a_grid_desc_ak0_m_ak1
=
MakeAGridDescriptor_AK0_M_AK1
(
const
auto
a_grid_desc_ak0_m_ak1
=
MakeAGridDescriptor_AK0_M_AK1
(
...
@@ -724,6 +728,8 @@ struct DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle
...
@@ -724,6 +728,8 @@ struct DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle
unsigned
long
long
offset_
;
unsigned
long
long
offset_
;
GemmAccDataType
p_dropout_rescale_
;
GemmAccDataType
p_dropout_rescale_
;
bool
is_dropout_
;
bool
is_dropout_
;
bool
is_lse_storing_
=
true
;
};
};
// Invoker
// Invoker
...
@@ -756,37 +762,39 @@ struct DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle
...
@@ -756,37 +762,39 @@ struct DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle
float
ave_time
=
0
;
float
ave_time
=
0
;
auto
launch_kernel
=
[
&
](
auto
has_main_k_block_loop_
,
auto
is_dropout_
)
{
auto
launch_kernel
=
const
auto
kernel
=
[
&
](
auto
has_main_k_block_loop_
,
auto
is_dropout_
,
auto
is_lse_storing_
)
{
kernel_grouped_gemm_softmax_gemm_xdl_cshuffle_v2
<
GridwiseGemm
,
const
auto
kernel
=
GemmAccDataType
,
kernel_grouped_gemm_softmax_gemm_xdl_cshuffle_v2
<
GridwiseGemm
,
GroupKernelArg
,
GemmAccDataType
,
AElementwiseOperation
,
GroupKernelArg
,
BElementwiseOperation
,
AElementwiseOperation
,
AccElementwiseOperation
,
BElementwiseOperation
,
B1ElementwiseOperation
,
AccElementwiseOperation
,
CElementwiseOperation
,
B1ElementwiseOperation
,
has_main_k_block_loop_
,
CElementwiseOperation
,
is_dropout_
>
;
has_main_k_block_loop_
,
is_dropout_
,
return
launch_and_time_kernel
(
is_lse_storing_
>
;
stream_config
,
kernel
,
return
launch_and_time_kernel
(
dim3
(
arg
.
grid_size_
),
stream_config
,
dim3
(
BlockSize
),
kernel
,
0
,
dim3
(
arg
.
grid_size_
),
cast_pointer_to_constant_address_space
(
arg
.
p_workspace_
),
dim3
(
BlockSize
),
arg
.
group_count_
,
0
,
arg
.
a_element_op_
,
cast_pointer_to_constant_address_space
(
arg
.
p_workspace_
),
arg
.
b_element_op_
,
arg
.
group_count_
,
arg
.
acc_element_op_
,
arg
.
a_element_op_
,
arg
.
b1_element_op_
,
arg
.
b_element_op_
,
arg
.
c_element_op_
,
arg
.
acc_element_op_
,
arg
.
p_dropout_in_16bits_
,
arg
.
b1_element_op_
,
arg
.
p_dropout_rescale_
,
arg
.
c_element_op_
,
arg
.
seed_
,
arg
.
p_dropout_in_16bits_
,
arg
.
offset_
);
arg
.
p_dropout_rescale_
,
};
arg
.
seed_
,
arg
.
offset_
);
};
// Gemm1_K is split into Gemm1_K0/K1 where K1 is known at compile time, so we only need
// Gemm1_K is split into Gemm1_K0/K1 where K1 is known at compile time, so we only need
// to concern Gemm0's loop
// to concern Gemm0's loop
...
@@ -794,26 +802,66 @@ struct DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle
...
@@ -794,26 +802,66 @@ struct DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle
{
{
if
(
arg
.
is_dropout_
)
if
(
arg
.
is_dropout_
)
{
{
ave_time
=
launch_kernel
(
integral_constant
<
bool
,
true
>
{},
if
(
arg
.
is_lse_storing_
)
integral_constant
<
bool
,
true
>
{});
{
ave_time
=
launch_kernel
(
integral_constant
<
bool
,
true
>
{},
integral_constant
<
bool
,
true
>
{},
integral_constant
<
bool
,
true
>
{});
}
else
{
ave_time
=
launch_kernel
(
integral_constant
<
bool
,
true
>
{},
integral_constant
<
bool
,
true
>
{},
integral_constant
<
bool
,
false
>
{});
}
}
}
else
else
{
{
ave_time
=
launch_kernel
(
integral_constant
<
bool
,
true
>
{},
if
(
arg
.
is_lse_storing_
)
integral_constant
<
bool
,
false
>
{});
{
ave_time
=
launch_kernel
(
integral_constant
<
bool
,
true
>
{},
integral_constant
<
bool
,
false
>
{},
integral_constant
<
bool
,
true
>
{});
}
else
{
ave_time
=
launch_kernel
(
integral_constant
<
bool
,
true
>
{},
integral_constant
<
bool
,
false
>
{},
integral_constant
<
bool
,
false
>
{});
}
}
}
}
}
else
if
(
!
some_has_main_k_block_loop
)
else
if
(
!
some_has_main_k_block_loop
)
{
{
if
(
arg
.
is_dropout_
)
if
(
arg
.
is_dropout_
)
{
{
ave_time
=
launch_kernel
(
integral_constant
<
bool
,
false
>
{},
if
(
arg
.
is_lse_storing_
)
integral_constant
<
bool
,
true
>
{});
{
ave_time
=
launch_kernel
(
integral_constant
<
bool
,
false
>
{},
integral_constant
<
bool
,
true
>
{},
integral_constant
<
bool
,
true
>
{});
}
else
{
ave_time
=
launch_kernel
(
integral_constant
<
bool
,
false
>
{},
integral_constant
<
bool
,
true
>
{},
integral_constant
<
bool
,
false
>
{});
}
}
}
else
else
{
{
ave_time
=
launch_kernel
(
integral_constant
<
bool
,
false
>
{},
if
(
arg
.
is_lse_storing_
)
integral_constant
<
bool
,
false
>
{});
{
ave_time
=
launch_kernel
(
integral_constant
<
bool
,
false
>
{},
integral_constant
<
bool
,
false
>
{},
integral_constant
<
bool
,
true
>
{});
}
else
{
ave_time
=
launch_kernel
(
integral_constant
<
bool
,
false
>
{},
integral_constant
<
bool
,
false
>
{},
integral_constant
<
bool
,
false
>
{});
}
}
}
}
}
else
else
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment