Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
cdc6f6ba
Commit
cdc6f6ba
authored
Feb 13, 2023
by
guangzlu
Browse files
fixed some bugs in fwd drop verify
parent
bf80ceee
Changes
4
Show whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
93 additions
and
30 deletions
+93
-30
example/32_batched_gemm_scale_softmax_gemm/grouped_multihead_attention_forward_fp16.cpp
...softmax_gemm/grouped_multihead_attention_forward_fp16.cpp
+0
-1
example/32_batched_gemm_scale_softmax_gemm/run_grouped_multihead_attention_forward.inc
..._softmax_gemm/run_grouped_multihead_attention_forward.inc
+11
-4
include/ck/tensor_operation/gpu/device/impl/device_grouped_multihead_attention_forward_xdl_cshuffle.hpp
...vice_grouped_multihead_attention_forward_xdl_cshuffle.hpp
+70
-15
include/ck/tensor_operation/gpu/grid/gridwise_batched_multihead_attention_forward_xdl_cshuffle.hpp
...wise_batched_multihead_attention_forward_xdl_cshuffle.hpp
+12
-10
No files found.
example/32_batched_gemm_scale_softmax_gemm/grouped_multihead_attention_forward_fp16.cpp
View file @
cdc6f6ba
...
@@ -163,5 +163,4 @@ using ReferenceGemm1Instance = ck::tensor_operation::host::ReferenceBatchedGemm<
...
@@ -163,5 +163,4 @@ using ReferenceGemm1Instance = ck::tensor_operation::host::ReferenceBatchedGemm<
#include "run_grouped_multihead_attention_forward.inc"
#include "run_grouped_multihead_attention_forward.inc"
int
main
(
int
argc
,
char
*
argv
[])
{
return
run
(
argc
,
argv
);
}
int
main
(
int
argc
,
char
*
argv
[])
{
return
run
(
argc
,
argv
);
}
example/32_batched_gemm_scale_softmax_gemm/run_grouped_multihead_attention_forward.inc
View file @
cdc6f6ba
...
@@ -10,6 +10,13 @@ int run(int argc, char* argv[])
...
@@ -10,6 +10,13 @@ int run(int argc, char* argv[])
bool
input_permute
=
false
;
bool
input_permute
=
false
;
bool
output_permute
=
true
;
bool
output_permute
=
true
;
float
p_drop
=
0.2
;
float
p_dropout
=
1
-
p_drop
;
uint16_t
p_dropout_in_16bits
=
uint16_t
(
std
::
floor
(
p_dropout
*
65535.0
));
float
rp_dropout
=
1.0
/
p_dropout
;
const
unsigned
long
long
seed
=
1
;
const
unsigned
long
long
offset
=
0
;
if
(
argc
==
1
)
if
(
argc
==
1
)
{
{
// use default case
// use default case
...
@@ -152,12 +159,12 @@ int run(int argc, char* argv[])
...
@@ -152,12 +159,12 @@ int run(int argc, char* argv[])
<<
"b0_gs_ns_ks["
<<
i
<<
"]: "
<<
b0_gs_ns_ks
.
mDesc
<<
", "
<<
"b0_gs_ns_ks["
<<
i
<<
"]: "
<<
b0_gs_ns_ks
.
mDesc
<<
", "
<<
"b1_gs_os_ns["
<<
i
<<
"]: "
<<
b1_gs_os_ns
.
mDesc
<<
", "
<<
"b1_gs_os_ns["
<<
i
<<
"]: "
<<
b1_gs_os_ns
.
mDesc
<<
", "
<<
"c_gs_ms_os["
<<
i
<<
"]: "
<<
c_gs_ms_os_device_result
.
mDesc
<<
", "
<<
"c_gs_ms_os["
<<
i
<<
"]: "
<<
c_gs_ms_os_device_result
.
mDesc
<<
", "
<<
"
c
_gs_ms_
o
s["
<<
i
<<
"]: "
<<
c
_gs_ms_
os_device_result
.
mDesc
<<
", "
<<
"
z
_gs_ms_
n
s["
<<
i
<<
"]: "
<<
z
_gs_ms_
ns
.
mDesc
<<
", "
<<
"lse_gs_ms_os["
<<
i
<<
"]: "
<<
lse_gs_ms_device_result
.
mDesc
<<
"lse_gs_ms_os["
<<
i
<<
"]: "
<<
lse_gs_ms_device_result
.
mDesc
<<
std
::
endl
;
<<
std
::
endl
;
}
}
z_gs_ms_ns
.
GenerateTensorValue
(
GeneratorTensor_1
<
DataType
>
{
0
});
z_gs_ms_ns
.
GenerateTensorValue
(
GeneratorTensor_1
<
Z
DataType
>
{
0
});
switch
(
init_method
)
switch
(
init_method
)
{
{
...
@@ -238,7 +245,7 @@ int run(int argc, char* argv[])
...
@@ -238,7 +245,7 @@ int run(int argc, char* argv[])
acc0_element_op
,
acc0_element_op
,
b1_element_op
,
b1_element_op
,
c_element_op
,
c_element_op
,
0
,
// dropout ratio
p_drop
,
// dropout ratio
{
0
,
448
});
// dropout random seed and offset, offset should be
{
0
,
448
});
// dropout random seed and offset, offset should be
// at least the number of elements on a thread
// at least the number of elements on a thread
...
...
include/ck/tensor_operation/gpu/device/impl/device_grouped_multihead_attention_forward_xdl_cshuffle.hpp
View file @
cdc6f6ba
...
@@ -37,7 +37,7 @@ __global__ void
...
@@ -37,7 +37,7 @@ __global__ void
#if CK_USE_LAUNCH_BOUNDS
#if CK_USE_LAUNCH_BOUNDS
__launch_bounds__
(
CK_MAX_THREAD_PER_BLOCK
,
CK_MIN_BLOCK_PER_CU
)
__launch_bounds__
(
CK_MAX_THREAD_PER_BLOCK
,
CK_MIN_BLOCK_PER_CU
)
#endif
#endif
kernel_grouped_
multiheadattention_forward
_xdl_cshuffle
(
kernel_grouped_
gemm_softmax_gemm
_xdl_cshuffle
_v2
(
const
void
CK_CONSTANT_ADDRESS_SPACE
*
group_kernel_args
,
const
void
CK_CONSTANT_ADDRESS_SPACE
*
group_kernel_args
,
const
index_t
group_count
,
const
index_t
group_count
,
const
AElementwiseOperation
a_element_op
,
const
AElementwiseOperation
a_element_op
,
...
@@ -92,14 +92,22 @@ __global__ void
...
@@ -92,14 +92,22 @@ __global__ void
arg_ptr
[
group_id
].
compute_base_ptr_of_batch_
.
GetB1BasePtr
(
g_idx
)));
arg_ptr
[
group_id
].
compute_base_ptr_of_batch_
.
GetB1BasePtr
(
g_idx
)));
const
long_index_t
c_batch_offset
=
__builtin_amdgcn_readfirstlane
(
const
long_index_t
c_batch_offset
=
__builtin_amdgcn_readfirstlane
(
static_cast
<
long_index_t
>
(
arg_ptr
[
group_id
].
compute_base_ptr_of_batch_
.
GetCBasePtr
(
g_idx
)));
static_cast
<
long_index_t
>
(
arg_ptr
[
group_id
].
compute_base_ptr_of_batch_
.
GetCBasePtr
(
g_idx
)));
const
long_index_t
z_batch_offset
=
__builtin_amdgcn_readfirstlane
(
static_cast
<
long_index_t
>
(
arg_ptr
[
group_id
].
compute_base_ptr_of_batch_
.
GetZBasePtr
(
g_idx
)));
const
long_index_t
lse_batch_offset
=
__builtin_amdgcn_readfirstlane
(
static_cast
<
long_index_t
>
(
const
long_index_t
lse_batch_offset
=
__builtin_amdgcn_readfirstlane
(
static_cast
<
long_index_t
>
(
arg_ptr
[
group_id
].
compute_base_ptr_of_batch_
.
GetLSEBasePtr
(
g_idx
)));
arg_ptr
[
group_id
].
compute_base_ptr_of_batch_
.
GetLSEBasePtr
(
g_idx
)));
// unsigned short* p_z_grid_in = //
// (arg_ptr[group_id].p_z_grid_ == nullptr ? nullptr
// : arg_ptr[group_id].p_z_grid_ + z_batch_offset);
GridwiseGemm
::
template
Run
<
HasMainKBlockLoop
,
IsDropout
>(
GridwiseGemm
::
template
Run
<
HasMainKBlockLoop
,
IsDropout
>(
arg_ptr
[
group_id
].
p_a_grid_
+
a_batch_offset
,
arg_ptr
[
group_id
].
p_a_grid_
+
a_batch_offset
,
arg_ptr
[
group_id
].
p_b_grid_
+
b_batch_offset
,
arg_ptr
[
group_id
].
p_b_grid_
+
b_batch_offset
,
arg_ptr
[
group_id
].
p_b1_grid_
+
b1_batch_offset
,
arg_ptr
[
group_id
].
p_b1_grid_
+
b1_batch_offset
,
arg_ptr
[
group_id
].
p_c_grid_
+
c_batch_offset
,
arg_ptr
[
group_id
].
p_c_grid_
+
c_batch_offset
,
arg_ptr
[
group_id
].
p_z_grid_
==
nullptr
?
nullptr
:
arg_ptr
[
group_id
].
p_z_grid_
+
z_batch_offset
,
arg_ptr
[
group_id
].
p_lse_grid_
+
lse_batch_offset
,
arg_ptr
[
group_id
].
p_lse_grid_
+
lse_batch_offset
,
p_shared
,
p_shared
,
a_element_op
,
a_element_op
,
...
@@ -111,6 +119,7 @@ __global__ void
...
@@ -111,6 +119,7 @@ __global__ void
arg_ptr
[
group_id
].
b_grid_desc_bk0_n_bk1_
,
arg_ptr
[
group_id
].
b_grid_desc_bk0_n_bk1_
,
arg_ptr
[
group_id
].
b1_grid_desc_bk0_n_bk1_
,
arg_ptr
[
group_id
].
b1_grid_desc_bk0_n_bk1_
,
arg_ptr
[
group_id
].
c_grid_desc_mblock_mperblock_nblock_nperblock_
,
arg_ptr
[
group_id
].
c_grid_desc_mblock_mperblock_nblock_nperblock_
,
arg_ptr
[
group_id
].
z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5_
,
////////
arg_ptr
[
group_id
].
lse_grid_desc_m_
,
arg_ptr
[
group_id
].
lse_grid_desc_m_
,
arg_ptr
[
group_id
].
block_2_ctile_map_
,
arg_ptr
[
group_id
].
block_2_ctile_map_
,
arg_ptr
[
group_id
].
c0_matrix_mask_
,
arg_ptr
[
group_id
].
c0_matrix_mask_
,
...
@@ -140,6 +149,7 @@ template <index_t NumDimG,
...
@@ -140,6 +149,7 @@ template <index_t NumDimG,
typename
BDataType
,
typename
BDataType
,
typename
B1DataType
,
typename
B1DataType
,
typename
CDataType
,
typename
CDataType
,
typename
ZDataType
,
typename
LSEDataType
,
typename
LSEDataType
,
typename
Acc0BiasDataType
,
typename
Acc0BiasDataType
,
typename
Acc1BiasDataType
,
typename
Acc1BiasDataType
,
...
@@ -207,6 +217,7 @@ struct DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle
...
@@ -207,6 +217,7 @@ struct DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle
BDataType
,
BDataType
,
B1DataType
,
B1DataType
,
CDataType
,
CDataType
,
ZDataType
,
LSEDataType
,
LSEDataType
,
Acc0BiasDataType
,
Acc0BiasDataType
,
Acc1BiasDataType
,
Acc1BiasDataType
,
...
@@ -246,6 +257,7 @@ struct DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle
...
@@ -246,6 +257,7 @@ struct DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle
BDataType
,
BDataType
,
B1DataType
,
B1DataType
,
CDataType
,
CDataType
,
ZDataType
,
LSEDataType
,
LSEDataType
,
Acc0BiasDataType
,
Acc0BiasDataType
,
Acc1BiasDataType
,
Acc1BiasDataType
,
...
@@ -295,6 +307,12 @@ struct DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle
...
@@ -295,6 +307,12 @@ struct DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle
Number
<
B1K1
>
{});
Number
<
B1K1
>
{});
}
}
static
auto
MakeZGridDescriptor_M_N
(
const
std
::
vector
<
index_t
>&
z_gs_ms_ns_lengths_vec
,
const
std
::
vector
<
index_t
>&
z_gs_ms_ns_strides_vec
)
{
return
Transform
::
MakeCGridDescriptor_M_N
(
z_gs_ms_ns_lengths_vec
,
z_gs_ms_ns_strides_vec
);
}
static
auto
MakeLSEGridDescriptor_M
(
index_t
MRaw
)
static
auto
MakeLSEGridDescriptor_M
(
index_t
MRaw
)
{
{
const
auto
lse_grid_desc_mraw
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
MRaw
));
const
auto
lse_grid_desc_mraw
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
MRaw
));
...
@@ -325,10 +343,13 @@ struct DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle
...
@@ -325,10 +343,13 @@ struct DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle
using
B1GridDesc_BK0_N_BK1
=
decltype
(
MakeB1GridDescriptor_BK0_N_BK1
({},
{}));
using
B1GridDesc_BK0_N_BK1
=
decltype
(
MakeB1GridDescriptor_BK0_N_BK1
({},
{}));
using
CGridDesc_M_N
=
decltype
(
Transform
::
MakeCGridDescriptor_M_N
({},
{}));
using
CGridDesc_M_N
=
decltype
(
Transform
::
MakeCGridDescriptor_M_N
({},
{}));
using
LSEGridDesc_M
=
decltype
(
MakeLSEGridDescriptor_M
(
1
));
using
LSEGridDesc_M
=
decltype
(
MakeLSEGridDescriptor_M
(
1
));
using
ZGridDesc_M_N
=
decltype
(
MakeZGridDescriptor_M_N
({},
{}));
using
AGridDesc_G_M_K
=
decltype
(
Transform
::
MakeAGridDescriptor_G_M_K
({},
{}));
using
AGridDesc_G_M_K
=
decltype
(
Transform
::
MakeAGridDescriptor_G_M_K
({},
{}));
using
BGridDesc_G_N_K
=
decltype
(
Transform
::
MakeB0GridDescriptor_G_N_K
({},
{}));
using
BGridDesc_G_N_K
=
decltype
(
Transform
::
MakeB0GridDescriptor_G_N_K
({},
{}));
using
B1GridDesc_G_N_K
=
decltype
(
Transform
::
MakeB1GridDescriptor_G_N_K
({},
{}));
using
B1GridDesc_G_N_K
=
decltype
(
Transform
::
MakeB1GridDescriptor_G_N_K
({},
{}));
using
CGridDesc_G_M_N
=
decltype
(
Transform
::
MakeCGridDescriptor_G_M_N
({},
{}));
using
CGridDesc_G_M_N
=
decltype
(
Transform
::
MakeCGridDescriptor_G_M_N
({},
{}));
using
ZGridDesc_G_M_N
=
decltype
(
Transform
::
MakeCGridDescriptor_G_M_N
({},
{}));
constexpr
static
auto
make_MaskOutPredicate
()
constexpr
static
auto
make_MaskOutPredicate
()
{
{
...
@@ -349,11 +370,13 @@ struct DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle
...
@@ -349,11 +370,13 @@ struct DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle
const
BGridDesc_G_N_K
&
b_grid_desc_g_n_k
,
const
BGridDesc_G_N_K
&
b_grid_desc_g_n_k
,
const
B1GridDesc_G_N_K
&
b1_grid_desc_g_n_k
,
const
B1GridDesc_G_N_K
&
b1_grid_desc_g_n_k
,
const
CGridDesc_G_M_N
&
c_grid_desc_g_m_n
,
const
CGridDesc_G_M_N
&
c_grid_desc_g_m_n
,
const
ZGridDesc_G_M_N
&
z_grid_desc_g_m_n
,
index_t
BatchStrideLSE
)
index_t
BatchStrideLSE
)
:
a_grid_desc_g_m_k_
(
a_grid_desc_g_m_k
),
:
a_grid_desc_g_m_k_
(
a_grid_desc_g_m_k
),
b_grid_desc_g_n_k_
(
b_grid_desc_g_n_k
),
b_grid_desc_g_n_k_
(
b_grid_desc_g_n_k
),
b1_grid_desc_g_n_k_
(
b1_grid_desc_g_n_k
),
b1_grid_desc_g_n_k_
(
b1_grid_desc_g_n_k
),
c_grid_desc_g_m_n_
(
c_grid_desc_g_m_n
),
c_grid_desc_g_m_n_
(
c_grid_desc_g_m_n
),
z_grid_desc_g_m_n_
(
z_grid_desc_g_m_n
),
BatchStrideLSE_
(
BatchStrideLSE
)
BatchStrideLSE_
(
BatchStrideLSE
)
{
{
}
}
...
@@ -378,6 +401,11 @@ struct DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle
...
@@ -378,6 +401,11 @@ struct DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle
return
c_grid_desc_g_m_n_
.
CalculateOffset
(
make_multi_index
(
g_idx
,
0
,
0
));
return
c_grid_desc_g_m_n_
.
CalculateOffset
(
make_multi_index
(
g_idx
,
0
,
0
));
}
}
__host__
__device__
constexpr
long_index_t
GetZBasePtr
(
index_t
g_idx
)
const
{
return
z_grid_desc_g_m_n_
.
CalculateOffset
(
make_multi_index
(
g_idx
,
0
,
0
));
}
__host__
__device__
constexpr
long_index_t
GetLSEBasePtr
(
index_t
g_idx
)
const
__host__
__device__
constexpr
long_index_t
GetLSEBasePtr
(
index_t
g_idx
)
const
{
{
return
g_idx
*
static_cast
<
long_index_t
>
(
BatchStrideLSE_
);
return
g_idx
*
static_cast
<
long_index_t
>
(
BatchStrideLSE_
);
...
@@ -388,6 +416,7 @@ struct DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle
...
@@ -388,6 +416,7 @@ struct DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle
BGridDesc_G_N_K
b_grid_desc_g_n_k_
;
BGridDesc_G_N_K
b_grid_desc_g_n_k_
;
B1GridDesc_G_N_K
b1_grid_desc_g_n_k_
;
B1GridDesc_G_N_K
b1_grid_desc_g_n_k_
;
CGridDesc_G_M_N
c_grid_desc_g_m_n_
;
CGridDesc_G_M_N
c_grid_desc_g_m_n_
;
ZGridDesc_G_M_N
z_grid_desc_g_m_n_
;
index_t
BatchStrideLSE_
;
index_t
BatchStrideLSE_
;
};
};
...
@@ -408,6 +437,7 @@ struct DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle
...
@@ -408,6 +437,7 @@ struct DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle
BGridDesc_BK0_N_BK1
,
BGridDesc_BK0_N_BK1
,
B1GridDesc_BK0_N_BK1
,
B1GridDesc_BK0_N_BK1
,
CGridDesc_M_N
,
CGridDesc_M_N
,
ZGridDesc_M_N
,
LSEGridDesc_M
,
LSEGridDesc_M
,
NumGemmKPrefetchStage
,
NumGemmKPrefetchStage
,
BlockSize
,
BlockSize
,
...
@@ -465,6 +495,7 @@ struct DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle
...
@@ -465,6 +495,7 @@ struct DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle
const
BDataType
*
p_b_grid_
;
const
BDataType
*
p_b_grid_
;
const
B1DataType
*
p_b1_grid_
;
const
B1DataType
*
p_b1_grid_
;
CDataType
*
p_c_grid_
;
CDataType
*
p_c_grid_
;
ZDataType
*
p_z_grid_
;
LSEDataType
*
p_lse_grid_
;
LSEDataType
*
p_lse_grid_
;
// tensor descriptors for block/thread-wise copy
// tensor descriptors for block/thread-wise copy
...
@@ -473,6 +504,9 @@ struct DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle
...
@@ -473,6 +504,9 @@ struct DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle
B1GridDesc_BK0_N_BK1
b1_grid_desc_bk0_n_bk1_
;
B1GridDesc_BK0_N_BK1
b1_grid_desc_bk0_n_bk1_
;
typename
GridwiseGemm
::
CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
typename
GridwiseGemm
::
CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
c_grid_desc_mblock_mperblock_nblock_nperblock_
;
c_grid_desc_mblock_mperblock_nblock_nperblock_
;
typename
GridwiseGemm
::
ZGridDescriptor_M0_N0_M1_N1_M2_N2_M3_N3_N4_N5
z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5_
;
ZGridDesc_M_N
z_grid_desc_m_n_
;
LSEGridDesc_M
lse_grid_desc_m_
;
LSEGridDesc_M
lse_grid_desc_m_
;
// batch & stride
// batch & stride
...
@@ -511,6 +545,7 @@ struct DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle
...
@@ -511,6 +545,7 @@ struct DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle
std
::
vector
<
const
void
*>
p_b_vec
,
std
::
vector
<
const
void
*>
p_b_vec
,
std
::
vector
<
const
void
*>
p_b1_vec
,
std
::
vector
<
const
void
*>
p_b1_vec
,
std
::
vector
<
void
*>
p_c_vec
,
std
::
vector
<
void
*>
p_c_vec
,
std
::
vector
<
void
*>
p_z_vec
,
std
::
vector
<
void
*>
p_lse_vec
,
std
::
vector
<
void
*>
p_lse_vec
,
std
::
vector
<
std
::
vector
<
const
void
*>>
p_acc0_biases_vec
,
std
::
vector
<
std
::
vector
<
const
void
*>>
p_acc0_biases_vec
,
std
::
vector
<
std
::
vector
<
const
void
*>>
p_acc1_biases_vec
,
std
::
vector
<
std
::
vector
<
const
void
*>>
p_acc1_biases_vec
,
...
@@ -550,6 +585,7 @@ struct DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle
...
@@ -550,6 +585,7 @@ struct DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle
const
auto
p_b_grid
=
static_cast
<
const
BDataType
*>
(
p_b_vec
[
i
]);
const
auto
p_b_grid
=
static_cast
<
const
BDataType
*>
(
p_b_vec
[
i
]);
const
auto
p_b1_grid
=
static_cast
<
const
B1DataType
*>
(
p_b1_vec
[
i
]);
const
auto
p_b1_grid
=
static_cast
<
const
B1DataType
*>
(
p_b1_vec
[
i
]);
const
auto
p_c_grid
=
static_cast
<
CDataType
*>
(
p_c_vec
[
i
]);
const
auto
p_c_grid
=
static_cast
<
CDataType
*>
(
p_c_vec
[
i
]);
const
auto
p_z_grid
=
static_cast
<
ZDataType
*>
(
p_z_vec
[
i
]);
const
auto
p_lse_grid
=
static_cast
<
LSEDataType
*>
(
p_lse_vec
[
i
]);
const
auto
p_lse_grid
=
static_cast
<
LSEDataType
*>
(
p_lse_vec
[
i
]);
const
auto
&
problem_desc
=
problem_desc_vec
[
i
];
const
auto
&
problem_desc
=
problem_desc_vec
[
i
];
...
@@ -562,6 +598,8 @@ struct DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle
...
@@ -562,6 +598,8 @@ struct DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle
problem_desc
.
b1_gs_os_ns_lengths
,
problem_desc
.
b1_gs_os_ns_strides
);
problem_desc
.
b1_gs_os_ns_lengths
,
problem_desc
.
b1_gs_os_ns_strides
);
const
auto
c_grid_desc_m_n
=
Transform
::
MakeCGridDescriptor_M_N
(
const
auto
c_grid_desc_m_n
=
Transform
::
MakeCGridDescriptor_M_N
(
problem_desc
.
c_gs_ms_os_lengths
,
problem_desc
.
c_gs_ms_os_strides
);
problem_desc
.
c_gs_ms_os_lengths
,
problem_desc
.
c_gs_ms_os_strides
);
const
auto
z_grid_desc_m_n
=
MakeZGridDescriptor_M_N
(
problem_desc
.
z_gs_ms_ns_lengths
,
problem_desc
.
z_gs_ms_ns_strides
);
const
auto
lse_grid_desc_m
=
const
auto
lse_grid_desc_m
=
DeviceOp
::
MakeLSEGridDescriptor_M
(
problem_desc
.
lse_gs_ms_lengths
[
NumDimG
]);
DeviceOp
::
MakeLSEGridDescriptor_M
(
problem_desc
.
lse_gs_ms_lengths
[
NumDimG
]);
...
@@ -573,11 +611,20 @@ struct DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle
...
@@ -573,11 +611,20 @@ struct DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle
problem_desc
.
b1_gs_os_ns_lengths
,
problem_desc
.
b1_gs_os_ns_strides
);
problem_desc
.
b1_gs_os_ns_lengths
,
problem_desc
.
b1_gs_os_ns_strides
);
const
auto
c_grid_desc_g_m_n
=
Transform
::
MakeCGridDescriptor_G_M_N
(
const
auto
c_grid_desc_g_m_n
=
Transform
::
MakeCGridDescriptor_G_M_N
(
problem_desc
.
c_gs_ms_os_lengths
,
problem_desc
.
c_gs_ms_os_strides
);
problem_desc
.
c_gs_ms_os_lengths
,
problem_desc
.
c_gs_ms_os_strides
);
const
auto
z_grid_desc_g_m_n
=
Transform
::
MakeCGridDescriptor_G_M_N
(
problem_desc
.
z_gs_ms_ns_lengths
,
problem_desc
.
z_gs_ms_ns_strides
);
const
auto
c_grid_desc_mblock_mperblock_nblock_nperblock
=
const
auto
c_grid_desc_mblock_mperblock_nblock_nperblock
=
GridwiseGemm
::
MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
(
GridwiseGemm
::
MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
(
c_grid_desc_m_n
);
c_grid_desc_m_n
);
// typename GridwiseGemm::ZGridDescriptor_M0_N0_M1_N1_M2_N2_M3_N3_N4_N5
// z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5;
auto
z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5
=
GridwiseGemm
::
MakeCGridDescriptor_M0_N0_M1_N1_M2_N2_M3_N3_N4_N5
(
z_grid_desc_m_n
);
const
index_t
BlockStart
=
grid_size_
;
const
index_t
BlockStart
=
grid_size_
;
const
auto
block_2_ctile_map
=
Block2CTileMap
(
c_grid_desc_m_n
,
BlockStart
);
const
auto
block_2_ctile_map
=
Block2CTileMap
(
c_grid_desc_m_n
,
BlockStart
);
const
index_t
batch_count
=
c_grid_desc_g_m_n
.
GetLength
(
I0
);
const
index_t
batch_count
=
c_grid_desc_g_m_n
.
GetLength
(
I0
);
...
@@ -591,6 +638,7 @@ struct DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle
...
@@ -591,6 +638,7 @@ struct DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle
b_grid_desc_g_n_k
,
b_grid_desc_g_n_k
,
b1_grid_desc_g_n_k
,
b1_grid_desc_g_n_k
,
c_grid_desc_g_m_n
,
c_grid_desc_g_m_n
,
z_grid_desc_g_m_n
,
type_convert
<
index_t
>
(
lse_grid_desc_m
.
GetElementSpaceSize
()));
type_convert
<
index_t
>
(
lse_grid_desc_m
.
GetElementSpaceSize
()));
// C0 mask
// C0 mask
...
@@ -614,11 +662,14 @@ struct DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle
...
@@ -614,11 +662,14 @@ struct DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle
p_b_grid
,
p_b_grid
,
p_b1_grid
,
p_b1_grid
,
p_c_grid
,
p_c_grid
,
p_z_grid
,
p_lse_grid
,
p_lse_grid
,
a_grid_desc_ak0_m_ak1
,
a_grid_desc_ak0_m_ak1
,
b_grid_desc_bk0_n_bk1
,
b_grid_desc_bk0_n_bk1
,
b1_grid_desc_bk0_n_bk1
,
b1_grid_desc_bk0_n_bk1
,
c_grid_desc_mblock_mperblock_nblock_nperblock
,
c_grid_desc_mblock_mperblock_nblock_nperblock
,
z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5
,
z_grid_desc_m_n
,
lse_grid_desc_m
,
lse_grid_desc_m
,
block_2_ctile_map
.
CalculateGridSize
(
c_grid_desc_m_n
),
block_2_ctile_map
.
CalculateGridSize
(
c_grid_desc_m_n
),
compute_base_ptr_of_batch
,
compute_base_ptr_of_batch
,
...
@@ -705,7 +756,7 @@ struct DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle
...
@@ -705,7 +756,7 @@ struct DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle
auto
launch_kernel
=
[
&
](
auto
has_main_k_block_loop_
,
auto
is_dropout_
)
{
auto
launch_kernel
=
[
&
](
auto
has_main_k_block_loop_
,
auto
is_dropout_
)
{
const
auto
kernel
=
const
auto
kernel
=
kernel_grouped_
multiheadattention_forward
_xdl_cshuffle
<
GridwiseGemm
,
kernel_grouped_
gemm_softmax_gemm
_xdl_cshuffle
_v2
<
GridwiseGemm
,
GemmAccDataType
,
GemmAccDataType
,
GroupKernelArg
,
GroupKernelArg
,
AElementwiseOperation
,
AElementwiseOperation
,
...
@@ -891,6 +942,7 @@ struct DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle
...
@@ -891,6 +942,7 @@ struct DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle
std
::
vector
<
const
void
*>
p_b_vec
,
std
::
vector
<
const
void
*>
p_b_vec
,
std
::
vector
<
const
void
*>
p_b1_vec
,
std
::
vector
<
const
void
*>
p_b1_vec
,
std
::
vector
<
void
*>
p_c_vec
,
std
::
vector
<
void
*>
p_c_vec
,
std
::
vector
<
void
*>
p_z_vec
,
std
::
vector
<
void
*>
p_lse_vec
,
std
::
vector
<
void
*>
p_lse_vec
,
std
::
vector
<
std
::
vector
<
const
void
*>>
p_acc0_biases_vec
,
std
::
vector
<
std
::
vector
<
const
void
*>>
p_acc0_biases_vec
,
std
::
vector
<
std
::
vector
<
const
void
*>>
p_acc1_biases_vec
,
std
::
vector
<
std
::
vector
<
const
void
*>>
p_acc1_biases_vec
,
...
@@ -907,6 +959,7 @@ struct DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle
...
@@ -907,6 +959,7 @@ struct DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle
p_b_vec
,
p_b_vec
,
p_b1_vec
,
p_b1_vec
,
p_c_vec
,
p_c_vec
,
p_z_vec
,
p_lse_vec
,
p_lse_vec
,
p_acc0_biases_vec
,
p_acc0_biases_vec
,
p_acc1_biases_vec
,
p_acc1_biases_vec
,
...
@@ -928,6 +981,7 @@ struct DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle
...
@@ -928,6 +981,7 @@ struct DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle
std
::
vector
<
const
void
*>
p_b_vec
,
std
::
vector
<
const
void
*>
p_b_vec
,
std
::
vector
<
const
void
*>
p_b1_vec
,
std
::
vector
<
const
void
*>
p_b1_vec
,
std
::
vector
<
void
*>
p_c_vec
,
std
::
vector
<
void
*>
p_c_vec
,
std
::
vector
<
void
*>
p_z_vec
,
std
::
vector
<
void
*>
p_lse_vec
,
std
::
vector
<
void
*>
p_lse_vec
,
std
::
vector
<
std
::
vector
<
const
void
*>>
p_acc0_biases_vec
,
std
::
vector
<
std
::
vector
<
const
void
*>>
p_acc0_biases_vec
,
std
::
vector
<
std
::
vector
<
const
void
*>>
p_acc1_biases_vec
,
std
::
vector
<
std
::
vector
<
const
void
*>>
p_acc1_biases_vec
,
...
@@ -944,6 +998,7 @@ struct DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle
...
@@ -944,6 +998,7 @@ struct DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle
p_b_vec
,
p_b_vec
,
p_b1_vec
,
p_b1_vec
,
p_c_vec
,
p_c_vec
,
p_z_vec
,
p_lse_vec
,
p_lse_vec
,
p_acc0_biases_vec
,
p_acc0_biases_vec
,
p_acc1_biases_vec
,
p_acc1_biases_vec
,
...
...
include/ck/tensor_operation/gpu/grid/gridwise_batched_multihead_attention_forward_xdl_cshuffle.hpp
View file @
cdc6f6ba
...
@@ -120,8 +120,8 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle
...
@@ -120,8 +120,8 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle
GridwiseGemmPipeline_Selector
<
PipelineVer
,
NumGemmKPrefetchStage
>
())
>
;
GridwiseGemmPipeline_Selector
<
PipelineVer
,
NumGemmKPrefetchStage
>
())
>
;
// C desc for source in blockwise copy
// C desc for source in blockwise copy
__host__
__device__
static
constexpr
auto
__host__
__device__
static
constexpr
auto
MakeCGridDescriptor_M0_N0_M1_N1_M2_N2_M3_N3_N4_N5
(
MakeCGridDescriptor_M0_N0_M1_N1_M2_N2_M3_N3_N4_N5
(
const
ZGridDesc_M_N
&
z_grid_desc_m_n
)
////=> for z use
const
ZGridDesc_M_N
&
z_grid_desc_m_n
)
////=> for z use
{
{
const
auto
M
=
z_grid_desc_m_n
.
GetLength
(
I0
);
const
auto
M
=
z_grid_desc_m_n
.
GetLength
(
I0
);
const
auto
N
=
z_grid_desc_m_n
.
GetLength
(
I1
);
const
auto
N
=
z_grid_desc_m_n
.
GetLength
(
I1
);
...
@@ -140,7 +140,8 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle
...
@@ -140,7 +140,8 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle
make_tuple
(
Sequence
<
0
,
2
,
4
,
6
>
{},
Sequence
<
1
,
3
,
5
,
7
,
8
,
9
>
{}));
make_tuple
(
Sequence
<
0
,
2
,
4
,
6
>
{},
Sequence
<
1
,
3
,
5
,
7
,
8
,
9
>
{}));
}
}
__host__
__device__
static
constexpr
auto
__host__
__device__
static
constexpr
auto
MakeZGridDescriptor_M0_N0_M1_N1_M2_N2_M3_N3_N4_N5
(
const
index_t
M
,
const
index_t
N
)
////=> for z use
MakeZGridDescriptor_M0_N0_M1_N1_M2_N2_M3_N3_N4_N5
(
const
index_t
M
,
const
index_t
N
)
////=> for z use
{
{
constexpr
auto
mfma
=
MfmaSelector
<
FloatAB
,
MPerXdl
,
NPerXdl
>::
selected_mfma
;
constexpr
auto
mfma
=
MfmaSelector
<
FloatAB
,
MPerXdl
,
NPerXdl
>::
selected_mfma
;
constexpr
auto
N3
=
mfma
.
num_groups_per_blk
;
constexpr
auto
N3
=
mfma
.
num_groups_per_blk
;
...
@@ -1027,7 +1028,8 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle
...
@@ -1027,7 +1028,8 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle
true
>(
true
>(
acc_thread_buf
,
ph
,
z_tenor_buffer
);
acc_thread_buf
,
ph
,
z_tenor_buffer
);
z_thread_copy_vgpr_to_global
.
Run
(
z_thread_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5
,
z_thread_copy_vgpr_to_global
.
Run
(
z_thread_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5
,
make_tuple
(
I0
,
I0
,
I0
,
I0
,
I0
,
I0
,
I0
,
I0
,
I0
,
I0
),
make_tuple
(
I0
,
I0
,
I0
,
I0
,
I0
,
I0
,
I0
,
I0
,
I0
,
I0
),
z_tenor_buffer
,
z_tenor_buffer
,
z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5
,
z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5
,
...
@@ -1041,7 +1043,7 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle
...
@@ -1041,7 +1043,7 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle
}
}
}
}
//if constexpr(IsDropout) // dropout
//
if constexpr(IsDropout) // dropout
//{
//{
// blockwise_dropout.ApplyDropout(acc_thread_buf, ph);
// blockwise_dropout.ApplyDropout(acc_thread_buf, ph);
//}
//}
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment