Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
8ced5c4f
Commit
8ced5c4f
authored
Sep 06, 2023
by
danyao12
Browse files
bias examples sync with uint8 dropout
parent
0353c29e
Changes
5
Show whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
44 additions
and
37 deletions
+44
-37
example/52_flash_atten_bias/batched_multihead_attention_bias_backward_v2.cpp
...ten_bias/batched_multihead_attention_bias_backward_v2.cpp
+8
-8
example/52_flash_atten_bias/grouped_multihead_attention_bias_backward_v2.cpp
...ten_bias/grouped_multihead_attention_bias_backward_v2.cpp
+7
-7
example/52_flash_atten_bias/run_batched_multihead_attention_bias_forward.inc
...ten_bias/run_batched_multihead_attention_bias_forward.inc
+9
-6
example/52_flash_atten_bias/run_grouped_multihead_attention_bias_forward.inc
...ten_bias/run_grouped_multihead_attention_bias_forward.inc
+15
-11
include/ck/tensor_operation/gpu/grid/gridwise_batched_dropout.hpp
...ck/tensor_operation/gpu/grid/gridwise_batched_dropout.hpp
+5
-5
No files found.
example/52_flash_atten_bias/batched_multihead_attention_bias_backward_v2.cpp
View file @
8ced5c4f
...
...
@@ -217,7 +217,7 @@ void run_attention_fwd_host(const TensorQ& q_g_m_k,
TensorLSE
&
lse_g_m
,
TensorP
&
p_drop_g_m_n
,
TensorZ
&
z_g_m_n
,
ZDataType
p_dropout_in_
16bits
,
ZDataType
p_dropout_in_
uint8_t
,
float
rp_dropout
)
{
// S = alpha * Q * K^T
...
...
@@ -252,7 +252,7 @@ void run_attention_fwd_host(const TensorQ& q_g_m_k,
auto
ref_dropout
=
ReferenceDropoutInstance
{};
auto
ref_dropout_invoker
=
ref_dropout
.
MakeInvoker
();
auto
ref_dropout_argment
=
ref_dropout
.
MakeArgument
(
z_g_m_n
,
p_g_m_n
,
p_drop_g_m_n
,
p_dropout_in_
16bits
,
rp_dropout
);
ref_dropout
.
MakeArgument
(
z_g_m_n
,
p_g_m_n
,
p_drop_g_m_n
,
p_dropout_in_
uint8_t
,
rp_dropout
);
ref_dropout_invoker
.
Run
(
ref_dropout_argment
);
// Y = P_dropout * V
...
...
@@ -328,7 +328,7 @@ int run(int argc, char* argv[])
}
float
p_dropout
=
1
-
p_drop
;
ZDataType
p_dropout_in_
16bits
=
ZDataType
(
std
::
floor
(
p_dropout
*
6553
5.0
));
ZDataType
p_dropout_in_
uint8_t
=
ZDataType
(
std
::
floor
(
p_dropout
*
25
5.0
));
float
rp_dropout
=
1.0
/
p_dropout
;
float
alpha
=
1.
f
/
std
::
sqrt
(
K
);
...
...
@@ -655,7 +655,7 @@ int run(int argc, char* argv[])
lse_g_m
,
p_drop_g_m_n
,
z_g_m_n
,
p_dropout_in_
16bits
,
p_dropout_in_
uint8_t
,
rp_dropout
);
y_gs_ms_os
.
ForEach
([
&
](
auto
&
self
,
auto
idx
)
{
self
(
idx
)
=
y_g_m_o
(
idx
[
0
]
*
G1
+
idx
[
1
],
idx
[
2
],
idx
[
3
]);
...
...
@@ -715,7 +715,7 @@ int run(int argc, char* argv[])
auto
ref_dropout
=
ReferenceDropoutInstance
{};
auto
ref_dropout_invoker
=
ref_dropout
.
MakeInvoker
();
auto
ref_dropout_argment
=
ref_dropout
.
MakeArgument
(
z_g_m_n
,
pgrad_drop_g_m_n
,
pgrad_g_m_n
,
p_dropout_in_
16bits
,
rp_dropout
);
z_g_m_n
,
pgrad_drop_g_m_n
,
pgrad_g_m_n
,
p_dropout_in_
uint8_t
,
rp_dropout
);
ref_dropout_invoker
.
Run
(
ref_dropout_argment
);
// dS_i_j = P_i_j .* (dP_i_j - dY_i dot Y_i)
...
...
example/52_flash_atten_bias/grouped_multihead_attention_bias_backward_v2.cpp
View file @
8ced5c4f
...
...
@@ -216,7 +216,7 @@ void run_attention_fwd_host(const TensorQ& q_g_m_k,
TensorLSE
&
lse_g_m
,
TensorP
&
p_drop_g_m_n
,
TensorZ
&
z_g_m_n
,
ZDataType
p_dropout_in_
16bits
,
ZDataType
p_dropout_in_
uint8_t
,
float
rp_dropout
)
{
// S = alpha * Q * K^T
...
...
@@ -251,7 +251,7 @@ void run_attention_fwd_host(const TensorQ& q_g_m_k,
auto
ref_dropout
=
ReferenceDropoutInstance
{};
auto
ref_dropout_invoker
=
ref_dropout
.
MakeInvoker
();
auto
ref_dropout_argment
=
ref_dropout
.
MakeArgument
(
z_g_m_n
,
p_g_m_n
,
p_drop_g_m_n
,
p_dropout_in_
16bits
,
rp_dropout
);
ref_dropout
.
MakeArgument
(
z_g_m_n
,
p_g_m_n
,
p_drop_g_m_n
,
p_dropout_in_
uint8_t
,
rp_dropout
);
ref_dropout_invoker
.
Run
(
ref_dropout_argment
);
// Y = P_dropout * V
...
...
@@ -315,7 +315,7 @@ int run(int argc, char* argv[])
}
float
p_dropout
=
1
-
p_drop
;
ZDataType
p_dropout_in_
16bits
=
ZDataType
(
std
::
floor
(
p_dropout
*
6553
5.0
));
ZDataType
p_dropout_in_
uint8_t
=
ZDataType
(
std
::
floor
(
p_dropout
*
25
5.0
));
float
rp_dropout
=
1.0
/
p_dropout
;
auto
gemm
=
DeviceGemmInstance
{};
...
...
@@ -719,7 +719,7 @@ int run(int argc, char* argv[])
lse_g_ms
[
i
],
p_drop_g_m_ns
[
i
],
z_g_m_ns
[
i
],
p_dropout_in_
16bits
,
p_dropout_in_
uint8_t
,
rp_dropout
);
y_tensors
[
i
].
ForEach
([
&
](
auto
&
self
,
auto
idx
)
{
...
...
@@ -772,7 +772,7 @@ int run(int argc, char* argv[])
auto
ref_dropout
=
ReferenceDropoutInstance
{};
auto
ref_dropout_invoker
=
ref_dropout
.
MakeInvoker
();
auto
ref_dropout_argment
=
ref_dropout
.
MakeArgument
(
z_g_m_ns
[
i
],
pgrad_drop_g_m_n
,
pgrad_g_m_n
,
p_dropout_in_
16bits
,
rp_dropout
);
z_g_m_ns
[
i
],
pgrad_drop_g_m_n
,
pgrad_g_m_n
,
p_dropout_in_
uint8_t
,
rp_dropout
);
ref_dropout_invoker
.
Run
(
ref_dropout_argment
);
// dS_i_j = P_i_j .* (dP_i_j - dY_i dot Y_i)
sgrad_g_m_n
.
ForEach
([
&
](
auto
&
self
,
auto
idx_gmn
)
{
...
...
example/52_flash_atten_bias/run_batched_multihead_attention_bias_forward.inc
View file @
8ced5c4f
...
...
@@ -67,7 +67,7 @@ int run(int argc, char* argv[])
}
float
p_dropout
=
1
-
p_drop
;
ZDataType
p_dropout_in_
16bits
=
ZDataType
(
std
::
floor
(
p_dropout
*
6553
5.0
));
ZDataType
p_dropout_in_
uint8_t
=
ZDataType
(
std
::
floor
(
p_dropout
*
25
5.0
));
float
rp_dropout
=
1.0
/
p_dropout
;
float
alpha
=
1.
f
/
std
::
sqrt
(
K
);
...
...
@@ -172,6 +172,7 @@ int run(int argc, char* argv[])
b0_device_buf
.
ToDevice
(
b0_gs_ns_ks
.
mData
.
data
());
b1_device_buf
.
ToDevice
(
b1_gs_os_ns
.
mData
.
data
());
d_device_buf
.
ToDevice
(
d_gs_ms_ns
.
mData
.
data
());
z_device_buf
.
ToDevice
(
z_gs_ms_ns
.
mData
.
data
());
auto
a_element_op
=
AElementOp
{};
auto
b0_element_op
=
B0ElementOp
{};
...
...
@@ -322,7 +323,9 @@ int run(int argc, char* argv[])
ref_gemm0_invoker
.
Run
(
ref_gemm0_argument
);
// bias
acc0_g_m_n
.
ForEach
([
&
](
auto
&
self
,
auto
idx
)
{
self
(
idx
)
+=
ck
::
type_convert
<
AccDataType
>
(
d_g_m_n
(
idx
));
});
acc0_g_m_n
.
ForEach
([
&
](
auto
&
self
,
auto
idx
)
{
self
(
idx
)
+=
ck
::
type_convert
<
AccDataType
>
(
d_g_m_n
(
idx
));
});
// masking
const
auto
mask
=
DeviceGemmInstance
::
C0MatrixMask
(
M
,
N
);
acc0_g_m_n
.
ForEach
([
&
](
auto
&
self
,
auto
idx
)
{
...
...
@@ -342,7 +345,7 @@ int run(int argc, char* argv[])
auto
ref_dropout
=
ReferenceDropoutInstance
{};
auto
ref_dropout_invoker
=
ref_dropout
.
MakeInvoker
();
auto
ref_dropout_argment
=
ref_dropout
.
MakeArgument
(
z_g_m_n
,
a1_g_m_n
,
a1_g_m_n_drop
,
p_dropout_in_
16bits
,
rp_dropout
);
z_g_m_n
,
a1_g_m_n
,
a1_g_m_n_drop
,
p_dropout_in_
uint8_t
,
rp_dropout
);
ref_dropout_invoker
.
Run
(
ref_dropout_argment
);
// gemm1
...
...
example/52_flash_atten_bias/run_grouped_multihead_attention_bias_forward.inc
View file @
8ced5c4f
...
...
@@ -44,7 +44,7 @@ int run(int argc, char* argv[])
}
float
p_dropout
=
1
-
p_drop
;
uint16_t
p_dropout_in_
16bits
=
uint16_t
(
std
::
floor
(
p_dropout
*
6553
5.0
));
ZDataType
p_dropout_in_
uint8_t
=
ZDataType
(
std
::
floor
(
p_dropout
*
25
5.0
));
float
rp_dropout
=
1.0
/
p_dropout
;
float
alpha
=
1
;
// scaling after 1st gemm
...
...
@@ -163,8 +163,9 @@ int run(int argc, char* argv[])
int
Batch
=
G0
*
G1
;
flop
+=
(
size_t
(
M
)
*
N
*
K
*
2
+
size_t
(
M
)
*
N
*
O
*
2
)
*
Batch
;
num_byte
+=
(
sizeof
(
ADataType
)
*
M
*
K
+
sizeof
(
B0DataType
)
*
K
*
N
+
sizeof
(
B1DataType
)
*
N
*
O
+
sizeof
(
CDataType
)
*
M
*
O
+
num_byte
+=
(
sizeof
(
ADataType
)
*
M
*
K
+
sizeof
(
B0DataType
)
*
K
*
N
+
sizeof
(
B1DataType
)
*
N
*
O
+
sizeof
(
CDataType
)
*
M
*
O
+
sizeof
(
Acc0BiasDataType
)
*
M
*
N
*
(
std
::
is_void
<
Acc0BiasDataType
>::
value
?
0
:
1
))
*
Batch
;
...
...
@@ -237,6 +238,7 @@ int run(int argc, char* argv[])
b0_tensors_device
[
i
]
->
ToDevice
(
b0_gs_ns_ks
.
mData
.
data
());
b1_tensors_device
[
i
]
->
ToDevice
(
b1_gs_os_ns
.
mData
.
data
());
d_tensors_device
[
i
]
->
ToDevice
(
d_gs_ms_ns
.
mData
.
data
());
z_tensors_device
[
i
]
->
ToDevice
(
z_gs_ms_ns
.
mData
.
data
());
p_a
.
push_back
(
a_tensors_device
[
i
]
->
GetDeviceBuffer
());
p_b0
.
push_back
(
b0_tensors_device
[
i
]
->
GetDeviceBuffer
());
...
...
@@ -396,7 +398,9 @@ int run(int argc, char* argv[])
ref_gemm0_invoker
.
Run
(
ref_gemm0_argument
);
// bias
acc0_g_m_n
.
ForEach
([
&
](
auto
&
self
,
auto
idx
)
{
self
(
idx
)
+=
ck
::
type_convert
<
AccDataType
>
(
d_g_m_n
(
idx
));
});
acc0_g_m_n
.
ForEach
([
&
](
auto
&
self
,
auto
idx
)
{
self
(
idx
)
+=
ck
::
type_convert
<
AccDataType
>
(
d_g_m_n
(
idx
));
});
// masking
const
auto
mask
=
DeviceGemmInstance
::
C0MatrixMask
(
M
,
N
);
acc0_g_m_n
.
ForEach
([
&
](
auto
&
self
,
auto
idx
)
{
...
...
@@ -419,7 +423,7 @@ int run(int argc, char* argv[])
auto
ref_dropout
=
ReferenceDropoutInstance
{};
auto
ref_dropout_invoker
=
ref_dropout
.
MakeInvoker
();
auto
ref_dropout_argment
=
ref_dropout
.
MakeArgument
(
z_g_m_n
,
a1_g_m_n
,
a1_g_m_n_drop
,
p_dropout_in_
16bits
,
rp_dropout
);
z_g_m_n
,
a1_g_m_n
,
a1_g_m_n_drop
,
p_dropout_in_
uint8_t
,
rp_dropout
);
ref_dropout_invoker
.
Run
(
ref_dropout_argment
);
// gemm 1
...
...
include/ck/tensor_operation/gpu/grid/gridwise_batched_dropout.hpp
View file @
8ced5c4f
...
...
@@ -57,8 +57,8 @@ struct GridwiseBatchedDropout
static
constexpr
auto
mfma
=
MfmaSelector
<
GemmDataType
,
MPerXdl
,
NPerXdl
>::
selected_mfma
;
static
constexpr
auto
DropoutNThread
=
mfma
.
num_input_blks
;
// 2
// get_random_
8x
16() generates
8
random numbers each time
static
constexpr
auto
DropoutTile
=
Number
<
DropoutNThread
*
8
>
{};
//
16
// get_random_16
x8
() generates
16
random numbers each time
static
constexpr
auto
DropoutTile
=
Number
<
DropoutNThread
*
16
>
{};
//
32
using
ThisThreadBlock
=
ThisThreadBlock
<
BlockSize
>
;
...
...
@@ -241,7 +241,7 @@ struct GridwiseBatchedDropout
// only used for providing ApplyDropoutAttnBwdSaveZ
auto
blockwise_dropout
=
BlockwiseDropout
<
FloatGemmAcc
,
decltype
(
thread_slice_desc_m_n
)
>
{
static_cast
<
unsigned
short
>
(
0.8
f
*
6553
5.
f
),
static_cast
<
FloatGemmAcc
>
(
1.0
f
/
0.8
f
)};
static_cast
<
unsigned
short
>
(
0.8
f
*
25
5.
f
),
static_cast
<
FloatGemmAcc
>
(
1.0
f
/
0.8
f
)};
//
// z vgpr copy to global
...
...
@@ -260,7 +260,7 @@ struct GridwiseBatchedDropout
n2
));
// NPerXdl
StaticBuffer
<
AddressSpaceEnum
::
Vgpr
,
u
shor
t
,
u
int8_
t
,
z_thread_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3
.
GetElementSpaceSize
(),
true
>
z_tensor_buffer
;
...
...
@@ -273,7 +273,7 @@ struct GridwiseBatchedDropout
const
auto
wave_m_n_id
=
GetGemm0WaveMNIdx
(
wave_id
[
I2
]);
// I2: 0~63
auto
z_thread_copy_vgpr_to_global
=
ThreadwiseTensorSliceTransfer_v1r3
<
u
shor
t
,
u
int8_
t
,
ZDataType
,
decltype
(
z_thread_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3
),
decltype
(
z_grid_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3
),
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment