Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
6cc7d0de
Commit
6cc7d0de
authored
Jun 30, 2023
by
danyao12
Browse files
rename device ops
parent
38f48480
Changes
26
Expand all
Show whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
354 additions
and
349 deletions
+354
-349
example/32_batched_gemm_scale_softmax_gemm/batched_multihead_attention_backward_v1.cpp
..._softmax_gemm/batched_multihead_attention_backward_v1.cpp
+4
-4
example/32_batched_gemm_scale_softmax_gemm/batched_multihead_attention_backward_v2.cpp
..._softmax_gemm/batched_multihead_attention_backward_v2.cpp
+34
-34
example/32_batched_gemm_scale_softmax_gemm/batched_multihead_attention_backward_v2_phased.cpp
...x_gemm/batched_multihead_attention_backward_v2_phased.cpp
+2
-2
example/32_batched_gemm_scale_softmax_gemm/batched_multihead_attention_forward_v1.cpp
...e_softmax_gemm/batched_multihead_attention_forward_v1.cpp
+3
-3
example/32_batched_gemm_scale_softmax_gemm/batched_multihead_attention_forward_v2.cpp
...e_softmax_gemm/batched_multihead_attention_forward_v2.cpp
+3
-3
example/32_batched_gemm_scale_softmax_gemm/batched_multihead_attention_train_v1.cpp
...ale_softmax_gemm/batched_multihead_attention_train_v1.cpp
+7
-7
example/32_batched_gemm_scale_softmax_gemm/batched_multihead_attention_train_v2.cpp
...ale_softmax_gemm/batched_multihead_attention_train_v2.cpp
+49
-49
example/32_batched_gemm_scale_softmax_gemm/grouped_multihead_attention_backward_v1.cpp
..._softmax_gemm/grouped_multihead_attention_backward_v1.cpp
+4
-4
example/32_batched_gemm_scale_softmax_gemm/grouped_multihead_attention_backward_v2.cpp
..._softmax_gemm/grouped_multihead_attention_backward_v2.cpp
+34
-34
example/32_batched_gemm_scale_softmax_gemm/grouped_multihead_attention_forward_v1.cpp
...e_softmax_gemm/grouped_multihead_attention_forward_v1.cpp
+3
-3
example/32_batched_gemm_scale_softmax_gemm/grouped_multihead_attention_forward_v2.cpp
...e_softmax_gemm/grouped_multihead_attention_forward_v2.cpp
+3
-3
example/32_batched_gemm_scale_softmax_gemm/grouped_multihead_attention_train_v1.cpp
...ale_softmax_gemm/grouped_multihead_attention_train_v1.cpp
+7
-7
example/32_batched_gemm_scale_softmax_gemm/grouped_multihead_attention_train_v2.cpp
...ale_softmax_gemm/grouped_multihead_attention_train_v2.cpp
+49
-49
include/ck/tensor_operation/gpu/device/impl/device_batched_mha_bwd_xdl_cshuffle_kloop_v1.hpp
...ice/impl/device_batched_mha_bwd_xdl_cshuffle_kloop_v1.hpp
+29
-28
include/ck/tensor_operation/gpu/device/impl/device_batched_mha_bwd_xdl_cshuffle_kloop_v2.hpp
...ice/impl/device_batched_mha_bwd_xdl_cshuffle_kloop_v2.hpp
+29
-28
include/ck/tensor_operation/gpu/device/impl/device_batched_mha_bwd_xdl_cshuffle_qloop_phased_v1.hpp
...l/device_batched_mha_bwd_xdl_cshuffle_qloop_phased_v1.hpp
+28
-27
include/ck/tensor_operation/gpu/device/impl/device_batched_mha_bwd_xdl_cshuffle_qloop_v1.hpp
...ice/impl/device_batched_mha_bwd_xdl_cshuffle_qloop_v1.hpp
+28
-27
include/ck/tensor_operation/gpu/device/impl/device_batched_mha_bwd_xdl_cshuffle_qloop_v2.hpp
...ice/impl/device_batched_mha_bwd_xdl_cshuffle_qloop_v2.hpp
+28
-27
include/ck/tensor_operation/gpu/device/impl/device_batched_mha_fwd_xdl_cshuffle_v1.hpp
...pu/device/impl/device_batched_mha_fwd_xdl_cshuffle_v1.hpp
+5
-5
include/ck/tensor_operation/gpu/device/impl/device_batched_mha_fwd_xdl_cshuffle_v2.hpp
...pu/device/impl/device_batched_mha_fwd_xdl_cshuffle_v2.hpp
+5
-5
No files found.
example/32_batched_gemm_scale_softmax_gemm/batched_multihead_attention_backward_v1.cpp
View file @
6cc7d0de
...
@@ -103,7 +103,7 @@ static constexpr bool Deterministic = false;
...
@@ -103,7 +103,7 @@ static constexpr bool Deterministic = false;
// If 64 < DIM <= 128, ues prototype2 2nd template.
// If 64 < DIM <= 128, ues prototype2 2nd template.
#if(DIM <= 32)
#if(DIM <= 32)
using
DeviceGemmInstance
=
using
DeviceGemmInstance
=
ck
::
tensor_operation
::
device
::
DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_V1
<
ck
::
tensor_operation
::
device
::
DeviceBatchedMultiheadAttentionBackward_
Kloop_
Xdl_CShuffle_V1
<
NumDimG
,
NumDimG
,
NumDimM
,
NumDimM
,
NumDimN
,
NumDimN
,
...
@@ -173,7 +173,7 @@ using DeviceGemmInstance =
...
@@ -173,7 +173,7 @@ using DeviceGemmInstance =
Deterministic
>
;
Deterministic
>
;
#elif(DIM <= 64)
#elif(DIM <= 64)
using
DeviceGemmInstance
=
using
DeviceGemmInstance
=
ck
::
tensor_operation
::
device
::
DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_V1
<
ck
::
tensor_operation
::
device
::
DeviceBatchedMultiheadAttentionBackward_
Kloop_
Xdl_CShuffle_V1
<
NumDimG
,
NumDimG
,
NumDimM
,
NumDimM
,
NumDimN
,
NumDimN
,
...
@@ -243,7 +243,7 @@ using DeviceGemmInstance =
...
@@ -243,7 +243,7 @@ using DeviceGemmInstance =
Deterministic
>
;
Deterministic
>
;
// using DeviceGemmInstance =
// using DeviceGemmInstance =
// ck::tensor_operation::device::DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2<
// ck::tensor_operation::device::DeviceBatchedMultiheadAttentionBackward_
Kloop_
Xdl_CShuffle_V2<
// NumDimG,
// NumDimG,
// NumDimM,
// NumDimM,
// NumDimN,
// NumDimN,
...
@@ -313,7 +313,7 @@ using DeviceGemmInstance =
...
@@ -313,7 +313,7 @@ using DeviceGemmInstance =
// Deterministic>;
// Deterministic>;
#elif(DIM <= 128)
#elif(DIM <= 128)
using
DeviceGemmInstance
=
using
DeviceGemmInstance
=
ck
::
tensor_operation
::
device
::
DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2
<
ck
::
tensor_operation
::
device
::
DeviceBatchedMultiheadAttentionBackward_
Kloop_
Xdl_CShuffle_V2
<
NumDimG
,
NumDimG
,
NumDimM
,
NumDimM
,
NumDimN
,
NumDimN
,
...
...
example/32_batched_gemm_scale_softmax_gemm/batched_multihead_attention_backward_v2.cpp
View file @
6cc7d0de
This diff is collapsed.
Click to expand it.
example/32_batched_gemm_scale_softmax_gemm/batched_multihead_attention_backward_v2_phased.cpp
View file @
6cc7d0de
...
@@ -103,7 +103,7 @@ static constexpr bool Deterministic = false;
...
@@ -103,7 +103,7 @@ static constexpr bool Deterministic = false;
// If 64 < DIM <= 128, ues prototype2 2nd template.
// If 64 < DIM <= 128, ues prototype2 2nd template.
#if(DIM <= 32)
#if(DIM <= 32)
using
DeviceGemmInstance
=
using
DeviceGemmInstance
=
ck
::
tensor_operation
::
device
::
DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_V1
<
ck
::
tensor_operation
::
device
::
DeviceBatchedMultiheadAttentionBackward_
Qloop_Phased_
Xdl_CShuffle_V1
<
NumDimG
,
NumDimG
,
NumDimM
,
NumDimM
,
NumDimN
,
NumDimN
,
...
@@ -173,7 +173,7 @@ using DeviceGemmInstance =
...
@@ -173,7 +173,7 @@ using DeviceGemmInstance =
Deterministic
>
;
Deterministic
>
;
#elif(DIM <= 64)
#elif(DIM <= 64)
using
DeviceGemmInstance
=
using
DeviceGemmInstance
=
ck
::
tensor_operation
::
device
::
DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_V1
<
ck
::
tensor_operation
::
device
::
DeviceBatchedMultiheadAttentionBackward_
Qloop_Phased_
Xdl_CShuffle_V1
<
NumDimG
,
NumDimG
,
NumDimM
,
NumDimM
,
NumDimN
,
NumDimN
,
...
...
example/32_batched_gemm_scale_softmax_gemm/batched_multihead_attention_forward_v1.cpp
View file @
6cc7d0de
...
@@ -79,7 +79,7 @@ static constexpr bool Deterministic = false;
...
@@ -79,7 +79,7 @@ static constexpr bool Deterministic = false;
#if(DIM <= 32)
#if(DIM <= 32)
using
DeviceGemmInstance
=
using
DeviceGemmInstance
=
ck
::
tensor_operation
::
device
::
DeviceBatchedMultiheadAttentionForward_Xdl_CShuffle
<
ck
::
tensor_operation
::
device
::
DeviceBatchedMultiheadAttentionForward_Xdl_CShuffle
_V1
<
NumDimG
,
NumDimG
,
NumDimM
,
NumDimM
,
NumDimN
,
NumDimN
,
...
@@ -150,7 +150,7 @@ using DeviceGemmInstance =
...
@@ -150,7 +150,7 @@ using DeviceGemmInstance =
Deterministic
>
;
Deterministic
>
;
#elif(DIM <= 64)
#elif(DIM <= 64)
using
DeviceGemmInstance
=
using
DeviceGemmInstance
=
ck
::
tensor_operation
::
device
::
DeviceBatchedMultiheadAttentionForward_Xdl_CShuffle
<
ck
::
tensor_operation
::
device
::
DeviceBatchedMultiheadAttentionForward_Xdl_CShuffle
_V1
<
NumDimG
,
NumDimG
,
NumDimM
,
NumDimM
,
NumDimN
,
NumDimN
,
...
@@ -221,7 +221,7 @@ using DeviceGemmInstance =
...
@@ -221,7 +221,7 @@ using DeviceGemmInstance =
Deterministic
>
;
Deterministic
>
;
#elif(DIM <= 128)
#elif(DIM <= 128)
using
DeviceGemmInstance
=
using
DeviceGemmInstance
=
ck
::
tensor_operation
::
device
::
DeviceBatchedMultiheadAttentionForward_Xdl_CShuffle
<
ck
::
tensor_operation
::
device
::
DeviceBatchedMultiheadAttentionForward_Xdl_CShuffle
_V1
<
NumDimG
,
NumDimG
,
NumDimM
,
NumDimM
,
NumDimN
,
NumDimN
,
...
...
example/32_batched_gemm_scale_softmax_gemm/batched_multihead_attention_forward_v2.cpp
View file @
6cc7d0de
...
@@ -79,7 +79,7 @@ static constexpr bool Deterministic = false;
...
@@ -79,7 +79,7 @@ static constexpr bool Deterministic = false;
#if(DIM <= 32)
#if(DIM <= 32)
using
DeviceGemmInstance
=
using
DeviceGemmInstance
=
ck
::
tensor_operation
::
device
::
DeviceBatchedMultiheadAttentionForward_Xdl_CShuffle
<
ck
::
tensor_operation
::
device
::
DeviceBatchedMultiheadAttentionForward_Xdl_CShuffle
_V2
<
NumDimG
,
NumDimG
,
NumDimM
,
NumDimM
,
NumDimN
,
NumDimN
,
...
@@ -150,7 +150,7 @@ using DeviceGemmInstance =
...
@@ -150,7 +150,7 @@ using DeviceGemmInstance =
Deterministic
>
;
Deterministic
>
;
#elif(DIM <= 64)
#elif(DIM <= 64)
using
DeviceGemmInstance
=
using
DeviceGemmInstance
=
ck
::
tensor_operation
::
device
::
DeviceBatchedMultiheadAttentionForward_Xdl_CShuffle
<
ck
::
tensor_operation
::
device
::
DeviceBatchedMultiheadAttentionForward_Xdl_CShuffle
_V2
<
NumDimG
,
NumDimG
,
NumDimM
,
NumDimM
,
NumDimN
,
NumDimN
,
...
@@ -221,7 +221,7 @@ using DeviceGemmInstance =
...
@@ -221,7 +221,7 @@ using DeviceGemmInstance =
Deterministic
>
;
Deterministic
>
;
#elif(DIM <= 128)
#elif(DIM <= 128)
using
DeviceGemmInstance
=
using
DeviceGemmInstance
=
ck
::
tensor_operation
::
device
::
DeviceBatchedMultiheadAttentionForward_Xdl_CShuffle
<
ck
::
tensor_operation
::
device
::
DeviceBatchedMultiheadAttentionForward_Xdl_CShuffle
_V2
<
NumDimG
,
NumDimG
,
NumDimM
,
NumDimM
,
NumDimN
,
NumDimN
,
...
...
example/32_batched_gemm_scale_softmax_gemm/batched_multihead_attention_train_v1.cpp
View file @
6cc7d0de
...
@@ -112,7 +112,7 @@ static constexpr bool Deterministic = false;
...
@@ -112,7 +112,7 @@ static constexpr bool Deterministic = false;
// If 64 < DIM <= 128, ues prototype2 2nd template.
// If 64 < DIM <= 128, ues prototype2 2nd template.
#if(DIM <= 32)
#if(DIM <= 32)
using
DeviceGemmInstanceFWD
=
using
DeviceGemmInstanceFWD
=
ck
::
tensor_operation
::
device
::
DeviceBatchedMultiheadAttentionForward_Xdl_CShuffle
<
ck
::
tensor_operation
::
device
::
DeviceBatchedMultiheadAttentionForward_Xdl_CShuffle
_V1
<
NumDimG
,
NumDimG
,
NumDimM
,
NumDimM
,
NumDimN
,
NumDimN
,
...
@@ -183,7 +183,7 @@ using DeviceGemmInstanceFWD =
...
@@ -183,7 +183,7 @@ using DeviceGemmInstanceFWD =
Deterministic
>
;
Deterministic
>
;
using
DeviceGemmInstanceBWD
=
using
DeviceGemmInstanceBWD
=
ck
::
tensor_operation
::
device
::
DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_V1
<
ck
::
tensor_operation
::
device
::
DeviceBatchedMultiheadAttentionBackward_
Kloop_
Xdl_CShuffle_V1
<
NumDimG
,
NumDimG
,
NumDimM
,
NumDimM
,
NumDimN
,
NumDimN
,
...
@@ -253,7 +253,7 @@ using DeviceGemmInstanceBWD =
...
@@ -253,7 +253,7 @@ using DeviceGemmInstanceBWD =
Deterministic
>
;
Deterministic
>
;
#elif(DIM <= 64)
#elif(DIM <= 64)
using
DeviceGemmInstanceFWD
=
using
DeviceGemmInstanceFWD
=
ck
::
tensor_operation
::
device
::
DeviceBatchedMultiheadAttentionForward_Xdl_CShuffle
<
ck
::
tensor_operation
::
device
::
DeviceBatchedMultiheadAttentionForward_Xdl_CShuffle
_V1
<
NumDimG
,
NumDimG
,
NumDimM
,
NumDimM
,
NumDimN
,
NumDimN
,
...
@@ -324,7 +324,7 @@ using DeviceGemmInstanceFWD =
...
@@ -324,7 +324,7 @@ using DeviceGemmInstanceFWD =
Deterministic
>
;
Deterministic
>
;
using
DeviceGemmInstanceBWD
=
using
DeviceGemmInstanceBWD
=
ck
::
tensor_operation
::
device
::
DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_V1
<
ck
::
tensor_operation
::
device
::
DeviceBatchedMultiheadAttentionBackward_
Kloop_
Xdl_CShuffle_V1
<
NumDimG
,
NumDimG
,
NumDimM
,
NumDimM
,
NumDimN
,
NumDimN
,
...
@@ -394,7 +394,7 @@ using DeviceGemmInstanceBWD =
...
@@ -394,7 +394,7 @@ using DeviceGemmInstanceBWD =
Deterministic
>
;
Deterministic
>
;
// using DeviceGemmInstanceBWD =
// using DeviceGemmInstanceBWD =
// ck::tensor_operation::device::DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2<
// ck::tensor_operation::device::DeviceBatchedMultiheadAttentionBackward_
Kloop_
Xdl_CShuffle_V2<
// NumDimG,
// NumDimG,
// NumDimM,
// NumDimM,
// NumDimN,
// NumDimN,
...
@@ -464,7 +464,7 @@ using DeviceGemmInstanceBWD =
...
@@ -464,7 +464,7 @@ using DeviceGemmInstanceBWD =
// Deterministic>;
// Deterministic>;
#elif(DIM <= 128)
#elif(DIM <= 128)
using
DeviceGemmInstanceFWD
=
using
DeviceGemmInstanceFWD
=
ck
::
tensor_operation
::
device
::
DeviceBatchedMultiheadAttentionForward_Xdl_CShuffle
<
ck
::
tensor_operation
::
device
::
DeviceBatchedMultiheadAttentionForward_Xdl_CShuffle
_V1
<
NumDimG
,
NumDimG
,
NumDimM
,
NumDimM
,
NumDimN
,
NumDimN
,
...
@@ -535,7 +535,7 @@ using DeviceGemmInstanceFWD =
...
@@ -535,7 +535,7 @@ using DeviceGemmInstanceFWD =
Deterministic
>
;
Deterministic
>
;
using
DeviceGemmInstanceBWD
=
using
DeviceGemmInstanceBWD
=
ck
::
tensor_operation
::
device
::
DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2
<
ck
::
tensor_operation
::
device
::
DeviceBatchedMultiheadAttentionBackward_
Kloop_
Xdl_CShuffle_V2
<
NumDimG
,
NumDimG
,
NumDimM
,
NumDimM
,
NumDimN
,
NumDimN
,
...
...
example/32_batched_gemm_scale_softmax_gemm/batched_multihead_attention_train_v2.cpp
View file @
6cc7d0de
This diff is collapsed.
Click to expand it.
example/32_batched_gemm_scale_softmax_gemm/grouped_multihead_attention_backward_v1.cpp
View file @
6cc7d0de
...
@@ -102,7 +102,7 @@ static constexpr bool Deterministic = false;
...
@@ -102,7 +102,7 @@ static constexpr bool Deterministic = false;
// If 64 < DIM <= 128, ues prototype2 2nd template.
// If 64 < DIM <= 128, ues prototype2 2nd template.
#if(DIM <= 32)
#if(DIM <= 32)
using
DeviceGemmInstance
=
using
DeviceGemmInstance
=
ck
::
tensor_operation
::
device
::
DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle_V1
<
ck
::
tensor_operation
::
device
::
DeviceGroupedMultiheadAttentionBackward_
Kloop_
Xdl_CShuffle_V1
<
NumDimG
,
NumDimG
,
NumDimM
,
NumDimM
,
NumDimN
,
NumDimN
,
...
@@ -172,7 +172,7 @@ using DeviceGemmInstance =
...
@@ -172,7 +172,7 @@ using DeviceGemmInstance =
Deterministic
>
;
Deterministic
>
;
#elif(DIM <= 64)
#elif(DIM <= 64)
using
DeviceGemmInstance
=
using
DeviceGemmInstance
=
ck
::
tensor_operation
::
device
::
DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle_V1
<
ck
::
tensor_operation
::
device
::
DeviceGroupedMultiheadAttentionBackward_
Kloop_
Xdl_CShuffle_V1
<
NumDimG
,
NumDimG
,
NumDimM
,
NumDimM
,
NumDimN
,
NumDimN
,
...
@@ -242,7 +242,7 @@ using DeviceGemmInstance =
...
@@ -242,7 +242,7 @@ using DeviceGemmInstance =
Deterministic
>
;
Deterministic
>
;
// using DeviceGemmInstance =
// using DeviceGemmInstance =
// ck::tensor_operation::device::DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle_V2<
// ck::tensor_operation::device::DeviceGroupedMultiheadAttentionBackward_
Kloop_
Xdl_CShuffle_V2<
// NumDimG,
// NumDimG,
// NumDimM,
// NumDimM,
// NumDimN,
// NumDimN,
...
@@ -312,7 +312,7 @@ using DeviceGemmInstance =
...
@@ -312,7 +312,7 @@ using DeviceGemmInstance =
// Deterministic>;
// Deterministic>;
#elif(DIM <= 128)
#elif(DIM <= 128)
using
DeviceGemmInstance
=
using
DeviceGemmInstance
=
ck
::
tensor_operation
::
device
::
DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle_V2
<
ck
::
tensor_operation
::
device
::
DeviceGroupedMultiheadAttentionBackward_
Kloop_
Xdl_CShuffle_V2
<
NumDimG
,
NumDimG
,
NumDimM
,
NumDimM
,
NumDimN
,
NumDimN
,
...
...
example/32_batched_gemm_scale_softmax_gemm/grouped_multihead_attention_backward_v2.cpp
View file @
6cc7d0de
This diff is collapsed.
Click to expand it.
example/32_batched_gemm_scale_softmax_gemm/grouped_multihead_attention_forward_v1.cpp
View file @
6cc7d0de
...
@@ -79,7 +79,7 @@ static constexpr bool Deterministic = true;
...
@@ -79,7 +79,7 @@ static constexpr bool Deterministic = true;
#if(DIM <= 32)
#if(DIM <= 32)
using
DeviceGemmInstance
=
using
DeviceGemmInstance
=
ck
::
tensor_operation
::
device
::
DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle
<
ck
::
tensor_operation
::
device
::
DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle
_V1
<
NumDimG
,
NumDimG
,
NumDimM
,
NumDimM
,
NumDimN
,
NumDimN
,
...
@@ -150,7 +150,7 @@ using DeviceGemmInstance =
...
@@ -150,7 +150,7 @@ using DeviceGemmInstance =
Deterministic
>
;
Deterministic
>
;
#elif(DIM <= 64)
#elif(DIM <= 64)
using
DeviceGemmInstance
=
using
DeviceGemmInstance
=
ck
::
tensor_operation
::
device
::
DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle
<
ck
::
tensor_operation
::
device
::
DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle
_V1
<
NumDimG
,
NumDimG
,
NumDimM
,
NumDimM
,
NumDimN
,
NumDimN
,
...
@@ -221,7 +221,7 @@ using DeviceGemmInstance =
...
@@ -221,7 +221,7 @@ using DeviceGemmInstance =
Deterministic
>
;
Deterministic
>
;
#elif(DIM <= 128)
#elif(DIM <= 128)
using
DeviceGemmInstance
=
using
DeviceGemmInstance
=
ck
::
tensor_operation
::
device
::
DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle
<
ck
::
tensor_operation
::
device
::
DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle
_V1
<
NumDimG
,
NumDimG
,
NumDimM
,
NumDimM
,
NumDimN
,
NumDimN
,
...
...
example/32_batched_gemm_scale_softmax_gemm/grouped_multihead_attention_forward_v2.cpp
View file @
6cc7d0de
...
@@ -79,7 +79,7 @@ static constexpr bool Deterministic = true;
...
@@ -79,7 +79,7 @@ static constexpr bool Deterministic = true;
#if(DIM <= 32)
#if(DIM <= 32)
using
DeviceGemmInstance
=
using
DeviceGemmInstance
=
ck
::
tensor_operation
::
device
::
DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle
<
ck
::
tensor_operation
::
device
::
DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle
_V2
<
NumDimG
,
NumDimG
,
NumDimM
,
NumDimM
,
NumDimN
,
NumDimN
,
...
@@ -150,7 +150,7 @@ using DeviceGemmInstance =
...
@@ -150,7 +150,7 @@ using DeviceGemmInstance =
Deterministic
>
;
Deterministic
>
;
#elif(DIM <= 64)
#elif(DIM <= 64)
using
DeviceGemmInstance
=
using
DeviceGemmInstance
=
ck
::
tensor_operation
::
device
::
DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle
<
ck
::
tensor_operation
::
device
::
DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle
_V2
<
NumDimG
,
NumDimG
,
NumDimM
,
NumDimM
,
NumDimN
,
NumDimN
,
...
@@ -221,7 +221,7 @@ using DeviceGemmInstance =
...
@@ -221,7 +221,7 @@ using DeviceGemmInstance =
Deterministic
>
;
Deterministic
>
;
#elif(DIM <= 128)
#elif(DIM <= 128)
using
DeviceGemmInstance
=
using
DeviceGemmInstance
=
ck
::
tensor_operation
::
device
::
DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle
<
ck
::
tensor_operation
::
device
::
DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle
_V2
<
NumDimG
,
NumDimG
,
NumDimM
,
NumDimM
,
NumDimN
,
NumDimN
,
...
...
example/32_batched_gemm_scale_softmax_gemm/grouped_multihead_attention_train_v1.cpp
View file @
6cc7d0de
...
@@ -111,7 +111,7 @@ static constexpr bool Deterministic = true;
...
@@ -111,7 +111,7 @@ static constexpr bool Deterministic = true;
// If 64 < DIM <= 128, ues prototype2 2nd template.
// If 64 < DIM <= 128, ues prototype2 2nd template.
#if(DIM <= 32)
#if(DIM <= 32)
using
DeviceGemmInstanceFWD
=
using
DeviceGemmInstanceFWD
=
ck
::
tensor_operation
::
device
::
DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle
<
ck
::
tensor_operation
::
device
::
DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle
_V1
<
NumDimG
,
NumDimG
,
NumDimM
,
NumDimM
,
NumDimN
,
NumDimN
,
...
@@ -182,7 +182,7 @@ using DeviceGemmInstanceFWD =
...
@@ -182,7 +182,7 @@ using DeviceGemmInstanceFWD =
Deterministic
>
;
Deterministic
>
;
using
DeviceGemmInstanceBWD
=
using
DeviceGemmInstanceBWD
=
ck
::
tensor_operation
::
device
::
DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle_V1
<
ck
::
tensor_operation
::
device
::
DeviceGroupedMultiheadAttentionBackward_
Kloop_
Xdl_CShuffle_V1
<
NumDimG
,
NumDimG
,
NumDimM
,
NumDimM
,
NumDimN
,
NumDimN
,
...
@@ -252,7 +252,7 @@ using DeviceGemmInstanceBWD =
...
@@ -252,7 +252,7 @@ using DeviceGemmInstanceBWD =
Deterministic
>
;
Deterministic
>
;
#elif(DIM <= 64)
#elif(DIM <= 64)
using
DeviceGemmInstanceFWD
=
using
DeviceGemmInstanceFWD
=
ck
::
tensor_operation
::
device
::
DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle
<
ck
::
tensor_operation
::
device
::
DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle
_V1
<
NumDimG
,
NumDimG
,
NumDimM
,
NumDimM
,
NumDimN
,
NumDimN
,
...
@@ -323,7 +323,7 @@ using DeviceGemmInstanceFWD =
...
@@ -323,7 +323,7 @@ using DeviceGemmInstanceFWD =
Deterministic
>
;
Deterministic
>
;
using
DeviceGemmInstanceBWD
=
using
DeviceGemmInstanceBWD
=
ck
::
tensor_operation
::
device
::
DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle_V1
<
ck
::
tensor_operation
::
device
::
DeviceGroupedMultiheadAttentionBackward_
Kloop_
Xdl_CShuffle_V1
<
NumDimG
,
NumDimG
,
NumDimM
,
NumDimM
,
NumDimN
,
NumDimN
,
...
@@ -393,7 +393,7 @@ using DeviceGemmInstanceBWD =
...
@@ -393,7 +393,7 @@ using DeviceGemmInstanceBWD =
Deterministic
>
;
Deterministic
>
;
// using DeviceGemmInstanceBWD =
// using DeviceGemmInstanceBWD =
// ck::tensor_operation::device::DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle_V2<
// ck::tensor_operation::device::DeviceGroupedMultiheadAttentionBackward_
Kloop_
Xdl_CShuffle_V2<
// NumDimG,
// NumDimG,
// NumDimM,
// NumDimM,
// NumDimN,
// NumDimN,
...
@@ -463,7 +463,7 @@ using DeviceGemmInstanceBWD =
...
@@ -463,7 +463,7 @@ using DeviceGemmInstanceBWD =
// Deterministic>;
// Deterministic>;
#elif(DIM <= 128)
#elif(DIM <= 128)
using
DeviceGemmInstanceFWD
=
using
DeviceGemmInstanceFWD
=
ck
::
tensor_operation
::
device
::
DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle
<
ck
::
tensor_operation
::
device
::
DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle
_V1
<
NumDimG
,
NumDimG
,
NumDimM
,
NumDimM
,
NumDimN
,
NumDimN
,
...
@@ -534,7 +534,7 @@ using DeviceGemmInstanceFWD =
...
@@ -534,7 +534,7 @@ using DeviceGemmInstanceFWD =
Deterministic
>
;
Deterministic
>
;
using
DeviceGemmInstanceBWD
=
using
DeviceGemmInstanceBWD
=
ck
::
tensor_operation
::
device
::
DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle_V2
<
ck
::
tensor_operation
::
device
::
DeviceGroupedMultiheadAttentionBackward_
Kloop_
Xdl_CShuffle_V2
<
NumDimG
,
NumDimG
,
NumDimM
,
NumDimM
,
NumDimN
,
NumDimN
,
...
...
example/32_batched_gemm_scale_softmax_gemm/grouped_multihead_attention_train_v2.cpp
View file @
6cc7d0de
This diff is collapsed.
Click to expand it.
include/ck/tensor_operation/gpu/device/impl/device_batched_mha_bwd_xdl_cshuffle_kloop_v1.hpp
View file @
6cc7d0de
...
@@ -54,7 +54,7 @@ __global__ void
...
@@ -54,7 +54,7 @@ __global__ void
#if CK_USE_LAUNCH_BOUNDS
#if CK_USE_LAUNCH_BOUNDS
__launch_bounds__
(
CK_MAX_THREAD_PER_BLOCK
,
/*CK_MIN_BLOCK_PER_CU*/
1
)
__launch_bounds__
(
CK_MAX_THREAD_PER_BLOCK
,
/*CK_MIN_BLOCK_PER_CU*/
1
)
#endif
#endif
kernel_batched_multihead_attention_backward_xdl_cshuffle_v1
(
kernel_batched_multihead_attention_backward_
kloop_
xdl_cshuffle_v1
(
const
InputDataType
*
__restrict__
p_a_grid
,
const
InputDataType
*
__restrict__
p_a_grid
,
const
InputDataType
*
__restrict__
p_b_grid
,
const
InputDataType
*
__restrict__
p_b_grid
,
ZDataType
*
__restrict__
p_z_grid
,
ZDataType
*
__restrict__
p_z_grid
,
...
@@ -277,7 +277,7 @@ template <index_t NumDimG,
...
@@ -277,7 +277,7 @@ template <index_t NumDimG,
MaskingSpecialization
MaskingSpec
,
MaskingSpecialization
MaskingSpec
,
bool
Deterministic
,
bool
Deterministic
,
LoopScheduler
LoopSched
=
LoopScheduler
::
Default
>
LoopScheduler
LoopSched
=
LoopScheduler
::
Default
>
struct
DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_V1
struct
DeviceBatchedMultiheadAttentionBackward_
Kloop_
Xdl_CShuffle_V1
:
public
BaseOperator
// TODO inherit atten bwd op once API stablizes
:
public
BaseOperator
// TODO inherit atten bwd op once API stablizes
{
{
static_assert
(
NumDimG
>
0
&&
NumDimM
>
0
&&
NumDimN
>
0
&&
NumDimK
>
0
&&
NumDimO
>
0
,
static_assert
(
NumDimG
>
0
&&
NumDimM
>
0
&&
NumDimN
>
0
&&
NumDimK
>
0
&&
NumDimO
>
0
,
...
@@ -299,7 +299,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_V1
...
@@ -299,7 +299,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_V1
static constexpr index_t NumDimGemm1K = NumDimN;
static constexpr index_t NumDimGemm1K = NumDimN;
#endif
#endif
using
DeviceOp
=
DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_V1
;
using
DeviceOp
=
DeviceBatchedMultiheadAttentionBackward_
Kloop_
Xdl_CShuffle_V1
;
static
constexpr
auto
I0
=
Number
<
0
>
{};
static
constexpr
auto
I0
=
Number
<
0
>
{};
static
constexpr
auto
I1
=
Number
<
1
>
{};
static
constexpr
auto
I1
=
Number
<
1
>
{};
...
@@ -949,7 +949,8 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_V1
...
@@ -949,7 +949,8 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_V1
float
ave_time
=
0
;
float
ave_time
=
0
;
auto
launch_kernel
=
[
&
](
auto
has_main_k_block_loop_
)
{
auto
launch_kernel
=
[
&
](
auto
has_main_k_block_loop_
)
{
const
auto
kernel
=
kernel_batched_multihead_attention_backward_xdl_cshuffle_v1
<
const
auto
kernel
=
kernel_batched_multihead_attention_backward_kloop_xdl_cshuffle_v1
<
GridwiseGemm
,
GridwiseGemm
,
InputDataType
,
InputDataType
,
OutputDataType
,
OutputDataType
,
...
@@ -1284,7 +1285,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_V1
...
@@ -1284,7 +1285,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_V1
auto
str
=
std
::
stringstream
();
auto
str
=
std
::
stringstream
();
// clang-format off
// clang-format off
str
<<
"DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_V1"
str
<<
"DeviceBatchedMultiheadAttentionBackward_
Kloop_
Xdl_CShuffle_V1"
<<
"<"
<<
"<"
<<
BlockSize
<<
", "
<<
BlockSize
<<
", "
<<
MPerBlock
<<
", "
<<
MPerBlock
<<
", "
...
...
include/ck/tensor_operation/gpu/device/impl/device_batched_mha_bwd_xdl_cshuffle_kloop_v2.hpp
View file @
6cc7d0de
...
@@ -53,7 +53,7 @@ __global__ void
...
@@ -53,7 +53,7 @@ __global__ void
#if CK_USE_LAUNCH_BOUNDS
#if CK_USE_LAUNCH_BOUNDS
__launch_bounds__
(
CK_MAX_THREAD_PER_BLOCK
,
/*CK_MIN_BLOCK_PER_CU*/
1
)
__launch_bounds__
(
CK_MAX_THREAD_PER_BLOCK
,
/*CK_MIN_BLOCK_PER_CU*/
1
)
#endif
#endif
kernel_batched_multihead_attention_backward_xdl_cshuffle_v2
(
kernel_batched_multihead_attention_backward_
kloop_
xdl_cshuffle_v2
(
const
InputDataType
*
__restrict__
p_a_grid
,
const
InputDataType
*
__restrict__
p_a_grid
,
const
InputDataType
*
__restrict__
p_b_grid
,
const
InputDataType
*
__restrict__
p_b_grid
,
ZDataType
*
__restrict__
p_z_grid
,
ZDataType
*
__restrict__
p_z_grid
,
...
@@ -276,7 +276,7 @@ template <index_t NumDimG,
...
@@ -276,7 +276,7 @@ template <index_t NumDimG,
MaskingSpecialization
MaskingSpec
,
MaskingSpecialization
MaskingSpec
,
bool
Deterministic
,
bool
Deterministic
,
LoopScheduler
LoopSched
=
LoopScheduler
::
Default
>
LoopScheduler
LoopSched
=
LoopScheduler
::
Default
>
struct
DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2
struct
DeviceBatchedMultiheadAttentionBackward_
Kloop_
Xdl_CShuffle_V2
:
public
BaseOperator
// TODO inherit atten bwd op once API stablizes
:
public
BaseOperator
// TODO inherit atten bwd op once API stablizes
{
{
static_assert
(
NumDimG
>
0
&&
NumDimM
>
0
&&
NumDimN
>
0
&&
NumDimK
>
0
&&
NumDimO
>
0
,
static_assert
(
NumDimG
>
0
&&
NumDimM
>
0
&&
NumDimN
>
0
&&
NumDimK
>
0
&&
NumDimO
>
0
,
...
@@ -298,7 +298,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2
...
@@ -298,7 +298,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2
static constexpr index_t NumDimGemm1K = NumDimN;
static constexpr index_t NumDimGemm1K = NumDimN;
#endif
#endif
using
DeviceOp
=
DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2
;
using
DeviceOp
=
DeviceBatchedMultiheadAttentionBackward_
Kloop_
Xdl_CShuffle_V2
;
static
constexpr
auto
I0
=
Number
<
0
>
{};
static
constexpr
auto
I0
=
Number
<
0
>
{};
static
constexpr
auto
I1
=
Number
<
1
>
{};
static
constexpr
auto
I1
=
Number
<
1
>
{};
...
@@ -951,7 +951,8 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2
...
@@ -951,7 +951,8 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2
float
ave_time
=
0
;
float
ave_time
=
0
;
auto
launch_kernel
=
[
&
](
auto
has_main_k_block_loop_
)
{
auto
launch_kernel
=
[
&
](
auto
has_main_k_block_loop_
)
{
const
auto
kernel
=
kernel_batched_multihead_attention_backward_xdl_cshuffle_v2
<
const
auto
kernel
=
kernel_batched_multihead_attention_backward_kloop_xdl_cshuffle_v2
<
GridwiseGemm
,
GridwiseGemm
,
InputDataType
,
InputDataType
,
OutputDataType
,
OutputDataType
,
...
@@ -1284,7 +1285,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2
...
@@ -1284,7 +1285,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2
auto
str
=
std
::
stringstream
();
auto
str
=
std
::
stringstream
();
// clang-format off
// clang-format off
str
<<
"DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2"
str
<<
"DeviceBatchedMultiheadAttentionBackward_
Kloop_
Xdl_CShuffle_V2"
<<
"<"
<<
"<"
<<
BlockSize
<<
", "
<<
BlockSize
<<
", "
<<
MPerBlock
<<
", "
<<
MPerBlock
<<
", "
...
...
include/ck/tensor_operation/gpu/device/impl/device_batched_mha_bwd_xdl_cshuffle_qloop_phased_v1.hpp
View file @
6cc7d0de
...
@@ -53,7 +53,7 @@ __global__ void
...
@@ -53,7 +53,7 @@ __global__ void
#if CK_USE_LAUNCH_BOUNDS
#if CK_USE_LAUNCH_BOUNDS
__launch_bounds__
(
CK_MAX_THREAD_PER_BLOCK
,
/*CK_MIN_BLOCK_PER_CU*/
1
)
__launch_bounds__
(
CK_MAX_THREAD_PER_BLOCK
,
/*CK_MIN_BLOCK_PER_CU*/
1
)
#endif
#endif
kernel_batched_multihead_attention_backward_xdl_cshuffle_v1
(
kernel_batched_multihead_attention_backward_
qloop_phased_
xdl_cshuffle_v1
(
const
InputDataType
*
__restrict__
p_a_grid
,
const
InputDataType
*
__restrict__
p_a_grid
,
const
InputDataType
*
__restrict__
p_b_grid
,
const
InputDataType
*
__restrict__
p_b_grid
,
ZDataType
*
__restrict__
p_z_grid
,
ZDataType
*
__restrict__
p_z_grid
,
...
@@ -273,7 +273,7 @@ template <index_t NumDimG,
...
@@ -273,7 +273,7 @@ template <index_t NumDimG,
MaskingSpecialization
MaskingSpec
,
MaskingSpecialization
MaskingSpec
,
bool
Deterministic
,
bool
Deterministic
,
LoopScheduler
LoopSched
=
LoopScheduler
::
Default
>
LoopScheduler
LoopSched
=
LoopScheduler
::
Default
>
struct
DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_V1
struct
DeviceBatchedMultiheadAttentionBackward_
Qloop_Phased_
Xdl_CShuffle_V1
:
public
BaseOperator
// TODO inherit atten bwd op once API stablizes
:
public
BaseOperator
// TODO inherit atten bwd op once API stablizes
{
{
static_assert
(
NumDimG
>
0
&&
NumDimM
>
0
&&
NumDimN
>
0
&&
NumDimK
>
0
&&
NumDimO
>
0
,
static_assert
(
NumDimG
>
0
&&
NumDimM
>
0
&&
NumDimN
>
0
&&
NumDimK
>
0
&&
NumDimO
>
0
,
...
@@ -295,7 +295,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_V1
...
@@ -295,7 +295,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_V1
static constexpr index_t NumDimGemm1K = NumDimN;
static constexpr index_t NumDimGemm1K = NumDimN;
#endif
#endif
using
DeviceOp
=
DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_V1
;
using
DeviceOp
=
DeviceBatchedMultiheadAttentionBackward_
Qloop_Phased_
Xdl_CShuffle_V1
;
static
constexpr
auto
I0
=
Number
<
0
>
{};
static
constexpr
auto
I0
=
Number
<
0
>
{};
static
constexpr
auto
I1
=
Number
<
1
>
{};
static
constexpr
auto
I1
=
Number
<
1
>
{};
...
@@ -945,7 +945,8 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_V1
...
@@ -945,7 +945,8 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_V1
float
ave_time
=
0
;
float
ave_time
=
0
;
auto
launch_kernel
=
[
&
](
auto
has_main_k_block_loop_
)
{
auto
launch_kernel
=
[
&
](
auto
has_main_k_block_loop_
)
{
const
auto
kernel
=
kernel_batched_multihead_attention_backward_xdl_cshuffle_v1
<
const
auto
kernel
=
kernel_batched_multihead_attention_backward_qloop_phased_xdl_cshuffle_v1
<
GridwiseGemm
,
GridwiseGemm
,
InputDataType
,
InputDataType
,
OutputDataType
,
OutputDataType
,
...
@@ -1277,7 +1278,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_V1
...
@@ -1277,7 +1278,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_V1
auto
str
=
std
::
stringstream
();
auto
str
=
std
::
stringstream
();
// clang-format off
// clang-format off
str
<<
"DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_V1"
str
<<
"DeviceBatchedMultiheadAttentionBackward_
Qloop_Phased_
Xdl_CShuffle_V1"
<<
"<"
<<
"<"
<<
BlockSize
<<
", "
<<
BlockSize
<<
", "
<<
MPerBlock
<<
", "
<<
MPerBlock
<<
", "
...
...
include/ck/tensor_operation/gpu/device/impl/device_batched_mha_bwd_xdl_cshuffle_qloop_v1.hpp
View file @
6cc7d0de
...
@@ -53,7 +53,7 @@ __global__ void
...
@@ -53,7 +53,7 @@ __global__ void
#if CK_USE_LAUNCH_BOUNDS
#if CK_USE_LAUNCH_BOUNDS
__launch_bounds__
(
CK_MAX_THREAD_PER_BLOCK
,
/*CK_MIN_BLOCK_PER_CU*/
1
)
__launch_bounds__
(
CK_MAX_THREAD_PER_BLOCK
,
/*CK_MIN_BLOCK_PER_CU*/
1
)
#endif
#endif
kernel_batched_multihead_attention_backward_xdl_cshuffle_v1
(
kernel_batched_multihead_attention_backward_
qloop_
xdl_cshuffle_v1
(
const
InputDataType
*
__restrict__
p_a_grid
,
const
InputDataType
*
__restrict__
p_a_grid
,
const
InputDataType
*
__restrict__
p_b_grid
,
const
InputDataType
*
__restrict__
p_b_grid
,
ZDataType
*
__restrict__
p_z_grid
,
ZDataType
*
__restrict__
p_z_grid
,
...
@@ -273,7 +273,7 @@ template <index_t NumDimG,
...
@@ -273,7 +273,7 @@ template <index_t NumDimG,
MaskingSpecialization
MaskingSpec
,
MaskingSpecialization
MaskingSpec
,
bool
Deterministic
,
bool
Deterministic
,
LoopScheduler
LoopSched
=
LoopScheduler
::
Default
>
LoopScheduler
LoopSched
=
LoopScheduler
::
Default
>
struct
DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_V1
struct
DeviceBatchedMultiheadAttentionBackward_
Qloop_
Xdl_CShuffle_V1
:
public
BaseOperator
// TODO inherit atten bwd op once API stablizes
:
public
BaseOperator
// TODO inherit atten bwd op once API stablizes
{
{
static_assert
(
NumDimG
>
0
&&
NumDimM
>
0
&&
NumDimN
>
0
&&
NumDimK
>
0
&&
NumDimO
>
0
,
static_assert
(
NumDimG
>
0
&&
NumDimM
>
0
&&
NumDimN
>
0
&&
NumDimK
>
0
&&
NumDimO
>
0
,
...
@@ -285,7 +285,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_V1
...
@@ -285,7 +285,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_V1
// TODO: implement bias combination
// TODO: implement bias combination
static_assert
(
NumAcc0Bias
==
0
&&
NumAcc0Bias
==
0
,
"Bias addition is unimplemented"
);
static_assert
(
NumAcc0Bias
==
0
&&
NumAcc0Bias
==
0
,
"Bias addition is unimplemented"
);
using
DeviceOp
=
DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_V1
;
using
DeviceOp
=
DeviceBatchedMultiheadAttentionBackward_
Qloop_
Xdl_CShuffle_V1
;
static
constexpr
auto
I0
=
Number
<
0
>
{};
static
constexpr
auto
I0
=
Number
<
0
>
{};
static
constexpr
auto
I1
=
Number
<
1
>
{};
static
constexpr
auto
I1
=
Number
<
1
>
{};
...
@@ -933,7 +933,8 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_V1
...
@@ -933,7 +933,8 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_V1
float
ave_time
=
0
;
float
ave_time
=
0
;
auto
launch_kernel
=
[
&
](
auto
has_main_k_block_loop_
)
{
auto
launch_kernel
=
[
&
](
auto
has_main_k_block_loop_
)
{
const
auto
kernel
=
kernel_batched_multihead_attention_backward_xdl_cshuffle_v1
<
const
auto
kernel
=
kernel_batched_multihead_attention_backward_qloop_xdl_cshuffle_v1
<
GridwiseGemm
,
GridwiseGemm
,
InputDataType
,
InputDataType
,
OutputDataType
,
OutputDataType
,
...
@@ -1248,7 +1249,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_V1
...
@@ -1248,7 +1249,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_V1
auto
str
=
std
::
stringstream
();
auto
str
=
std
::
stringstream
();
// clang-format off
// clang-format off
str
<<
"DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_V1"
str
<<
"DeviceBatchedMultiheadAttentionBackward_
Qloop_
Xdl_CShuffle_V1"
<<
"<"
<<
"<"
<<
BlockSize
<<
", "
<<
BlockSize
<<
", "
<<
MPerBlock
<<
", "
<<
MPerBlock
<<
", "
...
...
include/ck/tensor_operation/gpu/device/impl/device_batched_mha_bwd_xdl_cshuffle_qloop_v2.hpp
View file @
6cc7d0de
...
@@ -52,7 +52,7 @@ __global__ void
...
@@ -52,7 +52,7 @@ __global__ void
#if CK_USE_LAUNCH_BOUNDS
#if CK_USE_LAUNCH_BOUNDS
__launch_bounds__
(
CK_MAX_THREAD_PER_BLOCK
,
/*CK_MIN_BLOCK_PER_CU*/
1
)
__launch_bounds__
(
CK_MAX_THREAD_PER_BLOCK
,
/*CK_MIN_BLOCK_PER_CU*/
1
)
#endif
#endif
kernel_batched_multihead_attention_backward_xdl_cshuffle_v2
(
kernel_batched_multihead_attention_backward_
qloop_
xdl_cshuffle_v2
(
const
InputDataType
*
__restrict__
p_a_grid
,
const
InputDataType
*
__restrict__
p_a_grid
,
const
InputDataType
*
__restrict__
p_b_grid
,
const
InputDataType
*
__restrict__
p_b_grid
,
ZDataType
*
__restrict__
p_z_grid
,
ZDataType
*
__restrict__
p_z_grid
,
...
@@ -279,7 +279,7 @@ template <index_t NumDimG,
...
@@ -279,7 +279,7 @@ template <index_t NumDimG,
MaskingSpecialization
MaskingSpec
,
MaskingSpecialization
MaskingSpec
,
bool
Deterministic
,
bool
Deterministic
,
LoopScheduler
LoopSched
=
LoopScheduler
::
Default
>
LoopScheduler
LoopSched
=
LoopScheduler
::
Default
>
struct
DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2
struct
DeviceBatchedMultiheadAttentionBackward_
Qloop_
Xdl_CShuffle_V2
:
public
BaseOperator
// TODO inherit atten bwd op once API stablizes
:
public
BaseOperator
// TODO inherit atten bwd op once API stablizes
{
{
static_assert
(
NumDimG
>
0
&&
NumDimM
>
0
&&
NumDimN
>
0
&&
NumDimK
>
0
&&
NumDimO
>
0
,
static_assert
(
NumDimG
>
0
&&
NumDimM
>
0
&&
NumDimN
>
0
&&
NumDimK
>
0
&&
NumDimO
>
0
,
...
@@ -291,7 +291,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2
...
@@ -291,7 +291,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2
// TODO: implement bias combination
// TODO: implement bias combination
static_assert
(
NumAcc0Bias
==
0
&&
NumAcc0Bias
==
0
,
"Bias addition is unimplemented"
);
static_assert
(
NumAcc0Bias
==
0
&&
NumAcc0Bias
==
0
,
"Bias addition is unimplemented"
);
using
DeviceOp
=
DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2
;
using
DeviceOp
=
DeviceBatchedMultiheadAttentionBackward_
Qloop_
Xdl_CShuffle_V2
;
static
constexpr
auto
I0
=
Number
<
0
>
{};
static
constexpr
auto
I0
=
Number
<
0
>
{};
static
constexpr
auto
I1
=
Number
<
1
>
{};
static
constexpr
auto
I1
=
Number
<
1
>
{};
...
@@ -950,7 +950,8 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2
...
@@ -950,7 +950,8 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2
float
ave_time
=
0
;
float
ave_time
=
0
;
auto
launch_kernel
=
[
&
](
auto
has_main_k_block_loop_
)
{
auto
launch_kernel
=
[
&
](
auto
has_main_k_block_loop_
)
{
const
auto
kernel
=
kernel_batched_multihead_attention_backward_xdl_cshuffle_v2
<
const
auto
kernel
=
kernel_batched_multihead_attention_backward_qloop_xdl_cshuffle_v2
<
GridwiseGemm
,
GridwiseGemm
,
InputDataType
,
InputDataType
,
OutputDataType
,
OutputDataType
,
...
@@ -1279,7 +1280,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2
...
@@ -1279,7 +1280,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2
auto
str
=
std
::
stringstream
();
auto
str
=
std
::
stringstream
();
// clang-format off
// clang-format off
str
<<
"DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2"
str
<<
"DeviceBatchedMultiheadAttentionBackward_
Qloop_
Xdl_CShuffle_V2"
<<
"<"
<<
"<"
<<
BlockSize
<<
", "
<<
BlockSize
<<
", "
<<
MPerBlock
<<
", "
<<
MPerBlock
<<
", "
...
...
include/ck/tensor_operation/gpu/device/impl/device_batched_mha_fwd_xdl_cshuffle_v1.hpp
View file @
6cc7d0de
...
@@ -51,7 +51,7 @@ __global__ void
...
@@ -51,7 +51,7 @@ __global__ void
#if CK_USE_LAUNCH_BOUNDS
#if CK_USE_LAUNCH_BOUNDS
__launch_bounds__
(
CK_MAX_THREAD_PER_BLOCK
,
CK_MIN_BLOCK_PER_CU
)
__launch_bounds__
(
CK_MAX_THREAD_PER_BLOCK
,
CK_MIN_BLOCK_PER_CU
)
#endif
#endif
kernel_batched_multiheadattention_forward_xdl_cshuffle
(
kernel_batched_multiheadattention_forward_xdl_cshuffle
_v1
(
const
FloatAB
*
__restrict__
p_a_grid
,
const
FloatAB
*
__restrict__
p_a_grid
,
const
FloatAB
*
__restrict__
p_b_grid
,
const
FloatAB
*
__restrict__
p_b_grid
,
const
FloatAB
*
__restrict__
p_b1_grid
,
const
FloatAB
*
__restrict__
p_b1_grid
,
...
@@ -255,7 +255,7 @@ template <index_t NumDimG,
...
@@ -255,7 +255,7 @@ template <index_t NumDimG,
MaskingSpecialization
MaskingSpec
,
MaskingSpecialization
MaskingSpec
,
bool
Deterministic
,
bool
Deterministic
,
LoopScheduler
LoopSched
=
LoopScheduler
::
Default
>
LoopScheduler
LoopSched
=
LoopScheduler
::
Default
>
struct
DeviceBatchedMultiheadAttentionForward_Xdl_CShuffle
struct
DeviceBatchedMultiheadAttentionForward_Xdl_CShuffle
_V1
:
public
DeviceBatchedMultiheadAttentionForward
<
NumDimG
,
:
public
DeviceBatchedMultiheadAttentionForward
<
NumDimG
,
NumDimM
,
NumDimM
,
NumDimN
,
NumDimN
,
...
@@ -295,7 +295,7 @@ struct DeviceBatchedMultiheadAttentionForward_Xdl_CShuffle
...
@@ -295,7 +295,7 @@ struct DeviceBatchedMultiheadAttentionForward_Xdl_CShuffle
static constexpr index_t NumDimGemm1K = NumDimN;
static constexpr index_t NumDimGemm1K = NumDimN;
#endif
#endif
using
DeviceOp
=
DeviceBatchedMultiheadAttentionForward_Xdl_CShuffle
;
using
DeviceOp
=
DeviceBatchedMultiheadAttentionForward_Xdl_CShuffle
_V1
;
static
constexpr
auto
I0
=
Number
<
0
>
{};
static
constexpr
auto
I0
=
Number
<
0
>
{};
static
constexpr
auto
I1
=
Number
<
1
>
{};
static
constexpr
auto
I1
=
Number
<
1
>
{};
...
@@ -746,7 +746,7 @@ struct DeviceBatchedMultiheadAttentionForward_Xdl_CShuffle
...
@@ -746,7 +746,7 @@ struct DeviceBatchedMultiheadAttentionForward_Xdl_CShuffle
auto
launch_kernel
=
auto
launch_kernel
=
[
&
](
auto
has_main_k_block_loop_
,
auto
is_dropout_
,
auto
is_lse_storing_
)
{
[
&
](
auto
has_main_k_block_loop_
,
auto
is_dropout_
,
auto
is_lse_storing_
)
{
const
auto
kernel
=
kernel_batched_multiheadattention_forward_xdl_cshuffle
<
const
auto
kernel
=
kernel_batched_multiheadattention_forward_xdl_cshuffle
_v1
<
GridwiseGemm
,
GridwiseGemm
,
ADataType
,
// TODO: distiguish A/B datatype
ADataType
,
// TODO: distiguish A/B datatype
CDataType
,
CDataType
,
...
@@ -1116,7 +1116,7 @@ struct DeviceBatchedMultiheadAttentionForward_Xdl_CShuffle
...
@@ -1116,7 +1116,7 @@ struct DeviceBatchedMultiheadAttentionForward_Xdl_CShuffle
auto
str
=
std
::
stringstream
();
auto
str
=
std
::
stringstream
();
// clang-format off
// clang-format off
str
<<
"DeviceBatchedMultiheadAttentionForward_Xdl_CShuffle"
str
<<
"DeviceBatchedMultiheadAttentionForward_Xdl_CShuffle
_V1
"
<<
"<"
<<
"<"
<<
BlockSize
<<
", "
<<
BlockSize
<<
", "
<<
MPerBlock
<<
", "
<<
MPerBlock
<<
", "
...
...
include/ck/tensor_operation/gpu/device/impl/device_batched_mha_fwd_xdl_cshuffle_v2.hpp
View file @
6cc7d0de
...
@@ -51,7 +51,7 @@ __global__ void
...
@@ -51,7 +51,7 @@ __global__ void
#if CK_USE_LAUNCH_BOUNDS
#if CK_USE_LAUNCH_BOUNDS
__launch_bounds__
(
CK_MAX_THREAD_PER_BLOCK
,
CK_MIN_BLOCK_PER_CU
)
__launch_bounds__
(
CK_MAX_THREAD_PER_BLOCK
,
CK_MIN_BLOCK_PER_CU
)
#endif
#endif
kernel_batched_multiheadattention_forward_xdl_cshuffle
(
kernel_batched_multiheadattention_forward_xdl_cshuffle
_v2
(
const
FloatAB
*
__restrict__
p_a_grid
,
const
FloatAB
*
__restrict__
p_a_grid
,
const
FloatAB
*
__restrict__
p_b_grid
,
const
FloatAB
*
__restrict__
p_b_grid
,
const
FloatAB
*
__restrict__
p_b1_grid
,
const
FloatAB
*
__restrict__
p_b1_grid
,
...
@@ -263,7 +263,7 @@ template <index_t NumDimG,
...
@@ -263,7 +263,7 @@ template <index_t NumDimG,
MaskingSpecialization
MaskingSpec
,
MaskingSpecialization
MaskingSpec
,
bool
Deterministic
,
bool
Deterministic
,
LoopScheduler
LoopSched
=
LoopScheduler
::
Default
>
LoopScheduler
LoopSched
=
LoopScheduler
::
Default
>
struct
DeviceBatchedMultiheadAttentionForward_Xdl_CShuffle
struct
DeviceBatchedMultiheadAttentionForward_Xdl_CShuffle
_V2
:
public
DeviceBatchedMultiheadAttentionForward
<
NumDimG
,
:
public
DeviceBatchedMultiheadAttentionForward
<
NumDimG
,
NumDimM
,
NumDimM
,
NumDimN
,
NumDimN
,
...
@@ -303,7 +303,7 @@ struct DeviceBatchedMultiheadAttentionForward_Xdl_CShuffle
...
@@ -303,7 +303,7 @@ struct DeviceBatchedMultiheadAttentionForward_Xdl_CShuffle
static constexpr index_t NumDimGemm1K = NumDimN;
static constexpr index_t NumDimGemm1K = NumDimN;
#endif
#endif
using
DeviceOp
=
DeviceBatchedMultiheadAttentionForward_Xdl_CShuffle
;
using
DeviceOp
=
DeviceBatchedMultiheadAttentionForward_Xdl_CShuffle
_V2
;
static
constexpr
auto
I0
=
Number
<
0
>
{};
static
constexpr
auto
I0
=
Number
<
0
>
{};
static
constexpr
auto
I1
=
Number
<
1
>
{};
static
constexpr
auto
I1
=
Number
<
1
>
{};
...
@@ -761,7 +761,7 @@ struct DeviceBatchedMultiheadAttentionForward_Xdl_CShuffle
...
@@ -761,7 +761,7 @@ struct DeviceBatchedMultiheadAttentionForward_Xdl_CShuffle
auto
launch_kernel
=
auto
launch_kernel
=
[
&
](
auto
has_main_k_block_loop_
,
auto
is_dropout_
,
auto
is_lse_storing_
)
{
[
&
](
auto
has_main_k_block_loop_
,
auto
is_dropout_
,
auto
is_lse_storing_
)
{
const
auto
kernel
=
kernel_batched_multiheadattention_forward_xdl_cshuffle
<
const
auto
kernel
=
kernel_batched_multiheadattention_forward_xdl_cshuffle
_v2
<
GridwiseGemm
,
GridwiseGemm
,
ADataType
,
// TODO: distiguish A/B datatype
ADataType
,
// TODO: distiguish A/B datatype
CDataType
,
CDataType
,
...
@@ -1133,7 +1133,7 @@ struct DeviceBatchedMultiheadAttentionForward_Xdl_CShuffle
...
@@ -1133,7 +1133,7 @@ struct DeviceBatchedMultiheadAttentionForward_Xdl_CShuffle
auto
str
=
std
::
stringstream
();
auto
str
=
std
::
stringstream
();
// clang-format off
// clang-format off
str
<<
"DeviceBatchedMultiheadAttentionForward_Xdl_CShuffle"
str
<<
"DeviceBatchedMultiheadAttentionForward_Xdl_CShuffle
_V2
"
<<
"<"
<<
"<"
<<
BlockSize
<<
", "
<<
BlockSize
<<
", "
<<
MPerBlock
<<
", "
<<
MPerBlock
<<
", "
...
...
Prev
1
2
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment