Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
b8f08e67
Commit
b8f08e67
authored
Aug 09, 2023
by
letaoqin
Browse files
group change to multiple D
parent
fa066d60
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
86 additions
and
64 deletions
+86
-64
include/ck/tensor_operation/gpu/device/impl/device_grouped_mha_fwd_xdl_cshuffle_v2r2.hpp
.../device/impl/device_grouped_mha_fwd_xdl_cshuffle_v2r2.hpp
+86
-64
No files found.
include/ck/tensor_operation/gpu/device/impl/device_grouped_mha_fwd_xdl_cshuffle_v2r2.hpp
View file @
b8f08e67
...
@@ -95,13 +95,19 @@ __global__ void
...
@@ -95,13 +95,19 @@ __global__ void
arg_ptr
[
group_id
].
compute_base_ptr_of_batch_
.
GetB1BasePtr
(
g_idx
)));
arg_ptr
[
group_id
].
compute_base_ptr_of_batch_
.
GetB1BasePtr
(
g_idx
)));
const
long_index_t
c_batch_offset
=
__builtin_amdgcn_readfirstlane
(
const
long_index_t
c_batch_offset
=
__builtin_amdgcn_readfirstlane
(
static_cast
<
long_index_t
>
(
arg_ptr
[
group_id
].
compute_base_ptr_of_batch_
.
GetCBasePtr
(
g_idx
)));
static_cast
<
long_index_t
>
(
arg_ptr
[
group_id
].
compute_base_ptr_of_batch_
.
GetCBasePtr
(
g_idx
)));
const
long_index_t
d_batch_offset
=
__builtin_amdgcn_readfirstlane
(
static_cast
<
long_index_t
>
(
arg_ptr
[
group_id
].
compute_base_ptr_of_batch_
.
GetDBasePtr
(
g_idx
)));
const
long_index_t
z_batch_offset
=
__builtin_amdgcn_readfirstlane
(
const
long_index_t
z_batch_offset
=
__builtin_amdgcn_readfirstlane
(
static_cast
<
long_index_t
>
(
arg_ptr
[
group_id
].
compute_base_ptr_of_batch_
.
GetZBasePtr
(
g_idx
)));
static_cast
<
long_index_t
>
(
arg_ptr
[
group_id
].
compute_base_ptr_of_batch_
.
GetZBasePtr
(
g_idx
)));
const
long_index_t
lse_batch_offset
=
__builtin_amdgcn_readfirstlane
(
static_cast
<
long_index_t
>
(
const
long_index_t
lse_batch_offset
=
__builtin_amdgcn_readfirstlane
(
static_cast
<
long_index_t
>
(
arg_ptr
[
group_id
].
compute_base_ptr_of_batch_
.
GetLSEBasePtr
(
g_idx
)));
arg_ptr
[
group_id
].
compute_base_ptr_of_batch_
.
GetLSEBasePtr
(
g_idx
)));
typename
GridwiseGemm
::
D0sGridPointer
p_d0s_grid
=
arg_ptr
[
group_id
].
p_d0s_grid_
;
static_for
<
0
,
p_d0s_grid
.
Size
(),
1
>
{}([
&
](
auto
In
)
{
const
long_index_t
d0_batch_offset
=
__builtin_amdgcn_readfirstlane
(
static_cast
<
long_index_t
>
(
arg_ptr
[
group_id
].
compute_base_ptr_of_batch_
.
GetD0BasePtr
(
g_idx
,
In
)));
p_d0s_grid
(
In
)
=
p_d0s_grid
(
In
)
+
d0_batch_offset
;
});
if
constexpr
(
Deterministic
)
if
constexpr
(
Deterministic
)
{
{
for
(
index_t
i
=
0
;
i
<
num_blocks_per_batch
;
i
++
)
for
(
index_t
i
=
0
;
i
<
num_blocks_per_batch
;
i
++
)
...
@@ -109,11 +115,9 @@ __global__ void
...
@@ -109,11 +115,9 @@ __global__ void
GridwiseGemm
::
template
Run
<
HasMainKBlockLoop
,
IsDropout
,
IsLseStoring
>(
GridwiseGemm
::
template
Run
<
HasMainKBlockLoop
,
IsDropout
,
IsLseStoring
>(
arg_ptr
[
group_id
].
p_a_grid_
+
a_batch_offset
,
arg_ptr
[
group_id
].
p_a_grid_
+
a_batch_offset
,
arg_ptr
[
group_id
].
p_b_grid_
+
b_batch_offset
,
arg_ptr
[
group_id
].
p_b_grid_
+
b_batch_offset
,
p_d0s_grid
,
arg_ptr
[
group_id
].
p_b1_grid_
+
b1_batch_offset
,
arg_ptr
[
group_id
].
p_b1_grid_
+
b1_batch_offset
,
arg_ptr
[
group_id
].
p_c_grid_
+
c_batch_offset
,
arg_ptr
[
group_id
].
p_c_grid_
+
c_batch_offset
,
arg_ptr
[
group_id
].
p_d_grid_
==
nullptr
?
nullptr
:
arg_ptr
[
group_id
].
p_d_grid_
+
d_batch_offset
,
arg_ptr
[
group_id
].
p_z_grid_
==
nullptr
arg_ptr
[
group_id
].
p_z_grid_
==
nullptr
?
nullptr
?
nullptr
:
arg_ptr
[
group_id
].
p_z_grid_
+
z_batch_offset
,
:
arg_ptr
[
group_id
].
p_z_grid_
+
z_batch_offset
,
...
@@ -129,9 +133,9 @@ __global__ void
...
@@ -129,9 +133,9 @@ __global__ void
c_element_op
,
c_element_op
,
arg_ptr
[
group_id
].
a_grid_desc_ak0_m_ak1_
,
arg_ptr
[
group_id
].
a_grid_desc_ak0_m_ak1_
,
arg_ptr
[
group_id
].
b_grid_desc_bk0_n_bk1_
,
arg_ptr
[
group_id
].
b_grid_desc_bk0_n_bk1_
,
arg_ptr
[
group_id
].
d0s_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5_
,
arg_ptr
[
group_id
].
b1_grid_desc_bk0_n_bk1_
,
arg_ptr
[
group_id
].
b1_grid_desc_bk0_n_bk1_
,
arg_ptr
[
group_id
].
c_grid_desc_mblock_mperblock_nblock_nperblock_
,
arg_ptr
[
group_id
].
c_grid_desc_mblock_mperblock_nblock_nperblock_
,
arg_ptr
[
group_id
].
d_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5_
,
arg_ptr
[
group_id
].
z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5_
,
arg_ptr
[
group_id
].
z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5_
,
arg_ptr
[
group_id
].
lse_grid_desc_m_
,
arg_ptr
[
group_id
].
lse_grid_desc_m_
,
arg_ptr
[
group_id
].
block_2_ctile_map_
,
arg_ptr
[
group_id
].
block_2_ctile_map_
,
...
@@ -150,10 +154,9 @@ __global__ void
...
@@ -150,10 +154,9 @@ __global__ void
GridwiseGemm
::
template
Run
<
HasMainKBlockLoop
,
IsDropout
,
IsLseStoring
>(
GridwiseGemm
::
template
Run
<
HasMainKBlockLoop
,
IsDropout
,
IsLseStoring
>(
arg_ptr
[
group_id
].
p_a_grid_
+
a_batch_offset
,
arg_ptr
[
group_id
].
p_a_grid_
+
a_batch_offset
,
arg_ptr
[
group_id
].
p_b_grid_
+
b_batch_offset
,
arg_ptr
[
group_id
].
p_b_grid_
+
b_batch_offset
,
p_d0s_grid
,
arg_ptr
[
group_id
].
p_b1_grid_
+
b1_batch_offset
,
arg_ptr
[
group_id
].
p_b1_grid_
+
b1_batch_offset
,
arg_ptr
[
group_id
].
p_c_grid_
+
c_batch_offset
,
arg_ptr
[
group_id
].
p_c_grid_
+
c_batch_offset
,
arg_ptr
[
group_id
].
p_d_grid_
==
nullptr
?
nullptr
:
arg_ptr
[
group_id
].
p_d_grid_
+
d_batch_offset
,
arg_ptr
[
group_id
].
p_z_grid_
==
nullptr
?
nullptr
arg_ptr
[
group_id
].
p_z_grid_
==
nullptr
?
nullptr
:
arg_ptr
[
group_id
].
p_z_grid_
+
z_batch_offset
,
:
arg_ptr
[
group_id
].
p_z_grid_
+
z_batch_offset
,
arg_ptr
[
group_id
].
p_lse_grid_
==
nullptr
arg_ptr
[
group_id
].
p_lse_grid_
==
nullptr
...
@@ -168,9 +171,9 @@ __global__ void
...
@@ -168,9 +171,9 @@ __global__ void
c_element_op
,
c_element_op
,
arg_ptr
[
group_id
].
a_grid_desc_ak0_m_ak1_
,
arg_ptr
[
group_id
].
a_grid_desc_ak0_m_ak1_
,
arg_ptr
[
group_id
].
b_grid_desc_bk0_n_bk1_
,
arg_ptr
[
group_id
].
b_grid_desc_bk0_n_bk1_
,
arg_ptr
[
group_id
].
d0s_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5_
,
arg_ptr
[
group_id
].
b1_grid_desc_bk0_n_bk1_
,
arg_ptr
[
group_id
].
b1_grid_desc_bk0_n_bk1_
,
arg_ptr
[
group_id
].
c_grid_desc_mblock_mperblock_nblock_nperblock_
,
arg_ptr
[
group_id
].
c_grid_desc_mblock_mperblock_nblock_nperblock_
,
arg_ptr
[
group_id
].
d_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5_
,
arg_ptr
[
group_id
].
z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5_
,
arg_ptr
[
group_id
].
z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5_
,
arg_ptr
[
group_id
].
lse_grid_desc_m_
,
arg_ptr
[
group_id
].
lse_grid_desc_m_
,
arg_ptr
[
group_id
].
block_2_ctile_map_
,
arg_ptr
[
group_id
].
block_2_ctile_map_
,
...
@@ -300,12 +303,7 @@ struct DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle_V2R2
...
@@ -300,12 +303,7 @@ struct DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle_V2R2
static
constexpr
index_t
NumD1Tensor
=
Acc1BiasDataType
::
Size
();
static
constexpr
index_t
NumD1Tensor
=
Acc1BiasDataType
::
Size
();
// TODO ANT: implement bias combination
// TODO ANT: implement bias combination
static_assert
(
NumD0Tensor
<=
1
,
"Bias0 addition is only support one bias"
);
static_assert
(
NumD1Tensor
==
0
,
"Bias addition is unimplemented"
);
static_assert
(
NumD1Tensor
==
0
,
"Bias addition is unimplemented"
);
static_assert
(
NumD0Tensor
==
0
?
true
:
std
::
is_same_v
<
ADataType
,
ck
::
tuple_element_t
<
0
,
Acc0BiasDataType
>>
);
using
DDataType
=
ADataType
;
#if 0
#if 0
// TODO ANT: use alias
// TODO ANT: use alias
...
@@ -424,20 +422,44 @@ struct DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle_V2R2
...
@@ -424,20 +422,44 @@ struct DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle_V2R2
}
}
}
}
static
auto
MakeD0sGridDescriptor_M_N
(
const
std
::
vector
<
std
::
vector
<
ck
::
index_t
>>&
acc0_biases_gs_ms_ns_lengths
,
const
std
::
vector
<
std
::
vector
<
ck
::
index_t
>>&
acc0_biases_gs_ms_ns_strides
)
{
return
generate_tuple
(
[
&
](
auto
i
)
{
return
Transform
::
MakeCGridDescriptor_M_N
(
acc0_biases_gs_ms_ns_lengths
[
i
],
acc0_biases_gs_ms_ns_strides
[
i
]);
},
Number
<
NumD0Tensor
>
{});
}
static
auto
MakeD0sGridDescriptor_G_M_N
(
const
std
::
vector
<
std
::
vector
<
ck
::
index_t
>>&
acc0_biases_gs_ms_ns_lengths
,
const
std
::
vector
<
std
::
vector
<
ck
::
index_t
>>&
acc0_biases_gs_ms_ns_strides
)
{
return
generate_tuple
(
[
&
](
auto
i
)
{
return
Transform
::
MakeCGridDescriptor_G_M_N
(
acc0_biases_gs_ms_ns_lengths
[
i
],
acc0_biases_gs_ms_ns_strides
[
i
]);
},
Number
<
NumD0Tensor
>
{});
}
using
AGridDesc_AK0_M_AK1
=
decltype
(
MakeAGridDescriptor_AK0_M_AK1
({},
{}));
using
AGridDesc_AK0_M_AK1
=
decltype
(
MakeAGridDescriptor_AK0_M_AK1
({},
{}));
using
BGridDesc_BK0_N_BK1
=
decltype
(
MakeBGridDescriptor_BK0_N_BK1
({},
{}));
using
BGridDesc_BK0_N_BK1
=
decltype
(
MakeBGridDescriptor_BK0_N_BK1
({},
{}));
using
D0sGridDesc_M_N
=
decltype
(
MakeD0sGridDescriptor_M_N
({},
{}));
using
B1GridDesc_BK0_N_BK1
=
decltype
(
MakeB1GridDescriptor_BK0_N_BK1
({},
{}));
using
B1GridDesc_BK0_N_BK1
=
decltype
(
MakeB1GridDescriptor_BK0_N_BK1
({},
{}));
using
CGridDesc_M_N
=
decltype
(
Transform
::
MakeCGridDescriptor_M_N
({},
{}));
using
CGridDesc_M_N
=
decltype
(
Transform
::
MakeCGridDescriptor_M_N
({},
{}));
using
LSEGridDesc_M
=
decltype
(
MakeLSEGridDescriptor_M
(
1
));
using
LSEGridDesc_M
=
decltype
(
MakeLSEGridDescriptor_M
(
1
));
using
DGridDesc_M_N
=
decltype
(
MakeZGridDescriptor_M_N
({},
{}));
using
ZGridDesc_M_N
=
decltype
(
MakeZGridDescriptor_M_N
({},
{}));
using
ZGridDesc_M_N
=
decltype
(
MakeZGridDescriptor_M_N
({},
{}));
using
AGridDesc_G_M_K
=
decltype
(
Transform
::
MakeAGridDescriptor_G_M_K
({},
{}));
using
AGridDesc_G_M_K
=
decltype
(
Transform
::
MakeAGridDescriptor_G_M_K
({},
{}));
using
BGridDesc_G_N_K
=
decltype
(
Transform
::
MakeB0GridDescriptor_G_N_K
({},
{}));
using
BGridDesc_G_N_K
=
decltype
(
Transform
::
MakeB0GridDescriptor_G_N_K
({},
{}));
using
B1
GridDesc_G_
N_K
=
decltype
(
Transform
::
Make
B1
GridDescriptor_G_
N_K
({},
{}));
using
D0s
GridDesc_G_
M_N
=
decltype
(
Make
D0s
GridDescriptor_G_
M_N
({},
{}));
using
C
GridDesc_G_
M_N
=
decltype
(
Transform
::
Make
C
GridDescriptor_G_
M_N
({},
{}));
using
B1
GridDesc_G_
N_K
=
decltype
(
Transform
::
Make
B1
GridDescriptor_G_
N_K
({},
{}));
using
D
GridDesc_G_M_N
=
decltype
(
Transform
::
MakeCGridDescriptor_G_M_N
({},
{}));
using
C
GridDesc_G_M_N
=
decltype
(
Transform
::
MakeCGridDescriptor_G_M_N
({},
{}));
using
ZGridDesc_G_M_N
=
decltype
(
Transform
::
MakeCGridDescriptor_G_M_N
({},
{}));
using
ZGridDesc_G_M_N
=
decltype
(
Transform
::
MakeCGridDescriptor_G_M_N
({},
{}));
constexpr
static
auto
make_MaskOutPredicate
()
constexpr
static
auto
make_MaskOutPredicate
()
{
{
...
@@ -460,16 +482,16 @@ struct DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle_V2R2
...
@@ -460,16 +482,16 @@ struct DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle_V2R2
{
{
ComputeBasePtrOfStridedBatch
(
const
AGridDesc_G_M_K
&
a_grid_desc_g_m_k
,
ComputeBasePtrOfStridedBatch
(
const
AGridDesc_G_M_K
&
a_grid_desc_g_m_k
,
const
BGridDesc_G_N_K
&
b_grid_desc_g_n_k
,
const
BGridDesc_G_N_K
&
b_grid_desc_g_n_k
,
const
D0sGridDesc_G_M_N
&
d0s_grid_desc_g_m_n
,
const
B1GridDesc_G_N_K
&
b1_grid_desc_g_n_k
,
const
B1GridDesc_G_N_K
&
b1_grid_desc_g_n_k
,
const
CGridDesc_G_M_N
&
c_grid_desc_g_m_n
,
const
CGridDesc_G_M_N
&
c_grid_desc_g_m_n
,
const
DGridDesc_G_M_N
&
d_grid_desc_g_m_n
,
const
ZGridDesc_G_M_N
&
z_grid_desc_g_m_n
,
const
ZGridDesc_G_M_N
&
z_grid_desc_g_m_n
,
index_t
BatchStrideLSE
)
index_t
BatchStrideLSE
)
:
a_grid_desc_g_m_k_
(
a_grid_desc_g_m_k
),
:
a_grid_desc_g_m_k_
(
a_grid_desc_g_m_k
),
b_grid_desc_g_n_k_
(
b_grid_desc_g_n_k
),
b_grid_desc_g_n_k_
(
b_grid_desc_g_n_k
),
d0s_grid_desc_g_m_n_
(
d0s_grid_desc_g_m_n
),
b1_grid_desc_g_n_k_
(
b1_grid_desc_g_n_k
),
b1_grid_desc_g_n_k_
(
b1_grid_desc_g_n_k
),
c_grid_desc_g_m_n_
(
c_grid_desc_g_m_n
),
c_grid_desc_g_m_n_
(
c_grid_desc_g_m_n
),
d_grid_desc_g_m_n_
(
d_grid_desc_g_m_n
),
z_grid_desc_g_m_n_
(
z_grid_desc_g_m_n
),
z_grid_desc_g_m_n_
(
z_grid_desc_g_m_n
),
BatchStrideLSE_
(
BatchStrideLSE
)
BatchStrideLSE_
(
BatchStrideLSE
)
{
{
...
@@ -495,9 +517,11 @@ struct DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle_V2R2
...
@@ -495,9 +517,11 @@ struct DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle_V2R2
return
c_grid_desc_g_m_n_
.
CalculateOffset
(
make_multi_index
(
g_idx
,
0
,
0
));
return
c_grid_desc_g_m_n_
.
CalculateOffset
(
make_multi_index
(
g_idx
,
0
,
0
));
}
}
__host__
__device__
constexpr
long_index_t
GetDBasePtr
(
index_t
g_idx
)
const
template
<
index_t
I
>
__host__
__device__
constexpr
long_index_t
GetD0BasePtr
(
index_t
g_idx
,
Number
<
I
>
d0_idx
)
const
{
{
return
d_grid_desc_g_m_n_
.
CalculateOffset
(
make_multi_index
(
g_idx
,
0
,
0
));
return
d
0s
_grid_desc_g_m_n_
[
d0_idx
]
.
CalculateOffset
(
make_multi_index
(
g_idx
,
0
,
0
));
}
}
__host__
__device__
constexpr
long_index_t
GetZBasePtr
(
index_t
g_idx
)
const
__host__
__device__
constexpr
long_index_t
GetZBasePtr
(
index_t
g_idx
)
const
...
@@ -513,9 +537,9 @@ struct DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle_V2R2
...
@@ -513,9 +537,9 @@ struct DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle_V2R2
private:
private:
AGridDesc_G_M_K
a_grid_desc_g_m_k_
;
AGridDesc_G_M_K
a_grid_desc_g_m_k_
;
BGridDesc_G_N_K
b_grid_desc_g_n_k_
;
BGridDesc_G_N_K
b_grid_desc_g_n_k_
;
D0sGridDesc_G_M_N
d0s_grid_desc_g_m_n_
;
B1GridDesc_G_N_K
b1_grid_desc_g_n_k_
;
B1GridDesc_G_N_K
b1_grid_desc_g_n_k_
;
CGridDesc_G_M_N
c_grid_desc_g_m_n_
;
CGridDesc_G_M_N
c_grid_desc_g_m_n_
;
DGridDesc_G_M_N
d_grid_desc_g_m_n_
;
ZGridDesc_G_M_N
z_grid_desc_g_m_n_
;
ZGridDesc_G_M_N
z_grid_desc_g_m_n_
;
index_t
BatchStrideLSE_
;
index_t
BatchStrideLSE_
;
...
@@ -524,6 +548,7 @@ struct DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle_V2R2
...
@@ -524,6 +548,7 @@ struct DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle_V2R2
// GridwiseGemm
// GridwiseGemm
using
GridwiseGemm
=
GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle_V2R2
<
using
GridwiseGemm
=
GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle_V2R2
<
ADataType
,
// TODO: distinguish A/B datatype
ADataType
,
// TODO: distinguish A/B datatype
Acc0BiasDataType
,
ZDataType
,
ZDataType
,
GemmDataType
,
GemmDataType
,
GemmAccDataType
,
GemmAccDataType
,
...
@@ -538,6 +563,7 @@ struct DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle_V2R2
...
@@ -538,6 +563,7 @@ struct DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle_V2R2
InMemoryDataOperationEnum
::
Set
,
InMemoryDataOperationEnum
::
Set
,
AGridDesc_AK0_M_AK1
,
AGridDesc_AK0_M_AK1
,
BGridDesc_BK0_N_BK1
,
BGridDesc_BK0_N_BK1
,
D0sGridDesc_M_N
,
B1GridDesc_BK0_N_BK1
,
B1GridDesc_BK0_N_BK1
,
CGridDesc_M_N
,
CGridDesc_M_N
,
ZGridDesc_M_N
,
ZGridDesc_M_N
,
...
@@ -599,21 +625,20 @@ struct DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle_V2R2
...
@@ -599,21 +625,20 @@ struct DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle_V2R2
// pointers
// pointers
const
ADataType
*
p_a_grid_
;
const
ADataType
*
p_a_grid_
;
const
BDataType
*
p_b_grid_
;
const
BDataType
*
p_b_grid_
;
typename
GridwiseGemm
::
D0sGridPointer
p_d0s_grid_
;
const
B1DataType
*
p_b1_grid_
;
const
B1DataType
*
p_b1_grid_
;
CDataType
*
p_c_grid_
;
CDataType
*
p_c_grid_
;
const
DDataType
*
p_d_grid_
;
ZDataType
*
p_z_grid_
;
ZDataType
*
p_z_grid_
;
LSEDataType
*
p_lse_grid_
;
LSEDataType
*
p_lse_grid_
;
// tensor descriptors for block/thread-wise copy
// tensor descriptors for block/thread-wise copy
AGridDesc_AK0_M_AK1
a_grid_desc_ak0_m_ak1_
;
AGridDesc_AK0_M_AK1
a_grid_desc_ak0_m_ak1_
;
BGridDesc_BK0_N_BK1
b_grid_desc_bk0_n_bk1_
;
BGridDesc_BK0_N_BK1
b_grid_desc_bk0_n_bk1_
;
typename
GridwiseGemm
::
D0sGridDescriptor_M0_N0_M1_N1_M2_N2_M3_N3_N4_N5
d0s_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5_
;
B1GridDesc_BK0_N_BK1
b1_grid_desc_bk0_n_bk1_
;
B1GridDesc_BK0_N_BK1
b1_grid_desc_bk0_n_bk1_
;
typename
GridwiseGemm
::
CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
typename
GridwiseGemm
::
CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
c_grid_desc_mblock_mperblock_nblock_nperblock_
;
c_grid_desc_mblock_mperblock_nblock_nperblock_
;
DGridDesc_M_N
d_grid_desc_m_n_
;
typename
GridwiseGemm
::
ZGridDescriptor_M0_N0_M1_N1_M2_N2_M3_N3_N4_N5
d_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5_
;
typename
GridwiseGemm
::
ZGridDescriptor_M0_N0_M1_N1_M2_N2_M3_N3_N4_N5
typename
GridwiseGemm
::
ZGridDescriptor_M0_N0_M1_N1_M2_N2_M3_N3_N4_N5
z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5_
;
z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5_
;
ZGridDesc_M_N
z_grid_desc_m_n_
;
ZGridDesc_M_N
z_grid_desc_m_n_
;
...
@@ -650,7 +675,7 @@ struct DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle_V2R2
...
@@ -650,7 +675,7 @@ struct DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle_V2R2
CGridDesc_M_N
c_grid_desc_m_n_
;
CGridDesc_M_N
c_grid_desc_m_n_
;
// raw data
// raw data
int
raw_d0_n
_
;
std
::
array
<
std
::
vector
<
ck
::
index_t
>
,
NumD0Tensor
>
d0s_nl_ns_lengths_strides
_
;
};
};
// Argument
// Argument
...
@@ -700,9 +725,21 @@ struct DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle_V2R2
...
@@ -700,9 +725,21 @@ struct DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle_V2R2
const
auto
p_b_grid
=
static_cast
<
const
BDataType
*>
(
p_b_vec
[
i
]);
const
auto
p_b_grid
=
static_cast
<
const
BDataType
*>
(
p_b_vec
[
i
]);
const
auto
p_b1_grid
=
static_cast
<
const
B1DataType
*>
(
p_b1_vec
[
i
]);
const
auto
p_b1_grid
=
static_cast
<
const
B1DataType
*>
(
p_b1_vec
[
i
]);
const
auto
p_c_grid
=
static_cast
<
CDataType
*>
(
p_c_vec
[
i
]);
const
auto
p_c_grid
=
static_cast
<
CDataType
*>
(
p_c_vec
[
i
]);
const
auto
p_d_grid
=
NumD0Tensor
==
0
?
nullptr
const
auto
&
problem_desc
=
problem_desc_vec
[
i
];
:
static_cast
<
const
DDataType
*>
(
p_acc0_biases_vec
[
i
][
0
]);
std
::
array
<
std
::
vector
<
ck
::
index_t
>
,
NumD0Tensor
>
d0s_nl_ns_lengths_strides
;
typename
GridwiseGemm
::
D0sGridPointer
p_d0s_grid
;
static_for
<
0
,
NumD0Tensor
,
1
>
{}([
&
](
auto
j
)
{
using
D0DataType
=
remove_cvref_t
<
tuple_element_t
<
j
.
value
,
Acc0BiasDataType
>>
;
// D0 pointer
p_d0s_grid
(
j
)
=
static_cast
<
const
D0DataType
*>
(
p_acc0_biases_vec
[
i
][
j
]);
// for check
d0s_nl_ns_lengths_strides
[
j
].
push_back
(
problem_desc
.
acc0_biases_gs_ms_ns_lengths
[
j
][
NumDimG
+
NumDimM
]);
d0s_nl_ns_lengths_strides
[
j
].
push_back
(
problem_desc
.
acc0_biases_gs_ms_ns_strides
[
j
][
NumDimG
+
NumDimM
]);
});
const
auto
p_z_grid
=
static_cast
<
ZDataType
*>
(
p_z_vec
[
i
]);
const
auto
p_z_grid
=
static_cast
<
ZDataType
*>
(
p_z_vec
[
i
]);
const
auto
p_lse_grid
=
static_cast
<
LSEDataType
*>
(
p_lse_vec
[
i
]);
const
auto
p_lse_grid
=
static_cast
<
LSEDataType
*>
(
p_lse_vec
[
i
]);
...
@@ -711,24 +748,22 @@ struct DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle_V2R2
...
@@ -711,24 +748,22 @@ struct DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle_V2R2
is_lse_storing_
=
false
;
is_lse_storing_
=
false
;
}
}
const
auto
&
problem_desc
=
problem_desc_vec
[
i
];
const
auto
a_grid_desc_ak0_m_ak1
=
MakeAGridDescriptor_AK0_M_AK1
(
const
auto
a_grid_desc_ak0_m_ak1
=
MakeAGridDescriptor_AK0_M_AK1
(
problem_desc
.
a_gs_ms_ks_lengths
,
problem_desc
.
a_gs_ms_ks_strides
);
problem_desc
.
a_gs_ms_ks_lengths
,
problem_desc
.
a_gs_ms_ks_strides
);
const
auto
b_grid_desc_bk0_n_bk1
=
MakeBGridDescriptor_BK0_N_BK1
(
const
auto
b_grid_desc_bk0_n_bk1
=
MakeBGridDescriptor_BK0_N_BK1
(
problem_desc
.
b0_gs_ns_ks_lengths
,
problem_desc
.
b0_gs_ns_ks_strides
);
problem_desc
.
b0_gs_ns_ks_lengths
,
problem_desc
.
b0_gs_ns_ks_strides
);
const
D0sGridDesc_M_N
d0s_grid_desc_m_n
{
DeviceOp
::
MakeD0sGridDescriptor_M_N
(
problem_desc
.
acc0_biases_gs_ms_ns_lengths
,
problem_desc
.
acc0_biases_gs_ms_ns_strides
)};
const
auto
d0s_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5
=
GridwiseGemm
::
MakeD0sGridDescriptor_M0_N0_M1_N1_M2_N2_M3_N3_N4_N5
(
d0s_grid_desc_m_n
);
const
auto
b1_grid_desc_bk0_n_bk1
=
MakeB1GridDescriptor_BK0_N_BK1
(
const
auto
b1_grid_desc_bk0_n_bk1
=
MakeB1GridDescriptor_BK0_N_BK1
(
problem_desc
.
b1_gs_os_ns_lengths
,
problem_desc
.
b1_gs_os_ns_strides
);
problem_desc
.
b1_gs_os_ns_lengths
,
problem_desc
.
b1_gs_os_ns_strides
);
const
auto
c_grid_desc_m_n
=
Transform
::
MakeCGridDescriptor_M_N
(
const
auto
c_grid_desc_m_n
=
Transform
::
MakeCGridDescriptor_M_N
(
problem_desc
.
c_gs_ms_os_lengths
,
problem_desc
.
c_gs_ms_os_strides
);
problem_desc
.
c_gs_ms_os_lengths
,
problem_desc
.
c_gs_ms_os_strides
);
const
auto
d_grid_desc_m_n
=
NumD0Tensor
==
0
?
DGridDesc_M_N
{}
:
MakeZGridDescriptor_M_N
(
problem_desc
.
acc0_biases_gs_ms_ns_lengths
[
0
],
problem_desc
.
acc0_biases_gs_ms_ns_strides
[
0
]);
const
auto
d_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5
=
GridwiseGemm
::
MakeCGridDescriptor_M0_N0_M1_N1_M2_N2_M3_N3_N4_N5
(
d_grid_desc_m_n
);
const
auto
z_grid_desc_m_n
=
MakeZGridDescriptor_M_N
(
const
auto
z_grid_desc_m_n
=
MakeZGridDescriptor_M_N
(
problem_desc
.
z_gs_ms_ns_lengths
,
problem_desc
.
z_gs_ms_ns_strides
);
problem_desc
.
z_gs_ms_ns_lengths
,
problem_desc
.
z_gs_ms_ns_strides
);
const
auto
lse_grid_desc_m
=
const
auto
lse_grid_desc_m
=
...
@@ -742,11 +777,6 @@ struct DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle_V2R2
...
@@ -742,11 +777,6 @@ struct DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle_V2R2
problem_desc
.
b1_gs_os_ns_lengths
,
problem_desc
.
b1_gs_os_ns_strides
);
problem_desc
.
b1_gs_os_ns_lengths
,
problem_desc
.
b1_gs_os_ns_strides
);
const
auto
c_grid_desc_g_m_n
=
Transform
::
MakeCGridDescriptor_G_M_N
(
const
auto
c_grid_desc_g_m_n
=
Transform
::
MakeCGridDescriptor_G_M_N
(
problem_desc
.
c_gs_ms_os_lengths
,
problem_desc
.
c_gs_ms_os_strides
);
problem_desc
.
c_gs_ms_os_lengths
,
problem_desc
.
c_gs_ms_os_strides
);
const
auto
d_grid_desc_g_m_n
=
NumD0Tensor
==
0
?
DGridDesc_G_M_N
{}
:
Transform
::
MakeCGridDescriptor_G_M_N
(
problem_desc
.
acc0_biases_gs_ms_ns_lengths
[
0
],
problem_desc
.
acc0_biases_gs_ms_ns_strides
[
0
]);
const
auto
z_grid_desc_g_m_n
=
Transform
::
MakeCGridDescriptor_G_M_N
(
const
auto
z_grid_desc_g_m_n
=
Transform
::
MakeCGridDescriptor_G_M_N
(
problem_desc
.
z_gs_ms_ns_lengths
,
problem_desc
.
z_gs_ms_ns_strides
);
problem_desc
.
z_gs_ms_ns_lengths
,
problem_desc
.
z_gs_ms_ns_strides
);
...
@@ -771,12 +801,15 @@ struct DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle_V2R2
...
@@ -771,12 +801,15 @@ struct DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle_V2R2
const
index_t
BlockEnd
=
grid_size_
+
grid_size_grp
;
const
index_t
BlockEnd
=
grid_size_
+
grid_size_grp
;
// batch stride
// batch stride
const
auto
d0s_grid_desc_g_m_n
=
DeviceOp
::
MakeD0sGridDescriptor_G_M_N
(
problem_desc
.
acc0_biases_gs_ms_ns_lengths
,
problem_desc
.
acc0_biases_gs_ms_ns_strides
);
const
auto
compute_base_ptr_of_batch
=
ComputeBasePtrOfStridedBatch
(
const
auto
compute_base_ptr_of_batch
=
ComputeBasePtrOfStridedBatch
(
a_grid_desc_g_m_k
,
a_grid_desc_g_m_k
,
b_grid_desc_g_n_k
,
b_grid_desc_g_n_k
,
d0s_grid_desc_g_m_n
,
b1_grid_desc_g_n_k
,
b1_grid_desc_g_n_k
,
c_grid_desc_g_m_n
,
c_grid_desc_g_m_n
,
d_grid_desc_g_m_n
,
z_grid_desc_g_m_n
,
z_grid_desc_g_m_n
,
type_convert
<
index_t
>
(
lse_grid_desc_m
.
GetElementSpaceSize
()));
type_convert
<
index_t
>
(
lse_grid_desc_m
.
GetElementSpaceSize
()));
...
@@ -805,17 +838,16 @@ struct DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle_V2R2
...
@@ -805,17 +838,16 @@ struct DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle_V2R2
group_kernel_args_
.
push_back
({
p_a_grid
,
group_kernel_args_
.
push_back
({
p_a_grid
,
p_b_grid
,
p_b_grid
,
p_d0s_grid
,
p_b1_grid
,
p_b1_grid
,
p_c_grid
,
p_c_grid
,
p_d_grid
,
p_z_grid
,
p_z_grid
,
p_lse_grid
,
p_lse_grid
,
a_grid_desc_ak0_m_ak1
,
a_grid_desc_ak0_m_ak1
,
b_grid_desc_bk0_n_bk1
,
b_grid_desc_bk0_n_bk1
,
d0s_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5
,
b1_grid_desc_bk0_n_bk1
,
b1_grid_desc_bk0_n_bk1
,
c_grid_desc_mblock_mperblock_nblock_nperblock
,
c_grid_desc_mblock_mperblock_nblock_nperblock
,
d_grid_desc_m_n
,
d_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5
,
z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5
,
z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5
,
z_grid_desc_m_n
,
z_grid_desc_m_n
,
lse_grid_desc_m
,
lse_grid_desc_m
,
...
@@ -832,11 +864,6 @@ struct DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle_V2R2
...
@@ -832,11 +864,6 @@ struct DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle_V2R2
z_random_matrix_offset
=
z_random_matrix_offset
=
z_random_matrix_offset
+
raw_m_padded
*
raw_n_padded
*
batch_count
;
z_random_matrix_offset
+
raw_m_padded
*
raw_n_padded
*
batch_count
;
auto
raw_d0_m_n
=
NumD0Tensor
==
0
?
RawTransform
::
MakeCGridDescriptor_M_N
({},
{})
:
RawTransform
::
MakeCGridDescriptor_M_N
(
problem_desc
.
acc0_biases_gs_ms_ns_lengths
[
0
],
problem_desc
.
acc0_biases_gs_ms_ns_strides
[
0
]);
group_device_args_
.
push_back
(
group_device_args_
.
push_back
(
{{
problem_desc
.
a_gs_ms_ks_lengths
[
NumDimG
+
NumDimM
-
1
],
{{
problem_desc
.
a_gs_ms_ks_lengths
[
NumDimG
+
NumDimM
-
1
],
problem_desc
.
b0_gs_ns_ks_lengths
[
NumDimG
+
NumDimN
-
1
],
problem_desc
.
b0_gs_ns_ks_lengths
[
NumDimG
+
NumDimN
-
1
],
...
@@ -851,7 +878,7 @@ struct DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle_V2R2
...
@@ -851,7 +878,7 @@ struct DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle_V2R2
{
problem_desc
.
c_gs_ms_os_strides
[
NumDimG
+
NumDimM
-
1
],
{
problem_desc
.
c_gs_ms_os_strides
[
NumDimG
+
NumDimM
-
1
],
problem_desc
.
c_gs_ms_os_strides
[
NumDimG
+
NumDimM
+
NumDimO
-
1
]},
problem_desc
.
c_gs_ms_os_strides
[
NumDimG
+
NumDimM
+
NumDimO
-
1
]},
c_grid_desc_m_n
,
c_grid_desc_m_n
,
NumD0Tensor
==
0
?
0
:
raw_d0_m_n
.
GetLength
(
I1
)
});
d0s_nl_ns_lengths_strides
});
}
}
is_dropout_
=
p_dropout
>
0.0
;
//
is_dropout_
=
p_dropout
>
0.0
;
//
...
@@ -1067,11 +1094,6 @@ struct DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle_V2R2
...
@@ -1067,11 +1094,6 @@ struct DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle_V2R2
const
index_t
a_m
=
kernel_arg
.
a_grid_desc_ak0_m_ak1_
.
GetLength
(
I1
);
const
index_t
a_m
=
kernel_arg
.
a_grid_desc_ak0_m_ak1_
.
GetLength
(
I1
);
const
index_t
b1_gemm1n
=
kernel_arg
.
b1_grid_desc_bk0_n_bk1_
.
GetLength
(
I1
);
const
index_t
b1_gemm1n
=
kernel_arg
.
b1_grid_desc_bk0_n_bk1_
.
GetLength
(
I1
);
if
(
device_arg
.
raw_d0_n_
%
Acc0BiasTransferSrcScalarPerVector
!=
0
)
{
return
false
;
}
if
(
!
(
c_m
==
a_m
&&
c_gemm1n
==
b1_gemm1n
))
if
(
!
(
c_m
==
a_m
&&
c_gemm1n
==
b1_gemm1n
))
{
{
return
false
;
return
false
;
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment