Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
855b604f
Commit
855b604f
authored
Aug 13, 2023
by
letaoqin
Browse files
grouped gemm remove multiple D
parent
98212afe
Changes
7
Expand all
Hide whitespace changes
Inline
Side-by-side
Showing
7 changed files
with
107 additions
and
139 deletions
+107
-139
example/52_flash_atten_bias/grouped_multihead_attention_bias_forward_v2.cpp
...tten_bias/grouped_multihead_attention_bias_forward_v2.cpp
+2
-2
example/52_flash_atten_bias/run_batched_multihead_attention_bias_forward.inc
...ten_bias/run_batched_multihead_attention_bias_forward.inc
+1
-1
example/52_flash_atten_bias/run_grouped_multihead_attention_bias_forward.inc
...ten_bias/run_grouped_multihead_attention_bias_forward.inc
+5
-9
include/ck/tensor_operation/gpu/device/device_grouped_gemm_softmax_gemm_permute.hpp
...n/gpu/device/device_grouped_gemm_softmax_gemm_permute.hpp
+7
-7
include/ck/tensor_operation/gpu/device/impl/device_batched_mha_fwd_xdl_cshuffle_v2.hpp
...pu/device/impl/device_batched_mha_fwd_xdl_cshuffle_v2.hpp
+7
-5
include/ck/tensor_operation/gpu/device/impl/device_grouped_mha_fwd_xdl_cshuffle_v2.hpp
...pu/device/impl/device_grouped_mha_fwd_xdl_cshuffle_v2.hpp
+84
-114
include/ck/tensor_operation/gpu/grid/gridwise_batched_mha_fwd_xdl_cshuffle_v2.hpp
...ion/gpu/grid/gridwise_batched_mha_fwd_xdl_cshuffle_v2.hpp
+1
-1
No files found.
example/52_flash_atten_bias/grouped_multihead_attention_bias_forward_v2.cpp
View file @
855b604f
...
@@ -53,8 +53,8 @@ using CShuffleDataType = F32;
...
@@ -53,8 +53,8 @@ using CShuffleDataType = F32;
using
CDataType
=
DataType
;
using
CDataType
=
DataType
;
using
ZDataType
=
U16
;
// INT32
using
ZDataType
=
U16
;
// INT32
using
LSEDataType
=
F32
;
using
LSEDataType
=
F32
;
using
Acc0BiasDataType
=
ck
::
Tuple
<
DDataType
>
;
using
Acc0BiasDataType
=
DDataType
;
using
Acc1BiasDataType
=
ck
::
Tuple
<>
;
using
Acc1BiasDataType
=
void
;
static
constexpr
ck
::
index_t
NumDimG
=
2
;
static
constexpr
ck
::
index_t
NumDimG
=
2
;
static
constexpr
ck
::
index_t
NumDimM
=
1
;
static
constexpr
ck
::
index_t
NumDimM
=
1
;
...
...
example/52_flash_atten_bias/run_batched_multihead_attention_bias_forward.inc
View file @
855b604f
...
@@ -137,7 +137,7 @@ int run(int argc, char* argv[])
...
@@ -137,7 +137,7 @@ int run(int argc, char* argv[])
a_gs_ms_ks
.
GenerateTensorValue
(
GeneratorTensor_2
<
ADataType
>
{
-
2
,
2
});
a_gs_ms_ks
.
GenerateTensorValue
(
GeneratorTensor_2
<
ADataType
>
{
-
2
,
2
});
b0_gs_ns_ks
.
GenerateTensorValue
(
GeneratorTensor_2
<
B0DataType
>
{
-
2
,
2
});
b0_gs_ns_ks
.
GenerateTensorValue
(
GeneratorTensor_2
<
B0DataType
>
{
-
2
,
2
});
b1_gs_os_ns
.
GenerateTensorValue
(
GeneratorTensor_2
<
B1DataType
>
{
-
2
,
2
});
b1_gs_os_ns
.
GenerateTensorValue
(
GeneratorTensor_2
<
B1DataType
>
{
-
2
,
2
});
d_gs_ms_ns
.
GenerateTensorValue
(
GeneratorTensor_2
<
DDataType
>
{
-
2
,
2
});
d_gs_ms_ns
.
GenerateTensorValue
(
GeneratorTensor_2
<
DDataType
>
{
-
1
,
1
});
break
;
break
;
case
2
:
case
2
:
a_gs_ms_ks
.
GenerateTensorValue
(
GeneratorTensor_3
<
ADataType
>
{
0.0
,
1.0
});
a_gs_ms_ks
.
GenerateTensorValue
(
GeneratorTensor_3
<
ADataType
>
{
0.0
,
1.0
});
...
...
example/52_flash_atten_bias/run_grouped_multihead_attention_bias_forward.inc
View file @
855b604f
...
@@ -57,7 +57,7 @@ int run(int argc, char* argv[])
...
@@ -57,7 +57,7 @@ int run(int argc, char* argv[])
std
::
vector
<
const
void
*>
p_b0
;
std
::
vector
<
const
void
*>
p_b0
;
std
::
vector
<
const
void
*>
p_b1
;
std
::
vector
<
const
void
*>
p_b1
;
std
::
vector
<
void
*>
p_c
;
std
::
vector
<
void
*>
p_c
;
std
::
vector
<
std
::
vector
<
const
void
*>
>
p_d
;
std
::
vector
<
const
void
*>
p_d
;
std
::
vector
<
void
*>
p_z
;
// for result verification
std
::
vector
<
void
*>
p_z
;
// for result verification
std
::
vector
<
void
*>
p_z_nullptr
;
// for time test
std
::
vector
<
void
*>
p_z_nullptr
;
// for time test
std
::
vector
<
void
*>
p_lse
;
std
::
vector
<
void
*>
p_lse
;
...
@@ -147,10 +147,8 @@ int run(int argc, char* argv[])
...
@@ -147,10 +147,8 @@ int run(int argc, char* argv[])
z_gs_ms_ns_strides
,
z_gs_ms_ns_strides
,
lse_gs_ms_lengths
,
lse_gs_ms_lengths
,
lse_gs_ms_strides
,
lse_gs_ms_strides
,
std
::
vector
<
std
::
vector
<
ck
::
index_t
>>
{
d_gs_ms_ns_lengths
,
// acc0_biases_gs_ms_ns_lengths
d_gs_ms_ns_lengths
},
// acc0_biases_gs_ms_ns_lengths
d_gs_ms_ns_strides
,
// acc0_biases_gs_ms_ns_strides
std
::
vector
<
std
::
vector
<
ck
::
index_t
>>
{
d_gs_ms_ns_strides
},
// acc0_biases_gs_ms_ns_strides
{},
// acc1_biases_gs_ms_os_lengths
{},
// acc1_biases_gs_ms_os_lengths
{}});
// acc1_biases_gs_ms_os_strides
{}});
// acc1_biases_gs_ms_os_strides
...
@@ -167,7 +165,7 @@ int run(int argc, char* argv[])
...
@@ -167,7 +165,7 @@ int run(int argc, char* argv[])
flop
+=
(
size_t
(
M
)
*
N
*
K
*
2
+
size_t
(
M
)
*
N
*
O
*
2
)
*
Batch
;
flop
+=
(
size_t
(
M
)
*
N
*
K
*
2
+
size_t
(
M
)
*
N
*
O
*
2
)
*
Batch
;
num_byte
+=
(
sizeof
(
ADataType
)
*
M
*
K
+
sizeof
(
B0DataType
)
*
K
*
N
+
num_byte
+=
(
sizeof
(
ADataType
)
*
M
*
K
+
sizeof
(
B0DataType
)
*
K
*
N
+
sizeof
(
B1DataType
)
*
N
*
O
+
sizeof
(
CDataType
)
*
M
*
O
+
sizeof
(
B1DataType
)
*
N
*
O
+
sizeof
(
CDataType
)
*
M
*
O
+
sizeof
(
DDataType
)
*
M
*
N
*
(
Acc0BiasDataType
::
Size
()
?
0
:
1
))
*
sizeof
(
DDataType
)
*
M
*
N
*
(
std
::
is_void
<
Acc0BiasDataType
>
::
value
?
0
:
1
))
*
Batch
;
Batch
;
if
(
i
<
4
)
if
(
i
<
4
)
...
@@ -244,9 +242,7 @@ int run(int argc, char* argv[])
...
@@ -244,9 +242,7 @@ int run(int argc, char* argv[])
p_b0
.
push_back
(
b0_tensors_device
[
i
]
->
GetDeviceBuffer
());
p_b0
.
push_back
(
b0_tensors_device
[
i
]
->
GetDeviceBuffer
());
p_b1
.
push_back
(
b1_tensors_device
[
i
]
->
GetDeviceBuffer
());
p_b1
.
push_back
(
b1_tensors_device
[
i
]
->
GetDeviceBuffer
());
p_c
.
push_back
(
c_tensors_device
[
i
]
->
GetDeviceBuffer
());
p_c
.
push_back
(
c_tensors_device
[
i
]
->
GetDeviceBuffer
());
p_d
.
push_back
({
d_tensors_device
[
i
]
->
GetDeviceBuffer
()});
p_d
.
push_back
(
d_tensors_device
[
i
]
->
GetDeviceBuffer
());
// std::cout << "from host group id: " << i << " d address: " <<
// d_tensors_device[i]->GetDeviceBuffer() << std::endl;
p_z
.
push_back
(
z_tensors_device
[
i
]
->
GetDeviceBuffer
());
p_z
.
push_back
(
z_tensors_device
[
i
]
->
GetDeviceBuffer
());
p_z_nullptr
.
push_back
(
nullptr
);
p_z_nullptr
.
push_back
(
nullptr
);
p_lse
.
push_back
(
lse_tensors_device
[
i
]
->
GetDeviceBuffer
());
p_lse
.
push_back
(
lse_tensors_device
[
i
]
->
GetDeviceBuffer
());
...
...
include/ck/tensor_operation/gpu/device/device_grouped_gemm_softmax_gemm_permute.hpp
View file @
855b604f
...
@@ -111,11 +111,11 @@ struct DeviceGroupedMultiheadAttentionForward : public BaseOperator
...
@@ -111,11 +111,11 @@ struct DeviceGroupedMultiheadAttentionForward : public BaseOperator
std
::
vector
<
index_t
>
lse_gs_ms_lengths
;
std
::
vector
<
index_t
>
lse_gs_ms_lengths
;
std
::
vector
<
index_t
>
lse_gs_ms_strides
;
std
::
vector
<
index_t
>
lse_gs_ms_strides
;
std
::
vector
<
std
::
vector
<
index_t
>
>
acc0_biases_gs_ms_ns_lengths
;
std
::
vector
<
index_t
>
acc0_biases_gs_ms_ns_lengths
;
std
::
vector
<
std
::
vector
<
index_t
>
>
acc0_biases_gs_ms_ns_strides
;
std
::
vector
<
index_t
>
acc0_biases_gs_ms_ns_strides
;
std
::
vector
<
std
::
vector
<
index_t
>
>
acc1_biases_gs_ms_os_lengths
;
std
::
vector
<
index_t
>
acc1_biases_gs_ms_os_lengths
;
std
::
vector
<
std
::
vector
<
index_t
>
>
acc1_biases_gs_ms_os_strides
;
std
::
vector
<
index_t
>
acc1_biases_gs_ms_os_strides
;
};
};
virtual
std
::
unique_ptr
<
BaseArgument
>
virtual
std
::
unique_ptr
<
BaseArgument
>
...
@@ -125,9 +125,9 @@ struct DeviceGroupedMultiheadAttentionForward : public BaseOperator
...
@@ -125,9 +125,9 @@ struct DeviceGroupedMultiheadAttentionForward : public BaseOperator
std
::
vector
<
void
*>
p_c_vec
,
std
::
vector
<
void
*>
p_c_vec
,
std
::
vector
<
void
*>
p_z_vec
,
std
::
vector
<
void
*>
p_z_vec
,
std
::
vector
<
void
*>
p_lse_vec
,
std
::
vector
<
void
*>
p_lse_vec
,
std
::
vector
<
std
::
vector
<
const
void
*>
>
p_acc0_biases_vec
,
std
::
vector
<
const
void
*>
p_acc0_biases_vec
,
std
::
vector
<
std
::
vector
<
const
void
*>
>
p_acc1_biases_vec
,
std
::
vector
<
const
void
*>
p_acc1_biases_vec
,
std
::
vector
<
ProblemDesc
>
problem_desc_vec
,
std
::
vector
<
ProblemDesc
>
&
problem_desc_vec
,
AElementwiseOperation
a_element_op
,
AElementwiseOperation
a_element_op
,
B0ElementwiseOperation
b0_element_op
,
B0ElementwiseOperation
b0_element_op
,
Acc0ElementwiseOperation
acc0_element_op
,
Acc0ElementwiseOperation
acc0_element_op
,
...
...
include/ck/tensor_operation/gpu/device/impl/device_batched_mha_fwd_xdl_cshuffle_v2.hpp
View file @
855b604f
...
@@ -685,11 +685,6 @@ struct DeviceBatchedMultiheadAttentionForward_Xdl_CShuffle_V2
...
@@ -685,11 +685,6 @@ struct DeviceBatchedMultiheadAttentionForward_Xdl_CShuffle_V2
c_grid_desc_mblock_mperblock_nblock_nperblock_
=
c_grid_desc_mblock_mperblock_nblock_nperblock_
=
GridwiseGemm
::
MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
(
GridwiseGemm
::
MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
(
c_grid_desc_m_n_
);
c_grid_desc_m_n_
);
D0GridDesc_M_N
d0_grid_desc_m_n
{};
d0_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5_
=
GridwiseGemm
::
MakeD0GridDescriptor_M0_N0_M1_N1_M2_N2_M3_N3_N4_N5
(
d0_grid_desc_m_n
);
}
}
d0_n_length_stride_
.
push_back
(
acc0_biases_gs_ms_ns_lengths
[
NumDimG
+
NumDimM
]);
d0_n_length_stride_
.
push_back
(
acc0_biases_gs_ms_ns_lengths
[
NumDimG
+
NumDimM
]);
...
@@ -724,6 +719,13 @@ struct DeviceBatchedMultiheadAttentionForward_Xdl_CShuffle_V2
...
@@ -724,6 +719,13 @@ struct DeviceBatchedMultiheadAttentionForward_Xdl_CShuffle_V2
std
::
cout
<<
"b_grid_desc_g_n_k_: "
<<
b_grid_desc_g_n_k_
.
GetLength
(
I0
)
<<
", "
std
::
cout
<<
"b_grid_desc_g_n_k_: "
<<
b_grid_desc_g_n_k_
.
GetLength
(
I0
)
<<
", "
<<
b_grid_desc_g_n_k_
.
GetLength
(
I1
)
<<
", "
<<
b_grid_desc_g_n_k_
.
GetLength
(
I1
)
<<
", "
<<
b_grid_desc_g_n_k_
.
GetLength
(
I2
)
<<
'\n'
;
<<
b_grid_desc_g_n_k_
.
GetLength
(
I2
)
<<
'\n'
;
std
::
cout
<<
"d0_grid_desc_g_m_n_: "
<<
d0_grid_desc_g_m_n_
.
GetLength
(
I0
)
<<
", "
<<
d0_grid_desc_g_m_n_
.
GetLength
(
I1
)
<<
", "
<<
d0_grid_desc_g_m_n_
.
GetLength
(
I2
)
<<
'\n'
;
std
::
cout
<<
"d0_grid_desc_m_n_: "
<<
d0_grid_desc_m_n_
.
GetLength
(
I0
)
<<
", "
<<
d0_grid_desc_m_n_
.
GetLength
(
I1
)
<<
'\n'
;
std
::
cout
<<
"b1_grid_desc_g_n_k_: "
<<
b1_grid_desc_g_n_k_
.
GetLength
(
I0
)
<<
", "
std
::
cout
<<
"b1_grid_desc_g_n_k_: "
<<
b1_grid_desc_g_n_k_
.
GetLength
(
I0
)
<<
", "
<<
b1_grid_desc_g_n_k_
.
GetLength
(
I1
)
<<
", "
<<
b1_grid_desc_g_n_k_
.
GetLength
(
I1
)
<<
", "
<<
b1_grid_desc_g_n_k_
.
GetLength
(
I2
)
<<
'\n'
;
<<
b1_grid_desc_g_n_k_
.
GetLength
(
I2
)
<<
'\n'
;
...
...
include/ck/tensor_operation/gpu/device/impl/device_grouped_mha_fwd_xdl_cshuffle_v2.hpp
View file @
855b604f
This diff is collapsed.
Click to expand it.
include/ck/tensor_operation/gpu/grid/gridwise_batched_mha_fwd_xdl_cshuffle_v2.hpp
View file @
855b604f
...
@@ -1291,7 +1291,7 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle_V2
...
@@ -1291,7 +1291,7 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle_V2
block_sync_lds
();
// wait for lds read in gemm0 blockwise gemm
block_sync_lds
();
// wait for lds read in gemm0 blockwise gemm
// add bias
// add bias
if
constexpr
(
std
::
is_void
<
D0DataType
>::
value
)
if
constexpr
(
!
std
::
is_void
<
D0DataType
>::
value
)
{
{
// get register
// get register
StaticBuffer
<
AddressSpaceEnum
::
Vgpr
,
StaticBuffer
<
AddressSpaceEnum
::
Vgpr
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment