Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
aea324d2
Commit
aea324d2
authored
Sep 20, 2023
by
letaoqin
Browse files
Merge branch 'mha-train-develop' into mha-train-develop-grad-bias
parents
73611570
f04ec574
Changes
26
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
361 additions
and
275 deletions
+361
-275
example/32_batched_gemm_scale_softmax_gemm/CMakeLists.txt
example/32_batched_gemm_scale_softmax_gemm/CMakeLists.txt
+6
-6
example/32_batched_gemm_scale_softmax_gemm/batched_multihead_attention_backward_v2.cpp
..._softmax_gemm/batched_multihead_attention_backward_v2.cpp
+8
-8
example/32_batched_gemm_scale_softmax_gemm/batched_multihead_attention_backward_v3.cpp
..._softmax_gemm/batched_multihead_attention_backward_v3.cpp
+8
-8
example/32_batched_gemm_scale_softmax_gemm/batched_multihead_attention_train_v2.cpp
...ale_softmax_gemm/batched_multihead_attention_train_v2.cpp
+8
-8
example/32_batched_gemm_scale_softmax_gemm/grouped_multihead_attention_backward_v2.cpp
..._softmax_gemm/grouped_multihead_attention_backward_v2.cpp
+7
-7
example/32_batched_gemm_scale_softmax_gemm/grouped_multihead_attention_backward_v3.cpp
..._softmax_gemm/grouped_multihead_attention_backward_v3.cpp
+7
-7
example/32_batched_gemm_scale_softmax_gemm/grouped_multihead_attention_train_v2.cpp
...ale_softmax_gemm/grouped_multihead_attention_train_v2.cpp
+7
-7
example/32_batched_gemm_scale_softmax_gemm/run_batched_multihead_attention_forward.inc
..._softmax_gemm/run_batched_multihead_attention_forward.inc
+6
-5
example/32_batched_gemm_scale_softmax_gemm/run_grouped_multihead_attention_forward.inc
..._softmax_gemm/run_grouped_multihead_attention_forward.inc
+5
-4
example/52_flash_atten_bias/batched_multihead_attention_bias_backward_v2.cpp
...ten_bias/batched_multihead_attention_bias_backward_v2.cpp
+8
-8
example/52_flash_atten_bias/grouped_multihead_attention_bias_backward_v2.cpp
...ten_bias/grouped_multihead_attention_bias_backward_v2.cpp
+7
-7
example/52_flash_atten_bias/run_batched_multihead_attention_bias_forward.inc
...ten_bias/run_batched_multihead_attention_bias_forward.inc
+34
-9
example/52_flash_atten_bias/run_grouped_multihead_attention_bias_forward.inc
...ten_bias/run_grouped_multihead_attention_bias_forward.inc
+37
-12
include/ck/host_utility/hip_check_error.hpp
include/ck/host_utility/hip_check_error.hpp
+13
-0
include/ck/tensor_operation/gpu/block/blockwise_dropout.hpp
include/ck/tensor_operation/gpu/block/blockwise_dropout.hpp
+118
-118
include/ck/tensor_operation/gpu/device/impl/device_batched_dropout.hpp
...nsor_operation/gpu/device/impl/device_batched_dropout.hpp
+0
-1
include/ck/tensor_operation/gpu/device/impl/device_batched_mha_fwd_xdl_cshuffle_v2.hpp
...pu/device/impl/device_batched_mha_fwd_xdl_cshuffle_v2.hpp
+23
-24
include/ck/tensor_operation/gpu/device/impl/device_grouped_mha_fwd_xdl_cshuffle_v2.hpp
...pu/device/impl/device_grouped_mha_fwd_xdl_cshuffle_v2.hpp
+47
-22
include/ck/tensor_operation/gpu/grid/gridwise_batched_dropout.hpp
...ck/tensor_operation/gpu/grid/gridwise_batched_dropout.hpp
+5
-5
include/ck/tensor_operation/gpu/grid/gridwise_batched_mha_bwd_xdl_cshuffle_qloop_b2t_light_v1.hpp
...dwise_batched_mha_bwd_xdl_cshuffle_qloop_b2t_light_v1.hpp
+7
-9
No files found.
example/32_batched_gemm_scale_softmax_gemm/CMakeLists.txt
View file @
aea324d2
...
@@ -5,12 +5,12 @@ add_example_executable(example_batched_gemm_scale_softmax_gemm_permute_xdl_bf16
...
@@ -5,12 +5,12 @@ add_example_executable(example_batched_gemm_scale_softmax_gemm_permute_xdl_bf16
add_example_executable
(
example_grouped_gemm_scale_softmax_gemm_permute_xdl_fp16 grouped_gemm_scale_softmax_gemm_permute_xdl_fp16.cpp
)
add_example_executable
(
example_grouped_gemm_scale_softmax_gemm_permute_xdl_fp16 grouped_gemm_scale_softmax_gemm_permute_xdl_fp16.cpp
)
add_example_executable
(
example_batched_gemm_lower_triangle_scale_softmax_gemm_permute_xdl_fp16 batched_gemm_lower_triangle_scale_softmax_gemm_permute_xdl_fp16.cpp
)
add_example_executable
(
example_batched_gemm_lower_triangle_scale_softmax_gemm_permute_xdl_fp16 batched_gemm_lower_triangle_scale_softmax_gemm_permute_xdl_fp16.cpp
)
add_example_executable
(
example_grouped_gemm_lower_triangle_scale_softmax_gemm_permute_xdl_fp16 grouped_gemm_lower_triangle_scale_softmax_gemm_permute_xdl_fp16.cpp
)
add_example_executable
(
example_grouped_gemm_lower_triangle_scale_softmax_gemm_permute_xdl_fp16 grouped_gemm_lower_triangle_scale_softmax_gemm_permute_xdl_fp16.cpp
)
add_example_executable
(
example_grouped_multihead_attention_forward_v1 grouped_multihead_attention_forward_v1.cpp
)
#
add_example_executable(example_grouped_multihead_attention_forward_v1 grouped_multihead_attention_forward_v1.cpp)
add_example_executable
(
example_batched_multihead_attention_forward_v1 batched_multihead_attention_forward_v1.cpp
)
#
add_example_executable(example_batched_multihead_attention_forward_v1 batched_multihead_attention_forward_v1.cpp)
add_example_executable
(
example_grouped_multihead_attention_backward_v1 grouped_multihead_attention_backward_v1.cpp
)
#
add_example_executable(example_grouped_multihead_attention_backward_v1 grouped_multihead_attention_backward_v1.cpp)
add_example_executable
(
example_batched_multihead_attention_backward_v1 batched_multihead_attention_backward_v1.cpp
)
#
add_example_executable(example_batched_multihead_attention_backward_v1 batched_multihead_attention_backward_v1.cpp)
add_example_executable
(
example_grouped_multihead_attention_train_v1 grouped_multihead_attention_train_v1.cpp
)
#
add_example_executable(example_grouped_multihead_attention_train_v1 grouped_multihead_attention_train_v1.cpp)
add_example_executable
(
example_batched_multihead_attention_train_v1 batched_multihead_attention_train_v1.cpp
)
#
add_example_executable(example_batched_multihead_attention_train_v1 batched_multihead_attention_train_v1.cpp)
add_example_executable
(
example_grouped_multihead_attention_forward_v2 grouped_multihead_attention_forward_v2.cpp
)
add_example_executable
(
example_grouped_multihead_attention_forward_v2 grouped_multihead_attention_forward_v2.cpp
)
add_example_executable
(
example_batched_multihead_attention_forward_v2 batched_multihead_attention_forward_v2.cpp
)
add_example_executable
(
example_batched_multihead_attention_forward_v2 batched_multihead_attention_forward_v2.cpp
)
add_example_executable
(
example_grouped_multihead_attention_backward_v2 grouped_multihead_attention_backward_v2.cpp
)
add_example_executable
(
example_grouped_multihead_attention_backward_v2 grouped_multihead_attention_backward_v2.cpp
)
...
...
example/32_batched_gemm_scale_softmax_gemm/batched_multihead_attention_backward_v2.cpp
View file @
aea324d2
...
@@ -217,7 +217,7 @@ void run_attention_fwd_host(const TensorQ& q_g_m_k,
...
@@ -217,7 +217,7 @@ void run_attention_fwd_host(const TensorQ& q_g_m_k,
TensorLSE
&
lse_g_m
,
TensorLSE
&
lse_g_m
,
TensorP
&
p_drop_g_m_n
,
TensorP
&
p_drop_g_m_n
,
TensorZ
&
z_g_m_n
,
TensorZ
&
z_g_m_n
,
ZDataType
p_dropout_in_
16bits
,
ZDataType
p_dropout_in_
uint8_t
,
float
rp_dropout
)
float
rp_dropout
)
{
{
// S = alpha * Q * K^T
// S = alpha * Q * K^T
...
@@ -249,7 +249,7 @@ void run_attention_fwd_host(const TensorQ& q_g_m_k,
...
@@ -249,7 +249,7 @@ void run_attention_fwd_host(const TensorQ& q_g_m_k,
auto
ref_dropout
=
ReferenceDropoutInstance
{};
auto
ref_dropout
=
ReferenceDropoutInstance
{};
auto
ref_dropout_invoker
=
ref_dropout
.
MakeInvoker
();
auto
ref_dropout_invoker
=
ref_dropout
.
MakeInvoker
();
auto
ref_dropout_argment
=
auto
ref_dropout_argment
=
ref_dropout
.
MakeArgument
(
z_g_m_n
,
p_g_m_n
,
p_drop_g_m_n
,
p_dropout_in_
16bits
,
rp_dropout
);
ref_dropout
.
MakeArgument
(
z_g_m_n
,
p_g_m_n
,
p_drop_g_m_n
,
p_dropout_in_
uint8_t
,
rp_dropout
);
ref_dropout_invoker
.
Run
(
ref_dropout_argment
);
ref_dropout_invoker
.
Run
(
ref_dropout_argment
);
// Y = P_dropout * V
// Y = P_dropout * V
...
@@ -324,10 +324,10 @@ int run(int argc, char* argv[])
...
@@ -324,10 +324,10 @@ int run(int argc, char* argv[])
exit
(
0
);
exit
(
0
);
}
}
float
p_dropout
=
1
-
p_drop
;
float
p_dropout
=
1
-
p_drop
;
ZDataType
p_dropout_in_
16bits
=
ZDataType
(
std
::
floor
(
p_dropout
*
6553
5.0
));
ZDataType
p_dropout_in_
uint8_t
=
ZDataType
(
std
::
floor
(
p_dropout
*
25
5.0
));
float
rp_dropout
=
1.0
/
p_dropout
;
float
rp_dropout
=
1.0
/
p_dropout
;
float
alpha
=
1.
f
/
std
::
sqrt
(
K
);
float
alpha
=
1.
f
/
std
::
sqrt
(
K
);
std
::
cout
<<
"do_verification: "
<<
do_verification
<<
std
::
endl
;
std
::
cout
<<
"do_verification: "
<<
do_verification
<<
std
::
endl
;
std
::
cout
<<
"init_method: "
<<
init_method
<<
std
::
endl
;
std
::
cout
<<
"init_method: "
<<
init_method
<<
std
::
endl
;
...
@@ -631,7 +631,7 @@ int run(int argc, char* argv[])
...
@@ -631,7 +631,7 @@ int run(int argc, char* argv[])
lse_g_m
,
lse_g_m
,
p_drop_g_m_n
,
p_drop_g_m_n
,
z_g_m_n
,
z_g_m_n
,
p_dropout_in_
16bits
,
p_dropout_in_
uint8_t
,
rp_dropout
);
rp_dropout
);
y_gs_ms_os
.
ForEach
([
&
](
auto
&
self
,
auto
idx
)
{
y_gs_ms_os
.
ForEach
([
&
](
auto
&
self
,
auto
idx
)
{
self
(
idx
)
=
y_g_m_o
(
idx
[
0
]
*
G1
+
idx
[
1
],
idx
[
2
],
idx
[
3
]);
self
(
idx
)
=
y_g_m_o
(
idx
[
0
]
*
G1
+
idx
[
1
],
idx
[
2
],
idx
[
3
]);
...
@@ -691,7 +691,7 @@ int run(int argc, char* argv[])
...
@@ -691,7 +691,7 @@ int run(int argc, char* argv[])
auto
ref_dropout
=
ReferenceDropoutInstance
{};
auto
ref_dropout
=
ReferenceDropoutInstance
{};
auto
ref_dropout_invoker
=
ref_dropout
.
MakeInvoker
();
auto
ref_dropout_invoker
=
ref_dropout
.
MakeInvoker
();
auto
ref_dropout_argment
=
ref_dropout
.
MakeArgument
(
auto
ref_dropout_argment
=
ref_dropout
.
MakeArgument
(
z_g_m_n
,
pgrad_drop_g_m_n
,
pgrad_g_m_n
,
p_dropout_in_
16bits
,
rp_dropout
);
z_g_m_n
,
pgrad_drop_g_m_n
,
pgrad_g_m_n
,
p_dropout_in_
uint8_t
,
rp_dropout
);
ref_dropout_invoker
.
Run
(
ref_dropout_argment
);
ref_dropout_invoker
.
Run
(
ref_dropout_argment
);
// dS_i_j = P_i_j .* (dP_i_j - dY_i dot Y_i)
// dS_i_j = P_i_j .* (dP_i_j - dY_i dot Y_i)
...
...
example/32_batched_gemm_scale_softmax_gemm/batched_multihead_attention_backward_v3.cpp
View file @
aea324d2
...
@@ -218,7 +218,7 @@ void run_attention_fwd_host(const TensorQ& q_g_m_k,
...
@@ -218,7 +218,7 @@ void run_attention_fwd_host(const TensorQ& q_g_m_k,
TensorLSE
&
lse_g_m
,
TensorLSE
&
lse_g_m
,
TensorP
&
p_drop_g_m_n
,
TensorP
&
p_drop_g_m_n
,
TensorZ
&
z_g_m_n
,
TensorZ
&
z_g_m_n
,
ZDataType
p_dropout_in_
16bits
,
ZDataType
p_dropout_in_
uint8_t
,
float
rp_dropout
)
float
rp_dropout
)
{
{
// S = alpha * Q * K^T
// S = alpha * Q * K^T
...
@@ -250,7 +250,7 @@ void run_attention_fwd_host(const TensorQ& q_g_m_k,
...
@@ -250,7 +250,7 @@ void run_attention_fwd_host(const TensorQ& q_g_m_k,
auto
ref_dropout
=
ReferenceDropoutInstance
{};
auto
ref_dropout
=
ReferenceDropoutInstance
{};
auto
ref_dropout_invoker
=
ref_dropout
.
MakeInvoker
();
auto
ref_dropout_invoker
=
ref_dropout
.
MakeInvoker
();
auto
ref_dropout_argment
=
auto
ref_dropout_argment
=
ref_dropout
.
MakeArgument
(
z_g_m_n
,
p_g_m_n
,
p_drop_g_m_n
,
p_dropout_in_
16bits
,
rp_dropout
);
ref_dropout
.
MakeArgument
(
z_g_m_n
,
p_g_m_n
,
p_drop_g_m_n
,
p_dropout_in_
uint8_t
,
rp_dropout
);
ref_dropout_invoker
.
Run
(
ref_dropout_argment
);
ref_dropout_invoker
.
Run
(
ref_dropout_argment
);
// Y = P_dropout * V
// Y = P_dropout * V
...
@@ -325,10 +325,10 @@ int run(int argc, char* argv[])
...
@@ -325,10 +325,10 @@ int run(int argc, char* argv[])
exit
(
0
);
exit
(
0
);
}
}
float
p_dropout
=
1
-
p_drop
;
float
p_dropout
=
1
-
p_drop
;
ZDataType
p_dropout_in_
16bits
=
ZDataType
(
std
::
floor
(
p_dropout
*
6553
5.0
));
ZDataType
p_dropout_in_
uint8_t
=
ZDataType
(
std
::
floor
(
p_dropout
*
25
5.0
));
float
rp_dropout
=
1.0
/
p_dropout
;
float
rp_dropout
=
1.0
/
p_dropout
;
float
alpha
=
1.
f
/
std
::
sqrt
(
K
);
float
alpha
=
1.
f
/
std
::
sqrt
(
K
);
std
::
cout
<<
"do_verification: "
<<
do_verification
<<
std
::
endl
;
std
::
cout
<<
"do_verification: "
<<
do_verification
<<
std
::
endl
;
std
::
cout
<<
"init_method: "
<<
init_method
<<
std
::
endl
;
std
::
cout
<<
"init_method: "
<<
init_method
<<
std
::
endl
;
...
@@ -637,7 +637,7 @@ int run(int argc, char* argv[])
...
@@ -637,7 +637,7 @@ int run(int argc, char* argv[])
lse_g_m
,
lse_g_m
,
p_drop_g_m_n
,
p_drop_g_m_n
,
z_g_m_n
,
z_g_m_n
,
p_dropout_in_
16bits
,
p_dropout_in_
uint8_t
,
rp_dropout
);
rp_dropout
);
y_gs_ms_os
.
ForEach
([
&
](
auto
&
self
,
auto
idx
)
{
y_gs_ms_os
.
ForEach
([
&
](
auto
&
self
,
auto
idx
)
{
self
(
idx
)
=
y_g_m_o
(
idx
[
0
]
*
G1
+
idx
[
1
],
idx
[
2
],
idx
[
3
]);
self
(
idx
)
=
y_g_m_o
(
idx
[
0
]
*
G1
+
idx
[
1
],
idx
[
2
],
idx
[
3
]);
...
@@ -697,7 +697,7 @@ int run(int argc, char* argv[])
...
@@ -697,7 +697,7 @@ int run(int argc, char* argv[])
auto
ref_dropout
=
ReferenceDropoutInstance
{};
auto
ref_dropout
=
ReferenceDropoutInstance
{};
auto
ref_dropout_invoker
=
ref_dropout
.
MakeInvoker
();
auto
ref_dropout_invoker
=
ref_dropout
.
MakeInvoker
();
auto
ref_dropout_argment
=
ref_dropout
.
MakeArgument
(
auto
ref_dropout_argment
=
ref_dropout
.
MakeArgument
(
z_g_m_n
,
pgrad_drop_g_m_n
,
pgrad_g_m_n
,
p_dropout_in_
16bits
,
rp_dropout
);
z_g_m_n
,
pgrad_drop_g_m_n
,
pgrad_g_m_n
,
p_dropout_in_
uint8_t
,
rp_dropout
);
ref_dropout_invoker
.
Run
(
ref_dropout_argment
);
ref_dropout_invoker
.
Run
(
ref_dropout_argment
);
// dS_i_j = P_i_j .* (dP_i_j - dY_i dot Y_i)
// dS_i_j = P_i_j .* (dP_i_j - dY_i dot Y_i)
...
...
example/32_batched_gemm_scale_softmax_gemm/batched_multihead_attention_train_v2.cpp
View file @
aea324d2
...
@@ -247,7 +247,7 @@ void run_attention_fwd_host(const TensorQ& q_g_m_k,
...
@@ -247,7 +247,7 @@ void run_attention_fwd_host(const TensorQ& q_g_m_k,
TensorLSE
&
lse_g_m
,
TensorLSE
&
lse_g_m
,
TensorP
&
p_drop_g_m_n
,
TensorP
&
p_drop_g_m_n
,
TensorZ
&
z_g_m_n
,
TensorZ
&
z_g_m_n
,
ZDataType
p_dropout_in_
16bits
,
ZDataType
p_dropout_in_
uint8_t
,
float
rp_dropout
)
float
rp_dropout
)
{
{
// S = alpha * Q * K^T
// S = alpha * Q * K^T
...
@@ -279,7 +279,7 @@ void run_attention_fwd_host(const TensorQ& q_g_m_k,
...
@@ -279,7 +279,7 @@ void run_attention_fwd_host(const TensorQ& q_g_m_k,
auto
ref_dropout
=
ReferenceDropoutInstance
{};
auto
ref_dropout
=
ReferenceDropoutInstance
{};
auto
ref_dropout_invoker
=
ref_dropout
.
MakeInvoker
();
auto
ref_dropout_invoker
=
ref_dropout
.
MakeInvoker
();
auto
ref_dropout_argment
=
auto
ref_dropout_argment
=
ref_dropout
.
MakeArgument
(
z_g_m_n
,
p_g_m_n
,
p_drop_g_m_n
,
p_dropout_in_
16bits
,
rp_dropout
);
ref_dropout
.
MakeArgument
(
z_g_m_n
,
p_g_m_n
,
p_drop_g_m_n
,
p_dropout_in_
uint8_t
,
rp_dropout
);
ref_dropout_invoker
.
Run
(
ref_dropout_argment
);
ref_dropout_invoker
.
Run
(
ref_dropout_argment
);
// Y = P_dropout * V
// Y = P_dropout * V
...
@@ -354,10 +354,10 @@ int run(int argc, char* argv[])
...
@@ -354,10 +354,10 @@ int run(int argc, char* argv[])
exit
(
0
);
exit
(
0
);
}
}
float
p_dropout
=
1
-
p_drop
;
float
p_dropout
=
1
-
p_drop
;
ZDataType
p_dropout_in_
16bits
=
ZDataType
(
std
::
floor
(
p_dropout
*
6553
5.0
));
ZDataType
p_dropout_in_
uint8_t
=
ZDataType
(
std
::
floor
(
p_dropout
*
25
5.0
));
float
rp_dropout
=
1.0
/
p_dropout
;
float
rp_dropout
=
1.0
/
p_dropout
;
float
alpha
=
1.
f
/
std
::
sqrt
(
K
);
float
alpha
=
1.
f
/
std
::
sqrt
(
K
);
std
::
cout
<<
"do_verification: "
<<
do_verification
<<
std
::
endl
;
std
::
cout
<<
"do_verification: "
<<
do_verification
<<
std
::
endl
;
std
::
cout
<<
"init_method: "
<<
init_method
<<
std
::
endl
;
std
::
cout
<<
"init_method: "
<<
init_method
<<
std
::
endl
;
...
@@ -815,7 +815,7 @@ int run(int argc, char* argv[])
...
@@ -815,7 +815,7 @@ int run(int argc, char* argv[])
lse_g_m
,
lse_g_m
,
p_drop_g_m_n
,
p_drop_g_m_n
,
z_fwd_g_m_n
,
z_fwd_g_m_n
,
p_dropout_in_
16bits
,
p_dropout_in_
uint8_t
,
rp_dropout
);
rp_dropout
);
ygrad_gs_ms_os
.
ForEach
([
&
](
auto
&
self
,
auto
idx
)
{
ygrad_gs_ms_os
.
ForEach
([
&
](
auto
&
self
,
auto
idx
)
{
...
@@ -858,7 +858,7 @@ int run(int argc, char* argv[])
...
@@ -858,7 +858,7 @@ int run(int argc, char* argv[])
auto
ref_dropout
=
ReferenceDropoutInstance
{};
auto
ref_dropout
=
ReferenceDropoutInstance
{};
auto
ref_dropout_invoker
=
ref_dropout
.
MakeInvoker
();
auto
ref_dropout_invoker
=
ref_dropout
.
MakeInvoker
();
auto
ref_dropout_argment
=
ref_dropout
.
MakeArgument
(
auto
ref_dropout_argment
=
ref_dropout
.
MakeArgument
(
z_bwd_g_m_n
,
pgrad_drop_g_m_n
,
pgrad_g_m_n
,
p_dropout_in_
16bits
,
rp_dropout
);
z_bwd_g_m_n
,
pgrad_drop_g_m_n
,
pgrad_g_m_n
,
p_dropout_in_
uint8_t
,
rp_dropout
);
ref_dropout_invoker
.
Run
(
ref_dropout_argment
);
ref_dropout_invoker
.
Run
(
ref_dropout_argment
);
// dS_i_j = P_i_j .* (dP_i_j - dY_i dot Y_i)
// dS_i_j = P_i_j .* (dP_i_j - dY_i dot Y_i)
...
...
example/32_batched_gemm_scale_softmax_gemm/grouped_multihead_attention_backward_v2.cpp
View file @
aea324d2
...
@@ -216,7 +216,7 @@ void run_attention_fwd_host(const TensorQ& q_g_m_k,
...
@@ -216,7 +216,7 @@ void run_attention_fwd_host(const TensorQ& q_g_m_k,
TensorLSE
&
lse_g_m
,
TensorLSE
&
lse_g_m
,
TensorP
&
p_drop_g_m_n
,
TensorP
&
p_drop_g_m_n
,
TensorZ
&
z_g_m_n
,
TensorZ
&
z_g_m_n
,
ZDataType
p_dropout_in_
16bits
,
ZDataType
p_dropout_in_
uint8_t
,
float
rp_dropout
)
float
rp_dropout
)
{
{
// S = alpha * Q * K^T
// S = alpha * Q * K^T
...
@@ -248,7 +248,7 @@ void run_attention_fwd_host(const TensorQ& q_g_m_k,
...
@@ -248,7 +248,7 @@ void run_attention_fwd_host(const TensorQ& q_g_m_k,
auto
ref_dropout
=
ReferenceDropoutInstance
{};
auto
ref_dropout
=
ReferenceDropoutInstance
{};
auto
ref_dropout_invoker
=
ref_dropout
.
MakeInvoker
();
auto
ref_dropout_invoker
=
ref_dropout
.
MakeInvoker
();
auto
ref_dropout_argment
=
auto
ref_dropout_argment
=
ref_dropout
.
MakeArgument
(
z_g_m_n
,
p_g_m_n
,
p_drop_g_m_n
,
p_dropout_in_
16bits
,
rp_dropout
);
ref_dropout
.
MakeArgument
(
z_g_m_n
,
p_g_m_n
,
p_drop_g_m_n
,
p_dropout_in_
uint8_t
,
rp_dropout
);
ref_dropout_invoker
.
Run
(
ref_dropout_argment
);
ref_dropout_invoker
.
Run
(
ref_dropout_argment
);
// Y = P_dropout * V
// Y = P_dropout * V
...
@@ -311,9 +311,9 @@ int run(int argc, char* argv[])
...
@@ -311,9 +311,9 @@ int run(int argc, char* argv[])
exit
(
0
);
exit
(
0
);
}
}
float
p_dropout
=
1
-
p_drop
;
float
p_dropout
=
1
-
p_drop
;
ZDataType
p_dropout_in_
16bits
=
ZDataType
(
std
::
floor
(
p_dropout
*
6553
5.0
));
ZDataType
p_dropout_in_
uint8_t
=
ZDataType
(
std
::
floor
(
p_dropout
*
25
5.0
));
float
rp_dropout
=
1.0
/
p_dropout
;
float
rp_dropout
=
1.0
/
p_dropout
;
auto
gemm
=
DeviceGemmInstance
{};
auto
gemm
=
DeviceGemmInstance
{};
auto
invoker
=
gemm
.
MakeInvoker
();
auto
invoker
=
gemm
.
MakeInvoker
();
...
@@ -690,7 +690,7 @@ int run(int argc, char* argv[])
...
@@ -690,7 +690,7 @@ int run(int argc, char* argv[])
lse_g_ms
[
i
],
lse_g_ms
[
i
],
p_drop_g_m_ns
[
i
],
p_drop_g_m_ns
[
i
],
z_g_m_ns
[
i
],
z_g_m_ns
[
i
],
p_dropout_in_
16bits
,
p_dropout_in_
uint8_t
,
rp_dropout
);
rp_dropout
);
y_tensors
[
i
].
ForEach
([
&
](
auto
&
self
,
auto
idx
)
{
y_tensors
[
i
].
ForEach
([
&
](
auto
&
self
,
auto
idx
)
{
...
@@ -742,7 +742,7 @@ int run(int argc, char* argv[])
...
@@ -742,7 +742,7 @@ int run(int argc, char* argv[])
auto
ref_dropout
=
ReferenceDropoutInstance
{};
auto
ref_dropout
=
ReferenceDropoutInstance
{};
auto
ref_dropout_invoker
=
ref_dropout
.
MakeInvoker
();
auto
ref_dropout_invoker
=
ref_dropout
.
MakeInvoker
();
auto
ref_dropout_argment
=
ref_dropout
.
MakeArgument
(
auto
ref_dropout_argment
=
ref_dropout
.
MakeArgument
(
z_g_m_ns
[
i
],
pgrad_drop_g_m_n
,
pgrad_g_m_n
,
p_dropout_in_
16bits
,
rp_dropout
);
z_g_m_ns
[
i
],
pgrad_drop_g_m_n
,
pgrad_g_m_n
,
p_dropout_in_
uint8_t
,
rp_dropout
);
ref_dropout_invoker
.
Run
(
ref_dropout_argment
);
ref_dropout_invoker
.
Run
(
ref_dropout_argment
);
sgrad_g_m_n
.
ForEach
([
&
](
auto
&
self
,
auto
idx_gmn
)
{
sgrad_g_m_n
.
ForEach
([
&
](
auto
&
self
,
auto
idx_gmn
)
{
...
...
example/32_batched_gemm_scale_softmax_gemm/grouped_multihead_attention_backward_v3.cpp
View file @
aea324d2
...
@@ -217,7 +217,7 @@ void run_attention_fwd_host(const TensorQ& q_g_m_k,
...
@@ -217,7 +217,7 @@ void run_attention_fwd_host(const TensorQ& q_g_m_k,
TensorLSE
&
lse_g_m
,
TensorLSE
&
lse_g_m
,
TensorP
&
p_drop_g_m_n
,
TensorP
&
p_drop_g_m_n
,
TensorZ
&
z_g_m_n
,
TensorZ
&
z_g_m_n
,
ZDataType
p_dropout_in_
16bits
,
ZDataType
p_dropout_in_
uint8_t
,
float
rp_dropout
)
float
rp_dropout
)
{
{
// S = alpha * Q * K^T
// S = alpha * Q * K^T
...
@@ -249,7 +249,7 @@ void run_attention_fwd_host(const TensorQ& q_g_m_k,
...
@@ -249,7 +249,7 @@ void run_attention_fwd_host(const TensorQ& q_g_m_k,
auto
ref_dropout
=
ReferenceDropoutInstance
{};
auto
ref_dropout
=
ReferenceDropoutInstance
{};
auto
ref_dropout_invoker
=
ref_dropout
.
MakeInvoker
();
auto
ref_dropout_invoker
=
ref_dropout
.
MakeInvoker
();
auto
ref_dropout_argment
=
auto
ref_dropout_argment
=
ref_dropout
.
MakeArgument
(
z_g_m_n
,
p_g_m_n
,
p_drop_g_m_n
,
p_dropout_in_
16bits
,
rp_dropout
);
ref_dropout
.
MakeArgument
(
z_g_m_n
,
p_g_m_n
,
p_drop_g_m_n
,
p_dropout_in_
uint8_t
,
rp_dropout
);
ref_dropout_invoker
.
Run
(
ref_dropout_argment
);
ref_dropout_invoker
.
Run
(
ref_dropout_argment
);
// Y = P_dropout * V
// Y = P_dropout * V
...
@@ -312,9 +312,9 @@ int run(int argc, char* argv[])
...
@@ -312,9 +312,9 @@ int run(int argc, char* argv[])
exit
(
0
);
exit
(
0
);
}
}
float
p_dropout
=
1
-
p_drop
;
float
p_dropout
=
1
-
p_drop
;
ZDataType
p_dropout_in_
16bits
=
ZDataType
(
std
::
floor
(
p_dropout
*
6553
5.0
));
ZDataType
p_dropout_in_
uint8_t
=
ZDataType
(
std
::
floor
(
p_dropout
*
25
5.0
));
float
rp_dropout
=
1.0
/
p_dropout
;
float
rp_dropout
=
1.0
/
p_dropout
;
auto
gemm
=
DeviceGemmInstance
{};
auto
gemm
=
DeviceGemmInstance
{};
auto
invoker
=
gemm
.
MakeInvoker
();
auto
invoker
=
gemm
.
MakeInvoker
();
...
@@ -703,7 +703,7 @@ int run(int argc, char* argv[])
...
@@ -703,7 +703,7 @@ int run(int argc, char* argv[])
lse_g_ms
[
i
],
lse_g_ms
[
i
],
p_drop_g_m_ns
[
i
],
p_drop_g_m_ns
[
i
],
z_g_m_ns
[
i
],
z_g_m_ns
[
i
],
p_dropout_in_
16bits
,
p_dropout_in_
uint8_t
,
rp_dropout
);
rp_dropout
);
y_tensors
[
i
].
ForEach
([
&
](
auto
&
self
,
auto
idx
)
{
y_tensors
[
i
].
ForEach
([
&
](
auto
&
self
,
auto
idx
)
{
...
@@ -755,7 +755,7 @@ int run(int argc, char* argv[])
...
@@ -755,7 +755,7 @@ int run(int argc, char* argv[])
auto
ref_dropout
=
ReferenceDropoutInstance
{};
auto
ref_dropout
=
ReferenceDropoutInstance
{};
auto
ref_dropout_invoker
=
ref_dropout
.
MakeInvoker
();
auto
ref_dropout_invoker
=
ref_dropout
.
MakeInvoker
();
auto
ref_dropout_argment
=
ref_dropout
.
MakeArgument
(
auto
ref_dropout_argment
=
ref_dropout
.
MakeArgument
(
z_g_m_ns
[
i
],
pgrad_drop_g_m_n
,
pgrad_g_m_n
,
p_dropout_in_
16bits
,
rp_dropout
);
z_g_m_ns
[
i
],
pgrad_drop_g_m_n
,
pgrad_g_m_n
,
p_dropout_in_
uint8_t
,
rp_dropout
);
ref_dropout_invoker
.
Run
(
ref_dropout_argment
);
ref_dropout_invoker
.
Run
(
ref_dropout_argment
);
sgrad_g_m_n
.
ForEach
([
&
](
auto
&
self
,
auto
idx_gmn
)
{
sgrad_g_m_n
.
ForEach
([
&
](
auto
&
self
,
auto
idx_gmn
)
{
...
...
example/32_batched_gemm_scale_softmax_gemm/grouped_multihead_attention_train_v2.cpp
View file @
aea324d2
...
@@ -246,7 +246,7 @@ void run_attention_fwd_host(const TensorQ& q_g_m_k,
...
@@ -246,7 +246,7 @@ void run_attention_fwd_host(const TensorQ& q_g_m_k,
TensorLSE
&
lse_g_m
,
TensorLSE
&
lse_g_m
,
TensorP
&
p_drop_g_m_n
,
TensorP
&
p_drop_g_m_n
,
TensorZ
&
z_g_m_n
,
TensorZ
&
z_g_m_n
,
ZDataType
p_dropout_in_
16bits
,
ZDataType
p_dropout_in_
uint8_t
,
float
rp_dropout
)
float
rp_dropout
)
{
{
// S = alpha * Q * K^T
// S = alpha * Q * K^T
...
@@ -278,7 +278,7 @@ void run_attention_fwd_host(const TensorQ& q_g_m_k,
...
@@ -278,7 +278,7 @@ void run_attention_fwd_host(const TensorQ& q_g_m_k,
auto
ref_dropout
=
ReferenceDropoutInstance
{};
auto
ref_dropout
=
ReferenceDropoutInstance
{};
auto
ref_dropout_invoker
=
ref_dropout
.
MakeInvoker
();
auto
ref_dropout_invoker
=
ref_dropout
.
MakeInvoker
();
auto
ref_dropout_argment
=
auto
ref_dropout_argment
=
ref_dropout
.
MakeArgument
(
z_g_m_n
,
p_g_m_n
,
p_drop_g_m_n
,
p_dropout_in_
16bits
,
rp_dropout
);
ref_dropout
.
MakeArgument
(
z_g_m_n
,
p_g_m_n
,
p_drop_g_m_n
,
p_dropout_in_
uint8_t
,
rp_dropout
);
ref_dropout_invoker
.
Run
(
ref_dropout_argment
);
ref_dropout_invoker
.
Run
(
ref_dropout_argment
);
// Y = P_dropout * V
// Y = P_dropout * V
...
@@ -341,9 +341,9 @@ int run(int argc, char* argv[])
...
@@ -341,9 +341,9 @@ int run(int argc, char* argv[])
exit
(
0
);
exit
(
0
);
}
}
float
p_dropout
=
1
-
p_drop
;
float
p_dropout
=
1
-
p_drop
;
ZDataType
p_dropout_in_
16bits
=
ZDataType
(
std
::
floor
(
p_dropout
*
6553
5.0
));
ZDataType
p_dropout_in_
uint8_t
=
ZDataType
(
std
::
floor
(
p_dropout
*
25
5.0
));
float
rp_dropout
=
1.0
/
p_dropout
;
float
rp_dropout
=
1.0
/
p_dropout
;
auto
gemm_fwd
=
DeviceGemmInstanceFWD
{};
auto
gemm_fwd
=
DeviceGemmInstanceFWD
{};
auto
invoker_fwd
=
gemm_fwd
.
MakeInvoker
();
auto
invoker_fwd
=
gemm_fwd
.
MakeInvoker
();
...
@@ -860,7 +860,7 @@ int run(int argc, char* argv[])
...
@@ -860,7 +860,7 @@ int run(int argc, char* argv[])
lse_g_ms
[
i
],
lse_g_ms
[
i
],
p_drop_g_m_ns
[
i
],
p_drop_g_m_ns
[
i
],
z_fwd_g_m_ns
[
i
],
z_fwd_g_m_ns
[
i
],
p_dropout_in_
16bits
,
p_dropout_in_
uint8_t
,
rp_dropout
);
rp_dropout
);
int
G0
=
v_tensors
[
i
].
GetLengths
()[
0
];
int
G0
=
v_tensors
[
i
].
GetLengths
()[
0
];
...
@@ -893,7 +893,7 @@ int run(int argc, char* argv[])
...
@@ -893,7 +893,7 @@ int run(int argc, char* argv[])
auto
ref_dropout
=
ReferenceDropoutInstance
{};
auto
ref_dropout
=
ReferenceDropoutInstance
{};
auto
ref_dropout_invoker
=
ref_dropout
.
MakeInvoker
();
auto
ref_dropout_invoker
=
ref_dropout
.
MakeInvoker
();
auto
ref_dropout_argment
=
ref_dropout
.
MakeArgument
(
auto
ref_dropout_argment
=
ref_dropout
.
MakeArgument
(
z_bwd_g_m_ns
[
i
],
pgrad_drop_g_m_n
,
pgrad_g_m_n
,
p_dropout_in_
16bits
,
rp_dropout
);
z_bwd_g_m_ns
[
i
],
pgrad_drop_g_m_n
,
pgrad_g_m_n
,
p_dropout_in_
uint8_t
,
rp_dropout
);
ref_dropout_invoker
.
Run
(
ref_dropout_argment
);
ref_dropout_invoker
.
Run
(
ref_dropout_argment
);
sgrad_g_m_n
.
ForEach
([
&
](
auto
&
self
,
auto
idx_gmn
)
{
sgrad_g_m_n
.
ForEach
([
&
](
auto
&
self
,
auto
idx_gmn
)
{
...
...
example/32_batched_gemm_scale_softmax_gemm/run_batched_multihead_attention_forward.inc
View file @
aea324d2
...
@@ -66,10 +66,10 @@ int run(int argc, char* argv[])
...
@@ -66,10 +66,10 @@ int run(int argc, char* argv[])
exit
(
0
);
exit
(
0
);
}
}
float
p_dropout
=
1
-
p_drop
;
float
p_dropout
=
1
-
p_drop
;
ZDataType
p_dropout_in_
16bits
=
ZDataType
(
std
::
floor
(
p_dropout
*
6553
5.0
));
ZDataType
p_dropout_in_
uint8_t
=
ZDataType
(
std
::
floor
(
p_dropout
*
25
5.0
));
float
rp_dropout
=
1.0
/
p_dropout
;
float
rp_dropout
=
1.0
/
p_dropout
;
float
alpha
=
1.
f
/
std
::
sqrt
(
K
);
float
alpha
=
1.
f
/
std
::
sqrt
(
K
);
std
::
vector
<
ck
::
index_t
>
a_gs_ms_ks_lengths
{
G0
,
G1
,
M
,
K
};
std
::
vector
<
ck
::
index_t
>
a_gs_ms_ks_lengths
{
G0
,
G1
,
M
,
K
};
std
::
vector
<
ck
::
index_t
>
a_gs_ms_ks_strides
=
std
::
vector
<
ck
::
index_t
>
a_gs_ms_ks_strides
=
...
@@ -159,6 +159,7 @@ int run(int argc, char* argv[])
...
@@ -159,6 +159,7 @@ int run(int argc, char* argv[])
a_device_buf
.
ToDevice
(
a_gs_ms_ks
.
mData
.
data
());
a_device_buf
.
ToDevice
(
a_gs_ms_ks
.
mData
.
data
());
b0_device_buf
.
ToDevice
(
b0_gs_ns_ks
.
mData
.
data
());
b0_device_buf
.
ToDevice
(
b0_gs_ns_ks
.
mData
.
data
());
b1_device_buf
.
ToDevice
(
b1_gs_os_ns
.
mData
.
data
());
b1_device_buf
.
ToDevice
(
b1_gs_os_ns
.
mData
.
data
());
z_device_buf
.
ToDevice
(
z_gs_ms_ns
.
mData
.
data
());
auto
a_element_op
=
AElementOp
{};
auto
a_element_op
=
AElementOp
{};
auto
b0_element_op
=
B0ElementOp
{};
auto
b0_element_op
=
B0ElementOp
{};
...
@@ -322,7 +323,7 @@ int run(int argc, char* argv[])
...
@@ -322,7 +323,7 @@ int run(int argc, char* argv[])
auto
ref_dropout
=
ReferenceDropoutInstance
{};
auto
ref_dropout
=
ReferenceDropoutInstance
{};
auto
ref_dropout_invoker
=
ref_dropout
.
MakeInvoker
();
auto
ref_dropout_invoker
=
ref_dropout
.
MakeInvoker
();
auto
ref_dropout_argment
=
ref_dropout
.
MakeArgument
(
auto
ref_dropout_argment
=
ref_dropout
.
MakeArgument
(
z_g_m_n
,
a1_g_m_n
,
a1_g_m_n_drop
,
p_dropout_in_
16bits
,
rp_dropout
);
z_g_m_n
,
a1_g_m_n
,
a1_g_m_n_drop
,
p_dropout_in_
uint8_t
,
rp_dropout
);
ref_dropout_invoker
.
Run
(
ref_dropout_argment
);
ref_dropout_invoker
.
Run
(
ref_dropout_argment
);
// gemm1
// gemm1
...
...
example/32_batched_gemm_scale_softmax_gemm/run_grouped_multihead_attention_forward.inc
View file @
aea324d2
...
@@ -43,9 +43,9 @@ int run(int argc, char* argv[])
...
@@ -43,9 +43,9 @@ int run(int argc, char* argv[])
exit
(
0
);
exit
(
0
);
}
}
float
p_dropout
=
1
-
p_drop
;
float
p_dropout
=
1
-
p_drop
;
uint16_t
p_dropout_in_
16bits
=
uint16_t
(
std
::
floor
(
p_dropout
*
6553
5.0
));
ZDataType
p_dropout_in_
uint8_t
=
ZDataType
(
std
::
floor
(
p_dropout
*
25
5.0
));
float
rp_dropout
=
1.0
/
p_dropout
;
float
rp_dropout
=
1.0
/
p_dropout
;
float
alpha
=
1
;
// scaling after 1st gemm
float
alpha
=
1
;
// scaling after 1st gemm
...
@@ -217,6 +217,7 @@ int run(int argc, char* argv[])
...
@@ -217,6 +217,7 @@ int run(int argc, char* argv[])
a_tensors_device
[
i
]
->
ToDevice
(
a_gs_ms_ks
.
mData
.
data
());
a_tensors_device
[
i
]
->
ToDevice
(
a_gs_ms_ks
.
mData
.
data
());
b0_tensors_device
[
i
]
->
ToDevice
(
b0_gs_ns_ks
.
mData
.
data
());
b0_tensors_device
[
i
]
->
ToDevice
(
b0_gs_ns_ks
.
mData
.
data
());
b1_tensors_device
[
i
]
->
ToDevice
(
b1_gs_os_ns
.
mData
.
data
());
b1_tensors_device
[
i
]
->
ToDevice
(
b1_gs_os_ns
.
mData
.
data
());
z_tensors_device
[
i
]
->
ToDevice
(
z_gs_ms_ns
.
mData
.
data
());
p_a
.
push_back
(
a_tensors_device
[
i
]
->
GetDeviceBuffer
());
p_a
.
push_back
(
a_tensors_device
[
i
]
->
GetDeviceBuffer
());
p_b0
.
push_back
(
b0_tensors_device
[
i
]
->
GetDeviceBuffer
());
p_b0
.
push_back
(
b0_tensors_device
[
i
]
->
GetDeviceBuffer
());
...
@@ -390,7 +391,7 @@ int run(int argc, char* argv[])
...
@@ -390,7 +391,7 @@ int run(int argc, char* argv[])
auto
ref_dropout
=
ReferenceDropoutInstance
{};
auto
ref_dropout
=
ReferenceDropoutInstance
{};
auto
ref_dropout_invoker
=
ref_dropout
.
MakeInvoker
();
auto
ref_dropout_invoker
=
ref_dropout
.
MakeInvoker
();
auto
ref_dropout_argment
=
ref_dropout
.
MakeArgument
(
auto
ref_dropout_argment
=
ref_dropout
.
MakeArgument
(
z_g_m_n
,
a1_g_m_n
,
a1_g_m_n_drop
,
p_dropout_in_
16bits
,
rp_dropout
);
z_g_m_n
,
a1_g_m_n
,
a1_g_m_n_drop
,
p_dropout_in_
uint8_t
,
rp_dropout
);
ref_dropout_invoker
.
Run
(
ref_dropout_argment
);
ref_dropout_invoker
.
Run
(
ref_dropout_argment
);
// gemm 1
// gemm 1
...
...
example/52_flash_atten_bias/batched_multihead_attention_bias_backward_v2.cpp
View file @
aea324d2
...
@@ -217,7 +217,7 @@ void run_attention_fwd_host(const TensorQ& q_g_m_k,
...
@@ -217,7 +217,7 @@ void run_attention_fwd_host(const TensorQ& q_g_m_k,
TensorLSE
&
lse_g_m
,
TensorLSE
&
lse_g_m
,
TensorP
&
p_drop_g_m_n
,
TensorP
&
p_drop_g_m_n
,
TensorZ
&
z_g_m_n
,
TensorZ
&
z_g_m_n
,
ZDataType
p_dropout_in_
16bits
,
ZDataType
p_dropout_in_
uint8_t
,
float
rp_dropout
)
float
rp_dropout
)
{
{
// S = alpha * Q * K^T
// S = alpha * Q * K^T
...
@@ -252,7 +252,7 @@ void run_attention_fwd_host(const TensorQ& q_g_m_k,
...
@@ -252,7 +252,7 @@ void run_attention_fwd_host(const TensorQ& q_g_m_k,
auto
ref_dropout
=
ReferenceDropoutInstance
{};
auto
ref_dropout
=
ReferenceDropoutInstance
{};
auto
ref_dropout_invoker
=
ref_dropout
.
MakeInvoker
();
auto
ref_dropout_invoker
=
ref_dropout
.
MakeInvoker
();
auto
ref_dropout_argment
=
auto
ref_dropout_argment
=
ref_dropout
.
MakeArgument
(
z_g_m_n
,
p_g_m_n
,
p_drop_g_m_n
,
p_dropout_in_
16bits
,
rp_dropout
);
ref_dropout
.
MakeArgument
(
z_g_m_n
,
p_g_m_n
,
p_drop_g_m_n
,
p_dropout_in_
uint8_t
,
rp_dropout
);
ref_dropout_invoker
.
Run
(
ref_dropout_argment
);
ref_dropout_invoker
.
Run
(
ref_dropout_argment
);
// Y = P_dropout * V
// Y = P_dropout * V
...
@@ -327,10 +327,10 @@ int run(int argc, char* argv[])
...
@@ -327,10 +327,10 @@ int run(int argc, char* argv[])
exit
(
0
);
exit
(
0
);
}
}
float
p_dropout
=
1
-
p_drop
;
float
p_dropout
=
1
-
p_drop
;
ZDataType
p_dropout_in_
16bits
=
ZDataType
(
std
::
floor
(
p_dropout
*
6553
5.0
));
ZDataType
p_dropout_in_
uint8_t
=
ZDataType
(
std
::
floor
(
p_dropout
*
25
5.0
));
float
rp_dropout
=
1.0
/
p_dropout
;
float
rp_dropout
=
1.0
/
p_dropout
;
float
alpha
=
1.
f
/
std
::
sqrt
(
K
);
float
alpha
=
1.
f
/
std
::
sqrt
(
K
);
std
::
cout
<<
"do_verification: "
<<
do_verification
<<
std
::
endl
;
std
::
cout
<<
"do_verification: "
<<
do_verification
<<
std
::
endl
;
std
::
cout
<<
"init_method: "
<<
init_method
<<
std
::
endl
;
std
::
cout
<<
"init_method: "
<<
init_method
<<
std
::
endl
;
...
@@ -659,7 +659,7 @@ int run(int argc, char* argv[])
...
@@ -659,7 +659,7 @@ int run(int argc, char* argv[])
lse_g_m
,
lse_g_m
,
p_drop_g_m_n
,
p_drop_g_m_n
,
z_g_m_n
,
z_g_m_n
,
p_dropout_in_
16bits
,
p_dropout_in_
uint8_t
,
rp_dropout
);
rp_dropout
);
y_gs_ms_os
.
ForEach
([
&
](
auto
&
self
,
auto
idx
)
{
y_gs_ms_os
.
ForEach
([
&
](
auto
&
self
,
auto
idx
)
{
self
(
idx
)
=
y_g_m_o
(
idx
[
0
]
*
G1
+
idx
[
1
],
idx
[
2
],
idx
[
3
]);
self
(
idx
)
=
y_g_m_o
(
idx
[
0
]
*
G1
+
idx
[
1
],
idx
[
2
],
idx
[
3
]);
...
@@ -719,7 +719,7 @@ int run(int argc, char* argv[])
...
@@ -719,7 +719,7 @@ int run(int argc, char* argv[])
auto
ref_dropout
=
ReferenceDropoutInstance
{};
auto
ref_dropout
=
ReferenceDropoutInstance
{};
auto
ref_dropout_invoker
=
ref_dropout
.
MakeInvoker
();
auto
ref_dropout_invoker
=
ref_dropout
.
MakeInvoker
();
auto
ref_dropout_argment
=
ref_dropout
.
MakeArgument
(
auto
ref_dropout_argment
=
ref_dropout
.
MakeArgument
(
z_g_m_n
,
pgrad_drop_g_m_n
,
pgrad_g_m_n
,
p_dropout_in_
16bits
,
rp_dropout
);
z_g_m_n
,
pgrad_drop_g_m_n
,
pgrad_g_m_n
,
p_dropout_in_
uint8_t
,
rp_dropout
);
ref_dropout_invoker
.
Run
(
ref_dropout_argment
);
ref_dropout_invoker
.
Run
(
ref_dropout_argment
);
// dS_i_j = P_i_j .* (dP_i_j - dY_i dot Y_i)
// dS_i_j = P_i_j .* (dP_i_j - dY_i dot Y_i)
...
...
example/52_flash_atten_bias/grouped_multihead_attention_bias_backward_v2.cpp
View file @
aea324d2
...
@@ -216,7 +216,7 @@ void run_attention_fwd_host(const TensorQ& q_g_m_k,
...
@@ -216,7 +216,7 @@ void run_attention_fwd_host(const TensorQ& q_g_m_k,
TensorLSE
&
lse_g_m
,
TensorLSE
&
lse_g_m
,
TensorP
&
p_drop_g_m_n
,
TensorP
&
p_drop_g_m_n
,
TensorZ
&
z_g_m_n
,
TensorZ
&
z_g_m_n
,
ZDataType
p_dropout_in_
16bits
,
ZDataType
p_dropout_in_
uint8_t
,
float
rp_dropout
)
float
rp_dropout
)
{
{
// S = alpha * Q * K^T
// S = alpha * Q * K^T
...
@@ -251,7 +251,7 @@ void run_attention_fwd_host(const TensorQ& q_g_m_k,
...
@@ -251,7 +251,7 @@ void run_attention_fwd_host(const TensorQ& q_g_m_k,
auto
ref_dropout
=
ReferenceDropoutInstance
{};
auto
ref_dropout
=
ReferenceDropoutInstance
{};
auto
ref_dropout_invoker
=
ref_dropout
.
MakeInvoker
();
auto
ref_dropout_invoker
=
ref_dropout
.
MakeInvoker
();
auto
ref_dropout_argment
=
auto
ref_dropout_argment
=
ref_dropout
.
MakeArgument
(
z_g_m_n
,
p_g_m_n
,
p_drop_g_m_n
,
p_dropout_in_
16bits
,
rp_dropout
);
ref_dropout
.
MakeArgument
(
z_g_m_n
,
p_g_m_n
,
p_drop_g_m_n
,
p_dropout_in_
uint8_t
,
rp_dropout
);
ref_dropout_invoker
.
Run
(
ref_dropout_argment
);
ref_dropout_invoker
.
Run
(
ref_dropout_argment
);
// Y = P_dropout * V
// Y = P_dropout * V
...
@@ -314,9 +314,9 @@ int run(int argc, char* argv[])
...
@@ -314,9 +314,9 @@ int run(int argc, char* argv[])
exit
(
0
);
exit
(
0
);
}
}
float
p_dropout
=
1
-
p_drop
;
float
p_dropout
=
1
-
p_drop
;
ZDataType
p_dropout_in_
16bits
=
ZDataType
(
std
::
floor
(
p_dropout
*
6553
5.0
));
ZDataType
p_dropout_in_
uint8_t
=
ZDataType
(
std
::
floor
(
p_dropout
*
25
5.0
));
float
rp_dropout
=
1.0
/
p_dropout
;
float
rp_dropout
=
1.0
/
p_dropout
;
auto
gemm
=
DeviceGemmInstance
{};
auto
gemm
=
DeviceGemmInstance
{};
auto
invoker
=
gemm
.
MakeInvoker
();
auto
invoker
=
gemm
.
MakeInvoker
();
...
@@ -730,7 +730,7 @@ int run(int argc, char* argv[])
...
@@ -730,7 +730,7 @@ int run(int argc, char* argv[])
lse_g_ms
[
i
],
lse_g_ms
[
i
],
p_drop_g_m_ns
[
i
],
p_drop_g_m_ns
[
i
],
z_g_m_ns
[
i
],
z_g_m_ns
[
i
],
p_dropout_in_
16bits
,
p_dropout_in_
uint8_t
,
rp_dropout
);
rp_dropout
);
y_tensors
[
i
].
ForEach
([
&
](
auto
&
self
,
auto
idx
)
{
y_tensors
[
i
].
ForEach
([
&
](
auto
&
self
,
auto
idx
)
{
...
@@ -784,7 +784,7 @@ int run(int argc, char* argv[])
...
@@ -784,7 +784,7 @@ int run(int argc, char* argv[])
auto
ref_dropout
=
ReferenceDropoutInstance
{};
auto
ref_dropout
=
ReferenceDropoutInstance
{};
auto
ref_dropout_invoker
=
ref_dropout
.
MakeInvoker
();
auto
ref_dropout_invoker
=
ref_dropout
.
MakeInvoker
();
auto
ref_dropout_argment
=
ref_dropout
.
MakeArgument
(
auto
ref_dropout_argment
=
ref_dropout
.
MakeArgument
(
z_g_m_ns
[
i
],
pgrad_drop_g_m_n
,
pgrad_g_m_n
,
p_dropout_in_
16bits
,
rp_dropout
);
z_g_m_ns
[
i
],
pgrad_drop_g_m_n
,
pgrad_g_m_n
,
p_dropout_in_
uint8_t
,
rp_dropout
);
ref_dropout_invoker
.
Run
(
ref_dropout_argment
);
ref_dropout_invoker
.
Run
(
ref_dropout_argment
);
// dS_i_j = P_i_j .* (dP_i_j - dY_i dot Y_i)
// dS_i_j = P_i_j .* (dP_i_j - dY_i dot Y_i)
sgrad_g_m_n
.
ForEach
([
&
](
auto
&
self
,
auto
idx_gmn
)
{
sgrad_g_m_n
.
ForEach
([
&
](
auto
&
self
,
auto
idx_gmn
)
{
...
...
example/52_flash_atten_bias/run_batched_multihead_attention_bias_forward.inc
View file @
aea324d2
...
@@ -66,10 +66,10 @@ int run(int argc, char* argv[])
...
@@ -66,10 +66,10 @@ int run(int argc, char* argv[])
exit
(
0
);
exit
(
0
);
}
}
float
p_dropout
=
1
-
p_drop
;
float
p_dropout
=
1
-
p_drop
;
ZDataType
p_dropout_in_
16bits
=
ZDataType
(
std
::
floor
(
p_dropout
*
6553
5.0
));
ZDataType
p_dropout_in_
uint8_t
=
ZDataType
(
std
::
floor
(
p_dropout
*
25
5.0
));
float
rp_dropout
=
1.0
/
p_dropout
;
float
rp_dropout
=
1.0
/
p_dropout
;
float
alpha
=
1.
f
/
std
::
sqrt
(
K
);
float
alpha
=
1.
f
/
std
::
sqrt
(
K
);
std
::
vector
<
ck
::
index_t
>
a_gs_ms_ks_lengths
{
G0
,
G1
,
M
,
K
};
std
::
vector
<
ck
::
index_t
>
a_gs_ms_ks_lengths
{
G0
,
G1
,
M
,
K
};
std
::
vector
<
ck
::
index_t
>
a_gs_ms_ks_strides
=
std
::
vector
<
ck
::
index_t
>
a_gs_ms_ks_strides
=
...
@@ -172,6 +172,7 @@ int run(int argc, char* argv[])
...
@@ -172,6 +172,7 @@ int run(int argc, char* argv[])
b0_device_buf
.
ToDevice
(
b0_gs_ns_ks
.
mData
.
data
());
b0_device_buf
.
ToDevice
(
b0_gs_ns_ks
.
mData
.
data
());
b1_device_buf
.
ToDevice
(
b1_gs_os_ns
.
mData
.
data
());
b1_device_buf
.
ToDevice
(
b1_gs_os_ns
.
mData
.
data
());
d_device_buf
.
ToDevice
(
d_gs_ms_ns
.
mData
.
data
());
d_device_buf
.
ToDevice
(
d_gs_ms_ns
.
mData
.
data
());
z_device_buf
.
ToDevice
(
z_gs_ms_ns
.
mData
.
data
());
auto
a_element_op
=
AElementOp
{};
auto
a_element_op
=
AElementOp
{};
auto
b0_element_op
=
B0ElementOp
{};
auto
b0_element_op
=
B0ElementOp
{};
...
@@ -243,6 +244,18 @@ int run(int argc, char* argv[])
...
@@ -243,6 +244,18 @@ int run(int argc, char* argv[])
if
(
do_verification
)
if
(
do_verification
)
{
{
// data objects for hipGraph verification
hipGraph_t
graph
;
hipGraphExec_t
g_instance
;
hipStream_t
stream
;
std
::
cout
<<
"verification with hipGraph capturing and replaying ... "
<<
std
::
endl
;
HIP_CHECK_ERROR
(
hipStreamCreate
(
&
stream
));
HIP_CHECK_ERROR
(
hipGraphCreate
(
&
graph
,
0
));
HIP_CHECK_ERROR
(
hipStreamBeginCapture
(
stream
,
hipStreamCaptureModeGlobal
));
// run for storing z tensor
// run for storing z tensor
argument
=
argument
=
gemm
.
MakeArgument
(
static_cast
<
ADataType
*>
(
a_device_buf
.
GetDeviceBuffer
()),
gemm
.
MakeArgument
(
static_cast
<
ADataType
*>
(
a_device_buf
.
GetDeviceBuffer
()),
...
@@ -276,9 +289,19 @@ int run(int argc, char* argv[])
...
@@ -276,9 +289,19 @@ int run(int argc, char* argv[])
p_drop
,
// dropout ratio
p_drop
,
// dropout ratio
{
seed
,
offset
});
// dropout random seed and offset, offset should be
{
seed
,
offset
});
// dropout random seed and offset, offset should be
// at least the number of elements on a thread
// at least the number of elements on a thread
c_device_buf
.
SetZero
();
HIP_CHECK_ERROR
(
hipMemsetAsync
(
lse_device_buf
.
SetZero
();
c_device_buf
.
GetDeviceBuffer
(),
0
,
c_device_buf
.
GetBufferSize
(),
stream
));
invoker
.
Run
(
argument
,
StreamConfig
{
nullptr
,
false
});
HIP_CHECK_ERROR
(
hipMemsetAsync
(
lse_device_buf
.
GetDeviceBuffer
(),
0
,
lse_device_buf
.
GetBufferSize
(),
stream
));
invoker
.
Run
(
argument
,
StreamConfig
{
stream
,
false
});
HIP_CHECK_ERROR
(
hipStreamEndCapture
(
stream
,
&
graph
));
HIP_CHECK_ERROR
(
hipGraphInstantiate
(
&
g_instance
,
graph
,
nullptr
,
nullptr
,
0
));
HIP_CHECK_ERROR
(
hipGraphLaunch
(
g_instance
,
stream
));
HIP_CHECK_ERROR
(
hipStreamSynchronize
(
stream
));
c_device_buf
.
FromDevice
(
c_gs_ms_os_device_result
.
mData
.
data
());
c_device_buf
.
FromDevice
(
c_gs_ms_os_device_result
.
mData
.
data
());
z_device_buf
.
FromDevice
(
z_gs_ms_ns
.
mData
.
data
());
z_device_buf
.
FromDevice
(
z_gs_ms_ns
.
mData
.
data
());
...
@@ -322,7 +345,9 @@ int run(int argc, char* argv[])
...
@@ -322,7 +345,9 @@ int run(int argc, char* argv[])
ref_gemm0_invoker
.
Run
(
ref_gemm0_argument
);
ref_gemm0_invoker
.
Run
(
ref_gemm0_argument
);
// bias
// bias
acc0_g_m_n
.
ForEach
([
&
](
auto
&
self
,
auto
idx
)
{
self
(
idx
)
+=
ck
::
type_convert
<
AccDataType
>
(
d_g_m_n
(
idx
));
});
acc0_g_m_n
.
ForEach
([
&
](
auto
&
self
,
auto
idx
)
{
self
(
idx
)
+=
ck
::
type_convert
<
AccDataType
>
(
d_g_m_n
(
idx
));
});
// masking
// masking
const
auto
mask
=
DeviceGemmInstance
::
C0MatrixMask
(
M
,
N
);
const
auto
mask
=
DeviceGemmInstance
::
C0MatrixMask
(
M
,
N
);
acc0_g_m_n
.
ForEach
([
&
](
auto
&
self
,
auto
idx
)
{
acc0_g_m_n
.
ForEach
([
&
](
auto
&
self
,
auto
idx
)
{
...
@@ -342,7 +367,7 @@ int run(int argc, char* argv[])
...
@@ -342,7 +367,7 @@ int run(int argc, char* argv[])
auto
ref_dropout
=
ReferenceDropoutInstance
{};
auto
ref_dropout
=
ReferenceDropoutInstance
{};
auto
ref_dropout_invoker
=
ref_dropout
.
MakeInvoker
();
auto
ref_dropout_invoker
=
ref_dropout
.
MakeInvoker
();
auto
ref_dropout_argment
=
ref_dropout
.
MakeArgument
(
auto
ref_dropout_argment
=
ref_dropout
.
MakeArgument
(
z_g_m_n
,
a1_g_m_n
,
a1_g_m_n_drop
,
p_dropout_in_
16bits
,
rp_dropout
);
z_g_m_n
,
a1_g_m_n
,
a1_g_m_n_drop
,
p_dropout_in_
uint8_t
,
rp_dropout
);
ref_dropout_invoker
.
Run
(
ref_dropout_argment
);
ref_dropout_invoker
.
Run
(
ref_dropout_argment
);
// gemm1
// gemm1
...
...
example/52_flash_atten_bias/run_grouped_multihead_attention_bias_forward.inc
View file @
aea324d2
...
@@ -43,9 +43,9 @@ int run(int argc, char* argv[])
...
@@ -43,9 +43,9 @@ int run(int argc, char* argv[])
exit
(
0
);
exit
(
0
);
}
}
float
p_dropout
=
1
-
p_drop
;
float
p_dropout
=
1
-
p_drop
;
uint16_t
p_dropout_in_
16bits
=
uint16_t
(
std
::
floor
(
p_dropout
*
6553
5.0
));
ZDataType
p_dropout_in_
uint8_t
=
ZDataType
(
std
::
floor
(
p_dropout
*
25
5.0
));
float
rp_dropout
=
1.0
/
p_dropout
;
float
rp_dropout
=
1.0
/
p_dropout
;
float
alpha
=
1
;
// scaling after 1st gemm
float
alpha
=
1
;
// scaling after 1st gemm
...
@@ -149,8 +149,8 @@ int run(int argc, char* argv[])
...
@@ -149,8 +149,8 @@ int run(int argc, char* argv[])
lse_gs_ms_strides
,
lse_gs_ms_strides
,
d_gs_ms_ns_lengths
,
// acc0_biases_gs_ms_ns_lengths
d_gs_ms_ns_lengths
,
// acc0_biases_gs_ms_ns_lengths
d_gs_ms_ns_strides
,
// acc0_biases_gs_ms_ns_strides
d_gs_ms_ns_strides
,
// acc0_biases_gs_ms_ns_strides
{},
// acc1_biases_gs_ms_os_lengths
{},
// acc1_biases_gs_ms_os_lengths
{}});
// acc1_biases_gs_ms_os_strides
{}});
// acc1_biases_gs_ms_os_strides
// C_m_o = A_m_k * B0_k_n * B1_n_o
// C_m_o = A_m_k * B0_k_n * B1_n_o
Tensor
<
ADataType
>
a_gs_ms_ks
(
a_gs_ms_ks_lengths
,
a_gs_ms_ks_strides
);
Tensor
<
ADataType
>
a_gs_ms_ks
(
a_gs_ms_ks_lengths
,
a_gs_ms_ks_strides
);
...
@@ -163,10 +163,11 @@ int run(int argc, char* argv[])
...
@@ -163,10 +163,11 @@ int run(int argc, char* argv[])
int
Batch
=
G0
*
G1
;
int
Batch
=
G0
*
G1
;
flop
+=
(
size_t
(
M
)
*
N
*
K
*
2
+
size_t
(
M
)
*
N
*
O
*
2
)
*
Batch
;
flop
+=
(
size_t
(
M
)
*
N
*
K
*
2
+
size_t
(
M
)
*
N
*
O
*
2
)
*
Batch
;
num_byte
+=
(
sizeof
(
ADataType
)
*
M
*
K
+
sizeof
(
B0DataType
)
*
K
*
N
+
num_byte
+=
sizeof
(
B1DataType
)
*
N
*
O
+
sizeof
(
CDataType
)
*
M
*
O
+
(
sizeof
(
ADataType
)
*
M
*
K
+
sizeof
(
B0DataType
)
*
K
*
N
+
sizeof
(
B1DataType
)
*
N
*
O
+
sizeof
(
Acc0BiasDataType
)
*
M
*
N
*
(
std
::
is_void
<
Acc0BiasDataType
>::
value
?
0
:
1
))
*
sizeof
(
CDataType
)
*
M
*
O
+
Batch
;
sizeof
(
Acc0BiasDataType
)
*
M
*
N
*
(
std
::
is_void
<
Acc0BiasDataType
>::
value
?
0
:
1
))
*
Batch
;
if
(
i
<
4
)
if
(
i
<
4
)
{
{
...
@@ -237,6 +238,7 @@ int run(int argc, char* argv[])
...
@@ -237,6 +238,7 @@ int run(int argc, char* argv[])
b0_tensors_device
[
i
]
->
ToDevice
(
b0_gs_ns_ks
.
mData
.
data
());
b0_tensors_device
[
i
]
->
ToDevice
(
b0_gs_ns_ks
.
mData
.
data
());
b1_tensors_device
[
i
]
->
ToDevice
(
b1_gs_os_ns
.
mData
.
data
());
b1_tensors_device
[
i
]
->
ToDevice
(
b1_gs_os_ns
.
mData
.
data
());
d_tensors_device
[
i
]
->
ToDevice
(
d_gs_ms_ns
.
mData
.
data
());
d_tensors_device
[
i
]
->
ToDevice
(
d_gs_ms_ns
.
mData
.
data
());
z_tensors_device
[
i
]
->
ToDevice
(
z_gs_ms_ns
.
mData
.
data
());
p_a
.
push_back
(
a_tensors_device
[
i
]
->
GetDeviceBuffer
());
p_a
.
push_back
(
a_tensors_device
[
i
]
->
GetDeviceBuffer
());
p_b0
.
push_back
(
b0_tensors_device
[
i
]
->
GetDeviceBuffer
());
p_b0
.
push_back
(
b0_tensors_device
[
i
]
->
GetDeviceBuffer
());
...
@@ -301,6 +303,18 @@ int run(int argc, char* argv[])
...
@@ -301,6 +303,18 @@ int run(int argc, char* argv[])
bool
pass
=
true
;
bool
pass
=
true
;
if
(
do_verification
)
if
(
do_verification
)
{
{
// data objects for hipGraph verification
hipGraph_t
graph
;
hipGraphExec_t
g_instance
;
hipStream_t
stream
;
std
::
cout
<<
"verification with hipGraph capturing and replaying ... "
<<
std
::
endl
;
HIP_CHECK_ERROR
(
hipStreamCreate
(
&
stream
));
HIP_CHECK_ERROR
(
hipGraphCreate
(
&
graph
,
0
));
HIP_CHECK_ERROR
(
hipStreamBeginCapture
(
stream
,
hipStreamCaptureModeRelaxed
));
argument
=
argument
=
gemm
.
MakeArgument
(
p_a
,
gemm
.
MakeArgument
(
p_a
,
p_b0
,
p_b0
,
...
@@ -324,7 +338,16 @@ int run(int argc, char* argv[])
...
@@ -324,7 +338,16 @@ int run(int argc, char* argv[])
gemm
.
SetWorkSpacePointer
(
&
argument
,
problem_desc_workspace_verify
.
GetDeviceBuffer
());
gemm
.
SetWorkSpacePointer
(
&
argument
,
problem_desc_workspace_verify
.
GetDeviceBuffer
());
invoker
.
Run
(
argument
,
StreamConfig
{
nullptr
,
false
});
invoker
.
Run
(
argument
,
StreamConfig
{
stream
,
false
});
HIP_CHECK_ERROR
(
hipStreamEndCapture
(
stream
,
&
graph
));
HIP_CHECK_ERROR
(
hipGraphInstantiate
(
&
g_instance
,
graph
,
nullptr
,
nullptr
,
0
));
HIP_CHECK_ERROR
(
hipGraphDebugDotPrint
(
graph
,
"grouped_fwd_debug.dot"
,
0x007f
));
HIP_CHECK_ERROR
(
hipGraphLaunch
(
g_instance
,
stream
));
HIP_CHECK_ERROR
(
hipStreamSynchronize
(
stream
));
for
(
std
::
size_t
i
=
0
;
i
<
group_count
;
i
++
)
for
(
std
::
size_t
i
=
0
;
i
<
group_count
;
i
++
)
{
{
...
@@ -396,7 +419,9 @@ int run(int argc, char* argv[])
...
@@ -396,7 +419,9 @@ int run(int argc, char* argv[])
ref_gemm0_invoker
.
Run
(
ref_gemm0_argument
);
ref_gemm0_invoker
.
Run
(
ref_gemm0_argument
);
// bias
// bias
acc0_g_m_n
.
ForEach
([
&
](
auto
&
self
,
auto
idx
)
{
self
(
idx
)
+=
ck
::
type_convert
<
AccDataType
>
(
d_g_m_n
(
idx
));
});
acc0_g_m_n
.
ForEach
([
&
](
auto
&
self
,
auto
idx
)
{
self
(
idx
)
+=
ck
::
type_convert
<
AccDataType
>
(
d_g_m_n
(
idx
));
});
// masking
// masking
const
auto
mask
=
DeviceGemmInstance
::
C0MatrixMask
(
M
,
N
);
const
auto
mask
=
DeviceGemmInstance
::
C0MatrixMask
(
M
,
N
);
acc0_g_m_n
.
ForEach
([
&
](
auto
&
self
,
auto
idx
)
{
acc0_g_m_n
.
ForEach
([
&
](
auto
&
self
,
auto
idx
)
{
...
@@ -419,7 +444,7 @@ int run(int argc, char* argv[])
...
@@ -419,7 +444,7 @@ int run(int argc, char* argv[])
auto
ref_dropout
=
ReferenceDropoutInstance
{};
auto
ref_dropout
=
ReferenceDropoutInstance
{};
auto
ref_dropout_invoker
=
ref_dropout
.
MakeInvoker
();
auto
ref_dropout_invoker
=
ref_dropout
.
MakeInvoker
();
auto
ref_dropout_argment
=
ref_dropout
.
MakeArgument
(
auto
ref_dropout_argment
=
ref_dropout
.
MakeArgument
(
z_g_m_n
,
a1_g_m_n
,
a1_g_m_n_drop
,
p_dropout_in_
16bits
,
rp_dropout
);
z_g_m_n
,
a1_g_m_n
,
a1_g_m_n_drop
,
p_dropout_in_
uint8_t
,
rp_dropout
);
ref_dropout_invoker
.
Run
(
ref_dropout_argment
);
ref_dropout_invoker
.
Run
(
ref_dropout_argment
);
// gemm 1
// gemm 1
...
...
include/ck/host_utility/hip_check_error.hpp
View file @
aea324d2
...
@@ -15,3 +15,16 @@ inline void hip_check_error(hipError_t x)
...
@@ -15,3 +15,16 @@ inline void hip_check_error(hipError_t x)
throw
std
::
runtime_error
(
ss
.
str
());
throw
std
::
runtime_error
(
ss
.
str
());
}
}
}
}
#define HIP_CHECK_ERROR(flag) \
do \
{ \
hipError_t _tmpVal; \
if((_tmpVal = flag) != hipSuccess) \
{ \
std::ostringstream ostr; \
ostr << "HIP Function Failed (" << __FILE__ << "," << __LINE__ << ") " \
<< hipGetErrorString(_tmpVal); \
throw std::runtime_error(ostr.str()); \
} \
} while(0)
include/ck/tensor_operation/gpu/block/blockwise_dropout.hpp
View file @
aea324d2
...
@@ -16,111 +16,111 @@ struct BlockwiseDropout
...
@@ -16,111 +16,111 @@ struct BlockwiseDropout
static
constexpr
index_t
MRepeat
=
ThreadSliceDesc_M_K
{}.
GetLength
(
I0
);
static
constexpr
index_t
MRepeat
=
ThreadSliceDesc_M_K
{}.
GetLength
(
I0
);
static
constexpr
index_t
KRepeat
=
ThreadSliceDesc_M_K
{}.
GetLength
(
I1
);
static
constexpr
index_t
KRepeat
=
ThreadSliceDesc_M_K
{}.
GetLength
(
I1
);
template
<
typename
CThreadBuffer
,
bool
using_sign_bit
=
false
>
//
template <typename CThreadBuffer, bool using_sign_bit = false>
__host__
__device__
void
ApplyDropout
(
CThreadBuffer
&
in_thread_buf
,
ck
::
philox
&
ph
)
//
__host__ __device__ void ApplyDropout(CThreadBuffer& in_thread_buf, ck::philox& ph)
{
//
{
auto
execute_dropout
=
[
&
](
bool
keep
,
DataType
val
)
{
//
auto execute_dropout = [&](bool keep, DataType val) {
if
constexpr
(
using_sign_bit
)
//
if constexpr(using_sign_bit)
return
keep
?
val
:
-
val
;
//
return keep ? val : -val;
else
//
else
return
keep
?
val
*
p_dropout_rescale
:
float
(
0
);
//
return keep ? val * p_dropout_rescale : float(0);
};
//
};
constexpr
int
tmp_size
=
MRepeat
*
KRepeat
;
//
constexpr int tmp_size = MRepeat * KRepeat;
int
philox_calls
=
tmp_size
/
8
;
//
int philox_calls = tmp_size / 8;
ushort
tmp
[
tmp_size
];
//
ushort tmp[tmp_size];
for
(
int
i
=
0
;
i
<
philox_calls
;
i
++
)
//
for(int i = 0; i < philox_calls; i++)
{
//
{
ph
.
get_random_8x16
((
tmp
+
i
*
8
));
//
ph.get_random_8x16((tmp + i * 8));
}
//
}
block_sync_lds
();
//
block_sync_lds();
int
tmp_index
=
0
;
//
int tmp_index = 0;
static_for
<
0
,
MRepeat
,
1
>
{}([
&
](
auto
iM
)
{
//
static_for<0, MRepeat, 1>{}([&](auto iM) {
static_for
<
0
,
KRepeat
,
1
>
{}([
&
](
auto
iK
)
{
//
static_for<0, KRepeat, 1>{}([&](auto iK) {
auto
offset
=
Number
<
ThreadSliceDesc_M_K
{}.
CalculateOffset
(
make_tuple
(
iM
,
iK
))
>
{};
//
auto offset = Number<ThreadSliceDesc_M_K{}.CalculateOffset(make_tuple(iM,
in_thread_buf
(
offset
)
=
//
iK))>{};
in_thread_buf(offset) =
execute_dropout
(
tmp
[
tmp_index
]
<=
p_dropout_
16bits
,
in_thread_buf
(
offset
));
//
execute_dropout(tmp[tmp_index] <= p_dropout_
uint8_t
, in_thread_buf(offset));
tmp_index
=
tmp_index
+
1
;
//
tmp_index = tmp_index + 1;
});
//
});
});
//
});
}
//
}
template
<
typename
CThreadBuffer
,
typename
ZThreadBuffer
,
bool
using_sign_bit
=
false
>
//
template <typename CThreadBuffer, typename ZThreadBuffer, bool using_sign_bit = false>
__host__
__device__
void
//
__host__ __device__ void
ApplyDropout
(
CThreadBuffer
&
in_thread_buf
,
ck
::
philox
&
ph
,
ZThreadBuffer
&
z_thread_buf
)
//
ApplyDropout(CThreadBuffer& in_thread_buf, ck::philox& ph, ZThreadBuffer& z_thread_buf)
{
//
{
auto
execute_dropout
=
[
&
](
bool
keep
,
DataType
val
)
{
//
auto execute_dropout = [&](bool keep, DataType val) {
if
constexpr
(
using_sign_bit
)
//
if constexpr(using_sign_bit)
return
keep
?
val
:
-
val
;
//
return keep ? val : -val;
else
//
else
return
keep
?
val
*
p_dropout_rescale
:
float
(
0
);
//
return keep ? val * p_dropout_rescale : float(0);
};
//
};
constexpr
int
tmp_size
=
MRepeat
*
KRepeat
;
//
constexpr int tmp_size = MRepeat * KRepeat;
int
philox_calls
=
tmp_size
/
8
;
//
int philox_calls = tmp_size / 8;
ushort
tmp
[
tmp_size
];
//
ushort tmp[tmp_size];
for
(
int
i
=
0
;
i
<
philox_calls
;
i
++
)
//
for(int i = 0; i < philox_calls; i++)
{
//
{
ph
.
get_random_8x16
((
tmp
+
i
*
8
));
//
ph.get_random_8x16((tmp + i * 8));
}
//
}
block_sync_lds
();
//
block_sync_lds();
int
tmp_index
=
0
;
//
int tmp_index = 0;
static_for
<
0
,
MRepeat
,
1
>
{}([
&
](
auto
iM
)
{
//
static_for<0, MRepeat, 1>{}([&](auto iM) {
static_for
<
0
,
KRepeat
,
1
>
{}([
&
](
auto
iK
)
{
//
static_for<0, KRepeat, 1>{}([&](auto iK) {
auto
offset
=
Number
<
ThreadSliceDesc_M_K
{}.
CalculateOffset
(
make_tuple
(
iM
,
iK
))
>
{};
//
auto offset = Number<ThreadSliceDesc_M_K{}.CalculateOffset(make_tuple(iM,
in_thread_buf
(
offset
)
=
//
iK))>{};
in_thread_buf(offset) =
execute_dropout
(
tmp
[
tmp_index
]
<=
p_dropout_
16bits
,
in_thread_buf
(
offset
));
//
execute_dropout(tmp[tmp_index] <= p_dropout_
uint8_t
, in_thread_buf(offset));
z_thread_buf
(
offset
)
=
tmp
[
tmp_index
];
//
z_thread_buf(offset) = tmp[tmp_index];
tmp_index
=
tmp_index
+
1
;
//
tmp_index = tmp_index + 1;
});
//
});
});
//
});
}
//
}
template
<
typename
CThreadBuffer
,
//
template <typename CThreadBuffer,
typename
ZThreadBuffer
,
//
typename ZThreadBuffer,
bool
using_sign_bit
,
//
bool using_sign_bit,
typename
N0
,
//
typename N0,
typename
Offset
>
//
typename Offset>
__host__
__device__
void
//
__host__ __device__ void
ApplyDropout
(
CThreadBuffer
&
in_thread_buf
,
ck
::
philox
&
ph
,
ZThreadBuffer
&
z_thread_buf
)
//
ApplyDropout(CThreadBuffer& in_thread_buf, ck::philox& ph, ZThreadBuffer& z_thread_buf)
{
//
{
auto
execute_dropout
=
[
&
](
bool
keep
,
DataType
val
)
{
//
auto execute_dropout = [&](bool keep, DataType val) {
if
constexpr
(
using_sign_bit
)
//
if constexpr(using_sign_bit)
return
keep
?
val
:
-
val
;
//
return keep ? val : -val;
else
//
else
return
keep
?
val
*
p_dropout_rescale
:
float
(
0
);
//
return keep ? val * p_dropout_rescale : float(0);
};
//
};
constexpr
int
tmp_size
=
MRepeat
*
KRepeat
/
N0
{}.
value
;
//
constexpr int tmp_size = MRepeat * KRepeat / N0{}.value;
int
philox_calls
=
tmp_size
/
8
;
//
int philox_calls = tmp_size / 8;
ushort
tmp
[
tmp_size
];
//
ushort tmp[tmp_size];
for
(
int
i
=
0
;
i
<
philox_calls
;
i
++
)
//
for(int i = 0; i < philox_calls; i++)
{
//
{
ph
.
get_random_8x16
((
tmp
+
i
*
8
));
//
ph.get_random_8x16((tmp + i * 8));
}
//
}
block_sync_lds
();
//
block_sync_lds();
constexpr
auto
iOffset
=
Number
<
tmp_size
>
{}
*
Offset
{};
//
constexpr auto iOffset = Number<tmp_size>{} * Offset{};
static_for
<
0
,
tmp_size
,
1
>
{}([
&
](
auto
i
)
{
//
static_for<0, tmp_size, 1>{}([&](auto i) {
in_thread_buf
(
i
+
iOffset
)
=
//
in_thread_buf(i + iOffset) =
execute_dropout
(
tmp
[
i
.
value
]
<=
p_dropout_
16bits
,
in_thread_buf
(
i
+
iOffset
));
//
execute_dropout(tmp[i.value] <= p_dropout_
uint8_t
, in_thread_buf(i + iOffset));
z_thread_buf
(
i
)
=
tmp
[
i
.
value
];
//
z_thread_buf(i) = tmp[i.value];
});
//
});
}
//
}
template
<
typename
CThreadBuffer
,
typename
Offset
,
bool
using_sign_bit
=
false
>
template
<
typename
CThreadBuffer
,
typename
Offset
,
bool
using_sign_bit
=
false
>
__host__
__device__
void
ApplyDropoutAttnBwd
(
CThreadBuffer
&
in_thread_buf
,
__host__
__device__
void
ApplyDropoutAttnBwd
(
CThreadBuffer
&
in_thread_buf
,
...
@@ -138,12 +138,12 @@ struct BlockwiseDropout
...
@@ -138,12 +138,12 @@ struct BlockwiseDropout
constexpr
int
tmp_size
=
MRepeat
*
KRepeat
;
constexpr
int
tmp_size
=
MRepeat
*
KRepeat
;
int
philox_calls
=
tmp_size
/
8
;
int
philox_calls
=
tmp_size
/
16
;
u
shor
t
tmp
[
tmp_size
];
u
int8_
t
tmp
[
tmp_size
];
for
(
int
i
=
0
;
i
<
philox_calls
;
i
++
)
for
(
int
i
=
0
;
i
<
philox_calls
;
i
++
)
{
{
ph
.
get_random_
8x
16
((
tmp
+
i
*
8
),
element_global_1d_id
+
i
*
Offset
{}
*
MRaw
);
ph
.
get_random_16
x8
((
tmp
+
i
*
16
),
element_global_1d_id
+
i
*
Offset
{}
*
MRaw
);
}
}
block_sync_lds
();
block_sync_lds
();
...
@@ -153,7 +153,7 @@ struct BlockwiseDropout
...
@@ -153,7 +153,7 @@ struct BlockwiseDropout
static_for
<
0
,
KRepeat
,
1
>
{}([
&
](
auto
iK
)
{
static_for
<
0
,
KRepeat
,
1
>
{}([
&
](
auto
iK
)
{
auto
offset
=
Number
<
ThreadSliceDesc_M_K
{}.
CalculateOffset
(
make_tuple
(
iM
,
iK
))
>
{};
auto
offset
=
Number
<
ThreadSliceDesc_M_K
{}.
CalculateOffset
(
make_tuple
(
iM
,
iK
))
>
{};
in_thread_buf
(
offset
)
=
in_thread_buf
(
offset
)
=
execute_dropout
(
tmp
[
tmp_index
]
<=
p_dropout_
16bits
,
in_thread_buf
(
offset
));
execute_dropout
(
tmp
[
tmp_index
]
<=
p_dropout_
uint8_t
,
in_thread_buf
(
offset
));
tmp_index
=
tmp_index
+
1
;
tmp_index
=
tmp_index
+
1
;
});
});
});
});
...
@@ -179,12 +179,12 @@ struct BlockwiseDropout
...
@@ -179,12 +179,12 @@ struct BlockwiseDropout
constexpr
int
tmp_size
=
MRepeat
*
KRepeat
;
constexpr
int
tmp_size
=
MRepeat
*
KRepeat
;
int
philox_calls
=
tmp_size
/
8
;
int
philox_calls
=
tmp_size
/
16
;
u
shor
t
tmp
[
tmp_size
];
u
int8_
t
tmp
[
tmp_size
];
for
(
int
i
=
0
;
i
<
philox_calls
;
i
++
)
for
(
int
i
=
0
;
i
<
philox_calls
;
i
++
)
{
{
ph
.
get_random_
8x
16
((
tmp
+
i
*
8
),
element_global_1d_id
+
i
*
Offset
{}
*
MRaw
);
ph
.
get_random_16
x8
((
tmp
+
i
*
16
),
element_global_1d_id
+
i
*
Offset
{}
*
MRaw
);
}
}
block_sync_lds
();
block_sync_lds
();
...
@@ -194,7 +194,7 @@ struct BlockwiseDropout
...
@@ -194,7 +194,7 @@ struct BlockwiseDropout
static_for
<
0
,
KRepeat
,
1
>
{}([
&
](
auto
iK
)
{
static_for
<
0
,
KRepeat
,
1
>
{}([
&
](
auto
iK
)
{
auto
offset
=
Number
<
ThreadSliceDesc_M_K
{}.
CalculateOffset
(
make_tuple
(
iM
,
iK
))
>
{};
auto
offset
=
Number
<
ThreadSliceDesc_M_K
{}.
CalculateOffset
(
make_tuple
(
iM
,
iK
))
>
{};
in_thread_buf
(
offset
)
=
in_thread_buf
(
offset
)
=
execute_dropout
(
tmp
[
tmp_index
]
<=
p_dropout_
16bits
,
in_thread_buf
(
offset
));
execute_dropout
(
tmp
[
tmp_index
]
<=
p_dropout_
uint8_t
,
in_thread_buf
(
offset
));
z_thread_buf
(
offset
)
=
tmp
[
tmp_index
];
z_thread_buf
(
offset
)
=
tmp
[
tmp_index
];
tmp_index
=
tmp_index
+
1
;
tmp_index
=
tmp_index
+
1
;
});
});
...
@@ -213,7 +213,7 @@ struct BlockwiseDropout
...
@@ -213,7 +213,7 @@ struct BlockwiseDropout
constexpr
int
tmp_size
=
MRepeat
*
KRepeat
/
Step
{}.
value
;
constexpr
int
tmp_size
=
MRepeat
*
KRepeat
/
Step
{}.
value
;
static_for
<
0
,
tmp_size
,
1
>
{}([
&
](
auto
i
)
{
static_for
<
0
,
tmp_size
,
1
>
{}([
&
](
auto
i
)
{
in_thread_buf
(
i
+
Offset
{})
=
in_thread_buf
(
i
+
Offset
{})
=
execute_dropout
(
z_thread_buf
(
i
)
<=
p_dropout_
16bits
,
in_thread_buf
(
i
+
Offset
{}));
execute_dropout
(
z_thread_buf
(
i
)
<=
p_dropout_
uint8_t
,
in_thread_buf
(
i
+
Offset
{}));
});
});
}
}
...
@@ -225,18 +225,18 @@ struct BlockwiseDropout
...
@@ -225,18 +225,18 @@ struct BlockwiseDropout
{
{
constexpr
int
tmp_size
=
MRepeat
*
KRepeat
/
Step
{}.
value
;
constexpr
int
tmp_size
=
MRepeat
*
KRepeat
/
Step
{}.
value
;
int
philox_calls
=
tmp_size
/
8
;
int
philox_calls
=
tmp_size
/
16
;
u
shor
t
tmp
[
tmp_size
];
u
int8_
t
tmp
[
tmp_size
];
for
(
int
i
=
0
;
i
<
philox_calls
;
i
++
)
for
(
int
i
=
0
;
i
<
philox_calls
;
i
++
)
{
{
ph
.
get_random_
8x
16
((
tmp
+
i
*
8
),
element_global_1d_id
+
i
*
Offset
{});
ph
.
get_random_16
x8
((
tmp
+
i
*
16
),
element_global_1d_id
+
i
*
Offset
{});
}
}
static_for
<
0
,
tmp_size
,
1
>
{}([
&
](
auto
i
)
{
z_thread_buf
(
i
)
=
tmp
[
i
.
value
];
});
static_for
<
0
,
tmp_size
,
1
>
{}([
&
](
auto
i
)
{
z_thread_buf
(
i
)
=
tmp
[
i
.
value
];
});
}
}
u
shor
t
p_dropout_
16bits
;
u
int8_
t
p_dropout_
uint8_t
;
DataType
p_dropout_rescale
;
DataType
p_dropout_rescale
;
};
};
...
...
include/ck/tensor_operation/gpu/device/impl/device_batched_dropout.hpp
View file @
aea324d2
...
@@ -69,7 +69,6 @@ __global__ void
...
@@ -69,7 +69,6 @@ __global__ void
raw_n_padded
);
raw_n_padded
);
#else
#else
ignore
=
p_z_grid
;
ignore
=
p_z_grid
;
ignore
=
a_grid_desc_ak0_m_ak1
;
ignore
=
c_grid_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3
;
ignore
=
c_grid_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3
;
ignore
=
block_2_ctile_map
;
ignore
=
block_2_ctile_map
;
ignore
=
batch_count
;
ignore
=
batch_count
;
...
...
include/ck/tensor_operation/gpu/device/impl/device_batched_mha_fwd_xdl_cshuffle_v2.hpp
View file @
aea324d2
...
@@ -40,7 +40,7 @@ template <typename GridwiseGemm,
...
@@ -40,7 +40,7 @@ template <typename GridwiseGemm,
typename
D0GridDescriptor_M0_N0_M1_N1_M2_N2_M3_N3_N4_N5
,
typename
D0GridDescriptor_M0_N0_M1_N1_M2_N2_M3_N3_N4_N5
,
typename
B1GridDesc_BK0_N_BK1
,
typename
B1GridDesc_BK0_N_BK1
,
typename
CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
,
typename
CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
,
typename
ZGridDescriptor_M0_N0_M1_N1_M2_N2_M3_N3_
M4_
N4_N5
_N6
,
typename
ZGridDescriptor_M0_N0_M1_N1_M2_N2_M3_N3_N4_N5
,
typename
LSEGridDescriptor_M
,
typename
LSEGridDescriptor_M
,
typename
Block2CTileMap
,
typename
Block2CTileMap
,
typename
ComputeBasePtrOfStridedBatch
,
typename
ComputeBasePtrOfStridedBatch
,
...
@@ -73,15 +73,15 @@ __global__ void
...
@@ -73,15 +73,15 @@ __global__ void
const
B1GridDesc_BK0_N_BK1
b1_grid_desc_bk0_n_bk1
,
const
B1GridDesc_BK0_N_BK1
b1_grid_desc_bk0_n_bk1
,
const
CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
const
CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
c_grid_desc_mblock_mperblock_nblock_nperblock
,
c_grid_desc_mblock_mperblock_nblock_nperblock
,
const
ZGridDescriptor_M0_N0_M1_N1_M2_N2_M3_N3_
M4_
N4_N5
_N6
const
ZGridDescriptor_M0_N0_M1_N1_M2_N2_M3_N3_N4_N5
z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_
m4_
n4_n5
_n6
,
z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5
,
const
LSEGridDescriptor_M
lse_grid_desc_m
,
const
LSEGridDescriptor_M
lse_grid_desc_m
,
const
Block2CTileMap
block_2_ctile_map
,
const
Block2CTileMap
block_2_ctile_map
,
const
index_t
batch_count
,
const
index_t
batch_count
,
const
index_t
mblock
,
const
index_t
mblock
,
const
ComputeBasePtrOfStridedBatch
compute_base_ptr_of_batch
,
const
ComputeBasePtrOfStridedBatch
compute_base_ptr_of_batch
,
const
C0MatrixMask
c0_matrix_mask
,
const
C0MatrixMask
c0_matrix_mask
,
const
u
shor
t
p_dropout_in_
16bits
,
const
u
int8_
t
p_dropout_in_
uint8_t
,
const
GemmAccDataType
p_dropout_rescale
,
const
GemmAccDataType
p_dropout_rescale
,
const
unsigned
long
long
seed
,
const
unsigned
long
long
seed
,
const
unsigned
long
long
offset
,
const
unsigned
long
long
offset
,
...
@@ -145,11 +145,11 @@ __global__ void
...
@@ -145,11 +145,11 @@ __global__ void
d0_griddesc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5
,
d0_griddesc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5
,
b1_grid_desc_bk0_n_bk1
,
b1_grid_desc_bk0_n_bk1
,
c_grid_desc_mblock_mperblock_nblock_nperblock
,
c_grid_desc_mblock_mperblock_nblock_nperblock
,
z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_
m4_
n4_n5
_n6
,
z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5
,
lse_grid_desc_m
,
lse_grid_desc_m
,
block_2_ctile_map
,
block_2_ctile_map
,
c0_matrix_mask
,
c0_matrix_mask
,
p_dropout_in_
16bits
,
p_dropout_in_
uint8_t
,
p_dropout_rescale
,
p_dropout_rescale
,
ph
,
ph
,
z_random_matrix_offset
,
z_random_matrix_offset
,
...
@@ -178,11 +178,11 @@ __global__ void
...
@@ -178,11 +178,11 @@ __global__ void
d0_griddesc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5
,
d0_griddesc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5
,
b1_grid_desc_bk0_n_bk1
,
b1_grid_desc_bk0_n_bk1
,
c_grid_desc_mblock_mperblock_nblock_nperblock
,
c_grid_desc_mblock_mperblock_nblock_nperblock
,
z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_
m4_
n4_n5
_n6
,
z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5
,
lse_grid_desc_m
,
lse_grid_desc_m
,
block_2_ctile_map
,
block_2_ctile_map
,
c0_matrix_mask
,
c0_matrix_mask
,
p_dropout_in_
16bits
,
p_dropout_in_
uint8_t
,
p_dropout_rescale
,
p_dropout_rescale
,
ph
,
ph
,
z_random_matrix_offset
,
z_random_matrix_offset
,
...
@@ -207,14 +207,14 @@ __global__ void
...
@@ -207,14 +207,14 @@ __global__ void
ignore
=
d0_griddesc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5
;
ignore
=
d0_griddesc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5
;
ignore
=
b1_grid_desc_bk0_n_bk1
;
ignore
=
b1_grid_desc_bk0_n_bk1
;
ignore
=
c_grid_desc_mblock_mperblock_nblock_nperblock
;
ignore
=
c_grid_desc_mblock_mperblock_nblock_nperblock
;
ignore
=
z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_
m4_
n4_n5
_n6
;
ignore
=
z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5
;
ignore
=
lse_grid_desc_m
;
ignore
=
lse_grid_desc_m
;
ignore
=
block_2_ctile_map
;
ignore
=
block_2_ctile_map
;
ignore
=
batch_count
;
ignore
=
batch_count
;
ignore
=
mblock
;
ignore
=
mblock
;
ignore
=
compute_base_ptr_of_batch
;
ignore
=
compute_base_ptr_of_batch
;
ignore
=
c0_matrix_mask
;
ignore
=
c0_matrix_mask
;
ignore
=
p_dropout_in_
16bits
;
ignore
=
p_dropout_in_
uint8_t
;
ignore
=
p_dropout_rescale
;
ignore
=
p_dropout_rescale
;
ignore
=
seed
;
ignore
=
seed
;
ignore
=
offset
;
ignore
=
offset
;
...
@@ -695,18 +695,17 @@ struct DeviceBatchedMultiheadAttentionForward_Xdl_CShuffle_V2
...
@@ -695,18 +695,17 @@ struct DeviceBatchedMultiheadAttentionForward_Xdl_CShuffle_V2
}
}
}
}
is_dropout_
=
p_dropout
>
0.0
;
//
is_dropout_
=
p_dropout
>
0.0
;
//
p_dropout_
=
1.
f
-
p_dropout
;
p_dropout_
=
1.
f
-
p_dropout
;
p_dropout_in_
16bits
_
=
uint
16
_t
(
std
::
floor
(
p_dropout_
*
6553
5.0
));
p_dropout_in_
uint8_t
_
=
uint
8
_t
(
std
::
floor
(
p_dropout_
*
25
5.0
));
p_dropout_
=
1.
f
/
p_dropout_
;
p_dropout_
=
1.
f
/
p_dropout_
;
p_dropout_rescale_
=
type_convert
<
GemmAccDataType
>
(
p_dropout_
);
p_dropout_rescale_
=
type_convert
<
GemmAccDataType
>
(
p_dropout_
);
seed_
=
std
::
get
<
0
>
(
seeds
);
seed_
=
std
::
get
<
0
>
(
seeds
);
offset_
=
std
::
get
<
1
>
(
seeds
);
offset_
=
std
::
get
<
1
>
(
seeds
);
z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_m4_n4_n5_n6_
=
z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5_
=
GridwiseGemm
::
MakeCGridDescriptor_M0_N0_M1_N1_M2_N2_M3_N3_M4_N4_N5_N6
(
GridwiseGemm
::
MakeCGridDescriptor_M0_N0_M1_N1_M2_N2_M3_N3_N4_N5
(
z_grid_desc_m_n_
);
z_grid_desc_m_n_
);
m_raw_padded_
=
GridwiseGemm
::
GetPaddedSize
(
raw_lengths_mz_nz_kz_gemm1nz_
[
0
]);
m_raw_padded_
=
GridwiseGemm
::
GetPaddedSize
(
raw_lengths_mz_nz_kz_gemm1nz_
[
0
]);
n_raw_padded_
=
GridwiseGemm
::
GetPaddedSize
(
raw_lengths_mz_nz_kz_gemm1nz_
[
1
]);
n_raw_padded_
=
GridwiseGemm
::
GetPaddedSize
(
raw_lengths_mz_nz_kz_gemm1nz_
[
1
]);
...
@@ -779,8 +778,8 @@ struct DeviceBatchedMultiheadAttentionForward_Xdl_CShuffle_V2
...
@@ -779,8 +778,8 @@ struct DeviceBatchedMultiheadAttentionForward_Xdl_CShuffle_V2
typename
GridwiseGemm
::
CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
typename
GridwiseGemm
::
CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
c_grid_desc_mblock_mperblock_nblock_nperblock_
;
c_grid_desc_mblock_mperblock_nblock_nperblock_
;
typename
GridwiseGemm
::
ZGridDescriptor_M0_N0_M1_N1_M2_N2_M3_N3_
M4_
N4_N5
_N6
typename
GridwiseGemm
::
ZGridDescriptor_M0_N0_M1_N1_M2_N2_M3_N3_N4_N5
z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_
m4_
n4_n5_
n6_
;
z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5_
;
// block-to-c-tile map
// block-to-c-tile map
typename
GridwiseGemm
::
DefaultBlock2CTileMap
block_2_ctile_map_
;
typename
GridwiseGemm
::
DefaultBlock2CTileMap
block_2_ctile_map_
;
...
@@ -806,7 +805,7 @@ struct DeviceBatchedMultiheadAttentionForward_Xdl_CShuffle_V2
...
@@ -806,7 +805,7 @@ struct DeviceBatchedMultiheadAttentionForward_Xdl_CShuffle_V2
ComputeBasePtrOfStridedBatch
compute_base_ptr_of_batch_
;
ComputeBasePtrOfStridedBatch
compute_base_ptr_of_batch_
;
float
p_dropout_
;
float
p_dropout_
;
u
shor
t
p_dropout_in_
16bits
_
;
u
int8_
t
p_dropout_in_
uint8_t
_
;
GemmAccDataType
p_dropout_rescale_
;
GemmAccDataType
p_dropout_rescale_
;
unsigned
long
long
seed_
;
unsigned
long
long
seed_
;
unsigned
long
long
offset_
;
unsigned
long
long
offset_
;
...
@@ -864,7 +863,7 @@ struct DeviceBatchedMultiheadAttentionForward_Xdl_CShuffle_V2
...
@@ -864,7 +863,7 @@ struct DeviceBatchedMultiheadAttentionForward_Xdl_CShuffle_V2
typename
GridwiseGemm
::
D0GridDescriptor_M0_N0_M1_N1_M2_N2_M3_N3_N4_N5
,
typename
GridwiseGemm
::
D0GridDescriptor_M0_N0_M1_N1_M2_N2_M3_N3_N4_N5
,
DeviceOp
::
B1GridDesc_BK0_N_BK1
,
DeviceOp
::
B1GridDesc_BK0_N_BK1
,
typename
GridwiseGemm
::
CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
,
typename
GridwiseGemm
::
CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
,
typename
GridwiseGemm
::
ZGridDescriptor_M0_N0_M1_N1_M2_N2_M3_N3_
M4_
N4_N5
_N6
,
typename
GridwiseGemm
::
ZGridDescriptor_M0_N0_M1_N1_M2_N2_M3_N3_N4_N5
,
DeviceOp
::
LSEGridDesc_M
,
DeviceOp
::
LSEGridDesc_M
,
typename
GridwiseGemm
::
DefaultBlock2CTileMap
,
typename
GridwiseGemm
::
DefaultBlock2CTileMap
,
ComputeBasePtrOfStridedBatch
,
ComputeBasePtrOfStridedBatch
,
...
@@ -897,14 +896,14 @@ struct DeviceBatchedMultiheadAttentionForward_Xdl_CShuffle_V2
...
@@ -897,14 +896,14 @@ struct DeviceBatchedMultiheadAttentionForward_Xdl_CShuffle_V2
arg
.
d0_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5_
,
arg
.
d0_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5_
,
arg
.
b1_grid_desc_bk0_n_bk1_
,
arg
.
b1_grid_desc_bk0_n_bk1_
,
arg
.
c_grid_desc_mblock_mperblock_nblock_nperblock_
,
arg
.
c_grid_desc_mblock_mperblock_nblock_nperblock_
,
arg
.
z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_
m4_
n4_n5_
n6_
,
arg
.
z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5_
,
arg
.
lse_grid_desc_m_
,
arg
.
lse_grid_desc_m_
,
arg
.
block_2_ctile_map_
,
arg
.
block_2_ctile_map_
,
arg
.
batch_count_
,
arg
.
batch_count_
,
arg
.
block_2_ctile_map_
.
CalculateGridSize
(
arg
.
c_grid_desc_m_n_
),
arg
.
block_2_ctile_map_
.
CalculateGridSize
(
arg
.
c_grid_desc_m_n_
),
arg
.
compute_base_ptr_of_batch_
,
arg
.
compute_base_ptr_of_batch_
,
arg
.
c0_matrix_mask_
,
arg
.
c0_matrix_mask_
,
arg
.
p_dropout_in_
16bits
_
,
arg
.
p_dropout_in_
uint8_t
_
,
arg
.
p_dropout_rescale_
,
arg
.
p_dropout_rescale_
,
arg
.
seed_
,
arg
.
seed_
,
arg
.
offset_
,
arg
.
offset_
,
...
...
include/ck/tensor_operation/gpu/device/impl/device_grouped_mha_fwd_xdl_cshuffle_v2.hpp
View file @
aea324d2
...
@@ -5,6 +5,7 @@
...
@@ -5,6 +5,7 @@
#include <iostream>
#include <iostream>
#include <sstream>
#include <sstream>
#include <cstring>
#include "ck/utility/common_header.hpp"
#include "ck/utility/common_header.hpp"
#include "ck/utility/philox_rand.hpp"
#include "ck/utility/philox_rand.hpp"
...
@@ -48,7 +49,7 @@ __global__ void
...
@@ -48,7 +49,7 @@ __global__ void
const
AccElementwiseOperation
acc_element_op
,
const
AccElementwiseOperation
acc_element_op
,
const
B1ElementwiseOperation
b1_element_op
,
const
B1ElementwiseOperation
b1_element_op
,
const
CElementwiseOperation
c_element_op
,
const
CElementwiseOperation
c_element_op
,
const
u
shor
t
p_dropout_in_
16bits
,
const
u
int8_
t
p_dropout_in_
uint8_t
,
const
GemmAccDataType
p_dropout_rescale
,
const
GemmAccDataType
p_dropout_rescale
,
const
unsigned
long
long
seed
,
const
unsigned
long
long
seed
,
const
unsigned
long
long
offset
)
const
unsigned
long
long
offset
)
...
@@ -140,11 +141,11 @@ __global__ void
...
@@ -140,11 +141,11 @@ __global__ void
arg_ptr
[
group_id
].
d0_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5_
,
arg_ptr
[
group_id
].
d0_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5_
,
arg_ptr
[
group_id
].
b1_grid_desc_bk0_n_bk1_
,
arg_ptr
[
group_id
].
b1_grid_desc_bk0_n_bk1_
,
arg_ptr
[
group_id
].
c_grid_desc_mblock_mperblock_nblock_nperblock_
,
arg_ptr
[
group_id
].
c_grid_desc_mblock_mperblock_nblock_nperblock_
,
arg_ptr
[
group_id
].
z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_
m4_
n4_n5_
n6_
,
arg_ptr
[
group_id
].
z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5_
,
arg_ptr
[
group_id
].
lse_grid_desc_m_
,
arg_ptr
[
group_id
].
lse_grid_desc_m_
,
arg_ptr
[
group_id
].
block_2_ctile_map_
,
arg_ptr
[
group_id
].
block_2_ctile_map_
,
arg_ptr
[
group_id
].
c0_matrix_mask_
,
arg_ptr
[
group_id
].
c0_matrix_mask_
,
p_dropout_in_
16bits
,
p_dropout_in_
uint8_t
,
p_dropout_rescale
,
p_dropout_rescale
,
ph
,
ph
,
arg_ptr
[
group_id
].
z_random_matrix_offset_
+
arg_ptr
[
group_id
].
z_random_matrix_offset_
+
...
@@ -178,11 +179,11 @@ __global__ void
...
@@ -178,11 +179,11 @@ __global__ void
arg_ptr
[
group_id
].
d0_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5_
,
arg_ptr
[
group_id
].
d0_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5_
,
arg_ptr
[
group_id
].
b1_grid_desc_bk0_n_bk1_
,
arg_ptr
[
group_id
].
b1_grid_desc_bk0_n_bk1_
,
arg_ptr
[
group_id
].
c_grid_desc_mblock_mperblock_nblock_nperblock_
,
arg_ptr
[
group_id
].
c_grid_desc_mblock_mperblock_nblock_nperblock_
,
arg_ptr
[
group_id
].
z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_
m4_
n4_n5_
n6_
,
arg_ptr
[
group_id
].
z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5_
,
arg_ptr
[
group_id
].
lse_grid_desc_m_
,
arg_ptr
[
group_id
].
lse_grid_desc_m_
,
arg_ptr
[
group_id
].
block_2_ctile_map_
,
arg_ptr
[
group_id
].
block_2_ctile_map_
,
arg_ptr
[
group_id
].
c0_matrix_mask_
,
arg_ptr
[
group_id
].
c0_matrix_mask_
,
p_dropout_in_
16bits
,
p_dropout_in_
uint8_t
,
p_dropout_rescale
,
p_dropout_rescale
,
ph
,
ph
,
arg_ptr
[
group_id
].
z_random_matrix_offset_
+
arg_ptr
[
group_id
].
z_random_matrix_offset_
+
...
@@ -198,7 +199,7 @@ __global__ void
...
@@ -198,7 +199,7 @@ __global__ void
ignore
=
acc_element_op
;
ignore
=
acc_element_op
;
ignore
=
b1_element_op
;
ignore
=
b1_element_op
;
ignore
=
c_element_op
;
ignore
=
c_element_op
;
ignore
=
p_dropout_in_
16bits
;
ignore
=
p_dropout_in_
uint8_t
;
ignore
=
p_dropout_rescale
;
ignore
=
p_dropout_rescale
;
ignore
=
seed
;
ignore
=
seed
;
ignore
=
offset
;
ignore
=
offset
;
...
@@ -620,8 +621,8 @@ struct DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle_V2
...
@@ -620,8 +621,8 @@ struct DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle_V2
B1GridDesc_BK0_N_BK1
b1_grid_desc_bk0_n_bk1_
;
B1GridDesc_BK0_N_BK1
b1_grid_desc_bk0_n_bk1_
;
typename
GridwiseGemm
::
CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
typename
GridwiseGemm
::
CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
c_grid_desc_mblock_mperblock_nblock_nperblock_
;
c_grid_desc_mblock_mperblock_nblock_nperblock_
;
typename
GridwiseGemm
::
ZGridDescriptor_M0_N0_M1_N1_M2_N2_M3_N3_
M4_
N4_N5
_N6
typename
GridwiseGemm
::
ZGridDescriptor_M0_N0_M1_N1_M2_N2_M3_N3_N4_N5
z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_
m4_
n4_n5_
n6_
;
z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5_
;
ZGridDesc_M_N
z_grid_desc_m_n_
;
ZGridDesc_M_N
z_grid_desc_m_n_
;
LSEGridDesc_M
lse_grid_desc_m_
;
LSEGridDesc_M
lse_grid_desc_m_
;
...
@@ -774,8 +775,8 @@ struct DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle_V2
...
@@ -774,8 +775,8 @@ struct DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle_V2
GridwiseGemm
::
MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
(
GridwiseGemm
::
MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
(
c_grid_desc_m_n
);
c_grid_desc_m_n
);
const
auto
z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_
m4_
n4_n5
_n6
=
const
auto
z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5
=
GridwiseGemm
::
MakeCGridDescriptor_M0_N0_M1_N1_M2_N2_M3_N3_
M4_
N4_N5
_N6
(
GridwiseGemm
::
MakeCGridDescriptor_M0_N0_M1_N1_M2_N2_M3_N3_N4_N5
(
z_grid_desc_m_n
);
z_grid_desc_m_n
);
const
index_t
BlockStart
=
grid_size_
;
const
index_t
BlockStart
=
grid_size_
;
...
@@ -819,7 +820,7 @@ struct DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle_V2
...
@@ -819,7 +820,7 @@ struct DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle_V2
d0_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5
,
d0_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5
,
b1_grid_desc_bk0_n_bk1
,
b1_grid_desc_bk0_n_bk1
,
c_grid_desc_mblock_mperblock_nblock_nperblock
,
c_grid_desc_mblock_mperblock_nblock_nperblock
,
z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_
m4_
n4_n5
_n6
,
z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5
,
z_grid_desc_m_n
,
z_grid_desc_m_n
,
lse_grid_desc_m
,
lse_grid_desc_m
,
block_2_ctile_map
.
CalculateGridSize
(
c_grid_desc_m_n
),
block_2_ctile_map
.
CalculateGridSize
(
c_grid_desc_m_n
),
...
@@ -857,11 +858,11 @@ struct DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle_V2
...
@@ -857,11 +858,11 @@ struct DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle_V2
d0_n_length_stride
});
d0_n_length_stride
});
}
}
use_dropout_
=
p_dropout
>
0.0
;
//
use_dropout_
=
p_dropout
>
0.0
;
//
p_dropout_
=
1.
f
-
p_dropout
;
p_dropout_
=
1.
f
-
p_dropout
;
p_dropout_in_
16bits
_
=
uint
16
_t
(
std
::
floor
(
p_dropout_
*
6553
5.0
));
p_dropout_in_
uint8_t
_
=
uint
8
_t
(
std
::
floor
(
p_dropout_
*
25
5.0
));
p_dropout_
=
1.
f
/
p_dropout_
;
p_dropout_
=
1.
f
/
p_dropout_
;
p_dropout_rescale_
=
type_convert
<
GemmAccDataType
>
(
p_dropout_
);
p_dropout_rescale_
=
type_convert
<
GemmAccDataType
>
(
p_dropout_
);
seed_
=
std
::
get
<
0
>
(
seeds
);
seed_
=
std
::
get
<
0
>
(
seeds
);
offset_
=
std
::
get
<
1
>
(
seeds
);
offset_
=
std
::
get
<
1
>
(
seeds
);
...
@@ -880,7 +881,7 @@ struct DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle_V2
...
@@ -880,7 +881,7 @@ struct DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle_V2
CElementwiseOperation
c_element_op_
;
CElementwiseOperation
c_element_op_
;
float
p_dropout_
;
float
p_dropout_
;
u
shor
t
p_dropout_in_
16bits
_
;
u
int8_
t
p_dropout_in_
uint8_t
_
;
unsigned
long
long
seed_
;
unsigned
long
long
seed_
;
unsigned
long
long
offset_
;
unsigned
long
long
offset_
;
GemmAccDataType
p_dropout_rescale_
;
GemmAccDataType
p_dropout_rescale_
;
...
@@ -912,10 +913,34 @@ struct DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle_V2
...
@@ -912,10 +913,34 @@ struct DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle_V2
some_has_main_k_block_loop
|=
y
;
some_has_main_k_block_loop
|=
y
;
}
}
hipGetErrorString
(
hipMemcpy
(
arg
.
p_workspace_
,
hipStreamCaptureStatus
status
=
hipStreamCaptureStatusNone
;
arg
.
group_kernel_args_
.
data
(),
arg
.
group_kernel_args_
.
size
()
*
sizeof
(
GroupKernelArg
),
HIP_CHECK_ERROR
(
hipStreamIsCapturing
(
stream_config
.
stream_id_
,
&
status
));
hipMemcpyHostToDevice
));
if
(
status
==
hipStreamCaptureStatusActive
)
{
size_t
copy_size
=
arg
.
group_kernel_args_
.
size
()
*
sizeof
(
GroupKernelArg
);
// ToDO: when to release this memory buffer?
char
*
persistent_ptr
=
new
char
[
copy_size
];
(
void
)
std
::
memcpy
(
persistent_ptr
,
arg
.
group_kernel_args_
.
data
(),
copy_size
);
HIP_CHECK_ERROR
(
hipMemcpyAsync
(
arg
.
p_workspace_
,
persistent_ptr
,
copy_size
,
hipMemcpyHostToDevice
,
stream_config
.
stream_id_
));
}
else
{
HIP_CHECK_ERROR
(
hipMemcpyAsync
(
arg
.
p_workspace_
,
arg
.
group_kernel_args_
.
data
(),
arg
.
group_kernel_args_
.
size
()
*
sizeof
(
GroupKernelArg
),
hipMemcpyHostToDevice
,
stream_config
.
stream_id_
));
}
float
ave_time
=
0
;
float
ave_time
=
0
;
...
@@ -949,7 +974,7 @@ struct DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle_V2
...
@@ -949,7 +974,7 @@ struct DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle_V2
arg
.
acc_element_op_
,
arg
.
acc_element_op_
,
arg
.
b1_element_op_
,
arg
.
b1_element_op_
,
arg
.
c_element_op_
,
arg
.
c_element_op_
,
arg
.
p_dropout_in_
16bits
_
,
arg
.
p_dropout_in_
uint8_t
_
,
arg
.
p_dropout_rescale_
,
arg
.
p_dropout_rescale_
,
arg
.
seed_
,
arg
.
seed_
,
arg
.
offset_
);
arg
.
offset_
);
...
...
include/ck/tensor_operation/gpu/grid/gridwise_batched_dropout.hpp
View file @
aea324d2
...
@@ -57,8 +57,8 @@ struct GridwiseBatchedDropout
...
@@ -57,8 +57,8 @@ struct GridwiseBatchedDropout
static
constexpr
auto
mfma
=
MfmaSelector
<
GemmDataType
,
MPerXdl
,
NPerXdl
>::
selected_mfma
;
static
constexpr
auto
mfma
=
MfmaSelector
<
GemmDataType
,
MPerXdl
,
NPerXdl
>::
selected_mfma
;
static
constexpr
auto
DropoutNThread
=
mfma
.
num_input_blks
;
// 2
static
constexpr
auto
DropoutNThread
=
mfma
.
num_input_blks
;
// 2
// get_random_
8x
16() generates
8
random numbers each time
// get_random_16
x8
() generates
16
random numbers each time
static
constexpr
auto
DropoutTile
=
Number
<
DropoutNThread
*
8
>
{};
//
16
static
constexpr
auto
DropoutTile
=
Number
<
DropoutNThread
*
16
>
{};
//
32
using
ThisThreadBlock
=
ThisThreadBlock
<
BlockSize
>
;
using
ThisThreadBlock
=
ThisThreadBlock
<
BlockSize
>
;
...
@@ -241,7 +241,7 @@ struct GridwiseBatchedDropout
...
@@ -241,7 +241,7 @@ struct GridwiseBatchedDropout
// only used for providing ApplyDropoutAttnBwdSaveZ
// only used for providing ApplyDropoutAttnBwdSaveZ
auto
blockwise_dropout
=
BlockwiseDropout
<
FloatGemmAcc
,
decltype
(
thread_slice_desc_m_n
)
>
{
auto
blockwise_dropout
=
BlockwiseDropout
<
FloatGemmAcc
,
decltype
(
thread_slice_desc_m_n
)
>
{
static_cast
<
unsigned
short
>
(
0.8
f
*
6553
5.
f
),
static_cast
<
FloatGemmAcc
>
(
1.0
f
/
0.8
f
)};
static_cast
<
unsigned
short
>
(
0.8
f
*
25
5.
f
),
static_cast
<
FloatGemmAcc
>
(
1.0
f
/
0.8
f
)};
//
//
// z vgpr copy to global
// z vgpr copy to global
...
@@ -260,7 +260,7 @@ struct GridwiseBatchedDropout
...
@@ -260,7 +260,7 @@ struct GridwiseBatchedDropout
n2
));
// NPerXdl
n2
));
// NPerXdl
StaticBuffer
<
AddressSpaceEnum
::
Vgpr
,
StaticBuffer
<
AddressSpaceEnum
::
Vgpr
,
u
shor
t
,
u
int8_
t
,
z_thread_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3
.
GetElementSpaceSize
(),
z_thread_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3
.
GetElementSpaceSize
(),
true
>
true
>
z_tensor_buffer
;
z_tensor_buffer
;
...
@@ -273,7 +273,7 @@ struct GridwiseBatchedDropout
...
@@ -273,7 +273,7 @@ struct GridwiseBatchedDropout
const
auto
wave_m_n_id
=
GetGemm0WaveMNIdx
(
wave_id
[
I2
]);
// I2: 0~63
const
auto
wave_m_n_id
=
GetGemm0WaveMNIdx
(
wave_id
[
I2
]);
// I2: 0~63
auto
z_thread_copy_vgpr_to_global
=
ThreadwiseTensorSliceTransfer_v1r3
<
auto
z_thread_copy_vgpr_to_global
=
ThreadwiseTensorSliceTransfer_v1r3
<
u
shor
t
,
u
int8_
t
,
ZDataType
,
ZDataType
,
decltype
(
z_thread_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3
),
decltype
(
z_thread_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3
),
decltype
(
z_grid_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3
),
decltype
(
z_grid_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3
),
...
...
include/ck/tensor_operation/gpu/grid/gridwise_batched_mha_bwd_xdl_cshuffle_qloop_b2t_light_v1.hpp
View file @
aea324d2
...
@@ -99,8 +99,6 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V1
...
@@ -99,8 +99,6 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V1
static
constexpr
auto
I5
=
Number
<
5
>
{};
static
constexpr
auto
I5
=
Number
<
5
>
{};
static
constexpr
auto
I6
=
Number
<
6
>
{};
static
constexpr
auto
I6
=
Number
<
6
>
{};
static
constexpr
auto
I7
=
Number
<
7
>
{};
static
constexpr
auto
I7
=
Number
<
7
>
{};
static
constexpr
auto
I8
=
Number
<
8
>
{};
static
constexpr
auto
I9
=
Number
<
9
>
{};
static
constexpr
auto
WaveSize
=
64
;
static
constexpr
auto
WaveSize
=
64
;
...
@@ -120,8 +118,8 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V1
...
@@ -120,8 +118,8 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V1
static
constexpr
auto
V_K0
=
KPerBlock
/
V_K1
/
V_K2
;
static
constexpr
auto
V_K0
=
KPerBlock
/
V_K1
/
V_K2
;
static
constexpr
auto
V_N1
=
NXdlPerWave
;
static
constexpr
auto
V_N1
=
NXdlPerWave
;
static
constexpr
auto
DropoutNThread
=
mfma
.
num_input_blks
;
// 2
static
constexpr
auto
DropoutNThread
=
mfma
.
num_input_blks
;
// 2
// get_random_
8x
16() generates
8
random numbers each time
// get_random_16
x8
() generates
16
random numbers each time
static
constexpr
auto
DropoutTile
=
Number
<
DropoutNThread
*
8
>
{};
//
16
static
constexpr
auto
DropoutTile
=
Number
<
DropoutNThread
*
16
>
{};
//
32
using
ThisThreadBlock
=
ThisThreadBlock
<
BlockSize
>
;
using
ThisThreadBlock
=
ThisThreadBlock
<
BlockSize
>
;
...
@@ -1450,8 +1448,8 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V1
...
@@ -1450,8 +1448,8 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V1
{
{
const
FloatGemmAcc
p_dropout
=
type_convert
<
FloatGemmAcc
>
(
1.0
f
-
p_drop
);
const
FloatGemmAcc
p_dropout
=
type_convert
<
FloatGemmAcc
>
(
1.0
f
-
p_drop
);
const
FloatGemmAcc
rp_dropout
=
type_convert
<
FloatGemmAcc
>
(
1.0
f
/
p_dropout
);
const
FloatGemmAcc
rp_dropout
=
type_convert
<
FloatGemmAcc
>
(
1.0
f
/
p_dropout
);
const
u
shor
t
p_dropout_in_
16bits
=
const
u
int8_
t
p_dropout_in_
uint8_t
=
__builtin_amdgcn_readfirstlane
(
std
::
floor
(
p_dropout
*
6
55
35
.0
));
__builtin_amdgcn_readfirstlane
(
uint8_t
(
std
::
floor
(
p_dropout
*
2
55.0
))
)
;
const
tensor_operation
::
element_wise
::
Scale
scale_rp_dropout
(
s_element_op
.
Value
()
*
const
tensor_operation
::
element_wise
::
Scale
scale_rp_dropout
(
s_element_op
.
Value
()
*
rp_dropout
);
rp_dropout
);
...
@@ -1769,7 +1767,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V1
...
@@ -1769,7 +1767,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V1
decltype
(
thread_slice_desc_m_n
)
>
{};
decltype
(
thread_slice_desc_m_n
)
>
{};
auto
blockwise_dropout
=
BlockwiseDropout
<
FloatGemmAcc
,
decltype
(
thread_slice_desc_m_n
)
>
{
auto
blockwise_dropout
=
BlockwiseDropout
<
FloatGemmAcc
,
decltype
(
thread_slice_desc_m_n
)
>
{
p_dropout_in_
16bits
,
rp_dropout
};
p_dropout_in_
uint8_t
,
rp_dropout
};
auto
lse_grid_desc_mb_m0_m1_m2_m3_m4
=
auto
lse_grid_desc_mb_m0_m1_m2_m3_m4
=
MakeLSEGridDescriptor_MB_M0_M1_M2_M3_M4
<
decltype
(
s_blockwise_gemm
)
>
(
lse_grid_desc_m
);
MakeLSEGridDescriptor_MB_M0_M1_M2_M3_M4
<
decltype
(
s_blockwise_gemm
)
>
(
lse_grid_desc_m
);
...
@@ -1838,7 +1836,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V1
...
@@ -1838,7 +1836,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V1
n2
));
// NPerXdl
n2
));
// NPerXdl
StaticBuffer
<
AddressSpaceEnum
::
Vgpr
,
StaticBuffer
<
AddressSpaceEnum
::
Vgpr
,
u
shor
t
,
u
int8_
t
,
z_thread_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3
.
GetElementSpaceSize
(),
z_thread_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3
.
GetElementSpaceSize
(),
true
>
true
>
z_tensor_buffer
;
z_tensor_buffer
;
...
@@ -1848,7 +1846,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V1
...
@@ -1848,7 +1846,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V1
p_z_grid
,
z_grid_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3
.
GetElementSpaceSize
());
p_z_grid
,
z_grid_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3
.
GetElementSpaceSize
());
auto
z_thread_copy_vgpr_to_global
=
ThreadwiseTensorSliceTransfer_v1r3
<
auto
z_thread_copy_vgpr_to_global
=
ThreadwiseTensorSliceTransfer_v1r3
<
u
shor
t
,
u
int8_
t
,
ZDataType
,
ZDataType
,
decltype
(
z_thread_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3
),
decltype
(
z_thread_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3
),
decltype
(
z_grid_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3
),
decltype
(
z_grid_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3
),
...
...
Prev
1
2
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment