Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
8e7b98eb
Commit
8e7b98eb
authored
Aug 07, 2023
by
letaoqin
Browse files
group gemm add kernel
parent
b158c537
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
61 additions
and
33 deletions
+61
-33
example/52_flash_atten_bias/grouped_multihead_attention_bias_forward_v2.cpp
...tten_bias/grouped_multihead_attention_bias_forward_v2.cpp
+1
-1
example/52_flash_atten_bias/run_grouped_multihead_attention_bias_forward.inc
...ten_bias/run_grouped_multihead_attention_bias_forward.inc
+26
-11
include/ck/tensor_operation/gpu/device/impl/device_grouped_mha_fwd_xdl_cshuffle_v2r2.hpp
.../device/impl/device_grouped_mha_fwd_xdl_cshuffle_v2r2.hpp
+34
-21
No files found.
example/52_flash_atten_bias/grouped_multihead_attention_bias_forward_v2.cpp
View file @
8e7b98eb
...
...
@@ -76,7 +76,7 @@ static constexpr auto TensorSpecA = ck::tensor_operation::device::TensorSpecia
static
constexpr
auto
TensorSpecB0
=
ck
::
tensor_operation
::
device
::
TensorSpecialization
::
Default
;
static
constexpr
auto
TensorSpecB1
=
ck
::
tensor_operation
::
device
::
TensorSpecialization
::
Default
;
static
constexpr
auto
TensorSpecC
=
ck
::
tensor_operation
::
device
::
TensorSpecialization
::
Default
;
static
constexpr
bool
Deterministic
=
tru
e
;
static
constexpr
bool
Deterministic
=
fals
e
;
#if(DIM <= 32)
using
DeviceGemmInstance
=
...
...
example/52_flash_atten_bias/run_grouped_multihead_attention_bias_forward.inc
View file @
8e7b98eb
...
...
@@ -57,7 +57,7 @@ int run(int argc, char* argv[])
std
::
vector
<
const
void
*>
p_b0
;
std
::
vector
<
const
void
*>
p_b1
;
std
::
vector
<
void
*>
p_c
;
std
::
vector
<
const
void
*>
p_d
;
std
::
vector
<
std
::
vector
<
const
void
*>
>
p_d
;
std
::
vector
<
void
*>
p_z
;
// for result verification
std
::
vector
<
void
*>
p_z_nullptr
;
// for time test
std
::
vector
<
void
*>
p_lse
;
...
...
@@ -122,8 +122,8 @@ int run(int argc, char* argv[])
std
::
vector
<
ck
::
index_t
>
d_gs_ms_ns_lengths
{
G0
,
G1
,
M
,
N
};
std
::
vector
<
ck
::
index_t
>
d_gs_ms_ns_strides
=
input_permute
?
std
::
vector
<
ck
::
index_t
>
{
M
*
G1
*
N
,
N
,
G1
*
N
,
1
}
//
Z
layout [G0, M, G1, N]
:
std
::
vector
<
ck
::
index_t
>
{
G1
*
M
*
N
,
M
*
N
,
N
,
1
};
//
Z
layout [G0, G1, M, N]
?
std
::
vector
<
ck
::
index_t
>
{
M
*
G1
*
N
,
N
,
G1
*
N
,
1
}
//
D
layout [G0, M, G1, N]
:
std
::
vector
<
ck
::
index_t
>
{
G1
*
M
*
N
,
M
*
N
,
N
,
1
};
//
D
layout [G0, G1, M, N]
std
::
vector
<
ck
::
index_t
>
z_gs_ms_ns_lengths
{
G0
,
G1
,
M
,
N
};
std
::
vector
<
ck
::
index_t
>
z_gs_ms_ns_strides
=
...
...
@@ -159,7 +159,7 @@ int run(int argc, char* argv[])
Tensor
<
B0DataType
>
b0_gs_ns_ks
(
b0_gs_ns_ks_lengths
,
b0_gs_ns_ks_strides
);
Tensor
<
B1DataType
>
b1_gs_os_ns
(
b1_gs_os_ns_lengths
,
b1_gs_os_ns_strides
);
Tensor
<
CDataType
>
c_gs_ms_os_device_result
(
c_gs_ms_os_lengths
,
c_gs_ms_os_strides
);
Tensor
<
Z
DataType
>
d_gs_ms_ns
(
d_gs_ms_ns_lengths
,
d_gs_ms_ns_strides
);
Tensor
<
D
DataType
>
d_gs_ms_ns
(
d_gs_ms_ns_lengths
,
d_gs_ms_ns_strides
);
Tensor
<
ZDataType
>
z_gs_ms_ns
(
z_gs_ms_ns_lengths
,
z_gs_ms_ns_strides
);
Tensor
<
LSEDataType
>
lse_gs_ms_device_result
(
lse_gs_ms_lengths
,
lse_gs_ms_strides
);
...
...
@@ -176,6 +176,7 @@ int run(int argc, char* argv[])
<<
"b0_gs_ns_ks["
<<
i
<<
"]: "
<<
b0_gs_ns_ks
.
mDesc
<<
", "
<<
"b1_gs_os_ns["
<<
i
<<
"]: "
<<
b1_gs_os_ns
.
mDesc
<<
", "
<<
"c_gs_ms_os["
<<
i
<<
"]: "
<<
c_gs_ms_os_device_result
.
mDesc
<<
", "
<<
"d_gs_ms_ns["
<<
i
<<
"]: "
<<
d_gs_ms_ns
.
mDesc
<<
", "
<<
"z_gs_ms_ns["
<<
i
<<
"]: "
<<
z_gs_ms_ns
.
mDesc
<<
", "
<<
"lse_gs_ms_os["
<<
i
<<
"]: "
<<
lse_gs_ms_device_result
.
mDesc
<<
std
::
endl
;
...
...
@@ -190,7 +191,7 @@ int run(int argc, char* argv[])
a_gs_ms_ks
.
GenerateTensorValue
(
GeneratorTensor_2
<
ADataType
>
{
-
2
,
2
});
b0_gs_ns_ks
.
GenerateTensorValue
(
GeneratorTensor_2
<
B0DataType
>
{
-
2
,
2
});
b1_gs_os_ns
.
GenerateTensorValue
(
GeneratorTensor_2
<
B1DataType
>
{
-
2
,
2
});
d_gs_ms_ns
.
GenerateTensorValue
(
GeneratorTensor_2
<
DDataType
>
{
-
2
,
2
});
d_gs_ms_ns
.
GenerateTensorValue
(
GeneratorTensor_2
<
DDataType
>
{
-
1
,
1
});
break
;
case
2
:
a_gs_ms_ks
.
GenerateTensorValue
(
GeneratorTensor_3
<
ADataType
>
{
0.0
,
1.0
});
...
...
@@ -243,7 +244,9 @@ int run(int argc, char* argv[])
p_b0
.
push_back
(
b0_tensors_device
[
i
]
->
GetDeviceBuffer
());
p_b1
.
push_back
(
b1_tensors_device
[
i
]
->
GetDeviceBuffer
());
p_c
.
push_back
(
c_tensors_device
[
i
]
->
GetDeviceBuffer
());
p_d
.
push_back
(
d_tensors_device
[
i
]
->
GetDeviceBuffer
());
p_d
.
push_back
({
d_tensors_device
[
i
]
->
GetDeviceBuffer
()});
// std::cout << "from host group id: " << i << " d address: " <<
// d_tensors_device[i]->GetDeviceBuffer() << std::endl;
p_z
.
push_back
(
z_tensors_device
[
i
]
->
GetDeviceBuffer
());
p_z_nullptr
.
push_back
(
nullptr
);
p_lse
.
push_back
(
lse_tensors_device
[
i
]
->
GetDeviceBuffer
());
...
...
@@ -266,8 +269,8 @@ int run(int argc, char* argv[])
p_c
,
p_z_nullptr
,
p_lse
,
std
::
vector
<
std
::
vector
<
const
void
*>>
{
p_d
}
,
// p_acc0_biases
{},
// p_acc1_biases
p_d
,
// p_acc0_biases
{},
// p_acc1_biases
problem_descs
,
a_element_op
,
b0_element_op
,
...
...
@@ -309,8 +312,8 @@ int run(int argc, char* argv[])
p_c
,
p_z
,
p_lse
,
{}
,
// p_acc0_biases
{},
// p_acc1_biases
p_d
,
// p_acc0_biases
{},
// p_acc1_biases
problem_descs
,
a_element_op
,
b0_element_op
,
...
...
@@ -344,6 +347,7 @@ int run(int argc, char* argv[])
const
auto
&
a_gs_ms_ks
=
a_tensors
[
i
];
const
auto
&
b0_gs_ns_ks
=
b0_tensors
[
i
];
const
auto
&
b1_gs_os_ns
=
b1_tensors
[
i
];
const
auto
&
d_gs_ms_ns
=
d_tensors
[
i
];
auto
&
c_gs_ms_os_device_result
=
c_tensors
[
i
];
auto
&
z_gs_ms_ns_device_result
=
z_tensors
[
i
];
auto
&
lse_gs_ms_device_result
=
lse_tensors
[
i
];
...
...
@@ -358,7 +362,8 @@ int run(int argc, char* argv[])
Tensor
<
ADataType
>
a_g_m_k
({
G0
*
G1
,
M
,
K
});
Tensor
<
B0DataType
>
b0_g_k_n
({
G0
*
G1
,
K
,
N
});
Tensor
<
B1DataType
>
b1_g_n_o
({
G0
*
G1
,
N
,
O
});
Tensor
<
AccDataType
>
acc0_g_m_n
({
G0
*
G1
,
M
,
N
});
// scratch object after gemm0
Tensor
<
AccDataType
>
acc0_g_m_n
({
G0
*
G1
,
M
,
N
});
// scratch object after gemm0
Tensor
<
AccDataType
>
d_g_m_n
({
G0
*
G1
,
M
,
N
});
Tensor
<
ADataType
>
a1_g_m_n
({
G0
*
G1
,
M
,
N
});
// scratch object after softmax
Tensor
<
ADataType
>
a1_g_m_n_drop
({
G0
*
G1
,
M
,
N
});
// scratch object after softmax
Tensor
<
CDataType
>
c_g_m_o_host_result
({
G0
*
G1
,
M
,
O
});
// scratch object after gemm1
...
...
@@ -378,6 +383,10 @@ int run(int argc, char* argv[])
b1_g_n_o
(
idx
[
0
]
*
G1
+
idx
[
1
],
idx
[
3
],
idx
[
2
])
=
self
(
idx
);
});
d_gs_ms_ns
.
ForEach
([
&
](
auto
&
self
,
auto
idx
)
{
d_g_m_n
(
idx
[
0
]
*
G1
+
idx
[
1
],
idx
[
2
],
idx
[
3
])
=
self
(
idx
);
});
z_gs_ms_ns_device_result
.
ForEach
([
&
](
auto
&
self
,
auto
idx
)
{
z_g_m_n
(
idx
[
0
]
*
G1
+
idx
[
1
],
idx
[
2
],
idx
[
3
])
=
self
(
idx
);
});
...
...
@@ -390,6 +399,8 @@ int run(int argc, char* argv[])
ref_gemm0_invoker
.
Run
(
ref_gemm0_argument
);
// bias
acc0_g_m_n
.
ForEach
([
&
](
auto
&
self
,
auto
idx
)
{
self
(
idx
)
+=
d_g_m_n
(
idx
);
});
// masking
const
auto
mask
=
DeviceGemmInstance
::
C0MatrixMask
(
M
,
N
);
acc0_g_m_n
.
ForEach
([
&
](
auto
&
self
,
auto
idx
)
{
...
...
@@ -470,6 +481,10 @@ int run(int argc, char* argv[])
"Error: Incorrect results lse!"
,
rtol
,
atol
);
if
(
!
pass_
)
{
std
::
cout
<<
"from group: "
<<
i
<<
std
::
endl
;
}
pass
&=
pass_
;
}
if
(
pass
)
...
...
include/ck/tensor_operation/gpu/device/impl/device_grouped_mha_fwd_xdl_cshuffle_v2r2.hpp
View file @
8e7b98eb
...
...
@@ -14,7 +14,7 @@
#include "ck/tensor_operation/gpu/device/device_grouped_gemm_softmax_gemm_permute.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/matrix_padder.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_batched_mha_fwd_xdl_cshuffle_v2.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_batched_mha_fwd_xdl_cshuffle_v2
r2
.hpp"
#include "ck/tensor_operation/operator_transform/transform_contraction_to_gemm.hpp"
#include "ck/host_utility/device_prop.hpp"
#include "ck/host_utility/kernel_launch.hpp"
...
...
@@ -39,7 +39,7 @@ __global__ void
#if CK_USE_LAUNCH_BOUNDS
__launch_bounds__
(
CK_MAX_THREAD_PER_BLOCK
,
CK_MIN_BLOCK_PER_CU
)
#endif
kernel_grouped_gemm_softmax_gemm_xdl_cshuffle_v2
(
kernel_grouped_gemm_softmax_gemm_xdl_cshuffle_v2
r2
(
const
void
CK_CONSTANT_ADDRESS_SPACE
*
group_kernel_args
,
const
index_t
group_count
,
const
AElementwiseOperation
a_element_op
,
...
...
@@ -95,6 +95,8 @@ __global__ void
arg_ptr
[
group_id
].
compute_base_ptr_of_batch_
.
GetB1BasePtr
(
g_idx
)));
const
long_index_t
c_batch_offset
=
__builtin_amdgcn_readfirstlane
(
static_cast
<
long_index_t
>
(
arg_ptr
[
group_id
].
compute_base_ptr_of_batch_
.
GetCBasePtr
(
g_idx
)));
const
long_index_t
d_batch_offset
=
__builtin_amdgcn_readfirstlane
(
static_cast
<
long_index_t
>
(
arg_ptr
[
group_id
].
compute_base_ptr_of_batch_
.
GetDBasePtr
(
g_idx
)));
const
long_index_t
z_batch_offset
=
__builtin_amdgcn_readfirstlane
(
static_cast
<
long_index_t
>
(
arg_ptr
[
group_id
].
compute_base_ptr_of_batch_
.
GetZBasePtr
(
g_idx
)));
const
long_index_t
lse_batch_offset
=
__builtin_amdgcn_readfirstlane
(
static_cast
<
long_index_t
>
(
...
...
@@ -109,6 +111,9 @@ __global__ void
arg_ptr
[
group_id
].
p_b_grid_
+
b_batch_offset
,
arg_ptr
[
group_id
].
p_b1_grid_
+
b1_batch_offset
,
arg_ptr
[
group_id
].
p_c_grid_
+
c_batch_offset
,
arg_ptr
[
group_id
].
p_d_grid_
==
nullptr
?
nullptr
:
arg_ptr
[
group_id
].
p_d_grid_
+
d_batch_offset
,
arg_ptr
[
group_id
].
p_z_grid_
==
nullptr
?
nullptr
:
arg_ptr
[
group_id
].
p_z_grid_
+
z_batch_offset
,
...
...
@@ -126,6 +131,7 @@ __global__ void
arg_ptr
[
group_id
].
b_grid_desc_bk0_n_bk1_
,
arg_ptr
[
group_id
].
b1_grid_desc_bk0_n_bk1_
,
arg_ptr
[
group_id
].
c_grid_desc_mblock_mperblock_nblock_nperblock_
,
arg_ptr
[
group_id
].
d_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5_
,
arg_ptr
[
group_id
].
z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5_
,
arg_ptr
[
group_id
].
lse_grid_desc_m_
,
arg_ptr
[
group_id
].
block_2_ctile_map_
,
...
...
@@ -146,6 +152,8 @@ __global__ void
arg_ptr
[
group_id
].
p_b_grid_
+
b_batch_offset
,
arg_ptr
[
group_id
].
p_b1_grid_
+
b1_batch_offset
,
arg_ptr
[
group_id
].
p_c_grid_
+
c_batch_offset
,
arg_ptr
[
group_id
].
p_d_grid_
==
nullptr
?
nullptr
:
arg_ptr
[
group_id
].
p_d_grid_
+
d_batch_offset
,
arg_ptr
[
group_id
].
p_z_grid_
==
nullptr
?
nullptr
:
arg_ptr
[
group_id
].
p_z_grid_
+
z_batch_offset
,
arg_ptr
[
group_id
].
p_lse_grid_
==
nullptr
...
...
@@ -162,6 +170,7 @@ __global__ void
arg_ptr
[
group_id
].
b_grid_desc_bk0_n_bk1_
,
arg_ptr
[
group_id
].
b1_grid_desc_bk0_n_bk1_
,
arg_ptr
[
group_id
].
c_grid_desc_mblock_mperblock_nblock_nperblock_
,
arg_ptr
[
group_id
].
d_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5_
,
arg_ptr
[
group_id
].
z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5_
,
arg_ptr
[
group_id
].
lse_grid_desc_m_
,
arg_ptr
[
group_id
].
block_2_ctile_map_
,
...
...
@@ -330,6 +339,13 @@ struct DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle_V2R2
static
constexpr
auto
I0
=
Number
<
0
>
{};
static
constexpr
auto
I1
=
Number
<
1
>
{};
static
constexpr
auto
I2
=
Number
<
2
>
{};
static
constexpr
auto
I3
=
Number
<
3
>
{};
static
constexpr
auto
I4
=
Number
<
4
>
{};
static
constexpr
auto
I5
=
Number
<
5
>
{};
static
constexpr
auto
I6
=
Number
<
6
>
{};
static
constexpr
auto
I7
=
Number
<
7
>
{};
static
constexpr
auto
I8
=
Number
<
8
>
{};
static
constexpr
auto
I9
=
Number
<
9
>
{};
using
Transform
=
TransformBatchedContractionContractionToBatchedGemmGemm
<
Sequence
<
NumDimG
,
NumDimM
,
NumDimN
,
NumDimK
,
NumDimO
>
,
...
...
@@ -495,7 +511,7 @@ struct DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle_V2R2
};
// GridwiseGemm
using
GridwiseGemm
=
GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle_V2
<
using
GridwiseGemm
=
GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle_V2
R2
<
ADataType
,
// TODO: distinguish A/B datatype
ZDataType
,
GemmDataType
,
...
...
@@ -647,20 +663,17 @@ struct DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle_V2R2
b1_element_op_
{
b1_element_op
},
c_element_op_
{
c_element_op
}
{
ignore
=
p_acc1_biases_vec
;
// TODO ANT: implement bias addition
group_count_
=
problem_desc_vec
.
size
();
if
(
!
(
group_count_
==
p_a_vec
.
size
()
&&
group_count_
==
p_b_vec
.
size
()
&&
group_count_
==
p_b1_vec
.
size
()
&&
group_count_
==
p_c_vec
.
size
()))
group_count_
==
p_b1_vec
.
size
()
&&
group_count_
==
p_c_vec
.
size
()
&&
(
group_count_
==
p_acc0_biases_vec
.
size
()
||
p_acc0_biases_vec
.
size
()
==
0
)))
{
throw
std
::
runtime_error
(
"wrong! group_count_ != a/b/b1/c_vec.size"
);
}
if
(
!
(
p_acc0_biases_vec
.
size
()
==
p_acc1_biases_vec
.
size
()))
{
throw
std
::
runtime_error
(
"wrong! acc0_bias_vec.size != acc1_bias_vec.size"
);
}
grid_size_
=
0
;
index_t
z_random_matrix_offset
=
0
;
...
...
@@ -884,18 +897,18 @@ struct DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle_V2R2
auto
launch_kernel
=
[
&
](
auto
has_main_k_block_loop_
,
auto
is_dropout_
,
auto
is_lse_storing_
)
{
const
auto
kernel
=
kernel_grouped_gemm_softmax_gemm_xdl_cshuffle_v2
<
GridwiseGemm
,
GemmAccDataType
,
GroupKernelArg
,
AElementwiseOperation
,
BElementwiseOperation
,
AccElementwiseOperation
,
B1ElementwiseOperation
,
CElementwiseOperation
,
has_main_k_block_loop_
,
is_dropout_
,
is_lse_storing_
,
Deterministic
>
;
kernel_grouped_gemm_softmax_gemm_xdl_cshuffle_v2
r2
<
GridwiseGemm
,
GemmAccDataType
,
GroupKernelArg
,
AElementwiseOperation
,
BElementwiseOperation
,
AccElementwiseOperation
,
B1ElementwiseOperation
,
CElementwiseOperation
,
has_main_k_block_loop_
,
is_dropout_
,
is_lse_storing_
,
Deterministic
>
;
return
launch_and_time_kernel
(
stream_config
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment