Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
0353c29e
Commit
0353c29e
authored
Sep 06, 2023
by
danyao12
Browse files
uint8 dropout
parent
b7b7e153
Changes
18
Expand all
Hide whitespace changes
Inline
Side-by-side
Showing
18 changed files
with
455 additions
and
492 deletions
+455
-492
example/32_batched_gemm_scale_softmax_gemm/batched_multihead_attention_backward_v2.cpp
..._softmax_gemm/batched_multihead_attention_backward_v2.cpp
+8
-8
example/32_batched_gemm_scale_softmax_gemm/batched_multihead_attention_backward_v3.cpp
..._softmax_gemm/batched_multihead_attention_backward_v3.cpp
+8
-8
example/32_batched_gemm_scale_softmax_gemm/batched_multihead_attention_train_v2.cpp
...ale_softmax_gemm/batched_multihead_attention_train_v2.cpp
+8
-8
example/32_batched_gemm_scale_softmax_gemm/grouped_multihead_attention_backward_v2.cpp
..._softmax_gemm/grouped_multihead_attention_backward_v2.cpp
+7
-7
example/32_batched_gemm_scale_softmax_gemm/grouped_multihead_attention_backward_v3.cpp
..._softmax_gemm/grouped_multihead_attention_backward_v3.cpp
+7
-7
example/32_batched_gemm_scale_softmax_gemm/grouped_multihead_attention_train_v2.cpp
...ale_softmax_gemm/grouped_multihead_attention_train_v2.cpp
+7
-7
example/32_batched_gemm_scale_softmax_gemm/run_batched_multihead_attention_forward.inc
..._softmax_gemm/run_batched_multihead_attention_forward.inc
+6
-5
example/32_batched_gemm_scale_softmax_gemm/run_grouped_multihead_attention_forward.inc
..._softmax_gemm/run_grouped_multihead_attention_forward.inc
+5
-4
include/ck/tensor_operation/gpu/block/blockwise_dropout.hpp
include/ck/tensor_operation/gpu/block/blockwise_dropout.hpp
+118
-118
include/ck/tensor_operation/gpu/device/impl/device_batched_mha_fwd_xdl_cshuffle_v2.hpp
...pu/device/impl/device_batched_mha_fwd_xdl_cshuffle_v2.hpp
+23
-24
include/ck/tensor_operation/gpu/device/impl/device_grouped_mha_fwd_xdl_cshuffle_v2.hpp
...pu/device/impl/device_grouped_mha_fwd_xdl_cshuffle_v2.hpp
+18
-18
include/ck/tensor_operation/gpu/grid/gridwise_batched_mha_bwd_xdl_cshuffle_qloop_b2t_light_v1.hpp
...dwise_batched_mha_bwd_xdl_cshuffle_qloop_b2t_light_v1.hpp
+7
-7
include/ck/tensor_operation/gpu/grid/gridwise_batched_mha_bwd_xdl_cshuffle_qloop_b2t_light_v2.hpp
...dwise_batched_mha_bwd_xdl_cshuffle_qloop_b2t_light_v2.hpp
+7
-7
include/ck/tensor_operation/gpu/grid/gridwise_batched_mha_bwd_xdl_cshuffle_qloop_b2t_v1.hpp
...id/gridwise_batched_mha_bwd_xdl_cshuffle_qloop_b2t_v1.hpp
+7
-7
include/ck/tensor_operation/gpu/grid/gridwise_batched_mha_bwd_xdl_cshuffle_qloop_b2t_v2.hpp
...id/gridwise_batched_mha_bwd_xdl_cshuffle_qloop_b2t_v2.hpp
+10
-10
include/ck/tensor_operation/gpu/grid/gridwise_batched_mha_fwd_xdl_cshuffle_v2.hpp
...ion/gpu/grid/gridwise_batched_mha_fwd_xdl_cshuffle_v2.hpp
+190
-241
include/ck/utility/philox_rand.hpp
include/ck/utility/philox_rand.hpp
+13
-0
library/include/ck/library/reference_tensor_operation/cpu/reference_dropout.hpp
...rary/reference_tensor_operation/cpu/reference_dropout.hpp
+6
-6
No files found.
example/32_batched_gemm_scale_softmax_gemm/batched_multihead_attention_backward_v2.cpp
View file @
0353c29e
...
...
@@ -217,7 +217,7 @@ void run_attention_fwd_host(const TensorQ& q_g_m_k,
TensorLSE
&
lse_g_m
,
TensorP
&
p_drop_g_m_n
,
TensorZ
&
z_g_m_n
,
ZDataType
p_dropout_in_
16bits
,
ZDataType
p_dropout_in_
uint8_t
,
float
rp_dropout
)
{
// S = alpha * Q * K^T
...
...
@@ -249,7 +249,7 @@ void run_attention_fwd_host(const TensorQ& q_g_m_k,
auto
ref_dropout
=
ReferenceDropoutInstance
{};
auto
ref_dropout_invoker
=
ref_dropout
.
MakeInvoker
();
auto
ref_dropout_argment
=
ref_dropout
.
MakeArgument
(
z_g_m_n
,
p_g_m_n
,
p_drop_g_m_n
,
p_dropout_in_
16bits
,
rp_dropout
);
ref_dropout
.
MakeArgument
(
z_g_m_n
,
p_g_m_n
,
p_drop_g_m_n
,
p_dropout_in_
uint8_t
,
rp_dropout
);
ref_dropout_invoker
.
Run
(
ref_dropout_argment
);
// Y = P_dropout * V
...
...
@@ -324,10 +324,10 @@ int run(int argc, char* argv[])
exit
(
0
);
}
float
p_dropout
=
1
-
p_drop
;
ZDataType
p_dropout_in_
16bits
=
ZDataType
(
std
::
floor
(
p_dropout
*
6553
5.0
));
float
rp_dropout
=
1.0
/
p_dropout
;
float
alpha
=
1.
f
/
std
::
sqrt
(
K
);
float
p_dropout
=
1
-
p_drop
;
ZDataType
p_dropout_in_
uint8_t
=
ZDataType
(
std
::
floor
(
p_dropout
*
25
5.0
));
float
rp_dropout
=
1.0
/
p_dropout
;
float
alpha
=
1.
f
/
std
::
sqrt
(
K
);
std
::
cout
<<
"do_verification: "
<<
do_verification
<<
std
::
endl
;
std
::
cout
<<
"init_method: "
<<
init_method
<<
std
::
endl
;
...
...
@@ -627,7 +627,7 @@ int run(int argc, char* argv[])
lse_g_m
,
p_drop_g_m_n
,
z_g_m_n
,
p_dropout_in_
16bits
,
p_dropout_in_
uint8_t
,
rp_dropout
);
y_gs_ms_os
.
ForEach
([
&
](
auto
&
self
,
auto
idx
)
{
self
(
idx
)
=
y_g_m_o
(
idx
[
0
]
*
G1
+
idx
[
1
],
idx
[
2
],
idx
[
3
]);
...
...
@@ -687,7 +687,7 @@ int run(int argc, char* argv[])
auto
ref_dropout
=
ReferenceDropoutInstance
{};
auto
ref_dropout_invoker
=
ref_dropout
.
MakeInvoker
();
auto
ref_dropout_argment
=
ref_dropout
.
MakeArgument
(
z_g_m_n
,
pgrad_drop_g_m_n
,
pgrad_g_m_n
,
p_dropout_in_
16bits
,
rp_dropout
);
z_g_m_n
,
pgrad_drop_g_m_n
,
pgrad_g_m_n
,
p_dropout_in_
uint8_t
,
rp_dropout
);
ref_dropout_invoker
.
Run
(
ref_dropout_argment
);
// dS_i_j = P_i_j .* (dP_i_j - dY_i dot Y_i)
...
...
example/32_batched_gemm_scale_softmax_gemm/batched_multihead_attention_backward_v3.cpp
View file @
0353c29e
...
...
@@ -218,7 +218,7 @@ void run_attention_fwd_host(const TensorQ& q_g_m_k,
TensorLSE
&
lse_g_m
,
TensorP
&
p_drop_g_m_n
,
TensorZ
&
z_g_m_n
,
ZDataType
p_dropout_in_
16bits
,
ZDataType
p_dropout_in_
uint8_t
,
float
rp_dropout
)
{
// S = alpha * Q * K^T
...
...
@@ -250,7 +250,7 @@ void run_attention_fwd_host(const TensorQ& q_g_m_k,
auto
ref_dropout
=
ReferenceDropoutInstance
{};
auto
ref_dropout_invoker
=
ref_dropout
.
MakeInvoker
();
auto
ref_dropout_argment
=
ref_dropout
.
MakeArgument
(
z_g_m_n
,
p_g_m_n
,
p_drop_g_m_n
,
p_dropout_in_
16bits
,
rp_dropout
);
ref_dropout
.
MakeArgument
(
z_g_m_n
,
p_g_m_n
,
p_drop_g_m_n
,
p_dropout_in_
uint8_t
,
rp_dropout
);
ref_dropout_invoker
.
Run
(
ref_dropout_argment
);
// Y = P_dropout * V
...
...
@@ -325,10 +325,10 @@ int run(int argc, char* argv[])
exit
(
0
);
}
float
p_dropout
=
1
-
p_drop
;
ZDataType
p_dropout_in_
16bits
=
ZDataType
(
std
::
floor
(
p_dropout
*
6553
5.0
));
float
rp_dropout
=
1.0
/
p_dropout
;
float
alpha
=
1.
f
/
std
::
sqrt
(
K
);
float
p_dropout
=
1
-
p_drop
;
ZDataType
p_dropout_in_
uint8_t
=
ZDataType
(
std
::
floor
(
p_dropout
*
25
5.0
));
float
rp_dropout
=
1.0
/
p_dropout
;
float
alpha
=
1.
f
/
std
::
sqrt
(
K
);
std
::
cout
<<
"do_verification: "
<<
do_verification
<<
std
::
endl
;
std
::
cout
<<
"init_method: "
<<
init_method
<<
std
::
endl
;
...
...
@@ -633,7 +633,7 @@ int run(int argc, char* argv[])
lse_g_m
,
p_drop_g_m_n
,
z_g_m_n
,
p_dropout_in_
16bits
,
p_dropout_in_
uint8_t
,
rp_dropout
);
y_gs_ms_os
.
ForEach
([
&
](
auto
&
self
,
auto
idx
)
{
self
(
idx
)
=
y_g_m_o
(
idx
[
0
]
*
G1
+
idx
[
1
],
idx
[
2
],
idx
[
3
]);
...
...
@@ -693,7 +693,7 @@ int run(int argc, char* argv[])
auto
ref_dropout
=
ReferenceDropoutInstance
{};
auto
ref_dropout_invoker
=
ref_dropout
.
MakeInvoker
();
auto
ref_dropout_argment
=
ref_dropout
.
MakeArgument
(
z_g_m_n
,
pgrad_drop_g_m_n
,
pgrad_g_m_n
,
p_dropout_in_
16bits
,
rp_dropout
);
z_g_m_n
,
pgrad_drop_g_m_n
,
pgrad_g_m_n
,
p_dropout_in_
uint8_t
,
rp_dropout
);
ref_dropout_invoker
.
Run
(
ref_dropout_argment
);
// dS_i_j = P_i_j .* (dP_i_j - dY_i dot Y_i)
...
...
example/32_batched_gemm_scale_softmax_gemm/batched_multihead_attention_train_v2.cpp
View file @
0353c29e
...
...
@@ -247,7 +247,7 @@ void run_attention_fwd_host(const TensorQ& q_g_m_k,
TensorLSE
&
lse_g_m
,
TensorP
&
p_drop_g_m_n
,
TensorZ
&
z_g_m_n
,
ZDataType
p_dropout_in_
16bits
,
ZDataType
p_dropout_in_
uint8_t
,
float
rp_dropout
)
{
// S = alpha * Q * K^T
...
...
@@ -279,7 +279,7 @@ void run_attention_fwd_host(const TensorQ& q_g_m_k,
auto
ref_dropout
=
ReferenceDropoutInstance
{};
auto
ref_dropout_invoker
=
ref_dropout
.
MakeInvoker
();
auto
ref_dropout_argment
=
ref_dropout
.
MakeArgument
(
z_g_m_n
,
p_g_m_n
,
p_drop_g_m_n
,
p_dropout_in_
16bits
,
rp_dropout
);
ref_dropout
.
MakeArgument
(
z_g_m_n
,
p_g_m_n
,
p_drop_g_m_n
,
p_dropout_in_
uint8_t
,
rp_dropout
);
ref_dropout_invoker
.
Run
(
ref_dropout_argment
);
// Y = P_dropout * V
...
...
@@ -354,10 +354,10 @@ int run(int argc, char* argv[])
exit
(
0
);
}
float
p_dropout
=
1
-
p_drop
;
ZDataType
p_dropout_in_
16bits
=
ZDataType
(
std
::
floor
(
p_dropout
*
6553
5.0
));
float
rp_dropout
=
1.0
/
p_dropout
;
float
alpha
=
1.
f
/
std
::
sqrt
(
K
);
float
p_dropout
=
1
-
p_drop
;
ZDataType
p_dropout_in_
uint8_t
=
ZDataType
(
std
::
floor
(
p_dropout
*
25
5.0
));
float
rp_dropout
=
1.0
/
p_dropout
;
float
alpha
=
1.
f
/
std
::
sqrt
(
K
);
std
::
cout
<<
"do_verification: "
<<
do_verification
<<
std
::
endl
;
std
::
cout
<<
"init_method: "
<<
init_method
<<
std
::
endl
;
...
...
@@ -811,7 +811,7 @@ int run(int argc, char* argv[])
lse_g_m
,
p_drop_g_m_n
,
z_fwd_g_m_n
,
p_dropout_in_
16bits
,
p_dropout_in_
uint8_t
,
rp_dropout
);
ygrad_gs_ms_os
.
ForEach
([
&
](
auto
&
self
,
auto
idx
)
{
...
...
@@ -854,7 +854,7 @@ int run(int argc, char* argv[])
auto
ref_dropout
=
ReferenceDropoutInstance
{};
auto
ref_dropout_invoker
=
ref_dropout
.
MakeInvoker
();
auto
ref_dropout_argment
=
ref_dropout
.
MakeArgument
(
z_bwd_g_m_n
,
pgrad_drop_g_m_n
,
pgrad_g_m_n
,
p_dropout_in_
16bits
,
rp_dropout
);
z_bwd_g_m_n
,
pgrad_drop_g_m_n
,
pgrad_g_m_n
,
p_dropout_in_
uint8_t
,
rp_dropout
);
ref_dropout_invoker
.
Run
(
ref_dropout_argment
);
// dS_i_j = P_i_j .* (dP_i_j - dY_i dot Y_i)
...
...
example/32_batched_gemm_scale_softmax_gemm/grouped_multihead_attention_backward_v2.cpp
View file @
0353c29e
...
...
@@ -216,7 +216,7 @@ void run_attention_fwd_host(const TensorQ& q_g_m_k,
TensorLSE
&
lse_g_m
,
TensorP
&
p_drop_g_m_n
,
TensorZ
&
z_g_m_n
,
ZDataType
p_dropout_in_
16bits
,
ZDataType
p_dropout_in_
uint8_t
,
float
rp_dropout
)
{
// S = alpha * Q * K^T
...
...
@@ -248,7 +248,7 @@ void run_attention_fwd_host(const TensorQ& q_g_m_k,
auto
ref_dropout
=
ReferenceDropoutInstance
{};
auto
ref_dropout_invoker
=
ref_dropout
.
MakeInvoker
();
auto
ref_dropout_argment
=
ref_dropout
.
MakeArgument
(
z_g_m_n
,
p_g_m_n
,
p_drop_g_m_n
,
p_dropout_in_
16bits
,
rp_dropout
);
ref_dropout
.
MakeArgument
(
z_g_m_n
,
p_g_m_n
,
p_drop_g_m_n
,
p_dropout_in_
uint8_t
,
rp_dropout
);
ref_dropout_invoker
.
Run
(
ref_dropout_argment
);
// Y = P_dropout * V
...
...
@@ -311,9 +311,9 @@ int run(int argc, char* argv[])
exit
(
0
);
}
float
p_dropout
=
1
-
p_drop
;
ZDataType
p_dropout_in_
16bits
=
ZDataType
(
std
::
floor
(
p_dropout
*
6553
5.0
));
float
rp_dropout
=
1.0
/
p_dropout
;
float
p_dropout
=
1
-
p_drop
;
ZDataType
p_dropout_in_
uint8_t
=
ZDataType
(
std
::
floor
(
p_dropout
*
25
5.0
));
float
rp_dropout
=
1.0
/
p_dropout
;
auto
gemm
=
DeviceGemmInstance
{};
auto
invoker
=
gemm
.
MakeInvoker
();
...
...
@@ -686,7 +686,7 @@ int run(int argc, char* argv[])
lse_g_ms
[
i
],
p_drop_g_m_ns
[
i
],
z_g_m_ns
[
i
],
p_dropout_in_
16bits
,
p_dropout_in_
uint8_t
,
rp_dropout
);
y_tensors
[
i
].
ForEach
([
&
](
auto
&
self
,
auto
idx
)
{
...
...
@@ -738,7 +738,7 @@ int run(int argc, char* argv[])
auto
ref_dropout
=
ReferenceDropoutInstance
{};
auto
ref_dropout_invoker
=
ref_dropout
.
MakeInvoker
();
auto
ref_dropout_argment
=
ref_dropout
.
MakeArgument
(
z_g_m_ns
[
i
],
pgrad_drop_g_m_n
,
pgrad_g_m_n
,
p_dropout_in_
16bits
,
rp_dropout
);
z_g_m_ns
[
i
],
pgrad_drop_g_m_n
,
pgrad_g_m_n
,
p_dropout_in_
uint8_t
,
rp_dropout
);
ref_dropout_invoker
.
Run
(
ref_dropout_argment
);
sgrad_g_m_n
.
ForEach
([
&
](
auto
&
self
,
auto
idx_gmn
)
{
...
...
example/32_batched_gemm_scale_softmax_gemm/grouped_multihead_attention_backward_v3.cpp
View file @
0353c29e
...
...
@@ -217,7 +217,7 @@ void run_attention_fwd_host(const TensorQ& q_g_m_k,
TensorLSE
&
lse_g_m
,
TensorP
&
p_drop_g_m_n
,
TensorZ
&
z_g_m_n
,
ZDataType
p_dropout_in_
16bits
,
ZDataType
p_dropout_in_
uint8_t
,
float
rp_dropout
)
{
// S = alpha * Q * K^T
...
...
@@ -249,7 +249,7 @@ void run_attention_fwd_host(const TensorQ& q_g_m_k,
auto
ref_dropout
=
ReferenceDropoutInstance
{};
auto
ref_dropout_invoker
=
ref_dropout
.
MakeInvoker
();
auto
ref_dropout_argment
=
ref_dropout
.
MakeArgument
(
z_g_m_n
,
p_g_m_n
,
p_drop_g_m_n
,
p_dropout_in_
16bits
,
rp_dropout
);
ref_dropout
.
MakeArgument
(
z_g_m_n
,
p_g_m_n
,
p_drop_g_m_n
,
p_dropout_in_
uint8_t
,
rp_dropout
);
ref_dropout_invoker
.
Run
(
ref_dropout_argment
);
// Y = P_dropout * V
...
...
@@ -312,9 +312,9 @@ int run(int argc, char* argv[])
exit
(
0
);
}
float
p_dropout
=
1
-
p_drop
;
ZDataType
p_dropout_in_
16bits
=
ZDataType
(
std
::
floor
(
p_dropout
*
6553
5.0
));
float
rp_dropout
=
1.0
/
p_dropout
;
float
p_dropout
=
1
-
p_drop
;
ZDataType
p_dropout_in_
uint8_t
=
ZDataType
(
std
::
floor
(
p_dropout
*
25
5.0
));
float
rp_dropout
=
1.0
/
p_dropout
;
auto
gemm
=
DeviceGemmInstance
{};
auto
invoker
=
gemm
.
MakeInvoker
();
...
...
@@ -699,7 +699,7 @@ int run(int argc, char* argv[])
lse_g_ms
[
i
],
p_drop_g_m_ns
[
i
],
z_g_m_ns
[
i
],
p_dropout_in_
16bits
,
p_dropout_in_
uint8_t
,
rp_dropout
);
y_tensors
[
i
].
ForEach
([
&
](
auto
&
self
,
auto
idx
)
{
...
...
@@ -751,7 +751,7 @@ int run(int argc, char* argv[])
auto
ref_dropout
=
ReferenceDropoutInstance
{};
auto
ref_dropout_invoker
=
ref_dropout
.
MakeInvoker
();
auto
ref_dropout_argment
=
ref_dropout
.
MakeArgument
(
z_g_m_ns
[
i
],
pgrad_drop_g_m_n
,
pgrad_g_m_n
,
p_dropout_in_
16bits
,
rp_dropout
);
z_g_m_ns
[
i
],
pgrad_drop_g_m_n
,
pgrad_g_m_n
,
p_dropout_in_
uint8_t
,
rp_dropout
);
ref_dropout_invoker
.
Run
(
ref_dropout_argment
);
sgrad_g_m_n
.
ForEach
([
&
](
auto
&
self
,
auto
idx_gmn
)
{
...
...
example/32_batched_gemm_scale_softmax_gemm/grouped_multihead_attention_train_v2.cpp
View file @
0353c29e
...
...
@@ -246,7 +246,7 @@ void run_attention_fwd_host(const TensorQ& q_g_m_k,
TensorLSE
&
lse_g_m
,
TensorP
&
p_drop_g_m_n
,
TensorZ
&
z_g_m_n
,
ZDataType
p_dropout_in_
16bits
,
ZDataType
p_dropout_in_
uint8_t
,
float
rp_dropout
)
{
// S = alpha * Q * K^T
...
...
@@ -278,7 +278,7 @@ void run_attention_fwd_host(const TensorQ& q_g_m_k,
auto
ref_dropout
=
ReferenceDropoutInstance
{};
auto
ref_dropout_invoker
=
ref_dropout
.
MakeInvoker
();
auto
ref_dropout_argment
=
ref_dropout
.
MakeArgument
(
z_g_m_n
,
p_g_m_n
,
p_drop_g_m_n
,
p_dropout_in_
16bits
,
rp_dropout
);
ref_dropout
.
MakeArgument
(
z_g_m_n
,
p_g_m_n
,
p_drop_g_m_n
,
p_dropout_in_
uint8_t
,
rp_dropout
);
ref_dropout_invoker
.
Run
(
ref_dropout_argment
);
// Y = P_dropout * V
...
...
@@ -341,9 +341,9 @@ int run(int argc, char* argv[])
exit
(
0
);
}
float
p_dropout
=
1
-
p_drop
;
ZDataType
p_dropout_in_
16bits
=
ZDataType
(
std
::
floor
(
p_dropout
*
6553
5.0
));
float
rp_dropout
=
1.0
/
p_dropout
;
float
p_dropout
=
1
-
p_drop
;
ZDataType
p_dropout_in_
uint8_t
=
ZDataType
(
std
::
floor
(
p_dropout
*
25
5.0
));
float
rp_dropout
=
1.0
/
p_dropout
;
auto
gemm_fwd
=
DeviceGemmInstanceFWD
{};
auto
invoker_fwd
=
gemm_fwd
.
MakeInvoker
();
...
...
@@ -856,7 +856,7 @@ int run(int argc, char* argv[])
lse_g_ms
[
i
],
p_drop_g_m_ns
[
i
],
z_fwd_g_m_ns
[
i
],
p_dropout_in_
16bits
,
p_dropout_in_
uint8_t
,
rp_dropout
);
int
G0
=
v_tensors
[
i
].
GetLengths
()[
0
];
...
...
@@ -889,7 +889,7 @@ int run(int argc, char* argv[])
auto
ref_dropout
=
ReferenceDropoutInstance
{};
auto
ref_dropout_invoker
=
ref_dropout
.
MakeInvoker
();
auto
ref_dropout_argment
=
ref_dropout
.
MakeArgument
(
z_bwd_g_m_ns
[
i
],
pgrad_drop_g_m_n
,
pgrad_g_m_n
,
p_dropout_in_
16bits
,
rp_dropout
);
z_bwd_g_m_ns
[
i
],
pgrad_drop_g_m_n
,
pgrad_g_m_n
,
p_dropout_in_
uint8_t
,
rp_dropout
);
ref_dropout_invoker
.
Run
(
ref_dropout_argment
);
sgrad_g_m_n
.
ForEach
([
&
](
auto
&
self
,
auto
idx_gmn
)
{
...
...
example/32_batched_gemm_scale_softmax_gemm/run_batched_multihead_attention_forward.inc
View file @
0353c29e
...
...
@@ -66,10 +66,10 @@ int run(int argc, char* argv[])
exit
(
0
);
}
float
p_dropout
=
1
-
p_drop
;
ZDataType
p_dropout_in_
16bits
=
ZDataType
(
std
::
floor
(
p_dropout
*
6553
5.0
));
float
rp_dropout
=
1.0
/
p_dropout
;
float
alpha
=
1.
f
/
std
::
sqrt
(
K
);
float
p_dropout
=
1
-
p_drop
;
ZDataType
p_dropout_in_
uint8_t
=
ZDataType
(
std
::
floor
(
p_dropout
*
25
5.0
));
float
rp_dropout
=
1.0
/
p_dropout
;
float
alpha
=
1.
f
/
std
::
sqrt
(
K
);
std
::
vector
<
ck
::
index_t
>
a_gs_ms_ks_lengths
{
G0
,
G1
,
M
,
K
};
std
::
vector
<
ck
::
index_t
>
a_gs_ms_ks_strides
=
...
...
@@ -159,6 +159,7 @@ int run(int argc, char* argv[])
a_device_buf
.
ToDevice
(
a_gs_ms_ks
.
mData
.
data
());
b0_device_buf
.
ToDevice
(
b0_gs_ns_ks
.
mData
.
data
());
b1_device_buf
.
ToDevice
(
b1_gs_os_ns
.
mData
.
data
());
z_device_buf
.
ToDevice
(
z_gs_ms_ns
.
mData
.
data
());
auto
a_element_op
=
AElementOp
{};
auto
b0_element_op
=
B0ElementOp
{};
...
...
@@ -322,7 +323,7 @@ int run(int argc, char* argv[])
auto
ref_dropout
=
ReferenceDropoutInstance
{};
auto
ref_dropout_invoker
=
ref_dropout
.
MakeInvoker
();
auto
ref_dropout_argment
=
ref_dropout
.
MakeArgument
(
z_g_m_n
,
a1_g_m_n
,
a1_g_m_n_drop
,
p_dropout_in_
16bits
,
rp_dropout
);
z_g_m_n
,
a1_g_m_n
,
a1_g_m_n_drop
,
p_dropout_in_
uint8_t
,
rp_dropout
);
ref_dropout_invoker
.
Run
(
ref_dropout_argment
);
// gemm1
...
...
example/32_batched_gemm_scale_softmax_gemm/run_grouped_multihead_attention_forward.inc
View file @
0353c29e
...
...
@@ -43,9 +43,9 @@ int run(int argc, char* argv[])
exit
(
0
);
}
float
p_dropout
=
1
-
p_drop
;
uint16_t
p_dropout_in_
16bits
=
uint16_t
(
std
::
floor
(
p_dropout
*
6553
5.0
));
float
rp_dropout
=
1.0
/
p_dropout
;
float
p_dropout
=
1
-
p_drop
;
ZDataType
p_dropout_in_
uint8_t
=
ZDataType
(
std
::
floor
(
p_dropout
*
25
5.0
));
float
rp_dropout
=
1.0
/
p_dropout
;
float
alpha
=
1
;
// scaling after 1st gemm
...
...
@@ -217,6 +217,7 @@ int run(int argc, char* argv[])
a_tensors_device
[
i
]
->
ToDevice
(
a_gs_ms_ks
.
mData
.
data
());
b0_tensors_device
[
i
]
->
ToDevice
(
b0_gs_ns_ks
.
mData
.
data
());
b1_tensors_device
[
i
]
->
ToDevice
(
b1_gs_os_ns
.
mData
.
data
());
z_tensors_device
[
i
]
->
ToDevice
(
z_gs_ms_ns
.
mData
.
data
());
p_a
.
push_back
(
a_tensors_device
[
i
]
->
GetDeviceBuffer
());
p_b0
.
push_back
(
b0_tensors_device
[
i
]
->
GetDeviceBuffer
());
...
...
@@ -390,7 +391,7 @@ int run(int argc, char* argv[])
auto
ref_dropout
=
ReferenceDropoutInstance
{};
auto
ref_dropout_invoker
=
ref_dropout
.
MakeInvoker
();
auto
ref_dropout_argment
=
ref_dropout
.
MakeArgument
(
z_g_m_n
,
a1_g_m_n
,
a1_g_m_n_drop
,
p_dropout_in_
16bits
,
rp_dropout
);
z_g_m_n
,
a1_g_m_n
,
a1_g_m_n_drop
,
p_dropout_in_
uint8_t
,
rp_dropout
);
ref_dropout_invoker
.
Run
(
ref_dropout_argment
);
// gemm 1
...
...
include/ck/tensor_operation/gpu/block/blockwise_dropout.hpp
View file @
0353c29e
...
...
@@ -16,111 +16,111 @@ struct BlockwiseDropout
static
constexpr
index_t
MRepeat
=
ThreadSliceDesc_M_K
{}.
GetLength
(
I0
);
static
constexpr
index_t
KRepeat
=
ThreadSliceDesc_M_K
{}.
GetLength
(
I1
);
template
<
typename
CThreadBuffer
,
bool
using_sign_bit
=
false
>
__host__
__device__
void
ApplyDropout
(
CThreadBuffer
&
in_thread_buf
,
ck
::
philox
&
ph
)
{
auto
execute_dropout
=
[
&
](
bool
keep
,
DataType
val
)
{
if
constexpr
(
using_sign_bit
)
return
keep
?
val
:
-
val
;
else
return
keep
?
val
*
p_dropout_rescale
:
float
(
0
);
};
constexpr
int
tmp_size
=
MRepeat
*
KRepeat
;
int
philox_calls
=
tmp_size
/
8
;
ushort
tmp
[
tmp_size
];
for
(
int
i
=
0
;
i
<
philox_calls
;
i
++
)
{
ph
.
get_random_8x16
((
tmp
+
i
*
8
));
}
block_sync_lds
();
int
tmp_index
=
0
;
static_for
<
0
,
MRepeat
,
1
>
{}([
&
](
auto
iM
)
{
static_for
<
0
,
KRepeat
,
1
>
{}([
&
](
auto
iK
)
{
auto
offset
=
Number
<
ThreadSliceDesc_M_K
{}.
CalculateOffset
(
make_tuple
(
iM
,
iK
))
>
{};
in_thread_buf
(
offset
)
=
execute_dropout
(
tmp
[
tmp_index
]
<=
p_dropout_
16bits
,
in_thread_buf
(
offset
));
tmp_index
=
tmp_index
+
1
;
});
});
}
template
<
typename
CThreadBuffer
,
typename
ZThreadBuffer
,
bool
using_sign_bit
=
false
>
__host__
__device__
void
ApplyDropout
(
CThreadBuffer
&
in_thread_buf
,
ck
::
philox
&
ph
,
ZThreadBuffer
&
z_thread_buf
)
{
auto
execute_dropout
=
[
&
](
bool
keep
,
DataType
val
)
{
if
constexpr
(
using_sign_bit
)
return
keep
?
val
:
-
val
;
else
return
keep
?
val
*
p_dropout_rescale
:
float
(
0
);
};
constexpr
int
tmp_size
=
MRepeat
*
KRepeat
;
int
philox_calls
=
tmp_size
/
8
;
ushort
tmp
[
tmp_size
];
for
(
int
i
=
0
;
i
<
philox_calls
;
i
++
)
{
ph
.
get_random_8x16
((
tmp
+
i
*
8
));
}
block_sync_lds
();
int
tmp_index
=
0
;
static_for
<
0
,
MRepeat
,
1
>
{}([
&
](
auto
iM
)
{
static_for
<
0
,
KRepeat
,
1
>
{}([
&
](
auto
iK
)
{
auto
offset
=
Number
<
ThreadSliceDesc_M_K
{}.
CalculateOffset
(
make_tuple
(
iM
,
iK
))
>
{};
in_thread_buf
(
offset
)
=
execute_dropout
(
tmp
[
tmp_index
]
<=
p_dropout_
16bits
,
in_thread_buf
(
offset
));
z_thread_buf
(
offset
)
=
tmp
[
tmp_index
];
tmp_index
=
tmp_index
+
1
;
});
});
}
template
<
typename
CThreadBuffer
,
typename
ZThreadBuffer
,
bool
using_sign_bit
,
typename
N0
,
typename
Offset
>
__host__
__device__
void
ApplyDropout
(
CThreadBuffer
&
in_thread_buf
,
ck
::
philox
&
ph
,
ZThreadBuffer
&
z_thread_buf
)
{
auto
execute_dropout
=
[
&
](
bool
keep
,
DataType
val
)
{
if
constexpr
(
using_sign_bit
)
return
keep
?
val
:
-
val
;
else
return
keep
?
val
*
p_dropout_rescale
:
float
(
0
);
};
constexpr
int
tmp_size
=
MRepeat
*
KRepeat
/
N0
{}.
value
;
int
philox_calls
=
tmp_size
/
8
;
ushort
tmp
[
tmp_size
];
for
(
int
i
=
0
;
i
<
philox_calls
;
i
++
)
{
ph
.
get_random_8x16
((
tmp
+
i
*
8
));
}
block_sync_lds
();
constexpr
auto
iOffset
=
Number
<
tmp_size
>
{}
*
Offset
{};
static_for
<
0
,
tmp_size
,
1
>
{}([
&
](
auto
i
)
{
in_thread_buf
(
i
+
iOffset
)
=
execute_dropout
(
tmp
[
i
.
value
]
<=
p_dropout_
16bits
,
in_thread_buf
(
i
+
iOffset
));
z_thread_buf
(
i
)
=
tmp
[
i
.
value
];
});
}
//
template <typename CThreadBuffer, bool using_sign_bit = false>
//
__host__ __device__ void ApplyDropout(CThreadBuffer& in_thread_buf, ck::philox& ph)
//
{
//
auto execute_dropout = [&](bool keep, DataType val) {
//
if constexpr(using_sign_bit)
//
return keep ? val : -val;
//
else
//
return keep ? val * p_dropout_rescale : float(0);
//
};
//
constexpr int tmp_size = MRepeat * KRepeat;
//
int philox_calls = tmp_size / 8;
//
ushort tmp[tmp_size];
//
for(int i = 0; i < philox_calls; i++)
//
{
//
ph.get_random_8x16((tmp + i * 8));
//
}
//
block_sync_lds();
//
int tmp_index = 0;
//
static_for<0, MRepeat, 1>{}([&](auto iM) {
//
static_for<0, KRepeat, 1>{}([&](auto iK) {
//
auto offset = Number<ThreadSliceDesc_M_K{}.CalculateOffset(make_tuple(iM,
//
iK))>{};
in_thread_buf(offset) =
//
execute_dropout(tmp[tmp_index] <= p_dropout_
uint8_t
, in_thread_buf(offset));
//
tmp_index = tmp_index + 1;
//
});
//
});
//
}
//
template <typename CThreadBuffer, typename ZThreadBuffer, bool using_sign_bit = false>
//
__host__ __device__ void
//
ApplyDropout(CThreadBuffer& in_thread_buf, ck::philox& ph, ZThreadBuffer& z_thread_buf)
//
{
//
auto execute_dropout = [&](bool keep, DataType val) {
//
if constexpr(using_sign_bit)
//
return keep ? val : -val;
//
else
//
return keep ? val * p_dropout_rescale : float(0);
//
};
//
constexpr int tmp_size = MRepeat * KRepeat;
//
int philox_calls = tmp_size / 8;
//
ushort tmp[tmp_size];
//
for(int i = 0; i < philox_calls; i++)
//
{
//
ph.get_random_8x16((tmp + i * 8));
//
}
//
block_sync_lds();
//
int tmp_index = 0;
//
static_for<0, MRepeat, 1>{}([&](auto iM) {
//
static_for<0, KRepeat, 1>{}([&](auto iK) {
//
auto offset = Number<ThreadSliceDesc_M_K{}.CalculateOffset(make_tuple(iM,
//
iK))>{};
in_thread_buf(offset) =
//
execute_dropout(tmp[tmp_index] <= p_dropout_
uint8_t
, in_thread_buf(offset));
//
z_thread_buf(offset) = tmp[tmp_index];
//
tmp_index = tmp_index + 1;
//
});
//
});
//
}
//
template <typename CThreadBuffer,
//
typename ZThreadBuffer,
//
bool using_sign_bit,
//
typename N0,
//
typename Offset>
//
__host__ __device__ void
//
ApplyDropout(CThreadBuffer& in_thread_buf, ck::philox& ph, ZThreadBuffer& z_thread_buf)
//
{
//
auto execute_dropout = [&](bool keep, DataType val) {
//
if constexpr(using_sign_bit)
//
return keep ? val : -val;
//
else
//
return keep ? val * p_dropout_rescale : float(0);
//
};
//
constexpr int tmp_size = MRepeat * KRepeat / N0{}.value;
//
int philox_calls = tmp_size / 8;
//
ushort tmp[tmp_size];
//
for(int i = 0; i < philox_calls; i++)
//
{
//
ph.get_random_8x16((tmp + i * 8));
//
}
//
block_sync_lds();
//
constexpr auto iOffset = Number<tmp_size>{} * Offset{};
//
static_for<0, tmp_size, 1>{}([&](auto i) {
//
in_thread_buf(i + iOffset) =
//
execute_dropout(tmp[i.value] <= p_dropout_
uint8_t
, in_thread_buf(i + iOffset));
//
z_thread_buf(i) = tmp[i.value];
//
});
//
}
template
<
typename
CThreadBuffer
,
typename
Offset
,
bool
using_sign_bit
=
false
>
__host__
__device__
void
ApplyDropoutAttnBwd
(
CThreadBuffer
&
in_thread_buf
,
...
...
@@ -138,12 +138,12 @@ struct BlockwiseDropout
constexpr
int
tmp_size
=
MRepeat
*
KRepeat
;
int
philox_calls
=
tmp_size
/
8
;
int
philox_calls
=
tmp_size
/
16
;
u
shor
t
tmp
[
tmp_size
];
u
int8_
t
tmp
[
tmp_size
];
for
(
int
i
=
0
;
i
<
philox_calls
;
i
++
)
{
ph
.
get_random_
8x
16
((
tmp
+
i
*
8
),
element_global_1d_id
+
i
*
Offset
{}
*
MRaw
);
ph
.
get_random_16
x8
((
tmp
+
i
*
16
),
element_global_1d_id
+
i
*
Offset
{}
*
MRaw
);
}
block_sync_lds
();
...
...
@@ -153,7 +153,7 @@ struct BlockwiseDropout
static_for
<
0
,
KRepeat
,
1
>
{}([
&
](
auto
iK
)
{
auto
offset
=
Number
<
ThreadSliceDesc_M_K
{}.
CalculateOffset
(
make_tuple
(
iM
,
iK
))
>
{};
in_thread_buf
(
offset
)
=
execute_dropout
(
tmp
[
tmp_index
]
<=
p_dropout_
16bits
,
in_thread_buf
(
offset
));
execute_dropout
(
tmp
[
tmp_index
]
<=
p_dropout_
uint8_t
,
in_thread_buf
(
offset
));
tmp_index
=
tmp_index
+
1
;
});
});
...
...
@@ -179,12 +179,12 @@ struct BlockwiseDropout
constexpr
int
tmp_size
=
MRepeat
*
KRepeat
;
int
philox_calls
=
tmp_size
/
8
;
int
philox_calls
=
tmp_size
/
16
;
u
shor
t
tmp
[
tmp_size
];
u
int8_
t
tmp
[
tmp_size
];
for
(
int
i
=
0
;
i
<
philox_calls
;
i
++
)
{
ph
.
get_random_
8x
16
((
tmp
+
i
*
8
),
element_global_1d_id
+
i
*
Offset
{}
*
MRaw
);
ph
.
get_random_16
x8
((
tmp
+
i
*
16
),
element_global_1d_id
+
i
*
Offset
{}
*
MRaw
);
}
block_sync_lds
();
...
...
@@ -194,7 +194,7 @@ struct BlockwiseDropout
static_for
<
0
,
KRepeat
,
1
>
{}([
&
](
auto
iK
)
{
auto
offset
=
Number
<
ThreadSliceDesc_M_K
{}.
CalculateOffset
(
make_tuple
(
iM
,
iK
))
>
{};
in_thread_buf
(
offset
)
=
execute_dropout
(
tmp
[
tmp_index
]
<=
p_dropout_
16bits
,
in_thread_buf
(
offset
));
execute_dropout
(
tmp
[
tmp_index
]
<=
p_dropout_
uint8_t
,
in_thread_buf
(
offset
));
z_thread_buf
(
offset
)
=
tmp
[
tmp_index
];
tmp_index
=
tmp_index
+
1
;
});
...
...
@@ -213,7 +213,7 @@ struct BlockwiseDropout
constexpr
int
tmp_size
=
MRepeat
*
KRepeat
/
Step
{}.
value
;
static_for
<
0
,
tmp_size
,
1
>
{}([
&
](
auto
i
)
{
in_thread_buf
(
i
+
Offset
{})
=
execute_dropout
(
z_thread_buf
(
i
)
<=
p_dropout_
16bits
,
in_thread_buf
(
i
+
Offset
{}));
execute_dropout
(
z_thread_buf
(
i
)
<=
p_dropout_
uint8_t
,
in_thread_buf
(
i
+
Offset
{}));
});
}
...
...
@@ -225,18 +225,18 @@ struct BlockwiseDropout
{
constexpr
int
tmp_size
=
MRepeat
*
KRepeat
/
Step
{}.
value
;
int
philox_calls
=
tmp_size
/
8
;
int
philox_calls
=
tmp_size
/
16
;
u
shor
t
tmp
[
tmp_size
];
u
int8_
t
tmp
[
tmp_size
];
for
(
int
i
=
0
;
i
<
philox_calls
;
i
++
)
{
ph
.
get_random_
8x
16
((
tmp
+
i
*
8
),
element_global_1d_id
+
i
*
Offset
{});
ph
.
get_random_16
x8
((
tmp
+
i
*
16
),
element_global_1d_id
+
i
*
Offset
{});
}
static_for
<
0
,
tmp_size
,
1
>
{}([
&
](
auto
i
)
{
z_thread_buf
(
i
)
=
tmp
[
i
.
value
];
});
}
u
shor
t
p_dropout_
16bits
;
u
int8_
t
p_dropout_
uint8_t
;
DataType
p_dropout_rescale
;
};
...
...
include/ck/tensor_operation/gpu/device/impl/device_batched_mha_fwd_xdl_cshuffle_v2.hpp
View file @
0353c29e
...
...
@@ -40,7 +40,7 @@ template <typename GridwiseGemm,
typename
D0GridDescriptor_M0_N0_M1_N1_M2_N2_M3_N3_N4_N5
,
typename
B1GridDesc_BK0_N_BK1
,
typename
CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
,
typename
ZGridDescriptor_M0_N0_M1_N1_M2_N2_M3_N3_
M4_
N4_N5
_N6
,
typename
ZGridDescriptor_M0_N0_M1_N1_M2_N2_M3_N3_N4_N5
,
typename
LSEGridDescriptor_M
,
typename
Block2CTileMap
,
typename
ComputeBasePtrOfStridedBatch
,
...
...
@@ -73,15 +73,15 @@ __global__ void
const
B1GridDesc_BK0_N_BK1
b1_grid_desc_bk0_n_bk1
,
const
CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
c_grid_desc_mblock_mperblock_nblock_nperblock
,
const
ZGridDescriptor_M0_N0_M1_N1_M2_N2_M3_N3_
M4_
N4_N5
_N6
z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_
m4_
n4_n5
_n6
,
const
ZGridDescriptor_M0_N0_M1_N1_M2_N2_M3_N3_N4_N5
z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5
,
const
LSEGridDescriptor_M
lse_grid_desc_m
,
const
Block2CTileMap
block_2_ctile_map
,
const
index_t
batch_count
,
const
index_t
mblock
,
const
ComputeBasePtrOfStridedBatch
compute_base_ptr_of_batch
,
const
C0MatrixMask
c0_matrix_mask
,
const
u
shor
t
p_dropout_in_
16bits
,
const
u
int8_
t
p_dropout_in_
uint8_t
,
const
GemmAccDataType
p_dropout_rescale
,
const
unsigned
long
long
seed
,
const
unsigned
long
long
offset
,
...
...
@@ -145,11 +145,11 @@ __global__ void
d0_griddesc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5
,
b1_grid_desc_bk0_n_bk1
,
c_grid_desc_mblock_mperblock_nblock_nperblock
,
z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_
m4_
n4_n5
_n6
,
z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5
,
lse_grid_desc_m
,
block_2_ctile_map
,
c0_matrix_mask
,
p_dropout_in_
16bits
,
p_dropout_in_
uint8_t
,
p_dropout_rescale
,
ph
,
z_random_matrix_offset
,
...
...
@@ -178,11 +178,11 @@ __global__ void
d0_griddesc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5
,
b1_grid_desc_bk0_n_bk1
,
c_grid_desc_mblock_mperblock_nblock_nperblock
,
z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_
m4_
n4_n5
_n6
,
z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5
,
lse_grid_desc_m
,
block_2_ctile_map
,
c0_matrix_mask
,
p_dropout_in_
16bits
,
p_dropout_in_
uint8_t
,
p_dropout_rescale
,
ph
,
z_random_matrix_offset
,
...
...
@@ -207,14 +207,14 @@ __global__ void
ignore
=
d0_griddesc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5
;
ignore
=
b1_grid_desc_bk0_n_bk1
;
ignore
=
c_grid_desc_mblock_mperblock_nblock_nperblock
;
ignore
=
z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_
m4_
n4_n5
_n6
;
ignore
=
z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5
;
ignore
=
lse_grid_desc_m
;
ignore
=
block_2_ctile_map
;
ignore
=
batch_count
;
ignore
=
mblock
;
ignore
=
compute_base_ptr_of_batch
;
ignore
=
c0_matrix_mask
;
ignore
=
p_dropout_in_
16bits
;
ignore
=
p_dropout_in_
uint8_t
;
ignore
=
p_dropout_rescale
;
ignore
=
seed
;
ignore
=
offset
;
...
...
@@ -695,18 +695,17 @@ struct DeviceBatchedMultiheadAttentionForward_Xdl_CShuffle_V2
}
}
is_dropout_
=
p_dropout
>
0.0
;
//
p_dropout_
=
1.
f
-
p_dropout
;
p_dropout_in_
16bits
_
=
uint
16
_t
(
std
::
floor
(
p_dropout_
*
6553
5.0
));
p_dropout_
=
1.
f
/
p_dropout_
;
p_dropout_rescale_
=
type_convert
<
GemmAccDataType
>
(
p_dropout_
);
is_dropout_
=
p_dropout
>
0.0
;
//
p_dropout_
=
1.
f
-
p_dropout
;
p_dropout_in_
uint8_t
_
=
uint
8
_t
(
std
::
floor
(
p_dropout_
*
25
5.0
));
p_dropout_
=
1.
f
/
p_dropout_
;
p_dropout_rescale_
=
type_convert
<
GemmAccDataType
>
(
p_dropout_
);
seed_
=
std
::
get
<
0
>
(
seeds
);
offset_
=
std
::
get
<
1
>
(
seeds
);
z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_m4_n4_n5_n6_
=
GridwiseGemm
::
MakeCGridDescriptor_M0_N0_M1_N1_M2_N2_M3_N3_M4_N4_N5_N6
(
z_grid_desc_m_n_
);
z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5_
=
GridwiseGemm
::
MakeCGridDescriptor_M0_N0_M1_N1_M2_N2_M3_N3_N4_N5
(
z_grid_desc_m_n_
);
m_raw_padded_
=
GridwiseGemm
::
GetPaddedSize
(
raw_lengths_mz_nz_kz_gemm1nz_
[
0
]);
n_raw_padded_
=
GridwiseGemm
::
GetPaddedSize
(
raw_lengths_mz_nz_kz_gemm1nz_
[
1
]);
...
...
@@ -779,8 +778,8 @@ struct DeviceBatchedMultiheadAttentionForward_Xdl_CShuffle_V2
typename
GridwiseGemm
::
CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
c_grid_desc_mblock_mperblock_nblock_nperblock_
;
typename
GridwiseGemm
::
ZGridDescriptor_M0_N0_M1_N1_M2_N2_M3_N3_
M4_
N4_N5
_N6
z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_
m4_
n4_n5_
n6_
;
typename
GridwiseGemm
::
ZGridDescriptor_M0_N0_M1_N1_M2_N2_M3_N3_N4_N5
z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5_
;
// block-to-c-tile map
typename
GridwiseGemm
::
DefaultBlock2CTileMap
block_2_ctile_map_
;
...
...
@@ -806,7 +805,7 @@ struct DeviceBatchedMultiheadAttentionForward_Xdl_CShuffle_V2
ComputeBasePtrOfStridedBatch
compute_base_ptr_of_batch_
;
float
p_dropout_
;
u
shor
t
p_dropout_in_
16bits
_
;
u
int8_
t
p_dropout_in_
uint8_t
_
;
GemmAccDataType
p_dropout_rescale_
;
unsigned
long
long
seed_
;
unsigned
long
long
offset_
;
...
...
@@ -864,7 +863,7 @@ struct DeviceBatchedMultiheadAttentionForward_Xdl_CShuffle_V2
typename
GridwiseGemm
::
D0GridDescriptor_M0_N0_M1_N1_M2_N2_M3_N3_N4_N5
,
DeviceOp
::
B1GridDesc_BK0_N_BK1
,
typename
GridwiseGemm
::
CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
,
typename
GridwiseGemm
::
ZGridDescriptor_M0_N0_M1_N1_M2_N2_M3_N3_
M4_
N4_N5
_N6
,
typename
GridwiseGemm
::
ZGridDescriptor_M0_N0_M1_N1_M2_N2_M3_N3_N4_N5
,
DeviceOp
::
LSEGridDesc_M
,
typename
GridwiseGemm
::
DefaultBlock2CTileMap
,
ComputeBasePtrOfStridedBatch
,
...
...
@@ -897,14 +896,14 @@ struct DeviceBatchedMultiheadAttentionForward_Xdl_CShuffle_V2
arg
.
d0_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5_
,
arg
.
b1_grid_desc_bk0_n_bk1_
,
arg
.
c_grid_desc_mblock_mperblock_nblock_nperblock_
,
arg
.
z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_
m4_
n4_n5_
n6_
,
arg
.
z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5_
,
arg
.
lse_grid_desc_m_
,
arg
.
block_2_ctile_map_
,
arg
.
batch_count_
,
arg
.
block_2_ctile_map_
.
CalculateGridSize
(
arg
.
c_grid_desc_m_n_
),
arg
.
compute_base_ptr_of_batch_
,
arg
.
c0_matrix_mask_
,
arg
.
p_dropout_in_
16bits
_
,
arg
.
p_dropout_in_
uint8_t
_
,
arg
.
p_dropout_rescale_
,
arg
.
seed_
,
arg
.
offset_
,
...
...
include/ck/tensor_operation/gpu/device/impl/device_grouped_mha_fwd_xdl_cshuffle_v2.hpp
View file @
0353c29e
...
...
@@ -48,7 +48,7 @@ __global__ void
const
AccElementwiseOperation
acc_element_op
,
const
B1ElementwiseOperation
b1_element_op
,
const
CElementwiseOperation
c_element_op
,
const
u
shor
t
p_dropout_in_
16bits
,
const
u
int8_
t
p_dropout_in_
uint8_t
,
const
GemmAccDataType
p_dropout_rescale
,
const
unsigned
long
long
seed
,
const
unsigned
long
long
offset
)
...
...
@@ -140,11 +140,11 @@ __global__ void
arg_ptr
[
group_id
].
d0_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5_
,
arg_ptr
[
group_id
].
b1_grid_desc_bk0_n_bk1_
,
arg_ptr
[
group_id
].
c_grid_desc_mblock_mperblock_nblock_nperblock_
,
arg_ptr
[
group_id
].
z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_
m4_
n4_n5_
n6_
,
arg_ptr
[
group_id
].
z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5_
,
arg_ptr
[
group_id
].
lse_grid_desc_m_
,
arg_ptr
[
group_id
].
block_2_ctile_map_
,
arg_ptr
[
group_id
].
c0_matrix_mask_
,
p_dropout_in_
16bits
,
p_dropout_in_
uint8_t
,
p_dropout_rescale
,
ph
,
arg_ptr
[
group_id
].
z_random_matrix_offset_
+
...
...
@@ -178,11 +178,11 @@ __global__ void
arg_ptr
[
group_id
].
d0_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5_
,
arg_ptr
[
group_id
].
b1_grid_desc_bk0_n_bk1_
,
arg_ptr
[
group_id
].
c_grid_desc_mblock_mperblock_nblock_nperblock_
,
arg_ptr
[
group_id
].
z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_
m4_
n4_n5_
n6_
,
arg_ptr
[
group_id
].
z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5_
,
arg_ptr
[
group_id
].
lse_grid_desc_m_
,
arg_ptr
[
group_id
].
block_2_ctile_map_
,
arg_ptr
[
group_id
].
c0_matrix_mask_
,
p_dropout_in_
16bits
,
p_dropout_in_
uint8_t
,
p_dropout_rescale
,
ph
,
arg_ptr
[
group_id
].
z_random_matrix_offset_
+
...
...
@@ -198,7 +198,7 @@ __global__ void
ignore
=
acc_element_op
;
ignore
=
b1_element_op
;
ignore
=
c_element_op
;
ignore
=
p_dropout_in_
16bits
;
ignore
=
p_dropout_in_
uint8_t
;
ignore
=
p_dropout_rescale
;
ignore
=
seed
;
ignore
=
offset
;
...
...
@@ -620,8 +620,8 @@ struct DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle_V2
B1GridDesc_BK0_N_BK1
b1_grid_desc_bk0_n_bk1_
;
typename
GridwiseGemm
::
CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
c_grid_desc_mblock_mperblock_nblock_nperblock_
;
typename
GridwiseGemm
::
ZGridDescriptor_M0_N0_M1_N1_M2_N2_M3_N3_
M4_
N4_N5
_N6
z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_
m4_
n4_n5_
n6_
;
typename
GridwiseGemm
::
ZGridDescriptor_M0_N0_M1_N1_M2_N2_M3_N3_N4_N5
z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5_
;
ZGridDesc_M_N
z_grid_desc_m_n_
;
LSEGridDesc_M
lse_grid_desc_m_
;
...
...
@@ -774,8 +774,8 @@ struct DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle_V2
GridwiseGemm
::
MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
(
c_grid_desc_m_n
);
const
auto
z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_
m4_
n4_n5
_n6
=
GridwiseGemm
::
MakeCGridDescriptor_M0_N0_M1_N1_M2_N2_M3_N3_
M4_
N4_N5
_N6
(
const
auto
z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5
=
GridwiseGemm
::
MakeCGridDescriptor_M0_N0_M1_N1_M2_N2_M3_N3_N4_N5
(
z_grid_desc_m_n
);
const
index_t
BlockStart
=
grid_size_
;
...
...
@@ -819,7 +819,7 @@ struct DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle_V2
d0_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5
,
b1_grid_desc_bk0_n_bk1
,
c_grid_desc_mblock_mperblock_nblock_nperblock
,
z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_
m4_
n4_n5
_n6
,
z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5
,
z_grid_desc_m_n
,
lse_grid_desc_m
,
block_2_ctile_map
.
CalculateGridSize
(
c_grid_desc_m_n
),
...
...
@@ -857,11 +857,11 @@ struct DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle_V2
d0_n_length_stride
});
}
use_dropout_
=
p_dropout
>
0.0
;
//
p_dropout_
=
1.
f
-
p_dropout
;
p_dropout_in_
16bits
_
=
uint
16
_t
(
std
::
floor
(
p_dropout_
*
6553
5.0
));
p_dropout_
=
1.
f
/
p_dropout_
;
p_dropout_rescale_
=
type_convert
<
GemmAccDataType
>
(
p_dropout_
);
use_dropout_
=
p_dropout
>
0.0
;
//
p_dropout_
=
1.
f
-
p_dropout
;
p_dropout_in_
uint8_t
_
=
uint
8
_t
(
std
::
floor
(
p_dropout_
*
25
5.0
));
p_dropout_
=
1.
f
/
p_dropout_
;
p_dropout_rescale_
=
type_convert
<
GemmAccDataType
>
(
p_dropout_
);
seed_
=
std
::
get
<
0
>
(
seeds
);
offset_
=
std
::
get
<
1
>
(
seeds
);
...
...
@@ -880,7 +880,7 @@ struct DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle_V2
CElementwiseOperation
c_element_op_
;
float
p_dropout_
;
u
shor
t
p_dropout_in_
16bits
_
;
u
int8_
t
p_dropout_in_
uint8_t
_
;
unsigned
long
long
seed_
;
unsigned
long
long
offset_
;
GemmAccDataType
p_dropout_rescale_
;
...
...
@@ -949,7 +949,7 @@ struct DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle_V2
arg
.
acc_element_op_
,
arg
.
b1_element_op_
,
arg
.
c_element_op_
,
arg
.
p_dropout_in_
16bits
_
,
arg
.
p_dropout_in_
uint8_t
_
,
arg
.
p_dropout_rescale_
,
arg
.
seed_
,
arg
.
offset_
);
...
...
include/ck/tensor_operation/gpu/grid/gridwise_batched_mha_bwd_xdl_cshuffle_qloop_b2t_light_v1.hpp
View file @
0353c29e
...
...
@@ -120,8 +120,8 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V1
static
constexpr
auto
V_K0
=
KPerBlock
/
V_K1
/
V_K2
;
static
constexpr
auto
V_N1
=
NXdlPerWave
;
static
constexpr
auto
DropoutNThread
=
mfma
.
num_input_blks
;
// 2
// get_random_
8x
16() generates
8
random numbers each time
static
constexpr
auto
DropoutTile
=
Number
<
DropoutNThread
*
8
>
{};
//
16
// get_random_16
x8
() generates
16
random numbers each time
static
constexpr
auto
DropoutTile
=
Number
<
DropoutNThread
*
16
>
{};
//
32
using
ThisThreadBlock
=
ThisThreadBlock
<
BlockSize
>
;
...
...
@@ -1409,8 +1409,8 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V1
{
const
FloatGemmAcc
p_dropout
=
type_convert
<
FloatGemmAcc
>
(
1.0
f
-
p_drop
);
const
FloatGemmAcc
rp_dropout
=
type_convert
<
FloatGemmAcc
>
(
1.0
f
/
p_dropout
);
const
u
shor
t
p_dropout_in_
16bits
=
__builtin_amdgcn_readfirstlane
(
std
::
floor
(
p_dropout
*
6
55
35
.0
));
const
u
int8_
t
p_dropout_in_
uint8_t
=
__builtin_amdgcn_readfirstlane
(
uint8_t
(
std
::
floor
(
p_dropout
*
2
55.0
))
)
;
const
tensor_operation
::
element_wise
::
Scale
scale_rp_dropout
(
s_element_op
.
Value
()
*
rp_dropout
);
...
...
@@ -1726,7 +1726,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V1
decltype
(
thread_slice_desc_m_n
)
>
{};
auto
blockwise_dropout
=
BlockwiseDropout
<
FloatGemmAcc
,
decltype
(
thread_slice_desc_m_n
)
>
{
p_dropout_in_
16bits
,
rp_dropout
};
p_dropout_in_
uint8_t
,
rp_dropout
};
auto
lse_grid_desc_mb_m0_m1_m2_m3_m4
=
MakeLSEGridDescriptor_MB_M0_M1_M2_M3_M4
<
decltype
(
s_blockwise_gemm
)
>
(
lse_grid_desc_m
);
...
...
@@ -1795,7 +1795,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V1
n2
));
// NPerXdl
StaticBuffer
<
AddressSpaceEnum
::
Vgpr
,
u
shor
t
,
u
int8_
t
,
z_thread_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3
.
GetElementSpaceSize
(),
true
>
z_tensor_buffer
;
...
...
@@ -1805,7 +1805,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V1
p_z_grid
,
z_grid_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3
.
GetElementSpaceSize
());
auto
z_thread_copy_vgpr_to_global
=
ThreadwiseTensorSliceTransfer_v1r3
<
u
shor
t
,
u
int8_
t
,
ZDataType
,
decltype
(
z_thread_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3
),
decltype
(
z_grid_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3
),
...
...
include/ck/tensor_operation/gpu/grid/gridwise_batched_mha_bwd_xdl_cshuffle_qloop_b2t_light_v2.hpp
View file @
0353c29e
...
...
@@ -133,8 +133,8 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2
static
constexpr
auto
V_K0
=
Gemm1NPerBlock
/
KPerBlock
;
static
constexpr
auto
V_N1
=
NXdlPerWave
;
static
constexpr
auto
DropoutNThread
=
mfma
.
num_input_blks
;
// 2
// get_random_
8x
16() generates
8
random numbers each time
static
constexpr
auto
DropoutTile
=
Number
<
DropoutNThread
*
8
>
{};
//
16
// get_random_16
x8
() generates
16
random numbers each time
static
constexpr
auto
DropoutTile
=
Number
<
DropoutNThread
*
16
>
{};
//
32
using
ThisThreadBlock
=
ThisThreadBlock
<
BlockSize
>
;
...
...
@@ -1506,8 +1506,8 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2
{
const
FloatGemmAcc
p_dropout
=
type_convert
<
FloatGemmAcc
>
(
1.0
f
-
p_drop
);
const
FloatGemmAcc
rp_dropout
=
type_convert
<
FloatGemmAcc
>
(
1.0
f
/
p_dropout
);
const
u
shor
t
p_dropout_in_
16bits
=
__builtin_amdgcn_readfirstlane
(
std
::
floor
(
p_dropout
*
6
55
35
.0
));
const
u
int8_
t
p_dropout_in_
uint8_t
=
__builtin_amdgcn_readfirstlane
(
uint8_t
(
std
::
floor
(
p_dropout
*
2
55.0
))
)
;
const
tensor_operation
::
element_wise
::
Scale
scale_rp_dropout
(
s_element_op
.
Value
()
*
rp_dropout
);
...
...
@@ -1848,7 +1848,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2
decltype
(
thread_slice_desc_m_n
)
>
{};
auto
blockwise_dropout
=
BlockwiseDropout
<
FloatGemmAcc
,
decltype
(
thread_slice_desc_m_n
)
>
{
p_dropout_in_
16bits
,
rp_dropout
};
p_dropout_in_
uint8_t
,
rp_dropout
};
auto
lse_grid_desc_mb_m0_m1_m2_m3_m4
=
MakeLSEGridDescriptor_MB_M0_M1_M2_M3_M4
<
decltype
(
s_blockwise_gemm
)
>
(
lse_grid_desc_m
);
...
...
@@ -1917,7 +1917,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2
n2
));
// NPerXdl
StaticBuffer
<
AddressSpaceEnum
::
Vgpr
,
u
shor
t
,
u
int8_
t
,
z_thread_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3
.
GetElementSpaceSize
(),
true
>
z_tensor_buffer
;
...
...
@@ -1927,7 +1927,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2
p_z_grid
,
z_grid_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3
.
GetElementSpaceSize
());
auto
z_thread_copy_vgpr_to_global
=
ThreadwiseTensorSliceTransfer_v1r3
<
u
shor
t
,
u
int8_
t
,
ZDataType
,
decltype
(
z_thread_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3
),
decltype
(
z_grid_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3
),
...
...
include/ck/tensor_operation/gpu/grid/gridwise_batched_mha_bwd_xdl_cshuffle_qloop_b2t_v1.hpp
View file @
0353c29e
...
...
@@ -119,8 +119,8 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
static
constexpr
auto
V_K0
=
KPerBlock
/
V_K1
/
V_K2
;
static
constexpr
auto
V_N1
=
NXdlPerWave
;
static
constexpr
auto
DropoutNThread
=
mfma
.
num_input_blks
;
// 2
// get_random_
8x
16() generates
8
random numbers each time
static
constexpr
auto
DropoutTile
=
Number
<
DropoutNThread
*
8
>
{};
//
16
// get_random_16
x8
() generates
16
random numbers each time
static
constexpr
auto
DropoutTile
=
Number
<
DropoutNThread
*
16
>
{};
//
32
using
ThisThreadBlock
=
ThisThreadBlock
<
BlockSize
>
;
...
...
@@ -1492,8 +1492,8 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
{
const
FloatGemmAcc
p_dropout
=
type_convert
<
FloatGemmAcc
>
(
1.0
f
-
p_drop
);
const
FloatGemmAcc
rp_dropout
=
type_convert
<
FloatGemmAcc
>
(
1.0
f
/
p_dropout
);
const
u
shor
t
p_dropout_in_
16bits
=
__builtin_amdgcn_readfirstlane
(
std
::
floor
(
p_dropout
*
6
55
35
.0
));
const
u
int8_
t
p_dropout_in_
uint8_t
=
__builtin_amdgcn_readfirstlane
(
uint8_t
(
std
::
floor
(
p_dropout
*
2
55.0
))
)
;
const
tensor_operation
::
element_wise
::
Scale
scale_rp_dropout
(
s_element_op
.
Value
()
*
rp_dropout
);
...
...
@@ -1809,7 +1809,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
decltype
(
thread_slice_desc_m_n
)
>
{};
auto
blockwise_dropout
=
BlockwiseDropout
<
FloatGemmAcc
,
decltype
(
thread_slice_desc_m_n
)
>
{
p_dropout_in_
16bits
,
rp_dropout
};
p_dropout_in_
uint8_t
,
rp_dropout
};
auto
lse_grid_desc_mb_m0_m1_m2_m3_m4
=
MakeLSEGridDescriptor_MB_M0_M1_M2_M3_M4
<
decltype
(
s_blockwise_gemm
)
>
(
lse_grid_desc_m
);
...
...
@@ -1859,7 +1859,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
n2
));
// NPerXdl
StaticBuffer
<
AddressSpaceEnum
::
Vgpr
,
u
shor
t
,
u
int8_
t
,
z_thread_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3
.
GetElementSpaceSize
(),
true
>
z_tensor_buffer
;
...
...
@@ -1869,7 +1869,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
p_z_grid
,
z_grid_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3
.
GetElementSpaceSize
());
auto
z_thread_copy_vgpr_to_global
=
ThreadwiseTensorSliceTransfer_v1r3
<
u
shor
t
,
u
int8_
t
,
ZDataType
,
decltype
(
z_thread_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3
),
decltype
(
z_grid_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3
),
...
...
include/ck/tensor_operation/gpu/grid/gridwise_batched_mha_bwd_xdl_cshuffle_qloop_b2t_v2.hpp
View file @
0353c29e
...
...
@@ -132,8 +132,8 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
static
constexpr
auto
V_K0
=
Gemm1NPerBlock
/
KPerBlock
;
static
constexpr
auto
V_N1
=
NXdlPerWave
;
static
constexpr
auto
DropoutNThread
=
mfma
.
num_input_blks
;
// 2
// get_random_
8x
16() generates
8
random numbers each time
static
constexpr
auto
DropoutTile
=
Number
<
DropoutNThread
*
8
>
{};
//
16
// get_random_16
x8
() generates
16
random numbers each time
static
constexpr
auto
DropoutTile
=
Number
<
DropoutNThread
*
16
>
{};
//
32
using
ThisThreadBlock
=
ThisThreadBlock
<
BlockSize
>
;
...
...
@@ -1478,9 +1478,9 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
static
constexpr
auto
d0_block_space_size_aligned
=
math
::
integer_least_multiple
(
D0Loader
::
d0_block_write_desc_m0_n0_m1_m2_n1_m3
.
GetElementSpaceSize
(),
max_lds_align
);
static
constexpr
auto
d0_block_space_offset
=
k_block_space_size_aligned
.
value
*
sizeof
(
GemmDataType
)
/
D0Loader
::
template
TypeTransform
<
D0DataType
>
::
Size
;
static
constexpr
auto
d0_block_space_offset
=
k_block_space_size_aligned
.
value
*
sizeof
(
GemmDataType
)
/
D0Loader
::
template
TypeTransform
<
D0DataType
>
::
Size
;
// LDS allocation for C shuffle in LDS
static
constexpr
auto
c_shuffle_block_desc_mblock_mperblock_nblock_nperblock
=
...
...
@@ -1564,8 +1564,8 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
{
const
FloatGemmAcc
p_dropout
=
type_convert
<
FloatGemmAcc
>
(
1.0
f
-
p_drop
);
const
FloatGemmAcc
rp_dropout
=
type_convert
<
FloatGemmAcc
>
(
1.0
f
/
p_dropout
);
const
u
shor
t
p_dropout_in_
16bits
=
__builtin_amdgcn_readfirstlane
(
std
::
floor
(
p_dropout
*
6
55
35
.0
));
const
u
int8_
t
p_dropout_in_
uint8_t
=
__builtin_amdgcn_readfirstlane
(
uint8_t
(
std
::
floor
(
p_dropout
*
2
55.0
))
)
;
const
tensor_operation
::
element_wise
::
Scale
scale_rp_dropout
(
s_element_op
.
Value
()
*
rp_dropout
);
...
...
@@ -1906,7 +1906,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
decltype
(
thread_slice_desc_m_n
)
>
{};
auto
blockwise_dropout
=
BlockwiseDropout
<
FloatGemmAcc
,
decltype
(
thread_slice_desc_m_n
)
>
{
p_dropout_in_
16bits
,
rp_dropout
};
p_dropout_in_
uint8_t
,
rp_dropout
};
auto
lse_grid_desc_mb_m0_m1_m2_m3_m4
=
MakeLSEGridDescriptor_MB_M0_M1_M2_M3_M4
<
decltype
(
s_blockwise_gemm
)
>
(
lse_grid_desc_m
);
...
...
@@ -1956,7 +1956,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
n2
));
// NPerXdl
StaticBuffer
<
AddressSpaceEnum
::
Vgpr
,
u
shor
t
,
u
int8_
t
,
z_thread_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3
.
GetElementSpaceSize
(),
true
>
z_tensor_buffer
;
...
...
@@ -1966,7 +1966,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
p_z_grid
,
z_grid_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3
.
GetElementSpaceSize
());
auto
z_thread_copy_vgpr_to_global
=
ThreadwiseTensorSliceTransfer_v1r3
<
u
shor
t
,
u
int8_
t
,
ZDataType
,
decltype
(
z_thread_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3
),
decltype
(
z_grid_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3
),
...
...
include/ck/tensor_operation/gpu/grid/gridwise_batched_mha_fwd_xdl_cshuffle_v2.hpp
View file @
0353c29e
This diff is collapsed.
Click to expand it.
include/ck/utility/philox_rand.hpp
View file @
0353c29e
...
...
@@ -84,6 +84,19 @@ class philox
out_tmp
[
3
]
=
tmp_ph
.
w
;
}
__device__
void
get_random_16x8
(
uint8_t
*
out
,
const
unsigned
long
long
subsequence
)
{
uint4
tmp_ph
;
tmp_ph
=
get_philox_4x32
(
subsequence
);
uint32_t
*
out_tmp
=
reinterpret_cast
<
uint32_t
*>
(
&
out
[
0
]);
out_tmp
[
0
]
=
tmp_ph
.
x
;
out_tmp
[
1
]
=
tmp_ph
.
y
;
out_tmp
[
2
]
=
tmp_ph
.
z
;
out_tmp
[
3
]
=
tmp_ph
.
w
;
}
__device__
void
get_random_4x16
(
ushort
*
out
,
const
unsigned
long
long
subsequence
)
{
uint4
tmp_ph
;
...
...
library/include/ck/library/reference_tensor_operation/cpu/reference_dropout.hpp
View file @
0353c29e
...
...
@@ -25,19 +25,19 @@ struct ReferenceDropout : public device::BaseOperator
Argument
(
const
Tensor
<
RefDataType
>&
ref
,
const
Tensor
<
InDataType
>&
in
,
Tensor
<
OutDataType
>&
out
,
RefDataType
p_dropout_in_
16bits
,
RefDataType
p_dropout_in_
uint8_t
,
float
rp_dropout
)
:
ref_
(
ref
),
in_
(
in
),
out_
(
out
),
p_dropout_in_
16bits
_
(
p_dropout_in_
16bits
),
p_dropout_in_
uint8_t
_
(
p_dropout_in_
uint8_t
),
rp_dropout_
(
rp_dropout
)
{
}
const
Tensor
<
RefDataType
>&
ref_
;
const
Tensor
<
InDataType
>&
in_
;
Tensor
<
OutDataType
>&
out_
;
RefDataType
p_dropout_in_
16bits
_
;
RefDataType
p_dropout_in_
uint8_t
_
;
float
rp_dropout_
;
};
...
...
@@ -48,7 +48,7 @@ struct ReferenceDropout : public device::BaseOperator
{
arg
.
out_
.
ForEach
([
&
](
auto
&
self
,
auto
idx
)
{
self
(
idx
)
=
arg
.
ref_
(
idx
)
<=
arg
.
p_dropout_in_
16bits
_
arg
.
ref_
(
idx
)
<=
arg
.
p_dropout_in_
uint8_t
_
?
ck
::
type_convert
<
OutDataType
>
(
ck
::
type_convert
<
float
>
(
arg
.
in_
(
idx
))
*
ck
::
type_convert
<
float
>
(
arg
.
rp_dropout_
))
:
0
;
...
...
@@ -74,10 +74,10 @@ struct ReferenceDropout : public device::BaseOperator
static
auto
MakeArgument
(
const
Tensor
<
RefDataType
>&
ref
,
const
Tensor
<
InDataType
>&
in
,
Tensor
<
OutDataType
>&
out
,
RefDataType
p_dropout_in_
16bits
,
RefDataType
p_dropout_in_
uint8_t
,
float
rp_dropout
)
{
return
Argument
{
ref
,
in
,
out
,
p_dropout_in_
16bits
,
rp_dropout
};
return
Argument
{
ref
,
in
,
out
,
p_dropout_in_
uint8_t
,
rp_dropout
};
}
static
auto
MakeInvoker
()
{
return
Invoker
{};
}
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment