Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
b2df7018
Commit
b2df7018
authored
Aug 15, 2023
by
letaoqin
Browse files
fix ComputeBasePtrOfStridedBatch init bug
parent
4b0a5069
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
79 additions
and
68 deletions
+79
-68
example/52_flash_atten_bias/run_batched_multihead_attention_bias_forward.inc
...ten_bias/run_batched_multihead_attention_bias_forward.inc
+70
-68
include/ck/tensor_operation/gpu/device/impl/device_batched_mha_fwd_xdl_cshuffle_v2.hpp
...pu/device/impl/device_batched_mha_fwd_xdl_cshuffle_v2.hpp
+9
-0
No files found.
example/52_flash_atten_bias/run_batched_multihead_attention_bias_forward.inc
View file @
b2df7018
...
...
@@ -183,8 +183,8 @@ int run(int argc, char* argv[])
// TODO ANT: replace array with vector?
auto
gemm
=
DeviceGemmInstance
{};
auto
invoker
=
gemm
.
MakeInvoker
();
auto
argument
=
gemm
.
MakeArgument
(
static_cast
<
ADataType
*>
(
a_device_buf
.
GetDeviceBuffer
()),
auto
argument
=
gemm
.
MakeArgument
(
static_cast
<
ADataType
*>
(
a_device_buf
.
GetDeviceBuffer
()),
static_cast
<
B0DataType
*>
(
b0_device_buf
.
GetDeviceBuffer
()),
static_cast
<
B1DataType
*>
(
b1_device_buf
.
GetDeviceBuffer
()),
static_cast
<
CDataType
*>
(
c_device_buf
.
GetDeviceBuffer
()),
...
...
@@ -213,8 +213,8 @@ int run(int argc, char* argv[])
b1_element_op
,
c_element_op
,
p_drop
,
// dropout ratio
{
seed
,
offset
});
// dropout random seed and offset, offset should be at
least the number of
//
elements on a thread
{
seed
,
offset
});
// dropout random seed and offset, offset should be at
// least the number of
elements on a thread
if
(
!
gemm
.
IsSupportedArgument
(
argument
))
{
...
...
@@ -230,7 +230,9 @@ int run(int argc, char* argv[])
std
::
size_t
flop
=
(
size_t
(
M
)
*
N
*
K
*
2
+
size_t
(
M
)
*
N
*
O
*
2
)
*
BatchCount
;
std
::
size_t
num_btype
=
(
sizeof
(
ADataType
)
*
M
*
K
+
sizeof
(
B0DataType
)
*
K
*
N
+
sizeof
(
B1DataType
)
*
N
*
O
+
sizeof
(
CDataType
)
*
M
*
O
+
sizeof
(
DDataType
)
*
M
*
N
*
std
::
is_void
<
DDataType
>::
value
?
1
:
0
)
*
sizeof
(
DDataType
)
*
M
*
N
*
std
::
is_void
<
DDataType
>::
value
?
0
:
1
)
*
BatchCount
;
float
tflops
=
static_cast
<
float
>
(
flop
)
/
1.E9
/
ave_time
;
...
...
@@ -243,8 +245,8 @@ int run(int argc, char* argv[])
if
(
do_verification
)
{
// run for storing z tensor
argument
=
gemm
.
MakeArgument
(
static_cast
<
ADataType
*>
(
a_device_buf
.
GetDeviceBuffer
()),
argument
=
gemm
.
MakeArgument
(
static_cast
<
ADataType
*>
(
a_device_buf
.
GetDeviceBuffer
()),
static_cast
<
B0DataType
*>
(
b0_device_buf
.
GetDeviceBuffer
()),
static_cast
<
B1DataType
*>
(
b1_device_buf
.
GetDeviceBuffer
()),
static_cast
<
CDataType
*>
(
c_device_buf
.
GetDeviceBuffer
()),
...
...
@@ -273,8 +275,8 @@ int run(int argc, char* argv[])
b1_element_op
,
c_element_op
,
p_drop
,
// dropout ratio
{
seed
,
offset
});
// dropout random seed and offset, offset should be
at least the number
//
of elements on a thread
{
seed
,
offset
});
// dropout random seed and offset, offset should be
// at least the number
of elements on a thread
c_device_buf
.
SetZero
();
lse_device_buf
.
SetZero
();
invoker
.
Run
(
argument
,
StreamConfig
{
nullptr
,
false
});
...
...
include/ck/tensor_operation/gpu/device/impl/device_batched_mha_fwd_xdl_cshuffle_v2.hpp
View file @
b2df7018
...
...
@@ -697,6 +697,15 @@ struct DeviceBatchedMultiheadAttentionForward_Xdl_CShuffle_V2
d0_grid_desc_g_m_n_
=
Transform
::
MakeCGridDescriptor_G_M_N
(
acc0_biases_gs_ms_ns_lengths
,
acc0_biases_gs_ms_ns_strides
);
compute_base_ptr_of_batch_
=
ComputeBasePtrOfStridedBatch
(
a_grid_desc_g_m_k_
,
b_grid_desc_g_n_k_
,
d0_grid_desc_g_m_n_
,
b1_grid_desc_g_n_k_
,
c_grid_desc_g_m_n_
,
z_grid_desc_g_m_n_
,
type_convert
<
index_t
>
(
lse_grid_desc_m_
.
GetElementSpaceSize
()));
d0_n_length_stride_
.
push_back
(
acc0_biases_gs_ms_ns_lengths
[
NumDimG
+
NumDimM
]);
d0_n_length_stride_
.
push_back
(
acc0_biases_gs_ms_ns_strides
[
NumDimG
+
NumDimM
]);
}
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment