Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
de6af588
Commit
de6af588
authored
Aug 02, 2023
by
letaoqin
Browse files
v2 add bias
parent
8e26a612
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
89 additions
and
11 deletions
+89
-11
example/52_flash_atten_bias/run_batched_multihead_attention_bias_forward.inc
...ten_bias/run_batched_multihead_attention_bias_forward.inc
+13
-4
include/ck/tensor_operation/gpu/grid/gridwise_batched_mha_fwd_xdl_cshuffle_v2r2.hpp
...n/gpu/grid/gridwise_batched_mha_fwd_xdl_cshuffle_v2r2.hpp
+76
-7
No files found.
example/52_flash_atten_bias/run_batched_multihead_attention_bias_forward.inc
View file @
de6af588
...
@@ -23,7 +23,7 @@ int run(int argc, char* argv[])
...
@@ -23,7 +23,7 @@ int run(int argc, char* argv[])
bool
input_permute
=
false
;
bool
input_permute
=
false
;
bool
output_permute
=
true
;
bool
output_permute
=
true
;
float
p_drop
=
0
.1
;
float
p_drop
=
0
;
const
unsigned
long
long
seed
=
1
;
const
unsigned
long
long
seed
=
1
;
const
unsigned
long
long
offset
=
0
;
const
unsigned
long
long
offset
=
0
;
...
@@ -69,7 +69,7 @@ int run(int argc, char* argv[])
...
@@ -69,7 +69,7 @@ int run(int argc, char* argv[])
float
p_dropout
=
1
-
p_drop
;
float
p_dropout
=
1
-
p_drop
;
ZDataType
p_dropout_in_16bits
=
ZDataType
(
std
::
floor
(
p_dropout
*
65535.0
));
ZDataType
p_dropout_in_16bits
=
ZDataType
(
std
::
floor
(
p_dropout
*
65535.0
));
float
rp_dropout
=
1.0
/
p_dropout
;
float
rp_dropout
=
1.0
/
p_dropout
;
float
alpha
=
1.
f
/
std
::
sqrt
(
K
);
float
alpha
=
1.
f
;
//
/ std::sqrt(K);
std
::
vector
<
ck
::
index_t
>
a_gs_ms_ks_lengths
{
G0
,
G1
,
M
,
K
};
std
::
vector
<
ck
::
index_t
>
a_gs_ms_ks_lengths
{
G0
,
G1
,
M
,
K
};
std
::
vector
<
ck
::
index_t
>
a_gs_ms_ks_strides
=
std
::
vector
<
ck
::
index_t
>
a_gs_ms_ks_strides
=
...
@@ -137,7 +137,7 @@ int run(int argc, char* argv[])
...
@@ -137,7 +137,7 @@ int run(int argc, char* argv[])
a_gs_ms_ks
.
GenerateTensorValue
(
GeneratorTensor_2
<
ADataType
>
{
-
2
,
2
});
a_gs_ms_ks
.
GenerateTensorValue
(
GeneratorTensor_2
<
ADataType
>
{
-
2
,
2
});
b0_gs_ns_ks
.
GenerateTensorValue
(
GeneratorTensor_2
<
B0DataType
>
{
-
2
,
2
});
b0_gs_ns_ks
.
GenerateTensorValue
(
GeneratorTensor_2
<
B0DataType
>
{
-
2
,
2
});
b1_gs_os_ns
.
GenerateTensorValue
(
GeneratorTensor_2
<
B1DataType
>
{
-
2
,
2
});
b1_gs_os_ns
.
GenerateTensorValue
(
GeneratorTensor_2
<
B1DataType
>
{
-
2
,
2
});
d_gs_ms_ns
.
GenerateTensorValue
(
GeneratorTensor_2
<
DDataType
>
{
-
2
,
2
});
d_gs_ms_ns
.
GenerateTensorValue
(
GeneratorTensor_2
<
DDataType
>
{
-
1
,
1
});
break
;
break
;
case
2
:
case
2
:
a_gs_ms_ks
.
GenerateTensorValue
(
GeneratorTensor_3
<
ADataType
>
{
0.0
,
1.0
});
a_gs_ms_ks
.
GenerateTensorValue
(
GeneratorTensor_3
<
ADataType
>
{
0.0
,
1.0
});
...
@@ -297,6 +297,7 @@ int run(int argc, char* argv[])
...
@@ -297,6 +297,7 @@ int run(int argc, char* argv[])
Tensor
<
ADataType
>
a1_g_m_n_drop
({
G0
*
G1
,
M
,
N
});
Tensor
<
ADataType
>
a1_g_m_n_drop
({
G0
*
G1
,
M
,
N
});
Tensor
<
LSEDataType
>
lse_g_m_host_result
(
Tensor
<
LSEDataType
>
lse_g_m_host_result
(
{
BatchCount
,
M
});
// scratch object after max + ln(sum)
{
BatchCount
,
M
});
// scratch object after max + ln(sum)
Tensor
<
DDataType
>
d_g_m_n
({
G0
*
G1
,
M
,
N
});
Tensor
<
ZDataType
>
z_g_m_n
({
G0
*
G1
,
M
,
N
});
Tensor
<
ZDataType
>
z_g_m_n
({
G0
*
G1
,
M
,
N
});
Tensor
<
CDataType
>
c_g_m_o_host_result
({
BatchCount
,
M
,
O
});
// scratch object after gemm1
Tensor
<
CDataType
>
c_g_m_o_host_result
({
BatchCount
,
M
,
O
});
// scratch object after gemm1
...
@@ -310,6 +311,10 @@ int run(int argc, char* argv[])
...
@@ -310,6 +311,10 @@ int run(int argc, char* argv[])
b1_gs_os_ns
.
ForEach
([
&
](
auto
&
self
,
auto
idx
)
{
b1_gs_os_ns
.
ForEach
([
&
](
auto
&
self
,
auto
idx
)
{
b1_g_n_o
(
idx
[
0
]
*
G1
+
idx
[
1
],
idx
[
3
],
idx
[
2
])
=
self
(
idx
);
b1_g_n_o
(
idx
[
0
]
*
G1
+
idx
[
1
],
idx
[
3
],
idx
[
2
])
=
self
(
idx
);
});
});
d_gs_ms_ns
.
ForEach
([
&
](
auto
&
self
,
auto
idx
)
{
d_g_m_n
(
idx
[
0
]
*
G1
+
idx
[
1
],
idx
[
2
],
idx
[
3
])
=
self
(
idx
);
});
z_gs_ms_ns
.
ForEach
([
&
](
auto
&
self
,
auto
idx
)
{
z_gs_ms_ns
.
ForEach
([
&
](
auto
&
self
,
auto
idx
)
{
z_g_m_n
(
idx
[
0
]
*
G1
+
idx
[
1
],
idx
[
2
],
idx
[
3
])
=
self
(
idx
);
z_g_m_n
(
idx
[
0
]
*
G1
+
idx
[
1
],
idx
[
2
],
idx
[
3
])
=
self
(
idx
);
});
});
...
@@ -322,6 +327,10 @@ int run(int argc, char* argv[])
...
@@ -322,6 +327,10 @@ int run(int argc, char* argv[])
ref_gemm0_invoker
.
Run
(
ref_gemm0_argument
);
ref_gemm0_invoker
.
Run
(
ref_gemm0_argument
);
//bias
acc0_g_m_n
.
ForEach
([
&
](
auto
&
self
,
auto
idx
)
{
self
(
idx
)
+=
d_g_m_n
(
idx
);
});
// masking
// masking
const
auto
mask
=
DeviceGemmInstance
::
C0MatrixMask
(
M
,
N
);
const
auto
mask
=
DeviceGemmInstance
::
C0MatrixMask
(
M
,
N
);
acc0_g_m_n
.
ForEach
([
&
](
auto
&
self
,
auto
idx
)
{
acc0_g_m_n
.
ForEach
([
&
](
auto
&
self
,
auto
idx
)
{
...
...
include/ck/tensor_operation/gpu/grid/gridwise_batched_mha_fwd_xdl_cshuffle_v2r2.hpp
View file @
de6af588
...
@@ -499,8 +499,6 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle_V2R2
...
@@ -499,8 +499,6 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle_V2R2
const
index_t
raw_n_padded
,
const
index_t
raw_n_padded
,
const
index_t
block_idx_m
)
const
index_t
block_idx_m
)
{
{
ignore
=
p_d_grid
;
ignore
=
d_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5
;
const
auto
a_grid_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
const
auto
a_grid_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
p_a_grid
,
a_grid_desc_ak0_m_ak1
.
GetElementSpaceSize
());
p_a_grid
,
a_grid_desc_ak0_m_ak1
.
GetElementSpaceSize
());
const
auto
b_grid_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
const
auto
b_grid_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
...
@@ -900,6 +898,52 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle_V2R2
...
@@ -900,6 +898,52 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle_V2R2
// gemm1 K loop
// gemm1 K loop
index_t
gemm1_k_block_outer_index
=
0
;
index_t
gemm1_k_block_outer_index
=
0
;
const
auto
wave_id
=
GetGemm0WaveIdx
();
const
auto
wave_m_n_id
=
GetGemm0WaveMNIdx
(
wave_id
[
I2
]);
// I2: 0~63
// bias (d matrix)
constexpr
auto
d_thread_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
I1
,
// MBlockId
I1
,
// NBlockId
m0
,
// MRepeat
n0
,
// NRepeat
m1
,
// MWaveId
n1
,
// NWaveId
m2
,
// MPerXdl
n2
,
// NGroupNum
n3
,
// NInputNum
n4
));
// RegisterNum
auto
d_threadwise_copy_globla_vgpr
=
ThreadwiseTensorSliceTransfer_v2
<
DDataType
,
DDataType
,
decltype
(
d_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5
),
decltype
(
d_thread_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5
),
Sequence
<
I1
,
// MBlockId
I1
,
// NBlockID
m0
,
// MRepeat
n0
,
// NRepeat
m1
,
// MWaveId
n1
,
// NWaveId
m2
,
// MPerXdl
n2
,
// NGroupNum
n3
,
// NInputNum
n4
>
,
Sequence
<
0
,
1
,
2
,
3
,
4
,
5
,
6
,
7
,
8
,
9
>
,
9
,
4
,
1
,
false
>
(
d_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5
,
make_multi_index
(
block_work_idx
[
I0
],
// MBlockId
0
,
// NBlockId
0
,
// mrepeat
0
,
// nrepeat
wave_id
[
I0
],
// MWaveId
wave_id
[
I1
],
// NWaveId
wave_m_n_id
[
I1
],
// MPerXdl
0
,
// group
wave_m_n_id
[
I0
],
// NInputIndex
0
));
// register number
// z is random number matrix for dropout verify
// z is random number matrix for dropout verify
//
//
// z vgpr copy to global
// z vgpr copy to global
...
@@ -992,9 +1036,6 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle_V2R2
...
@@ -992,9 +1036,6 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle_V2R2
static_cast
<
ushort
*>
(
p_shared
),
static_cast
<
ushort
*>
(
p_shared
),
z_shuffle_block_desc_m0_n0_m1_n1_m2_n2_n3_n4
.
GetElementSpaceSize
());
z_shuffle_block_desc_m0_n0_m1_n1_m2_n2_n3_n4
.
GetElementSpaceSize
());
const
auto
wave_id
=
GetGemm0WaveIdx
();
const
auto
wave_m_n_id
=
GetGemm0WaveMNIdx
(
wave_id
[
I2
]);
// I2: 0~63
auto
z_tmp_thread_copy_vgpr_to_lds
=
ThreadwiseTensorSliceTransfer_v1r3
<
auto
z_tmp_thread_copy_vgpr_to_lds
=
ThreadwiseTensorSliceTransfer_v1r3
<
ushort
,
ushort
,
ushort
,
ushort
,
...
@@ -1111,6 +1152,35 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle_V2R2
...
@@ -1111,6 +1152,35 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle_V2R2
acc_thread_buf
,
acc_thread_buf
,
num_k_block_main_loop
);
num_k_block_main_loop
);
// add bias
if
(
p_d_grid
)
{
auto
d_grid_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
p_d_grid
,
d_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5
.
GetElementSpaceSize
());
StaticBuffer
<
AddressSpaceEnum
::
Vgpr
,
DDataType
,
d_thread_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5
.
GetElementSpaceSize
(),
true
>
d_thread_buf
;
d_threadwise_copy_globla_vgpr
.
Run
(
d_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5
,
d_grid_buf
,
d_thread_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5
,
make_tuple
(
I0
,
I0
,
I0
,
I0
,
I0
,
I0
,
I0
,
I0
,
I0
,
I0
),
d_thread_buf
);
static_for
<
0
,
m0
*
n0
*
n2
*
n4
,
1
>
{}([
&
](
auto
i
)
{
// acc add bias
acc_thread_buf
(
i
)
+=
d_thread_buf
[
i
];
});
d_threadwise_copy_globla_vgpr
.
MoveSrcSliceWindow
(
d_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5
,
make_multi_index
(
0
,
1
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
));
}
// 8d thread_desc in thread scope
// 8d thread_desc in thread scope
constexpr
auto
c_thread_lengths
=
constexpr
auto
c_thread_lengths
=
blockwise_gemm
.
GetCThreadDescriptor_M0_N0_M1_N1_M2_N2_N3_N4
().
GetLengths
();
blockwise_gemm
.
GetCThreadDescriptor_M0_N0_M1_N1_M2_N2_N3_N4
().
GetLengths
();
...
@@ -1145,7 +1215,6 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle_V2R2
...
@@ -1145,7 +1215,6 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle_V2R2
// do MNK padding or upper triangular masking
// do MNK padding or upper triangular masking
if
constexpr
(
MaskOutUpperTriangle
||
PadN
)
if
constexpr
(
MaskOutUpperTriangle
||
PadN
)
{
{
static_for
<
0
,
Acc0TileIterator
::
GetNumOfAccess
(),
1
>
{}([
&
](
auto
i
)
{
static_for
<
0
,
Acc0TileIterator
::
GetNumOfAccess
(),
1
>
{}([
&
](
auto
i
)
{
auto
acc0_thread_idx
=
Acc0TileIterator
::
GetIndex
(
i
)
+
acc0_thread_origin
;
auto
acc0_thread_idx
=
Acc0TileIterator
::
GetIndex
(
i
)
+
acc0_thread_origin
;
auto
m_local
=
auto
m_local
=
...
@@ -1572,6 +1641,6 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle_V2R2
...
@@ -1572,6 +1641,6 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle_V2R2
});
});
}
}
}
}
};
};
// namespace ck
}
// namespace ck
}
// namespace ck
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment