Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
2aa9cbee
"...composable_kernel.git" did not exist on "62d4af74491c153c196237575087843792553714"
Commit
2aa9cbee
authored
Jul 18, 2023
by
ltqin
Browse files
Merge branch 'attn-train-develop-qloop' into attn-train-develop-qloop-light
parents
be38f68d
41c659bb
Changes
38
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
119 additions
and
162 deletions
+119
-162
example/32_batched_gemm_scale_softmax_gemm/batched_multihead_attention_backward_v1.cpp
..._softmax_gemm/batched_multihead_attention_backward_v1.cpp
+1
-15
example/32_batched_gemm_scale_softmax_gemm/batched_multihead_attention_backward_v2.cpp
..._softmax_gemm/batched_multihead_attention_backward_v2.cpp
+1
-1
example/32_batched_gemm_scale_softmax_gemm/batched_multihead_attention_backward_v2_phased.cpp
...x_gemm/batched_multihead_attention_backward_v2_phased.cpp
+1
-1
example/32_batched_gemm_scale_softmax_gemm/batched_multihead_attention_forward_v1.cpp
...e_softmax_gemm/batched_multihead_attention_forward_v1.cpp
+1
-1
example/32_batched_gemm_scale_softmax_gemm/batched_multihead_attention_forward_v2.cpp
...e_softmax_gemm/batched_multihead_attention_forward_v2.cpp
+1
-1
example/32_batched_gemm_scale_softmax_gemm/batched_multihead_attention_train_v1.cpp
...ale_softmax_gemm/batched_multihead_attention_train_v1.cpp
+2
-16
example/32_batched_gemm_scale_softmax_gemm/batched_multihead_attention_train_v2.cpp
...ale_softmax_gemm/batched_multihead_attention_train_v2.cpp
+2
-2
example/32_batched_gemm_scale_softmax_gemm/grouped_multihead_attention_backward_v1.cpp
..._softmax_gemm/grouped_multihead_attention_backward_v1.cpp
+1
-15
example/32_batched_gemm_scale_softmax_gemm/grouped_multihead_attention_backward_v2.cpp
..._softmax_gemm/grouped_multihead_attention_backward_v2.cpp
+1
-1
example/32_batched_gemm_scale_softmax_gemm/grouped_multihead_attention_forward_v1.cpp
...e_softmax_gemm/grouped_multihead_attention_forward_v1.cpp
+1
-1
example/32_batched_gemm_scale_softmax_gemm/grouped_multihead_attention_forward_v2.cpp
...e_softmax_gemm/grouped_multihead_attention_forward_v2.cpp
+1
-1
example/32_batched_gemm_scale_softmax_gemm/grouped_multihead_attention_train_v1.cpp
...ale_softmax_gemm/grouped_multihead_attention_train_v1.cpp
+2
-16
example/32_batched_gemm_scale_softmax_gemm/grouped_multihead_attention_train_v2.cpp
...ale_softmax_gemm/grouped_multihead_attention_train_v2.cpp
+2
-2
example/32_batched_gemm_scale_softmax_gemm/run_batched_multihead_attention_forward.inc
..._softmax_gemm/run_batched_multihead_attention_forward.inc
+1
-1
example/32_batched_gemm_scale_softmax_gemm/run_grouped_multihead_attention_forward.inc
..._softmax_gemm/run_grouped_multihead_attention_forward.inc
+1
-1
include/ck/tensor_operation/gpu/device/impl/device_batched_mha_bwd_xdl_cshuffle_kloop_v1.hpp
...ice/impl/device_batched_mha_bwd_xdl_cshuffle_kloop_v1.hpp
+12
-29
include/ck/tensor_operation/gpu/device/impl/device_batched_mha_bwd_xdl_cshuffle_kloop_v2.hpp
...ice/impl/device_batched_mha_bwd_xdl_cshuffle_kloop_v2.hpp
+8
-5
include/ck/tensor_operation/gpu/device/impl/device_batched_mha_bwd_xdl_cshuffle_qloop_phased_v1.hpp
...l/device_batched_mha_bwd_xdl_cshuffle_qloop_phased_v1.hpp
+8
-5
include/ck/tensor_operation/gpu/device/impl/device_batched_mha_bwd_xdl_cshuffle_qloop_v1.hpp
...ice/impl/device_batched_mha_bwd_xdl_cshuffle_qloop_v1.hpp
+49
-39
include/ck/tensor_operation/gpu/device/impl/device_batched_mha_bwd_xdl_cshuffle_qloop_v2.hpp
...ice/impl/device_batched_mha_bwd_xdl_cshuffle_qloop_v2.hpp
+23
-9
No files found.
example/32_batched_gemm_scale_softmax_gemm/batched_multihead_attention_backward_v1.cpp
View file @
2aa9cbee
// SPDX-License-Identifier: MIT
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-202
2
, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-202
3
, Advanced Micro Devices, Inc. All rights reserved.
/*
/*
Backprop for Gemm + Softmax + Gemm fused operation, where forward prop is defined as:
Backprop for Gemm + Softmax + Gemm fused operation, where forward prop is defined as:
...
@@ -158,13 +158,6 @@ using DeviceGemmInstance =
...
@@ -158,13 +158,6 @@ using DeviceGemmInstance =
8
,
8
,
8
,
8
,
true
,
true
,
S
<
8
,
32
,
1
>
,
// B1BlockTransfer
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
2
,
false
,
1
,
// CShuffleMXdlPerWavePerShuffle
1
,
// CShuffleMXdlPerWavePerShuffle
1
,
// CShuffleNXdlPerWavePerShuffle
1
,
// CShuffleNXdlPerWavePerShuffle
S
<
1
,
64
,
1
,
4
>
,
// CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
S
<
1
,
64
,
1
,
4
>
,
// CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
...
@@ -228,13 +221,6 @@ using DeviceGemmInstance =
...
@@ -228,13 +221,6 @@ using DeviceGemmInstance =
8
,
8
,
8
,
8
,
true
,
true
,
S
<
8
,
32
,
1
>
,
// B1BlockTransfer
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
2
,
false
,
1
,
// CShuffleMXdlPerWavePerShuffle
1
,
// CShuffleMXdlPerWavePerShuffle
2
,
// CShuffleNXdlPerWavePerShuffle
2
,
// CShuffleNXdlPerWavePerShuffle
S
<
1
,
32
,
1
,
8
>
,
// CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
S
<
1
,
32
,
1
,
8
>
,
// CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
...
...
example/32_batched_gemm_scale_softmax_gemm/batched_multihead_attention_backward_v2.cpp
View file @
2aa9cbee
// SPDX-License-Identifier: MIT
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-202
2
, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-202
3
, Advanced Micro Devices, Inc. All rights reserved.
/*
/*
Backprop for Gemm + Softmax + Gemm fused operation, where forward prop is defined as:
Backprop for Gemm + Softmax + Gemm fused operation, where forward prop is defined as:
...
...
example/32_batched_gemm_scale_softmax_gemm/batched_multihead_attention_backward_v2_phased.cpp
View file @
2aa9cbee
// SPDX-License-Identifier: MIT
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-202
2
, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-202
3
, Advanced Micro Devices, Inc. All rights reserved.
/*
/*
Backprop for Gemm + Softmax + Gemm fused operation, where forward prop is defined as:
Backprop for Gemm + Softmax + Gemm fused operation, where forward prop is defined as:
...
...
example/32_batched_gemm_scale_softmax_gemm/batched_multihead_attention_forward_v1.cpp
View file @
2aa9cbee
// SPDX-License-Identifier: MIT
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-202
2
, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-202
3
, Advanced Micro Devices, Inc. All rights reserved.
/*
/*
Gemm + Softmax + Gemm fused operation. Computes C_g_m_o = Softmax(A_g_m_k * B0_g_k_n) * B1_g_n_o
Gemm + Softmax + Gemm fused operation. Computes C_g_m_o = Softmax(A_g_m_k * B0_g_k_n) * B1_g_n_o
...
...
example/32_batched_gemm_scale_softmax_gemm/batched_multihead_attention_forward_v2.cpp
View file @
2aa9cbee
// SPDX-License-Identifier: MIT
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-202
2
, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-202
3
, Advanced Micro Devices, Inc. All rights reserved.
/*
/*
Gemm + Softmax + Gemm fused operation. Computes C_g_m_o = Softmax(A_g_m_k * B0_g_k_n) * B1_g_n_o
Gemm + Softmax + Gemm fused operation. Computes C_g_m_o = Softmax(A_g_m_k * B0_g_k_n) * B1_g_n_o
...
...
example/32_batched_gemm_scale_softmax_gemm/batched_multihead_attention_train_v1.cpp
View file @
2aa9cbee
// SPDX-License-Identifier: MIT
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-202
2
, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-202
3
, Advanced Micro Devices, Inc. All rights reserved.
/*
/*
Gemm + Softmax + Gemm fused operation. Computes C_g_m_o = Softmax(A_g_m_k * B0_g_k_n) * B1_g_n_o
Gemm + Softmax + Gemm fused operation. Computes C_g_m_o = Softmax(A_g_m_k * B0_g_k_n) * B1_g_n_o
...
@@ -238,13 +238,6 @@ using DeviceGemmInstanceBWD =
...
@@ -238,13 +238,6 @@ using DeviceGemmInstanceBWD =
8
,
8
,
8
,
8
,
true
,
true
,
S
<
8
,
32
,
1
>
,
// B1BlockTransfer
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
2
,
false
,
1
,
// CShuffleMXdlPerWavePerShuffle
1
,
// CShuffleMXdlPerWavePerShuffle
1
,
// CShuffleNXdlPerWavePerShuffle
1
,
// CShuffleNXdlPerWavePerShuffle
S
<
1
,
64
,
1
,
4
>
,
// CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
S
<
1
,
64
,
1
,
4
>
,
// CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
...
@@ -379,13 +372,6 @@ using DeviceGemmInstanceBWD =
...
@@ -379,13 +372,6 @@ using DeviceGemmInstanceBWD =
8
,
8
,
8
,
8
,
true
,
true
,
S
<
8
,
32
,
1
>
,
// B1BlockTransfer
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
2
,
false
,
1
,
// CShuffleMXdlPerWavePerShuffle
1
,
// CShuffleMXdlPerWavePerShuffle
2
,
// CShuffleNXdlPerWavePerShuffle
2
,
// CShuffleNXdlPerWavePerShuffle
S
<
1
,
32
,
1
,
8
>
,
// CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
S
<
1
,
32
,
1
,
8
>
,
// CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
...
@@ -1396,7 +1382,7 @@ int run(int argc, char* argv[])
...
@@ -1396,7 +1382,7 @@ int run(int argc, char* argv[])
}
}
std
::
cout
<<
"Checking z:
\n
"
;
std
::
cout
<<
"Checking z:
\n
"
;
pass
&=
ck
::
utils
::
check_err
(
z_fwd_gs_ms_ns
.
mData
,
z_bwd_gs_ms_ns
.
mData
,
1
);
pass
&=
ck
::
utils
::
check_
integer_
err
(
z_fwd_gs_ms_ns
.
mData
,
z_bwd_gs_ms_ns
.
mData
,
1
);
std
::
cout
<<
"Checking y:
\n
"
;
std
::
cout
<<
"Checking y:
\n
"
;
pass
&=
ck
::
utils
::
check_err
(
pass
&=
ck
::
utils
::
check_err
(
...
...
example/32_batched_gemm_scale_softmax_gemm/batched_multihead_attention_train_v2.cpp
View file @
2aa9cbee
// SPDX-License-Identifier: MIT
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-202
2
, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-202
3
, Advanced Micro Devices, Inc. All rights reserved.
/*
/*
Gemm + Softmax + Gemm fused operation. Computes C_g_m_o = Softmax(A_g_m_k * B0_g_k_n) * B1_g_n_o
Gemm + Softmax + Gemm fused operation. Computes C_g_m_o = Softmax(A_g_m_k * B0_g_k_n) * B1_g_n_o
...
@@ -969,7 +969,7 @@ int run(int argc, char* argv[])
...
@@ -969,7 +969,7 @@ int run(int argc, char* argv[])
}
}
std
::
cout
<<
"Checking z:
\n
"
;
std
::
cout
<<
"Checking z:
\n
"
;
pass
&=
ck
::
utils
::
check_err
(
z_fwd_gs_ms_ns
.
mData
,
z_bwd_gs_ms_ns
.
mData
,
1
);
pass
&=
ck
::
utils
::
check_
integer_
err
(
z_fwd_gs_ms_ns
.
mData
,
z_bwd_gs_ms_ns
.
mData
,
1
);
std
::
cout
<<
"Checking y:
\n
"
;
std
::
cout
<<
"Checking y:
\n
"
;
pass
&=
ck
::
utils
::
check_err
(
pass
&=
ck
::
utils
::
check_err
(
...
...
example/32_batched_gemm_scale_softmax_gemm/grouped_multihead_attention_backward_v1.cpp
View file @
2aa9cbee
// SPDX-License-Identifier: MIT
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-202
2
, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-202
3
, Advanced Micro Devices, Inc. All rights reserved.
/*
/*
Backprop for Gemm + Softmax + Gemm fused operation, where forward prop is defined as:
Backprop for Gemm + Softmax + Gemm fused operation, where forward prop is defined as:
...
@@ -157,13 +157,6 @@ using DeviceGemmInstance =
...
@@ -157,13 +157,6 @@ using DeviceGemmInstance =
8
,
8
,
8
,
8
,
true
,
true
,
S
<
8
,
32
,
1
>
,
// B1BlockTransfer
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
2
,
false
,
1
,
// CShuffleMXdlPerWavePerShuffle
1
,
// CShuffleMXdlPerWavePerShuffle
1
,
// CShuffleNXdlPerWavePerShuffle
1
,
// CShuffleNXdlPerWavePerShuffle
S
<
1
,
64
,
1
,
4
>
,
// CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
S
<
1
,
64
,
1
,
4
>
,
// CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
...
@@ -227,13 +220,6 @@ using DeviceGemmInstance =
...
@@ -227,13 +220,6 @@ using DeviceGemmInstance =
8
,
8
,
8
,
8
,
true
,
true
,
S
<
8
,
32
,
1
>
,
// B1BlockTransfer
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
2
,
false
,
1
,
// CShuffleMXdlPerWavePerShuffle
1
,
// CShuffleMXdlPerWavePerShuffle
2
,
// CShuffleNXdlPerWavePerShuffle
2
,
// CShuffleNXdlPerWavePerShuffle
S
<
1
,
32
,
1
,
8
>
,
// CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
S
<
1
,
32
,
1
,
8
>
,
// CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
...
...
example/32_batched_gemm_scale_softmax_gemm/grouped_multihead_attention_backward_v2.cpp
View file @
2aa9cbee
// SPDX-License-Identifier: MIT
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-202
2
, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-202
3
, Advanced Micro Devices, Inc. All rights reserved.
/*
/*
Backprop for Gemm + Softmax + Gemm fused operation, where forward prop is defined as:
Backprop for Gemm + Softmax + Gemm fused operation, where forward prop is defined as:
...
...
example/32_batched_gemm_scale_softmax_gemm/grouped_multihead_attention_forward_v1.cpp
View file @
2aa9cbee
// SPDX-License-Identifier: MIT
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-202
2
, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-202
3
, Advanced Micro Devices, Inc. All rights reserved.
/*
/*
Gemm + Softmax + Gemm fused operation. Computes C_g_m_o = Softmax(A_g_m_k * B0_g_k_n) * B1_g_n_o
Gemm + Softmax + Gemm fused operation. Computes C_g_m_o = Softmax(A_g_m_k * B0_g_k_n) * B1_g_n_o
...
...
example/32_batched_gemm_scale_softmax_gemm/grouped_multihead_attention_forward_v2.cpp
View file @
2aa9cbee
// SPDX-License-Identifier: MIT
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-202
2
, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-202
3
, Advanced Micro Devices, Inc. All rights reserved.
/*
/*
Gemm + Softmax + Gemm fused operation. Computes C_g_m_o = Softmax(A_g_m_k * B0_g_k_n) * B1_g_n_o
Gemm + Softmax + Gemm fused operation. Computes C_g_m_o = Softmax(A_g_m_k * B0_g_k_n) * B1_g_n_o
...
...
example/32_batched_gemm_scale_softmax_gemm/grouped_multihead_attention_train_v1.cpp
View file @
2aa9cbee
// SPDX-License-Identifier: MIT
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-202
2
, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-202
3
, Advanced Micro Devices, Inc. All rights reserved.
/*
/*
Gemm + Softmax + Gemm fused operation. Computes C_g_m_o = Softmax(A_g_m_k * B0_g_k_n) * B1_g_n_o
Gemm + Softmax + Gemm fused operation. Computes C_g_m_o = Softmax(A_g_m_k * B0_g_k_n) * B1_g_n_o
...
@@ -237,13 +237,6 @@ using DeviceGemmInstanceBWD =
...
@@ -237,13 +237,6 @@ using DeviceGemmInstanceBWD =
8
,
8
,
8
,
8
,
true
,
true
,
S
<
8
,
32
,
1
>
,
// B1BlockTransfer
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
2
,
false
,
1
,
// CShuffleMXdlPerWavePerShuffle
1
,
// CShuffleMXdlPerWavePerShuffle
1
,
// CShuffleNXdlPerWavePerShuffle
1
,
// CShuffleNXdlPerWavePerShuffle
S
<
1
,
64
,
1
,
4
>
,
// CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
S
<
1
,
64
,
1
,
4
>
,
// CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
...
@@ -378,13 +371,6 @@ using DeviceGemmInstanceBWD =
...
@@ -378,13 +371,6 @@ using DeviceGemmInstanceBWD =
8
,
8
,
8
,
8
,
true
,
true
,
S
<
8
,
32
,
1
>
,
// B1BlockTransfer
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
2
,
false
,
1
,
// CShuffleMXdlPerWavePerShuffle
1
,
// CShuffleMXdlPerWavePerShuffle
2
,
// CShuffleNXdlPerWavePerShuffle
2
,
// CShuffleNXdlPerWavePerShuffle
S
<
1
,
32
,
1
,
8
>
,
// CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
S
<
1
,
32
,
1
,
8
>
,
// CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
...
@@ -1420,7 +1406,7 @@ int run(int argc, char* argv[])
...
@@ -1420,7 +1406,7 @@ int run(int argc, char* argv[])
}
}
std
::
cout
<<
"Checking z:
\n
"
;
std
::
cout
<<
"Checking z:
\n
"
;
pass
&=
ck
::
utils
::
check_err
(
z_fwd_tensors
[
i
].
mData
,
z_bwd_tensors
[
i
].
mData
,
1
);
pass
&=
ck
::
utils
::
check_
integer_
err
(
z_fwd_tensors
[
i
].
mData
,
z_bwd_tensors
[
i
].
mData
,
1
);
std
::
cout
<<
"Checking y:
\n
"
;
std
::
cout
<<
"Checking y:
\n
"
;
pass
&=
ck
::
utils
::
check_err
(
pass
&=
ck
::
utils
::
check_err
(
...
...
example/32_batched_gemm_scale_softmax_gemm/grouped_multihead_attention_train_v2.cpp
View file @
2aa9cbee
// SPDX-License-Identifier: MIT
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-202
2
, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-202
3
, Advanced Micro Devices, Inc. All rights reserved.
/*
/*
Gemm + Softmax + Gemm fused operation. Computes C_g_m_o = Softmax(A_g_m_k * B0_g_k_n) * B1_g_n_o
Gemm + Softmax + Gemm fused operation. Computes C_g_m_o = Softmax(A_g_m_k * B0_g_k_n) * B1_g_n_o
...
@@ -994,7 +994,7 @@ int run(int argc, char* argv[])
...
@@ -994,7 +994,7 @@ int run(int argc, char* argv[])
}
}
std
::
cout
<<
"Checking z:
\n
"
;
std
::
cout
<<
"Checking z:
\n
"
;
pass
&=
ck
::
utils
::
check_err
(
z_fwd_tensors
[
i
].
mData
,
z_bwd_tensors
[
i
].
mData
,
1
);
pass
&=
ck
::
utils
::
check_
integer_
err
(
z_fwd_tensors
[
i
].
mData
,
z_bwd_tensors
[
i
].
mData
,
1
);
std
::
cout
<<
"Checking y:
\n
"
;
std
::
cout
<<
"Checking y:
\n
"
;
pass
&=
ck
::
utils
::
check_err
(
pass
&=
ck
::
utils
::
check_err
(
...
...
example/32_batched_gemm_scale_softmax_gemm/run_batched_multihead_attention_forward.inc
View file @
2aa9cbee
// SPDX-License-Identifier: MIT
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-202
2
, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-202
3
, Advanced Micro Devices, Inc. All rights reserved.
int
run
(
int
argc
,
char
*
argv
[])
int
run
(
int
argc
,
char
*
argv
[])
{
{
...
...
example/32_batched_gemm_scale_softmax_gemm/run_grouped_multihead_attention_forward.inc
View file @
2aa9cbee
// SPDX-License-Identifier: MIT
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-202
2
, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-202
3
, Advanced Micro Devices, Inc. All rights reserved.
int
run
(
int
argc
,
char
*
argv
[])
int
run
(
int
argc
,
char
*
argv
[])
{
{
...
...
include/ck/tensor_operation/gpu/device/impl/device_batched_mha_bwd_xdl_cshuffle_kloop_v1.hpp
View file @
2aa9cbee
// SPDX-License-Identifier: MIT
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-202
2
, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-202
3
, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#pragma once
...
@@ -89,7 +89,8 @@ __global__ void
...
@@ -89,7 +89,8 @@ __global__ void
const
unsigned
long
long
seed
,
const
unsigned
long
long
seed
,
const
unsigned
long
long
offset
)
const
unsigned
long
long
offset
)
{
{
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__))
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \
defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__))
__shared__
char
p_shared
[
GridwiseGemm
::
GetSharedMemoryNumberOfByte
()];
__shared__
char
p_shared
[
GridwiseGemm
::
GetSharedMemoryNumberOfByte
()];
const
index_t
num_blocks_per_batch
=
const
index_t
num_blocks_per_batch
=
__builtin_amdgcn_readfirstlane
(
get_grid_size
()
/
batch_count
);
__builtin_amdgcn_readfirstlane
(
get_grid_size
()
/
batch_count
);
...
@@ -263,13 +264,6 @@ template <index_t NumDimG,
...
@@ -263,13 +264,6 @@ template <index_t NumDimG,
index_t
BBlockTransferSrcScalarPerVector
,
index_t
BBlockTransferSrcScalarPerVector
,
index_t
BBlockTransferDstScalarPerVector_BK1
,
index_t
BBlockTransferDstScalarPerVector_BK1
,
bool
BBlockLdsExtraN
,
bool
BBlockLdsExtraN
,
typename
B1BlockTransferThreadClusterLengths_BK0_N_BK1
,
typename
B1BlockTransferThreadClusterArrangeOrder
,
typename
B1BlockTransferSrcAccessOrder
,
index_t
B1BlockTransferSrcVectorDim
,
index_t
B1BlockTransferSrcScalarPerVector
,
index_t
B1BlockTransferDstScalarPerVector_BK1
,
bool
B1BlockLdsExtraN
,
index_t
CShuffleMXdlPerWavePerShuffle
,
index_t
CShuffleMXdlPerWavePerShuffle
,
index_t
CShuffleNXdlPerWavePerShuffle
,
index_t
CShuffleNXdlPerWavePerShuffle
,
typename
CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
,
typename
CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
,
...
@@ -641,7 +635,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Kloop_Xdl_CShuffle_V1
...
@@ -641,7 +635,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Kloop_Xdl_CShuffle_V1
};
};
// GridwiseGemm
// GridwiseGemm
using
GridwiseGemm
=
GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V1
<
using
GridwiseGemm
=
GridwiseBatchedMultiheadAttentionBackward_
Kloop_
Xdl_CShuffle_V1
<
InputDataType
,
// TODO: distinguish A/B datatype
InputDataType
,
// TODO: distinguish A/B datatype
OutputDataType
,
OutputDataType
,
ZDataType
,
ZDataType
,
...
@@ -693,14 +687,6 @@ struct DeviceBatchedMultiheadAttentionBackward_Kloop_Xdl_CShuffle_V1
...
@@ -693,14 +687,6 @@ struct DeviceBatchedMultiheadAttentionBackward_Kloop_Xdl_CShuffle_V1
BBlockTransferDstScalarPerVector_BK1
,
BBlockTransferDstScalarPerVector_BK1
,
true
,
true
,
BBlockLdsExtraN
,
BBlockLdsExtraN
,
B1BlockTransferThreadClusterLengths_BK0_N_BK1
,
B1BlockTransferThreadClusterArrangeOrder
,
B1BlockTransferSrcAccessOrder
,
B1BlockTransferSrcVectorDim
,
B1BlockTransferSrcScalarPerVector
,
B1BlockTransferDstScalarPerVector_BK1
,
false
,
B1BlockLdsExtraN
,
CShuffleMXdlPerWavePerShuffle
,
CShuffleMXdlPerWavePerShuffle
,
CShuffleNXdlPerWavePerShuffle
,
CShuffleNXdlPerWavePerShuffle
,
CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
,
CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
,
...
@@ -1046,11 +1032,13 @@ struct DeviceBatchedMultiheadAttentionBackward_Kloop_Xdl_CShuffle_V1
...
@@ -1046,11 +1032,13 @@ struct DeviceBatchedMultiheadAttentionBackward_Kloop_Xdl_CShuffle_V1
static
bool
IsSupportedArgument
(
const
Argument
&
arg
)
static
bool
IsSupportedArgument
(
const
Argument
&
arg
)
{
{
#if
0
#if
DEBUG_LOG
arg
.
Print
();
arg
.
Print
();
#endif
#endif
if
(
!
(
ck
::
get_device_name
()
==
"gfx908"
||
ck
::
get_device_name
()
==
"gfx90a"
))
if
(
!
(
ck
::
get_device_name
()
==
"gfx908"
||
ck
::
get_device_name
()
==
"gfx90a"
||
ck
::
get_device_name
()
==
"gfx940"
||
ck
::
get_device_name
()
==
"gfx941"
||
ck
::
get_device_name
()
==
"gfx942"
))
{
{
return
false
;
return
false
;
}
}
...
@@ -1079,14 +1067,12 @@ struct DeviceBatchedMultiheadAttentionBackward_Kloop_Xdl_CShuffle_V1
...
@@ -1079,14 +1067,12 @@ struct DeviceBatchedMultiheadAttentionBackward_Kloop_Xdl_CShuffle_V1
const
auto
Gemm1NzRaw
=
arg
.
raw_lengths_mz_nz_kz_gemm1nz_
[
3
];
const
auto
Gemm1NzRaw
=
arg
.
raw_lengths_mz_nz_kz_gemm1nz_
[
3
];
// Check scalar per vector requirement
// Check scalar per vector requirement
const
auto
a_extent_lowest
=
ABlockTransferSrcVectorDim
==
2
?
KzRaw
:
MzRaw
;
const
auto
a_extent_lowest
=
ABlockTransferSrcVectorDim
==
2
?
KzRaw
:
MzRaw
;
const
auto
b_extent_lowest
=
BBlockTransferSrcVectorDim
==
2
?
KzRaw
:
NzRaw
;
const
auto
b_extent_lowest
=
BBlockTransferSrcVectorDim
==
2
?
KzRaw
:
NzRaw
;
const
auto
b1_extent_lowest
=
B1BlockTransferSrcVectorDim
==
2
?
NzRaw
:
Gemm1NzRaw
;
const
auto
c_extent_lowest
=
Gemm1NzRaw
;
const
auto
c_extent_lowest
=
Gemm1NzRaw
;
if
(
!
(
a_extent_lowest
%
ABlockTransferSrcScalarPerVector
==
0
&&
if
(
!
(
a_extent_lowest
%
ABlockTransferSrcScalarPerVector
==
0
&&
b_extent_lowest
%
BBlockTransferSrcScalarPerVector
==
0
&&
b_extent_lowest
%
BBlockTransferSrcScalarPerVector
==
0
&&
b1_extent_lowest
%
B1BlockTransferSrcScalarPerVector
==
0
&&
c_extent_lowest
%
CShuffleBlockTransferScalarPerVector_NPerBlock
==
0
))
c_extent_lowest
%
CShuffleBlockTransferScalarPerVector_NPerBlock
==
0
))
{
{
return
false
;
return
false
;
...
@@ -1097,13 +1083,10 @@ struct DeviceBatchedMultiheadAttentionBackward_Kloop_Xdl_CShuffle_V1
...
@@ -1097,13 +1083,10 @@ struct DeviceBatchedMultiheadAttentionBackward_Kloop_Xdl_CShuffle_V1
ABlockTransferSrcVectorDim
==
2
?
arg
.
a_mz_kz_strides_
[
1
]
:
arg
.
a_mz_kz_strides_
[
0
];
ABlockTransferSrcVectorDim
==
2
?
arg
.
a_mz_kz_strides_
[
1
]
:
arg
.
a_mz_kz_strides_
[
0
];
const
auto
b_stride_lowest
=
const
auto
b_stride_lowest
=
BBlockTransferSrcVectorDim
==
2
?
arg
.
b_nz_kz_strides_
[
1
]
:
arg
.
b_nz_kz_strides_
[
0
];
BBlockTransferSrcVectorDim
==
2
?
arg
.
b_nz_kz_strides_
[
1
]
:
arg
.
b_nz_kz_strides_
[
0
];
const
auto
b1_stride_lowest
=
B1BlockTransferSrcVectorDim
==
2
?
arg
.
b1_nz_kz_strides_
[
1
]
:
arg
.
b1_nz_kz_strides_
[
0
];
const
auto
c_stride_lowest
=
const
auto
c_stride_lowest
=
arg
.
c_mz_gemm1nz_strides_
[
1
];
// cshuffle assumes lowest dim in Gemm1Ns to be contiguous
arg
.
c_mz_gemm1nz_strides_
[
1
];
// cshuffle assumes lowest dim in Gemm1Ns to be contiguous
if
(
!
(
a_stride_lowest
==
1
||
b_stride_lowest
==
1
||
b1_stride_lowest
==
1
||
if
(
!
(
a_stride_lowest
==
1
||
b_stride_lowest
==
1
||
c_stride_lowest
==
1
))
c_stride_lowest
==
1
))
{
{
return
false
;
return
false
;
}
}
...
...
include/ck/tensor_operation/gpu/device/impl/device_batched_mha_bwd_xdl_cshuffle_kloop_v2.hpp
View file @
2aa9cbee
// SPDX-License-Identifier: MIT
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-202
2
, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-202
3
, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#pragma once
...
@@ -88,7 +88,8 @@ __global__ void
...
@@ -88,7 +88,8 @@ __global__ void
const
unsigned
long
long
seed
,
const
unsigned
long
long
seed
,
const
unsigned
long
long
offset
)
const
unsigned
long
long
offset
)
{
{
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__))
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \
defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__))
__shared__
char
p_shared
[
GridwiseGemm
::
GetSharedMemoryNumberOfByte
()];
__shared__
char
p_shared
[
GridwiseGemm
::
GetSharedMemoryNumberOfByte
()];
const
index_t
num_blocks_per_batch
=
const
index_t
num_blocks_per_batch
=
__builtin_amdgcn_readfirstlane
(
get_grid_size
()
/
batch_count
);
__builtin_amdgcn_readfirstlane
(
get_grid_size
()
/
batch_count
);
...
@@ -640,7 +641,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Kloop_Xdl_CShuffle_V2
...
@@ -640,7 +641,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Kloop_Xdl_CShuffle_V2
};
};
// GridwiseGemm
// GridwiseGemm
using
GridwiseGemm
=
GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2
<
using
GridwiseGemm
=
GridwiseBatchedMultiheadAttentionBackward_
Kloop_
Xdl_CShuffle_V2
<
InputDataType
,
// TODO: distinguish A/B datatype
InputDataType
,
// TODO: distinguish A/B datatype
OutputDataType
,
OutputDataType
,
ZDataType
,
ZDataType
,
...
@@ -1047,11 +1048,13 @@ struct DeviceBatchedMultiheadAttentionBackward_Kloop_Xdl_CShuffle_V2
...
@@ -1047,11 +1048,13 @@ struct DeviceBatchedMultiheadAttentionBackward_Kloop_Xdl_CShuffle_V2
static
bool
IsSupportedArgument
(
const
Argument
&
arg
)
static
bool
IsSupportedArgument
(
const
Argument
&
arg
)
{
{
#if
0
#if
DEBUG_LOG
arg
.
Print
();
arg
.
Print
();
#endif
#endif
if
(
!
(
ck
::
get_device_name
()
==
"gfx908"
||
ck
::
get_device_name
()
==
"gfx90a"
))
if
(
!
(
ck
::
get_device_name
()
==
"gfx908"
||
ck
::
get_device_name
()
==
"gfx90a"
||
ck
::
get_device_name
()
==
"gfx940"
||
ck
::
get_device_name
()
==
"gfx941"
||
ck
::
get_device_name
()
==
"gfx942"
))
{
{
return
false
;
return
false
;
}
}
...
...
include/ck/tensor_operation/gpu/device/impl/device_batched_mha_bwd_xdl_cshuffle_qloop_phased_v1.hpp
View file @
2aa9cbee
// SPDX-License-Identifier: MIT
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-202
2
, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-202
3
, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#pragma once
...
@@ -87,7 +87,8 @@ __global__ void
...
@@ -87,7 +87,8 @@ __global__ void
const
unsigned
long
long
seed
,
const
unsigned
long
long
seed
,
const
unsigned
long
long
offset
)
const
unsigned
long
long
offset
)
{
{
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__))
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \
defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__))
__shared__
char
p_shared
[
GridwiseGemm
::
GetSharedMemoryNumberOfByte
()];
__shared__
char
p_shared
[
GridwiseGemm
::
GetSharedMemoryNumberOfByte
()];
const
index_t
num_blocks_per_batch
=
const
index_t
num_blocks_per_batch
=
__builtin_amdgcn_readfirstlane
(
get_grid_size
()
/
batch_count
);
__builtin_amdgcn_readfirstlane
(
get_grid_size
()
/
batch_count
);
...
@@ -637,7 +638,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Phased_Xdl_CShuffle_V1
...
@@ -637,7 +638,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Phased_Xdl_CShuffle_V1
};
};
// GridwiseGemm
// GridwiseGemm
using
GridwiseGemm
=
GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V1
<
using
GridwiseGemm
=
GridwiseBatchedMultiheadAttentionBackward_
Qloop_
Xdl_CShuffle_V1
<
InputDataType
,
// TODO: distinguish A/B datatype
InputDataType
,
// TODO: distinguish A/B datatype
OutputDataType
,
OutputDataType
,
ZDataType
,
ZDataType
,
...
@@ -1040,11 +1041,13 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Phased_Xdl_CShuffle_V1
...
@@ -1040,11 +1041,13 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Phased_Xdl_CShuffle_V1
static
bool
IsSupportedArgument
(
const
Argument
&
arg
)
static
bool
IsSupportedArgument
(
const
Argument
&
arg
)
{
{
#if
0
#if
DEBUG_LOG
arg
.
Print
();
arg
.
Print
();
#endif
#endif
if
(
!
(
ck
::
get_device_name
()
==
"gfx908"
||
ck
::
get_device_name
()
==
"gfx90a"
))
if
(
!
(
ck
::
get_device_name
()
==
"gfx908"
||
ck
::
get_device_name
()
==
"gfx90a"
||
ck
::
get_device_name
()
==
"gfx940"
||
ck
::
get_device_name
()
==
"gfx941"
||
ck
::
get_device_name
()
==
"gfx942"
))
{
{
return
false
;
return
false
;
}
}
...
...
include/ck/tensor_operation/gpu/device/impl/device_batched_mha_bwd_xdl_cshuffle_qloop_v1.hpp
View file @
2aa9cbee
// SPDX-License-Identifier: MIT
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-202
2
, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-202
3
, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#pragma once
...
@@ -48,6 +48,7 @@ template <typename GridwiseGemm,
...
@@ -48,6 +48,7 @@ template <typename GridwiseGemm,
typename
ComputeBasePtrOfStridedBatch
,
typename
ComputeBasePtrOfStridedBatch
,
typename
C0MatrixMask
,
typename
C0MatrixMask
,
bool
HasMainKBlockLoop
,
bool
HasMainKBlockLoop
,
bool
IsDropout
,
bool
Deterministic
>
bool
Deterministic
>
__global__
void
__global__
void
#if CK_USE_LAUNCH_BOUNDS
#if CK_USE_LAUNCH_BOUNDS
...
@@ -89,7 +90,8 @@ __global__ void
...
@@ -89,7 +90,8 @@ __global__ void
const
index_t
raw_m_padded
,
const
index_t
raw_m_padded
,
const
index_t
raw_n_padded
)
const
index_t
raw_n_padded
)
{
{
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__))
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \
defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__))
__shared__
char
p_shared
[
GridwiseGemm
::
GetSharedMemoryNumberOfByte
()];
__shared__
char
p_shared
[
GridwiseGemm
::
GetSharedMemoryNumberOfByte
()];
const
index_t
num_blocks_per_batch
=
const
index_t
num_blocks_per_batch
=
__builtin_amdgcn_readfirstlane
(
get_grid_size
()
/
batch_count
);
__builtin_amdgcn_readfirstlane
(
get_grid_size
()
/
batch_count
);
...
@@ -119,7 +121,7 @@ __global__ void
...
@@ -119,7 +121,7 @@ __global__ void
{
{
for
(
index_t
i
=
0
;
i
<
nblock
;
i
++
)
for
(
index_t
i
=
0
;
i
<
nblock
;
i
++
)
{
{
GridwiseGemm
::
template
Run
<
HasMainKBlockLoop
>(
GridwiseGemm
::
template
Run
<
HasMainKBlockLoop
,
IsDropout
>(
p_a_grid
+
a_batch_offset
,
p_a_grid
+
a_batch_offset
,
p_b_grid
+
b_batch_offset
,
p_b_grid
+
b_batch_offset
,
z_matrix_ptr
,
z_matrix_ptr
,
...
@@ -154,36 +156,36 @@ __global__ void
...
@@ -154,36 +156,36 @@ __global__ void
}
}
else
else
{
{
GridwiseGemm
::
template
Run
<
HasMainKBlockLoop
>(
p_a_grid
+
a_batch_offset
,
GridwiseGemm
::
template
Run
<
HasMainKBlockLoop
,
IsDropout
>(
p_a_grid
+
a_batch_offset
,
p_b_grid
+
b_batch_offset
,
p_b_grid
+
b_batch_offset
,
z_matrix_ptr
,
z_matrix_ptr
,
p_b1_grid
+
b1_batch_offset
,
p_b1_grid
+
b1_batch_offset
,
p_c_grid
+
c_batch_offset
,
p_c_grid
+
c_batch_offset
,
p_lse_grid
+
lse_batch_offset
,
p_lse_grid
+
lse_batch_offset
,
p_ygrad_grid
+
c_batch_offset
,
p_ygrad_grid
+
c_batch_offset
,
p_qgrad_grid
+
a_batch_offset
,
p_qgrad_grid
+
a_batch_offset
,
p_kgrad_grid
+
b_batch_offset
,
p_kgrad_grid
+
b_batch_offset
,
p_vgrad_grid
+
b1_batch_offset
,
p_vgrad_grid
+
b1_batch_offset
,
p_shared
,
p_shared
,
a_element_op
,
a_element_op
,
b_element_op
,
b_element_op
,
acc_element_op
,
acc_element_op
,
b1_element_op
,
b1_element_op
,
c_element_op
,
c_element_op
,
a_grid_desc_ak0_m_ak1
,
a_grid_desc_ak0_m_ak1
,
b_grid_desc_bk0_n_bk1
,
b_grid_desc_bk0_n_bk1
,
c_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5
,
c_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5
,
b1_grid_desc_bk0_n_bk1
,
b1_grid_desc_bk0_n_bk1
,
c_grid_desc_mblock_mperblock_nblock_nperblock
,
c_grid_desc_mblock_mperblock_nblock_nperblock
,
lse_grid_desc_m
,
lse_grid_desc_m
,
ygrad_grid_desc_o0_m_o1
,
ygrad_grid_desc_o0_m_o1
,
block_2_ctile_map
,
block_2_ctile_map
,
c0_matrix_mask
,
c0_matrix_mask
,
p_drop
,
p_drop
,
ph
,
ph
,
z_random_matrix_offset
,
z_random_matrix_offset
,
raw_n_padded
,
raw_n_padded
,
0
);
0
);
}
}
#else
#else
ignore
=
p_a_grid
;
ignore
=
p_a_grid
;
...
@@ -627,7 +629,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
...
@@ -627,7 +629,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
};
};
// GridwiseGemm
// GridwiseGemm
using
GridwiseGemm
=
GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V1
<
using
GridwiseGemm
=
GridwiseBatchedMultiheadAttentionBackward_
Qloop_
Xdl_CShuffle_V1
<
InputDataType
,
// TODO: distinguish A/B datatype
InputDataType
,
// TODO: distinguish A/B datatype
OutputDataType
,
OutputDataType
,
ZDataType
,
ZDataType
,
...
@@ -932,7 +934,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
...
@@ -932,7 +934,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
float
ave_time
=
0
;
float
ave_time
=
0
;
auto
launch_kernel
=
[
&
](
auto
has_main_k_block_loop_
)
{
auto
launch_kernel
=
[
&
](
auto
has_main_k_block_loop_
,
auto
is_dropout_
)
{
const
auto
kernel
=
const
auto
kernel
=
kernel_batched_multihead_attention_backward_qloop_xdl_cshuffle_v1
<
kernel_batched_multihead_attention_backward_qloop_xdl_cshuffle_v1
<
GridwiseGemm
,
GridwiseGemm
,
...
@@ -956,6 +958,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
...
@@ -956,6 +958,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
ComputeBasePtrOfStridedBatch
,
ComputeBasePtrOfStridedBatch
,
C0MatrixMask
,
C0MatrixMask
,
has_main_k_block_loop_
,
has_main_k_block_loop_
,
is_dropout_
,
Deterministic
>
;
Deterministic
>
;
return
launch_and_time_kernel
(
return
launch_and_time_kernel
(
...
@@ -997,9 +1000,11 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
...
@@ -997,9 +1000,11 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
arg
.
m_raw_padded_
,
arg
.
m_raw_padded_
,
arg
.
n_raw_padded_
);
arg
.
n_raw_padded_
);
};
};
if
(
arg
.
p_drop_
>
0.0
){
ave_time
=
launch_kernel
(
integral_constant
<
bool
,
false
>
{});
ave_time
=
launch_kernel
(
integral_constant
<
bool
,
false
>
{},
integral_constant
<
bool
,
true
>
{});
}
else
{
ave_time
=
launch_kernel
(
integral_constant
<
bool
,
false
>
{},
integral_constant
<
bool
,
false
>
{});
}
return
ave_time
;
return
ave_time
;
}
}
...
@@ -1019,8 +1024,13 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
...
@@ -1019,8 +1024,13 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
static
bool
IsSupportedArgument
(
const
Argument
&
arg
)
static
bool
IsSupportedArgument
(
const
Argument
&
arg
)
{
{
#if DEBUG_LOG
arg
.
Print
();
#endif
if
(
!
(
ck
::
get_device_name
()
==
"gfx908"
||
ck
::
get_device_name
()
==
"gfx90a"
))
if
(
!
(
ck
::
get_device_name
()
==
"gfx908"
||
ck
::
get_device_name
()
==
"gfx90a"
||
ck
::
get_device_name
()
==
"gfx940"
||
ck
::
get_device_name
()
==
"gfx941"
||
ck
::
get_device_name
()
==
"gfx942"
))
{
{
return
false
;
return
false
;
}
}
...
...
include/ck/tensor_operation/gpu/device/impl/device_batched_mha_bwd_xdl_cshuffle_qloop_v2.hpp
View file @
2aa9cbee
// SPDX-License-Identifier: MIT
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-202
2
, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-202
3
, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#pragma once
...
@@ -47,6 +47,7 @@ template <typename GridwiseGemm,
...
@@ -47,6 +47,7 @@ template <typename GridwiseGemm,
typename
ComputeBasePtrOfStridedBatch
,
typename
ComputeBasePtrOfStridedBatch
,
typename
C0MatrixMask
,
typename
C0MatrixMask
,
bool
HasMainKBlockLoop
,
bool
HasMainKBlockLoop
,
bool
IsDropout
,
bool
Deterministic
>
bool
Deterministic
>
__global__
void
__global__
void
#if CK_USE_LAUNCH_BOUNDS
#if CK_USE_LAUNCH_BOUNDS
...
@@ -88,7 +89,8 @@ __global__ void
...
@@ -88,7 +89,8 @@ __global__ void
const
index_t
raw_m_padded
,
const
index_t
raw_m_padded
,
const
index_t
raw_n_padded
)
const
index_t
raw_n_padded
)
{
{
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__))
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \
defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__))
__shared__
char
p_shared
[
GridwiseGemm
::
GetSharedMemoryNumberOfByte
()];
__shared__
char
p_shared
[
GridwiseGemm
::
GetSharedMemoryNumberOfByte
()];
const
index_t
num_blocks_per_batch
=
const
index_t
num_blocks_per_batch
=
__builtin_amdgcn_readfirstlane
(
get_grid_size
()
/
batch_count
);
__builtin_amdgcn_readfirstlane
(
get_grid_size
()
/
batch_count
);
...
@@ -118,7 +120,7 @@ __global__ void
...
@@ -118,7 +120,7 @@ __global__ void
{
{
for
(
index_t
i
=
0
;
i
<
nblock
;
i
++
)
for
(
index_t
i
=
0
;
i
<
nblock
;
i
++
)
{
{
GridwiseGemm
::
template
Run
<
HasMainKBlockLoop
>(
GridwiseGemm
::
template
Run
<
HasMainKBlockLoop
,
IsDropout
>(
p_a_grid
+
a_batch_offset
,
p_a_grid
+
a_batch_offset
,
p_b_grid
+
b_batch_offset
,
p_b_grid
+
b_batch_offset
,
z_matrix_ptr
,
z_matrix_ptr
,
...
@@ -153,7 +155,7 @@ __global__ void
...
@@ -153,7 +155,7 @@ __global__ void
}
}
else
else
{
{
GridwiseGemm
::
template
Run
<
HasMainKBlockLoop
>(
p_a_grid
+
a_batch_offset
,
GridwiseGemm
::
template
Run
<
HasMainKBlockLoop
,
IsDropout
>(
p_a_grid
+
a_batch_offset
,
p_b_grid
+
b_batch_offset
,
p_b_grid
+
b_batch_offset
,
z_matrix_ptr
,
z_matrix_ptr
,
p_b1_grid
+
b1_batch_offset
,
p_b1_grid
+
b1_batch_offset
,
...
@@ -633,7 +635,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
...
@@ -633,7 +635,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
};
};
// GridwiseGemm
// GridwiseGemm
using
GridwiseGemm
=
GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2
<
using
GridwiseGemm
=
GridwiseBatchedMultiheadAttentionBackward_
Qloop_
Xdl_CShuffle_V2
<
InputDataType
,
// TODO: distinguish A/B datatype
InputDataType
,
// TODO: distinguish A/B datatype
OutputDataType
,
OutputDataType
,
ZDataType
,
ZDataType
,
...
@@ -949,7 +951,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
...
@@ -949,7 +951,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
float
ave_time
=
0
;
float
ave_time
=
0
;
auto
launch_kernel
=
[
&
](
auto
has_main_k_block_loop_
)
{
auto
launch_kernel
=
[
&
](
auto
has_main_k_block_loop_
,
auto
is_dropout_
)
{
const
auto
kernel
=
const
auto
kernel
=
kernel_batched_multihead_attention_backward_qloop_xdl_cshuffle_v2
<
kernel_batched_multihead_attention_backward_qloop_xdl_cshuffle_v2
<
GridwiseGemm
,
GridwiseGemm
,
...
@@ -973,6 +975,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
...
@@ -973,6 +975,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
ComputeBasePtrOfStridedBatch
,
ComputeBasePtrOfStridedBatch
,
C0MatrixMask
,
C0MatrixMask
,
has_main_k_block_loop_
,
has_main_k_block_loop_
,
is_dropout_
,
Deterministic
>
;
Deterministic
>
;
return
launch_and_time_kernel
(
return
launch_and_time_kernel
(
...
@@ -1020,11 +1023,17 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
...
@@ -1020,11 +1023,17 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
if
(
GridwiseGemm
::
CalculateHasMainKBlockLoop
(
K
))
if
(
GridwiseGemm
::
CalculateHasMainKBlockLoop
(
K
))
{
{
ave_time
=
launch_kernel
(
integral_constant
<
bool
,
true
>
{});
if
(
arg
.
p_drop_
>
0.0
)
ave_time
=
launch_kernel
(
integral_constant
<
bool
,
true
>
{},
integral_constant
<
bool
,
true
>
{});
else
ave_time
=
launch_kernel
(
integral_constant
<
bool
,
true
>
{},
integral_constant
<
bool
,
false
>
{});
}
}
else
else
{
{
ave_time
=
launch_kernel
(
integral_constant
<
bool
,
false
>
{});
if
(
arg
.
p_drop_
>
0.0
)
ave_time
=
launch_kernel
(
integral_constant
<
bool
,
false
>
{},
integral_constant
<
bool
,
true
>
{});
else
ave_time
=
launch_kernel
(
integral_constant
<
bool
,
false
>
{},
integral_constant
<
bool
,
false
>
{});
}
}
return
ave_time
;
return
ave_time
;
...
@@ -1046,8 +1055,13 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
...
@@ -1046,8 +1055,13 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
static
bool
IsSupportedArgument
(
const
Argument
&
arg
)
static
bool
IsSupportedArgument
(
const
Argument
&
arg
)
{
{
#if DEBUG_LOG
arg
.
Print
();
#endif
if
(
!
(
ck
::
get_device_name
()
==
"gfx908"
||
ck
::
get_device_name
()
==
"gfx90a"
))
if
(
!
(
ck
::
get_device_name
()
==
"gfx908"
||
ck
::
get_device_name
()
==
"gfx90a"
||
ck
::
get_device_name
()
==
"gfx940"
||
ck
::
get_device_name
()
==
"gfx941"
||
ck
::
get_device_name
()
==
"gfx942"
))
{
{
return
false
;
return
false
;
}
}
...
...
Prev
1
2
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment