Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
499dfe39
Commit
499dfe39
authored
Aug 04, 2023
by
letaoqin
Browse files
change name to NumD0Tensor
parent
381a7317
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
26 additions
and
26 deletions
+26
-26
include/ck/tensor_operation/gpu/device/impl/device_batched_mha_fwd_bias_xdl_cshuffle_v2.hpp
...vice/impl/device_batched_mha_fwd_bias_xdl_cshuffle_v2.hpp
+26
-26
No files found.
include/ck/tensor_operation/gpu/device/impl/device_batched_mha_fwd_bias_xdl_cshuffle_v2.hpp
View file @
499dfe39
...
@@ -312,13 +312,13 @@ struct DeviceBatchedMultiheadAttentionForward_Xdl_CShuffle_V2R2
...
@@ -312,13 +312,13 @@ struct DeviceBatchedMultiheadAttentionForward_Xdl_CShuffle_V2R2
static_assert
(
NumDimG
>
0
&&
NumDimM
>
0
&&
NumDimN
>
0
&&
NumDimK
>
0
&&
NumDimO
>
0
,
static_assert
(
NumDimG
>
0
&&
NumDimM
>
0
&&
NumDimN
>
0
&&
NumDimK
>
0
&&
NumDimO
>
0
,
"Number of dimension must be greater than 0"
);
"Number of dimension must be greater than 0"
);
static
constexpr
index_t
Num
Acc0Bias
=
Acc0BiasDataType
::
Size
();
static
constexpr
index_t
Num
D0Tensor
=
Acc0BiasDataType
::
Size
();
static
constexpr
index_t
Num
Acc1Bias
=
Acc1BiasDataType
::
Size
();
static
constexpr
index_t
Num
D1Tensor
=
Acc1BiasDataType
::
Size
();
// TODO ANT: implement bias combination
// TODO ANT: implement bias combination
static_assert
(
Num
Acc0Bias
<=
1
,
"Acc0 Bias addition is max support one bias"
);
static_assert
(
Num
D0Tensor
<=
1
,
"Acc0 Bias addition is max support one bias"
);
static_assert
(
Num
Acc1Bias
==
0
,
"Acc1 Bias addition is unimplemented"
);
static_assert
(
Num
D1Tensor
==
0
,
"Acc1 Bias addition is unimplemented"
);
static_assert
(
Num
Acc1Bias
==
0
static_assert
(
Num
D1Tensor
==
0
?
true
?
true
:
std
::
is_same_v
<
ADataType
,
ck
::
tuple_element_t
<
0
,
Acc0BiasDataType
>>
);
:
std
::
is_same_v
<
ADataType
,
ck
::
tuple_element_t
<
0
,
Acc0BiasDataType
>>
);
using
DDataType
=
ADataType
;
using
DDataType
=
ADataType
;
...
@@ -580,8 +580,8 @@ struct DeviceBatchedMultiheadAttentionForward_Xdl_CShuffle_V2R2
...
@@ -580,8 +580,8 @@ struct DeviceBatchedMultiheadAttentionForward_Xdl_CShuffle_V2R2
CDataType
*
p_c_grid
,
CDataType
*
p_c_grid
,
ZDataType
*
p_z_grid
,
ZDataType
*
p_z_grid
,
LSEDataType
*
p_lse_grid
,
LSEDataType
*
p_lse_grid
,
const
std
::
array
<
void
*
,
Num
Acc0Bias
>
p_acc0_biases
,
const
std
::
array
<
void
*
,
Num
D0Tensor
>
p_acc0_biases
,
const
std
::
array
<
void
*
,
Num
Acc1Bias
>
p_acc1_biases
,
const
std
::
array
<
void
*
,
Num
D1Tensor
>
p_acc1_biases
,
const
std
::
vector
<
index_t
>&
a_gs_ms_ks_lengths
,
const
std
::
vector
<
index_t
>&
a_gs_ms_ks_lengths
,
const
std
::
vector
<
index_t
>&
a_gs_ms_ks_strides
,
const
std
::
vector
<
index_t
>&
a_gs_ms_ks_strides
,
const
std
::
vector
<
index_t
>&
b_gs_ns_ks_lengths
,
const
std
::
vector
<
index_t
>&
b_gs_ns_ks_lengths
,
...
@@ -593,11 +593,11 @@ struct DeviceBatchedMultiheadAttentionForward_Xdl_CShuffle_V2R2
...
@@ -593,11 +593,11 @@ struct DeviceBatchedMultiheadAttentionForward_Xdl_CShuffle_V2R2
const
std
::
vector
<
index_t
>&
z_gs_ms_ns_lengths
,
const
std
::
vector
<
index_t
>&
z_gs_ms_ns_lengths
,
const
std
::
vector
<
index_t
>&
z_gs_ms_ns_strides
,
const
std
::
vector
<
index_t
>&
z_gs_ms_ns_strides
,
const
std
::
vector
<
index_t
>&
lse_gs_ms_lengths
,
const
std
::
vector
<
index_t
>&
lse_gs_ms_lengths
,
const
std
::
array
<
std
::
vector
<
ck
::
index_t
>
,
Num
Acc0Bias
>
acc0_biases_gs_ms_ns_lengths
,
const
std
::
array
<
std
::
vector
<
ck
::
index_t
>
,
Num
D0Tensor
>
acc0_biases_gs_ms_ns_lengths
,
const
std
::
array
<
std
::
vector
<
ck
::
index_t
>
,
Num
Acc0Bias
>
acc0_biases_gs_ms_ns_strides
,
const
std
::
array
<
std
::
vector
<
ck
::
index_t
>
,
Num
D0Tensor
>
acc0_biases_gs_ms_ns_strides
,
const
std
::
array
<
std
::
vector
<
ck
::
index_t
>
,
Num
Acc1Bias
>
const
std
::
array
<
std
::
vector
<
ck
::
index_t
>
,
Num
D1Tensor
>
acc1_biases_gs_ms_gemm1ns_lengths
,
// acc1_biases_gs_ms_os_lengths
acc1_biases_gs_ms_gemm1ns_lengths
,
// acc1_biases_gs_ms_os_lengths
const
std
::
array
<
std
::
vector
<
ck
::
index_t
>
,
Num
Acc1Bias
>
const
std
::
array
<
std
::
vector
<
ck
::
index_t
>
,
Num
D1Tensor
>
acc1_biases_gs_ms_gemm1ns_strides
,
// acc1_biases_gs_ms_os_strides
acc1_biases_gs_ms_gemm1ns_strides
,
// acc1_biases_gs_ms_os_strides
AElementwiseOperation
a_element_op
,
AElementwiseOperation
a_element_op
,
BElementwiseOperation
b_element_op
,
BElementwiseOperation
b_element_op
,
...
@@ -610,7 +610,7 @@ struct DeviceBatchedMultiheadAttentionForward_Xdl_CShuffle_V2R2
...
@@ -610,7 +610,7 @@ struct DeviceBatchedMultiheadAttentionForward_Xdl_CShuffle_V2R2
p_b_grid_
{
p_b_grid
},
p_b_grid_
{
p_b_grid
},
p_b1_grid_
{
p_b1_grid
},
p_b1_grid_
{
p_b1_grid
},
p_c_grid_
{
p_c_grid
},
p_c_grid_
{
p_c_grid
},
p_d_grid_
{
Num
Acc0Bias
==
0
?
nullptr
p_d_grid_
{
Num
D0Tensor
==
0
?
nullptr
:
static_cast
<
const
DDataType
*>
(
p_acc0_biases
[
0
])},
:
static_cast
<
const
DDataType
*>
(
p_acc0_biases
[
0
])},
p_z_grid_
{
p_z_grid
},
p_z_grid_
{
p_z_grid
},
p_lse_grid_
{
p_lse_grid
},
p_lse_grid_
{
p_lse_grid
},
...
@@ -622,7 +622,7 @@ struct DeviceBatchedMultiheadAttentionForward_Xdl_CShuffle_V2R2
...
@@ -622,7 +622,7 @@ struct DeviceBatchedMultiheadAttentionForward_Xdl_CShuffle_V2R2
b1_gs_gemm1ns_gemm1ks_lengths
,
b1_gs_gemm1ns_gemm1ks_strides
)},
b1_gs_gemm1ns_gemm1ks_lengths
,
b1_gs_gemm1ns_gemm1ks_strides
)},
c_grid_desc_m_n_
{
Transform
::
MakeCGridDescriptor_M_N
(
c_gs_ms_gemm1ns_lengths
,
c_grid_desc_m_n_
{
Transform
::
MakeCGridDescriptor_M_N
(
c_gs_ms_gemm1ns_lengths
,
c_gs_ms_gemm1ns_strides
)},
c_gs_ms_gemm1ns_strides
)},
d_grid_desc_m_n_
{
Num
Acc0Bias
==
0
d_grid_desc_m_n_
{
Num
D0Tensor
==
0
?
DGridDesc_M_N
{}
?
DGridDesc_M_N
{}
:
MakeZGridDescriptor_M_N
(
acc0_biases_gs_ms_ns_lengths
[
0
],
:
MakeZGridDescriptor_M_N
(
acc0_biases_gs_ms_ns_lengths
[
0
],
acc0_biases_gs_ms_ns_strides
[
0
])},
acc0_biases_gs_ms_ns_strides
[
0
])},
...
@@ -636,7 +636,7 @@ struct DeviceBatchedMultiheadAttentionForward_Xdl_CShuffle_V2R2
...
@@ -636,7 +636,7 @@ struct DeviceBatchedMultiheadAttentionForward_Xdl_CShuffle_V2R2
b1_gs_gemm1ns_gemm1ks_lengths
,
b1_gs_gemm1ns_gemm1ks_strides
)},
b1_gs_gemm1ns_gemm1ks_lengths
,
b1_gs_gemm1ns_gemm1ks_strides
)},
c_grid_desc_g_m_n_
{
Transform
::
MakeCGridDescriptor_G_M_N
(
c_gs_ms_gemm1ns_lengths
,
c_grid_desc_g_m_n_
{
Transform
::
MakeCGridDescriptor_G_M_N
(
c_gs_ms_gemm1ns_lengths
,
c_gs_ms_gemm1ns_strides
)},
c_gs_ms_gemm1ns_strides
)},
d_grid_desc_g_m_n_
{
Num
Acc0Bias
==
0
?
DGridDesc_G_M_N
{}
d_grid_desc_g_m_n_
{
Num
D0Tensor
==
0
?
DGridDesc_G_M_N
{}
:
Transform
::
MakeCGridDescriptor_G_M_N
(
:
Transform
::
MakeCGridDescriptor_G_M_N
(
acc0_biases_gs_ms_ns_lengths
[
0
],
acc0_biases_gs_ms_ns_lengths
[
0
],
acc0_biases_gs_ms_ns_strides
[
0
])},
acc0_biases_gs_ms_ns_strides
[
0
])},
...
@@ -1058,8 +1058,8 @@ struct DeviceBatchedMultiheadAttentionForward_Xdl_CShuffle_V2R2
...
@@ -1058,8 +1058,8 @@ struct DeviceBatchedMultiheadAttentionForward_Xdl_CShuffle_V2R2
CDataType
*
p_c
,
CDataType
*
p_c
,
ZDataType
*
p_z
,
ZDataType
*
p_z
,
LSEDataType
*
p_lse
,
LSEDataType
*
p_lse
,
const
std
::
array
<
void
*
,
Num
Acc0Bias
>
p_acc0_biases
,
const
std
::
array
<
void
*
,
Num
D0Tensor
>
p_acc0_biases
,
const
std
::
array
<
void
*
,
Num
Acc1Bias
>
p_acc1_biases
,
const
std
::
array
<
void
*
,
Num
D1Tensor
>
p_acc1_biases
,
const
std
::
vector
<
index_t
>&
a_gs_ms_ks_lengths
,
const
std
::
vector
<
index_t
>&
a_gs_ms_ks_lengths
,
const
std
::
vector
<
index_t
>&
a_gs_ms_ks_strides
,
const
std
::
vector
<
index_t
>&
a_gs_ms_ks_strides
,
const
std
::
vector
<
index_t
>&
b_gs_ns_ks_lengths
,
const
std
::
vector
<
index_t
>&
b_gs_ns_ks_lengths
,
...
@@ -1071,11 +1071,11 @@ struct DeviceBatchedMultiheadAttentionForward_Xdl_CShuffle_V2R2
...
@@ -1071,11 +1071,11 @@ struct DeviceBatchedMultiheadAttentionForward_Xdl_CShuffle_V2R2
const
std
::
vector
<
index_t
>&
z_gs_ms_ns_lengths
,
const
std
::
vector
<
index_t
>&
z_gs_ms_ns_lengths
,
const
std
::
vector
<
index_t
>&
z_gs_ms_ns_strides
,
const
std
::
vector
<
index_t
>&
z_gs_ms_ns_strides
,
const
std
::
vector
<
index_t
>&
lse_gs_ms_lengths
,
const
std
::
vector
<
index_t
>&
lse_gs_ms_lengths
,
const
std
::
array
<
std
::
vector
<
ck
::
index_t
>
,
Num
Acc0Bias
>
acc0_biases_gs_ms_ns_lengths
,
const
std
::
array
<
std
::
vector
<
ck
::
index_t
>
,
Num
D0Tensor
>
acc0_biases_gs_ms_ns_lengths
,
const
std
::
array
<
std
::
vector
<
ck
::
index_t
>
,
Num
Acc0Bias
>
acc0_biases_gs_ms_ns_strides
,
const
std
::
array
<
std
::
vector
<
ck
::
index_t
>
,
Num
D0Tensor
>
acc0_biases_gs_ms_ns_strides
,
const
std
::
array
<
std
::
vector
<
ck
::
index_t
>
,
Num
Acc1Bias
>
const
std
::
array
<
std
::
vector
<
ck
::
index_t
>
,
Num
D1Tensor
>
acc1_biases_gs_ms_gemm1ns_lengths
,
// acc1_biases_gs_ms_os_lengths
acc1_biases_gs_ms_gemm1ns_lengths
,
// acc1_biases_gs_ms_os_lengths
const
std
::
array
<
std
::
vector
<
ck
::
index_t
>
,
Num
Acc1Bias
>
const
std
::
array
<
std
::
vector
<
ck
::
index_t
>
,
Num
D1Tensor
>
acc1_biases_gs_ms_gemm1ns_strides
,
// acc1_biases_gs_ms_os_strides
acc1_biases_gs_ms_gemm1ns_strides
,
// acc1_biases_gs_ms_os_strides
AElementwiseOperation
a_element_op
,
AElementwiseOperation
a_element_op
,
BElementwiseOperation
b_element_op
,
BElementwiseOperation
b_element_op
,
...
@@ -1128,8 +1128,8 @@ struct DeviceBatchedMultiheadAttentionForward_Xdl_CShuffle_V2R2
...
@@ -1128,8 +1128,8 @@ struct DeviceBatchedMultiheadAttentionForward_Xdl_CShuffle_V2R2
void
*
p_c
,
void
*
p_c
,
void
*
p_z
,
void
*
p_z
,
void
*
p_lse
,
void
*
p_lse
,
const
std
::
array
<
void
*
,
Num
Acc0Bias
>
p_acc0_biases
,
const
std
::
array
<
void
*
,
Num
D0Tensor
>
p_acc0_biases
,
const
std
::
array
<
void
*
,
Num
Acc1Bias
>
p_acc1_biases
,
const
std
::
array
<
void
*
,
Num
D1Tensor
>
p_acc1_biases
,
const
std
::
vector
<
index_t
>&
a_gs_ms_ks_lengths
,
const
std
::
vector
<
index_t
>&
a_gs_ms_ks_lengths
,
const
std
::
vector
<
index_t
>&
a_gs_ms_ks_strides
,
const
std
::
vector
<
index_t
>&
a_gs_ms_ks_strides
,
const
std
::
vector
<
index_t
>&
b_gs_ns_ks_lengths
,
const
std
::
vector
<
index_t
>&
b_gs_ns_ks_lengths
,
...
@@ -1141,11 +1141,11 @@ struct DeviceBatchedMultiheadAttentionForward_Xdl_CShuffle_V2R2
...
@@ -1141,11 +1141,11 @@ struct DeviceBatchedMultiheadAttentionForward_Xdl_CShuffle_V2R2
const
std
::
vector
<
index_t
>&
z_gs_ms_ns_lengths
,
const
std
::
vector
<
index_t
>&
z_gs_ms_ns_lengths
,
const
std
::
vector
<
index_t
>&
z_gs_ms_ns_strides
,
const
std
::
vector
<
index_t
>&
z_gs_ms_ns_strides
,
const
std
::
vector
<
index_t
>&
lse_gs_ms_lengths
,
const
std
::
vector
<
index_t
>&
lse_gs_ms_lengths
,
const
std
::
array
<
std
::
vector
<
ck
::
index_t
>
,
Num
Acc0Bias
>
acc0_biases_gs_ms_ns_lengths
,
const
std
::
array
<
std
::
vector
<
ck
::
index_t
>
,
Num
D0Tensor
>
acc0_biases_gs_ms_ns_lengths
,
const
std
::
array
<
std
::
vector
<
ck
::
index_t
>
,
Num
Acc0Bias
>
acc0_biases_gs_ms_ns_strides
,
const
std
::
array
<
std
::
vector
<
ck
::
index_t
>
,
Num
D0Tensor
>
acc0_biases_gs_ms_ns_strides
,
const
std
::
array
<
std
::
vector
<
ck
::
index_t
>
,
Num
Acc1Bias
>
const
std
::
array
<
std
::
vector
<
ck
::
index_t
>
,
Num
D1Tensor
>
acc1_biases_gs_ms_gemm1ns_lengths
,
// acc1_biases_gs_ms_os_lengths
acc1_biases_gs_ms_gemm1ns_lengths
,
// acc1_biases_gs_ms_os_lengths
const
std
::
array
<
std
::
vector
<
ck
::
index_t
>
,
Num
Acc1Bias
>
const
std
::
array
<
std
::
vector
<
ck
::
index_t
>
,
Num
D1Tensor
>
acc1_biases_gs_ms_gemm1ns_strides
,
// acc1_biases_gs_ms_os_strides
acc1_biases_gs_ms_gemm1ns_strides
,
// acc1_biases_gs_ms_os_strides
AElementwiseOperation
a_element_op
,
AElementwiseOperation
a_element_op
,
BElementwiseOperation
b_element_op
,
BElementwiseOperation
b_element_op
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment