Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
64c9f790
Commit
64c9f790
authored
Aug 15, 2023
by
letaoqin
Browse files
fix times
parent
b2df7018
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
28 additions
and
37 deletions
+28
-37
example/52_flash_atten_bias/batched_multihead_attention_bias_forward_v2.cpp
...tten_bias/batched_multihead_attention_bias_forward_v2.cpp
+1
-2
example/52_flash_atten_bias/run_batched_multihead_attention_bias_forward.inc
...ten_bias/run_batched_multihead_attention_bias_forward.inc
+16
-17
include/ck/tensor_operation/gpu/device/impl/device_batched_mha_fwd_xdl_cshuffle_v2.hpp
...pu/device/impl/device_batched_mha_fwd_xdl_cshuffle_v2.hpp
+11
-18
No files found.
example/52_flash_atten_bias/batched_multihead_attention_bias_forward_v2.cpp
View file @
64c9f790
...
...
@@ -50,10 +50,9 @@ using B1DataType = DataType;
using
AccDataType
=
F32
;
using
CShuffleDataType
=
F32
;
using
CDataType
=
DataType
;
using
DDataType
=
F16
;
using
ZDataType
=
U16
;
// INT32
using
LSEDataType
=
F32
;
using
Acc0BiasDataType
=
DDataType
;
using
Acc0BiasDataType
=
F16
;
using
Acc1BiasDataType
=
void
;
static
constexpr
ck
::
index_t
NumDimG
=
2
;
...
...
example/52_flash_atten_bias/run_batched_multihead_attention_bias_forward.inc
View file @
64c9f790
...
...
@@ -116,7 +116,7 @@ int run(int argc, char* argv[])
Tensor
<
B1DataType
>
b1_gs_os_ns
(
b1_gs_os_ns_lengths
,
b1_gs_os_ns_strides
);
Tensor
<
CDataType
>
c_gs_ms_os_host_result
(
c_gs_ms_os_lengths
,
c_gs_ms_os_strides
);
Tensor
<
CDataType
>
c_gs_ms_os_device_result
(
c_gs_ms_os_lengths
,
c_gs_ms_os_strides
);
Tensor
<
D
DataType
>
d_gs_ms_ns
(
d_gs_ms_ns_lengths
,
z_gs_ms_ns_strides
);
Tensor
<
Acc0Bias
DataType
>
d_gs_ms_ns
(
d_gs_ms_ns_lengths
,
z_gs_ms_ns_strides
);
Tensor
<
ZDataType
>
z_gs_ms_ns
(
z_gs_ms_ns_lengths
,
z_gs_ms_ns_strides
);
Tensor
<
LSEDataType
>
lse_gs_ms_host_result
(
lse_gs_ms_lengths
,
lse_gs_ms_strides
);
Tensor
<
LSEDataType
>
lse_gs_ms_device_result
(
lse_gs_ms_lengths
,
lse_gs_ms_strides
);
...
...
@@ -137,25 +137,25 @@ int run(int argc, char* argv[])
a_gs_ms_ks
.
GenerateTensorValue
(
GeneratorTensor_2
<
ADataType
>
{
-
2
,
2
});
b0_gs_ns_ks
.
GenerateTensorValue
(
GeneratorTensor_2
<
B0DataType
>
{
-
2
,
2
});
b1_gs_os_ns
.
GenerateTensorValue
(
GeneratorTensor_2
<
B1DataType
>
{
-
2
,
2
});
d_gs_ms_ns
.
GenerateTensorValue
(
GeneratorTensor_2
<
D
DataType
>
{
-
1
,
1
});
d_gs_ms_ns
.
GenerateTensorValue
(
GeneratorTensor_2
<
Acc0Bias
DataType
>
{
-
1
,
1
});
break
;
case
2
:
a_gs_ms_ks
.
GenerateTensorValue
(
GeneratorTensor_3
<
ADataType
>
{
0.0
,
1.0
});
b0_gs_ns_ks
.
GenerateTensorValue
(
GeneratorTensor_3
<
B0DataType
>
{
0.0
,
1.0
});
b1_gs_os_ns
.
GenerateTensorValue
(
GeneratorTensor_3
<
B1DataType
>
{
-
0.5
,
0.5
});
d_gs_ms_ns
.
GenerateTensorValue
(
GeneratorTensor_3
<
D
DataType
>
{
-
0.5
,
0.5
});
d_gs_ms_ns
.
GenerateTensorValue
(
GeneratorTensor_3
<
Acc0Bias
DataType
>
{
-
0.5
,
0.5
});
break
;
case
3
:
a_gs_ms_ks
.
GenerateTensorValue
(
GeneratorTensor_2
<
ADataType
>
{
-
2
,
2
});
b0_gs_ns_ks
.
GenerateTensorValue
(
GeneratorTensor_Diagonal
<
B0DataType
>
{});
b1_gs_os_ns
.
GenerateTensorValue
(
GeneratorTensor_Diagonal
<
B1DataType
>
{});
d_gs_ms_ns
.
GenerateTensorValue
(
GeneratorTensor_1
<
D
DataType
>
{
1
});
d_gs_ms_ns
.
GenerateTensorValue
(
GeneratorTensor_1
<
Acc0Bias
DataType
>
{
1
});
break
;
default
:
a_gs_ms_ks
.
GenerateTensorValue
(
GeneratorTensor_Sequential
<
2
>
{});
b0_gs_ns_ks
.
GenerateTensorValue
(
GeneratorTensor_Diagonal
<
B0DataType
>
{});
b1_gs_os_ns
.
GenerateTensorValue
(
GeneratorTensor_Diagonal
<
B1DataType
>
{});
d_gs_ms_ns
.
GenerateTensorValue
(
GeneratorTensor_1
<
D
DataType
>
{
1
});
d_gs_ms_ns
.
GenerateTensorValue
(
GeneratorTensor_1
<
Acc0Bias
DataType
>
{
1
});
}
DeviceMem
a_device_buf
(
sizeof
(
ADataType
)
*
a_gs_ms_ks
.
mDesc
.
GetElementSpaceSize
());
...
...
@@ -163,7 +163,7 @@ int run(int argc, char* argv[])
DeviceMem
b1_device_buf
(
sizeof
(
B1DataType
)
*
b1_gs_os_ns
.
mDesc
.
GetElementSpaceSize
());
DeviceMem
c_device_buf
(
sizeof
(
CDataType
)
*
c_gs_ms_os_device_result
.
mDesc
.
GetElementSpaceSize
());
DeviceMem
d_device_buf
(
sizeof
(
D
DataType
)
*
d_gs_ms_ns
.
mDesc
.
GetElementSpaceSize
());
DeviceMem
d_device_buf
(
sizeof
(
Acc0Bias
DataType
)
*
d_gs_ms_ns
.
mDesc
.
GetElementSpaceSize
());
DeviceMem
z_device_buf
(
sizeof
(
ZDataType
)
*
z_gs_ms_ns
.
mDesc
.
GetElementSpaceSize
());
DeviceMem
lse_device_buf
(
sizeof
(
LSEDataType
)
*
lse_gs_ms_device_result
.
mDesc
.
GetElementSpaceSize
());
...
...
@@ -190,7 +190,7 @@ int run(int argc, char* argv[])
static_cast
<
CDataType
*>
(
c_device_buf
.
GetDeviceBuffer
()),
static_cast
<
ZDataType
*>
(
nullptr
),
static_cast
<
LSEDataType
*>
(
lse_device_buf
.
GetDeviceBuffer
()),
static_cast
<
D
DataType
*>
(
d_device_buf
.
GetDeviceBuffer
()),
//
static_cast
<
Acc0Bias
DataType
*>
(
d_device_buf
.
GetDeviceBuffer
()),
//
nullptr
,
a_gs_ms_ks_lengths
,
a_gs_ms_ks_strides
,
...
...
@@ -227,17 +227,16 @@ int run(int argc, char* argv[])
float
ave_time
=
invoker
.
Run
(
argument
,
StreamConfig
{
nullptr
,
time_kernel
});
std
::
size_t
flop
=
(
size_t
(
M
)
*
N
*
K
*
2
+
size_t
(
M
)
*
N
*
O
*
2
)
*
BatchCount
;
std
::
size_t
num_btype
=
(
sizeof
(
ADataType
)
*
M
*
K
+
sizeof
(
B0DataType
)
*
K
*
N
+
sizeof
(
B1DataType
)
*
N
*
O
+
sizeof
(
CDataType
)
*
M
*
O
+
sizeof
(
DDataType
)
*
M
*
N
*
std
::
is_void
<
DDataType
>::
value
?
0
:
1
)
*
BatchCount
;
std
::
size_t
flop
=
(
size_t
(
M
)
*
N
*
K
*
2
+
size_t
(
M
)
*
N
*
O
*
2
)
*
BatchCount
;
std
::
size_t
num_bytes
=
(
sizeof
(
ADataType
)
*
M
*
K
+
sizeof
(
B0DataType
)
*
K
*
N
+
sizeof
(
B1DataType
)
*
N
*
O
+
sizeof
(
CDataType
)
*
M
*
O
+
sizeof
(
Acc0BiasDataType
)
*
M
*
N
*
(
std
::
is_void
<
Acc0BiasDataType
>::
value
?
0
:
1
))
*
BatchCount
;
float
tflops
=
static_cast
<
float
>
(
flop
)
/
1.E9
/
ave_time
;
float
gb_per_sec
=
num_b
type
/
1.E6
/
ave_time
;
float
gb_per_sec
=
num_b
ytes
/
1.E6
/
ave_time
;
std
::
cout
<<
"Perf: "
<<
ave_time
<<
" ms, "
<<
tflops
<<
" TFlops, "
<<
gb_per_sec
<<
" GB/s, "
<<
gemm
.
GetTypeString
()
<<
std
::
endl
;
...
...
@@ -252,7 +251,7 @@ int run(int argc, char* argv[])
static_cast
<
CDataType
*>
(
c_device_buf
.
GetDeviceBuffer
()),
static_cast
<
ZDataType
*>
(
z_device_buf
.
GetDeviceBuffer
()),
static_cast
<
LSEDataType
*>
(
lse_device_buf
.
GetDeviceBuffer
()),
static_cast
<
D
DataType
*>
(
d_device_buf
.
GetDeviceBuffer
()),
static_cast
<
Acc0Bias
DataType
*>
(
d_device_buf
.
GetDeviceBuffer
()),
nullptr
,
a_gs_ms_ks_lengths
,
a_gs_ms_ks_strides
,
...
...
@@ -293,7 +292,7 @@ int run(int argc, char* argv[])
Tensor
<
ADataType
>
a1_g_m_n_drop
({
G0
*
G1
,
M
,
N
});
Tensor
<
LSEDataType
>
lse_g_m_host_result
(
{
BatchCount
,
M
});
// scratch object after max + ln(sum)
Tensor
<
D
DataType
>
d_g_m_n
({
G0
*
G1
,
M
,
N
});
Tensor
<
Acc0Bias
DataType
>
d_g_m_n
({
G0
*
G1
,
M
,
N
});
Tensor
<
ZDataType
>
z_g_m_n
({
G0
*
G1
,
M
,
N
});
Tensor
<
CDataType
>
c_g_m_o_host_result
({
BatchCount
,
M
,
O
});
// scratch object after gemm1
...
...
include/ck/tensor_operation/gpu/device/impl/device_batched_mha_fwd_xdl_cshuffle_v2.hpp
View file @
64c9f790
...
...
@@ -442,6 +442,7 @@ struct DeviceBatchedMultiheadAttentionForward_Xdl_CShuffle_V2
struct
ComputeBasePtrOfStridedBatch
{
ComputeBasePtrOfStridedBatch
()
{}
ComputeBasePtrOfStridedBatch
(
const
AGridDesc_G_M_K
&
a_grid_desc_g_m_k
,
const
BGridDesc_G_N_K
&
b_grid_desc_g_n_k
,
const
D0GridDesc_G_M_N
&
d0_grid_desc_g_m_n
,
...
...
@@ -661,15 +662,7 @@ struct DeviceBatchedMultiheadAttentionForward_Xdl_CShuffle_V2
b1_gs_gemm1ns_gemm1ks_strides
[
NumDimG
+
NumDimO
+
NumDimN
-
1
]},
c_mz_gemm1nz_strides_
{
c_gs_ms_gemm1ns_strides
[
NumDimG
+
NumDimM
-
1
],
c_gs_ms_gemm1ns_strides
[
NumDimG
+
NumDimM
+
NumDimO
-
1
]},
batch_count_
{
c_grid_desc_g_m_n_
.
GetLength
(
I0
)},
compute_base_ptr_of_batch_
{
a_grid_desc_g_m_k_
,
b_grid_desc_g_n_k_
,
d0_grid_desc_g_m_n_
,
b1_grid_desc_g_n_k_
,
c_grid_desc_g_m_n_
,
z_grid_desc_g_m_n_
,
type_convert
<
index_t
>
(
lse_grid_desc_m_
.
GetElementSpaceSize
())}
batch_count_
{
c_grid_desc_g_m_n_
.
GetLength
(
I0
)}
{
// TODO ANT: implement bias addition
ignore
=
p_acc1_biases
;
...
...
@@ -697,15 +690,6 @@ struct DeviceBatchedMultiheadAttentionForward_Xdl_CShuffle_V2
d0_grid_desc_g_m_n_
=
Transform
::
MakeCGridDescriptor_G_M_N
(
acc0_biases_gs_ms_ns_lengths
,
acc0_biases_gs_ms_ns_strides
);
compute_base_ptr_of_batch_
=
ComputeBasePtrOfStridedBatch
(
a_grid_desc_g_m_k_
,
b_grid_desc_g_n_k_
,
d0_grid_desc_g_m_n_
,
b1_grid_desc_g_n_k_
,
c_grid_desc_g_m_n_
,
z_grid_desc_g_m_n_
,
type_convert
<
index_t
>
(
lse_grid_desc_m_
.
GetElementSpaceSize
()));
d0_n_length_stride_
.
push_back
(
acc0_biases_gs_ms_ns_lengths
[
NumDimG
+
NumDimM
]);
d0_n_length_stride_
.
push_back
(
acc0_biases_gs_ms_ns_strides
[
NumDimG
+
NumDimM
]);
}
...
...
@@ -731,6 +715,15 @@ struct DeviceBatchedMultiheadAttentionForward_Xdl_CShuffle_V2
{
is_lse_storing_
=
false
;
}
compute_base_ptr_of_batch_
=
ComputeBasePtrOfStridedBatch
(
a_grid_desc_g_m_k_
,
b_grid_desc_g_n_k_
,
d0_grid_desc_g_m_n_
,
b1_grid_desc_g_n_k_
,
c_grid_desc_g_m_n_
,
z_grid_desc_g_m_n_
,
type_convert
<
index_t
>
(
lse_grid_desc_m_
.
GetElementSpaceSize
()));
}
void
Print
()
const
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment